From e0c090f227e9b64e595b47d4d1f96f8a2fff5bf7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Jan 2018 09:19:18 +0800 Subject: [PATCH 0001/2461] [SPARK-22932][SQL] Refactor AnalysisContext ## What changes were proposed in this pull request? Add a `reset` function to ensure the state in `AnalysisContext ` is per-query. ## How was this patch tested? The existing test cases Author: gatorsmile Closes #20127 from gatorsmile/refactorAnalysisContext. --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee7..35b35110e491f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -176,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -600,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -1269,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1406,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] From a6fc300e91273230e7134ac6db95ccb4436c6f8f Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 2 Jan 2018 23:30:38 +0800 Subject: [PATCH 0002/2461] [SPARK-22897][CORE] Expose stageAttemptId in TaskContext ## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE Closes #20082 from advancedxy/SPARK-22897. --- .../scala/org/apache/spark/TaskContext.scala | 9 +++++- .../org/apache/spark/TaskContextImpl.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/JavaTaskContextCompileCheck.java | 2 ++ .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++-- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 29 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 +- 13 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4fa..69739745aa6cf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb06..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a06..f536fc2a5f0a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1e..f8e233a05a447 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0a..ced5a06516f75 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc248..dcf89e4f75acf 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085d..aa9c36c0aaacb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f11..9c0699bc981f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813ea..3b452f35c5ec1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae7998..3e31d22e15c0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d0..6af9f8b77f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9a..a3ae93810aa3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bedf..3fad7dfddadcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From 247a08939d58405aef39b2a4e7773aa45474ad12 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 3 Jan 2018 21:40:51 +0800 Subject: [PATCH 0003/2461] [SPARK-22938] Assert that SQLConf.get is accessed only on the driver. ## What changes were proposed in this pull request? Assert if code tries to access SQLConf.get on executor. This can lead to hard to detect bugs, where the executor will read fallbackConf, falling back to default config values, ignoring potentially changed non-default configs. If a config is to be passed to executor code, it needs to be read on the driver, and passed explicitly. ## How was this patch tested? Check in existing tests. Author: Juliusz Sompolski Closes #20136 from juliuszsompolski/SPARK-22938. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af57..80cdc61484c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -1087,6 +1089,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From 1a87a1609c4d2c9027a2cf669ea3337b89f61fb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 3 Jan 2018 22:09:30 +0800 Subject: [PATCH 0004/2461] [SPARK-22934][SQL] Make optional clauses order insensitive for CREATE TABLE SQL statement ## What changes were proposed in this pull request? Currently, our CREATE TABLE syntax require the EXACT order of clauses. It is pretty hard to remember the exact order. Thus, this PR is to make optional clauses order insensitive for `CREATE TABLE` SQL statement. ``` CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1 col_type1 [COMMENT col_comment1], ...)] USING datasource [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` The same idea is also applicable to Create Hive Table. ``` CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1[:] col_type1 [COMMENT col_comment1], ...)] [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20133 from gatorsmile/createDataSourceTableDDL. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 24 +- .../sql/catalyst/parser/ParserUtils.scala | 9 + .../spark/sql/execution/SparkSqlParser.scala | 81 +++++-- .../execution/command/DDLParserSuite.scala | 220 ++++++++++++++---- .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 13 +- .../sql/hive/execution/SQLQuerySuite.scala | 124 +++++----- 7 files changed, 335 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d55..6daf01d98426c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e6..89347f4b1f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972c..d3cfd2a1ffbf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b602..2b1aea08b1223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9cb..591510c1d8283 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38b..65be244418670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae945848..47adc77a52d51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } From a66fe36cee9363b01ee70e469f1c968f633c5713 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Jan 2018 22:18:13 +0800 Subject: [PATCH 0005/2461] [SPARK-20236][SQL] dynamic partition overwrite ## What changes were proposed in this pull request? When overwriting a partitioned table with dynamic partition columns, the behavior is different between data source and hive tables. data source table: delete all partition directories that match the static partition values provided in the insert statement. hive table: only delete partition directories which have data written into it This PR adds a new config to make users be able to choose hive's behavior. ## How was this patch tested? new tests Author: Wenchen Fan Closes #18714 from cloud-fan/overwrite-partition. --- .../internal/io/FileCommitProtocol.scala | 25 ++++-- .../io/HadoopMapReduceCommitProtocol.scala | 75 ++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 21 +++++ .../InsertIntoHadoopFsRelationCommand.scala | 20 ++++- .../SQLHadoopMapReduceCommitProtocol.scala | 10 ++- .../spark/sql/sources/InsertSuite.scala | 78 +++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af4530..6d0059b6a0272 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -28,8 +28,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,10 +140,22 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9c..6d20ef1f98a3c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + fs.delete(finalPartPath, true) + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +207,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80cdc61484c0f..5d6edf6b8abec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1068,6 +1068,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1394,6 +1412,9 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942a..dd7ef0d15c140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b1..39c594a9bc618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f45946..fef01c860db6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } From 9a2b65a3c0c36316aae0a53aa0f61c5044c2ceff Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 3 Jan 2018 11:31:32 -0600 Subject: [PATCH 0006/2461] [SPARK-22896] Improvement in String interpolation ## What changes were proposed in this pull request? * String interpolation in ml pipeline example has been corrected as per scala standard. ## How was this patch tested? * manually tested. Author: chetkhatri Closes #20070 from chetkhatri/mllib-chetan-contrib. --- .../spark/examples/ml/JavaQuantileDiscretizerExample.java | 2 +- .../apache/spark/examples/SimpleSkewedGroupByTest.scala | 4 ---- .../org/apache/spark/examples/graphx/Analytics.scala | 6 ++++-- .../org/apache/spark/examples/graphx/SynthBenchmark.scala | 6 +++--- .../apache/spark/examples/ml/ChiSquareTestExample.scala | 6 +++--- .../org/apache/spark/examples/ml/CorrelationExample.scala | 4 ++-- .../org/apache/spark/examples/ml/DataFrameExample.scala | 4 ++-- .../examples/ml/DecisionTreeClassificationExample.scala | 4 ++-- .../spark/examples/ml/DecisionTreeRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/DeveloperApiExample.scala | 6 +++--- .../examples/ml/EstimatorTransformerParamExample.scala | 6 +++--- .../ml/GradientBoostedTreeClassifierExample.scala | 4 ++-- .../examples/ml/GradientBoostedTreeRegressorExample.scala | 4 ++-- ...ulticlassLogisticRegressionWithElasticNetExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 2 +- .../org/apache/spark/examples/ml/NaiveBayesExample.scala | 2 +- .../spark/examples/ml/QuantileDiscretizerExample.scala | 4 ++-- .../spark/examples/ml/RandomForestClassifierExample.scala | 4 ++-- .../spark/examples/ml/RandomForestRegressorExample.scala | 4 ++-- .../apache/spark/examples/ml/VectorIndexerExample.scala | 4 ++-- .../spark/examples/mllib/AssociationRulesExample.scala | 6 +++--- .../mllib/BinaryClassificationMetricsExample.scala | 4 ++-- .../mllib/DecisionTreeClassificationExample.scala | 4 ++-- .../examples/mllib/DecisionTreeRegressionExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/FPGrowthExample.scala | 2 +- .../mllib/GradientBoostingClassificationExample.scala | 4 ++-- .../mllib/GradientBoostingRegressionExample.scala | 4 ++-- .../spark/examples/mllib/HypothesisTestingExample.scala | 2 +- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/KMeansExample.scala | 2 +- .../org/apache/spark/examples/mllib/LBFGSExample.scala | 2 +- .../examples/mllib/LatentDirichletAllocationExample.scala | 8 +++++--- .../examples/mllib/LinearRegressionWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/PCAExample.scala | 4 ++-- .../spark/examples/mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/examples/mllib/PrefixSpanExample.scala | 4 ++-- .../mllib/RandomForestClassificationExample.scala | 4 ++-- .../examples/mllib/RandomForestRegressionExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../apache/spark/examples/mllib/SVMWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/SimpleFPGrowth.scala | 8 +++----- .../spark/examples/mllib/StratifiedSamplingExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/TallSkinnyPCA.scala | 2 +- .../org/apache/spark/examples/mllib/TallSkinnySVD.scala | 2 +- .../apache/spark/examples/streaming/CustomReceiver.scala | 6 +++--- .../apache/spark/examples/streaming/RawNetworkGrep.scala | 2 +- .../examples/streaming/RecoverableNetworkWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 49 files changed, 94 insertions(+), 96 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac621102..43cc30c1a899b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d94..2332a661f26a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc0..815404d1218b7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742aa..57b2edf992208 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce58..5146fd0316467 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb00..d7f1fc8ed74d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf16961..ee4469faab3a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933ea..276cedab11abc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0c..aaaecaea47081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e96..2dc11b07d88ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a6921..e5d91f132a3f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b817..ef78c0a1145ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d3..3feb2343f6a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353d..3e61dbe628c20 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,7 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8d..646f46a925062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66c..50c70c626b128 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb70..0fe16fb6dfa9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce4285..6265f83902528 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef5..2679fcb353a8a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b98..96bb8ea2338af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e5..a07535bb5a38d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff6..c6312d71cc912 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777ce..c2f89b72c9a2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5b..1ecf6426e1f95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc127752..f724ee1030f04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a36..3c56e1941aeca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839b..c288bf29bf255 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c07..add1719739539 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04c..a10d6f0dda880 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f375..b0a6f1671a898 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa098381..123782fa6b9cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f1..d25962c5500ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d399618094487..449b725d1d173 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba1..eff2393cc3abe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb11..96deafd469bc7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c4336576..8b789277774af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733ed..246e71de25615 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4b..770e30276bc30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e70..0bb2b8c8c2b43 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3faa..285e2ce512639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba09..694c3bb18b045 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef60699..3d41bef0af88c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5a..071d341b81614 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e7..8ae6de16d80e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b1204..25c7bf2871972 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b4..437ccf0898d7c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c0427321133..f018f3a26d2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2b..2108bc63edea2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32c..b8e7c7e9e9152 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() From b297029130735316e1ac1144dee44761a12bfba7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 07:28:53 +0800 Subject: [PATCH 0007/2461] [SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request? move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20116 from cloud-fan/column-vector. --- .../VectorizedParquetRecordReader.java | 7 ++- .../vectorized/ColumnVectorUtils.java | 2 + .../vectorized/MutableColumnarRow.java | 4 ++ .../vectorized/WritableColumnVector.java | 7 ++- .../vectorized/ArrowColumnVector.java | 62 +------------------ .../vectorized/ColumnVector.java | 31 ++++++---- .../vectorized/ColumnarArray.java | 7 +-- .../vectorized/ColumnarBatch.java | 34 +++------- .../vectorized/ColumnarRow.java | 7 +-- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 20 files changed, 63 insertions(+), 125 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ArrowColumnVector.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnVector.java (79%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarArray.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarBatch.java (73%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarRow.java (96%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411f..cd745b1f0e4e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e5..b5cbe8e2839ba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe9..70057a9def6c0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e12..d2ae32b06f83b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a501..708333213f3f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd9..d1196e1299fee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec2..0d89a52e7a4fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa679726..9ae1c6d9993f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c0..3c6656dec77cd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,8 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { // The data for this row. @@ -34,7 +33,7 @@ public final class ColumnarRow extends InternalRow { private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292ba..5617046e1396e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a9..ce3c68810f3b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d63..0cf9b53ce1d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc412430263..bcd1aa0890ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b12850..933b9753faa61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f2..835ce98462477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed3535654..dc5ba96e69aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92e..c42bc60a59d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a655..7304803a092c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f6..944240f3bade5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d0..675f06b31b970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) From 7d045c5f00e2c7c67011830e2169a4e130c3ace8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 13:14:52 +0800 Subject: [PATCH 0008/2461] [SPARK-22944][SQL] improve FoldablePropagation ## What changes were proposed in this pull request? `FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes. Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20139 from cloud-fan/foldable. --- .../sql/catalyst/optimizer/expressions.scala | 65 +++++++++++-------- .../optimizer/FoldablePropagationSuite.scala | 23 ++++++- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc32..1c0b7bd806801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a8..c28844642aed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } From df95a908baf78800556636a76d58bba9b3dd943f Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 Jan 2018 21:43:14 -0800 Subject: [PATCH 0009/2461] [SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, trigger, partitionBy ## What changes were proposed in this pull request? R Structured Streaming API for withWatermark, trigger, partitionBy ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20129 from felixcheung/rwater. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 96 +++++++++++++++- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/generics.R | 6 + R/pkg/tests/fulltests/test_streaming.R | 107 ++++++++++++++++++ python/pyspark/sql/streaming.py | 4 + .../sql/execution/streaming/Triggers.scala | 2 +- 7 files changed, 214 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47b..c51eb0f39c4b1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -179,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb0..9956f7eda91e6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3661,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3707,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3725,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3967,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb8..9d0a2d5e074e4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5e..e0dde3339fabc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f517..a354d50c6b54e 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7ab..24ae3776a217b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -793,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c08..19e3e55cb2829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental From 9fa703e89318922393bae03c0db4575f4f4b4c56 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 4 Jan 2018 19:10:10 +0800 Subject: [PATCH 0010/2461] [SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent ## What changes were proposed in this pull request? ChildFirstClassLoader's parent is set to null, so we can't get jars from its parent. This will cause ClassNotFoundException during HiveClient initialization with builtin hive jars, where we may should use spark context loader instead. ## How was this patch tested? add new ut cc cloud-fan gatorsmile Author: Kent Yao Closes #20145 from yaooqinn/SPARK-22950. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 4 +++- .../spark/sql/hive/HiveUtilsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd1..c7717d70c996f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a68440..8697d47e89e89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } From d5861aba9d80ca15ad3f22793b79822e470d6913 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 19:17:22 +0800 Subject: [PATCH 0011/2461] [SPARK-22945][SQL] add java UDF APIs in the functions object ## What changes were proposed in this pull request? Currently Scala users can use UDF like ``` val foo = udf((i: Int) => Math.random() + i).asNondeterministic df.select(foo('a)) ``` Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20141 from cloud-fan/udf. --- .../apache/spark/sql/UDFRegistration.scala | 90 ++--- .../sql/expressions/UserDefinedFunction.scala | 1 + .../org/apache/spark/sql/functions.scala | 313 ++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 11 + .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +- 5 files changed, 315 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e41..f94baef39dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f830520..40a058d2cadd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01dec..0d11682d80a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3254,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84b..4f8a31f185724 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7a..db37be68e42e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { From 5aadbc929cb194e06dbd3bab054a161569289af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Jan 2018 21:07:31 +0800 Subject: [PATCH 0012/2461] [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction ## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile Closes #20137 from gatorsmile/registerFunction. --- python/pyspark/sql/catalog.py | 27 +++++++++++++++---- python/pyspark/sql/context.py | 16 +++++++++--- python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++---------- python/pyspark/sql/udf.py | 21 ++++++++++----- 4 files changed, 84 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0c..156603128d063 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef3..b8d86cc098e94 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93b..6dc767f9ec46e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,41 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +470,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -567,7 +593,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +600,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1299,7 +1322,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c8..5e75eb6545333 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -130,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,7 +168,8 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() return wrapper @@ -172,5 +179,5 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + self.deterministic = False return self From 6f68316e98fad72b171df422566e1fc9a7bbfcde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:15:10 +0800 Subject: [PATCH 0013/2461] [SPARK-22771][SQL] Add a missing return statement in Concat.checkInputDataTypes ## What changes were proposed in this pull request? This pr is a follow-up to fix a bug left in #19977. ## How was this patch tested? Added tests in `StringExpressionsSuite`. Author: Takeshi Yamamuro Closes #20149 from maropu/SPARK-22771-FOLLOWUP. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961b..41dc762154a4c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e27..97ddbeba2c5ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From 93f92c0ed7442a4382e97254307309977ff676f8 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 4 Jan 2018 11:39:42 -0800 Subject: [PATCH 0014/2461] [SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for external shuffle service ## What changes were proposed in this pull request? This PR is the second attempt of #18684 , NIO's Files API doesn't override `skip` method for `InputStream`, so it will bring in performance issue (mentioned in #20119). But using `FileInputStream`/`FileOutputStream` will also bring in memory issue (https://dzone.com/articles/fileinputstream-fileoutputstream-considered-harmful), which is severe for long running external shuffle service. So here in this proposal, only fixing the external shuffle service related code. ## How was this patch tested? Existing tests. Author: jerryshao Closes #20144 from jerryshao/SPARK-21475-v2. --- .../apache/spark/network/buffer/FileSegmentManagedBuffer.java | 3 ++- .../apache/spark/network/shuffle/ShuffleIndexInformation.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..8b8f9892847c3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b76..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { From d2cddc88eac32f26b18ec26bb59e85c6f09a8c88 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:19:00 -0600 Subject: [PATCH 0015/2461] [SPARK-22850][CORE] Ensure queued events are delivered to all event queues. The code in LiveListenerBus was queueing events before start in the queues themselves; so in situations like the following: bus.post(someEvent) bus.addToEventLogQueue(listener) bus.start() "someEvent" would not be delivered to "listener" if that was the first listener in the queue, because the queue wouldn't exist when the event was posted. This change buffers the events before starting the bus in the bus itself, so that they can be delivered to all registered queues when the bus is started. Also tweaked the unit tests to cover the behavior above. Author: Marcelo Vanzin Closes #20039 from vanzin/SPARK-22850. --- .../spark/scheduler/LiveListenerBus.scala | 45 ++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 21 +++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b1025..ba6387a8f08ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f0..da6ecb82c7e42 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: From 95f9659abe8845f9f3f42fd7ababd79e55c52489 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 15:00:09 -0800 Subject: [PATCH 0016/2461] [SPARK-22948][K8S] Move SparkPodInitContainer to correct package. Author: Marcelo Vanzin Closes #20156 from vanzin/SPARK-22948. --- dev/sparktestsupport/modules.py | 2 +- .../spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala | 2 +- .../deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala | 2 +- .../docker/src/main/dockerfiles/init-container/Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala (99%) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala (98%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dda..7164180a6a7b0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -539,7 +539,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala index 4a4b628aedbbf..c0f08786b76a1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.concurrent.TimeUnit diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala index 6c557ec4a7c9a..e0f29ecd0fb53 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.UUID diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 055493188fcb7..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -21,4 +21,4 @@ FROM spark-base # command should be invoked from the top level directory of the Spark distribution. E.g.: # docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] From e288fc87a027ec1e1a21401d1f151df20dbfecf3 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 15:35:20 -0800 Subject: [PATCH 0017/2461] [SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-container is used ## What changes were proposed in this pull request? User-specified secrets are mounted into both the main container and init-container (when it is used) in a Spark driver/executor pod, using the `MountSecretsBootstrap`. Because `MountSecretsBootstrap` always adds new secret volumes for the secrets to the pod, the same secret volumes get added twice, one when mounting the secrets to the main container, and the other when mounting the secrets to the init-container. This PR fixes the issue by separating `MountSecretsBootstrap.mountSecrets` out into two methods: `addSecretVolumes` for adding secret volumes to a pod and `mountSecrets` for mounting secret volumes to a container, respectively. `addSecretVolumes` is only called once for each pod, whereas `mountSecrets` is called individually for the main container and the init-container (if it is used). Ref: https://github.com/apache-spark-on-k8s/spark/issues/594. ## How was this patch tested? Unit tested and manually tested. vanzin This replaces https://github.com/apache/spark/pull/20148. hex108 foxish kimoonkim Author: Yinan Li Closes #20159 from liyinan926/master. --- .../deploy/k8s/MountSecretsBootstrap.scala | 30 ++++++++++++------- .../k8s/submit/DriverConfigOrchestrator.scala | 16 +++++----- .../steps/BasicDriverConfigurationStep.scala | 2 +- .../submit/steps/DriverMountSecretsStep.scala | 4 +-- .../InitContainerMountSecretsStep.scala | 11 +++---- .../cluster/k8s/ExecutorPodFactory.scala | 6 ++-- .../k8s/{submit => }/SecretVolumeUtils.scala | 18 +++++------ .../BasicDriverConfigurationStepSuite.scala | 4 +-- .../steps/DriverMountSecretsStepSuite.scala | 4 +-- .../InitContainerMountSecretsStepSuite.scala | 7 +---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 14 +++++---- 11 files changed, 61 insertions(+), 55 deletions(-) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit => }/SecretVolumeUtils.scala (71%) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala index 8286546ce0641..c35e7db51d407 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -24,26 +24,36 @@ import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBui private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. * * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. + * @return the updated pod with the secret volumes added. */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + def addSecretVolumes(pod: Pod): Pod = { var podBuilder = new PodBuilder(pod) secretNamesToMountPaths.keys.foreach { name => podBuilder = podBuilder .editOrNewSpec() .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() .endSpec() } + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { var containerBuilder = new ContainerBuilder(container) secretNamesToMountPaths.foreach { case (name, path) => containerBuilder = containerBuilder @@ -53,7 +63,7 @@ private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, .endVolumeMount() } - (podBuilder.build(), containerBuilder.build()) + containerBuilder.build() } private def secretVolumeName(secretName: String): String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 00c9c4ee49177..c9cc300d65569 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -127,6 +127,12 @@ private[spark] class DriverConfigOrchestrator( Nil } + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { val orchestrator = new InitContainerConfigOrchestrator( sparkJars, @@ -147,19 +153,13 @@ private[spark] class DriverConfigOrchestrator( Nil } - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep + mountSecretsStep ++ + initContainerBootstrapStep } private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b7a69a7dfd472..eca46b84c6066 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -119,7 +119,7 @@ private[spark] class BasicDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala index f872e0f4b65d1..91e9a9f211335 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -28,8 +28,8 @@ private[spark] class DriverMountSecretsStep( bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) driverSpec.copy( driverPod = pod, driverContainer = container diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala index c0e7bb20cce8c..0daa7b95e8aae 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -28,12 +28,9 @@ private[spark] class InitContainerMountSecretsStep( bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ba5d891f4c77e..066d7e9f70ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -214,7 +214,7 @@ private[spark] class ExecutorPodFactory( val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) val (bootstrappedPod, bootstrappedContainer) = @@ -227,7 +227,9 @@ private[spark] class ExecutorPodFactory( val (pod, mayBeSecretsMountedInitContainer) = initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) val bootstrappedPod = KubernetesUtils.appendInitContainer( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded268..16780584a674a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e864c6a16eeb1..8ee629ac8ddc1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -33,7 +33,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -82,7 +82,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala index 9ec0cb55de5aa..960d0bda1d011 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DriverMountSecretsStepSuite extends SparkFunSuite { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala index eab4e17659456..7ac0bde80dfe6 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps.initcontainer import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} class InitContainerMountSecretsStepSuite extends SparkFunSuite { @@ -44,12 +43,8 @@ class InitContainerMountSecretsStepSuite extends SparkFunSuite { val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => assert(SecretVolumeUtils.containerHasVolume( initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7121a802c69c1..884da8aabd880 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -25,7 +25,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -165,17 +165,19 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val factory = new ExecutorPodFactory( conf, - None, + Some(secretsBootstrap), Some(initContainerBootstrap), Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) checkOwnerReferences(executor, driverPodUid) } From 0428368c2c5e135f99f62be20877bbbda43be310 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:34:56 -0800 Subject: [PATCH 0018/2461] [SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly. - Make it possible to build images from a git clone. - Make it easy to use minikube to test things. Also fixed what seemed like a bug: the base image wasn't getting the tag provided in the command line. Adding the tag allows users to use multiple Spark builds in the same kubernetes cluster. Tested by deploying images on minikube and running spark-submit from a dev environment; also by building the images with different tags and verifying "docker images" in minikube. Author: Marcelo Vanzin Closes #20154 from vanzin/SPARK-22960. --- docs/running-on-kubernetes.md | 9 +- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 3 +- .../main/dockerfiles/spark-base/Dockerfile | 7 +- sbin/build-push-docker-images.sh | 120 +++++++++++++++--- 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3c..2d69f636472ae 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -16,6 +16,9 @@ Kubernetes scheduler that has been added to Spark. you may setup a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -197,7 +200,7 @@ kubectl port-forward 4040:4040 Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -215,8 +218,8 @@ If the pod has encountered a runtime error, the status can be probed further usi kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 45fbcd9cd0deb..ff5289e10c21e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0f806cf7e148e..3eabb42d4d852 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 047056ab2633b..e0a249e0ac71f 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a82..da1d6b9e161cc 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index b3137598692d8..bb8806dd33f37 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,29 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + local base_image="$(image_ref spark-base 0)" + docker build --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t "$base_image" \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac From df7fc3ef3899cadd252d2837092bebe3442d6523 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 Jan 2018 10:16:34 +0800 Subject: [PATCH 0019/2461] [SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt ## What changes were proposed in this pull request? 32bit Int was used for row rank. That overflowed in a dataframe with more than 2B rows. ## How was this patch tested? Added test, but ignored, as it takes 4 minutes. Author: Juliusz Sompolski Closes #20152 from juliuszsompolski/SPARK-22957. --- .../aggregate/ApproximatePercentile.scala | 12 ++++++------ .../spark/sql/catalyst/util/QuantileSummaries.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed5..a45854a3b5146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6af..b013add9c9778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2e..5169d2b5fc6b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() From 52fc5c17d9d784b846149771b398e741621c0b5c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 14:02:21 +0800 Subject: [PATCH 0020/2461] [SPARK-22825][SQL] Fix incorrect results of Casting Array to String ## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro Closes #20024 from maropu/SPARK-22825. --- .../codegen/UTF8StringBuilder.java | 78 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 68 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 0000000000000..f0f66bae245fd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16db..d4fc5e0f168a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a1..e3ed7171defd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..96bf65fce9c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From cf0aa65576acbe0209c67f04c029058fd73555c1 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 4 Jan 2018 22:45:15 -0800 Subject: [PATCH 0021/2461] [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit ## What changes were proposed in this pull request? Avoid holding all models in memory for `TrainValidationSplit`. ## How was this patch tested? Existing tests. Author: Bago Amirbekian Closes #20143 from MrBago/trainValidMemoryFix. --- .../spark/ml/tuning/CrossValidator.scala | 4 +++- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83f..a0b507d2e718c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd18475475..8826ef3271bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") From 6cff7d19f6a905fe425bd6892fe7ca014c0e696b Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 23:23:41 -0800 Subject: [PATCH 0022/2461] [SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode ## What changes were proposed in this pull request? We missed enabling `spark.files` and `spark.jars` in https://github.com/apache/spark/pull/19954. The result is that remote dependencies specified through `spark.files` or `spark.jars` are not included in the list of remote dependencies to be downloaded by the init-container. This PR fixes it. ## How was this patch tested? Manual tests. vanzin This replaces https://github.com/apache/spark/pull/20157. foxish Author: Yinan Li Closes #20160 from liyinan926/SPARK-22757. --- .../src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a1..1e381965c52ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -584,10 +584,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, From 51c33bd0d402af9e0284c6cbc0111f926446bfba Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 5 Jan 2018 21:32:39 +0800 Subject: [PATCH 0023/2461] [SPARK-22961][REGRESSION] Constant columns should generate QueryPlanConstraints ## What changes were proposed in this pull request? #19201 introduced the following regression: given something like `df.withColumn("c", lit(2))`, we're no longer picking up `c === 2` as a constraint and infer filters from it when joins are involved, which may lead to noticeable performance degradation. This patch re-enables this optimization by picking up Aliases of Literals in Projection lists as constraints and making sure they're not treated as aliased columns. ## How was this patch tested? Unit test was added. Author: Adrian Ionescu Closes #20155 from adrian-ionescu/constant_constraints. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 ++ .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5e..ff2a0ec588567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38dea..9c0a30a47f839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec72..a0708bf7eee9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } From c0b7424ecacb56d3e7a18acc11ba3d5e7be57c43 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 5 Jan 2018 09:58:28 -0800 Subject: [PATCH 0024/2461] [SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on platforms that don't have wget ## What changes were proposed in this pull request? Modified HiveExternalCatalogVersionsSuite.scala to use Utils.doFetchFile to download different versions of Spark binaries rather than launching wget as an external process. On platforms that don't have wget installed, this suite fails with an error. cloud-fan : would you like to check this change? ## How was this patch tested? 1) test-only of HiveExternalCatalogVersionsSuite on several platforms. Tested bad mirror, read timeout, and redirects. 2) ./dev/run-tests Author: Bruce Robbins Closes #20147 from bersprockets/SPARK-22940-alt. --- .../HiveExternalCatalogVersionsSuite.scala | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..ae4aeb7b4ce4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() From 930b90a84871e2504b57ed50efa7b8bb52d3ba44 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Jan 2018 11:51:25 -0800 Subject: [PATCH 0025/2461] [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator ## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #20132 from jkbradley/viirya-SPARK-13030. --- .../ml/feature/OneHotEncoderEstimator.scala | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28d..bd1e3426c8780 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) From ea956833017fcbd8ed2288368bfa2e417a2251c5 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Fri, 5 Jan 2018 17:25:28 -0800 Subject: [PATCH 0026/2461] [SPARK-22914][DEPLOY] Register history.ui.port ## What changes were proposed in this pull request? Register spark.history.ui.port as a known spark conf to be used in substitution expressions even if it's not set explicitly. ## How was this patch tested? Added unit test to demonstrate the issue Author: Gera Shegalov Author: Gera Shegalov Closes #20098 from gerashegalov/gera/register-SHS-port-conf. --- .../spark/deploy/history/HistoryServer.scala | 3 +- .../apache/spark/deploy/history/config.scala | 5 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +++++--- .../deploy/yarn/ApplicationMasterSuite.scala | 43 +++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30f..0ec4afad0308c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -276,7 +277,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a4..efdbf672bb52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,9 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d72633..4d5e3bb043671 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 0000000000000..695a82f3583e6 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} From e8af7e8aeca15a6107248f358d9514521ffdc6d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 6 Jan 2018 09:26:03 +0800 Subject: [PATCH 0027/2461] [SPARK-22937][SQL] SQL elt output binary for binary inputs ## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro Closes #20135 from maropu/SPARK-22937. --- docs/sql-programming-guide.md | 2 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 +++++ .../expressions/stringExpressions.scala | 46 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../catalyst/analysis/TypeCoercionSuite.scala | 54 ++++++++ .../inputs/typeCoercion/native/elt.sql | 44 +++++++ .../results/typeCoercion/native/elt.sql.out | 115 ++++++++++++++++++ 7 files changed, 281 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d27..b50f9360b866c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1783,6 +1783,8 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2e..e8669c4637d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -684,6 +685,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 41dc762154a4c..e004bfc6af473 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d6edf6b8abec..80b8965e084a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1052,6 +1052,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622b..52a7ebdafd7c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 0000000000000..717616f91db05 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 0000000000000..b62e1b6826045 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 From bf65cd3cda46d5480bfcd13110975c46ca631972 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 5 Jan 2018 17:29:27 -0800 Subject: [PATCH 0028/2461] [SPARK-22960][K8S] Revert use of ARG base_image in images ## What changes were proposed in this pull request? This PR reverts the `ARG base_image` before `FROM` in the images of driver, executor, and init-container, introduced in https://github.com/apache/spark/pull/20154. The reason is Docker versions before 17.06 do not support this use (`ARG` before `FROM`). ## How was this patch tested? Tested manually. vanzin foxish kimoonkim Author: Yinan Li Closes #20170 from liyinan926/master. --- .../docker/src/main/dockerfiles/driver/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/executor/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/init-container/Dockerfile | 3 +-- sbin/build-push-docker-images.sh | 8 ++++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index ff5289e10c21e..45fbcd9cd0deb 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 3eabb42d4d852..0f806cf7e148e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index e0a249e0ac71f..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index bb8806dd33f37..b9532597419a5 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -60,13 +60,13 @@ function image_ref { } function build { - local base_image="$(image_ref spark-base 0)" - docker build --build-arg "spark_jars=$SPARK_JARS" \ + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ --build-arg "img_path=$IMG_PATH" \ - -t "$base_image" \ + -t spark-base \ -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } From f2dd8b923759e8771b0e5f59bfa7ae4ad7e6a339 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 6 Jan 2018 16:11:20 +0800 Subject: [PATCH 0029/2461] [SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs for non-deterministic cases ## What changes were proposed in this pull request? Add tests for using non deterministic UDFs in aggregate. Update pandas_udf docstring w.r.t to determinism. ## How was this patch tested? test_nondeterministic_udf_in_aggregate Author: Li Jin Closes #20142 from icexelloss/SPARK-22930-pandas-udf-deterministic. --- python/pyspark/sql/functions.py | 12 +++++++- python/pyspark/sql/tests.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b4..733e32bd825b0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2214,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46e..689736d8e6456 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -386,6 +386,7 @@ def test_udf3(self): self.assertEqual(row[0], 5) def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() @@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -3567,6 +3580,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): From be9a804f2ef77a5044d3da7d9374976daf59fc16 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 6 Jan 2018 18:07:45 +0800 Subject: [PATCH 0030/2461] [SPARK-22793][SQL] Memory leak in Spark Thrift Server # What changes were proposed in this pull request? 1. Start HiveThriftServer2. 2. Connect to thriftserver through beeline. 3. Close the beeline. 4. repeat step2 and step 3 for many times. we found there are many directories never be dropped under the path `hive.exec.local.scratchdir` and `hive.exec.scratchdir`, as we know the scratchdir has been added to deleteOnExit when it be created. So it means that the cache size of FileSystem `deleteOnExit` will keep increasing until JVM terminated. In addition, we use `jmap -histo:live [PID]` to printout the size of objects in HiveThriftServer2 Process, we can find the object `org.apache.spark.sql.hive.client.HiveClientImpl` and `org.apache.hadoop.hive.ql.session.SessionState` keep increasing even though we closed all the beeline connections, which may caused the leak of Memory. # How was this patch tested? manual tests This PR follw-up the https://github.com/apache/spark/pull/19989 Author: zuotingbing Closes #20029 from zuotingbing/SPARK-22793. --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e3..dc92ad3b0c1ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } From 7b78041423b6ee330def2336dfd1ff9ae8469c59 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 6 Jan 2018 18:19:57 +0800 Subject: [PATCH 0031/2461] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? Since Hive 1.1, Hive allows users to set parquet compression codec via table-level properties parquet.compression. See the JIRA: https://issues.apache.org/jira/browse/HIVE-7858 . We do support orc.compression for ORC. Thus, for external users, it is more straightforward to support both. See the stackflow question: https://stackoverflow.com/questions/36941122/spark-sql-ignores-parquet-compression-propertie-specified-in-tblproperties In Spark side, our table-level compression conf compression was added by #11464 since Spark 2.0. We need to support both table-level conf. Users might also use session-level conf spark.sql.parquet.compression.codec. The priority rule will be like If other compression codec configuration was found through hive or parquet, the precedence would be compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo. The rule for Parquet is consistent with the ORC after the change. Changes: 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the precedence order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Add test. Author: fjh100456 Closes #20076 from fjh100456/ParquetOptionIssue. --- docs/sql-programming-guide.md | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../datasources/parquet/ParquetOptions.scala | 12 +- ...rquetCompressionCodecPrecedenceSuite.scala | 122 ++++++++++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b50f9360b866c..3ccaaf4d5b1fa 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,8 +953,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80b8965e084a2..7d1217de254a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,11 +325,13 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + + "`parquet.compression` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -366,8 +368,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + + "`orc.compress` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de548..ef67ea7d17cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 0000000000000..ed8fd2b453456 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} From 993f21567a1dd33e43ef9a626e0ddfbe46f83f93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 6 Jan 2018 23:08:26 +0800 Subject: [PATCH 0032/2461] [SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic for wrapped UDF function ## What changes were proposed in this pull request? This PR wraps the `asNondeterministic` attribute in the wrapped UDF function to set the docstring properly. ```python from pyspark.sql.functions import udf help(udf(lambda x: x).asNondeterministic) ``` Before: ``` Help on function in module pyspark.sql.udf: lambda (END ``` After: ``` Help on function asNondeterministic in module pyspark.sql.udf: asNondeterministic() Updates UserDefinedFunction to nondeterministic. .. versionadded:: 2.3 (END) ``` ## How was this patch tested? Manually tested and a simple test was added. Author: hyukjinkwon Closes #20173 from HyukjinKwon/SPARK-22901-followup. --- python/pyspark/sql/tests.py | 1 + python/pyspark/sql/udf.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 689736d8e6456..122a65b83aef9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -413,6 +413,7 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e75eb6545333..5e80ab9165867 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -169,8 +169,8 @@ def wrapper(*args): wrapper.returnType = self.returnType wrapper.evalType = self.evalType wrapper.deterministic = self.deterministic - wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() - + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): From 9a7048b2889bd0fd66e68a0ce3e07e466315a051 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 7 Jan 2018 00:19:21 +0800 Subject: [PATCH 0033/2461] [HOTFIX] Fix style checking failure ## What changes were proposed in this pull request? This PR is to fix the style checking failure. ## How was this patch tested? N/A Author: gatorsmile Closes #20175 from gatorsmile/stylefix. --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7d1217de254a2..5c61f10bb71ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,10 +325,11 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + - "`parquet.compression` is specified in the table-specific options/properties, the precedence" + - "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + - "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) @@ -368,8 +369,8 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + - "`orc.compress` is specified in the table-specific options/properties, the precedence" + + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf From 18e94149992618a2b4e6f0fd3b3f4594e1745224 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 7 Jan 2018 13:42:01 +0800 Subject: [PATCH 0034/2461] [SPARK-22973][SQL] Fix incorrect results of Casting Map to String ## What changes were proposed in this pull request? This pr fixed the issue when casting maps into strings; ``` scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t") scala> sql("SELECT cast(a as String) FROM t").show(false) +----------------------------------------------------------------+ |a | +----------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d| +----------------------------------------------------------------+ ``` This pr modified the result into; ``` +----------------+ |a | +----------------+ |[1 -> a, 2 -> b]| +----------------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20166 from maropu/SPARK-22973. --- .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 28 ++++++ 2 files changed, 117 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4fc5e0f168a7..f2de4c8e30bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e3ed7171defd8..1445bb8a97d40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } } From 71d65a32158a55285be197bec4e41fedc9225b94 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 11:39:45 +0800 Subject: [PATCH 0035/2461] [SPARK-22985] Fix argument escaping bug in from_utc_timestamp / to_utc_timestamp codegen ## What changes were proposed in this pull request? This patch adds additional escaping in `from_utc_timestamp` / `to_utc_timestamp` expression codegen in order to a bug where invalid timezones which contain special characters could cause generated code to fail to compile. ## How was this patch tested? New regression tests in `DateExpressionsSuite`. Author: Josh Rosen Closes #20182 from JoshRosen/SPARK-22985-fix-utc-timezone-function-escaping-bugs. --- .../catalyst/expressions/datetimeExpressions.scala | 12 ++++++++---- .../catalyst/expressions/DateExpressionsSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7a674ea7f4d76..424871f2047e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,6 +23,8 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -1008,7 +1010,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1017,8 +1019,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") @@ -1185,7 +1188,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1194,8 +1197,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 63f6ceeb21b96..786266a2c13c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -791,6 +792,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate( + ToUTCTimestamp(Literal(Timestamp.valueOf("2015-07-24 00:00:00")), Literal("\"quote")) :: Nil) } test("from_utc_timestamp") { @@ -811,5 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate(FromUTCTimestamp(Literal(0), Literal("\"quote")) :: Nil) } } From 3e40eb3f1ffac3d2f49459a801e3ce171ed34091 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 8 Jan 2018 14:32:05 +0900 Subject: [PATCH 0036/2461] [SPARK-22566][PYTHON] Better error message for `_merge_type` in Pandas to Spark DF conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It provides a better error message when doing `spark_session.createDataFrame(pandas_df)` with no schema and an error occurs in the schema inference due to incompatible types. The Pandas column names are propagated down and the error message mentions which column had the merging error. https://issues.apache.org/jira/browse/SPARK-22566 ## How was this patch tested? Manually in the `./bin/pyspark` console, and with new tests: `./python/run-tests` screen shot 2017-11-21 at 13 29 49 I state that the contribution is my original work and that I license the work to the Apache Spark project under the project’s open source license. Author: Guilherme Berger Closes #19792 from gberger/master. --- python/pyspark/sql/session.py | 17 +++--- python/pyspark/sql/tests.py | 100 ++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 28 +++++++--- 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6e5eec48e8aca..6052fa9e84096 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -325,11 +325,12 @@ def range(self, start, end=None, step=1, numPartitions=None): return DataFrame(jdf, self._wrapped) - def _inferSchemaFromList(self, data): + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple + :param names: list of column names :return: :class:`pyspark.sql.types.StructType` """ if not data: @@ -338,12 +339,12 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) + schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema - def _inferSchema(self, rdd, samplingRatio=None): + def _inferSchema(self, rdd, samplingRatio=None, names=None): """ Infer schema from an RDD of Row or tuple. @@ -360,10 +361,10 @@ def _inferSchema(self, rdd, samplingRatio=None): "Use pyspark.sql.Row instead") if samplingRatio is None: - schema = _infer_schema(first) + schema = _infer_schema(first, names=names) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) + schema = _merge_type(schema, _infer_schema(row, names=names)) if not _has_nulltype(schema): break else: @@ -372,7 +373,7 @@ def _inferSchema(self, rdd, samplingRatio=None): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): @@ -380,7 +381,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) + struct = self._inferSchema(rdd, samplingRatio, names=schema) converter = _create_converter(struct) rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): @@ -406,7 +407,7 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) + struct = self._inferSchemaFromList(data, names=schema) converter = _create_converter(struct) data = map(converter, data) if isinstance(schema, (list, tuple)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 122a65b83aef9..13576ff57001b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,6 +68,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings +from pyspark.sql.types import _merge_type from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -898,6 +899,15 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), @@ -918,6 +928,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data) @@ -1772,6 +1786,92 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 146e673ae9756..0dc5823f72a3c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1073,7 +1073,7 @@ def _infer_type(obj): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row): +def _infer_schema(row, names=None): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1084,7 +1084,10 @@ def _infer_schema(row): elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: - names = ['_%d' % i for i in range(1, len(row) + 1)] + if names is None: + names = ['_%d' % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1109,19 +1112,27 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b): +def _merge_type(a, b, name=None): + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) + if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), + name=new_name(f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1130,11 +1141,12 @@ def _merge_type(a, b): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) + return ArrayType(_merge_type(a.elementType, b.elementType, + name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), + return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), + _merge_type(a.valueType, b.valueType, name='value of map %s' % name), True) else: return a From 8fdeb4b9946bd9be045abb919da2e531708b3bd4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 8 Jan 2018 13:59:08 +0800 Subject: [PATCH 0037/2461] [SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava) ## What changes were proposed in this pull request? Seems we can avoid type dispatch for each value when Java objection (from Pyrolite) -> Spark's internal data format because we know the schema ahead. I manually performed the benchmark as below: ```scala test("EvaluatePython.fromJava / EvaluatePython.makeFromJava") { val numRows = 1000 * 1000 val numFields = 30 val random = new Random(System.nanoTime()) val types = Array( BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType, DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType) val schema = RandomDataGenerator.randomSchema(random, numFields, types) val rows = mutable.ArrayBuffer.empty[Array[Any]] var i = 0 while (i < numRows) { val row = RandomDataGenerator.randomRow(random, schema) rows += row.toSeq.toArray i += 1 } val benchmark = new Benchmark("EvaluatePython.fromJava / EvaluatePython.makeFromJava", numRows) benchmark.addCase("Before - EvaluatePython.fromJava", 3) { _ => var i = 0 while (i < numRows) { EvaluatePython.fromJava(rows(i), schema) i += 1 } } benchmark.addCase("After - EvaluatePython.makeFromJava", 3) { _ => val fromJava = EvaluatePython.makeFromJava(schema) var i = 0 while (i < numRows) { fromJava(rows(i)) i += 1 } } benchmark.run() } ``` ``` EvaluatePython.fromJava / EvaluatePython.makeFromJava: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Before - EvaluatePython.fromJava 1265 / 1346 0.8 1264.8 1.0X After - EvaluatePython.makeFromJava 571 / 649 1.8 570.8 2.2X ``` If the structure is nested, I think the advantage should be larger than this. ## How was this patch tested? Existing tests should cover this. Also, I manually checked if the values from before / after are actually same via `assert` when performing the benchmarks. Author: hyukjinkwon Closes #20172 from HyukjinKwon/type-dispatch-python-eval. --- .../org/apache/spark/sql/SparkSession.scala | 5 +- .../python/BatchEvalPythonExec.scala | 7 +- .../sql/execution/python/EvaluatePython.scala | 166 ++++++++++++------ 3 files changed, 118 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 272eb844226d4..734573ba31f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -742,7 +742,10 @@ class SparkSession private( private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.mapPartitions { iter => + val fromJava = python.EvaluatePython.makeFromJava(schema) + iter.map(r => fromJava(r).asInstanceOf[InternalRow]) + } internalCreateDataFrame(rowRdd, schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 26ee25f633ea4..f4d83e8dc7c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } + + val fromJava = EvaluatePython.makeFromJava(resultType) + outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => if (udfs.length == 1) { // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow(0) = fromJava(result) mutableRow } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + fromJava(result).asInstanceOf[InternalRow] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 9bbfa6018ba77..520afad287648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -83,82 +83,134 @@ object EvaluatePython { } /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. + * Make a converter that converts `obj` to the type specified by the data type, or returns + * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c + def makeFromJava(dataType: DataType): Any => Any = dataType match { + case BooleanType => (obj: Any) => nullSafeConvert(obj) { + case b: Boolean => b + } - case (c: Byte, ByteType) => c - case (c: Short, ByteType) => c.toByte - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte + case ByteType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c + case c: Short => c.toByte + case c: Int => c.toByte + case c: Long => c.toByte + } - case (c: Byte, ShortType) => c.toShort - case (c: Short, ShortType) => c - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort + case ShortType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toShort + case c: Short => c + case c: Int => c.toShort + case c: Long => c.toShort + } - case (c: Byte, IntegerType) => c.toInt - case (c: Short, IntegerType) => c.toInt - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt + case IntegerType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toInt + case c: Short => c.toInt + case c: Int => c + case c: Long => c.toInt + } - case (c: Byte, LongType) => c.toLong - case (c: Short, LongType) => c.toLong - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c + case LongType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toLong + case c: Short => c.toLong + case c: Int => c.toLong + case c: Long => c + } - case (c: Float, FloatType) => c - case (c: Double, FloatType) => c.toFloat + case FloatType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c + case c: Double => c.toFloat + } - case (c: Float, DoubleType) => c.toDouble - case (c: Double, DoubleType) => c + case DoubleType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c.toDouble + case c: Double => c + } - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) { + case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale) + } - case (c: Int, DateType) => c + case DateType => (obj: Any) => nullSafeConvert(obj) { + case c: Int => c + } - case (c: Long, TimestampType) => c - // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs - case (c: Int, TimestampType) => c.toLong + case TimestampType => (obj: Any) => nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + } - case (c, StringType) => UTF8String.fromString(c.toString) + case StringType => (obj: Any) => nullSafeConvert(obj) { + case _ => UTF8String.fromString(obj.toString) + } - case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case BinaryType => (obj: Any) => nullSafeConvert(obj) { + case c: String => c.getBytes(StandardCharsets.UTF_8) + case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + } - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + case ArrayType(elementType, _) => + val elementFromJava = makeFromJava(elementType) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + (obj: Any) => nullSafeConvert(obj) { + case c: java.util.List[_] => + new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray) + case c if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e))) + } - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => - ArrayBasedMapData( - javaMap, - (key: Any) => fromJava(key, keyType), - (value: Any) => fromJava(value, valueType)) + case MapType(keyType, valueType, _) => + val keyFromJava = makeFromJava(keyType) + val valueFromJava = makeFromJava(valueType) + + (obj: Any) => nullSafeConvert(obj) { + case javaMap: java.util.Map[_, _] => + ArrayBasedMapData( + javaMap, + (key: Any) => keyFromJava(key), + (value: Any) => valueFromJava(value)) + } - case (c, StructType(fields)) if c.getClass.isArray => - val array = c.asInstanceOf[Array[_]] - if (array.length != fields.length) { - throw new IllegalStateException( - s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." - ) + case StructType(fields) => + val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray + + (obj: Any) => nullSafeConvert(obj) { + case c if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + + val row = new GenericInternalRow(fields.length) + var i = 0 + while (i < fields.length) { + row(i) = fieldsFromJava(i)(array(i)) + i += 1 + } + row } - new GenericInternalRow(array.zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) + + case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + } - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null + private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, { + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + _: Any => null + }) + } } private val module = "pyspark.sql.types" From 2c73d2a948bdde798aaf0f87c18846281deb05fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 16:04:03 +0800 Subject: [PATCH 0038/2461] [SPARK-22983] Don't push filters beneath aggregates with empty grouping expressions ## What changes were proposed in this pull request? The following SQL query should return zero rows, but in Spark it actually returns one row: ``` SELECT 1 from ( SELECT 1 AS z, MIN(a.x) FROM (select 1 as x) a WHERE false ) b where b.z != b.z ``` The problem stems from the `PushDownPredicate` rule: when this rule encounters a filter on top of an Aggregate operator, e.g. `Filter(Agg(...))`, it removes the original filter and adds a new filter onto Aggregate's child, e.g. `Agg(Filter(...))`. This is sometimes okay, but the case above is a counterexample: because there is no explicit `GROUP BY`, we are implicitly computing a global aggregate over the entire table so the original filter was not acting like a `HAVING` clause filtering the number of groups: if we push this filter then it fails to actually reduce the cardinality of the Aggregate output, leading to the wrong answer. In 2016 I fixed a similar problem involving invalid pushdowns of data-independent filters (filters which reference no columns of the filtered relation). There was additional discussion after my fix was merged which pointed out that my patch was an incomplete fix (see #15289), but it looks I must have either misunderstood the comment or forgot to follow up on the additional points raised there. This patch fixes the problem by choosing to never push down filters in cases where there are no grouping expressions. Since there are no grouping keys, the only columns are aggregate columns and we can't push filters defined over aggregate results, so this change won't cause us to miss out on any legitimate pushdown opportunities. ## How was this patch tested? New regression tests in `SQLQueryTestSuite` and `FilterPushdownSuite`. Author: Josh Rosen Closes #20180 from JoshRosen/SPARK-22983-dont-push-filters-beneath-aggs-with-empty-grouping-expressions. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../catalyst/optimizer/FilterPushdownSuite.scala | 13 +++++++++++++ .../test/resources/sql-tests/inputs/group-by.sql | 9 +++++++++ .../resources/sql-tests/results/group-by.sql.out | 16 +++++++++++++++- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d4b02c6e7d8a..df0af8264a329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -795,7 +795,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) - if aggregate.aggregateExpressions.forall(_.deterministic) => + if aggregate.aggregateExpressions.forall(_.deterministic) + && aggregate.groupingExpressions.nonEmpty => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 85a5e979f6021..82a10254d846d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -809,6 +809,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("aggregate: don't push filters if the aggregate has no grouping expressions") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy()(count(1)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 1e1384549a410..c5070b734d521 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -60,3 +60,12 @@ SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; -- Aggregate with empty input and empty GroupBy expressions. SELECT COUNT(1) FROM testData WHERE false; SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; + +-- Aggregate with empty GroupBy expressions and filter on top +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 986bb01c13fe4..c1abc6dff754b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 26 -- !query 0 @@ -227,3 +227,17 @@ SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t struct<1:int> -- !query 24 output 1 + + +-- !query 25 +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z +-- !query 25 schema +struct<1:int> +-- !query 25 output + From eb45b52e826ea9cea48629760db35ef87f91fea0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Jan 2018 19:41:41 +0800 Subject: [PATCH 0039/2461] [SPARK-21865][SQL] simplify the distribution semantic of Spark SQL ## What changes were proposed in this pull request? **The current shuffle planning logic** 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings are compatible with each other, via the `Partitioning.compatibleWith`. 6. If the check in 5 failed, add a shuffle above each child. 7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`. This design has a major problem with the definition of "compatible". `Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` can't know if it's compatible with other `Partitioning`, without more information from the operator. For example, `t1 join t2 on t1.a = t2.b`, `HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` under this case, but the partitioning itself doesn't know it. As a result, currently `Partitioning.compatibleWith` always return false except for literals, which make it almost useless. This also means, if an operator has distribution requirements for multiple children, Spark always add shuffle nodes to all the children(although some of them can be eliminated). However, there is no guarantee that the children's output partitionings are compatible with each other after adding these shuffles, we just assume that the operator will only specify `ClusteredDistribution` for multiple children. I think it's very hard to guarantee children co-partition for all kinds of operators, and we can not even give a clear definition about co-partition between distributions like `ClusteredDistribution(a,b)` and `ClusteredDistribution(c)`. I think we should drop the "compatible" concept in the distribution model, and let the operator achieve the co-partition requirement by special distribution requirements. **Proposed shuffle planning logic after this PR** (The first 4 are same as before) 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings have the same number of partitions. 6. If the check in 5 failed, pick the max number of partitions from children's output partitionings, and add shuffle to child whose number of partitions doesn't equal to the max one. The new distribution model is very simple, we only have one kind of relationship, which is `Partitioning.satisfy`. For multiple children, Spark only guarantees they have the same number of partitions, and it's the operator's responsibility to leverage this guarantee to achieve more complicated requirements. For example, non-broadcast joins can use the newly added `HashPartitionedDistribution` to achieve co-partition. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19080 from cloud-fan/exchange. --- .../plans/physical/partitioning.scala | 286 +++++++----------- .../sql/catalyst/PartitioningSuite.scala | 55 ---- .../spark/sql/execution/SparkPlan.scala | 16 +- .../exchange/EnsureRequirements.scala | 120 +++----- .../joins/ShuffledHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 81 ++--- 8 files changed, 194 insertions(+), 370 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e57c842ce2a36..0189bd73c56bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * - Intra-partition ordering of data: In this case the distribution describes guarantees made * about how tuples are distributed within a single partition. */ -sealed trait Distribution +sealed trait Distribution { + /** + * The required number of partitions for this distribution. If it's None, then any number of + * partitions is allowed for this distribution. + */ + def requiredNumPartitions: Option[Int] + + /** + * Creates a default partitioning for this distribution, which can satisfy this distribution while + * matching the given number of partitions. + */ + def createPartitioning(numPartitions: Int): Partitioning +} /** * Represents a distribution where no promises are made about co-location of data. */ -case object UnspecifiedDistribution extends Distribution +case object UnspecifiedDistribution extends Distribution { + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.") + } +} /** * Represents a distribution that only has a single partition and all tuples of the dataset * are co-located. */ -case object AllTuples extends Distribution +case object AllTuples extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.") + SinglePartition + } +} /** * Represents data where tuples that share the same values for the `clustering` @@ -51,12 +76,41 @@ case object AllTuples extends Distribution */ case class ClusteredDistribution( clustering: Seq[Expression], - numPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(clustering, numPartitions) + } +} + +/** + * Represents data where tuples have been clustered according to the hash of the given + * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * [[HashPartitioning]] can satisfy this distribution. + * + * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the + * number of partitions, this distribution strictly requires which partition the tuple should be in. + */ +case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + HashPartitioning(expressions, numPartitions) + } } /** @@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - // TODO: This is not really valid... - def clustering: Set[Expression] = ordering.map(_.child).toSet + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + RangePartitioning(ordering, numPartitions) + } } /** * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, + "The default partitioning of BroadcastDistribution can only have 1 partition.") + BroadcastPartitioning(mode) + } +} /** - * Describes how an operator's output is split across partitions. The `compatibleWith`, - * `guarantees`, and `satisfies` methods describe relationships between child partitionings, - * target partitionings, and [[Distribution]]s. These relations are described more precisely in - * their individual method docs, but at a high level: - * - * - `satisfies` is a relationship between partitionings and distributions. - * - `compatibleWith` is relationships between an operator's child output partitionings. - * - `guarantees` is a relationship between a child's existing output partitioning and a target - * output partitioning. - * - * Diagrammatically: - * - * +--------------+ - * | Distribution | - * +--------------+ - * ^ - * | - * satisfies - * | - * +--------------+ +--------------+ - * | Child | | Target | - * +----| Partitioning |----guarantees--->| Partitioning | - * | +--------------+ +--------------+ - * | ^ - * | | - * | compatibleWith - * | | - * +------------+ - * + * Describes how an operator's output is split across partitions. It has 2 major properties: + * 1. number of partitions. + * 2. if it can satisfy a given distribution. */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ @@ -123,113 +162,35 @@ sealed trait Partitioning { * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). - */ - def satisfies(required: Distribution): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] - * guarantees the same partitioning scheme described by `other`. - * - * Compatibility of partitionings is only checked for operators that have multiple children - * and that require a specific child output [[Distribution]], such as joins. - * - * Intuitively, partitionings are compatible if they route the same partitioning key to the same - * partition. For instance, two hash partitionings are only compatible if they produce the same - * number of output partitionings and hash records according to the same hash function and - * same partitioning key schema. - * - * Put another way, two partitionings are compatible with each other if they satisfy all of the - * same distribution guarantees. - */ - def compatibleWith(other: Partitioning): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees - * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning - * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance - * optimization to allow the exchange planner to avoid redundant repartitionings. By default, - * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number - * of partitions, same strategy (range or hash), etc). - * - * In order to enable more aggressive optimization, this strict equality check can be relaxed. - * For example, say that the planner needs to repartition all of an operator's children so that - * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children - * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens - * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; - * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` - * [[SinglePartition]]. - * - * The SinglePartition example given above is not particularly interesting; guarantees' real - * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion - * of null-safe partitionings, under which partitionings can specify whether rows whose - * partitioning keys contain null values will be grouped into the same partition or whether they - * will have an unknown / random distribution. If a partitioning does not require nulls to be - * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered - * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot - * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a - * symmetric relation. * - * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows - * produced by `A` could have also been produced by `B`. + * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if + * the [[Partitioning]] only have one partition. Implementations can overwrite this method with + * special logic. */ - def guarantees(other: Partitioning): Boolean = this == other -} - -object Partitioning { - def allCompatible(partitionings: Seq[Partitioning]): Boolean = { - // Note: this assumes transitivity - partitionings.sliding(2).map { - case Seq(a) => true - case Seq(a, b) => - if (a.numPartitions != b.numPartitions) { - assert(!a.compatibleWith(b) && !b.compatibleWith(a)) - false - } else { - a.compatibleWith(b) && b.compatibleWith(a) - } - }.forall(_ == true) - } -} - -case class UnknownPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { + def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false } +case class UnknownPartitioning(numPartitions: Int) extends Partitioning + /** * Represents a partitioning where rows are distributed evenly across output partitions * by starting from a random target partition number and distributing rows in a round-robin * fashion. This partitioning is used when implementing the DataFrame.repartition() operator. */ -case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false -} +case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) + case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } - - override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - - override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering, desiredPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case h: HashClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } /** @@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, desiredPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } } @@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) - /** - * Returns true if any `partitioning` of this collection is compatible with - * the given [[Partitioning]]. - */ - override def compatibleWith(other: Partitioning): Boolean = - partitionings.exists(_.compatibleWith(other)) - - /** - * Returns true if any `partitioning` of this collection guarantees - * the given [[Partitioning]]. - */ - override def guarantees(other: Partitioning): Boolean = - partitionings.exists(_.guarantees(other)) - override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } @@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { case BroadcastDistribution(m) if m == mode => true case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning(m) if m == mode => true - case _ => false - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala deleted file mode 100644 index 5b802ccc637dd..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} - -class PartitioningSuite extends SparkFunSuite { - test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { - val expressions = Seq(Literal(2), Literal(3)) - // Consider two HashPartitionings that have the same _set_ of hash expressions but which are - // created with different orderings of those expressions: - val partitioningA = HashPartitioning(expressions, 100) - val partitioningB = HashPartitioning(expressions.reverse, 100) - // These partitionings are not considered equal: - assert(partitioningA != partitioningB) - // However, they both satisfy the same clustered distribution: - val distribution = ClusteredDistribution(expressions) - assert(partitioningA.satisfies(distribution)) - assert(partitioningB.satisfies(distribution)) - // These partitionings compute different hashcodes for the same input row: - def computeHashCode(partitioning: HashPartitioning): Int = { - val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) - hashExprProj.apply(InternalRow.empty).hashCode() - } - assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) - // Thus, these partitionings are incompatible: - assert(!partitioningA.compatibleWith(partitioningB)) - assert(!partitioningB.compatibleWith(partitioningA)) - assert(!partitioningA.guarantees(partitioningB)) - assert(!partitioningB.guarantees(partitioningA)) - - // Just to be sure that we haven't cheated by having these methods always return false, - // check that identical partitionings are still compatible with and guarantee each other: - assert(partitioningA === partitioningA) - assert(partitioningA.guarantees(partitioningA)) - assert(partitioningA.compatibleWith(partitioningA)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 787c1cfbfb3d8..82300efc01632 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! - /** Specifies any partition requirements on the input data for this operator. */ + /** + * Specifies the data distribution requirements of all the children for this operator. By default + * it's [[UnspecifiedDistribution]] for each child, which means each child can have any + * distribution. + * + * If an operator overwrites this method, and specifies distribution requirements(excluding + * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark + * guarantees that the outputs of these children will have same number of partitions, so that the + * operator can safely zip partitions of these children's result RDDs. Some operators can leverage + * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify + * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d) + * for its right child, then it's guaranteed that left and right child are co-partitioned by + * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g., + * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child. + */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c8e236be28b42..e3d28388c5470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -46,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - * @param requiredDistribution The distribution that is required by the operator - * @param numPartitions Used when the distribution doesn't require a specific number of partitions - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering, desiredPartitions) => - HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - /** * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. @@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // shuffle data when we have more than one children because data generated by // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + val supportsDistribution = requiredChildDistributions.forall { dist => + dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] + } children.length > 1 && supportsDistribution } @@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // // It will be great to introduce a new Partitioning to represent the post-shuffle // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } @@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - // Ensure that the operator's children satisfy their output distribution requirements: + // Ensure that the operator's children satisfy their output distribution requirements. children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + val numPartitions = distribution.requiredNumPartitions + .getOrElse(defaultNumPreShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { - case UnspecifiedDistribution => false - case BroadcastDistribution(_) => false + // Get the indexes of children which have specified distribution requirements and need to have + // same number of partitions. + val childrenIndexes = requiredChildDistributions.zipWithIndex.filter { + case (UnspecifiedDistribution, _) => false + case (_: BroadcastDistribution, _) => false case _ => true - } - if (children.length > 1 - && requiredChildDistributions.exists(requireCompatiblePartitioning) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + }.map(_._2) - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + val childrenNumPartitions = + childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet + + if (childrenNumPartitions.size > 1) { + // Get the number of partitions which is explicitly required by the distributions. + val requiredNumPartitions = { + val numPartitionsSet = childrenIndexes.flatMap { + index => requiredChildDistributions(index).requiredNumPartitions + }.toSet + assert(numPartitionsSet.size <= 1, + s"$operator have incompatible requirements of the number of partitions for its children") + numPartitionsSet.headOption } - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } + val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) - children.zip(requiredChildDistributions).map { - case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) - case _ => ShuffleExchangeExec(targetPartitioning, child) - } + children = children.zip(requiredChildDistributions).zipWithIndex.map { + case ((child, distribution), index) if childrenIndexes.contains(index) => + if (child.outputPartitioning.numPartitions == targetNumPartitions) { + child + } else { + val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) + child match { + // If child is an exchange, we replace it with a new one having defaultPartitioning. + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) + case _ => ShuffleExchangeExec(defaultPartitioning, child) + } } - } + + case ((child, _), _) => child } } @@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchangeExec(partitioning, child, _) => - child.children match { - case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator + // TODO: remove this after we create a physical operator for `RepartitionByExpression`. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator } case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 66e8031bb5191..897a4dae39f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -46,7 +46,7 @@ case class ShuffledHashJoinExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 94405410cce90..2de2f30eb05d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -78,7 +78,7 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d1bd8a7076863..03d1bbf2ab882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -456,7 +456,7 @@ case class CoGroupExec( right: SparkPlan) extends BinaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil override def requiredChildOrdering: Seq[Seq[SortOrder]] = leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b50642d275ba8..f8b26f5b28cc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -260,11 +260,16 @@ class PlannerSuite extends SharedSQLContext { // do they satisfy the distribution requirements? As a result, we need at least four test cases. private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { - if (outputPlan.children.length > 1 - && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { - val childPartitionings = outputPlan.children.map(_.outputPartitioning) - if (!Partitioning.allCompatible(childPartitionings)) { - fail(s"Partitionings are not compatible: $childPartitionings") + if (outputPlan.children.length > 1) { + val childPartitionings = outputPlan.children.zip(outputPlan.requiredChildDistribution) + .filter { + case (_, UnspecifiedDistribution) => false + case (_, _: BroadcastDistribution) => false + case _ => true + }.map(_._1.outputPartitioning) + + if (childPartitionings.map(_.numPartitions).toSet.size > 1) { + fail(s"Partitionings doesn't have same number of partitions: $childPartitionings") } } outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { @@ -274,40 +279,7 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { - // Consider an operator that requires inputs that are clustered by two expressions (e.g. - // sort merge join where there are multiple columns in the equi-join condition) - val clusteringA = Literal(1) :: Nil - val clusteringB = Literal(2) :: Nil - val distribution = ClusteredDistribution(clusteringA ++ clusteringB) - // Say that the left and right inputs are each partitioned by _one_ of the two join columns: - val leftPartitioning = HashPartitioning(clusteringA, 1) - val rightPartitioning = HashPartitioning(clusteringB, 1) - // Individually, each input's partitioning satisfies the clustering distribution: - assert(leftPartitioning.satisfies(distribution)) - assert(rightPartitioning.satisfies(distribution)) - // However, these partitionings are not compatible with each other, so we still need to - // repartition both inputs prior to performing the join: - assert(!leftPartitioning.compatibleWith(rightPartitioning)) - assert(!rightPartitioning.compatibleWith(leftPartitioning)) - val inputPlan = DummySparkPlan( - children = Seq( - DummySparkPlan(outputPartitioning = leftPartitioning), - DummySparkPlan(outputPartitioning = rightPartitioning) - ), - requiredChildDistribution = Seq(distribution, distribution), - requiredChildOrdering = Seq(Seq.empty, Seq.empty) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { - fail(s"Exchange should have been added:\n$outputPlan") - } - } - test("EnsureRequirements with child partitionings with different numbers of output partitions") { - // This is similar to the previous test, except it checks that partitionings are not compatible - // unless they produce the same number of partitions. val clustering = Literal(1) :: Nil val distribution = ClusteredDistribution(clustering) val inputPlan = DummySparkPlan( @@ -386,18 +358,15 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + test("EnsureRequirements eliminates Exchange if child has same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { @@ -407,17 +376,13 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements does not eliminate Exchange with different partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - // Number of partitions differ - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { From 40b983c3b44b6771f07302ce87987fa4716b5ebf Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 8 Jan 2018 23:49:07 +0800 Subject: [PATCH 0040/2461] [SPARK-22952][CORE] Deprecate stageAttemptId in favour of stageAttemptNumber ## What changes were proposed in this pull request? 1. Deprecate attemptId in StageInfo and add `def attemptNumber() = attemptId` 2. Replace usage of stageAttemptId with stageAttemptNumber ## How was this patch tested? I manually checked the compiler warning info Author: Xianjin YE Closes #20178 from advancedxy/SPARK-22952. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++--- .../apache/spark/scheduler/StageInfo.scala | 4 +- .../spark/scheduler/StatsReportListener.scala | 2 +- .../spark/status/AppStatusListener.scala | 7 +-- .../org/apache/spark/status/LiveEntity.scala | 4 +- .../spark/ui/scope/RDDOperationGraph.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 54 ++++++++++--------- .../execution/ui/SQLAppStatusListener.scala | 2 +- 9 files changed, 51 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2498d4808e91..199937b8c27af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -815,7 +815,8 @@ class DAGScheduler( private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { // Note that there is a chance that this task is launched after the stage is cancelled. // In that case, we wouldn't have the stage anymore in stageIdToStage. - val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } @@ -1050,7 +1051,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) stage.pendingPartitions += id - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1060,7 +1061,7 @@ class DAGScheduler( val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, + new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1076,7 +1077,7 @@ class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run @@ -1245,7 +1246,7 @@ class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) { + if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { // This task was for the currently running attempt of the stage. Since the task // completed successfully from the perspective of the TaskSetManager, mark it as // no longer pending (the TaskSetManager may consider the task complete even @@ -1324,10 +1325,10 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) - if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index c513ed36d1680..903e25b7986f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - val attemptId: Int, + @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + def attemptNumber(): Int = attemptId + private[spark] def getStatusString: String = { if (completionTime.isDefined) { if (failureReason.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 3c8cab7504c17..3c7af4f6146fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -79,7 +79,7 @@ class StatsReportListener extends SparkListener with Logging { x => info.completionTime.getOrElse(System.currentTimeMillis()) - x ).getOrElse("-") - s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Stage(${info.stageId}, ${info.attemptNumber}); Name: '${info.name}'; " + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + s"Took: $timeTaken msec" } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 487a782e865e8..88b75ddd5993a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -529,7 +529,8 @@ private[spark] class AppStatusListener( } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + val maybeStage = + Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -785,7 +786,7 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber), new Function[(Int, Int), LiveStage]() { override def apply(key: (Int, Int)): LiveStage = new LiveStage() }) @@ -912,7 +913,7 @@ private[spark] class AppStatusListener( private def cleanupTasks(stage: LiveStage): Unit = { val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { - val stageKey = Array(stage.info.stageId, stage.info.attemptId) + val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) .last(stageKey) diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 52e83f250d34e..305c2fafa6aac 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -412,14 +412,14 @@ private class LiveStage extends LiveEntity { def executorSummary(executorId: String): LiveExecutorStageSummary = { executorSummaries.getOrElseUpdate(executorId, - new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + new LiveExecutorStageSummary(info.stageId, info.attemptNumber, executorId)) } def toApi(): v1.StageData = { new v1.StageData( status, info.stageId, - info.attemptId, + info.attemptNumber, info.numTasks, activeTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 827a8637b9bd2..948858224d724 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -116,7 +116,7 @@ private[spark] object RDDOperationGraph extends Logging { // Use a special prefix here to differentiate this cluster from other operation clusters val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId val stageClusterName = s"Stage ${stage.stageId}" + - { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" } + { if (stage.attemptNumber == 0) "" else s" (attempt ${stage.attemptNumber})" } val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName) var rootNodeCount = 0 diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5e60218c5740b..ff83301d631c4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -263,7 +263,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ - ("Stage Attempt ID" -> stageInfo.attemptId) ~ + ("Stage Attempt ID" -> stageInfo.attemptNumber) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 997c7de8dd02b..b8c84e24c2c3f 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -195,7 +195,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val s1Tasks = createTasks(4, execIds) s1Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, + stages.head.attemptNumber, + task)) } assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) @@ -213,10 +215,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.info.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + val runtime = Array[AnyRef](stages.head.stageId: JInteger, + stages.head.attemptNumber: JInteger, -1L: JLong) assert(Arrays.equals(wrapper.runtime, runtime)) @@ -237,7 +240,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { Some(1L), None, true, false, None) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) } check[StageDataWrapper](key(stages.head)) { stage => @@ -254,12 +257,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 val reattempt = newAttempt(s1Tasks.head, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt)) assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) @@ -289,7 +292,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val killed = s1Tasks.drop(1).head killed.finishTime = time killed.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskKilled("killed"), killed, null)) check[JobDataWrapper](1) { job => @@ -311,13 +314,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val denied = newAttempt(killed, nextTaskId()) val denyReason = TaskCommitDenied(1, 1, 1) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, denied)) time += 1 denied.finishTime = time denied.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", denyReason, denied, null)) check[JobDataWrapper](1) { job => @@ -337,7 +340,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a new attempt. val reattempt2 = newAttempt(denied, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt2)) // Succeed all tasks in stage 1. @@ -350,7 +353,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 pending.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", Success, task, s1Metrics)) } @@ -414,13 +417,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val s2Tasks = createTasks(4, execIds) s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, + stages.last.attemptNumber, + task)) } time += 1 s2Tasks.foreach { task => task.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptNumber, "taskType", TaskResultLost, task, null)) } @@ -455,7 +460,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // - Re-submit stage 2, all tasks, and succeed them and the stage. val oldS2 = stages.last - val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks, oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) time += 1 @@ -466,14 +471,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val newS2Tasks = createTasks(4, execIds) newS2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptNumber, task)) } time += 1 newS2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, - task, null)) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptNumber, "taskType", + Success, task, null)) } time += 1 @@ -522,14 +527,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val j2s2Tasks = createTasks(4, execIds) j2s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, + j2Stages.last.attemptNumber, task)) } time += 1 j2s2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptNumber, "taskType", Success, task, null)) } @@ -919,13 +925,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val tasks = createTasks(2, Array("1")) tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) // Start a 3rd task. The finished tasks should be deleted. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -934,7 +940,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a 4th task. The first task should be deleted, even if it's still running. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -960,7 +966,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d8adbe7bee13e..73a105266e1c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -99,7 +99,7 @@ class SQLAppStatusListener( // Reset the metrics tracking object for the new attempt. Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptId + metrics.attemptId = event.stageInfo.attemptNumber } } From eed82a0b211352215316ec70dc48aefc013ad0b2 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 8 Jan 2018 13:01:45 -0800 Subject: [PATCH 0041/2461] [SPARK-22992][K8S] Remove assumption of the DNS domain ## What changes were proposed in this pull request? Remove the use of FQDN to access the driver because it assumes that it's set up in a DNS zone - `cluster.local` which is common but not ubiquitous Note that we already access the in-cluster API server through `kubernetes.default.svc`, so, by extension, this should work as well. The alternative is to introduce DNS zones for both of those addresses. ## How was this patch tested? Unit tests cc vanzin liyinan926 mridulm mccheah Author: foxish Closes #20187 from foxish/cluster.local. --- .../deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala | 2 +- .../k8s/submit/steps/DriverServiceBootstrapStepSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index eb594e4f16ec0..34af7cde6c1a9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -83,7 +83,7 @@ private[spark] class DriverServiceBootstrapStep( .build() val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" + val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) .set("spark.driver.port", driverPort.toString) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala index 006ce2668f8a0..78c8c3ba1afbd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala @@ -85,7 +85,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } @@ -120,7 +120,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } From 4f7e75883436069c2d9028c4cd5daa78e8d59560 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 8 Jan 2018 13:24:08 -0800 Subject: [PATCH 0042/2461] [SPARK-22912] v2 data source support in MicroBatchExecution ## What changes were proposed in this pull request? Support for v2 data sources in microbatch streaming. ## How was this patch tested? A very basic new unit test on the toy v2 implementation of rate source. Once we have a v1 source fully migrated to v2, we'll need to do more detailed compatibility testing. Author: Jose Torres Closes #20097 from jose-torres/v2-impl. --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../datasources/v2/DataSourceV2Relation.scala | 10 ++ .../streaming/MicroBatchExecution.scala | 112 ++++++++++++++---- .../streaming/ProgressReporter.scala | 6 +- .../streaming/RateSourceProvider.scala | 10 +- .../execution/streaming/StreamExecution.scala | 4 +- .../streaming/StreamingRelation.scala | 4 +- .../continuous/ContinuousExecution.scala | 4 +- .../ContinuousRateStreamSource.scala | 17 +-- .../sources/RateStreamSourceV2.scala | 31 ++++- .../sql/streaming/DataStreamReader.scala | 25 +++- .../sql/streaming/StreamingQueryManager.scala | 24 ++-- .../streaming/RateSourceV2Suite.scala | 68 +++++++++-- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 2 +- 15 files changed, 241 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 6cdfe2fae5642..0259c774bbf4a 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -7,3 +7,4 @@ org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7eb99a645001a..cba20dd902007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,6 +35,16 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 9a7a13fcc5806..42240eeb58d4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.util.Optional + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} @@ -24,7 +27,10 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -33,10 +39,11 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: Sink, + sink: BaseStreamingSink, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, + extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, @@ -57,6 +64,13 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a + // map as we go to ensure each identical relation gets the same StreamingExecutionRelation + // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical + // plan for the data within that batch. + // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, + // since the existing logical plan has already used those attributes. The per-microbatch + // transformation is responsible for replacing attributes with their final values. val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -64,19 +78,26 @@ class MicroBatchExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" val source = dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) - if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val reader = source.createMicroBatchReader( + Optional.empty(), // user specified schema + metadataPath, + new DataSourceV2Options(options.asJava)) + nextSourceId += 1 + StreamingExecutionRelation(reader, output)(sparkSession) + }) + case s @ StreamingRelationV2(_, _, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = v1DataSource.createSource(metadataPath) + assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) } @@ -192,7 +213,8 @@ class MicroBatchExecution( source.getBatch(start, end) } case nonV1Tuple => - throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") + // The V2 API does not have the same edge case requiring getBatch to be called + // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets @@ -236,14 +258,27 @@ class MicroBatchExecution( val hasNewData = { awaitProgressLock.lock() try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { (s, s.getOffset) } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) + + (s, Some(s.getEndOffset)) + } }.toMap - availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) if (dataAvailable) { true @@ -317,6 +352,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) } } else { throw new IllegalStateException(s"batch $currentBatchId doesn't exist") @@ -357,7 +394,16 @@ class MicroBatchExecution( s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) + Some(source -> batch.logicalPlan) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + reader.setOffsetRange( + toJava(current), + Optional.of(available.asInstanceOf[OffsetV2])) + logDebug(s"Retrieving data from $reader: $current -> $available") + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -365,15 +411,14 @@ class MicroBatchExecution( // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. - val withNewSources = logicalPlan transform { + val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => - newData.get(source).map { data => - val newPlan = data.logicalPlan - assert(output.size == newPlan.output.size, + newData.get(source).map { dataPlan => + assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newPlan.output, ",")}") - replacements ++= output.zip(newPlan.output) - newPlan + s"${Utils.truncatedString(dataPlan.output, ",")}") + replacements ++= output.zip(dataPlan.output) + dataPlan }.getOrElse { LocalRelation(output, isStreaming = true) } @@ -381,7 +426,7 @@ class MicroBatchExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { + val newAttributePlan = newBatchesPlan transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => @@ -392,6 +437,20 @@ class MicroBatchExecution( cd.dataType, cd.timeZoneId) } + val triggerLogicalPlan = sink match { + case _: Sink => newAttributePlan + case s: MicroBatchWriteSupport => + val writer = s.createMicroBatchWriter( + s"$runId", + currentBatchId, + newAttributePlan.schema, + outputMode, + new DataSourceV2Options(extraOptions.asJava)) + assert(writer.isPresent, "microbatch writer must always be present") + WriteToDataSourceV2(writer.get, newAttributePlan) + case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") + } + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -409,7 +468,12 @@ class MicroBatchExecution( reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { - sink.addBatch(currentBatchId, nextBatch) + sink match { + case s: Sink => s.addBatch(currentBatchId, nextBatch) + case s: MicroBatchWriteSupport => + // This doesn't accumulate any data - it just forces execution of the microbatch writer. + nextBatch.collect() + } } } @@ -421,4 +485,8 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1c9043613cb69..d1e5be9c12762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -53,7 +53,7 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, DataFrame] + protected def newData: Map[BaseStreamingSource, LogicalPlan] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] @@ -225,8 +225,8 @@ trait ProgressReporter extends Logging { // // 3. For each source, we sum the metrics of the associated execution plan leaves. // - val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index d02cf882b61ac..66eb0169ac1ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -29,12 +29,12 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} @@ -112,7 +112,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister schema: Optional[StructType], checkpointLocation: String, options: DataSourceV2Options): ContinuousReader = { - new ContinuousRateStreamReader(options) + new RateStreamContinuousReader(options) } override def shortName(): String = "rate" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3e76bf7b7ca8f..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -163,7 +163,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null @@ -418,7 +418,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a9d50e3a112e7..a0ee683a895d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -61,7 +61,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: Source, + source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) extends LeafNode { @@ -92,7 +92,7 @@ case class StreamingRelationV2( sourceName: String, extraOptions: Map[String, String], output: Seq[Attribute], - v1DataSource: DataSource)(session: SparkSession) + v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2843ab13bde2b..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} @@ -174,7 +174,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) - DataSourceV2Relation(newOutput, reader) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index c9aa78a5a2e28..b4b21e7d2052f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -case class ContinuousRateStreamPartitionOffset( +case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class ContinuousRateStreamReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceV2Options) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats @@ -48,7 +48,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + case RateStreamPartitionOffset(i, currVal, nextRead) => (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) @@ -86,7 +86,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamReadTask( + RateStreamContinuousReadTask( start.value, start.runTimeMs, i, @@ -101,7 +101,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) } -case class RateStreamReadTask( +case class RateStreamContinuousReadTask( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -109,10 +109,11 @@ case class RateStreamReadTask( rowsPerSecond: Double) extends ReadTask[Row] { override def createDataReader(): DataReader[Row] = - new RateStreamDataReader(startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) + new RateStreamContinuousDataReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamDataReader( +class RateStreamContinuousDataReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -151,5 +152,5 @@ class RateStreamDataReader( override def close(): Unit = {} override def getOffset(): PartitionOffset = - ContinuousRateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) + RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 97bada08bcd2b..c0ed12cec25ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -28,17 +28,38 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamV2Reader(options: DataSourceV2Options) +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceV2Options) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats - val clock = new SystemClock + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } private val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt @@ -111,7 +132,7 @@ class RateStreamV2Reader(options: DataSourceV2Options) val packedRows = mutable.ListBuffer[(Long, Long)]() var outVal = startVal + numPartitions - var outTimeMs = startTimeMs + msPerPartitionBetweenRows + var outTimeMs = startTimeMs while (outVal <= endVal) { packedRows.append((outTimeMs, outVal)) outVal += numPartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2e92beecf2c17..52f2e2639cd86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -27,8 +27,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -166,19 +167,31 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) + val v1Relation = ds match { + case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource)) + case _ => None + } ds match { + case s: MicroBatchReadSupport => + val tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( - java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Optional.ofNullable(userSpecifiedSchema.orNull), Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, options) - // Generate the V1 node to catch errors thrown within generation. - StreamingRelation(v1DataSource) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index b508f4406138f..4b27e0d4ef47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,10 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -240,31 +240,35 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - sink match { - case v1Sink: Sink => - new StreamingQueryWrapper(new MicroBatchExecution( + (sink, trigger) match { + case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v1Sink, + v2Sink, trigger, triggerClock, outputMode, + extraOptions, deleteCheckpointOnStop)) - case v2Sink: ContinuousWriteSupport => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) - new StreamingQueryWrapper(new ContinuousExecution( + case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + sink, trigger, triggerClock, outputMode, extraOptions, deleteCheckpointOnStop)) + case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => + throw new AnalysisException( + "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index e11705a227f48..85085d43061bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -18,20 +18,64 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + test("microbatch - numPartitions propagated") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createReadTasks() @@ -39,7 +83,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - set offset") { - val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -48,7 +92,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - infer offsets") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) @@ -69,7 +113,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - predetermined batch size") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) @@ -80,7 +124,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - data read") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { @@ -107,14 +151,14 @@ class RateSourceV2Suite extends StreamTest { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[ContinuousRateStreamReader]) + assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") } } test("continuous data") { - val reader = new ContinuousRateStreamReader( + val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createReadTasks() @@ -122,17 +166,17 @@ class RateSourceV2Suite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamReadTask => + case t: RateStreamContinuousReadTask => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamDataReader] + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) assert(r.getOffset() == - ContinuousRateStreamPartitionOffset( + RateStreamPartitionOffset( t.partitionIndex, t.partitionIndex + rowIndex * 2, startTimeMs + (rowIndex + 1) * 100)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4b7f0fbe97d4e..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -105,7 +105,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (Source, Offset) + def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) } /** A trait that can be extended when testing a source. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index eda0d8ad48313..9562c10feafe9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -61,7 +61,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 68ce792b5857f0291154f524ac651036db868bb9 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 9 Jan 2018 10:15:01 +0800 Subject: [PATCH 0043/2461] [SPARK-22972] Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc ## What changes were proposed in this pull request? Fix the warning: Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc. ## How was this patch tested? test("SPARK-22972: hive orc source") assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") .equals(HiveSerDe.sourceToSerDe("orc"))) Author: xubo245 <601450868@qq.com> Closes #20165 from xubo245/HiveSerDe. --- .../apache/spark/sql/internal/HiveSerDe.scala | 1 + .../sql/hive/orc/HiveOrcSourceSuite.scala | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index b9515ec7bca2a..dac463641cfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -73,6 +73,7 @@ object HiveSerDe { val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc" case s if s.equals("orcfile") => "orc" case s if s.equals("parquetfile") => "parquet" case s if s.equals("avrofile") => "avro" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 17b7d8cfe127e..d556a030e2186 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -62,6 +64,33 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { """.stripMargin) } + test("SPARK-22972: hive orc source") { + val tableName = "normal_orc_as_source_hive" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + val tableMetadata = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(tableName)) + assert(tableMetadata.storage.inputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(tableMetadata.storage.outputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(tableMetadata.storage.serde == + Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + } + } + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { val location = Utils.createTempDir() val uri = location.toURI From 849043ce1d28a976659278d29368da0799329db8 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 9 Jan 2018 10:44:21 +0800 Subject: [PATCH 0044/2461] [SPARK-22990][CORE] Fix method isFairScheduler in JobsTab and StagesTab ## What changes were proposed in this pull request? In current implementation, the function `isFairScheduler` is always false, since it is comparing String with `SchedulingMode` Author: Wang Gengliang Closes #20186 from gengliangwang/isFairScheduler. --- .../src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala | 8 ++++---- .../main/scala/org/apache/spark/ui/jobs/StagesTab.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 99eab1b2a27d8..ff1b75e5c5065 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -34,10 +34,10 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) val killEnabled = parent.killEnabled def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def getSparkUser: String = parent.getSparkUser diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index be05a963f0e68..10b032084ce4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -37,10 +37,10 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def handleKillRequest(request: HttpServletRequest): Unit = { From f20131dd35939734fe16b0005a086aa72400893b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 9 Jan 2018 11:49:10 +0800 Subject: [PATCH 0045/2461] [SPARK-22984] Fix incorrect bitmap copying and offset adjustment in GenerateUnsafeRowJoiner ## What changes were proposed in this pull request? This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This class was introduced in https://github.com/apache/spark/pull/7821 (July 2015 / Spark 1.5.0+) and is used to combine pairs of UnsafeRows in TungstenAggregationIterator, CartesianProductExec, and AppendColumns. ### Bugs fixed by this patch 1. **Incorrect combining of null-tracking bitmaps**: when concatenating two UnsafeRows, the implementation "Concatenate the two bitsets together into a single one, taking padding into account". If one row has no columns then it has a bitset size of 0, but the code was incorrectly assuming that if the left row had a non-zero number of fields then the right row would also have at least one field, so it was copying invalid bytes and and treating them as part of the bitset. I'm not sure whether this bug was also present in the original implementation or whether it was introduced in https://github.com/apache/spark/pull/7892 (which fixed another bug in this code). 2. **Incorrect updating of data offsets for null variable-length fields**: after updating the bitsets and copying fixed-length and variable-length data, we need to perform adjustments to the offsets pointing the start of variable length fields's data. The existing code was _conditionally_ adding a fixed offset to correct for the new length of the combined row, but it is unsafe to do this if the variable-length field has a null value: we always represent nulls by storing `0` in the fixed-length slot, but this code was incorrectly incrementing those values. This bug was present since the original version of `GenerateUnsafeRowJoiner`. ### Why this bug remained latent for so long The PR which introduced `GenerateUnsafeRowJoiner` features several randomized tests, including tests of the cases where one side of the join has no fields and where string-valued fields are null. However, the existing assertions were too weak to uncover this bug: - If a null field has a non-zero value in its fixed-length data slot then this will not cause problems for field accesses because the null-tracking bitmap should still be correct and we will not try to use the incorrect offset for anything. - If the null tracking bitmap is corrupted by joining against a row with no fields then the corruption occurs in field numbers past the actual field numbers contained in the row. Thus valid `isNullAt()` calls will not read the incorrectly-set bits. The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and `isNullAt()`, but didn't actually check the UnsafeRows for bit-for-bit equality, preventing these bugs from failing assertions. It turns out that there was even a [GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala) but it looks like it also didn't catch this problem because it only tested the bitsets in an end-to-end fashion by accessing them through the `UnsafeRow` interface instead of actually comparing the bitsets' bytes. ### Impact of these bugs - This bug will cause `equals()` and `hashCode()` to be incorrect for these rows, which will be problematic in case`GenerateUnsafeRowJoiner`'s results are used as join or grouping keys. - Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in reads from invalid null bitmap positions causing fields to incorrectly become NULL (see the end-to-end example below). - It looks like this generally only happens in `CartesianProductExec`, which our query optimizer often avoids executing (usually we try to plan a `BroadcastNestedLoopJoin` instead). ### End-to-end test case demonstrating the problem The following query demonstrates how this bug may result in incorrect query results: ```sql set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger CartesianProductExec create table a as select * from values 1; create table b as select * from values 2; SELECT t3.col1, t1.col1 FROM a t1 CROSS JOIN b t2 CROSS JOIN b t3 ``` This should return `(2, 1)` but instead was returning `(null, 1)`. Column pruning ends up trimming off all columns from `t2`, so when `t2` joins with another table this triggers the bitmap-copying bug. This incorrect bitmap is subsequently copied again when performing the final join, causing the final output to have an incorrectly-set null bit for the first field. ## How was this patch tested? Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. Also verified that the end-to-end test case which uncovered this now passes. Author: Josh Rosen Closes #20181 from JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs. --- .../codegen/GenerateUnsafeRowJoiner.scala | 52 +++++++++- .../GenerateUnsafeRowJoinerSuite.scala | 95 ++++++++++++++++++- 2 files changed, 138 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index be5f5a73b5d47..febf7b0c96c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => - val bits = if (bitset1Remainder > 0) { + val bits = if (bitset1Remainder > 0 && bitset2Words != 0) { if (i < bitset1Words - 1) { s"$getLong(obj1, offset1 + ${i * 8})" } else if (i == bitset1Words - 1) { @@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } else { // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the // offset in the upper 32 bit of the words, we can just shift the offset to the left by - // 32 and increment that amount in place. + // 32 and increment that amount in place. However, we need to handle the important special + // case of a null field, in which case the offset should be zero and should not have a + // shift added to it. val shift = if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" @@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" + // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's + // output as a de-facto specification for the internal layout of data. + // + // Null-valued fields will always have a data offset of 0 because + // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's + // position in the fixed-length section of the row. As a result, we must NOT add + // `shift` to the offset for null fields. + // + // We could perform a null-check here by inspecting the null-tracking bitmap, but doing + // so could be expensive and will add significant bloat to the generated code. Instead, + // we'll rely on the invariant "stored offset == 0 for variable-length data type implies + // that the field's value is null." + // + // To establish that this invariant holds, we'll prove that a non-null field can never + // have a stored offset of 0. There are two cases to consider: + // + // 1. The non-null field's data is of non-zero length: reading this field's value + // must read data from the variable-length section of the row, so the stored offset + // will actually be used in address calculation and must be correct. The offsets + // count bytes from the start of the UnsafeRow so these offsets will always be + // non-zero because the storage of the offsets themselves takes up space at the + // start of the row. + // 2. The non-null field's data is of zero length (i.e. its data is empty). In this + // case, we have to worry about the possibility that an arbitrary offset value was + // stored because we never actually read any bytes using this offset and therefore + // would not crash if it was incorrect. The variable-sized data writing paths in + // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with + // no special handling for the case where `numBytes == 0`. Internally, + // setOffsetAndSize computes the offset without taking the size into account. Thus + // the stored offset is the same non-zero offset that would be used if the field's + // dataSize was non-zero (and in (1) above we've shown that case behaves as we + // expect). + // + // Thus it is safe to perform `existingOffset != 0` checks here in the place of + // more expensive null-bit checks. + s""" + |existingOffset = $getLong(buf, $cursor); + |if (existingOffset != 0) { + | $putLong(buf, $cursor, existingOffset + ($shift << 32)); + |} + """.stripMargin } } val updateOffsets = ctx.splitExpressions( expressions = updateOffset, funcName = "copyBitsetFunc", - arguments = ("long", "numBytesVariableRow1") :: Nil) + arguments = ("long", "numBytesVariableRow1") :: Nil, + makeSplitFunction = (s: String) => "long existingOffset;\n" + s) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 + | long existingOffset; | $updateOffsets | | out.pointTo(buf, sizeInBytes); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index f203f25ad10d4..75c6beeb32150 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for [[GenerateUnsafeRowJoiner]]. @@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { testConcat(64, 64, fixed) } + test("rows with all empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8)) + testConcat(schema, row, schema, row) + } + + test("rows with all empty int arrays") { + val schema = StructType(Seq( + StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType)))) + val emptyIntArray = + ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(emptyIntArray, emptyIntArray)) + testConcat(schema, row, schema, row) + } + + test("alternating empty and non-empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo"))) + testConcat(schema, row, schema, row) + } + test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) @@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) + testConcat(schema1, row1, schema2, row2) + } + + private def testConcat( + schema1: StructType, + row1: UnsafeRow, + schema2: StructType, + row2: UnsafeRow) { // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) - val output = concater.join(row1, row2) + val output: UnsafeRow = concater.join(row1, row2) + + // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures + // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written + // correctly. + val expectedOutput: UnsafeRow = { + val joinedRowProjection = UnsafeProjection.create(mergedSchema) + val joined = new JoinedRow() + joinedRowProjection.apply(joined.apply(row1, row2)) + } // Test everything equals ... for (i <- mergedSchema.indices) { + val dataType = mergedSchema(i).dataType if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row1.get(i, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === - row2.get(i - schema1.size, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } } + + + assert( + expectedOutput.getSizeInBytes == output.getSizeInBytes, + "output isn't same size in bytes as slow path") + + // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in + // case this assertion fails: + val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]] + .take(output.getSizeInBytes) + val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]] + .take(expectedOutput.getSizeInBytes) + + val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields()) + val actualBitset = actualBytes.take(bitsetWidth) + val expectedBitset = expectedBytes.take(bitsetWidth) + assert(actualBitset === expectedBitset, "bitsets were not equal") + + val fixedLengthSize = expectedOutput.numFields() * 8 + val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + if (actualFixedLength !== expectedFixedLength) { + actualFixedLength.grouped(8) + .zip(expectedFixedLength.grouped(8)) + .zip(mergedSchema.fields.toIterator) + .foreach { + case ((actual, expected), field) => + assert(actual === expected, s"Fixed length sections are not equal for field $field") + } + fail("Fixed length sections were not equal") + } + + val variableLengthStart = bitsetWidth + fixedLengthSize + val actualVariableLength = actualBytes.drop(variableLengthStart) + val expectedVariableLength = expectedBytes.drop(variableLengthStart) + assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal") + + assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal") } } From 8486ad419d8f1779e277ec71c39e1516673a83ab Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 21:58:26 -0800 Subject: [PATCH 0046/2461] [SPARK-21292][DOCS] refreshtable example ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20198 from felixcheung/rrefreshdoc. --- docs/sql-programming-guide.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ccaaf4d5b1fa..72f79d6909ecc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -915,6 +915,14 @@ spark.catalog.refreshTable("my_table") +
+ +{% highlight r %} +refreshTable("my_table") +{% endhighlight %} + +
+
{% highlight sql %} @@ -1498,10 +1506,10 @@ that these options will be deprecated in future release as more optimizations ar ## Broadcast Hint for SQL Queries The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) support BHJ. When the broadcast nested loop join is selected, we still respect the hint.
@@ -1780,7 +1788,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. @@ -2167,7 +2175,7 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. - + ### Incompatible Hive UDF Below are the scenarios in which Hive and Spark generate different results: From 02214b094390e913f52e71d55c9bb8a81c9e7ef9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 22:08:19 -0800 Subject: [PATCH 0047/2461] [SPARK-21293][SPARKR][DOCS] structured streaming doc update ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20197 from felixcheung/rwadoc. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 2 +- docs/sparkr.md | 2 +- .../structured-streaming-programming-guide.md | 32 +++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 2e662424b25f2..feca617c2554c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1042,7 +1042,7 @@ unlink(modelPath) ## Structured Streaming -SparkR supports the Structured Streaming API (experimental). +SparkR supports the Structured Streaming API. You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. diff --git a/docs/sparkr.md b/docs/sparkr.md index 997ea60fb6cf0..6685b585a393a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -596,7 +596,7 @@ The following example shows how to save/load a MLlib model by SparkR. # Structured Streaming -SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) +SparkR supports the Structured Streaming API. Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) # R Function Name Conflicts diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 31fcfabb9cacc..de13e281916db 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -827,8 +827,8 @@ df.isStreaming() {% endhighlight %}
-{% highlight bash %} -Not available. +{% highlight r %} +isStreaming(df) {% endhighlight %}
@@ -885,6 +885,19 @@ windowedCounts = words.groupBy( ).count() {% endhighlight %} + +
+{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
@@ -959,6 +972,21 @@ windowedCounts = words \ .count() {% endhighlight %} + +
+{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group + +words <- withWatermark(words, "timestamp", "10 minutes") +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
From 0959aa581a399279be3f94214bcdffc6a1b6d60a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 9 Jan 2018 16:31:20 +0800 Subject: [PATCH 0048/2461] [SPARK-23000] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite in Spark 2.3 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ The test suite DataSourceWithHiveMetastoreCatalogSuite of Branch 2.3 always failed in hadoop 2.6 The table `t` exists in `default`, but `runSQLHive` reported the table does not exist. Obviously, Hive client's default database is different. The fix is to clean the environment and use `DEFAULT` as the database. ``` org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' Stacktrace sbt.ForkMain$ForkError: org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:699) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20196 from gatorsmile/testFix. --- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 6 +++++- .../apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7b7f4e0f10210..102f40bacc985 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -823,7 +823,8 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - client.getAllTables("default").asScala.foreach { t => + try { + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).asScala.foreach { index => @@ -837,6 +838,9 @@ private[hive] class HiveClientImpl( logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } + } finally { + runSqlHive("USE default") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 18137e7ea1d63..cf4ce83124d88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -146,6 +146,11 @@ class DataSourceWithHiveMetastoreCatalogSuite 'id cast StringType as 'd2 ).coalesce(1) + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.metadataHive.reset() + } + Seq( "parquet" -> (( "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", From 6a4206ff04746481d7c8e307dfd0d31ff1402555 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Tue, 9 Jan 2018 01:32:48 -0800 Subject: [PATCH 0049/2461] [SPARK-22998][K8S] Set missing value for SPARK_MOUNTED_CLASSPATH in the executors ## What changes were proposed in this pull request? The environment variable `SPARK_MOUNTED_CLASSPATH` is referenced in the executor's Dockerfile, where its value is added to the classpath of the executor. However, the scheduler backend code missed setting it when creating the executor pods. This PR fixes it. ## How was this patch tested? Unit tested. vanzin Can you help take a look? Thanks! foxish Author: Yinan Li Closes #20193 from liyinan926/master. --- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 ++++- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 066d7e9f70ca5..bcacb3934d36a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -94,6 +94,8 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) + /** * Configure and construct an executor pod with the given parameters. */ @@ -145,7 +147,8 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + (ENV_EXECUTOR_ID, executorId), + (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 884da8aabd880..7cfbe54c95390 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -197,7 +197,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + ENV_EXECUTOR_POD_IP -> null, + ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) From f44ba910f58083458e1133502e193a9d6f2bf766 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 9 Jan 2018 21:48:14 +0800 Subject: [PATCH 0050/2461] [SPARK-16060][SQL] Support Vectorized ORC Reader ## What changes were proposed in this pull request? This PR adds an ORC columnar-batch reader to native `OrcFileFormat`. Since both Spark `ColumnarBatch` and ORC `RowBatch` are used together, it is faster than the current Spark implementation. This replaces the prior PR, #17924. Also, this PR adds `OrcReadBenchmark` to show the performance improvement. ## How was this patch tested? Pass the existing test cases. Author: Dongjoon Hyun Closes #19943 from dongjoon-hyun/SPARK-16060. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../orc/OrcColumnarBatchReader.java | 523 ++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 75 ++- .../execution/datasources/orc/OrcUtils.scala | 7 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 435 +++++++++++++++ 5 files changed, 1022 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5c61f10bb71ad..74949db883f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -386,6 +386,11 @@ object SQLConf { .checkValues(Set("hive", "native")) .createWithDefault("native") + val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc decoding.") + .booleanConf + .createWithDefault(true) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf @@ -1183,6 +1188,8 @@ class SQLConf extends Serializable with Logging { def orcCompressionCodec: String = getConf(ORC_COMPRESSION) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java new file mode 100644 index 0000000000000..5c28d0e6e507a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.io.IOException; +import java.util.stream.IntStream; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.*; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OrcColumnarBatchReader extends RecordReader { + + /** + * The default size of batch. We use this value for both ORC and Spark consistently + * because they have different default values like the following. + * + * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 + * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 + */ + public static final int DEFAULT_SIZE = 4 * 1024; + + // ORC File Reader + private Reader reader; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column doesn't exist in the ORC file. + */ + private int[] requestedColIds; + + // Record reader from ORC row batch. + private org.apache.orc.RecordReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + private ColumnarBatch columnarBatch; + + // Writable column vectors of the result columnar batch. + private WritableColumnVector[] columnVectors; + + /** + * The memory mode of the columnarBatch + */ + private final MemoryMode MEMORY_MODE; + + public OrcColumnarBatchReader(boolean useOffHeap) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + } + + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + reader = OrcFile.createReader( + fileSplit.getPath(), + OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + + Reader.Options options = + OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); + recordReader = reader.rows(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + */ + public void initBatch( + TypeDescription orcSchema, + int[] requestedColIds, + StructField[] requiredFields, + StructType partitionSchema, + InternalRow partitionValues) { + batch = orcSchema.createRowBatch(DEFAULT_SIZE); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + + this.requiredFields = requiredFields; + this.requestedColIds = requestedColIds; + assert(requiredFields.length == requestedColIds.length); + + StructType resultSchema = new StructType(requiredFields); + for (StructField f : partitionSchema.fields()) { + resultSchema = resultSchema.add(f); + } + + int capacity = DEFAULT_SIZE; + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + } + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].setIsConstant(); + } + } + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); + + recordReader.nextBatch(batch); + int batchSize = batch.size; + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + for (int i = 0; i < requiredFields.length; i++) { + StructField field = requiredFields[i]; + WritableColumnVector toColumn = columnVectors[i]; + + if (requestedColIds[i] >= 0) { + ColumnVector fromColumn = batch.cols[requestedColIds[i]]; + + if (fromColumn.isRepeating) { + putRepeatingValues(batchSize, field, fromColumn, toColumn); + } else if (fromColumn.noNulls) { + putNonNullValues(batchSize, field, fromColumn, toColumn); + } else { + putValues(batchSize, field, fromColumn, toColumn); + } + } + } + return true; + } + + private void putRepeatingValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + if (fromColumn.isNull[0]) { + toColumn.putNulls(0, batchSize); + } else { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1); + } else if (type instanceof ByteType) { + toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof ShortType) { + toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof IntegerType || type instanceof DateType) { + toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof TimestampType) { + toColumn.putLongs(0, batchSize, + fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0)); + } else if (type instanceof FloatType) { + toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int size = data.vector[0].length; + arrayData.reserve(size); + arrayData.putBytes(0, size, data.vector[0], 0); + for (int index = 0; index < batchSize; index++) { + toColumn.putArray(index, 0, size); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + putDecimalWritables( + toColumn, + batchSize, + decimalType.precision(), + decimalType.scale(), + ((DecimalColumnVector)fromColumn).vector[0]); + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + } + + private void putNonNullValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putBoolean(index, data[index] == 1); + } + } else if (type instanceof ByteType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putByte(index, (byte)data[index]); + } + } else if (type instanceof ShortType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putShort(index, (short)data[index]); + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putInt(index, (int)data[index]); + } + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0); + } else if (type instanceof TimestampType) { + TimestampColumnVector data = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + toColumn.putLong(index, fromTimestampColumnVector(data, index)); + } + } else if (type instanceof FloatType) { + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putFloat(index, (float)data[index]); + } + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = ((BytesColumnVector)fromColumn); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(data.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { + arrayData.putBytes(pos, data.length[index], data.vector[index], data.start[index]); + toColumn.putArray(index, pos, data.length[index]); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + DecimalColumnVector data = ((DecimalColumnVector)fromColumn); + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + data.vector[index]); + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + private void putValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putBoolean(index, vector[index] == 1); + } + } + } else if (type instanceof ByteType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putByte(index, (byte)vector[index]); + } + } + } else if (type instanceof ShortType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putShort(index, (short)vector[index]); + } + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putInt(index, (int)vector[index]); + } + } + } else if (type instanceof LongType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, vector[index]); + } + } + } else if (type instanceof TimestampType) { + TimestampColumnVector vector = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, fromTimestampColumnVector(vector, index)); + } + } + } else if (type instanceof FloatType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putFloat(index, (float)vector[index]); + } + } + } else if (type instanceof DoubleType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putDouble(index, vector[index]); + } + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector vector = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(vector.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + arrayData.putBytes(pos, vector.length[index], vector.vector[index], vector.start[index]); + toColumn.putArray(index, pos, vector.length[index]); + } + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + vector[index]); + } + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + /** + * Returns the number of micros since epoch from an element of TimestampColumnVector. + */ + private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { + return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + } + + /** + * Put a `HiveDecimalWritable` to a `WritableColumnVector`. + */ + private static void putDecimalWritable( + WritableColumnVector toColumn, + int index, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInt(index, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLong(index, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.putArray(index, index * 16, bytes.length); + } + } + + /** + * Put `HiveDecimalWritable`s to a `WritableColumnVector`. + */ + private static void putDecimalWritables( + WritableColumnVector toColumn, + int size, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInts(0, size, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLongs(0, size, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(bytes.length); + arrayData.putBytes(0, bytes.length, bytes, 0); + for (int index = 0; index < size; index++) { + toColumn.putArray(index, 0, bytes.length); + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index f7471cd7debce..b8bacfa1838ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -118,6 +118,13 @@ class OrcFileFormat } } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def isSplitable( sparkSession: SparkSession, options: Map[String, String], @@ -139,6 +146,11 @@ class OrcFileFormat } } + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -146,8 +158,14 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + isCaseSensitive, dataSchema, requiredSchema, reader, conf) if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty @@ -155,29 +173,46 @@ class OrcFileFormat val requestedColIds = requestedColIdsOrEmptyFile.get assert(requestedColIds.length == requiredSchema.length, "[BUG] requested column IDs do not match required schema") - conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, requestedColIds.filter(_ != -1).sorted.mkString(",")) - val fileSplit = - new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) - val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - - if (partitionSchema.length == 0) { - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val batchReader = + new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initBatch( + reader.getSchema, + requestedColIds, + requiredSchema.fields, + partitionSchema, + file.partitionValues) + + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + iter.asInstanceOf[Iterator[InternalRow]] } else { - val joinedRow = new JoinedRow() - iter.map(value => - unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b03ee06d04a16..13a23996f4ade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -80,11 +80,8 @@ object OrcUtils extends Logging { isCaseSensitive: Boolean, dataSchema: StructType, requiredSchema: StructType, - file: Path, + reader: Reader, conf: Configuration): Option[Array[Int]] = { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) val orcFieldNames = reader.getSchema.getFieldNames.asScala if (orcFieldNames.isEmpty) { // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala new file mode 100644 index 0000000000000..37ed846acd1eb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure ORC read performance. + * + * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. + */ +// scalastyle:off line.size.limit +object OrcReadBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("OrcReadBenchmark") + .config(conf) + .getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName + + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val dirORC = dir.getCanonicalPath + + if (partition.isDefined) { + df.write.partitionBy(partition.get).orc(dirORC) + } else { + df.write.orc(dirORC) + } + + spark.read.format(NATIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("nativeOrcTable") + spark.read.format(HIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("hiveOrcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1192 / 1221 13.2 75.8 1.0X + Native ORC Vectorized 161 / 170 97.5 10.3 7.4X + Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1287 / 1333 12.2 81.8 1.0X + Native ORC Vectorized 164 / 172 95.6 10.5 7.8X + Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1304 / 1388 12.1 82.9 1.0X + Native ORC Vectorized 227 / 240 69.3 14.4 5.7X + Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1331 / 1357 11.8 84.6 1.0X + Native ORC Vectorized 289 / 297 54.4 18.4 4.6X + Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1410 / 1428 11.2 89.7 1.0X + Native ORC Vectorized 328 / 335 48.0 20.8 4.3X + Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1467 / 1485 10.7 93.3 1.0X + Native ORC Vectorized 402 / 411 39.1 25.6 3.6X + Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + */ + sqlBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2729 / 2744 3.8 260.2 1.0X + Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X + Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Read data column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read data column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read partition column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read both columns - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X + Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X + Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X + Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X + Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X + Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X + Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X + Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X + Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1325 / 1328 7.9 126.4 1.0X + Native ORC Vectorized 320 / 330 32.8 30.5 4.1X + Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + val benchmark = new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2553 / 2554 4.1 243.4 1.0X + Native ORC Vectorized 953 / 954 11.0 90.9 2.7X + Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + + String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2389 / 2408 4.4 227.8 1.0X + Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X + Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + + String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1295 / 1311 8.1 123.5 1.0X + Native ORC Vectorized 449 / 457 23.4 42.8 2.9X + Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1103 / 1124 1.0 1052.0 1.0X + Native ORC Vectorized 92 / 100 11.4 87.9 12.0X + Hive built-in ORC 383 / 390 2.7 365.4 2.9X + + SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2245 / 2250 0.5 2141.0 1.0X + Native ORC Vectorized 157 / 165 6.7 150.2 14.3X + Hive built-in ORC 587 / 593 1.8 559.4 3.8X + + SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 3343 / 3350 0.3 3188.3 1.0X + Native ORC Vectorized 265 / 280 3.9 253.2 12.6X + Hive built-in ORC 828 / 842 1.3 789.8 4.0X + */ + sqlBenchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + repeatedStringScanBenchmark(1024 * 1024 * 10) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) + } +} +// scalastyle:on line.size.limit From 2250cb75b99d257e698fe5418a51d8cddb4d5104 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 Jan 2018 21:58:55 +0800 Subject: [PATCH 0051/2461] [SPARK-22981][SQL] Fix incorrect results of Casting Struct to String ## What changes were proposed in this pull request? This pr fixed the issue when casting structs into strings; ``` scala> val df = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") scala> df.write.saveAsTable("t") scala> sql("SELECT CAST(a AS STRING) FROM t").show +-------------------+ | a| +-------------------+ |[0,1,1800000001,61]| |[0,2,1800000001,62]| +-------------------+ ``` This pr modified the result into; ``` +------+ | a| +------+ |[1, a]| |[2, b]| +------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20176 from maropu/SPARK-22981. --- .../spark/sql/catalyst/expressions/Cast.scala | 71 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 16 +++++ 2 files changed, 87 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2de4c8e30bec..f21aa1e9e3135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -259,6 +259,29 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case StructType(fields) => + buildCast[InternalRow](_, row => { + val builder = new UTF8StringBuilder + builder.append("[") + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (!row.isNullAt(0)) { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (!row.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -732,6 +755,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeStructToStringBuilder( + st: Seq[DataType], + row: String, + buffer: String, + ctx: CodegenContext): String = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshName("field") + val fieldStr = ctx.freshName("fieldStr") + s""" + |${if (i != 0) s"""$buffer.append(",");""" else ""} + |if (!$row.isNullAt($i)) { + | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | + | // Append $i field into the string buffer + | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode, + funcName = "fieldToString", + arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + + s""" + |$buffer.append("["); + |$writeStructCode + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -765,6 +823,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case StructType(fields) => + (c, evPrim, evNull) => { + val row = ctx.freshName("row") + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + s""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1445bb8a97d40..5b25bdf907c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -906,4 +906,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") } + + test("SPARK-22981 Cast struct to string") { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, "[1, a, 0.1]") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, "[1,, a]") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, "[[1, a], 5, 0.1]") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") + } } From 96ba217a06fbe1dad703447d7058cb7841653861 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 10:15:27 +0800 Subject: [PATCH 0052/2461] [SPARK-23005][CORE] Improve RDD.take on small number of partitions ## What changes were proposed in this pull request? In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%: `(1.5 * num * partsScanned / buf.size).toInt` However, when the number is small, the result of `.toInt` is not what we want. E.g, 2.9 will become 2, which should be 3. Use Math.ceil to fix the problem. Also clean up the code in RDD.scala. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20200 from gengliangwang/Take. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 27 +++++++++---------- .../spark/sql/execution/SparkPlan.scala | 5 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc925362..7859781e98223 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - @transient var name: String = null + @transient var name: String = _ /** Assign a name to this RDD */ def setName(_name: String): this.type = { @@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag]( // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null + private var dependencies_ : Seq[Dependency[_]] = _ + @transient private var partitions_ : Array[Partition] = _ /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag]( private[spark] def getNarrowAncestors: Seq[RDD[_]] = { val ancestors = new mutable.HashSet[RDD[_]] - def visit(rdd: RDD[_]) { + def visit(rdd: RDD[_]): Unit = { val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]]) val narrowParents = narrowDependencies.map(_.rdd) val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains) @@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) + var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag]( def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } - (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) + partitions.indices.iterator.flatMap(i => collectPartition(i)) } /** @@ -1338,6 +1338,7 @@ abstract class RDD[T: ClassTag]( // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L + val left = num - buf.size if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1345,13 +1346,12 @@ abstract class RDD[T: ClassTag]( if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } - val left = num - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) @@ -1677,8 +1677,7 @@ abstract class RDD[T: ClassTag]( // an RDD and its parent in every batch, in which case the parent may never be checkpointed // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). private val checkpointAllMarkedAncestors = - Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) - .map(_.toBoolean).getOrElse(false) + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean) /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { @@ -1686,7 +1685,7 @@ abstract class RDD[T: ClassTag]( } /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ - protected[spark] def parent[U: ClassTag](j: Int) = { + protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = { dependencies(j).rdd.asInstanceOf[RDD[U]] } @@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag]( * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ - protected def clearDependencies() { + protected def clearDependencies(): Unit = { dependencies_ = null } @@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag]( val lastDepStrings = debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) - (frontDepStrings ++ lastDepStrings) + frontDepStrings ++ lastDepStrings } } // The first RDD in the dependency stack has no parents, so no need for a +- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 82300efc01632..398758a3331b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1) + val left = n - buf.size + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) } } From 6f169ca9e1444fe8fd1ab6f3fbf0a8be1670f1b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 10:20:34 +0800 Subject: [PATCH 0053/2461] [MINOR] fix a typo in BroadcastJoinSuite ## What changes were proposed in this pull request? `BroadcastNestedLoopJoinExec` should be `BroadcastHashJoinExec` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20202 from cloud-fan/typo. --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6da46ea3480b3..0bcd54e1fceab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -318,7 +318,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { case b: BroadcastNestedLoopJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) - case b: BroadcastNestedLoopJoinExec => + case b: BroadcastHashJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => From 7bcc2666810cefc85dfa0d6679ac7a0de9e23154 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:00:07 +0900 Subject: [PATCH 0054/2461] [SPARK-23018][PYTHON] Fix createDataFrame from Pandas timestamp series assignment ## What changes were proposed in this pull request? This fixes createDataFrame from Pandas to only assign modified timestamp series back to a copied version of the Pandas DataFrame. Previously, if the Pandas DataFrame was only a reference (e.g. a slice of another) each series will still get assigned back to the reference even if it is not a modified timestamp column. This caused the following warning "SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame." ## How was this patch tested? existing tests Author: Bryan Cutler Closes #20213 from BryanCutler/pyspark-createDataFrame-copy-slice-warn-SPARK-23018. --- python/pyspark/sql/session.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6052fa9e84096..3e4574729a631 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -459,21 +459,23 @@ def _convert_from_pandas(self, pdf, schema, timezone): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if isinstance(field.dataType, TimestampType): s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if not copied and s is not pdf[field.name]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s else: for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) - if not copied and s is not pdf[column]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) From e5998372487af20114e160264a594957344ff433 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:55:24 +0900 Subject: [PATCH 0055/2461] [SPARK-23009][PYTHON] Fix for non-str col names to createDataFrame from Pandas ## What changes were proposed in this pull request? This the case when calling `SparkSession.createDataFrame` using a Pandas DataFrame that has non-str column labels. The column name conversion logic to handle non-string or unicode in python2 is: ``` if column is not any type of string: name = str(column) else if column is unicode in Python 2: name = column.encode('utf-8') ``` ## How was this patch tested? Added a new test with a Pandas DataFrame that has int column labels Author: Bryan Cutler Closes #20210 from BryanCutler/python-createDataFrame-int-col-error-SPARK-23009. --- python/pyspark/sql/session.py | 4 +++- python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 3e4574729a631..604021c1f45cc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -648,7 +648,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13576ff57001b..80a94a91a87b3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3532,6 +3532,15 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): From edf0a48c2ec696b92ed6a96dcee6eeb1a046b20b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 15:01:11 +0800 Subject: [PATCH 0056/2461] [SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel ## What changes were proposed in this pull request? This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.** For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel. Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method. It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads. The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread. The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`. After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded. This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom. This bug was introduced in SPARK-11956 / #9941 and is present in Spark 1.6.0+. ## How was this patch tested? This fix was tested manually against a workload which non-deterministically hit this bug. Author: Josh Rosen Closes #20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 37 +++++++++++-------- .../shuffle/IndexShuffleBlockResolver.scala | 21 +++++++++-- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f951591e02a5c..a2936d6ad539c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) - try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - } catch { - case e: Exception => - pipe.sink().close() - source.close() - throw e - } + })(catchBlock = { + pipe.sink().close() + source.close() + }) source } @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv( fileDownloadFactory.createClient(host, port) } - private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { @volatile private var error: Throwable = _ def setError(e: Throwable): Unit = { + // This setError callback is invoked by internal RPC threads in order to propagate remote + // exceptions to application-level threads which are reading from this channel. When an + // RPC error occurs, the RPC system will call setError() and then will close the + // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe + // sink will cause `source.read()` operations to return EOF, unblocking the application-level + // reading thread. Thus there is no need to actually call `source.close()` here in the + // onError() callback and, in fact, calling it here would be dangerous because the close() + // would be asynchronous with respect to the read() call and could trigger race-conditions + // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. error = e - source.close() } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + // See the documentation above in setError(): if an RPC error has occurred then setError() + // will be called to propagate the RPC error and then `source`'s corresponding + // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate + // the remote RPC exception (and not any exceptions triggered by the pipe close, such as + // ChannelClosedException), hence this `error != null` check: + case _ if error != null => throw error case Success(bytesRead) => bytesRead - case Failure(readErr) => - if (error != null) { - throw error - } else { - throw readErr - } + case Failure(readErr) => throw readErr } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d0..266ee42e39cca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,8 +18,8 @@ package org.apache.spark.shuffle import java.io._ - -import com.google.common.io.ByteStreams +import java.nio.channels.Channels +import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code + // which is incorrectly using our file descriptor then this code will fetch the wrong offsets + // (which may cause a reducer to be sent a different reducer's data). The explicit position + // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this + // class of issue from re-occurring in the future which is why they are left here even though + // SPARK-22982 is fixed. + val channel = Files.newByteChannel(indexFile.toPath) + channel.position(blockId.reduceId * 8) + val in = new DataInputStream(Channels.newInputStream(channel)) try { - ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() + val actualPosition = channel.position() + val expectedPosition = blockId.reduceId * 8 + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + + s"expected $expectedPosition but actual position was $actualPosition.") + } new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), From eaac60a1e20e29084b7151ffca964cfaa5ba99d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 15:16:27 +0800 Subject: [PATCH 0057/2461] [SPARK-16060][SQL][FOLLOW-UP] add a wrapper solution for vectorized orc reader ## What changes were proposed in this pull request? This is mostly from https://github.com/apache/spark/pull/13775 The wrapper solution is pretty good for string/binary type, as the ORC column vector doesn't keep bytes in a continuous memory region, and has a significant overhead when copying the data to Spark columnar batch. For other cases, the wrapper solution is almost same with the current solution. I think we can treat the wrapper solution as a baseline and keep improving the writing to Spark solution. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20205 from cloud-fan/orc. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../datasources/orc/OrcColumnVector.java | 251 ++++++++++++++++++ .../orc/OrcColumnarBatchReader.java | 106 ++++++-- .../datasources/orc/OrcFileFormat.scala | 6 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 236 ++++++++++------ 5 files changed, 490 insertions(+), 116 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 74949db883f7a..36e802a9faa6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -391,6 +391,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") + .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + + "vectorized ORC reader.") + .internal() + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java new file mode 100644 index 0000000000000..f94c55d860304 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. + */ +public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + final private boolean isTimestamp; + + private int batchSize; + + OrcColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } + + @Override + public int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5c28d0e6e507a..36fdf2bdf84d2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -51,13 +51,13 @@ public class OrcColumnarBatchReader extends RecordReader { /** - * The default size of batch. We use this value for both ORC and Spark consistently - * because they have different default values like the following. + * The default size of batch. We use this value for ORC reader to make it consistent with Spark's + * columnar batch, because their default batch sizes are different like the following: * * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 */ - public static final int DEFAULT_SIZE = 4 * 1024; + private static final int DEFAULT_SIZE = 4 * 1024; // ORC File Reader private Reader reader; @@ -82,13 +82,18 @@ public class OrcColumnarBatchReader extends RecordReader { // Writable column vectors of the result columnar batch. private WritableColumnVector[] columnVectors; - /** - * The memory mode of the columnarBatch - */ + // The wrapped ORC column vectors. It should be null if `copyToSpark` is true. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + // The memory mode of the columnarBatch private final MemoryMode MEMORY_MODE; - public OrcColumnarBatchReader(boolean useOffHeap) { + // Whether or not to copy the ORC columnar batch to Spark columnar batch. + private final boolean copyToSpark; + + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + this.copyToSpark = copyToSpark; } @@ -167,27 +172,61 @@ public void initBatch( } int capacity = DEFAULT_SIZE; - if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); - } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); - } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); - columnVectors[i + partitionIdx].setIsConstant(); + if (copyToSpark) { + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } - } - // Initialize the missing columns once. - for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); - columnVectors[i].setIsConstant(); + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, capacity); + columnVectors[i].setIsConstant(); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + } else { + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + DataType dt = partitionSchema.fields()[i].dataType(); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + orcVectorWrappers[partitionIdx + i] = partitionCol; + } + } + + columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); } } @@ -196,17 +235,26 @@ public void initBatch( * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. */ private boolean nextBatch() throws IOException { - for (WritableColumnVector vector : columnVectors) { - vector.reset(); - } - columnarBatch.setNumRows(0); - recordReader.nextBatch(batch); int batchSize = batch.size; if (batchSize == 0) { return false; } columnarBatch.setNumRows(batchSize); + + if (!copyToSpark) { + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] != -1) { + ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + return true; + } + + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + for (int i = 0; i < requiredFields.length; i++) { StructField field = requiredFields[i]; WritableColumnVector toColumn = columnVectors[i]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b8bacfa1838ae..2dd314d165348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -150,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -183,8 +185,8 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { - val batchReader = - new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + val batchReader = new OrcColumnarBatchReader( + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 37ed846acd1eb..bf6efa7c4c08c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -86,7 +86,7 @@ object OrcReadBenchmark { } def numericScanBenchmark(values: Int, dataType: DataType): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -95,61 +95,73 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1192 / 1221 13.2 75.8 1.0X - Native ORC Vectorized 161 / 170 97.5 10.3 7.4X - Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + Native ORC MR 1135 / 1171 13.9 72.2 1.0X + Native ORC Vectorized 152 / 163 103.4 9.7 7.5X + Native ORC Vectorized with copy 149 / 162 105.4 9.5 7.6X + Hive built-in ORC 1380 / 1384 11.4 87.7 0.8X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1287 / 1333 12.2 81.8 1.0X - Native ORC Vectorized 164 / 172 95.6 10.5 7.8X - Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + Native ORC MR 1182 / 1244 13.3 75.2 1.0X + Native ORC Vectorized 145 / 156 108.7 9.2 8.2X + Native ORC Vectorized with copy 148 / 158 106.4 9.4 8.0X + Hive built-in ORC 1591 / 1636 9.9 101.2 0.7X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1304 / 1388 12.1 82.9 1.0X - Native ORC Vectorized 227 / 240 69.3 14.4 5.7X - Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + Native ORC MR 1271 / 1271 12.4 80.8 1.0X + Native ORC Vectorized 206 / 212 76.3 13.1 6.2X + Native ORC Vectorized with copy 200 / 213 78.8 12.7 6.4X + Hive built-in ORC 1776 / 1787 8.9 112.9 0.7X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1331 / 1357 11.8 84.6 1.0X - Native ORC Vectorized 289 / 297 54.4 18.4 4.6X - Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + Native ORC MR 1344 / 1355 11.7 85.4 1.0X + Native ORC Vectorized 258 / 268 61.0 16.4 5.2X + Native ORC Vectorized with copy 252 / 257 62.4 16.0 5.3X + Hive built-in ORC 1818 / 1823 8.7 115.6 0.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1410 / 1428 11.2 89.7 1.0X - Native ORC Vectorized 328 / 335 48.0 20.8 4.3X - Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + Native ORC MR 1333 / 1352 11.8 84.8 1.0X + Native ORC Vectorized 310 / 324 50.7 19.7 4.3X + Native ORC Vectorized with copy 312 / 320 50.4 19.9 4.3X + Hive built-in ORC 1904 / 1918 8.3 121.0 0.7X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1467 / 1485 10.7 93.3 1.0X - Native ORC Vectorized 402 / 411 39.1 25.6 3.6X - Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + Native ORC MR 1408 / 1585 11.2 89.5 1.0X + Native ORC Vectorized 359 / 368 43.8 22.8 3.9X + Native ORC Vectorized with copy 364 / 371 43.2 23.2 3.9X + Hive built-in ORC 1881 / 1954 8.4 119.6 0.7X */ - sqlBenchmark.run() + benchmark.run() } } } @@ -176,19 +188,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2729 / 2744 3.8 260.2 1.0X - Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X - Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + Native ORC MR 2566 / 2592 4.1 244.7 1.0X + Native ORC Vectorized 1098 / 1113 9.6 104.7 2.3X + Native ORC Vectorized with copy 1527 / 1593 6.9 145.6 1.7X + Hive built-in ORC 3561 / 3705 2.9 339.6 0.7X */ benchmark.run() } @@ -205,63 +224,84 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) - benchmark.addCase("Read data column - Native ORC MR") { _ => + benchmark.addCase("Data column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + benchmark.addCase("Data column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read data column - Hive built-in ORC") { _ => + benchmark.addCase("Data column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Data column - Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } - benchmark.addCase("Read partition column - Native ORC MR") { _ => + benchmark.addCase("Partition column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } - benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Partition column - Hive built-in ORC") { _ => spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() } - benchmark.addCase("Read both columns - Native ORC MR") { _ => + benchmark.addCase("Both columns - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + benchmark.addCase("Both columns - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + benchmark.addCase("Both column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Both columns - Hive built-in ORC") { _ => spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X - Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X - Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X - Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X - Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X - Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X - Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X - Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X - Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + Data only - Native ORC MR 1447 / 1457 10.9 92.0 1.0X + Data only - Native ORC Vectorized 256 / 266 61.4 16.3 5.6X + Data only - Native ORC Vectorized with copy 263 / 273 59.8 16.7 5.5X + Data only - Hive built-in ORC 1960 / 1988 8.0 124.6 0.7X + Partition only - Native ORC MR 1039 / 1043 15.1 66.0 1.4X + Partition only - Native ORC Vectorized 48 / 53 326.6 3.1 30.1X + Partition only - Native ORC Vectorized with copy 48 / 53 328.4 3.0 30.2X + Partition only - Hive built-in ORC 1234 / 1242 12.7 78.4 1.2X + Both columns - Native ORC MR 1465 / 1475 10.7 93.1 1.0X + Both columns - Native ORC Vectorized 292 / 301 53.9 18.6 5.0X + Both column - Native ORC Vectorized with copy 348 / 354 45.1 22.2 4.2X + Both columns - Hive built-in ORC 2051 / 2060 7.7 130.4 0.7X */ benchmark.run() } @@ -287,19 +327,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1325 / 1328 7.9 126.4 1.0X - Native ORC Vectorized 320 / 330 32.8 30.5 4.1X - Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + Native ORC MR 1271 / 1278 8.3 121.2 1.0X + Native ORC Vectorized 200 / 212 52.4 19.1 6.4X + Native ORC Vectorized with copy 342 / 347 30.7 32.6 3.7X + Hive built-in ORC 1874 / 2105 5.6 178.7 0.7X */ benchmark.run() } @@ -331,32 +378,42 @@ object OrcReadBenchmark { "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2553 / 2554 4.1 243.4 1.0X - Native ORC Vectorized 953 / 954 11.0 90.9 2.7X - Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + Native ORC MR 2394 / 2886 4.4 228.3 1.0X + Native ORC Vectorized 699 / 729 15.0 66.7 3.4X + Native ORC Vectorized with copy 959 / 1025 10.9 91.5 2.5X + Hive built-in ORC 3899 / 3901 2.7 371.9 0.6X String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2389 / 2408 4.4 227.8 1.0X - Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X - Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + Native ORC MR 2234 / 2255 4.7 213.1 1.0X + Native ORC Vectorized 854 / 869 12.3 81.4 2.6X + Native ORC Vectorized with copy 1099 / 1128 9.5 104.8 2.0X + Hive built-in ORC 2767 / 2793 3.8 263.9 0.8X String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1295 / 1311 8.1 123.5 1.0X - Native ORC Vectorized 449 / 457 23.4 42.8 2.9X - Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + Native ORC MR 1166 / 1202 9.0 111.2 1.0X + Native ORC Vectorized 338 / 345 31.1 32.2 3.5X + Native ORC Vectorized with copy 418 / 428 25.1 39.9 2.8X + Hive built-in ORC 1730 / 1761 6.1 164.9 0.7X */ benchmark.run() } @@ -364,7 +421,7 @@ object OrcReadBenchmark { } def columnsBenchmark(values: Int, width: Int): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -376,43 +433,52 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1103 / 1124 1.0 1052.0 1.0X - Native ORC Vectorized 92 / 100 11.4 87.9 12.0X - Hive built-in ORC 383 / 390 2.7 365.4 2.9X + Native ORC MR 1050 / 1053 1.0 1001.1 1.0X + Native ORC Vectorized 95 / 101 11.0 90.9 11.0X + Native ORC Vectorized with copy 95 / 102 11.0 90.9 11.0X + Hive built-in ORC 348 / 358 3.0 331.8 3.0X - SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2245 / 2250 0.5 2141.0 1.0X - Native ORC Vectorized 157 / 165 6.7 150.2 14.3X - Hive built-in ORC 587 / 593 1.8 559.4 3.8X + Native ORC MR 2099 / 2108 0.5 2002.1 1.0X + Native ORC Vectorized 179 / 187 5.8 171.1 11.7X + Native ORC Vectorized with copy 176 / 188 6.0 167.6 11.9X + Hive built-in ORC 562 / 581 1.9 535.9 3.7X - SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 3343 / 3350 0.3 3188.3 1.0X - Native ORC Vectorized 265 / 280 3.9 253.2 12.6X - Hive built-in ORC 828 / 842 1.3 789.8 4.0X + Native ORC MR 3221 / 3246 0.3 3071.4 1.0X + Native ORC Vectorized 312 / 322 3.4 298.0 10.3X + Native ORC Vectorized with copy 306 / 320 3.4 291.6 10.5X + Hive built-in ORC 815 / 824 1.3 777.3 4.0X */ - sqlBenchmark.run() + benchmark.run() } } } From 70bcc9d5ae33d6669bb5c97db29087ccead770fb Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 9 Jan 2018 23:32:47 -0800 Subject: [PATCH 0058/2461] [SPARK-22993][ML] Clarify HasCheckpointInterval param doc ## What changes were proposed in this pull request? Add a note to the `HasCheckpointInterval` parameter doc that clarifies that this setting is ignored when no checkpoint directory has been set on the spark context. ## How was this patch tested? No tests necessary, just a doc update. Author: sethah Closes #20188 from sethah/als_checkpoint_doc. --- R/pkg/R/mllib_recommendation.R | 2 ++ R/pkg/R/mllib_tree.R | 6 ++++++ .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- python/pyspark/ml/param/_shared_params_code_gen.py | 5 +++-- python/pyspark/ml/param/shared.py | 4 ++-- 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index fa794249085d7..5441c4a4022a9 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -48,6 +48,8 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @param numUserBlocks number of user blocks used to parallelize computation (> 0). #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 89a58bf0aadae..4e5ddf22ee16d 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -161,6 +161,8 @@ print.summary.decisionTree <- function(x) { #' >= 1. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -382,6 +384,8 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -595,6 +599,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a5d57a15317e6..6ad44af9ef7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -63,7 +63,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + - "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), + "every 10 iterations. Note: this setting will be ignored if the checkpoint directory " + + "is not set in the SparkContext", + isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an error). More " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f18..be8b2f273164b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -282,10 +282,10 @@ trait HasOutputCols extends Params { trait HasCheckpointInterval extends Params { /** - * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index d55d209d09398..1d0f60acc6983 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -121,8 +121,9 @@ def get$Name(self): ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, - "TypeConverters.toInt"), + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + + "this setting will be ignored if the checkpoint directory is not set in the SparkContext.", + None, "TypeConverters.toInt"), ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None, "TypeConverters.toFloat"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index e5c5ddfba6c1f..813f7a59f3fd1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -281,10 +281,10 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. """ - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() From f340b6b3066033d40b7e163fd5fb68e9820adfb1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 00:45:47 -0800 Subject: [PATCH 0059/2461] [SPARK-22997] Add additional defenses against use of freed MemoryBlocks ## What changes were proposed in this pull request? This patch modifies Spark's `MemoryAllocator` implementations so that `free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap case) or null out references to backing `long[]` arrays (in the on-heap case). The goal of this change is to add an extra layer of defense against use-after-free bugs because currently it's hard to detect corruption caused by blind writes to freed memory blocks. ## How was this patch tested? New unit tests in `PlatformSuite`, including new tests for existing functionality because we did not have sufficient mutation coverage of the on-heap memory allocator's pooling logic. Author: Josh Rosen Closes #20191 from JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator. --- .../unsafe/memory/HeapMemoryAllocator.java | 35 +++++++++---- .../spark/unsafe/memory/MemoryBlock.java | 21 +++++++- .../unsafe/memory/UnsafeMemoryAllocator.java | 11 ++++ .../spark/unsafe/PlatformUtilSuite.java | 50 ++++++++++++++++++- .../spark/memory/TaskMemoryManager.java | 13 ++++- .../spark/memory/TaskMemoryManagerSuite.java | 29 +++++++++++ 6 files changed, 146 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index cc9cc429643ad..3acfe3696cb1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -31,8 +31,7 @@ public class HeapMemoryAllocator implements MemoryAllocator { @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap<>(); + private final Map>> bufferPoolsBySize = new HashMap<>(); private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; @@ -49,13 +48,14 @@ private boolean shouldPool(long size) { public MemoryBlock allocate(long size) throws OutOfMemoryError { if (shouldPool(size)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(size); if (pool != null) { while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); + final WeakReference arrayReference = pool.pop(); + final long[] array = arrayReference.get(); + if (array != null) { + assert (array.length * 8L >= size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -76,18 +76,35 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { + assert (memory.obj != null) : + "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to null out its reference to the long[] array. + long[] array = (long[]) memory.obj; + memory.setObjAndOffset(null, 0); + if (shouldPool(size)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(size); if (pool == null) { pool = new LinkedList<>(); bufferPoolsBySize.put(size, pool); } - pool.add(new WeakReference<>(memory)); + pool.add(new WeakReference<>(array)); } } else { // Do nothing diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index cd1d378bc1470..c333857358d30 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -26,6 +26,25 @@ */ public class MemoryBlock extends MemoryLocation { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ + public static final int NO_PAGE_NUMBER = -1; + + /** + * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager. + * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator + * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM + * before being passed to MemoryAllocator.free() (it is an error to allocate a page in + * TaskMemoryManager and then directly free it in a MemoryAllocator without going through + * the TMM freePage() call). + */ + public static final int FREED_IN_TMM_PAGE_NUMBER = -2; + + /** + * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows + * us to detect double-frees. + */ + public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; + private final long length; /** @@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation { * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, * which lives in a different package. */ - public int pageNumber = -1; + public int pageNumber = NO_PAGE_NUMBER; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 55bcdf1ed7b06..4368fb615ba1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -38,9 +38,20 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to reset its pointer. + memory.offset = 0; + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4b141339ec816..62854837b05ed 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -62,6 +62,52 @@ public void overlappingCopyMemory() { } } + @Test + public void onHeapMemoryAllocatorPoolingReUsesLongArrays() { + MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject1 = block1.getBaseObject(); + MemoryAllocator.HEAP.free(block1); + MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject2 = block2.getBaseObject(); + Assert.assertSame(baseObject1, baseObject2); + MemoryAllocator.HEAP.free(block2); + } + + @Test + public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + Assert.assertNotNull(block.getBaseObject()); + MemoryAllocator.HEAP.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test + public void freeingOffHeapMemoryBlockResetsOffset() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + Assert.assertNull(block.getBaseObject()); + Assert.assertNotEquals(0, block.getBaseOffset()); + MemoryAllocator.UNSAFE.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test(expected = AssertionError.class) + public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + MemoryAllocator.HEAP.free(block); + MemoryAllocator.HEAP.free(block); + } + + @Test(expected = AssertionError.class) + public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + MemoryAllocator.UNSAFE.free(block); + MemoryAllocator.UNSAFE.free(block); + } + @Test public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); @@ -71,9 +117,11 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object onheap1BaseObject = onheap1.getBaseObject(); + long onheap1BaseOffset = onheap1.getBaseOffset(); MemoryAllocator.HEAP.free(onheap1); Assert.assertEquals( - Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + Platform.getByte(onheap1BaseObject, onheap1BaseOffset), MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); Assert.assertEquals( diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index e8d3730daa7a4..632d718062212 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -321,8 +321,12 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != -1) : + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; assert(allocatedPages.get(page.pageNumber)); pageTable[page.pageNumber] = null; synchronized (this) { @@ -332,6 +336,10 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -358,7 +366,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { - assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } @@ -424,6 +432,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 46b0516e36141..a0664b30d6cc2 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -21,6 +21,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; public class TaskMemoryManagerSuite { @@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() { Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); } + @Test + public void freeingPageSetsPageNumberToSpecialConstant() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + c.freePage(dataPage); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + } + + @Test(expected = AssertionError.class) + public void freeingPageDirectlyInAllocatorTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + MemoryAllocator.HEAP.free(dataPage); + } + + @Test(expected = AssertionError.class) + public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256); + manager.freePage(dataPage, c); + } + @Test public void cooperativeSpilling() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); From 344e3aab87178e45957333479a07e07f202ca1fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 09:44:30 -0800 Subject: [PATCH 0060/2461] [SPARK-23019][CORE] Wait until SparkContext.stop() finished in SparkLauncherSuite ## What changes were proposed in this pull request? In current code ,the function `waitFor` call https://github.com/apache/spark/blob/cfcd746689c2b84824745fa6d327ffb584c7a17d/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java#L155 only wait until DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. https://github.com/apache/spark/blob/1c9f95cb771ac78775a77edd1abfeb2d8ae2a124/core/src/main/scala/org/apache/spark/SparkContext.scala#L1924 Thus, in the Jenkins test https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-maven-hadoop-2.6/ , `JdbcRDDSuite` failed because the previous test `SparkLauncherSuite` exit before SparkContext.stop() is finished. To repo: ``` $ build/sbt > project core > testOnly *SparkLauncherSuite *JavaJdbcRDDSuite ``` To Fix: Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM in SparkLauncherSuite. Can' come up with any better solution for now. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20221 from gengliangwang/SPARK-23019. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index c2261c204cd45..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; @@ -133,6 +134,10 @@ public void testInProcessLauncher() throws Exception { p.put(e.getKey(), e.getValue()); } System.setProperties(p); + // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. + // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. + // See SPARK-23019 and SparkContext.stop() for details. + TimeUnit.MILLISECONDS.sleep(500); } } From 9b33dfc408de986f4203bb0ac0c3f5c56effd69d Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 10 Jan 2018 14:25:04 -0800 Subject: [PATCH 0061/2461] [SPARK-22951][SQL] fix aggregation after dropDuplicates on empty data frames ## What changes were proposed in this pull request? (courtesy of liancheng) Spark SQL supports both global aggregation and grouping aggregation. Global aggregation always return a single row with the initial aggregation state as the output, even there are zero input rows. Spark implements this by simply checking the number of grouping keys and treats an aggregation as a global aggregation if it has zero grouping keys. However, this simple principle drops the ball in the following case: ```scala spark.emptyDataFrame.dropDuplicates().agg(count($"*") as "c").show() // +---+ // | c | // +---+ // | 1 | // +---+ ``` The reason is that: 1. `df.dropDuplicates()` is roughly translated into something equivalent to: ```scala val allColumns = df.columns.map { col } df.groupBy(allColumns: _*).agg(allColumns.head, allColumns.tail: _*) ``` This translation is implemented in the rule `ReplaceDeduplicateWithAggregate`. 2. `spark.emptyDataFrame` contains zero columns and zero rows. Therefore, rule `ReplaceDeduplicateWithAggregate` makes a confusing transformation roughly equivalent to the following one: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy().agg(Map.empty[String, String]) ``` The above transformation is confusing because the resulting aggregate operator contains no grouping keys (because `emptyDataFrame` contains no columns), and gets recognized as a global aggregation. As a result, Spark SQL allocates a single row filled by the initial aggregation state and uses it as the output, and returns a wrong result. To fix this issue, this PR tweaks `ReplaceDeduplicateWithAggregate` by appending a literal `1` to the grouping key list of the resulting `Aggregate` operator when the input plan contains zero output columns. In this way, `spark.emptyDataFrame.dropDuplicates()` is now translated into a grouping aggregation, roughly depicted as: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy(lit(1)).agg(Map.empty[String, String]) ``` Which is now properly treated as a grouping aggregation and returns the correct answer. ## How was this patch tested? New unit tests added Author: Feng Liu Closes #20174 from liufengdb/fix-duplicate. --- .../sql/catalyst/optimizer/Optimizer.scala | 8 ++++++- .../optimizer/ReplaceOperatorSuite.scala | 10 +++++++- .../spark/sql/DataFrameAggregateSuite.scala | 24 +++++++++++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index df0af8264a329..c794ba8619322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1222,7 +1222,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } - Aggregate(keys, aggCols, child) + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + Aggregate(nonemptyKeys, aggCols, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 0fa1aaeb9e164..e9701ffd2c54b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -198,6 +198,14 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("add one grouping key if necessary when replace Deduplicate with Aggregate") { + val input = LocalRelation() + val query = Deduplicate(Seq.empty, input) // dropDuplicates() + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input) + comparePlans(optimized, correctAnswer) + } + test("don't replace streaming Deduplicate") { val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) val attrA = input.output(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 06848e4d2b297..e7776e36702ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import scala.util.Random +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -27,7 +29,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.DecimalType case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -456,7 +458,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), Row(null, null, null, null, null)) @@ -666,4 +667,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(exchangePlans.length == 1) } } + + Seq(true, false).foreach { codegen => + test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " + + s"results when codegen is enabled: $codegen") { + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) { + // explicit global aggregations + val emptyAgg = Map.empty[String, String] + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + + // global aggregation is converted to grouping aggregation: + assert(spark.emptyDataFrame.dropDuplicates().count() == 0) + } + } + } } From a6647ffbf7a312a3e119a9beef90880cc915aa60 Mon Sep 17 00:00:00 2001 From: Mingjie Tang Date: Thu, 11 Jan 2018 11:51:03 +0800 Subject: [PATCH 0062/2461] [SPARK-22587] Spark job fails if fs.defaultFS and application jar are different url ## What changes were proposed in this pull request? Two filesystems comparing does not consider the authority of URI. This is specific for WASB file storage system, where userInfo is honored to differentiate filesystems. For example: wasbs://user1xyz.net, wasbs://user2xyz.net would consider as two filesystem. Therefore, we have to add the authority to compare two filesystem, and two filesystem with different authority can not be the same FS. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Mingjie Tang Closes #19885 from merlintang/EAR-7377. --- .../org/apache/spark/deploy/yarn/Client.scala | 24 +++++++++++--- .../spark/deploy/yarn/ClientSuite.scala | 33 +++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 15328d08b3b5c..8cd3cd9746a3a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1421,15 +1421,20 @@ private object Client extends Logging { } /** - * Return whether the two file systems are the same. + * Return whether two URI represent file system are the same */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() + private[spark] def compareUri(srcUri: URI, dstUri: URI): Boolean = { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + val srcAuthority = srcUri.getAuthority() + val dstAuthority = dstUri.getAuthority() + if (srcAuthority != null && !srcAuthority.equalsIgnoreCase(dstAuthority)) { + return false + } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() @@ -1447,6 +1452,17 @@ private object Client extends Logging { } Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() + + } + + /** + * Return whether the two file systems are the same. + */ + protected def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + + compareUri(srcUri, dstUri) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 9d5f5eb621118..7fa597167f3f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -357,6 +357,39 @@ class ClientSuite extends SparkFunSuite with Matchers { sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName))) } + private val matching = Seq( + ("files URI match test1", "file:///file1", "file:///file2"), + ("files URI match test2", "file:///c:file1", "file://c:file2"), + ("files URI match test3", "file://host/file1", "file://host/file2"), + ("wasb URI match test", "wasb://bucket1@user", "wasb://bucket1@user/"), + ("hdfs URI match test", "hdfs:/path1", "hdfs:/path1") + ) + + matching.foreach { t => + test(t._1) { + assert(Client.compareUri(new URI(t._2), new URI(t._3)), + s"No match between ${t._2} and ${t._3}") + } + } + + private val unmatching = Seq( + ("files URI unmatch test1", "file:///file1", "file://host/file2"), + ("files URI unmatch test2", "file://host/file1", "file:///file2"), + ("files URI unmatch test3", "file://host/file1", "file://host2/file2"), + ("wasb URI unmatch test1", "wasb://bucket1@user", "wasb://bucket2@user/"), + ("wasb URI unmatch test2", "wasb://bucket1@user", "wasb://bucket1@user2/"), + ("s3 URI unmatch test", "s3a://user@pass:bucket1/", "s3a://user2@pass2:bucket1/"), + ("hdfs URI unmatch test1", "hdfs://namenode1/path1", "hdfs://namenode1:8080/path2"), + ("hdfs URI unmatch test2", "hdfs://namenode1:8020/path1", "hdfs://namenode1:8080/path2") + ) + + unmatching.foreach { t => + test(t._1) { + assert(!Client.compareUri(new URI(t._2), new URI(t._3)), + s"match between ${t._2} and ${t._3}") + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 87c98de8b23f0e978958fc83677fdc4c339b7e6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 18:17:34 +0800 Subject: [PATCH 0063/2461] [SPARK-23001][SQL] Fix NullPointerException when DESC a database with NULL description ## What changes were proposed in this pull request? When users' DB description is NULL, users might hit `NullPointerException`. This PR is to fix the issue. ## How was this patch tested? Added test cases Author: gatorsmile Closes #20215 from gatorsmile/SPARK-23001. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ++++++ .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 102f40bacc985..4b923f5235a90 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -330,7 +330,7 @@ private[hive] class HiveClientImpl( Option(client.getDatabase(dbName)).map { d => CatalogDatabase( name = d.getName, - description = d.getDescription, + description = Option(d.getDescription).getOrElse(""), locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 2e35fdeba464d..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -107,4 +107,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { .filter(_.contains("Num Buckets")).head assert(bucketString.contains("10")) } + + test("SPARK-23001: NullPointerException when running desc database") { + val catalog = newBasicCatalog() + catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) + assert(catalog.getDatabase("dbWithNullDesc").description == "") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 94473a08dd317..ff90e9dda5f7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -163,6 +163,15 @@ class VersionsSuite extends SparkFunSuite with Logging { client.createDatabase(tempDB, ignoreIfExists = true) } + test(s"$version: createDatabase with null description") { + withTempDir { tmpDir => + val dbWithNullDesc = + CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) + client.createDatabase(dbWithNullDesc, ignoreIfExists = true) + assert(client.getDatabase("dbWithNullDesc").description == "") + } + } + test(s"$version: setCurrentDatabase") { client.setCurrentDatabase("default") } From 1c70da3bfbb4016e394de2c73eb0db7cdd9a6968 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 19:41:48 +0800 Subject: [PATCH 0064/2461] [SPARK-20657][CORE] Speed up rendering of the stages page. There are two main changes to speed up rendering of the tasks list when rendering the stage page. The first one makes the code only load the tasks being shown in the current page of the tasks table, and information related to only those tasks. One side-effect of this change is that the graph that shows task-related events now only shows events for the tasks in the current page, instead of the previously hardcoded limit of "events for the first 1000 tasks". That ends up helping with readability, though. To make sorting efficient when using a disk store, the task wrapper was extended to include many new indices, one for each of the sortable columns in the UI, and metrics for which quantiles are calculated. The second changes the way metric quantiles are calculated for stages. Instead of using the "Distribution" class to process data for all task metrics, which requires scanning all tasks of a stage, the code now uses the KVStore "skip()" functionality to only read tasks that contain interesting information for the quantiles that are desired. This is still not cheap; because there are many metrics that the UI and API track, the code needs to scan the index for each metric to gather the information. Savings come mainly from skipping deserialization when using the disk store, but the in-memory code also seems to be faster than before (most probably because of other changes in this patch). To make subsequent calls faster, some quantiles are cached in the status store. This makes UIs much faster after the first time a stage has been loaded. With the above changes, a lot of code in the UI layer could be simplified. Author: Marcelo Vanzin Closes #20013 from vanzin/SPARK-20657. --- .../apache/spark/util/kvstore/LevelDB.java | 1 + .../spark/status/AppStatusListener.scala | 57 +- .../apache/spark/status/AppStatusStore.scala | 389 +++++--- .../apache/spark/status/AppStatusUtils.scala | 68 ++ .../org/apache/spark/status/LiveEntity.scala | 344 ++++--- .../spark/status/api/v1/StagesResource.scala | 3 +- .../org/apache/spark/status/api/v1/api.scala | 3 + .../org/apache/spark/status/storeTypes.scala | 327 ++++++- .../apache/spark/ui/jobs/ExecutorTable.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 919 ++++++------------ ...mmary_w__custom_quantiles_expectation.json | 3 + ...sk_summary_w_shuffle_read_expectation.json | 3 + ...k_summary_w_shuffle_write_expectation.json | 3 + .../spark/status/AppStatusListenerSuite.scala | 105 +- .../spark/status/AppStatusStoreSuite.scala | 104 ++ .../org/apache/spark/ui/StagePageSuite.scala | 10 +- scalastyle-config.xml | 2 +- 18 files changed, 1361 insertions(+), 986 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 4f9e10ca20066..0e491efac9181 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -83,6 +83,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { if (versionData != null) { long version = serializer.deserializeLong(versionData); if (version != STORE_VERSION) { + close(); throw new UnsupportedStoreVersionException(); } } else { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 88b75ddd5993a..b4edcf23abc09 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -377,6 +377,10 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + + val locality = event.taskInfo.taskLocality.toString() + val count = stage.localitySummary.getOrElse(locality, 0L) + 1L + stage.localitySummary = stage.localitySummary ++ Map(locality -> count) maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -433,7 +437,7 @@ private[spark] class AppStatusListener( } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task, now) + update(task, now, last = true) delta }.orNull @@ -450,7 +454,7 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { - stage.metrics.update(metricsDelta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta) } stage.activeTasks -= 1 stage.completedTasks += completedDelta @@ -486,7 +490,7 @@ private[spark] class AppStatusListener( esummary.failedTasks += failedDelta esummary.killedTasks += killedDelta if (metricsDelta != null) { - esummary.metrics.update(metricsDelta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } maybeUpdate(esummary, now) @@ -604,11 +608,11 @@ private[spark] class AppStatusListener( maybeUpdate(task, now) Option(liveStages.get((sid, sAttempt))).foreach { stage => - stage.metrics.update(delta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, delta) maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) - esummary.metrics.update(delta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, delta) maybeUpdate(esummary, now) } } @@ -690,7 +694,7 @@ private[spark] class AppStatusListener( // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + rdd.setStorageLevel(updatedStorageLevel.get) } val partition = rdd.partition(block.name) @@ -814,7 +818,7 @@ private[spark] class AppStatusListener( /** Update a live entity only if it hasn't been updated in the last configured period. */ private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { - if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { update(entity, now) } } @@ -865,7 +869,7 @@ private[spark] class AppStatusListener( } stages.foreach { s => - val key = s.id + val key = Array(s.info.stageId, s.info.attemptId) kvstore.delete(s.getClass(), key) val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) @@ -885,15 +889,15 @@ private[spark] class AppStatusListener( .asScala tasks.foreach { t => - kvstore.delete(t.getClass(), t.info.taskId) + kvstore.delete(t.getClass(), t.taskId) } // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) .index("stageId") - .first(s.stageId) - .last(s.stageId) + .first(s.info.stageId) + .last(s.info.stageId) .closeableIterator() val hasMoreAttempts = try { @@ -905,8 +909,10 @@ private[spark] class AppStatusListener( } if (!hasMoreAttempts) { - kvstore.delete(classOf[RDDOperationGraphWrapper], s.stageId) + kvstore.delete(classOf[RDDOperationGraphWrapper], s.info.stageId) } + + cleanupCachedQuantiles(key) } } @@ -919,9 +925,9 @@ private[spark] class AppStatusListener( // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => - !live || t.info.status != TaskState.RUNNING.toString() + !live || t.status != TaskState.RUNNING.toString() } - toDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + toDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-toDelete.size) // If there are more running tasks than the configured limit, delete running tasks. This @@ -930,13 +936,34 @@ private[spark] class AppStatusListener( val remaining = countToDelete - toDelete.size if (remaining > 0) { val runningTasksToDelete = view.max(remaining).iterator().asScala.toList - runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-remaining) } + + // On live applications, cleanup any cached quantiles for the stage. This makes sure that + // quantiles will be recalculated after tasks are replaced with newer ones. + // + // This is not needed in the SHS since caching only happens after the event logs are + // completely processed. + if (live) { + cleanupCachedQuantiles(stageKey) + } } stage.cleaning = false } + private def cleanupCachedQuantiles(stageKey: Array[Int]): Unit = { + val cachedQuantiles = kvstore.view(classOf[CachedQuantile]) + .index("stage") + .first(stageKey) + .last(stageKey) + .asScala + .toList + cachedQuantiles.foreach { q => + kvstore.delete(q.getClass(), q.id) + } + } + /** * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done * asynchronously, this method may return 0 in case enough items have been deleted already. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 5a942f5284018..efc28538a33db 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ -import org.apache.spark.util.Distribution +import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** @@ -98,7 +98,11 @@ private[spark] class AppStatusStore( val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) .closeableIterator() try { - it.next().info + if (it.hasNext()) { + it.next().info + } else { + throw new NoSuchElementException(s"No stage with id $stageId") + } } finally { it.close() } @@ -110,107 +114,238 @@ private[spark] class AppStatusStore( if (details) stageWithDetails(stage) else stage } + def taskCount(stageId: Int, stageAttemptId: Int): Long = { + store.count(classOf[TaskDataWrapper], "stage", Array(stageId, stageAttemptId)) + } + + def localitySummary(stageId: Int, stageAttemptId: Int): Map[String, Long] = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).locality + } + + /** + * Calculates a summary of the task metrics for the given stage attempt, returning the + * requested quantiles for the recorded metrics. + * + * This method can be expensive if the requested quantiles are not cached; the method + * will only cache certain quantiles (every 0.05 step), so it's recommended to stick to + * those to avoid expensive scans of all task data. + */ def taskSummary( stageId: Int, stageAttemptId: Int, - quantiles: Array[Double]): v1.TaskMetricDistributions = { - - val stage = Array(stageId, stageAttemptId) - - val rawMetrics = store.view(classOf[TaskDataWrapper]) - .index("stage") - .first(stage) - .last(stage) - .asScala - .flatMap(_.info.taskMetrics) - .toList - .view - - def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics = - new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics - - def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics = - new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics - - def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics = - new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( - readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles { s => - s.localBlocksFetched + s.remoteBlocksFetched - }, - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics = - new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = - raw.shuffleWriteMetrics - - def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new v1.TaskMetricDistributions( + unsortedQuantiles: Array[Double]): Option[v1.TaskMetricDistributions] = { + val stageKey = Array(stageId, stageAttemptId) + val quantiles = unsortedQuantiles.sorted + + // We don't know how many tasks remain in the store that actually have metrics. So scan one + // metric and count how many valid tasks there are. Use skip() instead of next() since it's + // cheaper for disk stores (avoids deserialization). + val count = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.EXEC_RUN_TIME) + .first(0L) + .closeableIterator() + ) { it => + var _count = 0L + while (it.hasNext()) { + _count += 1 + it.skip(1) + } + _count + } + } + + if (count <= 0) { + return None + } + + // Find out which quantiles are already cached. The data in the store must match the expected + // task count to be considered, otherwise it will be re-scanned and overwritten. + val cachedQuantiles = quantiles.filter(shouldCacheQuantile).flatMap { q => + val qkey = Array(stageId, stageAttemptId, quantileToString(q)) + asOption(store.read(classOf[CachedQuantile], qkey)).filter(_.taskCount == count) + } + + // If there are no missing quantiles, return the data. Otherwise, just compute everything + // to make the code simpler. + if (cachedQuantiles.size == quantiles.size) { + def toValues(fn: CachedQuantile => Double): IndexedSeq[Double] = cachedQuantiles.map(fn) + + val distributions = new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = toValues(_.executorDeserializeTime), + executorDeserializeCpuTime = toValues(_.executorDeserializeCpuTime), + executorRunTime = toValues(_.executorRunTime), + executorCpuTime = toValues(_.executorCpuTime), + resultSize = toValues(_.resultSize), + jvmGcTime = toValues(_.jvmGcTime), + resultSerializationTime = toValues(_.resultSerializationTime), + gettingResultTime = toValues(_.gettingResultTime), + schedulerDelay = toValues(_.schedulerDelay), + peakExecutionMemory = toValues(_.peakExecutionMemory), + memoryBytesSpilled = toValues(_.memoryBytesSpilled), + diskBytesSpilled = toValues(_.diskBytesSpilled), + inputMetrics = new v1.InputMetricDistributions( + toValues(_.bytesRead), + toValues(_.recordsRead)), + outputMetrics = new v1.OutputMetricDistributions( + toValues(_.bytesWritten), + toValues(_.recordsWritten)), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + toValues(_.shuffleReadBytes), + toValues(_.shuffleRecordsRead), + toValues(_.shuffleRemoteBlocksFetched), + toValues(_.shuffleLocalBlocksFetched), + toValues(_.shuffleFetchWaitTime), + toValues(_.shuffleRemoteBytesRead), + toValues(_.shuffleRemoteBytesReadToDisk), + toValues(_.shuffleTotalBlocksFetched)), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + toValues(_.shuffleWriteBytes), + toValues(_.shuffleWriteRecords), + toValues(_.shuffleWriteTime))) + + return Some(distributions) + } + + // Compute quantiles by scanning the tasks in the store. This is not really stable for live + // stages (e.g. the number of recorded tasks may change while this code is running), but should + // stabilize once the stage finishes. It's also slow, especially with disk stores. + val indices = quantiles.map { q => math.min((q * count).toLong, count - 1) } + + def scanTasks(index: String)(fn: TaskDataWrapper => Long): IndexedSeq[Double] = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(index) + .first(0L) + .closeableIterator() + ) { it => + var last = Double.NaN + var currentIdx = -1L + indices.map { idx => + if (idx == currentIdx) { + last + } else { + val diff = idx - currentIdx + currentIdx = idx + if (it.skip(diff - 1)) { + last = fn(it.next()).toDouble + last + } else { + Double.NaN + } + } + }.toIndexedSeq + } + } + + val computedQuantiles = new v1.TaskMetricDistributions( quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGcTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) + executorDeserializeTime = scanTasks(TaskIndexNames.DESER_TIME) { t => + t.executorDeserializeTime + }, + executorDeserializeCpuTime = scanTasks(TaskIndexNames.DESER_CPU_TIME) { t => + t.executorDeserializeCpuTime + }, + executorRunTime = scanTasks(TaskIndexNames.EXEC_RUN_TIME) { t => t.executorRunTime }, + executorCpuTime = scanTasks(TaskIndexNames.EXEC_CPU_TIME) { t => t.executorCpuTime }, + resultSize = scanTasks(TaskIndexNames.RESULT_SIZE) { t => t.resultSize }, + jvmGcTime = scanTasks(TaskIndexNames.GC_TIME) { t => t.jvmGcTime }, + resultSerializationTime = scanTasks(TaskIndexNames.SER_TIME) { t => + t.resultSerializationTime + }, + gettingResultTime = scanTasks(TaskIndexNames.GETTING_RESULT_TIME) { t => + t.gettingResultTime + }, + schedulerDelay = scanTasks(TaskIndexNames.SCHEDULER_DELAY) { t => t.schedulerDelay }, + peakExecutionMemory = scanTasks(TaskIndexNames.PEAK_MEM) { t => t.peakExecutionMemory }, + memoryBytesSpilled = scanTasks(TaskIndexNames.MEM_SPILL) { t => t.memoryBytesSpilled }, + diskBytesSpilled = scanTasks(TaskIndexNames.DISK_SPILL) { t => t.diskBytesSpilled }, + inputMetrics = new v1.InputMetricDistributions( + scanTasks(TaskIndexNames.INPUT_SIZE) { t => t.inputBytesRead }, + scanTasks(TaskIndexNames.INPUT_RECORDS) { t => t.inputRecordsRead }), + outputMetrics = new v1.OutputMetricDistributions( + scanTasks(TaskIndexNames.OUTPUT_SIZE) { t => t.outputBytesWritten }, + scanTasks(TaskIndexNames.OUTPUT_RECORDS) { t => t.outputRecordsWritten }), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_READS) { m => + m.shuffleLocalBytesRead + m.shuffleRemoteBytesRead + }, + scanTasks(TaskIndexNames.SHUFFLE_READ_RECORDS) { t => t.shuffleRecordsRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_BLOCKS) { t => t.shuffleRemoteBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_LOCAL_BLOCKS) { t => t.shuffleLocalBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_READ_TIME) { t => t.shuffleFetchWaitTime }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS) { t => t.shuffleRemoteBytesRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK) { t => + t.shuffleRemoteBytesReadToDisk + }, + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_BLOCKS) { m => + m.shuffleLocalBlocksFetched + m.shuffleRemoteBlocksFetched + }), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_WRITE_SIZE) { t => t.shuffleBytesWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_RECORDS) { t => t.shuffleRecordsWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_TIME) { t => t.shuffleWriteTime })) + + // Go through the computed quantiles and cache the values that match the caching criteria. + computedQuantiles.quantiles.zipWithIndex + .filter { case (q, _) => quantiles.contains(q) && shouldCacheQuantile(q) } + .foreach { case (q, idx) => + val cached = new CachedQuantile(stageId, stageAttemptId, quantileToString(q), count, + executorDeserializeTime = computedQuantiles.executorDeserializeTime(idx), + executorDeserializeCpuTime = computedQuantiles.executorDeserializeCpuTime(idx), + executorRunTime = computedQuantiles.executorRunTime(idx), + executorCpuTime = computedQuantiles.executorCpuTime(idx), + resultSize = computedQuantiles.resultSize(idx), + jvmGcTime = computedQuantiles.jvmGcTime(idx), + resultSerializationTime = computedQuantiles.resultSerializationTime(idx), + gettingResultTime = computedQuantiles.gettingResultTime(idx), + schedulerDelay = computedQuantiles.schedulerDelay(idx), + peakExecutionMemory = computedQuantiles.peakExecutionMemory(idx), + memoryBytesSpilled = computedQuantiles.memoryBytesSpilled(idx), + diskBytesSpilled = computedQuantiles.diskBytesSpilled(idx), + + bytesRead = computedQuantiles.inputMetrics.bytesRead(idx), + recordsRead = computedQuantiles.inputMetrics.recordsRead(idx), + + bytesWritten = computedQuantiles.outputMetrics.bytesWritten(idx), + recordsWritten = computedQuantiles.outputMetrics.recordsWritten(idx), + + shuffleReadBytes = computedQuantiles.shuffleReadMetrics.readBytes(idx), + shuffleRecordsRead = computedQuantiles.shuffleReadMetrics.readRecords(idx), + shuffleRemoteBlocksFetched = + computedQuantiles.shuffleReadMetrics.remoteBlocksFetched(idx), + shuffleLocalBlocksFetched = computedQuantiles.shuffleReadMetrics.localBlocksFetched(idx), + shuffleFetchWaitTime = computedQuantiles.shuffleReadMetrics.fetchWaitTime(idx), + shuffleRemoteBytesRead = computedQuantiles.shuffleReadMetrics.remoteBytesRead(idx), + shuffleRemoteBytesReadToDisk = + computedQuantiles.shuffleReadMetrics.remoteBytesReadToDisk(idx), + shuffleTotalBlocksFetched = computedQuantiles.shuffleReadMetrics.totalBlocksFetched(idx), + + shuffleWriteBytes = computedQuantiles.shuffleWriteMetrics.writeBytes(idx), + shuffleWriteRecords = computedQuantiles.shuffleWriteMetrics.writeRecords(idx), + shuffleWriteTime = computedQuantiles.shuffleWriteMetrics.writeTime(idx)) + store.write(cached) + } + + Some(computedQuantiles) } + /** + * Whether to cache information about a specific metric quantile. We cache quantiles at every 0.05 + * step, which covers the default values used both in the API and in the stages page. + */ + private def shouldCacheQuantile(q: Double): Boolean = (math.round(q * 100) % 5) == 0 + + private def quantileToString(q: Double): String = math.round(q * 100).toString + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.info).toSeq.reverse + .max(maxTasks).asScala.map(_.toApi).toSeq.reverse } def taskList( @@ -219,18 +354,43 @@ private[spark] class AppStatusStore( offset: Int, length: Int, sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val (indexName, ascending) = sortBy match { + case v1.TaskSorting.ID => + (None, true) + case v1.TaskSorting.INCREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), true) + case v1.TaskSorting.DECREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), false) + } + taskList(stageId, stageAttemptId, offset, length, indexName, ascending) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: Option[String], + ascending: Boolean): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) val base = store.view(classOf[TaskDataWrapper]) val indexed = sortBy match { - case v1.TaskSorting.ID => + case Some(index) => + base.index(index).parent(stageKey) + + case _ => + // Sort by ID, which is the "stage" index. base.index("stage").first(stageKey).last(stageKey) - case v1.TaskSorting.INCREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) - case v1.TaskSorting.DECREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) - .reverse() } - indexed.skip(offset).max(length).asScala.map(_.info).toSeq + + val ordered = if (ascending) indexed else indexed.reverse() + ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + } + + def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { + val stageKey = Array(stageId, attemptId) + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey).last(stageKey) + .asScala.map { exec => (exec.executorId -> exec.info) }.toMap } def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { @@ -256,12 +416,6 @@ private[spark] class AppStatusStore( .map { t => (t.taskId, t) } .toMap - val stageKey = Array(stage.stageId, stage.attemptId) - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) - .last(stageKey).closeableIterator().asScala - .map { exec => (exec.executorId -> exec.info) } - .toMap - new v1.StageData( stage.status, stage.stageId, @@ -295,7 +449,7 @@ private[spark] class AppStatusStore( stage.rddIds, stage.accumulatorUpdates, Some(tasks), - Some(execs), + Some(executorSummary(stage.stageId, stage.attemptId)), stage.killedTasksSummary) } @@ -352,22 +506,3 @@ private[spark] object AppStatusStore { } } - -/** - * Helper for getting distributions from nested metric types. - */ -private abstract class MetricHelper[I, O]( - rawMetrics: Seq[v1.TaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: v1.TaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala new file mode 100644 index 0000000000000..341bd4e0cd016 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +private[spark] object AppStatusUtils { + + def schedulerDelay(task: TaskData): Long = { + if (task.taskMetrics.isDefined && task.duration.isDefined) { + val m = task.taskMetrics.get + schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, + m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) + } else { + 0L + } + } + + def gettingResultTime(task: TaskData): Long = { + gettingResultTime(task.launchTime.getTime(), fetchStart(task), task.duration.getOrElse(-1L)) + } + + def schedulerDelay( + launchTime: Long, + fetchStart: Long, + duration: Long, + deserializeTime: Long, + serializeTime: Long, + runTime: Long): Long = { + math.max(0, duration - runTime - deserializeTime - serializeTime - + gettingResultTime(launchTime, fetchStart, duration)) + } + + def gettingResultTime(launchTime: Long, fetchStart: Long, duration: Long): Long = { + if (fetchStart > 0) { + if (duration > 0) { + launchTime + duration - fetchStart + } else { + System.currentTimeMillis() - fetchStart + } + } else { + 0L + } + } + + private def fetchStart(task: TaskData): Long = { + if (task.resultFetchStart.isDefined) { + task.resultFetchStart.get.getTime() + } else { + -1 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 305c2fafa6aac..4295e664e131c 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} @@ -119,7 +121,9 @@ private class LiveTask( import LiveEntityHelpers._ - private var recordedMetrics: v1.TaskMetrics = null + // The task metrics use a special value when no metrics have been reported. The special value is + // checked when calculating indexed values when writing to the store (see [[TaskDataWrapper]]). + private var metrics: v1.TaskMetrics = createMetrics(default = -1L) var errorMessage: Option[String] = None @@ -129,8 +133,8 @@ private class LiveTask( */ def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { if (metrics != null) { - val old = recordedMetrics - recordedMetrics = new v1.TaskMetrics( + val old = this.metrics + val newMetrics = createMetrics( metrics.executorDeserializeTime, metrics.executorDeserializeCpuTime, metrics.executorRunTime, @@ -141,73 +145,35 @@ private class LiveTask( metrics.memoryBytesSpilled, metrics.diskBytesSpilled, metrics.peakExecutionMemory, - new v1.InputMetrics( - metrics.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead), - new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten), - new v1.ShuffleReadMetrics( - metrics.shuffleReadMetrics.remoteBlocksFetched, - metrics.shuffleReadMetrics.localBlocksFetched, - metrics.shuffleReadMetrics.fetchWaitTime, - metrics.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead), - new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten, - metrics.shuffleWriteMetrics.writeTime, - metrics.shuffleWriteMetrics.recordsWritten)) - if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten) + + this.metrics = newMetrics + + // Only calculate the delta if the old metrics contain valid information, otherwise + // the new metrics are the delta. + if (old.executorDeserializeTime >= 0L) { + subtractMetrics(newMetrics, old) + } else { + newMetrics + } } else { null } } - /** - * Return a new TaskMetrics object containing the delta of the various fields of the given - * metrics objects. This is currently targeted at updating stage data, so it does not - * necessarily calculate deltas for all the fields. - */ - private def calculateMetricsDelta( - metrics: v1.TaskMetrics, - old: v1.TaskMetrics): v1.TaskMetrics = { - val shuffleWriteDelta = new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, - 0L, - metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) - - val shuffleReadDelta = new v1.ShuffleReadMetrics( - 0L, 0L, 0L, - metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk - - old.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) - - val inputDelta = new v1.InputMetrics( - metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) - - val outputDelta = new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) - - new v1.TaskMetrics( - 0L, 0L, - metrics.executorRunTime - old.executorRunTime, - metrics.executorCpuTime - old.executorCpuTime, - 0L, 0L, 0L, - metrics.memoryBytesSpilled - old.memoryBytesSpilled, - metrics.diskBytesSpilled - old.diskBytesSpilled, - 0L, - inputDelta, - outputDelta, - shuffleReadDelta, - shuffleWriteDelta) - } - override protected def doUpdate(): Any = { val duration = if (info.finished) { info.duration @@ -215,22 +181,48 @@ private class LiveTask( info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) } - val task = new v1.TaskData( + new TaskDataWrapper( info.taskId, info.index, info.attemptNumber, - new Date(info.launchTime), - if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, - Some(duration), - info.executorId, - info.host, - info.status, - info.taskLocality.toString(), + info.launchTime, + if (info.gettingResult) info.gettingResultTime else -1L, + duration, + weakIntern(info.executorId), + weakIntern(info.host), + weakIntern(info.status), + weakIntern(info.taskLocality.toString()), info.speculative, newAccumulatorInfos(info.accumulables), errorMessage, - Option(recordedMetrics)) - new TaskDataWrapper(task, stageId, stageAttemptId) + + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGcTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + metrics.peakExecutionMemory, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten, + + stageId, + stageAttemptId) } } @@ -313,50 +305,19 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE } -/** Metrics tracked per stage (both total and per executor). */ -private class MetricsTracker { - var executorRunTime = 0L - var executorCpuTime = 0L - var inputBytes = 0L - var inputRecords = 0L - var outputBytes = 0L - var outputRecords = 0L - var shuffleReadBytes = 0L - var shuffleReadRecords = 0L - var shuffleWriteBytes = 0L - var shuffleWriteRecords = 0L - var memoryBytesSpilled = 0L - var diskBytesSpilled = 0L - - def update(delta: v1.TaskMetrics): Unit = { - executorRunTime += delta.executorRunTime - executorCpuTime += delta.executorCpuTime - inputBytes += delta.inputMetrics.bytesRead - inputRecords += delta.inputMetrics.recordsRead - outputBytes += delta.outputMetrics.bytesWritten - outputRecords += delta.outputMetrics.recordsWritten - shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + - delta.shuffleReadMetrics.remoteBytesRead - shuffleReadRecords += delta.shuffleReadMetrics.recordsRead - shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten - shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten - memoryBytesSpilled += delta.memoryBytesSpilled - diskBytesSpilled += delta.diskBytesSpilled - } - -} - private class LiveExecutorStageSummary( stageId: Int, attemptId: Int, executorId: String) extends LiveEntity { + import LiveEntityHelpers._ + var taskTime = 0L var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 - val metrics = new MetricsTracker() + var metrics = createMetrics(default = 0L) override protected def doUpdate(): Any = { val info = new v1.ExecutorStageSummary( @@ -364,14 +325,14 @@ private class LiveExecutorStageSummary( failedTasks, succeededTasks, killedTasks, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBytesRead + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) @@ -402,7 +363,9 @@ private class LiveStage extends LiveEntity { var firstLaunchTime = Long.MaxValue - val metrics = new MetricsTracker() + var localitySummary: Map[String, Long] = Map() + + var metrics = createMetrics(default = 0L) val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() @@ -435,14 +398,14 @@ private class LiveStage extends LiveEntity { info.completionTime.map(new Date(_)), info.failureReason, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.localBytesRead + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, @@ -459,13 +422,15 @@ private class LiveStage extends LiveEntity { } override protected def doUpdate(): Any = { - new StageDataWrapper(toApi(), jobIds) + new StageDataWrapper(toApi(), jobIds, localitySummary) } } private class LiveRDDPartition(val blockName: String) { + import LiveEntityHelpers._ + // Pointers used by RDDPartitionSeq. @volatile var prev: LiveRDDPartition = null @volatile var next: LiveRDDPartition = null @@ -485,7 +450,7 @@ private class LiveRDDPartition(val blockName: String) { diskUsed: Long): Unit = { value = new v1.RDDPartitionInfo( blockName, - storageLevel, + weakIntern(storageLevel), memoryUsed, diskUsed, executors) @@ -495,6 +460,8 @@ private class LiveRDDPartition(val blockName: String) { private class LiveRDDDistribution(exec: LiveExecutor) { + import LiveEntityHelpers._ + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L @@ -508,7 +475,7 @@ private class LiveRDDDistribution(exec: LiveExecutor) { def toApi(): v1.RDDDataDistribution = { if (lastUpdate == null) { lastUpdate = new v1.RDDDataDistribution( - exec.hostPort, + weakIntern(exec.hostPort), memoryUsed, exec.maxMemory - exec.memoryUsed, diskUsed, @@ -524,7 +491,9 @@ private class LiveRDDDistribution(exec: LiveExecutor) { private class LiveRDD(val info: RDDInfo) extends LiveEntity { - var storageLevel: String = info.storageLevel.description + import LiveEntityHelpers._ + + var storageLevel: String = weakIntern(info.storageLevel.description) var memoryUsed = 0L var diskUsed = 0L @@ -533,6 +502,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { private val distributions = new HashMap[String, LiveRDDDistribution]() + def setStorageLevel(level: String): Unit = { + this.storageLevel = weakIntern(level) + } + def partition(blockName: String): LiveRDDPartition = { partitions.getOrElseUpdate(blockName, { val part = new LiveRDDPartition(blockName) @@ -593,6 +566,9 @@ private class SchedulerPool(name: String) extends LiveEntity { private object LiveEntityHelpers { + private val stringInterner = Interners.newWeakInterner[String]() + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { accums .filter { acc => @@ -604,13 +580,119 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.orNull, + acc.name.map(weakIntern).orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } .toSeq } + /** String interning to reduce the memory usage. */ + def weakIntern(s: String): String = { + stringInterner.intern(s) + } + + // scalastyle:off argcount + def createMetrics( + executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, + executorRunTime: Long, + executorCpuTime: Long, + resultSize: Long, + jvmGcTime: Long, + resultSerializationTime: Long, + memoryBytesSpilled: Long, + diskBytesSpilled: Long, + peakExecutionMemory: Long, + inputBytesRead: Long, + inputRecordsRead: Long, + outputBytesWritten: Long, + outputRecordsWritten: Long, + shuffleRemoteBlocksFetched: Long, + shuffleLocalBlocksFetched: Long, + shuffleFetchWaitTime: Long, + shuffleRemoteBytesRead: Long, + shuffleRemoteBytesReadToDisk: Long, + shuffleLocalBytesRead: Long, + shuffleRecordsRead: Long, + shuffleBytesWritten: Long, + shuffleWriteTime: Long, + shuffleRecordsWritten: Long): v1.TaskMetrics = { + new v1.TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new v1.InputMetrics( + inputBytesRead, + inputRecordsRead), + new v1.OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new v1.ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new v1.ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten)) + } + // scalastyle:on argcount + + def createMetrics(default: Long): v1.TaskMetrics = { + createMetrics(default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default) + } + + /** Add m2 values to m1. */ + def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = addMetrics(m1, m2, 1) + + /** Subtract m2 values from m1. */ + def subtractMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = { + addMetrics(m1, m2, -1) + } + + private def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics, mult: Int): v1.TaskMetrics = { + createMetrics( + m1.executorDeserializeTime + m2.executorDeserializeTime * mult, + m1.executorDeserializeCpuTime + m2.executorDeserializeCpuTime * mult, + m1.executorRunTime + m2.executorRunTime * mult, + m1.executorCpuTime + m2.executorCpuTime * mult, + m1.resultSize + m2.resultSize * mult, + m1.jvmGcTime + m2.jvmGcTime * mult, + m1.resultSerializationTime + m2.resultSerializationTime * mult, + m1.memoryBytesSpilled + m2.memoryBytesSpilled * mult, + m1.diskBytesSpilled + m2.diskBytesSpilled * mult, + m1.peakExecutionMemory + m2.peakExecutionMemory * mult, + m1.inputMetrics.bytesRead + m2.inputMetrics.bytesRead * mult, + m1.inputMetrics.recordsRead + m2.inputMetrics.recordsRead * mult, + m1.outputMetrics.bytesWritten + m2.outputMetrics.bytesWritten * mult, + m1.outputMetrics.recordsWritten + m2.outputMetrics.recordsWritten * mult, + m1.shuffleReadMetrics.remoteBlocksFetched + m2.shuffleReadMetrics.remoteBlocksFetched * mult, + m1.shuffleReadMetrics.localBlocksFetched + m2.shuffleReadMetrics.localBlocksFetched * mult, + m1.shuffleReadMetrics.fetchWaitTime + m2.shuffleReadMetrics.fetchWaitTime * mult, + m1.shuffleReadMetrics.remoteBytesRead + m2.shuffleReadMetrics.remoteBytesRead * mult, + m1.shuffleReadMetrics.remoteBytesReadToDisk + + m2.shuffleReadMetrics.remoteBytesReadToDisk * mult, + m1.shuffleReadMetrics.localBytesRead + m2.shuffleReadMetrics.localBytesRead * mult, + m1.shuffleReadMetrics.recordsRead + m2.shuffleReadMetrics.recordsRead * mult, + m1.shuffleWriteMetrics.bytesWritten + m2.shuffleWriteMetrics.bytesWritten * mult, + m1.shuffleWriteMetrics.writeTime + m2.shuffleWriteMetrics.writeTime * mult, + m1.shuffleWriteMetrics.recordsWritten + m2.shuffleWriteMetrics.recordsWritten * mult) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 3b879545b3d2e..96249e4bfd5fa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -87,7 +87,8 @@ private[v1] class StagesResource extends BaseAppResource { } } - ui.store.taskSummary(stageId, stageAttemptId, quantiles) + ui.store.taskSummary(stageId, stageAttemptId, quantiles).getOrElse( + throw new NotFoundException(s"No tasks reported metrics for $stageId / $stageAttemptId yet.")) } @GET diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 45eaf935fb083..7d8e4de3c8efb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -261,6 +261,9 @@ class TaskMetricDistributions private[spark]( val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], + val gettingResultTime: IndexedSeq[Double], + val schedulerDelay: IndexedSeq[Double], + val peakExecutionMemory: IndexedSeq[Double], val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 1cfd30df49091..c9cb996a55fcc 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,9 +17,11 @@ package org.apache.spark.status -import java.lang.{Integer => JInteger, Long => JLong} +import java.lang.{Long => JLong} +import java.util.Date import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ @@ -49,10 +51,10 @@ private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvi private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex - private[this] val id: String = info.id + private def id: String = info.id @JsonIgnore @KVIndex("active") - private[this] val active: Boolean = info.isActive + private def active: Boolean = info.isActive @JsonIgnore @KVIndex("host") val host: String = info.hostPort.split(":")(0) @@ -69,51 +71,271 @@ private[spark] class JobDataWrapper( val skippedStages: Set[Int]) { @JsonIgnore @KVIndex - private[this] val id: Int = info.jobId + private def id: Int = info.jobId } private[spark] class StageDataWrapper( val info: StageData, - val jobIds: Set[Int]) { + val jobIds: Set[Int], + @JsonDeserialize(contentAs = classOf[JLong]) + val locality: Map[String, Long]) { @JsonIgnore @KVIndex - def id: Array[Int] = Array(info.stageId, info.attemptId) + private[this] val id: Array[Int] = Array(info.stageId, info.attemptId) @JsonIgnore @KVIndex("stageId") - def stageId: Int = info.stageId + private def stageId: Int = info.stageId + @JsonIgnore @KVIndex("active") + private def active: Boolean = info.status == StageStatus.ACTIVE + +} + +/** + * Tasks have a lot of indices that are used in a few different places. This object keeps logical + * names for these indices, mapped to short strings to save space when using a disk store. + */ +private[spark] object TaskIndexNames { + final val ACCUMULATORS = "acc" + final val ATTEMPT = "att" + final val DESER_CPU_TIME = "dct" + final val DESER_TIME = "des" + final val DISK_SPILL = "dbs" + final val DURATION = "dur" + final val ERROR = "err" + final val EXECUTOR = "exe" + final val EXEC_CPU_TIME = "ect" + final val EXEC_RUN_TIME = "ert" + final val GC_TIME = "gc" + final val GETTING_RESULT_TIME = "grt" + final val INPUT_RECORDS = "ir" + final val INPUT_SIZE = "is" + final val LAUNCH_TIME = "lt" + final val LOCALITY = "loc" + final val MEM_SPILL = "mbs" + final val OUTPUT_RECORDS = "or" + final val OUTPUT_SIZE = "os" + final val PEAK_MEM = "pem" + final val RESULT_SIZE = "rs" + final val SCHEDULER_DELAY = "dly" + final val SER_TIME = "rst" + final val SHUFFLE_LOCAL_BLOCKS = "slbl" + final val SHUFFLE_READ_RECORDS = "srr" + final val SHUFFLE_READ_TIME = "srt" + final val SHUFFLE_REMOTE_BLOCKS = "srbl" + final val SHUFFLE_REMOTE_READS = "srby" + final val SHUFFLE_REMOTE_READS_TO_DISK = "srbd" + final val SHUFFLE_TOTAL_READS = "stby" + final val SHUFFLE_TOTAL_BLOCKS = "stbl" + final val SHUFFLE_WRITE_RECORDS = "swr" + final val SHUFFLE_WRITE_SIZE = "sws" + final val SHUFFLE_WRITE_TIME = "swt" + final val STAGE = "stage" + final val STATUS = "sta" + final val TASK_INDEX = "idx" } /** - * The task information is always indexed with the stage ID, since that is how the UI and API - * consume it. That means every indexed value has the stage ID and attempt ID included, aside - * from the actual data being indexed. + * Unlike other data types, the task data wrapper does not keep a reference to the API's TaskData. + * That is to save memory, since for large applications there can be a large number of these + * elements (by default up to 100,000 per stage), and every bit of wasted memory adds up. + * + * It also contains many secondary indices, which are used to sort data efficiently in the UI at the + * expense of storage space (and slower write times). */ private[spark] class TaskDataWrapper( - val info: TaskData, + // Storing this as an object actually saves memory; it's also used as the key in the in-memory + // store, so in that case you'd save the extra copy of the value here. + @KVIndexParam + val taskId: JLong, + @KVIndexParam(value = TaskIndexNames.TASK_INDEX, parent = TaskIndexNames.STAGE) + val index: Int, + @KVIndexParam(value = TaskIndexNames.ATTEMPT, parent = TaskIndexNames.STAGE) + val attempt: Int, + @KVIndexParam(value = TaskIndexNames.LAUNCH_TIME, parent = TaskIndexNames.STAGE) + val launchTime: Long, + val resultFetchStart: Long, + @KVIndexParam(value = TaskIndexNames.DURATION, parent = TaskIndexNames.STAGE) + val duration: Long, + @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) + val executorId: String, + val host: String, + @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) + val status: String, + @KVIndexParam(value = TaskIndexNames.LOCALITY, parent = TaskIndexNames.STAGE) + val taskLocality: String, + val speculative: Boolean, + val accumulatorUpdates: Seq[AccumulableInfo], + val errorMessage: Option[String], + + // The following is an exploded view of a TaskMetrics API object. This saves 5 objects + // (= 80 bytes of Java object overhead) per instance of this wrapper. If the first value + // (executorDeserializeTime) is -1L, it means the metrics for this task have not been + // recorded. + @KVIndexParam(value = TaskIndexNames.DESER_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeTime: Long, + @KVIndexParam(value = TaskIndexNames.DESER_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_RUN_TIME, parent = TaskIndexNames.STAGE) + val executorRunTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.RESULT_SIZE, parent = TaskIndexNames.STAGE) + val resultSize: Long, + @KVIndexParam(value = TaskIndexNames.GC_TIME, parent = TaskIndexNames.STAGE) + val jvmGcTime: Long, + @KVIndexParam(value = TaskIndexNames.SER_TIME, parent = TaskIndexNames.STAGE) + val resultSerializationTime: Long, + @KVIndexParam(value = TaskIndexNames.MEM_SPILL, parent = TaskIndexNames.STAGE) + val memoryBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.DISK_SPILL, parent = TaskIndexNames.STAGE) + val diskBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.PEAK_MEM, parent = TaskIndexNames.STAGE) + val peakExecutionMemory: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_SIZE, parent = TaskIndexNames.STAGE) + val inputBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_RECORDS, parent = TaskIndexNames.STAGE) + val inputRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_SIZE, parent = TaskIndexNames.STAGE) + val outputBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_RECORDS, parent = TaskIndexNames.STAGE) + val outputRecordsWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_LOCAL_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleLocalBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_TIME, parent = TaskIndexNames.STAGE) + val shuffleFetchWaitTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK, + parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesReadToDisk: Long, + val shuffleLocalBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_SIZE, parent = TaskIndexNames.STAGE) + val shuffleBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_TIME, parent = TaskIndexNames.STAGE) + val shuffleWriteTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsWritten: Long, + val stageId: Int, val stageAttemptId: Int) { - @JsonIgnore @KVIndex - def id: Long = info.taskId + def hasMetrics: Boolean = executorDeserializeTime >= 0 + + def toApi: TaskData = { + val metrics = if (hasMetrics) { + Some(new TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new InputMetrics( + inputBytesRead, + inputRecordsRead), + new OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten))) + } else { + None + } - @JsonIgnore @KVIndex("stage") - def stage: Array[Int] = Array(stageId, stageAttemptId) + new TaskData( + taskId, + index, + attempt, + new Date(launchTime), + if (resultFetchStart > 0L) Some(new Date(resultFetchStart)) else None, + if (duration > 0L) Some(duration) else None, + executorId, + host, + status, + taskLocality, + speculative, + accumulatorUpdates, + errorMessage, + metrics) + } + + @JsonIgnore @KVIndex(TaskIndexNames.STAGE) + private def stage: Array[Int] = Array(stageId, stageAttemptId) - @JsonIgnore @KVIndex("runtime") - def runtime: Array[AnyRef] = { - val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) - Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.SCHEDULER_DELAY, parent = TaskIndexNames.STAGE) + def schedulerDelay: Long = { + if (hasMetrics) { + AppStatusUtils.schedulerDelay(launchTime, resultFetchStart, duration, executorDeserializeTime, + resultSerializationTime, executorRunTime) + } else { + -1L + } } - @JsonIgnore @KVIndex("startTime") - def startTime: Array[AnyRef] = { - Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.GETTING_RESULT_TIME, parent = TaskIndexNames.STAGE) + def gettingResultTime: Long = { + if (hasMetrics) { + AppStatusUtils.gettingResultTime(launchTime, resultFetchStart, duration) + } else { + -1L + } } - @JsonIgnore @KVIndex("active") - def active: Boolean = info.duration.isEmpty + /** + * Sorting by accumulators is a little weird, and the previous behavior would generate + * insanely long keys in the index. So this implementation just considers the first + * accumulator and its String representation. + */ + @JsonIgnore @KVIndex(value = TaskIndexNames.ACCUMULATORS, parent = TaskIndexNames.STAGE) + private def accumulators: String = { + if (accumulatorUpdates.nonEmpty) { + val acc = accumulatorUpdates.head + s"${acc.name}:${acc.value}" + } else { + "" + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_READS, parent = TaskIndexNames.STAGE) + private def shuffleTotalReads: Long = { + if (hasMetrics) { + shuffleLocalBytesRead + shuffleRemoteBytesRead + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_BLOCKS, parent = TaskIndexNames.STAGE) + private def shuffleTotalBlocks: Long = { + if (hasMetrics) { + shuffleLocalBlocksFetched + shuffleRemoteBlocksFetched + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) + private def error: String = if (errorMessage.isDefined) errorMessage.get else "" } @@ -134,10 +356,13 @@ private[spark] class ExecutorStageSummaryWrapper( val info: ExecutorStageSummary) { @JsonIgnore @KVIndex - val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + private val _id: Array[Any] = Array(stageId, stageAttemptId, executorId) @JsonIgnore @KVIndex("stage") - private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + private def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore + def id: Array[Any] = _id } @@ -203,3 +428,53 @@ private[spark] class AppSummary( def id: String = classOf[AppSummary].getName() } + +/** + * A cached view of a specific quantile for one stage attempt's metrics. + */ +private[spark] class CachedQuantile( + val stageId: Int, + val stageAttemptId: Int, + val quantile: String, + val taskCount: Long, + + // The following fields are an exploded view of a single entry for TaskMetricDistributions. + val executorDeserializeTime: Double, + val executorDeserializeCpuTime: Double, + val executorRunTime: Double, + val executorCpuTime: Double, + val resultSize: Double, + val jvmGcTime: Double, + val resultSerializationTime: Double, + val gettingResultTime: Double, + val schedulerDelay: Double, + val peakExecutionMemory: Double, + val memoryBytesSpilled: Double, + val diskBytesSpilled: Double, + + val bytesRead: Double, + val recordsRead: Double, + + val bytesWritten: Double, + val recordsWritten: Double, + + val shuffleReadBytes: Double, + val shuffleRecordsRead: Double, + val shuffleRemoteBlocksFetched: Double, + val shuffleLocalBlocksFetched: Double, + val shuffleFetchWaitTime: Double, + val shuffleRemoteBytesRead: Double, + val shuffleRemoteBytesReadToDisk: Double, + val shuffleTotalBlocksFetched: Double, + + val shuffleWriteBytes: Double, + val shuffleWriteRecords: Double, + val shuffleWriteTime: Double) { + + @KVIndex @JsonIgnore + def id: Array[Any] = Array(stageId, stageAttemptId, quantile) + + @KVIndex("stage") @JsonIgnore + def stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 41d42b52430a5..95c12b1e73653 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -87,7 +87,9 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { } private def createExecutorTable(stage: StageData) : Seq[Node] = { - stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) + + executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 740f12e7d13d4..bf59152c8c0cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -201,7 +201,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the // stage or if the stage information has been garbage collected - store.stageData(stageId).lastOption.getOrElse { + store.asOption(store.lastStageAttempt(stageId)).getOrElse { new v1.StageData( v1.StageStatus.PENDING, stageId, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 11a6a34344976..7c6e06cf183ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder import java.util.Date +import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} @@ -29,15 +30,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status._ import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.Utils /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import ApiHelper._ - import StagePage._ private val TIMELINE_LEGEND = {
@@ -67,17 +67,17 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { - val localities = taskList.map(_.taskLocality) - val localityCounts = localities.groupBy(identity).mapValues(_.size) + private def getLocalitySummaryString(localitySummary: Map[String, Long]): String = { val names = Map( TaskLocality.PROCESS_LOCAL.toString() -> "Process local", TaskLocality.NODE_LOCAL.toString() -> "Node local", TaskLocality.RACK_LOCAL.toString() -> "Rack local", TaskLocality.ANY.toString() -> "Any") - val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - s"${names(locality)}: $count" - } + val localityNamesAndCounts = names.flatMap { case (key, name) => + localitySummary.get(key).map { count => + s"$name: $count" + } + }.toSeq localityNamesAndCounts.sorted.mkString("; ") } @@ -108,7 +108,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" val stageData = parent.store - .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false)) .getOrElse { val content =
@@ -117,8 +117,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } - val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq - if (tasks.isEmpty) { + val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) + + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + if (totalTasks == 0) { val content =

Summary Metrics

No tasks have started yet @@ -127,18 +130,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } + val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { + val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${tasks.size}" + s"$totalTasks, showing ${storedTasks}" } - val externalAccumulables = stageData.accumulatorUpdates - val hasAccumulators = externalAccumulables.size > 0 - val summary =
    @@ -148,7 +147,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
  • Locality Level Summary: - {getLocalitySummaryString(stageData, tasks)} + {getLocalitySummaryString(localitySummary)}
  • {if (hasInput(stageData)) {
  • @@ -266,7 +265,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - externalAccumulables.toSeq) + stageData.accumulatorUpdates.toSeq) val page: Int = { // If the user has changed to a larger page size, then go to page 1 in order to avoid @@ -280,16 +279,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( - parent.conf, + stageData, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", - tasks, - hasAccumulators, - hasInput(stageData), - hasOutput(stageData), - hasShuffleRead(stageData), - hasShuffleWrite(stageData), - hasBytesSpilled(stageData), currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -320,217 +312,155 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We | } |}); """.stripMargin - } + } } - val taskIdsInPage = if (taskTable == null) Set.empty[Long] - else taskTable.dataSource.slicedTaskIds + val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, + Array(0, 0.25, 0.5, 0.75, 1.0)) - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) - - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } else { - def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { - Distribution(data).get.getQuantiles() - } - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - getDistributionQuantiles(times).map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + val summaryTable = metricsSummary.map { metrics => + def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { millis => + {UIUtils.formatDuration(millis.toLong)} } + } - val deserializationTimes = validTasks.map { task => - task.taskMetrics.get.executorDeserializeTime.toDouble - } - val deserializationQuantiles = - - - Task Deserialization Time - - +: getFormattedTimeQuantiles(deserializationTimes) - - val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) - val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) - - val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) - val gcQuantiles = - - GC Time - - +: getFormattedTimeQuantiles(gcTimes) - - val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) - val serializationQuantiles = - - - Result Serialization Time - - +: getFormattedTimeQuantiles(serializationTimes) - - val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) - val gettingResultQuantiles = - - - Getting Result Time - - +: - getFormattedTimeQuantiles(gettingResultTimes) - - val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) - val peakExecutionMemoryQuantiles = { - - - Peak Execution Memory - - +: getFormattedSizeQuantiles(peakExecutionMemory) + def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { size => + {Utils.bytesToString(size.toLong)} } + } - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { task => - getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble - } - val schedulerDelayTitle = Scheduler Delay - val schedulerDelayQuantiles = schedulerDelayTitle +: - getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) - : Seq[Elem] = { - val recordDist = getDistributionQuantiles(records).iterator - getDistributionQuantiles(data).map(d => - {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} - ) + def sizeQuantilesWithRecords( + data: IndexedSeq[Double], + records: IndexedSeq[Double]) : Seq[Node] = { + data.zip(records).map { case (d, r) => + {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} } + } - val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) - val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) - val inputQuantiles = Input Size / Records +: - getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) + def titleCell(title: String, tooltip: String): Seq[Node] = { + + + {title} + + + } - val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) - val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) - val outputQuantiles = Output Size / Records +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + def simpleTitleCell(title: String): Seq[Node] = {title} - val shuffleReadBlockedTimes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - - - Shuffle Read Blocked Time - - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) - - val shuffleReadTotalSizes = validTasks.map { task => - totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble - } - val shuffleReadTotalRecords = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - - - Shuffle Read Size / Records - - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - - val shuffleReadRemoteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble - } - val shuffleReadRemoteQuantiles = - - - Shuffle Remote Reads - - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) - - val shuffleWriteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationQuantiles = titleCell("Task Deserialization Time", + ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - val shuffleWriteRecords = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - val shuffleWriteQuantiles = Shuffle Write Size / Records +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) - val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val serializationQuantiles = titleCell("Result Serialization Time", + ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) - val diskBytesSpilledQuantiles = Shuffle spill (disk) +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) + val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ + timeQuantiles(metrics.gettingResultTime) - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", + ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) + + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ + timeQuantiles(metrics.schedulerDelay) + + def inputQuantiles: Seq[Node] = { + simpleTitleCell("Input Size / Records") ++ + sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) + } + + def outputQuantiles: Seq[Node] = { + simpleTitleCell("Output Size / Records") ++ + sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten) } + def shuffleReadBlockedQuantiles: Seq[Node] = { + titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ + timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) + } + + def shuffleReadTotalQuantiles: Seq[Node] = { + titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ + sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, + metrics.shuffleReadMetrics.readRecords) + } + + def shuffleReadRemoteQuantiles: Seq[Node] = { + titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ + sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) + } + + def shuffleWriteQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle Write Size / Records") ++ + sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, + metrics.shuffleWriteMetrics.writeRecords) + } + + def memoryBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) + } + + def diskBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) + } + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", + "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false) + } + val executorTable = new ExecutorTable(stageData, parent.store) val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators(stageData)) {

    Accumulators

    ++ accumulableTable } else Seq() val aggMetrics = taskIdsInPage.contains(t.taskId) }, + Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ @@ -593,10 +523,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskInfo, currentTime) + val gettingResultTime = AppStatusUtils.gettingResultTime(taskInfo) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = - metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelay = AppStatusUtils.schedulerDelay(taskInfo) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime @@ -708,7 +637,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We { if (MAX_TIMELINE_TASKS < tasks.size) { - This stage has more than the maximum number of tasks that can be shown in the + This page has more than the maximum number of tasks that can be shown in the visualization! Only the most recent {MAX_TIMELINE_TASKS} tasks (of {tasks.size} total) are shown. @@ -733,402 +662,49 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } -private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { - info.resultFetchStart match { - case Some(start) => - info.duration match { - case Some(duration) => - info.launchTime.getTime() + duration - start.getTime() - - case _ => - currentTime - start.getTime() - } - - case _ => - 0L - } - } - - private[ui] def getSchedulerDelay( - info: TaskData, - metrics: TaskMetrics, - currentTime: Long): Long = { - info.duration match { - case Some(duration) => - val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime - math.max( - 0, - duration - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - - case _ => - // The task is still running and the metrics like executorRunTime are not available. - 0L - } - } - -} - -private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) - -private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) - -private[ui] case class TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable: Long, - shuffleReadBlockedTimeReadable: String, - shuffleReadSortable: Long, - shuffleReadReadable: String, - shuffleReadRemoteSortable: Long, - shuffleReadRemoteReadable: String) - -private[ui] case class TaskTableRowShuffleWriteData( - writeTimeSortable: Long, - writeTimeReadable: String, - shuffleWriteSortable: Long, - shuffleWriteReadable: String) - -private[ui] case class TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable: Long, - memoryBytesSpilledReadable: String, - diskBytesSpilledSortable: Long, - diskBytesSpilledReadable: String) - -/** - * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskData to avoid creating duplicate contents during sorting the data. - */ -private[ui] class TaskTableRowData( - val index: Int, - val taskId: Long, - val attempt: Int, - val speculative: Boolean, - val status: String, - val taskLocality: String, - val executorId: String, - val host: String, - val launchTime: Long, - val duration: Long, - val formatDuration: String, - val schedulerDelay: Long, - val taskDeserializationTime: Long, - val gcTime: Long, - val serializationTime: Long, - val gettingResultTime: Long, - val peakExecutionMemoryUsed: Long, - val accumulators: Option[String], // HTML - val input: Option[TaskTableRowInputData], - val output: Option[TaskTableRowOutputData], - val shuffleRead: Option[TaskTableRowShuffleReadData], - val shuffleWrite: Option[TaskTableRowShuffleWriteData], - val bytesSpilled: Option[TaskTableRowBytesSpilledData], - val error: String, - val logs: Map[String, String]) - private[ui] class TaskDataSource( - tasks: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, + stage: StageData, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { - import StagePage._ + store: AppStatusStore) extends PagedDataSource[TaskData](pageSize) { + import ApiHelper._ // Keep an internal cache of executor log maps so that long task lists render faster. private val executorIdToLogs = new HashMap[String, Map[String, String]]() - // Convert TaskData to TaskTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data - private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - - private var _slicedTaskIds: Set[Long] = _ + private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = data.size + override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks - override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { - val r = data.slice(from, to) - _slicedTaskIds = r.map(_.taskId).toSet - r - } - - def slicedTaskIds: Set[Long] = _slicedTaskIds - - private def taskRow(info: TaskData): TaskTableRowData = { - val metrics = info.taskMetrics - val duration = info.duration.getOrElse(1L) - val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info, currentTime) - - val externalAccumulableReadable = info.accumulatorUpdates.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + override def sliceData(from: Int, to: Int): Seq[TaskData] = { + if (_tasksToShow == null) { + _tasksToShow = store.taskList(stage.stageId, stage.attemptId, from, to - from, + indexName(sortColumn), !desc) } - val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) - - val maybeInput = metrics.map(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)}") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.map(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.map(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.recordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime) - val writeTimeSortable = maybeWriteTime.getOrElse(0L) - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val input = - if (hasInput) { - Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) - } else { - None - } - - val output = - if (hasOutput) { - Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) - } else { - None - } - - val shuffleRead = - if (hasShuffleRead) { - Some(TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable, - shuffleReadBlockedTimeReadable, - shuffleReadSortable, - s"$shuffleReadReadable / $shuffleReadRecords", - shuffleReadRemoteSortable, - shuffleReadRemoteReadable - )) - } else { - None - } - - val shuffleWrite = - if (hasShuffleWrite) { - Some(TaskTableRowShuffleWriteData( - writeTimeSortable, - writeTimeReadable, - shuffleWriteSortable, - s"$shuffleWriteReadable / $shuffleWriteRecords" - )) - } else { - None - } - - val bytesSpilled = - if (hasBytesSpilled) { - Some(TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable, - memoryBytesSpilledReadable, - diskBytesSpilledSortable, - diskBytesSpilledReadable - )) - } else { - None - } - - new TaskTableRowData( - info.index, - info.taskId, - info.attempt, - info.speculative, - info.status, - info.taskLocality.toString, - info.executorId, - info.host, - info.launchTime.getTime(), - duration, - formatDuration, - schedulerDelay, - taskDeserializationTime, - gcTime, - serializationTime, - gettingResultTime, - peakExecutionMemoryUsed, - if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, - input, - output, - shuffleRead, - shuffleWrite, - bytesSpilled, - info.errorMessage.getOrElse(""), - executorLogs(info.executorId)) + _tasksToShow } - private def executorLogs(id: String): Map[String, String] = { + def tasks: Seq[TaskData] = _tasksToShow + + def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } - /** - * Return Ordering according to sortColumn and desc - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering: Ordering[TaskTableRowData] = sortColumn match { - case "Index" => Ordering.by(_.index) - case "ID" => Ordering.by(_.taskId) - case "Attempt" => Ordering.by(_.attempt) - case "Status" => Ordering.by(_.status) - case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID" => Ordering.by(_.executorId) - case "Host" => Ordering.by(_.host) - case "Launch Time" => Ordering.by(_.launchTime) - case "Duration" => Ordering.by(_.duration) - case "Scheduler Delay" => Ordering.by(_.schedulerDelay) - case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) - case "GC Time" => Ordering.by(_.gcTime) - case "Result Serialization Time" => Ordering.by(_.serializationTime) - case "Getting Result Time" => Ordering.by(_.gettingResultTime) - case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) - case "Accumulators" => - if (hasAccumulators) { - Ordering.by(_.accumulators.get) - } else { - throw new IllegalArgumentException( - "Cannot sort by Accumulators because of no accumulators") - } - case "Input Size / Records" => - if (hasInput) { - Ordering.by(_.input.get.inputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Input Size / Records because of no inputs") - } - case "Output Size / Records" => - if (hasOutput) { - Ordering.by(_.output.get.outputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Output Size / Records because of no outputs") - } - // ShuffleRead - case "Shuffle Read Blocked Time" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") - } - case "Shuffle Read Size / Records" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") - } - case "Shuffle Remote Reads" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Remote Reads because of no shuffle reads") - } - // ShuffleWrite - case "Write Time" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.writeTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Write Time because of no shuffle writes") - } - case "Shuffle Write Size / Records" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") - } - // BytesSpilled - case "Shuffle Spill (Memory)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Memory) because of no spills") - } - case "Shuffle Spill (Disk)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Disk) because of no spills") - } - case "Errors" => Ordering.by(_.error) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } - } - } private[ui] class TaskPagedTable( - conf: SparkConf, + stage: StageData, basePath: String, - data: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskData] { + + import ApiHelper._ override def tableId: String = "task-table" @@ -1142,13 +718,7 @@ private[ui] class TaskPagedTable( override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( - data, - hasAccumulators, - hasInput, - hasOutput, - hasShuffleRead, - hasShuffleWrite, - hasBytesSpilled, + stage, currentTime, pageSize, sortColumn, @@ -1180,22 +750,22 @@ private[ui] class TaskPagedTable( ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (hasShuffleRead) { + {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + {if (hasShuffleRead(stage)) { Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), ("Shuffle Read Size / Records", ""), ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ - {if (hasShuffleWrite) { + {if (hasShuffleWrite(stage)) { Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) } else { Nil }} ++ - {if (hasBytesSpilled) { + {if (hasBytesSpilled(stage)) { Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) } else { Nil @@ -1237,7 +807,17 @@ private[ui] class TaskPagedTable( {headerRow} } - def row(task: TaskTableRowData): Seq[Node] = { + def row(task: TaskData): Seq[Node] = { + def formatDuration(value: Option[Long], hideZero: Boolean = false): String = { + value.map { v => + if (v > 0 || !hideZero) UIUtils.formatDuration(v) else "" + }.getOrElse("") + } + + def formatBytes(value: Option[Long]): String = { + Utils.bytesToString(value.getOrElse(0L)) + } + {task.index} {task.taskId} @@ -1249,62 +829,98 @@ private[ui] class TaskPagedTable(
    {task.host}
    { - task.logs.map { + dataSource.executorLogs(task.executorId).map { case (logName, logUrl) => } }
    - {UIUtils.formatDate(new Date(task.launchTime))} - {task.formatDuration} + {UIUtils.formatDate(task.launchTime)} + {formatDuration(task.duration)} - {UIUtils.formatDuration(task.schedulerDelay)} + {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} - {UIUtils.formatDuration(task.taskDeserializationTime)} + {formatDuration(task.taskMetrics.map(_.executorDeserializeTime))} - {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + {formatDuration(task.taskMetrics.map(_.jvmGcTime), hideZero = true)} - {UIUtils.formatDuration(task.serializationTime)} + {formatDuration(task.taskMetrics.map(_.resultSerializationTime))} - {UIUtils.formatDuration(task.gettingResultTime)} + {UIUtils.formatDuration(AppStatusUtils.gettingResultTime(task))} - {Utils.bytesToString(task.peakExecutionMemoryUsed)} + {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} - {if (task.accumulators.nonEmpty) { - {Unparsed(task.accumulators.get)} + {if (hasAccumulators(stage)) { + accumulatorsInfo(task) }} - {if (task.input.nonEmpty) { - {task.input.get.inputReadable} + {if (hasInput(stage)) { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead) + val records = m.inputMetrics.recordsRead + {bytesRead} / {records} + } }} - {if (task.output.nonEmpty) { - {task.output.get.outputReadable} + {if (hasOutput(stage)) { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten) + val records = m.outputMetrics.recordsWritten + {bytesWritten} / {records} + } }} - {if (task.shuffleRead.nonEmpty) { + {if (hasShuffleRead(stage)) { - {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + {formatDuration(task.taskMetrics.map(_.shuffleReadMetrics.fetchWaitTime))} - {task.shuffleRead.get.shuffleReadReadable} + { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(totalBytesRead(m.shuffleReadMetrics)) + val records = m.shuffleReadMetrics.recordsRead + Unparsed(s"$bytesRead / $records") + } + } - {task.shuffleRead.get.shuffleReadRemoteReadable} + {formatBytes(task.taskMetrics.map(_.shuffleReadMetrics.remoteBytesRead))} }} - {if (task.shuffleWrite.nonEmpty) { - {task.shuffleWrite.get.writeTimeReadable} - {task.shuffleWrite.get.shuffleWriteReadable} + {if (hasShuffleWrite(stage)) { + { + formatDuration( + task.taskMetrics.map { m => + TimeUnit.NANOSECONDS.toMillis(m.shuffleWriteMetrics.writeTime) + }, + hideZero = true) + } + { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.shuffleWriteMetrics.bytesWritten) + val records = m.shuffleWriteMetrics.recordsWritten + Unparsed(s"$bytesWritten / $records") + } + } }} - {if (task.bytesSpilled.nonEmpty) { - {task.bytesSpilled.get.memoryBytesSpilledReadable} - {task.bytesSpilled.get.diskBytesSpilledReadable} + {if (hasBytesSpilled(stage)) { + {formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))} + {formatBytes(task.taskMetrics.map(_.diskBytesSpilled))} }} - {errorMessageCell(task.error)} + {errorMessageCell(task.errorMessage.getOrElse(""))} } + private def accumulatorsInfo(task: TaskData): Seq[Node] = { + task.accumulatorUpdates.map { acc => + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + } + } + + private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = { + task.taskMetrics.map(fn).getOrElse(Nil) + } + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default @@ -1333,6 +949,36 @@ private[ui] class TaskPagedTable( private object ApiHelper { + + private val COLUMN_TO_INDEX = Map( + "ID" -> null.asInstanceOf[String], + "Index" -> TaskIndexNames.TASK_INDEX, + "Attempt" -> TaskIndexNames.ATTEMPT, + "Status" -> TaskIndexNames.STATUS, + "Locality Level" -> TaskIndexNames.LOCALITY, + "Executor ID / Host" -> TaskIndexNames.EXECUTOR, + "Launch Time" -> TaskIndexNames.LAUNCH_TIME, + "Duration" -> TaskIndexNames.DURATION, + "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, + "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, + "GC Time" -> TaskIndexNames.GC_TIME, + "Result Serialization Time" -> TaskIndexNames.SER_TIME, + "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, + "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, + "Accumulators" -> TaskIndexNames.ACCUMULATORS, + "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, + "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, + "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, + "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, + "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, + "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, + "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, + "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, + "Errors" -> TaskIndexNames.ERROR) + + def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 def hasOutput(stageData: StageData): Boolean = stageData.outputBytes > 0 @@ -1349,4 +995,11 @@ private object ApiHelper { metrics.localBytesRead + metrics.remoteBytesRead } + def indexName(sortColumn: String): Option[String] = { + COLUMN_TO_INDEX.get(sortColumn) match { + case Some(v) => Option(v) + case _ => throw new IllegalArgumentException(s"Invalid sort column: $sortColumn") + } + } + } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index f8e27703c0def..5c42ac1d87f4c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 6.0, 53.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index a28bda16a956e..e6b705989cc97 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 4.0, 4.0, 6.0, 7.0, 9.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index ede3eaed1d1d2..788f28cf7b365 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 4.0, 6.0, 13.0, 40.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b8c84e24c2c3f..ca66b6b9db890 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -213,45 +213,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.taskId === task.taskId) + assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptNumber) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - - val runtime = Array[AnyRef](stages.head.stageId: JInteger, - stages.head.attemptNumber: JInteger, - -1L: JLong) - assert(Arrays.equals(wrapper.runtime, runtime)) - - assert(wrapper.info.index === task.index) - assert(wrapper.info.attempt === task.attemptNumber) - assert(wrapper.info.launchTime === new Date(task.launchTime)) - assert(wrapper.info.executorId === task.executorId) - assert(wrapper.info.host === task.host) - assert(wrapper.info.status === task.status) - assert(wrapper.info.taskLocality === task.taskLocality.toString()) - assert(wrapper.info.speculative === task.speculative) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.index === task.index) + assert(wrapper.attempt === task.attemptNumber) + assert(wrapper.launchTime === task.launchTime) + assert(wrapper.executorId === task.executorId) + assert(wrapper.host === task.host) + assert(wrapper.status === task.status) + assert(wrapper.taskLocality === task.taskLocality.toString()) + assert(wrapper.speculative === task.speculative) } } - // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. - s1Tasks.foreach { task => - val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), - Some(1L), None, true, false, None) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( - task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) - } + // Send two executor metrics update. Only update one metric to avoid a lot of boilerplate code. + // The tasks are distributed among the two executors, so the executor-level metrics should + // hold half of the cummulative value of the metric being updated. + Seq(1L, 2L).foreach { value => + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(value), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) + } - check[StageDataWrapper](key(stages.head)) { stage => - assert(stage.info.memoryBytesSpilled === s1Tasks.size) - } + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size * value) + } - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") - .first(key(stages.head)).last(key(stages.head)).asScala.toSeq - assert(execs.size > 0) - execs.foreach { exec => - assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size * value / 2) + } } // Fail one of the tasks, re-start it. @@ -278,13 +275,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](s1Tasks.head.taskId) { task => - assert(task.info.status === s1Tasks.head.status) - assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + assert(task.status === s1Tasks.head.status) + assert(task.errorMessage == Some(TaskResultLost.toErrorString)) } check[TaskDataWrapper](reattempt.taskId) { task => - assert(task.info.index === s1Tasks.head.index) - assert(task.info.attempt === reattempt.attemptNumber) + assert(task.index === s1Tasks.head.index) + assert(task.attempt === reattempt.attemptNumber) } // Kill one task, restart it. @@ -306,8 +303,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](killed.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some("killed")) + assert(task.index === killed.index) + assert(task.errorMessage === Some("killed")) } // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. @@ -334,8 +331,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](denied.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some(denyReason.toErrorString)) + assert(task.index === killed.index) + assert(task.errorMessage === Some(denyReason.toErrorString)) } // Start a new attempt. @@ -373,10 +370,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { pending.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.errorMessage === None) - assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) - assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) - assert(wrapper.info.duration === Some(task.duration)) + assert(wrapper.errorMessage === None) + assert(wrapper.executorCpuTime === 2L) + assert(wrapper.executorRunTime === 4L) + assert(wrapper.duration === task.duration) } } @@ -894,6 +891,23 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.count(classOf[StageDataWrapper]) === 3) assert(store.count(classOf[RDDOperationGraphWrapper]) === 3) + val dropped = stages.drop(1).head + + // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be + // calculcated, we need at least one finished task. + time += 1 + val task = createTasks(1, Array("1")).head + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + + time += 1 + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + "taskType", Success, task, null)) + + new AppStatusStore(store) + .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) + stages.drop(1).foreach { s => time += 1 s.completionTime = Some(time) @@ -905,6 +919,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { intercept[NoSuchElementException] { store.read(classOf[StageDataWrapper], Array(2, 0)) } + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0) val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") time += 1 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala new file mode 100644 index 0000000000000..92f90f3d96ddf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.TaskMetricDistributions +import org.apache.spark.util.Distribution +import org.apache.spark.util.kvstore._ + +class AppStatusStoreSuite extends SparkFunSuite { + + private val uiQuantiles = Array(0.0, 0.25, 0.5, 0.75, 1.0) + private val stageId = 1 + private val attemptId = 1 + + test("quantile calculation: 1 task") { + compareQuantiles(1, uiQuantiles) + } + + test("quantile calculation: few tasks") { + compareQuantiles(4, uiQuantiles) + } + + test("quantile calculation: more tasks") { + compareQuantiles(100, uiQuantiles) + } + + test("quantile calculation: lots of tasks") { + compareQuantiles(4096, uiQuantiles) + } + + test("quantile calculation: custom quantiles") { + compareQuantiles(4096, Array(0.01, 0.33, 0.5, 0.42, 0.69, 0.99)) + } + + test("quantile cache") { + val store = new InMemoryStore() + (0 until 4096).foreach { i => store.write(newTaskData(i)) } + + val appStore = new AppStatusStore(store) + + appStore.taskSummary(stageId, attemptId, Array(0.13d)) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "13")) + } + + appStore.taskSummary(stageId, attemptId, Array(0.25d)) + val d1 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + + // Add a new task to force the cached quantile to be evicted, and make sure it's updated. + store.write(newTaskData(4096)) + appStore.taskSummary(stageId, attemptId, Array(0.25d, 0.50d, 0.73d)) + + val d2 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + assert(d1.taskCount != d2.taskCount) + + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "50")) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "73")) + } + + assert(store.count(classOf[CachedQuantile]) === 2) + } + + private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = { + val store = new InMemoryStore() + val values = (0 until count).map { i => + val task = newTaskData(i) + store.write(task) + i.toDouble + }.toArray + + val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, quantiles).get + val dist = new Distribution(values, 0, values.length).getQuantiles(quantiles.sorted) + + dist.zip(summary.executorRunTime).foreach { case (expected, actual) => + assert(expected === actual) + } + } + + private def newTaskData(i: Int): TaskDataWrapper = { + new TaskDataWrapper( + i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, stageId, attemptId) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 661d0d48d2f37..0aeddf730cd35 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.config._ import org.apache.spark.ui.jobs.{StagePage, StagesTab} class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -35,15 +36,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory should displayed") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } @@ -52,7 +51,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. */ - private def renderStagePage(conf: SparkConf): Seq[Node] = { + private def renderStagePage(): Seq[Node] = { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) val statusStore = AppStatusStore.createLiveStore(conf) val listener = statusStore.listener.get diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7bdd3fac773a3..e2fa5754afaee 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -93,7 +93,7 @@ This file is divided into 3 sections: - + From 0552c36e02434c60dad82024334d291f6008b822 Mon Sep 17 00:00:00 2001 From: wuyi5 Date: Thu, 11 Jan 2018 22:17:15 +0900 Subject: [PATCH 0065/2461] [SPARK-22967][TESTS] Fix VersionSuite's unit tests by change Windows path into URI path ## What changes were proposed in this pull request? Two unit test will fail due to Windows format path: 1.test(s"$version: read avro file containing decimal") ``` org.apache.hadoop.hive.ql.metadata.HiveException: MetaException(message:java.lang.IllegalArgumentException: Can not create a Path from an empty string); ``` 2.test(s"$version: SPARK-17920: Insert into/overwrite avro table") ``` Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; org.apache.spark.sql.AnalysisException: Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; ``` This pr fix these two unit test by change Windows path into URI path. ## How was this patch tested? Existed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: wuyi5 Closes #20199 from Ngone51/SPARK-22967. --- .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff90e9dda5f7c..e64389e56b5a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -811,7 +811,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: read avro file containing decimal") { val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val location = new File(url.getFile) + val location = new File(url.getFile).toURI.toString val tableName = "tab1" val avroSchema = @@ -851,6 +851,8 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + // skipped because it's failed in the condition on Windows + assume(!(Utils.isWindows && version == "0.12")) withTempDir { dir => val avroSchema = """ @@ -875,10 +877,10 @@ class VersionsSuite extends SparkFunSuite with Logging { val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() - val schemaPath = schemaFile.getCanonicalPath + val schemaPath = schemaFile.toURI.toString val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile).getCanonicalPath + val srcLocation = new File(url.getFile).toURI.toString val destTableName = "tab1" val srcTableName = "tab2" From 76892bcf2c08efd7e9c5b16d377e623d82fe695e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 21:32:36 +0800 Subject: [PATCH 0066/2461] [SPARK-23000][TEST-HADOOP2.6] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite ## What changes were proposed in this pull request? The Spark 2.3 branch still failed due to the flaky test suite `DataSourceWithHiveMetastoreCatalogSuite `. https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ Although https://github.com/apache/spark/pull/20207 is unable to reproduce it in Spark 2.3, it sounds like the current DB of Spark's Catalog is changed based on the following stacktrace. Thus, we just need to reset it. ``` [info] DataSourceWithHiveMetastoreCatalogSuite: 02:40:39.486 ERROR org.apache.hadoop.hive.ql.parse.CalcitePlanner: org.apache.hadoop.hive.ql.parse.SemanticException: Line 1:14 Table not found 't' at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1594) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1545) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.genResolvedParseTree(SemanticAnalyzer.java:10077) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.analyzeInternal(SemanticAnalyzer.java:10128) at org.apache.hadoop.hive.ql.parse.CalcitePlanner.analyzeInternal(CalcitePlanner.java:209) at org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.analyze(BaseSemanticAnalyzer.java:227) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:424) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:308) at org.apache.hadoop.hive.ql.Driver.compileInternal(Driver.java:1122) at org.apache.hadoop.hive.ql.Driver.runInternal(Driver.java:1170) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1059) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1049) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:694) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1$$anonfun$apply$mcV$sp$3.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:185) at org.apache.spark.sql.test.SQLTestUtilsBase$class.withTable(SQLTestUtils.scala:273) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:139) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:186) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:183) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:289) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:196) at org.scalatest.FunSuite.runTest(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:396) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:384) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:384) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:379) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:461) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:229) at org.scalatest.FunSuite.runTests(FunSuite.scala:1560) at org.scalatest.Suite$class.run(Suite.scala:1147) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.SuperEngine.runImpl(Engine.scala:521) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:233) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:210) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:314) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:480) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20218 from gatorsmile/testFixAgain. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index cf4ce83124d88..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -148,6 +148,7 @@ class DataSourceWithHiveMetastoreCatalogSuite override def beforeAll(): Unit = { super.beforeAll() + sparkSession.sessionState.catalog.reset() sparkSession.metadataHive.reset() } From b46e58b74c82dac37b7b92284ea3714919c5a886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 22:33:42 +0900 Subject: [PATCH 0067/2461] [SPARK-19732][FOLLOW-UP] Document behavior changes made in na.fill and fillna ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18164 introduces the behavior changes. We need to document it. ## How was this patch tested? N/A Author: gatorsmile Closes #20234 from gatorsmile/docBehaviorChange. --- docs/sql-programming-guide.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 72f79d6909ecc..258c769ff593b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1788,12 +1788,10 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - - - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. ## Upgrading From Spark SQL 2.1 to 2.2 From 6d230dccf65300651f989392159d84bfaf08f18f Mon Sep 17 00:00:00 2001 From: FanDonglai Date: Thu, 11 Jan 2018 09:06:40 -0600 Subject: [PATCH 0068/2461] Update PageRank.scala ## What changes were proposed in this pull request? Hi, acording to code below, "if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0)" I think the comment can be wrong ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: FanDonglai Closes #20220 from ddna1021/master. --- .../src/main/scala/org/apache/spark/graphx/lib/PageRank.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index fd7b7f7c1c487..ebd65e8320e5c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -303,7 +303,7 @@ object PageRank extends Logging { val src: VertexId = srcId.getOrElse(-1L) // Initialize the pagerankGraph with each edge attribute - // having weight 1/outDegree and each vertex with attribute 1.0. + // having weight 1/outDegree and each vertex with attribute 0. val pagerankGraph: Graph[(Double, Double), Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { From 0b2eefb674151a0af64806728b38d9410da552ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 10:37:35 -0800 Subject: [PATCH 0069/2461] [SPARK-22994][K8S] Use a single image for all Spark containers. This change allows a user to submit a Spark application on kubernetes having to provide a single image, instead of one image for each type of container. The image's entry point now takes an extra argument that identifies the process that is being started. The configuration still allows the user to provide different images for each container type if they so desire. On top of that, the entry point was simplified a bit to share more code; mainly, the same env variable is used to propagate the user-defined classpath to the different containers. Aside from being modified to match the new behavior, the 'build-push-docker-images.sh' script was renamed to 'docker-image-tool.sh' to more closely match its purpose; the old name was a little awkward and now also not entirely correct, since there is a single image. It was also moved to 'bin' since it's not necessarily an admin tool. Docs have been updated to match the new behavior. Tested locally with minikube. Author: Marcelo Vanzin Closes #20192 from vanzin/SPARK-22994. --- .../docker-image-tool.sh | 68 ++++++------- docs/running-on-kubernetes.md | 58 +++++------ .../org/apache/spark/deploy/k8s/Config.scala | 17 ++-- .../apache/spark/deploy/k8s/Constants.scala | 3 +- .../deploy/k8s/InitContainerBootstrap.scala | 1 + .../steps/BasicDriverConfigurationStep.scala | 3 +- .../cluster/k8s/ExecutorPodFactory.scala | 3 +- .../DriverConfigOrchestratorSuite.scala | 12 +-- .../BasicDriverConfigurationStepSuite.scala | 4 +- ...InitContainerConfigOrchestratorSuite.scala | 4 +- .../cluster/k8s/ExecutorPodFactorySuite.scala | 4 +- .../src/main/dockerfiles/driver/Dockerfile | 35 ------- .../src/main/dockerfiles/executor/Dockerfile | 35 ------- .../dockerfiles/init-container/Dockerfile | 24 ----- .../main/dockerfiles/spark-base/entrypoint.sh | 37 ------- .../{spark-base => spark}/Dockerfile | 10 +- .../src/main/dockerfiles/spark/entrypoint.sh | 97 +++++++++++++++++++ 17 files changed, 189 insertions(+), 226 deletions(-) rename sbin/build-push-docker-images.sh => bin/docker-image-tool.sh (63%) delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile delete mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh rename resource-managers/kubernetes/docker/src/main/dockerfiles/{spark-base => spark}/Dockerfile (87%) create mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh diff --git a/sbin/build-push-docker-images.sh b/bin/docker-image-tool.sh similarity index 63% rename from sbin/build-push-docker-images.sh rename to bin/docker-image-tool.sh index b9532597419a5..071406336d1b1 100755 --- a/sbin/build-push-docker-images.sh +++ b/bin/docker-image-tool.sh @@ -24,29 +24,11 @@ function error { exit 1 } -# Detect whether this is a git clone or a Spark distribution and adjust paths -# accordingly. if [ -z "${SPARK_HOME}" ]; then SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi . "${SPARK_HOME}/bin/load-spark-env.sh" -if [ -f "$SPARK_HOME/RELEASE" ]; then - IMG_PATH="kubernetes/dockerfiles" - SPARK_JARS="jars" -else - IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" - SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -if [ ! -d "$IMG_PATH" ]; then - error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." -fi - -declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ - [spark-executor]="$IMG_PATH/executor/Dockerfile" \ - [spark-init]="$IMG_PATH/init-container/Dockerfile" ) - function image_ref { local image="$1" local add_repo="${2:-1}" @@ -60,35 +42,49 @@ function image_ref { } function build { - docker build \ - --build-arg "spark_jars=$SPARK_JARS" \ - --build-arg "img_path=$IMG_PATH" \ - -t spark-base \ - -f "$IMG_PATH/spark-base/Dockerfile" . - for image in "${!path[@]}"; do - docker build -t "$(image_ref $image)" -f ${path[$image]} . - done + local BUILD_ARGS + local IMG_PATH + + if [ ! -f "$SPARK_HOME/RELEASE" ]; then + # Set image build arguments accordingly if this is a source repo and not a distribution archive. + IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles + BUILD_ARGS=( + --build-arg + img_path=$IMG_PATH + --build-arg + spark_jars=assembly/target/scala-$SPARK_SCALA_VERSION/jars + ) + else + # Not passed as an argument to docker, but used to validate the Spark directory. + IMG_PATH="kubernetes/dockerfiles" + fi + + if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." + fi + + docker build "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$IMG_PATH/spark/Dockerfile" . } function push { - for image in "${!path[@]}"; do - docker push "$(image_ref $image)" - done + docker push "$(image_ref spark)" } function usage { cat < -t my-tag build - ./sbin/build-push-docker-images.sh -r -t my-tag push - -Docker files are under the `kubernetes/dockerfiles/` directory and can be customized further before -building using the supplied script, or manually. + ./bin/docker-image-tool.sh -r -t my-tag build + ./bin/docker-image-tool.sh -r -t my-tag push ## Cluster Mode @@ -79,8 +76,7 @@ $ bin/spark-submit \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.container.image= \ local:///path/to/examples.jar ``` @@ -126,13 +122,7 @@ Those dependencies can be added to the classpath by referencing them with `local ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. This requires users to specify the container -image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users -simply add the following option to the `spark-submit` command to specify the init-container image: - -``` ---conf spark.kubernetes.initContainer.image= -``` +the dependencies so the driver and executor containers can use them locally. The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and `spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., @@ -147,9 +137,7 @@ $ bin/spark-submit \ --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ - --conf spark.kubernetes.initContainer.image= + --conf spark.kubernetes.container.image= \ https://path/to/examples.jar ``` @@ -322,21 +310,27 @@ specific to Spark on Kubernetes. - spark.kubernetes.driver.container.image + spark.kubernetes.container.image (none) - Container image to use for the driver. - This is usually of the form example.com/repo/spark-driver:v1.0.0. - This configuration is required and must be provided by the user. + Container image to use for the Spark application. + This is usually of the form example.com/repo/spark:v1.0.0. + This configuration is required and must be provided by the user, unless explicit + images are provided for each different container type. + + + + spark.kubernetes.driver.container.image + (value of spark.kubernetes.container.image) + + Custom container image to use for the driver. spark.kubernetes.executor.container.image - (none) + (value of spark.kubernetes.container.image) - Container image to use for the executors. - This is usually of the form example.com/repo/spark-executor:v1.0.0. - This configuration is required and must be provided by the user. + Custom container image to use for executors. @@ -643,9 +637,9 @@ specific to Spark on Kubernetes. spark.kubernetes.initContainer.image - (none) + (value of spark.kubernetes.container.image) - Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. + Custom container image for the init container of both driver and executors. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e5d79d9a9d9da..471196ac0e3f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -29,17 +29,23 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("default") + val CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.container.image") + .doc("Container image to use for Spark containers. Individual container types " + + "(e.g. driver or executor) can also be configured to use different images if desired, " + + "by setting the container type-specific image name.") + .stringConf + .createOptional + val DRIVER_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.driver.container.image") .doc("Container image to use for the driver.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val EXECUTOR_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.executor.container.image") .doc("Container image to use for the executors.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val CONTAINER_IMAGE_PULL_POLICY = ConfigBuilder("spark.kubernetes.container.image.pullPolicy") @@ -148,8 +154,7 @@ private[spark] object Config extends Logging { val INIT_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.initContainer.image") .doc("Image for the driver and executor's init-container for downloading dependencies.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val INIT_CONTAINER_MOUNT_TIMEOUT = ConfigBuilder("spark.kubernetes.mountDependencies.timeout") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 111cb2a3b75e5..9411956996843 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -60,10 +60,9 @@ private[spark] object Constants { val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" - val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala index dfeccf9e2bd1c..f6a57dfe00171 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -77,6 +77,7 @@ private[spark] class InitContainerBootstrap( .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) .endVolumeMount() .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs("init") .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index eca46b84c6066..164e2e5594778 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -66,7 +66,7 @@ private[spark] class BasicDriverConfigurationStep( override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => new EnvVarBuilder() - .withName(ENV_SUBMIT_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(classPath) .build() } @@ -133,6 +133,7 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits("memory", driverMemoryLimitQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() + .addToArgs("driver") .build() val baseDriverPod = new PodBuilder(driverSpec.driverPod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index bcacb3934d36a..141bd2827e7c5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -128,7 +128,7 @@ private[spark] class ExecutorPodFactory( .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() - .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(cp) .build() } @@ -181,6 +181,7 @@ private[spark] class ExecutorPodFactory( .endResources() .addAllToEnv(executorEnv.asJava) .withPorts(requiredPorts.asJava) + .addToArgs("executor") .build() val executorPod = new PodBuilder() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index f193b1f4d3664..65274c6f50e01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -34,8 +34,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, @@ -55,8 +54,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { } test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, @@ -75,8 +73,8 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with an init-container.") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( @@ -98,7 +96,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with driver secrets to mount") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index 8ee629ac8ddc1..b136f2c02ffba 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -47,7 +47,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(DRIVER_CONTAINER_IMAGE, "spark-driver:latest") + .set(CONTAINER_IMAGE, "spark-driver:latest") .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") @@ -79,7 +79,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .asScala .map(env => (env.getName, env.getValue)) .toMap - assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") + assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala index 20f2e5bc15df3..09b42e4484d86 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -40,7 +40,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including basic configuration step") { val sparkConf = new SparkConf(true) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) val orchestrator = new InitContainerConfigOrchestrator( @@ -59,7 +59,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including step to mount user-specified secrets") { val sparkConf = new SparkConf(false) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7cfbe54c95390..a3c615be031d2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -54,7 +54,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef baseConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(EXECUTOR_CONTAINER_IMAGE, executorImage) + .set(CONTAINER_IMAGE, executorImage) } test("basic executor pod has reasonable defaults") { @@ -107,7 +107,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkEnv(executor, Map("SPARK_JAVA_OPT_0" -> "foo=bar", - "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + ENV_CLASSPATH -> "bar=baz", "qux" -> "quux")) checkOwnerReferences(executor, driverPodUid) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile deleted file mode 100644 index 45fbcd9cd0deb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile deleted file mode 100644 index 0f806cf7e148e..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile deleted file mode 100644 index 047056ab2633b..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# If this docker file is being used in the context of building your images from a Spark distribution, the docker build -# command should be invoked from the top level directory of the Spark distribution. E.g.: -# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . - -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh deleted file mode 100755 index 82559889f4beb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# echo commands to the terminal output -set -ex - -# Check whether there is a passwd entry for the container UID -myuid=$(id -u) -mygid=$(id -g) -uidentry=$(getent passwd $myuid) - -# If there is no passwd entry for the container UID, attempt to create one -if [ -z "$uidentry" ] ; then - if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd - else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" - fi -fi - -# Execute the container CMD under tini for better hygiene -/sbin/tini -s -- "$@" diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile similarity index 87% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile rename to resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index da1d6b9e161cc..491b7cf692478 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,15 +17,15 @@ FROM openjdk:8-alpine -ARG spark_jars -ARG img_path +ARG spark_jars=jars +ARG img_path=kubernetes/dockerfiles # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-base:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ apk upgrade --no-cache && \ @@ -41,7 +41,9 @@ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY ${img_path}/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh new file mode 100755 index 0000000000000..0c28c75857871 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# echo commands to the terminal output +set -ex + +# Check whether there is a passwd entry for the container UID +myuid=$(id -u) +mygid=$(id -g) +uidentry=$(getent passwd $myuid) + +# If there is no passwd entry for the container UID, attempt to create one +if [ -z "$uidentry" ] ; then + if [ -w /etc/passwd ] ; then + echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + else + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + fi +fi + +SPARK_K8S_CMD="$1" +if [ -z "$SPARK_K8S_CMD" ]; then + echo "No command to execute has been provided." 1>&2 + exit 1 +fi +shift 1 + +SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" +env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +fi +if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then + cp -R "$SPARK_MOUNTED_FILES_DIR/." . +fi + +case "$SPARK_K8S_CMD" in + driver) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_DRIVER_JAVA_OPTS[@]}" + -cp "$SPARK_CLASSPATH" + -Xms$SPARK_DRIVER_MEMORY + -Xmx$SPARK_DRIVER_MEMORY + -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS + $SPARK_DRIVER_CLASS + $SPARK_DRIVER_ARGS + ) + ;; + + executor) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + -Xms$SPARK_EXECUTOR_MEMORY + -Xmx$SPARK_EXECUTOR_MEMORY + -cp "$SPARK_CLASSPATH" + org.apache.spark.executor.CoarseGrainedExecutorBackend + --driver-url $SPARK_DRIVER_URL + --executor-id $SPARK_EXECUTOR_ID + --cores $SPARK_EXECUTOR_CORES + --app-id $SPARK_APPLICATION_ID + --hostname $SPARK_EXECUTOR_POD_IP + ) + ;; + + init) + CMD=( + "$SPARK_HOME/bin/spark-class" + "org.apache.spark.deploy.k8s.SparkPodInitContainer" + "$@" + ) + ;; + + *) + echo "Unknown command: $SPARK_K8S_CMD" 1>&2 + exit 1 +esac + +# Execute the container CMD under tini for better hygiene +exec /sbin/tini -s -- "${CMD[@]}" From 6f7aaed805070d29dcba32e04ca7a1f581fa54b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 11 Jan 2018 10:52:12 -0800 Subject: [PATCH 0070/2461] [SPARK-22908] Add kafka source and sink for continuous processing. ## What changes were proposed in this pull request? Add kafka source and sink for continuous processing. This involves two small changes to the execution engine: * Bring data reader close() into the normal data reader thread to avoid thread safety issues. * Fix up the semantics of the RECONFIGURING StreamExecution state. State updates are now atomic, and we don't have to deal with swallowing an exception. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20096 from jose-torres/continuous-kafka. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 +++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ++++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 64 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 +++++++++-------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1531 insertions(+), 383 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..928379544758c --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs = Long.MaxValue, + failOnDataLoss) + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..27da76068a66f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..dfc97b1c38bb5 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..e713e6695d2bd --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..d66908f86ccc7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,7 +94,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -97,16 +109,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +151,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +395,51 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) + .option("startingOffsets", s"earliest") .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock + .load() - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() @@ -328,7 +451,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -422,65 +545,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -540,34 +604,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -629,8 +665,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -676,6 +710,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -706,6 +744,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -723,47 +762,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -800,9 +798,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -843,9 +839,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -977,20 +971,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1004,6 +986,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..e700aa4f9aea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -201,6 +200,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 186bf8fb2e9ff8a80f3f6bcb5f2a0327fa79a1c9 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 11 Jan 2018 13:57:15 -0800 Subject: [PATCH 0071/2461] [SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline ## What changes were proposed in this pull request? Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes. ## How was this patch tested? Unit tests. Author: Bago Amirbekian Closes #20238 from MrBago/rFormulaVectorSize. --- R/pkg/R/mllib_utils.R | 1 + .../apache/spark/ml/feature/RFormula.scala | 18 +++++++-- .../spark/ml/feature/RFormulaSuite.scala | 37 ++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..23dda42c325be 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,3 +130,4 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7da3339f8b487..f384ffbf578bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.util._ @@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => + dataset.schema(term).dataType match { + case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() .setInputCol(term) @@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(dataset.schema(term)) + val size = if (group.size < 0) { + dataset.select(term).first().getAs[Vector](0).size + } else { + group.size + } + encoderStages += new VectorSizeHint(uid) + .setHandleInvalid("optimistic") + .setInputCol(term) + .setSize(size) + (term, term) case _ => (term, term) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6dfa..f3f4b5a3d0233 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.types.DoubleType -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result3.collect() === expected3.collect()) assert(result4.collect() === expected4.collect()) } + + test("Use Vectors as inputs to formula.") { + val original = Seq( + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + val (first +: rest) = Seq("id", "a", "b", "features", "label") + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + + val group = new AttributeGroup("b", 3) + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) + val model = formula.fit(dfWithMetadata) + // model should work even when applied to dataframe without metadata. + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + } } From b5042d75c2faa5f15bc1e160d75f06dfdd6eea37 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 11 Jan 2018 16:20:30 -0800 Subject: [PATCH 0072/2461] [SPARK-23008][ML] OnehotEncoderEstimator python API ## What changes were proposed in this pull request? OnehotEncoderEstimator python API. ## How was this patch tested? doctest Author: WeichenXu Closes #20209 from WeichenXu123/ohe_py. --- python/pyspark/ml/feature.py | 113 ++++++++++++++++++ .../ml/param/_shared_params_code_gen.py | 1 + python/pyspark/ml/param/shared.py | 23 ++++ 3 files changed, 137 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 13bf95cce40be..b963e45dd7cff 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -45,6 +45,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', + 'OneHotEncoderEstimator', 'OneHotEncoderModel', 'PCA', 'PCAModel', 'PolynomialExpansion', 'QuantileDiscretizer', @@ -1641,6 +1642,118 @@ def getDropLast(self): return self.getOrDefault(self.dropLast) +@inherit_doc +class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + A one-hot encoder that maps a column of category indices to a column of binary vectors, with + at most a single one-value per row that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to an output vector of + `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via `dropLast`), + because it makes the vector entries sum up to one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + + Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories. + The output vectors are sparse. + + When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + vector. + + Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output + cols come in pairs, specified by the order in the arrays, and each pair is treated + independently. + + See `StringIndexer` for converting categorical values into category indices + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"]) + >>> model = ohe.fit(df) + >>> model.transform(df).head().output + SparseVector(2, {0: 1.0}) + >>> ohePath = temp_path + "/oheEstimator" + >>> ohe.save(ohePath) + >>> loadedOHE = OneHotEncoderEstimator.load(ohePath) + >>> loadedOHE.getInputCols() == ohe.getInputCols() + True + >>> modelPath = temp_path + "/ohe-model" + >>> model.save(modelPath) + >>> loadedModel = OneHotEncoderModel.load(modelPath) + >>> loadedModel.categorySizes == model.categorySizes + True + + .. versionadded:: 2.3.0 + """ + + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " + + "transform(). Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error). Note that this Param " + + "is only used during transform; during fitting, invalid data will " + + "result in an error.", + typeConverter=TypeConverters.toString) + + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + """ + super(OneHotEncoderEstimator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid) + self._setDefault(handleInvalid="error", dropLast=True) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + Sets params for this OneHotEncoderEstimator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + return self._set(dropLast=value) + + @since("2.3.0") + def getDropLast(self): + """ + Gets the value of dropLast or its default value. + """ + return self.getOrDefault(self.dropLast) + + def _create_model(self, java_model): + return OneHotEncoderModel(java_model) + + +class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`OneHotEncoderEstimator`. + + .. versionadded:: 2.3.0 + """ + + @property + @since("2.3.0") + def categorySizes(self): + """ + Original number of categories for each feature being encoded. + The array contains one value for each input column, in order. + """ + return self._call_java("categorySizes") + + @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 1d0f60acc6983..db951d81de1e7 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -119,6 +119,7 @@ def get$Name(self): ("inputCol", "input column name.", None, "TypeConverters.toString"), ("inputCols", "input column names.", None, "TypeConverters.toListString"), ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("outputCols", "output column names.", None, "TypeConverters.toListString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 813f7a59f3fd1..474c38764e5a1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -256,6 +256,29 @@ def getOutputCol(self): return self.getOrDefault(self.outputCol) +class HasOutputCols(Params): + """ + Mixin for param outputCols: output column names. + """ + + outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString) + + def __init__(self): + super(HasOutputCols, self).__init__() + + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def getOutputCols(self): + """ + Gets the value of outputCols or its default value. + """ + return self.getOrDefault(self.outputCols) + + class HasNumFeatures(Params): """ Mixin for param numFeatures: number of features. From cbe7c6fbf9dc2fc422b93b3644c40d449a869eea Mon Sep 17 00:00:00 2001 From: ho3rexqj Date: Fri, 12 Jan 2018 15:27:00 +0800 Subject: [PATCH 0073/2461] [SPARK-22986][CORE] Use a cache to avoid instantiating multiple instances of broadcast variable values When resources happen to be constrained on an executor the first time a broadcast variable is instantiated it is persisted to disk by the BlockManager. Consequently, every subsequent call to TorrentBroadcast::readBroadcastBlock from other instances of that broadcast variable spawns another instance of the underlying value. That is, broadcast variables are spawned once per executor **unless** memory is constrained, in which case every instance of a broadcast variable is provided with a unique copy of the underlying value. This patch fixes the above by explicitly caching the underlying values using weak references in a ReferenceMap. Author: ho3rexqj Closes #20183 from ho3rexqj/fix/cache-broadcast-values. --- .../spark/broadcast/BroadcastManager.scala | 6 ++ .../spark/broadcast/TorrentBroadcast.scala | 72 +++++++++++-------- .../spark/broadcast/BroadcastSuite.scala | 34 +++++++++ 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2e..8d7a4a353a792 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag +import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging @@ -52,6 +54,10 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) + private[broadcast] val cachedValues = { + new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK) + } + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7aecd3c9668ea..e125095cf4777 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -206,36 +206,50 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { - setConf(SparkEnv.get.conf) - val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId) match { - case Some(blockResult) => - if (blockResult.data.hasNext) { - val x = blockResult.data.next().asInstanceOf[T] - releaseLock(broadcastId) - x - } else { - throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") - } - case None => - logInfo("Started reading broadcast variable " + id) - val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks() - logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - - try { - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + + Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + + if (x != null) { + broadcastCache.put(broadcastId, x) + } + + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } - obj - } finally { - blocks.foreach(_.dispose()) - } + case None => + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + + if (obj != null) { + broadcastCache.put(broadcastId, obj) + } + + obj + } finally { + blocks.foreach(_.dispose()) + } + } } } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 159629825c677..9ad2e9a5e74ac 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(broadcast.value.sum === 10) } + test("One broadcast value instance per executor") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + + test("One broadcast value instance per executor when memory is constrained") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * From a7d98d53ceaf69cabaecc6c9113f17438c4e61f6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 12 Jan 2018 11:27:02 +0200 Subject: [PATCH 0074/2461] [SPARK-23008][ML][FOLLOW-UP] mark OneHotEncoder python API deprecated ## What changes were proposed in this pull request? mark OneHotEncoder python API deprecated ## How was this patch tested? N/A Author: WeichenXu Closes #20241 from WeichenXu123/mark_ohe_deprecated. --- python/pyspark/ml/feature.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b963e45dd7cff..eb79b193103e2 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1578,6 +1578,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories. The output vectors are sparse. + .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to + :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0. + .. seealso:: :py:class:`StringIndexer` for converting categorical values into From 505086806997b4331d4a8c2fc5e08345d869a23c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 18:04:44 +0800 Subject: [PATCH 0075/2461] [SPARK-23025][SQL] Support Null type in scala reflection ## What changes were proposed in this pull request? Add support for `Null` type in the `schemaFor` method for Scala reflection. ## How was this patch tested? Added UT Author: Marco Gaido Closes #20219 from mgaido91/SPARK-23025. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++++ .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 3 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 65040f1af4b04..9a4bf0075a178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection { private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { + case t if t <:< definitions.NullTpe => NullType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType @@ -712,6 +713,9 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { + // this must be the first case, since all objects in scala are instances of Null, therefore + // Null type would wrongly match the first of them, which is Option as of now + case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 23e866cdf4917..8c3db48a01f12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -356,4 +356,13 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } + + test("SPARK-23025: schemaFor should support Null type") { + val schema = schemaFor[(Int, Null)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", NullType, nullable = true))), + nullable = true)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d535896723bd5..54893c184642b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1441,6 +1441,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) } } + + test("SPARK-23025: Add support for null type in scala reflection") { + val data = Seq(("a", null)) + checkDataset(data.toDS(), data: _*) + } } case class SingleData(id: Int) From f5300fbbe370af3741560f67bfb5ae6f0b0f7bb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20Beaup=C3=A8re?= Date: Fri, 12 Jan 2018 08:29:46 -0600 Subject: [PATCH 0076/2461] Update rdd-programming-guide.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Small typing correction - double word ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Matthias Beaupère Closes #20212 from matthiasbe/patch-1. --- docs/rdd-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 29af159510e46..2e29aef7f21a2 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -91,7 +91,7 @@ so C libraries like NumPy can be used. It also works with PyPy 2.3+. Python 2.6 support was removed in Spark 2.2.0. -Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including including it in your setup.py as: +Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including it in your setup.py as: {% highlight python %} install_requires=[ From 651f76153f5e9b185aaf593161d40cabe7994fea Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 13 Jan 2018 00:37:59 +0800 Subject: [PATCH 0077/2461] [SPARK-23028] Bump master branch version to 2.4.0-SNAPSHOT ## What changes were proposed in this pull request? This patch bumps the master branch version to `2.4.0-SNAPSHOT`. ## How was this patch tested? N/A Author: gatorsmile Closes #20222 from gatorsmile/bump24. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- dev/run-tests-jenkins.py | 4 ++-- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ python/pyspark/version.py | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 43 files changed, 49 insertions(+), 44 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 6d46c31906260..855eb5bf77f16 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.3.0 +Version: 2.4.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/assembly/pom.xml b/assembly/pom.xml index b3b4239771bc3..a207dae5a74ff 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index cf93d41cd77cf..8c148359c3029 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 18cbdadd224ab..8ca7733507f1b 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 9968480ab7658..05335df61a664 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ec2db6e5bb88c..564e6583c909e 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2d59c71cc3757..2f04abe8c7e88 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f7e586ee777e1..ba127408e1c59 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a3772a2620088..1527854730394 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 0a5bd958fc9c5..9258a856028a0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 914eb93622d51..3960a0de62530 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,8 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 300m) - tests_timeout = "250m" + # must be less than the timeout configured on Jenkins (currently 350m) + tests_timeout = "300m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/docs/_config.yml b/docs/_config.yml index dcc211204d766..095fadb93fe5d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.3.0 +SPARK_VERSION: 2.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.4.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.8" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 1791dbaad775e..868110b8e35ef 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 485b562dce990..431339d412194 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 71016bc645ca7..7cd1ec4c9c09a 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 12630840e79dc..f810aa80e8780 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 87a09642405a7..498e88f665eb5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index d6f97316b326a..a742b8d6dbddb 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 0c9f0aa765a39..16bbc6db641ca 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 6eb7ba5f0092d..3b124b2a69d50 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 786349474389b..41bc8b3e3ee1f 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 849c8b465f99e..6d1c4789f382d 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 48783d65826aa..37c7d1e604ec5 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 40a751a652fa9..4915893965595 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 36d555066b181..027157e53d511 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index cb30e4a4af4bc..fbe77fcb958d5 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index aa36dd4774d86..8e424b1c50236 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index e9b46c4cf0ffa..912eb6b6d2a08 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 043d13609fd26..53286fe93478d 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a906c9e02cd4c..f07d7f24fd312 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 1b37164376460..d14594aa4ccb0 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b452f35c5ec1..32eb31f495979 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.4.x + lazy val v24excludes = v23excludes ++ Seq( + ) + // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-22897] Expose stageAttemptId in TaskContext @@ -1082,6 +1086,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 12dd53b9d2902..b9c2c4ced71d5 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.3.0.dev0" +__version__ = "2.4.0.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index 1cb0098d0eca3..6f4a863c48bc7 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 7d35aea8a4142..a62f271273465 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 70d0c1750b14e..3995d0afeb5f4 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 43a7ce95bd3de..37e25ceecb883 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9e2ced30407d4..839b929abd3cb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 93010c606cf45..744daa6079779 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3135a8a275dae..9f247f9224c75 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 66fad85ea0263..c55ba32fa458c 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index fea882ad11230..4497e53b65984 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 37427e8da62d8..242219e29f50f 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml From 7bd14cfd40500a0b6462cda647bdbb686a430328 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 12 Jan 2018 10:18:42 -0800 Subject: [PATCH 0078/2461] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up the java-lint errors (for v2.3.0-rc1 tag). Hopefully, this will be the final one. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[85] (sizes) LineLength: Line is longer than 100 characters (found 101). [ERROR] src/main/java/org/apache/spark/launcher/InProcessAppHandle.java:[20,8] (imports) UnusedImports: Unused import - java.io.IOException. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java:[41,9] (modifier) ModifierOrder: 'private' modifier out of order with the JLS suggestions. [ERROR] src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java:[464] (sizes) LineLength: Line is longer than 100 characters (found 102). ``` ## How was this patch tested? Manual. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` Author: Dongjoon Hyun Closes #20242 from dongjoon-hyun/fix_lint_java_2.3_rc1. --- .../org/apache/spark/unsafe/memory/HeapMemoryAllocator.java | 3 ++- .../java/org/apache/spark/launcher/InProcessAppHandle.java | 1 - .../spark/sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 3acfe3696cb1e..a9603c1aba051 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -82,7 +82,8 @@ public void free(MemoryBlock memory) { "page has already been freed"; assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : - "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + + "free()"; final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 0d6a73a3da3ed..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.io.IOException; import java.lang.reflect.Method; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index f94c55d860304..b6e792274da11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -38,7 +38,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto private BytesColumnVector bytesData; private DecimalColumnVector decimalData; private TimestampColumnVector timestampData; - final private boolean isTimestamp; + private final boolean isTimestamp; private int batchSize; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4f8a31f185724..69a2904f5f3fe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -461,7 +461,8 @@ public void testCircularReferenceBean() { public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); - String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)) + .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); From 54277398afbde92a38ba2802f4a7a3e5910533de Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 11:25:37 -0800 Subject: [PATCH 0079/2461] [SPARK-22975][SS] MetricsReporter should not throw exception when there was no progress reported ## What changes were proposed in this pull request? `MetricsReporter ` assumes that there has been some progress for the query, ie. `lastProgress` is not null. If this is not true, as it might happen in particular conditions, a `NullPointerException` can be thrown. The PR checks whether there is a `lastProgress` and if this is not true, it returns a default value for the metrics. ## How was this patch tested? added UT Author: Marco Gaido Closes #20189 from mgaido91/SPARK-22975. --- .../execution/streaming/MetricsReporter.scala | 21 ++++++++--------- .../sql/streaming/StreamingQuerySuite.scala | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index b84e6ce64c611..66b11ecddf233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} - -import scala.collection.mutable - import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} -import org.apache.spark.util.Clock +import org.apache.spark.sql.streaming.StreamingQueryProgress /** * Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to @@ -39,14 +35,17 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group - registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) - registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) - - private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0) + registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) + registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + + private def registerGauge[T]( + name: String, + f: StreamingQueryProgress => T, + default: T): Unit = { synchronized { metricRegistry.register(name, new Gauge[T] { - override def getValue: T = f() + override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2fa4595dab376..76201c63a2701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -424,6 +424,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22975: MetricsReporter defaults when there was no progress reported") { + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + + val gauges = sq.streamMetrics.metricRegistry.getGauges + assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + sq.stop() + } + } + } + test("input row calculation with mixed batch and streaming sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") From 55dbfbca37ce4c05f83180777ba3d4fe2d96a02e Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 12 Jan 2018 15:00:00 -0800 Subject: [PATCH 0080/2461] Revert "[SPARK-22908] Add kafka source and sink for continuous processing." This reverts commit 6f7aaed805070d29dcba32e04ca7a1f581fa54b9. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 --------- .../sql/kafka010/KafkaContinuousWriter.scala | 119 ----- .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +--- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 +-- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ------------------ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ---- .../sql/kafka010/KafkaContinuousTest.scala | 64 --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 ++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 +-- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 383 insertions(+), 1531 deletions(-) delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala deleted file mode 100644 index 928379544758c..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} - -import org.apache.kafka.clients.consumer.ConsumerRecord -import org.apache.kafka.common.TopicPartition - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String - -/** - * A [[ContinuousReader]] for data from kafka. - * - * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be - * read by per-task consumers generated later. - * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which - * are not Kafka consumer params. - * @param metadataPath Path to a directory this reader can use for writing metadata. - * @param initialOffsets The Kafka offsets to start reading data at. - * @param failOnDataLoss Flag indicating whether reading should fail in data loss - * scenarios, where some offsets after the specified initial ones can't be - * properly read. - */ -class KafkaContinuousReader( - offsetReader: KafkaOffsetReader, - kafkaParams: ju.Map[String, Object], - sourceOptions: Map[String, String], - metadataPath: String, - initialOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext - - // Initialized when creating read tasks. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets - } - } - - override def getStartOffset(): Offset = offset - - override def deserializeOffset(json: String): Offset = { - KafkaSourceOffset(JsonUtils.partitionOffsets(json)) - } - - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - - startOffsets.toSeq.map { - case (topicPartition, start) => - KafkaContinuousReadTask( - topicPartition, start, kafkaParams, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] - }.asJava - } - - /** Stop this source and free any resources it has allocated. */ - def stop(): Unit = synchronized { - offsetReader.close() - } - - override def commit(end: Offset): Unit = {} - - override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { - val mergedMap = offsets.map { - case KafkaSourcePartitionOffset(p, o) => Map(p -> o) - }.reduce(_ ++ _) - KafkaSourceOffset(mergedMap) - } - - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions - } - - override def toString(): String = s"KafkaSource[$offsetReader]" - - /** - * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. - * Otherwise, just log a warning. - */ - private def reportDataLoss(message: String): Unit = { - if (failOnDataLoss) { - throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") - } else { - logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") - } - } -} - -/** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. - * - * @param topicPartition The (topic, partition) pair this task is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -case class KafkaContinuousReadTask( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) - } -} - -/** - * A per-task data reader for continuous Kafka processing. - * - * @param topicPartition The (topic, partition) pair this data reader is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -class KafkaContinuousDataReader( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) - - private var nextKafkaOffset = startOffset - private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ - - override def next(): Boolean = { - var r: ConsumerRecord[Array[Byte], Array[Byte]] = null - while (r == null) { - r = consumer.get( - nextKafkaOffset, - untilOffset = Long.MaxValue, - pollTimeoutMs = Long.MaxValue, - failOnDataLoss) - } - nextKafkaOffset = r.offset + 1 - currentRecord = r - true - } - - override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow - } - - override def getOffset(): KafkaSourcePartitionOffset = { - KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) - } - - override def close(): Unit = { - consumer.close() - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala deleted file mode 100644 index 9843f469c5b25..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} -import scala.collection.JavaConverters._ - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} -import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} - -/** - * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we - * don't need to really send one. - */ -case object KafkaWriterCommitMessage extends WriterCommitMessage - -/** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. - * @param topic The topic this writer is responsible for. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -class KafkaContinuousWriter( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { - - validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} -} - -/** - * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate - * the per-task data writers. - * @param topic The topic that should be written to. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -case class KafkaContinuousWriterFactory( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) - } -} - -/** - * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to - * process incoming rows. - * - * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred - * from a `topic` field in the incoming data. - * @param producerParams Parameters to use for the Kafka producer. - * @param inputSchema The attributes in the input data. - */ -class KafkaContinuousDataWriter( - targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) - extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - import scala.collection.JavaConverters._ - - private lazy val producer = CachedKafkaProducer.getOrCreate( - new java.util.HashMap[String, Object](producerParams.asJava)) - - def write(row: InternalRow): Unit = { - checkForErrors() - sendRow(row, producer) - } - - def commit(): WriterCommitMessage = { - // Send is asynchronous, but we can't commit until all rows are actually in Kafka. - // This requires flushing and then checking that no callbacks produced errors. - // We also check for errors before to fail as soon as possible - the check is cheap. - checkForErrors() - producer.flush() - checkForErrors() - KafkaWriterCommitMessage - } - - def abort(): Unit = {} - - def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) - } - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..3e65949a6fd1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,14 +117,10 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. - * - * @param partitionOffsets the specific offsets to resolve - * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long], - reportDataLoss: String => Unit): KafkaSourceOffset = { - val fetched = runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = + runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -149,19 +145,6 @@ private[kafka010] class KafkaOffsetReader( } } - partitionOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (fetched(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(fetched) - } - /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 27da76068a66f..e9cff04ba5f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,6 +138,21 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } + private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { + val result = kafkaReader.fetchSpecificOffsets(specificOffsets) + specificOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (result(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(result) + } + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..b5da415b3097e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,22 +20,17 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } -private[kafka010] -case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) - extends PartitionOffset - /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..3cb4d8cad12cc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, Optional, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -27,12 +27,9 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,8 +43,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport - with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -106,43 +101,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - override def createContinuousReader( - schema: Optional[StructType], - metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap - - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - - new KafkaContinuousReader( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) - } - /** * Returns a new base relation with the given parameters. * @@ -223,22 +181,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - import scala.collection.JavaConverters._ - - val spark = SparkSession.getActiveSession.get - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - KafkaWriter.validateQuery( - schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -488,27 +450,4 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } - - private[kafka010] def kafkaParamsForProducer( - parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } - - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..6fd333e2f43ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { + topic: Option[String]) { // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -44,7 +46,23 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - sendRow(currentRow, producer) + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) } } @@ -56,49 +74,8 @@ private[kafka010] class KafkaWriteTask( producer = null } } -} - -private[kafka010] abstract class KafkaRowWriter( - inputSchema: Seq[Attribute], topic: Option[String]) { - - // used to synchronize with Kafka callbacks - @volatile protected var failedWrite: Exception = _ - protected val projection = createProjection - - private val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - /** - * Send the specified row to the producer, with a callback that will save any exception - * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before - * assuming the row is in Kafka. - */ - protected def sendRow( - row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { - val projectedRow = projection(row) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - producer.send(record, callback) - } - - protected def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } - - private def createProjection = { + private def createProjection: UnsafeProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -135,5 +112,11 @@ private[kafka010] abstract class KafkaRowWriter( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } + + private def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..5e9ae35b3f008 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,9 +43,10 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - schema: Seq[Attribute], + queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -83,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(schema, kafkaParameters, topic) + validateQuery(queryExecution, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala deleted file mode 100644 index dfc97b1c38bb5..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ /dev/null @@ -1,474 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.kafka.clients.producer.ProducerConfig -import org.apache.kafka.common.serialization.ByteArraySerializer -import org.scalatest.time.SpanSugar._ -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{BinaryType, DataType} -import org.apache.spark.util.Utils - -/** - * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. - * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have - * to duplicate all the code. - */ -class KafkaContinuousSinkSuite extends KafkaContinuousTest { - import testImplicits._ - - override val streamingTimeout = 30.seconds - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - } - super.afterAll() - } - - test("streaming - write to kafka with topic field") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = None, - withOutputMode = Some(OutputMode.Append))( - withSelectExpr = s"'$topic' as topic", "value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - write w/o topic field, with topic option") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))() - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - topic field and topic option") { - /* The purpose of this test is to ensure that the topic option - * overrides the topic field. We begin by writing some data that - * includes a topic field and value (e.g., 'foo') along with a topic - * option. Then when we read from the topic specified in the option - * we should see the data i.e., the data was written to the topic - * option, and not to the topic in the data e.g., foo - */ - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))( - withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("null topic attribute") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "CAST(null as STRING) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getCause.getCause.getMessage - .toLowerCase(Locale.ROOT) - .contains("null topic present in the data.")) - } - - test("streaming - write data with bad schema") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "value as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage - .toLowerCase(Locale.ROOT) - .contains("topic option required when no 'topic' attribute is present")) - - try { - /* No value field */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as key" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "required attribute 'value' not found")) - } - - test("streaming - write data with valid schema but wrong types") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - .selectExpr("CAST(value as STRING) value") - val topic = newTopic() - testUtils.createTopic(topic) - - var writer: StreamingQuery = null - var ex: Exception = null - try { - /* topic field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"CAST('1' as INT) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - - try { - /* value field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) - - try { - ex = intercept[StreamingQueryException] { - /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) - } - - test("streaming - write to non-existing topic") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - } - throw writer.exception.get - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) - } - - test("streaming - exception on config serializer") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - testUtils.sendMessages(inputTopic, Array("0")) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .load() - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.key.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - } finally { - writer.stop() - } - - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.value.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) - } finally { - writer.stop() - } - } - - test("generic - write big data with small producer buffer") { - /* This test ensures that we understand the semantics of Kafka when - * is comes to blocking on a call to send when the send buffer is full. - * This test will configure the smallest possible producer buffer and - * indicate that we should block when it is full. Thus, no exception should - * be thrown in the case of a full buffer. - */ - val topic = newTopic() - testUtils.createTopic(topic, 1) - val options = new java.util.HashMap[String, String] - options.put("bootstrap.servers", testUtils.brokerAddress) - options.put("buffer.memory", "16384") // min buffer size - options.put("block.on.buffer.full", "true") - options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - val inputSchema = Seq(AttributeReference("value", BinaryType)()) - val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) - try { - val fieldTypes: Array[DataType] = Array(BinaryType) - val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) - row.update(0, data) - val iter = Seq.fill(1000)(converter.apply(row)).iterator - iter.foreach(writeTask.write(_)) - writeTask.commit() - } finally { - writeTask.close() - } - } - - private def createKafkaReader(topic: String): DataFrame = { - spark.read - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("startingOffsets", "earliest") - .option("endingOffsets", "latest") - .option("subscribe", topic) - .load() - } - - private def createKafkaWriter( - input: DataFrame, - withTopic: Option[String] = None, - withOutputMode: Option[OutputMode] = None, - withOptions: Map[String, String] = Map[String, String]()) - (withSelectExpr: String*): StreamingQuery = { - var stream: DataStreamWriter[Row] = null - val checkpointDir = Utils.createTempDir() - var df = input.toDF() - if (withSelectExpr.length > 0) { - df = df.selectExpr(withSelectExpr: _*) - } - stream = df.writeStream - .format("kafka") - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - // We need to reduce blocking time to efficiently test non-existent partition behavior. - .option("kafka.max.block.ms", "1000") - .trigger(Trigger.Continuous(1000)) - .queryName("kafkaStream") - withTopic.foreach(stream.option("topic", _)) - withOutputMode.foreach(stream.outputMode(_)) - withOptions.foreach(opt => stream.option(opt._1, opt._2)) - stream.start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala deleted file mode 100644 index b3dade414f625..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} - -// Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest - -class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { - import testImplicits._ - - override val brokerProps = Map("auto.create.topics.enable" -> "false") - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Execute { query => - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists { r => - // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) - }, - s"query never reconfigured to new topic $topic2") - } - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } -} - -class KafkaContinuousSourceStressForDontFailOnDataLossSuite - extends KafkaSourceStressForDontFailOnDataLossSuite { - override protected def startStream(ds: Dataset[Int]) = { - ds.writeStream - .format("memory") - .queryName("memory") - .start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala deleted file mode 100644 index e713e6695d2bd..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.spark.SparkContext -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.sql.test.TestSparkSession - -// Trait to configure StreamTest for kafka continuous execution tests. -trait KafkaContinuousTest extends KafkaSourceTest { - override val defaultTrigger = Trigger.Continuous(1000) - override val defaultUseV2Sink = true - - // We need more than the default local[2] to be able to schedule all partitions simultaneously. - override protected def createSparkSession = new TestSparkSession( - new SparkContext( - "local[10]", - "continuous-stream-test-sql-context", - sparkConf.set("spark.sql.testkey", "true"))) - - // In addition to setting the partitions in Kafka, we have to wait until the query has - // reconfigured to the new count so the test framework can hook in properly. - override protected def setTopicPartitions( - topic: String, newCount: Int, query: StreamExecution) = { - testUtils.addPartitions(topic, newCount) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists(_.knownPartitions.size == newCount), - s"query never reconfigured to $newCount partitions") - } - } - - test("ensure continuous stream is being used") { - val query = spark.readStream - .format("rate") - .option("numPartitions", "1") - .option("rowsPerSecond", "1") - .load() - - testStream(query)( - Execute(q => assert(q.isInstanceOf[ContinuousExecution])) - ) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index d66908f86ccc7..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,14 +34,11 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -52,11 +49,9 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds - protected val brokerProps = Map[String, Object]() - override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils(brokerProps) + testUtils = new KafkaTestUtils testUtils.setup() } @@ -64,25 +59,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null + super.afterAll() } - super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, // we don't know which data should be fetched when `startingOffsets` is latest. - q match { - case c: ContinuousExecution => c.awaitEpoch(0) - case m: MicroBatchExecution => m.processAllAvailable() - } + q.processAllAvailable() true } - protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { - testUtils.addPartitions(topic, newCount) - } - /** * Add data to Kafka. * @@ -94,7 +82,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -109,18 +97,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -151,158 +137,14 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } - - private val topicId = new AtomicInteger(0) - protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { - - import testImplicits._ - - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) - .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } - - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) - } - - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") +class KafkaSourceSuite extends KafkaSourceTest { - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) + import testImplicits._ - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } + private val topicId = new AtomicInteger(0) testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -395,51 +237,86 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() + test("(de)serialization of initial offsets") { val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) + testUtils.createTopic(topic, partitions = 64) - val kafka = spark + val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) } -} -class KafkaSourceSuiteBase extends KafkaSourceTest { + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) - import testImplicits._ + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } test("cannot stop Kafka stream") { val topic = newTopic() @@ -451,7 +328,7 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topic.*") + .option("subscribePattern", s"topic-.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -545,6 +422,65 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("starting offset is latest by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("0")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val kafka = reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + val mapped = kafka.map(_.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(1, 2, 3) // should not have 0 + ) + } + test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -604,6 +540,34 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -665,6 +629,8 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -710,10 +676,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, - Execute { q => - // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) - }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -744,7 +706,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") - .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -762,6 +723,47 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { query.stop() } + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() + } + private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -798,7 +800,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -839,7 +843,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -971,23 +977,6 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - test("stress test for failOnDataLoss=false") { val reader = spark .readStream @@ -1001,7 +990,20 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = { + true + } + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = { + } + }).start() val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..e8d683a578f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,9 +191,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -211,30 +208,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => null // fall back to v1 + case _ => + throw new AnalysisException(s"$cls does not support data reading.") } - if (reader == null) { - loadV1Source(paths: _*) - } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) - } + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } else { - loadV1Source(paths: _*) + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) } } - private def loadV1Source(paths: String*) = { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) - } - /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..3304f368e1050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,24 +255,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() + case _ => throw new AnalysisException(s"$cls does not support data writing.") } } else { - saveToV1Source() - } - } - - private def saveToV1Source(): Unit = { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..f0bdf84bb7a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,11 +81,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -419,17 +418,11 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - if (sources == null) { - // sources might not be initialized yet - false - } else { - val source = sources(sourceIndex) - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset - } + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { @@ -443,7 +436,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") + logDebug(s"Unblocked at $newOffset for $source") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index e700aa4f9aea7..d79e4bd65f563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,6 +77,7 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { + reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -200,8 +201,6 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. - } finally { - reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.UUID import java.util.concurrent.TimeUnit -import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -54,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -80,17 +78,15 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - val stateUpdate = new UnaryOperator[State] { - override def apply(s: State) = s match { - // If we ended the query to reconfigure, reset the state to active. - case RECONFIGURING => ACTIVE - case _ => s - } - } - do { - runContinuous(sparkSessionForStream) - } while (state.updateAndGet(stateUpdate) == ACTIVE) + try { + runContinuous(sparkSessionForStream) + } catch { + case _: InterruptedException if state.get().equals(RECONFIGURING) => + // swallow exception and run again + state.set(ACTIVE) + } + } while (state.get() == ACTIVE) } /** @@ -124,16 +120,12 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + // Forcibly align commit and offset logs by slicing off any spurious offset logs from + // a previous run. We can't allow commits to an epoch that a previous run reached but + // this run has not. + offsetLog.purgeAfter(latestEpochId) + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -149,7 +141,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -234,11 +225,13 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + if (reader.needsReconfiguration()) { + state.set(RECONFIGURING) stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() + // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -266,7 +259,6 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { - epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -281,22 +273,17 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - val oldOffset = synchronized { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - offsetLog.get(epoch - 1) + if (partitionOffsets.contains(null)) { + // If any offset is null, that means the corresponding partition hasn't seen any data yet, so + // there's nothing meaningful to add to the offset log. } - - // If offset hasn't changed since last epoch, there's been no new data. - if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { - noNewData = true - } - - awaitProgressLock.lock() - try { - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + synchronized { + if (queryExecutionThread.isAlive) { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + } else { + return + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..98017c3ac6a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,15 +39,6 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage -/** - * The RpcEndpoint stop() will wait to clear out the message queue before terminating the - * object. This can lead to a race condition where the query restarts at epoch n, a new - * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. - * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous - * message to stop any writes to the ContinuousExecution object. - */ -private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage - // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -125,8 +116,6 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var queryWritesStopped: Boolean = false - private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -158,16 +147,12 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + partitionCommits.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { - // If we just drop these messages, we won't do any writes to the query. The lame duck tasks - // won't shed errors or anything. - case _ if queryWritesStopped => () - case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -203,9 +188,5 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) - - case StopContinuousExecutionWrites => - queryWritesStopped = true - context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..db588ae282f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,29 +279,18 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) - } - + val dataSource = + DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - sink, + dataSource.createSink(outputMode), outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,8 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -81,9 +80,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } - protected val defaultTrigger = Trigger.ProcessingTime(0) - protected val defaultUseV2Sink = false - /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -193,7 +189,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = defaultTrigger, + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -280,7 +276,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -407,11 +403,18 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") + // Get the map of source index to the current source objects + val indexToSource = currentStream + .logicalPlan + .collect { case StreamingExecutionRelation(s, _) => s } + .zipWithIndex + .map(_.swap) + .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(indexToSource(sourceIndex), offset) } } @@ -470,12 +473,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } - case _ => - } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -603,10 +600,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { - case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r - } + .collect { case StreamingExecutionRelation(s, _) => s } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -619,13 +613,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } }.getOrElse { throw new IllegalArgumentException( - "Could not find index of the source to which data was added") + "Could find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From cd9f49a2aed3799964976ead06080a0f7044a0c3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 13 Jan 2018 16:13:44 +0900 Subject: [PATCH 0081/2461] [SPARK-22980][PYTHON][SQL] Clarify the length of each series is of each batch within scalar Pandas UDF ## What changes were proposed in this pull request? This PR proposes to add a note that saying the length of a scalar Pandas UDF's `Series` is not of the whole input column but of the batch. We are fine for a group map UDF because the usage is different from our typical UDF but scalar UDFs might cause confusion with the normal UDF. For example, please consider this example: ```python from pyspark.sql.functions import pandas_udf, col, lit df = spark.range(1) f = pandas_udf(lambda x, y: len(x) + y, LongType()) df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 1| +------------------+ ``` ```python from pyspark.sql.functions import udf, col, lit df = spark.range(1) f = udf(lambda x, y: len(x) + y, "long") df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 4| +------------------+ ``` ## How was this patch tested? Manually built the doc and checked the output. Author: hyukjinkwon Closes #20237 from HyukjinKwon/SPARK-22980. --- python/pyspark/sql/functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 733e32bd825b0..e1ad6590554cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2184,6 +2184,11 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 8| JOHN DOE| 22| +----------+--------------+------------+ + .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input + column, but is the length of an internal batch used for each call to the function. + Therefore, this can be used, for example, to ensure the length of each returned + `pandas.Series`, and can not be used as the column length. + 2. GROUP_MAP A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` From 628a1ca5a4d14397a90e9e96a7e03e8f63531b20 Mon Sep 17 00:00:00 2001 From: shimamoto Date: Sat, 13 Jan 2018 09:40:00 -0600 Subject: [PATCH 0082/2461] [SPARK-23043][BUILD] Upgrade json4s to 3.5.3 ## What changes were proposed in this pull request? Spark still use a few years old version 3.2.11. This change is to upgrade json4s to 3.5.3. Note that this change does not include the Jackson update because the Jackson version referenced in json4s 3.5.3 is 2.8.4, which has a security vulnerability ([see](https://issues.apache.org/jira/browse/SPARK-20433)). ## How was this patch tested? Existing unit tests and build. Author: shimamoto Closes #20233 from shimamoto/upgrade-json4s. --- .../deploy/history/HistoryServerSuite.scala | 2 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 19 ++++++++++--------- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- pom.xml | 13 +++++++------ 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3738f85da5831..87778dda0e2c8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -486,7 +486,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers json match { case JNothing => Seq() case apps: JArray => - apps.filter(app => { + apps.children.filter(app => { (app \ "attempts") match { case attempts: JArray => val state = (attempts.children.head \ "completed").asInstanceOf[JBool] diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 326546787ab6c..ed51fc445fdfb 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -131,7 +131,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val storageJson = getJson(ui, "storage/rdd") storageJson.children.length should be (1) - (storageJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) + (storageJson.children.head \ "storageLevel").extract[String] should be ( + StorageLevels.DISK_ONLY.description) val rddJson = getJson(ui, "storage/rdd/0") (rddJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) @@ -150,7 +151,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedStorageJson = getJson(ui, "storage/rdd") updatedStorageJson.children.length should be (1) - (updatedStorageJson \ "storageLevel").extract[String] should be ( + (updatedStorageJson.children.head \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( @@ -204,7 +205,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } val stageJson = getJson(sc.ui.get, "stages") stageJson.children.length should be (1) - (stageJson \ "status").extract[String] should be (StageStatus.FAILED.name()) + (stageJson.children.head \ "status").extract[String] should be (StageStatus.FAILED.name()) // Regression test for SPARK-2105 class NotSerializable @@ -325,11 +326,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") - (jobJson \ "numTasks").extract[Int]should be (2) - (jobJson \ "numCompletedTasks").extract[Int] should be (3) - (jobJson \ "numFailedTasks").extract[Int] should be (1) - (jobJson \ "numCompletedStages").extract[Int] should be (2) - (jobJson \ "numFailedStages").extract[Int] should be (1) + (jobJson \\ "numTasks").extract[Int]should be (2) + (jobJson \\ "numCompletedTasks").extract[Int] should be (3) + (jobJson \\ "numFailedTasks").extract[Int] should be (1) + (jobJson \\ "numCompletedStages").extract[Int] should be (2) + (jobJson \\ "numFailedStages").extract[Int] should be (1) val stageJson = getJson(sc.ui.get, "stages") for { @@ -656,7 +657,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) - val attempts = (appListJsonAst \ "attempts").children + val attempts = (appListJsonAst.children.head \ "attempts").children attempts.size should be (1) (attempts(0) \ "completed").extract[Boolean] should be (false) parseDate(attempts(0) \ "startTime") should be (sc.startTime) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index a7fce2ede0ea5..2a298769be44c 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -167,7 +168,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 94b2e98d85e74..abee326f283ab 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -168,7 +169,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/pom.xml b/pom.xml index d14594aa4ccb0..666d5d7169a15 100644 --- a/pom.xml +++ b/pom.xml @@ -705,7 +705,13 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.11 + 3.5.3 + + + com.fasterxml.jackson.core + * + + org.scala-lang @@ -732,11 +738,6 @@ scala-parser-combinators_${scala.binary.version} 1.0.4 - - org.scala-lang - scalap - ${scala.version} - jline From fc6fe8a1d0f161c4788f3db94de49a8669ba3bcc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 13 Jan 2018 10:01:44 -0600 Subject: [PATCH 0083/2461] [SPARK-22870][CORE] Dynamic allocation should allow 0 idle time ## What changes were proposed in this pull request? This pr to make `0` as a valid value for `spark.dynamicAllocation.executorIdleTimeout`. For details, see the jira description: https://issues.apache.org/jira/browse/SPARK-22870. ## How was this patch tested? N/A Author: Yuming Wang Author: Yuming Wang Closes #20080 from wangyum/SPARK-22870. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 2e00dc8b49dd5..6c59038f2a6c1 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -195,8 +195,11 @@ private[spark] class ExecutorAllocationManager( throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeoutS <= 0) { - throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + if (executorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be >= 0!") + } + if (cachedExecutorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.cachedExecutorIdleTimeout must be >= 0!") } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors From bd4a21b4820c4ebaf750131574a6b2eeea36907e Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 14 Jan 2018 02:28:57 +0800 Subject: [PATCH 0084/2461] [SPARK-23036][SQL][TEST] Add withGlobalTempView for testing ## What changes were proposed in this pull request? Add withGlobalTempView when create global temp view, like withTempView and withView. And correct some improper usage. Please see jira. There are other similar place like that. I will fix it if community need. Please confirm it. ## How was this patch tested? no new test. Author: xubo245 <601450868@qq.com> Closes #20228 from xubo245/DropTempView. --- .../sql/execution/GlobalTempViewSuite.scala | 55 ++++++++----------- .../spark/sql/execution/SQLViewSuite.scala | 34 +++++++----- .../apache/spark/sql/test/SQLTestUtils.scala | 21 +++++-- 3 files changed, 59 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index cc943e0356f2a..dcc6fa6403f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -36,7 +36,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("basic semantic") { val expectedErrorMsg = "not found" - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, @@ -79,19 +79,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) - } finally { - spark.catalog.dropGlobalTempView("src") } } test("global temp view is shared among all sessions") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) val newSession = spark.newSession() checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) - } finally { - spark.catalog.dropGlobalTempView("src") } } @@ -105,27 +101,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE GLOBAL TEMP VIEW USING") { withTempPath { path => - try { + withGlobalTempView("src") { Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.toURI}')") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) - } finally { - spark.catalog.dropGlobalTempView("src") } } } test("CREATE TABLE LIKE should work for global temp view") { - try { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") - val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) - assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) - } finally { - spark.catalog.dropGlobalTempView("src") - sql("DROP TABLE default.cloned") + withTable("cloned") { + withGlobalTempView("src") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType() + .add("a", "int", false).add("b", "string", false)) + } } } @@ -146,26 +140,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } test("should lookup global temp view if and only if global temp db is specified") { - try { - sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") - sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + withTempView("same_name") { + withGlobalTempView("same_name") { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") - checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) - // we never lookup global temp views if database is not specified in table name - spark.catalog.dropTempView("same_name") - intercept[AnalysisException](sql("SELECT * FROM same_name")) + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) - // Use qualified name to lookup a global temp view. - checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) - } finally { - spark.catalog.dropTempView("same_name") - spark.catalog.dropGlobalTempView("same_name") + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } } } test("public Catalog should recognize global temp view") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") assert(spark.catalog.tableExists(globalTempDB, "src")) @@ -175,8 +168,6 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { description = null, tableType = "TEMPORARY", isTemporary = true).toString) - } finally { - spark.catalog.dropGlobalTempView("src") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 08a4a21b20f61..8c55758cfe38d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -69,21 +69,25 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a permanent view on a temp view") { - withView("jtv1", "temp_jtv1", "global_temp_jtv1") { - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") - var e = intercept[AnalysisException] { - sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `jtv1` by " + - "referencing a temporary view `temp_jtv1`")) - - val globalTempDB = spark.sharedState.globalTempViewManager.database - sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") - e = intercept[AnalysisException] { - sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + - s"a temporary view `global_temp`.`global_temp_jtv1`")) + withView("jtv1") { + withTempView("temp_jtv1") { + withGlobalTempView("global_temp_jtv1") { + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + var e = intercept[AnalysisException] { + sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `jtv1` by " + + "referencing a temporary view `temp_jtv1`")) + + val globalTempDB = spark.sharedState.globalTempViewManager.database + sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") + e = intercept[AnalysisException] { + sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + + s"a temporary view `global_temp`.`global_temp_jtv1`")) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 904f9f2ad0b22..bc4a120f7042f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -254,13 +254,26 @@ private[sql] trait SQLTestUtilsBase } /** - * Drops temporary table `tableName` after calling `f`. + * Drops temporary view `viewNames` after calling `f`. */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + protected def withTempView(viewNames: String*)(f: => Unit): Unit = { try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { + // temp views that never got created. + try viewNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops global temporary view `viewNames` after calling `f`. + */ + protected def withGlobalTempView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // global temp views that never got created. + try viewNames.foreach(spark.catalog.dropGlobalTempView) catch { case _: NoSuchTableException => } } From ba891ec993c616dc4249fc786c56ea82ed04a827 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 14 Jan 2018 02:36:32 +0800 Subject: [PATCH 0085/2461] [SPARK-22790][SQL] add a configurable factor to describe HadoopFsRelation's size ## What changes were proposed in this pull request? as per discussion in https://github.com/apache/spark/pull/19864#discussion_r156847927 the current HadoopFsRelation is purely based on the underlying file size which is not accurate and makes the execution vulnerable to errors like OOM Users can enable CBO with the functionalities in https://github.com/apache/spark/pull/19864 to avoid this issue This JIRA proposes to add a configurable factor to sizeInBytes method in HadoopFsRelation class so that users can mitigate this problem without CBO ## How was this patch tested? Existing tests Author: CodingCat Author: Nan Zhu Closes #20072 from CodingCat/SPARK-22790. --- .../apache/spark/sql/internal/SQLConf.scala | 13 +++++- .../datasources/HadoopFsRelation.scala | 6 ++- .../datasources/HadoopFsRelationSuite.scala | 41 +++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 36e802a9faa6f..6746fbcaf2483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -249,7 +249,7 @@ object SQLConf { val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") .internal() .doc("When true, the query optimizer will infer and propagate data constraints in the query " + - "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive " + "for certain kinds of query plans (such as those with a large number of predicates and " + "aliases) which might negatively impact overall runtime.") .booleanConf @@ -263,6 +263,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val FILE_COMRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor") + .internal() + .doc("When estimating the output data size of a table scan, multiply the file size with this " + + "factor as the estimated data size, in case the data is compressed in the file and lead to" + + " a heavily underestimated result.") + .doubleConf + .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0") + .createWithDefault(1.0) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -1255,6 +1264,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 89d8a85a9cbd2..6b34638529770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -82,7 +82,11 @@ case class HadoopFsRelation( } } - override def sizeInBytes: Long = location.sizeInBytes + override def sizeInBytes: Long = { + val compressionFactor = sqlContext.conf.fileCompressionFactor + (location.sizeInBytes * compressionFactor).toLong + } + override def inputFiles: Array[String] = location.inputFiles } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index caf03885e3873..c1f2c18d1417d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.io.{File, FilenameFilter} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.test.SharedSQLContext class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { @@ -39,4 +40,44 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } + + test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") { + import testImplicits._ + Seq(1.0, 0.5).foreach { compressionFactor => + withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, + "spark.sql.autoBroadcastJoinThreshold" -> "400") { + withTempPath { workDir => + // the file size is 740 bytes + val workDirPath = workDir.getAbsolutePath + val data1 = Seq(100, 200, 300, 400).toDF("count") + data1.write.parquet(workDirPath + "/data1") + val df1FromFile = spark.read.parquet(workDirPath + "/data1") + val data2 = Seq(100, 200, 300, 400).toDF("count") + data2.write.parquet(workDirPath + "/data2") + val df2FromFile = spark.read.parquet(workDirPath + "/data2") + val joinedDF = df1FromFile.join(df2FromFile, Seq("count")) + if (compressionFactor == 0.5) { + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.nonEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.isEmpty) + } else { + // compressionFactor is 1.0 + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.isEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } + } + } + } + } } From 0066d6f6fa604817468471832968d4339f71c5cb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 05:39:38 +0800 Subject: [PATCH 0086/2461] [SPARK-21213][SQL][FOLLOWUP] Use compatible types for comparisons in compareAndGetNewStats ## What changes were proposed in this pull request? This pr fixed code to compare values in `compareAndGetNewStats`. The test below fails in the current master; ``` val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) val newStats5 = CommandUtils.compareAndGetNewStats( Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) assert(newStats5.isEmpty) ``` ## How was this patch tested? Added some tests in `CommandUtilsSuite`. Author: Takeshi Yamamuro Closes #20245 from maropu/SPARK-21213-FOLLOWUP. --- .../sql/execution/command/CommandUtils.scala | 4 +- .../execution/command/CommandUtilsSuite.scala | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 1a0d67fc71fbc..c27048626c8eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -116,8 +116,8 @@ object CommandUtils extends Logging { oldStats: Option[CatalogStatistics], newTotalSize: BigInt, newRowCount: Option[BigInt]): Option[CatalogStatistics] = { - val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L) - val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + val oldTotalSize = oldStats.map(_.sizeInBytes).getOrElse(BigInt(-1)) + val oldRowCount = oldStats.flatMap(_.rowCount).getOrElse(BigInt(-1)) var newStats: Option[CatalogStatistics] = None if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala new file mode 100644 index 0000000000000..f3e15189a6418 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics + +class CommandUtilsSuite extends SparkFunSuite { + + test("Check if compareAndGetNewStats returns correct results") { + val oldStats1 = CatalogStatistics(sizeInBytes = 10, rowCount = Some(100)) + val newStats1 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 10, newRowCount = Some(100)) + assert(newStats1.isEmpty) + val newStats2 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = None) + assert(newStats2.isEmpty) + val newStats3 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 20, newRowCount = Some(-1)) + assert(newStats3.isDefined) + newStats3.foreach { stat => + assert(stat.sizeInBytes === 20) + assert(stat.rowCount.isEmpty) + } + val newStats4 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = Some(200)) + assert(newStats4.isDefined) + newStats4.foreach { stat => + assert(stat.sizeInBytes === 10) + assert(stat.rowCount.isDefined && stat.rowCount.get === 200) + } + } + + test("Check if compareAndGetNewStats can handle large values") { + // Tests for large values + val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) + val newStats5 = CommandUtils.compareAndGetNewStats( + Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) + assert(newStats5.isEmpty) + } +} From afae8f2bc82597593595af68d1aa2d802210ea8b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 14 Jan 2018 11:26:49 +0900 Subject: [PATCH 0087/2461] [SPARK-22959][PYTHON] Configuration to select the modules for daemon and worker in PySpark ## What changes were proposed in this pull request? We are now forced to use `pyspark/daemon.py` and `pyspark/worker.py` in PySpark. This doesn't allow a custom modification for it (well, maybe we can still do this in a super hacky way though, for example, setting Python executable that has the custom modification). Because of this, for example, it's sometimes hard to debug what happens inside Python worker processes. This is actually related with [SPARK-7721](https://issues.apache.org/jira/browse/SPARK-7721) too as somehow Coverage is unable to detect the coverage from `os.fork`. If we have some custom fixes to force the coverage, it works fine. This is also related with [SPARK-20368](https://issues.apache.org/jira/browse/SPARK-20368). This JIRA describes Sentry support which (roughly) needs some changes within worker side. With this configuration advanced users will be able to do a lot of pluggable workarounds and we can meet such potential needs in the future. As an example, let's say if I configure the module `coverage_daemon` and had `coverage_daemon.py` in the python path: ```python import os from pyspark import daemon if "COVERAGE_PROCESS_START" in os.environ: from pyspark.worker import main def _cov_wrapped(*args, **kwargs): import coverage cov = coverage.coverage( config_file=os.environ["COVERAGE_PROCESS_START"]) cov.start() try: main(*args, **kwargs) finally: cov.stop() cov.save() daemon.worker_main = _cov_wrapped if __name__ == '__main__': daemon.manager() ``` I can track the coverages in worker side too. More importantly, we can leave the main code intact but allow some workarounds. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20151 from HyukjinKwon/configuration-daemon-worker. --- .../api/python/PythonWorkerFactory.scala | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index f53c6178047f5..30976ac752a8a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -34,10 +34,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String import PythonWorkerFactory._ - // Because forking processes from Java is expensive, we prefer to launch a single Python daemon - // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently - // only works on UNIX-based systems now because it uses signals for child management, so we can - // also fall back to launching workers (pyspark/worker.py) directly. + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon, + // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon + // currently only works on UNIX-based systems now because it uses signals for child management, + // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. val useDaemon = { val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) @@ -45,6 +45,28 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled } + // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are + // for very advanced users and they are experimental. This should be considered + // as expert-only option, and shouldn't be used before knowing what it means exactly. + + // This configuration indicates the module to run the daemon to execute its Python workers. + val daemonModule = SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value => + logInfo( + s"Python daemon module in PySpark is set to [$value] in 'spark.python.daemon.module', " + + "using this to start the daemon up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is enabled and the platform is not Windows.") + value + }.getOrElse("pyspark.daemon") + + // This configuration indicates the module to run each Python worker. + val workerModule = SparkEnv.get.conf.getOption("spark.python.worker.module").map { value => + logInfo( + s"Python worker module in PySpark is set to [$value] in 'spark.python.worker.module', " + + "using this to start the worker up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is disabled or the platform is Windows.") + value + }.getOrElse("pyspark.worker") + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -74,8 +96,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself - * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + * Connect to a worker launched through pyspark/daemon.py (by default), which forks python + * processes itself to avoid the high cost of forking from Java. This currently only works + * on UNIX-based systems. */ private def createThroughDaemon(): Socket = { @@ -108,7 +131,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Launch a worker by executing worker.py directly and telling it to connect to us. + * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ private def createSimpleWorker(): Socket = { var serverSocket: ServerSocket = null @@ -116,7 +139,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -159,7 +182,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) From c3548d11c3c57e8f2c6ebd9d2d6a3924ddcd3cba Mon Sep 17 00:00:00 2001 From: foxish Date: Sat, 13 Jan 2018 21:34:28 -0800 Subject: [PATCH 0088/2461] [SPARK-23063][K8S] K8s changes for publishing scripts (and a couple of other misses) ## What changes were proposed in this pull request? Including the `-Pkubernetes` flag in a few places it was missed. ## How was this patch tested? checkstyle, mima through manual tests. Author: foxish Closes #20256 from foxish/SPARK-23063. --- dev/create-release/release-build.sh | 4 ++-- dev/create-release/releaseutils.py | 2 ++ dev/deps/spark-deps-hadoop-2.6 | 11 +++++++++++ dev/deps/spark-deps-hadoop-2.7 | 11 +++++++++++ dev/lint-java | 2 +- dev/mima | 2 +- dev/scalastyle | 1 + dev/test-dependencies.sh | 2 +- 8 files changed, 30 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c71137468054f..a3579f21fc539 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -92,9 +92,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 730138195e5fe..32f6cbb29f0be 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -185,6 +185,8 @@ def get_commits(tag): "graphx": "GraphX", "input/output": CORE_COMPONENT, "java api": "Java API", + "k8s": "Kubernetes", + "kubernetes": "Kubernetes", "mesos": "Mesos", "ml": "MLlib", "mllib": "MLlib", diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2a298769be44c..48e54568e6fc6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -131,10 +135,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -147,6 +154,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -171,6 +180,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -186,5 +196,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index abee326f283ab..1807a77900e52 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -132,10 +136,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -148,6 +155,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -172,6 +181,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -187,5 +197,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/lint-java b/dev/lint-java index c2e80538ef2a5..1f0b0c8379ed0 100755 --- a/dev/lint-java +++ b/dev/lint-java @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" -ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) if test ! -z "$ERRORS"; then echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/mima b/dev/mima index 1e3ca9700bc07..cd2694ff4d3de 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8c..b8053df05fa2b 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -24,6 +24,7 @@ ERRORS=$(echo -e "q\n" \ -Pkinesis-asl \ -Pmesos \ -Pkafka-0-8 \ + -Pkubernetes \ -Pyarn \ -Pflume \ -Phive \ diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e00..3bf7618e1ea96 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 From 7a3d0aad2b89aef54f7dd580397302e9ff984d9d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 13 Jan 2018 23:26:12 -0800 Subject: [PATCH 0089/2461] [SPARK-23038][TEST] Update docker/spark-test (JDK/OS) ## What changes were proposed in this pull request? This PR aims to update the followings in `docker/spark-test`. - JDK7 -> JDK8 Spark 2.2+ supports JDK8 only. - Ubuntu 12.04.5 LTS(precise) -> Ubuntu 16.04.3 LTS(xeniel) The end of life of `precise` was April 28, 2017. ## How was this patch tested? Manual. * Master ``` $ cd external/docker $ ./build $ export SPARK_HOME=... $ docker run -v $SPARK_HOME:/opt/spark spark-test-master CONTAINER_IP=172.17.0.3 ... 18/01/11 06:50:25 INFO MasterWebUI: Bound MasterWebUI to 172.17.0.3, and started at http://172.17.0.3:8080 18/01/11 06:50:25 INFO Utils: Successfully started service on port 6066. 18/01/11 06:50:25 INFO StandaloneRestServer: Started REST server for submitting applications on port 6066 18/01/11 06:50:25 INFO Master: I have been elected leader! New state: ALIVE ``` * Slave ``` $ docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://172.17.0.3:7077 CONTAINER_IP=172.17.0.4 ... 18/01/11 06:51:54 INFO Worker: Successfully registered with master spark://172.17.0.3:7077 ``` After slave starts, master will show ``` 18/01/11 06:51:54 INFO Master: Registering worker 172.17.0.4:8888 with 4 cores, 1024.0 MB RAM ``` Author: Dongjoon Hyun Closes #20230 from dongjoon-hyun/SPARK-23038. --- external/docker/spark-test/base/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index 5a95a9387c310..c70cd71367679 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -15,14 +15,14 @@ # limitations under the License. # -FROM ubuntu:precise +FROM ubuntu:xenial # Upgrade package index -# install a few other useful packages plus Open Jdk 7 +# install a few other useful packages plus Open Jdk 8 # Remove unneeded /var/lib/apt/lists/* after install to reduce the # docker image size (by ~30MB) RUN apt-get update && \ - apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.11.8 From 66738d29c59871b29d26fc3756772b95ef536248 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 14 Jan 2018 19:43:10 +0900 Subject: [PATCH 0090/2461] [SPARK-23069][DOCS][SPARKR] fix R doc for describe missing text ## What changes were proposed in this pull request? fix doc truncated ## How was this patch tested? manually Author: Felix Cheung Closes #20263 from felixcheung/r23docfix. --- R/pkg/R/DataFrame.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 9956f7eda91e6..6caa125e1e14a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3054,10 +3054,10 @@ setMethod("describe", #' \item stddev #' \item min #' \item max -#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75\%") #' } #' If no statistics are given, this function computes count, mean, stddev, min, -#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' approximate quartiles (percentiles at 25\%, 50\%, and 75\%), and max. #' This function is meant for exploratory data analysis, as we make no guarantee about the #' backward compatibility of the schema of the resulting Dataset. If you want to #' programmatically compute summary statistics, use the \code{agg} function instead. @@ -4019,9 +4019,9 @@ setMethod("broadcast", #' #' Spark will use this watermark for several purposes: #' \itemize{ -#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' \item To know when a given time window aggregation can be finalized and thus can be emitted #' when using output modes that do not allow updates. -#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' \item To minimize the amount of state that we need to keep for on-going aggregations. #' } #' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across #' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost From 990f05c80347c6eec2ee06823cff587c9ea90b49 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 22:26:21 +0800 Subject: [PATCH 0091/2461] [SPARK-23021][SQL] AnalysisBarrier should override innerChildren to print correct explain output ## What changes were proposed in this pull request? `AnalysisBarrier` in the current master cuts off explain results for parsed logical plans; ``` scala> Seq((1, 1)).toDF("a", "b").groupBy("a").count().sample(0.1).explain(true) == Parsed Logical Plan == Sample 0.0, 0.1, false, -7661439431999668039 +- AnalysisBarrier Aggregate [a#5], [a#5, count(1) AS count#14L] ``` To fix this, `AnalysisBarrier` needs to override `innerChildren` and this pr changed the output to; ``` == Parsed Logical Plan == Sample 0.0, 0.1, false, -5086223488015741426 +- AnalysisBarrier +- Aggregate [a#5], [a#5, count(1) AS count#14L] +- Project [_1#2 AS a#5, _2#3 AS b#6] +- LocalRelation [_1#2, _2#3] ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20247 from maropu/SPARK-23021-2. --- .../plans/logical/basicLogicalOperators.scala | 1 + .../sql/hive/execution/HiveExplainSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 95e099c340af1..a4fca790dd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -903,6 +903,7 @@ case class Deduplicate( * This analysis barrier will be removed at the end of analysis stage. */ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override protected def innerChildren: Seq[LogicalPlan] = Seq(child) override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = child.isStreaming override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index dfabf1ec2a22a..a4273de5fe260 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -171,4 +171,21 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } + + test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { + val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(true) + } + assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( + s"""== Parsed Logical Plan == + |GlobalLimit 1 + |+- LocalLimit 1 + | +- AnalysisBarrier + | +- Aggregate [a#0], [a#0, count(1) AS count#0L] + | +- Project [_1#0 AS a#0, _2#0 AS b#0] + | +- LocalRelation [_1#0, _2#0] + |""".stripMargin)) + } } From 60eeecd7760aee6ce2fd207c83ae40054eadaf83 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Sun, 14 Jan 2018 08:32:35 -0600 Subject: [PATCH 0092/2461] [SPARK-23051][CORE] Fix for broken job description in Spark UI ## What changes were proposed in this pull request? In 2.2, Spark UI displayed the stage description if the job description was not set. This functionality was broken, the GUI has shown no description in this case. In addition, the code uses jobName and jobDescription instead of stageName and stageDescription when JobTableRowData is created. In this PR the logic producing values for the job rows was modified to find the latest stage attempt for the job and use that as a fallback if job description was missing. StageName and stageDescription are also set using values from stage and jobName/description is used only as a fallback. ## How was this patch tested? Manual testing of the UI, using the code in the bug report. Author: Sandor Murakozi Closes #20251 from smurakozi/SPARK-23051. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 37e3b3b304a63..ff916bb6a5759 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -65,12 +65,10 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We }.map { job => val jobId = job.jobId val status = job.status - val displayJobDescription = - if (job.description.isEmpty) { - job.name - } else { - UIUtils.makeDescription(job.description.get, "", plainText = true).text - } + val jobDescription = store.lastStageAttempt(job.stageIds.max).description + val displayJobDescription = jobDescription + .map(UIUtils.makeDescription(_, "", plainText = true).text) + .getOrElse("") val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -429,20 +427,23 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), - basePath, plainText = false) + val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) + val lastStageDescription = lastStageAttempt.description.getOrElse("") + + val formattedJobDescription = + UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - jobData.name, - jobData.description.getOrElse(jobData.name), + lastStageAttempt.name, + lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - jobDescription, + formattedJobDescription, detailUrl ) } From 42a1a15d739890bdfbb367ef94198b19e98ffcb7 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 15 Jan 2018 02:02:49 +0800 Subject: [PATCH 0093/2461] [SPARK-22999][SQL] show databases like command' can remove the like keyword ## What changes were proposed in this pull request? SHOW DATABASES (LIKE pattern = STRING)? Can be like the back increase? When using this command, LIKE keyword can be removed. You can refer to the SHOW TABLES command, SHOW TABLES 'test *' and SHOW TABELS like 'test *' can be used. Similarly SHOW DATABASES 'test *' and SHOW DATABASES like 'test *' can be used. ## How was this patch tested? unit tests manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20194 from guoxiaolongzte/SPARK-22999. --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6daf01d98426c..39d5e4ed56628 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 591510c1d8283..2b4b7c137428a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -991,6 +991,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("SHOW DATABASES LIKE '*db1A'"), Row("showdb1a") :: Nil) + checkAnswer( + sql("SHOW DATABASES '*db1A'"), + Row("showdb1a") :: Nil) + checkAnswer( sql("SHOW DATABASES LIKE 'showdb1A'"), Row("showdb1a") :: Nil) From b98ffa4d6dabaf787177d3f14b200fc4b118c7ce Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 10:55:21 +0800 Subject: [PATCH 0094/2461] [SPARK-23054][SQL] Fix incorrect results of casting UserDefinedType to String ## What changes were proposed in this pull request? This pr fixed the issue when casting `UserDefinedType`s into strings; ``` >>> from pyspark.ml.classification import MultilayerPerceptronClassifier >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])), (1.0, Vectors.dense([0.0, 1.0]))], ["label", "features"]) >>> df.selectExpr("CAST(features AS STRING)").show(truncate = False) +-------------------------------------------+ |features | +-------------------------------------------+ |[6,1,0,0,2800000020,2,0,0,0] | |[6,1,0,0,2800000020,2,0,0,3ff0000000000000]| +-------------------------------------------+ ``` The root cause is that `Cast` handles input data as `UserDefinedType.sqlType`(this is underlying storage type), so we should pass data into `UserDefinedType.deserialize` then `toString`. This pr modified the result into; ``` +---------+ |features | +---------+ |[0.0,0.0]| |[0.0,1.0]| +---------+ ``` ## How was this patch tested? Added tests in `UserDefinedTypeSuite `. Author: Takeshi Yamamuro Closes #20246 from maropu/SPARK-23054. --- .../spark/sql/catalyst/expressions/Cast.scala | 7 +++++++ .../apache/spark/sql/UserDefinedTypeSuite.scala | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f21aa1e9e3135..a95ebe301b9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case udt: UserDefinedType[_] => + buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case udt: UserDefinedType[_] => + val udtRef = ctx.addReferenceObj("udt", udt) + (c, evPrim, evNull) => { + s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a08433ba794d9..cc8b600efa46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -44,6 +44,8 @@ object UDT { case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } + + override def toString: String = data.mkString("(", ", ", ")") } private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { @@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest + with ExpressionEvalHelper { import testImplicits._ private lazy val pointsRDD = Seq( @@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT pointsRDD.except(pointsRDD2), Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } + + test("SPARK-23054 Cast UserDefinedType to string") { + val udt = new UDT.MyDenseVectorUDT() + val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val data = udt.serialize(vector) + val ret = Cast(Literal(data, udt), StringType, None) + checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") + } } From 9a96bfc8bf021cb4b6c62fac6ce1bcf87affcd43 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 15 Jan 2018 12:06:56 +0800 Subject: [PATCH 0095/2461] [SPARK-23049][SQL] `spark.sql.files.ignoreCorruptFiles` should work for ORC files ## What changes were proposed in this pull request? When `spark.sql.files.ignoreCorruptFiles=true`, we should ignore corrupted ORC files. ## How was this patch tested? Pass the Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20240 from dongjoon-hyun/SPARK-23049. --- .../execution/datasources/orc/OrcUtils.scala | 29 ++++++++---- .../datasources/orc/OrcQuerySuite.scala | 47 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 23 +++++++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 8 +++- .../spark/sql/hive/orc/OrcFileOperator.scala | 28 +++++++++-- 5 files changed, 117 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 13a23996f4ade..460194ba61c8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -50,23 +51,35 @@ object OrcUtils extends Logging { paths } - def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean) + : Option[TypeDescription] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - Some(schema) + try { + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } catch { + case e: org.apache.orc.FileFormatException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $file", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $file", e) + } } } def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e00e057a18cc6..f58c331f33ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} @@ -531,6 +532,52 @@ abstract class OrcQueryTest extends OrcTest { val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) assert(df.count() == 20) } + + test("Enabling/disabling ignoreCorruptFiles") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val m1 = intercept[SparkException] { + testIgnoreCorruptFiles() + }.getMessage + assert(m1.contains("Could not read footer for file")) + val m2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m2.contains("Malformed ORC file")) + } + } } class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4c8c9ef6e0432..6ad88ed997ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -320,14 +320,27 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - checkAnswer( - df, - Seq(Row(0), Row(1))) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) } } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { @@ -335,6 +348,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext testIgnoreCorruptFiles() } assert(exception.getMessage().contains("is not a Parquet file")) + val exception2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + } + assert(exception2.getMessage().contains("is not a Parquet file")) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 95741c7b30289..237ed9bc05988 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -59,9 +59,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles OrcFileOperator.readSchema( files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()) + Some(sparkSession.sessionState.newHadoopConf()), + ignoreCorruptFiles ) } @@ -129,6 +131,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -138,7 +141,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty + val isEmptyFile = + OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf), ignoreCorruptFiles).isEmpty if (isEmptyFile) { Iterator.empty } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 5a3fcd7a759c0..80e44ca504356 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.orc +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -46,7 +49,10 @@ private[hive] object OrcFileOperator extends Logging { * create the result reader from that file. If no such file is found, it returns `None`. * @todo Needs to consider all files when schema evolution is taken into account. */ - def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = { + def getFileReader(basePath: String, + config: Option[Configuration] = None, + ignoreCorruptFiles: Boolean = false) + : Option[Reader] = { def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = { reader.getObjectInspector match { case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 => @@ -65,16 +71,28 @@ private[hive] object OrcFileOperator extends Logging { } listOrcFiles(basePath, conf).iterator.map { path => - path -> OrcFile.createReader(fs, path) + val reader = try { + Some(OrcFile.createReader(fs, path)) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $path", e) + } + } + path -> reader }.collectFirst { - case (path, reader) if isWithNonEmptySchema(path, reader) => reader + case (path, Some(reader)) if isWithNonEmptySchema(path, reader) => reader } } - def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = { + def readSchema(paths: Seq[String], conf: Option[Configuration], ignoreCorruptFiles: Boolean) + : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf)).headOption.map { reader => + paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") From b59808385cfe24ce768e5b3098b9034e64b99a5a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 16:26:52 +0800 Subject: [PATCH 0096/2461] [SPARK-23023][SQL] Cast field data to strings in showString ## What changes were proposed in this pull request? The current `Datset.showString` prints rows thru `RowEncoder` deserializers like; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------------------------------------------+ |a | +------------------------------------------------------------+ |[WrappedArray(1, 2), WrappedArray(3), WrappedArray(4, 5, 6)]| +------------------------------------------------------------+ ``` This result is incorrect because the correct one is; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------+ |a | +------------------------+ |[[1, 2], [3], [4, 5, 6]]| +------------------------+ ``` So, this pr fixed code in `showString` to cast field data to strings before printing. ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20214 from maropu/SPARK-23023. --- python/pyspark/sql/functions.py | 32 +++++++++---------- .../scala/org/apache/spark/sql/Dataset.scala | 21 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 28 ++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e1ad6590554cf..f7b3f29764040 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1849,14 +1849,14 @@ def explode_outer(col): +---+----------+----+-----+ >>> df.select("id", "a_map", explode_outer("an_array")).show() - +---+-------------+----+ - | id| a_map| col| - +---+-------------+----+ - | 1|Map(x -> 1.0)| foo| - | 1|Map(x -> 1.0)| bar| - | 2| Map()|null| - | 3| null|null| - +---+-------------+----+ + +---+----------+----+ + | id| a_map| col| + +---+----------+----+ + | 1|[x -> 1.0]| foo| + | 1|[x -> 1.0]| bar| + | 2| []|null| + | 3| null|null| + +---+----------+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.explode_outer(_to_java_column(col)) @@ -1881,14 +1881,14 @@ def posexplode_outer(col): | 3| null|null|null| null| +---+----------+----+----+-----+ >>> df.select("id", "a_map", posexplode_outer("an_array")).show() - +---+-------------+----+----+ - | id| a_map| pos| col| - +---+-------------+----+----+ - | 1|Map(x -> 1.0)| 0| foo| - | 1|Map(x -> 1.0)| 1| bar| - | 2| Map()|null|null| - | 3| null|null|null| - +---+-------------+----+----+ + +---+----------+----+----+ + | id| a_map| pos| col| + +---+----------+----+----+ + | 1|[x -> 1.0]| 0| foo| + | 1|[x -> 1.0]| 1| bar| + | 2| []|null|null| + | 3| null|null|null| + +---+----------+----+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 77e571272920a..34f0ab5aa6699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,13 +237,20 @@ class Dataset[T] private[sql]( private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val takeResult = toDF().take(numRows + 1) + val newDf = toDF() + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } + } + val takeResult = newDf.select(castCols: _*).take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - lazy val timeZone = - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) - // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -252,12 +259,6 @@ class Dataset[T] private[sql]( val str = cell match { case null => "null" case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case d: Date => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case ts: Timestamp => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5e4c1a6a484fb..33707080c1301 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1255,6 +1255,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } + test("SPARK-23023 Cast rows to strings in showString") { + val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a") + assert(df1.showString(10) === + s"""+------------+ + || a| + |+------------+ + ||[1, 2, 3, 4]| + |+------------+ + |""".stripMargin) + val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a") + assert(df2.showString(10) === + s"""+----------------+ + || a| + |+----------------+ + ||[1 -> a, 2 -> b]| + |+----------------+ + |""".stripMargin) + val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") + assert(df3.showString(10) === + s"""+------+---+ + || a| b| + |+------+---+ + ||[1, a]| 0| + ||[2, b]| 0| + |+------+---+ + |""".stripMargin) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 54893c184642b..49c59cf695dc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -958,12 +958,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDS() val expected = - """+-------+ - || f| - |+-------+ - ||[foo,1]| - ||[bar,2]| - |+-------+ + """+--------+ + || f| + |+--------+ + ||[foo, 1]| + ||[bar, 2]| + |+--------+ |""".stripMargin checkShowString(ds, expected) From a38c887ac093d7cf343d807515147d87ca931ce7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 15 Jan 2018 07:49:34 -0600 Subject: [PATCH 0097/2461] [SPARK-19550][BUILD][FOLLOW-UP] Remove MaxPermSize for sql module ## What changes were proposed in this pull request? Remove `MaxPermSize` for `sql` module ## How was this patch tested? Manually tested. Author: Yuming Wang Closes #20268 from wangyum/SPARK-19550-MaxPermSize. --- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 839b929abd3cb..7d23637e28342 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -134,7 +134,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 744daa6079779..ef41837f89d68 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -195,7 +195,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} From bd08a9e7af4137bddca638e627ad2ae531bce20f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 15 Jan 2018 22:32:38 +0800 Subject: [PATCH 0098/2461] [SPARK-23070] Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 ## What changes were proposed in this pull request? Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 and add the missing exclusions to `v23excludes` in `MimaExcludes`. No item can be un-excluded in `v23excludes`. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20264 from gatorsmile/bump22. --- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 2ef0e7b40d940..adde213e361f0 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,7 +88,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "2.0.0" + val previousSparkVersion = "2.2.0" val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 32eb31f495979..d35c50e1d00fe 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -102,7 +102,40 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-21728][CORE] Allow SparkSubmit to use Logging + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), + + // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), + + // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), + + // [SPARK-20643][CORE] Add listener implementation to collect app state + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), + + // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), + + // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json + // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), + + // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), + + // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") ) // Exclude rules for 2.2.x From 6c81fe227a6233f5d9665d2efadf8a1cf09f700d Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 15 Jan 2018 23:13:15 +0800 Subject: [PATCH 0099/2461] [SPARK-23035][SQL] Fix improper information of TempTableAlreadyExistsException ## What changes were proposed in this pull request? Problem: it throw TempTableAlreadyExistsException and output "Temporary table '$table' already exists" when we create temp view by using org.apache.spark.sql.catalyst.catalog.GlobalTempViewManager#create, it's improper. So fix improper information about TempTableAlreadyExistsException when create temp view: change "Temporary table" to "Temporary view" ## How was this patch tested? test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") Author: xubo245 <601450868@qq.com> Closes #20227 from xubo245/fixDeprecated. --- .../analysis/AlreadyExistException.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 6 +- .../spark/sql/execution/SQLViewSuite.scala | 2 +- .../sql/execution/command/DDLSuite.scala | 75 ++++++++++++++++++- 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 57f7a80bedc6c..6d587abd8fd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -31,7 +31,7 @@ class TableAlreadyExistsException(db: String, table: String) extends AnalysisException(s"Table or view '$table' already exists in database '$db'") class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary table '$table' already exists") + extends AnalysisException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 95c87ffa20cb7..6abab0073cca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -279,7 +279,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } - test("create temp table") { + test("create temp view") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) val tempTable2 = Range(1, 20, 2, 10) @@ -288,11 +288,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { assert(catalog.getTempView("tbl1") == Option(tempTable1)) assert(catalog.getTempView("tbl2") == Option(tempTable2)) assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists + // Temporary view already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } - // Temporary table already exists but we override it + // Temporary view already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) assert(catalog.getTempView("tbl1") == Option(tempTable2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8c55758cfe38d..14082197ba0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -293,7 +293,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") } - assert(e.message.contains("Temporary table") && e.message.contains("already exists")) + assert(e.message.contains("Temporary view") && e.message.contains("already exists")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2b4b7c137428a..6ca21b5aa1595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -835,6 +835,31 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") { + withTempView("view1") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("view1"))) + } + } + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") @@ -883,6 +908,42 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") { + withTempView("view1", "view2") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + sql( + """ + |CREATE TEMPORARY VIEW view2 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO view2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`view2`': destination table already exists")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == + Seq(TableIdentifier("view1"), TableIdentifier("view2"))) + } + } + test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -1728,12 +1789,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("block creating duplicate temp table") { - withView("t_temp") { + withTempView("t_temp") { sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") val e = intercept[TempTableAlreadyExistsException] { sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") }.getMessage - assert(e.contains("Temporary table 't_temp' already exists")) + assert(e.contains("Temporary view 't_temp' already exists")) + } + } + + test("block creating duplicate temp view") { + withTempView("t_temp") { + sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") + val e = intercept[TempTableAlreadyExistsException] { + sql("CREATE TEMPORARY VIEW t_temp (c3 int, c4 string) USING JSON") + }.getMessage + assert(e.contains("Temporary view 't_temp' already exists")) } } From 8ab2d7ea99b2cff8b54b2cb3a1dbf7580845986a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 11:47:42 +0900 Subject: [PATCH 0100/2461] [SPARK-23080][SQL] Improve error message for built-in functions ## What changes were proposed in this pull request? When a user puts the wrong number of parameters in a function, an AnalysisException is thrown. If the function is a UDF, he user is told how many parameters the function expected and how many he/she put. If the function, instead, is a built-in one, no information about the number of parameters expected and the actual one is provided. This can help in some cases, to debug the errors (eg. bad quotes escaping may lead to a different number of parameters than expected, etc. etc.) The PR adds the information about the number of parameters passed and the expected one, analogously to what happens for UDF. ## How was this patch tested? modified existing UT + manual test Author: Marco Gaido Closes #20271 from mgaido91/SPARK-23080. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 10 +++++++++- .../resources/sql-tests/results/json-functions.sql.out | 4 ++-- .../src/test/scala/org/apache/spark/sql/UDFSuite.scala | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5ddb39822617d..747016beb06e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -526,7 +526,15 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - throw new AnalysisException(s"Invalid number of arguments for function $name") + val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${params.length}") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d9dc728a18e8d..581dddc89d0bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 -- !query 22 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index db37be68e42e6..af6a10b425b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -80,7 +80,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function substr")) + assert(e.getMessage.contains("Invalid number of arguments for function substr. Expected:")) } test("error reporting for incorrect number of arguments - udf") { @@ -89,7 +89,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function foo")) + assert(e.getMessage.contains("Invalid number of arguments for function foo. Expected:")) } test("error reporting for undefined functions") { From c7572b79da0a29e502890d7618eaf805a1c9f474 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 11:20:18 +0800 Subject: [PATCH 0101/2461] [SPARK-23000] Use fully qualified table names in HiveMetastoreCatalogSuite ## What changes were proposed in this pull request? In another attempt to fix DataSourceWithHiveMetastoreCatalogSuite, this patch uses qualified table names (`default.t`) in the individual tests. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20273 from sameeragarwal/flaky-test. --- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..83b4c862e2546 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("t") { + withTable("default.t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,14 +187,15 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("t") { + withTable("default.t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -203,7 +204,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = @@ -219,8 +220,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === Seq("1.1\t1", "2.1\t2")) } } @@ -228,9 +229,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("t") { + withTable("default.t") { sql( - s"""CREATE TABLE t USING $provider + s"""CREATE TABLE default.t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -248,8 +249,9 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + checkAnswer(table("default.t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1\tval_1")) } } } From 07ae39d0ec1f03b1c73259373a8bb599694c7860 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 15 Jan 2018 22:01:14 -0800 Subject: [PATCH 0102/2461] [SPARK-22956][SS] Bug fix for 2 streams union failover scenario ## What changes were proposed in this pull request? This problem reported by yanlin-Lynn ivoson and LiangchangZ. Thanks! When we union 2 streams from kafka or other sources, while one of them have no continues data coming and in the same time task restart, this will cause an `IllegalStateException`. This mainly cause because the code in [MicroBatchExecution](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala#L190) , while one stream has no continues data, its comittedOffset same with availableOffset during `populateStartOffsets`, and `currentPartitionOffsets` not properly handled in KafkaSource. Also, maybe we should also consider this scenario in other Source. ## How was this patch tested? Add a UT in KafkaSourceSuite.scala Author: Yuanjian Li Closes #20150 from xuanyuanking/SPARK-22956. --- .../spark/sql/kafka010/KafkaSource.scala | 13 ++-- .../spark/sql/kafka010/KafkaSourceSuite.scala | 65 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 6 +- .../sql/execution/streaming/memory.scala | 6 ++ 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..864a92b8f813f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -223,6 +223,14 @@ private[kafka010] class KafkaSource( logInfo(s"GetBatch called with start = $start, end = $end") val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + if (start.isDefined && start.get == end) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) @@ -305,11 +313,6 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - // On recovery, getBatch will get called before getOffset - if (currentPartitionOffsets.isEmpty) { - currentPartitionOffsets = Some(untilPartitionOffsets) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..a0f5695fc485c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -318,6 +318,71 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { + def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, range.map(_.toString).toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + + reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(k => k.toInt) + } + + val df1 = getSpecificDF(0 to 9) + val df2 = getSpecificDF(100 to 199) + + val kafka = df1.union(df2) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(kafka)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((0 to 4) ++ (100 to 104): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((5 to 9) ++ (105 to 109): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smaller topic empty, 5 from bigger one + CheckLastBatch(110 to 114: _*), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(115 to 119: _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(120 to 124: _*) + ) + } + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 42240eeb58d4b..70407f0580f97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -208,10 +208,8 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + val start = committedOffsets.get(source) + source.getBatch(start, end) case nonV1Tuple => // The V2 API does not have the same edge case requiring getBatch to be called // here, so we do nothing here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3041d4d703cb4..509a69dd922fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -119,9 +119,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val newBlocks = synchronized { val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") batches.slice(sliceStart, sliceEnd) } + if (newBlocks.isEmpty) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) newBlocks From 66217dac4f8952a9923625908ad3dcb030763c81 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 15 Jan 2018 22:40:44 -0800 Subject: [PATCH 0103/2461] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20223 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 49 ++++++++++++------- .../spark/launcher/AbstractAppHandle.java | 22 +++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 ++++--- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 ++++++++++++++--- .../org/apache/spark/launcher/BaseSuite.java | 42 +++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 20 +++----- 8 files changed, 156 insertions(+), 72 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -31,6 +32,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -137,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -146,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..fd6f229b2349c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { + public synchronized void close() throws IOException { if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..660c4443b20b9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -338,16 +365,21 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); + + synchronized (this) { + super.close(); + notifyAll(); + } + if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + handle.dispose(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..75c1af0c71e2a 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -197,28 +199,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { From b85eb946ac298e711dad25db0d04eee41d7fd236 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 16 Jan 2018 20:20:33 +0900 Subject: [PATCH 0104/2461] [SPARK-22978][PYSPARK] Register Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? Register Vectorized UDFs for SQL Statement. For example, ```Python >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... >>> _ = spark.udf.register("add_one", add_one) >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20171 from gatorsmile/supportVectorizedUDF. --- python/pyspark/sql/catalog.py | 75 ++++++++++++++++++++++++---------- python/pyspark/sql/context.py | 51 ++++++++++++++++------- python/pyspark/sql/tests.py | 76 ++++++++++++++++++++++++++++++----- 3 files changed, 155 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 156603128d063..35fbe9e669adb 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -226,18 +226,23 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() @@ -256,27 +261,55 @@ def registerFunction(self, name, f, returnType=StringType()): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ # This is to check whether the input function is a wrapped/native UserDefinedFunction if hasattr(f, 'asNondeterministic'): - udf = UserDefinedFunction(f.func, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF, - deterministic=f.deterministic) + if returnType is not None: + raise TypeError( + "Invalid returnType: None is expected when f is a UserDefinedFunction, " + "but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f else: - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - self._jsparkSession.udf().registerPython(name, udf._judf) - return udf._wrapped() + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b8d86cc098e94..85479095af594 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -174,18 +174,23 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() @@ -204,15 +209,31 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = sqlContext.udf.register("slen", slen) + >>> sqlContext.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP + >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) @@ -575,7 +596,7 @@ class UDFRegistration(object): def __init__(self, sqlContext): self.sqlContext = sqlContext - def register(self, name, f, returnType=StringType()): + def register(self, name, f, returnType=None): return self.sqlContext.registerFunction(name, f, returnType) def registerJavaFunction(self, name, javaClassName, returnType=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 80a94a91a87b3..8906618666b14 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -380,12 +380,25 @@ def test_udf2(self): self.assertEqual(4, res[0]) def test_udf3(self): - twoargs = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) - self.assertEqual(twoargs.deterministic, True) + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + def test_nondeterministic_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf @@ -402,12 +415,12 @@ def test_nondeterministic_udf2(self): from pyspark.sql.functions import udf random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) self.assertEqual(random_udf1.deterministic, False) [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf()).collect() self.assertEqual(row[0], 6) # render_doc() reproduces the help() exception without printing output @@ -3691,7 +3704,7 @@ def tearDownClass(cls): ReusedSQLTestCase.tearDownClass() @property - def random_udf(self): + def nondeterministic_vectorized_udf(self): from pyspark.sql.functions import pandas_udf @pandas_udf('double') @@ -3726,6 +3739,21 @@ def test_vectorized_udf_basic(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_register_nondeterministic_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + import random + random_pandas_udf = pandas_udf( + lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + def test_vectorized_udf_null_boolean(self): from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] @@ -4085,14 +4113,14 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) - def test_nondeterministic_udf(self): + def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf, pandas_udf, col @pandas_udf('double') def plus_ten(v): return v + 10 - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() @@ -4100,11 +4128,11 @@ def plus_ten(v): self.assertEqual(random_udf.deterministic, False) self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) - def test_nondeterministic_udf_in_aggregate(self): + def test_nondeterministic_vectorized_udf_in_aggregate(self): from pyspark.sql.functions import pandas_udf, sum df = self.spark.range(10) - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf with QuietTest(self.sc): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): @@ -4112,6 +4140,23 @@ def test_nondeterministic_udf_in_aggregate(self): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): df.agg(sum(random_udf(df.id))).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b')) + original_add = pandas_udf(lambda x, y: x + y, IntegerType()) + self.assertEqual(original_add.deterministic, True) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + new_add = self.spark.catalog.registerFunction("add1", original_add) + res1 = df.select(new_add(col('a'), col('b'))) + res2 = self.spark.sql( + "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") + expected = df.select(expr('a + b')) + self.assertEquals(expected.collect(), res1.collect()) + self.assertEquals(expected.collect(), res2.collect()) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): @@ -4147,6 +4192,15 @@ def test_simple(self): expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) + def test_register_group_map_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' + 'SQL_PANDAS_SCALAR_UDF'): + self.spark.catalog.registerFunction("foo_udf", foo_udf) + def test_decorator(self): from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data From 75db14864d2bd9b8e13154226e94d466e3a7e0a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 16 Jan 2018 22:41:30 +0800 Subject: [PATCH 0105/2461] [SPARK-22392][SQL] data source v2 columnar batch reader ## What changes were proposed in this pull request? a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20153 from cloud-fan/columnar-reader. --- .../sources/v2/reader/DataSourceV2Reader.java | 5 +- .../v2/reader/SupportsScanColumnarBatch.java | 52 ++++++++ .../v2/reader/SupportsScanUnsafeRow.java | 2 +- .../sql/execution/ColumnarBatchScan.scala | 37 +++++- .../sql/execution/DataSourceScanExec.scala | 39 ++---- .../columnar/InMemoryTableScanExec.scala | 101 +++++++++------- .../datasources/v2/DataSourceRDD.scala | 20 ++-- .../datasources/v2/DataSourceV2ScanExec.scala | 72 ++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 112 ++++++++++++++++++ .../execution/WholeStageCodegenSuite.scala | 28 ++--- .../sql/sources/v2/DataSourceV2Suite.scala | 72 ++++++++++- 12 files changed, 400 insertions(+), 144 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 95ee4a8278322..f23c3842bf1b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -38,7 +38,10 @@ * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. + * Names of these interfaces start with `SupportsScan`. Note that a reader should only + * implement at most one of the special scans, if more than one special scans are implemented, + * only one of them would be respected, according to the priority list from high to low: + * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 0000000000000..27cf3a77724f0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ +@InterfaceStability.Evolving +public interface SupportsScanColumnarBatch extends DataSourceV2Reader { + @Override + default List> createReadTasks() { + throw new IllegalStateException( + "createReadTasks not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + */ + List> createBatchReadTasks(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #createReadTasks()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index b90ec880dc85e..2d3ad0eee65ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override default List> createReadTasks() { throw new IllegalStateException( - "createReadTasks should not be called with SupportsScanUnsafeRow."); + "createReadTasks not supported by default within SupportsScanUnsafeRow"); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 5617046e1396e..dd68df9686691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType @@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** - * Helper trait for abstracting scan functionality using - * [[ColumnarBatch]]es. + * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { def vectorTypes: Option[Seq[String]] = None + protected def supportsBatch: Boolean = true + + protected def needsUnsafeRowConversion: Boolean = true + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + if (supportsBatch) { + produceBatches(ctx, input) + } else { + produceRows(ctx, input) + } + } + private def produceBatches(ctx: CodegenContext, input: String): String = { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") @@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { """.stripMargin } + private def produceRows(ctx: CodegenContext, input: String): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + val row = ctx.freshName("row") + + ctx.INPUT_ROW = row + ctx.currentVars = null + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + val inputRow = if (needsUnsafeRowConversion) null else row + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, outputVars, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d1ff82c7c06bc..7c7d79c2bbd7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -164,13 +164,15 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( + override val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled - } else { - false + override val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + } else { + false + } } override def vectorTypes: Option[Seq[String]] = @@ -346,33 +348,6 @@ case class FileSourceScanExec( override val nodeNamePrefix: String = "File" - override protected def doProduce(ctx: CodegenContext): String = { - if (supportsBatch) { - return super.doProduce(ctx) - } - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - // Always provide `outputVars`, so that the framework can help us build unsafe row if the input - // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. - val outputVars = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - val inputRow = if (needsUnsafeRowConversion) null else row - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, outputVars, inputRow).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } - /** * Create an RDD for bucketed reads. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 933b9753faa61..3565ee3af1b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -49,9 +49,9 @@ case class InMemoryTableScanExec( /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. - * If false, get data from UnsafeRow build from ColumnVector + * If false, get data from UnsafeRow build from CachedBatch */ - override val supportCodegen: Boolean = { + override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields relation.schema.fields.forall(f => f.dataType match { @@ -61,6 +61,8 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + override protected def needsUnsafeRowConversion: Boolean = false + private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -90,14 +92,56 @@ case class InMemoryTableScanExec( columnarBatch } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - assert(supportCodegen) + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + if (supportsBatch) { + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + } else { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) + } + columnarIterator + } + } } + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { @@ -185,7 +229,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulators: Boolean = + lazy val enableAccumulatorsForTest: Boolean = sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes @@ -230,43 +274,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - // Using these variables here to avoid serialization of entire objects (if referenced directly) - // within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulators) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulators && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + inputRDD } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5f30be5ed4af1..ac104d7cd0cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.v2.reader.ReadTask -class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) extends Partition with Serializable -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readTasks: java.util.List[ReadTask[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { readTasks.asScala.zipWithIndex.map { @@ -38,10 +38,10 @@ class DataSourceRDD( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[UnsafeRow] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): UnsafeRow = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -63,6 +63,6 @@ class DataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 49c506bc560cf..8c64df080242f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,10 +24,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + @transient reader: DataSourceV2Reader) + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def references: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } - override protected def doExecute(): RDD[InternalRow] = { - val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() - case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] - }.asJava - } + private lazy val inputRDD: RDD[InternalRow] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + + case _ => + new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + } - val inputRDD = reader match { - case _: ContinuousReader => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } - case _ => - new DataSourceRDD(sparkContext, readTasks) - } + override protected def needsUnsafeRowConversion: Boolean = false - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[InternalRow]].map { r => - numOutputRows += 1 - r + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..b3f1a1a1aaab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,7 +52,7 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) @@ -132,7 +132,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 0000000000000..44e5146d7c553 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createBatchReadTasks() { + return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + } + } + + static class JavaBatchReadTask implements ReadTask, DataReader { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createDataReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c47..22ca128c27768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ - val dsInt = spark.range(3).cache - dsInt.count + val dsInt = spark.range(3).cache() + dsInt.count() val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined - ) + assert(planInt.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache - dsString.count + val dsString = spark.range(3).map(_.toString).cache() + dsString.count() val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec]).isDefined - ) + assert(planString.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + }.length == 1) assert(dsStringFilter.collect() === Array("1")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e4984bd1f..a89f7c55bf4f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("unsafe row implementation") { + test("unsafe row scan implementation") { Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("columnar batch scan implementation") { + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 90).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + } + } + } + test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { @@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = new Reader(schema) } + +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class BatchReadTask(start: Int, end: Int) + extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { + + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch( + new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + + private var current = start + + override def createDataReader(): DataReader[ColumnarBatch] = this + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } + + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() +} From 12db365b4faf7a185708648d246fc4a2aae0c2c0 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 16 Jan 2018 11:41:08 -0800 Subject: [PATCH 0106/2461] [SPARK-16139][TEST] Add logging functionality for leaked threads in tests ## What changes were proposed in this pull request? Lots of our tests don't properly shutdown everything they create, and end up leaking lots of threads. For example, `TaskSetManagerSuite` doesn't stop the extra `TaskScheduler` and `DAGScheduler` it creates. There are a couple more instances, eg. in `DAGSchedulerSuite`. This PR adds the possibility to print out the not properly stopped thread list after a test suite executed. The format is the following: ``` ===== FINISHED o.a.s.scheduler.DAGSchedulerSuite: 'task end event should have updated accumulators (SPARK-20342)' ===== ... ===== Global thread whitelist loaded with name /thread_whitelist from classpath: rpc-client.*, rpc-server.*, shuffle-client.*, shuffle-server.*' ===== ScalaTest-run: ===== THREADS NOT STOPPED PROPERLY ===== ScalaTest-run: dag-scheduler-event-loop ScalaTest-run: globalEventExecutor-2-5 ScalaTest-run: ===== END OF THREAD DUMP ===== ScalaTest-run: ===== EITHER PUT THREAD NAME INTO THE WHITELIST FILE OR SHUT IT DOWN PROPERLY ===== ``` With the help of this leaking threads has been identified in TaskSetManagerSuite. My intention is to hunt down and fix such bugs in later PRs. ## How was this patch tested? Manual: TaskSetManagerSuite test executed and found out where are the leaking threads. Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #19893 from gaborgsomogyi/SPARK-16139. --- .../org/apache/spark/SparkFunSuite.scala | 34 +++++++ .../scala/org/apache/spark/ThreadAudit.scala | 99 +++++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 7 +- .../apache/spark/sql/SessionStateSuite.scala | 1 + .../sql/sources/DataSourceAnalysisSuite.scala | 1 + .../spark/sql/test/SharedSQLContext.scala | 23 ++++- .../hive/HiveContextCompatibilitySuite.scala | 1 + .../sql/hive/HiveSessionStateSuite.scala | 1 + .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 + .../sql/hive/client/HiveClientSuite.scala | 1 + .../sql/hive/client/HiveVersionSuite.scala | 1 + .../spark/sql/hive/client/VersionsSuite.scala | 2 + .../hive/execution/HiveComparisonTest.scala | 2 + .../hive/orc/OrcHadoopFsRelationSuite.scala | 1 + .../sql/hive/test/TestHiveSingleton.scala | 1 + 15 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ThreadAudit.scala diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 18077c08c9dcc..3af9d82393bc4 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -27,19 +27,53 @@ import org.apache.spark.util.AccumulatorContext /** * Base abstract class for all unit tests in Spark for handling common functionality. + * + * Thread audit happens normally here automatically when a new test suite created. + * The only prerequisite for that is that the test class must extend [[SparkFunSuite]]. + * + * It is possible to override the default thread audit behavior by setting enableAutoThreadAudit + * to false and manually calling the audit methods, if desired. For example: + * + * class MyTestSuite extends SparkFunSuite { + * + * override val enableAutoThreadAudit = false + * + * protected override def beforeAll(): Unit = { + * doThreadPreAudit() + * super.beforeAll() + * } + * + * protected override def afterAll(): Unit = { + * super.afterAll() + * doThreadPostAudit() + * } + * } */ abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll + with ThreadAudit with Logging { // scalastyle:on + protected val enableAutoThreadAudit = true + + protected override def beforeAll(): Unit = { + if (enableAutoThreadAudit) { + doThreadPreAudit() + } + super.beforeAll() + } + protected override def afterAll(): Unit = { try { // Avoid leaking map entries in tests that use accumulators without SparkContext AccumulatorContext.clear() } finally { super.afterAll() + if (enableAutoThreadAudit) { + doThreadPostAudit() + } } } diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala new file mode 100644 index 0000000000000..b3cea9de8f304 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging + +/** + * Thread audit for test suites. + */ +trait ThreadAudit extends Logging { + + val threadWhiteList = Set( + /** + * Netty related internal threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "netty.*", + + /** + * Netty related internal threads. + * A Single-thread singleton EventExecutor inside netty which creates such threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "globalEventExecutor.*", + + /** + * Netty related internal threads. + * Checks if a thread is alive periodically and runs a task when a thread dies. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "threadDeathWatcher.*", + + /** + * During [[SparkContext]] creation [[org.apache.spark.rpc.netty.NettyRpcEnv]] + * creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "rpc-client.*", + "rpc-server.*", + + /** + * During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "shuffle-client.*", + "shuffle-server.*" + ) + private var threadNamesSnapshot: Set[String] = Set.empty + + protected def doThreadPreAudit(): Unit = { + threadNamesSnapshot = runningThreadNames() + } + + protected def doThreadPostAudit(): Unit = { + val shortSuiteName = this.getClass.getName.replaceAll("org.apache.spark", "o.a.s") + + if (threadNamesSnapshot.nonEmpty) { + val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot) + .filterNot { s => threadWhiteList.exists(s.matches(_)) } + if (remainingThreadNames.nonEmpty) { + logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " + + s"thread names: ${remainingThreadNames.mkString(", ")} =====\n") + } + } else { + logWarning("\n\n===== THREAD AUDIT POST ACTION CALLED " + + s"WITHOUT PRE ACTION IN SUITE $shortSuiteName =====\n") + } + } + + private def runningThreadNames(): Set[String] = { + Thread.getAllStackTraces.keySet().asScala.map(_.getName).toSet + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2ce81ae27daf6..ca6a7e5db3b17 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -683,7 +683,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val conf = new SparkConf().set("spark.speculation", "true") sc = new SparkContext("local", "test", conf) - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { override def killTask( taskId: Long, @@ -709,6 +709,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val singleTask = new ShuffleMapTask(0, 0, null, new Partition { @@ -754,7 +755,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation", "true") var killTaskCalled = false - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) sched.initialize(new FakeSchedulerBackend() { override def killTask( @@ -789,6 +790,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, @@ -1183,6 +1185,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler.stop() sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index c01666770720c..5d75f5835bf9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -39,6 +39,7 @@ class SessionStateSuite extends SparkFunSuite protected var activeSession: SparkSession = _ override def beforeAll(): Unit = { + super.beforeAll() activeSession = SparkSession.builder().master("local").getOrCreate() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 735e07c21373a..e1022e377132c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -33,6 +33,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { private var targetPartitionSchema: StructType = _ override def beforeAll(): Unit = { + super.beforeAll() targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) targetPartitionSchema = new StructType() .add("b", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f5494..e6c7648c986ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,25 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { + + /** + * Suites extending [[SharedSQLContext]] are sharing resources (eg. SparkSession) in their tests. + * That trait initializes the spark session in its [[beforeAll()]] implementation before the + * automatic thread snapshot is performed, so the audit code could fail to report threads leaked + * by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + super.afterAll() + doThreadPostAudit() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 8a7423663f28d..a80db765846e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach { + override protected val enableAutoThreadAudit = false private var sc: SparkContext = null private var hc: HiveContext = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 958ad3e1c3ce8..f7da3c4cbb0aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -30,6 +30,7 @@ class HiveSessionStateSuite extends SessionStateSuite override def beforeAll(): Unit = { // Reuse the singleton session + super.beforeAll() activeSession = spark } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 21b3e281490cf..10204f4694663 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -44,6 +44,8 @@ class HiveSparkSubmitSuite with BeforeAndAfterEach with ResetSystemProperties { + override protected val enableAutoThreadAudit = false + // TODO: rewrite these or mark them as slow tests to be run sparingly override def beforeEach() { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index ce53acef51503..a5dfd89b3a574 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -67,6 +67,7 @@ class HiveClientSuite(version: String) } override def beforeAll() { + super.beforeAll() client = init(true) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index 951ebfad4590e..bb8a4697b0a13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.HiveUtils private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + override protected val enableAutoThreadAudit = false protected var client: HiveClient = null protected def buildClient(hadoopConf: Configuration): HiveClient = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index e64389e56b5a1..72536b833481a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -50,6 +50,8 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { + override protected val enableAutoThreadAudit = false + import HiveClientBuilder.buildClient /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cee82cda4628a..272e6f51f5002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -48,6 +48,8 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} abstract class HiveComparisonTest extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { + override protected val enableAutoThreadAudit = false + /** * Path to the test datasets. We find this by looking up "hive-test-path-helper.txt" file. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index f87162f94c01a..a1f054b8e3f44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ + override protected val enableAutoThreadAudit = false override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index df7988f542b71..d3fff37c3424d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + override protected val enableAutoThreadAudit = false protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = From 4371466b3f06ca171b10568e776c9446f7bae6dd Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 16 Jan 2018 12:56:57 -0800 Subject: [PATCH 0107/2461] [SPARK-23045][ML][SPARKR] Update RFormula to use OneHotEncoderEstimator. ## What changes were proposed in this pull request? RFormula should use VectorSizeHint & OneHotEncoderEstimator in its pipeline to avoid using the deprecated OneHotEncoder & to ensure the model produced can be used in streaming. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #20229 from MrBago/rFormula. --- R/pkg/R/mllib_utils.R | 1 - .../apache/spark/ml/feature/RFormula.scala | 20 +++++-- .../spark/ml/feature/RFormulaSuite.scala | 53 +++++++++++-------- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 23dda42c325be..a53c92c2c4815 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,4 +130,3 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f384ffbf578bc..1155ea5fdd85b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() + val oneHotEncodeColumns = ArrayBuffer[(String, String)]() val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() @@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - var encoder = new OneHotEncoder() - .setInputCol(indexed(term)) - .setOutputCol(encodedCol) // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. if (!hasIntercept && !keepReferenceCategory) { - encoder = encoder.setDropLast(false) + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(Array(indexed(term))) + .setOutputCols(Array(encodedCol)) + .setDropLast(false) keepReferenceCategory = true + } else { + oneHotEncodeColumns += indexed(term) -> encodedCol } - encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => @@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) interactionCol } + if (oneHotEncodeColumns.nonEmpty) { + val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(inputCols) + .setOutputCols(outputCols) + .setDropLast(true) + } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index f3f4b5a3d0233..bfe38d32dd77d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ + def testRFormulaTransform[A: Encoder]( + dataframe: DataFrame, + formulaModel: RFormulaModel, + expected: DataFrame): Unit = { + val (first +: rest) = expected.schema.fieldNames.toSeq + val expectedRows = expected.collect() + testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows === expectedRows) + } + } + test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("features column already exists") { @@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("encodes string terms") { @@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("encodes string terms with string indexer order type") { @@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected(idx).collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } } @@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("formula w/o intercept, we should output reference category when encoding string terms") { @@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result1.schema.toString == resultSchema1.toString) - assert(result1.collect() === expected1.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( @@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result2.schema.toString == resultSchema2.toString) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( @@ -302,7 +313,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), @@ -310,7 +320,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(String, String, Int)](original, model, expected) } test("force to index label even it is numeric type") { @@ -319,7 +329,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = spark.createDataFrame( Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), @@ -327,7 +336,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Double, String, Int)](original, model, expected) } test("attribute generation") { @@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, String)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[SparkException] { formula1.fit(df1).transform(df2).collect() } - val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) - val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) + val model1 = formula1.setHandleInvalid("skip").fit(df1) + val model2 = formula1.setHandleInvalid("keep").fit(df1) val expected1 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), @@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result1.collect() === expected1.collect()) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String)](df2, model1, expected1) + testRFormulaTransform[(Int, String, String)](df2, model2, expected2) // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") intercept[SparkException] { formula2.fit(df1).transform(df2).collect() } - val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) - val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) + val model3 = formula2.setHandleInvalid("skip").fit(df1) + val model4 = formula2.setHandleInvalid("keep").fit(df1) val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), @@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") - assert(result3.collect() === expected3.collect()) - assert(result4.collect() === expected4.collect()) + testRFormulaTransform[(Int, String, String)](df2, model3, expected3) + testRFormulaTransform[(Int, String, String)](df2, model4, expected4) } test("Use Vectors as inputs to formula.") { From 5ae333391bd73331b5b90af71a3de52cdbb24109 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 16 Jan 2018 16:25:10 -0800 Subject: [PATCH 0108/2461] [SPARK-23044] Error handling for jira assignment ## What changes were proposed in this pull request? * If there is any error while trying to assign the jira, prompt again * Filter out the "Apache Spark" choice * allow arbitrary user ids to be entered ## How was this patch tested? Couldn't really test the error case, just some testing of similar-ish code in python shell. Haven't run a merge yet. Author: Imran Rashid Closes #20236 from squito/SPARK-23044. --- dev/merge_spark_pr.py | 50 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 57ca8400b6f3d..6b244d8184b2c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -30,6 +30,7 @@ import re import subprocess import sys +import traceback import urllib2 try: @@ -298,24 +299,37 @@ def choose_jira_assignee(issue, asf_jira): Prompt the user to choose who to assign the issue to in jira, given a list of candidates, including the original reporter and all commentors """ - reporter = issue.fields.reporter - commentors = map(lambda x: x.author, issue.fields.comment.comments) - candidates = set(commentors) - candidates.add(reporter) - candidates = list(candidates) - print("JIRA is unassigned, choose assignee") - for idx, author in enumerate(candidates): - annotations = ["Reporter"] if author == reporter else [] - if author in commentors: - annotations.append("Commentor") - print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - assignee = raw_input("Enter number of user to assign to (blank to leave unassigned):") - if assignee == "": - return None - else: - assignee = candidates[int(assignee)] - asf_jira.assign_issue(issue.key, assignee.key) - return assignee + while True: + try: + reporter = issue.fields.reporter + commentors = map(lambda x: x.author, issue.fields.comment.comments) + candidates = set(commentors) + candidates.add(reporter) + candidates = list(candidates) + print("JIRA is unassigned, choose assignee") + for idx, author in enumerate(candidates): + if author.key == "apachespark": + continue + annotations = ["Reporter"] if author == reporter else [] + if author in commentors: + annotations.append("Commentor") + print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) + raw_assignee = raw_input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") + if raw_assignee == "": + return None + else: + try: + id = int(raw_assignee) + assignee = candidates[id] + except: + # assume it's a user id, and try to assign (might fail, we just prompt again) + assignee = asf_jira.user(raw_assignee) + asf_jira.assign_issue(issue.key, assignee.key) + return assignee + except: + traceback.print_exc() + print("Error assigning JIRA, try again (or leave blank and fix manually)") def resolve_jira_issues(title, merge_branches, comment): From 0c2ba427bc7323729e6ffb34f1f06a97f0bf0c1d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 17 Jan 2018 09:57:30 +0800 Subject: [PATCH 0109/2461] [SPARK-23095][SQL] Decorrelation of scalar subquery fails with java.util.NoSuchElementException ## What changes were proposed in this pull request? The following SQL involving scalar correlated query returns a map exception. ``` SQL SELECT t1a FROM t1 WHERE t1a = (SELECT count(*) FROM t2 WHERE t2c = t1c HAVING count(*) >= 1) ``` ``` SQL key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) java.util.NoSuchElementException: key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$.org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$evalSubqueryOnZeroTups(subquery.scala:378) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:430) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:426) ``` In this case, after evaluating the HAVING clause "count(*) > 1" statically against the binding of aggregtation result on empty input, we determine that this query will not have a the count bug. We should simply return the evalSubqueryOnZeroTups with empty value. (Please fill in changes proposed in this fix) ## How was this patch tested? A new test was added in the Subquery bucket. Author: Dilip Biswal Closes #20283 from dilipbiswal/scalar-count-defect. --- .../sql/catalyst/optimizer/subquery.scala | 5 +- .../scalar-subquery-predicate.sql | 10 ++++ .../scalar-subquery-predicate.sql.out | 57 ++++++++++++------- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2673bea648d09..709db6d8bec7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -369,13 +369,14 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + case _ => + sys.error(s"Unexpected operator in scalar subquery: $lp") } val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId) + resultMap.getOrElse(plan.output.head.exprId, None) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index fb0d07fbdace7..1661209093fc4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -173,6 +173,16 @@ WHERE t1a = (SELECT max(t2a) HAVING count(*) >= 0) OR t1i > '2014-12-31'; +-- TC 02.03.01 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31'; + -- TC 02.04 -- t1 on the right of an outer join -- can be reduced to inner join diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 8b29300e71f90..a2b86db3e4f4c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -293,6 +293,21 @@ val1d -- !query 19 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31' +-- !query 19 schema +struct +-- !query 19 output +val1c +val1d + +-- !query 22 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -300,13 +315,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 19 schema +-- !query 22 schema struct --- !query 19 output +-- !query 22 output 7 --- !query 20 +-- !query 23 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -317,14 +332,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 20 schema +-- !query 23 schema struct --- !query 20 output +-- !query 23 output val1b val1c --- !query 21 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -338,14 +353,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 21 schema +-- !query 24 schema struct --- !query 21 output +-- !query 24 output val1b val1c --- !query 22 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -359,9 +374,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 25 schema struct --- !query 22 output +-- !query 25 output val1a val1a val1b @@ -372,7 +387,7 @@ val1d val1d --- !query 23 +-- !query 26 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -386,16 +401,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 26 schema struct --- !query 23 output +-- !query 26 output val1a val1b val1c val1d --- !query 24 +-- !query 27 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -409,13 +424,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 27 schema struct --- !query 24 output +-- !query 27 output val1a --- !query 25 +-- !query 28 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -423,8 +438,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 28 schema struct --- !query 25 output +-- !query 28 output val1b val1c From a9b845ebb5b51eb619cfa7d73b6153024a6a420d Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 17 Jan 2018 10:03:25 +0800 Subject: [PATCH 0110/2461] [SPARK-22361][SQL][TEST] Add unit test for Window Frames ## What changes were proposed in this pull request? There are already quite a few integration tests using window frames, but the unit tests coverage is not ideal. In this PR the already existing tests are reorganized, extended and where gaps found additional cases added. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20019 from gaborgsomogyi/SPARK-22361. --- .../parser/ExpressionParserSuite.scala | 57 ++- .../sql/DataFrameWindowFramesSuite.scala | 405 ++++++++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 243 ----------- 3 files changed, 454 insertions(+), 251 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2b9783a3295c6..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -249,8 +249,8 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) @@ -263,21 +263,62 @@ class ExpressionParserSuite extends PlanTest { "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", WindowExpression('sum.function('product + 1), WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + } + + test("range/rows window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } - // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", -Literal(10), CurrentRow), + // No between combinations + ("unbounded preceding", UnboundedPreceding, CurrentRow), ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("10 preceding", -Literal(10), CurrentRow), + ("3 + 1 preceding", -Add(Literal(3), Literal(1)), CurrentRow), + ("0 preceding", -Literal(0), CurrentRow), + ("current row", CurrentRow, CurrentRow), + ("0 following", Literal(0), CurrentRow), ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), - ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("10 following", Literal(10), CurrentRow), + ("2147483649 following", Literal(2147483649L), CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + + // Between combinations + ("between unbounded preceding and 5 following", + UnboundedPreceding, Literal(5)), + ("between unbounded preceding and 3 + 1 following", + UnboundedPreceding, Add(Literal(3), Literal(1))), + ("between unbounded preceding and 2147483649 following", + UnboundedPreceding, Literal(2147483649L)), ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), + ("between 2147483648 preceding and current row", -Literal(2147483648L), CurrentRow), ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between 3 + 1 preceding and current row", -Add(Literal(3), Literal(1)), CurrentRow), + ("between 0 preceding and current row", -Literal(0), CurrentRow), + ("between current row and current row", CurrentRow, CurrentRow), + ("between current row and 0 following", CurrentRow, Literal(0)), ("between current row and 5 following", CurrentRow, Literal(5)), - ("between 10 preceding and 5 following", -Literal(10), Literal(5)) + ("between current row and 3 + 1 following", CurrentRow, Add(Literal(3), Literal(1))), + ("between current row and 2147483649 following", CurrentRow, Literal(2147483649L)), + ("between current row and unbounded following", CurrentRow, UnboundedFollowing), + ("between 2147483648 preceding and unbounded following", + -Literal(2147483648L), UnboundedFollowing), + ("between 10 preceding and unbounded following", + -Literal(10), UnboundedFollowing), + ("between 3 + 1 preceding and unbounded following", + -Add(Literal(3), Literal(1)), UnboundedFollowing), + ("between 0 preceding and unbounded following", -Literal(0), UnboundedFollowing), + + // Between partial and full range + ("between 10 preceding and 5 following", -Literal(10), Literal(5)), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing) ) frameTypes.foreach { case (frameTypeSql, frameType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala new file mode 100644 index 0000000000000..0ee9b0edc02b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Window frame testing for DataFrame API. + */ +class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("lead/lag with empty data frame") { + val df = Seq.empty[(Int, String)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + lead("value", 1).over(window), + lag("value", 1).over(window)), + Nil) + } + + test("lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) + } + + test("reverse lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) + } + + test("lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) + } + + test("reverse lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil) + } + + test("lead/lag with default value") { + val default = "n/a" + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 2, default).over(window), + lag("value", 2, default).over(window), + lead("value", -2, default).over(window), + lag("value", -2, default).over(window)), + Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: + Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: + Row(2, default, default, default, default) :: Nil) + } + + test("rows/range between with empty data frame") { + val df = Seq.empty[(String, Int)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + first("value").over( + window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + first("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Nil) + } + + test("rows between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept at most one ORDER BY expression when unbounded") { + val df = Seq((1, 1)).toDF("key", "value") + val window = Window.orderBy($"key", $"value") + + checkAnswer( + df.select( + $"key", + min("key").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq(Row(1, 1)) + ) + + val e1 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e2 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e3 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + } + + test("range between should accept numeric values only when bounded") { + val df = Seq("non_numeric").toDF("value") + val window = Window.orderBy($"value") + + checkAnswer( + df.select( + $"value", + min("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) + + val e1 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("The data type of the upper bound 'string' " + + "does not match the expected data type")) + + val e2 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + + val e3 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + } + + test("range between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + + def dt(date: String): Date = Date.valueOf(date) + + val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), + (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)) + + checkAnswer( + df2.select( + $"key", + count("key").over(window)), + Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), + Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) + ) + } + + test("range between should accept double values as boundary") { + val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"), + (100.001D, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D)) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) + ) + } + + test("range between should accept interval values as boundary") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours"))) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), + Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) + ) + } + + test("unbounded rows/range between with aggregation") { + val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + sum("value").over(window. + rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + sum("value").over(window. + rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) + } + + test("unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) :: + Row(4, 4, 4) :: Nil) + } + + test("reverse unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc) + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) :: + Row(1, 1, 1) :: Nil) + } + + test("unbounded preceding/following range between with aggregation") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy("value").orderBy("key") + + checkAnswer( + df.select( + $"key", + avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1)) + .as("avg_key1"), + avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing)) + .as("avg_key2")), + Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) :: + Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) :: + Row(6, 17.0d / 4.0d, 6.0d) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse preceding/following range between with aggregation") { + val df = Seq(1, 2, 4, 3, 2, 1).toDF("value") + val window = Window.orderBy($"value".desc) + + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)), + sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: + Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("reverse sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("sliding range between with aggregation") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) :: + Row(2, 2.0d) :: Row(2, 2.0d) :: Nil) + } + + test("reverse sliding range between with aggregation") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window.partitionBy($"category").orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 01c988ecc3726..281147835abde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -55,56 +55,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } - test("Window.rowsBetween") { - val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") - // Running (cumulative) sum - checkAnswer( - df.select('key, sum("value").over( - Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row("one", 1) :: Row("two", 3) :: Nil - ) - } - - test("lead") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) - } - - test("lag") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) - } - - test("lead with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) - } - - test("lag with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) - } - test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -136,199 +86,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } - test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) - } - - test("aggregation and range between") { - val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), - Row(2.0d), Row(2.0d))) - } - - test("row between should accept integer values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - - val e = intercept[AnalysisException]( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) - assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) - } - - test("range between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) - ) - - def dt(date: String): Date = Date.valueOf(date) - - val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), - (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) - .toDF("key", "value") - checkAnswer( - df2.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))), - Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), - Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) - ) - } - - test("range between should accept double values as boundary") { - val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), - (3.3D, "2"), (2.02D, "1"), (100.001D, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, lit(2.5D)))), - Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) - ) - } - - test("range between should accept interval values as boundary") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, - lit(CalendarInterval.fromString("interval 23 days 4 hours"))))), - Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), - Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) - ) - } - - test("aggregation and rows between with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.currentRow)), - last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), - Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), - Row(4, 4, 4, 4))) - } - - test("aggregation and range between with unbounded") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) - .equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) - .as("avg_key3") - ), - Seq(Row(3, null, 3.0d, 4.0d, 3.0d), - Row(5, false, 4.0d, 5.0d, 5.0d), - Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), - Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), - Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), - Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) - } - - test("reverse sliding range frame") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - } - test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") From 16670578519a7b787b0c63888b7d2873af12d5b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 18:11:27 -0800 Subject: [PATCH 0111/2461] [SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader ## What changes were proposed in this pull request? The Kafka reader is now interruptible and can close itself. ## How was this patch tested? I locally ran one of the ContinuousKafkaSourceSuite tests in a tight loop. Before the fix, my machine ran out of open file descriptors a few iterations in; now it works fine. Author: Jose Torres Closes #20253 from jose-torres/fix-data-reader. --- .../sql/kafka010/KafkaContinuousReader.scala | 260 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 ++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 476 ++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 94 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 539 +++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 4 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1628 insertions(+), 416 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..fc977977504f7 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.TimeoutException + +import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader( + topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false + // Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving + // interrupt points to end the query rather than waiting for new data that might never come. + try { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs, + failOnDataLoss) + } catch { + // We didn't read within the timeout. We're supposed to block indefinitely for new data, so + // swallow and ignore this. + case _: TimeoutException => + + // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, + // or if it's the endpoint of the data range (i.e. the "true" next offset). + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + val range = consumer.getAvailableOffsetRange() + if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { + // retry + } else { + throw e + } + } + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 864a92b8f813f..169a5d006fb04 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..8487a69851237 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..5a1a14f7a307a --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + // Continuous processing tasks end asynchronously, so test that they actually end. + private val tasksEndedListener = new SparkListener() { + val activeTaskIdCount = new AtomicInteger(0) + + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + activeTaskIdCount.incrementAndGet() + } + + override def onTaskEnd(end: SparkListenerTaskEnd): Unit = { + activeTaskIdCount.decrementAndGet() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + spark.sparkContext.addSparkListener(tasksEndedListener) + } + + override def afterEach(): Unit = { + eventually(timeout(streamingTimeout)) { + assert(tasksEndedListener.activeTaskIdCount.get() == 0) + } + spark.sparkContext.removeSparkListener(tasksEndedListener) + super.afterEach() + } + + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index a0f5695fc485c..1acff61e11d2a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - if (query.get.isActive) { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + query match { // Make sure no Spark job is running when deleting a topic - query.get.processAllAvailable() + case Some(m: MicroBatchExecution) => m.processAllAvailable() + case _ => } val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap @@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") .option("subscribe", topic) + .load() - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } - test("maxOffsetsPerTrigger") { + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true } - if (q.exception.isDefined) { - throw q.exception.get + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) } - true - } - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("delete a topic when a Spark job is running") { - KafkaSourceSuite.collectedData.clear() - - val topic = newTopic() - testUtils.createTopic(topic, partitions = 1) - testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribe", topic) - // If a topic is deleted and we try to poll data starting from offset 0, - // the Kafka consumer will just block until timeout and return an empty result. - // So set the timeout to 1 second to make this test fast. - .option("kafkaConsumer.pollTimeoutMs", "1000") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - KafkaSourceSuite.globalTestUtils = testUtils - // The following ForeachWriter will delete the topic before fetching data from Kafka - // in executors. - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { - override def open(partitionId: Long, version: Long): Boolean = { - KafkaSourceSuite.globalTestUtils.deleteTopic(topic) - true - } - - override def process(value: Int): Unit = { - KafkaSourceSuite.collectedData.add(value) - } - - override def close(errorOrNull: Throwable): Unit = {} - }).start() - query.processAllAvailable() - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - assert(query.exception.isEmpty) - } - test("get offsets from case insensitive parameters") { for ((optionKey, optionValue, answer) <- Seq( (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), @@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() - query.processAllAvailable() - val rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + var rows: Array[Row] = Array() + eventually(timeout(streamingTimeout)) { + rows = spark.table("kafkaColumnTypes").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + } val row = rows(0) assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") @@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index b3f1a1a1aaab3..66eb42d4658f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -177,6 +176,7 @@ class DataReaderThread( private[continuous] var failureReason: Throwable = _ override def run(): Unit = { + TaskContext.setTaskContext(context) val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) try { while (!context.isInterrupted && !context.isCompleted()) { @@ -201,6 +201,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 50345a2aa59741c511d555edbbad2da9611e7d16 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 22:14:47 -0800 Subject: [PATCH 0112/2461] Revert "[SPARK-23020][CORE] Fix races in launcher code, test." This reverts commit 66217dac4f8952a9923625908ad3dcb030763c81. --- .../spark/launcher/SparkLauncherSuite.java | 49 +++++++------------ .../spark/launcher/AbstractAppHandle.java | 22 ++------- .../spark/launcher/ChildProcAppHandle.java | 18 +++---- .../spark/launcher/InProcessAppHandle.java | 17 +++---- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 +++-------------- .../org/apache/spark/launcher/BaseSuite.java | 42 +++------------- .../spark/launcher/LauncherServerSuite.java | 20 +++++--- 8 files changed, 72 insertions(+), 156 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -32,7 +31,6 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; -import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -139,9 +137,7 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { - assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); - }); + TimeUnit.MILLISECONDS.sleep(500); } } @@ -150,35 +146,26 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - synchronized (transitions) { - transitions.add(h.getState()); - } + transitions.add(h.getState()); return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = null; - try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); - } finally { - if (handle != null) { - handle.kill(); - } - } + SparkAppHandle handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..df1e7316861d4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private volatile boolean disposed; + private boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,7 +70,8 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { + if (!disposed) { + disposed = true; if (connection != null) { try { connection.close(); @@ -78,7 +79,7 @@ public synchronized void disconnect() { // no-op. } } - dispose(); + server.unregister(this); } } @@ -94,21 +95,6 @@ boolean isDisposed() { return disposed; } - /** - * Mark the handle as disposed, and set it as LOST in case the current state is not final. - */ - synchronized void dispose() { - if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. - server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } - this.disposed = true; - } - } - void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..8b3f427b7750e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,16 +48,14 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - if (!isDisposed()) { - setState(State.KILLED); - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); - } - childProc = null; + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); } + childProc = null; } + setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -96,6 +94,8 @@ void monitorChild() { return; } + disconnect(); + int ec; try { ec = proc.exitValue(); @@ -118,8 +118,6 @@ void monitorChild() { if (newState != null) { setState(newState, true); } - - disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,16 +39,15 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - if (!isDisposed()) { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - setState(State.KILLED); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); - } + LOG.warning("kill() may leave the underlying app running in in-process mode."); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); } + + setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index fd6f229b2349c..b4a8719e26053 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public synchronized void close() throws IOException { + public void close() throws IOException { if (!closed) { - closed = true; - socket.close(); + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } } } - boolean isOpen() { - return !closed; - } - } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 660c4443b20b9..b8999a1d7a4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,33 +217,6 @@ void unregister(AbstractAppHandle handle) { break; } } - - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -315,7 +288,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - volatile AbstractAppHandle handle; + private AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -365,21 +338,16 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { - if (!isOpen()) { - return; - } - synchronized (clients) { clients.remove(this); } - - synchronized (this) { - super.close(); - notifyAll(); - } - + super.close(); if (handle != null) { - handle.dispose(); + if (!handle.getState().isFinal()) { + LOG.log(Level.WARNING, "Lost connection to spark application."); + handle.setState(SparkAppHandle.State.LOST); + } + handle.disconnect(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..3e1a90eae98d4 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -48,46 +47,19 @@ public void postChecks() { assertNull(server); } - protected void waitFor(final SparkAppHandle handle) throws Exception { + protected void waitFor(SparkAppHandle handle) throws Exception { + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); try { - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is not in final state.", handle.getState().isFinal()); - }); + while (!handle.getState().isFinal()) { + assertTrue("Timed out waiting for handle to transition to final state.", + System.nanoTime() < deadline); + TimeUnit.MILLISECONDS.sleep(10); + } } finally { if (!handle.getState().isFinal()) { handle.kill(); } } - - // Wait until the handle has been marked as disposed, to make sure all cleanup tasks - // have been performed. - AbstractAppHandle ahandle = (AbstractAppHandle) handle; - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); - }); - } - - /** - * Call a closure that performs a check every "period" until it succeeds, or the timeout - * elapses. - */ - protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { - assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); - long deadline = System.nanoTime() + timeout.toNanos(); - int count = 0; - while (true) { - try { - count++; - check.run(); - return; - } catch (Throwable t) { - if (System.nanoTime() >= deadline) { - String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); - throw new IllegalStateException(msg, t); - } - Thread.sleep(period.toMillis()); - } - } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 75c1af0c71e2a..7e2b09ce25c9b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,14 +23,12 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; -import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -199,20 +197,28 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - final AtomicBoolean helloSent = new AtomicBoolean(); - eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { try { - if (!helloSent.get()) { + if (!helloSent) { client.send(new Hello(secret, "1.4.0")); - helloSent.set(true); + helloSent = true; } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } } - }); + } } private static class TestClient extends LauncherConnection { From a963980a6d2b4bef2c546aa33acf0aa501d2507b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 22:27:28 -0800 Subject: [PATCH 0113/2461] Fix merge between 07ae39d0ec and 1667057851 ## What changes were proposed in this pull request? The first commit added a new test, and the second refactored the class the test was in. The automatic merge put the test in the wrong place. ## How was this patch tested? - Author: Jose Torres Closes #20289 from jose-torres/fix. --- .../apache/spark/sql/kafka010/KafkaSourceSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 1acff61e11d2a..62f6a34a6b67a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -479,11 +479,6 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { // `failOnDataLoss` is `false`, we should not fail the query assert(query.exception.isEmpty) } -} - -class KafkaSourceSuiteBase extends KafkaSourceTest { - - import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -549,6 +544,11 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { CheckLastBatch(120 to 124: _*) ) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() From a0aedb0ded4183cc33b27e369df1cbf862779e26 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 14:32:18 +0800 Subject: [PATCH 0114/2461] [SPARK-23072][SQL][TEST] Add a Unicode schema test for file-based data sources ## What changes were proposed in this pull request? After [SPARK-20682](https://github.com/apache/spark/pull/19651), Apache Spark 2.3 is able to read ORC files with Unicode schema. Previously, it raises `org.apache.spark.sql.catalyst.parser.ParseException`. This PR adds a Unicode schema test for CSV/JSON/ORC/Parquet file-based data sources. Note that TEXT data source only has [a single column with a fixed name 'value'](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala#L71). ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #20266 from dongjoon-hyun/SPARK-23072. --- .../spark/sql/FileBasedDataSourceSuite.scala | 81 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ---- .../sql/hive/MetastoreDataSourcesSuite.scala | 14 ---- .../sql/hive/execution/SQLQuerySuite.scala | 8 -- 4 files changed, 81 insertions(+), 38 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala new file mode 100644 index 0000000000000..22fb496bc838e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + + allFileBasedDataSources.foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempPath { dir => + Seq("str").toDS().limit(0).write.format(format).save(dir.getCanonicalPath) + } + } + } + + // `TEXT` data source always has a single column whose name is `value`. + allFileBasedDataSources.filterNot(_ == "text").foreach { format => + test(s"SPARK-23072 Write and read back unicode column names - $format") { + withTempPath { path => + val dir = path.getCanonicalPath + + // scalastyle:off nonascii + val df = Seq("a").toDF("한글") + // scalastyle:on nonascii + + df.write.format(format).option("header", "true").save(dir) + val answerDf = spark.read.format(format).option("header", "true").load(dir) + + assert(df.schema.sameType(answerDf.schema)) + checkAnswer(df, answerDf) + } + } + } + + // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. + // `TEXT` data source always has a single column whose name is `value`. + Seq("orc", "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF().limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-22146 read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 96bf65fce9c4a..7c9840a34eaa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2757,20 +2757,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - // Only New OrcFileFormat supports this - Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, - "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { - withTempPath { file => - val path = file.getCanonicalPath - val emptyDf = Seq((true, 1, "str")).toDF.limit(0) - emptyDf.write.format(format).save(path) - - val df = spark.read.format(format).load(path) - assert(df.schema.sameType(emptyDf.schema)) - checkAnswer(df, emptyDf) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c8caba83bf365..fade143a1755e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,14 +23,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.command.CreateTableCommand import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ @@ -1344,18 +1342,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"SPARK-22146: read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" - withTempDir { dir => - val tmpFile = s"$dir/$nameWithSpecialChars" - spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - val fileContent = spark.read.format(format).load(tmpFile) - checkAnswer(fileContent, Seq(Row("a"), Row("b"))) - } - } - } - private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 47adc77a52d51..33bcae91fdaf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2159,12 +2159,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"Writing empty datasets should not fail - $format") { - withTempDir { dir => - Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") - } - } - } } From 1f3d933e0bd2b1e934a233ed699ad39295376e71 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 17 Jan 2018 16:01:41 +0800 Subject: [PATCH 0115/2461] [SPARK-23062][SQL] Improve EXCEPT documentation ## What changes were proposed in this pull request? Make the default behavior of EXCEPT (i.e. EXCEPT DISTINCT) more explicit in the documentation, and call out the change in behavior from 1.x. Author: Henry Robinson Closes #20254 from henryr/spark-23062. --- R/pkg/R/DataFrame.R | 2 +- python/pyspark/sql/dataframe.py | 3 ++- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6caa125e1e14a..29f3e986eaab6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2853,7 +2853,7 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT DISTINCT} in SQL. #' #' @param x a SparkDataFrame. #' @param y a SparkDataFrame. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 95eca76fa9888..2d5e9b91468cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1364,7 +1364,8 @@ def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. - This is equivalent to `EXCEPT` in SQL. + This is equivalent to `EXCEPT DISTINCT` in SQL. + """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 34f0ab5aa6699..912f411fa3845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1903,7 +1903,7 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. + * This is equivalent to `EXCEPT DISTINCT` in SQL. * * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. From 0f8a28617a0742d5a99debfbae91222c2e3b5cec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 21:53:36 +0800 Subject: [PATCH 0116/2461] [SPARK-21783][SQL] Turn on ORC filter push-down by default ## What changes were proposed in this pull request? ORC filter push-down is disabled by default from the beginning, [SPARK-2883](https://github.com/apache/spark/commit/aa31e431fc09f0477f1c2351c6275769a31aca90#diff-41ef65b9ef5b518f77e2a03559893f4dR149 ). Now, Apache Spark starts to depend on Apache ORC 1.4.1. For Apache Spark 2.3, this PR turns on ORC filter push-down by default like Parquet ([SPARK-9207](https://issues.apache.org/jira/browse/SPARK-21783)) as a part of [SPARK-20901](https://issues.apache.org/jira/browse/SPARK-20901), "Feature parity for ORC with Parquet". ## How was this patch tested? Pass the existing tests. Author: Dongjoon Hyun Closes #20265 from dongjoon-hyun/SPARK-21783. --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/FilterPushdownBenchmark.scala | 243 ++++++++++++++++++ 2 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6746fbcaf2483..16fbb0c3e9e21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -410,7 +410,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..c6dd7dadc9d93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + conf.set("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("FilterPushdownBenchmark") + .config(conf) + .getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("id", monotonically_increasing_id()) + + val dirORC = dir.getCanonicalPath + "/orc" + val dirParquet = dir.getCanonicalPath + "/parquet" + + df.write.mode("overwrite").orc(dirORC) + df.write.mode("overwrite").parquet(dirParquet) + + spark.read.orc(dirORC).createOrReplaceTempView("orcTable") + spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X + Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X + Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X + Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X + + Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X + Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X + Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X + Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X + + Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X + Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X + Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X + Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X + + Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X + Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X + Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X + + Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X + Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X + Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X + Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X + + Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X + Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X + Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X + + Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X + Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X + Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X + Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X + + Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X + Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X + Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X + Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X + + Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X + Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X + Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X + Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X + + Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X + Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X + Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X + Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X + + Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X + Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X + Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X + Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X + + Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X + Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X + Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X + Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X + */ + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + val mid = numRows / 2 + + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width) + + Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => + val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"id = $mid", + s"id <=> $mid", + s"$mid <= id AND id <= $mid", + s"${mid - 1} < id AND id < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% rows (id < ${numRows * percent / 100})", + s"id < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + } + } +} From 8598a982b4147abe5f1aae005fea0fd5ae395ac4 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 18 Jan 2018 00:05:26 +0800 Subject: [PATCH 0117/2461] [SPARK-23079][SQL] Fix query constraints propagation with aliases ## What changes were proposed in this pull request? Previously, PR #19201 fix the problem of non-converging constraints. After that PR #19149 improve the loop and constraints is inferred only once. So the problem of non-converging constraints is gone. However, the case below will fail. ``` spark.range(5).write.saveAsTable("t") val t = spark.read.table("t") val left = t.withColumn("xid", $"id" + lit(1)).as("x") val right = t.withColumnRenamed("id", "xid").as("y") val df = left.join(right, "xid").filter("id = 3").toDF() checkAnswer(df, Row(4, 3)) ``` Because `aliasMap` replace all the aliased child. See the test case in PR for details. This PR is to fix this bug by removing useless code for preventing non-converging constraints. It can be also fixed with #20270, but this is much simpler and clean up the code. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20278 from gengliangwang/FixConstraintSimple. --- .../catalyst/plans/logical/LogicalPlan.scala | 1 + .../plans/logical/QueryPlanConstraints.scala | 37 +----------- .../InferFiltersFromConstraintsSuite.scala | 59 +------------------ .../plans/ConstraintPropagationSuite.scala | 2 + .../org/apache/spark/sql/SQLQuerySuite.scala | 11 ++++ 5 files changed, 17 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ff2a0ec588567..c8ccd9bd03994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -255,6 +255,7 @@ abstract class UnaryNode extends LogicalPlan { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) + allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 9c0a30a47f839..5c7b8e5b97883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -94,25 +94,16 @@ trait QueryPlanConstraints { self: LogicalPlan => case _ => Seq.empty[Attribute] } - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) - // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. - /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) var inferredConstraints = Set.empty[Expression] - aliasedConstraints.foreach { + constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = aliasedConstraints - eq + val candidateConstraints = constraints - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) case _ => // No inference @@ -120,30 +111,6 @@ trait QueryPlanConstraints { self: LogicalPlan => inferredConstraints -- constraints } - /** - * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. - * Thus non-converging inference can be prevented. - * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions. - * Also, the size of constraints is reduced without losing any information. - * When the inferred filters are pushed down the operators that generate the alias, - * the alias names used in filters are replaced by the aliased expressions. - */ - private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) - : Set[Expression] = { - val attributesInEqualTo = constraints.flatMap { - case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil - case _ => Nil - } - var aliasedConstraints = constraints - attributesInEqualTo.foreach { a => - if (aliasMap.contains(a)) { - val child = aliasMap.get(a).get - aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) - } - } - aliasedConstraints - } - private def replaceConstraints( constraints: Set[Expression], source: Expression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index a0708bf7eee9a..178c4b8c270a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -34,6 +34,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushDownPredicate, InferFiltersFromConstraints, CombineFilters, + SimplifyBinaryComparison, BooleanSimplification) :: Nil } @@ -160,64 +161,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join with alias: don't generate constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through - // the constraint inference procedure. - val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - // We hide an `Alias` inside the child's child's expressions, to cover the situation reported - // in [SPARK-20700]. - .select('int_col, 'd, 'a).as("t") - .join(t2, Inner, - Some("t.a".attr === "t2.a".attr - && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a))) - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b))) - && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b)) - && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b)) - && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b))) - .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - .select('int_col, 'd, 'a).as("t") - .join( - t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && - 'a === Coalesce(Seq('a, 'a))), - Inner, - Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - - test("inner join with EqualTo expressions containing part of each other: don't generate " + - "constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating - // complicated constraints through the constraint inference procedure. - val originalQuery = t1 - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .where('a === 'd && 'c === 'e) - .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) && - 'c === Coalesce(Seq('a, 'b))) - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .join(t2.where(IsNotNull('a) && IsNotNull('c)), - Inner, - Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - test("generate correct filters for alias that don't produce recursive constraints") { val t1 = testRelation.subquery('t1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 866ff0d33cbb2..a37e06d922642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7c9840a34eaa3..d4d0aa4f5f5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23079: constraints should be inferred correctly with aliases") { + withTable("t") { + spark.range(5).write.saveAsTable("t") + val t = spark.read.table("t") + val left = t.withColumn("xid", $"id" + lit(1)).as("x") + val right = t.withColumnRenamed("id", "xid").as("y") + val df = left.join(right, "xid").filter("id = 3").toDF() + checkAnswer(df, Row(4, 3)) + } + } + test("SRARK-22266: the same aggregate function was calculated multiple times") { val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" val df = sql(query) From c132538a164cd8b55dbd7e8ffdc0c0782a0b588c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 17 Jan 2018 09:27:49 -0800 Subject: [PATCH 0118/2461] [SPARK-23020] Ignore Flaky Test: SparkLauncherSuite.testInProcessLauncher ## What changes were proposed in this pull request? Temporarily ignoring flaky test `SparkLauncherSuite.testInProcessLauncher` to de-flake the builds. This should be re-enabled when SPARK-23020 is merged. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20291 from sameeragarwal/disable-test-2. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..dffa609f1cbdf 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -25,6 +25,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; @@ -120,7 +121,8 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - @Test + // TODO: [SPARK-23020] Re-enable this + @Ignore public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original From 86a845031824a5334db6a5299c6f5dcc982bc5b8 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:52:51 -0800 Subject: [PATCH 0119/2461] [SPARK-23033][SS] Don't use task level retry for continuous processing ## What changes were proposed in this pull request? Continuous processing tasks will fail on any attempt number greater than 0. ContinuousExecution will catch these failures and restart globally from the last recorded checkpoints. ## How was this patch tested? unit test Author: Jose Torres Closes #20225 from jose-torres/no-retry. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +-- .../ContinuousDataSourceRDDIter.scala | 5 ++ .../continuous/ContinuousExecution.scala | 2 +- .../ContinuousTaskRetryException.scala | 26 +++++++ .../spark/sql/streaming/StreamTest.scala | 9 ++- .../continuous/ContinuousSuite.scala | 71 +++++++++++-------- 6 files changed, 84 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 62f6a34a6b67a..27dbb3f7a8f31 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -808,16 +808,14 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { val query = kafka .writeStream .format("memory") - .outputMode("append") .queryName("kafkaColumnTypes") .trigger(defaultTrigger) .start() - var rows: Array[Row] = Array() eventually(timeout(streamingTimeout)) { - rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + assert(spark.table("kafkaColumnTypes").count == 1, + s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}") } - val row = rows(0) + val row = spark.table("kafkaColumnTypes").head() assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 66eb42d4658f6..dcb3b54c4e160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,6 +52,11 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..45b794c70a50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,7 +24,7 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala new file mode 100644 index 0000000000000..e0a6f6dd50bb3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.SparkException + +/** + * An exception thrown when a continuous processing task runs with a nonzero attempt ID. + */ +class ContinuousTaskRetryException + extends SparkException("Continuous execution does not support task retry", null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..c75247e0f6ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,8 +472,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } + s.lastExecution.executedPlan // will fail if lastExecution is null + } case _ => } } catch { @@ -645,7 +645,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9562c10feafe9..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,36 +17,18 @@ package org.apache.spark.sql.streaming.continuous -import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} -import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} +import java.util.UUID -import scala.reflect.ClassTag -import scala.util.control.ControlThrowable - -import com.google.common.util.concurrent.UncheckedExecutionException -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.{SparkContext, SparkEnv} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes -import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.TestSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class ContinuousSuiteBase extends StreamTest { // We need more than the default local[2] to be able to schedule all partitions simultaneously. @@ -219,6 +201,41 @@ class ContinuousSuite extends ContinuousSuiteBase { StopStream) } + test("task failure kills the query") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + // Get an arbitrary task from this query to kill. It doesn't matter which one. + var taskId: Long = -1 + val listener = new SparkListener() { + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + taskId = start.taskInfo.taskId + } + } + spark.sparkContext.addSparkListener(listener) + try { + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(100)), + Execute(waitForRateSourceTriggers(_, 2)), + Execute { _ => + // Wait until a task is started, then kill its first attempt. + eventually(timeout(streamingTimeout)) { + assert(taskId != -1) + } + spark.sparkContext.killTaskAttempt(taskId) + }, + ExpectFailure[SparkException] { e => + e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException] + }) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + test("query without test harness") { val df = spark.readStream .format("rate") @@ -258,13 +275,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), - Execute { query => - val data = query.sink.asInstanceOf[MemorySinkV2].allData - val vals = data.map(_.getLong(0)).toSet - assert(scala.Range(0, 25000).forall { i => - vals.contains(i) - }) - }) + StopStream, + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))) + ) } test("automatic epoch advancement") { @@ -280,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } @@ -311,6 +325,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { StopStream, StartStream(Trigger.Continuous(2012)), AwaitEpoch(50), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } From e946c63dd56d121cf898084ed7e9b5b0868b226e Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:58:44 -0800 Subject: [PATCH 0120/2461] [SPARK-23093][SS] Don't change run id when reconfiguring a continuous processing query. ## What changes were proposed in this pull request? Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination. ## How was this patch tested? new and existing unit tests Author: Jose Torres Closes #20282 from jose-torres/fix-runid. --- .../datasources/v2/DataSourceV2ScanExec.scala | 3 ++- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++-- .../execution/streaming/StreamExecution.scala | 3 +-- .../ContinuousDataSourceRDDIter.scala | 10 ++++---- .../continuous/ContinuousExecution.scala | 18 ++++++++----- .../continuous/EpochCoordinator.scala | 9 ++++--- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 25 +++++++++++++++++++ 8 files changed, 54 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 8c64df080242f..beb66738732be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -58,7 +58,8 @@ case class DataSourceV2ScanExec( case _: ContinuousReader => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetReaderPartitions(readTasks.size())) new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) .asInstanceOf[RDD[InternalRow]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..3dbdae7b4df9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) val runTask = writer match { case w: ContinuousWriter => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) (context: TaskContext, iter: Iterator[InternalRow]) => @@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow]): WriterCommitMessage = { val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) val currentMsg: WriterCommitMessage = null var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..e7982d7880ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index dcb3b54c4e160..cd7065f5e6601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -59,7 +59,7 @@ class ContinuousDataSourceRDD( val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() - val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) + val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) // This queue contains two types of messages: // * (null, null) representing an epoch boundary. @@ -68,7 +68,7 @@ class ContinuousDataSourceRDD( val epochPollFailed = new AtomicBoolean(false) val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--${runId}--${context.partitionId()}") + s"epoch-poll--$coordinatorId--${context.partitionId()}") val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) epochPollExecutor.scheduleWithFixedDelay( epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) @@ -86,7 +86,7 @@ class ContinuousDataSourceRDD( epochPollExecutor.shutdown() }) - val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) new Iterator[UnsafeRow] { private val POLL_TIMEOUT_MS = 1000 @@ -150,7 +150,7 @@ class EpochPollRunnable( private[continuous] var failureReason: Throwable = _ private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong override def run(): Unit = { @@ -177,7 +177,7 @@ class DataReaderThread( failedFlag: AtomicBoolean) extends Thread( s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { private[continuous] var failureReason: Throwable = _ override def run(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 45b794c70a50a..c0507224f9be8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -57,6 +57,9 @@ class ContinuousExecution( @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources + // For use only in test harnesses. + private[sql] var currentEpochCoordinatorId: String = _ + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in StreamExecutionThread " + @@ -149,7 +152,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -219,15 +221,19 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - sparkSession.sparkContext.setLocalProperty( + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) - sparkSession.sparkContext.setLocalProperty( - ContinuousExecution.RUN_ID_KEY, runId.toString) + // Add another random ID on top of the run ID, to distinguish epoch coordinators across + // reconfigurations. + val epochCoordinatorId = s"$runId--${UUID.randomUUID}" + currentEpochCoordinatorId = epochCoordinatorId + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { @@ -359,5 +365,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" - val RUN_ID_KEY = "__run_id" + val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..90b3584aa0436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset( /** Helper object used to create reference to [[EpochCoordinator]]. */ private[sql] object EpochCoordinatorRef extends Logging { - private def endpointName(runId: String) = s"EpochCoordinator-$runId" + private def endpointName(id: String) = s"EpochCoordinator-$id" /** * Create a reference to a new [[EpochCoordinator]]. @@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging { writer: ContinuousWriter, reader: ContinuousReader, query: ContinuousExecution, + epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( writer, reader, query, startEpoch, session, env.rpcEnv) - val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref } - def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { - val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv) logDebug("Retrieved existing EpochCoordinator endpoint") rpcEndpointRef } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index c75247e0f6ed8..efdb0e0e7cf1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(): AssertOnQuery = Execute { case s: ContinuousExecution => - val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get) .askSync[Long](IncrementAndGetEpoch) s.awaitEpoch(newEpoch - 1) case _ => throw new IllegalStateException("microbatch cannot increment epoch") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288fb..79d65192a14aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("continuous processing listeners should receive QueryTerminatedEvent") { + val df = spark.readStream.format("rate").load() + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append, useV2Sink = true)( + StartStream(Trigger.Continuous(1000)), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() From 4e6f8fb150ae09c7d1de6beecb2b98e5afa5da19 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 07:26:43 +0900 Subject: [PATCH 0121/2461] [SPARK-23047][PYTHON][SQL] Change MapVector to NullableMapVector in ArrowColumnVector ## What changes were proposed in this pull request? This PR changes usage of `MapVector` in Spark codebase to use `NullableMapVector`. `MapVector` is an internal Arrow class that is not supposed to be used directly. We should use `NullableMapVector` instead. ## How was this patch tested? Existing test. Author: Li Jin Closes #20239 from icexelloss/arrow-map-vector. --- .../sql/vectorized/ArrowColumnVector.java | 13 +++++-- .../vectorized/ArrowColumnVectorSuite.scala | 36 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 708333213f3f1..eb69001fe677e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -247,8 +247,8 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; + } else if (vector instanceof NullableMapVector) { + NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); childColumns = new ArrowColumnVector[mapVector.size()]; @@ -553,9 +553,16 @@ final int getArrayOffset(int rowId) { } } + /** + * Any call to "get" method will throw UnsupportedOperationException. + * + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined + * in the parent class. Any call to "get" method in this class is a bug in the code. + * + */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(MapVector vector) { + StructAccessor(NullableMapVector vector) { super(vector); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 7304803a092c0..53432669e215d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -322,6 +322,42 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("non nullable struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) + .createVector(allocator).asInstanceOf[NullableMapVector] + + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + vector.setValueCount(2) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.numNulls === 0) + + val row0 = columnVector.getStruct(0, 2) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1, 2) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + columnVector.close() + allocator.close() + } + test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) From 45ad97df87c89cb94ce9564e5773897b6d9326f5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 07:30:54 +0900 Subject: [PATCH 0122/2461] [SPARK-23132][PYTHON][ML] Run doctests in ml.image when testing ## What changes were proposed in this pull request? This PR proposes to actually run the doctests in `ml/image.py`. ## How was this patch tested? doctests in `python/pyspark/ml/image.py`. Author: hyukjinkwon Closes #20294 from HyukjinKwon/trigger-image. --- python/pyspark/ml/image.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index c9b840276f675..2d86c7f03860c 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -194,9 +194,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) >>> df.count() - 4 + 5 .. versionadded:: 2.3.0 """ @@ -216,3 +216,25 @@ def readImages(self, path, recursive=False, numPartitions=-1, def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance + + +def _test(): + import doctest + import pyspark.ml.image + globs = pyspark.ml.image.__dict__.copy() + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.image tests")\ + .getOrCreate() + globs['spark'] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.image, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 7823d43ec0e9c4b8284bb4529b0e624c43bc9bb7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 17 Jan 2018 17:16:57 -0600 Subject: [PATCH 0123/2461] [MINOR] Fix typos in ML scaladocs ## What changes were proposed in this pull request? Fixed some typos found in ML scaladocs ## How was this patch tested? NA Author: Bryan Cutler Closes #20300 from BryanCutler/ml-doc-typos-MINOR. --- .../src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 9bed74a9f2c05..d40827edb6d64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -75,7 +75,7 @@ sealed abstract class SummaryBuilder { * val Row(meanVec) = meanDF.first() * }}} * - * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * Note: Currently, the performance of this interface is about 2x~3x slower than using the RDD * interface. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8826ef3271bc1..88ff0dfd75e96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -93,7 +93,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam @@ -112,7 +112,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St * for more information. * * @group expertSetParam - */@Since("2.3.0") + */ + @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") From bac0d661af6092dd26638223156827aceb901229 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:40:02 -0800 Subject: [PATCH 0124/2461] [SPARK-23119][SS] Minor fixes to V2 streaming APIs ## What changes were proposed in this pull request? - Added `InterfaceStability.Evolving` annotations - Improved docs. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20286 from tdas/SPARK-23119. --- .../v2/streaming/ContinuousReadSupport.java | 2 ++ .../streaming/reader/ContinuousDataReader.java | 2 ++ .../v2/streaming/reader/ContinuousReader.java | 9 +++++++-- .../v2/streaming/reader/MicroBatchReader.java | 5 +++++ .../sources/v2/streaming/reader/Offset.java | 18 +++++++++++++----- .../v2/streaming/reader/PartitionOffset.java | 3 +++ .../sources/v2/writer/DataSourceV2Writer.java | 5 ++++- 7 files changed, 36 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 3136cee1f655f..9a93a806b0efc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -19,6 +19,7 @@ import java.util.Optional; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; @@ -28,6 +29,7 @@ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability for continuous stream processing. */ +@InterfaceStability.Evolving public interface ContinuousReadSupport extends DataSourceV2 { /** * Creates a {@link ContinuousReader} to scan the data from this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index ca9a290e97a02..3f13a4dbf5793 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -17,11 +17,13 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. */ +@InterfaceStability.Evolving public interface ContinuousDataReader extends DataReader { /** * Get the offset of the current record, or the start offset if no records have been read. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index f0b205869ed6c..745f1ce502443 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; @@ -27,11 +28,15 @@ * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { /** - * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to - * a single global offset. + * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 70ff756806032..02f37cebc7484 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; @@ -25,7 +26,11 @@ /** * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** * Set the desired offset range for read tasks created from this reader. Read tasks will diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index 60b87f2ac0756..abba3e7188b13 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -17,12 +17,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; + /** - * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. - * During execution, Offsets provided by the data source implementation will be logged and used as - * restart checkpoints. Sources should provide an Offset implementation which they can use to - * reconstruct the stream position where the offset was taken. + * An abstract representation of progress through a {@link MicroBatchReader} or + * {@link ContinuousReader}. + * During execution, offsets provided by the data source implementation will be logged and used as + * restart checkpoints. Each source should provide an offset implementation which the source can use + * to reconstruct a position in the stream up to which data has been seen/processed. + * + * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to + * maintain compatibility with DataSource V1 APIs. This extension will be removed once we + * get rid of V1 completely. */ +@InterfaceStability.Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is @@ -37,7 +45,7 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of /** * Equality based on JSON string representation. We leverage the * JSON representation for normalization between the Offset's - * in memory and on disk representations. + * in deserialized and serialized representations. */ @Override public boolean equals(Object obj) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index eca0085c8a8ce..4688b85f49f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -19,11 +19,14 @@ import java.io.Serializable; +import org.apache.spark.annotation.InterfaceStability; + /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will * provide a method to merge these into a global Offset. * * These offsets must be serializable. */ +@InterfaceStability.Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index fc37b9a516f82..317ac45bcfd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -22,11 +22,14 @@ import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * From 1002bd6b23ff78a010ca259ea76988ef4c478c6e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:41:43 -0800 Subject: [PATCH 0125/2461] [SPARK-23064][DOCS][SS] Added documentation for stream-stream joins ## What changes were proposed in this pull request? Added documentation for stream-stream joins ![image](https://user-images.githubusercontent.com/663212/35018744-e999895a-fad7-11e7-9d6a-8c7a73e6eb9c.png) ![image](https://user-images.githubusercontent.com/663212/35018775-157eb464-fad8-11e7-879e-47a2fcbd8690.png) ![image](https://user-images.githubusercontent.com/663212/35018784-27791a24-fad8-11e7-98f4-7ff246f62a74.png) ![image](https://user-images.githubusercontent.com/663212/35018791-36a80334-fad8-11e7-9791-f85efa7c6ba2.png) ## How was this patch tested? N/a Author: Tathagata Das Closes #20255 from tdas/join-docs. --- .../structured-streaming-programming-guide.md | 338 +++++++++++++++++- 1 file changed, 326 insertions(+), 12 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index de13e281916db..1779a4215e085 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1051,7 +1051,19 @@ output mode. ### Join Operations -Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples. +Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame +as well as another streaming Dataset/DataFrame. The result of the streaming join is generated +incrementally, similar to the results of streaming aggregations in the previous section. In this +section we will explore what type of joins (i.e. inner, outer, etc.) are supported in the above +cases. Note that in all the supported join types, the result of the join with a streaming +Dataset/DataFrame will be the exactly the same as if it was with a static Dataset/DataFrame +containing the same data in the stream. + + +#### Stream-static joins + +Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some +type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example.
    @@ -1089,6 +1101,300 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat
    +Note that stream-static joins are not stateful, so no state management is necessary. +However, a few types of stream-static outer joins are not yet supported. +These are listed at the [end of this Join section](#support-matrix-for-joins-in-streaming-queries). + +#### Stream-stream Joins +In Spark 2.3, we have added support for stream-stream joins, that is, you can join two streaming +Datasets/DataFrames. The challenge of generating join results between two data streams is that, +at any point of time, the view of the dataset is incomplete for both sides of the join making +it much harder to find matches between inputs. Any row received from one input stream can match +with any future, yet-to-be-received row from the other input stream. Hence, for both the input +streams, we buffer past input as streaming state, so that we can match every future input with +past input and accordingly generate joined results. Furthermore, similar to streaming aggregations, +we automatically handle late, out-of-order data and can limit the state using watermarks. +Let’s discuss the different types of supported stream-stream joins and how to use them. + +##### Inner Joins with optional Watermarking +Inner joins on any kind of columns along with any kind of join conditions are supported. +However, as the stream runs, the size of streaming state will keep growing indefinitely as +*all* past input must be saved as the any new input can match with any input from the past. +To avoid unbounded state, you have to define additional join conditions such that indefinitely +old inputs cannot match with future inputs and therefore can be cleared from the state. +In other words, you will have to do the following additional steps in the join. + +1. Define watermark delays on both inputs such that the engine knows how delayed the input can be +(similar to streaming aggregations) + +1. Define a constraint on event-time across the two inputs such that the engine can figure out when +old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for +matches with the other input. This constraint can be defined in one of the two ways. + + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + + 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). + +Let’s understand this with an example. + +Let’s say we want to join a stream of advertisement impressions (when an ad was shown) with +another stream of user clicks on advertisements to correlate when impressions led to +monetizable clicks. To allow the state cleanup in this stream-stream join, you will have to +specify the watermarking delays and the time constraints as follows. + +1. Watermark delays: Say, the impressions and the corresponding clicks can be late/out-of-order +in event-time by at most 2 and 3 hours, respectively. + +1. Event-time range condition: Say, a click can occur within a time range of 0 seconds to 1 hour +after the corresponding impression. + +The code would look like this. + +
    +
    + +{% highlight scala %} +import org.apache.spark.sql.functions.expr + +val impressions = spark.readStream. ... +val clicks = spark.readStream. ... + +// Apply watermarks on event-time columns +val impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +val clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +import static org.apache.spark.sql.functions.expr + +Dataset impressions = spark.readStream(). ... +Dataset clicks = spark.readStream(). ... + +// Apply watermarks on event-time columns +Dataset impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours"); +Dataset clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours"); + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour ") +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +from pyspark.sql.functions import expr + +impressions = spark.readStream. ... +clicks = spark.readStream. ... + +# Apply watermarks on event-time columns +impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +# Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +##### Outer Joins with Watermarking +While the watermark + event-time constraints is optional for inner joins, for left and right outer +joins they must be specified. This is because for generating the NULL results in outer join, the +engine must know when an input row is not going to match with anything in future. Hence, the +watermark + event-time constraints must be specified for generating correct results. Therefore, +a query with outer-join will look quite like the ad-monetization example earlier, except that +there will be an additional parameter specifying it to be an outer-join. + +
    +
    + +{% highlight scala %} + +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + joinType = "leftOuter" // can be "inner", "leftOuter", "rightOuter" + ) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour "), + "leftOuter" // can be "inner", "leftOuter", "rightOuter" +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + "leftOuter" # can be "inner", "leftOuter", "rightOuter" +) + +{% endhighlight %} + +
    +
    + +However, note that the outer NULL results will be generated with a delay (depends on the specified +watermark delay and the time range condition) because the engine has to wait for that long to ensure +there were no matches and there will be no more matches in future. + +##### Support matrix for joins in streaming queries + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Left InputRight InputJoin Type
    StaticStaticAll types + Supported, since its not on streaming data even though it + can be present in a streaming query +
    StreamStaticInnerSupported, not stateful
    Left OuterSupported, not stateful
    Right OuterNot supported
    Full OuterNot supported
    StaticStreamInnerSupported, not stateful
    Left OuterNot supported
    Right OuterSupported, not stateful
    Full OuterNot supported
    StreamStreamInner + Supported, optionally specify watermark on both sides + + time constraints for state cleanup +
    Left Outer + Conditionally supported, must specify watermark on right + time constraints for correct + results, optionally specify watermark on left for all state cleanup +
    Right Outer + Conditionally supported, must specify watermark on left + time constraints for correct + results, optionally specify watermark on right for all state cleanup +
    Full OuterNot supported
    + +Additional details on supported joins: + +- Joins can be cascaded, that is, you can do `df1.join(df2, ...).join(df3, ...).join(df4, ....)`. + +- As of Spark 2.3, you can use joins only when the query is in Append output mode. Other output modes are not yet supported. + +- As of Spark 2.3, you cannot use other non-map-like operations before joins. Here are a few examples of + what cannot be used. + + - Cannot use streaming aggregations before joins. + + - Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins. + + ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. @@ -1160,15 +1466,9 @@ Some of them are as follows. - Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode. -- Outer joins between a streaming and a static Datasets are conditionally supported. - - + Full outer join with a streaming Dataset is not supported - - + Left outer join with a streaming Dataset on the right is not supported - - + Right outer join with a streaming Dataset on the left is not supported - -- Any kind of joins between two streaming Datasets is not yet supported. +- Few types of outer joins on streaming Datasets are not supported. See the + support matrix in the Join Operations section + for more details. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -1276,6 +1576,15 @@ Here is the compatibility matrix. Aggregations not allowed after flatMapGroupsWithState. + + Queries with joins + Append + + Update and Complete mode not supported yet. See the + support matrix in the Join Operations section + for more details on what types of joins are supported. + + Other queries Append, Update @@ -2142,6 +2451,11 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat **Talks** -- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) -- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) +- Spark Summit Europe 2017 + - Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark - + [Part 1 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark), [Part 2 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark-continues) + - Deep Dive into Stateful Stream Processing in Structured Streaming - [slides/video](https://databricks.com/session/deep-dive-into-stateful-stream-processing-in-structured-streaming) +- Spark Summit 2016 + - A Deep Dive into Structured Streaming - [slides/video](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + From 02194702068291b3af77486d01029fb848c36d7b Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Wed, 17 Jan 2018 16:42:38 -0800 Subject: [PATCH 0126/2461] [SPARK-21996][SQL] read files with space in name for streaming ## What changes were proposed in this pull request? Structured streaming is now able to read files with space in file name (previously it would skip the file and output a warning) ## How was this patch tested? Added new unit test. Author: Xiayun Sun Closes #19247 from xysun/SPARK-21996. --- .../streaming/FileStreamSource.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 50 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0debd7db84757..8c016abc5b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -166,7 +166,7 @@ class FileStreamSource( val newDataSource = DataSource( sparkSession, - paths = files.map(_.path), + paths = files.map(f => new Path(new URI(f.path)).toString), userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 39bb572740617..5bb0f4d643bbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -74,11 +74,11 @@ abstract class FileStreamSourceTest protected def addData(source: FileStreamSource): Unit } - case class AddTextFileData(content: String, src: File, tmp: File) + case class AddTextFileData(content: String, src: File, tmp: File, tmpFilePrefix: String = "text") extends AddFileData { override def addData(source: FileStreamSource): Unit = { - val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val tempFile = Utils.tempFileWith(new File(tmp, tmpFilePrefix)) val finalFile = new File(src, tempFile.getName) src.mkdirs() require(stringToFile(tempFile, content).renameTo(finalFile)) @@ -408,6 +408,52 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-21996 read from text files -- file name has space") { + withTempDirs { case (src, tmp) => + val textStream = createFileStream("text", src.getCanonicalPath) + val filtered = textStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp, "text text"), + CheckAnswer("keep2", "keep3") + ) + } + } + + test("SPARK-21996 read from text files generated by file sink -- file name has space") { + val testTableName = "FileStreamSourceTest" + withTable(testTableName) { + withTempDirs { case (src, checkpoint) => + val output = new File(src, "text text") + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val query = ds.writeStream + .option("checkpointLocation", checkpoint.getCanonicalPath) + .format("text") + .start(output.getCanonicalPath) + + try { + inputData.addData("foo") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + query.stop() + } + + val df2 = spark.readStream.format("text").load(output.getCanonicalPath) + val query2 = df2.writeStream.format("memory").queryName(testTableName).start() + try { + query2.processAllAvailable() + checkDatasetUnorderly(spark.table(testTableName).as[String], "foo") + } finally { + query2.stop() + } + } + } + } + test("read from textfile") { withTempDirs { case (src, tmp) => val textStream = spark.readStream.textFile(src.getCanonicalPath) From 39d244d921d8d2d3ed741e8e8f1175515a74bdbd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 14:51:05 +0900 Subject: [PATCH 0127/2461] [SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark ## What changes were proposed in this pull request? This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0. These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`. Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring. This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158 ## How was this patch tested? Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly. Author: hyukjinkwon Closes #20288 from HyukjinKwon/deprecate-udf. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/catalog.py | 91 ++-------------- python/pyspark/sql/context.py | 137 ++++-------------------- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 3 +- python/pyspark/sql/session.py | 6 +- python/pyspark/sql/tests.py | 20 ++++ python/pyspark/sql/udf.py | 182 +++++++++++++++++++++++++++++++- 8 files changed, 234 insertions(+), 210 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7164180a6a7b0..b900f0bd913c3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 35fbe9e669adb..6aef0f22340be 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -224,92 +224,17 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) - @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = spark.catalog.registerFunction("stringLengthString", len) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. - # This is to check whether the input function is a wrapped/native UserDefinedFunction - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: None is expected when f is a UserDefinedFunction, " - "but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + """ + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self._sparkSession.udf.register(name, f, returnType) @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 85479095af594..cc1cd1a5842d9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,9 +29,10 @@ from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType +from pyspark.sql.udf import UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] +__all__ = ["SQLContext", "HiveContext"] class SQLContext(object): @@ -147,7 +148,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -172,113 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) - @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = sqlContext.udf.register("slen", slen) - >>> sqlContext.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) - >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP - >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. """ - return self.sparkSession.catalog.registerFunction(name, f, returnType) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self.sparkSession.udf.register(name, f, returnType) - @ignore_unicode_prefix @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a java UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - :param name: name of the UDF - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> sqlContext.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] + """An alias for :func:`spark.udf.registerJavaFunction`. + See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - @since(2.3) - def registerJavaUDAF(self, name, javaClassName): - """Register a java UDAF so it can be used in SQL statements. - - :param name: name of the UDAF - :param javaClassName: fully qualified name of java class - - >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -590,24 +507,6 @@ def refreshTable(self, tableName): self._ssql_ctx.refreshTable(tableName) -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sqlContext): - self.sqlContext = sqlContext - - def register(self, name, f, returnType=None): - return self.sqlContext.registerFunction(name, f, returnType) - - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) - - def registerJavaUDAF(self, name, javaClassName): - self.sqlContext.registerJavaUDAF(name, javaClassName) - - register.__doc__ = SQLContext.registerFunction.__doc__ - - def _test(): import os import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f7b3f29764040..988c1d25259bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()): >>> import random >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. @@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): ... return pd.Series(np.random.randn(len(v)) >>> random = random.asNondeterministic() # doctest: +SKIP - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. """ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 09fae46adf014..22061b83eb78c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -212,7 +212,8 @@ def apply(self, udf): This function does not support partial aggregation, and requires shuffling all the data in the :class:`DataFrame`. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: a group map user-defined function returned by + :meth:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 604021c1f45cc..6c84023c43fb6 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -29,7 +29,6 @@ from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.sql.catalog import Catalog from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -280,6 +279,7 @@ def catalog(self): :return: :class:`Catalog` """ + from pyspark.sql.catalog import Catalog if not hasattr(self, "_catalog"): self._catalog = Catalog(self) return self._catalog @@ -291,8 +291,8 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.context import UDFRegistration - return UDFRegistration(self._wrapped) + from pyspark.sql.udf import UDFRegistration + return UDFRegistration(self) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8906618666b14..f84aa3d68b808 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -372,6 +372,12 @@ def test_udf(self): [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + def test_udf2(self): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ @@ -577,11 +583,25 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + def test_non_existed_udaf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e80ab9165867..1943bb73f9ac2 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -19,11 +19,13 @@ """ import functools -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +__all__ = ["UDFRegistration"] + def _wrap_function(sc, func, returnType): command = (func, returnType) @@ -181,3 +183,179 @@ def asNondeterministic(self): """ self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. This instance can be accessed by + :attr:`spark.udf` or :attr:`sqlContext.udf`. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since("1.3.1") + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. Please see below. + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction( + ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + """ + + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 1c76a91e5fae11dcb66c453889e587b48039fdc9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 22:36:29 -0800 Subject: [PATCH 0128/2461] [SPARK-23052][SS] Migrate ConsoleSink to data source V2 api. ## What changes were proposed in this pull request? Migrate ConsoleSink to data source V2 api. Note that this includes a missing piece in DataStreamWriter required to specify a data source V2 writer. Note also that I've removed the "Rerun batch" part of the sink, because as far as I can tell this would never have actually happened. A MicroBatchExecution object will only commit each batch once for its lifetime, and a new MicroBatchExecution object would have a new ConsoleSink object which doesn't know it's retrying a batch. So I think this represents an anti-feature rather than a weakness in the V2 API. ## How was this patch tested? new unit test Author: Jose Torres Closes #20243 from jose-torres/console-sink. --- .../streaming/MicroBatchExecution.scala | 7 +- .../sql/execution/streaming/console.scala | 62 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 64 +++++ .../sources/PackedRowWriterFactory.scala | 60 +++++ .../sql/streaming/DataStreamWriter.scala | 16 +- ...pache.spark.sql.sources.DataSourceRegister | 8 + .../sources/ConsoleWriterSuite.scala | 135 ++++++++++ .../sources/StreamingDataSourceV2Suite.scala | 249 ++++++++++++++++++ .../test/DataStreamReaderWriterSuite.scala | 25 -- 10 files changed, 551 insertions(+), 84 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 70407f0580f97..7c3804547b736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -91,11 +91,14 @@ class MicroBatchExecution( nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, _, _, output, v1Relation) => + case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + if (v1Relation.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support microbatch processing.") + } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 StreamingExecutionRelation(source, output)(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 71eaabe273fea..94820376ff7e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,58 +17,36 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -class ConsoleSink(options: Map[String, String]) extends Sink with Logging { - // Number of rows to display, by default 20 rows - private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20) - - // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true) +import java.util.Optional - // Track the batch id - private var lastBatchId = -1L - - override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - val batchIdStr = if (batchId <= lastBatchId) { - s"Rerun batch: $batchId" - } else { - lastBatchId = batchId - s"Batch: $batchId" - } - - // scalastyle:off println - println("-------------------------------------------") - println(batchIdStr) - println("-------------------------------------------") - // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } +import scala.collection.JavaConverters._ - override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" -} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { override def schema: StructType = data.schema } -class ConsoleSinkProvider extends StreamSinkProvider +class ConsoleSinkProvider extends DataSourceV2 + with MicroBatchWriteSupport with DataSourceRegister with CreatableRelationProvider { - def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { - new ConsoleSink(parameters) + + override def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + Optional.of(new ConsoleWriter(epochId, schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c0507224f9be8..462e7d9721d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -54,16 +54,13 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ - override lazy val logicalPlan: LogicalPlan = { - assert(queryExecutionThread eq Thread.currentThread, - "logicalPlan must be initialized in StreamExecutionThread " + - s"but the current thread was ${Thread.currentThread}") + override val logicalPlan: LogicalPlan = { val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( @@ -72,7 +69,7 @@ class ContinuousExecution( ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) case StreamingRelationV2(_, sourceName, _, _, _) => - throw new AnalysisException( + throw new UnsupportedOperationException( s"Data source $sourceName does not support continuous processing.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala new file mode 100644 index 0000000000000..361979984bbec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.types.StructType + +/** + * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. + * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) + extends DataSourceV2Writer with Logging { + // Number of rows to display, by default 20 rows + private val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + private val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + + override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + + override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { + val batch = messages.collect { + case PackedRowCommitMessage(rows) => rows + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(s"Batch: $batchId") + println("-------------------------------------------") + // scalastyle:off println + spark.createDataFrame( + spark.sparkContext.parallelize(batch), schema) + .show(numRowsToShow, isTruncated) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala new file mode 100644 index 0000000000000..9282ba05bdb7b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} + +/** + * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery + * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * + * Note that, because it sends all rows to the driver, this factory will generally be unsuitable + * for production-quality sinks. It's intended for use in tests. + */ +case object PackedRowWriterFactory extends DataWriterFactory[Row] { + def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + new PackedRowDataWriter() + } +} + +/** + * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most + * recent interval. + */ +case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage + +/** + * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. + */ +class PackedRowDataWriter() extends DataWriter[Row] with Logging { + private val data = mutable.Buffer[Row]() + + override def write(row: Row): Unit = data.append(row) + + override def commit(): PackedRowCommitMessage = { + val msg = PackedRowCommitMessage(data.toArray) + data.clear() + msg + } + + override def abort(): Unit = data.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..d24f0ddeab4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,14 +280,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val sink = (ds.newInstance(), trigger) match { + case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w + case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( + s"Data source $source does not support continuous writing") + case (w: MicroBatchWriteSupport, _) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c6973bf41d34b..a0b25b4e82364 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,11 @@ org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo org.apache.fakesource.FakeExternalSourceThree +org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly +org.apache.spark.sql.streaming.sources.FakeReadBothModes +org.apache.spark.sql.streaming.sources.FakeReadNeitherMode +org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly +org.apache.spark.sql.streaming.sources.FakeWriteBothModes +org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala new file mode 100644 index 0000000000000..60ffee9b9b42c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest + +class ConsoleWriterSuite extends StreamTest { + import testImplicits._ + + test("console") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("console with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("console with truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala new file mode 100644 index 0000000000000..f152174b0a7f0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.sources + +import java.util.Optional + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +case class FakeReader() extends MicroBatchReader with ContinuousReader { + def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} + def getStartOffset: Offset = RateStreamOffset(Map()) + def getEndOffset: Offset = RateStreamOffset(Map()) + def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + def commit(end: Offset): Unit = {} + def readSchema(): StructType = StructType(Seq()) + def stop(): Unit = {} + def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + def setOffset(start: Optional[Offset]): Unit = {} + + def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + throw new IllegalStateException("fake source - cannot actually read") + } +} + +trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = FakeReader() +} + +trait FakeContinuousReadSupport extends ContinuousReadSupport { + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): ContinuousReader = FakeReader() +} + +trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { + def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +trait FakeContinuousWriteSupport extends ContinuousWriteSupport { + def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { + override def shortName(): String = "fake-read-microbatch-only" +} + +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-continuous-only" +} + +class FakeReadBothModes extends DataSourceRegister + with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-microbatch-continuous" +} + +class FakeReadNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-read-neither-mode" +} + +class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { + override def shortName(): String = "fake-write-microbatch-only" +} + +class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-continuous-only" +} + +class FakeWriteBothModes extends DataSourceRegister + with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" +} + +class FakeWriteNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" +} + +class StreamingDataSourceV2Suite extends StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + val fakeCheckpoint = Utils.createTempDir() + spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) + } + + val readFormats = Seq( + "fake-read-microbatch-only", + "fake-read-continuous-only", + "fake-read-microbatch-continuous", + "fake-read-neither-mode") + val writeFormats = Seq( + "fake-write-microbatch-only", + "fake-write-continuous-only", + "fake-write-microbatch-continuous", + "fake-write-neither-mode") + val triggers = Seq( + Trigger.Once(), + Trigger.ProcessingTime(1000), + Trigger.Continuous(1000)) + + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + query.stop() + } + + private def testNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val ex = intercept[UnsupportedOperationException] { + testPositiveCase(readFormat, writeFormat, trigger) + } + assert(ex.getMessage.contains(errorMsg)) + } + + private def testPostCreationNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + + eventually(timeout(streamingTimeout)) { + assert(query.exception.isDefined) + assert(query.exception.get.cause != null) + assert(query.exception.get.cause.getMessage.contains(errorMsg)) + } + } + + // Get a list of (read, write, trigger) tuples for test cases. + val cases = readFormats.flatMap { read => + writeFormats.flatMap { write => + triggers.map(t => (write, t)) + }.map { + case (write, t) => (read, write, t) + } + } + + for ((read, write, trigger) <- cases) { + testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + (readSource, writeSource, trigger) match { + // Valid microbatch queries. + case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + if !t.isInstanceOf[ContinuousTrigger] => + testPositiveCase(read, write, trigger) + + // Valid continuous queries. + case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + testPositiveCase(read, write, trigger) + + // Invalid - can't read at all + case (r, _, _) + if !r.isInstanceOf[MicroBatchReadSupport] + && !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support streamed reading") + + // Invalid - trigger is continuous but writer is not + case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support continuous writing") + + // Invalid - can't write at all + case (_, w, _) + if !w.isInstanceOf[MicroBatchWriteSupport] + && !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are continuous but reader is not + case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support continuous processing") + + // Invalid - trigger is microbatch but writer is not + case (_, w, t) + if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are microbatch but reader is not + case (r, _, t) + if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c38..8212fb912ec57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -422,21 +422,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink can be correctly loaded") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream - .format("console") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(2.seconds)) - .start() - - sq.awaitTermination(2000L) - } - test("prevent all column partitioning") { withTempDir { dir => val path = dir.getCanonicalPath @@ -450,16 +435,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink should not require checkpointLocation") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream.format("console").start() - sq.stop() - } - private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) From 7a2248341396840628eef398aa512cac3e3bd55f Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 19:18:55 +0800 Subject: [PATCH 0129/2461] [SPARK-23140][SQL] Add DataSourceV2Strategy to Hive Session state's planner ## What changes were proposed in this pull request? `DataSourceV2Strategy` is missing in `HiveSessionStateBuilder`'s planner, which will throw exception as described in [SPARK-23140](https://issues.apache.org/jira/browse/SPARK-23140). ## How was this patch tested? Manual test. Author: jerryshao Closes #20305 from jerryshao/SPARK-23140. --- .../sql/hive/HiveSessionStateBuilder.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dc92ad3b0c1ac..12c74368dd184 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -96,22 +96,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = - super.extraPlanningStrategies ++ customPlanningStrategies - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ - extraPlanningStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy(conf), - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } + super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts) } } From e28eb431146bcdcaf02a6f6c406ca30920592a6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 18 Jan 2018 21:24:39 +0800 Subject: [PATCH 0130/2461] [SPARK-22036][SQL] Decimal multiplication with high precision/scale often returns NULL ## What changes were proposed in this pull request? When there is an operation between Decimals and the result is a number which is not representable exactly with the result's precision and scale, Spark is returning `NULL`. This was done to reflect Hive's behavior, but it is against SQL ANSI 2011, which states that "If the result cannot be represented exactly in the result type, then whether it is rounded or truncated is implementation-defined". Moreover, Hive now changed its behavior in order to respect the standard, thanks to HIVE-15331. Therefore, the PR propose to: - update the rules to determine the result precision and scale according to the new Hive's ones introduces in HIVE-15331; - round the result of the operations, when it is not representable exactly with the result's precision and scale, instead of returning `NULL` - introduce a new config `spark.sql.decimalOperations.allowPrecisionLoss` which default to `true` (ie. the new behavior) in order to allow users to switch back to the previous one. Hive behavior reflects SQLServer's one. The only difference is that the precision and scale are adjusted for all the arithmetic operations in Hive, while SQL Server is said to do so only for multiplications and divisions in the documentation. This PR follows Hive's behavior. A more detailed explanation is available here: https://mail-archives.apache.org/mod_mbox/spark-dev/201712.mbox/%3CCAEorWNAJ4TxJR9NBcgSFMD_VxTg8qVxusjP%2BAJP-x%2BJV9zH-yA%40mail.gmail.com%3E. ## How was this patch tested? modified and added UTs. Comparisons with results of Hive and SQLServer. Author: Marco Gaido Closes #20023 from mgaido91/SPARK-22036. --- docs/sql-programming-guide.md | 5 + .../catalyst/analysis/DecimalPrecision.scala | 114 +++++--- .../sql/catalyst/expressions/literals.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 12 + .../apache/spark/sql/types/DecimalType.scala | 45 +++- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 20 +- .../native/decimalArithmeticOperations.sql | 47 ++++ .../decimalArithmeticOperations.sql.out | 245 ++++++++++++++++-- .../native/decimalPrecision.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 -- 11 files changed, 434 insertions(+), 82 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 258c769ff593b..3e2e48a0ef249 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1793,6 +1793,11 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). + - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index a8100b9b24aac..ab63131b07573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -42,8 +43,10 @@ import org.apache.spark.sql.types._ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -56,6 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { @@ -93,41 +97,76 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) } - val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -137,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions } /** @@ -243,17 +279,35 @@ object DecimalPrecision extends TypeCoercionRule { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a209833..cd176d941819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -58,7 +58,7 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 16fbb0c3e9e21..cc4f4bf332459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,6 +1064,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") + .internal() + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") + .booleanConf + .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = ConfigBuilder("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + @@ -1441,6 +1451,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6e050c18b8acb..ef3b67c0d48d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) @@ -136,10 +137,52 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } + private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { + case v: Short => fromBigDecimal(BigDecimal(v)) + case v: Int => fromBigDecimal(BigDecimal(v)) + case v: Long => fromBigDecimal(BigDecimal(v)) + case _ => forType(literal.dataType) + } + + private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = { + DecimalType(Math.max(d.precision, d.scale), d.scale) + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } + /** + * Scale adjustment implementation is based on Hive's one, which is itself inspired to + * SQLServer's one. In particular, when a result precision is greater than + * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * result from being truncated. + * + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + */ + private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { + // Assumptions: + assert(precision >= scale) + assert(scale >= 0) + + if (precision <= MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + DecimalType(precision, scale) + } else { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + val intDigits = precision - scale + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) + + DecimalType(MAX_PRECISION, adjustedScale) + } + } + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f4514205d3ae0..cd8579584eada 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) - assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) - assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6)) assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8b..c86dc18dfa680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c8e108ac2c45e..c6d8a49d4b93a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -22,6 +22,51 @@ select a / b from t; select a % b from t; select pmod(a, b) from t; +-- tests for decimals handling in operations +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss are truncated +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; @@ -31,3 +76,5 @@ select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL select 123456789123456789.1234567890 * 1.123456789123456789; select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index ce02f6adc456c..4d70fe19d539f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 32 -- !query 0 @@ -35,48 +35,257 @@ NULL -- !query 4 -select (5e36 + 0.1) + 5e36 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet -- !query 4 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 4 output -NULL + -- !query 5 -select (-4e36 - 0.1) - 7e36 +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 5 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 5 output -NULL + -- !query 6 -select 12345678901234567890.0 * 12345678901234567890.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 6 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct -- !query 6 output -NULL +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 7 -select 1e35 / 0.1 +select id, a*10, b/10 from decimals_test order by id -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct -- !query 7 output -NULL +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 -- !query 8 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 10.3 * 3.0 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 8 output -NULL +30.9 -- !query 9 -select 0.001 / 9876543210987654321098765432109876543.2 +select 10.3000 * 3.0 -- !query 9 schema -struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 9 output +30.9 + + +-- !query 10 +select 10.30000 * 30.0 +-- !query 10 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 10 output +309 + + +-- !query 11 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 11 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 11 output +30.9 + + +-- !query 12 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 12 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 12 output +30.9 + + +-- !query 13 +select (5e36 + 0.1) + 5e36 +-- !query 13 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 13 output +NULL + + +-- !query 14 +select (-4e36 - 0.1) - 7e36 +-- !query 14 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 14 output +NULL + + +-- !query 15 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 15 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 15 output NULL + + +-- !query 16 +select 1e35 / 0.1 +-- !query 16 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 16 output +NULL + + +-- !query 17 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 17 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 17 output +138698367904130467.654320988515622621 + + +-- !query 18 +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'spark' expecting (line 3, pos 4) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +----^^^ + + +-- !query 19 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 19 schema +struct +-- !query 19 output +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 + + +-- !query 20 +select id, a*10, b/10 from decimals_test order by id +-- !query 20 schema +struct +-- !query 20 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 + + +-- !query 21 +select 10.3 * 3.0 +-- !query 21 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 21 output +30.9 + + +-- !query 22 +select 10.3000 * 3.0 +-- !query 22 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 22 output +30.9 + + +-- !query 23 +select 10.30000 * 30.0 +-- !query 23 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 23 output +309 + + +-- !query 24 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 24 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 24 output +30.9 + + +-- !query 25 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 25 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 25 output +30.9 + + +-- !query 26 +select (5e36 + 0.1) + 5e36 +-- !query 26 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 26 output +NULL + + +-- !query 27 +select (-4e36 - 0.1) - 7e36 +-- !query 27 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 27 output +NULL + + +-- !query 28 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 28 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 28 output +NULL + + +-- !query 29 +select 1e35 / 0.1 +-- !query 29 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 29 output +NULL + + +-- !query 30 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 30 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 30 output +138698367904130467.654320988515622621 + + +-- !query 31 +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'table' expecting (line 3, pos 5) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-----^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out index ebc8201ed5a1d..6ee7f59d69877 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -2329,7 +2329,7 @@ struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(C -- !query 280 SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t -- !query 280 schema -struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,18)> -- !query 280 output 1 @@ -2661,7 +2661,7 @@ struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BI -- !query 320 SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t -- !query 320 schema -struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,18)> -- !query 320 output 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d4d0aa4f5f5eb..083a0c0b1b9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1517,24 +1517,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) - } - test("SPARK-10215 Div of Decimal returns null") { val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") From 5063b7481173ad72bd0dc941b5cf3c9b26a591e4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 18 Jan 2018 22:33:04 +0900 Subject: [PATCH 0131/2461] [SPARK-23141][SQL][PYSPARK] Support data type string as a returnType for registerJavaFunction. ## What changes were proposed in this pull request? Currently `UDFRegistration.registerJavaFunction` doesn't support data type string as a `returnType` whereas `UDFRegistration.register`, `udf`, or `pandas_udf` does. We can support it for `UDFRegistration.registerJavaFunction` as well. ## How was this patch tested? Added a doctest and existing tests. Author: Takuya UESHIN Closes #20307 from ueshin/issues/SPARK-23141. --- python/pyspark/sql/functions.py | 6 ++++-- python/pyspark/sql/udf.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 988c1d25259bc..961b3267b44cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2108,7 +2108,8 @@ def udf(f=None, returnType=StringType()): can fail on special rows, the workaround is to incorporate the condition into the functions. :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -2148,7 +2149,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 1943bb73f9ac2..c77f19f89a442 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -206,7 +206,8 @@ def register(self, name, f, returnType=None): :param f: a Python function, or a user-defined function. The user-defined function can be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and :meth:`pyspark.sql.functions.pandas_udf`. - :param returnType: the return type of the registered user-defined function. + :param returnType: the return type of the registered user-defined function. The value can + be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. `returnType` can be optionally specified when `f` is a Python function but not @@ -303,21 +304,30 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): :param name: name of the user-defined function :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the registered Java function. The value can be either + a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> spark.udf.registerJavaFunction( ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> spark.sql("SELECT javaStringLength('test')").collect() [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") >>> spark.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] + + >>> spark.udf.registerJavaFunction( + ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") + >>> spark.sql("SELECT javaStringLength3('test')").collect() + [Row(UDF:javaStringLength3(test)=4)] """ jdt = None if returnType is not None: + if not isinstance(returnType, DataType): + returnType = _parse_datatype_string(returnType) jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) From cf7ee1767ddadce08dce050fc3b40c77cdd187da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 10:19:36 -0800 Subject: [PATCH 0132/2461] [SPARK-23147][UI] Fix task page table IndexOutOfBound Exception ## What changes were proposed in this pull request? Stage's task page table will throw an exception when there's no complete tasks. Furthermore, because the `dataSize` doesn't take running tasks into account, so sometimes UI cannot show the running tasks. Besides table will only be displayed when first task is finished according to the default sortColumn("index"). ![screen shot 2018-01-18 at 8 50 08 pm](https://user-images.githubusercontent.com/850797/35100052-470b4cae-fc95-11e7-96a2-ad9636e732b3.png) To reproduce this issue, user could try `sc.parallelize(1 to 20, 20).map { i => Thread.sleep(10000); i }.collect()` or `sc.parallelize(1 to 20, 20).map { i => Thread.sleep((20 - i) * 1000); i }.collect` to reproduce the above issue. Here propose a solution to fix it. Not sure if it is a right fix, please help to review. ## How was this patch tested? Manual test. Author: jerryshao Closes #20315 from jerryshao/SPARK-23147. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7c6e06cf183ba..af78373ddb4b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -676,7 +676,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks + override def dataSize: Int = stage.numTasks override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { From 9678941f54ebc5db935ed8d694e502086e2a31c0 Mon Sep 17 00:00:00 2001 From: Fernando Pereira Date: Thu, 18 Jan 2018 13:02:03 -0600 Subject: [PATCH 0133/2461] [SPARK-23029][DOCS] Specifying default units of configuration entries ## What changes were proposed in this pull request? This PR completes the docs, specifying the default units assumed in configuration entries of type size. This is crucial since unit-less values are accepted and the user might assume the base unit is bytes, which in most cases it is not, leading to hard-to-debug problems. ## How was this patch tested? This patch updates only documentation only. Author: Fernando Pereira Closes #20269 from ferdonline/docs_units. --- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../spark/internal/config/package.scala | 47 ++++---- docs/configuration.md | 100 ++++++++++-------- 3 files changed, 85 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d77303e6fdf8b..f53b2bed74c6e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -640,9 +640,9 @@ private[spark] object SparkConf extends Logging { translation = s => s"${s.toLong * 10}s")), "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), - "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${(s.toDouble * 1000).toInt}k")), + "spark.kryoserializer.buffer" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index eb12ddf961314..bbfcfbaa7363c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -38,10 +38,13 @@ package object config { ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .doc("Amount of memory to use for the driver process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.driver.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per driver in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -62,6 +65,7 @@ package object config { .createWithDefault(false) private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .doc("Buffer size to use when writing to output streams, in KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") @@ -81,10 +85,13 @@ package object config { ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .doc("Amount of memory to use per executor process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.executor.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per executor in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -353,7 +360,7 @@ package object config { private[spark] val BUFFER_WRITE_CHUNK_SIZE = ConfigBuilder("spark.buffer.write.chunkSize") .internal() - .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + " ChunkedByteBuffer should not larger than Int.MaxValue.") @@ -368,9 +375,9 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") - .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + - "record the size accurately if it's above this config. This helps to prevent OOM by " + - "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .doc("Threshold in bytes above which the size of shuffle blocks in " + + "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + + "by avoiding underestimating shuffle block size when fetch shuffle blocks.") .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) @@ -389,23 +396,23 @@ package object config { private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") - .doc("This configuration limits the number of remote blocks being fetched per reduce task" + - " from a given host port. When a large number of blocks are being requested from a given" + - " address in a single fetch or simultaneously, this could crash the serving executor or" + - " Node Manager. This is especially useful to reduce the load on the Node Manager when" + - " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .doc("This configuration limits the number of remote blocks being fetched per reduce task " + + "from a given host port. When a large number of blocks are being requested from a given " + + "address in a single fetch or simultaneously, this could crash the serving executor or " + + "Node Manager. This is especially useful to reduce the load on the Node Manager when " + + "external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") .intConf .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") - .doc("Remote block will be fetched to disk when size of the block is " + - "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + - "affect both shuffle fetch and block manager remote block fetch. For users who " + - "enabled external shuffle service, this feature can only be worked when external shuffle" + - " service is newer than Spark 2.2.") + .doc("Remote block will be fetched to disk when size of the block is above this threshold " + + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + + "both shuffle fetch and block manager remote block fetch. For users who enabled " + + "external shuffle service, this feature can only be worked when external shuffle" + + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -419,9 +426,9 @@ package object config { private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") - .doc("Size of the in-memory buffer for each shuffle file output stream. " + - "These buffers reduce the number of disk seeks and system calls made " + - "in creating intermediate shuffle files.") + .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + + "otherwise specified. These buffers reduce the number of disk seeks and system calls " + + "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -430,7 +437,7 @@ package object config { private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") .doc("The file system for this buffer size after each partition " + - "is written in unsafe shuffle writer.") + "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -438,7 +445,7 @@ package object config { private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") - .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .doc("The buffer size, in bytes, to use when writing the sorted records to an on-disk file.") .bytesConf(ByteUnit.BYTE) .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") diff --git a/docs/configuration.md b/docs/configuration.md index 1189aea2aa71f..eecb39dcafc9e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -58,6 +58,10 @@ The following format is accepted: 1t or 1tb (tebibytes = 1024 gibibytes) 1p or 1pb (pebibytes = 1024 tebibytes) +While numbers without units are generally interpreted as bytes, a few are interpreted as KiB or MiB. +See documentation of individual configuration properties. Specifying units is desirable where +possible. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For @@ -136,9 +140,9 @@ of the most common options to set are: spark.driver.maxResultSize 1g - Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). - Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size - is above this limit. + Limit of total size of serialized results of all partitions for each Spark action (e.g. + collect) in bytes. Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total + size is above this limit. Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory and memory overhead of objects in JVM). Setting a proper limit can protect the driver from out-of-memory errors. @@ -148,10 +152,10 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
    Note: In client mode, this config must not be set through the SparkConf + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB + unless otherwise specified (e.g. 1g, 2g). +
    + Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option or in your default properties file. @@ -161,27 +165,28 @@ of the most common options to set are: spark.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is - memory that accounts for things like VM overheads, interned strings, other native overheads, etc. - This tends to grow with the container size (typically 6-10%). This option is currently supported - on YARN and Kubernetes. + The amount of off-heap memory to be allocated per driver in cluster mode, in MiB unless + otherwise specified. This is memory that accounts for things like VM overheads, interned strings, + other native overheads, etc. This tends to grow with the container size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. spark.executor.memory 1g - Amount of memory to use per executor process (e.g. 2g, 8g). + Amount of memory to use per executor process, in MiB unless otherwise specified. + (e.g. 2g, 8g). spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that - accounts for things like VM overheads, interned strings, other native overheads, etc. This tends - to grow with the executor size (typically 6-10%). This option is currently supported on YARN and - Kubernetes. + The amount of off-heap memory to be allocated per executor, in MiB unless otherwise specified. + This is memory that accounts for things like VM overheads, interned strings, other native + overheads, etc. This tends to grow with the executor size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. @@ -431,8 +436,9 @@ Apart from these, the following properties are also available, and may be useful 512m Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g). + If the memory used during aggregation goes above this amount, it will spill the data into disks. @@ -540,9 +546,10 @@ Apart from these, the following properties are also available, and may be useful spark.reducer.maxSizeInFlight 48m - Maximum size of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + Maximum size of map outputs to fetch simultaneously from each reduce task, in MiB unless + otherwise specified. Since each output requires us to create a buffer to receive it, this + represents a fixed memory overhead per reduce task, so keep it small unless you have a + large amount of memory. @@ -570,9 +577,9 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The remote block will be fetched to disk when size of the block is above this threshold. + The remote block will be fetched to disk when size of the block is above this threshold in bytes. This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, this feature can only be worked when external shuffle service is newer than Spark 2.2. @@ -589,8 +596,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer 32k - Size of the in-memory buffer for each shuffle file output stream. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + Size of the in-memory buffer for each shuffle file output stream, in KiB unless otherwise + specified. These buffers reduce the number of disk seeks and system calls made in creating + intermediate shuffle files. @@ -651,7 +659,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.index.cache.size 100m - Cache entries limited to the specified memory footprint. + Cache entries limited to the specified memory footprint in bytes. @@ -685,9 +693,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.accurateBlockThreshold 100 * 1024 * 1024 - When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the - size accurately if it's above this config. This helps to prevent OOM by avoiding - underestimating shuffle block size when fetch shuffle blocks. + Threshold in bytes above which the size of shuffle blocks in HighlyCompressedMapStatus is + accurately recorded. This helps to prevent OOM by avoiding underestimating shuffle + block size when fetch shuffle blocks. @@ -779,7 +787,7 @@ Apart from these, the following properties are also available, and may be useful spark.eventLog.buffer.kb 100k - Buffer size in KB to use when writing to output streams. + Buffer size to use when writing to output streams, in KiB unless otherwise specified. @@ -917,7 +925,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.lz4.blockSize 32k - Block size used in LZ4 compression, in the case when LZ4 compression codec + Block size in bytes used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -925,7 +933,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.snappy.blockSize 32k - Block size used in Snappy compression, in the case when Snappy compression codec + Block size in bytes used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. @@ -941,7 +949,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.zstd.bufferSize 32k - Buffer size used in Zstd compression, in the case when Zstd compression codec + Buffer size in bytes used in Zstd compression, in the case when Zstd compression codec is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it might increase the compression cost because of excessive JNI call overhead. @@ -1001,8 +1009,8 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.max 64m - Maximum allowable size of Kryo serialization buffer. This must be larger than any - object you attempt to serialize and must be less than 2048m. + Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified. + This must be larger than any object you attempt to serialize and must be less than 2048m. Increase this if you get a "buffer limit exceeded" exception inside Kryo. @@ -1010,9 +1018,9 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer 64k - Initial size of Kryo's serialization buffer. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max if needed. + Initial size of Kryo's serialization buffer, in KiB unless otherwise specified. + Note that there will be one buffer per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max if needed. @@ -1086,7 +1094,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory + use is enabled, then spark.memory.offHeap.size must be positive. @@ -1094,7 +1103,8 @@ Apart from these, the following properties are also available, and may be useful 0 The absolute amount of memory in bytes which can be used for off-heap allocation. - This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This setting has no impact on heap memory usage, so if your executors' total memory consumption + must fit within some hard limit then be sure to shrink your JVM heap size accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true. @@ -1202,9 +1212,9 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.blockSize 4m - Size of each piece of a block for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + Size of each piece of a block for TorrentBroadcastFactory, in KiB unless otherwise + specified. Too large a value decreases parallelism during broadcast (makes it slower); however, + if it is too small, BlockManager might take a performance hit. @@ -1312,7 +1322,7 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryMapThreshold 2m - Size of a block above which Spark memory maps when reading a block from disk. + Size in bytes of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system. @@ -2490,4 +2500,4 @@ Also, you can modify or add configurations at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ --conf spark.hadoop.abc.def=xyz \ myApp.jar -{% endhighlight %} \ No newline at end of file +{% endhighlight %} From 2d41f040a34d6483919fd5d491cf90eee5429290 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:25:52 -0800 Subject: [PATCH 0134/2461] [SPARK-23143][SS][PYTHON] Added python API for setting continuous trigger ## What changes were proposed in this pull request? Self-explanatory. ## How was this patch tested? New python tests. Author: Tathagata Das Closes #20309 from tdas/SPARK-23143. --- python/pyspark/sql/streaming.py | 23 +++++++++++++++++++---- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 24ae3776a217b..e2a97acb5e2a7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -786,7 +786,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None, once=None): + def trigger(self, processingTime=None, once=None, continuous=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -802,23 +802,38 @@ def trigger(self, processingTime=None, once=None): >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') >>> # trigger the query for just once batch of data >>> writer = sdf.writeStream.trigger(once=True) + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(continuous='5 seconds') """ + params = [processingTime, once, continuous] + + if params.count(None) == 3: + raise ValueError('No trigger provided') + elif params.count(None) < 2: + raise ValueError('Multiple triggers not allowed.') + jTrigger = None if processingTime is not None: - if once is not None: - raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( interval) + elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: - raise ValueError('No trigger provided') + if type(continuous) != str or len(continuous.strip()) == 0: + raise ValueError('Value for continuous must be a non empty string. Got: %s' % + continuous) + interval = continuous.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( + interval) + self._jwrite = self._jwrite.trigger(jTrigger) return self diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..25483594f2725 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1538,6 +1538,12 @@ def test_stream_trigger(self): except ValueError: pass + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + # Should take only keyword args try: df.writeStream.trigger('5 seconds') From bf34d665b9c865e00fac7001500bf6d521c2dff9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:33:39 -0800 Subject: [PATCH 0135/2461] [SPARK-23144][SS] Added console sink for continuous processing ## What changes were proposed in this pull request? Refactored ConsoleWriter into ConsoleMicrobatchWriter and ConsoleContinuousWriter. ## How was this patch tested? new unit test Author: Tathagata Das Closes #20311 from tdas/SPARK-23144. --- .../sql/execution/streaming/console.scala | 20 +++-- .../streaming/sources/ConsoleWriter.scala | 80 ++++++++++++++----- .../sources/ConsoleWriterSuite.scala | 26 +++++- 3 files changed, 96 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 94820376ff7e7..f2aa3259731d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,16 +36,25 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) class ConsoleSinkProvider extends DataSourceV2 with MicroBatchWriteSupport + with ContinuousWriteSupport with DataSourceRegister with CreatableRelationProvider { override def createMicroBatchWriter( queryId: String, - epochId: Long, + batchId: Long, schema: StructType, mode: OutputMode, options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleWriter(epochId, schema, options)) + Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) + } + + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + Optional.of(new ConsoleContinuousWriter(schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 361979984bbec..6fb61dff60045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,45 +20,85 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType -/** - * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. - * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) - extends DataSourceV2Writer with Logging { +/** Common methods used to create writes for the the console sink */ +trait ConsoleWriter extends Logging { + + def options: DataSourceV2Options + // Number of rows to display, by default 20 rows - private val numRowsToShow = options.getInt("numRows", 20) + protected val numRowsToShow = options.getInt("numRows", 20) // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.getBoolean("truncate", true) + protected val isTruncated = options.getBoolean("truncate", true) assert(SparkSession.getActiveSession.isDefined) - private val spark = SparkSession.getActiveSession.get + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { - val batch = messages.collect { + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { case PackedRowCommitMessage(rows) => rows }.flatten // scalastyle:off println println("-------------------------------------------") - println(s"Batch: $batchId") + println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark.createDataFrame( - spark.sparkContext.parallelize(batch), schema) + spark + .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } +} + + +/** + * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) + extends DataSourceV2Writer with ConsoleWriter { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Batch: $batchId") + } + + override def toString(): String = { + s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +/** + * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) + extends ContinuousWriter with ConsoleWriter { + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Continuous processing epoch $epochId") + } + + override def toString(): String = { + s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala index 60ffee9b9b42c..55acf2ba28d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} class ConsoleWriterSuite extends StreamTest { import testImplicits._ - test("console") { + test("microbatch - default") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -77,7 +79,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with numRows") { + test("microbatch - with numRows") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -106,7 +108,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with truncation") { + test("microbatch - truncation") { val input = MemoryStream[String] val captured = new ByteArrayOutputStream() @@ -132,4 +134,20 @@ class ConsoleWriterSuite extends StreamTest { | |""".stripMargin) } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } } From f568e9cf76f657d094f1d036ab5a95f2531f5761 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Thu, 18 Jan 2018 14:00:12 -0800 Subject: [PATCH 0136/2461] [SPARK-23133][K8S] Fix passing java options to Executor Pass through spark java options to the executor in context of docker image. Closes #20296 andrusha: Deployed two version of containers to local k8s, checked that java options were present in the updated image on the running executor. Manual test Author: Andrew Korzhuev Closes #20322 from foxish/patch-1. --- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 0c28c75857871..b9090dc2852a5 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -42,7 +42,7 @@ shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" fi @@ -54,7 +54,7 @@ case "$SPARK_K8S_CMD" in driver) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_DRIVER_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY @@ -67,7 +67,7 @@ case "$SPARK_K8S_CMD" in executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" From e01919e834d301e13adc8919932796ebae900576 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 19 Jan 2018 07:36:06 +0900 Subject: [PATCH 0137/2461] [SPARK-23094] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? There were two related fixes regarding `from_json`, `get_json_object` and `json_tuple` ([Fix #1](https://github.com/apache/spark/commit/c8803c06854683c8761fdb3c0e4c55d5a9e22a95), [Fix #2](https://github.com/apache/spark/commit/86174ea89b39a300caaba6baffac70f3dc702788)), but they weren't comprehensive it seems. I wanted to extend those fixes to all the parsers, and add tests for each case. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #20302 from brkyvz/json-invfix. --- .../catalyst/json/CreateJacksonParser.scala | 5 +-- .../sources/JsonHadoopFsRelationSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..b1672e7e2fca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,10 +40,11 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - jsonFactory.createParser(record.getBytes, 0, record.getLength) + val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 49be30435ad2f..27f398ebf301a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -105,4 +107,36 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } + + test("invalid json with leading nulls - from file (multiLine=true)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val expected = s"""$badJson\n{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("invalid json with leading nulls - from file (multiLine=false)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("invalid json with leading nulls - from dataset") { + import testImplicits._ + checkAnswer( + spark.read.json(Seq(badJson).toDS()), + Row(badJson)) + } } From 5d7c4ba4d73a72f26d591108db3c20b4a6c84f3f Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 18 Jan 2018 14:44:22 -0800 Subject: [PATCH 0138/2461] [SPARK-22962][K8S] Fail fast if submission client local files are used ## What changes were proposed in this pull request? In the Kubernetes mode, fails fast in the submission process if any submission client local dependencies are used as the use case is not supported yet. ## How was this patch tested? Unit tests, integration tests, and manual tests. vanzin foxish Author: Yinan Li Closes #20320 from liyinan926/master. --- docs/running-on-kubernetes.md | 5 ++- .../k8s/submit/DriverConfigOrchestrator.scala | 14 ++++++++- .../DriverConfigOrchestratorSuite.scala | 31 ++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 08ec34c63ba3f..d6b1735ce5550 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,7 +117,10 @@ This URI is the location of the example jar that is already in the Docker image. If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to by their appropriate remote URIs. Also, application dependencies can be pre-mounted into custom-built Docker images. Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the -`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. +`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. The `local://` scheme is also required when referring to +dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission +client's local file system is currently not yet supported. + ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index c9cc300d65569..ae70904621184 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -20,7 +20,7 @@ import java.util.UUID import com.google.common.primitives.Longs -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -117,6 +117,12 @@ private[spark] class DriverConfigOrchestrator( .map(_.split(",")) .getOrElse(Array.empty[String]) + // TODO(SPARK-23153): remove once submission client local dependencies are supported. + if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { + throw new SparkException("The Kubernetes mode does not yet support referencing application " + + "dependencies in the local file system.") + } + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, @@ -162,6 +168,12 @@ private[spark] class DriverConfigOrchestrator( initContainerBootstrapStep } + private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme == "file" + } + } + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { files.exists { uri => Utils.resolveURI(uri).getScheme != "local" diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 65274c6f50e01..033d303e946fd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.submit -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ @@ -117,6 +117,35 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[DriverMountSecretsStep]) } + test("Submission using client local dependencies") { + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + var orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + + sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") + orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + } + private def validateStepTypes( orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { From 4cd2ecc0c7222fef1337e04f1948333296c3be86 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 16:29:45 -0800 Subject: [PATCH 0139/2461] [SPARK-23142][SS][DOCS] Added docs for continuous processing ## What changes were proposed in this pull request? Added documentation for continuous processing. Modified two locations. - Modified the overview to have a mention of Continuous Processing. - Added a new section on Continuous Processing at the end. ![image](https://user-images.githubusercontent.com/663212/35083551-a3dd23f6-fbd4-11e7-9e7e-90866f131ca9.png) ![image](https://user-images.githubusercontent.com/663212/35083618-d844027c-fbd4-11e7-9fde-75992cc517bd.png) ## How was this patch tested? N/A Author: Tathagata Das Closes #20308 from tdas/SPARK-23142. --- .../structured-streaming-programming-guide.md | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 1779a4215e085..2ddba2f0d942e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,9 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. + +In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in @@ -2434,6 +2436,100 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat
+# Continuous Processing [Experimental] +**Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). + +To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, + +
+
+{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +spark + .readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("") + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start() +{% endhighlight %} +
+
+{% highlight java %} +import org.apache.spark.sql.streaming.Trigger; + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start(); +{% endhighlight %} +
+
+{% highlight python %} +spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .load() \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .trigger(continuous="1 second") \ # only change in query + .start() + +{% endhighlight %} +
+
+ +A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. + +## Supported Queries +As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. + +- *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). + + All SQL functions are supported except aggregation functions (since aggregations are not yet supported), `current_timestamp()` and `current_date()` (deterministic computations using time is challenging). + +- *Sources*: + + Kafka source: All options are supported. + + Rate source: Good for testing. Only options that are supported in the continuous mode are `numPartitions` and `rowsPerSecond`. + +- *Sinks*: + + Kafka sink: All options are supported. + + Memory sink: Good for debugging. + + Console sink: Good for debugging. All options are supported. Note that the console will print every checkpoint interval that you have specified in the continuous trigger. + +See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. + +## Caveats +- Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. +- Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. +- There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. + # Additional Information **Further Reading** From 6121e91b7f5c9513d68674e4d5edbc3a4a5fd5fd Mon Sep 17 00:00:00 2001 From: brandonJY Date: Thu, 18 Jan 2018 18:57:49 -0600 Subject: [PATCH 0140/2461] [DOCS] change to dataset for java code in structured-streaming-kafka-integration document ## What changes were proposed in this pull request? In latest structured-streaming-kafka-integration document, Java code example for Kafka integration is using `DataFrame`, shouldn't it be changed to `DataSet`? ## How was this patch tested? manual test has been performed to test the updated example Java code in Spark 2.2.1 with Kafka 1.0 Author: brandonJY Closes #20312 from brandonJY/patch-2. --- docs/structured-streaming-kafka-integration.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index bab0be8ddeb9f..461c29ce1ba89 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -61,7 +61,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -70,7 +70,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -79,7 +79,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -171,7 +171,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -180,7 +180,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -191,7 +191,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") From 568055da93049c207bb830f244ff9b60c638837c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 19 Jan 2018 11:37:08 +0800 Subject: [PATCH 0141/2461] [SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String. ## What changes were proposed in this pull request? This is a follow-up of #20246. If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`. This pr fixes it by using its `sqlType` casting. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20306 from ueshin/issues/SPARK-23054/fup1. --- python/pyspark/sql/tests.py | 11 +++++++++++ .../apache/spark/sql/catalyst/expressions/Cast.scala | 2 ++ .../org/apache/spark/sql/test/ExamplePointUDT.scala | 2 ++ 3 files changed, 15 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 25483594f2725..4fee2ecde391b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1189,6 +1189,17 @@ def test_union_with_udt(self): ] ) + def test_cast_to_string_with_udt(self): + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a95ebe301b9d1..79b051670e9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case pudt: PythonUserDefinedType => castToString(pudt.sqlType) case udt: UserDefinedType[_] => buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) @@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => val udtRef = ctx.addReferenceObj("udt", udt) (c, evPrim, evNull) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index a73e4272950a4..8bab7e1c58762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false } + + override def toString(): String = s"($x, $y)" } /** From 9c4b99861cda3f9ec44ca8c1adc81a293508190c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 19 Jan 2018 01:38:08 -0800 Subject: [PATCH 0142/2461] [BUILD][MINOR] Fix java style check issues ## What changes were proposed in this pull request? This patch fixes a few recently introduced java style check errors in master and release branch. As an aside, given that [java linting currently fails](https://github.com/apache/spark/pull/10763 ) on machines with a clean maven cache, it'd be great to find another workaround to [re-enable the java style checks](https://github.com/apache/spark/blob/3a07eff5af601511e97a05e6fea0e3d48f74c4f0/dev/run-tests.py#L577) as part of Spark PRB. /cc zsxwing JoshRosen srowen for any suggestions ## How was this patch tested? Manual Check Author: Sameer Agarwal Closes #20323 from sameeragarwal/java. --- .../spark/sql/sources/v2/writer/DataSourceV2Writer.java | 6 ++++-- .../org/apache/spark/sql/vectorized/ArrowColumnVector.java | 5 +++-- .../apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 317ac45bcfd74..f1ef411423162 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,8 +28,10 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( + * String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index eb69001fe677e..bfd1b4cb0ef12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -556,8 +556,9 @@ final int getArrayOffset(int rowId) { /** * Any call to "get" method will throw UnsupportedOperationException. * - * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined - * in the parent class. Any call to "get" method in this class is a bug in the code. + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. * */ private static class StructAccessor extends ArrowVectorAccessor { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 44e5146d7c553..98d6a53b54d28 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,7 +69,8 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = + new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); return this; } From 60203fca6a605ad158184e1e0ce5187e144a3ea7 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 19 Jan 2018 12:43:23 +0200 Subject: [PATCH 0143/2461] [SPARK-23127][DOC] Update FeatureHasher guide for categoricalCols parameter Update user guide entry for `FeatureHasher` to match the Scala / Python doc, to describe the `categoricalCols` parameter. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20293 from MLnick/SPARK-23127-catCol-userguide. --- docs/ml-features.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 72643137d96b1..10183c3e78c76 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -222,9 +222,9 @@ The `FeatureHasher` transformer operates on multiple columns. Each column may co numeric or categorical features. Behavior and handling of column data types is as follows: - Numeric columns: For numeric features, the hash value of the column name is used to map the -feature value to its index in the feature vector. Numeric features are never treated as -categorical, even when they are integers. You must explicitly convert numeric columns containing -categorical features to strings first. +feature value to its index in the feature vector. By default, numeric features are not treated +as categorical (even when they are integers). To treat them as categorical, specify the relevant +columns using the `categoricalCols` parameter. - String columns: For categorical features, the hash value of the string "column_name=value" is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with From b74366481cc87490adf4e69d26389ec737548c15 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jan 2018 12:48:42 +0200 Subject: [PATCH 0144/2461] [SPARK-23048][ML] Add OneHotEncoderEstimator document and examples ## What changes were proposed in this pull request? We have `OneHotEncoderEstimator` now and `OneHotEncoder` will be deprecated since 2.3.0. We should add `OneHotEncoderEstimator` into mllib document. We also need to provide corresponding examples for `OneHotEncoderEstimator` which are used in the document too. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20257 from viirya/SPARK-23048. --- docs/ml-features.md | 28 ++++++++----- ...=> JavaOneHotEncoderEstimatorExample.java} | 41 ++++++++----------- ...py => onehot_encoder_estimator_example.py} | 29 +++++++------ ...la => OneHotEncoderEstimatorExample.scala} | 40 ++++++++---------- 4 files changed, 68 insertions(+), 70 deletions(-) rename examples/src/main/java/org/apache/spark/examples/ml/{JavaOneHotEncoderExample.java => JavaOneHotEncoderEstimatorExample.java} (62%) rename examples/src/main/python/ml/{onehot_encoder_example.py => onehot_encoder_estimator_example.py} (65%) rename examples/src/main/scala/org/apache/spark/examples/ml/{OneHotEncoderExample.scala => OneHotEncoderEstimatorExample.scala} (65%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 10183c3e78c76..466a8fbe99cf6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -775,35 +775,43 @@ for more details on the API.
-## OneHotEncoder +## OneHotEncoder (Deprecated since 2.3.0) -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). + +`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. + +## OneHotEncoderEstimator + +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). + +`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). **Examples**
-Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
-Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) +Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
-Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% include_example python/ml/onehot_encoder_estimator_example.py %}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java similarity index 62% rename from examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java index 99af37676ba98..6f93cff94b725 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java @@ -23,9 +23,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.ml.feature.OneHotEncoderEstimator; +import org.apache.spark.ml.feature.OneHotEncoderModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -35,41 +34,37 @@ import org.apache.spark.sql.types.StructType; // $example off$ -public class JavaOneHotEncoderExample { +public class JavaOneHotEncoderEstimatorExample { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("JavaOneHotEncoderExample") + .appName("JavaOneHotEncoderEstimatorExample") .getOrCreate(); + // Note: categorical features are usually first encoded with StringIndexer // $example on$ List data = Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") + RowFactory.create(0.0, 1.0), + RowFactory.create(1.0, 0.0), + RowFactory.create(2.0, 1.0), + RowFactory.create(0.0, 2.0), + RowFactory.create(0.0, 1.0), + RowFactory.create(2.0, 0.0) ); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset df = spark.createDataFrame(data, schema); - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - Dataset indexed = indexer.transform(df); + OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - - Dataset encoded = encoder.transform(indexed); + OneHotEncoderModel model = encoder.fit(df); + Dataset encoded = model.transform(df); encoded.show(); // $example off$ diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_estimator_example.py similarity index 65% rename from examples/src/main/python/ml/onehot_encoder_example.py rename to examples/src/main/python/ml/onehot_encoder_estimator_example.py index e1996c7f0a55b..2723e681cea7c 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_estimator_example.py @@ -18,32 +18,31 @@ from __future__ import print_function # $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.ml.feature import OneHotEncoderEstimator # $example off$ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("OneHotEncoderExample")\ + .appName("OneHotEncoderEstimatorExample")\ .getOrCreate() + # Note: categorical features are usually first encoded with StringIndexer # $example on$ df = spark.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + ], ["categoryIndex1", "categoryIndex2"]) - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - - encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) + encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryVec1", "categoryVec2"]) + model = encoder.fit(df) + encoded = model.transform(df) encoded.show() # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala similarity index 65% rename from examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala index 274cc1268f4d1..45d816808ed8e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala @@ -19,38 +19,34 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +import org.apache.spark.ml.feature.OneHotEncoderEstimator // $example off$ import org.apache.spark.sql.SparkSession -object OneHotEncoderExample { +object OneHotEncoderEstimatorExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder - .appName("OneHotEncoderExample") + .appName("OneHotEncoderEstimatorExample") .getOrCreate() + // Note: categorical features are usually first encoded with StringIndexer // $example on$ val df = spark.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") - - val encoded = encoder.transform(indexed) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + )).toDF("categoryIndex1", "categoryIndex2") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryVec1", "categoryVec2")) + val model = encoder.fit(df) + + val encoded = model.transform(df) encoded.show() // $example off$ From e41400c3c8aace9eb72e6134173f222627fb0faf Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 19 Jan 2018 19:46:48 +0800 Subject: [PATCH 0145/2461] [SPARK-23089][STS] Recreate session log directory if it doesn't exist ## What changes were proposed in this pull request? When creating a session directory, Thrift should create the parent directory (i.e. /tmp/base_session_log_dir) if it is not present. It is common that many tools delete empty directories, so the directory may be deleted. This can cause the session log to be disabled. This was fixed in HIVE-12262: this PR brings it in Spark too. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20281 from mgaido91/SPARK-23089. --- .../hive/service/cli/session/HiveSessionImpl.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 47bfaa86021d6..108074cce3d6d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -223,6 +223,18 @@ private void configureSession(Map sessionConfMap) throws HiveSQL @Override public void setOperationLogSessionDir(File operationLogRootDir) { + if (!operationLogRootDir.exists()) { + LOG.warn("The operation log root directory is removed, recreating: " + + operationLogRootDir.getAbsolutePath()); + if (!operationLogRootDir.mkdirs()) { + LOG.warn("Unable to create operation log root directory: " + + operationLogRootDir.getAbsolutePath()); + } + } + if (!operationLogRootDir.canWrite()) { + LOG.warn("The operation log root directory is not writable: " + + operationLogRootDir.getAbsolutePath()); + } sessionLogDir = new File(operationLogRootDir, sessionHandle.getHandleIdentifier().toString()); isOperationLogEnabled = true; if (!sessionLogDir.exists()) { From e1c33b6cd14e4e1123814f4d040e3520db7d1ec9 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 19 Jan 2018 08:22:24 -0600 Subject: [PATCH 0146/2461] [SPARK-23024][WEB-UI] Spark ui about the contents of the form need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark ui about the contents of the form need to have hidden and show features, when the table records very much. Because sometimes you do not care about the record of the table, you just want to see the contents of the next table, but you have to scroll the scroll bar for a long time to see the contents of the next table. Currently we have about 500 workers, but I just wanted to see the logs for the running applications table. I had to scroll through the scroll bars for a long time to see the logs for the running applications table. In order to ensure functional consistency, I modified the Master Page, Worker Page, Job Page, Stage Page, Task Page, Configuration Page, Storage Page, Pool Page. fix before: ![1](https://user-images.githubusercontent.com/26266482/34805936-601ed628-f6bb-11e7-8dd3-d8413573a076.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/34805949-6af8afba-f6bb-11e7-89f4-ab16584916fb.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20216 from guoxiaolongzte/SPARK-23024. --- .../org/apache/spark/ui/static/webui.js | 30 ++++++++ .../deploy/master/ui/ApplicationPage.scala | 25 +++++-- .../spark/deploy/master/ui/MasterPage.scala | 63 ++++++++++++++--- .../spark/deploy/worker/ui/WorkerPage.scala | 52 +++++++++++--- .../apache/spark/ui/env/EnvironmentPage.scala | 48 +++++++++++-- .../apache/spark/ui/jobs/AllJobsPage.scala | 39 +++++++++-- .../apache/spark/ui/jobs/AllStagesPage.scala | 67 +++++++++++++++--- .../org/apache/spark/ui/jobs/JobPage.scala | 68 ++++++++++++++++--- .../org/apache/spark/ui/jobs/PoolPage.scala | 13 +++- .../org/apache/spark/ui/jobs/StagePage.scala | 12 +++- .../apache/spark/ui/storage/StoragePage.scala | 12 +++- 11 files changed, 373 insertions(+), 56 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 0fa1fcf25f8b9..e575c4c78970d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -50,4 +50,34 @@ function collapseTable(thisName, table){ // to remember if it's collapsed on each page reload $(function() { collapseTablePageLoad('collapse-aggregated-metrics','aggregated-metrics'); + collapseTablePageLoad('collapse-aggregated-executors','aggregated-executors'); + collapseTablePageLoad('collapse-aggregated-removedExecutors','aggregated-removedExecutors'); + collapseTablePageLoad('collapse-aggregated-workers','aggregated-workers'); + collapseTablePageLoad('collapse-aggregated-activeApps','aggregated-activeApps'); + collapseTablePageLoad('collapse-aggregated-activeDrivers','aggregated-activeDrivers'); + collapseTablePageLoad('collapse-aggregated-completedApps','aggregated-completedApps'); + collapseTablePageLoad('collapse-aggregated-completedDrivers','aggregated-completedDrivers'); + collapseTablePageLoad('collapse-aggregated-runningExecutors','aggregated-runningExecutors'); + collapseTablePageLoad('collapse-aggregated-runningDrivers','aggregated-runningDrivers'); + collapseTablePageLoad('collapse-aggregated-finishedExecutors','aggregated-finishedExecutors'); + collapseTablePageLoad('collapse-aggregated-finishedDrivers','aggregated-finishedDrivers'); + collapseTablePageLoad('collapse-aggregated-runtimeInformation','aggregated-runtimeInformation'); + collapseTablePageLoad('collapse-aggregated-sparkProperties','aggregated-sparkProperties'); + collapseTablePageLoad('collapse-aggregated-systemProperties','aggregated-systemProperties'); + collapseTablePageLoad('collapse-aggregated-classpathEntries','aggregated-classpathEntries'); + collapseTablePageLoad('collapse-aggregated-activeJobs','aggregated-activeJobs'); + collapseTablePageLoad('collapse-aggregated-completedJobs','aggregated-completedJobs'); + collapseTablePageLoad('collapse-aggregated-failedJobs','aggregated-failedJobs'); + collapseTablePageLoad('collapse-aggregated-poolTable','aggregated-poolTable'); + collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); + collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); + collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); + collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); + collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); + collapseTablePageLoad('collapse-aggregated-completedStages','aggregated-completedStages'); + collapseTablePageLoad('collapse-aggregated-failedStages','aggregated-failedStages'); + collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); + collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); + collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); }); \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 68e57b7564ad1..f699c75085fe1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -100,12 +100,29 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
-

Executor Summary ({allExecutors.length})

- {executorsTable} + +

+ + Executor Summary ({allExecutors.length}) +

+
+
+ {executorsTable} +
{ if (removedExecutors.nonEmpty) { -

Removed Executors ({removedExecutors.length})

++ - removedExecutorsTable + +

+ + Removed Executors ({removedExecutors.length}) +

+
++ +
+ {removedExecutorsTable} +
} }
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index bc0bf6a1d9700..c629937606b51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -128,15 +128,31 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Workers ({workers.length})

- {workerTable} + +

+ + Workers ({workers.length}) +

+
+
+ {workerTable} +
-

Running Applications ({activeApps.length})

- {activeAppsTable} + +

+ + Running Applications ({activeApps.length}) +

+
+
+ {activeAppsTable} +
@@ -144,8 +160,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
-

Running Drivers ({activeDrivers.length})

- {activeDriversTable} + +

+ + Running Drivers ({activeDrivers.length}) +

+
+
+ {activeDriversTable} +
} @@ -154,8 +179,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Completed Applications ({completedApps.length})

- {completedAppsTable} + +

+ + Completed Applications ({completedApps.length}) +

+
+
+ {completedAppsTable} +
@@ -164,8 +198,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
-

Completed Drivers ({completedDrivers.length})

- {completedDriversTable} + +

+ + Completed Drivers ({completedDrivers.length}) +

+
+
+ {completedDriversTable} +
} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index ce84bc4dae32c..8b98ae56fc108 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -77,24 +77,60 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
-

Running Executors ({runningExecutors.size})

- {runningExecutorTable} + +

+ + Running Executors ({runningExecutors.size}) +

+
+
+ {runningExecutorTable} +
{ if (runningDrivers.nonEmpty) { -

Running Drivers ({runningDrivers.size})

++ - runningDriverTable + +

+ + Running Drivers ({runningDrivers.size}) +

+
++ +
+ {runningDriverTable} +
} } { if (finishedExecutors.nonEmpty) { -

Finished Executors ({finishedExecutors.size})

++ - finishedExecutorTable + +

+ + Finished Executors ({finishedExecutors.size}) +

+
++ +
+ {finishedExecutorTable} +
} } { if (finishedDrivers.nonEmpty) { -

Finished Drivers ({finishedDrivers.size})

++ - finishedDriverTable + +

+ + Finished Drivers ({finishedDrivers.size}) +

+
++ +
+ {finishedDriverTable} +
} }
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 43adab7a35d65..902eb92b854f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -48,10 +48,50 @@ private[ui] class EnvironmentPage( classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content = -

Runtime Information

{runtimeInformationTable} -

Spark Properties

{sparkPropertiesTable} -

System Properties

{systemPropertiesTable} -

Classpath Entries

{classpathEntriesTable} + +

+ + Runtime Information +

+
+
+ {runtimeInformationTable} +
+ +

+ + Spark Properties +

+
+
+ {sparkPropertiesTable} +
+ +

+ + System Properties +

+
+
+ {systemPropertiesTable} +
+ +

+ + Classpath Entries +

+
+
+ {classpathEntriesTable} +
UIUtils.headerSparkPage("Environment", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index ff916bb6a5759..e3b72f1f34859 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -363,16 +363,43 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We store.executorList(false), startTime) if (shouldShowActiveJobs) { - content ++=

Active Jobs ({activeJobs.size})

++ - activeJobsTable + content ++= + +

+ + Active Jobs ({activeJobs.size}) +

+
++ +
+ {activeJobsTable} +
} if (shouldShowCompletedJobs) { - content ++=

Completed Jobs ({completedJobNumStr})

++ - completedJobsTable + content ++= + +

+ + Completed Jobs ({completedJobNumStr}) +

+
++ +
+ {completedJobsTable} +
} if (shouldShowFailedJobs) { - content ++=

Failed Jobs ({failedJobs.size})

++ - failedJobsTable + content ++= + +

+ + Failed Jobs ({failedJobs.size}) +

+
++ +
+ {failedJobsTable} +
} val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index b1e343451e28e..606dc1e180e5b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -116,26 +116,75 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { var content = summary ++ { if (sc.isDefined && isFairScheduler) { -

Fair Scheduler Pools ({pools.size})

++ poolTable.toNodeSeq + +

+ + Fair Scheduler Pools ({pools.size}) +

+
++ +
+ {poolTable.toNodeSeq} +
} else { Seq.empty[Node] } } if (shouldShowActiveStages) { - content ++=

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} if (shouldShowPendingStages) { - content ++=

Pending Stages ({pendingStages.size})

++ - pendingStagesTable.toNodeSeq + content ++= + +

+ + Pending Stages ({pendingStages.size}) +

+
++ +
+ {pendingStagesTable.toNodeSeq} +
} if (shouldShowCompletedStages) { - content ++=

Completed Stages ({completedStageNumStr})

++ - completedStagesTable.toNodeSeq + content ++= + +

+ + Completed Stages ({completedStageNumStr}) +

+
++ +
+ {completedStagesTable.toNodeSeq} +
} if (shouldShowFailedStages) { - content ++=

Failed Stages ({numFailedStages})

++ - failedStagesTable.toNodeSeq + content ++= + +

+ + Failed Stages ({numFailedStages}) +

+
++ +
+ {failedStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage("Stages for All Jobs", content, parent) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index bf59152c8c0cd..c27f30c21a843 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -340,24 +340,72 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP jobId, store.operationGraphForJob(jobId)) if (shouldShowActiveStages) { - content ++=

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} if (shouldShowPendingStages) { - content ++=

Pending Stages ({pendingOrSkippedStages.size})

++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

+ + Pending Stages ({pendingOrSkippedStages.size}) +

+
++ +
+ {pendingOrSkippedStagesTable.toNodeSeq} +
} if (shouldShowCompletedStages) { - content ++=

Completed Stages ({completedStages.size})

++ - completedStagesTable.toNodeSeq + content ++= + +

+ + Completed Stages ({completedStages.size}) +

+
++ +
+ {completedStagesTable.toNodeSeq} +
} if (shouldShowSkippedStages) { - content ++=

Skipped Stages ({pendingOrSkippedStages.size})

++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

+ + Skipped Stages ({pendingOrSkippedStages.size}) +

+
++ +
+ {pendingOrSkippedStagesTable.toNodeSeq} +
} if (shouldShowFailedStages) { - content ++=

Failed Stages ({failedStages.size})

++ - failedStagesTable.toNodeSeq + content ++= + +

+ + Failed Stages ({failedStages.size}) +

+
++ +
+ {failedStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 98fbd7aceaa11..a3e1f13782e30 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -51,7 +51,18 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { val poolTable = new PoolTable(Map(pool -> uiPool), parent) var content =

Summary

++ poolTable.toNodeSeq if (activeStages.nonEmpty) { - content ++=

Active Stages ({activeStages.size})

++ activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index af78373ddb4b2..25bee33028393 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -486,8 +486,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++ aggMetrics ++ maybeAccumulableTable ++ -

Tasks ({totalTasksNumStr})

++ - taskTableHTML ++ jsForScrollingDownToTaskTable + +

+ + Tasks ({totalTasksNumStr}) +

+
++ +
+ {taskTableHTML ++ jsForScrollingDownToTaskTable} +
UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b8aec9890247a..68d946574a37b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -41,8 +41,16 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends Nil } else {
-

RDDs

- {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + +

+ + RDDs +

+
+
+ {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
} } From 6c39654efcb2aa8cb4d082ab7277a6fa38fb48e4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 19 Jan 2018 22:47:18 +0800 Subject: [PATCH 0147/2461] [SPARK-23000][TEST] Keep Derby DB Location Unchanged After Session Cloning ## What changes were proposed in this pull request? After session cloning in `TestHive`, the conf of the singleton SparkContext for derby DB location is changed to a new directory. The new directory is created in `HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false)`. This PR is to keep the conf value of `ConfVars.METASTORECONNECTURLKEY.varname` unchanged during the session clone. ## How was this patch tested? The issue can be reproduced by the command: > build/sbt -Phive "hive/test-only org.apache.spark.sql.hive.HiveSessionStateSuite org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite" Also added a test case. Author: gatorsmile Closes #20328 from gatorsmile/fixTestFailure. --- .../org/apache/spark/sql/SessionStateSuite.scala | 5 +---- .../apache/spark/sql/hive/test/TestHive.scala | 8 +++++++- .../spark/sql/hive/HiveSessionStateSuite.scala | 16 +++++++++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5d75f5835bf9e..4efae4c46c2e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll -import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite @@ -28,8 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -class SessionStateSuite extends SparkFunSuite - with BeforeAndAfterEach with BeforeAndAfterAll { +class SessionStateSuite extends SparkFunSuite { /** * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b6be00dbb3a73..c84131fc3212a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -180,7 +180,13 @@ private[hive] class TestHiveSparkSession( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", // scratch directory used by Hive's metastore client ConfVars.SCRATCHDIR.varname -> TestHiveContext.makeScratchDir().toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") ++ + // After session cloning, the JDBC connect string for a JDBC metastore should not be changed. + existingSharedState.map { state => + val connKey = + state.sparkContext.hadoopConfiguration.get(ConfVars.METASTORECONNECTURLKEY.varname) + ConfVars.METASTORECONNECTURLKEY.varname -> connKey + } metastoreTempConf.foreach { case (k, v) => sc.hadoopConfiguration.set(k, v) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index f7da3c4cbb0aa..ecc09cdcdbeaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterEach +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -25,8 +25,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton /** * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ -class HiveSessionStateSuite extends SessionStateSuite - with TestHiveSingleton with BeforeAndAfterEach { +class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { override def beforeAll(): Unit = { // Reuse the singleton session @@ -39,4 +38,15 @@ class HiveSessionStateSuite extends SessionStateSuite activeSession = null super.afterAll() } + + test("Clone then newSession") { + val sparkSession = hiveContext.sparkSession + val conf = sparkSession.sparkContext.hadoopConfiguration + val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + sparkSession.cloneSession() + sparkSession.sharedState.externalCatalog.client.newSession() + val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + assert(oldValue == newValue, + "cloneSession and then newSession should not affect the Derby directory") + } } From 606a7485f12c5d5377c50258006c353ba5e49c3f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 19 Jan 2018 09:28:35 -0600 Subject: [PATCH 0148/2461] [SPARK-23085][ML] API parity for mllib.linalg.Vectors.sparse ## What changes were proposed in this pull request? `ML.Vectors#sparse(size: Int, elements: Seq[(Int, Double)])` support zero-length ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20275 from zhengruifeng/SparseVector_size. --- .../scala/org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/linalg/VectorsSuite.scala | 14 ++++++++++++++ .../org/apache/spark/mllib/linalg/Vectors.scala | 3 +-- .../apache/spark/mllib/linalg/VectorsSuite.scala | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 941b6eca568d3..5824e463ca1aa 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -565,7 +565,7 @@ class SparseVector @Since("2.0.0") ( // validate the data { - require(size >= 0, "The size of the requested sparse vector must be greater than 0.") + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 79acef8214d88..0a316f57f811b 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -366,4 +366,18 @@ class VectorsSuite extends SparkMLFunSuite { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fd9605c013625..6e68d9684a672 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -326,8 +326,6 @@ object Vectors { */ @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0, "The size of the requested sparse vector must be greater than 0.") - val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 indices.foreach { i => @@ -758,6 +756,7 @@ class SparseVector @Since("1.0.0") ( @Since("1.0.0") val indices: Array[Int], @Since("1.0.0") val values: Array[Double]) extends Vector { + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 4074bead421e6..217b4a35438fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -495,4 +495,18 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV)) assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV)) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } From d8aaa771e249b3f54b57ce24763e53fd65a0dbf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Jan 2018 08:58:21 -0800 Subject: [PATCH 0149/2461] [SPARK-23149][SQL] polish ColumnarBatch ## What changes were proposed in this pull request? Several cleanups in `ColumnarBatch` * remove `schema`. The `ColumnVector`s inside `ColumnarBatch` already have the data type information, we don't need this `schema`. * remove `capacity`. `ColumnarBatch` is just a wrapper of `ColumnVector`s, not builders, it doesn't need a capacity property. * remove `DEFAULT_BATCH_SIZE`. As a wrapper, `ColumnarBatch` can't decide the batch size, it should be decided by the reader, e.g. parquet reader, orc reader, cached table reader. The default batch size should also be defined by the reader. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20316 from cloud-fan/columnar-batch. --- .../orc/OrcColumnarBatchReader.java | 49 +++++++------------ .../SpecificParquetRecordReaderBase.java | 12 ++--- .../VectorizedParquetRecordReader.java | 24 ++++----- .../vectorized/ColumnVectorUtils.java | 18 +++---- .../spark/sql/vectorized/ColumnarBatch.java | 20 +------- .../VectorizedHashMapGenerator.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 5 +- .../python/ArrowEvalPythonExec.scala | 8 +-- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 3 +- .../vectorized/ColumnarBatchSuite.scala | 7 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 3 +- 13 files changed, 61 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 36fdf2bdf84d2..89bae4326e93b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,18 +49,8 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - - /** - * The default size of batch. We use this value for ORC reader to make it consistent with Spark's - * columnar batch, because their default batch sizes are different like the following: - * - * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 - * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 - */ - private static final int DEFAULT_SIZE = 4 * 1024; - - // ORC File Reader - private Reader reader; + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -98,22 +88,22 @@ public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @Override - public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + public ColumnarBatch getCurrentValue() { return columnarBatch; } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() throws IOException { return recordReader.getProgress(); } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { return nextBatch(); } @@ -134,16 +124,15 @@ public void close() throws IOException { * Please note that `initBatch` is needed to be called after this. */ @Override - public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException { + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { FileSplit fileSplit = (FileSplit)inputSplit; Configuration conf = taskAttemptContext.getConfiguration(); - reader = OrcFile.createReader( + Reader reader = OrcFile.createReader( fileSplit.getPath(), OrcFile.readerOptions(conf) .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) .filesystem(fileSplit.getPath().getFileSystem(conf))); - Reader.Options options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); recordReader = reader.rows(options); @@ -159,7 +148,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(DEFAULT_SIZE); + batch = orcSchema.createRowBatch(CAPACITY); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -171,19 +160,17 @@ public void initBatch( resultSchema = resultSchema.add(f); } - int capacity = DEFAULT_SIZE; - if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, capacity); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } @@ -196,7 +183,7 @@ public void initBatch( } } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); } else { // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; @@ -206,8 +193,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); + OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); + missingCol.putNulls(0, CAPACITY); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -219,14 +206,14 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; } } - columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); + columnarBatch = new ColumnarBatch(orcVectorWrappers); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 80c2f491b48ce..e65cd252c3ddf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -170,7 +170,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont * Returns the list of files at 'path' recursively. This skips files that are ignored normally * by MapReduce. */ - public static List listDirectory(File path) throws IOException { + public static List listDirectory(File path) { List result = new ArrayList<>(); if (path.isDirectory()) { for (File f: path.listFiles()) { @@ -231,7 +231,7 @@ protected void initialize(String path, List columns) throws IOException } @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @@ -259,7 +259,7 @@ public ValuesReaderIntIterator(ValuesReader delegate) { } @Override - int nextInt() throws IOException { + int nextInt() { return delegate.readInteger(); } } @@ -279,15 +279,15 @@ int nextInt() throws IOException { protected static final class NullIntIterator extends IntIterator { @Override - int nextInt() throws IOException { return 0; } + int nextInt() { return 0; } } /** * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, - ColumnDescriptor descriptor) throws IOException { + protected static IntIterator createRLEIterator( + int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); return new RLEIntIterator( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index cd745b1f0e4e3..bb1b23611a7d7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,6 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; + /** * Batch of rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -152,7 +155,7 @@ public void close() throws IOException { } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { resultBatch(); if (returnColumnarBatch) return nextBatch(); @@ -165,13 +168,13 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } @Override - public Object getCurrentValue() throws IOException, InterruptedException { + public Object getCurrentValue() { if (returnColumnarBatch) return columnarBatch; return columnarBatch.getRow(batchIdx - 1); } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return (float) rowsReturned / totalRowCount; } @@ -181,7 +184,7 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch( + private void initBatch( MemoryMode memMode, StructType partitionColumns, InternalRow partitionValues) { @@ -195,13 +198,12 @@ public void initBatch( } } - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } - columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -213,13 +215,13 @@ public void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } } - public void initBatch() { + private void initBatch() { initBatch(MEMORY_MODE, null, null); } @@ -255,7 +257,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b5cbe8e2839ba..5ee8cc8da2309 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -118,19 +118,19 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } } else { if (t == DataTypes.BooleanType) { - dst.appendBoolean(((Boolean)o).booleanValue()); + dst.appendBoolean((Boolean) o); } else if (t == DataTypes.ByteType) { - dst.appendByte(((Byte) o).byteValue()); + dst.appendByte((Byte) o); } else if (t == DataTypes.ShortType) { - dst.appendShort(((Short)o).shortValue()); + dst.appendShort((Short) o); } else if (t == DataTypes.IntegerType) { - dst.appendInt(((Integer)o).intValue()); + dst.appendInt((Integer) o); } else if (t == DataTypes.LongType) { - dst.appendLong(((Long)o).longValue()); + dst.appendLong((Long) o); } else if (t == DataTypes.FloatType) { - dst.appendFloat(((Float)o).floatValue()); + dst.appendFloat((Float) o); } else if (t == DataTypes.DoubleType) { - dst.appendDouble(((Double)o).doubleValue()); + dst.appendDouble((Double) o); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); @@ -192,7 +192,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + int capacity = 4 * 1024; WritableColumnVector[] columnVectors; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); @@ -208,7 +208,7 @@ public static ColumnarBatch toBatch( } n++; } - ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); + ColumnarBatch batch = new ColumnarBatch(columnVectors); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 9ae1c6d9993f0..4dc826cf60c15 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; -import org.apache.spark.sql.types.StructType; /** * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this @@ -28,10 +27,6 @@ * the entire data loading process. */ public final class ColumnarBatch { - public static final int DEFAULT_BATCH_SIZE = 4 * 1024; - - private final StructType schema; - private final int capacity; private int numRows; private final ColumnVector[] columns; @@ -82,7 +77,6 @@ public void remove() { * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { - assert(numRows <= this.capacity); this.numRows = numRows; } @@ -96,16 +90,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the schema that makes up this batch. - */ - public StructType schema() { return schema; } - - /** - * Returns the max capacity (in number of rows) for this batch. - */ - public int capacity() { return capacity; } - /** * Returns the column at `ordinal`. */ @@ -120,10 +104,8 @@ public InternalRow getRow(int rowId) { return row; } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { - this.schema = schema; + public ColumnarBatch(ColumnVector[] columns) { this.columns = columns; - this.capacity = capacity; this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0cf9b53ce1d5d..eb48584d0c1ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -94,7 +94,7 @@ class VectorizedHashMapGenerator( | | public $generatedClassName() { | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); - | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); + | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcd1aa0890ba3..7487564ed64da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -175,7 +175,7 @@ private[sql] object ArrowConverters { new ArrowColumnVector(vector).asInstanceOf[ColumnVector] }.toArray - val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch.rowIterator().asScala } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3565ee3af1b9f..28b3875505cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,11 +78,10 @@ case class InMemoryTableScanExec( } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } - val columnarBatch = new ColumnarBatch( - columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) columnarBatch.setNumRows(rowCount) - for (i <- 0 until attributes.length) { + for (i <- attributes.indices) { ColumnAccessor.decompress( cachedColumnarBatch.buffers(columnIndices(i)), columnarBatch.column(i).asInstanceOf[WritableColumnVector], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c06bc7b66ff39..47b146f076b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -74,8 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) + val outputTypes = output.drop(child.output.length).map(_.dataType) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) @@ -90,8 +89,9 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() - assert(schemaOut.equals(batch.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } else { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index dc5ba96e69aec..5fcdcddca7d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -138,7 +138,7 @@ class ArrowPythonRunner( if (reader != null && batchLoaded) { batchLoaded = reader.loadNextBatch() if (batchLoaded) { - val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) batch } else { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 98d6a53b54d28..a5d77a90ece42 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,8 +69,7 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = - new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = new ColumnarBatch(vectors); return this; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 675f06b31b970..cd90681ecabc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -875,14 +875,13 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val capacity = 4 * 1024 val columns = schema.fields.map { field => allocate(capacity, field.dataType, memMode) } - val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) + val batch = new ColumnarBatch(columns.toArray) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] @@ -1153,7 +1152,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) - val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + val batch = new ColumnarBatch(columnVectors.toArray) batch.setNumRows(11) assert(batch.numCols() == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a89f7c55bf4f7..0ca29524c6d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -333,8 +333,7 @@ class BatchReadTask(start: Int, end: Int) private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch( - new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + private lazy val batch = new ColumnarBatch(Array(i, j)) private var current = start From 73d3b230f3816a854a181c0912d87b180e347271 Mon Sep 17 00:00:00 2001 From: foxish Date: Fri, 19 Jan 2018 10:23:13 -0800 Subject: [PATCH 0150/2461] [SPARK-23104][K8S][DOCS] Changes to Kubernetes scheduler documentation ## What changes were proposed in this pull request? Docs changes: - Adding a warning that the backend is experimental. - Removing a defunct internal-only option from documentation - Clarifying that node selectors can be used right away, and other minor cosmetic changes ## How was this patch tested? Docs only change Author: foxish Closes #20314 from foxish/ambiguous-docs. --- docs/cluster-overview.md | 4 ++-- docs/running-on-kubernetes.md | 43 ++++++++++++++++------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 658e67f99dd71..7277e2fb2731d 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,8 +52,8 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -* [Kubernetes](running-on-kubernetes.html) -- [Kubernetes](https://kubernetes.io/docs/concepts/overview/what-is-kubernetes/) -is an open-source platform that provides container-centric infrastructure. +* [Kubernetes](running-on-kubernetes.html) -- an open-source system for automating deployment, scaling, + and management of containerized applications. A third-party project (not supported by the Spark project) exists to add support for [Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d6b1735ce5550..3c7586e8544ba 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -8,6 +8,10 @@ title: Running Spark on Kubernetes Spark can run on clusters managed by [Kubernetes](https://kubernetes.io). This feature makes use of native Kubernetes scheduler that has been added to Spark. +**The Kubernetes scheduler is currently experimental. +In future versions, there may be behavioral changes around configuration, +container images and entrypoints.** + # Prerequisites * A runnable distribution of Spark 2.3 or above. @@ -41,11 +45,10 @@ logs and remains in "completed" state in the Kubernetes API until it's eventuall Note that in the completed state, the driver pod does *not* use any computational or memory resources. -The driver and executor pod scheduling is handled by Kubernetes. It will be possible to affect Kubernetes scheduling -decisions for driver and executor pods using advanced primitives like -[node selectors](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) -and [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) -in a future release. +The driver and executor pod scheduling is handled by Kubernetes. It is possible to schedule the +driver and executor pods on a subset of available nodes through a [node selector](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) +using the configuration property for it. It will be possible to use more advanced +scheduling hints like [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) in a future release. # Submitting Applications to Kubernetes @@ -62,8 +65,10 @@ use with the Kubernetes backend. Example usage is: - ./bin/docker-image-tool.sh -r -t my-tag build - ./bin/docker-image-tool.sh -r -t my-tag push +```bash +$ ./bin/docker-image-tool.sh -r -t my-tag build +$ ./bin/docker-image-tool.sh -r -t my-tag push +``` ## Cluster Mode @@ -94,7 +99,7 @@ must consist of lower case alphanumeric characters, `-`, and `.` and must start If you have a Kubernetes cluster setup, one way to discover the apiserver URL is by executing `kubectl cluster-info`. ```bash -kubectl cluster-info +$ kubectl cluster-info Kubernetes master is running at http://127.0.0.1:6443 ``` @@ -105,7 +110,7 @@ authenticating proxy, `kubectl proxy` to communicate to the Kubernetes API. The local proxy can be started by: ```bash -kubectl proxy +$ kubectl proxy ``` If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0.1:8001` can be used as the argument to @@ -173,7 +178,7 @@ Logs can be accessed using the Kubernetes API and the `kubectl` CLI. When a Spar to stream logs from the application using: ```bash -kubectl -n= logs -f +$ kubectl -n= logs -f ``` The same logs can also be accessed through the @@ -186,7 +191,7 @@ The UI associated with any application can be accessed locally using [`kubectl port-forward`](https://kubernetes.io/docs/tasks/access-application-cluster/port-forward-access-application-cluster/#forward-a-local-port-to-a-port-on-the-pod). ```bash -kubectl port-forward 4040:4040 +$ kubectl port-forward 4040:4040 ``` Then, the Spark driver UI can be accessed on `http://localhost:4040`. @@ -200,13 +205,13 @@ are errors during the running of the application, often, the best way to investi To get some basic information about the scheduling decisions made around the driver pod, you can run: ```bash -kubectl describe pod +$ kubectl describe pod ``` If the pod has encountered a runtime error, the status can be probed further using: ```bash -kubectl logs +$ kubectl logs ``` Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark @@ -254,7 +259,7 @@ To create a custom service account, a user can use the `kubectl create serviceac following command creates a service account named `spark`: ```bash -kubectl create serviceaccount spark +$ kubectl create serviceaccount spark ``` To grant a service account a `Role` or `ClusterRole`, a `RoleBinding` or `ClusterRoleBinding` is needed. To create @@ -263,7 +268,7 @@ for `ClusterRoleBinding`) command. For example, the following command creates an namespace and grants it to the `spark` service account created above: ```bash -kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default +$ kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default ``` Note that a `Role` can only be used to grant access to resources (like pods) within a single namespace, whereas a @@ -543,14 +548,6 @@ specific to Spark on Kubernetes. to avoid name conflicts. - - spark.kubernetes.executor.podNamePrefix - (none) - - Prefix for naming the executor pods. - If not set, the executor pod name is set to driver pod name suffixed by an integer. - - spark.kubernetes.executor.lostCheck.maxAttempts 10 From 07296a61c29eb074553956b6c0f92810ecf7bab2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 10:25:18 -0800 Subject: [PATCH 0151/2461] [INFRA] Close stale PR. Closes #20185. From fed2139f053fac4a9a6952ff0ab1cc2a5f657bd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:26:37 -0600 Subject: [PATCH 0152/2461] [SPARK-20664][CORE] Delete stale application data from SHS. Detect the deletion of event log files from storage, and remove data about the related application attempt in the SHS. Also contains code to fix SPARK-21571 based on code by ericvandenbergfb. Author: Marcelo Vanzin Closes #20138 from vanzin/SPARK-20664. --- .../deploy/history/FsHistoryProvider.scala | 297 +++++++++++------- .../history/FsHistoryProviderSuite.scala | 117 ++++++- .../deploy/history/HistoryServerSuite.scala | 4 +- 3 files changed, 306 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 94c80ebd55e74..f9d0b5ee4e23e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.util.{Date, ServiceLoader, UUID} -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -29,7 +29,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem @@ -116,8 +116,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs // and applications between check task and clean task. - private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() - .setNameFormat("spark-history-task-%d").setDaemon(true).build()) + private val pool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-history-task-%d") // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) @@ -174,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (!conf.contains("spark.testing")) { + if (Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() @@ -275,7 +274,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { Some(load(appId).toApplicationInfo()) } catch { - case e: NoSuchElementException => + case _: NoSuchElementException => None } } @@ -405,49 +404,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val newLastScanTime = getNewLastScanTime() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - // scan for modified applications, replay and merge them - val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) + + val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && - recordedFileSize(entry.getPath()) < entry.getLen() + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + } + .filter { entry => + try { + val info = listing.read(classOf[LogInfo], entry.getPath().toString()) + if (info.fileSize < entry.getLen()) { + // Log size has changed, it should be parsed. + true + } else { + // If the SHS view has a valid application, update the time the file was last seen so + // that the entry is not deleted from the SHS listing. + if (info.appId.isDefined) { + listing.write(info.copy(lastProcessed = newLastScanTime)) + } + false + } + } catch { + case _: NoSuchElementException => + // If the file is currently not being tracked by the SHS, add an entry for it and try + // to parse it. This will allow the cleaner code to detect the file as stale later on + // if it was not possible to parse it. + listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, None, None, + entry.getLen())) + entry.getLen() > 0 + } } .sortWith { case (entry1, entry2) => entry1.getModificationTime() > entry2.getModificationTime() } - if (logInfos.nonEmpty) { - logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") + if (updated.nonEmpty) { + logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - var tasks = mutable.ListBuffer[Future[_]]() - - try { - for (file <- logInfos) { - tasks += replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(file) + val tasks = updated.map { entry => + try { + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) }) + } catch { + // let the iteration over the updated entries break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + null } - } catch { - // let the iteration over logInfos break, since an exception on - // replayExecutor.submit (..) indicates the ExecutorService is unable - // to take any more submissions at this time - - case e: Exception => - logError(s"Exception while submitting event log for replay", e) - } + }.filter(_ != null) pendingReplayTasksCount.addAndGet(tasks.size) + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. tasks.foreach { task => try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. task.get() } catch { case e: InterruptedException => @@ -459,13 +479,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + // Delete all information about applications whose log files disappeared from storage. + // This is done by identifying the event logs which were not touched by the current + // directory scan. + // + // Only entries with valid applications are cleaned up here. Cleaning up invalid log + // files is done by the periodic cleaner task. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .last(newLastScanTime - 1) + .asScala + .toList + stale.foreach { log => + log.appId.foreach { appId => + cleanAppData(appId, log.attemptId, log.logPath) + listing.delete(classOf[LogInfo], log.logPath) + } + } + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } - private def getNewLastScanTime(): Long = { + private def cleanAppData(appId: String, attemptId: Option[String], logPath: String): Unit = { + try { + val app = load(appId) + val (attempt, others) = app.attempts.partition(_.info.attemptId == attemptId) + + assert(attempt.isEmpty || attempt.size == 1) + val isStale = attempt.headOption.exists { a => + if (a.logPath != new Path(logPath).getName()) { + // If the log file name does not match, then probably the old log file was from an + // in progress application. Just return that the app should be left alone. + false + } else { + val maybeUI = synchronized { + activeUIs.remove(appId -> attemptId) + } + + maybeUI.foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + + diskManager.foreach(_.release(appId, attemptId, delete = true)) + true + } + } + + if (isStale) { + if (others.nonEmpty) { + val newAppInfo = new ApplicationInfoWrapper(app.info, others) + listing.write(newAppInfo) + } else { + listing.delete(classOf[ApplicationInfoWrapper], appId) + } + } + } catch { + case _: NoSuchElementException => + } + } + + private[history] def getNewLastScanTime(): Long = { val fileName = "." + UUID.randomUUID().toString val path = new Path(logDir, fileName) val fos = fs.create(path) @@ -530,7 +607,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -544,73 +621,78 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.addListener(listener) replay(fileStatus, bus, eventsFilter = eventsFilter) - listener.applicationInfo.foreach { app => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + val (appId, attemptId) = listener.applicationInfo match { + case Some(app) => + // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a + // discussion on the UI lifecycle. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - } - addListing(app) + addListing(app) + (Some(app.info.id), app.attempts.head.info.attemptId) + + case _ => + // If the app hasn't written down its app ID to the logs, still record the entry in the + // listing db, with an empty ID. This will make the log eligible for deletion if the app + // does not make progress after the configured max log age. + (None, None) } - listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) + listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** * Delete event logs from the log directory according to the clean policy defined by the user. */ - private[history] def cleanLogs(): Unit = { - var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None - try { - val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - - // Iterate descending over all applications whose oldest attempt happened before maxTime. - iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) - .index("oldestAttempt") - .reverse() - .first(maxTime) - .closeableIterator()) - - iterator.get.asScala.foreach { app => - // Applications may have multiple attempts, some of which may not need to be deleted yet. - val (remaining, toDelete) = app.attempts.partition { attempt => - attempt.info.lastUpdated.getTime() >= maxTime - } + private[history] def cleanLogs(): Unit = Utils.tryLog { + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - if (remaining.nonEmpty) { - val newApp = new ApplicationInfoWrapper(app.info, remaining) - listing.write(newApp) - } + val expired = listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .asScala + .toList + expired.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } - toDelete.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - try { - listing.delete(classOf[LogInfo], logPath.toString()) - } catch { - case _: NoSuchElementException => - logDebug(s"Log info entry for $logPath not found.") - } - try { - fs.delete(logPath, true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) - } - } + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) + } - if (remaining.isEmpty) { - listing.delete(app.getClass(), app.id) - } + toDelete.foreach { attempt => + logInfo(s"Deleting expired event log for ${attempt.logPath}") + val logPath = new Path(logDir, attempt.logPath) + listing.delete(classOf[LogInfo], logPath.toString()) + cleanAppData(app.id, attempt.info.attemptId, logPath.toString()) + deleteLog(logPath) + } + + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) + } + } + + // Delete log files that don't have a valid application and exceed the configured max age. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .reverse() + .first(maxTime) + .asScala + .toList + stale.foreach { log => + if (log.appId.isEmpty) { + logInfo(s"Deleting invalid / corrupt event log ${log.logPath}") + deleteLog(new Path(log.logPath)) + listing.delete(classOf[LogInfo], log.logPath) } - } catch { - case t: Exception => logError("Exception while cleaning logs", t) - } finally { - iterator.foreach(_.close()) } } @@ -631,12 +713,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // an error the other way -- if we report a size bigger (ie later) than the file that is // actually read, we may never refresh the app. FileStatus is guaranteed to be static // after it's created, so we get a file size that is no bigger than what is actually read. - val logInput = EventLoggingListener.openEventLog(logPath, fs) - try { - bus.replay(logInput, logPath.toString, !isCompleted, eventsFilter) + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !isCompleted, eventsFilter) logInfo(s"Finished parsing $logPath") - } finally { - logInput.close() } } @@ -703,18 +782,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return the last known size of the given event log, recorded the last time the file - * system scanner detected a change in the file. - */ - private def recordedFileSize(log: Path): Long = { - try { - listing.read(classOf[LogInfo], log.toString()).fileSize - } catch { - case _: NoSuchElementException => 0L - } - } - private def load(appId: String): ApplicationInfoWrapper = { listing.read(classOf[ApplicationInfoWrapper], appId) } @@ -773,11 +840,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") val lease = dm.lease(status.getLen(), isCompressed) val newStorePath = try { - val store = KVUtils.open(lease.tmpPath, metadata) - try { + Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata)) { store => rebuildAppStore(store, status, attempt.info.lastUpdated.getTime()) - } finally { - store.close() } lease.commit(appId, attempt.info.attemptId) } catch { @@ -806,6 +870,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) } + private def deleteLog(log: Path): Unit = { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } + } + } private[history] object FsHistoryProvider { @@ -832,8 +907,16 @@ private[history] case class FsHistoryProviderMetadata( uiVersion: Long, logDir: String) +/** + * Tracking info for event logs detected in the configured log directory. Tracks both valid and + * invalid logs (e.g. unparseable logs, recorded as logs with no app ID) so that the cleaner + * can know what log files are safe to delete. + */ private[history] case class LogInfo( @KVIndexParam logPath: String, + @KVIndexParam("lastProcessed") lastProcessed: Long, + appId: Option[String], + attemptId: Option[String], fileSize: Long) private[history] class AttemptInfoWrapper( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 84ee01c7f5aaf..787de59edf465 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.Mockito.{doReturn, mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -149,8 +149,10 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { var mergeApplicationListingCall = 0 - override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - super.mergeApplicationListing(fileStatus) + override protected def mergeApplicationListing( + fileStatus: FileStatus, + lastSeen: Long): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen) mergeApplicationListingCall += 1 } } @@ -663,6 +665,115 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc freshUI.get.ui.store.job(0) } + test("clean up stale app information") { + val storeDir = Utils.createTempDir() + val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + val provider = spy(new FsHistoryProvider(conf)) + val appId = "new1" + + // Write logs for two app attempts. + doReturn(1L).when(provider).getNewLastScanTime() + val attempt1 = newLogFile(appId, Some("1"), inProgress = false) + writeFile(attempt1, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + val attempt2 = newLogFile(appId, Some("2"), inProgress = false) + writeFile(attempt2, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 2) + } + + // Load the app's UI. + val ui = provider.getAppUI(appId, Some("1")) + assert(ui.isDefined) + + // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since + // attempt 2 still exists, listing data should be there. + doReturn(2L).when(provider).getNewLastScanTime() + attempt1.delete() + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 1) + } + assert(!ui.get.valid) + assert(provider.getAppUI(appId, None) === None) + + // Delete the second attempt's log file. Now everything should go away. + doReturn(3L).when(provider).getNewLastScanTime() + attempt2.delete() + updateAndCheck(provider) { list => + assert(list.isEmpty) + } + } + + test("SPARK-21571: clean up removes invalid history files") { + // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that + // until we figure out what's causing the problem. + val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") + val provider = new FsHistoryProvider(conf, clock) { + override def getNewLastScanTime(): Long = clock.getTimeMillis() + } + + // Create 0-byte size inprogress and complete files + var logCount = 0 + var validLogCount = 0 + + val emptyInProgress = newLogFile("emptyInprogressLogFile", None, inProgress = true) + emptyInProgress.createNewFile() + emptyInProgress.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val slowApp = newLogFile("slowApp", None, inProgress = true) + slowApp.createNewFile() + slowApp.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val emptyFinished = newLogFile("emptyFinishedLogFile", None, inProgress = false) + emptyFinished.createNewFile() + emptyFinished.setLastModified(clock.getTimeMillis()) + logCount += 1 + + // Create an incomplete log file, has an end record but no start record. + val corrupt = newLogFile("nonEmptyCorruptLogFile", None, inProgress = false) + writeFile(corrupt, true, None, SparkListenerApplicationEnd(0)) + corrupt.setLastModified(clock.getTimeMillis()) + logCount += 1 + + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Move the clock forward 1 day and scan the files again. They should still be there. + clock.advance(TimeUnit.DAYS.toMillis(1)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Update the slow app to contain valid info. Code should detect the change and not clean + // it up. + writeFile(slowApp, true, None, + SparkListenerApplicationStart(slowApp.getName(), Some(slowApp.getName()), 1L, "test", None)) + slowApp.setLastModified(clock.getTimeMillis()) + validLogCount += 1 + + // Move the clock forward another 2 days and scan the files again. This time the cleaner should + // pick up the invalid files and get rid of them. + clock.advance(TimeUnit.DAYS.toMillis(2)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === validLogCount) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87778dda0e2c8..7aa60f2b60796 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -564,7 +564,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit() + ShutdownHookManager.registerShutdownDeleteDir(logDir) } test("ui and api authorization checks") { From aa3a1276f9e23ffbb093d00743e63cd4369f9f57 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:32:20 -0600 Subject: [PATCH 0153/2461] [SPARK-23103][CORE] Ensure correct sort order for negative values in LevelDB. The code was sorting "0" as "less than" negative values, which is a little wrong. Fix is simple, most of the changes are the added test and related cleanup. Author: Marcelo Vanzin Closes #20284 from vanzin/SPARK-23103. --- .../spark/util/kvstore/LevelDBTypeInfo.java | 2 +- .../spark/util/kvstore/DBIteratorSuite.java | 7 +- .../spark/util/kvstore/LevelDBSuite.java | 77 ++++++++++--------- .../spark/status/AppStatusListenerSuite.scala | 8 +- 4 files changed, 52 insertions(+), 42 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java index 232ee41dd0b1f..f4d359234cb9e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -493,7 +493,7 @@ byte[] toKey(Object value, byte prefix) { byte[] key = new byte[bytes * 2 + 2]; long longValue = ((Number) value).longValue(); key[0] = prefix; - key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + key[1] = longValue >= 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; for (int i = 0; i < key.length - 2; i++) { int masked = (int) ((longValue >>> (4 * i)) & 0xF); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java index 9a81f86812cde..1e062437d1803 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java @@ -73,7 +73,9 @@ default BaseComparator reverse() { private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); - private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> { + return Integer.valueOf(t1.num).compareTo(t2.num); + }; private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); /** @@ -112,7 +114,8 @@ public void setup() throws Exception { t.key = "key" + i; t.id = "id" + i; t.name = "name" + RND.nextInt(MAX_ENTRIES); - t.num = RND.nextInt(MAX_ENTRIES); + // Force one item to have an integer value of zero to test the fix for SPARK-23103. + t.num = (i != 0) ? (int) RND.nextLong() : 0; t.child = "child" + (i % MIN_ENTRIES); allEntries.add(t); } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 2b07d249d2022..b8123ac81d29a 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; @@ -74,11 +76,7 @@ public void testReopenAndVersionCheckDb() throws Exception { @Test public void testObjectWriteReadDelete() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); try { db.read(CustomType1.class, t.key); @@ -106,17 +104,9 @@ public void testObjectWriteReadDelete() throws Exception { @Test public void testMultipleObjectWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "key1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; - - CustomType1 t2 = new CustomType1(); - t2.key = "key2"; - t2.id = "id"; - t2.name = "name2"; - t2.child = "child2"; + CustomType1 t1 = createCustomType1(1); + CustomType1 t2 = createCustomType1(2); + t2.id = t1.id; db.write(t1); db.write(t2); @@ -142,11 +132,7 @@ public void testMultipleObjectWriteReadDelete() throws Exception { @Test public void testMultipleTypesWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; + CustomType1 t1 = createCustomType1(1); IntKeyType t2 = new IntKeyType(); t2.key = 2; @@ -188,10 +174,7 @@ public void testMultipleTypesWriteReadDelete() throws Exception { public void testMetadata() throws Exception { assertNull(db.getMetadata(CustomType1.class)); - CustomType1 t = new CustomType1(); - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.setMetadata(t); assertEquals(t, db.getMetadata(CustomType1.class)); @@ -202,11 +185,7 @@ public void testMetadata() throws Exception { @Test public void testUpdate() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.write(t); @@ -222,13 +201,7 @@ public void testUpdate() throws Exception { @Test public void testSkip() throws Exception { for (int i = 0; i < 10; i++) { - CustomType1 t = new CustomType1(); - t.key = "key" + i; - t.id = "id" + i; - t.name = "name" + i; - t.child = "child" + i; - - db.write(t); + db.write(createCustomType1(i)); } KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); @@ -240,6 +213,36 @@ public void testSkip() throws Exception { assertFalse(it.hasNext()); } + @Test + public void testNegativeIndexValues() throws Exception { + List expected = Arrays.asList(-100, -50, 0, 50, 100); + + expected.stream().forEach(i -> { + try { + db.write(createCustomType1(i)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + List results = StreamSupport + .stream(db.view(CustomType1.class).index("int").spliterator(), false) + .map(e -> e.num) + .collect(Collectors.toList()); + + assertEquals(expected, results); + } + + private CustomType1 createCustomType1(int i) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.num = i; + t.child = "child" + i; + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ca66b6b9db890..e7981bec6d64b 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -894,15 +894,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val dropped = stages.drop(1).head // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be - // calculcated, we need at least one finished task. + // calculated, we need at least one finished task. The code in AppStatusStore uses + // `executorRunTime` to detect valid tasks, so that metric needs to be updated in the + // task end event. time += 1 val task = createTasks(1, Array("1")).head listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) time += 1 task.markFinished(TaskState.FINISHED, time) + val metrics = TaskMetrics.empty + metrics.setExecutorRunTime(42L) listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, - "taskType", Success, task, null)) + "taskType", Success, task, metrics)) new AppStatusStore(store) .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) From f6da41b0150725fe96ccb2ee3b48840b207f47eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:14:24 -0800 Subject: [PATCH 0154/2461] [SPARK-23135][UI] Fix rendering of accumulators in the stage page. This follows the behavior of 2.2: only named accumulators with a value are rendered. Screenshot: ![accs](https://user-images.githubusercontent.com/1694083/35065700-df409114-fb82-11e7-87c1-550c3f674371.png) Author: Marcelo Vanzin Closes #20299 from vanzin/SPARK-23135. --- .../org/apache/spark/ui/jobs/StagePage.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 25bee33028393..0eb3190205c3e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -260,7 +260,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { - {acc.name}{acc.value} + if (acc.name != null && acc.value != null) { + {acc.name}{acc.value} + } else { + Nil + } } val accumulableTable = UIUtils.listingTable( accumulableHeaders, @@ -864,7 +868,7 @@ private[ui] class TaskPagedTable( {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} {if (hasAccumulators(stage)) { - accumulatorsInfo(task) + {accumulatorsInfo(task)} }} {if (hasInput(stage)) { metricInfo(task) { m => @@ -920,8 +924,12 @@ private[ui] class TaskPagedTable( } private def accumulatorsInfo(task: TaskData): Seq[Node] = { - task.accumulatorUpdates.map { acc => - Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + task.accumulatorUpdates.flatMap { acc => + if (acc.name != null && acc.update.isDefined) { + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")) ++
+ } else { + Nil + } } } @@ -985,7 +993,9 @@ private object ApiHelper { "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, "Errors" -> TaskIndexNames.ERROR) - def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasAccumulators(stageData: StageData): Boolean = { + stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } + } def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 From 793841c6b8b98b918dcf241e29f60ef125914db9 Mon Sep 17 00:00:00 2001 From: Kent Yao <11215016@zju.edu.cn> Date: Fri, 19 Jan 2018 15:49:29 -0800 Subject: [PATCH 0155/2461] [SPARK-21771][SQL] remove useless hive client in SparkSQLEnv ## What changes were proposed in this pull request? Once a meta hive client is created, it generates its SessionState which creates a lot of session related directories, some deleteOnExit, some does not. if a hive client is useless we may not create it at the very start. ## How was this patch tested? N/A cc hvanhovell cloud-fan Author: Kent Yao <11215016@zju.edu.cn> Closes #18983 from yaooqinn/patch-1. --- .../org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 6b19f971b73bb..cbd75ad12d430 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,8 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - .client.newSession() + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) From 396cdfbea45232bacbc03bfaf8be4ea85d47d3fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 19 Jan 2018 22:46:34 -0800 Subject: [PATCH 0156/2461] [SPARK-23091][ML] Incorrect unit test for approxQuantile ## What changes were proposed in this pull request? Narrow bound on approx quantile test to epsilon from 2*epsilon to match paper ## How was this patch tested? Existing tests. Author: Sean Owen Closes #20324 from srowen/SPARK-23091. --- .../apache/spark/sql/DataFrameStatSuite.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 5169d2b5fc6b2..8eae35325faea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -154,24 +154,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) - val error_single = 2 * 1000 * epsilon - val error_double = 2 * 2000 * epsilon + val errorSingle = 1000 * epsilon + val errorDouble = 2.0 * errorSingle - assert(math.abs(single1 - q1 * n) < error_single) - assert(math.abs(double2 - 2 * q2 * n) < error_double) - assert(math.abs(s1 - q1 * n) < error_single) - assert(math.abs(s2 - q2 * n) < error_single) - assert(math.abs(d1 - 2 * q1 * n) < error_double) - assert(math.abs(d2 - 2 * q2 * n) < error_double) + assert(math.abs(single1 - q1 * n) <= errorSingle) + assert(math.abs(double2 - 2 * q2 * n) <= errorDouble) + assert(math.abs(s1 - q1 * n) <= errorSingle) + assert(math.abs(s2 - q2 * n) <= errorSingle) + assert(math.abs(d1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(d2 - 2 * q2 * n) <= errorDouble) // Multiple columns val Array(Array(ms1, ms2), Array(md1, md2)) = df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) - assert(math.abs(ms1 - q1 * n) < error_single) - assert(math.abs(ms2 - q2 * n) < error_single) - assert(math.abs(md1 - 2 * q1 * n) < error_double) - assert(math.abs(md2 - 2 * q2 * n) < error_double) + assert(math.abs(ms1 - q1 * n) <= errorSingle) + assert(math.abs(ms2 - q2 * n) <= errorSingle) + assert(math.abs(md1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(md2 - 2 * q2 * n) <= errorDouble) } // quantile should be in the range [0.0, 1.0] From 84a076e0e9a38a26edf7b702c24fdbbcf1e697b9 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 20 Jan 2018 14:34:37 -0800 Subject: [PATCH 0157/2461] [SPARK-23165][DOC] Spelling mistake fix in quick-start doc. ## What changes were proposed in this pull request? Fix spelling in quick-start doc. ## How was this patch tested? Doc only. Author: Shashwat Anand Closes #20336 from ashashwat/SPARK-23165. --- docs/cloud-integration.md | 4 ++-- docs/configuration.md | 14 +++++++------- docs/graphx-programming-guide.md | 4 ++-- docs/monitoring.md | 8 ++++---- docs/quick-start.md | 6 +++--- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/sql-programming-guide.md | 8 ++++---- docs/storage-openstack-swift.md | 2 +- docs/streaming-programming-guide.md | 4 ++-- docs/structured-streaming-kafka-integration.md | 4 ++-- docs/structured-streaming-programming-guide.md | 6 +++--- docs/submitting-applications.md | 8 ++++---- 14 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 751a192da4ffd..c150d9efc06ff 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -180,10 +180,10 @@ under the path, not the number of *new* files, so it can become a slow operation The size of the window needs to be set to handle this. 1. Files only appear in an object store once they are completely written; there -is no need for a worklow of write-then-rename to ensure that files aren't picked up +is no need for a workflow of write-then-rename to ensure that files aren't picked up while they are still being written. Applications can write straight to the monitored directory. -1. Streams should only be checkpointed to an store implementing a fast and +1. Streams should only be checkpointed to a store implementing a fast and atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index eecb39dcafc9e..e7f2419cc2fa4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -79,7 +79,7 @@ Then, you can supply configuration values at runtime: {% endhighlight %} The Spark shell and [`spark-submit`](submitting-applications.html) -tool support two ways to load configurations dynamically. The first are command line options, +tool support two ways to load configurations dynamically. The first is command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. Running `./bin/spark-submit --help` will show the entire list of these options. @@ -413,7 +413,7 @@ Apart from these, the following properties are also available, and may be useful false Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), - or it will be displayed before the driver exiting. It also can be dumped into disk by + or it will be displayed before the driver exits. It also can be dumped into disk by sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. @@ -446,7 +446,7 @@ Apart from these, the following properties are also available, and may be useful true Reuse Python worker or not. If yes, it will use a fixed number of Python workers, - does not need to fork() a Python process for every tasks. It will be very useful + does not need to fork() a Python process for every task. It will be very useful if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -1294,7 +1294,7 @@ Apart from these, the following properties are also available, and may be useful spark.files.openCostInBytes 4194304 (4 MB) - The estimated cost to open a file, measured by the number of bytes could be scanned in the same + The estimated cost to open a file, measured by the number of bytes could be scanned at the same time. This is used when putting multiple files into a partition. It is better to over estimate, then the partitions with small files will be faster than partitions with bigger files. @@ -1855,8 +1855,8 @@ Apart from these, the following properties are also available, and may be useful spark.user.groups.mapping org.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user are determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can configured by this property. + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider which can be specified to resolve a list of groups for a user. Note: This implementation supports only a Unix/Linux based environment. Windows environment is @@ -2465,7 +2465,7 @@ should be included on Spark's classpath: The location of these configuration files varies across Hadoop versions, but a common location is inside of `/etc/hadoop/conf`. Some tools create -configurations on-the-fly, but offer a mechanisms to download copies of them. +configurations on-the-fly, but offer a mechanism to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/conf/spark-env.sh` to a location containing the configuration files. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 46225dc598da8..5c97a248df4bc 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,7 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodically checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): @@ -928,7 +928,7 @@ switch to 2D-partitioning or other heuristics included in GraphX.

-Once the edges have be partitioned the key challenge to efficient graph-parallel computation is +Once the edges have been partitioned the key challenge to efficient graph-parallel computation is efficiently joining vertex attributes with the edges. Because real-world graphs typically have more edges than vertices, we move vertex attributes to the edges. Because not all partitions will contain edges adjacent to all vertices we internally maintain a routing table which identifies where diff --git a/docs/monitoring.md b/docs/monitoring.md index f8d3ce91a0691..6f6cfc1288d73 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -118,7 +118,7 @@ The history server can be configured as follows: The number of applications to retain UI data for in the cache. If this cap is exceeded, then the oldest applications will be removed from the cache. If an application is not in the cache, - it will have to be loaded from disk if its accessed from the UI. + it will have to be loaded from disk if it is accessed from the UI. @@ -407,7 +407,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running -The number of jobs and stages which can retrieved is constrained by the same retention +The number of jobs and stages which can be retrieved is constrained by the same retention mechanism of the standalone Spark UI; `"spark.ui.retainedJobs"` defines the threshold value triggering garbage collection on jobs, and `spark.ui.retainedStages` that for stages. Note that the garbage collection takes place on playback: it is possible to retrieve @@ -422,10 +422,10 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future as a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. -Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is +Note that even when examining the UI of running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. diff --git a/docs/quick-start.md b/docs/quick-start.md index 200b97230e866..07c520cbee6be 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -67,7 +67,7 @@ res3: Long = 15 ./bin/pyspark -Or if PySpark is installed with pip in your current enviroment: +Or if PySpark is installed with pip in your current environment: pyspark @@ -156,7 +156,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).alias("word")).groupBy("word").count() {% endhighlight %} -Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: +Here, we use the `explode` function in `select`, to transform a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() @@ -422,7 +422,7 @@ $ YOUR_SPARK_HOME/bin/spark-submit \ Lines with a: 46, Lines with b: 23 {% endhighlight %} -If you have PySpark pip installed into your enviroment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. +If you have PySpark pip installed into your environment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. {% highlight bash %} # Use the Python interpreter to run your application diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 382cbfd5301b0..2bb5ecf1b8509 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -154,7 +154,7 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispacther, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. +By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e7edec5990363..e4f5a0c659e66 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -445,7 +445,7 @@ To use a custom metrics.properties for the application master and executors, upd yarn.nodemanager.log-aggregation.roll-monitoring-interval-seconds should be configured in yarn-site.xml. This feature can only be used with Hadoop 2.6.4+. The Spark log4j appender needs be changed to use - FileAppender or another appender that can handle the files being removed while its running. Based + FileAppender or another appender that can handle the files being removed while it is running. Based on the file name configured in the log4j configuration (like spark.log), the user should set the regex (spark*) to include all the log files that need to be aggregated. diff --git a/docs/security.md b/docs/security.md index 15aadf07cf873..bebc28ddbfb0e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -62,7 +62,7 @@ component-specific configuration namespaces used to override the default setting -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). +The full breakdown of available SSL options can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. ### YARN mode diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e2e48a0ef249..502c0a8c37e01 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1253,7 +1253,7 @@ provide a ClassTag. (Note that this is different than the Spark SQL JDBC server, which allows other applications to run queries using Spark SQL). -To get started you will need to include the JDBC driver for you particular database on the +To get started you will need to include the JDBC driver for your particular database on the spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: @@ -1793,7 +1793,7 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. @@ -1821,7 +1821,7 @@ options. transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ - APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the + APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the single-node data frame notion in these languages. - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union` @@ -1997,7 +1997,7 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +of either language should use `SQLContext` and `DataFrame`. In general these classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index f4bb2353e3c49..1dd54719b21aa 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -42,7 +42,7 @@ Create core-site.xml and place it inside Spark's conf The main category of parameters that should be configured are the authentication parameters required by Keystone. -The following table contains a list of Keystone mandatory parameters. PROVIDER can be +The following table contains a list of Keystone mandatory parameters. PROVIDER can be any (alphanumeric) name. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 868acc41226dc..ffda36d64a770 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -74,7 +74,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. -// The master requires 2 cores to prevent from a starvation scenario. +// The master requires 2 cores to prevent a starvation scenario. val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) @@ -172,7 +172,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Note that we defined the transformation using a [FlatMapFunction](api/scala/index.html#org.apache.spark.api.java.function.FlatMapFunction) object. As we will discover along the way, there are a number of such convenience classes in the Java API -that help define DStream transformations. +that help defines DStream transformations. Next, we want to count these words. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 461c29ce1ba89..5647ec6bc5797 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -125,7 +125,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") ### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, -you can create an Dataset/DataFrame for a defined range of offsets. +you can create a Dataset/DataFrame for a defined range of offsets.
@@ -597,7 +597,7 @@ Note that the following Kafka params cannot be set and the Kafka source or sink - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. - **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use -DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. +DataFrame operations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ddba2f0d942e..2ef5d3168a87b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,7 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. @@ -1121,7 +1121,7 @@ Let’s discuss the different types of supported stream-stream joins and how to ##### Inner Joins with optional Watermarking Inner joins on any kind of columns along with any kind of join conditions are supported. However, as the stream runs, the size of streaming state will keep growing indefinitely as -*all* past input must be saved as the any new input can match with any input from the past. +*all* past input must be saved as any new input can match with any input from the past. To avoid unbounded state, you have to define additional join conditions such that indefinitely old inputs cannot match with future inputs and therefore can be cleared from the state. In other words, you will have to do the following additional steps in the join. @@ -1839,7 +1839,7 @@ aggDF \ .format("console") \ .start() -# Have all the aggregates in an in memory table. The query name will be the table name +# Have all the aggregates in an in-memory table. The query name will be the table name aggDF \ .writeStream \ .queryName("aggregates") \ diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 0473ab73a5e6c..a3643bf0838a1 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -5,7 +5,7 @@ title: Submitting Applications The `spark-submit` script in Spark's `bin` directory is used to launch applications on a cluster. It can use all of Spark's supported [cluster managers](cluster-overview.html#cluster-manager-types) -through a uniform interface so you don't have to configure your application specially for each one. +through a uniform interface so you don't have to configure your application especially for each one. # Bundling Your Application's Dependencies If your code depends on other projects, you will need to package them alongside @@ -58,7 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +the drivers and the executors. Currently, the standalone mode does not support cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, @@ -68,7 +68,7 @@ There are a few options available that are specific to the [cluster manager](cluster-overview.html#cluster-manager-types) that is being used. For example, with a [Spark standalone cluster](spark-standalone.html) with `cluster` deploy mode, you can also specify `--supervise` to make sure that the driver is automatically restarted if it -fails with non-zero exit code. To enumerate all such options available to `spark-submit`, +fails with a non-zero exit code. To enumerate all such options available to `spark-submit`, run it with `--help`. Here are a few examples of common options: {% highlight bash %} @@ -192,7 +192,7 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included in the driver and executor classpaths. Directory expansion does not work with `--jars`. Spark uses the following URL scheme to allow different strategies for disseminating jars: From 00d169156d4b1c91d2bcfd788b254b03c509dc41 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 14:49:49 -0800 Subject: [PATCH 0158/2461] [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing What changes were proposed in this pull request? Pass ‘spark.sql.parquet.compression.codec’ value to ‘parquet.compression’. Pass ‘spark.sql.orc.compression.codec’ value to ‘orc.compress’. How was this patch tested? Add test. Note: This is the same issue mentioned in #19218 . That branch was deleted mistakenly, so make a new pr instead. gatorsmile maropu dongjoon-hyun discipleforteen Author: fjh100456 Author: Takeshi Yamamuro Author: Wenchen Fan Author: gatorsmile Author: Yinan Li Author: Marcelo Vanzin Author: Juliusz Sompolski Author: Felix Cheung Author: jerryshao Author: Li Jin Author: Gera Shegalov Author: chetkhatri Author: Joseph K. Bradley Author: Bago Amirbekian Author: Xianjin YE Author: Bruce Robbins Author: zuotingbing Author: Kent Yao Author: hyukjinkwon Author: Adrian Ionescu Closes #20087 from fjh100456/HiveTableWriting. --- .../datasources/orc/OrcOptions.scala | 2 + .../datasources/parquet/ParquetOptions.scala | 6 +- .../sql/hive/execution/HiveOptions.scala | 22 ++ .../sql/hive/execution/SaveAsHiveFile.scala | 20 +- .../sql/hive/CompressionCodecSuite.scala | 353 ++++++++++++++++++ 5 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a525..0ad3862f6cf01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,6 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index ef67ea7d17cea..f36a89a4c3c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -82,4 +82,8 @@ object ParquetOptions { "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = { + shortParquetCompressionCodecNames(name).name() + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9c..802ddafdbee4d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6c..e484356906e87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,18 +55,28 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } val committer = FileCommitProtocol.instantiate( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 0000000000000..d10a6f25c64fc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} +import org.apache.spark.sql.hive.orc.OrcFileOperator +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 50 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) + } + } + + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format.toLowerCase match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter { file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + codecs.distinct + } + + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + } + + private def writeDataToTable( + tableName: String, + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString + sql( + s""" + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) + } + + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + totalSize = getTableSize(path) + } + } + } + assert(totalSize > 0L) + totalSize + } + + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String], + usingCTAS: Boolean) + (assertion: (String, Long) => Unit): Unit = { + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" + withTempDir { tmpDir => + withTable(tableName) { + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) + val tableSize = getTableSize(path) + assertion(relCompressionCodecs.head, tableSize) + } + } + } + + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } + } + } + + // When the amount of data is small, compressed data size may be larger than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + tableSize: Long): Boolean = { + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize + } + } + + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // For non-partitioned table and when convertMetastore is true, Expect session-level + // take effect, and in other cases expect table-level take effect + // TODO: It should always be table-level taking effect when the bug(SPARK-22926) + // is fixed + val expectCodec = + if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + test("both table-level and session-level compression are set") { + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + test("table-level compression is not set but session-level compressions is set ") { + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partitionValue = if (isPartitioned) Some(compressionCodec) else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partitionValue) } + } + val tablePath = getTablePartitionPath(tmpDir, tableName, None) + val realCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else { + getTableCompressionCodec(tablePath, format) + } + + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } + } + } + } + + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) + } +} From 121dc96f088a7b157d5b2cffb626b0e22d1fc052 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 20 Jan 2018 22:39:49 -0800 Subject: [PATCH 0159/2461] [SPARK-23087][SQL] CheckCartesianProduct too restrictive when condition is false/null ## What changes were proposed in this pull request? CheckCartesianProduct raises an AnalysisException also when the join condition is always false/null. In this case, we shouldn't raise it, since the result will not be a cartesian product. ## How was this patch tested? added UT Author: Marco Gaido Closes #20333 from mgaido91/SPARK-23087. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 10 +++++++--- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c794ba8619322..0f9daa5f04c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1108,15 +1108,19 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { */ def isCartesianProduct(join: Join): Boolean = { val conditions = join.condition.map(splitConjunctivePredicates).getOrElse(Nil) - !conditions.map(_.references).exists(refs => refs.exists(join.left.outputSet.contains) - && refs.exists(join.right.outputSet.contains)) + + conditions match { + case Seq(Literal.FalseLiteral) | Seq(Literal(null, BooleanType)) => false + case _ => !conditions.map(_.references).exists(refs => + refs.exists(join.left.outputSet.contains) && refs.exists(join.right.outputSet.contains)) + } } def apply(plan: LogicalPlan): LogicalPlan = if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( s"""Detected cartesian product for ${j.joinType.sql} join between logical plans diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index aef0d7f3e425b..1656f290ee19c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -274,4 +274,18 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { checkAnswer(innerJoin, Row(1) :: Nil) } + test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + + "is false or null") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + val planNull = df.join(dfNull, $"id" === $"b", "left").queryExecution.analyzed + + spark.sessionState.executePlan(planNull).optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + val planFalse = dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.analyzed + + spark.sessionState.executePlan(planFalse).optimizedPlan + } } From 4f43d27c9e97be8605b120b3d7c11c7c61e3ca6f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 21 Jan 2018 08:51:12 -0600 Subject: [PATCH 0160/2461] [SPARK-22119][ML] Add cosine distance to KMeans ## What changes were proposed in this pull request? Currently, KMeans assumes the only possible distance measure to be used is the Euclidean. This PR aims to add the cosine distance support to the KMeans algorithm. ## How was this patch tested? existing and added UTs. Author: Marco Gaido Author: Marco Gaido Closes #19340 from mgaido91/SPARK-22119. --- .../apache/spark/ml/clustering/KMeans.scala | 22 +- .../mllib/clustering/BisectingKMeans.scala | 11 +- .../spark/mllib/clustering/KMeans.scala | 216 ++++++++++++++---- .../spark/mllib/clustering/KMeansModel.scala | 74 +++++- .../spark/mllib/clustering/LocalKMeans.scala | 10 +- .../spark/ml/clustering/KMeansSuite.scala | 42 +++- .../spark/mllib/clustering/KMeansSuite.scala | 6 +- 7 files changed, 315 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f2af7fe082b41..c8145de564cbe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) + @Since("2.4.0") + final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + (value: String) => MLlibKMeans.validateDistanceMeasure(value)) + + /** @group expertGetParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. @@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") ( maxIter -> 20, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 2, - tol -> 1e-4) + tol -> 1e-4, + distanceMeasure -> DistanceMeasure.EUCLIDEAN) @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -284,6 +294,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + /** @group expertSetParam */ @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) @@ -314,7 +328,8 @@ class KMeans @Since("1.5.0") ( } val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) + instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -322,6 +337,7 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 9b9c70cfe5109..2221f4c0edc17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -350,7 +350,7 @@ private object BisectingKMeans extends Serializable { val newClusterChildren = children.filter(newClusterCenters.contains(_)) if (newClusterChildren.nonEmpty) { val selected = newClusterChildren.minBy { child => - KMeans.fastSquaredDistance(newClusterCenters(child), v) + EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) } (selected, v) } else { @@ -387,7 +387,7 @@ private object BisectingKMeans extends Serializable { val rightIndex = rightChildIndex(rawIndex) val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) val height = math.sqrt(indexes.map { childIndex => - KMeans.fastSquaredDistance(center, clusters(childIndex).center) + EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) }.max) val children = indexes.map(buildSubTree(_)).toArray new ClusteringTreeNode(index, size, center, cost, height, children) @@ -457,7 +457,7 @@ private[clustering] class ClusteringTreeNode private[clustering] ( this :: Nil } else { val selected = children.minBy { child => - KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) } selected :: selected.predictPath(pointWithNorm) } @@ -475,7 +475,8 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * Predicts the cluster index and the cost of the input point. */ private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + predict(pointWithNorm, + EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) } /** @@ -490,7 +491,7 @@ private[clustering] class ClusteringTreeNode private[clustering] ( (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) selectedChild.predict(pointWithNorm, minCost) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 49043b5acb807..607145cb59fba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,14 +46,23 @@ class KMeans private ( private var initializationMode: String, private var initializationSteps: Int, private var epsilon: Double, - private var seed: Long) extends Serializable with Logging { + private var seed: Long, + private var distanceMeasure: String) extends Serializable with Logging { + + @Since("0.8.0") + private def this(k: Int, maxIterations: Int, initializationMode: String, initializationSteps: Int, + epsilon: Double, seed: Long) = + this(k, maxIterations, initializationMode, initializationSteps, + epsilon, seed, DistanceMeasure.EUCLIDEAN) /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, - * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. + * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random, + * distanceMeasure: "euclidean"}. */ @Since("0.8.0") - def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong(), + DistanceMeasure.EUCLIDEAN) /** * Number of clusters to create (k). @@ -184,6 +193,22 @@ class KMeans private ( this } + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + KMeans.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + // Initial cluster centers can be provided as a KMeansModel object rather than using the // random or k-means|| initializationMode private var initialModel: Option[KMeansModel] = None @@ -246,6 +271,8 @@ class KMeans private ( val initStartTime = System.nanoTime() + val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure) + val centers = initialModel match { case Some(kMeansCenters) => kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) @@ -253,7 +280,7 @@ class KMeans private ( if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { - initKMeansParallel(data) + initKMeansParallel(data, distanceMeasureInstance) } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 @@ -281,7 +308,7 @@ class KMeans private ( val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => - val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) + val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) val sum = sums(bestCenter) axpy(1.0, point.vector, sum) @@ -302,7 +329,8 @@ class KMeans private ( // Update the cluster centers and costs converged = true newCenters.foreach { case (j, newCenter) => - if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { + if (converged && + !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) { converged = false } centers(j) = newCenter @@ -323,7 +351,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector)) + new KMeansModel(centers.map(_.vector), distanceMeasure) } /** @@ -345,7 +373,8 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm], + distanceMeasureInstance: DistanceMeasure): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -369,7 +398,7 @@ class KMeans private ( bcNewCentersList += bcNewCenters val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - math.min(KMeans.pointCost(bcNewCenters.value, point), cost) + math.min(distanceMeasureInstance.pointCost(bcNewCenters.value, point), cost) }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs.sum() @@ -397,7 +426,9 @@ class KMeans private ( // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick k of them val bcCenters = data.context.broadcast(distinctCenters) - val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() + val countMap = data + .map(distanceMeasureInstance.findClosest(bcCenters.value, _)._1) + .countByValue() bcCenters.destroy(blocking = false) @@ -546,10 +577,110 @@ object KMeans { .run(data) } + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +/** + * A vector with its norm for fast distance computation. + */ +private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) + extends Serializable { + + def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) + + def this(array: Array[Double]) = this(Vectors.dense(array)) + + /** Converts the vector to a dense vector. */ + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +} + + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + /** - * Returns the index of the closest center to the given point, as well as the squared distance. + * @return the K-means cost of a given point against the given cluster centers. */ - private[mllib] def findClosest( + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the cosine distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( centers: TraversableOnce[VectorWithNorm], point: VectorWithNorm): (Int, Double) = { var bestDistance = Double.PositiveInfinity @@ -561,7 +692,7 @@ object KMeans { var lowerBoundOfSqDist = center.norm - point.norm lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = fastSquaredDistance(center, point) + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) if (distance < bestDistance) { bestDistance = distance bestIndex = i @@ -573,15 +704,29 @@ object KMeans { } /** - * Returns the K-means cost of a given point against the given cluster centers. + * @return whether a center converged or not, given the epsilon parameter. */ - private[mllib] def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = - findClosest(centers, point)._2 + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } +} + +private[spark] object EuclideanDistanceMeasure { /** - * Returns the squared Euclidean distance between two vectors computed by + * @return the squared Euclidean distance between two vectors computed by * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. */ private[clustering] def fastSquaredDistance( @@ -589,28 +734,15 @@ object KMeans { v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } - - private[spark] def validateInitMode(initMode: String): Boolean = { - initMode match { - case KMeans.RANDOM => true - case KMeans.K_MEANS_PARALLEL => true - case _ => false - } - } } -/** - * A vector with its norm for fast distance computation. - * - * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] - */ -private[clustering] -class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable { - - def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) - - def this(array: Array[Double]) = this(Vectors.dense(array)) - - /** Converts the vector to a dense vector. */ - def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3ad08c46d204d..a78c21e838e44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,12 +36,20 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) +class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String) extends Saveable with Serializable with PMMLExportable { + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("1.1.0") + def this(clusterCenters: Array[Vector]) = + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + /** * A Java-friendly constructor that takes an Iterable of Vectors. */ @@ -59,7 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec */ @Since("0.8.0") def predict(point: Vector): Int = { - KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 + distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } /** @@ -68,7 +76,8 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) - points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) + points.map(p => + distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } /** @@ -85,8 +94,9 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) - val cost = data - .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + val cost = data.map(p => + distanceMeasureInstance.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))) + .sum() bcCentersWithNorm.destroy(blocking = false) cost } @@ -94,7 +104,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { - KMeansModel.SaveLoadV1_0.save(sc, this, path) + KMeansModel.SaveLoadV2_0.save(sc, this, path) } override protected def formatVersion: String = "1.0" @@ -105,7 +115,20 @@ object KMeansModel extends Loader[KMeansModel] { @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { - KMeansModel.SaveLoadV1_0.load(sc, path) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + val classNameV2_0 = SaveLoadV2_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case (className, "2.0") if className == classNameV2_0 => + SaveLoadV2_0.load(sc, path) + case _ => throw new Exception( + s"KMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)\n" + + s" ($classNameV2_0, 2.0)") + } } private case class Cluster(id: Int, point: Vector) @@ -116,8 +139,7 @@ object KMeansModel extends Loader[KMeansModel] { } } - private[clustering] - object SaveLoadV1_0 { + private[clustering] object SaveLoadV1_0 { private val thisFormatVersion = "1.0" @@ -149,4 +171,38 @@ object KMeansModel extends Loader[KMeansModel] { new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } + + private[clustering] object SaveLoadV2_0 { + + private val thisFormatVersion = "2.0" + + private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => + Cluster(id, p.vector) + } + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centroids = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.rdd.map(Cluster.apply).collect() + assert(k == localCentroids.length) + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 53587670a5db0..4a08c0a55e68f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -46,7 +46,7 @@ private[mllib] object LocalKMeans extends Logging { // Initialize centers by sampling using the k-means++ procedure. centers(0) = pickWeighted(rand, points, weights).toDense - val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0))) + val costArray = points.map(EuclideanDistanceMeasure.fastSquaredDistance(_, centers(0))) for (i <- 1 until k) { val sum = costArray.zip(weights).map(p => p._1 * p._2).sum @@ -67,11 +67,15 @@ private[mllib] object LocalKMeans extends Logging { // update costArray for (p <- points.indices) { - costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p)) + costArray(p) = math.min( + EuclideanDistanceMeasure.fastSquaredDistance(points(p), centers(i)), + costArray(p)) } } + val distanceMeasureInstance = new EuclideanDistanceMeasure + // Run up to maxIterations iterations of Lloyd's algorithm val oldClosest = Array.fill(points.length)(-1) var iteration = 0 @@ -83,7 +87,7 @@ private[mllib] object LocalKMeans extends Logging { var i = 0 while (i < points.length) { val p = points(i) - val index = KMeans.findClosest(centers, p)._1 + val index = distanceMeasureInstance.findClosest(centers, p)._1 axpy(weights(i), p.vector, sums(index)) counts(index) += weights(i) if (index != oldClosest(i)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 119fe1dead9a9..e4506f23feb31 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -50,6 +50,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) val model = kmeans.setMaxIter(1).fit(dataset) MLTestingUtils.checkCopyAndUids(kmeans, model) @@ -68,6 +69,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .setInitSteps(3) .setSeed(123) .setTol(1e-3) + .setDistanceMeasure(DistanceMeasure.COSINE) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") @@ -77,6 +79,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) assert(kmeans.getTol === 1e-3) + assert(kmeans.getDistanceMeasure === DistanceMeasure.COSINE) } test("parameters validation") { @@ -89,6 +92,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } + intercept[IllegalArgumentException] { + new KMeans().setDistanceMeasure("no_such_a_measure") + } } test("fit, transform and summary") { @@ -144,6 +150,37 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("KMeans using cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + + val model = new KMeans() + .setK(3) + .setSeed(1) + .setInitMode(MLlibKMeans.RANDOM) + .setTol(1e-6) + .setDistanceMeasure(DistanceMeasure.COSINE) + .fit(df) + + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -182,6 +219,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "distanceMeasure" -> DistanceMeasure.EUCLIDEAN ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 00d7e2f2d3864..1b98250061c7a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -89,7 +89,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters = km.initKMeansParallel(normedData).map(_.vector) + + val distanceMeasureInstance = new EuclideanDistanceMeasure + val initialCenters = km.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters.length === initialCenters.distinct.length) assert(initialCenters.length <= numDistinctPoints) @@ -104,7 +106,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector) + val initialCenters2 = km2.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters2.length === initialCenters2.distinct.length) assert(initialCenters2.length === k) From 2239d7a410e906ccd40aa8e84d637e9d06cd7b8a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 21 Jan 2018 11:23:51 -0800 Subject: [PATCH 0161/2461] [SPARK-21293][SS][SPARKR] Add doc example for streaming join, dedup ## What changes were proposed in this pull request? streaming programming guide changes ## How was this patch tested? manually Author: Felix Cheung Closes #20340 from felixcheung/rstreamdoc. --- .../structured-streaming-programming-guide.md | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ef5d3168a87b..62589a62ac4c4 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1100,6 +1100,21 @@ streamingDf.join(staticDf, "type") # inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") # right outer join with a static DF {% endhighlight %} +
+ +
+ +{% highlight r %} +staticDf <- read.df(...) +streamingDf <- read.stream(...) +joined <- merge(streamingDf, staticDf, sort = FALSE) # inner equi-join with a static DF +joined <- join( + staticDf, + streamingDf, + streamingDf$value == staticDf$value, + "right_outer") # right outer join with a static DF +{% endhighlight %} +
@@ -1227,6 +1242,30 @@ impressionsWithWatermark.join( {% endhighlight %} + +
+ +{% highlight r %} +impressions <- read.stream(...) +clicks <- read.stream(...) + +# Apply watermarks on event-time columns +impressionsWithWatermark <- withWatermark(impressions, "impressionTime", "2 hours") +clicksWithWatermark <- withWatermark(clicks, "clickTime", "3 hours") + +# Join with event-time constraints +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour" +))) + +{% endhighlight %} +
@@ -1287,6 +1326,23 @@ impressionsWithWatermark.join( {% endhighlight %} + +
+ +{% highlight r %} +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour"), + "left_outer" # can be "inner", "left_outer", "right_outer" +)) + +{% endhighlight %} +
@@ -1441,15 +1497,29 @@ streamingDf {% highlight python %} streamingDf = spark.readStream. ... -// Without watermark using guid column +# Without watermark using guid column streamingDf.dropDuplicates("guid") -// With watermark using guid and eventTime columns +# With watermark using guid and eventTime columns streamingDf \ .withWatermark("eventTime", "10 seconds") \ .dropDuplicates("guid", "eventTime") {% endhighlight %} + +
+ +{% highlight r %} +streamingDf <- read.stream(...) + +# Without watermark using guid column +streamingDf <- dropDuplicates(streamingDf, "guid") + +# With watermark using guid and eventTime columns +streamingDf <- withWatermark(streamingDf, "eventTime", "10 seconds") +streamingDf <- dropDuplicates(streamingDf, "guid", "eventTime") +{% endhighlight %} +
From 12faae295e42820b99a695ba49826051944244e1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 22 Jan 2018 09:45:27 +0900 Subject: [PATCH 0162/2461] [SPARK-23169][INFRA][R] Run lintr on the changes of lint-r script and .lintr configuration ## What changes were proposed in this pull request? When running the `run-tests` script, seems we don't run lintr on the changes of `lint-r` script and `.lintr` configuration. ## How was this patch tested? Jenkins builds Author: hyukjinkwon Closes #20339 from HyukjinKwon/check-r-changed. --- dev/run-tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 7e6f7ff060351..fb270c4ee0508 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -578,7 +578,10 @@ def main(): pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() - if not changed_files or any(f.endswith(".R") for f in changed_files): + if not changed_files or any(f.endswith(".R") + or f.endswith("lint-r") + or f.endswith(".lintr") + for f in changed_files): run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment From 602c6d82d893a7f34b37d674642669048eb59b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=99=93=E5=93=B2?= Date: Mon, 22 Jan 2018 10:43:12 +0900 Subject: [PATCH 0163/2461] [SPARK-20947][PYTHON] Fix encoding/decoding error in pipe action MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Pipe action convert objects into strings using a way that was affected by the default encoding setting of Python environment. This patch fixed the problem. The detailed description is added here: https://issues.apache.org/jira/browse/SPARK-20947 ## How was this patch tested? Run the following statement in pyspark-shell, and it will NOT raise exception if this patch is applied: ```python sc.parallelize([u'\u6d4b\u8bd5']).pipe('cat').collect() ``` Author: 王晓哲 Closes #18277 from chaoslawful/fix_pipe_encoding_error. --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..1b3915548fb14 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -766,7 +766,7 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - s = str(obj).rstrip('\n') + '\n' + s = unicode(obj).rstrip('\n') + '\n' out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index da99872da2f0e..511585763cb01 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1239,6 +1239,13 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + class ProfilerTests(PySparkTestCase): From 11daeb833222b1cd349fb1410307d64ab33981db Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Mon, 22 Jan 2018 12:27:51 +0800 Subject: [PATCH 0164/2461] [SPARK-22976][CORE] Cluster mode driver dir removed while running ## What changes were proposed in this pull request? The clean up logic on the worker perviously determined the liveness of a particular applicaiton based on whether or not it had running executors. This would fail in the case that a directory was made for a driver running in cluster mode if that driver had no running executors on the same machine. To preserve driver directories we consider both executors and running drivers when checking directory liveness. ## How was this patch tested? Manually started up two node cluster with a single core on each node. Turned on worker directory cleanup and set the interval to 1 second and liveness to one second. Without the patch the driver directory is removed immediately after the app is launched. With the patch it is not ### Without Patch ``` INFO 2018-01-05 23:48:24,693 Logging.scala:54 - Asked to launch driver driver-20180105234824-0000 INFO 2018-01-05 23:48:25,293 Logging.scala:54 - Changing view acls to: cassandra INFO 2018-01-05 23:48:25,293 Logging.scala:54 - Changing modify acls to: cassandra INFO 2018-01-05 23:48:25,294 Logging.scala:54 - Changing view acls groups to: INFO 2018-01-05 23:48:25,294 Logging.scala:54 - Changing modify acls groups to: INFO 2018-01-05 23:48:25,294 Logging.scala:54 - SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(cassandra); groups with view permissions: Set(); users with modify permissions: Set(cassandra); groups with modify permissions: Set() INFO 2018-01-05 23:48:25,330 Logging.scala:54 - Copying user jar file:/home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180105234824-0000/writeRead-0.1.jar INFO 2018-01-05 23:48:25,332 Logging.scala:54 - Copying /home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180105234824-0000/writeRead-0.1.jar INFO 2018-01-05 23:48:25,361 Logging.scala:54 - Launch Command: "/usr/lib/jvm/jdk1.8.0_40//bin/java" .... **** INFO 2018-01-05 23:48:56,577 Logging.scala:54 - Removing directory: /var/lib/spark/worker/driver-20180105234824-0000 ### << Cleaned up **** -- One minute passes while app runs (app has 1 minute sleep built in) -- WARN 2018-01-05 23:49:58,080 ShuffleSecretManager.java:73 - Attempted to unregister application app-20180105234831-0000 when it is not registered INFO 2018-01-05 23:49:58,081 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = false INFO 2018-01-05 23:49:58,081 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = false INFO 2018-01-05 23:49:58,082 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = true INFO 2018-01-05 23:50:00,999 Logging.scala:54 - Driver driver-20180105234824-0000 exited successfully ``` With Patch ``` INFO 2018-01-08 23:19:54,603 Logging.scala:54 - Asked to launch driver driver-20180108231954-0002 INFO 2018-01-08 23:19:54,975 Logging.scala:54 - Changing view acls to: automaton INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing modify acls to: automaton INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing view acls groups to: INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing modify acls groups to: INFO 2018-01-08 23:19:54,976 Logging.scala:54 - SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(automaton); groups with view permissions: Set(); users with modify permissions: Set(automaton); groups with modify permissions: Set() INFO 2018-01-08 23:19:55,029 Logging.scala:54 - Copying user jar file:/home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180108231954-0002/writeRead-0.1.jar INFO 2018-01-08 23:19:55,031 Logging.scala:54 - Copying /home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180108231954-0002/writeRead-0.1.jar INFO 2018-01-08 23:19:55,038 Logging.scala:54 - Launch Command: ...... INFO 2018-01-08 23:21:28,674 ShuffleSecretManager.java:69 - Unregistered shuffle secret for application app-20180108232000-0000 INFO 2018-01-08 23:21:28,675 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = false INFO 2018-01-08 23:21:28,675 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = false INFO 2018-01-08 23:21:28,681 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = true INFO 2018-01-08 23:21:31,703 Logging.scala:54 - Driver driver-20180108231954-0002 exited successfully ***** INFO 2018-01-08 23:21:32,346 Logging.scala:54 - Removing directory: /var/lib/spark/worker/driver-20180108231954-0002 ### < Happening AFTER the Run completes rather than during it ***** ``` Author: Russell Spitzer Closes #20298 from RussellSpitzer/SPARK-22976-master. --- core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3962d422f81d3..563b84934f264 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -441,7 +441,7 @@ private[deploy] class Worker( // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. - val appIds = executors.values.map(_.appId).toSet + val appIds = (executors.values.map(_.appId) ++ drivers.values.map(_.driverId)).toSet val cleanupFuture = concurrent.Future { val appDirs = workDir.listFiles() if (appDirs == null) { From 8142a3b883a5fe6fc620a2c5b25b6bde4fda32e5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 22 Jan 2018 15:18:57 +0900 Subject: [PATCH 0165/2461] [MINOR][SQL] Fix wrong comments on org.apache.spark.sql.parquet.row.attributes ## What changes were proposed in this pull request? This PR fixes the wrong comment on `org.apache.spark.sql.parquet.row.attributes` which is useful for UDTs like Vector/Matrix. Please see [SPARK-22320](https://issues.apache.org/jira/browse/SPARK-22320) for the usage. Originally, [SPARK-19411](https://github.com/apache/spark/commit/bf493686eb17006727b3ec81849b22f3df68fdef#diff-ee26d4c4be21e92e92a02e9f16dbc285L314) left this behind during removing optional column metadatas. In the same PR, the same comment was removed at line 310-311. ## How was this patch tested? N/A (This is about comments). Author: Dongjoon Hyun Closes #20346 from dongjoon-hyun/minor_comment_parquet. --- .../sql/execution/datasources/parquet/ParquetFileFormat.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 45bedf70f975c..f53a97ba45a26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -108,8 +108,7 @@ class ParquetFileFormat ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushdowning filters. + // This metadata is useful for keeping UDTs like Vector/Matrix. ParquetWriteSupport.setSchema(dataSchema, conf) // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet From ec228976156619ed8df21a85bceb5fd3bdeb5855 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 22 Jan 2018 14:49:12 +0800 Subject: [PATCH 0166/2461] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. The original version of this fix introduced the flipped version of the first race described above; the connection closing might override the handle state before the handle might have a chance to do cleanup. The fix there is to only dispose of the handle from the connection when there is an error, and let the handle dispose itself in the normal case. The fix also caused a bug in YarnClusterSuite to be surfaced; the code was checking for a file in the classpath that was not expected to be there in client mode. Because of the above issues, the error was not propagating correctly and the (buggy) test was incorrectly passing. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20297 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 53 +++++++++++-------- .../spark/launcher/AbstractAppHandle.java | 22 ++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 +++--- .../spark/launcher/LauncherConnection.java | 18 +++---- .../apache/spark/launcher/LauncherServer.java | 49 +++++++++++++---- .../org/apache/spark/launcher/BaseSuite.java | 42 ++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 23 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 4 +- 9 files changed, 165 insertions(+), 81 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index dffa609f1cbdf..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -25,13 +26,13 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -121,8 +122,7 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - // TODO: [SPARK-23020] Re-enable this - @Ignore + @Test public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original @@ -139,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -148,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..e8ab3f5e369ab 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -53,7 +53,7 @@ abstract class LauncherConnection implements Closeable, Runnable { public void run() { try { FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream()); - while (!closed) { + while (isOpen()) { Message msg = (Message) in.readObject(); handle(msg); } @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { - if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + public synchronized void close() throws IOException { + if (isOpen()) { + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..8091885c4f562 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -313,7 +340,7 @@ protected void handle(Message msg) throws IOException { } else { if (handle == null) { throw new IllegalArgumentException("Expected hello, got: " + - msg != null ? msg.getClass().getName() : null); + msg != null ? msg.getClass().getName() : null); } if (msg instanceof SetAppId) { SetAppId set = (SetAppId) msg; @@ -331,6 +358,9 @@ protected void handle(Message msg) throws IOException { timeout.cancel(); } close(); + if (handle != null) { + handle.dispose(); + } } finally { timeoutTimer.purge(); } @@ -338,16 +368,17 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); - if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + + synchronized (this) { + super.close(); + notifyAll(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..024efac33c391 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -143,7 +145,8 @@ public void infoChanged(SparkAppHandle handle) { assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); // Make sure the server matched the client to the handle. assertNotNull(handle.getConnection()); - close(client); + client.close(); + handle.dispose(); assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); assertEquals(SparkAppHandle.State.LOST, handle.getState()); } finally { @@ -197,28 +200,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 061f653b97b7a..e9dcfaf6ba4f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -381,7 +381,9 @@ private object YarnClusterDriver extends Logging with Matchers { // Verify that the config archive is correctly placed in the classpath of all containers. val confFile = "/" + Client.SPARK_CONF_FILE - assert(getClass().getResource(confFile) != null) + if (conf.getOption(SparkLauncher.DEPLOY_MODE) == Some("cluster")) { + assert(getClass().getResource(confFile) != null) + } val configFromExecutors = sc.parallelize(1 to 4, 4) .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull } .collect() From 60175e959f275d2961798fbc5a9150dac9de51ff Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Mon, 22 Jan 2018 20:17:05 +0800 Subject: [PATCH 0167/2461] [MINOR][DOC] Fix the path to the examples jar ## What changes were proposed in this pull request? The example jar file is now in ./examples/jars directory of Spark distribution. Author: Arseniy Tashoyan Closes #20349 from tashoyan/patch-1. --- docs/running-on-yarn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4f5a0c659e66..c010af35f8d2e 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -35,7 +35,7 @@ For example: --executor-memory 2g \ --executor-cores 1 \ --queue thequeue \ - lib/spark-examples*.jar \ + examples/jars/spark-examples*.jar \ 10 The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. From 73281161fc7fddd645c712986ec376ac2b1bd213 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:27:59 -0800 Subject: [PATCH 0168/2461] [SPARK-23122][PYSPARK][FOLLOW-UP] Update the docs for UDF Registration ## What changes were proposed in this pull request? This PR is to update the docs for UDF registration ## How was this patch tested? N/A Author: gatorsmile Closes #20348 from gatorsmile/testUpdateDoc. --- python/pyspark/sql/udf.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c77f19f89a442..134badb8485f5 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -199,8 +199,8 @@ def __init__(self, sparkSession): @ignore_unicode_prefix @since("1.3.1") def register(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a user-defined function - in SQL statements. + """Register a Python function (including lambda function) or a user-defined function + as a SQL function. :param name: name of the user-defined function in SQL statements. :param f: a Python function, or a user-defined function. The user-defined function can @@ -210,6 +210,10 @@ def register(self, name, f, returnType=None): be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. + To register a nondeterministic Python function, users need to first build + a nondeterministic user-defined function for the Python function and then register it + as a SQL function. + `returnType` can be optionally specified when `f` is a Python function but not when `f` is a user-defined function. Please see below. @@ -297,7 +301,7 @@ def register(self, name, f, returnType=None): @ignore_unicode_prefix @since(2.3) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a Java user-defined function so it can be used in SQL statements. + """Register a Java user-defined function as a SQL function. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not specified we would infer it via reflection. @@ -334,7 +338,7 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): @ignore_unicode_prefix @since(2.3) def registerJavaUDAF(self, name, javaClassName): - """Register a Java user-defined aggregate function so it can be used in SQL statements. + """Register a Java user-defined aggregate function as a SQL function. :param name: name of the user-defined aggregate function :param javaClassName: fully qualified name of java class From 78801881c405de47f7e53eea3e0420dd69593dbd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:31:24 -0800 Subject: [PATCH 0169/2461] [SPARK-23170][SQL] Dump the statistics of effective runs of analyzer and optimizer rules ## What changes were proposed in this pull request? Dump the statistics of effective runs of analyzer and optimizer rules. ## How was this patch tested? Do a manual run of TPCDSQuerySuite ``` === Metrics of Analyzer/Optimizer Rules === Total number of runs: 175899 Total time: 25.486559948 seconds Rule Effective Time / Total Time Effective Runs / Total Runs org.apache.spark.sql.catalyst.optimizer.ColumnPruning 1603280450 / 2868461549 761 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$CTESubstitution 2045860009 / 2056602674 37 / 788 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions 440719059 / 1693110949 38 / 1982 org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries 1429834919 / 1446016225 39 / 285 org.apache.spark.sql.catalyst.optimizer.PruneFilters 33273083 / 1389586938 3 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 821183615 / 1266668754 616 / 1982 org.apache.spark.sql.catalyst.optimizer.ReorderJoin 775837028 / 866238225 132 / 1592 org.apache.spark.sql.catalyst.analysis.DecimalPrecision 550683593 / 748854507 211 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery 513075345 / 634370596 49 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$FixNullability 33475731 / 606406532 12 / 742 org.apache.spark.sql.catalyst.analysis.TypeCoercion$ImplicitTypeCasts 193144298 / 545403925 86 / 1982 org.apache.spark.sql.catalyst.optimizer.BooleanSimplification 18651497 / 495725004 7 / 1592 org.apache.spark.sql.catalyst.optimizer.PushPredicateThroughJoin 369257217 / 489934378 709 / 1592 org.apache.spark.sql.catalyst.optimizer.RemoveRedundantAliases 3707000 / 468291609 9 / 1592 org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints 410155900 / 435254175 192 / 285 org.apache.spark.sql.execution.datasources.FindDataSourceTable 348885539 / 371855866 233 / 1982 org.apache.spark.sql.catalyst.optimizer.NullPropagation 11307645 / 307531225 26 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions 120324545 / 304948785 294 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$FunctionArgumentConversion 92323199 / 286695007 38 / 1982 org.apache.spark.sql.catalyst.optimizer.PushDownPredicate 230084193 / 265845972 785 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$PromoteStrings 45938401 / 265144009 40 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$InConversion 14888776 / 261499450 1 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$CaseWhenCoercion 113796384 / 244913861 29 / 1982 org.apache.spark.sql.catalyst.optimizer.ConstantFolding 65008069 / 236548480 126 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator 0 / 226338929 0 / 1982 org.apache.spark.sql.catalyst.analysis.ResolveTimeZone 98134906 / 221323770 417 / 1982 org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator 0 / 208421703 0 / 1592 org.apache.spark.sql.catalyst.optimizer.OptimizeIn 8762534 / 199351958 16 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$DateTimeOperations 11980016 / 190779046 27 / 1982 org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison 0 / 188887385 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals 0 / 186812106 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions 0 / 183885230 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCasts 17128295 / 182901910 69 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$Division 14579110 / 180309340 8 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$BooleanEquality 0 / 176740516 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$IfCoercion 0 / 170781986 0 / 1982 org.apache.spark.sql.catalyst.optimizer.LikeSimplification 771605 / 164136736 1 / 1592 org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions 0 / 155958962 0 / 1592 org.apache.spark.sql.catalyst.analysis.ResolveCreateNamedStruct 0 / 151222943 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder 7534632 / 146596355 14 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$EltCoercion 0 / 144488654 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$ConcatCoercion 0 / 142403338 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame 12067635 / 141500665 21 / 1982 org.apache.spark.sql.catalyst.analysis.TimeWindowing 0 / 140431958 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$WindowFrameCoercion 0 / 125471960 0 / 1982 org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin 14226972 / 124922019 11 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$StackCoercion 0 / 123613887 0 / 1982 org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery 8491071 / 121179056 7 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics 55526073 / 120290529 11 / 1982 org.apache.spark.sql.catalyst.optimizer.ConstantPropagation 0 / 113886790 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer 52383759 / 107160222 148 / 1982 org.apache.spark.sql.catalyst.analysis.CleanupAliases 52543524 / 102091518 344 / 1086 org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject 40682895 / 94403652 342 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions 38473816 / 89740578 23 / 1982 org.apache.spark.sql.catalyst.optimizer.CollapseProject 46806090 / 83315506 281 / 1877 org.apache.spark.sql.catalyst.optimizer.FoldablePropagation 0 / 78750087 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases 13742765 / 77227258 47 / 1982 org.apache.spark.sql.catalyst.optimizer.CombineFilters 53386729 / 76960344 448 / 1592 org.apache.spark.sql.execution.datasources.DataSourceAnalysis 68034341 / 75724186 24 / 742 org.apache.spark.sql.catalyst.analysis.Analyzer$LookupFunctions 0 / 71151084 0 / 750 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveMissingReferences 12139848 / 67599140 8 / 1982 org.apache.spark.sql.catalyst.optimizer.PullupCorrelatedPredicates 45017938 / 65968777 23 / 285 org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource 0 / 60937767 0 / 285 org.apache.spark.sql.catalyst.optimizer.CollapseRepartition 0 / 59897237 0 / 1592 org.apache.spark.sql.catalyst.optimizer.PushProjectionThroughUnion 8547262 / 53941370 10 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$HandleNullInputsForUDF 0 / 52735976 0 / 742 org.apache.spark.sql.catalyst.analysis.TypeCoercion$WidenSetOperationTypes 9797713 / 52401665 9 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$PullOutNondeterministic 0 / 51741500 0 / 742 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations 28614911 / 51061186 233 / 1990 org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions 0 / 50621510 0 / 285 org.apache.spark.sql.catalyst.optimizer.CombineUnions 2777800 / 50262112 17 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates 1640641 / 49633909 46 / 1982 org.apache.spark.sql.catalyst.optimizer.DecimalAggregates 20198374 / 48488419 100 / 385 org.apache.spark.sql.catalyst.optimizer.LimitPushDown 0 / 45052523 0 / 1592 org.apache.spark.sql.catalyst.optimizer.CombineLimits 0 / 44719443 0 / 1592 org.apache.spark.sql.catalyst.optimizer.EliminateSorts 0 / 44216930 0 / 1592 org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery 36235699 / 44165786 148 / 285 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNewInstance 0 / 42750307 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast 0 / 41811748 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveOrdinalInOrderByAndGroupBy 3819476 / 41776562 4 / 1982 org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime 0 / 40527808 0 / 285 org.apache.spark.sql.catalyst.optimizer.CollapseWindow 0 / 36832538 0 / 1592 org.apache.spark.sql.catalyst.optimizer.EliminateSerialization 0 / 36120667 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggAliasInGroupBy 0 / 32435826 0 / 1982 org.apache.spark.sql.execution.datasources.PreprocessTableCreation 0 / 32145218 0 / 742 org.apache.spark.sql.execution.datasources.ResolveSQLOnFile 0 / 30295614 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot 0 / 30111655 0 / 1982 org.apache.spark.sql.catalyst.expressions.codegen.package$ExpressionCanonicalizer$CleanExpressions 59930 / 28038201 26 / 8280 org.apache.spark.sql.catalyst.analysis.ResolveInlineTables 0 / 27808108 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases 0 / 27066690 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate 0 / 26660210 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin 0 / 25255184 0 / 1982 org.apache.spark.sql.catalyst.analysis.ResolveTableValuedFunctions 0 / 24663088 0 / 1990 org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals 9709079 / 24450670 4 / 788 org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveBroadcastHints 0 / 23776535 0 / 750 org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions 0 / 22697895 0 / 285 org.apache.spark.sql.catalyst.optimizer.CheckCartesianProducts 0 / 22523798 0 / 285 org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate 988593 / 21535410 15 / 300 org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects 0 / 20269996 0 / 285 org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates 0 / 19388592 0 / 285 org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases 17675532 / 18971185 215 / 285 org.apache.spark.sql.catalyst.optimizer.GetCurrentDatabase 0 / 18271152 0 / 285 org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation 2077097 / 17190855 3 / 288 org.apache.spark.sql.catalyst.analysis.EliminateBarriers 0 / 16736359 0 / 1086 org.apache.spark.sql.execution.OptimizeMetadataOnlyQuery 0 / 16669341 0 / 285 org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences 0 / 14470235 0 / 742 org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithAntiJoin 6715625 / 12190561 1 / 300 org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin 3451793 / 11431432 7 / 300 org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate 0 / 10810568 0 / 285 org.apache.spark.sql.catalyst.optimizer.RemoveRepetitionFromGroupExpressions 344198 / 10475276 1 / 286 org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution 0 / 10386630 0 / 788 org.apache.spark.sql.catalyst.analysis.EliminateUnions 0 / 10096526 0 / 788 org.apache.spark.sql.catalyst.analysis.AliasViewChild 0 / 9991706 0 / 742 org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation 0 / 9649334 0 / 288 org.apache.spark.sql.catalyst.analysis.ResolveHints$RemoveAllHints 0 / 8739109 0 / 750 org.apache.spark.sql.execution.datasources.PreprocessTableInsertion 0 / 8420889 0 / 742 org.apache.spark.sql.catalyst.analysis.EliminateView 0 / 8319134 0 / 285 org.apache.spark.sql.catalyst.optimizer.RemoveLiteralFromGroupExpressions 0 / 7392627 0 / 286 org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithFilter 0 / 7170516 0 / 300 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateArrayOps 0 / 7109643 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateStructOps 0 / 6837590 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateMapOps 0 / 6617848 0 / 1592 org.apache.spark.sql.catalyst.optimizer.CombineConcats 0 / 5768406 0 / 1592 org.apache.spark.sql.catalyst.optimizer.ReplaceDeduplicateWithAggregate 0 / 5349831 0 / 285 org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters 0 / 5186642 0 / 285 org.apache.spark.sql.catalyst.optimizer.EliminateDistinct 0 / 2427686 0 / 285 org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder 0 / 2420436 0 / 285 ``` Author: gatorsmile Closes #20342 from gatorsmile/reportExecution. --- .../rules/QueryExecutionMetering.scala | 91 +++++++++++++++++++ .../sql/catalyst/rules/RuleExecutor.scala | 32 +++---- .../apache/spark/sql/BenchmarkQueryTest.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 2 +- 5 files changed, 109 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala new file mode 100644 index 0000000000000..62f7541150a6e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.rules + +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + +case class QueryExecutionMetering() { + private val timeMap = AtomicLongMap.create[String]() + private val numRunsMap = AtomicLongMap.create[String]() + private val numEffectiveRunsMap = AtomicLongMap.create[String]() + private val timeEffectiveRunsMap = AtomicLongMap.create[String]() + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + timeMap.clear() + numRunsMap.clear() + numEffectiveRunsMap.clear() + timeEffectiveRunsMap.clear() + } + + def totalTime: Long = { + timeMap.sum() + } + + def totalNumRuns: Long = { + numRunsMap.sum() + } + + def incExecutionTimeBy(ruleName: String, delta: Long): Unit = { + timeMap.addAndGet(ruleName, delta) + } + + def incTimeEffectiveExecutionBy(ruleName: String, delta: Long): Unit = { + timeEffectiveRunsMap.addAndGet(ruleName, delta) + } + + def incNumEffectiveExecution(ruleName: String): Unit = { + numEffectiveRunsMap.incrementAndGet(ruleName) + } + + def incNumExecution(ruleName: String): Unit = { + numRunsMap.incrementAndGet(ruleName) + } + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val map = timeMap.asMap().asScala + val maxLengthRuleNames = map.keys.map(_.toString.length).max + + val colRuleName = "Rule".padTo(maxLengthRuleNames, " ").mkString + val colRunTime = "Effective Time / Total Time".padTo(len = 47, " ").mkString + val colNumRuns = "Effective Runs / Total Runs".padTo(len = 47, " ").mkString + + val ruleMetrics = map.toSeq.sortBy(_._2).reverseMap { case (name, time) => + val timeEffectiveRun = timeEffectiveRunsMap.get(name) + val numRuns = numRunsMap.get(name) + val numEffectiveRun = numEffectiveRunsMap.get(name) + + val ruleName = name.padTo(maxLengthRuleNames, " ").mkString + val runtimeValue = s"$timeEffectiveRun / $time".padTo(len = 47, " ").mkString + val numRunValue = s"$numEffectiveRun / $numRuns".padTo(len = 47, " ").mkString + s"$ruleName $runtimeValue $numRunValue" + }.mkString("\n", "\n", "") + + s""" + |=== Metrics of Analyzer/Optimizer Rules === + |Total number of runs: $totalNumRuns + |Total time: ${totalTime / 1000000000D} seconds + | + |$colRuleName $colRunTime $colNumRuns + |$ruleMetrics + """.stripMargin + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 7e4b784033bfc..dccb44ddebfa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.rules -import scala.collection.JavaConverters._ - -import com.google.common.util.concurrent.AtomicLongMap - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode @@ -28,18 +24,16 @@ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.util.Utils object RuleExecutor { - protected val timeMap = AtomicLongMap.create[String]() - - /** Resets statistics about time spent running specific rules */ - def resetTime(): Unit = timeMap.clear() + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val map = timeMap.asMap().asScala - val maxSize = map.keys.map(_.toString.length).max - map.toSeq.sortBy(_._2).reverseMap { case (k, v) => - s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n", "\n", "") + queryExecutionMeter.dumpTimeSpent() + } + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + queryExecutionMeter.resetMetrics() } } @@ -77,6 +71,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ def execute(plan: TreeType): TreeType = { var curPlan = plan + val queryExecutionMetrics = RuleExecutor.queryExecutionMeter batches.foreach { batch => val batchStartPlan = curPlan @@ -91,15 +86,18 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { + queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) + queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} """.stripMargin) } + queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) + queryExecutionMetrics.incNumExecution(rule.ruleName) // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { @@ -135,9 +133,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (!batchStartPlan.fastEquals(curPlan)) { logDebug( s""" - |=== Result of Batch ${batch.name} === - |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} - """.stripMargin) + |=== Result of Batch ${batch.name} === + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} + """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index 7037749f14478..e51aad021fcbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -46,7 +46,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B override def beforeAll() { super.beforeAll() - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } protected def checkGeneratedCode(plan: SparkPlan): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e3901af4b9988..054ada56d99ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -291,7 +291,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll(): Unit = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 45791c69b4cb7..cebaad5b4ad9b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -62,7 +62,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll() { From 896e45af5fea264683b1d7d20a1711f33908a06f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:32:59 -0800 Subject: [PATCH 0170/2461] [MINOR][SQL][TEST] Test case cleanups for recent PRs ## What changes were proposed in this pull request? Revert the unneeded test case changes we made in SPARK-23000 Also fixes the test suites that do not call `super.afterAll()` in the local `afterAll`. The `afterAll()` of `TestHiveSingleton` actually reset the environments. ## How was this patch tested? N/A Author: gatorsmile Closes #20341 from gatorsmile/testRelated. --- .../apache/spark/sql/DataFrameJoinSuite.scala | 21 ++++++----- .../apache/spark/sql/hive/test/TestHive.scala | 3 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 +++++++------- .../sql/hive/execution/HiveUDAFSuite.scala | 8 +++-- .../hive/execution/Hive_2_1_DDLSuite.scala | 6 +++- .../execution/ObjectHashAggregateSuite.scala | 6 +++- .../apache/spark/sql/hive/parquetSuites.scala | 35 +++++++++++-------- 7 files changed, 60 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 1656f290ee19c..0d9eeabb397a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class DataFrameJoinSuite extends QueryTest with SharedSQLContext { @@ -276,16 +277,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + "is false or null") { - val df = spark.range(10) - val dfNull = spark.range(10).select(lit(null).as("b")) - val planNull = df.join(dfNull, $"id" === $"b", "left").queryExecution.analyzed - - spark.sessionState.executePlan(planNull).optimizedPlan - - val dfOne = df.select(lit(1).as("a")) - val dfTwo = spark.range(10).select(lit(2).as("b")) - val planFalse = dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.analyzed - - spark.sessionState.executePlan(planFalse).optimizedPlan + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + df.join(dfNull, $"id" === $"b", "left").queryExecution.optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index c84131fc3212a..7287e20d55bbe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -492,8 +492,7 @@ private[hive] class TestHiveSparkSession( protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** - * Resets the test instance by deleting any tables that have been created. - * TODO: also clear out UDFs, views, etc. + * Resets the test instance by deleting any table, view, temp view, and UDF that have been created */ def reset() { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 83b4c862e2546..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("default.t") { + withTable("t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("default.t") + .saveAsTable("t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,15 +187,14 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("default.t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === - Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("default.t") { + withTable("t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -204,7 +203,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("default.t") + .saveAsTable("t") } val hiveTable = @@ -220,8 +219,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("default.t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + checkAnswer(table("t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -229,9 +228,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("default.t") { + withTable("t") { sql( - s"""CREATE TABLE default.t USING $provider + s"""CREATE TABLE t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -249,9 +248,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("default.t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === - Seq("1\tval_1")) + checkAnswer(table("t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 8986fb58c6460..7402c9626873c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -49,8 +49,12 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("built-in Hive UDAF") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index bc828877e35ec..eaedac1fa95d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -74,7 +74,11 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with Before } override def afterAll(): Unit = { - catalog = null + try { + catalog = null + } finally { + super.afterAll() + } } test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 9eaf44c043c71..8dbcd24cd78de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -47,7 +47,11 @@ class ObjectHashAggregateSuite } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("typed_count without grouping keys") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 740e0837350cc..2327d83a1b4f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -180,15 +180,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet", - "jt", - "jt_array", - "test_parquet") - super.afterAll() + try { + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + } finally { + super.afterAll() + } } test(s"conversion is working") { @@ -931,11 +934,15 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with } override protected def afterAll(): Unit = { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() + try { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } finally { + super.afterAll() + } } /** From 5d680cae486c77cdb12dbe9e043710e49e8d51e4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 20:56:38 +0800 Subject: [PATCH 0171/2461] [SPARK-23090][SQL] polish ColumnVector ## What changes were proposed in this pull request? Several improvements: * provide a default implementation for the batch get methods * rename `getChildColumn` to `getChild`, which is more concise * remove `getStruct(int, int)`, it's only used to simplify the codegen, which is an internal thing, we should not add a public API for this purpose. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20277 from cloud-fan/column-vector. --- .../expressions/codegen/CodeGenerator.scala | 18 ++-- .../datasources/orc/OrcColumnVector.java | 65 +----------- .../orc/OrcColumnarBatchReader.java | 23 ++--- .../vectorized/ColumnVectorUtils.java | 10 +- .../vectorized/MutableColumnarRow.java | 4 +- .../vectorized/WritableColumnVector.java | 10 +- .../sql/vectorized/ArrowColumnVector.java | 99 +------------------ .../spark/sql/vectorized/ColumnVector.java | 79 +++++++++++---- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 46 ++++----- .../sql/execution/ColumnarBatchScan.scala | 2 +- .../VectorizedHashMapGenerator.scala | 4 +- .../execution/arrow/ArrowWriterSuite.scala | 14 +-- .../vectorized/ArrowColumnVectorSuite.scala | 12 +-- .../vectorized/ColumnVectorSuite.scala | 12 +-- .../vectorized/ColumnarBatchBenchmark.scala | 38 +++---- .../vectorized/ColumnarBatchSuite.scala | 20 ++-- 17 files changed, 164 insertions(+), 296 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2c714c228e6c9..f96ed7628fda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -688,17 +688,13 @@ class CodegenContext { /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValue(vector: String, rowId: String, dataType: DataType): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.get${primitiveTypeName(jt)}($rowId)" - case t: DecimalType => - s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" - case StringType => - s"$vector.getUTF8String($rowId)" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index b6e792274da11..aaf2a380034a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) { return longData.vector[getRowIndex(rowId)] == 1; } - @Override - public boolean[] getBooleans(int rowId, int count) { - boolean[] res = new boolean[count]; - for (int i = 0; i < count; i++) { - res[i] = getBoolean(rowId + i); - } - return res; - } - @Override public byte getByte(int rowId) { return (byte) longData.vector[getRowIndex(rowId)]; } - @Override - public byte[] getBytes(int rowId, int count) { - byte[] res = new byte[count]; - for (int i = 0; i < count; i++) { - res[i] = getByte(rowId + i); - } - return res; - } - @Override public short getShort(int rowId) { return (short) longData.vector[getRowIndex(rowId)]; } - @Override - public short[] getShorts(int rowId, int count) { - short[] res = new short[count]; - for (int i = 0; i < count; i++) { - res[i] = getShort(rowId + i); - } - return res; - } - @Override public int getInt(int rowId) { return (int) longData.vector[getRowIndex(rowId)]; } - @Override - public int[] getInts(int rowId, int count) { - int[] res = new int[count]; - for (int i = 0; i < count; i++) { - res[i] = getInt(rowId + i); - } - return res; - } - @Override public long getLong(int rowId) { int index = getRowIndex(rowId); @@ -171,43 +135,16 @@ public long getLong(int rowId) { } } - @Override - public long[] getLongs(int rowId, int count) { - long[] res = new long[count]; - for (int i = 0; i < count; i++) { - res[i] = getLong(rowId + i); - } - return res; - } - @Override public float getFloat(int rowId) { return (float) doubleData.vector[getRowIndex(rowId)]; } - @Override - public float[] getFloats(int rowId, int count) { - float[] res = new float[count]; - for (int i = 0; i < count; i++) { - res[i] = getFloat(rowId + i); - } - return res; - } - @Override public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public double[] getDoubles(int rowId, int count) { - double[] res = new double[count]; - for (int i = 0; i < count; i++) { - res[i] = getDouble(rowId + i); - } - return res; - } - @Override public int getArrayLength(int rowId) { throw new UnsupportedOperationException(); @@ -245,7 +182,7 @@ public org.apache.spark.sql.vectorized.ColumnVector arrayData() { } @Override - public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 89bae4326e93b..5e7cad470e1d1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -289,10 +289,9 @@ private void putRepeatingValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); int size = data.vector[0].length; - arrayData.reserve(size); - arrayData.putBytes(0, size, data.vector[0], 0); + toColumn.arrayData().reserve(size); + toColumn.arrayData().putBytes(0, size, data.vector[0], 0); for (int index = 0; index < batchSize; index++) { toColumn.putArray(index, 0, size); } @@ -352,7 +351,7 @@ private void putNonNullValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = ((BytesColumnVector)fromColumn); - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(data.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { @@ -363,8 +362,7 @@ private void putNonNullValues( DecimalType decimalType = (DecimalType)type; DecimalColumnVector data = ((DecimalColumnVector)fromColumn); if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { putDecimalWritable( @@ -459,7 +457,7 @@ private void putValues( } } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector vector = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(vector.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { @@ -474,8 +472,7 @@ private void putValues( DecimalType decimalType = (DecimalType)type; HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { if (fromColumn.isNull[index]) { @@ -521,8 +518,7 @@ private static void putDecimalWritable( toColumn.putLong(index, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0); toColumn.putArray(index, index * 16, bytes.length); } } @@ -547,9 +543,8 @@ private static void putDecimalWritables( toColumn.putLongs(0, size, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(bytes.length); - arrayData.putBytes(0, bytes.length, bytes, 0); + toColumn.arrayData().reserve(bytes.length); + toColumn.arrayData().putBytes(0, bytes.length, bytes, 0); for (int index = 0; index < size; index++) { toColumn.putArray(index, 0, bytes.length); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 5ee8cc8da2309..a2853bbadc92b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); - col.getChildColumn(0).putInts(0, capacity, c.months); - col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + col.getChild(0).putInts(0, capacity, c.months); + col.getChild(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { col.putInts(0, capacity, row.getInt(fieldIdx)); } else if (t instanceof TimestampType) { @@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)o; dst.appendStruct(false); - dst.getChildColumn(0).appendInt(c.months); - dst.getChildColumn(1).appendLong(c.microseconds); + dst.getChild(0).appendInt(c.months); + dst.getChild(1).appendLong(c.microseconds); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { @@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i dst.appendStruct(false); Row c = src.getStruct(fieldIdx); for (int i = 0; i < st.fields().length; i++) { - appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 70057a9def6c0..2bab095d4d951 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,8 +146,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + final int months = columns[ordinal].getChild(0).getInt(rowId); + final long microseconds = columns[ordinal].getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d2ae32b06f83b..ca4f00985c2a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) { return elementsAppended; } - /** - * Returns the data for the underlying array. - */ + // `WritableColumnVector` puts the data of array in the first child column vector, and puts the + // array offsets and lengths in the current column vector. @Override public WritableColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override - public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } /** * Returns the elements appended. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index bfd1b4cb0ef12..ca7a4751450d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -33,18 +33,6 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; - private void ensureAccessible(int index) { - ensureAccessible(index, 1); - } - - private void ensureAccessible(int index, int count) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index + count > valueCount) { - throw new IndexOutOfBoundsException( - String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); - } - } - @Override public int numNulls() { return accessor.getNullCount(); @@ -55,156 +43,75 @@ public void close() { if (childColumns != null) { for (int i = 0; i < childColumns.length; i++) { childColumns[i].close(); + childColumns[i] = null; } + childColumns = null; } accessor.close(); } @Override public boolean isNullAt(int rowId) { - ensureAccessible(rowId); return accessor.isNullAt(rowId); } @Override public boolean getBoolean(int rowId) { - ensureAccessible(rowId); return accessor.getBoolean(rowId); } - @Override - public boolean[] getBooleans(int rowId, int count) { - ensureAccessible(rowId, count); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getBoolean(rowId + i); - } - return array; - } - @Override public byte getByte(int rowId) { - ensureAccessible(rowId); return accessor.getByte(rowId); } - @Override - public byte[] getBytes(int rowId, int count) { - ensureAccessible(rowId, count); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getByte(rowId + i); - } - return array; - } - @Override public short getShort(int rowId) { - ensureAccessible(rowId); return accessor.getShort(rowId); } - @Override - public short[] getShorts(int rowId, int count) { - ensureAccessible(rowId, count); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getShort(rowId + i); - } - return array; - } - @Override public int getInt(int rowId) { - ensureAccessible(rowId); return accessor.getInt(rowId); } - @Override - public int[] getInts(int rowId, int count) { - ensureAccessible(rowId, count); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getInt(rowId + i); - } - return array; - } - @Override public long getLong(int rowId) { - ensureAccessible(rowId); return accessor.getLong(rowId); } - @Override - public long[] getLongs(int rowId, int count) { - ensureAccessible(rowId, count); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getLong(rowId + i); - } - return array; - } - @Override public float getFloat(int rowId) { - ensureAccessible(rowId); return accessor.getFloat(rowId); } - @Override - public float[] getFloats(int rowId, int count) { - ensureAccessible(rowId, count); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getFloat(rowId + i); - } - return array; - } - @Override public double getDouble(int rowId) { - ensureAccessible(rowId); return accessor.getDouble(rowId); } - @Override - public double[] getDoubles(int rowId, int count) { - ensureAccessible(rowId, count); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getDouble(rowId + i); - } - return array; - } - @Override public int getArrayLength(int rowId) { - ensureAccessible(rowId); return accessor.getArrayLength(rowId); } @Override public int getArrayOffset(int rowId) { - ensureAccessible(rowId); return accessor.getArrayOffset(rowId); } @Override public Decimal getDecimal(int rowId, int precision, int scale) { - ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { - ensureAccessible(rowId); return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { - ensureAccessible(rowId); return accessor.getBinary(rowId); } @@ -212,7 +119,7 @@ public byte[] getBinary(int rowId) { public ArrowColumnVector arrayData() { return childColumns[0]; } @Override - public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d1196e1299fee..f9936214035b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -51,12 +51,16 @@ public abstract class ColumnVector implements AutoCloseable { public final DataType dataType() { return type; } /** - * Cleans up memory for this column. The column is not usable after this. + * Cleans up memory for this column vector. The column vector is not usable after this. + * + * This overwrites `AutoCloseable.close` to remove the `throws` clause, as column vector is + * in-memory and we don't expect any exception to happen during closing. */ + @Override public abstract void close(); /** - * Returns the number of nulls in this column. + * Returns the number of nulls in this column vector. */ public abstract int numNulls(); @@ -73,7 +77,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract boolean[] getBooleans(int rowId, int count); + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -83,7 +93,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract byte[] getBytes(int rowId, int count); + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -93,7 +109,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract short[] getShorts(int rowId, int count); + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -103,7 +125,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract int[] getInts(int rowId, int count); + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -113,7 +141,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract long[] getLongs(int rowId, int count); + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -123,7 +157,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract float[] getFloats(int rowId, int count); + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -133,7 +173,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract double[] getDoubles(int rowId, int count); + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } /** * Returns the length of the array for rowId. @@ -152,14 +198,6 @@ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } - /** - * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark - * codegen framework, the second parameter is totally ignored. - */ - public final ColumnarRow getStruct(int rowId, int size) { - return getStruct(rowId); - } - /** * Returns the array for rowId. */ @@ -196,9 +234,9 @@ public MapData getMap(int ordinal) { public abstract ColumnVector arrayData(); /** - * Returns the ordinal's child data column. + * Returns the ordinal's child column vector. */ - public abstract ColumnVector getChildColumn(int ordinal); + public abstract ColumnVector getChild(int ordinal); /** * Data type for this column. @@ -206,8 +244,7 @@ public MapData getMap(int ordinal) { protected DataType type; /** - * Sets up the common state and also handles creating the child columns if this is a nested - * type. + * Sets up the data type of this column vector. */ protected ColumnVector(DataType type) { this.type = type; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d89a52e7a4fe..522c39580389f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -133,8 +133,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + int month = data.getChild(0).getInt(offset + ordinal); + long microseconds = data.getChild(1).getLong(offset + ordinal); return new CalendarInterval(month, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 3c6656dec77cd..2e59085a82768 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -28,7 +28,7 @@ */ public final class ColumnarRow extends InternalRow { // The data for this row. - // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; @@ -53,7 +53,7 @@ public InternalRow copy() { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = data.getChildColumn(i).dataType(); + DataType dt = data.getChild(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -93,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChild(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChild(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } + public byte getByte(int ordinal) { return data.getChild(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } + public short getShort(int ordinal) { return data.getChild(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } + public int getInt(int ordinal) { return data.getChild(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } + public long getLong(int ordinal) { return data.getChild(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChild(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChild(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getUTF8String(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getBinary(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); - final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + final int months = data.getChild(ordinal).getChild(0).getInt(rowId); + final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getStruct(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getArray(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index dd68df9686691..04f2619ed7541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -50,7 +50,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { dataType: DataType, nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) - val value = ctx.getValue(columnVar, dataType, ordinal) + val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index eb48584d0c1ee..633eeac180974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -127,8 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", - key.dataType), key.name)})""" + val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index c42bc60a59d67..92506032ab2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -217,21 +217,21 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct0 = reader.getStruct(0, 2) + val struct0 = reader.getStruct(0) assert(struct0.getInt(0) === 1) assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) - val struct1 = reader.getStruct(1, 2) + val struct1 = reader.getStruct(1) assert(struct1.isNullAt(0)) assert(struct1.isNullAt(1)) assert(reader.isNullAt(2)) - val struct3 = reader.getStruct(3, 2) + val struct3 = reader.getStruct(3) assert(struct3.getInt(0) === 4) assert(struct3.isNullAt(1)) - val struct4 = reader.getStruct(4, 2) + val struct4 = reader.getStruct(4) assert(struct4.isNullAt(0)) assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) @@ -252,15 +252,15 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + val struct00 = reader.getStruct(0).getStruct(0, 2) assert(struct00.getInt(0) === 1) assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) - val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + val struct10 = reader.getStruct(1).getStruct(0, 2) assert(struct10.isNullAt(0)) assert(struct10.isNullAt(1)) - val struct2 = reader.getStruct(2, 1) + val struct2 = reader.getStruct(2) assert(struct2.isNullAt(0)) assert(reader.isNullAt(3)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 53432669e215d..e794f50781ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -346,11 +346,11 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 0) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) @@ -398,21 +398,21 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 1) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) - val row2 = columnVector.getStruct(2, 2) + val row2 = columnVector.getStruct(2) assert(row2.isNullAt(0)) assert(row2.getLong(1) === 3L) assert(columnVector.isNullAt(3)) - val row4 = columnVector.getStruct(4, 2) + val row4 = columnVector.getStruct(4) assert(row4.getInt(0) === 5) assert(row4.getLong(1) === 5L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 944240f3bade5..2d1ad4b456783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -199,17 +199,17 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) testVectors("struct", 10, structType) { testVector => - val c1 = testVector.getChildColumn(0) - val c2 = testVector.getChildColumn(1) + val c1 = testVector.getChild(0) + val c2 = testVector.getChild(1) c1.putInt(0, 123) c2.putDouble(0, 3.45) c1.putInt(1, 456) c2.putDouble(1, 5.67) - assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0).get(0, IntegerType) === 123) + assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1).get(0, IntegerType) === 456) + assert(testVector.getStruct(1).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 38ea2e47fdef8..ad74fb99b0c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -268,17 +268,17 @@ object ColumnarBatchBenchmark { Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Java Array 177 / 181 1856.4 0.5 1.0X - ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X - ByteBuffer API 1411 / 1418 232.2 4.3 0.1X - DirectByteBuffer 467 / 474 701.8 1.4 0.4X - Unsafe Buffer 178 / 185 1843.6 0.5 1.0X - Column(on heap) 178 / 184 1840.8 0.5 1.0X - Column(off heap) 341 / 344 961.8 1.0 0.5X - Column(off heap direct) 178 / 184 1845.4 0.5 1.0X - UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X - UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X - Column On Heap Append 309 / 318 1059.1 0.9 0.6X + Java Array 177 / 183 1851.1 0.5 1.0X + ByteBuffer Unsafe 314 / 330 1043.7 1.0 0.6X + ByteBuffer API 1298 / 1307 252.4 4.0 0.1X + DirectByteBuffer 465 / 483 704.2 1.4 0.4X + Unsafe Buffer 179 / 183 1835.5 0.5 1.0X + Column(on heap) 181 / 186 1815.2 0.6 1.0X + Column(off heap) 344 / 349 951.7 1.1 0.5X + Column(off heap direct) 178 / 186 1838.6 0.5 1.0X + UnsafeRow (on heap) 388 / 394 844.8 1.2 0.5X + UnsafeRow (off heap) 400 / 403 819.4 1.2 0.4X + Column On Heap Append 315 / 325 1041.8 1.0 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -337,8 +337,8 @@ object ColumnarBatchBenchmark { Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Bitset 726 / 727 462.4 2.2 1.0X - Byte Array 530 / 542 632.7 1.6 1.4X + Bitset 741 / 747 452.6 2.2 1.0X + Byte Array 531 / 542 631.6 1.6 1.4X */ benchmark.run() } @@ -394,8 +394,8 @@ object ColumnarBatchBenchmark { String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap 332 / 338 49.3 20.3 1.0X - Off Heap 466 / 467 35.2 28.4 0.7X + On Heap 351 / 362 46.6 21.4 1.0X + Off Heap 456 / 466 35.9 27.8 0.8X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -479,10 +479,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 415 / 422 394.7 2.5 1.0X - Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X - On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X - Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + On Heap Read Size Only 416 / 423 393.5 2.5 1.0X + Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X + On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X + Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X */ benchmark.run } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index cd90681ecabc6..1873c24ab063c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -732,8 +732,8 @@ class ColumnarBatchSuite extends SparkFunSuite { "Struct Column", 10, new StructType().add("int", IntegerType).add("double", DoubleType)) { column => - val c1 = column.getChildColumn(0) - val c2 = column.getChildColumn(1) + val c1 = column.getChild(0) + val c2 = column.getChild(1) assert(c1.dataType() == IntegerType) assert(c2.dataType() == DoubleType) @@ -787,8 +787,8 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new ArrayType(structType, true)) { column => val data = column.arrayData() - val c0 = data.getChildColumn(0) - val c1 = data.getChildColumn(1) + val c0 = data.getChild(0) + val c1 = data.getChild(1) // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) (0 until 6).foreach { i => c0.putInt(i, i) @@ -815,8 +815,8 @@ class ColumnarBatchSuite extends SparkFunSuite { new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true))) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) @@ -844,13 +844,13 @@ class ColumnarBatchSuite extends SparkFunSuite { "Nest Struct in Struct", 10, new StructType().add("int", IntegerType).add("struct", subSchema)) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) - val c1c0 = c1.getChildColumn(0) - val c1c1 = c1.getChildColumn(1) + val c1c0 = c1.getChild(0) + val c1c1 = c1.getChild(1) // Structs in c1: (7, 70), (8, 80), (9, 90) c1c0.putInt(0, 7) c1c0.putInt(1, 8) From 87ffe7adddf517541aac0d1e8536b02ad8881606 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 22 Jan 2018 22:12:50 +0900 Subject: [PATCH 0172/2461] [SPARK-7721][PYTHON][TESTS] Adds PySpark coverage generation script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Note that this PR was made based on the top of https://github.com/apache/spark/pull/20151. So, it almost leaves the main codes intact. This PR proposes to add a script for the preparation of automatic PySpark coverage generation. Now, it's difficult to check the actual coverage in case of PySpark. With this script, it allows to run tests by the way we did via `run-tests` script before. The usage is exactly the same with `run-tests` script as this basically wraps it. This script and PR alone should also be useful. I was asked about how to run this before, and seems some reviewers (including me) need this. It would be also useful to run it manually. It usually requires a small diff in normal Python projects but PySpark cases are a bit different because apparently we are unable to track the coverage after it's forked. So, here, I made a custom worker that forces the coverage, based on the top of https://github.com/apache/spark/pull/20151. I made a simple demo. Please take a look - https://spark-test.github.io/pyspark-coverage-site. To show up the structure, this PR adds the files as below: ``` python ├── .coveragerc # Runtime configuration when we run the script. ├── run-tests-with-coverage # The script that has coverage support and wraps run-tests script. └── test_coverage # Directories that have files required when running coverage. ├── conf │   └── spark-defaults.conf # Having the configuration 'spark.python.daemon.module'. ├── coverage_daemon.py # A daemon having custom fix and wrapping our daemon.py └── sitecustomize.py # Initiate coverage with COVERAGE_PROCESS_START ``` Note that this PR has a minor nit: [This scope](https://github.com/apache/spark/blob/04e44b37cc04f62fbf9e08c7076349e0a4d12ea8/python/pyspark/daemon.py#L148-L169) in `daemon.py` is not in the coverage results as basically I am producing the coverage results in `worker.py` separately and then merging it. I believe it's not a big deal. In a followup, I might have a site that has a single up-to-date PySpark coverage from the master branch as the fallback / default, or have a site that has multiple PySpark coverages and the site link will be left to each pull request. ## How was this patch tested? Manually tested. Usage is the same with the existing Python test script - `./python/run-tests`. For example, ``` sh run-tests-with-coverage --python-executables=python3 --modules=pyspark-sql ``` Running this will generate HTMLs under `./python/test_coverage/htmlcov`. Console output example: ``` sh run-tests-with-coverage --python-executables=python3,python --modules=pyspark-core Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3', 'python'] Will test the following Python modules: ['pyspark-core'] Starting test(python): pyspark.tests Starting test(python3): pyspark.tests ... Tests passed in 231 seconds Combining collected coverage data under /.../spark/python/test_coverage/coverage_data Reporting the coverage data at /...spark/python/test_coverage/coverage_data/coverage Name Stmts Miss Branch BrPart Cover -------------------------------------------------------------- pyspark/__init__.py 41 0 8 2 96% ... pyspark/profiler.py 74 11 22 5 83% pyspark/rdd.py 871 40 303 32 93% pyspark/rddsampler.py 68 10 32 2 82% ... -------------------------------------------------------------- TOTAL 8521 3077 2748 191 59% Generating HTML files for PySpark coverage under /.../spark/python/test_coverage/htmlcov ``` Author: hyukjinkwon Closes #20204 from HyukjinKwon/python-coverage. --- .gitignore | 2 + python/.coveragerc | 21 ++++++ python/run-tests-with-coverage | 69 +++++++++++++++++++ python/run-tests.py | 5 +- python/test_coverage/conf/spark-defaults.conf | 21 ++++++ python/test_coverage/coverage_daemon.py | 45 ++++++++++++ python/test_coverage/sitecustomize.py | 23 +++++++ 7 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 python/.coveragerc create mode 100755 python/run-tests-with-coverage create mode 100644 python/test_coverage/conf/spark-defaults.conf create mode 100644 python/test_coverage/coverage_daemon.py create mode 100644 python/test_coverage/sitecustomize.py diff --git a/.gitignore b/.gitignore index 903297db96901..39085904e324c 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,8 @@ project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip python/deps +python/test_coverage/coverage_data +python/test_coverage/htmlcov python/pyspark/python reports/ scalastyle-on-compile.generated.xml diff --git a/python/.coveragerc b/python/.coveragerc new file mode 100644 index 0000000000000..b3339cd356a6e --- /dev/null +++ b/python/.coveragerc @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[run] +branch = true +parallel = true +data_file = ${COVERAGE_DIR}/coverage_data/coverage diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage new file mode 100755 index 0000000000000..6d74b563e9140 --- /dev/null +++ b/python/run-tests-with-coverage @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -o pipefail +set -e + +# This variable indicates which coverage executable to run to combine coverages +# and generate HTMLs, for example, 'coverage3' in Python 3. +COV_EXEC="${COV_EXEC:-coverage}" +FWDIR="$(cd "`dirname $0`"; pwd)" +pushd "$FWDIR" > /dev/null + +# Ensure that coverage executable is installed. +if ! hash $COV_EXEC 2>/dev/null; then + echo "Missing coverage executable in your path, skipping PySpark coverage" + exit 1 +fi + +# Set up the directories for coverage results. +export COVERAGE_DIR="$FWDIR/test_coverage" +rm -fr "$COVERAGE_DIR/coverage_data" +rm -fr "$COVERAGE_DIR/htmlcov" +mkdir -p "$COVERAGE_DIR/coverage_data" + +# Current directory are added in the python path so that it doesn't refer our built +# pyspark zip library first. +export PYTHONPATH="$FWDIR:$PYTHONPATH" +# Also, our sitecustomize.py and coverage_daemon.py are included in the path. +export PYTHONPATH="$COVERAGE_DIR:$PYTHONPATH" + +# We use 'spark.python.daemon.module' configuration to insert the coverage supported workers. +export SPARK_CONF_DIR="$COVERAGE_DIR/conf" + +# This environment variable enables the coverage. +export COVERAGE_PROCESS_START="$FWDIR/.coveragerc" + +# If you'd like to run a specific unittest class, you could do such as +# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests +./run-tests "$@" + +# Don't run coverage for the coverage command itself +unset COVERAGE_PROCESS_START + +# Coverage could generate empty coverage data files. Remove it to get rid of warnings when combining. +find $COVERAGE_DIR/coverage_data -size 0 -print0 | xargs -0 rm +echo "Combining collected coverage data under $COVERAGE_DIR/coverage_data" +$COV_EXEC combine +echo "Reporting the coverage data at $COVERAGE_DIR/coverage_data/coverage" +$COV_EXEC report --include "pyspark/*" +echo "Generating HTML files for PySpark coverage under $COVERAGE_DIR/htmlcov" +$COV_EXEC html --ignore-errors --include "pyspark/*" --directory "$COVERAGE_DIR/htmlcov" + +popd diff --git a/python/run-tests.py b/python/run-tests.py index 1341086f02db0..f03284c334285 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -38,7 +38,7 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which, subprocess_check_output # noqa +from sparktestsupport.shellutils import which, subprocess_check_output, run_cmd # noqa from sparktestsupport.modules import all_modules # noqa @@ -175,6 +175,9 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: + if "COVERAGE_PROCESS_START" in os.environ: + # Make sure if coverage is installed. + run_cmd([python_exec, "-c", "import coverage"]) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() diff --git a/python/test_coverage/conf/spark-defaults.conf b/python/test_coverage/conf/spark-defaults.conf new file mode 100644 index 0000000000000..bf44ea6e7cfec --- /dev/null +++ b/python/test_coverage/conf/spark-defaults.conf @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This is used to generate PySpark coverage results. Seems there's no way to +# add a configuration when SPARK_TESTING environment variable is set because +# we will directly execute modules by python -m. +spark.python.daemon.module coverage_daemon diff --git a/python/test_coverage/coverage_daemon.py b/python/test_coverage/coverage_daemon.py new file mode 100644 index 0000000000000..c87366a1ac23b --- /dev/null +++ b/python/test_coverage/coverage_daemon.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import imp + + +# This is a hack to always refer the main code rather than built zip. +main_code_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +daemon = imp.load_source("daemon", "%s/pyspark/daemon.py" % main_code_dir) + +if "COVERAGE_PROCESS_START" in os.environ: + worker = imp.load_source("worker", "%s/pyspark/worker.py" % main_code_dir) + + def _cov_wrapped(*args, **kwargs): + import coverage + cov = coverage.coverage( + config_file=os.environ["COVERAGE_PROCESS_START"]) + cov.start() + try: + worker.main(*args, **kwargs) + finally: + cov.stop() + cov.save() + daemon.worker_main = _cov_wrapped +else: + raise RuntimeError("COVERAGE_PROCESS_START environment variable is not set, exiting.") + + +if __name__ == '__main__': + daemon.manager() diff --git a/python/test_coverage/sitecustomize.py b/python/test_coverage/sitecustomize.py new file mode 100644 index 0000000000000..630237a518126 --- /dev/null +++ b/python/test_coverage/sitecustomize.py @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Note that this 'sitecustomize' module is a built-in feature in Python. +# If this module is defined, it's executed when the Python session begins. +# `coverage.process_startup()` seeks if COVERAGE_PROCESS_START environment +# variable is set or not. If set, it starts to run the coverage. +import coverage +coverage.process_startup() From 4327ccf289b5a0dc51f6294113d01af6eb52eea0 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 22 Jan 2018 08:36:17 -0600 Subject: [PATCH 0173/2461] [SPARK-11630][CORE] ClosureCleaner moved from warning to debug ## What changes were proposed in this pull request? ClosureCleaner moved from warning to debug ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20337 from rekhajoshm/SPARK-11630-1. --- core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 40616421b5bca..ad0c0639521f6 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -207,7 +207,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + logDebug(s"Expected a closure; got ${func.getClass.getName}") return } From 446948af1d8dbc080a26a6eec6f743d338f1d12b Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Mon, 22 Jan 2018 10:36:28 -0800 Subject: [PATCH 0174/2461] [SPARK-23121][CORE] Fix for ui becoming unaccessible for long running streaming apps ## What changes were proposed in this pull request? The allJobs and the job pages attempt to use stage attempt and DAG visualization from the store, but for long running jobs they are not guaranteed to be retained, leading to exceptions when these pages are rendered. To fix it `store.lastStageAttempt(stageId)` and `store.operationGraphForJob(jobId)` are wrapped in `store.asOption` and default values are used if the info is missing. ## How was this patch tested? Manual testing of the UI, also using the test command reported in SPARK-23121: ./bin/spark-submit --class org.apache.spark.examples.streaming.HdfsWordCount ./examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar /spark Closes #20287 Author: Sandor Murakozi Closes #20330 from smurakozi/SPARK-23121. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 24 ++++++++++--------- .../org/apache/spark/ui/jobs/JobPage.scala | 10 ++++++-- .../org/apache/spark/ui/jobs/StagePage.scala | 9 ++++--- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index e3b72f1f34859..2b0f4acbac72a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -36,6 +36,9 @@ import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("") { + + import ApiHelper._ + private val JOBS_LEGEND =
val jobId = job.jobId val status = job.status - val jobDescription = store.lastStageAttempt(job.stageIds.max).description - val displayJobDescription = jobDescription - .map(UIUtils.makeDescription(_, "", plainText = true).text) - .getOrElse("") + val (_, lastStageDescription) = lastStageNameAndDescription(store, job) + val jobDescription = UIUtils.makeDescription(lastStageDescription, "", plainText = true).text + val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -80,7 +82,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We // The timeline library treats contents as HTML, so we have to escape them. We need to add // extra layers of escaping in order to embed this in a Javascript string literal. - val escapedDesc = Utility.escape(displayJobDescription) + val escapedDesc = Utility.escape(jobDescription) val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc) val jobEventJsonAsStr = s""" @@ -430,6 +432,8 @@ private[ui] class JobDataSource( sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { + import ApiHelper._ + // Convert JobUIData to JobTableRowData which contains the final contents to show in the table // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) @@ -454,23 +458,21 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) - val lastStageDescription = lastStageAttempt.description.getOrElse("") + val (lastStageName, lastStageDescription) = lastStageNameAndDescription(store, jobData) - val formattedJobDescription = - UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) + val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - lastStageAttempt.name, + lastStageName, lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - formattedJobDescription, + jobDescription, detailUrl ) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index c27f30c21a843..46f2a76cc651b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -336,8 +336,14 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, store.executorList(false), appStartTime) - content ++= UIUtils.showDagVizForJob( - jobId, store.operationGraphForJob(jobId)) + val operationGraphContent = store.asOption(store.operationGraphForJob(jobId)) match { + case Some(operationGraph) => UIUtils.showDagVizForJob(jobId, operationGraph) + case None => +
+

No DAG visualization information to display for job {jobId}

+
+ } + content ++= operationGraphContent if (shouldShowActiveStages) { content ++= diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0eb3190205c3e..5c2b0c3a19996 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -23,12 +23,10 @@ import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} -import scala.xml.{Elem, Node, Unparsed} +import scala.xml.{Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkConf -import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status._ import org.apache.spark.status.api.v1._ @@ -1020,4 +1018,9 @@ private object ApiHelper { } } + def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { + val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) + } + } From 76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 22 Jan 2018 13:55:14 -0600 Subject: [PATCH 0175/2461] [MINOR] Typo fixes ## What changes were proposed in this pull request? Typo fixes ## How was this patch tested? Local build / Doc-only changes Author: Jacek Laskowski Closes #20344 from jaceklaskowski/typo-fixes. --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/kafka010/KafkaSourceProvider.scala | 4 ++-- .../apache/spark/sql/kafka010/KafkaWriteTask.scala | 2 +- .../org/apache/spark/sql/streaming/OutputMode.java | 2 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- .../spark/sql/catalyst/analysis/unresolved.scala | 2 +- .../catalyst/expressions/aggregate/interfaces.scala | 12 +++++------- .../catalyst/plans/logical/LogicalPlanVisitor.scala | 2 +- .../statsEstimation/BasicStatsPlanVisitor.scala | 2 +- .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../apache/spark/sql/execution/SparkSqlParser.scala | 2 +- .../spark/sql/execution/WholeStageCodegenExec.scala | 2 +- .../spark/sql/execution/command/SetCommand.scala | 4 ++-- .../spark/sql/execution/datasources/rules.scala | 2 +- .../sql/execution/streaming/HDFSMetadataLog.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeq.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeqLog.scala | 2 +- .../execution/streaming/StreamingQueryWrapper.scala | 2 +- .../sql/execution/streaming/state/StateStore.scala | 2 +- .../spark/sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/expressions/UserDefinedFunction.scala | 4 ++-- .../spark/sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../spark/sql/streaming/DataStreamReader.scala | 6 +++--- .../results/columnresolution-negative.sql.out | 2 +- .../sql-tests/results/columnresolution-views.sql.out | 2 +- .../sql-tests/results/columnresolution.sql.out | 6 +++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- .../apache/spark/sql/execution/SQLViewSuite.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalog.scala | 4 ++-- 32 files changed, 50 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 31f3cb9dfa0ae..3828d4f703247 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2276,7 +2276,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Clean a closure to make it ready to be serialized and send to tasks + * Clean a closure to make it ready to be serialized and sent to tasks * (removes unreferenced variables in $outer's, updates REPL variables) * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..62a998fbfb30b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -307,7 +307,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + - s"user-specified consumer groups is not used to track offsets.") + s"user-specified consumer groups are not used to track offsets.") } if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { @@ -335,7 +335,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "values are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + "operations to explicitly deserialize the values.") } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..d90630a8adc93 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Unsa import org.apache.spark.sql.types.{BinaryType, StringType} /** - * A simple trait for writing out data in a single Spark task, without any concerns about how + * Writes out data in a single Spark task, without any concerns about how * to commit or abort tasks. Exceptions thrown by the implementation of this class will * automatically trigger task aborts. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 2800b3068f87b..470c128ee6c3d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** - * OutputMode is used to what data will be written to a streaming sink when there is + * OutputMode describes what data will be written to a streaming sink when there is * new data available in a streaming DataFrame/Dataset. * * @since 2.0.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 35b35110e491f..2b14c8220d43b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -611,8 +611,8 @@ class Analyzer( if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + - "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + - "aroud this.") + s"avoid errors. Increase the value of ${SQLConf.MAX_NESTED_VIEW_DEPTH.key} to work " + + "around this.") } executeSameContext(child) } @@ -653,7 +653,7 @@ class Analyzer( // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exsits.") + s"database ${e.db} doesn't exist.") } } @@ -1524,7 +1524,7 @@ class Analyzer( } /** - * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] + * Extracts [[Generator]] from the projectList of a [[Project]] operator and creates [[Generate]] * operator under [[Project]]. * * This rule will throw [[AnalysisException]] for following cases: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d336f801d0770..a65f58fa61ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -294,7 +294,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu } else { val from = input.inputSet.map(_.name).mkString(", ") val targetString = target.get.mkString(".") - throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") + throw new AnalysisException(s"cannot resolve '$targetString.*' given input columns '$from'") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 19abce01a26cf..e1d16a2cd38b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -190,17 +190,15 @@ abstract class AggregateFunction extends Expression { def defaultResult: Option[Literal] = None /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because - * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * and the flag indicating if this aggregation is distinct aggregation or not. - * An [[AggregateFunction]] should not be used without being wrapped in - * an [[AggregateExpression]]. + * Creates [[AggregateExpression]] with `isDistinct` flag disabled. + * + * @see `toAggregateExpression(isDistinct: Boolean)` for detailed description */ def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct - * field of the [[AggregateExpression]] to the given value because + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct` + * flag of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, * and the flag indicating if this aggregation is distinct aggregation or not. * An [[AggregateFunction]] should not be used without being wrapped in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index e0748043c46e2..2c248d74869ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical /** - * A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties. + * A visitor pattern for traversing a [[LogicalPlan]] tree and computing some properties. */ trait LogicalPlanVisitor[T] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index ca0775a2e8408..b6c16079d1984 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.plans.logical._ /** - * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. + * A [[LogicalPlanVisitor]] that computes the statistics for the cost-based optimizer. */ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 5e1c4e0bd6069..85f67c7d66075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -48,8 +48,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } /** - * For leaf nodes, use its computeStats. For other nodes, we assume the size in bytes is the - * sum of all of the children's. + * For leaf nodes, use its `computeStats`. For other nodes, we assume the size in bytes is the + * product of all of the children's `computeStats`. */ override def default(p: LogicalPlan): Statistics = p match { case p: LeafNode => p.computeStats() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc4f4bf332459..1cef09a5bf053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -894,7 +894,7 @@ object SQLConf { .internal() .doc("The number of bins when generating histograms.") .intConf - .checkValue(num => num > 1, "The number of bins must be large than 1.") + .checkValue(num => num > 1, "The number of bins must be larger than 1.") .createWithDefault(254) val PERCENTILE_ACCURACY = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 82c5307d54360..6241d5cbb1d25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -154,7 +154,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => } /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL * configurations. */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..5f3d4448e4e54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -311,7 +311,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (partitioningColumns.isDefined) { throw new AnalysisException( "insertInto() can't be used together with partitionBy(). " + - "Partition columns have already be defined for the table. " + + "Partition columns have already been defined for the table. " + "It is not necessary to use partitionBy()." ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d3cfd2a1ffbf2..4828fa60a7b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -327,7 +327,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[DescribeTableCommand]] logical plan. + * Create a [[DescribeColumnCommand]] or [[DescribeTableCommand]] logical commands. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { val isExtended = ctx.EXTENDED != null || ctx.FORMATTED != null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 065954559e487..6102937852347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -58,7 +58,7 @@ trait CodegenSupport extends SparkPlan { } /** - * Whether this SparkPlan support whole stage codegen or not. + * Whether this SparkPlan supports whole stage codegen or not. */ def supportCodegen: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7477d025dfe89..3c900be839aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -91,8 +91,8 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm if (sparkSession.conf.get(CATALOG_IMPLEMENTATION.key).equals("hive") && key.startsWith("hive.")) { logWarning(s"'SET $key=$value' might not work, since Spark doesn't support changing " + - "the Hive config dynamically. Please passing the Hive-specific config by adding the " + - s"prefix spark.hadoop (e.g., spark.hadoop.$key) when starting a Spark application. " + + "the Hive config dynamically. Please pass the Hive-specific config by adding the " + + s"prefix spark.hadoop (e.g. spark.hadoop.$key) when starting a Spark application. " + "For details, see the link: https://spark.apache.org/docs/latest/configuration.html#" + "dynamically-loading-spark-properties.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index f64e079539c4f..5dbcf4a915cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.util.SchemaUtils /** - * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. + * Replaces [[UnresolvedRelation]]s if the plan is for direct query on files. */ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def maybeSQLFile(u: UnresolvedRelation): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 6e8154d58d4c6..00bc215a5dc8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -330,7 +330,7 @@ object HDFSMetadataLog { /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ trait FileManager { - /** List the files in a path that matches a filter. */ + /** List the files in a path that match a filter. */ def list(path: Path, filter: PathFilter): Array[FileStatus] /** Make directory at the give path and all its parent directories as needed. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a1b63a6de3823..73945b39b8967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PR case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { /** - * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of + * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of * sources. * * This method is typically used to associate a serialized offset with actual sources (which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index e3f4abcf9f1dc..2c8d7c7b0f3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * This class is used to log offsets to persistent files in HDFS. * Each file corresponds to a specific batch of offsets. The file - * format contain a version string in the first line, followed + * format contains a version string in the first line, followed * by a the JSON string representation of the offsets separated * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala index 020c9cb4a7304..3f2cdadfbaeee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} /** - * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to + * Wrap non-serializable StreamExecution to make the query serializable as it's easy for it to * get captured with normal usage. It's safe to capture the query but not use it in executors. * However, if the user tries to call its methods, it will throw `IllegalStateException`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6fe632f958ffc..d1d9f95cb0977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -94,7 +94,7 @@ trait StateStore { def abort(): Unit /** - * Return an iterator containing all the key-value pairs in the SateStore. Implementations must + * Return an iterator containing all the key-value pairs in the StateStore. Implementations must * ensure that updates (puts, removes) can be made while iterating over this iterator. */ def iterator(): Iterator[UnsafeRowPair] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index f29e135ac357f..e0554f0c4d337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -80,7 +80,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging planVisualization(metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse { -
No information to display for Plan {executionId}
+
No information to display for query {executionId}
} UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 40a058d2cadd2..bdc4bb4422ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types.DataType * * As an example: * {{{ - * // Defined a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => if (score > 0.5) true else false) + * // Define a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => score > 0.5) * * // Projects a column that adds a prediction column based on the score column. * df.select( predict(df("score")) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2867b4cd7da5e..007f8760edf82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -206,7 +206,7 @@ abstract class BaseSessionStateBuilder( /** * Logical query plan optimizer. * - * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + * Note: this depends on `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { new SparkOptimizer(catalog, experimentalMethods) { @@ -263,7 +263,7 @@ abstract class BaseSessionStateBuilder( * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. * - * This gets cloned from parent if available, otherwise is a new instance is created. + * This gets cloned from parent if available, otherwise a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { parentState.map(_.listenerManager.clone()).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 52f2e2639cd86..9f5ca9f914284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -118,7 +118,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * You can set the following option(s): *
    *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
  • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
* * @since 2.0.0 @@ -129,12 +129,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Adds input options for the underlying data source. + * (Java-specific) Adds input options for the underlying data source. * * You can set the following option(s): *
    *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
  • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
* * @since 2.0.0 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index b5a4f5c2bf654..539f673c9d679 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -195,7 +195,7 @@ SELECT t1.x.y.* FROM t1 struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 't1.x.y.*' give input columns 'i1'; +cannot resolve 't1.x.y.*' given input columns 'i1'; -- !query 23 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 7c451c2aa5b5c..2092119600954 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -88,7 +88,7 @@ SELECT global_temp.view1.* FROM global_temp.view1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' give input columns 'i1'; +cannot resolve 'global_temp.view1.*' given input columns 'i1'; -- !query 11 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index d3ca4443cce55..e10f516ad6e5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -179,7 +179,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 22 @@ -212,7 +212,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 26 @@ -420,7 +420,7 @@ SELECT mydb1.t5.* FROM mydb1.t5 struct<> -- !query 50 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; +cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; -- !query 51 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 083a0c0b1b9a0..a79ab47f0197e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1896,12 +1896,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { var e = intercept[AnalysisException] { sql("SELECT a.* FROM temp_table_no_cols a") }.getMessage - assert(e.contains("cannot resolve 'a.*' give input columns ''")) + assert(e.contains("cannot resolve 'a.*' given input columns ''")) e = intercept[AnalysisException] { dfNoCols.select($"b.*") }.getMessage - assert(e.contains("cannot resolve 'b.*' give input columns ''")) + assert(e.contains("cannot resolve 'b.*' given input columns ''")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 14082197ba0bd..ce8fde28a941c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -663,7 +663,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + - "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + "of spark.sql.view.maxNestedViewDepth to work around this.")) } val e = intercept[IllegalArgumentException] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 632e3e0c4c3f9..3b8a8ca301c27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -109,8 +109,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Get the raw table metadata from hive metastore directly. The raw table metadata may contains - * special data source properties and should not be exposed outside of `HiveExternalCatalog`. We + * Get the raw table metadata from hive metastore directly. The raw table metadata may contain + * special data source properties that should not be exposed outside of `HiveExternalCatalog`. We * should interpret these special data source properties and restore the original table metadata * before returning it. */ From 51eb750263dd710434ddb60311571fa3dcec66eb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 15:21:09 -0800 Subject: [PATCH 0176/2461] [SPARK-22389][SQL] data source v2 partitioning reporting interface ## What changes were proposed in this pull request? a new interface which allows data source to report partitioning and avoid shuffle at Spark side. The design is pretty like the internal distribution/partitioing framework. Spark defines a `Distribution` interfaces and several concrete implementations, and ask the data source to report a `Partitioning`, the `Partitioning` should tell Spark if it can satisfy a `Distribution` or not. ## How was this patch tested? new test Author: Wenchen Fan Closes #20201 from cloud-fan/partition-reporting. --- .../plans/physical/partitioning.scala | 2 +- .../v2/reader/ClusteredDistribution.java | 38 ++++++ .../sql/sources/v2/reader/Distribution.java | 39 +++++++ .../sql/sources/v2/reader/Partitioning.java | 46 ++++++++ .../v2/reader/SupportsReportPartitioning.java | 33 ++++++ .../v2/DataSourcePartitioning.scala | 56 +++++++++ .../datasources/v2/DataSourceV2ScanExec.scala | 9 ++ .../v2/JavaPartitionAwareDataSource.java | 110 ++++++++++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 79 +++++++++++++ 9 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0189bd73c56bf..4d9a9925fe3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { * 1. number of partitions. * 2. if it can satisfy a given distribution. */ -sealed trait Partitioning { +trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java new file mode 100644 index 0000000000000..7346500de45b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A concrete implementation of {@link Distribution}. Represents a distribution where records that + * share the same values for the {@link #clusteredColumns} will be produced by the same + * {@link ReadTask}. + */ +@InterfaceStability.Evolving +public class ClusteredDistribution implements Distribution { + + /** + * The names of the clustered columns. Note that they are order insensitive. + */ + public final String[] clusteredColumns; + + public ClusteredDistribution(String[] clusteredColumns) { + this.clusteredColumns = clusteredColumns; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java new file mode 100644 index 0000000000000..a6201a222f541 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent data distribution requirement, which specifies how the records should + * be distributed among the {@link ReadTask}s that are returned by + * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with + * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * + * The instance of this interface is created and provided by Spark, then consumed by + * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to + * implement this interface, but need to catch as more concrete implementations of this interface + * as possible in {@link Partitioning#satisfy(Distribution)}. + * + * Concrete implementations until now: + *
    + *
  • {@link ClusteredDistribution}
  • + *
+ */ +@InterfaceStability.Evolving +public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java new file mode 100644 index 0000000000000..199e45d4a02ab --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent the output data partitioning for a data source, which is returned by + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of + * partitions and the same "satisfy" result for a certain distribution. + */ +@InterfaceStability.Evolving +public interface Partitioning { + + /** + * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + */ + int numPartitions(); + + /** + * Returns true if this partitioning can satisfy the given distribution, which means Spark does + * not need to shuffle the output data of this data source for some certain operations. + * + * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases. + * This method should be aware of it and always return false for unrecognized distributions. It's + * recommended to check every Spark new release and support new distributions if possible, to + * avoid shuffle at Spark side for more cases. + */ + boolean satisfy(Distribution distribution); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java new file mode 100644 index 0000000000000..f786472ccf345 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to report data partitioning and try to avoid shuffle at Spark side. + */ +@InterfaceStability.Evolving +public interface SupportsReportPartitioning { + + /** + * Returns the output data partitioning that this reader guarantees. + */ + Partitioning outputPartitioning(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala new file mode 100644 index 0000000000000..943d0100aca56 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} + +/** + * An adapter from public data source partitioning to catalyst internal `Partitioning`. + */ +class DataSourcePartitioning( + partitioning: Partitioning, + colNames: AttributeMap[String]) extends physical.Partitioning { + + override val numPartitions: Int = partitioning.numPartitions() + + override def satisfies(required: physical.Distribution): Boolean = { + super.satisfies(required) || { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) + + case _ => false + } + } + } + + private def isCandidate(clustering: Seq[Expression]): Boolean = { + clustering.forall { + case a: Attribute => colNames.contains(a) + case _ => false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index beb66738732be..69d871df3e1dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ @@ -42,6 +43,14 @@ case class DataSourceV2ScanExec( override def producedAttributes: AttributeSet = AttributeSet(fullOutput) + override def outputPartitioning: physical.Partitioning = reader match { + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + + case _ => super.outputPartitioning + } + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() case _ => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java new file mode 100644 index 0000000000000..806d0bcd93f18 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + } + + @Override + public Partitioning outputPartitioning() { + return new MyPartitioning(); + } + } + + static class MyPartitioning implements Partitioning { + + @Override + public int numPartitions() { + return 2; + } + + @Override + public boolean satisfy(Distribution distribution) { + if (distribution instanceof ClusteredDistribution) { + String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; + return Arrays.asList(clusteredCols).contains("a"); + } + + return false; + } + } + + static class SpecificReadTask implements ReadTask, DataReader { + private int[] i; + private int[] j; + private int current = -1; + + SpecificReadTask(int[] i, int[] j) { + assert i.length == j.length; + this.i = i; + this.j = j; + } + + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } + + @Override + public DataReader createDataReader() { + return this; + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0ca29524c6d05..0620693b35d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ @@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("partitioning reporting") { + import org.apache.spark.sql.functions.{count, sum} + Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) + + val groupByColA = df.groupBy('a).agg(sum('b)) + checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) + assert(groupByColA.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) + assert(groupByColAB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColB = df.groupBy('b).agg(sum('a)) + checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(groupByColB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + } + } + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int) override def close(): Unit = batch.close() } + +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + // Note that we don't have same value of column `a` across partitions. + java.util.Arrays.asList( + new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + class MyPartitioning extends Partitioning { + override def numPartitions(): Int = 2 + + override def satisfy(distribution: Distribution): Boolean = distribution match { + case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case _ => false + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { + assert(i.length == j.length) + + private var current = -1 + + override def createDataReader(): DataReader[Row] = this + + override def next(): Boolean = { + current += 1 + current < i.length + } + + override def get(): Row = Row(i(current), j(current)) + + override def close(): Unit = {} +} From b2ce17b4c9fea58140a57ca1846b2689b15c0d61 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 23 Jan 2018 14:11:30 +0900 Subject: [PATCH 0177/2461] [SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) ## What changes were proposed in this pull request? Add support for using pandas UDFs with groupby().agg(). This PR introduces a new type of pandas UDF - group aggregate pandas UDF. This type of UDF defines a transformation of multiple pandas Series -> a scalar value. Group aggregate pandas UDFs can be used with groupby().agg(). Note group aggregate pandas UDF doesn't support partial aggregation, i.e., a full shuffle is required. This PR doesn't support group aggregate pandas UDFs that return ArrayType, StructType or MapType. Support for these types is left for future PR. ## How was this patch tested? GroupbyAggPandasUDFTests Author: Li Jin Closes #19872 from icexelloss/SPARK-22274-groupby-agg. --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 36 +- python/pyspark/sql/group.py | 33 +- python/pyspark/sql/tests.py | 486 +++++++++++++++++- python/pyspark/sql/udf.py | 13 +- python/pyspark/worker.py | 22 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 +- .../sql/catalyst/expressions}/PythonUDF.scala | 31 +- .../sql/catalyst/planning/patterns.scala | 12 +- .../spark/sql/RelationalGroupedDataset.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 29 +- .../python/AggregateInPandasExec.scala | 155 ++++++ .../execution/python/ExtractPythonUDFs.scala | 16 +- .../python/UserDefinedPythonFunction.scala | 2 +- 15 files changed, 792 insertions(+), 61 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/python => catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions}/PythonUDF.scala (60%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1ec0e717fac29..29148a7ee558b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,12 +39,14 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGG_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b3915548fb14..6b018c3a38444 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGG_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 961b3267b44cf..a291c9b71913f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2089,6 +2089,8 @@ class PandasUDFType(object): GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + @since(1.3) def udf(f=None, returnType=StringType()): @@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + 3. GROUP_AGG + + A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. If your function is not deterministic, call @@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): raise ValueError("Invalid returnType: returnType can not be None") if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 22061b83eb78c..f90a909d7c2b1 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -65,13 +65,27 @@ def __init__(self, jgd, df): def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + The available aggregate functions can be: + + 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` + + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` + + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -82,6 +96,13 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def min_udf(v): + ... return v.min() + >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP + [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -204,16 +225,18 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. :param udf: a group map user-defined function returned by - :meth:`pyspark.sql.functions.pandas_udf`. + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4fee2ecde391b..84e8eec71dd8a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -197,6 +197,12 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -3371,12 +3377,6 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3414,7 +3414,7 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3425,11 +3425,11 @@ def test_toPandas_respect_session_timezone(self): self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") try: pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) + self.assertPandasEqual(pdf_arrow_la, pdf_la) finally: self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3439,7 +3439,7 @@ def test_toPandas_respect_session_timezone(self): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) @@ -3447,7 +3447,7 @@ def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3505,7 +3505,7 @@ def test_createDataFrame_with_schema(self): df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() @@ -3717,7 +3717,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedSQLTestCase): +class ScalarPandasUDF(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4196,13 +4196,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedSQLTestCase): - - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) +class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4227,7 +4221,7 @@ def test_simple(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_register_group_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4251,7 +4245,7 @@ def foo(pdf): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4266,7 +4260,7 @@ def test_coerce(self): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4285,7 +4279,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4304,7 +4298,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4318,7 +4312,7 @@ def test_datatype_string(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4370,6 +4364,446 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyAggPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, (int, float)) + return v + 1 + return plus_one + + @property + def pandas_scalar_plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum(v): + return v.sum() + return sum + + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + return weighted_mean + + def test_manual(self): + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5], + [1, 255.0, 25.5], + [2, 265.0, 26.5], + [3, 275.0, 27.5], + [4, 285.0, 28.5], + [5, 295.0, 29.5], + [6, 305.0, 30.5], + [7, 315.0, 31.5], + [8, 325.0, 32.5], + [9, 335.0, 33.5]], + ['id', 'sum(v)', 'avg(v)']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + # Groupby one column and aggregate one UDF with literal + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + # Groupby one expression and aggregate one UDF with literal + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + # Groupby one column and aggregate one UDF without literal + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + # Groupby one expression and aggregate one UDF without literal + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_unsupported_types(self): + from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return [v.mean(), v.std()] + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF with sql expression + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + expected1 = (df.groupby('id') + .agg(sum(df.v) + 1) + .sort('id')) + + # Mix group aggregate pandas UDF with sql expression (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1)) + .sort('id')) + + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF and python UDF + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and python UDF (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + expected2 = (df.groupby('id') + .agg(sum(plus_one(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(plus_two(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v))) + .sort('id')) + + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v)))) + .sort('plus_one(id)')) + + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + expected1 = (df.groupBy('id') + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_groupby(self): + from pyspark.sql.functions import lit, sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + + # empty groupby + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v)) + + # groupby one column and one sql expression + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + + # groupby one python UDF + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + + # groupby one scalar pandas UDF + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) + + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + # Test sequential groupby aggregate + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v'))) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v'))) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_retain_group_columns(self): + from pyspark.sql.functions import sum, lit, col + orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) + self.spark.conf.set("spark.sql.retainGroupColumns", False) + try: + df = self.data + sum_udf = self.pandas_agg_sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v)) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.retainGroupColumns") + else: + self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) + + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.python_plus_one + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 134badb8485f5..de96846c5c774 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,7 +22,8 @@ from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ + _parse_datatype_string __all__ = ["UDFRegistration"] @@ -36,8 +37,10 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ - evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -113,6 +116,10 @@ def returnType(self): and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " "pandas_udf with function type GROUP_MAP") + elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): + raise NotImplementedError( + "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e6737ae1c1285..173d8fb2856fa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -110,6 +110,17 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_agg_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,8 +137,12 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - else: + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: + return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): @@ -148,8 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bbcec5627bd49..ef91d79f3302c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -153,11 +153,19 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + def isAggregateExpression(expr: Expression) = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + } + def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case aggExpr: AggregateExpression => - aggExpr.aggregateFunction.children.foreach { child => + case expr: Expression if isAggregateExpression(expr) => + val aggFunction = expr match { + case agg: AggregateExpression => agg.aggregateFunction + case udf: PythonUDF => udf + } + aggFunction.children.foreach { child => child.foreach { - case agg: AggregateExpression => + case expr: Expression if isAggregateExpression(expr) => failAnalysis( s"It is not allowed to use an aggregate function in the argument of " + s"another aggregate function. Please use the inner aggregate function " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index d3f743d9eb61e..4ba8ff6e3802f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -15,12 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType +/** + * Helper functions for [[PythonUDF]] + */ +object PythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) + } + + def isGroupAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + } +} + /** * A serialized version of a Python lambda function. */ @@ -30,12 +49,16 @@ case class PythonUDF( dataType: DataType, children: Seq[Expression], evalType: Int, - udfDeterministic: Boolean) + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"$name(${children.mkString(", ")})" + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cc391aae55787..132241061d510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { case logical.Aggregate(groupingExpressions, resultExpressions, child) => @@ -213,7 +214,10 @@ object PhysicalAggregation { expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression - if (!equivalentAggregateExpressions.addExpr(agg)) => agg + if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF + if PythonUDF.isGroupAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -241,6 +245,10 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + // Similar to AggregateExpression + case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a009c00b0abc5..d320c1c359411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c318..ce512bc46563a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -288,9 +289,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") + } + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, - aggregateExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, planLater(child)) @@ -333,8 +339,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -363,6 +371,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateOperator + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + + case PhysicalAggregation(_, _, _, _) => + // If cannot match the two cases above, then it's an error + throw new AnalysisException( + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala new file mode 100644 index 0000000000000..18e5f8605c60d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +/** + * Physical node for aggregation with group aggregate Pandas UDF. + * + * This plan works by sending the necessary (projected) input grouped data as Arrow record batches + * to the python worker, the python worker invokes the UDF and sends the results to the executor, + * finally the executor evaluates any post-aggregation expressions and join the result with the + * grouped key. + */ +case class AggregateInPandasExec( + groupingExpressions: Seq[NamedExpression], + udfExpressions: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExpressions.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val prunedProj = UnsafeProjection.create(allInputs, child.output) + + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d0..1862e3f6e12ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,12 +39,13 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined }.isDefined } @@ -93,7 +94,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { - e.find(_.isInstanceOf[PythonUDF]).isDefined + e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -106,12 +107,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas can be evaluated directly in python worker + // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) @@ -149,10 +150,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(udf => - udf.evalType == PythonEvalType.SQL_BATCHED_UDF || - udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF - ), "Can only extract scalar vectorized udf or sql batch udf") + require( + validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 50dca32cb7861..f4c2d02ee9420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} import org.apache.spark.sql.types.DataType /** From 96cb60bc33936c1aaf728a1738781073891480ff Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 23 Jan 2018 04:08:32 -0800 Subject: [PATCH 0178/2461] [SPARK-22465][FOLLOWUP] Update the number of partitions of default partitioner when defaultParallelism is set ## What changes were proposed in this pull request? #20002 purposed a way to safe check the default partitioner, however, if `spark.default.parallelism` is set, the defaultParallelism still could be smaller than the proper number of partitions for upstreams RDDs. This PR tries to extend the approach to address the condition when `spark.default.parallelism` is set. The requirements where the PR helps with are : - Max partitioner is not eligible since it is atleast an order smaller, and - User has explicitly set 'spark.default.parallelism', and - Value of 'spark.default.parallelism' is lower than max partitioner - Since max partitioner was discarded due to being at least an order smaller, default parallelism is worse - even though user specified. Under the rest cases, the changes should be no-op. ## How was this patch tested? Add corresponding test cases in `PairRDDFunctionsSuite` and `PartitioningSuite`. Author: Xingbo Jiang Closes #20091 from jiangxb1987/partitioner. --- .../scala/org/apache/spark/Partitioner.scala | 51 ++++++++++--------- .../org/apache/spark/PartitioningSuite.scala | 44 +++++++++++++--- .../spark/rdd/PairRDDFunctionsSuite.scala | 45 +++++++++++++++- 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 437bbaae1968b..c940cb25d478b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -43,17 +43,19 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, and the number of partitions of the - * partitioner is either greater than or is less than and within a single order of - * magnitude of the max number of upstream partitions, choose that one. + * If spark.default.parallelism is set, we'll use the value of SparkContext defaultParallelism + * as the default partitions number, otherwise we'll use the max number of upstream partitions. * - * Otherwise, we use a default HashPartitioner. For the number of partitions, if - * spark.default.parallelism is set, then we'll use the value from SparkContext - * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * When available, we choose the partitioner from rdds with maximum number of partitions. If this + * partitioner is eligible (number of partitions within an order of maximum number of partitions + * in rdds), or has partition number higher than default partitions number - we use this + * partitioner. * - * Unless spark.default.parallelism is set, the number of partitions will be the - * same as the number of partitions in the largest upstream RDD, as this should - * be least likely to cause out-of-memory errors. + * Otherwise, we'll use a new HashPartitioner with the default partitions number. + * + * Unless spark.default.parallelism is set, the number of partitions will be the same as the + * number of partitions in the largest upstream RDD, as this should be least likely to cause + * out-of-memory errors. * * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ @@ -67,31 +69,32 @@ object Partitioner { None } - if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + val defaultNumPartitions = if (rdd.context.conf.contains("spark.default.parallelism")) { + rdd.context.defaultParallelism + } else { + rdds.map(_.partitions.length).max + } + + // If the existing max partitioner is an eligible one, or its partitions number is larger + // than the default number of partitions, use the existing partitioner. + if (hasMaxPartitioner.nonEmpty && (isEligiblePartitioner(hasMaxPartitioner.get, rdds) || + defaultNumPartitions < hasMaxPartitioner.get.getNumPartitions)) { hasMaxPartitioner.get.partitioner.get } else { - if (rdd.context.conf.contains("spark.default.parallelism")) { - new HashPartitioner(rdd.context.defaultParallelism) - } else { - new HashPartitioner(rdds.map(_.partitions.length).max) - } + new HashPartitioner(defaultNumPartitions) } } /** - * Returns true if the number of partitions of the RDD is either greater - * than or is less than and within a single order of magnitude of the - * max number of upstream partitions; - * otherwise, returns false + * Returns true if the number of partitions of the RDD is either greater than or is less than and + * within a single order of magnitude of the max number of upstream partitions, otherwise returns + * false. */ private def isEligiblePartitioner( - hasMaxPartitioner: Option[RDD[_]], + hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = { - if (hasMaxPartitioner.isEmpty) { - return false - } val maxPartitions = rdds.map(_.partitions.length).max - log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + log10(maxPartitions) - log10(hasMaxPartitioner.getNumPartitions) < 1 } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 155ca17db726b..9206b5debf4f3 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -262,14 +262,11 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("defaultPartitioner") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) - val rdd2 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(10)) - val rdd3 = sc - .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) .partitionBy(new HashPartitioner(100)) - val rdd4 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(9)) val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) @@ -284,7 +281,42 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva assert(partitioner3.numPartitions == rdd3.getNumPartitions) assert(partitioner4.numPartitions == rdd3.getNumPartitions) assert(partitioner5.numPartitions == rdd4.getNumPartitions) + } + test("defaultPartitioner when defaultParallelism is set") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + val rdd6 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(3)) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + val partitioner6 = Partitioner.defaultPartitioner(rdd5, rdd5) + val partitioner7 = Partitioner.defaultPartitioner(rdd1, rdd6) + + assert(partitioner1.numPartitions == rdd2.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + assert(partitioner6.numPartitions == sc.defaultParallelism) + assert(partitioner7.numPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index a39e0469272fe..47af5c3320dd9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -322,8 +322,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } // See SPARK-22465 - test("cogroup between multiple RDD" + - " with number of partitions similar in order of magnitude") { + test("cogroup between multiple RDD with number of partitions similar in order of magnitude") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) val rdd2 = sc .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) @@ -332,6 +331,48 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.getNumPartitions == rdd2.getNumPartitions) } + test("cogroup between multiple RDD when defaultParallelism is set without proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)), 10) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set with proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set; with huge number of " + + "partitions in upstream RDDs") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) From ee572ba8c1339d21c592001ec4f7f270005ff1cf Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 21:36:20 +0900 Subject: [PATCH 0179/2461] [SPARK-20749][SQL][FOLLOW-UP] Override prettyName for bit_length and octet_length ## What changes were proposed in this pull request? We need to override the prettyName for bit_length and octet_length for getting the expected auto-generated alias name. ## How was this patch tested? The existing tests Author: gatorsmile Closes #20358 from gatorsmile/test2.3More. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../expressions/stringExpressions.scala | 4 ++ .../sql-tests/results/operators.sql.out | 4 +- .../scalar-subquery-predicate.sql.out | 45 ++++++++++--------- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 39d5e4ed56628..5fa75fe348e68 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e004bfc6af473..5cf783f1a5979 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1708,6 +1708,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") } } + + override def prettyName: String = "bit_length" } /** @@ -1735,6 +1737,8 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override def prettyName: String = "octet_length" } /** diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 237b618a8b904..840655b7a6447 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -425,7 +425,7 @@ struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NUL -- !query 51 select BIT_LENGTH('abc') -- !query 51 schema -struct +struct -- !query 51 output 24 @@ -449,7 +449,7 @@ struct -- !query 54 select OCTET_LENGTH('abc') -- !query 54 schema -struct +struct -- !query 54 output 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a2b86db3e4f4c..dd82efba0dde1 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 27 -- !query 0 @@ -307,7 +307,8 @@ struct val1c val1d --- !query 22 + +-- !query 20 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -315,13 +316,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 20 schema struct --- !query 22 output +-- !query 20 output 7 --- !query 23 +-- !query 21 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -332,14 +333,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 21 schema struct --- !query 23 output +-- !query 21 output val1b val1c --- !query 24 +-- !query 22 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -353,14 +354,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 22 schema struct --- !query 24 output +-- !query 22 output val1b val1c --- !query 25 +-- !query 23 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -374,9 +375,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 23 schema struct --- !query 25 output +-- !query 23 output val1a val1a val1b @@ -387,7 +388,7 @@ val1d val1d --- !query 26 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -401,16 +402,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 26 schema +-- !query 24 schema struct --- !query 26 output +-- !query 24 output val1a val1b val1c val1d --- !query 27 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -424,13 +425,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 27 schema +-- !query 25 schema struct --- !query 27 output +-- !query 25 output val1a --- !query 28 +-- !query 26 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -438,8 +439,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 28 schema +-- !query 26 schema struct --- !query 28 output +-- !query 26 output val1b val1c From bdebb8e48eafcca0382d1a3173b2f3ce969abab3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Jan 2018 10:12:13 -0800 Subject: [PATCH 0180/2461] [SPARK-20664][SPARK-23103][CORE] Follow-up: remove workaround for . Author: Marcelo Vanzin Closes #20353 from vanzin/SPARK-20664. --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 787de59edf465..fde5f25bce456 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -716,9 +716,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("SPARK-21571: clean up removes invalid history files") { - // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that - // until we figure out what's causing the problem. - val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") val provider = new FsHistoryProvider(conf, clock) { override def getNewLastScanTime(): Long = clock.getTimeMillis() From dc4761fd8f0eec1d001e53837e65f7c5fe4e248d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Jan 2018 12:51:40 -0800 Subject: [PATCH 0181/2461] [SPARK-17088][HIVE] Fix 'sharesHadoopClasses' option when creating client. Because the call to the constructor of HiveClientImpl crosses class loader boundaries, different versions of the same class (Configuration in this case) were loaded, and that caused a runtime error when instantiating the client. By using a safer type in the signature of the constructor, it's possible to avoid the problem. I considered removing 'sharesHadoopClasses', but it may still be desired (even though there are 0 users of it since it was not working). When Spark starts to support Hadoop 3, it may be necessary to use that option to load clients for older Hive metastore versions that don't know about Hadoop 3. Tested with added unit test. Author: Marcelo Vanzin Closes #20169 from vanzin/SPARK-17088. --- .../spark/sql/hive/client/HiveClientImpl.scala | 8 +++++--- .../sql/hive/client/IsolatedClientLoader.scala | 16 ++++++++++------ .../sql/hive/client/HiveClientBuilder.scala | 6 ++++-- .../spark/sql/hive/client/HiveClientSuite.scala | 4 ++++ .../spark/sql/hive/client/HiveVersionSuite.scala | 11 ++++++++--- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 4b923f5235a90..39d839059be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} -import java.util.Locale +import java.lang.{Iterable => JIterable} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -82,8 +83,9 @@ import org.apache.spark.util.{CircularBuffer, Utils} */ private[hive] class HiveClientImpl( override val version: HiveVersion, + warehouseDir: Option[String], sparkConf: SparkConf, - hadoopConf: Configuration, + hadoopConf: JIterable[JMap.Entry[String, String]], extraConfig: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) @@ -130,7 +132,7 @@ private[hive] class HiveClientImpl( if (ret != null) { // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState // instance constructed, we need to follow that change here. - Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + warehouseDir.foreach { dir => ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) } ret diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7a76fd3fd2eb3..dac0e333b63bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkSubmitUtils @@ -48,11 +49,12 @@ private[hive] object IsolatedClientLoader extends Logging { config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, - barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { + barrierPrefixes: Seq[String] = Seq.empty, + sharesHadoopClasses: Boolean = true): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(hiveMetastoreVersion) // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact // with the given version, we will use Hadoop 2.6 and then will not share Hadoop classes. - var sharesHadoopClasses = true + var _sharesHadoopClasses = sharesHadoopClasses val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { resolvedVersions((resolvedVersion, hadoopVersion)) } else { @@ -68,7 +70,7 @@ private[hive] object IsolatedClientLoader extends Logging { "Hadoop classes will not be shared between Spark and Hive metastore client. " + "It is recommended to set jars used by Hive metastore client through " + "spark.sql.hive.metastore.jars in the production environment.") - sharesHadoopClasses = false + _sharesHadoopClasses = false (downloadVersion(resolvedVersion, "2.6.5", ivyPath), "2.6.5") } resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) @@ -81,7 +83,7 @@ private[hive] object IsolatedClientLoader extends Logging { execJars = files, hadoopConf = hadoopConf, config = config, - sharesHadoopClasses = sharesHadoopClasses, + sharesHadoopClasses = _sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -249,8 +251,10 @@ private[hive] class IsolatedClientLoader( /** The isolated client interface to Hive. */ private[hive] def createClient(): HiveClient = synchronized { + val warehouseDir = Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)) if (!isolationOn) { - return new HiveClientImpl(version, sparkConf, hadoopConf, config, baseClassLoader, this) + return new HiveClientImpl(version, warehouseDir, sparkConf, hadoopConf, config, + baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -261,7 +265,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head - .newInstance(version, sparkConf, hadoopConf, config, classLoader, this) + .newInstance(version, warehouseDir, sparkConf, hadoopConf, config, classLoader, this) .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index ae804ce7c7b07..ab73f668c6ca6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -46,13 +46,15 @@ private[client] object HiveClientBuilder { def buildClient( version: String, hadoopConf: Configuration, - extraConf: Map[String, String] = Map.empty): HiveClient = { + extraConf: Map[String, String] = Map.empty, + sharesHadoopClasses: Boolean = true): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = new SparkConf(), hadoopConf = hadoopConf, config = buildConf(extraConf), - ivyPath = ivyPath).createClient() + ivyPath = ivyPath, + sharesHadoopClasses = sharesHadoopClasses).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index a5dfd89b3a574..f991352b207d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -202,6 +202,10 @@ class HiveClientSuite(version: String) day1 :: day2 :: Nil) } + test("create client with sharesHadoopClasses = false") { + buildClient(new Configuration(), sharesHadoopClasses = false) + } + private def testMetastorePartitionFiltering( filterString: String, expectedDs: Seq[Int], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index bb8a4697b0a13..a70fb6464cc1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -28,7 +28,9 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu override protected val enableAutoThreadAudit = false protected var client: HiveClient = null - protected def buildClient(hadoopConf: Configuration): HiveClient = { + protected def buildClient( + hadoopConf: Configuration, + sharesHadoopClasses: Boolean = true): HiveClient = { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 @@ -36,8 +38,11 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - HiveClientBuilder - .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) + HiveClientBuilder.buildClient( + version, + hadoopConf, + HiveUtils.formatTimeVarsForHiveClient(hadoopConf), + sharesHadoopClasses = sharesHadoopClasses) } override def suiteName: String = s"${super.suiteName}($version)" From 05839d164836e544af79c13de25802552eadd636 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 23 Jan 2018 14:11:23 -0800 Subject: [PATCH 0182/2461] [SPARK-22735][ML][DOC] Added VectorSizeHint docs and examples. ## What changes were proposed in this pull request? Added documentation for new transformer. Author: Bago Amirbekian Closes #20285 from MrBago/sizeHintDocs. --- docs/ml-features.md | 51 ++++++++++++ .../ml/JavaVectorSizeHintExample.java | 79 +++++++++++++++++++ .../python/ml/vector_size_hint_example.py | 57 +++++++++++++ .../examples/ml/VectorSizeHintExample.scala | 63 +++++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java create mode 100644 examples/src/main/python/ml/vector_size_hint_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 466a8fbe99cf6..3370eb3893272 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1291,6 +1291,57 @@ for more details on the API.
+## VectorSizeHint + +It can sometimes be useful to explicitly specify the size of the vectors for a column of +`VectorType`. For example, `VectorAssembler` uses size information from its input columns to +produce size information and metadata for its output column. While in some cases this information +can be obtained by inspecting the contents of the column, in a streaming dataframe the contents are +not available until the stream is started. `VectorSizeHint` allows a user to explicitly specify the +vector size for a column so that `VectorAssembler`, or other transformers that might +need to know vector size, can use that column as an input. + +To use `VectorSizeHint` a user must set the `inputCol` and `size` parameters. Applying this +transformer to a dataframe produces a new dataframe with updated metadata for `inputCol` specifying +the vector size. Downstream operations on the resulting dataframe can get this size using the +meatadata. + +`VectorSizeHint` can also take an optional `handleInvalid` parameter which controls its +behaviour when the vector column contains nulls or vectors of the wrong size. By default +`handleInvalid` is set to "error", indicating an exception should be thrown. This parameter can +also be set to "skip", indicating that rows containing invalid values should be filtered out from +the resulting dataframe, or "optimistic", indicating that the column should not be checked for +invalid values and all rows should be kept. Note that the use of "optimistic" can cause the +resulting dataframe to be in an inconsistent state, me:aning the metadata for the column +`VectorSizeHint` was applied to does not match the contents of that column. Users should take care +to avoid this kind of inconsistent state. + +
+
+ +Refer to the [VectorSizeHint Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala %} +
+ +
+ +Refer to the [VectorSizeHint Java docs](api/java/org/apache/spark/ml/feature/VectorSizeHint.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java %} +
+ +
+ +Refer to the [VectorSizeHint Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example python/ml/vector_size_hint_example.py %} +
+
+ ## QuantileDiscretizer `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java new file mode 100644 index 0000000000000..d649a2ccbaa72 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.ml.feature.VectorSizeHint; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorSizeHintExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorSizeHintExample") + .getOrCreate(); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row0 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + Row row1 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0); + Dataset dataset = spark.createDataFrame(Arrays.asList(row0, row1), schema); + + VectorSizeHint sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3); + + Dataset datasetWithSize = sizeHint.transform(dataset); + System.out.println("Rows where 'userFeatures' is not the right size are filtered out"); + datasetWithSize.show(false); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + // This dataframe can be used by downstream transformers as before + Dataset output = assembler.transform(datasetWithSize); + System.out.println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column " + + "'features'"); + output.select("features", "clicked").show(false); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/ml/vector_size_hint_example.py b/examples/src/main/python/ml/vector_size_hint_example.py new file mode 100644 index 0000000000000..fb77dacec629d --- /dev/null +++ b/examples/src/main/python/ml/vector_size_hint_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.linalg import Vectors +from pyspark.ml.feature import (VectorSizeHint, VectorAssembler) +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("VectorSizeHintExample")\ + .getOrCreate() + + # $example on$ + dataset = spark.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0), + (0, 18, 1.0, Vectors.dense([0.0, 10.0]), 0.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + + sizeHint = VectorSizeHint( + inputCol="userFeatures", + handleInvalid="skip", + size=3) + + datasetWithSize = sizeHint.transform(dataset) + print("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(truncate=False) + + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + + # This dataframe can be used by downstream transformers as before + output = assembler.transform(datasetWithSize) + print("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala new file mode 100644 index 0000000000000..688731a791f35 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{VectorAssembler, VectorSizeHint} +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +object VectorSizeHintExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("VectorSizeHintExample") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame( + Seq( + (0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0), + (0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3) + + val datasetWithSize = sizeHint.transform(dataset) + println("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(false) + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + // This dataframe can be used by downstream transformers as before + val output = assembler.transform(datasetWithSize) + println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(false) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From 613c290336e3826111164c24319f66774b1f65a3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 14:56:28 -0800 Subject: [PATCH 0183/2461] [SPARK-23192][SQL] Keep the Hint after Using Cached Data ## What changes were proposed in this pull request? The hint of the plan segment is lost, if the plan segment is replaced by the cached data. ```Scala val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") df2.cache() val df3 = df1.join(broadcast(df2), Seq("key"), "inner") ``` This PR is to fix it. ## How was this patch tested? Added a test Author: gatorsmile Closes #20365 from gatorsmile/fixBroadcastHintloss. --- .../apache/spark/sql/execution/CacheManager.scala | 12 ++++++++---- .../sql/execution/joins/BroadcastJoinSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b05fe49a6ac3b..432eb59d6fe57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -170,9 +170,13 @@ class CacheManager extends Logging { def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { case currentFragment => - lookupCachedData(currentFragment) - .map(_.cachedRepresentation.withOutput(currentFragment.output)) - .getOrElse(currentFragment) + lookupCachedData(currentFragment).map { cached => + val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) + currentFragment match { + case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) + case _ => cachedPlan + } + }.getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0bcd54e1fceab..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -109,6 +109,19 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained after using the cached data") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") From 44cc4daf3a03f1a220eef8ce3c86867745db9ab7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 16:17:09 -0800 Subject: [PATCH 0184/2461] [SPARK-23195][SQL] Keep the Hint of Cached Data ## What changes were proposed in this pull request? The broadcast hint of the cached plan is lost if we cache the plan. This PR is to correct it. ```Scala val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") broadcast(df2).cache() df2.collect() val df3 = df1.join(df2, Seq("key"), "inner") ``` ## How was this patch tested? Added a test. Author: gatorsmile Closes #20368 from gatorsmile/cachedBroadcastHint. --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..5945808c4abfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..889cab0489534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,6 +139,22 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained in a cached plan") { + Seq(true, false).foreach { materialized => + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + broadcast(df2).cache() + if (materialized) df2.collect() + val df3 = df1.join(df2, Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + } + private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") From 15adcc8273e73352e5e1c3fc9915c0b004ec4836 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Jan 2018 16:24:20 -0800 Subject: [PATCH 0185/2461] [SPARK-23197][DSTREAMS] Increased timeouts to resolve flakiness ## What changes were proposed in this pull request? Increased timeout from 50 ms to 300 ms (50 ms was really too low). ## How was this patch tested? Multiple rounds of tests. Author: Tathagata Das Closes #20371 from tdas/SPARK-23197. --- .../scala/org/apache/spark/streaming/ReceiverSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 145c48e5a9a72..fc6218a33f741 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -105,13 +105,13 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { assert(executor.errors.head.eq(exception)) // Verify restarting actually stops and starts the receiver - receiver.restart("restarting", null, 100) - eventually(timeout(50 millis), interval(10 millis)) { + receiver.restart("restarting", null, 600) + eventually(timeout(300 millis), interval(10 millis)) { // receiver will be stopped async assert(receiver.isStopped) assert(receiver.onStopCalled) } - eventually(timeout(1000 millis), interval(100 millis)) { + eventually(timeout(1000 millis), interval(10 millis)) { // receiver will be started async assert(receiver.onStartCalled) assert(executor.isReceiverStarted) From a3911cf896de6e9386042ae4d93632cba69eef0f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Jan 2018 11:43:48 +0900 Subject: [PATCH 0186/2461] [SPARK-23177][SQL][PYSPARK] Extract zero-parameter UDFs from aggregate ## What changes were proposed in this pull request? We extract Python UDFs in logical aggregate which depends on aggregate expression or grouping key in ExtractPythonUDFFromAggregate rule. But Python UDFs which don't depend on above expressions should also be extracted to avoid the issue reported in the JIRA. A small code snippet to reproduce that issue looks like: ```python import pyspark.sql.functions as f df = spark.createDataFrame([(1,2), (3,4)]) f_udf = f.udf(lambda: str("const_str")) df2 = df.distinct().withColumn("a", f_udf()) df2.show() ``` Error exception is raised as: ``` : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#50 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:91) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:90) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:90) at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:514) at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:513) ``` This exception raises because `HashAggregateExec` tries to bind the aliased Python UDF expression (e.g., `pythonUDF0#50 AS a#44`) to grouping key. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20360 from viirya/SPARK-23177. --- python/pyspark/sql/tests.py | 8 ++++++++ .../spark/sql/execution/python/ExtractPythonUDFs.scala | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84e8eec71dd8a..a466ab87d882d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1106,6 +1106,14 @@ def myudf(x): rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1, 2), (1, 2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1862e3f6e12ca..4ae4e164830be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or - * grouping key, evaluate them after aggregate. + * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate. */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { @@ -45,7 +45,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && + (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined) }.isDefined } From f54b65c15a732540f7a41a9083eeb7a08feca125 Mon Sep 17 00:00:00 2001 From: neilalex Date: Tue, 23 Jan 2018 22:31:14 -0800 Subject: [PATCH 0187/2461] [SPARK-21727][R] Allow multi-element atomic vector as column type in SparkR DataFrame MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A fix to https://issues.apache.org/jira/browse/SPARK-21727, "Operating on an ArrayType in a SparkR DataFrame throws error" ## How was this patch tested? - Ran tests at R\pkg\tests\run-all.R (see below attached results) - Tested the following lines in SparkR, which now seem to execute without error: ``` indices <- 1:4 myDf <- data.frame(indices) myDf$data <- list(rep(0, 20)) mySparkDf <- as.DataFrame(myDf) collect(mySparkDf) ``` [2018-01-22 SPARK-21727 Test Results.txt](https://github.com/apache/spark/files/1653535/2018-01-22.SPARK-21727.Test.Results.txt) felixcheung yanboliang sun-rui shivaram _The contribution is my original work and I license the work to the project under the project’s open source license_ Author: neilalex Closes #20352 from neilalex/neilalex-sparkr-arraytype. --- R/pkg/R/serialize.R | 11 +++---- R/pkg/tests/fulltests/test_Serde.R | 47 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 3bbf60d9b668c..263b9b576c0c5 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -30,14 +30,17 @@ # POSIXct,POSIXlt -> Time # # list[T] -> Array[T], where T is one of above mentioned types +# Multi-element vector of any of the above (except raw) -> Array[T] # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend # nolint end getSerdeType <- function(object) { type <- class(object)[[1]] - if (type != "list") { - type + if (is.atomic(object) & !is.raw(object) & length(object) > 1) { + "array" + } else if (type != "list") { + type } else { # Check if all elements are of same type elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) @@ -50,9 +53,7 @@ getSerdeType <- function(object) { } writeObject <- function(con, object, writeType = TRUE) { - # NOTE: In R vectors have same type as objects. So we don't support - # passing in vectors as arrays and instead require arrays to be passed - # as lists. + # NOTE: In R vectors have same type as objects type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") # Checking types is needed here, since 'is.na' only handles atomic vectors, # lists and pairlists diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 6bbd201bf1d82..3577929323b8b 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -37,6 +37,53 @@ test_that("SerDe of primitive types", { expect_equal(class(x), "character") }) +test_that("SerDe of multi-element primitive vectors inside R data.frame", { + # vector of integers embedded in R data.frame + indices <- 1L:3L + myDf <- data.frame(indices) + myDf$data <- list(rep(0L, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0L, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "integer") + + # vector of numeric embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(0, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "numeric") + + # vector of logical embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(TRUE, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(TRUE, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "logical") + + # vector of character embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep("abc", 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep("abc", 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "character") +}) + test_that("SerDe of list of primitive types", { x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) From 4e7b49041aceca0beafec20f697b63a473a2b42f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 22:38:20 -0800 Subject: [PATCH 0188/2461] Revert "[SPARK-23195][SQL] Keep the Hint of Cached Data" This reverts commit 44cc4daf3a03f1a220eef8ce3c86867745db9ab7. --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ---------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 5945808c4abfb..51928d914841e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics = null) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) + Statistics(sizeInBytes = batchStats.value.longValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 889cab0489534..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,22 +139,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained in a cached plan") { - Seq(true, false).foreach { materialized => - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - broadcast(df2).cache() - if (materialized) df2.collect() - val df3 = df1.join(df2, Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) - } - } - } - private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") From 7af1a325da57daa2e25c713472a320f4ccb43d71 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 24 Jan 2018 21:13:47 +0900 Subject: [PATCH 0189/2461] [SPARK-23174][BUILD][PYTHON] python code style checker update ## What changes were proposed in this pull request? Referencing latest python code style checking from PyPi/pycodestyle Removed pending TODO For now, in tox.ini excluded the additional style error discovered on existing python due to latest style checker (will fallback on review comment to finalize exclusion or fix py) Any further code styling requirement needs to be part of pycodestyle, not in SPARK. ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Author: rjoshi2 Closes #20338 from rekhajoshm/SPARK-11222. --- dev/lint-python | 37 ++++++++++++++++++------------------- dev/run-tests.py | 5 ++++- dev/tox.ini | 4 ++-- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index df8df037a5f69..e069cafa1b8c6 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,7 +21,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" # Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" SPHINXBUILD=${SPHINXBUILD:=sphinx-build} @@ -30,23 +30,22 @@ SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" -# Get pep8 at runtime so that we don't rely on it being installed on the build server. +# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -#+ TODOs: -#+ - Download pep8 from PyPI. It's more "official". -PEP8_VERSION="1.7.0" -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" +# Updated to latest official version for pep8. pep8 is formally renamed to pycodestyle. +PYCODESTYLE_VERSION="2.3.1" +PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" +PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" -if [ ! -e "$PEP8_SCRIPT_PATH" ]; then - curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" curl_status="$?" if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." exit "$curl_status" fi fi @@ -64,23 +63,23 @@ export "PATH=$PYTHONPATH:$PATH" #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" -pep8_status="${PIPESTATUS[0]}" +python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" +pycodestyle_status="${PIPESTATUS[0]}" -if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then +if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then lint_status=0 else lint_status=1 fi if [ "$lint_status" -ne 0 ]; then - echo "PEP8 checks failed." - cat "$PEP8_REPORT_PATH" - rm "$PEP8_REPORT_PATH" + echo "PYCODESTYLE checks failed." + cat "$PYCODESTYLE_REPORT_PATH" + rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" else - echo "PEP8 checks passed." - rm "$PEP8_REPORT_PATH" + echo "pycodestyle checks passed." + rm "$PYCODESTYLE_REPORT_PATH" fi # Check that the documentation builds acceptably, skip check if sphinx is not installed. diff --git a/dev/run-tests.py b/dev/run-tests.py index fb270c4ee0508..fe75ef4411c8c 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -576,7 +576,10 @@ def main(): for f in changed_files): # run_java_style_checks() pass - if not changed_files or any(f.endswith(".py") for f in changed_files): + if not changed_files or any(f.endswith("lint-python") + or f.endswith("tox.ini") + or f.endswith(".py") + for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") or f.endswith("lint-r") diff --git a/dev/tox.ini b/dev/tox.ini index eb8b1eb2c2886..583c1eaaa966b 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -[pep8] -ignore=E402,E731,E241,W503,E226 +[pycodestyle] +ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* From de36f65d3a819c00d6bf6979deef46c824203669 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 24 Jan 2018 21:19:09 +0900 Subject: [PATCH 0190/2461] [SPARK-23148][SQL] Allow pathnames with special characters for CSV / JSON / text MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …JSON / text ## What changes were proposed in this pull request? Fix for JSON and CSV data sources when file names include characters that would be changed by URL encoding. ## How was this patch tested? New unit tests for JSON, CSV and text suites Author: Henry Robinson Closes #20355 from henryr/spark-23148. --- .../execution/datasources/CodecStreams.scala | 6 +++--- .../datasources/csv/CSVDataSource.scala | 11 ++++++----- .../datasources/json/JsonDataSource.scala | 10 ++++++---- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++-- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 54549f698aca5..c0df6c779d7bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -45,11 +45,11 @@ object CodecStreams { } /** - * Creates an input stream from the string path and add a closure for the input stream to be + * Creates an input stream from the given path and add a closure for the input stream to be * closed on task completion. */ - def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { - val inputStream = createInputStream(config, new Path(path)) + def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { + val inputStream = createInputStream(config, path) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2031381dd2e10..4870d75fc5f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.datasources.csv +import java.net.URI import java.nio.charset.{Charset, StandardCharsets} import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -32,7 +33,6 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -206,7 +206,7 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, schema) @@ -218,8 +218,9 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => + val path = new Path(lines.getPath()) UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) }.take(1).headOption match { @@ -230,7 +231,7 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource( lines.getConfiguration, - lines.getPath()), + new Path(lines.getPath())), parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 8b7c2709afde1..77e7edc8e7a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.datasources.json import java.io.InputStream +import java.net.URI import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -168,9 +169,10 @@ object MultiLineJsonDataSource extends JsonDataSource { } private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + val path = new Path(record.getPath()) CreateJacksonParser.inputStream( jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) } override def readFile( @@ -180,7 +182,7 @@ object MultiLineJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))) } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } @@ -193,6 +195,6 @@ object MultiLineJsonDataSource extends JsonDataSource { parser.options.columnNameOfCorruptRecord) safeParser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 22fb496bc838e..c272c99ae45a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { import testImplicits._ private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + private val nameWithSpecialChars = "sp&cial%c hars" allFileBasedDataSources.foreach { format => test(s"Writing empty datasets should not fail - $format") { @@ -54,7 +55,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. // `TEXT` data source always has a single column whose name is `value`. Seq("orc", "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + test(s"SPARK-15474 Write and read back non-empty schema with empty dataframe - $format") { withTempPath { file => val path = file.getCanonicalPath val emptyDf = Seq((true, 1, "str")).toDF().limit(0) @@ -69,7 +70,6 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" withTempDir { dir => val tmpFile = s"$dir/$nameWithSpecialChars" spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) @@ -78,4 +78,18 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { } } } + + // Separate test case for formats that support multiLine as an option. + Seq("json", "csv").foreach { format => + test("SPARK-23148 read files containing special characters " + + s"using $format with multiline enabled") { + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val reader = spark.read.format(format).option("multiLine", true) + val fileContent = reader.load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } } From 0ec95bb7df775be33fc8983f6c0983a67032d2c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 24 Jan 2018 11:34:59 -0600 Subject: [PATCH 0191/2461] [SPARK-22577][CORE] executor page blacklist status should update with TaskSet level blacklisting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR stage blacklisting is propagated to UI by introducing a new Spark listener event (SparkListenerExecutorBlacklistedForStage) which indicates the executor is blacklisted for a stage. Either because of the number of failures are exceeded a limit given for an executor (spark.blacklist.stage.maxFailedTasksPerExecutor) or because of the whole node is blacklisted for a stage (spark.blacklist.stage.maxFailedExecutorsPerNode). In case of the node is blacklisting all executors will listed as blacklisted for the stage. Blacklisting state for a selected stage can be seen "Aggregated Metrics by Executor" table's blacklisting column, where after this change three possible labels could be found: - "for application": when the executor is blacklisted for the application (see the configuration spark.blacklist.application.maxFailedTasksPerExecutor for details) - "for stage": when the executor is **only** blacklisted for the stage - "false" : when the executor is not blacklisted at all ## How was this patch tested? It is tested both manually and with unit tests. #### Unit tests - HistoryServerSuite - TaskSetBlacklistSuite - AppStatusListenerSuite #### Manual test for executor blacklisting Running Spark as a local cluster: ``` $ bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" --conf "spark.eventLog.enabled=true" ``` Executing: ``` scala import org.apache.spark.SparkEnv sc.parallelize(1 to 10, 10).map { x => if (SparkEnv.get.executorId == "0") throw new RuntimeException("Bad executor") else (x % 3, x) }.reduceByKey((a, b) => a + b).collect() ``` To see result check the "Aggregated Metrics by Executor" section at the bottom of picture: ![UI screenshot for stage level blacklisting executor](https://issues.apache.org/jira/secure/attachment/12905283/stage_blacklisting.png) #### Manual test for node blacklisting Running Spark as on a cluster: ``` bash ./bin/spark-shell --master yarn --deploy-mode client --executor-memory=2G --num-executors=8 --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.stage.maxFailedExecutorsPerNode=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" --conf "spark.eventLog.enabled=true" ``` And the job was: ``` scala import org.apache.spark.SparkEnv sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt >= 4) throw new RuntimeException("Bad executor") else (x % 3, x) }.reduceByKey((a, b) => a + b).collect() ``` The result is: ![UI screenshot for stage level node blacklisting](https://issues.apache.org/jira/secure/attachment/12906833/node_blacklisting_for_stage.png) Here you can see apiros3.gce.test.com was node blacklisted for the stage because of failures on executor 4 and 5. As expected executor 3 is also blacklisted even it has no failures itself but sharing the node with 4 and 5. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20203 from attilapiros/SPARK-22577. --- .../apache/spark/SparkFirehoseListener.java | 12 + .../scheduler/EventLoggingListener.scala | 9 + .../spark/scheduler/SparkListener.scala | 35 + .../spark/scheduler/SparkListenerBus.scala | 4 + .../spark/scheduler/TaskSetBlacklist.scala | 19 +- .../spark/scheduler/TaskSetManager.scala | 2 +- .../spark/status/AppStatusListener.scala | 25 + .../org/apache/spark/status/LiveEntity.scala | 4 +- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 10 +- .../application_list_json_expectation.json | 70 +- .../blacklisting_for_stage_expectation.json | 639 ++++++++++++++ ...acklisting_node_for_stage_expectation.json | 783 ++++++++++++++++++ .../completed_app_list_json_expectation.json | 71 +- .../limit_app_list_json_expectation.json | 54 +- .../minDate_app_list_json_expectation.json | 62 +- .../minEndDate_app_list_json_expectation.json | 34 +- .../one_stage_attempt_json_expectation.json | 3 +- .../one_stage_json_expectation.json | 3 +- ...age_with_accumulable_json_expectation.json | 3 +- .../spark-events/app-20180109111548-0000 | 59 ++ .../application_1516285256255_0012 | 71 ++ .../deploy/history/HistoryServerSuite.scala | 2 + .../scheduler/BlacklistTrackerSuite.scala | 2 +- .../scheduler/TaskSetBlacklistSuite.scala | 119 ++- .../spark/status/AppStatusListenerSuite.scala | 43 + dev/.rat-excludes | 2 + 27 files changed, 2040 insertions(+), 103 deletions(-) create mode 100644 core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json create mode 100644 core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json create mode 100755 core/src/test/resources/spark-events/app-20180109111548-0000 create mode 100755 core/src/test/resources/spark-events/application_1516285256255_0012 diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 3583856d88998..94c5c11b61a50 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,6 +118,18 @@ public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executo onEvent(executorBlacklisted); } + @Override + public void onExecutorBlacklistedForStage( + SparkListenerExecutorBlacklistedForStage executorBlacklistedForStage) { + onEvent(executorBlacklistedForStage); + } + + @Override + public void onNodeBlacklistedForStage( + SparkListenerNodeBlacklistedForStage nodeBlacklistedForStage) { + onEvent(nodeBlacklistedForStage); + } + @Override public final void onExecutorUnblacklisted( SparkListenerExecutorUnblacklisted executorUnblacklisted) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index b3a5b1f1e05b3..69bc51c1ecf90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,15 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { logEvent(event, flushLogger = true) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3b677ca9657db..8a112f6a37b96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -120,6 +120,24 @@ case class SparkListenerExecutorBlacklisted( taskFailures: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorBlacklistedForStage( + time: Long, + executorId: String, + taskFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerNodeBlacklistedForStage( + time: Long, + hostId: String, + executorFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String) extends SparkListenerEvent @@ -261,6 +279,17 @@ private[spark] trait SparkListenerInterface { */ def onExecutorBlacklisted(executorBlacklisted: SparkListenerExecutorBlacklisted): Unit + /** + * Called when the driver blacklists an executor for a stage. + */ + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit + + /** + * Called when the driver blacklists a node for a stage. + */ + def onNodeBlacklistedForStage(nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit + /** * Called when the driver re-enables a previously blacklisted executor. */ @@ -339,6 +368,12 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorBlacklisted( executorBlacklisted: SparkListenerExecutorBlacklisted): Unit = { } + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit = { } + + def onNodeBlacklistedForStage( + nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit = { } + override def onExecutorUnblacklisted( executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 056c0cbded435..ff19cc65552e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,10 @@ private[spark] trait SparkListenerBus listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage => + listener.onExecutorBlacklistedForStage(executorBlacklistedForStage) + case nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage => + listener.onNodeBlacklistedForStage(nodeBlacklistedForStage) case executorBlacklisted: SparkListenerExecutorBlacklisted => listener.onExecutorBlacklisted(executorBlacklisted) case executorUnblacklisted: SparkListenerExecutorUnblacklisted => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index 233781f3d9719..b680979a466a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -36,8 +36,12 @@ import org.apache.spark.util.Clock * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. */ -private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock) - extends Logging { +private[scheduler] class TaskSetBlacklist( + private val listenerBus: LiveListenerBus, + val conf: SparkConf, + val stageId: Int, + val stageAttemptId: Int, + val clock: Clock) extends Logging { private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -128,16 +132,23 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, } // Check if enough tasks have failed on the executor to blacklist it for the entire stage. - if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) { + val numFailures = execFailures.numUniqueTasksWithFailures + if (numFailures >= MAX_FAILURES_PER_EXEC_STAGE) { if (blacklistedExecs.add(exec)) { logInfo(s"Blacklisting executor ${exec} for stage $stageId") // This executor has been pushed into the blacklist for this stage. Let's check if it // pushes the whole node into the blacklist. val blacklistedExecutorsOnNode = execsWithFailuresOnNode.filter(blacklistedExecs.contains(_)) - if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) { + val now = clock.getTimeMillis() + listenerBus.post( + SparkListenerExecutorBlacklistedForStage(now, exec, numFailures, stageId, stageAttemptId)) + val numFailExec = blacklistedExecutorsOnNode.size + if (numFailExec >= MAX_FAILED_EXEC_PER_NODE_STAGE) { if (blacklistedNodes.add(host)) { logInfo(s"Blacklisting ${host} for stage $stageId") + listenerBus.post( + SparkListenerNodeBlacklistedForStage(now, host, numFailExec, stageId, stageAttemptId)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c3ed11bfe352a..886c2c99f1ff3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -102,7 +102,7 @@ private[spark] class TaskSetManager( private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(sched.sc.listenerBus, conf, stageId, taskSet.stageAttemptId, clock) } } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index b4edcf23abc09..3e34bdc0c7b63 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -211,6 +211,31 @@ private[spark] class AppStatusListener( updateBlackListStatus(event.executorId, true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + val now = System.nanoTime() + val esummary = stage.executorSummary(event.executorId) + esummary.isBlacklisted = true + maybeUpdate(esummary, now) + } + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + val now = System.nanoTime() + + // Implicitly blacklist every available executor for the stage associated with this node + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + liveExecutors.values.foreach { exec => + if (exec.hostname == event.hostId) { + val esummary = stage.executorSummary(exec.executorId) + esummary.isBlacklisted = true + maybeUpdate(esummary, now) + } + } + } + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 4295e664e131c..d5f9e19ffdcd0 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -316,6 +316,7 @@ private class LiveExecutorStageSummary( var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 + var isBlacklisted = false var metrics = createMetrics(default = 0L) @@ -334,7 +335,8 @@ private class LiveExecutorStageSummary( metrics.shuffleWriteMetrics.bytesWritten, metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, - metrics.diskBytesSpilled) + metrics.diskBytesSpilled, + isBlacklisted) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 7d8e4de3c8efb..550eac3952bbb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -68,7 +68,8 @@ class ExecutorStageSummary private[spark]( val shuffleWrite : Long, val shuffleWriteRecords : Long, val memoryBytesSpilled : Long, - val diskBytesSpilled : Long) + val diskBytesSpilled : Long, + val isBlacklistedForStage: Boolean) class ExecutorSummary private[spark]( val id: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 95c12b1e73653..0ff64f053f371 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -136,7 +136,15 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { {Utils.bytesToString(v.diskBytesSpilled)} }} -
+ { + if (executor.map(_.isBlacklisted).getOrElse(false)) { + + } else if (v.isBlacklistedForStage) { + + } else { + + } + } } } diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index f2c3ec5da8891..4fecf84db65a2 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,9 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -125,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -140,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json new file mode 100644 index 0000000000000..5e9e8230e2745 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -0,0 +1,639 @@ +{ + "status": "COMPLETE", + "stageId": 0, + "attemptId": 0, + "numTasks": 10, + "numActiveTasks": 0, + "numCompleteTasks": 10, + "numFailedTasks": 2, + "numKilledTasks": 0, + "numCompletedIndices": 10, + "executorRunTime": 761, + "executorCpuTime": 269916000, + "submissionTime": "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", + "completionTime": "2018-01-09T10:21:19.062GMT", + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleReadBytes": 0, + "shuffleReadRecords": 0, + "shuffleWriteBytes": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "name": "map at :26", + "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool": "default", + "rddIds": [ + 1, + 0 + ], + "accumulatorUpdates": [], + "tasks": { + "0": { + "taskId": 0, + "index": 0, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.347GMT", + "duration": 562, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 460, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 14, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 3873006, + "recordsWritten": 0 + } + } + }, + "5": { + "taskId": 5, + "index": 3, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.958GMT", + "duration": 22, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2586000, + "executorRunTime": 9, + "executorCpuTime": 9635000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 262919, + "recordsWritten": 1 + } + } + }, + "10": { + "taskId": 10, + "index": 8, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.034GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1803000, + "executorRunTime": 6, + "executorCpuTime": 6157000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 243647, + "recordsWritten": 1 + } + } + }, + "1": { + "taskId": 1, + "index": 1, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.364GMT", + "duration": 565, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 301, + "executorDeserializeCpuTime": 200029000, + "executorRunTime": 212, + "executorCpuTime": 198479000, + "resultSize": 1115, + "jvmGcTime": 13, + "resultSerializationTime": 1, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 2409488, + "recordsWritten": 1 + } + } + }, + "6": { + "taskId": 6, + "index": 4, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.980GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2610000, + "executorRunTime": 10, + "executorCpuTime": 9622000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 385110, + "recordsWritten": 1 + } + } + }, + "9": { + "taskId": 9, + "index": 7, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.022GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1981000, + "executorRunTime": 7, + "executorCpuTime": 6335000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 259354, + "recordsWritten": 1 + } + } + }, + "2": { + "taskId": 2, + "index": 2, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.899GMT", + "duration": 27, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 16, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 126128, + "recordsWritten": 0 + } + } + }, + "7": { + "taskId": 7, + "index": 5, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.996GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2231000, + "executorRunTime": 9, + "executorCpuTime": 8407000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 205520, + "recordsWritten": 1 + } + } + }, + "3": { + "taskId": 3, + "index": 0, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.919GMT", + "duration": 24, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 8, + "executorDeserializeCpuTime": 8878000, + "executorRunTime": 10, + "executorCpuTime": 9364000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 207014, + "recordsWritten": 1 + } + } + }, + "11": { + "taskId": 11, + "index": 9, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.045GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2017000, + "executorRunTime": 6, + "executorCpuTime": 6676000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 233652, + "recordsWritten": 1 + } + } + }, + "8": { + "taskId": 8, + "index": 6, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.011GMT", + "duration": 11, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 1, + "executorDeserializeCpuTime": 1554000, + "executorRunTime": 7, + "executorCpuTime": 6034000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 213296, + "recordsWritten": 1 + } + } + }, + "4": { + "taskId": 4, + "index": 2, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.943GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2211000, + "executorRunTime": 9, + "executorCpuTime": 9207000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 292381, + "recordsWritten": 1 + } + } + } + }, + "executorSummary": { + "0": { + "taskTime": 589, + "failedTasks": 2, + "succeededTasks": 0, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 0, + "shuffleWriteRecords": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": true + }, + "1": { + "taskTime": 708, + "failedTasks": 0, + "succeededTasks": 10, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": false + } + }, + "killedTasksSummary": {} +} diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json new file mode 100644 index 0000000000000..acd4cc53de6cd --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -0,0 +1,783 @@ +{ + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 4, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 5080, + "executorCpuTime" : 1163210819, + "submissionTime" : "2018-01-18T18:33:12.658GMT", + "firstTaskLaunchedTime" : "2018-01-18T18:33:12.816GMT", + "completionTime" : "2018-01-18T18:33:15.279GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 1461, + "shuffleWriteRecords" : 30, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :27", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.816GMT", + "duration" : 2064, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1081, + "executorDeserializeCpuTime" : 353981050, + "executorRunTime" : 914, + "executorCpuTime" : 368865439, + "resultSize" : 1134, + "jvmGcTime" : 75, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3662221, + "recordsWritten" : 3 + } + } + }, + "5" : { + "taskId" : 5, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.320GMT", + "duration" : 73, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 27, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 191901, + "recordsWritten" : 0 + } + } + }, + "10" : { + "taskId" : 10, + "index" : 1, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:15.069GMT", + "duration" : 132, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4598966, + "executorRunTime" : 76, + "executorCpuTime" : 20826337, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 301705, + "recordsWritten" : 3 + } + } + }, + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1506, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1332, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 33, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3075188, + "recordsWritten" : 0 + } + } + }, + "6" : { + "taskId" : 6, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.323GMT", + "duration" : 67, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 51, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 183718, + "recordsWritten" : 0 + } + } + }, + "9" : { + "taskId" : 9, + "index" : 4, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.973GMT", + "duration" : 96, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4793905, + "executorRunTime" : 48, + "executorCpuTime" : 25678331, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 366050, + "recordsWritten" : 3 + } + } + }, + "13" : { + "taskId" : 13, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.200GMT", + "duration" : 76, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 25, + "executorDeserializeCpuTime" : 5860574, + "executorRunTime" : 25, + "executorCpuTime" : 20585619, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 369513, + "recordsWritten" : 3 + } + } + }, + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1774, + "executorId" : "3", + "host" : "apiros-2.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1206, + "executorDeserializeCpuTime" : 263386625, + "executorRunTime" : 493, + "executorCpuTime" : 278399617, + "resultSize" : 1134, + "jvmGcTime" : 78, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3322956, + "recordsWritten" : 3 + } + } + }, + "12" : { + "taskId" : 12, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.165GMT", + "duration" : 60, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4010338, + "executorRunTime" : 34, + "executorCpuTime" : 21657558, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 319101, + "recordsWritten" : 3 + } + } + }, + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.859GMT", + "duration" : 115, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 10894331, + "executorRunTime" : 84, + "executorCpuTime" : 28283110, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 377601, + "recordsWritten" : 3 + } + } + }, + "3" : { + "taskId" : 3, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 2027, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1282, + "executorDeserializeCpuTime" : 365807898, + "executorRunTime" : 681, + "executorCpuTime" : 349920830, + "resultSize" : 1134, + "jvmGcTime" : 102, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 3587839, + "recordsWritten" : 3 + } + } + }, + "11" : { + "taskId" : 11, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.072GMT", + "duration" : 93, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4239884, + "executorRunTime" : 77, + "executorCpuTime" : 21689428, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 323898, + "recordsWritten" : 3 + } + } + }, + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.879GMT", + "duration" : 194, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 56, + "executorDeserializeCpuTime" : 12246145, + "executorRunTime" : 54, + "executorCpuTime" : 27304550, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 311940, + "recordsWritten" : 3 + } + } + }, + "4" : { + "taskId" : 4, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 1522, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1184, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 82, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 16858066, + "recordsWritten" : 0 + } + } + } + }, + "executorSummary" : { + "4" : { + "taskTime" : 1589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "5" : { + "taskTime" : 1579, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "1" : { + "taskTime" : 2411, + "failedTasks" : 0, + "succeededTasks" : 4, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 585, + "shuffleWriteRecords" : 12, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "2" : { + "taskTime" : 2446, + "failedTasks" : 0, + "succeededTasks" : 5, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 732, + "shuffleWriteRecords" : 15, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "3" : { + "taskTime" : 1774, + "failedTasks" : 0, + "succeededTasks" : 1, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 144, + "shuffleWriteRecords" : 3, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + } + }, + "killedTasksSummary" : { } +} \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index c925c1dd8a4d3..4fecf84db65a2 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,10 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -126,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -141,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index cc0b2b0022bd3..79950b0dc6486 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,46 +1,46 @@ [ { - "id" : "app-20161116163331-0000", + "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-16T22:33:29.916GMT", - "endTime" : "2016-11-16T22:33:40.587GMT", + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", "lastUpdated" : "", - "duration" : 10671, - "sparkUser" : "jose", + "duration" : 472819, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, - "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 } ] }, { - "id" : "app-20161115172038-0000", + "id" : "app-20180109111548-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-15T23:20:37.079GMT", - "endTime" : "2016-11-15T23:22:18.874GMT", + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", "lastUpdated" : "", - "duration" : 101795, - "sparkUser" : "jose", + "duration" : 535234, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, - "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 } ] }, { - "id" : "local-1430917381534", + "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:11.398GMT", + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", "lastUpdated" : "", - "duration" : 10505, - "sparkUser" : "irashid", + "duration" : 10671, + "sparkUser" : "jose", "completed" : true, - "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, - "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.1.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, + "endTimeEpoch" : 1479335620587 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 5af50abd85330..7d60977dcd4fe 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,8 +140,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index 7f896c74b5be1..dfbfd8aedcc23 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,8 +39,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, "endTimeEpoch" : 1479335620587 } ] }, { @@ -24,8 +54,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479252037079, "endTimeEpoch" : 1479252138874 } ] }, { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 31093a661663b..03f886afa5413 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 601d70695b17c..947c89906955d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 9cdcef0746185..963f010968b62 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -465,7 +465,8 @@ "shuffleWrite" : 0, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/spark-events/app-20180109111548-0000 b/core/src/test/resources/spark-events/app-20180109111548-0000 new file mode 100755 index 0000000000000..50893d3001b95 --- /dev/null +++ b/core/src/test/resources/spark-events/app-20180109111548-0000 @@ -0,0 +1,59 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","Java Version":"1.8.0_152 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"172.30.65.138","spark.eventLog.enabled":"true","spark.driver.port":"64273","spark.repl.class.uri":"spark://172.30.65.138:64273/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/spark-811c1b49-eb66-4bfb-91ae-33b45efa269d/repl-c4438f51-ee23-41ed-8e04-71496e2f40f5","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"local-cluster[2,1,1024]","spark.home":"*********(redacted)","spark.sql.catalogImplementation":"in-memory","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.app.id":"app-20180109111548-0000"},"System Properties":{"java.io.tmpdir":"/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib","user.dir":"*********(redacted)","java.library.path":"*********(redacted)","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.152-b16","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_152-b16","java.vm.info":"mixed mode","java.ext.dirs":"*********(redacted)","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.12.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"*********(redacted)","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[2,1,1024] --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","java.version":"1.8.0_152","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/conf/":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-client-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"*********(redacted)"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20180109111548-0000","Timestamp":1515492942372,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965588,"Executor ID":"0","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout","stderr":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965598,"Executor ID":"1","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout","stderr":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.30.65.138","Port":64290},"Maximum Memory":384093388,"Timestamp":1515492965643,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.30.65.138","Port":64291},"Maximum Memory":384093388,"Timestamp":1515492965652,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1515493278122,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1515493278918,"executorId":"0","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"460","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"14","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3873006","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278909,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3873006,"Value":3873006,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":14,"Value":14,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":460,"Value":460,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":460,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":14,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3873006,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"16","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"126128","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278926,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":126128,"Value":3999134,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":476,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":16,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":126128,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278929,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":2409488,"Value":6408622,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":46,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":896,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":13,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1115,"Value":1115,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":198479000,"Value":198479000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":212,"Value":688,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":200029000,"Value":200029000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":301,"Value":301,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":301,"Executor Deserialize CPU Time":200029000,"Executor Run Time":212,"Executor CPU Time":198479000,"Result Size":1115,"JVM GC Time":13,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":2409488,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":8,"Executor Deserialize CPU Time":8878000,"Executor Run Time":10,"Executor CPU Time":9364000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":207014,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2211000,"Executor Run Time":9,"Executor CPU Time":9207000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":292381,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2586000,"Executor Run Time":9,"Executor CPU Time":9635000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":262919,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278996,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":385110,"Value":7556046,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":230,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":4528,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":5231,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9622000,"Value":236307000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":726,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2610000,"Value":216314000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":317,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2610000,"Executor Run Time":10,"Executor CPU Time":9622000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":385110,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279011,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":205520,"Value":7761566,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":276,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":5440,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":6260,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":8407000,"Value":244714000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":735,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2231000,"Value":218545000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":319,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2231000,"Executor Run Time":9,"Executor CPU Time":8407000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":205520,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279022,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":213296,"Value":7974862,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":6352,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":7289,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6034000,"Value":250748000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":742,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1554000,"Value":220099000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1,"Value":320,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1,"Executor Deserialize CPU Time":1554000,"Executor Run Time":7,"Executor CPU Time":6034000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":213296,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279034,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":259354,"Value":8234216,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":368,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":7264,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":8318,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6335000,"Value":257083000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":749,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1981000,"Value":222080000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":322,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1981000,"Executor Run Time":7,"Executor CPU Time":6335000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":259354,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279046,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":243647,"Value":8477863,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":414,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":8176,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":9347,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6157000,"Value":263240000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":755,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1803000,"Value":223883000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":324,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1803000,"Executor Run Time":6,"Executor CPU Time":6157000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":243647,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279060,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":233652,"Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":460,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":9088,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":10376,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6676000,"Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":761,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2017000,"Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":327,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2017000,"Executor Run Time":6,"Executor CPU Time":6676000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":233652,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Completion Time":1515493279062,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":761,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10376,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":10,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":9088,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":460,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":327,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279152,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":184,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":1286,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":41280000,"Value":41280000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":53,"Value":53,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":11820000,"Value":11820000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":17,"Value":17,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":17,"Executor Deserialize CPU Time":11820000,"Executor Run Time":53,"Executor CPU Time":41280000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":184,"Total Records Read":4},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279167,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":138,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":2572,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7673000,"Value":48953000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":61,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1706000,"Value":13526000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":19,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1706000,"Executor Run Time":8,"Executor CPU Time":7673000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":3,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":138,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279180,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3706,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6972000,"Value":55925000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":68,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1569000,"Value":15095000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":21,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1569000,"Executor Run Time":7,"Executor CPU Time":6972000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279190,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4840,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":60830000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":73,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1882000,"Value":16977000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":23,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1882000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279194,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":23,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":138,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":6126,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":56742000,"Value":117572000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":89,"Value":162,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12625000,"Value":29602000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":18,"Value":41,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":18,"Executor Deserialize CPU Time":12625000,"Executor Run Time":89,"Executor CPU Time":56742000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":3,"Local Blocks Fetched":0,"Fetch Wait Time":23,"Remote Bytes Read":138,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279203,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7260,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6476000,"Value":124048000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":169,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1890000,"Value":31492000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":43,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1890000,"Executor Run Time":7,"Executor CPU Time":6476000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279216,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":8394,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6927000,"Value":130975000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":176,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2038000,"Value":33530000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":45,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2038000,"Executor Run Time":7,"Executor CPU Time":6927000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279218,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":9528,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11214000,"Value":142189000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":16,"Value":192,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2697000,"Value":36227000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":49,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":2697000,"Executor Run Time":16,"Executor CPU Time":11214000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279226,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":10662,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":147094000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":197,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1700000,"Value":37927000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":51,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1700000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279232,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7850000,"Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":205,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2186000,"Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":54,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2186000,"Executor Run Time":8,"Executor CPU Time":7850000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Completion Time":1515493279232,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":23,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":138,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":322,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":54,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":7,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":205,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":10,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1515493279237,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1515493477606} diff --git a/core/src/test/resources/spark-events/application_1516285256255_0012 b/core/src/test/resources/spark-events/application_1516285256255_0012 new file mode 100755 index 0000000000000..3e1736c3fe224 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1516285256255_0012 @@ -0,0 +1,71 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","Java Version":"1.8.0_161 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"apiros-1.gce.test.com","spark.eventLog.enabled":"true","spark.driver.port":"33058","spark.repl.class.uri":"spark://apiros-1.gce.test.com:33058/classes","spark.jars":"","spark.repl.class.outputDir":"/tmp/spark-6781fb17-e07a-4b32-848b-9936c2e88b33/repl-c0fd7008-04be-471e-a173-6ad3e62d53d7","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"1","spark.scheduler.mode":"FIFO","spark.executor.instances":"8","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"yarn","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/github/spark","spark.sql.catalogImplementation":"hive","spark.driver.appUIAddress":"http://apiros-1.gce.test.com:4040","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"apiros-1.gce.test.com","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://apiros-1.gce.test.com:8088/proxy/application_1516285256255_0012","spark.app.id":"application_1516285256255_0012"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.161-b14","java.endorsed.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/endorsed","java.runtime.version":"1.8.0_161-b14","java.vm.info":"mixed mode","java.ext.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"OpenJDK Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/resources.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/rt.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/sunrsasign.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jsse.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jce.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/charsets.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jfr.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"3.10.0-693.5.2.el7.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"OpenJDK 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.blacklist.stage.maxFailedExecutorsPerNode=1 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell --executor-memory 2G --num-executors 8 spark-shell","java.home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","java.version":"1.8.0_161","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-hive_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/github/spark/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libthrift-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stringtemplate-3.2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/JavaEWAH-0.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javolution-5.5.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libfb303-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jodd-core-3.5.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-0.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ST4-4.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-core-3.2.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-servlet-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/etc/hadoop/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-yarn_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-client-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-runtime-3.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jta-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/derby-10.12.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-logging-1.1.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-pool-1.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-web-proxy-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-2.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jdo-api-3.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-dbcp-1.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1516285256255_0012","Timestamp":1516300235119,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252095,"Executor ID":"2","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"apiros-3.gce.test.com","Port":38670},"Maximum Memory":956615884,"Timestamp":1516300252260,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252715,"Executor ID":"3","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252918,"Executor ID":"1","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"apiros-2.gce.test.com","Port":38641},"Maximum Memory":956615884,"Timestamp":1516300252959,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"apiros-3.gce.test.com","Port":34970},"Maximum Memory":956615884,"Timestamp":1516300252988,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300253542,"Executor ID":"4","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"apiros-2.gce.test.com","Port":33229},"Maximum Memory":956615884,"Timestamp":1516300253653,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300254323,"Executor ID":"5","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"apiros-2.gce.test.com","Port":45147},"Maximum Memory":956615884,"Timestamp":1516300254385,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1516300392631,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394348,"executorId":"5","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklistedForStage","time":1516300394348,"hostId":"apiros-2.gce.test.com","executorFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394356,"executorId":"4","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1332","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"33","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3075188","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394338,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3075188,"Value":3075188,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1332,"Value":1332,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1332,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3075188,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1184","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"82","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"16858066","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394355,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":16858066,"Value":19933254,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":82,"Value":115,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1184,"Value":2516,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1184,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":82,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":16858066,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"51","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"183718","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394390,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":183718,"Value":20116972,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":51,"Value":2567,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":51,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":183718,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"27","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"191901","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394393,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":191901,"Value":20308873,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":27,"Value":2594,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":27,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":191901,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394606,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3322956,"Value":23631829,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":1080,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":78,"Value":193,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":278399617,"Value":278399617,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":493,"Value":3087,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":263386625,"Value":263386625,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1206,"Value":1206,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1206,"Executor Deserialize CPU Time":263386625,"Executor Run Time":493,"Executor CPU Time":278399617,"Result Size":1134,"JVM GC Time":78,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3322956,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394860,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3587839,"Value":27219668,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":291,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":2160,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":102,"Value":295,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":349920830,"Value":628320447,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":681,"Value":3768,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":365807898,"Value":629194523,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1282,"Value":2488,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1282,"Executor Deserialize CPU Time":365807898,"Executor Run Time":681,"Executor CPU Time":349920830,"Result Size":1134,"JVM GC Time":102,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":3587839,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394880,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3662221,"Value":30881889,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":435,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":3240,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":75,"Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":368865439,"Value":997185886,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":914,"Value":4682,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":353981050,"Value":983175573,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1081,"Value":3569,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1081,"Executor Deserialize CPU Time":353981050,"Executor Run Time":914,"Executor CPU Time":368865439,"Result Size":1134,"JVM GC Time":75,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3662221,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394974,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":377601,"Value":31259490,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":582,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":4320,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":4450,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":28283110,"Value":1025468996,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":84,"Value":4766,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":10894331,"Value":994069904,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":11,"Value":3580,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":11,"Executor Deserialize CPU Time":10894331,"Executor Run Time":84,"Executor CPU Time":28283110,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":377601,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395069,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":366050,"Value":31625540,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":15,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":729,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":5400,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":5541,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":25678331,"Value":1051147327,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":48,"Value":4814,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4793905,"Value":998863809,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3585,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4793905,"Executor Run Time":48,"Executor CPU Time":25678331,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":366050,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395073,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":311940,"Value":31937480,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":876,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":6480,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":6589,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":27304550,"Value":1078451877,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":54,"Value":4868,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12246145,"Value":1011109954,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":3641,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":12246145,"Executor Run Time":54,"Executor CPU Time":27304550,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":311940,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395165,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":323898,"Value":32261378,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":21,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1023,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":7560,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":7637,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21689428,"Value":1100141305,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":77,"Value":4945,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4239884,"Value":1015349838,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3645,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4239884,"Executor Run Time":77,"Executor CPU Time":21689428,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":323898,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395201,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":301705,"Value":32563083,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":24,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":1167,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":8640,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":8728,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20826337,"Value":1120967642,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":76,"Value":5021,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4598966,"Value":1019948804,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3650,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4598966,"Executor Run Time":76,"Executor CPU Time":20826337,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":301705,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395225,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":319101,"Value":32882184,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1314,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":9720,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":9776,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21657558,"Value":1142625200,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":34,"Value":5055,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4010338,"Value":1023959142,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3654,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4010338,"Executor Run Time":34,"Executor CPU Time":21657558,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":319101,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395276,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":369513,"Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1461,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":10800,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":10824,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20585619,"Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":25,"Value":5080,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5860574,"Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":25,"Value":3679,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":25,"Executor Deserialize CPU Time":5860574,"Executor Run Time":25,"Executor CPU Time":20585619,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":369513,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Completion Time":1516300395279,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":5080,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10824,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":30,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":10800,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1461,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":5,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":3679,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395525,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52455999,"Value":52455999,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":95,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":23136577,"Value":23136577,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":82,"Value":82,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":82,"Executor Deserialize CPU Time":23136577,"Executor Run Time":95,"Executor CPU Time":52455999,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395576,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":13617615,"Value":66073614,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":29,"Value":124,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3469612,"Value":26606189,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":86,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3469612,"Executor Run Time":29,"Executor CPU Time":13617615,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395581,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":55540208,"Value":121613822,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":179,"Value":303,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22400065,"Value":49006254,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":78,"Value":164,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":78,"Executor Deserialize CPU Time":22400065,"Executor Run Time":179,"Executor CPU Time":55540208,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395593,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4536,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52311573,"Value":173925395,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":153,"Value":456,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":20519033,"Value":69525287,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":67,"Value":231,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":67,"Executor Deserialize CPU Time":20519033,"Executor Run Time":153,"Executor CPU Time":52311573,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395660,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":5670,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11294260,"Value":185219655,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":489,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3570887,"Value":73096174,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":235,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3570887,"Executor Run Time":33,"Executor CPU Time":11294260,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395669,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":6804,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":12983732,"Value":198203387,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":44,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3518757,"Value":76614931,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":239,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3518757,"Executor Run Time":44,"Executor CPU Time":12983732,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395674,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7938,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":14706240,"Value":212909627,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":64,"Value":597,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":7698059,"Value":84312990,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":21,"Value":260,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":21,"Executor Deserialize CPU Time":7698059,"Executor Run Time":64,"Executor CPU Time":14706240,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":52,"Value":52,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":195,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":292,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":9224,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91696783,"Value":304606410,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":221,"Value":818,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":24063461,"Value":108376451,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":150,"Value":410,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":150,"Executor Deserialize CPU Time":24063461,"Executor Run Time":221,"Executor CPU Time":91696783,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":52,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":20,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":107,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":244,"Value":439,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":243,"Value":535,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":5,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":5,"Value":11,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":10510,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91683507,"Value":396289917,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":289,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22106726,"Value":130483177,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":79,"Value":489,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":79,"Executor Deserialize CPU Time":22106726,"Executor Run Time":289,"Executor CPU Time":91683507,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":5,"Local Blocks Fetched":5,"Fetch Wait Time":107,"Remote Bytes Read":243,"Remote Bytes Read To Disk":0,"Local Bytes Read":244,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395728,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":634,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":827,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":13,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":17,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":17607810,"Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":1140,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2897647,"Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":491,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2897647,"Executor Run Time":33,"Executor CPU Time":17607810,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Completion Time":1516300395728,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":159,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":827,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":634,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":491,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":13,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":1140,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":17,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":30,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1516300395734,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1516300707938} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 7aa60f2b60796..87f12f303cd5e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -156,6 +156,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "applications/local-1426533911241/1/stages/0/0/taskList", "stage task list from multi-attempt app json(2)" -> "applications/local-1426533911241/2/stages/0/0/taskList", + "blacklisting for stage" -> "applications/app-20180109111548-0000/stages/0/0", + "blacklisting node for stage" -> "applications/application_1516285256255_0012/stages/0/0", "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index cd1b7a9e5ab18..afebcdd7b9e31 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -92,7 +92,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } def createTaskSetBlacklist(stageId: Int = 0): TaskSetBlacklist = { - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(listenerBusMock, conf, stageId, stageAttemptId = 0, clock = clock) } test("executors can be blacklisted with only a few failures per stage") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 18981d5be2f94..6e2709dbe1e8b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -16,18 +16,32 @@ */ package org.apache.spark.scheduler +import org.mockito.Matchers.isA +import org.mockito.Mockito.{never, verify} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config -import org.apache.spark.util.{ManualClock, SystemClock} +import org.apache.spark.util.ManualClock + +class TaskSetBlacklistSuite extends SparkFunSuite with BeforeAndAfterEach with MockitoSugar { -class TaskSetBlacklistSuite extends SparkFunSuite { + private var listenerBusMock: LiveListenerBus = _ + + override def beforeEach(): Unit = { + listenerBusMock = mock[LiveListenerBus] + super.beforeEach() + } test("Blacklisting tasks, executors, and nodes") { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val clock = new ManualClock + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) clock.setTime(0) // We will mark task 0 & 1 failed on both executor 1 & 2. // We should blacklist all executors on that host, for all tasks for the stage. Note the API @@ -46,27 +60,53 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val shouldBeBlacklisted = (executor == "exec1" && index == 0) assert(taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === shouldBeBlacklisted) } + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerExecutorBlacklistedForStage])) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec2", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(0, "hostA", 2, 0, attemptId)) + // Make sure the blacklist has the correct per-task && per-executor responses, over a wider // range of inputs. for { @@ -81,6 +121,10 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // intentional, it keeps it fast and is sufficient for usage in the scheduler. taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === (badExec && badIndex)) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet(executor) === badExec) + if (badExec) { + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, executor, 2, 0, attemptId)) + } } } assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -110,7 +154,14 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_TASK_ATTEMPTS_PER_NODE, 3) .set(config.MAX_FAILURES_PER_EXEC_STAGE, 2) .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + + var time = 0 + clock.setTime(time) // Fail a task twice on hostA, exec:1 taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") @@ -118,37 +169,75 @@ class TaskSetBlacklistSuite extends SparkFunSuite { "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail the same task once more on hostA, exec:2 + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 2, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 3, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 4, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "3", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 3, 0, attemptId)) } test("only blacklist nodes for the task set when all the blacklisted executors are all on " + @@ -157,22 +246,42 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // lead to any node blacklisting val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + var time = 0 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostB")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index e7981bec6d64b..042bba7f226fd 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -251,6 +251,49 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + // Blacklisting executor for stage + time += 1 + listener.onExecutorBlacklistedForStage(SparkListenerExecutorBlacklistedForStage( + time = time, + executorId = execIds.head, + taskFailures = 2, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappers = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappers.nonEmpty) + executorStageSummaryWrappers.foreach { exec => + // only the first executor is expected to be blacklisted + val expectedBlacklistedFlag = exec.executorId == execIds.head + assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) + } + + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "2.example.com", // this is where the second executor is hosted + executorFailures = 1, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappersForNode = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappersForNode.nonEmpty) + executorStageSummaryWrappersForNode.foreach { exec => + // both executor is expected to be blacklisted + assert(exec.info.isBlacklistedForStage === true) + } + // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 607234b4068d0..243fbe3e1bc24 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -73,8 +73,10 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation +app-20180109111548-0000 app-20161115172038-0000 app-20161116163331-0000 +application_1516285256255_0012 local-1422981759269 local-1422981780767 local-1425081759269 From e18d6f5326e0d9ea03d31de5ce04cb84d3b8ab37 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 24 Jan 2018 09:37:54 -0800 Subject: [PATCH 0192/2461] [SPARK-20906][SPARKR] Add API doc example for Constrained Logistic Regression ## What changes were proposed in this pull request? doc only changes ## How was this patch tested? manual Author: Felix Cheung Closes #20380 from felixcheung/rclrdoc. --- R/pkg/R/mllib_classification.R | 15 ++++++++++++++- R/pkg/tests/fulltests/test_mllib_classification.R | 10 +++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 7cd072a1d6f89..f6e9b1357561b 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -279,11 +279,24 @@ function(object, path, overwrite = FALSE) { #' savedModel <- read.ml(path) #' summary(savedModel) #' -#' # multinomial logistic regression +#' # binary logistic regression against two classes with +#' # upperBoundsOnCoefficients and upperBoundsOnIntercepts +#' ubc <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) +#' model <- spark.logit(training, Species ~ ., +#' upperBoundsOnCoefficients = ubc, +#' upperBoundsOnIntercepts = 1.0) #' +#' # multinomial logistic regression #' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' +#' # multinomial logistic regression with +#' # lowerBoundsOnCoefficients and lowerBoundsOnIntercepts +#' lbc <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) +#' lbi <- as.array(c(0.0, 0.0)) +#' model <- spark.logit(training, Species ~ ., family = "multinomial", +#' lowerBoundsOnCoefficients = lbc, +#' lowerBoundsOnIntercepts = lbi) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index ad47717ddc12f..a46c47dccd02e 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -124,7 +124,7 @@ test_that("spark.logit", { # Petal.Width 0.42122607 # nolint end - # Test multinomial logistic regression againt three classes + # Test multinomial logistic regression against three classes df <- suppressWarnings(createDataFrame(iris)) model <- spark.logit(df, Species ~ ., regParam = 0.5) summary <- summary(model) @@ -196,7 +196,7 @@ test_that("spark.logit", { # # nolint end - # Test multinomial logistic regression againt two classes + # Test multinomial logistic regression against two classes df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") @@ -208,7 +208,7 @@ test_that("spark.logit", { expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) - # Test binomial logistic regression againt two classes + # Test binomial logistic regression against two classes model <- spark.logit(training, Species ~ ., regParam = 0.5) summary <- summary(model) coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) @@ -239,7 +239,7 @@ test_that("spark.logit", { prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) - # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # Test binomial logistic regression against two classes with upperBoundsOnCoefficients # and upperBoundsOnIntercepts u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, @@ -252,7 +252,7 @@ test_that("spark.logit", { expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), upperBoundsOnIntercepts = 1.0)) - # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # Test binomial logistic regression against two classes with lowerBoundsOnCoefficients # and lowerBoundsOnIntercepts l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, From 8c273b4162b6138c4abba64f595c2750d1ef8bcb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 24 Jan 2018 10:00:42 -0800 Subject: [PATCH 0193/2461] [SPARK-23020][CORE][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20297 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/test/java/org/apache/spark/launcher/BaseSuite.java:[21,8] (imports) UnusedImports: Unused import - java.util.concurrent.TimeUnit. [ERROR] src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java:[27,8] (imports) UnusedImports: Unused import - java.util.concurrent.TimeUnit. ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20376 from ueshin/issues/SPARK-23020/fup1. --- .../test/java/org/apache/spark/launcher/SparkLauncherSuite.java | 1 - launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java | 1 - 2 files changed, 2 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..1543f4fdb0162 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Properties; -import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..438349e027a24 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.launcher; import java.time.Duration; -import java.util.concurrent.TimeUnit; import org.junit.After; import org.slf4j.bridge.SLF4JBridgeHandler; From bbb87b350d9d0d393db3fb7ca61dcbae538553bb Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Wed, 24 Jan 2018 10:07:24 -0800 Subject: [PATCH 0194/2461] [SPARK-22837][SQL] Session timeout checker does not work in SessionManager. ## What changes were proposed in this pull request? Currently we do not call the `super.init(hiveConf)` in `SparkSQLSessionManager.init`. So we do not load the config `HIVE_SERVER2_SESSION_CHECK_INTERVAL HIVE_SERVER2_IDLE_SESSION_TIMEOUT HIVE_SERVER2_IDLE_SESSION_CHECK_OPERATION` , which cause the session timeout checker does not work. ## How was this patch tested? manual tests Author: zuotingbing Closes #20025 from zuotingbing/SPARK-22837. --- .../thriftserver/SparkSQLSessionManager.scala | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 48c0ebef3e0ce..2958b771f3648 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -40,22 +40,8 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - // Create operation log root directory, if operation logging is enabled - if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { - invoke(classOf[SessionManager], this, "initOperationLogRootDir") - } - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) + super.init(hiveConf) } override def openSession( From 840dea64abd8a3a5960de830f19a57f5f1aa3bf6 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 24 Jan 2018 13:13:44 -0500 Subject: [PATCH 0195/2461] [SPARK-23152][ML] - Correctly guard against empty datasets ## What changes were proposed in this pull request? Correctly guard against empty datasets in `org.apache.spark.ml.classification.Classifier` ## How was this patch tested? existing tests Author: Matthew Tovbin Closes #20321 from tovbinm/SPARK-23152. --- .../org/apache/spark/ml/classification/Classifier.scala | 2 +- .../apache/spark/ml/classification/ClassifierSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index bc0b49d48d323..9d1d5aa1e0cff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -109,7 +109,7 @@ abstract class Classifier[ case None => // Get number of classes from dataset itself. val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1) - if (maxLabelRow.isEmpty) { + if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) { throw new SparkException("ML algorithm was given empty dataset.") } val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index de712079329da..87bf2be06c2be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -90,6 +90,13 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } assert(e.getMessage.contains("requires integers in range")) } + val df3 = getTestData(Seq.empty[Double]) + withClue("getNumClasses should fail if dataset is empty") { + val e: SparkException = intercept[SparkException] { + c.getNumClasses(df3) + } + assert(e.getMessage == "ML algorithm was given empty dataset.") + } } } From 0e178e1523175a0be9437920045e80deb0a2712b Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Wed, 24 Jan 2018 10:25:14 -0800 Subject: [PATCH 0196/2461] [SPARK-22297][CORE TESTS] Flaky test: BlockManagerSuite "Shuffle registration timeout and maxAttempts conf" ## What changes were proposed in this pull request? [Ticket](https://issues.apache.org/jira/browse/SPARK-22297) - one of the tests seems to produce unreliable results due to execution speed variability Since the original test was trying to connect to the test server with `40 ms` timeout, and the test server replied after `50 ms`, the error might be produced under the following conditions: - it might occur that the test server replies correctly after `50 ms` - but the client does only receive the timeout after `51 ms`s - this might happen if the executor has to schedule a big number of threads, and decides to delay the thread/actor that is responsible to watch the timeout, because of high CPU load - running an entire test suite usually produces high loads on the CPU executing the tests ## How was this patch tested? The test's check cases remain the same and the set-up emulates the previous version's. Author: Mark Petruska Closes #19671 from mpetruska/SPARK-22297. --- .../spark/storage/BlockManagerSuite.scala | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 629eed49b04cc..b19d8ebf72c61 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.storage import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.duration._ @@ -44,8 +43,9 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} +import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -1325,9 +1325,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { val tryAgainMsg = "test_spark_20640_try_again" + val timingoutExecutor = "timingoutExecutor" + val tryAgainExecutor = "tryAgainExecutor" + val succeedingExecutor = "succeedingExecutor" + // a server which delays response 50ms and must try twice for success. def newShuffleServer(port: Int): (TransportServer, Int) = { - val attempts = new mutable.HashMap[String, Int]() + val failure = new Exception(tryAgainMsg) + val success = ByteBuffer.wrap(new Array[Byte](0)) + + var secondExecutorFailedOnce = false + var thirdExecutorFailedOnce = false + val handler = new NoOpRpcHandler { override def receive( client: TransportClient, @@ -1335,15 +1344,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE callback: RpcResponseCallback): Unit = { val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) msgObj match { - case exec: RegisterExecutor => - Thread.sleep(50) - val attempt = attempts.getOrElse(exec.execId, 0) + 1 - attempts(exec.execId) = attempt - if (attempt < 2) { - callback.onFailure(new Exception(tryAgainMsg)) - return - } - callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + + case exec: RegisterExecutor if exec.execId == timingoutExecutor => + () // No reply to generate client-side timeout + + case exec: RegisterExecutor + if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce => + secondExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == tryAgainExecutor => + callback.onSuccess(success) + + case exec: RegisterExecutor + if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce => + thirdExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == succeedingExecutor => + callback.onSuccess(success) + } } } @@ -1352,6 +1372,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val transCtx = new TransportContext(transConf, handler, true) (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } + val candidatePort = RandomUtils.nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") @@ -1360,21 +1381,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.service.port", shufflePort.toString) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - var e = intercept[SparkException]{ - makeBlockManager(8000, "executor1") + var e = intercept[SparkException] { + makeBlockManager(8000, timingoutExecutor) }.getMessage assert(e.contains("TimeoutException")) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - e = intercept[SparkException]{ - makeBlockManager(8000, "executor2") + e = intercept[SparkException] { + makeBlockManager(8000, tryAgainExecutor) }.getMessage assert(e.contains(tryAgainMsg)) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") - makeBlockManager(8000, "executor3") + makeBlockManager(8000, succeedingExecutor) server.close() } From bc9641d9026aeae3571915b003ac971f6245d53c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 24 Jan 2018 12:58:44 -0800 Subject: [PATCH 0197/2461] [SPARK-23198][SS][TEST] Fix KafkaContinuousSourceStressForDontFailOnDataLossSuite to test ContinuousExecution ## What changes were proposed in this pull request? Currently, `KafkaContinuousSourceStressForDontFailOnDataLossSuite` runs on `MicroBatchExecution`. It should test `ContinuousExecution`. ## How was this patch tested? Pass the updated test suite. Author: Dongjoon Hyun Closes #20374 from dongjoon-hyun/SPARK-23198. --- .../apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index b3dade414f625..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -91,6 +91,7 @@ class KafkaContinuousSourceStressForDontFailOnDataLossSuite ds.writeStream .format("memory") .queryName("memory") + .trigger(Trigger.Continuous("1 second")) .start() } } From 6f0ba8472d1128551fa8090deebcecde0daebc53 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Wed, 24 Jan 2018 13:06:09 -0800 Subject: [PATCH 0198/2461] [MINOR][SQL] add new unit test to LimitPushdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR is repaired as follows 1、update y -> x in "left outer join" test case ,maybe is mistake. 2、add a new test case:"left outer join and left sides are limited" 3、add a new test case:"left outer join and right sides are limited" 4、add a new test case: "right outer join and right sides are limited" 5、add a new test case: "right outer join and left sides are limited" 6、Remove annotations without code implementation ## How was this patch tested? add new unit test case. Author: caoxuewen Closes #20381 from heary-cao/LimitPushdownSuite. --- .../sql/catalyst/optimizer/Optimizer.scala | 1 - .../optimizer/LimitPushdownSuite.scala | 30 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f9daa5f04c76..8d207708c12ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -352,7 +352,6 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index cc98d2350c777..17fb9fc5d11e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -93,7 +93,21 @@ class LimitPushdownSuite extends PlanTest { test("left outer join") { val originalQuery = x.join(y, LeftOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(Limit(2, y), LeftOuter)).analyze comparePlans(optimized, correctAnswer) } @@ -104,6 +118,20 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("right outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + test("larger limits are not pushed on top of smaller ones in right outer join") { val originalQuery = x.join(y.limit(5), RightOuter).limit(10) val optimized = Optimize.execute(originalQuery.analyze) From 45b4bbfddc18a77011c3bc1bfd71b2cd3466443c Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 25 Jan 2018 15:24:52 +0800 Subject: [PATCH 0199/2461] [SPARK-23129][CORE] Make deserializeStream of DiskMapIterator init lazily ## What changes were proposed in this pull request? Currently,the deserializeStream in ExternalAppendOnlyMap#DiskMapIterator init when DiskMapIterator instance created.This will cause memory use overhead when ExternalAppendOnlyMap spill too much times. We can avoid this by making deserializeStream init when it is used the first time. This patch make deserializeStream init lazily. ## How was this patch tested? Exist tests Author: zhoukang Closes #20292 from caneGuy/zhoukang/lay-diskmapiterator. --- .../util/collection/ExternalAppendOnlyMap.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 375f4a6921225..5c6dd45ec58e3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -463,7 +463,7 @@ class ExternalAppendOnlyMap[K, V, C]( // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() + private var deserializeStream: DeserializationStream = null private var nextItem: (K, C) = null private var objectsRead = 0 @@ -528,7 +528,11 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { if (deserializeStream == null) { - return false + // In case of deserializeStream has not been initialized + deserializeStream = nextBatchStream() + if (deserializeStream == null) { + return false + } } nextItem = readNextItem() } @@ -536,19 +540,18 @@ class ExternalAppendOnlyMap[K, V, C]( } override def next(): (K, C) = { - val item = if (nextItem == null) readNextItem() else nextItem - if (item == null) { + if (!hasNext) { throw new NoSuchElementException } + val item = nextItem nextItem = null item } private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - if (ds != null) { - ds.close() + if (deserializeStream != null) { + deserializeStream.close() deserializeStream = null } if (fileStream != null) { From e29b08add92462a6505fef966629e74ba30e994e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 25 Jan 2018 16:40:41 +0800 Subject: [PATCH 0200/2461] [SPARK-23208][SQL] Fix code generation for complex create array (related) expressions ## What changes were proposed in this pull request? The `GenArrayData.genCodeToCreateArrayData` produces illegal java code when code splitting is enabled. This is used in `CreateArray` and `CreateMap` expressions for complex object arrays. This issue is caused by a typo. ## How was this patch tested? Added a regression test in `complexTypesSuite`. Author: Herman van Hovell Closes #20391 from hvanhovell/SPARK-23208. --- .../sql/catalyst/expressions/complexTypeCreator.scala | 2 +- .../spark/sql/catalyst/optimizer/complexTypesSuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3dc2ee03a86e3..047b80ac5289c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -111,7 +111,7 @@ private [sql] object GenArrayData { val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", - extraArguments = ("Object[]", arrayDataName) :: Nil) + extraArguments = ("Object[]", arrayName) :: Nil) (s"Object[] $arrayName = new Object[$numElements];", assignmentString, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 0d11958876ce9..de544ac314789 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -31,7 +32,7 @@ import org.apache.spark.sql.types._ * i.e. {{{create_named_struct(square, `x` * `x`).square}}} can be simplified to {{{`x` * `x`}}}. * sam applies to create_array and create_map */ -class ComplexTypesSuite extends PlanTest{ +class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = @@ -171,6 +172,11 @@ class ComplexTypesSuite extends PlanTest{ assert(ctx.inlinedMutableStates.length == 0) } + test("SPARK-23208: Test code splitting for create array related methods") { + val inputs = (1 to 2500).map(x => Literal(s"l_$x")) + checkEvaluation(CreateArray(inputs), new GenericArrayData(inputs.map(_.eval()))) + } + test("simplify map ops") { val rel = relation .select( From 39ee2acf96f1e1496cff8e4d2614d27fca76d43b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 25 Jan 2018 01:48:11 -0800 Subject: [PATCH 0201/2461] [SPARK-23163][DOC][PYTHON] Sync ML Python API with Scala ## What changes were proposed in this pull request? This syncs the ML Python API with Scala for differences found after the 2.3 QA audit. ## How was this patch tested? NA Author: Bryan Cutler Closes #20354 from BryanCutler/pyspark-ml-doc-sync-23163. --- python/pyspark/ml/evaluation.py | 8 +++++++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/fpm.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index aa8dbe708a115..0cbce9b40048f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -334,7 +334,13 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, .. note:: Experimental Evaluator for Clustering results, which expects two input - columns: prediction and features. + columns: prediction and features. The metric computes the Silhouette + measure using the squared Euclidean distance. + + The Silhouette is a measure for the validation of the consistency + within clusters. It ranges between 1 and -1, where a value close to + 1 means that the points in a cluster are close to the other points + in the same cluster and far from the points of the other clusters. >>> from pyspark.ml.linalg import Vectors >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index eb79b193103e2..da85ba761a145 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3440,7 +3440,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja selectorType = Param(Params._dummy(), "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: numTopFeatures (default), percentile and fpr.", + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.", typeConverter=TypeConverters.toString) numTopFeatures = \ diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index dd7dda5f03124..b8dafd49d354d 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -144,7 +144,7 @@ def freqItemsets(self): @since("2.2.0") def associationRules(self): """ - Data with three columns: + DataFrame with three columns: * `antecedent` - Array of the same type as the input column. * `consequent` - Array of the same type as the input column. * `confidence` - Confidence for the rule (`DoubleType`). From d20bbc2d87ae6bd56d236a7c3d036b52c5f20ff5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Jan 2018 19:49:58 +0800 Subject: [PATCH 0202/2461] [SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen ## What changes were proposed in this pull request? It has been observed in SPARK-21603 that whole-stage codegen suffers performance degradation, if the generated functions are too long to be optimized by JIT. We basically produce a single function to incorporate generated codes from all physical operators in whole-stage. Thus, it is possibly to grow the size of generated function over a threshold that we can't have JIT optimization for it anymore. This patch is trying to decouple the logic of consuming rows in physical operators to avoid a giant function processing rows. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #18931 from viirya/SPARK-21717. --- .../expressions/codegen/CodeGenerator.scala | 38 ++++- .../apache/spark/sql/internal/SQLConf.scala | 12 ++ .../sql/execution/WholeStageCodegenExec.scala | 135 +++++++++++++++--- .../execution/WholeStageCodegenSuite.scala | 47 +++++- 4 files changed, 203 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f96ed7628fda1..4dcbb702893da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1245,6 +1245,31 @@ class CodegenContext { "" } } + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr(_)).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + } } /** @@ -1311,26 +1336,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings - val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + + // The max valid length of method parameters in JVM. + final val MAX_JVM_METHOD_PARAMS_LENGTH = 255 // This is the threshold over which the methods in an inner class are grouped in a single // method which is going to be called by the outer class instead of the many small ones - val MERGE_SPLIT_METHODS_THRESHOLD = 3 + final val MERGE_SPLIT_METHODS_THRESHOLD = 3 // The number of named constants that can exist in the class is limited by the Constant Pool // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a // threshold of 1000k bytes to determine when a function should be inlined to a private, inner // class. - val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 // This is the threshold for the number of global variables, whose types are primitive type or // complex type (e.g. more than one-dimensional array), that will be placed at the outer class - val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 // This is the maximum number of array elements to keep global variables in one Java array // 32767 is the maximum integer value that does not require a constant pool entry in a Java // bytecode instruction - val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 /** * Compile the Java source code into a Java class, using Janino. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1cef09a5bf053..470f88c213561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -661,6 +661,15 @@ object SQLConf { .intConf .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = + buildConf("spark.sql.codegen.splitConsumeFuncByOperator") + .internal() + .doc("When true, whole stage codegen would put the logic of consuming rows of each " + + "physical operator into individual methods, instead of a single big method. This can be " + + "used to avoid oversized function that can miss the opportunity of JIT optimization.") + .booleanConf + .createWithDefault(true) + val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging { def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) + def wholeStageSplitConsumeFuncByOperator: Boolean = + getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR) + def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6102937852347..8ea9e81b2e53b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.Locale +import scala.collection.mutable + import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan { */ protected def doProduce(ctx: CodegenContext): String + private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { + if (row != null) { + ExprCode("", "false", row) + } else { + if (colVars.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + val evaluateInputs = evaluateVariables(colVars) + // generate the code to create a UnsafeRow + ctx.INPUT_ROW = row + ctx.currentVars = colVars + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val code = s""" + |$evaluateInputs + |${ev.code.trim} + """.stripMargin.trim + ExprCode(code, "false", ev.value) + } else { + // There is no columns + ExprCode("", "false", "unsafeRow") + } + } + } + /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. * @@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan { } } - val rowVar = if (row != null) { - ExprCode("", "false", row) - } else { - if (outputVars.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - val evaluateInputs = evaluateVariables(outputVars) - // generate the code to create a UnsafeRow - ctx.INPUT_ROW = row - ctx.currentVars = outputVars - val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" - |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim - ExprCode(code, "false", ev.value) - } else { - // There is no columns - ExprCode("", "false", "unsafeRow") - } - } + val rowVar = prepareRowVar(ctx, row, outputVars) // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to @@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) + + // Under certain conditions, we can put the logic to consume the rows of this operator into + // another function. So we can prevent a generated function too long to be optimized by JIT. + // The conditions: + // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled. + // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses + // all variables in output (see `requireAllOutput`). + // 3. The number of output variables must less than maximum number of parameters in Java method + // declaration. + val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator + val requireAllOutput = output.forall(parent.usedInputs.contains(_)) + val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + constructDoConsumeFunction(ctx, inputVars, row) + } else { + parent.doConsume(ctx, inputVars, rowVar) + } s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} |$evaluated - |${parent.doConsume(ctx, inputVars, rowVar)} + |$consumeFunc + """.stripMargin + } + + /** + * To prevent concatenated function growing too long to be optimized by JIT. We can separate the + * parent's `doConsume` codes of a `CodegenSupport` operator into a function to call. + */ + private def constructDoConsumeFunction( + ctx: CodegenContext, + inputVars: Seq[ExprCode], + row: String): String = { + val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) + val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) + + val doConsume = ctx.freshName("doConsume") + ctx.currentVars = inputVarsInFunc + ctx.INPUT_ROW = null + + val doConsumeFuncName = ctx.addNewFunction(doConsume, + s""" + | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException { + | ${parent.doConsume(ctx, inputVarsInFunc, rowVar)} + | } + """.stripMargin) + + s""" + | $doConsumeFuncName(${args.mkString(", ")}); """.stripMargin } + /** + * Returns arguments for calling method and method definition parameters of the consume function. + * And also returns the list of `ExprCode` for the parameters. + */ + private def constructConsumeParameters( + ctx: CodegenContext, + attributes: Seq[Attribute], + variables: Seq[ExprCode], + row: String): (Seq[String], Seq[String], Seq[ExprCode]) = { + val arguments = mutable.ArrayBuffer[String]() + val parameters = mutable.ArrayBuffer[String]() + val paramVars = mutable.ArrayBuffer[ExprCode]() + + if (row != null) { + arguments += row + parameters += s"InternalRow $row" + } + + variables.zipWithIndex.foreach { case (ev, i) => + val paramName = ctx.freshName(s"expr_$i") + val paramType = ctx.javaType(attributes(i).dataType) + + arguments += ev.value + parameters += s"$paramType $paramName" + val paramIsNull = if (!attributes(i).nullable) { + // Use constant `false` without passing `isNull` for non-nullable variable. + "false" + } else { + val isNull = ctx.freshName(s"exprIsNull_$i") + arguments += ev.isNull + parameters += s"boolean $isNull" + isNull + } + + paramVars += ExprCode("", paramIsNull, paramName) + } + (arguments, parameters, paramVars) + } + /** * Returns source code to evaluate all the variables, and clear the code of them, to prevent * them to be evaluated twice. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 22ca128c27768..242bb48c22942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) - val codeWithLongFunctions = genGroupByCode(20) + val codeWithLongFunctions = genGroupByCode(50) val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } @@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("Control splitting consume function by operators with config") { + import testImplicits._ + val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) + + Seq(true, false).foreach { config => + withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == config) + } + } + } + + test("Skip splitting consume function when parameter number exceeds JVM limit") { + import testImplicits._ + + Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + .write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", + SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { + val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") + val df = spark.read.parquet(path).selectExpr(projection: _*) + + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == hasSplit) + } + } + } + } } From 8532e26f335b67b74c976712ad82c20ea6dbbf80 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 25 Jan 2018 15:01:22 +0200 Subject: [PATCH 0203/2461] [SPARK-23112][DOC] Add highlights and migration guide for 2.3 Update ML user guide with highlights and migration guide for `2.3`. ## How was this patch tested? Doc only. Author: Nick Pentreath Closes #20363 from MLnick/SPARK-23112-ml-guide. --- docs/ml-guide.md | 78 ++++++++++++++----------------------- docs/ml-migration-guides.md | 23 +++++++++++ 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index f6288e7c32d97..b957445579ffd 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -72,32 +72,31 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 [^1]: To learn more about the benefits and background of system optimised natives, you may wish to watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). -# Highlights in 2.2 +# Highlights in 2.3 -The list below highlights some of the new features and enhancements added to MLlib in the `2.2` +The list below highlights some of the new features and enhancements added to MLlib in the `2.3` release of Spark: -* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all - users or items, matching the functionality in `mllib` - ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). - Performance was also improved for both `ml` and `mllib` - ([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and - [SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587)) -* [`Correlation`](ml-statistics.html#correlation) and - [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames` - ([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and - [SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635)) -* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining - ([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503)) -* `GLM` now supports the full `Tweedie` family - ([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929)) -* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset - ([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568)) -* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine) - for linear Support Vector Machine classification - ([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709)) -* Logistic regression now supports constraints on the coefficients during training - ([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047)) +* Built-in support for reading images into a `DataFrame` was added +([SPARK-21866](https://issues.apache.org/jira/browse/SPARK-21866)). +* [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) was added, and should be +used instead of the existing `OneHotEncoder` transformer. The new estimator supports +transforming multiple columns. +* Multiple column support was also added to `QuantileDiscretizer` and `Bucketizer` +([SPARK-22397](https://issues.apache.org/jira/browse/SPARK-22397) and +[SPARK-20542](https://issues.apache.org/jira/browse/SPARK-20542)) +* A new [`FeatureHasher`](ml-features.html#featurehasher) transformer was added + ([SPARK-13969](https://issues.apache.org/jira/browse/SPARK-13969)). +* Added support for evaluating multiple models in parallel when performing cross-validation using +[`TrainValidationSplit` or `CrossValidator`](ml-tuning.html) +([SPARK-19357](https://issues.apache.org/jira/browse/SPARK-19357)). +* Improved support for custom pipeline components in Python (see +[SPARK-21633](https://issues.apache.org/jira/browse/SPARK-21633) and +[SPARK-21542](https://issues.apache.org/jira/browse/SPARK-21542)). +* `DataFrame` functions for descriptive summary statistics over vector columns +([SPARK-19634](https://issues.apache.org/jira/browse/SPARK-19634)). +* Robust linear regression with Huber loss +([SPARK-3181](https://issues.apache.org/jira/browse/SPARK-3181)). # Migration guide @@ -115,36 +114,17 @@ There are no breaking changes. **Deprecations** -There are no deprecations. +* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the +new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) +(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that +`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but +`OneHotEncoderEstimator` will be kept as an alias). **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial), in 2.2 and earlier version, - the `OneVsRest` parallelism would be parallelism of the default threadpool in scala. - -## From 2.1 to 2.2 - -### Breaking changes - -There are no breaking changes. - -### Deprecations and changes of behavior - -**Deprecations** - -There are no deprecations. - -**Changes of behavior** - -* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): - Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). - **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. -* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): - Fixed inconsistency between Python and Scala APIs for `Param.copy` method. -* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): - `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception - would always be thrown regardless of the setting of the `handleInvalid` parameter. + We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + earlier versions, the level of parallelism was set to the default threadpool size in Scala. ## Previous Spark versions diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index 687d7c8930362..f4b0df58cf63b 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -7,6 +7,29 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). +## From 2.1 to 2.2 + +### Breaking changes + +There are no breaking changes. + +### Deprecations and changes of behavior + +**Deprecations** + +There are no deprecations. + +**Changes of behavior** + +* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): + Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). + **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. +* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): + Fixed inconsistency between Python and Scala APIs for `Param.copy` method. +* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): + `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception + would always be thrown regardless of the setting of the `handleInvalid` parameter. + ## From 2.0 to 2.1 ### Breaking changes From 8480c0c57698b7dcccec5483d67b17cf2c7527ed Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 26 Jan 2018 07:50:48 +0900 Subject: [PATCH 0204/2461] [SPARK-23081][PYTHON] Add colRegex API to PySpark ## What changes were proposed in this pull request? Add colRegex API to PySpark ## How was this patch tested? add a test in sql/tests.py Author: Huaxin Gao Closes #20390 from huaxingao/spark-23081. --- python/pyspark/sql/dataframe.py | 23 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 8 +++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d5e9b91468cf..ac403080acfdf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -819,6 +819,29 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @since(2.3) + def colRegex(self, colName): + """ + Selects column based on the column name specified as a regex and returns it + as :class:`Column`. + + :param colName: string, column name specified as a regex. + + >>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"]) + >>> df.select(df.colRegex("`(Col1)?+.+`")).show() + +----+ + |Col2| + +----+ + | 1| + | 2| + | 3| + +----+ + """ + if not isinstance(colName, basestring): + raise ValueError("colName should be provided as string") + jc = self._jdf.colRegex(colName) + return Column(jc) + @ignore_unicode_prefix @since(1.3) def alias(self, alias): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 912f411fa3845..edb6644ed5ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1194,7 +1194,7 @@ class Dataset[T] private[sql]( def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1220,7 +1220,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1240,7 +1240,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name specified as a regex and return it as [[Column]]. + * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel * @since 2.3.0 */ @@ -2729,7 +2729,7 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all rows in this Dataset. + * Returns an iterator that contains all rows in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * From e57f394818b0a62f99609e1032fede7e981f306f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Thu, 25 Jan 2018 16:11:33 -0800 Subject: [PATCH 0205/2461] [SPARK-23032][SQL] Add a per-query codegenStageId to WholeStageCodegenExec ## What changes were proposed in this pull request? **Proposal** Add a per-query ID to the codegen stages as represented by `WholeStageCodegenExec` operators. This ID will be used in - the explain output of the physical plan, and in - the generated class name. Specifically, this ID will be stable within a query, counting up from 1 in depth-first post-order for all the `WholeStageCodegenExec` inserted into a plan. The ID value 0 is reserved for "free-floating" `WholeStageCodegenExec` objects, which may have been created for one-off purposes, e.g. for fallback handling of codegen stages that failed to codegen the whole stage and wishes to codegen a subset of the children operators (as seen in `org.apache.spark.sql.execution.FileSourceScanExec#doExecute`). Example: for the following query: ```scala scala> spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 1) scala> val df1 = spark.range(10).select('id as 'x, 'id + 1 as 'y).orderBy('x).select('x + 1 as 'z, 'y) df1: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint] scala> val df2 = spark.range(5) df2: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> val query = df1.join(df2, 'z === 'id) query: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint ... 1 more field] ``` The explain output before the change is: ```scala scala> query.explain == Physical Plan == *SortMergeJoin [z#9L], [id#13L], Inner :- *Sort [z#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(z#9L, 200) : +- *Project [(x#3L + 1) AS z#9L, y#4L] : +- *Sort [x#3L ASC NULLS FIRST], true, 0 : +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200) : +- *Project [id#0L AS x#3L, (id#0L + 1) AS y#4L] : +- *Range (0, 10, step=1, splits=8) +- *Sort [id#13L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#13L, 200) +- *Range (0, 5, step=1, splits=8) ``` Note how codegen'd operators are annotated with a prefix `"*"`. See how the `SortMergeJoin` operator and its direct children `Sort` operators are adjacent and all annotated with the `"*"`, so it's hard to tell they're actually in separate codegen stages. and after this change it'll be: ```scala scala> query.explain == Physical Plan == *(6) SortMergeJoin [z#9L], [id#13L], Inner :- *(3) Sort [z#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(z#9L, 200) : +- *(2) Project [(x#3L + 1) AS z#9L, y#4L] : +- *(2) Sort [x#3L ASC NULLS FIRST], true, 0 : +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200) : +- *(1) Project [id#0L AS x#3L, (id#0L + 1) AS y#4L] : +- *(1) Range (0, 10, step=1, splits=8) +- *(5) Sort [id#13L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#13L, 200) +- *(4) Range (0, 5, step=1, splits=8) ``` Note that the annotated prefix becomes `"*(id) "`. See how the `SortMergeJoin` operator and its direct children `Sort` operators have different codegen stage IDs. It'll also show up in the name of the generated class, as a suffix in the format of `GeneratedClass$GeneratedIterator$id`. For example, note how `GeneratedClass$GeneratedIteratorForCodegenStage3` and `GeneratedClass$GeneratedIteratorForCodegenStage6` in the following stack trace corresponds to the IDs shown in the explain output above: ``` "Executor task launch worker for task 42412957" daemon prio=5 tid=0x58 nid=NA runnable java.lang.Thread.State: RUNNABLE at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.sort_addToSorter$(generated.java:32) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(generated.java:41) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$9$$anon$1.hasNext(WholeStageCodegenExec.scala:494) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.findNextInnerJoinRows$(generated.java:42) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.processNext(generated.java:101) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$2.hasNext(WholeStageCodegenExec.scala:513) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:748) ``` **Rationale** Right now, the codegen from Spark SQL lacks the means to differentiate between a couple of things: 1. It's hard to tell which physical operators are in the same WholeStageCodegen stage. Note that this "stage" is a separate notion from Spark's RDD execution stages; this one is only to delineate codegen units. There can be adjacent physical operators that are both codegen'd but are in separate codegen stages. Some of this is due to hacky implementation details, such as the case with `SortMergeJoin` and its `Sort` inputs -- they're hard coded to be split into separate stages although both are codegen'd. When printing out the explain output of the physical plan, you'd only see the codegen'd physical operators annotated with a preceding star (`'*'`) but would have no way to figure out if they're in the same stage. 2. Performance/error diagnosis The generated code has class/method names that are hard to differentiate between queries or even between codegen stages within the same query. If we use a Java-level profiler to collect profiles, or if we encounter a Java-level exception with a stack trace in it, it's really hard to tell which part of a query it's at. By introducing a per-query codegen stage ID, we'd at least be able to know which codegen stage (and in turn, which group of physical operators) was a profile tick or an exception happened. The reason why this proposal uses a per-query ID is because it's stable within a query, so that multiple runs of the same query will see the same resulting IDs. This both benefits understandability for users, and also it plays well with the codegen cache in Spark SQL which uses the generated source code as the key. The downside to using per-query IDs as opposed to a per-session or globally incrementing ID is of course we can't tell apart different query runs with this ID alone. But for now I believe this is a good enough tradeoff. ## How was this patch tested? Existing tests. This PR does not involve any runtime behavior changes other than some name changes. The SQL query test suites that compares explain outputs have been updates to ignore the newly added `codegenStageId`. Author: Kris Mok Closes #20224 from rednaxelafx/wsc-codegenstageid. --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../sql/execution/DataSourceScanExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 85 +++++++++++++++++-- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 34 ++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 39 +++++++-- 9 files changed, 158 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 470f88c213561..b0d18b6dced76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -629,6 +629,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME = + buildConf("spark.sql.codegen.useIdInClassName") + .internal() + .doc("When true, embed the (whole-stage) codegen stage ID into " + + "the class name of the generated class as a suffix") + .booleanConf + .createWithDefault(true) + val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields") .internal() .doc("The maximum number of fields (including nested fields) that will be supported before" + @@ -1264,6 +1272,8 @@ class SQLConf extends Serializable with Logging { def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME) + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 7c7d79c2bbd7c..aa66ee7e948ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -324,7 +324,7 @@ case class FileSourceScanExec( // in the case of fallback, this batched scan should never fail because of: // 1) only primitive types are supported // 2) the number of columns should be smaller than spark.sql.codegen.maxFields - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { val unsafeRows = { val scan = inputRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 8ea9e81b2e53b..0e525b1e22eb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.Locale +import java.util.function.Supplier import scala.collection.mutable @@ -414,6 +415,58 @@ object WholeStageCodegenExec { } } +object WholeStageCodegenId { + // codegenStageId: ID for codegen stages within a query plan. + // It does not affect equality, nor does it participate in destructuring pattern matching + // of WholeStageCodegenExec. + // + // This ID is used to help differentiate between codegen stages. It is included as a part + // of the explain output for physical plans, e.g. + // + // == Physical Plan == + // *(5) SortMergeJoin [x#3L], [y#9L], Inner + // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(x#3L, 200) + // : +- *(1) Project [(id#0L % 2) AS x#3L] + // : +- *(1) Filter isnotnull((id#0L % 2)) + // : +- *(1) Range (0, 5, step=1, splits=8) + // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 + // +- Exchange hashpartitioning(y#9L, 200) + // +- *(3) Project [(id#6L % 2) AS y#9L] + // +- *(3) Filter isnotnull((id#6L % 2)) + // +- *(3) Range (0, 5, step=1, splits=8) + // + // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the + // same codegen stage. + // + // The codegen stage ID is also optionally included in the name of the generated classes as + // a suffix, so that it's easier to associate a generated class back to the physical operator. + // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName + // + // The ID is also included in various log messages. + // + // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". + // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. + // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. + // + // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object + // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec + // failed to generate/compile code. + + private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] { + override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+ + }) + + def resetPerQuery(): Unit = codegenStageCounter.set(1) + + def getNextStageId(): Int = { + val counter = codegenStageCounter + val id = counter.get() + counter.set(id + 1) + id + } +} + /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -442,7 +495,8 @@ object WholeStageCodegenExec { * `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input, * used to generated code for [[BoundReference]]. */ -case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) + extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -454,6 +508,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) + def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) { + s"GeneratedIteratorForCodegenStage$codegenStageId" + } else { + "GeneratedIterator" + } + /** * Generates code for this subtree. * @@ -471,19 +531,23 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } """, inlineToOuterClass = true) + val className = generatedClassName() + val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new $className(references); } - ${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")} - final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + ${ctx.registerComment( + s"""Codegend pipeline for stage (id=$codegenStageId) + |${this.treeString.trim}""".stripMargin)} + final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public $className(Object[] references) { this.references = references; } @@ -516,7 +580,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message - logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") + logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() } @@ -525,7 +589,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logInfo(s"Found too long generated codes and JIT optimization might not work: " + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + - s"for this plan. To avoid this, you can raise the limit " + + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") child match { // The fallback solution of batch file source scan still uses WholeStageCodegenExec @@ -603,10 +667,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ") } override def needStopCheck: Boolean = true + + override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) } @@ -657,13 +723,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => - WholeStageCodegenExec(insertInputAdapter(plan)) + WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId()) case other => other.withNewChildren(other.children.map(insertWholeStageCodegen)) } def apply(plan: SparkPlan): SparkPlan = { if (conf.wholeStageEnabled) { + WholeStageCodegenId.resetPerQuery() insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 28b3875505cd2..c167f1e7dc621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -274,7 +274,7 @@ case class InMemoryTableScanExec( protected override def doExecute(): RDD[InternalRow] = { if (supportsBatch) { - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { inputRDD } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 69d871df3e1dd..2c22239e81869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -88,7 +88,7 @@ case class DataSourceV2ScanExec( override protected def doExecute(): RDD[InternalRow] = { if (supportsBatch) { - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { val numOutputRows = longMetric("numOutputRows") inputRDD.map { r => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 054ada56d99ad..beac9699585d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") - .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 242bb48c22942..28ad712feaae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -273,4 +274,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") { + // test case adapted from DataFrameSuite to trigger ReuseExchange + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { + val df = spark.range(100) + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined, + "codegen stage IDs should be preserved through ReuseExchange") + checkAnswer(join, df.toDF) + } + } + + test("including codegen stage ID in generated class name should not regress codegen caching") { + import testImplicits._ + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { + val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE + + // the same query run twice should hit the codegen cache + spark.range(3).select('id + 2).collect + val after1 = bytecodeSizeHisto.getCount + spark.range(3).select('id + 2).collect + val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately + // bytecodeSizeHisto's count is always monotonically increasing if new compilation to + // bytecode had occurred. If the count stayed the same that means we've got a cache hit. + assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected") + + // a different query can result in codegen cache miss, that's by design + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index ff7c5e58e9863..2280da927cf70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -477,7 +477,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) val execPlan = if (enabled == "true") { - WholeStageCodegenExec(planBeforeFilter.head) + WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) } else { planBeforeFilter.head } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index a4273de5fe260..f84d188075b72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -154,14 +154,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("EXPLAIN CODEGEN command") { - checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), - "WholeStageCodegen", - "Generated code:", - "/* 001 */ public Object generate(Object[] references) {", - "/* 002 */ return new GeneratedIterator(references);", - "/* 003 */ }" + test("explain output of physical plan should contain proper codegen stage ID") { + checkKeywordsExist(sql( + """ + |EXPLAIN SELECT t1.id AS a, t2.id AS b FROM + |(SELECT * FROM range(3)) t1 JOIN + |(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3 + """.stripMargin), + "== Physical Plan ==", + "*(2) Project ", + "+- *(2) BroadcastHashJoin ", + " :- BroadcastExchange ", + " : +- *(1) Range ", + " +- *(2) Range " ) + } + + test("EXPLAIN CODEGEN command") { + // the generated class name in this test should stay in sync with + // org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName() + for ((useIdInClassName, expectedClassName) <- Seq( + ("true", "GeneratedIteratorForCodegenStage1"), + ("false", "GeneratedIterator"))) { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) { + checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + s"/* 002 */ return new $expectedClassName(references);", + "/* 003 */ }" + ) + } + } checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"), "== Physical Plan ==" From 7bd46d9871567597216cc02e1dc72ff5806ecdf8 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Thu, 25 Jan 2018 18:15:29 -0600 Subject: [PATCH 0206/2461] [SPARK-23205][ML] Update ImageSchema.readImages to correctly set alpha values for four-channel images ## What changes were proposed in this pull request? When parsing raw image data in ImageSchema.decode(), we use a [java.awt.Color](https://docs.oracle.com/javase/7/docs/api/java/awt/Color.html#Color(int)) constructor that sets alpha = 255, even for four-channel images (which may have different alpha values). This PR fixes this issue & adds a unit test to verify correctness of reading four-channel images. ## How was this patch tested? Updates an existing unit test ("readImages pixel values test" in `ImageSchemaSuite`) to also verify correctness when reading a four-channel image. Author: Sid Murching Closes #20389 from smurching/image-schema-bugfix. --- data/mllib/images/multi-channel/BGRA_alpha_60.png | Bin 0 -> 747 bytes .../org/apache/spark/ml/image/ImageSchema.scala | 5 ++--- .../apache/spark/ml/image/ImageSchemaSuite.scala | 9 ++++++--- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 data/mllib/images/multi-channel/BGRA_alpha_60.png diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/multi-channel/BGRA_alpha_60.png new file mode 100644 index 0000000000000000000000000000000000000000..913637cd2828ab4e2ff4b2bbd92c4cf362f871c4 GIT binary patch literal 747 zcmV zL3V>M390}vf5ExS_g|iB()io}T zdzc-a9(QkC83JQkyn5Lt1Z5CrRi#xHpD{9oI-@&0jtqe@HUD3>kZs0McGO5TM~1+d z7FE_NXiaZ3z?maMU@%u%R=mvsm?J}AJhD?K_8)UCLtrp7B&y$7129;Iz!*D2yjN8< z0y9X4z_>+*xc3{0=Ex8j*D|Cx*=8j4K{5o!c7|A?%)le4CMiSsY+o_Vo>AV}VFh50 z41v-2iea`H1DYg5U<}czr}rCyox2Qyfqu6aXGTB<$q*RG3>nT0#|)AoFmf{_nrt%+ z=R=0T=#wE~ZFN$PgH=pIpR#r$}~v0vQ6s<%(gm8QC*8vEQiGG6cq@D~4&`dke3xtTJT?jHV1R zn*p1-WHaVkhQK(LAu?mT_I%GyhQKgo$ZgD^p$y@(n<2xRQG=Ep8^{nCn;A09ds8(A zT2-xU83JRGAy_kN4BT(jY8e8f?Uz2S`+1phqfY#&mLV|C{nDp(Kbg^7%Mci;cTmZU z&sv7SNV$V*STh2UAuvMkprV>#Myssnk&_`_<2W4$=?*U$0wXp)$=dj2RgM;~p7O*^V>AfDC~V@@+tF<5118 dqE*&-`~khx#S9AHa*F@}002ovPDHLkV1lO_Oc($F literal 0 HcmV?d00001 diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index f7850b238465b..dcc40b6668c7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -169,12 +169,11 @@ object ImageSchema { var offset = 0 for (h <- 0 until height) { for (w <- 0 until width) { - val color = new Color(img.getRGB(w, h)) - + val color = new Color(img.getRGB(w, h), hasAlpha) decoded(offset) = color.getBlue.toByte decoded(offset + 1) = color.getGreen.toByte decoded(offset + 2) = color.getRed.toByte - if (nChannels == 4) { + if (hasAlpha) { decoded(offset + 3) = color.getAlpha.toByte } offset += nChannels diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index dba61cd1eb1cc..a8833c615865d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -53,11 +53,11 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(df.count === 1) df = readImages(imagePath, null, true, -1, false, 1.0, 0) - assert(df.count === 9) + assert(df.count === 10) df = readImages(imagePath, null, true, -1, true, 1.0, 0) val countTotal = df.count - assert(countTotal === 7) + assert(countTotal === 8) df = readImages(imagePath, null, true, -1, true, 0.5, 0) // Random number about half of the size of the original dataset @@ -103,6 +103,9 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { -71, -58, -56, -73, -64))), "BGRA.png" -> (("CV_8UC4", Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, - -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))), + "BGRA_alpha_60.png" -> (("CV_8UC4", + Array[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128, + -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60))) ) } From 70a68b328b856c17eb22cc86fee0ebe8d64f8825 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 26 Jan 2018 11:58:20 +0800 Subject: [PATCH 0207/2461] [SPARK-23020][CORE] Fix race in SparkAppHandle cleanup, again. Third time is the charm? There was still a race that was left in previous attempts. If the handle closes the connection, the close() implementation would clean up state that would prevent the thread from waiting on the connection thread to finish. That could cause the race causing the test flakiness reported in the bug. The fix is to move the "wait for connection thread" code to a separate close method that is used by the handle; that also simplifies the code a bit and makes it also easier to follow. I included an unrelated, but correct, change to a YARN test so that it triggers when the PR is built. Tested by inserting a sleep in the connection thread to mimic the race; test failed reliably with the sleep, passes now. (Sleep not included in the patch.) Also ran YARN tests to make sure. Author: Marcelo Vanzin Closes #20388 from vanzin/SPARK-23020. --- .../spark/launcher/AbstractAppHandle.java | 42 ++++++++------ .../spark/launcher/ChildProcAppHandle.java | 11 +--- .../spark/launcher/InProcessAppHandle.java | 9 +-- .../apache/spark/launcher/LauncherServer.java | 55 +++++++++---------- .../spark/deploy/yarn/YarnClusterSuite.scala | 5 +- 5 files changed, 55 insertions(+), 67 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..84a25a5254151 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; @@ -29,15 +30,15 @@ abstract class AbstractAppHandle implements SparkAppHandle { private final LauncherServer server; - private LauncherConnection connection; + private LauncherServer.ServerConnection connection; private List listeners; - private State state; + private AtomicReference state; private String appId; private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; - this.state = State.UNKNOWN; + this.state = new AtomicReference<>(State.UNKNOWN); } @Override @@ -50,7 +51,7 @@ public synchronized void addListener(Listener l) { @Override public State getState() { - return state; + return state.get(); } @Override @@ -73,7 +74,7 @@ public synchronized void disconnect() { if (!isDisposed()) { if (connection != null) { try { - connection.close(); + connection.closeAndWait(); } catch (IOException ioe) { // no-op. } @@ -82,7 +83,7 @@ public synchronized void disconnect() { } } - void setConnection(LauncherConnection connection) { + void setConnection(LauncherServer.ServerConnection connection) { this.connection = connection; } @@ -99,12 +100,9 @@ boolean isDisposed() { */ synchronized void dispose() { if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } + // Set state to LOST if not yet final. + setState(State.LOST, false); this.disposed = true; } } @@ -113,14 +111,24 @@ void setState(State s) { setState(s, false); } - synchronized void setState(State s, boolean force) { - if (force || !state.isFinal()) { - state = s; + void setState(State s, boolean force) { + if (force) { + state.set(s); fireEvent(false); - } else { - LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { state, s }); + return; } + + State current = state.get(); + while (!current.isFinal()) { + if (state.compareAndSet(current, s)) { + fireEvent(false); + return; + } + current = state.get(); + } + + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { current, s }); } synchronized void setAppId(String appId) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..5e3c95676ecbe 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -104,19 +104,12 @@ void monitorChild() { ec = 1; } - State currState = getState(); - State newState = null; if (ec != 0) { + State currState = getState(); // Override state with failure if the current state is not final, or is success. if (!currState.isFinal() || currState == State.FINISHED) { - newState = State.FAILED; + setState(State.FAILED, true); } - } else if (!currState.isFinal()) { - newState = State.LOST; - } - - if (newState != null) { - setState(newState, true); } disconnect(); diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..b8030e0063a37 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -66,14 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - synchronized (InProcessAppHandle.this) { - if (!isDisposed()) { - disconnect(); - if (!getState().isFinal()) { - setState(State.LOST, true); - } - } - } + disconnect(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 8091885c4f562..f4ecd52fdeab8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -218,32 +218,6 @@ void unregister(AbstractAppHandle handle) { } } - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -312,9 +286,10 @@ private String createSecret() { } } - private class ServerConnection extends LauncherConnection { + class ServerConnection extends LauncherConnection { private TimerTask timeout; + private volatile Thread connectionThread; volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { @@ -322,6 +297,12 @@ private class ServerConnection extends LauncherConnection { this.timeout = timeout; } + @Override + public void run() { + this.connectionThread = Thread.currentThread(); + super.run(); + } + @Override protected void handle(Message msg) throws IOException { try { @@ -376,9 +357,23 @@ public void close() throws IOException { clients.remove(this); } - synchronized (this) { - super.close(); - notifyAll(); + super.close(); + } + + /** + * Close the connection and wait for any buffered data to be processed before returning. + * This ensures any changes reported by the child application take effect. + */ + public void closeAndWait() throws IOException { + close(); + + Thread connThread = this.connectionThread; + if (Thread.currentThread() != connThread) { + try { + connThread.join(); + } catch (InterruptedException ie) { + // Ignore. + } } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index e9dcfaf6ba4f0..5003326b440bf 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -45,8 +45,7 @@ import org.apache.spark.util.Utils /** * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN - * applications, and require the Spark assembly to be built before they can be successfully - * run. + * applications. */ @ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { @@ -152,7 +151,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } test("run Python application in yarn-cluster mode using " + - " spark.yarn.appMasterEnv to override local envvar") { + "spark.yarn.appMasterEnv to override local envvar") { testPySpark( clientMode = false, extraConf = Map( From d1721816d26bedee3c72eeb75db49da500568376 Mon Sep 17 00:00:00 2001 From: Santiago Saavedra Date: Fri, 26 Jan 2018 15:24:06 +0800 Subject: [PATCH 0208/2461] [SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore ## What changes were proposed in this pull request? When using the Kubernetes cluster-manager and spawning a Streaming workload, it is important to reset many spark.kubernetes.* properties that are generated by spark-submit but which would get rewritten when restoring a Checkpoint. This is so, because the spark-submit codepath creates Kubernetes resources, such as a ConfigMap, a Secret and other variables, which have an autogenerated name and the previous one will not resolve anymore. In short, this change enables checkpoint restoration for streaming workloads, and thus enables Spark Streaming workloads in Kubernetes, which were not possible to restore from a checkpoint before if the workload went down. ## How was this patch tested? This patch was tested with the twitter-streaming example in AWS, using checkpoints in s3 with the s3a:// protocol, as supported by Hadoop. This is similar to the YARN related code for resetting a Spark Streaming workload, but for the Kubernetes scheduler. I'm adding the initcontainers properties because even if the discussion is not completely settled on the mailing list, my understanding is that at this moment they are going forward for the moment. For a previous discussion, see the non-rebased work at: https://github.com/apache-spark-on-k8s/spark/pull/516 Author: Santiago Saavedra Closes #20383 from ssaavedra/fix-k8s-checkpointing. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..ed2a896033749 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,6 +53,21 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", + "spark.kubernetes.driver.pod.name", + "spark.kubernetes.executor.podNamePrefix", + "spark.kubernetes.initcontainer.executor.configmapname", + "spark.kubernetes.initcontainer.executor.configmapkey", + "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", + "spark.kubernetes.initcontainer.downloadJarsSecretLocation", + "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", + "spark.kubernetes.initcontainer.downloadFilesSecretLocation", + "spark.kubernetes.initcontainer.remoteJars", + "spark.kubernetes.initcontainer.remoteFiles", + "spark.kubernetes.mountdependencies.jarsDownloadDir", + "spark.kubernetes.mountdependencies.filesDownloadDir", + "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", + "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", + "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -66,6 +81,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") + .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From cd3956df0f96dd416b6161bf7ce2962e06d0a62e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Jan 2018 12:23:14 +0200 Subject: [PATCH 0209/2461] [SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set ## What changes were proposed in this pull request? Currently there is a mixed situation when both single- and multi-column are supported. In some cases exceptions are thrown, in others only a warning log is emitted. In this discussion https://issues.apache.org/jira/browse/SPARK-8418?focusedCommentId=16275049&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-16275049, the decision was to throw an exception. The PR throws an exception in `Bucketizer`, instead of logging a warning. ## How was this patch tested? modified UT Author: Marco Gaido Author: Joseph K. Bradley Closes #19993 from mgaido91/SPARK-22799. --- .../apache/spark/ml/feature/Bucketizer.scala | 44 +++++------- .../org/apache/spark/ml/param/params.scala | 69 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 41 +++++------ .../apache/spark/ml/param/ParamsSuite.scala | 22 ++++++ 4 files changed, 131 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 8299a3e95d822..c13bf47eacb94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * + * Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that - * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and - * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is - * only used for single column usage, and `splitsArray` is for multiple columns. + * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The + * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple + * columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) @@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - /** - * Determines whether this `Bucketizer` is going to map multiple columns. If and only if - * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. A warning will be printed if both are set. - */ - private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol)) { - logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + - "`Bucketizer` only map one column specified by `inputCol`") - false - } else if (isSet(inputCols)) { - true - } else { - false - } - } - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + val (inputColumns, outputColumns) = if (isSet(inputCols)) { ($(inputCols).toSeq, $(outputCols).toSeq) } else { (Seq($(inputCol)), Seq($(outputCol))) @@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val seqOfSplits = if (isBucketizeMultipleColumns()) { + val seqOfSplits = if (isSet(inputCols)) { $(splitsArray).toSeq } else { Seq($(splits)) @@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - if (isBucketizeMultipleColumns()) { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), + Seq(outputCols, splitsArray)) + + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length && + getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).") + var transformedSchema = schema - $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) transformedSchema = SchemaUtils.appendColumn(transformedSchema, prepOutputField($(splitsArray)(idx), outputCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa0..9a83a5882ce29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -249,6 +249,75 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** + * Utility for Param validity checks for Transformers which have both single- and multi-column + * support. This utility assumes that `inputCol` indicates single-column usage and + * that `inputCols` indicates multi-column usage. + * + * This checks to ensure that exactly one set of Params has been set, and it + * raises an `IllegalArgumentException` if not. + * + * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been + * set. This does not need to include `inputCol`. + * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been + * set. This does not need to include `inputCols`. + */ + def checkSingleVsMultiColumnParams( + model: Params, + singleColumnParams: Seq[Param[_]], + multiColumnParams: Seq[Param[_]]): Unit = { + val name = s"${model.getClass.getSimpleName} $model" + + def checkExclusiveParams( + isSingleCol: Boolean, + requiredParams: Seq[Param[_]], + excludedParams: Seq[Param[_]]): Unit = { + val badParamsMsgBuilder = new mutable.StringBuilder() + + val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) + .map(_.name).mkString(", ") + if (mustUnsetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params are not applicable and should not be set: $mustUnsetParams." + } + + val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) + .map(_.name).mkString(", ") + if (mustSetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params must be defined but are not set: $mustSetParams." + } + + val badParamsMsg = badParamsMsgBuilder.toString() + + if (badParamsMsg.nonEmpty) { + val errPrefix = if (isSingleCol) { + s"$name has the inputCol Param set for single-column transform." + } else { + s"$name has the inputCols Param set for multi-column transform." + } + throw new IllegalArgumentException(s"$errPrefix $badParamsMsg") + } + } + + val inputCol = model.getParam("inputCol") + val inputCols = model.getParam("inputCols") + + if (model.isSet(inputCol)) { + require(!model.isSet(inputCols), s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but both are set.") + + checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams, + excludedParams = multiColumnParams) + } else if (model.isSet(inputCols)) { + checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams, + excludedParams = singleColumnParams) + } else { + throw new IllegalArgumentException(s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but neither is set.") + } + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index d9c97ae8067d3..7403680ae3fdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer1.isBucketizeMultipleColumns()) - bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), Seq("result1", "result2"), @@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result")) .setSplitsArray(Array(splits(0))) - assert(bucketizer2.isBucketizeMultipleColumns()) - withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer2.transform(badDF1).collect() @@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), Seq("expected1", "expected2")) @@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - bucketizer.setHandleInvalid("keep") BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), @@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - assert(t.isBucketizeMultipleColumns()) testDefaultReadWrite(t) } @@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) - assert(bucket.isBucketizeMultipleColumns()) - val pl = new Pipeline() .setStages(Array(bucket)) .fit(df) @@ -401,15 +390,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } - test("Both inputCol and inputCols are set") { - val bucket = new Bucketizer() - .setInputCol("feature1") - .setOutputCol("result") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setInputCols(Array("feature1", "feature2")) - - // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. - assert(bucket.isBucketizeMultipleColumns() == false) + test("assert exception is thrown if both multi-column and single-column params are set") { + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("outputCols", Array("result1", "result2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"), + ("splits", Array(-0.5, 0.0, 0.5))) + + // the following should fail because not all the params are set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1")) + ParamsSuite.testExclusiveParams(new Bucketizer, df, + ("inputCols", Array("feature1", "feature2")), + ("outputCols", Array("result1", "result2"))) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 85198ad4c913a..36e06091d24de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.ml.param import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.MyParams +import org.apache.spark.sql.Dataset class ParamsSuite extends SparkFunSuite { @@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite { require(copyReturnType === obj.getClass, s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } + + /** + * Checks that the class throws an exception in case multiple exclusive params are set. + * The params to be checked are passed as arguments with their value. + */ + def testExclusiveParams( + model: Params, + dataset: Dataset[_], + paramsAndValues: (String, Any)*): Unit = { + val m = model.copy(ParamMap.empty) + paramsAndValues.foreach { case (paramName, paramValue) => + m.set(m.getParam(paramName), paramValue) + } + intercept[IllegalArgumentException] { + m match { + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) + } + } + } } From c22eaa94e85aaac649566495dcf763a5de3c8d06 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 26 Jan 2018 12:28:27 +0200 Subject: [PATCH 0210/2461] [SPARK-22797][PYSPARK] Bucketizer support multi-column ## What changes were proposed in this pull request? Bucketizer support multi-column in the python side ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng Closes #19892 from zhengruifeng/20542_py. --- python/pyspark/ml/feature.py | 105 +++++++++++++++++++++------- python/pyspark/ml/param/__init__.py | 10 +++ python/pyspark/ml/tests.py | 9 +++ 3 files changed, 99 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..fdc7787140490 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -317,26 +317,33 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, - JavaMLReadable, JavaMLWritable): - """ - Maps a column of continuous features to a column of feature buckets. - - >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] - >>> df = spark.createDataFrame(values, ["values"]) +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, + HasHandleInvalid, JavaMLReadable, JavaMLWritable): + """ + Maps a column of continuous features to a column of feature buckets. Since 2.3.0, + :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols` + parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters + are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single + column usage, and :py:attr:`splitsArray` is for multiple columns. + + >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), + ... (float("nan"), 1.0), (float("nan"), 0.0)] + >>> df = spark.createDataFrame(values, ["values1", "values2"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values", outputCol="buckets") - >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() - >>> len(bucketed) - 6 - >>> bucketed[0].buckets - 0.0 - >>> bucketed[1].buckets - 0.0 - >>> bucketed[2].buckets - 1.0 - >>> bucketed[3].buckets - 2.0 + ... inputCol="values1", outputCol="buckets") + >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) + >>> bucketed.show(truncate=False) + +-------+-------+ + |values1|buckets| + +-------+-------+ + |0.1 |0.0 | + |0.4 |0.0 | + |1.2 |1.0 | + |1.5 |2.0 | + |NaN |3.0 | + |NaN |3.0 | + +-------+-------+ + ... >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 >>> bucketizerPath = temp_path + "/bucketizer" @@ -347,6 +354,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() >>> len(bucketed) 4 + >>> bucketizer2 = Bucketizer(splitsArray= + ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]], + ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"]) + >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df) + >>> bucketed2.show(truncate=False) + +-------+-------+--------+--------+ + |values1|values2|buckets1|buckets2| + +-------+-------+--------+--------+ + |0.1 |0.0 |0.0 |0.0 | + |0.4 |1.0 |0.0 |1.0 | + |1.2 |1.3 |1.0 |1.0 | + |1.5 |NaN |2.0 |2.0 | + |NaN |1.0 |3.0 |1.0 | + |NaN |0.0 |3.0 |0.0 | + +-------+-------+--------+--------+ + ... .. versionadded:: 1.4.0 """ @@ -363,14 +386,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + "Options are 'skip' (filter out rows with invalid values), " + - "'error' (throw an error), or 'keep' (keep invalid values in a special " + - "additional bucket).", + "'error' (throw an error), or 'keep' (keep invalid values in a " + + "special additional bucket). Note that in the multiple column " + + "case, the invalid handling is applied to all columns. That said " + + "for 'error' it will throw an error if any invalids are found in " + + "any column, for 'skip' it will skip rows with any invalids in " + + "any columns, etc.", typeConverter=TypeConverters.toString) + splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " + + "continuous features into buckets for multiple columns. For each input " + + "column, with n+1 splits, there are n buckets. A bucket defined by " + + "splits x,y holds values in the range [x,y) except the last bucket, " + + "which also includes y. The splits should be of length >= 3 and " + + "strictly increasing. Values at -inf, inf must be explicitly provided " + + "to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.", + typeConverter=TypeConverters.toListListFloat) + @keyword_only - def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): + def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", + splitsArray=None, inputCols=None, outputCols=None): """ - __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") + __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ + splitsArray=None, inputCols=None, outputCols=None) """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) @@ -380,9 +419,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er @keyword_only @since("1.4.0") - def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): + def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", + splitsArray=None, inputCols=None, outputCols=None): """ - setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") + setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ + splitsArray=None, inputCols=None, outputCols=None) Sets params for this Bucketizer. """ kwargs = self._input_kwargs @@ -402,6 +443,20 @@ def getSplits(self): """ return self.getOrDefault(self.splits) + @since("2.3.0") + def setSplitsArray(self, value): + """ + Sets the value of :py:attr:`splitsArray`. + """ + return self._set(splitsArray=value) + + @since("2.3.0") + def getSplitsArray(self): + """ + Gets the array of split points or its default value. + """ + return self.getOrDefault(self.splitsArray) + @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 043c25cf9feb4..5b6b70292f099 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -134,6 +134,16 @@ def toListFloat(value): return [float(v) for v in value] raise TypeError("Could not convert %s to list of floats" % value) + @staticmethod + def toListListFloat(value): + """ + Convert a value to list of list of floats, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + return [TypeConverters.toListFloat(v) for v in value] + raise TypeError("Could not convert %s to list of list of floats" % value) + @staticmethod def toListInt(value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..b8bddbd06f165 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -238,6 +238,15 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + def test_list_list_float(self): + b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]]) + self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]]) + self.assertTrue(all([type(v) == list for v in b.getSplitsArray()])) + self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]])) + self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]])) + self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0])) + self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]])) + class PipelineTests(PySparkTestCase): From 3e252514741447004f3c18ddd77c617b4e37cfaa Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 26 Jan 2018 19:18:18 +0800 Subject: [PATCH 0211/2461] [SPARK-22068][CORE] Reduce the duplicate code between putIteratorAsValues and putIteratorAsBytes ## What changes were proposed in this pull request? The code logic between `MemoryStore.putIteratorAsValues` and `Memory.putIteratorAsBytes` are almost same, so we should reduce the duplicate code between them. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19285 from ConeyLiu/rmemorystore. --- .../spark/storage/memory/MemoryStore.scala | 336 ++++++++++-------- 1 file changed, 178 insertions(+), 158 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 17f7a69ad6ba1..4cc5bcb7f9baf 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -162,7 +162,7 @@ private[spark] class MemoryStore( } /** - * Attempt to put the given block in memory store as values. + * Attempt to put the given block in memory store as values or bytes. * * It's possible that the iterator is too large to materialize and store in memory. To avoid * OOM exceptions, this method will gradually unroll the iterator while periodically checking @@ -170,18 +170,24 @@ private[spark] class MemoryStore( * temporary unroll memory used during the materialization is "transferred" to storage memory, * so we won't acquire more memory than is actually needed to store the block. * - * @return in case of success, the estimated size of the stored data. In case of failure, return - * an iterator containing the values of the block. The returned iterator will be backed - * by the combination of the partially-unrolled block and the remaining elements of the - * original input iterator. The caller must either fully consume this iterator or call - * `close()` on it in order to free the storage memory consumed by the partially-unrolled - * block. + * @param blockId The block id. + * @param values The values which need be stored. + * @param classTag the [[ClassTag]] for the block. + * @param memoryMode The values saved memory mode(ON_HEAP or OFF_HEAP). + * @param valuesHolder A holder that supports storing record of values into memory store as + * values or bytes. + * @return if the block is stored successfully, return the stored data size. Else return the + * memory has reserved for unrolling the block (There are two reasons for store failed: + * First, the block is partially-unrolled; second, the block is entirely unrolled and + * the actual stored data size is larger than reserved, but we can't request extra + * memory). */ - private[storage] def putIteratorAsValues[T]( + private def putIterator[T]( blockId: BlockId, values: Iterator[T], - classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { - + classTag: ClassTag[T], + memoryMode: MemoryMode, + valuesHolder: ValuesHolder[T]): Either[Long, Long] = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") // Number of elements unrolled so far @@ -198,12 +204,10 @@ private[spark] class MemoryStore( val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L - // Underlying vector for unrolling the block - var vector = new SizeTrackingVector[T]()(classTag) // Request enough memory to begin unrolling keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -214,14 +218,14 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold periodically while (values.hasNext && keepUnrolling) { - vector += values.next() + valuesHolder.storeValue(values.next()) if (elementsUnrolled % memoryCheckPeriod == 0) { + val currentSize = valuesHolder.estimatedSize() // If our vector's size has exceeded the threshold, request more memory - val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest } @@ -232,78 +236,86 @@ private[spark] class MemoryStore( elementsUnrolled += 1 } + // Make sure that we have enough memory to store the block. By this point, it is possible that + // the block's actual memory usage has exceeded the unroll memory by a small amount, so we + // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { - // We successfully unrolled the entirety of this block - val arrayValues = vector.toArray - vector = null - val entry = - new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) - val size = entry.size - def transferUnrollToStorage(amount: Long): Unit = { + val entryBuilder = valuesHolder.getBuilder() + val size = entryBuilder.preciseSize + if (size > unrollMemoryUsedByThisBlock) { + val amountToRequest = size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } + + if (keepUnrolling) { + val entry = entryBuilder.build() // Synchronize so that transfer is atomic memoryManager.synchronized { - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount) - val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP) + releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) + val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) assert(success, "transferring unroll memory to storage memory failed") } - } - // Acquire storage memory if necessary to store this block in memory. - val enoughStorageMemory = { - if (unrollMemoryUsedByThisBlock <= size) { - val acquiredExtra = - memoryManager.acquireStorageMemory( - blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP) - if (acquiredExtra) { - transferUnrollToStorage(unrollMemoryUsedByThisBlock) - } - acquiredExtra - } else { // unrollMemoryUsedByThisBlock > size - // If this task attempt already owns more unroll memory than is necessary to store the - // block, then release the extra memory that will not be used. - val excessUnrollMemory = unrollMemoryUsedByThisBlock - size - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory) - transferUnrollToStorage(size) - true - } - } - if (enoughStorageMemory) { + entries.synchronized { entries.put(blockId, entry) } - logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(size) + + logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(blockId, + Utils.bytesToString(entry.size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + Right(entry.size) } else { - assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, - "released too much unroll memory") + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, entryBuilder.preciseSize) + Left(unrollMemoryUsedByThisBlock) + } + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, valuesHolder.estimatedSize()) + Left(unrollMemoryUsedByThisBlock) + } + } + + /** + * Attempt to put the given block in memory store as values. + * + * @return in case of success, the estimated size of the stored data. In case of failure, return + * an iterator containing the values of the block. The returned iterator will be backed + * by the combination of the partially-unrolled block and the remaining elements of the + * original input iterator. The caller must either fully consume this iterator or call + * `close()` on it in order to free the storage memory consumed by the partially-unrolled + * block. + */ + private[storage] def putIteratorAsValues[T]( + blockId: BlockId, + values: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + + val valuesHolder = new DeserializedValuesHolder[T](classTag) + + putIterator(blockId, values, classTag, MemoryMode.ON_HEAP, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + val unrolledIterator = if (valuesHolder.vector != null) { + valuesHolder.vector.iterator + } else { + valuesHolder.arrayValues.toIterator + } + Left(new PartiallyUnrolledIterator( this, MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, - unrolled = arrayValues.toIterator, - rest = Iterator.empty)) - } - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, vector.estimateSize()) - Left(new PartiallyUnrolledIterator( - this, - MemoryMode.ON_HEAP, - unrollMemoryUsedByThisBlock, - unrolled = vector.iterator, - rest = values)) + unrolled = unrolledIterator, + rest = values)) } } /** * Attempt to put the given block in memory store as bytes. * - * It's possible that the iterator is too large to materialize and store in memory. To avoid - * OOM exceptions, this method will gradually unroll the iterator while periodically checking - * whether there is enough free memory. If the block is successfully materialized, then the - * temporary unroll memory used during the materialization is "transferred" to storage memory, - * so we won't acquire more memory than is actually needed to store the block. - * * @return in case of success, the estimated size of the stored data. In case of failure, * return a handle which allows the caller to either finish the serialization by * spilling to disk or to deserialize the partially-serialized block and reconstruct @@ -319,25 +331,8 @@ private[spark] class MemoryStore( require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - val allocator = memoryMode match { - case MemoryMode.ON_HEAP => ByteBuffer.allocate _ - case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ - } - - // Whether there is still enough memory for us to continue unrolling this block - var keepUnrolling = true - // Number of elements unrolled so far - var elementsUnrolled = 0L - // How often to check whether we need to request more memory - val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) - // Memory to request as a multiple of current bbos size - val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold - // Keep track of unroll memory used by this particular block / putIterator() operation - var unrollMemoryUsedByThisBlock = 0L - // Underlying buffer for unrolling the block - val redirectableStream = new RedirectableOutputStream val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + s"is too large to be set as chunk size. Chunk size has been capped to " + @@ -346,85 +341,22 @@ private[spark] class MemoryStore( } else { initialMemoryThreshold.toInt } - val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) - redirectableStream.setOutputStream(bbos) - val serializationStream: SerializationStream = { - val autoPick = !blockId.isInstanceOf[StreamBlockId] - val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) - } - // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) + val valuesHolder = new SerializedValuesHolder[T](blockId, chunkSize, classTag, + memoryMode, serializerManager) - if (!keepUnrolling) { - logWarning(s"Failed to reserve initial memory threshold of " + - s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") - } else { - unrollMemoryUsedByThisBlock += initialMemoryThreshold - } - - def reserveAdditionalMemoryIfNecessary(): Unit = { - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - // Unroll this block safely, checking whether we have exceeded our threshold - while (values.hasNext && keepUnrolling) { - serializationStream.writeObject(values.next())(classTag) - elementsUnrolled += 1 - if (elementsUnrolled % memoryCheckPeriod == 0) { - reserveAdditionalMemoryIfNecessary() - } - } - - // Make sure that we have enough memory to store the block. By this point, it is possible that - // the block's actual memory usage has exceeded the unroll memory by a small amount, so we - // perform one final call to attempt to allocate additional memory if necessary. - if (keepUnrolling) { - serializationStream.close() - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - if (keepUnrolling) { - val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) - // Synchronize so that transfer is atomic - memoryManager.synchronized { - releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) - val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) - assert(success, "transferring unroll memory to storage memory failed") - } - entries.synchronized { - entries.put(blockId, entry) - } - logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(entry.size), - Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(entry.size) - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, bbos.size) - Left( - new PartiallySerializedBlock( + putIterator(blockId, values, classTag, memoryMode, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + Left(new PartiallySerializedBlock( this, serializerManager, blockId, - serializationStream, - redirectableStream, + valuesHolder.serializationStream, + valuesHolder.redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos, + valuesHolder.bbos, values, classTag)) } @@ -702,6 +634,94 @@ private[spark] class MemoryStore( } } +private trait MemoryEntryBuilder[T] { + def preciseSize: Long + def build(): MemoryEntry[T] +} + +private trait ValuesHolder[T] { + def storeValue(value: T): Unit + def estimatedSize(): Long + + /** + * Note: After this method is called, the ValuesHolder is invalid, we can't store data and + * get estimate size again. + * @return a MemoryEntryBuilder which is used to build a memory entry and get the stored data + * size. + */ + def getBuilder(): MemoryEntryBuilder[T] +} + +/** + * A holder for storing the deserialized values. + */ +private class DeserializedValuesHolder[T] (classTag: ClassTag[T]) extends ValuesHolder[T] { + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[T]()(classTag) + var arrayValues: Array[T] = null + + override def storeValue(value: T): Unit = { + vector += value + } + + override def estimatedSize(): Long = { + vector.estimateSize() + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + arrayValues = vector.toArray + vector = null + + override val preciseSize: Long = SizeEstimator.estimate(arrayValues) + + override def build(): MemoryEntry[T] = + DeserializedMemoryEntry[T](arrayValues, preciseSize, classTag) + } +} + +/** + * A holder for storing the serialized values. + */ +private class SerializedValuesHolder[T]( + blockId: BlockId, + chunkSize: Int, + classTag: ClassTag[T], + memoryMode: MemoryMode, + serializerManager: SerializerManager) extends ValuesHolder[T] { + val allocator = memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + + val redirectableStream = new RedirectableOutputStream + val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) + redirectableStream.setOutputStream(bbos) + val serializationStream: SerializationStream = { + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + } + + override def storeValue(value: T): Unit = { + serializationStream.writeObject(value)(classTag) + } + + override def estimatedSize(): Long = { + bbos.size + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + serializationStream.close() + + override def preciseSize(): Long = bbos.size + + override def build(): MemoryEntry[T] = + SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) + } +} + /** * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * From dd8e257d1ccf20f4383dd7f30d634010b176f0d3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Jan 2018 09:17:05 -0800 Subject: [PATCH 0212/2461] [SPARK-23218][SQL] simplify ColumnVector.getArray ## What changes were proposed in this pull request? `ColumnVector` is very flexible about how to implement array type. As a result `ColumnVector` has 3 abstract methods for array type: `arrayData`, `getArrayOffset`, `getArrayLength`. For example, in `WritableColumnVector` we use the first child vector as the array data vector, and store offsets and lengths in 2 arrays in the parent vector. `ArrowColumnVector` has a different implementation. This PR simplifies `ColumnVector` by using only one abstract method for array type: `getArray`. ## How was this patch tested? existing tests. rerun `ColumnarBatchBenchmark`, there is no performance regression. Author: Wenchen Fan Closes #20395 from cloud-fan/vector. --- .../datasources/orc/OrcColumnVector.java | 13 +-- .../vectorized/WritableColumnVector.java | 13 ++- .../sql/vectorized/ArrowColumnVector.java | 48 ++++------ .../spark/sql/vectorized/ColumnVector.java | 88 ++++++++++--------- .../spark/sql/vectorized/ColumnarArray.java | 2 + .../spark/sql/vectorized/ColumnarBatch.java | 2 + .../spark/sql/vectorized/ColumnarRow.java | 2 + .../vectorized/ColumnarBatchBenchmark.scala | 14 ++- 8 files changed, 87 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index aaf2a380034a9..5078bc7922ee2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.unsafe.types.UTF8String; /** @@ -145,16 +146,6 @@ public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public int getArrayOffset(int rowId) { - throw new UnsupportedOperationException(); - } - @Override public Decimal getDecimal(int rowId, int precision, int scale) { BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); @@ -177,7 +168,7 @@ public byte[] getBinary(int rowId) { } @Override - public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index ca4f00985c2a3..a8ec8ef2aadf8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -602,7 +603,17 @@ public final int appendStruct(boolean isNull) { // `WritableColumnVector` puts the data of array in the first child column vector, and puts the // array offsets and lengths in the current column vector. @Override - public WritableColumnVector arrayData() { return childColumns[0]; } + public final ColumnarArray getArray(int rowId) { + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); + } + + public WritableColumnVector arrayData() { + return childColumns[0]; + } + + public abstract int getArrayLength(int rowId); + + public abstract int getArrayOffset(int rowId); @Override public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index ca7a4751450d4..9803c3dec6de2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -17,17 +17,21 @@ package org.apache.spark.sql.vectorized; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. + * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * supported. */ +@InterfaceStability.Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; @@ -90,16 +94,6 @@ public double getDouble(int rowId) { return accessor.getDouble(rowId); } - @Override - public int getArrayLength(int rowId) { - return accessor.getArrayLength(rowId); - } - - @Override - public int getArrayOffset(int rowId) { - return accessor.getArrayOffset(rowId); - } - @Override public Decimal getDecimal(int rowId, int precision, int scale) { return accessor.getDecimal(rowId, precision, scale); @@ -116,7 +110,9 @@ public byte[] getBinary(int rowId) { } @Override - public ArrowColumnVector arrayData() { return childColumns[0]; } + public ColumnarArray getArray(int rowId) { + return accessor.getArray(rowId); + } @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } @@ -151,9 +147,6 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - - childColumns = new ArrowColumnVector[1]; - childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); } else if (vector instanceof NullableMapVector) { NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); @@ -180,10 +173,6 @@ boolean isNullAt(int rowId) { return vector.isNull(rowId); } - final int getValueCount() { - return vector.getValueCount(); - } - final int getNullCount() { return vector.getNullCount(); } @@ -232,11 +221,7 @@ byte[] getBinary(int rowId) { throw new UnsupportedOperationException(); } - int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - int getArrayOffset(int rowId) { + ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } } @@ -433,10 +418,12 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; + private final ArrowColumnVector arrayData; ArrayAccessor(ListVector vector) { super(vector); this.accessor = vector; + this.arrayData = new ArrowColumnVector(vector.getDataVector()); } @Override @@ -450,13 +437,12 @@ final boolean isNullAt(int rowId) { } @Override - final int getArrayLength(int rowId) { - return accessor.getInnerValueCountAt(rowId); - } - - @Override - final int getArrayOffset(int rowId) { - return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); + final ColumnarArray getArray(int rowId) { + ArrowBuf offsets = accessor.getOffsetBuffer(); + int index = rowId * accessor.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + return new ColumnarArray(arrayData, start, end - start); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index f9936214035b6..4b955ceddd0f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; @@ -29,11 +30,14 @@ * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in this ColumnVector. * + * Spark only calls specific `get` method according to the data type of this {@link ColumnVector}, + * e.g. if it's int type, Spark is guaranteed to only call {@link #getInt(int)} or + * {@link #getInts(int, int)}. + * * ColumnVector supports all the data types including nested types. To handle nested types, - * ColumnVector can have children and is a tree structure. For struct type, it stores the actual - * data of each field in the corresponding child ColumnVector, and only stores null information in - * the parent ColumnVector. For array type, it stores the actual array elements in the child - * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. + * ColumnVector can have children and is a tree structure. Please refer to {@link #getStruct(int)}, + * {@link #getArray(int)} and {@link #getMap(int)} for the details about how to implement nested + * types. * * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating * memory again and again. @@ -43,6 +47,7 @@ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage * footprint is negligible. */ +@InterfaceStability.Evolving public abstract class ColumnVector implements AutoCloseable { /** @@ -70,12 +75,12 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the value for rowId. + * Returns the boolean type value for rowId. */ public abstract boolean getBoolean(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count) */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -86,12 +91,12 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the byte type value for rowId. */ public abstract byte getByte(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count) */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -102,12 +107,12 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the short type value for rowId. */ public abstract short getShort(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count) */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -118,12 +123,12 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the int type value for rowId. */ public abstract int getInt(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count) */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -134,12 +139,12 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the long type value for rowId. */ public abstract long getLong(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count) */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -150,12 +155,12 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the float type value for rowId. */ public abstract float getFloat(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count) */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -166,12 +171,12 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the double type value for rowId. */ public abstract double getDouble(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count) */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -182,57 +187,54 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the length of the array for rowId. - */ - public abstract int getArrayLength(int rowId); - - /** - * Returns the offset of the array for rowId. - */ - public abstract int getArrayOffset(int rowId); - - /** - * Returns the struct for rowId. + * Returns the struct type value for rowId. + * + * To support struct type, implementations must implement {@link #getChild(int)} and make this + * vector a tree structure. The number of child vectors must be same as the number of fields of + * the struct type, and each child vector is responsible to store the data for its corresponding + * struct field. */ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } /** - * Returns the array for rowId. + * Returns the array type value for rowId. + * + * To support array type, implementations must construct an {@link ColumnarArray} and return it in + * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all + * the elements of all the arrays in this vector, and an offset and length which points to a range + * in that {@link ColumnVector}, and the range represents the array for rowId. Implementations + * are free to decide where to put the data vector and offsets and lengths. For example, we can + * use the first child vector as the data vector, and store offsets and lengths in 2 int arrays in + * this vector. */ - public final ColumnarArray getArray(int rowId) { - return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); - } + public abstract ColumnarArray getArray(int rowId); /** - * Returns the map for rowId. + * Returns the map type value for rowId. */ public MapData getMap(int ordinal) { throw new UnsupportedOperationException(); } /** - * Returns the decimal for rowId. + * Returns the decimal type value for rowId. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the UTF8String for rowId. Note that the returned UTF8String may point to the data of - * this column vector, please copy it if you want to keep it after this column vector is freed. + * Returns the string type value for rowId. Note that the returned UTF8String may point to the + * data of this column vector, please copy it if you want to keep it after this column vector is + * freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the byte array for rowId. + * Returns the binary type value for rowId. */ public abstract byte[] getBinary(int rowId); - /** - * Returns the data for the underlying array. - */ - public abstract ColumnVector arrayData(); - /** * Returns the ordinal's child column vector. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 522c39580389f..0d2c3ec8648d3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; @@ -25,6 +26,7 @@ /** * Array abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 4dc826cf60c15..d206c1df42abb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.*; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; @@ -26,6 +27,7 @@ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during * the entire data loading process. */ +@InterfaceStability.Evolving public final class ColumnarBatch { private int numRows; private final ColumnVector[] columns; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 2e59085a82768..25db7e09d20d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; @@ -26,6 +27,7 @@ /** * Row abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarRow extends InternalRow { // The data for this row. // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index ad74fb99b0c73..1f31aa45a1220 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.vectorized import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -23,8 +23,6 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark @@ -434,7 +432,6 @@ object ColumnarBatchBenchmark { } def readArrays(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -448,7 +445,6 @@ object ColumnarBatchBenchmark { } def readArrayElements(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -479,10 +475,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 416 / 423 393.5 2.5 1.0X - Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X - On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X - Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X + On Heap Read Size Only 426 / 437 384.9 2.6 1.0X + Off Heap Read Size Only 406 / 421 404.0 2.5 1.0X + On Heap Read Elements 2636 / 2642 62.2 16.1 0.2X + Off Heap Read Elements 3770 / 3774 43.5 23.0 0.1X */ benchmark.run } From a8a3e9b7cf7b9346c43cfbbf7b26fd2fd28dd521 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 26 Jan 2018 23:48:02 +0200 Subject: [PATCH 0213/2461] Revert "[SPARK-22797][PYSPARK] Bucketizer support multi-column" This reverts commit c22eaa94e85aaac649566495dcf763a5de3c8d06. --- python/pyspark/ml/feature.py | 105 +++++++--------------------- python/pyspark/ml/param/__init__.py | 10 --- python/pyspark/ml/tests.py | 9 --- 3 files changed, 25 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fdc7787140490..da85ba761a145 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -317,33 +317,26 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, - HasHandleInvalid, JavaMLReadable, JavaMLWritable): - """ - Maps a column of continuous features to a column of feature buckets. Since 2.3.0, - :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols` - parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters - are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single - column usage, and :py:attr:`splitsArray` is for multiple columns. - - >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), - ... (float("nan"), 1.0), (float("nan"), 0.0)] - >>> df = spark.createDataFrame(values, ["values1", "values2"]) +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + Maps a column of continuous features to a column of feature buckets. + + >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] + >>> df = spark.createDataFrame(values, ["values"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values1", outputCol="buckets") - >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) - >>> bucketed.show(truncate=False) - +-------+-------+ - |values1|buckets| - +-------+-------+ - |0.1 |0.0 | - |0.4 |0.0 | - |1.2 |1.0 | - |1.5 |2.0 | - |NaN |3.0 | - |NaN |3.0 | - +-------+-------+ - ... + ... inputCol="values", outputCol="buckets") + >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() + >>> len(bucketed) + 6 + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 0.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 >>> bucketizerPath = temp_path + "/bucketizer" @@ -354,22 +347,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() >>> len(bucketed) 4 - >>> bucketizer2 = Bucketizer(splitsArray= - ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]], - ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"]) - >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df) - >>> bucketed2.show(truncate=False) - +-------+-------+--------+--------+ - |values1|values2|buckets1|buckets2| - +-------+-------+--------+--------+ - |0.1 |0.0 |0.0 |0.0 | - |0.4 |1.0 |0.0 |1.0 | - |1.2 |1.3 |1.0 |1.0 | - |1.5 |NaN |2.0 |2.0 | - |NaN |1.0 |3.0 |1.0 | - |NaN |0.0 |3.0 |0.0 | - +-------+-------+--------+--------+ - ... .. versionadded:: 1.4.0 """ @@ -386,30 +363,14 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + "Options are 'skip' (filter out rows with invalid values), " + - "'error' (throw an error), or 'keep' (keep invalid values in a " + - "special additional bucket). Note that in the multiple column " + - "case, the invalid handling is applied to all columns. That said " + - "for 'error' it will throw an error if any invalids are found in " + - "any column, for 'skip' it will skip rows with any invalids in " + - "any columns, etc.", + "'error' (throw an error), or 'keep' (keep invalid values in a special " + + "additional bucket).", typeConverter=TypeConverters.toString) - splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " + - "continuous features into buckets for multiple columns. For each input " + - "column, with n+1 splits, there are n buckets. A bucket defined by " + - "splits x,y holds values in the range [x,y) except the last bucket, " + - "which also includes y. The splits should be of length >= 3 and " + - "strictly increasing. Values at -inf, inf must be explicitly provided " + - "to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.", - typeConverter=TypeConverters.toListListFloat) - @keyword_only - def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", - splitsArray=None, inputCols=None, outputCols=None): + def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ - splitsArray=None, inputCols=None, outputCols=None) + __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) @@ -419,11 +380,9 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er @keyword_only @since("1.4.0") - def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", - splitsArray=None, inputCols=None, outputCols=None): + def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ - splitsArray=None, inputCols=None, outputCols=None) + setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this Bucketizer. """ kwargs = self._input_kwargs @@ -443,20 +402,6 @@ def getSplits(self): """ return self.getOrDefault(self.splits) - @since("2.3.0") - def setSplitsArray(self, value): - """ - Sets the value of :py:attr:`splitsArray`. - """ - return self._set(splitsArray=value) - - @since("2.3.0") - def getSplitsArray(self): - """ - Gets the array of split points or its default value. - """ - return self.getOrDefault(self.splitsArray) - @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5b6b70292f099..043c25cf9feb4 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -134,16 +134,6 @@ def toListFloat(value): return [float(v) for v in value] raise TypeError("Could not convert %s to list of floats" % value) - @staticmethod - def toListListFloat(value): - """ - Convert a value to list of list of floats, if possible. - """ - if TypeConverters._can_convert_to_list(value): - value = TypeConverters.toList(value) - return [TypeConverters.toListFloat(v) for v in value] - raise TypeError("Could not convert %s to list of list of floats" % value) - @staticmethod def toListInt(value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b8bddbd06f165..1af2b91da900d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -238,15 +238,6 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - def test_list_list_float(self): - b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]]) - self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]]) - self.assertTrue(all([type(v) == list for v in b.getSplitsArray()])) - self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]])) - self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]])) - self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0])) - self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]])) - class PipelineTests(PySparkTestCase): From 94c67a76ec1fda908a671a47a2a1fa63b3ab1b06 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 26 Jan 2018 15:01:03 -0800 Subject: [PATCH 0214/2461] [SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers ## What changes were proposed in this pull request? Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic since the sequence of input rows are not determined. The bug can be triggered when there is a repartition call following a shuffle (which would lead to non-deterministic row ordering), as the pattern shows below: upstream stage -> repartition stage -> result stage (-> indicate a shuffle) When one of the executors process goes down, some tasks on the repartition stage will be retried and generate inconsistent ordering, and some tasks of the result stage will be retried generating different data. The following code returns 931532, instead of 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() ``` In this PR, we propose a most straight-forward way to fix this problem by performing a local sort before partitioning, after we make the input row ordering deterministic, the function from rows to partitions is fully deterministic too. The downside of the approach is that with extra local sort inserted, the performance of repartition() will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to control whether this patch is applied. The patch is default enabled to be safe-by-default, but user may choose to manually turn it off to avoid performance regression. This patch also changes the output rows ordering of repartition(), that leads to a bunch of test cases failure because they are comparing the results directly. ## How was this patch tested? Add unit test in ExchangeSuite. With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following query returns 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true") val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() res7: Long = 1000000 ``` Author: Xingbo Jiang Closes #20393 from jiangxb1987/shuffle-repartition. --- .../unsafe/sort/RecordComparator.java | 4 +- .../unsafe/sort/UnsafeInMemorySorter.java | 7 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 4 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 + .../sort/UnsafeExternalSorterSuite.java | 4 +- .../sort/UnsafeInMemorySorterSuite.java | 8 ++- .../spark/ml/feature/Word2VecSuite.scala | 3 +- .../sql/execution/RecordBinaryComparator.java | 70 +++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 44 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 14 ++++ .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../exchange/ShuffleExchangeExec.scala | 52 +++++++++++++- .../spark/sql/execution/ExchangeSuite.scala | 26 ++++++- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../datasources/text/WholeTextFileSuite.scala | 2 +- .../streaming/ForeachSinkSuite.scala | 6 +- 17 files changed, 233 insertions(+), 29 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e4258792204..02b5de8e128c9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -32,6 +32,8 @@ public abstract class RecordComparator { public abstract int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset); + long rightBaseOffset, + int rightBaseLength); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 951d076420ee6..b3c27d83da172 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -62,12 +62,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { int uaoSize = UnsafeAlignedOffset.getUaoSize(); if (prefixComparisonResult == 0) { final Object baseObject1 = memoryManager.getPage(r1.recordPointer); - // skip length final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); final Object baseObject2 = memoryManager.getPage(r2.recordPointer); - // skip length final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize; - return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); } else { return prefixComparisonResult; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index cf4dfde86ca91..ff0dcc259a4ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger { prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); + left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(), + right.getBaseObject(), right.getBaseOffset(), right.getRecordLength()); } else { return prefixComparisonResult; } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7859781e98223..0574abdca32ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -414,6 +414,8 @@ abstract class RDD[T: ClassTag]( * * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. + * + * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207. */ def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index af4975c888d65..411cd5cb57331 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 594f07dd780f9..c145532328514 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; @@ -164,8 +166,10 @@ public void freeAfterOOM() { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 6183606a7b2ac..10682ba176aca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val oldModel = new OldWord2VecModel(word2VecMap) val instance = new Word2VecModel("myWord2VecModel", oldModel) val newInstance = testDefaultReadWrite(instance) - assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + assert(newInstance.getVectors.collect().sortBy(_.getString(0)) === + instance.getVectors.collect().sortBy(_.getString(0))) } test("Word2Vec works with input that is non-nullable (NGram)") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java new file mode 100644 index 0000000000000..bb77b5bf6de2a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; + +public final class RecordBinaryComparator extends RecordComparator { + + // TODO(jiangxb) Add test suite for this. + @Override + public int compare( + Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { + int i = 0; + int res = 0; + + // If the arrays have different length, the longer one is larger. + if (leftLen != rightLen) { + return leftLen - rightLen; + } + + // The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since + // we have guaranteed `leftLen` == `rightLen`. + + // check if stars align and we can get both offsets to be aligned + if ((leftOff % 8) == (rightOff % 8)) { + while ((leftOff + i) % 8 != 0 && i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + } + // for architectures that support unaligned accesses, chew it up 8 bytes at a time + if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { + while (i <= leftLen - 8) { + res = (int) ((Platform.getLong(leftObj, leftOff + i) - + Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); + if (res != 0) return res; + i += 8; + } + } + // this will finish off the unaligned comparisons, or do the entire aligned comparison + // whichever is needed. + while (i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + + // The two arrays are equal. + return 0; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 6b002f0d3f8e8..78647b56d621f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.function.Supplier; +import org.apache.spark.sql.catalyst.util.TypeUtils; import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -56,26 +58,50 @@ public abstract static class PrefixComputer { public static class Prefix { /** Key prefix value, or the null prefix value if isNull = true. **/ - long value; + public long value; /** Whether the key is null. */ - boolean isNull; + public boolean isNull; } /** * Computes prefix for the given row. For efficiency, the returned object may be reused in * further calls to a given PrefixComputer. */ - abstract Prefix computePrefix(InternalRow row); + public abstract Prefix computePrefix(InternalRow row); } - public UnsafeExternalRowSorter( + public static UnsafeExternalRowSorter createWithRecordComparator( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { + Supplier recordComparatorSupplier = + () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); @@ -85,7 +111,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - () -> new RowComparator(ordering, schema.length()), + recordComparatorSupplier, prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -206,7 +232,13 @@ private static final class RowComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1, 0); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b0d18b6dced76..76b9d6f6f33bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1145,6 +1145,18 @@ object SQLConf { .checkValues(PartitionOverwriteMode.values.map(_.toString)) .createWithDefault(PartitionOverwriteMode.STATIC.toString) + val SORT_BEFORE_REPARTITION = + buildConf("spark.sql.execution.sortBeforeRepartition") + .internal() + .doc("When perform a repartition following a shuffle, the output row ordering would be " + + "nondeterministic. If some downstream stages fail and some tasks of the repartition " + + "stage retry, these tasks may generate different data, and that can lead to correctness " + + "issues. Turn on this config to insert a local sort before actually doing repartition " + + "to generate consistent repartition results. The performance of repartition() may go " + + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1300,6 +1312,8 @@ class SQLConf extends Serializable with Logging { def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index eb2fe82007af3..b0b5383a081a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -241,7 +241,13 @@ private static final class KVComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1 + 4, 0); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ef1bb1c2a4468..ac1c34d41c4f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -84,7 +84,7 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( + val sorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5a1e217082bc2..76c1fa65f924b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.exchange import java.util.Random +import java.util.function.Supplier import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -25,13 +26,15 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** * Performs a shuffle that will result in the desired `newPartitioning`. @@ -247,14 +250,57 @@ object ShuffleExchangeExec { case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + canUseRadixSort) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index aac8d56ba6201..697d7e6520713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } + + test("SPARK-23207: Make repartition() generate consistent output") { + def assertConsistency(ds: Dataset[java.lang.Long]): Unit = { + ds.persist() + + val exchange = ds.mapPartitions { iter => + Random.shuffle(iter) + }.repartition(111) + val exchange2 = ds.repartition(111) + + assert(exchange.rdd.collectPartitions() === exchange2.rdd.collectPartitions()) + } + + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") { + // repartition() should generate consistent output. + assertConsistency(spark.range(10000)) + + // case when input contains duplicated rows. + assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 44a8b25c61dfb..f3ece5b15e26a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -662,7 +662,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getInt(0), row.getString(1)) result += v } - assert(data == result) + assert(data.toSet == result.toSet) } finally { reader.close() } @@ -678,7 +678,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val row = reader.getCurrentValue.asInstanceOf[InternalRow] result += row.getString(0) } - assert(data.map(_._2) == result) + assert(data.map(_._2).toSet == result.toSet) } finally { reader.close() } @@ -695,7 +695,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getString(0), row.getInt(1)) result += v } - assert(data.map { x => (x._2, x._1) } == result) + assert(data.map { x => (x._2, x._1) }.toSet == result.toSet) } finally { reader.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index 8bd736bee69de..fff0f82f9bc2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -95,7 +95,7 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { df1.write.option("compression", "gzip").mode("overwrite").text(path) // On reading through wholetext mode, one file will be read as a single row, i.e. not // delimited by "next line" character. - val expected = Row(Range(0, 1000).mkString("", "\n", "\n")) + val expected = Row(df1.collect().map(_.getString(0)).mkString("", "\n", "\n")) Seq(10, 100, 1000).foreach { bytes => withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { val df2 = spark.read.option("wholetext", "true").format("text").load(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9137d650e906b..1248c670df45c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf var expectedEventsForPartition0 = Seq( ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 1), + ForeachSinkSuite.Process(value = 2), ForeachSinkSuite.Process(value = 3), ForeachSinkSuite.Close(None) ) var expectedEventsForPartition1 = Seq( ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 2), + ForeachSinkSuite.Process(value = 1), ForeachSinkSuite.Process(value = 4), ForeachSinkSuite.Close(None) ) @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 073744985f439ca90afb9bd0bbc1332c53f7b4bb Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 26 Jan 2018 16:09:57 -0800 Subject: [PATCH 0215/2461] [SPARK-23242][SS][TESTS] Don't run tests in KafkaSourceSuiteBase twice ## What changes were proposed in this pull request? KafkaSourceSuiteBase should be abstract class, otherwise KafkaSourceSuiteBase will also run. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20412 from zsxwing/SPARK-23242. --- .../scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 27dbb3f7a8f31..c4cb1bc4a2e18 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -546,7 +546,7 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } -class KafkaSourceSuiteBase extends KafkaSourceTest { +abstract class KafkaSourceSuiteBase extends KafkaSourceTest { import testImplicits._ From 5b5447c68ac79715e2256e487e1212861cdab1fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Jan 2018 16:46:51 -0800 Subject: [PATCH 0216/2461] [SPARK-23214][SQL] cached data should not carry extra hint info ## What changes were proposed in this pull request? This is a regression introduced by https://github.com/apache/spark/pull/19864 When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info. ## How was this patch tested? a new test. Author: Wenchen Fan Closes #20394 from cloud-fan/cache. --- .../spark/sql/execution/CacheManager.scala | 17 +-- .../execution/columnar/InMemoryRelation.scala | 27 +++-- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 103 +++++++++++------- 5 files changed, 94 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 432eb59d6fe57..d68aeb275afda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -169,14 +169,17 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { + // Do not lookup the cache by hint node. Hint node is special, we should ignore it when + // canonicalizing plans, so that plans which are same except hint can hit the same cache. + // However, we also want to keep the hint info after cache lookup. Here we skip the hint + // node, so that the returned caching plan won't replace the hint node and drop the hint info + // from the original plan. + case hint: ResolvedHint => hint + case currentFragment => - lookupCachedData(currentFragment).map { cached => - val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) - currentFragment match { - case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) - case _ => cachedPlan - } - }.getOrElse(currentFragment) + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..22e16913d4da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -62,8 +62,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -73,11 +73,16 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override def computeStats(): Statistics = { - if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache - statsOfPlanToCache + if (sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = sizeInBytesStats.value.longValue) } } @@ -122,7 +127,7 @@ case class InMemoryRelation( rowCount += 1 } - batchStats.add(totalSize) + sizeInBytesStats.add(totalSize) val stats = InternalRow.fromSeq( columnBuilders.flatMap(_.columnStats.collectedStatistics)) @@ -144,7 +149,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } override def newInstance(): this.type = { @@ -156,12 +161,12 @@ case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - batchStats, + sizeInBytesStats, statsOfPlanToCache).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) + Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e52445f28fc1..72fe0f42801f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2280da927cf70..dc1766fb9a785 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..bcdee792f4c70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { private def testBroadcastJoin[T: ClassTag]( joinType: String, forceBroadcast: Boolean = false): SparkPlan = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") @@ -109,30 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained after using the cached data") { + test("SPARK-23192: broadcast hint should be retained after using the cached data") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - df2.cache() - val df3 = df1.join(broadcast(df2), Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } finally { + spark.catalog.clearCache() + } + } + } + + test("SPARK-23214: cached data should not carry extra hint info") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + broadcast(df2).cache() + + val df3 = df1.join(df2, Seq("key"), "inner") + val numCachedPlan = df3.queryExecution.executedPlan.collect { + case i: InMemoryTableScanExec => i + }.size + // df2 should be cached. + assert(numCachedPlan === 1) + + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + // df2 should not be broadcasted. + assert(numBroadCastHashJoin === 0) + } finally { + spark.catalog.clearCache() + } } } test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) - val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value") val df5 = df4.join(df3, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) @@ -140,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value") val broadcasted = broadcast(df2) - val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") - - val cases = Seq(broadcasted.limit(2), - broadcasted.filter("value < 10"), - broadcasted.sample(true, 0.5), - broadcasted.distinct(), - broadcasted.groupBy("value").agg(min($"key").as("key")), - // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the broadcast hint should be propagated here - broadcasted.except(df3), - broadcasted.intersect(df3)) + val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value") + + val cases = Seq( + broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) cases.foreach(assertBroadcastJoin) } @@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't change broadcast join buildSide if user clearly specified") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes @@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't bias towards build right if user didn't specify") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes From e7bc9f0524822a08d857c3a5ba57119644ceae85 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 26 Jan 2018 18:57:32 -0600 Subject: [PATCH 0217/2461] [MINOR][SS][DOC] Fix `Trigger` Scala/Java doc examples ## What changes were proposed in this pull request? This PR fixes Scala/Java doc examples in `Trigger.java`. ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20401 from dongjoon-hyun/SPARK-TRIGGER. --- .../src/main/java/org/apache/spark/sql/streaming/Trigger.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 33ae9a9e87668..5371a23230c98 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -50,7 +50,7 @@ public static Trigger ProcessingTime(long intervalMs) { * * {{{ * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) * }}} * * @since 2.2.0 @@ -66,7 +66,7 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { * * {{{ * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) * }}} * @since 2.2.0 */ From 6328868e524121bd00595959d6d059f74e038a6b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 26 Jan 2018 23:06:03 -0800 Subject: [PATCH 0218/2461] [SPARK-23245][SS][TESTS] Don't access `lastExecution.executedPlan` in StreamTest ## What changes were proposed in this pull request? `lastExecution.executedPlan` is lazy val so accessing it in StreamTest may need to acquire the lock of `lastExecution`. It may be waiting forever when the streaming thread is holding it and running a continuous Spark job. This PR changes to check if `s.lastExecution` is null to avoid accessing `lastExecution.executedPlan`. ## How was this patch tested? Jenkins Author: Jose Torres Closes #20413 from zsxwing/SPARK-23245. --- .../test/scala/org/apache/spark/sql/streaming/StreamTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index efdb0e0e7cf1c..d6433562fb29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,7 +472,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null + assert(s.lastExecution != null) } case _ => } From 3227d14feb1a65e95a2bf326cff6ac95615cc5ac Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 27 Jan 2018 11:26:09 -0800 Subject: [PATCH 0219/2461] [SPARK-23233][PYTHON] Reset the cache in asNondeterministic to set deterministic properly ## What changes were proposed in this pull request? Reproducer: ```python from pyspark.sql.functions import udf f = udf(lambda x: x) spark.range(1).select(f("id")) # cache JVM UDF instance. f = f.asNondeterministic() spark.range(1).select(f("id"))._jdf.logicalPlan().projectList().head().deterministic() ``` It should return `False` but the current master returns `True`. Seems it's because we cache the JVM UDF instance and then we reuse it even after setting `deterministic` disabled once it's called. ## How was this patch tested? Manually tested. I am not sure if I should add the test with a lot of JVM accesses with the intetnal stuff .. Let me know if anyone feels so. I will add. Author: hyukjinkwon Closes #20409 from HyukjinKwon/SPARK-23233. --- python/pyspark/sql/tests.py | 13 +++++++++++++ python/pyspark/sql/udf.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a466ab87d882d..ca7bbf8ffe71c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -441,6 +441,19 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf1) pydoc.render_doc(udf(lambda x: x).asNondeterministic) + def test_nondeterministic_udf3(self): + # regression test for SPARK-23233 + from pyspark.sql.functions import udf + f = udf(lambda x: x) + # Here we cache the JVM UDF instance. + self.spark.range(1).select(f("id")) + # This should reset the cache to set the deterministic status correctly. + f = f.asNondeterministic() + # Check the deterministic status of udf. + df = self.spark.range(1).select(f("id")) + deterministic = df._jdf.logicalPlan().projectList().head().deterministic() + self.assertFalse(deterministic) + def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum import random diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index de96846c5c774..4f303304e5600 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -188,6 +188,9 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ + # Here, we explicitly clean the cache to create a JVM UDF instance + # with 'deterministic' updated. See SPARK-23233. + self._judf_placeholder = None self.deterministic = False return self From b8c32dc57368e49baaacf660b7e8836eedab2df7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 28 Jan 2018 10:33:06 +0900 Subject: [PATCH 0220/2461] [SPARK-23248][PYTHON][EXAMPLES] Relocate module docstrings to the top in PySpark examples ## What changes were proposed in this pull request? This PR proposes to relocate the docstrings in modules of examples to the top. Seems these are mistakes. So, for example, the below codes ```python >>> help(aft_survival_regression) ``` shows the module docstrings for examples as below: **Before** ``` Help on module aft_survival_regression: NAME aft_survival_regression ... DESCRIPTION # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ... (END) ``` **After** ``` Help on module aft_survival_regression: NAME aft_survival_regression ... DESCRIPTION An example demonstrating aft survival regression. Run with: bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py (END) ``` ## How was this patch tested? Manually checked. Author: hyukjinkwon Closes #20416 from HyukjinKwon/module-docstring-example. --- examples/src/main/python/avro_inputformat.py | 14 +++++++------- .../src/main/python/ml/aft_survival_regression.py | 11 +++++------ .../main/python/ml/bisecting_k_means_example.py | 11 +++++------ .../ml/bucketed_random_projection_lsh_example.py | 12 +++++------- .../src/main/python/ml/chi_square_test_example.py | 10 +++++----- .../src/main/python/ml/correlation_example.py | 10 +++++----- examples/src/main/python/ml/cross_validator.py | 15 +++++++-------- examples/src/main/python/ml/fpgrowth_example.py | 9 ++++----- .../main/python/ml/gaussian_mixture_example.py | 11 +++++------ .../ml/generalized_linear_regression_example.py | 11 +++++------ examples/src/main/python/ml/imputer_example.py | 9 ++++----- .../main/python/ml/isotonic_regression_example.py | 9 +++------ examples/src/main/python/ml/kmeans_example.py | 15 +++++++-------- examples/src/main/python/ml/lda_example.py | 12 +++++------- .../ml/logistic_regression_summary_example.py | 11 +++++------ .../src/main/python/ml/min_hash_lsh_example.py | 12 +++++------- .../src/main/python/ml/one_vs_rest_example.py | 13 ++++++------- .../src/main/python/ml/train_validation_split.py | 13 ++++++------- examples/src/main/python/parquet_inputformat.py | 12 ++++++------ examples/src/main/python/sql/basic.py | 11 +++++------ examples/src/main/python/sql/datasource.py | 11 +++++------ examples/src/main/python/sql/hive.py | 11 +++++------ 22 files changed, 115 insertions(+), 138 deletions(-) diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4422f9e7a9589..6286ba6541fbd 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,13 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from functools import reduce -from pyspark.sql import SparkSession - """ Read data file users.avro in local Spark distro: @@ -50,6 +43,13 @@ {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} """ +from __future__ import print_function + +import sys + +from functools import reduce +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: print(""" diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py index 2f0ca995e55c7..0a71f76418ea6 100644 --- a/examples/src/main/python/ml/aft_survival_regression.py +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating aft survival regression. +Run with: + bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py +""" from __future__ import print_function # $example on$ @@ -23,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating aft survival regression. -Run with: - bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index 1263cb5d177a8..7842d2009e238 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating bisecting k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating bisecting k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py index 1b7a458125cef..610176ea596ca 100644 --- a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +++ b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating BucketedRandomProjectionLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating BucketedRandomProjectionLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/chi_square_test_example.py b/examples/src/main/python/ml/chi_square_test_example.py index 8f25318ded00a..2af7e683cdb72 100644 --- a/examples/src/main/python/ml/chi_square_test_example.py +++ b/examples/src/main/python/ml/chi_square_test_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for Chi-square hypothesis testing. +Run with: + bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -23,11 +28,6 @@ from pyspark.ml.stat import ChiSquareTest # $example off$ -""" -An example for Chi-square hypothesis testing. -Run with: - bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/correlation_example.py b/examples/src/main/python/ml/correlation_example.py index 0a9d30da5a42e..1f4e402ac1a51 100644 --- a/examples/src/main/python/ml/correlation_example.py +++ b/examples/src/main/python/ml/correlation_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for computing correlation matrix. +Run with: + bin/spark-submit examples/src/main/python/ml/correlation_example.py +""" from __future__ import print_function # $example on$ @@ -23,11 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example for computing correlation matrix. -Run with: - bin/spark-submit examples/src/main/python/ml/correlation_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index db7054307c2e3..6256d11504afb 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" from __future__ import print_function # $example on$ @@ -26,14 +33,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating model selection using CrossValidator. -This example also demonstrates how Pipelines are Estimators. -Run with: - - bin/spark-submit examples/src/main/python/ml/cross_validator.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py index c92c3c27abb21..39092e616d429 100644 --- a/examples/src/main/python/ml/fpgrowth_example.py +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.fpm import FPGrowth -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating FPGrowth. Run with: bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py """ +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index e4a0d314e9d91..4938a904189f9 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Gaussian Mixture Model (GMM). +Run with: + bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating Gaussian Mixture Model (GMM). -Run with: - bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/generalized_linear_regression_example.py b/examples/src/main/python/ml/generalized_linear_regression_example.py index 796752a60f3ab..a52f4650c1c6f 100644 --- a/examples/src/main/python/ml/generalized_linear_regression_example.py +++ b/examples/src/main/python/ml/generalized_linear_regression_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating generalized linear regression. +Run with: + bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.ml.regression import GeneralizedLinearRegression # $example off$ -""" -An example demonstrating generalized linear regression. -Run with: - bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py index b8437f827e56d..9ba0147763618 100644 --- a/examples/src/main/python/ml/imputer_example.py +++ b/examples/src/main/python/ml/imputer_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.feature import Imputer -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating Imputer. Run with: bin/spark-submit examples/src/main/python/ml/imputer_example.py """ +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/isotonic_regression_example.py b/examples/src/main/python/ml/isotonic_regression_example.py index 6ae15f1b4b0dd..89cba9dfc7e8f 100644 --- a/examples/src/main/python/ml/isotonic_regression_example.py +++ b/examples/src/main/python/ml/isotonic_regression_example.py @@ -17,6 +17,9 @@ """ Isotonic Regression Example. + +Run with: + bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py """ from __future__ import print_function @@ -25,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating isotonic regression. -Run with: - bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 5f77843e3743a..80a878af679f4 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +An example demonstrating k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/kmeans_example.py + +This example requires NumPy (http://www.numpy.org/). +""" from __future__ import print_function # $example on$ @@ -24,14 +31,6 @@ from pyspark.sql import SparkSession -""" -An example demonstrating k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/kmeans_example.py - -This example requires NumPy (http://www.numpy.org/). -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py index a8b346f72cd6f..97d1a042d1479 100644 --- a/examples/src/main/python/ml/lda_example.py +++ b/examples/src/main/python/ml/lda_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating LDA. +Run with: + bin/spark-submit examples/src/main/python/ml/lda_example.py +""" from __future__ import print_function # $example on$ @@ -23,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating LDA. -Run with: - bin/spark-submit examples/src/main/python/ml/lda_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py index bd440a1fbe8df..2274ff707b2a3 100644 --- a/examples/src/main/python/ml/logistic_regression_summary_example.py +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating Logistic Regression Summary. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating Logistic Regression Summary. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/min_hash_lsh_example.py b/examples/src/main/python/ml/min_hash_lsh_example.py index 7b1dd611a865b..93136e6ae3cae 100644 --- a/examples/src/main/python/ml/min_hash_lsh_example.py +++ b/examples/src/main/python/ml/min_hash_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating MinHashLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating MinHashLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py index 8e00c25d9342e..956e94ae4ab62 100644 --- a/examples/src/main/python/ml/one_vs_rest_example.py +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -15,6 +15,12 @@ # limitations under the License. # +""" +An example of Multiclass to Binary Reduction with One Vs Rest, +using Logistic Regression as the base classifier. +Run with: + bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py +""" from __future__ import print_function # $example on$ @@ -23,13 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example of Multiclass to Binary Reduction with One Vs Rest, -using Logistic Regression as the base classifier. -Run with: - bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py index d104f7d30a1bf..d4f9184bf576e 100644 --- a/examples/src/main/python/ml/train_validation_split.py +++ b/examples/src/main/python/ml/train_validation_split.py @@ -15,13 +15,6 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.ml.regression import LinearRegression -from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit -# $example off$ -from pyspark.sql import SparkSession - """ This example demonstrates applying TrainValidationSplit to split data and preform model selection. @@ -29,6 +22,12 @@ bin/spark-submit examples/src/main/python/ml/train_validation_split.py """ +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.regression import LinearRegression +from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 52e9662d528d8..a3f86cf8999cf 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -15,12 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from pyspark.sql import SparkSession - """ Read data file users.parquet in local Spark distro: @@ -35,6 +29,12 @@ {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} <...more log output...> """ +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2: print(""" diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index c07fa8f2752b3..c8fb25d0533b5 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating basic Spark SQL features. +Run with: + ./bin/spark-submit examples/src/main/python/sql/basic.py +""" from __future__ import print_function # $example on:init_session$ @@ -30,12 +35,6 @@ from pyspark.sql.types import * # $example off:programmatic_schema$ -""" -A simple example demonstrating basic Spark SQL features. -Run with: - ./bin/spark-submit examples/src/main/python/sql/basic.py -""" - def basic_df_example(spark): # $example on:create_df$ diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index b375fa775de39..d8c879dfe02ed 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL data sources. +Run with: + ./bin/spark-submit examples/src/main/python/sql/datasource.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.sql import Row # $example off:schema_merging$ -""" -A simple example demonstrating Spark SQL data sources. -Run with: - ./bin/spark-submit examples/src/main/python/sql/datasource.py -""" - def basic_datasource_example(spark): # $example on:generic_load_save_functions$ diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f83a6fb48b97..33fc2dfbeefa2 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL Hive integration. +Run with: + ./bin/spark-submit examples/src/main/python/sql/hive.py +""" from __future__ import print_function # $example on:spark_hive$ @@ -24,12 +29,6 @@ from pyspark.sql import Row # $example off:spark_hive$ -""" -A simple example demonstrating Spark SQL Hive integration. -Run with: - ./bin/spark-submit examples/src/main/python/sql/hive.py -""" - if __name__ == "__main__": # $example on:spark_hive$ From c40fda9e4cf32d6cd17af2ace959bbbbe7c782a4 Mon Sep 17 00:00:00 2001 From: Yacine Mazari Date: Sun, 28 Jan 2018 10:27:59 -0600 Subject: [PATCH 0221/2461] [SPARK-23166][ML] Add maxDF Parameter to CountVectorizer ## What changes were proposed in this pull request? Currently, the CountVectorizer has a minDF parameter. It might be useful to also have a maxDF parameter. It will be used as a threshold for filtering all the terms that occur very frequently in a text corpus, because they are not very informative or could even be stop-words. This is analogous to scikit-learn, CountVectorizer, max_df. Other changes: - Refactored code to invoke "filter()" conditioned on maxDF or minDF set. - Refactored code to unpersist input after counting is done. ## How was this patch tested? Unit tests. Author: Yacine Mazari Closes #20367 from ymazari/SPARK-23166. --- .../spark/ml/feature/CountVectorizer.scala | 67 ++++++++++++++--- .../ml/feature/CountVectorizerSuite.scala | 72 +++++++++++++++++++ 2 files changed, 131 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1ebe29703bc47..60a4f918790a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -69,6 +69,25 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getMinDF: Double = $(minDF) + /** + * Specifies the maximum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer greater than or equal to 1, this specifies the number of documents + * the term must appear in; if this is a double in [0,1), then this specifies the fraction of + * documents. + * + * Default: (2^64^) - 1 + * @group param + */ + val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMaxDF: Double = $(maxDF) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) @@ -113,7 +132,11 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getBinary: Boolean = $(binary) - setDefault(vocabSize -> (1 << 18), minDF -> 1.0, minTF -> 1.0, binary -> false) + setDefault(vocabSize -> (1 << 18), + minDF -> 1.0, + maxDF -> Long.MaxValue, + minTF -> 1.0, + binary -> false) } /** @@ -142,6 +165,10 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setMinDF(value: Double): this.type = set(minDF, value) + /** @group setParam */ + @Since("2.4.0") + def setMaxDF(value: Double): this.type = set(maxDF, value) + /** @group setParam */ @Since("1.5.0") def setMinTF(value: Double): this.type = set(minTF, value) @@ -155,12 +182,24 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0 + val maybeInputSize = if (countingRequired) { + Some(input.cache().count()) + } else { + None + } val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { - $(minDF) * input.cache().count() + $(minDF) * maybeInputSize.get } - val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val maxDf = if ($(maxDF) >= 1.0) { + $(maxDF) + } else { + $(maxDF) * maybeInputSize.get + } + require(maxDf >= minDf, "maxDF must be >= minDF.") + val allWordCounts = input.flatMap { case (tokens) => val wc = new OpenHashMap[String, Long] tokens.foreach { w => wc.changeValue(w, 1L, _ + 1L) @@ -168,11 +207,23 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) wc.map { case (word, count) => (word, (count, 1)) } }.reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2) - }.filter { case (word, (wc, df)) => - df >= minDf - }.map { case (word, (count, dfCount)) => - (word, count) - }.cache() + } + + val filteringRequired = isSet(minDF) || isSet(maxDF) + val maybeFilteredWordCounts = if (filteringRequired) { + allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf } + } else { + allWordCounts + } + + val wordCounts = maybeFilteredWordCounts + .map { case (word, (count, _)) => (word, count) } + .cache() + + if (countingRequired) { + input.unpersist() + } + val fullVocabSize = wordCounts.count() val vocab = wordCounts diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index f213145f1ba0a..1784c07ca23e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -119,6 +119,78 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } } + test("CountVectorizer maxDF") { + val df = Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq())) + ).toDF("id", "words", "expected") + + // maxDF: ignore terms with count more than 3 + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c", "d")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // maxDF: ignore terms with freq > 0.75 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c", "d")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer using both minDF and maxDF") { + // Ignore terms with count more than 3 AND less than 2 + val df = Seq( + (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(2, Seq())) + ).toDF("id", "words", "expected") + + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(2) + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // Ignore terms with frequency higher than 0.75 AND less than 0.5 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(0.5) + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { val df = Seq( From 686a622c93207564635569f054e1e6c921624e96 Mon Sep 17 00:00:00 2001 From: CCInCharge Date: Sun, 28 Jan 2018 14:55:43 -0600 Subject: [PATCH 0222/2461] [SPARK-23250][DOCS] Typo in JavaDoc/ScalaDoc for DataFrameWriter ## What changes were proposed in this pull request? Fix typo in ScalaDoc for DataFrameWriter - originally stated "This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0", should be "starting with Spark 2.1.0". ## How was this patch tested? Check of correct spelling in ScalaDoc Please review http://spark.apache.org/contributing.html before opening a pull request. Author: CCInCharge Closes #20417 from CCInCharge/master. --- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5f3d4448e4e54..5c02eae05304b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -174,7 +174,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * predicates on the partitioned columns. In order for partitioning to work well, the number * of distinct values in each column should typically be less than tens of thousands. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 1.4.0 */ @@ -188,7 +189,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Buckets the output by the given columns. If specified, the output is laid out on the file * system similar to Hive's bucketing scheme. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ @@ -202,7 +204,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Sorts the output in each bucket by the given columns. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ From 49b0207dc9327989c72700b4d04d2a714c92e159 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 29 Jan 2018 13:10:38 +0800 Subject: [PATCH 0223/2461] [SPARK-23196] Unify continuous and microbatch V2 sinks ## What changes were proposed in this pull request? Replace streaming V2 sinks with a unified StreamWriteSupport interface, with a shim to use it with microbatch execution. Add a new SQL config to use for disabling V2 sinks, falling back to the V1 sink implementation. ## How was this patch tested? Existing tests, which in the case of Kafka (the only existing continuous V2 sink) now use V2 for microbatch. Author: Jose Torres Closes #20369 from jose-torres/streaming-sink. --- .../sql/kafka010/KafkaSourceProvider.scala | 16 +-- ...usWriter.scala => KafkaStreamWriter.scala} | 30 ++--- .../kafka010/KafkaContinuousSinkSuite.scala | 8 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 14 ++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../v2/streaming/MicroBatchWriteSupport.java | 60 ---------- ...teSupport.java => StreamWriteSupport.java} | 12 +- ...ontinuousWriter.java => StreamWriter.java} | 34 +++++- .../sources/v2/writer/DataSourceV2Writer.java | 4 +- .../datasources/v2/WriteToDataSourceV2.scala | 11 +- .../streaming/MicroBatchExecution.scala | 19 +-- .../sql/execution/streaming/console.scala | 27 ++--- .../continuous/ContinuousExecution.scala | 19 ++- .../continuous/EpochCoordinator.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 59 ++------- .../streaming/sources/MicroBatchWriter.scala | 54 +++++++++ .../streaming/sources/memoryV2.scala | 29 ++--- .../sql/streaming/DataStreamWriter.scala | 10 +- .../sql/streaming/StreamingQueryManager.scala | 9 +- ...pache.spark.sql.sources.DataSourceRegister | 7 +- .../streaming/MemorySinkV2Suite.scala | 2 +- .../sources/StreamingDataSourceV2Suite.scala | 112 +++++++++--------- 23 files changed, 265 insertions(+), 297 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaContinuousWriter.scala => KafkaStreamWriter.scala} (78%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/{ContinuousWriteSupport.java => StreamWriteSupport.java} (85%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/{ContinuousWriter.java => StreamWriter.java} (50%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 62a998fbfb30b..2deb7fa2cdf1e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -28,11 +28,11 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,7 +46,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport + with StreamWriteSupport with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -223,11 +223,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { + options: DataSourceV2Options): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -238,7 +238,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + new KafkaStreamWriter(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala similarity index 78% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9843f469c5b25..a24efdefa4464 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.kafka010 -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} import scala.collection.JavaConverters._ -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} +import org.apache.spark.sql.types.StructType /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we @@ -38,23 +33,24 @@ import org.apache.spark.sql.types.{BinaryType, StringType, StructType} case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaContinuousWriter( +class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { + extends StreamWriter with SupportsWriteInternalRow { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) + override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} } /** @@ -65,12 +61,12 @@ class KafkaContinuousWriter( * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -case class KafkaContinuousWriterFactory( +case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } @@ -83,7 +79,7 @@ case class KafkaContinuousWriterFactory( * @param producerParams Parameters to use for the Kafka producer. * @param inputSchema The attributes in the input data. */ -class KafkaContinuousDataWriter( +class KafkaStreamDataWriter( targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { import scala.collection.JavaConverters._ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 8487a69851237..fc890a0cfdac3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.kafka010 import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.util.Utils @@ -362,7 +360,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -424,7 +422,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) val inputSchema = Seq(AttributeReference("value", BinaryType)()) val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + val writeTask = new KafkaStreamDataWriter(Some(topic), options.asScala.toMap, inputSchema) try { val fieldTypes: Array[DataType] = Array(BinaryType) val converter = UnsafeProjection.create(fieldTypes) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 2ab336c7ac476..42f8b4c7657e2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -336,27 +336,31 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { val input = MemoryStream[String] var writer: StreamingQuery = null var ex: Exception = null - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index c4cb1bc4a2e18..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -29,19 +29,17 @@ import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata import org.apache.kafka.common.TopicPartition -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 76b9d6f6f33bd..2c70b004bcff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1127,6 +1127,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") + .internal() + .doc("A comma-separated list of fully qualified data source register class names for which" + + " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1494,6 +1501,8 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java deleted file mode 100644 index 53ffa95ae0f4c..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.streaming; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data from a microbatch to the data source. - */ -@InterfaceStability.Evolving -public interface MicroBatchWriteSupport extends BaseStreamingSink { - - /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. - * @param epochId The unique numeric ID of the batch within this writing query. This is an - * incrementing counter representing a consistent set of data; the same batch may - * be started multiple times in failure recovery scenarios, but it will always - * contain the same records. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive batch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - Optional createMicroBatchWriter( - String queryId, - long epochId, - StructType schema, - OutputMode mode, - DataSourceV2Options options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java similarity index 85% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java index dee493cadb71e..6cd219c67109a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java @@ -17,26 +17,24 @@ package org.apache.spark.sql.sources.v2.streaming; -import java.util.Optional; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for continuous stream processing. + * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface ContinuousWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends BaseStreamingSink { /** - * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data + * Creates an optional {@link StreamWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * * @param queryId A unique string for the writing query. It's possible that there are many @@ -48,7 +46,7 @@ public interface ContinuousWriteSupport extends BaseStreamingSink { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createContinuousWriter( + StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java similarity index 50% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java index 723395bd1e963..3156c88933e5e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java @@ -23,10 +23,14 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceV2Writer} for use with continuous stream processing. + * A {@link DataSourceV2Writer} for use with structured streaming. This writer handles commits and + * aborts relative to an epoch ID determined by the execution engine. + * + * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, + * and so must reset any internal state after a successful commit. */ @InterfaceStability.Evolving -public interface ContinuousWriter extends DataSourceV2Writer { +public interface StreamWriter extends DataSourceV2Writer { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by @@ -34,11 +38,35 @@ public interface ContinuousWriter extends DataSourceV2Writer { * * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * + * To support exactly-once processing, writer implementations should ensure that this method is + * idempotent. The execution engine may call commit() multiple times for the same epoch + * in some circumstances. */ void commit(long epochId, WriterCommitMessage[] messages); + /** + * Aborts this writing job because some data writers are failed and keep failing when retry, or + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * + * If this method fails (by throwing an exception), the underlying data source may require manual + * cleanup. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(long epochId, WriterCommitMessage[] messages); + default void commit(WriterCommitMessage[] messages) { throw new UnsupportedOperationException( - "Commit without epoch should not be called with ContinuousWriter"); + "Commit without epoch should not be called with StreamWriter"); + } + + default void abort(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Abort without epoch should not be called with StreamWriter"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index f1ef411423162..8048f507a1dca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,9 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( - * String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 3dbdae7b4df9f..cd6b3e99b6bcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -26,9 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -62,7 +61,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) try { val runTask = writer match { - case w: ContinuousWriter => + // This case means that we're doing continuous processing. In microbatch streaming, the + // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. + case w: StreamWriter => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -82,13 +83,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { + if (!writer.isInstanceOf[StreamWriter]) { logInfo(s"Data source writer $writer is committing.") writer.commit(messages) logInfo(s"Data source writer $writer committed.") } } catch { - case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => + case _: InterruptedException if writer.isInstanceOf[StreamWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 7c3804547b736..975975243a3d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,9 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -440,15 +442,18 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: MicroBatchWriteSupport => - val writer = s.createMicroBatchWriter( + case s: StreamWriteSupport => + val writer = s.createStreamWriter( s"$runId", - currentBatchId, newAttributePlan.schema, outputMode, new DataSourceV2Options(extraOptions.asJava)) - assert(writer.isPresent, "microbatch writer must always be present") - WriteToDataSourceV2(writer.get, newAttributePlan) + if (writer.isInstanceOf[SupportsWriteInternalRow]) { + WriteToDataSourceV2( + new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) + } else { + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + } case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -471,7 +476,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case s: MicroBatchWriteSupport => + case _: StreamWriteSupport => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index f2aa3259731d1..d5ac0bd1df52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -35,26 +32,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with MicroBatchWriteSupport - with ContinuousWriteSupport + with StreamWriteSupport with DataSourceRegister with CreatableRelationProvider { - override def createMicroBatchWriter( - queryId: String, - batchId: Long, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) - } - - override def createContinuousWriter( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - Optional.of(new ConsoleContinuousWriter(schema, options)) + options: DataSourceV2Options): StreamWriter = { + new ConsoleWriter(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 462e7d9721d28..60f880f9c73b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,17 +24,16 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.{SparkEnv, SparkException} -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.SparkEnv +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} @@ -44,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: ContinuousWriteSupport, + sink: StreamWriteSupport, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -195,12 +194,12 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createContinuousWriter( + val writer = sink.createStreamWriter( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceV2Options(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan) + val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { case DataSourceV2Relation(_, r: ContinuousReader) => r @@ -230,7 +229,7 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 90b3584aa0436..84d262116cb46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.util.RpcUtils @@ -85,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, epochCoordinatorId: String, @@ -118,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, startEpoch: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 6fb61dff60045..7c1700f1de48c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -trait ConsoleWriter extends Logging { - - def options: DataSourceV2Options +class ConsoleWriter(schema: StructType, options: DataSourceV2Options) + extends StreamWriter with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -40,14 +39,20 @@ trait ConsoleWriter extends Logging { def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 + // behavior. + printRows(messages, schema, s"Batch: $epochId") + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} protected def printRows( commitMessages: Array[WriterCommitMessage], schema: StructType, printMessage: String): Unit = { val rows = commitMessages.collect { - case PackedRowCommitMessage(rows) => rows + case PackedRowCommitMessage(rs) => rs }.flatten // scalastyle:off println @@ -59,46 +64,8 @@ trait ConsoleWriter extends Logging { .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } -} - - -/** - * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and - * prints them in the console. Created by - * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) - extends DataSourceV2Writer with ConsoleWriter { - - override def commit(messages: Array[WriterCommitMessage]): Unit = { - printRows(messages, schema, s"Batch: $batchId") - } - - override def toString(): String = { - s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" - } -} - - -/** - * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and - * prints them in the console. Created by - * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) - extends ContinuousWriter with ConsoleWriter { - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - printRows(messages, schema, s"Continuous processing epoch $epochId") - } override def toString(): String = { - s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala new file mode 100644 index 0000000000000..d7f3ba8856982 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} + +/** + * A [[DataSourceV2Writer]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped + * streaming writer. + */ +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2Writer { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() +} + +class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) + extends DataSourceV2Writer with SupportsWriteInternalRow { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = + writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => throw new IllegalStateException( + "InternalRowMicroBatchWriter should only be created with base writer support") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index da7c31cf62428..ce55e44d932bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,8 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -40,24 +40,13 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 - with MicroBatchWriteSupport with ContinuousWriteSupport with Logging { - - override def createMicroBatchWriter( +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { + override def createStreamWriter( queryId: String, - batchId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[DataSourceV2Writer] = { - java.util.Optional.of(new MemoryWriter(this, batchId, mode)) - } - - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[ContinuousWriter] = { - java.util.Optional.of(new ContinuousMemoryWriter(this, mode)) + options: DataSourceV2Options): StreamWriter = { + new MemoryStreamWriter(this, mode) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -141,8 +130,8 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) - extends ContinuousWriter { +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) + extends StreamWriter { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -153,7 +142,7 @@ class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) sink.write(epochId, outputMode, newRows) } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d24f0ddeab4de..3b5b30d77945c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -281,11 +281,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { trigger = trigger) } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - val sink = (ds.newInstance(), trigger) match { - case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w - case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( - s"Data source $source does not support continuous writing") - case (w: MicroBatchWriteSupport, _) => w + val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") + val sink = ds.newInstance() match { + case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 4b27e0d4ef47b..fdd709cdb1f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -241,7 +241,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) new StreamingQueryWrapper(new ContinuousExecution( sparkSession, @@ -254,7 +254,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode, extraOptions, deleteCheckpointOnStop)) - case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + case _ => new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, @@ -266,9 +266,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode, extraOptions, deleteCheckpointOnStop)) - case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => - throw new AnalysisException( - "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a0b25b4e82364..46b38bed1c0fb 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,7 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly -org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly -org.apache.spark.sql.streaming.sources.FakeWriteBothModes -org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode +org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeNoWrite +org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 00d4f0b8503d8..9be22d94b5654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -40,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new ContinuousMemoryWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index f152174b0a7f0..d4f8bae96695d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.streaming.sources import java.util.Optional -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader.ReadTask -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer -import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -64,23 +64,12 @@ trait FakeContinuousReadSupport extends ContinuousReadSupport { options: DataSourceV2Options): ContinuousReader = FakeReader() } -trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { - def createMicroBatchWriter( +trait FakeStreamWriteSupport extends StreamWriteSupport { + override def createStreamWriter( queryId: String, - epochId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - throw new IllegalStateException("fake sink - cannot actually write") - } -} - -trait FakeContinuousWriteSupport extends ContinuousWriteSupport { - def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { + options: DataSourceV2Options): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } @@ -102,23 +91,36 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { - override def shortName(): String = "fake-write-microbatch-only" +class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" } -class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { - override def shortName(): String = "fake-write-continuous-only" +class FakeNoWrite extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" } -class FakeWriteBothModes extends DataSourceRegister - with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { - override def shortName(): String = "fake-write-microbatch-continuous" + +case class FakeWriteV1FallbackException() extends Exception + +class FakeSink extends Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteNeitherMode extends DataSourceRegister { - override def shortName(): String = "fake-write-neither-mode" +class FakeWriteV1Fallback extends DataSourceRegister + with FakeStreamWriteSupport with StreamSinkProvider { + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + new FakeSink() + } + + override def shortName(): String = "fake-write-v1-fallback" } + class StreamingDataSourceV2Suite extends StreamTest { override def beforeAll(): Unit = { @@ -133,8 +135,6 @@ class StreamingDataSourceV2Suite extends StreamTest { "fake-read-microbatch-continuous", "fake-read-neither-mode") val writeFormats = Seq( - "fake-write-microbatch-only", - "fake-write-continuous-only", "fake-write-microbatch-continuous", "fake-write-neither-mode") val triggers = Seq( @@ -151,6 +151,7 @@ class StreamingDataSourceV2Suite extends StreamTest { .trigger(trigger) .start() query.stop() + query } private def testNegativeCase( @@ -184,6 +185,24 @@ class StreamingDataSourceV2Suite extends StreamTest { } } + test("disabled v2 write") { + // Ensure the V2 path works normally and generates a V2 sink.. + val v2Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeWriteV1Fallback]) + + // Ensure we create a V1 sink with the config. Note the config is a comma separated + // list, including other fake entries. + val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { + val v1Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeSink]) + } + } + // Get a list of (read, write, trigger) tuples for test cases. val cases = readFormats.flatMap { read => writeFormats.flatMap { write => @@ -199,12 +218,12 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all @@ -214,31 +233,18 @@ class StreamingDataSourceV2Suite extends StreamTest { testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") - // Invalid - trigger is continuous but writer is not - case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => - testNegativeCase(read, write, trigger, - s"Data source $write does not support continuous writing") - - // Invalid - can't write at all - case (_, w, _) - if !w.isInstanceOf[MicroBatchWriteSupport] - && !w.isInstanceOf[ContinuousWriteSupport] => + // Invalid - can't write + case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") - // Invalid - trigger and writer are continuous but reader is not - case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + // Invalid - trigger is continuous but reader is not + case (r, _: StreamWriteSupport, _: ContinuousTrigger) if !r.isInstanceOf[ContinuousReadSupport] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") - // Invalid - trigger is microbatch but writer is not - case (_, w, t) - if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => - testNegativeCase(read, write, trigger, - s"Data source $write does not support streamed writing") - - // Invalid - trigger and writer are microbatch but reader is not + // Invalid - trigger is microbatch but reader is not case (r, _, t) if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, From 39d2c6b03488895a0acb1dd3c46329db00fdd357 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 29 Jan 2018 21:09:05 +0900 Subject: [PATCH 0224/2461] [SPARK-23238][SQL] Externalize SQLConf configurations exposed in documentation ## What changes were proposed in this pull request? This PR proposes to expose few internal configurations found in the documentation. Also it fixes the description for `spark.sql.execution.arrow.enabled`. It's quite self-explanatory. ## How was this patch tested? N/A Author: hyukjinkwon Closes #20403 from HyukjinKwon/minor-doc-arrow. --- .../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2c70b004bcff9..61ea03d395afc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -123,14 +123,12 @@ object SQLConf { .createWithDefault(10) val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") - .internal() .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") .booleanConf .createWithDefault(true) val COLUMN_BATCH_SIZE = buildConf("spark.sql.inMemoryColumnarStorage.batchSize") - .internal() .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + "memory utilization and compression, but risk OOMs when caching data.") .intConf @@ -1043,11 +1041,11 @@ object SQLConf { val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enabled") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") + .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas, and " + + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + + "The following data types are unsupported: " + + "MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From badf0d0e0d1d9aa169ed655176ce9ae684d3905d Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 30 Jan 2018 00:50:49 +0800 Subject: [PATCH 0225/2461] [SPARK-23219][SQL] Rename ReadTask to DataReaderFactory in data source v2 ## What changes were proposed in this pull request? Currently we have `ReadTask` in data source v2 reader, while in writer we have `DataWriterFactory`. To make the naming consistent and better, renaming `ReadTask` to `DataReaderFactory`. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20397 from gengliangwang/rename. --- .../sql/kafka010/KafkaContinuousReader.scala | 16 ++--- .../execution/UnsafeExternalRowSorter.java | 1 - .../v2/reader/ClusteredDistribution.java | 2 +- .../sql/sources/v2/reader/DataReader.java | 2 +- .../{ReadTask.java => DataReaderFactory.java} | 22 +++---- .../sources/v2/reader/DataSourceV2Reader.java | 11 ++-- .../sql/sources/v2/reader/Distribution.java | 6 +- .../sql/sources/v2/reader/Partitioning.java | 2 +- .../v2/reader/SupportsScanColumnarBatch.java | 11 ++-- .../v2/reader/SupportsScanUnsafeRow.java | 9 +-- .../v2/streaming/MicroBatchReadSupport.java | 4 +- .../v2/streaming/reader/ContinuousReader.java | 14 ++--- .../v2/streaming/reader/MicroBatchReader.java | 6 +- .../datasources/v2/DataSourceRDD.scala | 14 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 25 ++++---- .../ContinuousDataSourceRDDIter.scala | 11 ++-- .../ContinuousRateStreamSource.scala | 10 ++-- .../sources/RateStreamSourceV2.scala | 6 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 20 +++---- .../sql/sources/v2/JavaBatchDataSourceV2.java | 10 ++-- .../v2/JavaPartitionAwareDataSource.java | 10 ++-- .../v2/JavaSchemaRequiredDataSource.java | 4 +- .../sources/v2/JavaSimpleDataSourceV2.java | 14 ++--- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 13 ++-- .../streaming/RateSourceV2Suite.scala | 10 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 59 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 12 ++-- .../sources/StreamingDataSourceV2Suite.scala | 4 +- 28 files changed, 172 insertions(+), 156 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ReadTask.java => DataReaderFactory.java} (65%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index fc977977504f7..9125cf5799d74 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -63,7 +63,7 @@ class KafkaContinuousReader( private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating read tasks. If this diverges from the partitions at the latest + // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. private[sql] var knownPartitions: Set[TopicPartition] = _ @@ -89,7 +89,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -109,9 +109,9 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousReadTask( + KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] + .asInstanceOf[DataReaderFactory[UnsafeRow]] }.asJava } @@ -149,8 +149,8 @@ class KafkaContinuousReader( } /** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. + * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. * @param startOffset The offset to start reading from within the partition. @@ -159,12 +159,12 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousReadTask( +case class KafkaContinuousDataReaderFactory( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 78647b56d621f..1b2f5eee5ccdd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import org.apache.spark.sql.catalyst.util.TypeUtils; import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java index 7346500de45b6..27905e325df87 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -22,7 +22,7 @@ /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link ReadTask}. + * {@link DataReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 8f58c865b6201..bb9790a1c819e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link ReadTask#createDataReader()} and is responsible for + * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java similarity index 65% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index fa161cdb8b347..077b95b837964 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,21 +22,23 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for - * creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} + * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is + * responsible for creating the actual data reader. The relationship between + * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. So {@link ReadTask} must be serializable and - * {@link DataReader} doesn't need to be. + * Note that, the reader factory will be serialized and sent to executors, then the data reader + * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be + * serializable and {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface ReadTask extends Serializable { +public interface DataReaderFactory extends Serializable { /** - * The preferred locations where this read task can run faster, but Spark does not guarantee that - * this task will always run on these locations. The implementations should make sure that it can - * be run on any location. The location is a string representing the host name. + * The preferred locations where the data reader returned by this reader factory can run faster, + * but Spark does not guarantee to run the data reader on these locations. + * The implementations should make sure that it can be run on any location. + * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in * the returned locations. By default this method returns empty string array, which means this @@ -50,7 +52,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work for this read task. + * Returns a data reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index f23c3842bf1b1..0180cd9ea47f8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,8 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link DataReaderFactory}s that are returned by + * {@link #createDataReaderFactories()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -63,9 +64,9 @@ public interface DataSourceV2Reader { StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for outputting data for one RDD - * partition. That means the number of tasks returned here is same as the number of RDD - * partitions this scan outputs. + * Returns a list of reader factories. Each factory is responsible for creating a data reader to + * output data for one RDD partition. That means the number of factories returned here is same as + * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before @@ -74,5 +75,5 @@ public interface DataSourceV2Reader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createReadTasks(); + List> createDataReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java index a6201a222f541..b37562167d9ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -21,9 +21,9 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the {@link ReadTask}s that are returned by - * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with - * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * Note that this interface has nothing to do with the data ordering inside one + * partition(the output records of a single {@link DataReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java index 199e45d4a02ab..5e334d13a1215 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -29,7 +29,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 27cf3a77724f0..67da55554bbf3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,21 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceV2Reader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks not supported by default within SupportsScanColumnarBatch."); + "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data + * in batches. */ - List> createBatchReadTasks(); + List> createBatchDataReaderFactories(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createReadTasks()} to fallback to normal read path under some conditions. + * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 2d3ad0eee65ff..156af69520f77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,13 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks not supported by default within SupportsScanUnsafeRow"); + "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. + * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, + * but returns data in unsafe row format. */ - List> createUnsafeRowReadTasks(); + List> createUnsafeRowReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 3c87a3db68243..3b357c01a29fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -36,8 +36,8 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createReadTasks for each batch to process, and then - * call stop() when the execution is complete. Note that a single query may have multiple + * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * * @param schema the user provided schema, or empty() if none was provided diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 745f1ce502443..3ac979cb0b7b4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -47,9 +47,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset deserializeOffset(String json); /** - * Set the desired start offset for read tasks created from this reader. The scan will start - * from the first record after the provided offset, or from an implementation-defined inferred - * starting point if no offset is provided. + * Set the desired start offset for reader factories created from this reader. The scan will + * start from the first record after the provided offset, or from an implementation-defined + * inferred starting point if no offset is provided. */ void setOffset(Optional start); @@ -61,9 +61,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new read tasks need - * to be generated, which may be required if for example the underlying source system has had - * partitions added or removed. + * The execution engine will call this method in every epoch to determine if new reader + * factories need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 02f37cebc7484..68887e569fc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -33,9 +33,9 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** - * Set the desired offset range for read tasks created from this reader. Read tasks will - * generate only data within (`start`, `end`]; that is, from the first record after `start` to - * the record with offset `end`. + * Set the desired offset range for reader factories created from this reader. Reader factories + * will generate only data within (`start`, `end`]; that is, from the first record after `start` + * to the record with offset `end`. * * @param start The initial offset to scan from. If not specified, scan from an * implementation-specified start point, such as the earliest available record. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index ac104d7cd0cb3..5ed0ba71e94c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,24 +22,24 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[T]]) + @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.asScala.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +63,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 2c22239e81869..3f808fbb40932 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -51,11 +51,11 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + reader.createDataReaderFactories().asScala.map { + new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] }.asJava } @@ -63,18 +63,19 @@ case class DataSourceV2ScanExec( case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) + .asInstanceOf[RDD[InternalRow]] case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .askSync[Unit](SetReaderPartitions(readerFactories.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -99,14 +100,14 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) - extends ReadTask[UnsafeRow] { +class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) + extends DataReaderFactory[UnsafeRow] { - override def preferredLocations: Array[String] = rowReadTask.preferredLocations + override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations override def createDataReader: DataReader[UnsafeRow] = { new RowToUnsafeDataReader( - rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cd7065f5e6601..8a7a38b22caca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -39,15 +39,15 @@ import org.apache.spark.util.{SystemClock, ThreadUtils} class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.asScala.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } @@ -57,7 +57,8 @@ class ContinuousDataSourceRDD( throw new ContinuousTaskRetryException() } - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] + .readerFactory.createDataReader() val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) @@ -136,7 +137,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b4b21e7d2052f..61304480f4721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -68,7 +68,7 @@ class RateStreamContinuousReader(options: DataSourceV2Options) override def getStartOffset(): Offset = offset - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -86,13 +86,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousReadTask( + RateStreamContinuousDataReaderFactory( start.value, start.runTimeMs, i, numPartitions, perPartitionRate) - .asInstanceOf[ReadTask[Row]] + .asInstanceOf[DataReaderFactory[Row]] }.asJava } @@ -101,13 +101,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options) } -case class RateStreamContinuousReadTask( +case class RateStreamContinuousDataReaderFactory( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ReadTask[Row] { + extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index c0ed12cec25ef..a25cc4f3b06f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -123,7 +123,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val startMap = start.partitionToValueAndRunTimeMs val endMap = end.partitionToValueAndRunTimeMs endMap.keys.toSeq.map { part => @@ -139,7 +139,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) outTimeMs += msPerPartitionBetweenRows } - RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]] + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] }.toList.asJava } @@ -147,7 +147,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) override def stop(): Unit = {} } -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 1cfdc08217e6e..4026ee44bfdb7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -60,8 +60,8 @@ public Filter[] pushedFilters() { } @Override - public List> createReadTasks() { - List> res = new ArrayList<>(); + public List> createDataReaderFactories() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -75,25 +75,25 @@ public List> createReadTasks() { } if (lowerBound == null) { - res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedReadTask implements ReadTask, DataReader { + static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; @@ -101,7 +101,7 @@ static class JavaAdvancedReadTask implements ReadTask, DataReader { @Override public DataReader createDataReader() { - return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index a5d77a90ece42..34e6c63801064 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,12 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchReadTasks() { - return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + public List> createBatchDataReaderFactories() { + return java.util.Arrays.asList( + new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); } } - static class JavaBatchReadTask implements ReadTask, DataReader { + static class JavaBatchDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; @@ -57,7 +59,7 @@ static class JavaBatchReadTask implements ReadTask, DataReader> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -70,12 +70,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificReadTask implements ReadTask, DataReader { + static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { private int[] i; private int[] j; private int current = -1; - SpecificReadTask(int[] i, int[] j) { + SpecificDataReaderFactory(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index a174bd8092cbd..f997366af1a64 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2d458b7f7e906..2beed431d301f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new JavaSimpleReadTask(0, 5), - new JavaSimpleReadTask(5, 10)); + new JavaSimpleDataReaderFactory(0, 5), + new JavaSimpleDataReaderFactory(5, 10)); } } - static class JavaSimpleReadTask implements ReadTask, DataReader { + static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; - JavaSimpleReadTask(int start, int end) { + JavaSimpleDataReaderFactory(int start, int end) { this.start = start; this.end = end; } @Override public DataReader createDataReader() { - return new JavaSimpleReadTask(start - 1, end); + return new JavaSimpleDataReaderFactory(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index f6aa00869a681..e8187524ea871 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,19 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReadTasks() { + public List> createUnsafeRowReaderFactories() { return java.util.Arrays.asList( - new JavaUnsafeRowReadTask(0, 5), - new JavaUnsafeRowReadTask(5, 10)); + new JavaUnsafeRowDataReaderFactory(0, 5), + new JavaUnsafeRowDataReaderFactory(5, 10)); } } - static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + static class JavaUnsafeRowDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowReadTask(int start, int end) { + JavaUnsafeRowDataReaderFactory(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,7 +60,7 @@ static class JavaUnsafeRowReadTask implements ReadTask, DataReader createDataReader() { - return new JavaUnsafeRowReadTask(start - 1, end); + return new JavaUnsafeRowDataReaderFactory(start - 1, end); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 85085d43061bd..d2cfe7905f6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -78,7 +78,7 @@ class RateSourceV2Suite extends StreamTest { val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) } @@ -118,7 +118,7 @@ class RateSourceV2Suite extends StreamTest { val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 1) assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) } @@ -133,7 +133,7 @@ class RateSourceV2Suite extends StreamTest { }.toMap) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) val readData = tasks.asScala @@ -161,12 +161,12 @@ class RateSourceV2Suite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousReadTask => + case t: RateStreamContinuousDataReaderFactory => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0620693b35d16..42c5d3bcea44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -204,18 +204,20 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { - java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { +class SimpleDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[Row] + with DataReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) + override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) override def next(): Boolean = { current += 1 @@ -252,21 +254,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[ReadTask[Row]] + val res = new ArrayList[DataReaderFactory[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedReadTask(0, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) } res @@ -276,13 +278,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) - extends ReadTask[Row] with DataReader[Row] { +class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) + extends DataReaderFactory[Row] with DataReader[Row] { private var current = start - 1 override def createDataReader(): DataReader[Row] = { - new AdvancedReadTask(start, end, requiredSchema) + new AdvancedDataReaderFactory(start, end, requiredSchema) } override def close(): Unit = {} @@ -307,16 +309,17 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), + new UnsafeRowDataReaderFactory(5, 10)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class UnsafeRowReadTask(start: Int, end: Int) - extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) @@ -341,7 +344,7 @@ class UnsafeRowReadTask(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceV2Reader { - override def createReadTasks(): JList[ReadTask[Row]] = + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } @@ -354,16 +357,16 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class BatchReadTask(start: Int, end: Int) - extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -406,11 +409,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -428,7 +431,9 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { +class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) + extends DataReaderFactory[Row] + with DataReader[Row] { assert(i.length == j.length) private var current = -1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index cd7252eb2e3d6..3310d6dd199d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { override def readSchema(): StructType = schema - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,7 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + new SimpleCSVDataReaderFactory( + f.getPath.toUri.toString, + serializableConf): DataReaderFactory[Row] }.toList.asJava } else { Collections.emptyList() @@ -149,8 +151,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) - extends ReadTask[Row] with DataReader[Row] { +class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) + extends DataReaderFactory[Row] with DataReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index d4f8bae96695d..dc8c857018457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter @@ -45,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setOffset(start: Optional[Offset]): Unit = {} - def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From 54dd7cf4ef921bc9dc12f99cfb90d1da57939901 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 29 Jan 2018 08:56:42 -0800 Subject: [PATCH 0226/2461] [SPARK-23199][SQL] improved Removes repetition from group expressions in Aggregate ## What changes were proposed in this pull request? Currently, all Aggregate operations will go into RemoveRepetitionFromGroupExpressions, but there is no group expression or there is no duplicate group expression in group expression, we not need copy for logic plan. ## How was this patch tested? the existed test case. Author: caoxuewen Closes #20375 from heary-cao/RepetitionGroupExpressions. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++++-- .../sql/catalyst/optimizer/AggregateOptimizeSuite.scala | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d207708c12ad..a28b6a0feb8f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1302,8 +1302,12 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { */ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.size > 1 => val newGrouping = ExpressionSet(grouping).toSeq - a.copy(groupingExpressions = newGrouping) + if (newGrouping.size == grouping.size) { + a + } else { + a.copy(groupingExpressions = newGrouping) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index a3184a4266c7c..f8ddc93597070 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -67,10 +67,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("remove repetition in grouping expression") { - val input = LocalRelation('a.int, 'b.int, 'c.int) - val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze + val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } From fbce2ed0fa5c3e9fb2bdf9d9741eb3ff0760f88c Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 29 Jan 2018 08:58:14 -0800 Subject: [PATCH 0227/2461] [SPARK-23059][SQL][TEST] Correct some improper with view related method usage ## What changes were proposed in this pull request? Correct some improper with view related method usage Only change test cases like: ``` test("list global temp views") { try { sql("CREATE GLOBAL TEMP VIEW v1 AS SELECT 3, 4") sql("CREATE TEMP VIEW v2 AS SELECT 1, 2") checkAnswer(sql(s"SHOW TABLES IN $globalTempDB"), Row(globalTempDB, "v1", true) :: Row("", "v2", true) :: Nil) assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { spark.catalog.dropTempView("v1") spark.catalog.dropGlobalTempView("v2") } } ``` other change please review the code. ## How was this patch tested? See test case. Author: xubo245 <601450868@qq.com> Closes #20250 from xubo245/DropTempViewError. --- .../org/apache/spark/sql/SQLQuerySuite.scala | 48 ++++++++++--------- .../sql/execution/GlobalTempViewSuite.scala | 4 +- .../spark/sql/execution/SQLViewSuite.scala | 36 ++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../sql/hive/execution/HiveSQLViewSuite.scala | 26 +++++----- .../sql/hive/execution/SQLQuerySuite.scala | 44 +++++++++-------- 7 files changed, 88 insertions(+), 74 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a79ab47f0197e..ffd736d2ebbb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1565,36 +1565,38 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = - sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) - }.getMessage - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index dcc6fa6403f31..972b47e96fe06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -134,8 +134,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { - spark.catalog.dropTempView("v1") - spark.catalog.dropGlobalTempView("v2") + spark.catalog.dropGlobalTempView("v1") + spark.catalog.dropTempView("v2") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index ce8fde28a941c..8269d4d3a285d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -53,15 +53,17 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a temp view on a permanent view") { - withView("jtv1", "temp_jtv1") { - sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") - checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + withView("jtv1") { + withTempView("temp_jtv1") { + sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") + checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + } } } test("create a temp view on a temp view") { - withView("temp_jtv1", "temp_jtv2") { + withTempView("temp_jtv1", "temp_jtv2") { sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") sql("CREATE TEMPORARY VIEW temp_jtv2 AS SELECT * FROM temp_jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM temp_jtv2"), Row(2)) @@ -222,10 +224,12 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("error handling: disallow IF NOT EXISTS for CREATE TEMPORARY VIEW") { - val e = intercept[AnalysisException] { - sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + withTempView("myabcdview") { + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + } + assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } - assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } test("error handling: fail if the temp view sql itself is invalid") { @@ -274,7 +278,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("correctly parse CREATE TEMPORARY VIEW statement") { - withView("testView") { + withTempView("testView") { sql( """CREATE TEMPORARY VIEW |testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') @@ -286,7 +290,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("should NOT allow CREATE TEMPORARY VIEW when TEMPORARY VIEW with same name exists") { - withView("testView") { + withTempView("testView") { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") val e = intercept[AnalysisException] { @@ -299,15 +303,19 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { test("should allow CREATE TEMPORARY VIEW when a permanent VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE VIEW testView AS SELECT id FROM jt") - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE VIEW testView AS SELECT id FROM jt") + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + } } } test("should allow CREATE permanent VIEW when a TEMPORARY VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6ca21b5aa1595..ee3674ba17821 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -739,7 +739,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it // to a temp file by withResourceTempPath withResourceTempPath("test-data/cars.csv") { tmpFile => - withView("testview") { + withTempView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + s"OPTIONS (PATH '${tmpFile.toURI}')") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index fade143a1755e..859099a321bf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1151,7 +1151,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("create a temp view using hive") { val tableName = "tab1" - withTable(tableName) { + withTempView(tableName) { val e = intercept[AnalysisException] { sql( s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 97e4c2b6b2db8..5e6e114fc3fdc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -67,20 +67,22 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName withUserDefinedFunction(tempFunctionName -> true) { sql(s"CREATE TEMPORARY FUNCTION $tempFunctionName AS '$functionClass'") - withView("view1", "tempView1") { - withTable("tab1") { - (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") + withView("view1") { + withTempView("tempView1") { + withTable("tab1") { + (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") - // temporary view - sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") - checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) + // temporary view + sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") + checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) - // permanent view - val e = intercept[AnalysisException] { - sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + - s"a temporary function `$tempFunctionName`")) + // permanent view + val e = intercept[AnalysisException] { + sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + + s"a temporary function `$tempFunctionName`")) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 33bcae91fdaf4..baabc4a3bca2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1203,35 +1203,37 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + } + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) + checkAnswer(spark.table("`db.t`"), df) } - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) } } From 2d903cf9d3a827e54217dfc9f1e4be99d8204387 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 29 Jan 2018 09:00:54 -0800 Subject: [PATCH 0228/2461] [SPARK-23223][SQL] Make stacking dataset transforms more performant ## What changes were proposed in this pull request? It is a common pattern to apply multiple transforms to a `Dataset` (using `Dataset.withColumn` for example. This is currently quite expensive because we run `CheckAnalysis` on the full plan and create an encoder for each intermediate `Dataset`. This PR extends the usage of the `AnalysisBarrier` to include `CheckAnalysis`. By doing this we hide the already analyzed plan from `CheckAnalysis` because barrier is a `LeafNode`. The `AnalysisBarrier` is in the `FinishAnalysis` phase of the optimizer. We also make binding the `Dataset` encoder lazy. The bound encoder is only needed when we materialize the dataset. ## How was this patch tested? Existing test should cover this. Author: Herman van Hovell Closes #20402 from hvanhovell/SPARK-23223. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/analysis/AnalysisTest.scala | 3 +-- .../scala/org/apache/spark/sql/Dataset.scala | 8 ++++++-- .../spark/sql/execution/QueryExecution.scala | 16 ++-------------- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 6 files changed, 25 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2b14c8220d43b..91cb0365a0856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -98,6 +98,19 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + val analyzed = execute(plan) + try { + checkAnalysis(analyzed) + EliminateBarriers(analyzed) + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.reset() try { @@ -178,8 +191,7 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases, - EliminateBarriers) + CleanupAliases) ) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ef91d79f3302c..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -356,6 +356,7 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { + case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 549a4355dfba3..3d7c91870133b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -54,8 +54,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.execute(inputPlan) - analyzer.checkAnalysis(actualPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index edb6644ed5ac0..cc5b647b3f037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,11 @@ import org.apache.spark.util.Utils private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { - new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + // Eagerly bind the encoder so we verify that the encoder matches the underlying + // schema. The user will get an error if this is not the case. + dataset.deserializer + dataset } def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { @@ -204,7 +208,7 @@ class Dataset[T] private[sql]( // The deserializer expression which can be used to build a projection and turn rows to objects // of type T, after collecting rows to the driver side. - private val deserializer = + private lazy val deserializer = exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8bfe3eff0c3b3..7cae24bf5976c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -44,19 +44,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner - def assertAnalyzed(): Unit = { - // Analyzer is invoked outside the try block to avoid calling it again from within the - // catch block below. - analyzed - try { - sparkSession.sessionState.analyzer.checkAnalysis(analyzed) - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae - } - } + def assertAnalyzed(): Unit = analyzed def assertSupported(): Unit = { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { @@ -66,7 +54,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } lazy val withCachedData: LogicalPlan = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7287e20d55bbe..59708e7a0f2ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -575,7 +575,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } } From 0d60b3213fe9a7ae5e9b208639f92011fdb2ca32 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Jan 2018 10:25:25 -0800 Subject: [PATCH 0229/2461] [SPARK-22221][DOCS] Adding User Documentation for Arrow ## What changes were proposed in this pull request? Adding user facing documentation for working with Arrow in Spark Author: Bryan Cutler Author: Li Jin Author: hyukjinkwon Closes #19575 from BryanCutler/arrow-user-docs-SPARK-2221. --- docs/sql-programming-guide.md | 134 +++++++++++++++++++++++++- examples/src/main/python/sql/arrow.py | 129 +++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/python/sql/arrow.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 502c0a8c37e01..d49c8d869cba6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1640,6 +1640,138 @@ Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` a You may run `./bin/spark-sql --help` for a complete list of all available options. +# PySpark Usage Guide for Pandas with Apache Arrow + +## Apache Arrow in Spark + +Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer +data between JVM and Python processes. This currently is most beneficial to Python users that +work with Pandas/NumPy data. Its usage is not automatic and might require some minor +changes to configuration or code to take full advantage and ensure compatibility. This guide will +give a high-level description of how to use Arrow in Spark and highlight any differences when +working with Arrow-enabled data. + +### Ensure PyArrow Installed + +If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the +SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow +is installed and available on all cluster nodes. The current supported version is 0.8.0. +You can install using pip or conda from the conda-forge channel. See PyArrow +[installation](https://arrow.apache.org/docs/python/install.html) for details. + +## Enabling for Conversion to/from Pandas + +Arrow is available as an optimization when converting a Spark DataFrame to a Pandas DataFrame +using the call `toPandas()` and when creating a Spark DataFrame from a Pandas DataFrame with +`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set +the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. + +
+
+{% include_example dataframe_with_arrow python/sql/arrow.py %} +
+
+ +Using the above optimizations with Arrow will produce the same results as when Arrow is not +enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the +DataFrame to the driver program and should be done on a small subset of the data. Not all Spark +data types are currently supported and an error can be raised if a column has an unsupported type, +see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +Spark will fall back to create the DataFrame without Arrow. + +## Pandas UDFs (a.k.a. Vectorized UDFs) + +Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and +Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator +or to wrap the function, no additional configuration is required. Currently, there are two types of +Pandas UDF: Scalar and Group Map. + +### Scalar + +Scalar Pandas UDFs are used for vectorizing scalar operations. They can be used with functions such +as `select` and `withColumn`. The Python function should take `pandas.Series` as inputs and return +a `pandas.Series` of the same length. Internally, Spark will execute a Pandas UDF by splitting +columns into batches and calling the function for each batch as a subset of the data, then +concatenating the results together. + +The following example shows how to create a scalar Pandas UDF that computes the product of 2 columns. + +
+
+{% include_example scalar_pandas_udf python/sql/arrow.py %} +
+
+ +### Group Map +Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +Split-apply-combine consists of three steps: +* Split the data into groups by using `DataFrame.groupBy`. +* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The + input data contains all the rows and columns for each group. +* Combine the results into a new `DataFrame`. + +To use `groupBy().apply()`, the user needs to define the following: +* A Python function that defines the computation for each group. +* A `StructType` object or a string that defines the schema of the output `DataFrame`. + +Note that all data for a group will be loaded into memory before the function is applied. This can +lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user +to ensure that the grouped data will fit into the available memory. + +The following example shows how to use `groupby().apply()` to subtract the mean from each value in the group. + +
+
+{% include_example group_map_pandas_udf python/sql/arrow.py %} +
+
+ +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and +[`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). + +## Usage Notes + +### Supported SQL Types + +Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +`ArrayType` of `TimestampType`, and nested `StructType`. + +### Setting Arrow Batch Size + +Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to +high memory usage in the JVM. To avoid possible out of memory exceptions, the size of the Arrow +record batches can be adjusted by setting the conf "spark.sql.execution.arrow.maxRecordsPerBatch" +to an integer that will determine the maximum number of rows for each batch. The default value is +10,000 records per batch. If the number of columns is large, the value should be adjusted +accordingly. Using this limit, each data partition will be made into 1 or more record batches for +processing. + +### Timestamp with Time Zone Semantics + +Spark internally stores timestamps as UTC values, and timestamp data that is brought in without +a specified time zone is converted as local time to UTC with microsecond resolution. When timestamp +data is exported or displayed in Spark, the session time zone is used to localize the timestamp +values. The session time zone is set with the configuration 'spark.sql.session.timeZone' and will +default to the JVM system local time zone if not set. Pandas uses a `datetime64` type with nanosecond +resolution, `datetime64[ns]`, with optional time zone on a per-column basis. + +When timestamp data is transferred from Spark to Pandas it will be converted to nanoseconds +and each column will be converted to the Spark session time zone then localized to that time +zone, which removes the time zone and displays values as local time. This will occur +when calling `toPandas()` or `pandas_udf` with timestamp columns. + +When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This +occurs when calling `createDataFrame` with a Pandas DataFrame or when returning a timestamp from a +`pandas_udf`. These conversions are done automatically to ensure Spark will have data in the +expected format, so it is not necessary to do any of these conversions yourself. Any nanosecond +values will be truncated. + +Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is +different than a Pandas timestamp. It is recommended to use Pandas time series functionality when +working with timestamps in `pandas_udf`s to get the best performance, see +[here](https://pandas.pydata.org/pandas-docs/stable/timeseries.html) for details. + # Migration Guide ## Upgrading From Spark SQL 2.2 to 2.3 @@ -1788,7 +1920,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py new file mode 100644 index 0000000000000..6c0028b3f1c1f --- /dev/null +++ b/examples/src/main/python/sql/arrow.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A simple example demonstrating Arrow in Spark. +Run with: + ./bin/spark-submit examples/src/main/python/sql/arrow.py +""" + +from __future__ import print_function + +from pyspark.sql import SparkSession +from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version + +require_minimum_pandas_version() +require_minimum_pyarrow_version() + + +def dataframe_with_arrow_example(spark): + # $example on:dataframe_with_arrow$ + import numpy as np + import pandas as pd + + # Enable Arrow-based columnar data transfers + spark.conf.set("spark.sql.execution.arrow.enabled", "true") + + # Generate a Pandas DataFrame + pdf = pd.DataFrame(np.random.rand(100, 3)) + + # Create a Spark DataFrame from a Pandas DataFrame using Arrow + df = spark.createDataFrame(pdf) + + # Convert the Spark DataFrame back to a Pandas DataFrame using Arrow + result_pdf = df.select("*").toPandas() + # $example off:dataframe_with_arrow$ + print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) + + +def scalar_pandas_udf_example(spark): + # $example on:scalar_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import col, pandas_udf + from pyspark.sql.types import LongType + + # Declare the function and create the UDF + def multiply_func(a, b): + return a * b + + multiply = pandas_udf(multiply_func, returnType=LongType()) + + # The function for a pandas_udf should be able to execute with local Pandas data + x = pd.Series([1, 2, 3]) + print(multiply_func(x, x)) + # 0 1 + # 1 4 + # 2 9 + # dtype: int64 + + # Create a Spark DataFrame, 'spark' is an existing SparkSession + df = spark.createDataFrame(pd.DataFrame(x, columns=["x"])) + + # Execute function as a Spark vectorized UDF + df.select(multiply(col("x"), col("x"))).show() + # +-------------------+ + # |multiply_func(x, x)| + # +-------------------+ + # | 1| + # | 4| + # | 9| + # +-------------------+ + # $example off:scalar_pandas_udf$ + + +def group_map_pandas_udf_example(spark): + # $example on:group_map_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + def substract_mean(pdf): + # pdf is a pandas.DataFrame + v = pdf.v + return pdf.assign(v=v - v.mean()) + + df.groupby("id").apply(substract_mean).show() + # +---+----+ + # | id| v| + # +---+----+ + # | 1|-0.5| + # | 1| 0.5| + # | 2|-3.0| + # | 2|-1.0| + # | 2| 4.0| + # +---+----+ + # $example off:group_map_pandas_udf$ + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("Python Arrow-in-Spark example") \ + .getOrCreate() + + print("Running Pandas to/from conversion example") + dataframe_with_arrow_example(spark) + print("Running pandas_udf scalar example") + scalar_pandas_udf_example(spark) + print("Running pandas_udf group map example") + group_map_pandas_udf_example(spark) + + spark.stop() From e30b34f7bd9a687eb43d636fffeb98fe235fcbf4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 29 Jan 2018 10:29:42 -0800 Subject: [PATCH 0230/2461] [SPARK-22916][SQL][FOLLOW-UP] Update the Description of Join Selection ## What changes were proposed in this pull request? This PR is to update the description of the join algorithm changes. ## How was this patch tested? N/A Author: gatorsmile Closes #20420 from gatorsmile/followUp22916. --- .../spark/sql/execution/SparkStrategies.scala | 60 +++++++++++++++---- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce512bc46563a..82b4eb9fba242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -91,23 +91,58 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Select the proper physical plan for join based on joining keys and size of logical plan. * * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the - * predicates can be evaluated by matching join keys. If found, Join implementations are chosen + * predicates can be evaluated by matching join keys. If found, join implementations are chosen * with the following precedence: * - * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the - * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). - * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller - * estimated physical size. If neither one of the sides has the broadcast hint, - * we only broadcast the join side if its estimated physical size that is smaller than - * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. + * - Broadcast hash join (BHJ): + * BHJ is not supported for full outer join. For right outer join, we only can broadcast the + * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin, + * we only can broadcast the right side. For inner like join, we can broadcast both sides. + * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is + * small. However, broadcasting tables is a network-intensive operation. It could cause OOM + * or perform worse than the other join algorithms, especially when the build/broadcast side + * is big. + * + * For the supported cases, users can specify the broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and + * which join side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type + * is inner like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used. + * * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. + * * - Sort merge: if the matching join keys are sortable. * * If there is no joining keys, Join implementations are chosen with the following precedence: - * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted - * - CartesianProduct: for Inner join - * - BroadcastNestedLoopJoin + * - BroadcastNestedLoopJoin (BNLJ): + * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios: + * For right outer join, the left side is broadcast. For left outer, left semi, left anti + * and the internal join type ExistenceJoin, the right side is broadcast. For inner like + * joins, either side is broadcast. + * + * Like BHJ, users still can specify the broadcast hint and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for + * inner-like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used. + * + * - CartesianProduct: for inner like join, CartesianProduct is the fallback option. + * + * - BroadcastNestedLoopJoin (BNLJ): + * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast + * side with the broadcast hint. If neither side has a hint, we broadcast the side with + * the smaller estimated physical size. */ object JoinSelection extends Strategy with PredicateHelper { @@ -140,8 +175,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true - case j: ExistenceJoin => true + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } @@ -244,7 +278,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ - // Pick BroadcastNestedLoopJoin if one side could be broadcasted + // Pick BroadcastNestedLoopJoin if one side could be broadcast case j @ logical.Join(left, right, joinType, condition) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) From b834446ec1338349f6d974afd96f677db3e8fd1a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Jan 2018 16:09:14 -0600 Subject: [PATCH 0231/2461] [SPARK-23209][core] Allow credential manager to work when Hive not available. The JVM seems to be doing early binding of classes that the Hive provider depends on, causing an error to be thrown before it was caught by the code in the class. The fix wraps the creation of the provider in a try..catch so that the provider can be ignored when dependencies are missing. Added a unit test (which fails without the fix), and also tested that getting tokens still works in a real cluster. Author: Marcelo Vanzin Closes #20399 from vanzin/SPARK-23209. --- .../HadoopDelegationTokenManager.scala | 17 +++++- .../HadoopDelegationTokenManagerSuite.scala | 58 +++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 116a686fe1480..5151df00476f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -64,9 +64,9 @@ private[spark] class HadoopDelegationTokenManager( } private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { - val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), - new HiveDelegationTokenProvider, - new HBaseDelegationTokenProvider) + val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystems)) ++ + safeCreateProvider(new HiveDelegationTokenProvider) ++ + safeCreateProvider(new HBaseDelegationTokenProvider) // Filter out providers for which spark.security.credentials.{service}.enabled is false. providers @@ -75,6 +75,17 @@ private[spark] class HadoopDelegationTokenManager( .toMap } + private def safeCreateProvider( + createFn: => HadoopDelegationTokenProvider): Option[HadoopDelegationTokenProvider] = { + try { + Some(createFn) + } catch { + case t: Throwable => + logDebug(s"Failed to load built in provider.", t) + None + } + } + def isServiceEnabled(serviceName: String): Boolean = { val key = providerEnabledConfig.format(serviceName) diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index eeffc36070b44..2849a10a2c81e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.security +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials @@ -110,7 +111,64 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { creds.getAllTokens.size should be (0) } + test("SPARK-23209: obtain tokens when Hive classes are not available") { + // This test needs a custom class loader to hide Hive classes which are in the classpath. + // Because the manager code loads the Hive provider directly instead of using reflection, we + // need to drive the test through the custom class loader so a new copy that cannot find + // Hive classes is loaded. + val currentLoader = Thread.currentThread().getContextClassLoader() + val noHive = new ClassLoader() { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + if (name.startsWith("org.apache.hive") || name.startsWith("org.apache.hadoop.hive")) { + throw new ClassNotFoundException(name) + } + + if (name.startsWith("java") || name.startsWith("scala")) { + currentLoader.loadClass(name) + } else { + val classFileName = name.replaceAll("\\.", "/") + ".class" + val in = currentLoader.getResourceAsStream(classFileName) + if (in != null) { + val bytes = IOUtils.toByteArray(in) + defineClass(name, bytes, 0, bytes.length) + } else { + throw new ClassNotFoundException(name) + } + } + } + } + + try { + Thread.currentThread().setContextClassLoader(noHive) + val test = noHive.loadClass(NoHiveTest.getClass.getName().stripSuffix("$")) + test.getMethod("runTest").invoke(null) + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + } + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { Set(FileSystem.get(hadoopConf)) } } + +/** Test code for SPARK-23209 to avoid using too much reflection above. */ +private object NoHiveTest extends Matchers { + + def runTest(): Unit = { + try { + val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(), + _ => Set()) + manager.getServiceDelegationTokenProvider("hive") should be (None) + } catch { + case e: Throwable => + // Throw a better exception in case the test fails, since there may be a lot of nesting. + var cause = e + while (cause.getCause() != null) { + cause = cause.getCause() + } + throw cause + } + } + +} From f235df66a4754cbb64d5b7b5cfd5a52bdd243b8a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Jan 2018 17:37:55 -0800 Subject: [PATCH 0232/2461] [SPARK-22221][SQL][FOLLOWUP] Externalize spark.sql.execution.arrow.maxRecordsPerBatch ## What changes were proposed in this pull request? This is a followup to #19575 which added a section on setting max Arrow record batches and this will externalize the conf that was referenced in the docs. ## How was this patch tested? NA Author: Bryan Cutler Closes #20423 from BryanCutler/arrow-user-doc-externalize-maxRecordsPerBatch-SPARK-22221. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 61ea03d395afc..54a35594f505e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1051,7 +1051,6 @@ object SQLConf { val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() .doc("When using Apache Arrow, limit the maximum number of records that can be written " + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") .intConf From 31bd1dab1301d27a16c9d5d1b0b3301d618b0516 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Tue, 30 Jan 2018 11:15:27 +0800 Subject: [PATCH 0233/2461] [SPARK-23088][CORE] History server not showing incomplete/running applications ## What changes were proposed in this pull request? History server not showing incomplete/running applications when spark.history.ui.maxApplications property is set to a value that is smaller than the total number of applications. ## How was this patch tested? Verified manually against master and 2.2.2 branch. Author: Paul Mackles Closes #20335 from pmackles/SPARK-23088. --- .../resources/org/apache/spark/ui/static/historypage.js | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2cde66b081a1c..f0b2a5a833a99 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -108,7 +108,12 @@ $(document).ready(function() { requestedIncomplete = getParameterByName("showIncomplete", searchString); requestedIncomplete = (requestedIncomplete == "true" ? true : false); - $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) { + appParams = { + limit: appLimit, + status: (requestedIncomplete ? "running" : "completed") + }; + + $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { From b375397b1678b7fe20a0b7f87a7e8b37ae5646ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 30 Jan 2018 11:40:42 +0800 Subject: [PATCH 0234/2461] [SPARK-23207][SQL][FOLLOW-UP] Don't perform local sort for DataFrame.repartition(1) ## What changes were proposed in this pull request? In `ShuffleExchangeExec`, we don't need to insert extra local sort before round-robin partitioning, if the new partitioning has only 1 partition, because under that case all output rows go to the same partition. ## How was this patch tested? The existing test cases. Author: Xingbo Jiang Closes #20426 from jiangxb1987/repartition1. --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 4 ++++ .../spark/sql/execution/streaming/ForeachSinkSuite.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 76c1fa65f924b..4d95ee34f30de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -257,7 +257,11 @@ object ShuffleExchangeExec { // // Currently we following the most straight-forward way that perform a local sort before // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.numPartitions > 1 && newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 1248c670df45c..41434e6d8b974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 8b983243e45dfe2617c043a3229a7d87f4c4b44b Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 29 Jan 2018 22:19:59 -0800 Subject: [PATCH 0235/2461] [SPARK-23157][SQL] Explain restriction on column expression in withColumn() ## What changes were proposed in this pull request? It's not obvious from the comments that any added column must be a function of the dataset that we are adding it to. Add a comment to that effect to Scala, Python and R Data* methods. Author: Henry Robinson Closes #20429 from henryr/SPARK-23157. --- R/pkg/R/DataFrame.R | 3 ++- python/pyspark/sql/dataframe.py | 4 ++++ sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 29f3e986eaab6..547b5ea48a555 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,7 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in +#' the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ac403080acfdf..055b2c4a0ffec 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,11 +1829,15 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. + The column expression must be an expression over this dataframe; attempting to add + a column from some other dataframe will raise an error. + :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cc5b647b3f037..d47cd0aecf56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2150,6 +2150,9 @@ class Dataset[T] private[sql]( * Returns a new Dataset by adding a column or replacing the existing column that has * the same name. * + * `column`'s expression must only refer to attributes supplied by this Dataset. It is an + * error to add a column that refers to some other Dataset. + * * @group untypedrel * @since 2.0.0 */ From 5056877e8bea56dd0f4dc9e3385669e1e78b2925 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 30 Jan 2018 09:02:16 +0200 Subject: [PATCH 0236/2461] [SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide ## What changes were proposed in this pull request? User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139). I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary. ## How was this patch tested? Docs and examples only. Ran all examples locally using spark-submit. Author: sethah Closes #20332 from sethah/multiclass_summary_example. --- docs/ml-classification-regression.md | 22 +++---- .../JavaLogisticRegressionSummaryExample.java | 17 ++--- ...gisticRegressionWithElasticNetExample.java | 62 +++++++++++++++++++ ...ss_logistic_regression_with_elastic_net.py | 38 ++++++++++++ .../ml/LogisticRegressionSummaryExample.scala | 15 ++--- ...isticRegressionWithElasticNetExample.scala | 43 +++++++++++++ 6 files changed, 164 insertions(+), 33 deletions(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index bf979f3c73a52..ddd2f4b49ca07 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark The `spark.ml` implementation of logistic regression also supports extracting a summary of the model over the training set. Note that the predictions and metrics which are stored as `DataFrame` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +`LogisticRegressionSummary` are annotated `@transient` and hence only available on the driver.
@@ -97,10 +97,9 @@ only available on the driver. [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) provides a summary for a [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -111,10 +110,9 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) provides a summary for a [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). Continuing the earlier example: @@ -125,7 +123,8 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary) provides a summary for a [`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin **Examples** The following example shows how to train a multiclass logistic regression -model with elastic net regularization. +model with elastic net regularization, as well as extract the multiclass +training summary for evaluating the model.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java index dee56799d8aee..1529da16f051f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -18,10 +18,9 @@ package org.apache.spark.examples.ml; // $example on$ -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -50,7 +49,7 @@ public static void main(String[] args) { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -58,21 +57,15 @@ public static void main(String[] args) { System.out.println(lossPerIteration); } - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary - // classification problem. - BinaryLogisticRegressionSummary binarySummary = - (BinaryLogisticRegressionSummary) trainingSummary; - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - Dataset roc = binarySummary.roc(); + Dataset roc = trainingSummary.roc(); roc.show(); roc.select("FPR").show(); - System.out.println(binarySummary.areaUnderROC()); + System.out.println(trainingSummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. - Dataset fMeasure = binarySummary.fMeasureByThreshold(); + Dataset fMeasure = trainingSummary.fMeasureByThreshold(); double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) .select("threshold").head().getDouble(0); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java index da410cba2b3f1..801a82cd2f24f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -48,6 +49,67 @@ public static void main(String[] args) { // Print the coefficients and intercept for multinomial logistic regression System.out.println("Coefficients: \n" + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector()); + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // for multiclass, we can inspect metrics on a per-label basis + System.out.println("False positive rate by label:"); + int i = 0; + double[] fprLabel = trainingSummary.falsePositiveRateByLabel(); + for (double fpr : fprLabel) { + System.out.println("label " + i + ": " + fpr); + i++; + } + + System.out.println("True positive rate by label:"); + i = 0; + double[] tprLabel = trainingSummary.truePositiveRateByLabel(); + for (double tpr : tprLabel) { + System.out.println("label " + i + ": " + tpr); + i++; + } + + System.out.println("Precision by label:"); + i = 0; + double[] precLabel = trainingSummary.precisionByLabel(); + for (double prec : precLabel) { + System.out.println("label " + i + ": " + prec); + i++; + } + + System.out.println("Recall by label:"); + i = 0; + double[] recLabel = trainingSummary.recallByLabel(); + for (double rec : recLabel) { + System.out.println("label " + i + ": " + rec); + i++; + } + + System.out.println("F-measure by label:"); + i = 0; + double[] fLabel = trainingSummary.fMeasureByLabel(); + for (double f : fLabel) { + System.out.println("label " + i + ": " + f); + i++; + } + + double accuracy = trainingSummary.accuracy(); + double falsePositiveRate = trainingSummary.weightedFalsePositiveRate(); + double truePositiveRate = trainingSummary.weightedTruePositiveRate(); + double fMeasure = trainingSummary.weightedFMeasure(); + double precision = trainingSummary.weightedPrecision(); + double recall = trainingSummary.weightedRecall(); + System.out.println("Accuracy: " + accuracy); + System.out.println("FPR: " + falsePositiveRate); + System.out.println("TPR: " + truePositiveRate); + System.out.println("F-measure: " + fMeasure); + System.out.println("Precision: " + precision); + System.out.println("Recall: " + recall); // $example off$ spark.stop(); diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py index bb9cd82d6ba27..bec9860c79a2d 100644 --- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -43,6 +43,44 @@ # Print the coefficients and intercept for multinomial logistic regression print("Coefficients: \n" + str(lrModel.coefficientMatrix)) print("Intercept: " + str(lrModel.interceptVector)) + + trainingSummary = lrModel.summary + + # Obtain the objective per iteration + objectiveHistory = trainingSummary.objectiveHistory + print("objectiveHistory:") + for objective in objectiveHistory: + print(objective) + + # for multiclass, we can inspect metrics on a per-label basis + print("False positive rate by label:") + for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("True positive rate by label:") + for i, rate in enumerate(trainingSummary.truePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("Precision by label:") + for i, prec in enumerate(trainingSummary.precisionByLabel): + print("label %d: %s" % (i, prec)) + + print("Recall by label:") + for i, rec in enumerate(trainingSummary.recallByLabel): + print("label %d: %s" % (i, rec)) + + print("F-measure by label:") + for i, f in enumerate(trainingSummary.fMeasureByLabel()): + print("label %d: %s" % (i, f)) + + accuracy = trainingSummary.accuracy + falsePositiveRate = trainingSummary.weightedFalsePositiveRate + truePositiveRate = trainingSummary.weightedTruePositiveRate + fMeasure = trainingSummary.weightedFMeasure() + precision = trainingSummary.weightedPrecision + recall = trainingSummary.weightedRecall + print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s" + % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall)) # $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 1740a0d3f9d12..0368dcba460b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +import org.apache.spark.ml.classification.LogisticRegression // $example off$ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.max @@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - val trainingSummary = lrModel.summary + val trainingSummary = lrModel.binarySummary // Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory println("objectiveHistory:") objectiveHistory.foreach(loss => println(loss)) - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a - // binary classification problem. - val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - val roc = binarySummary.roc + val roc = trainingSummary.roc roc.show() - println(s"areaUnderROC: ${binarySummary.areaUnderROC}") + println(s"areaUnderROC: ${trainingSummary.areaUnderROC}") // Set the model threshold to maximize F-Measure - val fMeasure = binarySummary.fMeasureByThreshold + val fMeasure = trainingSummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) .select("threshold").head().getDouble(0) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 3e61dbe628c20..1f7dbddd454e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") println(s"Intercepts: \n${lrModel.interceptVector}") + + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration + val objectiveHistory = trainingSummary.objectiveHistory + println("objectiveHistory:") + objectiveHistory.foreach(println) + + // for multiclass, we can inspect metrics on a per-label basis + println("False positive rate by label:") + trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("True positive rate by label:") + trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("Precision by label:") + trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) => + println(s"label $label: $prec") + } + + println("Recall by label:") + trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) => + println(s"label $label: $rec") + } + + + println("F-measure by label:") + trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) => + println(s"label $label: $f") + } + + val accuracy = trainingSummary.accuracy + val falsePositiveRate = trainingSummary.weightedFalsePositiveRate + val truePositiveRate = trainingSummary.weightedTruePositiveRate + val fMeasure = trainingSummary.weightedFMeasure + val precision = trainingSummary.weightedPrecision + val recall = trainingSummary.weightedRecall + println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" + + s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall") // $example off$ spark.stop() From 0a9ac0248b6514a1e83ff7e4c522424f01b8b78d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jan 2018 19:43:17 +0800 Subject: [PATCH 0237/2461] [SPARK-23260][SPARK-23262][SQL] several data source v2 naming cleanup ## What changes were proposed in this pull request? All other classes in the reader/writer package doesn't have `V2` in their names, and the streaming reader/writer don't have `V2` either. It's more consistent to remove `V2` from `DataSourceV2Reader` and `DataSourceVWriter`. Also rename `DataSourceV2Option` to remote the `V2`, we should only have `V2` in the root interface: `DataSourceV2`. This PR also fixes some places that the mix-in interface doesn't extend the interface it aimed to mix in. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20427 from cloud-fan/ds-v2. --- .../sql/kafka010/KafkaContinuousReader.scala | 2 +- .../sql/kafka010/KafkaSourceProvider.scala | 6 ++--- ...eV2Options.java => DataSourceOptions.java} | 8 +++---- .../spark/sql/sources/v2/ReadSupport.java | 8 +++---- .../sql/sources/v2/ReadSupportWithSchema.java | 8 +++---- .../sql/sources/v2/SessionConfigSupport.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 12 +++++----- .../sources/v2/reader/DataReaderFactory.java | 2 +- ...rceV2Reader.java => DataSourceReader.java} | 11 +++++---- .../SupportsPushDownCatalystFilters.java | 4 ++-- .../v2/reader/SupportsPushDownFilters.java | 4 ++-- .../SupportsPushDownRequiredColumns.java | 6 ++--- .../v2/reader/SupportsReportPartitioning.java | 4 ++-- .../v2/reader/SupportsReportStatistics.java | 4 ++-- .../v2/reader/SupportsScanColumnarBatch.java | 6 ++--- .../v2/reader/SupportsScanUnsafeRow.java | 6 ++--- .../v2/streaming/ContinuousReadSupport.java | 4 ++-- .../v2/streaming/MicroBatchReadSupport.java | 4 ++-- .../v2/streaming/StreamWriteSupport.java | 10 ++++---- .../v2/streaming/reader/ContinuousReader.java | 6 ++--- .../v2/streaming/reader/MicroBatchReader.java | 6 ++--- .../v2/streaming/writer/StreamWriter.java | 6 ++--- ...rceV2Writer.java => DataSourceWriter.java} | 8 +++---- .../sql/sources/v2/writer/DataWriter.java | 12 +++++----- .../sources/v2/writer/DataWriterFactory.java | 2 +- .../v2/writer/SupportsWriteInternalRow.java | 4 ++-- .../v2/writer/WriterCommitMessage.java | 4 ++-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 2 +- .../datasources/v2/DataSourceV2Relation.scala | 6 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../datasources/v2/WriteToDataSourceV2.scala | 4 ++-- .../streaming/MicroBatchExecution.scala | 6 ++--- .../streaming/RateSourceProvider.scala | 2 +- .../sql/execution/streaming/console.scala | 4 ++-- .../continuous/ContinuousExecution.scala | 6 ++--- .../ContinuousRateStreamSource.scala | 7 +++--- .../streaming/sources/ConsoleWriter.scala | 4 ++-- .../streaming/sources/MicroBatchWriter.scala | 8 +++---- .../sources/PackedRowWriterFactory.scala | 4 ++-- .../sources/RateStreamSourceV2.scala | 6 ++--- .../streaming/sources/memoryV2.scala | 6 ++--- .../sql/streaming/DataStreamReader.scala | 4 ++-- .../sources/v2/JavaAdvancedDataSourceV2.java | 6 ++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 6 ++--- .../v2/JavaPartitionAwareDataSource.java | 6 ++--- .../v2/JavaSchemaRequiredDataSource.java | 8 +++---- .../sources/v2/JavaSimpleDataSourceV2.java | 8 +++---- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 6 ++--- .../streaming/RateSourceV2Suite.scala | 18 +++++++------- ...ite.scala => DataSourceOptionsSuite.scala} | 16 ++++++------- .../sql/sources/v2/DataSourceV2Suite.scala | 24 +++++++++---------- .../sources/v2/SimpleWritableDataSource.scala | 12 +++++----- .../sources/StreamingDataSourceV2Suite.scala | 8 +++---- 55 files changed, 176 insertions(+), 176 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{DataSourceV2Options.java => DataSourceOptions.java} (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataSourceV2Reader.java => DataSourceReader.java} (91%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/{DataSourceV2Writer.java => DataSourceWriter.java} (96%) rename sql/core/src/test/scala/org/apache/spark/sql/sources/v2/{DataSourceV2OptionsSuite.scala => DataSourceOptionsSuite.scala} (80%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 9125cf5799d74..8c733426b256f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceOptions]] params which * are not Kafka consumer params. * @param metadataPath Path to a directory this reader can use for writing metadata. * @param initialOffsets The Kafka offsets to start reading data at. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 2deb7fa2cdf1e..85e96b6783327 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode @@ -109,7 +109,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def createContinuousReader( schema: Optional[StructType], metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReader = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -227,7 +227,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index ddc2acca693ac..c32053580f016 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -29,18 +29,18 @@ * data source options. */ @InterfaceStability.Evolving -public class DataSourceV2Options { +public class DataSourceOptions { private final Map keyLowerCasedMap; private String toLowerCase(String key) { return key.toLowerCase(Locale.ROOT); } - public static DataSourceV2Options empty() { - return new DataSourceV2Options(new HashMap<>()); + public static DataSourceOptions empty() { + return new DataSourceOptions(new HashMap<>()); } - public DataSourceV2Options(Map originalMap) { + public DataSourceOptions(Map originalMap) { keyLowerCasedMap = new HashMap<>(originalMap.size()); for (Map.Entry entry : originalMap.entrySet()) { keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 948e20bacf4a2..0ea4dc6b5def3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -18,17 +18,17 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability and scan the data from the data source. */ @InterfaceStability.Evolving -public interface ReadSupport { +public interface ReadSupport extends DataSourceV2 { /** - * Creates a {@link DataSourceV2Reader} to scan the data from this data source. + * Creates a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -36,5 +36,5 @@ public interface ReadSupport { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(DataSourceV2Options options); + DataSourceReader createReader(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index b69c6bed8d1b5..3801402268af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; /** @@ -30,10 +30,10 @@ * supports both schema inference and user-specified schema. */ @InterfaceStability.Evolving -public interface ReadSupportWithSchema { +public interface ReadSupportWithSchema extends DataSourceV2 { /** - * Create a {@link DataSourceV2Reader} to scan the data from this data source. + * Create a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -45,5 +45,5 @@ public interface ReadSupportWithSchema { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); + DataSourceReader createReader(StructType schema, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 3cb020d2e0836..9d66805d79b9e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -25,7 +25,7 @@ * session. */ @InterfaceStability.Evolving -public interface SessionConfigSupport { +public interface SessionConfigSupport extends DataSourceV2 { /** * Key prefix of the session configs to propagate. Spark will extract all session configs that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 1e3b644d8c4ae..cab56453816cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; /** @@ -29,17 +29,17 @@ * provide data writing ability and save the data to the data source. */ @InterfaceStability.Evolving -public interface WriteSupport { +public interface WriteSupport extends DataSourceV2 { /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceV2Writer} can + * jobs running at the same time, and the returned {@link DataSourceWriter} can * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data @@ -47,6 +47,6 @@ public interface WriteSupport { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index 077b95b837964..32e98e8f5d8bd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is + * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is * responsible for creating the actual data reader. The relationship between * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java similarity index 91% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 0180cd9ea47f8..a470bccc5aad2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -21,14 +21,15 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.types.StructType; /** * A data source reader that is returned by - * {@link org.apache.spark.sql.sources.v2.ReadSupport#createReader( - * org.apache.spark.sql.sources.v2.DataSourceV2Options)} or - * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( - * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. + * {@link ReadSupport#createReader(DataSourceOptions)} or + * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan * logic is delegated to {@link DataReaderFactory}s that are returned by * {@link #createDataReaderFactories()}. @@ -52,7 +53,7 @@ * issues the scan request and does the actual data reading. */ @InterfaceStability.Evolving -public interface DataSourceV2Reader { +public interface DataSourceReader { /** * Returns the actual schema of this data source reader, which may be different from the physical diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index f76c687f450c8..98224102374aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down arbitrary expressions as predicates to the data source. * This is an experimental and unstable interface as {@link Expression} is not public and may get * changed in the future Spark versions. @@ -31,7 +31,7 @@ * process this interface. */ @InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters { +public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** * Pushes down filters, and returns unsupported filters. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 6b0c9d417eeae..f35c711b0387a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down filters to the data source and reduce the size of the data to be read. * * Note that, if data source readers implement both this interface and @@ -29,7 +29,7 @@ * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters { +public interface SupportsPushDownFilters extends DataSourceReader { /** * Pushes down filters, and returns unsupported filters. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index fe0ac8ee0ee32..427b4d00a1128 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns { +public interface SupportsPushDownRequiredColumns extends DataSourceReader { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,7 +35,7 @@ public interface SupportsPushDownRequiredColumns { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceV2Reader#readSchema()} after + * Note that, data source readers should update {@link DataSourceReader#readSchema()} after * applying column pruning. */ void pruneColumns(StructType requiredSchema); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index f786472ccf345..a2383a9d7d680 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning { +public interface SupportsReportPartitioning extends DataSourceReader { /** * Returns the output data partitioning that this reader guarantees. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index c019d2f819ab7..11bb13fd3b211 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics { +public interface SupportsReportStatistics extends DataSourceReader { /** * Returns the basic statistics of this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 67da55554bbf3..2e5cfa78511f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -24,11 +24,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link ColumnarBatch} and make the scan faster. */ @InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceV2Reader { +public interface SupportsScanColumnarBatch extends DataSourceReader { @Override default List> createDataReaderFactories() { throw new IllegalStateException( @@ -36,7 +36,7 @@ default List> createDataReaderFactories() { } /** - * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data * in batches. */ List> createBatchDataReaderFactories(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 156af69520f77..9cd749e8e4ce9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceV2Reader { +public interface SupportsScanUnsafeRow extends DataSourceReader { @Override default List> createDataReaderFactories() { @@ -39,7 +39,7 @@ default List> createDataReaderFactories() { } /** - * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#createDataReaderFactories()}, * but returns data in unsafe row format. */ List> createUnsafeRowReaderFactories(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 9a93a806b0efc..f79424e036a52 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; import org.apache.spark.sql.types.StructType; @@ -44,5 +44,5 @@ public interface ContinuousReadSupport extends DataSourceV2 { ContinuousReader createContinuousReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 3b357c01a29fe..22660e42ad850 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -20,8 +20,8 @@ import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; import org.apache.spark.sql.types.StructType; @@ -50,5 +50,5 @@ public interface MicroBatchReadSupport extends DataSourceV2 { MicroBatchReader createMicroBatchReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java index 6cd219c67109a..7c5f304425093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java @@ -19,10 +19,10 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; @@ -31,7 +31,7 @@ * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface StreamWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { /** * Creates an optional {@link StreamWriter} to save the data to this data source. Data @@ -39,7 +39,7 @@ public interface StreamWriteSupport extends BaseStreamingSink { * * @param queryId A unique string for the writing query. It's possible that there are many * writing queries running at the same time, and the returned - * {@link DataSourceV2Writer} can use this id to distinguish itself from others. + * {@link DataSourceWriter} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. @@ -50,5 +50,5 @@ StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 3ac979cb0b7b4..6e5177ee83a62 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -19,12 +19,12 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. @@ -33,7 +33,7 @@ * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ @InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { +public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each * partition to a single global offset. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 68887e569fc1d..fcec446d892f5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -18,20 +18,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ @InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { +public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** * Set the desired offset range for reader factories created from this reader. Reader factories * will generate only data within (`start`, `end`]; that is, from the first record after `start` diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java index 3156c88933e5e..915ee6c4fb390 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java @@ -18,19 +18,19 @@ package org.apache.spark.sql.sources.v2.streaming.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceV2Writer} for use with structured streaming. This writer handles commits and + * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and * aborts relative to an epoch ID determined by the execution engine. * * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, * and so must reset any internal state after a successful commit. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceV2Writer { +public interface StreamWriter extends DataSourceWriter { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 8048f507a1dca..d89d27d0e5b1b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -20,16 +20,16 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceV2Options)}. + * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * @@ -52,7 +52,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceV2Writer { +public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 04b03e63de500..53941a89ba94e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -33,11 +33,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark will retry this writing task for some times, * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the @@ -69,11 +69,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -91,7 +91,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 18ec792f5a2c9..ea95442511ce5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java index 3e0518814f458..d2cf7e01c08c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -22,14 +22,14 @@ import org.apache.spark.sql.catalyst.InternalRow; /** - * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceV2Writer { +public interface SupportsWriteInternalRow extends DataSourceWriter { @Override default DataWriterFactory createWriterFactory() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 082d6b5dc409f..9e38836c0edf9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -23,10 +23,10 @@ /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} * implementations. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..46b5f54a33f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -186,7 +186,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance() - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5c02eae05304b..ed7a9100cc7f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val ds = cls.newInstance() ds match { case ws: WriteSupport => - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = df.sparkSession.sessionState.conf)).asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6093df26630cd..6460c97abe344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -35,7 +35,7 @@ trait DataSourceReaderHolder { /** * The held data source reader. */ - def reader: DataSourceV2Reader + def reader: DataSourceReader /** * The metadata of this data source reader that can be used for equality test. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cba20dd902007..3d4c64981373d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] @@ -41,12 +41,12 @@ case class DataSourceV2Relation( */ class StreamingDataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { override def isStreaming: Boolean = true } object DataSourceV2Relation { - def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { + def apply(reader: DataSourceReader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3f808fbb40932..ee085820b0775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) + @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index cd6b3e99b6bcb..c544adbf32cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.Utils /** * The logical plan for writing data into data source v2. */ -case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -43,7 +43,7 @@ case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) e /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 975975243a3d1..93572f7a63132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow @@ -89,7 +89,7 @@ class MicroBatchExecution( val reader = source.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, - new DataSourceV2Options(options.asJava)) + new DataSourceOptions(options.asJava)) nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) @@ -447,7 +447,7 @@ class MicroBatchExecution( s"$runId", newAttributePlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) + new DataSourceOptions(extraOptions.asJava)) if (writer.isInstanceOf[SupportsWriteInternalRow]) { WriteToDataSourceV2( new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 66eb0169ac1ec..5e3fee633f591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -111,7 +111,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = { + options: DataSourceOptions): ContinuousReader = { new RateStreamContinuousReader(options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index d5ac0bd1df52b..3f5bb489d6528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode @@ -40,7 +40,7 @@ class ConsoleSinkProvider extends DataSourceV2 queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { new ConsoleWriter(schema, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 60f880f9c73b8..9402d7c1dcefd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -160,7 +160,7 @@ class ContinuousExecution( dataSource.createContinuousReader( java.util.Optional.empty[StructType](), metadataPath, - new DataSourceV2Options(extraReaderOptions.asJava)) + new DataSourceOptions(extraReaderOptions.asJava)) } uniqueSources = continuousSources.distinct @@ -198,7 +198,7 @@ class ContinuousExecution( s"$runId", triggerLogicalPlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) + new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 61304480f4721..ff028ebc4236a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -23,19 +23,18 @@ import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 7c1700f1de48c..d46f4d7b86360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceV2Options) +class ConsoleWriter(schema: StructType, options: DataSourceOptions) extends StreamWriter with Logging { // Number of rows to display, by default 20 rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7f3ba8856982..d7ce9a7b84479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter -import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} /** - * A [[DataSourceV2Writer]] used to hook V2 stream writers into a microbatch plan. It implements + * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped * streaming writer. */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2Writer { +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } @@ -38,7 +38,7 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2 } class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceV2Writer with SupportsWriteInternalRow { + extends DataSourceWriter with SupportsWriteInternalRow { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 9282ba05bdb7b..248295e401a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,11 +21,11 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * to a [[DataSourceWriter]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index a25cc4f3b06f8..43949e6180aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} @@ -44,14 +44,14 @@ class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with override def createMicroBatchReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReader = { new RateStreamMicroBatchReader(options) } override def shortName(): String = "ratev2" } -class RateStreamMicroBatchReader(options: DataSourceV2Options) +class RateStreamMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index ce55e44d932bd..58767261dc684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ @@ -45,7 +45,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { new MemoryStreamWriter(this, mode) } @@ -114,7 +114,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) - extends DataSourceV2Writer with Logging { + extends DataSourceWriter with Logging { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9f5ca9f914284..f1b3f93c4e1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -158,7 +158,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() - val options = new DataSourceV2Options(extraOptions.asJava) + val options = new DataSourceOptions(extraOptions.asJava) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 4026ee44bfdb7..d421f7d19563f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,15 +24,15 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, + class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -131,7 +131,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 34e6c63801064..c55093768105b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -21,8 +21,8 @@ import java.util.List; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; @@ -33,7 +33,7 @@ public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + class Reader implements DataSourceReader, SupportsScanColumnarBatch { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -108,7 +108,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index d0c87503ab455..99cca0f6dd626 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -23,15 +23,15 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -104,7 +104,7 @@ public DataReader createDataReader() { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index f997366af1a64..048d078dfaac4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -20,16 +20,16 @@ import java.util.List; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema; Reader(StructType schema) { @@ -48,7 +48,7 @@ public List> createDataReaderFactories() { } @Override - public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { return new Reader(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2beed431d301f..96f55b8a76811 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -23,16 +23,16 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -80,7 +80,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index e8187524ea871..c3916e0b370b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -21,15 +21,15 @@ import java.util.List; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader, SupportsScanUnsafeRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -83,7 +83,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index d2cfe7905f6fa..b060aeeef811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -49,7 +49,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch in registry") { DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) assert(reader.isInstanceOf[RateStreamMicroBatchReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") @@ -76,14 +76,14 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - numPartitions propagated") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) } test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -93,7 +93,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - infer offsets") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) reader.getStartOffset() match { @@ -114,7 +114,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - predetermined batch size") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -125,7 +125,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - data read") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => @@ -150,7 +150,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") @@ -159,7 +159,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous data") { val reader = new RateStreamContinuousReader( - new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 90d92864b26fa..31dfc55b23361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -22,24 +22,24 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite /** - * A simple test suite to verify `DataSourceV2Options`. + * A simple test suite to verify `DataSourceOptions`. */ -class DataSourceV2OptionsSuite extends SparkFunSuite { +class DataSourceOptionsSuite extends SparkFunSuite { test("key is case-insensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("foo" -> "bar").asJava) assert(options.get("foo").get() == "bar") assert(options.get("FoO").get() == "bar") assert(!options.get("abc").isPresent) } test("value is case-sensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + val options = new DataSourceOptions(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } test("getInt") { - val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava) assert(options.getInt("numFOO", 10) == 1) assert(options.getInt("numFOO2", 10) == 10) @@ -49,7 +49,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getBoolean") { - val options = new DataSourceV2Options( + val options = new DataSourceOptions( Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) assert(options.getBoolean("isFoo", false)) assert(!options.getBoolean("isFoo2", true)) @@ -59,7 +59,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getLong") { - val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807", "foo" -> "bar").asJava) assert(options.getLong("numFOO", 0L) == 9223372036854775807L) assert(options.getLong("numFoo2", -1L) == -1L) @@ -70,7 +70,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getDouble") { - val options = new DataSourceV2Options(Map("numFoo" -> "922337.1", + val options = new DataSourceOptions(Map("numFoo" -> "922337.1", "foo" -> "bar").asJava) assert(options.getDouble("numFOO", 0d) == 922337.1d) assert(options.getDouble("numFoo2", -1.02d) == -1.02d) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 42c5d3bcea44b..ee50e8a92270b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -201,7 +201,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -209,7 +209,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SimpleDataReaderFactory(start: Int, end: Int) @@ -233,7 +233,7 @@ class SimpleDataReaderFactory(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader + class Reader extends DataSourceReader with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -275,7 +275,7 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) @@ -306,7 +306,7 @@ class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { @@ -315,7 +315,7 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class UnsafeRowDataReaderFactory(start: Int, end: Int) @@ -343,18 +343,18 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceV2Reader { + class Reader(val readSchema: StructType) extends DataSourceReader { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } - override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = new Reader(schema) } class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { @@ -362,7 +362,7 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class BatchDataReaderFactory(start: Int, end: Int) @@ -406,7 +406,7 @@ class BatchDataReaderFactory(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -428,7 +428,7 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 3310d6dd199d6..a131b16953e3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,7 +42,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -64,7 +64,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } @@ -104,7 +104,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration new Reader(path.toUri.toString, conf) @@ -114,7 +114,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS jobId: String, schema: StructType, mode: SaveMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + options: DataSourceOptions): Optional[DataSourceWriter] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -141,7 +141,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString if (internal) { new InternalRowWriter(jobId, pathStr, conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index dc8c857018457..3127d664d32dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, Streami import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} @@ -54,14 +54,14 @@ trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { override def createMicroBatchReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReader = FakeReader() } trait FakeContinuousReadSupport extends ContinuousReadSupport { override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReader = FakeReader() } trait FakeStreamWriteSupport extends StreamWriteSupport { @@ -69,7 +69,7 @@ trait FakeStreamWriteSupport extends StreamWriteSupport { queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } From 7a2ada223e14d09271a76091be0338b2d375081e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 21:55:55 +0900 Subject: [PATCH 0238/2461] [SPARK-23261][PYSPARK] Rename Pandas UDFs ## What changes were proposed in this pull request? Rename the public APIs and names of pandas udfs. - `PANDAS SCALAR UDF` -> `SCALAR PANDAS UDF` - `PANDAS GROUP MAP UDF` -> `GROUPED MAP PANDAS UDF` - `PANDAS GROUP AGG UDF` -> `GROUPED AGG PANDAS UDF` ## How was this patch tested? The existing tests Author: gatorsmile Closes #20428 from gatorsmile/renamePandasUDFs. --- .../spark/api/python/PythonRunner.scala | 12 +-- docs/sql-programming-guide.md | 8 +- examples/src/main/python/sql/arrow.py | 12 +-- python/pyspark/rdd.py | 6 +- python/pyspark/sql/functions.py | 34 +++---- python/pyspark/sql/group.py | 10 +- python/pyspark/sql/tests.py | 92 +++++++++---------- python/pyspark/sql/udf.py | 25 ++--- python/pyspark/worker.py | 24 ++--- .../sql/catalyst/expressions/PythonUDF.scala | 4 +- .../sql/catalyst/planning/patterns.scala | 1 - .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../python/AggregateInPandasExec.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- 16 files changed, 120 insertions(+), 120 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 29148a7ee558b..f075a7e0eb0b4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -37,16 +37,16 @@ private[spark] object PythonEvalType { val SQL_BATCHED_UDF = 100 - val SQL_PANDAS_SCALAR_UDF = 200 - val SQL_PANDAS_GROUP_MAP_UDF = 201 - val SQL_PANDAS_GROUP_AGG_UDF = 202 + val SQL_SCALAR_PANDAS_UDF = 200 + val SQL_GROUPED_MAP_PANDAS_UDF = 201 + val SQL_GROUPED_AGG_PANDAS_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" - case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" - case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" - case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" + case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" + case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" } } diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d49c8d869cba6..a0e221b39cc34 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1684,7 +1684,7 @@ Spark will fall back to create the DataFrame without Arrow. Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator or to wrap the function, no additional configuration is required. Currently, there are two types of -Pandas UDF: Scalar and Group Map. +Pandas UDF: Scalar and Grouped Map. ### Scalar @@ -1702,8 +1702,8 @@ The following example shows how to create a scalar Pandas UDF that computes the
-### Group Map -Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +### Grouped Map +Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. Split-apply-combine consists of three steps: * Split the data into groups by using `DataFrame.groupBy`. * Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The @@ -1723,7 +1723,7 @@ The following example shows how to use `groupby().apply()` to subtract the mean
-{% include_example group_map_pandas_udf python/sql/arrow.py %} +{% include_example grouped_map_pandas_udf python/sql/arrow.py %}
diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 6c0028b3f1c1f..4c5aefb6ff4a6 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -86,15 +86,15 @@ def multiply_func(a, b): # $example off:scalar_pandas_udf$ -def group_map_pandas_udf_example(spark): - # $example on:group_map_pandas_udf$ +def grouped_map_pandas_udf_example(spark): + # $example on:grouped_map_pandas_udf$ from pyspark.sql.functions import pandas_udf, PandasUDFType df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) - @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) def substract_mean(pdf): # pdf is a pandas.DataFrame v = pdf.v @@ -110,7 +110,7 @@ def substract_mean(pdf): # | 2|-1.0| # | 2| 4.0| # +---+----+ - # $example off:group_map_pandas_udf$ + # $example off:grouped_map_pandas_udf$ if __name__ == "__main__": @@ -123,7 +123,7 @@ def substract_mean(pdf): dataframe_with_arrow_example(spark) print("Running pandas_udf scalar example") scalar_pandas_udf_example(spark) - print("Running pandas_udf group map example") - group_map_pandas_udf_example(spark) + print("Running pandas_udf grouped map example") + grouped_map_pandas_udf_example(spark) spark.stop() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6b018c3a38444..93b8974a7e64a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -68,9 +68,9 @@ class PythonEvalType(object): SQL_BATCHED_UDF = 100 - SQL_PANDAS_SCALAR_UDF = 200 - SQL_PANDAS_GROUP_MAP_UDF = 201 - SQL_PANDAS_GROUP_AGG_UDF = 202 + SQL_SCALAR_PANDAS_UDF = 200 + SQL_GROUPED_MAP_PANDAS_UDF = 201 + SQL_GROUPED_AGG_PANDAS_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a291c9b71913f..3c8fb4c4d19e7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1737,8 +1737,8 @@ def translate(srcCol, matching, replace): def create_map(*cols): """Creates a new map column. - :param cols: list of column names (string) or list of :class:`Column` expressions that grouped - as key-value pairs, e.g. (key1, value1, key2, value2, ...). + :param cols: list of column names (string) or list of :class:`Column` expressions that are + grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). >>> df.select(create_map('name', 'age').alias("map")).collect() [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] @@ -2085,11 +2085,11 @@ def map_values(col): class PandasUDFType(object): """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. """ - SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF - GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @since(1.3) @@ -2193,20 +2193,20 @@ def pandas_udf(f=None, returnType=None, functionType=None): Therefore, this can be used, for example, to ensure the length of each returned `pandas.Series`, and can not be used as the column length. - 2. GROUP_MAP + 2. GROUPED_MAP - A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` + A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. The length of the returned `pandas.DataFrame` can be arbitrary. - Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -2223,9 +2223,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - 3. GROUP_AGG + 3. GROUPED_AGG - A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. @@ -2239,7 +2239,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP @@ -2285,21 +2285,21 @@ def pandas_udf(f=None, returnType=None, functionType=None): eval_type = returnType else: # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF else: return_type = returnType if functionType is not None: eval_type = functionType else: - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF if return_type is None: raise ValueError("Invalid returnType: returnType can not be None") - if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: + if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f90a909d7c2b1..ab646535c864c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -98,7 +98,7 @@ def agg(self, *exprs): [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP @@ -235,14 +235,14 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. - :param udf: a group map user-defined function returned by + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -262,9 +262,9 @@ def apply(self, udf): """ # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " - "GROUP_MAP.") + "GROUPED_MAP.") df = self._df udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ca7bbf8ffe71c..dc80870d3cd9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3621,34 +3621,34 @@ def test_pandas_udf_basic(self): udf = pandas_udf(lambda x: x, DoubleType()) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), - PandasUDFType.GROUP_MAP) + PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, returnType='v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_pandas_udf_decorator(self): from pyspark.rdd import PythonEvalType @@ -3659,45 +3659,45 @@ def test_pandas_udf_decorator(self): def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=DoubleType()) def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) schema = StructType([StructField("v", DoubleType())]) - @pandas_udf(schema, PandasUDFType.GROUP_MAP) + @pandas_udf(schema, PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf('v double', PandasUDFType.GROUP_MAP) + @pandas_udf('v double', PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_udf_wrong_arg(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -3724,15 +3724,15 @@ def zero_with_type(): return 1 with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): - @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): - @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) def foo(k, v): return k @@ -3804,11 +3804,11 @@ def test_register_nondeterministic_vectorized_udf_basic(self): random_pandas_udf = pandas_udf( lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() self.assertEqual(random_pandas_udf.deterministic, False) - self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) nondeterministic_pandas_udf = self.spark.catalog.registerFunction( "randomPandasUDF", random_pandas_udf) self.assertEqual(nondeterministic_pandas_udf.deterministic, False) - self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() self.assertEqual(row[0], 7) @@ -4206,7 +4206,7 @@ def test_register_vectorized_udf_basic(self): col('id').cast('int').alias('b')) original_add = pandas_udf(lambda x, y: x + y, IntegerType()) self.assertEqual(original_add.deterministic, True) - self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) new_add = self.spark.catalog.registerFunction("add1", original_add) res1 = df.select(new_add(col('a'), col('b'))) res2 = self.spark.sql( @@ -4237,20 +4237,20 @@ def test_simple(self): StructField('v', IntegerType()), StructField('v1', DoubleType()), StructField('v2', LongType())]), - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertPandasEqual(expected, result) - def test_register_group_map_udf(self): + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' - 'SQL_PANDAS_SCALAR_UDF'): + 'SQL_SCALAR_PANDAS_UDF'): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): @@ -4259,7 +4259,7 @@ def test_decorator(self): @pandas_udf( 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) @@ -4275,7 +4275,7 @@ def test_coerce(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo).sort('id').toPandas() @@ -4289,7 +4289,7 @@ def test_complex_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4308,7 +4308,7 @@ def test_empty_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4328,7 +4328,7 @@ def test_datatype_string(self): foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() @@ -4342,7 +4342,7 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v map', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) with QuietTest(self.sc): @@ -4368,7 +4368,7 @@ def test_wrong_args(self): with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), PandasUDFType.SCALAR)) @@ -4379,7 +4379,7 @@ def test_unsupported_types(self): [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) + f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() @@ -4422,7 +4422,7 @@ def plus_two(v): def pandas_agg_mean_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() return avg @@ -4431,7 +4431,7 @@ def avg(v): def pandas_agg_sum_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def sum(v): return v.sum() return sum @@ -4441,7 +4441,7 @@ def pandas_agg_weighted_mean_udf(self): import numpy as np from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def weighted_mean(v, w): return np.average(v, weights=w) return weighted_mean @@ -4505,19 +4505,19 @@ def test_unsupported_types(self): with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 4f303304e5600..0f759c448b8a7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,9 +37,9 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -47,16 +47,16 @@ def _create_udf(f, returnType, evalType): require_minimum_pyarrow_version() argspec = inspect.getargspec(f) - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ + if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: raise ValueError( "Invalid function: 0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: raise ValueError( - "Invalid function: pandas_udfs with function type GROUP_MAP " + "Invalid function: pandas_udfs with function type GROUPED_MAP " "must take a single arg that is a pandas DataFrame." ) @@ -112,14 +112,15 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUP_MAP") - elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + "pandas_udf with function type GROUPED_MAP") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") + "ArrayType, StructType and MapType are not supported with " + "PandasUDFType.GROUPED_AGG") return self._returnType_placeholder @@ -292,9 +293,9 @@ def register(self, name, f, returnType=None): "Invalid returnType: data type can not be specified when f is" "a user-defined function, but got %s." % returnType) if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + PythonEvalType.SQL_SCALAR_PANDAS_UDF]: raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 173d8fb2856fa..121b3dd1aeec9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,7 +74,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_scalar_udf(f, return_type): +def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -90,7 +90,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_pandas_group_map_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd @@ -110,7 +110,7 @@ def wrapped(*series): return wrapped -def wrap_pandas_group_agg_udf(f, return_type): +def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def wrapped(*series): @@ -133,12 +133,12 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: - return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: - return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) else: @@ -163,9 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 4ba8ff6e3802f..efd664dde725a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType object PythonUDF { private[this] val SCALAR_TYPES = Set( PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF + PythonEvalType.SQL_SCALAR_PANDAS_UDF ) def isScalarPythonUDF(e: Expression): Boolean = { @@ -36,7 +36,7 @@ object PythonUDF { def isGroupAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && - e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 132241061d510..626f905707191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d320c1c359411..7147798d99533 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -449,8 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - "Must pass a group map udf") + require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 18e5f8605c60d..8e01e8e56a5bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -136,7 +136,7 @@ case class AggregateInPandasExec( val columnarBatchIter = new ArrowPythonRunner( pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(projectedRowIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 47b146f076b62..c4de214679ae4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 4ae4e164830be..9d56f48249982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -160,7 +160,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } val evaluation = validUdfs.partition( - _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 59db66bd7adf1..c798fe5a92c54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -96,7 +96,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 84bcf9dc88ffeae6fba4cfad9455ad75bed6e6f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jan 2018 21:00:29 +0800 Subject: [PATCH 0239/2461] [SPARK-23222][SQL] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? It is reported that the test `Cancelling stage in a query with Range` in `DataFrameRangeSuite` fails a few times in unrelated PRs. I personally also saw it too in my PR. This test is not very flaky actually but only fails occasionally. Based on how the test works, I guess that is because `range` finishes before the listener calls `cancelStage`. I increase the range number from `1000000000L` to `100000000000L` and count the range in one partition. I also reduce the `interval` of checking stage id. Hopefully it can make the test not flaky anymore. ## How was this patch tested? The modified tests. Author: Liang-Chi Hsieh Closes #20431 from viirya/SPARK-23222. --- .../scala/org/apache/spark/sql/DataFrameRangeSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 45afbd29d1907..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -154,7 +154,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) @@ -166,7 +166,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(1000000000L).map { x => + spark.range(0, 100000000000L, 1, 1).map { x => DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() x }.toDF("id").agg(sum("id")).collect() @@ -184,6 +184,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { From a23187f53037425c61f1180b5e7990a116f86a42 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 31 Jan 2018 00:51:00 +0900 Subject: [PATCH 0240/2461] [SPARK-23174][BUILD][PYTHON][FOLLOWUP] Add pycodestyle*.py to .gitignore file. ## What changes were proposed in this pull request? This is a follow-up pr of #20338 which changed the downloaded file name of the python code style checker but it's not contained in .gitignore file so the file remains as an untracked file for git after running the checker. This pr adds the file name to .gitignore file. ## How was this patch tested? Tested manually. Author: Takuya UESHIN Closes #20432 from ueshin/issues/SPARK-23174/fup1. --- dev/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/.gitignore b/dev/.gitignore index 4a6027429e0d3..c673922f36d23 100644 --- a/dev/.gitignore +++ b/dev/.gitignore @@ -1 +1,2 @@ pep8*.py +pycodestyle*.py From 31c00ad8b090d7eddc4622e73dc4440cd32624de Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 11:33:30 -0800 Subject: [PATCH 0241/2461] [SPARK-23267][SQL] Increase spark.sql.codegen.hugeMethodLimit to 65535 ## What changes were proposed in this pull request? Still saw the performance regression introduced by `spark.sql.codegen.hugeMethodLimit` in our internal workloads. There are two major issues in the current solution. - The size of the complied byte code is not identical to the bytecode size of the method. The detection is still not accurate. - The bytecode size of a single operator (e.g., `SerializeFromObject`) could still exceed 8K limit. We saw the performance regression in such scenario. Since it is close to the release of 2.3, we decide to increase it to 64K for avoiding the perf regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20434 from gatorsmile/revertConf. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 ++++++----- .../spark/sql/execution/WholeStageCodegenSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 54a35594f505e..7394a0d7cf983 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -660,12 +660,13 @@ object SQLConf { val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + - "codegen. When the compiled function exceeds this threshold, " + - "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + - "this is a limit in the OpenJDK JVM implementation.") + "codegen. When the compiled function exceeds this threshold, the whole-stage codegen is " + + "deactivated for this subtree of the current query plan. The default value is 65535, which " + + "is the largest bytecode size possible for a valid Java method. When running on HotSpot, " + + s"it may be preferable to set the value to ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} " + + "to match HotSpot's implementation.") .intConf - .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + .createWithDefault(65535) val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = buildConf("spark.sql.codegen.splitConsumeFuncByOperator") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 28ad712feaae6..6e8d5a70d5a8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -202,7 +202,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21871 check if we can get large code size when compiling too long functions") { + ignore("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) @@ -211,7 +211,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } - test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + ignore("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath From 58fcb5a95ee0b91300138cd23f3ce2165fab597f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 30 Jan 2018 14:11:06 -0800 Subject: [PATCH 0242/2461] [SPARK-23275][SQL] hive/tests have been failing when run locally on the laptop (Mac) with OOM ## What changes were proposed in this pull request? hive tests have been failing when they are run locally (Mac Os) after a recent change in the trunk. After running the tests for some time, the test fails with OOM with Error: unable to create new native thread. I noticed the thread count goes all the way up to 2000+ after which we start getting these OOM errors. Most of the threads seem to be related to the connection pool in hive metastore (BoneCP-xxxxx-xxxx ). This behaviour change is happening after we made the following change to HiveClientImpl.reset() ``` SQL def reset(): Unit = withHiveState { try { // code } finally { runSqlHive("USE default") ===> this is causing the issue } ``` I am proposing to temporarily back-out part of a fix made to address SPARK-23000 to resolve this issue while we work-out the exact reason for this sudden increase in thread counts. ## How was this patch tested? Ran hive/test multiple times in different machines. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #20441 from dilipbiswal/hive_tests. --- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 39d839059be75..6c0f4144992ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -825,23 +825,19 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - try { - client.getAllTables("default").asScala.foreach { t => - logDebug(s"Deleting table $t") - val table = client.getTable("default", t) - client.getIndexes("default", t, 255).asScala.foreach { index => - shim.dropIndex(client, "default", t, index.getIndexName) - } - if (!table.isIndexTable) { - client.dropTable("default", t) - } + client.getAllTables("default").asScala.foreach { t => + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) } - client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - client.dropDatabase(db, true, false, true) + if (!table.isIndexTable) { + client.dropTable("default", t) } - } finally { - runSqlHive("USE default") + } + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) } } } From 9623a98248837da302ba4ec240335d1c4268ee21 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Wed, 31 Jan 2018 07:37:25 +0900 Subject: [PATCH 0243/2461] [MINOR] Fix typos in dev/* scripts. ## What changes were proposed in this pull request? Consistency in style, grammar and removal of extraneous characters. ## How was this patch tested? Manually as this is a doc change. Author: Shashwat Anand Closes #20436 from ashashwat/SPARK-23174. --- dev/appveyor-guide.md | 6 +++--- dev/lint-python | 12 ++++++------ dev/run-pip-tests | 4 ++-- dev/run-tests-jenkins | 2 +- dev/sparktestsupport/modules.py | 8 ++++---- dev/sparktestsupport/toposort.py | 6 +++--- dev/tests/pr_merge_ability.sh | 4 ++-- dev/tests/pr_public_classes.sh | 4 ++-- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dev/appveyor-guide.md b/dev/appveyor-guide.md index d2e00b484727d..a842f39b3049a 100644 --- a/dev/appveyor-guide.md +++ b/dev/appveyor-guide.md @@ -1,6 +1,6 @@ # AppVeyor Guides -Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documenation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. +Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documentation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. ### Setting up AppVeyor @@ -45,7 +45,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 16 35 -- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access to the Github logs (e.g. commits). +- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access the Github logs (e.g. commits). 2016-09-04 11 10 22 @@ -87,7 +87,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 29 41 -- If the build is running, "CANCEL BUILD" buttom appears. Click this button top cancel the current build. +- If the build is running, "CANCEL BUILD" button appears. Click this button to cancel the current build. 2016-08-30 1 11 13 diff --git a/dev/lint-python b/dev/lint-python index e069cafa1b8c6..f738af9c49763 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -34,8 +34,8 @@ python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. -#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -# Updated to latest official version for pep8. pep8 is formally renamed to pycodestyle. +# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. PYCODESTYLE_VERSION="2.3.1" PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" @@ -60,9 +60,9 @@ export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" # There is no need to write this output to a file -#+ first, but we do so so that the check status can -#+ be output before the report, like with the -#+ scalastyle and RAT checks. +# first, but we do so so that the check status can +# be output before the report, like with the +# scalastyle and RAT checks. python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" pycodestyle_status="${PIPESTATUS[0]}" @@ -73,7 +73,7 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "PYCODESTYLE checks failed." + echo "pycodestyle checks failed." cat "$PYCODESTYLE_REPORT_PATH" rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c5..1321c2be4c192 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -25,10 +25,10 @@ shopt -s nullglob FWDIR="$(cd "$(dirname "$0")"/..; pwd)" cd "$FWDIR" -echo "Constucting virtual env for testing" +echo "Constructing virtual env for testing" VIRTUALENV_BASE=$(mktemp -d) -# Clean up the virtual env enviroment used if we created one. +# Clean up the virtual env environment used if we created one. function delete_virtualenv() { echo "Cleaning up temporary directory - $VIRTUALENV_BASE" rm -rf "$VIRTUALENV_BASE" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 03fd6ff0fba40..5bc03e41d1f2d 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -20,7 +20,7 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. # Environment variables are populated by the code here: -#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 +# https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b900f0bd913c3..dfea762db98c6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -25,10 +25,10 @@ @total_ordering class Module(object): """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. + A module is the basic abstraction in our test runner script. Each module consists of a set + of source files, a set of test commands, and a set of dependencies on other modules. We use + modules to define a dependency graph that let us determine which tests to run based on which + files have changed. """ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py index 6c67b4504bc3b..8b2688d20039f 100644 --- a/dev/sparktestsupport/toposort.py +++ b/dev/sparktestsupport/toposort.py @@ -43,8 +43,8 @@ def toposort(data): """Dependencies are expressed as a dictionary whose keys are items and whose values are a set of dependent items. Output is a list of sets in topological order. The first set consists of items with no -dependences, each subsequent set consists of items that depend upon -items in the preceeding sets. +dependencies, each subsequent set consists of items that depend upon +items in the preceding sets. """ # Special case empty input. @@ -59,7 +59,7 @@ def toposort(data): v.discard(k) # Find all items that don't depend on anything. extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) - # Add empty dependences where needed. + # Add empty dependencies where needed. data.update({item: set() for item in extra_items_in_deps}) while True: ordered = set(item for item, dep in data.items() if len(dep) == 0) diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh index d9a347fe24a8c..25fdbccac4dd8 100755 --- a/dev/tests/pr_merge_ability.sh +++ b/dev/tests/pr_merge_ability.sh @@ -23,9 +23,9 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` # Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` +# known as `sha1` in `run-tests-jenkins` # ghprbActualCommit="$1" diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 41c5d3ee8cb3c..479d1851fe0b8 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -23,7 +23,7 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` ghprbActualCommit="$1" @@ -31,7 +31,7 @@ ghprbActualCommit="$1" # master commit and the tip of the pull request branch. # By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only -# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# non-test files, we can get changes introduced in the PR and not anything else added to master # since the PR was branched. # Handle differences between GNU and BSD sed From 77866167330a665e174ae08a2f8902ef9dc3438b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 30 Jan 2018 17:14:17 -0800 Subject: [PATCH 0244/2461] [SPARK-23276][SQL][TEST] Enable UDT tests in (Hive)OrcHadoopFsRelationSuite ## What changes were proposed in this pull request? Like Parquet, ORC test suites should enable UDT tests. ## How was this patch tested? Pass the Jenkins with newly enabled test cases. Author: Dongjoon Hyun Closes #20440 from dongjoon-hyun/SPARK-23276. --- .../apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index a1f054b8e3f44..3b82a6c458ce4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -34,11 +34,10 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName - // ORC does not play well with NullType and UDT. + // ORC does not play well with NullType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false case _: CalendarIntervalType => false - case _: UserDefinedType[_] => false case _ => true } From ca04c3ff2387bf0a4308a4b010154e6761827278 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 20:05:57 -0800 Subject: [PATCH 0245/2461] [SPARK-23274][SQL] Fix ReplaceExceptWithFilter when the right's Filter contains the references that are not in the left output ## What changes were proposed in this pull request? This PR is to fix the `ReplaceExceptWithFilter` rule when the right's Filter contains the references that are not in the left output. Before this PR, we got the error like ``` java.util.NoSuchElementException: key not found: a at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) ``` After this PR, `ReplaceExceptWithFilter ` will not take an effect in this case. ## How was this patch tested? Added tests Author: gatorsmile Closes #20444 from gatorsmile/fixReplaceExceptWithFilter. --- .../optimizer/ReplaceExceptWithFilter.scala | 17 +++++++++++++---- .../optimizer/ReplaceOperatorSuite.scala | 15 +++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 89bfcee078fba..45edf266bbce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,18 +46,27 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case Except(left, right) if isEligible(left, right) => - Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + case e @ Except(left, right) if isEligible(left, right) => + val newCondition = transformCondition(left, skipProject(right)) + newCondition.map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { val filterCondition = InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { + Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + } else { + None + } } // TODO: This can be further extended in the future. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index e9701ffd2c54b..52dc2e9fb076c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -168,6 +168,21 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter when only right filter can be applied to the left") { + val table = LocalRelation(Seq('a.int, 'b.int)) + val left = table.where('b < 1).select('a).as("left") + val right = table.where('b < 3).select('a).as("right") + + val query = Except(left, right) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(left.output, right.output, + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33707080c1301..8b66f77b2f923 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -589,6 +589,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + test("except distinct - SQL compliance") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") From 8c6a9c90a36a938372f28ee8be72178192fbc313 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Jan 2018 13:59:21 +0800 Subject: [PATCH 0246/2461] [SPARK-23279][SS] Avoid triggering distributed job for Console sink ## What changes were proposed in this pull request? Console sink will redistribute collected local data and trigger a distributed job in each batch, this is not necessary, so here change to local job. ## How was this patch tested? Existing UT and manual verification. Author: jerryshao Closes #20447 from jerryshao/console-minor. --- .../spark/sql/execution/streaming/sources/ConsoleWriter.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d46f4d7b86360..c57bdc4a28905 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources +import scala.collection.JavaConverters._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -61,7 +63,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println("-------------------------------------------") // scalastyle:off println spark - .createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createDataFrame(rows.toList.asJava, schema) .show(numRowsToShow, isTruncated) } From 695f7146bca342a0ee192d8c7f5ec48d4d8577a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 31 Jan 2018 15:13:15 +0800 Subject: [PATCH 0247/2461] [SPARK-23272][SQL] add calendar interval type support to ColumnVector ## What changes were proposed in this pull request? `ColumnVector` is aimed to support all the data types, but `CalendarIntervalType` is missing. Actually we do support interval type for inner fields, e.g. `ColumnarRow`, `ColumnarArray` both support interval type. It's weird if we don't support interval type at the top level. This PR adds the interval type support. This PR also makes `ColumnVector.getChild` protect. We need it public because `MutableColumnaRow.getInterval` needs it. Now the interval implementation is in `ColumnVector.getInterval`. ## How was this patch tested? a new test. Author: Wenchen Fan Closes #20438 from cloud-fan/interval. --- .../vectorized/MutableColumnarRow.java | 4 +- .../sql/vectorized/ArrowColumnVector.java | 2 +- .../spark/sql/vectorized/ColumnVector.java | 26 ++++++++++- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 4 +- .../vectorized/ColumnarBatchSuite.scala | 45 +++++++++++++++++-- 6 files changed, 70 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 2bab095d4d951..66668f3753604 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,9 +146,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChild(0).getInt(rowId); - final long microseconds = columns[ordinal].getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return columns[ordinal].getInterval(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 9803c3dec6de2..a75d76bd0f82e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 4b955ceddd0f2..111f5d9b358d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -195,6 +196,7 @@ public double[] getDoubles(int rowId, int count) { * struct field. */ public final ColumnarRow getStruct(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarRow(this, rowId); } @@ -236,9 +238,29 @@ public MapData getMap(int ordinal) { public abstract byte[] getBinary(int rowId); /** - * Returns the ordinal's child column vector. + * Returns the calendar interval type value for rowId. + * + * In Spark, calendar interval type value is basically an integer value representing the number of + * months in this interval, and a long value representing the number of microseconds in this + * interval. An interval type vector is the same as a struct type vector with 2 fields: `months` + * and `microseconds`. + * + * To support interval type, implementations must implement {@link #getChild(int)} and define 2 + * child vectors: the first child vector is an int type vector, containing all the month values of + * all the interval values in this vector. The second child vector is a long type vector, + * containing all the microsecond values of all the interval values in this vector. + */ + public final CalendarInterval getInterval(int rowId) { + if (isNullAt(rowId)) return null; + final int months = getChild(0).getInt(rowId); + final long microseconds = getChild(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + /** + * @return child [[ColumnVector]] at the given ordinal. */ - public abstract ColumnVector getChild(int ordinal); + protected abstract ColumnVector getChild(int ordinal); /** * Data type for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d2c3ec8648d3..72c07ee7cad3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -135,9 +135,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChild(0).getInt(offset + ordinal); - long microseconds = data.getChild(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); + return data.getInterval(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 25db7e09d20d0..6ca749d7c6e85 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -139,9 +139,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (data.getChild(ordinal).isNullAt(rowId)) return null; - final int months = data.getChild(ordinal).getChild(0).getInt(rowId); - final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return data.getChild(ordinal).getInterval(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 1873c24ab063c..925c101fe1fee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -620,6 +620,39 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 0) } + testVector("CalendarInterval APIs", 4, CalendarIntervalType) { + column => + val reference = mutable.ArrayBuffer.empty[CalendarInterval] + + val months = column.getChild(0) + val microseconds = column.getChild(1) + assert(months.dataType() == IntegerType) + assert(microseconds.dataType() == LongType) + + months.putInt(0, 1) + microseconds.putLong(0, 100) + reference += new CalendarInterval(1, 100) + + months.putInt(1, 0) + microseconds.putLong(1, 2000) + reference += new CalendarInterval(0, 2000) + + column.putNull(2) + reference += null + + months.putInt(3, 20) + microseconds.putLong(3, 0) + reference += new CalendarInterval(20, 0) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getInterval(i), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { column => @@ -739,14 +772,20 @@ class ColumnarBatchSuite extends SparkFunSuite { c1.putInt(0, 123) c2.putDouble(0, 3.45) - c1.putInt(1, 456) - c2.putDouble(1, 5.67) + + column.putNull(1) + + c1.putInt(2, 456) + c2.putDouble(2, 5.67) val s = column.getStruct(0) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) - val s2 = column.getStruct(1) + assert(column.isNullAt(1)) + assert(column.getStruct(1) == null) + + val s2 = column.getStruct(2) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) } From 161a3f2ae324271a601500e3d2900db9359ee2ef Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 31 Jan 2018 10:37:37 +0200 Subject: [PATCH 0248/2461] [SPARK-23112][DOC] Update ML migration guide with breaking and behavior changes. Add breaking changes, as well as update behavior changes, to `2.3` ML migration guide. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20421 from MLnick/SPARK-23112-ml-guide. --- docs/ml-guide.md | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b957445579ffd..702bcf748fc74 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -108,7 +108,13 @@ and the migration guide below will explain all changes between releases. ### Breaking changes -There are no breaking changes. +* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner +and better accommodate the addition of the multi-class summary. This is a breaking change for user +code that casts a `LogisticRegressionTrainingSummary` to a +` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail +(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which +will still work correctly for both multinomial and binary cases. ### Deprecations and changes of behavior @@ -123,8 +129,19 @@ new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and earlier versions, the level of parallelism was set to the default threadpool size in Scala. +* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156): + The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than + `1`. This will cause training results to be different between `2.3` and earlier versions. +* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681): + Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients + when some features had zero variance. +* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957): + Tree algorithms now use mid-points for split values. This may change results from model training. +* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657): + Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent + with the output in R. This may change results from model training in this scenario. ## Previous Spark versions From 3d0911bbe47f76c341c090edad3737e88a67e3d7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Jan 2018 20:04:51 +0900 Subject: [PATCH 0249/2461] [SPARK-23228][PYSPARK] Add Python Created jsparkSession to JVM's defaultSession ## What changes were proposed in this pull request? In the current PySpark code, Python created `jsparkSession` doesn't add to JVM's defaultSession, this `SparkSession` object cannot be fetched from Java side, so the below scala code will be failed when loaded in PySpark application. ```scala class TestSparkSession extends SparkListener with Logging { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case CreateTableEvent(db, table) => val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) assert(session.isDefined) val tableInfo = session.get.sharedState.externalCatalog.getTable(db, table) logInfo(s"Table info ${tableInfo}") case e => logInfo(s"event $e") } } } ``` So here propose to add fresh create `jsparkSession` to `defaultSession`. ## How was this patch tested? Manual verification. Author: jerryshao Author: hyukjinkwon Author: Saisai Shao Closes #20404 from jerryshao/SPARK-23228. --- python/pyspark/sql/session.py | 10 +++++++++- python/pyspark/sql/tests.py | 28 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6c84023c43fb6..1ed04298bc899 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -213,7 +213,12 @@ def __init__(self, sparkContext, jsparkSession=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if jsparkSession is None: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if self._jvm.SparkSession.getDefaultSession().isDefined() \ + and not self._jvm.SparkSession.getDefaultSession().get() \ + .sparkContext().isStopped(): + jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + else: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) @@ -225,6 +230,7 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + self._jvm.SparkSession.setDefaultSession(self._jsparkSession) def _repr_html_(self): return """ @@ -759,6 +765,8 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc80870d3cd9f..dc26b96334c7a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -69,7 +69,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -2925,6 +2925,32 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: From 48dd6a4c79e33a8f2dba8349b58aa07e4796a925 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 00:24:42 +0800 Subject: [PATCH 0250/2461] revert [SPARK-22785][SQL] remove ColumnVector.anyNullsSet ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/19980 , we thought `anyNullsSet` can be simply implemented by `numNulls() > 0`. This is logically true, but may have performance problems. `OrcColumnVector` is an example. It doesn't have the `numNulls` property, only has a `noNulls` property. We will lose a lot of performance if we use `numNulls() > 0` to check null. This PR simply revert #19980, with a renaming to call it `hasNull`. Better name suggestions are welcome, e.g. `nullable`? ## How was this patch tested? existing test Author: Wenchen Fan Closes #20452 from cloud-fan/null. --- .../execution/datasources/orc/OrcColumnVector.java | 5 +++++ .../execution/vectorized/OffHeapColumnVector.java | 2 +- .../sql/execution/vectorized/OnHeapColumnVector.java | 2 +- .../execution/vectorized/WritableColumnVector.java | 7 ++++++- .../spark/sql/vectorized/ArrowColumnVector.java | 5 +++++ .../apache/spark/sql/vectorized/ColumnVector.java | 5 +++++ .../vectorized/ArrowColumnVectorSuite.scala | 12 ++++++++++++ .../execution/vectorized/ColumnarBatchSuite.scala | 9 +++++++++ 8 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 5078bc7922ee2..78203e3145c62 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -77,6 +77,11 @@ public void close() { } + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + @Override public int numNulls() { if (baseData.isRepeating) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1c45b846790b6..fa52e4a354786 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -123,7 +123,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 1d538fe4181b7..cccef78aebdc8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -119,7 +119,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index a8ec8ef2aadf8..8ebc1adf59c8b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -59,8 +59,8 @@ public void reset() { elementsAppended = 0; if (numNulls > 0) { putNotNulls(0, capacity); + numNulls = 0; } - numNulls = 0; } @Override @@ -102,6 +102,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { throw new RuntimeException(message, cause); } + @Override + public boolean hasNull() { + return numNulls > 0; + } + @Override public int numNulls() { return numNulls; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index a75d76bd0f82e..5ff6474c161f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -37,6 +37,11 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; + } + @Override public int numNulls() { return accessor.getNullCount(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 111f5d9b358d4..d588956208047 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -65,6 +65,11 @@ public abstract class ColumnVector implements AutoCloseable { @Override public abstract void close(); + /** + * Returns true if this column vector contains any null values. + */ + public abstract boolean hasNull(); + /** * Returns the number of nulls in this column vector. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index e794f50781ff2..b55489cb2678a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -42,6 +42,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -69,6 +70,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -96,6 +98,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -123,6 +126,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -150,6 +154,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -177,6 +182,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -204,6 +210,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -232,6 +239,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -258,6 +266,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -300,6 +309,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val array0 = columnVector.getArray(0) @@ -344,6 +354,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(!columnVector.hasNull) assert(columnVector.numNulls === 0) val row0 = columnVector.getStruct(0) @@ -396,6 +407,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val row0 = columnVector.getStruct(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 925c101fe1fee..168bc5e3e480b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -66,22 +66,27 @@ class ColumnarBatchSuite extends SparkFunSuite { column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNull() reference += false + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNull() reference += true + assert(column.hasNull) assert(column.numNulls() == 1) column.appendNulls(3) (1 to 3).foreach(_ => reference += true) + assert(column.hasNull) assert(column.numNulls() == 4) idx = column.elementsAppended @@ -89,11 +94,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putNotNull(idx) reference += false idx += 1 + assert(column.hasNull) assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 + assert(column.hasNull) assert(column.numNulls() == 5) column.putNulls(idx, 3) @@ -101,6 +108,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += true reference += true idx += 3 + assert(column.hasNull) assert(column.numNulls() == 8) column.putNotNulls(idx, 4) @@ -109,6 +117,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 + assert(column.hasNull) assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => From 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9 Mon Sep 17 00:00:00 2001 From: Glen Takahashi Date: Thu, 1 Feb 2018 01:14:01 +0800 Subject: [PATCH 0251/2461] [SPARK-23249][SQL] Improved block merging logic for partitions ## What changes were proposed in this pull request? Change DataSourceScanExec so that when grouping blocks together into partitions, also checks the end of the sorted list of splits to more efficiently fill out partitions. ## How was this patch tested? Updated old test to reflect the new logic, which causes the # of partitions to drop from 4 -> 3 Also, a current test exists to test large non-splittable files at https://github.com/glentakahashi/spark/blob/c575977a5952bf50b605be8079c9be1e30f3bd36/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala#L346 ## Rationale The current bin-packing method of next-fit descending for blocks into partitions is sub-optimal in a lot of cases and will result in extra partitions, un-even distribution of block-counts across partitions, and un-even distribution of partition sizes. As an example, 128 files ranging from 1MB, 2MB,...127MB,128MB. will result in 82 partitions with the current algorithm, but only 64 using this algorithm. Also in this example, the max # of blocks per partition in NFD is 13, while in this algorithm is is 2. More generally, running a simulation of 1000 runs using 128MB blocksize, between 1-1000 normally distributed file sizes between 1-500Mb, you can see an improvement of approx 5% reduction of partition counts, and a large reduction in standard deviation of blocks per partition. This algorithm also runs in O(n) time as NFD does, and in every case is strictly better results than NFD. Overall, the more even distribution of blocks across partitions and therefore reduced partition counts should result in a small but significant performance increase across the board Author: Glen Takahashi Closes #20372 from glentakahashi/feature/improved-block-merging. --- .../sql/execution/DataSourceScanExec.scala | 29 ++++++++++++++----- .../datasources/FileSourceStrategySuite.scala | 15 ++++------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index aa66ee7e948ea..f7732e2098c29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -445,16 +445,29 @@ case class FileSourceScanExec( currentSize = 0 } - // Assign files to partitions using "Next Fit Decreasing" - splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes) { - closePartition() + def addFile(file: PartitionedFile): Unit = { + currentFiles += file + currentSize += file.length + openCostInBytes + } + + var frontIndex = 0 + var backIndex = splitFiles.length - 1 + + while (frontIndex <= backIndex) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + while (frontIndex <= backIndex && + currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + } + while (backIndex > frontIndex && + currentSize + splitFiles(backIndex).length <= maxSplitBytes) { + addFile(splitFiles(backIndex)) + backIndex -= 1 } - // Add the given file to the current partition. - currentSize += file.length + openCostInBytes - currentFiles += file + closePartition() } - closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..bfccc9335b361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,16 +141,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] - assert(partitions.size == 4, "when checking partitions") - assert(partitions(0).files.size == 1, "when checking partition 1") + // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] + assert(partitions.size == 3, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") - assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1) + // First partition reads (file1, file6) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) + assert(partitions(0).files(1).start == 0) + assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -163,10 +164,6 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) - - // Final partition reads (file6) - assert(partitions(3).files(0).start == 0) - assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From dd242bad39cc6df7ff6c6b16642bdc92dccca6ac Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 31 Jan 2018 11:48:19 -0800 Subject: [PATCH 0252/2461] [SPARK-21525][STREAMING] Check error code from supervisor RPC. The code was ignoring the error code from the AddBlock RPC, which means that a failure to write to the WAL was being ignored by the receiver, and would lead to the block being acked (in the case of the Flume receiver) and data potentially lost. Author: Marcelo Vanzin Closes #20161 from vanzin/SPARK-21525. --- .../spark/streaming/receiver/ReceiverSupervisorImpl.scala | 4 +++- .../apache/spark/streaming/scheduler/ReceiverTracker.scala | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 27644a645727c..5d38c56aa5873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -159,7 +159,9 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) - trackerEndpoint.askSync[Boolean](AddBlock(blockInfo)) + if (!trackerEndpoint.askSync[Boolean](AddBlock(blockInfo))) { + throw new SparkException("Failed to add block to receiver tracker.") + } logDebug(s"Reported block $blockId") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6f130c803f310..c74ca1918a81d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -521,7 +521,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (active) { context.reply(addBlock(receivedBlockInfo)) } else { - throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + context.sendFailure( + new IllegalStateException("ReceiverTracker RpcEndpoint already shut down.")) } } }) From 9ff1d96f01e2c89acfd248db917e068b93f519a6 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 31 Jan 2018 13:52:47 -0800 Subject: [PATCH 0253/2461] [SPARK-23281][SQL] Query produces results in incorrect order when a composite order by clause refers to both original columns and aliases ## What changes were proposed in this pull request? Here is the test snippet. ``` SQL scala> Seq[(Integer, Integer)]( | (1, 1), | (1, 3), | (2, 3), | (3, 3), | (4, null), | (5, null) | ).toDF("key", "value").createOrReplaceTempView("src") scala> sql( | """ | |SELECT MAX(value) as value, key as col2 | |FROM src | |GROUP BY key | |ORDER BY value desc, key | """.stripMargin).show +-----+----+ |value|col2| +-----+----+ | 3| 3| | 3| 2| | 3| 1| | null| 5| | null| 4| +-----+----+ ```SQL Here is the explain output : ```SQL == Parsed Logical Plan == 'Sort ['value DESC NULLS LAST, 'key ASC NULLS FIRST], true +- 'Aggregate ['key], ['MAX('value) AS value#9, 'key AS col2#10] +- 'UnresolvedRelation `src` == Analyzed Logical Plan == value: int, col2: int Project [value#9, col2#10] +- Sort [value#9 DESC NULLS LAST, col2#10 DESC NULLS LAST], true +- Aggregate [key#5], [max(value#6) AS value#9, key#5 AS col2#10] +- SubqueryAlias src +- Project [_1#2 AS key#5, _2#3 AS value#6] +- LocalRelation [_1#2, _2#3] ``` SQL The sort direction is being wrongly changed from ASC to DSC while resolving ```Sort``` in resolveAggregateFunctions. The above testcase models TPCDS-Q71 and thus we have the same issue in Q71 as well. ## How was this patch tested? A few tests are added in SQLQuerySuite. Author: Dilip Biswal Closes #20453 from dilipbiswal/local_spark. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 91cb0365a0856..251099f750cf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1493,7 +1493,7 @@ class Analyzer( // to push down this ordering expression and can reference the original aggregate // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map { case (evaluated, order) => val index = originalAggExprs.indexWhere { case Alias(child, _) => child semanticEquals evaluated.child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ffd736d2ebbb6..8f14575c3325f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.File -import java.math.MathContext import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -1618,6 +1617,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23281: verify the correctness of sort direction on composite order by clause") { + withTempView("src") { + Seq[(Integer, Integer)]( + (1, 1), + (1, 3), + (2, 3), + (3, 3), + (4, null), + (5, null) + ).toDF("key", "value").createOrReplaceTempView("src") + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key + """.stripMargin), + Seq(Row(3, 1), Row(3, 2), Row(3, 3), Row(null, 4), Row(null, 5))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key desc + """.stripMargin), + Seq(Row(3, 3), Row(3, 2), Row(3, 1), Row(null, 5), Row(null, 4))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value asc, key desc + """.stripMargin), + Seq(Row(null, 5), Row(null, 4), Row(3, 3), Row(3, 2), Row(3, 1))) + } + } + test("run sql directly on files") { val df = spark.range(100).toDF() withTempPath(f => { From f470df2fcf14e6234c577dc1bdfac27d49b441f5 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Thu, 1 Feb 2018 11:15:17 +0900 Subject: [PATCH 0254/2461] [SPARK-23157][SQL][FOLLOW-UP] DataFrame -> SparkDataFrame in R comment Author: Henry Robinson Closes #20443 from henryr/SPARK-23157. --- R/pkg/R/DataFrame.R | 4 ++-- python/pyspark/sql/dataframe.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 547b5ea48a555..41c3c3a89fa72 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,8 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in -#' the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this SparkDataFrame), or an atomic +#' vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 055b2c4a0ffec..1496cba91b90e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,7 +1829,7 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. - The column expression must be an expression over this dataframe; attempting to add + The column expression must be an expression over this DataFrame; attempting to add a column from some other dataframe will raise an error. :param colName: string, name of the new column. From 52e00f70663a87b5837235bdf72a3e6f84e11411 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 11:56:06 +0800 Subject: [PATCH 0255/2461] [SPARK-23280][SQL] add map type support to ColumnVector ## What changes were proposed in this pull request? Fill the last missing piece of `ColumnVector`: the map type support. The idea is similar to the array type support. A map is basically 2 arrays: keys and values. We ask the implementations to provide a key array, a value array, and an offset and length to specify the range of this map in the key/value array. In `WritableColumnVector`, we put the key array in first child vector, and value array in second child vector, and offsets and lengths in the current vector, which is very similar to how array type is implemented here. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20450 from cloud-fan/map. --- .../datasources/orc/OrcColumnVector.java | 6 ++ .../vectorized/ColumnVectorUtils.java | 15 ++++ .../vectorized/OffHeapColumnVector.java | 4 +- .../vectorized/OnHeapColumnVector.java | 4 +- .../vectorized/WritableColumnVector.java | 13 ++++ .../sql/vectorized/ArrowColumnVector.java | 5 ++ .../spark/sql/vectorized/ColumnVector.java | 14 +++- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarMap.java | 53 ++++++++++++++ .../spark/sql/vectorized/ColumnarRow.java | 5 +- .../vectorized/ColumnarBatchSuite.scala | 70 ++++++++++++++----- 11 files changed, 166 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 78203e3145c62..c8add4c9f486c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; /** @@ -177,6 +178,11 @@ public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index a2853bbadc92b..829f3ce750fe6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -20,8 +20,10 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; @@ -30,6 +32,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -109,6 +112,18 @@ public static int[] toJavaIntArray(ColumnarArray array) { return array.toIntArray(); } + public static Map toJavaIntMap(ColumnarMap map) { + int[] keys = toJavaIntArray(map.keyArray()); + int[] values = toJavaIntArray(map.valueArray()); + assert keys.length == values.length; + + Map result = new HashMap<>(); + for (int i = 0; i < keys.length; i++) { + result.put(keys[i], values[i]); + } + return result; + } + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index fa52e4a354786..754c26579ff08 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -60,7 +60,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long nulls; private long data; - // Set iff the type is array. + // Only set if type is Array or Map. private long lengthData; private long offsetData; @@ -530,7 +530,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (isArray()) { + if (isArray() || type instanceof MapType) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index cccef78aebdc8..23dcc104e67c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -69,7 +69,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private float[] floatData; private double[] doubleData; - // Only set if type is Array. + // Only set if type is Array or Map. private int[] arrayLengths; private int[] arrayOffsets; @@ -503,7 +503,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (isArray()) { + if (isArray() || type instanceof MapType) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8ebc1adf59c8b..c2e595455549c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -612,6 +613,13 @@ public final ColumnarArray getArray(int rowId) { return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } + // `WritableColumnVector` puts the key array in the first child column vector, value array in the + // second child column vector, and puts the offsets and lengths in the current column vector. + @Override + public final ColumnarMap getMap(int rowId) { + return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); + } + public WritableColumnVector arrayData() { return childColumns[0]; } @@ -705,6 +713,11 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } + } else if (type instanceof MapType) { + MapType mapType = (MapType) type; + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5ff6474c161f3..f3ece538c3b80 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -119,6 +119,11 @@ public ColumnarArray getArray(int rowId) { return accessor.getArray(rowId); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d588956208047..05271ec1f46ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -220,10 +220,18 @@ public final ColumnarRow getStruct(int rowId) { /** * Returns the map type value for rowId. + * + * In Spark, map type value is basically a key data array and a value data array. A key from the + * key array with a index and a value from the value array with the same index contribute to + * an entry of this map type value. + * + * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all + * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data + * of all the values of all the maps in this vector, and a pair of offset and length which + * specify the range of the key/value array that belongs to the map type value at rowId. */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } + public abstract ColumnarMap getMap(int ordinal); /** * Returns the decimal type value for rowId. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 72c07ee7cad3f..7c7a1c806a2b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -149,8 +149,8 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java new file mode 100644 index 0000000000000..35648e386c4f1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.MapData; + +/** + * Map abstraction in {@link ColumnVector}. + */ +public final class ColumnarMap extends MapData { + private final ColumnarArray keys; + private final ColumnarArray values; + private final int length; + + public ColumnarMap(ColumnVector keys, ColumnVector values, int offset, int length) { + this.length = length; + this.keys = new ColumnarArray(keys, offset, length); + this.values = new ColumnarArray(values, offset, length); + } + + @Override + public int numElements() { return length; } + + @Override + public ColumnarArray keyArray() { + return keys; + } + + @Override + public ColumnarArray valueArray() { + return values; + } + + @Override + public ColumnarMap copy() { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 6ca749d7c6e85..0c9e92ed11fbd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -155,8 +155,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getMap(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 168bc5e3e480b..8fe2985836f2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -673,35 +673,37 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } - // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + // Populate it with arrays [0], [1, 2], null, [], [3, 4, 5] column.putArray(0, 0, 1) column.putArray(1, 1, 2) - column.putArray(2, 2, 0) - column.putArray(3, 3, 3) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getArray(0).numElements == 1) + assert(column.getArray(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getArray(3).numElements == 0) + assert(column.getArray(4).numElements == 3) val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) - val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) - val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(4)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) - // Verify the ArrayData APIs - assert(column.getArray(0).numElements() == 1) + // Verify the ArrayData get APIs assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).numElements() == 0) - - assert(column.getArray(3).numElements() == 3) - assert(column.getArray(3).getInt(0) == 3) - assert(column.getArray(3).getInt(1) == 4) - assert(column.getArray(3).getInt(2) == 5) + assert(column.getArray(4).getInt(0) == 3) + assert(column.getArray(4).getInt(1) == 4) + assert(column.getArray(4).getInt(2) == 5) // Add a longer array which requires resizing column.reset() @@ -711,8 +713,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) - === array) + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } test("toArray for primitive types") { @@ -770,6 +771,43 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + test("Int Map") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = allocate(10, new MapType(IntegerType, IntegerType, false), memMode) + (0 to 1).foreach { colIndex => + val data = column.getChild(colIndex) + (0 to 5).foreach {i => + data.putInt(i, i * (colIndex + 1)) + } + } + + // Populate it with maps [0->0], [1->2, 2->4], null, [], [3->6, 4->8, 5->10] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getMap(0).numElements == 1) + assert(column.getMap(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getMap(3).numElements == 0) + assert(column.getMap(4).numElements == 3) + + val a1 = ColumnVectorUtils.toJavaIntMap(column.getMap(0)) + val a2 = ColumnVectorUtils.toJavaIntMap(column.getMap(1)) + val a4 = ColumnVectorUtils.toJavaIntMap(column.getMap(3)) + val a5 = ColumnVectorUtils.toJavaIntMap(column.getMap(4)) + + assert(a1.asScala == Map(0 -> 0)) + assert(a2.asScala == Map(1 -> 2, 2 -> 4)) + assert(a4.asScala == Map()) + assert(a5.asScala == Map(3 -> 6, 4 -> 8, 5 -> 10)) + + column.close() + } + } + testVector( "Struct Column", 10, From 2ac895be909de7e58e1051dc2a1bba98a25bf4be Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 1 Feb 2018 12:05:12 +0800 Subject: [PATCH 0256/2461] [SPARK-23247][SQL] combines Unsafe operations and statistics operations in Scan Data Source ## What changes were proposed in this pull request? Currently, we scan the execution plan of the data source, first the unsafe operation of each row of data, and then re traverse the data for the count of rows. In terms of performance, this is not necessary. this PR combines the two operations and makes statistics on the number of rows while performing the unsafe operation. Before modified, ``` val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) iter.map(proj) } val numOutputRows = longMetric("numOutputRows") unsafeRow.map { r => numOutputRows += 1 r } ``` After modified, val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) iter.map( r => { numOutputRows += 1 proj(r) }) } ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20415 from heary-cao/DataSourceScanExec. --- .../sql/execution/DataSourceScanExec.scala | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index f7732e2098c29..ba1157d5b6a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -90,16 +90,15 @@ case class RowDataSourceScanExec( Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => + val numOutputRows = longMetric("numOutputRows") + + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) - iter.map(proj) - } - - val numOutputRows = longMetric("numOutputRows") - unsafeRow.map { r => - numOutputRows += 1 - r + iter.map( r => { + numOutputRows += 1 + proj(r) + }) } } @@ -326,22 +325,22 @@ case class FileSourceScanExec( // 2) the number of columns should be smaller than spark.sql.codegen.maxFields WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { - val unsafeRows = { - val scan = inputRDD - if (needsUnsafeRowConversion) { - scan.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map(proj) - } - } else { - scan - } - } val numOutputRows = longMetric("numOutputRows") - unsafeRows.map { r => - numOutputRows += 1 - r + + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map( r => { + numOutputRows += 1 + proj(r) + }) + } + } else { + inputRDD.map { r => + numOutputRows += 1 + r + } } } } From 56ae32657e9e5d1e30b62afe77d9e14eb07cf4fb Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 31 Jan 2018 20:33:51 -0800 Subject: [PATCH 0257/2461] [SPARK-23268][SQL] Reorganize packages in data source V2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. create a new package for partitioning/distribution related classes. As Spark will add new concrete implementations of `Distribution` in new releases, it is good to have a new package for partitioning/distribution related classes. 2. move streaming related class to package `org.apache.spark.sql.sources.v2.reader/writer.streaming`, instead of `org.apache.spark.sql.sources.v2.streaming.reader/writer`. So that the there won't be package reader/writer inside package streaming, which is quite confusing. Before change: ``` v2 ├── reader ├── streaming │   ├── reader │   └── writer └── writer ``` After change: ``` v2 ├── reader │   └── streaming └── writer └── streaming ``` ## How was this patch tested? Unit test. Author: Wang Gengliang Closes #20435 from gengliangwang/new_pkg. --- .../spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- .../apache/spark/sql/kafka010/KafkaSourceOffset.scala | 2 +- .../spark/sql/kafka010/KafkaSourceProvider.scala | 5 +++-- .../apache/spark/sql/kafka010/KafkaStreamWriter.scala | 2 +- .../{streaming => reader}/ContinuousReadSupport.java | 4 ++-- .../{streaming => reader}/MicroBatchReadSupport.java | 4 ++-- .../sources/v2/reader/SupportsReportPartitioning.java | 1 + .../{ => partitioning}/ClusteredDistribution.java | 3 ++- .../v2/reader/{ => partitioning}/Distribution.java | 3 ++- .../v2/reader/{ => partitioning}/Partitioning.java | 4 +++- .../streaming}/ContinuousDataReader.java | 2 +- .../reader => reader/streaming}/ContinuousReader.java | 2 +- .../reader => reader/streaming}/MicroBatchReader.java | 2 +- .../{streaming/reader => reader/streaming}/Offset.java | 2 +- .../reader => reader/streaming}/PartitionOffset.java | 2 +- .../spark/sql/sources/v2/writer/DataSourceWriter.java | 2 +- .../v2/{streaming => writer}/StreamWriteSupport.java | 5 ++--- .../writer => writer/streaming}/StreamWriter.java | 2 +- .../datasources/v2/DataSourcePartitioning.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../execution/datasources/v2/WriteToDataSourceV2.scala | 2 +- .../sql/execution/streaming/MicroBatchExecution.scala | 6 +++--- .../sql/execution/streaming/RateSourceProvider.scala | 5 ++--- .../sql/execution/streaming/RateStreamOffset.scala | 2 +- .../sql/execution/streaming/StreamingRelation.scala | 2 +- .../apache/spark/sql/execution/streaming/console.scala | 4 ++-- .../continuous/ContinuousDataSourceRDDIter.scala | 10 +++------- .../streaming/continuous/ContinuousExecution.scala | 5 +++-- .../continuous/ContinuousRateStreamSource.scala | 2 +- .../streaming/continuous/EpochCoordinator.scala | 4 ++-- .../execution/streaming/sources/ConsoleWriter.scala | 2 +- .../execution/streaming/sources/MicroBatchWriter.scala | 2 +- .../streaming/sources/RateStreamSourceV2.scala | 3 +-- .../sql/execution/streaming/sources/memoryV2.scala | 3 +-- .../apache/spark/sql/streaming/DataStreamReader.scala | 2 +- .../apache/spark/sql/streaming/DataStreamWriter.scala | 2 +- .../spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../sql/sources/v2/JavaPartitionAwareDataSource.java | 3 +++ .../sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 1 + .../streaming/sources/StreamingDataSourceV2Suite.scala | 8 ++++---- 41 files changed, 64 insertions(+), 61 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => reader}/ContinuousReadSupport.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => reader}/MicroBatchReadSupport.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/ClusteredDistribution.java (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/Distribution.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/Partitioning.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/ContinuousDataReader.java (96%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/ContinuousReader.java (98%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/MicroBatchReader.java (98%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/Offset.java (97%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/PartitionOffset.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => writer}/StreamWriteSupport.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/writer => writer/streaming}/StreamWriter.java (98%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 8c733426b256f..41c443bc12120 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRo import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..8d41c0da2b133 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 85e96b6783327..694ca76e24964 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index a24efdefa4464..9307bfc001c03 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java index f79424e036a52..0c1d5d1a9577a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java index 22660e42ad850..5e8f0c0dafdcf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index a2383a9d7d680..5405a916951b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 27905e325df87..2d0ee50212b56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index b37562167d9ef..f6b111fdf220d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * An interface to represent data distribution requirement, which specifies how the records should diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 5e334d13a1215..309d9e5de0a0f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -15,9 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java index 3f13a4dbf5793..47d26440841fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 6e5177ee83a62..d1d1e7ffd1dd4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index fcec446d892f5..67ebde30d61a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index abba3e7188b13..e41c0351edc82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 4688b85f49f5f..383e73db6762b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import java.io.Serializable; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index d89d27d0e5b1b..7096aec0d22c2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -28,7 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( + * {@link StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java index 7c5f304425093..1c0e2e12f8d51 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 915ee6c4fb390..4913341bd505d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.writer; +package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 943d0100aca56..017a6737161a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Partitioning} /** * An adapter from public data source partitioning to catalyst internal `Partitioning`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index ee085820b0775..df469af2c262a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index c544adbf32cdf..6592bd72fa338 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 93572f7a63132..d9aa8573ba930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow +import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 5e3fee633f591..ce5e63f5bde85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -30,11 +30,10 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 261d69bbd9843..02fed50485b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.streaming.reader.Offset { + extends v2.reader.streaming.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a0ee683a895d8..845c8d2c14e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 3f5bb489d6528..db600866067bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 8a7a38b22caca..cf02c0dda25d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{SystemClock, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9402d7c1dcefd..08c81419a9d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index ff028ebc4236a..0eaaa4889ba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamO import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 84d262116cb46..cc6808065c0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index c57bdc4a28905..d276403190b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7ce9a7b84479..56f7ff25cbed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 43949e6180aaa..1315885da8a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 58767261dc684..3411edbc53412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f1b3f93c4e1fc..116ac3da07b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b5b30d77945c..9aac360fd4bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index fdd709cdb1f38..ddb1edc433d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 99cca0f6dd626..32fad59b97ff6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -27,6 +27,9 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index b060aeeef811d..3158995ec62f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ee50e8a92270b..2f49b07018aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3127d664d32dc..cb873ab688e96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.streaming._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils From b2e7677f4d3d8f47f5f148680af39d38f2b558f0 Mon Sep 17 00:00:00 2001 From: Atallah Hezbor Date: Wed, 31 Jan 2018 20:45:55 -0800 Subject: [PATCH 0258/2461] [SPARK-21396][SQL] Fixes MatchError when UDTs are passed through Hive Thriftserver Signed-off-by: Atallah Hezbor ## What changes were proposed in this pull request? This PR proposes modifying the match statement that gets the columns of a row in HiveThriftServer. There was previously no case for `UserDefinedType`, so querying a table that contained them would throw a match error. The changes catch that case and return the string representation. ## How was this patch tested? While I would have liked to add a unit test, I couldn't easily incorporate UDTs into the ``HiveThriftServer2Suites`` pipeline. With some guidance I would be happy to push a commit with tests. Instead I did a manual test by loading a `DataFrame` with Point UDT in a spark shell with a HiveThriftServer. Then in beeline, connecting to the server and querying that table. Here is the result before the change ``` 0: jdbc:hive2://localhost:10000> select * from chicago; Error: scala.MatchError: org.apache.spark.sql.PointUDT2d980dc3 (of class org.apache.spark.sql.PointUDT) (state=,code=0) ``` And after the change: ``` 0: jdbc:hive2://localhost:10000> select * from chicago; +---------------------------------------+--------------+------------------------+---------------------+--+ | __fid__ | case_number | dtg | geom | +---------------------------------------+--------------+------------------------+---------------------+--+ | 109602f9-54f8-414b-8c6f-42b1a337643e | 2 | 2016-01-01 19:00:00.0 | POINT (-77 38) | | 709602f9-fcff-4429-8027-55649b6fd7ed | 1 | 2015-12-31 19:00:00.0 | POINT (-76.5 38.5) | | 009602f9-fcb5-45b1-a867-eb8ba10cab40 | 3 | 2016-01-02 19:00:00.0 | POINT (-78 39) | +---------------------------------------+--------------+------------------------+---------------------+--+ ``` Author: Atallah Hezbor Closes #20385 from atallahhezbor/udts_over_hive. --- .../thriftserver/SparkExecuteStatementOperation.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala | 8 +++++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601eaa..3cfc81b8a9579 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -102,7 +102,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) - case _: ArrayType | _: StructType | _: MapType => + case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => val hiveString = HiveUtils.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c7717d70c996f..d9627eb9790eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -460,6 +460,7 @@ private[spark] object HiveUtils extends Logging { case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString + case (other, _ : UserDefinedType[_]) => other.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index 8697d47e89e89..f2b75e4b23f02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SQLTestUtils} import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -62,4 +62,10 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton Thread.currentThread().setContextClassLoader(contextClassLoader) } } + + test("toHiveString correctly handles UDTs") { + val point = new ExamplePoint(50.0, 50.0) + val tpe = new ExamplePointUDT() + assert(HiveUtils.toHiveString((point, tpe)) === "(50.0, 50.0)") + } } From cc41245fa3f954f961541bf4b4275c28473042b8 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 1 Feb 2018 12:56:07 +0800 Subject: [PATCH 0259/2461] [SPARK-23188][SQL] Make vectorized columar reader batch size configurable ## What changes were proposed in this pull request? This PR include the following changes: - Make the capacity of `VectorizedParquetRecordReader` configurable; - Make the capacity of `OrcColumnarBatchReader` configurable; - Update the error message when required capacity in writable columnar vector cannot be fulfilled. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20361 from jiangxb1987/vectorCapacity. --- .../apache/spark/sql/internal/SQLConf.scala | 16 ++++++++++++++ .../orc/OrcColumnarBatchReader.java | 22 ++++++++++--------- .../VectorizedParquetRecordReader.java | 20 ++++++++--------- .../vectorized/WritableColumnVector.java | 7 ++++-- .../datasources/orc/OrcFileFormat.scala | 3 ++- .../parquet/ParquetFileFormat.scala | 3 ++- .../parquet/ParquetEncodingSuite.scala | 12 +++++++--- .../datasources/parquet/ParquetIOSuite.scala | 21 +++++++++++++----- .../parquet/ParquetReadBenchmark.scala | 11 +++++++--- 9 files changed, 78 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7394a0d7cf983..90654e67457e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -375,6 +375,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") + .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -398,6 +404,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.orc.columnarReaderBatchSize") + .doc("The number of rows to include in a orc vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + "vectorized ORC reader.") @@ -1250,10 +1262,14 @@ class SQLConf extends Serializable with Logging { def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5e7cad470e1d1..dcebdc39f0aa2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,8 +49,9 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -81,9 +82,10 @@ public class OrcColumnarBatchReader extends RecordReader { // Whether or not to copy the ORC columnar batch to Spark columnar batch. private final boolean copyToSpark; - public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark, int capacity) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.copyToSpark = copyToSpark; + this.capacity = capacity; } @@ -148,7 +150,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(CAPACITY); + batch = orcSchema.createRowBatch(capacity); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -162,15 +164,15 @@ public void initBatch( if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -193,8 +195,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); - missingCol.putNulls(0, CAPACITY); + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -206,7 +208,7 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index bb1b23611a7d7..5934a23db8af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,8 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; /** * Batch of rows that we assemble and the current index we've returned. Every time this @@ -115,13 +116,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) { + public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap, int capacity) { this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; - } - - public VectorizedParquetRecordReader(boolean useOffHeap) { - this(null, useOffHeap); + this.capacity = capacity; } /** @@ -199,9 +197,9 @@ private void initBatch( } if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); } columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { @@ -215,7 +213,7 @@ private void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -257,7 +255,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index c2e595455549c..9d447cdc79063 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -98,8 +98,11 @@ public void reserve(int requiredCapacity) { private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; + "vectorized reader, or increase the vectorized reader batch size. For parquet file " + + "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 2dd314d165348..dbf3bc6f0ee6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -151,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val capacity = sqlConf.orcVectorizedReaderBatchSize val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = @@ -186,7 +187,7 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( - enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f53a97ba45a26..ba69f9a26c968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -350,6 +350,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetRecordFilterEnabled val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -396,7 +397,7 @@ class ParquetFileFormat val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( - convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined) + convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index edb1290ee2eb0..db73bfa149aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +67,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +98,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f3ece5b15e26a..3af80930ec807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -653,7 +653,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -670,7 +672,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -686,7 +690,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -703,7 +709,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,8 +750,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = - new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val vectorizedReader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 86a3c71a3c4f6..e43336d947364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -76,6 +76,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -96,7 +97,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -119,7 +121,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -262,6 +265,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -279,7 +283,8 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() From b6b50efc854f298d5b3e11c05dca995a85bec962 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 31 Jan 2018 20:59:19 -0800 Subject: [PATCH 0260/2461] [SQL][MINOR] Inline SpecifiedWindowFrame.defaultWindowFrame(). ## What changes were proposed in this pull request? SpecifiedWindowFrame.defaultWindowFrame(hasOrderSpecification, acceptWindowFrame) was designed to handle the cases when some Window functions don't support setting a window frame (e.g. rank). However this param is never used. We may inline the whole of this function to simplify the code. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20463 from jiangxb1987/defaultWindowFrame. --- .../sql/catalyst/analysis/Analyzer.scala | 6 +++++- .../expressions/windowExpressions.scala | 21 ------------------- .../catalyst/ExpressionSQLBuilderSuite.scala | 5 +---- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 251099f750cf6..7848f88bda1c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2038,7 +2038,11 @@ class Analyzer( WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } we.copy(windowSpec = s.copy(frameSpecification = frame)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index dd13d9a3bba51..78895f1c2f6f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -265,27 +265,6 @@ case class SpecifiedWindowFrame( } } -object SpecifiedWindowFrame { - /** - * @param hasOrderSpecification If the window spec has order by expressions. - * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return the default window frame. - */ - def defaultWindowFrame( - hasOrderSpecification: Boolean, - acceptWindowFrame: Boolean): SpecifiedWindowFrame = { - if (hasOrderSpecification && acceptWindowFrame) { - // If order spec is defined and the window function supports user specified window frames, - // the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - // Otherwise, the default frame is - // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - } -} - case class UnresolvedWindowExpression( child: Expression, windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index d9cf1f361c1d6..61f9179042fe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -108,10 +108,7 @@ class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { } test("window specification") { - val frame = SpecifiedWindowFrame.defaultWindowFrame( - hasOrderSpecification = true, - acceptWindowFrame = true - ) + val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), From 4b7cd479a28b274f5a0802c9b017b3eb15002c21 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 1 Feb 2018 13:58:13 +0800 Subject: [PATCH 0261/2461] Revert "[SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore" This reverts commit d1721816d26bedee3c72eeb75db49da500568376. The patch is not fully tested and out-of-date. So revert it. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ed2a896033749..aed67a5027433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,21 +53,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", - "spark.kubernetes.driver.pod.name", - "spark.kubernetes.executor.podNamePrefix", - "spark.kubernetes.initcontainer.executor.configmapname", - "spark.kubernetes.initcontainer.executor.configmapkey", - "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", - "spark.kubernetes.initcontainer.downloadJarsSecretLocation", - "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", - "spark.kubernetes.initcontainer.downloadFilesSecretLocation", - "spark.kubernetes.initcontainer.remoteJars", - "spark.kubernetes.initcontainer.remoteFiles", - "spark.kubernetes.mountdependencies.jarsDownloadDir", - "spark.kubernetes.mountdependencies.filesDownloadDir", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", - "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -81,7 +66,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") - .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From 07cee33736aabf9e9a4a89344eda2b8ea29b27ea Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 31 Jan 2018 22:26:27 -0800 Subject: [PATCH 0262/2461] [SPARK-22274][PYTHON][SQL][FOLLOWUP] Use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## What changes were proposed in this pull request? This is a follow-up pr of #19872 which uses `assertRaisesRegex` but it doesn't exist in Python 2, so some tests fail when running tests in Python 2 environment. Unfortunately, we missed it because currently Python 2 environment of the pr builder doesn't have proper versions of pandas or pyarrow, so the tests were skipped. This pr modifies to use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN Closes #20467 from ueshin/issues/SPARK-22274/fup1. --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc26b96334c7a..b27363023ae77 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4530,19 +4530,19 @@ def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} From e15da5b14c8d845028365a609c0c66731d024ee7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 1 Feb 2018 11:25:01 +0200 Subject: [PATCH 0263/2461] [SPARK-23107][ML] ML 2.3 QA: New Scala APIs, docs. ## What changes were proposed in this pull request? Audit new APIs and docs in 2.3.0. ## How was this patch tested? No test. Author: Yanbo Liang Closes #20459 from yanboliang/SPARK-23107. --- mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1155ea5fdd85b..22e7b8bbf1ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -74,7 +74,7 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * @group param */ @Since("2.3.0") - final override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen or NULL values) in features and label column of string " + "type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a5873d03b4161..6d3fe7a6c748c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -645,7 +645,7 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable { - def this(uid: String, coefficients: Vector, intercept: Double) = + private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) private var trainingSummary: Option[LinearRegressionTrainingSummary] = None From 8bb70b068ea782e799e45238fcb093a6acb0fc9f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Feb 2018 21:25:02 +0900 Subject: [PATCH 0264/2461] [SPARK-23280][SQL][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20450 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java:[22,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20468 from ueshin/issues/SPARK-23280/fup1. --- .../main/java/org/apache/spark/sql/vectorized/ColumnVector.java | 1 - .../main/java/org/apache/spark/sql/vectorized/ColumnarArray.java | 1 - .../main/java/org/apache/spark/sql/vectorized/ColumnarRow.java | 1 - 3 files changed, 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 05271ec1f46ab..530d4d23d4eaf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.vectorized; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 7c7a1c806a2b7..72a192d089b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -18,7 +18,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 0c9e92ed11fbd..b400f7f93c1fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -19,7 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; From 89e8d556b93d1bf1b28fe153fd284f154045b0ee Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Feb 2018 21:28:53 +0900 Subject: [PATCH 0265/2461] [SPARK-23280][SQL][FOLLOWUP] Enable `MutableColumnarRow.getMap()`. ## What changes were proposed in this pull request? This is a followup pr of #20450. We should've enabled `MutableColumnarRow.getMap()` as well. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20471 from ueshin/issues/SPARK-23280/fup2. --- .../spark/sql/execution/vectorized/MutableColumnarRow.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 66668f3753604..307c19032dee5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -21,10 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; @@ -162,8 +162,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getMap(rowId); } @Override From ffbca84519011a747e0552632e88f5e4956e493d Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 1 Feb 2018 20:39:15 +0800 Subject: [PATCH 0266/2461] [SPARK-23202][SQL] Add new API in DataSourceWriter: onDataWriterCommit ## What changes were proposed in this pull request? The current DataSourceWriter API makes it hard to implement `onTaskCommit(taskCommit: TaskCommitMessage)` in `FileCommitProtocol`. In general, on receiving commit message, driver can start processing messages(e.g. persist messages into files) before all the messages are collected. The proposal to add a new API: `add(WriterCommitMessage message)`: Handles a commit message on receiving from a successful data writer. This should make the whole API of DataSourceWriter compatible with `FileCommitProtocol`, and more flexible. There was another radical attempt in #20386. This one should be more reasonable. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20454 from gengliangwang/write_api. --- .../sources/v2/writer/DataSourceWriter.java | 14 +++++++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 21 ++++++++++++++++++- .../sources/v2/SimpleWritableDataSource.scala | 21 +++++++++++++++++++ 4 files changed, 57 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7096aec0d22c2..52324b3792b8a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -62,6 +62,14 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Handles a commit message on receiving from a successful data writer. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. + */ + default void onDataWriterCommit(WriterCommitMessage message) {} + /** * Commits this writing job with a list of commit messages. The commit messages are collected from * successful data writers and are produced by {@link DataWriter#commit()}. @@ -78,8 +86,10 @@ public interface DataSourceWriter { void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * Aborts this writing job because some data writers are failed and keep failing when retry, + * or the Spark job fails with some unknown reasons, + * or {@link #onDataWriterCommit(WriterCommitMessage)} fails, + * or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 6592bd72fa338..eefbcf4c0e087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -80,7 +80,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e rdd, runTask, rdd.partitions.indices, - (index, message: WriterCommitMessage) => messages(index) = message + (index, message: WriterCommitMessage) => { + messages(index) = message + writer.onDataWriterCommit(message) + } ) if (!writer.isInstanceOf[StreamWriter]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2f49b07018aaf..1c3ba7826f7de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,7 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -198,6 +198,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple counter in writer with onDataWriterCommit") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + val numPartition = 6 + spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + assert(SimpleCounter.getCounter == numPartition, + "method onDataWriterCommit should be called as many as the number of partitions") + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a131b16953e3b..36dd2a350a055 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -66,9 +66,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { + SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } + override def onDataWriterCommit(message: WriterCommitMessage): Unit = { + SimpleCounter.increaseCounter + } + override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) @@ -183,6 +188,22 @@ class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) } } +private[v2] object SimpleCounter { + private var count: Int = 0 + + def increaseCounter: Unit = { + count += 1 + } + + def getCounter: Int = { + count + } + + def resetCounter: Unit = { + count = 0 + } +} + class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { From ec63e2d0743a4f75e1cce21d0fe2b54407a86a4a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 1 Feb 2018 21:00:47 +0800 Subject: [PATCH 0267/2461] [SPARK-23289][CORE] OneForOneBlockFetcher.DownloadCallback.onData should write the buffer fully ## What changes were proposed in this pull request? `channel.write(buf)` may not write the whole buffer since the underlying channel is a FileChannel, we should retry until the whole buffer is written. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20461 from zsxwing/SPARK-23289. --- .../apache/spark/network/shuffle/OneForOneBlockFetcher.java | 4 +++- core/src/test/scala/org/apache/spark/FileSuite.scala | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9cac7d00cc6b6..0bc571874f07c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -171,7 +171,9 @@ private class DownloadCallback implements StreamCallback { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { - channel.write(buf); + while (buf.hasRemaining()) { + channel.write(buf); + } } @Override diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index e9539dc73f6fa..55a9122cf9026 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -244,7 +244,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until testOutputCopies) { // Shift values by i so that they're different in the output val alteredOutput = testOutput.map(b => (b + i).toByte) - channel.write(ByteBuffer.wrap(alteredOutput)) + val buffer = ByteBuffer.wrap(alteredOutput) + while (buffer.hasRemaining) { + channel.write(buffer) + } } channel.close() file.close() From f051f834036e63d5e480d86440ce39924f979e82 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Feb 2018 10:36:31 -0800 Subject: [PATCH 0268/2461] [SPARK-13983][SQL] Fix HiveThriftServer2 can not get "--hiveconf" and ''--hivevar" variables since 2.0 ## What changes were proposed in this pull request? `--hiveconf` and `--hivevar` variables no longer work since Spark 2.0. The `spark-sql` client has fixed by [SPARK-15730](https://issues.apache.org/jira/browse/SPARK-15730) and [SPARK-18086](https://issues.apache.org/jira/browse/SPARK-18086). but `beeline`/[`Spark SQL HiveThriftServer2`](https://github.com/apache/spark/blob/v2.1.1/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala) is still broken. This pull request fix it. This pull request works for both `JDBC client` and `beeline`. ## How was this patch tested? unit tests for `JDBC client` manual tests for `beeline`: ``` git checkout origin/pr/17886 dev/make-distribution.sh --mvn mvn --tgz -Phive -Phive-thriftserver -Phadoop-2.6 -DskipTests tar -zxf spark-2.3.0-SNAPSHOT-bin-2.6.5.tgz && cd spark-2.3.0-SNAPSHOT-bin-2.6.5 sbin/start-thriftserver.sh ``` ``` cat < test.sql select '\${a}', '\${b}'; EOF beeline -u jdbc:hive2://localhost:10000 --hiveconf a=avalue --hivevar b=bvalue -f test.sql ``` Author: Yuming Wang Closes #17886 from wangyum/SPARK-13983-dev. --- .../service/cli/session/HiveSessionImpl.java | 74 ++++++++++++++++++- .../server/SparkSQLOperationManager.scala | 12 +++ .../HiveThriftServer2Suites.scala | 23 +++++- 3 files changed, 105 insertions(+), 4 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 108074cce3d6d..fc818bc69c761 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -44,7 +44,7 @@ import org.apache.hadoop.hive.ql.history.HiveHistory; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.processors.SetProcessor; +import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hive.common.util.HiveVersionInfo; @@ -71,6 +71,12 @@ import org.apache.hive.service.cli.thrift.TProtocolVersion; import org.apache.hive.service.server.ThreadWithGarbageCleanup; +import static org.apache.hadoop.hive.conf.SystemVariables.ENV_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVECONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVEVAR_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.METACONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.SYSTEM_PREFIX; + /** * HiveSession * @@ -209,7 +215,7 @@ private void configureSession(Map sessionConfMap) throws HiveSQL String key = entry.getKey(); if (key.startsWith("set:")) { try { - SetProcessor.setVariable(key.substring(4), entry.getValue()); + setVariable(key.substring(4), entry.getValue()); } catch (Exception e) { throw new HiveSQLException(e); } @@ -221,6 +227,70 @@ private void configureSession(Map sessionConfMap) throws HiveSQL } } + // Copy from org.apache.hadoop.hive.ql.processors.SetProcessor, only change: + // setConf(varname, propName, varvalue, true) when varname.startsWith(HIVECONF_PREFIX) + public static int setVariable(String varname, String varvalue) throws Exception { + SessionState ss = SessionState.get(); + if (varvalue.contains("\n")){ + ss.err.println("Warning: Value had a \\n character in it."); + } + varname = varname.trim(); + if (varname.startsWith(ENV_PREFIX)){ + ss.err.println("env:* variables can not be set."); + return 1; + } else if (varname.startsWith(SYSTEM_PREFIX)){ + String propName = varname.substring(SYSTEM_PREFIX.length()); + System.getProperties().setProperty(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(HIVECONF_PREFIX)){ + String propName = varname.substring(HIVECONF_PREFIX.length()); + setConf(varname, propName, varvalue, true); + } else if (varname.startsWith(HIVEVAR_PREFIX)) { + String propName = varname.substring(HIVEVAR_PREFIX.length()); + ss.getHiveVariables().put(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(METACONF_PREFIX)) { + String propName = varname.substring(METACONF_PREFIX.length()); + Hive hive = Hive.get(ss.getConf()); + hive.setMetaConf(propName, new VariableSubstitution().substitute(ss.getConf(), varvalue)); + } else { + setConf(varname, varname, varvalue, true); + } + return 0; + } + + // returns non-null string for validation fail + private static void setConf(String varname, String key, String varvalue, boolean register) + throws IllegalArgumentException { + HiveConf conf = SessionState.get().getConf(); + String value = new VariableSubstitution().substitute(conf, varvalue); + if (conf.getBoolVar(HiveConf.ConfVars.HIVECONFVALIDATION)) { + HiveConf.ConfVars confVars = HiveConf.getConfVars(key); + if (confVars != null) { + if (!confVars.isType(value)) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED because ").append(key).append(" expects "); + message.append(confVars.typeString()).append(" type value."); + throw new IllegalArgumentException(message.toString()); + } + String fail = confVars.validate(value); + if (fail != null) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED in validation : ").append(fail).append('.'); + throw new IllegalArgumentException(message.toString()); + } + } else if (key.startsWith("hive.")) { + throw new IllegalArgumentException("hive configuration " + key + " does not exists."); + } + } + conf.verifyAndSet(key, value); + if (register) { + SessionState.get().getOverriddenConfigurations().put(key, value); + } + } + @Override public void setOperationLogSessionDir(File operationLogRootDir) { if (!operationLogRootDir.exists()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a0e5012633f5e..bf7c01f60fb5c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.internal.SQLConf /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -50,6 +51,9 @@ private[thriftserver] class SparkSQLOperationManager() require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") val conf = sqlContext.sessionState.conf + val hiveSessionState = parentSession.getSessionState + setConfMap(conf, hiveSessionState.getOverriddenConfigurations) + setConfMap(conf, hiveSessionState.getHiveVariables) val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) @@ -58,4 +62,12 @@ private[thriftserver] class SparkSQLOperationManager() s"runInBackground=$runInBackground") operation } + + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { + val iterator = confMap.entrySet().iterator() + while (iterator.hasNext) { + val kv = iterator.next() + conf.setConfString(kv.getKey, kv.getValue) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 7289da71a3365..496f8c82a6c61 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -135,6 +135,22 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("Support beeline --hiveconf and --hivevar") { + withJdbcStatement() { statement => + executeTest(hiveConfList) + executeTest(hiveVarList) + def executeTest(hiveList: String): Unit = { + hiveList.split(";").foreach{ m => + val kv = m.split("=") + // select "${a}"; ---> avalue + val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"") + resultSet.next() + assert(resultSet.getString(1) === kv(1)) + } + } + } + } + test("JDBC query execution") { withJdbcStatement("test") { statement => val queries = Seq( @@ -740,10 +756,11 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { s"""jdbc:hive2://localhost:$serverPort/ |default? |hive.server2.transport.mode=http; - |hive.server2.thrift.http.path=cliservice + |hive.server2.thrift.http.path=cliservice; + |${hiveConfList}#${hiveVarList} """.stripMargin.split("\n").mkString.trim } else { - s"jdbc:hive2://localhost:$serverPort/" + s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}" } def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) { @@ -779,6 +796,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private var listeningPort: Int = _ protected def serverPort: Int = listeningPort + protected val hiveConfList = "a=avalue;b=bvalue" + protected val hiveVarList = "c=cvalue;d=dvalue" protected def user = System.getProperty("user.name") protected var warehousePath: File = _ From 73da3b6968630d9e2cafc742ccb6d4eb54957df4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 10:48:34 -0800 Subject: [PATCH 0269/2461] [SPARK-23293][SQL] fix data source v2 self join ## What changes were proposed in this pull request? `DataSourceV2Relation` should extend `MultiInstanceRelation`, to take care of self-join. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20466 from cloud-fan/dsv2-selfjoin. --- .../execution/datasources/v2/DataSourceV2Relation.scala | 8 +++++++- .../apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3d4c64981373d..eebfa29f91b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] @@ -33,6 +35,10 @@ case class DataSourceV2Relation( case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } + + override def newInstance(): DataSourceV2Relation = { + copy(fullOutput = fullOutput.map(_.newInstance())) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1c3ba7826f7de..23147fffe8a08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -217,6 +217,12 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23293: data source v2 self join") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + val df2 = df.select(($"i" + 1).as("k"), $"j") + checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { From 4bcfdefb9f6d5ba88335953683a1dabbee83e9ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 1 Feb 2018 14:56:40 -0800 Subject: [PATCH 0270/2461] [INFRA] Close stale PRs. Closes #20334 Closes #20262 From 032c11b83f0d276bf8085992229b8c598f02798a Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Thu, 1 Feb 2018 15:26:59 -0800 Subject: [PATCH 0271/2461] [SPARK-23296][YARN] Include stacktrace in YARN-app diagnostic ## What changes were proposed in this pull request? Include stacktrace in the diagnostics message upon abnormal unregister from RM ## How was this patch tested? Tested with a failing job, and confirmed a stacktrace in the client output and YARN webUI. Author: Gera Shegalov Closes #20470 from gerashegalov/gera/stacktrace-diagnostics. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4d5e3bb043671..2f88feb0f1fdf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -718,7 +719,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + cause) + "User class threw exception: " + StringUtils.stringifyException(cause)) } sparkContextPromise.tryFailure(e.getCause()) } finally { From 90848d507457d30abb36e3ba07618dfc87c34cd6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Feb 2018 10:18:32 +0800 Subject: [PATCH 0272/2461] [SPARK-23284][SQL] Document the behavior of several ColumnVector's get APIs when accessing null slot ## What changes were proposed in this pull request? For some ColumnVector get APIs such as getDecimal, getBinary, getStruct, getArray, getInterval, getUTF8String, we should clearly document their behaviors when accessing null slot. They should return null in this case. Then we can remove null checks from the places using above APIs. For the APIs of primitive values like getInt, getInts, etc., this also documents their behaviors when accessing null slots. Their returning values are undefined and can be anything. ## How was this patch tested? Added tests into `ColumnarBatchSuite`. Author: Liang-Chi Hsieh Closes #20455 from viirya/SPARK-23272-followup. --- .../datasources/orc/OrcColumnVector.java | 3 + .../vectorized/MutableColumnarRow.java | 7 -- .../vectorized/WritableColumnVector.java | 5 ++ .../sql/vectorized/ArrowColumnVector.java | 4 + .../spark/sql/vectorized/ColumnVector.java | 63 ++++++++++------ .../spark/sql/vectorized/ColumnarRow.java | 7 -- .../vectorized/ColumnarBatchSuite.scala | 74 ++++++++++++++++++- 7 files changed, 124 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index c8add4c9f486c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -154,12 +154,14 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); return Decimal.apply(data, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); BytesColumnVector col = bytesData; return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); @@ -167,6 +169,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); byte[] binary = new byte[bytesData.length[index]]; System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 307c19032dee5..4e4242fe8d9b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -127,43 +127,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getMap(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 9d447cdc79063..5275e4a91eac0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -341,6 +341,7 @@ public final int putByteArray(int rowId, byte[] value) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -367,6 +368,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytesAsUTF8String(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -384,6 +386,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytes(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -613,6 +616,7 @@ public final int appendStruct(boolean isNull) { // array offsets and lengths in the current column vector. @Override public final ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } @@ -620,6 +624,7 @@ public final ColumnarArray getArray(int rowId) { // second child column vector, and puts the offsets and lengths in the current column vector. @Override public final ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f3ece538c3b80..f8e37e995a17f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -101,21 +101,25 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getBinary(rowId); } @Override public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 530d4d23d4eaf..ad99b450a4809 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -80,12 +80,14 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the boolean type value for rowId. + * Returns the boolean type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract boolean getBoolean(int rowId); /** - * Gets boolean type values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -96,12 +98,14 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the byte type value for rowId. + * Returns the byte type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract byte getByte(int rowId); /** - * Gets byte type values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -112,12 +116,14 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the short type value for rowId. + * Returns the short type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract short getShort(int rowId); /** - * Gets short type values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -128,12 +134,14 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the int type value for rowId. + * Returns the int type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract int getInt(int rowId); /** - * Gets int type values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -144,12 +152,14 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the long type value for rowId. + * Returns the long type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract long getLong(int rowId); /** - * Gets long type values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -160,12 +170,14 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the float type value for rowId. + * Returns the float type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract float getFloat(int rowId); /** - * Gets float type values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -176,12 +188,14 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the double type value for rowId. + * Returns the double type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract double getDouble(int rowId); /** - * Gets double type values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -192,7 +206,7 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the struct type value for rowId. + * Returns the struct type value for rowId. If the slot for rowId is null, it should return null. * * To support struct type, implementations must implement {@link #getChild(int)} and make this * vector a tree structure. The number of child vectors must be same as the number of fields of @@ -205,7 +219,7 @@ public final ColumnarRow getStruct(int rowId) { } /** - * Returns the array type value for rowId. + * Returns the array type value for rowId. If the slot for rowId is null, it should return null. * * To support array type, implementations must construct an {@link ColumnarArray} and return it in * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all @@ -218,13 +232,13 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarArray getArray(int rowId); /** - * Returns the map type value for rowId. + * Returns the map type value for rowId. If the slot for rowId is null, it should return null. * * In Spark, map type value is basically a key data array and a value data array. A key from the * key array with a index and a value from the value array with the same index contribute to * an entry of this map type value. * - * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * To support map type, implementations must construct a {@link ColumnarMap} and return it in * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data * of all the values of all the maps in this vector, and a pair of offset and length which @@ -233,24 +247,25 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarMap getMap(int ordinal); /** - * Returns the decimal type value for rowId. + * Returns the decimal type value for rowId. If the slot for rowId is null, it should return null. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the string type value for rowId. Note that the returned UTF8String may point to the - * data of this column vector, please copy it if you want to keep it after this column vector is - * freed. + * Returns the string type value for rowId. If the slot for rowId is null, it should return null. + * Note that the returned UTF8String may point to the data of this column vector, please copy it + * if you want to keep it after this column vector is freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the binary type value for rowId. + * Returns the binary type value for rowId. If the slot for rowId is null, it should return null. */ public abstract byte[] getBinary(int rowId); /** - * Returns the calendar interval type value for rowId. + * Returns the calendar interval type value for rowId. If the slot for rowId is null, it should + * return null. * * In Spark, calendar interval type value is basically an integer value representing the number of * months in this interval, and a long value representing the number of microseconds in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index b400f7f93c1fe..f2f2279590023 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -119,43 +119,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getMap(rowId); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8fe2985836f2e..772f687526008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -572,7 +572,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - testVector("String APIs", 6, StringType) { + testVector("String APIs", 7, StringType) { column => val reference = mutable.ArrayBuffer.empty[String] @@ -619,6 +619,10 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 1 assert(column.arrayData().elementsAppended == 17 + (s + s).length) + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + reference.zipWithIndex.foreach { v => val errMsg = "VectorType=" + column.getClass.getSimpleName assert(v._1.length == column.getArrayLength(v._2), errMsg) @@ -647,6 +651,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += new CalendarInterval(0, 2000) column.putNull(2) + assert(column.getInterval(2) == null) reference += null months.putInt(3, 20) @@ -683,6 +688,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(0).numElements == 1) assert(column.getArray(1).numElements == 2) assert(column.isNullAt(2)) + assert(column.getArray(2) == null) assert(column.getArray(3).numElements == 0) assert(column.getArray(4).numElements == 3) @@ -785,6 +791,7 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, 1) column.putArray(1, 1, 2) column.putNull(2) + assert(column.getMap(2) == null) column.putArray(3, 3, 0) column.putArray(4, 3, 3) @@ -821,6 +828,7 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(0, 3.45) column.putNull(1) + assert(column.getStruct(1) == null) c1.putInt(2, 456) c2.putDouble(2, 5.67) @@ -1261,4 +1269,68 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.close() allocator.close() } + + testVector("Decimal API", 4, DecimalType.IntDecimal) { + column => + + val reference = mutable.ArrayBuffer.empty[Decimal] + + var idx = 0 + column.putDecimal(idx, new Decimal().set(10), 10) + reference += new Decimal().set(10) + idx += 1 + + column.putDecimal(idx, new Decimal().set(20), 10) + reference += new Decimal().set(20) + idx += 1 + + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + reference += null + idx += 1 + + column.putDecimal(idx, new Decimal().set(30), 10) + reference += new Decimal().set(30) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getDecimal(i, 10, 0), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + + testVector("Binary APIs", 4, BinaryType) { + column => + + val reference = mutable.ArrayBuffer.empty[String] + var idx = 0 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + reference += "Hello" + idx += 1 + + column.putByteArray(idx, "World".getBytes(StandardCharsets.UTF_8)) + reference += "World" + idx += 1 + + column.putNull(idx) + reference += null + idx += 1 + + column.putByteArray(idx, "abc".getBytes(StandardCharsets.UTF_8)) + reference += "abc" + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + if (v != null) { + assert(v == new String(column.getBinary(i)), errMsg) + } else { + assert(column.isNullAt(i), errMsg) + assert(column.getBinary(i) == null, errMsg) + } + } + + column.close() + } } From 969eda4a02faa7ca6cf3aff5cd10e6d51026b845 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 2 Feb 2018 11:43:22 +0800 Subject: [PATCH 0273/2461] [SPARK-23020][CORE] Fix another race in the in-process launcher test. First the bad news: there's an unfixable race in the launcher code. (By unfixable I mean it would take a lot more effort than this change to fix it.) The good news is that it should only affect super short lived applications, such as the one run by the flaky test, so it's possible to work around it in our test. The fix also uncovered an issue with the recently added "closeAndWait()" method; closing the connection would still possibly cause data loss, so this change waits a while for the connection to finish itself, and closes the socket if that times out. The existing connection timeout is reused so that if desired it's possible to control how long to wait. As part of that I also restored the old behavior that disconnect() would force a disconnection from the child app; the "wait for data to arrive" approach is only taken when disposing of the handle. I tested this by inserting a bunch of sleeps in the test and the socket handling code in the launcher library; with those I was able to reproduce the error from the jenkins jobs. With the changes, even with all the sleeps still in place, all tests pass. Author: Marcelo Vanzin Closes #20462 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 40 ++++++++++++++--- .../spark/launcher/AbstractAppHandle.java | 45 ++++++++++++------- .../spark/launcher/ChildProcAppHandle.java | 2 +- .../spark/launcher/InProcessAppHandle.java | 2 +- .../apache/spark/launcher/LauncherServer.java | 30 ++++++++----- .../spark/launcher/LauncherServerSuite.java | 2 +- 6 files changed, 87 insertions(+), 34 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 1543f4fdb0162..2225591a4ff75 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -157,12 +157,24 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle handle = null; try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); + synchronized (InProcessTestApp.LOCK) { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + // SPARK-23020: see doc for InProcessTestApp.LOCK for a description of the race. Here + // we wait until we know that the connection between the app and the launcher has been + // established before allowing the app to finish. + final SparkAppHandle _handle = handle; + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertNotEquals(SparkAppHandle.State.UNKNOWN, _handle.getState()); + }); + + InProcessTestApp.LOCK.wait(5000); + } waitFor(handle); assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); @@ -193,10 +205,26 @@ public static void main(String[] args) throws Exception { public static class InProcessTestApp { + /** + * SPARK-23020: there's a race caused by a child app finishing too quickly. This would cause + * the InProcessAppHandle to dispose of itself even before the child connection was properly + * established, so no state changes would be detected for the application and its final + * state would be LOST. + * + * It's not really possible to fix that race safely in the handle code itself without changing + * the way in-process apps talk to the launcher library, so we work around that in the test by + * synchronizing on this object. + */ + public static final Object LOCK = new Object(); + public static void main(String[] args) throws Exception { assertNotEquals(0, args.length); assertEquals(args[0], "hello"); new SparkContext().stop(); + + synchronized (LOCK) { + LOCK.notifyAll(); + } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index 84a25a5254151..9cbebdaeb33d3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -18,22 +18,22 @@ package org.apache.spark.launcher; import java.io.IOException; -import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; abstract class AbstractAppHandle implements SparkAppHandle { - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(AbstractAppHandle.class.getName()); private final LauncherServer server; private LauncherServer.ServerConnection connection; private List listeners; private AtomicReference state; - private String appId; + private volatile String appId; private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { @@ -44,7 +44,7 @@ protected AbstractAppHandle(LauncherServer server) { @Override public synchronized void addListener(Listener l) { if (listeners == null) { - listeners = new ArrayList<>(); + listeners = new CopyOnWriteArrayList<>(); } listeners.add(l); } @@ -71,16 +71,14 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { - if (connection != null) { - try { - connection.closeAndWait(); - } catch (IOException ioe) { - // no-op. - } + if (connection != null && connection.isOpen()) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. } - dispose(); } + dispose(); } void setConnection(LauncherServer.ServerConnection connection) { @@ -97,10 +95,25 @@ boolean isDisposed() { /** * Mark the handle as disposed, and set it as LOST in case the current state is not final. + * + * This method should be called only when there's a reasonable expectation that the communication + * with the child application is not needed anymore, either because the code managing the handle + * has said so, or because the child application is finished. */ synchronized void dispose() { if (!isDisposed()) { + // First wait for all data from the connection to be read. Then unregister the handle. + // Otherwise, unregistering might cause the server to be stopped and all child connections + // to be closed. + if (connection != null) { + try { + connection.waitForClose(); + } catch (IOException ioe) { + // no-op. + } + } server.unregister(this); + // Set state to LOST if not yet final. setState(State.LOST, false); this.disposed = true; @@ -127,11 +140,13 @@ void setState(State s, boolean force) { current = state.get(); } - LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { current, s }); + if (s != State.LOST) { + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { current, s }); + } } - synchronized void setAppId(String appId) { + void setAppId(String appId) { this.appId = appId; fireEvent(true); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 5e3c95676ecbe..5609f8492f4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -112,7 +112,7 @@ void monitorChild() { } } - disconnect(); + dispose(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index b8030e0063a37..4b740d3fad20e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -66,7 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - disconnect(); + dispose(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index f4ecd52fdeab8..607879fd02ea9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -238,6 +238,7 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); + clientConnection.setConnectionThread(clientThread); synchronized (clients) { clients.add(clientConnection); } @@ -290,17 +291,15 @@ class ServerConnection extends LauncherConnection { private TimerTask timeout; private volatile Thread connectionThread; - volatile AbstractAppHandle handle; + private volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); this.timeout = timeout; } - @Override - public void run() { - this.connectionThread = Thread.currentThread(); - super.run(); + void setConnectionThread(Thread t) { + this.connectionThread = t; } @Override @@ -361,19 +360,30 @@ public void close() throws IOException { } /** - * Close the connection and wait for any buffered data to be processed before returning. + * Wait for the remote side to close the connection so that any pending data is processed. * This ensures any changes reported by the child application take effect. + * + * This method allows a short period for the above to happen (same amount of time as the + * connection timeout, which is configurable). This should be fine for well-behaved + * applications, where they close the connection arond the same time the app handle detects the + * app has finished. + * + * In case the connection is not closed within the grace period, this method forcefully closes + * it and any subsequent data that may arrive will be ignored. */ - public void closeAndWait() throws IOException { - close(); - + public void waitForClose() throws IOException { Thread connThread = this.connectionThread; if (Thread.currentThread() != connThread) { try { - connThread.join(); + connThread.join(getConnectionTimeout()); } catch (InterruptedException ie) { // Ignore. } + + if (connThread.isAlive()) { + LOG.log(Level.WARNING, "Timed out waiting for child connection to close."); + close(); + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 024efac33c391..d16337a319be3 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -94,8 +94,8 @@ public void infoChanged(SparkAppHandle handle) { Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS); assertTrue(stopMsg instanceof Stop); } finally { - handle.kill(); close(client); + handle.kill(); client.clientThread.join(); } } From b3a04283f490020c13b6750de021af734c449c3a Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Fri, 2 Feb 2018 12:21:06 +0800 Subject: [PATCH 0274/2461] [SPARK-23306] Fix the oom caused by contention ## What changes were proposed in this pull request? here is race condition in TaskMemoryManger, which may cause OOM. The memory released may be taken by another task because there is a gap between releaseMemory and acquireMemory, e.g., UnifiedMemoryManager, causing the OOM. if the current is the only one that can perform spill. It can happen to BytesToBytesMap, as it only spill required bytes. Loop on current consumer if it still has memory to release. ## How was this patch tested? The race contention is hard to reproduce, but the current logic seems causing the issue. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Zhan Zhang Closes #20480 from zhzhan/oom. --- .../org/apache/spark/memory/TaskMemoryManager.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 632d718062212..d07faf1da1248 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -172,10 +172,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { currentEntry = sortedConsumers.lastEntry(); } List cList = currentEntry.getValue(); - MemoryConsumer c = cList.remove(cList.size() - 1); - if (cList.isEmpty()) { - sortedConsumers.remove(currentEntry.getKey()); - } + MemoryConsumer c = cList.get(cList.size() - 1); try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -185,6 +182,11 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got >= required) { break; } + } else { + cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } } } catch (ClosedByInterruptException e) { // This called by user to kill a task (e.g: speculative task). From 19c7c7ebdef6c1c7a02ebac9af6a24f521b52c37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 20:44:46 -0800 Subject: [PATCH 0275/2461] [SPARK-23301][SQL] data source column pruning should work for arbitrary expressions ## What changes were proposed in this pull request? This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`. ## How was this patch tested? a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works. Author: Wenchen Fan Closes #20476 from cloud-fan/push-down. --- .../v2/PushDownOperatorsToDataSource.scala | 53 ++++---- .../sources/v2/JavaAdvancedDataSourceV2.java | 29 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 113 ++++++++++++++++-- 3 files changed, 155 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index df034adf1e7d6..566a48394f02e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - // TODO: nested fields pruning - def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { - plan match { - case Project(projectList, child) => - val required = projectList.filter(requiredByParent.contains).flatMap(_.references) - pushDownRequiredColumns(child, required) - - case Filter(condition, child) => - val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) - - case DataSourceV2Relation(fullOutput, reader) => reader match { - case r: SupportsPushDownRequiredColumns => - // Match original case of attributes. - val attrMap = AttributeMap(fullOutput.zip(fullOutput)) - val requiredColumns = requiredByParent.map(attrMap) - r.pruneColumns(requiredColumns.toStructType) - case _ => - } + pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + // TODO: nested fields pruning + private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.flatMap(_.references) + pushDownRequiredColumns(child, AttributeSet(required)) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) - // TODO: there may be more operators can be used to calculate required columns, we can add - // more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + case relation: DataSourceV2Relation => relation.reader match { + case reader: SupportsPushDownRequiredColumns => + val requiredColumns = relation.output.filter(requiredByParent.contains) + reader.pruneColumns(requiredColumns.toStructType) + + case _ => } - } - pushDownRequiredColumns(filterPushed, filterPushed.output) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + // TODO: there may be more operators that can be used to calculate the required columns. We + // can add more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + } } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index d421f7d19563f..172e5d5eebcbe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -32,11 +32,12 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { - private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - private Filter[] filters = new Filter[0]; + // Exposed for testing. + public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + public Filter[] filters = new Filter[0]; @Override public StructType readSchema() { @@ -50,8 +51,26 @@ public void pruneColumns(StructType requiredSchema) { @Override public Filter[] pushFilters(Filter[] filters) { - this.filters = filters; - return new Filter[0]; + Filter[] supported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return gt.attribute().equals("i") && gt.value() instanceof Integer; + } else { + return false; + } + }).toArray(Filter[]::new); + + Filter[] unsupported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return !gt.attribute().equals("i") || !(gt.value() instanceof Integer); + } else { + return true; + } + }).toArray(Filter[]::new); + + this.filters = supported; + return unsupported; } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 23147fffe8a08..eccd45442a3b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} @@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + } + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) - checkAnswer(df.select('i).filter('i > 10), Nil) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } else { + val reader = getJavaReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } + + val q3 = df.select('i).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } else { + val reader = getJavaReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } + + val q4 = df.select('j).filter('j < -10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } } } } @@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df2 = df.select(($"i" + 1).as("k"), $"j") checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) } + + test("SPARK-23301: column pruning with arbitrary expressions") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + + val q1 = df.select('i + 1) + checkAnswer(q1, (1 until 11).map(i => Row(i))) + val reader1 = getReader(q1) + assert(reader1.requiredSchema.fieldNames === Seq("i")) + + val q2 = df.select(lit(1)) + checkAnswer(q2, (0 until 10).map(i => Row(1))) + val reader2 = getReader(q2) + assert(reader2.requiredSchema.isEmpty) + + // 'j === 1 can't be pushed down, but we should still be able do column pruning + val q3 = df.filter('j === -1).select('j * 2) + checkAnswer(q3, Row(-2)) + val reader3 = getReader(q3) + assert(reader3.filters.isEmpty) + assert(reader3.requiredSchema.fieldNames === Seq("j")) + + // column pruning should work with other operators. + val q4 = df.sort('i).limit(1).select('i + 1) + checkAnswer(q4, Row(1)) + val reader4 = getReader(q4) + assert(reader4.requiredSchema.fieldNames === Seq("i")) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { @@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - Array.empty + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } override def pushedFilters(): Array[Filter] = filters From b9503fcbb3f4a3ce263164d1f11a8e99b9ca5710 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Feb 2018 22:43:28 +0800 Subject: [PATCH 0276/2461] [SPARK-23312][SQL] add a config to turn off vectorized cache reader ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-23309 reported a performance regression about cached table in Spark 2.3. While the investigating is still going on, this PR adds a conf to turn off the vectorized cache reader, to unblock the 2.3 release. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20483 from cloud-fan/cache. --- .../org/apache/spark/sql/internal/SQLConf.scala | 8 ++++++++ .../columnar/InMemoryTableScanExec.scala | 2 +- .../org/apache/spark/sql/CachedTableSuite.scala | 15 +++++++++++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 90654e67457e0..1e2501ee7757d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -141,6 +141,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CACHE_VECTORIZED_READER_ENABLED = + buildConf("spark.sql.inMemoryColumnarStorage.enableVectorizedReader") + .doc("Enables vectorized reader for columnar caching.") + .booleanConf + .createWithDefault(true) + val COLUMN_VECTOR_OFFHEAP_ENABLED = buildConf("spark.sql.columnVector.offheap.enabled") .internal() @@ -1272,6 +1278,8 @@ class SQLConf extends Serializable with Logging { def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) def targetPostShuffleInputSize: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c167f1e7dc621..e972f8b30d87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -54,7 +54,7 @@ case class InMemoryTableScanExec( override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields - relation.schema.fields.forall(f => f.dataType match { + conf.cacheVectorizedReaderEnabled && relation.schema.fields.forall(f => f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72fe0f42801f1..9f27fa09127af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -21,8 +21,6 @@ import scala.collection.mutable.HashSet import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.concurrent.Eventually._ - import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression @@ -30,6 +28,7 @@ import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{AccumulatorContext, Utils} @@ -782,4 +781,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedDs2) == 1) } } + + test("SPARK-23312: vectorized cache reader can be disabled") { + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + val df = spark.range(10).cache() + df.queryExecution.executedPlan.foreach { + case i: InMemoryTableScanExec => assert(i.supportsBatch == vectorized) + case _ => + } + } + } + } } From dd52681bf542386711609cb037a55b3d264eddef Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 2 Feb 2018 09:10:50 -0600 Subject: [PATCH 0277/2461] [SPARK-23253][CORE][SHUFFLE] Only write shuffle temporary index file when there is not an existing one ## What changes were proposed in this pull request? Shuffle Index temporay file is used for atomic creating shuffle index file, it is not needed when the index file already exists after another attempts of same task had it done. ## How was this patch tested? exitsting ut cc squito Author: Kent Yao Closes #20422 from yaooqinn/SPARK-23253. --- .../shuffle/IndexShuffleBlockResolver.scala | 27 ++++----- .../sort/IndexShuffleBlockResolverSuite.scala | 59 ++++++++++++++----- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 266ee42e39cca..c5f3f6e2b42b6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -141,19 +141,6 @@ private[spark] class IndexShuffleBlockResolver( val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length - out.writeLong(offset) - } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. @@ -166,10 +153,22 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } - indexTmp.delete() } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() + } + if (indexFile.exists()) { indexFile.delete() } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index d21ce73f4021e..4ce379b76b551 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -64,6 +64,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } test("commit shuffle files multiple times") { + val shuffleId = 1 + val mapId = 2 + val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" val resolver = new IndexShuffleBlockResolver(conf, blockManager) val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) @@ -73,9 +76,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) - val dataFile = resolver.getDataFile(1, 2) + val indexFile = new File(tempDir.getAbsolutePath, idxName) + val dataFile = resolver.getDataFile(shuffleId, mapId) + + assert(indexFile.exists()) + assert(indexFile.length() === (lengths.length + 1) * 8) assert(dataFile.exists()) assert(dataFile.length() === 30) assert(!dataTmp.exists()) @@ -89,7 +96,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + + assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 30) @@ -97,18 +106,27 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa // The dataFile should be the previous one val firstByte = new Array[Byte](1) - val in = new FileInputStream(dataFile) + val dataIn = new FileInputStream(dataFile) Utils.tryWithSafeFinally { - in.read(firstByte) + dataIn.read(firstByte) } { - in.close() + dataIn.close() } assert(firstByte(0) === 0) + // The index file should not change + val indexIn = new DataInputStream(new FileInputStream(indexFile)) + Utils.tryWithSafeFinally { + indexIn.readLong() // the first offset is always 0 + assert(indexIn.readLong() === 10, "The index file should not change") + } { + indexIn.close() + } + // remove data file dataFile.delete() - val lengths3 = Array[Long](10, 10, 15) + val lengths3 = Array[Long](7, 10, 15, 3) val dataTmp3 = File.createTempFile("shuffle", null, tempDir) val out3 = new FileOutputStream(dataTmp3) Utils.tryWithSafeFinally { @@ -117,20 +135,29 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 35) - assert(!dataTmp2.exists()) + assert(!dataTmp3.exists()) - // The dataFile should be the previous one - val firstByte2 = new Array[Byte](1) - val in2 = new FileInputStream(dataFile) + // The dataFile should be the new one, since we deleted the dataFile from the first attempt + val dataIn2 = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + dataIn2.read(firstByte) + } { + dataIn2.close() + } + assert(firstByte(0) === 2) + + // The index file should be updated, since we deleted the dataFile from the first attempt + val indexIn2 = new DataInputStream(new FileInputStream(indexFile)) Utils.tryWithSafeFinally { - in2.read(firstByte2) + indexIn2.readLong() // the first offset is always 0 + assert(indexIn2.readLong() === 7, "The index file should be updated") } { - in2.close() + indexIn2.close() } - assert(firstByte2(0) === 2) } } From eefec93d193d43d5b71b8f8a4b1060286da971dd Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 2 Feb 2018 10:17:51 -0600 Subject: [PATCH 0278/2461] [SPARK-23295][BUILD][MINOR] Exclude Waring message when generating versions in make-distribution.sh ## What changes were proposed in this pull request? When we specified a wrong profile to make a spark distribution, such as `-Phadoop1000`, we will get an odd package named like `spark-[WARNING] The requested profile "hadoop1000" could not be activated because it does not exist.-bin-hadoop-2.7.tgz`, which actually should be `"spark-$VERSION-bin-$NAME.tgz"` ## How was this patch tested? ### before ``` build/mvn help:evaluate -Dexpression=scala.binary.version -Phadoop1000 2>/dev/null | grep -v "INFO" | tail -n 1 [WARNING] The requested profile "hadoop1000" could not be activated because it does not exist. ``` ``` build/mvn help:evaluate -Dexpression=project.version -Phadoop1000 2>/dev/null | grep -v "INFO" | tail -n 1 [WARNING] The requested profile "hadoop1000" could not be activated because it does not exist. ``` ### after ``` build/mvn help:evaluate -Dexpression=project.version -Phadoop1000 2>/dev/null | grep -v "INFO" | grep -v "WARNING" | tail -n 1 2.4.0-SNAPSHOT ``` ``` build/mvn help:evaluate -Dexpression=scala.binary.version -Dscala.binary.version=2.11.1 2>/dev/null | grep -v "INFO" | grep -v "WARNING" | tail -n 1 2.11.1 ``` cloud-fan srowen Author: Kent Yao Closes #20469 from yaooqinn/dist-minor. --- dev/make-distribution.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 7245163ea2a51..8b02446b2f15f 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -117,15 +117,21 @@ if [ ! "$(command -v "$MVN")" ] ; then exit -1; fi -VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ # because we use "set -o pipefail" From eaf35de2471fac4337dd2920026836d52b1ec847 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Feb 2018 17:37:51 -0800 Subject: [PATCH 0279/2461] [SPARK-23064][SS][DOCS] Stream-stream joins Documentation - follow up ## What changes were proposed in this pull request? Further clarification of caveats in using stream-stream outer joins. ## How was this patch tested? N/A Author: Tathagata Das Closes #20494 from tdas/SPARK-23064-2. --- docs/structured-streaming-programming-guide.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 62589a62ac4c4..48d6d0b542cc0 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1346,10 +1346,20 @@ joined <- join( -However, note that the outer NULL results will be generated with a delay (depends on the specified -watermark delay and the time range condition) because the engine has to wait for that long to ensure + +There are a few points to note regarding outer joins. + +- *The outer NULL results will be generated with a delay that depends on the specified watermark +delay and the time range condition.* This is because the engine has to wait for that long to ensure there were no matches and there will be no more matches in future. +- In the current implementation in the micro-batch engine, watermarks are advanced at the end of a +micro-batch, and the next micro-batch uses the updated watermark to clean up state and output +outer results. Since we trigger a micro-batch only when there is new data to be processed, the +generation of the outer result may get delayed if there no new data being received in the stream. +*In short, if any of the two input streams being joined does not receive data for a while, the +outer (both cases, left or right) output may get delayed.* + ##### Support matrix for joins in streaming queries
{executor.map(_.isBlacklisted).getOrElse(false)}for applicationfor stagefalse
From 3ff83ad43a704cc3354ef9783e711c065e2a1a22 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 2 Feb 2018 20:36:27 -0800 Subject: [PATCH 0280/2461] [SQL] Minor doc update: Add an example in DataFrameReader.schema ## What changes were proposed in this pull request? This patch adds a small example to the schema string definition of schema function. It isn't obvious how to use it, so an example would be useful. ## How was this patch tested? N/A - doc only. Author: Reynold Xin Closes #20491 from rxin/schema-doc. --- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 46b5f54a33f74..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -74,6 +74,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * infer the input schema automatically from data. By specifying the schema here, the underlying * data source can skip the schema inference step, and thus speed up data loading. * + * {{{ + * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") + * }}} + * * @since 2.3.0 */ def schema(schemaString: String): DataFrameReader = { From fe73cb4b439169f16cc24cd851a11fd398ce7edf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Feb 2018 20:49:08 -0800 Subject: [PATCH 0281/2461] [SPARK-23317][SQL] rename ContinuousReader.setOffset to setStartOffset ## What changes were proposed in this pull request? In the document of `ContinuousReader.setOffset`, we say this method is used to specify the start offset. We also have a `ContinuousReader.getStartOffset` to get the value back. I think it makes more sense to rename `ContinuousReader.setOffset` to `setStartOffset`. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20486 from cloud-fan/rename. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- .../sql/sources/v2/reader/streaming/ContinuousReader.java | 4 ++-- .../execution/streaming/continuous/ContinuousExecution.scala | 2 +- .../streaming/continuous/ContinuousRateStreamSource.scala | 2 +- .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../sql/streaming/sources/StreamingDataSourceV2Suite.scala | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 41c443bc12120..b049a054cb40e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -71,7 +71,7 @@ class KafkaContinuousReader( override def readSchema: StructType = KafkaOffsetReader.kafkaSchema private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { + override def setStartOffset(start: ju.Optional[Offset]): Unit = { offset = start.orElse { val offsets = initialOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index d1d1e7ffd1dd4..7fe7f00ac2fa8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -51,12 +51,12 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ - void setOffset(Optional start); + void setStartOffset(Optional start); /** * Return the specified or inferred start offset for this reader. * - * @throws IllegalStateException if setOffset has not been called + * @throws IllegalStateException if setStartOffset has not been called */ Offset getStartOffset(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 08c81419a9d34..ed22b9100497a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -181,7 +181,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) + reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) new StreamingDataSourceV2Relation(newOutput, reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 0eaaa4889ba9e..b63d8d3e20650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -61,7 +61,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) private var offset: Offset = _ - override def setOffset(offset: java.util.Optional[Offset]): Unit = { + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 3158995ec62f1..0d68d9c3138aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -160,7 +160,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous data") { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setOffset(Optional.empty()) + reader.setStartOffset(Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index cb873ab688e96..51f44fa6285e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -43,7 +43,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def readSchema(): StructType = StructType(Seq()) def stop(): Unit = {} def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setOffset(start: Optional[Offset]): Unit = {} + def setStartOffset(start: Optional[Offset]): Unit = {} def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { throw new IllegalStateException("fake source - cannot actually read") From 63b49fa2e599080c2ba7d5189f9dde20a2e01fb4 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 3 Feb 2018 00:02:03 -0800 Subject: [PATCH 0282/2461] [SPARK-23311][SQL][TEST] add FilterFunction test case for test CombineTypedFilters ## What changes were proposed in this pull request? In the current test case for CombineTypedFilters, we lack the test of FilterFunction, so let's add it. In addition, in TypedFilterOptimizationSuite's existing test cases, Let's extract a common LocalRelation. ## How was this patch tested? add new test cases. Author: caoxuewen Closes #20482 from heary-cao/TypedFilterOptimizationSuite. --- .../spark/sql/catalyst/dsl/package.scala | 3 + .../TypedFilterOptimizationSuite.scala | 95 ++++++++++++++++--- 2 files changed, 84 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 59cb26d5e6c36..efb2eba655e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -301,6 +302,8 @@ package object dsl { def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) + def filter[T : Encoder](func: FilterFunction[T]): LogicalPlan = TypedFilter(func, logicalPlan) + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 56f096f3ecf8c..5fc99a3a57c0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,18 +39,19 @@ class TypedFilterOptimizationSuite extends PlanTest { implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + val testRelation = LocalRelation('_1.int, '_2.int) + test("filter after serialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -58,10 +60,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter after serialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze @@ -70,17 +71,16 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -89,10 +89,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze @@ -101,21 +100,89 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("back to back filter with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: (Int, Int)) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 1) } test("back to back filter with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: OtherTuple) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("back to back FilterFunction with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("back to back FilterFunction with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("FilterFunction and filter with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("FilterFunction and filter with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: OtherTuple) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("filter and FilterFunction with the same object type") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("filter and FilterFunction with different object types") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 2) } From 522e0b1866a0298669c83de5a47ba380dc0b7c84 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 3 Feb 2018 00:04:00 -0800 Subject: [PATCH 0283/2461] [SPARK-23305][SQL][TEST] Test `spark.sql.files.ignoreMissingFiles` for all file-based data sources ## What changes were proposed in this pull request? Like Parquet, all file-based data source handles `spark.sql.files.ignoreMissingFiles` correctly. We had better have a test coverage for feature parity and in order to prevent future accidental regression for all data sources. ## How was this patch tested? Pass Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20479 from dongjoon-hyun/SPARK-23305. --- .../spark/sql/FileBasedDataSourceSuite.scala | 37 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 33 ----------------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index c272c99ae45a8..640d6b1583663 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { @@ -92,4 +96,37 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { } } } + + allFileBasedDataSources.foreach { format => + testQuietly(s"Enabling/disabling ignoreMissingFiles using $format") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) + Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val df = spark.read.format(format).load( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + assert(fs.delete(thirdPath, true)) + checkAnswer(df, Seq(Row("0"), Row("1"))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 6ad88ed997ce7..55b0f729be8ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -355,39 +355,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - testQuietly("Enabling/disabling ignoreMissingFiles") { - def testIgnoreMissingFiles(): Unit = { - withTempDir { dir => - val basePath = dir.getCanonicalPath - spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) - spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) - val thirdPath = new Path(basePath, "third") - spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) - val df = spark.read.parquet( - new Path(basePath, "first").toString, - new Path(basePath, "second").toString, - new Path(basePath, "third").toString) - - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) - fs.delete(thirdPath, true) - checkAnswer( - df, - Seq(Row(0), Row(1))) - } - } - - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { - testIgnoreMissingFiles() - } - - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { - val exception = intercept[SparkException] { - testIgnoreMissingFiles() - } - assert(exception.getMessage().contains("does not exist")) - } - } - /** * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop * to increase the chance of failure From 4aaa7d40bf495317e740b6d6f9c2a55dfd03521b Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 3 Feb 2018 10:31:04 -0800 Subject: [PATCH 0284/2461] [MINOR][DOC] Use raw triple double quotes around docstrings where there are occurrences of backslashes. From [PEP 257](https://www.python.org/dev/peps/pep-0257/): > For consistency, always use """triple double quotes""" around docstrings. Use r"""raw triple double quotes""" if you use any backslashes in your docstrings. For Unicode docstrings, use u"""Unicode triple-quoted strings""". For example, this is what help (kafka_wordcount) shows: ``` DESCRIPTION Counts words in UTF8 encoded, ' ' delimited text received from the network every second. Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py localhost:2181 test` ``` This is what it shows, after the fix: ``` DESCRIPTION Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example `$ bin/spark-submit --jars \ external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` ``` The thing worth noticing is no linebreak here in the help. ## What changes were proposed in this pull request? Change triple double quotes to raw triple double quotes when there are occurrences of backslashes in docstrings. ## How was this patch tested? Manually as this is a doc fix. Author: Shashwat Anand Closes #20497 from ashashwat/docstring-fixes. --- .../main/python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- examples/src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index afde2550587ca..c3284c1d01017 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network. Usage: structured_network_wordcount.py and describe the TCP server that Structured Streaming diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index 02a7d3363d780..db672551504b5 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network over a sliding window of configurable duration. Each line from the network is tagged with a timestamp that is used to determine the windows into which it falls. diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 7097f7f4502bd..425df309011a0 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. Usage: direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index d75bc6daac138..5d6e6dc36d6f9 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: flume_wordcount.py diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 8d697f620f467..704f6602e2297 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 2b48bcfd55db0..9010fafb425e6 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: network_wordcount.py and describe the TCP server that Spark Streaming would connect to receive data. diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index b309d9fad33f5..d51a380a5d5f9 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 398ac8d2d8f5e..7f12281c0e3fe 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the network every second. diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index f8bbc659c2ea7..d7bb61e729f18 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. From 551dff2bccb65e9b3f77b986f167aec90d9a6016 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 3 Feb 2018 10:40:21 -0800 Subject: [PATCH 0285/2461] [SPARK-21658][SQL][PYSPARK] Revert "[] Add default None for value in na.replace in PySpark" This reverts commit 0fcde87aadc9a92e138f11583119465ca4b5c518. See the discussion in [SPARK-21658](https://issues.apache.org/jira/browse/SPARK-21658), [SPARK-19454](https://issues.apache.org/jira/browse/SPARK-19454) and https://github.com/apache/spark/pull/16793 Author: hyukjinkwon Closes #20496 from HyukjinKwon/revert-SPARK-21658. --- python/pyspark/sql/dataframe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1496cba91b90e..2e55407b5397b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1577,16 +1577,6 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ - >>> df4.na.replace('Alice').show() - +----+------+----+ - | age|height|name| - +----+------+----+ - | 10| 80|null| - | 5| null| Bob| - |null| null| Tom| - |null| null|null| - +----+------+----+ - >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -2055,7 +2045,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ From 715047b02df0ac9ec16ab2a73481ab7f36ffc6ca Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Feb 2018 17:53:31 +0900 Subject: [PATCH 0286/2461] [SPARK-23256][ML][PYTHON] Add columnSchema method to PySpark image reader ## What changes were proposed in this pull request? This PR proposes to add `columnSchema` in Python side too. ```python >>> from pyspark.ml.image import ImageSchema >>> ImageSchema.columnSchema.simpleString() 'struct' ``` ## How was this patch tested? Manually tested and unittest was added in `python/pyspark/ml/tests.py`. Author: hyukjinkwon Closes #20475 from HyukjinKwon/SPARK-23256. --- python/pyspark/ml/image.py | 20 +++++++++++++++++++- python/pyspark/ml/tests.py | 1 + 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 2d86c7f03860c..45c936645f2a8 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -40,6 +40,7 @@ class _ImageSchema(object): def __init__(self): self._imageSchema = None self._ocvTypes = None + self._columnSchema = None self._imageFields = None self._undefinedImageType = None @@ -49,7 +50,7 @@ def imageSchema(self): Returns the image schema. :return: a :class:`StructType` with a single column of images - named "image" (nullable). + named "image" (nullable) and having the same type returned by :meth:`columnSchema`. .. versionadded:: 2.3.0 """ @@ -75,6 +76,23 @@ def ocvTypes(self): self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) return self._ocvTypes + @property + def columnSchema(self): + """ + Returns the schema for the image column. + + :return: a :class:`StructType` for image column, + ``struct``. + + .. versionadded:: 2.4.0 + """ + + if self._columnSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.columnSchema() + self._columnSchema = _parse_datatype_json_string(jschema.json()) + return self._columnSchema + @property def imageFields(self): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..75d04785a0710 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1852,6 +1852,7 @@ def test_read_images(self): self.assertEqual(len(array), first_row[1]) self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} self.assertEqual(ImageSchema.ocvTypes, expected) expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] From 6fb3fd15365d43733aefdb396db205d7ccf57f75 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 4 Feb 2018 09:15:48 -0800 Subject: [PATCH 0287/2461] [SPARK-22036][SQL][FOLLOWUP] Fix decimalArithmeticOperations.sql ## What changes were proposed in this pull request? Fix decimalArithmeticOperations.sql test ## How was this patch tested? N/A Author: Yuming Wang Author: wangyum Author: Yuming Wang Closes #20498 from wangyum/SPARK-22036. --- .../native/decimalArithmeticOperations.sql | 6 +- .../decimalArithmeticOperations.sql.out | 140 ++++++++++-------- 2 files changed, 80 insertions(+), 66 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c6d8a49d4b93a..9be7fcdadfea8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -48,8 +48,9 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss are truncated +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; -- return NULL instead of rounding, according to old Spark versions' behavior set spark.sql.decimalOperations.allowPrecisionLoss=false; @@ -74,7 +75,8 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 4d70fe19d539f..6bfdb84548d4d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 36 -- !query 0 @@ -146,146 +146,158 @@ NULL -- !query 17 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 17 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 17 output -138698367904130467.654320988515622621 +10012345678912345678912345678911.246907 -- !query 18 -select 0.001 / 9876543210987654321098765432109876543.2 - -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 18 schema -struct<> +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 18 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input 'spark' expecting (line 3, pos 4) - -== SQL == -select 0.001 / 9876543210987654321098765432109876543.2 - -set spark.sql.decimalOperations.allowPrecisionLoss=false -----^^^ +138698367904130467.654320988515622621 -- !query 19 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 19 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 19 output -1 1099 -899 99900 0.1001 -2 24690.246 0 152402061.885129 1 -3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 -4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 +1000000073899961059796.725866332 -- !query 20 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 20 schema -struct +struct -- !query 20 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.112345678912345679 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 21 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 21 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 21 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 22 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 22 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 22 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 23 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 23 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 23 output -309 +30.9 -- !query 24 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 24 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 24 output 30.9 -- !query 25 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 25 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 25 output -30.9 +309 -- !query 26 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 26 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 26 output -NULL +30.9 -- !query 27 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 27 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 27 output NULL -- !query 28 -select 12345678901234567890.0 * 12345678901234567890.0 +select (5e36 + 0.1) + 5e36 -- !query 28 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 28 output NULL -- !query 29 -select 1e35 / 0.1 +select (-4e36 - 0.1) - 7e36 -- !query 29 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 29 output NULL -- !query 30 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 30 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 30 output -138698367904130467.654320988515622621 +NULL -- !query 31 -select 0.001 / 9876543210987654321098765432109876543.2 - -drop table decimals_test +select 1e35 / 0.1 -- !query 31 schema -struct<> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 31 output -org.apache.spark.sql.catalyst.parser.ParseException +NULL -mismatched input 'table' expecting (line 3, pos 5) -== SQL == -select 0.001 / 9876543210987654321098765432109876543.2 +-- !query 32 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 32 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 32 output +NULL + + +-- !query 33 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 33 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 33 output +NULL + +-- !query 34 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 34 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 34 output +NULL + + +-- !query 35 drop table decimals_test ------^^^ +-- !query 35 schema +struct<> +-- !query 35 output + From a6bf3db20773ba65cbc4f2775db7bd215e78829a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 5 Feb 2018 18:41:49 +0800 Subject: [PATCH 0288/2461] [SPARK-23307][WEBUI] Sort jobs/stages/tasks/queries with the completed timestamp before cleaning up them ## What changes were proposed in this pull request? Sort jobs/stages/tasks/queries with the completed timestamp before cleaning up them to make the behavior consistent with 2.2. ## How was this patch tested? - Jenkins. - Manually ran the following codes and checked the UI for jobs/stages/tasks/queries. ``` spark.ui.retainedJobs 10 spark.ui.retainedStages 10 spark.sql.ui.retainedExecutions 10 spark.ui.retainedTasks 10 ``` ``` new Thread() { override def run() { spark.range(1, 2).foreach { i => Thread.sleep(10000) } } }.start() Thread.sleep(5000) for (_ <- 1 to 20) { new Thread() { override def run() { spark.range(1, 2).foreach { i => } } }.start() } Thread.sleep(15000) spark.range(1, 2).foreach { i => } sc.makeRDD(1 to 100, 100).foreach { i => } ``` Author: Shixiong Zhu Closes #20481 from zsxwing/SPARK-23307. --- .../spark/status/AppStatusListener.scala | 13 +-- .../org/apache/spark/status/storeTypes.scala | 7 ++ .../spark/status/AppStatusListenerSuite.scala | 90 +++++++++++++++++++ .../execution/ui/SQLAppStatusListener.scala | 4 +- .../sql/execution/ui/SQLAppStatusStore.scala | 9 +- .../ui/SQLAppStatusListenerSuite.scala | 45 ++++++++++ 6 files changed, 158 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 3e34bdc0c7b63..ab01cddfca5b0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -875,8 +875,8 @@ private[spark] class AppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[JobDataWrapper]), - countToDelete.toInt) { j => + val view = kvstore.view(classOf[JobDataWrapper]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j => j.info.status != JobExecutionStatus.RUNNING && j.info.status != JobExecutionStatus.UNKNOWN } toDelete.foreach { j => kvstore.delete(j.getClass(), j.info.jobId) } @@ -888,8 +888,8 @@ private[spark] class AppStatusListener( return } - val stages = KVUtils.viewToSeq(kvstore.view(classOf[StageDataWrapper]), - countToDelete.toInt) { s => + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } @@ -945,8 +945,9 @@ private[spark] class AppStatusListener( val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) - val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) - .last(stageKey) + val view = kvstore.view(classOf[TaskDataWrapper]) + .index(TaskIndexNames.COMPLETION_TIME) + .parent(stageKey) // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index c9cb996a55fcc..412644d3657b5 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -73,6 +73,8 @@ private[spark] class JobDataWrapper( @JsonIgnore @KVIndex private def id: Int = info.jobId + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } private[spark] class StageDataWrapper( @@ -90,6 +92,8 @@ private[spark] class StageDataWrapper( @JsonIgnore @KVIndex("active") private def active: Boolean = info.status == StageStatus.ACTIVE + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } /** @@ -134,6 +138,7 @@ private[spark] object TaskIndexNames { final val STAGE = "stage" final val STATUS = "sta" final val TASK_INDEX = "idx" + final val COMPLETION_TIME = "ct" } /** @@ -337,6 +342,8 @@ private[spark] class TaskDataWrapper( @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) private def error: String = if (errorMessage.isDefined) errorMessage.get else "" + @JsonIgnore @KVIndex(value = TaskIndexNames.COMPLETION_TIME, parent = TaskIndexNames.STAGE) + private def completionTime: Long = launchTime + duration } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 042bba7f226fd..b74d6ee2ec836 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1010,6 +1010,96 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("eviction should respect job completion time") { + val testConf = conf.clone().set(MAX_RETAINED_JOBS, 2) + val listener = new AppStatusListener(store, testConf, true) + + // Start job 1 and job 2 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Nil, null)) + time += 1 + listener.onJobStart(SparkListenerJobStart(2, time, Nil, null)) + + // Stop job 2 before job 1 + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Start job 3 and job 2 should be evicted. + time += 1 + listener.onJobStart(SparkListenerJobStart(3, time, Nil, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[JobDataWrapper], 2) + } + } + + test("eviction should respect stage completion time") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + // Start stage 1 and stage 2 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + + // Stop stage 2 before stage 1 + time += 1 + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Start stage 3 and stage 2 should be evicted. + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + + test("eviction should respect task completion time") { + val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + // Start task 1 and task 2 + val tasks = createTasks(3, Array("1")) + tasks.take(2).foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task)) + } + + // Stop task 2 before task 1 + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + + // Start task 3 and task 2 should be evicted. + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) + assert(store.count(classOf[TaskDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[TaskDataWrapper], tasks(1).id) + } + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 73a105266e1c1..53fb9a0cc21cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -332,8 +332,8 @@ class SQLAppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[SQLExecutionUIData]), - countToDelete.toInt) { e => e.completionTime.isDefined } + val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 910f2e52fdbb3..9a76584717f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -23,11 +23,12 @@ import java.util.Date import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus import org.apache.spark.status.KVUtils.KVIndexParam -import org.apache.spark.util.kvstore.KVStore +import org.apache.spark.util.kvstore.{KVIndex, KVStore} /** * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's @@ -90,7 +91,11 @@ class SQLExecutionUIData( * from the SQL listener instance. */ @JsonDeserialize(keyAs = classOf[JLong]) - val metricValues: Map[Long, String]) + val metricValues: Map[Long, String]) { + + @JsonIgnore @KVIndex("completionTime") + private def completionTimeIndex: Long = completionTime.map(_.getTime).getOrElse(-1L) +} class SparkPlanGraphWrapper( @KVIndexParam val executionId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 7d84f45d36bee..85face3994fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.status.ElementTrackingStore import org.apache.spark.status.config._ @@ -510,6 +511,50 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with } } + test("eviction should respect execution completion time") { + val conf = sparkContext.conf.clone().set(UI_RETAINED_EXECUTIONS.key, "2") + val store = new ElementTrackingStore(new InMemoryStore, conf) + val listener = new SQLAppStatusListener(conf, store, live = true) + val statusStore = new SQLAppStatusStore(store, Some(listener)) + + var time = 0 + val df = createTestDataFrame + // Start execution 1 and execution 2 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 1, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 2, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + + // Stop execution 2 before execution 1 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(2, time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(1, time)) + + // Start execution 3 and execution 2 should be evicted. + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 3, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + assert(statusStore.executionsCount === 2) + assert(statusStore.execution(2) === None) + } } From 03b7e120dd7ff7848c936c7a23644da5bd7219ab Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Mon, 5 Feb 2018 10:19:18 -0800 Subject: [PATCH 0289/2461] [SPARK-23310][CORE] Turn off read ahead input stream for unshafe shuffle reader To fix regression for TPC-DS queries Author: Sital Kedia Closes #20492 from sitalkedia/turn_off_async_inputstream. --- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index e2f48e5508af6..71e7c7a95ebdb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -76,8 +76,10 @@ public UnsafeSorterSpillReader( SparkEnv.get() == null ? 0.5 : SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf regression for + // TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); From c2766b07b4b9ed976931966a79c65043e81cf694 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 5 Feb 2018 14:17:11 -0800 Subject: [PATCH 0290/2461] [SPARK-23330][WEBUI] Spark UI SQL executions page throws NPE ## What changes were proposed in this pull request? Spark SQL executions page throws the following error and the page crashes: ``` HTTP ERROR 500 Problem accessing /SQL/. Reason: Server Error Caused by: java.lang.NullPointerException at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47) at scala.collection.immutable.StringOps.length(StringOps.scala:47) at scala.collection.IndexedSeqOptimized$class.isEmpty(IndexedSeqOptimized.scala:27) at scala.collection.immutable.StringOps.isEmpty(StringOps.scala:29) at scala.collection.TraversableOnce$class.nonEmpty(TraversableOnce.scala:111) at scala.collection.immutable.StringOps.nonEmpty(StringOps.scala:29) at org.apache.spark.sql.execution.ui.ExecutionTable.descriptionCell(AllExecutionsPage.scala:182) at org.apache.spark.sql.execution.ui.ExecutionTable.row(AllExecutionsPage.scala:155) at org.apache.spark.sql.execution.ui.ExecutionTable$$anonfun$8.apply(AllExecutionsPage.scala:204) at org.apache.spark.sql.execution.ui.ExecutionTable$$anonfun$8.apply(AllExecutionsPage.scala:204) at org.apache.spark.ui.UIUtils$$anonfun$listingTable$2.apply(UIUtils.scala:339) at org.apache.spark.ui.UIUtils$$anonfun$listingTable$2.apply(UIUtils.scala:339) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.ui.UIUtils$.listingTable(UIUtils.scala:339) at org.apache.spark.sql.execution.ui.ExecutionTable.toNodeSeq(AllExecutionsPage.scala:203) at org.apache.spark.sql.execution.ui.AllExecutionsPage.render(AllExecutionsPage.scala:67) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) at org.eclipse.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180) at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512) at org.eclipse.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112) at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141) at org.eclipse.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213) at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134) at org.eclipse.jetty.server.Server.handle(Server.java:534) at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:320) at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:251) at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283) at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:108) at org.eclipse.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136) at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671) at org.eclipse.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589) at java.lang.Thread.run(Thread.java:748) ``` One of the possible reason that this page fails may be the `SparkListenerSQLExecutionStart` event get dropped before processed, so the execution description and details don't get updated. This was not a issue in 2.2 because it would ignore any job start event that arrives before the corresponding execution start event, which doesn't sound like a good decision. We shall try to handle the null values in the front page side, that is, try to give a default value when `execution.details` or `execution.description` is null. Another possible approach is not to spill the `LiveExecutionData` in `SQLAppStatusListener.update(exec: LiveExecutionData)` if `exec.details` is null. This is not ideal because this way you will not see the execution if `SparkListenerSQLExecutionStart` event is lost, because `AllExecutionsPage` only read executions from KVStore. ## How was this patch tested? After the change, the page shows the following: ![image](https://user-images.githubusercontent.com/4784782/35775480-28cc5fde-093e-11e8-8ccc-f58c2ef4a514.png) Author: Xingbo Jiang Closes #20502 from jiangxb1987/executionPage. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 7019d98e1619f..e751ce39cd5d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -179,7 +179,7 @@ private[ui] abstract class ExecutionTable( } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { - val details = if (execution.details.nonEmpty) { + val details = if (execution.details != null && execution.details.nonEmpty) { +details ++ @@ -190,8 +190,10 @@ private[ui] abstract class ExecutionTable( Nil } - val desc = { + val desc = if (execution.description != null && execution.description.nonEmpty) { {execution.description} + } else { + {execution.executionId} }
{desc} {details}
From f3f1e14bb73dfdd2927d95b12d7d61d22de8a0ac Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 6 Feb 2018 14:42:42 +0800 Subject: [PATCH 0291/2461] [SPARK-23326][WEBUI] schedulerDelay should return 0 when the task is running ## What changes were proposed in this pull request? When a task is still running, metrics like executorRunTime are not available. Then `schedulerDelay` will be almost the same as `duration` and that's confusing. This PR makes `schedulerDelay` return 0 when the task is running which is the same behavior as 2.2. ## How was this patch tested? `AppStatusUtilsSuite.schedulerDelay` Author: Shixiong Zhu Closes #20493 from zsxwing/SPARK-23326. --- .../apache/spark/status/AppStatusUtils.scala | 11 ++- .../spark/status/AppStatusUtilsSuite.scala | 89 +++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala index 341bd4e0cd016..87f434daf4870 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -17,16 +17,23 @@ package org.apache.spark.status -import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} +import org.apache.spark.status.api.v1.TaskData private[spark] object AppStatusUtils { + private val TASK_FINISHED_STATES = Set("FAILED", "KILLED", "SUCCESS") + + private def isTaskFinished(task: TaskData): Boolean = { + TASK_FINISHED_STATES.contains(task.status) + } + def schedulerDelay(task: TaskData): Long = { - if (task.taskMetrics.isDefined && task.duration.isDefined) { + if (isTaskFinished(task) && task.taskMetrics.isDefined && task.duration.isDefined) { val m = task.taskMetrics.get schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) } else { + // The task is still running and the metrics like executorRunTime are not available. 0L } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala new file mode 100644 index 0000000000000..9e74e86ad54b9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status + +import java.util.Date + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +class AppStatusUtilsSuite extends SparkFunSuite { + + test("schedulerDelay") { + val runningTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "RUNNING", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 0L, + executorDeserializeCpuTime = 0L, + executorRunTime = 0L, + executorCpuTime = 0L, + resultSize = 0L, + jvmGcTime = 0L, + resultSerializationTime = 0L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 0L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) + + val finishedTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "SUCCESS", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 5L, + executorDeserializeCpuTime = 3L, + executorRunTime = 90L, + executorCpuTime = 10L, + resultSize = 100L, + jvmGcTime = 10L, + resultSerializationTime = 2L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 100L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) + } +} From a24c03138a6935a442b983c8a4c721b26df3f9e2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 14:52:25 +0800 Subject: [PATCH 0292/2461] [SPARK-23290][SQL][PYTHON] Use datetime.date for date type when converting Spark DataFrame to Pandas DataFrame. ## What changes were proposed in this pull request? In #18664, there was a change in how `DateType` is being returned to users ([line 1968 in dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)). This can cause client code which works in Spark 2.2 to fail. See [SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917) for an example. This pr modifies to use `datetime.date` for date type as Spark 2.2 does. ## How was this patch tested? Tests modified to fit the new behavior and existing tests. Author: Takuya UESHIN Closes #20506 from ueshin/issues/SPARK-23290. --- python/pyspark/serializers.py | 9 ++++-- python/pyspark/sql/dataframe.py | 7 ++-- python/pyspark/sql/tests.py | 57 ++++++++++++++++++++++++--------- python/pyspark/sql/types.py | 15 +++++++++ 4 files changed, 66 insertions(+), 22 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 88d6a191babca..e870325d202ca 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -267,12 +267,15 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) + schema = from_arrow_schema(reader.schema) for batch in reader: - # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) + pdf = batch.to_pandas() + pdf = _check_dataframe_convert_date(pdf, schema) + pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) yield [c for _, c in pdf.iteritems()] def __repr__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2e55407b5397b..59a417015b949 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1923,7 +1923,8 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow require_minimum_pyarrow_version() @@ -1931,6 +1932,7 @@ def toPandas(self): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) @@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. - NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 - elif type(dt) == DateType: - return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b27363023ae77..545ec5aee08ff 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2816,7 +2816,7 @@ def test_to_pandas(self): self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") @@ -3388,7 +3388,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - from datetime import datetime + from datetime import date, datetime from decimal import Decimal ReusedSQLTestCase.setUpClass() @@ -3410,11 +3410,11 @@ def setUpClass(cls): StructField("7_date_t", DateType(), True), StructField("8_timestamp_t", TimestampType(), True)]) cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3461,7 +3461,9 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow, pdf) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -4062,18 +4064,42 @@ def test_vectorized_udf_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('map'))).collect() - def test_vectorized_udf_null_date(self): + def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date - schema = StructType().add("date", DateType()) - data = [(date(1969, 1, 1),), - (date(2012, 2, 2),), - (None,), - (date(2100, 4, 4),)] + schema = StructType().add("idx", LongType()).add("date", DateType()) + data = [(0, date(1969, 1, 1),), + (1, date(2012, 2, 2),), + (2, None,), + (3, date(2100, 4, 4),)] df = self.spark.createDataFrame(data, schema=schema) - date_f = pandas_udf(lambda t: t, returnType=DateType()) - res = df.select(date_f(col("date"))) - self.assertEquals(df.collect(), res.collect()) + + date_copy = pandas_udf(lambda t: t, returnType=DateType()) + df = df.withColumn("date_copy", date_copy(col("date"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, date, date_copy): + import pandas as pd + msgs = [] + is_equal = date.isnull() + for i in range(len(idx)): + if (is_equal[i] and data[idx[i]][1] is None) or \ + date[i] == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "date values are not equal (date='%s': data[%d][1]='%s')" + % (date[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", + check_data(col("idx"), col("date"), col("date_copy"))).collect() + + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][1], result[i][2]) # "date_copy" col + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): from pyspark.sql.functions import pandas_udf, col @@ -4114,6 +4140,7 @@ def check_data(idx, timestamp, timestamp_copy): self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0dc5823f72a3c..093dae5a22e1f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_dataframe_convert_date(pdf, schema): + """ Correct date type value to use datetime.date. + + Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should + use datetime.date to match the behavior with when Arrow optimization is disabled. + + :param pdf: pandas.DataFrame + :param schema: a Spark schema of the pandas.DataFrame + """ + for field in schema: + if type(field.dataType) == DateType: + pdf[field.name] = pdf[field.name].dt.date + return pdf + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone From 8141c3e3ddb55586906b9bc79ef515142c2b551a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 6 Feb 2018 16:08:15 +0900 Subject: [PATCH 0293/2461] [SPARK-23300][TESTS] Prints out if Pandas and PyArrow are installed or not in PySpark SQL tests ## What changes were proposed in this pull request? This PR proposes to log if PyArrow and Pandas are installed or not so we can check if related tests are going to be skipped or not. ## How was this patch tested? Manually tested: I don't have PyArrow installed in PyPy. ```bash $ ./run-tests --python-executables=python3 ``` ``` ... Will test against the following Python executables: ['python3'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python3' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python3' in 'pyspark-sql' module. Starting test(python3): pyspark.mllib.tests Starting test(python3): pyspark.sql.tests Starting test(python3): pyspark.streaming.tests Starting test(python3): pyspark.tests ``` ```bash $ ./run-tests --modules=pyspark-streaming ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-streaming'] Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.streaming.util Starting test(python2.7): pyspark.streaming.tests Starting test(python2.7): pyspark.streaming.util ``` ```bash $ ./run-tests ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 0.8.0 is required; however, PyArrow was not found. Will test Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.tests Starting test(python2.7): pyspark.mllib.tests ``` ```bash $ ./run-tests --modules=pyspark-sql --python-executables=pypy ``` ``` ... Will test against the following Python executables: ['pypy'] Will test the following Python modules: ['pyspark-sql'] Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 0.8.0 is required; however, PyArrow was not found. Will test Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.sql.catalog Starting test(pypy): pyspark.sql.column Starting test(pypy): pyspark.sql.conf ``` After some modification to produce other cases: ```bash $ ./run-tests ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will skip PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. PyArrow >= 20.0.0 is required; however, PyArrow 0.8.0 was found. Will skip Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Pandas >= 20.0.0 is required; however, Pandas 0.20.2 was found. Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 20.0.0 is required; however, PyArrow was not found. Will skip Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Pandas >= 20.0.0 is required; however, Pandas 0.22.0 was found. Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.tests Starting test(python2.7): pyspark.mllib.tests ``` ```bash ./run-tests-with-coverage ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Coverage is not installed in Python executable 'pypy' but 'COVERAGE_PROCESS_START' environment variable is set, exiting. ``` Author: hyukjinkwon Closes #20473 from HyukjinKwon/SPARK-23300. --- python/run-tests.py | 73 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index f03284c334285..6b41b5ee22814 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,7 @@ import Queue else: import queue as Queue +from distutils.version import LooseVersion # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -38,8 +39,8 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which, subprocess_check_output, run_cmd # noqa -from sparktestsupport.modules import all_modules # noqa +from sparktestsupport.shellutils import which, subprocess_check_output # noqa +from sparktestsupport.modules import all_modules, pyspark_sql # noqa python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') @@ -151,6 +152,67 @@ def parse_opts(): return opts +def _check_dependencies(python_exec, modules_to_test): + if "COVERAGE_PROCESS_START" in os.environ: + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) + + # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and + # explicitly prints out. See SPARK-23300. + if pyspark_sql in modules_to_test: + # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. + minimum_pyarrow_version = '0.8.0' + minimum_pandas_version = '0.19.2' + + try: + pyarrow_version = subprocess_check_output( + [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], + universal_newlines=True, + stderr=open(os.devnull, 'w')).strip() + if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): + LOGGER.info("Will test PyArrow related features against Python executable " + "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) + else: + LOGGER.warning( + "Will skip PyArrow related features against Python executable " + "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " + "%s was found." % ( + python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) + except: + LOGGER.warning( + "Will skip PyArrow related features against Python executable " + "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " + "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) + + try: + pandas_version = subprocess_check_output( + [python_exec, "-c", "import pandas; print(pandas.__version__)"], + universal_newlines=True, + stderr=open(os.devnull, 'w')).strip() + if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): + LOGGER.info("Will test Pandas related features against Python executable " + "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) + else: + LOGGER.warning( + "Will skip Pandas related features against Python executable " + "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " + "%s was found." % ( + python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) + except: + LOGGER.warning( + "Will skip Pandas related features against Python executable " + "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " + "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) + + def main(): opts = parse_opts() if (opts.verbose): @@ -175,9 +237,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - run_cmd([python_exec, "-c", "import coverage"]) + # Check if the python executable has proper dependencies installed to run tests + # for given modules properly. + _check_dependencies(python_exec, modules_to_test) + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() From 63c5bf13ce5cd3b8d7e7fb88de881ed207fde720 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 18:30:50 +0900 Subject: [PATCH 0294/2461] [SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2. ## What changes were proposed in this pull request? In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g., ```python from pyspark.sql.functions import pandas_udf, col import pandas as pd df = spark.range(10) str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string") df.select(str_f(col('id'))).show() ``` raises the following exception: ``` ... java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.(ArrowEvalPythonExec.scala:93) ... ``` Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2. This pr adds a workaround for the case. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20507 from ueshin/issues/SPARK-23334. --- python/pyspark/serializers.py | 4 ++++ python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e870325d202ca..91a7f093cec19 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -230,6 +230,10 @@ def create_array(s, t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 545ec5aee08ff..89b7c2182d2d1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3922,6 +3922,15 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) + def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( From 7db9979babe52d15828967c86eb77e3fb2791579 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 10:46:48 -0800 Subject: [PATCH 0295/2461] [SPARK-23310][CORE][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20492 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java:[79] (sizes) LineLength: Line is longer than 100 characters (found 114). ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20514 from ueshin/issues/SPARK-23310/fup1. --- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 71e7c7a95ebdb..2c53c8d809d2e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -76,8 +76,8 @@ public UnsafeSorterSpillReader( SparkEnv.get() == null ? 0.5 : SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf regression for - // TPC-DS queries. + // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf + // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); From ac7454cac04a1d9252b3856360eda5c3e8bcb8da Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:27:37 -0800 Subject: [PATCH 0296/2461] [SPARK-23312][SQL][FOLLOWUP] add a config to turn off vectorized cache reader ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/20483 tried to provide a way to turn off the new columnar cache reader, to restore the behavior in 2.2. However even we turn off that config, the behavior is still different than 2.2. If the output data are rows, we still enable whole stage codegen for the scan node, which is different with 2.2, we should also fix it. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20513 from cloud-fan/cache. --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 3 +++ .../src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | 3 ++- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e972f8b30d87c..a93e8a1ad954d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -61,6 +61,9 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + // TODO: revisit this. Shall we always turn off whole stage codegen if the output data are rows? + override def supportCodegen: Boolean = supportsBatch + override protected def needsUnsafeRowConversion: Boolean = false private val columnIndices = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 9f27fa09127af..669e5f2bf4e65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -787,7 +787,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { val df = spark.range(10).cache() df.queryExecution.executedPlan.foreach { - case i: InMemoryTableScanExec => assert(i.supportsBatch == vectorized) + case i: InMemoryTableScanExec => + assert(i.supportsBatch == vectorized && i.supportCodegen == vectorized) case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6e8d5a70d5a8f..ef16292a8e75c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -137,7 +137,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan assert(planString.collect { - case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + case i: InMemoryTableScanExec if !i.supportsBatch => () }.length == 1) assert(dsStringFilter.collect() === Array("1")) } From caf30445632de6aec810309293499199e7a20892 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Feb 2018 12:30:04 -0800 Subject: [PATCH 0297/2461] [MINOR][TEST] Fix class name for Pandas UDF tests ## What changes were proposed in this pull request? In https://github.com/apache/spark/commit/b2ce17b4c9fea58140a57ca1846b2689b15c0d61, I mistakenly renamed `VectorizedUDFTests` to `ScalarPandasUDF`. This PR fixes the mistake. ## How was this patch tested? Existing tests. Author: Li Jin Closes #20489 from icexelloss/fix-scalar-udf-tests. --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 89b7c2182d2d1..53da7dd45c2f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3766,7 +3766,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class ScalarPandasUDF(ReusedSQLTestCase): +class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4279,7 +4279,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): +class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4448,7 +4448,7 @@ def test_unsupported_types(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyAggPandasUDFTests(ReusedSQLTestCase): +class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property def data(self): From b96a083b1c6ff0d2c588be9499b456e1adce97dc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:43:45 -0800 Subject: [PATCH 0298/2461] [SPARK-23315][SQL] failed to get output from canonicalized data source v2 related plans ## What changes were proposed in this pull request? `DataSourceV2Relation` keeps a `fullOutput` and resolves the real output on demand by column name lookup. i.e. ``` lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => fullOutput.find(_.name == name).get } ``` This will be broken after we canonicalize the plan, because all attribute names become "None", see https://github.com/apache/spark/blob/v2.3.0-rc1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala#L42 To fix this, `DataSourceV2Relation` should just keep `output`, and update the `output` when doing column pruning. ## How was this patch tested? a new test case Author: Wenchen Fan Closes #20485 from cloud-fan/canonicalize. --- .../v2/DataSourceReaderHolder.scala | 12 +++----- .../datasources/v2/DataSourceV2Relation.scala | 8 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 4 +-- .../v2/PushDownOperatorsToDataSource.scala | 29 +++++++++++++------ .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- 5 files changed, 48 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6460c97abe344..81219e9771bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.Objects -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.v2.reader._ /** @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._ trait DataSourceReaderHolder { /** - * The full output of the data source reader, without column pruning. + * The output of the data source reader, w.r.t. column pruning. */ - def fullOutput: Seq[AttributeReference] + def output: Seq[Attribute] /** * The held data source reader. @@ -46,7 +46,7 @@ trait DataSourceReaderHolder { case s: SupportsPushDownFilters => s.pushedFilters().toSet case _ => Nil } - Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + Seq(output, reader.getClass, filters) } def canEqual(other: Any): Boolean @@ -61,8 +61,4 @@ trait DataSourceReaderHolder { override def hashCode(): Int = { metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) } - - lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => - fullOutput.find(_.name == name).get - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index eebfa29f91b99..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], reader: DataSourceReader) extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { @@ -37,7 +37,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(fullOutput = fullOutput.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -46,8 +46,8 @@ case class DataSourceV2Relation( * to the non-streaming relation. */ class StreamingDataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { override def isStreaming: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index df469af2c262a..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 566a48394f02e..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + RemoveRedundantProject(columnPruned) } // TODO: nested fields pruning - private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + private def pushDownRequiredColumns( + plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { plan match { - case Project(projectList, child) => + case p @ Project(projectList, child) => val required = projectList.flatMap(_.references) - pushDownRequiredColumns(child, AttributeSet(required)) + p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - case Filter(condition, child) => + case f @ Filter(condition, child) => val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) + f.copy(child = pushDownRequiredColumns(child, required)) case relation: DataSourceV2Relation => relation.reader match { case reader: SupportsPushDownRequiredColumns => + // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now + // it's possible that the mutable reader being updated by someone else, and we need to + // always call `reader.pruneColumns` here to correct it. + // assert(relation.output.toStructType == reader.readSchema(), + // "Schema of data source reader does not match the relation plan.") + val requiredColumns = relation.output.filter(requiredByParent.contains) reader.pruneColumns(requiredColumns.toStructType) - case _ => + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val newOutput = reader.readSchema().map(_.name).map(nameToAttr) + relation.copy(output = newOutput) + + case _ => relation } // TODO: there may be more operators that can be used to calculate the required columns. We // can add more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index eccd45442a3b2..a1c87fb15542c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ @@ -316,6 +316,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val reader4 = getReader(q4) assert(reader4.requiredSchema.fieldNames === Seq("i")) } + + test("SPARK-23315: get output from canonicalized data source v2 related plans") { + def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + assert(logical.canonicalized.output.length == numOutput) + + val physical = df.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + }.head + assert(physical.canonicalized.output.length == numOutput) + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + checkCanonicalizedOutput(df, 2) + checkCanonicalizedOutput(df.select('i), 1) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { From c36fecc3b416c38002779c3cf40b6a665ac4bf13 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 6 Feb 2018 16:46:43 -0800 Subject: [PATCH 0299/2461] [SPARK-23327][SQL] Update the description and tests of three external API or functions ## What changes were proposed in this pull request? Update the description and tests of three external API or functions `createFunction `, `length` and `repartitionByRange ` ## How was this patch tested? N/A Author: gatorsmile Closes #20495 from gatorsmile/updateFunc. --- R/pkg/R/functions.R | 4 +++- python/pyspark/sql/functions.py | 8 ++++--- .../sql/catalyst/catalog/SessionCatalog.scala | 7 ++++-- .../expressions/stringExpressions.scala | 23 ++++++++++--------- .../scala/org/apache/spark/sql/Dataset.scala | 2 ++ .../sql/execution/command/functions.scala | 14 +++++++---- .../org/apache/spark/sql/functions.scala | 4 +++- .../execution/command/DDLParserSuite.scala | 10 ++++---- 8 files changed, 44 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 55365a41d774b..9f7c6317cd924 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1026,7 +1026,9 @@ setMethod("last_day", }) #' @details -#' \code{length}: Computes the length of a given string or binary column. +#' \code{length}: Computes the character length of a string data or number of bytes +#' of a binary data. The length of string data includes the trailing spaces. +#' The length of binary data includes binary zeros. #' #' @rdname column_string_functions #' @aliases length length,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c8fb4c4d19e7..05031f5ec87d7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1705,10 +1705,12 @@ def unhex(col): @ignore_unicode_prefix @since(1.5) def length(col): - """Calculates the length of a string or binary expression. + """Computes the character length of string data or number of bytes of binary data. + The length of character data includes the trailing spaces. The length of binary data + includes binary zeros. - >>> spark.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() - [Row(length=3)] + >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=4)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.length(_to_java_column(col))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a129896230775..4b119c75260a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -988,8 +988,11 @@ class SessionCatalog( // ------------------------------------------------------- /** - * Create a metastore function in the database specified in `funcDefinition`. + * Create a function in the database specified in `funcDefinition`. * If no such database is specified, create it in the current database. + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. */ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) @@ -1061,7 +1064,7 @@ class SessionCatalog( } /** - * Check if the specified function exists. + * Check if the function with the specified name exists */ def functionExists(name: FunctionIdentifier): Boolean = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5cf783f1a5979..d7612e30b4c57 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1653,19 +1653,19 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + + "binary data. The length of string data includes the trailing spaces. The length of binary " + + "data includes binary zeros.", examples = """ Examples: - > SELECT _FUNC_('Spark SQL'); - 9 - > SELECT CHAR_LENGTH('Spark SQL'); - 9 - > SELECT CHARACTER_LENGTH('Spark SQL'); - 9 + > SELECT _FUNC_('Spark SQL '); + 10 + > SELECT CHAR_LENGTH('Spark SQL '); + 10 + > SELECT CHARACTER_LENGTH('Spark SQL '); + 10 """) -// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1687,7 +1687,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn * A function that returns the bit length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the bit length of `expr` or number of bits in binary data.", + usage = "_FUNC_(expr) - Returns the bit length of string data or number of bits of binary data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); @@ -1716,7 +1716,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas * A function that returns the byte length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the byte length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the byte length of string data or number of bytes of binary " + + "data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d47cd0aecf56a..0aee1d7be5788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2825,6 +2825,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 @@ -2848,6 +2849,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 4f92ffee687aa..1f7808c2f8e80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -40,6 +40,10 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [databaseName.]functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. + * @param replace: When true, alter the function with the specified name */ case class CreateFunctionCommand( databaseName: Option[String], @@ -47,17 +51,17 @@ case class CreateFunctionCommand( className: String, resources: Seq[FunctionResource], isTemp: Boolean, - ifNotExists: Boolean, + ignoreIfExists: Boolean, replace: Boolean) extends RunnableCommand { - if (ifNotExists && replace) { + if (ignoreIfExists && replace) { throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + " is not allowed.") } // Disallow to define a temporary function with `IF NOT EXISTS` - if (ifNotExists && isTemp) { + if (ignoreIfExists && isTemp) { throw new AnalysisException( "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.") } @@ -79,12 +83,12 @@ case class CreateFunctionCommand( // Handles `CREATE OR REPLACE FUNCTION AS ... USING ...` if (replace && catalog.functionExists(func.identifier)) { // alter the function in the metastore - catalog.alterFunction(CatalogFunction(func.identifier, className, resources)) + catalog.alterFunction(func) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. - catalog.createFunction(CatalogFunction(func.identifier, className, resources), ifNotExists) + catalog.createFunction(func, ignoreIfExists) } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d11682d80a3c..0d54c02c3d06f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2267,7 +2267,9 @@ object functions { } /** - * Computes the length of a given string or binary column. + * Computes the character length of a given string or number of bytes of a binary string. + * The length of character strings include the trailing spaces. The length of binary strings + * includes binary zeros. * * @group string_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 2b1aea08b1223..e0ccae15f1d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -236,7 +236,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = false) + isTemp = true, ignoreIfExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -244,7 +244,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = false) + isTemp = false, ignoreIfExists = false, replace = false) val expected3 = CreateFunctionCommand( None, "helloworld3", @@ -252,7 +252,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = true) + isTemp = true, ignoreIfExists = false, replace = true) val expected4 = CreateFunctionCommand( Some("hello"), "world1", @@ -260,7 +260,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = true) + isTemp = false, ignoreIfExists = false, replace = true) val expected5 = CreateFunctionCommand( Some("hello"), "world2", @@ -268,7 +268,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = true, replace = false) + isTemp = false, ignoreIfExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) From 9775df67f924663598d51723a878557ddafb8cfd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 7 Feb 2018 23:24:16 +0900 Subject: [PATCH 0300/2461] [SPARK-23122][PYSPARK][FOLLOWUP] Replace registerTempTable by createOrReplaceTempView ## What changes were proposed in this pull request? Replace `registerTempTable` by `createOrReplaceTempView`. ## How was this patch tested? N/A Author: gatorsmile Closes #20523 from gatorsmile/updateExamples. --- python/pyspark/sql/udf.py | 2 +- .../src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 0f759c448b8a7..08c6b9e521e82 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -356,7 +356,7 @@ def registerJavaUDAF(self, name, javaClassName): >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") + >>> df.createOrReplaceTempView("df") >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] """ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java index ddbaa45a483cb..08dc129f27a0c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -46,7 +46,7 @@ public void tearDown() { @SuppressWarnings("unchecked") @Test public void udf1Test() { - spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.range(1, 10).toDF("value").createOrReplaceTempView("df"); spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); From 71cfba04aeec5ae9b85a507b13996e80f8750edc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 7 Feb 2018 23:28:10 +0900 Subject: [PATCH 0301/2461] [SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test) ## What changes were proposed in this pull request? This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test. We declared the extra dependencies: https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204 In case of PyArrow: Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed: ``` ====================================================================== ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf return _create_udf(f=f, returnType=return_type, evalType=eval_type) File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf require_minimum_pyarrow_version() File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version "however, your version was %s." % pyarrow.__version__) ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0. ---------------------------------------------------------------------- Ran 33 tests in 8.098s FAILED (errors=33) ``` In case of Pandas: There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing. ## How was this patch tested? Manually tested by modifying the condition: ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ``` Author: hyukjinkwon Closes #20487 from HyukjinKwon/pyarrow-pandas-skip. --- pom.xml | 4 ++ python/pyspark/sql/dataframe.py | 3 ++ python/pyspark/sql/session.py | 3 ++ python/pyspark/sql/tests.py | 87 ++++++++++++++++++--------------- python/pyspark/sql/utils.py | 30 +++++++++--- python/setup.py | 10 +++- 6 files changed, 89 insertions(+), 48 deletions(-) diff --git a/pom.xml b/pom.xml index 666d5d7169a15..d18831df1db6d 100644 --- a/pom.xml +++ b/pom.xml @@ -185,6 +185,10 @@ 2.8 1.8 1.0.0 + 0.8.0 ${java.home} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 59a417015b949..8ec24db8717b2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1913,6 +1913,9 @@ def toPandas(self): 0 2 Alice 1 5 Bob """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1ed04298bc899..b3af9b82953f3 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -646,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ == "true": timezone = self.conf.get("spark.sql.session.timeZone") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 53da7dd45c2f2..58359b61dc83a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -48,19 +48,26 @@ else: import unittest -_have_pandas = False -_have_old_pandas = False +_pandas_requirement_message = None try: - import pandas - try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - _have_pandas = True - except: - _have_old_pandas = True -except: - # No Pandas, but that's okay, we'll skip those tests - pass + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() +except ImportError as e: + from pyspark.util import _exception_message + # If Pandas version requirement is not satisfied, skip related tests. + _pandas_requirement_message = _exception_message(e) + +_pyarrow_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() +except ImportError as e: + from pyspark.util import _exception_message + # If Arrow version requirement is not satisfied, skip related tests. + _pyarrow_requirement_message = _exception_message(e) + +_have_pandas = _pandas_requirement_message is None +_have_pyarrow = _pyarrow_requirement_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -75,15 +82,6 @@ from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2794,7 +2792,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"): def _to_pandas(self): from datetime import datetime, date - import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType())\ .add("dt", DateType()).add("ts", TimestampType()) @@ -2807,7 +2804,7 @@ def _to_pandas(self): df = self.spark.createDataFrame(data, schema) return df.toPandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas(self): import numpy as np pdf = self._to_pandas() @@ -2819,13 +2816,13 @@ def test_to_pandas(self): self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_to_pandas_old(self): + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_to_pandas_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas_avoid_astype(self): import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ @@ -2843,7 +2840,7 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime @@ -2858,14 +2855,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self): self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_create_dataframe_from_old_pandas(self): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): + with self.assertRaisesRegexp( + ImportError, + '(Pandas >= .* must be installed|No module named pandas)'): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) @@ -3383,7 +3382,9 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ArrowTests(ReusedSQLTestCase): @classmethod @@ -3641,7 +3642,9 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df_arrow.columns) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType @@ -3765,7 +3768,9 @@ def foo(k, v): return k -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod @@ -4278,7 +4283,9 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res2.collect()) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property @@ -4447,7 +4454,9 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 08c34c6dccc5e..578298632dd4c 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr): def require_minimum_pandas_version(): """ Raise ImportError if minimum version of Pandas is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pandas_version = "0.19.2" + from distutils.version import LooseVersion - import pandas - if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " - "however, your version was %s." % pandas.__version__) + try: + import pandas + except ImportError: + raise ImportError("Pandas >= %s must be installed; however, " + "it was not found." % minimum_pandas_version) + if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): + raise ImportError("Pandas >= %s must be installed; however, " + "your version was %s." % (minimum_pandas_version, pandas.__version__)) def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pyarrow_version = "0.8.0" + from distutils.version import LooseVersion - import pyarrow - if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " - "however, your version was %s." % pyarrow.__version__) + try: + import pyarrow + except ImportError: + raise ImportError("PyArrow >= %s must be installed; however, " + "it was not found." % minimum_pyarrow_version) + if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): + raise ImportError("PyArrow >= %s must be installed; however, " + "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) diff --git a/python/setup.py b/python/setup.py index 251d4526d4dd0..6a98401941d8d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,6 +100,11 @@ def _supports_symlinks(): file=sys.stderr) exit(-1) +# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and +# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. +_minimum_pandas_version = "0.19.2" +_minimum_pyarrow_version = "0.8.0" + try: # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts # find it where expected. The rest of the files aren't copied because they are accessed @@ -201,7 +206,10 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] + 'sql': [ + 'pandas>=%s' % _minimum_pandas_version, + 'pyarrow>=%s' % _minimum_pyarrow_version, + ] }, classifiers=[ 'Development Status :: 5 - Production/Stable', From 9841ae0313cbee1f083f131f9446808c90ed5a7b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Feb 2018 09:48:49 -0800 Subject: [PATCH 0302/2461] [SPARK-23345][SQL] Remove open stream record even closing it fails ## What changes were proposed in this pull request? When `DebugFilesystem` closes opened stream, if any exception occurs, we still need to remove the open stream record from `DebugFilesystem`. Otherwise, it goes to report leaked filesystem connection. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20524 from viirya/SPARK-23345. --- core/src/test/scala/org/apache/spark/DebugFilesystem.scala | 7 +++++-- .../org/apache/spark/sql/test/SharedSparkSession.scala | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index 91355f7362900..a5bdc95790722 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -103,8 +103,11 @@ class DebugFilesystem extends LocalFileSystem { override def markSupported(): Boolean = wrapped.markSupported() override def close(): Unit = { - wrapped.close() - removeOpenStream(wrapped) + try { + wrapped.close() + } finally { + removeOpenStream(wrapped) + } } override def read(): Int = wrapped.read() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 0b4629a51b425..e758c865b908f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -111,7 +111,7 @@ trait SharedSparkSession spark.sharedState.cacheManager.clearCache() // files can be closed from other threads, so wait a bit // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(2.seconds)) { DebugFilesystem.assertNoOpenStreams() } } From 30295bf5a6754d0ae43334f7bf00e7a29ed0f1af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 7 Feb 2018 15:22:53 -0800 Subject: [PATCH 0303/2461] [SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs ## What changes were proposed in this pull request? This PR migrates the MemoryStream to DataSourceV2 APIs. One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly. ## How was this patch tested? Existing unit tests, few updated unit tests. Author: Tathagata Das Author: Burak Yavuz Closes #20445 from tdas/SPARK-23092. --- .../sql/execution/streaming/LongOffset.scala | 4 +- .../streaming/MicroBatchExecution.scala | 27 ++-- .../sql/execution/streaming/memory.scala | 132 +++++++++++------- .../sources/RateStreamSourceV2.scala | 2 +- .../streaming/ForeachSinkSuite.scala | 55 +++----- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 5 +- .../sql/streaming/StreamingQuerySuite.scala | 70 ++++++---- 9 files changed, 171 insertions(+), 134 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 5f0b195fcfcb8..3ff5b86ac45d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} + /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends Offset { +case class LongOffset(offset: Long) extends OffsetV2 { override val json = offset.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index d9aa8573ba930..045d2b4b9569c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -270,16 +270,17 @@ class MicroBatchExecution( } case s: MicroBatchReader => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - - (s, Some(s.getEndOffset)) + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } + + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -401,10 +402,14 @@ class MicroBatchExecution( case (reader: MicroBatchReader, available) if committedOffsets.get(reader).map(_ != available).getOrElse(true) => val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) + case v2: OffsetV2 => v2 + } reader.setOffsetRange( toJava(current), - Optional.of(available.asInstanceOf[OffsetV2])) - logDebug(s"Retrieving data from $reader: $current -> $available") + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") Some(reader -> new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 509a69dd922fb..352d4ce9fbcaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,21 +17,23 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -51,9 +53,10 @@ object MemoryStream { * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends Source with Logging { + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) + private val attributes = encoder.schema.toAttributes + protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -61,11 +64,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[Dataset[A]] + protected val batches = new ListBuffer[Array[UnsafeRow]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + @GuardedBy("this") + private var startOffset = new LongOffset(-1) + + @GuardedBy("this") + private var endOffset = new LongOffset(-1) + /** * Last offset that was discarded, or -1 if no commits have occurred. Note that the value * -1 is used in calculations below and isn't just an arbitrary constant. @@ -73,8 +82,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def schema: StructType = encoder.schema - def toDS(): Dataset[A] = { Dataset(sqlContext.sparkSession, logicalPlan) } @@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } def addData(data: TraversableOnce[A]): Offset = { - val encoded = data.toVector.map(d => encoder.toRow(d).copy()) - val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) - val ds = Dataset[A](sqlContext.sparkSession, plan) - logDebug(s"Adding ds: $ds") + val objects = data.toSeq + val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray + logDebug(s"Adding: $objects") this.synchronized { currentOffset = currentOffset + 1 - batches += ds + batches += rows currentOffset } } override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None - } else { - Some(currentOffset) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] } } - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 - - // Internal buffer only holds the batches after lastCommittedOffset. - val newBlocks = synchronized { - val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 - val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 - assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") - batches.slice(sliceStart, sliceEnd) - } + override def readSchema(): StructType = encoder.schema - if (newBlocks.isEmpty) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) + + override def getStartOffset: OffsetV2 = synchronized { + if (startOffset.offset == -1) null else startOffset + } - logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) + override def getEndOffset: OffsetV2 = synchronized { + if (endOffset.offset == -1) null else endOffset + } - newBlocks - .map(_.toDF()) - .reduceOption(_ union _) - .getOrElse { - sys.error("No data selected!") + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) + val startOrdinal = startOffset.offset.toInt + 1 + val endOrdinal = endOffset.offset.toInt + 1 + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") + batches.slice(sliceStart, sliceEnd) } + + logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) + + newBlocks.map { block => + new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + }.asJava + } } private def generateDebugString( - blocks: TraversableOnce[Dataset[A]], + rows: Seq[UnsafeRow], startOrdinal: Int, endOrdinal: Int): String = { - val originalUnsupportedCheck = - sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck") - try { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false") - s"MemoryBatch [$startOrdinal, $endOrdinal]: " + - s"${blocks.flatMap(_.collect()).mkString(", ")}" - } finally { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck) - } + val fromRow = encoder.resolveAndBind().fromRow _ + s"MemoryBatch [$startOrdinal, $endOrdinal]: " + + s"${rows.map(row => fromRow(row)).mkString(", ")}" } - override def commit(end: Offset): Unit = synchronized { + override def commit(end: OffsetV2): Unit = synchronized { def check(newOffset: LongOffset): Unit = { val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt @@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def reset(): Unit = synchronized { batches.clear() + startOffset = LongOffset(-1) + endOffset = LongOffset(-1) currentOffset = new LongOffset(-1) lastOffsetCommitted = new LongOffset(-1) } } + +class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) + extends DataReaderFactory[UnsafeRow] { + override def createDataReader(): DataReader[UnsafeRow] = { + new DataReader[UnsafeRow] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the array. + currentIndex += 1 + currentIndex < records.length + } + + override def get(): UnsafeRow = records(currentIndex) + + override def close(): Unit = {} + } + } +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 1315885da8a6f..077a255946a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor } class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - var currentIndex = -1 + private var currentIndex = -1 override def next(): Boolean = { // Return true as long as the new index is in the seq. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 41434e6d8b974..b249dd41a84a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter()) .start() - // -- batch 0 --------------------------------------- - input.addData(1, 2, 3, 4) - query.processAllAvailable() + def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { + import ForeachSinkSuite._ - var expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) - ) - var expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 1), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) - ) + val events = ForeachSinkSuite.allEvents() + assert(events.size === 2) // one seq of events for each of the 2 partitions - var allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + // Verify both seq of events have an Open event as the first event + assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) + + // Verify all the Process event correspond to the expected data + val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) + assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) + + // Verify both seq of events have a Close event as the last event + assert(events.map(_.last).toSet === Set(Close(None), Close(None))) + } + // -- batch 0 --------------------------------------- ForeachSinkSuite.clear() + input.addData(1, 2, 3, 4) + query.processAllAvailable() + verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- + ForeachSinkSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() - - expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 5), - ForeachSinkSuite.Process(value = 7), - ForeachSinkSuite.Close(None) - ) - expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 1), - ForeachSinkSuite.Process(value = 6), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) - ) - - allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + verifyOutput(expectedVersion = 1, expectedData = 5 to 8) query.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c65e5d3dd75c2..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d6433562fb29b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { (source, source.addData(data)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 79d65192a14aa..b96f2bcbdd644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol @@ -298,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getOffset: Option[Offset] = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.getOffset + super.getEndOffset } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 76201c63a2701..3f9aa0d1fa5be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,25 +17,27 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils -import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ManualClock class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { @@ -206,19 +208,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // getOffset should take 50 ms the first time it is called - override def getOffset: Option[Offset] = { - val offset = super.getOffset - if (offset.nonEmpty) { - clock.waitTillTime(1050) + + private def dataAdded: Boolean = currentOffset.offset != -1 + + // setOffsetRange should take 50 ms the first time it is called after data is added + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.setOffsetRange(start, end) } - offset + } + + // getEndOffset should take 100 ms the first time it is called after data is added + override def getEndOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1150) + super.getEndOffset() } // getBatch should take 100 ms the first time it is called - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - if (start.isEmpty) clock.waitTillTime(1150) - super.getBatch(start, end) + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + clock.waitTillTime(1350) + super.createUnsafeRowReaderFactories() + } } } @@ -258,39 +270,44 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while offset is being fetched + // Test status and progress when setOffsetRange is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being fetched - AdvanceManualClock(50), // time = 1050 to unblock getOffset + AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 + AssertOnQuery(_.status.isDataAvailable === false), + AssertOnQuery(_.status.isTriggerActive === true), + AssertOnQuery(_.status.message.startsWith("Getting offsets from")), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), + + AdvanceManualClock(100), // time = 1150 to unblock getEndOffset + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being processed - AdvanceManualClock(100), // time = 1150 to unblock getBatch - AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 + AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AssertClockTime(1350), + AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, - AdvanceManualClock(350), // time = 1500 to unblock job + AdvanceManualClock(150), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadIsWaitingForTime(2000), + AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -307,10 +324,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 50) - assert(progress.durationMs.get("getBatch") === 100) + assert(progress.durationMs.get("setOffsetRange") === 50) + assert(progress.durationMs.get("getEndOffset") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From a62f30d3fa032ff75bc2b7bebbd0813e67ea5fd5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 Feb 2018 12:46:10 +0900 Subject: [PATCH 0304/2461] [SPARK-23319][TESTS][FOLLOWUP] Fix a test for Python 3 without pandas. ## What changes were proposed in this pull request? This is a followup pr of #20487. When importing module but it doesn't exists, the error message is slightly different between Python 2 and 3. E.g., in Python 2: ``` No module named pandas ``` in Python 3: ``` No module named 'pandas' ``` So, one test to check an import error fails in Python 3 without pandas. This pr fixes it. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN Closes #20538 from ueshin/issues/SPARK-23319/fup1. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 58359b61dc83a..90ff084fed55e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2860,7 +2860,7 @@ def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp( ImportError, - '(Pandas >= .* must be installed|No module named pandas)'): + "(Pandas >= .* must be installed|No module named '?pandas'?)"): import pandas as pd from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], From 3473fda6dc77bdfd84b3de95d2082856ad4f8626 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 8 Feb 2018 12:21:18 +0800 Subject: [PATCH 0305/2461] Revert [SPARK-22279][SQL] Turn on spark.sql.hive.convertMetastoreOrc by default ## What changes were proposed in this pull request? This is to revert the changes made in https://github.com/apache/spark/pull/19499 , because this causes a regression. We should not ignore the table-specific compression conf when the Hive serde tables are converted to the data source tables. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20536 from gatorsmile/revert22279. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index d9627eb9790eb..93f3f38e52aa9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -109,7 +109,7 @@ private[spark] object HiveUtils extends Logging { .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 7f5f5fb1296275a38da0adfa05125dd8ebf729ff Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Feb 2018 00:08:54 -0800 Subject: [PATCH 0306/2461] [SPARK-23348][SQL] append data using saveAsTable should adjust the data types ## What changes were proposed in this pull request? For inserting/appending data to an existing table, Spark should adjust the data types of the input query according to the table schema, or fail fast if it's uncastable. There are several ways to insert/append data: SQL API, `DataFrameWriter.insertInto`, `DataFrameWriter.saveAsTable`. The first 2 ways create `InsertIntoTable` plan, and the last way creates `CreateTable` plan. However, we only adjust input query data types for `InsertIntoTable`, and users may hit weird errors when appending data using `saveAsTable`. See the JIRA for the error case. This PR fixes this bug by adjusting data types for `CreateTable` too. ## How was this patch tested? new test. Author: Wenchen Fan Closes #20527 from cloud-fan/saveAsTable. --- .../sql/execution/datasources/rules.scala | 72 +++++++++++-------- .../sql/execution/command/DDLSuite.scala | 28 ++++++++ 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5dbcf4a915cbf..5cc21eeaeaa94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -178,7 +178,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(DDLPreprocessingUtils.castAndRenameQueryOutput( + newQuery, existingTable.schema.toAttributes, conf))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: @@ -316,7 +317,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -336,6 +337,8 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit s"including ${staticPartCols.size} partition column(s) having constant value(s).") } + val newQuery = DDLPreprocessingUtils.castAndRenameQueryOutput( + insert.query, expectedColumns, conf) if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( @@ -346,37 +349,11 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit """.stripMargin) } - castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) + insert.copy(query = newQuery, partition = normalizedPartSpec) } else { // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - castAndRenameChildOutput(insert, expectedColumns) - .copy(partition = partColNames.map(_ -> None).toMap) - } - } - - private def castAndRenameChildOutput( - insert: InsertIntoTable, - expectedOutput: Seq[Attribute]): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(insert.query.output).map { - case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && - expected.name == actual.name && - expected.metadata == actual.metadata) { - actual - } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - Alias(cast(actual, expected.dataType), expected.name)( - explicitMetadata = Option(expected.metadata)) - } - } - - if (newChildOutput == insert.query.output) { - insert - } else { - insert.copy(query = Project(newChildOutput, insert.query)) + insert.copy(query = newQuery, partition = partColNames.map(_ -> None).toMap) } } @@ -491,3 +468,36 @@ object PreWriteCheck extends (LogicalPlan => Unit) { } } } + +object DDLPreprocessingUtils { + + /** + * Adjusts the name and data type of the input query output columns, to match the expectation. + */ + def castAndRenameQueryOutput( + query: LogicalPlan, + expectedOutput: Seq[Attribute], + conf: SQLConf): LogicalPlan = { + val newChildOutput = expectedOutput.zip(query.output).map { + case (expected, actual) => + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { + actual + } else { + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias( + Cast(actual, expected.dataType, Option(conf.sessionLocalTimeZone)), + expected.name)(explicitMetadata = Option(expected.metadata)) + } + } + + if (newChildOutput == query.output) { + query + } else { + Project(newChildOutput, query) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index ee3674ba17821..f76bfd2fda2b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test @@ -132,6 +134,32 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) } } + + // TODO: This test is copied from HiveDDLSuite, unify it later. + test("SPARK-23348: append data to data source table with saveAsTable") { + withTable("t", "t1") { + Seq(1 -> "a").toDF("i", "j").write.saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a")) + + sql("INSERT INTO t SELECT 2, 'b'") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + + Seq(3 -> "c").toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Nil) + + Seq("c" -> 3).toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") + :: Row(null, "3") :: Nil) + + Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") + + val e = intercept[AnalysisException] { + Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + } + assert(e.message.contains("The format of the existing table default.t1 is " + + "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils { From a75f927173632eee1316879447cb62c8cf30ae37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Feb 2018 19:20:11 +0800 Subject: [PATCH 0307/2461] [SPARK-23268][SQL][FOLLOWUP] Reorganize packages in data source V2 ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/20435. While reorganizing the packages for streaming data source v2, the top level stream read/write support interfaces should not be in the reader/writer package, but should be in the `sources.v2` package, to follow the `ReadSupport`, `WriteSupport`, etc. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20509 from cloud-fan/followup. --- .../org/apache/spark/sql/kafka010/KafkaSourceProvider.scala | 4 +--- .../sql/sources/v2/{reader => }/ContinuousReadSupport.java | 4 +--- .../sql/sources/v2/{reader => }/MicroBatchReadSupport.java | 4 +--- .../sql/sources/v2/{writer => }/StreamWriteSupport.java | 5 ++--- .../apache/spark/sql/sources/v2/writer/DataSourceWriter.java | 1 + .../spark/sql/execution/streaming/MicroBatchExecution.scala | 5 ++--- .../spark/sql/execution/streaming/RateSourceProvider.scala | 1 - .../spark/sql/execution/streaming/StreamingRelation.scala | 3 +-- .../org/apache/spark/sql/execution/streaming/console.scala | 3 +-- .../execution/streaming/continuous/ContinuousExecution.scala | 4 +--- .../sql/execution/streaming/sources/RateStreamSourceV2.scala | 2 +- .../spark/sql/execution/streaming/sources/memoryV2.scala | 2 +- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 3 +-- .../org/apache/spark/sql/streaming/DataStreamWriter.scala | 2 +- .../apache/spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../sql/streaming/sources/StreamingDataSourceV2Suite.scala | 5 ++--- 17 files changed, 19 insertions(+), 33 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader => }/ContinuousReadSupport.java (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader => }/MicroBatchReadSupport.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{writer => }/StreamWriteSupport.java (93%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 694ca76e24964..d4fa0359c12d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,9 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java index 0c1d5d1a9577a..7df5a451ae5f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 5e8f0c0dafdcf..209ffa7a0b9fa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java index 1c0e2e12f8d51..a77b01497269e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 52324b3792b8a..e3f682bf96a66 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 045d2b4b9569c..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,10 +29,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index ce5e63f5bde85..649fbbfa184ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 845c8d2c14e43..7146190645b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index db600866067bc..cfba1001c6de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index ed22b9100497a..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,10 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 077a255946a6b..4e2459bb05bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 3411edbc53412..f960208155e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 116ac3da07b75..f23851655350a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,8 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 9aac360fd4bbc..2fc903168cfa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index ddb1edc433d5a..7cefd03e43bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 0d68d9c3138aa..983ba1668f58f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 51f44fa6285e4..af4618bed5456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,10 +25,9 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, Streami import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType From 76e019d9bdcdca176c79c1cd71ddbf496333bf93 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 8 Feb 2018 23:41:30 +0800 Subject: [PATCH 0308/2461] [SPARK-21860][CORE] Improve memory reuse for heap memory in `HeapMemoryAllocator` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In `HeapMemoryAllocator`, when allocating memory from pool, and the key of pool is memory size. Actually some size of memory ,such as 1025bytes,1026bytes,......1032bytes, we can think they are the same,because we allocate memory in multiples of 8 bytes. In this case, we can improve memory reuse. ## How was this patch tested? Existing tests and added unit tests Author: liuxian Closes #19077 from 10110346/headmemoptimize. --- .../unsafe/memory/HeapMemoryAllocator.java | 18 +++++++++------ .../spark/unsafe/PlatformUtilSuite.java | 22 +++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index a9603c1aba051..2733760dd19ef 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -46,9 +46,12 @@ private boolean shouldPool(long size) { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { + int numWords = (int) ((size + 7) / 8); + long alignedSize = numWords * 8L; + assert (alignedSize >= size); + if (shouldPool(alignedSize)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool != null) { while (!pool.isEmpty()) { final WeakReference arrayReference = pool.pop(); @@ -62,11 +65,11 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { return memory; } } - bufferPoolsBySize.remove(size); + bufferPoolsBySize.remove(alignedSize); } } } - long[] array = new long[(int) ((size + 7) / 8)]; + long[] array = new long[numWords]; MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); @@ -98,12 +101,13 @@ public void free(MemoryBlock memory) { long[] array = (long[]) memory.obj; memory.setObjAndOffset(null, 0); - if (shouldPool(size)) { + long alignedSize = ((size + 7) / 8) * 8; + if (shouldPool(alignedSize)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool == null) { pool = new LinkedList<>(); - bufferPoolsBySize.put(size, pool); + bufferPoolsBySize.put(alignedSize, pool); } pool.add(new WeakReference<>(array)); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 62854837b05ed..71c53d35dcab8 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe; +import org.apache.spark.unsafe.memory.HeapMemoryAllocator; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -134,4 +135,25 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryAllocator.UNSAFE.free(offheap); } + + @Test + public void heapMemoryReuse() { + MemoryAllocator heapMem = new HeapMemoryAllocator(); + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + MemoryBlock onheap1 = heapMem.allocate(513); + Object obj1 = onheap1.getBaseObject(); + heapMem.free(onheap1); + MemoryBlock onheap2 = heapMem.allocate(514); + Assert.assertNotEquals(obj1, onheap2.getBaseObject()); + + // The size is greater than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // reuse the previous memory which has released. + MemoryBlock onheap3 = heapMem.allocate(1024 * 1024 + 1); + Assert.assertEquals(onheap3.size(), 1024 * 1024 + 1); + Object obj3 = onheap3.getBaseObject(); + heapMem.free(onheap3); + MemoryBlock onheap4 = heapMem.allocate(1024 * 1024 + 7); + Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); + Assert.assertEquals(obj3, onheap4.getBaseObject()); + } } From 4df84c3f818aa536515729b442601e08c253ed35 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 8 Feb 2018 12:52:08 -0600 Subject: [PATCH 0309/2461] [SPARK-23336][BUILD] Upgrade snappy-java to 1.1.7.1 ## What changes were proposed in this pull request? This PR upgrade snappy-java from 1.1.2.6 to 1.1.7.1. 1.1.7.1 release notes: - Improved performance for big-endian architecture - The other performance improvement in [snappy-1.1.5](https://github.com/google/snappy/releases/tag/1.1.5) 1.1.4 release notes: - Fix a 1% performance regression when snappy is used in PIE executables. - Improve compression performance by 5%. - Improve decompression performance by 20%. More details: https://github.com/xerial/snappy-java/blob/master/Milestone.md ## How was this patch tested? manual tests Author: Yuming Wang Closes #20510 from wangyum/SPARK-23336. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 48e54568e6fc6..99031384aa22e 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -182,7 +182,7 @@ slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1807a77900e52..cf8d2789b7ee9 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -183,7 +183,7 @@ slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar diff --git a/pom.xml b/pom.xml index d18831df1db6d..de949b94d676c 100644 --- a/pom.xml +++ b/pom.xml @@ -160,7 +160,7 @@ 1.9.13 2.6.7 2.6.7.1 - 1.1.2.6 + 1.1.7.1 1.1.2 1.2.0-incubating 1.10 From 8cbcc33876c773722163b2259644037bbb259bd1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 9 Feb 2018 12:54:57 +0800 Subject: [PATCH 0310/2461] [SPARK-23186][SQL] Initialize DriverManager first before loading JDBC Drivers ## What changes were proposed in this pull request? Since some JDBC Drivers have class initialization code to call `DriverManager`, we need to initialize `DriverManager` first in order to avoid potential executor-side **deadlock** situations like the following (or [STORM-2527](https://issues.apache.org/jira/browse/STORM-2527)). ``` Thread 9587: (state = BLOCKED) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame; information may be imprecise) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - java.util.ServiceLoader$LazyIterator.nextService() bci=119, line=380 (Interpreted frame) - java.util.ServiceLoader$LazyIterator.next() bci=11, line=404 (Interpreted frame) - java.util.ServiceLoader$1.next() bci=37, line=480 (Interpreted frame) - java.sql.DriverManager$2.run() bci=21, line=603 (Interpreted frame) - java.sql.DriverManager$2.run() bci=1, line=583 (Interpreted frame) - java.security.AccessController.doPrivileged(java.security.PrivilegedAction) bci=0 (Compiled frame) - java.sql.DriverManager.loadInitialDrivers() bci=27, line=583 (Interpreted frame) - java.sql.DriverManager.() bci=32, line=101 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getConnection(java.lang.String, java.lang.Integer, java.lang.String, java.util.Properties) bci=12, line=98 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getInputConnection(org.apache.hadoop.conf.Configuration, java.util.Properties) bci=22, line=57 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.getQueryPlan(org.apache.hadoop.mapreduce.JobContext, org.apache.hadoop.conf.Configuration) bci=61, line=116 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.createRecordReader(org.apache.hadoop.mapreduce.InputSplit, org.apache.hadoop.mapreduce.TaskAttemptContext) bci=10, line=71 (Interpreted frame) - org.apache.spark.rdd.NewHadoopRDD$$anon$1.(org.apache.spark.rdd.NewHadoopRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=233, line=156 (Interpreted frame) Thread 9170: (state = BLOCKED) - org.apache.phoenix.jdbc.PhoenixDriver.() bci=35, line=125 (Interpreted frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(java.lang.String) bci=89, line=46 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=7, line=53 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=1, line=52 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$$anon$1.(org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=81, line=347 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD.compute(org.apache.spark.Partition, org.apache.spark.TaskContext) bci=7, line=339 (Interpreted frame) ``` ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #20359 from dongjoon-hyun/SPARK-23186. --- .../sql/execution/datasources/jdbc/DriverRegistry.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7a6c0f9fed2f9..1723596de1db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -32,6 +32,13 @@ import org.apache.spark.util.Utils */ object DriverRegistry extends Logging { + /** + * Load DriverManager first to avoid any race condition between + * DriverManager static initialization block and specific driver class's + * static initialization block. e.g. PhoenixDriver + */ + DriverManager.getDrivers + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty def register(className: String): Unit = { From 4b4ee2601079f12f8f410a38d2081793cbdedc14 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 9 Feb 2018 14:21:10 +0800 Subject: [PATCH 0311/2461] [SPARK-23328][PYTHON] Disallow default value None in na.replace/replace when 'to_replace' is not a dictionary ## What changes were proposed in this pull request? This PR proposes to disallow default value None when 'to_replace' is not a dictionary. It seems weird we set the default value of `value` to `None` and we ended up allowing the case as below: ```python >>> df.show() ``` ``` +----+------+-----+ | age|height| name| +----+------+-----+ | 10| 80|Alice| ... ``` ```python >>> df.na.replace('Alice').show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` **After** This PR targets to disallow the case above: ```python >>> df.na.replace('Alice').show() ``` ``` ... TypeError: value is required when to_replace is not a dictionary. ``` while we still allow when `to_replace` is a dictionary: ```python >>> df.na.replace({'Alice': None}).show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` ## How was this patch tested? Manually tested, tests were added in `python/pyspark/sql/tests.py` and doctests were fixed. Author: hyukjinkwon Closes #20499 from HyukjinKwon/SPARK-19454-followup. --- docs/sql-programming-guide.md | 1 + python/pyspark/__init__.py | 1 + python/pyspark/_globals.py | 70 +++++++++++++++++++++++++++++++++ python/pyspark/sql/dataframe.py | 26 +++++++++--- python/pyspark/sql/tests.py | 11 +++--- 5 files changed, 99 insertions(+), 10 deletions(-) create mode 100644 python/pyspark/_globals.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0e221b39cc34..eab4030ee25d2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1929,6 +1929,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 4d142c91629cc..58218918693ca 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,6 +54,7 @@ from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ +from pyspark._globals import _NoValue def since(version): diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py new file mode 100644 index 0000000000000..8e6099db09963 --- /dev/null +++ b/python/pyspark/_globals.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if pyspark itself is reloaded. In particular, a function like the following +will still work correctly after pyspark is reloaded: + + def foo(arg=pyspark._NoValue): + if arg is pyspark._NoValue: + ... + +See gh-7844 for a discussion of the reload problem that motivated this module. + +Note that this approach is taken after from NumPy. +""" + +__ALL__ = ['_NoValue'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading pyspark._globals is not allowed') +_is_loaded = True + + +class _NoValueType(object): + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + deprecated keyword in order to check if it has been given a user defined + value. + + This class was copied from NumPy. + """ + __instance = None + + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super(_NoValueType, cls).__new__(cls) + return cls.__instance + + # needed for python 2 to preserve identity through a pickle + def __reduce__(self): + return (self.__class__, ()) + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8ec24db8717b2..faee870a2d2e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,7 @@ import warnings -from pyspark import copy_func, since +from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer @@ -1532,7 +1532,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1545,8 +1545,8 @@ def replace(self, to_replace, value=None, subset=None): :param to_replace: bool, int, long, float, string, list or dict. Value to be replaced. - If the value is a dict, then `value` is ignored and `to_replace` must be a - mapping between a value and a replacement. + If the value is a dict, then `value` is ignored or can be omitted, and `to_replace` + must be a mapping between a value and a replacement. :param value: bool, int, long, float, string, list or None. The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. @@ -1577,6 +1577,16 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ + >>> df4.na.replace({'Alice': None}).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1587,6 +1597,12 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ """ + if value is _NoValue: + if isinstance(to_replace, dict): + value = None + else: + raise TypeError("value argument is required when to_replace is not a dictionary.") + # Helper functions def all_of(types): """Given a type or tuple of types and a sequence of xs @@ -2047,7 +2063,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 90ff084fed55e..6ace16955000d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2243,11 +2243,6 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) - # replace list while value is not given (default to None) - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - self.assertTupleEqual(row, (None, 10, 80.0)) - # replace string with None and then drop None rows row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() @@ -2283,6 +2278,12 @@ def test_replace(self): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + with self.assertRaisesRegexp( + TypeError, + 'value argument is required when to_replace is not a dictionary.'): + self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From f77270b8811bbd8956d0c08fa556265d2c5ee20e Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 9 Feb 2018 08:45:06 -0600 Subject: [PATCH 0312/2461] [SPARK-23358][CORE] When the number of partitions is greater than 2^28, it will result in an error result ## What changes were proposed in this pull request? In the `checkIndexAndDataFile`,the `blocks` is the ` Int` type, when it is greater than 2^28, `blocks*8` will overflow, and this will result in an error result. In fact, `blocks` is actually the number of partitions. ## How was this patch tested? Manual test Author: liuxian Closes #20544 from 10110346/overflow. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index c5f3f6e2b42b6..d88b25cc7e258 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -84,7 +84,7 @@ private[spark] class IndexShuffleBlockResolver( */ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8) { + if (index.length() != (blocks + 1) * 8L) { return null } val lengths = new Array[Long](blocks) From 0fc26313f8071cdcb4ccd67bb1d6942983199d36 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 9 Feb 2018 08:46:27 -0600 Subject: [PATCH 0313/2461] [SPARK-21860][CORE][FOLLOWUP] fix java style error ## What changes were proposed in this pull request? #19077 introduced a Java style error (too long line). Quick fix. ## How was this patch tested? running `./dev/lint-java` Author: Marco Gaido Closes #20558 from mgaido91/SPARK-21860. --- .../test/java/org/apache/spark/unsafe/PlatformUtilSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 71c53d35dcab8..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -139,7 +139,8 @@ public void memoryDebugFillEnabledInTest() { @Test public void heapMemoryReuse() { MemoryAllocator heapMem = new HeapMemoryAllocator(); - // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // allocate new memory every time. MemoryBlock onheap1 = heapMem.allocate(513); Object obj1 = onheap1.getBaseObject(); heapMem.free(onheap1); From 7f10cf83f311526737fc96d5bb8281d12e41932f Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Fri, 9 Feb 2018 11:21:20 -0800 Subject: [PATCH 0314/2461] [SPARK-16501][MESOS] Allow providing Mesos principal & secret via files This commit modifies the Mesos submission client to allow the principal and secret to be provided indirectly via files. The path to these files can be specified either via Spark configuration or via environment variable. Assuming these files are appropriately protected by FS/OS permissions this means we don't ever leak the actual values in process info like ps Environment variable specification is useful because it allows you to interpolate the location of this file when using per-user Mesos credentials. For some background as to why we have taken this approach I will briefly describe our set up. On our systems we provide each authorised user account with their own Mesos credentials to provide certain security and audit guarantees to our customers. These credentials are managed by a central Secret management service. In our `spark-env.sh` we determine the appropriate secret and principal files to use depending on the user who is invoking Spark hence the need to inject these via environment variables as well as by configuration properties. So we set these environment variables appropriately and our Spark read in the contents of those files to authenticate itself with Mesos. This is functionality we have been using it in production across multiple customer sites for some time. This has been in the field for around 18 months with no reported issues. These changes have been sufficient to meet our customer security and audit requirements. We have been building and deploying custom builds of Apache Spark with various minor tweaks like this which we are now looking to contribute back into the community in order that we can rely upon stock Apache Spark builds and stop maintaining our own internal fork. Author: Rob Vesse Closes #20167 from rvesse/SPARK-16501. --- docs/running-on-mesos.md | 40 ++++- .../cluster/mesos/MesosSchedulerUtils.scala | 55 ++++-- .../mesos/MesosSchedulerUtilsSuite.scala | 161 +++++++++++++++++- 3 files changed, 238 insertions(+), 18 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 2bb5ecf1b8509..8e58892e2689f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -82,6 +82,27 @@ a Spark driver program configured to connect to Mesos. Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure `spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. +## Authenticating to Mesos + +When Mesos Framework authentication is enabled it is necessary to provide a principal and secret by which to authenticate Spark to Mesos. Each Spark job will register with Mesos as a separate framework. + +Depending on your deployment environment you may wish to create a single set of framework credentials that are shared across all users or create framework credentials for each user. Creating and managing framework credentials should be done following the Mesos [Authentication documentation](http://mesos.apache.org/documentation/latest/authentication/). + +Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. + +Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. + +### Credential Specification Preference Order + +Please note that if you specify multiple ways to obtain the credentials then the following preference order applies. Spark will use the first valid value found and any subsequent values are ignored: + +- `spark.mesos.principal` configuration setting +- `SPARK_MESOS_PRINCIPAL` environment variable +- `spark.mesos.principal.file` configuration setting +- `SPARK_MESOS_PRINCIPAL_FILE` environment variable + +An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -427,7 +448,14 @@ See the [configuration page](configuration.html) for information on Spark config
+ + + + + @@ -435,7 +463,15 @@ See the [configuration page](configuration.html) for information on Spark config + + + + + diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index e75450369ad85..ecbcc960fc5a0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.File +import java.nio.charset.StandardCharsets import java.util.{List => JList} import java.util.concurrent.CountDownLatch @@ -25,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter +import com.google.common.io.Files import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability @@ -71,26 +74,15 @@ trait MesosSchedulerUtils extends Logging { failoverTimeout: Option[Double] = None, frameworkId: Option[String] = None): SchedulerDriver = { val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) - val credBuilder = Credential.newBuilder() + fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS))) webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } frameworkId.foreach { id => fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) } - fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( - conf.get(DRIVER_HOST_ADDRESS))) - conf.getOption("spark.mesos.principal").foreach { principal => - fwInfoBuilder.setPrincipal(principal) - credBuilder.setPrincipal(principal) - } - conf.getOption("spark.mesos.secret").foreach { secret => - credBuilder.setSecret(secret) - } - if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { - throw new SparkException( - "spark.mesos.principal must be configured when spark.mesos.secret is set") - } + conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } @@ -98,6 +90,7 @@ trait MesosSchedulerUtils extends Logging { if (maxGpus > 0) { fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) } + val credBuilder = buildCredentials(conf, fwInfoBuilder) if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) @@ -106,6 +99,40 @@ trait MesosSchedulerUtils extends Logging { } } + def buildCredentials( + conf: SparkConf, + fwInfoBuilder: Protos.FrameworkInfo.Builder): Protos.Credential.Builder = { + val credBuilder = Credential.newBuilder() + conf.getOption("spark.mesos.principal") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL"))) + .orElse( + conf.getOption("spark.mesos.principal.file") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL_FILE"))) + .map { principalFile => + Files.toString(new File(principalFile), StandardCharsets.UTF_8) + } + ).foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET"))) + .orElse( + conf.getOption("spark.mesos.secret.file") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET_FILE"))) + .map { secretFile => + Files.toString(new File(secretFile), StandardCharsets.UTF_8) + } + ).foreach { secret => + credBuilder.setSecret(secret) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + credBuilder + } + /** * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. * This driver is expected to not be running. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 7df738958f85c..8d90e1a8591ad 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.{File, FileNotFoundException} + import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import org.apache.mesos.Protos.{Resource, Value} +import com.google.common.io.Files +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Value} import org.mockito.Mockito._ import org.scalatest._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.internal.config._ +import org.apache.spark.util.SparkConfWithEnv class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { @@ -237,4 +241,157 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} portsToUse.isEmpty shouldBe true } + + test("Principal specified via spark.mesos.principal") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", pFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "test-principal")) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> pFile.getAbsolutePath())) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE that does not exist") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> "/tmp/does-not-exist")) + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Secret specified via spark.mesos.secret") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", sFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_SECRET") { + val env = Map("SPARK_MESOS_SECRET" -> "my-secret") + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via SPARK_MESOS_SECRET_FILE") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + + val sFilePath = sFile.getAbsolutePath() + val env = Map("SPARK_MESOS_SECRET_FILE" -> sFilePath) + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Secret specified with no principal") { + val conf = new SparkConf() + conf.set("spark.mesos.secret", "my-secret") + + intercept[SparkException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "other-principal")) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Secret specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_SECRET" -> "other-secret")) + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } } From 557938e2839afce26a10a849a2a4be8fc4580427 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 9 Feb 2018 18:18:30 -0600 Subject: [PATCH 0315/2461] [MINOR][HIVE] Typo fixes ## What changes were proposed in this pull request? Typo fixes (with expanding a Hive property) ## How was this patch tested? local build. Awaiting Jenkins Author: Jacek Laskowski Closes #20550 from jaceklaskowski/hiveutils-typos. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 93f3f38e52aa9..c448c5a9821be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -304,7 +304,7 @@ private[spark] object HiveUtils extends Logging { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"Specify a valid path to the correct hive jars using ${HIVE_METASTORE_JARS.key} " + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } @@ -324,7 +324,7 @@ private[spark] object HiveUtils extends Logging { if (jars.length == 0) { throw new IllegalArgumentException( "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") + s"Please set ${HIVE_METASTORE_JARS.key}.") } logInfo( From 6d7c38330e68c7beb10f54eee8b4f607ee3c4136 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 9 Feb 2018 16:21:47 -0800 Subject: [PATCH 0316/2461] [SPARK-23275][SQL] fix the thread leaking in hive/tests ## What changes were proposed in this pull request? This is a follow up of https://github.com/apache/spark/pull/20441. The two lines actually can trigger the hive metastore bug: https://issues.apache.org/jira/browse/HIVE-16844 The two configs are not in the default `ObjectStore` properties, so any run hive commands after these two lines will set the `propsChanged` flag in the `ObjectStore.setConf` and then cause thread leaks. I don't think the two lines are very useful. They can be removed safely. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20562 from liufengdb/fix-omm. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 59708e7a0f2ff..19028939f3673 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -530,8 +530,6 @@ private[hive] class TestHiveSparkSession( // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 metadataHive.runSqlHive("set hive.table.parameters.default=") - metadataHive.runSqlHive("set datanucleus.cache.collections=true") - metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") From 97a224a855c4410b2dfb9c0bcc6aae583bd28e92 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 11 Feb 2018 01:08:02 +0900 Subject: [PATCH 0317/2461] [SPARK-23360][SQL][PYTHON] Get local timezone from environment via pytz, or dateutil. ## What changes were proposed in this pull request? Currently we use `tzlocal()` to get Python local timezone, but it sometimes causes unexpected behavior. I changed the way to get Python local timezone to use pytz if the timezone is specified in environment variable, or timezone file via dateutil . ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20559 from ueshin/issues/SPARK-23360/master. --- python/pyspark/sql/tests.py | 28 ++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 23 +++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6ace16955000d..1087c3fafdd16 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2868,6 +2868,34 @@ def test_create_dataframe_required_pandas_not_found(self): "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) + # Regression test for SPARK-23360 + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) + def test_create_dateframe_from_pandas_with_dst(self): + import pandas as pd + from datetime import datetime + + pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + + orig_env_tz = os.environ.get('TZ', None) + orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') + try: + tz = 'America/Los_Angeles' + os.environ['TZ'] = tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', tz) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + finally: + del os.environ['TZ'] + if orig_env_tz is not None: + os.environ['TZ'] = orig_env_tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 093dae5a22e1f..2599dc5fdc599 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1709,6 +1709,21 @@ def _check_dataframe_convert_date(pdf, schema): return pdf +def _get_local_timezone(): + """ Get local timezone using pytz with environment variable, or dateutil. + + If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone + string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and + it reads system configuration to know the system local timezone. + + See also: + - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 + - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 + """ + import os + return os.environ.get('TZ', 'dateutil/:') + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1721,7 +1736,7 @@ def _check_dataframe_localize_timestamps(pdf, timezone): require_minimum_pandas_version() from pandas.api.types import is_datetime64tz_dtype - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): @@ -1744,7 +1759,7 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() return s.dt.tz_localize(tz).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') @@ -1766,8 +1781,8 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): import pandas as pd from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - from_tz = from_timezone or 'tzlocal()' - to_tz = to_timezone or 'tzlocal()' + from_tz = from_timezone or _get_local_timezone() + to_tz = to_timezone or _get_local_timezone() # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert(to_tz).dt.tz_localize(None) From 0783876c81f212e1422a1b7786c26e3ac8e84f9f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Feb 2018 10:46:45 -0600 Subject: [PATCH 0318/2461] [SPARK-23344][PYTHON][ML] Add distanceMeasure param to KMeans ## What changes were proposed in this pull request? SPARK-22119 introduced a new parameter for KMeans, ie. `distanceMeasure`. The PR adds it also to the Python interface. ## How was this patch tested? added UTs Author: Marco Gaido Closes #20520 from mgaido91/SPARK-23344. --- python/pyspark/ml/clustering.py | 32 +++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 18 ++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 66fb00508522e..6448b76a0da88 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -403,17 +403,23 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) + self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, + distanceMeasure="euclidean") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -423,10 +429,12 @@ def _create_model(self, java_model): @keyword_only @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") Sets params for KMeans. """ @@ -475,6 +483,20 @@ def getInitSteps(self): """ return self.getOrDefault(self.initSteps) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 75d04785a0710..6d6737241e06e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -418,6 +418,9 @@ def test_kmeans_param(self): self.assertEqual(algo.getK(), 10) algo.setInitSteps(10) self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") def test_hasseed(self): noSeedSpecd = TestParams() @@ -1620,6 +1623,21 @@ def test_kmeans_summary(self): self.assertEqual(s.k, 2) +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + class OneVsRestTests(SparkSessionTestCase): def test_copy(self): From a34fce19bc0ee5a7e36c6ecba75d2aeb70fdcbc7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sun, 11 Feb 2018 17:31:35 +0900 Subject: [PATCH 0319/2461] [SPARK-23314][PYTHON] Add ambiguous=False when localizing tz-naive timestamps in Arrow codepath to deal with dst ## What changes were proposed in this pull request? When tz_localize a tz-naive timetamp, pandas will throw exception if the timestamp is during daylight saving time period, e.g., `2015-11-01 01:30:00`. This PR fixes this issue by setting `ambiguous=False` when calling tz_localize, which is the same default behavior of pytz. ## How was this patch tested? Add `test_timestamp_dst` Author: Li Jin Closes #20537 from icexelloss/SPARK-23314. --- python/pyspark/sql/tests.py | 39 +++++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 37 ++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1087c3fafdd16..4bc59fd99fca5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3670,6 +3670,21 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) + + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4311,6 +4326,18 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res1.collect()) self.assertEquals(expected.collect(), res2.collect()) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda x: x, 'timestamp') + result = df.withColumn('time', foo_udf(df.time)) + self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4482,6 +4509,18 @@ def test_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) + result = df.groupby('time').apply(foo_udf).sort('time') + self.assertPandasEqual(df.toPandas(), result.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 2599dc5fdc599..f7141b4549e4e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1759,8 +1759,38 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): + # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive + # timestamp is during the hour when the clock is adjusted backward during due to + # daylight saving time (dst). + # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to + # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize + # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either + # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). + # + # Here we explicit choose to use standard time. This matches the default behavior of + # pytz. + # + # Here are some code to help understand this behavior: + # >>> import datetime + # >>> import pandas as pd + # >>> import pytz + # >>> + # >>> t = datetime.datetime(2015, 11, 1, 1, 30) + # >>> ts = pd.Series([t]) + # >>> tz = pytz.timezone('America/New_York') + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=True) + # 0 2015-11-01 01:30:00-04:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=False) + # 0 2015-11-01 01:30:00-05:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> str(tz.localize(t)) + # '2015-11-01 01:30:00-05:00' tz = timezone or _get_local_timezone() - return s.dt.tz_localize(tz).dt.tz_convert('UTC') + return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: @@ -1788,8 +1818,9 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): return s.dt.tz_convert(to_tz).dt.tz_localize(None) elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) - if ts is not pd.NaT else pd.NaT) + return s.apply( + lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) else: return s From 8acb51f08b448628b65e90af3b268994f9550e45 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Feb 2018 18:55:38 +0900 Subject: [PATCH 0320/2461] [SPARK-23084][PYTHON] Add unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark ## What changes were proposed in this pull request? Added unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark, also updated the rangeBetween API ## How was this patch tested? did unit test on my local. Please let me know if I need to add unit test in tests.py Author: Huaxin Gao Closes #20400 from huaxingao/spark_23084. --- python/pyspark/sql/functions.py | 30 ++++++++++++++ python/pyspark/sql/window.py | 70 ++++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 05031f5ec87d7..9bb9c323a5a60 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -809,6 +809,36 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +@since(2.4) +def unboundedPreceding(): + """ + Window function: returns the special frame boundary that represents the first row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedPreceding()) + + +@since(2.4) +def unboundedFollowing(): + """ + Window function: returns the special frame boundary that represents the last row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedFollowing()) + + +@since(2.4) +def currentRow(): + """ + Window function: returns the special frame boundary that represents the current row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.currentRow()) + + # ---------------------- Date/Timestamp functions ------------------------------ @since(1.5) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 7ce27f9b102c0..bb841a9b9ff7c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,9 +16,11 @@ # import sys +if sys.version >= '3': + long = int from pyspark import since, SparkContext -from pyspark.sql.column import _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -120,20 +122,45 @@ def rangeBetween(start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). + + >>> from pyspark.sql import functions as F, SparkSession, Window + >>> spark = SparkSession.builder.getOrCreate() + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> window = Window.orderBy("id").partitionBy("category").rangeBetween( + ... F.currentRow(), F.lit(1)) + >>> df.withColumn("sum", F.sum("id").over(window)).show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| b| 3| + | 2| b| 5| + | 3| b| 3| + | 1| a| 4| + | 1| a| 4| + | 2| a| 2| + +---+--------+---+ """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -208,27 +235,34 @@ def rangeBetween(self, start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc return WindowSpec(self._jspec.rangeBetween(start, end)) def _test(): import doctest SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod() + (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: exit(-1) From eacb62fbbed317fd0e972102838af231385d54d8 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 11 Feb 2018 19:23:15 +0900 Subject: [PATCH 0321/2461] [SPARK-22624][PYSPARK] Expose range partitioning shuffle introduced by spark-22614 ## What changes were proposed in this pull request? Expose range partitioning shuffle introduced by spark-22614 ## How was this patch tested? Unit test in dataframe.py Please review http://spark.apache.org/contributing.html before opening a pull request. Author: xubo245 <601450868@qq.com> Closes #20456 from xubo245/SPARK22624_PysparkRangePartition. --- python/pyspark/sql/dataframe.py | 45 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 28 ++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index faee870a2d2e2..5cc8b63cdfadf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -667,6 +667,51 @@ def repartition(self, numPartitions, *cols): else: raise TypeError("numPartitions should be an int or Column") + @since("2.4.0") + def repartitionByRange(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is range partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + At least one partition-by expression must be specified. + When no explicit sort order is specified, "ascending nulls first" is assumed. + + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() + 2 + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + >>> df.repartitionByRange(1, "age").rdd.getNumPartitions() + 1 + >>> data = df.repartitionByRange("age") + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return ValueError("At least one partition-by expression must be specified.") + else: + return DataFrame( + self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions,) + cols + return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int, string or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4bc59fd99fca5..fe89bd0685027 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2148,6 +2148,34 @@ def test_expr(self): result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) + def test_repartitionByRange_dataframe(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + df1 = self.spark.createDataFrame( + [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) + df2 = self.spark.createDataFrame( + [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) + + # test repartitionByRange(numPartitions, *cols) + df3 = df1.repartitionByRange(2, "name", "age") + self.assertEqual(df3.rdd.getNumPartitions(), 2) + self.assertEqual(df3.rdd.first(), df2.rdd.first()) + self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(numPartitions, *cols) + df4 = df1.repartitionByRange(3, "name", "age") + self.assertEqual(df4.rdd.getNumPartitions(), 3) + self.assertEqual(df4.rdd.first(), df2.rdd.first()) + self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(*cols) + df5 = df1.repartitionByRange("name", "age") + self.assertEqual(df5.rdd.first(), df2.rdd.first()) + self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), From 4bbd7443ebb005f81ed6bc39849940ac8db3b3cc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 00:03:49 +0800 Subject: [PATCH 0322/2461] [SPARK-23376][SQL] creating UnsafeKVExternalSorter with BytesToBytesMap may fail ## What changes were proposed in this pull request? This is a long-standing bug in `UnsafeKVExternalSorter` and was reported in the dev list multiple times. When creating `UnsafeKVExternalSorter` with `BytesToBytesMap`, we need to create a `UnsafeInMemorySorter` to sort the data in `BytesToBytesMap`. The data format of the sorter and the map is same, so no data movement is required. However, both the sorter and the map need a point array for some bookkeeping work. There is an optimization in `UnsafeKVExternalSorter`: reuse the point array between the sorter and the map, to avoid an extra memory allocation. This sounds like a reasonable optimization, the length of the `BytesToBytesMap` point array is at least 4 times larger than the number of keys(to avoid hash collision, the hash table size should be at least 2 times larger than the number of keys, and each key occupies 2 slots). `UnsafeInMemorySorter` needs the pointer array size to be 4 times of the number of entries, so we are safe to reuse the point array. However, the number of keys of the map doesn't equal to the number of entries in the map, because `BytesToBytesMap` supports duplicated keys. This breaks the assumption of the above optimization and we may run out of space when inserting data into the sorter, and hit error ``` java.lang.IllegalStateException: There is no space for new record at org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.insertRecord(UnsafeInMemorySorter.java:239) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:149) ... ``` This PR fixes this bug by creating a new point array if the existing one is not big enough. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20561 from cloud-fan/bug. --- .../sql/execution/UnsafeKVExternalSorter.java | 31 +++++++++++---- .../UnsafeKVExternalSorterSuite.scala | 39 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index b0b5383a081a0..9eb03430a7db2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.unsafe.sort.*; @@ -98,19 +99,33 @@ public UnsafeKVExternalSorter( numElementsForSpillThreshold, canUseRadixSort); } else { - // The array will be used to do in-place sort, which require half of the space to be empty. - // Note: each record in the map takes two entries in the array, one is record pointer, - // another is the key prefix. - assert(map.numKeys() * 2 <= map.getArray().size() / 2); - // During spilling, the array in map will not be used, so we can borrow that and use it - // as the underlying array for in-memory sorter (it's always large enough). - // Since we will not grow the array, it's fine to pass `null` as consumer. + // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow + // that and use it as the pointer array for `UnsafeInMemorySorter`. + LongArray pointerArray = map.getArray(); + // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but + // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since + // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer + // array can hold all the entries in `BytesToBytesMap`. + // The pointer array will be used to do in-place sort, which requires half of the space to be + // empty. Note: each record in the map takes two entries in the pointer array, one is record + // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`. + // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key, + // so that we can always reuse the pointer array. + if (map.numValues() > pointerArray.size() / 4) { + // Here we ask the map to allocate memory, so that the memory manager won't ask the map + // to spill, if the memory is not enough. + pointerArray = map.allocateArray(map.numValues() * 4L); + } + + // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed + // to be large enough, it's fine to pass `null` as consumer because we won't allocate more + // memory. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( null, taskMemoryManager, comparatorSupplier.get(), prefixComparator, - map.getArray(), + pointerArray, canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 6af9f8b77f8d3..bf588d3bb7841 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.map.BytesToBytesMap /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -205,4 +206,42 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } + + test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") { + val memoryManager = new TestMemoryManager(new SparkConf()) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes()) + + // Key/value are a unsafe rows with a single int column + val schema = new StructType().add("i", IntegerType) + val key = new UnsafeRow(1) + key.pointTo(new Array[Byte](32), 32) + key.setInt(0, 1) + val value = new UnsafeRow(1) + value.pointTo(new Array[Byte](32), 32) + value.setInt(0, 2) + + for (_ <- 1 to 65) { + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + + // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` + // which has duplicated keys and the number of entries exceeds its capacity. + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + new UnsafeKVExternalSorter( + schema, + schema, + sparkContext.env.blockManager, + sparkContext.env.serializerManager, + taskMemoryManager.pageSizeBytes(), + Int.MaxValue, + map) + } finally { + TaskContext.unset() + } + } } From c0c902aedcf9ed24e482d873d766a7df63b964cb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 11 Feb 2018 20:15:30 -0600 Subject: [PATCH 0323/2461] [SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance ## What changes were proposed in this pull request? In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido Closes #20518 from mgaido91/SPARK-22119_followup. --- .../spark/mllib/clustering/KMeans.scala | 54 ++++++++++++++++--- .../spark/ml/clustering/KMeansSuite.scala | 15 +++++- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 607145cb59fba..3c4ba0bc60c7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -310,8 +310,7 @@ class KMeans private ( points.foreach { point => val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) - val sum = sums(bestCenter) - axpy(1.0, point.vector, sum) + distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) counts(bestCenter) += 1 } @@ -319,10 +318,9 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.mapValues { case (sum, count) => - scal(1.0 / count, sum) - new VectorWithNorm(sum) - }.collectAsMap() + }.collectAsMap().mapValues { case (sum, count) => + distanceMeasureInstance.centroid(sum, count) + } bcCenters.destroy(blocking = false) @@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable { v1: VectorWithNorm, v2: VectorWithNorm): Double + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } } @Since("2.4.0") @@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { * @return the cosine distance between the two input vectors */ override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e4506f23feb31..32830b39407ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == predictionsMap(Vectors.dense(-100.0, 90.0))) + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } + + test("KMeans with cosine distance is not supported for 0-length vectors") { + val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } test("read/write") { From 6efd5d117e98074d1b16a5c991fbd38df9aa196e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 11 Feb 2018 23:46:23 -0800 Subject: [PATCH 0324/2461] [SPARK-23390][SQL] Flaky Test Suite: FileBasedDataSourceSuite in Spark 2.3/hadoop 2.7 ## What changes were proposed in this pull request? This test only fails with sbt on Hadoop 2.7, I can't reproduce it locally, but here is my speculation by looking at the code: 1. FileSystem.delete doesn't delete the directory entirely, somehow we can still open the file as a 0-length empty file.(just speculation) 2. ORC intentionally allow empty files, and the reader fails during reading without closing the file stream. This PR improves the test to make sure all files are deleted and can't be opened. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20584 from cloud-fan/flaky-test. --- .../spark/sql/FileBasedDataSourceSuite.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 640d6b1583663..2e332362ea644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.FileNotFoundException + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -102,17 +104,27 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { def testIgnoreMissingFiles(): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) + val df = spark.read.format(format).load( new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + // Make sure all data files are deleted and can't be opened. + files.foreach(f => fs.delete(f, false)) assert(fs.delete(thirdPath, true)) + for (f <- files) { + intercept[FileNotFoundException](fs.open(f)) + } + checkAnswer(df, Seq(Row("0"), Row("1"))) } } From c338c8cf8253c037ecd4f39bbd58ed5a86581b37 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 12 Feb 2018 20:49:36 +0900 Subject: [PATCH 0325/2461] [SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs ## What changes were proposed in this pull request? This PR targets to explicitly specify supported types in Pandas UDFs. The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things. 1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see: ```python from pyspark.sql.functions import pandas_udf pudf = pandas_udf(lambda x: x, "binary") df = spark.createDataFrame([[bytearray(1)]]) df.select(pudf("_1")).show() ``` ``` ... TypeError: Unsupported type in conversion to Arrow: BinaryType ``` We can document this behaviour for its guide. 2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case. ```python from pyspark.sql.functions import pandas_udf, PandasUDFType foo = pandas_udf(lambda v: v.mean(), 'array', PandasUDFType.GROUPED_AGG) df = spark.range(100).selectExpr("id", "array(id) as value") df.groupBy("id").agg(foo("value")).show() ``` ``` ... NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG ``` 3. Since we can check the return type ahead, we can fail fast before actual execution. ```python # we can fail fast at this stage because we know the schema ahead pandas_udf(lambda x: x, BinaryType()) ``` ## How was this patch tested? Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added. Author: hyukjinkwon Closes #20531 from HyukjinKwon/pudf-cleanup. --- docs/sql-programming-guide.md | 4 +- python/pyspark/sql/tests.py | 130 +++++++++++------- python/pyspark/sql/types.py | 4 + python/pyspark/sql/udf.py | 36 +++-- python/pyspark/worker.py | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- 6 files changed, 111 insertions(+), 67 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index eab4030ee25d2..6174a93b68492 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) @@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p ### Supported SQL Types -Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`, `ArrayType` of `TimestampType`, and nested `StructType`. ### Setting Arrow Batch Size diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe89bd0685027..2af218a691026 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3790,10 +3790,10 @@ def foo(x): self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) def foo(x): return x - self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.returnType, DoubleType()) self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) @@ -3830,7 +3830,7 @@ def zero_with_type(): @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df - with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df @@ -3879,7 +3879,7 @@ def random_udf(v): return random_udf def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -3887,7 +3887,8 @@ def test_vectorized_udf_basic(self): col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) + col('id').cast('boolean').alias('bool'), + array(col('id')).alias('array_long')) f = lambda x: x str_f = pandas_udf(f, StringType()) int_f = pandas_udf(f, IntegerType()) @@ -3896,10 +3897,11 @@ def test_vectorized_udf_basic(self): double_f = pandas_udf(f, DoubleType()) decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) + array_long_f = pandas_udf(f, ArrayType(LongType())) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) + bool_f(col('bool')), array_long_f('array_long')) self.assertEquals(df.collect(), res.collect()) def test_register_nondeterministic_vectorized_udf_basic(self): @@ -4104,10 +4106,11 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.select(f(col('id'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): from pyspark.sql.functions import pandas_udf, col @@ -4142,13 +4145,18 @@ def test_vectorized_udf_varargs(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('map'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col @@ -4379,15 +4387,16 @@ def data(self): .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data + def test_supported_types(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + df = self.data.withColumn("arr", array(col("id"))) foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), StructType( [StructField('id', LongType()), StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), StructField('v1', DoubleType()), StructField('v2', LongType())]), PandasUDFType.GROUPED_MAP @@ -4490,17 +4499,15 @@ def test_datatype_string(self): def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUPED_MAP - ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.groupby('id').apply(foo).sort('id').toPandas() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf( + lambda pdf: pdf, + 'id long, v map', + PandasUDFType.GROUPED_MAP) def test_wrong_args(self): from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType @@ -4519,23 +4526,30 @@ def test_wrong_args(self): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) + df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), - PandasUDFType.SCALAR)) + pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType + from pyspark.sql.functions import pandas_udf, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.groupby('id').apply(f).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + schema = StructType( + [StructField("id", LongType(), True), + StructField("arr_ts", ArrayType(TimestampType()), True)]) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) # Regression test for SPARK-23314 def test_timestamp_dst(self): @@ -4614,23 +4628,32 @@ def weighted_mean(v, w): return weighted_mean def test_manual(self): + from pyspark.sql.functions import pandas_udf, array + df = self.data sum_udf = self.pandas_agg_sum_udf mean_udf = self.pandas_agg_mean_udf - - result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + mean_arr_udf = pandas_udf( + self.pandas_agg_mean_udf.func, + ArrayType(self.pandas_agg_mean_udf.returnType), + self.pandas_agg_mean_udf.evalType) + + result1 = df.groupby('id').agg( + sum_udf(df.v), + mean_udf(df.v), + mean_arr_udf(array(df.v))).sort('id') expected1 = self.spark.createDataFrame( - [[0, 245.0, 24.5], - [1, 255.0, 25.5], - [2, 265.0, 26.5], - [3, 275.0, 27.5], - [4, 285.0, 28.5], - [5, 295.0, 29.5], - [6, 305.0, 30.5], - [7, 315.0, 31.5], - [8, 325.0, 32.5], - [9, 335.0, 33.5]], - ['id', 'sum(v)', 'avg(v)']) + [[0, 245.0, 24.5, [24.5]], + [1, 255.0, 25.5, [25.5]], + [2, 265.0, 26.5, [26.5]], + [3, 275.0, 27.5, [27.5]], + [4, 285.0, 28.5, [28.5]], + [5, 295.0, 29.5, [29.5]], + [6, 305.0, 30.5, [30.5]], + [7, 315.0, 31.5, [31.5]], + [8, 325.0, 32.5, [32.5]], + [9, 335.0, 33.5, [33.5]]], + ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) @@ -4667,14 +4690,15 @@ def test_basic(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_unsupported_types(self): - from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.types import DoubleType, MapType from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return [v.mean(), v.std()] + pandas_udf( + lambda x: x, + ArrayType(ArrayType(TimestampType())), + PandasUDFType.GROUPED_AGG) with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f7141b4549e4e..e25941cd37595 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1638,6 +1638,8 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: + if type(dt.elementType) == TimestampType: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -1680,6 +1682,8 @@ def from_arrow_type(at): elif types.is_timestamp(at): spark_type = TimestampType() elif types.is_list(at): + if types.is_timestamp(at.value_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 08c6b9e521e82..e5b35fc60e167 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -23,7 +23,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string + _parse_datatype_string, to_arrow_type, to_arrow_schema __all__ = ["UDFRegistration"] @@ -112,15 +112,31 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ - and not isinstance(self._returnType_placeholder, StructType): - raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUPED_MAP") - elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ - and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): - raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with " - "PandasUDFType.GROUPED_AGG") + if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with scalar Pandas UDFs: %s is " + "not supported" % str(self._returnType_placeholder)) + elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_schema(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped map Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) + else: + raise TypeError("Invalid returnType for grouped map Pandas " + "UDFs: returnType must be a StructType.") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped aggregate Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 121b3dd1aeec9..89a3a92bc66d6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -116,7 +116,7 @@ def wrap_grouped_agg_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd result = f(*series) - return pd.Series(result) + return pd.Series([result]) return lambda *a: (wrapped(*a), arrow_return_type) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1e2501ee7757d..7835dbaa58439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,7 +1064,7 @@ object SQLConf { "for use with pyspark.sql.DataFrame.toPandas, and " + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + "The following data types are unsupported: " + - "MapType, ArrayType of TimestampType, and nested StructType.") + "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From caeb108e25e5bfb7cffcf09ef9abbb1abcfa355d Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 12 Feb 2018 22:05:27 +0800 Subject: [PATCH 0326/2461] [MINOR][TEST] spark.testing` No effect on the SparkFunSuite unit test ## What changes were proposed in this pull request? Currently, we use SBT and MAVN to spark unit test, are affected by the parameters of `spark.testing`. However, when using the IDE test tool, `spark.testing` support is not very good, sometimes need to be manually added to the beforeEach. example: HiveSparkSubmitSuite RPackageUtilsSuite SparkSubmitSuite. The PR unified `spark.testing` parameter extraction to SparkFunSuite, support IDE test tool, and the test code is more compact. ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20582 from heary-cao/sparktesting. --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 1 + .../test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala | 1 - .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 1 - .../spark/network/netty/NettyBlockTransferServiceSuite.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 1 - 5 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 3af9d82393bc4..31289026b0027 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -59,6 +59,7 @@ abstract class SparkFunSuite protected val enableAutoThreadAudit = true protected override def beforeAll(): Unit = { + System.setProperty("spark.testing", "true") if (enableAutoThreadAudit) { doThreadPreAudit() } diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 32dd3ecc2f027..ef947eb074647 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -66,7 +66,6 @@ class RPackageUtilsSuite override def beforeEach(): Unit = { super.beforeEach() - System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 27dd435332348..803a38d77fb82 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -107,7 +107,6 @@ class SparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } // scalastyle:off println diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index f7bc3725d7278..78423ee68a0ec 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,6 +80,7 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests + // if `spark.testing` is true, // the default value for `spark.port.maxRetries` is 100 under test actualPort should be <= (expectedPort + 100) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10204f4694663..2d31781132edc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -50,7 +50,6 @@ class HiveSparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } test("temporary Hive UDF: define a UDF and use it") { From 0e2c266de7189473177f45aa68ea6a45c7e47ec3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 22:07:59 +0800 Subject: [PATCH 0327/2461] [SPARK-22977][SQL] fix web UI SQL tab for CTAS ## What changes were proposed in this pull request? This is a regression in Spark 2.3. In Spark 2.2, we have a fragile UI support for SQL data writing commands. We only track the input query plan of `FileFormatWriter` and display its metrics. This is not ideal because we don't know who triggered the writing(can be table insertion, CTAS, etc.), but it's still useful to see the metrics of the input query. In Spark 2.3, we introduced a new mechanism: `DataWritigCommand`, to fix the UI issue entirely. Now these writing commands have real children, and we don't need to hack into the `FileFormatWriter` for the UI. This also helps with `explain`, now `explain` can show the physical plan of the input query, while in 2.2 the physical writing plan is simply `ExecutedCommandExec` and it has no child. However there is a regression in CTAS. CTAS commands don't extend `DataWritigCommand`, and we don't have the UI hack in `FileFormatWriter` anymore, so the UI for CTAS is just an empty node. See https://issues.apache.org/jira/browse/SPARK-22977 for more information about this UI issue. To fix it, we should apply the `DataWritigCommand` mechanism to CTAS commands. TODO: In the future, we should refactor this part and create some physical layer code pieces for data writing, and reuse them in different writing commands. We should have different logical nodes for different operators, even some of them share some same logic, e.g. CTAS, CREATE TABLE, INSERT TABLE. Internally we can share the same physical logic. ## How was this patch tested? manually tested. For data source table 1 For hive table 2 Author: Wenchen Fan Closes #20521 from cloud-fan/UI. --- .../command/createDataSourceTables.scala | 21 +++---- .../execution/datasources/DataSource.scala | 44 +++++++++++++-- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../CreateHiveTableAsSelectCommand.scala | 55 ++++++++++--------- .../sql/hive/execution/HiveExplainSuite.scala | 26 --------- 6 files changed, 80 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 306f43dc4214a..e9747769dfcfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,7 +21,9 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -136,12 +138,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, - query: LogicalPlan) - extends RunnableCommand { - - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + query: LogicalPlan, + outputColumns: Seq[Attribute]) + extends DataWritingCommand { - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) assert(table.provider.isDefined) @@ -163,7 +164,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) + sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -173,7 +174,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) + sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of @@ -198,10 +199,10 @@ case class CreateDataSourceTableAsSelectCommand( session: SparkSession, table: CatalogTable, tableLocation: Option[URI], - data: LogicalPlan, + physicalPlan: SparkPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { - // Create the relation based on the input logical plan: `data`. + // Create the relation based on the input logical plan: `query`. val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, @@ -212,7 +213,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query) + dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 25e1210504273..6e1b5727e3fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -31,8 +31,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -435,10 +437,11 @@ case class DataSource( } /** - * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. + * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]]. + * The returned command is unresolved and need to be analyzed. */ private def planForWritingFileFormat( - format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { + format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -482,9 +485,24 @@ case class DataSource( /** * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. + * + * @param mode The save mode for this writing. + * @param data The input query plan that produces the data to be written. Note that this plan + * is analyzed and optimized. + * @param outputColumns The original output columns of the input query plan. The optimizer may not + * preserve the output column's names' case, so we need this parameter + * instead of `data.output`. + * @param physicalPlan The physical plan of the input query plan. We should run the writing + * command with this physical plan instead of creating a new physical plan, + * so that the metrics can be correctly linked to the given physical plan and + * shown in the web UI. */ - def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + def writeAndRead( + mode: SaveMode, + data: LogicalPlan, + outputColumns: Seq[Attribute], + physicalPlan: SparkPlan): BaseRelation = { + if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -493,9 +511,23 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd + val cmd = planForWritingFileFormat(format, mode, data) + val resolvedPartCols = cmd.partitionColumns.map { col => + // The partition columns created in `planForWritingFileFormat` should always be + // `UnresolvedAttribute` with a single name part. + assert(col.isInstanceOf[UnresolvedAttribute]) + val unresolved = col.asInstanceOf[UnresolvedAttribute] + assert(unresolved.nameParts.length == 1) + val name = unresolved.nameParts.head + outputColumns.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d94c5bbccdd84..3f41612c08065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ab857b9055720..8df05cbb20361 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -157,7 +157,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 65e8b4e3c725c..1e801fe1845c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand /** @@ -36,15 +37,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, + outputColumns: Seq[Attribute], mode: SaveMode) - extends RunnableCommand { + extends DataWritingCommand { private val tableIdentifier = tableDesc.identifier - override def innerChildren: Seq[LogicalPlan] = Seq(query) - - override def run(sparkSession: SparkSession): Seq[Row] = { - if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + if (catalog.tableExists(tableIdentifier)) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -56,34 +57,36 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = false, - ifPartitionNotExists = false)).toRdd + InsertIntoHiveTable( + tableDesc, + Map.empty, + query, + overwrite = false, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - sparkSession.sessionState.catalog.createTable( - tableDesc.copy(schema = query.schema), ignoreIfExists = false) + catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) try { - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = true, - ifPartitionNotExists = false)).toRdd + // Read back the metadata of the table which was created just now. + val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) + // For CTAS, there is no static partition values to insert. + val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + createdTableMeta, + partition, + query, + overwrite = true, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. - sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, - purge = false) + catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false) throw e } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index f84d188075b72..5d56f89c2271c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -128,32 +128,6 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { - withTempView("jt") { - val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() - spark.read.json(ds).createOrReplaceTempView("jt") - val outputs = sql( - s""" - |EXPLAIN EXTENDED - |CREATE TABLE t1 - |AS - |SELECT * FROM jt - """.stripMargin).collect().map(_.mkString).mkString - - val shouldContain = - "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: - "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: - "CreateHiveTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil - for (key <- shouldContain) { - assert(outputs.contains(key), s"$key doesn't exist in result") - } - - val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should contain SubqueryAlias since the query should not be optimized") - } - } - test("explain output of physical plan should contain proper codegen stage ID") { checkKeywordsExist(sql( """ From 4a4dd4f36f65410ef5c87f7b61a960373f044e61 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 12 Feb 2018 08:49:45 -0600 Subject: [PATCH 0328/2461] [SPARK-23391][CORE] It may lead to overflow for some integer multiplication ## What changes were proposed in this pull request? In the `getBlockData`,`blockId.reduceId` is the `Int` type, when it is greater than 2^28, `blockId.reduceId*8` will overflow In the `decompress0`, `len` and `unitSize` are Int type, so `len * unitSize` may lead to overflow ## How was this patch tested? N/A Author: liuxian Closes #20581 from 10110346/overflow2. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../execution/columnar/compression/compressionSchemes.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d88b25cc7e258..d3f1c7ec1bbee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -202,13 +202,13 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8) + channel.position(blockId.reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { val offset = in.readLong() val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8 + 16 + val expectedPosition = blockId.reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 79dcf3a6105ce..00a1d54b41709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -116,7 +116,7 @@ private[columnar] case object PassThrough extends CompressionScheme { while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos - assert(len * unitSize < Int.MaxValue) + assert(len * unitSize.toLong < Int.MaxValue) putFunction(columnVector, pos, bufferPos, len) bufferPos += len * unitSize pos += len From 5bb11411aec18b8d623e54caba5397d7cb8e89f0 Mon Sep 17 00:00:00 2001 From: James Thompson Date: Mon, 12 Feb 2018 11:34:56 -0800 Subject: [PATCH 0329/2461] [SPARK-23388][SQL] Support for Parquet Binary DecimalType in VectorizedColumnReader ## What changes were proposed in this pull request? Re-add support for parquet binary DecimalType in VectorizedColumnReader ## How was this patch tested? Existing test suite Author: James Thompson Closes #20580 from jamesthomp/jt/add-back-binary-decimal. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c120863152a96..47dd625f4b154 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -444,7 +444,8 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType + || DecimalType.isByteArrayDecimalType(column.dataType())) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { if (!shouldConvertTimestamps()) { From 0c66fe4f22f8af4932893134bb0fd56f00fabeae Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 12 Feb 2018 12:20:29 -0800 Subject: [PATCH 0330/2461] [SPARK-22002][SQL][FOLLOWUP][TEST] Add a test to check if the original schema doesn't have metadata. ## What changes were proposed in this pull request? This is a follow-up pr of #19231 which modified the behavior to remove metadata from JDBC table schema. This pr adds a test to check if the schema doesn't have metadata. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20585 from ueshin/issues/SPARK-22002/fup1. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cb2df0ac54f4c..5238adce4a699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1168,4 +1168,26 @@ class JDBCSuite extends SparkFunSuite val df3 = sql("SELECT * FROM test_sessionInitStatement") assert(df3.collect() === Array(Row(21519, 1234))) } + + test("jdbc data source shouldn't have unnecessary metadata in its schema") { + val schema = StructType(Seq( + StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("DbTaBle", "TEST.PEOPLE") + .load() + assert(df.schema === schema) + + withTempView("people_view") { + sql( + s""" + |CREATE TEMPORARY VIEW people_view + |USING org.apache.spark.sql.jdbc + |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + assert(sql("select * from people_view").schema === schema) + } + } } From fba01b9a65e5d9438d35da0bd807c179ba741911 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 14:58:31 -0800 Subject: [PATCH 0331/2461] [SPARK-23378][SQL] move setCurrentDatabase from HiveExternalCatalog to HiveClientImpl ## What changes were proposed in this pull request? This removes the special case that `alterPartitions` call from `HiveExternalCatalog` can reset the current database in the hive client as a side effect. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20564 from liufengdb/move. --- .../spark/sql/hive/HiveExternalCatalog.scala | 5 ---- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 3b8a8ca301c27..1ee1d57b8ebe1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1107,11 +1107,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - client.setCurrentDatabase(db) client.alterPartitions(db, table, withStatsProps) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6c0f4144992ae..c223f51b1be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -291,14 +291,18 @@ private[hive] class HiveClientImpl( state.err = stream } - override def setCurrentDatabase(databaseName: String): Unit = withHiveState { - if (databaseExists(databaseName)) { - state.setCurrentDatabase(databaseName) + private def setCurrentDatabaseRaw(db: String): Unit = { + if (databaseExists(db)) { + state.setCurrentDatabase(db) } else { - throw new NoSuchDatabaseException(databaseName) + throw new NoSuchDatabaseException(db) } } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + setCurrentDatabaseRaw(databaseName) + } + override def createDatabase( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { @@ -598,8 +602,18 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(userName)) - shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + val original = state.getCurrentDatabase + try { + setCurrentDatabaseRaw(db) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) + shim.alterPartitions(client, table, newParts.map { toHivePartition(_, hiveTable) }.asJava) + } finally { + state.setCurrentDatabase(original) + } } /** From 6cb59708c70c03696c772fbb5d158eed57fe67d4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Feb 2018 15:26:37 -0800 Subject: [PATCH 0332/2461] [SPARK-23313][DOC] Add a migration guide for ORC ## What changes were proposed in this pull request? This PR adds a migration guide documentation for ORC. ![orc-guide](https://user-images.githubusercontent.com/9700541/36123859-ec165cae-1002-11e8-90b7-7313be7a81a5.png) ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20484 from dongjoon-hyun/SPARK-23313. --- docs/sql-programming-guide.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6174a93b68492..0f9f01e18682f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1776,6 +1776,35 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 + - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. + + - New configurations + +
spark.mesos.principal (none) - Set the principal with which Spark framework will use to authenticate with Mesos. + Set the principal with which Spark framework will use to authenticate with Mesos. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL`. +
spark.mesos.principal.file(none) + Set the file containing the principal with which Spark framework will use to authenticate with Mesos. Allows specifying the principal indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL_FILE`.
(none) Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when - authenticating with the registry. + authenticating with the registry. You can also specify this via the environment variable `SPARK_MESOS_SECRET`. +
spark.mesos.secret.file(none) + Set the file containing the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when + authenticating with the registry. Allows for specifying the secret indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_SECRET_FILE`.
+ + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
+ + - Changed configurations + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
+ - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. From 4104b68e958cd13975567a96541dac7cccd8195c Mon Sep 17 00:00:00 2001 From: sychen Date: Mon, 12 Feb 2018 16:00:47 -0800 Subject: [PATCH 0333/2461] [SPARK-23230][SQL] When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error. We should take the default type of textfile and sequencefile both as org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe. ``` set hive.default.fileformat=orc; create table tbl( i string ) stored as textfile; desc formatted tbl; Serde Library org.apache.hadoop.hive.ql.io.orc.OrcSerde InputFormat org.apache.hadoop.mapred.TextInputFormat OutputFormat org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat ``` Author: sychen Closes #20406 from cxzl25/default_serde. --- .../apache/spark/sql/internal/HiveSerDe.scala | 6 ++++-- .../sql/hive/execution/HiveSerDeSuite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index dac463641cfab..eca612f06f9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -31,7 +31,8 @@ object HiveSerDe { "sequencefile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "rcfile" -> HiveSerDe( @@ -54,7 +55,8 @@ object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "avro" -> HiveSerDe( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 1c9f00141ae1d..d7752e987cb4b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -100,6 +100,25 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS textfile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS sequencefile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.SequenceFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } } test("create hive serde table with new syntax - basic") { From c1bcef876c1415e39e624cfbca9c9bdeae24cbb9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 13 Feb 2018 11:40:34 +0800 Subject: [PATCH 0334/2461] [SPARK-23323][SQL] Support commit coordinator for DataSourceV2 writes ## What changes were proposed in this pull request? DataSourceV2 batch writes should use the output commit coordinator if it is required by the data source. This adds a new method, `DataWriterFactory#useCommitCoordinator`, that determines whether the coordinator will be used. If the write factory returns true, `WriteToDataSourceV2` will use the coordinator for batch writes. ## How was this patch tested? This relies on existing write tests, which now use the commit coordinator. Author: Ryan Blue Closes #20490 from rdblue/SPARK-23323-add-commit-coordinator. --- .../sources/v2/writer/DataSourceWriter.java | 19 +++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 41 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index e3f682bf96a66..0a0fd8db58035 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -63,6 +63,16 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for + * each task commits. + * + * @return true if commit coordinator should be used, false otherwise. + */ + default boolean useCommitCoordinator() { + return true; + } + /** * Handles a commit message on receiving from a successful data writer. * @@ -79,10 +89,11 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * - * Note that, one partition may have multiple committed data writers because of speculative tasks. - * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data - * writer can commit, or have a way to clean up the data of already-committed writers. + * Note that speculative execution may cause multiple tasks to run for a partition. By default, + * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple + * attempts may have committed successfully and one successful commit message per task will be + * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index eefbcf4c0e087..535e7962d7439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -53,6 +54,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } + val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -73,7 +75,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e DataWritingSparkTask.runContinuous(writeTask, context, iter) case _ => (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter) + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) } sparkContext.runJob( @@ -116,21 +118,44 @@ object DataWritingSparkTask extends Logging { def run( writeTask: DataWriterFactory[InternalRow], context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean): WriterCommitMessage = { + val stageId = context.stageId() + val partId = context.partitionId() + val attemptId = context.attemptNumber() + val dataWriter = writeTask.createDataWriter(partId, attemptId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { iter.foreach(dataWriter.write) - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + if (commitAuthorized) { + logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + dataWriter.commit() + } else { + val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + logInfo(message) + // throwing CommitDeniedException will trigger the catch block for abort + throw new CommitDeniedException(message, stageId, partId, attemptId) + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + msg + })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for partition ${context.partitionId()} is aborting.") + logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") + logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } From ed4e78bd606e7defc2cd01a5c2e9b47954baa424 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 20:57:26 -0800 Subject: [PATCH 0335/2461] [SPARK-23379][SQL] skip when setting the same current database in HiveClientImpl ## What changes were proposed in this pull request? If the target database name is as same as the current database, we should be able to skip one metastore access. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20565 from liufengdb/remove-redundant. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c223f51b1be75..146fa54a1bce4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -292,10 +292,12 @@ private[hive] class HiveClientImpl( } private def setCurrentDatabaseRaw(db: String): Unit = { - if (databaseExists(db)) { - state.setCurrentDatabase(db) - } else { - throw new NoSuchDatabaseException(db) + if (state.getCurrentDatabase != db) { + if (databaseExists(db)) { + state.setCurrentDatabase(db) + } else { + throw new NoSuchDatabaseException(db) + } } } From f17b936f0ddb7d46d1349bd42f9a64c84c06e48d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 21:12:22 -0800 Subject: [PATCH 0336/2461] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- Relation AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- Relation AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == Relation AdvancedDataSourceV2[j#1] == Physical Plan == *(1) Scan AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) Scan JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) Scan FakeDataSourceV2$[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20477 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +--- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../v2/DataSourceV2QueryPlan.scala | 96 +++++++++++++++++++ .../datasources/v2/DataSourceV2Relation.scala | 26 +++-- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 +++-- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 +-- 15 files changed, 157 insertions(+), 127 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..72ee0c551ec3d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,8 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..d34458ac81014 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 02c87643568bd..cb09cce75ff6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,7 +117,8 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..984b6510f2dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,11 +189,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() + val ds = cls.newInstance().asInstanceOf[DataSourceV2] val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -221,7 +219,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala new file mode 100644 index 0000000000000..1e0d088f3a57c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A base class for data source v2 related query plan(both logical and physical). It defines the + * equals/hashCode methods, and provides a string representation of the query plan, according to + * some common information. + */ +trait DataSourceV2QueryPlan { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + /** + * The metadata of this data source query plan that can be used for equality check. + */ + private def metadata: Seq[Any] = Seq(output, source.getClass, filters) + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) + }, " (", ", ", ")") + } else { + "" + } + + s"${source.getClass.getSimpleName}$outputStr$entriesStr" + } + + private def redact(text: String): String = { + Utils.redact(SQLConf.get.stringRedationPattern, text) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..cd97e0cab6b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,15 +20,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + reader: DataSourceReader, + override val isStreaming: Boolean) + extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + override def simpleString: String = { + val streamingHeader = if (isStreaming) "Streaming " else "" + s"${streamingHeader}Relation $metadataString" + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -41,18 +49,8 @@ case class DataSourceV2Relation( } } -/** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. - */ -class StreamingDataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { - override def isStreaming: Boolean = true -} - object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..c99d535efcf81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,11 +37,14 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = s"Scan $metadataString" + override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..fb61e6f32b1f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..4cfdd50e8f46b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + case FilterAndProject(fields, condition, r: DataSourceV2Relation) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = reader match { + val stayUpFilters: Seq[Expression] = r.reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..84564b6639ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -90,6 +92,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -405,12 +408,15 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) + reader.setOffsetRange(toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + Some(reader -> new DataSourceV2Relation( + reader.readSchema().toAttributes, + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), + reader, + isStreaming = true)) case _ => None } } @@ -500,3 +506,5 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..f87d57d0b3209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(ds, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,8 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => + r.reader.asInstanceOf[ContinuousReader] }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..70eb9f0ac66d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..254394685857b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case d: DataSourceV2Relation => d.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..9ee9aaf87f87c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 407f67249639709c40c46917700ed6dd736daa7d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 15:05:13 +0900 Subject: [PATCH 0337/2461] [SPARK-20090][FOLLOW-UP] Revert the deprecation of `names` in PySpark ## What changes were proposed in this pull request? Deprecating the field `name` in PySpark is not expected. This PR is to revert the change. ## How was this patch tested? N/A Author: gatorsmile Closes #20595 from gatorsmile/removeDeprecate. --- python/pyspark/sql/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e25941cd37595..cd857402db8f7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -455,9 +455,6 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. - .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead - to get a list of field names. - >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) From 9dae715168a8e72e318ab231c34a1069bfa342a6 Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Tue, 13 Feb 2018 06:20:34 -0600 Subject: [PATCH 0338/2461] [SPARK-23318][ML] FP-growth: WARN FPGrowth: Input data is not cached ## What changes were proposed in this pull request? Cache the RDD of items in ml.FPGrowth before passing it to mllib.FPGrowth. Cache only when the user did not cache the input dataset of transactions. This fixes the warning about uncached data emerging from mllib.FPGrowth. ## How was this patch tested? Manually: 1. Run ml.FPGrowthExample - warning is there 2. Apply the fix 3. Run ml.FPGrowthExample again - no warning anymore Author: Arseniy Tashoyan Closes #20578 from tashoyan/SPARK-23318. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index aa7871d6ff29d..3d041fc80eb7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel /** * Common params for FPGrowth and FPGrowthModel @@ -158,18 +159,30 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val data = dataset.select($(itemsCol)) - val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } + + if (handlePersistence) { + items.persist(StorageLevel.MEMORY_AND_DISK) + } + val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + + if (handlePersistence) { + items.unpersist() + } + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) } From 300c40f50ab4258d697f06a814d1491dc875c847 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 06:23:10 -0600 Subject: [PATCH 0339/2461] [SPARK-23384][WEB-UI] When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. ## What changes were proposed in this pull request? When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. It is a bug. fix before: ![1](https://user-images.githubusercontent.com/26266482/36070635-264d7cf0-0f3a-11e8-8426-14135ffedb16.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/36070651-8ec3800e-0f3a-11e8-991c-6122cc9539fe.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20573 from guoxiaolongzte/SPARK-23384. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 5d62a7d8bebb4..6fc12d721e6f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - + ++ +
    @@ -65,7 +66,6 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++
    ++ - ++ ++ } else if (requestedIncomplete) { From 116c581d2658571d38f8b9b27a516ef517170589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 06:54:15 -0800 Subject: [PATCH 0340/2461] [SPARK-20659][CORE] Removing sc.getExecutorStorageStatus and making StorageStatus private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR StorageStatus is made to private and simplified a bit moreover SparkContext.getExecutorStorageStatus method is removed. The reason of keeping StorageStatus is that it is usage from SparkContext.getRDDStorageInfo. Instead of the method SparkContext.getExecutorStorageStatus executor infos are extended with additional memory metrics such as usedOnHeapStorageMemory, usedOffHeapStorageMemory, totalOnHeapStorageMemory, totalOffHeapStorageMemory. ## How was this patch tested? By running existing unit tests. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20546 from attilapiros/SPARK-20659. --- .../org/apache/spark/SparkExecutorInfo.java | 4 + .../scala/org/apache/spark/SparkContext.scala | 19 +- .../org/apache/spark/SparkStatusTracker.scala | 9 +- .../org/apache/spark/StatusAPIImpl.scala | 6 +- .../apache/spark/storage/StorageUtils.scala | 119 +--------- .../org/apache/spark/DistributedSuite.scala | 7 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../apache/spark/storage/StorageSuite.scala | 219 ------------------ project/MimaExcludes.scala | 14 ++ .../spark/repl/SingletonReplSuite.scala | 6 +- 10 files changed, 44 insertions(+), 361 deletions(-) diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java index dc3e826475987..2b93385adf103 100644 --- a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -30,4 +30,8 @@ public interface SparkExecutorInfo extends Serializable { int port(); long cacheSize(); int numRunningTasks(); + long usedOnHeapStorageMemory(); + long usedOffHeapStorageMemory(); + long totalOnHeapStorageMemory(); + long totalOffHeapStorageMemory(); } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3828d4f703247..c4f74c4f1f9c2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1715,7 +1715,13 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + val rddStorageInfo = statusStore.asOption(statusStore.rdd(rddId)) + rddInfo.numCachedPartitions = rddStorageInfo.map(_.numCachedPartitions).getOrElse(0) + rddInfo.memSize = rddStorageInfo.map(_.memoryUsed).getOrElse(0L) + rddInfo.diskSize = rddStorageInfo.map(_.diskUsed).getOrElse(0L) + } rddInfos.filter(_.isCached) } @@ -1726,17 +1732,6 @@ class SparkContext(config: SparkConf) extends Logging { */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - @deprecated("This method may change or be removed in a future release.", "2.2.0") - def getExecutorStorageStatus: Array[StorageStatus] = { - assertNotStopped() - env.blockManager.master.getStorageStatus - } - /** * :: DeveloperApi :: * Return pools for fair scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 70865cb58c571..815237eba0174 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -97,7 +97,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore } /** - * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. + * Returns information of all known executors, including host, port, cacheSize, numRunningTasks + * and memory metrics. */ def getExecutorInfos: Array[SparkExecutorInfo] = { store.executorList(true).map { exec => @@ -113,7 +114,11 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore host, port, cachedMem, - exec.activeTasks) + exec.activeTasks, + exec.memoryMetrics.map(_.usedOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.usedOnHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOnHeapStorageMemory).getOrElse(0L)) }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index c1f24a6377788..6a888c1e9e772 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -38,5 +38,9 @@ private class SparkExecutorInfoImpl( val host: String, val port: Int, val cacheSize: Long, - val numRunningTasks: Int) + val numRunningTasks: Int, + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) extends SparkExecutorInfo diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e9694fdbca2de..adc406bb1c441 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -24,19 +24,15 @@ import scala.collection.mutable import sun.nio.ch.DirectBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging /** - * :: DeveloperApi :: * Storage information for each BlockManager. * * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this * class cannot mutate the source of the information. Accesses are not thread-safe. */ -@DeveloperApi -@deprecated("This class may be removed or made private in a future release.", "2.2.0") -class StorageStatus( +private[spark] class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, val maxOnHeapMem: Option[Long], @@ -44,9 +40,6 @@ class StorageStatus( /** * Internal representation of the blocks stored in this block manager. - * - * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. - * These collections should only be mutated through the add/update/removeBlock methods. */ private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] @@ -87,9 +80,6 @@ class StorageStatus( */ def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - /** Return the blocks that belong to the given RDD stored in this block manager. */ - def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty) - /** Add the given block to this storage status. If it already exists, overwrite it. */ private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { updateStorageInfo(blockId, blockStatus) @@ -101,46 +91,6 @@ class StorageStatus( } } - /** Update the given block in this storage status. If it doesn't already exist, add it. */ - private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { - addBlock(blockId, blockStatus) - } - - /** Remove the given block from this storage status. */ - private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { - updateStorageInfo(blockId, BlockStatus.empty) - blockId match { - case RDDBlockId(rddId, _) => - // Actually remove the block, if it exists - if (_rddBlocks.contains(rddId)) { - val removed = _rddBlocks(rddId).remove(blockId) - // If the given RDD has no more blocks left, remove the RDD - if (_rddBlocks(rddId).isEmpty) { - _rddBlocks.remove(rddId) - } - removed - } else { - None - } - case _ => - _nonRddBlocks.remove(blockId) - } - } - - /** - * Return whether the given block is stored in this block manager in O(1) time. - * - * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. - */ - def containsBlock(blockId: BlockId): Boolean = { - blockId match { - case RDDBlockId(rddId, _) => - _rddBlocks.get(rddId).exists(_.contains(blockId)) - case _ => - _nonRddBlocks.contains(blockId) - } - } - /** * Return the given block stored in this block manager in O(1) time. * @@ -155,37 +105,12 @@ class StorageStatus( } } - /** - * Return the number of blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.blocks.size`, which is O(blocks) time. - */ - def numBlocks: Int = _nonRddBlocks.size + numRddBlocks - - /** - * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. - */ - def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - - /** - * Return the number of blocks that belong to the given RDD in O(1) time. - * - * @note This is much faster than `this.rddBlocksById(rddId).size`, which is - * O(blocks in this RDD) time. - */ - def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) - /** Return the max memory can be used by this block manager. */ def maxMem: Long = maxMemory /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed - /** Return the memory used by caching RDDs */ - def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) - /** Return the memory used by this block manager. */ def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) @@ -220,15 +145,9 @@ class StorageStatus( /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) - /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) - /** * Update the relevant storage info, taking into account any existing status for this block. */ @@ -295,40 +214,4 @@ private[spark] object StorageUtils extends Logging { cleaner.clean() } } - - /** - * Update the given list of RDDInfo with the given list of storage statuses. - * This method overwrites the old values stored in the RDDInfo's. - */ - def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { - rddInfos.foreach { rddInfo => - val rddId = rddInfo.id - // Assume all blocks belonging to the same RDD have the same storage level - val storageLevel = statuses - .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) - val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum - val memSize = statuses.map(_.memUsedByRdd(rddId)).sum - val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - - rddInfo.storageLevel = storageLevel - rddInfo.numCachedPartitions = numCachedPartitions - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - } - } - - /** - * Return a mapping from block ID to its locations for each block that belongs to the given RDD. - */ - def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] - statuses.foreach { status => - status.rddBlocksById(rddId).foreach { case (bid, _) => - val location = status.blockManagerId.hostPort - blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location - } - } - blockLocations - } - } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index e09d5f59817b9..28ea0c6f0bdba 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -160,11 +160,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) - assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === - storageLevel.replication * data.getNumPartitions) - assert(cachedData.count === 1000) - assert(cachedData.count === 1000) - + assert(sc.getRDDStorageInfo.filter(_.id == cachedData.id).map(_.numCachedPartitions).sum === + data.getNumPartitions) // Get all the locations of the first partition and try to fetch the partitions // from those locations. val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index bf7480d79f8a1..c21ee7d26f8ca 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -610,7 +610,7 @@ class StandaloneDynamicAllocationSuite * we submit a request to kill them. This must be called before each kill request. */ private def syncExecutors(sc: SparkContext): Unit = { - val driverExecutors = sc.getExecutorStorageStatus + val driverExecutors = sc.env.blockManager.master.getStorageStatus .map(_.blockManagerId.executorId) .filter { _ != SparkContext.DRIVER_IDENTIFIER} val masterExecutors = getExecutorIds(sc) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index da198f946fd64..ca352387055f4 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -51,27 +51,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsed === 60L) } - test("storage status update non-RDD blocks") { - val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) - assert(status.blocks.size === 3) - assert(status.memUsed === 160L) - assert(status.memRemaining === 840L) - assert(status.diskUsed === 140L) - } - - test("storage status remove non-RDD blocks") { - val status = storageStatus1 - status.removeBlock(TestBlockId("foo")) - status.removeBlock(TestBlockId("faa")) - assert(status.blocks.size === 1) - assert(status.blocks.contains(TestBlockId("fee"))) - assert(status.memUsed === 10L) - assert(status.memRemaining === 990L) - assert(status.diskUsed === 20L) - } - // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) @@ -95,85 +74,6 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocks.contains(RDDBlockId(2, 2))) assert(status.rddBlocks.contains(RDDBlockId(2, 3))) assert(status.rddBlocks.contains(RDDBlockId(2, 4))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1))) - assert(status.rddBlocksById(2).size === 3) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 30L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 80L) - assert(status.rddStorageLevel(0) === Some(memAndDisk)) - assert(status.rddStorageLevel(1) === Some(memAndDisk)) - assert(status.rddStorageLevel(2) === Some(memAndDisk)) - - // Verify default values for RDDs that don't exist - assert(status.rddBlocksById(10).isEmpty) - assert(status.memUsedByRdd(10) === 0L) - assert(status.diskUsedByRdd(10) === 0L) - assert(status.rddStorageLevel(10) === None) - } - - test("storage status update RDD blocks") { - val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) - assert(status.blocks.size === 7) - assert(status.rddBlocks.size === 5) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(2).size === 3) - assert(status.memUsedByRdd(0) === 0L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 20L) - assert(status.diskUsedByRdd(0) === 0L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 1060L) - } - - test("storage status remove RDD blocks") { - val status = storageStatus2 - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(1, 1)) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 4)) - assert(status.blocks.size === 3) - assert(status.rddBlocks.size === 2) - assert(status.rddBlocks.contains(RDDBlockId(0, 0))) - assert(status.rddBlocks.contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 0) - assert(status.rddBlocksById(2).size === 1) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 0L) - assert(status.memUsedByRdd(2) === 10L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 0L) - assert(status.diskUsedByRdd(2) === 20L) - } - - test("storage status containsBlock") { - val status = storageStatus2 - // blocks that actually exist - assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan"))) - assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man"))) - assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0))) - assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1))) - assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2))) - assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3))) - assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4))) - // blocks that don't exist - assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan"))) - assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0))) } test("storage status getBlock") { @@ -191,40 +91,6 @@ class StorageSuite extends SparkFunSuite { assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0))) } - test("storage status num[Rdd]Blocks") { - val status = storageStatus2 - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(100).size === status.numRddBlocksById(100)) - status.removeBlock(RDDBlockId(4, 0)) - status.removeBlock(RDDBlockId(10, 10)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - // remove a block that doesn't exist - status.removeBlock(RDDBlockId(1000, 999)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000)) - } test("storage status memUsed, diskUsed, externalBlockStoreUsed") { val status = storageStatus2 @@ -237,17 +103,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -273,65 +128,6 @@ class StorageSuite extends SparkFunSuite { Seq(info0, info1) } - test("StorageUtils.updateRddInfo") { - val storageStatuses = stockStorageStatuses - val rddInfos = stockRDDInfos - StorageUtils.updateRddInfo(rddInfos, storageStatuses) - assert(rddInfos(0).storageLevel === memAndDisk) - assert(rddInfos(0).numCachedPartitions === 5) - assert(rddInfos(0).memSize === 5L) - assert(rddInfos(0).diskSize === 10L) - assert(rddInfos(0).externalBlockStoreSize === 0L) - assert(rddInfos(1).storageLevel === memAndDisk) - assert(rddInfos(1).numCachedPartitions === 3) - assert(rddInfos(1).memSize === 3L) - assert(rddInfos(1).diskSize === 6L) - assert(rddInfos(1).externalBlockStoreSize === 0L) - } - - test("StorageUtils.getRddBlockLocations") { - val storageStatuses = stockStorageStatuses - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0.contains(RDDBlockId(0, 0))) - assert(blockLocations0.contains(RDDBlockId(0, 1))) - assert(blockLocations0.contains(RDDBlockId(0, 2))) - assert(blockLocations0.contains(RDDBlockId(0, 3))) - assert(blockLocations0.contains(RDDBlockId(0, 4))) - assert(blockLocations1.contains(RDDBlockId(1, 0))) - assert(blockLocations1.contains(RDDBlockId(1, 1))) - assert(blockLocations1.contains(RDDBlockId(1, 2))) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - - test("StorageUtils.getRddBlockLocations with multiple locations") { - val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - private val offheap = StorageLevel.OFF_HEAP // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap // and offheap blocks @@ -373,21 +169,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) - assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) - - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } private def storageStatus4: StorageStatus = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d35c50e1d00fe..381f7b5be1ddf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,20 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20659] Remove StorageStatus, or make it private + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") ) // Exclude rules for 2.3.x diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index ec3d790255ad3..d49e0fd85229f 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -350,7 +350,7 @@ class SingletonReplSuite extends SparkFunSuite { """ |val timeout = 60000 // 60 seconds |val start = System.currentTimeMillis - |while(sc.getExecutorStorageStatus.size != 3 && + |while(sc.statusTracker.getExecutorInfos.size != 3 && | (System.currentTimeMillis - start) < timeout) { | Thread.sleep(10) |} @@ -361,11 +361,11 @@ class SingletonReplSuite extends SparkFunSuite { |case class Foo(i: Int) |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) |ret.count() - |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + |val res = sc.getRDDStorageInfo.filter(_.id == ret.id).map(_.numCachedPartitions).sum """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res: Int = 20", output) + assertContains("res: Int = 10", output) } test("should clone and clean line object in ClosureCleaner") { From d6e1958a2472898e60bd013902c2f35111596e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 09:54:52 -0600 Subject: [PATCH 0341/2461] [SPARK-23189][CORE][WEB UI] Reflect stage level blacklisting on executor tab MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The purpose of this PR to reflect the stage level blacklisting on the executor tab for the currently active stages. After this change in the executor tab at the Status column one of the following label will be: - "Blacklisted" when the executor is blacklisted application level (old flag) - "Dead" when the executor is not Blacklisted and not Active - "Blacklisted in Stages: [...]" when the executor is Active but the there are active blacklisted stages for the executor. Within the [] coma separated active stageIDs are listed. - "Active" when the executor is Active and there is no active blacklisted stages for the executor ## How was this patch tested? Both with unit tests and manually. #### Manual test Spark was started as: ```bash bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" ``` And the job was: ```scala import org.apache.spark.SparkEnv val pairs = sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt == 0) throw new RuntimeException("Bad executor") else { Thread.sleep(10) (x % 10, x) } } val all = pairs.cogroup(pairs) all.collect() ``` UI screenshots about the running: - One executor is blacklisted in the two stages: ![One executor is blacklisted in two stages](https://issues.apache.org/jira/secure/attachment/12908314/multiple_stages_1.png) - One stage completes the other one is still running: ![One stage completes the other is still running](https://issues.apache.org/jira/secure/attachment/12908315/multiple_stages_2.png) - Both stages are completed: ![Both stages are completed](https://issues.apache.org/jira/secure/attachment/12908316/multiple_stages_3.png) ### Unit tests In AppStatusListenerSuite.scala both the node blacklisting for a stage and the executor blacklisting for stage are tested. Author: “attilapiros” Closes #20408 from attilapiros/SPARK-23189. --- .../apache/spark/ui/static/executorspage.js | 21 +++++--- .../spark/status/AppStatusListener.scala | 49 ++++++++++++++----- .../org/apache/spark/status/LiveEntity.scala | 7 ++- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../executor_list_json_expectation.json | 3 +- .../executor_memory_usage_expectation.json | 15 ++++-- ...xecutor_node_blacklisting_expectation.json | 15 ++++-- ...acklisting_unblacklisting_expectation.json | 15 ++++-- .../spark/status/AppStatusListenerSuite.scala | 21 ++++++++ 9 files changed, 113 insertions(+), 36 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index d430d8c5fb35a..6717af3ac4daf 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -25,12 +25,18 @@ function getThreadDumpEnabled() { return threadDumpEnabled; } -function formatStatus(status, type) { +function formatStatus(status, type, row) { + if (row.isBlacklisted) { + return "Blacklisted"; + } + if (status) { - return "Active" - } else { - return "Dead" + if (row.blacklistedInStages.length == 0) { + return "Active" + } + return "Active (Blacklisted in Stages: [" + row.blacklistedInStages.join(", ") + "])"; } + return "Dead" } jQuery.extend(jQuery.fn.dataTableExt.oSort, { @@ -415,9 +421,10 @@ $(document).ready(function () { } }, {data: 'hostPort'}, - {data: 'isActive', render: function (data, type, row) { - if (row.isBlacklisted) return "Blacklisted"; - else return formatStatus (data, type); + { + data: 'isActive', + render: function (data, type, row) { + return formatStatus (data, type, row); } }, {data: 'rddBlocks'}, diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index ab01cddfca5b0..79a17e26665fd 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -213,11 +213,13 @@ private[spark] class AppStatusListener( override def onExecutorBlacklistedForStage( event: SparkListenerExecutorBlacklistedForStage): Unit = { + val now = System.nanoTime() + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - val now = System.nanoTime() - val esummary = stage.executorSummary(event.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) + setStageBlackListStatus(stage, now, event.executorId) + } + liveExecutors.get(event.executorId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } @@ -226,16 +228,29 @@ private[spark] class AppStatusListener( // Implicitly blacklist every available executor for the stage associated with this node Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - liveExecutors.values.foreach { exec => - if (exec.hostname == event.hostId) { - val esummary = stage.executorSummary(exec.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) - } - } + val executorIds = liveExecutors.values.filter(_.host == event.hostId).map(_.executorId).toSeq + setStageBlackListStatus(stage, now, executorIds: _*) + } + liveExecutors.values.filter(_.hostname == event.hostId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } + private def addBlackListedStageTo(exec: LiveExecutor, stageId: Int, now: Long): Unit = { + exec.blacklistedInStages += stageId + liveUpdate(exec, now) + } + + private def setStageBlackListStatus(stage: LiveStage, now: Long, executorIds: String*): Unit = { + executorIds.foreach { executorId => + val executorStageSummary = stage.executorSummary(executorId) + executorStageSummary.isBlacklisted = true + maybeUpdate(executorStageSummary, now) + } + stage.blackListedExecutors ++= executorIds + maybeUpdate(stage, now) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } @@ -594,12 +609,24 @@ private[spark] class AppStatusListener( stage.executorSummaries.values.foreach(update(_, now)) update(stage, now, last = true) + + val executorIdsForStage = stage.blackListedExecutors + executorIdsForStage.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) + } + } } appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } + private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { + exec.blacklistedInStages -= stageId + liveUpdate(exec, now) + } + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index d5f9e19ffdcd0..79e3f13b826ce 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -20,6 +20,7 @@ package org.apache.spark.status import java.util.Date import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.{HashSet, TreeSet} import scala.collection.mutable.HashMap import com.google.common.collect.Interners @@ -254,6 +255,7 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE var totalShuffleRead = 0L var totalShuffleWrite = 0L var isBlacklisted = false + var blacklistedInStages: Set[Int] = TreeSet() var executorLogs = Map[String, String]() @@ -299,7 +301,8 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeTime), Option(removeReason), executorLogs, - memoryMetrics) + memoryMetrics, + blacklistedInStages) new ExecutorSummaryWrapper(info) } @@ -371,6 +374,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + var blackListedExecutors = new HashSet[String]() + // Used for cleanup of tasks after they reach the configured limit. Not written to the store. @volatile var cleaning = false var savedTasks = new AtomicInteger(0) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 550eac3952bbb..a333f1aaf6325 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -95,7 +95,8 @@ class ExecutorSummary private[spark]( val removeTime: Option[Date], val removeReason: Option[String], val executorLogs: Map[String, String], - val memoryMetrics: Option[MemoryMetrics]) + val memoryMetrics: Option[MemoryMetrics], + val blacklistedInStages: Set[Int]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 942e6d8f04363..7bb8fe8fd8f98 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -19,5 +19,6 @@ "isBlacklisted" : false, "maxMemory" : 278302556, "addTime" : "2015-02-03T16:43:00.906GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index ed33c90dd39ba..dd5b1dcb7372b 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ,{ "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 73519f1d9e2e4..3e55d3d9d7eb9 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 6931fead3d2ff..e87f3e78f2dc8 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -19,7 +19,8 @@ "isBlacklisted" : false, "maxMemory" : 384093388, "addTime" : "2016-11-15T23:20:38.836GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.111:64543", @@ -44,7 +45,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.111:64539", @@ -69,7 +71,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -94,7 +97,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.111:64540", @@ -119,5 +123,6 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b74d6ee2ec836..749502709b5c8 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -273,6 +273,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.head.stageId)) + } + // Blacklisting node for stage time += 1 listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( @@ -439,6 +443,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numCompleteTasks === pending.size) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set()) + } + // Submit stage 2. time += 1 stages.last.submissionTime = Some(time) @@ -453,6 +461,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) } + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "1.example.com", + executorFailures = 1, + stageId = stages.last.stageId, + stageAttemptId = stages.last.attemptId)) + + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) + } + // Start and fail all tasks of stage 2. time += 1 val s2Tasks = createTasks(4, execIds) From 091a000d27f324de8c5c527880854ecfcf5de9a4 Mon Sep 17 00:00:00 2001 From: huangtengfei Date: Tue, 13 Feb 2018 09:59:21 -0600 Subject: [PATCH 0342/2461] [SPARK-23053][CORE] taskBinarySerialization and task partitions calculate in DagScheduler.submitMissingTasks should keep the same RDD checkpoint status ## What changes were proposed in this pull request? When we run concurrent jobs using the same rdd which is marked to do checkpoint. If one job has finished running the job, and start the process of RDD.doCheckpoint, while another job is submitted, then submitStage and submitMissingTasks will be called. In [submitMissingTasks](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L961), will serialize taskBinaryBytes and calculate task partitions which are both affected by the status of checkpoint, if the former is calculated before doCheckpoint finished, while the latter is calculated after doCheckpoint finished, when run task, rdd.compute will be called, for some rdds with particular partition type such as [UnionRDD](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala) who will do partition type cast, will get a ClassCastException because the part params is actually a CheckpointRDDPartition. This error occurs because rdd.doCheckpoint occurs in the same thread that called sc.runJob, while the task serialization occurs in the DAGSchedulers event loop. ## How was this patch tested? the exist uts and also add a test case in DAGScheduerSuite to show the exception case. Author: huangtengfei Closes #20244 from ivoson/branch-taskpart-mistype. --- .../apache/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 199937b8c27af..8c46a84323392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1016,15 +1016,24 @@ class DAGScheduler( // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = stage match { - case stage: ShuffleMapStage => - JavaUtils.bufferToArray( - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) - case stage: ResultStage => - JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions } taskBinary = sc.broadcast(taskBinaryBytes) @@ -1049,7 +1058,7 @@ class DAGScheduler( stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) - val part = stage.rdd.partitions(id) + val part = partitions(id) stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), @@ -1059,7 +1068,7 @@ class DAGScheduler( case stage: ResultStage => partitionsToCompute.map { id => val p: Int = stage.partitions(id) - val part = stage.rdd.partitions(p) + val part = partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, From bd24731722a9142c90cf3d76008115f308203844 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 11:39:33 -0600 Subject: [PATCH 0343/2461] [SPARK-23382][WEB-UI] Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. please refer to https://github.com/apache/spark/pull/20216 fix after: ![1](https://user-images.githubusercontent.com/26266482/36068644-df029328-0f14-11e8-8350-cfdde9733ffc.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20570 from guoxiaolongzte/SPARK-23382. --- .../org/apache/spark/ui/static/webui.js | 2 + .../spark/streaming/ui/StreamingPage.scala | 37 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e575c4c78970d..83009df91d30a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -80,4 +80,6 @@ $(function() { collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); + collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); + collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); }); \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 7abafd6ba7908..3a176f64cdd60 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -490,15 +490,40 @@ private[ui] class StreamingPage(parent: StreamingTab) sortBy(_.batchTime.milliseconds).reverse val activeBatchesContent = { -

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ - new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Active Batches ({runningBatches.size + waitingBatches.size}) +

    +
    +
    + {new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } val completedBatchesContent = { -

    - Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) -

    ++ - new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Completed Batches (last {completedBatches.size} + out of {listener.numTotalCompletedBatches}) +

    +
    +
    + {new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } activeBatchesContent ++ completedBatchesContent From 263531466f4a7e223c94caa8705e6e8394a12054 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 13 Feb 2018 11:45:20 -0600 Subject: [PATCH 0344/2461] [SPARK-23392][TEST] Add some test cases for images feature ## What changes were proposed in this pull request? Add some test cases for images feature ## How was this patch tested? Add some test cases in ImageSchemaSuite Author: xubo245 <601450868@qq.com> Closes #20583 from xubo245/CARBONDATA23392_AddTestForImage. --- .../spark/ml/image/ImageSchemaSuite.scala | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index a8833c615865d..527b3f8955968 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -65,11 +65,71 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(count50 > 0 && count50 < countTotal) } + test("readImages test: recursive = false") { + val df = readImages(imagePath, null, false, 3, true, 1.0, 0) + assert(df.count() === 0) + } + + test("readImages test: read jpg image") { + val df = readImages(imagePath + "/kittens/DP153539.jpg", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read png image") { + val df = readImages(imagePath + "/multi-channel/BGRA.png", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read non image") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, true, 1.0, 0) + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + assert(df.count() === 0) + } + + test("readImages test: read non image and dropImageFailures is false") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, false, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: sampleRatio > 1") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, 1.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio < 0") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, -0.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio = 0") { + val df = readImages(imagePath, null, true, 3, true, 0.0, 0) + assert(df.count() === 0) + } + + test("readImages test: with sparkSession") { + val df = readImages(imagePath, sparkSession = spark, true, 3, true, 1.0, 0) + assert(df.count() === 8) + } + test("readImages partition test") { val df = readImages(imagePath, null, true, 3, true, 1.0, 0) assert(df.rdd.getNumPartitions === 3) } + test("readImages partition test: < 0") { + val df = readImages(imagePath, null, true, -3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + + test("readImages partition test: = 0") { + val df = readImages(imagePath, null, true, 0, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + // Images with the different number of channels test("readImages pixel values test") { @@ -93,7 +153,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // - default representation for 3-channel RGB images is BGR row-wise: // (B00, G00, R00, B10, G10, R10, ...) // - default representation for 4-channel RGB images is BGRA row-wise: - // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) private val firstBytes20 = Map( "grayscale.jpg" -> (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, From 05d051293fe46938e9cb012342fea6e8a3715cd4 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Tue, 13 Feb 2018 09:49:52 -0800 Subject: [PATCH 0345/2461] [SPARK-23316][SQL] AnalysisException after max iteration reached for IN query ## What changes were proposed in this pull request? Added flag ignoreNullability to DataType.equalsStructurally. The previous semantic is for ignoreNullability=false. When ignoreNullability=true equalsStructurally ignores nullability of contained types (map key types, value types, array element types, structure field types). In.checkInputTypes calls equalsStructurally to check if the children types match. They should match regardless of nullability (which is just a hint), so it is now called with ignoreNullability=true. ## How was this patch tested? New test in SubquerySuite Author: Bogdan Raducanu Closes #20548 from bogdanrdc/SPARK-23316. --- .../sql/catalyst/expressions/predicates.scala | 3 ++- .../org/apache/spark/sql/types/DataType.scala | 18 ++++++++++++------ .../org/apache/spark/sql/SubquerySuite.scala | 5 +++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b469f5cb7586a..a6d41ea7d00d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -157,7 +157,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, + ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d6e0df12218ad..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -295,25 +295,31 @@ object DataType { } /** - * Returns true if the two data types share the same "shape", i.e. the types (including - * nullability) are the same, but the field names don't need to be the same. + * Returns true if the two data types share the same "shape", i.e. the types + * are the same, but the field names don't need to be the same. + * + * @param ignoreNullability whether to ignore nullability when comparing the types */ - def equalsStructurally(from: DataType, to: DataType): Boolean = { + def equalsStructurally( + from: DataType, + to: DataType, + ignoreNullability: Boolean = false): Boolean = { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurally(left.elementType, right.elementType) && - left.containsNull == right.containsNull + (ignoreNullability || left.containsNull == right.containsNull) case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType) && equalsStructurally(left.valueType, right.valueType) && - left.valueContainsNull == right.valueContainsNull + (ignoreNullability || left.valueContainsNull == right.valueContainsNull) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields) .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + equalsStructurally(l.dataType, r.dataType) && + (ignoreNullability || l.nullable == r.nullable) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 8673dc14f7597..31e8b0e8dede0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -950,4 +950,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(join.duplicateResolved) assert(optimizedPlan.resolved) } + + test("SPARK-23316: AnalysisException after max iteration reached for IN query") { + // before the fix this would throw AnalysisException + spark.range(10).where("(id,id) in (select id, null from range(3))").count + } } From 4e0fb010ccdf13fe411f2a4796bbadc385b01520 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 13 Feb 2018 11:51:19 -0600 Subject: [PATCH 0346/2461] [SPARK-23217][ML] Add cosine distance measure to ClusteringEvaluator ## What changes were proposed in this pull request? The PR provided an implementation of ClusteringEvaluator using the cosine distance measure. This allows to evaluate clustering results created using the cosine distance, introduced in SPARK-22119. In the corresponding JIRA, there is a design document for the algorithm implemented here. ## How was this patch tested? Added UT which compares the result to the one provided by python sklearn. Author: Marco Gaido Closes #20396 from mgaido91/SPARK-23217. --- .../ml/evaluation/ClusteringEvaluator.scala | 334 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 32 +- 2 files changed, 300 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index d6ec5223237bb..8d4ae562b3d2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,11 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, + SchemaUtils} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -32,15 +33,11 @@ import org.apache.spark.sql.types.DoubleType * :: Experimental :: * * Evaluator for clustering results. - * The metric computes the Silhouette measure - * using the squared Euclidean distance. - * - * The Silhouette is a measure for the validation - * of the consistency within clusters. It ranges - * between 1 and -1, where a value close to 1 - * means that the points in a cluster are close - * to the other points in the same cluster and - * far from the points of the other clusters. + * The metric computes the Silhouette measure using the specified distance measure. + * + * The Silhouette is a measure for the validation of the consistency within clusters. It ranges + * between 1 and -1, where a value close to 1 means that the points in a cluster are close to the + * other points in the same cluster and far from the points of the other clusters. */ @Experimental @Since("2.3.0") @@ -84,18 +81,40 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") def setMetricName(value: String): this.type = set(metricName, value) - setDefault(metricName -> "silhouette") + /** + * param for distance measure to be used in evaluation + * (supports `"squaredEuclidean"` (default), `"cosine"`) + * @group param + */ + @Since("2.4.0") + val distanceMeasure: Param[String] = { + val availableValues = Array("squaredEuclidean", "cosine") + val allowedParams = ParamValidators.inArray(availableValues) + new Param(this, "distanceMeasure", "distance measure in evaluation. Supported options: " + + availableValues.mkString("'", "', '", "'"), allowedParams) + } + + /** @group getParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + + /** @group setParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + + setDefault(metricName -> "silhouette", distanceMeasure -> "squaredEuclidean") @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) - $(metricName) match { - case "silhouette" => + ($(metricName), $(distanceMeasure)) match { + case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol) - ) + dataset, $(predictionCol), $(featuresCol)) + case ("silhouette", "cosine") => + CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) } } } @@ -111,6 +130,48 @@ object ClusteringEvaluator } +private[evaluation] abstract class Silhouette { + + /** + * It computes the Silhouette coefficient for a point. + */ + def pointSilhouetteCoefficient( + clusterIds: Set[Double], + pointClusterId: Double, + pointClusterNumOfPoints: Long, + averageDistanceToCluster: (Double) => Double): Double = { + // Here we compute the average dissimilarity of the current point to any cluster of which the + // point is not a member. + // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current + // point - is said to be the "neighboring cluster". + val otherClusterIds = clusterIds.filter(_ != pointClusterId) + val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min + + // adjustment for excluding the node itself from the computation of the average dissimilarity + val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) { + 0.0 + } else { + averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints / + (pointClusterNumOfPoints - 1) + } + + if (currentClusterDissimilarity < neighboringClusterDissimilarity) { + 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) + } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { + (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 + } else { + 0.0 + } + } + + /** + * Compute the mean Silhouette values of all samples. + */ + def overallScore(df: DataFrame, scoreColumn: Column): Double = { + df.select(avg(scoreColumn)).collect()(0).getDouble(0) + } +} + /** * SquaredEuclideanSilhouette computes the average of the * Silhouette over all the data of the dataset, which is @@ -259,7 +320,7 @@ object ClusteringEvaluator * `N` is the number of points in the dataset and `W` is the number * of worker nodes. */ -private[evaluation] object SquaredEuclideanSilhouette { +private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { private[this] var kryoRegistrationPerformed: Boolean = false @@ -336,18 +397,19 @@ private[evaluation] object SquaredEuclideanSilhouette { * It computes the Silhouette coefficient for a point. * * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point. + * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. * @param clusterId The id of the cluster the current point belongs to. * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. * @return The Silhouette for the point. */ def computeSilhouetteCoefficient( broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], - features: Vector, + point: Vector, clusterId: Double, squaredNorm: Double): Double = { - def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = { + def compute(targetClusterId: Double): Double = { + val clusterStats = broadcastedClustersMap.value(targetClusterId) val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) squaredNorm + @@ -355,41 +417,14 @@ private[evaluation] object SquaredEuclideanSilhouette { 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints } - // Here we compute the average dissimilarity of the - // current point to any cluster of which the point - // is not a member. - // The cluster with the lowest average dissimilarity - // - i.e. the nearest cluster to the current point - - // is said to be the "neighboring cluster". - var neighboringClusterDissimilarity = Double.MaxValue - broadcastedClustersMap.value.keySet.foreach { - c => - if (c != clusterId) { - val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c)) - if(dissimilarity < neighboringClusterDissimilarity) { - neighboringClusterDissimilarity = dissimilarity - } - } - } - val currentCluster = broadcastedClustersMap.value(clusterId) - // adjustment for excluding the node itself from - // the computation of the average dissimilarity - val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) { - 0 - } else { - compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints / - (currentCluster.numOfPoints - 1) - } - - (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match { - case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) - case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 - case 0 => 0.0 - } + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId).numOfPoints, + compute) } /** - * Compute the mean Silhouette values of all samples. + * Compute the Silhouette score of the dataset using squared Euclidean distance measure. * * @param dataset The input dataset (previously clustered) on which compute the Silhouette. * @param predictionCol The name of the column which contains the predicted cluster id @@ -412,7 +447,7 @@ private[evaluation] object SquaredEuclideanSilhouette { val clustersStatsMap = SquaredEuclideanSilhouette .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol) - // Silhouette is reasonable only when the number of clusters is grater then 1 + // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) @@ -421,13 +456,190 @@ private[evaluation] object SquaredEuclideanSilhouette { computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) } - val silhouetteScore = dfWithSquaredNorm - .select(avg( - computeSilhouetteCoefficientUDF( - col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm")) - )) - .collect()(0) - .getDouble(0) + val silhouetteScore = overallScore(dfWithSquaredNorm, + computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), + col("squaredNorm"))) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} + + +/** + * The algorithm which is implemented in this object, instead, is an efficient and parallel + * implementation of the Silhouette using the cosine distance measure. The cosine distance + * measure is defined as `1 - s` where `s` is the cosine similarity between two points. + * + * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` + * is: + * + *
    + * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) + * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} + * \frac{c_{ij}}{\|C_{i}\|} + * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N + * \frac{c_{ij}}{\|C_{i}\|} \Big) + * $$ + *
    + * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension + * of the `i`-th point in cluster `$\Gamma$`. + * + * Then, we can define the vector: + * + *
    + * $$ + * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed for each point and the vector + * + *
    + * $$ + * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. + * + * With these definitions, the numerator becomes: + * + *
    + * $$ + * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} + * $$ + *
    + * + * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: + * + *
    + * $$ + * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} + * $$ + *
    + * + * In the implementation, the precomputed values for the clusters are distributed among the worker + * nodes via broadcasted variables, because we can assume that the clusters are limited in number. + * + * The main strengths of this algorithm are the low computational complexity and the intrinsic + * parallelism. The precomputed information for each point and for each cluster can be computed + * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the + * dataset and `W` is the number of worker nodes. After that, every point can be analyzed + * independently from the others. + * + * For every point we need to compute the average distance to all the clusters. Since the formula + * above requires `O(D)` operations, this phase has a computational complexity which is + * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number + * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker + * nodes. + */ +private[evaluation] object CosineSilhouette extends Silhouette { + + private[this] val normalizedFeaturesColName = "normalizedFeatures" + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a + * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). + */ + def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) + .rdd + .map { row => (row.getDouble(0), row.getAs[Vector](1)) } + .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))( + seqOp = { + case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) => + BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum) + (normalizedFeaturesSum, numOfPoints + 1) + }, + combOp = { + case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) => + BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) + (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2) + } + ) + + clustersStatsRDD + .collectAsMap() + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the + * normalized features of the current point. + * @param clusterId The id of the cluster the current point belongs to. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]], + normalizedFeatures: Vector, + clusterId: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) + 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId)._2, + compute) + } + + /** + * Compute the Silhouette score of the dataset using the cosine distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String): Double = { + val normalizeFeatureUDF = udf { + features: Vector => { + val norm = Vectors.norm(features, 2.0) + features match { + case d: DenseVector => Vectors.dense(d.values.map(_ / norm)) + case s: SparseVector => Vectors.sparse(s.size, s.indices, s.values.map(_ / norm)) + } + } + } + val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, + normalizeFeatureUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double) + } + + val silhouetteScore = overallScore(dfWithNormalizedFeatures, + computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), + col(predictionCol).cast(DoubleType))) bClustersStatsMap.destroy() diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 677ce49a903ab..3bf34770f5687 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -66,16 +66,38 @@ class ClusteringEvaluatorSuite assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } - test("number of clusters must be greater than one") { - val singleClusterDataset = irisDataset.where($"label" === 0.0) + /* + Use the following python code to load the data and evaluate it using scikit-learn package. + + from sklearn import datasets + from sklearn.metrics import silhouette_score + iris = datasets.load_iris() + round(silhouette_score(iris.data, iris.target, metric='cosine'), 10) + + 0.7222369298 + */ + test("cosine Silhouette") { val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") + .setDistanceMeasure("cosine") + + assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + } + + test("number of clusters must be greater than one") { + val singleClusterDataset = irisDataset.where($"label" === 0.0) + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) - val e = intercept[AssertionError]{ - evaluator.evaluate(singleClusterDataset) + val e = intercept[AssertionError] { + evaluator.evaluate(singleClusterDataset) + } + assert(e.getMessage.contains("Number of clusters must be greater than one")) } - assert(e.getMessage.contains("Number of clusters must be greater than one")) } } From d58fe28836639e68e262812d911f167cb071007b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 13 Feb 2018 11:18:45 -0800 Subject: [PATCH 0347/2461] [SPARK-23154][ML][DOC] Document backwards compatibility guarantees for ML persistence ## What changes were proposed in this pull request? Added documentation about what MLlib guarantees in terms of loading ML models and Pipelines from old Spark versions. Discussed & confirmed on linked JIRA. Author: Joseph K. Bradley Closes #20592 from jkbradley/SPARK-23154-backwards-compat-doc. --- docs/ml-pipeline.md | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index aa92c0a37c0f4..e22e9003c30f6 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -## Saving and Loading Pipelines +## ML persistence: Saving and Loading Pipelines -Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. +As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage. + +ML persistence works across Scala, Java and Python. However, R currently uses a modified format, +so models saved in R can only be loaded back in R; this should be fixed in the future and is +tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572). + +### Backwards compatibility for ML persistence + +In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML +model or Pipeline in one version of Spark, then you should be able to load it back and use it in a +future version of Spark. However, there are rare exceptions, described below. + +Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark +version X loadable by Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Yes; these are backwards compatible. +* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible. + +Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Identical behavior, except for bug fixes. + +For both model persistence and model behavior, any breaking changes across a minor version or patch +version are reported in the Spark version release notes. If a breakage is not reported in release +notes, then it should be treated as a bug to be fixed. # Code examples From 2ee76c22b6e48e643694c9475e5f0d37124215e7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 11:56:49 -0800 Subject: [PATCH 0348/2461] [SPARK-23400][SQL] Add a constructors for ScalaUDF ## What changes were proposed in this pull request? In this upcoming 2.3 release, we changed the interface of `ScalaUDF`. Unfortunately, some Spark packages (e.g., spark-deep-learning) are using our internal class `ScalaUDF`. In the release 2.3, we added new parameters into this class. The users hit the binary compatibility issues and got the exception: ``` > java.lang.NoSuchMethodError: org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(Ljava/lang/Object;Lorg/apache/spark/sql/types/DataType;Lscala/collection/Seq;Lscala/collection/Seq;Lscala/Option;)V ``` This PR is to improve the backward compatibility. However, we definitely should not encourage the external packages to use our internal classes. This might make us hard to maintain/develop the codes in Spark. ## How was this patch tested? N/A Author: gatorsmile Closes #20591 from gatorsmile/scalaUDF. --- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 388ef42883ad3..989c02305620a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,6 +49,17 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + // The constructor for SPARK 2.1 and 2.2 + def this( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType], + udfName: Option[String]) = { + this( + function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + } + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = From a5a4b83501526e02d0e3cd0056e4a5c0e1c8284f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 16:46:43 -0600 Subject: [PATCH 0349/2461] [SPARK-23235][CORE] Add executor Threaddump to api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending api with the executor thread dump data. For this new REST URL is introduced: - GET http://localhost:4040/api/v1/applications/{applicationId}/executors/{executorId}/threads
    Example response: ``` javascript [ { "threadId" : 52, "threadName" : "context-cleaner-periodic-gc", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1385411893})", "holdingLocks" : [ ] }, { "threadId" : 48, "threadName" : "dag-scheduler-event-loop", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingDeque.takeFirst(LinkedBlockingDeque.java:492)\njava.util.concurrent.LinkedBlockingDeque.take(LinkedBlockingDeque.java:680)\norg.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:46)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1138053349})", "holdingLocks" : [ ] }, { "threadId" : 17, "threadName" : "dispatcher-event-loop-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker832743930})" ] }, { "threadId" : 18, "threadName" : "dispatcher-event-loop-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker834153999})" ] }, { "threadId" : 19, "threadName" : "dispatcher-event-loop-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker664836465})" ] }, { "threadId" : 20, "threadName" : "dispatcher-event-loop-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1645557354})" ] }, { "threadId" : 21, "threadName" : "dispatcher-event-loop-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1188871851})" ] }, { "threadId" : 22, "threadName" : "dispatcher-event-loop-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker920926249})" ] }, { "threadId" : 23, "threadName" : "dispatcher-event-loop-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker355222677})" ] }, { "threadId" : 24, "threadName" : "dispatcher-event-loop-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1589745212})" ] }, { "threadId" : 49, "threadName" : "driver-heartbeater", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1602885835})", "holdingLocks" : [ ] }, { "threadId" : 53, "threadName" : "element-tracking-store-worker", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1439439099})", "holdingLocks" : [ ] }, { "threadId" : 3, "threadName" : "Finalizer", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:164)\njava.lang.ref.Finalizer$FinalizerThread.run(Finalizer.java:209)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1213098236})", "holdingLocks" : [ ] }, { "threadId" : 15, "threadName" : "ForkJoinPool-1-worker-13", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\nscala.concurrent.forkjoin.ForkJoinPool.scan(ForkJoinPool.java:2075)\nscala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)\nscala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)", "blockedByLock" : "Lock(scala.concurrent.forkjoin.ForkJoinPool380286413})", "holdingLocks" : [ ] }, { "threadId" : 45, "threadName" : "heartbeat-receiver-event-loop-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject715135812})", "holdingLocks" : [ ] }, { "threadId" : 1, "threadName" : "main", "threadState" : "RUNNABLE", "stackTrace" : "java.io.FileInputStream.read0(Native Method)\njava.io.FileInputStream.read(FileInputStream.java:207)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:169) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:137)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:246)\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:261) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:198) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.console.ConsoleReader.readCharacter(ConsoleReader.java:2145)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2349)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2269)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readOneLine(JLineReader.scala:57)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$.restartSysCalls(InteractiveReader.scala:44)\nscala.tools.nsc.interpreter.InteractiveReader$class.readLine(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readLine(JLineReader.scala:28)\nscala.tools.nsc.interpreter.ILoop.readOneLine(ILoop.scala:404)\nscala.tools.nsc.interpreter.ILoop.loop(ILoop.scala:413)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply$mcZ$sp(ILoop.scala:923)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.reflect.internal.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:97)\nscala.tools.nsc.interpreter.ILoop.process(ILoop.scala:909)\norg.apache.spark.repl.Main$.doMain(Main.scala:76)\norg.apache.spark.repl.Main$.main(Main.scala:56)\norg.apache.spark.repl.Main.main(Main.scala)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52)\norg.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:879)\norg.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197)\norg.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227)\norg.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136)\norg.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})" ] }, { "threadId" : 26, "threadName" : "map-output-dispatcher-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1791280119})" ] }, { "threadId" : 27, "threadName" : "map-output-dispatcher-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1947378744})" ] }, { "threadId" : 28, "threadName" : "map-output-dispatcher-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker507507251})" ] }, { "threadId" : 29, "threadName" : "map-output-dispatcher-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1016408627})" ] }, { "threadId" : 30, "threadName" : "map-output-dispatcher-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1879219501})" ] }, { "threadId" : 31, "threadName" : "map-output-dispatcher-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker290509937})" ] }, { "threadId" : 32, "threadName" : "map-output-dispatcher-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1889468930})" ] }, { "threadId" : 33, "threadName" : "map-output-dispatcher-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1699637904})" ] }, { "threadId" : 47, "threadName" : "netty-rpc-env-timeout", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject977194847})", "holdingLocks" : [ ] }, { "threadId" : 14, "threadName" : "NonBlockingInputStreamThread", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.run(NonBlockingInputStream.java:278)\njava.lang.Thread.run(Thread.java:748)", "blockedByThreadId" : 1, "blockedByLock" : "Lock(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})", "holdingLocks" : [ ] }, { "threadId" : 2, "threadName" : "Reference Handler", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.lang.ref.Reference.tryHandlePending(Reference.java:191)\njava.lang.ref.Reference$ReferenceHandler.run(Reference.java:153)", "blockedByLock" : "Lock(java.lang.ref.Reference$Lock1359433302})", "holdingLocks" : [ ] }, { "threadId" : 35, "threadName" : "refresh progress", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.util.TimerThread.mainLoop(Timer.java:552)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue44276328})", "holdingLocks" : [ ] }, { "threadId" : 34, "threadName" : "RemoteBlock-temp-file-clean-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager.org$apache$spark$storage$BlockManager$RemoteBlockTempFileManager$$keepCleaning(BlockManager.scala:1630)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager$$anon$1.run(BlockManager.scala:1608)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock391748181})", "holdingLocks" : [ ] }, { "threadId" : 25, "threadName" : "rpc-server-3-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet1066929256})", "Monitor(java.util.Collections$UnmodifiableSet561426729})", "Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})" ] }, { "threadId" : 50, "threadName" : "shuffle-server-5-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet385972319})", "Monitor(java.util.Collections$UnmodifiableSet477937109})", "Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})" ] }, { "threadId" : 4, "threadName" : "Signal Dispatcher", "threadState" : "RUNNABLE", "stackTrace" : "", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 51, "threadName" : "Spark Context Cleaner", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.ContextCleaner$$anonfun$org$apache$spark$ContextCleaner$$keepCleaning$1.apply$mcV$sp(ContextCleaner.scala:181)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.ContextCleaner.org$apache$spark$ContextCleaner$$keepCleaning(ContextCleaner.scala:178)\norg.apache.spark.ContextCleaner$$anon$1.run(ContextCleaner.scala:73)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1739420764})", "holdingLocks" : [ ] }, { "threadId" : 16, "threadName" : "spark-listener-group-appStatus", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1287190987})", "holdingLocks" : [ ] }, { "threadId" : 44, "threadName" : "spark-listener-group-executorManagement", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject943262890})", "holdingLocks" : [ ] }, { "threadId" : 54, "threadName" : "spark-listener-group-shared", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject334604425})", "holdingLocks" : [ ] }, { "threadId" : 37, "threadName" : "SparkUI-37", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 38, "threadName" : "SparkUI-38", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl841741934})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3873523986})", "Monitor(java.util.Collections$UnmodifiableSet1769333189})", "Monitor(sun.nio.ch.KQueueSelectorImpl841741934})" ] }, { "threadId" : 40, "threadName" : "SparkUI-40-acceptor-034929380-Spark3a557b62{HTTP/1.1,[http/1.1]}{0.0.0.0:4040}", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.ServerSocketChannelImpl.accept0(Native Method)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:422)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:250) => holding Monitor(java.lang.Object1134240909})\norg.spark_project.jetty.server.ServerConnector.accept(ServerConnector.java:371)\norg.spark_project.jetty.server.AbstractConnector$Acceptor.run(AbstractConnector.java:601)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(java.lang.Object1134240909})" ] }, { "threadId" : 43, "threadName" : "SparkUI-43", "threadState" : "RUNNABLE", "stackTrace" : "sun.management.ThreadImpl.dumpThreads0(Native Method)\nsun.management.ThreadImpl.dumpAllThreads(ThreadImpl.java:454)\norg.apache.spark.util.Utils$.getThreadDump(Utils.scala:2170)\norg.apache.spark.SparkContext.getExecutorThreadDump(SparkContext.scala:596)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:66)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:65)\nscala.Option.flatMap(Option.scala:171)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:65)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:58)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:139)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:134)\norg.apache.spark.ui.SparkUI.withSparkUI(SparkUI.scala:106)\norg.apache.spark.status.api.v1.BaseAppResource$class.withUI(ApiRootResource.scala:134)\norg.apache.spark.status.api.v1.AbstractApplicationResource.withUI(OneApplicationResource.scala:32)\norg.apache.spark.status.api.v1.AbstractApplicationResource.threadDump(OneApplicationResource.scala:58)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.glassfish.jersey.server.model.internal.ResourceMethodInvocationHandlerFactory$1.invoke(ResourceMethodInvocationHandlerFactory.java:81)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher$1.run(AbstractJavaResourceMethodDispatcher.java:144)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.invoke(AbstractJavaResourceMethodDispatcher.java:161)\norg.glassfish.jersey.server.model.internal.JavaResourceMethodDispatcherProvider$TypeOutInvoker.doDispatch(JavaResourceMethodDispatcherProvider.java:205)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.dispatch(AbstractJavaResourceMethodDispatcher.java:99)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.invoke(ResourceMethodInvoker.java:389)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:347)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:102)\norg.glassfish.jersey.server.ServerRuntime$2.run(ServerRuntime.java:326)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:271)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:267)\norg.glassfish.jersey.internal.Errors.process(Errors.java:315)\norg.glassfish.jersey.internal.Errors.process(Errors.java:297)\norg.glassfish.jersey.internal.Errors.process(Errors.java:267)\norg.glassfish.jersey.process.internal.RequestScope.runInScope(RequestScope.java:317)\norg.glassfish.jersey.server.ServerRuntime.process(ServerRuntime.java:305)\norg.glassfish.jersey.server.ApplicationHandler.handle(ApplicationHandler.java:1154)\norg.glassfish.jersey.servlet.WebComponent.serviceImpl(WebComponent.java:473)\norg.glassfish.jersey.servlet.WebComponent.service(WebComponent.java:427)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:388)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:341)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:228)\norg.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848)\norg.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584)\norg.spark_project.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180)\norg.spark_project.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512)\norg.spark_project.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112)\norg.spark_project.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)\norg.spark_project.jetty.server.handler.gzip.GzipHandler.handle(GzipHandler.java:493)\norg.spark_project.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213)\norg.spark_project.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134)\norg.spark_project.jetty.server.Server.handle(Server.java:534)\norg.spark_project.jetty.server.HttpChannel.handle(HttpChannel.java:320)\norg.spark_project.jetty.server.HttpConnection.onFillable(HttpConnection.java:251)\norg.spark_project.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283)\norg.spark_project.jetty.io.FillInterest.fillable(FillInterest.java:108)\norg.spark_project.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 67, "threadName" : "SparkUI-67", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3881415814})", "Monitor(java.util.Collections$UnmodifiableSet62050480})", "Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})" ] }, { "threadId" : 68, "threadName" : "SparkUI-68", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl223607814})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3543145185})", "Monitor(java.util.Collections$UnmodifiableSet897441546})", "Monitor(sun.nio.ch.KQueueSelectorImpl223607814})" ] }, { "threadId" : 71, "threadName" : "SparkUI-71", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 77, "threadName" : "SparkUI-77", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 78, "threadName" : "SparkUI-78", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl403077801})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3261312406})", "Monitor(java.util.Collections$UnmodifiableSet852901260})", "Monitor(sun.nio.ch.KQueueSelectorImpl403077801})" ] }, { "threadId" : 72, "threadName" : "SparkUI-JettyScheduler", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1587346642})", "holdingLocks" : [ ] }, { "threadId" : 63, "threadName" : "task-result-getter-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 64, "threadName" : "task-result-getter-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 65, "threadName" : "task-result-getter-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 66, "threadName" : "task-result-getter-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 46, "threadName" : "Timer-0", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.util.TimerThread.mainLoop(Timer.java:526)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue635634547})", "holdingLocks" : [ ] } ] ```
    ## How was this patch tested? It was tested manually. Old executor page with thread dumps: screen shot 2018-02-01 at 14 31 19 New api: screen shot 2018-02-01 at 14 31 56 Testing error cases. Initial state: ![screen shot 2018-02-06 at 13 05 05](https://user-images.githubusercontent.com/2017933/35858990-ad2982be-0b3e-11e8-879b-656112065c7f.png) Dead executor: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/1/threads Executor is not active. 400 ``` Never existed (but well formatted: number) executor ID: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/42/threads Executor does not exist. 404 ``` Not available stacktrace (dead executor but UI has not registered as dead yet): ```bash $ kill -9 ; curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/2/threads No thread dump is available. 404 ``` Invalid executor ID format: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/something6/threads Invalid executorId: neither 'driver' nor number. 400 ``` Author: “attilapiros” Closes #20474 from attilapiros/SPARK-23235. --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../spark/status/api/v1/ApiRootResource.scala | 8 +++++ .../api/v1/OneApplicationResource.scala | 29 +++++++++++++++-- .../org/apache/spark/status/api/v1/api.scala | 9 ++++++ .../ui/exec/ExecutorThreadDumpPage.scala | 13 +------- .../apache/spark/util/ThreadStackTrace.scala | 31 ------------------- .../scala/org/apache/spark/util/Utils.scala | 18 ++++++++++- docs/monitoring.md | 7 +++++ 8 files changed, 69 insertions(+), 47 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c4f74c4f1f9c2..dc531e3337014 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index ed9bdc6e1e3c2..7127397f6205c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -157,6 +157,14 @@ private[v1] class NotFoundException(msg: String) extends WebApplicationException .build() ) +private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( + new ServiceUnavailableException(msg), + Response + .status(Response.Status.SERVICE_UNAVAILABLE) + .entity(ErrorWrapper(msg)) + .build() +) + private[v1] class BadParameterException(msg: String) extends WebApplicationException( new IllegalArgumentException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index bd4df07e7afc6..974697890dd03 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -19,13 +19,13 @@ package org.apache.spark.status.api.v1 import java.io.OutputStream import java.util.{List => JList} import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs._ import javax.ws.rs.core.{MediaType, Response, StreamingOutput} import scala.util.control.NonFatal -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI +import org.apache.spark.{JobExecutionStatus, SparkContext} +import org.apache.spark.ui.UIUtils @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AbstractApplicationResource extends BaseAppResource { @@ -51,6 +51,29 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { @Path("executors") def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + @GET + @Path("executors/{executorId}/threads") + def threadDump(@PathParam("executorId") execId: String): Array[ThreadStackTrace] = withUI { ui => + if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) { + throw new BadParameterException( + s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.") + } + + val safeSparkContext = ui.sc.getOrElse { + throw new ServiceUnavailable("Thread dumps not available through the history server.") + } + + ui.store.asOption(ui.store.executorSummary(execId)) match { + case Some(executorSummary) if executorSummary.isActive => + val safeThreadDump = safeSparkContext.getExecutorThreadDump(execId).getOrElse { + throw new NotFoundException("No thread dump is available.") + } + safeThreadDump + case Some(_) => throw new BadParameterException("Executor is not active.") + case _ => throw new NotFoundException("Executor does not exist.") + } + } + @GET @Path("allexecutors") def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index a333f1aaf6325..369e98b683b1a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -316,3 +316,12 @@ class RuntimeInfo private[spark]( val javaVersion: String, val javaHome: String, val scalaVersion: String) + +case class ThreadStackTrace( + val threadId: Long, + val threadName: String, + val threadState: Thread.State, + val stackTrace: String, + val blockedByThreadId: Option[Long], + val blockedByLock: String, + val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f4686ea3cf91f..7a9aaf29a8b05 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -41,17 +40,7 @@ private[ui] class ExecutorThreadDumpPage( val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => - val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 - val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 - if (v1 == v2) { - threadTrace1.threadName.toLowerCase(Locale.ROOT) < - threadTrace2.threadName.toLowerCase(Locale.ROOT) - } else { - v1 > v2 - } - }.map { thread => + val dumpRows = threadDump.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { case Some(_) => diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala deleted file mode 100644 index b1217980faf1f..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Used for shipping per-thread stacktraces from the executors to driver. - */ -private[spark] case class ThreadStackTrace( - threadId: Long, - threadName: String, - threadState: Thread.State, - stackTrace: String, - blockedByThreadId: Option[Long], - blockedByLock: String, - holdingLocks: Seq[String]) - diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5853302973140..d493663f0b168 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -63,6 +63,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.status.api.v1.ThreadStackTrace /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2168,7 +2169,22 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + threadInfos.sortWith { case (threadTrace1, threadTrace2) => + val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT) + val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT) + val nameCmpRes = name1.compareTo(name2) + if (nameCmpRes == 0) { + threadTrace1.getThreadId < threadTrace2.getThreadId + } else { + nameCmpRes < 0 + } + } else { + v1 > v2 + } + }.map(threadInfoToThreadStackTrace) } def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { diff --git a/docs/monitoring.md b/docs/monitoring.md index 6f6cfc1288d73..d5f7ffcc260a1 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -347,6 +347,13 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/executors A list of all active executors for the given application. + + /applications/[app-id]/executors/[executor-id]/threads + + Stack traces of all the threads running within the given active executor. + Not available via the history server. + + /applications/[app-id]/allexecutors A list of all(active and dead) executors for the given application. From d6f5e172b480c62165be168deae0deff8062f476 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 16:21:17 -0800 Subject: [PATCH 0350/2461] Revert "[SPARK-23303][SQL] improve the explain result for data source v2 relations" This reverts commit f17b936f0ddb7d46d1349bd42f9a64c84c06e48d. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +++- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 +++++++++++++ .../v2/DataSourceV2QueryPlan.scala | 96 ------------------- .../datasources/v2/DataSourceV2Relation.scala | 26 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 ++- 15 files changed, 127 insertions(+), 157 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 72ee0c551ec3d..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,9 +17,20 @@ package org.apache.spark.sql.kafka010 -import org.apache.spark.sql.Dataset +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -60,8 +71,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index d34458ac81014..5a1a14f7a307a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,8 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index cb09cce75ff6f..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,8 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 984b6510f2dbe..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,9 +189,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.newInstance() val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -219,7 +221,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..81219e9771bd8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The held data source reader. + */ + def reader: DataSourceReader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(output, reader.getClass, filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala deleted file mode 100644 index 1e0d088f3a57c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.util.Utils - -/** - * A base class for data source v2 related query plan(both logical and physical). It defines the - * equals/hashCode methods, and provides a string representation of the query plan, according to - * some common information. - */ -trait DataSourceV2QueryPlan { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. - */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } - - /** - * The metadata of this data source query plan that can be used for equality check. - */ - private def metadata: Seq[Any] = Seq(output, source.getClass, filters) - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } - - def metadataString: String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") - - val outputStr = Utils.truncatedString(output, "[", ", ", "]") - - val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) - }, " (", ", ", ")") - } else { - "" - } - - s"${source.getClass.getSimpleName}$outputStr$entriesStr" - } - - private def redact(text: String): String = { - Utils.redact(SQLConf.get.stringRedationPattern, text) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cd97e0cab6b5c..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,23 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - source: DataSourceV2, - reader: DataSourceReader, - override val isStreaming: Boolean) - extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] - override def simpleString: String = { - val streamingHeader = if (isStreaming) "Streaming " else "" - s"${streamingHeader}Relation $metadataString" - } - override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -49,8 +41,18 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { - def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) + def apply(reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c99d535efcf81..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,14 +36,11 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], - @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def simpleString: String = s"Scan $metadataString" - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fb61e6f32b1f4..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 4cfdd50e8f46b..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r: DataSourceV2Relation) => + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = r.reader match { + val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84564b6639ac9..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,8 +52,6 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] - private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -92,7 +90,6 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -408,15 +405,12 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange(toJava(current), Optional.of(availableV2)) + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> new DataSourceV2Relation( - reader.readSchema().toAttributes, - // Provide a fake value here just in case something went wrong, e.g. the reader gives - // a wrong `equals` implementation. - readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), - reader, - isStreaming = true)) + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -506,5 +500,3 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } - -object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f87d57d0b3209..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(ds, _, output) => + case ContinuousExecutionRelation(_, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,8 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => - r.reader.asInstanceOf[ContinuousReader] + case DataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 70eb9f0ac66d5..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 254394685857b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case d: DataSourceV2Relation => d.reader + case DataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9ee9aaf87f87c..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.streaming.continuous -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import java.util.UUID + +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -40,7 +43,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 357babde5a8eb9710de7016d7ae82dee21fa4ef3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 14 Feb 2018 10:55:24 +0800 Subject: [PATCH 0351/2461] [SPARK-23399][SQL] Register a task completion listener first for OrcColumnarBatchReader ## What changes were proposed in this pull request? This PR aims to resolve an open file leakage issue reported at [SPARK-23390](https://issues.apache.org/jira/browse/SPARK-23390) by moving the listener registration position. Currently, the sequence is like the following. 1. Create `batchReader` 2. `batchReader.initialize` opens a ORC file. 3. `batchReader.initBatch` may take a long time to alloc memory in some environment and cause errors. 4. `Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))` This PR moves 4 before 2 and 3. To sum up, the new sequence is 1 -> 4 -> 2 -> 3. ## How was this patch tested? Manual. The following test case makes OOM intentionally to cause leaked filesystem connection in the current code base. With this patch, leakage doesn't occurs. ```scala // This should be tested manually because it raises OOM intentionally // in order to cause `Leaked filesystem connection`. test("SPARK-23399 Register a task completion listener first for OrcColumnarBatchReader") { withSQLConf(SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("orc").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("orc").save(new Path(basePath, "second").toString) val df = spark.read.orc( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20590 from dongjoon-hyun/SPARK-23399. --- .../sql/execution/datasources/orc/OrcFileFormat.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index dbf3bc6f0ee6c..1de2ca2914c44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -188,6 +188,12 @@ class OrcFileFormat if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, @@ -196,8 +202,6 @@ class OrcFileFormat partitionSchema, file.partitionValues) - val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) iter.asInstanceOf[Iterator[InternalRow]] } else { val orcRecordReader = new OrcInputFormat[OrcStruct] From 140f87533a468b1046504fc3ff01fbe1637e41cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Feb 2018 06:45:54 -0800 Subject: [PATCH 0352/2461] [SPARK-23394][UI] In RDD storage page show the executor addresses instead of the IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending RDD storage page to show executor addresses in the block table. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 10 30 59](https://user-images.githubusercontent.com/2017933/36142668-0b3578f8-10a9-11e8-95ea-2f57703ee4af.png) Author: “attilapiros” Closes #20589 from attilapiros/SPARK-23394. --- .../org/apache/spark/ui/storage/RDDPage.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 02cee7f8c5b33..2674b9291203a 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.status.api.v1.{ExecutorSummary, RDDDataDistribution, RDDPartitionInfo} import org.apache.spark.ui._ import org.apache.spark.util.Utils @@ -76,7 +76,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, - blockSortDesc) + blockSortDesc, + store.executorList(true)) _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -182,7 +183,8 @@ private[ui] class BlockDataSource( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + desc: Boolean, + executorIdToAddress: Map[String, String]) extends PagedDataSource[BlockTableRowData](pageSize) { private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) @@ -198,7 +200,10 @@ private[ui] class BlockDataSource( rddPartition.storageLevel, rddPartition.memoryUsed, rddPartition.diskUsed, - rddPartition.executors.mkString(" ")) + rddPartition.executors + .map { id => executorIdToAddress.get(id).getOrElse(id) } + .sorted + .mkString(" ")) } /** @@ -226,7 +231,8 @@ private[ui] class BlockPagedTable( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[BlockTableRowData] { + desc: Boolean, + executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] { override def tableId: String = "rdd-storage-by-block-table" @@ -243,7 +249,8 @@ private[ui] class BlockPagedTable( rddPartitions, pageSize, sortColumn, - desc) + desc, + executorSummaries.map { ex => (ex.id, ex.hostPort) }.toMap) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") From 400a1d9e25c1196f0be87323bd89fb3af0660166 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 10:57:12 -0800 Subject: [PATCH 0353/2461] Revert "[SPARK-23249][SQL] Improved block merging logic for partitions" This reverts commit 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9. --- .../sql/execution/DataSourceScanExec.scala | 29 +++++-------------- .../datasources/FileSourceStrategySuite.scala | 15 ++++++---- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index ba1157d5b6a49..08ff33afbba3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -444,29 +444,16 @@ case class FileSourceScanExec( currentSize = 0 } - def addFile(file: PartitionedFile): Unit = { - currentFiles += file - currentSize += file.length + openCostInBytes - } - - var frontIndex = 0 - var backIndex = splitFiles.length - 1 - - while (frontIndex <= backIndex) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - while (frontIndex <= backIndex && - currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - } - while (backIndex > frontIndex && - currentSize + splitFiles(backIndex).length <= maxSplitBytes) { - addFile(splitFiles(backIndex)) - backIndex -= 1 + // Assign files to partitions using "Next Fit Decreasing" + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() } - closePartition() + // Add the given file to the current partition. + currentSize += file.length + openCostInBytes + currentFiles += file } + closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index bfccc9335b361..c1d61b843d899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,17 +141,16 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] - assert(partitions.size == 3, "when checking partitions") - assert(partitions(0).files.size == 2, "when checking partition 1") + // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] + assert(partitions.size == 4, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") + assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1, file6) + // First partition reads (file1) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) - assert(partitions(0).files(1).start == 0) - assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -164,6 +163,10 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) + + // Final partition reads (file6) + assert(partitions(3).files(0).start == 0) + assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From 658d9d9d785a30857bf35d164e6cbbd9799d6959 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 14 Feb 2018 14:27:02 -0800 Subject: [PATCH 0354/2461] [SPARK-23406][SS] Enable stream-stream self-joins ## What changes were proposed in this pull request? Solved two bugs to enable stream-stream self joins. ### Incorrect analysis due to missing MultiInstanceRelation trait Streaming leaf nodes did not extend MultiInstanceRelation, which is necessary for the catalyst analyzer to convert the self-join logical plan DAG into a tree (by creating new instances of the leaf relations). This was causing the error `Failure when resolving conflicting references in Join:` (see JIRA for details). ### Incorrect attribute rewrite when splicing batch plans in MicroBatchExecution When splicing the source's batch plan into the streaming plan (by replacing the StreamingExecutionPlan), we were rewriting the attribute reference in the streaming plan with the new attribute references from the batch plan. This was incorrectly handling the scenario when multiple StreamingExecutionRelation point to the same source, and therefore eventually point to the same batch plan returned by the source. Here is an example query, and its corresponding plan transformations. ``` val df = input.toDF val join = df.select('value % 5 as "key", 'value).join( df.select('value % 5 as "key", 'value), "key") ``` Streaming logical plan before splicing the batch plan ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- StreamingExecutionRelation Memory[#1], value#12 // two different leaves pointing to same source ``` Batch logical plan after splicing the batch plan and before rewriting ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#12 ``` Batch logical plan after rewriting the attributes. Specifically, for spliced, the new output attributes (value#66) replace the earlier output attributes (value#12, and value#1, one for each StreamingExecutionRelation). ``` Project [key#6, value#66, value#66] // both value#1 and value#12 replaces by value#66 +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9, value#66] +- LocalRelation [value#66] ``` This causes the optimizer to eliminate value#66 from one side of the join. ``` Project [key#6, value#66, value#66] +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9] // this does not generate value, incorrect join results +- LocalRelation [value#66] ``` **Solution**: Instead of rewriting attributes, use a Project to introduce aliases between the output attribute references and the new reference generated by the spliced plans. The analyzer and optimizer will take care of the rest. ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- Project [value#66 AS value#1] // solution: project with aliases : +- LocalRelation [value#66] +- Project [(value#12 % 5) AS key#9, value#12] +- Project [value#66 AS value#12] // solution: project with aliases +- LocalRelation [value#66] ``` ## How was this patch tested? New unit test Author: Tathagata Das Closes #20598 from tdas/SPARK-23406. --- .../streaming/MicroBatchExecution.scala | 16 ++++++------ .../streaming/StreamingRelation.scala | 20 ++++++++++----- .../sql/streaming/StreamingJoinSuite.scala | 25 ++++++++++++++++++- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..ac73ba3417904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} @@ -415,8 +415,6 @@ class MicroBatchExecution( } } - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => @@ -424,18 +422,18 @@ class MicroBatchExecution( assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + s"${Utils.truncatedString(dataPlan.output, ",")}") - replacements ++= output.zip(dataPlan.output) - dataPlan + + val aliases = output.zip(dataPlan.output).map { case (to, from) => + Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) + } + Project(aliases, dataPlan) }.getOrElse { LocalRelation(output, isStreaming = true) } } // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) val newAttributePlan = newBatchesPlan transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 7146190645b37..f02d3a2c3733f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} @@ -42,7 +42,7 @@ object StreamingRelation { * passing to [[StreamExecution]] to run a query. */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName @@ -53,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) } /** @@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: case class StreamingExecutionRelation( source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -74,6 +76,8 @@ case class StreamingExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } // We have to pack in the V1 data source as a shim, for the case when a source implements @@ -92,13 +96,15 @@ case class StreamingRelationV2( extraOptions: Map[String, String], output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** @@ -108,7 +114,7 @@ case class ContinuousExecutionRelation( source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -120,6 +126,8 @@ case class ContinuousExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 54eb863dacc83..92087f68ad74a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ @@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } + test("stream stream self join") { + val input = MemoryStream[Int] + val df = input.toDF + val join = + df.select('value % 5 as "key", 'value).join( + df.select('value % 5 as "key", 'value), "key") + + testStream(join)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + } + test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ From a77ebb0921e390cf4fc6279a8c0a92868ad7e69b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:52:59 -0800 Subject: [PATCH 0355/2461] [SPARK-23421][SPARK-22356][SQL] Document the behavior change in ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/19579 introduces a behavior change. We need to document it in the migration guide. ## How was this patch tested? Also update the HiveExternalCatalogVersionsSuite to verify it. Author: gatorsmile Closes #20606 from gatorsmile/addMigrationGuide. --- docs/sql-programming-guide.md | 2 ++ .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0f9f01e18682f..cf9529a79f4f9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1963,6 +1963,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). ## Upgrading From Spark SQL 2.0 to 2.1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index ae4aeb7b4ce4a..c13a750dbb270 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") protected var spark: SparkSession = _ @@ -249,7 +249,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { spark.sql("msck repair table " + tbl_with_col_overlap) assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) From 95e4b4916065e66a4f8dba57e98e725796f75e04 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:56:02 -0800 Subject: [PATCH 0356/2461] [SPARK-23094] Revert [] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? This PR is to revert the PR https://github.com/apache/spark/pull/20302, because it causes a regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20614 from gatorsmile/revertJsonFix. --- .../catalyst/json/CreateJacksonParser.scala | 5 ++- .../sources/JsonHadoopFsRelationSuite.scala | 34 ------------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index b1672e7e2fca2..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,11 +40,10 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) - jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) + jsonFactory.createParser(record.getBytes, 0, record.getLength) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) + jsonFactory.createParser(record) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 27f398ebf301a..49be30435ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - private val badJson = "\u0000\u0000\u0000A\u0001AAA" - // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -107,36 +105,4 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } - - test("invalid json with leading nulls - from file (multiLine=true)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val expected = s"""$badJson\n{"a":1}\n""" - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) - checkAnswer(df, Row(null, expected)) - } - } - - test("invalid json with leading nulls - from file (multiLine=false)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) - checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) - } - } - - test("invalid json with leading nulls - from dataset") { - import testImplicits._ - checkAnswer( - spark.read.json(Seq(badJson).toDS()), - Row(badJson)) - } } From f38c760638063f1fb45e9ee2c772090fb203a4a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Feb 2018 16:59:44 +0800 Subject: [PATCH 0357/2461] [SPARK-23419][SPARK-23416][SS] data source v2 write path should re-throw interruption exceptions directly ## What changes were proposed in this pull request? Streaming execution has a list of exceptions that means interruption, and handle them specially. `WriteToDataSourceV2Exec` should also respect this list and not wrap them with `SparkException`. ## How was this patch tested? existing test. Author: Wenchen Fan Closes #20605 from cloud-fan/write. --- .../datasources/v2/WriteToDataSourceV2.scala | 11 ++++- .../execution/streaming/StreamExecution.scala | 40 ++++++++++--------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 535e7962d7439..41cdfc80d8a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.util.control.NonFatal + import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging @@ -27,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -107,7 +110,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e throw new SparkException("Writing job failed.", cause) } logError(s"Data source writer $writer aborted.") - throw new SparkException("Writing job aborted.", cause) + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } } sparkContext.emptyRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index e7982d7880ceb..3fc8c7887896a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -356,25 +356,7 @@ abstract class StreamExecution( private def isInterruptedByStop(e: Throwable): Boolean = { if (state.get == TERMINATED) { - e match { - // InterruptedIOException - thrown when an I/O operation is interrupted - // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted - case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => - true - // The cause of the following exceptions may be one of the above exceptions: - // - // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as - // BiFunction.apply - // ExecutionException - thrown by codes running in a thread pool and these codes throw an - // exception - // UncheckedExecutionException - thrown by codes that cannot throw a checked - // ExecutionException, such as BiFunction.apply - case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) - if e2.getCause != null => - isInterruptedByStop(e2.getCause) - case _ => - false - } + StreamExecution.isInterruptionException(e) } else { false } @@ -565,6 +547,26 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + + def isInterruptionException(e: Throwable): Boolean = e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptionException(e2.getCause) + case _ => + false + } } /** From 7539ae59d6c354c95c50528abe9ddff6972e960f Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 15 Feb 2018 17:09:06 +0800 Subject: [PATCH 0358/2461] [SPARK-23366] Improve hot reading path in ReadAheadInputStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `ReadAheadInputStream` was introduced in https://github.com/apache/spark/pull/18317/ to optimize reading spill files from disk. However, from the profiles it seems that the hot path of reading small amounts of data (like readInt) is inefficient - it involves taking locks, and multiple checks. Optimize locking: Lock is not needed when simply accessing the active buffer. Only lock when needing to swap buffers or trigger async reading, or get information about the async state. Optimize short-path single byte reads, that are used e.g. by Java library DataInputStream.readInt. The asyncReader used to call "read" only once on the underlying stream, that never filled the underlying buffer when it was wrapping an LZ4BlockInputStream. If the buffer was returned unfilled, that would trigger the async reader to be triggered to fill the read ahead buffer on each call, because the reader would see that the active buffer is below the refill threshold all the time. However, filling the full buffer all the time could introduce increased latency, so also add an `AtomicBoolean` flag for the async reader to return earlier if there is a reader waiting for data. Remove `readAheadThresholdInBytes` and instead immediately trigger async read when switching the buffers. It allows to simplify code paths, especially the hot one that then only has to check if there is available data in the active buffer, without worrying if it needs to retrigger async read. It seems to have positive effect on perf. ## How was this patch tested? It was noticed as a regression in some workloads after upgrading to Spark 2.3.  It was particularly visible on TPCDS Q95 running on instances with fast disk (i3 AWS instances). Running with profiling: * Spark 2.2 - 5.2-5.3 minutes 9.5% in LZ4BlockInputStream.read * Spark 2.3 - 6.4-6.6 minutes 31.1% in ReadAheadInputStream.read * Spark 2.3 + fix - 5.3-5.4 minutes 13.3% in ReadAheadInputStream.read - very slightly slower, practically within noise. We didn't see other regressions, and many workloads in general seem to be faster with Spark 2.3 (not investigated if thanks to async readed, or unrelated). Author: Juliusz Sompolski Closes #20555 from juliuszsompolski/SPARK-23366. --- .../apache/spark/io/ReadAheadInputStream.java | 119 +++++++++--------- .../unsafe/sort/UnsafeSorterSpillReader.java | 10 +- .../spark/io/GenericFileInputStreamSuite.java | 98 ++++++++------- .../spark/io/NioBufferedInputStreamSuite.java | 6 +- .../spark/io/ReadAheadInputStreamSuite.java | 17 ++- 5 files changed, 133 insertions(+), 117 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5b45d268ace8d..0cced9e222952 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream { // whether there is a read ahead task running, private boolean isReading; - // If the remaining data size in the current buffer is below this threshold, - // we issue an async read from the underlying input stream. - private final int readAheadThresholdInBytes; + // whether there is a reader waiting for data. + private AtomicBoolean isWaiting = new AtomicBoolean(false); private final InputStream underlyingInputStream; @@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream { * * @param inputStream The underlying input stream. * @param bufferSizeInBytes The buffer size. - * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead - * threshold, an async read is triggered. */ public ReadAheadInputStream( - InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + InputStream inputStream, int bufferSizeInBytes) { Preconditions.checkArgument(bufferSizeInBytes > 0, "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); - Preconditions.checkArgument(readAheadThresholdInBytes > 0 && - readAheadThresholdInBytes < bufferSizeInBytes, - "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " + - "but the value is " + readAheadThresholdInBytes); activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); - this.readAheadThresholdInBytes = readAheadThresholdInBytes; this.underlyingInputStream = inputStream; activeBuffer.flip(); readAheadBuffer.flip(); @@ -166,12 +159,17 @@ public void run() { // in that case the reader waits for this async read to complete. // So there is no race condition in both the situations. int read = 0; + int off = 0, len = arr.length; Throwable exception = null; try { - while (true) { - read = underlyingInputStream.read(arr); - if (0 != read) break; - } + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); } catch (Throwable ex) { exception = ex; if (ex instanceof Error) { @@ -181,13 +179,12 @@ public void run() { } } finally { stateChangeLock.lock(); + readAheadBuffer.limit(off); if (read < 0 || (exception instanceof EOFException)) { endOfStream = true; } else if (exception != null) { readAborted = true; readException = exception; - } else { - readAheadBuffer.limit(read); } readInProgress = false; signalAsyncReadComplete(); @@ -230,7 +227,10 @@ private void signalAsyncReadComplete() { private void waitForAsyncReadComplete() throws IOException { stateChangeLock.lock(); + isWaiting.set(true); try { + // There is only one reader, and one writer, so the writer should signal only once, + // but a while loop checking the wake up condition is still needed to avoid spurious wakeups. while (readInProgress) { asyncReadComplete.await(); } @@ -239,6 +239,7 @@ private void waitForAsyncReadComplete() throws IOException { iio.initCause(e); throw iio; } finally { + isWaiting.set(false); stateChangeLock.unlock(); } checkReadException(); @@ -246,8 +247,13 @@ private void waitForAsyncReadComplete() throws IOException { @Override public int read() throws IOException { - byte[] oneByteArray = oneByte.get(); - return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + if (activeBuffer.hasRemaining()) { + // short path - just get one byte. + return activeBuffer.get() & 0xFF; + } else { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } } @Override @@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException { if (len == 0) { return 0; } - stateChangeLock.lock(); - try { - return readInternal(b, offset, len); - } finally { - stateChangeLock.unlock(); - } - } - /** - * flip the active and read ahead buffer - */ - private void swapBuffers() { - ByteBuffer temp = activeBuffer; - activeBuffer = readAheadBuffer; - readAheadBuffer = temp; - } - - /** - * Internal read function which should be called only from read() api. The assumption is that - * the stateChangeLock is already acquired in the caller before calling this function. - */ - private int readInternal(byte[] b, int offset, int len) throws IOException { - assert (stateChangeLock.isLocked()); if (!activeBuffer.hasRemaining()) { - waitForAsyncReadComplete(); - if (readAheadBuffer.hasRemaining()) { - swapBuffers(); - } else { - // The first read or activeBuffer is skipped. - readAsync(); + // No remaining in active buffer - lock and switch to write ahead buffer. + stateChangeLock.lock(); + try { waitForAsyncReadComplete(); - if (isEndOfStream()) { - return -1; + if (!readAheadBuffer.hasRemaining()) { + // The first read. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } } + // Swap the newly read read ahead buffer in place of empty active buffer. swapBuffers(); + // After swapping buffers, trigger another async read for read ahead buffer. + readAsync(); + } finally { + stateChangeLock.unlock(); } - } else { - checkReadException(); } len = Math.min(len, activeBuffer.remaining()); activeBuffer.get(b, offset, len); - if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { - readAsync(); - } return len; } + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + @Override public int available() throws IOException { stateChangeLock.lock(); @@ -323,6 +318,11 @@ public long skip(long n) throws IOException { if (n <= 0L) { return 0L; } + if (n <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position((int) n + activeBuffer.position()); + return n; + } stateChangeLock.lock(); long skipped; try { @@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException { if (available() >= n) { // we can skip from the internal buffers int toSkip = (int) n; - if (toSkip <= activeBuffer.remaining()) { - // Only skipping from active buffer is sufficient - activeBuffer.position(toSkip + activeBuffer.position()); - if (activeBuffer.remaining() <= readAheadThresholdInBytes - && !readAheadBuffer.hasRemaining()) { - readAsync(); - } - return n; - } // We need to skip from both active buffer and read ahead buffer toSkip -= activeBuffer.remaining(); + assert(toSkip > 0); // skipping from activeBuffer already handled. activeBuffer.position(0); activeBuffer.flip(); readAheadBuffer.position(toSkip + readAheadBuffer.position()); swapBuffers(); + // Trigger async read to emptied read ahead buffer. readAsync(); return n; } else { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 2c53c8d809d2e..fb179d07edebc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,21 +72,15 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } - final double readAheadFraction = - SparkEnv.get() == null ? 0.5 : - SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf - // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { if (readAheadEnabled) { this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), - (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + (int) bufferSizeBytes); } else { this.in = serializerManager.wrapStream(blockId, bs); } diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 3440e1aea2f46..22db3592ecc96 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite { protected File inputFile; - protected InputStream inputStream; + protected InputStream[] inputStreams; @Before public void setUp() throws IOException { @@ -54,77 +54,91 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - for (int i = 0; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testReadMultipleBytes() throws IOException { - byte[] readBytes = new byte[8 * 1024]; - int i = 0; - while (i < randomBytes.length) { - int read = inputStream.read(readBytes, 0, 8 * 1024); - for (int j = 0; j < read; j++) { - assertEquals(randomBytes[i], readBytes[j]); - i++; + for (InputStream inputStream: inputStreams) { + byte[] readBytes = new byte[8 * 1024]; + int i = 0; + while (i < randomBytes.length) { + int read = inputStream.read(readBytes, 0, 8 * 1024); + for (int j = 0; j < read; j++) { + assertEquals(randomBytes[i], readBytes[j]); + i++; + } } } } @Test public void testBytesSkipped() throws IOException { - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - // Skipping negative bytes should essential be a no-op - assertEquals(0, inputStream.skip(-1)); - assertEquals(0, inputStream.skip(-1024)); - assertEquals(0, inputStream.skip(Long.MIN_VALUE)); - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + // Skipping negative bytes should essential be a no-op + assertEquals(0, inputStream.skip(-1)); + assertEquals(0, inputStream.skip(-1024)); + assertEquals(0, inputStream.skip(Long.MIN_VALUE)); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testSkipFromFileChannel() throws IOException { - // Since the buffer is smaller than the skipped bytes, this will guarantee - // we skip from underlying file channel. - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < 2048; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(256, inputStream.skip(256)); - assertEquals(256, inputStream.skip(256)); - assertEquals(512, inputStream.skip(512)); - for (int i = 3072; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + // Since the buffer is smaller than the skipped bytes, this will guarantee + // we skip from underlying file channel. + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < 2048; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(256, inputStream.skip(256)); + assertEquals(256, inputStream.skip(256)); + assertEquals(512, inputStream.skip(512)); + for (int i = 3072; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterEOF() throws IOException { - assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); - assertEquals(-1, inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); + assertEquals(-1, inputStream.read()); + } } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java index 211b33a1a9fb0..a320f8662f707 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -18,6 +18,7 @@ import org.junit.Before; +import java.io.InputStream; import java.io.IOException; /** @@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new NioBufferedFileInputStream(inputFile); + inputStreams = new InputStream[] { + new NioBufferedFileInputStream(inputFile), // default + new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer + }; } } diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java index 918ddc4517ec4..bfa1e0b908824 100644 --- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -19,16 +19,27 @@ import org.junit.Before; import java.io.IOException; +import java.io.InputStream; /** - * Tests functionality of {@link NioBufferedFileInputStream} + * Tests functionality of {@link ReadAheadInputStreamSuite} */ public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new ReadAheadInputStream( - new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + inputStreams = new InputStream[] { + // Tests equal and aligned buffers of wrapped an outer stream. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 1024), 8 * 1024), + // Tests aligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 1024), 2 * 1024), + // Tests aligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 1024), 3 * 1024), + // Tests unaligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 123), + // Tests unaligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 321) + }; } } From ed8647609883fcef16be5d24c2cb4ebda25bd6f0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Feb 2018 17:13:05 +0800 Subject: [PATCH 0359/2461] [SPARK-23359][SQL] Adds an alias 'names' of 'fieldNames' in Scala's StructType ## What changes were proposed in this pull request? This PR proposes to add an alias 'names' of 'fieldNames' in Scala. Please see the discussion in [SPARK-20090](https://issues.apache.org/jira/browse/SPARK-20090). ## How was this patch tested? Unit tests added in `DataTypeSuite.scala`. Author: hyukjinkwon Closes #20545 from HyukjinKwon/SPARK-23359. --- .../scala/org/apache/spark/sql/types/StructType.scala | 7 +++++++ .../scala/org/apache/spark/sql/types/DataTypeSuite.scala | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e3b0969283a84..d5011c3cb87e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -104,6 +104,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) + /** + * Returns all field names in an array. This is an alias of `fieldNames`. + * + * @since 2.4.0 + */ + def names: Array[String] = fieldNames + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 8e2b32c2b9a08..5a86f4055dce7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -134,6 +134,14 @@ class DataTypeSuite extends SparkFunSuite { assert(mapped === expected) } + test("fieldNames and names returns field names") { + val struct = StructType( + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) + + assert(struct.fieldNames === Seq("a", "b")) + assert(struct.names === Seq("a", "b")) + } + test("merge where right contains type conflict") { val left = StructType( StructField("a", LongType) :: From 44e20c42254bc6591b594f54cd94ced5fcfadae3 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 15 Feb 2018 03:52:40 -0800 Subject: [PATCH 0360/2461] =?UTF-8?q?[SPARK-23422][CORE]=20YarnShuffleInte?= =?UTF-8?q?grationSuite=20fix=20when=20SPARK=5FPREPEN=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …D_CLASSES set to 1 ## What changes were proposed in this pull request? YarnShuffleIntegrationSuite fails when SPARK_PREPEND_CLASSES set to 1. Normally mllib built before yarn module. When SPARK_PREPEND_CLASSES used mllib classes are on yarn test classpath. Before 2.3 that did not cause issues. But 2.3 has SPARK-22450, which registered some mllib classes with the kryo serializer. Now it dies with the following error: ` 18/02/13 07:33:29 INFO SparkContext: Starting job: collect at YarnShuffleIntegrationSuite.scala:143 Exception in thread "dag-scheduler-event-loop" java.lang.NoClassDefFoundError: breeze/linalg/DenseMatrix ` In this PR NoClassDefFoundError caught only in case of testing and then do nothing. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20608 from gaborgsomogyi/SPARK-23422. --- .../main/scala/org/apache/spark/serializer/KryoSerializer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 538ae05e4eea1..72427dd6ce4d4 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -206,6 +206,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(clazz) } catch { case NonFatal(_) => // do nothing + case _: NoClassDefFoundError if Utils.isTesting => // See SPARK-23422. } } From f217d7d9b22c4b9c947fc5467379af17f036ee61 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Feb 2018 07:47:40 -0800 Subject: [PATCH 0361/2461] [INFRA] Close stale PRs. Closes #20587 Closes #20586 From 2f0498d1e85a53b60da6a47d20bbdf56b42b7dcb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 08:55:39 -0800 Subject: [PATCH 0362/2461] [SPARK-23426][SQL] Use `hive` ORC impl and disable PPD for Spark 2.3.0 ## What changes were proposed in this pull request? To prevent any regressions, this PR changes ORC implementation to `hive` by default like Spark 2.2.X. Users can enable `native` ORC. Also, ORC PPD is also restored to `false` like Spark 2.2.X. ![orc_section](https://user-images.githubusercontent.com/9700541/36221575-57a1d702-1173-11e8-89fe-dca5842f4ca7.png) ## How was this patch tested? Pass all test cases. Author: Dongjoon Hyun Closes #20610 from dongjoon-hyun/SPARK-ORC-DISABLE. --- docs/sql-programming-guide.md | 52 ++++++++----------- .../apache/spark/sql/internal/SQLConf.scala | 6 +-- .../spark/sql/FileBasedDataSourceSuite.scala | 17 +++++- .../sql/streaming/FileStreamSinkSuite.scala | 13 +++++ .../sql/streaming/FileStreamSourceSuite.scala | 13 +++++ 5 files changed, 68 insertions(+), 33 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cf9529a79f4f9..91e43678481d6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1004,6 +1004,29 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession +## ORC Files + +Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. +To do that, the following configurations are newly added. The vectorized reader is used for the +native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` +is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC +serde tables (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), +the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also set to `true`. + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.orc.implhiveThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    + ## JSON Datasets
    @@ -1776,35 +1799,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 - - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. - - - New configurations - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    - - - Changed configurations - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
    - - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7835dbaa58439..f24fd7ff74d3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default prior to Spark 2.3.") + "1.2.1. It is 'hive' by default.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("native") + .createWithDefault("hive") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 2e332362ea644..b5d4c558f0d3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -20,14 +20,29 @@ package org.apache.spark.sql import java.io.FileNotFoundException import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") private val nameWithSpecialChars = "sp&cial%c hars" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 8c4e1fd00b0a2..ba48bc1ce0c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -33,6 +33,19 @@ import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + test("unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5bb0f4d643bbe..d4bd9c7987f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -207,6 +207,19 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + // ============= Basic parameter exists tests ================ test("FileStreamSource schema: no path") { From 6968c3cfd70961c4e86daffd6a156d0a9c1d7a2a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 09:40:08 -0800 Subject: [PATCH 0363/2461] [MINOR][SQL] Fix an error message about inserting into bucketed tables ## What changes were proposed in this pull request? This replaces `Sparkcurrently` to `Spark currently` in the following error message. ```scala scala> sql("insert into t2 select * from v1") org.apache.spark.sql.AnalysisException: Output Hive table `default`.`t2` is bucketed but Sparkcurrently does NOT populate bucketed ... ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20617 from dongjoon-hyun/SPARK-ERROR-MSG. --- .../apache/spark/sql/hive/execution/InsertIntoHiveTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6fc..02a60f16b3b3a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -172,7 +172,7 @@ case class InsertIntoHiveTable( val enforceBucketingConfig = "hive.enforce.bucketing" val enforceSortingConfig = "hive.enforce.sorting" - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + + val message = s"Output Hive table ${table.identifier} is bucketed but Spark " + "currently does NOT populate bucketed output which is compatible with Hive." if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || From db45daab90ede4c03c1abc9096f4eac584e9db17 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Feb 2018 09:54:39 -0800 Subject: [PATCH 0364/2461] [SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug ## What changes were proposed in this pull request? #### Problem: Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown. However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed. #### Fix: This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later. Please see the discussion in the JIRA. Note: The multi-column `QuantileDiscretizer` also has the same issue. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20594 from viirya/SPARK-23377-2. --- .../apache/spark/ml/feature/Bucketizer.scala | 28 +++++++++++++++++++ .../ml/feature/QuantileDiscretizer.scala | 28 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 12 ++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 14 ++++++++-- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index c13bf47eacb94..f49c410cbcfe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -19,6 +19,10 @@ package org.apache.spark.ml.feature import java.{util => ju} +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Model @@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } + + override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + + private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1ec5f8cb6139b..3b4c25478fb1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.feature +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) + + override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + private[QuantileDiscretizer] + class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 7403680ae3fdc..41cf72fe3470a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setSplits(Array(0.1, 0.8, 0.9)) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) } test("Bucket numeric features") { @@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) + assert(t.hasDefault(t.outputCol)) + assert(bucketizer.hasDefault(bucketizer.outputCol)) } test("Bucketizer in a pipeline") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index e9a75e931e6a8..6c363799dd300 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("Test observed number of buckets and their sizes match expected values") { val spark = this.spark import spark.implicits._ @@ -132,7 +134,10 @@ class QuantileDiscretizerSuite .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setNumBuckets(6) - testDefaultReadWrite(t) + + val readDiscretizer = testDefaultReadWrite(t) + val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol") + readDiscretizer.fit(data) } test("Verify resulting model has parent") { @@ -379,7 +384,12 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBucketsArray(Array(5, 10)) - testDefaultReadWrite(discretizer) + + val readDiscretizer = testDefaultReadWrite(discretizer) + val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2") + readDiscretizer.fit(data) + assert(discretizer.hasDefault(discretizer.outputCol)) + assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } test("Multiple Columns: Both inputCol and inputCols are set") { From 1dc2c1d5e85c5f404f470aeb44c1f3c22786bdea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 15 Feb 2018 13:51:24 -0600 Subject: [PATCH 0365/2461] [SPARK-23413][UI] Fix sorting tasks by Host / Executor ID at the Stage page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fixing exception got at sorting tasks by Host / Executor ID: ``` java.lang.IllegalArgumentException: Invalid sort column: Host at org.apache.spark.ui.jobs.ApiHelper$.indexName(StagePage.scala:1017) at org.apache.spark.ui.jobs.TaskDataSource.sliceData(StagePage.scala:694) at org.apache.spark.ui.PagedDataSource.pageData(PagedTable.scala:61) at org.apache.spark.ui.PagedTable$class.table(PagedTable.scala:96) at org.apache.spark.ui.jobs.TaskPagedTable.table(StagePage.scala:708) at org.apache.spark.ui.jobs.StagePage.liftedTree1$1(StagePage.scala:293) at org.apache.spark.ui.jobs.StagePage.render(StagePage.scala:282) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) ``` Moreover some refactoring to avoid similar problems by introducing constants for each header name and reusing them at the identification of the corresponding sorting index. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 18 57 10](https://user-images.githubusercontent.com/2017933/36166532-1cfdf3b8-10f3-11e8-8d32-5fcaad2af214.png) Author: “attilapiros” Closes #20601 from attilapiros/SPARK-23413. --- .../org/apache/spark/status/storeTypes.scala | 2 + .../org/apache/spark/ui/jobs/StagePage.scala | 121 +++++++++++------- .../org/apache/spark/ui/StagePageSuite.scala | 63 ++++++++- 3 files changed, 139 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 412644d3657b5..646cf25880e37 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -109,6 +109,7 @@ private[spark] object TaskIndexNames { final val DURATION = "dur" final val ERROR = "err" final val EXECUTOR = "exe" + final val HOST = "hst" final val EXEC_CPU_TIME = "ect" final val EXEC_RUN_TIME = "ert" final val GC_TIME = "gc" @@ -165,6 +166,7 @@ private[spark] class TaskDataWrapper( val duration: Long, @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) val executorId: String, + @KVIndexParam(value = TaskIndexNames.HOST, parent = TaskIndexNames.STAGE) val host: String, @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) val status: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5c2b0c3a19996..a9265d4dbcdfb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -750,37 +750,39 @@ private[ui] class TaskPagedTable( } def headers: Seq[Node] = { + import ApiHelper._ + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), - ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + (HEADER_TASK_INDEX, ""), (HEADER_ID, ""), (HEADER_ATTEMPT, ""), (HEADER_STATUS, ""), + (HEADER_LOCALITY, ""), (HEADER_EXECUTOR, ""), (HEADER_HOST, ""), (HEADER_LAUNCH_TIME, ""), + (HEADER_DURATION, ""), (HEADER_SCHEDULER_DELAY, TaskDetailsClassNames.SCHEDULER_DELAY), + (HEADER_DESER_TIME, TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + (HEADER_GC_TIME, ""), + (HEADER_SER_TIME, TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + (HEADER_GETTING_RESULT_TIME, TaskDetailsClassNames.GETTING_RESULT_TIME), + (HEADER_PEAK_MEM, TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ + {if (hasAccumulators(stage)) Seq((HEADER_ACCUMULATORS, "")) else Nil} ++ + {if (hasInput(stage)) Seq((HEADER_INPUT_SIZE, "")) else Nil} ++ + {if (hasOutput(stage)) Seq((HEADER_OUTPUT_SIZE, "")) else Nil} ++ {if (hasShuffleRead(stage)) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + Seq((HEADER_SHUFFLE_READ_TIME, TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + (HEADER_SHUFFLE_TOTAL_READS, ""), + (HEADER_SHUFFLE_REMOTE_READS, TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ {if (hasShuffleWrite(stage)) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + Seq((HEADER_SHUFFLE_WRITE_TIME, ""), (HEADER_SHUFFLE_WRITE_SIZE, "")) } else { Nil }} ++ {if (hasBytesSpilled(stage)) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + Seq((HEADER_MEM_SPILL, ""), (HEADER_DISK_SPILL, "")) } else { Nil }} ++ - Seq(("Errors", "")) + Seq((HEADER_ERROR, "")) if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { throw new IllegalArgumentException(s"Unknown column: $sortColumn") @@ -961,35 +963,62 @@ private[ui] class TaskPagedTable( } } -private object ApiHelper { - - - private val COLUMN_TO_INDEX = Map( - "ID" -> null.asInstanceOf[String], - "Index" -> TaskIndexNames.TASK_INDEX, - "Attempt" -> TaskIndexNames.ATTEMPT, - "Status" -> TaskIndexNames.STATUS, - "Locality Level" -> TaskIndexNames.LOCALITY, - "Executor ID / Host" -> TaskIndexNames.EXECUTOR, - "Launch Time" -> TaskIndexNames.LAUNCH_TIME, - "Duration" -> TaskIndexNames.DURATION, - "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, - "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, - "GC Time" -> TaskIndexNames.GC_TIME, - "Result Serialization Time" -> TaskIndexNames.SER_TIME, - "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, - "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, - "Accumulators" -> TaskIndexNames.ACCUMULATORS, - "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, - "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, - "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, - "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, - "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, - "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, - "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, - "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, - "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, - "Errors" -> TaskIndexNames.ERROR) +private[ui] object ApiHelper { + + val HEADER_ID = "ID" + val HEADER_TASK_INDEX = "Index" + val HEADER_ATTEMPT = "Attempt" + val HEADER_STATUS = "Status" + val HEADER_LOCALITY = "Locality Level" + val HEADER_EXECUTOR = "Executor ID" + val HEADER_HOST = "Host" + val HEADER_LAUNCH_TIME = "Launch Time" + val HEADER_DURATION = "Duration" + val HEADER_SCHEDULER_DELAY = "Scheduler Delay" + val HEADER_DESER_TIME = "Task Deserialization Time" + val HEADER_GC_TIME = "GC Time" + val HEADER_SER_TIME = "Result Serialization Time" + val HEADER_GETTING_RESULT_TIME = "Getting Result Time" + val HEADER_PEAK_MEM = "Peak Execution Memory" + val HEADER_ACCUMULATORS = "Accumulators" + val HEADER_INPUT_SIZE = "Input Size / Records" + val HEADER_OUTPUT_SIZE = "Output Size / Records" + val HEADER_SHUFFLE_READ_TIME = "Shuffle Read Blocked Time" + val HEADER_SHUFFLE_TOTAL_READS = "Shuffle Read Size / Records" + val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads" + val HEADER_SHUFFLE_WRITE_TIME = "Write Time" + val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records" + val HEADER_MEM_SPILL = "Shuffle Spill (Memory)" + val HEADER_DISK_SPILL = "Shuffle Spill (Disk)" + val HEADER_ERROR = "Errors" + + private[ui] val COLUMN_TO_INDEX = Map( + HEADER_ID -> null.asInstanceOf[String], + HEADER_TASK_INDEX -> TaskIndexNames.TASK_INDEX, + HEADER_ATTEMPT -> TaskIndexNames.ATTEMPT, + HEADER_STATUS -> TaskIndexNames.STATUS, + HEADER_LOCALITY -> TaskIndexNames.LOCALITY, + HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, + HEADER_HOST -> TaskIndexNames.HOST, + HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, + HEADER_DURATION -> TaskIndexNames.DURATION, + HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, + HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, + HEADER_GC_TIME -> TaskIndexNames.GC_TIME, + HEADER_SER_TIME -> TaskIndexNames.SER_TIME, + HEADER_GETTING_RESULT_TIME -> TaskIndexNames.GETTING_RESULT_TIME, + HEADER_PEAK_MEM -> TaskIndexNames.PEAK_MEM, + HEADER_ACCUMULATORS -> TaskIndexNames.ACCUMULATORS, + HEADER_INPUT_SIZE -> TaskIndexNames.INPUT_SIZE, + HEADER_OUTPUT_SIZE -> TaskIndexNames.OUTPUT_SIZE, + HEADER_SHUFFLE_READ_TIME -> TaskIndexNames.SHUFFLE_READ_TIME, + HEADER_SHUFFLE_TOTAL_READS -> TaskIndexNames.SHUFFLE_TOTAL_READS, + HEADER_SHUFFLE_REMOTE_READS -> TaskIndexNames.SHUFFLE_REMOTE_READS, + HEADER_SHUFFLE_WRITE_TIME -> TaskIndexNames.SHUFFLE_WRITE_TIME, + HEADER_SHUFFLE_WRITE_SIZE -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + HEADER_MEM_SPILL -> TaskIndexNames.MEM_SPILL, + HEADER_DISK_SPILL -> TaskIndexNames.DISK_SPILL, + HEADER_ERROR -> TaskIndexNames.ERROR) def hasAccumulators(stageData: StageData): Boolean = { stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 0aeddf730cd35..6044563f7dde7 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,13 +28,74 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} import org.apache.spark.status.config._ -import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable} class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 + test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + val statusStore = AppStatusStore.createLiveStore(conf) + try { + val stageData = new StageData( + status = StageStatus.ACTIVE, + stageId = 1, + attemptId = 1, + numTasks = 1, + numActiveTasks = 1, + numCompleteTasks = 1, + numFailedTasks = 1, + numKilledTasks = 1, + numCompletedIndices = 1, + + executorRunTime = 1L, + executorCpuTime = 1L, + submissionTime = None, + firstTaskLaunchedTime = None, + completionTime = None, + failureReason = None, + + inputBytes = 1L, + inputRecords = 1L, + outputBytes = 1L, + outputRecords = 1L, + shuffleReadBytes = 1L, + shuffleReadRecords = 1L, + shuffleWriteBytes = 1L, + shuffleWriteRecords = 1L, + memoryBytesSpilled = 1L, + diskBytesSpilled = 1L, + + name = "stage1", + description = Some("description"), + details = "detail", + schedulingPool = "pool1", + + rddIds = Seq(1), + accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), + tasks = None, + executorSummary = None, + killedTasksSummary = Map.empty + ) + val taskTable = new TaskPagedTable( + stageData, + basePath = "/a/b/c", + currentTime = 0, + pageSize = 10, + sortColumn = "Index", + desc = false, + store = statusStore + ) + val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet + assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet) + } finally { + statusStore.close() + } + } + test("peak execution memory should displayed") { val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" From c5857e496ff0d170ed0339f14afc7d36b192da6d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Feb 2018 09:41:17 -0800 Subject: [PATCH 0366/2461] [SPARK-23446][PYTHON] Explicitly check supported types in toPandas ## What changes were proposed in this pull request? This PR explicitly specifies and checks the types we supported in `toPandas`. This was a hole. For example, we haven't finished the binary type support in Python side yet but now it allows as below: ```python spark.conf.set("spark.sql.execution.arrow.enabled", "false") df = spark.createDataFrame([[bytearray("a")]]) df.toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", "true") df.toPandas() ``` ``` _1 0 [97] _1 0 a ``` This should be disallowed. I think the same things also apply to nested timestamps too. I also added some nicer message about `spark.sql.execution.arrow.enabled` in the error message. ## How was this patch tested? Manually tested and tests added in `python/pyspark/sql/tests.py`. Author: hyukjinkwon Closes #20625 from HyukjinKwon/pandas_convertion_supported_type. --- python/pyspark/sql/dataframe.py | 15 +++++++++------ python/pyspark/sql/tests.py | 9 ++++++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cc8b63cdfadf..f37777e13ee12 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1988,10 +1988,11 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps + _check_dataframe_localize_timestamps, to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version - import pyarrow require_minimum_pyarrow_version() + import pyarrow + to_arrow_schema(self.schema) tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) @@ -2000,10 +2001,12 @@ def toPandas(self): return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (_exception_message(e), msg)) + except Exception as e: + msg = ( + "Note: toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " + "to disable this.") + raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2af218a691026..19653072ea316 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3497,7 +3497,14 @@ def test_unsupported_datatype(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() + + df = self.spark.createDataFrame([(None,)], schema="a binary") + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): df.toPandas() def test_null_conversion(self): From 0a73aa31f41c83503d5d99eff3c9d7b406014ab3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Feb 2018 14:30:19 -0800 Subject: [PATCH 0367/2461] [SPARK-23362][SS] Migrate Kafka Microbatch source to v2 ## What changes were proposed in this pull request? Migrating KafkaSource (with data source v1) to KafkaMicroBatchReader (with data source v2). Performance comparison: In a unit test with in-process Kafka broker, I tested the read throughput of V1 and V2 using 20M records in a single partition. They were comparable. ## How was this patch tested? Existing tests, few modified to be better tests than the existing ones. Author: Tathagata Das Closes #20554 from tdas/SPARK-23362. --- dev/.rat-excludes | 1 + .../sql/kafka010/CachedKafkaConsumer.scala | 2 +- .../sql/kafka010/KafkaContinuousReader.scala | 29 +- .../sql/kafka010/KafkaMicroBatchReader.scala | 403 ++++++++++++++++++ .../KafkaRecordToUnsafeRowConverter.scala | 52 +++ .../spark/sql/kafka010/KafkaSource.scala | 19 +- .../sql/kafka010/KafkaSourceProvider.scala | 70 ++- ...a-source-initial-offset-future-version.bin | 2 + ...ka-source-initial-offset-version-2.1.0.bin | 2 +- ...scala => KafkaMicroBatchSourceSuite.scala} | 254 +++++++---- .../apache/spark/sql/internal/SQLConf.scala | 15 +- .../streaming/MicroBatchExecution.scala | 20 +- 12 files changed, 741 insertions(+), 128 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala create mode 100644 external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin rename external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/{KafkaSourceSuite.scala => KafkaMicroBatchSourceSuite.scala} (85%) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 243fbe3e1bc24..9552d001a079c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -105,3 +105,4 @@ META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin +kafka-source-initial-offset-future-version.bin diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 90ed7b1fba2f8..e97881cb0a163 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index b049a054cb40e..97a0f66e1880d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType @@ -187,13 +187,9 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val consumer = + CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ @@ -232,22 +228,7 @@ class KafkaContinuousDataReader( } override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + converter.toUnsafeRow(currentRecord) } override def getOffset(): KafkaSourcePartitionOffset = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala new file mode 100644 index 0000000000000..fb647ca7e70dd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[MicroBatchReader]] that reads data from Kafka. + * + * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using Kafka maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] class KafkaMicroBatchReader( + kafkaOffsetReader: KafkaOffsetReader, + executorKafkaParams: ju.Map[String, Object], + options: DataSourceOptions, + metadataPath: String, + startingOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + + type PartitionOffsetMap = Map[TopicPartition, Long] + + private var startPartitionOffsets: PartitionOffsetMap = _ + private var endPartitionOffsets: PartitionOffsetMap = _ + + private val pollTimeoutMs = options.getLong( + "kafkaConsumer.pollTimeoutMs", + SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + + private val maxOffsetsPerTrigger = + Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() + + override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + startPartitionOffsets = Option(start.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse(initialPartitionOffsets) + + endPartitionOffsets = Option(end.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse { + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + } + } + } + + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + // Find the new partitions, and get their earliest offsets + val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) + val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + // Find deleted partitions, and report data loss if required + val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = endPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList() + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val factories = topicPartitions.flatMap { tp => + val fromOffset = startPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse( + tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = endPartitionOffsets(tp) + + if (untilOffset >= fromOffset) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + val preferredLoc = if (numExecutors > 0) { + Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) + } else None + val range = KafkaOffsetRange(tp, fromOffset, untilOffset) + Some( + new KafkaMicroBatchDataReaderFactory( + range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) + } else { + reportDataLoss( + s"Partition $tp's offset was changed from " + + s"$fromOffset to $untilOffset, some data may have been missed") + None + } + } + factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + } + + override def getStartOffset: Offset = { + KafkaSourceOffset(startPartitionOffsets) + } + + override def getEndOffset: Offset = { + KafkaSourceOffset(endPartitionOffsets) + } + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = { + kafkaOffsetReader.close() + } + + override def toString(): String = s"Kafka[$kafkaOffsetReader]" + + /** + * Read initial partition offsets from the checkpoint, or decide the offsets and write them to + * the checkpoint. + */ + private def getOrCreateInitialPartitionOffsets(): PartitionOffsetMap = { + // Make sure that `KafkaConsumer.poll` is only called in StreamExecutionThread. + // Otherwise, interrupting a thread while running `KafkaConsumer.poll` may hang forever + // (KAFKA-1894). + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + + // SparkSession is required for getting Hadoop configuration for writing to checkpoints + assert(SparkSession.getActiveSession.nonEmpty) + + val metadataLog = + new KafkaSourceInitialOffsetWriter(SparkSession.getActiveSession.get, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = startingOffsets match { + case EarliestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => + kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: PartitionOffsetMap, + until: PartitionOffsetMap): PartitionOffsetMap = { + val fromNew = kafkaOffsetReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } + } + + private def getSortedExecutorList(): Array[String] = { + + def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + } + + val bm = SparkEnv.get.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } + + /** A version of [[HDFSMetadataLog]] specialized for saving the initial offsets. */ + class KafkaSourceInitialOffsetWriter(sparkSession: SparkSession, metadataPath: String) + extends HDFSMetadataLog[KafkaSourceOffset](sparkSession, metadataPath) { + + val VERSION = 1 + + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + in.read() // A zero byte is read to support Spark 2.1.0 (SPARK-19517) + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + // The log was generated by Spark 2.1.0 + KafkaSourceOffset(SerializedOffset(content)) + } + } + } +} + +/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReaderFactory( + range: KafkaOffsetRange, + preferredLoc: Option[String], + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + + override def preferredLocations(): Array[String] = preferredLoc.toArray + + override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) +} + +/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReader( + offsetRange: KafkaOffsetRange, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + private val rangeToRead = resolveRange(offsetRange) + private val converter = new KafkaRecordToUnsafeRowConverter + + private var nextOffset = rangeToRead.fromOffset + private var nextRow: UnsafeRow = _ + + override def next(): Boolean = { + if (nextOffset < rangeToRead.untilOffset) { + val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) + if (record != null) { + nextRow = converter.toUnsafeRow(record) + true + } else { + false + } + } else { + false + } + } + + override def get(): UnsafeRow = { + assert(nextRow != null) + nextOffset += 1 + nextRow + } + + override def close(): Unit = { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + + private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { + if (range.fromOffset < 0 || range.untilOffset < 0) { + // Late bind the offset range + val availableOffsetRange = consumer.getAvailableOffsetRange() + val fromOffset = if (range.fromOffset < 0) { + assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST, + s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}") + availableOffsetRange.earliest + } else { + range.fromOffset + } + val untilOffset = if (range.untilOffset < 0) { + assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST, + s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}") + availableOffsetRange.latest + } else { + range.untilOffset + } + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + } else { + range + } + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala new file mode 100644 index 0000000000000..1acdd56125741 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ +private[kafka010] class KafkaRecordToUnsafeRowConverter { + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { + bufferHolder.reset() + + if (record.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, record.key) + } + rowWriter.write(1, record.value) + rowWriter.write(2, UTF8String.fromString(record.topic)) + rowWriter.write(3, record.partition) + rowWriter.write(4, record.offset) + rowWriter.write( + 5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) + rowWriter.write(6, record.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 169a5d006fb04..1c7b3a29a861f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -306,7 +307,7 @@ private[kafka010] class KafkaSource( kafkaReader.close() } - override def toString(): String = s"KafkaSource[$kafkaReader]" + override def toString(): String = s"KafkaSourceV1[$kafkaReader]" /** * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. @@ -323,22 +324,6 @@ private[kafka010] class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you want your streaming query to fail on such cases, set the source - | option "failOnDataLoss" to "true". - """.stripMargin - - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you don't want your streaming query to fail on such cases, set the - | source option "failOnDataLoss" to "false". - """.stripMargin - private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d4fa0359c12d6..0aa64a6a9cf90 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,13 +30,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * The provider class for all Kafka readers and writers. It is designed such that it throws * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ @@ -47,6 +47,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with CreatableRelationProvider with StreamWriteSupport with ContinuousReadSupport + with MicroBatchReadSupport with Logging { import KafkaSourceProvider._ @@ -105,6 +106,52 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches + * of Kafka data in a micro-batch streaming query. + */ + override def createMicroBatchReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceOptions): KafkaMicroBatchReader = { + + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Kafka data in a continuous streaming query. + */ override def createContinuousReader( schema: Optional[StructType], metadataPath: String, @@ -408,8 +455,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + + + private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin new file mode 100644 index 0000000000000..d530773f57327 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin @@ -0,0 +1,2 @@ +0v99999 +{"kafka-initial-offset-future-version":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin index ae928e724967d..8c78d9e390a0e 100644 --- a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin @@ -1 +1 @@ -2{"kafka-initial-offset-2-1-0":{"2":0,"1":0,"0":0}} \ No newline at end of file +2{"kafka-initial-offset-2-1-0":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala similarity index 85% rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 02c87643568bd..ed4ecfeafa972 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -25,6 +25,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable +import scala.io.Source import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata @@ -42,7 +43,6 @@ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -112,14 +112,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + val sources = { + query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: KafkaSource, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) + }.distinct + if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -155,7 +159,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { +abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { import testImplicits._ @@ -303,94 +307,105 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { ) } - testWithUninterruptibleThread( - "deserialization of initial offset with Spark 2.1.0") { + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { withTempDir { metadataPath => - val topic = newTopic - testUtils.createTopic(topic, partitions = 3) + val topic = "kafka-initial-offset-current" + testUtils.createTopic(topic, partitions = 1) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - source.getOffset.get // Write initial offset - - // Make sure Spark 2.1.0 will throw an exception when reading the new log - intercept[java.lang.IllegalArgumentException] { - // Simulate how Spark 2.1.0 reads the log - Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) - } + val initialOffsetFile = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0").toFile + + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + + // Test the written initial offset file has 0 byte in the beginning, so that + // Spark 2.1.0 can read the offsets (see SPARK-19517) + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + makeSureGetOffsetCalled) + + val binarySource = Source.fromFile(initialOffsetFile) + try { + assert(binarySource.next().toInt == 0) // first byte is binary 0 + } finally { + binarySource.close() } } } - testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") { + test("deserialization of initial offset written by Spark 2.1.0 (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-2-1-0" testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, Array("0", "1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("0", "10", "20"), Some(1)) + testUtils.sendMessages(topic, Array("0", "100", "200"), Some(2)) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. val from = new File( getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath - val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) Files.copy(from, to) - val source = provider.createSource( - spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) - val deserializedOffset = source.getOffset.get - val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - assert(referenceOffset == deserializedOffset) + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + // Test that the query starts from the expected initial offset (i.e. read older offsets, + // even though startingOffsets is latest). + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + AddKafkaData(Set(topic), 1000), + CheckAnswer(0, 1, 2, 10, 20, 200, 1000)) } } - testWithUninterruptibleThread("deserialization of initial offset written by future version") { + test("deserialization of initial offset written by future version") { withTempDir { metadataPath => - val futureMetadataLog = - new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, - metadataPath.getAbsolutePath) { - override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { - out.write(0) - val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v99999\n${metadata.json}") - writer.flush - } - } - - val topic = newTopic + val topic = "kafka-initial-offset-future-version" testUtils.createTopic(topic, partitions = 3) - val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - futureMetadataLog.add(0, offset) - - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - val e = intercept[java.lang.IllegalStateException] { - source.getOffset.get // Read initial offset - } + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. + val from = new File( + getClass.getResource("/kafka-source-initial-offset-future-version.bin").toURI).toPath + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) + Files.copy(from, to) - Seq( - s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", - "produced by a newer version of Spark and cannot be read by this version" - ).foreach { message => - assert(e.getMessage.contains(message)) - } + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + ExpectFailure[IllegalStateException](e => { + Seq( + s"maximum supported log version is v1, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.toString.contains(message)) + } + })) } } @@ -542,6 +557,91 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { CheckLastBatch(120 to 124: _*) ) } + + test("ensure stream-stream self-join generates only one offset in offset log") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .load() + + val values = kafka + .selectExpr("CAST(CAST(value AS STRING) AS INT) AS value", + "CAST(CAST(value AS STRING) AS INT) % 5 AS key") + + val join = values.join(values, "key") + + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + ) + } +} + + +class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set( + "spark.sql.streaming.disabledV2MicroBatchReaders", + classOf[KafkaSourceProvider].getCanonicalName) + } + + test("V1 Source is used when disabled through SQLConf") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaSource, _) => true + }.nonEmpty + } + ) + } +} + +class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { + + test("V2 Source is used by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + }.nonEmpty + } + ) + } } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f24fd7ff74d3f..e75e1d66ebcf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1146,10 +1146,20 @@ object SQLConf { val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .internal() .doc("A comma-separated list of fully qualified data source register class names for which" + - " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") .stringConf .createWithDefault("") + val DISABLED_V2_STREAMING_MICROBATCH_READERS = + buildConf("spark.sql.streaming.disabledV2MicroBatchReaders") + .internal() + .doc( + "A comma-separated list of fully qualified data source register class names for which " + + "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " + + "V1 Sources.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1525,6 +1535,9 @@ class SQLConf extends Serializable with Logging { def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def disabledV2StreamingMicroBatchReaders: String = + getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ac73ba3417904..84655013ba957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -72,27 +72,36 @@ class MicroBatchExecution( // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, // since the existing logical plan has already used those attributes. The per-microbatch // transformation is responsible for replacing attributes with their final values. + + val disabledSources = + sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + val _logicalPlan = analyzedPlan.transform { - case streamingRelation@StreamingRelation(dataSource, _, output) => + case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) + val source = dataSourceV1.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + case s @ StreamingRelationV2( + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = source.createMicroBatchReader( + val reader = dataSourceV2.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + logInfo(s"Using MicroBatchReader [$reader] from " + + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => + case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" @@ -102,6 +111,7 @@ class MicroBatchExecution( } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(source, output)(sparkSession) }) } From d5ed2108d32e1d95b26ee7fed39e8a733e935e2c Mon Sep 17 00:00:00 2001 From: Shintaro Murakami Date: Fri, 16 Feb 2018 17:17:55 -0800 Subject: [PATCH 0368/2461] [SPARK-23381][CORE] Murmur3 hash generates a different value from other implementations ## What changes were proposed in this pull request? Murmur3 hash generates a different value from the original and other implementations (like Scala standard library and Guava or so) when the length of a bytes array is not multiple of 4. ## How was this patch tested? Added a unit test. **Note: When we merge this PR, please give all the credits to Shintaro Murakami.** Author: Shintaro Murakami Author: gatorsmile Author: Shintaro Murakami Closes #20630 from gatorsmile/pr-20568. --- .../spark/util/sketch/Murmur3_x86_32.java | 16 +++++++++ .../spark/unsafe/hash/Murmur3_x86_32.java | 16 +++++++++ .../unsafe/hash/Murmur3_x86_32Suite.java | 19 +++++++++++ .../spark/ml/feature/FeatureHasher.scala | 33 ++++++++++++++++++- .../spark/mllib/feature/HashingTF.scala | 2 +- .../spark/ml/feature/FeatureHasherSuite.scala | 11 ++++++- python/pyspark/ml/feature.py | 4 +-- 7 files changed, 96 insertions(+), 5 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java index a61ce4fb7241d..e83b331391e39 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 5e7ee480cafd1..d239de6083ad0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index e759cb33b3e6a..6348a73bf3895 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,6 +22,8 @@ import java.util.Random; import java.util.Set; +import scala.util.hashing.MurmurHash3$; + import org.apache.spark.unsafe.Platform; import org.junit.Assert; import org.junit.Test; @@ -51,6 +53,23 @@ public void testKnownLongInputs() { Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); } + // SPARK-23381 Check whether the hash of the byte array is the same as another implementations + @Test + public void testKnownBytesInputs() { + byte[] test = "test".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0), + Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0)); + byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0), + Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0)); + byte[] te = "te".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0), + Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0)); + byte[] tes = "tes".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0), + Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index a918dd4c075da..c78f61ac3ef71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - val hashFunc: Any => Int = OldHashingTF.murmur3Hash + val hashFunc: Any => Int = FeatureHasher.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) val catCols = if (isSet(categoricalCols)) { @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { @Since("2.3.0") override def load(path: String): FeatureHasher = super.load(path) + + private val seed = OldHashingTF.seed + + /** + * Calculate a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32). + * This is the default hash algorithm used from Spark 2.0 onwards. + * Use hashUnsafeBytes2 to match the original algorithm with the value. + * See SPARK-23381. + */ + @Since("2.3.0") + private[feature] def murmur3Hash(term: Any): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 9abdd44a635d1..8935c8496cdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -135,7 +135,7 @@ object HashingTF { private[HashingTF] val Murmur3: String = "murmur3" - private val seed = 42 + private[spark] val seed = 42 /** * Calculate a hash code value for the term object using the native Scala implementation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 3fc3cbb62d5b5..7bc1825b69c43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class FeatureHasherSuite extends SparkFunSuite with MLlibTestSparkContext @@ -34,7 +35,7 @@ class FeatureHasherSuite extends SparkFunSuite import testImplicits._ - import HashingTFSuite.murmur3FeatureIdx + import FeatureHasherSuite.murmur3FeatureIdx implicit private val vectorEncoder = ExpressionEncoder[Vector]() @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite testDefaultReadWrite(t) } } + +object FeatureHasherSuite { + + private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures) + } + +} diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..04b07e6a05481 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> df = spark.createDataFrame(data, cols) >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasher.setCategoricalCols(["real"]).transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) + SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) From 15ad4a7f1000c83cefbecd41e315c964caa3c39f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sat, 17 Feb 2018 10:54:14 +0800 Subject: [PATCH 0369/2461] [SPARK-23447][SQL] Cleanup codegen template for Literal ## What changes were proposed in this pull request? Cleaned up the codegen templates for `Literal`s, to make sure that the `ExprCode` returned from `Literal.doGenCode()` has: 1. an empty `code` field; 2. an `isNull` field of either literal `true` or `false`; 3. a `value` field that is just a simple literal/constant. Before this PR, there are a couple of paths that would return a non-trivial `code` and all of them are actually unnecessary. The `NaN` and `Infinity` constants for `double` and `float` can be accessed through constants directly available so there's no need to add a reference for them. Also took the opportunity to add a new util method for ease of creating `ExprCode` for inline-able non-null values. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #20626 from rednaxelafx/codegen-literal. --- .../expressions/codegen/CodeGenerator.scala | 6 +++ .../sql/catalyst/expressions/literals.scala | 51 ++++++++++--------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dcbb702893da..31ba29ae8d8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -58,6 +58,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} */ case class ExprCode(var code: String, var isNull: String, var value: String) +object ExprCode { + def forNonNullValue(value: String): ExprCode = { + ExprCode(code = "", isNull = "false", value = value) + } +} + /** * State used for subexpression elimination. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index cd176d941819f..c1e65e34c2ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -278,40 +278,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - // change the isNull and primitive to consts, to inline them if (value == null) { - ev.isNull = "true" - ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") + val defaultValueLiteral = ctx.defaultValue(javaType) match { + case "null" => s"(($javaType)null)" + case lit => lit + } + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) } else { - ev.isNull = "false" dataType match { case BooleanType | IntegerType | DateType => - ev.copy(code = "", value = value.toString) + ExprCode.forNonNullValue(value.toString) case FloatType => - val v = value.asInstanceOf[Float] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}f") + value.asInstanceOf[Float] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Float.NaN") + case Float.PositiveInfinity => + ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + case Float.NegativeInfinity => + ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}F") } case DoubleType => - val v = value.asInstanceOf[Double] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}D") + value.asInstanceOf[Double] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Double.NaN") + case Double.PositiveInfinity => + ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + case Double.NegativeInfinity => + ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}D") } case ByteType | ShortType => - ev.copy(code = "", value = s"($javaType)$value") + ExprCode.forNonNullValue(s"($javaType)$value") case TimestampType | LongType => - ev.copy(code = "", value = s"${value}L") + ExprCode.forNonNullValue(s"${value}L") case _ => - ev.copy(code = "", value = ctx.addReferenceObj("literal", value, - ctx.javaType(dataType))) + val constRef = ctx.addReferenceObj("literal", value, javaType) + ExprCode.forNonNullValue(constRef) } } } From 3ee3b2ae1ff8fbeb43a08becef43a9bd763b06bb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Feb 2018 00:25:36 -0800 Subject: [PATCH 0370/2461] [SPARK-23340][SQL] Upgrade Apache ORC to 1.4.3 ## What changes were proposed in this pull request? This PR updates Apache ORC dependencies to 1.4.3 released on February 9th. Apache ORC 1.4.2 release removes unnecessary dependencies and 1.4.3 has 5 more patches (https://s.apache.org/Fll8). Especially, the following ORC-285 is fixed at 1.4.3. ```scala scala> val df = Seq(Array.empty[Float]).toDF() scala> df.write.format("orc").save("/tmp/floatarray") scala> spark.read.orc("/tmp/floatarray") res1: org.apache.spark.sql.DataFrame = [value: array] scala> spark.read.orc("/tmp/floatarray").show() 18/02/12 22:09:10 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 1) java.io.IOException: Error reading file: file:/tmp/floatarray/part-00000-9c0b461b-4df1-4c23-aac1-3e4f349ac7d6-c000.snappy.orc at org.apache.orc.impl.RecordReaderImpl.nextBatch(RecordReaderImpl.java:1191) at org.apache.orc.mapreduce.OrcMapreduceRecordReader.ensureBatch(OrcMapreduceRecordReader.java:78) ... Caused by: java.io.EOFException: Read past EOF for compressed stream Stream for column 2 kind DATA position: 0 length: 0 range: 0 offset: 0 limit: 0 ``` ## How was this patch tested? Pass the Jenkins test. Author: Dongjoon Hyun Closes #20511 from dongjoon-hyun/SPARK-23340. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 6 +----- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ .../apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala | 10 ++++++++++ 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 99031384aa22e..ed310507d14ed 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cf8d2789b7ee9..04dec04796af4 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index de949b94d676c..ac30107066389 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.1 + 1.4.3 nohive 1.6.0 9.3.20.v20170531 @@ -1740,10 +1740,6 @@ org.apache.hive hive-storage-api - - io.airlift - slice - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 6f5f2fd795f74..523f7cf77e103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -160,6 +160,15 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + Seq(Seq(Array.empty[Float]).toDF(), Seq(Array.empty[Double]).toDF()).foreach { df => + withTempPath { path => + df.write.format("orc").save(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), df) + } + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 92b2f069cacd6..597b0f56a55e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -208,4 +208,14 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "false") { + withTable("spark_23340") { + sql("CREATE TABLE spark_23340(a array, b array) STORED AS ORC") + sql("INSERT INTO spark_23340 VALUES (array(), array())") + checkAnswer(spark.table("spark_23340"), Seq(Row(Array.empty[Float], Array.empty[Double]))) + } + } + } } From f5850e78924d03448ad243cdd32b24c3fe0ea8af Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 13:33:03 +0800 Subject: [PATCH 0371/2461] [SPARK-23457][SQL] Register task completion listeners first in ParquetFileFormat ## What changes were proposed in this pull request? ParquetFileFormat leaks opened files in some cases. This PR prevents that by registering task completion listers first before initialization. - [spark-branch-2.3-test-sbt-hadoop-2.7](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.3-test-sbt-hadoop-2.7/205/testReport/org.apache.spark.sql/FileBasedDataSourceSuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) - [spark-master-test-sbt-hadoop-2.6](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4228/testReport/junit/org.apache.spark.sql.execution.datasources.parquet/ParquetQuerySuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) ``` Caused by: sbt.ForkMain$ForkError: java.lang.Throwable: null at org.apache.spark.DebugFilesystem$.addOpenStream(DebugFilesystem.scala:36) at org.apache.spark.DebugFilesystem.open(DebugFilesystem.scala:70) at org.apache.hadoop.fs.FileSystem.open(FileSystem.java:769) at org.apache.parquet.hadoop.ParquetFileReader.(ParquetFileReader.java:538) at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.initialize(SpecificParquetRecordReaderBase.java:149) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.initialize(VectorizedParquetRecordReader.java:133) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReaderWithPartitionValues$1.apply(ParquetFileFormat.scala:400) at ``` ## How was this patch tested? Manual. The following test case generates the same leakage. ```scala test("SPARK-23457 Register task completion listeners first in ParquetFileFormat") { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("parquet").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("parquet").save(new Path(basePath, "second").toString) val df = spark.read.parquet( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20619 from dongjoon-hyun/SPARK-23390. --- .../parquet/ParquetFileFormat.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ba69f9a26c968..476bd02374364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -395,16 +395,21 @@ class ParquetFileFormat ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } val taskContext = Option(TaskContext.get()) - val parquetReader = if (enableVectorizedReader) { + if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) + val iter = new RecordReaderIterator(vectorizedReader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) if (returningBatch) { vectorizedReader.enableReturningBatches() } - vectorizedReader + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow @@ -414,18 +419,11 @@ class ParquetFileFormat } else { new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) } + val iter = new RecordReaderIterator(reader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) reader.initialize(split, hadoopAttemptContext) - reader - } - val iter = new RecordReaderIterator(parquetReader) - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) - - // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. - if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && - enableVectorizedReader) { - iter.asInstanceOf[Iterator[InternalRow]] - } else { val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) From 651b0277fe989119932d5ae1ef729c9768aa018d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 20 Feb 2018 13:56:38 +0800 Subject: [PATCH 0372/2461] [SPARK-23436][SQL] Infer partition as Date only if it can be casted to Date ## What changes were proposed in this pull request? Before the patch, Spark could infer as Date a partition value which cannot be casted to Date (this can happen when there are extra characters after a valid date, like `2018-02-15AAA`). When this happens and the input format has metadata which define the schema of the table, then `null` is returned as a value for the partition column, because the `cast` operator used in (`PartitioningAwareFileIndex.inferPartitioning`) is unable to convert the value. The PR checks in the partition inference that values can be casted to Date and Timestamp, in order to infer that datatype to them. ## How was this patch tested? added UT Author: Marco Gaido Closes #20621 from mgaido91/SPARK-23436. --- .../datasources/PartitioningUtils.scala | 40 ++++++++++++++----- .../ParquetPartitionDiscoverySuite.scala | 14 +++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 472bf82d3604d..379acb67f7c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -407,6 +407,34 @@ object PartitioningUtils { Literal(bigDecimal) } + val dateTry = Try { + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // DateType + DateTimeUtils.getThreadLocalDateFormat.parse(raw) + // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. + // This can happen since DateFormat.parse may not use the entire text of the given string: + // so if there are extra-characters after the date, it returns correctly. + // We need to check that we can cast the raw string since we later can use Cast to get + // the partition values with the right DataType (see + // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning) + val dateValue = Cast(Literal(raw), DateType).eval() + // Disallow DateType if the cast returned null + require(dateValue != null) + Literal.create(dateValue, DateType) + } + + val timestampTry = Try { + val unescapedRaw = unescapePathName(raw) + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // TimestampType + DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw) + // SPARK-23436: see comment for date + val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval() + // Disallow TimestampType if the cast returned null + require(timestampValue != null) + Literal.create(timestampValue, TimestampType) + } + if (typeInference) { // First tries integral types Try(Literal.create(Integer.parseInt(raw), IntegerType)) @@ -415,16 +443,8 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try( - Literal.create( - DateTimeUtils.getThreadLocalTimestampFormat(timeZone) - .parse(unescapePathName(raw)).getTime * 1000L, - TimestampType))) - .orElse(Try( - Literal.create( - DateTimeUtils.millisToDays( - DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), - DateType))) + .orElse(timestampTry) + .orElse(dateTry) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index d4902641e335f..edb3da904d10d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1120,4 +1120,18 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Row(3, BigDecimal("2" * 30)) :: Nil) } } + + test("SPARK-23436: invalid Dates should be inferred as String in partition inference") { + withTempPath { path => + val data = Seq(("1", "2018-01", "2018-01-01-04", "test")) + .toDF("id", "date_month", "date_hour", "data") + + data.write.partitionBy("date_month", "date_hour").parquet(path.getAbsolutePath) + val input = spark.read.parquet(path.getAbsolutePath).select("id", + "date_month", "date_hour", "data") + + assert(input.schema.sameType(input.schema)) + checkAnswer(input, data) + } + } } From aadf9535b4a11b42fd9d72f636576d2da0766199 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 20 Feb 2018 16:04:22 +0800 Subject: [PATCH 0373/2461] [SPARK-23203][SQL] DataSourceV2: Use immutable logical plans. ## What changes were proposed in this pull request? SPARK-23203: DataSourceV2 should use immutable catalyst trees instead of wrapping a mutable DataSourceV2Reader. This commit updates DataSourceV2Relation and consolidates much of the DataSourceV2 API requirements for the read path in it. Instead of wrapping a reader that changes, the relation lazily produces a reader from its configuration. This commit also updates the predicate and projection push-down. Instead of the implementation from SPARK-22197, this reuses the rule matching from the Hive and DataSource read paths (using `PhysicalOperation`) and copies most of the implementation of `SparkPlanner.pruneFilterProject`, with updates for DataSourceV2. By reusing the implementation from other read paths, this should have fewer regressions from other read paths and is less code to maintain. The new push-down rules also supports the following edge cases: * The output of DataSourceV2Relation should be what is returned by the reader, in case the reader can only partially satisfy the requested schema projection * The requested projection passed to the DataSourceV2Reader should include filter columns * The push-down rule may be run more than once if filters are not pushed through projections ## How was this patch tested? Existing push-down and read tests. Author: Ryan Blue Closes #20387 from rdblue/SPARK-22386-push-down-immutable-trees. --- .../kafka010/KafkaContinuousSourceSuite.scala | 19 +- .../sql/kafka010/KafkaContinuousTest.scala | 4 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 41 +--- .../datasources/v2/DataSourceV2Relation.scala | 212 +++++++++++++++++- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../v2/PushDownOperatorsToDataSource.scala | 159 ++++--------- .../continuous/ContinuousExecution.scala | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 6 +- 10 files changed, 269 insertions(+), 187 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..f679e9bfc0450 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..48ac3fc1e8f9d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index ed4ecfeafa972..89c9ef4cc73b5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} @@ -119,7 +119,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..4274f120a375a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -189,39 +189,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. - val reader = (ds, userSpecifiedSchema) match { - case (ds: ReadSupportWithSchema, Some(schema)) => - ds.createReader(schema, options) - - case (ds: ReadSupport, None) => - ds.createReader(options) - - case (ds: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $ds.") - - case (ds: ReadSupport, Some(schema)) => - val reader = ds.createReader(options) - if (reader.readSchema() != schema) { - throw new AnalysisException(s"$ds does not allow user-specified schemas.") - } - reader - - case _ => null // fall back to v1 - } + val ds = cls.newInstance().asInstanceOf[DataSourceV2] + if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( + ds, extraOptions.toMap ++ sessionOptions, + userSpecifiedSchema = userSpecifiedSchema)) - if (reader == null) { - loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + loadV1Source(paths: _*) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..a98dd4866f82a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,17 +17,80 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + options: Map[String, String], + projection: Seq[AttributeReference], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + + import DataSourceV2Relation._ + + override def simpleString: String = { + s"DataSourceV2Relation(source=${source.name}, " + + s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + + s"filters=[${pushedFilters.mkString(", ")}], options=$options)" + } + + override lazy val schema: StructType = reader.readSchema() + + override lazy val output: Seq[AttributeReference] = { + // use the projection attributes to avoid assigning new ids. fields that are not projected + // will be assigned new ids, which is okay because they are not projected. + val attrMap = projection.map(a => a.name -> a).toMap + schema.map(f => attrMap.getOrElse(f.name, + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + } + + private lazy val v2Options: DataSourceOptions = makeV2Options(options) + + lazy val ( + reader: DataSourceReader, + unsupportedFilters: Seq[Expression], + pushedFilters: Seq[Expression]) = { + val newReader = userSpecifiedSchema match { + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + + DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + val (remainingFilters, pushedFilters) = filters match { + case Some(filterSeq) => + DataSourceV2Relation.pushFilters(newReader, filterSeq) + case _ => + (Nil, Nil) + } + + (newReader, remainingFilters, pushedFilters) + } + + override def doCanonicalize(): LogicalPlan = { + val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + + // override output with canonicalized output to avoid attempting to configure a reader + val canonicalOutput: Seq[AttributeReference] = this.output + .map(a => QueryPlan.normalizeExprId(a, projection)) + + new DataSourceV2Relation(c.source, c.options, c.projection) { + override lazy val output: Seq[AttributeReference] = canonicalOutput + } + } override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => @@ -37,7 +100,9 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(output = output.map(_.newInstance())) + // projection is used to maintain id assignment. + // if projection is not set, use output so the copy is not equal to the original + copy(projection = projection.map(_.newInstance())) } } @@ -45,14 +110,137 @@ case class DataSourceV2Relation( * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical * to the non-streaming relation. */ -class StreamingDataSourceV2Relation( +case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + reader: DataSourceReader) + extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { override def isStreaming: Boolean = true + + override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } } object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + private implicit class SourceHelpers(source: DataSourceV2) { + def asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _: ReadSupportWithSchema => + // this method is only called if there is no user-supplied schema. if there is no + // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. + throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def asReadSupportWithSchema: ReadSupportWithSchema = { + source match { + case support: ReadSupportWithSchema => + support + case _: ReadSupport => + throw new AnalysisException( + s"Data source does not support user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def name: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + } + + private def makeV2Options(options: Map[String, String]): DataSourceOptions = { + new DataSourceOptions(options.asJava) + } + + private def schema( + source: DataSourceV2, + v2Options: DataSourceOptions, + userSchema: Option[StructType]): StructType = { + val reader = userSchema match { + // TODO: remove this case because it is confusing for users + case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => + val reader = source.asReadSupport.createReader(v2Options) + if (reader.readSchema() != s) { + throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") + } + reader + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + reader.readSchema() + } + + def create( + source: DataSourceV2, + options: Map[String, String], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, options, projection, filters, + // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must + // be equal to the reader's schema. the schema method enforces this. because the user schema + // and the reader's schema are identical, drop the user schema. + if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + } + + private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + reader match { + case projectionSupport: SupportsPushDownRequiredColumns => + projectionSupport.pruneColumns(struct) + case _ => + } + } + + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case catalystFilterSupport: SupportsPushDownCatalystFilters => + ( + catalystFilterSupport.pushCatalystFilters(filters.toArray), + catalystFilterSupport.pushedCatalystFilters() + ) + + case filterSupport: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet + val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => + unhandledFilters.contains(f) + } + + (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + + case _ => (filters, Nil) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..c4e7644683c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case relation: DataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + + case relation: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..f23d228567241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,130 +17,55 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} -import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.v2.reader._ -/** - * Pushes down various operators to the underlying data source for better performance. Operators are - * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you - * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the - * data source should execute FILTER before LIMIT. And required columns are calculated at the end, - * because when more operators are pushed down, we may need less columns at Spark side. - */ -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may - // appear in many places for column pruning. - // TODO: Ideally column pruning should be implemented via a plan property that is propagated - // top-down, then we can simplify the logic here and only collect target operators. - val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(candidates.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => candidates - } - - val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) - if (withFilter.output == fields) { - withFilter - } else { - Project(fields, withFilter) - } - } - - // TODO: add more push down rules. - - val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(columnPruned) - } - - // TODO: nested fields pruning - private def pushDownRequiredColumns( - plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { - plan match { - case p @ Project(projectList, child) => - val required = projectList.flatMap(_.references) - p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - - case f @ Filter(condition, child) => - val required = requiredByParent ++ condition.references - f.copy(child = pushDownRequiredColumns(child, required)) - - case relation: DataSourceV2Relation => relation.reader match { - case reader: SupportsPushDownRequiredColumns => - // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now - // it's possible that the mutable reader being updated by someone else, and we need to - // always call `reader.pruneColumns` here to correct it. - // assert(relation.output.toStructType == reader.readSchema(), - // "Schema of data source reader does not match the relation plan.") - - val requiredColumns = relation.output.filter(requiredByParent.contains) - reader.pruneColumns(requiredColumns.toStructType) - - val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - val newOutput = reader.readSchema().map(_.name).map(nameToAttr) - relation.copy(output = newOutput) - - case _ => relation +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { + override def apply( + plan: LogicalPlan): LogicalPlan = plan transformUp { + // PhysicalOperation guarantees that filters are deterministic; no need to check + case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => + // merge the filters + val filters = relation.filters match { + case Some(existing) => + existing ++ newFilters + case _ => + newFilters } - // TODO: there may be more operators that can be used to calculate the required columns. We - // can add more and more in the future. - case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) - } - } - - /** - * Finds a Filter node(with an optional Project child) above data source relation. - */ - object FilterAndProject { - // returns the project list, the filter condition and the data source relation. - def unapply(plan: LogicalPlan) - : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + val projectAttrs = project.map(_.toAttribute) + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(projectAttrs) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + projectAttrs + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } - case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + val newRelation = relation.copy( + projection = projection.asInstanceOf[Seq[AttributeReference]], + filters = Some(filters)) - case Filter(condition, Project(fields, r: DataSourceV2Relation)) - if fields.forall(_.deterministic) => - val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) - val substituted = condition.transform { - case a: Attribute => attributeMap.getOrElse(a, a) - } - Some((fields, substituted, r)) + // Add a Filter for any filters that could not be pushed + val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) + val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - case _ => None - } + // Add a Project to ensure the output matches the required projection + if (newRelation.output != projectAttrs) { + Project(project, filtered) + } else { + filtered + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..2c1d6c509d21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a1c87fb15542c..1157a350461d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -146,7 +146,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("A schema needs to be specified")) + assert(e.message.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..159dd0ecb5902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case StreamingDataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) From 862fa697d829cdddf0f25e5613c91b040f9d9652 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 20 Feb 2018 20:26:26 +0900 Subject: [PATCH 0374/2461] [SPARK-23240][PYTHON] Better error message when extraneous data in pyspark.daemon's stdout ## What changes were proposed in this pull request? Print more helpful message when daemon module's stdout is empty or contains a bad port number. ## How was this patch tested? Manually recreated the environmental issues that caused the mysterious exceptions at one site. Tested that the expected messages are logged. Also, ran all scala unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bruce Robbins Closes #20424 from bersprockets/SPARK-23240_prop2. --- .../api/python/PythonWorkerFactory.scala | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 30976ac752a8a..2340580b54f67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, EOFException, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.charset.StandardCharsets import java.util.Arrays @@ -182,7 +182,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) + val command = Arrays.asList(pythonExec, "-m", daemonModule) + val pb = new ProcessBuilder(command) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -191,7 +192,29 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() + try { + daemonPort = in.readInt() + } catch { + case _: EOFException => + throw new SparkException(s"No port number in $daemonModule's stdout") + } + + // test that the returned port number is within a valid range. + // note: this does not cover the case where the port number + // is arbitrary data but is also coincidentally within range + if (daemonPort < 1 || daemonPort > 0xffff) { + val exceptionMessage = f""" + |Bad data in $daemonModule's standard output. Invalid port number: + | $daemonPort (0x$daemonPort%08x) + |Python command to execute the daemon was: + | ${command.asScala.mkString(" ")} + |Check that you don't have any unexpected modules or libraries in + |your PYTHONPATH: + | $pythonPath + |Also, check if you have a sitecustomize.py module in your python path, + |or in your python installation, that is printing to standard output""" + throw new SparkException(exceptionMessage.stripMargin) + } // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) From 189f56f3dcdad4d997248c01aa5490617f018bd0 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 20 Feb 2018 07:51:30 -0600 Subject: [PATCH 0375/2461] [SPARK-23383][BUILD][MINOR] Make a distribution should exit with usage while detecting wrong options ## What changes were proposed in this pull request? ```shell ./dev/make-distribution.sh --name ne-1.0.0-SNAPSHOT xyz --tgz -Phadoop-2.7 +++ dirname ./dev/make-distribution.sh ++ cd ./dev/.. ++ pwd + SPARK_HOME=/Users/Kent/Documents/spark + DISTDIR=/Users/Kent/Documents/spark/dist + MAKE_TGZ=false + MAKE_PIP=false + MAKE_R=false + NAME=none + MVN=/Users/Kent/Documents/spark/build/mvn + (( 5 )) + case $1 in + NAME=ne-1.0.0-SNAPSHOT + shift + shift + (( 3 )) + case $1 in + break + '[' -z /Users/Kent/.jenv/candidates/java/current ']' + '[' -z /Users/Kent/.jenv/candidates/java/current ']' ++ command -v git + '[' /usr/local/bin/git ']' ++ git rev-parse --short HEAD + GITREV=98ea6a7 + '[' '!' -z 98ea6a7 ']' + GITREVSTRING=' (git revision 98ea6a7)' + unset GITREV ++ command -v /Users/Kent/Documents/spark/build/mvn + '[' '!' /Users/Kent/Documents/spark/build/mvn ']' ++ /Users/Kent/Documents/spark/build/mvn help:evaluate -Dexpression=project.version xyz --tgz -Phadoop-2.7 ++ grep -v INFO ++ tail -n 1 + VERSION=' -X,--debug Produce execution debug output' ``` It is better to declare the mistakes and exit with usage than `break` ## How was this patch tested? manually cc srowen Author: Kent Yao Closes #20571 from yaooqinn/SPARK-23383. --- dev/make-distribution.sh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 8b02446b2f15f..84233c64caa9c 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -72,9 +72,17 @@ while (( "$#" )); do --help) exit_with_usage ;; - *) + --*) + echo "Error: $1 is not supported" + exit_with_usage + ;; + -*) break ;; + *) + echo "Error: $1 is not supported" + exit_with_usage + ;; esac shift done From 83c008762af444eef73d835eb6f506ecf5aebc17 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 09:14:56 -0800 Subject: [PATCH 0376/2461] [SPARK-23456][SPARK-21783] Turn on `native` ORC impl and PPD by default ## What changes were proposed in this pull request? Apache Spark 2.3 introduced `native` ORC supports with vectorization and many fixes. However, it's shipped as a not-default option. This PR enables `native` ORC implementation and predicate-pushdown by default for Apache Spark 2.4. We will improve and stabilize ORC data source before Apache Spark 2.4. And, eventually, Apache Spark will drop old Hive-based ORC code. ## How was this patch tested? Pass the Jenkins with existing tests. Author: Dongjoon Hyun Closes #20634 from dongjoon-hyun/SPARK-23456. --- docs/sql-programming-guide.md | 6 +++++- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 91e43678481d6..c37c338a134f3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1018,7 +1018,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also spark.sql.orc.impl hive - The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1. + The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. spark.sql.orc.enableVectorizedReader @@ -1797,6 +1797,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see # Migration Guide +## Upgrading From Spark SQL 2.3 to 2.4 + + - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e75e1d66ebcf8..ce3f94618edeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default.") + "1.2.1. It is 'hive' by default prior to Spark 2.4.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("hive") + .createWithDefault("native") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + From 3e48f3b9ee7645e4218ad3ff7559e578d4bd9667 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 16:02:44 -0800 Subject: [PATCH 0377/2461] [SPARK-23434][SQL] Spark should not warn `metadata directory` for a HDFS file path ## What changes were proposed in this pull request? In a kerberized cluster, when Spark reads a file path (e.g. `people.json`), it warns with a wrong warning message during looking up `people.json/_spark_metadata`. The root cause of this situation is the difference between `LocalFileSystem` and `DistributedFileSystem`. `LocalFileSystem.exists()` returns `false`, but `DistributedFileSystem.exists` raises `org.apache.hadoop.security.AccessControlException`. ```scala scala> spark.version res0: String = 2.4.0-SNAPSHOT scala> spark.read.json("file:///usr/hdp/current/spark-client/examples/src/main/resources/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ scala> spark.read.json("hdfs:///tmp/people.json") 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. ``` After this PR, ```scala scala> spark.read.json("hdfs:///tmp/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20616 from dongjoon-hyun/SPARK-23434. --- .../spark/sql/execution/streaming/FileStreamSink.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e98..87a17cebdc10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -42,9 +42,11 @@ object FileStreamSink extends Logging { try { val hdfsPath = new Path(singlePath) val fs = hdfsPath.getFileSystem(hadoopConf) - val metadataPath = new Path(hdfsPath, metadataDir) - val res = fs.exists(metadataPath) - res + if (fs.isDirectory(hdfsPath)) { + fs.exists(new Path(hdfsPath, metadataDir)) + } else { + false + } } catch { case NonFatal(e) => logWarning(s"Error while looking for metadata directory.") From 2ba77ed9e51922303e3c3533e368b95788bd7de5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 17:54:06 -0800 Subject: [PATCH 0378/2461] [SPARK-23470][UI] Use first attempt of last stage to define job description. This is much faster than finding out what the last attempt is, and the data should be the same. There's room for improvement in this page (like only loading data for the jobs being shown, instead of loading all available jobs and sorting them), but this should bring performance on par with the 2.2 version. Author: Marcelo Vanzin Closes #20644 from vanzin/SPARK-23470. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index a9265d4dbcdfb..ac83de10f9237 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1048,7 +1048,7 @@ private[ui] object ApiHelper { } def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { - val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)) (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } From 6d398c05cbad69aa9093429e04ae44c73b81cd5a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 18:06:21 -0800 Subject: [PATCH 0379/2461] [SPARK-23468][CORE] Stringify auth secret before storing it in credentials. The secret is used as a string in many parts of the code, so it has to be turned into a hex string to avoid issues such as the random byte sequence not containing a valid UTF8 sequence. Author: Marcelo Vanzin Closes #20643 from vanzin/SPARK-23468. --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 4c1dbe3ffb4ad..5b15a1c57779d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -541,7 +541,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretBytes) + val secretStr = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From 601d653bff9160db8477f86d961e609fc2190237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 20 Feb 2018 18:16:10 -0800 Subject: [PATCH 0380/2461] [SPARK-23454][SS][DOCS] Added trigger information to the Structured Streaming programming guide ## What changes were proposed in this pull request? - Added clear information about triggers - Made the semantics guarantees of watermarks more clear for streaming aggregations and stream-stream joins. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tathagata Das Closes #20631 from tdas/SPARK-23454. --- .../structured-streaming-programming-guide.md | 214 +++++++++++++++++- 1 file changed, 207 insertions(+), 7 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 48d6d0b542cc0..9a83f157452ad 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -904,7 +904,7 @@ windowedCounts <- count(
    -### Handling Late Data and Watermarking +#### Handling Late Data and Watermarking Now consider what happens if one of the events arrives late to the application. For example, say, a word generated at 12:04 (i.e. event time) could be received by the application at 12:11. The application should use the time 12:04 instead of 12:11 @@ -925,7 +925,9 @@ specifying the event time column and the threshold on how late the data is expec event time. For a specific window starting at time `T`, the engine will maintain state and allow late data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, -but data later than the threshold will be dropped. Let's understand this with an example. We can +but data later than the threshold will start getting dropped +(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below.
    @@ -1031,7 +1033,9 @@ then drops intermediate state of a window < watermark, and appends the final counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is appended to the Result Table only after the watermark is updated to `12:11`. -**Conditions for watermarking to clean aggregation state** +##### Conditions for watermarking to clean aggregation state +{:.no_toc} + It is important to note that the following conditions must be satisfied for the watermarking to clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*. @@ -1051,6 +1055,16 @@ from the aggregation column. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append output mode. +##### Semantic Guarantees of Aggregation with Watermarking +{:.no_toc} + +- A watermark delay (set with `withWatermark`) of "2 hours" guarantees that the engine will never +drop any data that is less than 2 hours delayed. In other words, any data less than 2 hours behind +(in terms of event-time) the latest data processed till then is guaranteed to be aggregated. + +- However, the guarantee is strict only in one direction. Data delayed by more than 2 hours is +not guaranteed to be dropped; it may or may not get aggregated. More delayed is the data, less +likely is the engine going to process it. ### Join Operations Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame @@ -1062,7 +1076,7 @@ Dataset/DataFrame will be the exactly the same as if it was with a static Datase containing the same data in the stream. -#### Stream-static joins +#### Stream-static Joins Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example. @@ -1269,6 +1283,12 @@ joined <- join(
+###### Semantic Guarantees of Stream-stream Inner Joins with Watermarking +{:.no_toc} +This is similar to the [guarantees provided by watermarking on aggregations](#semantic-guarantees-of-aggregation-with-watermarking). +A watermark delay of "2 hours" guarantees that the engine will never drop any data that is less than + 2 hours delayed. But data delayed by more than 2 hours may or may not get processed. + ##### Outer Joins with Watermarking While the watermark + event-time constraints is optional for inner joins, for left and right outer joins they must be specified. This is because for generating the NULL results in outer join, the @@ -1347,7 +1367,14 @@ joined <- join(
-There are a few points to note regarding outer joins. +###### Semantic Guarantees of Stream-stream Outer Joins with Watermarking +{:.no_toc} +Outer joins have the same guarantees as [inner joins](#semantic-guarantees-of-stream-stream-inner-joins-with-watermarking) +regarding watermark delays and whether data will be dropped or not. + +###### Caveats +{:.no_toc} +There are a few important characteristics to note regarding how the outer results are generated. - *The outer NULL results will be generated with a delay that depends on the specified watermark delay and the time range condition.* This is because the engine has to wait for that long to ensure @@ -1962,7 +1989,7 @@ head(sql("select * from aggregates")) -#### Using Foreach +##### Using Foreach The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. @@ -1979,6 +2006,172 @@ which has methods that get called whenever there is a sequence of rows generated - Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. +#### Triggers +The trigger settings of a streaming query defines the timing of streaming data processing, whether +the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query. +Here are the different kinds of triggers that are supported. + + + + + + + + + + + + + + + + + + + + + + +
Trigger TypeDescription
unspecified (default) + If no trigger setting is explicitly specified, then by default, the query will be + executed in micro-batch mode, where micro-batches will be generated as soon as + the previous micro-batch has completed processing. +
Fixed interval micro-batches + The query will be executed with micro-batches mode, where micro-batches will be kicked off + at the user-specified intervals. +
    +
  • If the previous micro-batch completes within the interval, then the engine will wait until + the interval is over before kicking off the next micro-batch.
  • + +
  • If the previous micro-batch takes longer than the interval to complete (i.e. if an + interval boundary is missed), then the next micro-batch will start as soon as the + previous one completes (i.e., it will not wait for the next interval boundary).
  • + +
  • If no new data is available, then no micro-batch will be kicked off.
  • +
+
One-time micro-batch + The query will execute *only one* micro-batch to process all the available data and then + stop on its own. This is useful in scenarios you want to periodically spin up a cluster, + process everything that is available since the last period, and then shutdown the + cluster. In some case, this may lead to significant cost savings. +
Continuous with fixed checkpoint interval
(experimental)
+ The query will be executed in the new low-latency, continuous processing mode. Read more + about this in the Continuous Processing section below. +
+ +Here are a few code examples. + +
+
+ +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start() + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start() + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start() + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start() + +{% endhighlight %} + + +
+
+ +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start(); + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start(); + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start(); + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start(); + +{% endhighlight %} + +
+
+ +{% highlight python %} + +# Default trigger (runs micro-batch as soon as it can) +df.writeStream \ + .format("console") \ + .start() + +# ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream \ + .format("console") \ + .trigger(processingTime='2 seconds') \ + .start() + +# One-time trigger +df.writeStream \ + .format("console") \ + .trigger(once=True) \ + .start() + +# Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(continuous='1 second') + .start() + +{% endhighlight %} +
+
+ +{% highlight r %} +# Default trigger (runs micro-batch as soon as it can) +write.stream(df, "console") + +# ProcessingTime trigger with two-seconds micro-batch interval +write.stream(df, "console", trigger.processingTime = "2 seconds") + +# One-time trigger +write.stream(df, "console", trigger.once = TRUE) + +# Continuous trigger is not yet supported +{% endhighlight %} +
+
+ + ## Managing Streaming Queries The `StreamingQuery` object created when a query is started can be used to monitor and manage the query. @@ -2516,7 +2709,10 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat -# Continuous Processing [Experimental] +# Continuous Processing +## [Experimental] +{:.no_toc} + **Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, @@ -2589,6 +2785,8 @@ spark \ A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. ## Supported Queries +{:.no_toc} + As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. - *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). @@ -2606,6 +2804,8 @@ As of Spark 2.3, only the following type of queries are supported in the continu See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. ## Caveats +{:.no_toc} + - Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. - Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. - There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. From 95e25ed1a8b56937345eff637c0032aea85a503d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 21 Feb 2018 11:26:06 +0800 Subject: [PATCH 0381/2461] [SPARK-23424][SQL] Add codegenStageId in comment ## What changes were proposed in this pull request? This PR always adds `codegenStageId` in comment of the generated class. This is a replication of #20419 for post-Spark 2.3. Closes #20419 ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; ... ``` ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20612 from kiszk/SPARK-23424. --- .../expressions/codegen/CodeGenerator.scala | 21 ++++++++++++++++--- .../sql/execution/WholeStageCodegenExec.scala | 4 +++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 31ba29ae8d8ce..60a6f50472504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1232,14 +1232,29 @@ class CodegenContext { /** * Register a comment and return the corresponding place holder + * + * @param placeholderId an optionally specified identifier for the comment's placeholder. + * The caller should make sure this identifier is unique within the + * compilation unit. If this argument is not specified, a fresh identifier + * will be automatically created and used as the placeholder. + * @param force whether to force registering the comments */ - def registerComment(text: => String): String = { + def registerComment( + text: => String, + placeholderId: String = "", + force: Boolean = false): String = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this // flat, see SPARK-15680. - if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { - val name = freshName("c") + if (force || + SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = if (placeholderId != "") { + assert(!placeHolderToComments.contains(placeholderId)) + placeholderId + } else { + freshName("c") + } val comment = if (text.contains("\n") || text.contains("\r")) { text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0e525b1e22eb9..deb0a044c2fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -540,7 +540,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) ${ctx.registerComment( s"""Codegend pipeline for stage (id=$codegenStageId) - |${this.treeString.trim}""".stripMargin)} + |${this.treeString.trim}""".stripMargin, + "wsc_codegenPipeline")} + ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; From c8c4441dfdfeda22f8d92e25aee1b6a6269752f9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 21 Feb 2018 15:10:08 +0800 Subject: [PATCH 0382/2461] [SPARK-23418][SQL] Fail DataSourceV2 reads when user schema is passed, but not supported. ## What changes were proposed in this pull request? DataSourceV2 initially allowed user-supplied schemas when a source doesn't implement `ReadSupportWithSchema`, as long as the schema was identical to the source's schema. This is confusing behavior because changes to an underlying table can cause a previously working job to fail with an exception that user-supplied schemas are not allowed. This reverts commit adcb25a0624, which was added to #20387 so that it could be removed in a separate JIRA issue and PR. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #20603 from rdblue/SPARK-23418-revert-adcb25a0624. --- .../datasources/v2/DataSourceV2Relation.scala | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a98dd4866f82a..cc6cb631e3f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -174,13 +174,6 @@ object DataSourceV2Relation { v2Options: DataSourceOptions, userSchema: Option[StructType]): StructType = { val reader = userSchema match { - // TODO: remove this case because it is confusing for users - case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => - val reader = source.asReadSupport.createReader(v2Options) - if (reader.readSchema() != s) { - throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") - } - reader case Some(s) => source.asReadSupportWithSchema.createReader(s, v2Options) case _ => @@ -195,11 +188,7 @@ object DataSourceV2Relation { filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, - // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must - // be equal to the reader's schema. the schema method enforces this. because the user schema - // and the reader's schema are identical, drop the user schema. - if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) } private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { From e836c27ce011ca9aef822bef6320b4a7059ec343 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Feb 2018 12:39:36 -0600 Subject: [PATCH 0383/2461] [SPARK-23217][ML][PYTHON] Add distanceMeasure param to ClusteringEvaluator Python API ## What changes were proposed in this pull request? The PR adds the `distanceMeasure` param to ClusteringEvaluator in the Python API. This allows the user to specify `cosine` as distance measure in addition to the default `squaredEuclidean`. ## How was this patch tested? added UT Author: Marco Gaido Closes #20627 from mgaido91/SPARK-23217_python. --- python/pyspark/ml/evaluation.py | 28 +++++++++++++++++++++++----- python/pyspark/ml/tests.py | 16 ++++++++++++++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 0cbce9b40048f..695d8ab27cc96 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -362,18 +362,21 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (silhouette)", typeConverter=TypeConverters.toString) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'squaredEuclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ __init__(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") """ super(ClusteringEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) - self._setDefault(metricName="silhouette") + self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean") kwargs = self._input_kwargs self._set(**kwargs) @@ -394,15 +397,30 @@ def getMetricName(self): @keyword_only @since("2.3.0") def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") Sets params for clustering evaluator. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + + if __name__ == "__main__": import doctest import tempfile diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6d6737241e06e..116885969345c 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -51,7 +51,7 @@ from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ +from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel @@ -541,6 +541,15 @@ def test_java_params(self): self.assertEqual(evaluator._java_obj.getMetricName(), "r2") self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + class FeatureTests(SparkSessionTestCase): @@ -1961,11 +1970,14 @@ def test_java_params(self): import pyspark.ml.feature import pyspark.ml.classification import pyspark.ml.clustering + import pyspark.ml.evaluation import pyspark.ml.pipeline import pyspark.ml.recommendation import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): if not name.endswith('Model') and issubclass(cls, JavaParams)\ From 3fd0ccb13fea44727d970479af1682ef00592147 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Feb 2018 14:56:13 -0800 Subject: [PATCH 0384/2461] [SPARK-23484][SS] Fix possible race condition in KafkaContinuousReader ## What changes were proposed in this pull request? var `KafkaContinuousReader.knownPartitions` should be threadsafe as it is accessed from multiple threads - the query thread at the time of reader factory creation, and the epoch tracking thread at the time of `needsReconfiguration`. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20655 from tdas/SPARK-23484. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 97a0f66e1880d..ecd1170321f3f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -66,7 +66,7 @@ class KafkaContinuousReader( // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ override def readSchema: StructType = KafkaOffsetReader.kafkaSchema From 744d5af652ee8cece361cbca31e5201134e0fb42 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 15:37:28 -0800 Subject: [PATCH 0385/2461] [SPARK-23481][WEBUI] lastStageAttempt should fail when a stage doesn't exist ## What changes were proposed in this pull request? The issue here is `AppStatusStore.lastStageAttempt` will return the next available stage in the store when a stage doesn't exist. This PR adds `last(stageId)` to ensure it returns a correct `StageData` ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #20654 from zsxwing/SPARK-23481. --- .../apache/spark/status/AppStatusStore.scala | 6 +++- .../spark/status/AppStatusListenerSuite.scala | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index efc28538a33db..688f25a9fdea1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -95,7 +95,11 @@ private[spark] class AppStatusStore( } def lastStageAttempt(stageId: Int): v1.StageData = { - val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + val it = store.view(classOf[StageDataWrapper]) + .index("stageId") + .reverse() + .first(stageId) + .last(stageId) .closeableIterator() try { if (it.hasNext()) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 749502709b5c8..673d191b5a4db 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1121,6 +1121,39 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("lastStageAttempt should fail when the stage doesn't exist") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1) + val listener = new AppStatusListener(store, testConf, true) + val appStore = new AppStatusStore(store) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Make stage 3 complete before stage 2 so that stage 3 will be evicted + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + stage3.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage3)) + + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + + assert(appStore.asOption(appStore.lastStageAttempt(1)) === None) + assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2)) + assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) From 45cf714ee6d4eead2fe00794a0d754fa6d33d4a6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 19:43:11 -0800 Subject: [PATCH 0386/2461] [SPARK-23475][WEBUI] Skipped stages should be evicted before completed stages ## What changes were proposed in this pull request? The root cause of missing completed stages is because `cleanupStages` will never remove skipped stages. This PR changes the logic to always remove skipped stage first. This is safe since the job itself contains enough information to render skipped stages in the UI. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #20656 from zsxwing/SPARK-23475. --- .../spark/status/AppStatusListener.scala | 5 ++- .../spark/status/AppStatusListenerSuite.scala | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 79a17e26665fd..5ea161cd0d151 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -915,7 +915,10 @@ private[spark] class AppStatusListener( return } - val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + // As the completion time of a skipped stage is always -1, we will remove skipped stages first. + // This is safe since the job itself contains enough information to render skipped stages in the + // UI. + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime") val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 673d191b5a4db..1cd71955ad4d9 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1089,6 +1089,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("skipped stages should be evicted before completed stages") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Sart job 1 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start and stop stage 1 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 and stage 2 will become SKIPPED + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Submit stage 3 and verify stage 2 is evicted + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + test("eviction should respect task completion time") { val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) val listener = new AppStatusListener(store, testConf, true) From 87293c746e19d66f475d506d0adb43421f496843 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Feb 2018 11:00:12 -0800 Subject: [PATCH 0387/2461] [SPARK-23475][UI] Show also skipped stages ## What changes were proposed in this pull request? SPARK-20648 introduced the status `SKIPPED` for the stages. On the UI, previously, skipped stages were shown as `PENDING`; after this change, they are not shown on the UI. The PR introduce a new section in order to show also `SKIPPED` stages in a proper table. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20651 from mgaido91/SPARK-23475. --- .../org/apache/spark/ui/static/webui.js | 1 + .../apache/spark/ui/jobs/AllStagesPage.scala | 27 +++++++++++++++++++ .../org/apache/spark/ui/UISeleniumSuite.scala | 17 ++++++++++++ 3 files changed, 45 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 83009df91d30a..f01c567ba58ad 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -72,6 +72,7 @@ $(function() { collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allSkippedStages','aggregated-allSkippedStages'); collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 606dc1e180e5b..38450b9126ff0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -36,6 +36,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) val pendingStages = allStages.filter(_.status == StageStatus.PENDING) + val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse @@ -51,6 +52,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val completedStagesTable = new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", parent.basePath, subPath, parent.isFairScheduler, false, false) + val skippedStagesTable = + new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) val failedStagesTable = new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", parent.basePath, subPath, parent.isFairScheduler, false, true) @@ -66,6 +70,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val shouldShowActiveStages = activeStages.nonEmpty val shouldShowPendingStages = pendingStages.nonEmpty val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = skippedStages.nonEmpty val shouldShowFailedStages = failedStages.nonEmpty val appSummary = parent.store.appSummary() @@ -102,6 +107,14 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { } } + { + if (shouldShowSkippedStages) { +
  • + Skipped Stages: + {skippedStages.size} +
  • + } + } { if (shouldShowFailedStages) {
  • @@ -172,6 +185,20 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { {completedStagesTable.toNodeSeq} } + if (shouldShowSkippedStages) { + content ++= + +

    + + Skipped Stages ({skippedStages.size}) +

    +
    ++ +
    + {skippedStagesTable.toNodeSeq} +
    + } if (shouldShowFailedStages) { content ++= + val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache() + rdd.count() + rdd.count() + + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/stages") + find(id("skipped")).get.text should be("Skipped Stages (1)") + } + val stagesJson = getJson(sc.ui.get, "stages") + stagesJson.children.size should be (4) + val stagesStatus = stagesJson.children.map(_ \ "status") + stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } From c5abb3c2d16f601d507bee3c53663d4e117eb8b5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 22 Feb 2018 12:07:51 -0800 Subject: [PATCH 0388/2461] [SPARK-23476][CORE] Generate secret in local mode when authentication on ## What changes were proposed in this pull request? If spark is run with "spark.authenticate=true", then it will fail to start in local mode. This PR generates secret in local mode when authentication on. ## How was this patch tested? Modified existing unit test. Manually started spark-shell. Author: Gabor Somogyi Closes #20652 from gaborgsomogyi/SPARK-23476. --- .../org/apache/spark/SecurityManager.scala | 16 ++++-- .../apache/spark/SecurityManagerSuite.scala | 50 +++++++++++++------ docs/security.md | 2 +- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 5b15a1c57779d..2519d266879aa 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -520,19 +520,25 @@ private[spark] class SecurityManager( * * If authentication is disabled, do nothing. * - * In YARN mode, generate a new secret and store it in the current user's credentials. + * In YARN and local mode, generate a new secret and store it in the current user's credentials. * * In other modes, assert that the auth secret is set in the configuration. */ def initializeAuth(): Unit = { + import SparkMasterRegex._ + if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return } - if (sparkConf.get(SparkLauncher.SPARK_MASTER, null) != "yarn") { - require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), - s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") - return + val master = sparkConf.get(SparkLauncher.SPARK_MASTER, "") + master match { + case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => + // Secret generation allowed here + case _ => + require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") + return } val rnd = new SecureRandom() diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index cf59265dd646d..106ece7aed0a4 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -440,23 +440,41 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } - test("secret key generation in yarn mode") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(SparkLauncher.SPARK_MASTER, "yarn") - val mgr = new SecurityManager(conf) - - UserGroupInformation.createUserForTesting("authTest", Array()).doAs( - new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - mgr.initializeAuth() - val creds = UserGroupInformation.getCurrentUser().getCredentials() - val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) - assert(secret != null) - assert(new String(secret, UTF_8) === mgr.getSecretKey()) + test("secret key generation") { + Seq( + ("yarn", true), + ("local", true), + ("local[*]", true), + ("local[1, 2]", true), + ("local-cluster[2, 1, 1024]", false), + ("invalid", false) + ).foreach { case (master, shouldGenerateSecret) => + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(SparkLauncher.SPARK_MASTER, master) + val mgr = new SecurityManager(conf) + + UserGroupInformation.createUserForTesting("authTest", Array()).doAs( + new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + if (shouldGenerateSecret) { + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + } else { + intercept[IllegalArgumentException] { + mgr.initializeAuth() + } + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } + } + } } - } - ) + ) + } } } diff --git a/docs/security.md b/docs/security.md index bebc28ddbfb0e..0f384b411812a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,7 +6,7 @@ title: Security Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. +* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. ## Web UI From 049f243c59737699fee54fdc9d65cbd7c788032a Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 22 Feb 2018 21:49:25 -0800 Subject: [PATCH 0389/2461] [SPARK-23490][SQL] Check storage.locationUri with existing table in CreateTable ## What changes were proposed in this pull request? For CreateTable with Append mode, we should check if `storage.locationUri` is the same with existing table in `PreprocessTableCreation` In the current code, there is only a simple exception if the `storage.locationUri` is different with existing table: `org.apache.spark.sql.AnalysisException: Table or view not found:` which can be improved. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20660 from gengliangwang/locationUri. --- .../sql/execution/datasources/rules.scala | 8 +++++ .../sql/execution/command/DDLSuite.scala | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5cc21eeaeaa94..0dea767840ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -118,6 +118,14 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + s"`${specifiedProvider.getSimpleName}`.") } + tableDesc.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw new AnalysisException( + s"The location of the existing table ${tableIdentWithDB.quotedString} is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${tableDesc.location}`.") + case _ => + } if (query.schema.length != existingTable.schema.length) { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f76bfd2fda2b9..b800e6ff5b0ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -536,6 +536,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create table - append to a non-partitioned table created with different paths") { + import testImplicits._ + withTempDir { dir1 => + withTempDir { dir2 => + withTable("path_test") { + Seq(1L -> "a").toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir1.getCanonicalPath) + .saveAsTable("path_test") + + val ex = intercept[AnalysisException] { + Seq((3L, "c")).toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir2.getCanonicalPath) + .saveAsTable("path_test") + }.getMessage + assert(ex.contains("The location of the existing table `default`.`path_test`")) + + checkAnswer( + spark.table("path_test"), Row(1L, "a") :: Nil) + } + } + } + } + test("Refresh table after changing the data source table partitioning") { import testImplicits._ From 855ce13d045569b7b16fdc7eee9c981f4ff3a545 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Feb 2018 12:40:58 -0800 Subject: [PATCH 0390/2461] [SPARK-23408][SS] Synchronize successive AddData actions in Streaming*JoinSuite **The best way to review this PR is to ignore whitespace/indent changes. Use this link - https://github.com/apache/spark/pull/20650/files?w=1** ## What changes were proposed in this pull request? The stream-stream join tests add data to multiple sources and expect it all to show up in the next batch. But there's a race condition; the new batch might trigger when only one of the AddData actions has been reached. Prior attempt to solve this issue by jose-torres in #20646 attempted to simultaneously synchronize on all memory sources together when consecutive AddData was found in the actions. However, this carries the risk of deadlock as well as unintended modification of stress tests (see the above PR for a detailed explanation). Instead, this PR attempts the following. - A new action called `StreamProgressBlockedActions` that allows multiple actions to be executed while the streaming query is blocked from making progress. This allows data to be added to multiple sources that are made visible simultaneously in the next batch. - An alias of `StreamProgressBlockedActions` called `MultiAddData` is explicitly used in the `Streaming*JoinSuites` to add data to two memory sources simultaneously. This should avoid unintentional modification of the stress tests (or any other test for that matter) while making sure that the flaky tests are deterministic. ## How was this patch tested? Modified test cases in `Streaming*JoinSuites` where there are consecutive `AddData` actions. Author: Tathagata Das Closes #20650 from tdas/SPARK-23408. --- .../streaming/MicroBatchExecution.scala | 10 + .../spark/sql/streaming/StreamTest.scala | 472 ++++++++++-------- .../sql/streaming/StreamingJoinSuite.scala | 54 +- 3 files changed, 284 insertions(+), 252 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84655013ba957..6bd03972c301d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -504,6 +504,16 @@ class MicroBatchExecution( } } + /** Execute a function while locking the stream from making an progress */ + private[sql] def withProgressLocked(f: => Unit): Unit = { + awaitProgressLock.lock() + try { + f + } finally { + awaitProgressLock.unlock() + } + } + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { Optional.ofNullable(scalaOption.orNull) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 159dd0ecb5902..08f722ecb10e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -102,6 +102,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be AddDataMemory(source, data) } + /** + * Adds data to multiple memory streams such that all the data will be made visible in the + * same batch. This is applicable only to MicroBatchExecution, as this coordination cannot be + * performed at the driver in ContinuousExecutions. + */ + object MultiAddData { + def apply[A] + (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = { + val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2)) + StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]")) + } + } + /** A trait that can be extended when testing a source. */ trait AddData extends StreamAction { /** @@ -217,6 +230,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]" } + /** + * Performs multiple actions while locking the stream from progressing. + * This is applicable only to MicroBatchExecution, as progress of ContinuousExecution + * cannot be controlled from the driver. + */ + case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: String = null) + extends StreamAction { + + override def toString(): String = { + if (desc != null) desc else super.toString + } + } + /** Assert that a body is true */ class Assert(condition: => Boolean, val message: String = "") extends StreamAction { def run(): Unit = { Assertions.assert(condition) } @@ -295,6 +321,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + var manualClockExpectedTime = -1L @volatile var streamThreadDeathCause: Throwable = null @@ -425,243 +454,254 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - var manualClockExpectedTime = -1L - val defaultCheckpointLocation = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - try { - startedTest.foreach { action => - logInfo(s"Processing test stream action: $action") - action match { - case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") - verify(triggerClock.isInstanceOf[SystemClock] - || triggerClock.isInstanceOf[StreamManualClock], - "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + def executeAction(action: StreamAction): Unit = { + logInfo(s"Processing test stream action: $action") + action match { + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => + verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + + additionalConfs.foreach(pair => { + val value = + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None + resetConfValues(pair._1) = value + sparkSession.conf.set(pair._1, pair._2) + }) + + lastStream = currentStream + currentStream = + sparkSession + .streams + .startQuery( + None, + Some(metadataRoot), + stream, + Map(), + sink, + outputMode, + trigger = trigger, + triggerClock = triggerClock) + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + // Wait until the initialization finishes, because some tests need to use `logicalPlan` + // after starting the query. + try { + currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + assert(s.lastExecution != null) + } + case _ => } - val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + } catch { + case _: StreamingQueryException => + // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. + } - additionalConfs.foreach(pair => { - val value = - if (sparkSession.conf.contains(pair._1)) { - Some(sparkSession.conf.get(pair._1)) - } else None - resetConfValues(pair._1) = value - sparkSession.conf.set(pair._1, pair._2) - }) + case AdvanceManualClock(timeToAdd) => + verify(currentStream != null, + "can not advance manual clock when a stream is not running") + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], + s"can not advance clock of type ${currentStream.triggerClock.getClass}") + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) + + // Make sure we don't advance ManualClock too early. See SPARK-16002. + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + } + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") + + case StopStream => + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.queryExecutionThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest( + "Timed out while stopping and waiting for microbatchthread to terminate.", e) + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { lastStream = currentStream - currentStream = - sparkSession - .streams - .startQuery( - None, - Some(metadataRoot), - stream, - Map(), - sink, - outputMode, - trigger = trigger, - triggerClock = triggerClock) - .asInstanceOf[StreamingQueryWrapper] - .streamingQuery - // Wait until the initialization finishes, because some tests need to use `logicalPlan` - // after starting the query. - try { - currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - assert(s.lastExecution != null) - } - case _ => - } - } catch { - case _: StreamingQueryException => - // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. - } + currentStream = null + } - case AdvanceManualClock(timeToAdd) => - verify(currentStream != null, - "can not advance manual clock when a stream is not running") - verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], - s"can not advance clock of type ${currentStream.triggerClock.getClass}") - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - assert(manualClockExpectedTime >= 0) - - // Make sure we don't advance ManualClock too early. See SPARK-16002. - eventually("StreamManualClock has not yet entered the waiting state") { - assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[StreamingQueryException] { + currentStream.awaitTermination() } - - clock.advance(timeToAdd) - manualClockExpectedTime += timeToAdd - verify(clock.getTimeMillis() === manualClockExpectedTime, - s"Unexpected clock time after updating: " + - s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") - - case StopStream => - verify(currentStream != null, "can not stop a stream that is not running") - try failAfter(streamingTimeout) { - currentStream.stop() - verify(!currentStream.queryExecutionThread.isAlive, - s"microbatch thread not stopped") - verify(!currentStream.isActive, - "query.isActive() is false even after stopping") - verify(currentStream.exception.isEmpty, - s"query.exception() is not empty after clean stop: " + - currentStream.exception.map(_.toString()).getOrElse("")) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest( - "Timed out while stopping and waiting for microbatchthread to terminate.", e) - case t: Throwable => - failTest("Error while stopping stream", t) - } finally { - lastStream = currentStream - currentStream = null + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.queryExecutionThread.isAlive) } + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + if (ef.isFatalError) { + // This is a fatal error, `streamThreadDeathCause` should be set to this error in + // UncaughtExceptionHandler. + verify(streamThreadDeathCause != null && + streamThreadDeathCause.getClass === ef.causeClass, + "UncaughtExceptionHandler didn't receive the correct error\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") + streamThreadDeathCause = null + } + ef.assertFailure(exception.getCause) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure", e) + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + } - case ef: ExpectFailure[_] => - verify(currentStream != null, "can not expect failure when stream is not running") - try failAfter(streamingTimeout) { - val thrownException = intercept[StreamingQueryException] { - currentStream.awaitTermination() - } - eventually("microbatch thread not stopped after termination with failure") { - assert(!currentStream.queryExecutionThread.isAlive) + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when no stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") + + case a: AddData => + try { + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } } - verify(currentStream.exception === Some(thrownException), - s"incorrect exception returned by query.exception()") - - val exception = currentStream.exception.get - verify(exception.cause.getClass === ef.causeClass, - "incorrect cause in exception returned by query.exception()\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") - if (ef.isFatalError) { - // This is a fatal error, `streamThreadDeathCause` should be set to this error in - // UncaughtExceptionHandler. - verify(streamThreadDeathCause != null && - streamThreadDeathCause.getClass === ef.causeClass, - "UncaughtExceptionHandler didn't receive the correct error\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") - streamThreadDeathCause = null + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") } - ef.assertFailure(exception.getCause) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while waiting for failure", e) - case t: Throwable => - failTest("Error while checking stream failure", t) - } finally { - lastStream = currentStream - currentStream = null } - - case a: AssertOnQuery => - verify(currentStream != null || lastStream != null, - "cannot assert when no stream has been started") - val streamToAssert = Option(currentStream).getOrElse(lastStream) - try { - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") - } catch { - case NonFatal(e) => - failTest(s"Assert on query failed: ${a.message}", e) + // Add data + val queryToUse = Option(currentStream).orElse(Option(lastStream)) + val (source, offset) = a.addData(queryToUse) + + def findSourceIndex(plan: LogicalPlan): Option[Int] = { + plan + .collect { + case StreamingExecutionRelation(s, _) => s + case StreamingDataSourceV2Relation(_, r) => r + } + .zipWithIndex + .find(_._1 == source) + .map(_._2) } - case a: Assert => - val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify({ a.run(); true }, s"Assert failed: ${a.message}") - - case a: AddData => - try { - - // If the query is running with manual clock, then wait for the stream execution - // thread to start waiting for the clock to increment. This is needed so that we - // are adding data when there is no trigger that is active. This would ensure that - // the data gets deterministically added to the next batch triggered after the manual - // clock is incremented in following AdvanceManualClock. This avoid race conditions - // between the test thread and the stream execution thread in tests using manual - // clock. - if (currentStream != null && - currentStream.triggerClock.isInstanceOf[StreamManualClock]) { - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - eventually("Error while synchronizing with manual clock before adding data") { - if (currentStream.isActive) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + // Try to find the index of the source to which data was added. Either get the index + // from the current active query or the original input logical plan. + val sourceIndex = + queryToUse.flatMap { query => + findSourceIndex(query.logicalPlan) + }.orElse { + findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) } - if (!currentStream.isActive) { - failTest("Query terminated while synchronizing with manual clock") - } - } - // Add data - val queryToUse = Option(currentStream).orElse(Option(lastStream)) - val (source, offset) = a.addData(queryToUse) - - def findSourceIndex(plan: LogicalPlan): Option[Int] = { - plan - .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r - } - .zipWithIndex - .find(_._1 == source) - .map(_._2) + }.getOrElse { + throw new IllegalArgumentException( + "Could not find index of the source to which data was added") } - // Try to find the index of the source to which data was added. Either get the index - // from the current active query or the original input logical plan. - val sourceIndex = - queryToUse.flatMap { query => - findSourceIndex(query.logicalPlan) - }.orElse { - findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } - }.getOrElse { - throw new IllegalArgumentException( - "Could not find index of the source to which data was added") - } + // Store the expected offset of added data to wait for it later + awaiting.put(sourceIndex, offset) + } catch { + case NonFatal(e) => + failTest("Error adding data", e) + } - // Store the expected offset of added data to wait for it later - awaiting.put(sourceIndex, offset) - } catch { - case NonFatal(e) => - failTest("Error adding data", e) - } + case e: ExternalAction => + e.runAction() - case e: ExternalAction => - e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { + error => failTest(error) + } - case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { - error => failTest(error) - } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } - case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = currentStream match { - case null => fetchStreamAnswer(lastStream, lastOnly) - case s => fetchStreamAnswer(s, lastOnly) - } - QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { - error => failTest(error) - } + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) + } + } + pos += 1 + } - case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - try { - globalCheckFunction(sparkAnswer) - } catch { - case e: Throwable => failTest(e.toString) - } - } - pos += 1 + try { + startedTest.foreach { + case StreamProgressLockedActions(actns, _) => + // Perform actions while holding the stream from progressing + assert(currentStream != null, + s"Cannot perform stream-progress-locked actions $actns when query is not active") + assert(currentStream.isInstanceOf[MicroBatchExecution], + s"Cannot perform stream-progress-locked actions on non-microbatch queries") + currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { + actns.foreach(executeAction) + } + + case action: StreamAction => executeAction(action) } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 92087f68ad74a..11bdd13942dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -462,15 +462,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -493,15 +491,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with value <= 7 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) @@ -524,15 +520,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with value <= 4 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) @@ -555,15 +549,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -575,13 +567,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -595,13 +585,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -676,11 +664,9 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), @@ -688,22 +674,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows - AddData(leftInput, 40, 41), - AddData(rightInput, 40, 41), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - AddData(leftInput, 70), - AddData(rightInput, 71), + MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), assertNumStateRows(total = 6, updated = 2), AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), // rightValue between 300 and 1000 should generate outer join rows even though it matches left - AddData(leftInput, 101, 102, 103), - AddData(rightInput, 101, 102, 103), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), CheckLastBatch(), - AddData(leftInput, 1000), - AddData(rightInput, 1001), + MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), assertNumStateRows(total = 8, updated = 2), AddData(rightInput, 1000), From 1a198ce8f580bcf35b9cbfab403fc40f821046a1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 23 Feb 2018 16:30:32 -0800 Subject: [PATCH 0391/2461] [SPARK-23459][SQL] Improve the error message when unknown column is specified in partition columns ## What changes were proposed in this pull request? This PR avoids to print schema internal information when unknown column is specified in partition columns. This PR prints column names in the schema with more readable format. The following is an example. Source code ``` test("save with an unknown partition column") { withTempDir { dir => val path = dir.getCanonicalPath Seq(1L -> "a").toDF("i", "j").write .format("parquet") .partitionBy("unknownColumn") .save(path) } ``` Output without this PR ``` Partition column unknownColumn not found in schema StructType(StructField(i,LongType,false), StructField(j,StringType,true)); ``` Output with this PR ``` Partition column unknownColumn not found in schema struct; ``` ## How was this patch tested? Manually tested Author: Kazuaki Ishizaki Closes #20653 from kiszk/SPARK-23459. --- .../datasources/PartitioningUtils.scala | 3 ++- .../apache/spark/sql/sources/SaveLoadSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 379acb67f7c71..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -486,7 +486,8 @@ object PartitioningUtils { val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => schema.find(f => equality(f.name, col)).getOrElse { - throw new AnalysisException(s"Partition column $col not found in schema $schema") + val schemaCatalog = schema.catalogString + throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog") } }).asNullable } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 773d34dfaf9a8..12779b46bfe8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -126,4 +126,20 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA checkLoad(df2, "jsonTable2") } + + test("SPARK-23459: Improve error message when specified unknown column in partition columns") { + withTempDir { dir => + val path = dir.getCanonicalPath + val unknown = "unknownColumn" + val df = Seq(1L -> "a").toDF("i", "j") + val schemaCatalog = df.schema.catalogString + val e = intercept[AnalysisException] { + df.write + .format("parquet") + .partitionBy(unknown) + .save(path) + }.getMessage + assert(e.contains(s"Partition column `$unknown` not found in schema $schemaCatalog")) + } + } } From 3ca9a2c56513444d7b233088b020d2d43fa6b77a Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Sun, 25 Feb 2018 09:29:59 -0600 Subject: [PATCH 0392/2461] =?UTF-8?q?[SPARK-22886][ML][TESTS]=20ML=20test?= =?UTF-8?q?=20for=20structured=20streaming:=20ml.recomme=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Converting spark.ml.recommendation tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20362 from gaborgsomogyi/SPARK-22886. --- .../spark/ml/recommendation/ALSSuite.scala | 213 ++++++++++++------ .../apache/spark/ml/util/MLTestingUtils.scala | 44 ---- 2 files changed, 143 insertions(+), 114 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index addcd21d50aac..e3dfe2faf5698 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,8 +22,7 @@ import java.util.Random import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.WrappedArray +import scala.collection.mutable.{ArrayBuffer, WrappedArray} import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -35,21 +34,20 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.recommendation.ALS.Rating -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class ALSSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { +class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() @@ -413,34 +411,36 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => - (rating.toDouble, prediction.toDouble) + testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble)) + + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the + // weighted RMSE with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val errorSquares = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + } + val mse = errorSquares.sum / errorSquares.length + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) } - val rmse = - if (implicitPrefs) { - // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. - // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE - // with the confidence scores as weights. - val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => - val confidence = 1.0 + alpha * math.abs(rating) - val rating01 = math.max(math.min(rating, 1.0), 0.0) - val prediction01 = math.max(math.min(prediction, 1.0), 0.0) - val err = prediction01 - rating01 - (confidence, confidence * err * err) - }.reduce { case ((c0, e0), (c1, e1)) => - (c0 + c1, e0 + e1) - } - math.sqrt(weightedSumSq / totalWeight) - } else { - val mse = predictions.map { case (rating, prediction) => - val err = rating - prediction - err * err - }.mean() - math.sqrt(mse) - } - logInfo(s"Test RMSE is $rmse.") - assert(rmse < targetRMSE) MLTestingUtils.checkCopyAndUids(als, model) } @@ -586,6 +586,68 @@ class ALSSuite allModelParamSettings, checkModelData) } + private def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val df = dfs.find { + case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType + } match { + case Some((_, df)) => df + } + val expected = estimator.fit(df) + val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } + + val baseDF = dfs.find(_._1.numericType == baseType).get._2 + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + private class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Int, Double)]) + + private def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String) = { + + import testImplicits._ + + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col) + val types = + Seq(new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()) + ) + types.map { t => + val cols = Seq(col(column).cast(t.numericType)) ++ others + t -> df.select(cols: _*) + } + } + test("input type validation") { val spark = this.spark import spark.implicits._ @@ -595,12 +657,16 @@ class ALSSuite val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => - ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) - } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1) + } { (ex, act, df, enc) => + val expected = ex.transform(df).selectExpr("prediction") + .first().getFloat(0) + testTransformerByGlobalCheckFunc(df, act, "prediction") { + case rows: Seq[Row] => + expected ~== rows.head.getFloat(0) absTol 1e-6 + }(enc) } } // check user/item ids falling outside of Int range @@ -628,18 +694,22 @@ class ALSSuite } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) - assert(intercept[SparkException] { - model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains(msg)) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + assert(intercept[SparkException] { + model.transform(dataFrame).first + }.getMessage.contains(msg)) + assert(intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + }.getMessage.contains(msg)) + } + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } } @@ -662,28 +732,31 @@ class ALSSuite val knownItem = data.select(max("item")).as[Int].first() val unknownItem = knownItem + 20 val test = Seq( - (unknownUser, unknownItem), - (knownUser, unknownItem), - (unknownUser, knownItem), - (knownUser, knownItem) - ).toDF("user", "item") + (unknownUser, unknownItem, true), + (knownUser, unknownItem, true), + (unknownUser, knownItem, true), + (knownUser, knownItem, false) + ).toDF("user", "item", "expectedIsNaN") val als = new ALS().setMaxIter(1).setRank(1) // default is 'nan' val defaultModel = als.fit(data) - val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() - assert(defaultPredictions.length == 4) - assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) - assert(!defaultPredictions.last.isNaN) + testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") { + case Row(expectedIsNaN: Boolean, prediction: Float) => + assert(prediction.isNaN === expectedIsNaN) + } // check 'drop' strategy should filter out rows with unknown users/items - val dropPredictions = defaultModel - .setColdStartStrategy("drop") - .transform(test) - .select("prediction").as[Float].collect() - assert(dropPredictions.length == 1) - assert(!dropPredictions.head.isNaN) - assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + val defaultPrediction = defaultModel.transform(test).select("prediction") + .as[Float].filter(!_.isNaN).first() + testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test, + defaultModel.setColdStartStrategy("drop"), "prediction") { + case rows: Seq[Row] => + val dropPredictions = rows.map(_.getFloat(0)) + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPrediction relTol 1e-14) + } } test("case insensitive cold start param value") { @@ -693,7 +766,7 @@ class ALSSuite val data = ratings.toDF val model = new ALS().fit(data) Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => - model.setColdStartStrategy(s).transform(data) + testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index aef81c8c173a0..c328d81b4bc3a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite { } } - def checkNumericTypesALS( - estimator: ALS, - spark: SparkSession, - column: String, - baseType: NumericType) - (check: (ALSModel, ALSModel) => Unit) - (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(spark, column) - val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) - actuals.foreach { case (_, actual) => check(expected, actual) } - actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } - - val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) - val cols = Seq(col(column).cast(StringType)) ++ others - val strDF = baseDF.select(cols: _*) - val thrown = intercept[IllegalArgumentException] { - estimator.fit(strDF) - } - assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) - } - def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } - def genRatingsDFWithNumericCols( - spark: SparkSession, - column: String): Map[NumericType, DataFrame] = { - val df = spark.createDataFrame(Seq( - (0, 10, 1.0), - (1, 20, 2.0), - (2, 30, 3.0), - (3, 40, 4.0), - (4, 50, 5.0) - )).toDF("user", "item", "rating") - - val others = df.columns.toSeq.diff(Seq(column)).map(col) - val types: Seq[NumericType] = - Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map { t => - val cols = Seq(col(column).cast(t)) ++ others - t -> df.select(cols: _*) - }.toMap - } - def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", From b308182f233b8840dfe0e6b5736d2f2746f40757 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 26 Feb 2018 08:39:44 -0800 Subject: [PATCH 0393/2461] [SPARK-23438][DSTREAMS] Fix DStreams data loss with WAL when driver crashes ## What changes were proposed in this pull request? There is a race condition introduced in SPARK-11141 which could cause data loss. The problem is that ReceivedBlockTracker.insertAllocatedBatch function assumes that all blocks from streamIdToUnallocatedBlockQueues allocated to the batch and clears the queue. In this PR only the allocated blocks will be removed from the queue which will prevent data loss. ## How was this patch tested? Additional unit test + manually. Author: Gabor Somogyi Closes #20620 from gaborgsomogyi/SPARK-23438. --- .../scheduler/ReceivedBlockTracker.scala | 11 +++++---- .../streaming/ReceivedBlockTrackerSuite.scala | 23 ++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 5d9a8ac0d9297..dacff69d55dd2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker( getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo } - // Insert the recovered block-to-batch allocations and clear the queue of received blocks - // (when the blocks were originally allocated to the batch, the queue must have been cleared). + // Insert the recovered block-to-batch allocations and removes them from queue of + // received blocks. def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") - streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + allocatedBlocks.streamIdToAllocatedBlocks.foreach { + case (streamId, allocatedBlocksInStream) => + getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet) + } timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } @@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { + private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { logTrace(s"Writing record: $record") try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 107c3f5dcc08d..4fa236bd39663 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult -import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _} import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } + test("recovery with write ahead logs should remove only allocated blocks from received queue") { + val manualClock = new ManualClock + val batchTime = manualClock.getTimeMillis() + + val tracker1 = createTracker(clock = manualClock) + tracker1.isWriteAheadLogEnabled should be (true) + + val allocatedBlockInfos = generateBlockInfos() + val unallocatedBlockInfos = generateBlockInfos() + val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos + receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) } + val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos)) + tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + tracker1.stop() + + val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) + tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks + tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates From 185f5bc7dd52cebe8fac9393ecb2bd0968bc5867 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Mon, 26 Feb 2018 10:28:45 -0800 Subject: [PATCH 0394/2461] [SPARK-23449][K8S] Preserve extraJavaOptions ordering For some JVM options, like `-XX:+UnlockExperimentalVMOptions` ordering is necessary. ## What changes were proposed in this pull request? Keep original `extraJavaOptions` ordering, when passing them through environment variables inside the Docker container. ## How was this patch tested? Ran base branch a couple of times and checked startup command in logs. Ordering differed every time. Added sorting, ordering was consistent to what user had in `extraJavaOptions`. Author: Andrew Korzhuev Closes #20628 from andrusha/patch-2. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index b9090dc2852a5..3d67b0a702dd4 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -41,7 +41,7 @@ fi shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" -env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" From 7ec83658fbc88505dfc2d8a6f76e90db747f1292 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 26 Feb 2018 11:28:44 -0800 Subject: [PATCH 0395/2461] [SPARK-23491][SS] Remove explicit job cancellation from ContinuousExecution reconfiguring ## What changes were proposed in this pull request? Remove queryExecutionThread.interrupt() from ContinuousExecution. As detailed in the JIRA, interrupting the thread is only relevant in the microbatch case; for continuous processing the query execution can quickly clean itself up without. ## How was this patch tested? existing tests Author: Jose Torres Closes #20622 from jose-torres/SPARK-23441. --- .../streaming/continuous/ContinuousExecution.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2c1d6c509d21b..daebd1dd010ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -236,9 +236,7 @@ class ContinuousExecution( startTrigger() if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { - stopSources() if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() } false @@ -266,12 +264,20 @@ class ContinuousExecution( SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } + } catch { + case t: Throwable + if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING => + logInfo(s"Query $id ignoring exception from reconfiguring: $t") + // interrupted by reconfiguration - swallow exception so we can restart the query } finally { epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() epochUpdateThread.join() + + stopSources() + sparkSession.sparkContext.cancelJobGroup(runId.toString) } } From 8077bb04f350fd35df83ef896135c0672dc3f7b0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Mon, 26 Feb 2018 23:37:31 -0800 Subject: [PATCH 0396/2461] [SPARK-23445] ColumnStat refactoring ## What changes were proposed in this pull request? Refactor ColumnStat to be more flexible. * Split `ColumnStat` and `CatalogColumnStat` just like `CatalogStatistics` is split from `Statistics`. This detaches how the statistics are stored from how they are processed in the query plan. `CatalogColumnStat` keeps `min` and `max` as `String`, making it not depend on dataType information. * For `CatalogColumnStat`, parse column names from property names in the metastore (`KEY_VERSION` property), not from metastore schema. This means that `CatalogColumnStat`s can be created for columns even if the schema itself is not stored in the metastore. * Make all fields optional. `min`, `max` and `histogram` for columns were optional already. Having them all optional is more consistent, and gives flexibility to e.g. drop some of the fields through transformations if they are difficult / impossible to calculate. The added flexibility will make it possible to have alternative implementations for stats, and separates stats collection from stats and estimation processing in plans. ## How was this patch tested? Refactored existing tests to work with refactored `ColumnStat` and `CatalogColumnStat`. New tests added in `StatisticsSuite` checking that backwards / forwards compatibility is not broken. Author: Juliusz Sompolski Closes #20624 from juliuszsompolski/SPARK-23445. --- .../sql/catalyst/catalog/interface.scala | 146 ++++++++- .../optimizer/StarSchemaDetection.scala | 6 +- .../catalyst/plans/logical/Statistics.scala | 256 ++-------------- .../statsEstimation/AggregateEstimation.scala | 6 +- .../statsEstimation/EstimationUtils.scala | 20 +- .../statsEstimation/FilterEstimation.scala | 98 +++--- .../statsEstimation/JoinEstimation.scala | 55 ++-- .../catalyst/optimizer/JoinReorderSuite.scala | 25 +- .../StarJoinCostBasedReorderSuite.scala | 96 ++---- .../optimizer/StarJoinReorderSuite.scala | 77 ++--- .../AggregateEstimationSuite.scala | 24 +- .../BasicStatsEstimationSuite.scala | 12 +- .../FilterEstimationSuite.scala | 279 +++++++++--------- .../statsEstimation/JoinEstimationSuite.scala | 138 +++++---- .../ProjectEstimationSuite.scala | 70 +++-- .../StatsEstimationTestBase.scala | 10 +- .../command/AnalyzeColumnCommand.scala | 138 ++++++++- .../spark/sql/execution/command/tables.scala | 9 +- .../spark/sql/StatisticsCollectionSuite.scala | 9 +- .../sql/StatisticsCollectionTestBase.scala | 168 +++++++++-- .../spark/sql/hive/HiveExternalCatalog.scala | 63 ++-- .../spark/sql/hive/StatisticsSuite.scala | 162 +++------- 22 files changed, 995 insertions(+), 872 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 95b6fbb0cd61a..f3e67dc4e975c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -21,7 +21,9 @@ import java.net.URI import java.util.Date import scala.collection.mutable +import scala.util.control.NonFatal +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** @@ -361,7 +363,7 @@ object CatalogTable { case class CatalogStatistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - colStats: Map[String, ColumnStat] = Map.empty) { + colStats: Map[String, CatalogColumnStat] = Map.empty) { /** * Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based @@ -369,7 +371,8 @@ case class CatalogStatistics( */ def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = { if (cboEnabled && rowCount.isDefined) { - val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _))) + val attrStats = AttributeMap(planOutput + .flatMap(a => colStats.get(a.name).map(a -> _.toPlanStat(a.name, a.dataType)))) // Estimate size as number of rows * row size. val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats) Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats) @@ -387,6 +390,143 @@ case class CatalogStatistics( } } +/** + * This class of statistics for a column is used in [[CatalogTable]] to interact with metastore. + */ +case class CatalogColumnStat( + distinctCount: Option[BigInt] = None, + min: Option[String] = None, + max: Option[String] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, + histogram: Option[Histogram] = None) { + + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the column and name of the field (e.g. "colName.distinctCount"), + * and the value is the string representation for the value. + * min/max values are stored as Strings. They can be deserialized using + * [[CatalogColumnStat.fromExternalString]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * Any of the fields that are null (None) won't appear in the map. + */ + def toMap(colName: String): Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", "1") + distinctCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString) + } + nullCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString) + } + avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) } + maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) } + min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) } + max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) } + histogram.foreach { h => + map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h)) + } + map.toMap + } + + /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */ + def toPlanStat( + colName: String, + dataType: DataType): ColumnStat = + ColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) +} + +object CatalogColumnStat extends Logging { + + // List of string keys used to serialize CatalogColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" + + /** + * Converts from string representation of data type to the corresponding Catalyst data type. + */ + def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + + /** + * Creates a [[CatalogColumnStat]] object from the given map. + * This is used to deserialize column stats from some external storage. + * The serialization side is defined in [[CatalogColumnStat.toMap]]. + */ + def fromMap( + table: String, + colName: String, + map: Map[String, String]): Option[CatalogColumnStat] = { + + try { + Some(CatalogColumnStat( + distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)), + min = map.get(s"${colName}.${KEY_MIN_VALUE}"), + max = map.get(s"${colName}.${KEY_MAX_VALUE}"), + nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)), + avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong), + maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong), + histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize) + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e) + None + } + } +} + case class CatalogTableType private(name: String) object CatalogTableType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 1f20b7661489e..2aa762e2595ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -187,11 +187,11 @@ object StarSchemaDetection extends PredicateHelper { stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { + val colStats = stats.attributeStats.get(col).get + if (!colStats.hasCountStats || colStats.nullCount.get > 0) { false } else { - val distinctCount = colStats.get.distinctCount + val distinctCount = colStats.distinctCount.get val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) // ndvMaxErr adjusted based on TPCDS 1TB data results relDiff <= conf.ndvMaxError * 2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 96b199d7f20b0..b3a48860aa63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -27,6 +27,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} @@ -79,11 +80,10 @@ case class Statistics( /** * Statistics collected for a column. * - * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * 1. The JVM data type stored in min/max is the internal data type for the corresponding * Catalyst data type. For example, the internal type of DateType is Int, and that the internal * type of TimestampType is Long. - * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -95,240 +95,32 @@ case class Statistics( * @param histogram histogram of the values */ case class ColumnStat( - distinctCount: BigInt, - min: Option[Any], - max: Option[Any], - nullCount: BigInt, - avgLen: Long, - maxLen: Long, + distinctCount: Option[BigInt] = None, + min: Option[Any] = None, + max: Option[Any] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, histogram: Option[Histogram] = None) { - // We currently don't store min/max for binary/string type. This can change in the future and - // then we need to remove this require. - require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) - require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - - /** - * Returns a map from string to string that can be used to serialize the column stats. - * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. min/max values are converted to the external data type. For - * example, for DateType we store java.sql.Date, and for TimestampType we store - * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. - * - * As part of the protocol, the returned map always contains a key called "version". - * In the case min/max values are null (None), they won't appear in the map. - */ - def toMap(colName: String, dataType: DataType): Map[String, String] = { - val map = new scala.collection.mutable.HashMap[String, String] - map.put(ColumnStat.KEY_VERSION, "1") - map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) - map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) - map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) - map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } - histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } - map.toMap - } - - /** - * Converts the given value from Catalyst data type to string representation of external - * data type. - */ - private def toExternalString(v: Any, colName: String, dataType: DataType): String = { - val externalValue = dataType match { - case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) - case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) - case BooleanType | _: IntegralType | FloatType | DoubleType => v - case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal - // This version of Spark does not use min/max for binary/string types so we ignore it. - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $colName of data type: $dataType.") - } - externalValue.toString - } - -} + // Are distinctCount and nullCount statistics defined? + val hasCountStats = distinctCount.isDefined && nullCount.isDefined + // Are min and max statistics defined? + val hasMinMaxStats = min.isDefined && max.isDefined -object ColumnStat extends Logging { - - // List of string keys used to serialize ColumnStat - val KEY_VERSION = "version" - private val KEY_DISTINCT_COUNT = "distinctCount" - private val KEY_MIN_VALUE = "min" - private val KEY_MAX_VALUE = "max" - private val KEY_NULL_COUNT = "nullCount" - private val KEY_AVG_LEN = "avgLen" - private val KEY_MAX_LEN = "maxLen" - private val KEY_HISTOGRAM = "histogram" - - /** Returns true iff the we support gathering column statistics on column of the given type. */ - def supportsType(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case BooleanType => true - case DateType => true - case TimestampType => true - case BinaryType | StringType => true - case _ => false - } - - /** Returns true iff the we support gathering histogram on column of the given type. */ - def supportsHistogram(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case DateType => true - case TimestampType => true - case _ => false - } - - /** - * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats - * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. - */ - def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { - try { - Some(ColumnStat( - distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), - // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - nullCount = BigInt(map(KEY_NULL_COUNT).toLong), - avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, - histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) - )) - } catch { - case NonFatal(e) => - logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) - None - } - } - - /** - * Converts from string representation of external data type to the corresponding Catalyst data - * type. - */ - private def fromExternalString(s: String, name: String, dataType: DataType): Any = { - dataType match { - case BooleanType => s.toBoolean - case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) - case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) - case ByteType => s.toByte - case ShortType => s.toShort - case IntegerType => s.toInt - case LongType => s.toLong - case FloatType => s.toFloat - case DoubleType => s.toDouble - case _: DecimalType => Decimal(s) - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $name of data type: $dataType.") - } - } - - /** - * Constructs an expression to compute column statistics for a given column. - * - * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, - * distinctCountsForIntervals: Array[Long] - * - * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and - * as a result should stay in sync with it. - */ - def statExprs( - col: Attribute, - conf: SQLConf, - colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { - def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => - expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } - }) - val one = Literal(1, LongType) - - // the approximate ndv (num distinct value) should never be larger than the number of rows - val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) - val numNulls = Subtract(Count(one), numNonNulls) - val defaultSize = Literal(col.dataType.defaultSize, LongType) - val nullArray = Literal(null, ArrayType(LongType)) - - def fixedLenTypeStruct: CreateNamedStruct = { - val genHistogram = - ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) - val intervalNdvsExpr = if (genHistogram) { - ApproxCountDistinctForIntervals(col, - Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) - } else { - nullArray - } - // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, - defaultSize, defaultSize, intervalNdvsExpr) - } - - col.dataType match { - case _: IntegralType => fixedLenTypeStruct - case _: DecimalType => fixedLenTypeStruct - case DoubleType | FloatType => fixedLenTypeStruct - case BooleanType => fixedLenTypeStruct - case DateType => fixedLenTypeStruct - case TimestampType => fixedLenTypeStruct - case BinaryType | StringType => - // For string and binary type, we don't compute min, max or histogram - val nullLit = Literal(null, col.dataType) - struct( - ndv, nullLit, nullLit, numNulls, - // Set avg/max size to default size if all the values are null or there is no value. - Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), - nullArray) - case _ => - throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") - } - } - - /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ - def rowToColumnStat( - row: InternalRow, - attr: Attribute, - rowCount: Long, - percentiles: Option[ArrayData]): ColumnStat = { - // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. - val cs = ColumnStat( - distinctCount = BigInt(row.getLong(0)), - // for string/binary min/max, get should return null - min = Option(row.get(1, attr.dataType)), - max = Option(row.get(2, attr.dataType)), - nullCount = BigInt(row.getLong(3)), - avgLen = row.getLong(4), - maxLen = row.getLong(5) - ) - if (row.isNullAt(6)) { - cs - } else { - val ndvs = row.getArray(6).toLongArray() - assert(percentiles.get.numElements() == ndvs.length + 1) - val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) - // Construct equi-height histogram - val bins = ndvs.zipWithIndex.map { case (ndv, i) => - HistogramBin(endpoints(i), endpoints(i + 1), ndv) - } - val nonNullRows = rowCount - cs.nullCount - val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) - cs.copy(histogram = Some(histogram)) - } - } + // Are avgLen and maxLen statistics defined? + val hasLenStats = avgLen.isDefined && maxLen.isDefined + def toCatalogColumnStat(colName: String, dataType: DataType): CatalogColumnStat = + CatalogColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index c41fac4015ec0..111c594a53e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -32,13 +32,15 @@ object AggregateEstimation { val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => - e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) + e.isInstanceOf[Attribute] && + childStats.attributeStats.get(e.asInstanceOf[Attribute]).exists(_.hasCountStats) } if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( - (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) + (res, expr) => res * + childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get) outputRows = if (agg.groupingExpressions.isEmpty) { // If there's no group-by columns, the output is a single row containing values of aggregate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index d793f77413d18..0f147f0ffb135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -38,9 +39,18 @@ object EstimationUtils { } } + /** Check if each attribute has column stat containing distinct and null counts + * in the corresponding statistic. */ + def columnStatsWithCountsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.get(attr).map(_.hasCountStats).getOrElse(false) + } + } + + /** Statistics for a Column containing only NULLs. */ def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, - avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(rowCount), + avgLen = Some(dataType.defaultSize), maxLen = Some(dataType.defaultSize)) } /** @@ -70,13 +80,13 @@ object EstimationUtils { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => - if (attrStats.contains(attr)) { + if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 + attrStats(attr).avgLen.get + 8 + 4 case _ => - attrStats(attr).avgLen + attrStats(attr).avgLen.get } } else { attr.dataType.defaultSize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4cc32de2d32d7..0538c9d88584b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, isNull: Boolean, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) { logDebug("[CBO] No statistics for " + attr) return None } @@ -234,14 +234,14 @@ case class FilterEstimation(plan: Filter) extends Logging { val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble + (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble } if (update) { val newStats = if (isNull) { - colStat.copy(distinctCount = 0, min = None, max = None) + colStat.copy(distinctCount = Some(0), min = None, max = None) } else { - colStat.copy(nullCount = 0) + colStat.copy(nullCount = Some(0)) } colStatsMap.update(attr, newStats) } @@ -322,17 +322,21 @@ case class FilterEstimation(plan: Filter) extends Logging { // value. val newStats = attr.dataType match { case StringType | BinaryType => - colStat.copy(distinctCount = 1, nullCount = 0) + colStat.copy(distinctCount = Some(1), nullCount = Some(0)) case _ => - colStat.copy(distinctCount = 1, min = Some(literal.value), - max = Some(literal.value), nullCount = 0) + colStat.copy(distinctCount = Some(1), min = Some(literal.value), + max = Some(literal.value), nullCount = Some(0)) } colStatsMap.update(attr, newStats) } if (colStat.histogram.isEmpty) { - // returns 1/ndv if there is no histogram - Some(1.0 / colStat.distinctCount.toDouble) + if (!colStat.distinctCount.isEmpty) { + // returns 1/ndv if there is no histogram + Some(1.0 / colStat.distinctCount.get.toDouble) + } else { + None + } } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -378,13 +382,13 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, hSet: Set[Any], update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.hasDistinctCount(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount + val ndv = colStat.distinctCount.get val dataType = attr.dataType var newNdv = ndv @@ -407,8 +411,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), - max = Some(newMax), nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), + max = Some(newMax), nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -416,7 +420,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case StringType | BinaryType => newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) colStatsMap.update(attr, newStats) } } @@ -443,12 +447,17 @@ case class FilterEstimation(plan: Filter) extends Logging { literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] val max = statsInterval.max val min = statsInterval.min - val ndv = colStat.distinctCount.toDouble + val ndv = colStat.distinctCount.get.toDouble // determine the overlapping degree between predicate interval and column's interval val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) @@ -520,8 +529,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = colStat.copy(distinctCount = ceil(ndv * percent), - min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)), + min = newMin, max = newMax, nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -637,11 +646,11 @@ case class FilterEstimation(plan: Filter) extends Logging { attrRight: Attribute, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attrLeft)) { + if (!colStatsMap.hasCountStats(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) return None } - if (!colStatsMap.contains(attrRight)) { + if (!colStatsMap.hasCountStats(attrRight)) { logDebug("[CBO] No statistics for " + attrRight) return None } @@ -668,7 +677,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val minRight = statsIntervalRight.min // determine the overlapping degree between predicate interval and column's interval - val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val allNotNull = (colStatLeft.nullCount.get == 0) && (colStatRight.nullCount.get == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right // - no overlap: @@ -707,14 +716,14 @@ case class FilterEstimation(plan: Filter) extends Logging { case _: EqualTo => ((maxLeft < minRight) || (maxRight < minLeft), (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) case _: EqualNullSafe => // For null-safe equality, we use a very restrictive condition to evaluate its overlap. // If null values exists, we set it to partial overlap. (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) } @@ -731,9 +740,9 @@ case class FilterEstimation(plan: Filter) extends Logging { if (update) { // Need to adjust new min/max after the filter condition is applied - val ndvLeft = BigDecimal(colStatLeft.distinctCount) + val ndvLeft = BigDecimal(colStatLeft.distinctCount.get) val newNdvLeft = ceil(ndvLeft * percent) - val ndvRight = BigDecimal(colStatRight.distinctCount) + val ndvRight = BigDecimal(colStatRight.distinctCount.get) val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max @@ -817,10 +826,10 @@ case class FilterEstimation(plan: Filter) extends Logging { } } - val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + val newStatsLeft = colStatLeft.copy(distinctCount = Some(newNdvLeft), min = newMinLeft, max = newMaxLeft) colStatsMap(attrLeft) = newStatsLeft - val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + val newStatsRight = colStatRight.copy(distinctCount = Some(newNdvRight), min = newMinRight, max = newMaxRight) colStatsMap(attrRight) = newStatsRight } @@ -849,17 +858,35 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) /** - * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in - * originalMap, because updatedMap has the latest (updated) column stats. + * Gets an Option of column stat for the given attribute. + * Prefer the column stat in updatedMap than that in originalMap, + * because updatedMap has the latest (updated) column stats. */ - def apply(a: Attribute): ColumnStat = { + def get(a: Attribute): Option[ColumnStat] = { if (updatedMap.contains(a.exprId)) { - updatedMap(a.exprId)._2 + updatedMap.get(a.exprId).map(_._2) } else { - originalMap(a) + originalMap.get(a) } } + def hasCountStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + def hasDistinctCount(a: Attribute): Boolean = + get(a).map(_.distinctCount.isDefined).getOrElse(false) + + def hasMinMaxStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + /** + * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in + * originalMap, because updatedMap has the latest (updated) column stats. + */ + def apply(a: Attribute): ColumnStat = { + get(a).get + } + /** Updates column stats in updatedMap. */ def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) @@ -871,11 +898,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) - val newNdv = if (colStat.distinctCount > 1) { + val newNdv = if (colStat.distinctCount.isEmpty) { + // No NDV in the original stats. + None + } else if (colStat.distinctCount.get > 1) { // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows // decreases; otherwise keep it unchanged. - EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get)) } else { // no need to scale down since it is already down to 1 (for skewed distribution case) colStat.distinctCount diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f0294a4246703..2543e38a92c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -85,7 +85,8 @@ case class JoinEstimation(join: Join) extends Logging { // 3. Update statistics based on the output of join val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) - val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val attributesWithStat = join.output.filter(a => + inputAttrStats.get(a).map(_.hasCountStats).getOrElse(false)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { @@ -106,10 +107,10 @@ case class JoinEstimation(join: Join) extends Logging { case FullOuter => fromLeft.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + rightRows))) } ++ fromRight.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + leftRows))) } case _ => assert(joinType == Inner || joinType == Cross) @@ -219,19 +220,27 @@ case class JoinEstimation(join: Join) extends Logging { private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, - newMin: Option[Any], - newMax: Option[Any]): (BigInt, ColumnStat) = { + min: Option[Any], + max: Option[Any]): (BigInt, ColumnStat) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + val maxNdv = leftKeyStat.distinctCount.get.max(rightKeyStat.distinctCount.get) // Compute cardinality by the basic formula. val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) // Get the intersected column stat. - val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + val newNdv = Some(leftKeyStat.distinctCount.get.min(rightKeyStat.distinctCount.get)) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(newNdv, min, max, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -267,9 +276,17 @@ case class JoinEstimation(join: Join) extends Logging { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(Some(ceil(totalNdv)), newMin, newMax, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -292,10 +309,14 @@ case class JoinEstimation(join: Join) extends Logging { } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount - val newNdv = if (join.left.outputSet.contains(a)) { - updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv) + val newNdv = if (oldNdv.isDefined) { + Some(if (join.left.outputSet.contains(a)) { + updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get) + } else { + updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get) + }) } else { - updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv) + None } val newColStat = oldColStat.copy(distinctCount = newNdv) // TODO: support nullCount updates for specific outer joins @@ -313,7 +334,7 @@ case class JoinEstimation(join: Join) extends Logging { // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to // support it in the future by using `nullCount` in column stats. case (lk: AttributeReference, rk: AttributeReference) - if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) + if columnStatsWithCountsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 2fb587d50a4cb..565b0a10154a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -62,24 +62,15 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } - /** Set up tables and columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("t1.k-1-2") -> rangeColumnStat(2, 0), + attr("t1.v-1-10") -> rangeColumnStat(10, 0), + attr("t2.k-1-5") -> rangeColumnStat(5, 0), + attr("t3.v-1-100") -> rangeColumnStat(100, 0), + attr("t4.k-1-2") -> rangeColumnStat(2, 0), + attr("t4.v-1-10") -> rangeColumnStat(10, 0), + attr("t5.k-1-5") -> rangeColumnStat(5, 0), + attr("t5.v-1-5") -> rangeColumnStat(5, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index ada6e2a43ea0f..d4d23ad69b2c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -68,88 +68,56 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 (fact table) - attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(100, 0), + attr("f1_fk2") -> rangeColumnStat(100, 0), + attr("f1_fk3") -> rangeColumnStat(100, 0), + attr("f1_c1") -> rangeColumnStat(100, 0), + attr("f1_c2") -> rangeColumnStat(100, 0), // D1 (dimension) - attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk") -> rangeColumnStat(100, 0), + attr("d1_c2") -> rangeColumnStat(50, 0), + attr("d1_c3") -> rangeColumnStat(50, 0), // D2 (dimension) - attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_pk") -> rangeColumnStat(20, 0), + attr("d2_c2") -> rangeColumnStat(10, 0), + attr("d2_c3") -> rangeColumnStat(10, 0), // D3 (dimension) - attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk") -> rangeColumnStat(10, 0), + attr("d3_c2") -> rangeColumnStat(5, 0), + attr("d3_c3") -> rangeColumnStat(5, 0), // T1 (regular table i.e. outside star) - attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c1") -> rangeColumnStat(20, 1), + attr("t1_c2") -> rangeColumnStat(10, 1), + attr("t1_c3") -> rangeColumnStat(10, 1), // T2 (regular table) - attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c1") -> rangeColumnStat(5, 1), + attr("t2_c2") -> rangeColumnStat(5, 1), + attr("t2_c3") -> rangeColumnStat(5, 1), // T3 (regular table) - attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c1") -> rangeColumnStat(5, 1), + attr("t3_c2") -> rangeColumnStat(5, 1), + attr("t3_c3") -> rangeColumnStat(5, 1), // T4 (regular table) - attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c1") -> rangeColumnStat(5, 1), + attr("t4_c2") -> rangeColumnStat(5, 1), + attr("t4_c3") -> rangeColumnStat(5, 1), // T5 (regular table) - attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c1") -> rangeColumnStat(5, 1), + attr("t5_c2") -> rangeColumnStat(5, 1), + attr("t5_c3") -> rangeColumnStat(5, 1), // T6 (regular table) - attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4) + attr("t6_c1") -> rangeColumnStat(5, 1), + attr("t6_c2") -> rangeColumnStat(5, 1), + attr("t6_c3") -> rangeColumnStat(5, 1) )) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 777c5637201ed..4e0883e91e84a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -70,59 +70,40 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // Tables' cardinality: f1 > d3 > d1 > d2 > s3 private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 - attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(3, 0), + attr("f1_fk2") -> rangeColumnStat(3, 0), + attr("f1_fk3") -> rangeColumnStat(4, 0), + attr("f1_c4") -> rangeColumnStat(4, 0), // D1 - attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk1") -> rangeColumnStat(4, 0), + attr("d1_c2") -> rangeColumnStat(3, 0), + attr("d1_c3") -> rangeColumnStat(4, 0), + attr("d1_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D2 - attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = Some(3), min = Some("1"), max = Some("3"), + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)), + attr("d2_pk1") -> rangeColumnStat(3, 0), + attr("d2_c3") -> rangeColumnStat(3, 0), + attr("d2_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D3 - attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_fk1") -> rangeColumnStat(3, 0), + attr("d3_c2") -> rangeColumnStat(3, 0), + attr("d3_pk1") -> rangeColumnStat(5, 0), + attr("d3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // S3 - attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_pk1") -> rangeColumnStat(2, 0), + attr("s3_c2") -> rangeColumnStat(1, 0), + attr("s3_c3") -> rangeColumnStat(1, 0), + attr("s3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // F11 - attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("f11_fk1") -> rangeColumnStat(3, 0), + attr("f11_fk2") -> rangeColumnStat(3, 0), + attr("f11_fk3") -> rangeColumnStat(4, 0), + attr("f11_c4") -> rangeColumnStat(4, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 23f95a6cc2ac2..8213d568fe85e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -29,16 +29,16 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key11") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key12") -> ColumnStat(distinctCount = Some(4), min = Some(10), max = Some(40), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key21") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -63,8 +63,8 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { tableRowCount = 6, groupByColumns = Seq("key21", "key22"), // Row count = product of ndv - expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2 - .distinctCount) + expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get * + nameToColInfo("key22")._2.distinctCount.get) } test("empty group-by column") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 7d532ff343178..953094cb0dd52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.IntegerType class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") - val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStat = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val plan = StatsTestPlan( outputList = Seq(attribute), @@ -116,13 +116,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2b1fe987a7960..43440d51dede6 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() - val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() - val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1) + val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)) // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() - val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatDate = ColumnStat(distinctCount = Some(10), + min = Some(dMin), max = Some(dMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = Decimal("0.200000000000000000") val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDecimal = ColumnStat(distinctCount = Some(4), + min = Some(decMin), max = Some(decMax), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() - val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() - val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2) + val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)) // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint < cint2 val attrInt2 = AttributeReference("cint2", IntegerType)() - val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint = cint3 without overlap at all. val attrInt3 = AttributeReference("cint3", IntegerType)() - val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint4 has values in the range from 1 to 10 // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 // This column is created to test complete overlap val attrInt4 = AttributeReference("cint4", IntegerType)() - val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. // Note that cintHgm has an even distribution with histogram information built. @@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) - val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt)) // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. // Note that cintSkewHgm has a skewed distribution with histogram information built. @@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) - val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew)) val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, @@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 3)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } @@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } @@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 2)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))), expectedRowCount = 2) } @@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 6)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))), expectedRowCount = 6) } @@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 5)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -342,47 +344,47 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 9), - attrString -> colStatString.copy(distinctCount = 9)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)), + attrString -> colStatString.copy(distinctCount = Some(9))), expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 7)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } @@ -391,18 +393,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(1), + min = Some(d20170102), max = Some(d20170102), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { + val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170101), max = Some(d20170103), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -414,8 +419,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170103), max = Some(d20170105), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -424,42 +430,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(1), + min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 1) } test("cdecimal < 0.60 ") { + val dec_0_20 = Decimal("0.200000000000000000") val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(3), + min = Some(dec_0_20), max = Some(dec_0_60), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 1) } test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 10) } @@ -468,8 +477,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. // The predicate is: column IN (1, 2, 3, 4, 5). - val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildColStatInt = ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val cornerChildStatsTestplan = StatsTestPlan( outputList = Seq(attrInt), rowCount = 2L, @@ -477,16 +487,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) validateEstimatedStats( Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) } // This is a limitation test. We should remove it after the limitation is removed. test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { val attrIntLargerRange = AttributeReference("c1", IntegerType)() - val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 10, avgLen = 4, maxLen = 4) + val colStatIntLargerRange = ColumnStat(distinctCount = Some(20), + min = Some(1), max = Some(20), + nullCount = Some(10), avgLen = Some(4), maxLen = Some(4)) val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) val largerTable = StatsTestPlan( outputList = Seq(attrIntLargerRange), @@ -508,10 +519,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -519,10 +530,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -530,10 +541,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -541,10 +552,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // complete overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -552,10 +563,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -571,10 +582,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // all table records qualify. validateEstimatedStats( Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -592,11 +603,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), Seq( - attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - attrString -> colStatString.copy(distinctCount = 5)), + attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrString -> colStatString.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -606,15 +617,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cintHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 1) } @@ -629,8 +640,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } @@ -645,16 +656,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } test("cintHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -669,8 +680,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 5) } @@ -679,8 +690,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -688,7 +699,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -698,15 +709,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))), expectedRowCount = 9) } test("cintSkewHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 4) } @@ -721,8 +732,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintSkewHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -738,16 +749,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } test("cintSkewHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -764,8 +775,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 3) } @@ -774,8 +785,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 8) } @@ -783,7 +794,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))), expectedRowCount = 3) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 26139d85d25fb..12c0a7be21292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def estimateByHistogram( leftHistogram: Histogram, rightHistogram: Histogram, - expectedMin: Double, - expectedMax: Double, + expectedMin: Any, + expectedMax: Any, expectedNdv: Long, expectedRows: Long): Unit = { val col1 = attr("key1") @@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRows), attributeStats = AttributeMap(Seq( col1 -> c1.stats.attributeStats(col1).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)), col2 -> c2.stats.attributeStats(col2).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)))) ) // Join order should not affect estimation result. @@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def generateJoinChild( col: Attribute, histogram: Histogram, - expectedMin: Double, - expectedMax: Double): LogicalPlan = { - val colStat = inferColumnStat(histogram) + expectedMin: Any, + expectedMax: Any): LogicalPlan = { + val colStat = inferColumnStat(histogram, expectedMin, expectedMax) StatsTestPlan( outputList = Seq(col), rowCount = (histogram.height * histogram.bins.length).toLong, @@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } /** Column statistics should be consistent with histograms in tests. */ - private def inferColumnStat(histogram: Histogram): ColumnStat = { + private def inferColumnStat( + histogram: Histogram, + expectedMin: Any, + expectedMax: Any): ColumnStat = { + var ndv = 0L for (i <- histogram.bins.indices) { val bin = histogram.bins(i) @@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { ndv += bin.ndv } } - ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), - max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + ColumnStat(distinctCount = Some(ndv), + min = Some(expectedMin), max = Some(expectedMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(histogram)) } @@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(5 + 3), attributeStats = AttributeMap( // Update null count in column stats. - Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), - nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), - nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), - nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5))))) assert(join.stats == expectedStats) } @@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) // Update column stats for equi-join keys (key-1-5 and key-1-2). - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it // unchanged (key-2-4). - val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5)) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. - val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) - val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1), + min = Some(false), max = Some(false), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toByte), max = Some(1.toByte), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toShort), max = Some(1.toShort), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1), max = Some(1), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1L), max = Some(1L), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0), max = Some(1.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0f), max = Some(1.0f), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(1), min = Some(dec), max = Some(dec), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1), + min = Some(date), max = Some(date), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1), + min = Some(timestamp), max = Some(timestamp), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) ) } @@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("join with null column") { val (nullColumn, nullColStat) = (attr("cnull"), - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4))) val nullTable = StatsTestPlan( outputList = Seq(nullColumn), rowCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index cda54fa9d64f4..dcb37017329fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.types._ class ProjectEstimationSuite extends StatsEstimationTestBase { test("project with alias") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1), - max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4)) - val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10), - max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1), + max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) + val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10), + max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1, ar2), @@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("project on empty table") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1), rowCount = 0, @@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2), + min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(6.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(7.0), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(2), min = Some(dec1), max = Some(dec2), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2), + min = Some(d1), max = Some(d2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2), + min = Some(t1), max = Some(t2), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) )) val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 31dea2e3e7f1d..9dceca59f5b87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen + case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4 + case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize) } def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() @@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite { val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) } + + /** Get a test ColumnStat with given distinctCount and nullCount */ + def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat = + ColumnStat(distinctCount = Some(distinctCount), + min = Some(1), max = Some(distinctCount), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1122522ccb4cb..640e01336aa75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** @@ -64,12 +66,12 @@ case class AnalyzeColumnCommand( /** * Compute stats for the given columns. - * @return (row count, map from column name to ColumnStats) + * @return (row count, map from column name to CatalogColumnStats) */ private def computeColumnStats( sparkSession: SparkSession, tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan @@ -81,7 +83,7 @@ case class AnalyzeColumnCommand( // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => - if (!ColumnStat.supportsType(attr.dataType)) { + if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") @@ -103,7 +105,7 @@ case class AnalyzeColumnCommand( // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) + attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -111,9 +113,9 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, - attributePercentiles.get(attr))) + // according to `statExprs`, the stats struct always have 7 fields. + (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) }.toMap (rowCount, columnStats) } @@ -124,7 +126,7 @@ case class AnalyzeColumnCommand( sparkSession: SparkSession, relation: LogicalPlan): AttributeMap[ArrayData] = { val attrsToGenHistogram = if (conf.histogramEnabled) { - attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) } else { Nil } @@ -154,4 +156,120 @@ case class AnalyzeColumnCommand( AttributeMap(attributePercentiles.toSeq) } + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case BinaryType | StringType => + // For string and binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e400975f19708..44749190c79eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -695,10 +695,11 @@ case class DescribeColumnCommand( // Show column stats when EXTENDED or FORMATTED is specified. buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) - buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL")) - buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) - buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) - buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("distinct_count", + cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL")) + buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL")) val histDesc = for { c <- cs hist <- c.histogram diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e798532056..ed4ea0231f1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -95,7 +96,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetchedStats2.get.sizeInBytes == 0) val expectedColStat = - "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + "key" -> CatalogColumnStat(Some(0), None, None, Some(0), + Some(IntegerType.defaultSize), Some(IntegerType.defaultSize)) // There won't be histogram for empty column. Seq("true", "false").foreach { histogramEnabled => @@ -156,7 +158,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(stats, statsWithHgms).foreach { s => s.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + val roundtrip = CatalogColumnStat.fromMap("table_is_foo", field.name, v.toMap(k)) assert(roundtrip == Some(v)) } } @@ -187,7 +189,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared }.mkString(", ")) val expectedColStats = dataTypes.map { case (tpe, idx) => - (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + (s"col$idx", CatalogColumnStat(Some(0), None, None, Some(1), + Some(tpe.defaultSize.toLong), Some(tpe.defaultSize.toLong))) } // There won't be histograms for null columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 65ccc1915882f..bf4abb6e625c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -67,18 +67,21 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) + "cbool" -> CatalogColumnStat(Some(2), Some("false"), Some("true"), Some(1), Some(1), Some(1)), + "cbyte" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(1), Some(1), Some(1)), + "cshort" -> CatalogColumnStat(Some(2), Some("1"), Some("3"), Some(1), Some(2), Some(2)), + "cint" -> CatalogColumnStat(Some(2), Some("1"), Some("4"), Some(1), Some(4), Some(4)), + "clong" -> CatalogColumnStat(Some(2), Some("1"), Some("5"), Some(1), Some(8), Some(8)), + "cdouble" -> CatalogColumnStat(Some(2), Some("1.0"), Some("6.0"), Some(1), Some(8), Some(8)), + "cfloat" -> CatalogColumnStat(Some(2), Some("1.0"), Some("7.0"), Some(1), Some(4), Some(4)), + "cdecimal" -> CatalogColumnStat(Some(2), Some(dec1.toString), Some(dec2.toString), Some(1), + Some(16), Some(16)), + "cstring" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cbinary" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cdate" -> CatalogColumnStat(Some(2), Some(d1.toString), Some(d2.toString), Some(1), Some(4), + Some(4)), + "ctimestamp" -> CatalogColumnStat(Some(2), Some(t1.toString), Some(t2.toString), Some(1), + Some(8), Some(8)) ) /** @@ -110,6 +113,110 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats } + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { @@ -151,7 +258,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils */ def checkColStats( df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) @@ -161,14 +268,24 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats.keys.mkString(", ")) // Validate statistics - val table = getCatalogTable(tableName) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } + validateColStats(tableName, colStats) + } + } + + /** + * Validate if the given catalog table has the provided statistics. + */ + def validateColStats( + tableName: String, + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { + + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) } } } @@ -215,12 +332,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val emptyColStat = ColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) + val emptyCatalogColStat = CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) // Check catalog statistics assert(catalogTable.stats.isDefined) assert(catalogTable.stats.get.sizeInBytes == 0) assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyCatalogColStat)) // Check relation statistics withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 1ee1d57b8ebe1..28c340a176d91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -663,14 +663,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert table statistics to properties so that we can persist them through hive client val statsProperties = if (stats.isDefined) { - statsToProperties(stats.get, schema) + statsToProperties(stats.get) } else { new mutable.HashMap[String, String]() } @@ -1028,9 +1024,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } - private def statsToProperties( - stats: CatalogStatistics, - schema: StructType): Map[String, String] = { + private def statsToProperties(stats: CatalogStatistics): Map[String, String] = { val statsProperties = new mutable.HashMap[String, String]() statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() @@ -1038,11 +1032,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - val colNameTypeMap: Map[String, DataType] = - schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + colStat.toMap(colName).foreach { case (k, v) => + // Fully qualified name used in table properties for a particular column stat. + // For example, for column "mycol", and "min" stat, this should return + // "spark.sql.statistics.colStats.mycol.min". + statsProperties += (STATISTICS_COL_STATS_PREFIX + k -> v) } } @@ -1058,23 +1053,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (statsProps.isEmpty) { None } else { + val colStats = new mutable.HashMap[String, CatalogColumnStat] + val colStatsProps = properties.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)).map { + case (k, v) => k.drop(STATISTICS_COL_STATS_PREFIX.length) -> v + } - val colStats = new mutable.HashMap[String, ColumnStat] - - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } - ColumnStat.fromMap(table, field, colStatMap).foreach { cs => - colStats += field.name -> cs - } + // Find all the column names by matching the KEY_VERSION properties for them. + colStatsProps.keys.filter { + k => k.endsWith(CatalogColumnStat.KEY_VERSION) + }.map { k => + k.dropRight(CatalogColumnStat.KEY_VERSION.length + 1) + }.foreach { fieldName => + // and for each, create a column stat. + CatalogColumnStat.fromMap(table, fieldName, colStatsProps).foreach { cs => + colStats += fieldName -> cs } } @@ -1093,14 +1085,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert partition statistics to properties so that we can persist them through hive api val withStatsProps = lowerCasedParts.map { p => if (p.stats.isDefined) { - val statsProperties = statsToProperties(p.stats.get, schema) + val statsProperties = statsToProperties(p.stats.get) p.copy(parameters = p.parameters ++ statsProperties) } else { p @@ -1310,15 +1298,6 @@ object HiveExternalCatalog { val EMPTY_DATA_SCHEMA = new StructType() .add("col", "array", nullable = true, comment = "from deserializer") - /** - * Returns the fully qualified name used in table properties for a particular column stat. - * For example, for column "mycol", and "min" stat, this should return - * "spark.sql.statistics.colStats.mycol.min". - */ - private def columnStatKeyPropName(columnName: String, statKey: String): String = { - STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey - } - // A persisted data source table always store its schema in the catalog. private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3af8af0814bb4..61cec82984795 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils @@ -177,8 +177,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats0.get.colStats == Map( - "a" -> ColumnStat(2, Some(1), Some(2), 0, 4, 4), - "b" -> ColumnStat(1, Some(1), Some(1), 0, 4, 4))) + "a" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(0), Some(4), Some(4)), + "b" -> CatalogColumnStat(Some(1), Some("1"), Some("1"), Some(0), Some(4), Some(4)))) } } @@ -208,8 +208,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "C1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "C1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) } } @@ -596,7 +596,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + assert(fetchedStats0.get.colStats == + Map("c1" -> CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)))) // Insert new data and analyze: have the latest column stats. sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") @@ -604,18 +605,18 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) // Analyze another column: since the table is not changed, the precious column stats are kept. sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats2.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4), - "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, - avgLen = 1, maxLen = 1))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c2" -> CatalogColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)))) // Insert new data and analyze: stale column stats are removed and newly collected column // stats are added. @@ -624,10 +625,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get assert(fetchedStats3.colStats == Map( - "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, - avgLen = 8, maxLen = 8))) + "c1" -> CatalogColumnStat(distinctCount = Some(2), min = Some("1"), max = Some("2"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(2), min = Some("10.0"), max = Some("20.0"), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)))) } } @@ -999,115 +1000,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("verify serialized column stats after analyzing columns") { import testImplicits._ - val tableName = "column_stats_test2" + val tableName = "column_stats_test_ser" // (data.head.productArity - 1) because the last column does not support stats collection. assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - val expectedSerializedColStats = Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - ) - - val expectedSerializedHistograms = Map( - "spark.sql.statistics.colStats.cbyte.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), - "spark.sql.statistics.colStats.cshort.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), - "spark.sql.statistics.colStats.cint.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), - "spark.sql.statistics.colStats.clong.histogram" -> - HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), - "spark.sql.statistics.colStats.cdouble.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), - "spark.sql.statistics.colStats.cfloat.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), - "spark.sql.statistics.colStats.cdecimal.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), - "spark.sql.statistics.colStats.cdate.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), - "spark.sql.statistics.colStats.ctimestamp.histogram" -> - HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) - ) - def checkColStatsProps(expected: Map[String, String]): Unit = { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) val table = hiveClient.getTable("default", tableName) @@ -1129,6 +1026,29 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("verify column stats can be deserialized from tblproperties") { + import testImplicits._ + + val tableName = "column_stats_test_de" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Put in stats properties manually. + val table = getCatalogTable(tableName) + val newTable = table.copy( + properties = table.properties ++ + expectedSerializedColStats ++ expectedSerializedHistograms + + ("spark.sql.statistics.totalSize" -> "1") /* totalSize always required */) + hiveClient.alterTable(newTable) + + validateColStats(tableName, statsWithHgms) + } + } + test("serialization and deserialization of histograms to/from hive metastore") { import testImplicits._ From 649ed9c5732f85ef1306576fdd3a9278a2a6410c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Feb 2018 08:18:41 -0600 Subject: [PATCH 0397/2461] [SPARK-23509][BUILD] Upgrade commons-net from 2.2 to 3.1 ## What changes were proposed in this pull request? This PR avoids version conflicts of `commons-net` by upgrading commons-net from 2.2 to 3.1. We are seeing the following message during the build using sbt. ``` [warn] Found version conflict(s) in library dependencies; some are suspected to be binary incompatible: ... [warn] * commons-net:commons-net:3.1 is selected over 2.2 [warn] +- org.apache.hadoop:hadoop-common:2.6.5 (depends on 3.1) [warn] +- org.apache.spark:spark-core_2.11:2.4.0-SNAPSHOT (depends on 2.2) [warn] ``` [Here](https://commons.apache.org/proper/commons-net/changes-report.html) is a release history. [Here](https://commons.apache.org/proper/commons-net/migration.html) is a migration guide from 2.x to 3.0. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20672 from kiszk/SPARK-23509. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index ed310507d14ed..c3d1dd444b506 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 04dec04796af4..290867035f91d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/pom.xml b/pom.xml index ac30107066389..b8396166f6b1b 100644 --- a/pom.xml +++ b/pom.xml @@ -579,7 +579,7 @@ commons-net commons-net - 2.2 + 3.1 io.netty From eac0b067222a3dfa52be20360a453cb7bd420bf2 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 27 Feb 2018 08:21:11 -0600 Subject: [PATCH 0398/2461] [SPARK-17147][STREAMING][KAFKA] Allow non-consecutive offsets ## What changes were proposed in this pull request? Add a configuration spark.streaming.kafka.allowNonConsecutiveOffsets to allow streaming jobs to proceed on compacted topics (or other situations involving gaps between offsets in the log). ## How was this patch tested? Added new unit test justinrmiller has been testing this branch in production for a few weeks Author: cody koeninger Closes #20572 from koeninger/SPARK-17147. --- .../kafka010/CachedKafkaConsumer.scala | 55 +++- .../spark/streaming/kafka010/KafkaRDD.scala | 236 +++++++++++++----- .../streaming/kafka010/KafkaRDDSuite.scala | 106 ++++++++ .../streaming/kafka010/KafkaTestUtils.scala | 25 +- .../kafka010/mocks/MockScheduler.scala | 96 +++++++ .../streaming/kafka010/mocks/MockTime.scala | 51 ++++ 6 files changed, 487 insertions(+), 82 deletions(-) create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala index fa3ea6131a507..aeb8c1dc342b3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -22,10 +22,8 @@ import java.{ util => ju } import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } import org.apache.kafka.common.{ KafkaException, TopicPartition } -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging - /** * Consumer of single topicpartition, intended for cached reuse. * Underlying consumer is not threadsafe, so neither is this, @@ -38,7 +36,7 @@ class CachedKafkaConsumer[K, V] private( val partition: Int, val kafkaParams: ju.Map[String, Object]) extends Logging { - assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), + require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), "groupId used for cache key must match the groupId in kafkaParams") val topicPartition = new TopicPartition(topic, partition) @@ -53,7 +51,7 @@ class CachedKafkaConsumer[K, V] private( // TODO if the buffer was kept around as a random-access structure, // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator + protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() protected var nextOffset = -2L def close(): Unit = consumer.close() @@ -71,7 +69,7 @@ class CachedKafkaConsumer[K, V] private( } if (!buffer.hasNext()) { poll(timeout) } - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") var record = buffer.next() @@ -79,17 +77,56 @@ class CachedKafkaConsumer[K, V] private( logInfo(s"Buffer miss for $groupId $topic $partition $offset") seek(offset) poll(timeout) - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") record = buffer.next() - assert(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset") + require(record.offset == offset, + s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) } nextOffset = offset + 1 record } + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, timeout: Long): Unit = { + logDebug(s"compacted start $groupId $topic $partition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(timeout: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + private def seek(offset: Long): Unit = { logDebug(s"Seeking to $topicPartition $offset") consumer.seek(topicPartition, offset) @@ -99,7 +136,7 @@ class CachedKafkaConsumer[K, V] private( val p = consumer.poll(timeout) val r = p.records(topicPartition) logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.iterator + buffer = r.listIterator } } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index d9fc9cc206647..07239eda64d2e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -55,12 +55,12 @@ private[spark] class KafkaRDD[K, V]( useConsumerCache: Boolean ) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges { - assert("none" == + require("none" == kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String], ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " must be set to none for executor kafka params, else messages may not match offsetRange") - assert(false == + require(false == kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean], ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + " must be set to false for executor kafka params, else offsets may commit before processing") @@ -74,6 +74,8 @@ private[spark] class KafkaRDD[K, V]( conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64) private val cacheLoadFactor = conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat + private val compacted = + conf.getBoolean("spark.streaming.kafka.allowNonConsecutiveOffsets", false) override def persist(newLevel: StorageLevel): this.type = { logError("Kafka ConsumerRecord is not serializable. " + @@ -87,48 +89,63 @@ private[spark] class KafkaRDD[K, V]( }.toArray } - override def count(): Long = offsetRanges.map(_.count).sum + override def count(): Long = + if (compacted) { + super.count() + } else { + offsetRanges.map(_.count).sum + } override def countApprox( timeout: Long, confidence: Double = 0.95 - ): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[K, V]] = { - val nonEmptyPartitions = this.partitions - .map(_.asInstanceOf[KafkaRDDPartition]) - .filter(_.count > 0) + ): PartialResult[BoundedDouble] = + if (compacted) { + super.countApprox(timeout, confidence) + } else { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[K, V]](0) + override def isEmpty(): Boolean = + if (compacted) { + super.isEmpty() + } else { + count == 0L } - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.count) - result + (part.index -> taken.toInt) + override def take(num: Int): Array[ConsumerRecord[K, V]] = + if (compacted) { + super.take(num) + } else if (num < 1) { + Array.empty[ConsumerRecord[K, V]] + } else { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (nonEmptyPartitions.isEmpty) { + Array.empty[ConsumerRecord[K, V]] } else { - result + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ).flatten } } - val buf = new ArrayBuffer[ConsumerRecord[K, V]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - private def executors(): Array[ExecutorCacheTaskLocation] = { val bm = sparkContext.env.blockManager bm.master.getPeers(bm.blockManagerId).toArray @@ -172,57 +189,138 @@ private[spark] class KafkaRDD[K, V]( override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = { val part = thePart.asInstanceOf[KafkaRDDPartition] - assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { - new KafkaRDDIterator(part, context) + logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + if (compacted) { + new CompactedKafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } else { + new KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } } } +} - /** - * An iterator that fetches messages directly from Kafka for the offsets in partition. - * Uses a cached consumer where possible to take advantage of prefetching - */ - private class KafkaRDDIterator( - part: KafkaRDDPartition, - context: TaskContext) extends Iterator[ConsumerRecord[K, V]] { - - logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + - s"offsets ${part.fromOffset} -> ${part.untilOffset}") - - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching + */ +private class KafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float +) extends Iterator[ConsumerRecord[K, V]] { + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + context.addTaskCompletionListener(_ => closeIfNeeded()) + + val consumer = if (useConsumerCache) { + CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + if (context.attemptNumber >= 1) { + // just in case the prior attempt failures were cache related + CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + } + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + } else { + CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + } - context.addTaskCompletionListener{ context => closeIfNeeded() } + var requestOffset = part.fromOffset - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + def closeIfNeeded(): Unit = { + if (!useConsumerCache && consumer != null) { + consumer.close() } + } - var requestOffset = part.fromOffset + override def hasNext(): Boolean = requestOffset < part.untilOffset - def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close - } + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") } + val r = consumer.get(requestOffset, pollTimeout) + requestOffset += 1 + r + } +} - override def hasNext(): Boolean = requestOffset < part.untilOffset - - override def next(): ConsumerRecord[K, V] = { - assert(hasNext(), "Can't call getNext() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeout) - requestOffset += 1 - r +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching. + * Intended for compacted topics, or other cases when non-consecutive offsets are ok. + */ +private class CompactedKafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float + ) extends KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) { + + consumer.compactedStart(part.fromOffset, pollTimeout) + + private var nextRecord = consumer.compactedNext(pollTimeout) + + private var okNext: Boolean = true + + override def hasNext(): Boolean = okNext + + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") + } + val r = nextRecord + if (r.offset + 1 >= part.untilOffset) { + okNext = false + } else { + nextRecord = consumer.compactedNext(pollTimeout) + if (nextRecord.offset >= part.untilOffset) { + okNext = false + consumer.compactedPrevious() + } } + r } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index be373af0599cc..271adea1df731 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -18,16 +18,22 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } +import java.io.File import scala.collection.JavaConverters._ import scala.util.Random +import kafka.common.TopicAndPartition +import kafka.log._ +import kafka.message._ +import kafka.utils.Pool import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.streaming.kafka010.mocks.MockTime class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -64,6 +70,41 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private val preferredHosts = LocationStrategies.PreferConsistent + private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { + val mockTime = new MockTime() + // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api + val logs = new Pool[TopicAndPartition, Log]() + val logDir = kafkaTestUtils.brokerLogDir + val dir = new File(logDir, topic + "-" + partition) + dir.mkdirs() + val logProps = new ju.Properties() + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val log = new Log( + dir, + LogConfig(logProps), + 0L, + mockTime.scheduler, + mockTime + ) + messages.foreach { case (k, v) => + val msg = new ByteBufferMessageSet( + NoCompressionCodec, + new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) + log.append(msg) + } + log.roll() + logs.put(TopicAndPartition(topic, partition), log) + + val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + cleaner.startup() + cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + + cleaner.shutdown() + mockTime.scheduler.shutdown() + } + + test("basic usage") { val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" kafkaTestUtils.createTopic(topic) @@ -102,6 +143,71 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("compacted topic") { + val compactConf = sparkConf.clone() + compactConf.set("spark.streaming.kafka.allowNonConsecutiveOffsets", "true") + sc.stop() + sc = new SparkContext(compactConf) + val topic = s"topiccompacted-${Random.nextInt}-${System.currentTimeMillis}" + + val messages = Array( + ("a", "1"), + ("a", "2"), + ("b", "1"), + ("c", "1"), + ("c", "2"), + ("b", "2"), + ("b", "3") + ) + val compactedMessages = Array( + ("a", "2"), + ("b", "3"), + ("c", "2") + ) + + compactLogs(topic, 0, messages) + + val props = new ju.Properties() + props.put("cleanup.policy", "compact") + props.put("flush.messages", "1") + props.put("segment.ms", "1") + props.put("segment.bytes", "256") + kafkaTestUtils.createTopic(topic, 1, props) + + + val kafkaParams = getKafkaParams() + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, offsetRanges, preferredHosts + ).map(m => m.key -> m.value) + + val received = rdd.collect.toSet + assert(received === compactedMessages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === compactedMessages.size) + assert(rdd.countApprox(0).getFinalValue.mean === compactedMessages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head === compactedMessages.head) + assert(rdd.take(messages.size + 10).size === compactedMessages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts) + .map(_.value) + .collect() + } + } + test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 6c7024ea4b5a5..70b579d96d692 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -162,17 +162,22 @@ private[kafka010] class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, config: Properties): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1, config) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + createTopic(topic, partitions, new Properties()) + } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { - createTopic(topic, 1) + createTopic(topic, 1, new Properties()) } /** Java-friendly function for sending messages to the Kafka broker */ @@ -196,12 +201,24 @@ private[kafka010] class KafkaTestUtils extends Logging { producer = null } + /** Send the array of (key, value) messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[(String, String)]): Unit = { + producer = new KafkaProducer[String, String](producerConfiguration) + messages.foreach { message => + producer.send(new ProducerRecord[String, String](topic, message._1, message._2)) + } + producer.close() + producer = null + } + + val brokerLogDir = Utils.createTempDir().getAbsolutePath + private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala new file mode 100644 index 0000000000000..928e1a6ef54b9 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.PriorityQueue + +import kafka.utils.{Scheduler, Time} + +/** + * A mock scheduler that executes tasks synchronously using a mock time instance. + * Tasks are executed synchronously when the time is advanced. + * This class is meant to be used in conjunction with MockTime. + * + * Example usage + * + * val time = new MockTime + * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000) + * time.sleep(1001) // this should cause our scheduled task to fire + * + * + * Incrementing the time to the exact next execution time of a task will result in that task + * executing (it as if execution itself takes no time). + */ +private[kafka010] class MockScheduler(val time: Time) extends Scheduler { + + /* a priority queue of tasks ordered by next execution time */ + var tasks = new PriorityQueue[MockTask]() + + def isStarted: Boolean = true + + def startup(): Unit = {} + + def shutdown(): Unit = synchronized { + tasks.foreach(_.fun()) + tasks.clear() + } + + /** + * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs + * when this method is called and the execution happens synchronously in the calling thread. + * If you are using the scheduler associated with a MockTime instance this call + * will be triggered automatically. + */ + def tick(): Unit = synchronized { + val now = time.milliseconds + while(!tasks.isEmpty && tasks.head.nextExecution <= now) { + /* pop and execute the task with the lowest next execution time */ + val curr = tasks.dequeue + curr.fun() + /* if the task is periodic, reschedule it and re-enqueue */ + if(curr.periodic) { + curr.nextExecution += curr.period + this.tasks += curr + } + } + } + + def schedule( + name: String, + fun: () => Unit, + delay: Long = 0, + period: Long = -1, + unit: TimeUnit = TimeUnit.MILLISECONDS): Unit = synchronized { + tasks += MockTask(name, fun, time.milliseconds + delay, period = period) + tick() + } + +} + +case class MockTask( + val name: String, + val fun: () => Unit, + var nextExecution: Long, + val period: Long) extends Ordered[MockTask] { + def periodic: Boolean = period >= 0 + def compare(t: MockTask): Int = { + java.lang.Long.compare(t.nextExecution, nextExecution) + } +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala new file mode 100644 index 0000000000000..a68f94db1f689 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent._ + +import kafka.utils.Time + +/** + * A class used for unit testing things which depend on the Time interface. + * + * This class never manually advances the clock, it only does so when you call + * sleep(ms) + * + * It also comes with an associated scheduler instance for managing background tasks in + * a deterministic way. + */ +private[kafka010] class MockTime(@volatile private var currentMs: Long) extends Time { + + val scheduler = new MockScheduler(this) + + def this() = this(System.currentTimeMillis) + + def milliseconds: Long = currentMs + + def nanoseconds: Long = + TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) + + def sleep(ms: Long) { + this.currentMs += ms + scheduler.tick() + } + + override def toString(): String = s"MockTime($milliseconds)" + +} From 414ee867ba0835b97aae2e8d4e489e1879c251dd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Feb 2018 08:44:25 -0800 Subject: [PATCH 0399/2461] [SPARK-23523][SQL] Fix the incorrect result caused by the rule OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? ```Scala val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") .write.json(tablePath.getCanonicalPath) val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() df.show() ``` It generates a wrong result. ``` [c,e,a] ``` We have a bug in the rule `OptimizeMetadataOnlyQuery `. We should respect the attribute order in the original leaf node. This PR is to fix it. ## How was this patch tested? Added a test case Author: gatorsmile Closes #20684 from gatorsmile/optimizeMetadataOnly. --- .../plans/logical/LocalRelation.scala | 9 ++++---- .../execution/OptimizeMetadataOnlyQuery.scala | 12 ++++++++-- .../datasources/HadoopFsRelation.scala | 3 +++ .../OptimizeMetadataOnlyQuerySuite.scala | 22 +++++++++++++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d73d7e73f28d5..b05508db786ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,10 +43,11 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], - data: Seq[InternalRow] = Nil, - // Indicates whether this relation has data from a streaming source. - override val isStreaming: Boolean = false) +case class LocalRelation( + output: Seq[Attribute], + data: Seq[InternalRow] = Nil, + // Indicates whether this relation has data from a streaming source. + override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 18f6f697bc857..0613d9053f826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import java.util.Locale + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -80,8 +83,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val partColumns = partitionColumnNames.map(_.toLowerCase).toSet - relation.output.filter(a => partColumns.contains(a.name.toLowerCase)) + val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + partitionColumnNames.map { colName => + attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), + throw new AnalysisException(s"Unable to find the column `$colName` " + + s"given [${relation.output.map(_.name).mkString(", ")}]") + ) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 6b34638529770..ac574b07ec497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,6 +67,9 @@ case class HadoopFsRelation( } } + // When data schema and partition schema have the overlapped columns, the output + // schema respects the order of data schema for the overlapped columns, but respect + // the data types of partition schema val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 78c1e5dae566d..a543eb8351656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.io.File + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SharedSQLContext class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { @@ -125,4 +128,23 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() } } + + test("Incorrect result caused by the rule OptimizeMetadataOnlyQuery") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + withTempPath { path => + val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") + Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") + .write.json(tablePath.getCanonicalPath) + + val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() + checkAnswer(df, Row("a", "e", "c")) + + val localRelation = df.queryExecution.optimizedPlan.collectFirst { + case l: LocalRelation => l + } + assert(localRelation.nonEmpty, "expect to see a LocalRelation") + assert(localRelation.get.output.map(_.name) == Seq("cOl3", "cOl1", "cOl5")) + } + } + } } From ecb8b383af1cf1b67f3111c148229e00c9c17c40 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 27 Feb 2018 11:12:32 -0800 Subject: [PATCH 0400/2461] [SPARK-23365][CORE] Do not adjust num executors when killing idle executors. The ExecutorAllocationManager should not adjust the target number of executors when killing idle executors, as it has already adjusted the target number down based on the task backlog. The name `replace` was misleading with DynamicAllocation on, as the target number of executors is changed outside of the call to `killExecutors`, so I adjusted that name. Also separated out the logic of `countFailures` as you don't always want that tied to `replace`. While I was there I made two changes that weren't directly related to this: 1) Fixed `countFailures` in a couple cases where it was getting an incorrect value since it used to be tied to `replace`, eg. when killing executors on a blacklisted node. 2) hard error if you call `sc.killExecutors` with dynamic allocation on, since that's another way the ExecutorAllocationManager and the CoarseGrainedSchedulerBackend would get out of sync. Added a unit test case which verifies that the calls to ExecutorAllocationClient do not adjust the number of executors. Author: Imran Rashid Closes #20604 from squito/SPARK-23365. --- .../spark/ExecutorAllocationClient.scala | 15 +++-- .../spark/ExecutorAllocationManager.scala | 20 ++++-- .../scala/org/apache/spark/SparkContext.scala | 13 +++- .../spark/scheduler/BlacklistTracker.scala | 3 +- .../CoarseGrainedSchedulerBackend.scala | 22 ++++--- .../ExecutorAllocationManagerSuite.scala | 66 ++++++++++++++++++- .../StandaloneDynamicAllocationSuite.scala | 3 +- .../scheduler/BlacklistTrackerSuite.scala | 14 ++-- 8 files changed, 121 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 9112d93a86b2a..63d87b4cd385c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -55,18 +55,18 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors will be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * to count those failures toward task failure limits * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ def killExecutors( executorIds: Seq[String], - replace: Boolean = false, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean = false): Seq[String] /** @@ -81,7 +81,8 @@ private[spark] trait ExecutorAllocationClient { * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = { - val killedExecutors = killExecutors(Seq(executorId)) + val killedExecutors = killExecutors(Seq(executorId), adjustTargetNumExecutors = true, + countFailures = false) killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6c59038f2a6c1..189d91333c045 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** @@ -81,7 +82,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, listenerBus: LiveListenerBus, - conf: SparkConf) + conf: SparkConf, + blockManagerMaster: BlockManagerMaster) extends Logging { allocationManager => @@ -151,7 +153,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener + val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. private val executor = @@ -334,6 +336,11 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { + // We lower the target number of executors but don't actively kill any yet. Killing is + // controlled separately by an idle timeout. It's still helpful to reduce the target number + // in case an executor just happens to get lost (eg., bad hardware, or the cluster manager + // preempts it) -- in that case, there is no point in trying to immediately get a new + // executor, since we wouldn't even use it yet. client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") @@ -455,7 +462,10 @@ private[spark] class ExecutorAllocationManager( val executorsRemoved = if (testing) { executorIdsToBeRemoved } else { - client.killExecutors(executorIdsToBeRemoved) + // We don't want to change our target number of executors, because we already did that + // when the task backlog decreased. + client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + countFailures = false, force = false) } // [SPARK-21834] killExecutors api reduces the target number of executors. // So we need to update the target with desired value. @@ -575,7 +585,7 @@ private[spark] class ExecutorAllocationManager( // Note that it is not necessary to query the executors since all the cached // blocks we are concerned with are reported to the driver. Note that this // does not include broadcast blocks. - val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId) val now = clock.getTimeMillis() val timeout = { if (hasCachedBlocks) { @@ -610,7 +620,7 @@ private[spark] class ExecutorAllocationManager( * This class is intentionally conservative in its assumptions about the relative ordering * and consistency of events returned by the listener. */ - private class ExecutorAllocationListener extends SparkListener { + private[spark] class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] // Number of running tasks per stage including speculative tasks. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dc531e3337014..5e8595603cc90 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -534,7 +534,8 @@ class SparkContext(config: SparkConf) extends Logging { schedulerBackend match { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( - schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, + _env.blockManager.master)) case _ => None } @@ -1633,6 +1634,8 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * + * This is not supported when dynamic allocation is turned on. + * * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to @@ -1644,7 +1647,10 @@ class SparkContext(config: SparkConf) extends Logging { def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(executorIds, replace = false, force = true).nonEmpty + require(executorAllocationManager.isEmpty, + "killExecutors() unsupported with Dynamic Allocation turned on") + b.killExecutors(executorIds, adjustTargetNumExecutors = true, countFailures = false, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false @@ -1682,7 +1688,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(Seq(executorId), replace = true, force = true).nonEmpty + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = false, countFailures = true, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index cd8e61d6d0208..952598f6de19d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -152,7 +152,8 @@ private[scheduler] class BlacklistTracker ( case Some(a) => logInfo(s"Killing blacklisted executor id $exec " + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") - a.killExecutors(Seq(exec), true, true) + a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, + force = true) case None => logWarning(s"Not attempting to kill blacklisted executor id $exec " + s"since allocation client is not defined.") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4d75063fbf1c5..5627a557a12f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -147,7 +147,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case KillExecutorsOnHost(host) => scheduler.getExecutorsAliveOnHost(host).foreach { exec => - killExecutors(exec.toSeq, replace = true, force = true) + killExecutors(exec.toSeq, adjustTargetNumExecutors = false, countFailures = false, + force = true) } case UpdateDelegationTokens(newDelegationTokens) => @@ -584,18 +585,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * those failures be counted to task failure limits? * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ final override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") @@ -610,7 +611,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") @@ -618,12 +619,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. val adjustTotalExecutors = - if (!replace) { + if (adjustTargetNumExecutors) { requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) if (requestedTotalExecutors != (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { logDebug( - s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force): + |Executor counts do not match: |requestedTotalExecutors = $requestedTotalExecutors |numExistingExecutors = $numExistingExecutors |numPendingExecutors = $numPendingExecutors diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a0cae5a9e011c..9807d1269e3d4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import scala.collection.mutable +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics @@ -26,6 +28,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.ManualClock /** @@ -1050,6 +1053,66 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager) === Map.empty) } + test("SPARK-23365 Don't update target num executors when killing idle executors") { + val minExecutors = 1 + val initialExecutors = 1 + val maxExecutors = 2 + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) + .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms") + val mockAllocationClient = mock(classOf[ExecutorAllocationClient]) + val mockBMM = mock(classOf[BlockManagerMaster]) + val manager = new ExecutorAllocationManager( + mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM) + val clock = new ManualClock() + manager.setClock(clock) + + when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true) + // test setup -- job with 2 tasks, scale up to two executors + assert(numExecutorsTarget(manager) === 1) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty))) + manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2))) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 2) + val taskInfo0 = createTaskInfo(0, 0, "executor-1") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0)) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty))) + val taskInfo1 = createTaskInfo(1, 1, "executor-2") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1)) + assert(numExecutorsTarget(manager) === 2) + + // have one task finish -- we should adjust the target number of executors down + // but we should *not* kill any executors yet + manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 1) + verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any()) + + // now we cross the idle timeout for executor-1, so we kill it. the really important + // thing here is that we do *not* ask the executor allocation client to adjust the target + // number of executors down + when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false)) + .thenReturn(Seq("executor-1")) + clock.advance(3000) + schedule(manager) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 1) + // here's the important verify -- we did kill the executors, but did not adjust the target count + verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -1268,7 +1331,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = executorIds override def start(): Unit = sb.start() diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index c21ee7d26f8ca..27cc47496c805 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite syncExecutors(sc) sc.schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = false, force) + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false, + force) case _ => fail("expected coarse grained scheduler") } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index afebcdd7b9e31..06d7afaaff55c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1) // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole @@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) - verify(allocationClientMock).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutors(Seq("2"), false, false, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -571,7 +571,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M conf.set(config.BLACKLIST_KILL_ENABLED, false) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -580,7 +580,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M clock.advance(1000) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) verify(allocationClientMock, never).killExecutorsOnHost(any()) assert(blacklist.executorIdToBlacklistStatus.contains("1")) From 598446b74b61fee272d3aee3a2e9a3fc90a70d6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 27 Feb 2018 11:33:10 -0800 Subject: [PATCH 0401/2461] [SPARK-23501][UI] Refactor AllStagesPage in order to avoid redundant code As suggested in #20651, the code is very redundant in `AllStagesPage` and modifying it is a copy-and-paste work. We should avoid such a pattern, which is error prone, and have a cleaner solution which avoids code redundancy. existing UTs Author: Marco Gaido Closes #20663 from mgaido91/SPARK-23475_followup. --- .../apache/spark/ui/jobs/AllStagesPage.scala | 261 +++++++----------- 1 file changed, 102 insertions(+), 159 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 38450b9126ff0..4658aa1cea3f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -19,46 +19,20 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, NodeSeq} +import scala.xml.{Attribute, Elem, Node, NodeSeq, Null, Text} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.status.PoolData -import org.apache.spark.status.api.v1._ +import org.apache.spark.status.{AppSummary, PoolData} +import org.apache.spark.status.api.v1.{StageData, StageStatus} import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc + private val subPath = "stages" private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - val allStages = parent.store.stageList(null) - - val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) - val pendingStages = allStages.filter(_.status == StageStatus.PENDING) - val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) - val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) - val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - - val numFailedStages = failedStages.size - val subPath = "stages" - - val activeStagesTable = - new StageTableBase(parent.store, request, activeStages, "active", "activeStage", - parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) - val pendingStagesTable = - new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val completedStagesTable = - new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val skippedStagesTable = - new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val failedStagesTable = - new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", - parent.basePath, subPath, parent.isFairScheduler, false, true) - // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( @@ -67,152 +41,121 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { }.toMap val poolTable = new PoolTable(pools, parent) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowSkippedStages = skippedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val allStatuses = Seq(StageStatus.ACTIVE, StageStatus.PENDING, StageStatus.COMPLETE, + StageStatus.SKIPPED, StageStatus.FAILED) + val allStages = parent.store.stageList(null) val appSummary = parent.store.appSummary() - val completedStageNumStr = if (appSummary.numCompletedStages == completedStages.size) { - s"${appSummary.numCompletedStages}" - } else { - s"${appSummary.numCompletedStages}, only showing ${completedStages.size}" - } + + val (summaries, tables) = allStatuses.map( + summaryAndTableForStatus(allStages, appSummary, _, request)).unzip val summary: NodeSeq =
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - Pending Stages: - {pendingStages.size} -
    • - } - } - { - if (shouldShowCompletedStages) { -
    • - Completed Stages: - {completedStageNumStr} -
    • - } - } - { - if (shouldShowSkippedStages) { -
    • - Skipped Stages: - {skippedStages.size} -
    • - } - } - { - if (shouldShowFailedStages) { -
    • - Failed Stages: - {numFailedStages} -
    • - } - } + {summaries.flatten}
    - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { - -

    - - Fair Scheduler Pools ({pools.size}) -

    -
    ++ -
    - {poolTable.toNodeSeq} -
    - } else { - Seq.empty[Node] - } - } - if (shouldShowActiveStages) { - content ++= - -

    - - Active Stages ({activeStages.size}) -

    -
    ++ -
    - {activeStagesTable.toNodeSeq} -
    - } - if (shouldShowPendingStages) { - content ++= - + val poolsDescription = if (sc.isDefined && isFairScheduler) { +

    - Pending Stages ({pendingStages.size}) + Fair Scheduler Pools ({pools.size})

    ++ -
    - {pendingStagesTable.toNodeSeq} +
    + {poolTable.toNodeSeq}
    + } else { + Seq.empty[Node] + } + + val content = summary ++ poolsDescription ++ tables.flatten.flatten + + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def summaryAndTableForStatus( + allStages: Seq[StageData], + appSummary: AppSummary, + status: StageStatus, + request: HttpServletRequest): (Option[Elem], Option[NodeSeq]) = { + val stages = if (status == StageStatus.FAILED) { + allStages.filter(_.status == status).reverse + } else { + allStages.filter(_.status == status) } - if (shouldShowCompletedStages) { - content ++= - -

    - - Completed Stages ({completedStageNumStr}) -

    -
    ++ -
    - {completedStagesTable.toNodeSeq} -
    + + if (stages.isEmpty) { + (None, None) + } else { + val killEnabled = status == StageStatus.ACTIVE && parent.killEnabled + val isFailedStage = status == StageStatus.FAILED + + val stagesTable = + new StageTableBase(parent.store, request, stages, statusName(status), stageTag(status), + parent.basePath, subPath, parent.isFairScheduler, killEnabled, isFailedStage) + val stagesSize = stages.size + (Some(summary(appSummary, status, stagesSize)), + Some(table(appSummary, status, stagesTable, stagesSize))) } - if (shouldShowSkippedStages) { - content ++= - -

    - - Skipped Stages ({skippedStages.size}) -

    -
    ++ -
    - {skippedStagesTable.toNodeSeq} -
    + } + + private def statusName(status: StageStatus): String = status match { + case StageStatus.ACTIVE => "active" + case StageStatus.COMPLETE => "completed" + case StageStatus.FAILED => "failed" + case StageStatus.PENDING => "pending" + case StageStatus.SKIPPED => "skipped" + } + + private def stageTag(status: StageStatus): String = s"${statusName(status)}Stage" + + private def headerDescription(status: StageStatus): String = statusName(status).capitalize + + private def summaryContent(appSummary: AppSummary, status: StageStatus, size: Int): String = { + if (status == StageStatus.COMPLETE && appSummary.numCompletedStages != size) { + s"${appSummary.numCompletedStages}, only showing $size" + } else { + s"$size" } - if (shouldShowFailedStages) { - content ++= - -

    - - Failed Stages ({numFailedStages}) -

    -
    ++ -
    - {failedStagesTable.toNodeSeq} -
    + } + + private def summary(appSummary: AppSummary, status: StageStatus, size: Int): Elem = { + val summary = +
  • + + {headerDescription(status)} Stages: + + {summaryContent(appSummary, status, size)} +
  • + + if (status == StageStatus.COMPLETE) { + summary % Attribute(None, "id", Text("completed-summary"), Null) + } else { + summary } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def table( + appSummary: AppSummary, + status: StageStatus, + stagesTable: StageTableBase, + size: Int): NodeSeq = { + val classSuffix = s"${statusName(status).capitalize}Stages" + +

    + + {headerDescription(status)} Stages ({summaryContent(appSummary, status, size)}) +

    +
    ++ +
    + {stagesTable.toNodeSeq} +
    } } From 23ac3aaba4a33bc3d31d01f21e93c4681ef6de03 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 28 Feb 2018 09:25:02 +0900 Subject: [PATCH 0402/2461] [SPARK-23417][PYTHON] Fix the build instructions supplied by exception messages in python streaming tests ## What changes were proposed in this pull request? Fix the build instructions supplied by exception messages in python streaming tests. I also added -DskipTests to the maven instructions to avoid the 170 minutes of scala tests that occurs each time one wants to add a jar to the assembly directory. ## How was this patch tested? - clone branch - run build/sbt package - run python/run-tests --modules "pyspark-streaming" , expect error message - follow instructions in error message. i.e., run build/sbt assembly/package streaming-kafka-0-8-assembly/assembly - rerun python tests, expect error message - follow instructions in error message. i.e run build/sbt -Pflume assembly/package streaming-flume-assembly/assembly - rerun python tests, see success. - repeated all of the above for mvn version of the process. Author: Bruce Robbins Closes #20638 from bersprockets/SPARK-23417_propa. --- python/pyspark/streaming/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..71f8101e34c50 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1477,8 +1477,8 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/sbt -Pkafka-0-8 assembly/package streaming-kafka-0-8-assembly/assembly' or " + "'build/mvn -DskipTests -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1494,8 +1494,8 @@ def search_flume_assembly_jar(): raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or " + "'build/mvn -DskipTests -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) From b14993e1fcb68e1c946a671c6048605ab4afdf58 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Feb 2018 11:00:54 +0900 Subject: [PATCH 0403/2461] [SPARK-23448][SQL] Clarify JSON and CSV parser behavior in document ## What changes were proposed in this pull request? Clarify JSON and CSV reader behavior in document. JSON doesn't support partial results for corrupted records. CSV only supports partial results for the records with more or less tokens. ## How was this patch tested? Pass existing tests. Author: Liang-Chi Hsieh Closes #20666 from viirya/SPARK-23448-2. --- python/pyspark/sql/readwriter.py | 30 ++++++++++--------- python/pyspark/sql/streaming.py | 30 ++++++++++--------- .../sql/catalyst/json/JacksonParser.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 22 +++++++------- .../datasources/csv/UnivocityParser.scala | 5 ++++ .../sql/streaming/DataStreamReader.scala | 22 +++++++------- 6 files changed, 64 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 49af1bcee5ef8..9d05ac7cb39be 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -209,13 +209,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -393,13 +393,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e2a97acb5e2a7..cc622decfd682 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -442,13 +442,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -621,13 +621,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index bd144c9575c72..7f6956994f31f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -357,6 +357,9 @@ class JacksonParser( } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => + // JSON parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4274f120a375a..0139913aaa4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -345,12 +345,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -550,12 +550,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 7d6d7e7eef926..3d6cc30f2ba83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,6 +203,8 @@ class UnivocityParser( case _: BadRecordException => None } } + // For records with less or more tokens than the schema, tries to return partial results + // if possible. throw BadRecordException( () => getCurrentInput, () => getPartialResult(), @@ -218,6 +220,9 @@ class UnivocityParser( row } catch { case NonFatal(e) => + // For corrupted records with the number of tokens same as the schema, + // CSV reader doesn't support partial results. All fields other than the field + // configured by `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => getCurrentInput, () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f23851655350a..61e22fac854f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -236,12 +236,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -316,12 +316,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    From 6a8abe29ef3369b387d9bc2ee3459a6611246ab1 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Wed, 28 Feb 2018 23:16:29 +0800 Subject: [PATCH 0404/2461] [SPARK-23508][CORE] Fix BlockmanagerId in case blockManagerIdCache cause oom MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … cause oom ## What changes were proposed in this pull request? blockManagerIdCache in BlockManagerId will not remove old values which may cause oom `val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()` Since whenever we apply a new BlockManagerId, it will put into this map. This patch will use guava cahce for blockManagerIdCache instead. A heap dump show in [SPARK-23508](https://issues.apache.org/jira/browse/SPARK-23508) ## How was this patch tested? Exist tests. Author: zhoukang Closes #20667 from caneGuy/zhoukang/fix-history. --- .../org/apache/spark/storage/BlockManagerId.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 2c3da0ee85e06..d4a59c33b974c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -18,7 +18,8 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap + +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi @@ -132,10 +133,17 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + /** + * The max cache size is hardcoded to 10000, since the size of a BlockManagerId + * object is about 48B, the total memory cost should be below 1MB which is feasible. + */ + val blockManagerIdCache = CacheBuilder.newBuilder() + .maximumSize(10000) + .build(new CacheLoader[BlockManagerId, BlockManagerId]() { + override def load(id: BlockManagerId) = id + }) def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) blockManagerIdCache.get(id) } } From fab563b9bd1581112462c0fc0b299ad6510b6564 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 1 Mar 2018 00:44:13 +0900 Subject: [PATCH 0405/2461] [SPARK-23517][PYTHON] Make `pyspark.util._exception_message` produce the trace from Java side by Py4JJavaError ## What changes were proposed in this pull request? This PR proposes for `pyspark.util._exception_message` to produce the trace from Java side by `Py4JJavaError`. Currently, in Python 2, it uses `message` attribute which `Py4JJavaError` didn't happen to have: ```python >>> from pyspark.util import _exception_message >>> try: ... sc._jvm.java.lang.String(None) ... except Exception as e: ... pass ... >>> e.message '' ``` Seems we should use `str` instead for now: https://github.com/bartdag/py4j/blob/aa6c53b59027925a426eb09b58c453de02c21b7c/py4j-python/src/py4j/protocol.py#L412 but this doesn't address the problem with non-ascii string from Java side - `https://github.com/bartdag/py4j/issues/306` So, we could directly call `__str__()`: ```python >>> e.__str__() u'An error occurred while calling None.java.lang.String.\n: java.lang.NullPointerException\n\tat java.lang.String.(String.java:588)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)\n\tat sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)\n\tat java.lang.reflect.Constructor.newInstance(Constructor.java:422)\n\tat py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)\n\tat py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\n\tat py4j.Gateway.invoke(Gateway.java:238)\n\tat py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)\n\tat py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)\n\tat py4j.GatewayConnection.run(GatewayConnection.java:214)\n\tat java.lang.Thread.run(Thread.java:745)\n' ``` which doesn't type coerce unicodes to `str` in Python 2. This can be actually a problem: ```python from pyspark.sql.functions import udf spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.range(1).select(udf(lambda x: [[]])()).toPandas() ``` **Before** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` **After** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: An error occurred while calling o47.collectAsArrowToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 0.0 failed 1 times, most recent failure: Lost task 7.0 in stage 0.0 (TID 7, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/.../spark/python/pyspark/worker.py", line 245, in main process() File "/.../spark/python/pyspark/worker.py", line 240, in process ... Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20680 from HyukjinKwon/SPARK-23517. --- python/pyspark/tests.py | 11 +++++++++++ python/pyspark/util.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 511585763cb01..9111dbbed5929 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2293,6 +2293,17 @@ def set(self, x=None, other=None, other_x=None): self.assertEqual(b._x, 2) +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e5d332ce54429..ad4a0bc68ef41 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from py4j.protocol import Py4JJavaError __all__ = [] @@ -33,6 +34,12 @@ def _exception_message(excp): >>> msg == _exception_message(excp) True """ + if isinstance(excp, Py4JJavaError): + # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message' + # attribute in Python 2. We should call 'str' function on this exception in general but + # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work + # around by the direct call, '__str__()'. Please see SPARK-23517. + return excp.__str__() if hasattr(excp, "message"): return excp.message return str(excp) From 476a7f026bc45462067ebd39cd269147e84cd641 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 28 Feb 2018 08:44:53 -0800 Subject: [PATCH 0406/2461] [SPARK-23514] Use SessionState.newHadoopConf() to propage hadoop configs set in SQLConf. ## What changes were proposed in this pull request? A few places in `spark-sql` were using `sc.hadoopConfiguration` directly. They should be using `sessionState.newHadoopConf()` to blend in configs that were set through `SQLConf`. Also, for better UX, for these configs blended in from `SQLConf`, we should consider removing the `spark.hadoop` prefix, so that the settings are recognized whether or not they were specified by the user. ## How was this patch tested? Tested that AlterTableRecoverPartitions now correctly recognizes settings that are passed in to the FileSystem through SQLConf. Author: Juliusz Sompolski Closes #20679 from juliuszsompolski/SPARK-23514. --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0142f17ce62e2..964cbca049b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -610,10 +610,10 @@ case class AlterTableRecoverPartitionsCommand( val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val hadoopConf = spark.sessionState.newHadoopConf() + val fs = root.getFileSystem(hadoopConf) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt - val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) @@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand( pathFilter: PathFilter, threshold: Int): GenMap[String, PartitionStatistics] = { if (partitionSpecsAndLocs.length > threshold) { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 19028939f3673..fcf2025d34432 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -518,8 +518,9 @@ private[hive] class TestHiveSparkSession( // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to // delete it. Later, it will be re-created with the right permission. - val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) - val fs = location.getFileSystem(sc.hadoopConfiguration) + val hadoopConf = sessionState.newHadoopConf() + val location = new Path(hadoopConf.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(hadoopConf) fs.delete(location, true) // Some tests corrupt this value on purpose, which breaks the RESET call below. From 25c2776dd9ae3f9792048c78be2cbd958fd99841 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 28 Feb 2018 12:16:26 -0800 Subject: [PATCH 0407/2461] [SPARK-23523][SQL][FOLLOWUP] Minor refactor of OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? Inside `OptimizeMetadataOnlyQuery.getPartitionAttrs`, avoid using `zip` to generate attribute map. Also include other minor update of comments and format. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #20693 from jiangxb1987/SPARK-23523. --- .../spark/sql/execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../spark/sql/execution/datasources/HadoopFsRelation.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 0613d9053f826..dc4aff9f12580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -83,7 +83,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + val attrMap = relation.output.map(a => a.name.toLowerCase(Locale.ROOT) -> a).toMap partitionColumnNames.map { colName => attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), throw new AnalysisException(s"Unable to find the column `$colName` " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index ac574b07ec497..b2f73b7f8d1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,9 +67,9 @@ case class HadoopFsRelation( } } - // When data schema and partition schema have the overlapped columns, the output - // schema respects the order of data schema for the overlapped columns, but respect - // the data types of partition schema + // When data and partition schemas have overlapping columns, the output + // schema respects the order of the data schema for the overlapping columns, and it + // respects the data types of the partition schema. val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) From 22f3d3334c85c042c6e90f5a02f308d7cd1c1498 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 1 Mar 2018 14:28:28 +0800 Subject: [PATCH 0408/2461] [SPARK-23389][CORE] When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine =false`, we should be able to use serialized sorting. ## What changes were proposed in this pull request? When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine=false`, in the map side,there is no need for aggregation and sorting, so we should be able to use serialized sorting. ## How was this patch tested? Existing unit test Author: liuxian Closes #20576 from 10110346/mapsidecombine. --- .../scala/org/apache/spark/Dependency.scala | 3 +++ .../spark/shuffle/BlockStoreShuffleReader.scala | 1 - .../spark/shuffle/sort/SortShuffleManager.scala | 6 +++--- .../spark/shuffle/sort/SortShuffleWriter.scala | 2 -- .../shuffle/sort/SortShuffleManagerSuite.scala | 17 +++++++++-------- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ca52ecafa2cc8..9ea6d2fa2fd95 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -76,6 +76,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { + if (mapSideCombine) { + require(aggregator.isDefined, "Map-side combine without Aggregator specified!") + } override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff57c5..edd69715c9602 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -90,7 +90,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bfb4dc698e325..d9fad64f34c7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${dependency.serializer.getClass.getName}, does not support object relocation") false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + } else if (dependency.mapSideCombine) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") false } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792bf3..274399b9cc1f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { @@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter { def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") false } else { val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 55cebe7c8b6a8..f29dac965c803 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // We support serialized shuffle if we do not need to do map-side aggregation + assert(canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) } test("unsupported shuffle dependencies for serialized shuffle") { @@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform aggregation - assert(!canUseSerializedShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = None, - aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), - mapSideCombine = false - ))) + // We do not support serialized shuffle if we need to do map-side aggregation assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, From ff1480189b827af0be38605d566a4ee71b4c36f6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Mar 2018 16:26:11 +0800 Subject: [PATCH 0409/2461] [SPARK-23510][SQL] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? This is based on https://github.com/apache/spark/pull/20668 for supporting Hive 2.2 and Hive 2.3 metastore. When we merge the PR, we should give the major credit to wangyum ## How was this patch tested? Added the test cases Author: Yuming Wang Author: gatorsmile Closes #20671 from gatorsmile/pr-20668. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../sql/hive/client/HiveClientImpl.scala | 3 +- .../spark/sql/hive/client/HiveShim.scala | 8 +-- .../hive/client/IsolatedClientLoader.scala | 2 + .../spark/sql/hive/client/package.scala | 10 +++- .../sql/hive/execution/SaveAsHiveFile.scala | 3 +- .../sql/hive/client/HiveClientVersions.scala | 3 +- .../sql/hive/client/HiveVersionSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 51 +++++++++++++++++-- 9 files changed, 72 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c448c5a9821be..10c9603745379 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.1.1.") + s"0.12.0 through 2.3.2.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 146fa54a1bce4..da9fe2d3088b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf @@ -104,6 +103,8 @@ private[hive] class HiveClientImpl( case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() case hive.v2_1 => new Shim_v2_1() + case hive.v2_2 => new Shim_v2_2() + case hive.v2_3 => new Shim_v2_3() } // Create an internal session state for this HiveClientImpl. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19cd..948ba542b5733 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -880,9 +880,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { } -private[client] class Shim_v1_0 extends Shim_v0_14 { - -} +private[client] class Shim_v1_0 extends Shim_v0_14 private[client] class Shim_v1_1 extends Shim_v1_0 { @@ -1146,3 +1144,7 @@ private[client] class Shim_v2_1 extends Shim_v2_0 { alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) } } + +private[client] class Shim_v2_2 extends Shim_v2_1 + +private[client] class Shim_v2_3 extends Shim_v2_1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index dac0e333b63bc..12975bc85b971 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -97,6 +97,8 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.2" | "1.2.0" | "1.2.1" | "1.2.2" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 + case "2.2" | "2.2.0" => hive.v2_2 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index c14154a3b3c21..681ee9200f02b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -71,7 +71,15 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) + case object v2_2 extends HiveVersion("2.2.0", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v2_3 extends HiveVersion("2.3.2", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index e484356906e87..6a7b25b36d9a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -114,7 +114,8 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = + Set(v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala index 2e7dfde8b2fa5..30592a3f85428 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -22,5 +22,6 @@ import scala.collection.immutable.IndexedSeq import org.apache.spark.SparkFunSuite private[client] trait HiveClientVersions { - protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + protected val versions = + IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index a70fb6464cc1d..e5963d03f6b52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -34,7 +34,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 72536b833481a..6176273c88db1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -110,7 +111,8 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + private val versions = + Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") private var client: HiveClient = null @@ -125,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } @@ -422,15 +424,18 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") + val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - val partition = CatalogTablePartition(spec, storage) + val partition = CatalogTablePartition(spec, storage, parameters) client.alterPartitions("default", "src_part", Seq(partition)) assert(client.getPartition("default", "src_part", spec) .storage.locationUri == Some(newLocation)) + assert(client.getPartition("default", "src_part", spec) + .parameters.get(StatsSetupConst.TOTAL_SIZE) == Some("0")) } test(s"$version: dropPartitions") { @@ -633,6 +638,46 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: CREATE Partitioned TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql( + """ + |CREATE TABLE tbl(c1 string) + |PARTITIONED BY (ds STRING) + """.stripMargin) + versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") + + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) + val partMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty && numFiles.isEmpty) + } else { + assert(totalSize.nonEmpty && numFiles.nonEmpty) + } + + versionSpark.sql( + """ + |ALTER TABLE tbl PARTITION (ds='2') + |SET SERDEPROPERTIES ('newKey' = 'vvv') + """.stripMargin) + val newPartMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + + val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(newTotalSize.isEmpty && newNumFiles.isEmpty) + } else { + assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) + } + } + } + test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { From cdcccd7b41c43d79edff2fec7a84cd00e9524f75 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei <584620569@qq.com> Date: Fri, 2 Mar 2018 00:09:44 +0800 Subject: [PATCH 0410/2461] [SPARK-23405] Generate additional constraints for Join's children ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) I run a sql: `select ls.cs_order_number from ls left semi join catalog_sales cs on ls.cs_order_number = cs.cs_order_number`, The `ls` table is a small table ,and the number is one. The `catalog_sales` table is a big table, and the number is 10 billion. The task will be hang up. And i find the many null values of `cs_order_number` in the `catalog_sales` table. I think the null value should be removed in the logical plan. >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > +- MetastoreRelation 100t, catalog_sales Now, use this patch, the plan will be: >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > : **+- Filter isnotnull(cs_order_number#22)** > :+- MetastoreRelation 100t, catalog_sales ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: KaiXinXiaoLei <584620569@qq.com> Author: hanghang <584620569@qq.com> Closes #20670 from KaiXinXiaoLei/Spark-23405. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../plans/logical/QueryPlanConstraints.scala | 27 ++++++++++--------- .../InferFiltersFromConstraintsSuite.scala | 12 +++++++++ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a28b6a0feb8f9..91208479be03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -661,7 +661,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe case join @ Join(left, right, joinType, conditionOpt) => // Only consider constraints that can be pushed down completely to either the left or the // right child - val constraints = join.constraints.filter { c => + val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) } // Remove those constraints that are already enforced by either the left or the right child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 5c7b8e5b97883..046848875548b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -23,25 +23,28 @@ import org.apache.spark.sql.catalyst.expressions._ trait QueryPlanConstraints { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. + * An [[ExpressionSet]] that contains an additional set of constraints, such as equality + * constraints and `isNotNull` constraints, etc. */ - lazy val constraints: ExpressionSet = { + lazy val allConstraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet( - validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints)) - .filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - } - ) + ExpressionSet(validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints))) } else { ExpressionSet(Set.empty) } } + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + }) + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 178c4b8c270a0..f78c2356e35a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -192,4 +192,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } + + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftSemi, condition).analyze + val left = x.where(IsNotNull('a)) + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftSemi, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 34811e0b908449fd59bca476604612b1d200778d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 1 Mar 2018 17:26:39 -0800 Subject: [PATCH 0411/2461] [SPARK-23551][BUILD] Exclude `hadoop-mapreduce-client-core` dependency from `orc-mapreduce` ## What changes were proposed in this pull request? This PR aims to prevent `orc-mapreduce` dependency from making IDEs and maven confused. **BEFORE** Please note that `2.6.4` at `Spark Project SQL`. ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.orc:orc-mapreduce:jar:nohive:1.4.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.6.4:compile ``` **AFTER** ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile ``` ## How was this patch tested? 1. Pass the Jenkins with `dev/test-dependencies.sh` with the existing dependencies. 2. Manually do the following and see the change. ``` mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ``` Author: Dongjoon Hyun Closes #20704 from dongjoon-hyun/SPARK-23551. --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index b8396166f6b1b..0a711f287a53f 100644 --- a/pom.xml +++ b/pom.xml @@ -1753,6 +1753,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-mapreduce-client-core + org.apache.orc orc-core From 119f6a0e4729aa952e811d2047790a32ee90bf69 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 1 Mar 2018 21:04:01 -0800 Subject: [PATCH 0412/2461] [SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * BinarizerSuite * BucketedRandomProjectionLSHSuite * BucketizerSuite * ChiSqSelectorSuite * CountVectorizerSuite * DCTSuite.scala * ElementwiseProductSuite * FeatureHasherSuite * HashingTFSuite ## How was this patch tested? It tests itself because it is a bunch of tests! Author: Joseph K. Bradley Closes #20111 from jkbradley/SPARK-22883-streaming-featureAM. --- .../spark/ml/feature/BinarizerSuite.scala | 8 ++-- .../BucketedRandomProjectionLSHSuite.scala | 26 ++++++++--- .../spark/ml/feature/BucketizerSuite.scala | 11 +++-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 36 +++++++-------- .../ml/feature/CountVectorizerSuite.scala | 23 +++++----- .../apache/spark/ml/feature/DCTSuite.scala | 14 +++--- .../ml/feature/ElementwiseProductSuite.scala | 30 ++++++++++--- .../spark/ml/feature/FeatureHasherSuite.scala | 45 +++++++++---------- .../spark/ml/feature/HashingTFSuite.scala | 34 ++++++++------ 9 files changed, 126 insertions(+), 101 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 4455d35210878..05d4a6ee2dabf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BinarizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setInputCol("feature") .setOutputCol("binarized_feature") - binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") { case Row(x: Double, y: Double) => assert(x === y, "The feature value is not correct after binarization.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 7175c721bff36..ed9a39d8d1512 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -20,16 +20,15 @@ package org.apache.spark.ml.feature import breeze.numerics.{cos, sin} import breeze.numerics.constants.Pi -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} -class BucketedRandomProjectionLSHSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -98,6 +97,21 @@ class BucketedRandomProjectionLSHSuite MLTestingUtils.checkCopyAndUids(brp, brpModel) } + test("BucketedRandomProjectionLSH: streaming transform") { + val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val brpModel = brp.fit(dataset) + + testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") { + case Row(values: Seq[_]) => + assert(values.length === brp.getNumHashTables) + } + } + test("BucketedRandomProjectionLSH: test of LSH property") { // Project from 2 dimensional Euclidean Space to 1 dimensions val brp = new BucketedRandomProjectionLSH() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 41cf72fe3470a..9ea15e1918532 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(splits) bucketizer.setHandleInvalid("keep") - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c83909c4498f2..c843df9f33e3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - val model = ChiSqSelectorSuite.testSelector(selector, dataset) + val model = testSelector(selector, dataset) MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fpr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fdr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fwe") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("read/write") { @@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(expected.selectedFeatures === actual.selectedFeatures) } } -} -object ChiSqSelectorSuite { - - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { - val selectorModel = selector.fit(dataset) - selectorModel.transform(dataset).select("filtered", "topFeature").collect() - .foreach { case Row(vec1: Vector, vec2: Vector) => + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(data) + testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, + "filtered", "topFeature") { + case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) - } + } selectorModel } +} + +object ChiSqSelectorSuite { /** * Mapping from all Params to valid settings which differ from the defaults. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 1784c07ca23e3..61217669d9277 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(cv, cvm) assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cvm.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel2.vocabulary === Array("a", "b")) - cvModel2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel3.vocabulary === Array("a", "b")) - cvModel3.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -219,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -238,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(0.3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -258,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setOutputCol("features") .setBinary(true) .fit(df) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -268,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setBinary(true) - cv2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 8dd3dd75e1be5..6734336aac39c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -21,16 +21,14 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DCTSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setOutputCol("resultVec") .setInverse(inverse) - transformer.transform(dataset) - .select("resultVec", "wantedVec") - .collect() - .foreach { case Row(resultVec: Vector, wantedVec: Vector) => - assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") { + case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index a4cca27be7815..3a8d0762e2ab7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -17,13 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.Row -class ElementwiseProductSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + test("streaming transform") { + val scalingVec = Vectors.dense(0.1, 10.0) + val data = Seq( + (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)), + (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0)) + ) + val df = spark.createDataFrame(data).toDF("features", "expected") + val ep = new ElementwiseProduct() + .setInputCol("features") + .setOutputCol("actual") + .setScalingVec(scalingVec) + testTransformer[(Vector, Vector)](df, ep, "actual", "expected") { + case Row(actual: Vector, expected: Vector) => + assert(actual ~== expected relTol 1e-14) + } + } test("read/write") { val ep = new ElementwiseProduct() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 7bc1825b69c43..d799ba6011fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -17,27 +17,24 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FeatureHasherSuite extends SparkFunSuite - with MLlibTestSparkContext - with DefaultReadWriteTest { +class FeatureHasherSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import FeatureHasherSuite.murmur3FeatureIdx - implicit private val vectorEncoder = ExpressionEncoder[Vector]() + implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]() test("params") { ParamsSuite.checkParams(new FeatureHasher) @@ -52,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite } test("feature hashing") { + val numFeatures = 100 + // Assume perfect hash on field names in computing expected results + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val df = Seq( - (2.0, true, "1", "foo"), - (3.0, false, "2", "bar") - ).toDF("real", "bool", "stringNum", "string") + (2.0, true, "1", "foo", + Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), + (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))), + (3.0, false, "2", "bar", + Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), + (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))) + ).toDF("real", "bool", "stringNum", "string", "expected") - val n = 100 val hasher = new FeatureHasher() .setInputCols("real", "bool", "stringNum", "string") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hasher.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - assert(attrGroup.numAttributes === Some(n)) + assert(attrGroup.numAttributes === Some(numFeatures)) - val features = output.select("features").as[Vector].collect() - // Assume perfect hash on field names - def idx: Any => Int = murmur3FeatureIdx(n) - // check expected indices - val expected = Seq( - Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), - (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))), - Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), - (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))) - ) - assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14 ) + } } test("setting explicit numerical columns to treat as categorical") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index a46272fdce1fb..c5183ecfef7d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class HashingTFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import HashingTFSuite.murmur3FeatureIdx @@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") - val n = 100 + val numFeatures = 100 + // Assume perfect hash when computing expected features. + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val data = Seq( + ("a a b b c d".split(" ").toSeq, + Vectors.sparse(numFeatures, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))) + ) + + val df = data.toDF("words", "expected") val hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hashingTF.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = murmur3FeatureIdx(n) - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) + require(attrGroup.numAttributes === Some(numFeatures)) + + testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("applying binary term freqs") { From 0b6ceadeb563205cbd6bd03bc88e608086273b5b Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 2 Mar 2018 09:23:39 -0800 Subject: [PATCH 0413/2461] [SPARKR][DOC] fix link in vignettes ## What changes were proposed in this pull request? Fix doc link that was changed in 2.3 shivaram Author: Felix Cheung Closes #20711 from felixcheung/rvigmean. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index feca617c2554c..d4713de7806a1 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,7 +46,7 @@ Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ## Overview -SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](https://spark.apache.org/mllib/). ## Getting Started @@ -132,7 +132,7 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](https://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() @@ -147,7 +147,7 @@ sparkR.session(sparkHome = "/HOME/spark") ### Spark Session {#SetupSparkSession} -In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](https://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](https://spark.apache.org/docs/latest/api/R/sparkR.session.html). In particular, the following Spark driver properties can be set in `sparkConfig`. @@ -169,7 +169,7 @@ sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) #### Cluster Mode -SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with ```{r, echo=FALSE, tidy = TRUE} @@ -177,7 +177,7 @@ paste("Spark", packageVersion("SparkR")) ``` It should be used both on the local computer and on the remote cluster. -To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](https://spark.apache.org/docs/latest/submitting-applications.html#master-urls). For example, to connect to a local standalone Spark master, we can call ```{r, eval=FALSE} @@ -317,7 +317,7 @@ A common flow of grouping and aggregation is 2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. -A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there. For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. @@ -935,7 +935,7 @@ perplexity #### Alternating Least Squares -`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). +`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](https://dl.acm.org/citation.cfm?id=1608614). There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. @@ -1171,11 +1171,11 @@ env | map ## References -* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) +* [Spark Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) -* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) +* [Submitting Spark Applications](https://spark.apache.org/docs/latest/submitting-applications.html) -* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) +* [Machine Learning Library Guide (MLlib)](https://spark.apache.org/docs/latest/ml-guide.html) * [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. From 3a4d15e5d2b9ddbaeb2a6ab2d86d059ada6407b2 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 2 Mar 2018 10:38:50 -0800 Subject: [PATCH 0414/2461] [SPARK-23518][SQL] Avoid metastore access when the users only want to read and write data frames ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18944 added one patch, which allowed a spark session to be created when the hive metastore server is down. However, it did not allow running any commands with the spark session. This brings troubles to the user who only wants to read / write data frames without metastore setup. ## How was this patch tested? Added some unit tests to read and write data frames based on the original HiveMetastoreLazyInitializationSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20681 from liufengdb/completely-lazy. --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++---- .../sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../HiveMetastoreLazyInitializationSuite.scala | 14 ++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 8 ++++---- .../spark/sql/hive/HiveSessionStateBuilder.scala | 10 +++++----- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5197838eaac66..bd0a0dcd0674c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -67,6 +67,8 @@ sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) +# materialize the catalog implementation +listTables() mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 4b119c75260a7..64e7ca11270b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,8 +54,8 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - val externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -70,8 +70,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf) { this( - externalCatalog, - new GlobalTempViewManager("global_temp"), + () => externalCatalog, + () => new GlobalTempViewManager("global_temp"), functionRegistry, conf, new Configuration(), @@ -87,6 +87,9 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } + lazy val externalCatalog = externalCatalogBuilder() + lazy val globalTempViewManager = globalTempViewManagerBuilder() + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") protected val tempViews = new mutable.HashMap[String, LogicalPlan] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 007f8760edf82..3a0db7e16c23a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -130,8 +130,8 @@ abstract class BaseSessionStateBuilder( */ protected lazy val catalog: SessionCatalog = { val catalog = new SessionCatalog( - session.sharedState.externalCatalog, - session.sharedState.globalTempViewManager, + () => session.sharedState.externalCatalog, + () => session.sharedState.globalTempViewManager, functionRegistry, conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index 3f135cc864983..277df548aefd0 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -38,6 +38,20 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { // We should be able to run Spark jobs without Hive client. assert(spark.sparkContext.range(0, 1).count() === 1) + // We should be able to use Spark SQL if no table references. + assert(spark.sql("select 1 + 1").count() === 1) + assert(spark.range(0, 1).count() === 1) + + // We should be able to use fs + val path = Utils.createTempDir() + path.delete() + try { + spark.range(0, 1).write.parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).count() === 1) + } finally { + Utils.deleteRecursively(path) + } + // Make sure that we are not using the local derby metastore. val exceptionString = Utils.exceptionString(intercept[AnalysisException] { spark.sql("show tables") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1f11adbd4f62e..e5aff3b99d0b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalog: HiveExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => HiveExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, @@ -48,8 +48,8 @@ private[sql] class HiveSessionCatalog( parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( - externalCatalog, - globalTempViewManager, + externalCatalogBuilder, + globalTempViewManagerBuilder, functionRegistry, conf, hadoopConf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 12c74368dd184..40b9bb51ca9a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,8 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client - new HiveSessionResourceLoader(session, client) + new HiveSessionResourceLoader(session, () => externalCatalog.client) } /** @@ -51,8 +50,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session */ override protected lazy val catalog: HiveSessionCatalog = { val catalog = new HiveSessionCatalog( - externalCatalog, - session.sharedState.globalTempViewManager, + () => externalCatalog, + () => session.sharedState.globalTempViewManager, new HiveMetastoreCatalog(session), functionRegistry, conf, @@ -105,8 +104,9 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session class HiveSessionResourceLoader( session: SparkSession, - client: HiveClient) + clientBuilder: () => HiveClient) extends SessionResourceLoader(session) { + private lazy val client = clientBuilder() override def addJar(path: String): Unit = { client.addJar(path) super.addJar(path) From 707e6506d0dbdb598a6c99d666f3c66746113b67 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 2 Mar 2018 12:27:42 -0800 Subject: [PATCH 0415/2461] [SPARK-23097][SQL][SS] Migrate text socket source to V2 ## What changes were proposed in this pull request? This PR moves structured streaming text socket source to V2. Questions: do we need to remove old "socket" source? ## How was this patch tested? Unit test and manual verification. Author: jerryshao Closes #20382 from jerryshao/SPARK-23097. --- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../execution/datasources/DataSource.scala | 5 +- .../streaming/{ => sources}/socket.scala | 178 ++++++---- .../sql/streaming/DataStreamReader.scala | 21 +- .../streaming/TextSocketStreamSuite.scala | 231 ------------- .../sources/TextSocketStreamSuite.scala | 306 ++++++++++++++++++ 6 files changed, 434 insertions(+), 309 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => sources}/socket.scala (51%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 0259c774bbf4a..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 6e1b5727e3fd5..35fcff69b14d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -563,6 +564,7 @@ object DataSource extends Logging { val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -583,7 +585,8 @@ object DataSource extends Logging { "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, - "com.databricks.spark.csv" -> csv + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 0b22cbc46e6bf..5aae46b463398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -15,27 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -import org.apache.spark.unsafe.types.UTF8String - -object TextSocketSource { +object TextSocketMicroBatchReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -43,12 +45,17 @@ object TextSocketSource { } /** - * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. - * This source will *not* work in production applications due to multiple reasons, including no - * support for fault recovery and keeping all of the text read in memory forever. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt @GuardedBy("this") private var socket: Socket = null @@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(String, Timestamp)] @GuardedBy("this") - protected var currentOffset: LongOffset = new LongOffset(-1) + private var currentOffset: LongOffset = LongOffset(-1L) @GuardedBy("this") - protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + private var lastOffsetCommitted: LongOffset = LongOffset(-1L) initialize() + /** This method is only used for unit test */ + private[sources] def getCurrentOffset(): LongOffset = synchronized { + currentOffset.copy() + } + private def initialize(): Unit = synchronized { socket = new Socket(host, port) val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) @@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo logWarning(s"Stream closed by $host:$port") return } - TextSocketSource.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = (line, Timestamp.valueOf( - TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - ) - currentOffset = currentOffset + 1 + TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + currentOffset += 1 batches.append(newData) } } @@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo readThread.start() } - /** Returns the schema of the data from this source */ - override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP - else TextSocketSource.SCHEMA_REGULAR + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } + + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None + override def readSchema(): StructType = { + if (options.getBoolean("includeTimestamp", false)) { + TextSocketMicroBatchReader.SCHEMA_TIMESTAMP } else { - Some(currentOffset) + TextSocketMicroBatchReader.SCHEMA_REGULAR } } - /** Returns the data that is between the offsets (`start`, `end`]. */ - override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") + + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo batches.slice(sliceStart, sliceEnd) } - val rdd = sqlContext.sparkContext - .parallelize(rawList) - .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + assert(SparkSession.getActiveSession.isDefined) + val spark = SparkSession.getActiveSession.get + val numPartitions = spark.sparkContext.defaultParallelism + + val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + rawList.zipWithIndex.foreach { case (r, idx) => + slices(idx % numPartitions).append(r) + } + + (0 until numPartitions).map { i => + val slice = slices(i) + new DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new DataReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } + + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } + + override def close(): Unit = {} + } + } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo } } - override def toString: String = s"TextSocketSource[host: $host, port: $port]" + override def toString: String = s"TextSocket[host: $host, port: $port]" } -class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { - private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { - Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { - case Success(bool) => bool - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } +class TextSocketSourceProvider extends DataSourceV2 + with MicroBatchReadSupport with DataSourceRegister with Logging { - /** Returns the name and schema of the source that can be used to continually read data. */ - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { + private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!parameters.contains("host")) { + if (!params.get("host").isPresent) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!parameters.contains("port")) { + if (!params.get("port").isPresent) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - if (schema.nonEmpty) { - throw new AnalysisException("The socket source does not support a user-specified schema.") + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") } - - val sourceSchema = - if (parseIncludeTimestamp(parameters)) { - TextSocketSource.SCHEMA_TIMESTAMP - } else { - TextSocketSource.SCHEMA_REGULAR - } - ("textSocket", sourceSchema) } - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val host = parameters("host") - val port = parameters("port").toInt - new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 61e22fac854f9..c393dcdfdd7e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } ds match { case s: MicroBatchReadSupport => - val tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + var tempReader: MicroBatchReader = null + val schema = try { + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReader != null) { + tempReader.stop() + tempReader = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( Optional.ofNullable(userSpecifiedSchema.orNull), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala deleted file mode 100644 index ec11549073650..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io.{IOException, OutputStreamWriter} -import java.net.ServerSocket -import java.sql.Timestamp -import java.util.concurrent.LinkedBlockingQueue - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} - -class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { - import testImplicits._ - - override def afterEach() { - sqlContext.streams.active.foreach(_.stop()) - if (serverThread != null) { - serverThread.interrupt() - serverThread.join() - serverThread = null - } - if (source != null) { - source.stop() - source = null - } - } - - private var serverThread: ServerThread = null - private var source: Source = null - - test("basic usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("timestamped usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString, - "includeTimestamp" -> "true") - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("params not given") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map()) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) - } - } - - test("non-boolean includeTimestamp") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost", - "port" -> "1234", "includeTimestamp" -> "fasle")) - } - } - - test("user-specified schema given") { - val provider = new TextSocketSourceProvider - val userSpecifiedSchema = StructType( - StructField("name", StringType) :: - StructField("area", StringType) :: Nil) - val exception = intercept[AnalysisException] { - provider.sourceSchema( - sqlContext, Some(userSpecifiedSchema), - "", - Map("host" -> "localhost", "port" -> "1234")) - } - assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) - } - - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - source = provider.createSource(sqlContext, "", None, "", parameters) - } - } - - test("input row metrics") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) - } - source.stop() - source = null - } - } - - private class ServerThread extends Thread with Logging { - private val serverSocket = new ServerSocket(0) - private val messageQueue = new LinkedBlockingQueue[String]() - - val port = serverSocket.getLocalPort - - override def run(): Unit = { - try { - val clientSocket = serverSocket.accept() - clientSocket.setTcpNoDelay(true) - val out = new OutputStreamWriter(clientSocket.getOutputStream) - while (true) { - val line = messageQueue.take() - out.write(line + "\n") - out.flush() - } - } catch { - case e: InterruptedException => - } finally { - serverSocket.close() - } - } - - def enqueue(line: String): Unit = { - messageQueue.put(line) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala new file mode 100644 index 0000000000000..a15a980bb92fd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.IOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.ServerSocketChannel +import java.sql.Timestamp +import java.util.Optional +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } + } + + private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null + + case class AddSocketData(data: String*) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active socket source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + } + if (sources.isEmpty) { + throw new Exception( + "Could not find socket source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the socket source in the StreamExecution logical plan as there" + + "are multiple socket sources:\n\t" + sources.mkString("\n\t")) + } + val socketSource = sources.head + + assert(serverThread != null && serverThread.port != 0) + val currOffset = socketSource.getCurrentOffset() + data.foreach(serverThread.enqueue) + + val newOffset = LongOffset(currOffset.offset + data.size) + (socketSource, newOffset) + } + + override def toString: String = s"AddSocketData(data = $data)" + } + + test("backward compatibility with old path") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[TextSocketSourceProvider]) + case _ => + throw new IllegalStateException("Could not find socket source") + } + } + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + } + } + + test("timestamped usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val socket = spark + .readStream + .format("socket") + .options(Map( + "host" -> "localhost", + "port" -> serverThread.port.toString, + "includeTimestamp" -> "true")) + .load() + + assert(socket.schema === StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil)) + + var batch1Stamp: Timestamp = null + var batch2Stamp: Timestamp = null + + val curr = System.currentTimeMillis() + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "hello") + batch1Stamp = rows.head.getAs[Timestamp](1) + Thread.sleep(10) + }, + true), + AddSocketData("world"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "world") + batch2Stamp = rows.head.getAs[Timestamp](1) + }, + true), + StopStream + ) + + // Timestamp for rate stream is round to second which leads to milliseconds lost, that will + // make batch1stamp smaller than current timestamp if both of them are in the same second. + // Comparing by second to make sure the correct behavior. + assert(batch1Stamp.getTime >= curr / 1000 * 1000) + assert(!batch2Stamp.before(batch1Stamp)) + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) + } + } + + test("non-boolean includeTimestamp") { + val provider = new TextSocketSourceProvider + val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") + intercept[AnalysisException] { + val a = new DataSourceOptions(params.asJava) + provider.createMicroBatchReader(Optional.empty(), "", a) + } + } + + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val params = Map("host" -> "localhost", "port" -> "1234") + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + batchReader = provider.createMicroBatchReader( + Optional.empty(), "", new DataSourceOptions(parameters.asJava)) + } + } + + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AssertOnQuery { q => + val numRowMetric = + q.lastExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + numRowMetric.nonEmpty && numRowMetric.get.value == 1 + }, + StopStream + ) + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(0)) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocketChannel.socket().getLocalPort + + override def run(): Unit = { + try { + while (true) { + val clientSocketChannel = serverSocketChannel.accept() + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + + // Check whether remote client is closed but still send data to this closed socket. + // This happens in DataStreamReader where a source will be created to get the schema. + var remoteIsClosed = false + var cnt = 0 + while (cnt < 3 && !remoteIsClosed) { + if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { + cnt += 1 + Thread.sleep(100) + } else { + remoteIsClosed = true + } + } + + if (remoteIsClosed) { + logInfo(s"remote client ${clientSocketChannel.socket()} is closed") + } else { + while (true) { + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) + } + } + } + } catch { + case e: InterruptedException => + } finally { + serverSocketChannel.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} From 487377e693af65b2ff3d6b874ca7326c1ff0076c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 2 Mar 2018 14:30:37 -0800 Subject: [PATCH 0416/2461] [SPARK-23570][SQL] Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite since Spark 2.3.0 is released for ensuring backward compatibility. ## How was this patch tested? N/A Author: gatorsmile Closes #20720 from gatorsmile/add2.3. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index c13a750dbb270..6ca58e68d31eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") protected var spark: SparkSession = _ From 9e26473c0f29ee4281519104ac5e182a3bd4bf23 Mon Sep 17 00:00:00 2001 From: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Date: Fri, 2 Mar 2018 16:24:29 -0800 Subject: [PATCH 0417/2461] [SPARK-3159][ML] Add decision tree pruning ## What changes were proposed in this pull request? Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode ## How was this patch tested? Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala": - test("SPARK-3159 tree model redundancy - classification") - test("SPARK-3159 tree model redundancy - regression") 4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra _prune_ parameter which can be used to disable pruning for testing) Author: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Closes #20632 from asolimando/master. --- .../scala/org/apache/spark/ml/tree/Node.scala | 22 ++-- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../DecisionTreeClassifierSuite.scala | 38 ------- .../ml/tree/impl/RandomForestSuite.scala | 100 ++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 10 +- 5 files changed, 115 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..d30be452a436e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, - InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. @@ -266,15 +265,23 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { + def toNode: Node = toNode(prune = true) + /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode: Node = { - if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + def toNode(prune: Boolean = true): Node = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + case (l, r) => + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) + } } else { if (stats.valid) { new LeafNode(stats.impurityCalculator.predict, stats.impurity, @@ -283,7 +290,6 @@ private[tree] class LearningNode( // Here we want to keep same behavior with the old mllib.DecisionTreeModel new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553b..8e514f11e78ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation[_]], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 98c879ece62d6..38b265d62611b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite dt.fit(df) } - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val data = sc.parallelize(arr) - val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val dt = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(1) - .setMaxBins(3) - val model = dt.fit(df) - model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case other => - fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") - } - case other => - fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") - } - } - test("Feature importance with toy data") { val dt = new DecisionTreeClassifier() .setImpurity("gini") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931fb9c..5f0d26eb5c058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + private val seed = 42 + ///////////////////////////////////////////////////////////////////////////// // Tests for split calculation ///////////////////////////////////////////////////////////////////////////// @@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42, instr = None, prune = false).head + model.rootNode match { case n: InternalNode => n.split match { case s: CategoricalSplit => @@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip Vectors.sparse(size, indices.toArray, values.toArray) } + + @tailrec + private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { + if (nodes.isEmpty) { + acc + } + else { + nodes.head match { + case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 441d0f7614bf6..bc59f3f4125fb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { - if (i < 1000) { + if (i < 1001) { arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) From dea381dfaa73e0cfb9a833b79c741b15ae274f64 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Sat, 3 Mar 2018 09:10:48 +0800 Subject: [PATCH 0418/2461] [SPARK-23514][FOLLOW-UP] Remove more places using sparkContext.hadoopConfiguration directly ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/20679 I missed a few places in SQL tests. For hygiene, they should also use the sessionState interface where possible. ## How was this patch tested? Modified existing tests. Author: Juliusz Sompolski Closes #20718 from juliuszsompolski/SPARK-23514-followup. --- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/execution/datasources/FileIndexSuite.scala | 2 +- .../execution/datasources/parquet/ParquetCommitterSuite.scala | 2 +- .../datasources/parquet/ParquetFileFormatSuite.scala | 4 ++-- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- .../sql/execution/datasources/parquet/ParquetQuerySuite.scala | 2 +- .../org/apache/spark/sql/streaming/FileStreamSinkSuite.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b5d4c558f0d3e..73e3df3b6202e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -124,7 +124,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) val thirdPath = new Path(basePath, "third") - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = thirdPath.getFileSystem(spark.sessionState.newHadoopConf()) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b800e6ff5b0ce..db9023b7ec8b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1052,7 +1052,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index b4616826e40b3..18bb4bfe661ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -59,7 +59,7 @@ class FileIndexSuite extends SharedSQLContext { require(!unqualifiedDirPath.toString.contains("file:")) require(!unqualifiedFilePath.toString.contains("file:")) - val fs = unqualifiedDirPath.getFileSystem(sparkContext.hadoopConfiguration) + val fs = unqualifiedDirPath.getFileSystem(spark.sessionState.newHadoopConf()) val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index caa4f6d70c6a9..f3ecc5ced689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -101,7 +101,7 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils if (check) { result = Some(MarkingFileOutput.checkMarker( destPath, - spark.sparkContext.hadoopConfiguration)) + spark.sessionState.newHadoopConf())) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index ccb34355f1bac..3a0867fd2b78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -29,7 +29,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo test("read parquet footers in parallel") { def testReadFooters(ignoreCorruptFiles: Boolean): Unit = { withTempDir { dir => - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val basePath = dir.getCanonicalPath val path1 = new Path(basePath, "first") @@ -44,7 +44,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten val footers = ParquetFileFormat.readParquetFootersInParallel( - sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles) + spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles) assert(footers.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index e3edafa9c84e1..fbd83a0fa425a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -163,7 +163,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // Just to be defensive in case anything ever changes in parquet, this test checks // the assumption on column stats, and also the end-to-end behavior. - val hadoopConf = sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val fs = FileSystem.get(hadoopConf) val parts = fs.listStatus(new Path(tableDir.getAbsolutePath), new PathFilter { override def accept(path: Path): Boolean = !path.getName.startsWith("_") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 55b0f729be8ce..e1f094d0a7af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -819,7 +819,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val path = dir.getCanonicalPath spark.range(3).write.parquet(path) - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val files = fs.listFiles(new Path(path), true) while (files.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index ba48bc1ce0c4d..31e5527d7366a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -353,7 +353,7 @@ class FileStreamSinkSuite extends StreamTest { } test("FileStreamSink.ancestorIsMetadataDirectory()") { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() def assertAncestorIsMetadataDirectory(path: String): Unit = assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) def assertAncestorIsNotMetadataDirectory(path: String): Unit = From 486f99eefead4e664a30a861eca65cab8568e70b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Mar 2018 18:14:13 -0800 Subject: [PATCH 0419/2461] [SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions ## What changes were proposed in this pull request? Currently, when the Kafka source reads from Kafka, it generates as many tasks as the number of partitions in the topic(s) to be read. In some case, it may be beneficial to read the data with greater parallelism, that is, with more number partitions/tasks. That means, offset ranges must be divided up into smaller ranges such the number of records in partition ~= total records in batch / desired partitions. This would also balance out any data skews between topic-partitions. In this patch, I have added a new option called `minPartitions`, which allows the user to specify the desired level of parallelism. ## How was this patch tested? New tests in KafkaMicroBatchV2SourceSuite. Author: Tathagata Das Closes #20698 from tdas/SPARK-23541. --- .../sql/kafka010/KafkaMicroBatchReader.scala | 109 ++++++------- .../kafka010/KafkaOffsetRangeCalculator.scala | 105 +++++++++++++ .../sql/kafka010/KafkaSourceProvider.scala | 7 + .../apache/spark/sql/kafka010/package.scala | 24 +++ .../kafka010/KafkaMicroBatchSourceSuite.scala | 56 ++++++- .../KafkaOffsetRangeCalculatorSuite.scala | 147 ++++++++++++++++++ 6 files changed, 388 insertions(+), 60 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index fb647ca7e70dd..8a5f3a249b11c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import org.apache.commons.io.IOUtils -import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -64,8 +63,6 @@ private[kafka010] class KafkaMicroBatchReader( failOnDataLoss: Boolean) extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - type PartitionOffsetMap = Map[TopicPartition, Long] - private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -76,6 +73,7 @@ private[kafka010] class KafkaMicroBatchReader( private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + private val rangeCalculator = KafkaOffsetRangeCalculator(options) /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -106,15 +104,15 @@ private[kafka010] class KafkaMicroBatchReader( override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) - val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) - if (newPartitionOffsets.keySet != newPartitions) { + val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionInitialOffsets.keySet != newPartitions) { // We cannot get from offsets for some partitions. It means they got deleted. - val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) reportDataLoss( s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") } - logInfo(s"Partitions added: $newPartitionOffsets") - newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + logInfo(s"Partitions added: $newPartitionInitialOffsets") + newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => reportDataLoss( s"Added partition $p starts from $o instead of 0. Some data may have been missed") } @@ -125,46 +123,28 @@ private[kafka010] class KafkaMicroBatchReader( reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") } - // Use the until partitions to calculate offset ranges to ignore partitions that have + // Use the end partitions to calculate offset ranges to ignore partitions that have // been deleted val topicPartitions = endPartitionOffsets.keySet.filter { tp => // Ignore partitions that we don't know the from offsets. - newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) }.toSeq logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) - val sortedExecutors = getSortedExecutorList() - val numExecutors = sortedExecutors.length - logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) - // Calculate offset ranges - val factories = topicPartitions.flatMap { tp => - val fromOffset = startPartitionOffsets.get(tp).getOrElse { - newPartitionOffsets.getOrElse( - tp, { - // This should not happen since newPartitionOffsets contains all partitions not in - // fromPartitionOffsets - throw new IllegalStateException(s"$tp doesn't have a from offset") - }) - } - val untilOffset = endPartitionOffsets(tp) - - if (untilOffset >= fromOffset) { - // This allows cached KafkaConsumers in the executors to be re-used to read the same - // partition in every batch. - val preferredLoc = if (numExecutors > 0) { - Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) - } else None - val range = KafkaOffsetRange(tp, fromOffset, untilOffset) - Some( - new KafkaMicroBatchDataReaderFactory( - range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) - } else { - reportDataLoss( - s"Partition $tp's offset was changed from " + - s"$fromOffset to $untilOffset, some data may have been missed") - None - } + val offsetRanges = rangeCalculator.getRanges( + fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, + untilOffsets = endPartitionOffsets, + executorLocations = getSortedExecutorList()) + + // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, + // that is, concurrent tasks will not read the same TopicPartitions. + val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size + + // Generate factories based on the offset ranges + val factories = offsetRanges.map { range => + new KafkaMicroBatchDataReaderFactory( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava } @@ -320,28 +300,39 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReaderFactory( - range: KafkaOffsetRange, - preferredLoc: Option[String], +private[kafka010] case class KafkaMicroBatchDataReaderFactory( + offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { - override def preferredLocations(): Array[String] = preferredLoc.toArray + override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } /** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReader( +private[kafka010] case class KafkaMicroBatchDataReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = { + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We + // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + } - private val consumer = CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -369,9 +360,14 @@ private[kafka010] class KafkaMicroBatchDataReader( } override def close(): Unit = { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + if (!reuseKafkaConsumer) { + // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! + consumer.close() + } else { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { @@ -392,12 +388,9 @@ private[kafka010] class KafkaMicroBatchDataReader( } else { range.untilOffset } - KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None) } else { range } } } - -private[kafka010] case class KafkaOffsetRange( - topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala new file mode 100644 index 0000000000000..6631ae84167c8 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.sources.v2.DataSourceOptions + + +/** + * Class to calculate offset ranges to process based on the the from and until offsets, and + * the configured `minPartitions`. + */ +private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { + require(minPartitions.isEmpty || minPartitions.get > 0) + + import KafkaOffsetRangeCalculator._ + /** + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` + * is not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If + * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up + * the read tasks of the skewed partitions to multiple Spark tasks. + * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more + * depending on rounding errors or Kafka partitions that didn't receive any new data. + */ + def getRanges( + fromOffsets: PartitionOffsetMap, + untilOffsets: PartitionOffsetMap, + executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = { + val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet) + + val offsetRanges = partitionsToRead.toSeq.map { tp => + KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None) + }.filter(_.size > 0) + + // If minPartitions not set or there are enough partitions to satisfy minPartitions + if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) { + // Assign preferred executor locations to each range such that the same topic-partition is + // preferentially read from the same executor and the KafkaConsumer can be reused. + offsetRanges.map { range => + range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) + } + } else { + + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = offsetRanges.map(_.size).sum + val idealRangeSize = totalSize.toDouble / minPartitions.get + + offsetRanges.flatMap { range => + // Split the current range into subranges as close to the ideal range size + val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt + + (0 until numSplitsInRange).map { i => + val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange) + val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange) + KafkaOffsetRange( + range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None) + } + } + } + } + + private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = { + def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b + + val numExecutors = executorLocations.length + if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(executorLocations(floorMod(tp.hashCode, numExecutors))) + } else None + } +} + +private[kafka010] object KafkaOffsetRangeCalculator { + + def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + new KafkaOffsetRangeCalculator(optionalValue) + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + lazy val size: Long = untilOffset - fromOffset +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 0aa64a6a9cf90..36b9f0466566b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -348,6 +348,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister throw new IllegalArgumentException("Unknown option") } + // Validate minPartitions value if present + if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { + val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt + if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") + } + // Validate user-specified Kafka options if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -455,6 +461,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + private val MIN_PARTITIONS_OPTION_KEY = "minpartitions" val TOPIC_OPTION_KEY = "topic" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala new file mode 100644 index 0000000000000..43acd6a8d9473 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.kafka.common.TopicPartition + +package object kafka010 { // scalastyle:ignore + // ^^ scalastyle:ignore is for ignoring warnings about digits in package name + type PartitionOffsetMap = Map[TopicPartition, Long] +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 89c9ef4cc73b5..f2b3ff7615e74 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Properties} +import java.util.{Locale, Optional, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.io.Source import scala.util.Random @@ -34,15 +35,19 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -642,6 +647,53 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { } ) } + + testWithUninterruptibleThread("minPartitions is supported") { + import testImplicits._ + + val topic = newTopic() + val tp = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + def test( + minPartitions: String, + numPartitionsGenerated: Int, + reusesConsumers: Boolean): Unit = { + + SparkSession.setActiveSession(spark) + withTempDir { dir => + val provider = new KafkaSourceProvider() + val options = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "subscribe" -> topic + ) ++ Option(minPartitions).map { p => "minPartitions" -> p} + val reader = provider.createMicroBatchReader( + Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + reader.setOffsetRange( + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) + ) + val factories = reader.createUnsafeRowReaderFactories().asScala + .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { + assert(factories.size == numPartitionsGenerated) + factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + } + } + } + + // Test cases when minPartitions is used and not used + test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false) + + // Test illegal minPartitions values + intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } + } + } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala new file mode 100644 index 0000000000000..2ccf3e291bea7 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.DataSourceOptions + +class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { + + def testWithMinPartitions(name: String, minPartition: Int) + (f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + test(s"with minPartition = $minPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + + test("with no minPartition: N TopicPartitions to N offset ranges") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2), + executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 2, Some("location")))) + } + + test("with no minPartition: empty ranges ignored") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + } + + testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp2, 1, 2, None), + KafkaOffsetRange(tp3, 1, 2, None))) + } + + testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5), + executorLocations = Seq("location")) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set + } + + testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 11)) == + Seq( + KafkaOffsetRange(tp1, 1, 4, None), + KafkaOffsetRange(tp1, 4, 7, None), + KafkaOffsetRange(tp1, 7, 11, None))) + } + + testWithMinPartitions("empty ranges ignored", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + private val tp1 = new TopicPartition("t1", 1) + private val tp2 = new TopicPartition("t2", 1) + private val tp3 = new TopicPartition("t3", 1) +} From a89cdf55fa76fa23a524f0443e323498c3cc8664 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 5 Mar 2018 07:32:24 +0900 Subject: [PATCH 0420/2461] [SQL][MINOR] XPathDouble prettyPrint should say 'double' not 'float' ## What changes were proposed in this pull request? It looks like this was incorrectly copied from `XPathFloat` in the class above. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Liang Closes #20730 from ericl/fix-typo-xpath. --- .../org/apache/spark/sql/catalyst/expressions/xml/xpath.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index d0185562c9cfc..aacf1a44e2ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -160,7 +160,7 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { """) // scalastyle:on line.size.limit case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { - override def prettyName: String = "xpath_float" + override def prettyName: String = "xpath_double" override def dataType: DataType = DoubleType override def nullSafeEval(xml: Any, path: Any): Any = { From 7965c91d8a67c213ca5eebda5e46e7c49a8ba121 Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 5 Mar 2018 13:36:42 +0900 Subject: [PATCH 0421/2461] [SPARK-23569][PYTHON] Allow pandas_udf to work with python3 style type-annotated functions ## What changes were proposed in this pull request? Check python version to determine whether to use `inspect.getargspec` or `inspect.getfullargspec` before applying `pandas_udf` core logic to a function. The former is python2.7 (deprecated in python3) and the latter is python3.x. The latter correctly accounts for type annotations, which are syntax errors in python2.x. ## How was this patch tested? Locally, on python 2.7 and 3.6. Author: Michael (Stu) Stewart Closes #20728 from mstewart141/pandas_udf_fix. --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ python/pyspark/sql/udf.py | 9 ++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 19653072ea316..fa3b7203e10ac 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4381,6 +4381,24 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") + def test_type_annotation(self): + from pyspark.sql.functions import pandas_udf + # Regression test to check if type hints can be used. See SPARK-23569. + # Note that it throws an error during compilation in lower Python versions if 'exec' + # is not used. Also, note that we explicitly use another dictionary to avoid modifications + # in the current 'locals()'. + # + # Hyukjin: I think it's an ugly way to test issues about syntax specific in + # higher versions of Python, which we shouldn't encourage. This was the last resort + # I could come up with at that time. + _locals = {} + exec( + "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", + _locals) + df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + self.assertEqual(df.first()[0], 0) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e5b35fc60e167..b9b490874f4fb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -42,10 +42,17 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect + import sys from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - argspec = inspect.getargspec(f) + + if sys.version_info[0] < 3: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: From 269cd53590dd155aeb5269efc909a6e228f21e22 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Mar 2018 21:22:30 -0800 Subject: [PATCH 0422/2461] [MINOR][DOCS] Fix a link in "Compatibility with Apache Hive" ## What changes were proposed in this pull request? This PR fixes a broken link as below: **Before:** 2018-03-05 12 23 58 **After:** 2018-03-05 12 23 20 Also see https://spark.apache.org/docs/2.3.0/sql-programming-guide.html#compatibility-with-apache-hive ## How was this patch tested? Manually tested. I checked the same instances in `docs` directory. Seems this is the only one. Author: hyukjinkwon Closes #20733 from HyukjinKwon/minor-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c37c338a134f3..4d0f015f401bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2223,7 +2223,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 2ce37b50fc01558f49ad22f89c8659f50544ffec Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 5 Mar 2018 11:39:01 +0100 Subject: [PATCH 0423/2461] [SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext ## What changes were proposed in this pull request? A current `CodegenContext` class has immutable value or method without mutable state, too. This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20700 from kiszk/SPARK-23546. --- .../catalyst/expressions/BoundAttribute.scala | 9 +- .../spark/sql/catalyst/expressions/Cast.scala | 35 +- .../sql/catalyst/expressions/Expression.scala | 16 +- .../MonotonicallyIncreasingID.scala | 8 +- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../expressions/SparkPartitionID.scala | 7 +- .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 51 +- .../expressions/bitwiseExpressions.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 458 +++++++++--------- .../expressions/codegen/CodegenFallback.scala | 7 +- .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateOrdering.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 6 +- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/collectionOperations.scala | 6 +- .../expressions/complexTypeCreator.scala | 4 +- .../expressions/complexTypeExtractors.scala | 15 +- .../expressions/conditionalExpressions.scala | 10 +- .../expressions/datetimeExpressions.scala | 18 +- .../spark/sql/catalyst/expressions/hash.scala | 25 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 8 +- .../expressions/mathExpressions.scala | 5 +- .../expressions/nullExpressions.scala | 22 +- .../expressions/objects/objects.scala | 99 ++-- .../sql/catalyst/expressions/predicates.scala | 14 +- .../expressions/randomExpressions.scala | 8 +- .../expressions/regexpExpressions.scala | 8 +- .../expressions/stringExpressions.scala | 39 +- .../expressions/CodeGenerationSuite.scala | 4 +- .../sql/execution/ColumnarBatchScan.scala | 13 +- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 8 +- .../apache/spark/sql/execution/SortExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 16 +- .../aggregate/HashMapGenerator.scala | 8 +- .../aggregate/RowBasedHashMapGenerator.scala | 8 +- .../VectorizedHashMapGenerator.scala | 11 +- .../execution/basicPhysicalOperators.scala | 10 +- .../columnar/GenerateColumnAccessor.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 5 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- .../apache/spark/sql/execution/limit.scala | 7 +- 45 files changed, 535 insertions(+), 497 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef2..89ffbb0016916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 79b051670e9e4..12330bfa55ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultIsNull = $inputIsNull; - ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val funcName = ctx.freshName("elementToString") val elementToStringFunc = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(et)} element) { + |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { | UTF8String elementStr = null; | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} | return elementStr; @@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); | } | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { | $buffer.append(","); | if (!$array.isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); | } | } |} @@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val dataToStringCode = castToStringCode(dataType, ctx) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { | UTF8String dataStr = null; | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} | return dataStr; @@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) val loopIndex = ctx.freshName("loopIndex") + val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") + val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") + val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) s""" |$buffer.append("["); |if ($map.numElements() > 0) { - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append($keyToStringFunc($getMapFirstKey)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt(0)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | $buffer.append($valueToStringFunc($getMapFirstValue)); | } | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { | $buffer.append(", "); - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append($keyToStringFunc($getMapKeyArray)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc( - | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | $buffer.append($valueToStringFunc($getMapValueArray)); | } | } |} @@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | ${if (i != 0) s"""$buffer.append(" ");""" else ""} | | // Append $i field into the string buffer - | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; + ${CodeGenerator.javaType(fromType)} $fromElementPrim = + ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { @@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") val toFieldNull = ctx.freshName("tfn") - val fromType = ctx.javaType(from.fields(i).dataType) + val fromType = CodeGenerator.javaType(from.fields(i).dataType) s""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; + ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4568714933095..ed90b185865a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") + val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" @@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] { "" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) @@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression { ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { @@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression { boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = s""" @@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression { ${leftGen.code} ${midGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 11fb579dfa88c..4523079060896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 989c02305620a..e869258469a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1018,11 +1018,12 @@ case class ScalaUDF( val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" val resultConverter = s"$convertersTerm[${children.length}]" + val boxedType = CodeGenerator.boxedType(dataType) val callFunc = s""" - |${ctx.boxedType(dataType)} $resultTerm = null; + |$boxedType $resultTerm = null; |try { - | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); + | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); |} catch (Exception e) { | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} @@ -1035,7 +1036,7 @@ case class ScalaUDF( |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index a160b9b275290..cc6a769d032d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9a9f579b37f58..6c4a3601c1730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,7 +165,7 @@ case class PreciseTimestampConversion( val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; - |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6d7b..508bdd5050b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code s""" - ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); + ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); + ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } @@ -107,7 +107,7 @@ case class Abs(child: Expression) case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => @@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => @@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s"${eval2.value} == 0" } val remainder = ctx.freshName("remainder") - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val result = dataType match { case DecimalType.Fixed(_, _) => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value}); + $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); } else { @@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $remainder = - (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value}); + $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { - ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value}); + ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); } else { ${ev.value}=$remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value}; + $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; } else { @@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", @@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } @@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", @@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 173481f06a716..cc24e397cc14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60a6f50472504..793824b0b0a2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: String, var value: String) object ExprCode { + def forNullValue(dataType: DataType): ExprCode = { + val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + } + def forNonNullValue(value: String): ExprCode = { ExprCode(code = "", isNull = "false", value = value) } @@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec( */ class CodegenContext { + import CodeGenerator._ + /** * Holding a list of objects that could be used passed into generated class. */ @@ -196,11 +203,11 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { - if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" currentIndex += 1 res @@ -247,10 +254,10 @@ class CodegenContext { * are satisfied: * 1. forceInline is true * 2. its type is primitive type and the total number of the inlined mutable variables - * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * is less than `OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array * When a variable is compacted into an array, the max size of the array for compaction - * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -261,7 +268,7 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) @@ -339,7 +346,7 @@ class CodegenContext { val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { - CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + MUTABLESTATEARRAY_SIZE_LIMIT } if (javaType.contains("[]")) { // initializer had an one-dimensional array variable @@ -468,7 +475,7 @@ class CodegenContext { inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { + } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -537,14 +544,6 @@ class CodegenContext { extraClasses.append(code) } - final val JAVA_BOOLEAN = "boolean" - final val JAVA_BYTE = "byte" - final val JAVA_SHORT = "short" - final val JAVA_INT = "int" - final val JAVA_LONG = "long" - final val JAVA_FLOAT = "float" - final val JAVA_DOUBLE = "double" - /** * The map from a variable name to it's next ID. */ @@ -580,196 +579,6 @@ class CodegenContext { } } - /** - * Returns the specialized code to access a value from `inputRow` at `ordinal`. - */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" - } - } - - /** - * Returns the code to update a column in Row for a given DataType. - */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) - // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy - // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" - } - } - - /** - * Update a column in MutableRow from ExprCode. - * - * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise - */ - def updateColumn( - row: String, - dataType: DataType, - ordinal: Int, - ev: ExprCode, - nullable: Boolean, - isVectorized: Boolean = false): String = { - if (nullable) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - ${setColumn(row, dataType, ordinal, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - $row.setNullAt($ordinal); - } - """ - } - } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType`. - */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType` - * that could potentially be nullable. - */ - def updateColumn( - vector: String, - rowId: String, - dataType: DataType, - ev: ExprCode, - nullable: Boolean): String = { - if (nullable) { - s""" - if (!${ev.isNull}) { - ${setValue(vector, rowId, dataType, ev.value)} - } else { - $vector.putNull($rowId); - } - """ - } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" - } - } - - /** - * Returns the specialized code to access a value from a column vector for a given `DataType`. - */ - def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { - if (dataType.isInstanceOf[StructType]) { - // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an - // `ordinal` parameter. - s"$vector.getStruct($rowId)" - } else { - getValue(vector, dataType, rowId) - } - } - - /** - * Returns the name used in accessor and setter for a Java primitive type. - */ - def primitiveTypeName(jt: String): String = jt match { - case JAVA_INT => "Int" - case _ => boxedType(jt) - } - - def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) - - /** - * Returns the Java type for a DataType. - */ - def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case dt: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" - case ObjectType(cls) => cls.getName - case _ => "Object" - } - - /** - * Returns the boxed type in Java. - */ - def boxedType(jt: String): String = jt match { - case JAVA_BOOLEAN => "Boolean" - case JAVA_BYTE => "Byte" - case JAVA_SHORT => "Short" - case JAVA_INT => "Integer" - case JAVA_LONG => "Long" - case JAVA_FLOAT => "Float" - case JAVA_DOUBLE => "Double" - case other => other - } - - def boxedType(dt: DataType): String = boxedType(javaType(dt)) - - /** - * Returns the representation of default value for a given Java Type. - */ - def defaultValue(jt: String): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => "null" - } - - def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - /** * Generates code for equal expression in Java. */ @@ -812,6 +621,7 @@ class CodegenContext { val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") + val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -833,8 +643,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; - ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue("a", elementType, "i")}; + $jt $elementB = ${getValue("b", elementType, "i")}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -906,19 +716,6 @@ class CodegenContext { } } - /** - * List of java data types that have special accessors and setters in [[InternalRow]]. - */ - val primitiveTypes = - Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - - /** - * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. - */ - def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - - def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) - /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1089,7 +886,7 @@ class CodegenContext { // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse - if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) { // Adding a new function to each inner class which contains the invocation of all the // ones which have been added to that inner class. For example, // private class NestedClass { @@ -1289,7 +1086,7 @@ class CodegenContext { * length less than a pre-defined constant. */ def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH } } @@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging { result } }) + + /** + * Name of Java primitive data type + */ + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + + /** + * List of java primitive data types + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) + + /** + * Returns true if a Java type is Java primitive primitive type + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Returns the specialized code to access a value from `inputRow` at `ordinal`. + */ + def getValue(input: String, dataType: DataType, ordinal: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" + case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" + } + } + + /** + * Returns the code to update a column in Row for a given DataType. + */ + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" + case _ => s"$row.update($ordinal, $value)" + } + } + + /** + * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean, + isVectorized: Boolean = false): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | ${setColumn(row, dataType, ordinal, "null")}; + |} + """.stripMargin + } else { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | $row.setNullAt($ordinal); + |} + """.stripMargin + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType`. + */ + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _ => + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType` + * that could potentially be nullable. + */ + def updateColumn( + vector: String, + rowId: String, + dataType: DataType, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + s""" + |if (!${ev.isNull}) { + | ${setValue(vector, rowId, dataType, ev.value)} + |} else { + | $vector.putNull($rowId); + |} + """.stripMargin + } else { + s"""${setValue(vector, rowId, dataType, ev.value)};""" + } + } + + /** + * Returns the specialized code to access a value from a column vector for a given `DataType`. + */ + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) + } + } + + /** + * Returns the name used in accessor and setter for a Java primitive type. + */ + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } + + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) + + /** + * Returns the Java type for a DataType. + */ + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case _: DecimalType => "Decimal" + case BinaryType => "byte[]" + case StringType => "UTF8String" + case CalendarIntervalType => "CalendarInterval" + case _: StructType => "InternalRow" + case _: ArrayType => "ArrayData" + case _: MapType => "MapData" + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName + case _ => "Object" + } + + /** + * Returns the boxed type in Java. + */ + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + + /** + * Returns the representation of default value for a given Java Type. + * @param jt the string name of the Java type + * @param typedNull if true, for null literals, return a typed (with a cast) version + */ + def defaultValue(jt: String, typedNull: Boolean): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + + def defaultValue(dt: DataType, typedNull: Boolean = false): String = + defaultValue(javaType(dt), typedNull) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a9ff..e12420bb5dfdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -44,20 +44,21 @@ trait CodegenFallback extends Expression { } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) + val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; """, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7e2d..d35fd8ecb4d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) - val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; @@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => val ev = ExprCode("", isNull, value) - ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) + CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed634..9a51be6ed5aeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR s""" ${ctx.INPUT_ROW} = a; boolean $isNullA; - ${ctx.javaType(order.child.dataType)} $primitiveA; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; @@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } ${ctx.INPUT_ROW} = b; boolean $isNullB; - ${ctx.javaType(order.child.dataType)} $primitiveB; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba42a..f92f70ee71fef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] mutableRow.setNullAt($i); } else { ${converter.code} - ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b6..22717f5954a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } s""" @@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case other => other } - val jt = ctx.javaType(et) + val jt = CodeGenerator.javaType(et) val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize case _ => 8 // we need 8 bytes to store offset and length } val tmpCursor = ctx.freshName("tmpCursor") - val element = ctx.getValue(tmpInput, et, index) + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val primitiveTypeName = + if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..beb84694c44e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = "false") } } @@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 047b80ac5289c..85facdad43db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -90,7 +90,7 @@ private [sql] object GenArrayData { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length - if (!ctx.isPrimitiveType(elementType)) { + if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName @@ -124,7 +124,7 @@ private [sql] object GenArrayData { ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca3908905..6cdad19168dce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; } """ } else { s""" - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } }) @@ -205,7 +205,7 @@ case class GetArrayStructFields( } else { final InternalRow $row = $eval.getStruct($j, $numFields); $nullSafeEval { - $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; } } } @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ }) @@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression) } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression) int $index = 0; boolean $found = false; while ($index < $length && !$found) { - final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)}; if (${ctx.genEqual(keyType, key, eval2)}) { $found = true; } else { @@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression) if (!$found$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(values, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b444c3a7be92a..f4e9619bac59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" |${condEval.code} |boolean ${ev.isNull} = false; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -244,10 +244,10 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BYTE, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -264,7 +264,7 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 424871f2047e9..1ae4e5a2f716b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -673,18 +673,19 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode.forNullValue(dataType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; @@ -713,7 +714,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; }""") @@ -724,7 +725,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") @@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.$truncFuncStr; }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da54..b702422ed7a1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression { } } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", @@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression { ctx: CodegenContext): String = { val element = ctx.freshName("element") + val jt = CodeGenerator.javaType(elementType) ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } @@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression { val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", @@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, + extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$body |return ${ev.value}; """.stripMargin, @@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = s""" - |${ctx.JAVA_INT} ${ev.value} = $seed; - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes """.stripMargin) } @@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), - returnType = ctx.JAVA_INT, + arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childResult = 0; + |${CodeGenerator.JAVA_INT} $childResult = 0; |$body |return $result; """.stripMargin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed1757..07785e7448586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getInputFilePath();", isNull = "false") } } @@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getStartOffset();", isNull = "false") } } @@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getLength();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c1e65e34c2ea6..7395609a04ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) if (value == null) { - val defaultValueLiteral = ctx.defaultValue(javaType) match { - case "null" => s"(($javaType)null)" - case lit => lit - } - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode.forNullValue(dataType) } else { dataType match { case BooleanType | IntegerType | DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d8dc0862f1141..2c2cf3d2e6227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression, }""" } + val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 470d5da041ea5..b35fa72e95d1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", returnType = resultType, makeSplitFunction = func => s""" - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $codes |} while (false); @@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression case DoubleType | FloatType => ev.copy(code = s""" ${eval.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") } } @@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression) ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, + extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" |do { @@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate ev.copy(code = s""" - |${ctx.JAVA_INT} $nonnull = 0; + |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes |} while (false); - |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64da9bb9cdec1..80618af1e859f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") argValue } @@ -137,7 +137,7 @@ case class StaticInvoke( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -151,7 +151,7 @@ case class StaticInvoke( } val evaluate = if (returnNullable) { - if (ctx.defaultValue(dataType) == "null") { + if (CodeGenerator.defaultValue(dataType) == "null") { s""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; @@ -159,7 +159,7 @@ case class StaticInvoke( } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -173,7 +173,7 @@ case class StaticInvoke( val code = s""" $argCode $prepareIsNull - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!$resultIsNull) { $evaluate } @@ -228,7 +228,7 @@ case class Invoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -255,11 +255,11 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" } else { s""" if ($funcResult != null) { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; } else { ${ev.isNull} = true; } @@ -275,7 +275,7 @@ case class Invoke( val code = s""" ${obj.code} boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { $argCode ${ev.isNull} = $resultIsNull; @@ -341,7 +341,7 @@ case class NewInstance( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -358,7 +358,8 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ ev.copy(code = code) } @@ -385,15 +386,15 @@ case class UnwrapOption( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) val code = s""" ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -546,7 +547,7 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVarDataType) + val elementJavaType = CodeGenerator.javaType(loopVarDataType) ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -554,7 +555,7 @@ case class MapObjects private( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(lambdaFunction.dataType) + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -621,7 +622,7 @@ case class MapObjects private( ( s"${genInputData.value}.numElements()", "", - ctx.getValue(genInputData.value, et, loopIndex) + CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => val it = ctx.freshName("it") @@ -643,7 +644,8 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -695,7 +697,7 @@ case class MapObjects private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -806,10 +808,10 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = ctx.javaType(mapType.keyType) + val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(mapType.valueType) + val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) @@ -825,10 +827,11 @@ case class CatalystToExternalMap private( val valueArray = ctx.freshName("valueArray") val getKeyArray = s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" - val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" - val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + val getValueLoopVar = CodeGenerator.getValue( + valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -844,7 +847,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { @@ -873,7 +876,7 @@ case class CatalystToExternalMap private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = ctx.javaType(keyType) - val valueElementJavaType = ctx.javaType(valueType) + val keyElementJavaType = CodeGenerator.javaType(keyType) + val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) @@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private( val arrayCls = classOf[GenericArrayData].getName val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = ctx.boxedType(keyConverter.dataType) - val convertedValueType = ctx.boxedType(valueConverter.dataType) + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) val code = s""" ${inputMap.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // Code to deserialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val instanceGen = beanInstance.genCode(ctx) val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) val initialize = setters.map { case (setterMethod, fieldValue) => @@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: ArrayType => s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => - s"$obj instanceof ${ctx.boxedType(dataType)}" + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } val code = s""" ${input.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d00d4..4b85d9adbe311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = ctx.javaType(value.dataType) + val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, - returnType = ctx.JAVA_BYTE, + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" |do { @@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = s""" |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (ctx.isPrimitiveType(left.dataType) + if (CodeGenerator.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbfc31..6c9937dacc70b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = "false") } } @@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f3e8f6de58975..ad0c0791d895f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4c57..22fbb8998ed89 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } @@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression]) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") - val idxInVararg = ctx.freshName("idxInVararg") + val idxVararg = ctx.freshName("idxInVararg") val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => @@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression]) if (eval.isNull == "true") { "" } else { - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") @@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression]) if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")}; } } """) @@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression]) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body - |return $idxInVararg; + |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( s""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; - int $idxInVararg = 0; + int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; $varargBuilds @@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression { val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") + val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" @@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression { expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, - returnType = ctx.JAVA_BOOLEAN, + returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |do { | $body |} while (false); @@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression { s""" |${index.code} |final int $indexVal = ${index.value}; - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |$inputVal = null; |do { | $codes |} while (false); - |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; + |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => val value = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.value})" + s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + + s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } @@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); @@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression) val usLocale = "US" val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956ddc8..1e48c7b8df9da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { - ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;") } assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04f2619ed7541..392906a022903 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValueFromVector(columnVar, dataType, ordinal) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { s"$javaType $valueVar = $value;" @@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 + val scanTimeTotalNs = + ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93ecd..12ae1ea4a7c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -154,7 +154,8 @@ case class ExpandExec( val value = ctx.freshName("value") val code = s""" |boolean $isNull = true; - |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + |${CodeGenerator.javaType(firstExpr.dataType)} $value = + | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin ExprCode(code, isNull, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0c2c4a1a9100d..384f0398a1ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -305,15 +305,15 @@ case class GenerateExec( nullable: Boolean, initialChecks: Seq[String]): ExprCode = { val value = ctx.freshName(name) - val javaType = ctx.javaType(dt) - val getter = ctx.getValue(source, dt, index) + val javaType = CodeGenerator.javaType(dt) + val getter = CodeGenerator.getValue(source, dt, index) val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = s""" |boolean $isNull = ${checks.mkString(" || ")}; - |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, isNull, value) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ac1c34d41c4f1..0dc16ba5ce281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -133,7 +133,8 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index deb0a044c2fb2..f89e3fb0e536f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan { variables.zipWithIndex.foreach { case (ev, i) => val paramName = ctx.freshName(s"expr_$i") - val paramType = ctx.javaType(attributes(i).dataType) + val paramType = CodeGenerator.javaType(attributes(i).dataType) arguments += ev.value parameters += s"$paramType $paramName" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce3c68810f3b6..1926e9373bc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -186,8 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -532,7 +532,7 @@ case class HashAggregateExec( */ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = - (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) @@ -565,7 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -757,7 +757,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { @@ -832,7 +832,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -855,7 +855,7 @@ case class HashAggregateExec( } val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( + CodeGenerator.updateColumn( fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4ab1..6b60b414ffe5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -41,13 +41,13 @@ abstract class HashMapGenerator( val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ") val buffVars: Seq[ExprCode] = { val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index fd25707dd4ca6..8617be88f3570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.types._ /** @@ -114,7 +114,7 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", key.dataType, ordinal.toString()), key.name)})""" }.mkString(" && ") } @@ -147,7 +147,7 @@ class RowBasedHashMapGenerator( case t: DecimalType => s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" case t: DataType => - if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) { throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") } s"agg_rowWriter.write(${ordinal}, ${key.name})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 633eeac180974..7b3580cecc60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -127,7 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, + "buckets[idx]") s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } @@ -182,14 +183,14 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, - buffVars(ordinal), nullable = true) + CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", + key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a15a8d11aa2a0..4707022f74547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number") + val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") + val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") + val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4f28eeb725cbb..3b5655ba0582e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { - case t if ctx.isPrimitiveType(dt) => + case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 1918fcc5482db..487d6a2383318 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -182,9 +182,10 @@ case class BroadcastHashJoinExec( // the variables are needed even there is no matched rows val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) val code = s""" |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { | ${ev.code} | $isNull = ${ev.isNull}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d3..5a511b30e4fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -516,9 +516,9 @@ case class SortMergeJoinExec( ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - val javaType = ctx.javaType(a.dataType) - val defaultValue = ctx.defaultValue(a.dataType) + val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val javaType = CodeGenerator.javaType(a.dataType) + val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cccee63bc0680..66bcda8913738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils @@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false + val stopEarly = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; From 42cf48e20cd5e47e1b7557af9c71c4eea142f10f Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Mon, 5 Mar 2018 14:33:12 +0100 Subject: [PATCH 0424/2461] [SPARK-23496][CORE] Locality of coalesced partitions can be severely skewed by the order of input partitions ## What changes were proposed in this pull request? The algorithm in `DefaultPartitionCoalescer.setupGroups` is responsible for picking preferred locations for coalesced partitions. It analyzes the preferred locations of input partitions. It starts by trying to create one partition for each unique location in the input. However, if the the requested number of coalesced partitions is higher that the number of unique locations, it has to pick duplicate locations. Previously, the duplicate locations would be picked by iterating over the input partitions in order, and copying their preferred locations to coalesced partitions. If the input partitions were clustered by location, this could result in severe skew. With the fix, instead of iterating over the list of input partitions in order, we pick them at random. It's not perfectly balanced, but it's much better. ## How was this patch tested? Unit test reproducing the behavior was added. Author: Ala Luszczak Closes #20664 from ala/SPARK-23496. --- .../org/apache/spark/rdd/CoalescedRDD.scala | 8 ++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 10451a324b0f4..94e7d0b38cba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -266,17 +266,17 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) numCreated += 1 } } - tries = 0 // if we don't have enough partition groups, create duplicates while (numCreated < targetLen) { - val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) - tries += 1 + // Copy the preferred location from a random input partition. + // This helps in avoiding skew when the input partitions are clustered by preferred location. + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs( + rnd.nextInt(partitionLocs.partsWithLocs.length)) val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup addPartToPGroup(nxt_part, pgroup) numCreated += 1 - if (tries >= partitionLocs.partsWithLocs.length) tries = 0 } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462f..191c61250ce21 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -1129,6 +1129,35 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { }.collect() } + test("SPARK-23496: order of input partitions can result in severe skew in coalesce") { + val numInputPartitions = 100 + val numCoalescedPartitions = 50 + val locations = Array("locA", "locB") + + val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions) + assert(inputRDD.getNumPartitions == numInputPartitions) + + val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => + if (p.index < numCoalescedPartitions) { + Seq(locations(0)) + } else { + Seq(locations(1)) + } + }) + val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions) + + val numPartsPerLocation = coalescedRDD + .getPartitions + .map(coalescedRDD.getPreferredLocations(_).head) + .groupBy(identity) + .mapValues(_.size) + + // Make sure the coalesced partitions are distributed fairly evenly between the two locations. + // This should not become flaky since the DefaultPartitionsCoalescer uses a fixed seed. + assert(numPartsPerLocation(locations(0)) > 0.4 * numCoalescedPartitions) + assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because @@ -1210,3 +1239,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria groups.toArray } } + +/** Alters the preferred locations of the parent RDD using provided function. */ +class LocationPrefRDD[T: ClassTag]( + @transient var prev: RDD[T], + val locationPicker: Partition => Seq[String]) extends RDD[T](prev) { + override protected def getPartitions: Array[Partition] = prev.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = + null.asInstanceOf[Iterator[T]] + + override def getPreferredLocations(partition: Partition): Seq[String] = + locationPicker(partition) +} From 5ff72ffcf495d2823f7f1186078d1cb261667c3d Mon Sep 17 00:00:00 2001 From: Anirudh Date: Mon, 5 Mar 2018 23:17:16 +0900 Subject: [PATCH 0425/2461] [SPARK-23566][MINOR][DOC] Argument name mismatch fixed Argument name mismatch fixed. ## What changes were proposed in this pull request? `col` changed to `new` in doc string to match the argument list. Patch file added: https://issues.apache.org/jira/browse/SPARK-23566 Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Anirudh Closes #20716 from animenon/master. --- python/pyspark/sql/dataframe.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f37777e13ee12..9d8e85cde914f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -588,6 +588,8 @@ def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + :param numPartitions: int, to specify the target number of partitions + Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will @@ -612,9 +614,10 @@ def repartition(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions @@ -673,9 +676,10 @@ def repartitionByRange(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is range partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. @@ -892,6 +896,8 @@ def colRegex(self, colName): def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. + :param alias: string, an alias name to be set for the DataFrame. + >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") @@ -1900,7 +1906,7 @@ def withColumnRenamed(self, existing, new): This is a no-op if schema doesn't contain the given column name. :param existing: string, name of the existing column to rename. - :param col: string, new name of the column. + :param new: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] From a366b950b90650693ad0eb1e5b9a988ad028d845 Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Mon, 5 Mar 2018 23:46:40 +0900 Subject: [PATCH 0426/2461] [SPARK-23329][SQL] Fix documentation of trigonometric functions ## What changes were proposed in this pull request? Provide more details in trigonometric function documentations. Referenced `java.lang.Math` for further details in the descriptions. ## How was this patch tested? Ran full build, checked generated documentation manually Author: Mihaly Toth Closes #20618 from misutoth/trigonometric-doc. --- R/pkg/R/functions.R | 34 ++-- python/pyspark/sql/functions.py | 62 ++++--- .../expressions/mathExpressions.scala | 99 ++++++++--- .../org/apache/spark/sql/functions.scala | 160 ++++++++++++------ 4 files changed, 248 insertions(+), 107 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f7c6317cd924..29ee146ab14f9 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -278,8 +278,8 @@ setMethod("abs", }) #' @details -#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in -#' the range 0.0 through pi. +#' \code{acos}: Returns the inverse cosine of the given value, +#' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions #' @export @@ -334,8 +334,8 @@ setMethod("ascii", }) #' @details -#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in -#' the range -pi/2 through pi/2. +#' \code{asin}: Returns the inverse sine of the given value, +#' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions #' @export @@ -349,8 +349,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. +#' \code{atan}: Returns the inverse tangent of the given value, +#' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions #' @export @@ -613,7 +613,8 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. Units in radians. +#' \code{cos}: Returns the cosine of the given value, +#' as if computed by \code{java.lang.Math.cos()}. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -627,7 +628,8 @@ setMethod("cos", }) #' @details -#' \code{cosh}: Computes the hyperbolic cosine of the given value. +#' \code{cosh}: Returns the hyperbolic cosine of the given value, +#' as if computed by \code{java.lang.Math.cosh()}. #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method @@ -1463,7 +1465,8 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. Units in radians. +#' \code{sin}: Returns the sine of the given value, +#' as if computed by \code{java.lang.Math.sin()}. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1477,7 +1480,8 @@ setMethod("sin", }) #' @details -#' \code{sinh}: Computes the hyperbolic sine of the given value. +#' \code{sinh}: Returns the hyperbolic sine of the given value, +#' as if computed by \code{java.lang.Math.sinh()}. #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method @@ -1653,7 +1657,9 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. Units in radians. +#' \code{tan}: Returns the tangent of the given value, +#' as if computed by \code{java.lang.Math.tan()}. +#' Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1667,7 +1673,8 @@ setMethod("tan", }) #' @details -#' \code{tanh}: Computes the hyperbolic tangent of the given value. +#' \code{tanh}: Returns the hyperbolic tangent of the given value, +#' as if computed by \code{java.lang.Math.tanh()}. #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method @@ -1973,7 +1980,8 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). Units in radians. +#' (x, y) to polar coordinates (r, theta), +#' as if computed by \code{java.lang.Math.atan2()}. Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bb9c323a5a60..b9c0c57262c5d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -106,18 +106,15 @@ def _(): _functions_1_4 = { # unary math functions - 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + - '0.0 through pi.', - 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2', + 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', + 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', + 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': """Computes the cosine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'cos': """:param col: angle in radians + :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""", + 'cosh': """:param col: hyperbolic angle + :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""", 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', 'floor': 'Computes the floor of the given value.', @@ -127,14 +124,16 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': """Computes the sine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': """Computes the tangent of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'sin': """:param col: angle in radians + :return: sine of the angle, as if computed by `java.lang.Math.sin()`""", + 'sinh': """:param col: hyperbolic angle + :return: hyperbolic sine of the given value, + as if computed by `java.lang.Math.sinh()`""", + 'tan': """:param col: angle in radians + :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""", + 'tanh': """:param col: hyperbolic angle + :return: hyperbolic tangent of the given value, + as if computed by `java.lang.Math.tanh()`""", 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', @@ -173,16 +172,31 @@ def _(): _functions_2_1 = { # unary math functions - 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'degrees': """ + Converts an angle measured in radians to an approximately equivalent angle + measured in degrees. + :param col: angle in radians + :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()` + """, + 'radians': """ + Converts an angle measured in degrees to an approximately equivalent angle + measured in radians. + :param col: angle in degrees + :return: angle in radians, as if computed by `java.lang.Math.toRadians()` + """, } # math functions that take two arguments as input _binary_mathfunctions = { - 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta). Units in radians.', + 'atan2': """ + :param col1: coordinate on y-axis + :param col2: coordinate on x-axis + :return: the `theta` component of the point + (`r`, `theta`) + in polar coordinates that corresponds to the point + (`x`, `y`) in Cartesian coordinates, + as if computed by `java.lang.Math.atan2()` + """, 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 2c2cf3d2e6227..bc4cfcec47425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -168,9 +168,11 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -178,12 +180,13 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`, + as if computed by `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -191,18 +194,18 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).", + usage = """ + _FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by + `java.lang.Math._FUNC_` + """, examples = """ Examples: > SELECT _FUNC_(0); 0.0 """) -// scalastyle:on line.size.limit case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") @ExpressionDescription( @@ -252,7 +255,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -261,7 +271,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -512,7 +529,11 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sine of `expr`.", + usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -521,7 +542,13 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.", + usage = """ + _FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -539,7 +566,13 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -548,7 +581,13 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cotangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -562,7 +601,14 @@ case class Cot(child: Expression) } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -572,6 +618,10 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH" @ExpressionDescription( usage = "_FUNC_(expr) - Converts radians to degrees.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(3.141592653589793); @@ -583,6 +633,10 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre @ExpressionDescription( usage = "_FUNC_(expr) - Converts degrees to radians.", + arguments = """ + Arguments: + * expr - angle in degrees + """, examples = """ Examples: > SELECT _FUNC_(180); @@ -768,15 +822,22 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).", + usage = """ + _FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane + and the point given by the coordinates (`exprX`, `exprY`), as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * exprY - coordinate on y-axis + * exprX - coordinate on x-axis + """, examples = """ Examples: > SELECT _FUNC_(0, 0); 0.0 """) -// scalastyle:on line.size.limit case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d54c02c3d06f..c9ca9a8996344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1313,8 +1313,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the cosine inverse of the given value; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1322,8 +1321,7 @@ object functions { def acos(e: Column): Column = withExpr { Acos(e.expr) } /** - * Computes the cosine inverse of the given column; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1331,8 +1329,7 @@ object functions { def acos(columnName: String): Column = acos(Column(columnName)) /** - * Computes the sine inverse of the given value; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1340,8 +1337,7 @@ object functions { def asin(e: Column): Column = withExpr { Asin(e.expr) } /** - * Computes the sine inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1349,8 +1345,7 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `e`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1358,8 +1353,7 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1367,77 +1361,117 @@ object functions { def atan(columnName: String): Column = atan(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). Units in radians. + * @param y coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } + def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) } /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, rightName: String): Column = - atan2(Column(leftName), Column(rightName)) + def atan2(yName: String, xName: String): Column = + atan2(Column(yName), Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) + def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l), r) + def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) /** * An expression that returns the string representation of the binary value of the given long @@ -1500,7 +1534,8 @@ object functions { } /** - * Computes the cosine of the given value. Units in radians. + * @param e angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1508,7 +1543,8 @@ object functions { def cos(e: Column): Column = withExpr { Cos(e.expr) } /** - * Computes the cosine of the given column. + * @param columnName angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1516,7 +1552,8 @@ object functions { def cos(columnName: String): Column = cos(Column(columnName)) /** - * Computes the hyperbolic cosine of the given value. + * @param e hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1524,7 +1561,8 @@ object functions { def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** - * Computes the hyperbolic cosine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1967,7 +2005,8 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. Units in radians. + * @param e angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1975,7 +2014,8 @@ object functions { def sin(e: Column): Column = withExpr { Sin(e.expr) } /** - * Computes the sine of the given column. + * @param columnName angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1983,7 +2023,8 @@ object functions { def sin(columnName: String): Column = sin(Column(columnName)) /** - * Computes the hyperbolic sine of the given value. + * @param e hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1991,7 +2032,8 @@ object functions { def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** - * Computes the hyperbolic sine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1999,7 +2041,8 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. Units in radians. + * @param e angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2007,7 +2050,8 @@ object functions { def tan(e: Column): Column = withExpr { Tan(e.expr) } /** - * Computes the tangent of the given column. + * @param columnName angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2015,7 +2059,8 @@ object functions { def tan(columnName: String): Column = tan(Column(columnName)) /** - * Computes the hyperbolic tangent of the given value. + * @param e hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2023,7 +2068,8 @@ object functions { def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** - * Computes the hyperbolic tangent of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2047,6 +2093,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param e angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2055,6 +2104,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param columnName angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2077,6 +2129,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param e angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2085,6 +2140,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param columnName angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2873,7 +2931,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2929,7 +2987,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 From 947b4e6f09db6aa5d92409344b6e273e9faeb24e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 5 Mar 2018 16:21:02 +0100 Subject: [PATCH 0427/2461] [SPARK-23510][DOC][FOLLOW-UP] Update spark.sql.hive.metastore.version ## What changes were proposed in this pull request? Update `spark.sql.hive.metastore.version` to 2.3.2, same as HiveUtils.scala: https://github.com/apache/spark/blob/ff1480189b827af0be38605d566a4ee71b4c36f6/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala#L63-L65 ## How was this patch tested? N/A Author: Yuming Wang Closes #20734 from wangyum/SPARK-23510-FOLLOW-UP. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d0f015f401bb..01e2076555ee6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 1.2.1. + options are 0.12.0 through 2.3.2. From 4586eada42d6a16bb78d1650d145531c51fa747f Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 5 Mar 2018 09:30:49 -0800 Subject: [PATCH 0428/2461] [SPARK-22430][R][DOCS] Unknown tag warnings when building R docs with Roxygen 6.0.1 ## What changes were proposed in this pull request? Removed export tag to get rid of unknown tag warnings ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20501 from rekhajoshm/SPARK-22430. --- R/pkg/R/DataFrame.R | 92 --------- R/pkg/R/SQLContext.R | 16 -- R/pkg/R/WindowSpec.R | 8 - R/pkg/R/broadcast.R | 3 - R/pkg/R/catalog.R | 18 -- R/pkg/R/column.R | 7 - R/pkg/R/context.R | 6 - R/pkg/R/functions.R | 181 ----------------- R/pkg/R/generics.R | 343 --------------------------------- R/pkg/R/group.R | 7 - R/pkg/R/install.R | 1 - R/pkg/R/jvm.R | 3 - R/pkg/R/mllib_classification.R | 20 -- R/pkg/R/mllib_clustering.R | 23 --- R/pkg/R/mllib_fpm.R | 6 - R/pkg/R/mllib_recommendation.R | 5 - R/pkg/R/mllib_regression.R | 17 -- R/pkg/R/mllib_stat.R | 4 - R/pkg/R/mllib_tree.R | 33 ---- R/pkg/R/mllib_utils.R | 3 - R/pkg/R/schema.R | 7 - R/pkg/R/sparkR.R | 7 - R/pkg/R/stats.R | 6 - R/pkg/R/streaming.R | 9 - R/pkg/R/utils.R | 1 - R/pkg/R/window.R | 4 - 26 files changed, 830 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 41c3c3a89fa72..c4852024c0f49 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -36,7 +36,6 @@ setOldClass("structType") #' @slot sdf A Java object reference to the backing Scala DataFrame #' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -77,7 +76,6 @@ setWriteMode <- function(write, mode) { write } -#' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached #' @noRd @@ -97,7 +95,6 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @rdname printSchema #' @name printSchema #' @aliases printSchema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -123,7 +120,6 @@ setMethod("printSchema", #' @rdname schema #' @name schema #' @aliases schema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -146,7 +142,6 @@ setMethod("schema", #' @aliases explain,SparkDataFrame-method #' @rdname explain #' @name explain -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -178,7 +173,6 @@ setMethod("explain", #' @rdname isLocal #' @name isLocal #' @aliases isLocal,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -209,7 +203,6 @@ setMethod("isLocal", #' @aliases showDF,SparkDataFrame-method #' @rdname showDF #' @name showDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -241,7 +234,6 @@ setMethod("showDF", #' @rdname show #' @aliases show,SparkDataFrame-method #' @name show -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -269,7 +261,6 @@ setMethod("show", "SparkDataFrame", #' @rdname dtypes #' @name dtypes #' @aliases dtypes,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -296,7 +287,6 @@ setMethod("dtypes", #' @rdname columns #' @name columns #' @aliases columns,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -388,7 +378,6 @@ setMethod("colnames<-", #' @aliases coltypes,SparkDataFrame-method #' @name coltypes #' @family SparkDataFrame functions -#' @export #' @examples #'\dontrun{ #' irisDF <- createDataFrame(iris) @@ -445,7 +434,6 @@ setMethod("coltypes", #' @rdname coltypes #' @name coltypes<- #' @aliases coltypes<-,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -494,7 +482,6 @@ setMethod("coltypes<-", #' @rdname createOrReplaceTempView #' @name createOrReplaceTempView #' @aliases createOrReplaceTempView,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -521,7 +508,6 @@ setMethod("createOrReplaceTempView", #' @rdname registerTempTable-deprecated #' @name registerTempTable #' @aliases registerTempTable,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -552,7 +538,6 @@ setMethod("registerTempTable", #' @rdname insertInto #' @name insertInto #' @aliases insertInto,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -580,7 +565,6 @@ setMethod("insertInto", #' @aliases cache,SparkDataFrame-method #' @rdname cache #' @name cache -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -611,7 +595,6 @@ setMethod("cache", #' @rdname persist #' @name persist #' @aliases persist,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -641,7 +624,6 @@ setMethod("persist", #' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -669,7 +651,6 @@ setMethod("unpersist", #' @rdname storageLevel #' @aliases storageLevel,SparkDataFrame-method #' @name storageLevel -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -707,7 +688,6 @@ setMethod("storageLevel", #' @name coalesce #' @aliases coalesce,SparkDataFrame-method #' @seealso \link{repartition} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -744,7 +724,6 @@ setMethod("coalesce", #' @name repartition #' @aliases repartition,SparkDataFrame-method #' @seealso \link{coalesce} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -793,7 +772,6 @@ setMethod("repartition", #' @rdname toJSON #' @name toJSON #' @aliases toJSON,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -826,7 +804,6 @@ setMethod("toJSON", #' @rdname write.json #' @name write.json #' @aliases write.json,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -858,7 +835,6 @@ setMethod("write.json", #' @aliases write.orc,SparkDataFrame,character-method #' @rdname write.orc #' @name write.orc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -890,7 +866,6 @@ setMethod("write.orc", #' @rdname write.parquet #' @name write.parquet #' @aliases write.parquet,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -911,7 +886,6 @@ setMethod("write.parquet", #' @rdname write.parquet #' @name saveAsParquetFile #' @aliases saveAsParquetFile,SparkDataFrame,character-method -#' @export #' @note saveAsParquetFile since 1.4.0 setMethod("saveAsParquetFile", signature(x = "SparkDataFrame", path = "character"), @@ -936,7 +910,6 @@ setMethod("saveAsParquetFile", #' @aliases write.text,SparkDataFrame,character-method #' @rdname write.text #' @name write.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -963,7 +936,6 @@ setMethod("write.text", #' @aliases distinct,SparkDataFrame-method #' @rdname distinct #' @name distinct -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1004,7 +976,6 @@ setMethod("unique", #' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1061,7 +1032,6 @@ setMethod("sample_frac", #' @rdname nrow #' @name nrow #' @aliases count,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1094,7 +1064,6 @@ setMethod("nrow", #' @rdname ncol #' @name ncol #' @aliases ncol,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1118,7 +1087,6 @@ setMethod("ncol", #' @rdname dim #' @aliases dim,SparkDataFrame-method #' @name dim -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1144,7 +1112,6 @@ setMethod("dim", #' @rdname collect #' @aliases collect,SparkDataFrame-method #' @name collect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1229,7 +1196,6 @@ setMethod("collect", #' @rdname limit #' @name limit #' @aliases limit,SparkDataFrame,numeric-method -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -1253,7 +1219,6 @@ setMethod("limit", #' @rdname take #' @name take #' @aliases take,SparkDataFrame,numeric-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1282,7 +1247,6 @@ setMethod("take", #' @aliases head,SparkDataFrame-method #' @rdname head #' @name head -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1307,7 +1271,6 @@ setMethod("head", #' @aliases first,SparkDataFrame-method #' @rdname first #' @name first -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1359,7 +1322,6 @@ setMethod("toRDD", #' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy #' @name groupBy -#' @export #' @examples #' \dontrun{ #' # Compute the average for all numeric columns grouped by department. @@ -1401,7 +1363,6 @@ setMethod("group_by", #' @aliases agg,SparkDataFrame-method #' @rdname summarize #' @name agg -#' @export #' @note agg since 1.4.0 setMethod("agg", signature(x = "SparkDataFrame"), @@ -1460,7 +1421,6 @@ setClassUnion("characterOrstructType", c("character", "structType")) #' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1519,7 +1479,6 @@ setMethod("dapply", #' @aliases dapplyCollect,SparkDataFrame,function-method #' @name dapplyCollect #' @seealso \link{dapply} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1576,7 +1535,6 @@ setMethod("dapplyCollect", #' @rdname gapply #' @name gapply #' @seealso \link{gapplyCollect} -#' @export #' @examples #' #' \dontrun{ @@ -1673,7 +1631,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @name gapplyCollect #' @seealso \link{gapply} -#' @export #' @examples #' #' \dontrun{ @@ -1947,7 +1904,6 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected #' columns. -#' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method #' @seealso \link{withColumn} @@ -1992,7 +1948,6 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' If more than one column is assigned in \code{col}, \code{...} #' should be left empty. #' @return A new SparkDataFrame with selected columns. -#' @export #' @family SparkDataFrame functions #' @rdname select #' @aliases select,SparkDataFrame,character-method @@ -2024,7 +1979,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "character"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,Column-method #' @note select(SparkDataFrame, Column) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "Column"), @@ -2037,7 +1991,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "Column"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,list-method #' @note select(SparkDataFrame, list) since 1.4.0 setMethod("select", @@ -2066,7 +2019,6 @@ setMethod("select", #' @aliases selectExpr,SparkDataFrame,character-method #' @rdname selectExpr #' @name selectExpr -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2098,7 +2050,6 @@ setMethod("selectExpr", #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} \link{subset} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2137,7 +2088,6 @@ setMethod("withColumn", #' @rdname mutate #' @name mutate #' @seealso \link{rename} \link{withColumn} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2208,7 +2158,6 @@ setMethod("mutate", }) #' @param _data a SparkDataFrame. -#' @export #' @rdname mutate #' @aliases transform,SparkDataFrame-method #' @name transform @@ -2232,7 +2181,6 @@ setMethod("transform", #' @name withColumnRenamed #' @aliases withColumnRenamed,SparkDataFrame,character,character-method #' @seealso \link{mutate} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2258,7 +2206,6 @@ setMethod("withColumnRenamed", #' @rdname rename #' @name rename #' @aliases rename,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2304,7 +2251,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @aliases arrange,SparkDataFrame,Column-method #' @rdname arrange #' @name arrange -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2335,7 +2281,6 @@ setMethod("arrange", #' @rdname arrange #' @name arrange #' @aliases arrange,SparkDataFrame,character-method -#' @export #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), @@ -2368,7 +2313,6 @@ setMethod("arrange", #' @rdname arrange #' @aliases orderBy,SparkDataFrame,characterOrColumn-method -#' @export #' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 setMethod("orderBy", signature(x = "SparkDataFrame", col = "characterOrColumn"), @@ -2389,7 +2333,6 @@ setMethod("orderBy", #' @rdname filter #' @name filter #' @family subsetting functions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2432,7 +2375,6 @@ setMethod("where", #' @aliases dropDuplicates,SparkDataFrame-method #' @rdname dropDuplicates #' @name dropDuplicates -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2481,7 +2423,6 @@ setMethod("dropDuplicates", #' @rdname join #' @name join #' @seealso \link{merge} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2533,7 +2474,6 @@ setMethod("join", #' @rdname crossJoin #' @name crossJoin #' @seealso \link{merge} \link{join} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2581,7 +2521,6 @@ setMethod("crossJoin", #' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge #' @seealso \link{join} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2721,7 +2660,6 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { #' @name union #' @aliases union,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2742,7 +2680,6 @@ setMethod("union", #' @rdname union #' @name unionAll #' @aliases unionAll,SparkDataFrame,SparkDataFrame-method -#' @export #' @note unionAll since 1.4.0 setMethod("unionAll", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2769,7 +2706,6 @@ setMethod("unionAll", #' @name unionByName #' @aliases unionByName,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{union} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2802,7 +2738,6 @@ setMethod("unionByName", #' @rdname rbind #' @name rbind #' @seealso \link{union} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2835,7 +2770,6 @@ setMethod("rbind", #' @aliases intersect,SparkDataFrame,SparkDataFrame-method #' @rdname intersect #' @name intersect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2863,7 +2797,6 @@ setMethod("intersect", #' @aliases except,SparkDataFrame,SparkDataFrame-method #' @rdname except #' @name except -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2872,7 +2805,6 @@ setMethod("intersect", #' exceptDF <- except(df, df2) #' } #' @rdname except -#' @export #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2909,7 +2841,6 @@ setMethod("except", #' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2944,7 +2875,6 @@ setMethod("write.df", #' @rdname write.df #' @name saveDF #' @aliases saveDF,SparkDataFrame,character-method -#' @export #' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), @@ -2978,7 +2908,6 @@ setMethod("saveDF", #' @aliases saveAsTable,SparkDataFrame,character-method #' @rdname saveAsTable #' @name saveAsTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3015,7 +2944,6 @@ setMethod("saveAsTable", #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname describe #' @name describe -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3071,7 +2999,6 @@ setMethod("describe", #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3117,7 +3044,6 @@ setMethod("summary", #' @rdname nafunctions #' @aliases dropna,SparkDataFrame-method #' @name dropna -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3148,7 +3074,6 @@ setMethod("dropna", #' @rdname nafunctions #' @name na.omit #' @aliases na.omit,SparkDataFrame-method -#' @export #' @note na.omit since 1.5.0 setMethod("na.omit", signature(object = "SparkDataFrame"), @@ -3168,7 +3093,6 @@ setMethod("na.omit", #' @rdname nafunctions #' @name fillna #' @aliases fillna,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3399,7 +3323,6 @@ setMethod("str", #' @rdname drop #' @name drop #' @aliases drop,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3427,7 +3350,6 @@ setMethod("drop", #' @name drop #' @rdname drop #' @aliases drop,ANY-method -#' @export setMethod("drop", signature(x = "ANY"), function(x) { @@ -3446,7 +3368,6 @@ setMethod("drop", #' @rdname histogram #' @aliases histogram,SparkDataFrame,characterOrColumn-method #' @family SparkDataFrame functions -#' @export #' @examples #' \dontrun{ #' @@ -3582,7 +3503,6 @@ setMethod("histogram", #' @rdname write.jdbc #' @name write.jdbc #' @aliases write.jdbc,SparkDataFrame,character,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3611,7 +3531,6 @@ setMethod("write.jdbc", #' @aliases randomSplit,SparkDataFrame,numeric-method #' @rdname randomSplit #' @name randomSplit -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3645,7 +3564,6 @@ setMethod("randomSplit", #' @aliases getNumPartitions,SparkDataFrame-method #' @rdname getNumPartitions #' @name getNumPartitions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3672,7 +3590,6 @@ setMethod("getNumPartitions", #' @rdname isStreaming #' @name isStreaming #' @seealso \link{read.stream} \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3726,7 +3643,6 @@ setMethod("isStreaming", #' @aliases write.stream,SparkDataFrame-method #' @rdname write.stream #' @name write.stream -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3819,7 +3735,6 @@ setMethod("write.stream", #' @rdname checkpoint #' @name checkpoint #' @seealso \link{setCheckpointDir} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") @@ -3847,7 +3762,6 @@ setMethod("checkpoint", #' @aliases localCheckpoint,SparkDataFrame-method #' @rdname localCheckpoint #' @name localCheckpoint -#' @export #' @examples #'\dontrun{ #' df <- localCheckpoint(df) @@ -3874,7 +3788,6 @@ setMethod("localCheckpoint", #' @aliases cube,SparkDataFrame-method #' @rdname cube #' @name cube -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3909,7 +3822,6 @@ setMethod("cube", #' @aliases rollup,SparkDataFrame-method #' @rdname rollup #' @name rollup -#' @export #' @examples #'\dontrun{ #' df <- createDataFrame(mtcars) @@ -3942,7 +3854,6 @@ setMethod("rollup", #' @aliases hint,SparkDataFrame,character-method #' @rdname hint #' @name hint -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3966,7 +3877,6 @@ setMethod("hint", #' @family SparkDataFrame functions #' @rdname alias #' @name alias -#' @export #' @examples #' \dontrun{ #' df <- alias(createDataFrame(mtcars), "mtcars") @@ -3997,7 +3907,6 @@ setMethod("alias", #' @family SparkDataFrame functions #' @rdname broadcast #' @name broadcast -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -4041,7 +3950,6 @@ setMethod("broadcast", #' @family SparkDataFrame functions #' @rdname withWatermark #' @name withWatermark -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9d0a2d5e074e4..ebec0ce3d1920 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -123,7 +123,6 @@ infer_type <- function(x) { #' @return a list of config values with keys as their names #' @rdname sparkR.conf #' @name sparkR.conf -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -163,7 +162,6 @@ sparkR.conf <- function(key, defaultValue) { #' @return a character string of the Spark version #' @rdname sparkR.version #' @name sparkR.version -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -191,7 +189,6 @@ getDefaultSqlSource <- function() { #' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -294,7 +291,6 @@ createDataFrame <- function(x, ...) { #' @rdname createDataFrame #' @aliases createDataFrame -#' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { @@ -304,7 +300,6 @@ as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPa #' @param ... additional argument(s). #' @rdname createDataFrame #' @aliases as.DataFrame -#' @export as.DataFrame <- function(data, ...) { dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } @@ -342,7 +337,6 @@ setMethod("toDF", signature(x = "RDD"), #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -371,7 +365,6 @@ read.json <- function(x, ...) { #' @rdname read.json #' @name jsonFile -#' @export #' @method jsonFile default #' @note jsonFile since 1.4.0 jsonFile.default <- function(path) { @@ -423,7 +416,6 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc -#' @export #' @name read.orc #' @note read.orc since 2.0.0 read.orc <- function(path, ...) { @@ -444,7 +436,6 @@ read.orc <- function(path, ...) { #' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @export #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 @@ -466,7 +457,6 @@ read.parquet <- function(x, ...) { #' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile -#' @export #' @method parquetFile default #' @note parquetFile since 1.4.0 parquetFile.default <- function(...) { @@ -490,7 +480,6 @@ parquetFile <- function(x, ...) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -522,7 +511,6 @@ read.text <- function(x, ...) { #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame #' @rdname sql -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -556,7 +544,6 @@ sql <- function(x, ...) { #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -591,7 +578,6 @@ tableToDF <- function(tableName) { #' @rdname read.df #' @name read.df #' @seealso \link{read.json} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -681,7 +667,6 @@ loadDF <- function(x = NULL, ...) { #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -734,7 +719,6 @@ read.jdbc <- function(url, tableName, #' @rdname read.stream #' @name read.stream #' @seealso \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index debc7cbde55e7..ee7f4adf726e6 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{windowPartitionBy}, \link{windowOrderBy} #' #' @param sws A Java object reference to the backing Scala WindowSpec -#' @export #' @note WindowSpec since 2.0.0 setClass("WindowSpec", slots = list(sws = "jobj")) @@ -44,7 +43,6 @@ windowSpec <- function(sws) { } #' @rdname show -#' @export #' @note show(WindowSpec) since 2.0.0 setMethod("show", "WindowSpec", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "WindowSpec", #' @name partitionBy #' @aliases partitionBy,WindowSpec-method #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' partitionBy(ws, "col1", "col2") @@ -97,7 +94,6 @@ setMethod("partitionBy", #' @aliases orderBy,WindowSpec,character-method #' @family windowspec_method #' @seealso See \link{arrange} for use in sorting a SparkDataFrame -#' @export #' @examples #' \dontrun{ #' orderBy(ws, "col1", "col2") @@ -113,7 +109,6 @@ setMethod("orderBy", #' @rdname orderBy #' @name orderBy #' @aliases orderBy,WindowSpec,Column-method -#' @export #' @note orderBy(WindowSpec, Column) since 2.0.0 setMethod("orderBy", signature(x = "WindowSpec", col = "Column"), @@ -142,7 +137,6 @@ setMethod("orderBy", #' @aliases rowsBetween,WindowSpec,numeric,numeric-method #' @name rowsBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rowsBetween(ws, 0, 3) @@ -174,7 +168,6 @@ setMethod("rowsBetween", #' @aliases rangeBetween,WindowSpec,numeric,numeric-method #' @name rangeBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rangeBetween(ws, 0, 3) @@ -202,7 +195,6 @@ setMethod("rangeBetween", #' @name over #' @aliases over,Column,WindowSpec-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 398dffc4ab1b4..282f8a6857738 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -32,14 +32,12 @@ # @seealso broadcast # # @param id Id of the backing Spark broadcast variable -# @export setClass("Broadcast", slots = list(id = "character")) # @rdname broadcast-class # @param value Value of the broadcast variable # @param jBroadcastRef reference to the backing Java broadcast object # @param objName name of broadcasted object -# @export Broadcast <- function(id, value, jBroadcastRef, objName) { .broadcastValues[[id]] <- value .broadcastNames[[as.character(objName)]] <- jBroadcastRef @@ -73,7 +71,6 @@ setMethod("value", # @param bcastId The id of broadcast variable to set # @param value The value to be set -# @export setBroadcastValue <- function(bcastId, value) { bcastIdStr <- as.character(bcastId) .broadcastValues[[bcastIdStr]] <- value diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index e59a7024333ac..baf4d861fcf86 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -34,7 +34,6 @@ #' @return A SparkDataFrame. #' @rdname createExternalTable-deprecated #' @seealso \link{createTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -71,7 +70,6 @@ createExternalTable <- function(x, ...) { #' @return A SparkDataFrame. #' @rdname createTable #' @seealso \link{createExternalTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -110,7 +108,6 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, .. #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -140,7 +137,6 @@ cacheTable <- function(x, ...) { #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -167,7 +163,6 @@ uncacheTable <- function(x, ...) { #' Removes all cached tables from the in-memory cache. #' #' @rdname clearCache -#' @export #' @examples #' \dontrun{ #' clearCache() @@ -193,7 +188,6 @@ clearCache <- function() { #' @param tableName The name of the SparkSQL table to be dropped. #' @seealso \link{dropTempView} #' @rdname dropTempTable-deprecated -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -225,7 +219,6 @@ dropTempTable <- function(x, ...) { #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +244,6 @@ dropTempView <- function(viewName) { #' @return a SparkDataFrame #' @rdname tables #' @seealso \link{listTables} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -276,7 +268,6 @@ tables <- function(x, ...) { #' @param databaseName (optional) name of the database #' @return a list of table names #' @rdname tableNames -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -304,7 +295,6 @@ tableNames <- function(x, ...) { #' @return name of the current default database. #' @rdname currentDatabase #' @name currentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -324,7 +314,6 @@ currentDatabase <- function() { #' @param databaseName name of the database #' @rdname setCurrentDatabase #' @name setCurrentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -347,7 +336,6 @@ setCurrentDatabase <- function(databaseName) { #' @return a SparkDataFrame of the list of databases. #' @rdname listDatabases #' @name listDatabases -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -370,7 +358,6 @@ listDatabases <- function() { #' @rdname listTables #' @name listTables #' @seealso \link{tables} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -403,7 +390,6 @@ listTables <- function(databaseName = NULL) { #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns #' @name listColumns -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -433,7 +419,6 @@ listColumns <- function(tableName, databaseName = NULL) { #' @return a SparkDataFrame of the list of function descriptions. #' @rdname listFunctions #' @name listFunctions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -463,7 +448,6 @@ listFunctions <- function(databaseName = NULL) { #' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -490,7 +474,6 @@ recoverPartitions <- function(tableName) { #' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -512,7 +495,6 @@ refreshTable <- function(tableName) { #' @param path the path of the data source. #' @rdname refreshByPath #' @name refreshByPath -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 3095adb918b67..9727efc354f10 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -29,7 +29,6 @@ setOldClass("jobj") #' @rdname column #' #' @slot jc reference to JVM SparkDataFrame column -#' @export #' @note Column since 1.4.0 setClass("Column", slots = list(jc = "jobj")) @@ -56,7 +55,6 @@ setMethod("column", #' @rdname show #' @name show #' @aliases show,Column-method -#' @export #' @note show(Column) since 1.4.0 setMethod("show", "Column", function(object) { @@ -134,7 +132,6 @@ createMethods() #' @name alias #' @aliases alias,Column-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -270,7 +267,6 @@ setMethod("cast", #' @name %in% #' @aliases %in%,Column-method #' @return A matched values as a result of comparing with given values. -#' @export #' @examples #' \dontrun{ #' filter(df, "age in (10, 30)") @@ -296,7 +292,6 @@ setMethod("%in%", #' @name otherwise #' @family colum_func #' @aliases otherwise,Column-method -#' @export #' @note otherwise since 1.5.0 setMethod("otherwise", signature(x = "Column", value = "ANY"), @@ -318,7 +313,6 @@ setMethod("otherwise", #' @rdname eq_null_safe #' @name %<=>% #' @aliases %<=>%,Column-method -#' @export #' @examples #' \dontrun{ #' df1 <- createDataFrame(data.frame( @@ -348,7 +342,6 @@ setMethod("%<=>%", #' @rdname not #' @name not #' @aliases !,Column-method -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ace..8ec727dd042bc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -308,7 +308,6 @@ setCheckpointDirSC <- function(sc, dirName) { #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. -#' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") @@ -323,7 +322,6 @@ spark.addFile <- function(path, recursive = FALSE) { #' #' @rdname spark.getSparkFilesRootDirectory #' @return the root directory that contains files added through spark.addFile -#' @export #' @examples #'\dontrun{ #' spark.getSparkFilesRootDirectory() @@ -344,7 +342,6 @@ spark.getSparkFilesRootDirectory <- function() { # nolint #' @rdname spark.getSparkFiles #' @param fileName The name of the file added through spark.addFile #' @return the absolute path of a file added through spark.addFile. -#' @export #' @examples #'\dontrun{ #' spark.getSparkFiles("myfile") @@ -391,7 +388,6 @@ spark.getSparkFiles <- function(fileName) { #' @param list the list of elements #' @param func a function that takes one argument. #' @return a list of results (the exact type being determined by the function) -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -412,7 +408,6 @@ spark.lapply <- function(list, func) { #' #' @rdname setLogLevel #' @param level New log level -#' @export #' @examples #'\dontrun{ #' setLogLevel("ERROR") @@ -431,7 +426,6 @@ setLogLevel <- function(level) { #' @rdname setCheckpointDir #' @param directory Directory path to checkpoint to #' @seealso \link{checkpoint} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 29ee146ab14f9..a527426b19674 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -244,7 +244,6 @@ NULL #' If the parameter is a Column, it is returned unchanged. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases lit lit,ANY-method #' @examples #' @@ -267,7 +266,6 @@ setMethod("lit", signature("ANY"), #' \code{abs}: Computes the absolute value. #' #' @rdname column_math_functions -#' @export #' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", @@ -282,7 +280,6 @@ setMethod("abs", #' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions -#' @export #' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", @@ -296,7 +293,6 @@ setMethod("acos", #' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' #' @rdname column_aggregate_functions -#' @export #' @aliases approxCountDistinct approxCountDistinct,Column-method #' @examples #' @@ -319,7 +315,6 @@ setMethod("approxCountDistinct", #' and returns the result as an int column. #' #' @rdname column_string_functions -#' @export #' @aliases ascii ascii,Column-method #' @examples #' @@ -338,7 +333,6 @@ setMethod("ascii", #' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions -#' @export #' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", @@ -353,7 +347,6 @@ setMethod("asin", #' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions -#' @export #' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", @@ -370,7 +363,6 @@ setMethod("atan", #' @rdname avg #' @name avg #' @family aggregate functions -#' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} #' @note avg since 1.4.0 @@ -386,7 +378,6 @@ setMethod("avg", #' a string column. This is the reverse of unbase64. #' #' @rdname column_string_functions -#' @export #' @aliases base64 base64,Column-method #' @examples #' @@ -410,7 +401,6 @@ setMethod("base64", #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions -#' @export #' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", @@ -424,7 +414,6 @@ setMethod("bin", #' \code{bitwiseNOT}: Computes bitwise NOT. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases bitwiseNOT bitwiseNOT,Column-method #' @examples #' @@ -442,7 +431,6 @@ setMethod("bitwiseNOT", #' \code{cbrt}: Computes the cube-root of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", @@ -456,7 +444,6 @@ setMethod("cbrt", #' \code{ceil}: Computes the ceiling of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", @@ -471,7 +458,6 @@ setMethod("ceil", #' #' @rdname column_math_functions #' @aliases ceiling ceiling,Column-method -#' @export #' @note ceiling since 1.5.0 setMethod("ceiling", signature(x = "Column"), @@ -483,7 +469,6 @@ setMethod("ceiling", #' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases coalesce,Column-method #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", @@ -514,7 +499,6 @@ col <- function(x) { #' @rdname column #' @name column #' @family non-aggregate functions -#' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} #' @note column since 1.6.0 @@ -533,7 +517,6 @@ setMethod("column", #' @rdname corr #' @name corr #' @family aggregate functions -#' @export #' @aliases corr,Column-method #' @examples #' \dontrun{ @@ -557,7 +540,6 @@ setMethod("corr", signature(x = "Column"), #' @rdname cov #' @name cov #' @family aggregate functions -#' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ @@ -598,7 +580,6 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname cov #' @name covar_pop -#' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), @@ -618,7 +599,6 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' @rdname column_math_functions #' @aliases cos cos,Column-method -#' @export #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -633,7 +613,6 @@ setMethod("cos", #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method -#' @export #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -651,7 +630,6 @@ setMethod("cosh", #' @name count #' @family aggregate functions #' @aliases count,Column-method -#' @export #' @examples \dontrun{count(df$c)} #' @note count since 1.4.0 setMethod("count", @@ -667,7 +645,6 @@ setMethod("count", #' #' @rdname column_misc_functions #' @aliases crc32 crc32,Column-method -#' @export #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -682,7 +659,6 @@ setMethod("crc32", #' #' @rdname column_misc_functions #' @aliases hash hash,Column-method -#' @export #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -701,7 +677,6 @@ setMethod("hash", #' #' @rdname column_datetime_functions #' @aliases dayofmonth dayofmonth,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -723,7 +698,6 @@ setMethod("dayofmonth", #' #' @rdname column_datetime_functions #' @aliases dayofweek dayofweek,Column-method -#' @export #' @note dayofweek since 2.3.0 setMethod("dayofweek", signature(x = "Column"), @@ -738,7 +712,6 @@ setMethod("dayofweek", #' #' @rdname column_datetime_functions #' @aliases dayofyear dayofyear,Column-method -#' @export #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -756,7 +729,6 @@ setMethod("dayofyear", #' #' @rdname column_string_functions #' @aliases decode decode,Column,character-method -#' @export #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -771,7 +743,6 @@ setMethod("decode", #' #' @rdname column_string_functions #' @aliases encode encode,Column,character-method -#' @export #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -785,7 +756,6 @@ setMethod("encode", #' #' @rdname column_math_functions #' @aliases exp exp,Column-method -#' @export #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -799,7 +769,6 @@ setMethod("exp", #' #' @rdname column_math_functions #' @aliases expm1 expm1,Column-method -#' @export #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -813,7 +782,6 @@ setMethod("expm1", #' #' @rdname column_math_functions #' @aliases factorial factorial,Column-method -#' @export #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -836,7 +804,6 @@ setMethod("factorial", #' @name first #' @aliases first,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' first(df$c) @@ -860,7 +827,6 @@ setMethod("first", #' #' @rdname column_math_functions #' @aliases floor floor,Column-method -#' @export #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -874,7 +840,6 @@ setMethod("floor", #' #' @rdname column_math_functions #' @aliases hex hex,Column-method -#' @export #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -888,7 +853,6 @@ setMethod("hex", #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -911,7 +875,6 @@ setMethod("hour", #' #' @rdname column_string_functions #' @aliases initcap initcap,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -946,7 +909,6 @@ setMethod("isnan", #' #' @rdname column_nonaggregate_functions #' @aliases is.nan is.nan,Column-method -#' @export #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -959,7 +921,6 @@ setMethod("is.nan", #' #' @rdname column_aggregate_functions #' @aliases kurtosis kurtosis,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -988,7 +949,6 @@ setMethod("kurtosis", #' @name last #' @aliases last,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' last(df$c) @@ -1014,7 +974,6 @@ setMethod("last", #' #' @rdname column_datetime_functions #' @aliases last_day last_day,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1034,7 +993,6 @@ setMethod("last_day", #' #' @rdname column_string_functions #' @aliases length length,Column-method -#' @export #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -1048,7 +1006,6 @@ setMethod("length", #' #' @rdname column_math_functions #' @aliases log log,Column-method -#' @export #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1062,7 +1019,6 @@ setMethod("log", #' #' @rdname column_math_functions #' @aliases log10 log10,Column-method -#' @export #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1076,7 +1032,6 @@ setMethod("log10", #' #' @rdname column_math_functions #' @aliases log1p log1p,Column-method -#' @export #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1090,7 +1045,6 @@ setMethod("log1p", #' #' @rdname column_math_functions #' @aliases log2 log2,Column-method -#' @export #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1104,7 +1058,6 @@ setMethod("log2", #' #' @rdname column_string_functions #' @aliases lower lower,Column-method -#' @export #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1119,7 +1072,6 @@ setMethod("lower", #' #' @rdname column_string_functions #' @aliases ltrim ltrim,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1143,7 +1095,6 @@ setMethod("ltrim", #' @param trimString a character string to trim with #' @rdname column_string_functions #' @aliases ltrim,Column,character-method -#' @export #' @note ltrim(Column, character) since 2.3.0 setMethod("ltrim", signature(x = "Column", trimString = "character"), @@ -1171,7 +1122,6 @@ setMethod("max", #' #' @rdname column_misc_functions #' @aliases md5 md5,Column-method -#' @export #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1185,7 +1135,6 @@ setMethod("md5", #' #' @rdname column_aggregate_functions #' @aliases mean mean,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1211,7 +1160,6 @@ setMethod("mean", #' #' @rdname column_aggregate_functions #' @aliases min min,Column-method -#' @export #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1225,7 +1173,6 @@ setMethod("min", #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method -#' @export #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1248,7 +1195,6 @@ setMethod("minute", #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, monotonically_increasing_id()))} @@ -1264,7 +1210,6 @@ setMethod("monotonically_increasing_id", #' #' @rdname column_datetime_functions #' @aliases month month,Column-method -#' @export #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1278,7 +1223,6 @@ setMethod("month", #' #' @rdname column_nonaggregate_functions #' @aliases negate negate,Column-method -#' @export #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1292,7 +1236,6 @@ setMethod("negate", #' #' @rdname column_datetime_functions #' @aliases quarter quarter,Column-method -#' @export #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1306,7 +1249,6 @@ setMethod("quarter", #' #' @rdname column_string_functions #' @aliases reverse reverse,Column-method -#' @export #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1321,7 +1263,6 @@ setMethod("reverse", #' #' @rdname column_math_functions #' @aliases rint rint,Column-method -#' @export #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1336,7 +1277,6 @@ setMethod("rint", #' #' @rdname column_math_functions #' @aliases round round,Column-method -#' @export #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1356,7 +1296,6 @@ setMethod("round", #' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method -#' @export #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1371,7 +1310,6 @@ setMethod("bround", #' #' @rdname column_string_functions #' @aliases rtrim rtrim,Column,missing-method -#' @export #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column", trimString = "missing"), @@ -1382,7 +1320,6 @@ setMethod("rtrim", #' @rdname column_string_functions #' @aliases rtrim,Column,character-method -#' @export #' @note rtrim(Column, character) since 2.3.0 setMethod("rtrim", signature(x = "Column", trimString = "character"), @@ -1396,7 +1333,6 @@ setMethod("rtrim", #' #' @rdname column_aggregate_functions #' @aliases sd sd,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1414,7 +1350,6 @@ setMethod("sd", #' #' @rdname column_datetime_functions #' @aliases second second,Column-method -#' @export #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1429,7 +1364,6 @@ setMethod("second", #' #' @rdname column_misc_functions #' @aliases sha1 sha1,Column-method -#' @export #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -1443,7 +1377,6 @@ setMethod("sha1", #' #' @rdname column_math_functions #' @aliases signum signum,Column-method -#' @export #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1457,7 +1390,6 @@ setMethod("signum", #' #' @rdname column_math_functions #' @aliases sign sign,Column-method -#' @export #' @note sign since 1.5.0 setMethod("sign", signature(x = "Column"), function(x) { @@ -1470,7 +1402,6 @@ setMethod("sign", signature(x = "Column"), #' #' @rdname column_math_functions #' @aliases sin sin,Column-method -#' @export #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1485,7 +1416,6 @@ setMethod("sin", #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method -#' @export #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1499,7 +1429,6 @@ setMethod("sinh", #' #' @rdname column_aggregate_functions #' @aliases skewness skewness,Column-method -#' @export #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1513,7 +1442,6 @@ setMethod("skewness", #' #' @rdname column_string_functions #' @aliases soundex soundex,Column-method -#' @export #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1530,7 +1458,6 @@ setMethod("soundex", #' #' @rdname column_nonaggregate_functions #' @aliases spark_partition_id spark_partition_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, spark_partition_id()))} @@ -1560,7 +1487,6 @@ setMethod("stddev", #' #' @rdname column_aggregate_functions #' @aliases stddev_pop stddev_pop,Column-method -#' @export #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1574,7 +1500,6 @@ setMethod("stddev_pop", #' #' @rdname column_aggregate_functions #' @aliases stddev_samp stddev_samp,Column-method -#' @export #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1588,7 +1513,6 @@ setMethod("stddev_samp", #' #' @rdname column_nonaggregate_functions #' @aliases struct struct,characterOrColumn-method -#' @export #' @examples #' #' \dontrun{ @@ -1614,7 +1538,6 @@ setMethod("struct", #' #' @rdname column_math_functions #' @aliases sqrt sqrt,Column-method -#' @export #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1628,7 +1551,6 @@ setMethod("sqrt", #' #' @rdname column_aggregate_functions #' @aliases sum sum,Column-method -#' @export #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1642,7 +1564,6 @@ setMethod("sum", #' #' @rdname column_aggregate_functions #' @aliases sumDistinct sumDistinct,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1663,7 +1584,6 @@ setMethod("sumDistinct", #' #' @rdname column_math_functions #' @aliases tan tan,Column-method -#' @export #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1678,7 +1598,6 @@ setMethod("tan", #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method -#' @export #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1693,7 +1612,6 @@ setMethod("tanh", #' #' @rdname column_math_functions #' @aliases toDegrees toDegrees,Column-method -#' @export #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1708,7 +1626,6 @@ setMethod("toDegrees", #' #' @rdname column_math_functions #' @aliases toRadians toRadians,Column-method -#' @export #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1728,7 +1645,6 @@ setMethod("toRadians", #' #' @rdname column_datetime_functions #' @aliases to_date to_date,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1749,7 +1665,6 @@ setMethod("to_date", #' @rdname column_datetime_functions #' @aliases to_date,Column,character-method -#' @export #' @note to_date(Column, character) since 2.2.0 setMethod("to_date", signature(x = "Column", format = "character"), @@ -1765,7 +1680,6 @@ setMethod("to_date", #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1803,7 +1717,6 @@ setMethod("to_json", signature(x = "Column"), #' #' @rdname column_datetime_functions #' @aliases to_timestamp to_timestamp,Column,missing-method -#' @export #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1814,7 +1727,6 @@ setMethod("to_timestamp", #' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method -#' @export #' @note to_timestamp(Column, character) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "character"), @@ -1829,7 +1741,6 @@ setMethod("to_timestamp", #' #' @rdname column_string_functions #' @aliases trim trim,Column,missing-method -#' @export #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column", trimString = "missing"), @@ -1840,7 +1751,6 @@ setMethod("trim", #' @rdname column_string_functions #' @aliases trim,Column,character-method -#' @export #' @note trim(Column, character) since 2.3.0 setMethod("trim", signature(x = "Column", trimString = "character"), @@ -1855,7 +1765,6 @@ setMethod("trim", #' #' @rdname column_string_functions #' @aliases unbase64 unbase64,Column-method -#' @export #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1870,7 +1779,6 @@ setMethod("unbase64", #' #' @rdname column_math_functions #' @aliases unhex unhex,Column-method -#' @export #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -1884,7 +1792,6 @@ setMethod("unhex", #' #' @rdname column_string_functions #' @aliases upper upper,Column-method -#' @export #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1898,7 +1805,6 @@ setMethod("upper", #' #' @rdname column_aggregate_functions #' @aliases var var,Column-method -#' @export #' @examples #' #'\dontrun{ @@ -1913,7 +1819,6 @@ setMethod("var", #' @rdname column_aggregate_functions #' @aliases variance variance,Column-method -#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1927,7 +1832,6 @@ setMethod("variance", #' #' @rdname column_aggregate_functions #' @aliases var_pop var_pop,Column-method -#' @export #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -1941,7 +1845,6 @@ setMethod("var_pop", #' #' @rdname column_aggregate_functions #' @aliases var_samp var_samp,Column-method -#' @export #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -1955,7 +1858,6 @@ setMethod("var_samp", #' #' @rdname column_datetime_functions #' @aliases weekofyear weekofyear,Column-method -#' @export #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -1969,7 +1871,6 @@ setMethod("weekofyear", #' #' @rdname column_datetime_functions #' @aliases year year,Column-method -#' @export #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -1985,7 +1886,6 @@ setMethod("year", #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method -#' @export #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2001,7 +1901,6 @@ setMethod("atan2", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2025,7 +1924,6 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases hypot hypot,Column-method -#' @export #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2041,7 +1939,6 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname column_string_functions #' @aliases levenshtein levenshtein,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2064,7 +1961,6 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method -#' @export #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2082,7 +1978,6 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method -#' @export #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2099,7 +1994,6 @@ setMethod("nanvl", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases pmod pmod,Column-method -#' @export #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2114,7 +2008,6 @@ setMethod("pmod", signature(y = "Column"), #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method -#' @export #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2128,7 +2021,6 @@ setMethod("approxCountDistinct", #' #' @rdname column_aggregate_functions #' @aliases countDistinct countDistinct,Column-method -#' @export #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2148,7 +2040,6 @@ setMethod("countDistinct", #' #' @rdname column_string_functions #' @aliases concat concat,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2177,7 +2068,6 @@ setMethod("concat", #' #' @rdname column_nonaggregate_functions #' @aliases greatest greatest,Column-method -#' @export #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2197,7 +2087,6 @@ setMethod("greatest", #' #' @rdname column_nonaggregate_functions #' @aliases least least,Column-method -#' @export #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2216,7 +2105,6 @@ setMethod("least", #' #' @rdname column_aggregate_functions #' @aliases n_distinct n_distinct,Column-method -#' @export #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -2226,7 +2114,6 @@ setMethod("n_distinct", signature(x = "Column"), #' @rdname count #' @name n #' @aliases n,Column-method -#' @export #' @examples \dontrun{n(df$c)} #' @note n since 1.4.0 setMethod("n", signature(x = "Column"), @@ -2245,7 +2132,6 @@ setMethod("n", signature(x = "Column"), #' @rdname column_datetime_diff_functions #' #' @aliases date_format date_format,Column,character-method -#' @export #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2263,7 +2149,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method -#' @export #' @examples #' #' \dontrun{ @@ -2306,7 +2191,6 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") #' @rdname column_datetime_diff_functions #' #' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2328,7 +2212,6 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_string_functions #' @aliases instr instr,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2351,7 +2234,6 @@ setMethod("instr", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases next_day next_day,Column,character-method -#' @export #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2366,7 +2248,6 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method -#' @export #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2379,7 +2260,6 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases add_months add_months,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2400,7 +2280,6 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' #' @rdname column_datetime_diff_functions #' @aliases date_add date_add,Column,numeric-method -#' @export #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2414,7 +2293,6 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname column_datetime_diff_functions #' #' @aliases date_sub date_sub,Column,numeric-method -#' @export #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2431,7 +2309,6 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' #' @rdname column_string_functions #' @aliases format_number format_number,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2454,7 +2331,6 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @rdname column_misc_functions #' @aliases sha2 sha2,Column,numeric-method -#' @export #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2468,7 +2344,6 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftLeft shiftLeft,Column,numeric-method -#' @export #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2484,7 +2359,6 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method -#' @export #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2500,7 +2374,6 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method -#' @export #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2517,7 +2390,6 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method -#' @export #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2533,7 +2405,6 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param toBase base to convert to. #' @rdname column_math_functions #' @aliases conv conv,Column,numeric,numeric-method -#' @export #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { @@ -2551,7 +2422,6 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' #' @rdname column_nonaggregate_functions #' @aliases expr expr,character-method -#' @export #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2566,7 +2436,6 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @rdname column_string_functions #' @aliases format_string format_string,character,Column-method -#' @export #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2587,7 +2456,6 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname column_datetime_functions #' #' @aliases from_unixtime from_unixtime,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2629,7 +2497,6 @@ setMethod("from_unixtime", signature(x = "Column"), #' \code{startTime} as \code{"15 minutes"}. #' @rdname column_datetime_functions #' @aliases window window,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2680,7 +2547,6 @@ setMethod("window", signature(x = "Column"), #' @param pos start position of search. #' @rdname column_string_functions #' @aliases locate locate,character,Column-method -#' @export #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2697,7 +2563,6 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param pad a character string to be padded with. #' @rdname column_string_functions #' @aliases lpad lpad,Column,numeric,character-method -#' @export #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2714,7 +2579,6 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. #' @aliases rand rand,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -2729,7 +2593,6 @@ setMethod("rand", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method -#' @export #' @note rand(numeric) since 1.5.0 setMethod("rand", signature(seed = "numeric"), function(seed) { @@ -2743,7 +2606,6 @@ setMethod("rand", signature(seed = "numeric"), #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method -#' @export #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2753,7 +2615,6 @@ setMethod("randn", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method -#' @export #' @note randn(numeric) since 1.5.0 setMethod("randn", signature(seed = "numeric"), function(seed) { @@ -2770,7 +2631,6 @@ setMethod("randn", signature(seed = "numeric"), #' @param idx a group index. #' @rdname column_string_functions #' @aliases regexp_extract regexp_extract,Column,character,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2799,7 +2659,6 @@ setMethod("regexp_extract", #' @param replacement a character string that a matched \code{pattern} is replaced with. #' @rdname column_string_functions #' @aliases regexp_replace regexp_replace,Column,character,character-method -#' @export #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2815,7 +2674,6 @@ setMethod("regexp_replace", #' #' @rdname column_string_functions #' @aliases rpad rpad,Column,numeric,character-method -#' @export #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2838,7 +2696,6 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' counting from the right. #' @rdname column_string_functions #' @aliases substring_index substring_index,Column,character,numeric-method -#' @export #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2861,7 +2718,6 @@ setMethod("substring_index", #' at the same location, if any. #' @rdname column_string_functions #' @aliases translate translate,Column,character,character-method -#' @export #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -2876,7 +2732,6 @@ setMethod("translate", #' #' @rdname column_datetime_functions #' @aliases unix_timestamp unix_timestamp,missing,missing-method -#' @export #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -2886,7 +2741,6 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method -#' @export #' @note unix_timestamp(Column) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { @@ -2896,7 +2750,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method -#' @export #' @note unix_timestamp(Column, character) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -2912,7 +2765,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. #' @aliases when when,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2941,7 +2793,6 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. #' @aliases ifelse ifelse,Column-method -#' @export #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -2967,7 +2818,6 @@ setMethod("ifelse", #' #' @rdname column_window_functions #' @aliases cume_dist cume_dist,missing-method -#' @export #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2988,7 +2838,6 @@ setMethod("cume_dist", #' #' @rdname column_window_functions #' @aliases dense_rank dense_rank,missing-method -#' @export #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -3005,7 +2854,6 @@ setMethod("dense_rank", #' #' @rdname column_window_functions #' @aliases lag lag,characterOrColumn-method -#' @export #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -3030,7 +2878,6 @@ setMethod("lag", #' #' @rdname column_window_functions #' @aliases lead lead,characterOrColumn,numeric-method -#' @export #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -3054,7 +2901,6 @@ setMethod("lead", #' #' @rdname column_window_functions #' @aliases ntile ntile,numeric-method -#' @export #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3072,7 +2918,6 @@ setMethod("ntile", #' #' @rdname column_window_functions #' @aliases percent_rank percent_rank,missing-method -#' @export #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3093,7 +2938,6 @@ setMethod("percent_rank", #' #' @rdname column_window_functions #' @aliases rank rank,missing-method -#' @export #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3104,7 +2948,6 @@ setMethod("rank", #' @rdname column_window_functions #' @aliases rank,ANY-method -#' @export setMethod("rank", signature(x = "ANY"), function(x, ...) { @@ -3118,7 +2961,6 @@ setMethod("rank", #' #' @rdname column_window_functions #' @aliases row_number row_number,missing-method -#' @export #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3136,7 +2978,6 @@ setMethod("row_number", #' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method -#' @export #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3150,7 +2991,6 @@ setMethod("array_contains", #' #' @rdname column_collection_functions #' @aliases map_keys map_keys,Column-method -#' @export #' @note map_keys since 2.3.0 setMethod("map_keys", signature(x = "Column"), @@ -3164,7 +3004,6 @@ setMethod("map_keys", #' #' @rdname column_collection_functions #' @aliases map_values map_values,Column-method -#' @export #' @note map_values since 2.3.0 setMethod("map_values", signature(x = "Column"), @@ -3178,7 +3017,6 @@ setMethod("map_values", #' #' @rdname column_collection_functions #' @aliases explode explode,Column-method -#' @export #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3192,7 +3030,6 @@ setMethod("explode", #' #' @rdname column_collection_functions #' @aliases size size,Column-method -#' @export #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3210,7 +3047,6 @@ setMethod("size", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method -#' @export #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3225,7 +3061,6 @@ setMethod("sort_array", #' #' @rdname column_collection_functions #' @aliases posexplode posexplode,Column-method -#' @export #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3240,7 +3075,6 @@ setMethod("posexplode", #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method -#' @export #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3261,7 +3095,6 @@ setMethod("create_array", #' #' @rdname column_nonaggregate_functions #' @aliases create_map create_map,Column-method -#' @export #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3279,7 +3112,6 @@ setMethod("create_map", #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3299,7 +3131,6 @@ setMethod("collect_list", #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method -#' @export #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3314,7 +3145,6 @@ setMethod("collect_set", #' #' @rdname column_string_functions #' @aliases split_string split_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3337,7 +3167,6 @@ setMethod("split_string", #' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3360,7 +3189,6 @@ setMethod("repeat_string", #' #' @rdname column_collection_functions #' @aliases explode_outer explode_outer,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3385,7 +3213,6 @@ setMethod("explode_outer", #' #' @rdname column_collection_functions #' @aliases posexplode_outer posexplode_outer,Column-method -#' @export #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), @@ -3406,7 +3233,6 @@ setMethod("posexplode_outer", #' @name not #' @aliases not,Column-method #' @family non-aggregate functions -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -3434,7 +3260,6 @@ setMethod("not", #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3467,7 +3292,6 @@ setMethod("grouping_bit", #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3502,7 +3326,6 @@ setMethod("grouping_id", #' #' @rdname column_nonaggregate_functions #' @aliases input_file_name input_file_name,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -3520,7 +3343,6 @@ setMethod("input_file_name", signature("missing"), #' #' @rdname column_datetime_functions #' @aliases trunc trunc,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3540,7 +3362,6 @@ setMethod("trunc", #' #' @rdname column_datetime_functions #' @aliases date_trunc date_trunc,character,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3559,7 +3380,6 @@ setMethod("date_trunc", #' #' @rdname column_datetime_functions #' @aliases current_date current_date,missing-method -#' @export #' @examples #' \dontrun{ #' head(select(df, current_date(), current_timestamp()))} @@ -3576,7 +3396,6 @@ setMethod("current_date", #' #' @rdname column_datetime_functions #' @aliases current_timestamp current_timestamp,missing-method -#' @export #' @note current_timestamp since 2.3.0 setMethod("current_timestamp", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e0dde3339fabc..6fba4b6c761dd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -19,7 +19,6 @@ # @rdname aggregateRDD # @seealso reduce -# @export setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) @@ -27,21 +26,17 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition -# @export setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods -# @export setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods -# @export setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) # @rdname collect-methods -# @export setGeneric("collectPartition", function(x, partitionId) { standardGeneric("collectPartition") @@ -52,19 +47,15 @@ setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue -# @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @rdname crosstab -# @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @rdname freqItems -# @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) # @rdname approxQuantile -# @export setGeneric("approxQuantile", function(x, cols, probabilities, relativeError) { standardGeneric("approxQuantile") @@ -73,18 +64,15 @@ setGeneric("approxQuantile", setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD -# @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap -# @export setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @rdname fold # @seealso reduce -# @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) @@ -95,17 +83,14 @@ setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachParti setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) # @rdname glom -# @export setGeneric("glom", function(x) { standardGeneric("glom") }) # @rdname histogram -# @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) # @rdname keyBy -# @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) @@ -123,47 +108,37 @@ setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) # @rdname maximum -# @export setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @rdname minimum -# @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) # @rdname sumRDD -# @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @rdname name -# @export setGeneric("name", function(x) { standardGeneric("name") }) # @rdname getNumPartitionsRDD -# @export setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions -# @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD -# @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) # @rdname pivot -# @export setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) # @rdname reduce -# @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD -# @export setGeneric("sampleRDD", function(x, withReplacement, fraction, seed) { standardGeneric("sampleRDD") @@ -171,21 +146,17 @@ setGeneric("sampleRDD", # @rdname saveAsObjectFile # @seealso objectFile -# @export setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) # @rdname saveAsTextFile -# @export setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) # @rdname setName -# @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) # @rdname sortBy -# @export setGeneric("sortBy", function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") @@ -194,88 +165,71 @@ setGeneric("sortBy", setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered -# @export setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) # @rdname takeSample -# @export setGeneric("takeSample", function(x, withReplacement, num, seed) { standardGeneric("takeSample") }) # @rdname top -# @export setGeneric("top", function(x, num) { standardGeneric("top") }) # @rdname unionRDD -# @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD -# @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD -# @export setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex # @seealso zipWithUniqueId -# @export setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) # @rdname zipWithUniqueId # @seealso zipWithIndex -# @export setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) ############ Binary Functions ############# # @rdname cartesian -# @export setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) # @rdname countByKey -# @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) # @rdname flatMapValues -# @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) # @rdname intersection -# @export setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) # @rdname keys -# @export setGeneric("keys", function(x) { standardGeneric("keys") }) # @rdname lookup -# @export setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) # @rdname mapValues -# @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) # @rdname sampleByKey -# @export setGeneric("sampleByKey", function(x, withReplacement, fractions, seed) { standardGeneric("sampleByKey") }) # @rdname values -# @export setGeneric("values", function(x) { standardGeneric("values") }) @@ -283,14 +237,12 @@ setGeneric("values", function(x) { standardGeneric("values") }) # @rdname aggregateByKey # @seealso foldByKey, combineByKey -# @export setGeneric("aggregateByKey", function(x, zeroValue, seqOp, combOp, numPartitions) { standardGeneric("aggregateByKey") }) # @rdname cogroup -# @export setGeneric("cogroup", function(..., numPartitions) { standardGeneric("cogroup") @@ -299,7 +251,6 @@ setGeneric("cogroup", # @rdname combineByKey # @seealso groupByKey, reduceByKey -# @export setGeneric("combineByKey", function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { standardGeneric("combineByKey") @@ -307,64 +258,53 @@ setGeneric("combineByKey", # @rdname foldByKey # @seealso aggregateByKey, combineByKey -# @export setGeneric("foldByKey", function(x, zeroValue, func, numPartitions) { standardGeneric("foldByKey") }) # @rdname join-methods -# @export setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) # @rdname groupByKey # @seealso reduceByKey -# @export setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) # @rdname join-methods -# @export setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @rdname join-methods -# @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey -# @export setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) # @rdname reduceByKeyLocally # @seealso reduceByKey -# @export setGeneric("reduceByKeyLocally", function(x, combineFunc) { standardGeneric("reduceByKeyLocally") }) # @rdname join-methods -# @export setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) # @rdname sortByKey -# @export setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) # @rdname subtract -# @export setGeneric("subtract", function(x, other, numPartitions = 1) { standardGeneric("subtract") }) # @rdname subtractByKey -# @export setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") @@ -374,7 +314,6 @@ setGeneric("subtractByKey", ################### Broadcast Variable Methods ################# # @rdname broadcast -# @export setGeneric("value", function(bcast) { standardGeneric("value") }) @@ -384,7 +323,6 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @param ... further arguments to be passed to or from other methods. #' @return A SparkDataFrame. #' @rdname summarize -#' @export setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias @@ -399,11 +337,9 @@ setGeneric("agg", function(x, ...) { standardGeneric("agg") }) NULL #' @rdname arrange -#' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname as.data.frame -#' @export setGeneric("as.data.frame", function(x, row.names = NULL, optional = FALSE, ...) { standardGeneric("as.data.frame") @@ -411,52 +347,41 @@ setGeneric("as.data.frame", # Do not document the generic because of signature changes across R versions #' @noRd -#' @export setGeneric("attach") #' @rdname cache -#' @export setGeneric("cache", function(x) { standardGeneric("cache") }) #' @rdname checkpoint -#' @export setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce #' @param x a SparkDataFrame. #' @param ... additional argument(s). -#' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) #' @rdname collect -#' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) #' @param do.NULL currently not used. #' @param prefix currently not used. #' @rdname columns -#' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) #' @rdname columns -#' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) #' @rdname coltypes -#' @export setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) #' @rdname coltypes -#' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) #' @rdname columns -#' @export setGeneric("columns", function(x) {standardGeneric("columns") }) #' @param x a GroupedData or Column. #' @rdname count -#' @export setGeneric("count", function(x) { standardGeneric("count") }) #' @rdname cov @@ -464,7 +389,6 @@ setGeneric("count", function(x) { standardGeneric("count") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname corr @@ -472,294 +396,229 @@ setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname cov -#' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) #' @rdname cov -#' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) #' @rdname createOrReplaceTempView -#' @export setGeneric("createOrReplaceTempView", function(x, viewName) { standardGeneric("createOrReplaceTempView") }) # @rdname crossJoin -# @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) #' @rdname cube -#' @export setGeneric("cube", function(x, ...) { standardGeneric("cube") }) #' @rdname dapply -#' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @rdname dapplyCollect -#' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapply -#' @export setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapplyCollect -#' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) # @rdname getNumPartitions -# @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) #' @rdname describe -#' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname distinct -#' @export setGeneric("distinct", function(x) { standardGeneric("distinct") }) #' @rdname drop -#' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) #' @rdname dropDuplicates -#' @export setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") }) #' @rdname nafunctions -#' @export setGeneric("dropna", function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { standardGeneric("dropna") }) #' @rdname nafunctions -#' @export setGeneric("na.omit", function(object, ...) { standardGeneric("na.omit") }) #' @rdname dtypes -#' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain -#' @export #' @param x a SparkDataFrame or a StreamingQuery. #' @param extended Logical. If extended is FALSE, prints only the physical plan. #' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except -#' @export setGeneric("except", function(x, y) { standardGeneric("except") }) #' @rdname nafunctions -#' @export setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) #' @rdname filter -#' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @rdname first -#' @export setGeneric("first", function(x, ...) { standardGeneric("first") }) #' @rdname groupBy -#' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @rdname groupBy -#' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) #' @rdname hint -#' @export setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) #' @rdname insertInto -#' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) #' @rdname intersect -#' @export setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @rdname isLocal -#' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @rdname isStreaming -#' @export setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @rdname limit -#' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) #' @rdname localCheckpoint -#' @export setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) #' @rdname merge -#' @export setGeneric("merge") #' @rdname mutate -#' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname orderBy -#' @export setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) #' @rdname persist -#' @export setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) #' @rdname printSchema -#' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) #' @rdname registerTempTable-deprecated -#' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) #' @rdname rename -#' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition -#' @export setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample -#' @export setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) #' @rdname rollup -#' @export setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample -#' @export setGeneric("sample_frac", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy -#' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) #' @rdname saveAsTable -#' @export setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) -#' @export setGeneric("str") #' @rdname take -#' @export setGeneric("take", function(x, num) { standardGeneric("take") }) #' @rdname mutate -#' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) #' @rdname write.df -#' @export setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) #' @rdname write.jdbc -#' @export setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { standardGeneric("write.jdbc") }) #' @rdname write.json -#' @export setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc -#' @export setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet -#' @export setGeneric("write.parquet", function(x, path, ...) { standardGeneric("write.parquet") }) #' @rdname write.parquet -#' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) #' @rdname write.stream -#' @export setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { standardGeneric("write.stream") }) #' @rdname write.text -#' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema -#' @export setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select -#' @export setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr -#' @export setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) #' @rdname showDF -#' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname storageLevel -# @export setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) #' @rdname subset -#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname summarize -#' @export setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary -#' @export setGeneric("summary", function(object, ...) { standardGeneric("summary") }) setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -767,830 +626,660 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union -#' @export setGeneric("union", function(x, y) { standardGeneric("union") }) #' @rdname union -#' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @rdname unionByName -#' @export setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) #' @rdname unpersist -#' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @rdname filter -#' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @rdname with -#' @export setGeneric("with") #' @rdname withColumn -#' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) #' @rdname rename -#' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) #' @rdname withWatermark -#' @export setGeneric("withWatermark", function(x, eventTime, delayThreshold) { standardGeneric("withWatermark") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit -#' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) #' @rdname broadcast -#' @export setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) ###################### Column Methods ########################## #' @rdname columnfunctions -#' @export setGeneric("asc", function(x) { standardGeneric("asc") }) #' @rdname between -#' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @rdname cast -#' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @rdname columnfunctions #' @param x a Column object. #' @param ... additional argument(s). -#' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) #' @rdname columnfunctions -#' @export setGeneric("desc", function(x) { standardGeneric("desc") }) #' @rdname endsWith -#' @export setGeneric("endsWith", function(x, suffix) { standardGeneric("endsWith") }) #' @rdname columnfunctions -#' @export setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @rdname columnfunctions -#' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname columnfunctions -#' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname columnfunctions -#' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @rdname columnfunctions -#' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname columnfunctions -#' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname columnfunctions -#' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname startsWith -#' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise -#' @export setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @rdname over -#' @export setGeneric("over", function(x, window) { standardGeneric("over") }) #' @rdname eq_null_safe -#' @export setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) ###################### WindowSpec Methods ########################## #' @rdname partitionBy -#' @export setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) #' @rdname rowsBetween -#' @export setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) #' @rdname rangeBetween -#' @export setGeneric("rangeBetween", function(x, start, end) { standardGeneric("rangeBetween") }) #' @rdname windowPartitionBy -#' @export setGeneric("windowPartitionBy", function(col, ...) { standardGeneric("windowPartitionBy") }) #' @rdname windowOrderBy -#' @export setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy") }) ###################### Expression Function Methods ########################## #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. #' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg -#' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column -#' @export setGeneric("column", function(x) { standardGeneric("column") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_date", function(x = "missing") { standardGeneric("current_date") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_timestamp", function(x = "missing") { standardGeneric("current_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last -#' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count -#' @export setGeneric("n", function(x) { standardGeneric("n") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not -#' @export setGeneric("not", function(x) { standardGeneric("not") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("trim", function(x, trimString) { standardGeneric("trim") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) @@ -1598,142 +1287,110 @@ setGeneric("year", function(x) { standardGeneric("year") }) ###################### Spark.ML Methods ########################## #' @rdname fitted -#' @export setGeneric("fitted") # Do not carry stats::glm usage and param here, and do not document the generic -#' @export #' @noRd setGeneric("glm") #' @param object a fitted ML model object. #' @param ... additional argument(s) passed to the method. #' @rdname predict -#' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind -#' @export setGeneric("rbind", signature = "...") #' @rdname spark.als -#' @export setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) #' @rdname spark.bisectingKmeans -#' @export setGeneric("spark.bisectingKmeans", function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") }) #' @rdname spark.gaussianMixture -#' @export setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) #' @rdname spark.gbt -#' @export setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) #' @rdname spark.glm -#' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) #' @rdname spark.isoreg -#' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) #' @rdname spark.kmeans -#' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) #' @rdname spark.kstest -#' @export setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) #' @rdname spark.lda -#' @export setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) #' @rdname spark.logit -#' @export setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp -#' @export setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes -#' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) #' @rdname spark.decisionTree -#' @export setGeneric("spark.decisionTree", function(data, formula, ...) { standardGeneric("spark.decisionTree") }) #' @rdname spark.randomForest -#' @export setGeneric("spark.randomForest", function(data, formula, ...) { standardGeneric("spark.randomForest") }) #' @rdname spark.survreg -#' @export setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear -#' @export setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") }) #' @rdname spark.lda -#' @export setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") }) #' @rdname spark.lda -#' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml -#' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) ###################### Streaming Methods ########################## #' @rdname awaitTermination -#' @export setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive -#' @export setGeneric("isActive", function(x) { standardGeneric("isActive") }) #' @rdname lastProgress -#' @export setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) #' @rdname queryName -#' @export setGeneric("queryName", function(x) { standardGeneric("queryName") }) #' @rdname status -#' @export setGeneric("status", function(x) { standardGeneric("status") }) #' @rdname stopQuery -#' @export setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 54ef9f07d6fae..f751b952f3915 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -30,7 +30,6 @@ setOldClass("jobj") #' @seealso groupBy #' #' @param sgd A Java object reference to the backing Scala GroupedData -#' @export #' @note GroupedData since 1.4.0 setClass("GroupedData", slots = list(sgd = "jobj")) @@ -48,7 +47,6 @@ groupedData <- function(sgd) { #' @rdname show #' @aliases show,GroupedData-method -#' @export #' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "GroupedData", #' @return A SparkDataFrame. #' @rdname count #' @aliases count,GroupedData-method -#' @export #' @examples #' \dontrun{ #' count(groupBy(df, "name")) @@ -87,7 +84,6 @@ setMethod("count", #' @aliases agg,GroupedData-method #' @name agg #' @family agg_funcs -#' @export #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -150,7 +146,6 @@ methods <- c("avg", "max", "mean", "min", "sum") #' @rdname pivot #' @aliases pivot,GroupedData,character-method #' @name pivot -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -202,7 +197,6 @@ createMethods() #' @rdname gapply #' @aliases gapply,GroupedData-method #' @name gapply -#' @export #' @note gapply(GroupedData) since 2.0.0 setMethod("gapply", signature(x = "GroupedData"), @@ -216,7 +210,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @aliases gapplyCollect,GroupedData-method #' @name gapplyCollect -#' @export #' @note gapplyCollect(GroupedData) since 2.0.0 setMethod("gapplyCollect", signature(x = "GroupedData"), diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 04dc7562e5346..6d1edf6b6f3cf 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -58,7 +58,6 @@ #' @rdname install.spark #' @name install.spark #' @aliases install.spark -#' @export #' @examples #'\dontrun{ #' install.spark() diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R index bb5c77544a3da..9a1b26b0fa3c5 100644 --- a/R/pkg/R/jvm.R +++ b/R/pkg/R/jvm.R @@ -35,7 +35,6 @@ #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} #' @rdname sparkR.callJMethod #' @examples @@ -69,7 +68,6 @@ sparkR.callJMethod <- function(x, methodName, ...) { #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} #' @rdname sparkR.callJStatic #' @examples @@ -100,7 +98,6 @@ sparkR.callJStatic <- function(x, methodName, ...) { #' @param ... arguments to be passed to the constructor. #' @return the object created. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} #' @rdname sparkR.newJObject #' @examples diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index f6e9b1357561b..2964fdeff0957 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -21,28 +21,24 @@ #' S4 class that represents an LinearSVCModel #' #' @param jobj a Java object reference to the backing Scala LinearSVCModel -#' @export #' @note LinearSVCModel since 2.2.0 setClass("LinearSVCModel", representation(jobj = "jobj")) #' S4 class that represents an LogisticRegressionModel #' #' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a MultilayerPerceptronClassificationModel #' #' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export #' @note MultilayerPerceptronClassificationModel since 2.1.0 setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a NaiveBayesModel #' #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) @@ -82,7 +78,6 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @rdname spark.svmLinear #' @aliases spark.svmLinear,SparkDataFrame,formula-method #' @name spark.svmLinear -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -131,7 +126,6 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu #' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method -#' @export #' @note predict(LinearSVCModel) since 2.2.0 setMethod("predict", signature(object = "LinearSVCModel"), function(object, newData) { @@ -146,7 +140,6 @@ setMethod("predict", signature(object = "LinearSVCModel"), #' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method -#' @export #' @note summary(LinearSVCModel) since 2.2.0 setMethod("summary", signature(object = "LinearSVCModel"), function(object) { @@ -169,7 +162,6 @@ setMethod("summary", signature(object = "LinearSVCModel"), #' #' @rdname spark.svmLinear #' @aliases write.ml,LinearSVCModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.2.0 setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -257,7 +249,6 @@ function(object, path, overwrite = FALSE) { #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -374,7 +365,6 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") #' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method -#' @export #' @note summary(LogisticRegressionModel) since 2.1.0 setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { @@ -402,7 +392,6 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. #' @rdname spark.logit #' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export #' @note predict(LogisticRegressionModel) since 2.1.0 setMethod("predict", signature(object = "LogisticRegressionModel"), function(object, newData) { @@ -417,7 +406,6 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), #' #' @rdname spark.logit #' @aliases write.ml,LogisticRegressionModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -458,7 +446,6 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} -#' @export #' @examples #' \dontrun{ #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") @@ -517,7 +504,6 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), #' For \code{weights}, it is a numeric vector with length equal to the expected #' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp -#' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method #' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), @@ -538,7 +524,6 @@ setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel #' "prediction". #' @rdname spark.mlp #' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export #' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), function(object, newData) { @@ -553,7 +538,6 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel #' #' @rdname spark.mlp #' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export #' @seealso \link{write.ml} #' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", @@ -585,7 +569,6 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @aliases spark.naiveBayes,SparkDataFrame,formula-method #' @name spark.naiveBayes #' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export #' @examples #' \dontrun{ #' data <- as.data.frame(UCBAdmissions) @@ -624,7 +607,6 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form #' The list includes \code{apriori} (the label distribution) and #' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes -#' @export #' @note summary(NaiveBayesModel) since 2.0.0 setMethod("summary", signature(object = "NaiveBayesModel"), function(object) { @@ -648,7 +630,6 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named #' "prediction". #' @rdname spark.naiveBayes -#' @export #' @note predict(NaiveBayesModel) since 2.0.0 setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { @@ -662,7 +643,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.naiveBayes -#' @export #' @seealso \link{write.ml} #' @note write.ml(NaiveBayesModel, character) since 2.0.0 setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index a25bf81c6d977..900be685824da 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -20,28 +20,24 @@ #' S4 class that represents a BisectingKMeansModel #' #' @param jobj a Java object reference to the backing Scala BisectingKMeansModel -#' @export #' @note BisectingKMeansModel since 2.2.0 setClass("BisectingKMeansModel", representation(jobj = "jobj")) #' S4 class that represents a GaussianMixtureModel #' #' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export #' @note GaussianMixtureModel since 2.1.0 setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' S4 class that represents a KMeansModel #' #' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export #' @note KMeansModel since 2.0.0 setClass("KMeansModel", representation(jobj = "jobj")) #' S4 class that represents an LDAModel #' #' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export #' @note LDAModel since 2.1.0 setClass("LDAModel", representation(jobj = "jobj")) @@ -68,7 +64,6 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @rdname spark.bisectingKmeans #' @aliases spark.bisectingKmeans,SparkDataFrame,formula-method #' @name spark.bisectingKmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -117,7 +112,6 @@ setMethod("spark.bisectingKmeans", signature(data = "SparkDataFrame", formula = #' (cluster centers of the transformed data; cluster is NULL if is.loaded is TRUE), #' and \code{is.loaded} (whether the model is loaded from a saved file). #' @rdname spark.bisectingKmeans -#' @export #' @note summary(BisectingKMeansModel) since 2.2.0 setMethod("summary", signature(object = "BisectingKMeansModel"), function(object) { @@ -144,7 +138,6 @@ setMethod("summary", signature(object = "BisectingKMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a bisecting k-means model. #' @rdname spark.bisectingKmeans -#' @export #' @note predict(BisectingKMeansModel) since 2.2.0 setMethod("predict", signature(object = "BisectingKMeansModel"), function(object, newData) { @@ -160,7 +153,6 @@ setMethod("predict", signature(object = "BisectingKMeansModel"), #' or \code{"classes"} for assigned classes. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname spark.bisectingKmeans -#' @export #' @note fitted since 2.2.0 setMethod("fitted", signature(object = "BisectingKMeansModel"), function(object, method = c("centers", "classes")) { @@ -181,7 +173,6 @@ setMethod("fitted", signature(object = "BisectingKMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.bisectingKmeans -#' @export #' @note write.ml(BisectingKMeansModel, character) since 2.2.0 setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -208,7 +199,6 @@ setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "charact #' @rdname spark.gaussianMixture #' @name spark.gaussianMixture #' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +241,6 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = #' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture -#' @export #' @note summary(GaussianMixtureModel) since 2.1.0 setMethod("summary", signature(object = "GaussianMixtureModel"), function(object) { @@ -291,7 +280,6 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), #' "prediction". #' @aliases predict,GaussianMixtureModel,SparkDataFrame-method #' @rdname spark.gaussianMixture -#' @export #' @note predict(GaussianMixtureModel) since 2.1.0 setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { @@ -306,7 +294,6 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' @aliases write.ml,GaussianMixtureModel,character-method #' @rdname spark.gaussianMixture -#' @export #' @note write.ml(GaussianMixtureModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -336,7 +323,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @rdname spark.kmeans #' @aliases spark.kmeans,SparkDataFrame,formula-method #' @name spark.kmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -385,7 +371,6 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' (the actual number of cluster centers. When using initMode = "random", #' \code{clusterSize} may not equal to \code{k}). #' @rdname spark.kmeans -#' @export #' @note summary(KMeansModel) since 2.0.0 setMethod("summary", signature(object = "KMeansModel"), function(object) { @@ -413,7 +398,6 @@ setMethod("summary", signature(object = "KMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a k-means model. #' @rdname spark.kmeans -#' @export #' @note predict(KMeansModel) since 2.0.0 setMethod("predict", signature(object = "KMeansModel"), function(object, newData) { @@ -431,7 +415,6 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param ... additional argument(s) passed to the method. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted -#' @export #' @examples #' \dontrun{ #' model <- spark.kmeans(trainingData, ~ ., 2) @@ -458,7 +441,6 @@ setMethod("fitted", signature(object = "KMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.kmeans -#' @export #' @note write.ml(KMeansModel, character) since 2.0.0 setMethod("write.ml", signature(object = "KMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -496,7 +478,6 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export #' @examples #' \dontrun{ #' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") @@ -558,7 +539,6 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' It is only for distributed LDA model (i.e., optimizer = "em")} #' @rdname spark.lda #' @aliases summary,LDAModel-method -#' @export #' @note summary(LDAModel) since 2.1.0 setMethod("summary", signature(object = "LDAModel"), function(object, maxTermsPerTopic) { @@ -596,7 +576,6 @@ setMethod("summary", signature(object = "LDAModel"), #' perplexity of the training data if missing argument "data". #' @rdname spark.lda #' @aliases spark.perplexity,LDAModel-method -#' @export #' @note spark.perplexity(LDAModel) since 2.1.0 setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), function(object, data) { @@ -611,7 +590,6 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr #' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export #' @note spark.posterior(LDAModel) since 2.1.0 setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), function(object, newData) { @@ -626,7 +604,6 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' #' @rdname spark.lda #' @aliases write.ml,LDAModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(LDAModel, character) since 2.1.0 setMethod("write.ml", signature(object = "LDAModel", path = "character"), diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index dfcb45a1b66c9..e2394906d8012 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -20,7 +20,6 @@ #' S4 class that represents a FPGrowthModel #' #' @param jobj a Java object reference to the backing Scala FPGrowthModel -#' @export #' @note FPGrowthModel since 2.2.0 setClass("FPGrowthModel", slots = list(jobj = "jobj")) @@ -45,7 +44,6 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' @rdname spark.fpGrowth #' @name spark.fpGrowth #' @aliases spark.fpGrowth,SparkDataFrame-method -#' @export #' @examples #' \dontrun{ #' raw_data <- read.df( @@ -109,7 +107,6 @@ setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), #' and \code{freq} (frequency of the itemset). #' @rdname spark.fpGrowth #' @aliases freqItemsets,FPGrowthModel-method -#' @export #' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), function(object) { @@ -125,7 +122,6 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), #' and \code{condfidence} (confidence). #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method -#' @export #' @note spark.associationRules(FPGrowthModel) since 2.2.0 setMethod("spark.associationRules", signature(object = "FPGrowthModel"), function(object) { @@ -138,7 +134,6 @@ setMethod("spark.associationRules", signature(object = "FPGrowthModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.fpGrowth #' @aliases predict,FPGrowthModel-method -#' @export #' @note predict(FPGrowthModel) since 2.2.0 setMethod("predict", signature(object = "FPGrowthModel"), function(object, newData) { @@ -153,7 +148,6 @@ setMethod("predict", signature(object = "FPGrowthModel"), #' if the output path exists. #' @rdname spark.fpGrowth #' @aliases write.ml,FPGrowthModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(FPGrowthModel, character) since 2.2.0 setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index 5441c4a4022a9..9a77b07462585 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -20,7 +20,6 @@ #' S4 class that represents an ALSModel #' #' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export #' @note ALSModel since 2.1.0 setClass("ALSModel", representation(jobj = "jobj")) @@ -55,7 +54,6 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als -#' @export #' @examples #' \dontrun{ #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -118,7 +116,6 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), #' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method -#' @export #' @note summary(ALSModel) since 2.1.0 setMethod("summary", signature(object = "ALSModel"), function(object) { @@ -139,7 +136,6 @@ setMethod("summary", signature(object = "ALSModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.als #' @aliases predict,ALSModel-method -#' @export #' @note predict(ALSModel) since 2.1.0 setMethod("predict", signature(object = "ALSModel"), function(object, newData) { @@ -155,7 +151,6 @@ setMethod("predict", signature(object = "ALSModel"), #' #' @rdname spark.als #' @aliases write.ml,ALSModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(ALSModel, character) since 2.1.0 setMethod("write.ml", signature(object = "ALSModel", path = "character"), diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 545be5e1d89f0..95c1a29905197 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -21,21 +21,18 @@ #' S4 class that represents a AFTSurvivalRegressionModel #' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export #' @note AFTSurvivalRegressionModel since 2.0.0 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a generalized linear model #' #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export #' @note GeneralizedLinearRegressionModel since 2.0.0 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' S4 class that represents an IsotonicRegressionModel #' #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export #' @note IsotonicRegressionModel since 2.1.0 setClass("IsotonicRegressionModel", representation(jobj = "jobj")) @@ -85,7 +82,6 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -211,7 +207,6 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @aliases glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -244,7 +239,6 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in #' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm -#' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object) { @@ -290,7 +284,6 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), #' @rdname spark.glm #' @param x summary object of fitted generalized linear model returned by \code{summary} function. -#' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { if (x$is.loaded) { @@ -324,7 +317,6 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named #' "prediction". #' @rdname spark.glm -#' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { @@ -338,7 +330,6 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.glm -#' @export #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -363,7 +354,6 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -412,7 +402,6 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" #' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method -#' @export #' @note summary(IsotonicRegressionModel) since 2.1.0 setMethod("summary", signature(object = "IsotonicRegressionModel"), function(object) { @@ -429,7 +418,6 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export #' @note predict(IsotonicRegressionModel) since 2.1.0 setMethod("predict", signature(object = "IsotonicRegressionModel"), function(object, newData) { @@ -444,7 +432,6 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), #' #' @rdname spark.isoreg #' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export #' @note write.ml(IsotonicRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -477,7 +464,6 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(ovarian) @@ -517,7 +503,6 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' The list includes the model's \code{coefficients} (features, coefficients, #' intercept and log(scale)). #' @rdname spark.survreg -#' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), function(object) { @@ -537,7 +522,6 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values #' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg -#' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { @@ -550,7 +534,6 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @rdname spark.survreg -#' @export #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 #' @seealso \link{write.ml} setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R index 3e013f1d45e38..f8c3329359961 100644 --- a/R/pkg/R/mllib_stat.R +++ b/R/pkg/R/mllib_stat.R @@ -20,7 +20,6 @@ #' S4 class that represents an KSTest #' #' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) @@ -52,7 +51,6 @@ setClass("KSTest", representation(jobj = "jobj")) #' @name spark.kstest #' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ #' MLlib: Hypothesis Testing} -#' @export #' @examples #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) @@ -94,7 +92,6 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), #' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method -#' @export #' @note summary(KSTest) since 2.1.0 setMethod("summary", signature(object = "KSTest"), function(object) { @@ -117,7 +114,6 @@ setMethod("summary", signature(object = "KSTest"), #' @rdname spark.kstest #' @param x summary object of KSTest returned by \code{summary}. -#' @export #' @note print.summary.KSTest since 2.1.0 print.summary.KSTest <- function(x, ...) { jobj <- x$jobj diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 4e5ddf22ee16d..6769be038efa9 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -20,42 +20,36 @@ #' S4 class that represents a GBTRegressionModel #' #' @param jobj a Java object reference to the backing Scala GBTRegressionModel -#' @export #' @note GBTRegressionModel since 2.1.0 setClass("GBTRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a GBTClassificationModel #' #' @param jobj a Java object reference to the backing Scala GBTClassificationModel -#' @export #' @note GBTClassificationModel since 2.1.0 setClass("GBTClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestRegressionModel #' #' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export #' @note RandomForestRegressionModel since 2.1.0 setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestClassificationModel #' #' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeRegressionModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel -#' @export #' @note DecisionTreeRegressionModel since 2.3.0 setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeClassificationModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel -#' @export #' @note DecisionTreeClassificationModel since 2.3.0 setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) @@ -179,7 +173,6 @@ print.summary.decisionTree <- function(x) { #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. #' @rdname spark.gbt #' @name spark.gbt -#' @export #' @examples #' \dontrun{ #' # fit a Gradient Boosted Tree Regression Model @@ -261,7 +254,6 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method -#' @export #' @note summary(GBTRegressionModel) since 2.1.0 setMethod("summary", signature(object = "GBTRegressionModel"), function(object) { @@ -275,7 +267,6 @@ setMethod("summary", signature(object = "GBTRegressionModel"), #' @param x summary object of Gradient Boosted Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.gbt -#' @export #' @note print.summary.GBTRegressionModel since 2.1.0 print.summary.GBTRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -285,7 +276,6 @@ print.summary.GBTRegressionModel <- function(x, ...) { #' @rdname spark.gbt #' @aliases summary,GBTClassificationModel-method -#' @export #' @note summary(GBTClassificationModel) since 2.1.0 setMethod("summary", signature(object = "GBTClassificationModel"), function(object) { @@ -297,7 +287,6 @@ setMethod("summary", signature(object = "GBTClassificationModel"), # Prints the summary of Gradient Boosted Tree Classification Model #' @rdname spark.gbt -#' @export #' @note print.summary.GBTClassificationModel since 2.1.0 print.summary.GBTClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -310,7 +299,6 @@ print.summary.GBTClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.gbt #' @aliases predict,GBTRegressionModel-method -#' @export #' @note predict(GBTRegressionModel) since 2.1.0 setMethod("predict", signature(object = "GBTRegressionModel"), function(object, newData) { @@ -319,7 +307,6 @@ setMethod("predict", signature(object = "GBTRegressionModel"), #' @rdname spark.gbt #' @aliases predict,GBTClassificationModel-method -#' @export #' @note predict(GBTClassificationModel) since 2.1.0 setMethod("predict", signature(object = "GBTClassificationModel"), function(object, newData) { @@ -334,7 +321,6 @@ setMethod("predict", signature(object = "GBTClassificationModel"), #' which means throw exception if the output path exists. #' @aliases write.ml,GBTRegressionModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -343,7 +329,6 @@ setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character #' @aliases write.ml,GBTClassificationModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -402,7 +387,6 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @return \code{spark.randomForest} returns a fitted Random Forest model. #' @rdname spark.randomForest #' @name spark.randomForest -#' @export #' @examples #' \dontrun{ #' # fit a Random Forest Regression Model @@ -480,7 +464,6 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method -#' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { @@ -494,7 +477,6 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -504,7 +486,6 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method -#' @export #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { @@ -516,7 +497,6 @@ setMethod("summary", signature(object = "RandomForestClassificationModel"), # Prints the summary of Random Forest Classification Model #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -529,7 +509,6 @@ print.summary.RandomForestClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method -#' @export #' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { @@ -538,7 +517,6 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method -#' @export #' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { @@ -554,7 +532,6 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), #' #' @aliases write.ml,RandomForestRegressionModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -563,7 +540,6 @@ setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = " #' @aliases write.ml,RandomForestClassificationModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -617,7 +593,6 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. #' @rdname spark.decisionTree #' @name spark.decisionTree -#' @export #' @examples #' \dontrun{ #' # fit a Decision Tree Regression Model @@ -690,7 +665,6 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method -#' @export #' @note summary(DecisionTreeRegressionModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeRegressionModel"), function(object) { @@ -704,7 +678,6 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"), #' @param x summary object of Decision Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeRegressionModel since 2.3.0 print.summary.DecisionTreeRegressionModel <- function(x, ...) { print.summary.decisionTree(x) @@ -714,7 +687,6 @@ print.summary.DecisionTreeRegressionModel <- function(x, ...) { #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeClassificationModel-method -#' @export #' @note summary(DecisionTreeClassificationModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeClassificationModel"), function(object) { @@ -726,7 +698,6 @@ setMethod("summary", signature(object = "DecisionTreeClassificationModel"), # Prints the summary of Decision Tree Classification Model #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeClassificationModel since 2.3.0 print.summary.DecisionTreeClassificationModel <- function(x, ...) { print.summary.decisionTree(x) @@ -739,7 +710,6 @@ print.summary.DecisionTreeClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeRegressionModel-method -#' @export #' @note predict(DecisionTreeRegressionModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeRegressionModel"), function(object, newData) { @@ -748,7 +718,6 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"), #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeClassificationModel-method -#' @export #' @note predict(DecisionTreeClassificationModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeClassificationModel"), function(object, newData) { @@ -764,7 +733,6 @@ setMethod("predict", signature(object = "DecisionTreeClassificationModel"), #' #' @aliases write.ml,DecisionTreeRegressionModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -773,7 +741,6 @@ setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = " #' @aliases write.ml,DecisionTreeClassificationModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..7d04bffcba3a4 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -31,7 +31,6 @@ #' MLlib model below. #' @rdname write.ml #' @name write.ml -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -48,7 +47,6 @@ NULL #' MLlib model below. #' @rdname predict #' @name predict -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -75,7 +73,6 @@ predict_internal <- function(object, newData) { #' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml -#' @export #' @seealso \link{write.ml} #' @examples #' \dontrun{ diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 65f418740c643..9831fc3cc6d01 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -29,7 +29,6 @@ #' @param ... additional structField objects #' @return a structType object #' @rdname structType -#' @export #' @examples #'\dontrun{ #' schema <- structType(structField("a", "integer"), structField("c", "string"), @@ -49,7 +48,6 @@ structType <- function(x, ...) { #' @rdname structType #' @method structType jobj -#' @export structType.jobj <- function(x, ...) { obj <- structure(list(), class = "structType") obj$jobj <- x @@ -59,7 +57,6 @@ structType.jobj <- function(x, ...) { #' @rdname structType #' @method structType structField -#' @export structType.structField <- function(x, ...) { fields <- list(x, ...) if (!all(sapply(fields, inherits, "structField"))) { @@ -76,7 +73,6 @@ structType.structField <- function(x, ...) { #' @rdname structType #' @method structType character -#' @export structType.character <- function(x, ...) { if (!is.character(x)) { stop("schema must be a DDL-formatted string.") @@ -119,7 +115,6 @@ print.structType <- function(x, ...) { #' @param ... additional argument(s) passed to the method. #' @return A structField object. #' @rdname structField -#' @export #' @examples #'\dontrun{ #' field1 <- structField("a", "integer") @@ -137,7 +132,6 @@ structField <- function(x, ...) { #' @rdname structField #' @method structField jobj -#' @export structField.jobj <- function(x, ...) { obj <- structure(list(), class = "structField") obj$jobj <- x @@ -212,7 +206,6 @@ checkType <- function(type) { #' @param type The data type of the field #' @param nullable A logical vector indicating whether or not the field is nullable #' @rdname structField -#' @export structField.character <- function(x, type, nullable = TRUE, ...) { if (class(x) != "character") { stop("Field name must be a string.") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 965471f3b07a0..a480ac606f10d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -35,7 +35,6 @@ connExists <- function(env) { #' Also terminates the backend this R session is connected to. #' @rdname sparkR.session.stop #' @name sparkR.session.stop -#' @export #' @note sparkR.session.stop since 2.0.0 sparkR.session.stop <- function() { env <- .sparkREnv @@ -84,7 +83,6 @@ sparkR.session.stop <- function() { #' @rdname sparkR.session.stop #' @name sparkR.stop -#' @export #' @note sparkR.stop since 1.4.0 sparkR.stop <- function() { sparkR.session.stop() @@ -103,7 +101,6 @@ sparkR.stop <- function() { #' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") @@ -270,7 +267,6 @@ sparkR.sparkContext <- function( #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRSQL.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -298,7 +294,6 @@ sparkRSQL.init <- function(jsc = NULL) { #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRHive.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -347,7 +342,6 @@ sparkRHive.init <- function(jsc = NULL) { #' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session #' @param ... named Spark properties passed to the method. -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -442,7 +436,6 @@ sparkR.session <- function( #' @return the SparkUI URL, or NA if it is disabled, or not started. #' @rdname sparkR.uiWebUrl #' @name sparkR.uiWebUrl -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index c8af798830b30..497f18c763048 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -37,7 +37,6 @@ setOldClass("jobj") #' @name crosstab #' @aliases crosstab,SparkDataFrame,character,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -63,7 +62,6 @@ setMethod("crosstab", #' @rdname cov #' @aliases cov,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -92,7 +90,6 @@ setMethod("cov", #' @name corr #' @aliases corr,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -124,7 +121,6 @@ setMethod("corr", #' @name freqItems #' @aliases freqItems,SparkDataFrame,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -168,7 +164,6 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @name approxQuantile #' @aliases approxQuantile,SparkDataFrame,character,numeric,numeric-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -205,7 +200,6 @@ setMethod("approxQuantile", #' @aliases sampleBy,SparkDataFrame,character,list,numeric-method #' @name sampleBy #' @family stat functions -#' @export #' @examples #'\dontrun{ #' df <- read.json("/path/to/file.json") diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index 8390bd5e6de72..fc83463f72cd4 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{read.stream} #' #' @param ssq A Java object reference to the backing Scala StreamingQuery -#' @export #' @note StreamingQuery since 2.2.0 #' @note experimental setClass("StreamingQuery", @@ -45,7 +44,6 @@ streamingQuery <- function(ssq) { } #' @rdname show -#' @export #' @note show(StreamingQuery) since 2.2.0 setMethod("show", "StreamingQuery", function(object) { @@ -70,7 +68,6 @@ setMethod("show", "StreamingQuery", #' @aliases queryName,StreamingQuery-method #' @family StreamingQuery methods #' @seealso \link{write.stream} -#' @export #' @examples #' \dontrun{ queryName(sq) } #' @note queryName(StreamingQuery) since 2.2.0 @@ -85,7 +82,6 @@ setMethod("queryName", #' @name explain #' @aliases explain,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ explain(sq) } #' @note explain(StreamingQuery) since 2.2.0 @@ -104,7 +100,6 @@ setMethod("explain", #' @name lastProgress #' @aliases lastProgress,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ lastProgress(sq) } #' @note lastProgress(StreamingQuery) since 2.2.0 @@ -129,7 +124,6 @@ setMethod("lastProgress", #' @name status #' @aliases status,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ status(sq) } #' @note status(StreamingQuery) since 2.2.0 @@ -150,7 +144,6 @@ setMethod("status", #' @name isActive #' @aliases isActive,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ isActive(sq) } #' @note isActive(StreamingQuery) since 2.2.0 @@ -177,7 +170,6 @@ setMethod("isActive", #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ awaitTermination(sq, 10000) } #' @note awaitTermination(StreamingQuery) since 2.2.0 @@ -202,7 +194,6 @@ setMethod("awaitTermination", #' @name stopQuery #' @aliases stopQuery,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ stopQuery(sq) } #' @note stopQuery(StreamingQuery) since 2.2.0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 164cd6d01a347..f1b5ecaa017df 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -108,7 +108,6 @@ isRDD <- function(name, env) { #' #' @param key the object to be hashed #' @return the hash code as an integer -#' @export #' @examples #'\dontrun{ #' hashCode(1L) # 1 diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R index 0799d841e5dc9..396b27bee80c6 100644 --- a/R/pkg/R/window.R +++ b/R/pkg/R/window.R @@ -29,7 +29,6 @@ #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") @@ -52,7 +51,6 @@ setMethod("windowPartitionBy", #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,Column-method -#' @export #' @note windowPartitionBy(Column) since 2.0.0 setMethod("windowPartitionBy", signature(col = "Column"), @@ -78,7 +76,6 @@ setMethod("windowPartitionBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- windowOrderBy("key1", "key2") @@ -101,7 +98,6 @@ setMethod("windowOrderBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,Column-method -#' @export #' @note windowOrderBy(Column) since 2.0.0 setMethod("windowOrderBy", signature(col = "Column"), From 98a5c0a35f0a24730f5074522939acf57ef95422 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 5 Mar 2018 10:50:00 -0800 Subject: [PATCH 0429/2461] [SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification ## What changes were proposed in this pull request? adding Structured Streaming tests for all Models/Transformers in spark.ml.classification ## How was this patch tested? N/A Author: WeichenXu Closes #20121 from WeichenXu123/ml_stream_test_classification. --- .../DecisionTreeClassifierSuite.scala | 29 +-- .../classification/GBTClassifierSuite.scala | 77 ++---- .../ml/classification/LinearSVCSuite.scala | 15 +- .../LogisticRegressionSuite.scala | 229 +++++++----------- .../MultilayerPerceptronClassifierSuite.scala | 44 ++-- .../ml/classification/NaiveBayesSuite.scala | 47 ++-- .../ml/classification/OneVsRestSuite.scala | 21 +- .../ProbabilisticClassifierSuite.scala | 29 +-- .../RandomForestClassifierSuite.scala | 16 +- 9 files changed, 202 insertions(+), 305 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 38b265d62611b..eeb0324187c5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs import testImplicits._ @@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite MLTestingUtils.checkCopyAndUids(dt, newTree) - val predictions = newTree.transform(newData) - .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => - assert(pred === rawPred.argmax, - s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") - val sum = rawPred.toArray.sum - assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, - "probability prediction mismatch") + testTransformer[(Vector, Double)](newData, newTree, + "prediction", "rawPrediction", "probability") { + case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, DecisionTreeClassificationModel](newTree, newData) + Vector, DecisionTreeClassificationModel](this, newTree, newData) } test("training with 1-category categorical feature") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 978f89c459f0a..092b4a01d5b0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -26,13 +26,12 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils @@ -40,8 +39,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import GBTClassifierSuite.compareAPIs @@ -126,14 +124,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // should predict all zeros binaryModel.setThresholds(Array(0.0, 1.0)) - val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } // should predict all ones binaryModel.setThresholds(Array(1.0, 0.0)) - val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 1.0 + } val gbtBase = new GBTClassifier val model = gbtBase.fit(df) @@ -141,15 +140,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // constant threshold scaling is the same as no thresholds binaryModel.setThresholds(Array(1.0, 1.0)) - val scaledPredictions = binaryModel.transform(df).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") { + scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) - val predictionsWithPredict = model.transform(df).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, model, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } } test("GBTClassifier: Predictor, Classifier methods") { @@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val blas = BLAS.getInstance() val validationDataset = validationData.toDF(labelCol, featuresCol) - val results = gbtModel.transform(validationDataset) - // check that raw prediction is tree predictions dot tree weights - results.select(rawPredictionCol, featuresCol).collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](validationDataset, gbtModel, + "rawPrediction", "features", "probability", "prediction") { + case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) => assert(raw.size === 2) + // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) - } - // Compare rawPrediction with probability - results.select(rawPredictionCol, probabilityCol).collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 2) + // Compare rawPrediction with probability assert(prob.size === 2) // Note: we should check other loss types for classification if they are added val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) assert(prob(0) ~== predFromRaw(0) relTol eps) assert(prob(1) ~== predFromRaw(1) relTol eps) assert(prob(0) + prob(1) ~== 1.0 absTol absEps) - } - // Compare prediction with probability - results.select(predictionCol, probabilityCol).collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") - val resultsUsingRaw2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) - val resultsUsingProb2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - gbtModel.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, GBTClassificationModel](gbtModel, validationDataset) + Vector, GBTClassificationModel](this, gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 41a5d22dd6283..a93825b8a812d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -21,20 +21,18 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.HingeAggregator import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.udf -class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearSVCSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau threshold: Double, expected: Set[(Int, Double)]): Unit = { model.setThreshold(threshold) - val results = model.transform(df).select("id", "prediction").collect() - .map(r => (r.getInt(0), r.getDouble(1))) - .toSet - assert(results === expected, s"Failed for threshold = $threshold") + testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") { + rows: Seq[Row] => + val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet + assert(results === expected, s"Failed for threshold = $threshold") + } } def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a5f81a38face9..9987cbf6ba116 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -22,22 +22,20 @@ import scala.language.existentials import scala.util.Random import scala.util.control.Breaks._ -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.LogisticAggregator import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType -class LogisticRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -332,15 +330,14 @@ class LogisticRegressionSuite val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) - val binaryZeroPredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } binaryModel.setThreshold(0.0) - val binaryOnePredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(smallMultinomialDataset) @@ -348,31 +345,36 @@ class LogisticRegressionSuite // should predict all zeros model.setThresholds(Array(1, 1000, 1000)) - val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } // should predict all ones model.setThresholds(Array(1000, 1, 1000)) - val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(onePredictions.forall(_.getDouble(0) === 1.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } // should predict all twos model.setThresholds(Array(1000, 1000, 1)) - val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(twoPredictions.forall(_.getDouble(0) === 2.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 2.0) + } // constant threshold scaling is the same as no thresholds model.setThresholds(Array(1000, 1000, 1000)) - val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1)) - val predictionsWithPredict = - model.transform(smallMultinomialDataset).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -403,21 +405,19 @@ class LogisticRegressionSuite // Modify model params, and check that the params worked. model.setThreshold(1.0) - val predAllZero = model.transform(smallBinaryDataset) - .select("prediction", "myProbability") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predAllZero.forall(_ === 0), - s"With threshold=1.0, expected predictions to be all 0, but only" + - s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model, "prediction", "myProbability") { rows => + val predAllZero = rows.map(_.getDouble(0)) + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + } // Call transform with params, and check that the params worked. - val predNotAllZero = - model.transform(smallBinaryDataset, model.threshold -> 0.0, - model.probabilityCol -> "myProb") - .select("prediction", "myProb") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predNotAllZero.exists(_ !== 0.0)) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model.copy(ParamMap(model.threshold -> 0.0, + model.probabilityCol -> "myProb")), "prediction", "myProb") { + rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0)) + } // Call fit() with new params, and check as many params as we can. lr.setThresholds(Array(0.6, 0.4)) @@ -441,10 +441,10 @@ class LogisticRegressionSuite val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallMultinomialDataset) - // check that raw prediction is coefficients dot features + intercept - results.select("rawPrediction", "features").collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), + model, "rawPrediction", "features", "probability") { + case Row(raw: Vector, features: Vector, prob: Vector) => + // check that raw prediction is coefficients dot features + intercept assert(raw.size === 3) val margins = Array.tabulate(3) { k => var margin = 0.0 @@ -455,12 +455,7 @@ class LogisticRegressionSuite margin } assert(raw ~== Vectors.dense(margins) relTol eps) - } - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 3) + // Compare rawPrediction with probability assert(prob.size === 3) val max = raw.toArray.max val subtract = if (max > 0) max else 0.0 @@ -472,39 +467,8 @@ class LogisticRegressionSuite assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps) } - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => - val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 - assert(pred == predFromProb) - } - - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallMultinomialDataset) + Vector, LogisticRegressionModel](this, model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -517,51 +481,22 @@ class LogisticRegressionSuite val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallBinaryDataset) - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), + model, "rawPrediction", "probability", "prediction") { + case Row(raw: Vector, prob: Vector, pred: Double) => + // Compare rawPrediction with probability assert(raw.size === 2) assert(prob.size === 2) val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) assert(prob(1) ~== probFromRaw1 relTol eps) assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) - } - - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallBinaryDataset) + Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } test("coefficients and intercept methods") { @@ -616,19 +551,21 @@ class LogisticRegressionSuite LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) ).toDF() - val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() - - // probabilities are correct when margins have to be adjusted - val raw1 = results(0).getAs[Vector](0) - val prob1 = results(0).getAs[Vector](1) - assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) - assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) - - // probabilities are correct when margins don't have to be adjusted - val raw2 = results(1).getAs[Vector](0) - val prob2 = results(1).getAs[Vector](1) - assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) - assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + + testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(), + model, "rawPrediction", "probability") { results: Seq[Row] => + // probabilities are correct when margins have to be adjusted + val raw1 = results(0).getAs[Vector](0) + val prob1 = results(0).getAs[Vector](1) + assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) + assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) + + // probabilities are correct when margins don't have to be adjusted + val raw2 = results(1).getAs[Vector](0) + val prob2 = results(1).getAs[Vector](1) + assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) + assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + } } test("MultiClassSummarizer") { @@ -2567,10 +2504,13 @@ class LogisticRegressionSuite val model1 = lr.fit(smallBinaryDataset) val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial") val model2 = lr2.fit(smallBinaryDataset) - val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect() - val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect() - predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect() + .map(_.getDouble(0)) + for (model <- Seq(model1, model2)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === binaryExpected + } } assert(model2.summary.totalIterations === 1) @@ -2579,10 +2519,13 @@ class LogisticRegressionSuite val lr4 = new LogisticRegression() .setInitialModel(model3).setMaxIter(5).setFamily("multinomial") val model4 = lr4.fit(smallMultinomialDataset) - val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect() - val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect() - predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction") + .collect().map(_.getDouble(0)) + for (model <- Seq(model3, model4)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === multinomialExpected + } } assert(model4.summary.totalIterations === 1) } @@ -2638,8 +2581,8 @@ class LogisticRegressionSuite LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) - val results = model.transform(constantData) - results.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantData, model, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0))) @@ -2653,8 +2596,8 @@ class LogisticRegressionSuite LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) - val resultsZero = modelZeroLabel.transform(constantZeroData) - resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) @@ -2666,8 +2609,8 @@ class LogisticRegressionSuite val constantDataWithMetadata = constantData .select(constantData("label").as("label", labelMeta), constantData("features")) val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata) - val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata) - resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index d3141ec708560..daa58a56896d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions._ -class MultilayerPerceptronClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -75,11 +70,9 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - val result = model.transform(dataset) MLTestingUtils.checkCopyAndUids(trainer, model) - val predictionAndLabels = result.select("prediction", "label").collect() - predictionAndLabels.foreach { case Row(p: Double, l: Double) => - assert(p == l) + testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") { + case Row(p: Double, l: Double) => assert(p == l) } } @@ -99,13 +92,12 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(strongDataset) - val result = model.transform(strongDataset) - result.select("probability", "expectedProbability").collect().foreach { - case Row(p: Vector, e: Vector) => - assert(p ~== e absTol 1e-3) + testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model, + "probability", "expectedProbability") { + case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, MultilayerPerceptronClassificationModel](model, strongDataset) + Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset) } test("test model probability") { @@ -118,11 +110,10 @@ class MultilayerPerceptronClassifierSuite .setSolver("l-bfgs") val model = trainer.fit(dataset) model.setProbabilityCol("probability") - val result = model.transform(dataset) - val features2prob = udf { features: Vector => model.mlpModel.predict(features) } - result.select(features2prob(col("features")), col("probability")).collect().foreach { - case Row(p1: Vector, p2: Vector) => - assert(p1 ~== p2 absTol 1e-3) + testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") { + case Row(features: Vector, prob: Vector) => + val prob2 = model.mlpModel.predict(features) + assert(prob ~== prob2 absTol 1e-3) } } @@ -175,9 +166,6 @@ class MultilayerPerceptronClassifierSuite val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { - case Row(p: Double, l: Double) => (p, l) - } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) @@ -189,8 +177,12 @@ class MultilayerPerceptronClassifierSuite lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) - val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) - assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") { + rows: Seq[Row] => + val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1))) + val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels)) + assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + } } test("read/write: MultilayerPerceptronClassifier") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 0d3adf993383f..49115c8a4db30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF() } - def validatePrediction(predictionAndLabels: DataFrame): Unit = { - val numOfErrorPredictions = predictionAndLabels.collect().count { + def validatePrediction(predictionAndLabels: Seq[Row]): Unit = { + val numOfErrorPredictions = predictionAndLabels.filter { case Row(prediction: Double, label: Double) => prediction != label - } + }.length // At least 80% of the predictions should be on. - assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + assert(numOfErrorPredictions < predictionAndLabels.length / 5) } def validateModelFit( @@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } def validateProbabilities( - featureAndProbabilities: DataFrame, + featureAndProbabilities: Seq[Row], model: NaiveBayesModel, modelType: String): Unit = { - featureAndProbabilities.collect().foreach { + featureAndProbabilities.foreach { case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { @@ -154,15 +153,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "multinomial") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "multinomial") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("Naive Bayes with weighted samples") { @@ -210,15 +212,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "bernoulli") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "bernoulli") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 25bad59b9c9cf..11e88367108b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneVsRestSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -85,10 +83,6 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset.select("prediction", "label").rdd.map { - row => (row.getDouble(0), row.getDouble(1)) - } - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) @@ -97,8 +91,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // determine the #confusion matrix in each class. // bound how much error we allow compared to multinomial logistic regression. val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel, + "prediction", "label") { rows => + val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) } + val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults)) + assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + } } test("one-vs-rest: tuning parallelism does not change output") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index d649ceac949c4..1c8c9829f18d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} @@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite { def testPredictMethods[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]( - model: M, testData: Dataset[_]): Unit = { + mlTest: MLTest, model: M, testData: Dataset[_]): Unit = { val allColModel = model.copy(ParamMap.empty) .setRawPredictionCol("rawPredictionAll") .setProbabilityCol("probabilityAll") .setPredictionCol("predictionAll") - val allColResult = allColModel.transform(testData) + + val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol)) + .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll") for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { for (probabilityCol <- Seq("", "probabilitySingle")) { @@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite { .setProbabilityCol(probabilityCol) .setPredictionCol(predictionCol) - val result = newModel.transform(allColResult) - - import org.apache.spark.sql.functions._ - - val resultRawPredictionCol = - if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) - val resultProbabilityCol = - if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) - val resultPredictionCol = - if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + import allColResult.sparkSession.implicits._ - result.select( - resultRawPredictionCol, col("rawPredictionAll"), - resultProbabilityCol, col("probabilityAll"), - resultPredictionCol, col("predictionAll") - ).collect().foreach { + mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel, + if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol, + "rawPredictionAll", + if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll", + if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll" + ) { case Row( rawPredictionSingle: Vector, rawPredictionAll: Vector, probabilitySingle: Vector, probabilityAll: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2cca2e6c04698..02a9d5c2a18c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -23,11 +23,10 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs import testImplicits._ @@ -143,11 +141,8 @@ class RandomForestClassifierSuite MLTestingUtils.checkCopyAndUids(rf, model) - val predictions = model.transform(df) - .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", + "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") val sum = rawPred.toArray.sum @@ -155,8 +150,9 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ - Vector, RandomForestClassificationModel](model, df) + Vector, RandomForestClassificationModel](this, model, df) } test("Fitting without numClasses in metadata") { From ba622f45caa808a9320c1f7ba4a4f344365dcf90 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 5 Mar 2018 20:43:03 +0100 Subject: [PATCH 0430/2461] [SPARK-23585][SQL] Add interpreted execution to UnwrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to UnwrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20736 from mgaido91/SPARK-23586. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++++-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 80618af1e859f..03cc8eaceb4e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -382,8 +382,14 @@ case class UnwrapOption( override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val inputObject = child.eval(input) + if (inputObject == null) { + null + } else { + inputObject.asInstanceOf[Option[_]].orNull + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 3edcc02f15264..d95db5867b19c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -66,4 +66,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvalutionWithUnsafeProjection( mapEncoder.serializer.head, mapExpected, mapInputRow) } + + test("SPARK-23585: UnwrapOption should support interpreted execution") { + val cls = classOf[Option[Int]] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val unwrapObject = UnwrapOption(IntegerType, inputObject) + Seq((Some(1), 1), (None, null), (null, null)).foreach { case (input, expected) => + checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From b0f422c3861a5a3831e481b8ffac08f6fa085d00 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 5 Mar 2018 13:23:01 -0800 Subject: [PATCH 0431/2461] [SPARK-23559][SS] Add epoch ID to DataWriterFactory. ## What changes were proposed in this pull request? Add an epoch ID argument to DataWriterFactory for use in streaming. As a side effect of passing in this value, DataWriter will now have a consistent lifecycle; commit() or abort() ends the lifecycle of a DataWriter instance in any execution mode. I considered making a separate streaming interface and adding the epoch ID only to that one, but I think it requires a lot of extra work for no real gain. I think it makes sense to define epoch 0 as the one and only epoch of a non-streaming query. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #20710 from jose-torres/api2. --- .../sql/kafka010/KafkaStreamWriter.scala | 5 +++- .../sql/sources/v2/writer/DataWriter.java | 12 ++++++--- .../sources/v2/writer/DataWriterFactory.java | 5 +++- .../v2/writer/streaming/StreamWriter.java | 19 +++++++------- .../datasources/v2/WriteToDataSourceV2.scala | 25 +++++++++++++------ .../streaming/MicroBatchExecution.scala | 7 ++++++ .../sources/PackedRowWriterFactory.scala | 5 +++- .../streaming/sources/memoryV2.scala | 5 +++- .../sources/v2/SimpleWritableDataSource.scala | 10 ++++++-- 9 files changed, 65 insertions(+), 28 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9307bfc001c03..ae5b5c52d514e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -65,7 +65,10 @@ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 53941a89ba94e..39bf458298862 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -31,13 +31,17 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle + * is over and Spark will not use it again. + * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an - * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. + * exception will be sent to the driver side, and Spark may retry this writing task a few times. + * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a + * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index ea95442511ce5..c2c2ab73257e8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -48,6 +48,9 @@ public interface DataWriterFactory extends Serializable { * same task id but different attempt number, which means there are multiple * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 4913341bd505d..a316b2a4c1d82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -23,11 +23,10 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and - * aborts relative to an epoch ID determined by the execution engine. + * A {@link DataSourceWriter} for use with structured streaming. * - * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, - * and so must reset any internal state after a successful commit. + * Streaming queries are divided into intervals of data called epochs, with a monotonically + * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving public interface StreamWriter extends DataSourceWriter { @@ -39,21 +38,21 @@ public interface StreamWriter extends DataSourceWriter { * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. * - * To support exactly-once processing, writer implementations should ensure that this method is - * idempotent. The execution engine may call commit() multiple times for the same epoch - * in some circumstances. + * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * To support exactly-once data semantics, implementations must ensure that multiple commits for + * the same epoch are idempotent. */ void commit(long epochId, WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or + * Aborts this writing job because some data writers are failed and keep failing when retried, or * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. * - * Unless the abort is triggered by the failure of commit, the given messages should have some - * null slots as there maybe only a few data writers that are committed before the abort + * Unless the abort is triggered by the failure of commit, the given messages will have some + * null slots, as there may be only a few data writers that were committed before the abort * happens, or some data writers were committed but their commit messages haven't reached the * driver when the abort is triggered. So this is just a "best effort" for data sources to * clean up the data left by data writers. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 41cdfc80d8a19..e80b44c1cdc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -132,7 +132,8 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val partId = context.partitionId() val attemptId = context.attemptNumber() - val dataWriter = writeTask.createDataWriter(partId, attemptId) + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") + val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -172,7 +173,6 @@ object DataWritingSparkTask extends Logging { writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) @@ -180,10 +180,15 @@ object DataWritingSparkTask extends Logging { var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong do { + var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { - iter.foreach(dataWriter.write) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } logInfo(s"Writer for partition ${context.partitionId()} is committing.") val msg = dataWriter.commit() logInfo(s"Writer for partition ${context.partitionId()} committed.") @@ -196,9 +201,10 @@ object DataWritingSparkTask extends Logging { // Continuous shutdown always involves an interrupt. Just finish the task. } })(catchBlock = { - // If there is an error, abort this writer + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. logError(s"Writer for partition ${context.partitionId()} is aborting.") - dataWriter.abort() + if (dataWriter != null) dataWriter.abort() logError(s"Writer for partition ${context.partitionId()} aborted.") }) } while (!context.isInterrupted()) @@ -211,9 +217,12 @@ class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber), + rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6bd03972c301d..ff4be9c7ab874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -469,6 +469,9 @@ class MicroBatchExecution( case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } + sparkSessionToRunBatch.sparkContext.setLocalProperty( + MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -518,3 +521,7 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object MicroBatchExecution { + val BATCH_ID_KEY = "streaming.sql.batchId" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 248295e401a0d..e07355aa37dba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * for production-quality sinks. It's intended for use in tests. */ case object PackedRowWriterFactory extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f960208155e3b..5f58246083bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -147,7 +147,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 36dd2a350a055..a5007fa321359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -207,7 +207,10 @@ private[v2] object SimpleCounter { class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -240,7 +243,10 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) From f2cab56ca22ed5db5ff604cd78cdb55aaa58f651 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 5 Mar 2018 14:57:32 -0800 Subject: [PATCH 0432/2461] [SPARK-23040][CORE] Returns interruptible iterator for shuffle reader ## What changes were proposed in this pull request? Before this commit, a non-interruptible iterator is returned if aggregator or ordering is specified. This commit also ensures that sorter is closed even when task is cancelled(killed) in the middle of sorting. ## How was this patch tested? Add a unit test in JobCancellationSuite Author: Xianjin YE Closes #20449 from advancedxy/SPARK-23040. --- .../shuffle/BlockStoreShuffleReader.scala | 9 ++- .../apache/spark/JobCancellationSuite.scala | 65 ++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index edd69715c9602..85e7e56a04a7d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -94,7 +94,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Sort the output if there is a sort ordering defined. - dep.keyOrdering match { + val resultIter = dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. val sorter = @@ -103,9 +103,16 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener(_ => { + sorter.stop() + }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } + // Use another interruptible iterator here to support task cancellation as aggregator or(and) + // sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 8a77aea75a992..3b793bb231cf3 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -26,7 +27,7 @@ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft override def afterEach() { try { resetSparkContext() + JobCancellationSuite.taskStartedSemaphore.drainPermits() + JobCancellationSuite.taskCancelledSemaphore.drainPermits() + JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits() + JobCancellationSuite.executionOfInterruptibleCounter.set(0) } finally { super.afterEach() } @@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft f2.get() } + test("interruptible iterator of shuffle reader") { + // In this test case, we create a Spark job of two stages. The second stage is cancelled during + // execution and a counter is used to make sure that the corresponding tasks are indeed + // cancelled. + import JobCancellationSuite._ + sc = new SparkContext("local[2]", "test interruptible iterator") + + val taskCompletedSem = new Semaphore(0) + + sc.addSparkListener(new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + // release taskCancelledSemaphore when cancelTasks event has been posted + if (stageCompleted.stageInfo.stageId == 1) { + taskCancelledSemaphore.release(1000) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.stageId == 1) { // make sure tasks are completed + taskCompletedSem.release() + } + } + }) + + val f = sc.parallelize(1 to 1000).map { i => (i, i) } + .repartitionAndSortWithinPartitions(new HashPartitioner(1)) + .mapPartitions { iter => + taskStartedSemaphore.release() + iter + }.foreachAsync { x => + if (x._1 >= 10) { + // This block of code is partially executed. It will be blocked when x._1 >= 10 and the + // next iteration will be cancelled if the source iterator is interruptible. Then in this + // case, the maximum num of increment would be 10(|1...10|) + taskCancelledSemaphore.acquire() + } + executionOfInterruptibleCounter.getAndIncrement() + } + + taskStartedSemaphore.acquire() + // Job is cancelled when: + // 1. task in reduce stage has been started, guaranteed by previous line. + // 2. task in reduce stage is blocked after processing at most 10 records as + // taskCancelledSemaphore is not released until cancelTasks event is posted + // After job being cancelled, task in reduce stage will be cancelled and no more iteration are + // executed. + f.cancel() + + val e = intercept[SparkException](f.get()).getCause + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + + // Make sure tasks are indeed completed. + taskCompletedSem.acquire() + assert(executionOfInterruptibleCounter.get() <= 10) + } + def testCount() { // Cancel before launching any tasks { @@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft object JobCancellationSuite { + // To avoid any headaches, reset these global variables in the companion class's afterEach block val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) val twoJobsSharingStageSemaphore = new Semaphore(0) + val executionOfInterruptibleCounter = new AtomicInteger(0) } From 508573958dc9b6402e684cd6dd37202deaaa97f6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 5 Mar 2018 15:03:27 -0800 Subject: [PATCH 0433/2461] [SPARK-23538][CORE] Remove custom configuration for SSL client. These options were used to configure the built-in JRE SSL libraries when downloading files from HTTPS servers. But because they were also used to set up the now (long) removed internal HTTPS file server, their default configuration chose convenience over security by having overly lenient settings. This change removes the configuration options that affect the JRE SSL libraries. The JRE trust store can still be configured via system properties (or globally in the JRE security config). The only lost functionality is not being able to disable the default hostname verifier when using spark-submit, which should be fine since Spark itself is not using https for any internal functionality anymore. I also removed the HTTP-related code from the REPL class loader, since we haven't had a HTTP server for REPL-generated classes for a while. Author: Marcelo Vanzin Closes #20723 from vanzin/SPARK-23538. --- .../org/apache/spark/SecurityManager.scala | 45 ------------ .../scala/org/apache/spark/util/Utils.scala | 15 ---- .../org/apache/spark/SSLSampleConfigs.scala | 68 ------------------- .../apache/spark/SecurityManagerSuite.scala | 45 ------------ docs/security.md | 4 -- .../spark/repl/ExecutorClassLoader.scala | 53 ++------------- 6 files changed, 7 insertions(+), 223 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2519d266879aa..da1c89cd78901 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -256,51 +256,6 @@ private[spark] class SecurityManager( // the default SSL configuration - it will be used by all communication layers unless overwritten private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) - // SSL configuration for the file server. This is used by Utils.setupSecureURLConnection(). - val fileServerSSLOptions = getSSLOptions("fs") - val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { - val trustStoreManagers = - for (trustStore <- fileServerSSLOptions.trustStore) yield { - val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() - - try { - val ks = KeyStore.getInstance(KeyStore.getDefaultType) - ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) - - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(ks) - tmf.getTrustManagers - } finally { - input.close() - } - } - - lazy val credulousTrustStoreManagers = Array({ - logWarning("Using 'accept-all' trust manager for SSL connections.") - new X509TrustManager { - override def getAcceptedIssuers: Array[X509Certificate] = null - - override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} - - override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} - }: TrustManager - }) - - require(fileServerSSLOptions.protocol.isDefined, - "spark.ssl.protocol is required when enabling SSL connections.") - - val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.get) - sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) - - val hostVerifier = new HostnameVerifier { - override def verify(s: String, sslSession: SSLSession): Boolean = true - } - - (Some(sslContext.getSocketFactory), Some(hostVerifier)) - } else { - (None, None) - } - def getSSLOptions(module: String): SSLOptions = { val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d493663f0b168..2e2a4a259e9af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -673,7 +673,6 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } - Utils.setupSecureURLConnection(uc, securityMgr) val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 @@ -2363,20 +2362,6 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } - /** - * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and - * the host verifier from the given security manager. - */ - def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { - urlConnection match { - case https: HttpsURLConnection => - sm.sslSocketFactory.foreach(https.setSSLSocketFactory) - sm.hostnameVerifier.foreach(https.setHostnameVerifier) - https - case connection => connection - } - } - def invoke( clazz: Class[_], obj: AnyRef, diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala deleted file mode 100644 index 33270bec6247c..0000000000000 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -object SSLSampleConfigs { - val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File( - this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath - val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - - val enabledAlgorithms = - // A reasonable set of TLSv1.2 Oracle security provider suites - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "TLS_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + - // and their equivalent names in the IBM Security provider - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "SSL_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" - - def sparkSSLConfig(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", keyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - - def sparkSSLConfigUntrusted(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", untrustedKeyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - -} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 106ece7aed0a4..e357299770a2e 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -370,51 +370,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions("user1") === false) } - test("ssl on setup") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val expectedAlgorithms = Set( - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "TLS_RSA_WITH_AES_256_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "SSL_RSA_WITH_AES_256_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === true) - - assert(securityManager.sslSocketFactory.isDefined === true) - assert(securityManager.hostnameVerifier.isDefined === true) - - assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - } - - test("ssl off setup") { - val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) - - System.setProperty("spark.ssl.configFile", file.getAbsolutePath) - val conf = new SparkConf() - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.sslSocketFactory.isDefined === false) - assert(securityManager.hostnameVerifier.isDefined === false) - } - test("missing secret authentication key") { val conf = new SparkConf().set("spark.authenticate", "true") val mgr = new SecurityManager(conf) diff --git a/docs/security.md b/docs/security.md index 0f384b411812a..913d9df50eb1c 100644 --- a/docs/security.md +++ b/docs/security.md @@ -44,10 +44,6 @@ component-specific configuration namespaces used to override the default setting Config Namespace Component - - spark.ssl.fs - File download client (used to download jars and files from HTTPS-enabled servers). - spark.ssl.ui Spark application Web UI diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 127f67329f266..4dc399827ffed 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,12 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException} -import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream} +import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels -import scala.util.control.NonFatal - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ @@ -30,13 +28,13 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.util.ParentClassLoader /** - * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, used to load classes - * defined by the interpreter when the REPL is used. Allows the user to specify if user class path - * should be first. This class loader delegates getting/finding resources to parent loader, which - * makes sense until REPL never provide resource dynamically. + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. * * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or * the load failed, that it will call the overridden `findClass` function. To avoid the potential @@ -60,7 +58,6 @@ class ExecutorClassLoader( private val fetchFn: (String) => InputStream = uri.getScheme() match { case "spark" => getClassFileInputStreamFromSparkRPC - case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer case _ => val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) getClassFileInputStreamFromFileSystem(fileSystem) @@ -113,42 +110,6 @@ class ExecutorClassLoader( } } - private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), - SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] - // Set the connection timeouts (for testing purposes) - if (httpUrlConnectionTimeoutMillis != -1) { - connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) - connection.setReadTimeout(httpUrlConnectionTimeoutMillis) - } - connection.connect() - try { - if (connection.getResponseCode != 200) { - // Close the error stream so that the connection is eligible for re-use - try { - connection.getErrorStream.close() - } catch { - case ioe: IOException => - logError("Exception while closing error stream", ioe) - } - throw new ClassNotFoundException(s"Class file not found at URL $url") - } else { - connection.getInputStream - } - } catch { - case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => - connection.disconnect() - throw e - } - } - private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) From 7706eea6a8bdcd73e9dde5212368f8825e2f1801 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 5 Mar 2018 15:53:10 -0800 Subject: [PATCH 0434/2461] [SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add tests The `__del__` method that explicitly detaches the object was moved from `JavaParams` to `JavaWrapper` class, this way model summaries could also be garbage collected in Java. A test case was added to make sure that relevant error messages are thrown after the objects are deleted. I ran pyspark tests agains `pyspark-ml` module `./python/run-tests --python-executables=$(which python) --modules=pyspark-ml` Author: Yogesh Garg Closes #20724 from yogeshg/java_wrapper_memory. --- python/pyspark/ml/tests.py | 39 ++++++++++++++++++++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 ++++---- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 116885969345c..6dee6938d8916 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5ef..5061f6434794a 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. From f6b49f9d1b6f218408197f7272c1999fe3d94328 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 01:37:51 +0100 Subject: [PATCH 0435/2461] [SPARK-23586][SQL] Add interpreted execution to WrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to WrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20741 from mgaido91/SPARK-23586_2. --- .../sql/catalyst/expressions/objects/objects.scala | 3 +-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 03cc8eaceb4e6..d832fe0a6857c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -422,8 +422,7 @@ case class WrapOption(child: Expression, optType: DataType) override def inputTypes: Seq[AbstractDataType] = optType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = Option(child.eval(input)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d95db5867b19c..d535578a7eb06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -75,4 +75,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23586: WrapOption should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val wrapObject = WrapOption(inputObject, cls) + Seq((1, Some(1)), (null, None)).foreach { case (input, expected) => + checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From 8c5b34c425bda2079a1ff969b12c067f2bb3f18f Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 5 Mar 2018 16:49:24 -0800 Subject: [PATCH 0436/2461] =?UTF-8?q?[SPARK-23604][SQL]=20Change=20Statist?= =?UTF-8?q?ics.isEmpty=20to=20!Statistics.hasNonNul=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …lValue ## What changes were proposed in this pull request? Parquet 1.9 will change the semantics of Statistics.isEmpty slightly to reflect if the null value count has been set. That breaks a timestamp interoperability test that cares only about whether there are column values present in the statistics of a written file for an INT96 column. Fix by using Statistics.hasNonNullValue instead. ## How was this patch tested? Unit tests continue to pass against Parquet 1.8, and also pass against a Parquet build including PARQUET-1217. Author: Henry Robinson Closes #20740 from henryr/spark-23604. --- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index fbd83a0fa425a..9c75965639d8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -184,7 +184,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // when the data is read back as mentioned above, b/c int96 is unsigned. This // assert makes sure this holds even if we change parquet versions (if eg. there // were ever statistics even on unsigned columns). - assert(columnStats.isEmpty) + assert(!columnStats.hasNonNullValue) } // These queries should return the entire dataset with the conversion applied, From ad640a5affceaaf3979e25848628fb1dfcdf932a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 5 Mar 2018 20:35:14 -0800 Subject: [PATCH 0437/2461] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The proposed explain format: **[streaming header] [RelationV2/ScanV2] [data source name] [output] [pushed filters] [options]** **streaming header**: if it's a streaming relation, put a "Streaming" at the beginning. **RelationV2/ScanV2**: if it's a logical plan, put a "RelationV2", else, put a "ScanV2" **data source name**: the simple class name of the data source implementation **output**: a string of the plan output attributes **pushed filters**: a string of all the filters that have been pushed to this data source **options**: all the options to create the data source reader. The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == RelationV2 AdvancedDataSourceV2[j#1] == Physical Plan == *(1) ScanV2 AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) ScanV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) ScanV2 MemoryStreamDataSource[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20647 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 2 +- .../sql/kafka010/KafkaContinuousTest.scala | 2 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../datasources/v2/DataSourceV2Relation.scala | 34 +++++-- .../datasources/v2/DataSourceV2ScanExec.scala | 18 +++- .../datasources/v2/DataSourceV2Strategy.scala | 8 +- .../v2/DataSourceV2StringFormat.scala | 94 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 29 +++++- .../continuous/ContinuousExecution.scala | 8 +- .../spark/sql/streaming/StreamSuite.scala | 12 ++- .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 11 +-- 13 files changed, 183 insertions(+), 105 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index f679e9bfc0450..aab8ec42189fb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -60,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 48ac3fc1e8f9d..fa1468a3943c8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index f2b3ff7615e74..e017fd9b84d21 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -124,7 +124,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cc6cb631e3f06..2b282ffae2390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,15 +35,12 @@ case class DataSourceV2Relation( options: Map[String, String], projection: Seq[AttributeReference], filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = { - s"DataSourceV2Relation(source=${source.name}, " + - s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + - s"filters=[${pushedFilters.mkString(", ")}], options=$options)" - } + override def simpleString: String = "RelationV2 " + metadataString override lazy val schema: StructType = reader.readSchema() @@ -107,19 +104,36 @@ case class DataSourceV2Relation( } /** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. + * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true. + * + * Note that, this plan has a mutable reader, so Spark won't apply operator push-down for this plan, + * to avoid making the plan mutable. We should consolidate this plan and [[DataSourceV2Relation]] + * after we figure out how to apply operator push-down for streaming data sources. */ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], + source: DataSourceV2, + options: Map[String, String], reader: DataSourceReader) - extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + override def isStreaming: Boolean = true - override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + override def simpleString: String = "Streaming RelationV2 " + metadataString override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: StreamingDataSourceV2Relation => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..cb691ba297076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,10 +37,23 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, + @transient options: Map[String, String], @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = "ScanV2 " + metadataString + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2ScanExec => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c4e7644683c36..1ac9572de6412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case relation: DataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil - case relation: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala new file mode 100644 index 0000000000000..aed55a429bfd7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A trait that can be used by data source v2 related query plans(both logical and physical), to + * provide a string format of the data source information for explain. + */ +trait DataSourceV2StringFormat { + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The options for this data source reader. + */ + def options: Map[String, String] + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + private def sourceName: String = source match { + case registered: DataSourceRegister => registered.shortName() + case _ => source.getClass.getSimpleName.stripSuffix("$") + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + + if (filters.nonEmpty) { + entries += "Filters" -> filters.mkString("[", ", ", "]") + } + + // TODO: we should only display some standard options like path, table, etc. + if (options.nonEmpty) { + entries += "Options" -> Utils.redact(options).map { + case (k, v) => s"$k=$v" + }.mkString("[", ",", "]") + } + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) + }, " (", ", ", ")") + } else { + "" + } + + s"$sourceName$outputStr$entriesStr" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ff4be9c7ab874..6e231970f4a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} +import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,9 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -97,6 +100,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = dataSourceV2 -> options logInfo(s"Using MicroBatchReader [$reader] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) @@ -419,8 +423,19 @@ class MicroBatchExecution( toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + + val (source, options) = reader match { + // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` + // implementation. We provide a fake one here for explain. + case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + case _ => readerToDataSourceMap.getOrElse(reader, { + FakeDataSourceV2 -> Map.empty[String, String] + }) + } + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } @@ -525,3 +540,7 @@ class MicroBatchExecution( object MicroBatchExecution { val BATCH_ID_KEY = "streaming.sql.batchId" } + +object MemoryStreamDataSource extends DataSourceV2 + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index daebd1dd010ac..1758b3844bd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(source, options, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..c1ec1eba69fb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,20 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 08f722ecb10e5..e44aef09f1f3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -629,8 +629,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r + case r: StreamingExecutionRelation => r.source + case r: StreamingDataSourceV2Relation => r.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..f5884b9c8de12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From e8a259d66dda0d4c76f3af8933676bade8a7451d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 6 Mar 2018 13:55:13 +0100 Subject: [PATCH 0438/2461] [SPARK-23594][SQL] GetExternalRowField should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `GetExternalRowField`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20746 from maropu/SPARK-23594. --- .../expressions/objects/objects.scala | 14 ++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d832fe0a6857c..97e3ff88858d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1358,11 +1358,19 @@ case class GetExternalRowField( override def dataType: DataType = ObjectType(classOf[Object]) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + override def eval(input: InternalRow): Any = { + val inputRow = child.eval(input).asInstanceOf[Row] + if (inputRow == null) { + throw new RuntimeException("The input external row cannot be null.") + } + if (inputRow.isNullAt(index)) { + throw new RuntimeException(errMsg) + } + inputRow.get(index) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d535578a7eb06..0f376c4b63c15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ @@ -84,4 +85,23 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23594 GetExternalRowField should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") + Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => + checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + } + + // If an input row or a field are null, a runtime exception will be thrown + val errMsg1 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(null))) + }.getMessage + assert(errMsg1 === "The input external row cannot be null.") + + val errMsg2 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) + }.getMessage + assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + } } From 8bceb899dc3220998a4ea4021f3b477f78faaca8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 6 Mar 2018 08:52:28 -0600 Subject: [PATCH 0439/2461] [SPARK-23601][BUILD] Remove .md5 files from release ## What changes were proposed in this pull request? Remove .md5 files from release artifacts ## How was this patch tested? N/A Author: Sean Owen Closes #20737 from srowen/SPARK-23601. --- dev/create-release/release-build.sh | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index a3579f21fc539..c00b00b845401 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -164,8 +164,6 @@ if [[ "$1" == "package" ]]; then tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ --detach-sig spark-$SPARK_VERSION.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ - spark-$SPARK_VERSION.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION @@ -215,9 +213,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $R_DIST_NAME.asc \ --detach-sig $R_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $R_DIST_NAME > \ - $R_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ $R_DIST_NAME.sha512 @@ -234,9 +229,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $PYTHON_DIST_NAME.asc \ --detach-sig $PYTHON_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $PYTHON_DIST_NAME > \ - $PYTHON_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $PYTHON_DIST_NAME > \ $PYTHON_DIST_NAME.sha512 @@ -247,9 +239,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ - spark-$SPARK_VERSION-bin-$NAME.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 @@ -382,18 +371,11 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there + # this must have .asc and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 4c587eb4887623c839854c1505f495de42898229 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 17:42:17 +0100 Subject: [PATCH 0440/2461] [SPARK-23590][SQL] Add interpreted execution to CreateExternalRow ## What changes were proposed in this pull request? The PR adds interpreted execution to CreateExternalRow ## How was this patch tested? added UT Author: Marco Gaido Closes #20749 from mgaido91/SPARK-23590. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 ++++-- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 4 +++- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 97e3ff88858d0..721d589709131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1111,8 +1111,10 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val values = children.map(_.eval(input)).toArray + new GenericRowWithSchema(values, schema) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5cc..29f0cc0d991aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,6 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -60,7 +61,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], and MapData. + * Array[Byte], Spread[Double], MapData and Row. */ protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { @@ -88,6 +89,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0f376c4b63c15..50e57737a4612 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{IntegerType, ObjectType} +import org.apache.spark.sql.types._ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -86,6 +86,12 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23590: CreateExternalRow should support interpreted execution") { + val schema = new StructType().add("a", IntegerType).add("b", StringType) + val createExternalRow = CreateExternalRow(Seq(Literal(1), Literal("x")), schema) + checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From 04e71c31603af3a13bc13300df799f003fe185f7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 7 Mar 2018 17:01:29 +0800 Subject: [PATCH 0441/2461] [MINOR][YARN] Add disable yarn.nodemanager.vmem-check-enabled option to memLimitExceededLogMessage My spark application sometimes will throw `Container killed by YARN for exceeding memory limits`. Even I increased `spark.yarn.executor.memoryOverhead` to 10G, this error still happen. The latest config: memory-config And error message: ``` ExecutorLostFailure (executor 121 exited caused by one of the running tasks) Reason: Container killed by YARN for exceeding memory limits. 30.7 GB of 30 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead. ``` This is because of [Linux glibc >= 2.10 (RHEL 6) malloc may show excessive virtual memory usage](https://www.ibm.com/developerworks/community/blogs/kevgrig/entry/linux_glibc_2_10_rhel_6_malloc_may_show_excessive_virtual_memory_usage?lang=en). So disable `yarn.nodemanager.vmem-check-enabled` looks like a good option as [MapR mentioned ](https://mapr.com/blog/best-practices-yarn-resource-management). This PR add disable `yarn.nodemanager.vmem-check-enabled` option to memLimitExceededLogMessage. More details: https://issues.apache.org/jira/browse/YARN-4714 https://stackoverflow.com/a/31450291 https://stackoverflow.com/a/42091255 After this PR: yarn N/A Author: Yuming Wang Author: Yuming Wang Closes #20735 from wangyum/YARN-4714. Change-Id: Ie10836e2c07b6384d228c3f9e89f802823bd9f16 --- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 506adb363aa90..a537243d641cb 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -736,7 +736,8 @@ private object YarnAllocator { def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) val diag = if (matcher.find()) " " + matcher.group() + "." else "" - ("Container killed by YARN for exceeding memory limits." + diag - + " Consider boosting spark.yarn.executor.memoryOverhead.") + s"Container killed by YARN for exceeding memory limits. $diag " + + "Consider boosting spark.yarn.executor.memoryOverhead or " + + "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." } } From 33c2cb22b3b246a413717042a5f741da04ded69d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Mar 2018 13:10:51 +0100 Subject: [PATCH 0442/2461] [SPARK-23611][SQL] Add a helper function to check exception for expr evaluation ## What changes were proposed in this pull request? This pr added a helper function in `ExpressionEvalHelper` to check exceptions in all the path of expression evaluation. ## How was this patch tested? Modified the existing tests. Author: Takeshi Yamamuro Closes #20748 from maropu/SPARK-23611. --- .../expressions/ExpressionEvalHelper.scala | 83 ++++++++++++++----- .../expressions/MathExpressionsSuite.scala | 2 +- .../expressions/MiscExpressionsSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- .../expressions/ObjectExpressionsSuite.scala | 17 ++-- .../expressions/RegexpExpressionsSuite.scala | 8 +- .../expressions/StringExpressionsSuite.scala | 2 +- .../expressions/TimeWindowSuite.scala | 2 +- 8 files changed, 79 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 29f0cc0d991aa..58d0c07622eb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException @@ -45,11 +47,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } - protected def checkEvaluation( - expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + } + + protected def checkEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -95,7 +101,31 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + inputRow: InternalRow, + expectedErrMsg: String): Unit = { + + def checkException(eval: => Unit, testMode: String): Unit = { + withClue(s"($testMode)") { + val errMsg = intercept[T] { + eval + }.getMessage + if (errMsg != expectedErrMsg) { + fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") + } + } + } + val expr = prepareEvaluation(expression) + checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") + checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode") + } + } + + protected def evaluateWithoutCodegen( + expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { case n: Nondeterministic => n.initialize(0) case _ => @@ -124,7 +154,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!checkResult(actual, expected, expression.dataType)) { @@ -139,33 +169,29 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = evaluateWithGeneratedMutableProjection(expression, inputRow) + if (!checkResult(actual, expected, expression.dataType)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + private def evaluateWithGeneratedMutableProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected, expression.dataType)) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") - } + plan(inputRow).get(0, expression.dataType) } protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // SPARK-16489 Explicitly doing code generation twice so code gen will fail if - // some expression is reusing variable names across different instances. - // This behavior is tested in ExpressionEvalHelperSuite. - val plan = generateProject( - UnsafeProjection.create( - Alias(expression, s"Optimized($expression)1")() :: - Alias(expression, s"Optimized($expression)2")() :: Nil), - expression) - - val unsafeRow = plan(inputRow) + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -185,6 +211,21 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + private def evaluateWithUnsafeProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): InternalRow = { + // SPARK-16489 Explicitly doing code generation twice so code gen will fail if + // some expression is reusing variable names across different instances. + // This behavior is tested in ExpressionEvalHelperSuite. + val plan = generateProject( + UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil), + expression) + + plan(inputRow) + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, @@ -294,7 +335,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try { - evaluate(expr, inputRow) + evaluateWithoutCodegen(expr, inputRow) } catch { case e: Exception => fail(s"Exception evaluating $expr", e) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 39e0060d41dd4..3a094079380fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -124,7 +124,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkNaNWithoutCodegen( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!actual.asInstanceOf[Double].isNaN) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index facc863081303..a21c139fe71d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -41,6 +41,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("uuid") { checkEvaluation(Length(Uuid()), 36) - assert(evaluate(Uuid()) !== evaluate(Uuid())) + assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index cc6c15cb2c909..424c3a4696077 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -51,7 +51,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 50e57737a4612..cbfbb6573ae8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -100,14 +100,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // If an input row or a field are null, a runtime exception will be thrown - val errMsg1 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(null))) - }.getMessage - assert(errMsg1 === "The input external row cannot be null.") - - val errMsg2 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) - }.getMessage - assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(null)), + "The input external row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(Row(null))), + "The 0th field 'c0' of input row cannot be null.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2a0a42c65b086..d532dc4f77198 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -100,12 +100,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // invalid escaping val invalidEscape = intercept[AnalysisException] { - evaluate("""a""" like """\a""") + evaluateWithoutCodegen("""a""" like """\a""") } assert(invalidEscape.getMessage.contains("pattern")) val endEscape = intercept[AnalysisException] { - evaluate("""a""" like """a\""") + evaluateWithoutCodegen("""a""" like """a\""") } assert(endEscape.getMessage.contains("pattern")) @@ -147,11 +147,11 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") + evaluateWithoutCodegen("abbbbc" rlike "**") } intercept[java.util.regex.PatternSyntaxException] { val regex = 'a.string.at(0) - evaluate("abbbbc" rlike regex, create_row("**")) + evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 97ddbeba2c5ca..9a1a4da074ce3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -756,7 +756,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { - evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), + evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), Literal("QUERY"), Literal("???")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index d6c8fcf291842..351d4d0c2eac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -27,7 +27,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva test("time window is unevaluable") { intercept[UnsupportedOperationException] { - evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + evaluateWithoutCodegen(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) } } From aff7d81cb73133483fc2256ca10e21b4b8101647 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 7 Mar 2018 18:31:59 +0100 Subject: [PATCH 0443/2461] [SPARK-23591][SQL] Add interpreted execution to EncodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to EncodeUsingSerializer. ## How was this patch tested? added UT Author: Marco Gaido Closes #20751 from mgaido91/SPARK-23591. --- .../expressions/objects/objects.scala | 114 ++++++++++-------- .../expressions/ObjectExpressionsSuite.scala | 16 ++- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 721d589709131..7bbc3c732e782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -105,6 +105,61 @@ trait InvokeLike extends Expression with NonSQLExpression { } } +/** + * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]] + */ +trait SerializerSupport { + /** + * If true, Kryo serialization is used, otherwise the Java one is used + */ + val kryo: Boolean + + /** + * The serializer instance to be used for serialization/deserialization in interpreted execution + */ + lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo) + + /** + * Adds a immutable state to the generated class containing a reference to the serializer. + * @return a string containing the name of the variable referencing the serializer + */ + def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + val (serializerInstance, serializerInstanceClass) = { + if (kryo) { + ("kryoSerializer", + classOf[KryoSerializerInstance].getName) + } else { + ("javaSerializer", + classOf[JavaSerializerInstance].getName) + } + } + val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" + // Code to initialize the serializer + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + s""" + |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); + """.stripMargin) + serializerInstance + } +} + +object SerializerSupport { + /** + * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is + * `useKryo` is set to `true`) or a `JavaSerializerInstance`. + */ + def newSerializer(useKryo: Boolean): SerializerInstance = { + // try conf from env, otherwise create a new one + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val s = if (useKryo) { + new KryoSerializer(conf) + } else { + new JavaSerializer(conf) + } + s.newInstance() + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -1154,36 +1209,14 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to serialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) @@ -1207,33 +1240,10 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index cbfbb6573ae8e..346b13277c709 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -109,4 +110,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(null))), "The 0th field 'c0' of input row cannot be null.") } + + test("SPARK-23591: EncodeUsingSerializer should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val expected = serializer.newInstance().serialize(new Integer(1)).array() + val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) + checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) + checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 53561d27c45db31893bcabd4aca2387fde869b72 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Mar 2018 09:37:42 -0800 Subject: [PATCH 0444/2461] [SPARK-23291][SQL][R] R's substr should not reduce starting position by 1 when calling Scala API ## What changes were proposed in this pull request? Seems R's substr API treats Scala substr API as zero based and so subtracts the given starting position by 1. Because Scala's substr API also accepts zero-based starting position (treated as the first element), so the current R's substr test results are correct as they all use 1 as starting positions. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20464 from viirya/SPARK-23291. --- R/pkg/R/column.R | 10 ++++++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 1 + docs/sparkr.md | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9727efc354f10..7926a9a2467ee 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -161,12 +161,18 @@ setMethod("alias", #' @aliases substr,Column-method #' #' @param x a Column. -#' @param start starting position. +#' @param start starting position. It should be 1-base. #' @param stop ending position. +#' @examples +#' \dontrun{ +#' df <- createDataFrame(list(list(a="abcdef"))) +#' collect(select(df, substr(df$a, 1, 4))) # the result is `abcd`. +#' collect(select(df, substr(df$a, 2, 4))) # the result is `bcd`. +#' } #' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { - jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + jc <- callJMethod(x@jc, "substr", as.integer(start), as.integer(stop - start + 1)) column(jc) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bd0a0dcd0674c..439191adb23ea 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1651,6 +1651,7 @@ test_that("string operators", { expect_false(first(select(df, startsWith(df$name, "m")))[[1]]) expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(first(select(df, substr(df$name, 4, 6)))[[1]], "hae") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { expect_true(startsWith("Hello World", "Hello")) expect_false(endsWith("Hello World", "a")) diff --git a/docs/sparkr.md b/docs/sparkr.md index 6685b585a393a..2909247e79e95 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -663,3 +663,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. + +## Upgrading to Spark 2.4.0 + + - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. From c99fc9ad9b600095baba003053dbf84304ca392b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 7 Mar 2018 13:42:06 -0800 Subject: [PATCH 0445/2461] [SPARK-23550][CORE] Cleanup `Utils`. A few different things going on: - Remove unused methods. - Move JSON methods to the only class that uses them. - Move test-only methods to TestUtils. - Make getMaxResultSize() a config constant. - Reuse functionality from existing libraries (JRE or JavaUtils) where possible. The change also includes changes to a few tests to call `Utils.createTempFile` correctly, so that temp dirs are created under the designated top-level temp dir instead of potentially polluting git index. Author: Marcelo Vanzin Closes #20706 from vanzin/SPARK-23550. --- .../scala/org/apache/spark/TestUtils.scala | 26 +++- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 4 +- .../spark/internal/config/ConfigBuilder.scala | 3 +- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/TaskSetManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 124 +++++++++-------- .../scala/org/apache/spark/util/Utils.scala | 131 +----------------- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../scala/org/apache/spark/DriverSuite.scala | 2 +- .../spark/deploy/SparkSubmitSuite.scala | 18 +-- .../spark/scheduler/ReplayListenerSuite.scala | 12 +- .../org/apache/spark/util/UtilsSuite.scala | 1 + scalastyle-config.xml | 2 +- .../datasources/orc/OrcSourceSuite.scala | 4 +- .../metric/SQLMetricsTestUtils.scala | 6 +- .../sql/sources/PartitionedWriteSuite.scala | 15 +- .../HiveThriftServer2Suites.scala | 2 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 20 +-- .../spark/streaming/CheckpointSuite.scala | 2 +- .../spark/streaming/MapWithStateSuite.scala | 2 +- 21 files changed, 152 insertions(+), 236 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 93e7ee3d2a404..b5c4c705dcbc7 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,7 +22,7 @@ import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate -import java.util.Arrays +import java.util.{Arrays, Properties} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ @@ -35,6 +35,7 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} +import org.apache.log4j.PropertyConfigurator import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -256,6 +257,29 @@ private[spark] object TestUtils { s"Can't find $numExecutors executors before $timeout milliseconds elapsed") } + /** + * config a log4j properties used for testsuite + */ + def configTestLog4j(level: String): Unit = { + val pro = new Properties() + pro.put("log4j.rootLogger", s"$level, console") + pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") + pro.put("log4j.appender.console.target", "System.err") + pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") + pro.put("log4j.appender.console.layout.ConversionPattern", + "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") + PropertyConfigurator.configure(pro) + } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9db7a1fe3106d..e7796d4ddbe34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{ByteArrayOutputStream, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream} import java.lang.reflect.InvocationTargetException import java.net.URI import java.nio.charset.StandardCharsets @@ -233,7 +233,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { - name = Utils.stripDirectory(primaryResource) + name = new File(primaryResource).getName() } // Action should be SUBMIT unless otherwise specified diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2c3a8ef74800b..dcec3ec21b546 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -35,6 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} @@ -141,8 +142,7 @@ private[spark] class Executor( conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20), RpcUtils.maxMessageSizeBytes(conf)) - // Limit of bytes for total size of results (default is 1GB) - private val maxResultSize = Utils.getMaxResultSize(conf) + private val maxResultSize = conf.get(MAX_RESULT_SIZE) // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b0cd7110a3b47..f27aca03773a9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -23,6 +23,7 @@ import java.util.regex.PatternSyntaxException import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.util.Utils private object ConfigHelpers { @@ -45,7 +46,7 @@ private object ConfigHelpers { } def stringToSeq[T](str: String, converter: String => T): Seq[T] = { - str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + Utils.stringToSeq(str).map(converter) } def seqToString[T](v: Seq[T], stringConverter: T => String): String = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bbfcfbaa7363c..a313ad0554a3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -520,4 +520,9 @@ package object config { .checkValue(v => v > 0, "The threshold should be positive.") .createWithDefault(10000000) + private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize") + .doc("Size limit for results.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1g") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 886c2c99f1ff3..d958658527f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -64,8 +64,7 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) - // Limit of bytes for total size of results (default is 1GB) - val maxResultSize = Utils.getMaxResultSize(conf) + val maxResultSize = conf.get(config.MAX_RESULT_SIZE) val speculationEnabled = conf.getBoolean("spark.speculation", false) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ff83301d631c4..40383fe05026b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -48,7 +48,7 @@ import org.apache.spark.storage._ * To ensure that we provide these guarantees, follow these rules when modifying these methods: * * - Never delete any JSON fields. - * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * - Any new JSON fields should be optional; use `jsonOption` when reading these fields * in `*FromJson` methods. */ private[spark] object JsonProtocol { @@ -408,7 +408,7 @@ private[spark] object JsonProtocol { ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => ("Kill Reason" -> taskKilled.reason) - case _ => Utils.emptyJson + case _ => emptyJson } ("Reason" -> reason) ~ json } @@ -422,7 +422,7 @@ private[spark] object JsonProtocol { def jobResultToJson(jobResult: JobResult): JValue = { val result = Utils.getFormattedClassName(jobResult) val json = jobResult match { - case JobSucceeded => Utils.emptyJson + case JobSucceeded => emptyJson case jobFailed: JobFailed => JObject("Exception" -> exceptionToJson(jobFailed.exception)) } @@ -573,7 +573,7 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -586,7 +586,7 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -597,11 +597,11 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] val submissionTime = - Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 - val stageInfos = Utils.jsonOption(json \ "Stage Infos") + val stageInfos = jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") @@ -613,7 +613,7 @@ private[spark] object JsonProtocol { def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] val completionTime = - Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") SparkListenerJobEnd(jobId, completionTime, jobResult) } @@ -630,15 +630,15 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) - val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val maxOnHeapMem = jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) SparkListenerBlockManagerRemoved(time, blockManagerId) } @@ -648,11 +648,11 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] - val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) + val appId = jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - val appAttemptId = Utils.jsonOption(json \ "App Attempt ID").map(_.extract[String]) - val driverLogs = Utils.jsonOption(json \ "Driver Logs").map(mapFromJson) + val appAttemptId = jsonOption(json \ "App Attempt ID").map(_.extract[String]) + val driverLogs = jsonOption(json \ "Driver Logs").map(mapFromJson) SparkListenerApplicationStart(appName, appId, time, sparkUser, appAttemptId, driverLogs) } @@ -703,19 +703,19 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + val attemptId = jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") - val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) - val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) - val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) + val details = jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") + val submissionTime = jsonOption(json \ "Submission Time").map(_.extract[Long]) + val completionTime = jsonOption(json \ "Completion Time").map(_.extract[Long]) + val failureReason = jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = { - Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -735,17 +735,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) + val attempt = jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] val executorId = (json \ "Executor ID").extract[String].intern() val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) + val speculative = jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) - val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { + val killed = jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -762,13 +762,13 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) - val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } - val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val name = jsonOption(json \ "Name").map(_.extract[String]) + val update = jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } + val internal = jsonOption(json \ "Internal").exists(_.extract[Boolean]) val countFailedValues = - Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) - val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) + jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -821,49 +821,49 @@ private[spark] object JsonProtocol { metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) // Shuffle read metrics - Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => val readMetrics = metrics.createTempShuffleReadMetrics() readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + jsonOption(readJson \ "Remote Bytes Read To Disk") .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( - Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) readMetrics.incRecordsRead( - Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } // Shuffle write metrics // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. - Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) writeMetrics.incRecordsWritten( - Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } // Output metrics - Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) outputMetrics.setRecordsWritten( - Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics - Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) inputMetrics.incRecordsRead( - Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks - Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + jsonOption(json \ "Updated Blocks").foreach { blocksJson => metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") @@ -897,7 +897,7 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + val message = jsonOption(json \ "Message").map(_.extract[String]) new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => @@ -905,9 +905,9 @@ private[spark] object JsonProtocol { val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") val fullStackTrace = - Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull + jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x - val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + val accumUpdates = jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => { acc.toInfo(Some(acc.value), None) @@ -915,21 +915,21 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => - val killReason = Utils.jsonOption(json \ "Kill Reason") + val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility // for reading those logs, we need to provide default values for all the fields. - val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) - val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) - val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + val jobId = jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => - val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) - val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + val exitCausedByApp = jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) + val executorId = jsonOption(json \ "Executor ID").map(_.extract[String]) + val reason = jsonOption(json \ "Loss Reason").map(_.extract[String]) ExecutorLostFailure( executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true), @@ -968,11 +968,11 @@ private[spark] object JsonProtocol { def rddInfoFromJson(json: JValue): RDDInfo = { val rddId = (json \ "RDD ID").extract[Int] val name = (json \ "Name").extract[String] - val scope = Utils.jsonOption(json \ "Scope") + val scope = jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val callsite = jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) val storageLevel = storageLevelFromJson(json \ "Storage Level") @@ -1029,7 +1029,7 @@ private[spark] object JsonProtocol { } def propertiesFromJson(json: JValue): Properties = { - Utils.jsonOption(json).map { value => + jsonOption(json).map { value => val properties = new Properties mapFromJson(json).foreach { case (k, v) => properties.setProperty(k, v) } properties @@ -1058,4 +1058,14 @@ private[spark] object JsonProtocol { e } + /** Return an option that translates JNothing to None */ + private def jsonOption(json: JValue): Option[JValue] = { + json match { + case JNothing => None + case value: JValue => Some(value) + } + } + + private def emptyJson: JObject = JObject(List[JField]()) + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2e2a4a259e9af..29d26ea2c85df 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} +import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -51,9 +51,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException -import org.json4s._ import org.slf4j.Logger import org.apache.spark._ @@ -1017,70 +1015,18 @@ private[spark] object Utils extends Logging { " " + (System.currentTimeMillis - startTimeMs) + " ms" } - private def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - /** - * Lists files recursively. - */ - def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } - /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. * Throws an exception if deletion is unsuccessful. */ - def deleteRecursively(file: File) { + def deleteRecursively(file: File): Unit = { if (file != null) { - try { - if (file.isDirectory && !isSymlink(file)) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - ShutdownHookManager.removeShutdownDeleteDir(file) - } - } finally { - if (file.delete()) { - logTrace(s"${file.getAbsolutePath} has been deleted") - } else { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } + JavaUtils.deleteRecursively(file) + ShutdownHookManager.removeShutdownDeleteDir(file) } } - /** - * Check to see if file is a symbolic link. - */ - def isSymlink(file: File): Boolean = { - return Files.isSymbolicLink(Paths.get(file.toURI)) - } - /** * Determines if a directory contains any files newer than cutoff seconds. * @@ -1828,7 +1774,7 @@ private[spark] object Utils extends Logging { * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower * in the current version of Scala. */ - def getIteratorSize[T](iterator: Iterator[T]): Long = { + def getIteratorSize(iterator: Iterator[_]): Long = { var count = 0L while (iterator.hasNext) { count += 1L @@ -1875,17 +1821,6 @@ private[spark] object Utils extends Logging { obj.getClass.getSimpleName.replace("$", "") } - /** Return an option that translates JNothing to None */ - def jsonOption(json: JValue): Option[JValue] = { - json match { - case JNothing => None - case value: JValue => Some(value) - } - } - - /** Return an empty JSON object */ - def emptyJson: JsonAST.JObject = JObject(List[JField]()) - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -1900,15 +1835,6 @@ private[spark] object Utils extends Logging { getHadoopFileSystem(new URI(path), conf) } - /** - * Return the absolute path of a file in the given directory. - */ - def getFilePath(dir: File, fileName: String): Path = { - assert(dir.isDirectory) - val path = new File(dir, fileName).getAbsolutePath - new Path(path) - } - /** * Whether the underlying operating system is Windows. */ @@ -1931,13 +1857,6 @@ private[spark] object Utils extends Logging { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } - /** - * Strip the directory from a path name - */ - def stripDirectory(path: String): String = { - new File(path).getName - } - /** * Terminates a process waiting for at most the specified duration. * @@ -2348,36 +2267,6 @@ private[spark] object Utils extends Logging { org.apache.log4j.Logger.getRootLogger().setLevel(l) } - /** - * config a log4j properties used for testsuite - */ - def configTestLog4j(level: String): Unit = { - val pro = new Properties() - pro.put("log4j.rootLogger", s"$level, console") - pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") - pro.put("log4j.appender.console.target", "System.err") - pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") - pro.put("log4j.appender.console.layout.ConversionPattern", - "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") - PropertyConfigurator.configure(pro) - } - - def invoke( - clazz: Class[_], - obj: AnyRef, - methodName: String, - args: (Class[_], AnyRef)*): AnyRef = { - val (types, values) = args.unzip - val method = clazz.getDeclaredMethod(methodName, types: _*) - method.setAccessible(true) - method.invoke(obj, values.toSeq: _*) - } - - // Limit of bytes for total size of results (default is 1GB) - def getMaxResultSize(conf: SparkConf): Long = { - memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 - } - /** * Return the current system LD_LIBRARY_PATH name */ @@ -2610,16 +2499,6 @@ private[spark] object Utils extends Logging { SignalUtils.registerLogger(log) } - /** - * Unions two comma-separated lists of files and filters out empty strings. - */ - def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { - var allFiles = Set.empty[String] - leftList.foreach { value => allFiles ++= value.split(",") } - rightList.foreach { value => allFiles ++= value.split(",") } - allFiles.filter { _.nonEmpty } - } - /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 24a55df84a240..0d5c5ea7903e9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -95,7 +95,7 @@ public void tearDown() { @SuppressWarnings("unchecked") public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir("test", "test"); + tempDir = Utils.createTempDir(null, "test"); mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 962945e5b6bb1..896cd2e80aaef 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits { */ object DriverWithoutCleanup { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 803a38d77fb82..d265643a80b4e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.TestUtils import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ @@ -761,18 +762,6 @@ class SparkSubmitSuite } } - test("comma separated list of files are unioned correctly") { - val left = Option("/tmp/a.jar,/tmp/b.jar") - val right = Option("/tmp/c.jar,/tmp/a.jar") - val emptyString = Option("") - Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar")) - Utils.unionFileLists(emptyString, emptyString) should be (Set.empty) - Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) - } - test("support glob path") { val tmpJarDir = Utils.createTempDir() val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) @@ -1042,6 +1031,7 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1076,7 +1066,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits { object JarCreationTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -1100,7 +1090,7 @@ object JarCreationTest extends Logging { object SimpleApplicationTest { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 73e7b3fe8c1de..e24d550a62665 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -47,7 +47,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Simple replay") { - val logFilePath = Utils.getFilePath(testDir, "events.txt") + val logFilePath = getFilePath(testDir, "events.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, @@ -97,7 +97,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // scalastyle:on println } - val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress") + val logFilePath = getFilePath(testDir, "events.lz4.inprogress") val bytes = buffered.toByteArray Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream => fstream.write(bytes, 0, buffered.size) @@ -129,7 +129,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Replay incompatible event log") { - val logFilePath = Utils.getFilePath(testDir, "incompatible.txt") + val logFilePath = getFilePath(testDir, "incompatible.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Incompatible App", None, @@ -226,6 +226,12 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } } + private def getFilePath(dir: File, fileName: String): Path = { + assert(dir.isDirectory) + val path = new File(dir, fileName).getAbsolutePath + new Path(path) + } + /** * A simple listener that buffers all the events it receives. * diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index eaea6b030c154..3b4273184f1e9 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -648,6 +648,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") + sourceDir.mkdir() val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e2fa5754afaee..e65e3aafe5b5b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -229,7 +229,7 @@ This file is divided into 3 sections: extractOpt - Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter is slower. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 523f7cf77e103..8a3bbd03a26dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -39,8 +39,8 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { protected override def beforeAll(): Unit = { super.beforeAll() - orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - orcTableDir = Utils.createTempDir("orctests", "sparksql") + orcTableAsDir = Utils.createTempDir(namePrefix = "orctests") + orcTableDir = Utils.createTempDir(namePrefix = "orctests") sparkContext .makeRDD(1 to 10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 122d28798136f..534d8bb629b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -21,13 +21,13 @@ import java.io.File import scala.collection.mutable.HashMap +import org.apache.spark.TestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils trait SQLMetricsTestUtils extends SQLTestUtils { @@ -91,7 +91,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) .write.format(dataFormat).mode("overwrite").insertInto(tableName) } - assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } } @@ -121,7 +121,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { .mode("overwrite") .insertInto(tableName) } - assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 0fe33e87318a5..27c983f270bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.TestUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils @@ -86,15 +87,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -106,7 +107,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -133,14 +134,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -148,7 +149,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 496f8c82a6c61..b32c547cefefe 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -804,7 +804,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected var metastorePath: File = _ protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ private var logTailingProcess: Process = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 2d31781132edc..079fe45860544 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -330,7 +330,7 @@ class HiveSparkSubmitSuite object SetMetastoreURLTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true) val builder = SparkSession.builder() @@ -368,7 +368,7 @@ object SetMetastoreURLTest extends Logging { object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = @@ -447,7 +447,7 @@ object SetWarehouseLocationTest extends Logging { // can load the jar defined with the function. object TemporaryHiveUDFTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -485,7 +485,7 @@ object TemporaryHiveUDFTest extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest1 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -523,7 +523,7 @@ object PermanentHiveUDFTest1 extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest2 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -558,7 +558,7 @@ object PermanentHiveUDFTest2 extends Logging { // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set("spark.ui.enabled", "false") @@ -628,7 +628,7 @@ object SparkSubmitClassLoaderTest extends Logging { // We test if we can correctly set spark sql configurations when HiveContext is used. object SparkSQLConfTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") // We override the SparkConf to add spark.sql.hive.metastore.version and // spark.sql.hive.metastore.jars to the beginning of the conf entry array. // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but @@ -669,7 +669,7 @@ object SPARK_9757 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( @@ -718,7 +718,7 @@ object SPARK_11009 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -749,7 +749,7 @@ object SPARK_14244 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ee2fd45a7e851..19b621f11759d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -97,7 +97,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => val batchDurationMillis = batchDuration.milliseconds // Setup the stream computation - val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + val checkpointDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString logDebug(s"Using checkpoint directory $checkpointDir") val ssc = createContextForCheckpointOperation(batchDuration) require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 3b662ec1833aa..06c0c2aa97ee1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -39,7 +39,7 @@ class MapWithStateSuite extends SparkFunSuite before { StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir("checkpoint") + checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") } after { From ac76eff6a88f6358a321b84cb5e60fb9d6403419 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 7 Mar 2018 13:51:44 -0800 Subject: [PATCH 0446/2461] [SPARK-23525][SQL] Support ALTER TABLE CHANGE COLUMN COMMENT for external hive table ## What changes were proposed in this pull request? The following query doesn't work as expected: ``` CREATE EXTERNAL TABLE ext_table(a STRING, b INT, c STRING) PARTITIONED BY (d STRING) LOCATION 'sql/core/spark-warehouse/ext_table'; ALTER TABLE ext_table CHANGE a a STRING COMMENT "new comment"; DESC ext_table; ``` The comment of column `a` is not updated, that's because `HiveExternalCatalog.doAlterTable` ignores table schema changes. To fix the issue, we should call `doAlterTableDataSchema` instead of `doAlterTable`. ## How was this patch tested? Updated `DDLSuite.testChangeColumn`. Author: Xingbo Jiang Closes #20696 from jiangxb1987/alterColumnComment. --- .../spark/sql/execution/command/ddl.scala | 12 ++++++------ .../sql-tests/inputs/change-column.sql | 1 + .../sql-tests/results/change-column.sql.out | 19 ++++++++++++++----- .../sql/execution/command/DDLSuite.scala | 1 + 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 964cbca049b27..bf4d96fa18d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -314,8 +314,8 @@ case class AlterTableChangeColumnCommand( val resolver = sparkSession.sessionState.conf.resolver DDLUtils.verifyAlterTableType(catalog, table, isView = false) - // Find the origin column from schema by column name. - val originColumn = findColumnByName(table.schema, columnName, resolver) + // Find the origin column from dataSchema by column name. + val originColumn = findColumnByName(table.dataSchema, columnName, resolver) // Throw an AnalysisException if the column name/dataType is changed. if (!columnEqual(originColumn, newColumn, resolver)) { throw new AnalysisException( @@ -324,7 +324,7 @@ case class AlterTableChangeColumnCommand( s"'${newColumn.name}' with type '${newColumn.dataType}'") } - val newSchema = table.schema.fields.map { field => + val newDataSchema = table.dataSchema.fields.map { field => if (field.name == originColumn.name) { // Create a new column from the origin column with the new comment. addComment(field, newColumn.getComment) @@ -332,8 +332,7 @@ case class AlterTableChangeColumnCommand( field } } - val newTable = table.copy(schema = StructType(newSchema)) - catalog.alterTable(newTable) + catalog.alterTableDataSchema(tableName, StructType(newDataSchema)) Seq.empty[Row] } @@ -345,7 +344,8 @@ case class AlterTableChangeColumnCommand( schema.fields.collectFirst { case field if resolver(field.name, name) => field }.getOrElse(throw new AnalysisException( - s"Invalid column reference '$name', table schema is '${schema}'")) + s"Can't find column `$name` given table data columns " + + s"${schema.fieldNames.mkString("[`", "`, `", "`]")}")) } // Add the comment to a column, if comment is empty, return the original column. diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index ad0f885f63d3d..2909024e4c9f7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column -- Change column in partition spec (not supported yet) CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d); ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT; +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c79..ff1ecbcc44c23 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 33 -- !query 0 @@ -154,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))'; +Can't find column `invalid_col` given table data columns [`a`, `b`, `c`]; -- !query 16 @@ -291,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT -- !query 30 -DROP TABLE test_change +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C' -- !query 30 schema struct<> -- !query 30 output - +org.apache.spark.sql.AnalysisException +Can't find column `c` given table data columns [`a`, `b`]; -- !query 31 -DROP TABLE partition_table +DROP TABLE test_change -- !query 31 schema struct<> -- !query 31 output + + +-- !query 32 +DROP TABLE partition_table +-- !query 32 schema +struct<> +-- !query 32 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index db9023b7ec8b6..4041176262426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1597,6 +1597,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Ensure that change column will preserve other metadata fields. sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") assert(getMetadata("col1").getString("key") == "value") + assert(getMetadata("col1").getString("comment") == "this is col1") } test("drop build-in function") { From 77c91cc746f93e609c412f3a220495d9e931f696 Mon Sep 17 00:00:00 2001 From: jx158167 Date: Wed, 7 Mar 2018 20:08:32 -0800 Subject: [PATCH 0447/2461] [SPARK-23524] Big local shuffle blocks should not be checked for corruption. ## What changes were proposed in this pull request? In current code, all local blocks will be checked for corruption no matter it's big or not. The reasons are as below: Size in FetchResult for local block is set to be 0 (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L327) SPARK-4105 meant to only check the small blocks(size Closes #20685 from jinxing64/SPARK-23524. --- .../storage/ShuffleBlockFetcherIterator.scala | 14 +++--- .../ShuffleBlockFetcherIteratorSuite.scala | 45 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a4529..dd9df74689a13 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() @@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - bytesInFlight -= size + if (!localBlocks.contains(blockId)) { + bytesInFlight -= size + } if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) @@ -583,8 +587,8 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block successfully. * @param blockId block id * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5bfe9905ff17b..692ae3bf597e0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("big blocks are not checked for corruption") { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + doReturn(10000L).when(corruptBuffer).size() + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ) + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ) + + val transfer = createMockTransfer( + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths), + (remoteBmId, remoteBlockLengths) + ) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + // Blocks should be returned without exceptions. + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) From fe22f32041572596a22e5f7441fa0bfbd9608648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Mar 2018 10:50:09 +0100 Subject: [PATCH 0448/2461] [SPARK-23620] Splitting thread dump lines by using the br tag ## What changes were proposed in this pull request? I propose to replace `'\n'` by the `
    ` tag in generated html of thread dump page. The `
    ` tag will split thread lines in more reliable way. For now it could look like on the screen shot if the html is proxied and `'\n'` is replaced by another whitespace. The changes allow to more easily read and copy stack traces. ## How was this patch tested? I tested it manually by checking the thread dump page and its source. Author: Maxim Gekk Closes #20762 from MaxGekk/br-thread-dump. --- .../org/apache/spark/status/api/v1/api.scala | 24 ++++++++++++++++++- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 369e98b683b1a..971d7e90fa7b8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -19,6 +19,8 @@ package org.apache.spark.status.api.v1 import java.lang.{Long => JLong} import java.util.Date +import scala.xml.{NodeSeq, Text} + import com.fasterxml.jackson.annotation.JsonIgnoreProperties import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -317,11 +319,31 @@ class RuntimeInfo private[spark]( val javaHome: String, val scalaVersion: String) +case class StackTrace(elems: Seq[String]) { + override def toString: String = elems.mkString + + def html: NodeSeq = { + val withNewLine = elems.foldLeft(NodeSeq.Empty) { (acc, elem) => + if (acc.isEmpty) { + acc :+ Text(elem) + } else { + acc :+
    :+ Text(elem) + } + } + + withNewLine + } + + def mkString(start: String, sep: String, end: String): String = { + elems.mkString(start, sep, end) + } +} + case class ThreadStackTrace( val threadId: Long, val threadName: String, val threadState: Thread.State, - val stackTrace: String, + val stackTrace: StackTrace, val blockedByThreadId: Option[Long], val blockedByLock: String, val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 7a9aaf29a8b05..9bb026c60565e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,7 +60,7 @@ private[ui] class ExecutorThreadDumpPage( {thread.threadName} {thread.threadState} {blockedBy}{heldLocks} - {thread.stackTrace} + {thread.stackTrace.html} } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 29d26ea2c85df..5caedeb526469 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -61,7 +61,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.status.api.v1.ThreadStackTrace +import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2118,14 +2118,14 @@ private[spark] object Utils extends Logging { private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => + val stackTrace = StackTrace(threadInfo.getStackTrace.map { frame => monitors.get(frame) match { case Some(monitor) => monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" case None => frame.toString } - }.mkString("\n") + }) // use a set to dedup re-entrant locks that are held at multiple places val heldLocks = From 9bb239c8b174d31981dfff63baa38bb8cecfe913 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Mar 2018 20:19:55 +0900 Subject: [PATCH 0449/2461] [SPARK-23159][PYTHON] Update cloudpickle to v0.4.3 ## What changes were proposed in this pull request? The version of cloudpickle in PySpark was close to version 0.4.0 with some additional backported fixes and some minor additions for Spark related things. This update removes Spark related changes and matches cloudpickle [v0.4.3](https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.3): Changes by updating to 0.4.3 include: * Fix pickling of named tuples https://github.com/cloudpipe/cloudpickle/pull/113 * Built in type constructors for PyPy compatibility [here](https://github.com/cloudpipe/cloudpickle/commit/d84980ccaafc7982a50d4e04064011f401f17d1b) * Fix memoryview support https://github.com/cloudpipe/cloudpickle/pull/122 * Improved compatibility with other cloudpickle versions https://github.com/cloudpipe/cloudpickle/pull/128 * Several cleanups https://github.com/cloudpipe/cloudpickle/pull/121 and [here](https://github.com/cloudpipe/cloudpickle/commit/c91aaf110441991307f5097f950764079d0f9652) * [MRG] Regression on pickling classes from the __main__ module https://github.com/cloudpipe/cloudpickle/pull/149 * BUG: Handle instance methods of builtin types https://github.com/cloudpipe/cloudpickle/pull/154 * Fix #129 : do not silence RuntimeError in dump() https://github.com/cloudpipe/cloudpickle/pull/153 ## How was this patch tested? Existing pyspark.tests using python 2.7.14, 3.5.2, 3.6.3 Author: Bryan Cutler Closes #20373 from BryanCutler/pyspark-update-cloudpickle-42-SPARK-23159. --- python/pyspark/accumulators.py | 1 - python/pyspark/cloudpickle.py | 320 ++++++++++++++------------------- python/pyspark/serializers.py | 14 +- 3 files changed, 151 insertions(+), 184 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc747..7def676b89a24 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,7 +94,6 @@ else: import socketserver as SocketServer import threading -from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 40e91a2d0655d..ea845b98b3db2 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -57,7 +57,6 @@ import types import weakref -from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -181,6 +180,32 @@ def _builtin_type(name): return getattr(types, name) +def _make__new__factory(type_): + def _factory(): + return type_.__new__ + return _factory + + +# NOTE: These need to be module globals so that they're pickleable as globals. +_get_dict_new = _make__new__factory(dict) +_get_frozenset_new = _make__new__factory(frozenset) +_get_list_new = _make__new__factory(list) +_get_set_new = _make__new__factory(set) +_get_tuple_new = _make__new__factory(tuple) +_get_object_new = _make__new__factory(object) + +# Pre-defined set of builtin_function_or_method instances that can be +# serialized. +_BUILTIN_TYPE_CONSTRUCTORS = { + dict.__new__: _get_dict_new, + frozenset.__new__: _get_frozenset_new, + set.__new__: _get_set_new, + list.__new__: _get_list_new, + tuple.__new__: _get_tuple_new, + object.__new__: _get_object_new, +} + + if sys.version_info < (3, 4): def _walk_global_ops(code): """ @@ -237,28 +262,16 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - except pickle.PickleError: - raise - except Exception as e: - emsg = _exception_message(e) - if "'i' format requires" in emsg: - msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) - print_exec(sys.stderr) - raise pickle.PicklingError(msg) - + raise def save_memoryview(self, obj): - """Fallback to save_string""" - Pickler.save_string(self, str(obj)) + self.save(obj.tobytes()) + dispatch[memoryview] = save_memoryview - def save_buffer(self, obj): - """Fallback to save_string""" - Pickler.save_string(self,str(obj)) - if PY3: - dispatch[memoryview] = save_memoryview - else: + if not PY3: + def save_buffer(self, obj): + self.save(str(obj)) dispatch[buffer] = save_buffer def save_unsupported(self, obj): @@ -318,6 +331,24 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ + try: + should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS + except TypeError: + # Methods of builtin types aren't hashable in python 2. + should_special_case = False + + if should_special_case: + # We keep a special-cased cache of built-in type constructors at + # global scope, because these functions are structured very + # differently in different python versions and implementations (for + # example, they're instances of types.BuiltinFunctionType in + # CPython, but they're ordinary types.FunctionType instances in + # PyPy). + # + # If the function we've received is in that cache, we just + # serialize it as a lookup into the cache. + return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj) + write = self.write if name is None: @@ -344,7 +375,7 @@ def save_function(self, obj, name=None): return self.save_global(obj, name) # a builtin_function_or_method which comes in as an attribute of some - # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # object (e.g., itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. @@ -352,16 +383,13 @@ def save_function(self, obj, name=None): # for different python versions. if not hasattr(obj, '__code__'): if PY3: - if sys.version_info < (3, 4): - raise pickle.PicklingError("Can't pickle %r" % obj) - else: - rv = obj.__reduce_ex__(self.proto) + rv = obj.__reduce_ex__(self.proto) else: if hasattr(obj, '__self__'): rv = (getattr, (obj.__self__, name)) else: raise pickle.PicklingError("Can't pickle %r" % obj) - return Pickler.save_reduce(self, obj=obj, *rv) + return self.save_reduce(obj=obj, *rv) # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a @@ -420,20 +448,18 @@ def save_dynamic_class(self, obj): from global modules. """ clsdict = dict(obj.__dict__) # copy dict proxy to a dict - if not isinstance(clsdict.get('__dict__', None), property): - # don't extract dict that are properties - clsdict.pop('__dict__', None) - clsdict.pop('__weakref__', None) - - # hack as __new__ is stored differently in the __dict__ - new_override = clsdict.get('__new__', None) - if new_override: - clsdict['__new__'] = obj.__new__ - - # namedtuple is a special case for Spark where we use the _load_namedtuple function - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return + clsdict.pop('__weakref__', None) + + # On PyPy, __doc__ is a readonly attribute, so we need to include it in + # the initial skeleton class. This is safe because we know that the + # doc can't participate in a cycle with the original class. + type_kwargs = {'__doc__': clsdict.pop('__doc__', None)} + + # If type overrides __dict__ as a property, include it in the type kwargs. + # In Python 2, we can't set this attribute after construction. + __dict__ = clsdict.pop('__dict__', None) + if isinstance(__dict__, property): + type_kwargs['__dict__'] = __dict__ save = self.save write = self.write @@ -453,23 +479,12 @@ def save_dynamic_class(self, obj): # Push the rehydration function. save(_rehydrate_skeleton_class) - # Mark the start of the args for the rehydration function. + # Mark the start of the args tuple for the rehydration function. write(pickle.MARK) - # On PyPy, __doc__ is a readonly attribute, so we need to include it in - # the initial skeleton class. This is safe because we know that the - # doc can't participate in a cycle with the original class. - doc_dict = {'__doc__': clsdict.pop('__doc__', None)} - - # Create and memoize an empty class with obj's name and bases. - save(type(obj)) - save(( - obj.__name__, - obj.__bases__, - doc_dict, - )) - write(pickle.REDUCE) - self.memoize(obj) + # Create and memoize an skeleton class with obj's name and bases. + tp = type(obj) + self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -522,17 +537,22 @@ def save_function_tuple(self, func): self.memoize(func) # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(dct) - save(func.__module__) - save(closure_values) + state = { + 'globals': f_globals, + 'defaults': defaults, + 'dict': dct, + 'module': func.__module__, + 'closure_values': closure_values, + } + if hasattr(func, '__qualname__'): + state['qualname'] = func.__qualname__ + save(state) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple _extract_code_globals_cache = ( weakref.WeakKeyDictionary() - if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") + if not hasattr(sys, "pypy_version_info") else {}) @classmethod @@ -608,37 +628,22 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = '__main__' + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) - if modname == '__main__': - themodule = None - else: - __import__(modname) - themodule = sys.modules[modname] - self.modules.add(themodule) + try: + return Pickler.save_global(self, obj, name=name) + except Exception: + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - if hasattr(themodule, name) and getattr(themodule, name) is obj: - return Pickler.save_global(self, obj, name) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + return self.save_dynamic_class(obj) - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - self.save_dynamic_class(obj) - else: - raise pickle.PicklingError("Can't pickle %r" % obj) + raise dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -709,12 +714,7 @@ def save_property(self, obj): dispatch[property] = save_property def save_classmethod(self, obj): - try: - orig_func = obj.__func__ - except AttributeError: # Python 2.6 - orig_func = obj.__get__(None, object) - if isinstance(obj, classmethod): - orig_func = orig_func.__func__ # Unbind + orig_func = obj.__func__ self.save_reduce(type(obj), (orig_func,), obj=obj) dispatch[classmethod] = save_classmethod dispatch[staticmethod] = save_classmethod @@ -754,64 +754,6 @@ def __getattribute__(self, item): if type(operator.attrgetter) is type: dispatch[operator.attrgetter] = save_attrgetter - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): - # Assert that args is a tuple or None - if not isinstance(args, tuple): - raise pickle.PicklingError("args from reduce() should be a tuple") - - # Assert that func is callable - if not hasattr(func, '__call__'): - raise pickle.PicklingError("func from reduce should be callable") - - save = self.save - write = self.write - - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - save(state) - write(pickle.BUILD) - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): # 2.7 supports partial pickling - dispatch[partial] = save_partial - - def save_file(self, obj): """Save a file""" try: @@ -867,23 +809,21 @@ def save_not_implemented(self, obj): dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented - # WeakSet was added in 2.7. - if hasattr(weakref, 'WeakSet'): - def save_weakset(self, obj): - self.save_reduce(weakref.WeakSet, (list(obj),)) - - dispatch[weakref.WeakSet] = save_weakset + def save_weakset(self, obj): + self.save_reduce(weakref.WeakSet, (list(obj),)) - """Special functions for Add-on libraries""" - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - pass + dispatch[weakref.WeakSet] = save_weakset def save_logger(self, obj): self.save_reduce(logging.getLogger, (obj.name,), obj=obj) dispatch[logging.Logger] = save_logger + """Special functions for Add-on libraries""" + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + pass + # Tornado support @@ -913,11 +853,12 @@ def dump(obj, file, protocol=2): def dumps(obj, protocol=2): file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - return file.getvalue() + try: + cp = CloudPickler(file,protocol) + cp.dump(obj) + return file.getvalue() + finally: + file.close() # including pickles unloading functions in this namespace load = pickle.load @@ -1019,18 +960,40 @@ def __reduce__(cls): return cls.__name__ -def _fill_function(func, globals, defaults, dict, module, closure_values): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). +def _fill_function(*args): + """Fills in the rest of function data into the skeleton function object + + The skeleton itself is create by _make_skel_func(). """ - func.__globals__.update(globals) - func.__defaults__ = defaults - func.__dict__ = dict - func.__module__ = module + if len(args) == 2: + func = args[0] + state = args[1] + elif len(args) == 5: + # Backwards compat for cloudpickle v0.4.0, after which the `module` + # argument was introduced + func = args[0] + keys = ['globals', 'defaults', 'dict', 'closure_values'] + state = dict(zip(keys, args[1:])) + elif len(args) == 6: + # Backwards compat for cloudpickle v0.4.1, after which the function + # state was passed as a dict to the _fill_function it-self. + func = args[0] + keys = ['globals', 'defaults', 'dict', 'module', 'closure_values'] + state = dict(zip(keys, args[1:])) + else: + raise ValueError('Unexpected _fill_value arguments: %r' % (args,)) + + func.__globals__.update(state['globals']) + func.__defaults__ = state['defaults'] + func.__dict__ = state['dict'] + if 'module' in state: + func.__module__ = state['module'] + if 'qualname' in state: + func.__qualname__ = state['qualname'] cells = func.__closure__ if cells is not None: - for cell, value in zip(cells, closure_values): + for cell, value in zip(cells, state['closure_values']): if value is not _empty_cell_value: cell_set(cell, value) @@ -1087,13 +1050,6 @@ def _find_module(mod_name): file.close() return path, description -def _load_namedtuple(name, fields): - """ - Loads a class generated by namedtuple - """ - from collections import namedtuple - return namedtuple(name, fields) - """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 91a7f093cec19..917e258d8a602 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ xrange = range from pyspark import cloudpickle +from pyspark.util import _exception_message __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] @@ -565,7 +566,18 @@ def loads(self, obj, encoding=None): class CloudPickleSerializer(PickleSerializer): def dumps(self, obj): - return cloudpickle.dumps(obj, 2) + try: + return cloudpickle.dumps(obj, 2) + except pickle.PickleError: + raise + except Exception as e: + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg + else: + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) + cloudpickle.print_exec(sys.stderr) + raise pickle.PicklingError(msg) class MarshalSerializer(FramedSerializer): From d6632d185e147fcbe6724545488ad80dce20277e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Mar 2018 20:22:07 +0900 Subject: [PATCH 0450/2461] [SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame ## What changes were proposed in this pull request? This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame. ## How was this patch tested? Manually tested and unit tests added. You can test this by: **`createDataFrame`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame(pdf, "a: map") ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame(pdf, "a: map") ``` **`toPandas`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` Author: hyukjinkwon Closes #20678 from HyukjinKwon/SPARK-23380-conf. --- docs/sql-programming-guide.md | 5 + python/pyspark/sql/dataframe.py | 120 ++++++++++++------ python/pyspark/sql/session.py | 22 +++- python/pyspark/sql/tests.py | 84 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 13 +- 5 files changed, 186 insertions(+), 58 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 01e2076555ee6..451b814ab6c53 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da `createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. +In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically +to non-Arrow optimization implementation if an error occurs before the actual computation within Spark. +This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'. +
    {% include_example dataframe_with_arrow python/sql/arrow.py %} @@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9d8e85cde914f..8f90a367e8bf8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1992,55 +1992,91 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + use_arrow = True try: - from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps, to_arrow_schema + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() - import pyarrow to_arrow_schema(self.schema) - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) - return _check_dataframe_localize_timestamps(pdf, timezone) - else: - return pd.DataFrame.from_records([], columns=self.columns) except Exception as e: - msg = ( - "Note: toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " - "to disable this.") - raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) - else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - dtype = {} + if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False + else: + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + import pyarrow + + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed unexpectedly:\n %s\n" + "Note that 'spark.sql.execution.arrow.fallback.enabled' does " + "not have an effect in such failure in the middle of " + "computation." % _exception_message(e)) + raise RuntimeError(msg) + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3af9b82953f3..215bb3e5c5173 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - # Fallback to create DataFrame without arrow if raise some exception + from pyspark.util import _exception_message + + if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fa3b7203e10ac..a9fe0b425ad3e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,7 +32,9 @@ import datetime import array import ctypes +import warnings import py4j +from contextlib import contextmanager try: import xmlrunner @@ -48,12 +50,13 @@ else: import unittest +from pyspark.util import _exception_message + _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: - from pyspark.util import _exception_message # If Pandas version requirement is not satisfied, skip related tests. _pandas_requirement_message = _exception_message(e) @@ -62,7 +65,6 @@ from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: - from pyspark.util import _exception_message # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) @@ -195,6 +197,28 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3458,6 +3482,8 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() - df = self.spark.createDataFrame([(None,)], schema="a binary") - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): - df.toPandas() - def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce3f94618edeb..3f96112659c11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1058,7 +1058,7 @@ object SQLConf { .intConf .createWithDefault(100) - val ARROW_EXECUTION_ENABLE = + val ARROW_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas, and " + @@ -1068,6 +1068,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .booleanConf + .createWithDefault(true) + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") .doc("When using Apache Arrow, limit the maximum number of records that can be written " + @@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging { def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED) + + def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) From 2cb23a8f51a151970c121015fcbad9beeafa8295 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Mar 2018 20:29:07 +0900 Subject: [PATCH 0451/2461] [SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF ## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key. --- python/pyspark/serializers.py | 18 ++- python/pyspark/sql/functions.py | 25 ++++ python/pyspark/sql/tests.py | 121 ++++++++++++++++-- python/pyspark/sql/types.py | 45 ++++++- python/pyspark/sql/udf.py | 19 +-- python/pyspark/util.py | 16 +++ python/pyspark/worker.py | 49 +++++-- .../python/FlatMapGroupsInPandasExec.scala | 56 +++++++- 8 files changed, 294 insertions(+), 55 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 917e258d8a602..ebf549396f463 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -250,6 +250,15 @@ def __init__(self, timezone): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.types import from_arrow_type, \ + _check_series_convert_date, _check_series_localize_timestamps + + s = arrow_column.to_pandas() + s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _check_series_localize_timestamps(s, self._timezone) + return s + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -272,16 +281,11 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) - schema = from_arrow_schema(reader.schema) + for batch in reader: - pdf = batch.to_pandas() - pdf = _check_dataframe_convert_date(pdf, schema) - pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) - yield [c for _, c in pdf.iteritems()] + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b9c0c57262c5d..dc1341ac74d3d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 1.1094003924504583| +---+-------------------+ + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key will be passed as the first argument and the data will + be passed as the second argument. The grouping key will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key in the function. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a9fe0b425ad3e..480815d27333f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3903,7 +3903,7 @@ def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v): + def foo(k, v, w): return k @@ -4476,20 +4476,45 @@ def test_supported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col df = self.data.withColumn("arr", array(col("id"))) - foo_udf = pandas_udf( + # Different forms of group map pandas UDF, results of these are the same + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), + StructField('v1', DoubleType()), + StructField('v2', LongType())]) + + udf1 = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]), + output_schema, PandasUDFType.GROUPED_MAP ) - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) + udf2 = pandas_udf( + lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4648,6 +4673,80 @@ def test_timestamp_dst(self): result = df.groupby('time').apply(foo_udf).sort('time') self.assertPandasEqual(df.toPandas(), result.toPandas()) + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cd857402db8f7..1632862d3f1ba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_series_convert_date(series, data_type): + """ + Cast the series to datetime.date if it's a date type, otherwise returns the original series. + + :param series: pandas.Series + :param data_type: a Spark data type for the series + """ + if type(data_type) == DateType: + return series.dt.date + else: + return series + + def _check_dataframe_convert_date(pdf, schema): """ Correct date type value to use datetime.date. @@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema): :param schema: a Spark schema of the pandas.DataFrame """ for field in schema: - if type(field.dataType) == DateType: - pdf[field.name] = pdf[field.name].dt.date + pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) return pdf @@ -1725,6 +1737,29 @@ def _get_local_timezone(): return os.environ.get('TZ', 'dateutil/:') +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) + pdf[column] = _check_series_localize_timestamps(series, timezone) return pdf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b9b490874f4fb..ce804c18e9b14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,6 +17,8 @@ """ User-defined function related classes and functions """ +import sys +import inspect import functools from pyspark import SparkContext, since @@ -24,6 +26,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - import inspect - import sys from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() - if sys.version_info[0] < 3: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - argspec = inspect.getargspec(f) - else: - argspec = inspect.getfullargspec(f) + argspec = _get_argspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: @@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType): "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (1, 2): raise ValueError( "Invalid function: pandas_udfs with function type GROUPED_MAP " - "must take a single arg that is a pandas DataFrame." - ) + "must take either one argument (data) or two arguments (key, data).") # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ad4a0bc68ef41..6837b18b7d7a5 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import sys +import inspect from py4j.protocol import Py4JJavaError __all__ = [] @@ -45,6 +48,19 @@ def _exception_message(excp): return str(excp) +def _get_argspec(f): + """ + Get argspec of a function. Supports both Python 2 and Python 3. + """ + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + if sys.version_info[0] < 3: + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) + return argspec + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 89a3a92bc66d6..202cac350aafc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import _get_argspec from pyspark import shuffle pickleSer = PickleSerializer() @@ -91,10 +92,16 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type): - def wrapped(*series): + def wrapped(key_series, value_series): import pandas as pd + argspec = _get_argspec(f) + + if len(argspec.args) == 1: + result = f(pd.concat(value_series, axis=1)) + elif len(argspec.args) == 2: + key = tuple(s[0] for s in key_series) + result = f(key, pd.concat(value_series, axis=1)) - result = f(pd.concat(series, axis=1)) if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "pandas.DataFrame, but is {}".format(type(result))) @@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] - for i in range(num_udfs): + mapper_str = "" + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # Create function like this: + # lambda a: f([a[0]], [a[0], a[1]]) + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs['f'] = udf + split_offset = arg_offsets[0] + 1 + arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] + arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + else: + # Create function like this: + # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c798fe5a92c54..513e174c7733e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) } } @@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 7013eea11cb32b1e0038dc751c485da5c94a484b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 8 Mar 2018 20:38:34 +0900 Subject: [PATCH 0452/2461] [SPARK-23522][PYTHON] always use sys.exit over builtin exit The exit() builtin is only for interactive use. applications should use sys.exit(). ## What changes were proposed in this pull request? All usage of the builtin `exit()` function is replaced by `sys.exit()`. ## How was this patch tested? I ran `python/run-tests`. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Benjamin Peterson Closes #20682 from benjaminp/sys-exit. --- dev/merge_spark_pr.py | 2 +- dev/run-tests.py | 2 +- examples/src/main/python/avro_inputformat.py | 2 +- examples/src/main/python/kmeans.py | 2 +- examples/src/main/python/logistic_regression.py | 2 +- examples/src/main/python/ml/dataframe_example.py | 2 +- examples/src/main/python/mllib/correlations.py | 2 +- examples/src/main/python/mllib/kmeans.py | 2 +- examples/src/main/python/mllib/logistic_regression.py | 2 +- examples/src/main/python/mllib/random_rdd_generation.py | 2 +- examples/src/main/python/mllib/sampled_rdds.py | 4 ++-- .../python/mllib/streaming_linear_regression_example.py | 2 +- examples/src/main/python/pagerank.py | 2 +- examples/src/main/python/parquet_inputformat.py | 2 +- examples/src/main/python/sort.py | 2 +- .../main/python/sql/streaming/structured_kafka_wordcount.py | 2 +- .../python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- .../src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/hdfs_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- .../main/python/streaming/recoverable_network_wordcount.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- examples/src/main/python/wordcount.py | 2 +- python/pyspark/accumulators.py | 2 +- python/pyspark/broadcast.py | 2 +- python/pyspark/conf.py | 2 +- python/pyspark/context.py | 2 +- python/pyspark/daemon.py | 2 +- python/pyspark/find_spark_home.py | 2 +- python/pyspark/heapq3.py | 3 ++- python/pyspark/ml/classification.py | 3 ++- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/image.py | 4 +++- python/pyspark/ml/linalg/__init__.py | 2 +- python/pyspark/ml/recommendation.py | 4 +++- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/stat.py | 4 +++- python/pyspark/ml/tuning.py | 6 ++++-- python/pyspark/mllib/classification.py | 3 ++- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/evaluation.py | 3 ++- python/pyspark/mllib/feature.py | 2 +- python/pyspark/mllib/fpm.py | 4 +++- python/pyspark/mllib/linalg/__init__.py | 2 +- python/pyspark/mllib/linalg/distributed.py | 2 +- python/pyspark/mllib/random.py | 3 ++- python/pyspark/mllib/recommendation.py | 3 ++- python/pyspark/mllib/regression.py | 6 ++++-- python/pyspark/mllib/stat/_statistics.py | 2 +- python/pyspark/mllib/tree.py | 3 ++- python/pyspark/mllib/util.py | 2 +- python/pyspark/profiler.py | 3 ++- python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 2 +- python/pyspark/shuffle.py | 3 ++- python/pyspark/sql/catalog.py | 3 ++- python/pyspark/sql/column.py | 2 +- python/pyspark/sql/conf.py | 4 +++- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 4 +++- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/streaming.py | 2 +- python/pyspark/sql/types.py | 2 +- python/pyspark/sql/udf.py | 3 ++- python/pyspark/sql/window.py | 2 +- python/pyspark/streaming/util.py | 3 ++- python/pyspark/util.py | 4 +++- python/pyspark/worker.py | 6 +++--- python/setup.py | 6 +++--- 79 files changed, 120 insertions(+), 86 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 6b244d8184b2c..5ea205fbed4aa 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -510,7 +510,7 @@ def main(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) try: main() except: diff --git a/dev/run-tests.py b/dev/run-tests.py index fe75ef4411c8c..164c1e2200aa9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -621,7 +621,7 @@ def _test(): import doctest failure_count = doctest.testmod()[0] if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 6286ba6541fbd..a18722c687f8b 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -61,7 +61,7 @@ Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 92e0a3ae2ee60..a42d711fc505f 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -49,7 +49,7 @@ def closestPoint(p, centers): if len(sys.argv) != 4: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 01c938454b108..bcc4e0f4e8eae 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -48,7 +48,7 @@ def readPointBatch(iterator): if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index 109f901012c9c..d62cf2338a1fe 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -33,7 +33,7 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Usage: dataframe_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) elif len(sys.argv) == 2: input = sys.argv[1] else: diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 0e13546b88e67..089504fa7064b 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -31,7 +31,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: correlations ()", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: filepath = sys.argv[1] diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 002fc75799648..1bdb3e9b4a2af 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -36,7 +36,7 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index d4f1d34e2d8cf..87efe17375226 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -42,7 +42,7 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 729bae30b152c..9a429b5f8abdf 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: random_rdd_generation", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index b7033ab7daeb3..00e7cf4bbcdbf 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: sampled_rdds ", file=sys.stderr) - exit(-1) + sys.exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] else: @@ -43,7 +43,7 @@ numExamples = examples.count() if numExamples == 0: print("Error: Data file had no samples to load.", file=sys.stderr) - exit(1) + sys.exit(1) print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py index f600496867c11..714c9a0de7217 100644 --- a/examples/src/main/python/mllib/streaming_linear_regression_example.py +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -36,7 +36,7 @@ if len(sys.argv) != 3: print("Usage: streaming_linear_regression_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0d6c253d397a0..2c19e8700ab16 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -47,7 +47,7 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: pagerank ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("WARN: This is a naive implementation of PageRank and is given as an example!\n" + "Please refer to PageRank implementation provided by graphx", diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index a3f86cf8999cf..83041f0040a0c 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -45,7 +45,7 @@ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 81898cf6d5ce6..d3cd985d197e3 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -25,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py index 9e8a552b3b10b..921067891352a 100644 --- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -49,7 +49,7 @@ print(""" Usage: structured_kafka_wordcount.py """, file=sys.stderr) - exit(-1) + sys.exit(-1) bootstrapServers = sys.argv[1] subscribeType = sys.argv[2] diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index c3284c1d01017..9ac392164735b 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -38,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: structured_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index db672551504b5..c4e3bbf44cd5a 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -53,7 +53,7 @@ msg = ("Usage: structured_network_wordcount_windowed.py " " []") print(msg, file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 425df309011a0..c5c186c11f79a 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") ssc = StreamingContext(sc, 2) diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 5d6e6dc36d6f9..c8ea92b61ca6e 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: flume_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingFlumeWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f815dd26823d1..f9a5c43a8eaa9 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: hdfs_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 704f6602e2297..e9ee08b9fd228 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 9010fafb425e6..f3099d2517cd5 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index d51a380a5d5f9..2b5434c0c845a 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -47,7 +47,7 @@ def print_happiest_words(rdd): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") ssc = StreamingContext(sc, 5) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 52b2639cdf55c..60167dc772544 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -101,7 +101,7 @@ def filterFunc(wordCount): if len(sys.argv) != 5: print("Usage: recoverable_network_wordcount.py " " ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, lambda: createContext(host, int(port), output)) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 7f12281c0e3fe..ab3cfc067994d 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -48,7 +48,7 @@ def getSparkSessionInstance(sparkConf): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: sql_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index d7bb61e729f18..d5d1eba6c5969 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: stateful_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 3d5e44d5b2df1..a05e24ff3ff95 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -26,7 +26,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 7def676b89a24..f730d290273fe 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -265,4 +265,4 @@ def _start_update_server(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 02fc515fb824a..b3dfc99962a35 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -162,4 +162,4 @@ def clear(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 491b3a81972bc..ab429d9ab10de 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -217,7 +217,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 24905f1c97b21..7c664966ed74e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1035,7 +1035,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7f06d4288c872..7bed5216eabf3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -89,7 +89,7 @@ def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) - exit(code) + sys.exit(code) def handle_sigterm(*args): shutdown(1) diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 212a618b767ab..9cf0e8c8d2fe9 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -68,7 +68,7 @@ def is_spark_home(path): return next(path for path in paths if is_spark_home(path)) except StopIteration: print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) - exit(-1) + sys.exit(-1) if __name__ == "__main__": print(_find_spark_home()) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index b27e91a4cc251..6af084adcf373 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -884,6 +884,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": import doctest + import sys (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d3..fbbe3d0307c81 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,6 +16,7 @@ # import operator +import sys from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only @@ -2043,4 +2044,4 @@ def _to_java(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6448b76a0da88..b3d5fb17f6b81 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper @@ -1181,4 +1183,4 @@ def getKeepLastCheckpoint(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 695d8ab27cc96..8eaf07645a37f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys from abc import abstractmethod, ABCMeta from pyspark import since, keyword_only @@ -446,4 +447,4 @@ def getDistanceMeasure(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b07e6a05481..f2e357f0bede5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3717,4 +3717,4 @@ def setSize(self, value): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 45c936645f2a8..96d702f844839 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -24,6 +24,8 @@ :members: """ +import sys + import numpy as np from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string @@ -251,7 +253,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index ad1b487676fa7..6a611a2b5b59d 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1158,7 +1158,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e8bcbe4cd34cb..a8eae9bd268d3 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -480,4 +482,4 @@ def recommendForItemSubset(self, dataset, numUsers): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f0812bd1d4a39..de0a0fa9f3bf8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since, keyword_only @@ -1812,4 +1813,4 @@ def __repr__(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 079b0833e1c6d..0eeb5e528434a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java from pyspark.ml.wrapper import _jvm @@ -151,4 +153,4 @@ def corr(dataset, column, method="pearson"): failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6c0cad6cbaaa1..545e24ca05aa5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,9 +15,11 @@ # limitations under the License. # import itertools -import numpy as np +import sys from multiprocessing.pool import ThreadPool +import numpy as np + from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java @@ -727,4 +729,4 @@ def _to_java(self): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cce703d432b5a..bb281981fd56b 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -16,6 +16,7 @@ # from math import exp +import sys import warnings import numpy @@ -761,7 +762,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index bb687a7da6ffd..0cbabab13a896 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1048,7 +1048,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 2cd1da3fbf9aa..36cb03369b8c0 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since @@ -542,7 +543,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index e5231dc3a27a8..40ecd2e0ff4be 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -819,7 +819,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": sys.path.pop(0) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f58ea5dfb0874..de18dad1f675d 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + import numpy from numpy import array from collections import namedtuple @@ -197,7 +199,7 @@ def _test(): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 7b24b3c74a9fa..60d96d8d5ceb8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1370,7 +1370,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 4cb802514be52..bba88542167ad 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1377,7 +1377,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 61213ddf62e8b..a8833cb446923 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -19,6 +19,7 @@ Python package for random data generation. """ +import sys from functools import wraps from pyspark import since @@ -421,7 +422,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 81182881352bb..3d4eae85132bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,6 +16,7 @@ # import array +import sys from collections import namedtuple from pyspark import SparkContext, since @@ -326,7 +327,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index ea107d400621d..6be45f51862c9 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -15,9 +15,11 @@ # limitations under the License. # +import sys +import warnings + import numpy as np from numpy import array -import warnings from pyspark import RDD, since from pyspark.streaming.dstream import DStream @@ -837,7 +839,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 49b26446dbc32..3c75b132ecad2 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -313,7 +313,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 619fa16d463f5..b05734ce489d9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -17,6 +17,7 @@ from __future__ import absolute_import +import sys import random from pyspark import SparkContext, RDD, since @@ -654,7 +655,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 97755807ef262..fc7809387b13a 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -521,7 +521,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 44d17bd629473..3c7656ab5758c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -19,6 +19,7 @@ import pstats import os import atexit +import sys from pyspark.accumulators import AccumulatorParam @@ -173,4 +174,4 @@ def stats(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 93b8974a7e64a..4b44f76747264 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2498,7 +2498,7 @@ def _test(): globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ebf549396f463..15753f77bd903 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -715,4 +715,4 @@ def write_with_length(obj, stream): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e974cda9fc3e1..02c773302e9da 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -23,6 +23,7 @@ import itertools import operator import random +import sys import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ @@ -810,4 +811,4 @@ def load_partition(j): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 6aef0f22340be..b0d8357f4feec 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from collections import namedtuple @@ -306,7 +307,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 43b38a2cd477c..e05a7b33c11a7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -660,7 +660,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 792c420ca6386..d929834aeeaa5 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix @@ -80,7 +82,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index cc1cd1a5842d9..6cb90399dd616 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -543,7 +543,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8f90a367e8bf8..3fc194d8ec1d1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2231,7 +2231,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dc1341ac74d3d..dff590983b4d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2404,7 +2404,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ab646535c864c..35cac406e0965 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal @@ -299,7 +301,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9d05ac7cb39be..803f561ece67b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -970,7 +970,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) sc.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 215bb3e5c5173..e82a9750a0014 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -830,7 +830,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index cc622decfd682..e8966c20a8f42 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -930,7 +930,7 @@ def _test(): globs['spark'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1632862d3f1ba..826aab97e58db 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1890,7 +1890,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index ce804c18e9b14..24dd06c26089c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -20,6 +20,7 @@ import sys import inspect import functools +import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix @@ -397,7 +398,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index bb841a9b9ff7c..e667fba099fb9 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -264,7 +264,7 @@ def _test(): SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..df184471993ff 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,6 +18,7 @@ import time from datetime import datetime import traceback +import sys from pyspark import SparkContext, RDD @@ -147,4 +148,4 @@ def rddToFileName(prefix, suffix, timestamp): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 6837b18b7d7a5..ed1bdd0e4be83 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,6 +22,8 @@ __all__ = [] +import sys + def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -65,4 +67,4 @@ def _get_argspec(f): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 202cac350aafc..a1a4336b1e8de 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -205,7 +205,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - exit(-1) + sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: @@ -279,7 +279,7 @@ def process(): # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - exit(-1) + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -297,7 +297,7 @@ def process(): else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - exit(-1) + sys.exit(-1) if __name__ == '__main__': diff --git a/python/setup.py b/python/setup.py index 6a98401941d8d..794ceceae3008 100644 --- a/python/setup.py +++ b/python/setup.py @@ -26,7 +26,7 @@ if sys.version_info < (2, 7): print("Python versions prior to 2.7 are not supported for pip installed PySpark.", file=sys.stderr) - exit(-1) + sys.exit(-1) try: exec(open('pyspark/version.py').read()) @@ -98,7 +98,7 @@ def _supports_symlinks(): except: print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), file=sys.stderr) - exit(-1) + sys.exit(-1) # If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and # ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. @@ -140,7 +140,7 @@ def _supports_symlinks(): if not os.path.isdir(SCRIPTS_TARGET): print(incorrect_invocation_message, file=sys.stderr) - exit(-1) + sys.exit(-1) # Scripts directive requires a list of each script path and does not take wild cards. script_names = os.listdir(SCRIPTS_TARGET) From 92e7ecbbbd6817378abdbd56541a9c13dcea8659 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 14:18:14 +0100 Subject: [PATCH 0453/2461] [SPARK-23592][SQL] Add interpreted execution to DecodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to DecodeUsingSerializer. ## How was this patch tested? added UT Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Marco Gaido Closes #20760 from mgaido91/SPARK-23592. --- .../catalyst/expressions/objects/objects.scala | 5 +++++ .../expressions/ObjectExpressionsSuite.scala | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7bbc3c732e782..adf9ddf327c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1242,6 +1242,11 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression with NonSQLExpression with SerializerSupport { + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 346b13277c709..ffeec2a38c532 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row @@ -123,4 +125,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { + val cls = classOf[java.lang.Integer] + val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val input = serializer.newInstance().serialize(new Integer(1)).array() + val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) + checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 3be4adf6485ca19cdc5db23394c3f5a660d7dc6f Mon Sep 17 00:00:00 2001 From: lucio <576632108@qq.com> Date: Thu, 8 Mar 2018 08:03:24 -0600 Subject: [PATCH 0454/2461] [SPARK-22751][ML] Improve ML RandomForest shuffle performance ## What changes were proposed in this pull request? As I mentioned in [SPARK-22751](https://issues.apache.org/jira/browse/SPARK-22751?jql=project%20%3D%20SPARK%20AND%20component%20%3D%20ML%20AND%20text%20~%20randomforest), there is a shuffle performance problem in ML Randomforest when train a RF in high dimensional data. The reason is that, in _org.apache.spark.tree.impl.RandomForest_, the function _findSplitsBySorting_ will actually flatmap a sparse vector into a dense vector, then in groupByKey there will be a huge shuffle write size. To avoid this, we can add a filter in flatmap, to filter out zero value. And in function _findSplitsForContinuousFeature_, we can infer the number of zero value by _metadata_. In addition, if a feature only contains zero value, _continuousSplits_ will not has the key of feature id. So I add a check when using _continuousSplits_. ## How was this patch tested? Ran model locally using spark-submit. Author: lucio <576632108@qq.com> Closes #20472 from lucio-yz/master. --- .../spark/ml/tree/impl/RandomForest.scala | 52 ++++++++++++++----- .../ml/tree/impl/RandomForestSuite.scala | 23 ++++---- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 8e514f11e78ea..16f32d76b9984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -892,13 +892,7 @@ private[spark] object RandomForest extends Logging { // Sample the input only if there are continuous features. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) val sampledInput = if (continuousFeatures.nonEmpty) { - // Calculate the number of samples for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples - } else { - 1.0 - } + val fraction = samplesFractionForFindSplits(metadata) logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { @@ -920,8 +914,9 @@ private[spark] object RandomForest extends Logging { val numPartitions = math.min(continuousFeatures.length, input.partitions.length) input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) - .groupByKey(numPartitions) + .flatMap { point => + continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) + }.groupByKey(numPartitions) .map { case (idx, samples) => val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) @@ -933,7 +928,8 @@ private[spark] object RandomForest extends Logging { val numFeatures = metadata.numFeatures val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { case i if metadata.isContinuous(i) => - val split = continuousSplits(i) + // some features may contain only zero, so continuousSplits will not have a record + val split = continuousSplits.getOrElse(i, Array.empty[Split]) metadata.setNumSplits(i, split.length) split @@ -1003,11 +999,22 @@ private[spark] object RandomForest extends Logging { } else { val numSplits = metadata.numSplits(featureIndex) - // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { - case ((m, cnt), x) => - (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + // get count for each distinct value except zero value + val partNumSamples = featureSamples.size + val partValueCountMap = scala.collection.mutable.Map[Double, Int]() + featureSamples.foreach { x => + partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 } + + // Calculate the expected number of samples for finding splits + val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + // add expected zero value count and get complete statistics + val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { + partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + } else { + partValueCountMap.toMap + } + // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -1149,4 +1156,21 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + + /** + * Calculate the subsample fraction for finding splits + * + * @param metadata decision tree metadata + * @return subsample fraction + */ + private def samplesFractionForFindSplits( + metadata: DecisionTreeMetadata): Double = { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 5f0d26eb5c058..743dacf146fe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,12 +93,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array.fill(200000)(math.random) + val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) assert(fakeMetadata.numSplits(0) === 5) @@ -109,7 +109,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -117,7 +117,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits <= numSplits { - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -125,7 +125,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { - val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -135,7 +135,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -150,7 +150,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -164,12 +164,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + .map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) @@ -177,12 +178,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits for constant feature { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits === Array.empty[Double]) From ea480990e726aed59750f1cea8d40adba56d991a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 11:09:15 -0800 Subject: [PATCH 0455/2461] [SPARK-23628][SQL] calculateParamLength should not return 1 + num of epressions ## What changes were proposed in this pull request? There was a bug in `calculateParamLength` which caused it to return always 1 + the number of expressions. This could lead to Exceptions especially with expressions of type long. ## How was this patch tested? added UT + fixed previous UT Author: Marco Gaido Closes #20772 from mgaido91/SPARK-23628. --- .../expressions/codegen/CodeGenerator.scala | 51 ++++++++++--------- .../expressions/CodeGenerationSuite.scala | 6 +++ .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../execution/WholeStageCodegenSuite.scala | 16 +++--- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 793824b0b0a2f..fe5e63ec0a2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1063,31 +1063,6 @@ class CodegenContext { "" } } - - /** - * Returns the length of parameters for a Java method descriptor. `this` contributes one unit - * and a parameter of type long or double contributes two units. Besides, for nullable parameter, - * we also need to pass a boolean parameter for the null status. - */ - def calculateParamLength(params: Seq[Expression]): Int = { - def paramLengthForExpr(input: Expression): Int = { - // For a nullable expression, we need to pass in an extra boolean parameter. - (if (input.nullable) 1 else 0) + javaType(input.dataType) match { - case JAVA_LONG | JAVA_DOUBLE => 2 - case _ => 1 - } - } - // Initial value is 1 for `this`. - 1 + params.map(paramLengthForExpr(_)).sum - } - - /** - * In Java, a method descriptor is valid only if it represents method parameters with a total - * length less than a pre-defined constant. - */ - def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH - } } /** @@ -1538,4 +1513,30 @@ object CodeGenerator extends Logging { def defaultValue(dt: DataType, typedNull: Boolean = false): String = defaultValue(javaType(dt), typedNull) + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + val javaParamLength = javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaParamLength + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e48c7b8df9da..64c13e8972036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addImmutableStateIfNotExists("String", mutableState2) assert(ctx.inlinedMutableStates.length == 2) } + + test("SPARK-23628: calculateParamLength should compute properly the param length") { + assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101) + assert(CodeGenerator.calculateParamLength( + Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f89e3fb0e536f..6ddaacfee1a40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -174,8 +174,9 @@ trait CodegenSupport extends SparkPlan { // declaration. val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator val requireAllOutput = output.forall(parent.usedInputs.contains(_)) - val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) - val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput + && CodeGenerator.isValidParamLength(paramLength)) { constructDoConsumeFunction(ctx, inputVars, row) } else { parent.doConsume(ctx, inputVars, rowVar) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ef16292a8e75c..0fb9dd2017a09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Skip splitting consume function when parameter number exceeds JVM limit") { - import testImplicits._ - - Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + // since every field is nullable we have 2 params for each input column (one for the value + // and one for the isNull variable) + Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) => withTempPath { dir => val path = dir.getCanonicalPath - spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*) .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", @@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val df = spark.read.parquet(path).selectExpr(projection: _*) val plan = df.queryExecution.executedPlan - val wholeStageCodeGenExec = plan.find(p => p match { - case wp: WholeStageCodegenExec => true + val wholeStageCodeGenExec = plan.find { + case _: WholeStageCodegenExec => true case _ => false - }) + } assert(wholeStageCodeGenExec.isDefined) val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 assert(code.body.contains("project_doConsume") == hasSplit) From e7bbca88964d95593fa15eb94643ba519801e352 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 22:02:28 +0100 Subject: [PATCH 0456/2461] [SPARK-23602][SQL] PrintToStderr prints value also in interpreted mode ## What changes were proposed in this pull request? `PrintToStderr` was doing what is it supposed to only when code generation is enabled. The PR adds the same behavior in interpreted mode too. ## How was this patch tested? added UT Author: Marco Gaido Closes #20773 from mgaido91/SPARK-23602. --- .../spark/sql/catalyst/expressions/misc.scala | 7 +++++- .../expressions/MiscExpressionsSuite.scala | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b423..38e4fe44b15ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,12 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType - protected override def nullSafeEval(input: Any): Any = input + protected override def nullSafeEval(input: Any): Any = { + // scalastyle:off println + System.err.println(outputPrefix + input) + // scalastyle:on println + input + } private val outputPrefix = s"Result of ${child.simpleString} is " diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index a21c139fe71d0..c3d08bf68c7bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.PrintStream + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -43,4 +45,27 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Uuid()), 36) assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } + + test("PrintToStderr") { + val inputExpr = Literal(1) + val systemErr = System.err + + val (outputEval, outputCodegen) = try { + val errorStream = new java.io.ByteArrayOutputStream() + System.setErr(new PrintStream(errorStream)) + // check without codegen + checkEvaluationWithoutCodegen(PrintToStderr(inputExpr), 1) + val outputEval = errorStream.toString + errorStream.reset() + // check with codegen + checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1) + val outputCodegen = errorStream.toString + (outputEval, outputCodegen) + } finally { + System.setErr(systemErr) + } + + assert(outputCodegen.contains(s"Result of $inputExpr is 1")) + assert(outputEval.contains(s"Result of $inputExpr is 1")) + } } From d90e77bd0ec19f8ba9198a24ec2ab3db7708eca8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 8 Mar 2018 14:58:40 -0800 Subject: [PATCH 0457/2461] [SPARK-23271][SQL] Parquet output contains only _SUCCESS file after writing an empty dataframe ## What changes were proposed in this pull request? Below are the two cases. ``` SQL case 1 scala> List.empty[String].toDF().rdd.partitions.length res18: Int = 1 ``` When we write the above data frame as parquet, we create a parquet file containing just the schema of the data frame. Case 2 ``` SQL scala> val anySchema = StructType(StructField("anyName", StringType, nullable = false) :: Nil) anySchema: org.apache.spark.sql.types.StructType = StructType(StructField(anyName,StringType,false)) scala> spark.read.schema(anySchema).csv("/tmp/empty_folder").rdd.partitions.length res22: Int = 0 ``` For the 2nd case, since number of partitions = 0, we don't call the write task (the task has logic to create the empty metadata only parquet file) The fix is to create a dummy single partition RDD and set up the write task based on it to ensure the metadata-only file. ## How was this patch tested? A new test is added to DataframeReaderWriterSuite. Author: Dilip Biswal Closes #20525 from dilipbiswal/spark-23271. --- docs/sql-programming-guide.md | 1 + .../datasources/FileFormatWriter.scala | 15 ++++++++++++--- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 1 - 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 451b814ab6c53..d2132d2ae7441 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,6 +1805,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1d..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -190,9 +190,18 @@ object FileFormatWriter extends Logging { global = false, child = plan).execute() } - val ret = new Array[WriteTaskResult](rdd.partitions.length) + + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + + val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( - rdd, + rddWithNonEmptyPartitions, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -202,7 +211,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 73e3df3b6202e..bd3071bcf9010 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -89,6 +90,23 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + Seq("orc", "parquet").foreach { format => + test(s"SPARK-23271 empty RDD when saved should write a metadata only file - $format") { + withTempPath { outputPath => + val df = spark.emptyDataFrame.select(lit(1).as("i")) + df.write.format(format).save(outputPath.toString) + val partFiles = outputPath.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + // Now read the file. + val df1 = spark.read.format(format).load(outputPath.toString) + checkAnswer(df1, Seq.empty[Row]) + assert(df1.schema.equals(df.schema.asNullable)) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8c9bb7d56a35f..a707a88dfa670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -301,7 +301,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be intercept[AnalysisException] { spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path) } - spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) } } From 2c3673680e16f88f1d1cd73a3f7445ded5b3daa8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 9 Mar 2018 10:36:38 -0800 Subject: [PATCH 0458/2461] [SPARK-23630][YARN] Allow user's hadoop conf customizations to take effect. This change restores functionality that was inadvertently removed as part of the fix for SPARK-22372. Also modified an existing unit test to make sure the feature works as intended. Author: Marcelo Vanzin Closes #20776 from vanzin/SPARK-23630. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 11 +++++- .../org/apache/spark/deploy/yarn/Client.scala | 14 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 34 ++++++++++++++----- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e14f9845e6db6..177295fb7af0f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -111,7 +111,9 @@ class SparkHadoopUtil extends Logging { * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { - SparkHadoopUtil.newConfiguration(conf) + val hadoopConf = SparkHadoopUtil.newConfiguration(conf) + hadoopConf.addResource(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE) + hadoopConf } /** @@ -435,6 +437,13 @@ object SparkHadoopUtil { */ private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + /** + * Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the + * cluster's Hadoop config. It is up to the Spark code launching the application to create + * this file if it's desired. If the file doesn't exist, it will just be ignored. + */ + private[spark] val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" + def get: SparkHadoopUtil = instance /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8cd3cd9746a3a..28087dee831d1 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -696,7 +696,13 @@ private[spark] class Client( } } - Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + // SPARK-23630: during testing, Spark scripts filter out hadoop conf dirs so that user's + // environments do not interfere with tests. This allows a special env variable during + // tests so that custom conf dirs can be used by unit tests. + val confDirs = Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR") ++ + (if (Utils.isTesting) Seq("SPARK_TEST_HADOOP_CONF_DIR") else Nil) + + confDirs.foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { @@ -753,7 +759,7 @@ private[spark] class Client( // Save the YARN configuration into a separate file that will be overlayed on top of the // cluster's Hadoop conf. - confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE)) + confStream.putNextEntry(new ZipEntry(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE)) hadoopConf.writeXml(confStream) confStream.closeEntry() @@ -1220,10 +1226,6 @@ private object Client extends Logging { // Name of the file in the conf archive containing Spark configuration. val SPARK_CONF_FILE = "__spark_conf__.properties" - // Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the - // cluster's Hadoop config. - val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" - // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 5003326b440bf..33d400a5b1b2e 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -114,12 +114,25 @@ class YarnClusterSuite extends BaseYarnClusterSuite { )) } - test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") { + test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)") { + // Create a custom hadoop config file, to make sure it's contents are propagated to the driver. + val customConf = Utils.createTempDir() + val coreSite = """ + | + | + | spark.test.key + | testvalue + | + | + |""".stripMargin + Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8) + val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(false, mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), - appArgs = Seq("key=value", result.getAbsolutePath()), - extraConf = Map("spark.hadoop.key" -> "value")) + appArgs = Seq("key=value", "spark.test.key=testvalue", result.getAbsolutePath()), + extraConf = Map("spark.hadoop.key" -> "value"), + extraEnv = Map("SPARK_TEST_HADOOP_CONF_DIR" -> customConf.getAbsolutePath())) checkResult(finalState, result) } @@ -319,13 +332,13 @@ private object YarnClusterDriverWithFailure extends Logging with Matchers { private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers { def main(args: Array[String]): Unit = { - if (args.length != 2) { + if (args.length < 2) { // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file] + |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value]+ [result file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -335,11 +348,16 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn test using SparkHadoopUtil's conf")) - val kv = args(0).split("=") - val status = new File(args(1)) + val kvs = args.take(args.length - 1).map { kv => + val parsed = kv.split("=") + (parsed(0), parsed(1)) + } + val status = new File(args.last) var result = "failure" try { - SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1)) + kvs.foreach { case (k, v) => + SparkHadoopUtil.get.conf.get(k) should be (v) + } result = "success" } finally { Files.write(result, status, StandardCharsets.UTF_8) From 2ca9bb083c515511d2bfee271fc3e0269aceb9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20=C5=9Awitakowski?= Date: Fri, 9 Mar 2018 14:29:31 -0800 Subject: [PATCH 0459/2461] [SPARK-23173][SQL] Avoid creating corrupt parquet files when loading data from JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The from_json() function accepts an additional parameter, where the user might specify the schema. The issue is that the specified schema might not be compatible with data. In particular, the JSON data might be missing data for fields declared as non-nullable in the schema. The from_json() function does not verify the data against such errors. When data with missing fields is sent to the parquet encoder, there is no verification either. The end results is a corrupt parquet file. To avoid corruptions, make sure that all fields in the user-specified schema are set to be nullable. Since this changes the behavior of a public function, we need to include it in release notes. The behavior can be reverted by setting `spark.sql.fromJsonForceNullableSchema=false` ## How was this patch tested? Added two new tests. Author: Michał Świtakowski Closes #20694 from mswit-databricks/SPARK-23173. --- .../expressions/jsonExpressions.scala | 22 +++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../expressions/JsonExpressionsSuite.scala | 30 ++++++++++++++++++- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++++++++ 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b4fed597447..fdd672c416a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -515,10 +516,15 @@ case class JsonToStructs( child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - override def nullable: Boolean = true - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, None) + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + + // The JSON input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder + // can generate incorrect files if values are missing in columns declared as non-nullable. + val nullableSchema = if (forceNullableSchema) schema.asNullable else schema + + override def nullable: Boolean = true // Used in `FunctionRegistry` def this(child: Expression, schema: Expression) = @@ -535,22 +541,22 @@ case class JsonToStructs( child = child, timeZoneId = None) - override def checkInputDataTypes(): TypeCheckResult = schema match { + override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") } @transient - lazy val rowSchema = schema match { + lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st } // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = schema match { + lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => @@ -563,7 +569,7 @@ case class JsonToStructs( rowSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) - override def dataType: DataType = schema + override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3f96112659c11..11864bd1b1847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -493,6 +493,14 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to curruptions.") + .booleanConf + .createWithDefault(true) + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index a0bbe02f92354..7812319756eae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,11 +22,13 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { val json = """ |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}], @@ -680,4 +682,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } } + + test("from_json missing fields") { + for (forceJsonNullableSchema <- Seq(false, true)) { + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + checkEvaluation( + JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), + output + ) + val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + .dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3af80930ec807..0b3e8ca060d87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -780,6 +781,24 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(option.compressionCodecClassName == "UNCOMPRESSED") } } + + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { + withTempPath { file => + val jsonData = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input") + .write.parquet(file.getAbsolutePath) + checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From 10b0657b035641ce735055bba2c8459e71bc2400 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 9 Mar 2018 15:41:19 -0800 Subject: [PATCH 0460/2461] [SPARK-23624][SQL] Revise doc of method pushFilters in Datasource V2 ## What changes were proposed in this pull request? Revise doc of method pushFilters in SupportsPushDownFilters/SupportsPushDownCatalystFilters In `FileSourceStrategy`, except `partitionKeyFilters`(the references of which is subset of partition keys), all filters needs to be evaluated after scanning. Otherwise, Spark will get wrong result from data sources like Orc/Parquet. This PR is to improve the doc. Author: Wang Gengliang Closes #20769 from gengliangwang/revise_pushdown_doc. --- .../sql/sources/v2/reader/SupportsPushDownCatalystFilters.java | 2 +- .../spark/sql/sources/v2/reader/SupportsPushDownFilters.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 98224102374aa..290d614805ac7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -34,7 +34,7 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Expression[] pushCatalystFilters(Expression[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index f35c711b0387a..1cff024232a44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -32,7 +32,7 @@ public interface SupportsPushDownFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Filter[] pushFilters(Filter[] filters); From 1a54f48b6744032b16543594651ee6d5e3ad4bda Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Mar 2018 15:54:55 -0800 Subject: [PATCH 0461/2461] [SPARK-23510][SQL][FOLLOW-UP] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? In the PR https://github.com/apache/spark/pull/20671, I forgot to update the doc about this new support. ## How was this patch tested? N/A Author: gatorsmile Closes #20789 from gatorsmile/docUpdate. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2132d2ae7441..0e092e0e37ccf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2229,7 +2229,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From b6f837c9d3cb0f76f0a52df37e34aea8944f6867 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Sat, 10 Mar 2018 19:48:29 +0900 Subject: [PATCH 0462/2461] [PYTHON] Changes input variable to not conflict with built-in function Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Changes variable name conflict: [input is a built-in python function](https://stackoverflow.com/questions/20670732/is-input-a-keyword-in-python). ## How was this patch tested? I runned the example and it works fine. Author: DylanGuedes Closes #20775 from DylanGuedes/input_variable. --- examples/src/main/python/ml/dataframe_example.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index d62cf2338a1fe..cabc3de68f2f4 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -17,7 +17,7 @@ """ An example of how to use DataFrame for ML. Run with:: - bin/spark-submit examples/src/main/python/ml/dataframe_example.py + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -35,18 +35,18 @@ print("Usage: dataframe_example.py ", file=sys.stderr) sys.exit(-1) elif len(sys.argv) == 2: - input = sys.argv[1] + input_path = sys.argv[1] else: - input = "data/mllib/sample_libsvm_data.txt" + input_path = "data/mllib/sample_libsvm_data.txt" spark = SparkSession \ .builder \ .appName("DataFrameExample") \ .getOrCreate() - # Load input data - print("Loading LIBSVM file with UDT from " + input + ".") - df = spark.read.format("libsvm").load(input).cache() + # Load an input file + print("Loading LIBSVM file with UDT from " + input_path + ".") + df = spark.read.format("libsvm").load(input_path).cache() print("Schema from LIBSVM:") df.printSchema() print("Loaded training data as a DataFrame with " + From b304e07e0671faf96530f9d8f49c55a83b07fa15 Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Mon, 12 Mar 2018 22:13:28 +0900 Subject: [PATCH 0463/2461] [SPARK-23462][SQL] improve missing field error message in `StructType` ## What changes were proposed in this pull request? The error message ```s"""Field "$name" does not exist."""``` is thrown when looking up an unknown field in StructType. In the error message, we should also contain the information about which columns/fields exist in this struct. ## How was this patch tested? Added new unit tests. Note: I created a new `StructTypeSuite.scala` as I couldn't find an existing suite that's suitable to place these tests. I may be missing something so feel free to propose new locations. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xiayun Sun Closes #20649 from xysun/SPARK-23462. --- .../apache/spark/sql/types/StructType.scala | 11 +++-- .../spark/sql/types/StructTypeSuite.scala | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d5011c3cb87e9..362676b252126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -271,7 +271,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def apply(name: String): StructField = { nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } /** @@ -284,7 +286,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") + s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -297,7 +300,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala new file mode 100644 index 0000000000000..c6ca8bb005429 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class StructTypeSuite extends SparkFunSuite { + + val s = StructType.fromDDL("a INT, b STRING") + + test("lookup a single missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s("c")).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup a set of missing fields should output existing fields") { + val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup fieldIndex for missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage + assert(e.contains("Available fields: a, b")) + } +} From d5b41aea62201cd5b1baad2f68f5fc7eb99c62c5 Mon Sep 17 00:00:00 2001 From: Jooseong Kim Date: Mon, 12 Mar 2018 11:31:34 -0700 Subject: [PATCH 0464/2461] [SPARK-23618][K8S][BUILD] Initialize BUILD_ARGS in docker-image-tool.sh ## What changes were proposed in this pull request? This change initializes BUILD_ARGS to an empty array when $SPARK_HOME/RELEASE exists. In function build, "local BUILD_ARGS" effectively creates an array of one element where the first and only element is an empty string, so "${BUILD_ARGS[]}" expands to "" and passes an extra argument to docker. Setting BUILD_ARGS to an empty array makes "${BUILD_ARGS[]}" expand to nothing. ## How was this patch tested? Manually tested. $ cat RELEASE Spark 2.3.0 (git revision a0d7949896) built for Hadoop 2.7.3 Build flags: -Phadoop-2.7 -Phive -Phive-thriftserver -Pkafka-0-8 -Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr -DzincPort=3036 $ ./bin/docker-image-tool.sh -m t testing build Sending build context to Docker daemon 256.4MB ... vanzin Author: Jooseong Kim Closes #20791 from jooseong/SPARK-23618. --- bin/docker-image-tool.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 071406336d1b1..0d0f564bb8b9b 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -57,6 +57,7 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" + BUILD_ARGS=() fi if [ ! -d "$IMG_PATH" ]; then From 567bd31e0ae8b632357baa93e1469b666fb06f3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 12 Mar 2018 14:53:15 -0500 Subject: [PATCH 0465/2461] [SPARK-23412][ML] Add cosine distance to BisectingKMeans ## What changes were proposed in this pull request? The PR adds the option to specify a distance measure in BisectingKMeans. Moreover, it introduces the ability to use the cosine distance measure in it. ## How was this patch tested? added UTs + existing UTs Author: Marco Gaido Closes #20600 from mgaido91/SPARK-23412. --- .../spark/ml/clustering/BisectingKMeans.scala | 16 +- .../apache/spark/ml/clustering/KMeans.scala | 11 +- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 19 ++ .../mllib/clustering/BisectingKMeans.scala | 139 ++++---- .../clustering/BisectingKMeansModel.scala | 115 +++++-- .../mllib/clustering/DistanceMeasure.scala | 303 ++++++++++++++++++ .../spark/mllib/clustering/KMeans.scala | 196 +---------- .../ml/clustering/BisectingKMeansSuite.scala | 44 ++- project/MimaExcludes.scala | 6 + 10 files changed, 557 insertions(+), 298 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 4c20e6563bad1..f7c422dc0faea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, + BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType} /** * Common params for BisectingKMeans and BisectingKMeansModel */ -private[clustering] trait BisectingKMeansParams extends Params - with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter + with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure { /** * The desired number of leaf clusters. Must be > 1. Default: 4. @@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) @@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") ( .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index c8145de564cbe..987a4285ebad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) - @Since("2.4.0") - final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - (value: String) => MLlibKMeans.validateDistanceMeasure(value)) - - /** @group expertGetParam */ - @Since("2.4.0") - def getDistanceMeasure: String = $(distanceMeasure) - /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 6ad44af9ef7eb..b9c3170cc3c28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen { "after fitting. If set to true, then all sub-models will be available. Warning: For " + "large models, collecting all sub-models can cause OOMs on the Spark driver", Some("false"), isExpertParam = true), - ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false) + ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false), + ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), + isValid = "(value: String) => " + + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index be8b2f273164b..282ea6ebcbf7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -504,4 +504,23 @@ trait HasLoss extends Params { /** @group getParam */ final def getLoss: String = $(loss) } + +/** + * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasDistanceMeasure extends Params { + + /** + * Param for The distance measure. Supported options: 'euclidean' and 'cosine'. + * @group param + */ + final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)) + + setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN) + + /** @group getParam */ + final def getDistanceMeasure: String = $(distanceMeasure) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 2221f4c0edc17..98af487306dcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -57,7 +57,8 @@ class BisectingKMeans private ( private var k: Int, private var maxIterations: Int, private var minDivisibleClusterSize: Double, - private var seed: Long) extends Logging { + private var seed: Long, + private var distanceMeasure: String) extends Logging { import BisectingKMeans._ @@ -65,7 +66,7 @@ class BisectingKMeans private ( * Constructs with the default configuration */ @Since("1.6.0") - def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN) /** * Sets the desired number of leaf clusters (default: 4). @@ -134,6 +135,22 @@ class BisectingKMeans private ( @Since("1.6.0") def getSeed: Long = this.seed + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + DistanceMeasure.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + /** * Runs the bisecting k-means algorithm. * @param input RDD of vectors @@ -147,11 +164,13 @@ class BisectingKMeans private ( } val d = input.map(_.size).first() logInfo(s"Feature dimension: $d.") + + val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure) // Compute and cache vector norms for fast distance computation. val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) - var activeClusters = summarize(d, assignments) + var activeClusters = summarize(d, assignments, dMeasure) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -184,24 +203,25 @@ class BisectingKMeans private ( val divisibleIndices = divisibleClusters.keys.toSet logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => - val (left, right) = splitCenter(summary.center, random) + val (left, right) = splitCenter(summary.center, random, dMeasure) Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map var newClusters: Map[Long, ClusterSummary] = null var newAssignments: RDD[(Long, VectorWithNorm)] = null for (iter <- 0 until maxIterations) { - newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters, + dMeasure) .filter { case (index, _) => divisibleIndices.contains(parentIndex(index)) } - newClusters = summarize(d, newAssignments) + newClusters = summarize(d, newAssignments, dMeasure) newClusterCenters = newClusters.mapValues(_.center).map(identity) } if (preIndices != null) { preIndices.unpersist(false) } preIndices = indices - indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys .persist(StorageLevel.MEMORY_AND_DISK) assignments = indices.zip(vectors) inactiveClusters ++= activeClusters @@ -222,8 +242,8 @@ class BisectingKMeans private ( } norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters - val root = buildTree(clusters) - new BisectingKMeansModel(root) + val root = buildTree(clusters, dMeasure) + new BisectingKMeansModel(root, this.distanceMeasure) } /** @@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable { */ private def summarize( d: Int, - assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { - assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + assignments: RDD[(Long, VectorWithNorm)], + distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))( seqOp = (agg, v) => agg.add(v), combOp = (agg1, agg2) => agg1.merge(agg2) ).mapValues(_.summary) @@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable { * Cluster summary aggregator. * @param d feature dimension */ - private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure) + extends Serializable { private var n: Long = 0L private val sum: Vector = Vectors.zeros(d) private var sumSq: Double = 0.0 @@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable { n += 1L // TODO: use a numerically stable approach to estimate cost sumSq += v.norm * v.norm - BLAS.axpy(1.0, v.vector, sum) + distanceMeasure.updateClusterSum(v, sum) this } @@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable { def merge(other: ClusterSummaryAggregator): this.type = { n += other.n sumSq += other.sumSq - BLAS.axpy(1.0, other.sum, sum) + distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum) this } /** Returns the summary. */ def summary: ClusterSummary = { - val mean = sum.copy - if (n > 0L) { - BLAS.scal(1.0 / n, mean) - } - val center = new VectorWithNorm(mean) - val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) - new ClusterSummary(n, center, cost) + val center = distanceMeasure.centroid(sum.copy, n) + val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq) + ClusterSummary(n, center, cost) } } @@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable { */ private def splitCenter( center: VectorWithNorm, - random: Random): (VectorWithNorm, VectorWithNorm) = { + random: Random, + distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = { val d = center.vector.size val norm = center.norm val level = 1e-4 * norm val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) - val left = center.vector.copy - BLAS.axpy(-level, noise, left) - val right = center.vector.copy - BLAS.axpy(level, noise, right) - (new VectorWithNorm(left), new VectorWithNorm(right)) + distanceMeasure.symmetricCentroids(level, noise, center.vector) } /** @@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable { private def updateAssignments( assignments: RDD[(Long, VectorWithNorm)], divisibleIndices: Set[Long], - newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + newClusterCenters: Map[Long, VectorWithNorm], + distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = { assignments.map { case (index, v) => if (divisibleIndices.contains(index)) { val children = Seq(leftChildIndex(index), rightChildIndex(index)) - val newClusterChildren = children.filter(newClusterCenters.contains(_)) + val newClusterChildren = children.filter(newClusterCenters.contains) + val newClusterChildrenCenterToId = + newClusterChildren.map(id => newClusterCenters(id) -> id).toMap + val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray if (newClusterChildren.nonEmpty) { - val selected = newClusterChildren.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) - } - (selected, v) + val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1 + val center = newClusterChildrenCenters(selected) + val id = newClusterChildrenCenterToId(center) + (id, v) } else { (index, v) } @@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable { * @param clusters a map from cluster indices to corresponding cluster summaries * @return the root node of the clustering tree */ - private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + private def buildTree( + clusters: Map[Long, ClusterSummary], + distanceMeasure: DistanceMeasure): ClusteringTreeNode = { var leafIndex = 0 var internalIndex = -1 @@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable { internalIndex -= 1 val leftIndex = leftChildIndex(rawIndex) val rightIndex = rightChildIndex(rawIndex) - val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) - val height = math.sqrt(indexes.map { childIndex => - EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) - }.max) - val children = indexes.map(buildSubTree(_)).toArray + val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains) + val height = indexes.map { childIndex => + distanceMeasure.distance(center, clusters(childIndex).center) + }.max + val children = indexes.map(buildSubTree).toArray new ClusteringTreeNode(index, size, center, cost, height, children) } else { val index = leafIndex @@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] ( def center: Vector = centerWithNorm.vector /** Predicts the leaf cluster node index that the input point belongs to. */ - def predict(point: Vector): Int = { - val (index, _) = predict(new VectorWithNorm(point)) + def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = { + val (index, _) = predict(new VectorWithNorm(point), distanceMeasure) index } /** Returns the full prediction path from root to leaf. */ - def predictPath(point: Vector): Array[ClusteringTreeNode] = { - predictPath(new VectorWithNorm(point)).toArray + def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point), distanceMeasure).toArray } /** Returns the full prediction path from root to leaf. */ - private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + private def predictPath( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = { if (isLeaf) { this :: Nil } else { val selected = children.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + distanceMeasure.distance(child.centerWithNorm, pointWithNorm) } - selected :: selected.predictPath(pointWithNorm) + selected :: selected.predictPath(pointWithNorm, distanceMeasure) } } /** - * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + * Computes the cost of the input point. */ - def computeCost(point: Vector): Double = { - val (_, cost) = predict(new VectorWithNorm(point)) + def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = { + val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure) cost } /** * Predicts the cluster index and the cost of the input point. */ - private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, - EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) + private def predict( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): (Int, Double) = { + predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure) } /** @@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * @return (predicted leaf cluster index, cost) */ @tailrec - private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + private def predict( + pointWithNorm: VectorWithNorm, + cost: Double, + distanceMeasure: DistanceMeasure): (Int, Double) = { if (isLeaf) { (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) - selectedChild.predict(pointWithNorm, minCost) + selectedChild.predict(pointWithNorm, minCost, distanceMeasure) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 633bda6aac804..9d115afcea75d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession} */ @Since("1.6.0") class BisectingKMeansModel private[clustering] ( - private[clustering] val root: ClusteringTreeNode + private[clustering] val root: ClusteringTreeNode, + @Since("2.4.0") val distanceMeasure: String ) extends Serializable with Saveable with Logging { + @Since("1.6.0") + def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) + + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + /** * Leaf cluster centers. */ @@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(point: Vector): Int = { - root.predict(point) + root.predict(point, distanceMeasureInstance) } /** @@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(points: RDD[Vector]): RDD[Int] = { - points.map { p => root.predict(p) } + points.map { p => root.predict(p, distanceMeasureInstance) } } /** @@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(point: Vector): Double = { - root.computeCost(point) + root.computeCost(point, distanceMeasureInstance) } /** @@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: RDD[Vector]): Double = { - data.map(root.computeCost).sum() + data.map(root.computeCost(_, distanceMeasureInstance)).sum() } /** @@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { @Since("2.0.0") override def load(sc: SparkContext, path: String): BisectingKMeansModel = { - val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) - implicit val formats = DefaultFormats - val rootId = (metadata \ "rootId").extract[Int] - val classNameV1_0 = SaveLoadV1_0.thisClassName + val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path) (loadedClassName, formatVersion) match { - case (classNameV1_0, "1.0") => - val model = SaveLoadV1_0.load(sc, path, rootId) + case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) + model + case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $formatVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + + s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") } } @@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) } + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes) ++ Array(node) + } + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes(rootId) + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + private[clustering] object SaveLoadV1_0 { - private val thisFormatVersion = "1.0" + private[clustering] val thisFormatVersion = "1.0" private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" @@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) } - private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { - if (node.children.isEmpty) { - Array(node) - } else { - node.children.flatMap(getNodes(_)) ++ Array(node) - } - } - - def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode) + new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) } + } + + private[clustering] object SaveLoadV2_0 { + private[clustering] val thisFormatVersion = "2.0" - private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { - val root = nodes.get(rootId).get - if (root.children.isEmpty) { - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, new Array[ClusteringTreeNode](0)) - } else { - val children = root.children.map(c => buildTree(c, nodes)) - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, children.toArray) - } + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode, distanceMeasure) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala new file mode 100644 index 0000000000000..683360efabc76 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} +import org.apache.spark.mllib.util.MLUtils + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return the K-means cost of a given point against the given cluster centers. + */ + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + + /** + * @return the total cost of the cluster from its aggregated properties + */ + def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val left = centroid.copy + axpy(-level, noise, left) + val right = centroid.copy + axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = distance(point, centroid) +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary + // distance computation. + var lowerBoundOfSqDist = center.norm - point.norm + lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist + if (lowerBoundOfSqDist < bestDistance) { + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + override def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = { + EuclideanDistanceMeasure.fastSquaredDistance(point, centroid) + } +} + + +private[spark] object EuclideanDistanceMeasure { + /** + * @return the squared Euclidean distance between two vectors computed by + * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. + */ + private[clustering] def fastSquaredDistance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double = { + MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) + } +} + +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.") + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + val costVector = pointsSum.vector.copy + math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + override def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val (left, right) = super.symmetricCentroids(level, noise, centroid) + val leftVector = left.vector + val rightVector = right.vector + scal(1.0 / left.norm, leftVector) + scal(1.0 / right.norm, rightVector) + (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 3c4ba0bc60c7f..b5b1be3490497 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -204,7 +203,7 @@ class KMeans private ( */ @Since("2.4.0") def setDistanceMeasure(distanceMeasure: String): this.type = { - KMeans.validateDistanceMeasure(distanceMeasure) + DistanceMeasure.validateDistanceMeasure(distanceMeasure) this.distanceMeasure = distanceMeasure this } @@ -582,14 +581,6 @@ object KMeans { case _ => false } } - - private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { - distanceMeasure match { - case DistanceMeasure.EUCLIDEAN => true - case DistanceMeasure.COSINE => true - case _ => false - } - } } /** @@ -605,186 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) /** Converts the vector to a dense vector. */ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } - - -private[spark] abstract class DistanceMeasure extends Serializable { - - /** - * @return the index of the closest center to the given point, as well as the cost. - */ - def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - val currentDistance = distance(center, point) - if (currentDistance < bestDistance) { - bestDistance = currentDistance - bestIndex = i - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return the K-means cost of a given point against the given cluster centers. - */ - def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = { - findClosest(centers, point)._2 - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - distance(oldCenter, newCenter) <= epsilon - } - - /** - * @return the cosine distance between two points. - */ - def distance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - new VectorWithNorm(sum) - } -} - -@Since("2.4.0") -object DistanceMeasure { - - @Since("2.4.0") - val EUCLIDEAN = "euclidean" - @Since("2.4.0") - val COSINE = "cosine" - - private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = - distanceMeasure match { - case EUCLIDEAN => new EuclideanDistanceMeasure - case COSINE => new CosineDistanceMeasure - case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + - s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") - } -} - -private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { - /** - * @return the index of the closest center to the given point, as well as the squared distance. - */ - override def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary - // distance computation. - var lowerBoundOfSqDist = center.norm - point.norm - lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist - if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - override def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon - } - - /** - * @param v1: first vector - * @param v2: second vector - * @return the Euclidean distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) - } -} - - -private[spark] object EuclideanDistanceMeasure { - /** - * @return the squared Euclidean distance between two vectors computed by - * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. - */ - private[clustering] def fastSquaredDistance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double = { - MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) - } -} - -private[spark] class CosineDistanceMeasure extends DistanceMeasure { - /** - * @param v1: first vector - * @param v2: second vector - * @return the cosine distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") - 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm - } - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0 / point.norm, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - override def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - val norm = Vectors.norm(sum, 2) - scal(1.0 / norm, sum) - new VectorWithNorm(sum, 1) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fa7471fa2d658..02880f96ae6d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.clustering -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -140,6 +142,46 @@ class BisectingKMeansSuite testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, BisectingKMeansSuite.allParamSettings, checkModelData) } + + test("BisectingKMeans with cosine distance is not supported for 0-length vectors") { + val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) + } + + test("BisectingKMeans with cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + val model = new BisectingKMeans() + .setK(3) + .setDistanceMeasure(DistanceMeasure.COSINE) + .setSeed(1) + .fit(df) + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } } object BisectingKMeansSuite { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 381f7b5be1ddf..1b6d1dec69d49 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,12 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), From 23370554d0f88b82154d4232744b874cc58c7848 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 15:20:09 +0100 Subject: [PATCH 0466/2461] [SPARK-23656][TEST] Perform assertions in XXH64Suite.testKnownByteArrayInputs() on big endian platform, too ## What changes were proposed in this pull request? This PR enables assertions in `XXH64Suite.testKnownByteArrayInputs()` on big endian platform, too. The current implementation performs them only on little endian platform. This PR increase test coverage of big endian platform. ## How was this patch tested? Updated `XXH64Suite` Tested on big endian platform using JIT compiler or interpreter `-Xint`. Author: Kazuaki Ishizaki Closes #20804 from kiszk/SPARK-23656. --- .../sql/catalyst/expressions/XXH64Suite.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 711887f02832a..1baee91b3439c 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -74,9 +74,6 @@ public void testKnownByteArrayInputs() { Assert.assertEquals(0x739840CB819FA723L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); - // These tests currently fail in a big endian environment because the test data and expected - // answers are generated with little endian the assumptions. We could revisit this when Platform - // becomes endian aware. if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); @@ -94,6 +91,23 @@ public void testKnownByteArrayInputs() { hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); Assert.assertEquals(0xCAA65939306F1E21L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } else { + Assert.assertEquals(0x7F875412350ADDDCL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x564D279F524D8516L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0x7D9F07E27E0EB006L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x893CEF564CB7858L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xC6198C4C9CC49E17L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x4E21BEF7164D4BBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0xBCF5FAEDEE1F2B5AL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0x6F680C877A358FE5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); } } From 9ddd1e2ceac8155b30beebb6bbfdcd32296fab2d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 13 Mar 2018 23:31:08 +0900 Subject: [PATCH 0467/2461] [MINOR][SQL][TEST] Create table using `dataSourceName` in `HadoopFsRelationTest` ## What changes were proposed in this pull request? This PR fixes a minor issue in `HadoopFsRelationTest`, that you should create table using `dataSourceName` instead of `parquet`. The issue won't affect the correctness, but it will generate wrong error message in case the test fails. ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #20780 from jiangxb1987/dataSourceName. --- .../apache/spark/sql/sources/HadoopFsRelationTest.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 80aff446bc24b..53397991e59dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -335,16 +335,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") - intercept[AnalysisException] { + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") + val msg = intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") - } + }.getMessage + assert(msg.contains("Table `t` already exists")) } } test("saveAsTable()/load() - non-partitioned table - Ignore") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } From 918fb9beee6a2fd499b8f18dfe0d460f078f5290 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Tue, 13 Mar 2018 11:31:32 -0700 Subject: [PATCH 0468/2461] [SPARK-23547][SQL] Cleanup the .pipeout file when the Hive Session closed ## What changes were proposed in this pull request? ![2018-03-07_121010](https://user-images.githubusercontent.com/24823338/37073232-922e10d2-2200-11e8-8172-6e03aa984b39.png) when the hive session closed, we should also cleanup the .pipeout file. ## How was this patch tested? Added test cases. Author: zuotingbing Closes #20702 from zuotingbing/SPARK-23547. --- .../service/cli/session/HiveSessionImpl.java | 18 +++++++++++ .../HiveThriftServer2Suites.scala | 32 ++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index fc818bc69c761..f59cdcd3188e6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -641,6 +641,8 @@ public void close() throws HiveSQLException { opHandleSet.clear(); // Cleanup session log directory. cleanupSessionLogDir(); + // Cleanup pipeout file. + cleanupPipeoutFile(); HiveHistory hiveHist = sessionState.getHiveHistory(); if (null != hiveHist) { hiveHist.closeStream(); @@ -665,6 +667,22 @@ public void close() throws HiveSQLException { } } + private void cleanupPipeoutFile() { + String lScratchDir = hiveConf.getVar(ConfVars.LOCALSCRATCHDIR); + String sessionID = hiveConf.getVar(ConfVars.HIVESESSIONID); + + File[] fileAry = new File(lScratchDir).listFiles( + (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); + + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } + } + } + private void cleanupSessionLogDir() { if (isOperationLogEnabled) { try { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b32c547cefefe..192f33a45e273 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import java.io.File +import java.io.{File, FilenameFilter} import java.net.URL import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -613,6 +614,28 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { bufferSrc.close() } } + + test("SPARK-23547 Cleanup the .pipeout file when the Hive Session closed") { + def pipeoutFileList(sessionID: UUID): Array[File] = { + lScratchDir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + name.startsWith(sessionID.toString) && name.endsWith(".pipeout") + } + }) + } + + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + val sessionID = sessionHandle.getSessionId + + assert(pipeoutFileList(sessionID).length == 1) + + client.closeSession(sessionHandle) + + assert(pipeoutFileList(sessionID).length == 0) + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -807,6 +830,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ + protected var lScratchDir: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -844,6 +868,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath + | --hiveconf ${ConfVars.LOCALSCRATCHDIR}=$lScratchDir | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -873,6 +898,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() operationLogPath = Utils.createTempDir() operationLogPath.delete() + lScratchDir = Utils.createTempDir() + lScratchDir.delete() logPath = null logTailingProcess = null @@ -956,6 +983,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl operationLogPath.delete() operationLogPath = null + lScratchDir.delete() + lScratchDir = null + Option(logPath).foreach(_.delete()) logPath = null From 1098933b0ac5cdb18101d3aebefa773c2ce05a50 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 23:04:16 +0100 Subject: [PATCH 0469/2461] [SPARK-23598][SQL] Make methods in BufferedRowIterator public to avoid runtime error for a large query ## What changes were proposed in this pull request? This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`. This PR fixes this issue by making them `public`. Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`. ``` test("SPARK-23598") { // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") df_pet_age.groupBy("name").avg("age").show() } ``` Exception: ``` 19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1 at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ... ``` Generated code (line 195 calles `stopEarly()`). ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean agg_initAgg; /* 010 */ private boolean agg_bufIsNull; /* 011 */ private double agg_bufValue; /* 012 */ private boolean agg_bufIsNull1; /* 013 */ private long agg_bufValue1; /* 014 */ private agg_FastHashMap agg_fastHashMap; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_fastHashMapIter; /* 016 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 017 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 018 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 019 */ private scala.collection.Iterator inputadapter_input; /* 020 */ private boolean agg_agg_isNull11; /* 021 */ private boolean agg_agg_isNull25; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2]; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 024 */ private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2]; /* 025 */ /* 026 */ public GeneratedIteratorForCodegenStage1(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ /* 034 */ agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer()); /* 035 */ agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap(); /* 036 */ inputadapter_input = inputs[0]; /* 037 */ agg_mutableStateArray[0] = new UnsafeRow(1); /* 038 */ agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32); /* 039 */ agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1); /* 040 */ agg_mutableStateArray[1] = new UnsafeRow(3); /* 041 */ agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32); /* 042 */ agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ public class agg_FastHashMap { /* 047 */ private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; /* 048 */ private int[] buckets; /* 049 */ private int capacity = 1 << 16; /* 050 */ private double loadFactor = 0.5; /* 051 */ private int numBuckets = (int) (capacity / loadFactor); /* 052 */ private int maxSteps = 2; /* 053 */ private int numRows = 0; /* 054 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType); /* 055 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType) /* 056 */ .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType); /* 057 */ private Object emptyVBase; /* 058 */ private long emptyVOff; /* 059 */ private int emptyVLen; /* 060 */ private boolean isBatchFull = false; /* 061 */ /* 062 */ public agg_FastHashMap( /* 063 */ org.apache.spark.memory.TaskMemoryManager taskMemoryManager, /* 064 */ InternalRow emptyAggregationBuffer) { /* 065 */ batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch /* 066 */ .allocate(keySchema, valueSchema, taskMemoryManager, capacity); /* 067 */ /* 068 */ final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); /* 069 */ final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); /* 070 */ /* 071 */ emptyVBase = emptyBuffer; /* 072 */ emptyVOff = Platform.BYTE_ARRAY_OFFSET; /* 073 */ emptyVLen = emptyBuffer.length; /* 074 */ /* 075 */ buckets = new int[numBuckets]; /* 076 */ java.util.Arrays.fill(buckets, -1); /* 077 */ } /* 078 */ /* 079 */ public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) { /* 080 */ long h = hash(agg_key); /* 081 */ int step = 0; /* 082 */ int idx = (int) h & (numBuckets - 1); /* 083 */ while (step < maxSteps) { /* 084 */ // Return bucket index if it's either an empty slot or already contains the key /* 085 */ if (buckets[idx] == -1) { /* 086 */ if (numRows < capacity && !isBatchFull) { /* 087 */ // creating the unsafe for new entry /* 088 */ UnsafeRow agg_result = new UnsafeRow(1); /* 089 */ org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder /* 090 */ = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, /* 091 */ 32); /* 092 */ org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter /* 093 */ = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( /* 094 */ agg_holder, /* 095 */ 1); /* 096 */ agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed /* 097 */ agg_rowWriter.zeroOutNullBytes(); /* 098 */ agg_rowWriter.write(0, agg_key); /* 099 */ agg_result.setTotalSize(agg_holder.totalSize()); /* 100 */ Object kbase = agg_result.getBaseObject(); /* 101 */ long koff = agg_result.getBaseOffset(); /* 102 */ int klen = agg_result.getSizeInBytes(); /* 103 */ /* 104 */ UnsafeRow vRow /* 105 */ = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen); /* 106 */ if (vRow == null) { /* 107 */ isBatchFull = true; /* 108 */ } else { /* 109 */ buckets[idx] = numRows++; /* 110 */ } /* 111 */ return vRow; /* 112 */ } else { /* 113 */ // No more space /* 114 */ return null; /* 115 */ } /* 116 */ } else if (equals(idx, agg_key)) { /* 117 */ return batch.getValueRow(buckets[idx]); /* 118 */ } /* 119 */ idx = (idx + 1) & (numBuckets - 1); /* 120 */ step++; /* 121 */ } /* 122 */ // Didn't find it /* 123 */ return null; /* 124 */ } /* 125 */ /* 126 */ private boolean equals(int idx, UTF8String agg_key) { /* 127 */ UnsafeRow row = batch.getKeyRow(buckets[idx]); /* 128 */ return (row.getUTF8String(0).equals(agg_key)); /* 129 */ } /* 130 */ /* 131 */ private long hash(UTF8String agg_key) { /* 132 */ long agg_hash = 0; /* 133 */ /* 134 */ int agg_result = 0; /* 135 */ byte[] agg_bytes = agg_key.getBytes(); /* 136 */ for (int i = 0; i < agg_bytes.length; i++) { /* 137 */ int agg_hash1 = agg_bytes[i]; /* 138 */ agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2); /* 139 */ } /* 140 */ /* 141 */ agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2); /* 142 */ /* 143 */ return agg_hash; /* 144 */ } /* 145 */ /* 146 */ public org.apache.spark.unsafe.KVIterator rowIterator() { /* 147 */ return batch.rowIterator(); /* 148 */ } /* 149 */ /* 150 */ public void close() { /* 151 */ batch.close(); /* 152 */ } /* 153 */ /* 154 */ } /* 155 */ /* 156 */ protected void processNext() throws java.io.IOException { /* 157 */ if (!agg_initAgg) { /* 158 */ agg_initAgg = true; /* 159 */ long wholestagecodegen_beforeAgg = System.nanoTime(); /* 160 */ agg_nestedClassInstance1.agg_doAggregateWithKeys(); /* 161 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000); /* 162 */ } /* 163 */ /* 164 */ // output the result /* 165 */ /* 166 */ while (agg_fastHashMapIter.next()) { /* 167 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey(); /* 168 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue(); /* 169 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 170 */ /* 171 */ if (shouldStop()) return; /* 172 */ } /* 173 */ agg_fastHashMap.close(); /* 174 */ /* 175 */ while (agg_mapIter.next()) { /* 176 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 177 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 178 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 179 */ /* 180 */ if (shouldStop()) return; /* 181 */ } /* 182 */ /* 183 */ agg_mapIter.close(); /* 184 */ if (agg_sorter == null) { /* 185 */ agg_hashMap.free(); /* 186 */ } /* 187 */ } /* 188 */ /* 189 */ private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass(); /* 190 */ private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1(); /* 191 */ private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass(); /* 192 */ /* 193 */ private class agg_NestedClass1 { /* 194 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 195 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 196 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 197 */ int inputadapter_value = inputadapter_row.getInt(0); /* 198 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 199 */ UTF8String inputadapter_value1 = inputadapter_isNull1 ? /* 200 */ null : (inputadapter_row.getUTF8String(1)); /* 201 */ /* 202 */ agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1); /* 203 */ if (shouldStop()) return; /* 204 */ } /* 205 */ /* 206 */ agg_fastHashMapIter = agg_fastHashMap.rowIterator(); /* 207 */ agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */)); /* 208 */ /* 209 */ } /* 210 */ /* 211 */ } /* 212 */ /* 213 */ private class wholestagecodegen_NestedClass { /* 214 */ private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm) /* 215 */ throws java.io.IOException { /* 216 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1); /* 217 */ /* 218 */ boolean agg_isNull35 = agg_keyTerm.isNullAt(0); /* 219 */ UTF8String agg_value37 = agg_isNull35 ? /* 220 */ null : (agg_keyTerm.getUTF8String(0)); /* 221 */ boolean agg_isNull36 = agg_bufferTerm.isNullAt(0); /* 222 */ double agg_value38 = agg_isNull36 ? /* 223 */ -1.0 : (agg_bufferTerm.getDouble(0)); /* 224 */ boolean agg_isNull37 = agg_bufferTerm.isNullAt(1); /* 225 */ long agg_value39 = agg_isNull37 ? /* 226 */ -1L : (agg_bufferTerm.getLong(1)); /* 227 */ /* 228 */ agg_mutableStateArray1[1].reset(); /* 229 */ /* 230 */ agg_mutableStateArray2[1].zeroOutNullBytes(); /* 231 */ /* 232 */ if (agg_isNull35) { /* 233 */ agg_mutableStateArray2[1].setNullAt(0); /* 234 */ } else { /* 235 */ agg_mutableStateArray2[1].write(0, agg_value37); /* 236 */ } /* 237 */ /* 238 */ if (agg_isNull36) { /* 239 */ agg_mutableStateArray2[1].setNullAt(1); /* 240 */ } else { /* 241 */ agg_mutableStateArray2[1].write(1, agg_value38); /* 242 */ } /* 243 */ /* 244 */ if (agg_isNull37) { /* 245 */ agg_mutableStateArray2[1].setNullAt(2); /* 246 */ } else { /* 247 */ agg_mutableStateArray2[1].write(2, agg_value39); /* 248 */ } /* 249 */ agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize()); /* 250 */ append(agg_mutableStateArray[1]); /* 251 */ /* 252 */ } /* 253 */ /* 254 */ } /* 255 */ /* 256 */ private class agg_NestedClass { /* 257 */ private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException { /* 258 */ UnsafeRow agg_unsafeRowAggBuffer = null; /* 259 */ UnsafeRow agg_fastAggBuffer = null; /* 260 */ /* 261 */ if (true) { /* 262 */ if (!agg_exprIsNull_1) { /* 263 */ agg_fastAggBuffer = agg_fastHashMap.findOrInsert( /* 264 */ agg_expr_1); /* 265 */ } /* 266 */ } /* 267 */ // Cannot find the key in fast hash map, try regular hash map. /* 268 */ if (agg_fastAggBuffer == null) { /* 269 */ // generate grouping key /* 270 */ agg_mutableStateArray1[0].reset(); /* 271 */ /* 272 */ agg_mutableStateArray2[0].zeroOutNullBytes(); /* 273 */ /* 274 */ if (agg_exprIsNull_1) { /* 275 */ agg_mutableStateArray2[0].setNullAt(0); /* 276 */ } else { /* 277 */ agg_mutableStateArray2[0].write(0, agg_expr_1); /* 278 */ } /* 279 */ agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize()); /* 280 */ int agg_value7 = 42; /* 281 */ /* 282 */ if (!agg_exprIsNull_1) { /* 283 */ agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7); /* 284 */ } /* 285 */ if (true) { /* 286 */ // try to get the buffer from hash map /* 287 */ agg_unsafeRowAggBuffer = /* 288 */ agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7); /* 289 */ } /* 290 */ // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based /* 291 */ // aggregation after processing all input rows. /* 292 */ if (agg_unsafeRowAggBuffer == null) { /* 293 */ if (agg_sorter == null) { /* 294 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 295 */ } else { /* 296 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 297 */ } /* 298 */ /* 299 */ // the hash map had be spilled, it should have enough memory now, /* 300 */ // try to allocate buffer again. /* 301 */ agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow( /* 302 */ agg_mutableStateArray[0], agg_value7); /* 303 */ if (agg_unsafeRowAggBuffer == null) { /* 304 */ // failed to allocate the first page /* 305 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 306 */ } /* 307 */ } /* 308 */ /* 309 */ } /* 310 */ /* 311 */ if (agg_fastAggBuffer != null) { /* 312 */ // common sub-expressions /* 313 */ boolean agg_isNull21 = false; /* 314 */ long agg_value23 = -1L; /* 315 */ if (!false) { /* 316 */ agg_value23 = (long) agg_expr_0; /* 317 */ } /* 318 */ // evaluate aggregate function /* 319 */ boolean agg_isNull23 = true; /* 320 */ double agg_value25 = -1.0; /* 321 */ /* 322 */ boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0); /* 323 */ double agg_value26 = agg_isNull24 ? /* 324 */ -1.0 : (agg_fastAggBuffer.getDouble(0)); /* 325 */ if (!agg_isNull24) { /* 326 */ agg_agg_isNull25 = true; /* 327 */ double agg_value27 = -1.0; /* 328 */ do { /* 329 */ boolean agg_isNull26 = agg_isNull21; /* 330 */ double agg_value28 = -1.0; /* 331 */ if (!agg_isNull21) { /* 332 */ agg_value28 = (double) agg_value23; /* 333 */ } /* 334 */ if (!agg_isNull26) { /* 335 */ agg_agg_isNull25 = false; /* 336 */ agg_value27 = agg_value28; /* 337 */ continue; /* 338 */ } /* 339 */ /* 340 */ boolean agg_isNull27 = false; /* 341 */ double agg_value29 = -1.0; /* 342 */ if (!false) { /* 343 */ agg_value29 = (double) 0; /* 344 */ } /* 345 */ if (!agg_isNull27) { /* 346 */ agg_agg_isNull25 = false; /* 347 */ agg_value27 = agg_value29; /* 348 */ continue; /* 349 */ } /* 350 */ /* 351 */ } while (false); /* 352 */ /* 353 */ agg_isNull23 = false; // resultCode could change nullability. /* 354 */ agg_value25 = agg_value26 + agg_value27; /* 355 */ /* 356 */ } /* 357 */ boolean agg_isNull29 = false; /* 358 */ long agg_value31 = -1L; /* 359 */ if (!false && agg_isNull21) { /* 360 */ boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1); /* 361 */ long agg_value33 = agg_isNull31 ? /* 362 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 363 */ agg_isNull29 = agg_isNull31; /* 364 */ agg_value31 = agg_value33; /* 365 */ } else { /* 366 */ boolean agg_isNull32 = true; /* 367 */ long agg_value34 = -1L; /* 368 */ /* 369 */ boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1); /* 370 */ long agg_value35 = agg_isNull33 ? /* 371 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 372 */ if (!agg_isNull33) { /* 373 */ agg_isNull32 = false; // resultCode could change nullability. /* 374 */ agg_value34 = agg_value35 + 1L; /* 375 */ /* 376 */ } /* 377 */ agg_isNull29 = agg_isNull32; /* 378 */ agg_value31 = agg_value34; /* 379 */ } /* 380 */ // update fast row /* 381 */ if (!agg_isNull23) { /* 382 */ agg_fastAggBuffer.setDouble(0, agg_value25); /* 383 */ } else { /* 384 */ agg_fastAggBuffer.setNullAt(0); /* 385 */ } /* 386 */ /* 387 */ if (!agg_isNull29) { /* 388 */ agg_fastAggBuffer.setLong(1, agg_value31); /* 389 */ } else { /* 390 */ agg_fastAggBuffer.setNullAt(1); /* 391 */ } /* 392 */ } else { /* 393 */ // common sub-expressions /* 394 */ boolean agg_isNull7 = false; /* 395 */ long agg_value9 = -1L; /* 396 */ if (!false) { /* 397 */ agg_value9 = (long) agg_expr_0; /* 398 */ } /* 399 */ // evaluate aggregate function /* 400 */ boolean agg_isNull9 = true; /* 401 */ double agg_value11 = -1.0; /* 402 */ /* 403 */ boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0); /* 404 */ double agg_value12 = agg_isNull10 ? /* 405 */ -1.0 : (agg_unsafeRowAggBuffer.getDouble(0)); /* 406 */ if (!agg_isNull10) { /* 407 */ agg_agg_isNull11 = true; /* 408 */ double agg_value13 = -1.0; /* 409 */ do { /* 410 */ boolean agg_isNull12 = agg_isNull7; /* 411 */ double agg_value14 = -1.0; /* 412 */ if (!agg_isNull7) { /* 413 */ agg_value14 = (double) agg_value9; /* 414 */ } /* 415 */ if (!agg_isNull12) { /* 416 */ agg_agg_isNull11 = false; /* 417 */ agg_value13 = agg_value14; /* 418 */ continue; /* 419 */ } /* 420 */ /* 421 */ boolean agg_isNull13 = false; /* 422 */ double agg_value15 = -1.0; /* 423 */ if (!false) { /* 424 */ agg_value15 = (double) 0; /* 425 */ } /* 426 */ if (!agg_isNull13) { /* 427 */ agg_agg_isNull11 = false; /* 428 */ agg_value13 = agg_value15; /* 429 */ continue; /* 430 */ } /* 431 */ /* 432 */ } while (false); /* 433 */ /* 434 */ agg_isNull9 = false; // resultCode could change nullability. /* 435 */ agg_value11 = agg_value12 + agg_value13; /* 436 */ /* 437 */ } /* 438 */ boolean agg_isNull15 = false; /* 439 */ long agg_value17 = -1L; /* 440 */ if (!false && agg_isNull7) { /* 441 */ boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1); /* 442 */ long agg_value19 = agg_isNull17 ? /* 443 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 444 */ agg_isNull15 = agg_isNull17; /* 445 */ agg_value17 = agg_value19; /* 446 */ } else { /* 447 */ boolean agg_isNull18 = true; /* 448 */ long agg_value20 = -1L; /* 449 */ /* 450 */ boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1); /* 451 */ long agg_value21 = agg_isNull19 ? /* 452 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 453 */ if (!agg_isNull19) { /* 454 */ agg_isNull18 = false; // resultCode could change nullability. /* 455 */ agg_value20 = agg_value21 + 1L; /* 456 */ /* 457 */ } /* 458 */ agg_isNull15 = agg_isNull18; /* 459 */ agg_value17 = agg_value20; /* 460 */ } /* 461 */ // update unsafe row buffer /* 462 */ if (!agg_isNull9) { /* 463 */ agg_unsafeRowAggBuffer.setDouble(0, agg_value11); /* 464 */ } else { /* 465 */ agg_unsafeRowAggBuffer.setNullAt(0); /* 466 */ } /* 467 */ /* 468 */ if (!agg_isNull15) { /* 469 */ agg_unsafeRowAggBuffer.setLong(1, agg_value17); /* 470 */ } else { /* 471 */ agg_unsafeRowAggBuffer.setNullAt(1); /* 472 */ } /* 473 */ /* 474 */ } /* 475 */ /* 476 */ } /* 477 */ /* 478 */ } /* 479 */ /* 480 */ } ``` ## How was this patch tested? Added UT into `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #20779 from kiszk/SPARK-23598. --- .../spark/sql/execution/BufferedRowIterator.java | 12 ++++++++---- .../spark/sql/execution/WholeStageCodegenSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 730a4ae8d5605..74c9c05992719 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -62,10 +62,14 @@ public long durationMs() { */ public abstract void init(int index, Iterator[] iters); + /* + * Attributes of the following four methods are public. Thus, they can be also accessed from + * methods in inner classes. See SPARK-23598 + */ /** * Append a row to currentRows. */ - protected void append(InternalRow row) { + public void append(InternalRow row) { currentRows.add(row); } @@ -75,7 +79,7 @@ protected void append(InternalRow row) { * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. * This interface is mainly used to limit the number of input rows. */ - protected boolean stopEarly() { + public boolean stopEarly() { return false; } @@ -84,14 +88,14 @@ protected boolean stopEarly() { * * If it returns true, the caller should exit the loop (return from processNext()). */ - protected boolean shouldStop() { + public boolean shouldStop() { return !currentRows.isEmpty(); } /** * Increase the peak execution memory for current task. */ - protected void incPeakExecutionMemory(long size) { + public void incPeakExecutionMemory(long size) { TaskContext.get().taskMetrics().incPeakExecutionMemory(size); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0fb9dd2017a09..4b40e4ef7571c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan @@ -307,4 +309,14 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { // a different query can result in codegen cache miss, that's by design } } + + test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") + for (i <- 0 until 70) { + df = df.groupBy("name").agg(avg("age").alias("age")) + } + assert(df.limit(1).collect() === Array(Row("bat", 8.0))) + } + } } From 279b3db8970809104c30941254e57e3d62da5041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Mar 2018 18:36:31 -0700 Subject: [PATCH 0470/2461] [SPARK-22915][MLLIB] Streaming tests for spark.ml.feature, from N to Z MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: - NGramSuite - NormalizerSuite - OneHotEncoderEstimatorSuite - OneHotEncoderSuite - PCASuite - PolynomialExpansionSuite - QuantileDiscretizerSuite - RFormulaSuite - SQLTransformerSuite - StandardScalerSuite - StopWordsRemoverSuite - StringIndexerSuite - TokenizerSuite - RegexTokenizerSuite - VectorAssemblerSuite - VectorIndexerSuite - VectorSizeHintSuite - VectorSlicerSuite - Word2VecSuite # How was this patch tested? They are unit test. Author: “attilapiros” Closes #20686 from attilapiros/SPARK-22915. --- .../apache/spark/ml/feature/NGramSuite.scala | 23 +- .../spark/ml/feature/NormalizerSuite.scala | 57 ++--- .../feature/OneHotEncoderEstimatorSuite.scala | 193 ++++++++--------- .../spark/ml/feature/OneHotEncoderSuite.scala | 124 ++++++----- .../apache/spark/ml/feature/PCASuite.scala | 14 +- .../ml/feature/PolynomialExpansionSuite.scala | 62 +++--- .../ml/feature/QuantileDiscretizerSuite.scala | 198 +++++++++-------- .../spark/ml/feature/RFormulaSuite.scala | 158 +++++++------- .../ml/feature/SQLTransformerSuite.scala | 35 +-- .../ml/feature/StandardScalerSuite.scala | 33 +-- .../ml/feature/StopWordsRemoverSuite.scala | 37 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 204 +++++++++--------- .../spark/ml/feature/TokenizerSuite.scala | 30 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 183 +++++++++------- .../ml/feature/VectorSizeHintSuite.scala | 88 +++++--- .../spark/ml/feature/VectorSlicerSuite.scala | 27 +-- .../spark/ml/feature/Word2VecSuite.scala | 28 +-- .../org/apache/spark/ml/util/MLTest.scala | 33 ++- 18 files changed, 809 insertions(+), 718 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index d4975c0b4e20e..e5956ee9942aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} + @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NGramSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.NGramSuite._ import testImplicits._ test("default behavior yields bigram features") { @@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setN(3) testDefaultReadWrite(t) } -} - -object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("nGrams", "wantedNGrams") - .collect() - .foreach { case Row(actualNGrams, wantedNGrams) => + def testNGram(t: NGram, dataFrame: DataFrame): Unit = { + testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { + case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => assert(actualNGrams === wantedNGrams) - } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index c75027fb4553d..eff57f1223af4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,21 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NormalizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @transient var data: Array[Vector] = _ - @transient var dataFrame: DataFrame = _ - @transient var normalizer: Normalizer = _ @transient var l1Normalized: Array[Vector] = _ @transient var l2Normalized: Array[Vector] = _ @@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.dense(0.897906166, 0.113419726, 0.42532397), Vectors.sparse(3, Seq()) ) - - dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() - normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normalized_features") - } - - def collectResult(result: DataFrame): Array[Vector] = { - result.select("normalized_features").collect().map { - case Row(features: Vector) => features - } } - def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { case (v1: DenseVector, v2: DenseVector) => true case (v1: SparseVector, v2: SparseVector) => true case _ => false }, "The vector type should be preserved after normalization.") } - def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { (vector1, vector2) => - vector1 ~== vector2 absTol 1E-5 - }, "The vector value is not correct after normalization.") + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.") } test("Normalization with default parameter") { - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized") + val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected") - assertValues(result, l2Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("Normalization with setter") { - normalizer.setP(1) + val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected") + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1) - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) - - assertValues(result, l1Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("read/write") { @@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } } - -private object NormalizerSuite { - case class FeatureData(features: Vector) -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 1d3f845586426..d549e13262273 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ -class OneHotEncoderEstimatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("size")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } test("input column without ML attribute") { @@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("index")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } } test("read/write") { @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite val df = spark.createDataFrame(sc.parallelize(data), schema) - val dfWithTypes = df - .withColumn("shortInput", df("input").cast(ShortType)) - .withColumn("longInput", df("input").cast(LongType)) - .withColumn("intInput", df("input").cast(IntegerType)) - .withColumn("floatInput", df("input").cast(FloatType)) - .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) - - val cols = Array("input", "shortInput", "longInput", "intInput", - "floatInput", "decimalInput") - for (col <- cols) { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array(col)) + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoderEstimator() + .setInputCols(Array("input")) .setOutputCols(Array("output")) .setDropLast(false) - val model = encoder.fit(dfWithTypes) - val encoded = model.transform(dfWithTypes) - - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) - } + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } @@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === false) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output1", "output2")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).show - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "encoded") + } test("Can't transform on negative input") { @@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Negative value: -1.0. Input can't be negative") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Negative value: -1.0. Input can't be negative", + firstResultCol = "encoded") } test("Keep on invalid values: dropLast = false") { @@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(false) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(true) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(df) model.setDropLast(false) - val encoded1 = model.transform(df) - encoded1.select("output", "expected1").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { + case Row(output: Vector, expected1: Vector) => + assert(output === expected1) } model.setDropLast(true) - val encoded2 = model.transform(df) - encoded2.select("output", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { + case Row(output: Vector, expected2: Vector) => + assert(output === expected2) } } @@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(trainingDF) model.setHandleInvalid("error") - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Double, Vector)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "output") model.setHandleInvalid("keep") - model.transform(testDF).collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } } test("Transforming on mismatched attributes") { @@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", testAttr.toMetadata())) - val err = intercept[Exception] { - model.transform(testDF).collect() - } - err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + testTransformerByInterceptingException[(Double)]( + testDF, + model, + expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", + firstResultCol = "encoded") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index c44c6813a94be..41b32b2ffa096 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ class OneHotEncoderSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -54,16 +54,19 @@ class OneHotEncoderSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("OneHotEncoder dropLast = true") { @@ -71,16 +74,19 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(2, Seq((0, 1.0)))), + (1, Vectors.sparse(2, Seq())), + (2, Vectors.sparse(2, Seq((1, 1.0)))), + (3, Vectors.sparse(2, Seq((0, 1.0)))), + (4, Vectors.sparse(2, Seq((0, 1.0)))), + (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("input column with ML attribute") { @@ -90,20 +96,22 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) + val rows = encoder.transform(df).select("encoded").collect() + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) @@ -119,29 +127,41 @@ class OneHotEncoderSuite test("OneHotEncoder with varying types") { val df = stringIndexed() - val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) - val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", - "floatLabel", "decimalLabel") - for (col <- cols) { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = df.join(expected, "id") + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = withExpected.select(col("labelIndex") + .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected")) val encoder = new OneHotEncoder() - .setInputCol(col) + .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) - val encoded = encoder.transform(dfWithTypes) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + + testTransformer(dfWithTypes, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 3067a52a4df76..531b1d7c4d9f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PCASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pcaModel = pca.fit(df) MLTestingUtils.checkCopyAndUids(pca, pcaModel) - - pcaModel.transform(df).select("pca_features", "expected").collect().foreach { - case Row(x: Vector, y: Vector) => - assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") { + case Row(result: Vector, expected: Vector) => + assert(result ~== expected absTol 1e-5, + "Transformed vector is different with expected vector.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index e4b0ddf98bfad..0be7aa6c83f29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PolynomialExpansionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,6 +55,18 @@ class PolynomialExpansionSuite -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), Vectors.sparse(19, Array.empty, Array.empty)) + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after polynomial expansion.") + } + + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.") + } + test("Polynomial expansion with default parameter") { val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") @@ -67,13 +74,10 @@ class PolynomialExpansionSuite .setInputCol("features") .setOutputCol("polyFeatures") - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -85,13 +89,10 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(3) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -103,11 +104,9 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(1) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { case Row(expanded: Vector, expected: Vector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + assertValues(expanded, expected) } } @@ -133,12 +132,13 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") for (i <- Seq(10, 11)) { - val transformed = t.setDegree(i) - .transform(df) - .select(s"expectedPoly${i}size", "polyFeatures") - .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } - - assert(transformed.collect.forall(identity)) + testTransformer[(Vector, Int, Int)]( + df, + t.setDegree(i), + s"expectedPoly${i}size", + "polyFeatures") { case Row(size: Int, expected: Vector) => + assert(size === expected.size) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6c363799dd300..b009038bbd833 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.udf -class QuantileDiscretizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,19 +36,19 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + val model = discretizer.fit(df) - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + val relativeError = discretizer.getRelativeError + val numGoodBuckets = result.groupBy("result").count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") } test("Test on data with high proportion of duplicated values") { @@ -67,11 +63,14 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") + } } test("Test transform on data with NaN value") { @@ -90,17 +89,20 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData.toSeq.toDF("input") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result") } List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result", "expected").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double)](dataFrame, model, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -119,14 +121,17 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(5) - val result = discretizer.fit(trainDF).transform(testDF) - val firstBucketSize = result.filter(result("result") === 0.0).count - val lastBucketSize = result.filter(result("result") === 4.0).count + val model = discretizer.fit(trainDF) + testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(firstBucketSize === 30L, - s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") - assert(lastBucketSize === 31L, - s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + } } test("read/write") { @@ -167,21 +172,24 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) - } - - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") - - val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + val relativeError = discretizer.getRelativeError + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result + .groupBy("result" + i) + .count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}") + .count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } } } @@ -198,12 +206,16 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets == expectedNumBucket, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBucket but found ($observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } } } @@ -226,9 +238,12 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double, Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result1") } List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { @@ -237,8 +252,14 @@ class QuantileDiscretizerSuite val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { case (((a, b), c), d) => (a, b, c, d) }.toSeq.toDF("input1", "input2", "expected1", "expected2") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result1", "expected1", "result2", "expected2").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double, Double, Double)]( + dataFrame, + model, + "result1", + "expected1", + "result2", + "expected2") { case Row(x: Double, y: Double, z: Double, w: Double) => assert(x === y && w === z) } @@ -270,9 +291,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(numBucketsArray) - discretizer.fit(df).transform(df). - select("result1", "expected1", "result2", "expected2", "result3", "expected3") - .collect().foreach { + val model = discretizer.fit(df) + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df, + model, + "result1", + "expected1", + "result2", + "expected2", + "result3", + "expected3") { case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => assert(r1 === e1, s"The result value is not correct after bucketing. Expected $e1 but found $r1") @@ -324,20 +352,16 @@ class QuantileDiscretizerSuite .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) .fit(df) - val resultForMultiCols = plForMultiCols.transform(df) - .select("result1", "result2", "result3") - .collect() - - val resultForSingleCol = plForSingleCol.transform(df) - .select("result1", "result2", "result3") - .collect() + val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect() - resultForSingleCol.zip(resultForMultiCols).foreach { - case (rowForSingle, rowForMultiCols) => - assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && - rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && - rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) - } + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + plForMultiCols, + "result1", + "result2", + "result3") { rows => + assert(rows === expected) + } } test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + @@ -364,18 +388,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(Array(10, 10, 10)) - val result1 = discretizerSingleNumBuckets.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - val result2 = discretizerNumBucketsArray.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - - result1.zip(result2).foreach { - case (row1, row2) => - assert(row1.getDouble(0) == row2.getDouble(0) && - row1.getDouble(1) == row2.getDouble(1) && - row1.getDouble(2) == row2.getDouble(2)) + val model = discretizerSingleNumBuckets.fit(df) + val expected = model.transform(df).select("result1", "result2", "result3").collect() + + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + discretizerNumBucketsArray.fit(df), + "result1", + "result2", + "result3") { rows => + assert(rows === expected) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index bfe38d32dd77d..27d570f0b68ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -32,10 +31,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { def testRFormulaTransform[A: Encoder]( dataframe: DataFrame, formulaModel: RFormulaModel, - expected: DataFrame): Unit = { + expected: DataFrame, + expectedAttributes: AttributeGroup*): Unit = { + val resultSchema = formulaModel.transformSchema(dataframe.schema) + assert(resultSchema.json === expected.schema.json) + assert(resultSchema === expected.schema) val (first +: rest) = expected.schema.fieldNames.toSeq val expectedRows = expected.collect() testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows.head.schema.toString() == resultSchema.toString()) + for (expectedAttributeGroup <- expectedAttributes) { + val attributeGroup = + AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name)) + assert(attributeGroup === expectedAttributeGroup) + } assert(rows === expectedRows) } } @@ -49,15 +58,10 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) MLTestingUtils.checkCopyAndUids(formula, model) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) ).toDF("id", "v1", "v2", "features", "label") - // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -73,9 +77,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "y", "features") val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == model.transform(original).schema.toString) + testRFormulaTransform[(Int, Double)](original, model, expected) } test("label column already exists but forceIndexLabel was set with true") { @@ -93,9 +101,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[IllegalArgumentException] { model.transformSchema(original.schema) } - intercept[IllegalArgumentException] { - model.transform(original) - } + testTransformerByInterceptingException[(Int, Boolean)]( + original, + model, + "Label column already exists and is not of type NumericType.", + "x") } test("allow missing label column for test datasets") { @@ -105,21 +115,22 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == model.transform(original).schema.toString) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "_not_y", "features") + testRFormulaTransform[(Int, Double)](original, model, expected) } test("allow empty label") { val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -128,15 +139,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -175,9 +183,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { val model = formula.setStringIndexerOrderType(orderType).fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } @@ -218,9 +223,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ).toDF("id", "a", "b", "features", "label") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -254,19 +256,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model1 = formula1.fit(original) - val result1 = model1.transform(original) - val resultSchema1 = model1.transformSchema(original.schema) - // Note the column order is different between R and Spark. - val expected1 = Seq( - (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), - (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), - (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), - (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) - ).toDF("id", "a", "b", "c", "features", "label") - assert(result1.schema.toString == resultSchema1.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) - - val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( "features", Array[Attribute]( @@ -275,14 +264,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new BinaryAttribute(Some("a_bar"), Some(3)), new BinaryAttribute(Some("b_zz"), Some(4)), new NumericAttribute(Some("c"), Some(5)))) - assert(attrs1 === expectedAttrs1) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1) // There is no impact for string terms interaction. val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model2 = formula2.fit(original) - val result2 = model2.transform(original) - val resultSchema2 = model2.transformSchema(original.schema) // Note the column order is different between R and Spark. val expected2 = Seq( (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), @@ -290,10 +285,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") - assert(result2.schema.toString == resultSchema2.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) - - val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( "features", Array[Attribute]( @@ -304,7 +295,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(5)), new NumericAttribute(Some("a_bar:b_zq"), Some(6)), new NumericAttribute(Some("c"), Some(7)))) - assert(attrs2 === expectedAttrs2) + + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2) } test("index string label") { @@ -313,13 +305,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) + val attr = NominalAttribute.defaultAttr val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") - // assert(result.schema.toString == resultSchema.toString) + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(String, String, Int)](original, model, expected) } @@ -329,13 +322,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val expected = spark.createDataFrame( - Seq( + val attr = NominalAttribute.defaultAttr + val expected = Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) - ).toDF("id", "a", "b", "features", "label") + .toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Double, String, Int)](original, model, expected) } @@ -344,15 +338,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + .toDF("id", "a", "b", "features", "label") val expectedAttrs = new AttributeGroup( "features", Array( new BinaryAttribute(Some("a_bar"), Some(1)), new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) + } test("vector attribute generation") { @@ -360,14 +359,19 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) .toDF("id", "vec") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val attrs = new AttributeGroup("vec", 2) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)) + .toDF("id", "vec", "features", "label") + .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec_0"), Some(1)), new NumericAttribute(Some("vec_1"), Some(2)))) - assert(attrs === expectedAttrs) + + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("vector attribute generation with unnamed input attrs") { @@ -381,31 +385,31 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0) + ).toDF("id", "vec2", "features", "label") + .select($"id", $"vec2".as("vec2", metadata), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec2_0"), Some(1)), new NumericAttribute(Some("vec2_1"), Some(2)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs) } test("factor numeric interaction") { @@ -414,7 +418,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -423,15 +426,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - testRFormulaTransform[(Int, String, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("a_baz:b"), Some(1)), new NumericAttribute(Some("a_bar:b"), Some(2)), new NumericAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) } test("factor factor interaction") { @@ -439,14 +440,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") testRFormulaTransform[(Int, String, String)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( @@ -454,7 +453,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(2)), new NumericAttribute(Some("a_foo:b_zq"), Some(3)), new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs) } test("read/write: RFormula") { @@ -517,9 +516,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen features. val formula1 = new RFormula().setFormula("id ~ a + b") - intercept[SparkException] { - formula1.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula1.fit(df1), + "Unseen label:", + "features") val model1 = formula1.setHandleInvalid("skip").fit(df1) val model2 = formula1.setHandleInvalid("keep").fit(df1) @@ -538,21 +539,28 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") - intercept[SparkException] { - formula2.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula2.fit(df1), + "Unseen label:", + "label") + val model3 = formula2.setHandleInvalid("skip").fit(df1) val model4 = formula2.setHandleInvalid("keep").fit(df1) + val attr = NominalAttribute.defaultAttr val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + val expected4 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Int, String, String)](df2, model3, expected3) testRFormulaTransform[(Int, String, String)](df2, model4, expected4) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 673a146e619f2..cf09418d8e0a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.storage.StorageLevel -class SQLTransformerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class SQLTransformerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -37,14 +34,22 @@ class SQLTransformerSuite val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") - val result = sqlTrans.transform(original) - val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) - assert(original.sparkSession.catalog.listTables().count() == 0) + val resultSchema = sqlTrans.transformSchema(original.schema) + testTransformerByGlobalCheckFunc[(Int, Double, Double)]( + original, + sqlTrans, + "id", + "v1", + "v2", + "v3", + "v4") { rows => + assert(rows.head.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(rows == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) + } } test("read/write") { @@ -63,13 +68,13 @@ class SQLTransformerSuite } test("SPARK-22538: SQLTransformer should not unpersist given dataset") { - val df = spark.range(10) + val df = spark.range(10).toDF() df.cache() df.count() assert(df.storageLevel != StorageLevel.NONE) - new SQLTransformer() + val sqlTrans = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") - .transform(df) + testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => } assert(df.storageLevel != StorageLevel.NONE) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 350ba44baa1eb..c5c49d67194e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class StandardScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext ) } - def assertResult(df: DataFrame): Unit = { - df.select("standardized_features", "expected").collect().foreach { - case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, - "The vector value is not correct after standardization.") - } + def assertResult: Row => Unit = { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") } test("params") { @@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext val standardScaler0 = standardScalerEst0.fit(df0) MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) - assertResult(standardScaler0.transform(df0)) + testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")( + assertResult) } test("Standardization with setter") { @@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithStd(false) .fit(df3) - assertResult(standardScaler1.transform(df1)) - assertResult(standardScaler2.transform(df2)) - assertResult(standardScaler3.transform(df3)) + testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")( + assertResult) } test("sparse data and withMean") { @@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithMean(true) .setWithStd(false) .fit(df) - assertResult(standardScaler.transform(df)) + testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")( + assertResult) } test("StandardScaler read/write") { @@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 5262b146b184e..21259a50916d2 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -17,28 +17,20 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - -object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("filtered", "expected") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} -class StopWordsRemoverSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { - import StopWordsRemoverSuite._ import testImplicits._ + def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = { + testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") { + case Row(tokens: Seq[_], wantedTokens: Seq[_]) => + assert(tokens === wantedTokens) + } + } + test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") @@ -151,9 +143,10 @@ class StopWordsRemoverSuite .setOutputCol(outputCol) val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) - val thrown = intercept[IllegalArgumentException] { - testStopWordsRemover(remover, dataSet) - } - assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + testTransformerByInterceptingException[(Array[String], Array[String])]( + dataSet, + remover, + s"requirement failed: Column $outputCol already exists.", + "expected") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 775a04d3df050..df24367177011 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -class StringIndexerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StringIndexerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -46,19 +43,23 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val indexerModel = indexer.fit(df) - MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - - val transformed = indexerModel.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + + testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(rows.seq === expected.collect().toSeq) + } } test("StringIndexerUnseen") { @@ -70,36 +71,38 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + // Verify we throw by default with unseen values - intercept[SparkException] { - indexer.transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer, + "Unseen label:", + "labelIndex") - indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformedSkip = indexer.transform(df2) - val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + indexer.setHandleInvalid("skip") + + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows.seq === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - // Verify that we keep the unseen records - val transformedKeep = indexer.transform(df2) - val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 - val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF() + + // Verify that we keep the unseen records + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexer with a numeric input column") { @@ -109,16 +112,14 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + assert(rows === expected.collect().toSeq) + } } test("StringIndexer with NULLs") { @@ -133,37 +134,36 @@ class StringIndexerSuite withClue("StringIndexer should throw error when setHandleInvalid=error " + "when given NULL values") { - intercept[SparkException] { - indexer.setHandleInvalid("error") - indexer.fit(df).transform(df2).collect() - } + indexer.setHandleInvalid("error") + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer.fit(df), + "StringIndexer encountered NULL value.", + "labelIndex") } indexer.setHandleInvalid("skip") - val transformedSkip = indexer.fit(df).transform(df2) - val attrSkip = Attribute - .fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + val modelSkip = indexer.fit(df) // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", "labelIndex") { rows => + val attrSkip = + Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - val transformedKeep = indexer.fit(df).transform(df2) - val attrKeep = Attribute - .fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0, null -> 2 - val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF() + val modelKeep = indexer.fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", "labelIndex") { rows => + val attrKeep = Attribute + .fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexerModel should keep silent if the input column does not exist.") { @@ -171,7 +171,9 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() - assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) + testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows => + assert(rows.toSet === df.collect().toSet) + } } test("StringIndexerModel can't overwrite output column") { @@ -188,9 +190,12 @@ class StringIndexerSuite .setOutputCol("indexedInput") .fit(df) - intercept[IllegalArgumentException] { - indexer.setOutputCol("output").transform(df) - } + testTransformerByInterceptingException[(Int, String)]( + df, + indexer.setOutputCol("output"), + "Output column output already exists.", + "labelIndex") + } test("StringIndexer read/write") { @@ -223,7 +228,8 @@ class StringIndexerSuite .setInputCol("index") .setOutputCol("actual") .setLabels(labels) - idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -234,7 +240,8 @@ class StringIndexerSuite val idxToStr1 = new IndexToString() .setInputCol("indexWithAttr") .setOutputCol("actual") - idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -252,9 +259,10 @@ class StringIndexerSuite .setInputCol("labelIndex") .setOutputCol("sameLabel") .setLabels(indexer.labels) - idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { - case Row(a: String, b: String) => - assert(a === b) + + testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", "label") { + case Row(sameLabel, label) => + assert(sameLabel === label) } } @@ -286,10 +294,11 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attrs = - NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) - assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") { rows => + val attrs = + NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } test("StringIndexer order types") { @@ -299,18 +308,17 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") - val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), - Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), - Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), - Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { - val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - assert(output === expected(idx)) + val model = indexer.setStringOrderType(orderType).fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", "labelIndex") { rows => + assert(rows === expected(idx).toDF().collect().toSeq) + } idx += 1 } } @@ -328,7 +336,11 @@ class StringIndexerSuite .setOutputCol("CITYIndexed") .fit(dfNoBristol) - val dfWithIndex = model.transform(dfNoBristol) - assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + testTransformerByGlobalCheckFunc[(String, String, String)]( + dfNoBristol, + model, + "CITYIndexed") { rows => + assert(rows.toList.count(_.getDouble(0) == 1.0) === 1) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index c895659a2d8be..be59b0af2c78e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class TokenizerSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) @@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } -class RegexTokenizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.RegexTokenizerSuite._ import testImplicits._ + def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = { + testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } + test("params") { ParamsSuite.checkParams(new RegexTokenizer) } @@ -105,14 +108,3 @@ class RegexTokenizerSuite } } -object RegexTokenizerSuite extends SparkFunSuite { - - def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("tokens", "wantedTokens") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 69a7b75e32eb7..e5675e31bbecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest with Logging { +class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { import testImplicits._ import VectorIndexerSuite.FeatureData @@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(vectorIndexer, model) - model.transform(densePoints1) // should work - model.transform(sparsePoints1) // should work + testTransformer[FeatureData](densePoints1, model, "indexed") { _ => } + testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => } + // If the data is local Dataset, it throws AssertionError directly. - intercept[AssertionError] { - model.transform(densePoints2).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called on " + + "vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2, + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } // If the data is distributed Dataset, it throws SparkException // which is the wrapper of AssertionError. - intercept[SparkException] { - model.transform(densePoints2.repartition(2)).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called " + + "on vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2.repartition(2), + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } intercept[SparkException] { vectorIndexer.fit(badPoints) @@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val categoryMaps = model.categoryMaps // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) - val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) - val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) - assert(featureAttrs.name === "indexed") - assert(featureAttrs.attributes.get.length === model.numFeatures) - categoricalFeatures.foreach { feature: Int => - val origValueSet = collectedData.map(_(feature)).toSet - val targetValueIndexSet = Range(0, origValueSet.size).toSet - val catMap = categoryMaps(feature) - assert(catMap.keys.toSet === origValueSet) // Correct categories - assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices - if (origValueSet.contains(0.0)) { - assert(catMap(0.0) === 0) // value 0 gets index 0 - } - // Check transformed data - assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) - // Check metadata - val featureAttr = featureAttrs(feature) - assert(featureAttr.index.get === feature) - featureAttr match { - case attr: BinaryAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - case attr: NominalAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - assert(attr.isOrdinal.get === false) - case _ => - throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) }.toDF("indexed") + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } - } - // Check numerical feature metadata. - Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) - .foreach { feature: Int => - val featureAttr = featureAttrs(feature) - featureAttr match { - case attr: NumericAttribute => - assert(featureAttr.index.get === feature) - case _ => - throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } } } catch { @@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext (sparsePoints1, sparsePoints1TestInvalid))) { val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") val model = vectorIndexer.fit(points) - intercept[SparkException] { - model.transform(pointsTestInvalid).collect() - } + testTransformerByInterceptingException[FeatureData]( + pointsTestInvalid, + model, + "VectorIndexer encountered invalid value", + "indexed") val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") val model1 = vectorIndexer1.fit(points) - val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) - assert(transformed1 === invalidTransformed1) - + val expected = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, 1.0), + Vectors.dense(1.0, 3.0, 2.0)) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } + testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") val model2 = vectorIndexer2.fit(points) - val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - assert(invalidTransformed2 === transformed1 ++ Array( - Vectors.dense(2.0, 2.0, 0.0), - Vectors.dense(0.0, 4.0, 2.0), - Vectors.dense(1.0, 3.0, 3.0)) - ) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, "indexed") { rows => + assert(rows.map(_(0)) == expected ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0))) + } } } @@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = - model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() - points.zip(indexedPoints).foreach { - case (orig: SparseVector, indexed: SparseVector) => - assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + points.zip(rows.map(_(0))).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } } } checkSparsity(sparsePoints1, maxCategories = 2) @@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. - val indexedPoints = model.transform(densePoints1WithMeta) - val transAttributes: Array[Attribute] = - AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get - featureAttributes.zip(transAttributes).foreach { case (orig, trans) => - assert(orig.name === trans.name) - (orig, trans) match { - case (orig: NumericAttribute, trans: NumericAttribute) => - assert(orig.max.nonEmpty && orig.max === trans.max) - case _ => + testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, "indexed") { rows => + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => // do nothing // TODO: Once input features marked as categorical are handled correctly, check that here. + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala index f6c9a76599fae..d89d10b320d84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamTest class VectorSizeHintSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,16 +38,23 @@ class VectorSizeHintSuite val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") val noSizeTransformer = new VectorSizeHint().setInputCol("vector") - intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noSizeTransformer, + "Failed to find a default value for size", + "vector") intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) val noInputColTransformer = new VectorSizeHint().setSize(2) - intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noInputColTransformer, + "Failed to find a default value for inputCol", + "vector") intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) } test("Adding size to column of vectors.") { - val size = 3 val vectorColName = "vector" val denseVector = Vectors.dense(1, 2, 3) @@ -66,12 +71,15 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrame) - assert( - AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, - "Transformer did not add expected size data.") - val numRows = withSize.collect().length - assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) { + rows => { + assert( + AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = rows.length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } } } @@ -93,14 +101,16 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrameWithMetadata) - - val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) - assert(newGroup.size === size, "Column has incorrect size metadata.") - assert( - newGroup.attributes.get === group.attributes.get, - "VectorSizeHint did not preserve attributes.") - withSize.collect + testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + vectorColName) { rows => + val newGroup = AttributeGroup.fromStructField(rows.head.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + } } } @@ -120,7 +130,11 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + testTransformerByInterceptingException[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + "Trying to set size of vectors in `vector` to 4 but size already set to 3.", + vectorColName) } } @@ -136,18 +150,36 @@ class VectorSizeHintSuite .setHandleInvalid("error") .setSize(3) - intercept[SparkException](sizeHint.transform(dataWithNull).collect()) - intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithNull, + sizeHint, + "Got null vector in VectorSizeHint", + "vector") + + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithShort, + sizeHint, + "VectorSizeHint Expecting a vector of size 3 but got 1", + "vector") sizeHint.setHandleInvalid("skip") - assert(sizeHint.transform(dataWithNull).count() === 1) - assert(sizeHint.transform(dataWithShort).count() === 1) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 1) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 1) + } sizeHint.setHandleInvalid("optimistic") - assert(sizeHint.transform(dataWithNull).count() === 2) - assert(sizeHint.transform(dataWithShort).count() === 2) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 2) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 2) + } } + test("read/write") { val sizeHint = new VectorSizeHint() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 1746ce53107c4..3d90f9d9ac764 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StructField, StructType} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { val slicer = new VectorSlicer().setInputCol("feature") @@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") - def validateResults(df: DataFrame): Unit = { - df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + def validateResults(rows: Seq[Row]): Unit = { + rows.foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 === vec2) } - val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) - val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected")) assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => assert(a === b) @@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 10682ba176aca..b59c4e7967338 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class Word2VecSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) - model.transform(docDF).select("result", "expected").collect().foreach { + testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } test("getVectors") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val spark = this.spark - import spark.implicits._ - val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -227,8 +213,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec works with input that is non-nullable (NGram)") { - val spark = this.spark - import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") @@ -243,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(ngramDF) // Just test that this transformation succeeds - model.transform(ngramDF).collect() + testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, model, "result") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 17678aa611a48..795fd0e2ac0e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,9 +22,10 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.ml.Transformer import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -62,8 +63,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val columnNames = dataframe.schema.fieldNames val stream = MemoryStream[A] - val streamDF = stream.toDS().toDF(columnNames: _*) - + val columnsWithMetadata = dataframe.schema.map { structField => + col(structField.name).as(structField.name, structField.metadata) + } + val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*) val data = dataframe.as[A].collect() val streamOutput = transformer.transform(streamDF) @@ -108,5 +111,29 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => otherResultCols: _*)(globalCheckFunction) testTransformerOnDF(dataframe, transformer, firstResultCol, otherResultCols: _*)(globalCheckFunction) + } + + def testTransformerByInterceptingException[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + expectedMessagePart : String, + firstResultCol: String) { + + def hasExpectedMessage(exception: Throwable): Boolean = + exception.getMessage.contains(expectedMessagePart) || + (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) + + withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { + val exceptionOnDf = intercept[Throwable] { + testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnDf)) + } + withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { + val exceptionOnStreamData = intercept[Throwable] { + testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnStreamData)) + } } } From 4f5bad615b47d743b8932aea1071652293981604 Mon Sep 17 00:00:00 2001 From: smallory Date: Thu, 15 Mar 2018 11:58:54 +0900 Subject: [PATCH 0471/2461] [SPARK-23642][DOCS] AccumulatorV2 subclass isZero scaladoc fix Added/corrected scaladoc for isZero on the DoubleAccumulator, CollectionAccumulator, and LongAccumulator subclasses of AccumulatorV2, particularly noting where there are requirements in addition to having a value of zero in order to return true. ## What changes were proposed in this pull request? Three scaladoc comments are updated in AccumulatorV2.scala No changes outside of comment blocks were made. ## How was this patch tested? Running "sbt unidoc", fixing style errors found, and reviewing the resulting local scaladoc in firefox. Author: smallory Closes #20790 from smallory/patch-1. --- .../main/scala/org/apache/spark/util/AccumulatorV2.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index f4a736d6d439a..0f84ea9752cf5 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -290,7 +290,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private var _count = 0L /** - * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + * * @since 2.0.0 */ override def isZero: Boolean = _sum == 0L && _count == 0 @@ -368,6 +369,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private var _sum = 0.0 private var _count = 0L + /** + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + */ override def isZero: Boolean = _sum == 0.0 && _count == 0 override def copy(): DoubleAccumulator = { @@ -441,6 +445,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) + /** + * Returns false if this accumulator instance has any values in it. + */ override def isZero: Boolean = _list.isEmpty override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator From 7c3e8995f18a1fb57c1f2c1b98a1d47590e28f38 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 15 Mar 2018 00:04:28 -0700 Subject: [PATCH 0472/2461] [SPARK-23533][SS] Add support for changing ContinuousDataReader's startOffset ## What changes were proposed in this pull request? As discussion in #20675, we need add a new interface `ContinuousDataReaderFactory` to support the requirements of setting start offset in Continuous Processing. ## How was this patch tested? Existing UT. Author: Yuanjian Li Closes #20689 from xuanyuanking/SPARK-23533. --- .../sql/kafka010/KafkaContinuousReader.scala | 11 +++++- .../reader/ContinuousDataReaderFactory.java | 35 +++++++++++++++++++ .../ContinuousRateStreamSource.scala | 15 +++++++- 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index ecd1170321f3f..6e56b0a72d671 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -164,7 +164,16 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + require(kafkaOffset.topicPartition == topicPartition, + s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") + new KafkaContinuousDataReader( + topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } + override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java new file mode 100644 index 0000000000000..a61697649c43e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; + +/** + * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can + * implement this interface to provide creating {@link DataReader} with particular offset. + */ +@InterfaceStability.Evolving +public interface ContinuousDataReaderFactory extends DataReaderFactory { + /** + * Create a DataReader with particular offset as its startOffset. + * + * @param offset offset want to set as the DataReader's startOffset. + */ + DataReader createDataReaderWithOffset(PartitionOffset offset); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b63d8d3e20650..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -106,7 +106,20 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends DataReaderFactory[Row] { + extends ContinuousDataReaderFactory[Row] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousDataReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) + } + override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) From 56e8f48a43eb51e8582db2461a585b13a771a00a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Mar 2018 10:55:33 -0700 Subject: [PATCH 0473/2461] [SPARK-23695][PYTHON] Fix the error message for Kinesis streaming tests ## What changes were proposed in this pull request? This PR proposes to fix the error message for Kinesis in PySpark when its jar is missing but explicitly enabled. ```bash ENABLE_KINESIS_TESTS=1 SPARK_TESTING=1 bin/pyspark pyspark.streaming.tests ``` Before: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1572, in % kinesis_asl_assembly_dir) + NameError: name 'kinesis_asl_assembly_dir' is not defined ``` After: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1576, in "You need to build Spark with 'build/sbt -Pkinesis-asl " Exception: Failed to find Spark Streaming Kinesis assembly jar in /.../spark/external/kinesis-asl-assembly. You need to build Spark with 'build/sbt -Pkinesis-asl assembly/package streaming-kinesis-asl-assembly/assembly'or 'build/mvn -Pkinesis-asl package' before running this test. ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20834 from HyukjinKwon/minor-variable. --- python/pyspark/streaming/tests.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 71f8101e34c50..7dde7c0928c08 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1503,10 +1503,13 @@ def search_flume_assembly_jar(): return jars[0] -def search_kinesis_asl_assembly_jar(): +def _kinesis_asl_assembly_dir(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") + + +def search_kinesis_asl_assembly_jar(): + jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") if not jars: return None elif len(jars) > 1: @@ -1569,7 +1572,7 @@ def search_kinesis_asl_assembly_jar(): else: raise Exception( ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % kinesis_asl_assembly_dir) + + % _kinesis_asl_assembly_dir()) + "You need to build Spark with 'build/sbt -Pkinesis-asl " "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") From 15c3c983008557165cc91713ddaf2dbd6d5a506c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Mar 2018 19:54:58 +0100 Subject: [PATCH 0474/2461] [HOT-FIX] Fix SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88263/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88260/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88257/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88224/testReport These tests all failed: ``` org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:157) at org.apache.spark.memory.MemoryConsumer.allocateArray(MemoryConsumer.java:98) at org.apache.spark.unsafe.map.BytesToBytesMap.allocate(BytesToBytesMap.java:787) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:204) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:219) ... ``` This PR ignore this test. ## How was this patch tested? N/A Author: Yuming Wang Closes #20835 from wangyum/SPARK-23598. --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4b40e4ef7571c..9180a22c260f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -310,7 +310,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + ignore("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") for (i <- 0 until 70) { From 7618896e855579f111dd92cd76794a5672a087e5 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 15 Mar 2018 17:04:39 -0700 Subject: [PATCH 0475/2461] [SPARK-23658][LAUNCHER] InProcessAppHandle uses the wrong class in getLogger ## What changes were proposed in this pull request? Changed `Logger` in `InProcessAppHandle` to use `InProcessAppHandle` instead of `ChildProcAppHandle` Author: Sahil Takiar Closes #20815 from sahilTakiar/master. --- .../main/java/org/apache/spark/launcher/InProcessAppHandle.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 4b740d3fad20e..15fbca0facef2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -25,7 +25,7 @@ class InProcessAppHandle extends AbstractAppHandle { private static final String THREAD_NAME_FMT = "spark-app-%d: '%s'"; - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(InProcessAppHandle.class.getName()); private static final AtomicLong THREAD_IDS = new AtomicLong(); // Avoid really long thread names. From 18f8575e0166c6997569358d45bdae2cf45bf624 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Mar 2018 17:12:01 -0700 Subject: [PATCH 0476/2461] [SPARK-23671][CORE] Fix condition to enable the SHS thread pool. Author: Marcelo Vanzin Closes #20814 from vanzin/SPARK-23671. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f9d0b5ee4e23e..ace6d9e00c838 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -173,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (Utils.isTesting) { + if (!Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() From 3675af7247e841e9a689666dc20891ba55c612b3 Mon Sep 17 00:00:00 2001 From: Ye Zhou Date: Thu, 15 Mar 2018 17:15:53 -0700 Subject: [PATCH 0477/2461] [SPARK-23608][CORE][WEBUI] Add synchronization in SHS between attachSparkUI and detachSparkUI functions to avoid concurrent modification issue to Jetty Handlers Jetty handlers are dynamically attached/detached while SHS is running. But the attach and detach operations might be taking place at the same time due to the async in load/clear in Guava Cache. ## What changes were proposed in this pull request? Add synchronization between attachSparkUI and detachSparkUI in SHS. ## How was this patch tested? With this patch, the jetty handlers missing issue never happens again in our production cluster SHS. Author: Ye Zhou Closes #20744 from zhouyejoe/SPARK-23608. --- .../apache/spark/deploy/history/HistoryServer.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 0ec4afad0308c..611fa563a7cd9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -150,14 +150,18 @@ class HistoryServer( ui: SparkUI, completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") - ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) + handlers.synchronized { + ui.getHandlers.foreach(attachHandler) + addFilters(ui.getHandlers, conf) + } } /** Detach a reconstructed UI from this server. Only valid after bind(). */ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") - ui.getHandlers.foreach(detachHandler) + handlers.synchronized { + ui.getHandlers.foreach(detachHandler) + } provider.onUIDetached(appId, attemptId, ui) } From c2632edebd978716dbfa7874a2fc0a8f5a4a9951 Mon Sep 17 00:00:00 2001 From: myroslavlisniak Date: Thu, 15 Mar 2018 17:20:17 -0700 Subject: [PATCH 0478/2461] [SPARK-23670][SQL] Fix memory leak on SparkPlanGraphWrapper Clean up SparkPlanGraphWrapper objects from InMemoryStore together with cleaning up SQLExecutionUIData existing unit test was extended to check also SparkPlanGraphWrapper object count vanzin Author: myroslavlisniak Closes #20813 from myroslavlisniak/master. --- .../apache/spark/sql/execution/ui/SQLAppStatusListener.scala | 5 ++++- .../apache/spark/sql/execution/ui/SQLAppStatusStore.scala | 4 ++++ .../spark/sql/execution/ui/SQLAppStatusListenerSuite.scala | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 53fb9a0cc21cf..71e9f93c4566e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -334,7 +334,10 @@ class SQLAppStatusListener( val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) - toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } + toDelete.foreach { e => + kvstore.delete(e.getClass(), e.executionId) + kvstore.delete(classOf[SparkPlanGraphWrapper], e.executionId) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 9a76584717f42..241001a857c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -54,6 +54,10 @@ class SQLAppStatusStore( store.count(classOf[SQLExecutionUIData]) } + def planGraphCount(): Long = { + store.count(classOf[SparkPlanGraphWrapper]) + } + def executionMetrics(executionId: Long): Map[Long, String] = { def metricsFromStore(): Option[Map[Long, String]] = { val exec = store.read(classOf[SQLExecutionUIData], executionId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 85face3994fd4..f3f08839c1d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -611,6 +611,7 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { sc.listenerBus.waitUntilEmpty(10000) val statusStore = spark.sharedState.statusStore assert(statusStore.executionsCount() <= 50) + assert(statusStore.planGraphCount() <= 50) // No live data should be left behind after all executions end. assert(statusStore.listener.get.noLiveData()) } From ca83526de55f0f8784df58cc8b7c0a7cb0c96e23 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 16 Mar 2018 15:12:26 +0800 Subject: [PATCH 0479/2461] [SPARK-23644][CORE][UI] Use absolute path for REST call in SHS ## What changes were proposed in this pull request? SHS is using a relative path for the REST API call to get the list of the application is a relative path call. In case of the SHS being consumed through a proxy, it can be an issue if the path doesn't end with a "/". Therefore, we should use an absolute path for the REST call as it is done for all the other resources. ## How was this patch tested? manual tests Before the change: ![screen shot 2018-03-10 at 4 22 02 pm](https://user-images.githubusercontent.com/8821783/37244190-8ccf9d40-2485-11e8-8fa9-345bc81472fc.png) After the change: ![screen shot 2018-03-10 at 4 36 34 pm 1](https://user-images.githubusercontent.com/8821783/37244201-a1922810-2485-11e8-8856-eeab2bf5e180.png) Author: Marco Gaido Closes #20794 from mgaido91/SPARK-23644. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index f0b2a5a833a99..abc2ec0fa6531 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -113,7 +113,7 @@ $(document).ready(function() { status: (requestedIncomplete ? "running" : "completed") }; - $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { + $.getJSON(uiRoot + "/api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { @@ -151,7 +151,7 @@ $(document).ready(function() { "showCompletedColumns": !requestedIncomplete, } - $.get("static/historypage-template.html", function(template) { + $.get(uiRoot + "/static/historypage-template.html", function(template) { var sibling = historySummary.prev(); historySummary.detach(); var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data)); From c952000487ee003200221b3c4e25dcb06e359f0a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Mar 2018 16:22:03 +0800 Subject: [PATCH 0480/2461] [SPARK-23635][YARN] AM env variable should not overwrite same name env variable set through spark.executorEnv. ## What changes were proposed in this pull request? In the current Spark on YARN code, AM always will copy and overwrite its env variables to executors, so we cannot set different values for executors. To reproduce issue, user could start spark-shell like: ``` ./bin/spark-shell --master yarn-client --conf spark.executorEnv.SPARK_ABC=executor_val --conf spark.yarn.appMasterEnv.SPARK_ABC=am_val ``` Then check executor env variables by ``` sc.parallelize(1 to 1).flatMap \{ i => sys.env.toSeq }.collect.foreach(println) ``` We will always get `am_val` instead of `executor_val`. So we should not let AM to overwrite specifically set executor env variables. ## How was this patch tested? Added UT and tested in local cluster. Author: jerryshao Closes #20799 from jerryshao/SPARK-23635. --- .../spark/deploy/yarn/ExecutorRunnable.scala | 22 +++++++----- .../spark/deploy/yarn/YarnClusterSuite.scala | 36 +++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 3f4d236571ffd..ab08698035c98 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -220,12 +220,6 @@ private[yarn] class ExecutorRunnable( val env = new HashMap[String, String]() Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, @@ -233,6 +227,20 @@ private[yarn] class ExecutorRunnable( ) val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } + + sparkConf.getExecutorEnv.foreach { case (key, value) => + if (key == Environment.CLASSPATH.name()) { + // If the key of env variable is CLASSPATH, we assume it is a path and append it. + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } else { + // For other env variables, simply overwrite the value. + env(key) = value + } + } + // Add log urls container.foreach { c => sys.env.get("SPARK_USER").foreach { user => @@ -245,8 +253,6 @@ private[yarn] class ExecutorRunnable( } } - System.getenv().asScala.filterKeys(_.startsWith("SPARK")) - .foreach { case (k, v) => env(k) = v } env } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 33d400a5b1b2e..a129be7c06b53 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -225,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { finalState should be (SparkAppHandle.State.FAILED) } + test("executor env overwrite AM env in client mode") { + testExecutorEnv(true) + } + + test("executor env overwrite AM env in cluster mode") { + testExecutorEnv(false) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -305,6 +313,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, executorResult, "OVERRIDDEN") } + private def testExecutorEnv(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass), + appArgs = Seq(result.getAbsolutePath), + extraConf = Map( + "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val", + "spark.executorEnv.TEST_ENV" -> "executor_val" + ) + ) + checkResult(finalState, result, "true") + } } private[spark] class SaveExecutorInfo extends SparkListener { @@ -526,3 +545,20 @@ private object SparkContextTimeoutApp { } } + +private object ExecutorEnvTestApp { + + def main(args: Array[String]): Unit = { + val status = args(0) + val sparkConf = new SparkConf() + val sc = new SparkContext(sparkConf) + val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap + val result = sparkConf.getExecutorEnv.forall { case (k, v) => + executorEnvs.get(k).contains(v) + } + + Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + sc.stop() + } + +} From 5414abca4fec6a68174c34d22d071c20027e959d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 16 Mar 2018 09:36:30 -0700 Subject: [PATCH 0481/2461] [SPARK-23553][TESTS] Tests should not assume the default value of `spark.sql.sources.default` ## What changes were proposed in this pull request? Currently, some tests have an assumption that `spark.sql.sources.default=parquet`. In fact, that is a correct assumption, but that assumption makes it difficult to test new data source format. This PR aims to - Improve test suites more robust and makes it easy to test new data sources in the future. - Test new native ORC data source with the full existing Apache Spark test coverage. As an example, the PR uses `spark.sql.sources.default=orc` during reviews. The value should be `parquet` when this PR is accepted. ## How was this patch tested? Pass the Jenkins with updated tests. Author: Dongjoon Hyun Closes #20705 from dongjoon-hyun/SPARK-23553. --- python/pyspark/sql/readwriter.py | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +-- .../columnar/InMemoryColumnarQuerySuite.scala | 5 +- .../sql/execution/command/DDLSuite.scala | 11 ++- .../ParquetPartitionDiscoverySuite.scala | 10 +++ .../sql/test/DataFrameReaderWriterSuite.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 72 +++++++++---------- .../PartitionProviderCompatibilitySuite.scala | 6 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 11 +-- .../sql/hive/execution/SQLQuerySuite.scala | 19 ++--- 11 files changed, 81 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 803f561ece67b..facc16bc53108 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -147,8 +147,8 @@ def load(self, path=None, format=None, schema=None, **options): or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options - >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, - ... opt2=1, opt3='str') + >>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_partitioned', + ... opt1=True, opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8f14575c3325f..640affc10ee58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2150,7 +2150,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { - sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + val provider = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE tbl(i INT, j STRING) USING $provider") checkAnswer(sql("SELECT i, j FROM tbl"), Nil) Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") @@ -2474,9 +2475,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { withTempDir { dir => - val parquetDir = new File(dir, "parquet").getCanonicalPath - spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) - spark.read.parquet(parquetDir) + val dataDir = new File(dir, "data").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(dataDir) + spark.read.load(dataDir) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index dc1766fb9a785..26b63e8e8490f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -487,7 +487,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { - withSQLConf("spark.sql.cbo.enabled" -> "true") { + // This test case depends on the size of parquet in statistics. + withSQLConf( + SQLConf.CBO_ENABLED.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "parquet") { withTempPath { workDir => withTable("table1") { val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4041176262426..4df8fbfe1c0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -154,10 +154,15 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") val e = intercept[AnalysisException] { - Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + val format = if (spark.sessionState.conf.defaultDataSourceName.equalsIgnoreCase("json")) { + "orc" + } else { + "json" + } + Seq(5 -> "e").toDF("i", "j").write.mode("append").format(format).saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index edb3da904d10d..e887c9734a8b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -57,6 +57,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val timeZone = TimeZone.getDefault() val timeZoneId = timeZone.getID + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "parquet") + } + + protected override def afterAll(): Unit = { + spark.conf.unset(SQLConf.DEFAULT_DATA_SOURCE_NAME.key) + super.afterAll() + } + test("column type inference") { def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { assert(inferPartitionColumnValue(raw, true, timeZone) === literal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a707a88dfa670..14b1feb2adc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -562,7 +562,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - sql("CREATE TABLE same_name(id LONG) USING parquet") + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") spark.range(10).createTempView("same_name") spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 859099a321bf7..d93215fefb810 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -591,7 +591,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (ArrayType)") { - withTable("arrayInParquet") { + withTable("array") { { val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") val expectedSchema = @@ -604,9 +604,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("arrayInParquet") + .saveAsTable("array") } { @@ -621,25 +620,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("arrayInParquet") + .insertInto("array") } (Tuple1(Seq(4, 5)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + .saveAsTable("array") // This one internally calls df2.insertInto. (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") + .saveAsTable("array") - sparkSession.catalog.refreshTable("arrayInParquet") + sparkSession.catalog.refreshTable("array") checkAnswer( - sql("SELECT a FROM arrayInParquet"), + sql("SELECT a FROM array"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -648,7 +646,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (MapType)") { - withTable("mapInParquet") { + withTable("map") { { val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") val expectedSchema = @@ -661,9 +659,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("mapInParquet") + .saveAsTable("map") } { @@ -678,27 +675,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("mapInParquet") + .insertInto("map") } (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + .saveAsTable("map") // This one internally calls df2.insertInto. (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") + .saveAsTable("map") - sparkSession.catalog.refreshTable("mapInParquet") + sparkSession.catalog.refreshTable("map") checkAnswer( - sql("SELECT a FROM mapInParquet"), + sql("SELECT a FROM map"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -852,52 +846,52 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv (from to to).map(i => i -> s"str$i").toDF("c1", "c2") } - withTable("insertParquet") { - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + withTable("t") { + createDF(0, 9).write.saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.saveAsTable("t") } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(20, 29).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") + createDF(30, 39).write.saveAsTable("t") } - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + createDF(40, 49).write.mode(SaveMode.Append).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (50 to 59).map(i => Row(i, s"str$i"))) - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (70 to 79).map(i => Row(i, s"str$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 9440a17677ebf..80afc9d8f44bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -37,11 +37,11 @@ class PartitionProviderCompatibilitySuite spark.range(5).selectExpr("id as fieldOne", "id as partCol").write .partitionBy("partCol") .mode("overwrite") - .parquet(dir.getAbsolutePath) + .save(dir.getAbsolutePath) spark.sql(s""" |create table $tableName (fieldOne long, partCol int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${dir.toURI}") |partitioned by (partCol)""".stripMargin) } @@ -358,7 +358,7 @@ class PartitionProviderCompatibilitySuite try { spark.sql(s""" |create table test (id long, P1 int, P2 int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${base.toURI}") |partitioned by (P1, P2)""".stripMargin) spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.toURI}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 54d3962a46b4d..1a86c604d5da3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -417,7 +417,7 @@ class PartitionedTablePerfStatsSuite import spark.implicits._ Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) HiveCatalogMetrics.reset() - spark.read.parquet(dir.getAbsolutePath) + spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 65be244418670..db76ec9d084cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1658,8 +1658,8 @@ class HiveDDLSuite Seq(5 -> "e").toDF("i", "j") .write.format("hive").mode("append").saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `HiveFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format `HiveFileFormat`.")) } } @@ -1709,11 +1709,12 @@ class HiveDDLSuite spark.sessionState.catalog.getTableMetadata(TableIdentifier(tblName)).schema.map(_.name) } + val provider = spark.sessionState.conf.defaultDataSourceName withTable("t", "t1", "t2", "t3", "t4", "t5", "t6") { - sql("CREATE TABLE t(a int, b int, c int, d int) USING parquet PARTITIONED BY (d, b)") + sql(s"CREATE TABLE t(a int, b int, c int, d int) USING $provider PARTITIONED BY (d, b)") assert(getTableColumns("t") == Seq("a", "c", "d", "b")) - sql("CREATE TABLE t1 USING parquet PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") + sql(s"CREATE TABLE t1 USING $provider PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") assert(getTableColumns("t1") == Seq("a", "c", "d", "b")) Seq((1, 1, 1, 1)).toDF("a", "b", "c", "d").write.partitionBy("d", "b").saveAsTable("t2") @@ -1723,7 +1724,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") + sql(s"CREATE TABLE t3 USING $provider LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index baabc4a3bca2c..73f83d593bbfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -516,24 +516,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS with default fileformat") { val table = "ctas1" val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" - withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { - withSQLConf("hive.default.fileformat" -> "textfile") { + Seq("orc", "parquet").foreach { dataSourceFormat => + withSQLConf( + SQLConf.CONVERT_CTAS.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> dataSourceFormat, + "hive.default.fileformat" -> "textfile") { withTable(table) { sql(ctas) - // We should use parquet here as that is the default datasource fileformat. The default - // datasource file format is controlled by `spark.sql.sources.default` configuration. + // The default datasource file format is controlled by `spark.sql.sources.default`. // This testcase verifies that setting `hive.default.fileformat` has no impact on // the target table's fileformat in case of CTAS. - assert(sessionState.conf.defaultDataSourceName === "parquet") - checkRelation(tableName = table, isDataSourceTable = true, format = "parquet") + checkRelation(tableName = table, isDataSourceTable = true, format = dataSourceFormat) } } - withSQLConf("spark.sql.sources.default" -> "orc") { - withTable(table) { - sql(ctas) - checkRelation(tableName = table, isDataSourceTable = true, format = "orc") - } - } } } From dffeac3691daa620446ae949c5b147518d128e08 Mon Sep 17 00:00:00 2001 From: Sebastian Arzt Date: Fri, 16 Mar 2018 12:25:58 -0500 Subject: [PATCH 0482/2461] [SPARK-18371][STREAMING] Spark Streaming backpressure generates batch with large number of records MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Omit rounding of backpressure rate. Effects: - no batch with large number of records is created when rate from PID estimator is one - the number of records per batch and partition is more fine-grained improving backpressure accuracy ## How was this patch tested? This was tested by running: - `mvn test -pl external/kafka-0-8` - `mvn test -pl external/kafka-0-10` - a streaming application which was suffering from the issue JasonMWhite The contribution is my original work and I license the work to the project under the project’s open source license Author: Sebastian Arzt Closes #17774 from arzt/kafka-back-pressure. --- .../kafka010/DirectKafkaInputDStream.scala | 6 +-- .../kafka010/DirectKafkaStreamSuite.scala | 48 +++++++++++++++++ .../kafka/DirectKafkaInputDStream.scala | 6 +-- .../kafka/DirectKafkaStreamSuite.scala | 51 +++++++++++++++++++ 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0fa3287f36db8..9cb2448fea0f4 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -138,17 +138,17 @@ private[spark] class DirectKafkaInputDStream[K, V]( lagPerPartition.map { case (tp, lag) => val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 453b5e5ab20d3..8524743ee2846 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -617,6 +617,54 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = getKafkaParams() + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimateRate = 1L + val fromOffsets = Map( + new TopicPartition(topic, 0) -> 0L, + new TopicPartition(topic, 1) -> 0L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 0L + ) + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { + currentOffsets = fromOffsets + override val rateController = Some(new ConstantRateController(id, null, estimateRate)) + } + } + + val offsets = Map[TopicPartition, Long]( + new TopicPartition(topic, 0) -> 0, + new TopicPartition(topic, 1) -> 100L, + new TopicPartition(topic, 2) -> 200L, + new TopicPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 1) -> 10L, + new TopicPartition(topic, 2) -> 20L, + new TopicPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d52c230eb7849..d6dd0744441e4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -104,17 +104,17 @@ class DirectKafkaInputDStream[ val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 06ef5bc3f8bd0..3fea6cfd910bf 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -456,6 +456,57 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimatedRate = 1L + val kafkaStream = withClue("Error creating direct stream") { + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val fromOffsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 0L, + TopicAndPartition(topic, 2) -> 0L, + TopicAndPartition(topic, 3) -> 0L + ) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, fromOffsets, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, null) { + override def getLatestRate() = estimatedRate + }) + } + } + + val offsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 100L, + TopicAndPartition(topic, 2) -> 200L, + TopicAndPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + TopicAndPartition(topic, 0) -> 1L, + TopicAndPartition(topic, 1) -> 10L, + TopicAndPartition(topic, 2) -> 20L, + TopicAndPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { From 88d8de9260edf6e9d5449ff7ef6e35d16051fc9f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 16 Mar 2018 18:28:16 +0100 Subject: [PATCH 0483/2461] [SPARK-23581][SQL] Add interpreted unsafe projection ## What changes were proposed in this pull request? We currently can only create unsafe rows using code generation. This is a problem for situations in which code generation fails. There is no fallback, and as a result we cannot execute the query. This PR adds an interpreted version of `UnsafeProjection`. The implementation is modeled after `InterpretedMutableProjection`. It stores the expression results in a `GenericInternalRow`, and it then uses a conversion function to convert the `GenericInternalRow` into an `UnsafeRow`. This PR does not implement the actual code generated to interpreted fallback logic. This will be done in a follow-up. ## How was this patch tested? I am piggybacking on exiting `UnsafeProjection` tests, and I have added an interpreted version for each of these. Author: Herman van Hovell Closes #20750 from hvanhovell/SPARK-23581. --- .../codegen/UnsafeArrayWriter.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 30 +- .../expressions/codegen/UnsafeWriter.java | 43 ++ .../sql/catalyst/expressions/Expression.scala | 26 ++ .../InterpretedUnsafeProjection.scala | 366 ++++++++++++++++++ .../MonotonicallyIncreasingID.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 19 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../expressions/randomExpressions.scala | 6 +- .../expressions/ComplexTypeSuite.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 20 +- .../expressions/ObjectExpressionsSuite.scala | 21 +- .../catalyst/expressions/ScalaUDFSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 56 +-- 14 files changed, 555 insertions(+), 74 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 791e8d80e6cba..82cd1b24607e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -30,7 +30,7 @@ * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. */ -public class UnsafeArrayWriter { +public final class UnsafeArrayWriter extends UnsafeWriter { private BufferHolder holder; @@ -83,7 +83,7 @@ private long getElementOffset(int ordinal, int elementSize) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, long currentCursor, int size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { assertIndexIsValid(ordinal); final long relativeOffset = currentCursor - startingOffset; final long offsetAndSize = (relativeOffset << 32) | (long)size; @@ -96,49 +96,31 @@ private void setNullBit(int ordinal) { BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); } - public void setNullBoolean(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); - } - - public void setNullByte(int ordinal) { + public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); } - public void setNullShort(int ordinal) { + public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); } - public void setNullInt(int ordinal) { + public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); } - public void setNullLong(int ordinal) { + public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); } - public void setNullFloat(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); - } - - public void setNullDouble(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); - } - - public void setNull(int ordinal) { setNullLong(ordinal); } + public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 5d9515c0725da..2620bbcfb87a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -38,7 +38,7 @@ * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. */ -public class UnsafeRowWriter { +public final class UnsafeRowWriter extends UnsafeWriter { private final BufferHolder holder; // The offset of the global buffer where we start to write this row. @@ -93,18 +93,38 @@ public void setNullAt(int ordinal) { Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); } + @Override + public void setNull1Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull2Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull4Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull8Bytes(int ordinal) { + setNullAt(ordinal); + } + public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, long size) { + public void setOffsetAndSize(int ordinal, int size) { setOffsetAndSize(ordinal, holder.cursor, size); } - public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { final long relativeOffset = currentCursor - startingOffset; final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | size; + final long offsetAndSize = (relativeOffset << 32) | (long) size; Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); } @@ -174,7 +194,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { if (input == null || !input.changePrecision(precision, scale)) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); // keep the offset for future update - setOffsetAndSize(ordinal, 0L); + setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); assert bytes.length <= 16; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java new file mode 100644 index 0000000000000..c94b5c7a367ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class for writing Unsafe* structures. + */ +public abstract class UnsafeWriter { + public abstract void setNull1Bytes(int ordinal); + public abstract void setNull2Bytes(int ordinal); + public abstract void setNull4Bytes(int ordinal); + public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); + public abstract void write(int ordinal, byte value); + public abstract void write(int ordinal, short value); + public abstract void write(int ordinal, int value); + public abstract void write(int ordinal, long value); + public abstract void write(int ordinal, float value); + public abstract void write(int ordinal, double value); + public abstract void write(int ordinal, Decimal input, int precision, int scale); + public abstract void write(int ordinal, UTF8String input); + public abstract void write(int ordinal, byte[] input); + public abstract void write(int ordinal, CalendarInterval input); + public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ed90b185865a0..d7f9e38915dd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -328,6 +328,32 @@ trait Nondeterministic extends Expression { protected def evalInternal(input: InternalRow): Any } +/** + * An expression that contains mutable state. A stateful expression is always non-deterministic + * because the results it produces during evaluation are not only dependent on the given input + * but also on its internal state. + * + * The state of the expressions is generally not exposed in the parameter list and this makes + * comparing stateful expressions problematic because similar stateful expressions (with the same + * parameter list) but with different internal state will be considered equal. This is especially + * problematic during tree transformations. In order to counter this the `fastEquals` method for + * stateful expressions only returns `true` for the same reference. + * + * A stateful expression should never be evaluated multiple times for a single row. This should + * only be a problem for interpreted execution. This can be prevented by creating fresh copies + * of the stateful expression before execution, these can be made using the `freshCopy` function. + */ +trait Stateful extends Nondeterministic { + /** + * Return a fresh uninitialized copy of the stateful expression. + */ + def freshCopy(): Stateful + + /** + * Only the same reference is considered equal. + */ + override def fastEquals(other: TreeNode[_]): Boolean = this eq other +} /** * A leaf expression, i.e. one without any child expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala new file mode 100644 index 0000000000000..0da5ece7e47fe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{UserDefinedType, _} +import org.apache.spark.unsafe.Platform + +/** + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer + * should copy the row if it is being buffered. This class is not thread safe. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { + import InterpretedUnsafeProjection._ + + /** Number of (top level) fields in the resulting row. */ + private[this] val numFields = expressions.length + + /** Array that expression results. */ + private[this] val values = new Array[Any](numFields) + + /** The row representing the expression results. */ + private[this] val intermediate = new GenericInternalRow(values) + + /** The row returned by the projection. */ + private[this] val result = new UnsafeRow(numFields) + + /** The buffer which holds the resulting row's backing data. */ + private[this] val holder = new BufferHolder(result, numFields * 32) + + /** The writer that writes the intermediate result to the result row. */ + private[this] val writer: InternalRow => Unit = { + val rowWriter = new UnsafeRowWriter(holder, numFields) + val baseWriter = generateStructWriter( + holder, + rowWriter, + expressions.map(e => StructField("", e.dataType, e.nullable))) + if (!expressions.exists(_.nullable)) { + // No nullable fields. The top-level null bit mask will always be zeroed out. + baseWriter + } else { + // Zero out the null bit mask before we write the row. + row => { + rowWriter.zeroOutNullBytes() + baseWriter(row) + } + } + } + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + override def apply(row: InternalRow): UnsafeRow = { + // Put the expression results in the intermediate row. + var i = 0 + while (i < numFields) { + values(i) = expressions(i).eval(row) + i += 1 + } + + // Write the intermediate row to an unsafe row. + holder.reset() + writer(intermediate) + result.setTotalSize(holder.totalSize()) + result + } +} + +/** + * Helper functions for creating an [[InterpretedUnsafeProjection]]. + */ +object InterpretedUnsafeProjection extends UnsafeProjectionCreator { + + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedUnsafeProjection(cleanedExpressions.toArray) + } + + /** + * Generate a struct writer function. The generated function writes an [[InternalRow]] to the + * given buffer using the given [[UnsafeRowWriter]]. + */ + private def generateStructWriter( + bufferHolder: BufferHolder, + rowWriter: UnsafeRowWriter, + fields: Array[StructField]): InternalRow => Unit = { + val numFields = fields.length + + // Create field writers. + val fieldWriters = fields.map { field => + generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + } + // Create basic writer. + row => { + var i = 0 + while (i < numFields) { + fieldWriters(i).apply(row, i) + i += 1 + } + } + } + + /** + * Generate a writer function for a struct field, array element, map key or map value. The + * generated function writes the element at an index in a [[SpecializedGetters]] object (row + * or array) to the given buffer using the given [[UnsafeWriter]]. + */ + private def generateFieldWriter( + bufferHolder: BufferHolder, + writer: UnsafeWriter, + dt: DataType, + nullable: Boolean): (SpecializedGetters, Int) => Unit = { + + // Create the the basic writer. + val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match { + case BooleanType => + (v, i) => writer.write(i, v.getBoolean(i)) + + case ByteType => + (v, i) => writer.write(i, v.getByte(i)) + + case ShortType => + (v, i) => writer.write(i, v.getShort(i)) + + case IntegerType | DateType => + (v, i) => writer.write(i, v.getInt(i)) + + case LongType | TimestampType => + (v, i) => writer.write(i, v.getLong(i)) + + case FloatType => + (v, i) => writer.write(i, v.getFloat(i)) + + case DoubleType => + (v, i) => writer.write(i, v.getDouble(i)) + + case DecimalType.Fixed(precision, scale) => + (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) + + case CalendarIntervalType => + (v, i) => writer.write(i, v.getInterval(i)) + + case BinaryType => + (v, i) => writer.write(i, v.getBinary(i)) + + case StringType => + (v, i) => writer.write(i, v.getUTF8String(i)) + + case StructType(fields) => + val numFields = fields.length + val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) + val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getStruct(i, fields.length) match { + case row: UnsafeRow => + writeUnsafeData( + bufferHolder, + row.getBaseObject, + row.getBaseOffset, + row.getSizeInBytes) + case row => + // Nested struct. We don't know where this will start because a row can be + // variable length, so we need to update the offsets and zero out the bit mask. + rowWriter.reset() + structWriter.apply(row) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case ArrayType(elementType, containsNull) => + val arrayWriter = new UnsafeArrayWriter + val elementSize = getElementSize(elementType) + val elementWriter = generateFieldWriter( + bufferHolder, + arrayWriter, + elementType, + containsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case MapType(keyType, valueType, valueContainsNull) => + val keyArrayWriter = new UnsafeArrayWriter + val keySize = getElementSize(keyType) + val keyWriter = generateFieldWriter( + bufferHolder, + keyArrayWriter, + keyType, + nullable = false) + val valueArrayWriter = new UnsafeArrayWriter + val valueSize = getElementSize(valueType) + val valueWriter = generateFieldWriter( + bufferHolder, + valueArrayWriter, + valueType, + valueContainsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getMap(i) match { + case map: UnsafeMapData => + writeUnsafeData( + bufferHolder, + map.getBaseObject, + map.getBaseOffset, + map.getSizeInBytes) + case map => + // preserve 8 bytes to write the key array numBytes later. + bufferHolder.grow(8) + bufferHolder.cursor += 8 + + // Write the keys and write the numBytes of key array into the first 8 bytes. + writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) + Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + + // Write the values. + writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case udt: UserDefinedType[_] => + generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + + case NullType => + (_, _) => {} + + case _ => + throw new SparkException(s"Unsupported data type $dt") + } + + // Always wrap the writer with a null safe version. + dt match { + case _: UserDefinedType[_] => + // The null wrapper depends on the sql type and not on the UDT. + unsafeWriter + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + // We can't call setNullAt() for DecimalType with precision larger than 18, we call write + // directly. We can use the unwrapped writer directly. + unsafeWriter + case BooleanType | ByteType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull1Bytes(i) + } + } + case ShortType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull2Bytes(i) + } + } + case IntegerType | DateType | FloatType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull4Bytes(i) + } + } + case _ => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull8Bytes(i) + } + } + } + } + + /** + * Get the number of bytes elements of a data type will occupy in the fixed part of an + * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an + * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length + * portion of the array object. Primitives take up to 8 bytes, depending on the size of the + * underlying data type. + */ + private def getElementSize(dataType: DataType): Int = dataType match { + case NullType | StringType | BinaryType | CalendarIntervalType | + _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 + case _ => dataType.defaultSize + } + + /** + * Write an array to the buffer. If the array is already in serialized form (an instance of + * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element + * copy. + */ + private def writeArray( + bufferHolder: BufferHolder, + arrayWriter: UnsafeArrayWriter, + elementWriter: (SpecializedGetters, Int) => Unit, + array: ArrayData, + elementSize: Int): Unit = array match { + case unsafe: UnsafeArrayData => + writeUnsafeData( + bufferHolder, + unsafe.getBaseObject, + unsafe.getBaseOffset, + unsafe.getSizeInBytes) + case _ => + val numElements = array.numElements() + arrayWriter.initialize(bufferHolder, numElements, elementSize) + var i = 0 + while (i < numElements) { + elementWriter.apply(array, i) + i += 1 + } + } + + /** + * Write an opaque block of data to the buffer. This is used to copy + * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. + */ + private def writeUnsafeData( + bufferHolder: BufferHolder, + baseObject: AnyRef, + baseOffset: Long, + sizeInBytes: Int) : Unit = { + bufferHolder.grow(sizeInBytes) + Platform.copyMemory( + baseObject, + baseOffset, + bufferHolder.buffer, + bufferHolder.cursor, + sizeInBytes) + bufferHolder.cursor += sizeInBytes + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 4523079060896..dd523d312e3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. """) -case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -79,4 +79,6 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" + + override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 64b94f0a2c103..3cd73682188bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -object UnsafeProjection { - +trait UnsafeProjectionCreator { /** * Returns an UnsafeProjection for given StructType. * @@ -127,13 +126,13 @@ object UnsafeProjection { } /** - * Returns an UnsafeProjection for given sequence of Expressions (bounded). + * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) - GenerateUnsafeProjection.generate(unsafeExprs) + createProjection(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -146,6 +145,18 @@ object UnsafeProjection { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + protected def createProjection(exprs: Seq[Expression]): UnsafeProjection +} + +object UnsafeProjection extends UnsafeProjectionCreator { + + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(exprs) + } + /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. * TODO: refactor the plumbing and clean this up. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 22717f5954a45..6682ba55b18b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -247,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); + $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); } else { $writeElement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6c9937dacc70b..f36633867316e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -85,6 +85,8 @@ case class Rand(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } + + override def freshCopy(): Rand = Rand(child) } object Rand { @@ -120,6 +122,8 @@ case class Randn(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } + + override def freshCopy(): Randn = Randn(child) } object Randn { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 84190f0bd5f7d..b4138ce366b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { null, null) } intercept[RuntimeException] { - checkEvalutionWithUnsafeProjection( + checkEvaluationWithUnsafeProjection( CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 58d0c07622eb9..c6343b1cbf600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) + checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } @@ -187,11 +187,20 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan(inputRow).get(0, expression.dataType) } - protected def checkEvalutionWithUnsafeProjection( + protected def checkEvaluationWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) + } + + protected def checkEvaluationWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow, + factory: UnsafeProjectionCreator): Unit = { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -203,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } else { val lit = InternalRow(expected, expected) val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + factory.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -213,7 +222,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow): InternalRow = { + inputRow: InternalRow = EmptyRow, + factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index ffeec2a38c532..1f6964dfef598 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -45,16 +45,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) - checkEvalutionWithUnsafeProjection( - structEncoder.serializer.head, structExpected, structInputRow) + checkEvaluationWithUnsafeProjection( + structEncoder.serializer.head, + structExpected, + structInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) - checkEvalutionWithUnsafeProjection( - arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + checkEvaluationWithUnsafeProjection( + arrayEncoder.serializer.head, + arrayExpected, + arrayInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -67,8 +73,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new ArrayBasedMapData( new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) - checkEvalutionWithUnsafeProjection( - mapEncoder.serializer.head, mapExpected, mapInputRow) + checkEvaluationWithUnsafeProjection( + mapEncoder.serializer.head, + mapExpected, + mapInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } test("SPARK-23585: UnwrapOption should support interpreted execution") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 10e3ffd0dff97..e083ae0089244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e1.getMessage.contains("Failed to execute user defined function")) val e2 = intercept[SparkException] { - checkEvalutionWithUnsafeProjection(udf, null) + checkEvaluationWithUnsafeProjection(udf, null) } assert(e2.getMessage.contains("Failed to execute user defined function")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index cf3cbe270753e..c07da122cd7b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - test("basic conversion with only primitive types") { - val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = UnsafeProjection.create(fieldTypes) + private def testWithFactory( + name: String)( + f: UnsafeProjectionCreator => Unit): Unit = { + test(name) { + f(UnsafeProjection) + f(InterpretedUnsafeProjection) + } + } + testWithFactory("basic conversion with only primitive types") { factory => + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) @@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - test("basic conversion with primitive, string and binary types") { + testWithFactory("basic conversion with primitive, string and binary types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - test("basic conversion with primitive, string, date and timestamp types") { + testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - test("null handling") { + testWithFactory("null handling") { factory => val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificInternalRow(fieldTypes) @@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - test("NaN canonicalization") { + testWithFactory("NaN canonicalization") { factory => val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - test("basic conversion with struct type") { + testWithFactory("basic conversion with struct type") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) @@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - test("basic conversion with array type") { + testWithFactory("basic conversion with array type") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) @@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - test("basic conversion with map type") { + testWithFactory("basic conversion with map type") { factory => val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val map1 = createMap(1, 2)(3, 4) @@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - test("basic conversion with struct and array") { + testWithFactory("basic conversion with struct and array") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) @@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with struct and map") { + testWithFactory("basic conversion with struct and map") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) @@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with array and map") { + testWithFactory("basic conversion with array and map") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) From 9945b0227efcd952c8e835453b2831a8c6d5d607 Mon Sep 17 00:00:00 2001 From: Ricardo Martinelli de Oliveira Date: Fri, 16 Mar 2018 10:37:11 -0700 Subject: [PATCH 0484/2461] [SPARK-23680] Fix entrypoint.sh to properly support Arbitrary UIDs ## What changes were proposed in this pull request? As described in SPARK-23680, entrypoint.sh returns an error code because of a command pipeline execution where it is expected in case of Openshift environments, where arbitrary UIDs are used to run containers ## How was this patch tested? This patch was manually tested by using docker-image-toll.sh script to generate a Spark driver image and running an example against an OpenShift cluster. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ricardo Martinelli de Oliveira Closes #20822 from rimolive/rmartine-spark-23680. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3d67b0a702dd4..d0cf284f035ea 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -22,7 +22,10 @@ set -ex # Check whether there is a passwd entry for the container UID myuid=$(id -u) mygid=$(id -g) +# turn off -e for getent because it will return error code in anonymous uid case +set +e uidentry=$(getent passwd $myuid) +set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then From bd201bf61e8e1713deb91b962f670c76c9e3492b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Mar 2018 11:11:07 -0700 Subject: [PATCH 0485/2461] [SPARK-23623][SS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? CacheKafkaConsumer in the project `kafka-0-10-sql` is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one task using trying to read the same Kafka TopicPartition at the same time. Hence, the cache was keyed by the TopicPartition a consumer is supposed to read. And any cases where this assumption may not be true, we have SparkPlan flag to disable the use of a cache. So it was up to the planner to correctly identify when it was not safe to use the cache and set the flag accordingly. Fundamentally, this is the wrong way to approach the problem. It is HARD for a high-level planner to reason about the low-level execution model, whether there will be multiple tasks in the same query trying to read the same partition. Case in point, 2.3.0 introduced stream-stream joins, and you can build a streaming self-join query on Kafka. It's pretty non-trivial to figure out how this leads to two tasks reading the same partition twice, possibly concurrently. And due to the non-triviality, it is hard to figure this out in the planner and set the flag to avoid the cache / consumer pool. And this can inadvertently lead to ConcurrentModificationException ,or worse, silent reading of incorrect data. Here is a better way to design this. The planner shouldnt have to understand these low-level optimizations. Rather the consumer pool should be smart enough avoid concurrent use of a cached consumer. Currently, it tries to do so but incorrectly (the flag inuse is not checked when returning a cached consumer, see [this](https://github.com/apache/spark/blob/master/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala#L403)). If there is another request for the same partition as a currently in-use consumer, the pool should automatically return a fresh consumer that should be closed when the task is done. Then the planner does not have to have a flag to avoid reuses. This PR is a step towards that goal. It does the following. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called KafkaConsumer is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply called `val consumer = KafkaConsumer.acquire` and then `consumer.release()`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is a concurrent attempt of the same task, then a new consumer is generated, and the existing cached consumer is marked for close upon release. - In addition, I renamed the classes because CachedKafkaConsumer is a misnomer given that what it returns may or may not be cached. This PR does not remove the planner flag to avoid reuse to make this patch safe enough for merging in branch-2.3. This can be done later in master-only. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same partition from the consumer pool. Author: Tathagata Das Closes #20767 from tdas/SPARK-23623. --- .../sql/kafka010/KafkaContinuousReader.scala | 5 +- ...Consumer.scala => KafkaDataConsumer.scala} | 242 ++++++++++++------ .../sql/kafka010/KafkaMicroBatchReader.scala | 22 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 23 +- .../kafka010/CachedKafkaConsumerSuite.scala | 34 --- .../sql/kafka010/KafkaDataConsumerSuite.scala | 124 +++++++++ 6 files changed, 295 insertions(+), 155 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{CachedKafkaConsumer.scala => KafkaDataConsumer.scala} (66%) delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 6e56b0a72d671..e7e27876088f3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -196,8 +196,7 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val consumer = - CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset @@ -245,6 +244,6 @@ class KafkaContinuousDataReader( } override def close(): Unit = { - consumer.close() + consumer.release() } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala similarity index 66% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index e97881cb0a163..48508d057a540 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -27,30 +27,73 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread +private[kafka010] sealed trait KafkaDataConsumer { + /** + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. + */ + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss) + } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange() + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + protected def internalConsumer: InternalKafkaConsumer +} + /** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected. + * This is not for direct use outside this file. */ -private[kafka010] case class CachedKafkaConsumer private( +private[kafka010] case class InternalKafkaConsumer( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { - import CachedKafkaConsumer._ + import InternalKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private var consumer = createConsumer + @volatile private var consumer = createConsumer /** indicates whether this consumer is in use or not */ - private var inuse = true + @volatile var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + @volatile var markedForClose = false /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = UNKNOWN_OFFSET + @volatile private var fetchedData = + ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -61,8 +104,6 @@ private[kafka010] case class CachedKafkaConsumer private( c } - case class AvailableOffsetRange(earliest: Long, latest: Long) - private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { case ut: UninterruptibleThread => ut.runUninterruptibly(body) @@ -313,21 +354,51 @@ private[kafka010] case class CachedKafkaConsumer private( } } -private[kafka010] object CachedKafkaConsumer extends Logging { - private val UNKNOWN_OFFSET = -2L +private[kafka010] object KafkaDataConsumer extends Logging { + + case class AvailableOffsetRange(earliest: Long, latest: Long) + + private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + assert(internalConsumer.inUse) // make sure this has been set to true + override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) } + } + + private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + override def release(): Unit = { internalConsumer.close() } + } - private case class CacheKey(groupId: String, topicPartition: TopicPartition) + private case class CacheKey(groupId: String, topicPartition: TopicPartition) { + def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = + this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) + } + // This cache has the following important properties. + // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. + // The capacity is not guaranteed to be maintained, especially when there are more active + // tasks simultaneously using consumers than the capacity. private lazy val cache = { val conf = SparkEnv.get.conf val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) - new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { - if (entry.getValue.inuse == false && this.size > capacity) { - logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + - s"removing consumer for ${entry.getKey}") + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > capacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") try { entry.getValue.close() } catch { @@ -342,80 +413,87 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } - def releaseKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - synchronized { - val consumer = cache.get(key) - if (consumer != null) { - consumer.inuse = false - } else { - logWarning(s"Attempting to release consumer that does not exist") - } - } - } - /** - * Removes (and closes) the Kafka Consumer for the given topic, partition and group id. + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by any one + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. */ - def removeKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) + def acquire( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + useCache: Boolean): KafkaDataConsumer = synchronized { + val key = new CacheKey(topicPartition, kafkaParams) + val existingInternalConsumer = cache.get(key) - synchronized { - val removedConsumer = cache.remove(key) - if (removedConsumer != null) { - removedConsumer.close() + lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams) + + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumer if any and + // start with a new one. + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + } } + cache.remove(key) // Invalidate the cache in any case + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (!useCache) { + // If planner asks to not reuse consumers, then do not use it, return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + cache.put(key, newInternalConsumer) + newInternalConsumer.inUse = true + CachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else { + // If consumer is already cached and is currently not in use, then return that consumer + existingInternalConsumer.inUse = true + CachedKafkaDataConsumer(existingInternalConsumer) } } - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def getOrCreate( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - // If this is reattempt at running the task, then invalidate cache and start with - // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { - removeKafkaConsumer(topic, partition, kafkaParams) - val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) - consumer.inuse = true - cache.put(key, consumer) - consumer - } else { - if (!cache.containsKey(key)) { - cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + private def release(intConsumer: InternalKafkaConsumer): Unit = { + synchronized { + + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams) + val cachedIntConsumer = cache.get(key) + if (intConsumer.eq(cachedIntConsumer)) { + // The released consumer is the same object as the cached one. + if (intConsumer.markedForClose) { + intConsumer.close() + cache.remove(key) + } else { + intConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + intConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache") } - val consumer = cache.get(key) - consumer.inuse = true - consumer } } +} - /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ - def createUncached( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { - new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) - } +private[kafka010] object InternalKafkaConsumer extends Logging { + + private val UNKNOWN_OFFSET = -2L private def reportDataLoss0( failOnDataLoss: Boolean, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a5f3a249b11c..2ed49ba3f5495 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -321,17 +321,8 @@ private[kafka010] case class KafkaMicroBatchDataReader( failOnDataLoss: Boolean, reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { - private val consumer = { - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We - // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } - } + private val consumer = KafkaDataConsumer.acquire( + offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -360,14 +351,7 @@ private[kafka010] case class KafkaMicroBatchDataReader( } override def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } + consumer.release() } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 66b3409c0cd04..498e344ea39f4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors @@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD( val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we - // uses `assign`, we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) - } + val consumer = KafkaDataConsumer.acquire( + sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD( } override protected def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) - } + consumer.release() } } // Release consumer, either by removing it or indicating we're no longer using it @@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD( } } - private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = { + private def resolveRange(consumer: KafkaDataConsumer, range: KafkaSourceRDDOffsetRange) = { if (range.fromOffset < 0 || range.untilOffset < 0) { // Late bind the offset range val availableOffsetRange = consumer.getAvailableOffsetRange() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala deleted file mode 100644 index 7aa7dd096c07b..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.scalatest.PrivateMethodTester - -import org.apache.spark.sql.test.SharedSQLContext - -class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { - - test("SPARK-19886: Report error cause correctly in reportDataLoss") { - val cause = new Exception("D'oh!") - val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) - val e = intercept[IllegalStateException] { - CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) - } - assert(e.getCause === cause) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..0d0fb9c3ab5af --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.PrivateMethodTester + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.ThreadUtils + +class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + protected var testUtils: KafkaTestUtils = _ + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils(Map[String, Object]()) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } + + test("SPARK-23623: concurrent use of KafkaDataConsumer") { + val topic = "topic" + Random.nextInt() + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic, 1) + testUtils.sendMessages(topic, data.toArray) + val topicPartition = new TopicPartition(topic, 0) + + import ConsumerConfig._ + val kafkaParams = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ) + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + TaskContext.setTaskContext(taskContext) + val consumer = KafkaDataConsumer.acquire( + topicPartition, kafkaParams.asJava, useCache) + try { + val range = consumer.getAvailableOffsetRange() + val rcvd = range.earliest until range.latest map { offset => + val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadpool.shutdown() + } + } +} From 8a72734f33f6a0abbd3207b0d661633c8b25d9ad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 16 Mar 2018 11:42:57 -0700 Subject: [PATCH 0486/2461] [SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list ## What changes were proposed in this pull request? Added a class method to construct CountVectorizerModel from a list of vocabulary strings, equivalent to the Scala version. Introduced a common param base class `_CountVectorizerParams` to allow the Python model to also own the parameters. This now matches the Scala class hierarchy. ## How was this patch tested? Added to CountVectorizer doctests to do a transform on a model constructed from vocab, and unit test to verify params and vocab are constructed correctly. Author: Bryan Cutler Closes #16770 from BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009. --- python/pyspark/ml/feature.py | 168 +++++++++++++++++++++++------------ python/pyspark/ml/tests.py | 32 ++++++- 2 files changed, 142 insertions(+), 58 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f2e357f0bede5..a1ceb7f02da8b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,12 +19,12 @@ if sys.version > '3': basestring = str -from pyspark import since, keyword_only +from pyspark import since, keyword_only, SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', @@ -403,8 +403,69 @@ def getSplits(self): return self.getOrDefault(self.splits) +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. + """ + + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0", typeConverter=TypeConverters.toFloat) + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) + + def __init__(self, *args): + super(_CountVectorizerParams, self).__init__(*args) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + + @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. @@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, >>> loadedModel = CountVectorizerModel.load(modelPath) >>> loadedModel.vocabulary == model.vocabulary True + >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], + ... inputCol="raw", outputCol="vectors") + >>> fromVocabModel.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... .. versionadded:: 1.6.0 """ - minTF = Param( - Params._dummy(), "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then this " + - "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", - typeConverter=TypeConverters.toFloat) - minDF = Param( - Params._dummy(), "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0", typeConverter=TypeConverters.toFloat) - vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", - typeConverter=TypeConverters.toInt) - binary = Param( - Params._dummy(), "binary", "Binary toggle to control the output vector values." + - " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + - " for discrete probabilistic models that model binary events rather than integer counts." + - " Default False", typeConverter=TypeConverters.toBoolean) - @keyword_only def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, outputCol=None): @@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -497,13 +544,6 @@ def setMinTF(self, value): """ return self._set(minTF=value) - @since("1.6.0") - def getMinTF(self): - """ - Gets the value of minTF or its default value. - """ - return self.getOrDefault(self.minTF) - @since("1.6.0") def setMinDF(self, value): """ @@ -511,13 +551,6 @@ def setMinDF(self, value): """ return self._set(minDF=value) - @since("1.6.0") - def getMinDF(self): - """ - Gets the value of minDF or its default value. - """ - return self.getOrDefault(self.minDF) - @since("1.6.0") def setVocabSize(self, value): """ @@ -525,13 +558,6 @@ def setVocabSize(self, value): """ return self._set(vocabSize=value) - @since("1.6.0") - def getVocabSize(self): - """ - Gets the value of vocabSize or its default value. - """ - return self.getOrDefault(self.vocabSize) - @since("2.0.0") def setBinary(self, value): """ @@ -539,24 +565,40 @@ def setBinary(self, value): """ return self._set(binary=value) - @since("2.0.0") - def getBinary(self): - """ - Gets the value of binary or its default value. - """ - return self.getOrDefault(self.binary) - def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): +@inherit_doc +class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ + @classmethod + @since("2.4.0") + def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): + """ + Construct the model directly from a vocabulary list of strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) + model = CountVectorizerModel._create_from_java_class( + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if minTF is not None: + model.setMinTF(minTF) + if binary is not None: + model.setBinary(binary) + model._set(vocabSize=len(vocabulary)) + return model + @property @since("1.6.0") def vocabulary(self): @@ -565,6 +607,20 @@ def vocabulary(self): """ return self._call_java("vocabulary") + @since("2.4.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("2.4.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6dee6938d8916..fd45fd00b270b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -679,6 +679,34 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + def test_rformula_force_index_label(self): df = self.spark.createDataFrame([ (1.0, 1.0, "a"), @@ -2019,8 +2047,8 @@ def test_java_params(self): pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and issubclass(cls, JavaParams)\ - and not inspect.isabstract(cls): + if not name.endswith('Model') and not name.endswith('Params')\ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) From 8a1efe3076f29259151f1fba2ff894487efb6c4e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 16 Mar 2018 15:40:21 -0700 Subject: [PATCH 0487/2461] [SPARK-23683][SQL] FileCommitProtocol.instantiate() hardening ## What changes were proposed in this pull request? With SPARK-20236, `FileCommitProtocol.instantiate()` looks for a three argument constructor, passing in the `dynamicPartitionOverwrite` parameter. If there is no such constructor, it falls back to the classic two-arg one. When `InsertIntoHadoopFsRelationCommand` passes down that `dynamicPartitionOverwrite` flag `to FileCommitProtocol.instantiate(`), it assumes that the instantiated protocol supports the specific requirements of dynamic partition overwrite. It does not notice when this does not hold, and so the output generated may be incorrect. This patch changes `FileCommitProtocol.instantiate()` so when `dynamicPartitionOverwrite == true`, it requires the protocol implementation to have a 3-arg constructor. Classic two arg constructors are supported when it is false. Also it adds some debug level logging for anyone trying to understand what's going on. ## How was this patch tested? Unit tests verify that * classes with only 2-arg constructor cannot be used with dynamic overwrite * classes with only 2-arg constructor can be used without dynamic overwrite * classes with 3 arg constructors can be used with both. * the fallback to any two arg ctor takes place after the attempt to load the 3-arg ctor, * passing in invalid class types fail as expected (regression tests on expected behavior) Author: Steve Loughran Closes #20824 from steveloughran/stevel/SPARK-23683-protocol-instantiate. --- .../internal/io/FileCommitProtocol.scala | 11 +- ...FileCommitProtocolInstantiationSuite.scala | 148 ++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 6d0059b6a0272..e6e9c9e328853 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import org.apache.hadoop.fs._ import org.apache.hadoop.mapreduce._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -132,7 +133,7 @@ abstract class FileCommitProtocol { } -object FileCommitProtocol { +object FileCommitProtocol extends Logging { class TaskCommitMessage(val obj: Any) extends Serializable object EmptyTaskCommitMessage extends TaskCommitMessage(null) @@ -145,15 +146,23 @@ object FileCommitProtocol { jobId: String, outputPath: String, dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { + + logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + + s" dynamic=$dynamicPartitionOverwrite") val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] // First try the constructor with arguments (jobId: String, outputPath: String, // dynamicPartitionOverwrite: Boolean). // If that doesn't exist, try the one with (jobId: string, outputPath: String). try { val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + logDebug("Using (String, String, Boolean) constructor") ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) } catch { case _: NoSuchMethodException => + logDebug("Falling back to (String, String) constructor") + require(!dynamicPartitionOverwrite, + "Dynamic Partition Overwrite is enabled but" + + s" the committer ${className} does not have the appropriate constructor") val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) ctor.newInstance(jobId, outputPath) } diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala new file mode 100644 index 0000000000000..2bd32fc927e21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for instantiation of FileCommitProtocol implementations. + */ +class FileCommitProtocolInstantiationSuite extends SparkFunSuite { + + test("Dynamic partitions require appropriate constructor") { + + // you cannot instantiate a two-arg client with dynamic partitions + // enabled. + val ex = intercept[IllegalArgumentException] { + instantiateClassic(true) + } + // check the contents of the message and rethrow if unexpected. + // this preserves the stack trace of the unexpected + // exception. + if (!ex.toString.contains("Dynamic Partition Overwrite")) { + fail(s"Wrong text in caught exception $ex", ex) + } + } + + test("Standard partitions work with classic constructor") { + instantiateClassic(false) + } + + test("Three arg constructors have priority") { + assert(3 == instantiateNew(false).argCount, + "Wrong constructor argument count") + } + + test("Three arg constructors have priority when dynamic") { + assert(3 == instantiateNew(true).argCount, + "Wrong constructor argument count") + } + + test("The protocol must be of the correct class") { + intercept[ClassCastException] { + FileCommitProtocol.instantiate( + classOf[Other].getCanonicalName, + "job", + "path", + false) + } + } + + test("If there is no matching constructor, class hierarchy is irrelevant") { + intercept[NoSuchMethodException] { + FileCommitProtocol.instantiate( + classOf[NoMatchingArgs].getCanonicalName, + "job", + "path", + false) + } + } + + /** + * Create a classic two-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[ClassicConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[ClassicConstructorCommitProtocol] + } + + /** + * Create a three-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateNew( + dynamic: Boolean): FullConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[FullConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[FullConstructorCommitProtocol] + } + +} + +/** + * This protocol implementation does not have the new three-arg + * constructor. + */ +private class ClassicConstructorCommitProtocol(arg1: String, arg2: String) + extends HadoopMapReduceCommitProtocol(arg1, arg2) { +} + +/** + * This protocol implementation does have the new three-arg constructor + * alongside the original, and a 4 arg one for completeness. + * The final value of the real constructor is the number of arguments + * used in the 2- and 3- constructor, for test assertions. + */ +private class FullConstructorCommitProtocol( + arg1: String, + arg2: String, + b: Boolean, + val argCount: Int) + extends HadoopMapReduceCommitProtocol(arg1, arg2, b) { + + def this(arg1: String, arg2: String) = { + this(arg1, arg2, false, 2) + } + + def this(arg1: String, arg2: String, b: Boolean) = { + this(arg1, arg2, false, 3) + } +} + +/** + * This has the 2-arity constructor, but isn't the right class. + */ +private class Other(arg1: String, arg2: String) { + +} + +/** + * This has no matching arguments as well as being the wrong class. + */ +private class NoMatchingArgs() { + +} + From 61487b308b0169e3108c2ad31674a0c80b8ac5f3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Mar 2018 20:24:14 +0900 Subject: [PATCH 0488/2461] [SPARK-23706][PYTHON] spark.conf.get(value, default=None) should produce None in PySpark ## What changes were proposed in this pull request? Scala: ``` scala> spark.conf.get("hey", null) res1: String = null ``` ``` scala> spark.conf.get("spark.sql.sources.partitionOverwriteMode", null) res2: String = null ``` Python: **Before** ``` >>> spark.conf.get("hey", None) ... py4j.protocol.Py4JJavaError: An error occurred while calling o30.get. : java.util.NoSuchElementException: hey ... ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) u'STATIC' ``` **After** ``` >>> spark.conf.get("hey", None) is None True ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) is None True ``` *Note that this PR preserves the case below: ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode") u'STATIC' ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20841 from HyukjinKwon/spark-conf-get. --- python/pyspark/sql/conf.py | 9 +++++---- python/pyspark/sql/context.py | 8 ++++---- python/pyspark/sql/tests.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index d929834aeeaa5..b82224b6194ed 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -17,7 +17,7 @@ import sys -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix @@ -39,15 +39,16 @@ def set(self, key, value): @ignore_unicode_prefix @since(2.0) - def get(self, key, default=None): + def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. """ self._checkType(key, "key") - if default is None: + if default is _NoValue: return self._jconf.get(key) else: - self._checkType(default, "default") + if default is not None: + self._checkType(default, "default") return self._jconf.get(key, default) @ignore_unicode_prefix diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6cb90399dd616..e9ec7ba866761 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -22,7 +22,7 @@ if sys.version >= '3': basestring = unicode = str -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame @@ -124,11 +124,11 @@ def setConf(self, key, value): @ignore_unicode_prefix @since(1.3) - def getConf(self, key, defaultValue=None): + def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. - If the key is not set and defaultValue is not None, return - defaultValue. If the key is not set and defaultValue is None, return + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 480815d27333f..a0d547ad620e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2504,6 +2504,17 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + def test_current_database(self): spark = self.spark spark.catalog._reset() From 745c8c0901ac522ba92c1356ca74bd0dd7701496 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Mon, 19 Mar 2018 13:31:21 +0800 Subject: [PATCH 0489/2461] [SPARK-23708][CORE] Correct comment for function addShutDownHook in ShutdownHookManager ## What changes were proposed in this pull request? Minor modification.Comment below is not right. ``` /** * Adds a shutdown hook with the given priority. Hooks with lower priority values run * first. * * param hook The code to run during shutdown. * return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } ``` ## How was this patch tested? UT Author: zhoukang Closes #20845 from caneGuy/zhoukang/fix-shutdowncomment. --- .../main/scala/org/apache/spark/util/ShutdownHookManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4001fac3c3d5a..b702838fa257f 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -143,7 +143,7 @@ private[spark] object ShutdownHookManager extends Logging { } /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * Adds a shutdown hook with the given priority. Hooks with higher priority values run * first. * * @param hook The code to run during shutdown. From 4de638c1976dea74761bbe5c30da808178ee885d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Mar 2018 09:41:43 +0100 Subject: [PATCH 0490/2461] [SPARK-23599][SQL] Add a UUID generator from Pseudo-Random Numbers ## What changes were proposed in this pull request? This patch adds a UUID generator from Pseudo-Random Numbers. We can use it later to have deterministic `UUID()` expression. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20817 from viirya/SPARK-23599. --- .../catalyst/util/RandomUUIDGenerator.scala | 43 ++++++++++++++ .../util/RandomUUIDGeneratorSuite.scala | 57 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala new file mode 100644 index 0000000000000..4fe07a071c1ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.UUID + +import org.apache.commons.math3.random.MersenneTwister + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class is used to generate a UUID from Pseudo-Random Numbers. + * + * For the algorithm, see RFC 4122: A Universally Unique IDentifier (UUID) URN Namespace, + * section 4.4 "Algorithms for Creating a UUID from Truly Random or Pseudo-Random Numbers". + */ +case class RandomUUIDGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextUUID(): UUID = { + val mostSigBits = (random.nextLong() & 0xFFFFFFFFFFFF0FFFL) | 0x0000000000004000L + val leastSigBits = (random.nextLong() | 0x8000000000000000L) & 0xBFFFFFFFFFFFFFFFL + + new UUID(mostSigBits, leastSigBits) + } + + def getNextUUIDUTF8String(): UTF8String = UTF8String.fromString(getNextUUID().toString()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala new file mode 100644 index 0000000000000..b75739e5a3a65 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + +class RandomUUIDGeneratorSuite extends SparkFunSuite { + test("RandomUUIDGenerator should generate version 4, variant 2 UUIDs") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + for (_ <- 0 to 100) { + val uuid = generator.getNextUUID() + assert(uuid.version() == 4) + assert(uuid.variant() == 2) + } + } + + test("UUID from RandomUUIDGenerator should be deterministic") { + val r1 = new Random(100) + val generator1 = RandomUUIDGenerator(r1.nextLong()) + val r2 = new Random(100) + val generator2 = RandomUUIDGenerator(r2.nextLong()) + val r3 = new Random(101) + val generator3 = RandomUUIDGenerator(r3.nextLong()) + + for (_ <- 0 to 100) { + val uuid1 = generator1.getNextUUID() + val uuid2 = generator2.getNextUUID() + val uuid3 = generator3.getNextUUID() + assert(uuid1 == uuid2) + assert(uuid1 != uuid3) + } + } + + test("Get UTF8String UUID") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + val utf8StringUUID = generator.getNextUUIDUTF8String() + val uuid = java.util.UUID.fromString(utf8StringUUID.toString) + assert(uuid.version() == 4 && uuid.variant() == 2 && utf8StringUUID.toString == uuid.toString) + } +} From f15906da153f139b698e192ec6f82f078f896f1e Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Mon, 19 Mar 2018 11:29:56 -0700 Subject: [PATCH 0491/2461] [SPARK-22839][K8S] Remove the use of init-container for downloading remote dependencies ## What changes were proposed in this pull request? Removal of the init-container for downloading remote dependencies. Built off of the work done by vanzin in an attempt to refactor driver/executor configuration elaborated in [this](https://issues.apache.org/jira/browse/SPARK-22839) ticket. ## How was this patch tested? This patch was tested with unit and integration tests. Author: Ilan Filonenko Closes #20669 from ifilonenko/remove-init-container. --- bin/docker-image-tool.sh | 9 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 - docs/running-on-kubernetes.md | 71 +------- .../spark/examples/SparkRemoteFileTest.scala | 48 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 73 +------- .../apache/spark/deploy/k8s/Constants.scala | 21 +-- .../deploy/k8s/InitContainerBootstrap.scala | 120 ------------- .../spark/deploy/k8s/KubernetesUtils.scala | 63 +------ .../k8s/PodWithDetachedInitContainer.scala | 31 ---- .../deploy/k8s/SparkPodInitContainer.scala | 116 ------------- .../k8s/submit/DriverConfigOrchestrator.scala | 45 +---- .../submit/KubernetesClientApplication.scala | 84 +++++---- .../steps/BasicDriverConfigurationStep.scala | 32 ++-- .../steps/DependencyResolutionStep.scala | 18 +- .../DriverInitContainerBootstrapStep.scala | 95 ----------- .../DriverKubernetesCredentialsStep.scala | 2 +- .../BasicInitContainerConfigurationStep.scala | 67 -------- .../InitContainerConfigOrchestrator.scala | 79 --------- .../InitContainerConfigurationStep.scala | 25 --- .../InitContainerMountSecretsStep.scala | 36 ---- .../initcontainer/InitContainerSpec.scala | 37 ---- .../cluster/k8s/ExecutorPodFactory.scala | 43 +---- .../k8s/KubernetesClusterManager.scala | 65 +------ .../k8s/SparkPodInitContainerSuite.scala | 86 ---------- .../spark/deploy/k8s/submit/ClientSuite.scala | 82 ++++----- .../DriverConfigOrchestratorSuite.scala | 41 +---- .../BasicDriverConfigurationStepSuite.scala | 8 +- .../steps/DependencyResolutionStepSuite.scala | 32 ++-- ...riverInitContainerBootstrapStepSuite.scala | 160 ------------------ ...cInitContainerConfigurationStepSuite.scala | 95 ----------- ...InitContainerConfigOrchestratorSuite.scala | 80 --------- .../InitContainerMountSecretsStepSuite.scala | 52 ------ .../cluster/k8s/ExecutorPodFactorySuite.scala | 67 +------- .../src/main/dockerfiles/spark/Dockerfile | 1 - .../src/main/dockerfiles/spark/entrypoint.sh | 20 +-- 35 files changed, 241 insertions(+), 1665 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 0d0f564bb8b9b..f090240065bf1 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -64,9 +64,11 @@ function build { error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi + local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + docker build "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$IMG_PATH/spark/Dockerfile" . + -f "$DOCKERFILE" . } function push { @@ -84,6 +86,7 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: + -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -113,10 +116,12 @@ fi REPO= TAG= -while getopts mr:t: option +DOCKERFILE= +while getopts f:mr:t: option do case "${option}" in + f) DOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; m) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e381965c52ba..329bde08718fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -320,8 +320,6 @@ object SparkSubmit extends CommandLineUtils with Logging { printErrorAndExit("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => printErrorAndExit("R applications are currently not supported for Kubernetes.") - case (KUBERNETES, CLIENT) => - printErrorAndExit("Client mode is currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 3c7586e8544ba..975b28de47e20 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -126,29 +126,6 @@ Those dependencies can be added to the classpath by referencing them with `local dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission client's local file system is currently not yet supported. - -### Using Remote Dependencies -When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods -need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. - -The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and -`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., -the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: - -```bash -$ bin/spark-submit \ - --master k8s://https://: \ - --deploy-mode cluster \ - --name spark-pi \ - --class org.apache.spark.examples.SparkPi \ - --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar - --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 - --conf spark.executor.instances=5 \ - --conf spark.kubernetes.container.image= \ - https://path/to/examples.jar -``` - ## Secret Management Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a Spark application to access secured services. To mount a user-specified secret into the driver container, users can use @@ -163,10 +140,6 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` -Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the -init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the -init-container of the executor. - ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -604,51 +577,12 @@ specific to Spark on Kubernetes. the Driver process. The user can specify multiple of these to set multiple environment variables. - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.timeout - 300s - - Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into - the driver and executor pods. - - - - spark.kubernetes.mountDependencies.maxSimultaneousDownloads - 5 - - Maximum number of remote dependencies to download simultaneously in a driver or executor pod. - - - - spark.kubernetes.initContainer.image - (value of spark.kubernetes.container.image) - - Custom container image for the init container of both driver and executors. - - spark.kubernetes.driver.secrets.[SecretName] (none) Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, - spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the driver pod. + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. @@ -656,8 +590,7 @@ specific to Spark on Kubernetes. (none) Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, - spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the executor pod. + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. \ No newline at end of file diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala new file mode 100644 index 0000000000000..64076f2deb706 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import org.apache.spark.SparkFiles +import org.apache.spark.sql.SparkSession + +/** Usage: SparkRemoteFileTest [file] */ +object SparkRemoteFileTest { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: SparkRemoteFileTest ") + System.exit(1) + } + val spark = SparkSession + .builder() + .appName("SparkRemoteFileTest") + .getOrCreate() + val sc = spark.sparkContext + val rdd = sc.parallelize(Seq(1)).map(_ => { + val localLocation = SparkFiles.get(args(0)) + println(s"${args(0)} is stored at: $localLocation") + new File(localLocation).isFile + }) + val truthCheck = rdd.collect().head + println(s"Mounting of ${args(0)} was $truthCheck") + spark.stop() + } +} +// scalastyle:on println diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 471196ac0e3f6..da34a7e06238a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -79,6 +79,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_DRIVER_SUBMIT_CHECK = + ConfigBuilder("spark.kubernetes.submitInDriver") + .internal() + .booleanConf + .createOptional + val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") .doc("Specify the hard cpu limit for each executor pod") @@ -135,73 +141,6 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") - val JARS_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pod.") - .stringConf - .createWithDefault("/var/spark-data/spark-jars") - - val FILES_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pods.") - .stringConf - .createWithDefault("/var/spark-data/spark-files") - - val INIT_CONTAINER_IMAGE = - ConfigBuilder("spark.kubernetes.initContainer.image") - .doc("Image for the driver and executor's init-container for downloading dependencies.") - .fallbackConf(CONTAINER_IMAGE) - - val INIT_CONTAINER_MOUNT_TIMEOUT = - ConfigBuilder("spark.kubernetes.mountDependencies.timeout") - .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + - "locations into the driver and executor pods.") - .timeConf(TimeUnit.SECONDS) - .createWithDefault(300) - - val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = - ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") - .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + - "executor pod.") - .intConf - .createWithDefault(5) - - val INIT_CONTAINER_REMOTE_JARS = - ConfigBuilder("spark.kubernetes.initContainer.remoteJars") - .doc("Comma-separated list of jar URIs to download in the init-container. This is " + - "calculated from spark.jars.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_REMOTE_FILES = - ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") - .doc("Comma-separated list of file URIs to download in the init-container. This is " + - "calculated from spark.files.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_NAME = - ConfigBuilder("spark.kubernetes.initContainer.configMapName") - .doc("Name of the config map to use in the init-container that retrieves submitted files " + - "for the executor.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = - ConfigBuilder("spark.kubernetes.initContainer.configMapKey") - .doc("Key for the entry in the init container config map for submitted files that " + - "corresponds to the properties for this init-container.") - .internal() - .stringConf - .createOptional - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 9411956996843..8da5f24044aad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -63,22 +63,13 @@ private[spark] object Constants { val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" - val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" - val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" - val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" - val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" - val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" - - // Bootstrapping dependencies with the init-container - val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" - val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" - val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" - val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" - val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" - val INIT_CONTAINER_PROPERTIES_FILE_PATH = - s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" - val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" + val ENV_SPARK_CONF_DIR = "SPARK_CONF_DIR" + // Spark app configs for containers + val SPARK_CONF_VOLUME = "spark-conf-volume" + val SPARK_CONF_DIR_INTERNAL = "/opt/spark/conf" + val SPARK_CONF_FILE_NAME = "spark.properties" + val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala deleted file mode 100644 index f6a57dfe00171..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Bootstraps an init-container for downloading remote dependencies. This is separated out from - * the init-container steps API because this component can be used to bootstrap init-containers - * for both the driver and executors. - */ -private[spark] class InitContainerBootstrap( - initContainerImage: String, - imagePullPolicy: String, - jarsDownloadPath: String, - filesDownloadPath: String, - configMapName: String, - configMapKey: String, - sparkRole: String, - sparkConf: SparkConf) { - - /** - * Bootstraps an init-container that downloads dependencies to be used by a main container. - */ - def bootstrapInitContainer( - original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { - val sharedVolumeMounts = Seq[VolumeMount]( - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withMountPath(jarsDownloadPath) - .build(), - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withMountPath(filesDownloadPath) - .build()) - - val customEnvVarKeyPrefix = sparkRole match { - case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY - case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." - case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") - } - val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { - case (key, value) => - new EnvVarBuilder() - .withName(key) - .withValue(value) - .build() - } - - val initContainer = new ContainerBuilder(original.initContainer) - .withName("spark-init") - .withImage(initContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(customEnvVars.asJava) - .addNewVolumeMount() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) - .endVolumeMount() - .addToVolumeMounts(sharedVolumeMounts: _*) - .addToArgs("init") - .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) - .build() - - val podWithBasicVolumes = new PodBuilder(original.pod) - .editSpec() - .addNewVolume() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withNewConfigMap() - .withName(configMapName) - .addNewItem() - .withKey(configMapKey) - .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .endSpec() - .build() - - val mainContainer = new ContainerBuilder(original.mainContainer) - .addToVolumeMounts(sharedVolumeMounts: _*) - .addNewEnv() - .withName(ENV_MOUNTED_FILES_DIR) - .withValue(filesDownloadPath) - .endEnv() - .build() - - PodWithDetachedInitContainer( - podWithBasicVolumes, - initContainer, - mainContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 37331d8bbf9b7..5bc070147d3a8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,10 +16,6 @@ */ package org.apache.spark.deploy.k8s -import java.io.File - -import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} - import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -43,72 +39,23 @@ private[spark] object KubernetesUtils { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } - /** - * Append the given init-container to a pod's list of init-containers. - * - * @param originalPodSpec original specification of the pod - * @param initContainer the init-container to add to the pod - * @return the pod with the init-container added to the list of InitContainers - */ - def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { - new PodBuilder(originalPodSpec) - .editOrNewSpec() - .addToInitContainers(initContainer) - .endSpec() - .build() - } - /** * For the given collection of file URIs, resolves them as follows: - * - File URIs with scheme file:// are resolved to the given download path. * - File URIs with scheme local:// resolve to just the path of the URI. * - Otherwise, the URIs are returned as-is. */ - def resolveFileUris( - fileUris: Iterable[String], - fileDownloadPath: String): Iterable[String] = { - fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, false) - } - } - - /** - * If any file uri has any scheme other than local:// it is mapped as if the file - * was downloaded to the file download path. Otherwise, it is mapped to the path - * part of the URI. - */ - def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = { + def resolveFileUrisAndPath(fileUris: Iterable[String]): Iterable[String] = { fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, true) - } - } - - /** - * Get from a given collection of file URIs the ones that represent remote files. - */ - def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { - uris.filter { uri => - val scheme = Utils.resolveURI(uri).getScheme - scheme != "file" && scheme != "local" + resolveFileUri(uri) } } - private def resolveFileUri( - uri: String, - fileDownloadPath: String, - assumesDownloaded: Boolean): String = { + private def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { - case "local" => - fileUri.getPath - case _ => - if (assumesDownloaded || fileScheme == "file") { - val fileName = new File(fileUri.getPath).getName - s"$fileDownloadPath/$fileName" - } else { - uri - } + case "local" => fileUri.getPath + case _ => uri } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala deleted file mode 100644 index 0b79f8b12e806..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, Pod} - -/** - * Represents a pod with a detached init-container (not yet added to the pod). - * - * @param pod the pod - * @param initContainer the init-container in the pod - * @param mainContainer the main container in the pod - */ -private[spark] case class PodWithDetachedInitContainer( - pod: Pod, - initContainer: Container, - mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala deleted file mode 100644 index c0f08786b76a1..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.concurrent.TimeUnit - -import scala.concurrent.{ExecutionContext, Future} - -import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -/** - * Process that fetches files from a resource staging server and/or arbitrary remote locations. - * - * The init-container can handle fetching files from any of those sources, but not all of the - * sources need to be specified. This allows for composing multiple instances of this container - * with different configurations for different download sources, or using the same container to - * download everything at once. - */ -private[spark] class SparkPodInitContainer( - sparkConf: SparkConf, - fileFetcher: FileFetcher) extends Logging { - - private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) - private implicit val downloadExecutor = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) - - private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) - private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) - - private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) - private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) - - private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) - - def run(): Unit = { - logInfo(s"Downloading remote jars: $remoteJars") - downloadFiles( - remoteJars, - jarsDownloadDir, - s"Remote jars download directory specified at $jarsDownloadDir does not exist " + - "or is not a directory.") - - logInfo(s"Downloading remote files: $remoteFiles") - downloadFiles( - remoteFiles, - filesDownloadDir, - s"Remote files download directory specified at $filesDownloadDir does not exist " + - "or is not a directory.") - - downloadExecutor.shutdown() - downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) - } - - private def downloadFiles( - filesCommaSeparated: Option[String], - downloadDir: File, - errMessage: String): Unit = { - filesCommaSeparated.foreach { files => - require(downloadDir.isDirectory, errMessage) - Utils.stringToSeq(files).foreach { file => - Future[Unit] { - fileFetcher.fetchFile(file, downloadDir) - } - } - } - } -} - -private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { - - def fetchFile(uri: String, targetDir: File): Unit = { - Utils.fetchFile( - url = uri, - targetDir = targetDir, - conf = sparkConf, - securityMgr = securityManager, - hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), - timestamp = System.currentTimeMillis(), - useCache = false) - } -} - -object SparkPodInitContainer extends Logging { - - def main(args: Array[String]): Unit = { - logInfo("Starting init-container to download Spark application dependencies.") - val sparkConf = new SparkConf(true) - if (args.nonEmpty) { - Utils.loadDefaultSparkProperties(sparkConf, args(0)) - } - - val securityManager = new SparkSecurityManager(sparkConf) - val fileFetcher = new FileFetcher(sparkConf, securityManager) - new SparkPodInitContainer(sparkConf, fileFetcher).run() - logInfo("Finished downloading application dependencies.") - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index ae70904621184..b4d3f04a1bc32 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -16,16 +16,11 @@ */ package org.apache.spark.deploy.k8s.submit -import java.util.UUID - -import com.google.common.primitives.Longs - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils @@ -34,13 +29,11 @@ import org.apache.spark.util.Utils * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to * configure the Spark driver pod. The returned steps will be applied one by one in the given * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to - * configure the driver init-container if one is needed, i.e., when there are remote dependencies - * to localize. + * to construct and create the driver pod. */ private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, - launchTime: Long, + kubernetesResourceNamePrefix: String, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, @@ -50,15 +43,8 @@ private[spark] class DriverConfigOrchestrator( // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the // application the user submitted. - private val kubernetesResourceNamePrefix = { - val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "") - s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") - } private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" - private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -126,9 +112,7 @@ private[spark] class DriverConfigOrchestrator( val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath)) + sparkFiles)) } else { Nil } @@ -139,33 +123,12 @@ private[spark] class DriverConfigOrchestrator( Nil } - val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { - val orchestrator = new InitContainerConfigOrchestrator( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - imagePullPolicy, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME, - sparkConf) - val bootstrapStep = new DriverInitContainerBootstrapStep( - orchestrator.getAllConfigurationSteps, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME) - - Seq(bootstrapStep) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - mountSecretsStep ++ - initContainerBootstrapStep + mountSecretsStep } private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 5884348cb3e41..e16d1add600b2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -16,14 +16,14 @@ */ package org.apache.spark.deploy.k8s.submit +import java.io.StringWriter import java.util.{Collections, UUID} - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.util.control.NonFatal +import java.util.Properties import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication @@ -32,6 +32,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -93,10 +94,8 @@ private[spark] class Client( kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - watcher: LoggingPodStatusWatcher) extends Logging { - - private val driverJavaOptions = sparkConf.get( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + watcher: LoggingPodStatusWatcher, + kubernetesResourceNamePrefix: String) extends Logging { /** * Run command that initializes a DriverSpec that will be updated after each @@ -110,33 +109,31 @@ private[spark] class Client( for (nextStep <- submissionSteps) { currentDriverSpec = nextStep.configureDriver(currentDriverSpec) } - - val resolvedDriverJavaOpts = currentDriverSpec - .driverSparkConf - // Remove this as the options are instead extracted and set individually below using - // environment variables with prefix SPARK_JAVA_OPT_. - .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) - .getAll - .map { - case (confKey, confValue) => s"-D$confKey=$confValue" - } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) - val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map { - case (option, index) => - new EnvVarBuilder() - .withName(s"$ENV_JAVA_OPT_PREFIX$index") - .withValue(option) - .build() - } - + val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" + val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the + // Spark command builder to pickup on the Java Options present in the ConfigMap val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) - .addAllToEnv(driverJavaOptsEnvs.asJava) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() .build() val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) .editSpec() .addToContainers(resolvedDriverContainer) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .endConfigMap() + .endVolume() .endSpec() .build() - Utils.tryWithResource( kubernetesClient .pods() @@ -145,7 +142,8 @@ private[spark] class Client( val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = currentDriverSpec.otherKubernetesResources + val otherKubernetesResources = + currentDriverSpec.otherKubernetesResources ++ Seq(configMap) addDriverOwnerReference(createdDriverPod, otherKubernetesResources) kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } @@ -180,6 +178,26 @@ private[spark] class Client( originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference)) } } + + // Build a Config Map that will house spark conf properties in a single file for spark-submit + private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + val properties = new Properties() + conf.getAll.foreach { case (k, v) => + properties.setProperty(k, v) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName") + + val namespace = conf.get(KUBERNETES_NAMESPACE) + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .withNamespace(namespace) + .endMetadata() + .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) + .build() + } } /** @@ -202,6 +220,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") + val kubernetesResourceNamePrefix = { + s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") + } // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -211,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, - launchTime, + kubernetesResourceNamePrefix, clientArguments.mainAppResource, appName, clientArguments.mainClass, @@ -231,7 +252,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesClient, waitForAppCompletion, appName, - watcher) + watcher, + kubernetesResourceNamePrefix) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 164e2e5594778..347c4d2d66826 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} +import org.apache.spark.launcher.SparkLauncher /** * Performs basic configuration for the driver pod. @@ -56,8 +57,6 @@ private[spark] class BasicDriverConfigurationStep( // Memory settings private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val driverMemoryString = sparkConf.get( - DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) @@ -103,24 +102,12 @@ private[spark] class BasicDriverConfigurationStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } - val driverContainer = new ContainerBuilder(driverSpec.driverContainer) + val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(imagePullPolicy) .addAllToEnv(driverCustomEnvs.asJava) .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_MEMORY) - .withValue(driverMemoryString) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_MAIN_CLASS) - .withValue(mainClass) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.mkString(" ")) - .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) .withValueFrom(new EnvVarSourceBuilder() @@ -134,7 +121,16 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") - .build() + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + + val driverContainer = appArgs.toList match { + case "" :: Nil | Nil => driverContainerWithoutArgs.build() + case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() + } val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() @@ -152,10 +148,14 @@ private[spark] class BasicDriverConfigurationStep( .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) + // to set the config variables to allow client-mode spark-submit from driver + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) driverSpec.copy( driverPod = baseDriverPod, driverSparkConf = resolvedSparkConf, driverContainer = driverContainer) } + } + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index d4b83235b4e3b..43de329f239ad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -30,13 +30,11 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec */ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String) extends DriverConfigurationStep { + sparkFiles: Seq[String]) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { @@ -45,14 +43,12 @@ private[spark] class DependencyResolutionStep( if (resolvedSparkFiles.nonEmpty) { sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - - val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { + val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedClasspath.mkString(File.pathSeparator)) - .endEnv() + .withName(ENV_MOUNTED_CLASSPATH) + .withValue(resolvedSparkJars.mkString(File.pathSeparator)) + .endEnv() .build() } else { driverSpec.driverContainer diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala deleted file mode 100644 index 9fb3dafdda540..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringWriter -import java.util.Properties - -import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} - -/** - * Configures the driver init-container that localizes remote dependencies into the driver pod. - * It applies the given InitContainerConfigurationSteps in the given order to produce a final - * InitContainerSpec that is then used to configure the driver pod with the init-container attached. - * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries - * configuration properties for the init-container. - */ -private[spark] class DriverInitContainerBootstrapStep( - steps: Seq[InitContainerConfigurationStep], - configMapName: String, - configMapKey: String) - extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - var initContainerSpec = InitContainerSpec( - properties = Map.empty[String, String], - driverSparkConf = Map.empty[String, String], - initContainer = new ContainerBuilder().build(), - driverContainer = driverSpec.driverContainer, - driverPod = driverSpec.driverPod, - dependentResources = Seq.empty[HasMetadata]) - for (nextStep <- steps) { - initContainerSpec = nextStep.configureInitContainer(initContainerSpec) - } - - val configMap = buildConfigMap( - configMapName, - configMapKey, - initContainerSpec.properties) - val resolvedDriverSparkConf = driverSpec.driverSparkConf - .clone() - .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) - .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) - .setAll(initContainerSpec.driverSparkConf) - val resolvedDriverPod = KubernetesUtils.appendInitContainer( - initContainerSpec.driverPod, initContainerSpec.initContainer) - - driverSpec.copy( - driverPod = resolvedDriverPod, - driverContainer = initContainerSpec.driverContainer, - driverSparkConf = resolvedDriverSparkConf, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ - initContainerSpec.dependentResources ++ - Seq(configMap)) - } - - private def buildConfigMap( - configMapName: String, - configMapKey: String, - config: Map[String, String]): ConfigMap = { - val properties = new Properties() - config.foreach { entry => - properties.setProperty(entry._1, entry._2) - } - val propertiesWriter = new StringWriter() - properties.store(propertiesWriter, - s"Java properties built from Kubernetes config map with name: $configMapName " + - s"and config map key: $configMapKey") - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(configMapKey, propertiesWriter.toString) - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala index ccc18908658f1..2424e63999a82 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala @@ -99,7 +99,7 @@ private[spark] class DriverKubernetesCredentialsStep( }.getOrElse(driverSpec.driverPod) ) - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret => + val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => new ContainerBuilder(driverSpec.driverContainer) .addNewVolumeMount() .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala deleted file mode 100644 index 01469853dacc2..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils - -/** - * Performs basic configuration for the driver init-container with most of the work delegated to - * the given InitContainerBootstrap. - */ -private[spark] class BasicInitContainerConfigurationStep( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - bootstrap: InitContainerBootstrap) - extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { - val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) - val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) - val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) - } else { - Map() - } - val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) - } else { - Map() - } - - val baseInitContainerConfig = Map( - JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, - FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ - remoteJarsConf ++ - remoteFilesConf - - val bootstrapped = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - spec.driverPod, - spec.initContainer, - spec.driverContainer)) - - spec.copy( - initContainer = bootstrapped.initContainer, - driverContainer = bootstrapped.mainContainer, - driverPod = bootstrapped.pod, - properties = spec.properties ++ baseInitContainerConfig) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala deleted file mode 100644 index f2c29c7ce1076..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to - * configure the driver init-container. The returned steps will be applied in the given order to - * produce a final InitContainerSpec that is used to construct the driver init-container in - * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., - * when there are remote application dependencies to localize. - */ -private[spark] class InitContainerConfigOrchestrator( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - imagePullPolicy: String, - configMapName: String, - configMapKey: String, - sparkConf: SparkConf) { - - private val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - - def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { - val initContainerBootstrap = new InitContainerBootstrap( - initContainerImage, - imagePullPolicy, - jarsDownloadPath, - filesDownloadPath, - configMapName, - configMapKey, - SPARK_POD_DRIVER_ROLE, - sparkConf) - val baseStep = new BasicInitContainerConfigurationStep( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - initContainerBootstrap) - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - // Mount user-specified driver secrets also into the driver's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The driver's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq(baseStep) ++ mountSecretsStep - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala deleted file mode 100644 index 0372ad5270951..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -/** - * Represents a step in configuring the driver init-container. - */ -private[spark] trait InitContainerConfigurationStep { - - def configureInitContainer(spec: InitContainerSpec): InitContainerSpec -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala deleted file mode 100644 index 0daa7b95e8aae..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -/** - * An init-container configuration step for mounting user-specified secrets onto user-specified - * paths. - * - * @param bootstrap a utility actually handling mounting of the secrets - */ -private[spark] class InitContainerMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - // Mount the secret volumes given that the volumes have already been added to the driver pod - // when mounting the secrets into the main driver container. - val initContainer = bootstrap.mountSecrets(spec.initContainer) - spec.copy(initContainer = initContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala deleted file mode 100644 index b52c343f0c0ed..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} - -/** - * Represents a specification of the init-container for the driver pod. - * - * @param properties properties that should be set on the init-container - * @param driverSparkConf Spark configuration properties that will be carried back to the driver - * @param initContainer the init-container object - * @param driverContainer the driver container object - * @param driverPod the driver pod object - * @param dependentResources resources the init-container depends on to work - */ -private[spark] case class InitContainerSpec( - properties: Map[String, String], - driverSparkConf: Map[String, String], - initContainer: Container, - driverContainer: Container, - driverPod: Pod, - dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 141bd2827e7c5..98cbd5607da00 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -34,18 +34,10 @@ import org.apache.spark.util.Utils * @param sparkConf Spark configuration * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto * user-specified paths into the executor container - * @param initContainerBootstrap an optional component for bootstrapping the executor init-container - * if one is needed, i.e., when there are remote dependencies to - * localize - * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified - * secrets onto user-specified paths into the executor - * init-container */ private[spark] class ExecutorPodFactory( sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap], - initContainerBootstrap: Option[InitContainerBootstrap], - initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { + mountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) @@ -94,8 +86,6 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) - /** * Configure and construct an executor pod with the given parameters. */ @@ -147,8 +137,9 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId), - (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) @@ -221,30 +212,10 @@ private[spark] class ExecutorPodFactory( (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) - val (bootstrappedPod, bootstrappedContainer) = - initContainerBootstrap.map { bootstrap => - val podWithInitContainer = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - maybeSecretsMountedPod, - new ContainerBuilder().build(), - maybeSecretsMountedContainer)) - - val (pod, mayBeSecretsMountedInitContainer) = - initContainerMountSecretsBootstrap.map { bootstrap => - // Mount the secret volumes given that the volumes have already been added to the - // executor pod when mounting the secrets into the main executor container. - (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) - }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) - - val bootstrappedPod = KubernetesUtils.appendInitContainer( - pod, mayBeSecretsMountedInitContainer) - - (bootstrappedPod, podWithInitContainer.mainContainer) - }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) - new PodBuilder(bootstrappedPod) + new PodBuilder(maybeSecretsMountedPod) .editSpec() - .addToContainers(bootstrappedContainer) + .addToContainers(maybeSecretsMountedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index a942db6ae02db..ff5f6801da2a3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -33,7 +33,9 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && sc.deployMode == "client") { + if (masterURL.startsWith("k8s") && + sc.deployMode == "client" && + !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { throw new SparkException("Client mode is currently not supported for Kubernetes.") } @@ -44,74 +46,23 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val sparkConf = sc.getConf - val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) - val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) - - if (initContainerConfigMap.isEmpty) { - logWarning("The executor's init-container config map is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - if (initContainerConfigMapKey.isEmpty) { - logWarning("The executor's init-container config map key is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - // Only set up the bootstrap if they've provided both the config map key and the config map - // name. The config map might not be provided if init-containers aren't being used to - // bootstrap dependencies. - val initContainerBootstrap = for { - configMap <- initContainerConfigMap - configMapKey <- initContainerConfigMapKey - } yield { - val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - new InitContainerBootstrap( - initContainerImage, - sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), - sparkConf.get(JARS_DOWNLOAD_LOCATION), - sparkConf.get(FILES_DOWNLOAD_LOCATION), - configMap, - configMapKey, - SPARK_POD_EXECUTOR_ROLE, - sparkConf) - } - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) } else { None } - // Mount user-specified executor secrets also into the executor's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The executor's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && - executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, - Some(sparkConf.get(KUBERNETES_NAMESPACE)), + Some(sc.conf.get(KUBERNETES_NAMESPACE)), KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - sparkConf, + sc.conf, Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory( - sparkConf, - mountSecretBootstrap, - initContainerBootstrap, - initContainerMountSecretsBootstrap) + val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala deleted file mode 100644 index e0f29ecd0fb53..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.UUID - -import com.google.common.base.Charsets -import com.google.common.io.Files -import org.mockito.Mockito -import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.util.Utils - -class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { - - private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") - private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") - - private var downloadJarsDir: File = _ - private var downloadFilesDir: File = _ - private var downloadJarsSecretValue: String = _ - private var downloadFilesSecretValue: String = _ - private var fileFetcher: FileFetcher = _ - - override def beforeAll(): Unit = { - downloadJarsSecretValue = Files.toString( - new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) - downloadFilesSecretValue = Files.toString( - new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) - } - - before { - downloadJarsDir = Utils.createTempDir() - downloadFilesDir = Utils.createTempDir() - fileFetcher = mock[FileFetcher] - } - - after { - downloadJarsDir.delete() - downloadFilesDir.delete() - } - - test("Downloads from remote server should invoke the file fetcher") { - val sparkConf = getSparkConfForRemoteFileDownloads - val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) - initContainerUnderTest.run() - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) - } - - private def getSparkConfForRemoteFileDownloads: SparkConf = { - new SparkConf(true) - .set(INIT_CONTAINER_REMOTE_JARS, - "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") - .set(INIT_CONTAINER_REMOTE_FILES, - "http://localhost:9000/file.txt") - .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) - .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) - } - - private def createTempFile(extension: String): String = { - val dir = Utils.createTempDir() - val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") - Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) - file.getAbsolutePath - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index bf4ec04893204..6a501592f42a3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -38,6 +38,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_UID = "pod-id" private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" + private val KUBERNETES_RESOURCE_PREFIX = "resource-example" private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -61,6 +62,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ + private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) @@ -94,7 +96,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) @@ -108,62 +111,52 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { SecondTestConfigurationStep.containerName) } - test("The client should create the secondary Kubernetes resources.") { + test("The client should create Kubernetes resources") { + val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" + val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( submissionSteps, - new SparkConf(false), + new SparkConf(false) + .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues - assert(otherCreatedResources.size === 1) - val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret] - assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(createdResource.getData.asScala === + assert(otherCreatedResources.size === 2) + val secrets = otherCreatedResources.toArray + .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val configMaps = otherCreatedResources.toArray + .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) + assert(secrets.nonEmpty) + val secret = secrets.head + assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) + assert(secret.getData.asScala === Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences) + val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) assert(ownerReference.getName === createdPod.getMetadata.getName) assert(ownerReference.getKind === DRIVER_POD_KIND) assert(ownerReference.getUid === DRIVER_POD_UID) assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) - } - - test("The client should attach the driver container with the appropriate JVM options.") { - val sparkConf = new SparkConf(false) - .set("spark.logConf", "true") - .set( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, - "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails") - val submissionClient = new Client( - submissionSteps, - sparkConf, - kubernetesClient, - false, - "spark", - loggingPodStatusWatcher) - submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue + assert(configMaps.nonEmpty) + val configMap = configMaps.head + assert(configMap.getMetadata.getName === + s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") + assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( + "spark.custom-conf=custom-conf-value")) val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env => - env.getName.startsWith(ENV_JAVA_OPT_PREFIX) - }.sortBy(_.getName) - assert(driverJvmOptsEnvs.size === 4) - - val expectedJvmOptsValues = Seq( - "-Dspark.logConf=true", - s"-D${SecondTestConfigurationStep.sparkConfKey}=" + - s"${SecondTestConfigurationStep.sparkConfValue}", - "-XX:+HeapDumpOnOutOfMemoryError", - "-XX:+PrintGCDetails") - driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach { - case ((resolvedEnv, expectedJvmOpt), index) => - assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index") - assert(resolvedEnv.getValue === expectedJvmOpt) - } + val driverEnv = driverContainer.getEnv.asScala.head + assert(driverEnv.getName === ENV_SPARK_CONF_DIR) + assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) + val driverMount = driverContainer.getVolumeMounts.asScala.head + assert(driverMount.getName === SPARK_CONF_VOLUME) + assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) } test("Waiting for app completion should stall on the watcher") { @@ -173,7 +166,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, true, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } @@ -209,13 +203,11 @@ private object FirstTestConfigurationStep extends DriverConfigurationStep { } private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" val annotationValue = "submitted" val sparkConfKey = "spark.custom-conf" val sparkConfValue = "custom-conf-value" val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val modifiedPod = new PodBuilder(driverSpec.driverPod) .editMetadata() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 033d303e946fd..df34d2dbcb5be 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -25,7 +25,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val DRIVER_IMAGE = "driver-image" private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" - private val LAUNCH_TIME = 975256L + private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") @@ -38,7 +38,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -49,15 +49,14 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep] - ) + classOf[DependencyResolutionStep]) } test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Option.empty, APP_NAME, MAIN_CLASS, @@ -67,31 +66,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { orchestrator, classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep] - ) - } - - test("Submission steps with an init-container.") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) - .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverInitContainerBootstrapStep]) + classOf[DriverKubernetesCredentialsStep]) } test("Submission steps with driver secrets to mount") { @@ -102,7 +77,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -122,7 +97,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, DRIVER_IMAGE) var orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, @@ -135,7 +110,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index b136f2c02ffba..ce068531c7673 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -73,16 +73,13 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - assert(preparedDriverSpec.driverContainer.getEnv.size === 7) + assert(preparedDriverSpec.driverContainer.getEnv.size === 4) val envs = preparedDriverSpec.driverContainer .getEnv .asScala .map(env => (env.getName, env.getValue)) .toMap assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(ENV_DRIVER_MEMORY) === "256M") - assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") @@ -112,7 +109,8 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX) + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") assert(resolvedSparkConf === expectedSparkConf) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala index 991b03cafb76c..ca43fc97dc991 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala @@ -29,24 +29,17 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DependencyResolutionStepSuite extends SparkFunSuite { private val SPARK_JARS = Seq( - "hdfs://localhost:9000/apps/jars/jar1.jar", - "file:///home/user/apps/jars/jar2.jar", - "local:///var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "local:///var/apps/jars/jar2.jar") private val SPARK_FILES = Seq( - "file:///home/user/apps/files/file1.txt", - "hdfs://localhost:9000/apps/files/file2.txt", - "local:///var/apps/files/file3.txt") - - private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars" - private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files" + "apps/files/file1.txt", + "local:///var/apps/files/file2.txt") test("Added dependencies should be resolved in Spark configuration and environment") { val dependencyResolutionStep = new DependencyResolutionStep( SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH) + SPARK_FILES) val driverPod = new PodBuilder().build() val baseDriverSpec = KubernetesDriverSpec( driverPod = driverPod, @@ -58,24 +51,19 @@ class DependencyResolutionStepSuite extends SparkFunSuite { assert(preparedDriverSpec.otherKubernetesResources.isEmpty) val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet val expectedResolvedSparkJars = Set( - "hdfs://localhost:9000/apps/jars/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "/var/apps/jars/jar2.jar") assert(resolvedSparkJars === expectedResolvedSparkJars) val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet val expectedResolvedSparkFiles = Set( - s"$FILES_DOWNLOAD_PATH/file1.txt", - s"hdfs://localhost:9000/apps/files/file2.txt", - s"/var/apps/files/file3.txt") + "apps/files/file1.txt", + "/var/apps/files/file2.txt") assert(resolvedSparkFiles === expectedResolvedSparkFiles) val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala assert(driverEnv.size === 1) assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = Set( - s"$JARS_DOWNLOAD_PATH/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + val expectedResolvedDriverClasspath = expectedResolvedSparkJars assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala deleted file mode 100644 index 758871e2ba356..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringReader -import java.util.Properties - -import scala.collection.JavaConverters._ - -import com.google.common.collect.Maps -import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} -import org.apache.spark.util.Utils - -class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { - - private val CONFIG_MAP_NAME = "spark-init-config-map" - private val CONFIG_MAP_KEY = "spark-init-config-map-key" - - test("The init container bootstrap step should use all of the init container steps") { - val baseDriverSpec = KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val initContainerSteps = Seq( - FirstTestInitContainerConfigurationStep, - SecondTestInitContainerConfigurationStep) - val bootstrapStep = new DriverInitContainerBootstrapStep( - initContainerSteps, - CONFIG_MAP_NAME, - CONFIG_MAP_KEY) - - val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === - FirstTestInitContainerConfigurationStep.additionalLabels) - val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(additionalDriverEnv.size === 1) - assert(additionalDriverEnv.head.getName === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) - assert(additionalDriverEnv.head.getValue === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) - - assert(preparedDriverSpec.otherKubernetesResources.size === 2) - assert(preparedDriverSpec.otherKubernetesResources.contains( - FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) - assert(preparedDriverSpec.otherKubernetesResources.exists { - case configMap: ConfigMap => - val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME - val configMapData = configMap.getData.asScala - val hasCorrectNumberOfEntries = configMapData.size == 1 - val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) - val initContainerProperties = new Properties() - Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { - initContainerProperties.load(_) - } - val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala - val expectedInitContainerProperties = Map( - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) - val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties - hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties - - case _ => false - }) - - val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers - assert(initContainers.size() === 1) - val initContainerEnv = initContainers.get(0).getEnv.asScala - assert(initContainerEnv.size === 1) - assert(initContainerEnv.head.getName === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) - assert(initContainerEnv.head.getValue === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) - - val expectedSparkConf = Map( - INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - } -} - -private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - - val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") - val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" - val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" - val additionalKubernetesResource = new SecretBuilder() - .withNewMetadata() - .withName("test-secret") - .endMetadata() - .addToData("secret-key", "secret-value") - .build() - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val driverPod = new PodBuilder(initContainerSpec.driverPod) - .editOrNewMetadata() - .addToLabels(additionalLabels.asJava) - .endMetadata() - .build() - val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) - .addNewEnv() - .withName(additionalMainContainerEnvKey) - .withValue(additionalMainContainerEnvValue) - .endEnv() - .build() - initContainerSpec.copy( - driverPod = driverPod, - driverContainer = mainContainer, - dependentResources = initContainerSpec.dependentResources ++ - Seq(additionalKubernetesResource)) - } -} - -private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" - val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" - val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" - val additionalInitContainerPropertyValue = "testvalue" - val additionalDriverSparkConfKey = "spark.driver.testkey" - val additionalDriverSparkConfValue = "spark.driver.testvalue" - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val initContainer = new ContainerBuilder(initContainerSpec.initContainer) - .addNewEnv() - .withName(additionalInitContainerEnvKey) - .withValue(additionalInitContainerEnvValue) - .endEnv() - .build() - val initContainerProperties = initContainerSpec.properties ++ - Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) - val driverSparkConf = initContainerSpec.driverSparkConf ++ - Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) - initContainerSpec.copy( - initContainer = initContainer, - properties = initContainerProperties, - driverSparkConf = driverSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala deleted file mode 100644 index 4553f9f6b1d45..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito.when -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ - -class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val POD_LABEL = Map("bootstrap" -> "true") - private val INIT_CONTAINER_NAME = "init-container" - private val DRIVER_CONTAINER_NAME = "driver-container" - - @Mock - private var podAndInitContainerBootstrap : InitContainerBootstrap = _ - - before { - MockitoAnnotations.initMocks(this) - when(podAndInitContainerBootstrap.bootstrapInitContainer( - any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { - override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { - val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) - pod.copy( - pod = new PodBuilder(pod.pod) - .withNewMetadata() - .addToLabels("bootstrap", "true") - .endMetadata() - .withNewSpec().endSpec() - .build(), - initContainer = new ContainerBuilder() - .withName(INIT_CONTAINER_NAME) - .build(), - mainContainer = new ContainerBuilder() - .withName(DRIVER_CONTAINER_NAME) - .build() - )}}) - } - - test("additionalDriverSparkConf with mix of remote files and jars") { - val baseInitStep = new BasicInitContainerConfigurationStep( - SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - podAndInitContainerBootstrap) - val expectedDriverSparkConf = Map( - JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, - INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", - INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") - val initContainerSpec = InitContainerSpec( - Map.empty[String, String], - Map.empty[String, String], - new Container(), - new Container(), - new Pod, - Seq.empty[HasMetadata]) - val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) - assert(expectedDriverSparkConf === returnContainerSpec.properties) - assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) - assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala deleted file mode 100644 index 09b42e4484d86..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -class InitContainerConfigOrchestratorSuite extends SparkFunSuite { - - private val DOCKER_IMAGE = "init-container" - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" - private val CUSTOM_LABEL_KEY = "customLabel" - private val CUSTOM_LABEL_VALUE = "customLabelValue" - private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" - private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("including basic configuration step") { - val sparkConf = new SparkConf(true) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.lengthCompare(1) == 0) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - } - - test("including step to mount user-specified secrets") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.length === 2) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala deleted file mode 100644 index 7ac0bde80dfe6..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} - -class InitContainerMountSecretsStepSuite extends SparkFunSuite { - - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("mounts all given secrets") { - val baseInitContainerSpec = InitContainerSpec( - Map.empty, - Map.empty, - new ContainerBuilder().build(), - new ContainerBuilder().build(), - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - Seq.empty) - val secretNamesToMountPaths = Map( - SECRET_FOO -> SECRET_MOUNT_PATH, - SECRET_BAR -> SECRET_MOUNT_PATH) - - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) - val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( - baseInitContainerSpec) - val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.containerHasVolume( - initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a3c615be031d2..7755b93835047 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ -import org.mockito.{AdditionalAnswers, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito._ +import org.mockito.MockitoAnnotations import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.MountSecretsBootstrap class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { @@ -55,10 +53,11 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None, None, None) + val factory = new ExecutorPodFactory(baseConf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -89,7 +88,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -101,7 +100,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -116,11 +115,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val conf = baseConf.clone() val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - None, - None) + val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -138,50 +133,6 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } - test("init-container bootstrap step adds an init container") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - - val factory = new ExecutorPodFactory( - conf, - None, - Some(initContainerBootstrap), - None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getInitContainers.size() === 1) - checkOwnerReferences(executor, driverPodUid) - } - - test("init-container with secrets mount bootstrap") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - Some(initContainerBootstrap), - Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getVolumes.size() === 1) - assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) - assert(executor.getSpec.getInitContainers.size() === 1) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) - - checkOwnerReferences(executor, driverPodUid) - } - // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) @@ -197,8 +148,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null, - ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 491b7cf692478..9badf8556afc3 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -40,7 +40,6 @@ RUN set -ex && \ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY conf /opt/spark/conf COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples COPY data /opt/spark/data diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index d0cf284f035ea..3e166116aa3fd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -56,14 +56,10 @@ fi case "$SPARK_K8S_CMD" in driver) CMD=( - ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" - -cp "$SPARK_CLASSPATH" - -Xms$SPARK_DRIVER_MEMORY - -Xmx$SPARK_DRIVER_MEMORY - -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS - $SPARK_DRIVER_CLASS - $SPARK_DRIVER_ARGS + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" ) ;; @@ -83,14 +79,6 @@ case "$SPARK_K8S_CMD" in ) ;; - init) - CMD=( - "$SPARK_HOME/bin/spark-class" - "org.apache.spark.deploy.k8s.SparkPodInitContainer" - "$@" - ) - ;; - *) echo "Unknown command: $SPARK_K8S_CMD" 1>&2 exit 1 From 5f4deff19511b6870f056eba5489104b9cac05a9 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 19 Mar 2018 18:02:04 -0700 Subject: [PATCH 0492/2461] [SPARK-23660] Fix exception in yarn cluster mode when application ended fast ## What changes were proposed in this pull request? Yarn throws the following exception in cluster mode when the application is really small: ``` 18/03/07 23:34:22 WARN netty.NettyRpcEnv: Ignored failure: java.util.concurrent.RejectedExecutionException: Task java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask7c974942 rejected from java.util.concurrent.ScheduledThreadPoolExecutor1eea9d2d[Terminated, pool size = 0, active threads = 0, queued tasks = 0, completed tasks = 0] 18/03/07 23:34:22 ERROR yarn.ApplicationMaster: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:205) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:76) at org.apache.spark.deploy.yarn.YarnAllocator.(YarnAllocator.scala:102) at org.apache.spark.deploy.yarn.YarnRMClient.register(YarnRMClient.scala:77) at org.apache.spark.deploy.yarn.ApplicationMaster.registerAM(ApplicationMaster.scala:450) at org.apache.spark.deploy.yarn.ApplicationMaster.runDriver(ApplicationMaster.scala:493) at org.apache.spark.deploy.yarn.ApplicationMaster.org$apache$spark$deploy$yarn$ApplicationMaster$$runImpl(ApplicationMaster.scala:345) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply$mcV$sp(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$5.run(ApplicationMaster.scala:810) at java.security.AccessController.doPrivileged(Native Method) at javax.security.auth.Subject.doAs(Subject.java:422) at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1920) at org.apache.spark.deploy.yarn.ApplicationMaster.doAsUser(ApplicationMaster.scala:809) at org.apache.spark.deploy.yarn.ApplicationMaster.run(ApplicationMaster.scala:259) at org.apache.spark.deploy.yarn.ApplicationMaster$.main(ApplicationMaster.scala:834) at org.apache.spark.deploy.yarn.ApplicationMaster.main(ApplicationMaster.scala) Caused by: org.apache.spark.rpc.RpcEnvStoppedException: RpcEnv already stopped. at org.apache.spark.rpc.netty.Dispatcher.postMessage(Dispatcher.scala:158) at org.apache.spark.rpc.netty.Dispatcher.postLocalMessage(Dispatcher.scala:135) at org.apache.spark.rpc.netty.NettyRpcEnv.ask(NettyRpcEnv.scala:229) at org.apache.spark.rpc.netty.NettyRpcEndpointRef.ask(NettyRpcEnv.scala:523) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:91) ... 17 more 18/03/07 23:34:22 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 13, (reason: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: ) ``` Example application: ``` object ExampleApp { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("ExampleApp") val sc = new SparkContext(conf) try { // Do nothing } finally { sc.stop() } } ``` This PR pauses user class thread after `SparkContext` created and keeps it so until application master initialises properly. ## How was this patch tested? Automated: Existing unit tests Manual: Application submitted into small cluster Author: Gabor Somogyi Closes #20807 from gaborgsomogyi/SPARK-23660. --- .../spark/deploy/yarn/ApplicationMaster.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2f88feb0f1fdf..6e35d23def6f0 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -418,7 +418,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def sparkContextInitialized(sc: SparkContext) = { - sparkContextPromise.success(sc) + sparkContextPromise.synchronized { + // Notify runDriver function that SparkContext is available + sparkContextPromise.success(sc) + // Pause the user class thread in order to make proper initialization in runDriver function. + sparkContextPromise.wait() + } + } + + private def resumeDriver(): Unit = { + // When initialization in runDriver happened the user class thread has to be resumed. + sparkContextPromise.synchronized { + sparkContextPromise.notify() + } } private def registerAM( @@ -497,6 +509,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // if the user app did not create a SparkContext. throw new IllegalStateException("User did not initialize spark context!") } + resumeDriver() userClassThread.join() } catch { case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => @@ -506,6 +519,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") + } finally { + resumeDriver() } } From 566321852b2d60641fe86acbc8914b4a7063b58e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Mar 2018 21:25:37 -0700 Subject: [PATCH 0493/2461] [SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible ## What changes were proposed in this pull request? https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e added an useful util ```python contextmanager def sql_conf(self, pairs): ... ``` to allow configuration set/unset within a block: ```python with self.sql_conf({"spark.blah.blah.blah", "blah"}) # test codes ``` This PR proposes to use this util where possible in PySpark tests. Note that there look already few places affecting tests without restoring the original value back in unittest classes. ## How was this patch tested? Manually tested via: ``` ./run-tests --modules=pyspark-sql --python-executables=python2 ./run-tests --modules=pyspark-sql --python-executables=python3 ``` Author: hyukjinkwon Closes #20830 from HyukjinKwon/cleanup-sql-conf. --- python/pyspark/sql/tests.py | 130 ++++++++++++++---------------------- 1 file changed, 50 insertions(+), 80 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a0d547ad620e5..39d6c5226f138 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2461,17 +2461,13 @@ def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") - try: - self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) - finally: - # We should unset this. Otherwise, other tests are affected. - self.spark.conf.unset("spark.sql.crossJoin.enabled") # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): @@ -2943,21 +2939,18 @@ def test_create_dateframe_from_pandas_with_dst(self): self.assertPandasEqual(pdf, df.toPandas()) orig_env_tz = os.environ.get('TZ', None) - orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') try: tz = 'America/Los_Angeles' os.environ['TZ'] = tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', tz) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) finally: del os.environ['TZ'] if orig_env_tz is not None: os.environ['TZ'] = orig_env_tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) class HiveSparkSubmitTests(SparkSubmitTests): @@ -3562,12 +3555,11 @@ def test_null_conversion(self): self.assertTrue(all([c == 1 for c in null_counts])) def _toPandas_arrow_toggle(self, df): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): pdf = df.toPandas() - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + pdf_arrow = df.toPandas() + return pdf, pdf_arrow def test_toPandas_arrow_toggle(self): @@ -3579,16 +3571,17 @@ def test_toPandas_arrow_toggle(self): def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) self.assertPandasEqual(pdf_arrow_ny, pdf_ny) @@ -3601,8 +3594,6 @@ def test_toPandas_respect_session_timezone(self): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) self.assertPandasEqual(pdf_ny, pdf_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() @@ -3618,12 +3609,11 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow def test_createDataFrame_toggle(self): @@ -3634,18 +3624,18 @@ def test_createDataFrame_toggle(self): def test_createDataFrame_respect_session_timezone(self): from datetime import timedelta pdf = self.create_pandas_data_frame() - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) result_ny = df_no_arrow_ny.collect() result_arrow_ny = df_arrow_ny.collect() @@ -3658,8 +3648,6 @@ def test_createDataFrame_respect_session_timezone(self): for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() @@ -4336,9 +4324,7 @@ def gen_timestamps(id): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd - orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) - try: + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @pandas_udf(returnType=LongType()) @@ -4348,11 +4334,6 @@ def check_records_per_batch(x): result = df.select(check_records_per_batch(col("id"))).collect() for (r,) in result: self.assertTrue(r <= 3) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - else: - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) def test_vectorized_udf_timestamps_respect_session_timezone(self): from pyspark.sql.functions import pandas_udf, col @@ -4371,30 +4352,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): internal_value = pandas_udf( lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() self.assertNotEqual(result_ny, result_la) self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations @@ -5170,9 +5148,7 @@ def test_complex_expressions(self): def test_retain_group_columns(self): from pyspark.sql.functions import sum, lit, col - orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) - self.spark.conf.set("spark.sql.retainGroupColumns", False) - try: + with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -5180,12 +5156,6 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.retainGroupColumns") - else: - self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) - def test_invalid_args(self): from pyspark.sql.functions import mean From 5e7bc2acef4a1e11d0d8056ef5c12cd5c8f220da Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 20 Mar 2018 10:34:56 -0700 Subject: [PATCH 0494/2461] [SPARK-23649][SQL] Skipping chars disallowed in UTF-8 ## What changes were proposed in this pull request? The mapping of UTF-8 char's first byte to char's size doesn't cover whole range 0-255. It is defined only for 0-253: https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L60-L65 https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L190 If the first byte of a char is 253-255, IndexOutOfBoundsException is thrown. Besides of that values for 244-252 are not correct according to recent unicode standard for UTF-8: http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf As a consequence of the exception above, the length of input string in UTF-8 encoding cannot be calculated if the string contains chars started from 253 code. It is visible on user's side as for example crashing of schema inferring of csv file which contains such chars but the file can be read if the schema is specified explicitly or if the mode set to multiline. The proposed changes build correct mapping of first byte of UTF-8 char to its size (now it covers all cases) and skip disallowed chars (counts it as one octet). ## How was this patch tested? Added a test and a file with a char which is disallowed in UTF-8 - 0xFF. Author: Maxim Gekk Closes #20796 from MaxGekk/skip-wrong-utf8-chars. --- .../apache/spark/unsafe/types/UTF8String.java | 48 +++++++++++++++---- .../spark/unsafe/types/UTF8StringSuite.java | 23 ++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b0d0c44823e68..5d468aed42337 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -57,12 +57,43 @@ public final class UTF8String implements Comparable, Externalizable, public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } - private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6}; + /** + * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which + * indicates the size of the char. See Unicode standard in page 126, Table 3-6: + * http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf + * + * Binary Hex Comments + * 0xxxxxxx 0x00..0x7F Only byte of a 1-byte character encoding + * 10xxxxxx 0x80..0xBF Continuation bytes (1-3 continuation bytes) + * 110xxxxx 0xC0..0xDF First byte of a 2-byte character encoding + * 1110xxxx 0xE0..0xEF First byte of a 3-byte character encoding + * 11110xxx 0xF0..0xF7 First byte of a 4-byte character encoding + * + * As a consequence of the well-formedness conditions specified in + * Table 3-7 (page 126), the following byte values are disallowed in UTF-8: + * C0–C1, F5–FF. + */ + private static byte[] bytesOfCodePointInUTF8 = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F + // Continuation bytes cannot appear as the first byte + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF + 0, 0, // 0xC0..0xC1 - disallowed in UTF-8 + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF + 4, 4, 4, 4, 4, // 0xF0..0xF4 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 + }; private static final boolean IS_LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; @@ -187,8 +218,9 @@ public void writeTo(OutputStream out) throws IOException { * @param b The first byte of a code point */ private static int numBytesForFirstByte(final byte b) { - final int offset = (b & 0xFF) - 192; - return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + final int offset = b & 0xFF; + byte numBytes = bytesOfCodePointInUTF8[offset]; + return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b303fa5bc6c5..7c34d419574ef 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -58,8 +58,12 @@ private static void checkBasic(String str, int len) { @Test public void basicTest() { checkBasic("", 0); - checkBasic("hello", 5); + checkBasic("¡", 1); // 2 bytes char + checkBasic("ку", 2); // 2 * 2 bytes chars + checkBasic("hello", 5); // 5 * 1 byte chars checkBasic("大 千 世 界", 7); + checkBasic("︽﹋%", 3); // 3 * 3 bytes chars + checkBasic("\uD83E\uDD19", 1); // 4 bytes char } @Test @@ -791,4 +795,21 @@ public void trimRightWithTrimString() { assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); } + + @Test + public void skipWrongFirstByte() { + int[] wrongFirstBytes = { + 0x80, 0x9F, 0xBF, // Skip Continuation bytes + 0xC0, 0xC2, // 0xC0..0xC1 - disallowed in UTF-8 + // 0xF5..0xFF - disallowed in UTF-8 + 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, + 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF + }; + byte[] c = new byte[1]; + + for (int i = 0; i < wrongFirstBytes.length; ++i) { + c[0] = (byte)wrongFirstBytes[i]; + assertEquals(fromBytes(c).numChars(), 1); + } + } } From 7f5e8aa2606b0ee0297ceb6f4603bd368e3b0291 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 20 Mar 2018 11:14:34 -0700 Subject: [PATCH 0495/2461] [SPARK-21898][ML] Feature parity for KolmogorovSmirnovTest in MLlib ## What changes were proposed in this pull request? Feature parity for KolmogorovSmirnovTest in MLlib. Implement `DataFrame` interface for `KolmogorovSmirnovTest` in `mllib.stat`. ## How was this patch tested? Test suite added. Author: WeichenXu Author: jkbradley Closes #19108 from WeichenXu123/ml-ks-test. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 113 ++++++++++++++ .../stat/JavaKolmogorovSmirnovTestSuite.java | 84 +++++++++++ .../ml/stat/KolmogorovSmirnovTestSuite.scala | 140 ++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala new file mode 100644 index 0000000000000..8d80e7768cb6e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.annotation.varargs + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.function.Function +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col + +/** + * :: Experimental :: + * + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see + * Kolmogorov-Smirnov test (Wikipedia) + */ +@Experimental +@Since("2.4.0") +object KolmogorovSmirnovTest { + + /** Used to construct output schema of test */ + private case class KolmogorovSmirnovTestResult( + pValue: Double, + statistic: Double) + + private def getSampleRDD(dataset: DataFrame, sampleCol: String): RDD[Double] = { + SchemaUtils.checkNumericType(dataset.schema, sampleCol) + import dataset.sparkSession.implicits._ + dataset.select(col(sampleCol).cast("double")).as[Double].rdd + } + + /** + * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } + + /** + * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) + } + + /** + * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability + * distribution equality. Currently supports the normal distribution, taking as parameters + * the mean and standard deviation. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param distName a `String` name for a theoretical distribution, currently only support "norm". + * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + @varargs + def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java new file mode 100644 index 0000000000000..021272dd5a40c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + + +public class JavaKolmogorovSmirnovTestSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = Arrays.asList(0.1, 1.1, 10.1, -1.1); + + dataset = spark.createDataset(points, Encoders.DOUBLE()).toDF("sample"); + } + + @Test + public void testKSTestCDF() { + // Create theoretical distributions + NormalDistribution stdNormalDist = new NormalDistribution(0, 1); + + // set seeds + Long seed = 10L; + stdNormalDist.reseedRandomGenerator(seed); + Function stdNormalCDF = (x) -> stdNormalDist.cumulativeProbability(x); + + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", stdNormalCDF).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } + + @Test + public void testKSTestNamedDistribution() { + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", "norm", 0.0, 1.0).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala new file mode 100644 index 0000000000000..1312de3a1b522 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.commons.math3.distribution.{ExponentialDistribution, NormalDistribution, + RealDistribution, UniformRealDistribution} +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => Math3KSTest} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class KolmogorovSmirnovTestSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + def apacheCommonMath3EquivalenceTest( + sampleDist: RealDistribution, + theoreticalDist: RealDistribution, + theoreticalDistByName: (String, Array[Double]), + rejectNullHypothesis: Boolean): Unit = { + + // set seeds + val seed = 10L + sampleDist.reseedRandomGenerator(seed) + if (theoreticalDist != null) { + theoreticalDist.reseedRandomGenerator(seed) + } + + // Sample data from the distributions and parallelize it + val n = 100000 + val sampledArray = sampleDist.sample(n) + val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample") + + // Use a apache math commons local KS test to verify calculations + val ksTest = new Math3KSTest() + val pThreshold = 0.05 + + // Comparing a standard normal sample to a standard normal distribution + val Row(pValue1: Double, statistic1: Double) = + if (theoreticalDist != null) { + val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x) + KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head() + } else { + KolmogorovSmirnovTest.test(sampledDF, "sample", + theoreticalDistByName._1, + theoreticalDistByName._2: _* + ).head() + } + val theoreticalDistMath3 = if (theoreticalDist == null) { + assert(theoreticalDistByName._1 == "norm") + val params = theoreticalDistByName._2 + new NormalDistribution(params(0), params(1)) + } else { + theoreticalDist + } + val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(theoreticalDistMath3, sampledArray) + val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n) + // Verify vs apache math commons ks test + assert(statistic1 ~== referenceStat1 relTol 1e-4) + assert(pValue1 ~== referencePVal1 relTol 1e-4) + + if (rejectNullHypothesis) { + assert(pValue1 < pThreshold) + } else { + assert(pValue1 > pThreshold) + } + } + + test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") { + // Create theoretical distributions + val stdNormalDist = new NormalDistribution(0.0, 1.0) + val expDist = new ExponentialDistribution(0.6) + val uniformDist = new UniformRealDistribution(0.0, 1.0) + val expDist2 = new ExponentialDistribution(0.2) + val stdNormByName = Tuple2("norm", Array(0.0, 1.0)) + + apacheCommonMath3EquivalenceTest(stdNormalDist, null, stdNormByName, false) + apacheCommonMath3EquivalenceTest(expDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(uniformDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(expDist, expDist2, null, true) + } + + test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") { + /* + Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample + > sessionInfo() + R version 3.2.0 (2015-04-16) + Platform: x86_64-apple-darwin13.4.0 (64-bit) + > set.seed(20) + > v <- rnorm(20) + > v + [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612 + [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222 + [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264 + [19] -0.08578219 0.38921440 + > ks.test(v, pnorm, alternative = "two.sided") + + One-sample Kolmogorov-Smirnov test + + data: v + D = 0.18874, p-value = 0.4223 + alternative hypothesis: two-sided + */ + + val rKSStat = 0.18874 + val rKSPVal = 0.4223 + val rData = sc.parallelize( + Array( + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ) + ).toDF("sample") + val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest + .test(rData, "sample", "norm", 0, 1).head() + assert(statistic ~== rKSStat relTol 1e-4) + assert(pValue ~== rKSPVal relTol 1e-4) + } +} From 2c4b9962fdf8c1beb66126ca41628c72eb6c2383 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 20 Mar 2018 11:46:51 -0700 Subject: [PATCH 0496/2461] [SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. ## What changes were proposed in this pull request? Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem. ## How was this patch tested? existing unit tests Author: Jose Torres Author: Jose Torres Closes #20726 from jose-torres/SPARK-23574. --- .../v2/reader/SupportsReportPartitioning.java | 3 ++ .../datasources/v2/DataSourceRDD.scala | 4 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 29 ++++++++++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- .../sql/streaming/StreamingQuerySuite.scala | 4 +-- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 5405a916951b8..607628746e873 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -23,6 +23,9 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. + * + * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving public interface SupportsReportPartitioning extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5ed0ba71e94c7..f85971be394b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index cb691ba297076..3a5e7bf89e142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. @@ -56,6 +58,15 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + SinglePartition + + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + SinglePartition + + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + SinglePartition + case s: SupportsReportPartitioning => new DataSourcePartitioning( s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) @@ -63,29 +74,33 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() + private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala case _ => reader.createDataReaderFactories().asScala.map { new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] - }.asJava + } } - private lazy val inputRDD: RDD[InternalRow] = reader match { + private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) - .asInstanceOf[RDD[InternalRow]] + r.createBatchDataReaderFactories().asScala + } + private lazy val inputRDD: RDD[InternalRow] = reader match { case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size())) + .askSync[Unit](SetReaderPartitions(readerFactories.size)) new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + case _ => new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cf02c0dda25d7..06754f01657d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1157a350461d8..e0a53272cd222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("SPARK-23574: no shuffle exchange with single partition") { + val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +} + class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 3f9aa0d1fa5be..ebc9a87b23f84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.durationMs.get("setOffsetRange") === 50) assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 0) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 350) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From 477d6bd7265e255fd43e53edda02019b32f29bb2 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 20 Mar 2018 13:27:50 -0700 Subject: [PATCH 0497/2461] [SPARK-23500][SQL] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? Complex type simplification optimizer rules were not applied to the entire plan, just the expressions reachable from the root node. This patch fixes the rules to transform the entire plan. ## How was this patch tested? New unit test + ran sql / core tests. Author: Henry Robinson Author: Henry Robinson Closes #20687 from henryr/spark-25000. --- .../sql/catalyst/optimizer/ComplexTypes.scala | 61 ++++++++----------- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/complexTypesSuite.scala | 55 +++++++++++++++-- 3 files changed, 76 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index be0009ec8c760..db7d6d3254bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,39 +18,39 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** -* push down operations into [[CreateNamedStructLike]]. -*/ -object SimplifyCreateStructOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field extraction + * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + */ +object SimplifyExtractValueOps extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // One place where this optimization is invalid is an aggregation where the select + // list expression is a function of a grouping expression: + // + // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) + // + // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this + // optimization for Aggregates (although this misses some cases where the optimization + // can be made). + case a: Aggregate => a + case p => p.transformExpressionsUp { + // Remove redundant field extraction. case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => createNamedStructLike.valExprs(ordinal) - } - } -} -/** -* push down operations into [[CreateArray]]. -*/ -object SimplifyCreateArrayOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field selection (array of structs) - case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) => - // instead f selecting the field on the entire array, - // select it from each member of the array. - // pushing down the operation this way open other optimizations opportunities - // (i.e. struct(...,x,...).x) + // Remove redundant array indexing. + case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => + // Instead of selecting the field on the entire array, select it from each member + // of the array. Pushing down the operation this way may open other optimizations + // opportunities (i.e. struct(...,x,...).x) CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name)))) - // push down item selection. + + // Remove redundant map lookup. case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) => - // instead of creating the array and then selecting one row, - // remove array creation altgether. + // Instead of creating the array and then selecting one row, remove array creation + // altogether. if (idx >= 0 && idx < elems.size) { // valid index elems(idx) @@ -58,18 +58,7 @@ object SimplifyCreateArrayOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - } - } -} - -/** -* push down operations into [[CreateMap]]. -*/ -object SimplifyCreateMapOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems) } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..2829d1d81eb1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateSerialization, RemoveRedundantAliases, RemoveRedundantProject, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps, + SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index de544ac314789..e44a6692ad8e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -44,14 +44,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: Nil + SimplifyExtractValueOps) :: Nil } val idAtt = ('id).long.notNull + val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt ) + lazy val relation = LocalRelation(idAtt, nullableIdAtt) test("explicit get from namedStruct") { val query = relation @@ -321,7 +320,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( CaseWhen(Seq( (EqualTo(2L, 'id), ('id + 1L)), - // these two are possible matches, we can't tell untill runtime + // these two are possible matches, we can't tell until runtime (EqualTo(2L, ('id + 1L)), ('id + 2L)), (EqualTo(2L, 'id + 2L), Literal.create(null, LongType)), // this is a definite match (two constants), @@ -331,4 +330,50 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .analyze comparePlans(Optimizer execute rel, expected) } + + test("SPARK-23500: Simplify complex ops that aren't at the plan root") { + val structRel = relation + .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") + .groupBy($"foo")("1").analyze + val structExpected = relation + .select('nullable_id as "foo") + .groupBy($"foo")("1").analyze + comparePlans(Optimizer execute structRel, structExpected) + + // These tests must use nullable attributes from the base relation for the following reason: + // in the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to a1, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. In the 'expected' plans, + // the grouping expressions have the same nullability as the original attribute in the + // relation. If that attribute is non-nullable, the tests will fail as the plans will + // compare differently, so for these tests we must use a nullable attribute. See + // SPARK-23634. + val arrayRel = relation + .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") + .groupBy($"a1")("1").analyze + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze + comparePlans(Optimizer execute arrayRel, arrayExpected) + + val mapRel = relation + .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") + .groupBy($"m1")("1").analyze + val mapExpected = relation + .select('nullable_id as "m1") + .groupBy($"m1")("1").analyze + comparePlans(Optimizer execute mapRel, mapExpected) + } + + test("SPARK-23500: Ensure that aggregation expressions are not simplified") { + // Make sure that aggregation exprs are correctly ignored. Maps can't be used in + // grouping exprs so aren't tested here. + val structAggRel = relation.groupBy( + CreateNamedStruct(Seq("att1", 'nullable_id)))( + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze + comparePlans(Optimizer execute structAggRel, structAggRel) + + val arrayAggRel = relation.groupBy( + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze + comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + } } From 983e8d9d64b6b1304c43ea6e5dffdc1078138ef9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 20 Mar 2018 23:17:49 -0700 Subject: [PATCH 0498/2461] [SPARK-23666][SQL] Do not display exprIds of Alias in user-facing info. ## What changes were proposed in this pull request? To drop `exprId`s for `Alias` in user-facing info., this pr added an entry for `Alias` in `NonSQLExpression.sql` ## How was this patch tested? Added tests in `UDFSuite`. Author: Takeshi Yamamuro Closes #20827 from maropu/SPARK-23666. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/expressions/Expression.scala | 1 + .../scala/org/apache/spark/sql/UDFSuite.scala | 132 ++++++++++-------- 3 files changed, 78 insertions(+), 56 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0e092e0e37ccf..5b47fd77f2cbc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1806,6 +1806,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d7f9e38915dd5..38caf67d465d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -288,6 +288,7 @@ trait NonSQLExpression extends Expression { final override def sql: String = { transform { case a: Attribute => new PrettyAttribute(a) + case a: Alias => PrettyAttribute(a.sql, a.dataType) }.toString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index af6a10b425b9f..21afdc7e2a33f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -144,73 +144,81 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a WHERE") { - spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + withTempView("integerData") { + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.createOrReplaceTempView("integerData") + val df = sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("integerData") - val result = - sql("SELECT * FROM integerData WHERE oneArgFilter(key)") - assert(result.count() === 20) + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } } test("UDF in a HAVING") { - spark.udf.register("havingFilter", (n: Long) => { n > 5 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT g, SUM(v) as s - | FROM groupData - | GROUP BY g - | HAVING havingFilter(s) - """.stripMargin) - - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } } test("UDF in a GROUP BY") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT SUM(v) - | FROM groupData - | GROUP BY groupFunction(v) - """.stripMargin) - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } } test("UDFs everywhere") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) - spark.udf.register("whereFilter", (n: Int) => { n < 150 }) - spark.udf.register("timesHundred", (n: Long) => { n * 100 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT timesHundred(SUM(v)) as v100 - | FROM groupData - | WHERE whereFilter(v) - | GROUP BY groupFunction(v) - | HAVING havingFilter(v100) - """.stripMargin) - assert(result.count() === 1) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } } test("struct UDF") { @@ -304,4 +312,16 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } + + test("SPARK-23666 Do not display exprId in argument names") { + withTempView("x") { + Seq(((1, 2), 3)).toDF("a", "b").createOrReplaceTempView("x") + spark.udf.register("f", (a: Int) => a) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + spark.sql("SELECT f(a._1) FROM x").show + } + assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) + } + } } From 500b21c3d6247015e550be7e144e9b4b26fe28be Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Mar 2018 10:19:02 -0500 Subject: [PATCH 0499/2461] [SPARK-23568][ML] Use metadata numAttributes if available in Silhouette ## What changes were proposed in this pull request? Silhouette need to know the number of features. This was taken using `first` and checking the size of the vector. Despite this works fine, if the number of attributes is present in metadata, we can avoid to trigger a job for this and use the metadata value. This can help improving performances of course. ## How was this patch tested? existing UTs + added UT Author: Marco Gaido Closes #20719 from mgaido91/SPARK-23568. --- .../ml/evaluation/ClusteringEvaluator.scala | 22 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 8d4ae562b3d2b..4353c46781e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} @@ -170,6 +171,15 @@ private[evaluation] abstract class Silhouette { def overallScore(df: DataFrame, scoreColumn: Column): Double = { df.select(avg(scoreColumn)).collect()(0).getDouble(0) } + + protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = { + val group = AttributeGroup.fromStructField(dataFrame.schema(columnName)) + if (group.size < 0) { + dataFrame.select(col(columnName)).first().getAs[Vector](0).size + } else { + group.size + } + } } /** @@ -360,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { df: DataFrame, predictionCol: String, featuresCol: String): Map[Double, ClusterStats] = { - val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm")) .rdd @@ -552,8 +562,11 @@ private[evaluation] object CosineSilhouette extends Silhouette { * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). */ - def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { - val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + def computeClusterStats( + df: DataFrame, + featuresCol: String, + predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) .rdd @@ -626,7 +639,8 @@ private[evaluation] object CosineSilhouette extends Silhouette { normalizeFeatureUDF(col(featuresCol))) // compute aggregate values for clusters needed by the algorithm - val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, + predictionCol) // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 3bf34770f5687..2c175ff68e0b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ @@ -100,4 +102,23 @@ class ClusteringEvaluatorSuite } } + test("SPARK-23568: we should use metadata to determine features number") { + val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size + val attrGroup = new AttributeGroup("features", attributesNum) + val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label") + require(AttributeGroup.fromStructField(df.schema("features")) + .numAttributes.isDefined, "numAttributes metadata should be defined") + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + + // with the proper metadata we compute correctly the result + assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5) + + val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1) + val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()), + $"label") + // with wrong metadata the evaluator throws an Exception + intercept[SparkException](evaluator.evaluate(dfWrong)) + } } From bf09f2f71276d3b3a84a8f89109bd785a066c3e6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 21 Mar 2018 09:39:14 -0700 Subject: [PATCH 0500/2461] [SPARK-10884][ML] Support prediction on single instance for regression and classification related models ## What changes were proposed in this pull request? Support prediction on single instance for regression and classification related models (i.e., PredictionModel, ClassificationModel and their sub classes). Add corresponding test cases. ## How was this patch tested? Test cases added. Author: WeichenXu Closes #19381 from WeichenXu123/single_prediction. --- .../scala/org/apache/spark/ml/Predictor.scala | 5 ++-- .../spark/ml/classification/Classifier.scala | 6 ++--- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 17 ++++++++++++- .../classification/GBTClassifierSuite.scala | 9 +++++++ .../ml/classification/LinearSVCSuite.scala | 6 +++++ .../LogisticRegressionSuite.scala | 9 +++++++ .../MultilayerPerceptronClassifierSuite.scala | 12 ++++++++++ .../ml/classification/NaiveBayesSuite.scala | 22 +++++++++++++++++ .../RandomForestClassifierSuite.scala | 16 +++++++++++++ .../DecisionTreeRegressorSuite.scala | 15 ++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 8 +++++++ .../GeneralizedLinearRegressionSuite.scala | 8 +++++++ .../ml/regression/LinearRegressionSuite.scala | 7 ++++++ .../RandomForestRegressorSuite.scala | 24 +++++++++++++++---- .../org/apache/spark/ml/util/MLTest.scala | 15 ++++++++++-- 25 files changed, 176 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 08b0cb9b8f6a5..d8f3dfa874439 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. */ - protected def predict(features: FeaturesType): Double + @Since("2.4.0") + def predict(features: FeaturesType): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 9d1d5aa1e0cff..7e5790ab70ee9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from `predictRaw()`. */ - override protected def predict(features: FeaturesType): Double = { + override def predict(features: FeaturesType): Double = { raw2prediction(predictRaw(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec52..65cce697d8202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f11bc1d8fe415..cd44489f618b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -267,7 +267,7 @@ class GBTClassificationModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization if (isDefined(thresholds)) { super.predict(features) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ce400f4f1faf7..8f950cd28c3aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -316,7 +316,7 @@ class LinearSVCModel private[classification] ( BLAS.dot(features, coefficients) + intercept } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { if (margin(features) > $(threshold)) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fa191604218db..3ae4db3f3f965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] ( * Predict label for the given feature vector. * The behavior of this can be adjusted using `thresholds`. */ - override protected def predict(features: Vector): Double = if (isMultinomial) { + override def predict(features: Vector): Double = if (isMultinomial) { super.predict(features) } else { // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index fd4c98f22132f..af2e4699924e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * Predict label for the given features. * This internal method is used to implement `transform()` and output [[predictionCol]]. */ - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { LabelConverter.decodeLabel(mlpModel.predict(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0291a57487c47..ad154fcd010cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f41d15b62dddd..6569ff2a5bfc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -230,7 +230,7 @@ class GBTRegressionModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 917a4d238d467..9f1f2405c428e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { predict(features, 0.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6d3fe7a6c748c..92510154d500e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] ( } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { dot(features, coefficients) + intercept } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 200b234b79978..2d594460c2475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index eeb0324187c5b..2930f4900d50e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, DecisionTreeClassificationModel](this, newTree, newData) } + test("prediction on single instance") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + testPredictionModelSinglePrediction(newTree, newData) + } + test("training with 1-category categorical feature") { val data = sc.parallelize(Seq( LabeledPoint(0, Vectors.dense(0, 2, 3)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 092b4a01d5b0d..57796069f6052 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, GBTClassificationModel](this, gbtModel, validationDataset) } + test("prediction on single instance") { + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF("label", "features") + val gbtModel = gbt.fit(trainingDataset) + + testPredictionModelSinglePrediction(gbtModel, trainingDataset) + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a93825b8a812d..c05c896df5cb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { dataset.as[LabeledPoint], estimator, modelEquals, 42L) } + test("prediction on single instance") { + val trainer = new LinearSVC() + val model = trainer.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(model, smallBinaryDataset) + } + test("linearSVC comparison with R e1071 and scikit-learn") { val trainer1 = new LinearSVC() .setRegParam(0.00002) // set regParam = 2.0 / datasize / c diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9987cbf6ba116..36b7e51f93d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } + test("prediction on single instance") { + val blor = new LogisticRegression().setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(blorModel, smallBinaryDataset) + val mlor = new LogisticRegression().setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) + testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset) + } + test("coefficients and intercept methods") { val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") val mlrModel = mlr.fit(smallMultinomialDataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index daa58a56896d7..6b5fe6e49ffea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe } } + test("prediction on single instance") { + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(dataset) + testPredictionModelSinglePrediction(model, dataset) + } + test("Predicted class probabilities: calibration on toy dataset") { val layers = Array[Int](4, 5, 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 49115c8a4db30..5f9ab98a2c3ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { Vector, NaiveBayesModel](this, model, testDataset) } + test("prediction on single instance") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val trainDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(trainDataset) + + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() + + testPredictionModelSinglePrediction(model, validationDataset) + } + test("Naive Bayes with weighted samples") { val numClasses = 3 def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 02a9d5c2a18c0..ba4a9cf082785 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, RandomForestClassificationModel](this, model, df) } + test("prediction on single instance") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + testPredictionModelSinglePrediction(model, df) + } + test("Fitting without numClasses in metadata") { val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 68a1218c23ece..29a438396516b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { assert(importances.toArray.forall(_ >= 0.0)) } + test("prediction on single instance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 11c593b521e65..fad11d078250f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(trainData.toDF()) + testPredictionModelSinglePrediction(model, validationData.toDF) + } + test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ef2ff94a5e213..d5bcbb221783e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(model.getLink === "identity") } + test("prediction on single instance") { + val glr = new GeneralizedLinearRegression + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + testPredictionModelSinglePrediction(model, datasetGaussianIdentity) + } + test("generalized linear regression: gaussian family against glm") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d42cb1714478f..9b19f63eba1bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val trainer = new LinearRegression + val model = trainer.fit(datasetWithDenseFeature) + + testPredictionModelSinglePrediction(model, datasetWithDenseFeature) + } + test("linear regression model with constant label") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 8b8e8a655f47b..e83c49f932973 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest{ +class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ @@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("prediction on single instance") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("Feature importance with toy data") { val rf = new RandomForestRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 795fd0e2ac0e4..76d41f9b23715 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,8 +22,9 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.Transformer -import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.ml.{PredictionModel, Transformer} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest @@ -136,4 +137,14 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => assert(hasExpectedMessage(exceptionOnStreamData)) } } + + def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _], + dataset: Dataset[_]): Unit = { + + model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol) + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction === model.predict(features)) + } + } } From 8d79113b812a91073d2c24a3a9ad94cc3b90b24a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 21 Mar 2018 09:46:47 -0700 Subject: [PATCH 0501/2461] [SPARK-23577][SQL] Supports custom line separator for text datasource ## What changes were proposed in this pull request? This PR proposes to add `lineSep` option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. ## How was this patch tested? Manual tests and unit tests were added. Author: hyukjinkwon Closes #20727 from HyukjinKwon/linesep-text. --- python/pyspark/sql/readwriter.py | 14 ++++--- python/pyspark/sql/streaming.py | 8 +++- python/pyspark/sql/tests.py | 24 ++++++++++- .../apache/spark/sql/DataFrameReader.scala | 30 ++++++++------ .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/HadoopFileLinesReader.scala | 23 ++++++++++- .../datasources/text/TextFileFormat.scala | 16 ++++---- .../datasources/text/TextOptions.scala | 12 ++++++ .../sql/streaming/DataStreamReader.scala | 12 +++++- .../datasources/text/TextSuite.scala | 40 +++++++++++++++++++ 10 files changed, 147 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index facc16bc53108..e5288636c596e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -304,16 +304,18 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths, wholetext=False): + def text(self, paths, wholetext=False, lineSep=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() @@ -322,7 +324,7 @@ def text(self, paths, wholetext=False): >>> df.collect() [Row(value=u'hello\\nthis')] """ - self._set_opts(wholetext=wholetext) + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -804,18 +806,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): self._jwrite.parquet(path) @since(1.6) - def text(self, path, compression=None): + def text(self, path, compression=None, lineSep=None): """Saves the content of the DataFrame in a text file at the specified path. :param path: the path in any Hadoop supported file system :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - self._set_opts(compression=compression) + self._set_opts(compression=compression, lineSep=lineSep) self._jwrite.text(path) @since(2.0) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e8966c20a8f42..07f9ac1b5aa9e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -531,17 +531,20 @@ def parquet(self, path): @ignore_unicode_prefix @since(2.0) - def text(self, path): + def text(self, path, wholetext=False, lineSep=None): """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. .. note:: Evolving. :param paths: string, or list of strings, for input path(s). + :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming @@ -549,6 +552,7 @@ def text(self, path): >>> "value" in str(text_sdf.schema) True """ + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.text(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39d6c5226f138..967cc83166f3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -648,7 +648,29 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", multiLine=True) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0139913aaa4e2..1a5e47508c070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -647,14 +647,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * You can set the following text-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    - * By default, each line in the text files is a new row in the resulting DataFrame. - * - * Usage example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.read.text("/path/to/spark/README.md") @@ -663,6 +656,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * + * You can set the following text-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input paths * @since 1.6.0 */ @@ -686,11 +687,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * You can set the following textFile-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: @@ -700,6 +696,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().textFile("/path/to/spark/README.md") * }}} * + * You can set the following textFile-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input path * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ed7a9100cc7f1..bb93889dc55e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -587,6 +587,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 83cf26c63a175..00a78f7343c59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -30,9 +30,22 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl /** * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. + * + * @param file A part (i.e. "block") of a single file that should be read line by line. + * @param lineSeparator A line separator that should be used for each line. If the value is `None`, + * it covers `\r`, `\r\n` and `\n`. + * @param conf Hadoop configuration + * + * @note The behavior when `lineSeparator` is `None` (covering `\r`, `\r\n` and `\n`) is defined + * by [[LineRecordReader]], not within Spark. */ class HadoopFileLinesReader( - file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { + file: PartitionedFile, + lineSeparator: Option[Array[Byte]], + conf: Configuration) extends Iterator[Text] with Closeable { + + def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf) + private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -42,7 +55,13 @@ class HadoopFileLinesReader( Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val reader = new LineRecordReader() + + val reader = lineSeparator match { + case Some(sep) => new LineRecordReader(sep) + // If the line separator is `None`, it covers `\r`, `\r\n` and `\n`. + case _ => new LineRecordReader() + } + reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index c661e9bd3b94c..9647f09867643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.execution.datasources.text -import java.io.Closeable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -89,7 +86,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -113,18 +110,18 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) + readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions) } private def readToUnsafeMem( conf: Broadcast[SerializableConfiguration], requiredSchema: StructType, - wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = { + textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = { (file: PartitionedFile) => { val confValue = conf.value.value - val reader = if (!wholeTextMode) { - new HadoopFileLinesReader(file, confValue) + val reader = if (!textOptions.wholeText) { + new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) } else { new HadoopFileWholeTextReader(file, confValue) } @@ -152,6 +149,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { class TextOutputWriter( path: String, dataSchema: StructType, + lineSeparator: Array[Byte], context: TaskAttemptContext) extends OutputWriter { @@ -162,7 +160,7 @@ class TextOutputWriter( val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) } - writer.write('\n') + writer.write(lineSeparator) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 2a661561ab51e..18698df9fd8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -39,9 +41,19 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean + private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => + require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInWrite: Array[Byte] = + lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } private[text] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c393dcdfdd7e5..9b17406a816b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -387,7 +387,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * Each line in the text files is a new row in the resulting DataFrame. For example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.readStream.text("/path/to/directory/") @@ -400,6 +400,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @since 2.0.0 @@ -413,7 +417,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * Each line in the text file is a new element in the resulting Dataset. For example: + * By default, each line in the text file is a new element in the resulting Dataset. For example: * {{{ * // Scala: * spark.readStream.textFile("/path/to/spark/README.md") @@ -426,6 +430,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @param path input path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 33287044f279e..e8a5299d6ba9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.TestUtils import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -172,6 +175,43 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-23577: Support line separator - lineSep: '$lineSep'") { + // Read + val values = Seq("a", "b", "\nc") + val data = values.mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, Seq("a", "b", "\nc").toDF()) + } + } + + // Write + withTempPath { path => + values.toDF().coalesce(1) + .write.option("lineSep", lineSep).text(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = values.toDF() + df.write.option("lineSep", lineSep).text(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + testLineSeparator(lineSep) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } From 98d0ea3f6091730285293321a50148f69e94c9cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 21 Mar 2018 09:52:28 -0700 Subject: [PATCH 0502/2461] [SPARK-23264][SQL] Fix scala.MatchError in literals.sql.out ## What changes were proposed in this pull request? To fix `scala.MatchError` in `literals.sql.out`, this pr added an entry for `CalendarIntervalType` in `QueryExecution.toHiveStructString`. ## How was this patch tested? Existing tests and added tests in `literals.sql` Author: Takeshi Yamamuro Closes #20872 from maropu/FixIntervalTests. --- .../spark/sql/execution/QueryExecution.scala | 2 ++ .../resources/sql-tests/inputs/literals.sql | 3 +++ .../sql-tests/results/literals.sql.out | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7cae24bf5976c..15379a0663f7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -155,6 +155,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -178,6 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes.contains(tpe) => other.toString } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 37b4b7606d12b..a743cf1ec2cde 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -105,3 +105,6 @@ select X'XuZ'; -- Hive literal_double test. SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; + +-- map + interval test +select map(1, interval 1 day, 2, interval 3 week); diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 95d4413148f64..b8c91dc8b59a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 43 +-- Number of queries: 44 -- !query 0 @@ -323,19 +323,17 @@ select timestamp '2016-33-11 20:54:00.000' -- !query 34 select interval 13.123456789 seconds, interval -13.123456789 second -- !query 34 schema -struct<> +struct -- !query 34 output -scala.MatchError -(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 13 seconds 123 milliseconds 456 microseconds interval -12 seconds -876 milliseconds -544 microseconds -- !query 35 select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond -- !query 35 schema -struct<> +struct -- !query 35 output -scala.MatchError -(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds 9 -- !query 36 @@ -416,3 +414,11 @@ SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> -- !query 42 output 3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 + + +-- !query 43 +select map(1, interval 1 day, 2, interval 3 week) +-- !query 43 schema +struct> +-- !query 43 output +{1:interval 1 days,2:interval 3 weeks} From 918c7e99afdcea05c36626e230636c4f8aabf82c Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 21 Mar 2018 10:06:26 -0700 Subject: [PATCH 0503/2461] [SPARK-23288][SS] Fix output metrics with parquet sink ## What changes were proposed in this pull request? Output metrics were not filled when parquet sink used. This PR fixes this problem by passing a `BasicWriteJobStatsTracker` in `FileStreamSink`. ## How was this patch tested? Additional unit test added. Author: Gabor Somogyi Closes #20745 from gaborgsomogyi/SPARK-23288. --- .../command/DataWritingCommand.scala | 11 +--- .../datasources/BasicWriteStatsTracker.scala | 25 +++++++-- .../execution/streaming/FileStreamSink.scala | 10 +++- .../sql/streaming/FileStreamSinkSuite.scala | 52 +++++++++++++++++++ 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a7..e11dbd201004d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} @@ -45,15 +44,7 @@ trait DataWritingCommand extends Command { // Output columns of the analyzed input query plan def outputColumns: Seq[Attribute] - lazy val metrics: Map[String, SQLMetric] = { - val sparkContext = SparkContext.getActive.get - Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") - ) - } + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 9dbbe9946ee99..69c03d862391e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker( totalNumOutput += summary.numRows } - metrics("numFiles").add(numFiles) - metrics("numOutputBytes").add(totalNumBytes) - metrics("numOutputRows").add(totalNumOutput) - metrics("numParts").add(numPartitions) + metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput) + metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) } } + +object BasicWriteJobStatsTracker { + private val NUM_FILES_KEY = "numFiles" + private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes" + private val NUM_OUTPUT_ROWS_KEY = "numOutputRows" + private val NUM_PARTS_KEY = "numParts" + + def metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"), + NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 87a17cebdc10c..b3d12f67b5d63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -97,6 +98,11 @@ class FileStreamSink( new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) + } + override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") @@ -131,7 +137,7 @@ class FileStreamSink( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, - statsTrackers = Nil, + statsTrackers = Seq(basicWriteJobStatsTracker), options = options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 31e5527d7366a..cf41d7e0e4fe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.hadoop.fs.Path +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -405,4 +406,55 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("SPARK-23288 writing and checking output metrics") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + + var query: StreamingQuery = null + + var numTasks = 0 + var recordsWritten: Long = 0L + var bytesWritten: Long = 0L + try { + spark.sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val outputMetrics = taskEnd.taskMetrics.outputMetrics + recordsWritten += outputMetrics.recordsWritten + bytesWritten += outputMetrics.bytesWritten + numTasks += 1 + } + }) + + query = + df.writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + + assert(numTasks > 0) + assert(recordsWritten === 5) + // This is heavily file type/version specific but should be filled + assert(bytesWritten > 0) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + } } From 2b89e4aa2e8bd8b88f6e5eb60d95c1a58e5c4ace Mon Sep 17 00:00:00 2001 From: akonopko Date: Wed, 21 Mar 2018 14:40:21 -0500 Subject: [PATCH 0504/2461] [SPARK-18580][DSTREAM][KAFKA] Add spark.streaming.backpressure.initialRate to direct Kafka streams ## What changes were proposed in this pull request? Add `spark.streaming.backpressure.initialRate` to direct Kafka Streams for Kafka 0.8 and 0.10 This is required in order to be able to use backpressure with huge lags, which cannot be processed at once. Without this parameter `DirectKafkaInputDStream` with backpressure enabled would try to get all the possible data from Kafka before adjusting consumption rate ## How was this patch tested? - Tests added to `org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala` and `org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala` - Manual tests on YARN cluster Author: akonopko Author: Alexander Konopko Closes #19431 from akonopko/SPARK-18580-initialrate. --- .../kafka010/DirectKafkaInputDStream.scala | 8 ++- .../kafka010/DirectKafkaStreamSuite.scala | 51 +++++++++++++++- .../kafka/DirectKafkaInputDStream.scala | 9 ++- .../kafka/DirectKafkaStreamSuite.scala | 59 ++++++++++++++++++- 4 files changed, 120 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 9cb2448fea0f4..215b7cab703fb 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -56,6 +56,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + val executorKafkaParams = { val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams) KafkaUtils.fixKafkaParams(ekp) @@ -126,7 +129,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 8524743ee2846..35e4678f2e3c8 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010 import java.io.File import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap } +import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -34,7 +34,7 @@ import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} @@ -617,6 +617,53 @@ class DirectKafkaStreamSuite ssc.stop() } + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) + } + kafkaStream.start() + + val input = Map(new TopicPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> maxMessagesPerPartition)) // we run for half a second + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = getKafkaParams() diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d6dd0744441e4..9297c39d170c4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -91,9 +91,16 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 3fea6cfd910bf..ecca38784e777 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.File -import java.util.Arrays +import java.util.{ Arrays, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -32,12 +32,11 @@ import kafka.serializer.StringDecoder import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils @@ -456,6 +455,60 @@ class DirectKafkaStreamSuite ssc.stop() } + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val topicPartitions = Set(TopicAndPartition(topic, 0)) + kafkaTestUtils.createTopic(topic, 1) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(topicPartitions) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) + } + kafkaStream.start() + + val input = Map(new TopicAndPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicAndPartition(topic, 0) -> maxMessagesPerPartition)) + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = Map( From a091ee676b8707819e94d92693956237310a6145 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 21 Mar 2018 13:52:03 -0700 Subject: [PATCH 0505/2461] [MINOR] Fix Java lint from new JavaKolmogorovSmirnovTestSuite ## What changes were proposed in this pull request? Fix lint-java from https://github.com/apache/spark/pull/19108 addition of JavaKolmogorovSmirnovTestSuite Author: Joseph K. Bradley Closes #20875 from jkbradley/kstest-lint-fix. --- .../spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java index 021272dd5a40c..830f668fe07b8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -18,18 +18,11 @@ package org.apache.spark.ml.stat; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.spark.ml.linalg.VectorUDT; -import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.junit.Test; import org.apache.spark.SharedSparkSession; From 0604beaff2baa2d0fed86c0c87fd2a16a1838b5f Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Wed, 21 Mar 2018 17:05:39 -0700 Subject: [PATCH 0506/2461] [SPARK-23729][CORE] Respect URI fragment when resolving globs Firstly, glob resolution will not result in swallowing the remote name part (that is preceded by the `#` sign) in case of `--files` or `--archives` options Moreover in the special case of multiple resolutions when the remote naming does not make sense and error is returned. Enhanced current test and wrote additional test for the error case Author: Mihaly Toth Closes #20853 from misutoth/glob-with-remote-name. --- .../apache/spark/deploy/DependencyUtils.scala | 34 +++++++++++---- .../org/apache/spark/deploy/SparkSubmit.scala | 13 ++++++ .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++---- 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ecc82d7ac8001..ab319c860ee69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy import java.io.File +import java.net.URI import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.{MutableURLClassLoader, Utils} private[deploy] object DependencyUtils { @@ -137,16 +138,31 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") Utils.stringToSeq(paths).flatMap { path => - val uri = Utils.resolveURI(path) - uri.getScheme match { - case "local" | "http" | "https" | "ftp" => Array(path) - case _ => - val fs = FileSystem.get(uri, hadoopConf) - Option(fs.globStatus(new Path(uri))).map { status => - status.filter(_.isFile).map(_.getPath.toUri.toString) - }.getOrElse(Array(path)) + val (base, fragment) = splitOnFragment(path) + (resolveGlobPath(base, hadoopConf), fragment) match { + case (resolved, Some(_)) if resolved.length > 1 => throw new SparkException( + s"${base.toString} resolves ambiguously to multiple files: ${resolved.mkString(",")}") + case (resolved, Some(namedAs)) => resolved.map(_ + "#" + namedAs) + case (resolved, _) => resolved } }.mkString(",") } + private def splitOnFragment(path: String): (URI, Option[String]) = { + val uri = Utils.resolveURI(path) + val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) + (withoutFragment, Option(uri.getFragment)) + } + + private def resolveGlobPath(uri: URI, hadoopConf: Configuration): Array[String] = { + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(uri.toString) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(uri.toString)) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329bde08718fe..3965f17f4b56e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -245,6 +245,19 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { + try { + doPrepareSubmitEnvironment(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage) + throw e + } + } + + private def doPrepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d265643a80b4e..2d0c192db4915 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -606,10 +606,13 @@ class SparkSubmitSuite } test("resolves command line argument paths correctly") { - val jars = "/jar1,/jar2" // --jars - val files = "local:/file1,file2" // --files - val archives = "file:/archive1,archive2" // --archives - val pyFiles = "py-file1,py-file2" // --py-files + val dir = Utils.createTempDir() + val archive = Paths.get(dir.toPath.toString, "single.zip") + Files.createFile(archive) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" // Test jars and files val clArgs = Seq( @@ -636,9 +639,10 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) - appArgs2.archives should be (Utils.resolveURIs(archives)) + appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.archives") should fullyMatch regex + ("file:/archive1,file:.*#archive3") // Test python files val clArgs3 = Seq( @@ -657,6 +661,29 @@ class SparkSubmitSuite conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } + test("ambiguous archive mapping results in error message") { + val dir = Utils.createTempDir() + val archive1 = Paths.get(dir.toPath.toString, "first.zip") + val archive2 = Paths.get(dir.toPath.toString, "second.zip") + Files.createFile(archive1) + Files.createFile(archive2) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + + testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + } + test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files From 95e51ff849a4c46cae463636b1ee393042469e7b Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 21 Mar 2018 21:21:36 -0700 Subject: [PATCH 0507/2461] [SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should save/restore CSE state correctly ## What changes were proposed in this pull request? Fixed `CodegenContext.withSubExprEliminationExprs()` so that it saves/restores CSE state correctly. ## How was this patch tested? Added new unit test to verify that the old CSE state is indeed saved and restored around the `withSubExprEliminationExprs()` call. Manually verified that this test fails without this patch. Author: Kris Mok Closes #20870 from rednaxelafx/codegen-subexpr-fix. --- .../expressions/codegen/CodeGenerator.scala | 16 +++---- .../expressions/CodeGenerationSuite.scala | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fe5e63ec0a2bb..84b1e3fbda876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,7 +402,7 @@ class CodegenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -921,14 +921,12 @@ class CodegenContext { newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs.clear - newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = newSubExprEliminationExprs val genCodes = f // Restore previous subExprEliminationExprs - subExprEliminationExprs.clear - oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = oldsubExprEliminationExprs genCodes } @@ -942,7 +940,7 @@ class CodegenContext { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree) @@ -955,10 +953,10 @@ class CodegenContext { // Generate the code for this expression tree. val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) - e.foreach(subExprEliminationExprs.put(_, state)) + e.foreach(localSubExprEliminationExprs.put(_, state)) eval.code.trim } - SubExprCodes(codes, subExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap) } /** @@ -1006,7 +1004,7 @@ class CodegenContext { subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) - e.foreach(subExprEliminationExprs.put(_, state)) + subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 64c13e8972036..398b6767654fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -442,4 +442,48 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(CodeGenerator.calculateParamLength( Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) } + + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } + + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + } + } } From 5c9eaa6b585e9febd782da8eb6490b24d0d39ff3 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 21 Mar 2018 21:49:02 -0700 Subject: [PATCH 0508/2461] [SPARK-23372][SQL] Writing empty struct in parquet fails during execution. It should fail earlier in the processing. ## What changes were proposed in this pull request? Currently we allow writing data frames with empty schema into a file based datasource for certain file formats such as JSON, ORC etc. For formats such as Parquet and Text, we raise error at different times of execution. For text format, we return error from the driver early on in processing where as for format such as parquet, the error is raised from executor. **Example** spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) **Results in** ``` SQL org.apache.parquet.schema.InvalidSchemaException: Cannot write a schema with an empty group: message spark_schema { } at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:27) at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:37) at org.apache.parquet.schema.MessageType.accept(MessageType.java:58) at org.apache.parquet.schema.TypeUtil.checkValidWriteSchema(TypeUtil.java:23) at org.apache.parquet.hadoop.ParquetFileWriter.(ParquetFileWriter.java:225) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:342) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:302) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:151) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.newOutputWriter(FileFormatWriter.scala:376) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.execute(FileFormatWriter.scala:387) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:278) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:276) at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1411) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:281) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:206) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:205) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread. ``` In this PR, we unify the error processing and raise error on attempt to write empty schema based dataframes into file based datasource (orc, parquet, text , csv, json etc) early on in the processing. ## How was this patch tested? Unit tests added in FileBasedDatasourceSuite. Author: Dilip Biswal Closes #20579 from dilipbiswal/spark-23372. --- docs/sql-programming-guide.md | 1 + .../execution/datasources/DataSource.scala | 26 ++++++++++++++++- .../spark/sql/FileBasedDataSourceSuite.scala | 28 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5b47fd77f2cbc..421e2eaf62bfb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1807,6 +1807,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 35fcff69b14d8..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils @@ -546,6 +546,7 @@ case class DataSource( case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => + DataSource.validateSchema(data.schema) planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") @@ -719,4 +720,27 @@ object DataSource extends Logging { } globPath } + + /** + * Called before writing into a FileFormat based data source to make sure the + * supplied schema is not empty. + * @param schema + */ + private def validateSchema(schema: StructType): Unit = { + def hasEmptySchema(schema: StructType): Boolean = { + schema.size == 0 || schema.find { + case StructField(_, b: StructType, _, _) => hasEmptySchema(b) + case _ => false + }.isDefined + } + + + if (hasEmptySchema(schema)) { + throw new AnalysisException( + s""" + |Datasource does not support writing empty or nested empty schemas. + |Please make sure the data schema has at least one or more column(s). + """.stripMargin) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index bd3071bcf9010..06303099f5310 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { @@ -107,6 +108,33 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + allFileBasedDataSources.foreach { format => + test(s"SPARK-23372 error while writing empty schema files using $format") { + withTempPath { outputPath => + val errMsg = intercept[AnalysisException] { + spark.emptyDataFrame.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + + // Nested empty schema + withTempPath { outputPath => + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", StructType(Nil)), + StructField("c", IntegerType) + )) + val df = spark.createDataFrame(sparkContext.emptyRDD[Row], schema) + val errMsg = intercept[AnalysisException] { + df.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => From 4d37008c78d7d6b8f8a649b375ecc090700eca4f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 19:57:32 +0100 Subject: [PATCH 0509/2461] [SPARK-23599][SQL] Use RandomUUIDGenerator in Uuid expression ## What changes were proposed in this pull request? As stated in Jira, there are problems with current `Uuid` expression which uses `java.util.UUID.randomUUID` for UUID generation. This patch uses the newly added `RandomUUIDGenerator` for UUID generation. So we can make `Uuid` deterministic between retries. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20861 from viirya/SPARK-23599-2. --- .../sql/catalyst/analysis/Analyzer.scala | 16 ++++ .../spark/sql/catalyst/expressions/misc.scala | 26 +++++-- .../ResolvedUuidExpressionsSuite.scala | 73 +++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 5 +- .../expressions/MiscExpressionsSuite.scala | 19 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 6 files changed, 136 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7848f88bda1c9..e821e96522f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ @@ -177,6 +178,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: + ResolvedUuidExpressions :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -1994,6 +1996,20 @@ class Analyzer( } } + /** + * Set the seed for random number generation in Uuid expressions. + */ + object ResolvedUuidExpressions extends Rule[LogicalPlan] { + private lazy val random = new Random() + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case p if p.resolved => p + case p => p transformExpressionsUp { + case Uuid(None) => Uuid(Some(random.nextLong())) + } + } + } + /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 38e4fe44b15ab..ec93620038cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -122,18 +123,33 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid() extends LeafExpression { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { - override lazy val deterministic: Boolean = false + def this() = this(None) + + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = + randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = + randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + val randomGen = ctx.freshName("randomGen") + ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, + forceInline = true, + useFreshName = false) + ctx.addPartitionInitializationStatement(s"$randomGen = " + + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + + s"${randomSeed.get}L + partitionIndex);") + ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala new file mode 100644 index 0000000000000..fe57c199b8744 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} + +/** + * Test suite for resolving Uuid expressions. + */ +class ResolvedUuidExpressionsSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val r = LocalRelation(a) + private lazy val uuid1 = Uuid().as('_uuid1) + private lazy val uuid2 = Uuid().as('_uuid2) + private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1Ref = uuid1.toAttribute + + private val analyzer = getAnalyzer(caseSensitive = true) + + private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { + plan.flatMap { + case p => + p.expressions.flatMap(_.collect { + case u: Uuid => u + }) + } + } + + test("analyzed plan sets random seed for Uuid expression") { + val plan = r.select(a, uuid1) + val resolvedPlan = analyzer.executeAndCheck(plan) + getUuidExpressions(resolvedPlan).foreach { u => + assert(u.resolved) + assert(u.randomSeed.isDefined) + } + } + + test("Uuid expressions should have different random seeds") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan = analyzer.executeAndCheck(plan) + assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) + } + + test("Different analyzed plans should have different random seeds in Uuids") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan1 = analyzer.executeAndCheck(plan) + val resolvedPlan2 = analyzer.executeAndCheck(plan) + val uuids1 = getUuidExpressions(resolvedPlan1) + val uuids2 = getUuidExpressions(resolvedPlan2) + assert(uuids1.distinct.length == 3) + assert(uuids2.distinct.length == 3) + assert(uuids1.intersect(uuids2).length == 0) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c6343b1cbf600..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -176,7 +176,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithGeneratedMutableProjection( + protected def evaluateWithGeneratedMutableProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( @@ -220,7 +220,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithUnsafeProjection( + protected def evaluateWithUnsafeProjection( expression: Expression, inputRow: InternalRow = EmptyRow, factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { @@ -233,6 +233,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { Alias(expression, s"Optimized($expression)2")() :: Nil), expression) + plan.initialize(0) plan(inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index c3d08bf68c7bb..3383d421f5616 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.io.PrintStream +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -42,8 +44,21 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("uuid") { - checkEvaluation(Length(Uuid()), 36) - assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) + checkEvaluation(Length(Uuid(Some(0))), 36) + val r = new Random() + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) === + evaluateWithGeneratedMutableProjection(Uuid(seed1))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) === + evaluateWithUnsafeProjection(Uuid(seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !== + evaluateWithGeneratedMutableProjection(Uuid(seed2))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== + evaluateWithUnsafeProjection(Uuid(seed2))) } test("PrintToStderr") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8b66f77b2f923..f7b3393f65cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -2264,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + + test("Uuid expressions should produce same results at retries in the same DataFrame") { + val df = spark.range(1).select($"id", new Column(Uuid())) + checkAnswer(df, df.collect()) + } } From a649fcf32a7e610da2a2b4e3d94f5d1372c825d6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Mar 2018 21:20:41 -0700 Subject: [PATCH 0510/2461] [MINOR][PYTHON] Remove unused codes in schema parsing logics of PySpark ## What changes were proposed in this pull request? This PR proposes to remove out unused codes, `_ignore_brackets_split` and `_BRACKETS`. `_ignore_brackets_split` was introduced in https://github.com/apache/spark/commit/d57daf1f7732a7ac54a91fe112deeda0a254f9ef to refactor and support `toDF("...")`; however, https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb replaced the logics here. Seems `_ignore_brackets_split` is not referred anymore. `_BRACKETS` was introduced in https://github.com/apache/spark/commit/880eabec37c69ce4e9594d7babfac291b0f93f50; however, all other usages were removed out in https://github.com/apache/spark/commit/648a8626b82d27d84db3e48bccfd73d020828586. This is rather a followup for https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb which I missed in that PR. ## How was this patch tested? Manually tested. Existing tests should cover this. I also double checked by `grep` in the whole repo. Author: hyukjinkwon Closes #20878 from HyukjinKwon/minor-remove-unused. --- python/pyspark/sql/types.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 826aab97e58db..5d5919e451b46 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -752,41 +752,6 @@ def __eq__(self, other): _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _ignore_brackets_split(s, separator): - """ - Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. - given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return - ["a", "d"]. - """ - parts = [] - buf = "" - level = 0 - for c in s: - if c in _BRACKETS.keys(): - level += 1 - buf += c - elif c in _BRACKETS.values(): - if level == 0: - raise ValueError("Brackets are not correctly paired: %s" % s) - level -= 1 - buf += c - elif c == separator and level > 0: - buf += c - elif c == separator: - parts.append(buf) - buf = "" - else: - buf += c - - if len(buf) == 0: - raise ValueError("The %s cannot be the last char: %s" % (separator, s)) - parts.append(buf) - return parts - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals From b2edc30db1dcc6102687d20c158a2700965fdf51 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 21:23:25 -0700 Subject: [PATCH 0511/2461] [SPARK-23614][SQL] Fix incorrect reuse exchange when caching is used ## What changes were proposed in this pull request? We should provide customized canonicalize plan for `InMemoryRelation` and `InMemoryTableScanExec`. Otherwise, we can wrongly treat two different cached plans as same result. It causes wrongly reused exchange then. For a test query like this: ```scala val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() val group1 = cached.groupBy("x").agg(min(col("y")) as "value") val group2 = cached.groupBy("x").agg(min(col("z")) as "value") group1.union(group2) ``` Canonicalized plans before: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` You can find that they have the canonicalized plans are the same, although we use different columns in two `InMemoryTableScan`s. Canonicalized plan after: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#2] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20831 from viirya/SPARK-23614. --- .../execution/columnar/InMemoryRelation.scala | 10 ++++++++++ .../columnar/InMemoryTableScanExec.scala | 19 +++++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ .../spark/sql/execution/ExchangeSuite.scala | 7 +++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 22e16913d4da9..2579046e30708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan @@ -68,6 +69,15 @@ case class InMemoryRelation( override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), + storageLevel = StorageLevel.NONE, + child = child.canonicalized, + tableName = None)( + _cachedColumnBuffers, + sizeInBytesStats, + statsOfPlanToCache) + override def producedAttributes: AttributeSet = outputSet @transient val partitionStatistics = new PartitionStatistics(output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index a93e8a1ad954d..e73e1378d52e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -38,6 +38,11 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + override def vectorTypes: Option[Seq[String]] = Option(Seq.fill(attributes.length)( if (!conf.offHeapColumnVectorEnabled) { @@ -169,11 +174,13 @@ case class InMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Keeps relation's partition statistics because we don't serialize relation. + private val stats = relation.partitionStatistics + private def statsFor(a: Attribute) = stats.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - @transient val buildFilter: PartialFunction[Expression, Expression] = { + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -213,14 +220,14 @@ case class InMemoryTableScanExec( l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } - val partitionFilters: Seq[Expression] = { + lazy val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = filter.map( BindReferences.bindReference( _, - relation.partitionStatistics.schema, + stats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -243,7 +250,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = relation.partitionStatistics.schema + val schema = stats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 49c59cf695dc1..9b745befcb611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1446,8 +1446,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val data = Seq(("a", null)) checkDataset(data.toDS(), data: _*) } + + test("SPARK-23614: Union produces incorrect results when caching is used") { + val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() + val group1 = cached.groupBy("x").agg(min(col("y")) as "value") + val group2 = cached.groupBy("x").agg(min(col("z")) as "value") + checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) + } } +case class TestDataUnion(x: Int, y: Int, z: Int) + case class SingleData(id: Int) case class DoubleData(id: Int, val1: String) case class TripleData(id: Int, val1: String, val2: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 697d7e6520713..bde2de5b39fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -125,4 +125,11 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) } } + + test("SPARK-23614: Fix incorrect reuse exchange when caching is used") { + val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache() + val projection1 = cached.select("_1", "_2").queryExecution.executedPlan + val projection2 = cached.select("_1", "_3").queryExecution.executedPlan + assert(!projection1.sameResult(projection2)) + } } From 5fa438471110afbf4e2174df449ac79e292501f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 23 Mar 2018 13:59:21 +0800 Subject: [PATCH 0512/2461] [SPARK-23361][YARN] Allow AM to restart after initial tokens expire. Currently, the Spark AM relies on the initial set of tokens created by the submission client to be able to talk to HDFS and other services that require delegation tokens. This means that after those tokens expire, a new AM will fail to start (e.g. when there is an application failure and re-attempts are enabled). This PR makes it so that the first thing the AM does when the user provides a principal and keytab is to create new delegation tokens for use. This makes sure that the AM can be started irrespective of how old the original token set is. It also allows all of the token management to be done by the AM - there is no need for the submission client to set configuration values to tell the AM when to renew tokens. Note that even though in this case the AM will not be using the delegation tokens created by the submission client, those tokens still need to be provided to YARN, since they are used to do log aggregation. To be able to re-use the code in the AMCredentialRenewal for the above purposes, I refactored that class a bit so that it can fetch tokens into a pre-defined UGI, insted of always logging in. Another issue with re-attempts is that, after the fix that allows the AM to restart correctly, new executors would get confused about when to update credentials, because the credential updater used the update time initially set up by the submission code. This could make the executor fail to update credentials in time, since that value would be very out of date in the situation described in the bug. To fix that, I changed the YARN code to use the new RPC-based mechanism for distributing tokens to executors. This allowed the old credential updater code to be removed, and a lot of code in the renewer to be simplified. I also made two currently hardcoded values (the renewal time ratio, and the retry wait) configurable; while this probably never needs to be set by anyone in a production environment, it helps with testing; that's also why they're not documented. Tested on real cluster with a specially crafted application to test this functionality: checked proper access to HDFS, Hive and HBase in cluster mode with token renewal on and AM restarts. Tested things still work in client mode too. Author: Marcelo Vanzin Closes #20657 from vanzin/SPARK-23361. --- .../scala/org/apache/spark/SparkConf.scala | 12 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 32 +- .../CoarseGrainedExecutorBackend.scala | 12 - .../spark/internal/config/package.scala | 12 + .../MesosHadoopDelegationTokenManager.scala | 11 +- .../spark/deploy/yarn/ApplicationMaster.scala | 117 +++---- .../org/apache/spark/deploy/yarn/Client.scala | 102 ++---- .../deploy/yarn/YarnSparkHadoopUtil.scala | 20 -- .../org/apache/spark/deploy/yarn/config.scala | 25 -- .../yarn/security/AMCredentialRenewer.scala | 291 +++++++----------- .../yarn/security/CredentialUpdater.scala | 131 -------- .../YARNHadoopDelegationTokenManager.scala | 9 +- .../cluster/YarnClientSchedulerBackend.scala | 9 +- .../cluster/YarnSchedulerBackend.scala | 10 +- ...ARNHadoopDelegationTokenManagerSuite.scala | 7 +- .../apache/spark/streaming/Checkpoint.scala | 3 - 16 files changed, 238 insertions(+), 565 deletions(-) delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f53b2bed74c6e..129956e9f9ffa 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -603,13 +603,15 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.rpc", "2.0", "Not used anymore."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", "Please use the new blacklisting options, spark.blacklist.*"), - DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), - DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used anymore"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used anymore"), DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", - "Not used any more. Please use spark.shuffle.service.index.cache.size") + "Not used anymore. Please use spark.shuffle.service.index.cache.size"), + DeprecatedConfig("spark.yarn.credentials.file.retention.count", "2.4.0", "Not used anymore."), + DeprecatedConfig("spark.yarn.credentials.file.retention.days", "2.4.0", "Not used anymore.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -748,7 +750,7 @@ private[spark] object SparkConf extends Logging { } if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) { logWarning( - s"The configuration key $key is not supported any more " + + s"The configuration key $key is not supported anymore " + s"because Spark doesn't use Akka since 2.0") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 177295fb7af0f..8353e64a619cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -146,7 +147,8 @@ class SparkHadoopUtil extends Logging { private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) val creds = deserialize(tokens) - logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + logInfo("Updating delegation tokens for current user.") + logDebug(s"Adding/updating delegation tokens ${dumpTokens(creds)}") addCurrentUserCredentials(creds) } @@ -321,19 +323,6 @@ class SparkHadoopUtil extends Logging { } } - /** - * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. - * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. - */ - private[spark] def getConfBypassingFSCache( - hadoopConf: Configuration, - scheme: String): Configuration = { - val newConf = new Configuration(hadoopConf) - val confKey = s"fs.${scheme}.impl.disable.cache" - newConf.setBoolean(confKey, true) - newConf - } - /** * Dump the credentials' tokens to string values. * @@ -447,16 +436,17 @@ object SparkHadoopUtil { def get: SparkHadoopUtil = instance /** - * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date - * when a given fraction of the duration until the expiration date has passed. - * Formula: current time + (fraction * (time until expiration)) + * Given an expiration date for the current set of credentials, calculate the time when new + * credentials should be created. + * * @param expirationDate Drop-dead expiration date - * @param fraction fraction of the time until expiration return - * @return Date when the fraction of the time until expiration has passed + * @param conf Spark configuration + * @return Timestamp when new credentials should be created. */ - private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + private[spark] def nextCredentialRenewalTime(expirationDate: Long, conf: SparkConf): Long = { val ct = System.currentTimeMillis - (ct + (fraction * (expirationDate - ct))).toLong + val ratio = conf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) + (ct + (ratio * (expirationDate - ct))).toLong } /** diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9b62e4b1b7150..48d3630abd1f9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -213,13 +213,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } - if (driverConf.contains("spark.yarn.credentials.file")) { - logInfo("Will periodically update credentials from: " + - driverConf.get("spark.yarn.credentials.file")) - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("startCredentialUpdater", classOf[SparkConf]) - .invoke(null, driverConf) - } cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) @@ -234,11 +227,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() - if (driverConf.contains("spark.yarn.credentials.file")) { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("stopCredentialUpdater") - .invoke(null) - } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a313ad0554a3a..407545aa4a47a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -525,4 +525,16 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1g") + private[spark] val CREDENTIALS_RENEWAL_INTERVAL_RATIO = + ConfigBuilder("spark.security.credentials.renewalRatio") + .doc("Ratio of the credential's expiration time when Spark should fetch new credentials.") + .doubleConf + .createWithDefault(0.75d) + + private[spark] val CREDENTIALS_RENEWAL_RETRY_WAIT = + ConfigBuilder("spark.security.credentials.retryWait") + .doc("How long to wait before retrying to fetch new credentials after a failure.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1h") + } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 7165bfae18a5e..a1bf4f0c048fe 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils @@ -63,7 +64,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") @@ -104,8 +105,10 @@ private[spark] class MesosHadoopDelegationTokenManager( } catch { case e: Exception => // Log the error and try to write new tokens back in an hour - logWarning("Couldn't broadcast tokens, trying again in an hour", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning( + s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e) + credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS) return } scheduleRenewal(this) @@ -135,7 +138,7 @@ private[spark] class MesosHadoopDelegationTokenManager( "related configurations in the target services.") currTime } else { - SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf) } logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6e35d23def6f0..d04989e138f83 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -29,7 +29,6 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -41,7 +40,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} +import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -79,42 +78,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) - private val ugi = { - val original = UserGroupInformation.getCurrentUser() - - // If a principal and keytab were provided, log in to kerberos, and set up a thread to - // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL - // of the TGT, use a configuration to define how often to check that a relogin is necessary. - // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. - val principal = sparkConf.get(PRINCIPAL).orNull - val keytab = sparkConf.get(KEYTAB).orNull - if (principal != null && keytab != null) { - UserGroupInformation.loginUserFromKeytab(principal, keytab) - - val renewer = new Thread() { - override def run(): Unit = Utils.tryLogNonFatalError { - while (true) { - TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) - UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() - } - } + private val userClassLoader = { + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + + if (isClusterMode) { + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } - renewer.setName("am-kerberos-renewer") - renewer.setDaemon(true) - renewer.start() - - // Transfer the original user's tokens to the new user, since that's needed to connect to - // YARN. It also copies over any delegation tokens that might have been created by the - // client, which will then be transferred over when starting executors (until new ones - // are created by the periodic task). - val newUser = UserGroupInformation.getCurrentUser() - SparkHadoopUtil.get.transferCredentials(original, newUser) - newUser } else { - SparkHadoopUtil.get.createSparkUser() + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } } + private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ => + new AMCredentialRenewer(sparkConf, yarnConf) + } + + private val ugi = credentialRenewer match { + case Some(cr) => + // Set the context class loader so that the token renewer has access to jars distributed + // by the user. + val currentLoader = Thread.currentThread().getContextClassLoader() + Thread.currentThread().setContextClassLoader(userClassLoader) + try { + cr.start() + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + + case _ => + SparkHadoopUtil.get.createSparkUser() + } + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic @@ -148,23 +148,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // A flag to check whether user has initialized spark context @volatile private var registered = false - private val userClassLoader = { - val classpath = Client.getUserClasspath(sparkConf) - val urls = classpath.map { entry => - new URL("file:" + new File(entry.getPath()).getAbsolutePath()) - } - - if (isClusterMode) { - if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { - new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } - // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() @@ -189,8 +172,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. private val sparkContextPromise = Promise[SparkContext]() - private var credentialRenewer: AMCredentialRenewer = _ - // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. @@ -316,31 +297,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - // If the credentials file config is present, we must periodically renew tokens. So create - // a new AMDelegationTokenRenewer - if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { - // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the - // classloader so that main jar and secondary jars could be used by AMCredentialRenewer. - val credentialRenewerThread = new Thread { - setName("AMCredentialRenewerStarter") - setContextClassLoader(userClassLoader) - - override def run(): Unit = { - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - yarnConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - - val credentialRenewer = - new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) - credentialRenewer.scheduleLoginFromKeytab() - } - } - - credentialRenewerThread.start() - credentialRenewerThread.join() - } - if (isClusterMode) { runDriver() } else { @@ -409,9 +365,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logDebug("shutting down user thread") userClassThread.interrupt() } - if (!inShutdown && credentialRenewer != null) { - credentialRenewer.stop() - credentialRenewer = null + if (!inShutdown) { + credentialRenewer.foreach(_.stop()) } } } @@ -468,6 +423,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends securityMgr, localResources) + credentialRenewer.foreach(_.setDriverRef(driverRef)) + // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), // the allocator is ready to service requests. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 28087dee831d1..5763c3dbc5a8a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -93,11 +93,21 @@ private[spark] class Client( private val distCacheMgr = new ClientDistributedCacheManager() - private var loginFromKeytab = false - private var principal: String = null - private var keytab: String = null - private var credentials: Credentials = null - private var amKeytabFileName: String = null + private val principal = sparkConf.get(PRINCIPAL).orNull + private val keytab = sparkConf.get(KEYTAB).orNull + private val loginFromKeytab = principal != null + private val amKeytabFileName: String = { + require((principal == null) == (keytab == null), + "Both principal and keytab must be defined, or neither.") + if (loginFromKeytab) { + logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + new File(keytab).getName() + "-" + UUID.randomUUID().toString + } else { + null + } + } private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -120,11 +130,6 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) } @@ -145,9 +150,6 @@ private[spark] class Client( var appId: ApplicationId = null try { launcherBackend.connect() - // Setup the credentials before doing anything else, - // so we have don't have issues at any point. - setupCredentials() yarnClient.init(hadoopConf) yarnClient.start() @@ -288,8 +290,26 @@ private[spark] class Client( appContext } - /** Set up security tokens for launching our ApplicationMaster container. */ + /** + * Set up security tokens for launching our ApplicationMaster container. + * + * This method will obtain delegation tokens from all the registered providers, and set them in + * the AM's launch context. + */ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + credentialManager.obtainDelegationTokens(hadoopConf, credentials) + + // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid + // that for regular users, since in those case the user already has access to the TGT, + // and adding delegation tokens could lead to expired or cancelled tokens being used + // later, as reported in SPARK-15754. + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) amContainer.setTokens(ByteBuffer.wrap(dob.getData)) @@ -384,36 +404,6 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = destDir.getFileSystem(hadoopConf) - // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) - - if (credentials != null) { - // Add credentials to current user's UGI, so that following operations don't need to use the - // Kerberos tgt to get delegations again in the client side. - val currentUser = UserGroupInformation.getCurrentUser() - if (SparkHadoopUtil.get.isProxyUser(currentUser)) { - currentUser.addCredentials(credentials) - } - logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) - } - - // If we use principal and keytab to login, also credentials can be renewed some time - // after current time, we should pass the next renewal and updating time to credential - // renewer and updater. - if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() && - nearestTimeOfNextRenewal != Long.MaxValue) { - - // Valid renewal time is 75% of next renewal time, and the valid update time will be - // slightly later then renewal time (80% of next renewal time). This is to make sure - // credentials are renewed and updated before expired. - val currTime = System.currentTimeMillis() - val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime - val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime - - sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong) - sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong) - } - // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. @@ -793,11 +783,6 @@ private[spark] class Client( populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (loginFromKeytab) { - val credentialsFile = "credentials-" + UUID.randomUUID().toString - sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) - logInfo(s"Credentials file set to: $credentialsFile") - } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -1014,25 +999,6 @@ private[spark] class Client( amContainer } - def setupCredentials(): Unit = { - loginFromKeytab = sparkConf.contains(PRINCIPAL.key) - if (loginFromKeytab) { - principal = sparkConf.get(PRINCIPAL).get - keytab = sparkConf.get(KEYTAB).orNull - - require(keytab != null, "Keytab must be specified when principal is specified.") - logInfo("Attempting to login to the Kerberos" + - s" using principal: $principal and keytab: $keytab") - val f = new File(keytab) - // Generate a file name that can be used for the keytab file, that does not conflict - // with any user file. - amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(PRINCIPAL.key, principal) - } - // Defensive copy of the credentials - credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials) - } - /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f406fabd61860..8eda6cb1277c5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.CredentialUpdater import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils @@ -38,8 +37,6 @@ import org.apache.spark.util.Utils object YarnSparkHadoopUtil { - private var credentialUpdater: CredentialUpdater = _ - // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. @@ -206,21 +203,4 @@ object YarnSparkHadoopUtil { filesystemsToAccess + stagingFS } - def startCredentialUpdater(sparkConf: SparkConf): Unit = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) - credentialUpdater.start() - } - - def stopCredentialUpdater(): Unit = { - if (credentialUpdater != null) { - credentialUpdater.stop() - credentialUpdater = null - } - } - } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 3ba3ae5ab4401..1a99b3bd57672 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -231,16 +231,6 @@ package object config { /* Security configuration. */ - private[spark] val CREDENTIAL_FILE_MAX_COUNT = - ConfigBuilder("spark.yarn.credentials.file.retention.count") - .intConf - .createWithDefault(5) - - private[spark] val CREDENTIALS_FILE_MAX_RETENTION = - ConfigBuilder("spark.yarn.credentials.file.retention.days") - .intConf - .createWithDefault(5) - private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + "fs.defaultFS does not need to be listed here.") @@ -271,11 +261,6 @@ package object config { /* Private configs. */ - private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") - .internal() - .stringConf - .createWithDefault(null) - // Internal config to propagate the location of the user's jar to the driver/executors private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") .internal() @@ -329,16 +314,6 @@ package object config { .stringConf .createOptional - private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - - private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1m") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index eaf2cff111a49..bc8d47dbd54c6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -18,221 +18,160 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils /** - * The following methods are primarily meant to make sure long-running apps like Spark - * Streaming apps can run without interruption while accessing secured services. The - * scheduleLoginFromKeytab method is called on the AM to get the new credentials. - * This method wakes up a thread that logs into the KDC - * once 75% of the renewal interval of the original credentials used for the container - * has elapsed. It then obtains new credentials and writes them to HDFS in a - * pre-specified location - the prefix of which is specified in the sparkConf by - * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc. - * - each update goes to a new file, with a monotonically increasing suffix), also the - * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater. - * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed. + * A manager tasked with periodically updating delegation tokens needed by the application. * - * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is - * called once 80% of the validity of the original credentials has elapsed. At that time the - * executor finds the credentials file with the latest timestamp and checks if it has read those - * credentials before (by keeping track of the suffix of the last file it read). If a new file has - * appeared, it will read the credentials and update the currently running UGI with it. This - * process happens again once 80% of the validity of this has expired. + * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run + * without interruption while accessing secured services. It periodically logs in to the KDC with + * user-provided credentials, and contacts all the configured secure services to obtain delegation + * tokens to be distributed to the rest of the application. + * + * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API + * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is + * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet + * needed. The check period can be overridden in the configuration. + * + * New delegation tokens are created once 75% of the renewal interval of the original tokens has + * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. + * The driver is tasked with distributing the tokens to other processes that might need them. */ private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { + hadoopConf: Configuration) extends Logging { - private var lastCredentialsFileSuffix = 0 + private val principal = sparkConf.get(PRINCIPAL).get + private val keytab = sparkConf.get(KEYTAB).get + private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - private val credentialRenewerThread: ScheduledExecutorService = + private val renewalExecutor: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - private val hadoopUtil = SparkHadoopUtil.get + private val driverRef = new AtomicReference[RpcEndpointRef]() - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) - private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT) - private val freshHadoopConf = - hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) + private val renewalTask = new Runnable() { + override def run(): Unit = { + updateTokensTask() + } + } - @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + def setDriverRef(ref: RpcEndpointRef): Unit = { + driverRef.set(ref) + } /** - * Schedule a login from the keytab and principal set using the --principal and --keytab - * arguments to spark-submit. This login happens only when the credentials of the current user - * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from - * SparkConf to do the login. This method is a no-op in non-YARN mode. + * Start the token renewer. Upon start, the renewer will: * + * - log in the configured user, and set up a task to keep that user's ticket renewed + * - obtain delegation tokens from all available providers + * - schedule a periodic task to update the tokens when needed. + * + * @return The newly logged in user. */ - private[spark] def scheduleLoginFromKeytab(): Unit = { - val principal = sparkConf.get(PRINCIPAL).get - val keytab = sparkConf.get(KEYTAB).get - - /** - * Schedule re-login and creation of new credentials. If credentials have already expired, this - * method will synchronously create new ones. - */ - def scheduleRenewal(runnable: Runnable): Unit = { - // Run now! - val remainingTime = timeOfNextRenewal - System.currentTimeMillis() - if (remainingTime <= 0) { - logInfo("Credentials have expired, creating new ones now.") - runnable.run() - } else { - logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + def start(): UserGroupInformation = { + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + val ugi = doLogin() + + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() } } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) - // This thread periodically runs on the AM to update the credentials on HDFS. - val credentialRenewerRunnable = - new Runnable { - override def run(): Unit = { - try { - writeNewCredentialsToHDFS(principal, keytab) - cleanupOldFiles() - } catch { - case e: Exception => - // Log the error and try to write new tokens back in an hour - logWarning("Failed to write out new credentials to HDFS, will try again in an " + - "hour! If this happens too often tasks will fail.", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) - return - } - scheduleRenewal(this) - } - } - // Schedule update of credentials. This handles the case of updating the credentials right now - // as well, since the renewal interval will be 0, and the thread will get scheduled - // immediately. - scheduleRenewal(credentialRenewerRunnable) + val creds = obtainTokensAndScheduleRenewal(ugi) + ugi.addCredentials(creds) + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. Explicitly avoid overwriting tokens that already exist in the current user's + // credentials, since those were freshly obtained above (see SPARK-23361). + val existing = ugi.getCredentials() + existing.mergeAll(originalCreds) + ugi.addCredentials(existing) + + ugi + } + + def stop(): Unit = { + renewalExecutor.shutdown() + } + + private def scheduleRenewal(delay: Long): Unit = { + val _delay = math.max(0, delay) + logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") + renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) } - // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At - // least numFilesToKeep files are kept for safety - private def cleanupOldFiles(): Unit = { - import scala.concurrent.duration._ + /** + * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself + * to fetch the next set of tokens when needed. + */ + private def updateTokensTask(): Unit = { try { - val remoteFs = FileSystem.get(freshHadoopConf) - val credentialsPath = new Path(credentialsFile) - val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .dropRight(numFilesToKeep) - .takeWhile(_.getModificationTime < thresholdTime) - .foreach(x => remoteFs.delete(x.getPath, true)) + val freshUGI = doLogin() + val creds = obtainTokensAndScheduleRenewal(freshUGI) + val tokens = SparkHadoopUtil.get.serialize(creds) + + val driver = driverRef.get() + if (driver != null) { + logInfo("Updating delegation tokens.") + driver.send(UpdateDelegationTokens(tokens)) + } else { + // This shouldn't really happen, since the driver should register way before tokens expire + // (or the AM should time out the application). + logWarning("Delegation tokens close to expiration but no driver has registered yet.") + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) + } } catch { - // Such errors are not fatal, so don't throw. Make sure they are logged though case e: Exception => - logWarning("Error while attempting to cleanup old credentials. If you are seeing many " + - "such warnings there may be an issue with your HDFS cluster.", e) + val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + + " If this happens too often tasks will fail.", e) + scheduleRenewal(delay) } } - private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = { - // Keytab is copied by YARN to the working directory of the AM, so full path is - // not needed. - - // HACK: - // HDFS will not issue new delegation tokens, if the Credentials object - // passed in already has tokens for that FS even if the tokens are expired (it really only - // checks if there are tokens for the service, and not if they are valid). So the only real - // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the current user's Credentials. - // So: - // - we login as a different user and get the UGI - // - use that UGI to get the tokens (see doAs block below) - // - copy the tokens over to the current user's credentials (this will overwrite the tokens - // in the current user's Credentials object for this FS). - // The login to KDC happens each time new tokens are required, but this is rare enough to not - // have to worry about (like once every day or so). This makes this code clearer than having - // to login and then relogin every time (the HDFS API may not relogin since we don't use this - // UGI directly for HDFS communication. - logInfo(s"Attempting to login to KDC using principal: $principal") - val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - logInfo("Successfully logged into KDC.") - val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(credentialsFile) - val dst = credentialsPath.getParent - var nearestNextRenewalTime = Long.MaxValue - keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { - // Get a copy of the credentials - override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainDelegationTokens( - freshHadoopConf, - tempCreds) - null + /** + * Obtain new delegation tokens from the available providers. Schedules a new task to fetch + * new tokens before the new set expires. + * + * @return Credentials containing the new tokens. + */ + private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { + ugi.doAs(new PrivilegedExceptionAction[Credentials]() { + override def run(): Credentials = { + val creds = new Credentials() + val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds) + + val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) - + System.currentTimeMillis() + scheduleRenewal(timeToWait) + creds } }) - - val currTime = System.currentTimeMillis() - val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) { - // If next renewal time is earlier than current time, we set next renewal time to current - // time, this will trigger next renewal immediately. Also set next update time to current - // time. There still has a gap between token renewal and update will potentially introduce - // issue. - logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " + - s"current time ($currTime), which is unexpected, please check your credential renewal " + - "related configurations in the target services.") - timeOfNextRenewal = currTime - currTime - } else { - // Next valid renewal time is about 75% of credential renewal time, and update time is - // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) - } - - // Add the temp credentials back to the original ones. - UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(freshHadoopConf) - // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM - // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file - // and update the lastCredentialsFileSuffix. - if (lastCredentialsFileSuffix == 0) { - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.foreach { status => - lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) - } - } - val nextSuffix = lastCredentialsFileSuffix + 1 - - val tokenPathStr = - credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - nextSuffix - val tokenPath = new Path(tokenPathStr) - val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - - logInfo("Writing out delegation tokens to " + tempTokenPath.toString) - val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) - logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") - remoteFs.rename(tempTokenPath, tokenPath) - logInfo("Delegation token file rename complete.") - lastCredentialsFileSuffix = nextSuffix } - def stop(): Unit = { - credentialRenewerThread.shutdown() + private def doLogin(): UserGroupInformation = { + logInfo(s"Attempting to login to KDC using principal: $principal") + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + ugi } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala deleted file mode 100644 index fe173dffc22a8..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.concurrent.{Executors, TimeUnit} - -import scala.util.control.NonFatal - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -private[spark] class CredentialUpdater( - sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { - - @volatile private var lastCredentialsFileSuffix = 0 - - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val freshHadoopConf = - SparkHadoopUtil.get.getConfBypassingFSCache( - hadoopConf, new Path(credentialsFile).toUri.getScheme) - - private val credentialUpdater = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) - - // This thread wakes up and picks up new credentials from HDFS, if any. - private val credentialUpdaterRunnable = - new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) - } - - /** Start the credential updater task */ - def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) - val remainingTime = startTime - System.currentTimeMillis() - if (remainingTime <= 0) { - credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) - } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") - credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) - } - } - - private def updateCredentialsIfRequired(): Unit = { - val timeToNextUpdate = try { - val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(freshHadoopConf) - SparkHadoopUtil.get.listFilesSorted( - remoteFs, credentialsFilePath.getParent, - credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.map { credentialsStatus => - val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) - if (suffix > lastCredentialsFileSuffix) { - logInfo("Reading new credentials from " + credentialsStatus.getPath) - val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) - lastCredentialsFileSuffix = suffix - UserGroupInformation.getCurrentUser.addCredentials(newCredentials) - logInfo("Credentials updated from credentials file.") - - val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis()) - if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime - } else { - // If current credential file is older than expected, sleep 1 hour and check again. - TimeUnit.HOURS.toMillis(1) - } - }.getOrElse { - // Wait for 1 minute to check again if there's no credential file currently - TimeUnit.MINUTES.toMillis(1) - } - } catch { - // Since the file may get deleted while we are reading it, catch the Exception and come - // back in an hour to try again - case NonFatal(e) => - logWarning("Error while trying to update credentials, will try again in 1 hour", e) - TimeUnit.HOURS.toMillis(1) - } - - logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") - credentialUpdater.schedule( - credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) - } - - private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { - val stream = remoteFs.open(tokenPath) - try { - val newCredentials = new Credentials() - newCredentials.readTokenStorageStream(stream) - newCredentials - } finally { - stream.close() - } - } - - private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = { - val name = credentialsPath.getName - val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - val slice = name.substring(0, index) - val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - name.substring(last2index + 1, index).toLong - } - - def stop(): Unit = { - credentialUpdater.shutdown() - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 163cfb4eb8624..d4eeb6bbcf886 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -22,11 +22,11 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -37,11 +37,10 @@ import org.apache.spark.util.Utils */ private[yarn] class YARNHadoopDelegationTokenManager( sparkConf: SparkConf, - hadoopConf: Configuration, - fileSystems: Configuration => Set[FileSystem]) extends Logging { + hadoopConf: Configuration) extends Logging { - private val delegationTokenManager = - new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) // public for testing val credentialProviders = getCredentialProviders diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0c6206eebe41d..06e54a2eaf95a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -62,12 +62,6 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() - // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver - // reads the credentials from HDFS, just like the executors and updates its own credentials - // cache. - if (conf.contains("spark.yarn.credentials.file")) { - YarnSparkHadoopUtil.startCredentialUpdater(conf) - } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -153,7 +147,6 @@ private[spark] class YarnClientSchedulerBackend( client.reportLauncherState(SparkAppHandle.State.FINISHED) super.stop() - YarnSparkHadoopUtil.stopCredentialUpdater() client.stop() logInfo("Stopped") } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bb615c36cd97f..63bea3e7a5003 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,9 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -70,6 +72,7 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -263,8 +266,13 @@ private[spark] abstract class YarnSchedulerBackend( logWarning(s"Requesting driver to remove executor $executorId for reason $reason") driverEndpoint.send(r) } - } + case u @ UpdateDelegationTokens(tokens) => + // Add the tokens to the current user and send a message to the scheduler so that it + // notifies all registered executors of the new tokens. + SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf) + driverEndpoint.send(u) + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 3c7cdc0f1dab8..9fa749b14c98c 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.security.Credentials import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { private var credentialManager: YARNHadoopDelegationTokenManager = null @@ -36,11 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers } test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - + credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) credentialManager.credentialProviders.get("yarn-test") should not be (None) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..3703a87cdb9ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -57,9 +57,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", - "spark.yarn.credentials.file", - "spark.yarn.credentials.renewalTime", - "spark.yarn.credentials.updateTime", "spark.ui.filters", "spark.mesos.driver.frameworkId") From 92e952557dbd8a170d66d615e25c6c6a8399dd43 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Mar 2018 21:01:07 +0900 Subject: [PATCH 0513/2461] [MINOR][R] Fix R lint failure ## What changes were proposed in this pull request? The lint failure bugged me: ```R R/SQLContext.R:715:97: style: Trailing whitespace is superfluous. #' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to ^ tests/fulltests/test_streaming.R:239:45: style: Commas should always have a space after. expect_equal(times[order(times$eventTime),][1, 2], 2) ^ lintr checks failed. ``` and I actually saw https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.6-ubuntu-test/500/console too. If I understood correctly, there is a try about moving to Unbuntu one. ## How was this patch tested? Manually tested by `./dev/lint-r`: ``` ... lintr checks passed. ``` Author: hyukjinkwon Closes #20879 from HyukjinKwon/minor-r-lint. --- R/pkg/R/SQLContext.R | 2 +- R/pkg/tests/fulltests/test_streaming.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ebec0ce3d1920..429dd5d565492 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -712,7 +712,7 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to #' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it #' uses the default value, session local timezone. #' @return SparkDataFrame diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index a354d50c6b54e..bfb1a046490ec 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -236,7 +236,7 @@ test_that("Watermark", { times <- collect(sql("SELECT * FROM times")) # looks like write timing can affect the first bucket; but it should be t - expect_equal(times[order(times$eventTime),][1, 2], 2) + expect_equal(times[order(times$eventTime), ][1, 2], 2) stopQuery(q) unlink(parquetPath) From 6ac4fba69290e1c7de2c0a5863f224981dedb919 Mon Sep 17 00:00:00 2001 From: arucard21 Date: Fri, 23 Mar 2018 21:02:34 +0900 Subject: [PATCH 0514/2461] [SPARK-23769][CORE] Remove comments that unnecessarily disable Scalastyle check ## What changes were proposed in this pull request? We re-enabled the Scalastyle checker on a line of code. It was previously disabled, but it does not violate any of the rules. So there's no reason to disable the Scalastyle checker here. ## How was this patch tested? We tested this by running `build/mvn scalastyle:check` after removing the comments that disable the checker. This check passed with no errors or warnings for Spark Core ``` [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Core 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- scalastyle-maven-plugin:1.0.0:check (default-cli) spark-core_2.11 --- Saving to outputFile=/spark/core/target/scalastyle-output.xml Processed 485 file(s) Found 0 errors Found 0 warnings Found 0 infos ``` We did not run all tests (with `dev/run-tests`) since this Scalastyle check seemed sufficient. ## Co-contributors: chialun-yeh Hrayo712 vpourquie Author: arucard21 Closes #20880 from arucard21/scalastyle_util. --- .../org/apache/spark/storage/BlockReplicationPolicy.scala | 4 +--- .../main/scala/org/apache/spark/util/CompletionIterator.scala | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index 353eac60df171..0bacc34cdfd90 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -54,10 +54,9 @@ trait BlockReplicationPolicy { } object BlockReplicationUtils { - // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see + * minimizing space usage. Please see * here. * * @param n total number of indices @@ -65,7 +64,6 @@ object BlockReplicationUtils { * @param r random number generator * @return list of m random unique indices */ - // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 31d230d0fec8e..21acaa95c5645 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -22,9 +22,7 @@ package org.apache.spark.util * through all the elements. */ private[spark] -// scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { -// scalastyle:on private[this] var completed = false def next(): A = sub.next() From 8b56f16640fc4156aa7bd529c54469d27635b951 Mon Sep 17 00:00:00 2001 From: bag_of_tricks Date: Fri, 23 Mar 2018 10:36:23 -0700 Subject: [PATCH 0515/2461] [SPARK-23759][UI] Unable to bind Spark UI to specific host name / IP ## What changes were proposed in this pull request? Fixes SPARK-23759 by moving connector.start() after connector.setHost() Problem was created due connector.setHost(hostName) call was after connector.start() ## How was this patch tested? Patch was tested after build and deployment. This patch requires SPARK_LOCAL_IP environment variable to be set on spark-env.sh Author: bag_of_tricks Closes #20883 from felixalbani/SPARK-23759. --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0adeb4058b6e4..0e8a6307de6a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -343,12 +343,13 @@ private[spark] object JettyUtils extends Logging { -1, connectionFactories: _*) connector.setPort(port) - connector.start() + connector.setHost(hostName) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) - connector.setHost(hostName) + + connector.start() // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 From cb43bbe13606673349511829fd71d1f34fc39c45 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 23 Mar 2018 11:42:40 -0700 Subject: [PATCH 0516/2461] [SPARK-21685][PYTHON][ML] PySpark Params isSet state should not change after transform ## What changes were proposed in this pull request? Currently when a PySpark Model is transformed, default params that have not been explicitly set are then set on the Java side on the call to `wrapper._transfer_values_to_java`. This incorrectly changes the state of the Param as it should still be marked as a default value only. ## How was this patch tested? Added a new test to verify that when transferring Params to Java, default params have their state preserved. Author: Bryan Cutler Closes #18982 from BryanCutler/pyspark-ml-param-to-java-defaults-SPARK-21685. --- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- python/pyspark/ml/wrapper.py | 13 ++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index fd45fd00b270b..080119959a4e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -369,7 +369,7 @@ def test_property(self): raise RuntimeError("Test property to raise error when invoked") -class ParamTests(PySparkTestCase): +class ParamTests(SparkSessionTestCase): def test_copy_new_parent(self): testParams = TestParams() @@ -514,6 +514,24 @@ def test_logistic_regression_check_thresholds(self): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + @staticmethod def check_params(test_self, py_stage, check_params_exist=True): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 5061f6434794a..d325633195ddb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -118,11 +118,18 @@ def _transfer_params_to_java(self): """ Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap() + pair_defaults = [] for param in self.params: - if param in paramMap: - pair = self._make_java_param_pair(param, paramMap[param]) + if self.isSet(param): + pair = self._make_java_param_pair(param, self._paramMap[param]) self._java_obj.set(pair) + if self.hasDefault(param): + pair = self._make_java_param_pair(param, self._defaultParamMap[param]) + pair_defaults.append(pair) + if len(pair_defaults) > 0: + sc = SparkContext._active_spark_context + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) + self._java_obj.setDefault(pair_defaults_seq) def _transfer_param_map_to_java(self, pyParamMap): """ From 95c03cbd27cea2255d9d748f9a84a0a38e54594d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Mar 2018 11:56:17 -0700 Subject: [PATCH 0517/2461] [SPARK-23783][SPARK-11239][ML] Add PMML export to Spark ML pipelines ## What changes were proposed in this pull request? Adds PMML export support to Spark ML pipelines in the style of Spark's DataSource API to allow library authors to add their own model export formats. Includes a specific implementation for Spark ML linear regression PMML export. In addition to adding PMML to reach parity with our current MLlib implementation, this approach will allow other libraries & formats (like PFA) to implement and export models with a unified API. ## How was this patch tested? Basic unit test. Author: Holden Karau Author: Holden Karau Closes #19876 from holdenk/SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans-r2. --- .../org.apache.spark.ml.util.MLFormatRegister | 2 + .../ml/regression/LinearRegression.scala | 70 ++++--- .../org/apache/spark/ml/util/ReadWrite.scala | 173 +++++++++++++++++- .../org.apache.spark.ml.util.MLFormatRegister | 3 + .../ml/regression/LinearRegressionSuite.scala | 27 ++- .../spark/ml/util/PMMLReadWriteTest.scala | 55 ++++++ .../org/apache/spark/ml/util/PMMLUtils.scala | 43 +++++ .../apache/spark/ml/util/ReadWriteSuite.scala | 132 +++++++++++++ 8 files changed, 474 insertions(+), 31 deletions(-) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..5e5484fd8784d --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,2 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92510154d500e..f67d9d831f327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{PipelineStage, PredictorParams} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionParams with GeneralMLWritable { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) @@ -710,7 +711,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -718,7 +719,50 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) +} + +/** A writer for LinearRegression that handles the "internal" (or default) format */ +private class InternalLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector, scale: Double) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[LinearRegressionModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients, scale + val data = Data(instance.intercept, instance.coefficients, instance.scale) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for LinearRegression that handles the "pmml" format */ +private class PMMLLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val sc = sparkSession.sparkContext + // Construct the MLLib model which knows how to write to PMML. + val instance = stage.asInstanceOf[LinearRegressionModel] + val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) + // Save PMML + oldModel.toPMML(sc, path) + } } @Since("1.6.0") @@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") override def load(path: String): LinearRegressionModel = super.load(path) - /** [[MLWriter]] instance for [[LinearRegressionModel]] */ - private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends MLWriter with Logging { - - private case class Data(intercept: Double, coefficients: Vector, scale: Double) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients, scale - val data = Data(instance.intercept, instance.coefficients, instance.scale) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a616907800969..7edcd498678cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,9 +18,11 @@ package org.apache.spark.ml.util import java.io.IOException -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.json4s._ @@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Abstract class to be implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.4.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLFormatRegister extends MLWriterFormat { + /** + * The string that represents the format that this format provider uses. This is, along with + * stageName, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def format(): String = + * "pmml" + * }}} + * Indicates that this format is capable of saving a pmml model. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def format(): String + + /** + * The string that represents the stage type that this writer supports. This is, along with + * format, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def stageName(): String = + * "org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own PMML model. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def stageName(): String + + private[ml] def shortName(): String = s"${format()}+${stageName()}" +} + +/** + * Abstract class for utility classes that can save ML instances in Spark's internal format. */ @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { @@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + /** * Map to store extra options for this writer. */ @@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + @Since("1.6.0") + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + @Since("1.6.0") + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +@InterfaceStability.Unstable +@Since("2.4.0") +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. "pmml", "internal", or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { - shouldOverwrite = true + @Since("2.4.0") + def format(source: String): this.type = { + this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.4.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String): Unit = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) + val stageName = stage.getClass.getName + val targetName = s"$source+$stageName" + val formats = serviceLoader.asScala.toList + val shortNames = formats.map(_.shortName()) + val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => + Try(loader.loadClass(source)) match { + case Success(writer) => + // Found the ML writer using the fully qualified path + writer + case Failure(error) => + throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) + } + case head :: Nil => + head.getClass + case _ => + // Multiple sources + throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") + } + if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + writer.write(path, sparkSession, optionMap, stage) + } else { + throw new SparkException(s"ML source $source is not a valid MLWriterFormat") + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) @@ -162,6 +306,19 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * Trait for classes that provide `GeneralMLWriter`. + */ +@Since("2.4.0") +@InterfaceStability.Unstable +trait GeneralMLWritable extends MLWritable { + /** + * Returns an `MLWriter` instance for this ML instance. + */ + @Since("2.4.0") + override def write: GeneralMLWriter +} + /** * :: DeveloperApi :: * diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..100ef2545418f --- /dev/null +++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,3 @@ +org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 +org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 +org.apache.spark.ml.util.FakeLinearRegressionWriterWithName diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 9b19f63eba1bd..90ceb7dee38f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.ml.regression +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random +import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} + import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { + +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { import testImplicits._ @@ -1052,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 0000000000000..d2c4832b12bac --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.write.format("pmml").save(path) + intercept[IOException] { + instance.write.format("pmml").save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark")) + checkModelData(pmmlModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala new file mode 100644 index 0000000000000..dbdc69f95d841 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.util + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala new file mode 100644 index 0000000000000..f4c1f0bdb32cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.sql.{DataFrame, SparkSession} + +class FakeLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + +class FakeLinearRegressionWriterWithName extends MLFormatRegister { + override def format(): String = "fakeWithName" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + + +class DuplicateLinearRegressionWriter1 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class DuplicateLinearRegressionWriter2 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class ReadWriteSuite extends MLTest { + + import testImplicits._ + + private val seed: Int = 42 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0), + xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF() + } + + test("unsupported/non existent export formats") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + // Does not exist with a long class name + val thrownDNE = intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + assert(thrownDNE.getMessage(). + contains("Could not load requested format")) + + // Does not exist with a short name + val thrownDNEShort = intercept[SparkException] { + model.write.format("boop").save("boop") + } + assert(thrownDNEShort.getMessage(). + contains("Could not load requested format")) + + // Check with a valid class that is not a writer format. + val thrownInvalid = intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + assert(thrownInvalid.getMessage() + .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat")) + } + + test("invalid paths fail") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("pmml").save("") + } + assert(thrown.getMessage().contains("Can not create a Path from an empty string")) + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name") + } + assert(thrown.getMessage().contains("Fake writer doesn't write")) + val thrownWithName = intercept[Exception] { + model.write.format("fakeWithName").save("name") + } + assert(thrownWithName.getMessage().contains("Fake writer doesn't write")) + } + + test("duplicate format raises error") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("dupe").save("dupepanda") + } + assert(thrown.getMessage().contains("Multiple writers found for")) + } +} From a33655348c4066d9c1d8ad2055aadfbc892ba7fd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 23 Mar 2018 15:58:48 -0700 Subject: [PATCH 0518/2461] [SPARK-23615][ML][PYSPARK] Add maxDF Parameter to Python CountVectorizer ## What changes were proposed in this pull request? The maxDF parameter is for filtering out frequently occurring terms. This param was recently added to the Scala CountVectorizer and needs to be added to Python also. ## How was this patch tested? add test Author: Huaxin Gao Closes #20777 from huaxingao/spark-23615. --- .../spark/ml/feature/CountVectorizer.scala | 20 +++++----- python/pyspark/ml/feature.py | 40 ++++++++++++++----- python/pyspark/ml/tests.py | 25 ++++++++++++ 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 60a4f918790a3..9e0ed437e7bfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,19 +70,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit def getMinDF: Double = $(minDF) /** - * Specifies the maximum number of different documents a term must appear in to be included - * in the vocabulary. - * If this is an integer greater than or equal to 1, this specifies the number of documents - * the term must appear in; if this is a double in [0,1), then this specifies the fraction of - * documents. + * Specifies the maximum number of different documents a term could appear in to be included + * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an + * integer greater than or equal to 1, this specifies the maximum number of documents the term + * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of + * documents the term could appear in. * - * Default: (2^64^) - 1 + * Default: (2^63^) - 1 * @group param */ val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an integer >= 1," + + " this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum fraction of" + + " documents the term could appear in.", ParamValidators.gtEq(0.0)) /** @group getParam */ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a1ceb7f02da8b..fcb0dfc563720 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -422,6 +422,14 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): " If this is an integer >= 1, this specifies the number of documents the term must" + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + " Default 1.0", typeConverter=TypeConverters.toFloat) + maxDF = Param( + Params._dummy(), "maxDF", "Specifies the maximum number of" + + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an" + + " integer >= 1, this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum" + + " fraction of documents the term could appear in." + + " Default (2^63) - 1", typeConverter=TypeConverters.toFloat) vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) @@ -433,7 +441,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): def __init__(self, *args): super(_CountVectorizerParams, self).__init__(*args) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False) @since("1.6.0") def getMinTF(self): @@ -449,6 +457,13 @@ def getMinDF(self): """ return self.getOrDefault(self.minDF) + @since("2.4.0") + def getMaxDF(self): + """ + Gets the value of maxDF or its default value. + """ + return self.getOrDefault(self.maxDF) + @since("1.6.0") def getVocabSize(self): """ @@ -513,11 +528,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav """ @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None,outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", @@ -527,11 +542,11 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None, outputCol=None) Set the params for the CountVectorizer """ kwargs = self._input_kwargs @@ -551,6 +566,13 @@ def setMinDF(self, value): """ return self._set(minDF=value) + @since("2.4.0") + def setMaxDF(self, value): + """ + Sets the value of :py:attr:`maxDF`. + """ + return self._set(maxDF=value) + @since("1.6.0") def setVocabSize(self, value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 080119959a4e8..cf1ffa181ecec 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -697,6 +697,31 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", outputCol="features", minTF=2) From 816a5496ba4caac438f70400f72bb10bfcc02418 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Sat, 24 Mar 2018 18:21:01 -0700 Subject: [PATCH 0519/2461] [SPARK-23788][SS] Fix race in StreamingQuerySuite ## What changes were proposed in this pull request? The serializability test uses the same MemoryStream instance for 3 different queries. If any of those queries ask it to commit before the others have run, the rest will see empty dataframes. This can fail the test if q3 is affected. We should use one instance per query instead. ## How was this patch tested? Existing unit test. If I move q2.processAllAvailable() before starting q3, the test always fails without the fix. Author: Jose Torres Closes #20896 from jose-torres/fixrace. --- .../spark/sql/streaming/StreamingQuerySuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ebc9a87b23f84..08749b49997e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -550,22 +550,22 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi .start() } - val input = MemoryStream[Int] - val q1 = startQuery(input.toDS, "stream_serializable_test_1") - val q2 = startQuery(input.toDS.map { i => + val input = MemoryStream[Int] :: MemoryStream[Int] :: MemoryStream[Int] :: Nil + val q1 = startQuery(input(0).toDS, "stream_serializable_test_1") + val q2 = startQuery(input(1).toDS.map { i => // Emulate that `StreamingQuery` get captured with normal usage unintentionally. // It should not fail the query. q1 i }, "stream_serializable_test_2") - val q3 = startQuery(input.toDS.map { i => + val q3 = startQuery(input(2).toDS.map { i => // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear // error message. q1.explain() i }, "stream_serializable_test_3") try { - input.addData(1) + input.foreach(_.addData(1)) // q2 should not fail since it doesn't use `q1` in the closure q2.processAllAvailable() From 5f653d4f7c84e6147cd323cd650da65e0381ebe8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 25 Mar 2018 09:18:26 -0700 Subject: [PATCH 0520/2461] [SPARK-23167][SQL] Add TPCDS queries v2.7 in TPCDSQuerySuite ## What changes were proposed in this pull request? This pr added TPCDS v2.7 (latest) queries in `TPCDSQuerySuite` because the current `TPCDSQuerySuite` tests older one (v1.4) and some queries are different from v1.4 and v2.7. Since the original v2.7 queries have the syntaxes that Spark cannot parse, I changed these queries in a following way: - [date] + 14 days -> date + `INTERVAL` 14 days - [column name] as "30 days" -> [column name] as \`30 days\` - Fix some syntax errors, e.g., missing brackets ## How was this patch tested? Added tests in `TPCDSQuerySuite`. Author: Takeshi Yamamuro Closes #20343 from maropu/TPCDSV2_7. --- .../src/test/resources/tpcds-v2.7.0/q10a.sql | 69 ++++++ .../src/test/resources/tpcds-v2.7.0/q11.sql | 84 +++++++ .../src/test/resources/tpcds-v2.7.0/q12.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q14.sql | 135 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q14a.sql | 215 ++++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q18a.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q20.sql | 19 ++ .../src/test/resources/tpcds-v2.7.0/q22.sql | 15 ++ .../src/test/resources/tpcds-v2.7.0/q22a.sql | 94 ++++++++ .../src/test/resources/tpcds-v2.7.0/q24.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q27a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q34.sql | 37 +++ .../src/test/resources/tpcds-v2.7.0/q35.sql | 65 ++++++ .../src/test/resources/tpcds-v2.7.0/q35a.sql | 62 +++++ .../src/test/resources/tpcds-v2.7.0/q36a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q47.sql | 64 ++++++ .../src/test/resources/tpcds-v2.7.0/q49.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q51a.sql | 103 +++++++++ .../src/test/resources/tpcds-v2.7.0/q57.sql | 57 +++++ .../src/test/resources/tpcds-v2.7.0/q5a.sql | 158 +++++++++++++ .../src/test/resources/tpcds-v2.7.0/q6.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q64.sql | 111 +++++++++ .../src/test/resources/tpcds-v2.7.0/q67a.sql | 208 +++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q70a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q72.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q74.sql | 60 +++++ .../src/test/resources/tpcds-v2.7.0/q75.sql | 78 +++++++ .../src/test/resources/tpcds-v2.7.0/q77a.sql | 121 ++++++++++ .../src/test/resources/tpcds-v2.7.0/q78.sql | 75 ++++++ .../src/test/resources/tpcds-v2.7.0/q80a.sql | 147 ++++++++++++ .../src/test/resources/tpcds-v2.7.0/q86a.sql | 61 +++++ .../src/test/resources/tpcds-v2.7.0/q98.sql | 22 ++ .../apache/spark/sql/TPCDSQuerySuite.scala | 38 +++- 33 files changed, 2691 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q11.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q12.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q20.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q22.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q24.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q34.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q35.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q47.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q49.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q57.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q6.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q64.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q72.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q74.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q75.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q78.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q98.sql diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql new file mode 100644 index 0000000000000..50e521567eb3a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql @@ -0,0 +1,69 @@ +-- This is a new query in TPCDS v2.7 +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from + customer c,customer_address ca,customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and ca_county in ('Walker County', 'Richland County', 'Gaines County', 'Douglas County', 'Dona Ana County') + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales,date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) + and exists ( + select * + from ( + select + ws_bill_customer_sk as customer_sk, + d_year, + d_moy + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3 + union all + select + cs_ship_customer_sk as customer_sk, + d_year, + d_moy + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) x + where c.c_customer_sk = customer_sk) +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql new file mode 100755 index 0000000000000..97bed33721742 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql @@ -0,0 +1,84 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT + -- select list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END +ORDER BY + -- order-by list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql new file mode 100755 index 0000000000000..7a6fafd22428a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql @@ -0,0 +1,23 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql new file mode 100644 index 0000000000000..b2ca3ddaf2baf --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql @@ -0,0 +1,135 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14a.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1998 AND 1998 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1998 AND 1998 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1998 AND 1998 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x) +select + * +from ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + 1 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year, + ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales)) last_year +where + this_year.i_brand_id = last_year.i_brand_id + and this_year.i_class_id = last_year.i_class_id + and this_year.i_category_id = last_year.i_category_id +order by + this_year.channel, + this_year.i_brand_id, + this_year.i_class_id, + this_year.i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql new file mode 100644 index 0000000000000..bfa70fe62d8d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql @@ -0,0 +1,215 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14b.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1999 AND 1999 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1999 AND 1999 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1999 AND 1999 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1999 and 2001 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x), +results AS ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales) sum_sales, + sum(number_sales) number_sales + from ( + select + 'store' channel, + i_brand_id,i_class_id, + i_category_id, + sum(ss_quantity*ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales) + union all + select + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity*cs_list_price) sales, + count(*) number_sales + from + catalog_sales, item, date_dim + where + cs_item_sk in (select ss_item_sk from cross_items) + and cs_item_sk = i_item_sk + and cs_sold_date_sk = d_date_sk + and d_year = 1998+2 + and d_moy = 11 + group by + i_brand_id,i_class_id,i_category_id + having + sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales) + union all + select + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity*ws_list_price) sales, + count(*) number_sales + from + web_sales, item, date_dim + where + ws_item_sk in (select ss_item_sk from cross_items) + and ws_item_sk = i_item_sk + and ws_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales)) y + group by + channel, + i_brand_id, + i_class_id, + i_category_id) +select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales +from ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales + from + results + union + select + channel, + i_brand_id, + i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id, + i_class_id + union + select + channel, + i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id + union + select + channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel + union + select + null as channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results) z +order by + channel, + i_brand_id, + i_class_id, + i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql new file mode 100644 index 0000000000000..2201a302ab352 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql @@ -0,0 +1,133 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + cast(cs_quantity as decimal(12,2)) agg1, + cast(cs_list_price as decimal(12,2)) agg2, + cast(cs_coupon_amt as decimal(12,2)) agg3, + cast(cs_sales_price as decimal(12,2)) agg4, + cast(cs_net_profit as decimal(12,2)) agg5, + cast(c_birth_year as decimal(12,2)) agg6, + cast(cd1.cd_dep_count as decimal(12,2)) agg7 + from + catalog_sales, customer_demographics cd1, customer_demographics cd2, customer, + customer_address, date_dim, item + where + cs_sold_date_sk = d_date_sk + and cs_item_sk = i_item_sk + and cs_bill_cdemo_sk = cd1.cd_demo_sk + and cs_bill_customer_sk = c_customer_sk + and cd1.cd_gender = 'M' + and cd1.cd_education_status = 'College' + and c_current_cdemo_sk = cd2.cd_demo_sk + and c_current_addr_sk = ca_address_sk + and c_birth_month in (9,5,12,4,1,10) + and d_year = 2001 + and ca_state in ('ND','WI','AL','NC','OK','MS','TN')) +select + i_item_id, + ca_country, + ca_state, + ca_county, + agg1, + agg2, + agg3, + agg4, + agg5, + agg6, + agg7 +from ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state, + ca_county + union all + select + i_item_id, + ca_country, + ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state + union all + select + i_item_id, + ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id, + ca_country + union all + select + i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results) foo +order by + ca_country, + ca_state, + ca_county, + i_item_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql new file mode 100755 index 0000000000000..34d46b1394d8f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql new file mode 100755 index 0000000000000..e7bea0804f162 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql @@ -0,0 +1,15 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + -- q22 in TPCDS v1.4 had a condition below: + -- AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql new file mode 100644 index 0000000000000..c886e6271511b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql @@ -0,0 +1,94 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh + from + inventory, date_dim, item, warehouse + where + inv_date_sk = d_date_sk + and inv_item_sk = i_item_sk + and inv_warehouse_sk = w_warehouse_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_product_name, + i_brand, + i_class, + i_category), +results_rollup as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class, + i_category + union all + select + i_product_name, + i_brand, + i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class + union all + select + i_product_name, + i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand + union all + select + i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name + union all + select + null i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results) +select + i_product_name, + i_brand, + i_class, + i_category, + qoh +from + results_rollup +order by + qoh, + i_product_name, + i_brand, + i_class, + i_category +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql new file mode 100755 index 0000000000000..92d64bc7eba78 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql @@ -0,0 +1,40 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_current_addr_sk = ca_address_sk -- This condition did not exist in TPCDS v1.4 + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) +-- no order-by exists in q24a of TPCDS v1.4 +ORDER BY + c_last_name, + c_first_name, + s_store_name diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql new file mode 100644 index 0000000000000..c70a2420e8387 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + s_state, 0 as g_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + from + store_sales, customer_demographics, date_dim, store, item + where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_store_sk = s_store_sk + and ss_cdemo_sk = cd_demo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and d_year = 1998 + and s_state in ('TN','TN', 'TN', 'TN', 'TN', 'TN')) +select + i_item_id, + s_state, + g_state, + agg1, + agg2, + agg3, + agg4 +from ( + select + i_item_id, + s_state, + 0 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results + group by + i_item_id, + s_state + union all + select + i_item_id, + NULL AS s_state, + 1 AS g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as s_state, + 1 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results) foo +order by + i_item_id, + s_state +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql new file mode 100755 index 0000000000000..bbede62acc9a7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql @@ -0,0 +1,37 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag DESC, + ss_ticket_number -- This order-by condition did not exist in TPCDS v1.4 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql new file mode 100755 index 0000000000000..27116a563d5c6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql @@ -0,0 +1,65 @@ +SELECT + -- select list of q35 in TPCDS v1.4 is below: + -- ca_state, + -- cd_gender, + -- cd_marital_status, + -- count(*) cnt1, + -- min(cd_dep_count), + -- max(cd_dep_count), + -- avg(cd_dep_count), + -- cd_dep_employed_count, + -- count(*) cnt2, + -- min(cd_dep_employed_count), + -- max(cd_dep_employed_count), + -- avg(cd_dep_employed_count), + -- cd_dep_college_count, + -- count(*) cnt3, + -- min(cd_dep_college_count), + -- max(cd_dep_college_count), + -- avg(cd_dep_college_count) + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql new file mode 100644 index 0000000000000..1c1463e44777f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql @@ -0,0 +1,62 @@ +-- This is a new query in TPCDS v2.7 +select + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +from + customer c, customer_address ca, customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales, date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) + and exists ( + select * + from ( + select ws_bill_customer_sk customsk + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4 + union all + select cs_ship_customer_sk customsk + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) x + where x.customsk = c.c_customer_sk) +group by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql new file mode 100644 index 0000000000000..9d98f32add508 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as ss_net_profit, + sum(ss_ext_sales_price) as ss_ext_sales_price, + sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin, + i_category, + i_class, + 0 as g_category, + 0 as g_class + from + store_sales, + date_dim d1, + item, + store + where + d1.d_year = 2001 + and d1.d_date_sk = ss_sold_date_sk + and i_item_sk = ss_item_sk + and s_store_sk = ss_store_sk + and s_state in ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') + group by + i_category, + i_class), + results_rollup as ( + select + gross_margin, + i_category, + i_class, + 0 as t_category, + 0 as t_class, + 0 as lochierarchy + from + results + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + i_category, NULL AS i_class, + 0 as t_category, + 1 as t_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + NULL AS i_category, + NULL AS i_class, + 1 as t_category, + 1 as t_class, + 2 as lochierarchy + from + results) +select + gross_margin, + i_category, + i_class, + lochierarchy, + rank() over ( + partition by lochierarchy, case when t_class = 0 then i_category end + order by gross_margin asc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql new file mode 100755 index 0000000000000..9f7ee457ea45f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql @@ -0,0 +1,64 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + -- q47 in TPCDS v1.4 had more columns below: + -- v1.i_brand, + -- v1.s_store_name, + -- v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql new file mode 100755 index 0000000000000..e8061bde4159e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql @@ -0,0 +1,133 @@ +-- The first SELECT query below is different from q49 of TPCDS v1.4 +SELECT + channel, + item, + return_ratio, + return_rank, + currency_rank +FROM ( + SELECT + 'web' as channel, + in_web.item, + in_web.return_ratio, + in_web.return_rank, + in_web.currency_rank + FROM + (SELECT + item, + return_ratio, + currency_ratio, + rank() over (ORDER BY return_ratio) AS return_rank, + rank() over (ORDER BY currency_ratio) AS currency_rank + FROM ( + SELECT + ws.ws_item_sk AS item, + CAST(SUM(COALESCE(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_quantity, 0)) AS DECIMAL(15, 4)) AS return_ratio, + CAST(SUM(COALESCE(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_net_paid, 0)) AS DECIMAL(15, 4)) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND ws.ws_item_sk = wr.wr_item_sk), + date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY + ws.ws_item_sk) + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY + -- order-by list of q49 in TPCDS v1.4 is below: + -- 1, 4, 5 + 1, 4, 5, 2 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql new file mode 100644 index 0000000000000..b8cbbbc8ef7d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql @@ -0,0 +1,103 @@ +-- This is a new query in TPCDS v2.7 +WITH web_tv as ( + select + ws_item_sk item_sk, + d_date, + sum(ws_sales_price) sumws, + row_number() over (partition by ws_item_sk order by d_date) rk + from + web_sales, date_dim + where + ws_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ws_item_sk is not NULL + group by + ws_item_sk, d_date), +web_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumws, + sum(v2.sumws) cume_sales + from + web_tv v1, web_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumws), +store_tv as ( + select + ss_item_sk item_sk, + d_date, + sum(ss_sales_price) sumss, + row_number() over (partition by ss_item_sk order by d_date) rk + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_item_sk is not NULL + group by ss_item_sk, d_date), +store_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumss, + sum(v2.sumss) cume_sales + from + store_tv v1, store_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumss), +v as ( + select + item_sk, + d_date, + web_sales, + store_sales, + row_number() over (partition by item_sk order by d_date) rk + from ( + select + case when web.item_sk is not null + then web.item_sk + else store.item_sk end item_sk, + case when web.d_date is not null + then web.d_date + else store.d_date end d_date, + web.cume_sales web_sales, + store.cume_sales store_sales + from + web_v1 web full outer join store_v1 store + on (web.item_sk = store.item_sk and web.d_date = store.d_date))) +select * +from ( + select + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales, + max(v2.web_sales) web_cumulative, + max(v2.store_sales) store_cumulative + from + v v1, v v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales) x +where + web_cumulative > store_cumulative +order by + item_sk, + d_date +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql new file mode 100755 index 0000000000000..ccefaac3c12ca --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql @@ -0,0 +1,57 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + -- q57 in TPCDS v1.4 had a column below: + -- v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql new file mode 100644 index 0000000000000..42bcf59c2aeb1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql @@ -0,0 +1,158 @@ +-- This is a new query in TPCDS v2.7 +with ssr as( + select + s_store_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ss_store_sk as store_sk, + ss_sold_date_sk as date_sk, + ss_ext_sales_price as sales_price, + ss_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + store_sales + union all + select + sr_store_sk as store_sk, + sr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + sr_return_amt as return_amt, + sr_net_loss as net_loss + from + store_returns) salesreturns, + date_dim, + store + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and store_sk = s_store_sk + group by + s_store_id), +csr as ( + select + cp_catalog_page_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + cs_catalog_page_sk as page_sk, + cs_sold_date_sk as date_sk, + cs_ext_sales_price as sales_price, + cs_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from catalog_sales + union all + select + cr_catalog_page_sk as page_sk, + cr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + cr_return_amount as return_amt, + cr_net_loss as net_loss + from catalog_returns) salesreturns, + date_dim, + catalog_page + where + date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and page_sk = cp_catalog_page_sk + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ws_web_site_sk as wsr_web_site_sk, + ws_sold_date_sk as date_sk, + ws_ext_sales_price as sales_price, + ws_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + web_sales + union all + select + ws_web_site_sk as wsr_web_site_sk, + wr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + wr_return_amt as return_amt, + wr_net_loss as net_loss + from + web_returns + left outer join web_sales on ( + wr_item_sk = ws_item_sk and wr_order_number = ws_order_number) + ) salesreturns, + date_dim, + web_site + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and wsr_web_site_sk = web_site_sk + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || s_store_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || cp_catalog_page_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + wsr) x + group by + channel, id) +select + channel, id, sales, returns, profit +from ( + select channel, id, sales, returns, profit + from results + union + select channel, null as id, sum(sales), sum(returns), sum(profit) + from results + group by channel + union + select null as channel, null as id, sum(sales), sum(returns), sum(profit) + from results) foo + order by channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql new file mode 100755 index 0000000000000..c0bfa40ad44a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql @@ -0,0 +1,23 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +-- order-by list of q6 in TPCDS v1.4 is below: +-- order by cnt +order by cnt, a.ca_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql new file mode 100755 index 0000000000000..cdcd8486b363d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql @@ -0,0 +1,111 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY + i_product_name, + i_item_sk, + s_store_name, + s_zip, + ad1.ca_street_number, + ad1.ca_street_name, + ad1.ca_city, + ad1.ca_zip, + ad2.ca_street_number, + ad2.ca_street_name, + ad2.ca_city, + ad2.ca_zip, + d1.d_year, + d2.d_year, + d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY + cs1.product_name, + cs1.store_name, + cs2.cnt, + -- The two columns below are newly added in TPCDS v2.7 + cs1.s1, + cs2.s1 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql new file mode 100644 index 0000000000000..70a14043bbb3d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql @@ -0,0 +1,208 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + from + store_sales, date_dim, store, item + where + ss_sold_date_sk=d_date_sk + and ss_item_sk=i_item_sk + and ss_store_sk = s_store_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id), +results_rollup as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales + from + results + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year + union all + select + i_category, + i_class, + i_brand, + i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name + union all + select + i_category, + i_class, + i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand + union all + select + i_category, + i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class + union all + select + i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from results + group by + i_category + union all + select + null i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results) +select + * +from ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() over (partition by i_category order by sumsales desc) rk + from results_rollup) dw2 +where + rk <= 100 +order by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rk +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql new file mode 100644 index 0000000000000..4aec9c7fd1fd6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as total_sum, + s_state ,s_county, + 0 as gstate, + 0 as g_county + from + store_sales, date_dim d1, store + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_state in ( + select s_state + from ( + select + s_state as s_state, + rank() over (partition by s_state order by sum(ss_net_profit) desc) as ranking + from store_sales, store, date_dim + where d_month_seq between 1212 and 1212 + 11 + and d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + group by s_state) tmp1 + where ranking <= 5) + group by + s_state, s_county), +results_rollup as ( + select + total_sum, + s_state, + s_county, + 0 as g_state, + 0 as g_county, + 0 as lochierarchy + from results + union + select + sum(total_sum) as total_sum,s_state, + NULL as s_county, + 0 as g_state, + 1 as g_county, + 1 as lochierarchy + from results + group by s_state + union + select + sum(total_sum) as total_sum, + NULL as s_state, + NULL as s_county, + 1 as g_state, + 1 as g_county, + 2 as lochierarchy + from results) +select + total_sum, + s_state, + s_county, + lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_county = 0 then s_state end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then s_state end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql new file mode 100755 index 0000000000000..066d6a587e917 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql @@ -0,0 +1,40 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +-- q72 in TPCDS v1.4 had conditions below: +-- WHERE d1.d_week_seq = d2.d_week_seq +-- AND inv_quantity_on_hand < cs_quantity +-- AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) +-- AND hd_buy_potential = '>10000' +-- AND d1.d_year = 1999 +-- AND hd_buy_potential = '>10000' +-- AND cd_marital_status = 'D' +-- AND d1.d_year = 1999 +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > d1.d_date + INTERVAL 5 days + AND hd_buy_potential = '1001-5000' + AND d1.d_year = 2001 + AND cd_marital_status = 'M' +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql new file mode 100755 index 0000000000000..94a0063b36c0c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql @@ -0,0 +1,60 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +-- order-by list of q74 in TPCDS v1.4 is below: +-- ORDER BY 1, 1, 1 +ORDER BY 2, 1, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql new file mode 100755 index 0000000000000..ae5dc97ef2317 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql @@ -0,0 +1,78 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY + sales_cnt_diff, + sales_amt_diff -- This order-by condition did not exist in TPCDS v1.4 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql new file mode 100644 index 0000000000000..fc69c43470f1e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql @@ -0,0 +1,121 @@ +-- This is a new query in TPCDS v2.7 +with ss as ( + select + s_store_sk, + sum(ss_ext_sales_price) as sales, + sum(ss_net_profit) as profit + from + store_sales, date_dim, store + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + group by + s_store_sk), +sr as ( + select + s_store_sk, + sum(sr_return_amt) as returns, + sum(sr_net_loss) as profit_loss + from + store_returns, date_dim, store + where + sr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and sr_store_sk = s_store_sk + group by + s_store_sk), +cs as ( + select + cs_call_center_sk, + sum(cs_ext_sales_price) as sales, + sum(cs_net_profit) as profit + from + catalog_sales, + date_dim + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + group by + cs_call_center_sk), + cr as ( + select + sum(cr_return_amount) as returns, + sum(cr_net_loss) as profit_loss + from catalog_returns, + date_dim + where + cr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days)), +ws as ( select wp_web_page_sk, + sum(ws_ext_sales_price) as sales, + sum(ws_net_profit) as profit + from web_sales, + date_dim, + web_page + where ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_page_sk = wp_web_page_sk + group by wp_web_page_sk), + wr as + (select wp_web_page_sk, + sum(wr_return_amt) as returns, + sum(wr_net_loss) as profit_loss + from web_returns, + date_dim, + web_page + where wr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and wr_web_page_sk = wp_web_page_sk + group by wp_web_page_sk) + , + results as + (select channel + , id + , sum(sales) as sales + , sum(returns) as returns + , sum(profit) as profit + from + (select 'store channel' as channel + , ss.s_store_sk as id + , sales + , coalesce(returns, 0) as returns + , (profit - coalesce(profit_loss,0)) as profit + from ss left join sr + on ss.s_store_sk = sr.s_store_sk + union all + select 'catalog channel' as channel + , cs_call_center_sk as id + , sales + , returns + , (profit - profit_loss) as profit + from cs + , cr + union all + select 'web channel' as channel + , ws.wp_web_page_sk as id + , sales + , coalesce(returns, 0) returns + , (profit - coalesce(profit_loss,0)) as profit + from ws left join wr + on ws.wp_web_page_sk = wr.wp_web_page_sk + ) x + group by channel, id ) + + select * + from ( + select channel, id, sales, returns, profit from results + union + select channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results group by channel + union + select NULL AS channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results +) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql new file mode 100755 index 0000000000000..d03d8af77174c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql @@ -0,0 +1,75 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + -- order-by list of q78 in TPCDS v1.4 is below: + -- ratio, + -- ss_qty DESC, ss_wc DESC, ss_sp DESC, + -- other_chan_qty, + -- other_chan_wholesale_cost, + -- other_chan_sales_price, + -- round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) + ss_sold_year, + ss_item_sk, + ss_customer_sk, + ss_qty desc, + ss_wc desc, + ss_sp desc, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + ratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql new file mode 100644 index 0000000000000..686e03ba2a6d0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql @@ -0,0 +1,147 @@ +-- This is a new query in TPCDS v2.7 +with ssr as ( + select + s_store_id as store_id, + sum(ss_ext_sales_price) as sales, + sum(coalesce(sr_return_amt, 0)) as returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit + from + store_sales left outer join store_returns on ( + ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number), + date_dim, + store, + item, + promotion + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + and ss_item_sk = i_item_sk + and i_current_price > 50 + and ss_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + s_store_id), +csr as ( + select + cp_catalog_page_id as catalog_page_id, + sum(cs_ext_sales_price) as sales, + sum(coalesce(cr_return_amount, 0)) as returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit + from + catalog_sales left outer join catalog_returns on + (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number), + date_dim, + catalog_page, + item, + promotion + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and cs_catalog_page_sk = cp_catalog_page_sk + and cs_item_sk = i_item_sk + and i_current_price > 50 + and cs_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(ws_ext_sales_price) as sales, + sum(coalesce(wr_return_amt, 0)) as returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit + from + web_sales left outer join web_returns on ( + ws_item_sk = wr_item_sk and ws_order_number = wr_order_number), + date_dim, + web_site, + item, + promotion + where + ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_site_sk = web_site_sk + and ws_item_sk = i_item_sk + and i_current_price > 50 + and ws_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || store_id as id, + sales, + returns, + profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || catalog_page_id as id, + sales, + returns, + profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + profit + from + wsr) x + group by + channel, id) +select + channel, + id, + sales, + returns, + profit +from ( + select + channel, + id, + sales, + returns, + profit + from + results + union + select + channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results + group by + channel + union + select + NULL AS channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql new file mode 100644 index 0000000000000..fff76b08d4ba0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql @@ -0,0 +1,61 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ws_net_paid) as total_sum, + i_category, i_class, + 0 as g_category, + 0 as g_class + from + web_sales, date_dim d1, item + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ws_sold_date_sk + and i_item_sk = ws_item_sk + group by + i_category, i_class), +results_rollup as( + select + total_sum, + i_category, + i_class, + g_category, + g_class, + 0 as lochierarchy + from + results + union + select + sum(total_sum) as total_sum, + i_category, + NULL as i_class, + 0 as g_category, + 1 as g_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(total_sum) as total_sum, + NULL as i_category, + NULL as i_class, + 1 as g_category, + 1 as g_class, + 2 as lochierarchy + from + results) +select + total_sum, + i_category ,i_class, lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_class = 0 then i_category end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql new file mode 100755 index 0000000000000..771117add2ed2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql @@ -0,0 +1,22 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index 1a584187a06e5..bc95b4696190d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -62,7 +62,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, - |`c_email_address` STRING, `c_last_review_date` STRING) + |`c_email_address` STRING, `c_last_review_date` INT) |USING parquet """.stripMargin) @@ -88,7 +88,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `date_dim` ( - |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_date_sk` INT, `d_date_id` STRING, `d_date` DATE, |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, @@ -115,8 +115,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ - |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, - |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` DATE, + |`i_rec_end_date` DATE, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, @@ -139,8 +139,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `store` ( - |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, - |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` DATE, + |`s_rec_end_date` DATE, `s_closed_date_sk` INT, `s_store_name` STRING, |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, @@ -157,7 +157,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, - |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_quantity` INT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), @@ -225,7 +225,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, - |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` INT, |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), @@ -244,7 +244,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, - |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |`web_country` STRING, `web_gmt_offset` DECIMAL(5,2), `web_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) @@ -315,6 +315,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { """.stripMargin) } + // The TPCDS queries below are based on v1.4 val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -339,6 +340,25 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { } } + // This list only includes TPCDS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7_0 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + + tpcdsQueriesV2_7_0.foreach { name => + val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"$name-v2.7") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } + } + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries val modifiedTPCDSQueries = Seq( "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", From e4bec7cb88b9ee63f8497e3f9e0ab0bfa5d5a77c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 25 Mar 2018 16:38:49 -0700 Subject: [PATCH 0521/2461] [SPARK-23549][SQL] Cast to timestamp when comparing timestamp with date ## What changes were proposed in this pull request? This PR fixes an incorrect comparison in SQL between timestamp and date. This is because both of them are casted to `string` and then are compared lexicographically. This implementation shows `false` regarding this query `spark.sql("select cast('2017-03-01 00:00:00' as timestamp) between cast('2017-02-28' as date) and cast('2017-03-01' as date)").show`. This PR shows `true` for this query by casting `date("2017-03-01")` to `timestamp("2017-03-01 00:00:00")`. (Please fill in changes proposed in this fix) ## How was this patch tested? Added new UTs to `TypeCoercionSuite`. Author: Kazuaki Ishizaki Closes #20774 from kiszk/SPARK-23549. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 ++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 13 ++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 34 ++++++++++++--- .../sql-tests/inputs/predicate-functions.sql | 7 ++++ .../results/predicate-functions.sql.out | 42 ++++++++++++++++++- 6 files changed, 108 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 421e2eaf62bfb..2b393f30d1435 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1808,6 +1808,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e8669c4637d06..ec7e7761dc4c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -47,9 +47,9 @@ import org.apache.spark.sql.types._ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = - InConversion :: + InConversion(conf) :: WidenSetOperationTypes :: - PromoteStrings :: + PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: FunctionArgumentConversion :: @@ -127,7 +127,8 @@ object TypeCoercion { * is a String and the other is not. It also handles when one op is a Date and the * other is a Timestamp by making the target type to be String. */ - val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + private def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true @@ -135,11 +136,17 @@ object TypeCoercion { case (DateType, StringType) => Some(StringType) case (StringType, TimestampType) => Some(StringType) case (TimestampType, StringType) => Some(StringType) - case (TimestampType, DateType) => Some(StringType) - case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + // Cast to TimestampType when we compare DateType with TimestampType + // if conf.compareDateTimestampInTimestamp is true + // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true + case (TimestampType, DateType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + case (DateType, TimestampType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + // There is no proper decimal type we can pick, // using double type is the best we can do. // See SPARK-22469 for details. @@ -147,7 +154,7 @@ object TypeCoercion { case (s: StringType, n: DecimalType) => Some(DoubleType) case (l: StringType, r: AtomicType) if r != StringType => Some(r) - case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) case (l, r) => None } @@ -313,7 +320,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends TypeCoercionRule { + case class PromoteStrings(conf: SQLConf) extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -342,8 +349,8 @@ object TypeCoercion { p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) - if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => - val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) @@ -374,7 +381,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends TypeCoercionRule { + case class InConversion(conf: SQLConf) extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -400,7 +407,7 @@ object TypeCoercion { val rhs = sub.output val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) + findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 11864bd1b1847..9cb03b5bb6152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -479,6 +479,16 @@ object SQLConf { .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = + buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp") + .internal() + .doc("When true (default), compare Date with Timestamp after converting both sides to " + + "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + + "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + + "converting both sides to string. This config will be removed in spark 3.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -1332,6 +1342,9 @@ class SQLConf extends Serializable with Logging { def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def compareDateTimestampInTimestamp : Boolean = + getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 52a7ebdafd7c7..8ac49dc05e3cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1207,7 +1207,7 @@ class TypeCoercionSuite extends AnalysisTest { */ test("make sure rules do not fire early") { // InConversion - val inConversion = TypeCoercion.InConversion + val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, In(UnresolvedAttribute("a"), Seq(Literal(1))), In(UnresolvedAttribute("a"), Seq(Literal(1))) @@ -1251,18 +1251,40 @@ class TypeCoercionSuite extends AnalysisTest { } test("binary comparison with string promotion") { - ruleTest(PromoteStrings, + val rule = TypeCoercion.PromoteStrings(conf) + ruleTest(rule, GreaterThan(Literal("123"), Literal(1)), GreaterThan(Cast(Literal("123"), IntegerType), Literal(1))) - ruleTest(PromoteStrings, + ruleTest(rule, LessThan(Literal(true), Literal("123")), LessThan(Literal(true), Cast(Literal("123"), BooleanType))) - ruleTest(PromoteStrings, + ruleTest(rule, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) - ruleTest(PromoteStrings, + ruleTest(rule, GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), - GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), + DoubleType))) + Seq(true, false).foreach { convertToTS => + withSQLConf( + "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) { + val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) + val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) + val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) + if (convertToTS) { + // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549 + ruleTest(rule, EqualTo(date0301, timestamp0301000000), + EqualTo(Cast(date0301, TimestampType), timestamp0301000000)) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, TimestampType), timestamp0301000001)) + } else { + ruleTest(rule, LessThan(date0301, timestamp0301000000), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType))) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType))) + } + } + } } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index e99d5cef81f64..fadb4bb27fa13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -39,3 +39,10 @@ select 2.0 <= '2.2'; select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; + +-- SPARK-23549: Cast to timestamp when comparing timestamp with date +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00'); +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01'); +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01'); +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01'); +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01'); diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index d51f6d37e4b41..cf828c69af62a 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 37 -- !query 0 @@ -256,3 +256,43 @@ select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> -- !query 31 output true + + +-- !query 32 +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00') +-- !query 32 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) = to_timestamp('2017-03-01 00:00:00')):boolean> +-- !query 32 output +true + + +-- !query 33 +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01') +-- !query 33 schema +struct<(to_timestamp('2017-03-01 00:00:01') > CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 33 output +true + + +-- !query 34 +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01') +-- !query 34 schema +struct<(to_timestamp('2017-03-01 00:00:01') >= CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 34 output +true + + +-- !query 35 +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01') +-- !query 35 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) < to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 35 output +true + + +-- !query 36 +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01') +-- !query 36 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) <= to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 36 output +true From a9350d7095b79c8374fb4a06fd3f1a1a67615f6f Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 26 Mar 2018 12:42:32 +0900 Subject: [PATCH 0522/2461] [SPARK-23700][PYTHON] Cleanup imports in pyspark.sql ## What changes were proposed in this pull request? This cleans up unused imports, mainly from pyspark.sql module. Added a note in function.py that imports `UserDefinedFunction` only to maintain backwards compatibility for using `from pyspark.sql.function import UserDefinedFunction`. ## How was this patch tested? Existing tests and built docs. Author: Bryan Cutler Closes #20892 from BryanCutler/pyspark-cleanup-imports-SPARK-23700. --- python/pyspark/sql/column.py | 1 - python/pyspark/sql/conf.py | 1 - python/pyspark/sql/functions.py | 3 +-- python/pyspark/sql/group.py | 3 +-- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 -- python/pyspark/sql/types.py | 1 - python/pyspark/sql/udf.py | 6 ++---- python/pyspark/util.py | 2 -- 9 files changed, 5 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e05a7b33c11a7..922c7cf288f8f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,7 +16,6 @@ # import sys -import warnings import json if sys.version >= '3': diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index b82224b6194ed..db49040e17b63 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -67,7 +67,6 @@ def _checkType(self, obj, identifier): def _test(): import os import doctest - from pyspark.context import SparkContext from pyspark.sql.session import SparkSession import pyspark.sql.conf diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dff590983b4d9..a4edb1e27b599 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,7 +18,6 @@ """ A collections of builtin functions """ -import math import sys import functools import warnings @@ -28,10 +27,10 @@ from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType +# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 35cac406e0965..3505065b648f2 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,9 +19,8 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal +from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e5288636c596e..4f9b9383a5ef4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,7 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD, since, keyword_only +from pyspark import RDD, since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 07f9ac1b5aa9e..c7907aaaf1f7b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,8 +24,6 @@ else: intlike = (int, long) -from abc import ABCMeta, abstractmethod - from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5d5919e451b46..1f6534836d64a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,7 +35,6 @@ from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer -from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 24dd06c26089c..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,16 +17,14 @@ """ User-defined function related classes and functions """ -import sys -import inspect import functools import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ + to_arrow_type, to_arrow_schema from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ed1bdd0e4be83..49afc13640332 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,8 +22,6 @@ __all__ = [] -import sys - def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both From 087fb3142028d679524e22596b0ad4f74ff47e8d Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 26 Mar 2018 12:45:45 +0900 Subject: [PATCH 0523/2461] [SPARK-23645][MINOR][DOCS][PYTHON] Add docs RE `pandas_udf` with keyword args ## What changes were proposed in this pull request? Add documentation about the limitations of `pandas_udf` with keyword arguments and related concepts, like `functools.partial` fn objects. NOTE: intermediate commits on this PR show some of the steps that can be taken to fix some (but not all) of these pain points. ### Survey of problems we face today: (Initialize) Note: python 3.6 and spark 2.4snapshot. ``` from pyspark.sql import SparkSession import inspect, functools from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, udf spark = SparkSession.builder.getOrCreate() print(spark.version) df = spark.range(1,6).withColumn('b', col('id') * 2) def ok(a,b): return a+b ``` Using a keyword argument at the call site `b=...` (and yes, *full* stack trace below, haha): ``` ---> 14 df.withColumn('ok', pandas_udf(f=ok, returnType='bigint')('id', b='id')).show() # no kwargs TypeError: wrapper() got an unexpected keyword argument 'b' ``` Using partial with a keyword argument where the kw-arg is the first argument of the fn: *(Aside: kind of interesting that lines 15,16 work great and then 17 explodes)* ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in () 15 df.withColumn('ok', pandas_udf(f=functools.partial(ok, 7), returnType='bigint')('id')).show() 16 df.withColumn('ok', pandas_udf(f=functools.partial(ok, b=7), returnType='bigint')('id')).show() ---> 17 df.withColumn('ok', pandas_udf(f=functools.partial(ok, a=7), returnType='bigint')('id')).show() /Users/stu/ZZ/spark/python/pyspark/sql/functions.py in pandas_udf(f, returnType, functionType) 2378 return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) 2379 else: -> 2380 return _create_udf(f=f, returnType=return_type, evalType=eval_type) 2381 2382 /Users/stu/ZZ/spark/python/pyspark/sql/udf.py in _create_udf(f, returnType, evalType) 54 argspec.varargs is None: 55 raise ValueError( ---> 56 "Invalid function: 0-arg pandas_udfs are not supported. " 57 "Instead, create a 1-arg pandas_udf and ignore the arg in your function." 58 ) ValueError: Invalid function: 0-arg pandas_udfs are not supported. Instead, create a 1-arg pandas_udf and ignore the arg in your function. ``` Author: Michael (Stu) Stewart Closes #20900 from mstewart141/udfkw2. --- python/pyspark/sql/functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4edb1e27b599..ad3e37c872628 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2154,6 +2154,8 @@ def udf(f=None, returnType=StringType()): in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + .. note:: The user-defined functions do not take keyword arguments on the calling side. + :param f: python function if used as a standalone function :param returnType: the return type of the user-defined function. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. @@ -2337,6 +2339,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + + .. note:: The user-defined functions do not take keyword arguments on the calling side. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From eb48edf9ca4f4b42c63f145718696472cb6a31ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 14:01:04 +0800 Subject: [PATCH 0524/2461] [SPARK-23787][TESTS] Fix file download test in SparkSubmitSuite for Hadoop 2.9. This particular test assumed that Hadoop libraries did not support http as a file system. Hadoop 2.9 does, so the test failed. The test now forces a non-existent implementation for the http fs, which forces the expected error. There were also a couple of other issues in the same test: SparkSubmit arguments in the wrong order, and the wrong check later when asserting, which was being masked by the previous issues. Author: Marcelo Vanzin Closes #20895 from vanzin/SPARK-23787. --- .../spark/deploy/SparkSubmitSuite.scala | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2d0c192db4915..d86ef907b4492 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -959,25 +959,28 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) } test("force download from blacklisted schemes") { - testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) } - private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, - supportMockHttpFs: Boolean): Unit = { + private def testRemoteResources( + enableHttpFs: Boolean, + blacklistHttpFs: Boolean): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - if (supportMockHttpFs) { + if (enableHttpFs) { hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) - hadoopConf.set("fs.http.impl.disable.cache", "true") + } else { + hadoopConf.set("fs.http.impl", getClass().getName() + ".DoesNotExist") } + hadoopConf.set("fs.http.impl.disable.cache", "true") val tmpDir = Utils.createTempDir() val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) @@ -986,20 +989,19 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + val forceDownloadArgs = if (blacklistHttpFs) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + } else { + Nil + } + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", "--master", "yarn", "--deploy-mode", "client", - "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", - s"s3a://$mainResource" - ) ++ ( - if (isHttpSchemeBlacklisted) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") - } else { - Nil - } - ) + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath" + ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) @@ -1009,7 +1011,7 @@ class SparkSubmitSuite // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) - if (supportMockHttpFs) { + if (enableHttpFs && !blacklistHttpFs) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) From b30a7d28b399950953d4b112c57d4c9b9ab223e9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 12:45:45 -0700 Subject: [PATCH 0525/2461] [SPARK-23572][DOCS] Bring "security.md" up to date. This change basically rewrites the security documentation so that it's up to date with new features, more correct, and more complete. Because security is such an important feature, I chose to move all the relevant configuration documentation to the security page, instead of having them peppered all over the place in the configuration page. This allows an almost one-stop shop for security configuration in Spark. The only exceptions are some YARN-specific minor features which I left in the YARN page. I also re-organized the page's topics, since they didn't make a lot of sense. You had kerberos features described inside paragraphs talking about UI access control, and other oddities. It should be easier now to find information about specific Spark security features. I also enabled TOCs for both the Security and YARN pages, since that makes it easier to see what is covered. I removed most of the comments from the SecurityManager javadoc since they just replicated information in the security doc, with different levels of out-of-dateness. Author: Marcelo Vanzin Closes #20742 from vanzin/SPARK-23572. --- .gitignore | 1 + .../org/apache/spark/SecurityManager.scala | 144 +--- docs/configuration.md | 359 +--------- docs/monitoring.md | 40 +- docs/running-on-yarn.md | 203 +++--- docs/security.md | 629 +++++++++++++++--- 6 files changed, 673 insertions(+), 703 deletions(-) diff --git a/.gitignore b/.gitignore index 39085904e324c..e4c44d0590d59 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ streaming-tests.log target/ unit-tests.log work/ +docs/.jekyll-metadata # For Hive TempStatsStore/ diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index da1c89cd78901..09ec8932353a0 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -42,148 +42,10 @@ import org.apache.spark.util.Utils * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. * - * Spark currently supports authentication via a shared secret. - * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do - * authentication using the shared secret. This authentication is a basic handshake to - * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.acls.enable', 'spark.ui.view.acls' and - * 'spark.ui.view.acls.groups' control the behavior of the acls. Note that the person who - * started the application always has view access to the UI. - * - * Spark has a set of individual and group modify acls (`spark.modify.acls`) and - * (`spark.modify.acls.groups`) that controls which users and groups have permission to - * modify a single application. This would include things like killing the application. - * By default the person who started the application has modify access. For modify access - * through the UI, you must have a filter that does authentication in place for the modify - * acls to work properly. - * - * Spark also has a set of individual and group admin acls (`spark.admin.acls`) and - * (`spark.admin.acls.groups`) which is a set of users/administrators and admin groups - * who always have permission to view or modify the Spark application. - * - * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. - * - * At this point spark has multiple communication protocols that need to be secured and - * different underlying mechanisms are used depending on the protocol: - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spnego, etc. It also supports multiple different login - * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. - * Since we are using DIGEST-MD5, the shared secret is not passed on the wire - * in plaintext. - * - * We currently support SSL (https) for this communication protocol (see the details - * below). - * - * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default - * Authenticator installed in the SecurityManager to how it does the authentication - * and in this case gets the user name and password from the request. - * - * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 - * as the authentication mechanism. This means the shared secret is not passed - * over the wire in plaintext. - * Note that SASL is pluggable as to what mechanism it uses. We currently use - * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. - * Spark currently supports "auth" for the quality of protection, which means - * the connection does not support integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which - * SPARK could support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to - * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the NioBlockTransferService does asynchronous messages passing, the SASL - * authentication is a bit more complex. A ConnectionManager can be both a client - * and a Server, so for a particular connection it has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to - * match up incoming messages with connections waiting for authentication. - * The ConnectionManager tracks all the sendingConnections using the ConnectionId, - * waits for the response from the server, and does the handshake before sending - * the real message. - * - * The NettyBlockTransferService ensures that SASL authentication is performed - * synchronously prior to any other communication on a connection. This is done in - * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. - * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters - * can be used. Yarn requires a specific AmIpFilter be installed for security to work - * properly. For non-Yarn deployments, users can write a filter to go through their - * organization's normal login service. If an authentication filter is in place then the - * SparkUI can be configured to check the logged in user against the list of users who - * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters - * could be used for logging, encryption, or compression. - * - * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. - * - * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop - * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to - * support different levels of protection. See the Hadoop documentation for more details. Each - * Spark application on YARN gets a different shared secret. - * - * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user - * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web - * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That - * authentication then happens via the ResourceManager Proxy and Spark will use that to do - * authorization against the view acls. - * - * For other Spark deployments, the shared secret must be specified via the - * spark.authenticate.secret config. - * All the nodes (Master and Workers) and the applications need to have the same shared secret. - * This again is not ideal as one user could potentially affect another users application. - * This should be enhanced in the future to provide better protection. - * If the UI needs to be secure, the user needs to install a javax servlet filter to do the - * authentication. Spark will then use that user to compare against the view acls to do - * authorization. If not filter is in place the user is generally null and no authorization - * can take place. - * - * When authentication is being used, encryption can also be enabled by setting the option - * spark.authenticate.enableSaslEncryption to true. This is only supported by communication - * channels that use the network-common library, and can be used as an alternative to SSL in those - * cases. - * - * SSL can be used for encryption for certain communication channels. The user can configure the - * default SSL settings which will be used for all the supported communication protocols unless - * they are overwritten by protocol specific settings. This way the user can easily provide the - * common settings for all the protocols without disabling the ability to configure each one - * individually. - * - * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, - * denote the global configuration for all the supported protocols. In order to override the global - * configuration for the particular protocol, the properties must be overwritten in the - * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global - * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be only`fs` for - * broadcast and file server. - * - * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of - * options that can be specified. - * - * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions - * object parses Spark configuration at a given namespace and builds the common representation - * of SSL settings. SSLOptions is then used to provide protocol-specific SSLContextFactory for - * Jetty. - * - * SSL must be configured on each node and configured for each component involved in - * communication using the particular protocol. In YARN clusters, the key-store can be prepared on - * the client side then distributed and used by the executors as the part of the application - * (YARN allows the user to deploy files before the application is started). - * In standalone deployment, the user needs to provide key-stores and configuration - * options for master and workers. In this mode, the user may allow the executors to use the SSL - * settings inherited from the worker which spawned that executor. It can be accomplished by - * setting `spark.ssl.useNodeLocalConf` to `true`. + * This class implements all of the configuration related to security features described + * in the "Security" document. Please refer to that document for specific features implemented + * here. */ - private[spark] class SecurityManager( sparkConf: SparkConf, val ioEncryptionKey: Option[Array[Byte]] = None) diff --git a/docs/configuration.md b/docs/configuration.md index e7f2419cc2fa4..2eb6a77434ea6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -712,30 +712,6 @@ Apart from these, the following properties are also available, and may be useful When we fail to register to the external shuffle service, we will retry for maxAttempts times. - - spark.io.encryption.enabled - false - - Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption - be enabled when using this feature. - - - - spark.io.encryption.keySizeBits - 128 - - IO encryption key size in bits. Supported values are 128, 192 and 256. - - - - spark.io.encryption.keygen.algorithm - HmacSHA1 - - The algorithm to use when generating the IO encryption key. The supported algorithms are - described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm - Name Documentation. - - ### Spark UI @@ -893,6 +869,23 @@ Apart from these, the following properties are also available, and may be useful How many dead executors the Spark UI and status APIs remember before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark Web UI. The filter should be a + standard + javax servlet Filter. + +
    Filter parameters can also be specified in the configuration, by setting config entries + of the form spark.<class name of filter>.param.<param name>=<value> + +
    For example: +
    spark.ui.filters=com.test.filter1 +
    spark.com.test.filter1.param.name1=foo +
    spark.com.test.filter1.param.name2=bar + + ### Compression and Serialization @@ -1446,6 +1439,15 @@ Apart from these, the following properties are also available, and may be useful Duration for an RPC remote endpoint lookup operation to wait before timing out. + + spark.core.connection.ack.wait.timeout + spark.network.timeout + + How long for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + ### Scheduling @@ -1817,313 +1819,8 @@ Apart from these, the following properties are also available, and may be useful ### Security - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.acls.enablefalse - Whether Spark acls should be enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
    spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things do not work. Putting a "*" in the list means any user can have the - privilege of admin. -
    spark.admin.acls.groupsEmpty - Comma separated list of groups that have view and modify access to all Spark jobs. - This can be used if you have a set of administrators or developers who help maintain and debug - the underlying infrastructure. Putting a "*" in the list means any user in any group can have - the privilege of admin. The user groups are obtained from the instance of the groups mapping - provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user is determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. - A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider - which can be specified to resolve a list of groups for a user. - Note: This implementation supports only a Unix/Linux based environment. Windows environment is - currently not supported. However, a new platform/protocol can be supported by implementing - the trait org.apache.spark.security.GroupMappingServiceProvider. -
    spark.authenticatefalse - Whether Spark authenticates its internal connections. See - spark.authenticate.secret if not running on YARN. -
    spark.authenticate.secretNone - Set the secret key used for Spark to authenticate between components. This needs to be set if - not running on YARN and authentication is enabled. -
    spark.network.crypto.enabledfalse - Enable encryption using the commons-crypto library for RPC and block transfer service. - Requires spark.authenticate to be enabled. -
    spark.network.crypto.keyLength128 - The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. -
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 - The key factory algorithm to use when generating encryption keys. Should be one of the - algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. -
    spark.network.crypto.saslFallbacktrue - Whether to fall back to SASL authentication if authentication fails using Spark's internal - mechanism. This is useful when the application is connecting to old shuffle services that - do not support the internal Spark authentication protocol. On the server side, this can be - used to block older clients from authenticating against a new shuffle service. -
    spark.network.crypto.config.*None - Configuration values for the commons-crypto library, such as which cipher implementations to - use. The config name should be the name of commons-crypto configuration without the - "commons.crypto" prefix. -
    spark.authenticate.enableSaslEncryptionfalse - Enable encrypted communication when authentication is - enabled. This is supported by the block transfer service and the - RPC endpoints. -
    spark.network.sasl.serverAlwaysEncryptfalse - Disable unencrypted connections for services that support SASL authentication. -
    spark.core.connection.ack.wait.timeoutspark.network.timeout - How long for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. -
    spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). Putting a "*" in - the list means any user can have access to modify it. -
    spark.modify.acls.groupsEmpty - Comma separated list of groups that have modify access to the Spark job. This can be used if you - have a set of administrators or developers from the same team to have access to control the job. - Putting a "*" in the list means any user in any group has the access to modify the Spark job. - The user groups are obtained from the instance of the groups mapping provider specified by - spark.user.groups.mapping. Check the entry spark.user.groups.mapping - for more details. -
    spark.ui.filtersNone - Comma separated list of filter class names to apply to the Spark web UI. The filter should be a - standard - javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of:
    - spark.<class name of filter>.params='param1=value1,param2=value2'
    - For example:
    - -Dspark.ui.filters=com.test.filter1
    - -Dspark.com.test.filter1.params='param1=foo,param2=testing' -
    spark.ui.view.aclsEmpty - Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. Putting a "*" in the list means any user can - have view access to this Spark job. -
    spark.ui.view.acls.groupsEmpty - Comma separated list of groups that have view access to the Spark web ui to view the Spark Job - details. This can be used if you have a set of administrators or developers or users who can - monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view - the Spark job details on the Spark web ui. The user groups are obtained from the instance of the - groups mapping provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    - -### TLS / SSL - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.ssl.enabledfalse - Whether to enable SSL connections on all supported protocols. - -
    When spark.ssl.enabled is configured, spark.ssl.protocol - is required. - -
    All the SSL settings like spark.ssl.xxx where xxx is a - particular configuration property, denote the global configuration for all the supported - protocols. In order to override the global configuration for the particular protocol, - the properties must be overwritten in the protocol-specific namespace. - -
    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Example values for YYY - include fs, ui, standalone, and - historyServer. See SSL - Configuration for details on hierarchical SSL configuration for services. -
    spark.ssl.[namespace].portNone - The port where the SSL service will listen on. - -
    The port must be defined within a namespace configuration; see - SSL Configuration for the available - namespaces. - -
    When not set, the SSL port will be derived from the non-SSL port for the - same service. A value of "0" will make the service bind to an ephemeral port. -
    spark.ssl.enabledAlgorithmsEmpty - A comma separated list of ciphers. The specified ciphers must be supported by JVM. - The reference list of protocols one can find on - this - page. - Note: If not set, it will use the default cipher suites of JVM. -
    spark.ssl.keyPasswordNone - A password to the private key in key-store. -
    spark.ssl.keyStoreNone - A path to a key-store file. The path can be absolute or relative to the directory where - the component is started in. -
    spark.ssl.keyStorePasswordNone - A password to the key-store. -
    spark.ssl.keyStoreTypeJKS - The type of the key-store. -
    spark.ssl.protocolNone - A protocol name. The protocol must be supported by JVM. The reference list of protocols - one can find on this - page. -
    spark.ssl.needClientAuthfalse - Set true if SSL needs client authentication. -
    spark.ssl.trustStoreNone - A path to a trust-store file. The path can be absolute or relative to the directory - where the component is started in. -
    spark.ssl.trustStorePasswordNone - A password to the trust-store. -
    spark.ssl.trustStoreTypeJKS - The type of the trust-store. -
    - +Please refer to the [Security](security.html) page for available options on how to secure different +Spark subsystems. ### Spark SQL diff --git a/docs/monitoring.md b/docs/monitoring.md index d5f7ffcc260a1..01736c77b0979 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -80,7 +80,10 @@ The history server can be configured as follows: -### Spark configuration options +### Spark History Server Configuration Options + +Security options for the Spark History Server are covered more detail in the +[Security](security.html#web-ui) page. @@ -160,41 +163,6 @@ The history server can be configured as follows: Location of the kerberos keytab file for the History Server. - - - - - - - - - - - - - - - diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c010af35f8d2e..e07759a4dba87 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -2,6 +2,8 @@ layout: global title: Running Spark on YARN --- +* This will become a table of contents (this text will be scraped). +{:toc} Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) @@ -217,8 +219,8 @@ To use a custom metrics.properties for the application master and executors, upd @@ -265,19 +267,6 @@ To use a custom metrics.properties for the application master and executors, upd distribution. - - - - - @@ -373,31 +362,6 @@ To use a custom metrics.properties for the application master and executors, upd in YARN ApplicationReports, which can be used for filtering when querying YARN apps. - - - - - - - - - - - - - - - @@ -424,17 +388,6 @@ To use a custom metrics.properties for the application master and executors, upd See spark.yarn.config.gatewayPath. - - - - - @@ -468,48 +421,104 @@ To use a custom metrics.properties for the application master and executors, upd - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. -# Running in a Secure Cluster +# Kerberos + +Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. + +In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home +directory, Spark will also automatically obtain delegation tokens for the service hosting the +staging directory of the Spark application. + +If an application needs to interact with other secure Hadoop filesystems, their URIs need to be +explicitly provided to Spark at launch time. This is done by listing them in the +`spark.yarn.access.hadoopFileSystems` property, described in the configuration section below. -As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to -authenticate principals associated with services and clients. This allows clients to -make requests of these authenticated services; the services to grant rights -to the authenticated principals. +The YARN integration also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. These +providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to +`false`, where `{service}` is the name of the credential provider. + +## YARN-specific Kerberos Configuration + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse - Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had - set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via - spark.ui.view.acls and groups specified via spark.ui.view.acls.groups - when the application was run will also have authorization to view that application. - If disabled, no access control checks are made. -
    spark.history.ui.admin.aclsempty - Comma separated list of users/administrators that have view access to all the Spark applications in - history server. By default only the users permitted to view the application at run-time could - access the related application history, with this, configured users/administrators could also - have the permission to access it. - Putting a "*" in the list means any user can have the privilege of admin. -
    spark.history.ui.admin.acls.groupsempty - Comma separated list of groups that have view access to all the Spark applications in - history server. By default only the groups permitted to view the application at run-time could - access the related application history, with this, configured groups could also - have the permission to access it. - Putting a "*" in the list means any group can have the privilege of admin. -
    spark.history.fs.cleaner.enabled falsespark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to - being added to YARN's distributed cache. For use in cases where the YARN service does not + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not support schemes that are supported by Spark, like http, https and ftp.
    spark.yarn.access.hadoopFileSystems(none) - A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For - example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, - webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed - and Kerberos must be properly configured to be able to access them (either in the same realm - or in a trusted realm). Spark acquires security tokens for each of the filesystems so that - the Spark application can access those remote Hadoop filesystems. spark.yarn.access.namenodes - is deprecated, please use this instead. -
    spark.yarn.appMasterEnv.[EnvironmentVariableName] (none)
    spark.yarn.keytab(none) - The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) -
    spark.yarn.principal(none) - Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) -
    spark.yarn.kerberos.relogin.period1m - How often to check whether the kerberos TGT should be renewed. This should be set to a value - that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). - The default value should be enough for most deployments. -
    spark.yarn.config.gatewayPath (none)
    spark.security.credentials.${service}.enabledtrue - Controls whether to obtain credentials for services when security is enabled. - By default, credentials for all supported services are retrieved when those services are - configured, but it's possible to disable that behavior if it somehow conflicts with the - application being run. For further details please see - [Running in a Secure Cluster](running-on-yarn.html#running-in-a-secure-cluster) -
    spark.yarn.rolledLog.includePattern (none)
    + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. This keytab + will be copied to the node running the YARN Application Master via the YARN Distributed Cache, and + will be used for renewing the login tickets and the delegation tokens periodically. Equivalent to + the --keytab command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure clusters. Equivalent to the + --principal command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.access.hadoopFileSystems(none) + A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For + example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the filesystems so that + the Spark application can access those remote Hadoop filesystems. +
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    -Hadoop services issue *hadoop tokens* to grant access to the services and data. -Clients must first acquire tokens for the services they will access and pass them along with their -application as it is launched in the YARN cluster. +## Troubleshooting Kerberos -For a Spark application to interact with any of the Hadoop filesystem (for example hdfs, webhdfs, etc), HBase and Hive, it must acquire the relevant tokens -using the Kerberos credentials of the user launching the application -—that is, the principal whose identity will become that of the launched Spark application. +Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to +enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` +environment variable. -This is normally done at launch time: in a secure cluster Spark will automatically obtain a -token for the cluster's default Hadoop filesystem, and potentially for HBase and Hive. +```bash +export HADOOP_JAAS_DEBUG=true +``` -An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares -the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.security.credentials.hbase.enabled` is not set to `false`. +The JDK classes can be configured to enable extra logging of their Kerberos and +SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` +and `sun.security.spnego.debug=true` -Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration -includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.security.credentials.hive.enabled` is not set to `false`. +``` +-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` -If an application needs to interact with other secure Hadoop filesystems, then -the tokens needed to access these clusters must be explicitly requested at -launch time. This is done by listing them in the `spark.yarn.access.hadoopFileSystems` property. +All these options can be enabled in the Application Master: ``` -spark.yarn.access.hadoopFileSystems hdfs://ireland.example.org:8020/,webhdfs://frankfurt.example.org:50070/ +spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true +spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true ``` -Spark supports integrating with other security-aware services through Java Services mechanism (see -`java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` -should be available to Spark by listing their names in the corresponding file in the jar's -`META-INF/services` directory. These plug-ins can be disabled by setting -`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of -credential provider. +Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log +will include a list of all tokens obtained, and their expiry details -## Configuring the External Shuffle Service + +# Configuring the External Shuffle Service To start the Spark Shuffle Service on each `NodeManager` in your YARN cluster, follow these instructions: @@ -542,7 +551,7 @@ The following extra configuration options are available when the shuffle service -## Launching your application with Apache Oozie +# Launching your application with Apache Oozie Apache Oozie can launch Spark applications as part of a workflow. In a secure cluster, the launched application will need the relevant tokens to access the cluster's @@ -576,35 +585,7 @@ spark.security.credentials.hbase.enabled false The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. -## Troubleshooting Kerberos - -Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to -enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` -environment variable. - -```bash -export HADOOP_JAAS_DEBUG=true -``` - -The JDK classes can be configured to enable extra logging of their Kerberos and -SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` -and `sun.security.spnego.debug=true` - -``` --Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -All these options can be enabled in the Application Master: - -``` -spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true -spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log -will include a list of all tokens obtained, and their expiry details - -## Using the Spark History Server to replace the Spark Web UI +# Using the Spark History Server to replace the Spark Web UI It is possible to use the Spark History Server application page as the tracking URL for running applications when the application UI is disabled. This may be desirable on secure clusters, or to diff --git a/docs/security.md b/docs/security.md index 913d9df50eb1c..3e5607a9a0d67 100644 --- a/docs/security.md +++ b/docs/security.md @@ -3,47 +3,336 @@ layout: global displayTitle: Spark Security title: Security --- +* This will become a table of contents (this text will be scraped). +{:toc} -Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: +# Spark RPC -* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. -* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. +## Authentication -## Web UI +Spark currently supports authentication for RPC channels using a shared secret. Authentication can +be turned on by setting the `spark.authenticate` configuration parameter. -The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). +The exact mechanism used to generate and distribute the shared secret is deployment-specific. -### Authentication +For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle +generating and distributing the shared secret. Each application will use a unique shared secret. In +the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of +secrets to be secure. -A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. +This secret will be shared by all the daemons and applications, so this deployment configuration is +not as secure as the above, especially when considering multi-tenant clusters. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. -Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.authenticatefalseWhether Spark authenticates its internal connections.
    spark.authenticate.secretNone + The secret key used authentication. See above for when this configuration should be set. +
    + +## Encryption -## Event Logging +Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC +authentication must also be enabled and properly configured. AES encryption uses the +[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +configuration system allows access to that library's configuration for advanced users. -If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +There is also support for SASL-based encryption, although it should be considered deprecated. It +is still required when talking to shuffle services from Spark versions older than 2.2.0. -## Encryption +The following table describes the different options available for configuring this feature. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.network.crypto.enabledfalse + Enable AES-based RPC encryption, including the new authentication protocol added in 2.2.0. +
    spark.network.crypto.keyLength128 + The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. +
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 + The key factory algorithm to use when generating encryption keys. Should be one of the + algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. +
    spark.network.crypto.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    spark.network.crypto.saslFallbacktrue + Whether to fall back to SASL authentication if authentication fails using Spark's internal + mechanism. This is useful when the application is connecting to old shuffle services that + do not support the internal Spark authentication protocol. On the shuffle service side, + disabling this feature will block older clients from authenticating. +
    spark.authenticate.enableSaslEncryptionfalse + Enable SASL-based encrypted communication. +
    spark.network.sasl.serverAlwaysEncryptfalse + Disable unencrypted connections for ports using SASL authentication. This will deny connections + from clients that have authentication enabled, but do not request SASL-based encryption. +
    + + +# Local Storage Encryption + +Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle +spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover +encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or +`saveAsTable`. + +The following settings cover enabling encryption for data written to disk: + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.io.encryption.enabledfalse + Enable local disk I/O encryption. Currently supported by all modes except Mesos. It's strongly + recommended that RPC encryption be enabled when using this feature. +
    spark.io.encryption.keySizeBits128 + IO encryption key size in bits. Supported values are 128, 192 and 256. +
    spark.io.encryption.keygen.algorithmHmacSHA1 + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. +
    spark.io.encryption.commons.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    + + +# Web UI + +## Authentication and Authorization + +Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +You will need a filter that implements the authentication method you want to deploy. Spark does not +provide any built-in authentication filters. + +Spark also supports access control to the UI when an authentication filter is present. Each +application can be configured with its own separate access control lists (ACLs). Spark +differentiates between "view" permissions (who is allowed to see the application's UI), and "modify" +permissions (who can do things like kill jobs in a running application). + +ACLs can be configured for either users or groups. Configuration entries accept comma-separated +lists as input, meaning multiple users or groups can be given the desired privileges. This can be +used if you run on a shared cluster and have a set of administrators or developers who need to +monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL +means that all users will have the respective pivilege. By default, only the user submitting the +application is added to the ACLs. + +Group membership is established by using a configurable group mapping provider. The mapper is +configured using the spark.user.groups.mapping config option, described in the table +below. + +The following options control the authentication of Web UIs: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.filtersNone + See the Spark UI configuration for how to configure + filters. +
    spark.acls.enablefalse + Whether UI ACLs should be enabled. If enabled, this checks to see if the user has access + permissions to view or modify the application. Note this requires the user to be authenticated, + so if no authentication filter is installed, this option does not do anything. +
    spark.admin.aclsNone + Comma-separated list of users that have view and modify access to the Spark application. +
    spark.admin.acls.groupsNone + Comma-separated list of groups that have view and modify access to the Spark application. +
    spark.modify.aclsNone + Comma-separated list of users that have modify access to the Spark application. +
    spark.modify.acls.groupsNone + Comma-separated list of groups that have modify access to the Spark application. +
    spark.ui.view.aclsNone + Comma-separated list of users that have view access to the Spark application. +
    spark.ui.view.acls.groupsNone + Comma-separated list of groups that have view access to the Spark application. +
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider, which can be configured by + this property. + +
    By default, a Unix shell-based implementation is used, which collects this information + from the host OS. + +
    Note: This implementation supports only Unix/Linux-based environments. + Windows environment is currently not supported. However, a new platform/protocol can + be supported by implementing the trait mentioned above. +
    + +On YARN, the view and modify ACLs are provided to the YARN service when submitting applications, and +control who has the respective privileges via YARN interfaces. + +## Spark History Server ACLs -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service -and the RPC endpoints. Shuffle files can also be encrypted if desired. +Authentication for the SHS Web UI is enabled the same way as for regular applications, using +servlet filters. -### SSL Configuration +To enable authorization in the SHS, a few extra options are used: + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse + Specifies whether ACLs should be checked to authorize users viewing the applications in + the history server. If enabled, access control checks are performed regardless of what the + individual applications had set for spark.ui.acls.enable. The application owner + will always have authorization to view their own application and any users specified via + spark.ui.view.acls and groups specified via spark.ui.view.acls.groups + when the application was run will also have authorization to view that application. + If disabled, no access control checks are made for any application UIs available through + the history server. +
    spark.history.ui.admin.aclsNone + Comma separated list of users that have view access to all the Spark applications in history + server. +
    spark.history.ui.admin.acls.groupsNone + Comma separated list of groups that have view access to all the Spark applications in history + server. +
    + +The SHS uses the same options to configure the group mapping provider as regular applications. +In this case, the group mapping provider will apply to all UIs server by the SHS, and individual +application configurations will be ignored. + +## SSL Configuration Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the -protocols without disabling the ability to configure each one individually. The common SSL settings -are at `spark.ssl` namespace in Spark configuration. The following table describes the -component-specific configuration namespaces used to override the default settings: +protocols without disabling the ability to configure each one individually. The following table +describes the the SSL configuration namespaces: + + + + @@ -58,49 +347,205 @@ component-specific configuration namespaces used to override the default setting
    Config Namespace Component
    spark.ssl + The default SSL configuration. These values will apply to all namespaces below, unless + explicitly overridden at the namespace level. +
    spark.ssl.ui Spark application Web UI
    -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). -SSL must be configured on each node and configured for each component involved in communication using the particular protocol. +The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be +replaced with one of the above namespaces. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    ${ns}.enabledfalseEnables SSL. When enabled, ${ns}.ssl.protocol is required.
    ${ns}.portNone + The port where the SSL service will listen on. + +
    The port must be defined within a specific namespace configuration. The default + namespace is ignored when reading this configuration. + +
    When not set, the SSL port will be derived from the non-SSL port for the + same service. A value of "0" will make the service bind to an ephemeral port. +
    ${ns}.enabledAlgorithmsNone + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + +
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section + of the Java security guide. The list for Java 8 can be found at + this + page. + +
    Note: If not set, the default cipher suite for the JRE will be used. +
    ${ns}.keyPasswordNone + The password to the private key in the key store. +
    ${ns}.keyStoreNone + Path to the key store file. The path can be absolute or relative to the directory in which the + process is started. +
    ${ns}.keyStorePasswordNonePassword to the key store.
    ${ns}.keyStoreTypeJKSThe type of the key store.
    ${ns}.protocolNone + TLS protocol to use. The protocol must be supported by JVM. + +
    The reference list of protocols can be found in the "Additional JSSE Standard Names" + section of the Java security guide. For Java 8, the list can be found at + this + page. +
    ${ns}.needClientAuthfalseWhether to require client authentication.
    ${ns}.trustStoreNone + Path to the trust store file. The path can be absolute or relative to the directory in which + the process is started. +
    ${ns}.trustStorePasswordNonePassword for the trust store.
    ${ns}.trustStoreTypeJKSThe type of the trust store.
    + +## Preparing the key stores + +Key stores can be generated by `keytool` program. The reference documentation for this tool for +Java 8 is [here](https://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html). +The most basic steps to configure the key stores and the trust store for a Spark Standalone +deployment mode is as follows: + +* Generate a key pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust store +* Distribute the trust store to the cluster nodes ### YARN mode -The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. -For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. +To provide a local trust store or key store file to drivers running in cluster mode, they can be +distributed with the application using the `--files` command line argument (or the equivalent +`spark.files` configuration). The files will be placed on the driver's working directory, so the TLS +configuration should just reference the file name with no absolute path. + +Distributing local key stores this way may require the files to be staged in HDFS (or other similar +distributed file system used by the cluster), so it's recommended that the undelying file system be +configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode -The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +The user needs to provide key stores and configuration options for master and workers. They have to +be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in +`SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. + +The user may allow the executors to use the SSL settings inherited from the worker process. That +can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that case, the settings +provided by the user on the client side are not used. ### Mesos mode -Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables. +Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based +secrets. Spark allows the specification of file-based and environment variable based secrets with +`spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. -### Preparing the key-stores -Key-stores can be generated by `keytool` program. The reference documentation for this tool is -[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic -steps to configure the key-stores and the trust-store for the standalone deployment mode is as -follows: +Depending on the secret store backend secrets can be passed by reference or by value with the +`spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, +respectively. -* Generate a keys pair for each node -* Export the public key of the key pair to a file on each node -* Import all exported public keys into a single trust-store -* Distribute the trust-store over the nodes +Reference type secrets are served by the secret store and referred to by name, for example +`/mysecret`. Value type secrets are passed on the command line and translated into their +appropriate files or environment variables. -### Configuring SASL Encryption +## HTTP Security Headers -SASL encryption is currently supported for the block transfer service when authentication -(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set -`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. +Apache Spark can be configured to include HTTP headers to aid in preventing Cross Site Scripting +(XSS), Cross-Frame Scripting (XFS), MIME-Sniffing, and also to enforce HTTP Strict Transport +Security. -When using an external shuffle service, it's possible to disable unencrypted connections by setting -`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that -option is enabled, applications that are not set up to use SASL encryption will fail to connect to -the shuffle service. + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When enabled, X-Content-Type-Options HTTP response header will be set to "nosniff". +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly. This option is only used when + SSL/TLS is enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    -## Configuring Ports for Network Security + +# Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight firewall settings. Below are the primary ports that Spark uses for its communication and how to configure those ports. -### Standalone mode only +## Standalone mode only @@ -141,7 +586,7 @@ configure those ports.
    -### All cluster managers +## All cluster managers @@ -182,54 +627,70 @@ configure those ports.
    -### HTTP Security Headers -Apache Spark can be configured to include HTTP Headers which aids in preventing Cross -Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP -Strict Transport Security. +# Kerberos + +Spark supports submitting applications in environments that use Kerberos for authentication. +In most cases, Spark relies on the credentials of the current logged in user when authenticating +to Kerberos-aware services. Such credentials can be obtained by logging in to the configured KDC +with tools like `kinit`. + +When talking to Hadoop-based services, Spark needs to obtain delegation tokens so that non-local +processes can authenticate. Spark ships with support for HDFS and other Hadoop file systems, Hive +and HBase. + +When using a Hadoop filesystem (such HDFS or WebHDFS), Spark will acquire the relevant tokens +for the service hosting the user's home directory. + +An HBase token will be obtained if HBase is in the application's classpath, and the HBase +configuration has Kerberos authentication turned (`hbase.security.authentication=kerberos`). + +Similarly, a Hive token will be obtained if Hive is in the classpath, and the configuration includes +URIs for remote metastore services (`hive.metastore.uris` is not empty). + +Delegation token support is currently only supported in YARN and Mesos modes. Consult the +deployment-specific page for more information. + +The following options provides finer-grained control for this feature: - - - - - - + - - - - -
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block - Value for HTTP X-XSS-Protection response header. You can choose appropriate value - from below: -
      -
    • 0 (Disables XSS filtering)
    • -
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, - the browser will sanitize the page.)
    • -
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering - of the page if an attack is detected.)
    • -
    -
    spark.ui.xContentTypeOptions.enabledspark.security.credentials.${service}.enabled true - When value is set to "true", X-Content-Type-Options HTTP response header will be set - to "nosniff". Set "false" to disable. -
    spark.ui.strictTransportSecurityNone - Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate - value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. -
      -
    • max-age=<expire-time>
    • -
    • max-age=<expire-time>; includeSubDomains
    • -
    • max-age=<expire-time>; preload
    • -
    + Controls whether to obtain credentials for services when security is enabled. + By default, credentials for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run.
    - -See the [configuration page](configuration.html) for more details on the security configuration -parameters, and -org.apache.spark.SecurityManager for implementation details about security. +## Long-Running Applications + +Long-running applications may run into issues if their run time exceeds the maximum delegation +token lifetime configured in services it needs to access. + +Spark supports automatically creating new tokens for these applications when running in YARN mode. +Kerberos credentials need to be provided to the Spark application via the `spark-submit` command, +using the `--principal` and `--keytab` parameters. + +The provided keytab will be copied over to the machine running the Application Master via the Hadoop +Distributed Cache. For this reason, it's strongly recommended that both YARN and HDFS be secured +with encryption, at least. + +The Kerberos login will be periodically renewed using the provided credentials, and new delegation +tokens for supported will be created. + + +# Event Logging + +If your applications are using event logging, the directory where the event logs go +(`spark.eventLog.dir`) should be manually created with proper permissions. To secure the log files, +the directory permissions should be set to `drwxrwxrwxt`. The owner and group of the directory +should correspond to the super user who is running the Spark History Server. +This will allow all users to write to the directory but will prevent unprivileged users from +reading, removing or renaming a file unless they own it. The event log files will be created by +Spark with permissions such that only the user and group have read and write access. From 3e778f5a91b0553b09fe0e0ee84d771a71504960 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 26 Mar 2018 15:45:27 -0700 Subject: [PATCH 0526/2461] [SPARK-23162][PYSPARK][ML] Add r2adj into Python API in LinearRegressionSummary ## What changes were proposed in this pull request? Adding r2adj in LinearRegressionSummary for Python API. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes in tests.py. Author: Kevin Yu Closes #20842 from kevinyu98/spark-23162. --- python/pyspark/ml/regression.py | 18 ++++++++++++++++-- python/pyspark/ml/tests.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de0a0fa9f3bf8..9a66d87d7f211 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -336,10 +336,10 @@ def rootMeanSquaredError(self): @since("2.0.0") def r2(self): """ - Returns R^2^, the coefficient of determination. + Returns R^2, the coefficient of determination. .. seealso:: `Wikipedia coefficient of determination \ - ` + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -347,6 +347,20 @@ def r2(self): """ return self._call_java("r2") + @property + @since("2.4.0") + def r2adj(self): + """ + Returns Adjusted R^2, the adjusted coefficient of determination. + + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ + `_ + + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark versions. + """ + return self._call_java("r2adj") + @property @since("2.0.0") def residuals(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index cf1ffa181ecec..6b4376cbf14e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1559,6 +1559,7 @@ def test_linear_regression_summary(self): self.assertAlmostEqual(s.meanSquaredError, 0.0) self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) self.assertTrue(isinstance(s.residuals, DataFrame)) self.assertEqual(s.numInstances, 2) self.assertEqual(s.degreesOfFreedom, 1) From 35997b59f3116830af06b3d40a7675ef0dbf7091 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Mar 2018 14:49:50 +0200 Subject: [PATCH 0527/2461] [SPARK-23794][SQL] Make UUID as stateful expression ## What changes were proposed in this pull request? The UUID() expression is stateful and should implement the `Stateful` trait instead of the `Nondeterministic` trait. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20912 from viirya/SPARK-23794. --- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 4 +++- .../sql/catalyst/expressions/MiscExpressionsSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ec93620038cff..a390f8ef7fd9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -123,7 +123,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { def this() = this(None) @@ -152,4 +152,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = "false") } + + override def freshCopy(): Uuid = Uuid(randomSeed) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 3383d421f5616..b6c269348b002 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -59,6 +59,12 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithGeneratedMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) + + val uuid = Uuid(seed1) + assert(uuid.fastEquals(uuid)) + assert(!uuid.fastEquals(Uuid(seed1))) + assert(!uuid.fastEquals(uuid.freshCopy())) + assert(!uuid.fastEquals(Uuid(seed2))) } test("PrintToStderr") { From c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 27 Mar 2018 14:39:05 -0700 Subject: [PATCH 0528/2461] [SPARK-23096][SS] Migrate rate source to V2 ## What changes were proposed in this pull request? This PR migrate micro batch rate source to V2 API and rewrite UTs to suite V2 test. ## How was this patch tested? UTs. Author: jerryshao Closes #20688 from jerryshao/SPARK-23096. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------- .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++ .../sources/RateStreamSourceV2.scala | 187 ---------- .../execution/streaming/RateSourceSuite.scala | 194 ---------- .../streaming/RateSourceV2Suite.scala | 191 ---------- .../sources/RateStreamProviderSuite.scala | 344 ++++++++++++++++++ 10 files changed, 715 insertions(+), 844 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala deleted file mode 100644 index 03d0f63fa4d7f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) - } - } - - test("basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("valueAtSecond") { - import RateStreamSource._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[ArithmeticException](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) - for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala new file mode 100644 index 0000000000000..9149e50962255 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.nio.file.Files +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + SparkSession.setActiveSession(spark) + } + + override def afterAll(): Unit = { + SparkSession.clearActiveSession() + super.afterAll() + } + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("microbatch - basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("valueAtSecond") { + import RateStreamProvider._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[TreeNodeException[_]](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getCause.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[IllegalArgumentException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + for (msg <- expectedMessages) { + assert(e.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} From ed72badb04a56d8046bbd185245abf5ae265ccfd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 27 Mar 2018 20:06:12 -0700 Subject: [PATCH 0529/2461] [SPARK-23699][PYTHON][SQL] Raise same type of error caught with Arrow enabled ## What changes were proposed in this pull request? When using Arrow for createDataFrame or toPandas and an error is encountered with fallback disabled, this will raise the same type of error instead of a RuntimeError. This change also allows for the traceback of the error to be retained and prevents the accidental chaining of exceptions with Python 3. ## How was this patch tested? Updated existing tests to verify error type. Author: Bryan Cutler Closes #20839 from BryanCutler/arrow-raise-same-error-SPARK-23699. --- python/pyspark/sql/dataframe.py | 25 +++++++++++++------------ python/pyspark/sql/session.py | 13 +++++++------ python/pyspark/sql/tests.py | 10 +++++----- python/pyspark/sql/utils.py | 6 ++++++ 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3fc194d8ec1d1..16f8e52dead7b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2007,7 +2007,7 @@ def toPandas(self): "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) @@ -2015,11 +2015,12 @@ def toPandas(self): else: msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Try to use Arrow optimization when the schema is supported and the required version # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. @@ -2042,12 +2043,12 @@ def toPandas(self): # be executed. So, simply fail in this case for now. msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed unexpectedly:\n %s\n" - "Note that 'spark.sql.execution.arrow.fallback.enabled' does " - "not have an effect in such failure in the middle of " - "computation." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and can not continue. Note that " + "'spark.sql.execution.arrow.fallback.enabled' does not have an effect " + "on failures in the middle of computation.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e82a9750a0014..13d6e2e53dbd0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -674,18 +674,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 967cc83166f3f..01c5dd6ff8c3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3559,7 +3559,7 @@ def test_toPandas_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): @@ -3682,7 +3682,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3707,7 +3707,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3775,14 +3775,14 @@ def test_createDataFrame_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type'): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 578298632dd4c..45363f089a73d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -121,7 +121,10 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion try: import pandas + have_pandas = True except ImportError: + have_pandas = False + if not have_pandas: raise ImportError("Pandas >= %s must be installed; however, " "it was not found." % minimum_pandas_version) if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): @@ -138,7 +141,10 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion try: import pyarrow + have_arrow = True except ImportError: + have_arrow = False + if not have_arrow: raise ImportError("PyArrow >= %s must be installed; however, " "it was not found." % minimum_pyarrow_version) if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): From 34c4b9c57e114cdb390e4dbc7383284d82fea317 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Mar 2018 19:49:27 +0800 Subject: [PATCH 0530/2461] [SPARK-23765][SQL] Supports custom line separator for json datasource ## What changes were proposed in this pull request? This PR proposes to add lineSep option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. The approach is similar with https://github.com/apache/spark/pull/20727; however, one main difference is, it uses text datasource's `lineSep` option to parse line by line in JSON's schema inference. ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Author: hyukjinkwon Closes #20877 from HyukjinKwon/linesep-json. --- python/pyspark/sql/readwriter.py | 14 ++-- python/pyspark/sql/streaming.py | 6 +- python/pyspark/sql/tests.py | 17 +++++ .../spark/sql/catalyst/json/JSONOptions.scala | 11 ++++ .../sql/catalyst/json/JacksonGenerator.scala | 8 ++- .../apache/spark/sql/DataFrameReader.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/json/JsonDataSource.scala | 17 +++-- .../datasources/text/TextOptions.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 66 ++++++++++++++++++- .../datasources/text/TextSuite.scala | 4 +- 12 files changed, 136 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4f9b9383a5ef4..6bd79bc2f43e5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -746,7 +748,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, + lineSep=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -770,12 +773,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, + lineSep=lineSep) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index c7907aaaf1f7b..15f9407389864 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,7 +405,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -468,6 +468,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -482,7 +484,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 01c5dd6ff8c3f..5181053a0d318 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -676,6 +676,23 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478a..5c9adc3332bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.json +import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -85,6 +86,16 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index eb06e4f304f0a..9c413de752a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer +import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator( private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val lineSeparator: String = options.lineSeparatorInWrite + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -251,5 +254,8 @@ private[sql] class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } - def writeLineEnding(): Unit = gen.writeRaw('\n') + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1a5e47508c070..ae3ba1690f696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -366,6 +366,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bb93889dc55e9..bbc063148a72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,6 +518,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 77e7edc8e7a20..5769c09c9a1d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,7 +92,8 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + val json: Dataset[String] = createBaseDataset( + sparkSession, inputPaths, parsedOptions.lineSeparator) inferFromDataset(json, parsedOptions) } @@ -104,13 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { private def createBaseDataset( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): Dataset[String] = { + inputPaths: Seq[FileStatus], + lineSeparator: Option[String]): Dataset[String] = { + val textOptions = lineSeparator.map { lineSep => + Map(TextOptions.LINE_SEPARATOR -> lineSep) + }.getOrElse(Map.empty[String, String]) + val paths = inputPaths.map(_.getPath.toString) sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -120,7 +127,7 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val safeParser = new FailureSafeParser[Text]( input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 18698df9fd8e5..5c1a35434f7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -52,7 +52,7 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[text] object TextOptions { +private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" val LINE_SEPARATOR = "lineSep" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9b17406a816b5..ae93965bc50ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -268,6 +268,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115a..10bac0554484a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -27,7 +28,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} @@ -2063,4 +2064,67 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val data = + s""" + | {"f": + |"a", "f0": 1}$lineSep{"f": + | + |"c", "f0": 2}$lineSep{"f": "d", "f0": 3} + """.stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write.option("lineSep", lineSep).json(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert( + readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).json(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => + testLineSeparator(lineSep) + } + // scalastyle:on nonascii + + test("""SPARK-21289: Support line separator - default value \r, \r\n and \n""") { + val data = + "{\"f\": \"a\", \"f0\": 1}\r{\"f\": \"c\", \"f0\": 2}\r\n{\"f\": \"d\", \"f0\": 3}\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index e8a5299d6ba9d..0e7f3afa9c3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -208,9 +208,11 @@ class TextSuite extends QueryTest with SharedSQLContext { } } - Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => testLineSeparator(lineSep) } + // scalastyle:on nonascii private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString From 761565a3ccbf7f083e587fee14a27b61867a3886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 28 Mar 2018 09:11:52 -0700 Subject: [PATCH 0531/2461] Revert "[SPARK-23096][SS] Migrate rate source to V2" This reverts commit c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 +++++++++++++ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 ----------- .../sources/RateStreamProvider.scala | 125 ------- .../sources/RateStreamSourceV2.scala | 187 ++++++++++ .../execution/streaming/RateSourceSuite.scala | 194 ++++++++++ .../streaming/RateSourceV2Suite.scala | 191 ++++++++++ .../sources/RateStreamProviderSuite.scala | 344 ------------------ 10 files changed, 844 insertions(+), 715 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1b37905543b4e..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,5 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,7 +566,6 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName - val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -588,8 +587,7 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, - "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..649fbbfa184ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister + with DataSourceV2 with ContinuousReadSupport { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + if (schema.nonEmpty) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + (shortName(), RateSourceProvider.SCHEMA) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + new RateStreamContinuousReader(options) + } + + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA + override def readSchema(): StructType = RateSourceProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,19 +98,6 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} - private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } - } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala deleted file mode 100644 index 6cf8520fc544f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ManualClock, SystemClock} - -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { - import RateStreamProvider._ - - private[sources] val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock - } - - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private[sources] val creationTimeMs = { - val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) - require(session.isDefined) - - val metadataLog = - new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - @volatile private var lastTimeMs: Long = creationTimeMs - - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA - - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return List.empty.asJava - } - - val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } - - (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( - p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" -} - -class RateStreamMicroBatchDataReaderFactory( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { - - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) -} - -class RateStreamMicroBatchDataReader( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { - private var count = 0 - - override def next(): Boolean = { - rangeStart + partitionId + numPartitions * count < rangeEnd - } - - override def get(): Row = { - val currValue = rangeStart + partitionId + numPartitions * count - count += 1 - val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) - } - - override def close(): Unit = {} -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala deleted file mode 100644 index 6bdd492f0cb35..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} -import org.apache.spark.sql.types._ - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { - import RateStreamProvider._ - - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - if (options.get(ROWS_PER_SECOND).isPresent) { - val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") - } - } - - if (options.get(RAMP_UP_TIME).isPresent) { - val rampUpTimeSeconds = - JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") - } - } - - if (options.get(NUM_PARTITIONS).isPresent) { - val numPartitions = options.get(NUM_PARTITIONS).get().toInt - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") - } - } - - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) - - override def shortName(): String = "rate" -} - -object RateStreamProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 - - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - val RAMP_UP_TIME = "rampUpTime" - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala new file mode 100644 index 0000000000000..4e2459bb05bd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceOptions) + extends MicroBatchReader { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } + + private val numPartitions = + options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + private val rowsPerSecond = + options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + + // The interval (in milliseconds) between rows in each partition. + // e.g. if there are 4 global rows per second, and 2 partitions, each partition + // should output rows every (1000 * 2 / 4) = 500 ms. + private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond + + override def readSchema(): StructType = { + StructType( + StructField("timestamp", TimestampType, false) :: + StructField("value", LongType, false) :: Nil) + } + + val creationTimeMs = clock.getTimeMillis() + + private var start: RateStreamOffset = _ + private var end: RateStreamOffset = _ + + override def setOffsetRange( + start: Optional[Offset], + end: Optional[Offset]): Unit = { + this.start = start.orElse( + RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) + .asInstanceOf[RateStreamOffset] + + this.end = end.orElse { + val currentTime = clock.getTimeMillis() + RateStreamOffset( + this.start.partitionToValueAndRunTimeMs.map { + case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + // Calculate the number of rows we should advance in this partition (based on the + // current time), and output a corresponding offset. + val readInterval = currentTime - currentReadTime + val numNewRows = readInterval / msPerPartitionBetweenRows + if (numNewRows <= 0) { + startOffset + } else { + (part, ValueRunTimeMsPair( + currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + } + } + ) + }.asInstanceOf[RateStreamOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startMap = start.partitionToValueAndRunTimeMs + val endMap = end.partitionToValueAndRunTimeMs + endMap.keys.toSeq.map { part => + val ValueRunTimeMsPair(endVal, _) = endMap(part) + val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) + + val packedRows = mutable.ListBuffer[(Long, Long)]() + var outVal = startVal + numPartitions + var outTimeMs = startTimeMs + while (outVal <= endVal) { + packedRows.append((outTimeMs, outVal)) + outVal += numPartitions + outTimeMs += msPerPartitionBetweenRows + } + + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} +} + +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) +} + +class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the seq. + currentIndex += 1 + currentIndex < vals.size + } + + override def get(): Row = { + Row( + DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), + vals(currentIndex)._2) + } + + override def close(): Unit = {} +} + +object RateStreamSourceV2 { + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + + private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..03d0f63fa4d7f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala new file mode 100644 index 0000000000000..983ba1668f58f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - numPartitions propagated") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + } + + test("microbatch - set offset") { + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: RateStreamOffset => + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: RateStreamOffset => + // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted + // longer than 100ms. It should never be early. + assert(r.partitionToValueAndRunTimeMs(0).value >= 9) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) + + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) + } + + test("microbatch - data read") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) + val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { + case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) + }.toMap) + + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala deleted file mode 100644 index 9149e50962255..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.nio.file.Files -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.Offset -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - protected override def beforeAll(): Unit = { - super.beforeAll() - SparkSession.setActiveSession(spark) - } - - override def afterAll(): Unit = { - SparkSession.clearActiveSession() - super.afterAll() - } - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( - rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) - (rateSource, offset) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("compatible with old path in registry") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - assert(ds.isInstanceOf[RateStreamProvider]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("microbatch - basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) - } - assert(data.size === 20) - } - - test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("valueAtSecond") { - import RateStreamProvider._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[TreeNodeException[_]](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getCause.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[IllegalArgumentException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - for (msg <- expectedMessages) { - assert(e.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find read support for continuous rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} From ea2fdc0d286e449884de44f22a908a26ab1248a5 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Wed, 28 Mar 2018 19:49:32 -0500 Subject: [PATCH 0532/2461] [SPARK-23675][WEB-UI] Title add spark logo, use spark logo image ## What changes were proposed in this pull request? Title add spark logo, use spark logo image. reference other big data system ui, so i think spark should add it. spark fix before: ![spark_fix_before](https://user-images.githubusercontent.com/26266482/37387866-2d5add0e-2799-11e8-9165-250f2b59df3f.png) spark fix after: ![spark_fix_after](https://user-images.githubusercontent.com/26266482/37387874-329e1876-2799-11e8-8bc5-c619fc1e680e.png) reference kafka ui: ![kafka](https://user-images.githubusercontent.com/26266482/37387878-35ca89d0-2799-11e8-834e-1598ae7158e1.png) reference storm ui: ![storm](https://user-images.githubusercontent.com/26266482/37387880-3854f12c-2799-11e8-8968-b428ba361995.png) reference yarn ui: ![yarn](https://user-images.githubusercontent.com/26266482/37387881-3a72e130-2799-11e8-97bb-dea85f573e95.png) reference nifi ui: ![nifi](https://user-images.githubusercontent.com/26266482/37387887-3cecfea0-2799-11e8-9a71-6c454d25840b.png) reference flink ui: ![flink](https://user-images.githubusercontent.com/26266482/37387888-3f16b1ee-2799-11e8-9d37-8355f0100548.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20818 from guoxiaolongzte/SPARK-23675. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ba798df13c95d..02cf19e00ecde 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -224,6 +224,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (showVisualization) vizHeaderNodes else Seq.empty} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {appName} - {title} @@ -265,6 +266,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {title} From 641aec68e8167546dbb922874c086c9b90198f08 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 29 Mar 2018 16:37:46 +0800 Subject: [PATCH 0533/2461] =?UTF-8?q?[SPARK-23806]=20Broadcast.unpersist?= =?UTF-8?q?=20can=20cause=20fatal=20exception=20when=20used=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with dynamic allocation ## What changes were proposed in this pull request? ignore errors when you are waiting for a broadcast.unpersist. This is handling it the same way as doing rdd.unpersist in https://issues.apache.org/jira/browse/SPARK-22618 ## How was this patch tested? Patch was tested manually against a couple jobs that exhibit this behavior, with the change the application no longer dies due to this and just prints the warning. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Thomas Graves Closes #20924 from tgravescs/SPARK-23806. --- .../spark/storage/BlockManagerMasterEndpoint.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 89a6a71a589a1..56b95c31eb4c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -192,11 +192,15 @@ class BlockManagerMasterEndpoint( val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + val futures = requiredBlockManagers.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove broadcast $broadcastId", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeBlockManager(blockManagerId: BlockManagerId) { From 505480cb578af9f23acc77bc82348afc9d8468e8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 29 Mar 2018 19:38:28 +0900 Subject: [PATCH 0534/2461] [SPARK-23770][R] Exposes repartitionByRange in SparkR ## What changes were proposed in this pull request? This PR proposes to expose `repartitionByRange`. ```R > df <- createDataFrame(iris) ... > getNumPartitions(repartitionByRange(df, 3, col = df$Species)) [1] 3 ``` ## How was this patch tested? Manually tested and the unit tests were added. The diff with `repartition` can be checked as below: ```R > df <- createDataFrame(mtcars) > take(repartition(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 2 10.4 8 460.0 215 3.00 5.424 17.82 0 0 3 4 3 32.4 4 78.7 66 4.08 2.200 19.47 1 1 4 1 > take(repartitionByRange(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 30.4 4 75.7 52 4.93 1.615 18.52 1 1 4 2 2 33.9 4 71.1 65 4.22 1.835 19.90 1 1 4 1 3 27.3 4 79.0 66 4.08 1.935 18.90 1 1 4 1 ``` Author: hyukjinkwon Closes #20902 from HyukjinKwon/r-repartitionByRange. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 65 ++++++++++++++++++++++++++- R/pkg/R/generics.R | 3 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 45 +++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c51eb0f39c4b1..190c50ea10482 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -151,6 +151,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "repartitionByRange", "rollup", "sample", "sample_frac", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c4852024c0f49..a1c9495b0795e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -687,7 +687,7 @@ setMethod("storageLevel", #' @rdname coalesce #' @name coalesce #' @aliases coalesce,SparkDataFrame-method -#' @seealso \link{repartition} +#' @seealso \link{repartition}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -723,7 +723,7 @@ setMethod("coalesce", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method -#' @seealso \link{coalesce} +#' @seealso \link{coalesce}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -759,6 +759,67 @@ setMethod("repartition", dataFrame(sdf) }) + +#' Repartition by range +#' +#' The following options for repartition by range are possible: +#' \itemize{ +#' \item{1.} {Return a new SparkDataFrame range partitioned by +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame range partitioned by the given column(s), +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} +#'} +#' +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the range partitioning will be performed. +#' @param ... additional column(s) to be used in the range partitioning. +#' +#' @family SparkDataFrame functions +#' @rdname repartitionByRange +#' @name repartitionByRange +#' @aliases repartitionByRange,SparkDataFrame-method +#' @seealso \link{repartition}, \link{coalesce} +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- repartitionByRange(df, col = df$col1, df$col2) +#' newDF <- repartitionByRange(df, 3L, col = df$col1, df$col2) +#'} +#' @note repartitionByRange since 2.4.0 +setMethod("repartitionByRange", + signature(x = "SparkDataFrame"), + function(x, numPartitions = NULL, col = NULL, ...) { + if (!is.null(numPartitions) && !is.null(col)) { + # number of partitions and columns both are specified + if (is.numeric(numPartitions) && class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", numToInt(numPartitions), jcol) + } else { + stop(paste("numPartitions and col must be numeric and Column; however, got", + class(numPartitions), "and", class(col))) + } + } else if (!is.null(col)) { + # only columns are specified + if (class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", jcol) + } else { + stop(paste("col must be Column; however, got", class(col))) + } + } else if (!is.null(numPartitions)) { + # only numPartitions is specified + stop("At least one partition-by column must be specified.") + } else { + stop("Please, specify a column(s) or the number of partitions with a column(s)") + } + dataFrame(sdf) + }) + #' toJSON #' #' Converts a SparkDataFrame into a SparkDataFrame of JSON string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6fba4b6c761dd..974beff1a3d76 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -531,6 +531,9 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) +#' @rdname repartitionByRange +setGeneric("repartitionByRange", function(x, ...) { standardGeneric("repartitionByRange") }) + #' @rdname sample setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 439191adb23ea..7105469ffc242 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3104,6 +3104,51 @@ test_that("repartition by columns on DataFrame", { }) }) +test_that("repartitionByRange on a DataFrame", { + # The tasks here launch R workers with shuffles. So, we decrease the number of shuffle + # partitions to reduce the number of the tasks to speed up the test. This is particularly + # slow on Windows because the R workers are unable to be forked. See also SPARK-21693. + conf <- callJMethod(sparkSession, "conf") + shufflepartitionsvalue <- callJMethod(conf, "get", "spark.sql.shuffle.partitions") + callJMethod(conf, "set", "spark.sql.shuffle.partitions", "5") + tryCatch({ + df <- createDataFrame(mtcars) + expect_error(repartitionByRange(df, "haha", df$mpg), + "numPartitions and col must be numeric and Column.*") + expect_error(repartitionByRange(df), + ".*specify a column.*or the number of partitions with a column.*") + expect_error(repartitionByRange(df, col = "haha"), + "col must be Column; however, got.*") + expect_error(repartitionByRange(df, 3), + "At least one partition-by column must be specified.") + + # The order of rows should be different with a normal repartition. + actual <- repartitionByRange(df, 3, df$mpg) + expect_equal(getNumPartitions(actual), 3) + expect_false(identical(collect(actual), collect(repartition(df, 3, df$mpg)))) + + actual <- repartitionByRange(df, col = df$mpg) + expect_false(identical(collect(actual), collect(repartition(df, col = df$mpg)))) + + # They should have same data. + actual <- collect(repartitionByRange(df, 3, df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, 3, df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + + actual <- collect(repartitionByRange(df, col = df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, col = df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.shuffle.partitions", shufflepartitionsvalue) + }) +}) + test_that("coalesce, repartition, numPartitions", { df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) From 491ec114fd3886ebd9fa29a482e3d112fb5a088c Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 29 Mar 2018 10:23:23 -0700 Subject: [PATCH 0535/2461] [SPARK-23785][LAUNCHER] LauncherBackend doesn't check state of connection before setting state ## What changes were proposed in this pull request? Changed `LauncherBackend` `set` method so that it checks if the connection is open or not before writing to it (uses `isConnected`). ## How was this patch tested? None Author: Sahil Takiar Closes #20893 from sahilTakiar/master. --- .../spark/launcher/LauncherBackend.scala | 6 +++--- .../spark/launcher/LauncherServerSuite.java | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index aaae33ca4e6f3..1b049b786023a 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -67,13 +67,13 @@ private[spark] abstract class LauncherBackend { } def setAppId(appId: String): Unit = { - if (connection != null) { + if (connection != null && isConnected) { connection.send(new SetAppId(appId)) } } def setState(state: SparkAppHandle.State): Unit = { - if (connection != null && lastState != state) { + if (connection != null && isConnected && lastState != state) { connection.send(new SetState(state)) lastState = state } @@ -114,10 +114,10 @@ private[spark] abstract class LauncherBackend { override def close(): Unit = { try { + _isConnected = false super.close() } finally { onDisconnected() - _isConnected = false } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index d16337a319be3..5413d3a416545 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -185,6 +185,26 @@ public void testStreamFiltering() throws Exception { } } + @Test + public void testAppHandleDisconnect() throws Exception { + LauncherServer server = LauncherServer.getOrCreateServer(); + ChildProcAppHandle handle = new ChildProcAppHandle(server); + String secret = server.registerHandle(handle); + + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); + client = new TestClient(s); + client.send(new Hello(secret, "1.4.0")); + handle.disconnect(); + waitForError(client, secret); + } finally { + handle.kill(); + close(client); + client.clientThread.join(); + } + } + private void close(Closeable c) { if (c != null) { try { From a7755fd8ce2f022118b9827aaac7d5d59f0f297a Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 29 Mar 2018 10:46:28 -0700 Subject: [PATCH 0536/2461] [SPARK-23639][SQL] Obtain token before init metastore client in SparkSQL CLI ## What changes were proposed in this pull request? In SparkSQLCLI, SessionState generates before SparkContext instantiating. When we use --proxy-user to impersonate, it's unable to initializing a metastore client to talk to the secured metastore for no kerberos ticket. This PR use real user ugi to obtain token for owner before talking to kerberized metastore. ## How was this patch tested? Manually verified with kerberized hive metasotre / hdfs. Author: Kent Yao Closes #20784 from yaooqinn/SPARK-23639. --- .../deploy/security/HiveDelegationTokenProvider.scala | 8 ++++---- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index ece5ce79c650d..7249eb85ac7c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils -private[security] class HiveDelegationTokenProvider +private[spark] class HiveDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" @@ -124,9 +124,9 @@ private[security] class HiveDelegationTokenProvider val currentUser = UserGroupInformation.getCurrentUser() val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) - // For some reason the Scala-generated anonymous class ends up causing an - // UndeclaredThrowableException, even if you annotate the method with @throws. - try { + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { realUser.doAs(new PrivilegedExceptionAction[T]() { override def run(): T = fn }) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 832a15d09599f..084f8200102ba 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,13 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HiveDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -121,6 +123,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { } } + val tokenProvider = new HiveDelegationTokenProvider() + if (tokenProvider.delegationTokensRequired(sparkConf, hadoopConf)) { + val credentials = new Credentials() + tokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + UserGroupInformation.getCurrentUser.addCredentials(credentials) + } + SessionState.start(sessionState) // Clean up after we exit From b348901192b231153b58fe5720253168c87963d4 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 29 Mar 2018 21:36:56 -0700 Subject: [PATCH 0537/2461] [SPARK-23808][SQL] Set default Spark session in test-only spark sessions. ## What changes were proposed in this pull request? Set default Spark session in the TestSparkSession and TestHiveSparkSession constructors. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20926 from jose-torres/test3. --- .../spark/sql/test/TestSQLContext.scala | 2 ++ .../sql/test/TestSparkSessionSuite.scala | 29 +++++++++++++++++++ .../apache/spark/sql/hive/test/TestHive.scala | 4 +++ 3 files changed, 35 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 4286e8a6ca2c8..3038b822beb4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -34,6 +34,8 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) this(new SparkConf) } + SparkSession.setDefaultSession(this) + @transient override lazy val sessionState: SessionState = { new TestSQLSessionStateBuilder(this, None).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala new file mode 100644 index 0000000000000..4019c6888da98 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession + +class TestSparkSessionSuite extends SparkFunSuite { + test("default session is set in constructor") { + val session = new TestSparkSession() + assert(SparkSession.getDefaultSession.contains(session)) + session.stop() + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index fcf2025d34432..814038d4ef7af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,6 +159,10 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => + // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as + // TestSparkSession, but doing this the same way breaks many tests in the package. We need + // to investigate and find a different strategy. + def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From df05fb63abe6018ccbe572c34cf65fc3ecbf1166 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 30 Mar 2018 14:07:35 +0800 Subject: [PATCH 0538/2461] [SPARK-23743][SQL] Changed a comparison logic from containing 'slf4j' to starting with 'org.slf4j' ## What changes were proposed in this pull request? isSharedClass returns if some classes can/should be shared or not. It checks if the classes names have some keywords or start with some names. Following the logic, it can occur unintended behaviors when a custom package has `slf4j` inside the package or class name. As I guess, the first intention seems to figure out the class containing `org.slf4j`. It would be better to change the comparison logic to `name.startsWith("org.slf4j")` ## How was this patch tested? This patch should pass all of the current tests and keep all of the current behaviors. In my case, I'm using ProtobufDeserializer to get a table schema from hive tables. Thus some Protobuf packages and names have `slf4j` inside. Without this patch, it cannot be resolved because of ClassCastException from different classloaders. Author: Jongyoul Lee Closes #20860 from jongyoul/SPARK-23743. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 12975bc85b971..c2690ec32b9e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -179,8 +179,9 @@ private[hive] class IsolatedClientLoader( val isHadoopClass = name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") - name.contains("slf4j") || - name.contains("log4j") || + name.startsWith("org.slf4j") || + name.startsWith("org.apache.log4j") || // log4j1.x + name.startsWith("org.apache.logging.log4j") || // log4j2 name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || From b02e76cbffe9e589b7a4e60f91250ca12a4420b2 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 30 Mar 2018 15:07:38 +0800 Subject: [PATCH 0539/2461] [SPARK-23727][SQL] Support for pushing down filters for DateType in parquet ## What changes were proposed in this pull request? This PR supports for pushing down filters for DateType in parquet ## How was this patch tested? Added UT and tested in local. Author: yucai Closes #20851 from yucai/SPARK-23727. --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../datasources/parquet/ParquetFilters.scala | 33 ++++++++++++ .../parquet/ParquetFilterSuite.scala | 50 +++++++++++++++++-- 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9cb03b5bb6152..13f31a6b2eb93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -353,6 +353,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") + .doc("If true, enables Parquet filter push-down optimization for Date. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1329,6 +1336,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 763841efbd9f3..ccc8306866d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.sql.Date + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ @@ -29,6 +34,10 @@ import org.apache.spark.sql.types._ */ private[parquet] object ParquetFilters { + private def dateToDays(date: Date): SQLDate = { + DateTimeUtils.fromJavaDate(date) + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -50,6 +59,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -72,6 +85,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -91,6 +108,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.lt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -110,6 +131,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.ltEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -129,6 +154,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -148,6 +177,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gtEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 33801954ebd51..1d3476e747046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -76,8 +77,10 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -102,7 +105,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex maybeFilter.exists(_.getClass === filterClass) } checker(stripSparkFilter(query), expected) - } } } @@ -313,6 +315,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - date") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") + + withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate( + Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate( + Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) + checkFilterPredicate( + '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From 5b5a36ed6d2bb0971edfeccddf0f280936d2275f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 30 Mar 2018 21:54:26 +0800 Subject: [PATCH 0540/2461] Roll forward "[SPARK-23096][SS] Migrate rate source to V2" ## What changes were proposed in this pull request? Roll forward c68ec4e (#20688). There are two minor test changes required: * An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException. * The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils. ## How was this patch tested? existing tests Author: Jose Torres Author: jerryshao Closes #20922 from jose-torres/ratefix. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------------ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++++ .../sources/RateStreamSourceV2.scala | 187 ------------- .../streaming/RateSourceV2Suite.scala | 191 ------------- .../RateStreamProviderSuite.scala} | 166 ++++++++++- 9 files changed, 524 insertions(+), 663 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{RateSourceSuite.scala => sources/RateStreamProviderSuite.scala} (50%) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala similarity index 50% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 03d0f63fa4d7f..ff14ec38e66a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -15,13 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources +import java.nio.file.Files +import java.util.Optional import java.util.concurrent.TimeUnit -import org.apache.spark.sql.AnalysisException +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock class RateSourceSuite extends StreamTest { @@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") } } - test("basic") { + test("microbatch - basic") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "10") @@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest { ) } - test("uniform distribution of event timestamps") { + test("microbatch - uniform distribution of event timestamps") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "1500") @@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + test("valueAtSecond") { - import RateStreamSource._ + import RateStreamProvider._ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) @@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest { option: String, value: String, expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { + val e = intercept[IllegalArgumentException] { spark.readStream .format("rate") .option(option, value) @@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest { .start() .awaitTermination() } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } } @@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest { assert(exception.getMessage.contains( "rate source does not support a user-specified schema")) } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } } From bc8d0931170cfa20a4fb64b3b11a2027ddb0d6e9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 30 Mar 2018 23:21:07 +0800 Subject: [PATCH 0541/2461] [SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? This PR is to improve the test coverage of the original PR https://github.com/apache/spark/pull/20687 ## How was this patch tested? N/A Author: gatorsmile Closes #20911 from gatorsmile/addTests. --- .../optimizer/complexTypesSuite.scala | 176 ++++++++++++------ .../apache/spark/sql/ComplexTypesSuite.scala | 109 +++++++++++ 2 files changed, 233 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e44a6692ad8e2..21ed987627b3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { SimplifyExtractValueOps) :: Nil } - val idAtt = ('id).long.notNull - val nullableIdAtt = ('nullable_id).long + private val idAtt = ('id).long.notNull + private val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt, nullableIdAtt) + private val relation = LocalRelation(idAtt, nullableIdAtt) + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + + private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { + val optimized = Optimizer.execute(originalQuery.analyze) + assert(optimized.resolved, "optimized plans must be still resolvable") + comparePlans(optimized, correctAnswer.analyze) + } test("explicit get from namedStruct") { val query = relation @@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField( CreateNamedStruct(Seq("att", 'id )), 0, - None) as "outerAtt").analyze - val expected = relation.select('id as "outerAtt").analyze + None) as "outerAtt") + val expected = relation.select('id as "outerAtt") - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) - .analyze val expected = relation .select('id as "named_struct(att, id).att") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") .select(GetStructField('struct1, 0, None) as "struct1Att") - .analyze - val expected = relation.select('id as "struct1Att").analyze - comparePlans(Optimizer execute query, expected) + val expected = relation.select('id as "struct1Att") + checkRule(query, expected) } test("collapse multiple CreateNamedStruct/GetStructField pairs") { @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None) as "struct1Att1", GetStructField('struct1, 1, None) as "struct1Att2") - .analyze val expected = relation. select( 'id as "struct1Att1", ('id * 'id) as "struct1Att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed2 - deduced names") { @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None), GetStructField('struct1, 1, None)) - .analyze val expected = relation. select( 'id as "struct1.att1", ('id * 'id) as "struct1.att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplified array ops") { @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { 1, false), 1) as "a4") - .analyze val expected = relation .select( @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { "att2", (('id + 1L) * ('id + 1L)))) as "a2", ('id + 1L) as "a3", ('id + 1L) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("SPARK-22570: CreateArray should not create a lot of global variables") { @@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", GetMapValue('m, "r32") as "a3", GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") - .analyze val expected = relation.select( @@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ) ) as "a3", Literal.create(null, LongType) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, constant lookup, dynamic keys") { @@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo(13L, ('id + 1L)), ('id + 2L)), (EqualTo(13L, ('id + 2L)), ('id + 3L)), (Literal(true), 'id))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { @@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), ('id + 3L)) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( @@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), (Literal(true), ('id + 4L)))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, no positive match") { @@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 'id + 30L) as "a") - .analyze val expected = relation.select( CaseWhen(Seq( (EqualTo('id + 30L, 'id), ('id + 1L)), @@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, constant lookup, mixed keys, eliminated constants") { @@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 2L), ('id + 3L), ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, potential dynamic match with null value + an absolute constant match") { @@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 2L ) as "a") - .analyze val expected = relation .select( @@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. (Literal.TrueLiteral, 'id))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) + } + + test("SPARK-23500: Simplify array ops that are not at the top node") { + val query = LocalRelation('id.long) + .select( + CreateArray(Seq( + CreateNamedStruct(Seq( + "att1", 'id, + "att2", 'id * 'id)), + CreateNamedStruct(Seq( + "att1", 'id + 1, + "att2", ('id + 1) * ('id + 1)) + )) + ) as "arr") + .select( + GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetArrayItem( + GetArrayStructFields('arr, + StructField("att1", LongType, nullable = false), + ordinal = 0, + numFields = 1, + containsNull = false), + ordinal = 1) as "a2") + .orderBy('id.asc) + + val expected = LocalRelation('id.long) + .select( + ('id + 1L) as "a1", + ('id + 1L) as "a2") + .orderBy('id.asc) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify map ops that are not top nodes") { + val query = + LocalRelation('id.long) + .select( + CreateMap(Seq( + "r1", 'id, + "r2", 'id + 1L)) as "m") + .select( + GetMapValue('m, "r1") as "a1", + GetMapValue('m, "r32") as "a2") + .orderBy('id.asc) + .select('a1, 'a2) + + val expected = + LocalRelation('id.long).select( + 'id as "a1", + Literal.create(null, LongType) as "a2") + .orderBy('id.asc) + checkRule(query, expected) } test("SPARK-23500: Simplify complex ops that aren't at the plan root") { val structRel = relation .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") - .groupBy($"foo")("1").analyze + .groupBy($"foo")("1") val structExpected = relation .select('nullable_id as "foo") - .groupBy($"foo")("1").analyze - comparePlans(Optimizer execute structRel, structExpected) + .groupBy($"foo")("1") + checkRule(structRel, structExpected) // These tests must use nullable attributes from the base relation for the following reason: // in the 'original' plans below, the Aggregate node produced by groupBy() has a @@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") - .groupBy($"a1")("1").analyze - val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze - comparePlans(Optimizer execute arrayRel, arrayExpected) + .groupBy($"a1")("1") + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + checkRule(arrayRel, arrayExpected) val mapRel = relation .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") - .groupBy($"m1")("1").analyze + .groupBy($"m1")("1") val mapExpected = relation .select('nullable_id as "m1") - .groupBy($"m1")("1").analyze - comparePlans(Optimizer execute mapRel, mapExpected) + .groupBy($"m1")("1") + checkRule(mapRel, mapExpected) } test("SPARK-23500: Ensure that aggregation expressions are not simplified") { @@ -369,11 +407,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // grouping exprs so aren't tested here. val structAggRel = relation.groupBy( CreateNamedStruct(Seq("att1", 'nullable_id)))( - GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze - comparePlans(Optimizer execute structAggRel, structAggRel) + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + checkRule(structAggRel, structAggRel) val arrayAggRel = relation.groupBy( - CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze - comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + checkRule(arrayAggRel, arrayAggRel) + + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) + } + + test("SPARK-23500: namedStruct and getField in the same Project #1") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) + .select('s1 getField "col2" as 's1Col2, + namedStruct("col1", 'a, "col2", 'b).as("s2")) + .select('s1Col2, 's2 getField "col2" as 's2Col2) + val correctAnswer = + testRelation + .select('c as 's1Col2, 'b as 's2Col2) + checkRule(originalQuery, correctAnswer) + } + + test("SPARK-23500: namedStruct and getField in the same Project #2") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, + namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + val correctAnswer = + testRelation + .select('c as 'sCol2, 'a as 'sCol1) + checkRule(originalQuery, correctAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala new file mode 100644 index 0000000000000..b74fe2f90df23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSQLContext + +class ComplexTypesSuite extends QueryTest with SharedSQLContext { + + override def beforeAll() { + super.beforeAll() + spark.range(10).selectExpr( + "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5") + .write.saveAsTable("tab") + } + + override def afterAll() { + try { + spark.sql("DROP TABLE IF EXISTS tab") + } finally { + super.afterAll() + } + } + + def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = { + var count = 0 + plan.foreach { operator => + operator.transformExpressions { + case c: CreateNamedStruct => + count += 1 + c + } + } + + if (expectedCount != count) { + fail(s"expect $expectedCount CreateNamedStruct but got $count.") + } + } + + test("simple case") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2") + .filter("col2.c > 11").selectExpr("col1.a") + checkAnswer(df, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("named_struct is used in the top Project") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .selectExpr("col1.a", "col1") + .filter("col1.a > 8") + checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1) + + val df1 = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .sort("col1") + .selectExpr("col1.a") + .filter("col1.a > 8") + checkAnswer(df1, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1) + } + + test("expression in named_struct") { + val df = spark.table("tab") + .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola") + .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10") + checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola") + .selectExpr("cola.i3").filter("cola.exp > 10") + checkAnswer(df1, Row(12) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("nested case") { + val df = spark.table("tab") + .selectExpr("struct(struct(i2, i3) as exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10") + checkAnswer(df, Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("struct(i2, i3) as exp", "i4") + .selectExpr("struct(exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11") + checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } +} From ae9172017c361e5c1039bc2ca94048117021974a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 30 Mar 2018 14:09:14 -0700 Subject: [PATCH 0542/2461] [SPARK-23640][CORE] Fix hadoop config may override spark config ## What changes were proposed in this pull request? It may be get `spark.shuffle.service.port` from https://github.com/apache/spark/blob/9745ec3a61c99be59ef6a9d5eebd445e8af65b7a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala#L459 Therefore, the client configuration `spark.shuffle.service.port` does not working unless the configuration is `spark.hadoop.spark.shuffle.service.port`. - This configuration is not working: ``` bin/spark-sql --master yarn --conf spark.shuffle.service.port=7338 ``` - This configuration works: ``` bin/spark-sql --master yarn --conf spark.hadoop.spark.shuffle.service.port=7338 ``` This PR fix this issue. ## How was this patch tested? It's difficult to carry out unit testing. But I've tested it manually. Author: Yuming Wang Closes #20785 from wangyum/SPARK-23640. --- .../scala/org/apache/spark/util/Utils.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5caedeb526469..d2be93226e2a2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2302,16 +2302,20 @@ private[spark] object Utils extends Logging { } /** - * Return the value of a config either through the SparkConf or the Hadoop configuration - * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf - * if the key is not set in the Hadoop configuration. + * Return the value of a config either through the SparkConf or the Hadoop configuration. + * We Check whether the key is set in the SparkConf before look at any Hadoop configuration. + * If the key is set in SparkConf, no matter whether it is running on YARN or not, + * gets the value from SparkConf. + * Only when the key is not set in SparkConf and running on YARN, + * gets the value from Hadoop configuration. */ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { - val sparkValue = conf.get(key, default) - if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { - new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue) + if (conf.contains(key)) { + conf.get(key, default) + } else if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { + new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, default) } else { - sparkValue + default } } From 15298b99ac8944e781328423289586176cf824d7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 30 Mar 2018 16:48:26 -0700 Subject: [PATCH 0543/2461] [SPARK-23827][SS] StreamingJoinExec should ensure that input data is partitioned into specific number of partitions ## What changes were proposed in this pull request? Currently, the requiredChildDistribution does not specify the partitions. This can cause the weird corner cases where the child's distribution is `SinglePartition` which satisfies the required distribution of `ClusterDistribution(no-num-partition-requirement)`, thus eliminating the shuffle needed to repartition input data into the required number of partitions (i.e. same as state stores). That can lead to "file not found" errors on the state store delta files as the micro-batch-with-no-shuffle will not run certain tasks and therefore not generate the expected state store delta files. This PR adds the required constraint on the number of partitions. ## How was this patch tested? Modified test harness to always check that ANY stateful operator should have a constraint on the number of partitions. As part of that, the existing opt-in checks on child output partitioning were removed, as they are redundant. Author: Tathagata Das Closes #20941 from tdas/SPARK-23827. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 3 +- .../sql/streaming/DeduplicateSuite.scala | 8 +-- .../FlatMapGroupsWithStateSuite.scala | 5 +- .../sql/streaming/StatefulOperatorTest.scala | 49 ------------------- .../spark/sql/streaming/StreamTest.scala | 19 +++++++ .../streaming/StreamingAggregationSuite.scala | 4 +- 7 files changed, 25 insertions(+), 65 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a10ed5f2df1b5..1a83c884d55bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,7 +62,7 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } - private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index c351f658cb955..fa7c8ee906ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index caf2bab8a5859..0088b64d6195e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index de2b51678cea6..b1416bff87ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -42,8 +42,7 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { + with BeforeAndAfterAll { import testImplicits._ import GroupStateImpl._ @@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( - sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala deleted file mode 100644 index 45142278993bb..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.streaming._ - -trait StatefulOperatorTest { - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - colNames: Seq[String]): Boolean = { - val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output - val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions - val groupingAttr = attr.filter(a => colNames.contains(a.name)) - checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) - } - - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - expectedPartitioning: Partitioning): Boolean = { - val operator = sq.asInstanceOf[StreamExecution].lastExecution - .executedPlan.collect { case p: T => p } - operator.head.children.forall( - _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e44aef09f1f3c..00741d660dd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ @@ -444,6 +445,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } + val lastExecution = currentStream.lastExecution + if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { + // Verify if stateful operators have correct metadata and distribution + // This can often catch hard to debug errors when developing stateful operators + lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s => + assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores)) + s.requiredChildDistribution.foreach { d => + withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") { + assert(d.requiredNumPartitions.isDefined) + assert(d.requiredNumPartitions.get >= 1) + if (d != AllTuples) { + assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions) + } + } + } + } + } + val (latestBatchData, allData) = sink match { case s: MemorySink => (s.latestBatchData, s.allData) case s: MemorySinkV2 => (s.latestBatchData, s.allData) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97e065193fd05..1cae8cb8d47f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions with StatefulOperatorTest { + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), From 529f847105fa8d98a5dc4d20955e4870df6bc1c5 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 31 Mar 2018 10:34:01 +0800 Subject: [PATCH 0544/2461] [SPARK-23040][CORE][FOLLOW-UP] Avoid double wrap result Iterator. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20449#discussion_r172414393, If `resultIter` is already a `InterruptibleIterator`, don't double wrap it. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20920 from jiangxb1987/SPARK-23040. --- .../spark/shuffle/BlockStoreShuffleReader.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 85e7e56a04a7d..4103dfb10175e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -111,8 +111,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( case None => aggregatedIter } - // Use another interruptible iterator here to support task cancellation as aggregator or(and) - // sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } } } From 44a9f8e6e82c300dc61ca18515aee16f17f27501 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 2 Apr 2018 09:53:37 -0700 Subject: [PATCH 0545/2461] [SPARK-15009][PYTHON][FOLLOWUP] Add default param checks for CountVectorizerModel ## What changes were proposed in this pull request? Adding test for default params for `CountVectorizerModel` constructed from vocabulary. This required that the param `maxDF` be added, which was done in SPARK-23615. ## How was this patch tested? Added an explicit test for CountVectorizerModel in DefaultValuesTests. Author: Bryan Cutler Closes #20942 from BryanCutler/pyspark-CountVectorizerModel-default-param-test-SPARK-15009. --- python/pyspark/ml/tests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6b4376cbf14e8..c2c4861e2aff4 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2096,6 +2096,11 @@ def test_java_params(self): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector): From 6151f29f9f589301159482044fc32717f430db6e Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 2 Apr 2018 12:00:37 -0700 Subject: [PATCH 0546/2461] [SPARK-23825][K8S] Requesting memory + memory overhead for pod memory ## What changes were proposed in this pull request? Kubernetes driver and executor pods should request `memory + memoryOverhead` as their resources instead of just `memory`, see https://issues.apache.org/jira/browse/SPARK-23825 ## How was this patch tested? Existing unit tests were adapted. Author: David Vogelbacher Closes #20943 from dvogelbacher/spark-23825. --- .../k8s/submit/steps/BasicDriverConfigurationStep.scala | 5 +---- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 +---- .../submit/steps/BasicDriverConfigurationStepSuite.scala | 2 +- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 6 ++++-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 347c4d2d66826..b811db324108c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -93,9 +93,6 @@ private[spark] class BasicDriverConfigurationStep( .withAmount(driverCpuCores) .build() val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryMiB}Mi") - .build() - val driverMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => @@ -117,7 +114,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewResources() .addToRequests("cpu", driverCpuQuantity) .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryLimitQuantity) + .addToLimits("memory", driverMemoryQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 98cbd5607da00..ac42385459dda 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -108,9 +108,6 @@ private[spark] class ExecutorPodFactory( SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ executorLabels val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryMiB}Mi") - .build() - val executorMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) @@ -167,7 +164,7 @@ private[spark] class ExecutorPodFactory( .withImagePullPolicy(imagePullPolicy) .withNewResources() .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryLimitQuantity) + .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() .addAllToEnv(executorEnv.asJava) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index ce068531c7673..e59c6d28a8cc2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -91,7 +91,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val resourceRequirements = preparedDriverSpec.driverContainer.getResources val requests = resourceRequirements.getRequests.asScala assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "256Mi") + assert(requests("memory").getAmount === "456Mi") val limits = resourceRequirements.getLimits.asScala assert(limits("memory").getAmount === "456Mi") assert(limits("cpu").getAmount === "4") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7755b93835047..cee8fe27039c9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -66,12 +66,14 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef assert(executor.getMetadata.getLabels.size() === 3) assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - // There is exactly 1 container with no volume mounts and default memory limits. - // Default memory limit is 1024M + 384M (minimum overhead constant). + // There is exactly 1 container with no volume mounts and default memory limits and requests. + // Default memory limit/request is 1024M + 384M (minimum overhead constant). assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getImage === executorImage) assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") From fe2b7a4568d65a62da6e6eb00fff05f248b4332c Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Mon, 2 Apr 2018 12:20:55 -0700 Subject: [PATCH 0547/2461] [SPARK-23285][K8S] Add a config property for specifying physical executor cores ## What changes were proposed in this pull request? As mentioned in SPARK-23285, this PR introduces a new configuration property `spark.kubernetes.executor.cores` for specifying the physical CPU cores requested for each executor pod. This is to avoid changing the semantics of `spark.executor.cores` and `spark.task.cpus` and their role in task scheduling, task parallelism, dynamic resource allocation, etc. The new configuration property only determines the physical CPU cores available to an executor. An executor can still run multiple tasks simultaneously by using appropriate values for `spark.executor.cores` and `spark.task.cpus`. ## How was this patch tested? Unit tests. felixcheung srowen jiangxb1987 jerryshao mccheah foxish Author: Yinan Li Author: Yinan Li Closes #20553 from liyinan926/master. --- docs/running-on-kubernetes.md | 15 ++++++++--- .../org/apache/spark/deploy/k8s/Config.scala | 6 +++++ .../cluster/k8s/ExecutorPodFactory.scala | 12 ++++++--- .../cluster/k8s/ExecutorPodFactorySuite.scala | 27 +++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 975b28de47e20..9c4644947c911 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -549,14 +549,23 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + spark.kubernetes.executor.request.cores + (none) + + Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task + parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. + spark.kubernetes.executor.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. @@ -593,4 +602,4 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. - \ No newline at end of file + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index da34a7e06238a..405ea476351bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -91,6 +91,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = + ConfigBuilder("spark.kubernetes.executor.request.cores") + .doc("Specify the cpu request for each executor pod") + .stringConf + .createOptional + val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") .doc("Name of the driver pod.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ac42385459dda..7143f7a6f0b71 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -83,7 +83,12 @@ private[spark] class ExecutorPodFactory( MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorCores = sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) /** @@ -111,7 +116,7 @@ private[spark] class ExecutorPodFactory( .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCores.toString) + .withAmount(executorCoresRequest) .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() @@ -130,8 +135,7 @@ private[spark] class ExecutorPodFactory( }.getOrElse(Seq.empty[EnvVar]) val executorEnv = (Seq( (ENV_DRIVER_URL, driverUrl), - // Executor backend expects integral value for executor cores, so round it up to an int. - (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_CORES, executorCores.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), // This is to set the SPARK_CONF_DIR to be /opt/spark/conf diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index cee8fe27039c9..a71a2a1b888bc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -85,6 +85,33 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor core request specification") { + var factory = new ExecutorPodFactory(baseConf, None) + var executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "1") + + val conf = baseConf.clone() + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") + factory = new ExecutorPodFactory(conf, None) + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "0.1") + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + factory = new ExecutorPodFactory(conf, None) + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "100m") + } + test("executor pod hostnames get truncated to 63 characters") { val conf = baseConf.clone() conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, From a7c19d9c21d59fd0109a7078c80b33d3da03fafd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 2 Apr 2018 21:48:44 +0200 Subject: [PATCH 0548/2461] [SPARK-23713][SQL] Cleanup UnsafeWriter and BufferHolder classes ## What changes were proposed in this pull request? This PR implemented the following cleanups related to `UnsafeWriter` class: - Remove code duplication between `UnsafeRowWriter` and `UnsafeArrayWriter` - Make `BufferHolder` class internal by delegating its accessor methods to `UnsafeWriter` - Replace `UnsafeRow.setTotalSize(...)` with `UnsafeRowWriter.setTotalSize()` ## How was this patch tested? Tested by existing UTs Author: Kazuaki Ishizaki Closes #20850 from kiszk/SPARK-23713. --- .../sql/kafka010/KafkaContinuousReader.scala | 3 - .../KafkaRecordToUnsafeRowConverter.scala | 11 +- .../expressions/codegen/BufferHolder.java | 32 +-- .../codegen/UnsafeArrayWriter.java | 133 +++--------- .../expressions/codegen/UnsafeRowWriter.java | 189 +++++++----------- .../expressions/codegen/UnsafeWriter.java | 157 ++++++++++++++- .../InterpretedUnsafeProjection.scala | 90 ++++----- .../codegen/GenerateUnsafeProjection.scala | 124 +++++------- .../RowBasedKeyValueBatchSuite.java | 28 +-- .../aggregate/RowBasedHashMapGenerator.scala | 12 +- .../columnar/GenerateColumnAccessor.scala | 9 +- .../datasources/text/TextFileFormat.scala | 11 +- 12 files changed, 391 insertions(+), 408 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index e7e27876088f3..f26c134c2f6e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -27,13 +27,10 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String /** * A [[ContinuousReader]] for data from kafka. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 1acdd56125741..f35a143e00374 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val rowWriter = new UnsafeRowWriter(7) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - bufferHolder.reset() + rowWriter.reset() if (record.key == null) { rowWriter.setNullAt(0) @@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12f..537ef244b7e81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,25 +30,21 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +82,23 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] getBuffer() { + return buffer; + } + + int getCursor() { + return cursor; + } + + void increaseCursor(int val) { + cursor += val; + } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 82cd1b24607e1..a78dd970d23e4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -32,14 +30,12 @@ */ public final class UnsafeArrayWriter extends UnsafeWriter { - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; - // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { + super(writer.getBufferHolder()); + this.elementSize = elementSize; + } + + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,112 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } - holder.cursor += (headerInBytes + fixedPartInBytes); + increaseCursor(headerInBytes + fixedPartInBytes); } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } - } - - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + writeByte(getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + writeShort(getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + writeInt(getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + writeLong(getElementOffset(ordinal), 0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + increaseCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 2620bbcfb87a2..71c49d8ed0177 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -31,7 +28,7 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very @@ -40,29 +37,58 @@ */ public final class UnsafeRowWriter extends UnsafeWriter { - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; + private final UnsafeRow row; + private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); + } + + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ + public UnsafeRow getRow() { + row.setTotalSize(totalSize()); + return row; } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -72,25 +98,17 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); - } - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } @Override @@ -117,67 +135,49 @@ public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, int size) { - setOffsetAndSize(ordinal, holder.cursor, size); - } - - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | (long) size; - - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + writeLong(offset, 0L); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + writeLong(offset, 0L); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + writeLong(offset, 0L); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + writeLong(offset, 0L); + writeInt(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + writeLong(offset, 0); + writeFloat(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + if (input != null && input.changePrecision(precision, scale)) { + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -185,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + increaseCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index c94b5c7a367ef..de0eb6dbb76be 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -24,10 +26,73 @@ * Base class for writing Unsafe* structures. */ public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] getBuffer() { + return holder.getBuffer(); + } + + public final void reset() { + holder.reset(); + } + + public final int totalSize() { + return holder.totalSize(); + } + + public final void grow(int neededSize) { + holder.grow(neededSize); + } + + public final int cursor() { + return holder.getCursor(); + } + + public final void increaseCursor(int val) { + holder.increaseCursor(val); + } + + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); + } + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); public abstract void write(int ordinal, byte value); public abstract void write(int ordinal, short value); @@ -36,8 +101,92 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, float value); public abstract void write(int ordinal, double value); public abstract void write(int ordinal, Decimal input, int precision, int scale); - public abstract void write(int ordinal, UTF8String input); - public abstract void write(int ordinal, byte[] input); - public abstract void write(int ordinal, CalendarInterval input); - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + public final void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(getBuffer(), cursor()); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + increaseCursor(16); + } + + protected final void writeBoolean(long offset, boolean value) { + Platform.putBoolean(getBuffer(), offset, value); + } + + protected final void writeByte(long offset, byte value) { + Platform.putByte(getBuffer(), offset, value); + } + + protected final void writeShort(long offset, short value) { + Platform.putShort(getBuffer(), offset, value); + } + + protected final void writeInt(long offset, int value) { + Platform.putInt(getBuffer(), offset, value); + } + + protected final void writeLong(long offset, long value) { + Platform.putLong(getBuffer(), offset, value); + } + + protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(getBuffer(), offset, value); + } + + protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(getBuffer(), offset, value); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 0da5ece7e47fe..b31466f5c92d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row representing the expression results. */ private[this] val intermediate = new GenericInternalRow(values) - /** The row returned by the projection. */ - private[this] val result = new UnsafeRow(numFields) - - /** The buffer which holds the resulting row's backing data. */ - private[this] val holder = new BufferHolder(result, numFields * 32) + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { - val rowWriter = new UnsafeRowWriter(holder, numFields) val baseWriter = generateStructWriter( - holder, rowWriter, expressions.map(e => StructField("", e.dataType, e.nullable))) if (!expressions.exists(_.nullable)) { @@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } // Write the intermediate row to an unsafe row. - holder.reset() + rowWriter.reset() writer(intermediate) - result.setTotalSize(holder.totalSize()) - result + rowWriter.getRow() } } @@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * given buffer using the given [[UnsafeRowWriter]]. */ private def generateStructWriter( - bufferHolder: BufferHolder, rowWriter: UnsafeRowWriter, fields: Array[StructField]): InternalRow => Unit = { val numFields = fields.length // Create field writers. val fieldWriters = fields.map { field => - generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + generateFieldWriter(rowWriter, field.dataType, field.nullable) } // Create basic writer. row => { @@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * or array) to the given buffer using the given [[UnsafeWriter]]. */ private def generateFieldWriter( - bufferHolder: BufferHolder, writer: UnsafeWriter, dt: DataType, nullable: Boolean): (SpecializedGetters, Int) => Unit = { @@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case StructType(fields) => val numFields = fields.length - val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) - val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( - bufferHolder, + rowWriter, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) case row => // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.reset() + rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter - val elementSize = getElementSize(elementType) + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) val elementWriter = generateFieldWriter( - bufferHolder, arrayWriter, elementType, containsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor - writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter - val keySize = getElementSize(keyType) + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) val keyWriter = generateFieldWriter( - bufferHolder, keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter - val valueSize = getElementSize(valueType) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) val valueWriter = generateFieldWriter( - bufferHolder, valueArrayWriter, valueType, valueContainsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( - bufferHolder, + valueArrayWriter, map.getBaseObject, map.getBaseOffset, map.getSizeInBytes) case map => // preserve 8 bytes to write the key array numBytes later. - bufferHolder.grow(8) - bufferHolder.cursor += 8 + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) - Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) // Write the values. - writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => - generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + generateFieldWriter(writer, udt.sqlType, nullable) case NullType => (_, _) => {} @@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * copy. */ private def writeArray( - bufferHolder: BufferHolder, arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, - array: ArrayData, - elementSize: Int): Unit = array match { + array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( - bufferHolder, + arrayWriter, unsafe.getBaseObject, unsafe.getBaseOffset, unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(bufferHolder, numElements, elementSize) + arrayWriter.initialize(numElements) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. */ private def writeUnsafeData( - bufferHolder: BufferHolder, + writer: UnsafeWriter, baseObject: AnyRef, baseOffset: Long, sizeInBytes: Int) : Unit = { - bufferHolder.grow(sizeInBytes) + writer.grow(sizeInBytes) Platform.copyMemory( baseObject, baseOffset, - bufferHolder.buffer, - bufferHolder.cursor, + writer.getBuffer, + writer.cursor, sizeInBytes) - bufferHolder.cursor += sizeInBytes + writer.increaseCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6682ba55b18b1..ab2254cd9f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + s""" final InternalRow $tmpInput = $input; if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} } """ } @@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case udt: UserDefinedType[_] => udt.sqlType case other => other } - val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } + val previousCursor = ctx.freshName("previousCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => 8 // we need 8 bytes to store offset and length } - val tmpCursor = ctx.freshName("tmpCursor") + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val previousCursor = ctx.freshName("previousCursor") + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeArrayToBuffer(ctx, element, et, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") @@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final MapData $tmpInput = $input; if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} } else { // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + $rowWriter.grow(8); + $rowWriter.increaseCursor(8); // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + final int $tmpCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } """ } @@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; + $rowWriter.grow($sizeInBytes); + $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); + $rowWriter.increaseCursor($sizeInBytes); """ } @@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") - - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$holder.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$result.setTotalSize($holder.totalSize());" - } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" - $resetBufferHolder + $rowWriter.reset(); $evalSubexpr $writeExpressions - $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", s"$rowWriter.getRow()") } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed1996..2da87113c6229 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +54,27 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 8617be88f3570..d5508275c48c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -165,18 +165,14 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 3b5655ba0582e..2d699e8a9d088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9647f09867643..e93908da43535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() + unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + unsafeRowWriter.getRow() } } } From 28ea4e3142b88eb396aa8dd5daf7b02b556204ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 2 Apr 2018 14:35:07 -0700 Subject: [PATCH 0549/2461] [SPARK-23834][TEST] Wait for connection before disconnect in LauncherServer test. It was possible that the disconnect() was called on the handle before the server had received the handshake messages, so no connection was yet attached to the handle. The fix waits until we're sure the handle has been mapped to a client connection. Author: Marcelo Vanzin Closes #20950 from vanzin/SPARK-23834. --- .../org/apache/spark/launcher/LauncherServerSuite.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5413d3a416545..f8dc0ec7a0bf6 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -196,6 +196,14 @@ public void testAppHandleDisconnect() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); client = new TestClient(s); client.send(new Hello(secret, "1.4.0")); + client.send(new SetAppId("someId")); + + // Wait until we know the server has received the messages and matched the handle to the + // connection before disconnecting. + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + assertEquals("someId", handle.getAppId()); + }); + handle.disconnect(); waitForError(client, secret); } finally { From a1351828d376a01e5ee0959cf608f767d756dd86 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 2 Apr 2018 16:41:26 -0700 Subject: [PATCH 0550/2461] [SPARK-23690][ML] Add handleinvalid to VectorAssembler ## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg Author: Bago Amirbekian Author: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Closes #20829 from yogeshg/rformula_handleinvalid. --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 198 ++++++++++++++---- .../ml/feature/VectorAssemblerSuite.scala | 131 +++++++++++- 3 files changed, 284 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab78..67cdb097217a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -234,7 +234,7 @@ class StringIndexerModel ( val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { + val (filteredDataset, keepInvalid) = $(handleInvalid) match { case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b373ae921ed38..6bf4aa38b1fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuilder +import java.util.NoSuchElementException + +import scala.collection.mutable +import scala.language.existentials import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. + * + * This requires one pass over the entire dataset. In case we need to infer column lengths from the + * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.4.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + @Since("2.4.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + |output). Column lengths are taken from the size of ML Attribute Group, which can be set using + |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + |""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val vectorCols = $(inputCols).filter { c => + schema(c).dataType match { + case _: VectorUDT => true + case _ => false + } + } + val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + + val featureAttributesMap = $(inputCols).map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) + Seq(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + if (attributeGroup.attributes.isDefined) { + attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until vectorColsLengths(c)).map { i => + NumericAttribute.defaultAttr.withName(c + "_" + i) + } } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute].toArray + val lengths = featureAttributesMap.map(a => a.length).toArray + val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) + case VectorAssembler.KEEP_INVALID => (dataset, true) + case VectorAssembler.ERROR_INVALID => (dataset, false) + } // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths( + dataset: Dataset[_], + columns: Seq[String], + handleInvalid: String): Map[String, Int] = { + val groupSizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq + val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missingColumns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint + |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + .stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + groupSizes ++ firstSizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 + /** + * Returns a function that has the required information to assemble each row. + * @param lengths an array of lengths of input columns, whose size should be equal to the number + * of cells in the row (vv) + * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows + * @return a udf that can be applied on each row + */ + private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { + val indices = mutable.ArrayBuilder.make[Int] + val values = mutable.ArrayBuilder.make[Double] + var featureIndex = 0 + + var inputColumnIndex = 0 vv.foreach { case v: Double => - if (v != 0.0) { - indices += cur + if (v.isNaN && !keepInvalid) { + throw new SparkException( + s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider + |removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } else if (v != 0.0) { + indices += featureIndex values += v } - cur += 1 + inputColumnIndex += 1 + featureIndex += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { - indices += cur + i + indices += featureIndex + i values += v } } - cur += vec.size + inputColumnIndex += 1 + featureIndex += vec.size case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") + if (keepInvalid) { + val length: Int = lengths(inputColumnIndex) + Array.range(0, length).foreach { i => + indices += featureIndex + i + values += Double.NaN + } + inputColumnIndex += 1 + featureIndex += length + } else { + throw new SparkException( + s"""Encountered null while assembling a row with handleInvalid = "keep". Consider + |removing nulls from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(featureIndex, indices.result(), values.result()).compressed } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index eca065f7e775d..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite @@ -31,30 +31,49 @@ class VectorAssemblerSuite import testImplicits._ + @transient var dfWithNullsAndNaNs: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sv = Vectors.sparse(2, Array(1), Array(3.0)) + dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)]( + (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (2, 1, 0.0, null, "a", sv, 6L, null), + (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null), + (4, 4, null, null, "a", sv, 9L, null), + (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null)) + .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls") + } + test("params") { ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble - assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) - assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + assert(assemble(Array(1), keepInvalid = true)(0.0) + === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0) + === Vectors.sparse(2, Array(1), Array(1.0))) val dv = Vectors.dense(2.0, 0.0) - assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) === + Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) - assert(assemble(0.0, dv, 1.0, sv) === + assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) === Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) - for (v <- Seq(1, "a", null)) { - intercept[SparkException](assemble(v)) - intercept[SparkException](assemble(1.0, v)) + for (v <- Seq(1, "a")) { + intercept[SparkException](assemble(Array(1), keepInvalid = true)(v)) + intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v)) } } test("assemble should compress vectors") { import org.apache.spark.ml.feature.VectorAssembler.assemble - val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) assert(v1.isInstanceOf[SparseVector]) - val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + val sv = Vectors.sparse(1, Array(0), Array(4.0)) + val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv) assert(v2.isInstanceOf[DenseVector]) } @@ -147,4 +166,94 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null) + === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1), keepInvalid = false)(null)) + intercept[SparkException](assemble(Array(2), keepInvalid = false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNullsAndNaNs + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val attributeZ = new AttributeGroup( + "z", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata()) + .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter) + val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata) + output.collect() + output + } + + def runWithFirstRow(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs) + output.collect() + output + } + + def runWithAllNullVectors(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode) + .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2")) + output.collect() + output + } + + // behavior when vector size hint is given + assert(runWithMetadata("keep").count() == 6, "should keep all rows") + assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls") + // should throw error with nulls + intercept[SparkException](runWithMetadata("error")) + // should throw error with NaNs + intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4")) + + // behavior when first row has information + assert(intercept[RuntimeException](runWithFirstRow("keep").count()) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls") + intercept[SparkException](runWithFirstRow("error")) + + // behavior when vector column is all null + assert(intercept[RuntimeException](runWithAllNullVectors("skip")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(intercept[NullPointerException](runWithAllNullVectors("error")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + + // behavior when scalar column is all null + assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) + } + } From 441d0d0766e9a6ac4c6ff79680394999ff7191fd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 3 Apr 2018 09:31:47 +0800 Subject: [PATCH 0551/2461] [SPARK-19964][CORE] Avoid reading from remote repos in SparkSubmitSuite. These tests can fail with a timeout if the remote repos are not responding, or slow. The tests don't need anything from those repos, so use an empty ivy config file to avoid setting up the defaults. The tests are passing reliably for me locally now, and failing more often than not today without this change since http://dl.bintray.com/spark-packages/maven doesn't seem to be loading from my machine. Author: Marcelo Vanzin Closes #20916 from vanzin/SPARK-19964. --- .../org/apache/spark/deploy/DependencyUtils.scala | 13 ++++++++----- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- .../apache/spark/deploy/SparkSubmitArguments.scala | 2 ++ .../apache/spark/deploy/worker/DriverWrapper.scala | 13 +++++++++---- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ab319c860ee69..fac834a70b893 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -33,7 +33,8 @@ private[deploy] object DependencyUtils { packagesExclusions: String, packages: String, repositories: String, - ivyRepoPath: String): String = { + ivyRepoPath: String, + ivySettingsPath: Option[String]): String = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { packagesExclusions.split(",") @@ -41,10 +42,12 @@ private[deploy] object DependencyUtils { Nil } // Create the IvySettings, either load from file or build defaults - val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + val ivySettings = ivySettingsPath match { + case Some(path) => + SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath)) + + case None => + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) } SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3965f17f4b56e..eddbedeb1024d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -359,7 +359,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( - args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index e7796d4ddbe34..8e7070593687b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var ivySettingsPath: Option[String] = None var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false @@ -184,6 +185,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b19c9904d5982..3f71237164a15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -79,12 +79,17 @@ object DriverWrapper extends Logging { val secMgr = new SecurityManager(sparkConf) val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) - val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = - Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") - .map(sys.props.get(_).orNull) + val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) = + Seq( + "spark.jars.excludes", + "spark.jars.packages", + "spark.jars.repositories", + "spark.jars.ivy", + "spark.jars.ivySettings" + ).map(sys.props.get(_).orNull) val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, - packages, repositories, ivyRepoPath) + packages, repositories, ivyRepoPath, Option(ivySettingsPath)) val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d86ef907b4492..0d7c342a5eacd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,9 @@ class SparkSubmitSuite // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val emptyIvySettings = File.createTempFile("ivy", ".xml") + FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + override def beforeEach() { super.beforeEach() } @@ -520,6 +523,7 @@ class SparkSubmitSuite "--repositories", repo, "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -530,7 +534,6 @@ class SparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), @@ -540,6 +543,7 @@ class SparkSubmitSuite "--conf", s"spark.jars.repositories=$repo", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -550,7 +554,6 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -563,6 +566,7 @@ class SparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--packages", main.toString, "--repositories", repo, + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) @@ -573,7 +577,6 @@ class SparkSubmitSuite test("include an external JAR in SparkR") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) From 8020f66fc47140a1b5f843fb18c34ec80541d5ca Mon Sep 17 00:00:00 2001 From: lemonjing <932191671@qq.com> Date: Tue, 3 Apr 2018 09:36:44 +0800 Subject: [PATCH 0552/2461] [MINOR][DOC] Fix a few markdown typos ## What changes were proposed in this pull request? Easy fix in the markdown. ## How was this patch tested? jekyII build test manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: lemonjing <932191671@qq.com> Closes #20897 from Lemonjing/master. --- docs/ml-guide.md | 2 +- docs/mllib-feature-extraction.md | 4 ++-- docs/mllib-pmml-model-export.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 702bcf748fc74..aea07be34cb86 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases. * The class and trait hierarchy for logistic regression model summaries was changed to be cleaner and better accommodate the addition of the multi-class summary. This is a breaking change for user code that casts a `LogisticRegressionTrainingSummary` to a -` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail (_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which will still work correctly for both multinomial and binary cases. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 75aea70601875..8b89296b14cdd 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -278,8 +278,8 @@ for details on the API. multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. -Qu8T948*1# -Denoting the `scalingVec` as "`w`," this transformation may be written as: + +Denoting the `scalingVec` as "`w`", this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index d3530908706d0..f567565437927 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API * Table of contents {:toc} -## `spark.mllib` supported models +## spark.mllib supported models `spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). @@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a - + From 7cf9fab33457ccc9b2d548f15dd5700d5e8d08ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 3 Apr 2018 21:26:49 +0800 Subject: [PATCH 0553/2461] [MINOR][CORE] Show block manager id when remove RDD/Broadcast fails. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20924#discussion_r177987175, show block manager id when remove RDD/Broadcast fails. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20960 from jiangxb1987/bmid. --- .../apache/spark/storage/BlockManagerMasterEndpoint.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56b95c31eb4c3..8e8f7d197c9ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint( val futures = blockManagerInfo.values.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId", e) + logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", + e) 0 // zero blocks were removed } }.toSeq @@ -195,7 +196,8 @@ class BlockManagerMasterEndpoint( val futures = requiredBlockManagers.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove broadcast $broadcastId", e) + logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + + s"${bm.blockManagerId}", e) 0 // zero blocks were removed } }.toSeq From 66a3a5a2dc83e03dedcee9839415c1ddc1fb8125 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Apr 2018 11:05:29 -0700 Subject: [PATCH 0554/2461] [SPARK-23099][SS] Migrate foreach sink to DataSourceV2 ## What changes were proposed in this pull request? Migrate foreach sink to DataSourceV2. Since the previous attempt at this PR #20552, we've changed and strictly defined the lifecycle of writer components. This means we no longer need the complicated lifecycle shim from that PR; it just naturally works. ## How was this patch tested? existing tests Author: Jose Torres Closes #20951 from jose-torres/foreach. --- .../sql/execution/streaming/ForeachSink.scala | 68 ----------- .../sources/ForeachWriterProvider.scala | 111 ++++++++++++++++++ .../sql/streaming/DataStreamWriter.scala | 4 +- .../ForeachWriterSuite.scala} | 83 ++++++------- .../sql/streaming/StreamingQuerySuite.scala | 1 + 5 files changed, 156 insertions(+), 111 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{ForeachSinkSuite.scala => sources/ForeachWriterSuite.scala} (77%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc54107f8b83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000000000..df5d69d57e36f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc903168cfa0..effc1471e8e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala similarity index 77% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index b249dd41a84a6..03bf71b3f4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.concurrent.ConcurrentLinkedQueue @@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -47,9 +48,9 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .start() def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { - import ForeachSinkSuite._ + import ForeachWriterSuite._ - val events = ForeachSinkSuite.allEvents() + val events = ForeachWriterSuite.allEvents() assert(events.size === 2) // one seq of events for each of the 2 partitions // Verify both seq of events have an Open event as the first event @@ -64,13 +65,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } // -- batch 0 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(1, 2, 3, 4) query.processAllAvailable() verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() verifyOutput(expectedVersion = 1, expectedData = 5 to 8) @@ -95,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf input.addData(1, 2, 3, 4) query.processAllAvailable() - var allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() // -- batch 1 --------------------------------------- input.addData(5, 6, 7, 8) query.processAllAvailable() - allEvents = ForeachSinkSuite.allEvents() + allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) @@ -131,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("error") + throw new RuntimeException("ForeachSinkSuite error") } }).start() input.addData(1, 2, 3, 4) @@ -141,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") assert(query.isActive === false) - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") } } @@ -177,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf inputData.addData(10, 11, 12) query.processAllAvailable() - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) } finally { @@ -216,21 +217,21 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 3) val expectedEvents = Seq( Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) ) assert(allEvents === expectedEvents) @@ -258,7 +259,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } /** A global object to collect events in the executor */ -object ForeachSinkSuite { +object ForeachWriterSuite { trait Event @@ -285,21 +286,21 @@ object ForeachSinkSuite { /** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) + events += ForeachWriterSuite.Open(partition = partitionId, version = version) true } override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) + events += ForeachWriterSuite.Process(value) } override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 08749b49997e0..20942ed93897c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.DataReaderFactory From 1035aaa61704b2790192d3186fe37e678553d36d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Apr 2018 01:36:58 +0200 Subject: [PATCH 0555/2461] [SPARK-23587][SQL] Add interpreted execution for MapObjects expression ## What changes were proposed in this pull request? Add interpreted execution for `MapObjects` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20771 from viirya/SPARK-23587. --- .../expressions/objects/objects.scala | 110 ++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 67 ++++++++++- 2 files changed, 165 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index adf9ddf327c96..0e9d357c19c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -501,12 +502,22 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + input.get(0, dataType) + } override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -599,8 +610,92 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[Seq[_]] + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].array + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + + "resulting collection.") + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( @@ -647,13 +742,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1f6964dfef598..0edd27c8241e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkFunSuite} @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) + + // Unsupported custom collection class + val errMsg = intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) + } + } + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { val cls = classOf[java.lang.Integer] val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) From 359375eff74630c9f0ea5a90ab7d45bf1b281ed0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 3 Apr 2018 17:09:12 -0700 Subject: [PATCH 0556/2461] [SPARK-23809][SQL] Active SparkSession should be set by getOrCreate ## What changes were proposed in this pull request? Currently, the active spark session is set inconsistently (e.g., in createDataFrame, prior to query execution). Many places in spark also incorrectly query active session when they should be calling activeSession.getOrElse(defaultSession) and so might get None even if a Spark session exists. The semantics here can be cleaned up if we also set the active session when the default session is set. Related: https://github.com/apache/spark/pull/20926/files ## How was this patch tested? Unit test, existing test. Note that if https://github.com/apache/spark/pull/20926 merges first we should also update the tests there. Author: Eric Liang Closes #20927 from ericl/active-session-cleanup. --- .../org/apache/spark/sql/SparkSession.scala | 14 +++++++++++++- .../spark/sql/SparkSessionBuilderSuite.scala | 18 ++++++++++++++++++ .../apache/spark/sql/test/TestSQLContext.scala | 1 + .../apache/spark/sql/hive/test/TestHive.scala | 3 +++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 734573ba31f71..b107492fbb330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -951,7 +951,8 @@ object SparkSession { session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } - defaultSession.set(session) + setDefaultSession(session) + setActiveSession(session) // Register a successfully instantiated context to the singleton. This should be at the // end of the class definition so that the singleton is updated only if there is no @@ -1027,6 +1028,17 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: SparkSession = { + getActiveSession.getOrElse(getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index c0301f2ce2d66..44bf8624a6bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(SparkSession.builder().getOrCreate() == session) } + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[IllegalStateException](SparkSession.active) + session.stop() + } + test("config options are propagated to existing SparkSession") { val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 3038b822beb4a..17603deacdcdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,6 +35,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) } SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) @transient override lazy val sessionState: SessionState = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 814038d4ef7af..a7006a16d7b73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -179,6 +179,9 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", From 5cfd5fabcdbd77a806b98a6dd59b02772d2f6dee Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 3 Apr 2018 17:25:54 -0700 Subject: [PATCH 0557/2461] [SPARK-23802][SQL] PropagateEmptyRelation can leave query plan in unresolved state ## What changes were proposed in this pull request? Add cast to nulls introduced by PropagateEmptyRelation so in cases they're part of coalesce they will not break its type checking rules ## How was this patch tested? Added unit test Author: Robert Kruszewski Closes #20914 from robert3005/rk/propagate-empty-fix. --- .../optimizer/PropagateEmptyRelation.scala | 8 ++++-- .../PropagateEmptyRelationSuite.scala | 26 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index a6e5aa6daca65..c3fdb924243df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Collapse plans consisting empty local relations generated by [[PruneFilters]]. @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ -object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { +object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport { private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = - plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + + override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 3964508e3a55e..f1ce7543ffdc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceIntersectWithSemiJoin, PushDownPredicate, PruneFilters, - PropagateEmptyRelation) :: Nil + PropagateEmptyRelation, + CollapseProject) :: Nil } object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters, + CollapseProject) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) @@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, FullOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), (true, false, LeftSemi, Some(LocalRelation('a.int))), @@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, - Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("propagate empty relation keeps the plan resolved") { + val query = testRelation1.join( + LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + val optimized = Optimize.execute(query.analyze) + assert(optimized.resolved) + } } From 16ef6baa36ac11c72cfeafaa2363e6b69f0ba573 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 4 Apr 2018 14:31:03 +0800 Subject: [PATCH 0558/2461] [SPARK-23826][TEST] TestHiveSparkSession should set default session ## What changes were proposed in this pull request? In TestHive, the base spark session does this in getOrCreate(), we emulate that behavior for tests. ## How was this patch tested? N/A Author: gatorsmile Closes #20969 from gatorsmile/setDefault. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a7006a16d7b73..965aea2b61456 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,10 +159,6 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => - // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as - // TestSparkSession, but doing this the same way breaks many tests in the package. We need - // to investigate and find a different strategy. - def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From 5197562afe8534b29f5a0d72683c2859f796275d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Apr 2018 14:39:19 +0800 Subject: [PATCH 0559/2461] [SPARK-21351][SQL] Update nullability based on children's output ## What changes were proposed in this pull request? This pr added a new optimizer rule `UpdateNullabilityInAttributeReferences ` to update the nullability that `Filter` changes when having `IsNotNull`. In the master, optimized plans do not respect the nullability when `Filter` has `IsNotNull`. This wrongly generates unnecessary code. For example: ``` scala> val df = Seq((Some(1), Some(2))).toDF("a", "b") scala> val bIsNotNull = df.where($"b" =!= 2).select($"b") scala> val targetQuery = bIsNotNull.distinct scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = true scala> targetQuery.debugCodegen Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- Exchange hashpartitioning(b#19, 200) +- *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- *Project [_2#16 AS b#19] +- *Filter isnotnull(_2#16) +- LocalTableScan [_1#15, _2#16] Generated code: ... /* 124 */ protected void processNext() throws java.io.IOException { ... /* 132 */ // output the result /* 133 */ /* 134 */ while (agg_mapIter.next()) { /* 135 */ wholestagecodegen_numOutputRows.add(1); /* 136 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 137 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 138 */ /* 139 */ boolean agg_isNull4 = agg_aggKey.isNullAt(0); /* 140 */ int agg_value4 = agg_isNull4 ? -1 : (agg_aggKey.getInt(0)); /* 141 */ agg_rowWriter1.zeroOutNullBytes(); /* 142 */ // We don't need this NULL check because NULL is filtered out in `$"b" =!=2` /* 143 */ if (agg_isNull4) { /* 144 */ agg_rowWriter1.setNullAt(0); /* 145 */ } else { /* 146 */ agg_rowWriter1.write(0, agg_value4); /* 147 */ } /* 148 */ append(agg_result1); /* 149 */ /* 150 */ if (shouldStop()) return; /* 151 */ } /* 152 */ /* 153 */ agg_mapIter.close(); /* 154 */ if (agg_sorter == null) { /* 155 */ agg_hashMap.free(); /* 156 */ } /* 157 */ } /* 158 */ /* 159 */ } ``` In the line 143, we don't need this NULL check because NULL is filtered out in `$"b" =!=2`. This pr could remove this NULL check; ``` scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = false scala> targetQuery.debugCodegen ... Generated code: ... /* 144 */ protected void processNext() throws java.io.IOException { ... /* 152 */ // output the result /* 153 */ /* 154 */ while (agg_mapIter.next()) { /* 155 */ wholestagecodegen_numOutputRows.add(1); /* 156 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 157 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 158 */ /* 159 */ int agg_value4 = agg_aggKey.getInt(0); /* 160 */ agg_rowWriter1.write(0, agg_value4); /* 161 */ append(agg_result1); /* 162 */ /* 163 */ if (shouldStop()) return; /* 164 */ } /* 165 */ /* 166 */ agg_mapIter.close(); /* 167 */ if (agg_sorter == null) { /* 168 */ agg_hashMap.free(); /* 169 */ } /* 170 */ } ``` ## How was this patch tested? Added `UpdateNullabilityInAttributeReferencesSuite` for unit tests. Author: Takeshi Yamamuro Closes #18576 from maropu/SPARK-21351. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 ++++++- ...ullabilityInAttributeReferencesSuite.scala | 57 +++++++++++++++++++ .../optimizer/complexTypesSuite.scala | 9 --- .../org/apache/spark/sql/DataFrameSuite.scala | 5 -- 4 files changed, 75 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2829d1d81eb1a..9a1bbc675e397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 0000000000000..09b11f5aba2a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 21ed987627b3b..633d86d495581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .groupBy($"foo")("1") checkRule(structRel, structExpected) - // These tests must use nullable attributes from the base relation for the following reason: - // in the 'original' plans below, the Aggregate node produced by groupBy() has a - // nullable AttributeReference to a1, because both array indexing and map lookup are - // nullable expressions. After optimization, the same attribute is now non-nullable, - // but the AttributeReference is not updated to reflect this. In the 'expected' plans, - // the grouping expressions have the same nullability as the original attribute in the - // relation. If that attribute is non-nullable, the tests will fail as the plans will - // compare differently, so for these tests we must use a nullable attribute. See - // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") .groupBy($"a1")("1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f7b3393f65cb1..60e84e6ee7504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2055,11 +2055,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; From a35523653cdac039ee2ddff316bc2c25d6514a91 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Apr 2018 18:36:15 +0200 Subject: [PATCH 0560/2461] [SPARK-23583][SQL] Invoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `Invoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20797 from kiszk/SPARK-28583. --- .../spark/sql/catalyst/ScalaReflection.scala | 48 +++++++++++++- .../expressions/objects/objects.scala | 56 ++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 65 +++++++++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9a4bf0075a178..1aae3aea3a31a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection { "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while") + + val typeJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[DateType.InternalType], + TimestampType -> classOf[TimestampType.InternalType], + BinaryType -> classOf[BinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long] + ) + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e9d357c19c63..a455c1c821a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.objects -import java.lang.reflect.Modifier +import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ import scala.collection.mutable.Builder @@ -28,7 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression { (argCode, argValues.mkString(", "), resultIsNull) } + + /** + * Evaluate each argument with a given row, invoke a method with a given object and arguments, + * and cast a return value if the return type can be mapped to a Java Boxed type + * + * @param obj the object for the method to be called. If null, perform s static method call + * @param method the method object to be called + * @param arguments the arguments used for the method call + * @param input the row used for evaluating arguments + * @param dataType the data type of the return object + * @return the return object of a method call + */ + def invoke( + obj: Any, + method: Method, + arguments: Seq[Expression], + input: InternalRow, + dataType: DataType): Any = { + val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) + if (needNullCheck && args.exists(_ == null)) { + // return null if one of arguments is null + null + } else { + val ret = method.invoke(obj, args: _*) + val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType) + if (boxedClass.isDefined) { + boxedClass.get.cast(ret) + } else { + ret + } + } + } } /** @@ -264,12 +296,11 @@ case class Invoke( propagateNull: Boolean = true, returnNullable : Boolean = true) extends InvokeLike { + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - private lazy val encodedFunctionName = TermName(functionName).encodedName.toString @transient lazy val method = targetObject.dataType match { @@ -283,6 +314,21 @@ case class Invoke( case _ => None } + override def eval(input: InternalRow): Any = { + val obj = targetObject.eval(input) + if (obj == null) { + // return null if obj is null + null + } else { + val invokeMethod = if (method.isDefined) { + method.get + } else { + obj.getClass.getDeclaredMethod(functionName, argClasses: _*) + } + invoke(obj, invokeMethod, arguments, input, dataType) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0edd27c8241e8..9bfe2916b0820 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +class InvokeTargetClass extends Serializable { + def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 + def filterPrimitiveInt(e: Int): Boolean = e > 0 + def binOp(e1: Int, e2: Double): Double = e1 + e2 +} + +class InvokeTargetSubClass extends InvokeTargetClass { + override def binOp(e1: Int, e2: Double): Double = e1 - e2 +} class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23583: Invoke should support interpreted execution") { + val targetObject = new InvokeTargetClass + val funcClass = classOf[InvokeTargetClass] + val funcObj = Literal.create(targetObject, ObjectType(funcClass)) + val targetSubObject = new InvokeTargetSubClass + val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass])) + val funcNullObj = Literal.create(null, ObjectType(funcClass)) + + val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true)) + val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false)) + val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false)) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt), + false, InternalRow.fromSeq(Seq(-1))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(null))) + + checkObjectExprEvaluation( + Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25)) + + checkObjectExprEvaluation( + Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } + // by scala values instead of catalyst values. + private def checkObjectExprEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + checkEvaluationWithoutCodegen(expr, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvaluationWithUnsafeProjection( + expr, + expected, + inputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + } + checkEvaluationWithOptimization(expr, expected, inputRow) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From cccaaa14ad775fb981e501452ba2cc06ff5c0f0a Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Wed, 4 Apr 2018 12:30:52 -0700 Subject: [PATCH 0561/2461] [SPARK-23668][K8S] Add config option for passing through k8s Pod.spec.imagePullSecrets ## What changes were proposed in this pull request? Pass through the `imagePullSecrets` option to the k8s pod in order to allow user to access private image registries. See https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ ## How was this patch tested? Unit tests + manual testing. Manual testing procedure: 1. Have private image registry. 2. Spark-submit application with no `spark.kubernetes.imagePullSecret` set. Do `kubectl describe pod ...`. See the error message: ``` Error syncing pod, skipping: failed to "StartContainer" for "spark-kubernetes-driver" with ErrImagePull: "rpc error: code = 2 desc = Error: Status 400 trying to pull repository ...: \"{\\n \\\"errors\\\" : [ {\\n \\\"status\\\" : 400,\\n \\\"message\\\" : \\\"Unsupported docker v1 repository request for '...'\\\"\\n } ]\\n}\"" ``` 3. Create secret `kubectl create secret docker-registry ...` 4. Spark-submit with `spark.kubernetes.imagePullSecret` set to the new secret. See that deployment was successful. Author: Andrew Korzhuev Author: Andrew Korzhuev Closes #20811 from andrusha/spark-23668-image-pull-secrets. --- .../org/apache/spark/deploy/k8s/Config.scala | 7 ++++ .../spark/deploy/k8s/KubernetesUtils.scala | 13 +++++++ .../steps/BasicDriverConfigurationStep.scala | 7 +++- .../cluster/k8s/ExecutorPodFactory.scala | 4 +++ .../deploy/k8s/KubernetesUtilsTest.scala | 36 +++++++++++++++++++ .../BasicDriverConfigurationStepSuite.scala | 8 ++++- .../cluster/k8s/ExecutorPodFactorySuite.scala | 5 +++ 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 405ea476351bb..82f6c714f3555 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -54,6 +54,13 @@ private[spark] object Config extends Logging { .checkValues(Set("Always", "Never", "IfNotPresent")) .createWithDefault("IfNotPresent") + val IMAGE_PULL_SECRETS = + ConfigBuilder("spark.kubernetes.container.image.pullSecrets") + .doc("Comma separated list of the Kubernetes secrets used " + + "to access private image registries.") + .stringConf + .createOptional + val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5bc070147d3a8..5b2bb819cdb14 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.deploy.k8s +import io.fabric8.kubernetes.api.model.LocalObjectReference + import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -35,6 +37,17 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } + /** + * Parses comma-separated list of imagePullSecrets into K8s-understandable format + */ + def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { + imagePullSecrets match { + case Some(secretsCommaSeparated) => + secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList + case None => Nil + } + } + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b811db324108c..fcb1db8008053 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit.steps import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ @@ -51,6 +51,8 @@ private[spark] class BasicDriverConfigurationStep( .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) + // CPU settings private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) @@ -129,6 +131,8 @@ private[spark] class BasicDriverConfigurationStep( case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() } + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() .withName(driverPodName) @@ -138,6 +142,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewSpec() .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 7143f7a6f0b71..8607d6fba3234 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -68,6 +68,7 @@ private[spark] class ExecutorPodFactory( .get(EXECUTOR_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the executor container image")) private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) private val blockManagerPort = sparkConf .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) @@ -103,6 +104,8 @@ private[spark] class ExecutorPodFactory( nodeToLocalTaskCount: Map[String, Int]): Pod = { val name = s"$executorPodNamePrefix-exec-$executorId" + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod // name as the hostname. This preserves uniqueness since the end of name contains // executorId @@ -194,6 +197,7 @@ private[spark] class ExecutorPodFactory( .withHostname(hostname) .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala new file mode 100644 index 0000000000000..cf41b22e241af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.LocalObjectReference + +import org.apache.spark.SparkFunSuite + +class KubernetesUtilsTest extends SparkFunSuite { + + test("testParseImagePullSecrets") { + val noSecrets = KubernetesUtils.parseImagePullSecrets(None) + assert(noSecrets === Nil) + + val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) + assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) + + val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) + assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e59c6d28a8cc2..ee450fff8d376 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -51,6 +51,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") + .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") val submissionStep = new BasicDriverConfigurationStep( APP_ID, @@ -103,7 +104,12 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, SPARK_APP_NAME_ANNOTATION -> APP_NAME) assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never") + + val driverPodSpec = preparedDriverSpec.driverPod.getSpec + assert(driverPodSpec.getRestartPolicy === "Never") + assert(driverPodSpec.getImagePullSecrets.size() === 2) + assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a71a2a1b888bc..d73df20f0f956 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -33,6 +33,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" private val executorImage: String = "executor-image" + private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" private val driverPod = new PodBuilder() .withNewMetadata() .withName(driverPodName) @@ -54,6 +55,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set(IMAGE_PULL_SECRETS, imagePullSecrets) } test("basic executor pod has reasonable defaults") { @@ -76,6 +78,9 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") + assert(executor.getSpec.getImagePullSecrets.size() === 2) + assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") // The pod has no node selector, volumes. assert(executor.getSpec.getNodeSelector.isEmpty) From d8379e5bc3629f4e8233ad42831bdaf68c24cfeb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 4 Apr 2018 15:43:58 -0700 Subject: [PATCH 0562/2461] [SPARK-23838][WEBUI] Running SQL query is displayed as "completed" in SQL tab ## What changes were proposed in this pull request? A running SQL query would appear as completed in the Spark UI: ![image1](https://user-images.githubusercontent.com/1097932/38170733-3d7cb00c-35bf-11e8-994c-43f2d4fa285d.png) We can see the query in "Completed queries", while in in the job page we see it's still running Job 132. ![image2](https://user-images.githubusercontent.com/1097932/38170735-48f2c714-35bf-11e8-8a41-6fae23543c46.png) After some time in the query still appears in "Completed queries" (while it's still running), but the "Duration" gets increased. ![image3](https://user-images.githubusercontent.com/1097932/38170737-50f87ea4-35bf-11e8-8b60-000f6f918964.png) To reproduce, we can run a query with multiple jobs. E.g. Run TPCDS q6. The reason is that updates from executions are written into kvstore periodically, and the job start event may be missed. ## How was this patch tested? Manually run the job again and check the SQL Tab. The fix is pretty simple. Author: Gengliang Wang Closes #20955 from gengliangwang/jobCompleted. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 3 ++- .../spark/sql/execution/ui/SQLAppStatusListener.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index e751ce39cd5d7..582528777f90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -39,7 +39,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() sqlStore.executionsList().foreach { e => - val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isRunning = e.completionTime.isEmpty || + e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } if (isRunning) { running += e diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 71e9f93c4566e..2b6bb48467eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -88,7 +88,7 @@ class SQLAppStatusListener( exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) exec.stages ++= event.stageIds.toSet - update(exec) + update(exec, force = true) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -308,11 +308,13 @@ class SQLAppStatusListener( }) } - private def update(exec: LiveExecutionData): Unit = { + private def update(exec: LiveExecutionData, force: Boolean = false): Unit = { val now = System.nanoTime() if (exec.endEvents >= exec.jobs.size + 1) { exec.write(kvstore, now) liveExecutions.remove(exec.executionId) + } else if (force) { + exec.write(kvstore, now) } else if (liveUpdatePeriodNs >= 0) { if (now - exec.lastWriteTime > liveUpdatePeriodNs) { exec.write(kvstore, now) From d3bd0435ee4ff3d414f32cce3f58b6b9f67e68bc Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 4 Apr 2018 15:51:27 -0700 Subject: [PATCH 0563/2461] [SPARK-23637][YARN] Yarn might allocate more resource if a same executor is killed multiple times. ## What changes were proposed in this pull request? `YarnAllocator` uses `numExecutorsRunning` to track the number of running executor. `numExecutorsRunning` is used to check if there're executors missing and need to allocate more. In current code, `numExecutorsRunning` can be negative when driver asks to kill a same idle executor multiple times. ## How was this patch tested? UT added Author: jinxing Closes #20781 from jinxing64/SPARK-23637. --- .../spark/deploy/yarn/YarnAllocator.scala | 36 +++++++------- .../deploy/yarn/YarnAllocatorSuite.scala | 48 ++++++++++++++++++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a537243d641cb..ebee3d431744d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -81,7 +81,8 @@ private[yarn] class YarnAllocator( private val releasedContainers = Collections.newSetFromMap[ContainerId]( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - private val numExecutorsRunning = new AtomicInteger(0) + private val runningExecutors = Collections.newSetFromMap[String]( + new ConcurrentHashMap[String, java.lang.Boolean]()) private val numExecutorsStarting = new AtomicInteger(0) @@ -166,7 +167,7 @@ private[yarn] class YarnAllocator( clock = newClock } - def getNumExecutorsRunning: Int = numExecutorsRunning.get() + def getNumExecutorsRunning: Int = runningExecutors.size() def getNumExecutorsFailed: Int = synchronized { val endTime = clock.getTimeMillis() @@ -242,12 +243,11 @@ private[yarn] class YarnAllocator( * Request that the ResourceManager release the container running the specified executor. */ def killExecutor(executorId: String): Unit = synchronized { - if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.get(executorId).get - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - logWarning(s"Attempted to kill unknown executor $executorId!") + executorIdToContainer.get(executorId) match { + case Some(container) if !releasedContainers.contains(container.getId) => + internalReleaseContainer(container) + runningExecutors.remove(executorId) + case _ => logWarning(s"Attempted to kill unknown executor $executorId!") } } @@ -274,7 +274,7 @@ private[yarn] class YarnAllocator( "Launching executor count: %d. Cluster resources: %s.") .format( allocatedContainers.size, - numExecutorsRunning.get, + runningExecutors.size, numExecutorsStarting.get, allocateResponse.getAvailableResources)) @@ -286,7 +286,7 @@ private[yarn] class YarnAllocator( logDebug("Completed %d containers".format(completedContainers.size)) processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." - .format(completedContainers.size, numExecutorsRunning.get)) + .format(completedContainers.size, runningExecutors.size)) } } @@ -300,9 +300,9 @@ private[yarn] class YarnAllocator( val pendingAllocate = getPendingAllocate val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - - numExecutorsStarting.get - numExecutorsRunning.get + numExecutorsStarting.get - runningExecutors.size logDebug(s"Updating resource requests, target: $targetNumExecutors, " + - s"pending: $numPendingAllocate, running: ${numExecutorsRunning.get}, " + + s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") if (missing > 0) { @@ -502,7 +502,7 @@ private[yarn] class YarnAllocator( s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { - numExecutorsRunning.incrementAndGet() + runningExecutors.add(executorId) numExecutorsStarting.decrementAndGet() executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -513,7 +513,7 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (numExecutorsRunning.get < targetNumExecutors) { + if (runningExecutors.size() < targetNumExecutors) { numExecutorsStarting.incrementAndGet() if (launchContainers) { launcherPool.execute(new Runnable { @@ -554,7 +554,7 @@ private[yarn] class YarnAllocator( } else { logInfo(("Skip launching executorRunnable as running executors count: %d " + "reached target executors count: %d.").format( - numExecutorsRunning.get, targetNumExecutors)) + runningExecutors.size, targetNumExecutors)) } } } @@ -569,7 +569,11 @@ private[yarn] class YarnAllocator( val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() + containerIdToExecutorId.get(containerId) match { + case Some(executorId) => runningExecutors.remove(executorId) + case None => logWarning(s"Cannot find executorId for container: ${containerId.toString}") + } + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, onHostStr, diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index cb1e3c5268510..525abb6f2b350 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -251,11 +251,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (1) } + test("kill same executor multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val executorToKill = handler.executorIdToContainer.keys.head + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty) + handler.updateResourceRequests() + handler.getPendingAllocate.size should be (1) + } + + test("process same completed container multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val statuses = Seq(container1, container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.processCompletedContainers(statuses) + handler.getNumExecutorsRunning should be (0) + + } + test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() @@ -272,7 +316,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (2) From c5c8b544047a83cb6128a20d31f1d943a15f9260 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 13:39:45 +0200 Subject: [PATCH 0564/2461] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20756 from viirya/SPARK-23593. --- .../expressions/objects/objects.scala | 47 ++++++++++++++++- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 +++++++++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a455c1c821a26..20c4f4c7324fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1410,8 +1410,47 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + if (paramVal == null) { + throw new NullPointerException("The parameter value for setters in " + + "`InitializeJavaBean` can not be null") + } else { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1424,6 +1463,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} + |if (${fieldGen.isNull}) { + | throw new NullPointerException("The parameter value for setters in " + + | "`InitializeJavaBean` can not be null"); + |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 9bfe2916b0820..44fecd602e854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -128,6 +128,50 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("Can not pass in null into setters in InitializeJavaBean") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + intercept[NullPointerException] { + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + intercept[NullPointerException] { + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -278,3 +322,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 1822ecda51cc9e14bb18050e0b8c270fee47ced7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 5 Apr 2018 13:47:06 +0200 Subject: [PATCH 0565/2461] [SPARK-23582][SQL] StaticInvoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `StaticInvoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20753 from kiszk/SPARK-23582. --- .../expressions/objects/objects.scala | 14 +++- .../expressions/ObjectExpressionsSuite.scala | 66 ++++++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 20c4f4c7324fd..9ca0b6137679e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. @@ -217,12 +218,21 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + + override def eval(input: InternalRow): Any = { + invoke(null, method, arguments, input, dataType) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 44fecd602e854..eb89e01b5ff9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -28,9 +30,11 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -93,6 +97,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23582: StaticInvoke should support interpreted execution") { + Seq((classOf[java.lang.Boolean], "true", true), + (classOf[java.lang.Byte], "1", 1.toByte), + (classOf[java.lang.Short], "257", 257.toShort), + (classOf[java.lang.Integer], "12345", 12345), + (classOf[java.lang.Long], "12345678", 12345678.toLong), + (classOf[java.lang.Float], "12.34", 12.34.toFloat), + (classOf[java.lang.Double], "1.2345678", 1.2345678) + ).foreach { case (cls, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))), + expected, InternalRow.fromSeq(Seq(arg))) + } + + // Return null when null argument is passed with propagateNull = true + val stringCls = classOf[java.lang.String] + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true), + null, InternalRow.fromSeq(Seq(null))) + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false), + "null", InternalRow.fromSeq(Seq(null))) + + // test no argument + val clCls = classOf[java.lang.ClassLoader] + checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil), + ClassLoader.getSystemClassLoader, InternalRow.empty) + // test more than one argument + val intCls = classOf[java.lang.Integer] + checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare", + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))), + 0, InternalRow.fromSeq(Seq(7, 7))) + + Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]), + new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))), + (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]), + new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))), + (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]), + "abc", UTF8String.fromString("abc")), + (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]), + BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))), + (Decimal.getClass, DecimalType.SYSTEM_DEFAULT, + "apply", ObjectType(classOf[java.math.BigInteger]), + new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))), + (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]), + Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))), + (classOf[UnsafeArrayData], ArrayType(IntegerType, false), + "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), + Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), + (DateTimeUtils.getClass, ObjectType(classOf[Date]), + "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), + "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) + ).foreach { case (cls, dataType, methodName, argType, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, + Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) + } + } + test("SPARK-23583: Invoke should support interpreted execution") { val targetObject = new InvokeTargetClass val funcClass = classOf[InvokeTargetClass] From b2329fb1fcdc0e93c4bdc39d574cde7328ef6094 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 5 Apr 2018 13:57:41 +0200 Subject: [PATCH 0566/2461] Revert "[SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression" This reverts commit c5c8b544047a83cb6128a20d31f1d943a15f9260. --- .../expressions/objects/objects.scala | 47 +---------------- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 ------------------- 3 files changed, 5 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9ca0b6137679e..3fa91bd36bb60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,47 +1420,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - private lazy val resolvedSetters = { - assert(beanInstance.dataType.isInstanceOf[ObjectType]) - - val ObjectType(beanClass) = beanInstance.dataType - setters.map { - case (name, expr) => - // Looking for known type mapping. - // But also looking for general `Object`-type parameter for generic methods. - val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) - val methods = paramTypes.flatMap { fieldClass => - try { - Some(beanClass.getDeclaredMethod(name, fieldClass)) - } catch { - case e: NoSuchMethodException => None - } - } - if (methods.isEmpty) { - throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + - "in any enclosing class nor any supertype") - } - methods.head -> expr - } - } - - override def eval(input: InternalRow): Any = { - val instance = beanInstance.eval(input) - if (instance != null) { - val bean = instance.asInstanceOf[Object] - resolvedSetters.foreach { - case (setter, expr) => - val paramVal = expr.eval(input) - if (paramVal == null) { - throw new NullPointerException("The parameter value for setters in " + - "`InitializeJavaBean` can not be null") - } else { - setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) - } - } - } - instance - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1473,10 +1434,6 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |if (${fieldGen.isNull}) { - | throw new NullPointerException("The parameter value for setters in " + - | "`InitializeJavaBean` can not be null"); - |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,8 +55,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -112,14 +111,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (!errMsg.contains(expectedErrMsg)) { + if (errMsg != expectedErrMsg) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index eb89e01b5ff9d..1d59b20077fa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,50 +192,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } - test("SPARK-23593: InitializeJavaBean should support interpreted execution") { - val list = new java.util.LinkedList[Int]() - list.add(1) - - val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), - Map("add" -> Literal(1))) - checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) - - val initializeWithNonexistingMethod = InitializeJavaBean( - Literal.fromObject(new java.util.LinkedList[Int]), - Map("nonexisting" -> Literal(1))) - checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), - """A method named "nonexisting" is not declared in any enclosing class """ + - "nor any supertype") - - val initializeWithWrongParamType = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setX" -> Literal("1"))) - intercept[Exception] { - evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) - }.getMessage.contains( - """A method named "setX" is not declared in any enclosing class """ + - "nor any supertype") - } - - test("Can not pass in null into setters in InitializeJavaBean") { - val initializeBean = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal(null))) - intercept[NullPointerException] { - evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - intercept[NullPointerException] { - evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - - val initializeBean2 = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal("string"))) - evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) - evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) - } - test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -386,11 +342,3 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - -class TestBean extends Serializable { - private var x: Int = 0 - - def setX(i: Int): Unit = x = i - def setNonPrimitive(i: AnyRef): Unit = - assert(i != null, "this setter should not be called with null.") -} From d9ca1c906bd0571802f2297c36b407e660fcdb64 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 20:43:05 +0200 Subject: [PATCH 0567/2461] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20985 from viirya/SPARK-23593-2. --- .../expressions/objects/objects.scala | 45 +++++++++++++++-- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 48 +++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 3fa91bd36bb60..9252425f86473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,8 +1420,45 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + // We don't call setter if input value is null. + if (paramVal != null) { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1434,7 +1471,9 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |$javaBeanInstance.$setterMethod(${fieldGen.value}); + |if (!${fieldGen.isNull}) { + | $javaBeanInstance.$setterMethod(${fieldGen.value}); + |} """.stripMargin } val initializeCode = ctx.splitExpressionsWithCurrentInputs( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1d59b20077fa9..b1bc67dfac1b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,6 +192,46 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("InitializeJavaBean doesn't call setters if input in null") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -342,3 +382,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 4807d381bb113a5c61e6dad88202f23a8b6dd141 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:13:59 +0800 Subject: [PATCH 0568/2461] [SPARK-10399][CORE][SQL] Introduce multiple MemoryBlocks to choose several types of memory block ## What changes were proposed in this pull request? This PR allows us to use one of several types of `MemoryBlock`, such as byte array, int array, long array, or `java.nio.DirectByteBuffer`. To use `java.nio.DirectByteBuffer` allows to have off heap memory which is automatically deallocated by JVM. `MemoryBlock` class has primitive accessors like `Platform.getInt()`, `Platform.putint()`, or `Platform.copyMemory()`. This PR uses `MemoryBlock` for `OffHeapColumnVector`, `UTF8String`, and other places. This PR can improve performance of operations involving memory accesses (e.g. `UTF8String.trim`) by 1.8x. For now, this PR does not use `MemoryBlock` for `BufferHolder` based on cloud-fan's [suggestion](https://github.com/apache/spark/pull/11494#issuecomment-309694290). Since this PR is a successor of #11494, close #11494. Many codes were ported from #11494. Many efforts were put here. **I think this PR should credit to yzotov.** This PR can achieve **1.1-1.4x performance improvements** for operations in `UTF8String` or `Murmur3_x86_32`. Other operations are almost comparable performances. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 526 / 536 0.0 131399881.5 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 525 / 552 1022.6 1.0 1.0X substring 414 / 423 1298.0 0.8 1.3X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 474 / 488 0.0 118552232.0 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 476 / 480 1127.3 0.9 1.0X substring 287 / 291 1869.9 0.5 1.7X ``` Benchmark program ``` test("benchmark Murmur3_x86_32") { val length = 8192 * 32768 + 31 val seed = 42L val iters = 1 << 2 val random = new Random(seed) val arrays = Array.fill[MemoryBlock](numArrays) { val bytes = new Array[Byte](length) random.nextBytes(bytes) new ByteArrayMemoryBlock(bytes, Platform.BYTE_ARRAY_OFFSET, length) } val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays, minNumIters = 20) benchmark.addCase("HiveHasher") { _: Int => var sum = 0L for (_ <- 0L until iters) { sum += HiveHasher.hashUnsafeBytesBlock( arrays(i), Platform.BYTE_ARRAY_OFFSET, length) } } benchmark.run() } test("benchmark UTF8String") { val N = 512 * 1024 * 1024 val iters = 2 val benchmark = new Benchmark("UTF8String benchmark", N, minNumIters = 20) val str0 = new java.io.StringWriter() { { for (i <- 0 until N) { write(" ") } } }.toString val s0 = UTF8String.fromString(str0) benchmark.addCase("hashCode") { _: Int => var h: Int = 0 for (_ <- 0L until iters) { h += s0.hashCode } } benchmark.addCase("substring") { _: Int => var s: UTF8String = null for (_ <- 0L until iters) { s = s0.substring(N / 2 - 5, N / 2 + 5) } } benchmark.run() } ``` I run [this benchmark program](https://gist.github.com/kiszk/94f75b506c93a663bbbc372ffe8f05de) using [the commit](https://github.com/apache/spark/pull/19222/commits/ee5a79861c18725fb1cd9b518cdfd2489c05b81d6). I got the following results: ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Memory access benchmarks: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ByteArrayMemoryBlock get/putInt() 220 / 221 609.3 1.6 1.0X Platform get/putInt(byte[]) 220 / 236 610.9 1.6 1.0X Platform get/putInt(Object) 492 / 494 272.8 3.7 0.4X OnHeapMemoryBlock get/putLong() 322 / 323 416.5 2.4 0.7X long[] 221 / 221 608.0 1.6 1.0X Platform get/putLong(long[]) 321 / 321 418.7 2.4 0.7X Platform get/putLong(Object) 561 / 563 239.2 4.2 0.4X ``` I also run [this benchmark program](https://gist.github.com/kiszk/5fdb4e03733a5d110421177e289d1fb5) for comparing performance of `Platform.copyMemory()`. ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Platform copyMemory: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Object to Object 1961 / 1967 8.6 116.9 1.0X System.arraycopy Object to Object 1917 / 1921 8.8 114.3 1.0X byte array to byte array 1961 / 1968 8.6 116.9 1.0X System.arraycopy byte array to byte array 1909 / 1937 8.8 113.8 1.0X int array to int array 1921 / 1990 8.7 114.5 1.0X double array to double array 1918 / 1923 8.7 114.3 1.0X Object to byte array 1961 / 1967 8.6 116.9 1.0X Object to short array 1965 / 1972 8.5 117.1 1.0X Object to int array 1910 / 1915 8.8 113.9 1.0X Object to float array 1971 / 1978 8.5 117.5 1.0X Object to double array 1919 / 1944 8.7 114.4 1.0X byte array to Object 1959 / 1967 8.6 116.8 1.0X int array to Object 1961 / 1970 8.6 116.9 1.0X double array to Object 1917 / 1924 8.8 114.3 1.0X ``` These results show three facts: 1. According to the second/third or sixth/seventh results in the first experiment, if we use `Platform.get/putInt(Object)`, we achieve more than 2x worse performance than `Platform.get/putInt(byte[])` with concrete type (i.e. `byte[]`). 2. According to the second/third or fourth/fifth/sixth results in the first experiment, the fastest way to access an array element on Java heap is `array[]`. **Cons of `array[]` is that it is not possible to support unaligned-8byte access.** 3. According to the first/second/third or fourth/sixth/seventh results in the first experiment, `getInt()/putInt() or getLong()/putLong()` in subclasses of `MemoryBlock` can achieve comparable performance to `Platform.get/putInt()` or `Platform.get/putLong()` with concrete type (second or sixth result). There is no overhead regarding virtual call. 4. According to results in the second experiment, for `Platform.copy()`, to pass `Object` can achieve the same performance as to pass any type of primitive array as source or destination. 5. According to second/fourth results in the second experiment, `Platform.copy()` can achieve the same performance as `System.arrayCopy`. **It would be good to use `Platform.copy()` since `Platform.copy()` can take any types for src and dst.** We are incrementally replace `Platform.get/putXXX` with `MemoryBlock.get/putXXX`. This is because we have two advantages. 1) Achieve better performance due to having a concrete type for an array. 2) Use simple OO design instead of passing `Object` It is easy to use `MemoryBlock` in `InternalRow`, `BufferHolder`, `TaskMemoryManager`, and others that are already abstracted. It is not easy to use `MemoryBlock` in utility classes related to hashing or others. Other candidates are - UnsafeRow, UnsafeArrayData, UnsafeMapData, SpecificUnsafeRowJoiner - UTF8StringBuffer - BufferHolder - TaskMemoryManager - OnHeapColumnVector - BytesToBytesMap - CachedBatch - classes for hash - others. ## How was this patch tested? Added `UnsafeMemoryAllocator` Author: Kazuaki Ishizaki Closes #19222 from kiszk/SPARK-10399. --- .../sql/catalyst/expressions/HiveHasher.java | 12 +- .../org/apache/spark/unsafe/Platform.java | 2 +- .../spark/unsafe/array/ByteArrayMethods.java | 13 +- .../apache/spark/unsafe/array/LongArray.java | 17 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 45 +++-- .../unsafe/memory/ByteArrayMemoryBlock.java | 128 +++++++++++++ .../unsafe/memory/HeapMemoryAllocator.java | 19 +- .../spark/unsafe/memory/MemoryAllocator.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 157 ++++++++++++++-- .../spark/unsafe/memory/MemoryLocation.java | 54 ------ .../unsafe/memory/OffHeapMemoryBlock.java | 105 +++++++++++ .../unsafe/memory/OnHeapMemoryBlock.java | 132 +++++++++++++ .../unsafe/memory/UnsafeMemoryAllocator.java | 21 ++- .../apache/spark/unsafe/types/UTF8String.java | 148 +++++++-------- .../spark/unsafe/PlatformUtilSuite.java | 4 +- .../spark/unsafe/array/LongArraySuite.java | 5 +- .../unsafe/hash/Murmur3_x86_32Suite.java | 18 ++ .../spark/unsafe/memory/MemoryBlockSuite.java | 175 ++++++++++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 29 +-- .../spark/memory/TaskMemoryManager.java | 22 +-- .../shuffle/sort/ShuffleInMemorySorter.java | 14 +- .../shuffle/sort/ShuffleSortDataFormat.java | 11 +- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../unsafe/sort/UnsafeInMemorySorter.java | 13 +- .../spark/memory/TaskMemoryManagerSuite.java | 2 +- .../util/collection/ExternalSorterSuite.scala | 7 +- .../unsafe/sort/RadixSortSuite.scala | 10 +- .../spark/ml/feature/FeatureHasher.scala | 5 +- .../spark/mllib/feature/HashingTF.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../spark/sql/catalyst/expressions/XXH64.java | 46 +++-- .../spark/sql/catalyst/expressions/hash.scala | 39 ++-- .../catalyst/expressions/HiveHasherSuite.java | 20 +- .../sql/catalyst/expressions/XXH64Suite.java | 18 +- .../vectorized/OffHeapColumnVector.java | 3 +- .../sql/vectorized/ArrowColumnVector.java | 6 +- .../execution/benchmark/SortBenchmark.scala | 16 +- .../sql/execution/python/RowQueueSuite.scala | 4 +- 39 files changed, 1002 insertions(+), 334 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java create mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 73577437ac506..5d905943a3aa7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -38,12 +39,17 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + public static int hashUnsafeBytesBlock(MemoryBlock mb) { + long lengthInBytes = mb.size(); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (int i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) Platform.getByte(base, offset + i); + for (long i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) mb.getByte(i); } return result; } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..54dcadf3a7754 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index a6b1f7a16d605..c334c9651cf6b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -48,6 +49,16 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); + /** + * MemoryBlock equality check for MemoryBlocks. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEqualsBlock( + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, + rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); + } + /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -56,7 +67,7 @@ public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if stars align and we can get both offsets to be aligned + // check if starts align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 2cd39bd60c2ac..b74d2de0691d5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -33,16 +32,12 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; - private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); - this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -51,11 +46,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return baseObj; + return memory.getBaseObject(); } public long getBaseOffset() { - return baseOffset; + return memory.getBaseOffset(); } /** @@ -69,8 +64,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + for (long off = 0; off < length * WIDTH; off += WIDTH) { + memory.putLong(off, 0); } } @@ -80,7 +75,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + memory.putLong(index * WIDTH, value); } /** @@ -89,6 +84,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return memory.getLong(index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index d239de6083ad0..f372b19fac119 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,9 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.Platform; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.memory.MemoryBlock; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -49,49 +51,66 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + int h1 = hashBytesByIntBlock(base, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(base, offset + i); + int halfWord = base.getByte(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthInBytes = Ints.checkedCast(base.size()); + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + k1 ^= (base.getByte(i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + private static int hashBytesByIntBlock(MemoryBlock base, int seed) { + long lengthInBytes = base.size(); assert (lengthInBytes % 4 == 0); int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); + for (long i = 0; i < lengthInBytes; i += 4) { + int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 0000000000000..99a9868a49a79 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a byte array on Java heap. + */ +public final class ByteArrayMemoryBlock extends MemoryBlock { + + private final byte[] array; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); + } + + public ByteArrayMemoryBlock(long length) { + this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new ByteArrayMemoryBlock(array, this.offset + offset, size); + } + + public byte[] getByteArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; + } + + @Override + public final void putByte(long offset, byte value) { + array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 2733760dd19ef..acf28fd7ee59b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -58,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -70,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -79,12 +79,13 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj != null) : + assert(memory instanceof OnHeapMemoryBlock); + assert (memory.getBaseObject() != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -94,12 +95,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = (long[]) memory.obj; - memory.setObjAndOffset(null, 0); + long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); + memory.resetObjAndOffset(); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 7b588681d9790..38315fb97b46a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index c333857358d30..b086941108522 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + * A representation of a consecutive memory block in Spark. It defines the common interfaces + * for memory accessing and mutating. */ -public class MemoryBlock extends MemoryLocation { - +public abstract class MemoryBlock { /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,38 +45,163 @@ public class MemoryBlock extends MemoryLocation { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - private final long length; + @Nullable + protected Object obj; + + protected long offset; + + protected long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. + * TaskMemoryManager. This field can be updated using setPageNumber method so that + * this can be modified by the TaskMemoryManager, which lives in a different package. */ - public int pageNumber = NO_PAGE_NUMBER; + private int pageNumber = NO_PAGE_NUMBER; - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); + protected MemoryBlock(@Nullable Object obj, long offset, long length) { + if (offset < 0 || length < 0) { + throw new IllegalArgumentException( + "Length " + length + " and offset " + offset + "must be non-negative"); + } + this.obj = obj; + this.offset = offset; this.length = length; } + protected MemoryBlock() { + this(null, 0, 0); + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } + + public void resetObjAndOffset() { + this.obj = null; + this.offset = 0; + } + /** * Returns the size of the memory block. */ - public long size() { + public final long size() { return length; } - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + public final void setPageNumber(int pageNum) { + pageNumber = pageNum; + } + + public final int getPageNumber() { + return pageNumber; } /** * Fills the memory block with the specified byte value. */ - public void fill(byte value) { + public final void fill(byte value) { Platform.setMemory(obj, offset, length, value); } + + /** + * Instantiate MemoryBlock for given object type with new offset + */ + public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + MemoryBlock mb = null; + if (obj instanceof byte[]) { + byte[] array = (byte[])obj; + mb = new ByteArrayMemoryBlock(array, offset, length); + } else if (obj instanceof long[]) { + long[] array = (long[])obj; + mb = new OnHeapMemoryBlock(array, offset, length); + } else if (obj == null) { + // we assume that to pass null pointer means off-heap + mb = new OffHeapMemoryBlock(offset, length); + } else { + throw new UnsupportedOperationException( + "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); + } + return mb; + } + + /** + * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative + * offset from the original offset. The data is not copied. + * If parameters are invalid, an exception is thrown. + */ + public abstract MemoryBlock subBlock(long offset, long size); + + protected void checkSubBlockRange(long offset, long size) { + if (offset < 0 || size < 0) { + throw new ArrayIndexOutOfBoundsException( + "Size " + size + " and offset " + offset + " must be non-negative"); + } + if (offset + size > length) { + throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + + offset + " should not be larger than the length " + length + " in the MemoryBlock"); + } + } + + /** + * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal + * memory access, throw an exception, or etc. + * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as + * JVM object header. The offset is 0-based and is expected as an logical offset in the memory + * block. + */ + public abstract int getInt(long offset); + + public abstract void putInt(long offset, int value); + + public abstract boolean getBoolean(long offset); + + public abstract void putBoolean(long offset, boolean value); + + public abstract byte getByte(long offset); + + public abstract void putByte(long offset, byte value); + + public abstract short getShort(long offset); + + public abstract void putShort(long offset, short value); + + public abstract long getLong(long offset); + + public abstract void putLong(long offset, long value); + + public abstract float getFloat(long offset); + + public abstract void putFloat(long offset, float value); + + public abstract double getDouble(long offset); + + public abstract void putDouble(long offset, double value); + + public static final void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, + dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); + } + + public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { + assert(length <= src.length && length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), + dst.getBaseObject(), dst.getBaseOffset(), length); + } + + public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); + } + + public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java deleted file mode 100644 index 74ebc87dc978c..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import javax.annotation.Nullable; - -/** - * A memory location. Tracked either by a memory address (with off-heap allocation), - * or by an offset from a JVM object (in-heap allocation). - */ -public class MemoryLocation { - - @Nullable - Object obj; - - long offset; - - public MemoryLocation(@Nullable Object obj, long offset) { - this.obj = obj; - this.offset = offset; - } - - public MemoryLocation() { - this(null, 0); - } - - public void setObjAndOffset(Object newObj, long newOffset) { - this.obj = newObj; - this.offset = newOffset; - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 0000000000000..f90f62bf21dcb --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock extends MemoryBlock { + static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + + public OffHeapMemoryBlock(long address, long size) { + super(null, address, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OffHeapMemoryBlock(this.offset + offset, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(null, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(null, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(null, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(null, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(null, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(null, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(null, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(null, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(null, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(null, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(null, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(null, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(null, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(null, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java new file mode 100644 index 0000000000000..12f67c7bd593e --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a long array on Java heap. + */ +public final class OnHeapMemoryBlock extends MemoryBlock { + + private final long[] array; + + public OnHeapMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); + } + + public OnHeapMemoryBlock(long size) { + this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OnHeapMemoryBlock(array, this.offset + offset, size); + } + + public long[] getLongArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static OnHeapMemoryBlock fromArray(final long[] array) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + public static OnHeapMemoryBlock fromArray(final long[] array, long size) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(array, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(array, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 4368fb615ba1e..5310bdf2779a9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - MemoryBlock memory = new MemoryBlock(null, address, size); + OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,22 +36,25 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert(memory instanceof OffHeapMemoryBlock) : + "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; + if (memory == OffHeapMemoryBlock.NULL) return; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.offset = 0; + memory.resetObjAndOffset(); // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5d468aed42337..e9b3d9b045af5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,9 +30,12 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -50,12 +53,13 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; - private long offset; + private MemoryBlock base; + // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int private int numBytes; - public Object getBaseObject() { return base; } - public long getBaseOffset() { return offset; } + public MemoryBlock getMemoryBlock() { return base; } + public Object getBaseObject() { return base.getBaseObject(); } + public long getBaseOffset() { return base.getBaseOffset(); } /** * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which @@ -108,7 +112,8 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); } else { return null; } @@ -121,19 +126,13 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); } else { return null; } } - /** - * Creates an UTF8String from given address (base and offset) and length. - */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { - return new UTF8String(base, offset, numBytes); - } - /** * Creates an UTF8String from String. */ @@ -150,16 +149,13 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + public UTF8String(MemoryBlock base) { this.base = base; - this.offset = offset; - this.numBytes = numBytes; + this.numBytes = Ints.checkedCast(base.size()); } // for serialization - public UTF8String() { - this(null, 0, 0); - } + public UTF8String() {} /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -167,7 +163,7 @@ public UTF8String() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - Platform.copyMemory(base, offset, target, targetOffset, numBytes); + base.writeTo(0, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -187,8 +183,9 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + long offset = base.getBaseOffset(); + if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -255,12 +252,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); + p = (long) base.getInt(0); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -269,12 +266,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; + p = ((long) base.getInt(0)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -289,12 +286,13 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + long offset = base.getBaseOffset(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -324,7 +322,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -365,14 +363,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return Platform.getByte(base, offset + i); + return base.getByte(i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -499,8 +497,7 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } lastComma = i; @@ -508,8 +505,7 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } return 0; @@ -524,7 +520,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -671,8 +667,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - copyMemory(this.base, this.offset + i, result, - BYTE_ARRAY_OFFSET + result.length - i - len, len); + base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -686,7 +681,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -723,7 +718,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -739,7 +734,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start += 1; @@ -753,7 +748,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start -= 1; @@ -786,7 +781,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -806,7 +801,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -829,15 +824,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -865,13 +860,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -896,8 +891,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -936,8 +931,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -945,8 +940,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - copyMemory( - separator.base, separator.offset, + separator.base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1220,7 +1215,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1228,11 +1223,10 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; + MemoryBlock rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); + long left = base.getLong(i); + long right = rbase.getLong(i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1243,7 +1237,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); + int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); if (res != 0) { return res; } @@ -1262,7 +1256,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); } else { return false; } @@ -1318,8 +1312,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, - s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1334,7 +1328,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); + return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); } /** @@ -1397,10 +1391,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + byte[] bytes = new byte[numBytes]; + in.readFully(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } @Override @@ -1412,10 +1406,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - this.offset = BYTE_ARRAY_OFFSET; - this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + numBytes = in.readInt(); + byte[] bytes = new byte[numBytes]; + in.read(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..583a148b3845d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f3..8c2e98c2bfc54 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,14 +20,13 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; public class LongArraySuite { @Test public void basicTest() { - long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 6348a73bf3895..d7ed005db1891 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,6 +70,24 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } + @Test + public void testKnownWordsInputs() { + byte[] bytes = new byte[16]; + long offset = Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < 16; i++) { + bytes[i] = 0; + } + Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = -1; + } + Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = (byte)i; + } + Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java new file mode 100644 index 0000000000000..47f05c928f2e5 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteOrder; + +import static org.hamcrest.core.StringContains.containsString; + +public class MemoryBlockSuite { + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private void check(MemoryBlock memory, Object obj, long offset, int length) { + memory.setPageNumber(1); + memory.fill((byte)-1); + memory.putBoolean(0, true); + memory.putByte(1, (byte)127); + memory.putShort(2, (short)257); + memory.putInt(4, 0x20000002); + memory.putLong(8, 0x1234567089ABCDEFL); + memory.putFloat(16, 1.0F); + memory.putLong(20, 0x1234567089ABCDEFL); + memory.putDouble(28, 2.0); + MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); + int[] a = new int[2]; + a[0] = 0x12345678; + a[1] = 0x13579BDF; + memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); + byte[] b = new byte[8]; + memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); + + Assert.assertEquals(obj, memory.getBaseObject()); + Assert.assertEquals(offset, memory.getBaseOffset()); + Assert.assertEquals(length, memory.size()); + Assert.assertEquals(1, memory.getPageNumber()); + Assert.assertEquals(true, memory.getBoolean(0)); + Assert.assertEquals((byte)127, memory.getByte(1 )); + Assert.assertEquals((short)257, memory.getShort(2)); + Assert.assertEquals(0x20000002, memory.getInt(4)); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); + Assert.assertEquals(1.0F, memory.getFloat(16), 0); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); + Assert.assertEquals(2.0, memory.getDouble(28), 0); + Assert.assertEquals(true, memory.getBoolean(36)); + Assert.assertEquals((byte)127, memory.getByte(37 )); + Assert.assertEquals((short)257, memory.getShort(38)); + Assert.assertEquals(a[0], memory.getInt(40)); + Assert.assertEquals(a[1], memory.getInt(44)); + if (bigEndianPlatform) { + Assert.assertEquals(a[0], + ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | + ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | + ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); + } else { + Assert.assertEquals(a[0], + ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | + ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | + ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); + } + for (int i = 48; i < memory.size(); i++) { + Assert.assertEquals((byte) -1, memory.getByte(i)); + } + + assert(memory.subBlock(0, memory.size()) == memory); + + try { + memory.subBlock(-8, 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, -8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, length + 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(8, length - 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(length + 8, 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + } + + @Test + public void ByteArrayMemoryBlockTest() { + byte[] obj = new byte[56]; + long offset = Platform.BYTE_ARRAY_OFFSET; + int length = obj.length; + + MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = ByteArrayMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new byte[112]; + memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OnHeapMemoryBlockTest() { + long[] obj = new long[7]; + long offset = Platform.LONG_ARRAY_OFFSET; + int length = obj.length * 8; + + MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = OnHeapMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new long[14]; + memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OffHeapArrayMemoryBlockTest() { + MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); + MemoryBlock memory = memoryAllocator.allocate(56); + Object obj = memory.getBaseObject(); + long offset = memory.getBaseOffset(); + int length = 56; + + check(memory, obj, offset, length); + + long address = Platform.allocateMemory(112); + memory = new OffHeapMemoryBlock(address, length); + obj = memory.getBaseObject(); + offset = memory.getBaseOffset(); + check(memory, obj, offset, length); + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7c34d419574ef..bad908fcaf136 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; import static org.junit.Assert.*; @@ -519,7 +522,8 @@ public void writeToOutputStreamUnderflow() throws IOException { final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + new UTF8String( + new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); @@ -534,7 +538,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -565,7 +569,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -592,26 +596,25 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamIntArray() throws IOException { + public void writeToOutputStreamLongArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(12, length); + assertEquals(16, length); - final int ints = length / 4; - final int[] array = new int[ints]; + final int longs = length / 8; + final long[] array = new long[longs]; - for (int i = 0; i < ints; ++i) { - array[i] = buffer.getInt(); + for (int i = 0; i < longs; ++i) { + array[i] = buffer.getLong(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - fromAddress(array, Platform.INT_ARRAY_OFFSET, length) - .writeTo(outputStream); - assertEquals("大千世界", outputStream.toString("UTF-8")); + new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); + assertEquals("3千大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..8651a639c07f7 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911f..8f49859746b89 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,7 +20,6 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -105,13 +104,7 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -180,10 +173,7 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 717bdd79d47ef..254449e95443e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,13 +60,8 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - Platform.copyMemory( - src.getBaseObject(), - src.getBaseOffset() + srcPos * 8L, - dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8L, - length * 8L - ); + MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, + dst.memoryBlock(),dstPos * 8L,length * 8L); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 66118f454159b..4fc19b1721518 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != + if (!loaded || page.getPageNumber() != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b3c27d83da172..20a7a8b267438 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,7 +26,6 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -216,12 +215,7 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -348,10 +342,7 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index a0664b30d6cc2..d7d2d0b012bd3 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e2..3e56db5ea116a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(MemoryBlock.fromLongArray(ref)) - val tmp = new Array[Long](size/2) - val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) + val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..ddf3740e76a7a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index c78f61ac3ef71..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -243,8 +243,7 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - val utf8 = UTF8String.fromString(s) - hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 8935c8496cdbb..7b73b286fb91c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d18542b188f71..8546c28335536 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,6 +27,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -230,7 +231,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 71c086029cc5b..29a1411241cf6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -414,7 +415,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index f37ef83ad92b4..883748932ad33 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off /** @@ -71,13 +74,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWords(Object base, long offset, int length) { - return hashUnsafeWords(base, offset, length, seed); + public long hashUnsafeWordsBlock(MemoryBlock mb) { + return hashUnsafeWordsBlock(mb, seed); } - public static long hashUnsafeWords(Object base, long offset, int length, long seed) { - assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWords(base, offset, length, seed); + public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { + assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWordsBlock(mb, seed); return fmix(hash); } @@ -85,26 +88,32 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWords(base, offset, length, seed); + long hash = hashBytesByWordsBlock(mb, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } return fmix(hash); } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + } + private static long fmix(long hash) { hash ^= hash >>> 33; hash *= PRIME64_2; @@ -114,30 +123,31 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWords(Object base, long offset, int length, long seed) { - long end = offset + length; + private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); long hash; if (length >= 32) { - long limit = end - 32; + long limit = length - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 += mb.getLong(offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 += mb.getLong(offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 += mb.getLong(offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 += mb.getLong(offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -178,9 +188,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long hash += length; - long limit = end - 8; + long limit = length - 8; while (offset <= limit) { - long k1 = Platform.getLong(base, offset); + long k1 = mb.getLong(offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b702422ed7a1d..b76b64ab5096f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -360,10 +361,8 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" } protected def genHashForMap( @@ -465,6 +464,8 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long + /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -490,8 +491,7 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) case array: ArrayData => val elementType = dataType match { @@ -578,9 +578,15 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } + + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) + } } /** @@ -605,9 +611,14 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } + + override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { + XXH64.hashUnsafeBytesBlock(base, seed) + } } /** @@ -714,10 +725,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" } override protected def genHashForArray( @@ -805,10 +814,14 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) + private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index b67c6f3e6e85e..8ffc1d7c24d61 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -53,7 +55,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -89,13 +91,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. @@ -112,13 +114,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 1baee91b3439c..cd8bce623c5df 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,6 +24,8 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -142,13 +144,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. @@ -165,13 +167,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 754c26579ff08..4733f36174f42 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return UTF8String.fromAddress(null, data + rowId, count); + return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f8e37e995a17f..227a16f7e69e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -377,9 +378,10 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return UTF8String.fromAddress(null, + return new UTF8String(new OffHeapMemoryBlock( stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); + stringResult.end - stringResult.start + )); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d9..470b93efd1974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c5..25ee95daa034c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val page = new OnHeapMemoryBlock((1<<10) * 8L) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } From f2ac0879561cde63ed4eb759f5efa0a5ce393a22 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Thu, 5 Apr 2018 19:55:42 -0700 Subject: [PATCH 0569/2461] [SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns ## What changes were proposed in this pull request? `handleInvalid` Param was forwarded to the VectorAssembler used by RFormula. ## How was this patch tested? added a test and ran all tests for RFormula and VectorAssembler Author: Yogesh Garg Closes #20970 from yogeshg/spark_23562. --- .../apache/spark/ml/feature/RFormula.scala | 1 + .../spark/ml/feature/RFormulaSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 22e7b8bbf1ff5..e214765e3307f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + .setHandleInvalid($(handleInvalid)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 27d570f0b68ad..a250331efeb1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { assert(features.toArray === a +: b.toArray) } } + + test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") { + val d1 = Seq( + (1001L, "a"), + (1002L, "b")).toDF("id1", "c1") + val d2 = Seq[(java.lang.Long, String)]( + (20001L, "x"), + (20002L, "y"), + (null, null)).toDF("id2", "c2") + val dataset = d1.crossJoin(d2) + + def get_output(mode: String): DataFrame = { + val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode) + formula.fit(dataset).transform(dataset).select("features", "label") + } + + assert(intercept[SparkException](get_output("error").collect()) + .getMessage.contains("Encountered null while assembling a row")) + assert(get_output("skip").count() == 4) + assert(get_output("keep").count() == 6) + } + } From d65e531b44a388fed25d3cbf28fdce5a2d0598e6 Mon Sep 17 00:00:00 2001 From: JiahuiJiang Date: Thu, 5 Apr 2018 20:06:08 -0700 Subject: [PATCH 0570/2461] [SPARK-23823][SQL] Keep origin in transformExpression Fixes https://issues.apache.org/jira/browse/SPARK-23823 Keep origin for all the methods using transformExpression ## What changes were proposed in this pull request? Keep origin in transformExpression ## How was this patch tested? Manually tested that this fixes https://issues.apache.org/jira/browse/SPARK-23823 and columns have correct origins after Analyzer.analyze Author: JiahuiJiang Author: Jiahui Jiang Closes #20961 from JiahuiJiang/jj/keep-origin. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++- .../sql/catalyst/plans/QueryPlanSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ddf2cbf2ab911..64cb8c726772f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -103,7 +103,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT var changed = false @inline def transformExpression(e: Expression): Expression = { - val newE = f(e) + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } if (newE.fastEquals(e)) { e } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala new file mode 100644 index 0000000000000..27914ef5565c0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.IntegerType + +class QueryPlanSuite extends SparkFunSuite { + + test("origin remains the same after mapExpressions (SPARK-23823)") { + CurrentOrigin.setPosition(0, 0) + val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) + val query = plans.DslLogicalPlan(plans.table("table")).select(column) + CurrentOrigin.reset() + + val mappedQuery = query mapExpressions { + case _: Expression => Literal(1) + } + + val mappedOrigin = mappedQuery.expressions.apply(0).origin + assert(mappedOrigin == Origin.apply(Some(0), Some(0))) + } + +} From 249007e37f51f00d14e596692aeac87fbc10b520 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 5 Apr 2018 20:19:25 -0700 Subject: [PATCH 0571/2461] [SPARK-19724][SQL] create a managed table with an existed default table should throw an exception ## What changes were proposed in this pull request? This PR is to finish https://github.com/apache/spark/pull/17272 This JIRA is a follow up work after SPARK-19583 As we discussed in that PR The following DDL for a managed table with an existed default location should throw an exception: CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... CREATE TABLE ... (PARTITIONED BY ...) Currently there are some situations which are not consist with above logic: CREATE TABLE ... (PARTITIONED BY ...) succeed with an existed default location situation: for both hive/datasource(with HiveExternalCatalog/InMemoryCatalog) CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... situation: hive table succeed with an existed default location This PR is going to make above two situations consist with the logic that it should throw an exception with an existed default location. ## How was this patch tested? unit test added Author: Gengliang Wang Closes #20886 from gengliangwang/pr-17272. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../command/createDataSourceTables.scala | 5 +- .../spark/sql/StatisticsCollectionSuite.scala | 7 ++ .../sql/execution/command/DDLSuite.scala | 67 +++++++++++++++++++ 6 files changed, 110 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2b393f30d1435..9822d669050d5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1809,6 +1809,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 64e7ca11270b4..52ed89ef8d781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -289,6 +289,7 @@ class SessionCatalog( def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) validateName(table) val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined @@ -298,15 +299,33 @@ class SessionCatalog( makeQualifiedPath(tableDefinition.storage.locationUri.get) tableDefinition.copy( storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), - identifier = TableIdentifier(table, Some(db))) + identifier = tableIdentifier) } else { - tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + tableDefinition.copy(identifier = tableIdentifier) } requireDbExists(db) + if (!ignoreIfExists) { + validateTableLocation(newTableDefinition) + } externalCatalog.createTable(newTableDefinition, ignoreIfExists) } + def validateTableLocation(table: CatalogTable): Unit = { + // SPARK-19724: the default location of a managed table should be non-existent or empty. + if (table.tableType == CatalogTableType.MANAGED && + !conf.allowCreatingManagedTableUsingNonemptyLocation) { + val tableLocation = + new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier))) + val fs = tableLocation.getFileSystem(hadoopConf) + + if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) { + throw new AnalysisException(s"Can not create the managed table('${table.identifier}')" + + s". The associated location('${tableLocation.toString}') already exists.") + } + } + } + /** * Alter the metadata of an existing metastore table identified by `tableDefinition`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 13f31a6b2eb93..1c8ab9c62623e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1159,6 +1159,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = + buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation") + .internal() + .doc("When this option is set to true, creating managed tables with nonempty location " + + "is allowed. Otherwise, an analysis exception is thrown. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1581,6 +1589,9 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = + getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index e9747769dfcfc..f7c3e9b019258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -167,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand( sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) - + sparkSession.sessionState.catalog.validateTableLocation(table) val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { Some(sessionState.catalog.defaultTablePath(table.identifier)) } else { @@ -181,7 +181,8 @@ case class CreateDataSourceTableAsSelectCommand( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) - sessionState.catalog.createTable(newTable, ignoreIfExists = false) + // Table location is already validated. No need to check it again during table creation. + sessionState.catalog.createTable(newTable, ignoreIfExists = true) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ed4ea0231f1a7..14a565863d66c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier @@ -26,6 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -242,6 +245,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier(table))) Seq(false, true).foreach { autoUpdate => withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { @@ -269,6 +273,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + // SPARK-19724: clean up the previous table location. + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4df8fbfe1c0db..4304d0b6f6b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -180,6 +180,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private val escapedIdentifier = "`(.+)`".r + private def dataSource: String = { + if (isUsingHiveMetastore) { + "HIVE" + } else { + "PARQUET" + } + } protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { @@ -365,6 +372,66 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("CTAS a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + sql("INSERT INTO tab1 VALUES (1, 'a')") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing non-empty directory") { + withTable("tab1") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + // create an empty hidden file + tableLoc.mkdir() + val hiddenGarbageFile = new File(tableLoc.getCanonicalPath, ".garbage") + hiddenGarbageFile.createNewFile() + val exMsg = "Can not create the managed table('`tab1`'). The associated location" + val exMsgWithDefaultDB = + "Can not create the managed table('`default`.`tab1`'). The associated location" + var ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + }.getMessage + if (isUsingHiveMetastore) { + assert(ex.contains(exMsgWithDefaultDB)) + } else { + assert(ex.contains(exMsg)) + } + + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], From 6ade5cbb498f6c6ea38779b97f2325d5cf5013f2 Mon Sep 17 00:00:00 2001 From: Daniel Sakuma Date: Fri, 6 Apr 2018 13:37:08 +0800 Subject: [PATCH 0572/2461] [MINOR][DOC] Fix some typos and grammar issues ## What changes were proposed in this pull request? Easy fix in the documentation. ## How was this patch tested? N/A Closes #20948 Author: Daniel Sakuma Closes #20928 from dsakuma/fix_typo_configuration_docs. --- docs/README.md | 2 +- docs/_plugins/include_example.rb | 2 +- docs/building-spark.md | 2 +- docs/cloud-integration.md | 4 +-- docs/configuration.md | 20 ++++++------ docs/css/pygments-default.css | 2 +- docs/graphx-programming-guide.md | 4 +-- docs/job-scheduling.md | 4 +-- docs/ml-advanced.md | 2 +- docs/ml-classification-regression.md | 6 ++-- docs/ml-collaborative-filtering.md | 2 +- docs/ml-features.md | 2 +- docs/ml-migration-guides.md | 2 +- docs/ml-tuning.md | 2 +- docs/mllib-clustering.md | 2 +- docs/mllib-collaborative-filtering.md | 4 +-- docs/mllib-data-types.md | 2 +- docs/mllib-dimensionality-reduction.md | 2 +- docs/mllib-evaluation-metrics.md | 2 +- docs/mllib-feature-extraction.md | 2 +- docs/mllib-isotonic-regression.md | 4 +-- docs/mllib-linear-methods.md | 2 +- docs/mllib-optimization.md | 4 +-- docs/monitoring.md | 4 +-- docs/quick-start.md | 6 ++-- docs/rdd-programming-guide.md | 2 +- docs/running-on-kubernetes.md | 4 +-- docs/running-on-mesos.md | 12 +++---- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/spark-standalone.md | 2 +- docs/sparkr.md | 6 ++-- docs/sql-programming-guide.md | 32 +++++++++---------- docs/storage-openstack-swift.md | 2 +- docs/streaming-flume-integration.md | 6 ++-- docs/streaming-kafka-0-8-integration.md | 10 +++--- docs/streaming-programming-guide.md | 26 +++++++-------- .../structured-streaming-kafka-integration.md | 2 +- .../structured-streaming-programming-guide.md | 8 ++--- docs/submitting-applications.md | 2 +- docs/tuning.md | 2 +- python/README.md | 2 +- sql/README.md | 2 +- 43 files changed, 107 insertions(+), 107 deletions(-) diff --git a/docs/README.md b/docs/README.md index 225bb1b2040de..9eac4ba35c458 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,7 +5,7 @@ here with the Spark source code. You can also find documentation specific to rel Spark at http://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the -documentation yourself. Why build it yourself? So that you have the docs that corresponds to +documentation yourself. Why build it yourself? So that you have the docs that correspond to whichever version of Spark you currently have checked out of revision control. ## Prerequisites diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ea1d438f529e..1e91f12518e0b 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -48,7 +48,7 @@ def render(context) begin code = File.open(@file).read.encode("UTF-8") rescue => e - # We need to explicitly exit on execptions here because Jekyll will silently swallow + # We need to explicitly exit on exceptions here because Jekyll will silently swallow # them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104) puts(e) puts(e.backtrace) diff --git a/docs/building-spark.md b/docs/building-spark.md index c391255a91596..0236bb05849ad 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -113,7 +113,7 @@ Note: Flume support is deprecated as of Spark 2.3.0. ## Building submodules individually -It's possible to build Spark sub-modules using the `mvn -pl` option. +It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index c150d9efc06ff..ac1c336988930 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -27,13 +27,13 @@ description: Introduction to cloud storage support in Apache Spark SPARK_VERSION All major cloud providers offer persistent data storage in *object stores*. These are not classic "POSIX" file systems. In order to store hundreds of petabytes of data without any single points of failure, -object stores replace the classic filesystem directory tree +object stores replace the classic file system directory tree with a simpler model of `object-name => data`. To enable remote access, operations on objects are usually offered as (slow) HTTP REST operations. Spark can read and write data in object stores through filesystem connectors implemented in Hadoop or provided by the infrastructure suppliers themselves. -These connectors make the object stores look *almost* like filesystems, with directories and files +These connectors make the object stores look *almost* like file systems, with directories and files and the classic operations on them such as list, delete and rename. diff --git a/docs/configuration.md b/docs/configuration.md index 2eb6a77434ea6..4d4d0c58dd07d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -558,7 +558,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1288,7 +1288,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1513,7 +1513,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1722,7 +1722,7 @@ Apart from these, the following properties are also available, and may be useful When spark.task.reaper.enabled = true, this setting specifies a timeout after which the executor JVM will kill itself if a killed task has not stopped running. The default value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose - of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + of this setting is to act as a safety-net to prevent runaway noncancellable tasks from rendering an executor unusable. @@ -1915,8 +1915,8 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1971,7 +1971,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1980,7 +1980,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -2178,7 +2178,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defalut.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. diff --git a/docs/css/pygments-default.css b/docs/css/pygments-default.css index 6247cd8396cf1..a4d583b366603 100644 --- a/docs/css/pygments-default.css +++ b/docs/css/pygments-default.css @@ -5,7 +5,7 @@ To generate this, I had to run But first I had to install pygments via easy_install pygments I had to override the conflicting bootstrap style rules by linking to -this stylesheet lower in the html than the bootstap css. +this stylesheet lower in the html than the bootstrap css. Also, I was thrown off for a while at first when I was using markdown code block inside my {% highlight scala %} ... {% endhighlight %} tags diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 5c97a248df4bc..35293348e3f3d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -491,7 +491,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts)( The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices` except that the user defined `map` function is applied to all vertices and can change the vertex property type. Because not all vertices may have a matching value in the input RDD the `map` -function takes an `Option` type. For example, we can setup a graph for PageRank by initializing +function takes an `Option` type. For example, we can set up a graph for PageRank by initializing vertex properties with their `outDegree`. @@ -969,7 +969,7 @@ A vertex is part of a triangle when it has two adjacent vertices with an edge be # Examples Suppose I want to build a graph from some text files, restrict the graph -to important relationships and users, run page-rank on the sub-graph, and +to important relationships and users, run page-rank on the subgraph, and then finally return attributes associated with the top users. I can do all of this in just a few lines with GraphX: diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index e6d881639a13b..da90342406c84 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -23,7 +23,7 @@ run tasks and store data for that application. If multiple users need to share y different options to manage allocation, depending on the cluster manager. The simplest option, available on all cluster managers, is _static partitioning_ of resources. With -this approach, each application is given a maximum amount of resources it can use, and holds onto them +this approach, each application is given a maximum amount of resources it can use and holds onto them for its whole duration. This is the approach used in Spark's [standalone](spark-standalone.html) and [YARN](running-on-yarn.html) modes, as well as the [coarse-grained Mesos mode](running-on-mesos.html#mesos-run-modes). @@ -230,7 +230,7 @@ properties: * `minShare`: Apart from an overall weight, each pool can be given a _minimum shares_ (as a number of CPU cores) that the administrator would like it to have. The fair scheduler always attempts to meet all active pools' minimum shares before redistributing extra resources according to the weights. - The `minShare` property can therefore be another way to ensure that a pool can always get up to a + The `minShare` property can, therefore, be another way to ensure that a pool can always get up to a certain number of resources (e.g. 10 cores) quickly without giving it a high priority for the rest of the cluster. By default, each pool's `minShare` is 0. diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 2747f2df7cb10..375957e92cc4c 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -77,7 +77,7 @@ Quasi-Newton methods in this case. This fallback is currently always enabled for L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively. -In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. +In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features is no more than 4096. For larger problems, use L-BFGS instead. ## Iteratively reweighted least squares (IRLS) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index ddd2f4b49ca07..d660655e193eb 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -420,7 +420,7 @@ Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator`. For the base classifier, it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. @@ -908,7 +908,7 @@ Refer to the [R API docs](api/R/spark.survreg.html) for more details. belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -927,7 +927,7 @@ We implement a which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns -label, features and weight. Additionally IsotonicRegression algorithm has one +label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 58f2d4b531e70..8b0f287dc39ad 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -35,7 +35,7 @@ but the ids must be within the integer value range. ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. diff --git a/docs/ml-features.md b/docs/ml-features.md index 3370eb3893272..7aed2341584fc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1174,7 +1174,7 @@ for more details on the API. ## SQLTransformer `SQLTransformer` implements the transformations which are defined by SQL statement. -Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +Currently, we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` where `"__THIS__"` represents the underlying table of the input dataset. The select clause specifies the fields, constants, and expressions to display in the output, and can be any select clause that Spark SQL supports. Users can also diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index f4b0df58cf63b..e4736411fb5fe 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -347,7 +347,7 @@ rather than using the old parameter class `Strategy`. These new training method separate classification and regression, and they replace specialized parameter types with simple `String` types. -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the [Decision Trees Guide](mllib-decision-tree.html#examples). ## From 0.9 to 1.0 diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 54d9cd21909df..028bfec465bab 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -103,7 +103,7 @@ Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.m In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in - the case of `CrossValidator`. It is therefore less expensive, + the case of `CrossValidator`. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large. Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index df2be92d860e4..dc6b095f5d59b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -42,7 +42,7 @@ The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute Within -Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the +Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact, the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 76a00f18b3b90..b2300028e151b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -31,7 +31,7 @@ following parameters: ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. @@ -60,7 +60,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
    -In the following example we load rating data. Each row consists of a user, a product and a rating. +In the following example, we load rating data. Each row consists of a user, a product and a rating. We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS$) method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 35cee3275e3b5..5066bb29387dc 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -350,7 +350,7 @@ which is a tuple of `(Int, Int, Matrix)`. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -In general the use of non-deterministic RDDs can lead to errors. +In general, the use of non-deterministic RDDs can lead to errors. ### RowMatrix diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index a72680d52a26c..4e6b4530942f1 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -91,7 +91,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an [Principal component analysis (PCA)](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical method to find a rotation such that the first coordinate has the largest variance -possible, and each succeeding coordinate in turn has the largest variance possible. The columns of +possible, and each succeeding coordinate, in turn, has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. `spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7f277543d2e9a..d9dbbab4840a3 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -13,7 +13,7 @@ of the model on some criteria, which depends on the application and its requirem suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, -regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 8b89296b14cdd..bb29f65c0322f 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -105,7 +105,7 @@ p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top \]` where $V$ is the vocabulary size. -The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to $O(\log(V))$ diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index ca84551506b2b..99cab98c690c6 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -9,7 +9,7 @@ displayTitle: Regression - RDD-based API belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -28,7 +28,7 @@ best fitting the original data points. which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent -label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 034e89e25000e..73f6e206ca543 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -425,7 +425,7 @@ We create our model by initializing the weights to zero and register the streams testing then start the job. Printing predictions alongside true labels lets us easily see the result. -Finally we can save text files with data to the training or testing folders. +Finally, we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` the model will update. Anytime a text file is placed in `args(1)` you will see predictions. diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 14d76a6e41e23..04758903da89c 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -121,7 +121,7 @@ computation of the sum of the partial results from each worker machine is perfor standard spark routines. If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in -each iteration is exact (sub)gradient descent. In this case there is no randomness and no +each iteration is exact (sub)gradient descent. In this case, there is no randomness and no variance in the used step directions. On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to @@ -135,7 +135,7 @@ algorithm in the family of quasi-Newton methods to solve the optimization proble quadratic without evaluating the second partial derivatives of the objective function to construct the Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no vertical scalability issue (the number of training features) when computing the Hessian matrix -explicitly in Newton's method. As a result, L-BFGS often achieves rapider convergence compared with +explicitly in Newton's method. As a result, L-BFGS often achieves more rapid convergence compared with other first-order optimization. ### Choosing an Optimization Method diff --git a/docs/monitoring.md b/docs/monitoring.md index 01736c77b0979..6eaf33135744d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -214,7 +214,7 @@ incomplete attempt or the final successful attempt. 2. Incomplete applications are only updated intermittently. The time between updates is defined by the interval between checks for changed files (`spark.history.fs.update.interval`). -On larger clusters the update interval may be set to large values. +On larger clusters, the update interval may be set to large values. The way to view a running application is actually to view its own web UI. 3. Applications which exited without registering themselves as completed will be listed @@ -422,7 +422,7 @@ configuration property. If, say, users wanted to set the metrics namespace to the name of the application, they can set the `spark.metrics.namespace` property to a value like `${spark.app.name}`. This value is then expanded appropriately by Spark and is used as the root namespace of the metrics system. -Non driver and executor metrics are never prefixed with `spark.app.id`, nor does the +Non-driver and executor metrics are never prefixed with `spark.app.id`, nor does the `spark.metrics.namespace` property have any such affect on such metrics. Spark's metrics are decoupled into different diff --git a/docs/quick-start.md b/docs/quick-start.md index 07c520cbee6be..f1a2096cd4dbd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,11 +11,11 @@ This tutorial provides a quick introduction to using Spark. We will first introd interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -To follow along with this guide, first download a packaged release of Spark from the +To follow along with this guide, first, download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. -Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. # Interactive Analysis with the Spark Shell @@ -47,7 +47,7 @@ scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. +Now let's transform this Dataset into a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 2e29aef7f21a2..b6424090d2fea 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -818,7 +818,7 @@ The behavior of the above code is undefined, and may not work as intended. To ex The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. +In local mode, in some circumstances, the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 9c4644947c911..e9e1f3e280609 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -17,7 +17,7 @@ container images and entrypoints.** * A runnable distribution of Spark 2.3 or above. * A running Kubernetes cluster at version >= 1.6 with access configured to it using [kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster, -you may setup a test cluster on your local machine using +you may set up a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. * Be aware that the default minikube configuration is not enough for running Spark applications. @@ -221,7 +221,7 @@ that allows driver pods to create pods and services under the default Kubernetes [RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to be used by the driver pod through the configuration property -`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod +`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example, to make the driver pod use the `spark` service account, a user simply adds the following option to the `spark-submit` command: ``` diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8e58892e2689f..3c2a1501ca692 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -90,7 +90,7 @@ Depending on your deployment environment you may wish to create a single set of Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. -Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. +Additionally, if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. ### Credential Specification Preference Order @@ -225,7 +225,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the Spark application will consume all resources offered to it by Mesos, -so we of course urge you to set this variable in any sort of +so we, of course, urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -233,14 +233,14 @@ The scheduler will start executors round-robin on the offers Mesos gives it, but there are no spread guarantees, as Mesos does not provide such guarantees on the offer stream. -In this mode spark executors will honor port allocation if such is -provided from the user. Specifically if the user defines +In this mode Spark executors will honor port allocation if such is +provided from the user. Specifically, if the user defines `spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the user, ephemeral ports are used as usual. This port honouring implementation -implies one task per host if the user defines a port. In the future network +implies one task per host if the user defines a port. In the future network, isolation shall be supported. The benefit of coarse-grained mode is much lower startup overhead, but @@ -486,7 +486,7 @@ See the [configuration page](configuration.html) for information on Spark config
    - + @@ -1797,6 +1798,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + + + + From 2a24c481da3f30b510deb62e5cf21c9463cf250c Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 24 Apr 2018 09:25:41 -0700 Subject: [PATCH 0680/2461] [SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features ## What changes were proposed in this pull request? - Multiple possible input types is added in validateAndTransformSchema() and computeCost() while checking column type - Add if statement in transform() to support array type as featuresCol - Add the case statement in fit() while selecting columns from dataset These changes will be applied to KMeans first, then to other clustering method ## How was this patch tested? unit test is added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21081 from ludatabricks/SPARK-23975. --- .../apache/spark/ml/clustering/KMeans.scala | 32 +++++++--- .../apache/spark/ml/util/DatasetUtils.scala | 63 +++++++++++++++++++ .../spark/ml/clustering/KMeansSuite.scala | 38 +++++++++++ 3 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1ad157a695a7d..d475c726e6f08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,13 +86,24 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** + * Validates the input schema. + * @param schema input schema + */ + private[clustering] def validateSchema(schema: StructType): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + + SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) + } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + validateSchema(schema) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -125,8 +136,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -146,8 +160,10 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + validateSchema(dataset.schema) + + val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } parentModel.computeCost(data) @@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select( + DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 0000000000000..52619cb65489a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vectors, VectorUDT} +import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 77c9d482d95b6..5445ebe5c95eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + val featuresColNameD = "array_double_features" + val featuresColNameF = "array_float_features" + + val doubleUDF = udf { (features: Vector) => + val featureArray = Array.fill[Double](features.size)(0.0) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + val floatUDF = udf { (features: Vector) => + val featureArray = Array.fill[Float](features.size)(0.0f) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + + val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) + .drop("features") + val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) + .drop("features") + assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) + assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) + + val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) + val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) + val modelD = kmeansD.fit(newdatasetD) + val modelF = kmeansF.fit(newdatasetF) + val transformedD = modelD.transform(newdatasetD) + val transformedF = modelF.transform(newdatasetF) + + val predictDifference = transformedD.select("prediction") + .except(transformedF.select("prediction")) + assert(predictDifference.count() == 0) + assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) From ce7ba2e98e0a3b038e881c271b5905058c43155b Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Tue, 24 Apr 2018 09:57:09 -0700 Subject: [PATCH 0681/2461] [SPARK-23807][BUILD] Add Hadoop 3.1 profile with relevant POM fix ups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. Adds a `hadoop-3.1` profile build depending on the hadoop-3.1 artifacts. 1. In the hadoop-cloud module, adds an explicit hadoop-3.1 profile which switches from explicitly pulling in cloud connectors (hadoop-openstack, hadoop-aws, hadoop-azure) to depending on the hadoop-cloudstorage POM artifact, which pulls these in, has pre-excluded things like hadoop-common, and stays up to date with new connectors (hadoop-azuredatalake, hadoop-allyun). Goal: it becomes the Hadoop projects homework of keeping this clean, and the spark project doesn't need to handle new hadoop releases adding more dependencies. 1. the hadoop-cloud/hadoop-3.1 profile also declares support for jetty-ajax and jetty-util to ensure that these jars get into the distribution jar directory when needed by unshaded libraries. 1. Increases the curator and zookeeper versions to match those in hadoop-3, fixing spark core to build in sbt with the hadoop-3 dependencies. ## How was this patch tested? * Everything this has been built and tested against both ASF Hadoop branch-3.1 and hadoop trunk. * spark-shell was used to create connectors to all the stores and verify that file IO could take place. The spark hive-1.2.1 JAR has problems here, as it's version check logic fails for Hadoop versions > 2. This can be avoided with either of * The hadoop JARs built to declare their version as Hadoop 2.11 `mvn install -DskipTests -DskipShade -Ddeclared.hadoop.version=2.11` . This is safe for local test runs, not for deployment (HDFS is very strict about cross-version deployment). * A modified version of spark hive whose version check switch statement is happy with hadoop 3. I've done both, with maven and SBT. Three issues surfaced 1. A spark-core test failure —fixed in SPARK-23787. 1. SBT only: Zookeeper not being found in spark-core. Somehow curator 2.12.0 triggers some slightly different dependency resolution logic from previous versions, and Ivy was missing zookeeper.jar entirely. This patch adds the explicit declaration for all spark profiles, setting the ZK version = 3.4.9 for hadoop-3.1 1. Marking jetty-utils as provided in spark was stopping hadoop-azure from being able to instantiate the azure wasb:// client; it was using jetty-util-ajax, which could then not find a class in jetty-util. Author: Steve Loughran Closes #20923 from steveloughran/cloud/SPARK-23807-hadoop-31. --- assembly/pom.xml | 8 ++ core/pom.xml | 6 + dev/deps/spark-deps-hadoop-3.1 | 221 +++++++++++++++++++++++++++++++++ dev/test-dependencies.sh | 1 + hadoop-cloud/pom.xml | 83 ++++++++++++- pom.xml | 9 ++ 6 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 dev/deps/spark-deps-hadoop-3.1 diff --git a/assembly/pom.xml b/assembly/pom.xml index a207dae5a74ff..9608c96fd5369 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/core/pom.xml b/core/pom.xml index 9258a856028a0..093a9869b6dd7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,12 @@ org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 0000000000000..97ad65a4096cb --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,221 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.8.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar +automaton-1.11-8.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.58.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.8.jar +commons-compress-1.4.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.4.jar +httpcore-4.4.8.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.8.jar +java-xmlbuilder-1.1.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jets3t-0.9.4.jar +jetty-webapp-9.3.20.v20170531.jar +jetty-xml-9.3.20.v20170531.jar +jline-2.12.1.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.8.2.jar +parquet-common-1.8.2.jar +parquet-encoding-1.8.2.jar +parquet-format-2.3.1.jar +parquet-hadoop-1.8.2.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.8.2.jar +protobuf-java-2.5.0.jar +py4j-0.10.6.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.5.9.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm5-shaded-4.4.jar +xz-1.0.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3bf7618e1ea96..2fbd6b5e98f7f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -34,6 +34,7 @@ MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 8e424b1c50236..2c39a7df0146e 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/pom.xml b/pom.xml index 0a711f287a53f..88e77ff874748 100644 --- a/pom.xml +++ b/pom.xml @@ -2671,6 +2671,15 @@ + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn From 83013752e3cfcbc3edeef249439ac20b143eeabc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Apr 2018 10:40:25 -0700 Subject: [PATCH 0682/2461] [SPARK-23455][ML] Default Params in ML should be saved separately in metadata ## What changes were proposed in this pull request? We save ML's user-supplied params and default params as one entity in metadata. During loading the saved models, we set all the loaded params into created ML model instances as user-supplied params. It causes some problems, e.g., if we strictly disallow some params to be set at the same time, a default param can fail the param check because it is treated as user-supplied param after loading. The loaded default params should not be set as user-supplied params. We should save ML default params separately in metadata. For backward compatibility, when loading metadata, if it is a metadata file from previous Spark, we shouldn't raise error if we can't find the default param field. ## How was this patch tested? Pass existing tests and added tests. Author: Liang-Chi Hsieh Closes #20633 from viirya/save-ml-default-params. --- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 4 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 4 +- .../RandomForestClassifier.scala | 4 +- .../spark/ml/clustering/BisectingKMeans.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../apache/spark/ml/clustering/KMeans.scala | 2 +- .../org/apache/spark/ml/clustering/LDA.scala | 4 +- .../feature/BucketedRandomProjectionLSH.scala | 2 +- .../apache/spark/ml/feature/Bucketizer.scala | 24 ---- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../org/apache/spark/ml/feature/Imputer.scala | 2 +- .../spark/ml/feature/MaxAbsScaler.scala | 2 +- .../apache/spark/ml/feature/MinHashLSH.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../ml/feature/OneHotEncoderEstimator.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../ml/feature/QuantileDiscretizer.scala | 24 ---- .../apache/spark/ml/feature/RFormula.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 13 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 4 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 4 +- .../spark/ml/tuning/CrossValidator.scala | 6 +- .../ml/tuning/TrainValidationSplit.scala | 6 +- .../org/apache/spark/ml/util/ReadWrite.scala | 130 ++++++++++++------ .../spark/ml/util/DefaultReadWriteTest.scala | 73 +++++++++- project/MimaExcludes.scala | 6 + 44 files changed, 223 insertions(+), 147 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 771cd4fe91dcf..57797d1cc4978 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -279,7 +279,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) val model = new DecisionTreeClassificationModel(metadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c0255103bc313..0aa24f0a3cfcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -379,14 +379,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8f950cd28c3aa..80c537e1e0eb2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ee4b01058c75c..e426263910f26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1270,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index af2e4699924e5..57ba47e596a97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d47435..45fb585ed2262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 5348d882cfd67..7df53a6b8ad10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -289,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -484,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bb972e9706fc1..f1ef26a07d3f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -319,14 +319,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica case (treeMetadata, root) => val tree = new DecisionTreeClassificationModel(treeMetadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f7c422dc0faea..addc12ac52ec1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -193,7 +193,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a6938..b5804900c0358 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -233,7 +233,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d475c726e6f08..de61c9c089a36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -280,7 +280,7 @@ object KMeansModel extends MLReadable[KMeansModel] { sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159f..47077230fac0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +391,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 41eaaf9679914..a906e954fecd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -238,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f49c410cbcfe2..f99649f7fa164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } - - private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea3..dbfb199ccd58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 9e0ed437e7bfc..10c48c3f52085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -363,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb8..58897cca4e5c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db8..1c074e204ad99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f67..90eceb0d61b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 556848e45532d..a67a3b0abbc1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -205,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..2e0ae4af66f06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index bd1e3426c8780..4a44f3186538d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d7930..8172491a517d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 3b4c25478fb1d..56e2c543d100a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - private[QuantileDiscretizer] - class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e214765e3307f..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d2..91b0707dec3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 67cdb097217a2..a833d8b270cf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff0..0e7396a621dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d6..fc9996d69ba72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 3d041fc80eb7f..0bf405d9abf9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -335,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9a83a5882ce29..e6c347ed17c15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0e..a23f9552b9e5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75f..7c6ec2a8419fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 5cef5c9f21f1e..8bcf0793a64c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) val model = new DecisionTreeRegressionModel(metadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 834aaa0e362d1..8598e808c4946 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -311,7 +311,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +319,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 4c3f1431d5077..e030a40cb19be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1146,7 +1146,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474b..b046897ab2b7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 9cdd3a051e719..f1d9a4453deaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -799,7 +799,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f77398ba2a22..4509f85aafd12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -276,14 +276,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c2826dcc08634..5e916cc4a9fdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -424,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8d1b9a8ddab59..13369c4df7180 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -407,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7edcd498678cc..72a60e04360d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434ee..4d9e664850c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a87fa68422c34..7d0e88ee20c3f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,6 +62,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), From 379bffa0525a4343f8c10e51ed192031922f9874 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 24 Apr 2018 11:02:22 -0700 Subject: [PATCH 0683/2461] [SPARK-23990][ML] Instruments logging improvements - ML regression package ## What changes were proposed in this pull request? Instruments logging improvements - ML regression package I add an `OptionalInstrument` class which used in `WeightLeastSquares` and `IterativelyReweightedLeastSquares`. ## How was this patch tested? N/A Author: WeichenXu Closes #21078 from WeichenXu123/inst_reg. --- .../classification/LogisticRegression.scala | 4 +- .../IterativelyReweightedLeastSquares.scala | 18 +++-- .../spark/ml/optim/WeightedLeastSquares.scala | 32 +++++---- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../GeneralizedLinearRegression.scala | 14 ++-- .../ml/regression/LinearRegression.scala | 22 +++--- .../spark/ml/tree/impl/RandomForest.scala | 2 + .../spark/ml/util/Instrumentation.scala | 68 ++++++++++++++++++- 8 files changed, 125 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e426263910f26..06ca37bc75146 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -816,7 +816,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4d..572b8cf0051b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd29..1b7c15f1f0a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 7c6ec2a8419fd..e27a96e1f5dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index e030a40cb19be..143c8a3548b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -492,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -501,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f1d9a4453deaa..c45ade94a4e33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -339,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -378,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -385,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -415,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -430,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -522,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 056a94b351f79..905870178e549 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -108,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index e694bc27b2f1e..3247c394dfa64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.util.UUID +import scala.reflect.ClassTag + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -40,7 +42,8 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { + val estimator: E, + val dataset: RDD[_]) extends Logging { private val id = UUID.randomUUID() private val prefix = { @@ -103,6 +106,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( logNamedValue(Instrumentation.loggerTags.numClasses, num) } + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) + } + /** * Logs the value with customized name field. */ @@ -114,6 +121,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Double): Unit = { + log(compact(render(name -> value))) + } + /** * Logs the successful completion of the training session. */ @@ -131,6 +142,8 @@ private[spark] object Instrumentation { val numFeatures = "numFeatures" val numClasses = "numClasses" val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" } /** @@ -150,3 +163,56 @@ private[spark] object Instrumentation { } } + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } + } + + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + + /** + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. + */ + def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), + instr.estimator.getClass.getName.stripSuffix("$")) + } + + /** + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. + */ + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) + } +} From 7b1e6523af3c96043aa8d2763e5f18b6e2781c3d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Apr 2018 14:33:33 -0700 Subject: [PATCH 0684/2461] [SPARK-24056][SS] Make consumer creation lazy in Kafka source for Structured streaming ## What changes were proposed in this pull request? Currently, the driver side of the Kafka source (i.e. KafkaMicroBatchReader) eagerly creates a consumer as soon as the Kafk aMicroBatchReader is created. However, we create dummy KafkaMicroBatchReader to get the schema and immediately stop it. Its better to make the consumer creation lazy, it will be created on the first attempt to fetch offsets using the KafkaOffsetReader. ## How was this patch tested? Existing unit tests Author: Tathagata Das Closes #21134 from tdas/SPARK-24056. --- .../sql/kafka010/KafkaOffsetReader.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..82066697cb95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } From d6c26d1c9a8f747a3e0d281a27ea9eb4d92102e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 24 Apr 2018 17:06:03 -0700 Subject: [PATCH 0685/2461] [SPARK-24038][SS] Refactor continuous writing to its own class ## What changes were proposed in this pull request? Refactor continuous writing to its own class. See WIP https://github.com/jose-torres/spark/pull/13 for the overall direction this is going, but I think this PR is very isolated and necessary anyway. ## How was this patch tested? existing unit tests - refactoring only Author: Jose Torres Closes #21116 from jose-torres/SPARK-24038. --- .../datasources/v2/DataSourceV2Strategy.scala | 4 + .../datasources/v2/WriteToDataSourceV2.scala | 74 +---------- .../continuous/ContinuousExecution.scala | 2 +- .../WriteToContinuousDataSource.scala | 31 +++++ .../WriteToContinuousDataSourceExec.scala | 124 ++++++++++++++++++ 5 files changed, 165 insertions(+), 70 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1ac9572de6412..c2a31442d2be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -32,6 +33,9 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index e80b44c1cdc66..ea283ed77efda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -65,25 +65,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e s"The input RDD has ${messages.length} partitions.") try { - val runTask = writer match { - // This case means that we're doing continuous processing. In microbatch streaming, the - // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. - case w: StreamWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) - } - sparkContext.runJob( rdd, - runTask, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message @@ -91,14 +76,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } ) - if (!writer.isInstanceOf[StreamWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { - case _: InterruptedException if writer.isInstanceOf[StreamWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -111,8 +92,6 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } logError(s"Data source writer $writer aborted.") cause match { - // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) case _ => throw cause @@ -168,49 +147,6 @@ object DataWritingSparkTask extends Logging { logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } } class InternalRowDataWriterFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 951d694355ec5..f58146ac42398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -199,7 +199,7 @@ class ContinuousExecution( triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) val reader = withSink.collect { case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 0000000000000..943c731a70529 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 0000000000000..ba88ae1af469a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.util.Utils + +/** + * The physical plan for writing data into a continuous processing [[StreamWriter]]. + */ +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${rdd.getNumPartitions} partitions.") + // Let the epoch coordinator know how many partitions the write RDD has. + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + WriteToContinuousDataSourceExec.run(writerFactory, context, iter), + rdd.partitions.indices) + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} + +object WriteToContinuousDataSourceExec extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): Unit = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + do { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } while (!context.isInterrupted()) + } +} From 5fea17b3befc50aef59b799711d03b9552f21b19 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 25 Apr 2018 11:19:08 +0900 Subject: [PATCH 0686/2461] [SPARK-23821][SQL] Collection function: flatten ## What changes were proposed in this pull request? This PR adds a new collection function that transforms an array of arrays into a single array. The PR comprises: - An expression for flattening array structure - Flatten function - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(Seq(1, 2), Seq(4, 5)), Seq(null, Seq(1)) ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 064 */ project_numElements, /* 065 */ 4); /* 066 */ if (project_size > 2147483632) { /* 067 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 068 */ project_size + " bytes of data due to exceeding the limit 2147483632" + /* 069 */ " bytes for UnsafeArrayData."); /* 070 */ } /* 071 */ /* 072 */ byte[] project_array = new byte[(int)project_size]; /* 073 */ UnsafeArrayData project_tempArrayData = new UnsafeArrayData(); /* 074 */ Platform.putLong(project_array, 16, project_numElements); /* 075 */ project_tempArrayData.pointTo(project_array, 16, (int)project_size); /* 076 */ int project_counter = 0; /* 077 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 078 */ ArrayData arr = inputadapter_value.getArray(k); /* 079 */ for (int l = 0; l < arr.numElements(); l++) { /* 080 */ if (arr.isNullAt(l)) { /* 081 */ project_tempArrayData.setNullAt(project_counter); /* 082 */ } else { /* 083 */ project_tempArrayData.setInt( /* 084 */ project_counter, /* 085 */ arr.getInt(l) /* 086 */ ); /* 087 */ } /* 088 */ project_counter++; /* 089 */ } /* 090 */ } /* 091 */ project_value = project_tempArrayData; /* 092 */ /* 093 */ } /* 094 */ /* 095 */ } ``` ### Non-primitive type ``` val df = Seq( Seq(Seq("a", "b"), Seq(null, "d")), Seq(null, Seq("a")) ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ Object[] project_arrayObject = new Object[(int)project_numElements]; /* 064 */ int project_counter = 0; /* 065 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 066 */ ArrayData arr = inputadapter_value.getArray(k); /* 067 */ for (int l = 0; l < arr.numElements(); l++) { /* 068 */ project_arrayObject[project_counter] = arr.getUTF8String(l); /* 069 */ project_counter++; /* 070 */ } /* 071 */ } /* 072 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject); /* 073 */ /* 074 */ } /* 075 */ /* 076 */ } ``` Author: mn-mikke Closes #20938 from mn-mikke/feature/array-api-flatten-to-master. --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 176 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 95 ++++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 79 ++++++++ 6 files changed, 376 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..de53b48b6f3b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2191,6 +2191,23 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..6afcf309bd690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -413,6 +413,7 @@ object FunctionRegistry { expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[Flatten]("flatten"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..bc71b5f34ce4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression { override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if (!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + + | " bytes for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[(int)$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..b49fa76b2a781 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..d2f057310f89b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3340,6 +3340,14 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..03605c30036a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -691,6 +691,85 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 64e8408e6fa2d74929601b01a29771738f6d8c65 Mon Sep 17 00:00:00 2001 From: liutang123 Date: Wed, 25 Apr 2018 18:10:51 +0800 Subject: [PATCH 0687/2461] [SPARK-24012][SQL] Union of map and other compatible column ## What changes were proposed in this pull request? Union of map and other compatible column result in unresolved operator 'Union; exception Reproduction `spark-sql>select map(1,2), 'str' union all select map(1,2,3,null), 1` Output: ``` Error in query: unresolved operator 'Union;; 'Union :- Project [map(1, 2) AS map(1, 2)#106, str AS str#107] : +- OneRowRelation$ +- Project [map(1, cast(2 as int), 3, cast(null as int)) AS map(1, CAST(2 AS INT), 3, CAST(NULL AS INT))#109, 1 AS 1#108] +- OneRowRelation$ ``` So, we should cast part of columns to be compatible when appropriate. ## How was this patch tested? Added a test (query union of map and other columns) to SQLQueryTestSuite's union.sql. Author: liutang123 Closes #21100 from liutang123/SPARK-24012. --- .../sql/catalyst/analysis/TypeCoercion.scala | 8 ++++ .../test/resources/sql-tests/inputs/union.sql | 11 +++++ .../resources/sql-tests/results/union.sql.out | 42 ++++++++++++++----- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cfcbd8db559a3..25bad28a2a209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -112,6 +112,14 @@ object TypeCoercion { StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) })) + case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + + case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) + Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) + case _ => None } diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad033..6da1b9b49b226 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0cf..b023df825d814 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + From 20ca208bcda6f22fe7d9fb54144de435b4237536 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 25 Apr 2018 19:06:18 +0800 Subject: [PATCH 0688/2461] [SPARK-23880][SQL] Do not trigger any jobs for caching data ## What changes were proposed in this pull request? This pr fixed code so that `cache` could prevent any jobs from being triggered. For example, in the current master, an operation below triggers a actual job; ``` val df = spark.range(10000000000L) .filter('id > 1000) .orderBy('id.desc) .cache() ``` This triggers a job while the cache should be lazy. The problem is that, when creating `InMemoryRelation`, we build the RDD, which calls `SparkPlan.execute` and may trigger jobs, like sampling job for range partitioner, or broadcast job. This pr removed the code to build a cached `RDD` in the constructor of `InMemoryRelation` and added `CachedRDDBuilder` to lazily build the `RDD` in `InMemoryRelation`. Then, the first call of `CachedRDDBuilder.cachedColumnBuffers` triggers a job to materialize the cache in `InMemoryTableScanExec` . ## How was this patch tested? Added tests in `CachedTableSuite`. Author: Takeshi Yamamuro Closes #21018 from maropu/SPARK-23880. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 14 +- .../execution/columnar/InMemoryRelation.scala | 155 ++++++++++-------- .../columnar/InMemoryTableScanExec.scala | 10 +- .../apache/spark/sql/CachedTableSuite.scala | 36 +++- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 8 files changed, 133 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 917168162b236..cd4def71e6f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2933,7 +2933,7 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a8794be7280c7..93bf91e56f1bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -119,7 +119,7 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } @@ -138,16 +138,14 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + cd.cachedRepresentation.cacheBuilder.clearCache() // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, + cacheBuilder = cd.cachedRepresentation + .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a7ba9b86a176f..da35a4734e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - logicalPlan: LogicalPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) -} - - /** * CachedBatch is a cached batch of rows. * @@ -55,58 +42,41 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics, - override val outputOrdering: Seq[SortOrder]) - extends logical.LeafNode with MultiInstanceRelation { - - override protected def innerChildren: Seq[SparkPlan] = Seq(child) - - override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = StorageLevel.NONE, - child = child.canonicalized, - tableName = None)( - _cachedColumnBuffers, - sizeInBytesStats, - statsOfPlanToCache, - outputOrdering) + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override def producedAttributes: AttributeSet = outputSet - - @transient val partitionStatistics = new PartitionStatistics(output) + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - override def computeStats(): Statistics = { - if (sizeInBytesStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) - } else { - Statistics(sizeInBytes = sizeInBytesStats.value.longValue) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } } + _cachedColumnBuffers } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } + } } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -154,32 +124,77 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - sizeInBytesStats, + cacheBuilder)( statsOfPlanToCache, outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e73e1378d52e3..ea315fb71617c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -154,7 +154,7 @@ case class InMemoryTableScanExec( private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -163,16 +163,16 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { + relation.cachedPlan.outputPartitioning match { case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + case _ => relation.cachedPlan.outputPartitioning } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. private val stats = relation.partitionStatistics @@ -252,7 +252,7 @@ case class InMemoryTableScanExec( // within the map Partitions closure. val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 669e5f2bf4e65..81b7e18773f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} @@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } @@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 40915a102bab0..f0dfe6b76f7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9b7b316211d30..863703b15f4f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a6178..569f00c053e5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head From 396938ef02c70468e1695872f96b1e9aff28b7ea Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 12:21:55 -0700 Subject: [PATCH 0689/2461] [SPARK-24050][SS] Calculate input / processing rates correctly for DataSourceV2 streaming sources ## What changes were proposed in this pull request? In some streaming queries, the input and processing rates are not calculated at all (shows up as zero) because MicroBatchExecution fails to associated metrics from the executed plan of a trigger with the sources in the logical plan of the trigger. The way this executed-plan-leaf-to-logical-source attribution works is as follows. With V1 sources, there was no way to identify which execution plan leaves were generated by a streaming source. So did a best-effort attempt to match logical and execution plan leaves when the number of leaves were same. In cases where the number of leaves is different, we just give up and report zero rates. An example where this may happen is as follows. ``` val cachedStaticDF = someStaticDF.union(anotherStaticDF).cache() val streamingInputDF = ... val query = streamingInputDF.join(cachedStaticDF).writeStream.... ``` In this case, the `cachedStaticDF` has multiple logical leaves, but in the trigger's execution plan it only has leaf because a cached subplan is represented as a single InMemoryTableScanExec leaf. This leads to a mismatch in the number of leaves causing the input rates to be computed as zero. With DataSourceV2, all inputs are represented in the executed plan using `DataSourceV2ScanExec`, each of which has a reference to the associated logical `DataSource` and `DataSourceReader`. So its easy to associate the metrics to the original streaming sources. In this PR, the solution is as follows. If all the streaming sources in a streaming query as v2 sources, then use a new code path where the execution-metrics-to-source mapping is done directly. Otherwise we fall back to existing mapping logic. ## How was this patch tested? - New unit tests using V2 memory source - Existing unit tests using V1 source Author: Tathagata Das Closes #21126 from tdas/SPARK-24050. --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 9 +- .../streaming/ProgressReporter.scala | 146 +++++++++++++----- .../sql/streaming/StreamingQuerySuite.scala | 134 +++++++++++++++- 3 files changed, 245 insertions(+), 44 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d21..d2d04b68de6ab 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -563,7 +563,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("ensure stream-stream self-join generates only one offset in offset log") { + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { val topic = newTopic() testUtils.createTopic(topic, partitions = 2) require(testUtils.getLatestOffsets(Set(topic)).size === 2) @@ -587,7 +587,12 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AddKafkaData(Set(topic), 1, 2), CheckAnswer((1, 1, 1), (2, 2, 2)), AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1e5be9c12762..16ad3ef9a3d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -141,7 +143,7 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + val sourceProgress = sources.distinct.map { source => val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -207,62 +209,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + } + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will leads to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897c..390d67d1feb27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -466,7 +466,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("input row calculation with mixed batch and streaming sources") { + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -479,7 +489,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -492,6 +502,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -733,6 +858,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * From ac4ca7c4dd3ff666ec70aeb26ac84cffa557ee12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Apr 2018 13:42:44 -0700 Subject: [PATCH 0690/2461] [SPARK-24012][SQL][TEST][FOLLOWUP] add unit test ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/21100 ## How was this patch tested? N/A Author: Wenchen Fan Closes #21154 from cloud-fan/test. --- .../catalyst/analysis/TypeCoercionSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index fd6a3121663ed..1cc431aaf0a60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -429,6 +429,24 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { From 95a651339ec39d5753e849e578ad715be0d7c83e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 09:12:38 +0800 Subject: [PATCH 0691/2461] [SPARK-24069][R] Add array_min / array_max functions ## What changes were proposed in this pull request? This PR proposes to add array_max and array_min in R side too. array_max: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_max(mutated$v1))) ``` ``` array_max(v1) 1 4 2 4 3 4 4 3 5 3 6 3 ``` array_min: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, array_min(mutated$v1))) ``` ``` array_min(v1) 1 6 2 6 3 4 4 6 5 8 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21142 from HyukjinKwon/sparkr_array_min_array_max. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 27 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 9 ++++++++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 55dec177ea853..f36d462a83cb0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,8 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_max", + "array_min", "array_position", "asc", "ascii", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7b3aa05074563..ec4bd4e73c7e5 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -206,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -2992,6 +2993,32 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + #' @details #' \code{array_position}: Locates the position of the first occurrence of the given value #' in the given array. Returns NA if either of the arguments are NA. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f30ac9e4295e4..562d3399ee9c8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,14 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a384997830276..8cc2db7a140f9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,11 +1479,18 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_position(), element_at() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() + # and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 0)) From 3f1e999d3d215bb3b867bcd83ec5c799448ec730 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 26 Apr 2018 09:14:24 +0800 Subject: [PATCH 0692/2461] [SPARK-23849][SQL] Tests for samplingRatio of json datasource ## What changes were proposed in this pull request? Added the `samplingRatio` option to the `json()` method of PySpark DataFrame Reader. Improving existing tests for Scala API according to review of the PR: https://github.com/apache/spark/pull/20959 ## How was this patch tested? Added new test for PySpark, updated 2 existing tests according to reviews of https://github.com/apache/spark/pull/20959 and added new negative test Author: Maxim Gekk Closes #21056 from MaxGekk/json-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 8 +++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonSuite.scala | 63 ++++++++++--------- .../datasources/json/TestJsonData.scala | 12 ++++ 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bd79bc2f43e5..df176c579fc8b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -239,6 +239,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -256,7 +258,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..98fa1b54b0a17 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3018,6 +3018,14 @@ def test_sort_with_nulls_order(self): df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d640fdc530ce2..b44552f0eb17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -374,6 +374,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 70aee561ff0f6..a58dff827b92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2128,38 +2128,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } - test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - withTempPath { path => - val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), - StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) - for (i <- 0 until 100) { - if (predefinedSample.contains(i)) { - writer.write(s"""{"f1":${i.toString}}""" + "\n") - } else { - writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") - } - } - writer.close() + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } - val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) - assert(ds.schema == new StructType().add("f1", LongType)) - } + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) } - test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { - val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - if (predefinedSample.contains(i)) { - s"""{"f1":${i.toString}}""" + "\n" - } else { - s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" - } - }.toDS() - val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) - assert(ds.schema == new StructType().add("f1", LongType)) + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f04..6e9559edf8ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } From 58c55cb4a6d72d72df908e37aa63f617b3cc5587 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 12:19:20 +0900 Subject: [PATCH 0693/2461] [SPARK-23902][SQL] Add roundOff flag to months_between ## What changes were proposed in this pull request? HIVE-15511 introduced the `roundOff` flag in order to disable the rounding to 8 digits which is performed in `months_between`. Since this can be a computational intensive operation, skipping it may improve performances when the rounding is not needed. ## How was this patch tested? modified existing UT Author: Marco Gaido Closes #21008 from mgaido91/SPARK-23902. --- python/pyspark/sql/functions.py | 10 +++- .../expressions/datetimeExpressions.scala | 33 +++++++---- .../sql/catalyst/util/DateTimeUtils.scala | 32 ++++------ .../expressions/DateExpressionsSuite.scala | 59 +++++++++++-------- .../catalyst/util/DateTimeUtilsSuite.scala | 30 +++++++--- .../org/apache/spark/sql/functions.scala | 13 +++- .../apache/spark/sql/DateFunctionsSuite.scala | 7 +++ 7 files changed, 118 insertions(+), 66 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index de53b48b6f3b4..38ae41a5dafe6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1088,16 +1088,20 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ Returns the number of months between date1 and date2. + Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b9b2cd5bdb9f0..d882d06cfd625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1156,38 +1156,49 @@ case class AddMonths(startDate: Expression, numMonths: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. + The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c85..4b00a61c6cf91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -870,24 +870,14 @@ object DateTimeUtils { * If time1 and time2 having the same day of month, or both are the last day of month, * it returns an integer (time under a day will be ignored). * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * If `roundOff` is set to true, the result is rounded to 8 decimal places. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -906,8 +896,12 @@ object DateTimeUtils { val timeInDay2 = millis2 - daysToMillis(date2, timeZone) val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff + } } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 080ec487cfa6a..63b24fb9eb13a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -464,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa3..cbf6106697f30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f057310f89b..f1587cd032adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2691,11 +2691,22 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. + * The result is rounded off to 8 digits. * @group datetime_funcs * @since 1.5.0 */ def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + new MonthsBetween(date1.expr, date2.expr) + } + + /** + * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf6..f712baa7a9134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -327,6 +327,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { From cd10f9df8284ee8a5d287b2cd204c70b8ba87f5e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 13:37:13 +0900 Subject: [PATCH 0694/2461] [SPARK-23916][SQL] Add array_join function ## What changes were proposed in this pull request? The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one. The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21011 from mgaido91/SPARK-23916. --- python/pyspark/sql/functions.py | 21 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 169 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 35 ++++ .../org/apache/spark/sql/functions.scala | 19 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 23 +++ 6 files changed, 268 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38ae41a5dafe6..ad4bd6f5089e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,27 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + @since(1.5) @ignore_unicode_prefix def concat(*cols): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6afcf309bd690..6bc7b4e4f7cb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -401,6 +401,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bc71b5f34ce4a..90223b9126555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -378,6 +378,175 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + s""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + s""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + +} + /** * Returns the minimum value in the array. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b49fa76b2a781..7048d93fd5649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f1587cd032adc..25afaacc38d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,25 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) + } + /** * Concatenates multiple input columns together into a single column. * The function works with strings, binary and compatible array columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03605c30036a3..c216d1322a06c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + } + test("array_min function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From ffaf0f9fd407aeba7006f3d785ea8a0e51187357 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 26 Apr 2018 13:27:33 +0800 Subject: [PATCH 0695/2461] [SPARK-24062][THRIFT SERVER] Fix SASL encryption cannot enabled issue in thrift server ## What changes were proposed in this pull request? For the details of the exception please see [SPARK-24062](https://issues.apache.org/jira/browse/SPARK-24062). The issue is: Spark on Yarn stores SASL secret in current UGI's credentials, this credentials will be distributed to AM and executors, so that executors and drive share the same secret to communicate. But STS/Hive library code will refresh the current UGI by UGI's loginFromKeytab() after Spark application is started, this will create a new UGI in the current driver's context with empty tokens and secret keys, so secret key is lost in the current context's UGI, that's why Spark driver throws secret key not found exception. In Spark 2.2 code, Spark also stores this secret key in SecurityManager's class variable, so even UGI is refreshed, the secret is still existed in the object, so STS with SASL can still be worked in Spark 2.2. But in Spark 2.3, we always search key from current UGI, which makes it fail to work in Spark 2.3. To fix this issue, there're two possible solutions: 1. Fix in STS/Hive library, when a new UGI is refreshed, copy the secret key from original UGI to the new one. The difficulty is that some codes to refresh the UGI is existed in Hive library, which makes us hard to change the code. 2. Roll back the logics in SecurityManager to match Spark 2.2, so that this issue can be fixed. 2nd solution seems a simple one. So I will propose a PR with 2nd solution. ## How was this patch tested? Verified in local cluster. CC vanzin tgravescs please help to review. Thanks! Author: jerryshao Closes #21138 from jerryshao/SPARK-24062. --- .../main/scala/org/apache/spark/SecurityManager.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 09ec8932353a0..dbfd5a514c189 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -89,6 +89,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -321,6 +322,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -364,8 +371,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - val secretStr = HashCodes.fromBytes(secretBytes).toString() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) + secretKey = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From d1eb8d3ddc877958512194cc8f5dd8119b41bed0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 23:24:05 -0700 Subject: [PATCH 0696/2461] [SPARK-24094][SS][MINOR] Change description strings of v2 streaming sources to reflect the change ## What changes were proposed in this pull request? This makes it easy to understand at runtime which version is running. Great for debugging production issues. ## How was this patch tested? Not necessary. Author: Tathagata Das Closes #21160 from tdas/SPARK-24094. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../streaming/sources/RateStreamMicroBatchReader.scala | 2 +- .../apache/spark/sql/execution/streaming/sources/socket.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 2ed49ba3f5495..cbe655f9bff1f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -169,7 +169,7 @@ private[kafka010] class KafkaMicroBatchReader( kafkaOffsetReader.close() } - override def toString(): String = s"Kafka[$kafkaOffsetReader]" + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" /** * Read initial partition offsets from the checkpoint, or decide the offsets and write them to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 6cf8520fc544f..f54291bea6678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -177,7 +177,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: override def stop(): Unit = {} - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b463398..90f4a5ba4234d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -214,7 +214,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def toString: String = s"TextSocket[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } class TextSocketSourceProvider extends DataSourceV2 From ce2f919f8df1b794ceaa23e1a59d5d541ed47bf5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 26 Apr 2018 19:07:13 +0800 Subject: [PATCH 0697/2461] [SPARK-23799][SQL][FOLLOW-UP] FilterEstimation.evaluateInSet produces wrong stats for STRING ## What changes were proposed in this pull request? `colStat.min` AND `colStat.max` are empty for string type. Thus, `evaluateInSet` should not return zero when either `colStat.min` or `colStat.max`. ## How was this patch tested? Added a test case. Author: gatorsmile Closes #21147 from gatorsmile/cached. --- .../logical/statsEstimation/FilterEstimation.scala | 12 ++++++++---- .../statsEstimation/FilterEstimationSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 263c9ba60d145..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,13 +392,13 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv - if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { - return Some(0.0) - } - // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -422,6 +422,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 16cb5d032cf57..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -368,6 +368,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 0) } + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), From 4f1e38649ebc7710850b7c40e6fb355775e7bb7f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Apr 2018 14:21:22 -0700 Subject: [PATCH 0698/2461] [SPARK-24057][PYTHON] put the real data type in the AssertionError message ## What changes were proposed in this pull request? Print out the data type in the AssertionError message to make it more meaningful. ## How was this patch tested? I manually tested the changed code on my local, but didn't add any test. Author: Huaxin Gao Closes #21159 from huaxingao/spark-24057. --- python/pyspark/sql/types.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..3cd7a2ef115af 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -289,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -343,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -402,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name From f7435bec6a9348cfbbe26b13c230c08545d16067 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 15:11:42 -0700 Subject: [PATCH 0699/2461] [SPARK-24044][PYTHON] Explicitly print out skipped tests from unittest module ## What changes were proposed in this pull request? This PR proposes to remove duplicated dependency checking logics and also print out skipped tests from unittests. For example, as below: ``` Skipped tests in pyspark.sql.tests with pypy: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ... Skipped tests in pyspark.sql.tests with python3: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ... ``` Currently, it's not printed out in the console. I think we should better print out skipped tests in the console. ## How was this patch tested? Manually tested. Also, fortunately, Jenkins has good environment to test the skipped output. Author: hyukjinkwon Closes #21107 from HyukjinKwon/skipped-tests-print. --- python/pyspark/ml/tests.py | 16 +++-- python/pyspark/mllib/tests.py | 4 +- python/pyspark/sql/tests.py | 51 +++++++------ python/pyspark/streaming/tests.py | 4 +- python/pyspark/tests.py | 12 +--- python/run-tests.py | 115 +++++++++++++----------------- 6 files changed, 98 insertions(+), 104 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..093593132e56d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2136,17 +2136,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -2662,6 +2668,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f1088..14d788b0bef60 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -1767,9 +1767,9 @@ def test_pca(self): if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98fa1b54b0a17..6b28c557a803e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3096,23 +3096,28 @@ def setUpClass(cls): filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - @classmethod def tearDownClass(cls): - cls.spark.stop() + if hasattr(cls, "spark"): + cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() @@ -3196,18 +3201,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -5316,6 +5325,6 @@ def test_invalid_args(self): if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 103940923dd4d..d77f1baa1f344 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1590,11 +1590,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9111dbbed5929..8392d7f29af53 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2353,15 +2353,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/run-tests.py b/python/run-tests.py index 6b41b5ee22814..f408fc5082b3d 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -32,6 +32,7 @@ else: import queue as Queue from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -50,6 +51,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -109,8 +111,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -152,65 +180,17 @@ def parse_opts(): return opts -def _check_dependencies(python_exec, modules_to_test): - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - try: - subprocess_check_output( - [python_exec, "-c", "import coverage"], - stderr=open(os.devnull, 'w')) - except: - print_red("Coverage is not installed in Python executable '%s' " - "but 'COVERAGE_PROCESS_START' environment variable is set, " - "exiting." % python_exec) - sys.exit(-1) - - # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and - # explicitly prints out. See SPARK-23300. - if pyspark_sql in modules_to_test: - # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. - minimum_pyarrow_version = '0.8.0' - minimum_pandas_version = '0.19.2' - - try: - pyarrow_version = subprocess_check_output( - [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): - LOGGER.info("Will test PyArrow related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) - except: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) - - try: - pandas_version = subprocess_check_output( - [python_exec, "-c", "import pandas; print(pandas.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): - LOGGER.info("Will test Pandas related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) - except: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) def main(): @@ -237,9 +217,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - # Check if the python executable has proper dependencies installed to run tests - # for given modules properly. - _check_dependencies(python_exec, modules_to_test) + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -281,6 +262,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() From 9ee9fcf5223efdf7543161b7bc99131111876b92 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 26 Apr 2018 15:38:11 -0700 Subject: [PATCH 0700/2461] [SPARK-24083][YARN] Log stacktrace for uncaught exception ## What changes were proposed in this pull request? Log stacktrace for uncaught exception ## How was this patch tested? UT and manually test Author: zhoukang Closes #21151 from caneGuy/zhoukang/log-stacktrace. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d04989e138f83..650840045361c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -308,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) } } From 8aa1d7b0ede5115297541d29eab4ce5f4fe905cb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 27 Apr 2018 11:00:41 +0800 Subject: [PATCH 0701/2461] [SPARK-23355][SQL] convertMetastore should not ignore table properties ## What changes were proposed in this pull request? Previously, SPARK-22158 fixed for `USING hive` syntax. This PR aims to fix for `STORED AS` syntax. Although the test case covers ORC part, the patch considers both `convertMetastoreOrc` and `convertMetastoreParquet`. ## How was this patch tested? Pass newly added test cases. Author: Dongjoon Hyun Closes #20522 from dongjoon-hyun/SPARK-22158-2. --- .../spark/sql/hive/HiveStrategies.scala | 17 +++- .../sql/hive/CompressionCodecSuite.scala | 7 +- .../sql/hive/execution/HiveDDLSuite.scala | 81 +++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8df05cbb20361..a0c197b06ddab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d10a6f25c64fc..4550d350f6db2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -268,12 +268,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - // For non-partitioned table and when convertMetastore is true, Expect session-level - // take effect, and in other cases expect table-level take effect - // TODO: It should always be table-level taking effect when the bug(SPARK-22926) - // is fixed - val expectCodec = - if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + val expectCodec = tableCodec.get assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c85db78c732de..daac6af9b557f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METAS import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -2144,6 +2145,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") From 109935fc5d8b3d381bb1b09a4a570040a0a1846f Mon Sep 17 00:00:00 2001 From: eric-maynard Date: Fri, 27 Apr 2018 15:25:07 +0800 Subject: [PATCH 0702/2461] [SPARK-23830][YARN] added check to ensure main method is found ## What changes were proposed in this pull request? When a user specifies the wrong class -- or, in fact, a class instead of an object -- Spark throws an NPE which is not useful for debugging. This was reported in [SPARK-23830](https://issues.apache.org/jira/browse/SPARK-23830). This PR adds a check to ensure the main method was found and logs a useful error in the event that it's null. ## How was this patch tested? * Unit tests + Manual testing * The scope of the changes is very limited Author: eric-maynard Author: Eric Maynard Closes #21168 from eric-maynard/feature/SPARK-23830. --- .../spark/deploy/yarn/ApplicationMaster.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 650840045361c..595077e7e809f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Modifier} import java.net.{Socket, URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -675,9 +675,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { From 2824f12b8bac5d86a82339d4dfb4d2625e978a15 Mon Sep 17 00:00:00 2001 From: Patrick McGloin Date: Fri, 27 Apr 2018 23:04:14 +0800 Subject: [PATCH 0703/2461] [SPARK-23565][SS] New error message for structured streaming sources assertion ## What changes were proposed in this pull request? A more informative message to tell you why a structured streaming query cannot continue if you have added more sources, than there are in the existing checkpoint offsets. ## How was this patch tested? I added a Unit Test. Author: Patrick McGloin Closes #20946 from patrickmcgloin/master. --- .../org/apache/spark/sql/execution/streaming/OffsetSeq.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73945b39b8967..787174481ff08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -39,7 +39,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } From 3fd297af6dc568357c97abf86760c570309d6597 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 Apr 2018 11:43:29 -0700 Subject: [PATCH 0704/2461] [SPARK-24085][SQL] Query returns UnsupportedOperationException when scalar subquery is present in partitioning expression ## What changes were proposed in this pull request? In this case, the partition pruning happens before the planning phase of scalar subquery expressions. For scalar subquery expressions, the planning occurs late in the cycle (after the physical planning) in "PlanSubqueries" just before execution. Currently we try to execute the scalar subquery expression as part of partition pruning and fail as it implements Unevaluable. The fix attempts to ignore the Subquery expressions from partition pruning computation. Another option can be to somehow plan the subqueries before the partition pruning. Since this may not be a commonly occuring expression, i am opting for a simpler fix. Repro ``` SQL CREATE TABLE test_prc_bug ( id_value string ) partitioned by (id_type string) location '/tmp/test_prc_bug' stored as parquet; insert into test_prc_bug values ('1','a'); insert into test_prc_bug values ('2','a'); insert into test_prc_bug values ('3','b'); insert into test_prc_bug values ('4','b'); select * from test_prc_bug where id_type = (select 'b'); ``` ## How was this patch tested? Added test in SubquerySuite and hive/SQLQuerySuite Author: Dilip Biswal Closes #21174 from dilipbiswal/spark-24085. --- .../datasources/FileSourceStrategy.scala | 5 ++- .../PruneFileSourcePartitions.scala | 4 ++- .../org/apache/spark/sql/SubquerySuite.scala | 15 +++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..0a568d6b8adce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -76,7 +76,10 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f0..16b2367bfdd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 31e8b0e8dede0..acef62d81ee12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -955,4 +955,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 73f83d593bbfb..704a410b6a37b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2156,4 +2156,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } + } + } + } + } From 8614edd445264007144caa6743a8c2ca2b5082e0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 27 Apr 2018 14:14:28 -0700 Subject: [PATCH 0705/2461] [SPARK-24104] SQLAppStatusListener overwrites metrics onDriverAccumUpdates instead of updating them ## What changes were proposed in this pull request? Event `SparkListenerDriverAccumUpdates` may happen multiple times in a query - e.g. every `FileSourceScanExec` and `BroadcastExchangeExec` call `postDriverMetricUpdates`. In Spark 2.2 `SQLListener` updated the map with new values. `SQLAppStatusListener` overwrites it. Unless `update` preserved it in the KV store (dependant on `exec.lastWriteTime`), only the metrics from the last operator that does `postDriverMetricUpdates` are preserved. ## How was this patch tested? Unit test added. Author: Juliusz Sompolski Closes #21171 from juliuszsompolski/SPARK-24104. --- .../execution/ui/SQLAppStatusListener.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2b6bb48467eb3..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index f3f08839c1d3a..02df45d1b7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -443,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -466,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -562,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } From 1fb46f30f83e4751169ff288ad406f26b7c11f7e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 28 Apr 2018 09:55:56 +0800 Subject: [PATCH 0706/2461] [SPARK-23688][SS] Refactor tests away from rate source ## What changes were proposed in this pull request? Replace rate source with memory source in continuous mode test suite. Keep using "rate" source if the tests intend to put data periodically in background, or need to put short source name to load, since "memory" doesn't have provider for source. ## How was this patch tested? Ran relevant test suite from IDE. Author: Jungtaek Lim Closes #21152 from HeartSaVioR/SPARK-23688. --- .../continuous/ContinuousSuite.scala | 163 +++++++----------- 1 file changed, 61 insertions(+), 102 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c318b951ff992..5f222e7885994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -75,73 +75,50 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -149,15 +126,11 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -165,58 +138,43 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("subquery alias") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .createOrReplaceTempView("rate") - val test = spark.sql("select value from rate where value > 5") + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") - testStream(test, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } test("task failure kills the query") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() // Get an arbitrary task from this query to kill. It doesn't matter which one. var taskId: Long = -1 @@ -227,9 +185,9 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.addSparkListener(listener) try { - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(100)), - Execute(waitForRateSourceTriggers(_, 2)), + AddData(input, 0, 1, 2, 3), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -252,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") From ad94e8592b2e8f4c1bdbd958e110797c6658af84 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 28 Apr 2018 10:47:43 +0800 Subject: [PATCH 0707/2461] [SPARK-23736][SQL][FOLLOWUP] Error message should contains SQL types ## What changes were proposed in this pull request? In the error messages we should return the SQL types (like `string` rather than the internal types like `StringType`). ## How was this patch tested? added UT Author: Marco Gaido Closes #21181 from mgaido91/SPARK-23736_followup. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 90223b9126555..6d63a531e3b74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -863,8 +863,9 @@ case class Concat(children: Seq[Expression]) extends Expression { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + - s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c216d1322a06c..470a1c8e331ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -712,6 +712,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df.selectExpr("concat(i1, array(i1, i2))") } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) } test("flatten function") { From 4df51361a5ff1fba20524f1b580f4049b328ed32 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 28 Apr 2018 16:57:41 +0800 Subject: [PATCH 0708/2461] [SPARK-22732][SS][FOLLOW-UP] Fix MemorySinkV2 toString error ## What changes were proposed in this pull request? Fix `MemorySinkV2` toString() error ## How was this patch tested? N/A Author: Yuming Wang Closes #21170 from wangyum/SPARK-22732. --- .../spark/sql/execution/streaming/sources/memoryV2.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb2..d871d37ad37c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -96,7 +96,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -107,7 +107,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.clear() } - override def toString(): String = "MemorySink" + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} @@ -175,7 +175,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { private val sizePerRow = output.map(_.dataType.defaultSize).sum From bd14da6fd5a77cc03efff193a84ffccbe892cc13 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 29 Apr 2018 11:25:31 +0800 Subject: [PATCH 0709/2461] [SPARK-23094][SPARK-23723][SPARK-23724][SQL] Support custom encoding for json files ## What changes were proposed in this pull request? I propose new option for JSON datasource which allows to specify encoding (charset) of input and output files. Here is an example of using of the option: ``` spark.read.schema(schema) .option("multiline", "true") .option("encoding", "UTF-16LE") .json(fileName) ``` If the option is not specified, charset auto-detection mechanism is used by default. The option can be used for saving datasets to jsons. Currently Spark is able to save datasets into json files in `UTF-8` charset only. The changes allow to save data in any supported charset. Here is the approximate list of supported charsets by Oracle Java SE: https://docs.oracle.com/javase/8/docs/technotes/guides/intl/encoding.doc.html . An user can specify the charset of output jsons via the charset option like `.option("charset", "UTF-16BE")`. By default the output charset is still `UTF-8` to keep backward compatibility. The solution has the following restrictions for per-line mode (`multiline = false`): - If charset is different from UTF-8, the lineSep option must be specified. The option required because Hadoop LineReader cannot detect the line separator correctly. Here is the ticket for solving the issue: https://issues.apache.org/jira/browse/SPARK-23725 - Encoding with [BOM](https://en.wikipedia.org/wiki/Byte_order_mark) are not supported. For example, the `UTF-16` and `UTF-32` encodings are blacklisted. The problem can be solved by https://github.com/MaxGekk/spark-1/pull/2 ## How was this patch tested? I added the following tests: - reads an json file in `UTF-16LE` encoding with BOM in `multiline` mode - read json file by using charset auto detection (`UTF-32BE` with BOM) - read json file using of user's charset (`UTF-16LE`) - saving in `UTF-32BE` and read the result by standard library (not by Spark) - checking that default charset is `UTF-8` - handling wrong (unsupported) charset Author: Maxim Gekk Author: Maxim Gekk Closes #20937 from MaxGekk/json-encoding-line-sep. --- python/pyspark/sql/readwriter.py | 15 +- python/pyspark/sql/tests.py | 7 + .../sql/people_array_utf16le.json | Bin 0 -> 182 bytes .../catalyst/json/CreateJacksonParser.scala | 49 +++- .../spark/sql/catalyst/json/JSONOptions.scala | 39 ++- .../sql/catalyst/json/JacksonParser.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 3 + .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../datasources/json/JsonDataSource.scala | 60 +++-- .../datasources/json/JsonFileFormat.scala | 10 +- .../datasources/text/TextOptions.scala | 18 +- .../src/test/resources/test-data/utf16LE.json | Bin 0 -> 98 bytes .../resources/test-data/utf16WithBOM.json | Bin 0 -> 200 bytes .../resources/test-data/utf32BEWithBOM.json | Bin 0 -> 172 bytes .../datasources/json/JsonBenchmarks.scala | 179 +++++++++++++ .../datasources/json/JsonSuite.scala | 245 +++++++++++++++++- 16 files changed, 599 insertions(+), 44 deletions(-) create mode 100644 python/test_support/sql/people_array_utf16le.json create mode 100644 sql/core/src/test/resources/test-data/utf16LE.json create mode 100644 sql/core/src/test/resources/test-data/utf16WithBOM.json create mode 100644 sql/core/src/test/resources/test-data/utf32BEWithBOM.json create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index df176c579fc8b..6811fa6b3b156 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +238,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. @@ -259,7 +264,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio) + samplingRatio=samplingRatio, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -752,7 +757,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None): + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -776,6 +781,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. @@ -784,7 +791,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep) + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6b28c557a803e..e0cd2aa41a2d0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -685,6 +685,13 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_linesep_json(self): df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") expected = [Row(_corrupt_record=None, name=u'Michael'), diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000000000000000000000000000000..9c657fa30ac9c651076ff8aa3676baa400b121fb GIT binary patch literal 182 zcma!M;9^h!!fGfDVk require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } - // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) - // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None), it will be detected automatically + * when the multiLine option is set to `true`. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map { enc => + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) + val isBlacklisted = blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + | ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = !(multiLine == false && + Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") /** Sets config options on a Jackson [[JsonFactory]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..a5a4a13eb608b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -361,6 +361,14 @@ class JacksonParser( // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b44552f0eb17b..6b2ea6c06d3ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -372,6 +372,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
  • *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bbc063148a72c..e183fa6f9542b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,8 +518,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.4.0 @@ -589,8 +590,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5769c09c9a1d9..983a5f0dcade2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,26 +92,30 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset( - sparkSession, inputPaths, parsedOptions.lineSeparator) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - lineSeparator: Option[String]): Dataset[String] = { - val textOptions = lineSeparator.map { lineSep => - Map(TextOptions.LINE_SEPARATOR -> lineSep) - }.getOrElse(Map.empty[String, String]) - + parsedOptions: JSONOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, @@ -129,8 +133,12 @@ object TextInputJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -153,7 +161,11 @@ object MultiLineJsonDataSource extends JsonDataSource { parsedOptions: JSONOptions): StructType = { val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( @@ -175,11 +187,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - val path = new Path(record.getPath()) - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -194,9 +213,12 @@ object MultiLineJsonDataSource extends JsonDataSource { UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..3b04510d29695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -151,7 +153,13 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 5c1a35434f7b5..e4e201995faa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} @@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean - private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => - require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") - sep + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep } + // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } val lineSeparatorInWrite: Array[Byte] = lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } @@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 0000000000000000000000000000000000000000..ce4117fd299dfcbc7089e7c0530098bfcaf5a27e GIT binary patch literal 98 zcmbi20w;GhFpeJpqLd9J2PYe#WR62N(?$cl`!==KvkHk Roq(bsb5ek+xfp7J7y!-s4k`cu literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 0000000000000000000000000000000000000000..cf4d29328b860ffe8288edea437222c6d432a100 GIT binary patch literal 200 zcmezWFPedufr~)_3ac5E7}6Lr8HyN+8A=%Z7!nzB8B&2_RzU2`kO36W1j;Be=m6C# tG2{T{G1WN%ML{N{09DiiRT68y3qw9bDMLB|(}RGj@}XvfOpXPc4*2#AY;xCDs(fH)C|bAdP&h(YSCptLiP&H!SNdXPSl f9+12a5Gz30IY1hupBVF;plV@mNP(JB3#7RK --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a58dff827b92d..0db688fec9a67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets +import java.io.{File, FileOutputStream, StringWriter} +import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -48,6 +48,10 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2167,4 +2171,241 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val sampled = spark.read.option("samplingRatio", 1.0).json(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + } + } + + assert(exception.getMessage == encoding) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read + } else { + spark.read.schema(schema) + } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } + + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } } From 56f501e1c0cec3be7d13008bd2c0182ec83ed2a2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Apr 2018 09:40:46 +0800 Subject: [PATCH 0710/2461] [MINOR][DOCS] Fix a broken link for Arrow's supported types in the programming guide ## What changes were proposed in this pull request? This PR fixes a broken link for Arrow's supported types in the programming guide. ## How was this patch tested? Manually tested via `SKIP_API=1 jekyll watch`. "Supported SQL Types" here in https://spark.apache.org/docs/latest/sql-programming-guide.html#enabling-for-conversion-tofrom-pandas is broken. It should be https://spark.apache.org/docs/latest/sql-programming-guide.html#supported-sql-types Author: hyukjinkwon Closes #21191 from HyukjinKwon/minor-arrow-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8ff1470970f7..836ce990205a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1703,7 +1703,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) From 3121b411f748859ed3ed1c97cbc21e6ae980a35c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 30 Apr 2018 09:45:22 +0800 Subject: [PATCH 0711/2461] [SPARK-23846][SQL] The samplingRatio option for CSV datasource ## What changes were proposed in this pull request? I propose to support the `samplingRatio` option for schema inferring of CSV datasource similar to the same option of JSON datasource: https://github.com/apache/spark/blob/b14993e1fcb68e1c946a671c6048605ab4afdf58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala#L49-L50 ## How was this patch tested? Added 2 tests for json and 2 tests for csv datasources. The tests checks that only subset of input dataset is used for schema inferring. Author: Maxim Gekk Author: Maxim Gekk Closes #20959 from MaxGekk/csv-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 7 +++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/csv/CSVOptions.scala | 3 ++ .../execution/datasources/csv/CSVUtils.scala | 28 +++++++++++ .../execution/datasources/csv/CSVSuite.scala | 47 ++++++++++++++++++- .../datasources/csv/TestCsvData.scala | 36 ++++++++++++++ 8 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6811fa6b3b156..9899eb5058b82 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -345,7 +345,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -428,6 +429,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -446,7 +449,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e0cd2aa41a2d0..bc3eaf16b4de7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3033,6 +3033,13 @@ def test_json_sampling_ratio(self): .json(rdd).schema self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6b2ea6c06d3ae..53f44888ebaff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -539,6 +539,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `header` (default `false`): uses the first line as names of columns.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • *
  • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
  • *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4870d75fc5f08..bc1f4ab3bb053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -161,7 +161,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -235,7 +236,8 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce17..2ec0fc605a84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -150,6 +150,9 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092ca..31464f1bcc68e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -131,4 +134,29 @@ object CSVUtils { schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV dataset as configured by `samplingRatio`. + */ + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d9217..461abdd96d3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -1279,4 +1278,48 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 0000000000000..3e20cc47dca2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} From b42ad165bb93c96cc5be9ed05b5026f9baafdfa2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Apr 2018 09:13:32 -0700 Subject: [PATCH 0712/2461] [SPARK-24072][SQL] clearly define pushed filters ## What changes were proposed in this pull request? filters like parquet row group filter, which is actually pushed to the data source but still to be evaluated by Spark, should also count as `pushedFilters`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21143 from cloud-fan/step1. --- .../SupportsPushDownCatalystFilters.java | 11 +++- .../v2/reader/SupportsPushDownFilters.java | 10 ++- .../datasources/v2/DataSourceV2Relation.scala | 63 +++++++++++-------- .../datasources/v2/DataSourceV2ScanExec.scala | 1 + .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/DataSourceV2StringFormat.scala | 19 ++---- .../v2/PushDownOperatorsToDataSource.scala | 6 +- .../continuous/ContinuousSuite.scala | 2 +- 8 files changed, 68 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 290d614805ac7..4543c143a9aca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -39,7 +39,16 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { Expression[] pushCatalystFilters(Expression[] filters); /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * Returns the catalyst filters that are pushed to the data source via + * {@link #pushCatalystFilters(Expression[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for * this case. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 1cff024232a44..b6a90a3d0b681 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -37,7 +37,15 @@ public interface SupportsPushDownFilters extends DataSourceReader { Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2b282ffae2390..90fb5a14c9fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -54,9 +54,12 @@ case class DataSourceV2Relation( private lazy val v2Options: DataSourceOptions = makeV2Options(options) + // postScanFilters: filters that need to be evaluated after the scan. + // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. + // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. lazy val ( reader: DataSourceReader, - unsupportedFilters: Seq[Expression], + postScanFilters: Seq[Expression], pushedFilters: Seq[Expression]) = { val newReader = userSpecifiedSchema match { case Some(s) => @@ -67,14 +70,16 @@ case class DataSourceV2Relation( DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - val (remainingFilters, pushedFilters) = filters match { + val (postScanFilters, pushedFilters) = filters match { case Some(filterSeq) => DataSourceV2Relation.pushFilters(newReader, filterSeq) case _ => (Nil, Nil) } + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - (newReader, remainingFilters, pushedFilters) + (newReader, postScanFilters, pushedFilters) } override def doCanonicalize(): LogicalPlan = { @@ -121,6 +126,8 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Nil + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. @@ -217,31 +224,35 @@ object DataSourceV2Relation { reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { - case catalystFilterSupport: SupportsPushDownCatalystFilters => - ( - catalystFilterSupport.pushCatalystFilters(filters.toArray), - catalystFilterSupport.pushedCatalystFilters() - ) - - case filterSupport: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet - val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => - unhandledFilters.contains(f) + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (postScanFilters, pushedFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } } - (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + + (untranslatableExprs ++ postScanFilters, pushedFilters) case _ => (filters, Nil) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e142..41bdda47c8c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -41,6 +41,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c2a31442d2be5..1b7c639f10f98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDat object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index aed55a429bfd7..693e67dcd108e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.util.Utils /** @@ -49,16 +47,9 @@ trait DataSourceV2StringFormat { def options: Map[String, String] /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. + * The filters which have been pushed to the data source. */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() @@ -68,8 +59,8 @@ trait DataSourceV2StringFormat { def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) { - entries += "Filters" -> filters.mkString("[", ", ", "]") + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") } // TODO: we should only display some standard options like path, table, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index f23d228567241..9293d4f831bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -57,9 +57,9 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { projection = projection.asInstanceOf[Seq[AttributeReference]], filters = Some(filters)) - // Add a Filter for any filters that could not be pushed - val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) - val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) + // Add a Filter for any filters that need to be evaluated after scan. + val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) + val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) // Add a Project to ensure the output matches the required projection if (newRelation.output != projectAttrs) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 5f222e7885994..cd1704ac2fdad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 007ae6878f4b4defe1f08114212fa7289fc9ee4a Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Mon, 30 Apr 2018 13:40:03 -0700 Subject: [PATCH 0713/2461] [SPARK-24003][CORE] Add support to provide spark.executor.extraJavaOptions in terms of App Id and/or Executor Id's ## What changes were proposed in this pull request? Added support to specify the 'spark.executor.extraJavaOptions' value in terms of the `{{APP_ID}}` and/or `{{EXECUTOR_ID}}`, `{{APP_ID}}` will be replaced by Application Id and `{{EXECUTOR_ID}}` will be replaced by Executor Id while starting the executor. ## How was this patch tested? I have verified this by checking the executor process command and gc logs. I verified the same in different deployment modes(Standalone, YARN, Mesos) client and cluster modes. Author: Devaraj K Closes #21088 from devaraj-kavali/SPARK-24003. --- .../spark/deploy/worker/ExecutorRunner.scala | 8 ++++++-- .../main/scala/org/apache/spark/util/Utils.scala | 15 +++++++++++++++ docs/configuration.md | 5 +++++ .../k8s/features/BasicExecutorFeatureStep.scala | 4 +++- .../MesosCoarseGrainedSchedulerBackend.scala | 4 +++- .../mesos/MesosFineGrainedSchedulerBackend.scala | 4 +++- .../org/apache/spark/deploy/yarn/Client.scala | 8 ++++++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 ++- 8 files changed, 43 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc8204..dc6a3076a5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d2be93226e2a2..dcad1b914038f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2689,6 +2689,21 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } } private[util] object CallerContext extends Logging { diff --git a/docs/configuration.md b/docs/configuration.md index fb02d7ea1d4ea..8a1aacef85760 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -328,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc
  • diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d22097587aafe..529069d3b8a0c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -89,7 +89,9 @@ private[spark] class BasicExecutorFeatureStep( val executorExtraJavaOptionsEnv = kubernetesConf .get(EXECUTOR_JAVA_OPTIONS) .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca486..9b75e4c98344a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d246109..71a70ff048ccc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5763c3dbc5a8a..bafb129032b49 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -892,7 +892,9 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten @@ -914,7 +916,9 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab08698035c98..a2a18cdff65af 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -141,7 +141,8 @@ private[yarn] class ExecutorRunnable( // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) From b857fb549f3bf4e6f289ba11f3903db0a3696dec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 May 2018 09:06:23 +0800 Subject: [PATCH 0714/2461] [SPARK-23853][PYSPARK][TEST] Run Hive-related PySpark tests only for `-Phive` ## What changes were proposed in this pull request? When `PyArrow` or `Pandas` are not available, the corresponding PySpark tests are skipped automatically. Currently, PySpark tests fail when we are not using `-Phive`. This PR aims to skip Hive related PySpark tests when `-Phive` is not given. **BEFORE** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql File "/Users/dongjoon/spark/python/pyspark/sql/readwriter.py", line 295, in pyspark.sql.readwriter.DataFrameReader.table ... IllegalArgumentException: u"Error while instantiating 'org.apache.spark.sql.hive.HiveExternalCatalog':" ********************************************************************** 1 of 3 in pyspark.sql.readwriter.DataFrameReader.table ***Test Failed*** 1 failures. ``` **AFTER** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql ... Tests passed in 138 seconds Skipped tests in pyspark.sql.tests with python2.7: ... test_hivecontext (pyspark.sql.tests.HiveSparkSubmitTests) ... skipped 'Hive is not available.' ``` ## How was this patch tested? This is a test-only change. First, this should pass the Jenkins. Then, manually do the following. ```bash build/mvn -DskipTests clean package python/run-tests.py --python-executables python2.7 --modules pyspark-sql ``` Author: Dongjoon Hyun Closes #21141 from dongjoon-hyun/SPARK-23853. --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/tests.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9899eb5058b82..448a4732001b5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -979,7 +979,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bc3eaf16b4de7..cc6acfdb07d99 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3043,6 +3043,26 @@ def test_csv_sampling_ratio(self): class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by From 7bbec0dced35aeed79c1a24b6f7a1e0a3508b0fb Mon Sep 17 00:00:00 2001 From: wangyanlin01 Date: Tue, 1 May 2018 16:22:52 +0800 Subject: [PATCH 0715/2461] [SPARK-24061][SS] Add TypedFilter support for continuous processing ## What changes were proposed in this pull request? Add TypedFilter support for continuous processing application. ## How was this patch tested? unit tests Author: wangyanlin01 Closes #21136 from yanlin-Lynn/SPARK-24061. --- .../UnsupportedOperationChecker.scala | 3 ++- .../analysis/UnsupportedOperationsSuite.scala | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dded..d3d6c636c4ba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,8 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda264..cb487c8893541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } From 6782359a04356e4cde32940861bf2410ef37f445 Mon Sep 17 00:00:00 2001 From: Bounkong Khamphousone Date: Tue, 1 May 2018 08:28:21 -0700 Subject: [PATCH 0716/2461] [SPARK-23941][MESOS] Mesos task failed on specific spark app name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Shell escaped the name passed to spark-submit and change how conf attributes are shell escaped. ## How was this patch tested? This test has been tested manually with Hive-on-spark with mesos or with the use case described in the issue with the sparkPi application with a custom name which contains illegal shell characters. With this PR, hive-on-spark on mesos works like a charm with hive 3.0.0-SNAPSHOT. I state that this contribution is my original work and that I license the work to the project under the project’s open source license Author: Bounkong Khamphousone Closes #21014 from tiboun/fix/SPARK-23941. --- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820a..b36f46456f9a5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -530,9 +530,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** From e15850be6e0210614a734a307f5b83bdf44e2456 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 May 2018 10:55:01 +0800 Subject: [PATCH 0717/2461] [SPARK-24131][PYSPARK] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? We need to determine Spark major and minor versions in PySpark. We can add a `majorMinorVersion` API to PySpark which is similar to the Scala API in `VersionUtils.majorMinorVersion`. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21203 from viirya/SPARK-24131. --- python/pyspark/util.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..04df835bf6717 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -61,6 +62,26 @@ def _get_argspec(f): return argspec +def majorMinorVersion(version): + """ + Get major and minor version numbers for given Spark version string. + + >>> version = "2.4.0" + >>> majorMinorVersion(version) + (2, 4) + + >>> version = "abc" + >>> majorMinorVersion(version) is None + True + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', version) + if m is None: + return None + else: + return (int(m.group(1)), int(m.group(2))) + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9215ee7a16b57c56ae927d65e024cf7afe542cbb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 10:41:34 +0200 Subject: [PATCH 0718/2461] [SPARK-23976][CORE] Detect length overflow in UTF8String.concat()/ByteArray.concat() ## What changes were proposed in this pull request? This PR detects length overflow if total elements in inputs are not acceptable. For example, when the three inputs has `0x7FFF_FF00`, `0x7FFF_FF00`, and `0xE00`, we should detect length overflow since we cannot allocate such a large structure on `byte[]`. On the other hand, the current algorithm can allocate the result structure with `0x1000`-byte length due to integer sum overflow. ## How was this patch tested? Existing UTs. If we would create UTs, we need large heap (6-8GB). It may make test environment unstable. If it is necessary to create UTs, I will create them. Author: Kazuaki Ishizaki Closes #21064 from kiszk/SPARK-23976. --- .../org/apache/spark/unsafe/types/ByteArray.java | 12 +++++++----- .../org/apache/spark/unsafe/types/UTF8String.java | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f61..ecd7c19f2c634 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e9b3d9b045af5..e91fc4391425c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,8 +29,8 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; - import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; From 152eaf6ae698cd0df7f5a5be3f17ee46e0be929d Mon Sep 17 00:00:00 2001 From: WangJinhai02 Date: Wed, 2 May 2018 22:40:14 +0800 Subject: [PATCH 0719/2461] [SPARK-24107][CORE] ChunkedByteBuffer.writeFully method has not reset the limit value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue: https://issues.apache.org/jira/browse/SPARK-24107?jql=text%20~%20%22ChunkedByteBuffer%22 ChunkedByteBuffer.writeFully method has not reset the limit value. When chunks larger than bufferWriteChunkSize, such as 80 * 1024 * 1024 larger than config.BUFFER_WRITE_CHUNK_SIZE(64 * 1024 * 1024),only while once, will lost 16 * 1024 * 1024 byte Author: WangJinhai02 Closes #21175 from manbuyun/bugfix-ChunkedByteBuffer. --- .../spark/util/io/ChunkedByteBuffer.scala | 13 +++++++++---- .../spark/io/ChunkedByteBufferSuite.scala | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd8..3ae8dfcc1cb66 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) + val curChunkLimit = bytes.limit() + while (bytes.hasRemaining) { + try { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + } finally { + bytes.limit(curChunkLimit) + } } } } diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b0499..2107559572d78 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) From 8dbf56c055218ff0f3fabae84b63c022f43afbfd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 11:58:55 -0700 Subject: [PATCH 0720/2461] [SPARK-24013][SQL] Remove unneeded compress in ApproximatePercentile ## What changes were proposed in this pull request? `ApproximatePercentile` contains a workaround logic to compress the samples since at the beginning `QuantileSummaries` was ignoring the compression threshold. This problem was fixed in SPARK-17439, but the workaround logic was not removed. So we are compressing the samples many more times than needed: this could lead to critical performance degradation. This can create serious performance issues in queries like: ``` select approx_percentile(id, array(0.1)) from range(10000000) ``` ## How was this patch tested? added UT Author: Marco Gaido Closes #21133 from mgaido91/SPARK-24013. --- .../aggregate/ApproximatePercentile.scala | 33 ++++--------------- .../sql/catalyst/util/QuantileSummaries.scala | 11 ++++--- .../sql/ApproximatePercentileQuerySuite.scala | 13 ++++++++ 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index a45854a3b5146..f1bbbdabb41f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -335,8 +314,8 @@ object ApproximatePercentile { sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index b013add9c9778..3190e511e2cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb9..d635912cf7205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } From 8bd27025b7cf0b44726b6f4020d294ef14dbbb7e Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Wed, 2 May 2018 12:43:19 -0700 Subject: [PATCH 0721/2461] [SPARK-24133][SQL] Check for integer overflows when resizing WritableColumnVectors ## What changes were proposed in this pull request? `ColumnVector`s store string data in one big byte array. Since the array size is capped at just under Integer.MAX_VALUE, a single `ColumnVector` cannot store more than 2GB of string data. But since the Parquet files commonly contain large blobs stored as strings, and `ColumnVector`s by default carry 4096 values, it's entirely possible to go past that limit. In such cases a negative capacity is requested from `WritableColumnVector.reserve()`. The call succeeds (requested capacity is smaller than already allocated capacity), and consequently `java.lang.ArrayIndexOutOfBoundsException` is thrown when the reader actually attempts to put the data into the array. This change introduces a simple check for integer overflow to `WritableColumnVector.reserve()` which should help catch the error earlier and provide more informative exception. Additionally, the error message in `WritableColumnVector.throwUnsupportedException()` was corrected, as it previously encouraged users to increase rather than reduce the batch size. ## How was this patch tested? New units tests were added. Author: Ala Luszczak Closes #21206 from ala/overflow-reserve. --- .../vectorized/WritableColumnVector.java | 21 ++++++++++++------- .../vectorized/ColumnarBatchSuite.scala | 7 +++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5275e4a91eac0..b0e119d658cb4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -81,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader, or increase the vectorized reader batch size. For parquet file " + - "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + - SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 772f687526008..f57f07b498261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1333,4 +1333,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.close() } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } From 504c9cfd21ef45a13d9428fef3b197dcbf6786cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 13:49:15 -0700 Subject: [PATCH 0722/2461] [SPARK-24123][SQL] Fix precision issues in monthsBetween with more than 8 digits ## What changes were proposed in this pull request? SPARK-23902 introduced the ability to retrieve more than 8 digits in `monthsBetween`. Unfortunately, current implementation can cause precision loss in such a case. This was causing also a flaky UT. This PR mirrors Hive's implementation in order to avoid precision loss also when more than 8 digits are returned. ## How was this patch tested? running 10000000 times the flaky UT Author: Marco Gaido Closes #21196 from mgaido91/SPARK-24123. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4b00a61c6cf91..d2fe15c48c6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -888,14 +888,19 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D if (roundOff) { // rounding to 8 digits math.round(diff * 1e8) / 1e8 From 5be8aab14468e55b1049a0c83f02dcec0651162f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 13:53:10 -0700 Subject: [PATCH 0723/2461] [SPARK-23923][SQL] Add cardinality function ## What changes were proposed in this pull request? The PR adds the SQL function `cardinality`. The behavior of the function is based on Presto's one. The function returns the length of the array or map stored in the column as `int` while the Presto version returns the value as `BigInt` (`long` in Spark). The discussions regarding the difference of return type are [here](https://github.com/apache/spark/pull/21031#issuecomment-381284638) and [there](https://github.com/apache/spark/pull/21031#discussion_r181622107). ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21031 from kiszk/SPARK-23923. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6bc7b4e4f7cb3..3ffbc9c8069fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 470a1c8e331ba..a5163accb1bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -341,6 +341,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(-1)) ) + + checkAnswer( + df.selectExpr("cardinality(a)"), + Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) + ) } test("map size function") { From e4c91c089a701117af82f585d14d8afc5245fc64 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 2 May 2018 16:12:21 -0700 Subject: [PATCH 0724/2461] [SPARK-24111][SQL] Add the TPCDS v2.7 (latest) queries in TPCDSQueryBenchmark ## What changes were proposed in this pull request? This pr added the TPCDS v2.7 (latest) queries in `TPCDSQueryBenchmark`. These query files have been added in `SPARK-23167`. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #21177 from maropu/AddTpcdsV2_7InBenchmark. --- .../benchmark/TPCDSQueryBenchmark.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9aa..abe61a2c2b9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } From bf4352ca6c96dfab16b286c54720685e32b216f1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 May 2018 09:28:14 +0800 Subject: [PATCH 0725/2461] [SPARK-24110][THRIFT-SERVER] Avoid UGI.loginUserFromKeytab in STS ## What changes were proposed in this pull request? Spark ThriftServer will call UGI.loginUserFromKeytab twice in initialization. This is unnecessary and will cause various potential problems, like Hadoop IPC failure after 7 days, or RM failover issue and so on. So here we need to remove all the unnecessary login logics and make sure UGI in the context never be created again. Note this is actually a HS2 issue, If later on we upgrade supported Hive version, the issue may already be fixed in Hive side. ## How was this patch tested? Local verification in secure cluster. Author: jerryshao Closes #21178 from jerryshao/SPARK-24110. --- .../hive/service/auth/HiveAuthFactory.java | 62 +++++++++++++++++-- .../thriftserver/SparkSQLCLIService.scala | 20 +++++- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade65283045..10000f12ab329 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a7..1335e16e35882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => From c9bfd1c6f8d16890ea1e5bc2bcb654a3afb32591 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 May 2018 15:15:05 +0800 Subject: [PATCH 0726/2461] [SPARK-23489][SQL][TEST] HiveExternalCatalogVersionsSuite should verify the downloaded file ## What changes were proposed in this pull request? Although [SPARK-22654](https://issues.apache.org/jira/browse/SPARK-22654) made `HiveExternalCatalogVersionsSuite` download from Apache mirrors three times, it has been flaky because it didn't verify the downloaded file. Some Apache mirrors terminate the downloading abnormally, the *corrupted* file shows the following errors. ``` gzip: stdin: not in gzip format tar: Child returned status 1 tar: Error is not recoverable: exiting now 22:46:32.700 WARN org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.hive.HiveExternalCatalogVersionsSuite, thread names: Keep-Alive-Timer ===== *** RUN ABORTED *** java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "/tmp/test-spark/spark-2.2.0"): error=2, No such file or directory ``` This has been reported weirdly in two ways. For example, the above case is reported as Case 2 `no failures`. - Case 1. [Test Result (1 failure / +1)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4389/) - Case 2. [Test Result (no failures)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4811/) This PR aims to make `HiveExternalCatalogVersionsSuite` more robust by verifying the downloaded `tgz` file by extracting and checking the existence of `bin/spark-submit`. If it turns out that the file is empty or corrupted, `HiveExternalCatalogVersionsSuite` will do retry logic like the download failure. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21210 from dongjoon-hyun/SPARK-23489. --- .../HiveExternalCatalogVersionsSuite.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6ca58e68d31eb..ea86ab9772bc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -67,7 +67,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } catch { case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } @@ -75,20 +89,6 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -161,7 +161,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( From 417ad92502e714da71552f64d0e1257d2fd5d3d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:27:01 +0800 Subject: [PATCH 0727/2461] [SPARK-23715][SQL] the input of to/from_utc_timestamp can not have timezone ## What changes were proposed in this pull request? `from_utc_timestamp` assumes its input is in UTC timezone and shifts it to the specified timezone. When the timestamp contains timezone(e.g. `2018-03-13T06:18:23+00:00`), Spark breaks the semantic and respect the timezone in the string. This is not what user expects and the result is different from Hive/Impala. `to_utc_timestamp` has the same problem. More details please refer to the JIRA ticket. This PR fixes this by returning null if the input timestamp contains timezone. ## How was this patch tested? new tests Author: Wenchen Fan Closes #21169 from cloud-fan/from_utc_timezone. --- docs/sql-programming-guide.md | 13 +- .../sql/catalyst/analysis/TypeCoercion.scala | 30 +++- .../expressions/datetimeExpressions.scala | 42 ++++++ .../sql/catalyst/util/DateTimeUtils.scala | 22 ++- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../catalyst/analysis/TypeCoercionSuite.scala | 12 +- .../resources/sql-tests/inputs/datetime.sql | 33 +++++ .../sql-tests/results/datetime.sql.out | 135 +++++++++++++++++- .../apache/spark/sql/DateFunctionsSuite.scala | 8 ++ 9 files changed, 283 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 836ce990205a9..075b953a0898e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,12 +1805,13 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 25bad28a2a209..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -776,12 +776,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -798,7 +819,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -814,6 +835,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d882d06cfd625..76aa61415a11f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1016,6 +1016,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = s""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d2fe15c48c6dd..e646da0659e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -296,10 +296,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +435,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293eca..3942240c442b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1208,6 +1208,13 @@ object SQLConf { .stringConf .createWithDefault("") + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1cc431aaf0a60..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,11 +536,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -823,7 +823,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..4950a4b7a4e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,36 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..9eede305dbdcc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 26 -- !query 0 @@ -82,9 +82,138 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f712baa7a9134..237412aa692e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -696,4 +697,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } From 991b526992bcf1dc1268578b650916569b12f583 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:56:30 +0800 Subject: [PATCH 0728/2461] [SPARK-24166][SQL] InMemoryTableScanExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from https://github.com/apache/spark/pull/21190 , to make it easier to backport. `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21223 from cloud-fan/minor1. --- .../execution/columnar/InMemoryTableScanExec.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index ea315fb71617c..0b4dd76c7d860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,10 +78,12 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) @@ -101,10 +103,13 @@ case class InMemoryTableScanExec( private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled if (supportsBatch) { // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] } else { val numOutputRows = longMetric("numOutputRows") From 96a50016bb0fb1cc57823a6706bff2467d671efd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 23:36:09 +0800 Subject: [PATCH 0729/2461] [SPARK-24169][SQL] JsonToStructs should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21226 from cloud-fan/minor4. --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/jsonExpressions.scala | 16 ++-- .../expressions/JsonExpressionsSuite.scala | 76 +++++++++---------- .../org/apache/spark/sql/functions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 +- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3ffbc9c8069fd..51bb6b0abe408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -534,7 +534,9 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted val expectedNumberOfParameters = if (validParametersCount.length == 1) { validParametersCount.head.toString } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index fdd672c416a03..34161f0f03f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -514,11 +514,10 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String], + forceNullableSchema: Boolean) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) - // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -532,14 +531,21 @@ case class JsonToStructs( schema = JsonExprUtils.validateSchemaLiteral(schema), options = Map.empty[String, String], child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.validateSchemaLiteral(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + + // Used in `org.apache.spark.sql.functions` + def this(schema: DataType, options: Map[String, String], child: Expression) = + this(schema, options, child, timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7812319756eae..00e97637eee7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,7 +512,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID)), + Option(tz.getID), + true), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -521,7 +522,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId), + gmtId, + true), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -530,7 +532,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), null ) } @@ -685,27 +687,23 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { - val input = - """{ - | "a": 1, - | "c": "foo" - |} - |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - checkEvaluation( - JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), - output - ) - val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) - .dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) - } + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs( + jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 25afaacc38d6f..d2e22fa355514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3179,7 +3179,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(schema, options, e.expr) + new JsonToStructs(schema, options, e.expr) } /** diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 581dddc89d0bb..14a69128ffb41 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 From 94641fe6cc68e5977dd8663b8f232a287a783acb Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 3 May 2018 10:59:18 -0500 Subject: [PATCH 0730/2461] [SPARK-23433][CORE] Late zombie task completions update all tasksets Fetch failure lead to multiple tasksets which are active for a given stage. While there is only one "active" version of the taskset, the earlier attempts can still have running tasks, which can complete successfully. So a task completion needs to update every taskset so that it knows the partition is completed. That way the final active taskset does not try to submit another task for the same partition, and so that it knows when it is completed and when it should be marked as a "zombie". Added a regression test. Author: Imran Rashid Closes #21131 from squito/SPARK-23433. --- .../spark/scheduler/TaskSchedulerImpl.scala | 14 +++ .../spark/scheduler/TaskSetManager.scala | 20 +++- .../scheduler/TaskSchedulerImplSuite.scala | 104 ++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981b..8e97b3da33820 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8a96a7692f614..195fc8025e4b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -73,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -153,7 +155,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -754,6 +756,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -764,6 +769,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7bef..33f2ea1c94e75 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } From e3201e165e41f076ec72175af246d12c0da529cf Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 3 May 2018 17:05:02 -0700 Subject: [PATCH 0731/2461] [SPARK-24035][SQL] SQL syntax for Pivot ## What changes were proposed in this pull request? Add SQL support for Pivot according to Pivot grammar defined by Oracle (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_clause.htm) with some simplifications, based on our existing functionality and limitations for Pivot at the backend: 1. For pivot_for_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_for_clause.htm), the column list form is not supported, which means the pivot column can only be one single column. 2. For pivot_in_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_in_clause.htm), the sub-query form and "ANY" is not supported (this is only supported by Oracle for XML anyway). 3. For pivot_in_clause, aliases for the constant values are not supported. The code changes are: 1. Add parser support for Pivot. Note that according to https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#i2076542, Pivot cannot be used together with lateral views in the from clause. This restriction has been implemented in the Parser rule. 2. Infer group-by expressions: group-by expressions are not explicitly specified in SQL Pivot clause and need to be deduced based on this rule: https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#CHDFAFIE, so we have to post-fix it at query analysis stage. 3. Override Pivot.resolved as "false": for the reason mentioned in [2] and the fact that output attributes change after Pivot being replaced by Project or Aggregate, we avoid resolving parent references until after Pivot has been resolved and replaced. 4. Verify aggregate expressions: only aggregate expressions with or without aliases can appear in the first part of the Pivot clause, and this check is performed as analysis stage. ## How was this patch tested? A new test suite PivotSuite is added. Author: maryannxue Closes #21187 from maryannxue/spark-24035. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 35 +++- .../sql/catalyst/parser/AstBuilder.scala | 20 +- .../plans/logical/basicLogicalOperators.scala | 27 ++- .../parser/TableIdentifierParserSuite.scala | 6 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 113 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 194 ++++++++++++++++++ 8 files changed, 386 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/pivot.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/pivot.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5fa75fe348e68..f7f921ec22c35 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* lateralView* + : FROM relation (',' relation)* (pivotClause | lateralView*)? ; aggregation @@ -413,6 +413,10 @@ groupingSet | expression ; +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + ; + lateralView : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? ; @@ -725,7 +729,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -745,7 +749,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -760,6 +764,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -805,6 +810,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7c..dfdcdbc1eb2c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -275,9 +275,9 @@ class Analyzer( case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -504,9 +504,20 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + // Check all aggregate expressions. + aggregates.foreach { e => + if (!isAggregateExpression(e)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$e'") + } + } + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) @@ -568,16 +579,20 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + private def isAggregateExpression(expr: Expression): Boolean = { + expr match { + case Alias(e, _) => isAggregateExpression(e) + case AggregateExpression(_, _, _, _) => true + case _ => false + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a878..64eed23884584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** @@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) + val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df504795430..3bf32ef7884e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -686,17 +686,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998d..89903c2825125 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7147798d99533..6c2be3610ae30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -73,7 +73,7 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 0000000000000..01dea6c81c11b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,113 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 0000000000000..85e3488990e20 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 13 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 2 schema +struct +-- !query 2 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 3 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 3 schema +struct +-- !query 3 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 4 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 4 schema +struct +-- !query 4 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 5 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +63000 50000 + + +-- !query 6 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 2012 50000 2012 + + +-- !query 7 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 7 schema +struct +-- !query 7 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 8 schema +struct +-- !query 8 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 10 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 10 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 10 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 11 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, found 'abs(earnings#x)'; + + +-- !query 12 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 From e646ae67f2e793204bc819ab2b90815214c2bbf3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 17:27:13 -0700 Subject: [PATCH 0732/2461] [SPARK-24168][SQL] WindowExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21225 from cloud-fan/minor3. --- .../spark/sql/execution/window/WindowExec.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f3996..626f39d9e95cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } From 0c23e254c38d4a9210939e1e1b0074278568abed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 09:27:14 +0800 Subject: [PATCH 0733/2461] [SPARK-24167][SQL] ParquetFilters should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't call `conf.parquetFilterPushDownDate` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21224 from cloud-fan/minor2. --- .../datasources/parquet/ParquetFileFormat.scala | 3 ++- .../datasources/parquet/ParquetFilters.scala | 15 +++++++-------- .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952de..d1f9e11ed4225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -352,7 +353,7 @@ class ParquetFileFormat // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ccc8306866d68..310626197a763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,14 +25,13 @@ import org.apache.parquet.io.api.Binary import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters(pushDownDate: Boolean) { private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -59,7 +58,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -85,7 +84,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -108,7 +107,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.lt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -131,7 +130,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.ltEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -154,7 +153,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -177,7 +176,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gtEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1d3476e747046..667e0b1760e3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,6 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -99,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -517,7 +519,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -525,7 +527,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -533,7 +535,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.Not( sources.And( From 7f1b6b182e3cf3cbf29399e7bfbe03fa869e0bc8 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 4 May 2018 16:02:21 +0800 Subject: [PATCH 0734/2461] [SPARK-24136][SS] Fix MemoryStreamDataReader.next to skip sleeping if record is available ## What changes were proposed in this pull request? Avoid unnecessary sleep (10 ms) in each invocation of MemoryStreamDataReader.next. ## How was this patch tested? Ran ContinuousSuite from IDE. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21207 from arunmahadevan/memorystream. --- .../streaming/sources/ContinuousMemoryStream.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729b..a8fca3c19a2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -183,11 +183,10 @@ class ContinuousMemoryStreamDataReader( private var current: Option[Row] = None override def next(): Boolean = { - current = None + current = getRecord while (current.isEmpty) { Thread.sleep(10) - current = endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + current = getRecord } currentOffset += 1 true @@ -199,6 +198,10 @@ class ContinuousMemoryStreamDataReader( override def getOffset: ContinuousMemoryStreamPartitionOffset = ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) From 4d5de4d303a773b1c18c350072344bd7efca9fc4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 19:20:15 +0800 Subject: [PATCH 0735/2461] [SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly ## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan Closes #21229 from cloud-fan/accumulator. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 ++++-- .../spark/util/AccumulatorV2Suite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 0f84ea9752cf5..2bc84953a56eb 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..fe0a9a471a651 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } From d04806a23c1843a7f0dcc4fa236ed1b40ae113a5 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 4 May 2018 13:29:47 -0700 Subject: [PATCH 0736/2461] =?UTF-8?q?[SPARK-24124]=20Spark=20history=20ser?= =?UTF-8?q?ver=20should=20create=20spark.history.store.=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …path and set permissions properly ## What changes were proposed in this pull request? Spark history server should create spark.history.store.path and set permissions properly. Note createdDirectories doesn't do anything if the directories are already created. This does not stomp on the permissions if the user had manually created the directory before the history server could. ## How was this patch tested? Manually tested in a 100 node cluster. Ensured directories created with proper permissions. Ensured restarted worked apps/temp directories worked as apps were read. Author: Thomas Graves Closes #21234 from tgravescs/SPARK-24124. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 56db9359e033f..bf1eeb0c1bf59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,8 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -130,8 +132,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) From af4dc50280ffcdeda208ef2dc5f8b843389732e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 4 May 2018 14:14:40 -0700 Subject: [PATCH 0737/2461] [SPARK-24039][SS] Do continuous processing writes with multiple compute() calls ## What changes were proposed in this pull request? Do continuous processing writes with multiple compute() calls. The current strategy (before this PR) is hacky; we just call next() on an iterator which has already returned hasNext = false, knowing that all the nodes we whitelist handle this properly. This will have to be changed before we can support more complex query plans. (In particular, I have a WIP https://github.com/jose-torres/spark/pull/13 which should be able to support aggregates in a single partition with minimal additional work.) Most of the changes here are just refactoring to accommodate the new model. The behavioral changes are: * The writer now calls prev.compute(split, context) once per epoch within the epoch loop. * ContinuousDataSourceRDD now spawns a ContinuousQueuedDataReader which is shared across multiple calls to compute() for the same partition. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #21200 from jose-torres/noAggr. --- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../continuous/ContinuousDataSourceRDD.scala | 114 +++++++++ .../ContinuousDataSourceRDDIter.scala | 222 ------------------ .../ContinuousQueuedDataReader.scala | 211 +++++++++++++++++ .../continuous/ContinuousWriteRDD.scala | 90 +++++++ .../WriteToContinuousDataSourceExec.scala | 57 +---- .../ContinuousQueuedDataReaderSuite.scala | 167 +++++++++++++ 7 files changed, 592 insertions(+), 275 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 41bdda47c8c3e..77cb707340b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -96,7 +96,11 @@ case class DataSourceV2ScanExec( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + readerFactories) .asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 0000000000000..0a3b9dcccb6c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.{NextIterator, ThreadUtils} + +class ContinuousDataSourceRDDPartition( + val index: Int, + val readerFactory: DataReaderFactory[UnsafeRow]) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readerFactories.zipWithIndex.map { + case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + }.toArray + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + if (partition.queueReader == null) { + partition.queueReader = + new ContinuousQueuedDataReader( + partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index 06754f01657d3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.util.ThreadUtils - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } - - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() - - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--$coordinatorId--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } finally { - reader.close() - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 0000000000000..01a999f6505fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + factory: DataReaderFactory[UnsafeRow], + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + private val reader = factory.createDataReader() + + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentEpoch: Long = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): UnsafeRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + context.partitionId(), currentEpoch, currentOffset)) + currentEpoch += 1 + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[DataReader]]. + */ + class DataReaderThread extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 0000000000000..91f1576581511 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch committed.") + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index ba88ae1af469a..e0af3a2f1b85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -46,24 +46,19 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } - val rdd = query.execute() + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + - s"The input RDD has ${rdd.getNumPartitions} partitions.") - // Let the epoch coordinator know how many partitions the write RDD has. + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) try { // Force the RDD to run so continuous processing starts; no data is actually being collected // to the driver, as ContinuousWriteRDD outputs nothing. - sparkContext.runJob( - rdd, - (context: TaskContext, iter: Iterator[InternalRow]) => - WriteToContinuousDataSourceExec.run(writerFactory, context, iter), - rdd.partitions.indices) + rdd.collect() } catch { case _: InterruptedException => // Interruption is how continuous queries are ended, so accept and ignore the exception. @@ -80,45 +75,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla sparkContext.emptyRDD } } - -object WriteToContinuousDataSourceExec extends Logging { - def run( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): Unit = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 0000000000000..e755625d09e0f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamWriter], + mock[ContinuousReader], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val factory = new DataReaderFactory[UnsafeRow] { + override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + } + val reader = new ContinuousQueuedDataReader( + factory, + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} From 47b5b68528c154d32b3f40f388918836d29462b8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 May 2018 16:35:24 -0700 Subject: [PATCH 0738/2461] [SPARK-24157][SS] Enabled no-data batches in MicroBatchExecution for streaming aggregation and deduplication. ## What changes were proposed in this pull request? This PR enables the MicroBatchExecution to run no-data batches if some SparkPlan requires running another batch to output results based on updated watermark / processing time. In this PR, I have enabled streaming aggregations and streaming deduplicates to automatically run addition batch even if new data is available. See https://issues.apache.org/jira/browse/SPARK-24156 for more context. Major changes/refactoring done in this PR. - Refactoring MicroBatchExecution - A major point of confusion in MicroBatchExecution control flow was always (at least to me) was that `populateStartOffsets` internally called `constructNextBatch` which was not obvious from just the name "populateStartOffsets" and made the control flow from the main trigger execution loop very confusing (main loop in `runActivatedStream` called `constructNextBatch` but only if `populateStartOffsets` hadn't already called it). Instead, the refactoring makes it cleaner. - `populateStartOffsets` only the updates `availableOffsets` and `committedOffsets`. Does not call `constructNextBatch`. - Main loop in `runActivatedStream` calls `constructNextBatch` which returns true or false reflecting whether the next batch is ready for executing. This method is now idempotent; if a batch has already been constructed, then it will always return true until the batch has been executed. - If next batch is ready then we call `runBatch` or sleep. - That's it. - Refactoring watermark management logic - This has been refactored out from `MicroBatchExecution` in a separate class to simplify `MicroBatchExecution`. - New method `shouldRunAnotherBatch` in `IncrementalExecution` - This returns true if there is any stateful operation in the last execution plan that requires another batch for state cleanup, etc. This is used to decide whether to construct a batch or not in `constructNextBatch`. - Changes to stream testing framework - Many tests used CheckLastBatch to validate answers. This assumed that there will be no more batches after the last set of input has been processed, so the last batch is the one that has output corresponding to the last input. This is not true anymore. To account for that, I made two changes. - `CheckNewAnswer` is a new test action that verifies the new rows generated since the last time the answer was checked by `CheckAnswer`, `CheckNewAnswer` or `CheckLastBatch`. This is agnostic to how many batches occurred between the last check and now. To do make this easier, I added a common trait between MemorySink and MemorySinkV2 to abstract out some common methods. - `assertNumStateRows` has been updated in the same way to be agnostic to batches while checking what the total rows and how many state rows were updated (sums up updates since the last check). ## How was this patch tested? - Changes made to existing tests - Tests have been changed in one of the following patterns. - Tests where the last input was given again to force another batch to be executed and state cleaned up / output generated, they were simplified by removing the extra input. - Tests using aggregation+watermark where CheckLastBatch were replaced with CheckNewAnswer to make them batch agnostic. - New tests added to check whether the flag works for streaming aggregation and deduplication Author: Tathagata Das Closes #21220 from tdas/SPARK-24157. --- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../streaming/IncrementalExecution.scala | 10 + .../streaming/MicroBatchExecution.scala | 231 ++++++++---------- .../streaming/WatermarkTracker.scala | 73 ++++++ .../sql/execution/streaming/memory.scala | 17 +- .../streaming/sources/memoryV2.scala | 8 +- .../streaming/statefulOperators.scala | 16 ++ .../sources/ForeachWriterSuite.scala | 8 +- .../streaming/EventTimeWatermarkSuite.scala | 112 ++++----- .../sql/streaming/FileStreamSinkSuite.scala | 7 +- .../sql/streaming/StateStoreMetricsTest.scala | 52 +++- .../spark/sql/streaming/StreamTest.scala | 56 ++++- ...cala => StreamingDeduplicationSuite.scala} | 94 +++---- .../sql/streaming/StreamingJoinSuite.scala | 18 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- 15 files changed, 450 insertions(+), 265 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{DeduplicateSuite.scala => StreamingDeduplicationSuite.scala} (80%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3942240c442b2..895e150756567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -919,6 +919,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -1313,6 +1321,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a83c884d55bd..c480b96626f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -143,4 +143,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6e231970f4a22..6709e7052f005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,6 +61,8 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private val watermarkTracker = new WatermarkTracker() + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -128,40 +130,55 @@ class MicroBatchExecution( * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (currentBatchIsRunnable) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // If the current batch has been executed, then increment the batch id, else there was + // no data to execute the batch + if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -211,6 +228,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -235,7 +253,6 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets - constructNextBatch() } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -243,19 +260,18 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -266,93 +282,63 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - case s: MicroBatchReader => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) - }.toMap - availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false + private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { + // If new data is already available that means this method has already been called before + // and it must have already committed the offset range of next batch to the offset log. + // Hence do nothing, just return true. + if (isNewDataAvailable) return true + + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } - } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -373,7 +359,7 @@ class MicroBatchExecution( reader.commit(reader.deserializeOffset(off.json)) } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -384,15 +370,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -400,6 +383,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -513,17 +498,17 @@ class MicroBatchExecution( } } - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. + withProgressLocked { + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() } + watermarkTracker.updateWatermark(lastExecution.executedPlan) + logDebug(s"Completed batch ${currentBatchId}") } /** Execute a function while locking the stream from making an progress */ - private[sql] def withProgressLocked(f: => Unit): Unit = { + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 0000000000000..80865669558dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan + +class WatermarkTracker extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var watermarkMs: Long = 0 + private var updated = false + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + watermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 + if (newWatermarkMs > watermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + watermarkMs = newWatermarkMs + updated = true + } else { + logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") + updated = false + } + } + + def currentWatermark: Long = synchronized { watermarkMs } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce7..22258274c70c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -222,11 +222,20 @@ class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) } } +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -236,7 +245,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -245,6 +254,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index d871d37ad37c1..0d6c239274dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { override def createStreamWriter( queryId: String, schema: StructType, @@ -67,6 +67,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c9354ac0ec78a..1691a6320a526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -126,6 +126,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -388,6 +394,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -454,6 +466,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 03bf71b3f4b78..e60c339bc9cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -211,14 +211,12 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. val allEvents = ForeachWriterSuite.allEvents() - assert(allEvents.size === 3) + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( ForeachWriterSuite.Open(partition = 0, version = 0), @@ -230,6 +228,10 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd ), Seq( ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 3), ForeachWriterSuite.Process(value = 3), ForeachWriterSuite.Close(None) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..7e8fde1ff8e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -137,20 +138,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.get("watermark") === formatTimestamp(5)) }, AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), CheckAnswer((10, 3)), assertEventStats { e => assert(e.get("max") === formatTimestamp(25)) assert(e.get("min") === formatTimestamp(25)) assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(5)) } ) } @@ -167,15 +160,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +183,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +241,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +380,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,8 +458,35 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index cf41d7e0e4fe1..ed53def556cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -279,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca8..e45f9d3e2e97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,58 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af0268fa47871..9d139a927bea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -192,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( @@ -202,6 +203,23 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + + private def operatorName = "CheckNewAnswer" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -435,13 +453,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { currentStream.awaitOffset(sourceIndex, offset) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } @@ -463,14 +492,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) - } - try if (lastOnly) latestBatchData else allData catch { + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } def executeAction(action: StreamAction): Unit = { @@ -704,6 +740,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } catch { case e: Throwable => failTest(e.toString) } + + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } } pos += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0088b64d6195e..42ffd472eb843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -97,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -141,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -260,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -279,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 11bdd13942dcb..da8f9608c1e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -192,7 +192,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 5, 11)), AddData(rightInput, (1, 10)), CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), @@ -276,14 +276,14 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), CheckLastBatch(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + assertNumStateRows(total = 12, updated = 5), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) @@ -573,7 +573,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 3, updated = 1) @@ -591,7 +591,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), assertNumStateRows(total = 3, updated = 1) @@ -630,7 +630,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 1, 5, 10)), AddData(rightInput, (1, 11)), CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), @@ -668,7 +668,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added AddData(rightInput, 20), CheckLastBatch( Row(20, 30, 40, 60)), @@ -678,7 +678,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + assertNumStateRows(total = 6, updated = 6), // all inputs added since last check AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), @@ -687,7 +687,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(), MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added AddData(rightInput, 1000), CheckLastBatch( Row(1000, 1010, 2000, 3000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 390d67d1feb27..0cb2375e0a49a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -334,7 +334,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) + assert(progress.sources(0).startOffset === "0") assert(progress.sources(0).endOffset !== null) assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms From dd4b1b9c7ccad3363a6a21524aed047fcd282f68 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 6 May 2018 10:25:01 +0800 Subject: [PATCH 0739/2461] [SPARK-24185][SPARKR][SQL] add flatten function to SparkR ## What changes were proposed in this pull request? add array flatten function to SparkR ## How was this patch tested? Unit tests were added in R/pkg/tests/fulltests/test_sparkSQL.R Author: Huaxin Gao Closes #21244 from huaxingao/spark-24185. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 14 ++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 4 files changed, 25 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f36d462a83cb0..8cd00352d1956 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -258,6 +258,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ec4bd4e73c7e5..0ec99d19e21e4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,6 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -3035,6 +3036,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{flatten}: Transforms an array of arrays into a single array. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 562d3399ee9c8..4ef12d19b3575 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 8cc2db7a140f9..3a8866bf2a88a 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1502,6 +1502,12 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + # Test flattern + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] From f38ea00e83099a5ae8d3afdec2e896e43c2db612 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 May 2018 20:41:32 -0700 Subject: [PATCH 0740/2461] [SPARK-24017][SQL] Refactor ExternalCatalog to be an interface ## What changes were proposed in this pull request? This refactors the external catalog to be an interface. It can be easier for the future work in the catalog federation. After the refactoring, `ExternalCatalog` is much cleaner without mixing the listener event generation logic. ## How was this patch tested? The existing tests Author: gatorsmile Closes #21122 from gatorsmile/refactorExternalCatalog. --- .../catalyst/catalog/ExternalCatalog.scala | 134 ++------ .../catalog/ExternalCatalogWithListener.scala | 298 ++++++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 26 +- .../catalog/ExternalCatalogEventSuite.scala | 2 +- .../spark/sql/internal/SharedState.scala | 9 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 26 +- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../sql/hive/HiveSessionStateBuilder.scala | 7 +- .../sql/hive/execution/SaveAsHiveFile.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 11 +- .../HiveExternalSessionCatalogSuite.scala | 2 +- .../sql/hive/HiveSchemaInferenceSuite.scala | 3 +- .../sql/hive/HiveSessionStateSuite.scala | 3 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 5 +- .../spark/sql/hive/ShowCreateTableSuite.scala | 3 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../sql/hive/test/TestHiveSingleton.scala | 2 +- 20 files changed, 384 insertions(+), 168 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c1..1a145c24d78cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 0000000000000..2f009be5816fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd52..741dc46b07382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a075..2fcaeca34db3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e3..5b6160e2b408f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index cbd75ad12d430..8980bcf885589 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,7 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 28c340a176d91..011a3ba553cb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -158,13 +158,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +480,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +489,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +624,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,7 +656,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { @@ -1208,7 +1208,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1221,12 +1221,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1235,7 +1235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index e5aff3b99d0b9..94ddeae1bf547 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalogBuilder: () => HiveExternalCatalog, + externalCatalogBuilder: () => ExternalCatalog, globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 40b9bb51ca9a0..2882672f327c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,14 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - new HiveSessionResourceLoader(session, () => externalCatalog.client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 6a7b25b36d9a5..e0f7375387d24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -122,7 +122,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 965aea2b61456..ee3f99ab7e9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -208,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0eac..fd5f47e428239 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d7..51a48a20daaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index ecc09cdcdbeaf..a3579862c9e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -44,7 +44,8 @@ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { val conf = sparkSession.sparkContext.hadoopConfiguration val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) sparkSession.cloneSession() - sparkSession.sharedState.externalCatalog.client.newSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) assert(oldValue == newValue, "cloneSession and then newSession should not affect the Derby directory") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 079fe45860544..aa5b531992613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -780,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474e..473bbced41b31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -289,7 +289,8 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6176273c88db1..dc96ec416afd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -134,8 +134,8 @@ class VersionsSuite extends SparkFunSuite with Logging { client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index daac6af9b557f..0341c3b378918 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1355,7 +1355,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1393,7 +1394,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 704a410b6a37b..828c18a770c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2099,7 +2099,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d3fff37c3424d..d50bf0b8fd603 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -30,7 +30,7 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { From a634d66ce767bd5e1d8553d1a2c32e2b1a80f642 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 7 May 2018 13:00:18 +0800 Subject: [PATCH 0741/2461] [SPARK-24126][PYSPARK] Use build-specific temp directory for pyspark tests. This avoids polluting and leaving garbage behind in /tmp, and allows the usual build tools to clean up any leftover files. Author: Marcelo Vanzin Closes #21198 from vanzin/SPARK-24126. --- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/streaming/tests.py | 6 ++++-- python/pyspark/tests.py | 33 ++++++++++++++++++++----------- python/run-tests.py | 29 +++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc6acfdb07d99..16aa9378ad8ee 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3092,8 +3092,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d77f1baa1f344..e4a428a0b27e7 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -1549,7 +1549,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8392d7f29af53..7b8ce2c6b799f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1951,7 +1951,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2017,7 +2022,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2033,7 +2038,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2051,7 +2056,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2070,7 +2075,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2087,8 +2092,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2103,9 +2110,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2124,7 +2133,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2144,7 +2153,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() diff --git a/python/run-tests.py b/python/run-tests.py index f408fc5082b3d..4c90926cfa350 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,11 +22,13 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: @@ -68,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -77,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -84,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -238,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -245,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() From 889f6cc10cbd7781df04f468674a61f0ac5a870b Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 7 May 2018 14:16:27 +0800 Subject: [PATCH 0742/2461] [SPARK-24143] filter empty blocks when convert mapstatus to (blockId, size) pair MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code(`MapOutputTracker.convertMapStatuses`), mapstatus are converted to (blockId, size) pair for all blocks – no matter the block is empty or not, which result in OOM when there are lots of consecutive empty blocks, especially when adaptive execution is enabled. (blockId, size) pair is only used in `ShuffleBlockFetcherIterator` to control shuffle-read and only non-empty block request is sent. Can we just filter out the empty blocks in MapOutputTracker.convertMapStatuses and save memory? ## How was this patch tested? not added yet. Author: jinxing Closes #21212 from jinxing64/SPARK-24143. --- .../org/apache/spark/MapOutputTracker.scala | 31 +++++++++------- .../storage/ShuffleBlockFetcherIterator.scala | 35 +++++++++++-------- .../apache/spark/MapOutputTrackerSuite.scala | 31 +++++++++++++++- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 19 +++++----- 5 files changed, 80 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b36..73646051f264c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -632,9 +632,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +643,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +670,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +843,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +860,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dd9df74689a13..6971efd2504c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d9..21f481d477242 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0))) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..2d8a83c6fabed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 692ae3bf597e0..cefebfa51b8b9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -99,7 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +176,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -310,7 +310,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -378,7 +378,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ) + ).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -437,7 +437,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -495,7 +495,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -513,14 +514,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. From 7564a9a70695dac2f0b5f51493d37cbc93691663 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 15:22:23 +0900 Subject: [PATCH 0743/2461] [SPARK-23921][SQL] Add array_sort function ## What changes were proposed in this pull request? The PR adds the SQL function `array_sort`. The behavior of the function is based on Presto's one. The function sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21021 from kiszk/SPARK-23921. --- python/pyspark/sql/functions.py | 26 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 240 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 34 ++- .../org/apache/spark/sql/functions.scala | 12 + .../spark/sql/DataFrameFunctionsSuite.scala | 34 ++- 6 files changed, 292 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad4bd6f5089e9..bd55b5f73b4d0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2183,20 +2183,38 @@ def array_max(col): def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe408..01776b85e6f53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -403,6 +403,7 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d63a531e3b74..23c09bc3b49d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,6 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -119,47 +120,16 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Common base class for [[SortArray]] and [[ArraySort]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } + protected def nullOrder: NullOrder @transient private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -170,9 +140,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - -1 + nullOrder } else if (o2 == null) { - 1 + -nullOrder } else { ordering.compare(o1, o2) } @@ -182,7 +152,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -193,9 +163,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - 1 + -nullOrder } else if (o2 == null) { - -1 + nullOrder } else { -ordering.compare(o1, o2) } @@ -203,18 +173,200 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } } - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType + def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + java.util.Arrays.sort(data, if (ascending) lt else gt) } new GenericArrayData(data.asInstanceOf[Array[Any]]) } + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } + s""" + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); + |} + """.stripMargin + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 + } +} + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + override def prettyName: String = "sort_array" } + +/** + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + /** * Returns a reversed string or an array with reverse order of elements. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7048d93fd5649..749374f1a14a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -61,28 +61,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2e22fa355514..10b6dcc0608c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3093,6 +3093,15 @@ object functions { ElementAt(column.expr, Literal(value)) } + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** * Creates a new row for each element in the given array or map column. * @@ -3332,6 +3341,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3341,6 +3351,8 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a5163accb1bb3..ae21cbc802d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,7 +276,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +286,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,6 +324,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } test("array size function") { From d2aa859b4faeda03e32a7574dd0c5b4ed367fae4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 7 May 2018 14:34:03 +0800 Subject: [PATCH 0744/2461] [SPARK-24160] ShuffleBlockFetcherIterator should fail if it receives zero-size blocks ## What changes were proposed in this pull request? This patch modifies `ShuffleBlockFetcherIterator` so that the receipt of zero-size blocks is treated as an error. This is done as a preventative measure to guard against a potential source of data loss bugs. In the shuffle layer, we guarantee that zero-size blocks will never be requested (a block containing zero records is always 0 bytes in size and is marked as empty such that it will never be legitimately requested by executors). However, the existing code does not fully take advantage of this invariant in the shuffle-read path: the existing code did not explicitly check whether blocks are non-zero-size. Additionally, our decompression and deserialization streams treat zero-size inputs as empty streams rather than errors (EOF might actually be treated as "end-of-stream" in certain layers (longstanding behavior dating to earliest versions of Spark) and decompressors like Snappy may be tolerant to zero-size inputs). As a result, if some other bug causes legitimate buffers to be replaced with zero-sized buffers (due to corruption on either the send or receive sides) then this would translate into silent data loss rather than an explicit fail-fast error. This patch addresses this problem by adding a `buf.size != 0` check. See code comments for pointers to tests which guarantee the invariants relied on here. ## How was this patch tested? Existing tests (which required modifications, since some were creating empty buffers in mocks). I also added a test to make sure we fail on zero-size blocks. To test that the zero-size blocks are indeed a potential corruption source, I manually ran a workload in `spark-shell` with a modified build which replaces all buffers with zero-size buffers in the receive path. Author: Josh Rosen Closes #21219 from JoshRosen/SPARK-24160. --- .../storage/ShuffleBlockFetcherIterator.scala | 19 +++++ .../ShuffleBlockFetcherIteratorSuite.scala | 71 +++++++++++++------ 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6971efd2504c2..b31862323a895 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -414,6 +414,25 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cefebfa51b8b9..8e9374b768adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are not checked for corruption") { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - doReturn(10000L).when(corruptBuffer).size() + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -428,9 +424,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -527,4 +523,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } From c5981976f1d514a3ad8a684b9a21cebe38b786fa Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 7 May 2018 14:45:14 +0800 Subject: [PATCH 0745/2461] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `onTaskStart` where stage ID is provided. In order to make sure cancelStage called for all stages `waitUntilEmpty` is called on `ListenerBus`. In [PR20888](https://github.com/apache/spark/pull/20888) this tried to get solved by: * Stopping the executor thread with `wait` * Wait for all `cancelStage` called * Kill the executor thread by setting `SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL` but the thread killing left the shared `SparkContext` sometimes in a state where further jobs can't be submitted. As a result DataFrameRangeSuite.test("Cancelling stage in a query with Range.") test passed properly but the next test inside the suite was hanging. ## How was this patch tested? Existing unit test executed 10k times. Author: Gabor Somogyi Closes #21214 from gaborgsomogyi/SPARK-23775_1. --- .../spark/sql/DataFrameRangeSuite.scala | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..b0b46640ff317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -23,8 +23,8 @@ import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -153,23 +153,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) - } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) } } sparkContext.addSparkListener(listener) for (codegen <- Seq(true, false)) { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } ex.getCause() match { case null => @@ -180,6 +174,8 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } } + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) eventually(timeout(20.seconds)) { assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } @@ -204,7 +200,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From f06528015d5856d6dc5cce00309bc2ae985e080f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 15:42:10 +0800 Subject: [PATCH 0746/2461] [SPARK-24160][FOLLOWUP] Fix compilation failure ## What changes were proposed in this pull request? SPARK-24160 is causing a compilation failure (after SPARK-24143 was merged). This fixes the issue. ## How was this patch tested? building successfully Author: Marco Gaido Closes #21256 from mgaido91/SPARK-24160_FOLLOWUP. --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 8e9374b768adc..a2997dbd1b1ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -546,7 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext, transfer, blockManager, - blocksByAddress, + blocksByAddress.toIterator, (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, From e35ad3caddeaa4b0d4c8524dcfb9e9f56dc7fe3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 16:57:37 +0900 Subject: [PATCH 0747/2461] [SPARK-23930][SQL] Add slice function ## What changes were proposed in this pull request? The PR add the `slice` function. The behavior of the function is based on Presto's one. The function slices an array according to the requested start index and length. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21040 from mgaido91/SPARK-23930. --- python/pyspark/sql/functions.py | 13 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 163 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 28 +++ .../expressions/ExpressionEvalHelper.scala | 6 + .../expressions/ObjectExpressionsSuite.scala | 1 - .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++ 9 files changed, 233 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bd55b5f73b4d0..ac3c79766702c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,19 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + @ignore_unicode_prefix @since(2.4) def array_join(col, delimiter, null_replacement=None): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 01776b85e6f53..87b0911e150c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Slice]("slice"), expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..4dda525294259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -730,6 +731,39 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 23c09bc3b49d7..12b9ab2b272ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -530,6 +529,129 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + /** * Creates a String containing all the elements of the input array separated by the delimiter. */ @@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends Expression { } private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { | for (int z = 0; z < args[y].numElements(); z++) { @@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + - | " bytes for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[(int)$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 749374f1a14a1..a2851d071c7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + test("ArrayJoin") { def testArrays( arrays: Seq[Expression], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..a22e9d4655e8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 730b36c32333c..77ca640f2e0bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b6dcc0608c2..8f9e4ae18b3f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + /** * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with * `nullReplacement`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ae21cbc802d0a..ecce06f4c0755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + test("array_join function") { val df = Seq( (Seq[String]("a", "b"), ","), From 4e861db5f149e10fd8dfe6b3c1484821a590b1e8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 7 May 2018 11:21:22 +0200 Subject: [PATCH 0748/2461] [SPARK-16406][SQL] Improve performance of LogicalPlan.resolve ## What changes were proposed in this pull request? `LogicalPlan.resolve(...)` uses linear searches to find an attribute matching a name. This is fine in normal cases, but gets problematic when you try to resolve a large number of columns on a plan with a large number of attributes. This PR adds an indexing structure to `resolve(...)` in order to find potential matches quicker. This PR improves the reference resolution time for the following code by 4x (11.8s -> 2.4s): ``` scala val n = 4000 val values = (1 to n).map(_.toString).mkString(", ") val columns = (1 to n).map("column" + _).mkString(", ") val query = s""" |SELECT $columns |FROM VALUES ($values) T($columns) |WHERE 1=2 AND 1 IN ($columns) |GROUP BY $columns |ORDER BY $columns |""".stripMargin spark.time(sql(query)) ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #14083 from hvanhovell/SPARK-16406. --- .../sql/catalyst/expressions/package.scala | 86 ++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 108 ++---------------- 2 files changed, 93 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af7..8a06daa37132d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 42034403d6d03..e487693927ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,105 +118,7 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** From d83e9637246b05eea202add07a168688f6c0481b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 7 May 2018 17:54:39 +0200 Subject: [PATCH 0749/2461] [SPARK-24043][SQL] Interpreted Predicate should initialize nondeterministic expressions ## What changes were proposed in this pull request? When creating an InterpretedPredicate instance, initialize any Nondeterministic expressions in the expression tree to avoid java.lang.IllegalArgumentException on later call to eval(). ## How was this patch tested? - sbt SQL tests - python SQL tests - new unit test Author: Bruce Robbins Closes #21144 from bersprockets/interpretedpredicate. --- .../spark/sql/catalyst/expressions/predicates.scala | 8 ++++++++ .../spark/sql/catalyst/expressions/PredicateSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..f8c6dc4e6adc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,6 +36,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -449,4 +449,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } From 56a52e0a58fc82ea69e47d0d8c4f905565be7c8b Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 7 May 2018 14:47:58 -0700 Subject: [PATCH 0750/2461] [SPARK-15750][MLLIB][PYSPARK] Constructing FPGrowth fails when no numPartitions specified in pyspark ## What changes were proposed in this pull request? Change FPGrowth from private to private[spark]. If no numPartitions is specified, then default value -1 is used. But -1 is only valid in the construction function of FPGrowth, but not in setNumPartitions. So I make this change and use the constructor directly rather than using set method. ## How was this patch tested? Unit test is added Author: Jeff Zhang Closes #13493 from zjffdu/SPARK-15750. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 5 +---- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae59..db3f074ecfbac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d16..4f2b7e6f0764e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 14d788b0bef60..4c2ce137e331c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,6 +1763,17 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: From 1c9c5de951ed86290bcd7d8edaab952b8cacd290 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 May 2018 14:52:14 -0700 Subject: [PATCH 0751/2461] [SPARK-23291][SPARK-23291][R][FOLLOWUP] Update SparkR migration note for ## What changes were proposed in this pull request? This PR fixes the migration note for SPARK-23291 since it's going to backport to 2.3.1. See the discussion in https://issues.apache.org/jira/browse/SPARK-23291 ## How was this patch tested? N/A Author: hyukjinkwon Closes #21249 from HyukjinKwon/SPARK-23291. --- docs/sparkr.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index 7fabab5d38f16..4faad2c4c1824 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -664,6 +664,6 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. -## Upgrading to Spark 2.4.0 +## Upgrading to SparkR 2.3.1 and above - - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. From f48bd6bdc5aefd9ec43e2d0ee648d17add7ef554 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:55:41 -0700 Subject: [PATCH 0752/2461] [SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning ## What changes were proposed in this pull request? ML test for StructuredStreaming: spark.ml.tuning ## How was this patch tested? N/A Author: WeichenXu Closes #20261 from WeichenXu123/ml_stream_tuning_test. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 15 +++++++++++---- .../ml/tuning/TrainValidationSplitSuite.scala | 15 +++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade2627090..e6ee7220d2279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c831..cd76acf9c67bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { From 76ecd095024a658bf68e5db658e4416565b30c17 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:57:14 -0700 Subject: [PATCH 0753/2461] [SPARK-20114][ML] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? PrefixSpan API for spark.ml. New implementation instead of #20810 ## How was this patch tested? TestSuite added. Author: WeichenXu Closes #20973 from WeichenXu123/prefixSpan2. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 96 +++++++++++++ .../apache/spark/mllib/fpm/PrefixSpan.scala | 3 +- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 136 ++++++++++++++++++ 3 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 0000000000000..02168fee16dbf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +object PrefixSpan { + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{Seq[Seq[_]]}}} type + * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column + * are ignored + * @param minSupport the minimal support level of the sequential pattern, any pattern that + * appears more than (minSupport * size-of-the-dataset) times will be output + * (recommended value: `0.1`). + * @param maxPatternLength the maximal length of the sequential pattern + * (recommended value: `10`). + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the + * internal storage format) allowed in a projected database before + * local processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run + * (recommended value: `32000000`). + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns( + dataset: Dataset[_], + sequenceCol: String, + minSupport: Double, + maxPatternLength: Int, + maxLocalProjDBSize: Long): DataFrame = { + + val inputType = dataset.schema(sequenceCol).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + + val data = dataset.select(sequenceCol) + val sequences = data.where(col(sequenceCol).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(maxLocalProjDBSize) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2c..7aed2f3bd8a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000000..9e538696cbcf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", + minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + From 0d63eb8888d17df747fb41d7ba254718bb7af3ae Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 7 May 2018 20:08:41 -0700 Subject: [PATCH 0754/2461] [SPARK-23975][ML] Add support of array input for all clustering methods ## What changes were proposed in this pull request? Add support for all of the clustering methods ## How was this patch tested? unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21195 from ludatabricks/SPARK-23975-1. --- .../spark/ml/clustering/BisectingKMeans.scala | 21 ++++---- .../spark/ml/clustering/GaussianMixture.scala | 12 +++-- .../apache/spark/ml/clustering/KMeans.scala | 31 +++--------- .../org/apache/spark/ml/clustering/LDA.scala | 9 ++-- .../apache/spark/ml/util/DatasetUtils.scala | 13 ++++- .../apache/spark/ml/util/SchemaUtils.scala | 16 ++++++- .../ml/clustering/BisectingKMeansSuite.scala | 21 +++++++- .../ml/clustering/GaussianMixtureSuite.scala | 21 +++++++- .../spark/ml/clustering/KMeansSuite.scala | 48 ++++++------------- .../apache/spark/ml/clustering/LDASuite.scala | 20 +++++++- .../apache/spark/ml/util/MLTestingUtils.scala | 23 ++++++++- 11 files changed, 147 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index addc12ac52ec1..438e53ba6197c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,17 +22,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -75,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -113,7 +111,8 @@ class BisectingKMeansModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -132,9 +131,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -260,9 +259,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index b5804900c0358..88d618c3a03a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[Vector] = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index de61c9c089a36..97f246fbfd859 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, PipelineStage} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,24 +86,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) - /** - * Validates the input schema. - * @param schema input schema - */ - private[clustering] def validateSchema(schema: StructType): Unit = { - val typeCandidates = List( new VectorUDT, - new ArrayType(DoubleType, false), - new ArrayType(FloatType, false)) - - SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) - } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateSchema(schema) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -160,12 +149,8 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - validateSchema(dataset.schema) - - val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } @@ -351,11 +336,7 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select( - DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 47077230fac0a..afe599cd167cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -461,7 +461,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -938,7 +939,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 52619cb65489a..6af4b3ebc2cc2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.linalg.{Vectors, VectorUDT} -import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} @@ -60,4 +62,11 @@ private[spark] object DatasetUtils { throw new IllegalArgumentException(s"$other column cannot be cast to Vector") } } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -101,4 +102,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 02880f96ae6d9..f3ff2afcad2cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -182,6 +185,22 @@ class BisectingKMeansSuite model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e4183..d0d461a42711a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap @@ -24,8 +26,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - +import org.apache.spark.sql.{DataFrame, Dataset, Row} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 5445ebe5c95eb..680a7c2034083 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} @@ -25,13 +26,11 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, - KMeansModel => MLlibKMeansModel} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -202,38 +201,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("KMean with Array input") { - val featuresColNameD = "array_double_features" - val featuresColNameF = "array_float_features" - - val doubleUDF = udf { (features: Vector) => - val featureArray = Array.fill[Double](features.size)(0.0) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray - } - val floatUDF = udf { (features: Vector) => - val featureArray = Array.fill[Float](features.size)(0.0f) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) } - val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) - .drop("features") - val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) - .drop("features") - assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) - assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) - - val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) - val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) - val modelD = kmeansD.fit(newdatasetD) - val modelF = kmeansF.fit(newdatasetF) - val transformedD = modelD.transform(newdatasetD) - val transformedF = modelF.transform(newdatasetF) - - val predictDifference = transformedD.select("prediction") - .except(transformedF.select("prediction")) - assert(predictDifference.count() == 0) - assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76bd..8d728f063dd8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite @@ -26,7 +28,6 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -323,4 +324,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c328d81b4bc3a..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -247,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } From cd12c5c3ecf28f7b04f566c2057f9b65eb456b7d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 8 May 2018 12:21:33 +0800 Subject: [PATCH 0755/2461] [SPARK-24128][SQL] Mention configuration option in implicit CROSS JOIN error ## What changes were proposed in this pull request? Mention `spark.sql.crossJoin.enabled` in error message when an implicit `CROSS JOIN` is detected. ## How was this patch tested? `CartesianProductSuite` and `JoinSuite`. Author: Henry Robinson Closes #21201 from henryr/spark-24128. --- R/pkg/tests/fulltests/test_sparkSQL.R | 4 ++-- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 ++++-- .../catalyst/optimizer/CheckCartesianProductsSuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 3a8866bf2a88a..43725e0ebd3bf 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2210,8 +2210,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 45f13956a0a85..bfa61116a6658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1182,12 +1182,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e8..788fedb3c8e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63ab..8fa747465cb1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) From 05eb19b6e09065265358eec2db2ff3b42806dfc9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 May 2018 14:32:04 +0800 Subject: [PATCH 0756/2461] [SPARK-24188][CORE] Restore "/version" API endpoint. It was missing the jax-rs annotation. Author: Marcelo Vanzin Closes #21245 from vanzin/SPARK-24188. Change-Id: Ib338e34b363d7c729cc92202df020dc51033b719 --- .../org/apache/spark/status/api/v1/ApiRootResource.scala | 1 + .../org/apache/spark/deploy/history/HistoryServerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 7127397f6205c..d121068718b8a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87f12f303cd5e..a871b1c717837 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -296,6 +296,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) From e17567ca78dbb416039c17da212957c8955bfa65 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 8 May 2018 11:34:27 +0200 Subject: [PATCH 0757/2461] [SPARK-24076][SQL] Use different seed in HashAggregate to avoid hash conflict ## What changes were proposed in this pull request? HashAggregate uses the same hash algorithm and seed as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n. Considering below example: ``` SET spark.sql.shuffle.partitions=8192; INSERT OVERWRITE TABLE target_xxx SELECT item_id, auct_end_dt FROM from source_xxx GROUP BY item_id, auct_end_dt; ``` In the shuffle stage, if user sets the shuffle.partition = 8192, all tuples in the same partition will meet the following relationship: ``` hash(tuple x) = hash(tuple y) + n * 8192 ``` Then in the next HashAggregate stage, all tuples from the same partition need be put into a 16K BytesToBytesMap (unsafeRowAggBuffer). Here, the HashAggregate uses the same hash algorithm on the same expression as shuffle, and uses the same seed, and 16K = 8192 * 2, so actually, all tuples in the same parititon will only be hashed to 2 different places in the BytesToBytesMap. It is bad hash conflict. With BytesToBytesMap growing, the conflict will always exist. Before change: hash_conflict After change: no_hash_conflict ## How was this patch tested? Unit tests and production cases. Author: yucai Closes #21149 from yucai/SPARK-24076. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..6a8ec4f722aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -755,7 +755,10 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, From b54bbe57b33b00063596cd9588fa2461745ed571 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 May 2018 21:22:54 +0800 Subject: [PATCH 0758/2461] [SPARK-24131][PYSPARK][FOLLOWUP] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? More close to Scala API behavior when can't parse input by throwing exception. Add tests. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21211 from viirya/SPARK-24131-followup. --- python/pyspark/tests.py | 4 ++++ python/pyspark/util.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7b8ce2c6b799f..498d6b57e4353 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2312,6 +2312,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 04df835bf6717..59cc2a6329350 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -62,24 +62,31 @@ def _get_argspec(f): return argspec -def majorMinorVersion(version): +class VersionUtils(object): """ - Get major and minor version numbers for given Spark version string. - - >>> version = "2.4.0" - >>> majorMinorVersion(version) - (2, 4) + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). - >>> version = "abc" - >>> majorMinorVersion(version) is None - True + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) - """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - return None - else: - return (int(m.group(1)), int(m.group(2))) + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") if __name__ == "__main__": From 2f6fe7d679a878ffd103cac6f06081c5b3888744 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 May 2018 21:24:35 +0800 Subject: [PATCH 0759/2461] [SPARK-23094][SPARK-23723][SPARK-23724][SQL][FOLLOW-UP] Support custom encoding for json files ## What changes were proposed in this pull request? This is to add a test case to check the behaviors when users write json in the specified UTF-16/UTF-32 encoding with multiline off. ## How was this patch tested? N/A Author: gatorsmile Closes #21254 from gatorsmile/followupSPARK-23094. --- .../spark/sql/catalyst/json/JSONOptions.scala | 9 +++++---- .../datasources/json/JsonSuite.scala | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f130af606e19..2579374e3f4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -110,11 +110,12 @@ private[sql] class JSONOptions( val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) val isBlacklisted = blacklist.contains(Charset.forName(enc)) require(multiLine || !isBlacklisted, - s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: - | ${blacklist.mkString(", ")}""".stripMargin) + s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. + |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - val isLineSepRequired = !(multiLine == false && - Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") enc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0db688fec9a67..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2313,6 +2313,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq( + ("a", 1), ("b", 2), ("c", 3)) + ).repartition(2) + val e = intercept[IllegalArgumentException] { + ds.write + .option("encoding", encoding) + .option("multiline", "false") + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + }.getMessage + assert(e.contains( + s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + } + } + } + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { val schema = new StructType().add("f1", StringType).add("f2", IntegerType) From 487faf17ab96c8edb729501dfb1ff82f7b2c6031 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 8 May 2018 23:43:02 +0800 Subject: [PATCH 0760/2461] [SPARK-24117][SQL] Unified the getSizePerRow ## What changes were proposed in this pull request? This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example: 1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80) 2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36) ## How was this patch tested? Exist tests Author: Yuming Wang Closes #21189 from wangyum/SPARK-24117. --- .../sql/catalyst/plans/logical/LocalRelation.scala | 3 ++- .../logical/statsEstimation/EstimationUtils.scala | 14 ++++++++------ .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 10 ++++------ .../sql/execution/streaming/sources/memoryV2.scala | 3 ++- .../spark/sql/StatisticsCollectionSuite.scala | 2 +- .../sql/execution/streaming/MemorySinkSuite.scala | 4 ++-- 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 720d42ab409a0..8c4828a4cef23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -77,7 +78,7 @@ case class LocalRelation( } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0f147f0ffb135..211a2a0717371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -73,13 +71,12 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => + 8 + attributes.map { attr => if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => @@ -92,10 +89,15 @@ object EstimationUtils { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 85f67c7d66075..ee43f9126386b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 22258274c70c1..6720cdd24b1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 0d6c239274dd8..468313bfe8c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} @@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b91712f4cc25d..60fa951e23178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9d..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { From e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 9 May 2018 08:32:20 +0800 Subject: [PATCH 0761/2461] [SPARK-24068] Propagating DataFrameReader's options to Text datasource on schema inferring ## What changes were proposed in this pull request? While reading CSV or JSON files, DataFrameReader's options are converted to Hadoop's parameters, for example there: https://github.com/apache/spark/blob/branch-2.3/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L302 but the options are not propagated to Text datasource on schema inferring, for instance: https://github.com/apache/spark/blob/branch-2.3/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala#L184-L188 The PR proposes propagation of user's options to Text datasource on scheme inferring in similar way as user's options are converted to Hadoop parameters if schema is specified. ## How was this patch tested? The changes were tested manually by using https://github.com/twitter/hadoop-lzo: ``` hadoop-lzo> mvn clean package hadoop-lzo> ln -s ./target/hadoop-lzo-0.4.21-SNAPSHOT.jar ./hadoop-lzo.jar ``` Create 2 test files in JSON and CSV format and compress them: ```shell $ cat test.csv col1|col2 a|1 $ lzop test.csv $ cat test.json {"col1":"a","col2":1} $ lzop test.json ``` Run `spark-shell` with hadoop-lzo: ``` bin/spark-shell --jars ~/hadoop-lzo/hadoop-lzo.jar ``` reading compressed CSV and JSON without schema: ```scala spark.read.option("io.compression.codecs", "com.hadoop.compression.lzo.LzopCodec").option("inferSchema",true).option("header",true).option("sep","|").csv("test.csv.lzo").show() +----+----+ |col1|col2| +----+----+ | a| 1| +----+----+ ``` ```scala spark.read.option("io.compression.codecs", "com.hadoop.compression.lzo.LzopCodec").option("multiLine", true).json("test.json.lzo").printSchema root |-- col1: string (nullable = true) |-- col2: long (nullable = true) ``` Author: Maxim Gekk Author: Maxim Gekk Closes #21182 from MaxGekk/text-options. --- .../apache/spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/execution/datasources/csv/CSVDataSource.scala | 6 ++++-- .../sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../spark/sql/execution/datasources/csv/CSVUtils.scala | 2 -- .../execution/datasources/csv/UnivocityParser.scala | 2 -- .../execution/datasources/json/JsonDataSource.scala | 10 ++++++---- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2579374e3f4e1..2ff12acb2946f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index bc1f4ab3bb053..dc54d182651b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -185,7 +185,8 @@ object TextInputCSVDataSource extends CSVDataSource { DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = options.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { @@ -250,7 +251,8 @@ object MultiLineCSVDataSource extends CSVDataSource { options: CSVOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) val name = paths.mkString(",") - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + options.parameters)) FileInputFormat.setInputPaths(job, paths: _*) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2ec0fc605a84b..ed2dc65a47914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 31464f1bcc68e..9dae41b63e810 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv -import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3d6cc30f2ba83..99557a1ceb0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.InputStream import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.Try import scala.util.control.NonFatal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 983a5f0dcade2..ba83df0efebd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -121,7 +121,7 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession, paths = paths, className = classOf[TextFileFormat].getName, - options = textOptions + options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -159,7 +159,7 @@ object MultiLineJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) val parser = parsedOptions.encoding .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) @@ -170,9 +170,11 @@ object MultiLineJsonDataSource extends JsonDataSource { private def createBaseRdd( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + parsedOptions.parameters)) val conf = job.getConfiguration val name = paths.mkString(",") FileInputFormat.setInputPaths(job, paths: _*) From 9498e528d21e286e496da6ea9bf9c7ad73a7b5bd Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 9 May 2018 08:39:46 +0800 Subject: [PATCH 0762/2461] [SPARK-23355][SQL][DOC][FOLLOWUP] Add migration doc for TBLPROPERTIES ## What changes were proposed in this pull request? In Apache Spark 2.4, [SPARK-23355](https://issues.apache.org/jira/browse/SPARK-23355) fixes a bug which ignores table properties during convertMetastore for tables created by STORED AS ORC/PARQUET. For some Parquet tables having table properties like TBLPROPERTIES (parquet.compression 'NONE'), it was ignored by default before Apache Spark 2.4. After upgrading cluster, Spark will write uncompressed file which is different from Apache Spark 2.3 and old. This PR adds a migration note for that. ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #21269 from dongjoon-hyun/SPARK-23355-DOC. --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 075b953a0898e..c521f3cb51e58 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1812,6 +1812,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. From 7e7350285dc22764f599671d874617c0eea093e5 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 8 May 2018 21:20:58 -0700 Subject: [PATCH 0763/2461] [SPARK-24132][ML] Instrumentation improvement for classification ## What changes were proposed in this pull request? - Add OptionalInstrumentation as argument for getNumClasses in ml.classification.Classifier - Change the function call for getNumClasses in train() in ml.classification.DecisionTreeClassifier, ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes - Modify the instrumentation creation in ml.classification.LinearSVC - Change the log call in ml.classification.OneVsRest and ml.classification.LinearSVC ## How was this patch tested? Manual. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21204 from ludatabricks/SPARK-23686. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 9 ++++++--- .../org/apache/spark/ml/classification/LinearSVC.scala | 9 ++++++--- .../org/apache/spark/ml/classification/NaiveBayes.scala | 3 ++- .../org/apache/spark/ml/classification/OneVsRest.scala | 4 ++-- .../spark/ml/classification/RandomForestClassifier.scala | 4 +++- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 57797d1cc4978..c9786f1f7ceb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") ( override def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + instr.logNumClasses(numClasses) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) @@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeClassificationModel = { val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 80c537e1e0eb2..38eb04556b775 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) @@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") ( bcFeaturesStd.destroy(blocking = false) if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 45fb585ed2262..1dde18d2d1a31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") ( private[spark] def trainWithLabelCheck( dataset: Dataset[_], positiveLabel: Boolean): NaiveBayesModel = { + val instr = Instrumentation.create(this, dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) + instr.logNumClasses(numClasses) require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") @@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") ( } } - val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7df53a6b8ad10..3474b61e40136 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") ( transformSchema(dataset.schema) val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism) + instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") ( getClassifier match { case _: HasWeightCol => true case c => - logWarning(s"weightCol is ignored, as it is not supported by $c now.") + instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.") false } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f1ef26a07d3f8..040db3b94b041 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") ( set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) instr.logSuccess(m) m } From cac9b1dea1bb44fa42abf77829c05bf93f70cf20 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 9 May 2018 12:27:32 +0800 Subject: [PATCH 0764/2461] [SPARK-23972][BUILD][SQL] Update Parquet to 1.10.0. ## What changes were proposed in this pull request? This updates Parquet to 1.10.0 and updates the vectorized path for buffer management changes. Parquet 1.10.0 uses ByteBufferInputStream instead of byte arrays in encoders. This allows Parquet to break allocations into smaller chunks that are better for garbage collection. ## How was this patch tested? Existing Parquet tests. Running in production at Netflix for about 3 months. Author: Ryan Blue Closes #21070 from rdblue/SPARK-23972-update-parquet-to-1.10.0. --- dev/deps/spark-deps-hadoop-2.6 | 12 +- dev/deps/spark-deps-hadoop-2.7 | 12 +- dev/deps/spark-deps-hadoop-3.1 | 12 +- docs/sql-programming-guide.md | 2 +- pom.xml | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../parquet/VectorizedColumnReader.java | 39 ++-- .../parquet/VectorizedPlainValuesReader.java | 166 +++++++++++------- .../parquet/VectorizedRleValuesReader.java | 163 ++++++++--------- .../datasources/parquet/ParquetOptions.scala | 5 +- .../describe-part-after-analyze.sql.out | 12 +- .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- 13 files changed, 241 insertions(+), 198 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c3d1dd444b506..f479c13f00be6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -162,13 +162,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 290867035f91d..e7c4599cb5003 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -163,13 +163,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 97ad65a4096cb..3447cd7395d95 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -181,13 +181,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c521f3cb51e58..3e8946e424237 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession Sets the compression codec used when writing Parquet files. If either `compression` or `parquet.compression` is specified in the table-specific options/properties, the precedence would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: - none, uncompressed, snappy, gzip, lzo. + none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. diff --git a/pom.xml b/pom.xml index 88e77ff874748..6e37e518d86e4 100644 --- a/pom.xml +++ b/pom.xml @@ -129,7 +129,7 @@ 1.2.110.12.1.1 - 1.8.2 + 1.10.01.4.3nohive1.6.0 @@ -1778,6 +1778,12 @@ parquet-hadoop${parquet.version}${parquet.deps.scope} + + + commons-pool + commons-pool + + org.apache.parquet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 895e150756567..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -345,7 +345,7 @@ object SQLConf { "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e65cd252c3ddf..10d6ed85a4080 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -293,7 +293,7 @@ protected static IntIterator createRLEIterator( return new RLEIntIterator( new RunLengthBitPackingHybridDecoder( BytesUtils.getWidthFromMaxInt(maxLevel), - new ByteArrayInputStream(bytes.toByteArray()))); + bytes.toInputStream())); } catch (IOException e) { throw new IOException("could not read levels in page for col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 72f1d024b08ce..d5969b55eef96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.TimeZone; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -388,7 +390,8 @@ private void decodeDictionaryIds( * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) + throws IOException { if (column.dataType() != DataTypes.BooleanType) { throw constructConvertNotSupportedException(descriptor, column); } @@ -396,7 +399,7 @@ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, WritableColumnVector column) { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -414,7 +417,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { } } - private void readLongBatch(int rowId, int num, WritableColumnVector column) { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType()) || @@ -434,7 +437,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } - private void readFloatBatch(int rowId, int num, WritableColumnVector column) { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -445,7 +448,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { } } - private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -456,7 +459,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { } } - private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -556,7 +559,7 @@ public Void visit(DataPageV2 dataPageV2) { }); } - private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { + private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException { this.endOfPageValueCount = valuesRead + pageValueCount; if (dataEncoding.usesDictionary()) { this.dataColumn = null; @@ -581,7 +584,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } try { - dataColumn.initFromPage(pageValueCount, bytes, offset); + dataColumn.initFromPage(pageValueCount, in); } catch (IOException e) { throw new IOException("could not read page in col " + descriptor, e); } @@ -602,12 +605,11 @@ private void readPageV1(DataPageV1 page) throws IOException { this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { - byte[] bytes = page.getBytes().toByteArray(); - rlReader.initFromPage(pageValueCount, bytes, 0); - int next = rlReader.getNextOffset(); - dlReader.initFromPage(pageValueCount, bytes, next); - next = dlReader.getNextOffset(); - initDataReader(page.getValueEncoding(), bytes, next); + BytesInput bytes = page.getBytes(); + ByteBufferInputStream in = bytes.toInputStream(); + rlReader.initFromPage(pageValueCount, in); + dlReader.initFromPage(pageValueCount, in); + initDataReader(page.getValueEncoding(), in); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } @@ -619,12 +621,13 @@ private void readPageV2(DataPageV2 page) throws IOException { page.getRepetitionLevels(), descriptor); int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - this.defColumn = new VectorizedRleValuesReader(bitWidth); + // do not read the length from the stream. v2 pages handle dividing the page bytes. + this.defColumn = new VectorizedRleValuesReader(bitWidth, false); this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); - this.defColumn.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.defColumn.initFromPage( + this.pageValueCount, page.getDefinitionLevels().toInputStream()); try { - initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); + initDataReader(page.getDataEncoding(), page.getData().toInputStream()); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 5b75f719339fb..aacefacfc1c1a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,34 +20,30 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.ParquetDecodingException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. */ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { - private byte[] buffer; - private int offset; - private int bitOffset; // Only used for booleans. - private ByteBuffer byteBuffer; // used to wrap the byte array buffer + private ByteBufferInputStream in = null; - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // Only used for booleans. + private int bitOffset; + private byte currentByte = 0; public VectorizedPlainValuesReader() { } @Override - public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { - this.buffer = bytes; - this.offset = offset + Platform.BYTE_ARRAY_OFFSET; - if (bigEndianPlatform) { - byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); - } + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; } @Override @@ -63,115 +59,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + private ByteBuffer getBuffer(int length) { + try { + return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read " + length + " bytes", e); + } + } + @Override public final void readIntegers(int total, WritableColumnVector c, int rowId) { - c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putIntsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putInt(rowId + i, buffer.getInt()); + } + } } @Override public final void readLongs(int total, WritableColumnVector c, int rowId) { - c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putLongsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, buffer.getLong()); + } + } } @Override public final void readFloats(int total, WritableColumnVector c, int rowId) { - c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putFloats(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putFloat(rowId + i, buffer.getFloat()); + } + } } @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { - c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putDoubles(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putDouble(rowId + i, buffer.getDouble()); + } + } } @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { - for (int i = 0; i < total; i++) { - // Bytes are stored as a 4-byte little endian int. Just read the first byte. - // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putByte(rowId + i, Platform.getByte(buffer, offset)); - offset += 4; + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + for (int i = 0; i < total; i += 1) { + c.putByte(rowId + i, buffer.get()); + // skip the next 3 bytes + buffer.position(buffer.position() + 3); } } @Override public final boolean readBoolean() { - byte b = Platform.getByte(buffer, offset); - boolean v = (b & (1 << bitOffset)) != 0; + // TODO: vectorize decoding and keep boolean[] instead of currentByte + if (bitOffset == 0) { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + + boolean v = (currentByte & (1 << bitOffset)) != 0; bitOffset += 1; if (bitOffset == 8) { bitOffset = 0; - offset++; } return v; } @Override public final int readInteger() { - int v = Platform.getInt(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Integer.reverseBytes(v); - } - offset += 4; - return v; + return getBuffer(4).getInt(); } @Override public final long readLong() { - long v = Platform.getLong(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Long.reverseBytes(v); - } - offset += 8; - return v; + return getBuffer(8).getLong(); } @Override public final byte readByte() { - return (byte)readInteger(); + return (byte) readInteger(); } @Override public final float readFloat() { - float v; - if (!bigEndianPlatform) { - v = Platform.getFloat(buffer, offset); - } else { - v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 4; - return v; + return getBuffer(4).getFloat(); } @Override public final double readDouble() { - double v; - if (!bigEndianPlatform) { - v = Platform.getDouble(buffer, offset); - } else { - v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 8; - return v; + return getBuffer(8).getDouble(); } @Override public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); - int start = offset; - offset += len; - v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + v.putByteArray(rowId + i, bytes); + } } } @Override public final Binary readBinary(int len) { - Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); - offset += len; - return result; + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + return Binary.fromConstantByteArray( + buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + return Binary.fromConstantByteArray(bytes); + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fc7fa70c39419..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; @@ -27,6 +28,9 @@ import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import java.io.IOException; +import java.nio.ByteBuffer; + /** * A values reader for Parquet's run-length encoded data. This is based off of the version in * parquet-mr with these changes: @@ -49,9 +53,7 @@ private enum MODE { } // Encoded data. - private byte[] in; - private int end; - private int offset; + private ByteBufferInputStream in; // bit/byte width of decoded data and utility to batch unpack them. private int bitWidth; @@ -70,45 +72,40 @@ private enum MODE { // If true, the bit width is fixed. This decoder is used in different places and this also // controls if we need to read the bitwidth from the beginning of the data stream. private final boolean fixedWidth; + private final boolean readLength; public VectorizedRleValuesReader() { - fixedWidth = false; + this.fixedWidth = false; + this.readLength = false; } public VectorizedRleValuesReader(int bitWidth) { - fixedWidth = true; + this.fixedWidth = true; + this.readLength = bitWidth != 0; + init(bitWidth); + } + + public VectorizedRleValuesReader(int bitWidth, boolean readLength) { + this.fixedWidth = true; + this.readLength = readLength; init(bitWidth); } @Override - public void initFromPage(int valueCount, byte[] page, int start) { - this.offset = start; - this.in = page; + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; if (fixedWidth) { - if (bitWidth != 0) { + // initialize for repetition and definition levels + if (readLength) { int length = readIntLittleEndian(); - this.end = this.offset + length; + this.in = in.sliceStream(length); } } else { - this.end = page.length; - if (this.end != this.offset) init(page[this.offset++] & 255); - } - if (bitWidth == 0) { - // 0 bit width, treat this as an RLE run of valueCount number of 0's. - this.mode = MODE.RLE; - this.currentCount = valueCount; - this.currentValue = 0; - } else { - this.currentCount = 0; + // initialize for values + if (in.available() > 0) { + init(in.read()); + } } - } - - // Initialize the reader from a buffer. This is used for the V2 page encoding where the - // definition are in its own buffer. - public void initFromBuffer(int valueCount, byte[] data) { - this.offset = 0; - this.in = data; - this.end = data.length; if (bitWidth == 0) { // 0 bit width, treat this as an RLE run of valueCount number of 0's. this.mode = MODE.RLE; @@ -129,11 +126,6 @@ private void init(int bitWidth) { this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); } - @Override - public int getNextOffset() { - return this.end; - } - @Override public boolean readBoolean() { return this.readInteger() != 0; @@ -182,7 +174,7 @@ public void readIntegers( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -217,7 +209,7 @@ public void readBooleans( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -251,7 +243,7 @@ public void readBytes( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -285,7 +277,7 @@ public void readShorts( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -321,7 +313,7 @@ public void readLongs( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -355,7 +347,7 @@ public void readFloats( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -389,7 +381,7 @@ public void readDoubles( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -423,7 +415,7 @@ public void readBinarys( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -462,7 +454,7 @@ public void readIntegers( WritableColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -559,12 +551,12 @@ public Binary readBinary(int len) { /** * Reads the next varint encoded int. */ - private int readUnsignedVarInt() { + private int readUnsignedVarInt() throws IOException { int value = 0; int shift = 0; int b; do { - b = in[offset++] & 255; + b = in.read(); value |= (b & 0x7F) << shift; shift += 7; } while ((b & 0x80) != 0); @@ -574,35 +566,32 @@ private int readUnsignedVarInt() { /** * Reads the next 4 byte little endian int. */ - private int readIntLittleEndian() { - int ch4 = in[offset] & 255; - int ch3 = in[offset + 1] & 255; - int ch2 = in[offset + 2] & 255; - int ch1 = in[offset + 3] & 255; - offset += 4; + private int readIntLittleEndian() throws IOException { + int ch4 = in.read(); + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** * Reads the next byteWidth little endian int. */ - private int readIntLittleEndianPaddedOnBitWidth() { + private int readIntLittleEndianPaddedOnBitWidth() throws IOException { switch (bytesWidth) { case 0: return 0; case 1: - return in[offset++] & 255; + return in.read(); case 2: { - int ch2 = in[offset] & 255; - int ch1 = in[offset + 1] & 255; - offset += 2; + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 8) + ch2; } case 3: { - int ch3 = in[offset] & 255; - int ch2 = in[offset + 1] & 255; - int ch1 = in[offset + 2] & 255; - offset += 3; + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { @@ -619,32 +608,36 @@ private int ceil8(int value) { /** * Reads the next group. */ - private void readNextGroup() { - int header = readUnsignedVarInt(); - this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; - switch (mode) { - case RLE: - this.currentCount = header >>> 1; - this.currentValue = readIntLittleEndianPaddedOnBitWidth(); - return; - case PACKED: - int numGroups = header >>> 1; - this.currentCount = numGroups * 8; - int bytesToRead = ceil8(this.currentCount * this.bitWidth); - - if (this.currentBuffer.length < this.currentCount) { - this.currentBuffer = new int[this.currentCount]; - } - currentBufferIdx = 0; - int valueIndex = 0; - for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { - this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); - valueIndex += 8; - } - offset += bytesToRead; - return; - default: - throw new ParquetDecodingException("not a valid mode " + this.mode); + private void readNextGroup() { + try { + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int valueIndex = 0; + while (valueIndex < this.currentCount) { + // values are bit packed 8 at a time, so reading bitWidth will always work + ByteBuffer buffer = in.slice(bitWidth); + this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex); + valueIndex += 8; + } + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read from input stream", e); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index f36a89a4c3c5f..9cfc30725f03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -81,7 +81,10 @@ object ParquetOptions { "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, - "lzo" -> CompressionCodecName.LZO) + "lzo" -> CompressionCodecName.LZO, + "lz4" -> CompressionCodecName.LZ4, + "brotli" -> CompressionCodecName.BROTLI, + "zstd" -> CompressionCodecName.ZSTD) def getParquetCompressionCodecName(name: String): String = { shortParquetCompressionCodecNames(name).name() diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 51dac111029e8..58ed201e2a60f 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -89,7 +89,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -122,7 +122,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -147,7 +147,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -180,7 +180,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -205,7 +205,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -230,7 +230,7 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 -Partition Statistics 1054 bytes, 2 rows +Partition Statistics 1144 bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 863703b15f4f1..efc2f20a907f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -503,7 +503,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 740) + assert(inMemoryRelation.computeStats().sizeInBytes === 800) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -516,7 +516,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + assert(inMemoryRelation2.computeStats().sizeInBytes === 800) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment From 6ea582e36ab0a2e4e01340f6fc8cfb8d493d567d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 9 May 2018 09:15:16 -0700 Subject: [PATCH 0765/2461] [SPARK-24181][SQL] Better error message for writing sorted data ## What changes were proposed in this pull request? The exception message should clearly distinguish sorting and bucketing in `save` and `jdbc` write. When a user tries to write a sorted data using save or insertInto, it will throw an exception with message that `s"'$operation' does not support bucketing right now""`. We should throw `s"'$operation' does not support sortBy right now""` instead. ## How was this patch tested? More tests in `DataFrameReaderWriterSuite.scala` Author: DB Tsai Closes #21235 from dbtsai/fixException. --- .../apache/spark/sql/DataFrameWriter.scala | 12 ++++++--- .../sql/sources/BucketedWriteSuite.scala | 27 ++++++++++++++++--- .../sql/test/DataFrameReaderWriterSuite.scala | 16 +++++++++-- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e183fa6f9542b..90bea2d676e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -330,8 +330,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def getBucketSpec: Option[BucketSpec] = { - if (sortColumnNames.isDefined) { - require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + if (sortColumnNames.isDefined && numBuckets.isEmpty) { + throw new AnalysisException("sortBy must be used together with bucketBy") } numBuckets.map { n => @@ -340,8 +340,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def assertNotBucketed(operation: String): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { - throw new AnalysisException(s"'$operation' does not support bucketing right now") + if (getBucketSpec.isDefined) { + if (sortColumnNames.isEmpty) { + throw new AnalysisException(s"'$operation' does not support bucketBy right now") + } else { + throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 93f3efe2ccc4a..5ff1ea84d9a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -60,7 +60,10 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + val e = intercept[AnalysisException] { + df.write.sortBy("j").saveAsTable("tt") + } + assert(e.getMessage == "sortBy must be used together with bucketBy;") } test("sorting by non-orderable column") { @@ -74,7 +77,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").parquet("/tmp/path") } - assert(e.getMessage == "'save' does not support bucketing right now;") + assert(e.getMessage == "'save' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using save()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path") + } + assert(e.getMessage == "'save' does not support bucketBy and sortBy right now;") } test("write bucketed data using insertInto()") { @@ -83,7 +95,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").insertInto("tt") } - assert(e.getMessage == "'insertInto' does not support bucketing right now;") + assert(e.getMessage == "'insertInto' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using insertInto()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").insertInto("tt") + } + assert(e.getMessage == "'insertInto' does not support bucketBy and sortBy right now;") } private lazy val df = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 14b1feb2adc20..b65058fffd339 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(LastOptions.parameters("doubleOpt") == "6.7") } - test("check jdbc() does not support partitioning or bucketing") { + test("check jdbc() does not support partitioning, bucketBy or sortBy") { val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) var w = df.write.partitionBy("value") @@ -287,7 +287,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "bucketing").foreach { s => + Seq("jdbc", "does not support bucketBy right now").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("sortBy must be used together with bucketBy").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.bucketBy(2, "value").sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "does not support bucketBy and sortBy right now").foreach { s => assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } From 94155d0395324a012db2fc8a57edb3cd90b61e96 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 9 May 2018 10:34:57 -0700 Subject: [PATCH 0766/2461] [MINOR][ML][DOC] Improved Naive Bayes user guide explanation ## What changes were proposed in this pull request? This copies the material from the spark.mllib user guide page for Naive Bayes to the spark.ml user guide page. I also improved the wording and organization slightly. ## How was this patch tested? Built docs locally. Author: Joseph K. Bradley Closes #21272 from jkbradley/nb-doc-update. --- docs/ml-classification-regression.md | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d660655e193eb..b3d109039da4d 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat ## Naive Bayes [Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple -probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence -assumptions between the features. The `spark.ml` implementation currently supports both [multinomial -naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between every pair of features. + +Naive Bayes can be trained very efficiently. With a single pass over the training data, +it computes the conditional probability distribution of each feature given each label. +For prediction, it applies Bayes' theorem to compute the conditional probability distribution +of each label given an observation. + +MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +*Input data*: +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each feature represents a term. +A feature's value is the frequency of the term (in multinomial Naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). +Feature values must be *non-negative*. The model type is selected with an optional parameter +"multinomial" or "bernoulli" with "multinomial" as the default. +For document classification, the input feature vectors should usually be sparse vectors. +Since the training data is only used once, it is not necessary to cache it. + +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +setting the parameter $\lambda$ (default to $1.0$). **Examples** From cc613b552e753d03cb62661591de59e1c8d82c74 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 13 Apr 2018 14:28:24 -0700 Subject: [PATCH 0767/2461] [PYSPARK] Update py4j to version 0.10.7. --- LICENSE | 2 +- bin/pyspark | 6 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../org/apache/spark/SecurityManager.scala | 12 +-- .../api/python/PythonGatewayServer.scala | 50 ++++++--- .../apache/spark/api/python/PythonRDD.scala | 29 +++-- .../apache/spark/api/python/PythonUtils.scala | 2 +- .../api/python/PythonWorkerFactory.scala | 20 ++-- .../apache/spark/deploy/PythonRunner.scala | 12 ++- .../spark/internal/config/package.scala | 5 + .../spark/security/SocketAuthHelper.scala | 101 ++++++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 13 ++- .../security/SocketAuthHelperSuite.scala | 97 +++++++++++++++++ dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- dev/run-pip-tests | 2 +- python/README.md | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.10.6-src.zip | Bin 80352 -> 0 bytes python/lib/py4j-0.10.7-src.zip | Bin 0 -> 42437 bytes python/pyspark/context.py | 4 +- python/pyspark/daemon.py | 21 +++- python/pyspark/java_gateway.py | 93 +++++++++------- python/pyspark/rdd.py | 21 ++-- python/pyspark/sql/dataframe.py | 12 +-- python/pyspark/worker.py | 7 +- python/setup.py | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- sbin/spark-config.sh | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 6 +- 33 files changed, 417 insertions(+), 120 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala create mode 100644 core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala delete mode 100644 python/lib/py4j-0.10.6-src.zip create mode 100644 python/lib/py4j-0.10.7-src.zip diff --git a/LICENSE b/LICENSE index c2b0d72663b55..820f14dbdeed0 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/bin/pyspark b/bin/pyspark index dd286277c1fc1..5d5affb1f97c3 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option -# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython +# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython # to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver # (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython # and executor Python executables. # Fail noisily if removed options are set if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then - echo "Error in pyspark startup:" + echo "Error in pyspark startup:" echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead." exit 1 fi @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 663670f2fddaf..15fa910c277b3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 093a9869b6dd7..220522d3a8296 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -350,7 +350,7 @@ net.sf.py4j py4j - 0.10.6 + 0.10.7 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index dbfd5a514c189..b87476322573d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,15 +17,10 @@ package org.apache.spark -import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import java.security.{KeyStore, SecureRandom} -import java.security.cert.X509Certificate import javax.net.ssl._ -import com.google.common.hash.HashCodes -import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} @@ -365,13 +360,8 @@ private[spark] class SecurityManager( return } - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secretBytes = new Array[Byte](length) - rnd.nextBytes(secretBytes) - + secretKey = Utils.createSecret(sparkConf) val creds = new Credentials() - secretKey = HashCodes.fromBytes(secretBytes).toString() creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 11f2432575d84..9ddc4a4910180 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -17,26 +17,39 @@ package org.apache.spark.api.python -import java.io.DataOutputStream -import java.net.Socket +import java.io.{DataOutputStream, File, FileOutputStream} +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import py4j.GatewayServer +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port - * back to its caller via a callback port specified by the caller. + * Process that starts a Py4J GatewayServer on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) - def main(args: Array[String]): Unit = Utils.tryOrExit { - // Start a GatewayServer on an ephemeral port - val gatewayServer: GatewayServer = new GatewayServer(null, 0) + def main(args: Array[String]): Unit = { + val secret = Utils.createSecret(new SparkConf()) + + // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured + // with the same secret, in case the app needs callbacks from the JVM to the underlying + // python processes. + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging { logDebug(s"Started PythonGatewayServer on port $boundPort") } - // Communicate the bound port back to the caller via the caller-specified callback port - val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") - val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt - logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") - val callbackSocket = new Socket(callbackHost, callbackPort) - val dos = new DataOutputStream(callbackSocket.getOutputStream) + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH")) + val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), + "connection", ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) + + val secretBytes = secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) dos.close() - callbackSocket.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: while (System.in.read() != -1) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f6293c0dc5091..a1ee2f7d1b119 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + // Authentication helper used when serving iterator data. + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SocketAuthHelper(conf) + } + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) @@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int]): Int = { + partitions: JArrayList[Int]): Array[Any] = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = @@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def collectAndServe[T](rdd: RDD[T]): Int = { + def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } - def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } @@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging { * and send them into this connection. * * The thread will terminate after all the data are sent or any exceptions happen. + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging { override def run() { try { val sock = serverSocket.accept() + authHelper.authClient(sock) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { writeIteratorToStream(items, out) } { out.close() + sock.close() } } catch { case NonFatal(e) => @@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging { } }.start() - serverSocket.getLocalPort + Array(serverSocket.getLocalPort, authHelper.secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 92e228a9dd10c..27a5e19f96a14 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 2340580b54f67..6afa37aa36fd3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) @@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String value }.getOrElse("pyspark.worker") + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } + + authHelper.authToServer(socket) daemonWorkers.put(socket, pid) socket } @@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) val worker = pb.start() // Redirect worker stdout and stderr redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) - // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) - out.write(serverSocket.getLocalPort + "\n") - out.flush() - - // Wait for it to connect to our socket + // Wait for it to connect to our socket, and validate the auth secret. serverSocket.setSoTimeout(10000) + try { val socket = serverSocket.accept() + authHelper.authClient(socket) simpleWorkers.put(socket, worker) return socket } catch { case e: Exception => - throw new SparkException("Python worker did not connect back in time", e) + throw new SparkException("Python worker failed to connect back.", e) } } finally { if (serverSocket != null) { @@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() @@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) - } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 7aca305783a7f..1b7e031ee0678 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.File -import java.net.URI +import java.net.{InetAddress, URI} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -39,6 +39,7 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() + val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -51,7 +52,13 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such - val gatewayServer = new py4j.GatewayServer(null, 0) + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() @@ -82,6 +89,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("PYSPARK_GATEWAY_SECRET", secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6bb98c37b4479..82f0a04e94b1c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -352,6 +352,11 @@ package object config { .regexConf .createOptional + private[spark] val AUTH_SECRET_BIT_LENGTH = + ConfigBuilder("spark.authenticate.secretBitLength") + .intConf + .createWithDefault(256) + private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala new file mode 100644 index 0000000000000..d15e7937b0523 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A class that can be used to add a simple authentication protocol to socket-based communication. + * + * The protocol is simple: an auth secret is written to the socket, and the other side checks the + * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is + * not expected to be valid anymore. + * + * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. + */ +private[spark] class SocketAuthHelper(conf: SparkConf) { + + val secret = Utils.createSecret(conf) + + /** + * Read the auth secret from the socket and compare to the expected value. Write the reply back + * to the socket. + * + * If authentication fails, this method will close the socket. + * + * @param s The client socket. + * @throws IllegalArgumentException If authentication fails. + */ + def authClient(s: Socket): Unit = { + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + } else { + writeUtf8("err", s) + JavaUtils.closeQuietly(s) + } + } finally { + s.setSoTimeout(currentTimeout) + } + } + + /** + * Authenticate with a server by writing the auth secret and checking the server's reply. + * + * If authentication fails, this method will close the socket. + * + * @param s The socket connected to the server. + * @throws IllegalArgumentException If authentication fails. + */ + def authToServer(s: Socket): Unit = { + writeUtf8(secret, s) + + val reply = readUtf8(s) + if (reply != "ok") { + JavaUtils.closeQuietly(s) + throw new IllegalArgumentException("Authentication failed.") + } + } + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index dcad1b914038f..13adaa921dc23 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.{Byte => JByte} import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -26,11 +27,11 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream -import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -44,6 +45,7 @@ import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.hash.HashCodes import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -2704,6 +2706,15 @@ private[spark] object Utils extends Logging { def substituteAppId(opt: String, appId: String): String = { opt.replace("{{APP_ID}}", appId) } + + def createSecret(conf: SparkConf): String = { + val bits = conf.get(AUTH_SECRET_BIT_LENGTH) + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + HashCodes.fromBytes(secretBytes).toString() + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala new file mode 100644 index 0000000000000..e57cb701b6284 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.Closeable +import java.net._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class SocketAuthHelperSuite extends SparkFunSuite { + + private val conf = new SparkConf() + private val authHelper = new SocketAuthHelper(conf) + + test("successful auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + authHelper.authToServer(client) + server.close() + server.join() + assert(server.error == null) + assert(server.authenticated) + } + } + } + + test("failed auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128)) + intercept[IllegalArgumentException] { + badHelper.authToServer(client) + } + server.close() + server.join() + assert(server.error != null) + assert(!server.authenticated) + } + } + } + + private class ServerThread extends Thread with Closeable { + + private val ss = new ServerSocket() + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + + @volatile var error: Exception = _ + @volatile var authenticated = false + + setDaemon(true) + start() + + def createClient(): Socket = { + new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort()) + } + + override def run(): Unit = { + var clientConn: Socket = null + try { + clientConn = ss.accept() + authHelper.authClient(clientConn) + authenticated = true + } catch { + case e: Exception => + error = e + } finally { + Option(clientConn).foreach(_.close()) + } + } + + override def close(): Unit = { + try { + ss.close() + } finally { + interrupt() + } + } + + } + +} diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f479c13f00be6..f552b81fde9f4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -170,7 +170,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e7c4599cb5003..024b1fca717df 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -171,7 +171,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 3447cd7395d95..938de7bc06663 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -189,7 +189,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar scala-compiler-2.11.8.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 1321c2be4c192..7271d1014e4ae 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do source "$VIRTUALENV_PATH"/bin/activate fi # Upgrade pip & friends if using virutal env - if [ ! -n "USE_CONDA" ]; then + if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi diff --git a/python/README.md b/python/README.md index 2e0112da58b94..c020d84b01ffd 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). +At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/docs/Makefile b/python/docs/Makefile index 09898f29950ed..b8e079483c90c 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip deleted file mode 100644 index 2f8edcc0c0b886669460642aa650a1b642367439..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 80352 zcmafZ1B@`;wq@J4ZM*wx+qP}nwr$(CZQC}!wmtuSnLBspB`;ISE+lnMcBQhk0v&QGOsvq7S?-^UssrKY{S?;d*)&wieEMdUW<4|230^y}AYwy4HW6Nl6u= zoaJ5-LR@-QR$A^vl6rDZB|J`!nwFAQGA2%Ke42Kgo=QPnnre2Al2#%n?2d5qNx#vg zLS!W4-9uaZ{$8F;zKm8{9bO_C`oE>MI*Aq63km=LhxLD@WoTezWpC%`{QpU7MO`Xx ziw&XoQw?Flz@#5B5@*v;0sWj>WS)p39@G;8D1fR(qEen@k{AUt3$zi z;{dihzth9IJM2krlN>HW3%ZU@f)zh#x}qz1I%iKBe-(PUFCT?dhT0aD|MeuiGxy4R>IH0`W_ZMf4)j+w+$g z&DpJ|DXf0EK#%-2Xw7m&$!gLmp+z{0+|j)9=N-Q-kysDt#|9;^#5;m^kYh*zh&4#| z4df7CH$H;4FfyEB7=l^*M|dV=3w8Kc42583#*=jtV%zL*ueu&U2Oo z&)ds{ur9liY+7?~5wuIlOP?FV65|eTeC@3ZEeknNdgMEe>Yr-~rnGgV#w^%xwuzI{{u3=pNbvv`QwzoNZN_kuvWq3t?@y$ z8MEXqc4>Y2dgo}^sFQb_jN``W;U}$r=9)Y3S(74`h~W{;EmYZKnY^q{u~4q{4VQf~ z79KOih$5^^(v5a|aY3h~Q(t8JFvuJMhg%~`X$_cd@_nb7q3Pt3mWyq?5X=1Cj-+d>H*(Nq>;;*OGB2~ z6}$*}Tnw0Z$!e_-f+JI8IeHLiZWRdIV|J<#VaX-l9)ZP_)z(gWnlBX7X-&o)PRc&< zKOTvBXrbPmD>-L= z`cZbtSy-iQOV||H$!Y_4TiFHsY&F^43qL=mKdxUEPMtU`d(E{95DdDFx%`kp<6>h6 z#S4|1euLb|Jb~|)-B;LQ%_k-gCC-eYUZluQgEXN2=B+z*FU4ld<*LphdGKbhOaGWD zJ*gL?wLTtS-$re89X)Dv*)w3Y=s!~ySv#&?=eaJ+n$M`ks-Q7b-^}BK#+ELLM*q4? zzks;u+q55ct(g4>p8kby<~YP%=coVxzPbPaDF22hBWnv2TW2Q|M^_Wae`CSF!0uo1 z-@xu#`y%P6HRZRq4`{)gA;nzTHCwqvc|~jGVl{f}!l+g2NTZH@q<}!kx?S`GNk;m;nclK2jw~ zLSc`rMZB#`+SWu@jR>*IxhIv>J77;TLtm|dR+R}ET4PTt10Lz#=z$QCuf$}BB%@tC zwatRA;3~_2F&aaf+Te*<7ur#?_x}v1l{r2iA|ziF@?!8en1 z0+yUUz;*&d6itxsh^nKNH%F=JwVFeo={PWkG99)Ej#hIr`S{v^z7Coe;@DNGtt95y z_)cN9NK%NtQ!>4}8q1?h5+aHZ`(SP9qpoAG zp|pO?h_NlcHeZ^L?@WKo@Xbs#0MhQ~q!?NYyOrcBY0sUy6RXp(QPdlD0?Ed>-U7(u zFB51^`qfQV+$imO-#`mYU4RnMoxcI>XkFPQ%(Bm#W;-?yH)njvkwS zb9KgUYiobmy1eW{%z_!8F=R9wi)_mqCoE8czP@cTC(LyM2ypy#DnqSW@i#M#x|%?EQ7L~}E!QUlDq)X4OqO7?x5 z9sCQHB{GA$z}DKWKFje3nu5Bv1+45-_3_2QJ*$^sIYjTKZZ~4;-}~~q*MQ$GjcOuX zC^TyH{mNsx$Nb4$j6J{;SX>b6lrm~!&7X~;@V+H~ZXsNRS;Ha*t1VXq#P^*s%|wt$ z6jU%>%TLBW&Kt)Yd~#`+FzrrSDP9@Xsh-AgS8)uRU={?oBCedNqePLlNcqXYzi$i7 z%2}`ZKc@EncP;Tu#(PgSNT<9E#W>*dMSoZQ^s-Rwp5Kd-h z-XWRydEzVVYG!ZK+8jO5_s~>wf+FX$_m29Zu+IP3U}LytElYKIfk-2acnT}62p3uw zv+t9RQ|gpS$?JU7w*C>bke}#Rn1cG=88$+ygmnt{ei-`8s9^}}LbWE5gb4F~l-KwB zCTSRa<462Kdc<>S(4K39~X>J1BqRJETL>by{2Hkqof4uyPm) z24-Q)&~AJc56Z?}XSFQG=bP#Vi7e;ErG}p2mu;LG@Dww*L zCwP;<%9KE8qz+f3AyB#6w(TMnJ|Vz?d9)(SU(op?8ead#DMy{;eYsSc;1F#x*lJ!C z2CR~rVu}{{HT-vnv;YQ6^+2Mb)9I^oz?f7b z17Iqfu6OUnRtXUWj{WcV6cbBNHwjTYWK4sh3atk=sFHQg>8KWmWO+&(HL#RK3nFJ6 z3(tj#_Yo-2Fa3HCuSEm&)CgMyfi%WHX0c?aqR$A)zb(S!ERUswrmB=4hb5?~Hkl0b;BN8)z0 zHg^Qs;|E>fZn)76JSB~DVuNgTmUfB#IqMeh z6vxngU2+c`kU4|#I`@W-iepJl3|@sN7P8zX-Pnj;akhHD2ZWwKWKU96v#dWF^cItU z%aS4PI>Gk_?AwvoAcbdU;83eb{q;5{LZ+Yez-lM#BO77aSw>%gSVP*LiSa=?`OhR6 z>Yx32=Jyj2=2ch8f%usJ%=l9_%I0$VyUwJ_TtFHV5AYsaL}fn%T zb7qzq5c42X49xd&DK~U-tqf0iQ@q`hQLK|iRzYzP!`ts*_GvTSFZ zH{BpjK;lBcWc4L{l?r4WXN~BNrG=KdD5fR;-mGYN8(J3Gr zt5B8|oMeIkWsEorBX$M%VcnsBgpR%0G_M0cSuf;tj*0v3B?XpXt zG}z7a+yG4#ehaCMCx%&G9LInwL!GSXQk3jY&)B|SKZEyM4ed6!dF|xK!SDgXrIBo9 zT$|HJa>vLydtmJHo^b?$>XSxd-4@FB}`L(=h^|tk}qSaat#V(=3aGd`WDp zy0f^*mHCrN0Ws>ML5{1X-EdZE7(sRG!86)6_!Cbtp+4iIeshH22(O&+#jA5)w4=V? zyBT#9VR+shhEWFwY1d8Ak#Cqfwy9yiYP(7%>IY9t;gjw){ukOIK1A&rr=9yx zBsm7GZiWRfdF^V7YaB!)d?i{Oem8?C2WirE z-_9scgUE^|FZ6M_n;rS}e#^bD;d4~B3jXFGye&AQSKut`<)c85S2i8nXB11FW%WU! zt|b#aI^_B|26Bohn`1G&kfxc{->1sIgOeH^uiU+YoJ(k>C+#*C->J)Afj^JXH1e#~ zTz~$kh4P+q44JH6tVhl<%&#yU@7t2mi=3FL2%5R@+;yr$s%pz{lWkhmD1D^J&+ww8 z-I6|{nD->^BDl%8SZM`#NI~_}IGeHAy!@LwVmS^6>nGe z3@cH_wsy`JZEuF&Tr7PzHLt*P{jASW1vXE8wMy5yfaF2L3{|XTd8<`hOQ7sfNA<{5k8_!YBj_7H`zuvV;uW7hTrtL_|DYLr2KQc#5WKD=BoA-g zFUPrd_}TR|9w&o8?OW(yB~dn~3p+=Hh~_l$graf_`UqbS;y>9%st6lM|RTj6uJGEp(Lhp}dmhT`Jz+OIkLg+f=*N8J$NAG`k?&t0&Rz^nd zPb|DWycx3q)@fAre`I=m-_s1n?e3(zt=;1Y`w+@PW+T39t&}UUZVfxE>mB{zw;Px4 zy;WQr6QfL#VYgzpi;5??@05ncs5S{rsao)2gM8KG&c0K~g^=9Z@yHYTqS)=T9Xs zvQi3j^>TD0jxz0zmngwhHUj{aoV@J-bSKE2Z)4UL8$>)8-%+zil)oOgez9fGUzBh! zB%(jAiHaXO9qZMibWBFJ4rd7?Gf6KY&T%)JR9RI9)RU>;2BEiy8mbh?YGxVJU&nl= zU}A!j$0kk9*4EV4?pI8_nY*}oj%m_?WQW6OC3JU?C0Xq1vgktW6;e<2clxCDd|zJ^ z#0FFM8IhF*3Wx>pY>t=Q34QyCe4QVh?R3Xit_kWO#N?1PSJk6+jBf}dA} zRCKBqRnN8Kg?>HHebxtOauS0t>aYxYoeUMsbpa19Tc0Z)Dr*FelymGA%mF$n#|wwN z=Y)g$@MQJ%^hVR0va^rXnHA+m*ds+*d_C+u-Rf4|UAumr?|g~7z_U@I^^IO!I%E!3 zGC?dBJW8|fCMcdH9X&d^U=MD)4>nz^V%q_(8n&o-)kD6-5^j$93@_nE?d~jm1#gG) z3?J~+8#$P})Wy~PLInnuDs4+!t)<2`43iKPp*L3-%q0nl&dpQD4 zVr)getD`+z5c($5Cs+&HJ$7viXN9o?zdj0fC+%UVAvxq78gZ=aIG1qh@c5g3@9%@6 zMk$G;^VXa-e&3)Hrl=>6~IE<&={1?0W&q7cV3WYMmHlZ?7wYpO$4TK0Us<66 zD!x|x=1%QiBAA$W_}eit8GG$AF-~5(S{+V-abENe)~KXw1K=B<%wv&uz+(qMwDjPz zfWOB5r8Xu}Gx|QiF0hk2SmCY^(#&GUOkV36XK9|xrqcMr1SSV$v*Y{y*0Wx6AAP+| zYQRKvJ{`IS1)7P6bL8uygFE_r=C;2Wlw-rkr-}^s<|WVDf>FdlmDCXkfHi^`{mIt- z{FdgAav&!}m4{g34Hx|2u`q!gKY=Vik;FDBCq(=$e(Y2bUq*<4Z1xoI1@_?%CX^S# zxr{KN-0kY2YxQ$c~tyfThY__XGrb&DzM|)YA(yI%w+a+YdZ#( z@9T%a?XiH8r~6tSsfJ5y=+MY_7{ptNMXx`c!s@fSz99Rd1+-zb+gYoQ-13DCSr0(swB40%H{K6nKqsta;wwM zB0BVw5CYXJE6ltzcbHf~h9aXS@;qy2AJgf!(k8&Uo{f?zQm3=AILmZECaOgR_A&T% zLUzUH&$s64jjze4Ekn_{L{5SIZY+|niknyt-37dQ&~;kVCA+YOoBg)h*yUCMcf#%R z=cv7gw^s(W_ciZ ziELhk>M-5fV}wPqe>P^Ddf8nqo-VGmwtJh(##=-2wlrO#6Yg&Q zXQFv(Dt~n=+|(;lEmvhdpk1FN^ZZ{dvov^Z*ez-0uNQx?jM^68nb(v)MP^@Mz<=TY zU)Go^*?s~87yuv<@*f59|FFh(HueVpM?3r<_~I7Tnb;);gx)hs(H*HyD#S~oU@A}b zf)HRNpnze(8Y1!x>z=I5ifij7^dnGr#!*+_EL-&W^MZ!H_tUYXIR`p|6ijqC>T>Cy zLm7swnLsy3Ev6x4U3(}0ZlsC>Mf35JK*_soBQQwYbz*V?fhIAOyLWvy#eSyBfC!k{ zLQbea)G7!S>UO&NQ25I=xrT(1Rg;O7Z9$x~lnRtl#ycCzlm^IhqG(EXD%Eq)df}-; zCJZl81QuyJ=ZFf6vX6Wp=_ZBoznPoydpHOnapf$uB&S4a^{sFV2!;KBU;^-~PH3yw z5a`hFJbR)1ed&Ze++FZVX?*Tn-e~p`c)Ue`?EwXE)cr$71Lkt4p&rZDleX#Y4rsOC z0WzMM+gA)$lMVc!Dx82C$Foic;;UvKSBAy(Z3aCh>xj)P(jz0dig#Ch=T+=%LW)!C zzKx8<(?xAi+)f9Kg)LulFs{vUj`ZUxN8qHei*lZ7u~*y|^$*40|N3H|5E9Gz*V=V5 z277Ko)JoWgMxY2GktT$~u9YPHu3-MYl9+POo~Yh?x5jdgTQCxLFc>hoq)`#SrCU>T z3CMclF?zer-eH0IRn!ZDZ(B=w%2o3qbz}_fbPw_P+0)EccJ6Nm=QRIHZuQjGI8yfc zM4w#c60a!Y%Z0ZniPg=*Wl|{OS0wzzk#gj3v^WcLmjIYGX7L$uo_MqUlIL(9>QOy+ zK>RATPvBdw4nXrC)YUyVOY|hqvA#wC#EiP-jB3%|f@$Je`mxXVdS%$9n}$YO+fr_h zadYV20oZ|_p}*tlpoeT<-5la8h;Q~UqDTQ_4$*Glzeq2{NCP2m3)>|6 zb{(!`9dazIA49j3u-VL=lssbJ=p);~<`wtDv0lzAiL&4mP?=!Pa*JneH`cqd`VJ4p zC_WW1K`AT>a`LN@R@$#Kc7l{Idn~K?YCFyUNy78KbO5@-cM16rxWoQWef`Tcp!87` zy!20fEg}E_ApX0)Of75;tSvnMBgy;^u7PW{ov1BQ1iv*sMs0jcbgg32c_NN;ohbnV z3&I}WltV2kVk@UAorbQ+qu{UGobPzjjfb6%wJ^a#axa#=zV{iJ53tldT0%9@v@DpP z?jDVwW0HxH5s?Q`)EfSJ)FyfT2~sS5`yo=~FJjAaP-ZZeOi_uUwPH;wnP2v8_2pxP zbQ*E!jAA~C4NS5LdaR5QpyXh7>7h92w}X*a`a}|=nwsR<{Y(q3f__i%zD8Kd(HOak z<;u{NN^6@chZZ1Rc1bOB(lR7_QVvklK^p&>7^IpJm3i^Nk{XJ7>K112PtB9pdy^)w zce)vs9r=!X){ZKfM3RC6$$Q+cOQIC;3n{d-eiX@NI!MM9UfB;Tl#D ze-*E8{iwgYT?^`kq8Z4!Il#coxrUN19xVZrsaui7e~=Al9zH!X-SOK{MPLHj<5XIK zT2Unjv!m0SO;NI+8~9U}IUrIz%y>m-U3Q^LKiA1z9PWZo)-AQu12XNXppo{HRz}K0*);O zph^{MDEO(Q4C(>VF7bq9CbzLrOkms^mjX9JQmva{58$@p@&nb-`rk%mnSu^^UPFuA}Jk2SHU+sCt1nZO;JHqxEK^rXX)% zVSfIzB1Ur{n3GP>ppnTyD?JO|kl6C*$ zLXb&c`d?T%#wEQV{ES5PLAKLv0*_l%uXu5T-6AlRv4TZ=S7d+IoIW<%0?Y}{(azb7 zJWfJ~)gT0#8Y}AwbR`J#Ag47e%uF*Lv1>xXBTio1?-UsV(FjN;d3gwNq08wDPq^8T zSc473#czC0?W;&Em1uQqxq}B;=wrRE4BZAN?>YuWDWph7?qCROlG=_S7c|)Zy@|I3 zE+ZCmxQK0mWr>u!stw+sc|IdN*G(4lD-R0!V*laS3FKasDe7kB)7E*|@N6%9XU z^mJ90I{Cy?h6E#%-fuD=8?<5GQlMfWi$6*KL*=_rc6;UH^|yBf_s@gdi%-_@&l9TG zaUiZ&=b*HHLT)IJ{fR~9?|<{|+%}|ns=xvO0Eq$sVElV^urzQr&@-~Lwl**@PwBM)bb!17sU4vUPClTMJd_zr(4rZClsb zCQB+r>U76UQYPG?w@ozQhfkWvpbYhEUztOh?J4Bhcqo|p{(Z&Z`n0aa4#H3Q>5z!W1)wA)Csb}b5HAFY2Fl2B(#IE zXk=@Oan0JlPYrw%Xi~aw#08;uc7U9vhiGqS+7E@}17Hm}bx)n4cdt-e#5RB z$2@GONPYgX1DM!=A#1pTZ}R&I59H@JL{mYItp|XR{`!DslA|*FWCU^1aZecY+9F+R za03z3MIHM1+VA-K1tb!#z|?@9&BD~x%_^KGe}bv1F=PXIVWO-ng88R?^ntF2Y)^?u z1ylz?R0T2z3?=G>YqLb~f$>WECj-*smfFJ$`6G>2jX>uXiWxZ+K6x@p8pLOw?v*x+ z`&ipdLl`Hsfxun5;2~ukpd|pOX}k)?#$T#fI-B*0cP=~zi_A| z7(B&mB4T5bQGKu|sxwL2lK@O70}0`{e&bH?Iegcw>|9tGtasyb<8qH>_GR4XT!h)9xO)w-enyzW0DaR9|N5DI!O z&sz~X5y`SPkLV=S`QYNBw^d;!10@$ovQ&x-O}&?b>JmjHcSDrM=y7KQQb20+Ket7( zjla{D)`Xr}f!p~Dk1Z)A_Ul%(03^>8ZzcsSMWzQrj5-C#R|N`xeq#dz%Ln94_6R?^ z6_j9iP6`0uxVi?*8SZrZN|^fZ&~;Tey<%Jw#G)96=@8j_>JUbVvkQvpjs7ABj4`23 zfVD(WqDTGJTxRJ-F8}BU{3C?$0^9#4G^DC_$C09yJ{|ufR(+L9eIrjC2^Fb+x&bPUHE(%vVSoybxsasJ|7Ai*`YCY7))oTGFazTEg3%fB7!Lr)D*RVzhKWW zsDo@C`#tSzOav%TG+AF|0^r61Bv2++6o3khwjr6= zu+t_A$(zGg2-56U)f~@v*PW2*Y58RN&dO9yhaRTSzq1jJTW&5t5@SI_&b>EYMwmZY zk(Grx;IVvcJ166XQf;x>w_O99z11AhRxaRLRck!9K2S^G!Md3Md!Vwtg_Mg}54w*Pr3p3}y#nCS12~LxFY= z4I7qvXcO7^@OpI^d}5! zw4!L#`vEczEn6u>JyACk(aGev`CJl|A)4oax*|C+Kwm_@3xq&+HCm2hO6&BVi0;se zqCLNS5>)M!shC~kCjw?pseaUxk|1Sm_2XmyoYsFNah7bcqTPiVoA5^%Og2iBza+o(ceM%ENcj z)Kn2);Ssx;nb=vM$eEZK@XME+oK@P5oVft!Zv>}2q8w(r`cvr#r;>k zYar73#j=a9Z8Bc}yl6;N(NsdoaYZ*`2Wv2Xnh&Ku1TD?{?^=I_r#CW1hRsNJ-me%_ z1i2{}$kh=11&1a35xksSEO?xeF9;N>YkD3P&$s%{F~Q*%SErd!4saA#YP|poj=!dg z+VRvRScT09M9{~FntlS6wJ>N4#nfFsE8ds>A@LJqX{G4j0|%X2aQNL$3!!{OYj`aG zWiVsY76AL99E4C5wRp!@^R_^ctyHoW8T=Ul&ixl$kvrHqib9HNQ$olJ^n-!QHqmhu z9+w}Gu2yD+2f7m-ed{m!dSPD-BZwpnze=h{WpEq;N*Zf*8g)L9x`GmbAmel5HX1C4(7tciiCf} z>&UdHzv82O;fEjjt?j7VU$NOx^p+@Z7Y{%%9m-}filb^48B{`|LLJI&8OZsd9r_W0 zO(V*fYfK$uwB`D7%QgPixX0p;mKMt}{0XD@ZVS`}@BoHU$Tu6>U74&jY!mxw!H1)Z zZtVP=xY>N11Ux+Z_EgJCcS?_t+EO`}x>N6M4zC>H#5b1`I9{M!`Rym`#W%zZtqou* zY4pR-w!iyhuY_F@2FZ{*V$9v|y0tO}QiaSea<3n+%qkpkE%_L?F{+7!@MIW4E&A5k zVcIbhCK}ju|0)Rc4)Axy}V+TtX)i;s(m4nwy$rqAxm5FA|4DXKQqtHMVXm;H$wgtbtM`w}w4es~`EAXwq_1mGEWPABXYRpB33mwlG#kd#t+< zdG-FnBF7>lx_s@`yOb-X82q^8`jIS^D8$?Miyq`e2SgQEAf{pvfiTk-c6t5zJP zu-Y$E1u46oUs^b)YeaN;RAQ3?a8oUfzr8N$sn6zWhdRJL#Q`r`C+yeQSt0BymRf0D zZySZjE5_K9y4TG!CG@_n)PDDmcSz~&_ysVda!mmiWP4KwqmFMsVlQRnLknmnj0qdS zpSNevdvjO#XWc(mG?6p#!9I{f?8``R8+2V`a5)D_s;-Nnk9=|+>tj|#)Q zM&b2!8rk?9$|5&I%Ltwc2frt$KFA#0PPUASx;2yA8T}_!=B{9NlGb%uQaK7R4)S3| zc_kD32`NNqnoKM%h8}5Z)-h=@q(X7AXR2Z=HlXd8>oU&^%9^w0NVb!nW#U|rcCs|S z3s$UC&}hH9Y-;W1@O^lwL4SGUshhfAOqY4wY7erebY|Y%-Ob2D?UGIux%F*p0GVS#1Zt;^8KweJAqgtvpq*K6X5<)C(?Jql zE7Lr*aqqXCIeKqh>jVB|vRh-}qt;E?D|zK_b!RlCh%|cKwT|zw8$()rSngJ}VxvtP z#+<+RG{xtm=bo>P>Nv{u5HG3m0o3Ry=%YF;{o+r?l36*?-qjf~=BFLwNqrFgh|NT(*6uKlx_c%~+G3 z3%&*~wd;2M(xyBIyo|l$<`0W4D;K;s6P7k!%&zvXm-FWt9b3hn-+xG)e_3mJYA4Ce z! zD$#2jTmYU@p|5=n$e{|JMjUczXuAtDjoDLq*mf5arEniAPA66Gb^NCNmA$Wn5_ zsnS=iEafNB_>NP|rq`6%b#i6F#)#{w8iWY4ue=5JEVX}+fa=zO&FDcPkR7=m(M5b8OY zfK^A&nI1HqbkrG>0q0@o%ppwfP4e1ibRfJ;dChGTA+!io%&W zf!Sq=$4REv393fQ01br3gIQdmkDAeoFGD$~JiFXyB?oZ!|5?!4M9&peF-;>D{%F}|=u)yiOsV&tPrVmH_*vI@v1 zg3x!56aSFhigdY&HmJU96J5=)VddU|9GtOc4a&JN^bG}^l`$8s5QefN6#jt*_Pdpv z;GdjKQ(n*;!DQDbSiZo%@YDMI+vCn1ixyq&KJSgbNHFr>l`!48*WZ`5Rfpe`h35+f zPEOw5OL(xm_V1?+zu?ciy`@$kFAoRSV6{2lx1Z13ouRz2I{Mk7sj8?v-xn=zH({@y zpA!bG-0Wxi+Jd1Ag}UF5!>4w?kfU=|+vMGai;JkD-MEE(zh|_!YPT>tzGgZ<-ybrp zFLy`GXF6HA&z^&ckw1(*&mvh{@O8BcyNW6UpD&sz*R|$#_j!>&l&Y2AfXz!(Du79G zp{b4Zk&{$YPWEA8o3{G7{TugY`Dufbx)d-Fn-}0b86-x3IxfNi5s_mt*30Lg`WJem zkn`&&AB|;V{a74n#fS-`-l4QTRzz-b%l2oMB9Jzu(*ZqOSrM>-0QXHJt$IluOEXRf z`+lum)TL1&%bYxlfQ=4!YdXfS_SpxiH4-FZ4t>UUC581j{0G8#pQ6dT=;Qs zk^6ov3Lol{KCaG|Wscb}-P6OS36nmd1bSUjJ4;^AAQFMemGgqDkQ zz>h-~x2;V708BJCezX=6qOxr9DU_p)ItkK%gh7s)17OTVS*M7xyD_+9o+H5Ld{{vC zs|k7G0LuF@p`vM}8Wt9?k|6-Bp)*u_NU;d07PD#(hxeldddY{1(w|s1@|&}9c(^$7 z9xDTT0w1RwJNF|we;{DRVxZ3eP={53r2nFaCt9qy<_DYjPkgk%9#zO$BF!wm0J{3i zXUs)?F}J8lRJ6VjMf8izApM~)C<>Y|xUKYK`73tz?G_I2=}%LrriQ+jmZu>;>YbUo z>L#IXuTlvC0AJn+6Rm;h?p6ubtp#OC=`59Sg5z$Q3^7pF#R2+B*&<8-H;Asn!X8oQ zxenY1gs1dYabnrnw5mAr5yU36ItbGrGnq00n;LK`?vaK>t$Sly`sMhi`61Q z*9TB<5~w*$0aKrj6itE+Ebo!QSb_%+M(_&9(#O;BfZG+61TIB_v=gWX%LnGzAj}4V zD&Qms2W#DT+t1!jA24~fkIf-2zbVD!j~K<*)sNW|do1Z4yHY*> zU6FoAh^A@V-%#nx$E&5eYro}Cc)(#AlNOQGS+uk})T7WsM#oRA)aSZ6UT9jKH~m>l zO;ff%0#FBP{8+a{Rk%fBRr=(RJV;}jKm;GMXq3_l{NL`Clof~8jBoB;(Y8h|YbzJT z{3(O9$;h6gfaKs>DNfyoBl!vUn$g;?h?WqZ2#ERg@xA;GQuzhVVqD`3 zhas51{QBX{d+|W}is2aQBZu9p5A|{^_}x_S+x5}`8qmrd@LeS4?Ujcl+#$lO_l#}?if zH1xi@yfAaJ4d}37YxP(*IfNX6N0rXEe#ACI^$mIS2(2lt3V#Ze!8mJHP_Ju}l*DDP6Hax(>OAHlcp5mMRjKRJnrQ(0%@)9j7c0 z4pI#ei7h9}19J5sRQ#8qM)VrQL2w-`$T4fKI;F;5!ohd_^pipuP2=GU#Qu>5PzBvr z*`U{`>lx;xAmzy|UAI0B`LNiN0WJlyO1b(xL6qce*a8Wn_J-m#vN!3M+WSXN4l6tW z_sIxNM1(OzulZS^hRDD87$aJLKy1ch?Jx)`diw`Q<3rOv)$4Ow8-OOT*5R0x*DTGV zCVLauCPQdY8tP>bTM%)_l2jD@0dG;M6dIMEbM^FFw1?URXAb51`@f6$kUX-P+rtQi zQ@h)V4lo*RE7NtL&?X$g3LN1Html6@gNLs{xr5-(qs${!$-2 zkKM>6y2Dr-q&_YsQMS+VT(e1;uGUS5yl-L)E*MtDEEZow)=YVipgFw05a6NpDO~#T zK=#E{;yzBUJGmGZLO6(ZfkIRF@enlQRS`_0Q?5tUs>5jlS-) zLCYLd<@M43+p6H(HNS351NF`T>lUAB0R#>}2Moz%iRVjh6`(X29A(tq%wKyMY!06;3J0DWe!gm5~4?Wt*V;Huf0Ne_Um26YPJRTkMY92ywM zKjp^F44rbke`KnU)u|dO$I3C<;@FP?hR6`VKQf$x#DZxf(eFbMzrk`;>TvO32faj1 zrBfn!0^Kq^BfW(72p1CpLvyy`v45AGYqCu-)K;KsRyl8M(jaP9@STtQA)7jyVMG^# z!8^0aQ>`px!f3)0>VYE7?YzdydI;Vo35y9S<;$joVK*`kI)}K*NwEZQkBZJ4-m_Vs z3Uy7e7TjzbHqxvof%|ISRY3Wo3SO?Efx>~w3@EB9u9Hq^ANgtamURx&a+I$t9;VztpKu; zWTZ(nlXzN#!p=z5S(*|m#LHmM?L;?&i`I|UygZK`#jmE!DNEl?chq>`Lck`oLVH6I zdI&U?;FkXR5E+_c^QoksE9j^^>=A<9r1)j1Q_?(}IN9v}2GBtOE{GOof>ahM?Im>W z>;?t9dsG!7YWGDBWrKs4)lXNzZbY7eH6o|0S`VbUZ~G(=w79BdE!QS$(*atenHC>0 zj>O|N1X2+k?HZ0sQOl_Wxop;15vmAMy3YH$F zfqIE-Wu+IPB_ib1&&%oQg~HRvVbzZ*Rw=NNL+++Q$VV9)R4F~MMjbkeS}ZC-ZYS*q zKwvnI8LZi-joj=$PA}>11zEjpaxcAp3lc9F5bh2;YYdE1n3{v!-|3d%oTEQb3-*w- z8thqu?orsqg77V%ObvBH_cmg;v3+XsmWo$NrTSrVfg|IBQ)q^JAjYv_H}O#bj~e+2 zKs0fRHl{3SU~?=-H+VLMfj>@kBrfFXR%{P3X>eK5phZS9;(cn;%UUjSW;6g z$JAmKj}vN>2kO;0OTDwvg^J-(@0D+>GofIaQ96fHT`_ZquO7qhF*2ZXgjrRjWKu&a zv^>r0A~L8-wYSX@5cg)o)wlO0;6Y*(@2)}j9PN+?|8Svn3>)_vUetn@)4?|(tPM1+ zhsV4XJl{dvn4M~8>PWlADm$~(L`B|lqe5l2%(6QZoby$X!h`@`z>T=N&Q@u)`EoU3 zJvc=D6XW5N)7-yxoP2B0o2N8?N+|EIO=N~``yCMalE9K%ZUx&YKVPLAy?%TWk($}Vc;wp+6B(R>0rYxcCvMl*dZJc<{z0^|M`q5=bmAqY*z&q78miz z2j#f&x{8m`P*#p=tM(23v`W~49KdN%lsSpeDsQNnAaI+8D}g)G%A~h>Ra5r4Hz)Bs%u)l=T$qjMf4 z2Irx???)idJ|`>m=q=@qBKT$yEdCHn5e=a6d8c;FBI1MA_EZa`#>tdJJ?A7{L60J= zle&uAm8yyul892L0t&+|w3v#KYd#YC7_{O%T2mbPr&&ySM%PMatAx)1jpVAzn47~H z1^`{sGeIU}Z%Dw8$u@RB-FS@L>=__k%`(Y94dM56{BuV-d!ex&3rcAx%U{EK&F#pS zC;ijo?VFMO>Q$fiB6n^@IvK)3=lvaTS88fR+mU1ngNwFQP2IwC4N*BaOM!1v%{!Z4+}pSg9$@_8x*lJK69Jmp8i21%c%z5bNmt*@>|v*R=C0}M^8J>PJGDu!*9ROY{*Aoa z^)6)i&|`UPox!JL`Mj4*7|ni3A4H4Axt|zV9Ss?5#Y_!XThm|LQ4z~kS^1Mu)B1*j z+x1k$9RK)uy9}wi1#V1V{0xWdTwUXF!?--rVqVOGxG{2BF5wjcE;+C}89@W=X-8N0hc}pX0n^dx& zWJRIMJWHpv11g*BZ*xw*0nuk{t__n_=BJGKFWyUb+WI@bH92cZXGgh-@fleDqYEgD zBD+J4;J&zbW69Hp>3Se8TwgrKF(T@tkqIBo1!%|IC{8^R$9&^iR+U=jw3ETMjG@F#K4ij>ouB>-80i5@aIfWgNd-A zzHspRdKg_Dzq|UnsiWISsLz_Zcg-Xwe&`_p70B)PRt6W*>B;%4{uRMC+S%LZ1P4)< zxEc)UfHer9VOv&A;;~zcW=G7sZYlM_T|8l5DPs`&IK&LMU{D%7OT>qB7IxzVE10ZV z1UpiL9=am#E;1Pd1E{ArYrXS8m_h<|bNvJ9FXej7dLxyYL^G~j*^;3ip5`&)411Tp zNGU>ji12mOOQI~zW-|zCPxW`!H9Xkol;ysAGWXt*;W=@DjQ>PHh*J@azl-DVx@PpQ z9kyd;eX>xK8}E{-o%=3nb@K*I4?E~BS-XPitb&}bc@icnzEyZN!2EduDzn7~((6rA zo-`m|3M%7PI&?`f&j73mpgbO(tBNYa7R@x7mF6?n5=KfbGof-PfFp5uNW7 z=w8OWP=>~t52v03eP+)+;eKV6>7(eDTHP%CSvPl=_QU=C{ii0i!|+)uYbG^6;=Q$m zK>fGDvzp`3Zw#*+u8!||8RYOOrxPcC@xY6Y)b!UG5^wx%+}Ete*5e$fhhp4rH=&#L zTtR!t2S?0m+BRuz3!Zegfk)|rWu|AJpyh!zFPjDYaMLEu{Q=Zef%oDk0$7pT- zn4h9;;!yP0I`lPGeK&=mW4g7{Tjz@{1)gE5p|6V6t3$CUK;)c2t5#Xaj70{4&ehbZWuXAI{eB zkp5tX_xlA-1RjE+G8`q4@KlG%A)hiN7VsF4tiU5dG?`0`j>HE3r5HTAY7%pZE*UqZ z3L(?+Ghvg}clpQ)e>ul<)Q{T@IAIOKJFP-xM_RKnN?c2@?eRN3Dxlz!B8S!iKd|SV z*lzHQ5Kz}AmMdLp(2UpfKr_xRIyLU~E+HV zQ6z_JcJbm(zjtpn5`Tb$us(jU?Wz?1Y*bsd!FImb(An2U+k5su$hH3FdsviZHx&#i z__tDP@APOec#u|M(EoRu)O9&kv^%uk8?TWuU(rCo^J<%$%=oiWAT~YRX<#}bO1aG~ z-(r(q2=Gpen4X6;0j7)&?#{Nr{`}TT(#TWl(X>7r_k`A3RCOQgvDJiC{rk7qaJ2X- z>O&VAgdZVAEkp{dAfkc-MgQI2w>GzNBnf`^ujqj`!hjcor3vrF26yOPnxbWnMN&&r z_Bb>Opb0d|wm<+z1EP68|M#nWRb{;zASKzJ-9gM)BvF<1%F4=jBC&dp@I-T>U75bt zRDrx?2WTJ>BE=~MQn^y_V2_DU?M%%fP$5WHkDQoZLiW#1!E03 z4S=~|TX#XXtDwD`vHf#oc}gxl==e~vm!aFdY@|@2G2&=^`XnXd=!leTUpl% zM}{K-JH@1$8qNXG7zVMTuM?rbsmH$YqRE5U0(9|4ebNFSZ_ry%c*@(^{FVeAj4hG; zeMbedE?^d==Q-RH7QI=yW3xomH#gO!E|J-0x`pY0#^$2aV{RSK`)vFur?GRC439T57B?| zyrG6*)x*?g=8r@0?i~JP1H{@Ohwh5oK&FI@3bK8_9o!D;nL77`d^u5P=vKkUh{0<+ z>AfD<5V#Nm&UPCJVVGP9=yI|R6cwxI04&!t?5lcV9N-YMVXiO{8T{B)=wUqq`yk{y z6a!fmBTS4a;RD&U*In=N!6~{+7&_8zzJ(G{AkK(U+#kjagaIBKFDrTQiQ0WQt((>6 z`x?yto)?x~O5q=k*4`Za&3#mSFHX5*SS4JqmDrmb3K42-2&oxa5kizi186^jEzQW} ztl_#D#7F7P#?03H-8PTC@{^rkziCBIyDI}oX4lQobK5A4@X5A9k3@~g&bBj4Ow?aU z(o|*fmtQ!S0H{0HmQaIn-!I{YOWGBANa(;YEr1ISrZ36X3>rJs=FlrUpIz2h;L%VP zBW39xwcvGh9DPX=wHsvJ?4zgfuFkgEFO**f=?gq~_$_VlJgpHr^V>(?8QDUXN2;e*Lg%r`s;bKvbxI= z#q^SAJXCMcfpiU<&rZ9Bxm(qj6xO^V?9eN(r?P+|^lu zScP%F^ZDELg4~80jMlxgS~vn(qbV6wgctr63FB(LlzVv85_!uYEikhNb5!*h4Hhu0 z3WkKwdI2{J$Ga=8YqX<~CbJPGp>kzH+tpjwCYo7OOfh)((i}9(VCh}OTNyBkP0GHn)~IOvj)d3|jW<-%!$rAFH|3-XG6&$4U^g9QjW;v%Wau#F z3a*Zgns4e!#d$!zk(A?0*K=;PiU+hbYcM@P5WWk+Q^EH_y0B0Wo#M^%t&nEkUlva= zZn)uCZVZtEPGyX4IOXL^Y5w5j=(JehS95bU{FK1ferF>l>_P^~w-Y%dd6;(Pi3+$i z3M;bVN3NeG@X*!%l4wM*ofR3j!3yB=Zj8@h{Snz>&BA=qIi2!Zam09FRC0CMK5 zP%=}7I0Vm!%o3VfN=qVk!1Xa6@g9L zgJ<^a3TB9}V;bBNoZn}4bDd(aG#hak@Or_LK)%o$>R%oXDYb6zo+Gto7Vq=M0~GgW zf7*EB2|xIssEt5@gtIKk)YUO0m`m?UZ~bNl;f#VuGNz5Ojo(f;8yLi}Ka^DfyZUdkq#fdoP%`rTr94QWPl7Z{rD2|pP-J%4 z0`Ster|{-?<5XZCn5mXc-NLX${MyH2UhHKw-`bxuO*CQA&4(#M@*(*Ypo-o-0LQG& zuUn{56eE?3BiAXEZP*T+$zU)DJD)r|DNfN^=?bb8U~Gc)0c6Ichav_UfWiS<9`0iI zKIya_3WPv{{MlaPZJ4%a>zkW9tghsRBg=5&U<=j(Dg*{}iqLJZ*UDumizmxVddLh!fF zi*Pq$i&|IBCVTj+-pEp{DcALEwfp#Kx`!TE+eNz2tGDH5F3WL+#U(tK646Q|oJsVC z42D9xr^m|R<4oM<5uN)R3@4j9QCrz9?-)&s~g3OYceHN6v!fnymw)3+F~ zOBWIh+xWOCFG<&;VF9}%nyLBh@ub51Fbr=QbOm0< z4eVfDav&$=I<&mh;NzGj+KmulICMf9jI9o~lNF-35cA?rJ1qGitFHudamsb#cgV?j zCml0g;i7gd$hTFsP+1^FY1P1P~m6A(%1rq=49F_WJZCE|X;9(iL$yjRTGBY05!d#NeqDELU z9IR?*ypdybWS?87DSzTt1lND7dpU^qyB8ykj!^f$2v`A^V^d*LNJ^3jB~;Q?I}PHZ zaOplO^<&MRSi+Wa?Uv!0t0Qmx(!$iX@#Hvg*D3pmjc;R6m zWt=I2)+!w4O@7 zcjF@NUYchMecHL=fE4gfDU<>GLN^GHvv>M-mpnO46uAn#s6sshBoJd+ew2JwY!wgO zVwJbrA`M4taS-`L`5Sz9RCHn4+=B2#O(w#a9&OXrnsR?96dn}6(9CV{qK6eDI(%$J zhyM}ir3&HJq%#|vrsl`E8NHOYkQzbm=Qx`o?o9Hjh7V`$(Y5Kj0<#y78tDp&!UZ5h zJtKU0%Ti283mXvgj!LsU|A5UodT6h_`0ahD*nL09-wbmKZvJK?1hHy39;GD z!q(9V)6S4BaERoti~sF=cY85Qore_Yg!e3w8iexQ+f3T+j~}l`5H%LE>`TY58~wA* z;mb6wKUi8|+q{6*8ME+udU|r|=udD@+L(F1DlI%Q1IyStqK;}X?|3EnCw0i7W>Ly- z%}5;%2#U(oU2na^lVc-qA3t*=IwyCZ|KP|q;8mR#n_|*37In%Hj43x+M~@j&$Xmne z=_QMBX`U+c*j1O8pb|0!-|W?uV>?`)orqCZL+(+%Cr|MYhf^AENO&RUr)gA_O+9xH zC$=`?(+d`AE zoq#JJFVG-I(A$M!__W*YA~TP;0a^D@#W_esUujOF7xNpol$LVZwAy8qXuhK&dc{!* z>HDieTL4Zy(#L%7dxt-Mj_{7G*y-6mD9ZI{pu+>@6>c!1@7V(^G2r zC4h%-l^(k9z(R6W^S?k%OGswn&_7*|!MS+i_ftXYOHFngOQLr>2r+9F!GBwI{4GQq-d5#Xb0r|{V!{yU z%l_rf#iT55m;Ep{cvtO)zu_&>-i0TowzrL9hP93Wm2wHS6a2jig|7=oP-j;*Hw&ND z$i|P4;C76XS3j+&V;pAK-Ry|Tlw>yxyQQVOYb0s(3$iJkY@Ju-)vEPrsa0V;6ScuD%`<*@nVk_c2YoDhp_ zra*PXZSEwGiZk>UupWaK)g{D@a*jngKu)vB2!!RYXki*w7{{NQK@`nNAln)Q5b_l8 zx6&u4;_x1*Pj*UV4eTNMWTiybpj!3G_q7fMW+9n@I|)%Jx?R1Yq;LW&qI(x9zQ4MiOmQ_kNO(&h72ji~S0rO7ykBw=K28r?OBz?oC;`@W^O)U%19T%5cN{QM z?4zfi_wIG=o45J?i+U-0rMF;Uc#~KWr|nRzSZG(iW!QhyScbZ~aI_xvEC z^t4l_A6`G*c@rSTH)x=4IEd4?$xtnalBz)%?u3XX_=nCY+iU=fM`K<_#U-dPS2&YF zc74lw(d~i4PNziOaPYFFc$=MD5I9xB90NrtD+y`D4}!IHlphIiE$}4fo*wKU9DaA; zypUBTi`WHOk`p?Y;`XoKy9oD%>jVR&iY<_DWG$ck6eUJXX%V>fPtnY78%` zB=c2JFw?6`fNe+?5+W{tB5otlO7r6JB3c3p{I%vYL}n17&raHs|1q!$dQlTO)c`_Y=0qw%p0y2%#r{lK0{B2#jc~-(tkDMG&$4 ziRK|C6 zE-Kx-IZY|itY`t#tUwaYjh$X;u*@(%$Q9BEM)AbF;Jp7SH0O5ueWYeM4@htADUO1G zEmQ~@re8RcKm+**(ktVZAV)^23IRGiD!`keYa911nt8z$!P1rz1I|`ZGjp@C`T(Ss zdU0soY>L(GyjZN43ov|#b9&m7gD$#oc~E!T{AUkS!9dm=tRGmU@U zwxR;?FQGNvGiO&{tXDMRC(*f)jmv*MB7H_Te4-NSgf2LE!*_j;dGX~>DhE$oQwRXY z5f~YDMf8vX!ODQcEl4m^{stOKVD2m@r0R!Gr!Gms?R?q5A(20qZnz@aCk+!=z4RJ)u6rNYH^baR?z~x$XiZ;*Bvd&yH!_EB;@6-dcC!G zcpMA^L-B3wknsA+n=nd`j0T)g)>Q?APppgo*w-cIon(UK?CG&`UdFMXtVDpEvo8{m zBSccLVfcX$8bkAn*gyugu@rOVlxLr~nUTlBzyoU?+t7CbLu8@eHmTrH0Hi|OP@HVJw295k z=M}lm^*wOx0HJy+t4)nYR4Q6Mr%JH`%+?GLagm%v{GhGTkMC|qUKZg}WbS@Z4Xy@O ztuqpLThn9kqnX}dTF$NpbhtrPv4jHk@~7D*FAFSvxeG1dORK0o8#XGzV%LUYJF(?J^pWdTeIeFRA2bU#?yuiq;G4S=K{uz$` zx?F})XGk_JHHca3ZM;y2|5M#ihgzZY^N)MgRLgISu)Mck&7p$xv`Xfh!GFOBZ|(^M@L$iKjMl7o1kg~q=3P{ z(a*Mr`de2`+VIx~tp|KoBo97iApyGC@&>zlsoYL*jOR#YArKYD7L!PaNg1f@`Y$Tc zpDz`@f7$Vr#=H|bwK0F~DUN}flG>Ca|9KRpLWaM~UFMAmOK0RXz;?htv<~<|%3pK6 z-rWrANtjsb!)SsHsuB1$)g!mCIy0_)V2H@@@YyGsw)DeB^&!*j*J=7ZFi8340kXA` z;_?8K%7}FH*eB}4NbOsA#zmPb&Rq%`Ow1uyNGBx(j)YLc6WRtMw0&T;!K>qMk0C$B zCcS}7E3F#v5a&139;l-Tt+%WtFk|e0yZ8Jci7%Q8^77h+R9c^AYFk3kWd|^%!qUPX$sF+mU=W=&KPVotfWC%0w^AQjn_8 zt5@ljGjgv4>UH@I5#N4Szap?%cDZ;u*6k)=jGoTPB{xsKAJ&96{Rz`_`V(zQ zn#kYJtS1974fd`*B&(Qc*)2uxk$kI&1;kW0$E~_sLUC|{D%?b&it0Vp&W+5AI;&|YO3cC3Y2&#tRwO|dZn9H&|# zo7Q1*i(RnnyG@-4=n8W79P)`=PODTQ!AnedeYKt%whR-X2p^cCi@fo>YX5>VJNURt z>G>gukYr5TExLqbuix}GzQy-7j=5H_hldB>Z|BC7rSg&R-L|yXPu}!?2WOuTC=0h6 z(=Z>68Y_G241*0K-Io4qe@wULJoYBE$FN*;`hEHF>CPL22yGl>#@d)ctfz0R{qc8U zI<-pU9;ljO)RJ7?!rr_+k03tQo=zO)|&-ReTqlB>V)+FIjAjm@jq} zbNXPv*mU2CxIsD0jnb_1Rqtcb1|uYALg^@gJK%e~sK;DXBRet0?n$|!Oyk^Fo{EFF zEvMoTUXx0O8tMjRvay0?Ng?fvx%tiTovYglb3ie-^4*bUDN7Zp`dJZy7Nrz0u{tkd zX>$vxMid&$kX7kv296mC%p#ZI=vzY_4cc1NVr(}q!(OS}1_Olnk^lx%$xTA5@`i+D z_nnN)WbRGojA3m%576e=as5#VMbHrpEhj9FFrU_E(SfJ2Nq|U+4cPUFBXTLIja63| z{6;s?pA$JrnjfK1E2VyN$KoArpidlGb<|{jZiediMu@kbmkU9$EfqM`%X-Q{QnO__ zsUFWSFD;Uz`lWVhXoKh!3L7>Y)al~m~*;A9UN&*r%)vmPl;GBSy=WX4nvuvPXQ-?D7i3K3u zhwNeQh@N5z9BH#63g8#zto_ZwFCd%Xl8V%^dS6N&$W0T}@ZZN6Gum_YZ&{Hx4?9A$ z1Ow%6wdNhXi>H5yoo3X1bU%hQY4f-Ut_s;N!>wpR=|D<4p6jX(x7AW1Tr)Q;9m0c^ z5$*}(By48mu1tKTe$0wiX!_wQTR z5c6vYZ;ydu16U%erfyfzXcVllCP=ot~sGcO69X`8H zJct*mth6q=qsk?SFx8_L5aV<)CU8SKyE8)N&48ny&g+}%sGd&iCN5;3DfdAypH#$2 zq+xdxL`z=cbtdgKMDpB&quR`hSKDr7eXt6`J)fj}j2GKDblDH@mUmxm4 zpkvRIE@uU}fO9inVdO6-Q{2Njm$%70US@!Y3gWFo^9JTEw719ns>f-E z7`x>>+6rbQ$0bfQzQXNA>x~RQt<1?8_I$ox8o-bcjbB!{$)ZfMDRL7UhzucU<>tK; zbreG#aFBN-vzwdlzxHW}X5!1UZLF#0`7(CKEzfng?Z|^I$kIaar-H``m@4&@>~_ZM zW;MTohpQ8mb;IogWxeT(NmG#_Lr?YEP(vuIM9#&HVYpm_YSyd`JLhmwvR1R>^-VcJ zPe{1=*ufelK^Z%($17K5RrzPb2P`}@N&$jy;SK|yV^}@!5fhtD4!OzutMT zXZ25z2!-l>KtjPu-K^@FSv~+D&ogmKA}0vX8w);Nv>DE5-ROmkPXqQU9lUyys%gI$ zc}5v#8-^h>X6Z5g$O+kck=sQti(k)XFCgH1_D>HCt2Pu*pw?cIuMe^~zojP+vZB_r zDstp$^WBMKPLlj(V7c1BZlX7X*I)Sss>$|zih!9>m>gg$Jo|p}*~D;oGqATcIii4< z1p6cIL)6R2ff%82$UzxJm&moHg!j70O_1WnMFQ~K;w==1=Ia+^dXcf*JOJ$>J0Ns$ zfAza-`=>Y3==Ib7pWlqwbI+jHWIe_NGb8K>zCY&cWmKy>GZf7modsmS{?lX~Mgr4%>fVCH8GMhHm@B%ly~fkLf_33uoSffav<3Qlb@3 zcxV}~RI%1G)1#r9h%0sLJmsJ*+^hFbTKXK2hsONPz;E*9D87)vM0Y^U83Fq?;zdL? z7Nk&_ccfVIu}Kzu?3$79(7SeeNBjg8K5H_MSB7H_4>PQ%x~=#N4w=E4BOM?ywfZmV z@Lyf_Ecz1H<%*`Iwk8m*GDeR~*))%Rk9g|@YG(NFd!y?AB3#T#wdpV`un=PvRa7SM zt|#Q*dkcrLjP^eq$1Uz+12wm|$s3tuz=#bKs)CcAaTdJP4yeF+EmQ3pRzUEcJf(Tvu z#oMY{WZe6FMl(e{hU%)c& z5rlk4M=n@Dq2)*|cLzWTK}q2O1)BJqXATz;2OQ>BRJGdR<8<9XZFs*x&NjO@2q(I+ zhc{~~UT+#FwQmW+)>f@2w8}9733bfLt@Os<+ARl%7^DE^eOP?Ip3GfKdE{(yFEJuw z^9jVT({HT0hpHS*U+Y`^dvjF2rKgq@61+;&*+pui*@6QHwGFw_KFp{A@QSo^4i2uu ztn=csDDFRM5%+OcL{hyeRR+EfHk7OJfIDbPN6N}(#h1LEy_>(4iI8&w zTC3}u8pm1EH>>%)FtX(xuhhNAIx#`UryV|8EkhtkFOyg}JiB^>yO+zV>J>qxVla?8 zvy>JsOvwTmfcZ)3{NI&E7+#bsP^Nid7Cgm=XZzvG2)S}hfHR5_3?sZ8|vS`X!Ii_mUK1UpH%`O-Rf3a~pq z=HIN#@?|}R1YA^I5@QGQ8mrD#wfg&KU(e?qgBy2B+i-Me{@90~)W2Vql-D96;TM?v z*|fm;SD7DR;Z3<3Uz>TAS0!Kv8GdG~8tPBiSF_TLSNZmz^Eq`Q$BdN}3c=tFxf)>G z?9eaA=F9HiKkLN18)b`?MK^9rpnm zroHzdL85RdJ2&x0&4EiODe)_{80hZs6@0a&H)eHHAw%!{W5@Pp2Wp5BmN*xHG_q=~ zT0)@1b@5mUb{dWut`MavPT%Oj1u4!joYrIO>#c9Lz0^z<{tGXwswV?vVp^_e<7=yu z3Al--Pman@;1OAxF|Kzzoj%=q`x_9V`C`@VzJB9BpOd;8&S!?juWi(ZgwG5c^ER$m ziI@xB!U$JPuO_9T-Lv}tRzvxo=@wE!nhN|FJm74_1a19dIS0mWc(t6b7bzz41@1%; z;UxwZ zN0x%_w=e?$Wd*kZ2YW^#i@*>QEyo8O1|X|*L{<$NP~k?jLz)@WATjpM7VtZry6a>Fy91+IbiJYAjUk!?6~4~KsRMrz_&Jm+QsK!8v4!!bfLA zx|@DLw=19V!VNb2VvuNRxb{8^iy)j`R1UF_+kq0DK7rv(DzHKG_0_duaohp1t3ml7 z_RW7Hb+!NEC!?^?4k??KJ!NzL!$ilT?nD;O21PV4uHofgI#N>zra}NySCAw~Eq`FGzbRz}*~0)%QEPCt`yNE5-`;`^#JyQYl-kBS#_gRxRw z3`*QJb6_&f`%_~&HpmGGln)aMDoTjYkmrRD`4Pd*cRs1m$ygj2tsT1|i^38$j41sY z>2m6&Jf>7|cUG@CF4w?O1BOkia!M}T1AmQESix>_3M&aMyLB@ZW0Ps(*KpJgKsp$D ztt=_0V1|U{+6C_fmqOcPlsbS7Ka@bjLShmCL<-W~L59ZltlJxCno(X3(HnJ_-5mxH zgk{hP?fM!G`wp9-lBfA!fNq3R1jT`)f4V$OU;cb{vj6Sy?ELg#??`VexD*;}8S1LM^W(xe`@oK$bKA~Q``#=5S&;DP&@H)U2$<6YIHG2p@fjT%D zE&}en6=aVXx;qZg^rog-GbzEQPhHc{=>v4$fyc&817%4%d+a(x?lMc8_j_e?E!)s8 z_vUm`^{qSNMe)m%Uw3|a`fJZG7L(X*4+wjY*!$Zx6HG+dMR7gQu3z~1IgZM^i%N1j zLg+$6u?q(OjWoVq6zj!*`nbw%cg<~rAG%a2lpNO074e`y;}Qbq+9mEdb7{neZ$V}) zzn%mm!A+JxfRn9%{0V9y(QsBxOddC^pG>$rScqny9Km~LRjQ~(9Ra8jkyX(UKWVj~ zq_8!&6n(+-tC8FBumv5FggOb<#Xv`>-?YDA4B!kzT=PTMOKy*v>FM>nS!G3rYe1d{Rc$DyxOY*c62<-%h4&%^J^*Nk2Ot4IDu^~36nh#}C_d>_bZm}? z$%&E#!pOmzftgh!4e-OndOfIy#^Z142~GZc@`A_CYa1VMI(FybW?9UF_XbbRY;Dp9 z#d{3sV7vd8B!Fr6xy~8Aq7Em_=^H}Z8AMHm=-z+_*`xvvDe-FhFH)};PMVt1v7-Bb zI^9E`n}3Cy?c$UqFss`UxK~EZlDTgZl&h zlXgV3GlQ(q`2S?HjZ8LP$ss7LK_UA|1VlVSTUP@gs=1Xq$^z@t$codwX+MW@Px{3$ z&4kc=1O6kNYs-3Z+jzD4qG!QuAY~5sNgsyocaPEAm!Kqs%=I1IF3Sb)-G-@$ z2O#@iPpTmFz?<)Ev388H+p(MSZ3SxfZfAMXNq1FmKp1yzzsb)znlH{IJ<=w^X%}<3 z9GQbz76CpHP^lxba$J1=F8^E=cjdTj>5(%=tTwM zO};m|Y_@dWLLS(PR(TH~vfoU&Z%ed!Q8F8SG6DXk_aiM8FS-(V8Yy$oqV%Lb^8cg_ZPtr4=v* z-H1CHmi#w;4`!fmX4_FD0^#=@|@ z)K}q-E#}Q?>w6m-Zq*}4iZ`#yK-_A?*4nC3NvDAz2B8b>hFjmCKUuyZdM(A0ikx!d z$$eN>oXCi(D?qFl)v_D!0YAG6+ zPj|fFp#A-|&YCW@>Lt@}`o3mAAX^%GA>tJ^j6?(3EVa!QGh{}DCsx+sy9%#FOv_y6 zj#$Hmn2m}8a3}2cdarjLKYbH{m|y%zHQF8wVRNYKDQ&{U&KD}zSWgD1LHEnkl=coj ztij`+eV%VGe$NbM^IAB>7#u7ihYJ0506`85vxr0<75SvAy<@$5-e+6Kd+)^yb`sG^ zPtssH4+x`x(@_|_q1??p#l9j;uU(<@;Y`@H`KM zp%-maN{i8nVU$9CJN;Z1l=jR3k5r+X865glB|{8K*tIW5t6o0ZVq7?G7k531h3CW< z@Ld#jbT5rVSfS7W480D(t~tAsfc;1r7u`R|6pzWCSwiAOl#(8S^gLu}Ls(wtq}&t? z83{@O!Hor}001XLr6v=rEGM>+AD;|M8Gj{=sP_$XY5qXr#>N_4rJYNQJpmKsX>cPI zo_n4MUF4CX?Fv)(s^=|FK=cCTv7R|jeG8ohD`Wz8WlK&4c|Pkgh9zg|PuojN_VCw+IJfo=$l(<9DHI`@j3eCviZLX zQkGa%4n9UPt%^P!c45U#{y{+f83>`BNYCo;o5D|ZdW4@LUfR=s?Jxnf5BEeKb^(Hp ze^Qy_{LLeG*hJpQ&K*?0gx*O_bT5X|3`%p2%@xMq{00x0K5rv|wmOHfphBXoVbEk$ zQ2m4kw&3TwQf2C_BXTA~z#ewFqvw-Po2G#vFs>WQGT0=P7@u)jj+Otiz_=z%lr*Uf z2MqTUs*g7!J!r|H^z@ZL8M(1W>@$Otk_3_%0v?2YV`vh@NB08PV+#udVtB^N`vj{B zM!V>(wpm{QQ^bMglX|>zkSv0YkkS|`GwK&3;G*UC9{j;ppuqice`1~3sPEQ}#1_7U zMj#m>i5P91al`uU$Djh=Q06nN+S%qJitF7ie)-k^^Q`)LMIz}Fcimxy_8G!B4#O3U zt07lS_pJnb6DYgL7$bp3uqlxPyUvSYo0pS`-^+Hg**_`QuaR6KHG)Bey#tVJTa>O_ zwr$(CZQHhX*|zPfUADc;wr$(SF5NnPUq|25FFHCSXGTUw##)&fbB#H2{A2!KimBnO z17FP?LId~Rcuw0#l_SGMWwXGm7)2e_3v(SQ%qI=EqFT)-q*jk_6lw6NFQ$? zOkBpCzx}pobBEjUI8+4Wm?_f8(rPo1G9;-N89^G^RHoE~-L<#M8*Mg7uGB(=x?nRyOr9345_k0gBiQh|0ZW2R3k*Y9MG; zGNBwOLs11J*?_Xb5+?Y7*Z**Ff4sH38DaP^D)pzQQ|{T@ znC%E{9t{%0#>Mw<`Tuk*wAj|tBtqBz?=bLYZ0jFW`TGuEQ2F*68{5=>WAx=$l%C)-mVixa)jnWPhW3|D1+Xf;&SYd-D!U-R_YYk5V zf{b-{{6lS|dFKI%&bU;XDGBQ_Z>T^`g0q0xWJ{(d!IlYv5&==k1*7TVv+$E6-z~Bi z7Ghr{^S;Pp0#^iR)>)bpAq~OoMwSvn&4$$-F_zEhiqV!aUhT2Qqil&$rq6gmo4#QR z;;~GVK6SE1Op!^#kZv`vc(%Qd<;59LR;hlHRtIH}b^^^@pq}X8(LQ&rEZYD(sVeF+ z#LI-f5&_Mx@3AF-x}lT(BUQp0*W6`czT*~f+)3%t0ZVHhVLZIN+|5;VTfpT_(sY(k ziNQ2n@zhYq`GRI^yNnks5#cA*FK>LNx#EEqX+@lCBLTmw($T>jh`@kkxqIKB^GC7XnyA%asy zcgbbiz5%rJJLCwdXz(1x7LxIDT{)=ufR(_b72lGK_Xe&3)+2tS8Z^_&_cC(b%XcGw z0j33>w3uqvde(Yv%idJ8V@!>KdU(2ioxJJKd~tj$3xN?sdK3uwjp+u>JsGg1_Wk!)6X@A~XfItV`jb_P=AiHt-neV^0fD-+&_q^%M;P$za zldXentX{)8yaA?t@w`Pt>~{UbQ>(Y5p?>XhRyTG;xRqT=TH1iEy9=k?$~5Hi zJt%8ifmU}9v*d^KX{POf<#rzX_h@>t#dCoYr&9yK(EM75*G@3IugV+F_M(JcF>bL% ztaTvK+D1D`@B8<5RN??!H1p-uxNs1u-R9fGU5ZwO|M+UhFMB%OpHg~@UK}{-2iziy zTWQx@+u{NkM;9>F=C-}`+ltN_GNx#sd0Rlb?Vp9Sbmea&?UgRhGO23s5ayld_iLL_ zdGN)B4_7Xr2s365D7r6b-A8t&Ab{lecOGgs_xm<~<#-Olb|@4g^}qnN3Q?=mzx>4T zIQHZ8M3L7|Bkb(Fue^v+?wQDK*!3lKKqCazc@)n45kwF=8A~?-RD%H!&FJk<)DPH+YukiXt6=Ke&{YWug?cH&?$6csDEs#a4lHx5 zg6q<*9h8D>5Q#e5p$Z@ZvHyDRc9#S46VCc}q5Ra{&H zj8+4q8m>*JTQhe6GF*j8oD^)Ykz7>5Ji+aG1ee8D(`3DwHRJV}CjJC`sNdjVj&PEu zV|EehBZ&zC^}=!7>F@wMo?clz^APT^*yvu-lOaKG6&n=J?Bu2VtNr+b3zXDA)j3;* zO7u!twyWsNJns4O$QxNBN;x1Lot9Dy3N=wnnD!MqW8}MIPT6g zdFzLiuub~I`XP30rRvlIX?48V1S+XM*WoN_A&Hr-qSuA0c7&_=0PjKB2R(W4HH~fa z7-*{AS?-!+YXx9di_yXCew=MYjk=#Kw~4rs-MsWniL8QC7WVjlGo-GwNUoI>6;djy z^}?}B`YaTCwYTkheVj+fOlSI*uswRCg_<6x~dNaRB9mro#xsh=8JeMO6~Zs=zm1ONWm0saUBEy3%jK)`@G#eFTdLWsBTj zXmvaSFVU^sL80!_%aR!PbUC#nvmj2k=niC=Ru;O*V;Tm-kWGD-lH%7F*Z%S%Pzmjw zw&=KIYl{4{$Wg&<4M&#yx?z1UD$6xw z4ex--$tBQsHT;b`XjQdaD^aQiTSI0`te~O(wl7Y(BqBYBh;<-jTclj35mG;qKhc;b zEsqLdb0v}lFO-vR%y$4#?59I;y|b*4!rKymyf#He#G0&w{4 zNfaMZZ3HS3B49!Ryvf3PyAfj7M9bjE0KRuQkv}wm6a=9qwnOb5 z{fhJDJK+wcq+!%XM?Vfi=`69!P#V;DLMysU?I|8I7wOTQ4)u`089XHFKgVhGE0pLP3B|ZZIqI%%x7E8|-mJHb@Bs&rhVG%OV#gjTHUZyJkfsv*pnz@dTF04T;~qj@pWyaZh6JlrH~59riQ3`Xy~ zQ~^6DXB4T^;VxQ4wBYGWCbI&BA%Ka@&Azw8``R`%j#KSV{iU_ELf&R(qM$;Z`3qP| zH=9s9a}BBz-}9aUXQqR2Dezz#!U1v}NgMHK{c093)U_!m`{ zO2|&S0E;n=;P+LbN%Y)R1ZM9<F^r#I2CPPdY>`-T=2z;RW;X!^G{*;{ z;&1e_{g^&Q&6t1}`iV$dg)DHW=$Oww{%n!MQKS?VqNYTqrwlgioZHY5JnBjqEj zVr>FTmy0tCtobDlgDFRV@woxk5~(EoM=_!Oe*9R(_|h;D;cfw1t-3`u&UZ8+jtPN? zIVLvG6dMq=*(7RsjTg&a*vVBMJ{Nz#(g@=9RSNshUGlabAF~Jyn0V7ykASL> zB13n=kepo3Ujk_M*_nxh$4Hh4@3@j{1gI;_TdhL^DL4_ zLzX@KpJhuv)_UAh2q=-}X3@P$;CAi)J^b$$0B=4Z7A|OScy2z)B0Kz7zEyf^jpE># zC_XkLlXXVp1ID1Rd1pgF2wUSv=9PrMNFVdc#Tg{M{eJJKQRdymqvDUo+n4V2rzp=>Ok<2{MMVB@$7E@r)r3%VzJofZ+je+1%JPA|3@RXnWRppCjks^*ip z{Z06g?n0}D1<_i?bse1idhIP4@@fW>9cw7zgv}yJ+%PHY!&>Y7DD|#JK@eHy6OV^c zY6-}jP@-QuQAltX<&x5N1Z-Yb)TCN6FEk>b=>hs9Le00|A0C_Fh)gDJ(OYRT)YP;h zLI^rHS39Wgjzm{TCvW>2fW-v#C0cU0h;-MNDf1BiqhEZ+Uj`t??xG+|$p5bETt5*! zy>&Z%^;K^=)0*->`868N*r0JcQ|*9EL7QA#(Kr;S?vSXemXVxBb(NTV>iMU2pxnYx zK4ZZom!dm>GO1|$tT_gx*g0%)AsijJ zLTFOY8WRf7C@>b7KBJF8Ut8Wg59OpoK!C#>;9a{n+RD-Poe;I+H(iF*iZe^{2{Zv# zs79`p7R^$ecz%7J%3u4tKepM2sMG!Qag3<5C(wq$D4W_Y>BZDp?3-Re(v5Y~|Zk3j> zIfGX12`On=R_&yJD9)(e;V8gK`IeByh4(A%3b{!(=ASiE)7Mg6qOjajU({$Lh`=w< zZbc%2he#jC3vv^j^hpj^hl=2ULr?N&JoTuJ_po%Bw@eyN4)ql;WXZow-f=;YN#UJ6 z>Z-3OsM>m}1`tjixG*YS)z|(eh|%=O&(G|eL>?cIM-B0H=~d?5n8&W{Imi80L4P=b zxAA`kq~qWhMmD|sKClWug2;AT5(2;1c3B(Rr;zlD(K13$Q`eFCP~qJv8S~MB4*B>z zLvqZD^P$!LrA1VOsT8%Fk3mDS0*W6X=x_&;jc=WQPvQ46E{kEf(rI|^#{h8mF@rH< z3gy2g7>-}RLTY0w2*Qw??~0Ihk9mEBD_U7ix^9ezXqpAbIJA-j5EDuUR@j`WAS=TG zO8`jsVJ$Ri>u*S>GKw;hi(h@~w0LY!Dt|6a?c`u{@4&P4dsII9R@dR^KwVeffVZ4X zme}lL=*9HufOEUPT?*v_9hC~s7fGoz)N)Bx!RPY?b?Ptt%)%fmJ$HN}T77=~1(L;; zjd;tiOiFIBk(MX@|H38w!@Ly!gKygm3IG6s2mk>8pM?h}OYXRjbR!ZLlG9 zU8p0BiE43d<9SlRyv`EI)(d0;$T~s+(bFojH6oQKs<3^%_y{MUm~xMyRnw(PPOoIK z56n3|?;nF!x)3iJ4)AsQJ?*1wQReb`J(t=I%Q3he#=?sECB)Q4GgH3p%1+ubaFYI1S`sDJ&UFzLypB_AG16{LTHjFGqifzhxA2I0I=5SG>N^)P;rN^ zGp;Ux@wcE~mt|Ql;KuA8I&b<`UxXR0W6SaT>8-SfO}AOMafKI9V9hn;@`vZxjJvA1 zG*e>Wt&u&wG4YDEAqFH18gy-F&NX-FHXA)4eHxTgF-2J9Q*3woJiK3DlWveL*f$}q zQk=(XPi%Y4N_BllI|Nzg-D0;kz=b!Q70rlFq1xa;2nD5SNdQ=AVhT2+9#j~{Gy&xC zdAjUJDNKXCR;LdEa9n*`W#EC?iVB;1*2tMdruNq`8E-bc_7l;8%;dnB6lTIjNPwFK1xIEi>GIHSDJToJ!@9eO6tQ>38Psdx#Wc+D?>oFA|xN^taE|gNyo|P3QPP6g;PW5mg*H;2NHM!sbM1*13 z#e>sje~5|lVthi<%t`J=urHaNdV&`dY!ZM1AdV(ORsi-k`CN=Rr{K9P&v;D~c&8sDvpH8ur#x1P+x^QO#l?Y# zhF(5yQ$hUlWsPk6X+}D-S7@#BjtE)KA{|Y_4m)5l)>H0#QW^UnmO_a84$)-J3gtLJ zZ|JmcSW+);gbhcYvi$=uZdY6#fwNbwb==^>5?*!bdnPnFR`QV({^roBe&ULxvLWf4 zGC&$=S4zP+8z;h+EirKVI8WEi)rbRPE^4|Nx~ZDEctj~puq`L2vWh=GdCFfWaQB4G zf61TYOJ{R7c(w?{hd)IRcgljuA7aTu#`Ma*%(`a%#;T@R;05Y2yFMCf^2(QpBdIpm z0QFBzYZauVZ&*#H3Y&=3Z{&hj>$;Y=zK4$~vfhC~lZ53<+JhTv2Oao*QgRfI5}j3D zB&Es!5<0Qxe4aM@{C!<0wFXTr@>RfzIjym|?o_B&o9L#qoYf`^F z9!{F2yEw#kwO#IrSTZ2Y;^o5MKmB9m&>40qW<}1#@ zaF>vLA0uPPm+t(ygGR-^?!If^J7!yOq_I8Sj+?sTjH|W5BXZj|< z8ThqdMr>RmQW{I#M*&?vRoP)9TWJH@Om~Ti{gYZy{lr;LY0bp@`u?&%al&noY~KN+ zl}GkNqslvQeQu5y4!Iqe`Ba_);+oJQ@H)ZB}e9Iwk2U!u_W+ z@K$U1{!k`Xu8Yh!J%_!wk&}z=yX+z52u^_vSk*O)?`e3ZbEDxERI*dkrig!KV{7Os zce8EMj>98WcTH#WG3{mOSl=-6^!UORQn$Tg9GV!!n9G}U&ZJ<@F7l}&62UE=`U~)% z=nn<(&+Q*r&VM}l{~n<=w%AD}W`d@AsI4P7aobaM3kGt=GRNegRhOZ*nm+W~T; z5lW{bg9tboPM%%M)CsRfzVgCbyqC@hJczn^1O|``g1gk?=)8{-?qY_xsBsDVIuN~Y zZnjv*r_N+h+5m1eYyA)i%BM;sO{rNY?9O1?!de89E@a3>>Wc-$_5S<@Cd5OqjY5B7 z%e82UL8x`YOmW<}wW}3(wh~My4|~^yvmLP(M5ueapJe74$j6XY20h(+U8liFY$_tm z_JV;{=88v?j+sZXBt#25B{(lYRbk@8aIt@nx7B>{Rb!OaeV|t$!sSe%jMEkA+o$I9 z4_hqOo5+s>i&7<_l#z0gW3yFCBN%Kr3KJKBG>C(W22v$M#1NGT5fK~zxUtEUX0#mc z#ST=1Izw>E6Lgalv31wV)Rs_loM%neg&Ot6R5%+JZnH?2i=Y!}7U&#z?b@>&%7XjC zZf%t?ts69MuaKVt#=OcZujVQ`A?zfo1TE)IvK5@LK&mq)!21Hp79c{~NMH#pMAW@e zDByxiMKy^Mr_juGzZkzFj~|YlmWH7V4U4@0@!cVMiN9NIMw>v8igPN{@~=x63MPe< zHC13_z*sc8z5c+=o10~U7k_L;RIDi@Zla3<<8t*X)xSeS<7H9sK!%;Kh>C4(R_|Gr zNtc9l;RhBXi*_;#r0LvyAHi6uHfUGN9Cjb*wR7p-k^>%?h{HTUOaVP3k+nq1sK_;b z?+{q?VwG9SA=ORS0Z{MWm^twRZRlCCP4c83|8v`Jmlf6?kGbh_&k;e!g67Ce2&s9T zQxEwy&d|kt0OEt1D4vTgh8j{L6%>FS2FP6xemPVEVgHb2Ieb^2Y~s_G>^dHC8;V_P zYr065JvRAAoWuQmxW%iQr`%>I;Nw7=79>;!HU1X{1ZP5 zUb(FJ?%0F@i_KX}RVOBhEe2Tqz$a_?olQ?xM*Hw;KghKncw0`w#~Ve>0r2pxa`nie z`7o!K1w3uvs!vCFT)0P=e|^KPXIl@*wzT_xL1VFBD|JY>^<8RWgUQ>yOR5+{>612- z_i)<`d#dU*azp2){kV0oZlVY?n)C!65yL;L;=uOK((a1Pb&JGp@5aWz_PH6`EbRMS zqjTwf8a+hsq4VZ;$EK5;x4U!Qi@yDtzQMwC1N%<8(Gh}DO1_#2H=>jTUo_ogiZMwF zA&)WH0U^DQ1fp}mC=b`rIf?WwOLGTgQy7K*SFJ_}dGz@ZaWJ^FB$3?j_&@xM=tOZ= zs)d1nclJ*{_rmy%7vYsz*jf2^ZEC`Lt_RKof_8{g?K42gIa+`eJ|8;j|Q)6<@!D@Dc2=)jL;^`Mm zDM6H5acyx4YnhwnQR^P96?PvR5~0KIFf(y7yBRJ6XBIAZ<8E)X^x~v&{Q~}X{4B(# zkohNktNg@Jynn}NBM)a&eJ5u}OFMJ@ek2=7C&qJ#uAmxhr>bC@d#QGECzrk}*pL{2SEl{F1#qZY3|3He3~QM@Ey zNj3y8h*$U?iy<3YLS?KL{T)-9NSF>u_d$-yozxrp)4&M7268(Cwa~&AOck+Pj+RcUsGtp>P?#+>QLXewa*x|T>3GOcjNKSXYd0jAKem2}i%>#a$05l&I1wf6Rp^i#DhcCmsDy8?qISqs_TEtM^A@Nohp%>=D z^SVOS!==qVC5E1^EG&9Z7Geuo?)*aW!^k{dgx^k&fiGEHQHrJG+=iOoTR6Eg8^ZIt z5UWHZxlHes;{_haxec~l$gta!mJXSVA1nmR_3lU3zuVIo+WNymer?1zp4H7Oxk;PB zktCI|R-EcyKSl}7-CO&+5n5i0T6<zLTk=tEuC^rqlm` z+q>eicG}=b-1(%gyI0L|uB;;=Z*)uHwa*#(g>;g*ifV1f#G0g}h(be{4?zd0HK)7f zw`0d50YpNoD3_hlg{6vz1i^y!c@5TU{Cv8DS1&Ow+TcF!o6h6+h`;LSoHd5ZLZ0cj zT%XQZFJ)R1m8zN^44sR+wk#_0sqRrFnmHCu4SE|XxU#79sP;zlr2WWrKxm}Eau_k2 zN$n%zurQSrCx`-_YcbTwor@a(E)K(Yay%B}!DLd~-*L+|*^j!yqb%&&H7A{+LJJIj zEPndq`;}QcAPe{a`pJ_f*-uPsa!&-`_!bYziW)I*h;&!piQAw*wb|0jz1J9ZV$R>0 zD^pv$Et{>+FR!Q9m7~|w;rZR@ahL~qILnmyOR{b%NgZ92tTF#BlDdI^iE8rEc{;#w z6VSg&sVS2p`al!IC(}q9@b;cG&e-FR)WNf0D#$iQ!PV1@yq>In&Zj$D7hhH)gfB7LzZ^nC1y4uI1!(12P1+Qn%-;b?V6Vq;S$`5sI zd)7gkVY#9}YyrAsWa3tHOROjg5vxbOSB(nNDyxd2$cc}a$cLWC3P8Iw>_H!)vKtK| z$gLZ5Qxh0CA$P=2ItSlh-M|-WP+^u_%@EHM9$1w;K6t~C@EaoYd=)XEo_Wjj!qSn( z&tIQj&zF&1U72+!>Nk2iGOLG?pD#xjD>nD{fdRVQJ#Bvs(`Ke-#)hUqJoPuzzu5V+ zf9vb;F#C&C+VRjTN~~G#G5npi?|V86XlG%lbEo;KVHoE60Br%-kGas#!p>3WQEvlo z>(MOqI$ti(8<$_bwBxo;VXk*%!!Zm++Xpp8?*CD2 zhEIoU_QaP=9N+9z@B{P)s#u~|Y60-e2_ti$cngNNlhP~a^78812BEZfMo9#Q zD-SR_n6ntP&_ooG#A22*f?!kzCG{U8LtS-ZX5rC52nFph8D0how>{Kh6ui`rc*CKf zw3ntxx2~5mWTXZgY1E6fkOY9iB9MVRNRQ1gcPpXZ6CpJvurzVX=%ryP1-fckv)ird z{^@Bq+2-vxE@ zLHWKj_x%*8olz~c#eRYD>h3u3Qc<=`}s9*v#!eliGVr_&p1fw`V88S^LOrzDx z&Zvu#RCYp>2)C}1i=IM6(_H8QDlDJS10u5B&NN1s0+c}8D6AG62S)2w0ajRDEw-F9 zT{v$7K8D;T^oHbRsC4NRuhHqS|IR&*QGnRHQ8BYNkYa|(J+S2wHeRVI!=RX%T$|Za zaoKJ8YvMErnQ|$C^KE=|u^^l8kd;wh#~bWD)uF)@*6ZZH59Y!^yG)?`nc}qu@{^r> z;2_q7!_2(Q*Ch($SS!%H@UfuWpEkWaCFk_DO>89AZl1$y$MTJ0n(g#iNa8SPTivBV zH24F@n?x&qI)Q0T+A>B)v;(^y=~sHbA^1eNV8K0EL3$g=T{6I*0GBOR=|(_Pp-J88 z3P-%*+y#S+dfCnMRa05Gbws@X=|lCL`R9 zQt*+@Nqd|zM%=rHZOBSZoD)tCYw;|rz53lE1Rc%^JHldebwq+a631O4Uabc%th)q&w^RefcCaUAaAWG0zy}YW%j@NM?4}Y>QQIC(3!?~7nFZP2&W1(}BU-HNF*G7R?^@Tynje#h!%M;- zHtgK=`EMbJtUh?}DYgZbND*-8nw|Tn)S{1c45N96=Ww6FI?NY|35sMnxYrNNgIMhT zW|nnZ7%b>V`QTRSz*d~yJ@?i1DV$>{B*GvI%)s;^DtB;3s-OG^dupbD*yL@GaGo<~w^nKVv~(8~yi--b8e5 zx`b-CkrSryV3mU`0jcny+@3>@a3VW;xdNgUZ(N8Sr56iIsmm}B@wY&irn>(hL3~`N zIZFnNP4S!UIYPIL50EDQiDvI#UmMhfsgb2o;c!YV6h$wa`Yuo^@^=vZUm&1y*ZRUM zzQ{m$*VJ(@H=OGMcT$X%ylY|n7g7%RgQJ@_FaEX*H@8|Z5r5;>4$s$6y@Ey-C;OPW zw1dT`$GwmcR03Iw2(MeFv;`Dh53 z2Yu}dqT@L?a+D;lGBAy(WUvmAW%N=lGrfzPQvLzDc|FM5Wp9$u?GL0)Nzm;J;KqaJ z?B6{{TIA~xE=3ICt^QlOUB{CaM44!LNj%$!#}U)~?*IMQnhCFKt#j?;Y~`@F%gkUl z{jq4Fimmq1QZ<)zy6(qvQIN`X6Na1$mcg)=k~aJsFxOwZ1cEvh0eOaEj5apLLt71` z3`f#s1&m(~FPM=0-QwXD+9GH~@v*&ZThIKS+0K$;T$w<8&#hmC^J(227N&WD0ye#| ziN`MaEaR}F{|gJk!#uUKo*@1hClqh z_&d-}s||+x>u_8sL~TjNe<@Vh6Oh%2 zTrY&Q|NOj~D%Y-r_zyoi!S?lYdJH$$*Gi)CK^yFfVARoVg04tFh&KC8OXCU#mZmQP z)H2##%O-NIjg8YlMZ?8WF2=7vH3DrpF>N=?=2+ZCHCqZyS*I8Gv-o9HM&fBeRF78i zsYdlIY=ARdN^i!)=tPqK7=3abdSlBwmr^&zUD$5p4DVf<`&(#(xz`;n$9-X&T2@1r zj{uQZvJ8CAfzRAJQv$UCA^de;#{gc)lP@l(3t}$n?2bOVjo)qCewR}1IlA3}<@)HE=*AW?trdlepzK0sG8HOxMFSCHRFTaBBk>iK;;{k(0dlTJ~3 zwJbIVgH~$uSD0Kke_gi4wknAF3AD>VMbSo3Hv8WJ9V`@m-Z6HB%BTb?GQTHp9Jn%I zq_>$_fCS+1!95~E?L9MT?kqaWX6Po5(Zp=5yf{&O9-Ppg0rS^nbDg@ zmL=^GRE2v!Z8u|eIf9`eon~9*m3USF3E9wvwzpqNS65&%3=Bj9xx9y7tjk7`^**;d zyp>a|uxa|7?9*f*%l1Ycpt9sQ`4a@l($b6uAK`mnxy6Cp^s1C$4G2TE+ zNyrlVmO2^D?kY^<2(JS|%%Qn!pu=o%84h)M^usKx4>r3KxAl?hIvS)g)*GUOq)n+n z&3pzZ0Dz?^xM!D}-h*trkrtwL?7x(0r>4>!0FAa0#~TL`yuu-f&kH3dtDY%OiV7?c zKh3bH1R1)7E`^k@Rc_Y^D`#t%Dm&M%PzMY0 zw>?$;N^`|7VY`5zbrYVbFCV!D(ONtvuMp3aNxa<@D72}QqF>>uomQ|RBmWPI*+*fy z(8@wex&EM8q!t0X0dy34k^=8R{n$d1-_V9vI^ViFt*z1~#7@c7s3;T+6_QW+s#84a^N7j;cM!D_JGQC>;b zMciX1#(N?M*QZ>qAmp0$3s~ot(MXiSJJDo-IJ-O;R+Dme7WRlm0L34%iAFs;DJ?Kt zinfmmT&4U!M!|5QnSTz{jlh!M^1&u zYzFIy4GRMY^?q_fOup?Lfq+5O@{OcCqt$CSe<-ARj@pU`~xWfgW9G+nu={2OoN8NcV{3ROX~CKJ7x0#Sn(e&Tb#Fa~Tbe z^M?u)iO;hZqfS8NP4$WcZ}*2ggGNyVS;q{S#HM~Afkjs@jK*fFJP7P+!K*f#u23cy zb8YHw)5hxWyyd$VGViQ0(mETm=Uld*CavLuY zJ?sv(KG}Qkh&;)XsKS)-P-`!)#LQ3k17|THJfhUj4KT?D(`M z(IJXW0){DQZ>@urTFttxL(NbzzyF))LSs|Y`~GvG^`im+VEy|-v$XqBTwH8_YUQRT z|BK+Vr6KL~(`MEEpl+}Uonq*5KE}*P(lSpaOC!IJGY<_$l*VM6XaF!8`Q;-3Kmv%6 zm?9<1y%cTAgVwRb7ogf_5yPaV8Xu9!hLhqVogu*Ky(M5mrcEFZ6Iu4VwPOcwAyrZE zNK%AD>7eY?y}YZ6DjyMNmANk07>Drfj~e&y#c5VsELyf-%gB$hJjh*%c^x_Kc`@4= znoM}4uzJXV+6@mT@fBJLmt)7z9b_``n2(!NkFj%T>8qrX0ceyf-B;HX)x*MU2E-;h zjeU^2fu&mr=Aj78qioS>GyZ$nnR3J}c9gdEi+%oL~My~tw z<=0Ic$g=J`xS3{CNXSXW;I+ax1t(yU0QfG|%amTQsnr|CkXuJOk5cQJMZ!-o_sTwFK(blIxXzY2W>p zBcu#4lP4$wz?_x^lYChQQo^6CJ~L*H*EwcV+7vRnziZu9kF|l#_$t?)Lb0^8g zwv!M32-C0D4DKB%3;&Y0c2*0yw4-v>s$A&ZsPffA!%gEc#FsPHcS%mF9P31^dP>6& z-LrW>8f{-o&XE~*l_hb-BdagHtf%B~ox9`UVLiHB{Qib{RSbu$HtIB*~4icv0a^quEp+LR< z-SSPY)_nL84Yyc6txWrY_pn+2sejJFTkYGn^*Hlw?=^Mv?{youJwPi02mqiA0ssK_ z-{Hi{(A7}i(9zM*<6oNF|FCe!sL9%|3m|m8t5Mv}54N}D}oocILJ<2bs^ z@C7k}qon7oLOoLuF6GKk=8Q8Bztyl53r%ZgqW)p3Na)F<@Qj~_`5yQ8ofEe+M zM+M?U0i{>z7OWPYd0+!-GK-VyJ7Z@IT9&uzp%pZC*N?FIHNuj~FPf)#nK>Yp(>)C8 zm2QU|jJrw}*DOF2sOgRUEV@1jZ>K@z_6ZxPbtKMd7Z94?pn^n;fgbym&!Ro?o9*y% zz9e(JP7`?+%(Yn&%Vl3QR|@%N?ql#k8P+*7U?~f^ajvdcU#!?s-wlf(iSIs(MOfp= zKNp|s;jtA2k%ov&0nDZF7U*^ukjya23+0R+4VvGolb`$ox}X)7#v`{DcPv)Q!yDih zoDotsr2`s`=FAxFdYD#%=l(n!tNH2uheJSKezI^!H|2F{fIfyZ3Zs zUnf&*uPHO+G+`xKQvVDT4m#?iQ>$DnoYgUM4g;P8+5X*cI6|$nqCRNx7RrB?kl>n1 z7t^$6yfQ@kdgb-{V1o;nJ4i*ZKlqxos#bgkxbj5!G48B0x9yy>&0ZnPY|^CHvi_;QGi~{w(+V|H(^pDh{NG zWFt}5e|pnyIsSjL@_)>4V?!Goqo3}ue@*dxLxq%R6C;0(6j_yrqgib;VG)UK zl<^o3KOo?braYjC@lNv`g!GoffJ>z7j|v8(B##Z`P~g_Z`gI=x{zbHt!3+5d#3Vfs zYQg`Vv5-y@iNAeZkbW^K=+ZoqjtQeVs#btPKIr!Z_!`jwou^xvNr-oeRZueylDw88>KC$LZaht+*?@Kg*M?)`+ZmmxpA9%sGSo!FeHyAM%xQg_L0lN86eZLwQs0+MSY zH&K651|1Qd>xD8W2ZuV~71thi3_iqOfN&%@Jg~nbhByGQTGl~y+sxwC0`9LKZv~^GjkJeWkQ7{ zh_?G3XmZ8@30ho@sJ!3u)78FsQL}X*#I)Mjp141hI&(+jgE>qwgPd+3T;*b)_YQSm z&&CWeuBVnLNw`{qS(59UBS>*Z2hc8kHMp72Wa{5u3H*T_`9Gu!5cBC z)?jgLqCMa;EaxWlN&)uG46f^0w5Sy^E&=(E2m}daSGY`ea zsIki91#-OgWjZ4_1s8Pfb54pS#uS0bW5wn9!X&X262HQMo=gSo@mVczPfPc1CBI7x zb(^aDwq*(56l*B=Ig7WuJZ+p4r}+qP}n z&aCv4wry0}wr$(isonjh-_v_vSO15I6>CO}Id0N>8}KhsY?ln07Ro>Gk}5YAR6hY+ zvdN-9ChP@oF_;X=W#zB3*pbQQ?Od{zfS57TG?L6_x~Tkvd6i1Gml#^IAEuYlKfNij zBOyo}zyhC)7^m<&IgKjP{Q<=pWFbR7hSiB#y%^k_az7 zmwC7ChL^!4=#&%Ve9s(VRp*Wdy<+N>P7N$QFSBE(c`>)1Xb9g|7pB_8Fqe>&=%>Qz zVnee~sUXTC0O@Z6^rb<(mSey!K6i1A;2(!>DD`G$na2`kN)Pe_<@1u9K{Zk+3u~)p zJ*arZIs0Vr9Po7VJ-95^4=3GuPC!7ZK$RSM-vtI2nP-fv;=DuF+ti5W^pp?y#I1aR zcaw0j(e>+_@T2)pEkC5cChx2rlqj-9w#&6wuXL7{TBX}eKcJP{O_*Ipni((`_o@|Z zhUuuB@=747x1SN5_es*0?Re@o{u+<4OS(@4r6FWNl#Dne_D(uEsN+x|WlcdhY)~ls zz}T_Cs$jx3ieyoNX-Q+%=Xmj;YJH-1BgS14Y?A5yCLx2^CCu_`iGQDZvUqj07;bSg zxEVFYBfWZxSy*)c@dXv~xh@(6c@6u;2`0gXEAwTQiq1sEMbY-+2+8aPttmTmxOCG((3r7EyZaB$UJ~5zSeYXeojdCtskqKh*AzIcNB$o5;MO=Q?<_icVi&vv-zvOC%(^ZC) z-mlMJi+jQ}5mJQtYo}d(s&QPj=cM;cp@gpg@)6mr6aX2-6dXEV@KM3-LTFM-!%!Ou zqcpOgEXi=+JHStLp?88xhEk$dN8s6;%1ONAxusGLRbSiLRnI+wP5miC8%0Ctwbqf< zzg(!@wN8DkUH@EUanuA$Lb@w08>+8z@*7EgiEz(&M3LXJ|h2wwGFH&FU7i_Yt> z*ASUu)T-%tDsHc?evDJ(DGH2f`-6oxN4`loJt?t?&+3k zSX1hN!n7oGV}-c}7MLH14`@jlZcO+53{^|0E9L~^jN?(sR!fM)5l=S!kFaE5rK{Sg z;q<$mq7WyvY4&)a<%~<~R40~GOWY8av3;@pvvE$>%vo9Hr&=+yF`dfMk~;L3brWA{ z-}n9{4BU^d_zDh2GW44_Zfgn&s)I8@*aG z-m}e@)<(CtFo*G;uwY${@rJ&>ipP{CojOO&B~IJ<=Luj;%}RyI}cs?4|V$@(aT{y5le=sRlo ziL;9b1*eG+&6y-T#7-zLyol{yx%-( zJfZME;QyHkSFts;SN~lxF(U&3;rut5(8TgTDa`*N5ngF%%U`k~`OVbU^-CAjsBHmt z;YLD633O#UQ)MrN!w^l7CJpWmmS0^q{B(DdN@tk_7KC=QAVy%#_&5{bIQ3D7IBr*{ z0S!l}!DGNQv&woxq$z?2F|-LBsyQuF^~w=LITfzi$X}3XYRp zFo4epTKwg*?}((+egG?0)LGEW6geHgcpfjhf=)OI@BL)^7c-GMPUdsL#GMLFk!2_a z2ZEg>m_QKVV|53mn~oWth)~^d&+Q;i&gXfVM8W5Ip0soIX#zBbhNk3OjOJ6b*22vx z(zc7b*!DsWnN`&1Esn^{$UiH(b9b2z3s8Zm5CCKZxTR)${Q!(C_elg?Dv;`%`NQ{gY7j9?P+&ZL0ae-pfu2HaxV(Fr%vM5@#6|qYcdm>_X zJcY5~GrM5~7CJpW8*}b+OJ{t)#7XD{&lS^HTM(*VV&YGTH<84)vq|?3A*;ThlVPlR z|7Q@b+;yRWmJJjfi@dFyGG^8LWw596P(eBJDqPZS*D&2NUSkF;46-Ykp_*}0i!GN> z1%#`W?{s2PPfhW(iYU80#Z*E#Kd}mUj7=8!Q4Q`uG5I%3@P^&F>cS4hmwO0bqIWPQ zP&u-KKEmO6;O3##0*=W===!nj`Z3pLX0^l4^6G+MYrOLp*c+nO>5rJljhMur`@jhs zM%(I}<)!&xYm>m#VbqV)kxBTRHq6>8V<>O^kwuHC2HlEPtaP$YSZ@kXW ztNYqagHNj#FW4`eJyn9Kd9TKEE%P{it>=!?dhU%cdr6Onuu*yN zGASx1jO{y|u$3P-o6S4#ztr5zGV=`s-Jxdf6D?0B@m>^Gw2l&xaAnzK@HR097s8)e zNo;!`J}NkV4g~5moR%76)Fhf+5>%1Fb7E{0EqT@*@zlw4JdnW6Lcdr1l@`Z$1P;FS zT;X#b)XL^vU&l^!FfT~E5T;c3ROAK{v>z7gvE#&ly4+;YQHttd_m07@NpFs(&s=hF z61S5;5Y`J8D7P^A{eN`=4Sk~=1}IFZDL#XA`v0$@)_>Kqpn(2MGUk^Ctcphn1VmQ> z1Vr%Pd@RfjT}<8oe^*HVgLCYX*V1`ItYPPgdU&=H=2QU+SgNJa%d0s`X6^o5OVO>t zv`vAKv&bMAC|H>3A@14zl3?RH9hf0F8p(BV2FHGxJTNEcI;Z>ibdP+^t6OxqZ}7>y zVR+SWqH6O-^DrX~jimQZH4bYkLL<$KN^el?%#!Y!v#y`h4y!ku@!> zFKhLgB{C^~oz3winlMIkzhf4-nU}pJ|3FlCKc=8p+Q`Z|0)E>G*8?AK#BNvIasm2_ zH~&nVzeL}72CtZYD5>p$n;F}w3>vS)c|=-J@1YiF|H3BAFN6Dz)t*tlb5%E_FHP(h z6p^jxbZ*hD;9!2C!Ds03@7eEsUl*>wmx+68eLGisM?(fbMbrlKDT+*5Wq*yh^ksDA zZ_)IF(Ou#O1 zz`jFb`3rUT#R^SjvKd@bUQO#m1fM=$X`%+`#`Lsw_rbrbCf2`6@}aLJ@MLs%?l$=? z%I5i9N*)<}$pZb|>~Osyyj(A6R4caoGTA>xQ++w}AAW%gnX$B#(mj!X9%O-JqsaXU z;J?*5A0d3hO?(=-bG9)roXOJ}ExUSGynHpbf>Z{dYD3`)$PUXZfQ&uKuM21C%+LQk zei;n48`eV#gwsbvc-bM)Uq3O&NX5mNXUJEETpY)TC2e8GUl6r9Fo1wYN|?uczHIdSXT_27gU0wwmxjaEKBB@ed>zKZeyKL|)qj$|9= z`R)DApIg;b*eGo^Xy&_3e=QcD!>!sm#^lDq!5Qcq3hx6?0$U?lmh0Ks`HTB++$Y8I z@Ed>Y&e{n^@Ic|t<-^xR>n(8voWh*7@qQSegr}YvC&30g8jGM&j6n_ zPslO>7OZ9FOyI@E5FbbCIKdQy-}txio?ZT^KQ>tn`&%S*ju`(G-pUG>t`OT8jSIO! zecUqI`X-4x+B~I1|J}kZ!2P)_Ye)|2I8jpZk=fm@><4O#kXk!#YlRZ$A1lYshsMIk zo*F?{RO6H+plq`j{Fq~?#SH})KzdK9dwOXPQMbZem91x{K zBL77N4XsoQ?E~=7LKt9t=5S6Rkc@cpJr+4~Lqo`ek-<8v>m5#zq0G6A%{f416>WnRtVVhq&?Z#4KE|+<& zJ;f%-Nc<6;jilymOj7sp{1_n$$bdA`)o7ZRZEkoUs-mEJj!f~(v83>dup{xTFbl1#bu(7EREGm@yrHksV|lxcpW9QfRR|qm7 zX1>E5(86hDc`uxfyfc47C=C!G^V^bZCyKscvLPbb_Y#kFaSqdlwM6Xj*dZDm6Nd*X z7q2a(c8?Yvd7(y?cX1`CUlmOnD7y!T%;SOZ>v)ki^Xc4tg8$ zPBj!N=Hu49hS-{3KW!pzAd?#(vk5IntG>i`z93N88eq`eRYJwqLLC@%z&HKNfQJD~ zcIdB-tgdvF&J(aSFo$F~*T@#K$K^>=Vfo9GcfOAt=MXuWOP^{Mv?Vct;^d9mY-2ap zGQ@6w{mC4{lVHwrKfw}->7ds&x~jG>SlA5t6D7r->1c?c#ffn@E83N`zoK1b`+$y zcv$+^wJuc1tioWDHe{)9|AXO=7=E;h{A8V-M3?AYKs5u}$HDuMjnA$3zDN6WikF@P znto^xL_VZbXZt)AdYKA_x{qdP(ZboMWEYi@gt&!7D2Ur)1aTPTzXy0MZqfD0u_vb-6_%SxbL_h2xop=NiK>^Jw(#E=elbc=E7$}@1!zYLhX*h?{)C+R zTz+KRP=Q6VkuTz3`v;J5{_FCK!Aj{sMOiZebD3W8)bD;$T+62|B|Urk#1T1uo#|$j zpisQIa>>)3KrYeg1`hC7lLOQ2nSYqlY|+O{^r*dFD0^wUtzkJH1xPYce5i8l`<+K{ zQzP|4QBIxjy>xjZ6<7D-Yq1FvKf=w=b6a%TMCR*8UqNhUVP++e5a2}?s8egpcJ)TO z!Y@i_Z`B!nw+(|eg`i`=xxj-=jutv2$E=B16&%GEOOF3KOmB>@maUnjC9?L=ml}AZ z$v2%`Crw>p(7Gkq2sdCBV|z1X7FmYTXHsEM##BuEnWf3>b4>)=d}8Lz;S+4HY{-yg zC?$#_V@yfx>Dhl=7s>P|o?G_UXMwV8g8GDF{)QRxFYQezph`4GI}2X##(491B%LAyHz27LzY z57X@jvlNu@m)}&#$tu4fK`SsTCvZoCIjTI}KKYvHBF_JdFvhVr^5Z0=Ri^7!Kr$8U zv{cF6OkNs01fOYU*4TO`r@-}>LZ%C!e5q2G1;hA<#dK7G3;>uCIifArs zb5&?HmM@KxdY=^@y?(D|OgdC7ICfXO|Oh!1*(sF($uojeg7}3X-G948j7Zrp`)z58#jXz{>rsSlI;q z>t|wCR|Flhlcei!hIV7w7rYYfWd{5YNyd}0<|^NyhL|^*D(&VuFNvqOG95hPDw1#{ z#pRS9jzAK@*s`qhb|a=^ra7Tnt(0{d1dA0+IHjeyXs`=oh`af5G{w@ueFmKn*Kkk} zH=KgvW*pH~Q116n{y!rcn;!N}w%!$Ho1&MNU~V+Y)7#y9K@S&v&SpFUQS#=-KRx?+ zVU>e-cnde}dX(x&j!Fs8Rfh;lmo=Vg$Dd-ROVt)x#Lsp`olxzf7fmo#DWZEhH?NHO zky?)w$Rlbl1Au3w^aVZ;lsZp|{pt0g==UFeG&ojS3DrW9_$&9PYuM}aq%4HacgB>Dnn-Qa^# zXn5DjU&XGw<&@5FtIe2kw{-g2*{hp^{_g_;JYyBO_U93>tWI{+Rro> ztZs;fq`}d4U})39rO|vY%@i?Tqk^+&!6aKM!tn*q~_%mD@0h_DNy2Ck`rPiaJy@V)wzi^ z-P=OcEiTF5$hHQh>_yhd?ra4-j>JWV_svoW)|BPY^%jZsE_BF|`4EC04kIL;2aBNN z_UvKmQvl%#6#d3>u@}Z528%cftl_#1)wgvs7fv2?n$u$_YeyMjp4g~Kk>a|!RK?bv zMcX;d#Fw68qLyzJ^H?^Cd#u#RgX85*U{Etvduo<+mG49rM6Rms{+_WtJR_hUvD;rf zBiTzW%fE}z+$X?!4FK_T!P(|t&nZE%p#E+u&kirzC;$X7gKGVS_mLJ5+w=|kR{bY^ zn;xqjGivKc=MOQP6P4@ge6rvzwK&6<4Yy{HCKz___(f9pWl8|={LZdk^P#&!S*PAo^i0O+?F{mOW%$9oBU18t+d z%}pmIapJc9!;2_BV*NPbi=(L6+n42c6K8)n0y3sQs8d$N{Y<;tp?tlc$6AwcSqevy zuPc7u3jsluckc8uvEZDQ?FGW4%@|nZb+icQ_m-*Cd6wGCG#q@b5X&YM!-XZreLHjy z`fLt^jz;g-Jj2IVcIt+j{YHb{qu!G$3>Lku#Uk1y;BhRJ7sf^Xyi$5HdtS3%@pj4g zbayx#BA^i)UgCFaF{^|l&y}|{8k=+a@IpKMdV`E(=Th{aZGC*T%PRXs<{fS1&O1`0 zbW;7lB)-rEEC`>AL`8)fVV?Il{)`1t*RnC;t9^6{K3;B|0h1b+G|@OxjY87~8G7}j zH%wDNgJ>o9Y!BJ#Gev(sOf!Nh$nMS6-Y<84#&a8_v1I+|Ko=3k_Ja=skHpNxL+hyw zs!CKq?#%?|O&_qfrHzjRcT?c(SOX!6PNUxHU!b)d4qrXNJD(i6gUxf9VR0I~`McRc zCzKI(pi>7-CL?}xRWFOHvUySs#wXDFdHw==z=;a@#kerBYoDe?#7DOshUZRJvt_$+ zHZ0@5<-^WU=eJ&}XQI;^0{<|gpTfwyiwLXj7(MtSdcIT{W7A3P+JDImex8iXS7R6+KM*Y5(HN$xK&F~2 zkveULMj5e>k&;we@r0z&ut#g_GxSlA3u!=`J7u;aURBgB(q{2nnMMbT>htA%v<|~; zwpls4ndEphFl>!>=v|zrH`;tmZ(qN*eQl)r!h{q%r zH%6ISo%4J>W6CyRJ#2Y7YPl9$`vpgv`HiJ%+s^J!kgY}6!YWsidj)|zz_nc0nln|4 zZNftu*{R&JF`)Il+_XUzj;yPzcw2#O3~Im`H-dZ3bqebb07kq8giaIe+&A?x^!H%hK*e>-5b2k*>pTP2b6iFFni0im5%vCaX)mc63tO(8%*KP^MEjuOxc z&99C@#U5P{pr6OB(LuX;`&T4NFUM76G`|Q3^-JPuB4kwyqY&OZgiM(ZKdJM~MY9aO zAb<0Hzea|5>~j?OT~cn4XL7Kdre#DCI_W}$5_4395`^`h#1)9G)qv^AMkAY>qfY!g zHl}ZD1l_iX(7{QdGtF*J=sKc6lNg-5WqsYM3A97Sd}LUB)UZ?ue$A|uIm8BiJ^T?q zU~@tDqC;;v56EvKA#+(f%fI2QG={z&kA{$9LaQ_#hF_mqQ#p~w2W)_7(sQ^MQYV-G zdrg^k%^}yQca9)W*-6L)(C?LgzDV4|>SV~@w$Mh^BH@d8GUS-xLhgMXzUq_rNGklM z!+=d?xjI{lchxc4IDWC*usbV}^H8#tut%WsG;%4G!9f3cTmMt%>4&fGP7%7$8M+@A zhSzu51U0VHJxG^`kPt)pcRn2@_BKkf9z!0$l?mc01CFFDZ0H_3seJ=QHQ$-1?8vce zn3c37MSH%~QlR=2x*bQJVCWfwxoaP!sy7t zZ!)^8|4&{&=MHMGO&TMWX^C#Xqa>7Uo538jDY135A&)b-C8q$2CHfA{9kzL=6ymJF z7R8+d(@Q;&v7Z$Zq()Zw2FqozCROKup~()tE^xR_+S%W8?-h8xom-KJeE) zY~recJ3^F7o{B}UU&2|0t#XP6E$1FFCnBFjWjF+VxCU>My>qnonO>$+7}=)DQ?bnc zblyIpB)<%$omxez5Rm{OgHkkK$UfHS1CT$f=$B_NaXNt<-`~d;i7a{hJI4gBP9fYu zn;iob+wPE=M0(_&3BvUnWw{9kN)r8E&BYzodTQd`@ee}ec)#oVqD!g3y$5bKC(r{V zrG&5tW-{gYen~Gj*)BgzGKfn>5pa@kZ0pLqMRn0)&{4+*lhFKRp#mC6;F?jmO1|)> z<1nIO!S%=h8m5|<_g;A^Zx>T`XAifMJE#_;e8f3Th=5cW6W`o0fa&zq@nnY@Z6_X9 zzcLnu^e4H6)!Nxc-`!oh>Se2ME@$iT82?ppj^f~KXx8xpZ2(}{r*zpm3C zV>giKP@z7)#@&4ivkiJZ^BJDMshWhI3~9N ztI};GS-j8kw8zNd$hM?pay}Y?ML$RqGDqT6WH8n_FV;DkL)z#{EP%L=p^>59cw+eO zs-SwU-bXB%bL5a_Jnl&T?h?aj>XeO4l_e@+R|wrRU_w0-#qwr$s4~oJ>XN&Y*NYqL ze-EszugYuQ?wNjQ7$TZuq#+!48|C)>gyn8c)n!VZy!0a`{7B@V^}NE5KbAM=ej6=; zuOHtqg1%{gJj0UMu*?WSm*Kab=dTM_VQfP4m!+zP8*n9Kf5Cn1xc6}TMJ$FN#Wi`S z=@&>G-g$I#&6{2ri(Lf=IXqF6GIjf;giTGm_|Cqu=1_1>CDk^jW}04! zEfH0j$iU8F&yef;vt|jMd`bLZa6RBFy5OBmgcGVm4LXOv+GZN9dLG;kPAn0Iu)Hb0 zY$0?B(Z$P;;QU}|eKth0Kf!mSA+Rb zY?gEi6_T9Lvbu+|X`gwrSvnb=NtW|Hl{vO zgd8g)fv`PK>xcv?f&%4T0pBixup^gV+~*g34tl-- zc7pWw@)EE}Yk%A3#Wj#HT^B-~URNp6!%RirIrDp4rK>BUD5$fXCrFQtRL zA2E&4g*z#1i4$>Ix=N1H8Vc*A)EQ2~|BPY&(~e;uw8T|>%?4c^LiP|AZ@v1IJ$B4` zKrS@bgwQJ_T6a`s!l?#AeaA6%=Xz#DG`6iPH!j@GKqkRzGa-QNK>)n-*8ZypyD5I# zaY#LA6fKOGs)-V#W=8ygwyU1U-dp&F%NsO|!_qzn%znEQx?8uF1UHVBKul1{PY=w$ zI$qhG?xLKqyYy~ME#xEWweohW9Bdj9YS*`M>;gxG_B^?9UyJH&z8UP`HNg#qt&z%h zRj2p7S}SlWN~>c@XYpB+n!C~Gi$rQ%-6U4bG-bo)SWM8{yO73~!|YMk>H(OfJa>!O zf$$yScU}2t(YXrc)CD&Ni&)`qFyfBd^Mlt%-PuEvpzvP}H8-SWYlBg#~j3gOGrpfCQ-LIT^S$f+HJREKU_U-SwcYxZg+0X3{RkDC9ex& z`Ej8Aw5_?&54ORoN3MCf&UcegDZZPAB}i;Cv}1)ApLgX!-9ENzYwQ~9pZ3g__d(K< zyYrEenL^IgXHQPJNWTrX%cWQzD7##toP(I^ksBo}UU>Al z)$9RBa4exyoLIT~GVT5N^}`(DfX{7_v5NjL9Sjn+(XPCrrgHUBj@YCRjV02H zkFTqHL;tsA5kE;j4B*9@N3PQ#GwA&BvvPJ6fFV?xhL34*>Juho@za^cRAl3tHR)>G%x~=LfJamcH3E6 zYu)jq~dv~iaMBUAN*gWI_hA=~`!ymSPi zoed?l>eSz^7LqU8iR2TM;a({skBWONq-^EDjA#O?1IxJv%w5RWIZl`VO4piHm7b|r z!%=RurWvfepH|Jk+bxMLEbHoN%@dE}Tw&FoC!6flQGRrYVX?%NlG@}8iPJ~oc;Ion zKChU9HdgX5N;QMoK-avl-cH?I>o)3^PuIS(&8Ryr*mlKRo7^u_NnD~Wi%WE@*xI*- z*60#nV4pV3=LeD1S~3>;aJMhwZ@}) zsc;aWR>_T@vA{+NPk! zlin&g>VN9nMSVeNn#Qz9z8F1JV-}SUJ{}csw>HS_%YFZ0t`FZH7BQ?YvTb|q8z7~u zs{GOp>mZ14YjCBQdp0SFrfiv^AWN6M`?RXB82^bVvAw7V-TzY$+QI!_b<_Wa^;?+Q zIQ$3F|35HFODTPJ8w_w^H=Z!WU6i5*d7^!NjD$^7@i) zO5Cz7xGlp=vh(s$pE6QX?a^lEv2i? zJp%|oTR&}F`qlGG#xoona8rDbeVCwjnqbc1q2K&*z2h)BuVl%588i{Yx*QRTnsk-9 z{6E;pdQR(J^G|LUx(ampBlYb27d}M+*q<}l<3%7V=DLHrgBh;#TEbS9wJEHOjjj^m z8F_k`xEAO*!P~ysdqIBaL_F1S^>SQ!MX9Xt5WdtjX4jpPdTxbKLj{c>dq$%CadG2h zP7YeKt6lf%-*#aEr-cuAi6ue;8nIcqql0S%U95Us0)joa;$TxOAV=`+g@rk56y+KE}Kw z2p^?BT_YFo0VvW}*-bnQ3TP<>x&4VEz4T){MBwI|cUpGYpWP<89vyrJO&>uZhM8p@PZ zjD9{6l1-FR%l}%EuSggqyx`y8-s(ja9C66qz=Ka31WE*L_v6CrBPi^C^gMd?9As5P zW#?JTtic$qUV0%w()?b_W6(LKDMbpSTrUb?N~W00q9KeSfj}Ku1KPWY8RQJEyFGnE z$cYggrZ!PNyY{)78Y3T)v6il{y`x32=a}Q}tnB2wZFd)40s>&9rtQN!K%|4n{Wdg_ z+lTRJDdEIyMV@2CNsLr-G@9;G|waiE1(on{DK)%pH1_#y2iN<~iC zr7f)mEM=1g?|^C~VXU%|(sPteUDR<(+EVTu?8q!pgf;oo(_JFr(xw45mPf^hc)Ytv z)Q1r=hZiWs?jxQ+o#Lc!U!+fJW0qWIp{F|^f{FFsl~x?AFTLr4Rd3PX+mhc@r6c%b z6#aTPn4{A%_3MPGB7-RLD{oHFiyeg=uC?yO0U{`(>tF|t`Yu{#e=vtE7m7O+FuJ1d z^inQ%5C=a-QZV0F(T|YN91{pN|4z@&IfzLIj=c(9oL_4dBYR>HFK%KYtM`9^lGcz_ z?11dHH9|&M<0&}g_kywnLu8wM{mhgBQ&Ns^1TQ)&Byum}!XjYr?t0v6HeN8Vz2N_# zenp<>&#_KA<^UU2QVM?}lC*&&6l2;YZ<;(KuzPf@i0SN9PLXq!(brH5Sm%JfP{5GJ z+-Ks-3q9u*CA0{z<`85mUcgkO&7c8H7f9^r-~IKTRYzby+M8-G+`Pf!7LoAr&<03Y zBgB2Vr_3!(jLS0G$=Ub2CWMl*7^Y zHB@J^HOe6?oT;E*FO&yKWK%`_$>G1Neza;)p5unqdMzJDfI4@4M&Dl;S_3gsj)*)j zEWDUgqts2GiOUV7*8(V3RW<2x8t=G6jC%ICAKT8^oLfbnETFGOV993x z!jT#Law0!{Dtc*HF6thxqc-x|Uv!tVg)Clw;iH$pqnC%) z4Nhuhkk3Oq;D+#9_-CI3(ldtBlC2PTdh~#eyZTAe@@vz%oFeW<{ws+bX ze%4d9ekGoEybCI{cqU1nY3koL3f1S0A1rOySJ~pszPxt^6N7(U)|M+2SI$un+`^bz zCN$}!!K=VYTRN$#s8R)dor%$9!s)YPyIC!>2`3UF&db{W>EjeB{RUEp?gE4LLa)1%vR(fYnD-TFu^h@*#&4WfP0 zlIPd*7p8mWhOTT?cl$%Rp_&RL%U(O)K6>HIusbeB`Ks3|h{2V5+Sh1*P6)HJgK(KD z+C95j4indwH2^ZRVVc3ptMS?l?$#~YM*cc-6CP@nwst{3yP*_QS2rIfz#BK-n3c~x z{utOUe?M49hGtvALajWaw&@i)tvm6z_F{I%?8jF2c_~^z-lBQN3&7!`c8G3@?2Lka z#7|c!tdu2(f#@MDPNQjQjC@Mr%6z;-*+(Zh_NJUId1sizRg;?gX6FPCE-3q4Uepi1 z!sv2LXnJ(l_v>tI)*C#PKXYVoxSvYVSM+gmmdO{qR$9Eeq{z5mC|*Z=Zx$n2vTUHt3wVnhG;?6$4p|I>fbklK#@!aohijJ{#NG?l7iX1}y3 z+3c`@goJlMsa3<~aXY!J(SfrX#QM+lnxuBe6`QQYZYtNw$K96uiSq$REFB9JC}$FU zM-aF-x5=6XPgKE60jqWSKfPp=rZD)EFKW`a)I^43x0P&VptIVdl-0m`iN-jthbrc5 zK|c`+Y_h6#fIGIL%Pb2-JlHg#-Cwz+_%F(5ii6G2*Yp8KqVKtg@UnePC_1YG*kepJ zGn|eC_4=D6u^23Tao=8PnJTyrWh^u%xFxyQn-Dl)@SV>x7Px~)FH^KmxzUBR-I zTcq?OkuCK-AFKg~?~-+x-?1@o+qNwN*T)usIuUQ>$qsoT~m|7{Qf_1 z=KtD13)u3Mw*U6e6YPJ#f1FMKpZ*88)OMmX1^#JlYB4XPh2j^eeL0|-rT{5IXMzi* zg)~VVlBlk&9oF-oFL}60mIydu&v;GF1fvJzl(1(`E z*y7Is`~CBL1o;`%Fd>&ED`?S;ITS*ulTMa2)o-HtOhzQ%a7;4^8}G_m9?76<&$~@) zcC~!vaZSy}`zvEaSb~RLsvlvEt3ev;rpE;aoz-gJes()c4>nB%x<8iZpqFtWGvO&5 z)+Z~`WNw!_>n6_Bt?SX|cboPGyLH&G5X|I(mCHog&7c|T(KpxQtwW&AN5vq_5^fzg z`w3bm(YywjJ%rk{3Sk@IdSCbBt@g(jlc$5;3_#c2v^EZFPDzOz_Z=9#>)qR*V;$u- z9aqFv-t{b)*F;h!LJHV;7}RGji=~oTH$EBli?O(Tqz_k#yKbvicwniH+MLSo-eSLD zY;t0eis~kY^5Yl45@jUYNjKdna^rt}V-$YT_BVAgmeZ7X$fdZV}XXJv^7B2 ze&`ypJG7pV>G!%i;Psxs|Dcuy%N>HA%_mRL{-Nai-s>T~3Y<;FQubhDA{jOZYSSgW zO8Mr7H8DB)nFHZ`4I|0f)0n2+`ofS*2H?0|IbJRt|_II_0P!9@GlzwckLY)8|VL*L{qG${a=a3 zZ>5%SF?=+wgiNpl1e}G&wUJeD*`vPE&;ll?B-Vxt_1`Ei+?$mf1q|_)tm1)BqbQ!w zY4>Zm0R~JQ0|cf5v^VhMLl*MV5ej~YzL_c~Y1{3Df7O*F{?vX!qyis1mwzfIYIUr< zOW<<#T8T0$IGBH`$Gx>PQW9g5)AZ{X>Lh3yG;U*V;@g!>6o9E0UF4$Z%ue_pxK0PGIaKAD(;>EsX6nlQ%}S?oDrt{% ztruLUyLJ=n&Ac$*?ua<&Pk94PvQoz>V^XkaS5R;I?itX9b^Qg`INqs`(GklW$+2R0 zyH2O9+Vu+YWMrs&l&in?lniDmxw7t;H_mcCNPalm+H!Q`YHKMoz8sjmfCs|W4{@{% zlx^;YPuQ8goWDP=hNj-aQ)?c`3{2^pUAGNH($Zn{+DTa_EGjk5OP*~8(_duQIlX5- zOH#gb--KhrU?6F9x=k?W<-BF2WR{!RG`$w9W1mUpSzFLsH(OMTo^@LB$z|i@xJ>?9 z_fhvZOd2fo0f$z%8N|{@md2q0WvQZ%V#mV8bJ6)ZxaQY$bLa-|i}@p|r&DxieoIf_>7ACCt`j5(w+k~b>@Cj*oPcKX`8N@i!Dc-Gg_0>> zNxd`L>*=z=7z6NDij+Sw)GS1XkHoj%e0`w~&@ zoVEY1SYfr>9Zz*8?ci|g&T;0XdydoK7**~kGTw9beVvsT)@DBZ^tNvBYXd``Z)nZKi2XoVA%Qrce9Gr)N14-?^ zqaq(?Cnw~;t-L;VBMqq*hP?CB)?7Fgr%nlB{F&?#&mVgpc=AdxVEB27|;9-LmG^xqUv{O@^F|GR{d7$4zvKGI^Zyz{ zVryz^@8qd(V`})Hf{*{w8hjI*}}rwU!pfft#av_FOD=?_wkTlKK?2W&)zA5FIFxG z_m-P`WiYqTHg4}`(zPa@IrAe#)ZHw_T$W)C`bN%1(+~acWID>uaB@&**?e0k3ZSHK ziN}}X7I&Dzlv`=ir`~;`yWWo0*@)4{oR(N{7$N9_mdBHpph@9%ZL=R;Pb)b#mvDxLl2RfJ5*Rw5Ou}1YOE3=T{({; zZz4>apH?yY79{IS-&Qz&xK&&(m+`5l_4#(v)q1>fz#ljDa4?sN-cY*wZPt(Jv1%53 z(~RWGm~(P;A6J66-bcEM0Oe{|AK9bb*cM%=wr#NT6b`pOBk901F2NAIz2~S^uKZ2P z3Q;PxRyV`&PAhUdm_O^dK)#CnE>)=e7$^RvPV3_2<@WjgNdJ8&{4=n--}q+6;K2?D zUeRKEeR&yk$67E+JgMRuT}=@_Vu2sUhD>1X%#5i&sO!i2cM6Ag(Zh-JqkHFUdeGb( z;fLv1t^%r~th!~1j$r1b#yA>*tvkh{qJ7Tzv>dxh!9%Iik2iJ3s1CIDvPU z-wy!k@D#um)4?MyuIR;>w+q%-Sj;z%+FUPvSK+kyD_QB-R&g?ZM}#FTx3M03n>-DC zP}rH_&7XNJ8LnTqWJB^dIVHn6@Pkzy?fxJwb8f(y;w@? z414TEfp=RPW%8V1frZkomPj5fhqHc(56ZZ|(G3hp{B-ig;XS+Gt-s%pM9ozmUi-5{ZWbhmVOBi%@ebSmB5DBWGsEhXLEA>6}PKhW>~ z^xng>pXdB>c;DG;*38~Bv(|d)6D8yX$Tr1`w&%dx=eR7)7Ul+5$6kr_uXQP-S_%?H zPRs(onLB49nq21noW^QeYl#WG8yXrgjO1|~#BD_iep`r>tcWysA1)DeNC|V+IsxaL zJo2r!A>ADNt86HtJ8}U|T;XZ58f>~Qr({_77xjL)LnMM=bOr=dt=d?;hLF7|2$$Cs zB+ehV$25ST=Gc#6(BUkKOuC+@H<>E6!;7n|amIZn1Zuy zqstW0@=NrT^HnoPYM^-DPosi*`FerrcO+p2sPnbq!$G;Hsy%Dy5~>p1ZK;HM*2zXb z{fVjKU5Lmm#Y@JI_1?xT_vk!KQ4C}L2uaXgAV~d*IxF$Zq$u^4HzobNY-0FCg%ONw zZ%eR|xGhesd4ug{yZHLbgtV@lwAh)I1c1Td#i{O4v1y!Mu&dv`G7V(zwG`%aO0aNi z_>wHsB<6Hscd%CBkM?XZPsg-W9(~zHjT+3+1_|D z4m1Qf-x$=LQsR)%6&u-!`@1=sdRP<_W_KDDX03Nnp9)8% zN7STZ*DEC0V~2$fz}c$Crz=z~TvXrDxfN#ie_E~O^H{gL?^ilnTU%L@$W;Jh6t9rw zS2fR|y(VuWk=!m#z)13-Lw?D7aynB zV#$BGG=9&vB2BFGt>%Jo-x(ZsuXseiyF{#5?JMw}iQZXcOzBvR))45lHQL>z>5JaV zfzpgth1o_5yCZWfJq$Rw=ExS?{a^#2G5taIe&lH0THeo(-hct4onE9zLaKg)x(N)rX+>f_{W-lxd zc7%22&ysUVtT$l=Zzm2U?!%oDY6^5vx>6o^yHX`|+k#`hs&P}Yvv@@8eo1_f0ZcPx zT&u_}tv^wX=_&fMtd}V-Qotilg#nv9c%%sAkSw zRBdoTcMjbYl^TUAPUyV0v_1`_M$Q6I1Da$Q8$aTjCzRYQ45hFWr(J6gI+pGvTL=}h zI|!DOLU#%X%PqGJR8EgP84Cd}M8;R@D1)0%SbiQXd*2L;xa-xjDlB^&-RP?trg-so zq`5Ow>}9S>BE+sdQClMG&hjbfnz6zMs`20zUbQ1H*k^sLP&Lpu6ydBe;kr=xq6(1n z0p^XtAK3s8F?&{NzZ9qqbb5;i6rUxuFJ0Bw<{eoaCsSC@tY5(Xt#YektA@zQw6n_s zcCEn6<=P=0$b_p=s2FAJ>RSI@6(4#vD6Hq218wCTy=#h&n#!kMpAZuozGE9lNKw$R zv0f--2bR;?X#FlcIQKQfSZlJ?+v}?WbKM~VYr0*b%EmbRQ7v&|v z$Q$$q`*Z!%qIQhHq-gC7rF6Cl7l8yY$NhY~gQODXWOf~=zis;gZR^P>RFcE6$>jyL z4ZT8b6T6Oo5ke~~Gi9Db82>h412BR|@&$OcRU{@Di+T!lFOB_m`Lq{FA*NU^u)|jh zyH;GcwPfba2ffV(%jT1acxrP~818!!_Osz#ccc&U!K*f)QVW#6#Ek3qeu_A?R>qvT zlsONi3t?**RC^AA;9ST?w$2~ON?J-)#fiydxTD9*?qVW%5MQX<>3d|0Cq4&DIyG)H zp@4TIwk>^+8geb=U%AQVHD9mAgPM6=m_jh8jdpO(;qR1>L1d%P5Nqs{-5DENMnKX-1t*of?aYM0_A~?aI?_FiWK)WfCdf@^mOF!^-g@Ujmq^w=exql_E3prwfgQs$fb?tyaI@%)D>|xQPe5Hv>_Gmx>P>959M6FQLoRpl@Hterv4>7 z`Kr~Rj|TaOTc_vb$_~xjUs(1wST|OB-IU5tc)PW|Obdnz5G~!RwGiaHyVWMpy-b(> zVt-!t0unFhCQ{>`uj61%IA~ET_J~SV&AGOT#DomDd^^~3w!enGW-596^u1N}J+td>oq71_kA$Ppy4_xH?Z zS>PYg2)c7O9T1LE6{qsr%*Dy#UX27RO=-uaZY}O!O8`>{2HY3NcP$@I!gwbEslDcq z8j@w^M8(qF8&?V(fHZ%ep$z&4?h6~^BEiS%i@nc!QDxgSW^68by@FEso;1fc;rixm z(bH<=az~x#3F5a_iWyqUP^GEfShze8EY9a7@%$R#70`}WOrc&_J8`Wb8Ql8spN>=%i zatt&tMxs$j*O9{ZzupK&l-S@EIUw00l|zlh0iENQRuVekAG_PZ$J00~g>vaP=TKQs z7_6~<6Cqn7>i4!SFN#@tyo9|46D?c0Yorc2zRyQ&sl~?BEoP`FsuC2PSX8!-3;m1G z*)UA*0MJa1ExZLzu}Y>;LcaW)sV*EX-j~LXtpr?dFA*e~%=GtdcG;gf5PN9VqYT~3 z(!7W6?0){a|69-pHINlG7E474bU0P$G!y|*^4d!K_h0KvqdjqPNvA^j0ztN*KDWNM z;78}JqpyNLl#$Kiw(T{VY|bvuw56d)Jk2S2tCv_N8vsLH@R9=GG3PcjG*4r0Adw!w zA3}&)E^?u?$d`G*Oh3&|L9XqUD3wJLDR&Sypd1T?Z}c+(pA$qdIAkboda!Q@#qcs} zLKPcVuYeOehG4lGHYdaGt|n7zod$~>d+lT8k8U$RVp%EMuEr>gC_~9$nOx8mj15Sa1l$WNZ4zqlTBcY9yOR;}w8ncp=g%DBl6b2V|CtT%4B;LoDK z1Y?59xk(yEa!;?nMJj;H;5uIznb)-J z3upu|DlqPfNU@E^uJF1Im}#1C^N1Ccr63Hc20s(Z>x9&)$wbOJRjbKDw4JS19gK}- z6{UJ?SyfP`1{V+LORDGNe!kjL3*7pHs2hFRYq-mpJcBeuXd4S@Y!2w!lWRlPqvh;Y z*9*?P7#R1HdCjj+H(XhlvNo(Bh1ryr(eao(5j5e=T#BvoUv!t0Y`s-Gl_GZ~h!JUfSi?DHyDWhK|ite+L z>)Rg6HZYwjpMt`3SGq{N3}?okbt^=}I;4 zQjVx zvEhcM>Nyvp{dWF0LCWmx7-a0tY}9nlomVCJ!co!N2Q;cx(GVHYuZ@h1jj_#9?U$Oe z>0a1XZQ#rBc*`c2<7=(#a)rko_GHf;UmN?LzjGX*Q`sD;%b+ytP4;y-gv?y@JINb~ zGn&=g!$1%MKH4h-?(#!VDIZO@h-0J=hQicN80!^e10DqV>OA5XiGRkBmyGuH!zwQk z*uD-9#D#&azi5#QD}CCY+=$e>CN#FjTu}M#sD`u{h=e4$=lk^A0~Wze)!02Se8&gT z{rF-KtD;f8A?>vWrrPr;Uk`ubky(o)>Y zfo!cKi0{s4KK+9G&8Y1y-Y(=I4Xx;!kp{@I(rS}A(;FUfG|wEdz{0~OP0lyz=iw+uu4Q0Ep{2R_8a(I#g>Za4z2E+uFTM5L8N^>kwfA70sMC@6A9pX|Gz~ zH$uDlzAn=Nn=Mgxq(}>N5|JV!+y0F-Hns&h8_q{W>BO=72>-Tti9m2OQldMfGpDsd zN;GvAWqmugt*O<$6mCLZ@0pa_%1iKhB zbkXmX($Yhsf^=0PT`bbmL0u=p)DQ%8Qzq0Z<=4w5^9v8k8K`eFJ&1eE(@NyHiBi|@ z3T&0%!9WTY?@aHUCG|o7Ooib4PM1||gJZ`~xCb!lu);|y zuqM~&C-~>{sG3EdVbo_jF_Mvm9|%^zqDY;GUjVbI5WuZ|P&xSH+4p;6-ZEUu959Cl z9NeJ^P+$azb<@*<33FiXvYH$SYp!(|-nBR=w{3criawuC6ToOJ#rGi0aHW!rMTj77 zIHsUgK{%XdrO8e)&GVo(pH|4{^|Hl8gZU;76VaWU$u-I`BfZHI`b|(eyRF{D`cqbn zla6Ip%cQ_8h#F3hh`DX?%ViA9dwEwTmz!2DdcFAs$!-Vr;P81PySqxPzW5dXgT9uF z(`?iI-kP27(4^3O!dGl`-K90}|O;iK34OE*jl|Hr4;HpR037|wD zF4j*~xBjrt8&hL|=*n*$^zFTtvIc~}Hg2exY<2Ey=z}|A!mOg&s|OZSmhG}Wd%H~u z(49M`tyR7Mo`*1Nw@d1{UBb<}NmMD;BogGBW%qVadPufo!?vQ~yDyxU@LUL;#YpEPy4xmjh%TWLc!!^f@Q6v8(319&CPNmNIw!QzOud);nW@RA%T9lhrw>i$tA0T?D!W~)25(`vQzX-O!AYvjk6BdGEckLG$a2BoTgHE$x!|`3%bFjuP5v?rq>}43r~hte z&?ya%x(>SZut7=H;pwTTU`6(H5Vim{YxF1D+!UtoTE0@n+a?IiRCEwcUH}pB+_Qt~atY4#`>Q(insX zgcxd7=WS5kA>L_?yc0NE9H)F0AGi)i3=U9~#qzO2WI);U__!LC_u=y z*qL!bd7v07@L@CYRgs0TMQ-Eat#s?5VzyucA0wDKp>VJ3VQi#0K6G^zy0!VWwG=0K zH|!;Y%Q4G4UU5@=!5LCc(>l!MT?p|7;Zejr6O;izQ`S4wV{n`AZx z$!Nz?l-hbd0fsHOVwaJ^72z3u1;|~cI%5k640DnxKcGDowB4!z3M9S5q)!I-Y_7Rz z-=#f(%xc6aNU4C=bme6;u?kglKoxP(&b|=JRuQ6{`24Yx4o3L}(>KMP;mCro#adx_ z{Z1;=TE$AK^se>hCHYgw_l)z;xbH6H!61)n8GwBcT>}bVT5Ex7Kn1KYcd+R)vX@$B zbl!tW_3ORb7MaFeZ{`|(0YrnxA`aR8A=8h9WZ4YTpeC`cp6Iz$AmN(Y-J^1!@qiNf zPN6hkSpeZ;{BS@(SWh3{ZwxJr9;OmQfS!z@rII2v5NHOVF8;&yLv8%VxEc6D(Wr+& z`o%FaN5*U}IV}N`YYw9`^DDy&1x8BpPyZwIr!suS&Ds|Y}OCyPrYQ6KUrPDDM@p4%R zWhE8HAl&0T(d*e_Yvetm)IzAGHYy?ixTiv0#{GBVcMH}TgviZDxcA{Wttbtmu(Tq< zZLw)YqLBws@)vutd1i9U^o)7UI z<^q90bkh4Je0D{WTCOO}AYIbwa1tS?oLX;(JZjKFQ}KwfFEaDp@Cyk;pM2NlA>i&R zx>0>ItC-Xkp{2qSIOE8mElhM->L8I6PQ?p_i92=SiYp+BR)g$PtqgJgwij)i9oIhg zD|C}B_t*7{BF%ob*Y?N0HL+Gv-rdI2?Ly>tNqJ-t(ffr(I-8gwBikT~bCxSd!tl_pneEWtc}vYq_=roYA}jOt zmB?QlODh?^W`;m8~5ctB-lB7uOW+u|46zJ+TX#9H|=}zbfxMs#U*g zD>IWQcb&Bra`PGbr7}&#Way#$>2u}9A$_1Mx=_neU^ZM)p~0h z5EN*-AaNI!d9qkGi1!d;8H2k-@fKj98lK&yqQP2HJdU!U>g)xY5?cW`uhf$7E^A)+ zWkpnb?5!O*pMW4@?z6nbtHq2qAVQ9tw|v+2R+zKb8{K>y_4Cr1`{I;FZ?Vg~(`fy3 z)5NOTL3EFPY{v4V@~BnnY#5cBxZ0h@-dnwd!<%Yg+by&g(m^d2tGRb@FQ{o7F9^#P zP{@+IeJO}t_)}dR4+lSeiC0KJSaiYSV1nO6=*V~D`_|Y^*mDDqA6umx=Qv*mx^Pg# z#(kND$BOKskLjJ0h`)(m%j{8xf4sm4t|yS3yg7;&<@FR;t9xg_S1Bg)?u~}B%n_p*WSTaxTPTueJMu3U z@$qws7$bugnM{WAL0i5Oi<^2IgCbYeDJdLZ*i+Ft0=I-1BQEgBjgqx1PDSUK)$%@ZO~N`m#unzLoNT zN!low6rK}u@a>W#K>~VmngR{U>gTEa#qU!zC~sE85y`!V?0f~ohs09s+b_9s+1!gu zz+vM=m7ydM#`dhr2Z5NZnP?Q8^Wa5pp%}Bbt(<}WyxVtyg8^EUsde#_{ znxI|7>1#1+8t#?|pjMR88EzKnI25|1j%6Jo}CLM*8dK#Zb6#=D2a z)_V93S$)&db zE3Ho=C}pyCUb`XVsGLi5Yvnjw`6PMS(C%>E*|Wo~;v`h>mc7N*5-h*ZO9_(|<#V-? zb!oA0T6G3C({C`kWi7ukY=qJ)4T@xMn5xUek@V3P1zmgUdIo-K)e$fX47nTc(-wHxb3An8NNhl-T`yxezG2(nym7jRj|}lEoJ3C zKS&#yi=nQ`PT{iD`0n(5I;Ft+36IGl*)%GA%zG1S@j2+0)xLEFHin>iS$NJdI4BM4 z8Xc6r*HCyO$kjZ4qZbWXubOU&DZQZX{rkOE;OmamXI$e=M{h zzU&3FHP1!mC{!$%46$ii8IbHbTp7q~)xB)horFMF%C;O|wwl!8@H67@0OuHFWXiI` z0N1h2#PEt$9rSL-PYxZkC*|pX<^tq!r@OT%^)faQE{cZI=F^J9P&LCh& z@sIze96hyol8~FrtE_OEJTqLIbwIL9J;Zw@~z%csf z1|5nw*r0lReW-*HDm=bzZ(p!sv}&H|R7Ne;TE0ZZK_mEHiL`)wuQ`Zd4Tp&3@>MUT zcSv-@lrdw^t|49XsG}e;HF#)9!ixh*9nY?Wg*7F|s9XI8{QE1>Xac~|my|DslRod~ z2&xrJI{qu4kDe#zyxJ_ItPK~Jyp`iGtw)>}iU+<92h232-+HTxlTcQSN=|xo>(OoLj(#_sRF>ULnZhVy-Ja^%*QgLL$rf%zQ=*XDE=rHo zfe2rEVnnp#=;*{-D37sC*R9PQvvk#$eoT-_KEi#C#IzD-bXOE4CQ-|(8sz!5>A1kx z+muXCD$md+q1`pO&(&EL1U4Glb|qY^mZ5rPp0rGH5%sxrdRjDtNCq)PC%8;P;6c2a zw%k(aWnzD?LrDOoMiZAb9O6Zfgf@~Qs{R)n{}uk5%bk^X{adFm8WBEJxw>7-I`j)d zr+UD|F?p=ZTawRoIpvU8mE!lZJ-5ZlRKmF+LPfJ#Hw6^?@HEhi=LqZoi)cii`k?bl zfd$vMENyDELw3E9Xh(TB`%?qk*@Bj@T4VGx6a~+kCbVRg847P7ao$o{|1sWVaJU@q zg`N?xD?+P-#t~*x)O)Ak`|!|R?WZd`Ph{4nD!`tXGazj9hZXPtVKLXJ+zZHOL~Pul zffy8WL?fy2pnP^2K@LA`zI=)d20|B#l2efvp{Sm@+5BN#oKCi=CLcaPXz-Zx08Yf% zhK~y}6`a*ot?m2z`t2vJa>t~-))U7dNlsV%VUI60exw5YnIRLW-2A>|k-m&4}1+*9P*R*mO5+bqH%);~Ob8rvS@OnC<93fLX z2#yd?Wd2CdBM;P!IT|@EQS3ytGWM_3L3}(S1X~EJCCL-Ra`w_vdjKaK~d5`p<<`*IP;jYkI#oXdEQ z252Ch4xK`gu*mBQQ=*U9?R4!J+iFT>r*v_$nuJ6|90gh~$SHQGlImj=m7(0-0sBZo z0nFUHG0ykQmPLxm>9hOqg(A_7RlKsibDMPK(s8}~wk*DmcibO2;Ca9`l5d;AC}-m7 zXYx12iz}ZsIjPyP;*-qA9V)XEDb}5tXDC+K5O(yi%c2H#I&0(}C%`lZN=5I*v5tk{ zq@=}DEd>Q5a^!VU6 z0eT2jyCZ&ojU=@mGw1so=xNLbUn6loe{qang#Po;!@(3vKuz`p0^*P8nb&DseXr*U zuX6)=oxFMG@Lg{ISdRb>=>Q*U)irTD1X+Cy!1sp>;qUc`qak49(Aw(nO>c~V9yFD%jd=4P}5q0u!)CbWr8*4lJ-$Z>dtO+hj{U$&*)Kd!x2=z%@7XT@+f84gd zg(*Nj^o^k%z>a?-|H29n$kZ`5y_kUQs3^cP68;G?D&F77)|NK9|FrfOm=9r{3fGx_ zV*t!iL?9rfCz$qxe_{TkEc|ycBskn2Y%_`d?M?|3?0WIUeflZys)rGXPA%#PXv{ z9K8Jt*~ZS=-df+<;&0T4GW)-9G{IT|%SwQwK>=i&5uRAXB>FGZzn(sSgZ{53{`cVp z`xod#qxblY;{O7DICX#Xi$V4$;Dalmd0&D*47rC3{Yfhep920rhFXbrujf7>UBV3E zM}I&_c+yxsLo0m~OI@UcfHr9a{Lb8u__Il7;){KEep`rxU+_hT0I$|o%U z49X`A|Et^l$5v03jUO{~)I4GMN$2?2c6zvjr&6bn@lhI|;QulHb)O2K{+i-vwaLd6 zC(S=o{BwZ+PYeoxhnjVh{6o-B$_(|adK=KHp`)6c7KBxRWqNi~rkBKx;e@*mfRLSo#J&m+@ z%#?-pUzq+i?&9|tp9Ts%W>mrWFO2^lHt>6tPY;qGQx;(TO!7a#r)>^W^E)!(;d)%BS!@o+|&h=Ktq?@acl?F@6rs hU-*xgcL30TeIEe_EK-1g$N+yGfcJL@+K0FH{{dOjytV)U diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..128e321078793f41154613544ab7016878d5617e GIT binary patch literal 42437 zcmagEV~j39x2FBHZQHhcw{6?DZQHhX+qP}nHg?>CvkN2FC`PIKdTjbLAX`34>8PU>@i^KjzbxRdnn zB+RF;R|mYL%{iA-2?6_Z zU7!jh>L$P>>K^$-VC%T%aRpM%5D^BJeMvMC>o3P@>OwN1h1~0HkiLZyal6>nXCpXU zA|vKIC@@Rx`$TDd0~n@W!fuGfSr1hzmQ`yWq~XanN(?kcRyPuEDIz*KgU(MqualOl z+=_@y&<$zZq}dBEC{hsI0*=$S>pvAplua&-aa%mWL7=*UfT}-@@j_Ht;JH{)Oxoc! zj|`hX=uZWvajH_HY&+Ntg%azO8S(7ITriEi#GVB6GiJsJ2C+7>xf*q=k?k2_+wg@< zrLWtr*9~{zL4pZOPDKo&Fx&H&n9Vt@r>SiIdBBeXHR#Rq#3|}BsbNL9i##!W3g;ca zEm7DH7{`Vsa3niI_E2NUfk-vT4h8sbC%9C-4{_w!sF2How@~5F^|w)J+V;{nTjry zxbfjP18pM4NSH|&=62wWSMu~AoJD2JQ(!w~R8;EtEe4LZgdq;0JCPRKNkb z|4+q^#QbsUKV+RG8#wC_iq?c+y3ARM7W?$ReEoBD9JI;1O{Q@ZjEIv~e+#W0gzQOC zE2M}>mKN%qu`E6|=QwD$`i9FsSxe6u5+qSJW|>C&z4+i$vZ*ifeOMIEz{9N(<@5%u zHU(khTf6NFod9N?L(#J){c%UN;%WFb+}H1JZP5`ozj1t^Gl%%AwRx2^!BwchFMwJO z>j)lHhpy{d+9m>}bSXMT)464g=JY@mXYtLEI$bDe&c&;-Dr3RTowZjsf;4~J+&bvd zB1hVO=E=Zxsz%`V5?RzT*3ytwP6Z$0Uv5UMyA<_SNTHD_@?3pLboUCx?J;{b$ncaB zAJ3p-s%jf&eXSSD>GUQOPG=S0gmP@1yswUphw6AS)>U~w;xifo9~;?!_}3-;uC;$T zAozDC)|axyy*?gEc^&DOlNL>`K@bIml~+c3U}w`s}ngJPJQQra!J< z7EYZxt$HoA3lI&vO}PD0z~bZL1|~dhlY%zGIDzb4}z5eUA zEN3yJ9;b@VLUS`u02WueBo_1QA@c(2ZeZJf*tKG=@V|2)TO?2B$NBAs84~~?vH$>3 z|2qeaZ7fagT%1jv+)SPRI}6Y_&$Rw){U2Fy&1>zv#hJACO;bM^Uf_ndGCO0pl2+-X z%BMZyvDV<+%F@xXL6?<8m}wlfoTS3jS@C(hh5#6jZu}={b~{u~9Vr3~m%q1TL0C6@ zwbzIKXqA{=75V1ydHSq>H?d&A=%G&68j5=DyBg)wLv^f`W=q1{FrF^0IfUS$J_Xwq z<=T=EnlbUBIQc+{;y=S@df3%%*QQCdYx?-voihLYUE5P9m6%>r&CwFlHzqEAMh2L~ z31d2})$Ev$J=3H*mVf-L){#`4I{mA?vCoqa_O+>EN`eyr$yjXB*urzD89S0)85V~n z=S%uP2CFZ9No!DFI&-N;Utg4U9q!zM4;(x10W2PX2h~IgeXd$V);1bWVrA+VHP$VF z+n?{$tJ>cX8ptEojgLonWkt8>(qY@yBBMNPr`l$ueF{lV=hQE)#!8YnlT`_jplajf zQlg8tY5}gk`YX%kpFkIUEn&1dxGa8{mUo43n>%wto6Ua?5zp*hexfvE+w{20qx7-R3IO(1)UnqZec6Dk7!00=$_SwuKdc?*pmp5OB%}Kr zv6TmG>6YHs)%{#~`|UjQ@#gmS$=c8D$=2}+OhqRzU9_oVm@irP zsN_g;aTJxwYTL3Slh>C6#M>`fw_0C^;9Ib>vpcP-JXIsA4{p(4&Q`pJg2L&3ooU(@ zy{x^sTs-?em-NeAb%9pzugD~fRr$*)X0S2&bS2ET(U9x4>H=4m;JWum_LUPRZFNik zHX6}j%|cD%t0m5{=nw?r;N#&Dp-gJwTHE0Y-@@Bf!5Ne=(Rl zYI2$Gm&|5BV%Fi}&w&iwaONeq|7koGj)q44#WL6*@T86w3%f=}-tS1_R&bZt@nH(l zJSZ3Ru6p-KU1_o*3lxn4UgxSY;~(UgoDWy8B#a-%nb&f%hzM`om9$45ZxX^fAC7kg zXI6%}5hsTrt(|F#3 zwK}d2^Q6$B8)#-72?Xa?<8fPw1X!qFFsF|kW~lKy?JsC6rHqY}RQ{1cHCw5K7XmK5 zYA$c2j3}k57pC#Iif5=o-b&x7PPovjt@26_f)w^|o_Rs6@`Od7l-oSD*0g1-7X>$v z+W+PA`0=_~YV}GWb8Yj;gU$Y7BJ#52ksF7sBoLrQ&3fnMHE}Z@s)%(yjajE!ZQ#_8 z(rx*MhT3MeI1HHO;J_%0$>jzt?Q-Ydp1wN$BrB;$9B)`o*~O429jfm?(L=N&ktr2aQx!?dT` zN`c*(d{Ab#iLRPYefN)&wxXgHSRMaEnJmz-g&`C)Y&1(<8PUKU(-PH2fJD~cEw1!t zW>K~;Dfl1v+nj^IfGkscH(8psqxaTPk8^a^EA9-BT||(r`T>Jf)heQ{ydW&okScM8 za-?&J0fuIA+zS&=;kPjb3A@d6Y&vDZ=wT$yso*ph4@HCQI57Q5YC%Ni|eq^0{9ar>J1HzAV7w^Ezy`tt?kJq zEXP8M{&Y=AzAJ#dkyNAHx*P2zzTrQ{k;g8Ld3#t;85bbqBN-4;9P=+pU^+`iM)4}<}I#IYg z3GWQ*EF|oN&Ot5rG93anmX(i9m<%SkJ86qCiLIS&D71WVc_4LE#+y=3{D=Gi|nPm~W$T_IPrnphZ4+w{Wev!O(ZEB;u zkM35Kn{VW={4TT*xPQk}Itl~Bub@0BdQs=DTUKq85<+;A8nKB+2QR}l<8dV2rk~(s zT0ZJ~d8@Ad;a#^jT|tegDXlY_&3N5Z#D^Hzp_BmE=MGA+kUx$KcHsyHP1IbqHK7JgoShdU6m-E}MsvS$ zNTIt&g(*D-Z25^yUohuR52^o5JBY09%uf19qN63s&y?+q{{TqtkJ7e`qR`fE$fL)d z1n-&!y7JYwPN;@sZM4;h`kkaC4sfddJLtah&D3pyNUFJKTAOx@N@9;p*~@f!%nHKu$FbTo8!X&$u+$eJ#JN$=>q)*w3i0-${nfgS zXR|?wCkhQqVR5hf(;QezAMNfsVjCbY5FdYi4-?WY>cW@dvj?iC@D` zZ+MiT##BD*afvwl)4ax5N@6EK3IX=}k-0g5i6$X_rFPL)VjED;M*6Qu-lmW$%8ZZ# zsh73)OtXEl9^!=xtPd=?p|NWb1S0(6?F^mUQBr12{X)%0+9x|biPaK(;vgnoBiVM; ze%R>-UsB8C(3t6a`4Tyo#8E!BLl#3D>+&PM-$05ze8vRl+sG}ReUNiR9_Oz5e9fp> zX!L$mid_{R5n+~FNs1!gt|QvwRkwCp&HnBOK>xQbDfcWCGqZKR6&D41j}Hygjm?wfZS|IN9=vxOhYz17fo z29W+`mmvA^;A{VoFEVd6(2W(geLm)&Lnmf85wVr<*aLYZ!+xPm`dvljTa3x(=Nm?5 zumCPC6l+wz9KO;Ex{3y~o=dW3>22$%j?6P}3q#0;`4mR@Axv1=vIO&_2WUqvPKmzd zI1$u9wTTVm8tK@VL?*~<8-U?=0t;GMjVSWV;jrY^>L1X5Nr!M4=w<;kAaZ)3RDusa z?c_D_(6cTUHI9dxX50dqO%|(-Imb}yzfwBE%3qSUV{NXYaYaSp*Vo??zmKan2rCb4 z9((;*ITH1+FN{pW+BAHC*{32uu!5mXPN%bjU+!(!`2C;M_5&Uvs)d1yiuqtpXXO*U zXP||N)dl)W`U=cZIm$Ry#9F-<$1-yD$%?k-lz93c_=8xFJzrX6jHUUS7vBWM_$nr7 z?~mD7CBwvb3#t&p@FD^?CSL+lJ9i_Oy>J>iJVPS$9%ZrU-uN7mQyQALO25$Ow9-ht^r&{C@*{B}1j7+k?ngnC2)+>esdc zeB4nK9N5jwLq3u^H87?Kkd3clJP;y_Jj=s$h=)7$cx#Ht$>=524MGI>ZTVlJg6~UI zY_8n~_PE^@mBf(tV{fdp1Tmw=*=vP3fJs&j4_nz5%U?ksk1$$n z;ZZ0QB_mf0^;Un<&r8b%qP_tTbSZP@>1y^X3K~V1_1FLUb{k_FrqY!g(vD)y(Go4< z&j;=|>U`2LJiDBMu8?z%&YPdC?U^;SmX-qJ(ZZyCCpfpyJ%g!=J%+%_FEul)%;2zI zVqmRo4W4CJtPY<~GS;ML?`V|n7SzMn<>Sp%kt{{#cG4^pIckO<;0U5uaY73nj692G zSS5y;Fzklk1fd=p=q?qFD5-eez&3DL(fvA#imuNiGylraa+w@5T(&X$A0=~Kyq`R& zd6OPxS1t9GaI~#0CA$wm^-y|qmcZUG$KqqY_EzXVkQTqKyBRL_7O-yAyc~^jZBed~ z`gzP<@-Lce&Nx;G;K6<();!`!B0phyP8wo@6q1dBcg*4+Bg?D9Uv3aCP8jQM??kX( z8nRj8vCG%k`e4l|Rm5n|4Xi`bD5>NcNMimi_x7Tqg>a*(q`faQfM)lC@O@Xv1z8J6 zl85~y9jVN8@J(Zb=YM;T|NVS@IIBCT8_1Xi5q(R6=Z=@>+xE-O)9dkeck)cr%j?dL zI~<(n2P*4{&Z@S-FOe)f`bp!%%k3B0X-%3_E-OE>Fbh!fN@GkBh%z0=1Yhv=K}}yB z$c6A&L#Z^J3uxhYJw)VY61|(ILJ)YBaSN3JZ5mHcVt<3fRTYy7 zzTncV1O)Zk($7e37h&~ym-*-n5B(!WJ-p!yFv%B2=Cu(=sy-TOqQLW7gBjv^gMzJg z{T4P#?Azlc^bxTCYmyz?*+$W}6i>0ot;83WP^1-rwB9~xUm^vNU7K|*i(7D-eMxtw9(-WwMhj{TCSp=r9=8y07I|I>czux3|U?y9R zl4f0+|F0#=7TNr2XR(LNSgtQ8HfpP?iNukS!2m zL7HeS+YKYI480WkrN+zj3X|cGWBYleiD(&%DZxI8NDrPFGo*b>SSB{i&N$>Ux_%P> zj2n{4ARao!V*am0#{AmS7Q$NkMc2j5%*573y*Y8tqzNX!5AVD9Jhh)FlwUI&zuHbw zY~@D0WiMxpa)xSLp@W6Gc5K;B+fs5gjH}o*JJ0=wTb1k@;-=a|5EneN~c zUZmcYg};&25S~#V+&=|e?VYXUt6#w2z#=_;X<6HxmIZEK&kRPRgv`Ss#HgT%$!ks5 zo)q*Eoew83n3;_3C=aTruTBI(iS&9pfe$5t)YjIvJ?z&U&Ap3H(Fx{6Iaya; zp4?ntwpOpi*`W~_*yyd#fMOf1xn$Y=HmBdqiMkr0A(j4CviYS4AI}A#>(9)<>z^lC z>qxM2oMTCS-5Fxw&7SpZg*(0AYJri4BXSsyp(2=YKf|Dv$cVxVeAGROIOcN10fZnY ziffYYvtAD5ngom?$7^K7mI@0_dY4$!4IXh5_(eW9dcy?* zHkZtl0H-;pPp)YHcD7eT%0h0Icx$)$VV9r`mvCQyXiwgGecGn(d>qny8Fj>$XQGLo zojU!3$`!8J{&BNszw2L8yu%lW}g zBtBsrZ?q66>7Ie#`1@xMKY3^G8$Y@T?HC@u3=vw}EkdKPuq)Kw zQ9HFQSfI`dQ!#d zu06-gXT`gQZxipkNgO%aMNRB&_QmbB#V*A@;%LN=!nzk+-6A7tb6MUJr!!$N%4x@L zNXo>woxm7nR$3a>-k9&vzdc1Ka)l^GPi^84-P9m@u^AWdg(=u-bnNq}ET&nO@7*^y zGWGJ_DUz(9NxrXYlLzboko~)ED?#-8FaiNIah^w_LVerd3>n_;-9Q53(Sj&do%K~9 zJsY#_mMt!m2!9=&e?{;VRx$S7AvR}RE7VdeHfxVj|D2>{-{j$Ro`(m#ukBw+v*ZfL z|B!uuaPr9?Tf2hM&_~+~93)A3jmwIkP8W6PlSJY|VT#$Gy^1M2cUJdbjGMXGZQa|E zL&n>?Kkiy+b(cS=?j4?7<&SUpQY0q~zCF{ELEN_mtqH^Ec`I@*xUW_ek?rXAtdd80 zYeMNuInCdfkCXSxBSN@nI9}?O;UC`#eT{020by-}yjA7V=A&S=`fd3i4x$RX0XM!f z2c6}VbE-^w?PdJLp-e*%buLZ#9F+fEUc!K%Ux7!5T7E7b8~6?C|D6b2o85YTy~=;l zg}DG)-1_}!V3ygXgYK3 zz9DgB)pR0tTL||owE}gN>CToawE?P}IEIRYTJ0RHUSz6}8Pi(~kyVD?C9=Y@>?7Y- zrb%)9U)E;A9xftid^sx}=_zq~eJlI|VqyOu*g%4+6T0d(M0$)nuU_Z?KYC$L4_5** zTHia@H`={KULR2q2OyyvjeyY6z`49>=*P14ysEuzXmMKIx3P&thL|m?`{{s*h}BCj z=CuXxkwF602)r~-QSMVM&WgvP!J)+aKR=ukA`5U)+hT1khnNK|2DvcxdB zwUXrD6|CP^QZt^}6SaGv);O+lOD2*IMnh)TbZU~f3>zA5K{+pcCLj0NJ8Urjih3c4 zZ5wGX`D$L|j?AH*?jc@(2in=n&i&1h+~!}Yt)AK%C#pW*n3JnKk`*Na`G^)}@w!>~ zEJ`JUio~CIGS2*s78fC&5`cMQHoq~~i4XfPMK0H&KJ{}4f;+6N$dGb-0dw$hoX>4AV}^Zaa5U z@`!U|fMO5#x40jk?Q&jOjFo_p+7xS+MvEGf%Z+IwH>8XGjT5(Z`OF*5h(qWyc z6SREUb6M3-$9euwGQQ8HBk&c0Yv_mI9nSy1zLYTfi|BE!z7xGAhUmYh&!j_Og`r(cHc!lXt~(`2Xi3z=mwKozO=9g_rQ6UIbrkY- zoBN$Ww(+pju@){=Na4-;x9@!h_5(a^kB&$kEIk|cr@Ke<=a_V2WJL5q46R0>9<52i zV1f+Wz+s3C<%`5>9E=5=HA_r#XsuX_TK1P?TVwecF@shDCbO7das!Kef&n{o1UMyx zLuM!*=IvnQl_81LxTYp$c0bEfyP)3-qOTE7YBW~9V!1MGrP9W>%CQAlk3&k^f~*YL zfs7N(Y>+mfCKkD7M0H*wsHBFno~DJx=TqzC_1?6}`<;G9bw{D&o~@%wHi@)gQQQ6- zg2pj^K&EET(C$nV_(>s5-GQUP4Pslc4asUoWVnV6G(gq6+aUViZr6fFp;#tLUM>g- zOP-OGt7l8#WZG6#@gEeUnTJo$EDwS6onzm;^7|?rj&{9!0u(3Y>(-C7h5iQ6j=+Uq8stvPN zNT+=EJP)vUIjX}eBYe387{vz;6+m%h%iuWBQ&GAeo=e1iUd15H89!Z>rA}Qe-=RTq&NvcM zR^MOp*K^Wz=`Tz(pr=W$!XM4>2VwjOstMxAbNGT^dB>UEl50@LATckD*Wtu6>;Ft& zQUj26+5b3U_WxP`+XB zG;~ySQnB&s*q6`~`V@18n3_xB3q3T1RM8VHgp zF#pG+;I<*%OBD_P02TuPF#o%0SsA(+>KohJ*q9o-SlZh;|38j*tz&Dy#fkiLqd(x> zC!xJ2(fN6RXoEg(w$t^L#2T>!hQ1?IV#}pQBh9BMcS!Tnw>Nz!TC9glA-6@V03vGZ zWX73;#g{`=!eh+f=P^LO!75h=zrM9lh4DL_CfBxgonyMBN~}SD%q(rn6L#A~3vu|Q zbqvN>zxI_ibW4hqOgyW^F4>+U)JSZMab*-o*Vw{(Gf7KAYm{h;@l(&MvX%<%g1%9RV!8>ILE@7>2rWD_-^ZV4mKY=c-_eN3> zc4rU7ReFg2cBb=CI6eU0fLjMM@R(D#!q}jbwBkSPrg_ZEeu~@|5Ldu5EAF7}0!H3* zV9kQv`7rZFGTkb^H#clVD=?T!Sr`haP+swi(W5oRdNIvI!^((utmx$)+TgP{!npdY zhw#9&WZb`v!ILW{eOGKX0mexeTkF{f+;w_Lon{bCDVI(|T1E&8Cz)|QqZYb`C-kPU zpZGv=enUJJ?9_Sy1QlQaWG*!-yH8FS9~1wCIjm+7An>S|26AU@b=n~D&DXmj>SVlqLsK~OcptN|m*I+5CJQ34SB zlK#oSjQFMYh(duVlT~A|xrJgTPQ_2&EYb#vnWuZ@&Eh_`HdC=dzYrws5+9;SEOP1(RwWH)DF;%( zbP6z0-hjW%wjAgI(G+>I>MZa=-{os{_oGdF)-(jb#UM7Y01l`x1A{9L7c0_GYgFqW2IUpr9@PU;;Pmi?dNs>5t$P>zJW-{YkA(9$eCDW;!a^vO}B5$c?|7bv`tN~2{0Hg$F7`u7)B@vmysjgZSTznLR8qV_oj zu*G~BSX77R)L{s{x$0n@cePZ+Od<wr5Ixb8)(cT=A|cl5 z4y||FYu@*?p9wLL0`X*hl_|iT6Ya_boE{_GKwxJh0Ahs{vO)+-*)kp-N*ymghV6Q(Inw&h7&@X%Ib z-`yG)e-%hRx^4YnU<+wA1iwGis~OxL)Lf))-ZLb^OPNmMT!-PYtp}@R_S%$zqKS4xvoJ4wRPAfv&ph@p^y$$C$tr zQ)%Z6d5B3qlf2Z3Y=|-sE6o%KSBZ3(+Fw-wUFySk(bQBCf8h~_xw-gRpXiymImpYG zyu5Y#jeY>+P+Ol}aT}IhisH6F2q+O|)2I1nO~w6Jy;~6S`NguUpIr)m|GZdebkS5| z$#F$DQU_ZILAo!M0VEym{O?+Sg_jQsWv1;&&fi~g=12-NZqTbCgbPk9jw1wl`8bGp zV?R)6G`Ebu)V$vsJI92FU)-JM!nq*P+-dcKsJH=Is_Ms6li(G$Bap!#AL<5))HWhu zEtFGt`E2-K28Se1Or@1#{|+2=YatMJJ1vFtk!%pK1D3%}%vt~rL%E1ysOkw$uNG}V zpj&C=EwThN0IvNPJkdM2Im$xHYBM6J3XFq+$~LiaR9@E~(5_Y%#RvKmT?3mhhI$b{ zOk>DoO#e#iM->QMK`L4s4O$H^n#U1YwwCK1g>-*jNPItQ$XX=pXto4CQ_l6vc2j8; z0VoYsR&qMhG9-3bZC-K`=>Hgh0-@rNIUOv8hZTuB)=5$6PyZxF`6CWL@>|={biU$p zpc$-C-!2}2VLMdJVUk{!8!~ggPKNEu-2G6#^}lo;+JazZ19dH z9<3~wVFeOL3EUTG3J?HB(I__?I$c?8wCof6=^=-si|!l(TzEPBT!g&5`wrC0%6H0- zP&(4NmwHq0ZH}*;5hOR4lDOVr-1+S%>%})DjI9mes_6{F&vw82W3NPAk%lQyy5cO| z?|QYehSG&BuJW%RuPmyZ@Gbe6x3OwTg9zl9!7T`r|O+`m-X3*$&p)c#mxt@^5{Bi0HBCh#o)37K|ghtlDhH zLB1{&tiwFMV~DG19s#y2EI7#=h>&E>ps?_x-K}rYU|5r`F7k459d<$G9Y>8cK^@=8 z6qy;%Yc3>pMasI3Hd*h1+@RQfZNJ8j+*ZQ9<*GGjDV)yBR6*))=a)9_=^8QpUuyA5 zLHMbb#@}AojI?J9^+R2dp5nk4?Guh`oa|5zRV(fEuD6ZC;}sK}NxkdlnGy!SRvQ2N z$2;VVc7g)fQTe7oOY*&`gHflqAMuwmilGJc5~jorkk8w*=e@Zrg0t=)Yucz8gb-h- zA&zC_w+;HPF?qVWFtFkfyL1Eo@kF=gvh(51V2bh!-ML)E*mNzqGBp9UX(C=D-J90QkcX-hp@y~zH&NQvNxEM*FF8HD zJ1YwO-lGTxx{d7oj%87sVP%BRM1$Xx zQy=7x?k8KuMcrB{?Mwlas&iMcJIU*MtZAGDm?7nxxBNmOPo-T zsk@zZ8*l{KiC>%oZs{(bJDcalb519;SD>#JkWk9cA{gYlx9B0xpRIFd6zpdmv(q{Y zti?!gN!y#3tbp6;Vuz_gRY`*jIq7DWo0&Mr)OC@?*2=UFZ9V$!X8vQ8t@VL?GTX1Q z^3&+0?v=a>w0ba^QAQa*?%E`D*pDGEJ}h^uS+mon4`a>WdzlgNGw{sUMt2-#c}kSj z_yTEm6!cLa7XRAQpUv;s_F0jaNzerDmOjLVPPhEYP_R6L2s|h&7Ft@Jwuz|zbFb&m z;mv0}=FFrRU(C_ROMA9Ktytw01CxR^bu>WH3wATE6i)>X_DiaS$-Bm?W?(CJD(ieT zQrjVU8-@;>lRiI#+b}22D|+6SbE9M(q_y5&SI5ulb-n}sB&z+T5!6LYsVGGI4H6XL za5PIPM2#b6?(}Z?eLkF?f=2ai)y$XR#%1eoM|dZI+~1Ts|F}giL_mQUul<`9SjnQ2QFYpVoD|era2t3sJ_=ar1}O zj*T0^hZ$Q3KXzAV*W2aujGn#X&i{Y@mj25f|A#O_)usvns5t`tk30Uq#x!$77gKjb z&;K6L{x9zMKOI}=EwQAZS$!sH{gohiu&KBv_bF=M1hP%V!nRaO?mu-fxX8lD=#Yc} z%FOJCpO>q2PY?-esT}E>p+8iSpk`iPUY*`~H+c&v4X-ZpGLM`Sry!0jUOlh+Z^cX3 zZgUSFsnm*R#1mGkla_Qh$zDVc6XI+&G#=E&Z*sem_v}~k(!wdq5pJh zuSD}&_QG;!wKd6|k}9Q^VKk#nne~^AvL>wtNS0y9dcD2LE49+R6fz@K(W@G>PD?I9 zS^stD0IwZV&fTb}Q823HMAO!ZR+PRwRWxE|?nxF;dde6W2jQW1%Oq>3)_V6JDOQ-& z7y+rouc-L!U@*MA{_bAD9z&a{!Vl9Bb}3Y7)s8geRLPK|N!17^#}{r?>ZMXgi8RtW z85pFJ5t^cvKut2&dQRB3C=>&EgA=HwEYRadYiucOQpt6dQPLVFWHNf5|zxrYf z$5lw~S zAM?!pr4GoHqX6Mr2;zCbAyT1}`-~01n{dzaH(1vlgaI!Lj!2dy~$5kjUJ~XAs%Jm0FZ(AUIX{hL+4Wn#q zhz3`@&8Y|%FE6i`i|2!{6JM_fj4ykqV2UGmia_5FL)YKm2@6j*Cl5Yei~{3q_-PeG z*N(rlbKqUA5`Ncxt@83k)IFBp>O-+s3O2pVO9i`IIN0}h1dSg9vPK_2AXJR>9skJ8{L>O(bPLM zJ1QgJx^NQgONvdU2p@`B0Ld11L~H=S0i~#@A|l7Kw1W{-c%gpLEnp8E{J9D}HeR@CgX}CcZeEGiM0wGrZ^SbHhfz1`B!xo-r z+ITf#!goNVcNseO;`zBnXU|-p9p_J=Zu&fV0*ZuvRn|8U%T(3JYc95u zV9PfO#mc~Ujss$#lMpd>nAjt;KH9=HFRUJzS9G}CP&M#d4G1hQkWoAN9=mx-D?9_7KGexvX0(EDWywpcSw_W(SDD%w-i`^?SCW+iR!g_wTP z8KpmU`9&aC2Ma5A4y-Y=;kSqcxIm2&yE?ktI(|16{~|1umKGmsAx&$k(SY`W6x9M1 zX`qR6%7IMOgS4P`l#apdd03#w_LqE<0)JGrN-wOBI&}}*M>uS$?TORWkw>X zTjrutUBX&IumET%S%^6lfv4k(cuCa@Jm6>qm5i4IVrZWx1yUpq2!S)tcMc}XU*ANn z#sO|be&OBR*3L`llq&V^i>)OYlP zkJi#ZolW~h%N*JAcsw72YzjP8FY)**motgf?n-3Ak!rl zL^Blo&iLW96=Ow$7RkbN^7eY|alSx-!nQ(C5taf~3u6Lz?(?CNB?h|}fq~a!L+Tv% z?4=q)0rCwq1J-u{sVV)ccJcX($;K=a)F01ZmJggUUt#3%iQ6cf2kll#j2k zFXz|4u?NBcCTUIkOlkXQ>9-(9;RWmtZxFNZ$LRljQq%jen!{1CWaAQs-q8?7z9z2Z zB^II8XG8HKk9mO$Jmu1|XcPu~+$$+7j4T-4-P>X9kX`oH4T%R(m$0Ouc#Y!70d+zr zRsbBv1x;1UEt#bSh?83|yKanA5T6K%1@sAh1D9$7!WQvu74k=6SbhZd(an2_!1{~p zjSc8ykF+L+_?CR?RpGbmWdb!}R5%enSr;5sDJ(sp1N{&*U(FWP)ECavjt`ddaC#3w zTm-OXmq3XI0#CSOvg*oPxja;7eC0ChPdRzGq0dS{$x@*mM9V2ulywsLyR}03N_Thj&~?wYHC*_M>oAfzhNsG|VQ& zCgDmet7-&=+oe&hG%3mA^$%RJg4>f~ljZ{gL5TkmL6sah$9E%AI-!UFwe=fwh0ZOd^_O>An<4upV_eQ@RR#2g1x)h3}9iY89 zz<1L~R)bL}x!H2fk-`iLxiavHfRlmzI}#bii1b2Lfu(8Gi)~U|+P*dtV2NO%y|KDx zUp|?H6Qu#%I+f&KG2YFOnOb#?Dglw9{uQx(8_~e#ycJOnTjhzQHj2HE*;=fOjM1Na zkCV2ZaTO+zXB|#1#J14dn2a16ODv-i4AT?N#-$w6yn?UcN}U{fxFf{_^;9(7t6IJzXZg-J2Er=4D8gF1;E zgQJBKaj}6`$dMW4$_ETYB|?MlatNBdx`1Xu%hIrdL4O7VkZTQQ9n`}zqGaScC{B3d zt(O@p1#=g|q5zkxG*X5wTqFvId8wetv_?J#g{Fd4BiZjunXo~7Rw{Wi=n{QuJ?UU7 zga*?xJXvKKeGnNY2j0^Jhk*NUu`1^Ij}MNOqZzcSTAH;B)u#4l!L6yO4quoK@fWgF=TpKU-=s%d026H2~G$?`kSV7-Rf(4Gi|1?-ucRbrh$G$FC8 zwOG>rzrt%qx0`%!s3PuS;xZh7LH{(_p{;w zZ=d_(1)Sf=WOn#?RfDx+4tEVDwoTKG>V9CgeY+>3h{;l28+r30cA8+V8kQLmlgPZD zlAshzG2USjG>S4REflEdKFR7hc(cWvcx8g{?Yx5C%BFZ>%00q!6v(ngANUhfT);7OY5! z*I*6mlUx?5>XRY(8_!W)rh7-vbYIrU?bNZH?5As-; zAO}#AG`UlKofccd1JEw3I0WJ=JdssR0b7k zw7~T_K-UZ*VNwDZo2xqLTGQ$}ACk;mmu9qOi8Cb^EZX)E{ekIX;eqVW5pCwY+r+qa zN5{0UWs-VWO~kokpG9$C=ufv5A!-@*+Bs=^#4G3?5r83U!(1GFpy#LuI%nYTT zVd~TYt;&==K?A(19Yb%6(5#3mRj`x{Q7`7ynHHC(eQ(7AGTgK6{5bUk#nFtW;A1nx zFT{I3y`=S-D?Hr~Mrc0Xa-N=tw~3~`J!rNEFnCg;+U^xep}IEn?;G#~B)HigHtRNK zY1&kENdbAUQErIUWHk?TB#OY&06?-aBPf5VOfJ|imBuAzlT5V7rCm^Op&)76p=>py zxx?ZHfIE~|Ni^NiXak3EV3qDYHE)f9zOBRR4}&C?sgwsBJ357&SS4sK=R;ixAN97d zNJl_7jn2MrQL+u5h&dnGgf+s$3Sr{GQnGRDLdeRD)E7r$|C>7ZY3T$T7mU{WwN0 zSrT>xq|q1eTa*0oD0zB@EM>pfc#gg%2(x7=|Zn?5zyD+H(z<3 zsbij*ScK!An2WV^Vu$9zsRFLx5)QY|$uvBAOL?mZVVfWW{}E3S4WL6PH><5Wsk6la zSxVnm(QSo#-8$mq|MryAEKzC92#ABvzMw}Drp;Yy^-9%c0?}fnV?Ko_1X{#m?3$0A z)r~!hWB+uM$;0WS*ld+DP8|dHIdgL`!@W)CZq5Fgv!f@3&tzNGa5o;KWP2t|r!7sE zhlcQbI{vvMoyF0(nkP#61)hHc0=LW4uTO@j$J>k3{OZk+YSTMQ({2u7#O1t3k=k+k z8*9>y0UlkTUPyaBzi|j(7poi3bhX-Pt{L0N==>r&8=k!yUKutqTVhA>&Xw3PcAoTa zdZcca>*b=eD49u;;{g|gwKjG&C(-^6>GxWbc3`QE%^MB2+TKQKzZs9|I{yQUw2R5+ zcJ0}(cjD_oL+=Y>sCCZkK~o0{X&58P;9Qh|xLJQY9NiAf|dOy#SoGWNzj&~9bc1PqUcA$U1xOnET>h&eA3EIARg zM}SGqC9=XJhg3lYx+!N=@q(uDBAPb_-CEw1FN_D61*Y_{j;dm2tJ!NRxg7MH{2i0= zyL!Ui^muyIT@Xks3c|88g?YokGXYDsp$fmO7Y`_vQIw>L%mkx+7U)LX#-YTMcP=xF zsj^(5VCf*fq;|&MC}+^PP(X=MPTsQE9NFd)XeI`>`e$m#hT4Pcg&i(554+mapab0f zIEBr3QbmAjwuS&grd-n`JL&5i(?{5;nU>M?b){;{$i2oQH@7VO!j*3{Ws@h+ShSWC%Y!X1mzv^=L z$Gc@nX-{xt`a*U%S{LdD58x)%iJs=gEND)sq3^yAbqogEjNpofuXKv~PKK$><{`06 z9@EXhRo&1--Qkg}+Md$@H-zdJvLq6A16$mZR@=j-T{U&n6gajt(z;32scBx8sw(m< zrlN{$v%k$b`3gjzvAI@E7MUM<<4+8>owWXeZ%yo9(P?UKVtfYXcIisiqRj76B6tME zyTj+n!*m^p3)dGLpC`d4>0@ z&1n+LVC7N}EHGxcF4tauT#p{IjI_lN1$SG~HE$-vZ3Dh6NyI=(-9AzPH$Kbh$XKtj z80J1%z;{`!x8GbAY+2EbF*HJI{cBdFMSmc!9SOW}ruo>tGA#ms&H}ZV2wUO{ufM$+ zMOVk~uD)#P=zbo`vmx$XD~VwzdI&%Ta{EoL>x=00ieMX+QujH$LewR$u19p` z04AVeTV76+iL*sh{Piv%o4O_5IjNOV2z~NrhP64I4IZ20b7%`&QN#2!OBTV7)S!p{ zQiHqNXA}&eo<5i09eu+977MxhLKM%xMrp^I=b|ylp*|w z9loi0X_ROAYzCX!bN!uV4G;D?ZhGkE*n_vmcugE2<3HUI;!*^Yzb46F`)2gM9kypy zeY#N7TkqPnUHd*sb@NWy5q8k053B{#u^bt%c@d^6xl?#q#{88JN;D=0((Cg9yl6na z5|j_G_0}cDJO>PmI6Q2$E%iRAYw~EEYrBtAZ*$JXyCK-p0@gEkOr_ym%+&dYjqYXC z3+K=}-sHqH(8p8V6Yf{lxxSoXsWsKYUv+bvZ$I4M-+yjGJKR1iW!0qNN4&R|5QzUS zc+qei`i7dXrfrMXwctrd8+d9oXved65(^#Z#mh_|e+Kta))eFHI`$kVRguayJ4UPX$NUs^ z6NjR|qN6W%>)Toc9kZ>u?j~QXDew$K4Sjj4nVTvH9**+BHV)w;H#!KzMLdrwo{%|F zHdh(u;@WDiWip_u*Pi}n4|&d#v{t2#CUK;yaVm&M=mKng`Z~|MbU?@DAI{eBkp2vf z_xlwt1fKVzBp?Nl@KkS;BR)_|EZ~VXS%XJ{-DECtcO*9O&<&o_Hi0?pE;%=({3O%y zBViNO_xW@de>umC1y8#zIAImSJ03(yPqNtrXWR&|?MYKTDxlzsb`C8AeqfLEvEJZu zIH0bN@Yg!ipc!w*fo7atbQ;|2$E}l1PD3#@VS4gqcKi3b5}g^~hA`12m8HqXjFH6H zW*}`Mw>=9zNaugr`|{>Cjx5jr^(m^Mi7?=SV5y^LVuIUr4^7cBZIRTHl-&+l0Th9v zSQZGtC_og~>+in#uFRKL0i@)scM#ncN#uRJeCO{mErJDrt!u1!6>(nz($vMv*%^ZL zLw9CaPdZ8h47HzM9Xx-L4HCq8xMC+?e0{LrstbjW8NtBszncmkgioaFSkRrX zsBDV`2ubUmBNI`|4j0&n;2&#Q_n+^boo%(Mrqh49m1MnKwe0lpz6W2!!1O|!1{76` zt@ZGU+{L7Zb9_fv13{O{d$I6VCstOqL^3jIPB-V<5a1;SQfLc@hKBHSxr*X(FlhOaeMATQYt zTFi;IjNlN=xhwPpa4tXf3?SMJ{j5M^1pHOMP`(06D}-z|y*>os2FF7?fMU@{!b6ig zjFn^X0quLDHnQ=s4JszNrHwHIYLR0eRm!VZ38YhHo*c)SLOZqlOPEK9@=stI-` ziF+q4Sxy_z3B8&eT>x(@0CT`yG!)Y8`tq&1`I%Pql4)0}U2jPK zQZ?%_I!ejPp1eGN@$!85?a85~%L9PqGGp(oSC^0eOOY?EZ%#VWWFum9xv@5vUV#T= z>VJBAaymRcIDdJ1JUlu$JKKA9(AwR@;{4?4NdX&CoPqv=+7nBlb+X1k|06$A{JwXW zL$gm{3V}O0=sHNW!O4-}mT4m(nn@iIuwhlCD6)4iAg_DTwOpbx3$#e)4lz288nR)l z{T4KiS|??>$H~Y+VK2j)U@U#P-jTYd~`9LB;!O2Dx<$DN8z6>LZ5mp-)n27;TZ# zBxQBM`&J!j^+!l&=4J(_M!16406MYSL8oN`r&cQliLnDB4Kd&gRqhLvZ$Pz6!K!a( z^Bd9yVWg=*4<8jM-GScuy&I5-xgY`zP_#|b&2{2c?RL^j}0$GkB|R2!T6>( zpgg#!fMF`SN~=f!Vv~A|Lq`~v>iuF`kLwl7ht7qO7vbX8nsOwLL;)|d#(%Vp-aF<> zHz@Oia_syV77{Ghz#&y5@hVW0D^dfJwFrwg9FA;ffIr}2GriQS;m~5Yl3OWwLld|3 zDL6Zue?DcmXQgppFU%PHAwhuC^s+aw&4$%oZBuSwgBG1ryI+7t0S`6);NJ1HJHGhg z2hXVPY6y-Gh=*$aIE2WE;ZHVPyshQXU4NUF2}@eFY(*T`ZU9wI0}j_^<^W?KUlhLGxilmy>N;QIY$uf#rIJbyY1y1sEcb_X^n)__3>whvl;R zpa*qF4AglTZ)(_IzmZLA-Sr+H9HP5~q0{!}TPWd$;`kn+0AK`R81A|8vXZyBSSE+l zx>;?$ufgc=cwsrD0{-!206hkOaUT}nNl?ldW(n78B`!G)1q?S9P{Q=A2q7x#g4TWn zcZ2%PjEw--gLo^w(SQS*{=J3$O;&#Oy7h3{RT%&RZZ|{EZJ{vUH(Lsw-}OOu=I>cz zp#DUX)_i+@`Gs?+f!(=w7BMI8I}{rQt4$sf+AxF?a!?`kCAs5(Q3Dm;^vcd>m-QQP z`YMZ&at0c;oa<;i`hqOE8$fmUF;Zq%Lv;5G<+_%DPr+|-QOHuke<0v?0fVVP`7i{3 z+Ef!wGldEO@WJK}j!U>U-0$*sXrxV3ANZeyF+^K&xV94n@ek@;37Y{QF{RN;kh7Ek zZW``}!S1dk*wzK$R^|kk-WU`xGXsx&vLtt7U{{OL+1c}bn%@ZJHLvoB44YT;W@YU? zBZB;(P%|DAdIX^^#H-#VMN^M7>fA`z@ckSxVz|9keMyqgiVQk%UV<}$u}p@+gXnB= zKU0eo8HCdzk4VYCf{3Q}j*2cEf*L zg>DsA0({mBxN6wnZE;nj?U}4UjTlyx`zPA_-?}o<$eLn`*dCW=qfsbJuPR>C2vEdr zl>-#MxK(@WcNenx;z)Xn&ut)^SP3W^3K0l@Ftkhl;&H(cBmh2=#09a78aR4}w`_R1 z77n%LgwZeT;Ui!A;KcAjV}Qve=%Ng@M(r4gA<(KX4>c>;ZQex>bOSz4I+AASI@5d_ z=P#sI#*xWrljJ$4zjbMynZ?5hE6 zFDNQ1->7;8FA7-}L50$!e4^AG6|vuucy%J%hDv&8Q7+SUIjI6*AsiB{rlYL!VrJeD zjdD`K)v=&D4d5scx`Ee|au)1*4&hc4f!;Z~nI0Svz6-%q!S_O{u#^v#;>GekH_W&{ zFCHVTNW)Ms7&8p4%NSi3%gdG0nZn1>Y4KQE&CS*D6AM_M2ODQ+Cz8Cr{JP?Z#Dy7E z>LKz3`n8}&mJutwvEhp*_!WX`#%LEIT7}fnFE}!6@;vU!Lm4oU6zyxHf?YlP%s^H9 zOQH}_yj?{?Q-zpJ?Rwy`4C!B-P{tXynzw2LzCk`{z1olITB^K7#1s5yYmf(wH&NgC;eZJd+4By?H z9%m!lEZF}Qc{4jd(G1)Fd-2%6D|c$Qb4Rm(gb7rU7Z4RGBvj7wbCftjcelC{e^4|C z5?r_$#9-~Xx zASo~G#-j(8;m$;R@gS+rN2p?o8OH`gK;3zKRX5ps}^yp8MZA;-91sGoZS{9FW=vY zT_Zy2qI5CSyg&rzYNM@ZhSWlEx8EJgR~J=C24S^B9-0Jf5&zNMFBn;*$TTZcUm}Tz z#mC|YuII3WKB+y+Mlx&qM1V-?TfoDbTZ~gD*%W6L#Vp{Fn_Ev34*RSxy~Z@0tTpFJ z?osr0ye!9+!pKLQ?Rk8oTv^BngLtFmw7~p*a#$FIbWDvqb2++dei%cKY#=dXe>8^P zJ{IO%`*WsF+*J7J9nA zzP`m=vA*;Z*`1=qo>HnDV+c1Jr;vd%_hwgRU=Xm52m1ny$Fsq-O1;#t!e5D`}$#fjEG~w;eND^I9h)Gz+rAnpkyD*H7 zVIY|3quIeS859nq1ZjS5n*6e{)y-ACg^u1EC;JHBmaPR#!ip=l&YWA9FV#vM;QMx$BoA& z2W<>`qIk@+*Nr?mn2Of|!@BOZ0zNf@j7~nh2D&!?sf~3{^pw@aQ__=N^0euy_deqo z%ddan2@D_bZ%SCCBszG>zL3Rqj`7$LI*~GHDw9>p&ZK(O=iGs9*|39yt3e)Hj?B zH9Dv9Ij`Tr*$y<}36Kz(r?`jPL}$ScAzBEkWF3%(s8?B55hNvr@oODm^oWJ8B5l_W znlCAMIqA~ir5?HVy|FGTr$WRrOzX>vxG}aQNjeHNl?YJOs&+O}*)~V^xpbQHCvHV> z{kOW8gJ`{b@uAT`Ki(GsE8uYG)kz9TN%BAmmBc8i6ugY=X?${`K_zD7PM3akiZG|v{sy>N~)e0t4W zma??K=?guhJkH+vN1*|2f(Q#J#=Bn0R~o|D5;LRS@;bSxID-tDr^p(EIt3Z2^`CKl zSVtN%BAAiK`HH5phbf#g`*afHtI4a4UO_!}j8|2We zrnK|RKtUtHLVmrXJ74nP5V(95I97&w20(#jS$>p!RczirTw;}M#vU3DdbcsD&IRe) zqoNDb<`Q&H)L^0$)1zg2v!-l8355s6I5l&N0pXpB5p6z}qRsym=%qr!(WEmQ;;+WX zXES;!ZH8uQxu4@~hPX1xry4#CemJ}=SLKv|XS>p7FYYzc6%vIDnhf=f@Zlv(F(EB% zL`eL>Znp@vq8K?XO{7)|R4(D`qAIHxCDZl*t(pA*mMF0nx@}L0%`O&}j&_(Z{NGyi&<(v;Q0f~8oXx-gf+^LWwXt+KYlzPIjFIaWnVge-RPffHeZIB3GkKA z<_WZpSjAd9`V-ufHfEl$N()cSjb&^dQAO36x4jY|nA+q}vnXkOXQU2?1B%MjT|c(` zT}o3(o5*bWC*7%|GK%RR~v>nZ+Wb4tSv2`|L_H1&$Ispkjf#MVaK{N?522=^Hv z^^uDO1iI2u%h;&N!aS|mzCr)zl z^Z&q#dF$(%Qp$s=>D{~r1{P&M*A*^Wn>zjvsO&8-hrs&5!KSCw@JrAhzE!&GzH7Gk z>hWvyzd%h(hs?G^|8%*1g*6F|2Ir>-dqqR*)l#1|5-z+C#JTSjc^6DjsTT%3AGday$OY{3rA21#smc=-m8(tm>9wB7$vWM zT2aSmnA`4VFI-9iVvzz=Ly%?)lHGPY%Zm^jCuh}(OAR$hOImo^7F-5TK!_XD*{5(O z4+};UPx0ssrr?lj)DUuLHSoD^-axWWK&C;tB#TNDhHSy3*NM#_9d?=P{`~ltou8ll z(z6&%xs7XI2s6GM?2o6HH|2~_pgGmt4#n1h%$uVerUI?uHJpr&nc)vJGsPcfcgLYh zlfBk4cvbv0vDkYFS?4+s#X0NC(8?J#=V`HNUUcc2b4AC@*TZ6u^FuG@H3TU@NI$1^ zT~wTOoRvPvELY_$(MUi=nmIItg^T%mY3B2;Y6j`mQnF}H3T%Z%N{en^B0^Mz!<4yu z2{ENl!gNY!st|bs%Bd7uOs5>tIlN@?obh@nV|3#|fz$XbN_CRH*4$1=WGx=KB9m?2 zRu2KmYo$ERFRDv`USK45p%)D1&O{jVNz=r1TOkxqY6KA=0yWuIAZQ`aAh^fw*EP4K zBG>O}ZgDmi-PnE1Ey~8Cn`$+;+|_s*SWRUFZYMr;a+1lI?*(=+)pV>+lZxO4K;j~BAazuh4&^Vb}xTWJww0O3FEBxDz6l z;66R0oEAY_+#B;UDxOUNLpW^Q)clt9qT2&~wobgTVdG^<@isfJIq*4#F^19|>Cmai zk8)~hD}SYPYa36V+|z^ogTrqRoRhxFmlu2J3%n27#4$5Nw;ogj$%5GPy~Uje;fT^I z%;q=tK|^9Bcs1xAPj2DBWPM3cIPh0V?Wf)sOnYd)k_?4KA$m{&RCaUNjMYX-0HuMY za;pqsvcpgk1lHNb2K+B`hvDT-5VxB`Afn@MqkAa8i*++FyibFpCpSsa z5ry1~bl=)cU6~$4%lku4j>L1c!BO}yn(@rlnTLP8?}BS$zk@%C^abk4&fv#+9XKew ziV}Xf(}hW~^}Uioyp9)K_oZg4Z3n#AZ9RPZR+!o7xYpVzQgxoSdX_JtE^rk$e zY)&$?ip!!vk}D2{^Ch6<8uq=FrVUgF=~H#c;1lM2Ap>4()|b@#85Hq0^*#w%`U%r~ zTZR^|htN?*$@}nm2sUQcZ}GvfiQW+wAxKnVsInw-v~A?GY3n&;U!#HInJj~$h6u;m zQda<#lr7y!-l>5RkXGgRsv95qbVx3oFU!z{=lmm{CI~?P9nMm4n6^1Uk^i6G-Bf0n zKkA0-L3%O@u@LPt%GWA;Hp+GsOA0_55XFmR$wFo>*|A4PKzN+3E+pHC3JI4HTaDA9 zPr$xmy0c!m0V3eHM+|22%*gvnNi5OA-yY?CMI;fQ2hvCqh4V)dznUt5mbI;?CPTMp z3qfXhCrM`Lq3UfCLUO2kL)q5d;^$w2udq1gZ6YF|LteN9WUL!5=8N^T%sMh;g%2qtb2tFCe~joH;eIjH#Mlf zxe>d;Ok$QalP?9i#9S_w@fh0HS%=`0N}6K!MP*DeyU8#MnxAEw6-c7FzSA=erWuCM zxpaL&FP@ke?Dzi&&AFBS5UCmV12RN-ilZQ&5fws)=@*V9&_F(f3!0;W8=}AvEy6D2? zLEUZhj~+sZ0ood7DX(Xk-`-p>-0gZc2^=xGwW1KVpWuaHGC)y+Vt?%wd&Rkt9!P`9 zcNarBxkAsSRHKwG*A)nXIOWwc)EgItpxhO3{t{Z#J+pW9#d<|OK8W^>Y+U~95$Q9! z;S-fmCv+kH0lw>d%!@C7Vu3DkO<@DjkF=3dS40mP5UfmdxC9Ai%3nZ33Cx{kAj5v> zlkG4(ZsyAdHp#4LGd7RKvYv-1*nsIY23fMEB!!dt8X?h*yiUo@H99$Gt{)__MaV9^ zf_c3*%coS+v4#TGxO6iEs@P9b1W$qjI*6yBZfru`G;#38q{+jiPJjYY9bqJB!EUUs zXeJ+5`pd(E=TA2#dG)t4)~L7jl9%M<<4bfQceXjC%N?REMLWOa^wmh*pbj7KM954` zi`I*YX|JmXF)+_=z3Re}h3waTHqwD5eh0VNy${{FJMY@Z*)b7&a+~87 z23CT|;*2WB*Kqsi)8B@>&~`oc`_g#Vb1d9B?_$e%)BBuc^uqIsWMqrH(p2u*!Nz3W z+nrNo!}~Bb(%p`(HMP5tvEvhS3y;1yu-6tIhtQw()#KM;5GxrC7{=V?9)?e>i~qH+ zOUygT0Lj_YW92+RzMrf_fSj{05|ATAQn2Cn10OU9pzYdZbXYx+JHD4}3MW_}(0M`^I|%Tck6G>DXLTPYA502myz3q-?+fDhcr$YWvPHESK) z(02hN$N|ik;e0^=L}dda<}Q~O_s@J@lIvXGYmOZtR83{IsnLjPHeo$pft&e-!!cb+ zmK_ln$w|bIX&n9d?t0{*&zB-|_ls)qW?lzXtSb8lE1i=1@%h z19p1w<@1C6^TU(lE%;$5jO_5&d*AL2pYI(%8}2{fJ3~=!vys(1Zrd5*?{={CT?4NZ z=3|S`iznoPNbazYUpq$l?B)J`1Xs#i{!tHOa4b5ld|nU)67>rabWW38y@%|owY5C@ zfrBuBUhf`SansX~kl3bP`@&Y5X#=zI>`aNv@Rk;kK&X8khg)8n(xmc zDMk{C3I7EnytyY3$Oj0}$kkA`Fl-8bXH^tv;i~r~G{@Pl;GOuku5Ki%^PTybTL1j> z&x^TPKKRRkUG;h(^`h3v)O6wseQRqi&2G%ErK(jctylGQaEjk08dBd%IuMuLO$J#v z^1O*793zd<=-t4TPK#Rm;#L~#tAT%t*5Y?%js6PKHV(TuCluXsGJ*f2cSkuTMPk)_ z_JS7M(1pp|b&Cqd#qHmI#2L$Jf}|#Zkc5AupKT5Gx2_tr;jhiH9`Ig~Joto#1gK`S zf9&e1aydC;JVwfzf=6L&PZH@c382l+|B|CcRn}@6T18}*TB(|WK1h9_x3}{(*1X4b z02_jq^$Eqe!V^<2ZRUD6WM3Eg&B6CT7XL2Ce!f)a{qv5eH0G7asg3z-PjL*?l+>ma z`Om#56*Bx??lLb-m{}&L0k$3fp>@FbQvRCj_3nCDPeR~RfjN|~?;%?oDK6h&QW=ph9{WUn9I1T^&$tMw;@lw~ zAuxwrA&s;I90?r>PiPy6(Dsql1}~4lIR<=+O?m?vR$4XS9h~1xd!V)=wBE9o!1S^I z&EB(vx?%X92y`*N=Y>UH@I5#N57-X*YEb}u=` z+T08F#pr3sI=Ol34QVz9Q>Umo=}(xZ)1PQb(nS7lWY1e5{iSSZ#(iwa8y0`HyzN>M} zm4Y=qJos)qH=Zn&kA&~Gsl9sqI*Rr9lL*1>Rx|{nQGI2vo8hcMmfO;C?T_itoIkb+ zwJ}WBYN)^;Wo7{;|@nP!=WYFx%+W*zF|jQW3RbZ z!{f(gxU%*KD=ox2QlBs9WLjxtrv}cTXqf(e=x8d z77s(8N}&J>{RKw&MWBlI_>;m5vDdq@fXqB_n1T`z*yg z5CYIH%cLY2ta`-3!4zM|S~QHqqnqfDiTpLqkC0NCP^jFtcv~B&Hb+)7HkqHBuDZPu zo?Fl1hFY;DRdcGAb=m=ZwPiV}9?dT=?IA}a)U1XFp0v)8*<-%qWXX_jVBNq2-6Fn4 z+)60RvGnH10r}MVmiz_YHSnW+4|CADl)pd!%Dm4nPEY>%eIl_kK2bvNK{K;64hQ6QI@;%Ybhi^@RO3pIDYYX7V>TB@j6lmyqXcjXhrnr> zrL5P^3HY1d)Q#H922!ha2vksv5%D@?H*-f+CQIOvn-x(2k9-*}&Km@N0lQv-?+-St-!&NRo8S zv_-jCy-mx7aq8si>cv$I=p7QQ>+^YsZn(xBI&w?5T{o`$9Xd%qh}r&PDMt?=*%mPg z;lW7YP0NULXD;Ouoj<&4HKMd0D%lOA)@-%DejU5U(L`iEsil`MW9c+fN3-t=Wqb3# z?S=#|`{;n=l>08wbw+#S>bk#Wo*0)ZCpl{QzU$rVvud?o1Xw|S`;e?4n0ZTVJ+&8* zR%=)eW*eT;5Lc%qPe|@wQp4`sZZB$WtMi&@HF>kT8=BOr>GUxQG}X}&sW6obwm+`Lx^XZ`DLj(Gc%^o-Ir- ztBw#GEY`WqS=nYS{_^qpDF{M`$LGV7FTOt5KOY`G&2EXkEOt{8S1={Dg^2$gE8<^n z@&jEyo=MP$g>k9reYVO-mkUR^K4dAopqr>ADFVBbj zC&%9&oSv)HFPmNzzb!B1Z$p$~2m_d*sGIDmq4rL{Q3_`-p-WGm;f;a#bP8fs)W+^4 zFUZ-!safMY_xKKc3dJU$cp!47>)qH7D9#;KPZGrrpWY=>#*5)rN+4ZPp)6vX>QM`b zaXO0|IB1>S8X5O`z;R*c^>uYrPp5Sg=S<5aLZO#Was{OlpVz`av|t03{TV>?Dh)wOcVe^9S01qgv7pJGrhHSyNsO)GeHl8Ya+vtC~Q_q zF<*b>PbZW1#JSzK$vhrr!hQ7ut)le?<}LI@i1}5I(+*GUmg8tkC6eryIMH<4L!Bxj zRFLdriZ{ca+bnPLvcgRk-Ig`0C5+k`LeR?1dpqhVhT7l&ktMU5o9@5%_J~H}%d>5) zspk1IhS)9lb+>KHgC)q)LeMM0;{;5VdP;UX<8`x|U&F)I2@>An@`3HZ)I~U1WXRBm zzc$nm5>AoZbz>MVSD>0TYs1dDvnY7p?2>+6PS8gbZa!A9c2&>~PwVl@rDs;2>+k`C z(~Q7T&|%)4fM*CefKD4#IldD6-SvF3#w=OL3L@!|p%ZtB&Tim&UMjEx)YM*7rg%MH z_F@tb3%668EZ9Gvf**)*3g0>?p31CdHe0y2-|JcZ6F|gJCKP~LoYc*#o|)-G1LTt@ z?pov};`yn;r;8@c8OD1=j6i4GE`OCnNwt?M5cMq?= z@|aYE?fDb|F=LnDYSP`6~%-6}Qrf0W@i}rUHv%{Fb=%rx`d_MB}mbJ+8K@eT2b`3m-g22MaM;Yik zj=Z7tc0@N;rl}HE2yj!SL8a{sule5TX^?j@eG$1FUGqENaCcldDo#qGDM5w_?~2-N z|ACdLz}>iYTPL38zwiEnA0B4#+#k{LR2G^7$yfkikHAK+G8d`!?c9MA;`Kg`2mfm<+Q?ro!x+o^R3bczR3x z1f_&)GLKh=V-7bn%rv{L_zNDZ!JH!x6L8mpCt%PA;`HfoPR6dSnvH zdF*?{TPKcYhX20vQT_kuEas%zbO0I{tTDeVDie6s6Y}q!g~M1z`xigs_UvMA)B7J> zmn!@t9a`Ot=hFFGoMoNgJhaED!T5s7pYp_0F6ckUahF~Bu!C;6>a32mm2U_SLRY*G zs!2}_I;1ou^?}*LmV7WR?@8;9CE*S~)VxZkAUZDn;%!weGTw$hqnV-}Lv>YJXa*=( ztwvYfbR$BjTws!3ep$^I73MGX)UY;FSJB0gYrN;42tux`BL`!U&~l`vy8}%LK}nqh zip233&ulIp9Pr;;Q9f^T9;fT3eix<=+Gh0z;XpUm@J5ZLVV#rOw*+Bp%VQK;<(Pnk zI_Bh3di`(hmV;dkQVr&PSbV>p%w0=)8*-&xs!;>rrLSjn63$Pp5rbJ2ce~saoW7@EQNEgpH#;jL zDHGSZ>!A-elsDr6SJ0G>l$FhjFL^zCH-9SwA?F0NR@c`x_OqmKR`YpbWXoHgse6w# zWJAWq9zI$vLrsugCNXh%cJ%^xFPBx-D}or#pd+T^V})G?)Zp*vmVr!^%S7VDDNhQGUPQ@oj29$ zhflwp&pYO9+$wFu(XIJoAAVB*eo+$QM?}Ie5Ek2%z{p@3ADqJLay7m(<0{{ja5~8F zGh5Y=nY(^7D@}iuZ~r-;QzdfDSV;;Mj3tq)0k+K!{c>!+?Edg+CtfAD|H$Uj8L*8u zH6jja-MoG6BlTg!EDNp<{Xw>g4;A2Yn@%{609wO@0A5P~p$-Z8CSI?(<`POu{7Ov* zs=Ms6LvAkJ<8ltBq>;_sIS+!Oz!QtV&xUU2|4ciP?h(MgvH#%@Z zim{i|dTf2Y_06`InyJEn<7HKuc7RMw%k^x0WmPf(H_`OTo%u05B1_Z9)o!QLr(17- z141-kteV|dul?t9Qdh(I%&_>i4I+_fo?&C&#`P)@0-{S8;fm?iq%^d9R{x)BDBm;P zLMli@fgghhoUNFktzRtXz}OAnEa&TmjcFmTHE_I6@s-bUae^2$vFRbo5H(7XPXdp{ zxK-|mD&A3^*uO`hQv>z&=Od2x02V+8r_`Ffo>f(`b7@AnGxB#a)w*B8bO6K_JPBOe zGYZ569+}8HzUj~b8KNUfYfc7bxkS^X88LMdGZb)9$Qgr*1~;TxRAajPh%E(D_2R0! z2AiamM*=gc_#?Q8hfbbA6s%`x@tqhY@DAC|JA7aG|XOaV8W&mO^( zJt>zHBV)3&J>=QA6vhO{6((c@+u)Av{L|BfUFCR_Y{%YU2%q}EWkM8Y`O@IJow|c)?mNu?v1>qcolY(t^Lj#iQo>?#O z7PR~6l3I^!!I67Q{L2c={zwKj8uVfsZ^=m{O5xz-_{*)tcf-^|9xq_8%+~QHN@-}f zp*$+Z+Zr|_TCsq7?LfoJia=lJ8<#y6OZcz1xdDi0*<}h52n|>RQ}F4K$70?LCe^!{ zc`*kr77L>^7oZ{V+PDPm#-wV-%X)zjWTkmnD}39RNSVxUgW@|_7ci~hNrZqII`OMI zFp<-3Y9u>uEIkIq`lBz1Y4gvuAn`keN7b<;5E{F_Mivy@yexC$Kb+Fa1YZ^`LXK?J zI}BkJPv77ON1!rrrh(lM;i;UFxECx-Q*>#u`@Bu9Pjd8%_loQK%@yU6qyWQ3wVd9P zWIG8sSoS`KvTMTFEASmP@9;G-(DjVKL4X*6b`|y?^G1bR5zeBIe#-P}O-od8;-j9V zd_y9kd<6SUi^Vzc-w+}HtCORH&o1iOXFp>2q3^EeiXeAMq0F3-2v4MTHYDaGrsWy_ zB;aG={?yK!PBls0gGuX01-YsL_i}+a z4<0;tSnTgTfBwbZ{x`+hi-Y~cFAw(xy(I7`4FOC*-x631k(+21<`u3PLD&2~t|+P; z-Lp}X?bNRi?65Faos*JT%EdUbTIY>4b6Fuf9_kG3f-xA@lJehbw;^S!SdcPZ63P(E za7RoDa-j`%26sVb7Ca{9%Nx`=D^yUC$~o6K!B|T%7>VjiW1_>>xDxA zgeUl0&nFdndy7Y@btq_vvs$967j+Isrk{GK+yq*9YZc7nat(3}oUln%PRW;k;LmXi zGuSOoVJ4x8y>5nL^)p5MG^4rz@OeY8lqKa9Rg!Qaa8XpjrO<{krG#M%909A?mY4(p zkpf^w01;i!y1jv>8Rg{=-GX=7?_>bcaOOCngM~hZeRrB6*{u0raNOufQ7f)F`lrj? z^ySZIC;Q(F&(2Q|_Kx(jf^VleEko|J8@KFAr(IRN!UH4DO}g1>gqW6pd$d|Fm`|t| z@BXu2{MG-9k4gedBsa_N)~q4?1bX>o+z$9vR{-NNbaxzD(~Fua&7=g2K5<1u?;9{a z1YF~F1DSI=d+ftReoIRWO~0~m;5Nq2y*Zs!3igh;v;6${mz|%V{L=ID86{TRYlPi+ z?EUSU0p{@kbbu9<>lc*=fN)gaUQ~iU3Na)N1(BHZUzM}u3h4eGnYn^#};JPiYUouB+ASZ2=I3Hw?9EGBrMS?=gIYnwZaL% z6WgL$Cr9v}T17VMu1C<+h{&qwMy_*B_JgDGAF~z-$B9$oiuPD41 zA#lb(hX`mS;78@4jRwV@ZYmU?G(;i$MZn-hH3jkfU_(L-;7IR&KShS_RJX?cZ|Raw zj+F9($6kaRA8^`s=W1wK%!1bjcg-wq(izBm3}|D!|B)nsX+Or!8NH(JOqks_gtjw? ziVD%a23NjG1sqc1)%0JaUN0Op6{T&(IE*yzMxTo!rR;g9umA#vhaWO!#Ut<^HhOhH zatL=KmEo_PfQ_96P#jyBs0Vk4;O_434#C~s-GX}{!6jI54IbPzxI4jZaCaT_k*&LL z@2!2i_tjKSbx%!w{nu$ZJ^wlLf3H7`0C*fPp>f)rq@NK2Xl|62$$!Tp28~|64V&Vp z*|#T%vtR__J&XPhZ{BO1fEG&nz|@coRivDNTWmM$gS?x46J?SPmv;&MEX^~Y0@|v* zw7>QSeVzt*pe@KceXZ~eTVE;8rTAVEW}`JWN4Ux`3RSiCH7^XG2)B zjY3kpA>Zuuv7Q6HEPg!=_*}F=zK%z+&xkLc5nSkj(L|3<)kySns|N&^$^5}p%9S}A zCJlH-j2#!=JCf9Xhi+Cf%3$9gut%ha70r=!M5_5jY21EQpQ0NMhMuXps+bXLIMz#- z{A$jd6^u50ZFTgxuo+D3eeP3-Z&yL(E^<-L07fkKFRz*T!C(mfbN$O6* zn2&ugwZDQ}cS#-xi_xkA4$a+InG(d{z?2y2JSYcV#+A8DP-BqS^A3hd{s`3cA`<6xs7OA?@TER#2})$U>+fW=4P zM(db#zMpsI?RO&9joyWQfUGgT8m47OlPIci)@uQh+N<4wGzj39)rCq^q5-bJIzby5 zwo}hSlna$}WlV%^1;95{(ZmKsp6pS)P5I?7E83uwP$8Ct{Nc@YHZ8(_S^Gf&mPo)~n zX#unMoKIlY?gvj1TI#h>HW~C{$x@-xAjt^mVuSN70t23b7RmE4B@O`|R-I4Q3^lcg z%y35ps4ATcz75w0!^b6XLVD|nGdl(ZZjCwmdoI#WbMhjsrF!})xAl1>-&cCvp67F% z$En>q`wx;6zkl@}e~kQb!-;f=4d>^tg=-wn0p8-bhb^pXIp}-dx5sj!#`hhg?L0np zgzvk^-bhZ)56rR)lftqr!I&m&rP&kk$uDtKUU(3ls4wmSFcAdy2*`nWKN;Lq>&0xX z;5o2Cn!Amit@b_JZ0y}kf+wxz0GXX$Y?+(G6tI%)0l1g zj3WHD{K*kh7F3$6YfO!39cE{oyMnzVUu5czvt{_0u9Rn7y>azGmtm|6?9}m;qREt^ zWmJw#y|l`msLLoqXn+lBL{{gAdhOY#3)zj^)vxwrWt~CC%bz#J0_Gk`O)DK9)wsL4 z`buNK^aTUHJPBkQ-mBR%%*j9HUp}?X)gGF7A5(HRvZVJp_4D$%*PY8?0Z?=VNrW<3 zC9j1D@B9b!3BV`@aIKlaXGcCj>%)%IiCPoqw|%)v<|s#J#O+c^#LFE8Q_bZStzC;s z!}=vTB*=I1wlgQgE}()BgfS%jiOtG;VNrtp)ck#zSj;ajo}}bj5t5Vw@-Cp=^&YxD z$g|z{BoYSeEkBpTbE_>CeCBjjtkb+Cs~+|io!dtluhqty>U9shdGNx*s@kUf34wi# zVrBwbQVNPP$;KFNeIfZHxh(`ZTRg3=_HN1pR;K_5qCI#pQzff?l!`Yi-)SvDIS;ta zb|g~+P|A%TD8ioiTu5FkofOR|W+Y2bzOWGPzu>I=xg;tT2Mie++7jGYf*BU9+QEc} zBu?8q*kpde8mnW8>lgtjhHGP{8mCoB>DwhGlY`_ybQ^^hu6X5(@@lMh1$=B4u5tAF z?ddjHBKd$0AIuZ2Wn1skhqRJL6ijO?j!Hhi^$LWK}qJ8$i~+T zhOH5QDJO7)F}sgXYf+sivq&vHKqg`=cW#;$vW`oAyeQ0DZ>_AuUXS8*Jt8Z;tjc|M zuWYr>raY(m)Q1c~|0nQgQ-h^m?Hw7X%+vtb)zk}imwL56(@nk}_jiHL?|SucecfD( zAs4@xM~aBd%)?2)(o6{F8X>xw?U=X;QhrG?PewHaj{e7(Xvik$o|(sgs=mE5H3_Q4 zlN^OuDy%d~W=U0$ohb+m^y*}UGTF)NNBUUyK0$)rBeVDkOJRu4ydpYajV1IA<)1}hOCSZRG|&J5Ri3{oM4VlnJ~^2= z+5bxc*arUh0OelJ3-2fcpn-RsbzGTde*cpa(7i8^F+cie8YbHca_S@+9 z-(SeqaT1$aNZh+LQidf|3?Ch=66O4wm+^yf9-s_f`ovPo+@XI9#MtEda&{hO+LjeAkkhy{qk5zk2uW53j z&8mxe^YC+0^b$YlF*kakiU3~tPC8!kE-`(6w2T0;fH`m%VKdNJ_2QH__S)yc5@#JT z5M?kC%3Cx$OZaX7+9===1QOmI^CGj#J}5QXDZAXU6~JX)@A<($0W=?&o5oU01nxSs z24HBjD5Z<^!XYai^g&5|23dRIhZbn{;>NyS!XUn8LF#-71HT7CCib5=c=1C8^c?md z8`r}Qeuvd-nXu+d>+162gLAT)T`_L-vg?(ogkXUBg|zYY+W3wtb!ql4^LPxq>kZZ! zN^5%q4LAw5zki#7zt7f;w{E~J_Fz!SB-A{AVD5zhq7(Uw&9fBnvjm(GKUx1YY^H>b@>Nc>63EZi6=VHOhvum{HjW zLA$v#unM%{Y_jo@8J?>bt52Ugb-ak+iF%<0k4Z$`XGdxAdd(VhZ2!Ssf|v2k5(8*0 z6jlKQZcItMDeeJ-$_|v1ZAEFC5Ln>*aakpE!CFums&yGu#kAhI;kqv@pGnA$H?4i# z*h6!gOi?}EpsmQ+A&ZO|e*1h#Biglrk1hWsC9a$-R%*{7X0Npke5~bc=%(e-F_!Tp z#|*$AkqQEpxApJdG|84sC}WgK&^Sttr$JhE(NyhDJ4|5npx~!XENiP}`eNyCG+KJ_ zkpFDum2`nuJjigl(6hUOP=x*!_jyIp@{Y%!*HO-7ys$bIsyAbPm0({4M*P$44}qU# zg^1liIUc*~wy615hmJ51T>K26g{U{}4E&>Q?y>O4z69 zxP1uP$bm5s`2s!TdWKoYaXDv#)iXV+Zh+6s;CkH$NNz5}*F+S$^Z;i2p21 z>ou9|VW=!tv^<6q;Ri^S=^}`|Yy^=JR00^UT&T$)q3aWaavKpJI0Y@ck3zS@E+SOk z0|MVRGg?fLJd_e4wsW7kg@f^`54?AvsJ!}vTU;9^INDN7y=S&_vLU(M#wxLD&uXyK z;2amohE7hQ+#^?^n2suyzY8HDAGbH#T~ZC(grt9?@T}l5Z!hiYi53)rN~`g?#fC|^ zFWOu0*w;WHs&5ha?rgcwhBRJRZN9WzVP)A(6T~O0_<_OnQm|JBx)8!9VIs;4U;N8C z63?;n^~9o39$H9xx*%#Osj?Tga=zK1iF+s>p}!2A#eJ;(6~$%;KqydiADkb z`Yg(rV-l8EB9`~43m-u0vC(T<6Gut@eDf@PV;d~mjR{`lJpLsOXG}k?%2ak(_cxYk zX>z0Nkb!%N_X+=I)f~)+aEu;`Tkm^c9~;ugSXN?^T}S&T|h$MG6~7YykGQ8 zjDD}9)!8m>|1m$F%3{fr_7s)^GsLy#5VA=3=q!Fgh$unP?u$J%F1;6=Z%{jn*}Zca zf>Z(F0sIZ|bi(GIqc?HH^KS~WQgYpUWrytNtHp2AUD>V{?O#5?!;HNV!%NepM4}mE z5#&)KR^Rq~zE84`rd-sEtc%gBGub+Jl$j%|L-U84Tu8Cv@J))@>|^cwvF|`T*X__B zmi2+%;rc%d=+JfhVhi?q!F z8`Wrio^|GMr%5E7PG7(75!F~c?R$@2!hr%`a;7&zEcOVN1G?ju&g+Dn&5 z$ZEAdc2!?yhiF+&OK)5?Ajlj}@%8z4$x<=C>0u4(f@n}e0ZUes&RTPIlo zyXnjlRi#KS3AXcxB&UK|WLoJ)d_P<~5v^dS7M|b?&b3qYu+k$IXsg8Np&niNMSPBd zbZOTf9=-w&?4MoX&rxeOnpz|p43mD_5)Dl*6j5_b6>5!jMNfgtowD8K_;3hcz#fD% zq8$u>S6R*~XM`6I$<&6b*>WB0XiNe%Z#*Vsn=o$X`!Whui>D^tIt^@fVRg4eMP`+OeZw6%&fc%@AXeR}z|8#iZ2*!bKPDBc z#XQb(7lO`7j{Z(0&Lj!=g@?zly;<5;VCM$(<#zZW5iR)Z(3Zc#Lk07rV^W$}WlsWD zd27Cc=jsSI9b^RF6uby{u_AZlPcO)r+|^*HmaHq;2wClx_VlLwj71l6=R9LcmF3m8 zmftB2B6$)N1AWIxM0B&52ym%b-ibZs6$gdd2Z=V&77ozxG;)3tq;908IN7jR_Py{t zKzYHRsICWg`m)g|w5qM|^7;l^SjA~{ibY& zI-@BjQ4whJhb-cAp6P#XyS|^N5RPM6Ir-Art8q>_<2ea^o)I=zITZ{OLp;j5QnC2# zY$R$%H}=k8Q@_14KBxnK#jMf$_51>14x;&>n|NtIr}G*1Jn}aEo@>Ks7qkL8HNIf& z_DF7}W<*=H%{ z?G?KBJ`*}{j!@kqI7ws1^U;S&sO3?Rp3xhOFT*@eFb|{H1XX(!6MZ`gnj9bi?=qLR=vB^U@gxYKf3~K^d{}RhHAOeM&hr-efEkJ zU6dz!f0tvA*4?G;HDstF@1?0oDg%OHiz|_>(praSTHxdf5!#`wmdn%yvX zM0#~NgXu&F>vsRFeXJt5h~(99NN1!XB41O*(3SzNVJ!}u;8M9L6YQ_x+S2{-1qs(| zPMgD|LgtcUB1XIE1bR$FjUy7cd4BxjLVGYIp$hg2kOUhr49dtFnGequNSzR9F-+vg z11}He6F53dpgMVdNHMvc#1z+-9nN?P4X(3R+i%E2$~Q61cEJWJ{yG%3Vre#|O|I_g zHg_Kd==*1D#dk8qsfd;?iW{+BhC=hz2`QorHX}IDbnzGwF8T1PG5KT`k?IW514D;I zQgCQfp0%yzM_cK%Z8^Ogx#Q=jKV|mxb99uFzE6SonNitK?KP^F3iGbc|SZRHR=#<+vtZ*0nhXu+I;}F!O;IF>50iCbB)f`PJzo%p}D>0V5 zvU8t$6C(FI!S+L(`-1n-O5s%!;!z9vz(uB3i`pqqp ztJT#Zul69(K$S=C1--E1^-C-#$qY+DHXl}0Jscj4^im@*Ug=SMovG4ugt{%Vt7?Fc ztqI+-IoqOI{ck?Kt1}mX$052CWx$hZ2;S`i*HAA^DiZfG9vYu%yfM{8c@@Xw_r~S! zeWR6!Pz@CC%S(Pwd0DFYv_eC*nvJ?8p+R4t4XTWdUnX{uPC%$)E7|-re7`mvw*1*Z zN(j)PKrPr;1M6*EX{3twlna@MCV{!iA>vw1MBq8Mrm-AaW8Y%>Haha!madnAf-!9= zllMJAyDiNgs4+z|*?8f@5p#7=#{oGs8}5^_dnV8bRL^^o>;ms5z;WJ9tL3IYnt>)b zImDuz*0vp&d&r7i zR4&MsL|yWZyJ@SRwMwoE51W}3Mib1Yq}EC0Ci4yEztS^*{OaN@#10Km008WJiU#FB zRjO__R__Ane>137Xeq?4aG`b_YoQKH>hi1;`OqLfOpz(nh@=A)T;ai)8CAII&`T54 zxnA%6#lK-%@DG2eU`mmlSjgb+op!t1+Jh>0C!aIv6>1N>-NM$T%@*{%D|Q@EWbxdI zMUo2q7E>L~PX7cps!4=-&Tzc5?(C+E-T$=R&gZ-zsv@s@Cz*-?Hyn$<$8iG(t3#c{ z(&Xdxr8|NKz*((dFLf_Z#~;4Tx;P6h+=zWvl3_DT7z5n-v>H%x9A>(VuP7X7xX>Cl z(QMbu7oJapH(i@89G+!A;-T)|K#N1PM0E>d6O?L33r-R>?pV>8ZfG-TFug{9F)pQJ zi?Gh4S#S5hemXm&T%nqEsz+a>If~UAUH6{)*zp$a9AcAuf!|b%9A10)V^VS)%boyM zEabbc48U3kSF{0ZyUZk}9*`^K8WHh zTZJvEH=DPCz;#?I0HTmSYvSbDqF1NOju7A%I6#tZ!f*FZ4X5*$Vx30QNrsC1UmT(A zEV@%qLdtsh`4|OAvAK(kD>@Tq;h*f?-I1)htg2l(;RiCjp73}4I_1@o;RnmTAsh4I zk-VWs31@E0QlB~XprjNsgMKma%q2vZJE1SHzs>BJ?XZ5Zfzk);M}YDVX|K{3q%S=B z@5pmiXc!Iobr1s8{ce+BpwpN*?|cCTD-2*=_RL}qx9!LCa!q?7dUuzz1acM0pTUSn zGLmYO8mliMRQ$qd2?8TbUdpuZl z;Zm?whIT&eHc}i&}Ec-vdsFpsc4KYsqCQla%AEfvG}Yn9rD1Xb2=CWdbt++d-2 zx7kmLCEP6>1+Z6bl1aQ}YH?tG2;V!AD1G@+S6l^3wzhrwJqYzh4j*)v2}28t1hwR@ z*ghz7QV;G6H+&uMAumfT>6be%0b@Y$pcRd?cOz|FlY(K6^YKVujMygUqi33A8n2v= zNBgJ)z2@dtQr6<1tMW=xfw}CyHj3)@F zL~*5}*3jCjK4)S7a|Yx1`?;3j&}$>9Om01=8Q=*1o?=!4qEK4^VtaWMbl`h zqNiCTkDlo{0oV5M`rlY+!26diPe+Bz<_lJMCc>@4{3~iDY38vWZ>wlL2d-wOSaCLhZX#Y!Y z-pJ^$?-?192$E2(g*xq5J z)_`}j+~29N(0@_Q>}@O@Ke<`BzK8jr8MkdGBy-eEuuce`$n2*ni5||K14DRDWUr zp=$s4ru{?xr(6B+)U5xc{y$CY|4#q!O2EIi!Ph%lixZ_iONfiWKVoR0jay-(NQG LCzn+IpRNA`+u!Y8 literal 0 HcmV?d00001 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7c664966ed74e..dbb463f6005a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -998,8 +998,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) - return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7bed5216eabf3..ebdd665e349c5 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -29,7 +29,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main -from pyspark.serializers import read_int, write_int +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer def compute_real_exit_code(exit_code): @@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock): +def worker(sock, authenticated): """ Called by a worker process after the fork(). """ @@ -56,6 +56,18 @@ def worker(sock): # otherwise writes also cause a seek that makes us miss data on the read side. infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) + + if not authenticated: + client_secret = UTF8Deserializer().loads(infile) + if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret: + write_with_length("ok".encode("utf-8"), outfile) + outfile.flush() + else: + write_with_length("err".encode("utf-8"), outfile) + outfile.flush() + sock.close() + return 1 + exit_code = 0 try: worker_main(infile, outfile) @@ -153,8 +165,11 @@ def handle_sigterm(*args): write_int(os.getpid(), outfile) outfile.flush() outfile.close() + authenticated = False while True: - code = worker(sock) + code = worker(sock, authenticated) + if code == 0: + authenticated = True if not reuse or code: # wait for closing try: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3e704fe9bf6ec..0afbe9dc6aa3e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -21,16 +21,19 @@ import select import signal import shlex +import shutil import socket import platform +import tempfile +import time from subprocess import Popen, PIPE if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_gateway import java_import, JavaGateway, GatewayParameters from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer def launch_gateway(conf=None): @@ -41,6 +44,7 @@ def launch_gateway(conf=None): """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -59,40 +63,40 @@ def launch_gateway(conf=None): ]) command = command + shlex.split(submit_args) - # Start a socket that will be used by PythonGatewayServer to communicate its port to us - callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - callback_socket.bind(('127.0.0.1', 0)) - callback_socket.listen(1) - callback_host, callback_port = callback_socket.getsockname() - env = dict(os.environ) - env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host - env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) - - # Launch the Java gateway. - # We open a pipe to stdin so that the Java gateway can die when the pipe is broken - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) - else: - # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) - - gateway_port = None - # We use select() here in order to avoid blocking indefinitely if the subprocess dies - # before connecting - while gateway_port is None and proc.poll() is None: - timeout = 1 # (seconds) - readable, _, _ = select.select([callback_socket], [], [], timeout) - if callback_socket in readable: - gateway_connection = callback_socket.accept()[0] - # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile(mode="rb")) - gateway_connection.close() - callback_socket.close() - if gateway_port is None: - raise Exception("Java gateway process exited before sending the driver its port number") + # Create a temporary directory where the gateway server should write the connection + # information. + conn_info_dir = tempfile.mkdtemp() + try: + fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) + os.close(fd) + os.unlink(conn_info_file) + + env = dict(os.environ) + env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdin=PIPE, env=env) + + # Wait for the file to appear, or for the process to exit, whichever happens first. + while not proc.poll() and not os.path.isfile(conn_info_file): + time.sleep(0.1) + + if not os.path.isfile(conn_info_file): + raise Exception("Java gateway process exited before sending its port number") + + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + finally: + shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -111,7 +115,9 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, + auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") @@ -126,3 +132,16 @@ def killChild(): java_import(gateway.jvm, "scala.Tuple2") return gateway + + +def do_server_auth(conn, auth_secret): + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise Exception("Unexpected reply from iterator server.") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..d5a237a5b2855 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ else: from itertools import imap as map, ifilter as filter +from pyspark.java_gateway import do_server_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ + UTF8Deserializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -136,7 +138,8 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(port, serializer): +def _load_from_socket(sock_info, serializer): + port, auth_secret = sock_info sock = None # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. @@ -156,8 +159,12 @@ def _load_from_socket(port, serializer): # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + + sockfile = sock.makefile("rwb", 65536) + do_server_auth(sockfile, auth_secret) + # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sock.makefile("rb", 65536)) + return serializer.load_stream(sockfile) def ignore_unicode_prefix(f): @@ -822,8 +829,8 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) - return list(_load_from_socket(port, self._jrdd_deserializer)) + sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) def reduce(self, f): """ @@ -2380,8 +2387,8 @@ def toLocalIterator(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(sock_info, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 16f8e52dead7b..213dc158f9328 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -463,8 +463,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + sock_info = self._jdf.collectToPython() + return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(2.0) @@ -477,8 +477,8 @@ def toLocalIterator(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + sock_info = self._jdf.toPythonIterator() + return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) @@ -2087,8 +2087,8 @@ def _collectAsArrow(self): .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + sock_info = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(sock_info, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a1a4336b1e8de..8bb63fcc7ff9c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.java_gateway import do_server_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -301,9 +302,11 @@ def process(): if __name__ == '__main__': - # Read a local port to connect to from stdin - java_port = int(sys.stdin.readline()) + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) sock_file = sock.makefile("rwb", 65536) + do_server_auth(sock_file, auth_secret) main(sock_file, sock_file) diff --git a/python/setup.py b/python/setup.py index 794ceceae3008..d309e0564530a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.6'], + install_requires=['py4j==0.10.7'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bafb129032b49..134b3e5fef11a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1152,7 +1152,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a129be7c06b53..59b0f29e37d84 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -265,7 +265,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.6-src.zip", + s"$sparkHome/python/lib/py4j-0.10.7-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index bac154e10ae62..bf3da18c3706e 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cd4def71e6f3b..d518e07bfb62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3187,7 +3187,7 @@ class Dataset[T] private[sql]( EvaluatePython.javaToPython(rdd) } - private[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) @@ -3200,7 +3200,7 @@ class Dataset[T] private[sql]( /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ - private[sql] def collectAsArrowToPython(): Int = { + private[sql] def collectAsArrowToPython(): Array[Any] = { withAction("collectAsArrowToPython", queryExecution) { plan => val iter: Iterator[Array[Byte]] = toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) @@ -3208,7 +3208,7 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Array[Any] = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } From 628c7b517969c4a7ccb26ea67ab3dd61266073ca Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 17 Apr 2018 13:29:43 -0700 Subject: [PATCH 0768/2461] [SPARKR] Match pyspark features in SparkR communication protocol. --- R/pkg/R/client.R | 4 +- R/pkg/R/deserialize.R | 10 +++- R/pkg/R/sparkR.R | 39 +++++++++++-- R/pkg/inst/worker/daemon.R | 4 +- R/pkg/inst/worker/worker.R | 5 +- .../org/apache/spark/api/r/RAuthHelper.scala | 38 +++++++++++++ .../org/apache/spark/api/r/RBackend.scala | 43 +++++++++++++-- .../spark/api/r/RBackendAuthHandler.scala | 55 +++++++++++++++++++ .../org/apache/spark/api/r/RRunner.scala | 35 ++++++++---- .../org/apache/spark/deploy/RRunner.scala | 6 +- 10 files changed, 210 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala create mode 100644 core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 9d82814211bc5..7244cc9f9e38e 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout) { +connectBackend <- function(hostname, port, timeout, authSecret) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) { con <- socketConnection(host = hostname, port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = timeout) - + doServerAuth(con, authSecret) assign(".sparkRCon", con, envir = .sparkREnv) con } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index a90f7d381026b..cb03f1667629f 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) { stop(paste("Unsupported type for deserialization", type))) } -readString <- function(con) { - stringLen <- readInt(con) - raw <- readBin(con, raw(), stringLen, endian = "big") +readStringData <- function(con, len) { + raw <- readBin(con, raw(), len, endian = "big") string <- rawToChar(raw) Encoding(string) <- "UTF-8" string } +readString <- function(con) { + stringLen <- readInt(con) + readStringData(con, stringLen) +} + readInt <- function(con) { readBin(con, integer(), n = 1, endian = "big") } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index a480ac606f10d..38ee79477996f 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -158,6 +158,10 @@ sparkR.sparkContext <- function( " please use the --packages commandline instead", sep = ",")) } backendPort <- existingPort + authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET") + if (nchar(authSecret) == 0) { + stop("Auth secret not provided in environment.") + } } else { path <- tempfile(pattern = "backend_port") submitOps <- getClientModeSparkSubmitOpts( @@ -186,16 +190,27 @@ sparkR.sparkContext <- function( monitorPort <- readInt(f) rLibPath <- readString(f) connectionTimeout <- readInt(f) + + # Don't use readString() so that we can provide a useful + # error message if the R and Java versions are mismatched. + authSecretLen = readInt(f) + if (length(authSecretLen) == 0 || authSecretLen == 0) { + stop("Unexpected EOF in JVM connection data. Mismatched versions?") + } + authSecret <- readStringData(f, authSecretLen) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || length(monitorPort) == 0 || monitorPort == 0 || - length(rLibPath) != 1) { + length(rLibPath) != 1 || length(authSecret) == 0) { stop("JVM failed to launch") } - assign(".monitorConn", - socketConnection(port = monitorPort, timeout = connectionTimeout), - envir = .sparkREnv) + + monitorConn <- socketConnection(port = monitorPort, blocking = TRUE, + timeout = connectionTimeout, open = "wb") + doServerAuth(monitorConn, authSecret) + + assign(".monitorConn", monitorConn, envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -205,7 +220,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort, timeout = connectionTimeout) + connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret) }, error = function(err) { stop("Failed to connect JVM\n") @@ -687,3 +702,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { NULL } } + +# Utility function for sending auth data over a socket and checking the server's reply. +doServerAuth <- function(con, authSecret) { + if (nchar(authSecret) == 0) { + stop("Auth secret not provided.") + } + writeString(con, authSecret) + flush(con) + reply <- readString(con) + if (reply != "ok") { + close(con) + stop("Unexpected reply from server.") + } +} diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 2e31dc5f728cd..fb9db63b07cd0 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) + port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout) + +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # Waits indefinitely for a socket connecion by default. selectTimeout <- NULL diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 00789d815bba8..ba458d2b9ddfb 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) + outputCon <- socketConnection( port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala new file mode 100644 index 0000000000000..ac6826a9ec774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.SparkConf +import org.apache.spark.security.SocketAuthHelper + +private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) { + + override protected def readUtf8(s: Socket): String = { + SerDe.readString(new DataInputStream(s.getInputStream())) + } + + override protected def writeUtf8(str: String, s: Socket): Unit = { + val out = s.getOutputStream() + SerDe.writeString(new DataOutputStream(out), str) + out.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 2d1152a036449..3b2e809408e0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -45,7 +47,7 @@ private[spark] class RBackend { /** Tracks JVM objects returned to R for this RBackend instance. */ private[r] val jvmObjectTracker = new JVMObjectTracker - def init(): Int = { + def init(): (Int, RAuthHelper) = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) @@ -53,6 +55,7 @@ private[spark] class RBackend { conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) + val authHelper = new RAuthHelper(conf) bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) @@ -71,13 +74,16 @@ private[spark] class RBackend { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) + .addLast(new RBackendAuthHandler(authHelper.secret)) .addLast("handler", handler) } }) channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + + val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + (port, authHelper) } def run(): Unit = { @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging { val sparkRBackend = new RBackend() try { // bind to random port - val boundPort = sparkRBackend.init() + val (boundPort, authHelper) = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.writeInt(backendConnectionTimeout) + SerDe.writeString(dos, authHelper.secret) dos.close() f.renameTo(new File(path)) @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging { val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) + + // Wait for the R process to connect back, ignoring any failed auth attempts. Allow + // a max number of connection attempts to avoid looping forever. try { - val inSocket = serverSocket.accept() + var remainingAttempts = 10 + var inSocket: Socket = null + while (inSocket == null) { + inSocket = serverSocket.accept() + try { + authHelper.authClient(inSocket) + } catch { + case e: Exception => + remainingAttempts -= 1 + if (remainingAttempts == 0) { + val msg = "Too many failed authentication attempts." + logError(msg) + throw new IllegalStateException(msg) + } + logInfo("Client connection failed authentication.") + inSocket = null + } + } + serverSocket.close() + // wait for the end of socket, closed if R process die inSocket.getInputStream().read(buf) } finally { + serverSocket.close() sparkRBackend.close() System.exit(0) } @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging { } System.exit(0) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala new file mode 100644 index 0000000000000..4162e4a6c7476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Authentication handler for connections from the R process. + */ +private class RBackendAuthHandler(secret: String) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + // The R code adds a null terminator to serialized strings, so ignore it here. + val clientSecret = new String(msg, 0, msg.length - 1, UTF_8) + try { + require(secret == clientSecret, "Auth secret mismatch.") + ctx.pipeline().remove(this) + writeReply("ok", ctx.channel()) + } catch { + case e: Exception => + logInfo("Authentication failure.", e) + writeReply("err", ctx.channel()) + ctx.close() + } + } + + private def writeReply(reply: String, chan: Channel): Unit = { + val out = new ByteArrayOutputStream() + SerDe.writeString(new DataOutputStream(out), reply) + chan.writeAndFlush(out.toByteArray()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 88118392003e8..e7fdc3963945a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -74,14 +74,19 @@ private[spark] class RRunner[U]( // the socket used to send out the input of task serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() + dataStream = try { + val inSocket = serverSocket.accept() + RRunner.authHelper.authClient(inSocket) + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + RRunner.authHelper.authClient(outSocket) + val inputStream = new BufferedInputStream(outSocket.getInputStream) + new DataInputStream(inputStream) + } finally { + serverSocket.close() + } try { return new Iterator[U] { @@ -315,6 +320,11 @@ private[r] object RRunner { private[this] var errThread: BufferedStreamThread = _ private[this] var daemonChannel: DataOutputStream = _ + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new RAuthHelper(conf) + } + /** * Start a thread to print the process's stderr to ours */ @@ -349,6 +359,7 @@ private[r] object RRunner { pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") + pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) @@ -370,8 +381,12 @@ private[r] object RRunner { // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() + try { + authHelper.authClient(sock) + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + } finally { + serverSocket.close() + } } try { daemonChannel.writeInt(port) diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 6eb53a8252205..e86b362639e57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -68,10 +68,13 @@ object RRunner { // Java system properties etc. val sparkRBackend = new RBackend() @volatile var sparkRBackendPort = 0 + @volatile var sparkRBackendSecret: String = null val initialized = new Semaphore(0) val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackendPort = sparkRBackend.init() + val (port, authHelper) = sparkRBackend.init() + sparkRBackendPort = port + sparkRBackendSecret = authHelper.secret initialized.release() sparkRBackend.run() } @@ -91,6 +94,7 @@ object RRunner { env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() From 7aaa148f593470b2c32221b69097b8b54524eb74 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 9 May 2018 11:09:19 -0700 Subject: [PATCH 0769/2461] [SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs ## What changes were proposed in this pull request? Provide evaluateEachIteration method or equivalent for spark.ml GBTs. ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21097 from WeichenXu123/GBTeval. --- .../ml/classification/GBTClassifier.scala | 15 +++++++++ .../spark/ml/regression/GBTRegressor.scala | 17 +++++++++- .../org/apache/spark/ml/tree/treeParams.scala | 6 +++- .../classification/GBTClassifierSuite.scala | 29 ++++++++++++++++- .../ml/regression/GBTRegressorSuite.scala | 32 +++++++++++++++++-- 5 files changed, 94 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0aa24f0a3cfcc..3fb6d1e4e4f3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -334,6 +334,21 @@ class GBTClassificationModel private[ml]( // hard coded loss, which is not meant to be changed in the model private val loss = getOldLossType + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, + OldAlgo.Classification + ) + } + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8598e808c4946..d7e054bf55ef6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ /** @@ -269,6 +269,21 @@ class GBTRegressionModel private[ml]( new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + * @param loss The loss function used to compute error. Supported options: squared, absolute + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, + convertToOldLossType(loss), OldAlgo.Regression) + } + @Since("2.0.0") override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 81b6222acc7ce..ec8868bb42cbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -579,7 +579,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { - getLossType match { + convertToOldLossType(getLossType) + } + + private[ml] def convertToOldLossType(loss: String): OldLoss = { + loss match { case "squared" => OldSquaredError case "absolute" => OldAbsoluteError case _ => diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index f0ee5496f9d1d..e20de196d65ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.RegressionLeafNode -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -365,6 +365,33 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + val gbt = new GBTClassifier() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType("logistic") + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTClassificationModel("gbt-cls-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses) + val model2 = new GBTClassificationModel("gbt-cls-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses) + + val evalArr = model3.evaluateEachIteration(validationData.toDF) + val remappedValidationData = validationData.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData, + model1.trees, model1.treeWeights, model1.getOldLossType) + val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData, + model2.trees, model2.treeWeights, model2.getOldLossType) + val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData, + model3.trees, model3.treeWeights, model3.getOldLossType) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index fad11d078250f..773f6d2c542fe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -201,7 +202,34 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } - + test("model evaluateEachIteration") { + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType(lossType) + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTRegressionModel("gbt-reg-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures) + val model2 = new GBTRegressionModel("gbt-reg-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures) + + for (evalLossType <- GBTRegressor.supportedLossTypes) { + val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType) + val lossErr1 = GradientBoostedTrees.computeError(validationData, + model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType)) + val lossErr2 = GradientBoostedTrees.computeError(validationData, + model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType)) + val lossErr3 = GradientBoostedTrees.computeError(validationData, + model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType)) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + } + } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load From fd1179c17273283d32f275d5cd5f97aaa2aca1f7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 9 May 2018 11:32:17 -0700 Subject: [PATCH 0770/2461] [SPARK-24214][SS] Fix toJSON for StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation ## What changes were proposed in this pull request? We should overwrite "otherCopyArgs" to provide the SparkSession parameter otherwise TreeNode.toJSON cannot get the full constructor parameter list. ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #21275 from zsxwing/SPARK-24214. --- .../execution/streaming/StreamingRelation.scala | 3 +++ .../spark/sql/streaming/StreamingQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f02d3a2c3733f..24195b5657e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -66,6 +66,7 @@ case class StreamingExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString @@ -97,6 +98,7 @@ case class StreamingRelationV2( output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = sourceName @@ -116,6 +118,7 @@ case class ContinuousExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 0cb2375e0a49a..57986999bf861 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -831,6 +831,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + + "should not fail") { + val df = spark.readStream.format("rate").load() + assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) + + testStream(df)( + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + ) + + testStream(df, useV2Sink = true)( + StartStream(trigger = Trigger.Continuous(100)), + AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation")) + ) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From 9e3bb313682bf88d0c81427167ee341698d29b02 Mon Sep 17 00:00:00 2001 From: wuyi Date: Wed, 9 May 2018 15:44:36 -0700 Subject: [PATCH 0771/2461] [SPARK-24141][CORE] Fix bug in CoarseGrainedSchedulerBackend.killExecutors ## What changes were proposed in this pull request? In method *CoarseGrainedSchedulerBackend.killExecutors()*, `numPendingExecutors` should add `executorsToKill.size` rather than `knownExecutors.size` if we do not adjust target number of executors. ## How was this patch tested? N/A Author: wuyi Closes #21209 from Ngone51/SPARK-24141. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5627a557a12f3..d8794e8e551aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -633,7 +633,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } doRequestTotalExecutors(requestedTotalExecutors) } else { - numPendingExecutors += knownExecutors.size + numPendingExecutors += executorsToKill.size Future.successful(true) } From 9341c951e85ff29714cbee302053872a6a4223da Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 9 May 2018 19:56:03 -0700 Subject: [PATCH 0772/2461] [SPARK-23852][SQL] Add test that fails if PARQUET-1217 is not fixed ## What changes were proposed in this pull request? Add a new test that triggers if PARQUET-1217 - a predicate pushdown bug - is not fixed in Spark's Parquet dependency. ## How was this patch tested? New unit test passes. Author: Henry Robinson Closes #21284 from henryr/spark-23852. --- .../test/resources/test-data/parquet-1217.parquet | Bin 0 -> 321 bytes .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++++++ 2 files changed, 10 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/parquet-1217.parquet diff --git a/sql/core/src/test/resources/test-data/parquet-1217.parquet b/sql/core/src/test/resources/test-data/parquet-1217.parquet new file mode 100644 index 0000000000000000000000000000000000000000..eb2dc4f79907019ffe4edaa9adb951924aad1c5d GIT binary patch literal 321 zcmbu5(MrQG7=^Pml#v^~@DB~-BFI`Q)R6|)b+DV=$S%Yc=L=*_hK0@zhq6oG%h&Ne zT(Vd2z~P6(;p68ti 0").count() === 2) + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 62d01391fee77eedd75b4e3f475ede8b9f0df0c2 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 9 May 2018 21:48:54 -0700 Subject: [PATCH 0773/2461] [SPARK-24073][SQL] Rename DataReaderFactory to InputPartition. ## What changes were proposed in this pull request? Renames: * `DataReaderFactory` to `InputPartition` * `DataReader` to `InputPartitionReader` * `createDataReaderFactories` to `planInputPartitions` * `createUnsafeDataReaderFactories` to `planUnsafeInputPartitions` * `createBatchDataReaderFactories` to `planBatchInputPartitions` This fixes the changes in SPARK-23219, which renamed ReadTask to DataReaderFactory. The intent of that change was to make the read and write API match (write side uses DataWriterFactory), but the underlying problem is that the two classes are not equivalent. ReadTask/DataReader function as Iterable/Iterator. One InputPartition is a specific partition of the data to be read, in contrast to DataWriterFactory where the same factory instance is used in all write tasks. InputPartition's purpose is to manage the lifecycle of the associated reader, which is now called InputPartitionReader, with an explicit create operation to mirror the close operation. This was no longer clear from the API because DataReaderFactory appeared to be more generic than it is and it isn't clear why a set of them is produced for a read. ## How was this patch tested? Existing tests, which have been updated to use the new name. Author: Ryan Blue Closes #21145 from rdblue/SPARK-24073-revert-data-reader-factory-rename. --- .../sql/kafka010/KafkaContinuousReader.scala | 20 ++--- .../sql/kafka010/KafkaMicroBatchReader.scala | 21 ++--- .../sql/kafka010/KafkaSourceProvider.scala | 3 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sql/sources/v2/MicroBatchReadSupport.java | 2 +- ...ory.java => ContinuousInputPartition.java} | 8 +- .../sources/v2/reader/DataSourceReader.java | 10 +-- ...ReaderFactory.java => InputPartition.java} | 16 ++-- ...aReader.java => InputPartitionReader.java} | 4 +- .../v2/reader/SupportsReportPartitioning.java | 2 +- .../v2/reader/SupportsScanColumnarBatch.java | 10 +-- .../v2/reader/SupportsScanUnsafeRow.java | 8 +- .../partitioning/ClusteredDistribution.java | 4 +- .../v2/reader/partitioning/Distribution.java | 6 +- .../v2/reader/partitioning/Partitioning.java | 4 +- ...va => ContinuousInputPartitionReader.java} | 6 +- .../v2/reader/streaming/ContinuousReader.java | 10 +-- .../v2/reader/streaming/MicroBatchReader.java | 2 +- .../datasources/v2/DataSourceRDD.scala | 11 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 46 +++++------ .../continuous/ContinuousDataSourceRDD.scala | 22 +++--- .../ContinuousQueuedDataReader.scala | 13 ++-- .../ContinuousRateStreamSource.scala | 20 ++--- .../sql/execution/streaming/memory.scala | 12 +-- .../sources/ContinuousMemoryStream.scala | 18 ++--- .../sources/RateStreamMicroBatchReader.scala | 15 ++-- .../execution/streaming/sources/socket.scala | 29 +++---- .../sources/v2/JavaAdvancedDataSourceV2.java | 22 +++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 12 +-- .../v2/JavaPartitionAwareDataSource.java | 12 +-- .../v2/JavaSchemaRequiredDataSource.java | 4 +- .../sources/v2/JavaSimpleDataSourceV2.java | 18 ++--- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 16 ++-- .../sources/RateStreamProviderSuite.scala | 12 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 78 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 14 ++-- .../sql/streaming/StreamingQuerySuite.scala | 11 +-- .../ContinuousQueuedDataReaderSuite.scala | 8 +- .../sources/StreamingDataSourceV2Suite.scala | 4 +- 39 files changed, 272 insertions(+), 263 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ContinuousDataReaderFactory.java => ContinuousInputPartition.java} (78%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataReaderFactory.java => InputPartition.java} (81%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataReader.java => InputPartitionReader.java} (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/{ContinuousDataReader.java => ContinuousInputPartitionReader.java} (84%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index f26c134c2f6e9..88abf8a8dd027 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType /** @@ -86,7 +86,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -108,7 +108,7 @@ class KafkaContinuousReader( case (topicPartition, start) => KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[DataReaderFactory[UnsafeRow]] + .asInstanceOf[InputPartition[UnsafeRow]] }.asJava } @@ -161,18 +161,18 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousDataReader( + new KafkaContinuousInputPartitionReader( topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader( + override def createPartitionReader(): KafkaContinuousInputPartitionReader = { + new KafkaContinuousInputPartitionReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } } @@ -187,12 +187,12 @@ case class KafkaContinuousDataReaderFactory( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousDataReader( +class KafkaContinuousInputPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index cbe655f9bff1f..8a377738ea782 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader( new KafkaMicroBatchDataReaderFactory( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } - factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava } override def getStartOffset: Offset = { @@ -299,27 +299,28 @@ private[kafka010] class KafkaMicroBatchReader( } } -/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ private[kafka010] case class KafkaMicroBatchDataReaderFactory( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, reuseKafkaConsumer) } -/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReader( +/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 36b9f0466566b..d225c1ea6b7f1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -149,7 +150,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Creates a [[ContinuousInputPartitionReader]] to read * Kafka data in a continuous streaming query. */ override def createContinuousReader( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index d2d04b68de6ab..871f9700cd1db 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.createUnsafeRowReaderFactories().asScala + val factories = reader.planUnsafeInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 209ffa7a0b9fa..7f4a2c9593c76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -34,7 +34,7 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index a61697649c43e..c24f3b21eade1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -21,15 +21,15 @@ import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can - * implement this interface to provide creating {@link DataReader} with particular offset. + * A mix-in interface for {@link InputPartition}. Continuous input partitions can + * implement this interface to provide creating {@link InputPartitionReader} with particular offset. */ @InterfaceStability.Evolving -public interface ContinuousDataReaderFactory extends DataReaderFactory { +public interface ContinuousInputPartition extends InputPartition { /** * Create a DataReader with particular offset as its startOffset. * * @param offset offset want to set as the DataReader's startOffset. */ - DataReader createDataReaderWithOffset(PartitionOffset offset); + InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index a470bccc5aad2..f898c296e4245 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,8 +31,8 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link DataReaderFactory}s that are returned by - * {@link #createDataReaderFactories()}. + * logic is delegated to {@link InputPartition}s that are returned by + * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -65,8 +65,8 @@ public interface DataSourceReader { StructType readSchema(); /** - * Returns a list of reader factories. Each factory is responsible for creating a data reader to - * output data for one RDD partition. That means the number of factories returned here is same as + * Returns a list of read tasks. Each task is responsible for creating a data reader to + * output data for one RDD partition. That means the number of tasks returned here is same as * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization @@ -76,5 +76,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createDataReaderFactories(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 81% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 32e98e8f5d8bd..c581e3b5d0047 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,20 +22,20 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is + * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is * responsible for creating the actual data reader. The relationship between - * {@link DataReaderFactory} and {@link DataReader} + * {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the reader factory will be serialized and sent to executors, then the data reader - * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be - * serializable and {@link DataReader} doesn't need to be. + * Note that input partitions will be serialized and sent to executors, then the partition reader + * will be created on executors and do the actual reading. So {@link InputPartition} must be + * serializable and {@link InputPartitionReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataReaderFactory extends Serializable { +public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this reader factory can run faster, + * The preferred locations where the data reader returned by this partition can run faster, * but Spark does not guarantee to run the data reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. @@ -57,5 +57,5 @@ default String[] preferredLocations() { * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ - DataReader createDataReader(); + InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index bb9790a1c819e..1b7051f1ad0af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for + * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data @@ -31,7 +31,7 @@ * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving -public interface DataReader extends Closeable { +public interface InputPartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 607628746e873..6b60da7c4dc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -24,7 +24,7 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 2e5cfa78511f0..0faf81db24605 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,22 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); + "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data * in batches. */ - List> createBatchDataReaderFactories(); + List> planBatchInputPartitions(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. + * {@link #planInputPartitions()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 9cd749e8e4ce9..f2220f6d31093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,14 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#planInputPartitions()}, * but returns data in unsafe row format. */ - List> createUnsafeRowReaderFactories(); + List> planUnsafeInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 2d0ee50212b56..38ca5fc6387b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link DataReader}. + * {@link InputPartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index f6b111fdf220d..d2ee9518d628f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link DataReader}). + * partition(the output records of a single {@link InputPartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 309d9e5de0a0f..f460f6bfe3bb9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** @@ -31,7 +31,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. + * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java similarity index 84% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java index 47d26440841fd..7b0ba0bbdda90 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** - * A variation on {@link DataReader} for use with streaming in continuous processing mode. + * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. */ @InterfaceStability.Evolving -public interface ContinuousDataReader extends DataReader { +public interface ContinuousInputPartitionReader extends InputPartitionReader { /** * Get the offset of the current record, or the start offset if no records have been read. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 7fe7f00ac2fa8..716c5c0e9e15a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. + * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -35,7 +35,7 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); @@ -47,7 +47,7 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset deserializeOffset(String json); /** - * Set the desired start offset for reader factories created from this reader. The scan will + * Set the desired start offset for partitions created from this reader. The scan will * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ @@ -61,8 +61,8 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new reader - * factories need to be generated, which may be required if for example the underlying + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index 67ebde30d61a9..0159c731762d9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -33,7 +33,7 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** - * Set the desired offset range for reader factories created from this reader. Reader factories + * Set the desired offset range for input partitions created from this reader. Partition readers * will generate only data within (`start`, `end`]; that is, from the first record after `start` * to the record with offset `end`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f85971be394b1..1a6b32429313a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,14 +22,14 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { @@ -39,7 +39,8 @@ class DataSourceRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition + .createPartitionReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +64,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 77cb707340b0f..c6a7684bf6ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -59,13 +59,13 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => SinglePartition - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => SinglePartition - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => SinglePartition case s: SupportsReportPartitioning => @@ -75,19 +75,19 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala + private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala case _ => - reader.createDataReaderFactories().asScala.map { - new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] + reader.planInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] } } - private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { + private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - r.createBatchDataReaderFactories().asScala + r.planBatchInputPartitions().asScala } private lazy val inputRDD: RDD[InternalRow] = reader match { @@ -95,19 +95,18 @@ case class DataSourceV2ScanExec( EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size)) + .askSync[Unit](SetReaderPartitions(partitions.size)) new ContinuousDataSourceRDD( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - readerFactories) - .asInstanceOf[RDD[InternalRow]] + partitions).asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -132,19 +131,22 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) - extends DataReaderFactory[UnsafeRow] { +class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) + extends InputPartition[UnsafeRow] { - override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations + override def preferredLocations: Array[String] = partition.preferredLocations - override def createDataReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader( - rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + new RowToUnsafeInputPartitionReader( + partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } } -class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) - extends DataReader[UnsafeRow] { +class RowToUnsafeInputPartitionReader( + val rowReader: InputPartitionReader[Row], + encoder: ExpressionEncoder[Row]) + + extends InputPartitionReader[UnsafeRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 0a3b9dcccb6c5..a7ccce10b0cee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,14 +21,14 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} import org.apache.spark.util.{NextIterator, ThreadUtils} class ContinuousDataSourceRDDPartition( val index: Int, - val readerFactory: DataReaderFactory[UnsafeRow]) + val inputPartition: InputPartition[UnsafeRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,12 +51,12 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerFactories.zipWithIndex.map { - case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -75,7 +75,7 @@ class ContinuousDataSourceRDD( if (partition.queueReader == null) { partition.queueReader = new ContinuousQueuedDataReader( - partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -96,17 +96,17 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() } } object ContinuousDataSourceRDD { private[continuous] def getContinuousReader( - reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case r: ContinuousInputPartitionReader[UnsafeRow] => r + case wrapped: RowToUnsafeInputPartitionReader => + wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 01a999f6505fc..d8645576c2052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.Closeable -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.util.control.NonFatal -import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils @@ -38,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - factory: DataReaderFactory[UnsafeRow], + partition: InputPartition[UnsafeRow], context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = factory.createDataReader() + private val reader = partition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -132,7 +131,7 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[DataReader]]. + * a new row arrives to the [[InputPartitionReader]]. */ class DataReaderThread extends Thread( s"continuous-reader--${context.partitionId()}--" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..8d25d9ccc43d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -91,7 +91,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) i, numPartitions, perPartitionRate) - .asInstanceOf[DataReaderFactory[Row]] + .asInstanceOf[InputPartition[Row]] }.asJava } @@ -119,13 +119,13 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReaderFactory[Row] { + extends ContinuousInputPartition[Row] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = { val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] require(rateStreamOffset.partition == partitionIndex, s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousDataReader( + new RateStreamContinuousInputPartitionReader( rateStreamOffset.currentValue, rateStreamOffset.currentTimeMs, partitionIndex, @@ -133,18 +133,18 @@ case class RateStreamContinuousDataReaderFactory( rowsPerSecond) } - override def createDataReader(): DataReader[Row] = - new RateStreamContinuousDataReader( + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamContinuousInputPartitionReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamContinuousDataReader( +class RateStreamContinuousInputPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReader[Row] { + extends ContinuousInputPartitionReader[Row] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6720cdd24b1b2..daa2963220aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -139,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -202,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) - extends DataReaderFactory[UnsafeRow] { - override def createDataReader(): DataReader[UnsafeRow] = { - new DataReader[UnsafeRow] { + extends InputPartition[UnsafeRow] { + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { + new InputPartitionReader[UnsafeRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index a8fca3c19a2d2..fef792eab69d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -34,8 +34,8 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -99,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) ) } - override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = @@ -108,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) startOffset.partitionNums.map { case (part, index) => new ContinuousMemoryStreamDataReaderFactory( - endpointName, part, index): DataReaderFactory[Row] + endpointName, part, index): InputPartition[Row] }.toList.asJava } } @@ -160,9 +160,9 @@ object ContinuousMemoryStream { class ContinuousMemoryStreamDataReaderFactory( driverEndpointName: String, partition: Int, - startOffset: Int) extends DataReaderFactory[Row] { - override def createDataReader: ContinuousMemoryStreamDataReader = - new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition[Row] { + override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = + new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } /** @@ -170,10 +170,10 @@ class ContinuousMemoryStreamDataReaderFactory( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamDataReader( +class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousDataReader[Row] { + startOffset: Int) extends ContinuousInputPartitionReader[Row] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index f54291bea6678..723cc3ad5bb89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") @@ -169,7 +169,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchDataReaderFactory( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] + : InputPartition[Row] }.toList.asJava } @@ -188,19 +188,20 @@ class RateStreamMicroBatchDataReaderFactory( rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { + relativeMsPerValue: Double) extends InputPartition[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, + localStartTimeMs, relativeMsPerValue) } -class RateStreamMicroBatchDataReader( +class RateStreamMicroBatchInputPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { + relativeMsPerValue: Double) extends InputPartitionReader[Row] { private var count = 0 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 90f4a5ba4234d..8240e06d4ab72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -140,7 +140,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") @@ -165,21 +165,22 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR (0 until numPartitions).map { i => val slice = slices(i) - new DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new DataReader[Row] { - private var currentIdx = -1 + new InputPartition[Row] { + override def createPartitionReader(): InputPartitionReader[Row] = + new InputPartitionReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) + override def close(): Unit = {} } - - override def close(): Unit = {} - } } }.toList.asJava } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 172e5d5eebcbe..714638e500c94 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -79,8 +79,8 @@ public Filter[] pushedFilters() { } @Override - public List> createDataReaderFactories() { - List> res = new ArrayList<>(); + public List> planInputPartitions() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -94,33 +94,33 @@ public List> createDataReaderFactories() { } if (lowerBound == null) { - res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { + JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; } @Override - public DataReader createDataReader() { - return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); + public InputPartitionReader createPartitionReader() { + return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index c55093768105b..97d6176d02559 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,14 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchDataReaderFactories() { + public List> planBatchInputPartitions() { return java.util.Arrays.asList( - new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); + new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); } } - static class JavaBatchDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaBatchInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; @@ -59,13 +59,13 @@ static class JavaBatchDataReaderFactory private OnHeapColumnVector j; private ColumnarBatch batch; - JavaBatchDataReaderFactory(int start, int end) { + JavaBatchInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 32fad59b97ff6..e49c8cf8b9e16 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -43,10 +43,10 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -73,12 +73,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { + static class SpecificInputPartition implements InputPartition, InputPartitionReader { private int[] i; private int[] j; private int current = -1; - SpecificDataReaderFactory(int[] i, int[] j) { + SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; @@ -101,7 +101,7 @@ public void close() throws IOException { } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { return this; } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 048d078dfaac4..80eeffd95f83b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 96f55b8a76811..8522a63898a3b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -25,8 +25,8 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new JavaSimpleDataReaderFactory(0, 5), - new JavaSimpleDataReaderFactory(5, 10)); + new JavaSimpleInputPartition(0, 5), + new JavaSimpleInputPartition(5, 10)); } } - static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; - JavaSimpleDataReaderFactory(int start, int end) { + JavaSimpleInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { - return new JavaSimpleDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaSimpleInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index c3916e0b370b5..3ad8e7a0104ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,20 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReaderFactories() { + public List> planUnsafeInputPartitions() { return java.util.Arrays.asList( - new JavaUnsafeRowDataReaderFactory(0, 5), - new JavaUnsafeRowDataReaderFactory(5, 10)); + new JavaUnsafeRowInputPartition(0, 5), + new JavaUnsafeRowInputPartition(5, 10)); } } - static class JavaUnsafeRowDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaUnsafeRowInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowDataReaderFactory(int start, int end) { + JavaUnsafeRowInputPartition(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,8 +59,8 @@ static class JavaUnsafeRowDataReaderFactory } @Override - public DataReader createDataReader() { - return new JavaUnsafeRowDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaUnsafeRowInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..39a010f970ce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -142,9 +142,9 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() + val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() while (dataReader.next()) { data.append(dataReader.get()) @@ -159,11 +159,11 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala - .map(_.createDataReader()) + .map(_.createPartitionReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[Row]() while (reader.next()) buf.append(reader.get()) @@ -304,7 +304,7 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() @@ -314,7 +314,7 @@ class RateSourceSuite extends StreamTest { .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0a53272cd222..505a3f3465c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -346,8 +346,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -359,20 +359,21 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SimpleInputPartition(start: Int, end: Int) + extends InputPartition[Row] + with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) + override def createPartitionReader(): InputPartitionReader[Row] = + new SimpleInputPartition(start, end) override def next(): Boolean = { current += 1 @@ -413,21 +414,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[DataReaderFactory[Row]] + val res = new ArrayList[InputPartition[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(0, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) } res @@ -437,13 +438,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) - extends DataReaderFactory[Row] with DataReader[Row] { +class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) + extends InputPartition[Row] with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = { - new AdvancedDataReaderFactory(start, end, requiredSchema) + override def createPartitionReader(): InputPartitionReader[Row] = { + new AdvancedInputPartition(start, end, requiredSchema) } override def close(): Unit = {} @@ -468,24 +469,24 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), - new UnsafeRowDataReaderFactory(5, 10)) + override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), + new UnsafeRowInputPartitionReader(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowInputPartitionReader(start: Int, end: Int) + extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -503,7 +504,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceReader { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = + override def planInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -516,16 +517,17 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + java.util.Arrays.asList( + new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class BatchDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchInputPartitionReader(start: Int, end: Int) + extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -534,7 +536,7 @@ class BatchDataReaderFactory(start: Int, end: Int) private var current = start - override def createDataReader(): DataReader[ColumnarBatch] = this + override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this override def next(): Boolean = { i.reset() @@ -568,11 +570,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -590,14 +592,14 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) + extends InputPartition[Row] + with InputPartitionReader[Row] { assert(i.length == j.length) private var current = -1 - override def createDataReader(): DataReader[Row] = this + override def createPartitionReader(): InputPartitionReader[Row] = this override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa321359..694bb3b95b0f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,9 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVDataReaderFactory( + new SimpleCSVInputPartitionReader( f.getPath.toUri.toString, - serializableConf): DataReaderFactory[Row] + serializableConf): InputPartition[Row] }.toList.asJava } else { Collections.emptyList() @@ -156,14 +156,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) - extends DataReaderFactory[Row] with DataReader[Row] { +class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) + extends InputPartition[Row] with InputPartitionReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createDataReader(): DataReader[Row] = { + override def createPartitionReader(): InputPartitionReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 57986999bf861..dcf6cb5d609ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { clock.waitTillTime(1350) - super.createUnsafeRowReaderFactories() + super.planUnsafeInputPartitions() } } } @@ -290,13 +290,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AdvanceManualClock(100), // time = 1150 to unblock getEndOffset AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 + // will block on planInputPartitions that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1350), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions AssertClockTime(1350), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e755625d09e0f..f47d3ec8ae025 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -72,8 +72,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new DataReaderFactory[UnsafeRow] { - override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + val factory = new InputPartition[UnsafeRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { var index = -1 var curr: UnsafeRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index af4618bed5456..c1a28b9bc75ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} @@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { + def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From e3d434994733ae16e7e1424fb6de2d22b1a13f99 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 10 May 2018 13:36:52 +0800 Subject: [PATCH 0774/2461] [SPARK-22279][SQL] Enable `convertMetastoreOrc` by default ## What changes were proposed in this pull request? We reverted `spark.sql.hive.convertMetastoreOrc` at https://github.com/apache/spark/pull/20536 because we should not ignore the table-specific compression conf. Now, it's resolved via [SPARK-23355](https://github.com/apache/spark/commit/8aa1d7b0ede5115297541d29eab4ce5f4fe905cb). ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21186 from dongjoon-hyun/SPARK-24112. --- docs/sql-programming-guide.md | 3 ++- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e8946e424237..3f79ed6422205 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1017,7 +1017,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also - + @@ -1813,6 +1813,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 10c9603745379..bb134bbe68bd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -105,11 +105,10 @@ private[spark] object HiveUtils extends Logging { .createWithDefault(false) val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") - .internal() .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 94d671448240c8f6da11d2523ba9e4ae5b56a410 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 10 May 2018 20:38:52 +0900 Subject: [PATCH 0775/2461] [SPARK-23907][SQL] Add regr_* functions ## What changes were proposed in this pull request? The PR introduces regr_slope, regr_intercept, regr_r2, regr_sxx, regr_syy, regr_sxy, regr_avgx, regr_avgy, regr_count. The implementation of this functions mirrors Hive's one in HIVE-15978. ## How was this patch tested? added UT (values compared with Hive) Author: Marco Gaido Closes #21054 from mgaido91/SPARK-23907. --- .../catalyst/analysis/FunctionRegistry.scala | 9 + .../expressions/aggregate/Average.scala | 47 +++-- .../aggregate/CentralMomentAgg.scala | 60 +++--- .../catalyst/expressions/aggregate/Corr.scala | 52 ++--- .../expressions/aggregate/Count.scala | 47 +++-- .../expressions/aggregate/Covariance.scala | 36 ++-- .../expressions/aggregate/regression.scala | 190 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 172 ++++++++++++++++ .../sql-tests/inputs/udaf-regrfunctions.sql | 56 ++++++ .../results/udaf-regrfunctions.sql.out | 93 +++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 71 ++++++- 11 files changed, 721 insertions(+), 112 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 87b0911e150c5..087d000a9db70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,15 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[RegrCount]("regr_count"), + expression[RegrSXX]("regr_sxx"), + expression[RegrSYY]("regr_syy"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), + expression[RegrSXY]("regr_sxy"), + expression[RegrSlope]("regr_slope"), + expression[RegrR2]("regr_r2"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbfc36058..a133bc2361eb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil +abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override def nullable: Boolean = true - // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") - private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) @@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right @@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _ => Cast(sum, resultType) / Cast(count, resultType) } + + protected def updateExpressionsDef: Seq[Expression] = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override lazy val updateExpressions = updateExpressionsDef +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) + extends AverageLike(child) with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 572d29caf5bc9..6bbb083f1e18e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression) override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val delta = child - avg - val deltaN = delta / newN - val newAvg = avg + deltaN - val newM2 = m2 + delta * (delta - deltaN) - - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val newM3 = if (momentOrder >= 3) { - m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) - } else { - Literal(0.0) - } - val newM4 = if (momentOrder >= 4) { - m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + - delta * (delta * delta2 - deltaN * deltaN2) - } else { - Literal(0.0) - } - - trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) - )) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression) trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) + } } // Compute the population standard deviation of a column diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 95a4a0d5af634..3cdef72c1f2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ /** - * Compute Pearson correlation between two expressions. + * Base class for computing Pearson correlation between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") -// scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef + + override val mergeExpressions: Seq[Expression] = { + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) + } + + protected def updateExpressionsDef: Seq[Expression] = { val newN = n + Literal(1.0) val dx = x - xAvg val dxN = dx / newN @@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression) If(isNull, yMk, newYMk) ) } +} - override val mergeExpressions: Seq[Expression] = { - - val n1 = n.left - val n2 = n.right - val newN = n1 + n2 - val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) - val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) - val newXAvg = xAvg.left + dxN * n2 - val newYAvg = yAvg.left + dyN * n2 - val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 - val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 - Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) - } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") +// scalastyle:on line.size.limit +case class Corr(x: Expression, y: Expression) + extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1990f2f2f0722..40582d0abd762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate { - +/** + * Base class for all counting aggregators. + */ +abstract class CountLike extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType - private lazy val count = AttributeReference("count", LongType, nullable = false)() + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends CountLike { + override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) if (nullableChildren.isEmpty) { @@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index fc6c34baafdd1..72a7c62b328ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) - override lazy val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val dx = x - xAvg - val dy = y - yAvg - val dyN = dy / newN - val newXAvg = xAvg + dx / newN - val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) - - val isNull = IsNull(x) || IsNull(y) - Seq( - If(isNull, n, newN), - If(isNull, xAvg, newXAvg), - If(isNull, yAvg, newYAvg), - If(isNull, ck, newCk) - ) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression) Seq(newN, newXAvg, newYAvg, newCk) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala new file mode 100644 index 0000000000000..d8f4505588ff2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{AbstractDataType, DoubleType} + +/** + * Base trait for all regression functions. + */ +trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { + def y: Expression + def x: Expression + + override def children: Seq[Expression] = Seq(y, x) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { + assert(aggBufferAttributes.length == exprs.length) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + exprs + } else { + exprs.zip(aggBufferAttributes).map { case (e, a) => + If(nullableChildren.map(IsNull).reduce(Or), a, e) + } + } + } +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", + since = "2.4.0") +case class RegrCount(y: Expression, x: Expression) + extends CountLike with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) + + override def prettyName: String = "regr_count" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSXX(y: Expression, x: Expression) + extends CentralMomentAgg(x) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_sxx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSYY(y: Expression, x: Expression) + extends CentralMomentAgg(y) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_syy" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgX(y: Expression, x: Expression) + extends AverageLike(x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgY(y: Expression, x: Expression) + extends AverageLike(y) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgy" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSXY(y: Expression, x: Expression) + extends Covariance(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), ck) + } + + override def prettyName: String = "regr_sxy" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSlope(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) + } + + override def prettyName: String = "regr_slope" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrR2(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) + } + + override def prettyName: String = "regr_r2" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + xAvg - (ck / yMk) * yAvg) + } + + override def prettyName: String = "regr_intercept" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8f9e4ae18b3f1..28cf705eb9700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -775,6 +775,178 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: Column, x: Column): Column = withAggregateFunction { + RegrCount(y.expr, x.expr) + } + + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { + RegrSXX(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: Column, x: Column): Column = withAggregateFunction { + RegrSYY(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgX(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { + RegrSXY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: Column, x: Column): Column = withAggregateFunction { + RegrSlope(y.expr, x.expr) + } + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: Column, x: Column): Column = withAggregateFunction { + RegrR2(y.expr, x.expr) + } + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { + RegrIntercept(y.expr, x.expr) + } + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) + + + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql new file mode 100644 index 0000000000000..92c7e26e3add2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -0,0 +1,56 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x); + +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px; + + +select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out new file mode 100644 index 0000000000000..d7d009a64bf84 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px +-- !query 1 schema +struct +-- !query 1 output +1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 +2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 +3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 +4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 +5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 + + +-- !query 2 +select id, regr_count(y,x) over (partition by px) from t1 order by id +-- !query 2 schema +struct +-- !query 2 output +101 4 +102 4 +103 4 +104 4 +201 4 +202 4 +203 4 +204 4 +301 4 +302 4 +303 4 +304 4 +401 4 +402 4 +403 4 +404 4 +501 0 +502 0 +503 0 +504 0 +601 0 +602 0 +603 0 +604 0 +701 0 +702 0 +703 0 +704 0 +800 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7776e36702ad..4337fb2290fbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ + val absTol = 1e-8 + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -686,4 +687,72 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23907: regression functions") { + val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") + val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) + .toDF("a", "b") + val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( + (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") + checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) + checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) + checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), + Row(null), absTol) + + + checkAggregatesWithTol(correlatedData.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), + absTol) + checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), + absTol) + } } From f4fed0512101a67d9dae50ace11d3940b910e05e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 10 May 2018 09:44:49 -0700 Subject: [PATCH 0776/2461] [SPARK-24171] Adding a note for non-deterministic functions ## What changes were proposed in this pull request? I propose to add a clear statement for functions like `collect_list()` about non-deterministic behavior of such functions. The behavior must be taken into account by user while creating and running queries. Author: Maxim Gekk Closes #21228 from MaxGekk/deterministic-comments. --- R/pkg/R/functions.R | 11 +++++ python/pyspark/sql/functions.py | 18 ++++++++ .../MonotonicallyIncreasingID.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/randomExpressions.scala | 8 ++-- .../org/apache/spark/sql/functions.scala | 46 +++++++++++++++++-- 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0ec99d19e21e4..04d0e4620b28a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -805,6 +805,8 @@ setMethod("factorial", #' #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param na.rm a logical value indicating whether NA values should be stripped #' before the computation proceeds. @@ -948,6 +950,8 @@ setMethod("kurtosis", #' #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param x column to compute on. #' @param na.rm a logical value indicating whether NA values should be stripped @@ -1201,6 +1205,7 @@ setMethod("minute", #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. #' The method should be used with no argument. +#' Note: the function is non-deterministic because its result depends on partition IDs. #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method @@ -2584,6 +2589,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @details #' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) #' samples from U[0.0, 1.0]. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2612,6 +2618,7 @@ setMethod("rand", signature(seed = "numeric"), #' @details #' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples #' from the standard normal distribution. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -3188,6 +3195,8 @@ setMethod("create_map", #' @details #' \code{collect_list}: Creates a list of objects with duplicates. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method @@ -3207,6 +3216,8 @@ setMethod("collect_list", #' @details #' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ac3c79766702c..f5a584152b4f6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -152,6 +152,9 @@ def _(): _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_list('age')).collect() [Row(collect_list(age)=[2, 5, 5])] @@ -159,6 +162,9 @@ def _(): _collect_set_doc = """ Aggregate function: returns a set of objects with duplicate elements eliminated. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] @@ -401,6 +407,9 @@ def first(col, ignorenulls=False): The function by default returns the first values it sees. It will return the first non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows which + may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) @@ -489,6 +498,9 @@ def last(col, ignorenulls=False): The function by default returns the last values it sees. It will return the last non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows + which may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) @@ -504,6 +516,8 @@ def monotonically_increasing_id(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + .. note:: The function is non-deterministic because its result depends on partition IDs. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -536,6 +550,8 @@ def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('rand', rand(seed=42) * 3).collect() [Row(age=2, name=u'Alice', rand=1.1568609015300986), Row(age=5, name=u'Bob', rand=1.403379671529166)] @@ -554,6 +570,8 @@ def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('randn', randn(seed=42)).collect() [Row(age=2, name=u'Alice', randn=-0.7556247885860078), Row(age=5, name=u'Bob', randn=-0.0861619008451133)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ad1e7bdb31987..9f0779642271d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.types.{DataType, LongType} puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + The function is non-deterministic because its result depends on partition IDs. """) case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7eda65a867028..b7834696cafc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -117,12 +117,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.", + usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", examples = """ Examples: > SELECT _FUNC_(); 46707d92-02f4-4817-8116-a4c3b23e6266 - """) + """, + note = "The function is non-deterministic.") // scalastyle:on line.size.limit case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 70186053617f8..2653b28f6c3bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -68,7 +68,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful 0.8446490682263027 > SELECT _FUNC_(null); 0.8446490682263027 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { @@ -96,7 +97,7 @@ object Rand { /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.", + usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""", examples = """ Examples: > SELECT _FUNC_(); @@ -105,7 +106,8 @@ object Rand { 1.1164209726833079 > SELECT _FUNC_(null); 1.1164209726833079 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 28cf705eb9700..225de0051d6fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -283,6 +283,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -291,6 +294,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -299,6 +305,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -307,6 +316,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -422,6 +434,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -435,6 +450,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -448,6 +466,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -459,6 +480,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -535,6 +559,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -548,6 +575,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -561,6 +591,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -572,6 +605,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -1344,7 +1380,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1355,6 +1391,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1364,7 +1402,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1375,6 +1413,8 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1383,7 +1423,7 @@ object functions { /** * Partition ID. * - * @note This is indeterministic because it depends on data partitioning and task scheduling. + * @note This is non-deterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 From 6282fc64e32fc2f70e79ace14efd4922e4535dbb Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 10 May 2018 11:36:41 -0700 Subject: [PATCH 0777/2461] [SPARK-24137][K8S] Mount local directories as empty dir volumes. ## What changes were proposed in this pull request? Drastically improves performance and won't cause Spark applications to fail because they write too much data to the Docker image's specific file system. The file system's directories that back emptydir volumes are generally larger and more performant. ## How was this patch tested? Has been in use via the prototype version of Kubernetes support, but lost in the transition to here. Author: mcheah Closes #21238 from mccheah/mount-local-dirs. --- .../scala/org/apache/spark/SparkConf.scala | 5 +- .../k8s/features/LocalDirsFeatureStep.scala | 77 ++++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 10 +- .../k8s/KubernetesExecutorBuilder.scala | 9 +- .../features/LocalDirsFeatureStepSuite.scala | 111 ++++++++++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 13 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 12 +- 7 files changed, 223 insertions(+), 14 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 129956e9f9ffa..dab409572646f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -454,8 +454,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { - val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + - "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." + val msg = "Note that spark.local.dir will be overridden by the value set by " + + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" + + " in YARN)." logWarning(msg) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala new file mode 100644 index 0000000000000..70b307303d149 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.nio.file.Paths +import java.util.UUID + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class LocalDirsFeatureStep( + conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], + defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}") + extends KubernetesFeatureConfigStep { + + // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system + // property - we want to instead default to mounting an emptydir volume that doesn't already + // exist in the image. + // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already + // a bit opinionated about YARN and Mesos. + private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS")) + .orElse(conf.getOption("spark.local.dir")) + .getOrElse(defaultLocalDir) + .split(",") + + override def configurePod(pod: SparkPod): SparkPod = { + val localDirVolumes = resolvedLocalDirs + .zipWithIndex + .map { case (localDir, index) => + new VolumeBuilder() + .withName(s"spark-local-dir-${index + 1}") + .withNewEmptyDir() + .endEmptyDir() + .build() + } + val localDirVolumeMounts = localDirVolumes + .zip(resolvedLocalDirs) + .map { case (localDirVolume, localDirPath) => + new VolumeMountBuilder() + .withName(localDirVolume.getName) + .withMountPath(localDirPath) + .build() + } + val podWithLocalDirVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(localDirVolumes: _*) + .endSpec() + .build() + val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container) + .addNewEnv() + .withName("SPARK_LOCAL_DIRS") + .withValue(resolvedLocalDirs.mkString(",")) + .endEnv() + .addToVolumeMounts(localDirVolumeMounts: _*) + .build() + SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index c7579ed8cb689..10b0154466a3a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -29,14 +29,18 @@ private[spark] class KubernetesDriverBuilder( new DriverServiceFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { val baseFeatures = Seq( provideBasicStep(kubernetesConf), provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf)) + provideServiceStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 22568fe7ea3be..d8f63d57574fb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,18 +17,21 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala new file mode 100644 index 0000000000000..91e184b84b86e --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + private val defaultLocalDir = "/var/data/default-local-dir" + private var sparkConf: SparkConf = _ + private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + + before { + val realSparkConf = new SparkConf(false) + sparkConf = Mockito.spy(realSparkConf) + kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + "resource", + "app-id", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + } + + test("Resolve to default local dir if neither env nor configuration are set") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } + + test("Use configured local dirs split on comma if provided.") { + Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 2) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.pod.getSpec.getVolumes.get(1) === + new VolumeBuilder() + .withName(s"spark-local-dir-2") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 2) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath("/var/data/my-local-dir-1") + .build()) + assert(configuredPod.container.getVolumeMounts.get(1) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-2") + .withMountPath("/var/data/my-local-dir-2") + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .build()) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 161f9afe7bba9..a511d254d2175 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val CREDENTIALS_STEP_TYPE = "credentials" private val SERVICE_STEP_TYPE = "service" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( @@ -36,6 +37,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) @@ -44,7 +48,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, - _ => secretsStep) + _ => secretsStep, + _ => localDirsStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( @@ -64,7 +69,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE) + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index f5270623f8acc..9ee86b5a423a9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,20 +20,24 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, - _ => mountSecretsStep) + _ => mountSecretsStep, + _ => localDirsStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -46,7 +50,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty) - validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -63,6 +68,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE) } From 3e2600538ee477ffe3f23fba57719e035219550b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 10 May 2018 14:26:38 -0700 Subject: [PATCH 0778/2461] [SPARK-19181][CORE] Fixing flaky "SparkListenerSuite.local metrics" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Sometimes "SparkListenerSuite.local metrics" test fails because the average of executorDeserializeTime is too short. As squito suggested to avoid these situations in one of the task a reference introduced to an object implementing a custom Externalizable.readExternal which sleeps 1ms before returning. ## How was this patch tested? With unit tests (and checking the effect of this change to the average with a much larger sleep time). Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #21280 from attilapiros/SPARK-19181. --- .../spark/scheduler/SparkListenerSuite.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index da6ecb82c7e42..fa47a52bbbc47 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.Semaphore import scala.collection.JavaConverters._ @@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) - // just to make sure some of the tasks take a noticeable amount of time + // just to make sure some of the tasks and their deserialization take a noticeable + // amount of time + val slowDeserializable = new SlowDeserializable val w = { i: Int => if (i == 0) { Thread.sleep(100) + slowDeserializable.use() } i } @@ -583,3 +587,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar case _ => } } + +private class SlowDeserializable extends Externalizable { + + override def writeExternal(out: ObjectOutput): Unit = { } + + override def readExternal(in: ObjectInput): Unit = Thread.sleep(1) + + def use(): Unit = { } +} From d3c426a5b02abdec49ff45df12a8f11f9e473a88 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 10 May 2018 14:41:55 -0700 Subject: [PATCH 0779/2461] [SPARK-10878][CORE] Fix race condition when multiple clients resolves artifacts at the same time ## What changes were proposed in this pull request? When multiple clients attempt to resolve artifacts via the `--packages` parameter, they could run into race condition when they each attempt to modify the dummy `org.apache.spark-spark-submit-parent-default.xml` file created in the default ivy cache dir. This PR changes the behavior to encode UUID in the dummy module descriptor so each client will operate on a different resolution file in the ivy cache dir. In addition, this patch changes the behavior of when and which resolution files are cleaned to prevent accumulation of resolution files in the default ivy cache dir. Since this PR is a successor of #18801, close #18801. Many codes were ported from #18801. **Many efforts were put here. I think this PR should credit to Victsm .** ## How was this patch tested? added UT into `SparkSubmitUtilsSuite` Author: Kazuaki Ishizaki Closes #21251 from kiszk/SPARK-10878. --- .../org/apache/spark/deploy/SparkSubmit.scala | 42 ++++++++++++++----- .../spark/deploy/SparkSubmitUtilsSuite.scala | 15 +++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 427c797755b84..087e9c31a9c9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab import java.net.URL import java.security.PrivilegedExceptionAction import java.text.ParseException +import java.util.UUID import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -1204,7 +1205,33 @@ private[spark] object SparkSubmitUtils { /** A nice function to use in tests as well. Values are dummy strings. */ def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( - ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + // Include UUID in module name, so multiple clients resolving maven coordinate at the same time + // do not modify the same resolution file concurrently. + ModuleRevisionId.newInstance("org.apache.spark", + s"spark-submit-parent-${UUID.randomUUID.toString}", + "1.0")) + + /** + * Clear ivy resolution from current launch. The resolution file is usually at + * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml, + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties. + * Since each launch will have its own resolution files created, delete them after + * each resolution to prevent accumulation of these files in the ivy cache dir. + */ + private def clearIvyResolutionFiles( + mdId: ModuleRevisionId, + ivySettings: IvySettings, + ivyConfName: String): Unit = { + val currentResolutionFiles = Seq( + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" + ) + currentResolutionFiles.foreach { filename => + new File(ivySettings.getDefaultCache, filename).delete() + } + } /** * Resolves any dependencies that were supplied through maven coordinates @@ -1255,14 +1282,6 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor - // clear ivy resolution from previous launches. The resolution file is usually at - // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file - // leads to confusion with Ivy when the files can no longer be found at the repository - // declared in that file/ - val mdId = md.getModuleRevisionId - val previousResolution = new File(ivySettings.getDefaultCache, - s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") - if (previousResolution.exists) previousResolution.delete md.setDefaultConf(ivyConfName) @@ -1283,7 +1302,10 @@ private[spark] object SparkSubmitUtils { packagesDirectory.getAbsolutePath + File.separator + "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val paths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val mdId = md.getModuleRevisionId + clearIvyResolutionFiles(mdId, ivySettings, ivyConfName) + paths } finally { System.setOut(sysOut) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index eb8c203ae7751..a0f09891787e0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } } + + test("SPARK-10878: test resolution files cleaned after resolving artifact") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + + IvyTestUtils.withRepository(main, None, None) { repo => + val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath)) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + ivySettings, + isTest = true) + val r = """.*org.apache.spark-spark-submit-parent-.*""".r + assert(!ivySettings.getDefaultCache.listFiles.map(_.getName) + .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned") + } + } } From a4206d58e05ab9ed6f01fee57e18dee65cbc4efc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 09:01:40 +0800 Subject: [PATCH 0780/2461] [SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/20136 . #20136 didn't really work because in the test, we are using local backend, which shares the driver side `SparkEnv`, so `SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER` doesn't work. This PR changes the check to `TaskContext.get != null`, and move the check to `SQLConf.get`, and fix all the places that violate this check: * `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. https://github.com/apache/spark/pull/21223 merged * `DataType#sameType` may be executed in executor side, for things like json schema inference, so we can't call `conf.caseSensitiveAnalysis` there. This contributes to most of the code changes, as we need to add `caseSensitive` parameter to a lot of methods. * `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't can't call `conf.parquetFilterPushDownDate` there. https://github.com/apache/spark/pull/21224 merged * `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. https://github.com/apache/spark/pull/21225 merged * `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. https://github.com/apache/spark/pull/21226 merged ## How was this patch tested? existing test Author: Wenchen Fan Closes #21190 from cloud-fan/minor. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++++-------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 ++- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 +++-- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 188 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..94b0561529e71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -260,7 +261,9 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { + val widerType = TypeCoercion.findWiderTypeForTwo( + dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) + if (widerType.isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index f2df3e132629f..4eb6e642b1c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,7 +83,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { + val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( + inputTypes, conf.caseSensitiveAnalysis) + val tpe = wideType.getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..a7ba201509b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes :: + WidenSetOperationTypes(conf) :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion :: + FunctionArgumentConversion(conf) :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion :: - IfCoercion :: + CaseWhenCoercion(conf) :: + IfCoercion(conf) :: StackCoercion :: Division :: - new ImplicitTypeCasts(conf) :: + ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,7 +83,10 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + def findTightestCommonType( + left: DataType, + right: DataType, + caseSensitive: Boolean): Option[DataType] = (left, right) match { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -102,22 +105,32 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => + val isSameType = if (caseSensitive) { + DataType.equalsIgnoreNullability(t1, t2) + } else { + DataType.equalsIgnoreCaseAndNullability(t1, t2) + } + + if (isSameType) { + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since t1 is same type of t2, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + } else { + None + } case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) + val keyType = findTightestCommonType(kt1, kt2, caseSensitive) + val valueType = findTightestCommonType(vt1, vt2, caseSensitive) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -172,13 +185,14 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) + def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { + findTightestCommonType(t1, t2, caseSensitive) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2, caseSensitive) + .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -193,7 +207,8 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + private def findWiderCommonType( + types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -201,7 +216,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) + case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) case _ => None }) } @@ -213,20 +228,22 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) + t2: DataType, + caseSensitive: Boolean): Option[DataType] = { + findTightestCommonType(t1, t2, caseSensitive) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) + findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + def findWiderTypeWithoutStringPromotion( + types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) case None => None }) } @@ -279,29 +296,32 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenSetOperationTypes extends Rule[LogicalPlan] { + case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + val newChildren: Seq[LogicalPlan] = + buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + val newChildren: Seq[LogicalPlan] = + buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes( + children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -316,18 +336,19 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + castedTypes: mutable.Queue[DataType], + caseSensitive: Boolean): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes) + getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) } } @@ -432,7 +453,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType)) + .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) } // The number of columns/expressions must match between LHS and RHS of an @@ -461,7 +482,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { + findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -515,7 +536,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - object FunctionArgumentConversion extends TypeCoercionRule { + case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -523,7 +544,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -531,7 +552,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -542,7 +563,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -552,7 +573,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -580,7 +601,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -590,14 +611,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -637,11 +658,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CaseWhenCoercion extends TypeCoercionRule { + case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -668,16 +689,17 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - object IfCoercion extends TypeCoercionRule { + case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { + widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -776,12 +798,11 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -804,17 +825,18 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType).map { commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { + commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..0b1965c438e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,7 +107,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") + } + confGetter.get()() + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1274,12 +1280,6 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..4ee12db9c10ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,11 +81,7 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - if (SQLConf.get.caseSensitiveAnalysis) { - DataType.equalsIgnoreNullability(this, other) - } else { - DataType.equalsIgnoreCaseAndNullability(this, other) - } + DataType.equalsIgnoreNullability(this, other) /** * Returns the same data type but set all nullability fields are true @@ -218,7 +214,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..f73e045685ee1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType) => Option[DataType], + widenFunc: (DataType, DataType, Boolean) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2) + var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1) + found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion + val rule = TypeCoercion.FunctionArgumentConversion(conf) val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion + val rule = TypeCoercion.IfCoercion(conf) val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(FunctionArgumentConversion(conf), Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f9a24806953e6..1edf27619ad7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -521,6 +522,8 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) + case (t1, t2) => + TypeCoercion.findWiderTypeForTwo( + t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e0424b7478122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -44,6 +45,7 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord + val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -53,7 +55,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions)) + Some(inferField(parser, configOptions, caseSensitive)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -68,7 +70,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -98,14 +100,15 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + private def inferField( + parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions) + inferField(parser, configOptions, caseSensitive) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -122,7 +125,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions), + inferField(parser, configOptions, caseSensitive), nullable = true) } val fields: Array[StructField] = builder.result() @@ -137,7 +140,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions)) + elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) } ArrayType(elementType) @@ -243,13 +246,14 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode): (DataType, DataType) => DataType = { + parseMode: ParseMode, + caseSensitive: Boolean): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -259,7 +263,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2) + case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -267,8 +271,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { - TypeCoercion.findTightestCommonType(t1, t2).getOrElse { + def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { + TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -303,7 +307,8 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + val dataType = compatibleType( + fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -326,15 +331,17 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + ArrayType( + compatibleType(elementType1, elementType2, caseSensitive), + containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2, caseSensitive) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2), caseSensitive) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..34d23ee53220d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2) + var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1) + actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 75cf369c742e7c7b68f384d123447c97be95c9f0 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 11 May 2018 09:05:35 +0800 Subject: [PATCH 0781/2461] [SPARK-24197][SPARKR][SQL] Adding array_sort function to SparkR ## What changes were proposed in this pull request? The PR adds array_sort function to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Example ``` > df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) > head(collect(select(df, array_sort(df[[1]])))) ``` Result: ``` array_sort(_1) 1 1, 2, 3, NA 2 4, 5, 6, NA, NA ``` Author: Marek Novotny Closes #21294 from mn-mikke/SPARK-24197. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 21 ++++++++++++++++++--- R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++++---- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 8cd00352d1956..5f8209689a559 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,6 +204,7 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_sort", "asc", "ascii", "asin", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 04d0e4620b28a..1f97054443e1b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -207,7 +207,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -3043,6 +3043,20 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array +#' must be orderable. NA elements will be placed at the end of the returned array. +#' +#' @rdname column_collection_functions +#' @aliases array_sort array_sort,Column-method +#' @note array_sort since 2.4.0 +setMethod("array_sort", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc) + column(jc) + }) + #' @details #' \code{flatten}: Transforms an array of arrays into a single array. #' @@ -3125,8 +3139,9 @@ setMethod("size", }) #' @details -#' \code{sort_array}: Sorts the input array in ascending or descending order according -#' to the natural ordering of the array elements. +#' \code{sort_array}: Sorts the input array in ascending or descending order according to +#' the natural ordering of the array elements. NA elements will be placed at the beginning of +#' the returned array in ascending order or at the end of the returned array in descending order. #' #' @rdname column_collection_functions #' @param asc a logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4ef12d19b3575..5faa51eef3abd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,6 +769,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 43725e0ebd3bf..b8bfded0ebf2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,8 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position(), element_at() - # and sort_array() + # Test array_contains(), array_max(), array_min(), array_position() and element_at() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1497,10 +1496,16 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + # Test array_sort() and sort_array() + df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) + + result <- collect(select(df, array_sort(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA))) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] - expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA))) result <- collect(select(df, sort_array(df[[1]])))[[1]] - expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) # Test flattern df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), From 54032682b910dc5089af27d2c7b6efe55700f034 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 11 May 2018 17:40:35 +0800 Subject: [PATCH 0782/2461] [SPARK-24182][YARN] Improve error message when client AM fails. Instead of always throwing a generic exception when the AM fails, print a generic error and throw the exception with the YARN diagnostics containing the reason for the failure. There was an issue with YARN sometimes providing a generic diagnostic message, even though the AM provides a failure reason when unregistering. That was happening because the AM was registering too late, and if errors happened before the registration, YARN would just create a generic "ExitCodeException" which wasn't very helpful. Since most errors in this path are a result of not being able to connect to the driver, this change modifies the AM registration a bit so that the AM is registered before the connection to the driver is established. That way, errors are properly propagated through YARN back to the driver. As part of that, I also removed the code that retried connections to the driver from the client AM. At that point, the driver should already be up and waiting for connections, so it's unlikely that retrying would help - and in case it does, that means a flaky network, which would mean problems would probably show up again. The effect of that is that connection-related errors are reported back to the driver much faster now (through the YARN report). One thing to note is that there seems to be a race on the YARN side that causes a report to be sent to the client without the corresponding diagnostics string from the AM; the diagnostics are available later from the RM web page. For that reason, the generic error messages are kept in the Spark scheduler code, to help guide users to a way of debugging their failure. Also of note is that if YARN's max attempts configuration is lower than Spark's, Spark will not unregister the AM with a proper diagnostics message. Unfortunately there seems to be no way to unregister the AM and still allow further re-attempts to happen. Testing: - existing unit tests - some of our integration tests - hardcoded an invalid driver address in the code and verified the error in the shell. e.g. ``` scala> 18/05/04 15:09:34 ERROR cluster.YarnClientSchedulerBackend: YARN application has exited unexpectedly with state FAILED! Check the YARN application logs for more details. 18/05/04 15:09:34 ERROR cluster.YarnClientSchedulerBackend: Diagnostics message: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: Caused by: java.io.IOException: Failed to connect to localhost/127.0.0.1:1234 ``` Author: Marcelo Vanzin Closes #21243 from vanzin/SPARK-24182. --- docs/running-on-yarn.md | 5 +- .../spark/deploy/yarn/ApplicationMaster.scala | 103 +++++++----------- .../org/apache/spark/deploy/yarn/Client.scala | 43 +++++--- .../spark/deploy/yarn/YarnRMClient.scala | 29 +++-- .../cluster/YarnClientSchedulerBackend.scala | 35 ++++-- 5 files changed, 112 insertions(+), 103 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ceda8a3ae2403..c9e68c3bfd056 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -133,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 595077e7e809f..3d6ee50b070a3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -346,7 +346,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends synchronized { if (!finished) { val inShutdown = ShutdownHookManager.inShutdown() - if (registered) { + if (registered || !isClusterMode) { exitCode = code finalStatus = status } else { @@ -389,37 +389,40 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def registerAM( + host: String, + port: Int, _sparkConf: SparkConf, - _rpcEnv: RpcEnv, - driverRef: RpcEndpointRef, - uiAddress: Option[String]) = { + uiAddress: Option[String]): Unit = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = ApplicationMaster .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) - val driverUrl = RpcEndpointAddress( - _sparkConf.get("spark.driver.host"), - _sparkConf.get("spark.driver.port").toInt, + client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress) + registered = true + } + + private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = { + val appId = client.getAttemptId().getApplicationId().toString() + val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString // Before we initialize the allocator, let's log the information about how executors will // be run up front, to avoid printing this out for every single executor being launched. // Use placeholders for information that changes such as executor IDs. logInfo { - val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt - val executorCores = sparkConf.get(EXECUTOR_CORES) - val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "", + val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt + val executorCores = _sparkConf.get(EXECUTOR_CORES) + val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "", "", executorMemory, executorCores, appId, securityMgr, localResources) dummyRunner.launchContextDebugInfo() } - allocator = client.register(driverUrl, - driverRef, + allocator = client.createAllocator( yarnConf, _sparkConf, - uiAddress, - historyAddress, + driverUrl, + driverRef, securityMgr, localResources) @@ -434,15 +437,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends reporterThread = launchReporterThread() } - /** - * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend. - */ - private def createSchedulerRef(host: String, port: String): RpcEndpointRef = { - rpcEnv.setupEndpointRef( - RpcAddress(host, port.toInt), - YarnSchedulerBackend.ENDPOINT_NAME) - } - private def runDriver(): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -456,11 +450,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Duration(totalWaitTime, TimeUnit.MILLISECONDS)) if (sc != null) { rpcEnv = sc.env.rpcEnv - val driverRef = createSchedulerRef( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port")) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl)) - registered = true + + val userConf = sc.getConf + val host = userConf.get("spark.driver.host") + val port = userConf.get("spark.driver.port").toInt + registerAM(host, port, userConf, sc.ui.map(_.webUrl)) + + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(host, port), + YarnSchedulerBackend.ENDPOINT_NAME) + createAllocator(driverRef, userConf) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -486,10 +485,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val amCores = sparkConf.get(AM_CORES) rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, amCores, true) - val driverRef = waitForSparkDriver() + + // The client-mode AM doesn't listen for incoming connections, so report an invalid port. + registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress")) + + // The driver should be up and listening, so unlike cluster mode, just try to connect to it + // with no waiting or retrying. + val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0)) + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(driverHost, driverPort), + YarnSchedulerBackend.ENDPOINT_NAME) addAmIpFilter(Some(driverRef)) - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress")) - registered = true + createAllocator(driverRef, sparkConf) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -600,40 +607,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def waitForSparkDriver(): RpcEndpointRef = { - logInfo("Waiting for Spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - - // Spark driver should already be up since it launched us, but we don't want to - // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME) - val deadline = System.currentTimeMillis + totalWaitTimeMs - - while (!driverUp && !finished && System.currentTimeMillis < deadline) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at %s:%s, retrying ...". - format(driverHost, driverPort)) - Thread.sleep(100L) - } - } - - if (!driverUp) { - throw new SparkException("Failed to connect to driver!") - } - - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - createSchedulerRef(driverHost, driverPort.toString) - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter(driver: Option[RpcEndpointRef]) = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 134b3e5fef11a..7225ff03dc34e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1019,8 +1019,7 @@ private[spark] class Client( appId: ApplicationId, returnOnRunning: Boolean = false, logApplicationReport: Boolean = true, - interval: Long = sparkConf.get(REPORT_INTERVAL)): - (YarnApplicationState, FinalApplicationStatus) = { + interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = { var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) @@ -1031,11 +1030,13 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") cleanupStagingDir(appId) - return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None) case NonFatal(e) => - logError(s"Failed to contact YARN for application $appId.", e) + val msg = s"Failed to contact YARN for application $appId." + logError(msg, e) // Don't necessarily clean up staging dir because status is unknown - return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) + return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED, + Some(msg)) } val state = report.getYarnApplicationState @@ -1073,14 +1074,14 @@ private[spark] class Client( } if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { cleanupStagingDir(appId) - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } if (returnOnRunning && state == YarnApplicationState.RUNNING) { - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } lastState = state @@ -1129,16 +1130,17 @@ private[spark] class Client( throw new SparkException(s"Application $appId finished with status: $state") } } else { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { + val YarnAppReport(appState, finalState, diags) = monitorApplication(appId) + if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) { + diags.foreach { err => + logError(s"Application diagnostics message: $err") + } throw new SparkException(s"Application $appId finished with failed status") } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { + if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) { throw new SparkException(s"Application $appId is killed") } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + if (finalState == FinalApplicationStatus.UNDEFINED) { throw new SparkException(s"The final status of application $appId is undefined") } } @@ -1477,6 +1479,12 @@ private object Client extends Logging { uri.startsWith(s"$LOCAL_SCHEME:") } + def createAppReport(report: ApplicationReport): YarnAppReport = { + val diags = report.getDiagnostics() + val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None + YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) + } + } private[spark] class YarnClusterApplication extends SparkApplication { @@ -1491,3 +1499,8 @@ private[spark] class YarnClusterApplication extends SparkApplication { } } + +private[spark] case class YarnAppReport( + appState: YarnApplicationState, + finalState: FinalApplicationStatus, + diagnostics: Option[String]) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 17234b120ae13..b59dcf158d87c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -42,23 +42,20 @@ private[spark] class YarnRMClient extends Logging { /** * Registers the application master with the RM. * + * @param driverHost Host name where driver is running. + * @param driverPort Port where driver is listening. * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. - * @param securityMgr The security manager. - * @param localResources Map with information about files distributed via YARN's cache. */ def register( - driverUrl: String, - driverRef: RpcEndpointRef, + driverHost: String, + driverPort: Int, conf: YarnConfiguration, sparkConf: SparkConf, uiAddress: Option[String], - uiHistoryAddress: String, - securityMgr: SecurityManager, - localResources: Map[String, LocalResource] - ): YarnAllocator = { + uiHistoryAddress: String): Unit = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) amClient.start() @@ -70,10 +67,19 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, - trackingUrl) + amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl) registered = true } + } + + def createAllocator( + conf: YarnConfiguration, + sparkConf: SparkConf, + driverUrl: String, + driverRef: RpcEndpointRef, + securityMgr: SecurityManager, + localResources: Map[String, LocalResource]): YarnAllocator = { + require(registered, "Must register AM before creating allocator.") new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) } @@ -88,6 +94,9 @@ private[spark] class YarnRMClient extends Logging { if (registered) { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } + if (amClient != null) { + amClient.stop() + } } /** Returns the attempt ID. */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 06e54a2eaf95a..f1a8df00f9c5b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -75,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend( val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL) assert(client != null && appId.isDefined, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true, - interval = monitorInterval) // blocking + val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get, + returnOnRunning = true, interval = monitorInterval) if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application has already ended! " + - "It might have been killed or unable to launch application master.") + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + val genericMessage = "The YARN application has already ended! " + + "It might have been killed or the Application Master may have failed to start. " + + "Check the YARN application logs for more details." + val exceptionMsg = diags match { + case Some(msg) => + logError(genericMessage) + msg + + case None => + genericMessage + } + throw new SparkException(exceptionMsg) } if (state == YarnApplicationState.RUNNING) { logInfo(s"Application ${appId.get} has started running.") @@ -100,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") + val YarnAppReport(_, state, diags) = + client.monitorApplication(appId.get, logApplicationReport = true) + logError(s"YARN application has exited unexpectedly with state $state! " + + "Check the YARN application logs for more details.") + diags.foreach { err => + logError(s"Diagnostics message: $err") + } allowInterrupt = false sc.stop() } catch { @@ -124,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread - t.setName("Yarn application state monitor") + t.setName("YARN application state monitor") t.setDaemon(true) t } From 928845a42230a2c0a318011002a54ad871468b2e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 10:00:28 -0700 Subject: [PATCH 0783/2461] [SPARK-24172][SQL] we should not apply operator pushdown to data source v2 many times ## What changes were proposed in this pull request? In `PushDownOperatorsToDataSource`, we use `transformUp` to match `PhysicalOperation` and apply pushdown. This is problematic if we have multiple `Filter` and `Project` above the data source v2 relation. e.g. for a query ``` Project Filter DataSourceV2Relation ``` The pattern match will be triggered twice and we will do operator pushdown twice. This is unnecessary, we can use `mapChildren` to only apply pushdown once. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21230 from cloud-fan/step2. --- .../v2/PushDownOperatorsToDataSource.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 9293d4f831bff..e894f8afd6762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -23,17 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project import org.apache.spark.sql.catalyst.rules.Rule object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply( - plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => - // merge the filters - val filters = relation.filters match { - case Some(existing) => - existing ++ newFilters - case _ => - newFilters - } + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + assert(relation.filters.isEmpty, "data source v2 should do push down only once.") val projectAttrs = project.map(_.toAttribute) val projectSet = AttributeSet(project.flatMap(_.references)) @@ -67,5 +60,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { } else { filtered } + + case other => other.mapChildren(apply) } } From 92f6f52ff0ce47e046656ca8bed7d7bfbbb42dcb Mon Sep 17 00:00:00 2001 From: aditkumar Date: Fri, 11 May 2018 14:42:23 -0500 Subject: [PATCH 0784/2461] [MINOR][DOCS] Documenting months_between direction ## What changes were proposed in this pull request? It's useful to know what relationship between date1 and date2 results in a positive number. Author: aditkumar Author: Adit Kumar Closes #20787 from aditkumar/master. --- R/pkg/R/functions.R | 6 +++++- python/pyspark/sql/functions.py | 7 +++++-- .../catalyst/expressions/datetimeExpressions.scala | 14 +++++++++++--- .../spark/sql/catalyst/util/DateTimeUtils.scala | 8 ++++---- .../scala/org/apache/spark/sql/functions.scala | 7 ++++++- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 1f97054443e1b..4964594284aa0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1912,6 +1912,7 @@ setMethod("atan2", signature(y = "Column"), #' @details #' \code{datediff}: Returns the number of days from \code{y} to \code{x}. +#' If \code{y} is later than \code{x} then the result is positive. #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method @@ -1971,7 +1972,10 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} +#' are on the same day of month, or both are the last day of month, time of day will be ignored. +#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5a584152b4f6..b62748e9a2d6c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1108,8 +1108,11 @@ def add_months(start, months): @since(1.5) def months_between(date1, date2, roundOff=True): """ - Returns the number of months between date1 and date2. - Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + If date1 and date2 are on the same day of month, or both are the last day of month, + returns an integer (time of day will be ignored). + The result is rounded off to 8 digits unless `roundOff` is set to `False`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 76aa61415a11f..03422fecb3209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1194,13 +1194,21 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } /** - * Returns number of months between dates date1 and date2. + * Returns number of months between times `timestamp1` and `timestamp2`. + * If `timestamp1` is later than `timestamp2`, then the result is positive. + * If `timestamp1` and `timestamp2` are on the same day of month, or both + * are the last day of month, time of day will be ignored. Otherwise, the + * difference is calculated based on 31 days per month, and rounded to + * 8 digits unless roundOff=false. */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. - The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", + _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result + is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both + are the last day of month, time of day will be ignored. Otherwise, the difference is + calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false. + """, examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index e646da0659e85..80f15053005ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -885,13 +885,13 @@ object DateTimeUtils { /** * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. + * microseconds since 1.1.1970. If time1 is later than time2, the result is positive. * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). + * If time1 and time2 are on the same day of month, or both are the last day of month, + * returns, time of day will be ignored. * * Otherwise, the difference is calculated based on 31 days per month. - * If `roundOff` is set to true, the result is rounded to 8 decimal places. + * The result is rounded to 8 decimal places if `roundOff` is set to true. */ def monthsBetween( time1: SQLTimestamp, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 225de0051d6fa..e7f866ddca681 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2903,7 +2903,12 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. - * The result is rounded off to 8 digits. + * If `date1` is later than `date2`, then the result is positive. + * If `date1` and `date2` are on the same day of month, or both are the last day of month, + * time of day will be ignored. + * + * Otherwise, the difference is calculated based on 31 days per month, and rounded to + * 8 digits. * @group datetime_funcs * @since 1.5.0 */ From f27a035daf705766d3445e5c6a99867c11c552b0 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 11 May 2018 17:00:51 -0700 Subject: [PATCH 0785/2461] [SPARKR] Require Java 8 for SparkR This change updates the SystemRequirements and also includes a runtime check if the JVM is being launched by R. The runtime check is done by querying `java -version` ## How was this patch tested? Tested on a Mac and Windows machine Author: Shivaram Venkataraman Closes #21278 from shivaram/sparkr-skip-solaris. --- R/pkg/DESCRIPTION | 1 + R/pkg/R/client.R | 35 +++++++++++++++++++++++++++++++++++ R/pkg/R/sparkR.R | 1 + R/pkg/R/utils.R | 4 ++-- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f16..f52d785e05cdd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html +SystemRequirements: Java (== 8) Depends: R (>= 3.0), methods diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 7244cc9f9e38e..e9295e05872bd 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -60,6 +60,40 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack combinedArgs } +checkJavaVersion <- function() { + javaBin <- "java" + javaHome <- Sys.getenv("JAVA_HOME") + javaReqs <- utils::packageDescription(utils::packageName(), fields=c("SystemRequirements")) + sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) + if (javaHome != "") { + javaBin <- file.path(javaHome, "bin", javaBin) + } + + # If java is missing from PATH, we get an error in Unix and a warning in Windows + javaVersionOut <- tryCatch( + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) + javaVersionFilter <- Filter( + function(x) { + grepl("java version", x) + }, javaVersionOut) + + javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] + # javaVersionStr is of the form 1.8.0_92. + # Extract 8 from it to compare to sparkJavaVersion + javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) + if (javaVersionNum != sparkJavaVersion) { + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) + } +} + launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { @@ -67,6 +101,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } else { sparkSubmitBin <- sparkSubmitBinName } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(launchScript(sparkSubmitBin, combinedArgs)) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 38ee79477996f..d6a2d08f9c218 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,6 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) + checkJavaVersion() launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index f1b5ecaa017df..c3501977e64bc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -746,7 +746,7 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, wait = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) @@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) { # stdout = F means discard output # stdout = "" means to its console (default) # Note that the console of this child process might not be the same as the running R process. - system2(script, combinedArgs, stdout = "", wait = wait) + system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr) } } From e3dabdf6ef210fb9f4337e305feb9c4983a57350 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 12 May 2018 12:15:36 +0800 Subject: [PATCH 0786/2461] [SPARK-23907] Removes regr_* functions in functions.scala ## What changes were proposed in this pull request? This patch removes the various regr_* functions in functions.scala. They are so uncommon that I don't think they deserve real estate in functions.scala. We can consider adding them later if more users need them. ## How was this patch tested? Removed the associated test case as well. Author: Reynold Xin Closes #21309 from rxin/SPARK-23907. --- .../org/apache/spark/sql/functions.scala | 171 ------------------ .../spark/sql/DataFrameAggregateSuite.scala | 68 ------- 2 files changed, 239 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e7f866ddca681..3c9ace407a58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -811,177 +811,6 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: Column, x: Column): Column = withAggregateFunction { - RegrCount(y.expr, x.expr) - } - - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { - RegrSXX(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: Column, x: Column): Column = withAggregateFunction { - RegrSYY(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgX(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { - RegrSXY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: Column, x: Column): Column = withAggregateFunction { - RegrSlope(y.expr, x.expr) - } - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: Column, x: Column): Column = withAggregateFunction { - RegrR2(y.expr, x.expr) - } - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { - RegrIntercept(y.expr, x.expr) - } - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) - - ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4337fb2290fbc..96c28961e5aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -687,72 +687,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-23907: regression functions") { - val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") - val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) - .toDF("a", "b") - val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( - (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") - checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) - checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) - checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), - Row(null), absTol) - - - checkAggregatesWithTol(correlatedData.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), - absTol) - checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), - absTol) - } } From 5902125ac7ad25a0cb7aa3d98825c8290ee33c12 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sat, 12 May 2018 19:21:42 +0800 Subject: [PATCH 0787/2461] [SPARK-24198][SPARKR][SQL] Adding slice function to SparkR ## What changes were proposed in this pull request? The PR adds the `slice` function to SparkR. The function returns a subset of consecutive elements from the given array. ``` > df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) > tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) > head(select(tmp, slice(tmp$v1, 2L, 2L))) ``` ``` slice(v1, 2, 2) 1 6, 110 2 6, 110 3 4, 93 4 6, 110 5 8, 175 6 6, 105 ``` ## How was this patch tested? A test added into R/pkg/tests/fulltests/test_sparkSQL.R Author: Marek Novotny Closes #21298 from mn-mikke/SPARK-24198. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 +++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ 4 files changed, 27 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5f8209689a559..c575fe255f57a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -352,6 +352,7 @@ exportMethods("%<=>%", "sinh", "size", "skewness", + "slice", "sort_array", "soundex", "spark_partition_id", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4964594284aa0..77d70cb5d19e6 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -212,6 +212,7 @@ NULL #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, slice(tmp$v1, 2L, 2L))) #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) @@ -3142,6 +3143,22 @@ setMethod("size", column(jc) }) +#' @details +#' \code{slice}: Returns an array containing all the elements in x from the index start +#' (or starting from the end if start is negative) with the specified length. +#' +#' @rdname column_collection_functions +#' @param start an index indicating the first element occuring in the result. +#' @param length a number of consecutive elements choosen to the result. +#' @aliases slice slice,Column-method +#' @note slice since 2.4.0 +setMethod("slice", + signature(x = "Column"), + function(x, start, length) { + jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length) + column(jc) + }) + #' @details #' \code{sort_array}: Sorts the input array in ascending or descending order according to #' the natural ordering of the array elements. NA elements will be placed at the beginning of diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5faa51eef3abd..fbc4113e2becc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1194,6 +1194,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("slice", function(x, start, length) { standardGeneric("slice") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index b8bfded0ebf2d..2a550b9efb506 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1507,6 +1507,11 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) + # Test slice() + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L)))) + result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] + expect_equal(result, list(list(2L, 3L), list(5L))) + # Test flattern df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) From 348ddfd20f5b88777014f18a6374f33ee9b12731 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 May 2018 08:35:14 -0500 Subject: [PATCH 0788/2461] [BUILD] Close stale PRs Closes https://github.com/apache/spark/pull/20458 Closes https://github.com/apache/spark/pull/20530 Closes https://github.com/apache/spark/pull/20557 Closes https://github.com/apache/spark/pull/20966 Closes https://github.com/apache/spark/pull/20857 Closes https://github.com/apache/spark/pull/19694 Closes https://github.com/apache/spark/pull/18227 Closes https://github.com/apache/spark/pull/20683 Closes https://github.com/apache/spark/pull/20881 Closes https://github.com/apache/spark/pull/20347 Closes https://github.com/apache/spark/pull/20825 Closes https://github.com/apache/spark/pull/20078 Closes https://github.com/apache/spark/pull/21281 Closes https://github.com/apache/spark/pull/19951 Closes https://github.com/apache/spark/pull/20905 Closes https://github.com/apache/spark/pull/20635 Author: Sean Owen Closes #21303 from srowen/ClosePRs. From 32acfa78c60465efc03ae01e022614ad91345b1c Mon Sep 17 00:00:00 2001 From: Cody Allen Date: Sat, 12 May 2018 14:35:40 -0500 Subject: [PATCH 0789/2461] Improve implicitNotFound message for Encoder The `implicitNotFound` message for `Encoder` doesn't mention the name of the type for which it can't find an encoder. Furthermore, it covers up the fact that `Encoder` is the name of the relevant type class. Hopefully this new message provides a little more specific type detail while still giving the general message about which types are supported. ## What changes were proposed in this pull request? Augment the existing message to mention that it's looking for an `Encoder` and what the type of the encoder is. For example instead of: ``` Unable to find encoder for type stored in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. ``` return this message: ``` Unable to find encoder for type Exception. An implicit Encoder[Exception] is needed to store Exception instances in a Dataset. Primitive types (Int, String, etc) and Product types (ca se classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. ``` ## How was this patch tested? It was tested manually in the Scala REPL, since triggering this in a test would cause a compilation error. ``` scala> implicitly[Encoder[Exception]] :51: error: Unable to find encoder for type Exception. An implicit Encoder[Exception] is needed to store Exception instances in a Dataset. Primitive types (Int, String, etc) and Product types (ca se classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. implicitly[Encoder[Exception]] ^ ``` Author: Cody Allen Closes #20869 from ceedubs/encoder-implicit-msg. --- .../src/main/scala/org/apache/spark/sql/Encoder.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ccdb6bc5d4b7c..7b02317b8538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -68,10 +68,10 @@ import org.apache.spark.sql.types._ */ @Experimental @InterfaceStability.Evolving -@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + - "(Int, String, etc) and Product types (case classes) are supported by importing " + - "spark.implicits._ Support for serializing other types will be added in future " + - "releases.") +@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + + "classes) are supported by importing spark.implicits._ Support for serializing other types " + + "will be added in future releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ From 0d210ec8b610e4b0570ce730f3987dc86787c663 Mon Sep 17 00:00:00 2001 From: Kelley Robinson Date: Sun, 13 May 2018 13:19:03 -0700 Subject: [PATCH 0790/2461] [SPARK-24262][PYTHON] Fix typo in UDF type match error message ## What changes were proposed in this pull request? Updates `functon` to `function`. This was called out in holdenk's PyCon 2018 conference talk. Didn't see any existing PR's for this. holdenk happy to fix the Pandas.Series bug too but will need a bit more guidance. Author: Kelley Robinson Closes #21304 from robinske/master. --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8bb63fcc7ff9c..5d2e58bef6466 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -82,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " + raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " From 2fa33649d96394ae630a092a9f7e1261d1893f6e Mon Sep 17 00:00:00 2001 From: Fan Donglai Date: Sun, 13 May 2018 18:10:00 -0500 Subject: [PATCH 0791/2461] Update StreamingKMeans.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? I think the ‘n_t+t’ in the following code may be wrong, it shoud be ‘n_t+1’ that means is the number of points to the cluster after it finish the no.t+1 min-batch. *
    * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ * n_t+t &= n_t * a + m_t * \end{align} * $$ *
    Author: Fan Donglai Closes #21179 from ddna1021/master. --- .../org/apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3ca75e8cdb97a..7a5e520d5818e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ - * n_t+t &= n_t * a + m_t + * n_t+1 &= n_t * a + m_t * \end{align} * $$ * From 3f0e801c11e600ed28491924e550d3ba93f19c19 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 14 May 2018 09:48:54 +0800 Subject: [PATCH 0792/2461] [SPARK-24186][R][SQL] change reverse and concat to collection functions in R ## What changes were proposed in this pull request? reverse and concat are already in functions.R as column string functions. Since now these two functions are categorized as collection functions in scala and python, we will do the same in R. ## How was this patch tested? Add test in test_sparkSQL.R Author: Huaxin Gao Closes #21307 from huaxingao/spark_24186. --- R/pkg/R/functions.R | 35 ++++++++++++++------------- R/pkg/R/generics.R | 4 +-- R/pkg/tests/fulltests/test_sparkSQL.R | 17 +++++++++++-- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 77d70cb5d19e6..fcb3521f901ea 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,7 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -218,7 +218,10 @@ NULL #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) #' head(select(tmp3, map_values(tmp3$v3))) -#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} +#' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL #' Window functions for Column operations @@ -1260,9 +1263,9 @@ setMethod("quarter", }) #' @details -#' \code{reverse}: Reverses the string column and returns it as a new string column. +#' \code{reverse}: Returns a reversed string or an array with reverse order of elements. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases reverse reverse,Column-method #' @note reverse since 1.5.0 setMethod("reverse", @@ -2055,20 +2058,10 @@ setMethod("countDistinct", #' @details #' \code{concat}: Concatenates multiple input columns together into a single column. -#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. +#' The function works with strings, binary and compatible array columns. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases concat concat,Column-method -#' @examples -#' -#' \dontrun{ -#' # concatenate strings -#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), -#' s2 = concat(df$Class, df$Sex, df$Age), -#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), -#' s4 = concat_ws("_", df$Class, df$Sex), -#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) -#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2409,6 +2402,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex), +#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -3063,7 +3063,8 @@ setMethod("array_sort", }) #' @details -#' \code{flatten}: Transforms an array of arrays into a single array. +#' \code{flatten}: Creates a single array from an array of arrays. +#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. #' #' @rdname column_collection_functions #' @aliases flatten flatten,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fbc4113e2becc..61da30badac4e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -817,7 +817,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -1134,7 +1134,7 @@ setGeneric("regexp_replace", #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 2a550b9efb506..13b55ac6e6e3c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,7 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position() and element_at() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1496,6 +1496,13 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + result <- collect(select(df, reverse(df[[1]])))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L))) + + df2 <- createDataFrame(list(list("abc"))) + result <- collect(select(df2, reverse(df2[[1]])))[[1]] + expect_equal(result, "cba") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1512,7 +1519,13 @@ test_that("column functions", { result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] expect_equal(result, list(list(2L, 3L), list(5L))) - # Test flattern + # Test concat() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + list(list(7L, 8L, 9L), list(10L, 11L, 12L)))) + result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L))) + + # Test flatten() df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) result <- collect(select(df, flatten(df[[1]])))[[1]] From 7a2d4895c75d4c232c377876b61c05a083eab3c8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 14 May 2018 10:01:06 +0800 Subject: [PATCH 0793/2461] [SPARK-17916][SQL] Fix empty string being parsed as null when nullValue is set. ## What changes were proposed in this pull request? I propose to bump version of uniVocity parser up to 2.6.3 where quoted empty strings are replaced by the empty value (passed to `setEmptyValue`) instead of `null` values as in the current version 2.5.9: https://github.com/uniVocity/univocity-parsers/blob/v2.6.3/src/main/java/com/univocity/parsers/csv/CsvParser.java#L125 Empty value for writer is set to `""`. So, empty string in dataframe/dataset is stored as empty quoted string `""`. Empty value for reader is set to empty string (zero size). In this way, saved empty quoted string will be read as just empty string. Please, look at the tests for more details. Here are main changes made in [2.6.0](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.0), [2.6.1](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.1), [2.6.2](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.2), [2.6.3](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.3): - CSV parser now parses quoted values ~30% faster - CSV format detection process has option provide a list of possible delimiters, in order of priority ( i.e. settings.detectFormatAutomatically( '-', '.');) - https://github.com/uniVocity/univocity-parsers/issues/214 - Implemented trim quoted values support - https://github.com/uniVocity/univocity-parsers/issues/230 - NullPointer when stopping parser when nothing is parsed - https://github.com/uniVocity/univocity-parsers/issues/219 - Concurrency issue when calling stopParsing() - https://github.com/uniVocity/univocity-parsers/issues/231 Closes #20068 ## How was this patch tested? Added tests from the PR https://github.com/apache/spark/pull/20068 Author: Maxim Gekk Closes #21273 from MaxGekk/univocity-2.6. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- sql/core/pom.xml | 2 +- .../datasources/csv/CSVOptions.scala | 3 +- .../datasources/csv/CSVBenchmarks.scala | 80 +++++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 46 +++++++++++ 7 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f552b81fde9f4..e710e26348117 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -190,7 +190,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 024b1fca717df..97ad17a9ff7b1 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -191,7 +191,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 938de7bc06663..e21bfef8c4291 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -211,7 +211,7 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar token-provider-1.0.1.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar xbean-asm5-shaded-4.4.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ef41837f89d68..f270c70fbfcf0 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.9 + 2.6.3 jar diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index ed2dc65a47914..1066d156acd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -164,7 +164,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("\"\"") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -185,6 +185,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) + settings.setEmptyValue("") settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala new file mode 100644 index 0000000000000..d442ba7e59c61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure CSV read/write performance. + * To run this: + * spark-submit --class --jars + */ +object CSVBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-csv-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + + withTempPath { path => + val str = (0 until 10000).map(i => s""""$i"""").mkString(",") + + spark.range(rowsNum) + .map(_ => str) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val schema = new StructType().add("value", StringType) + val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"One quoted string", numIters) { _ => + ds.filter((_: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + One quoted string 30273 / 30549 0.0 605451.2 1.0X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 461abdd96d3f3..07e6c74b14d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1322,4 +1322,50 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } } From b6c50d7820aafab172835633fb0b35899e93146b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 May 2018 10:57:10 +0800 Subject: [PATCH 0794/2461] [SPARK-24228][SQL] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following Java lint errors due to importing unimport classes ``` $ dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java:[25] (sizes) LineLength: Line is longer than 100 characters (found 109). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java:[38] (sizes) LineLength: Line is longer than 100 characters (found 102). [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[21,8] (imports) UnusedImports: Unused import - java.io.ByteArrayInputStream. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java:[29,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java:[110] (sizes) LineLength: Line is longer than 100 characters (found 101). ``` With this PR ``` $ dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks passed. ``` ## How was this patch tested? Existing UTs. Also manually run checkstyles against these two files. Author: Kazuaki Ishizaki Closes #21301 from kiszk/SPARK-24228. --- .../datasources/parquet/SpecificParquetRecordReaderBase.java | 1 - .../datasources/parquet/VectorizedPlainValuesReader.java | 1 - .../sql/sources/v2/reader/partitioning/Distribution.java | 3 ++- .../sql/sources/v2/reader/streaming/ContinuousReader.java | 4 ++-- .../apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java | 3 ++- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 10d6ed85a4080..daedfd7e78f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index aacefacfc1c1a..c62dc3d86386e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -26,7 +26,6 @@ import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; -import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index d2ee9518d628f..5e32ba6952e1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -22,7 +22,8 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). + * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * partition). * Note that this interface has nothing to do with the data ordering inside one * partition(the output records of a single {@link InputPartitionReader}). * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 716c5c0e9e15a..6e960bedf8020 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -35,8 +35,8 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each - * partition to a single global offset. + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances + * for each partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 714638e500c94..445cb29f5ee3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -107,7 +107,8 @@ public List> planInputPartitions() { } } - static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { private int start; private int end; private StructType requiredSchema; From 1430fa80e37762e31cc5adc74cd609c215d84b6e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 14 May 2018 10:49:12 -0700 Subject: [PATCH 0795/2461] [SPARK-24263][R] SparkR java check breaks with openjdk ## What changes were proposed in this pull request? Change text to grep for. ## How was this patch tested? manual test Author: Felix Cheung Closes #21314 from felixcheung/openjdkver. --- R/pkg/R/client.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index e9295e05872bd..14a17c600b17f 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -82,7 +82,7 @@ checkJavaVersion <- function() { }) javaVersionFilter <- Filter( function(x) { - grepl("java version", x) + grepl(" version", x) }, javaVersionOut) javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] From c26f673252c2cbbccf8c395ba6d4ab80c098d60e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 14 May 2018 11:37:57 -0700 Subject: [PATCH 0796/2461] [SPARK-24246][SQL] Improve AnalysisException by setting the cause when it's available ## What changes were proposed in this pull request? If there is an exception, it's better to set it as the cause of AnalysisException since the exception may contain useful debug information. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #21297 from zsxwing/SPARK-24246. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- .../spark/sql/catalyst/analysis/ResolveInlineTables.scala | 2 +- .../org/apache/spark/sql/catalyst/analysis/package.scala | 5 +++++ .../org/apache/spark/sql/execution/datasources/rules.scala | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dfdcdbc1eb2c7..3eaa9ecf5d075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -676,13 +676,13 @@ class Analyzer( try { catalog.lookupRelation(tableIdentWithDb) } catch { - case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + case e: NoSuchTableException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e) // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exist.") + s"database ${e.db} doesn't exist.", e) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 4eb6e642b1c37..31ba9d792024b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -105,7 +105,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas castedExpr.eval() } catch { case NonFatal(ex) => - table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 7731336d247db..354a3fa0602a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -41,6 +41,11 @@ package object analysis { def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } + + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String, cause: Throwable): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause)) + } } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0dea767840ed3..cab00251622b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { case _: ClassNotFoundException => u case e: Exception => // the provider is valid, but failed to create a logical plan - u.failAnalysis(e.getMessage) + u.failAnalysis(e.getMessage, e) } } } From 075d678c8844614910b50abca07282bde31ef7e0 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 14 May 2018 13:35:54 -0700 Subject: [PATCH 0797/2461] [SPARK-24155][ML] Instrumentation improvements for clustering ## What changes were proposed in this pull request? changed the instrument for all of the clustering methods ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21218 from ludatabricks/SPARK-23686-1. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 7 +++++-- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 5 ++++- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 +++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 438e53ba6197c..1ad4e097246a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -261,8 +261,9 @@ class BisectingKMeans @Since("2.0.0") ( transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, rdd) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val instr = Instrumentation.create(this, dataset) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() .setK($(k)) @@ -275,6 +276,8 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88d618c3a03a8..3091bb5a2e54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -352,7 +352,7 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) @@ -425,6 +425,9 @@ class GaussianMixture @Since("2.0.0") ( val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) + instr.logNamedValue("logLikelihood", logLikelihood) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 97f246fbfd859..e72d7f9485e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -342,7 +342,7 @@ class KMeans @Since("1.5.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -359,6 +359,8 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() From 8cd83acf4075d369bfcf9e703760d4946ef15f00 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 14 May 2018 14:05:42 -0700 Subject: [PATCH 0798/2461] [SPARK-24027][SQL] Support MapType with StringType for keys as the root type by from_json ## What changes were proposed in this pull request? Currently, the from_json function support StructType or ArrayType as the root type. The PR allows to specify MapType(StringType, DataType) as the root type additionally to mentioned types. For example: ```scala import org.apache.spark.sql.types._ val schema = MapType(StringType, IntegerType) val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() in.select(from_json($"value", schema, Map[String, String]())).collect() ``` ``` res1: Array[org.apache.spark.sql.Row] = Array([Map(a -> 1, b -> 2, c -> 3)]) ``` ## How was this patch tested? It was checked by new tests for the map type with integer type and struct type as value types. Also roundtrip tests like from_json(to_json) and to_json(from_json) for MapType are added. Author: Maxim Gekk Author: Maxim Gekk Closes #21108 from MaxGekk/from_json-map-type. --- python/pyspark/sql/functions.py | 10 ++- .../expressions/jsonExpressions.scala | 10 ++- .../sql/catalyst/json/JacksonParser.scala | 18 ++++- .../org/apache/spark/sql/functions.scala | 29 ++++---- .../apache/spark/sql/JsonFunctionsSuite.scala | 66 +++++++++++++++++++ 5 files changed, 113 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b62748e9a2d6c..6866c1cf9f882 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2095,12 +2095,13 @@ def json_tuple(col, *fields): return Column(jc) +@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` - of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an - unparseable string. + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -2117,6 +2118,9 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] + >>> schema = MapType(StringType(), IntegerType()) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 34161f0f03f4a..04a4eb0ffc032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -548,7 +548,7 @@ case class JsonToStructs( forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) => + case _: StructType | ArrayType(_: StructType, _) | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") @@ -558,6 +558,7 @@ case class JsonToStructs( lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st + case mt: MapType => mt } // This converts parsed rows to the desired output by the given schema. @@ -567,6 +568,8 @@ case class JsonToStructs( (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: MapType => + (rows: Seq[InternalRow]) => rows.head.getMap(0) } @transient @@ -613,6 +616,11 @@ case class JsonToStructs( } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def sql: String = schema match { + case _: MapType => "entries" + case _ => super.sql + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a5a4a13eb608b..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( - schema: StructType, + schema: DataType, val options: JSONOptions) extends Logging { import JacksonUtils._ @@ -57,7 +57,14 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { + private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) val fieldConverters = st.map(_.dataType).map(makeConverter).toArray (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { @@ -87,6 +94,13 @@ class JacksonParser( } } + private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) { + case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3c9ace407a58e..b71dfdad8aa9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3231,9 +3231,9 @@ object functions { from_json(e, schema.asInstanceOf[DataType], options) /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3263,9 +3263,9 @@ object functions { from_json(e, schema, options.asScala.toMap) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3292,8 +3292,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, + * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3305,9 +3306,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3322,9 +3323,9 @@ object functions { } /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string, it could be a diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 00d2acc4a1d8a..055e1fc5640f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -326,4 +326,70 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg4.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() + val schema = + """ + |{ + | "type" : "map", + | "keyType" : "string", + | "valueType" : "integer", + | "valueContainsNull" : true + |} + """.stripMargin + val out = in.select(from_json($"value", schema, Map[String, String]())) + + assert(out.columns.head == "entries") + checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, new StructType().add("b", IntegerType), true) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Row(1)))) + } + + test("SPARK-24027: from_json - map>") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, MapType(StringType, IntegerType)) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) + } + + test("SPARK-24027: roundtrip - from_json -> to_json - map") { + val json = """{"a":1,"b":2,"c":3}""" + val schema = MapType(StringType, IntegerType, true) + val out = Seq(json).toDS().select(to_json(from_json($"value", schema))) + + checkAnswer(out, Row(json)) + } + + test("SPARK-24027: roundtrip - to_json -> from_json - map") { + val in = Seq(Map("a" -> 1)).toDF() + val schema = MapType(StringType, IntegerType, true) + val out = in.select(from_json(to_json($"value"), schema)) + + checkAnswer(out, in) + } + + test("SPARK-24027: from_json - wrong map") { + val in = Seq("""{"a" 1}""").toDS() + val schema = MapType(StringType, IntegerType) + val out = in.select(from_json($"value", schema, Map[String, String]())) + + checkAnswer(out, Row(null)) + } + + test("SPARK-24027: from_json of a map with unsupported key type") { + val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) + + checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + } } From 061e0084ce19c1384ba271a97a0aa1f87abe879d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 14 May 2018 14:35:08 -0700 Subject: [PATCH 0799/2461] [SPARK-23852][SQL] Add withSQLConf(...) to test case ## What changes were proposed in this pull request? Add a `withSQLConf(...)` wrapper to force Parquet filter pushdown for a test that relies on it. ## How was this patch tested? Test passes Author: Henry Robinson Closes #21323 from henryr/spark-23582. --- .../datasources/parquet/ParquetFilterSuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4d0ecdef60986..90da7eb8c4fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -650,13 +650,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-23852: Broken Parquet push-down for partially-written stats") { - // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. - // The row-group statistics include null counts, but not min and max values, which - // triggers PARQUET-1217. - val df = readResourceParquetFile("test-data/parquet-1217.parquet") + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. + // The row-group statistics include null counts, but not min and max values, which + // triggers PARQUET-1217. + val df = readResourceParquetFile("test-data/parquet-1217.parquet") - // Will return 0 rows if PARQUET-1217 is not fixed. - assert(df.where("col > 0").count() === 2) + // Will return 0 rows if PARQUET-1217 is not fixed. + assert(df.where("col > 0").count() === 2) + } } } From 9059f1ee6ae13c8636c9b7fdbb708a349256fb8e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 14 May 2018 19:20:25 -0700 Subject: [PATCH 0800/2461] [SPARK-23780][R] Failed to use googleVis library with new SparkR ## What changes were proposed in this pull request? change generic to get it to work with googleVis also fix lintr ## How was this patch tested? manual test, unit tests Author: Felix Cheung Closes #21315 from felixcheung/googvis. --- R/pkg/R/client.R | 5 +++-- R/pkg/R/generics.R | 2 +- R/pkg/R/sparkR.R | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 14a17c600b17f..4c87f64e7f0e1 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -63,7 +63,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack checkJavaVersion <- function() { javaBin <- "java" javaHome <- Sys.getenv("JAVA_HOME") - javaReqs <- utils::packageDescription(utils::packageName(), fields=c("SystemRequirements")) + javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements")) sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) if (javaHome != "") { javaBin <- file.path(javaHome, "bin", javaBin) @@ -90,7 +90,8 @@ checkJavaVersion <- function() { # Extract 8 from it to compare to sparkJavaVersion javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) if (javaVersionNum != sparkJavaVersion) { - stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", + javaVersionStr)) } } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 61da30badac4e..3ea181157b644 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -624,7 +624,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) +setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d6a2d08f9c218..f7c1663d32c96 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -194,7 +194,7 @@ sparkR.sparkContext <- function( # Don't use readString() so that we can provide a useful # error message if the R and Java versions are mismatched. - authSecretLen = readInt(f) + authSecretLen <- readInt(f) if (length(authSecretLen) == 0 || authSecretLen == 0) { stop("Unexpected EOF in JVM connection data. Mismatched versions?") } From e29176fd7dbcef04a29c4922ba655d58144fed24 Mon Sep 17 00:00:00 2001 From: Goun Na Date: Tue, 15 May 2018 14:11:20 +0800 Subject: [PATCH 0801/2461] [SPARK-23627][SQL] Provide isEmpty in Dataset ## What changes were proposed in this pull request? This PR adds isEmpty() in DataSet ## How was this patch tested? Unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Goun Na Author: goungoun Closes #20800 from goungoun/SPARK-23627. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d518e07bfb62c..f001f16e1d5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -511,6 +511,16 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if the `Dataset` is empty. + * + * @group basic + * @since 2.4.0 + */ + def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) == 0 + } + /** * Returns true if this Dataset contains one or more sources that continuously * return data as it arrives. A Dataset that reads data from a streaming source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..d477d78dc14e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23627: provide isEmpty in DataSet") { + val ds1 = spark.emptyDataset[Int] + val ds2 = Seq(1, 2, 3).toDS() + + assert(ds1.isEmpty == true) + assert(ds2.isEmpty == false) + } + test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] From 80c6d35a3edbfb2e053c7d6650e2f725c36af53e Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 14 May 2018 23:34:42 -0700 Subject: [PATCH 0802/2461] [SPARK-24035][SQL] SQL syntax for Pivot - fix antlr warning ## What changes were proposed in this pull request? 1. Change antlr rule to fix the warning. 2. Add PIVOT/LATERAL check in AstBuilder with a more meaningful error message. ## How was this patch tested? 1. Add a counter case in `PlanParserSuite.test("lateral view")` Author: maryannxue Closes #21324 from maryannxue/spark-24035-fix. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 +++ .../spark/sql/catalyst/parser/PlanParserSuite.scala | 10 ++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f7f921ec22c35..7c54851097af3 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* (pivotClause | lateralView*)? + : FROM relation (',' relation)* lateralView* pivotClause? ; aggregation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 64eed23884584..b9ece295c2510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -504,6 +504,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging withJoinRelations(join, relation) } if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } withPivot(ctx.pivotClause, from) } else { ctx.lateralView.asScala.foldLeft(from)(withGenerate) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 812bfdd7bb885..fb51376c6163f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", expected) + + intercept( + """select * + |from t + |lateral view explode(x) expl + |pivot ( + | sum(x) + | FOR y IN ('a', 'b') + |)""".stripMargin, + "LATERAL cannot be used together with PIVOT in FROM clause") } test("joins") { From 4a2b15f0af400c71b7f20b2048f38a8b74d43dfa Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 15 May 2018 16:04:17 +0800 Subject: [PATCH 0803/2461] [SPARK-24241][SUBMIT] Do not fail fast when dynamic resource allocation enabled with 0 executor ## What changes were proposed in this pull request? ``` ~/spark-2.3.0-bin-hadoop2.7$ bin/spark-sql --num-executors 0 --conf spark.dynamicAllocation.enabled=true Java HotSpot(TM) 64-Bit Server VM warning: ignoring option PermSize=1024m; support was removed in 8.0 Java HotSpot(TM) 64-Bit Server VM warning: ignoring option MaxPermSize=1024m; support was removed in 8.0 Error: Number of executors must be a positive number Run with --help for usage help or --verbose for debug output ``` Actually, we could start up with min executor number with 0 before if dynamically ## How was this patch tested? ut added Author: Kent Yao Closes #21290 from yaooqinn/SPARK-24241. --- .../spark/deploy/SparkSubmitArguments.scala | 7 +++++-- .../spark/deploy/SparkSubmitSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0733fdb72cafb..fed4e0a5069c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils - /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var proxyUser: String = null var principal: String = null var keytab: String = null + private var dynamicAllocationEnabled: Boolean = false // Standalone cluster mode only var supervise: Boolean = false @@ -198,6 +198,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + dynamicAllocationEnabled = + sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -274,7 +276,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { error("Total executor cores must be a positive number") } - if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + if (!dynamicAllocationEnabled && + numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7451e07b25a1f..43286953e4383 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -180,6 +180,26 @@ class SparkSubmitSuite appArgs.toString should include ("thequeue") } + test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") { + val clArgs1 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=true", + "thejar.jar") + new SparkSubmitArguments(clArgs1) + + val clArgs2 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=false", + "thejar.jar") + + val e = intercept[SparkException](new SparkSubmitArguments(clArgs2)) + assert(e.getMessage.contains("Number of executors must be a positive number")) + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", From d610d2a3f57ca551f72cb4e5dfed78f27be62eec Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 May 2018 22:06:58 +0800 Subject: [PATCH 0804/2461] [SPARK-24259][SQL] ArrayWriter for Arrow produces wrong output ## What changes were proposed in this pull request? Right now `ArrayWriter` used to output Arrow data for array type, doesn't do `clear` or `reset` after each batch. It produces wrong output. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21312 from viirya/SPARK-24259. --- python/pyspark/sql/tests.py | 20 +++++++++++++++++++ .../sql/execution/arrow/ArrowWriter.scala | 8 ++++++++ 2 files changed, 28 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa9378ad8ee..a1b6db71782bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4680,6 +4680,26 @@ def test_supported_types(self): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b63513548fe..66888fce7f9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter { valueVector match { case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case listVector: ListVector => + // Manual "reset" the underlying buffer. + // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call + // `listVector.reset()`. + val buffers = listVector.getBuffers(false) + buffers.foreach(buf => buf.setZero(0, buf.capacity())) + listVector.setValueCount(0) + listVector.setLastSet(0) case _ => } count = 0 From 3fabbc576203c7fd63808a259adafc5c3cea1838 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 15 May 2018 10:25:29 -0700 Subject: [PATCH 0805/2461] [SPARK-24040][SS] Support single partition aggregates in continuous processing. ## What changes were proposed in this pull request? Support aggregates with exactly 1 partition in continuous processing. A few small tweaks are needed to make this work: * Replace currentEpoch tracking with an ThreadLocal. This means that current epoch is scoped to a task rather than a node, but I think that's sustainable even once we add shuffle. * Add a new testing-only flag to disable the UnsupportedOperationChecker whitelist of allowed continuous processing nodes. I think this is preferable to writing a pile of custom logic to enforce that there is in fact only 1 partition; we plan to support multi-partition aggregates before the next Spark release, so we'd just have to tear that logic back out. * Restart continuous processing queries from the first available uncommitted epoch, rather than one that's guaranteed to be unused. This is required for stateful operators to overwrite partial state from the previous attempt at the epoch, and there was no specific motivation for the original strategy. In another PR before stabilizing the StreamWriter API, we'll need to narrow down and document more precise semantic guarantees for the epoch IDs. * We need a single-partition ContinuousMemoryStream. The way MemoryStream is constructed means it can't be a text option like it is for rate source, unfortunately. ## How was this patch tested? new unit tests Author: Jose Torres Closes #21239 from jose-torres/withAggr. --- .../UnsupportedOperationChecker.scala | 1 + .../continuous/ContinuousExecution.scala | 11 +-- .../ContinuousQueuedDataReader.scala | 7 +- .../continuous/ContinuousWriteRDD.scala | 18 +++-- .../streaming/continuous/EpochTracker.scala | 58 +++++++++++++++ .../sources/ContinuousMemoryStream.scala | 14 ++-- .../streaming/state/StateStoreRDD.scala | 10 ++- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../ContinuousAggregationSuite.scala | 72 +++++++++++++++++++ .../ContinuousQueuedDataReaderSuite.scala | 1 + 10 files changed, 167 insertions(+), 29 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d3d6c636c4ba8..2bed41672fe33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f58146ac42398..0e7d1019b9c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -122,16 +122,7 @@ class ContinuousExecution( s"Batch $latestEpochId was committed without end epoch offsets!") } committedOffsets = nextOffsets.toStreamProgress(sources) - - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index d8645576c2052..f38577b6a9f16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -46,8 +46,6 @@ class ContinuousQueuedDataReader( // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = ContinuousDataSourceRDD.getContinuousReader(reader).getOffset - private var currentEpoch: Long = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong /** * The record types in the read buffer. @@ -115,8 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), currentEpoch, currentOffset)) - currentEpoch += 1 + context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -184,7 +181,7 @@ class ContinuousQueuedDataReader( private val epochCoordEndpoint = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That // field represents the epoch wrt the data being processed. The currentEpoch here is just a // counter to ensure we send the appropriate number of markers if we fall behind the driver. private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 91f1576581511..ef5f0da1e7cc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -45,7 +45,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null @@ -54,19 +55,24 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor try { val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) + context.partitionId(), + context.attemptNumber(), + EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) } logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch is committing.") + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") val msg = dataWriter.commit() epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + CommitPartitionEpoch( + context.partitionId(), + EpochTracker.getCurrentEpoch.get, + msg) ) logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch committed.") - currentEpoch += 1 + s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.") + EpochTracker.incrementCurrentEpoch() } catch { case _: InterruptedException => // Continuous shutdown always involves an interrupt. Just finish the task. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala new file mode 100644 index 0000000000000..bc0ae428d4521 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +/** + * Tracks the current continuous processing epoch within a task. Call + * EpochTracker.getCurrentEpoch to get the current epoch. + */ +object EpochTracker { + // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will + // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. + private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { + override def initialValue() = new AtomicLong(-1) + } + + /** + * Get the current epoch for the current task, or None if the task has no current epoch. + */ + def getCurrentEpoch: Option[Long] = { + currentEpoch.get().get() match { + case n if n < 0 => None + case e => Some(e) + } + } + + /** + * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * between epochs. + */ + def incrementCurrentEpoch(): Unit = { + currentEpoch.get().incrementAndGet() + } + + /** + * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * at the beginning of a task. + */ + def initializeCurrentEpoch(startEpoch: Long): Unit = { + currentEpoch.get().set(startEpoch) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index fef792eab69d5..4daafa65850de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -47,10 +47,9 @@ import org.apache.spark.util.RpcUtils * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified * offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) - private val NUM_PARTITIONS = 2 protected val logicalPlan = StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) @@ -58,7 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) // ContinuousReader implementation @GuardedBy("this") - private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + private val records = Seq.fill(numPartitions)(new ListBuffer[A]) @GuardedBy("this") private var startOffset: ContinuousMemoryStreamOffset = _ @@ -69,17 +68,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = synchronized { // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { - case (item, index) => records(index % NUM_PARTITIONS) += item + case (item, index) => records(index % numPartitions) += item } // The new target offset is the offset where all records in all partitions have been processed. - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } override def setStartOffset(start: Optional[Offset]): Unit = synchronized { // Inferred initial offset is position 0 in each partition. startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) }.asInstanceOf[ContinuousMemoryStreamOffset] } @@ -152,6 +151,9 @@ object ContinuousMemoryStream { def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 01d8e75980993..3f11b8f79943c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => storeVersion + case Some(value) => value + } + store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7cefd03e43bc3..97da2b1325f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -242,7 +242,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo (sink, trigger) match { case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala new file mode 100644 index 0000000000000..b7ef637f5270e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode + +class ContinuousAggregationSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("not enabled") { + val ex = intercept[AnalysisException] { + val input = ContinuousMemoryStream.singlePartition[Int] + testStream(input.toDF().agg(max('value)), OutputMode.Complete)() + } + + assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + } + + test("basic") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + } + + test("repeated restart") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + StartStream(), + StopStream, + StartStream(), + StopStream, + StartStream(), + AddData(input, 0), + CheckAnswer(2), + AddData(input, 5), + CheckAnswer(5)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index f47d3ec8ae025..e663fa8312da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -51,6 +51,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { startEpoch, spark, SparkEnv.get) + EpochTracker.initializeCurrentEpoch(0) } override def afterEach(): Unit = { From 6b94420f6c672683678a54404e6341a0b9ab3c24 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 15 May 2018 14:16:31 -0700 Subject: [PATCH 0806/2461] [SPARK-24231][PYSPARK][ML] Provide Python API for evaluateEachIteration for spark.ml GBTs ## What changes were proposed in this pull request? Add evaluateEachIteration for GBTClassification and GBTRegressionModel ## How was this patch tested? doctest Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21335 from ludatabricks/SPARK-14682. --- python/pyspark/ml/classification.py | 15 +++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653a1adf9..424ecfd89b060 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] .. versionadded:: 1.4.0 """ @@ -1319,6 +1323,17 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87d7f211..dd0b62f184d26 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1156,6 +1160,20 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, From 8a13c5096898f95d1dfcedaf5d31205a1cbf0a19 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 May 2018 16:50:09 -0700 Subject: [PATCH 0807/2461] [SPARK-24058][ML][PYSPARK] Default Params in ML should be saved separately: Python API ## What changes were proposed in this pull request? See SPARK-23455 for reference. Now default params in ML are saved separately in metadata file in Scala. We must change it for Python for Spark 2.4.0 as well in order to keep them in sync. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21153 from viirya/SPARK-24058. --- python/pyspark/ml/tests.py | 38 ++++++++++++++++++++++++++++++++++++++ python/pyspark/ml/util.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 093593132e56d..0dde0db9e3339 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1595,6 +1595,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..9fa85664939b8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ From 943493b165185c5362c8350dd355276cc458aad0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 16 May 2018 22:01:24 +0800 Subject: [PATCH 0808/2461] =?UTF-8?q?Revert=20"[SPARK-22938][SQL][FOLLOWUP?= =?UTF-8?q?]=20Assert=20that=20SQLConf.get=20is=20acces=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …sed only on the driver" This reverts commit a4206d58e05ab9ed6f01fee57e18dee65cbc4efc. This is from https://github.com/apache/spark/pull/21299 and to ease the review of it. Author: Wenchen Fan Closes #21341 from cloud-fan/revert. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++---------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +-- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 ++--- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 140 insertions(+), 188 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 94b0561529e71..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - val widerType = TypeCoercion.findWiderTypeForTwo( - dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) - if (widerType.isEmpty) { + if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 31ba9d792024b..71ed75454cd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( - inputTypes, conf.caseSensitiveAnalysis) - val tpe = wideType.getOrElse { + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a7ba201509b78..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes(conf) :: + WidenSetOperationTypes :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion(conf) :: + FunctionArgumentConversion :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion(conf) :: - IfCoercion(conf) :: + CaseWhenCoercion :: + IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts(conf) :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,10 +83,7 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - def findTightestCommonType( - left: DataType, - right: DataType, - caseSensitive: Boolean): Option[DataType] = (left, right) match { + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -105,32 +102,22 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => - val isSameType = if (caseSensitive) { - DataType.equalsIgnoreNullability(t1, t2) - } else { - DataType.equalsIgnoreCaseAndNullability(t1, t2) - } - - if (isSameType) { - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since t1 is same type of t2, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - } else { - None - } + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2, caseSensitive) - val valueType = findTightestCommonType(vt1, vt2, caseSensitive) + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -185,14 +172,13 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2, caseSensitive) - .map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -207,8 +193,7 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -216,7 +201,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeForTwo(d, c) case _ => None }) } @@ -228,22 +213,20 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType, - caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) + findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } @@ -296,32 +279,29 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes( - children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -336,19 +316,18 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType], - caseSensitive: Boolean): Seq[DataType] = { + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) + getWidestTypes(children, attrIndex + 1, castedTypes) } } @@ -453,7 +432,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) + .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -482,7 +461,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { + findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -536,7 +515,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { + object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -544,7 +523,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -552,7 +531,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -563,7 +542,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -573,7 +552,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -601,7 +580,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -611,14 +590,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -658,11 +637,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { + object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -689,17 +668,16 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { + object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -798,11 +776,12 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -825,18 +804,17 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b1965c438e27..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") - } - confGetter.get()() - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1280,6 +1274,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4ee12db9c10ca..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true @@ -214,7 +218,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index f73e045685ee1..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType, Boolean) => Option[DataType], + widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) + var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) + found = widenFunc(t2, t1) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion(conf) + val rule = TypeCoercion.FunctionArgumentConversion val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion(conf) + val rule = TypeCoercion.IfCoercion val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion(conf), Division) + val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1edf27619ad7b..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -522,8 +521,6 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => - TypeCoercion.findWiderTypeForTwo( - t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e0424b7478122..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +44,6 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -55,7 +53,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions, caseSensitive)) + Some(inferField(parser, configOptions)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -70,7 +68,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -100,15 +98,14 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField( - parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions, caseSensitive) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -125,7 +122,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions, caseSensitive), + inferField(parser, configOptions), nullable = true) } val fields: Array[StructField] = builder.result() @@ -140,7 +137,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) + elementType, inferField(parser, configOptions)) } ArrayType(elementType) @@ -246,14 +243,13 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode, - caseSensitive: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -263,7 +259,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) + case (ty1, ty2) => compatibleType(ty1, ty2) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -271,8 +267,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { - TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -307,8 +303,7 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType( - fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -331,17 +326,15 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType( - compatibleType(elementType1, elementType2, caseSensitive), - containsNull1 || containsNull2) + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2, caseSensitive) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2), caseSensitive) + compatibleType(t1, DecimalType.forType(t2)) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 34d23ee53220d..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) + var actual = compatibleType(t1, t2) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) + actual = compatibleType(t2, t1) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 6fb7d6c4f71be0007942f7d1fc3099f1bcf8c52b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 17 May 2018 00:40:39 +0800 Subject: [PATCH 0809/2461] [SPARK-24275][SQL] Revise doc comments in InputPartition ## What changes were proposed in this pull request? In #21145, DataReaderFactory is renamed to InputPartition. This PR is to revise wording in the comments to make it more clear. ## How was this patch tested? None Author: Gengliang Wang Closes #21326 from gengliangwang/revise_reader_comments. --- .../spark/sql/sources/v2/ReadSupport.java | 2 +- .../sql/sources/v2/ReadSupportWithSchema.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 2 +- .../sql/sources/v2/reader/DataSourceReader.java | 16 ++++++++-------- .../sql/sources/v2/reader/InputPartition.java | 17 +++++++++-------- .../sql/sources/v2/writer/DataSourceWriter.java | 6 +++--- .../sources/v2/writer/DataWriterFactory.java | 2 +- 7 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 0ea4dc6b5def3..b2526ded53d92 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 { /** * Creates a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param options the options for the returned data source reader, which is an immutable diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 3801402268af1..f31659904cc53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 { /** * Create a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param schema the full schema of this data source reader. Full schema usually maps to the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index cab56453816cc..83aeec0c47853 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 { * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index f898c296e4245..36a3e542b5a11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,7 +31,7 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s that are returned by + * logic is delegated to {@link InputPartition}s, which are returned by * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: @@ -45,8 +45,8 @@ * only one of them would be respected, according to the priority list from high to low: * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * - * If an exception was throw when applying any of these query optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these query optimizations, the action will fail + * and no Spark job will be submitted. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark @@ -59,21 +59,21 @@ public interface DataSourceReader { * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for creating a data reader to - * output data for one RDD partition. That means the number of tasks returned here is same as - * the number of RDD partitions this scan outputs. + * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for + * creating a data reader to output data of one RDD partition. The number of input partitions + * returned here is the same as the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ List> planInputPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index c581e3b5d0047..3524481784fea 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -23,13 +23,14 @@ /** * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader. The relationship between - * {@link InputPartition} and {@link InputPartitionReader} + * responsible for creating the actual data reader of one RDD partition. + * The relationship between {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that input partitions will be serialized and sent to executors, then the partition reader - * will be created on executors and do the actual reading. So {@link InputPartition} must be - * serializable and {@link InputPartitionReader} doesn't need to be. + * Note that {@link InputPartition}s will be serialized and sent to executors, then + * {@link InputPartitionReader}s will be created on executors to do the actual reading. So + * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to + * be. */ @InterfaceStability.Evolving public interface InputPartition extends Serializable { @@ -41,10 +42,10 @@ public interface InputPartition extends Serializable { * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in - * the returned locations. By default this method returns empty string array, which means this - * task has no location preference. + * the returned locations. The default return value is empty string array, which means this + * input partition's reader has no location preference. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ default String[] preferredLocations() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0a0fd8db58035..0030a9f05dba7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -34,8 +34,8 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * - * If an exception was throw when applying any of these writing optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these writing optimizations, the action will fail + * and no Spark job will be submitted. * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the @@ -58,7 +58,7 @@ public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ DataWriterFactory createWriterFactory(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e8..7527bcc0c4027 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param partitionId A unique id of the RDD partition that the returned writer will process. From 8e60a16b73490007fe1c480d77cc09d760f0a02b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 16 May 2018 13:34:54 -0700 Subject: [PATCH 0810/2461] [SPARK-23601][BUILD][FOLLOW-UP] Keep md5 checksums for nexus artifacts. The repository.apache.org server still requires md5 checksums or it won't publish the staging repo. Author: Marcelo Vanzin Closes #21338 from vanzin/SPARK-23601. --- dev/create-release/release-build.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c00b00b845401..5faa3d3260a56 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -371,11 +371,18 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc and .sha1 - it really doesn't like anything else there + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 991726f31a8d182ed6d5b0e59185d97c0c5c532f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 May 2018 14:55:02 -0700 Subject: [PATCH 0811/2461] [SPARK-24158][SS] Enable no-data batches for streaming joins ## What changes were proposed in this pull request? This is a continuation of the larger task of enabling zero-data batches for more eager state cleanup. This PR enables it for stream-stream joins. ## How was this patch tested? - Updated join tests. Additionally, updated them to not use `CheckLastBatch` anywhere to set good precedence for future. Author: Tathagata Das Closes #21253 from tdas/SPARK-24158. --- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 14 +- .../spark/sql/streaming/StreamTest.scala | 15 +- .../sql/streaming/StreamingJoinSuite.scala | 217 +++++++++--------- 4 files changed, 130 insertions(+), 118 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..37a0b9d6c8728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -361,7 +361,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( - "Stream stream joins without equality predicate is not supported", plan = Some(plan)) + "Stream-stream join without equality predicate is not supported", plan = Some(plan)) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index fa7c8ee906ecd..afa664eb76525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val watermarkUsedForStateCleanup = + stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty + + // Latest watermark value is more than that used in this previous executed plan + val watermarkHasChanged = + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + + watermarkUsedForStateCleanup && watermarkHasChanged + } + protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) @@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec( // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. val cleanupIter = joinType match { - case Inner => - leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() case _ => throwBadJoinTypeException() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 9d139a927bea5..f348dac1319cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -199,15 +199,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName" - private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - - private def operatorName = "CheckNewAnswer" + override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}" } object CheckNewAnswer { @@ -218,6 +215,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) } + + def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows) } /** Stops the stream. It must currently be running. */ @@ -747,7 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } } - pos += 1 } try { @@ -761,8 +759,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { actns.foreach(executeAction) } + pos += 1 - case action: StreamAction => executeAction(action) + case action: StreamAction => + executeAction(action) + pos += 1 } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index da8f9608c1e9c..1f62357e6d09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1), CheckAnswer(), AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckLastBatch((10, 20, 30)), + CheckNewAnswer((10, 20, 30)), AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), StopStream, StartStream(), AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckLastBatch((1, 2, 3), (1, 2, 3)), + CheckNewAnswer((1, 2, 3), (1, 2, 3)), StopStream, StartStream(), AddData(input1, 100), AddData(input2, 100), - CheckLastBatch((100, 200, 300)) + CheckNewAnswer((100, 200, 300)) ) } @@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( AddData(input1, 1), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckNewAnswer((1, 10, 2, 3)), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), + CheckNewAnswer(), StopStream, StartStream(), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), + CheckNewAnswer((25, 30, 50, 75)), AddData(input1, 1), - CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark StopStream, StartStream(), AddData(input1, 5), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 5), - CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark ) } @@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 1, updated = 1), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckAnswer((1, 10, 2, 3)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + CheckNewAnswer((25, 30, 50, 75)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input2, 1), - CheckLastBatch(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 AddData(input1, 5), - CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + CheckNewAnswer(), // Same reason as above assertNumStateRows(total = 2, updated = 0) ) } @@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5)), CheckAnswer(), AddData(rightInput, (1, 11)), - CheckLastBatch((1, 5, 11)), + CheckNewAnswer((1, 5, 11)), AddData(rightInput, (1, 10)), - CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), - CheckLastBatch((1, 3, 10), (1, 3, 11)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer(), // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - - // Run another batch with event time = 25 to clear right state where rightTime <= 25 - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), // New data to right input should match with left side (1, 3) and (1, 5), as left state should // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and // state rows with rightTime <= 25 should be removed from state. // (1, 20) ==> filtered by event time watermark = 20 // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as state watermark = 25 + // as 21 < state watermark = 25 // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1), // New data to left input with leftTime <= 20 should be filtered due to event time watermark AddData(leftInput, (1, 20), (1, 21)), - CheckLastBatch((1, 21, 28)), - assertNumStateRows(total = 7, updated = 1) + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1) ) } @@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 20)), CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), - CheckLastBatch(), // matches with nothing on the left + CheckNewAnswer(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 5), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) // Should drop < 20 from left, i.e., none // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) // Should drop < 25 from the right, i.e., 14 and 15 - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat - CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // only 31 added // Advance the watermark AddData(rightInput, (1, 80)), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 1), - + CheckNewAnswer(), // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) // Should drop < 36 from left, i.e., 20, 31 (30 was not added) // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - AddData(rightInput, (1, 50)), - CheckLastBatch((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 50 added ) } @@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with input1.addData(1) q.awaitTermination(10000) } - assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } test("stream stream self join") { @@ -404,10 +402,11 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1, 5), AddData(input2, 1, 5, 10), AddData(input3, 5, 10), - CheckLastBatch((5, 10, 5, 15, 5, 25))) + CheckNewAnswer((5, 10, 5, 15, 5, 25))) } } + class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { import testImplicits._ @@ -465,13 +464,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -492,15 +491,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with value <= 7 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The right rows with rightValue <= 7 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -521,15 +520,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with value <= 4 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The left rows with leftValue <= 4 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -552,13 +551,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -568,14 +567,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -586,14 +585,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -627,21 +626,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), AddData(rightInput, (1, 10), (2, 5)), - CheckLastBatch((1, 1, 5, 10)), + CheckNewAnswer((1, 1, 5, 10)), AddData(rightInput, (1, 11)), - CheckLastBatch(), // no match as left time is too low + CheckNewAnswer(), // no match as left time is too low assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), - CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)), assertNumStateRows(total = 7, updated = 2), - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 1), - AddData(rightInput, (0, 30)), - CheckLastBatch(outerResult), - assertNumStateRows(total = 3, updated = 1) + AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls + CheckNewAnswer(outerResult), + assertNumStateRows(total = 2, updated = 1) ) } } @@ -665,36 +661,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), - CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added + + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state + AddData(rightInput, 20), - CheckLastBatch( - Row(20, 30, 40, 60)), + CheckNewAnswer(Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows - MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), - CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - MultiAddData(leftInput, 70)(rightInput, 71), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 6), // all inputs added since last check + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31 + CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)), + assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state + + MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state + AddData(rightInput, 70), - CheckLastBatch((70, 80, 140, 210)), + CheckNewAnswer((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left - MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), - CheckLastBatch(), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91 + CheckNewAnswer(), + assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state + MultiAddData(leftInput, 1000)(rightInput, 1001), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added - AddData(rightInput, 1000), - CheckLastBatch( - Row(1000, 1010, 2000, 3000), + CheckNewAnswer( Row(101, 110, 202, null), Row(102, 110, 204, null), Row(103, 110, 206, null)), - assertNumStateRows(total = 3, updated = 1) + assertNumStateRows(total = 2, updated = 2) ) } } From bfd75cdfb22a8c2fb005da597621e1ccd3990e82 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Wed, 16 May 2018 17:54:06 -0700 Subject: [PATCH 0812/2461] [SPARK-22210][ML] Add seed for LDA variationalTopicInference ## What changes were proposed in this pull request? - Add seed parameter for variationalTopicInference - Add seed for calling variationalTopicInference in submitMiniBatch - Add var seed in LDAModel so that it can take the seed from LDA and use it for the function call of variationalTopicInference in logLikelihoodBound, topicDistributions, getTopicDistributionMethod, and topicDistribution. ## How was this patch tested? Check the test result in mllib.clustering.LDASuite to make sure the result is repeatable with the seed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21183 from ludatabricks/SPARK-22210. --- .../org/apache/spark/ml/clustering/LDA.scala | 6 ++- .../spark/mllib/clustering/LDAModel.scala | 34 ++++++++++++--- .../spark/mllib/clustering/LDAOptimizer.scala | 42 +++++++++++-------- .../apache/spark/ml/clustering/LDASuite.scala | 6 +++ 4 files changed, 64 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index afe599cd167cb..fed42c959b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -569,10 +569,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248421..f915062d77389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f026b..f8e5f3ed76457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 8d728f063dd8c..4d848205034c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -253,6 +253,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") { From 9a641e7f721d01d283afb09dccefaf32972d3c04 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 17 May 2018 12:07:58 +0800 Subject: [PATCH 0813/2461] [SPARK-21945][YARN][PYTHON] Make --py-files work with PySpark shell in Yarn client mode ## What changes were proposed in this pull request? ### Problem When we run _PySpark shell with Yarn client mode_, specified `--py-files` are not recognised in _driver side_. Here are the steps I took to check: ```bash $ cat /home/spark/tmp.py def testtest(): return 1 ``` ```bash $ ./bin/pyspark --master yarn --deploy-mode client --py-files /home/spark/tmp.py ``` ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() # executor side [1] >>> test() # driver side Traceback (most recent call last): File "", line 1, in File "", line 2, in test ImportError: No module named tmp ``` ### How did it happen? Unlike Yarn cluster and client mode with Spark submit, when Yarn client mode with PySpark shell specifically, 1. It first runs Python shell via: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java#L158 as pointed out by tgravescs in the JIRA. 2. this triggers shell.py and submit another application to launch a py4j gateway: https://github.com/apache/spark/blob/209b9361ac8a4410ff797cff1115e1888e2f7e66/python/pyspark/java_gateway.py#L45-L60 3. it runs a Py4J gateway: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L425 4. it copies (or downloads) --py-files into local temp directory: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L365-L376 and then these files are set up to `spark.submit.pyFiles` 5. Py4J JVM is launched and then the Python paths are set via: https://github.com/apache/spark/blob/7013eea11cb32b1e0038dc751c485da5c94a484b/python/pyspark/context.py#L209-L216 However, these are not actually set because those files were copied into a tmp directory in 4. whereas this code path looks for `SparkFiles.getRootDirectory` where the files are stored only when `SparkContext.addFile()` is called. In other cluster mode, `spark.files` are set via: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L554-L555 and those files are explicitly added via: https://github.com/apache/spark/blob/ecb8b383af1cf1b67f3111c148229e00c9c17c40/core/src/main/scala/org/apache/spark/SparkContext.scala#L395 So we are fine in other modes. In case of Yarn client and cluster with _submit_, these are manually being handled. In particular https://github.com/apache/spark/pull/6360 added most of the logics. In this case, the Python path looks manually set via, for example, `deploy.PythonRunner`. We don't use `spark.files` here. ### How does the PR fix the problem? I tried to make an isolated approach as possible as I can: simply copy py file or zip files into `SparkFiles.getRootDirectory()` in driver side if not existing. Another possible way is to set `spark.files` but it does unnecessary stuff together and sounds a bit invasive. **Before** ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() [1] >>> test() Traceback (most recent call last): File "", line 1, in File "", line 2, in test ImportError: No module named tmp ``` **After** ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() [1] >>> test() 1 ``` ## How was this patch tested? I manually tested in standalone and yarn cluster with PySpark shell. .zip and .py files were also tested with the similar steps above. It's difficult to add a test. Author: hyukjinkwon Closes #21267 from HyukjinKwon/SPARK-21945. --- python/pyspark/context.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dbb463f6005a1..ede3b6af0a8cf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: - self._python_includes.append(filename) - sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + try: + filepath = os.path.join(SparkFiles.getRootDirectory(), filename) + if not os.path.exists(filepath): + # In case of YARN with shell mode, 'spark.submit.pyFiles' files are + # not added via SparkContext.addFile. Here we check if the file exists, + # try to copy and then add it to the path. See SPARK-21945. + shutil.copyfile(path, filepath) + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: + self._python_includes.append(filename) + sys.path.insert(1, filepath) + except Exception: + warnings.warn( + "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to " + "Python path:\n %s" % (path, "\n ".join(sys.path)), + RuntimeWarning) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) From 3e66350c2477a456560302b7738c9d122d5d9c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florent=20P=C3=A9pin?= Date: Thu, 17 May 2018 13:31:14 +0900 Subject: [PATCH 0814/2461] [SPARK-23925][SQL] Add array_repeat collection function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The PR adds a new collection function, array_repeat. As there already was a function repeat with the same signature, with the only difference being the expected return type (String instead of Array), the new function is called array_repeat to distinguish. The behaviour of the function is based on Presto's one. The function creates an array containing a given element repeated the requested number of times. ## How was this patch tested? New unit tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite Author: Florent Pépin Author: Florent Pépin Closes #21208 from pepinoflo/SPARK-23925. --- python/pyspark/sql/functions.py | 14 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 149 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 18 +++ .../org/apache/spark/sql/functions.scala | 20 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 76 +++++++++ 6 files changed, 278 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6866c1cf9f882..925ac34196f4c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2329,6 +2329,20 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_repeat(col, count): + """ + Collection function: creates an array containing a column repeated count times. + + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=[u'ab', u'ab', u'ab'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 087d000a9db70..9c370599bc0df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -427,6 +427,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 12b9ab2b272ab..2a4e42d4ba316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1468,3 +1468,152 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value + val count = rightGen.value + val et = dataType.elementType + + val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { + genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) + } else { + genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) + } + val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) + + ev.copy(code = + s""" + |boolean ${ev.isNull} = false; + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = + | ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin) + } + + private def nullElementsProtection( + ev: ExprCode, + rightIsNull: String, + coreLogic: String): String = { + if (nullable) { + s""" + |if ($rightIsNull) { + | ${ev.isNull} = true; + |} else { + | ${coreLogic} + |} + """.stripMargin + } else { + coreLogic + } + } + + private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { + val numElements = ctx.freshName("numElements") + val numElementsCode = + s""" + |int $numElements = 0; + |if ($count > 0) { + | $numElements = $count; + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (numElements, numElementsCode) + } + + private def genCodeForPrimitiveElement( + ctx: CodegenContext, + elementType: DataType, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val tempArrayDataName = ctx.freshName("tempArrayData") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val errorMessage = s" $prettyName failed." + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |if (!$leftIsNull) { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | } + |} else { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.setNullAt(k); + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForNonPrimitiveElement( + ctx: CodegenContext, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |if (!$leftIsNull) { + | for (int k = 0; k < $numElemName; k++) { + | $arrayName[k] = $element; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2851d071c7c6..57fc5f75dbca7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -468,4 +468,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asa3), null) checkEvaluation(Flatten(asa4), null) } + + test("ArrayRepeat") { + val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType)) + + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi")) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi")) + checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true)) + checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1)) + checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2)) + checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null)) + checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null)) + checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2))) + checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b71dfdad8aa9b..550571a61a036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3447,6 +3447,26 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(left: Column, right: Column): Column = withExpr { + ArrayRepeat(left.expr, right.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ecce06f4c0755..e26565cd153b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -843,6 +843,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array_repeat function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val strDF = Seq( + ("hi", 2), + (null, 2) + ).toDF("a", "b") + + val strDFTwiceResult = Seq( + Row(Seq("hi", "hi")), + Row(Seq(null, null)) + ) + + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + + val intDF = { + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType))) + val data = Seq( + Row(3, 2), + Row(null, 2) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val intDFTwiceResult = Seq( + Row(Seq(3, 3)), + Row(Seq(null, null)) + ) + + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + + val nullCountDF = { + val schema = StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))) + val data = Seq( + Row("hi", null), + Row(null, null) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq( + Row(null), + Row(null) + ) + ) + + // Error test cases + val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") + + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", $"b")) + } + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", lit("1"))) + } + intercept[AnalysisException] { + invalidTypeDF.selectExpr("array_repeat(a, 1.0)") + } + + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 6c35865d949a8b46f654cd53c7e5f3288def18d0 Mon Sep 17 00:00:00 2001 From: Artem Rudoy Date: Thu, 17 May 2018 18:49:46 +0800 Subject: [PATCH 0815/2461] [SPARK-22371][CORE] Return None instead of throwing an exception when an accumulator is garbage collected. ## What changes were proposed in this pull request? There's a period of time when an accumulator has been garbage collected, but hasn't been removed from AccumulatorContext.originals by ContextCleaner. When an update is received for such accumulator it will throw an exception and kill the whole job. This can happen when a stage completes, but there're still running tasks from other attempts, speculation etc. Since AccumulatorContext.get() returns an option we can just return None in such case. ## How was this patch tested? Unit test. Author: Artem Rudoy Closes #21114 from artemrd/SPARK-22371. --- .../org/apache/spark/util/AccumulatorV2.scala | 14 +++++++++----- .../scala/org/apache/spark/AccumulatorSuite.scala | 6 ++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 2bc84953a56eb..3b469a69437b9 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo private[spark] case class AccumulatorMetadata( @@ -211,7 +212,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { /** * An internal class used to track accumulators by Spark itself. */ -private[spark] object AccumulatorContext { +private[spark] object AccumulatorContext extends Logging { /** * This global map holds the original accumulator objects that are created on the driver. @@ -258,13 +259,16 @@ private[spark] object AccumulatorContext { * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ def get(id: Long): Option[AccumulatorV2[_, _]] = { - Option(originals.get(id)).map { ref => - // Since we are storing weak references, we must check whether the underlying data is valid. + val ref = originals.get(id) + if (ref eq null) { + None + } else { + // Since we are storing weak references, warn when the underlying data is not valid. val acc = ref.get if (acc eq null) { - throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") + logWarning(s"Attempted to access garbage collected accumulator $id") } - acc + Option(acc) } } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 3990ee1ec326d..5d0ffd92647bc 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex System.gc() assert(ref.get.isEmpty) - // Getting a garbage collected accum should throw error - intercept[IllegalStateException] { - AccumulatorContext.get(accId) - } + // Getting a garbage collected accum should return None. + assert(AccumulatorContext.get(accId).isEmpty) // Getting a normal accumulator. Note: this has to be separate because referencing an // accumulator above in an `assert` would keep it from being garbage collected. From 6ec05826d7b0a512847e2522564e01256c8d192d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 May 2018 20:42:40 +0800 Subject: [PATCH 0816/2461] [SPARK-24107][CORE][FOLLOWUP] ChunkedByteBuffer.writeFully method has not reset the limit value ## What changes were proposed in this pull request? According to the discussion in https://github.com/apache/spark/pull/21175 , this PR proposes 2 improvements: 1. add comments to explain why we call `limit` to write out `ByteBuffer` with slices. 2. remove the `try ... finally` ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21327 from cloud-fan/minor. --- .../spark/util/io/ChunkedByteBuffer.scala | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 3ae8dfcc1cb66..700ce56466c35 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,15 +63,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - val curChunkLimit = bytes.limit() + val originalLimit = bytes.limit() while (bytes.hasRemaining) { - try { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) - } finally { - bytes.limit(curChunkLimit) - } + // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct + // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread. + // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may + // cause significant native memory leak, if a large direct ByteBuffer is allocated and + // cached, as it's never released until thread exits. Here we write the `bytes` with + // fixed-size slices to limit the size of the cached direct ByteBuffer. + // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details. + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + bytes.limit(originalLimit) } } } From 69350aa2f0a7aee4dcb1067f073b61a0b9f9cb51 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 17 May 2018 20:45:32 +0800 Subject: [PATCH 0817/2461] [SPARK-23922][SQL] Add arrays_overlap function ## What changes were proposed in this pull request? The PR adds the function `arrays_overlap`. This function returns `true` if the input arrays contain a non-null common element; if not, it returns `null` if any of the arrays contains a `null` element, `false` otherwise. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21028 from mgaido91/SPARK-23922. --- python/pyspark/sql/functions.py | 15 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 267 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 66 +++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 29 ++ 6 files changed, 388 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 925ac34196f4c..8490081facc5a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1855,6 +1855,21 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + @since(2.4) def slice(x, start, length): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c370599bc0df..867c2d5eab53d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2a4e42d4ba316..c82db839438ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,15 +18,51 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + } + } +} + + /** * Given an array or map, returns its size. Returns -1 if null. */ @@ -529,6 +565,235 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Checks if the two arrays contain at least one common element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +// scalastyle:off line.size.limit +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryArrayExpressionWithImplicitCast { + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (RowOrdering.isOrderable(elementType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") + } + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + fastEval _ + } else { + bruteForceEval _ + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smaller.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smaller.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + smallestSet += v + }) + bigger.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else if (smallestSet.contains(v1)) { + return true + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0 && arr2.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val comparisonCode = if (elementTypeSupportEquals) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } + s""" + |ArrayData $smaller; + |ArrayData $bigger; + |if ($a1.numElements() > $a2.numElements()) { + | $bigger = $a1; + | $smaller = $a2; + |} else { + | $smaller = $a1; + | $bigger = $a2; + |} + |if ($smaller.numElements() > 0) { + | $comparisonCode + |} + """.stripMargin + }) + } + + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { + | $compareValues + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { + | $isInSmaller + |} + """.stripMargin + } + + def nullSafeElementCodegen( + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } + + override def prettyName: String = "arrays_overlap" +} + /** * Slices an array according to the requested start index and length */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 57fc5f75dbca7..6ae1ac18c4dc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) + + // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) + } + test("Slice") { val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 550571a61a036..2a8fe583b83bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3085,6 +3085,17 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + /** * Returns an array containing all the elements in `x` from index `start` (or starting from the * end if `start` is negative) with the specified `length`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e26565cd153b4..d08982a138bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) + + intercept[AnalysisException] { + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } + } + test("slice function") { val df = Seq( Seq(1, 2, 3), From 8a837bf4f3f2758f7825d2362cf9de209026651a Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 17 May 2018 22:29:18 +0800 Subject: [PATCH 0818/2461] [SPARK-24193] create TakeOrderedAndProjectExec only when the limit number is below spark.sql.execution.topKSortFallbackThreshold. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Physical plan of `select colA from t order by colB limit M` is `TakeOrderedAndProject`; Currently `TakeOrderedAndProject` sorts data in memory, see https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala#L158 We can add a config – if the number of limit (M) is too big, we can sort by disk. Thus memory issue can be resolved. ## How was this patch tested? Test added Author: jinxing Closes #21252 from jinxing64/SPARK-24193. --- .../org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 12 ++++++++---- .../apache/spark/sql/execution/PlannerSuite.scala | 12 ++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..2a673c6ce8f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1253,6 +1253,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .internal() + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .intConf + .createWithDefault(Int.MaxValue) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1424,6 +1433,8 @@ class SQLConf extends Serializable with Logging { def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 37a0b9d6c8728..b97a87a122406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,9 +66,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => // With whole stage codegen, Spark releases resources only when all the output data of the @@ -79,9 +81,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f0dfe6b76f7ae..a375f881c7d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -197,6 +197,18 @@ class PlannerSuite extends SharedSQLContext { assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } + test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { + val query0 = testData.select('value).orderBy('key).limit(100) + val planned0 = query0.queryExecution.executedPlan + assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + + val query1 = testData.select('value).orderBy('key).limit(2000) + val planned1 = query1.queryExecution.executedPlan + assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + } + } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { val query = testData.select('key, 'value).sort('key.desc).cache() assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) From a7a9b1837808b281f47643490abcf054f6de7b50 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 17 May 2018 11:13:16 -0700 Subject: [PATCH 0819/2461] [SPARK-24115] Have logging pass through instrumentation class. ## What changes were proposed in this pull request? Fixes to tuning instrumentation. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21340 from MrBago/tunning-instrumentation. --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++----- .../apache/spark/ml/tuning/TrainValidationSplit.scala | 10 +++++----- .../org/apache/spark/ml/util/Instrumentation.scala | 7 +++++++ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5e916cc4a9fdd..f327f37bad204 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() - logDebug(s"Train split $splitIndex with multiple sets of parameters.") + instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => @@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 13369c4df7180..14d6a69c36747 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } else None // Fit models in a Future for training in parallel - logDebug(s"Train split with multiple sets of parameters.") + instr.logDebug(s"Train split with multiple sets of parameters.") val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] @@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.unpersist() validationDataset.unpersist() - logInfo(s"Train validation split metrics: ${metrics.toSeq}") + instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best train validation split metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3247c394dfa64..467130b37c16e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -58,6 +58,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( s" storageLevel=${dataset.getStorageLevel}") } + /** + * Logs a debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Logs a warning message with a prefix that uniquely identifies the training session. */ From 439c69511812776cb4b82956547ce958d0669c52 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 17 May 2018 13:42:10 -0700 Subject: [PATCH 0820/2461] [SPARK-24114] Add instrumentation to FPGrowth. ## What changes were proposed in this pull request? Have FPGrowth keep track of model training using the Instrumentation class. ## How was this patch tested? manually Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21344 from MrBago/fpgrowth-instr. --- mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 0bf405d9abf9d..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -161,6 +161,8 @@ class FPGrowth @Since("2.2.0") ( private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val instr = Instrumentation.create(this, dataset) + instr.logParams(params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -183,7 +185,9 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + instr.logSuccess(model) + model } @Since("2.2.0") From d4a0895c628ca854895c3c35c46ed990af36ec61 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Thu, 17 May 2018 16:33:06 -0700 Subject: [PATCH 0821/2461] [SPARK-22884][ML] ML tests for StructuredStreaming: spark.ml.clustering ## What changes were proposed in this pull request? Converting clustering tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. This PR is a new version of https://github.com/apache/spark/pull/20319 Author: Sandor Murakozi Author: Joseph K. Bradley Closes #21358 from jkbradley/smurakozi-SPARK-22884. --- .../ml/clustering/BisectingKMeansSuite.scala | 41 ++++++++++--------- .../ml/clustering/GaussianMixtureSuite.scala | 22 ++++------ .../spark/ml/clustering/KMeansSuite.scala | 31 +++++++------- .../apache/spark/ml/clustering/LDASuite.scala | 21 ++++------ 4 files changed, 50 insertions(+), 65 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f3ff2afcad2cd..81842afbddbbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -19,17 +19,18 @@ package org.apache.spark.ml.clustering import scala.language.existentials -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset -class BisectingKMeansSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -68,10 +69,13 @@ class BisectingKMeansSuite // Verify fit does not fail on very sparse data val model = bkm.fit(sparseDataset) - val result = model.transform(sparseDataset) - val numClusters = result.select("prediction").distinct().collect().length - // Verify we hit the edge case - assert(numClusters < k && numClusters > 1) + + testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") { + rows => + val numClusters = rows.distinct.length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } } test("setter/getter") { @@ -104,19 +108,16 @@ class BisectingKMeansSuite val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } + // Check validity of model summary val numRows = dataset.count() assert(model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index d0d461a42711a..0b91f502f615b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -23,16 +23,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.stat.distribution.MultivariateGaussian -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} -class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { - import testImplicits._ +class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { + import GaussianMixtureSuite._ + import testImplicits._ final val k = 5 private val seed = 538009335 @@ -119,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.weights.length === k) assert(model.gaussians.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName, probabilityColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - // Check prediction matches the highest probability, and probabilities sum to one. - transformed.select(predictionColName, probabilityColName).collect().foreach { - case Row(pred: Int, prob: Vector) => + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName, probabilityColName) { + case Row(_, pred: Int, prob: Vector) => val probArray = prob.toArray val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2 assert(pred === predFromProb) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 680a7c2034083..2569e7a432ca4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,20 +22,21 @@ import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest - with PMMLReadWriteTest { +class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -109,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val model = kmeans.fit(dataset) assert(model.clusterCenters.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) @@ -149,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) - Seq(featuresColName, predictionColName).foreach { column => - assert(transformed.columns.contains(column)) - } + assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName)) assert(model.getFeaturesCol == featuresColName) assert(model.getPredictionCol == predictionColName) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 4d848205034c0..096b5416899e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -21,11 +21,9 @@ import scala.language.existentials import org.apache.hadoop.fs.Path -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ object LDASuite { @@ -61,7 +59,7 @@ object LDASuite { } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LDASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -186,16 +184,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.topicsMatrix.numCols === k) assert(!model.isDistributed) - // transform() - val transformed = model.transform(dataset) - val expectedColumns = Array("features", lda.getTopicDistributionCol) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - transformed.select(lda.getTopicDistributionCol).collect().foreach { r => - val topicDistribution = r.getAs[Vector](0) - assert(topicDistribution.size === k) - assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", lda.getTopicDistributionCol) { + case Row(_, topicDistribution: Vector) => + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) } // logLikelihood, logPerplexity From 7b2dca5b12164b787ec4e8e7e9f92c60a3f9563e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 18 May 2018 15:32:29 +0800 Subject: [PATCH 0822/2461] [SPARK-24277][SQL] Code clean up in SQL module: HadoopMapReduceCommitProtocol ## What changes were proposed in this pull request? In HadoopMapReduceCommitProtocol and FileFormatWriter, there are unnecessary settings in hadoop configuration. Also clean up some code in SQL module. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21329 from gengliangwang/codeCleanWrite. --- .../io/HadoopMapReduceCommitProtocol.scala | 15 +++------------ .../datasources/orc/OrcColumnVector.java | 6 +----- .../parquet/VectorizedRleValuesReader.java | 4 ++-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/views.scala | 10 ++++------ .../execution/datasources/FileFormatWriter.scala | 11 +++++------ .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- 7 files changed, 17 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3e60c50ada59b..163511b7ffa3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -145,18 +145,9 @@ class HadoopMapReduceCommitProtocol( } override def setupJob(jobContext: JobContext): Unit = { - // Setup IDs - val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) - val taskId = new TaskID(jobId, TaskType.MAP, 0) - val taskAttemptId = new TaskAttemptID(taskId, 0) - - // Set up the configuration object - jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) - jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) - jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) - jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) - jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) - + // Create a dummy [[TaskAttemptContextImpl]] with configuration to get [[OutputCommitter]] + // instance on Spark driver. Note that the job/task/attampt id doesn't matter here. + val taskAttemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) committer.setupJob(jobContext) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..fcf73e8d7ae6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -47,11 +47,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto OrcColumnVector(DataType type, ColumnVector vector) { super(type); - if (type instanceof TimestampType) { - isTimestamp = true; - } else { - isTimestamp = false; - } + isTimestamp = type instanceof TimestampType; baseData = vector; if (vector instanceof LongColumnVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fe3d31ae8e746..de0d65a1e0906 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -571,7 +571,7 @@ private int readIntLittleEndian() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); } /** @@ -592,7 +592,7 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); + return (ch1 << 16) + (ch2 << 8) + (ch3); } case 4: { return readIntLittleEndian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index af20764f9a968..265a84b39a425 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -188,7 +188,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] + val fields = SerDe.readList(dis, jvmObjectTracker = null) Row.fromSeq(fields) case _ => null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 5172f32ec7b9c..6373584b10e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -410,12 +410,10 @@ object ViewHelper { } // Detect cyclic references from subqueries. - plan.expressions.foreach { expr => - expr match { - case s: SubqueryExpression => - checkCyclicViewReference(s.plan, path, viewIdent) - case _ => // Do nothing. - } + plan.expressions.foreach { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..681bb1df6bbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -244,18 +244,17 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) - val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) - val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the attempt context required to use in the output committer. val taskAttemptContext: TaskAttemptContext = { + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) - hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -378,7 +377,7 @@ object FileFormatWriter extends Logging { dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.map(_.newFile(currentPath)) + statsTrackers.foreach(_.newFile(currentPath)) } override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { @@ -429,10 +428,10 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty + private val isPartitioned = desc.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined + private val isBucketed = desc.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d254af400a7cf..2c4d0bcf103ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -170,7 +170,7 @@ class SQLAppStatusListener( .filter { case (id, _) => metricIds.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) } // Check the execution again for whether the aggregated metrics data has been calculated. From 0cf59fcbe3799dd3c4469cbf8cd842d668a76f34 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 18 May 2018 09:53:24 -0700 Subject: [PATCH 0823/2461] [SPARK-24303][PYTHON] Update cloudpickle to v0.4.4 ## What changes were proposed in this pull request? cloudpickle 0.4.4 is released - https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.4 There's no invasive change - the main difference is that we are now able to pickle the root logger, which fix is pretty isolated. ## How was this patch tested? Jenkins tests. Author: hyukjinkwon Closes #21350 from HyukjinKwon/SPARK-24303. --- python/pyspark/cloudpickle.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index ea845b98b3db2..88519d7311fcc 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -272,7 +272,7 @@ def save_memoryview(self, obj): if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) @@ -801,10 +801,10 @@ def save_ellipsis(self, obj): def save_not_implemented(self, obj): self.save_reduce(_gen_not_implemented, ()) - if PY3: - dispatch[io.TextIOWrapper] = save_file - else: + try: # Python 2 dispatch[file] = save_file + except NameError: # Python 3 + dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented @@ -819,6 +819,11 @@ def save_logger(self, obj): dispatch[logging.Logger] = save_logger + def save_root_logger(self, obj): + self.save_reduce(logging.getLogger, (), obj=obj) + + dispatch[logging.RootLogger] = save_root_logger + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" From 7696b9de0df6e9eb85a74bdb404409da693cf65e Mon Sep 17 00:00:00 2001 From: Soham Aurangabadkar Date: Fri, 18 May 2018 10:29:34 -0700 Subject: [PATCH 0824/2461] [SPARK-20538][SQL] Wrap Dataset.reduce with withNewRddExecutionId. ## What changes were proposed in this pull request? Wrap Dataset.reduce with `withNewExecutionId`. Author: Soham Aurangabadkar Closes #21316 from sohama4/dataset_reduce_withexecutionid. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f001f16e1d5ee..32267eb0300f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1617,7 +1617,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: From 807ba44cb742c5f7c22bdf6bfe2cf814be85398e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 May 2018 10:35:43 -0700 Subject: [PATCH 0825/2461] [SPARK-24159][SS] Enable no-data micro batches for streaming mapGroupswithState ## What changes were proposed in this pull request? Enabled no-data batches in flatMapGroupsWithState in following two cases. - When ProcessingTime timeout is used, then we always run a batch every trigger interval. - When event-time watermark is defined, then the user may be doing arbitrary logic against the watermark value even if timeouts are not set. In such cases, it's best to run batches whenever the watermark has changed, irrespective of whether timeouts (i.e. event-time timeout) have been explicitly enabled. ## How was this patch tested? updated tests Author: Tathagata Das Closes #21345 from tdas/SPARK-24159. --- .../FlatMapGroupsWithStateExec.scala | 17 ++- .../FlatMapGroupsWithStateSuite.scala | 120 ++++++++++-------- 2 files changed, 80 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 80769d728b8f1..8e82cccbc8fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec( case _ => iter } - // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all @@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timingOutPairs = store.getRange(None, None).filter { rowPair => val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => + timingOutPairs.flatMap { rowPair => callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b1416bff87ee7..988c8e6753e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -615,20 +615,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -657,15 +657,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) + CheckNewAnswer(("a", "1"), ("c", "1")) ) } @@ -694,22 +694,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Complete)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckNewAnswer(("a", 1)), AddData(inputData, "a", "b"), // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckNewAnswer(("a", 2), ("b", 1)), StopStream, StartStream(), AddData(inputData, "a", "b"), // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + CheckNewAnswer(("a", 3), ("b", 2)), StopStream, StartStream(), AddData(inputData, "a", "c"), // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) ) } @@ -729,8 +729,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } @@ -757,17 +757,17 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "b"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("b", "1")), + CheckNewAnswer(("b", "1")), assertNumStateRows(total = 2, updated = 1), AddData(inputData, "b"), AdvanceManualClock(10 * 1000), - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, @@ -775,38 +775,42 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "c"), AdvanceManualClock(11 * 1000), - CheckLastBatch(("b", "-1"), ("c", "1")), + CheckNewAnswer(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), - AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), - CheckLastBatch(("c", "2")), - assertNumStateRows(total = 1, updated = 1) + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows(total = 0, updated = 0) ) } test("flatMapGroupsWithState - streaming with event time timeout + watermark") { - // Function to maintain the max event time - // Returns the max event time in the state, or -1 if the state was removed by timeout + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - val timeoutDelay = 5 - if (key != "a") { - Iterator.empty + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) } else { - if (state.hasTimedOut) { - state.remove() - Iterator((key, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) - val timeoutTimestampMs = maxEventTime + timeoutDelay - state.update(maxEventTime) - state.setTimeoutTimestamp(timeoutTimestampMs * 1000) - Iterator((key, maxEventTime.toInt)) - } + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) } } val inputData = MemoryStream[(String, Int)] @@ -819,15 +823,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second")), - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... - CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckLastBatch(), // No output as data should get filtered by watermark - AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s - CheckLastBatch(), // No output as no data for "a" - AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored - CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 ) } @@ -856,20 +868,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -920,15 +932,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), + CheckNewAnswer(("a", 1L)), AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), + CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), ExpectFailure[SparkException](), // task should fail but should not increment count setFailInTask(false), StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count ) } @@ -938,7 +950,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch("a"), + CheckNewAnswer("a"), AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) ) } @@ -1000,7 +1012,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")) + CheckNewAnswer(("a", "1")) ) } } From ed7ba7db8fa344ff182b72d23ae458e711f63432 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 18 May 2018 11:14:22 -0700 Subject: [PATCH 0826/2461] [SPARK-23850][SQL] Add separate config for SQL options redaction. The old code was relying on a core configuration and extended its default value to include things that redact desired things in the app's environment. Instead, add a SQL-specific option for which options to redact, and apply both the core and SQL-specific rules when redacting the options in the save command. This is a little sub-optimal since it adds another config, but it retains the current default behavior. While there I also fixed a typo and a couple of minor config API usage issues in the related redaction option that SQL already had. Tested with existing unit tests, plus checking the env page on a shell UI. Author: Marcelo Vanzin Closes #21158 from vanzin/SPARK-23850. --- .../spark/internal/config/package.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 24 +++++++++++++++++-- .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../SaveIntoDataSourceCommand.scala | 5 ++-- .../SaveIntoDataSourceCommandSuite.scala | 3 --- 6 files changed, 27 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 82f0a04e94b1c..a54b091a64d50 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -342,7 +342,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password|url|user|username".r) + .createWithDefault("(?i)secret|password".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2a673c6ce8f4a..53a50305348fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1155,8 +1155,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = + buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .regexConf + .createWithDefault("(?i)url".r) + val SQL_STRING_REDACTION_PATTERN = - ConfigBuilder("spark.sql.redaction.string.regex") + buildConf("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + "information. When this regex matches a string part, that string part is replaced by a " + "dummy value. This is currently used to redact the output of SQL explain commands. " + @@ -1429,7 +1438,7 @@ class SQLConf extends Serializable with Logging { def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) - def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) @@ -1738,6 +1747,17 @@ class SQLConf extends Serializable with Logging { }.toSeq } + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions(options: Map[String, String]): Map[String, String] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap + } + /** * Return whether a given key is set in this [[SQLConf]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..61c14fee09337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 15379a0663f7d..3112b306c365e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -225,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * Redact the sensitive information in the given string. */ private def withRedaction(message: String): String = { - Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message) } /** A special namespace for commands that can be used to debug query execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 568e953a5db66..00b1b5dedb593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.CreatableRelationProvider -import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand( } override def simpleString: String = { - val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 4b3ca8e60cab6..a1da3ec43eae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class SaveIntoDataSourceCommandSuite extends SharedSQLContext { - override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.redaction.regex", "(?i)password|url") - test("simpleString is redacted") { val URL = "connection.url" val PASS = "123" From 1c4553d67de8089e8aa84bc736faa11f21615a6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 18 May 2018 12:51:09 -0700 Subject: [PATCH 0827/2461] Revert "[SPARK-24277][SQL] Code clean up in SQL module: HadoopMapReduceCommitProtocol" This reverts commit 7b2dca5b12164b787ec4e8e7e9f92c60a3f9563e. --- .../io/HadoopMapReduceCommitProtocol.scala | 15 ++++++++++++--- .../datasources/orc/OrcColumnVector.java | 6 +++++- .../parquet/VectorizedRleValuesReader.java | 4 ++-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/views.scala | 10 ++++++---- .../execution/datasources/FileFormatWriter.scala | 11 ++++++----- .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- 7 files changed, 33 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 163511b7ffa3a..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -145,9 +145,18 @@ class HadoopMapReduceCommitProtocol( } override def setupJob(jobContext: JobContext): Unit = { - // Create a dummy [[TaskAttemptContextImpl]] with configuration to get [[OutputCommitter]] - // instance on Spark driver. Note that the job/task/attampt id doesn't matter here. - val taskAttemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + // Setup IDs + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) + jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) + jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) committer.setupJob(jobContext) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index fcf73e8d7ae6c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -47,7 +47,11 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto OrcColumnVector(DataType type, ColumnVector vector) { super(type); - isTimestamp = type instanceof TimestampType; + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } baseData = vector; if (vector instanceof LongColumnVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index de0d65a1e0906..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -571,7 +571,7 @@ private int readIntLittleEndian() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** @@ -592,7 +592,7 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return (ch1 << 16) + (ch2 << 8) + (ch3); + return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { return readIntLittleEndian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 265a84b39a425..af20764f9a968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -188,7 +188,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis, jvmObjectTracker = null) + val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] Row.fromSeq(fields) case _ => null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 6373584b10e35..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -410,10 +410,12 @@ object ViewHelper { } // Detect cyclic references from subqueries. - plan.expressions.foreach { - case s: SubqueryExpression => - checkCyclicViewReference(s.plan, path, viewIdent) - case _ => // Do nothing. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 681bb1df6bbae..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -244,17 +244,18 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the attempt context required to use in the output committer. val taskAttemptContext: TaskAttemptContext = { - val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) - val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -377,7 +378,7 @@ object FileFormatWriter extends Logging { dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.foreach(_.newFile(currentPath)) + statsTrackers.map(_.newFile(currentPath)) } override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { @@ -428,10 +429,10 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { /** Flag saying whether or not the data to be written out is partitioned. */ - private val isPartitioned = desc.partitionColumns.nonEmpty + val isPartitioned = desc.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = desc.bucketIdExpression.isDefined + val isBucketed = desc.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2c4d0bcf103ff..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -170,7 +170,7 @@ class SQLAppStatusListener( .filter { case (id, _) => metricIds.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) } // Check the execution again for whether the aggregated metrics data has been calculated. From 7f82c4a47e94ee4f544dee8bb71b99534e919769 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 May 2018 12:54:19 -0700 Subject: [PATCH 0828/2461] [SPARK-24312][SQL] Upgrade to 2.3.3 for Hive Metastore Client 2.3 ## What changes were proposed in this pull request? Hive 2.3.3 was [released on April 3rd](https://issues.apache.org/jira/secure/ReleaseNote.jspa?version=12342162&styleName=Text&projectId=12310843). This PR aims to upgrade Hive Metastore Client 2.3 from 2.3.2 to 2.3.3. ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #21359 from dongjoon-hyun/SPARK-24312. --- docs/sql-programming-guide.md | 4 ++-- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/client/package.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f79ed6422205..b93d8531d9efe 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used
    @@ -2237,7 +2237,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bb134bbe68bd9..cd321d41f43e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.2.") + s"0.12.0 through 2.3.3.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c2690ec32b9e7..2f34f69b5cf48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -98,7 +98,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 681ee9200f02b..25e9886fa6576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.2", + case object v2_3 extends HiveVersion("2.3.3", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) From 3159ee085b23e2e9f1657d80b7ae3efe82b5edb9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 13:04:00 -0700 Subject: [PATCH 0829/2461] [SPARK-24149][YARN] Retrieve all federated namespaces tokens ## What changes were proposed in this pull request? Hadoop 3 introduces HDFS federation. This means that multiple namespaces are allowed on the same HDFS cluster. In Spark, we need to ask the delegation token for all the namenodes (for each namespace), otherwise accessing any other namespace different from the default one (for which we already fetch the delegation token) fails. The PR adds the automatic discovery of all the namenodes related to all the namespaces available according to the configs in hdfs-site.xml. ## How was this patch tested? manual tests in dockerized env Author: Marco Gaido Closes #21216 from mgaido91/SPARK-24149. --- docs/running-on-yarn.md | 9 ++- .../deploy/yarn/YarnSparkHadoopUtil.scala | 24 ++++++- .../yarn/YarnSparkHadoopUtilSuite.scala | 65 ++++++++++++++++++- 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c9e68c3bfd056..4dbcbeafbbd9d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -424,9 +424,12 @@ To use a custom metrics.properties for the application master and executors, upd Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. -In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home -directory, Spark will also automatically obtain delegation tokens for the service hosting the -staging directory of the Spark application. +In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens +for: + +- the filesystem hosting the staging directory of the Spark application (which is the default + filesystem if `spark.yarn.stagingDir` is not set); +- if Hadoop federation is enabled, all the federated filesystems in the configuration. If an application needs to interact with other secure Hadoop filesystems, their URIs need to be explicitly provided to Spark at launch time. This is done by listing them in the diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 8eda6cb1277c5..7250e58b6c49a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -200,7 +200,29 @@ object YarnSparkHadoopUtil { .map(new Path(_).getFileSystem(hadoopConf)) .getOrElse(FileSystem.get(hadoopConf)) - filesystemsToAccess + stagingFS + // Add the list of available namenodes for all namespaces in HDFS federation. + // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its + // namespaces. + val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { + Set.empty + } else { + val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") + // Retrieving the filesystem for the nameservices where HA is not enabled + val filesystemsWithoutHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode => + new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf) + } + } + // Retrieving the filesystem for the nameservices where HA is enabled + val filesystemsWithHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ => + new Path(s"hdfs://$ns").getFileSystem(hadoopConf) + } + } + (filesystemsWithoutHA ++ filesystemsWithHA).toSet + } + + filesystemsToAccess ++ hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index f21353aa007c8..61c0c43f7c04f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -21,7 +21,8 @@ import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.io.Text +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } + test("SPARK-24149: retrieve all namenodes from HDFS") { + val sparkConf = new SparkConf() + val basicFederationConf = new Configuration() + basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + basicFederationConf.set("dfs.nameservices", "ns1,ns2") + basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020") + basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021") + val basicFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf), + new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf)) + val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, basicFederationConf) + basicFederationResult should be (basicFederationExpected) + + // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them + val viewFsConf = new Configuration() + viewFsConf.addResource(basicFederationConf) + viewFsConf.set("fs.defaultFS", "viewfs://clusterX/") + viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/") + val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf)) + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected) + + // invalid config should not throw NullPointerException + val invalidFederationConf = new Configuration() + invalidFederationConf.addResource(basicFederationConf) + invalidFederationConf.unset("dfs.namenode.rpc-address.ns2") + val invalidFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf)) + val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, invalidFederationConf) + invalidFederationResult should be (invalidFederationExpected) + + // no namespaces defined, ie. old case + val noFederationConf = new Configuration() + noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + val noFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(noFederationConf)) + val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf) + noFederationResult should be (noFederationExpected) + + // federation and HA enabled + val federationAndHAConf = new Configuration() + federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA") + federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA") + federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2") + federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + + val federationAndHAExpected = Set( + new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf), + new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf)) + val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, federationAndHAConf) + federationAndHAResult should be (federationAndHAExpected) + } } From a53ea70c1d8903cdff051edf667b0127c8131a09 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 18 May 2018 13:38:36 -0700 Subject: [PATCH 0830/2461] [SPARK-23856][SQL] Add an option `queryTimeout` in JDBCOptions ## What changes were proposed in this pull request? This pr added an option `queryTimeout` for the number of seconds the the driver will wait for a Statement object to execute. ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro Closes #21173 from maropu/SPARK-23856. --- docs/sql-programming-guide.md | 11 +++++++++++ .../org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../datasources/jdbc/JDBCOptions.scala | 5 +++++ .../execution/datasources/jdbc/JDBCRDD.scala | 3 +++ .../jdbc/JdbcRelationProvider.scala | 2 +- .../execution/datasources/jdbc/JdbcUtils.scala | 16 +++++++++++++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 16 ++++++++++++++++ .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 18 ++++++++++++++++++ 8 files changed, 69 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b93d8531d9efe..f1ed316341b95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1338,6 +1338,17 @@ the following case-insensitive options: + + + + +
    `spark.mllib` modelPMML model
    spark.mllib modelPMML model
    This configuration limits the number of remote requests to fetch blocks at any given point. When the number of hosts in the cluster increase, it might lead to very large number - of in-bound connections to one or more nodes, causing the workers to fail under load. + of inbound connections to one or more nodes, causing the workers to fail under load. By allowing it to limit the number of fetch requests, this scenario can be mitigated.
    4194304 (4 MB) The estimated cost to open a file, measured by the number of bytes could be scanned at the same - time. This is used when putting multiple files into a partition. It is better to over estimate, + time. This is used when putting multiple files into a partition. It is better to overestimate, then the partitions with small files will be faster than partitions with bigger files.
    0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained + (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarse-grained mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, @@ -1634,7 +1634,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + failure happens. If external shuffle service is enabled, then the whole node will be blacklisted.
    spark.streaming.receiver.writeAheadLog.enable false - Enable write ahead logs for receivers. All the input data received through receivers - will be saved to write ahead logs that will allow it to be recovered after driver failures. + Enable write-ahead logs for receivers. All the input data received through receivers + will be saved to write-ahead logs that will allow it to be recovered after driver failures. See the deployment guide in the Spark Streaming programing guide for more details. spark.streaming.driver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the driver. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the metadata WAL on the driver. spark.streaming.receiver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the receivers. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the data WAL on the receivers. spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + Attribute-based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting applies only to executors. Refer to Mesos Attributes & Resources for more information on attributes.
      diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e07759a4dba87..ceda8a3ae2403 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -418,7 +418,7 @@ To use a custom metrics.properties for the application master and executors, upd - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example, you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. # Kerberos diff --git a/docs/security.md b/docs/security.md index 3e5607a9a0d67..8c0c66fb5a285 100644 --- a/docs/security.md +++ b/docs/security.md @@ -374,7 +374,7 @@ replaced with one of the above namespaces.
    ${ns}.enabledAlgorithms None - A comma separated list of ciphers. The specified ciphers must be supported by JVM. + A comma-separated list of ciphers. The specified ciphers must be supported by JVM.
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section of the Java security guide. The list for Java 8 can be found at diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 8fa643abf1373..f06e72a387df1 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -338,7 +338,7 @@ worker during one single schedule iteration. # Monitoring and Logging -Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. +Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. diff --git a/docs/sparkr.md b/docs/sparkr.md index 2909247e79e95..7fabab5d38f16 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -107,7 +107,7 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically, we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
    {% highlight r %} @@ -169,7 +169,7 @@ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings {% endhighlight %}
    -The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example +The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example, we can save the SparkDataFrame from the previous example to a Parquet file using `write.df`.
    @@ -241,7 +241,7 @@ head(filter(df, df$waiting < 50)) ### Grouping, Aggregation -SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example, we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
    {% highlight r %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9822d669050d5..55d35b9dd31db 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -165,7 +165,7 @@ In addition to simple column references and expressions, Datasets also have a ri
    -In Python it's possible to access a DataFrame's columns either by attribute +In Python, it's possible to access a DataFrame's columns either by attribute (`df.age`) or by indexing (`df['age']`). While the former is convenient for interactive data exploration, users are highly encouraged to use the latter form, which is future proof and won't break with column names that @@ -278,7 +278,7 @@ the bytes back into an object. Spark SQL supports two different methods for converting existing RDDs into Datasets. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This -reflection based approach leads to more concise code and works well when you already know the schema +reflection-based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating Datasets is through a programmatic interface that allows you to @@ -1243,7 +1243,7 @@ The following options can be used to configure the version of Hive that is used
    com.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc

    - A comma separated list of class prefixes that should be loaded using the classloader that is + A comma-separated list of class prefixes that should be loaded using the classloader that is shared between Spark SQL and a specific version of Hive. An example of classes that should be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need to be shared are those that interact with classes that are already shared. For example, @@ -1441,7 +1441,7 @@ SELECT * FROM resultTable # Performance Tuning -For some workloads it is possible to improve performance by either caching data in memory, or by +For some workloads, it is possible to improve performance by either caching data in memory, or by turning on some experimental options. ## Caching Data In Memory @@ -1804,7 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. @@ -1966,11 +1966,11 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. ## Upgrading From Spark SQL 2.1 to 2.2 - - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). @@ -2013,7 +2013,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC connection owns a copy of their own SQL configuration and temporary function registry. Cached tables are still shared though. If you prefer to run the Thrift server in the old single-session mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add @@ -2161,7 +2161,7 @@ been renamed to `DataFrame`. This is primarily because DataFrames no longer inhe directly, but instead provide most of the functionality that RDDs provide though their own implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. -In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. @@ -2170,11 +2170,11 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users of either language should use `SQLContext` and `DataFrame`. In general these classes try to -use types that are usable from both languages (i.e. `Array` instead of language specific collections). +use types that are usable from both languages (i.e. `Array` instead of language-specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally, the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. @@ -2231,7 +2231,7 @@ referencing a singleton. ## Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. -Currently Hive SerDes and UDFs are based on Hive 1.2.1, +Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore (from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). @@ -2323,10 +2323,10 @@ A handful of Hive optimizations are not yet included in Spark. Some of these (su less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. -* Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you +* Block-level bitmap indexes and virtual columns (used to build indexes) +* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". -* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still +* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. * `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. @@ -2983,6 +2983,6 @@ does not exactly match standard floating point semantics. Specifically: - NaN = NaN returns true. - - In aggregations all NaN values are grouped together. + - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index 1dd54719b21aa..dacaa3438d489 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -39,7 +39,7 @@ For example, for Maven support, add the following to the pom.xml fi # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -The main category of parameters that should be configured are the authentication parameters +The main category of parameters that should be configured is the authentication parameters required by Keystone. The following table contains a list of Keystone mandatory parameters. PROVIDER can be diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 257a4f7d4f3ca..a1b6942ffe0a4 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -17,7 +17,7 @@ Choose a machine in your cluster such that - Flume can be configured to push data to a port on that machine. -Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data. #### Configuring Flume Configure Flume agent to send data to an Avro sink by having the following in the configuration file. @@ -100,7 +100,7 @@ Choose a machine that will run the custom sink in a Flume agent. The rest of the #### Configuring Flume Configuring Flume on the chosen machine requires the following two steps. -1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink. (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). @@ -128,7 +128,7 @@ Configuring Flume on the chosen machine requires the following two steps. agent.sinks.spark.port = agent.sinks.spark.channel = memoryChannel - Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about configuring Flume agents. diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md index 9f0671da2ee31..becf217738d26 100644 --- a/docs/streaming-kafka-0-8-integration.md +++ b/docs/streaming-kafka-0-8-integration.md @@ -10,7 +10,7 @@ Here we explain how to configure Spark Streaming to receive data from Kafka. The ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write-Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write-ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write-Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -55,11 +55,11 @@ Next, we discuss how to use this approach in your streaming application. **Points to remember:** - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + - Topic partitions in Kafka do not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write-Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). 3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. @@ -80,9 +80,9 @@ This approach has the following advantages over the receiver-based approach (i.e - *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write-Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write-Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write-Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high-level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with-write-ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ffda36d64a770..c30959263cdfa 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1461,7 +1461,7 @@ Note that the connections in the pool should be lazily created on demand and tim *** ## DataFrame and SQL Operations -You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore, this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.

    @@ -2010,10 +2010,10 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *Configuring write ahead logs* - Since Spark 1.2, - we have introduced _write ahead logs_ for achieving strong +- *Configuring write-ahead logs* - Since Spark 1.2, + we have introduced _write-ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into - a write ahead log in the configuration checkpoint directory. This prevents data loss on driver + a write-ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) @@ -2021,15 +2021,15 @@ To run a Spark Streaming applications, you need to have the following. come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the - received data within Spark be disabled when the write ahead log is enabled as the log is already + received data within Spark be disabled when the write-ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that - does not support flushing) for _write ahead logs_, please remember to enable + does not support flushing) for _write-ahead logs_, please remember to enable `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - Note that Spark will not encrypt data written to the write ahead log when I/O encryption is - enabled. If encryption of the write ahead log data is desired, it should be stored in a file + Note that Spark will not encrypt data written to the write-ahead log when I/O encryption is + enabled. If encryption of the write-ahead log data is desired, it should be stored in a file system that supports encryption natively. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming @@ -2284,9 +2284,9 @@ Having bigger blockinterval means bigger blocks. A high value of `spark.locality - Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued. -- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted. +- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However, the partitioning of the RDDs is not impacted. -- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. +- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently, there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. *************************************************************************************************** @@ -2388,7 +2388,7 @@ then besides these losses, all of the past data that was received and replicated lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +ahead logs_ which save the received data to fault-tolerant storage. With the [write-ahead logs enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2402,7 +2402,7 @@ The following table summarizes the semantics under failures:
    Spark 1.1 or earlier, OR
    - Spark 1.2 or later without write ahead logs + Spark 1.2 or later without write-ahead logs
    Buffered data lost with unreliable receivers
    @@ -2416,7 +2416,7 @@ The following table summarizes the semantics under failures:
    Spark 1.2 or later with write ahead logsSpark 1.2 or later with write-ahead logs Zero data loss with reliable receivers
    At-least once semantics diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 5647ec6bc5797..71fd5b10cc407 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -15,7 +15,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also see the [Deploying](#deploying) subsection below. +For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also, see the [Deploying](#deploying) subsection below. ## Reading Data from Kafka diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 9a83f157452ad..602a4c70848e7 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write-Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. @@ -479,7 +479,7 @@ detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) -to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. +to track the read position in the stream. The engine uses checkpointing and write-ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` @@ -690,7 +690,7 @@ These examples generate streaming DataFrames that are untyped, meaning that the By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. -Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user-provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. @@ -2661,7 +2661,7 @@ sql("SET spark.sql.streaming.metricsEnabled=true") All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write-ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index a3643bf0838a1..77aa083c4a584 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,7 +177,7 @@ The master URL passed to Spark can be in one of the following formats: # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a -properties file and pass them on to your application. By default it will read options +properties file and pass them on to your application. By default, it will read options from `conf/spark-defaults.conf` in the Spark directory. For more detail, see the section on [loading default configurations](configuration.html#loading-default-configurations). diff --git a/docs/tuning.md b/docs/tuning.md index fc27713f28d46..912c39879be8f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -196,7 +196,7 @@ To further tune garbage collection, we first need to understand some basic infor * A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old - enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked. + enough or Survivor2 is full, it is moved to Old. Finally, when Old is close to full, a full GC is invoked. The goal of GC tuning in Spark is to ensure that only long-lived RDDs are stored in the Old generation and that the Young generation is sufficiently sized to store short-lived objects. This will help avoid full GCs to collect diff --git a/python/README.md b/python/README.md index 3f17fdb98a081..2e0112da58b94 100644 --- a/python/README.md +++ b/python/README.md @@ -22,7 +22,7 @@ This packaging is currently experimental and may change in future versions (alth Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to set up your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). **NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. diff --git a/sql/README.md b/sql/README.md index fe1d352050c09..70cc7c637b58d 100644 --- a/sql/README.md +++ b/sql/README.md @@ -6,7 +6,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`. From 94524019315ad463f9bc13c107131091d17c6af9 Mon Sep 17 00:00:00 2001 From: Yuchen Huo Date: Fri, 6 Apr 2018 08:35:20 -0700 Subject: [PATCH 0573/2461] [SPARK-23822][SQL] Improve error message for Parquet schema mismatches ## What changes were proposed in this pull request? This pull request tries to improve the error message for spark while reading parquet files with different schemas, e.g. One with a STRING column and the other with a INT column. A new ParquetSchemaColumnConvertNotSupportedException is added to replace the old UnsupportedOperationException. The Exception is again wrapped in FileScanRdd.scala to throw a more a general QueryExecutionException with the actual parquet file name which trigger the exception. ## How was this patch tested? Unit tests added to check the new exception and verify the error messages. Also manually tested with two parquet with different schema to check the error message. screen shot 2018-03-30 at 4 03 04 pm Author: Yuchen Huo Closes #20953 from yuchenhuo/SPARK-23822. --- ...emaColumnConvertNotSupportedException.java | 62 +++++++++++++++++++ .../parquet/VectorizedColumnReader.java | 38 ++++++++---- .../execution/QueryExecutionException.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 21 ++++++- .../parquet/ParquetSchemaSuite.scala | 55 ++++++++++++++++ 5 files changed, 166 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java new file mode 100644 index 0000000000000..82a1169cbe7ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * Exception thrown when the parquet reader find column type mismatches. + */ +@InterfaceStability.Unstable +public class SchemaColumnConvertNotSupportedException extends RuntimeException { + + /** + * Name of the column which cannot be converted. + */ + private String column; + /** + * Physical column type in the actual parquet file. + */ + private String physicalType; + /** + * Logical column type in the parquet schema the parquet reader use to parse all files. + */ + private String logicalType; + + public String getColumn() { + return column; + } + + public String getPhysicalType() { + return physicalType; + } + + public String getLogicalType() { + return logicalType; + } + + public SchemaColumnConvertNotSupportedException( + String column, + String physicalType, + String logicalType) { + super(); + this.column = column; + this.physicalType = physicalType; + this.logicalType = logicalType; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 47dd625f4b154..72f1d024b08ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.Arrays; import java.util.TimeZone; import org.apache.parquet.bytes.BytesUtils; @@ -31,6 +32,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -231,6 +233,18 @@ private boolean shouldConvertTimestamps() { return convertTz != null && !convertTz.equals(UTC); } + /** + * Helper function to construct exception for parquet schema mismatch. + */ + private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException( + ColumnDescriptor descriptor, + WritableColumnVector column) { + return new SchemaColumnConvertNotSupportedException( + Arrays.toString(descriptor.getPath()), + descriptor.getType().toString(), + column.dataType().toString()); + } + /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ @@ -261,7 +275,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -282,7 +296,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -321,7 +335,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; case BINARY: @@ -360,7 +374,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -375,7 +389,9 @@ private void decodeDictionaryIds( */ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { - assert(column.dataType() == DataTypes.BooleanType); + if (column.dataType() != DataTypes.BooleanType) { + throw constructConvertNotSupportedException(descriptor, column); + } defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } @@ -394,7 +410,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -414,7 +430,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -425,7 +441,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { defColumn.readFloats( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -436,7 +452,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { defColumn.readDoubles( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -471,7 +487,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -510,7 +526,7 @@ private void readFixedLenByteArrayBatch( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala index 16806c620635f..cffd97baea6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -17,4 +17,5 @@ package org.apache.spark.sql.execution -class QueryExecutionException(message: String) extends Exception(message) +class QueryExecutionException(message: String, cause: Throwable = null) + extends Exception(message, cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 835ce98462477..28c36b6020d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,11 +21,14 @@ import java.io.{FileNotFoundException, IOException} import scala.collection.mutable +import org.apache.parquet.io.ParquetDecodingException + import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator @@ -179,7 +182,23 @@ class FileScanRDD( currentIterator = readCurrentFile() } - hasNext + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } } else { currentFile = null InputFileBlockHolder.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 2cd2a600f2b97..9d3dfae348beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.parquet.io.ParquetDecodingException import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -382,6 +385,58 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + // ======================================= + // Tests for parquet schema mismatch error + // ======================================= + def testSchemaMismatch(path: String, vectorizedReaderEnabled: Boolean): SparkException = { + import testImplicits._ + + var e: SparkException = null + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReaderEnabled.toString) { + // Create two parquet files with different schemas in the same folder + Seq(("bcd", 2)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet(s"$path/parquet") + Seq((1, "abc")).toDF("a", "b").coalesce(1).write.mode("append").parquet(s"$path/parquet") + + e = intercept[SparkException] { + spark.read.parquet(s"$path/parquet").collect() + } + } + e + } + + test("schema mismatch failure error message for parquet reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) + val expectedMessage = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the corresponding " + + "files. Details:" + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[ParquetDecodingException]) + assert(e.getCause.getMessage.startsWith(expectedMessage)) + } + } + + test("schema mismatch failure error message for parquet vectorized reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true) + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + + // Check if the physical type is reporting correctly + val errMsg = e.getCause.getMessage + assert(errMsg.startsWith("Parquet column cannot be converted in file")) + val file = errMsg.substring("Parquet column cannot be converted in file ".length, + errMsg.indexOf(". ")) + val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + assert(col.length == 1) + if (col(0).dataType == StringType) { + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + } else { + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + } + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From d766ea2ff2bf59afbd631d3cc2e43bebfccdebed Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 7 Apr 2018 00:15:54 +0800 Subject: [PATCH 0574/2461] [SPARK-23861][SQL][DOC] Clarify default window frame with and without orderBy clause ## What changes were proposed in this pull request? Add docstring to clarify default window frame boundaries with and without orderBy clause ## How was this patch tested? Manually generate doc and check. Author: Li Jin Closes #20978 from icexelloss/SPARK-23861-window-doc. --- python/pyspark/sql/window.py | 4 ++++ .../main/scala/org/apache/spark/sql/expressions/Window.scala | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index e667fba099fb9..d19ced954f04e 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -44,6 +44,10 @@ class Window(object): >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + .. note:: When ordering is not defined, an unbounded window frame (rowFrame, + unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined, + a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default. + .. note:: Experimental .. versionadded:: 1.4 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 1caa243f8d118..cd819bab1b14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.catalyst.expressions._ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) * }}} * + * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, + * unboundedFollowing) is used by default. When ordering is defined, a growing window frame + * (rangeFrame, unboundedPreceding, currentRow) is used by default. + * * @since 1.4.0 */ @InterfaceStability.Stable From c926acf719a6deb9d884a0f19bde075c312bfe5a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 18:42:14 +0200 Subject: [PATCH 0575/2461] [SPARK-23882][CORE] UTF8StringSuite.writeToOutputStreamUnderflow() is not expected to be supported ## What changes were proposed in this pull request? This PR excludes an existing UT [`writeToOutputStreamUnderflow()`](https://github.com/apache/spark/blob/master/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java#L519-L532) in `UTF8StringSuite`. As discussed [here](https://github.com/apache/spark/pull/19222#discussion_r177692142), the behavior of this test looks surprising. This test seems to access metadata area of the JVM object where is reserved by `Platform.BYTE_ARRAY_OFFSET`. This test is introduced thru #16089 by NathanHowell. More specifically, [the commit](https://github.com/apache/spark/pull/16089/commits/27c102deb1701fe62f776fe4da61dac959270b73) `Improve test coverage of UTFString.write` introduced this UT. However, I cannot find any discussion about this UT. I think that it would be good to exclude this UT. ```java public void writeToOutputStreamUnderflow() throws IOException { // offset underflow is apparently supported? final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { new UTF8String( new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); outputStream.reset(); } } ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20995 from kiszk/SPARK-23882. --- .../spark/unsafe/types/UTF8StringSuite.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bad908fcaf136..652c40a35527f 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -515,22 +515,6 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } - @Test - public void writeToOutputStreamUnderflow() throws IOException { - // offset underflow is apparently supported? - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); - - for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - new UTF8String( - new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) - .writeTo(outputStream); - final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); - assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); - outputStream.reset(); - } - } - @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); From d23a805f975f209f273db2b52de3f336be17d873 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Fri, 6 Apr 2018 10:09:55 -0700 Subject: [PATCH 0576/2461] [SPARK-23859][ML] Initial PR for Instrumentation improvements: UUID and logging levels ## What changes were proposed in this pull request? Initial PR for Instrumentation improvements: UUID and logging levels. This PR takes over #20837 Closes #20837 ## How was this patch tested? Manual. Author: Bago Amirbekian Author: WeichenXu Closes #20982 from WeichenXu123/better-instrumentation. --- .../classification/LogisticRegression.scala | 15 ++++--- .../spark/ml/util/Instrumentation.scala | 40 +++++++++++++++---- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3ae4db3f3f965..ee4b01058c75c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } val isConstantLabel = histogram.count(_ != 0.0) == 1 if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { - logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + - s"will be zeros. Training is not needed.") + instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " + + s"coefficients will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") ( (coefMatrix, interceptVec, Array.empty[Double]) } else { if (!$(fitIntercept) && isConstantLabel) { - logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + + instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + s"dangerous ground, so the algorithm may not converge.") } @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") ( if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant " + "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") ( (_initialModel.interceptVector.size == numCoefficientSets) && (_initialModel.getFitIntercept == $(fitIntercept)) if (!modelIsValid) { - logWarning(s"Initial coefficients will be ignored! Its dimensions " + + instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " + s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " + s"expected size ($numCoefficientSets, $numFeatures)") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 7c46f45c59717..e694bc27b2f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import java.util.concurrent.atomic.AtomicLong +import java.util.UUID import org.json4s._ import org.json4s.JsonDSL._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset private[spark] class Instrumentation[E <: Estimator[_]] private ( estimator: E, dataset: RDD[_]) extends Logging { - private val id = Instrumentation.counter.incrementAndGet() + private val id = UUID.randomUUID() private val prefix = { val className = estimator.getClass.getSimpleName s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " @@ -56,12 +56,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } /** - * Logs a message with a prefix that uniquely identifies the training session. + * Logs a warning message with a prefix that uniquely identifies the training session. */ - def log(msg: String): Unit = { - logInfo(prefix + msg) + override def logWarning(msg: => String): Unit = { + super.logWarning(prefix + msg) } + /** + * Logs a error message with a prefix that uniquely identifies the training session. + */ + override def logError(msg: => String): Unit = { + super.logError(prefix + msg) + } + + /** + * Logs an info message with a prefix that uniquely identifies the training session. + */ + override def logInfo(msg: => String): Unit = { + super.logInfo(prefix + msg) + } + + /** + * Alias for logInfo, see above. + */ + def log(msg: String): Unit = logInfo(msg) + /** * Logs the value of the given parameters for the estimator being used in this session. */ @@ -77,11 +96,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } def logNumFeatures(num: Long): Unit = { - log(compact(render("numFeatures" -> num))) + logNamedValue(Instrumentation.loggerTags.numFeatures, num) } def logNumClasses(num: Long): Unit = { - log(compact(render("numClasses" -> num))) + logNamedValue(Instrumentation.loggerTags.numClasses, num) } /** @@ -107,7 +126,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Some common methods for logging information about a training session. */ private[spark] object Instrumentation { - private val counter = new AtomicLong(0) + + object loggerTags { + val numFeatures = "numFeatures" + val numClasses = "numClasses" + val numExamples = "numExamples" + } /** * Creates an instrumentation object for a training session. From b6935ffb4dfb1d9fdf36ba402ac07bd02978c012 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:23:26 -0700 Subject: [PATCH 0577/2461] [SPARK-10399][SPARK-23879][HOTFIX] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following errors in [Java lint](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-lint/7717/console) after #19222 has been merged. These errors were pointed by ueshin . ``` [ERROR] src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java:[57] (sizes) LineLength: Line is longer than 100 characters (found 106). [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[26,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java:[23,10] (modifier) ModifierOrder: 'public' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[64,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[69,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[74,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[79,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[84,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[89,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[94,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[99,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[104,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[109,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[114,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[119,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[124,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[129,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[60,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[65,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[70,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[75,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[80,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[85,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[90,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[95,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[100,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[105,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[110,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[115,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[120,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[125,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java:[114,16] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[30,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.memory.MemoryBlock. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[126,15] (naming) MethodName: Method name 'ByteArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[143,15] (naming) MethodName: Method name 'OnHeapMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[160,15] (naming) MethodName: Method name 'OffHeapArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[19,8] (imports) UnusedImports: Unused import - com.google.common.primitives.Ints. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20991 from kiszk/SPARK-10399-jlint. --- .../sql/catalyst/expressions/HiveHasher.java | 1 - .../spark/unsafe/array/ByteArrayMethods.java | 4 +-- .../unsafe/memory/ByteArrayMemoryBlock.java | 28 +++++++++---------- .../unsafe/memory/HeapMemoryAllocator.java | 2 -- .../spark/unsafe/memory/MemoryBlock.java | 2 +- .../unsafe/memory/OffHeapMemoryBlock.java | 2 +- .../unsafe/memory/OnHeapMemoryBlock.java | 28 +++++++++---------- .../spark/unsafe/memory/MemoryBlockSuite.java | 6 ++-- .../spark/unsafe/types/UTF8StringSuite.java | 1 - .../spark/sql/catalyst/expressions/XXH64.java | 3 -- .../catalyst/expressions/HiveHasherSuite.java | 1 - 11 files changed, 35 insertions(+), 43 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 5d905943a3aa7..c34e36903a93e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index c334c9651cf6b..4bc9955090fd7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -54,7 +54,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEqualsBlock( - MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); } @@ -64,7 +64,7 @@ public static boolean arrayEqualsBlock( * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { int i = 0; // check if starts align and we can get both offsets to be aligned diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java index 99a9868a49a79..9f238632bc87a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -57,72 +57,72 @@ public static ByteArrayMemoryBlock fromArray(final byte[] array) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index acf28fd7ee59b..36caf80888cda 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,8 +23,6 @@ import java.util.LinkedList; import java.util.Map; -import org.apache.spark.unsafe.Platform; - /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index b086941108522..ca7213bbf92da 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -111,7 +111,7 @@ public final void fill(byte value) { /** * Instantiate MemoryBlock for given object type with new offset */ - public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { MemoryBlock mb = null; if (obj instanceof byte[]) { byte[] array = (byte[])obj; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java index f90f62bf21dcb..3431b08980eb8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; public class OffHeapMemoryBlock extends MemoryBlock { - static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); public OffHeapMemoryBlock(long address, long size) { super(null, address, size); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java index 12f67c7bd593e..ee42bc27c9c5f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -61,72 +61,72 @@ public static OnHeapMemoryBlock fromArray(final long[] array, long size) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return Platform.getByte(array, this.offset + offset); } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { Platform.putByte(array, this.offset + offset, value); } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 47f05c928f2e5..5d5fdc1c55a75 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -123,7 +123,7 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } @Test - public void ByteArrayMemoryBlockTest() { + public void testByteArrayMemoryBlock() { byte[] obj = new byte[56]; long offset = Platform.BYTE_ARRAY_OFFSET; int length = obj.length; @@ -140,7 +140,7 @@ public void ByteArrayMemoryBlockTest() { } @Test - public void OnHeapMemoryBlockTest() { + public void testOnHeapMemoryBlock() { long[] obj = new long[7]; long offset = Platform.LONG_ARRAY_OFFSET; int length = obj.length * 8; @@ -157,7 +157,7 @@ public void OnHeapMemoryBlockTest() { } @Test - public void OffHeapArrayMemoryBlockTest() { + public void testOffHeapArrayMemoryBlock() { MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); MemoryBlock memory = memoryAllocator.allocate(56); Object obj = memory.getBaseObject(); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 652c40a35527f..2c08535a16465 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -27,7 +27,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index 883748932ad33..fe727f6011cbf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,9 +16,6 @@ */ package org.apache.spark.sql.catalyst.expressions; -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 8ffc1d7c24d61..76930f9368514 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; From e998250588de0df250e2800278da4d3e3705c259 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 6 Apr 2018 11:51:36 -0700 Subject: [PATCH 0578/2461] [SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels ## What changes were proposed in this pull request? The Scala StringIndexerModel has an alternate constructor that will create the model from an array of label strings. Add the corresponding Python API: model = StringIndexerModel.from_labels(["a", "b", "c"]) ## How was this patch tested? Add doctest and unit test. Author: Huaxin Gao Closes #20968 from huaxingao/spark-23828. --- python/pyspark/ml/feature.py | 88 ++++++++++++++++++++++++++---------- python/pyspark/ml/tests.py | 41 ++++++++++++++++- 2 files changed, 104 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fcb0dfc563720..5a3e0dd655150 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2342,9 +2342,38 @@ def mean(self): return self._call_java("mean") +class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`. + """ + + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "or NULL values) in features and label column of string type. " + + "Options are 'skip' (filter out rows with invalid data), " + + "error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + + def __init__(self, *args): + super(_StringIndexerParams, self).__init__(*args) + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + + @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] + >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"], + ... inputCol="label", outputCol="indexed", handleInvalid="error") + >>> result = fromlabelsModel.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)] .. versionadded:: 1.4.0 """ - stringOrderType = Param(Params._dummy(), "stringOrderType", - "How to order labels of string column. The first label after " + - "ordering is assigned an index of 0. Supported options: " + - "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", - typeConverter=TypeConverters.toString) - - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + - "or NULL values) in features and label column of string type. " + - "Options are 'skip' (filter out rows with invalid data), " + - "error (throw an error), or 'keep' (put invalid data " + - "in a special additional bucket, at index numLabels).", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2414,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2440,21 +2461,33 @@ def setStringOrderType(self, value): """ return self._set(stringOrderType=value) - @since("2.3.0") - def getStringOrderType(self): - """ - Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. - """ - return self.getOrDefault(self.stringOrderType) - -class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): +class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """ + @classmethod + @since("2.4.0") + def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None): + """ + Construct the model directly from an array of label strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jlabels = StringIndexerModel._new_java_array(labels, java_class) + model = StringIndexerModel._create_from_java_class( + "org.apache.spark.ml.feature.StringIndexerModel", jlabels) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if handleInvalid is not None: + model.setHandleInvalid(handleInvalid) + return model + @property @since("1.5.0") def labels(self): @@ -2463,6 +2496,13 @@ def labels(self): """ return self._call_java("labels") + @since("2.4.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c2c4861e2aff4..4ce54547eab09 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -800,6 +800,43 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + class HasInducedError(Params): @@ -2097,9 +2134,11 @@ def test_java_params(self): ParamTests.check_params(self, cls(), check_params_exist=False) # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), check_params_exist=False) + ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) def _squared_distance(a, b): From 6ab134ca7d8f7802a6d196929513cc02b9b4d35d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 6 Apr 2018 15:00:13 -0700 Subject: [PATCH 0579/2461] [SPARK-21898][ML][FOLLOWUP] Fix Scala 2.12 build. ## What changes were proposed in this pull request? This is a follow-up pr of #19108 which broke Scala 2.12 build. ``` [error] /Users/ueshin/workspace/apache-spark/spark/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala:86: overloaded method value test with alternatives: [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: org.apache.spark.api.java.function.Function[java.lang.Double,java.lang.Double])org.apache.spark.sql.DataFrame [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: scala.Double => scala.Double)org.apache.spark.sql.DataFrame [error] cannot be applied to (org.apache.spark.sql.DataFrame, String, scala.Double => java.lang.Double) [error] test(dataset, sampleCol, (x: Double) => cdf.call(x)) [error] ^ [error] one error found ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20994 from ueshin/issues/SPARK-21898/fix_scala-2.12. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index 8d80e7768cb6e..c62d7463288f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -83,7 +83,8 @@ object KolmogorovSmirnovTest { @Since("2.4.0") def test(dataset: DataFrame, sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + val f: Double => Double = x => cdf.call(x) + test(dataset, sampleCol, f) } /** From 2c1fe647575e97e28b2232478ca86847d113e185 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 8 Apr 2018 12:09:06 +0800 Subject: [PATCH 0580/2461] [SPARK-23847][PYTHON][SQL] Add asc_nulls_first, asc_nulls_last to PySpark ## What changes were proposed in this pull request? Column.scala and Functions.scala have asc_nulls_first, asc_nulls_last, desc_nulls_first and desc_nulls_last. Add the corresponding python APIs in column.py and functions.py ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #20962 from huaxingao/spark-23847. --- python/pyspark/sql/column.py | 56 +++++++++++++++++-- python/pyspark/sql/functions.py | 13 +++++ python/pyspark/sql/tests.py | 14 +++++ .../scala/org/apache/spark/sql/Column.scala | 4 +- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 82 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 922c7cf288f8f..e7dec11c69b57 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -447,24 +447,72 @@ def isin(self, *cols): # order _asc_doc = """ - Returns a sort expression based on the ascending order of the given column name + Returns a sort expression based on ascending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc()).collect() [Row(name=u'Alice'), Row(name=u'Tom')] """ + _asc_nulls_first_doc = """ + Returns a sort expression based on ascending order of the column, and null values + return before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + + .. versionadded:: 2.4 + """ + _asc_nulls_last_doc = """ + Returns a sort expression based on ascending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + + .. versionadded:: 2.4 + """ _desc_doc = """ - Returns a sort expression based on the descending order of the given column name. + Returns a sort expression based on the descending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc()).collect() [Row(name=u'Tom'), Row(name=u'Alice')] """ + _desc_nulls_first_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + + .. versionadded:: 2.4 + """ + _desc_nulls_last_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + + .. versionadded:: 2.4 + """ asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc)) + asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc)) desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) + desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc)) + desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc)) _isNull_doc = """ True if the current expression is null. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad3e37c872628..1b192680f0795 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -138,6 +138,17 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } +_functions_2_4 = { + 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values return before non-null values.', + 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values appear after non-null values.', + 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear before non-null values.', + 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear after non-null values', +} + _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. @@ -250,6 +261,8 @@ def _(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5181053a0d318..dd04ffb4ed393 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,6 +2991,20 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() + def test_2_4_functions(self): + from pyspark.sql import functions + + df = self.spark.createDataFrame( + [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 92988680871a4..ad0efbae89830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1083,10 +1083,10 @@ class Column(val expr: Expression) extends Logging { * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. - * df.sort(df("age").asc_nulls_last) + * df.sort(df("age").asc_nulls_first) * * // Java - * df.sort(df.col("age").asc_nulls_last()); + * df.sort(df.col("age").asc_nulls_first()); * }}} * * @group expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9ca9a8996344..c658f25ced053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -132,7 +132,7 @@ object functions { * Returns a sort expression based on ascending order of the column, * and null values return before non-null values. * {{{ - * df.sort(asc_nulls_last("dept"), desc("age")) + * df.sort(asc_nulls_first("dept"), desc("age")) * }}} * * @group sort_funcs From 6a734575a80e6b4ec4963206254451f05d64b742 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 7 Apr 2018 21:44:32 -0700 Subject: [PATCH 0581/2461] [SPARK-23849][SQL] Tests for the samplingRatio option of JSON datasource ## What changes were proposed in this pull request? Proposed tests checks that only subset of input dataset is touched during schema inferring. Author: Maxim Gekk Closes #20963 from MaxGekk/json-sampling-tests. --- .../datasources/json/JsonSuite.scala | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 10bac0554484a..70aee561ff0f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -2127,4 +2127,39 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(df.schema === expectedSchema) } } + + test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + withTempPath { path => + val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), + StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) + for (i <- 0 until 100) { + if (predefinedSample.contains(i)) { + writer.write(s"""{"f1":${i.toString}}""" + "\n") + } else { + writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") + } + } + writer.close() + + val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + assert(ds.schema == new StructType().add("f1", LongType)) + } + } + + test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { + val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(i)) { + s"""{"f1":${i.toString}}""" + "\n" + } else { + s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" + } + }.toDS() + val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + + assert(ds.schema == new StructType().add("f1", LongType)) + } } From 710a68cec27a94c2df10d8b4022a755a94a5443b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:26:31 +0200 Subject: [PATCH 0582/2461] [SPARK-23892][TEST] Improve converge and fix lint error in UTF8String-related tests ## What changes were proposed in this pull request? This PR improves test coverage in `UTF8StringSuite` and code efficiency in `UTF8StringPropertyCheckSuite`. This PR also fixes lint-java issue in `UTF8StringSuite` reported at [here](https://github.com/apache/spark/pull/20995#issuecomment-379325527) ```[ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[28,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform.``` ## How was this patch tested? Existing UT Author: Kazuaki Ishizaki Closes #21000 from kiszk/SPARK-23892. --- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 5 ++--- .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 2c08535a16465..42dda30480702 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,7 +25,6 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; @@ -53,8 +52,8 @@ private static void checkBasic(String str, int len) { assertTrue(s1.contains(s2)); assertTrue(s2.contains(s1)); - assertTrue(s1.startsWith(s1)); - assertTrue(s1.endsWith(s1)); + assertTrue(s1.startsWith(s2)); + assertTrue(s1.endsWith(s2)); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 62d4176d00f94..48004e812a8bf 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -164,7 +164,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { if (length <= 0) return "" if (length <= origin.length) { - if (length <= 0) "" else origin.substring(0, length) + origin.substring(0, length) } else { if (pad.length == 0) return origin val toPad = length - origin.length From 8d40a79a077a30024a8ef921781b68f6f7e542d1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:40:27 +0200 Subject: [PATCH 0583/2461] [SPARK-23893][CORE][SQL] Avoid possible integer overflow in multiplication ## What changes were proposed in this pull request? This PR avoids possible overflow at an operation `long = (long)(int * int)`. The multiplication of large positive integer values may set one to MSB. This leads to a negative value in long while we expected a positive value (e.g. `0111_0000_0000_0000 * 0000_0000_0000_0010`). This PR performs long cast before the multiplication to avoid this situation. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21002 from kiszk/SPARK-23893. --- .../util/collection/unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../util/collection/unsafe/sort/UnsafeSortDataFormat.java | 2 +- .../src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../scala/org/apache/spark/InternalAccumulatorSuite.scala | 2 +- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- .../test/scala/org/apache/spark/util/JsonProtocolSuite.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/HashBenchmark.scala | 2 +- .../scala/org/apache/spark/sql/HashByteArrayBenchmark.scala | 3 ++- .../org/apache/spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../columnar/compression/CompressionSchemeBenchmark.scala | 4 ++-- .../sql/execution/vectorized/ColumnarBatchBenchmark.scala | 2 +- 11 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 20a7a8b267438..717823ebbd320 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -124,7 +124,7 @@ public UnsafeInMemorySorter( int initialSize, boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2), canUseRadixSort); + consumer.allocateArray(initialSize * 2L), canUseRadixSort); } public UnsafeInMemorySorter( diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d9f84d10e9051..37772f41caa87 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -84,7 +84,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int @Override public LongArray allocate(int length) { - assert (length * 2 <= buffer.size()) : + assert (length * 2L <= buffer.size()) : "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2); return buffer; } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index c9ed12f4e1bd4..13db4985b0b80 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -90,7 +90,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = partsScanned * 4 + numPartsToTry = partsScanned * 4L } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 8d7be77f51fe9..62824a5bec9d1 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt // ID that's < 2 * numPartitions belongs to the first attempt of this stage. val taskContext = TaskContext.get() - val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L if (isFirstStageAttempt) { throw new FetchFailedException( SparkEnv.get.blockManager.blockManagerId, diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fde5f25bce456..0ba57bf4563c1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -382,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) writeFile(log, true, None, SparkListenerApplicationStart( - "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), - SparkListenerApplicationEnd(5001 * i) + "downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001L * i) ) log } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 4abbb8e7894f5..74b72d940eeef 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("SparkListenerJobStart backward compatibility") { // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) @@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite { // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. // Also, SparkListenerJobEnd did not have a "Completion Time" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) .removeField({ _._1 == "Submission Time"}) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 2d94b66a1e122..9a89e6290e695 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -40,7 +40,7 @@ object HashBenchmark { safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() ).toArray - val benchmark = new Benchmark("Hash For " + name, iters * numRows) + val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong) benchmark.addCase("interpreted version") { _: Int => var sum = 0 for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 2a753a0c84ed5..f6c8111f5bc57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -36,7 +36,8 @@ object HashByteArrayBenchmark { bytes } - val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + val benchmark = + new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong) benchmark.addCase("Murmur3_x86_32") { _: Int => var sum = 0L for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 769addf3b29e6..6c63769945312 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark { val iters = 1024 * 16 val numRows = 1024 * 16 - val benchmark = new Benchmark("unsafe projection", iters * numRows) + val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong) val schema1 = new StructType().add("l", LongType, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 9005ec93e786e..619b76fabdd5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -77,7 +77,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -101,7 +101,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 1f31aa45a1220..8aeb06d428951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -295,7 +295,7 @@ object ColumnarBatchBenchmark { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 - val benchmark = new Benchmark("Boolean Read/Write", iters * count) + val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong) benchmark.addCase("Bitset") { i: Int => { val b = new BitSet(count) var sum = 0L From 32471ba0af52b59141b44a8375025b6a7eafae70 Mon Sep 17 00:00:00 2001 From: Nolan Emirot Date: Mon, 9 Apr 2018 08:04:02 -0500 Subject: [PATCH 0584/2461] Fix typo in Python docstring kinesis example ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nolan Emirot Closes #20990 from emirot/kinesis_stream_example_typo. --- .../src/main/python/examples/streaming/kinesis_wordcount_asl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index 4d7fc9a549bfb..49794faab88c4 100644 --- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -34,7 +34,7 @@ $ export AWS_SECRET_KEY= # run the example - $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\ + $ bin/spark-submit -jars external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com From d81f29ecafe8fc9816e36087e3b8acdc93d6cc1b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 9 Apr 2018 10:19:22 -0700 Subject: [PATCH 0585/2461] [SPARK-23881][CORE][TEST] Fix flaky test JobCancellationSuite."interruptible iterator of shuffle reader" ## What changes were proposed in this pull request? The test case JobCancellationSuite."interruptible iterator of shuffle reader" has been flaky because `KillTask` event is handled asynchronously, so it can happen that the semaphore is released but the task is still running. Actually we only have to check if the total number of processed elements is less than the input elements number, so we know the task get cancelled. ## How was this patch tested? The new test case still fails without the purposed patch, and succeeded in current master. Author: Xingbo Jiang Closes #20993 from jiangxb1987/JobCancellationSuite. --- .../apache/spark/JobCancellationSuite.scala | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 3b793bb231cf3..61da4138896cd 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -332,13 +332,15 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft import JobCancellationSuite._ sc = new SparkContext("local[2]", "test interruptible iterator") + // Increase the number of elements to be proceeded to avoid this test being flaky. + val numElements = 10000 val taskCompletedSem = new Semaphore(0) sc.addSparkListener(new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { // release taskCancelledSemaphore when cancelTasks event has been posted if (stageCompleted.stageInfo.stageId == 1) { - taskCancelledSemaphore.release(1000) + taskCancelledSemaphore.release(numElements) } } @@ -349,28 +351,31 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) - val f = sc.parallelize(1 to 1000).map { i => (i, i) } + // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be + // interrupted by `InterruptibleIterator`. + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + + val f = sc.parallelize(1 to numElements).map { i => (i, i) } .repartitionAndSortWithinPartitions(new HashPartitioner(1)) .mapPartitions { iter => taskStartedSemaphore.release() iter }.foreachAsync { x => - if (x._1 >= 10) { - // This block of code is partially executed. It will be blocked when x._1 >= 10 and the - // next iteration will be cancelled if the source iterator is interruptible. Then in this - // case, the maximum num of increment would be 10(|1...10|) - taskCancelledSemaphore.acquire() - } + // Block this code from being executed, until the job get cancelled. In this case, if the + // source iterator is interruptible, the max number of increment should be under + // `numElements`. + taskCancelledSemaphore.acquire() executionOfInterruptibleCounter.getAndIncrement() } taskStartedSemaphore.acquire() // Job is cancelled when: // 1. task in reduce stage has been started, guaranteed by previous line. - // 2. task in reduce stage is blocked after processing at most 10 records as - // taskCancelledSemaphore is not released until cancelTasks event is posted - // After job being cancelled, task in reduce stage will be cancelled and no more iteration are - // executed. + // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until + // JobCancelled event is posted. + // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus + // partial of the inputs should not get processed (It's very unlikely that Spark can process + // 10000 elements between JobCancelled is posted and task is really killed). f.cancel() val e = intercept[SparkException](f.get()).getCause @@ -378,7 +383,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Make sure tasks are indeed completed. taskCompletedSem.acquire() - assert(executionOfInterruptibleCounter.get() <= 10) + assert(executionOfInterruptibleCounter.get() < numElements) } def testCount() { From 10f45bb8233e6ac838dd4f053052c8556f5b54bd Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 9 Apr 2018 11:31:21 -0700 Subject: [PATCH 0586/2461] [SPARK-23816][CORE] Killed tasks should ignore FetchFailures. SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins. Author: Imran Rashid Closes #20987 from squito/SPARK-23816. --- .../org/apache/spark/executor/Executor.scala | 26 +++--- .../apache/spark/executor/ExecutorSuite.scala | 92 +++++++++++++++---- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dcec3ec21b546..c325222b764b8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -480,6 +480,19 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { @@ -494,19 +507,6 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - - case _: InterruptedException | NonFatal(_) if - task != null && task.reasonIfKilled.isDefined => - val killReason = task.reasonIfKilled.getOrElse("unknown reason") - logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) - case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 105a178f2d94e..1a7bebe2c53cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) + } + + test("SPARK-23816: interrupts are not masked by a FetchFailure") { + // If killing the task causes a fetch failure, we still treat it as a task that was killed, + // as the fetch failure could easily be caused by interrupting the thread. + val (failReason, _) = testFetchFailureHandling(false) + assert(failReason.isInstanceOf[TaskKilled]) + } + + /** + * Helper for testing some cases where a FetchFailure should *not* get sent back, because its + * superceded by another error, either an OOM or intentionally killing a task. + * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the + * FetchFailure + */ + private def testFetchFailureHandling( + oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. + // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task + // does not represent a real fetch failure. val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat - // the fetch failure as a false positive, and just do normal OOM handling. + // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We + // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + if (!oom) { + // we are trying to setup a case where a task is killed after a fetch failure -- this + // is just a helper to coordinate between the task thread and this thread that will + // kill the task + ExecutorSuiteHelper.latches = new ExecutorSuiteHelper() + } + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serTask = serializer.serialize(task) val taskDescription = createFakeTaskDescription(serTask) - val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) - // make sure the task failure just looks like a OOM, not a fetch failure - assert(failReason.isInstanceOf[ExceptionFailure]) - val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) - verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) - assert(exceptionCaptor.getAllValues.size === 1) - assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) - } + runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom) + } test("Gracefully handle error in task deserialization") { val conf = new SparkConf @@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null + val timedOut = new AtomicBoolean(false) try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) + if (killTask) { + val killingThread = new Thread("kill-task") { + override def run(): Unit = { + // wait to kill the task until it has thrown a fetch failure + if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) { + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } else { + timedOut.set(true) + } + } + } + killingThread.start() + } eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } + assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop() @@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend) .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture()) // first statusUpdate for RUNNING has empty data assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting @@ -321,7 +364,8 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, val input: FetchFailureThrowingRDD, - throwOOM: Boolean) extends RDD[Int](input) { + throwOOM: Boolean, + interrupt: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { @@ -330,6 +374,15 @@ class FetchFailureHidingRDD( case t: Throwable => if (throwOOM) { throw new OutOfMemoryError("OOM while handling another exception") + } else if (interrupt) { + // make sure our test is setup correctly + assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) + // signal our test is ready for the task to get killed + ExecutorSuiteHelper.latches.latch1.countDown() + // then wait for another thread in the test to kill the task -- this latch + // is never actually decremented, we just wait to get killed. + ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS) + throw new IllegalStateException("timed out waiting to be interrupted") } else { throw new RuntimeException("User Exception that hides the original exception", t) } @@ -352,6 +405,11 @@ private class ExecutorSuiteHelper { @volatile var testFailedReason: TaskFailedReason = _ } +// helper for coordinating killing tasks +private object ExecutorSuiteHelper { + var latches: ExecutorSuiteHelper = null +} + private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { def writeExternal(out: ObjectOutput): Unit = {} def readExternal(in: ObjectInput): Unit = { From 7c1654e2159662e7e663ba141719d755002f770a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Apr 2018 11:54:35 -0700 Subject: [PATCH 0587/2461] [SPARK-22856][SQL] Add wrappers for codegen output and nullability ## What changes were proposed in this pull request? The codegen output of `Expression`, aka `ExprCode`, now encapsulates only strings of output value (`value`) and nullability (`isNull`). It makes difficulty for us to know what the output really is. I think it is better if we can add wrappers for the value and nullability that let us to easily know that. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20043 from viirya/SPARK-22856. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 16 ++-- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 16 ++-- .../expressions/codegen/CodegenFallback.scala | 2 +- .../expressions/codegen/ExprValue.scala | 76 +++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 19 +++-- .../codegen/GenerateUnsafeProjection.scala | 8 +- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeCreator.scala | 8 +- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/expressions/hash.scala | 4 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 24 +++--- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 25 ++++-- .../expressions/objects/objects.scala | 28 ++++--- .../sql/catalyst/expressions/predicates.scala | 10 +-- .../expressions/randomExpressions.scala | 6 +- .../expressions/CodeGenerationSuite.scala | 6 +- .../expressions/codegen/ExprValueSuite.scala | 46 +++++++++++ .../sql/execution/ColumnarBatchScan.scala | 10 ++- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 13 ++-- .../sql/execution/WholeStageCodegenExec.scala | 13 ++-- .../aggregate/HashAggregateExec.scala | 3 +- .../aggregate/HashMapGenerator.scala | 5 +- .../execution/basicPhysicalOperators.scala | 6 +- .../joins/BroadcastHashJoinExec.scala | 6 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- 35 files changed, 294 insertions(+), 120 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 89ffbb0016916..5021a567592e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -76,7 +76,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 38caf67d465d8..7a5e49cb5206b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode("", + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,10 +120,10 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = VariableValue(newValue, javaType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -446,7 +448,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -546,7 +548,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -690,7 +692,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index dd523d312e3b4..ad1e7bdb31987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index cc6a769d032d3..787bcaf5e81de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", - isNull = "false") + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 508bdd5050b54..478ff3a7c1011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -601,7 +601,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -680,7 +681,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84b1e3fbda876..c9c60ef1be640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,16 +56,17 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { def forNullValue(dataType: DataType): ExprCode = { val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode(code = "", isNull = TrueLiteral, + value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) } - def forNonNullValue(value: String): ExprCode = { - ExprCode(code = "", isNull = "false", value = value) + def forNonNullValue(value: ExprValue): ExprCode = { + ExprCode(code = "", isNull = FalseLiteral, value = value) } } @@ -77,7 +78,7 @@ object ExprCode { * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -330,7 +331,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) } def declareMutableStates(): String = { @@ -1003,7 +1004,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), + GlobalValue(value, javaType(expr.dataType))) subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index e12420bb5dfdd..a91989e129664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -59,7 +59,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala new file mode 100644 index 0000000000000..df5f1c58b1b2d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.implicitConversions + +import org.apache.spark.sql.types.DataType + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue { + + val javaType: String + + // Whether we can directly access the evaluation value anywhere. + // For example, a variable created outside a method can not be accessed inside the method. + // For such cases, we may need to pass the evaluation as parameter. + val canDirectAccess: Boolean + + def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +class LiteralValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +object LiteralValue { + def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, String)] = + Some((literal.value, literal.javaType)) +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue( + val variableName: String, + val javaType: String) extends ExprValue { + override def toString: String = variableName + override val canDirectAccess: Boolean = false +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue( + val statement: String, + val javaType: String, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +case object TrueLiteral extends LiteralValue("true", "boolean") +case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d35fd8ecb4d63..3ae0b54c754cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { + val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) + val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f92f70ee71fef..a30a0b22cd305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +76,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) } private def createCodeForArray( @@ -89,8 +91,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe( - ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) + val elementConverter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), + CodeGenerator.javaType(elementType)), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +107,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) } private def createCodeForMap( @@ -125,19 +128,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode("", FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index ab2254cd9f70a..4a4d76313a543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt))) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -334,7 +336,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $evalSubexpr $writeExpressions """ - ExprCode(code, "false", s"$rowWriter.getRow()") + // `rowWriter` is declared as a class field, so we can access it directly in methods. + ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", + canDirectAccess = true)) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e8..91188da8b0bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 85facdad43db7..49a8d12057188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f4e9619bac59d..409c0b6b79b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,7 +191,8 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) + ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + CodeGenerator.javaType(dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1ae4e5a2f716b..49dd988b4b53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,7 +813,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", + CodeGenerator.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4f4d49166e88c..3af4bfebad45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b76b64ab5096f..df29c38d64d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -644,7 +644,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 07785e7448586..2a3cc580273ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 7395609a04ba5..742a650eb445d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -283,36 +283,36 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(value.toString) + ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue("Float.NaN") + ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}F") + ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue("Double.NaN") + ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}D") + ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(s"($javaType)$value") + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) case TimestampType | LongType => - ExprCode.forNonNullValue(s"${value}L") + ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(constRef) + ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a390f8ef7fd9a..7081a5e096d56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -91,7 +91,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = TrueLiteral, + value = LiteralValue("null", CodeGenerator.javaType(dataType))) } override def sql: String = s"assert_true(${child.sql})" @@ -150,7 +151,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Uuid = Uuid(randomSeed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b35fa72e95d1e..55b6e346be82a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -235,7 +236,7 @@ case class IsNaN(child: Expression) extends UnaryExpression ev.copy(code = s""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } } @@ -320,7 +321,12 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + val value = if (eval.isNull.isInstanceOf[LiteralValue]) { + LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -346,7 +352,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = if (eval.isNull == TrueLiteral) { + FalseLiteral + } else if (eval.isNull == FalseLiteral) { + TrueLiteral + } else { + StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -441,6 +454,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9252425f86473..b2cca3178cd2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -61,13 +61,13 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) } else { - "false" + FalseLiteral } val argValues = arguments.map { e => val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") @@ -244,7 +244,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = FalseLiteral "" } @@ -546,7 +546,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -568,7 +568,13 @@ case class LambdaVariable( } override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + val isNullValue = if (nullable) { + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } + ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), + isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make @@ -840,7 +846,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -1343,7 +1349,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1538,7 +1544,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1589,7 +1595,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4b85d9adbe311..e195ec17f3bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -405,7 +405,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -461,7 +461,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -469,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -615,7 +615,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f36633867316e..70186053617f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -83,7 +83,7 @@ case class Rand(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Rand = Rand(child) @@ -120,7 +120,7 @@ case class Randn(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Randn = Randn(child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 398b6767654fa..8e83b35c3809c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,6 +448,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) + val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), + VariableValue("dummy", "boolean")) // raw testing of basic functionality { @@ -457,7 +459,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) assert(ctx.subExprEliminationExprs.contains(ref)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) Seq.empty @@ -475,7 +477,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE assert(ctx.subExprEliminationExprs.contains(add1)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { assert(ctx.subExprEliminationExprs.contains(ref)) assert(!ctx.subExprEliminationExprs.contains(add1)) Seq.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala new file mode 100644 index 0000000000000..c8f4cff7db48d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprValueSuite extends SparkFunSuite { + + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral + + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + assert(trueLit.isPrimitive) + assert(falseLit.isPrimitive) + + trueLit match { + case LiteralValue(value, javaType) => + assert(value == "true" && javaType == "boolean") + case _ => fail() + } + + falseLit match { + case LiteralValue(value, javaType) => + assert(value == "false" && javaType == "boolean") + case _ => fail() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 392906a022903..434214a10e1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -51,7 +51,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { nullable: Boolean): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val isNullVar = if (nullable) { + VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { @@ -62,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 12ae1ea4a7c13..0d9a62cace62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,7 +157,8 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 384f0398a1ec0..85c5ebfdaa689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -170,9 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), + VariableValue(index, CodeGenerator.JAVA_INT))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) } } else { Seq.empty @@ -315,9 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, javaType)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, + VariableValue(value, javaType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6ddaacfee1a40..805ff3cf001ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", "false", row) + ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -126,10 +126,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, "false", ev.value) + ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", "false", "unsafeRow") + ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) } } } @@ -241,15 +241,16 @@ trait CodegenSupport extends SparkPlan { parameters += s"$paramType $paramName" val paramIsNull = if (!attributes(i).nullable) { // Use constant `false` without passing `isNull` for non-nullable variable. - "false" + FalseLiteral } else { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - isNull + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) } - paramVars += ExprCode("", paramIsNull, paramName) + paramVars += ExprCode("", paramIsNull, + VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1926e9373bc55..8f7f10243d4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 6b60b414ffe5f..4978954271311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4707022f74547..cab7081400ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = FalseLiteral } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 487d6a2383318..fa62a32d51f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))) } } } @@ -487,7 +488,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode("", FalseLiteral, + VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5a511b30e4fd9..b61acb8d4fda9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -531,11 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, FalseLiteral, + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } }.unzip } From 252468a744b95082400ba9e8b2e3b3d9d50ab7fa Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 9 Apr 2018 12:18:07 -0700 Subject: [PATCH 0588/2461] [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ## What changes were proposed in this pull request? API: ``` trait ClassificationNode extends Node def getLabelCount(label: Int): Double trait RegressionNode extends Node def getCount(): Double def getSum(): Double def getSquareSum(): Double // turn LeafNode to be trait trait LeafNode extends Node { def prediction: Double def impurity: Double ... } class ClassificationLeafNode extends ClassificationNode with LeafNode class RegressionLeafNode extends RegressionNode with LeafNode // turn InternalNode to be trait trait InternalNode extends Node{ def gain: Double def leftChild: Node def rightChild: Node def split: Split ... } class ClassificationInternalNode extends ClassificationNode with InternalNode override def leftChild: ClassificationNode override def rightChild: ClassificationNode class RegressionInternalNode extends RegressionNode with InternalNode override val leftChild: RegressionNode override val rightChild: RegressionNode class DecisionTreeClassificationModel override val rootNode: ClassificationNode class DecisionTreeRegressionModel override val rootNode: RegressionNode ``` Closes #17466 ## How was this patch tested? UT will be added soon. Author: WeichenXu Author: jkbradley Closes #20786 from WeichenXu123/tree_stat_api_2. --- .../DecisionTreeClassifier.scala | 14 +- .../ml/classification/GBTClassifier.scala | 6 +- .../RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 247 ++++++++++++++---- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 ++- .../DecisionTreeClassifierSuite.scala | 31 ++- .../classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 14 + .../ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 9 +- 16 files changed, 333 insertions(+), 108 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 65cce697d8202..771cd4fe91dcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: Node, + @Since("1.4.0")override val rootNode: ClassificationNode, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) + val model = new DecisionTreeClassificationModel(metadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + new DecisionTreeClassificationModel(uid, + rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cd44489f618b2..c0255103bc313 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -371,14 +371,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 78a4972adbdbb..bb972e9706fc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,15 +310,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + val tree = new DecisionTreeClassificationModel(treeMetadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ad154fcd010cc..5cef5c9f21f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node, + override val rootNode: RegressionNode, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int) = + private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) + val model = new DecisionTreeRegressionModel(metadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 6569ff2a5bfc1..834aaa0e362d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -302,15 +302,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2d594460c2475..7f77398ba2a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -269,13 +269,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d30be452a436e..0242bc76698d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed abstract class Node extends Serializable { +sealed trait Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -84,35 +86,86 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + def fromOld( + oldNode: OldNode, + categoricalFeatures: Map[Int, Int], + isClassification: Boolean): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) + if (isClassification) { + new ClassificationLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + new RegressionLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, - gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + if (isClassification) { + new ClassificationInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } else { + new RegressionInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } } } } -/** - * Decision tree leaf node. - * @param prediction Prediction this node makes - * @param impurity Impurity measure at this node (for training data) - */ -class LeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { +@Since("2.4.0") +sealed trait ClassificationNode extends Node { + + /** + * Get count of training examples for specified label in this node + * @param label label number in the range [0, numClasses) + */ + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < impurityStats.stats.length, + "label should be in the range between 0 (inclusive) " + + s"and ${impurityStats.stats.length} (exclusive).") + impurityStats.stats(label) + } +} + +@Since("2.4.0") +sealed trait RegressionNode extends Node { + + /** Number of training data points in this node */ + @Since("2.4.0") + def getCount: Double = impurityStats.stats(0) + + /** Sum over training data points of the labels in this node */ + @Since("2.4.0") + def getSum: Double = impurityStats.stats(1) + + /** Sum over training data points of the square of the labels in this node */ + @Since("2.4.0") + def getSumOfSquares: Double = impurityStats.stats(2) +} + +@Since("2.4.0") +sealed trait LeafNode extends Node { + + /** Prediction this node makes. */ + def prediction: Double + + def impurity: Double override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -135,32 +188,58 @@ class LeafNode private[ml] ( override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * Decision tree leaf node for classification. + */ +@Since("2.4.0") +class ClassificationLeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with LeafNode { + override private[tree] def deepCopy(): Node = { - new LeafNode(prediction, impurity, impurityStats) + new ClassificationLeafNode(prediction, impurity, impurityStats) } } /** - * Internal Decision Tree node. - * @param prediction Prediction this node would make if it were a leaf node - * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - * @param leftChild Left-hand child node - * @param rightChild Right-hand child node - * @param split Information about the test used to split to the left or right child. + * Decision tree leaf node for regression. */ -class InternalNode private[ml] ( +@Since("2.4.0") +class RegressionLeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - val gain: Double, - val leftChild: Node, - val rightChild: Node, - val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with LeafNode { - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. + override private[tree] def deepCopy(): Node = { + new RegressionLeafNode(prediction, impurity, impurityStats) + } +} + +/** + * Internal Decision Tree node. + */ +@Since("2.4.0") +sealed trait InternalNode extends Node { + + /** + * Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + */ + def gain: Double + + /** Left-hand child node */ + def leftChild: Node + + /** Right-hand child node */ + def rightChild: Node + + /** Information about the test used to split to the left or right child. */ + def split: Split override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -205,11 +284,6 @@ class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } - - override private[tree] def deepCopy(): Node = { - new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), - split, impurityStats) - } } private object InternalNode { @@ -240,6 +314,57 @@ private object InternalNode { } } +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class ClassificationInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: ClassificationNode, + override val rightChild: ClassificationNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new ClassificationInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[ClassificationNode], + rightChild.deepCopy().asInstanceOf[ClassificationNode], + split, impurityStats) + } +} + +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class RegressionInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: RegressionNode, + override val rightChild: RegressionNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new RegressionInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[RegressionNode], + rightChild.deepCopy().asInstanceOf[RegressionNode], + split, impurityStats) + } +} + + /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -265,30 +390,52 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode: Node = toNode(prune = true) + def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) + + def toClassificationNode(prune: Boolean = true): ClassificationNode = { + toNode(true, prune).asInstanceOf[ClassificationNode] + } + + def toRegressionNode(prune: Boolean = true): RegressionNode = { + toNode(false, prune).asInstanceOf[RegressionNode] + } /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(prune: Boolean = true): Node = { + def toNode(isClassification: Boolean, prune: Boolean): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + (leftChild.get.toNode(isClassification, prune), + rightChild.get.toNode(isClassification, prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + if (isClassification) { + new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } else { + new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } case (l, r) => - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l, r, split.get, stats.impurityCalculator) + if (isClassification) { + new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, + stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], + split.get, stats.impurityCalculator) + } else { + new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], + split.get, stats.impurityCalculator) + } } } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + val impurity = if (stats.valid) stats.impurity else -1.0 + if (isClassification) { + new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, stats.impurityCalculator) } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + new RegressionLeafNode(stats.impurityCalculator.predict, impurity, + stats.impurityCalculator) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 16f32d76b9984..056a94b351f79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -224,23 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c3617e7fd..f027b14f1d476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => + case _: LeafNode => // do nothing + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession): Node = { + sparkSession: SparkSession, isClassification: Boolean): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, isClassification) } /** @@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + isClassification: Boolean): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, - n.split.getSplit, impurityStats) + if (isClassification) { + new ClassificationInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], + n.split.getSplit, impurityStats) + } else { + new RegressionInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], + n.split.getSplit, impurityStats) + } } else { - new LeafNode(n.prediction, n.impurity, impurityStats) + if (isClassification) { + new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) + } else { + new RegressionLeafNode(n.prediction, n.impurity, impurityStats) + } } finalNodes(n.id) = node } @@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String, + isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( + nodeData.toArray, impurityType, isClassification) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2930f4900d50e..d3dbb4e754d3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,7 +61,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -375,6 +376,32 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("entropy") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val rootNode1 = model1.rootNode + assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val rootNode2 = model2.rootNode + assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 57796069f6052..f0ee5496f9d1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.RegressionLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -69,7 +69,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ba4a9cf082785..3062aa9f3d274 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,7 +71,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 29a438396516b..9ae27339b11d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode + + assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 + && statInfo.getSumOfSquares == 600.0) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..4dbbd75d2466d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) - .asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], + numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..3f03d909d4a4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { + def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + if (isClassification) { + new ClassificationInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], + split, parentImp) + } else { + new RegressionInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], + split, parentImp) + } } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1b6d1dec69d49..b37b4d51775e8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,14 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") ) // Exclude rules for 2.3.x From 61b724724cc4a18818774ecaaa5a45b70fdb8dae Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Apr 2018 14:07:33 -0700 Subject: [PATCH 0589/2461] [INFRA] Close stale PRs. Closes #20957 Closes #20792 From f94f3624ea81053653a06560808cb71f510c6828 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 9 Apr 2018 21:07:28 -0700 Subject: [PATCH 0590/2461] [SPARK-23947][SQL] Add hashUTF8String convenience method to hasher classes ## What changes were proposed in this pull request? Add `hashUTF8String()` to the hasher classes to allow Spark SQL codegen to generate cleaner code for hashing `UTF8String`s. No change in behavior otherwise. Although with the introduction of SPARK-10399, the code size for hashing `UTF8String` is already smaller, it's still good to extract a separate function in the hasher classes so that the generated code can stay clean. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21016 from rednaxelafx/hashutf8. --- .../apache/spark/sql/catalyst/expressions/HiveHasher.java | 5 +++++ .../java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java | 7 ++++++- .../org/apache/spark/sql/catalyst/expressions/XXH64.java | 5 +++++ .../org/apache/spark/sql/catalyst/expressions/hash.scala | 6 ++---- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c34e36903a93e..62b75ae8aa01d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -51,4 +52,8 @@ public static int hashUnsafeBytesBlock(MemoryBlock mb) { public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); } + + public static int hashUTF8String(UTF8String str) { + return hashUnsafeBytesBlock(str.getMemoryBlock()); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index f372b19fac119..aff6e93d647fe 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -20,6 +20,7 @@ import com.google.common.primitives.Ints; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -82,6 +83,10 @@ public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { return fmix(h1, lengthInBytes); } + public static int hashUTF8String(UTF8String str, int seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } @@ -91,7 +96,7 @@ public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, } public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { - // This is compatible with original and another implementations. + // This is compatible with original and other implementations. // Use this method for new components after Spark 2.3. int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index fe727f6011cbf..8e9c0a2e9dc81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off /** @@ -107,6 +108,10 @@ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { return fmix(hash); } + public static long hashUTF8String(UTF8String str, long seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index df29c38d64d3d..ef790338bdd27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -361,8 +361,7 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" + s"$result = $hasherClassName.hashUTF8String($input, $result);" } protected def genHashForMap( @@ -725,8 +724,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" + s"$result = $hasherClassName.hashUTF8String($input);" } override protected def genHashForArray( From 64988841540464e261b0cbaede43058e7bd36261 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 9 Apr 2018 21:49:49 -0700 Subject: [PATCH 0591/2461] [SPARK-23898][SQL] Simplify add & subtract code generation ## What changes were proposed in this pull request? Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication. This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21005 from hvanhovell/SPARK-23898. --- .../sql/catalyst/expressions/arithmetic.scala | 50 ++++++++----------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 478ff3a7c1011..defd6f3cd8849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) @@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) - case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { @@ -104,7 +104,7 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") @@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType - override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") + /** Name of the function for this expression on a [[CalendarInterval]] type. */ + def calendarIntervalMethod: String = + sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case CalendarIntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, @@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" + + override def calendarIntervalMethod: String = "add" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" + override def decimalMethod: String = "$minus" + + override def calendarIntervalMethod: String = "subtract" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "pmod" - protected def checkTypesInternal(t: DataType) = + protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeUtils.checkForNumericExpr(t, "pmod") override def inputType: AbstractDataType = NumericType From 95034af69623bb8be5b9f5eabf50980bdeca48e6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 10 Apr 2018 08:51:35 -0500 Subject: [PATCH 0592/2461] [SPARK-23841][ML] NodeIdCache should unpersist the last cached nodeIdsForInstances ## What changes were proposed in this pull request? unpersist the last cached nodeIdsForInstances in `deleteAllCheckpoints` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20956 from zhengruifeng/NodeIdCache_cleanup. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index a7c5f489dea86..5b14a63ada4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -95,7 +95,7 @@ private[spark] class NodeIdCache( splits: Array[Array[Split]]): Unit = { if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } prevNodeIdsForInstances = nodeIdsForInstances @@ -166,9 +166,13 @@ private[spark] class NodeIdCache( } } } + if (nodeIdsForInstances != null) { + // Unpersist current one if one exists. + nodeIdsForInstances.unpersist(false) + } if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } } } From 3323b156f9c0beb0b3c2b724a6faddc6ffdfe99a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 10 Apr 2018 17:32:00 +0200 Subject: [PATCH 0593/2461] [SPARK-23864][SQL] Add unsafe object writing to UnsafeWriter ## What changes were proposed in this pull request? This PR moves writing of `UnsafeRow`, `UnsafeArrayData` & `UnsafeMapData` out of the `GenerateUnsafeProjection`/`InterpretedUnsafeProjection` classes into the `UnsafeWriter` interface. This cleans up the code a little bit, and it should also result in less byte code for the code generated path. ## How was this patch tested? Existing tests Author: Herman van Hovell Closes #20986 from hvanhovell/SPARK-23864. --- .../expressions/codegen/UnsafeWriter.java | 72 ++-- .../InterpretedUnsafeProjection.scala | 46 +-- .../codegen/GenerateUnsafeProjection.scala | 322 ++++++++---------- .../spark/sql/types/UserDefinedType.scala | 10 + 4 files changed, 204 insertions(+), 246 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index de0eb6dbb76be..2781655002000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -103,21 +106,7 @@ protected final void zeroOutPaddingBytes(int numBytes) { public abstract void write(int ordinal, Decimal input, int precision, int scale); public final void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(getBuffer(), cursor()); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - increaseCursor(roundedSize); + writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes()); } public final void write(int ordinal, byte[] input) { @@ -125,20 +114,19 @@ public final void write(int ordinal, byte[] input) { } public final void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes); + } - // grow the global buffer before writing data. + private void writeUnalignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); grow(roundedSize); - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); - + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. increaseCursor(roundedSize); } @@ -156,6 +144,40 @@ public final void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public final void write(int ordinal, UnsafeRow row) { + writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); + } + + public final void write(int ordinal, UnsafeMapData map) { + writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes()); + } + + public final void write(UnsafeArrayData array) { + // Unsafe arrays both can be written as a regular array field or as part of a map. This makes + // updating the offset and size dependent on the code path, this is why we currently do not + // provide an method for writing unsafe arrays that also updates the size and offset. + int numBytes = array.getSizeInBytes(); + grow(numBytes); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + getBuffer(), + cursor(), + numBytes); + increaseCursor(numBytes); + } + + private void writeAlignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + grow(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(numBytes); + } + protected final void writeBoolean(long offset, boolean value) { Platform.putBoolean(getBuffer(), offset, value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index b31466f5c92d1..6d69d69b1c802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => - writeUnsafeData( - rowWriter, - row.getBaseObject, - row.getBaseOffset, - row.getSizeInBytes) + writer.write(i, row) case row => + val previousCursor = writer.cursor() // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. rowWriter.resetRowWriter() structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => @@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => - writeUnsafeData( - valueArrayWriter, - map.getBaseObject, - map.getBaseOffset, - map.getSizeInBytes) + writer.write(i, map) case map => + val previousCursor = writer.cursor() + // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) valueArrayWriter.increaseCursor(8) @@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => @@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementWriter: (SpecializedGetters, Int) => Unit, array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => - writeUnsafeData( - arrayWriter, - unsafe.getBaseObject, - unsafe.getBaseOffset, - unsafe.getSizeInBytes) + arrayWriter.write(unsafe) case _ => val numElements = array.numElements() arrayWriter.initialize(numElements) @@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { i += 1 } } - - /** - * Write an opaque block of data to the buffer. This is used to copy - * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. - */ - private def writeUnsafeData( - writer: UnsafeWriter, - baseObject: AnyRef, - baseOffset: Long, - sizeInBytes: Int) : Unit = { - writer.grow(sizeInBytes) - Platform.copyMemory( - baseObject, - baseOffset, - writer.getBuffer, - writer.cursor, - sizeInBytes) - writer.increaseCursor(sizeInBytes) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4a4d76313a543..2fb441ac4500e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,14 +32,13 @@ import org.apache.spark.sql.types._ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { /** Returns true iff we support this data type. */ - def canSupport(dataType: DataType): Boolean = dataType match { + def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true - case t: AtomicType => true + case _: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -47,6 +46,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeStructToBuffer( ctx: CodegenContext, input: String, + index: String, fieldTypes: Seq[DataType], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. @@ -60,15 +60,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") - + val previousCursor = ctx.freshName("previousCursor") s""" - final InternalRow $tmpInput = $input; - if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} - } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} - } - """ + |final InternalRow $tmpInput = $input; + |if ($tmpInput instanceof UnsafeRow) { + | $rowWriter.write($index, (UnsafeRow) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin } private def writeExpressionsToBuffer( @@ -95,10 +99,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => - val dt = dataType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -106,58 +107,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } - val previousCursor = ctx.freshName("previousCursor") - - val writeField = dt match { - case t: StructType => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$rowWriter.write($index, ${input.value});" - } + val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) if (input.isNull == "false") { s""" - ${input.code} - ${writeField.trim} - """ + |${input.code} + |${writeField.trim} + """.stripMargin } else { s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { - ${writeField.trim} - } - """ + |${input.code} + |if (${input.isNull}) { + | ${setNull.trim} + |} else { + | ${writeField.trim} + |} + """.stripMargin } } @@ -171,11 +136,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro funcName = "writeFields", arguments = Seq("InternalRow" -> row)) } - s""" - $resetWriter - $writeFieldsCode - """.trim + |$resetWriter + |$writeFieldsCode + """.stripMargin } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -189,10 +153,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val et = elementType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val et = UserDefinedType.sqlType(elementType) val jt = CodeGenerator.javaType(et) @@ -205,106 +166,100 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val previousCursor = ctx.freshName("previousCursor") val element = CodeGenerator.getValue(tmpInput, et, index) - val writeElement = et match { - case t: StructType => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$arrayWriter.write($index, $element);" - } - val primitiveTypeName = - if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" - final ArrayData $tmpInput = $input; - if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} - } else { - final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($numElements); - - for (int $index = 0; $index < $numElements; $index++) { - if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - } else { - $writeElement - } - } - } - """ + |final ArrayData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeArrayData) { + | $rowWriter.write((UnsafeArrayData) $tmpInput); + |} else { + | final int $numElements = $tmpInput.numElements(); + | $arrayWriter.initialize($numElements); + | + | for (int $index = 0; $index < $numElements; $index++) { + | if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + | } else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + | } + | } + |} + """.stripMargin } // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, + index: String, keyType: DataType, valueType: DataType, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") + val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final MapData $tmpInput = $input; - if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} - } else { - // preserve 8 bytes to write the key array numBytes later. - $rowWriter.grow(8); - $rowWriter.increaseCursor(8); + |final MapData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeMapData) { + | $rowWriter.write($index, (UnsafeMapData) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | + | // preserve 8 bytes to write the key array numBytes later. + | $rowWriter.grow(8); + | $rowWriter.increaseCursor(8); + | + | // Remember the current cursor so that we can write numBytes of key array later. + | final int $tmpCursor = $rowWriter.cursor(); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | + | // Write the numBytes of key array into the first 8 bytes. + | Platform.putLong( + | $rowWriter.getBuffer(), + | $tmpCursor - 8, + | $rowWriter.cursor() - $tmpCursor); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin + } - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $rowWriter.cursor(); + private def writeElement( + ctx: CodegenContext, + input: String, + index: String, + dt: DataType, + writer: String): String = dt match { + case t: StructType => + writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + + case ArrayType(et, _) => + val previousCursor = ctx.freshName("previousCursor") + s""" + |// Remember the current cursor so that we can calculate how many bytes are + |// written later. + |final int $previousCursor = $writer.cursor(); + |${writeArrayToBuffer(ctx, input, et, writer)} + |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + """.stripMargin - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} - // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); + case MapType(kt, vt, _) => + writeMapToBuffer(ctx, input, index, kt, vt, writer) - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} - } - """ - } + case DecimalType.Fixed(precision, scale) => + s"$writer.write($index, $input, $precision, $scale);" - /** - * If the input is already in unsafe format, we don't need to go through all elements/fields, - * we can directly write it. - */ - private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { - val sizeInBytes = ctx.freshName("sizeInBytes") - s""" - final int $sizeInBytes = $input.getSizeInBytes(); - // grow the global buffer before writing data. - $rowWriter.grow($sizeInBytes); - $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); - $rowWriter.increaseCursor($sizeInBytes); - """ + case NullType => "" + + case _ => s"$writer.write($index, $input);" } def createCode( @@ -332,10 +287,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" - $rowWriter.reset(); - $evalSubexpr - $writeExpressions - """ + |$rowWriter.reset(); + |$evalSubexpr + |$writeExpressions + """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", canDirectAccess = true)) @@ -363,38 +318,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new SpecificUnsafeProjection(references); - } - - class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - - private Object[] references; - ${ctx.declareMutableStates()} - - public SpecificUnsafeProjection(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public void initialize(int partitionIndex) { - ${ctx.initPartition()} - } - - // Scala.Function1 need this - public java.lang.Object apply(java.lang.Object row) { - return apply((InternalRow) row); - } - - public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code.trim} - return ${eval.value}; - } - - ${ctx.declareAddedFunctions()} - } - """ + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificUnsafeProjection(references); + |} + | + |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { + | + | private Object[] references; + | ${ctx.declareMutableStates()} + | + | public SpecificUnsafeProjection(Object[] references) { + | this.references = references; + | ${ctx.initMutableStates()} + | } + | + | public void initialize(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | // Scala.Function1 need this + | public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + | } + | + | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | return ${eval.value}; + | } + | + | ${ctx.declareAddedFunctions()} + |} + """.stripMargin val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5a944e763e099..6af16e2dba105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def catalogString: String = sqlType.simpleString } +private[spark] object UserDefinedType { + /** + * Get the sqlType of a (potential) [[UserDefinedType]]. + */ + def sqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case _ => dt + } +} + /** * The user defined type in Python. * From e179658914963de472120a81621396706584c949 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 10 Apr 2018 09:33:09 -0700 Subject: [PATCH 0594/2461] [SPARK-19724][SQL][FOLLOW-UP] Check location of managed table when ignoreIfExists is true ## What changes were proposed in this pull request? In the PR #20886, I mistakenly check the table location only when `ignoreIfExists` is false, which was following the original deprecated PR. That was wrong. When `ignoreIfExists` is true and the target table doesn't exist, we should also check the table location. In other word, **`ignoreIfExists` has nothing to do with table location validation**. This is a follow-up PR to fix the mistake. ## How was this patch tested? Add one unit test. Author: Gengliang Wang Closes #21001 from gengliangwang/SPARK-19724-followup. --- .../spark/sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++++-- .../execution/command/createDataSourceTables.scala | 2 +- .../apache/spark/sql/execution/command/DDLSuite.scala | 9 +++++++++ .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 52ed89ef8d781..c390337c03ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -286,7 +286,10 @@ class SessionCatalog( * Create a metastore table in the database specified in `tableDefinition`. * If no such database is specified, create it in the current database. */ - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + def createTable( + tableDefinition: CatalogTable, + ignoreIfExists: Boolean, + validateLocation: Boolean = true): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -305,7 +308,11 @@ class SessionCatalog( } requireDbExists(db) - if (!ignoreIfExists) { + if (tableExists(newTableDefinition.identifier)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistsException(db = db, table = table) + } + } else if (validateLocation) { validateTableLocation(newTableDefinition) } externalCatalog.createTable(newTableDefinition, ignoreIfExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f7c3e9b019258..f6ef433f2ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -182,7 +182,7 @@ case class CreateDataSourceTableAsSelectCommand( // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) // Table location is already validated. No need to check it again during table creation. - sessionState.catalog.createTable(newTable, ignoreIfExists = true) + sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4304d0b6f6b16..cbd7f9d6f67be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -425,6 +425,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") }.getMessage assert(ex.contains(exMsgWithDefaultDB)) + + // Always check location of managed table, with or without (IF NOT EXISTS) + withTable("tab2") { + sql(s"CREATE TABLE tab2 (col1 int, col2 string) USING ${dataSource}") + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE IF NOT EXISTS tab1 LIKE tab2") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } } finally { waitForTasksToFinish() Utils.deleteRecursively(tableLoc) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index db76ec9d084cb..c85db78c732de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1461,7 +1461,7 @@ class HiveDDLSuite assert(e2.getMessage.contains(forbiddenPrefix + "foo")) val e3 = intercept[AnalysisException] { - sql(s"CREATE TABLE tbl (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") + sql(s"CREATE TABLE tbl2 (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") } assert(e3.getMessage.contains(forbiddenPrefix + "foo")) } From adb222b957f327a69929b8f16fa5ebc071fa99e3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Apr 2018 11:18:14 -0700 Subject: [PATCH 0595/2461] [SPARK-23751][ML][PYSPARK] Kolmogorov-Smirnoff test Python API in pyspark.ml ## What changes were proposed in this pull request? Kolmogorov-Smirnoff test Python API in `pyspark.ml` **Note** API with `CDF` is a little difficult to support in python. We can add it in following PR. ## How was this patch tested? doctest Author: WeichenXu Closes #20904 from WeichenXu123/ks-test-py. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 29 +-- python/pyspark/ml/stat.py | 181 ++++++++++++------ 2 files changed, 138 insertions(+), 72 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index c62d7463288f7..af8ff64d33ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.java.function.Function import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.stat.{Statistics => OldStatistics} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col /** @@ -59,7 +59,7 @@ object KolmogorovSmirnovTest { * distribution of the sample data and the theoretical distribution we can provide a test for the * the null hypothesis that the sample data comes from that theoretical distribution. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value * @return DataFrame containing the test result for the input sampled data. @@ -68,10 +68,10 @@ object KolmogorovSmirnovTest { * - `statistic: Double` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + def test(dataset: Dataset[_], sampleCol: String, cdf: Double => Double): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) @@ -81,10 +81,11 @@ object KolmogorovSmirnovTest { * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, - cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - val f: Double => Double = x => cdf.call(x) - test(dataset, sampleCol, f) + def test( + dataset: Dataset[_], + sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) } /** @@ -92,10 +93,11 @@ object KolmogorovSmirnovTest { * distribution equality. Currently supports the normal distribution, taking as parameters * the mean and standard deviation. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param distName a `String` name for a theoretical distribution, currently only support "norm". - * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @param params `Double*` specifying the parameters to be used for the theoretical distribution. + * For "norm" distribution, the parameters includes mean and variance. * @return DataFrame containing the test result for the input sampled data. * This DataFrame will contain a single Row with the following fields: * - `pValue: Double` @@ -103,10 +105,13 @@ object KolmogorovSmirnovTest { */ @Since("2.4.0") @varargs - def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + def test( + dataset: Dataset[_], + sampleCol: String, distName: String, + params: Double*): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 0eeb5e528434a..93d0f4fd9148f 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -32,32 +32,6 @@ class ChiSquareTest(object): The null hypothesis is that the occurrence of the outcomes is statistically independent. - :param dataset: - DataFrame of categorical labels and categorical features. - Real-valued features will be treated as categorical for each distinct value. - :param featuresCol: - Name of features column in dataset, of type `Vector` (`VectorUDT`). - :param labelCol: - Name of label column in dataset, of any numerical type. - :return: - DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: - - `pValues: Vector` - - `degreesOfFreedom: Array[Int]` - - `statistics: Vector` - Each of these fields has one value per feature. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import ChiSquareTest - >>> dataset = [[0, Vectors.dense([0, 0, 1])], - ... [0, Vectors.dense([1, 0, 1])], - ... [1, Vectors.dense([2, 1, 1])], - ... [1, Vectors.dense([3, 1, 1])]] - >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) - >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') - >>> chiSqResult.select("degreesOfFreedom").collect()[0] - Row(degreesOfFreedom=[3, 1, 0]) - .. versionadded:: 2.2.0 """ @@ -66,6 +40,32 @@ class ChiSquareTest(object): def test(dataset, featuresCol, labelCol): """ Perform a Pearson's independence test using dataset. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest @@ -85,40 +85,6 @@ class Correlation(object): which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` to avoid recomputing the common lineage. - :param dataset: - A dataset or a dataframe. - :param column: - The name of the column of vectors for which the correlation coefficient needs - to be computed. This must be a column of the dataset, and it must contain - Vector objects. - :param method: - String specifying the method to use for computing correlation. - Supported: `pearson` (default), `spearman`. - :return: - A dataframe that contains the correlation matrix of the column of vectors. This - dataframe contains a single row and a single column of name - '$METHODNAME($COLUMN)'. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import Correlation - >>> dataset = [[Vectors.dense([1, 0, 0, -2])], - ... [Vectors.dense([4, 5, 0, 3])], - ... [Vectors.dense([6, 7, 0, 8])], - ... [Vectors.dense([9, 0, 0, 1])]] - >>> dataset = spark.createDataFrame(dataset, ['features']) - >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] - >>> print(str(pearsonCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], - [ 0.0556..., 1. , NaN, 0.9135...], - [ NaN, NaN, 1. , NaN], - [ 0.4004..., 0.9135..., NaN, 1. ]]) - >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] - >>> print(str(spearmanCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], - [ 0.1054..., 1. , NaN, 0.9486... ], - [ NaN, NaN, 1. , NaN], - [ 0.4 , 0.9486... , NaN, 1. ]]) - .. versionadded:: 2.2.0 """ @@ -127,6 +93,40 @@ class Correlation(object): def corr(dataset, column, method="pearson"): """ Compute the correlation matrix with specified method using dataset. + + :param dataset: + A Dataset or a DataFrame. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A DataFrame that contains the correlation matrix of the column of vectors. This + DataFrame contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) """ sc = SparkContext._active_spark_context javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation @@ -134,6 +134,67 @@ def corr(dataset, column, method="pearson"): return _java2py(sc, javaCorrObj.corr(*args)) +class KolmogorovSmirnovTest(object): + """ + .. note:: Experimental + + Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous + distribution. + + By comparing the largest difference between the empirical cumulative + distribution of the sample data and the theoretical distribution we can provide a test for the + the null hypothesis that the sample data comes from that theoretical distribution. + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def test(dataset, sampleCol, distName, *params): + """ + Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution + equality. Currently supports the normal distribution, taking as parameters the mean and + standard deviation. + + :param dataset: + a Dataset or a DataFrame containing the sample of data to test. + :param sampleCol: + Name of sample column in dataset, of any numerical type. + :param distName: + a `string` name for a theoretical distribution, currently only support "norm". + :param params: + a list of `Double` values specifying the parameters to be used for the theoretical + distribution. For "norm" distribution, the parameters includes mean and variance. + :return: + A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data. + This DataFrame will contain a single Row with the following fields: + - `pValue: Double` + - `statistic: Double` + + >>> from pyspark.ml.stat import KolmogorovSmirnovTest + >>> dataset = [[-1.0], [0.0], [1.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + >>> dataset = [[2.0], [3.0], [4.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest + dataset = _py2java(sc, dataset) + params = [float(param) for param in params] + return _java2py(sc, javaTestObj.test(dataset, sampleCol, distName, + _jvm().PythonUtils.toSeq(params))) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 4f1e8b9bb7d795d4ca3d5cd5dcc0f9419e52dfae Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 10 Apr 2018 15:41:45 -0700 Subject: [PATCH 0596/2461] [SPARK-23871][ML][PYTHON] add python api for VectorAssembler handleInvalid ## What changes were proposed in this pull request? add python api for VectorAssembler handleInvalid ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #21003 from huaxingao/spark-23871. --- .../spark/ml/feature/VectorAssembler.scala | 12 +++--- python/pyspark/ml/feature.py | 42 ++++++++++++++++--- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 6bf4aa38b1fcb..4061154b39c14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) */ @Since("2.4.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with - |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the - |output). Column lengths are taken from the size of ML Attribute Group, which can be set using - |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred - |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. - |""".stripMargin.replaceAll("\n", " "), + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5a3e0dd655150..cdda30cfab482 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ A feature transformer that merges multiple columns into a vector column. @@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath) >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs True + >>> dfWithNullsAndNaNs = spark.createDataFrame( + ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"]) + >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features", + ... handleInvalid="keep") + >>> vecAssembler2.transform(dfWithNullsAndNaNs).show() + +---+---+----+-------------+ + | a| b| c| features| + +---+---+----+-------------+ + |1.0|2.0|null|[1.0,2.0,NaN]| + |3.0|NaN| 4.0|[3.0,NaN,4.0]| + |5.0|6.0| 7.0|[5.0,6.0,7.0]| + +---+---+----+-------------+ + ... + >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show() + +---+---+---+-------------+ + | a| b| c| features| + +---+---+---+-------------+ + |5.0|6.0|7.0|[5.0,6.0,7.0]| + +---+---+---+-------------+ + ... .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " + + "and NaN values). Options are 'skip' (filter out rows with invalid " + + "data), 'error' (throw an error), or 'keep' (return relevant number " + + "of NaN in the output). Column lengths are taken from the size of ML " + + "Attribute Group, which can be set using `VectorSizeHint` in a " + + "pipeline before `VectorAssembler`. Column lengths can also be " + + "inferred from first rows of the data since it is safe to do so but " + + "only in case of 'error' or 'skip').", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCols=None, outputCol=None): + def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCols=None, outputCol=None) + __init__(self, inputCols=None, outputCol=None, handleInvalid="error") """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) + self._setDefault(handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCols=None, outputCol=None): + def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCols=None, outputCol=None) + setParams(self, inputCols=None, outputCol=None, handleInvalid="error") Sets params for this VectorAssembler. """ kwargs = self._input_kwargs From 7c7570d466a8ded51e580eb6a28583bd9a9c5337 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 10 Apr 2018 17:26:06 -0700 Subject: [PATCH 0597/2461] [SPARK-23944][ML] Add the set method for the two LSHModel ## What changes were proposed in this pull request? Add two set method for LSHModel in LSH.scala, BucketedRandomProjectionLSH.scala, and MinHashLSH.scala ## How was this patch tested? New test for the param setup was added into - BucketedRandomProjectionLSHSuite.scala - MinHashLSHSuite.scala Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21015 from ludatabricks/SPARK-23944. --- .../spark/ml/feature/BucketedRandomProjectionLSH.scala | 8 ++++++++ .../src/main/scala/org/apache/spark/ml/feature/LSH.scala | 6 ++++++ .../scala/org/apache/spark/ml/feature/MinHashLSH.scala | 8 ++++++++ .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 8 ++++++++ .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 8 ++++++++ 5 files changed, 38 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 36a46ca6ff4b7..41eaaf9679914 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml]( private[ml] val randUnitVectors: Array[Vector]) extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { key: Vector => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 1c9f47a0b201d..a70931f783f45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHParams with MLWritable { self: T => + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 145422a059196..556848e45532d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -51,6 +51,14 @@ class MinHashLSHModel private[ml]( private[ml] val randCoefficients: Array[(Int, Int)]) extends LSHModel[MinHashLSHModel] { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { elems: Vector => { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ed9a39d8d1512..9b823259b1deb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest ParamsSuite.checkParams(model) } + test("setters") { + val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("BucketedRandomProjectionLSH: default params") { val brp = new BucketedRandomProjectionLSH assert(brp.getNumHashTables === 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 96df68dbdf053..3da0fb7da01ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa ParamsSuite.checkParams(model) } + test("setters") { + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("MinHashLSH: default params") { val rp = new MinHashLSH assert(rp.getNumHashTables === 1.0) From c7622befdadfea725797d76e820e3dfc76fec927 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:42:09 +0800 Subject: [PATCH 0598/2461] [SPARK-23847][FOLLOWUP][PYTHON][SQL] Actually test [desc|acs]_nulls_[first|last] functions in PySpark ## What changes were proposed in this pull request? There was a mistake in `tests.py` missing `assertEquals`. ## How was this patch tested? Fixed tests. Author: hyukjinkwon Closes #21035 from HyukjinKwon/SPARK-23847. --- python/pyspark/sql/tests.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dd04ffb4ed393..96c2a776a5049 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,19 +2991,23 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() - def test_2_4_functions(self): + def test_sort_with_nulls_order(self): from pyspark.sql import functions df = self.spark.createDataFrame( [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) - df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] - df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] - df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] - df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) class HiveSparkSubmitTests(SparkSubmitTests): From 87611bba222a95158fc5b638a566bdf47346da8e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:44:01 +0800 Subject: [PATCH 0599/2461] [MINOR][DOCS] Fix R documentation generation instruction for roxygen2 ## What changes were proposed in this pull request? This PR proposes to fix `roxygen2` to `5.0.1` in `docs/README.md` for SparkR documentation generation. If I use higher version and creates the doc, it shows the diff below. Not a big deal but it bothered me. ```diff diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f..159fca61e06 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION -57,6 +57,6 Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 5.0.1 +RoxygenNote: 6.0.1 VignetteBuilder: knitr NeedsCompilation: no ``` ## How was this patch tested? Manually tested. I met this every time I set the new environment for Spark dev but I have kept forgetting to fix it. Author: hyukjinkwon Closes #21020 from HyukjinKwon/minor-r-doc. --- docs/README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/README.md b/docs/README.md index 9eac4ba35c458..dbea4d64c4298 100644 --- a/docs/README.md +++ b/docs/README.md @@ -22,10 +22,13 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' ``` -(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) +Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. + +Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 5.0.1, which is updated if the version is mismatched. ## Generating the Documentation HTML @@ -62,12 +65,12 @@ $ PRODUCTION=1 jekyll build ## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs) -You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory. +You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory. Similarly, you can build just the PySpark docs by running `make html` from the -`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and -the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh` +`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as +public in `__init__.py`. The SparkR docs can be built by running `$SPARK_HOME/R/create-docs.sh`, and +the SQL docs can be built by running `$SPARK_HOME/sql/create-docs.sh` after [building Spark](https://github.com/apache/spark#building-spark) first. When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various From c604d659e19c1b2704cdf8c8ea97edaf50d8cb6b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 11 Apr 2018 20:11:03 +0800 Subject: [PATCH 0600/2461] [SPARK-23951][SQL] Use actual java class instead of string representation. ## What changes were proposed in this pull request? This PR slightly refactors the newly added `ExprValue` API by quite a bit. The following changes are introduced: 1. `ExprValue` now uses the actual class instead of the class name as its type. This should give some more flexibility with generating code in the future. 2. Renamed `StatementValue` to `SimpleExprValue`. The statement concept is broader then an expression (untyped and it cannot be on the right hand side of an assignment), and this was not really what we were using it for. I have added a top level `JavaCode` trait that can be used in the future to reinstate (no pun intended) a statement a-like code fragment. 3. Added factory methods to the `JavaCode` companion object to make it slightly less verbose to create `JavaCode`/`ExprValue` objects. This is also what makes the diff quite large. 4. Added one more factory method to `ExprCode` to make it easier to create code-less expressions. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21026 from hvanhovell/SPARK-23951. --- .../sql/catalyst/expressions/Expression.scala | 10 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 35 +++- .../expressions/codegen/ExprValue.scala | 76 -------- .../codegen/GenerateMutableProjection.scala | 36 ++-- .../codegen/GenerateSafeProjection.scala | 25 +-- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/codegen/javaCode.scala | 166 ++++++++++++++++++ .../expressions/complexTypeCreator.scala | 2 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/literals.scala | 27 +-- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/nullExpressions.scala | 20 +-- .../expressions/objects/objects.scala | 7 +- .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/codegen/ExprValueSuite.scala | 14 +- .../sql/execution/ColumnarBatchScan.scala | 6 +- .../spark/sql/execution/ExpandExec.scala | 8 +- .../spark/sql/execution/GenerateExec.scala | 15 +- .../sql/execution/WholeStageCodegenExec.scala | 11 +- .../aggregate/HashAggregateExec.scala | 6 +- .../aggregate/HashMapGenerator.scala | 8 +- .../execution/basicPhysicalOperators.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 9 +- .../execution/joins/SortMergeJoinExec.scala | 12 +- 26 files changed, 315 insertions(+), 212 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7a5e49cb5206b..97dff6ae88299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(dataType)))) + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) + eval.isNull = JavaCode.isNullGlobal(globalIsNull) s"$globalIsNull = $localIsNull;" } else { "" @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = VariableValue(newValue, javaType) + eval.value = JavaCode.variable(newValue, dataType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index defd6f3cd8849..9212c3de1f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c9c60ef1be640..0abfc9fa4c465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { + def apply(isNull: ExprValue, value: ExprValue): ExprCode = { + ExprCode(code = "", isNull, value) + } + def forNullValue(dataType: DataType): ExprCode = { - val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = TrueLiteral, - value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) + ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { @@ -331,7 +333,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) + ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } def declareMutableStates(): String = { @@ -1004,8 +1006,9 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), - GlobalValue(value, javaType(expr.dataType))) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType)) subExprEliminationExprs ++= e.map(_ -> state).toMap } } @@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging { case _ => "Object" } + def javaClass(dt: DataType): Class[_] = dt match { + case BooleanType => java.lang.Boolean.TYPE + case ByteType => java.lang.Byte.TYPE + case ShortType => java.lang.Short.TYPE + case IntegerType | DateType => java.lang.Integer.TYPE + case LongType | TimestampType => java.lang.Long.TYPE + case FloatType => java.lang.Float.TYPE + case DoubleType => java.lang.Double.TYPE + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case udt: UserDefinedType[_] => javaClass(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[Object] + } + /** * Returns the boxed type in Java. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala deleted file mode 100644 index df5f1c58b1b2d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import scala.language.implicitConversions - -import org.apache.spark.sql.types.DataType - -// An abstraction that represents the evaluation result of [[ExprCode]]. -abstract class ExprValue { - - val javaType: String - - // Whether we can directly access the evaluation value anywhere. - // For example, a variable created outside a method can not be accessed inside the method. - // For such cases, we may need to pass the evaluation as parameter. - val canDirectAccess: Boolean - - def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) -} - -object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString -} - -// A literal evaluation of [[ExprCode]]. -class LiteralValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -object LiteralValue { - def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) - def unapply(literal: LiteralValue): Option[(String, String)] = - Some((literal.value, literal.javaType)) -} - -// A variable evaluation of [[ExprCode]]. -case class VariableValue( - val variableName: String, - val javaType: String) extends ExprValue { - override def toString: String = variableName - override val canDirectAccess: Boolean = false -} - -// A statement evaluation of [[ExprCode]]. -case class StatementValue( - val statement: String, - val javaType: String, - val canDirectAccess: Boolean = false) extends ExprValue { - override def toString: String = statement -} - -// A global variable evaluation of [[ExprCode]]. -case class GlobalValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -case object TrueLiteral extends LiteralValue("true", "boolean") -case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3ae0b54c754cf..33d14329ec95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() - val (validExpr, index) = expressions.zipWithIndex.filter { + val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true - }.unzip - val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + } + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { - case (ev, i) => - val e = expressions(i) - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") - if (e.nullable) { + val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { + case ((e, i), ev) => + val value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"), + e.dataType) + val (code, isNull) = if (e.nullable) { val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) + """.stripMargin, JavaCode.isNullGlobal(isNull)) } else { (s""" |${ev.code} |$value = ${ev.value}; - """.stripMargin, ev.isNull, value, i) + """.stripMargin, FalseLiteral) } + val update = CodeGenerator.updateColumn( + "mutableRow", + e.dataType, + i, + ExprCode(isNull, value), + e.nullable) + (code, update) } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(projectionCodes).map { - case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) - CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) - } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) - val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index a30a0b22cd305..01c350e9dbf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt)), dt) + val converter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), + dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow])) } private def createCodeForArray( @@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), - CodeGenerator.javaType(elementType)), elementType) + val elementConverter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), + elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData])) } private def createCodeForMap( @@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData])) } @tailrec @@ -140,7 +143,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", FalseLiteral, input) + case _ => ExprCode(FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2fb441ac4500e..01b4d6c4529bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt))) + ExprCode( + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -109,7 +109,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == "false") { + if (input.isNull == FalseLiteral) { s""" |${input.code} |${writeField.trim} @@ -292,8 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro |$writeExpressions """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. - ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", - canDirectAccess = true)) + ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow])) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala new file mode 100644 index 0000000000000..74ff018488863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import java.lang.{Boolean => JBool} + +import scala.language.{existentials, implicitConversions} + +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Trait representing an opaque fragments of java code. + */ +trait JavaCode { + def code: String + override def toString: String = code +} + +/** + * Utility functions for creating [[JavaCode]] fragments. + */ +object JavaCode { + /** + * Create a java literal. + */ + def literal(v: String, dataType: DataType): LiteralValue = dataType match { + case BooleanType if v == "true" => TrueLiteral + case BooleanType if v == "false" => FalseLiteral + case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a default literal. This is null for reference types, false for boolean types and + * -1 for other primitive types. + */ + def defaultLiteral(dataType: DataType): LiteralValue = { + new LiteralValue( + CodeGenerator.defaultValue(dataType, typedNull = true), + CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, dataType: DataType): VariableValue = { + variable(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, javaClass: Class[_]): VariableValue = { + VariableValue(name, javaClass) + } + + /** + * Create a local isNull variable. + */ + def isNullVariable(name: String): VariableValue = variable(name, BooleanType) + + /** + * Create a global java variable. + */ + def global(name: String, dataType: DataType): GlobalValue = { + global(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a global java variable. + */ + def global(name: String, javaClass: Class[_]): GlobalValue = { + GlobalValue(name, javaClass) + } + + /** + * Create a global isNull variable. + */ + def isNullGlobal(name: String): GlobalValue = global(name, BooleanType) + + /** + * Create an expression fragment. + */ + def expression(code: String, dataType: DataType): SimpleExprValue = { + expression(code, CodeGenerator.javaClass(dataType)) + } + + /** + * Create an expression fragment. + */ + def expression(code: String, javaClass: Class[_]): SimpleExprValue = { + SimpleExprValue(code, javaClass) + } + + /** + * Create a isNull expression fragment. + */ + def isNullExpression(code: String): SimpleExprValue = { + expression(code, BooleanType) + } +} + +/** + * A typed java fragment that must be a valid java expression. + */ +trait ExprValue extends JavaCode { + def javaType: Class[_] + def isPrimitive: Boolean = javaType.isPrimitive +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + + +/** + * A java expression fragment. + */ +case class SimpleExprValue(expr: String, javaType: Class[_]) extends ExprValue { + override def code: String = s"($expr)" +} + +/** + * A local variable java expression. + */ +case class VariableValue(variableName: String, javaType: Class[_]) extends ExprValue { + override def code: String = variableName +} + +/** + * A global variable java expression. + */ +case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { + override def code: String = value +} + +/** + * A literal java expression. + */ +class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { + override def code: String = value + + override def equals(arg: Any): Boolean = arg match { + case l: LiteralValue => l.javaType == javaType && l.value == value + case _ => false + } + + override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() +} + +case object TrueLiteral extends LiteralValue("true", JBool.TYPE) +case object FalseLiteral extends LiteralValue("false", JBool.TYPE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 49a8d12057188..67876a8565488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 409c0b6b79b81..205d77f6a9acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,8 +191,9 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), - CodeGenerator.javaType(dataType)) + ev.value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + dataType) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 49dd988b4b53c..32fdb13afbbfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,8 +813,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", - CodeGenerator.javaType(dataType))) + ExprCode.forNullValue(StringType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 742a650eb445d..246025b82d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -281,38 +281,41 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { if (value == null) { ExprCode.forNullValue(dataType) } else { + def toExprCode(code: String): ExprCode = { + ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) + } dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) + toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) + toExprCode("Float.NaN") case Float.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) + toExprCode("Float.POSITIVE_INFINITY") case Float.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) + toExprCode("Float.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) + toExprCode(s"${value}F") } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) + toExprCode("Double.NaN") case Double.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) + toExprCode("Double.POSITIVE_INFINITY") case Double.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) + toExprCode("Double.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) + toExprCode(s"${value}D") } case ByteType | ShortType => - ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) + ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) case TimestampType | LongType => - ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) + toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) + ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7081a5e096d56..7eda65a867028 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -92,7 +92,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, - value = LiteralValue("null", CodeGenerator.javaType(dataType))) + value = JavaCode.defaultLiteral(dataType)) } override def sql: String = s"assert_true(${child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 55b6e346be82a..0787342bce6bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -321,12 +320,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } else { - VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } - ExprCode(code = eval.code, isNull = FalseLiteral, value = value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -352,12 +346,10 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull == TrueLiteral) { - FalseLiteral - } else if (eval.isNull == FalseLiteral) { - TrueLiteral - } else { - StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + val value = eval.isNull match { + case TrueLiteral => FalseLiteral + case FalseLiteral => TrueLiteral + case v => JavaCode.isNullExpression(s"!$v") } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b2cca3178cd2a..50e90ca550807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -65,7 +65,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullGlobal(resultIsNull) } else { FalseLiteral } @@ -569,12 +569,11 @@ case class LambdaVariable( override def genCode(ctx: CodegenContext): ExprCode = { val isNullValue = if (nullable) { - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } else { FalseLiteral } - ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), - isNull = isNullValue) + ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8e83b35c3809c..f7c023111ff59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,8 +448,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) - val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), - VariableValue("dummy", "boolean")) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) // raw testing of basic functionality { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index c8f4cff7db48d..378b8bc055e34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.BooleanType class ExprValueSuite extends SparkFunSuite { @@ -31,16 +32,7 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit.isPrimitive) assert(falseLit.isPrimitive) - trueLit match { - case LiteralValue(value, javaType) => - assert(value == "true" && javaType == "boolean") - case _ => fail() - } - - falseLit match { - case LiteralValue(value, javaType) => - assert(value == "false" && javaType == "boolean") - case _ => fail() - } + assert(trueLit === JavaCode.literal("true", BooleanType)) + assert(falseLit === JavaCode.literal("false", BooleanType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 434214a10e1e3..fc3dbc1c5591b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -52,7 +52,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(ctx.freshName("isNull")) } else { FalseLiteral } @@ -66,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 0d9a62cace62a..e4812f3d338fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,8 +157,10 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) + ExprCode( + code, + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, firstExpr.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 85c5ebfdaa689..f40c50df74ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types._ /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -170,10 +170,11 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), - VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode( + JavaCode.isNullExpression(s"$index == -1"), + JavaCode.variable(index, IntegerType))) } else { - Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode(FalseLiteral, JavaCode.variable(index, IntegerType))) } } else { Seq.empty @@ -316,11 +317,9 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, javaType)) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, - VariableValue(value, javaType)) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 805ff3cf001ba..828b51fa199de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) + ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -128,8 +128,8 @@ trait CodegenSupport extends SparkPlan { """.stripMargin.trim ExprCode(code, FalseLiteral, ev.value) } else { - // There is no columns - ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) + // There are no columns + ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow])) } } } @@ -246,11 +246,10 @@ trait CodegenSupport extends SparkPlan { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } - paramVars += ExprCode("", paramIsNull, - VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) + paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType)) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f7f10243d4cc..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,8 +194,10 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 4978954271311..de2d630de3fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ /** @@ -54,8 +54,10 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index cab7081400ce9..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) + val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fa62a32d51f3e..6fa716d9fadee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, LongType} import org.apache.spark.util.TaskCompletionListener /** @@ -192,8 +192,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) } } } @@ -488,8 +487,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", FalseLiteral, - VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) + val resultVar = input ++ Seq(ExprCode.forNonNullValue( + JavaCode.variable(existsVar, BooleanType))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b61acb8d4fda9..d8261f0f33b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,11 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, -ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -531,13 +530,12 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), + leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, FalseLiteral, - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } }.unzip } From 271c891b91917d660d1f6b995de397c47c7a6058 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 11 Apr 2018 21:52:48 +0800 Subject: [PATCH 0601/2461] [SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient ## What changes were proposed in this pull request? Mark `HashAggregateExec.bufVars` as transient to avoid it from being serialized. Also manually null out this field at the end of `doProduceWithoutKeys()` to shorten its lifecycle, because it'll no longer be used after that. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21039 from rednaxelafx/codegen-improve. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..965950ed94fe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. + @transient private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,6 +238,8 @@ case class HashAggregateExec( | } """.stripMargin) + bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner + val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 653fe02415a537299e15f92b56045569864b6183 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 09:49:25 -0500 Subject: [PATCH 0602/2461] [SPARK-6951][CORE] Speed up parsing of event logs during listing. This change introduces two optimizations to help speed up generation of listing data when parsing events logs. The first one allows the parser to be stopped when enough data to create the listing entry has been read. This is currently the start event plus environment info, to capture UI ACLs. If the end event is needed, the code will skip to the end of the log to try to find that information, instead of parsing the whole log file. Unfortunately this works better with uncompressed logs. Skipping bytes on compressed logs only saves the work of parsing lines and some events, so not a lot of gains are observed. The second optimization deals with in-progress logs. It works in two ways: first, it completely avoids parsing the rest of the log for these apps when enough listing data is read. This, unlike the above, also speeds things up for compressed logs, since only the very beginning of the log has to be read. On top of that, the code that decides whether to re-parse logs to get updated listing data will ignore in-progress applications until they've completed. Both optimizations can be disabled but are enabled by default. I tested this on some fake event logs to see the effect. I created 500 logs of about 60M each (so ~30G uncompressed; each log was 1.7M when compressed with zstd). Below, C = completed, IP = in-progress, the size means the amount of data re-parsed at the end of logs when necessary. ``` none/C none/IP zstd/C zstd/IP On / 16k 2s 2s 22s 2s On / 1m 3s 2s 24s 2s Off 1.1m 1.1m 26s 24s ``` This was with 4 threads on a single local SSD. As expected from the previous explanations, there are considerable gains for in-progress logs, and for uncompressed logs, but not so much when looking at the full compressed log. As a side note, I removed the custom code to get the scan time by creating a file on HDFS; since file mod times are not used to detect changed logs anymore, local time is enough for the current use of the SHS. Author: Marcelo Vanzin Closes #20952 from vanzin/SPARK-6951. --- .../deploy/history/FsHistoryProvider.scala | 251 ++++++++++++------ .../apache/spark/deploy/history/config.scala | 15 ++ .../spark/scheduler/ReplayListenerBus.scala | 11 + .../org/apache/spark/util/ListenerBus.scala | 5 +- .../history/FsHistoryProviderSuite.scala | 78 ++++-- 5 files changed, 264 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ace6d9e00c838..56db9359e033f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, ServiceLoader, UUID} +import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.io.Source import scala.util.Try import scala.xml.Node @@ -58,10 +59,10 @@ import org.apache.spark.util.kvstore._ * * == How new and updated attempts are detected == * - * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any - * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new attempt info entry - * and update or create a matching application info element in the list of applications. + * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any entries in the + * log dir whose size changed since the last scan time are considered new or updated. These are + * replayed to create a new attempt info entry and update or create a matching application info + * element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. * @@ -125,6 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) + private val fastInProgressParsing = conf.get(FAST_IN_PROGRESS_PARSING) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => @@ -402,13 +404,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { - val newLastScanTime = getNewLastScanTime() + val newLastScanTime = clock.getTimeMillis() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // FsHistoryProvider used to generate a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && @@ -417,15 +419,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .filter { entry => try { val info = listing.read(classOf[LogInfo], entry.getPath().toString()) - if (info.fileSize < entry.getLen()) { - // Log size has changed, it should be parsed. - true - } else { + + if (info.appId.isDefined) { // If the SHS view has a valid application, update the time the file was last seen so - // that the entry is not deleted from the SHS listing. - if (info.appId.isDefined) { - listing.write(info.copy(lastProcessed = newLastScanTime)) + // that the entry is not deleted from the SHS listing. Also update the file size, in + // case the code below decides we don't need to parse the log. + listing.write(info.copy(lastProcessed = newLastScanTime, fileSize = entry.getLen())) + } + + if (info.fileSize < entry.getLen()) { + if (info.appId.isDefined && fastInProgressParsing) { + // When fast in-progress parsing is on, we don't need to re-parse when the + // size changes, but we do need to invalidate any existing UIs. + invalidateUI(info.appId.get, info.attemptId) + false + } else { + true } + } else { false } } catch { @@ -449,7 +460,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val tasks = updated.map { entry => try { replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) }) } catch { // let the iteration over the updated entries break, since an exception on @@ -542,25 +553,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - private[history] def getNewLastScanTime(): Long = { - val fileName = "." + UUID.randomUUID().toString - val path = new Path(logDir, fileName) - val fos = fs.create(path) - - try { - fos.close() - fs.getFileStatus(path).getModificationTime - } catch { - case e: Exception => - logError("Exception encountered when attempting to update last scan time", e) - lastScanTime.get() - } finally { - if (!fs.delete(path, true)) { - logWarning(s"Error deleting ${path}") - } - } - } - override def writeEventLogs( appId: String, attemptId: Option[String], @@ -607,7 +599,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { + protected def mergeApplicationListing( + fileStatus: FileStatus, + scanTime: Long, + enableOptimizations: Boolean): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -616,32 +611,118 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() + val appCompleted = isCompleted(logPath.getName()) + val reparseChunkSize = conf.get(END_EVENT_REPARSE_CHUNK_SIZE) + + // Enable halt support in listener if: + // - app in progress && fast parsing enabled + // - skipping to end event is enabled (regardless of in-progress state) + val shouldHalt = enableOptimizations && + ((!appCompleted && fastInProgressParsing) || reparseChunkSize > 0) + val bus = new ReplayListenerBus() - val listener = new AppListingListener(fileStatus, clock) + val listener = new AppListingListener(fileStatus, clock, shouldHalt) bus.addListener(listener) - replay(fileStatus, bus, eventsFilter = eventsFilter) - - val (appId, attemptId) = listener.applicationInfo match { - case Some(app) => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + + logInfo(s"Parsing $logPath for listing data...") + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !appCompleted, eventsFilter) + } + + // If enabled above, the listing listener will halt parsing when there's enough information to + // create a listing entry. When the app is completed, or fast parsing is disabled, we still need + // to replay until the end of the log file to try to find the app end event. Instead of reading + // and parsing line by line, this code skips bytes from the underlying stream so that it is + // positioned somewhere close to the end of the log file. + // + // Because the application end event is written while some Spark subsystems such as the + // scheduler are still active, there is no guarantee that the end event will be the last + // in the log. So, to be safe, the code uses a configurable chunk to be re-parsed at + // the end of the file, and retries parsing the whole log later if the needed data is + // still not found. + // + // Note that skipping bytes in compressed files is still not cheap, but there are still some + // minor gains over the normal log parsing done by the replay bus. + // + // This code re-opens the file so that it knows where it's skipping to. This isn't as cheap as + // just skipping from the current position, but there isn't a a good way to detect what the + // current position is, since the replay listener bus buffers data internally. + val lookForEndEvent = shouldHalt && (appCompleted || !fastInProgressParsing) + if (lookForEndEvent && listener.applicationInfo.isDefined) { + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + val target = fileStatus.getLen() - reparseChunkSize + if (target > 0) { + logInfo(s"Looking for end event; skipping $target bytes from $logPath...") + var skipped = 0L + while (skipped < target) { + skipped += in.skip(target - skipped) } } + val source = Source.fromInputStream(in).getLines() + + // Because skipping may leave the stream in the middle of a line, read the next line + // before replaying. + if (target > 0) { + source.next() + } + + bus.replay(source, logPath.toString, !appCompleted, eventsFilter) + } + } + + logInfo(s"Finished parsing $logPath") + + listener.applicationInfo match { + case Some(app) if !lookForEndEvent || app.attempts.head.info.completed => + // In this case, we either didn't care about the end event, or we found it. So the + // listing data is good. + invalidateUI(app.info.id, app.attempts.head.info.attemptId) addListing(app) - (Some(app.info.id), app.attempts.head.info.attemptId) + listing.write(LogInfo(logPath.toString(), scanTime, Some(app.info.id), + app.attempts.head.info.attemptId, fileStatus.getLen())) + + // For a finished log, remove the corresponding "in progress" entry from the listing DB if + // the file is really gone. + if (appCompleted) { + val inProgressLog = logPath.toString() + EventLoggingListener.IN_PROGRESS + try { + // Fetch the entry first to avoid an RPC when it's already removed. + listing.read(classOf[LogInfo], inProgressLog) + if (!fs.isFile(new Path(inProgressLog))) { + listing.delete(classOf[LogInfo], inProgressLog) + } + } catch { + case _: NoSuchElementException => + } + } + + case Some(_) => + // In this case, the attempt is still not marked as finished but was expected to. This can + // mean the end event is before the configured threshold, so call the method again to + // re-parse the whole log. + logInfo(s"Reparsing $logPath since end event was not found.") + mergeApplicationListing(fileStatus, scanTime, false) case _ => // If the app hasn't written down its app ID to the logs, still record the entry in the // listing db, with an empty ID. This will make the log eligible for deletion if the app // does not make progress after the configured max log age. - (None, None) + listing.write(LogInfo(logPath.toString(), scanTime, None, None, fileStatus.getLen())) + } + } + + /** + * Invalidate an existing UI for a given app attempt. See LoadedAppUI for a discussion on the + * UI lifecycle. + */ + private def invalidateUI(appId: String, attemptId: Option[String]): Unit = { + synchronized { + activeUIs.get((appId, attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** @@ -696,29 +777,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. - * `ReplayEventsFilter` determines what events are replayed. - */ - private def replay( - eventLog: FileStatus, - bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { - val logPath = eventLog.getPath() - val isCompleted = !logPath.getName().endsWith(EventLoggingListener.IN_PROGRESS) - logInfo(s"Replaying log path: $logPath") - // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, - // and when we read the file here. That is OK -- it may result in an unnecessary refresh - // when there is no update, but will not result in missing an update. We *must* prevent - // an error the other way -- if we report a size bigger (ie later) than the file that is - // actually read, we may never refresh the app. FileStatus is guaranteed to be static - // after it's created, so we get a file size that is no bigger than what is actually read. - Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => - bus.replay(in, logPath.toString, !isCompleted, eventsFilter) - logInfo(s"Finished parsing $logPath") - } - } - /** * Rebuilds the application state store from its event log. */ @@ -741,8 +799,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } replayBus.addListener(listener) try { - replay(eventLog, replayBus) + val path = eventLog.getPath() + logInfo(s"Parsing $path to re-build UI...") + Utils.tryWithResource(EventLoggingListener.openEventLog(path, fs)) { in => + replayBus.replay(in, path.toString(), maybeTruncated = !isCompleted(path.toString())) + } trackingStore.close(false) + logInfo(s"Finished parsing $path") } catch { case e: Exception => Utils.tryLogNonFatalError { @@ -881,6 +944,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + private def isCompleted(name: String): Boolean = { + !name.endsWith(EventLoggingListener.IN_PROGRESS) + } + } private[history] object FsHistoryProvider { @@ -945,11 +1012,17 @@ private[history] class ApplicationInfoWrapper( } -private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { +private[history] class AppListingListener( + log: FileStatus, + clock: Clock, + haltEnabled: Boolean) extends SparkListener { private val app = new MutableApplicationInfo() private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + private var gotEnvUpdate = false + private var halted = false + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { app.id = event.appId.orNull app.name = event.appName @@ -958,6 +1031,8 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.startTime = new Date(event.time) attempt.lastUpdated = new Date(clock.getTimeMillis()) attempt.sparkUser = event.sparkUser + + checkProgress() } override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { @@ -968,11 +1043,18 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { - val allProperties = event.environmentDetails("Spark Properties").toMap - attempt.viewAcls = allProperties.get("spark.ui.view.acls") - attempt.adminAcls = allProperties.get("spark.admin.acls") - attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + // Only parse the first env update, since any future changes don't have any effect on + // the ACLs set for the UI. + if (!gotEnvUpdate) { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + + gotEnvUpdate = true + checkProgress() + } } override def onOtherEvent(event: SparkListenerEvent): Unit = event match { @@ -989,6 +1071,17 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } } + /** + * Throws a halt exception to stop replay if enough data to create the app listing has been + * read. + */ + private def checkProgress(): Unit = { + if (haltEnabled && !halted && app.id != null && gotEnvUpdate) { + halted = true + throw new HaltReplayException() + } + } + private class MutableApplicationInfo { var id: String = null var name: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index efdbf672bb52f..25ba9edb9e014 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -49,4 +49,19 @@ private[spark] object config { .intConf .createWithDefault(18080) + val FAST_IN_PROGRESS_PARSING = + ConfigBuilder("spark.history.fs.inProgressOptimization.enabled") + .doc("Enable optimized handling of in-progress logs. This option may leave finished " + + "applications that fail to rename their event logs listed as in-progress.") + .booleanConf + .createWithDefault(true) + + val END_EVENT_REPARSE_CHUNK_SIZE = + ConfigBuilder("spark.history.fs.endEventReparseChunkSize") + .doc("How many bytes to parse at the end of log files looking for the end event. " + + "This is used to speed up generation of application listings by skipping unnecessary " + + "parts of event log files. It can be disabled by setting this config to 0.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index c9cd662f5709d..226c23733c870 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -115,6 +115,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } } catch { + case e: HaltReplayException => + // Just stop replay. case _: EOFException if maybeTruncated => case ioe: IOException => throw ioe @@ -124,8 +126,17 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + override protected def isIgnorableException(e: Throwable): Boolean = { + e.isInstanceOf[HaltReplayException] + } + } +/** + * Exception that can be thrown by listeners to halt replay. This is handled by ReplayListenerBus + * only, and will cause errors if thrown when using other bus implementations. + */ +private[spark] class HaltReplayException extends RuntimeException private[spark] object ReplayListenerBus { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 76a56298aaebc..b25a731401f23 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -81,7 +81,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { try { doPostEvent(listener, event) } catch { - case NonFatal(e) => + case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { if (maybeTimerContext != null) { @@ -97,6 +97,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ protected def doPostEvent(listener: L, event: E): Unit + /** Allows bus implementations to prevent error logging for certain exceptions. */ + protected def isIgnorableException(e: Throwable): Boolean = false + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 0ba57bf4563c1..77b239489d489 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{doReturn, mock, spy, verify} +import org.mockito.Mockito.{mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -151,8 +151,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc var mergeApplicationListingCall = 0 override protected def mergeApplicationListing( fileStatus: FileStatus, - lastSeen: Long): Unit = { - super.mergeApplicationListing(fileStatus, lastSeen) + lastSeen: Long, + enableSkipToEnd: Boolean): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen, enableSkipToEnd) mergeApplicationListingCall += 1 } } @@ -256,14 +257,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => - list should not be (null) list.size should be (1) list.head.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt1, true, None, + writeFile(app2Attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -649,8 +649,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Add more info to the app log, and trigger the provider to update things. writeFile(appLog, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + SparkListenerJobStart(0, 1L, Nil, null) ) provider.checkForLogs() @@ -668,11 +667,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("clean up stale app information") { val storeDir = Utils.createTempDir() val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - val provider = spy(new FsHistoryProvider(conf)) + val clock = new ManualClock() + val provider = spy(new FsHistoryProvider(conf, clock)) val appId = "new1" // Write logs for two app attempts. - doReturn(1L).when(provider).getNewLastScanTime() + clock.advance(1) val attempt1 = newLogFile(appId, Some("1"), inProgress = false) writeFile(attempt1, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), @@ -697,7 +697,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since // attempt 2 still exists, listing data should be there. - doReturn(2L).when(provider).getNewLastScanTime() + clock.advance(1) attempt1.delete() updateAndCheck(provider) { list => assert(list.size === 1) @@ -708,7 +708,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(provider.getAppUI(appId, None) === None) // Delete the second attempt's log file. Now everything should go away. - doReturn(3L).when(provider).getNewLastScanTime() + clock.advance(1) attempt2.delete() updateAndCheck(provider) { list => assert(list.isEmpty) @@ -718,9 +718,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-21571: clean up removes invalid history files") { val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") - val provider = new FsHistoryProvider(conf, clock) { - override def getNewLastScanTime(): Long = clock.getTimeMillis() - } + val provider = new FsHistoryProvider(conf, clock) // Create 0-byte size inprogress and complete files var logCount = 0 @@ -772,6 +770,54 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(new File(testDir.toURI).listFiles().size === validLogCount) } + test("always find end event for finished apps") { + // Create a log file where the end event is before the configure chunk to be reparsed at + // the end of the file. The correct listing should still be generated. + val log = newLogFile("end-event-test", None, inProgress = false) + writeFile(log, true, None, + Seq( + SparkListenerApplicationStart("end-event-test", Some("end-event-test"), 1L, "test", None), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> Seq.empty, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(5L) + ) ++ (1 to 1000).map { i => SparkListenerJobStart(i, i, Nil) }: _*) + + val conf = createTestConf().set(END_EVENT_REPARSE_CHUNK_SIZE.key, s"1k") + val provider = new FsHistoryProvider(conf) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).attempts.size === 1) + assert(list(0).attempts(0).completed) + } + } + + test("parse event logs with optimizations off") { + val conf = createTestConf() + .set(END_EVENT_REPARSE_CHUNK_SIZE, 0L) + .set(FAST_IN_PROGRESS_PARSING, false) + val provider = new FsHistoryProvider(conf) + + val complete = newLogFile("complete", None, inProgress = false) + writeFile(complete, true, None, + SparkListenerApplicationStart("complete", Some("complete"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + val incomplete = newLogFile("incomplete", None, inProgress = true) + writeFile(incomplete, true, None, + SparkListenerApplicationStart("incomplete", Some("incomplete"), 1L, "test", None) + ) + + updateAndCheck(provider) { list => + list.size should be (2) + list.count(_.attempts.head.completed) should be (1) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -815,7 +861,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private def createTestConf(inMemory: Boolean = false): SparkConf = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set(EVENT_LOG_DIR, testDir.getAbsolutePath()) + .set(FAST_IN_PROGRESS_PARSING, true) if (!inMemory) { conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) @@ -848,4 +895,3 @@ class TestGroupsMappingProvider extends GroupMappingServiceProvider { mappings.get(username).map(Set(_)).getOrElse(Set.empty) } } - From 3cb82047f2f51af553df09b9323796af507d36f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 10:13:44 -0500 Subject: [PATCH 0603/2461] [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. The current in-process launcher implementation just calls the SparkSubmit object, which, in case of errors, will more often than not exit the JVM. This is not desirable since this launcher is meant to be used inside other applications, and that would kill the application. The change turns SparkSubmit into a class, and abstracts aways some of the functionality used to print error messages and abort the submission process. The default implementation uses the logging system for messages, and throws exceptions for errors. As part of that I also moved some code that doesn't really belong in SparkSubmit to a better location. The command line invocation of spark-submit now uses a special implementation of the SparkSubmit class that overrides those behaviors to do what is expected from the command line version (print to the terminal, exit the JVM, etc). A lot of the changes are to replace calls to methods such as "printErrorAndExit" with the new API. As part of adding tests for this, I had to fix some small things in the launcher option parser so that things like "--version" can work when used in the launcher library. There is still code that prints directly to the terminal, like all the Ivy-related code in SparkSubmitUtils, and other areas where some re-factoring would help, like the CommandLineUtils class, but I chose to leave those alone to keep this change more focused. Aside from existing and added unit tests, I ran command line tools with a bunch of different arguments to make sure messages and errors behave like before. Author: Marcelo Vanzin Closes #20925 from vanzin/SPARK-22941. --- .../apache/spark/deploy/DependencyUtils.scala | 30 +- .../org/apache/spark/deploy/SparkSubmit.scala | 318 +++++++++--------- .../spark/deploy/SparkSubmitArguments.scala | 90 +++-- .../spark/deploy/worker/DriverWrapper.scala | 4 +- .../apache/spark/util/CommandLineUtils.scala | 18 +- .../spark/launcher/SparkLauncherSuite.java | 37 +- .../spark/deploy/SparkSubmitSuite.scala | 69 ++-- .../rest/StandaloneRestSubmitSuite.scala | 2 +- .../spark/launcher/AbstractLauncher.java | 6 +- .../spark/launcher/InProcessLauncher.java | 14 +- .../launcher/SparkSubmitCommandBuilder.java | 82 +++-- project/MimaExcludes.scala | 7 +- .../deploy/mesos/MesosClusterDispatcher.scala | 10 +- .../MesosClusterDispatcherArguments.scala | 6 +- 14 files changed, 401 insertions(+), 292 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index fac834a70b893..178bdcfccb603 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.{MutableURLClassLoader, Utils} -private[deploy] object DependencyUtils { +private[deploy] object DependencyUtils extends Logging { def resolveMavenDependencies( packagesExclusions: String, @@ -75,7 +76,7 @@ private[deploy] object DependencyUtils { def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { if (jars != null) { for (jar <- jars.split(",")) { - SparkSubmit.addJarToClasspath(jar, loader) + addJarToClasspath(jar, loader) } } } @@ -151,6 +152,31 @@ private[deploy] object DependencyUtils { }.mkString(",") } + def addJarToClasspath(localJar: String, loader: MutableURLClassLoader): Unit = { + val uri = Utils.resolveURI(localJar) + uri.getScheme match { + case "file" | "local" => + val file = new File(uri.getPath) + if (file.exists()) { + loader.addURL(file.toURI.toURL) + } else { + logWarning(s"Local jar $file does not exist, skipping.") + } + case _ => + logWarning(s"Skip remote jar $uri.") + } + } + + /** + * Merge a sequence of comma-separated file lists, some of which may be null to indicate + * no files, into a single comma-separated string. + */ + def mergeFileLists(lists: String*): String = { + val merged = lists.filterNot(StringUtils.isBlank) + .flatMap(Utils.stringToSeq) + if (merged.nonEmpty) merged.mkString(",") else null + } + private def splitOnFragment(path: String): (URI, Option[String]) = { val uri = Utils.resolveURI(path) val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index eddbedeb1024d..427c797755b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ */ private[deploy] object SparkSubmitAction extends Enumeration { type SparkSubmitAction = Value - val SUBMIT, KILL, REQUEST_STATUS = Value + val SUBMIT, KILL, REQUEST_STATUS, PRINT_VERSION = Value } /** @@ -67,78 +67,32 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit extends CommandLineUtils with Logging { +private[spark] class SparkSubmit extends Logging { import DependencyUtils._ + import SparkSubmit._ - // Cluster managers - private val YARN = 1 - private val STANDALONE = 2 - private val MESOS = 4 - private val LOCAL = 8 - private val KUBERNETES = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES - - // Deploy modes - private val CLIENT = 1 - private val CLUSTER = 2 - private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - - // Special primary resource names that represent shells rather than application jars. - private val SPARK_SHELL = "spark-shell" - private val PYSPARK_SHELL = "pyspark-shell" - private val SPARKR_SHELL = "sparkr-shell" - private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" - private val R_PACKAGE_ARCHIVE = "rpkg.zip" - - private val CLASS_NOT_FOUND_EXIT_STATUS = 101 - - // Following constants are visible for testing. - private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.yarn.YarnClusterApplication" - private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() - private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() - private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" - - // scalastyle:off println - private[spark] def printVersionAndExit(): Unit = { - printStream.println("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - printStream.println("Using Scala %s, %s, %s".format( - Properties.versionString, Properties.javaVmName, Properties.javaVersion)) - printStream.println("Branch %s".format(SPARK_BRANCH)) - printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) - printStream.println("Revision %s".format(SPARK_REVISION)) - printStream.println("Url %s".format(SPARK_REPO_URL)) - printStream.println("Type --help for more information.") - exitFn(0) - } - // scalastyle:on println - - override def main(args: Array[String]): Unit = { + def doSubmit(args: Array[String]): Unit = { // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to // be reset before the application starts. val uninitLog = initializeLogIfNecessary(true, silent = true) - val appArgs = new SparkSubmitArguments(args) + val appArgs = parseArguments(args) if (appArgs.verbose) { - // scalastyle:off println - printStream.println(appArgs) - // scalastyle:on println + logInfo(appArgs.toString) } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.PRINT_VERSION => printVersion() } } + protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) + } + /** * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ @@ -156,6 +110,24 @@ object SparkSubmit extends CommandLineUtils with Logging { .requestSubmissionStatus(args.submissionToRequestStatusFor) } + /** Print version information to the log. */ + private def printVersion(): Unit = { + logInfo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + logInfo("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) + logInfo(s"Branch $SPARK_BRANCH") + logInfo(s"Compiled by user $SPARK_BUILD_USER on $SPARK_BUILD_DATE") + logInfo(s"Revision $SPARK_REVISION") + logInfo(s"Url $SPARK_REPO_URL") + logInfo("Type --help for more information.") + } + /** * Submit the application using the provided parameters. * @@ -185,10 +157,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { - // scalastyle:off println - printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - // scalastyle:on println - exitFn(1) + error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") } else { throw e } @@ -210,14 +179,11 @@ object SparkSubmit extends CommandLineUtils with Logging { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { - // scalastyle:off println - printStream.println("Running Spark using the REST application submission protocol.") - // scalastyle:on println - doRunMain() + logInfo("Running Spark using the REST application submission protocol.") } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => - printWarning(s"Master endpoint ${args.master} was not a REST server. " + + logWarning(s"Master endpoint ${args.master} was not a REST server. " + "Falling back to legacy submission gateway instead.") args.useRest = false submit(args, false) @@ -245,19 +211,6 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { - try { - doPrepareSubmitEnvironment(args, conf) - } catch { - case e: SparkException => - printErrorAndExit(e.getMessage) - throw e - } - } - - private def doPrepareSubmitEnvironment( - args: SparkSubmitArguments, - conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() @@ -268,7 +221,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val clusterManager: Int = args.master match { case "yarn" => YARN case "yarn-client" | "yarn-cluster" => - printWarning(s"Master ${args.master} is deprecated since 2.0." + + logWarning(s"Master ${args.master} is deprecated since 2.0." + " Please use master \"yarn\" with specified deploy mode instead.") YARN case m if m.startsWith("spark") => STANDALONE @@ -276,7 +229,7 @@ object SparkSubmit extends CommandLineUtils with Logging { case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") + error("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -284,7 +237,9 @@ object SparkSubmit extends CommandLineUtils with Logging { var deployMode: Int = args.deployMode match { case "client" | null => CLIENT case "cluster" => CLUSTER - case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 + case _ => + error("Deploy mode must be either client or cluster") + -1 } // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both @@ -296,16 +251,16 @@ object SparkSubmit extends CommandLineUtils with Logging { deployMode = CLUSTER args.master = "yarn" case ("yarn-cluster", "client") => - printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + error("Client deploy mode is not compatible with master \"yarn-cluster\"") case ("yarn-client", "cluster") => - printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + error("Cluster deploy mode is not compatible with master \"yarn-client\"") case (_, mode) => args.master = "yarn" } // Make sure YARN is included in our build if we're trying to use it if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load YARN classes. " + "This copy of Spark may not have been compiled with YARN support.") } @@ -315,7 +270,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.master = Utils.checkAndGetK8sMasterUrl(args.master) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load KUBERNETES classes. " + "This copy of Spark may not have been compiled with KUBERNETES support.") } @@ -324,23 +279,23 @@ object SparkSubmit extends CommandLineUtils with Logging { // Fail fast, the following modes are not supported or applicable (clusterManager, deployMode) match { case (STANDALONE, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + error("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") case (STANDALONE, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + + error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") case (KUBERNETES, _) if args.isPython => - printErrorAndExit("Python applications are currently not supported for Kubernetes.") + error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => - printErrorAndExit("R applications are currently not supported for Kubernetes.") + error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => - printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") + error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + error("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + error("Cluster deploy mode is not applicable to Spark SQL shell.") case (_, CLUSTER) if isThriftServer(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") + error("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } @@ -493,11 +448,11 @@ object SparkSubmit extends CommandLineUtils with Logging { if (args.isR && clusterManager == YARN) { val sparkRPackagePath = RUtils.localSparkRPackagePath if (sparkRPackagePath.isEmpty) { - printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + error("SPARK_HOME does not exist for R application in YARN mode.") } val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) if (!sparkRPackageFile.exists()) { - printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + error(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString @@ -510,7 +465,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val rPackageFile = RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) if (!rPackageFile.exists()) { - printErrorAndExit("Failed to zip all the built R packages.") + error("Failed to zip all the built R packages.") } val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString @@ -521,12 +476,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // TODO: Support distributing R packages with standalone cluster if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + error("Distributing R packages with standalone cluster is not supported.") } // TODO: Support distributing R packages with mesos cluster if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with mesos cluster is not supported.") + error("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -799,9 +754,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" - // scalastyle:off println - printStream.println(s"Setting ${key} to ${shortUserName}") - // scalastyle:off println + logInfo(s"Setting ${key} to ${shortUserName}") sparkConf.set(key, shortUserName) } @@ -817,16 +770,14 @@ object SparkSubmit extends CommandLineUtils with Logging { sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { - // scalastyle:off println if (verbose) { - printStream.println(s"Main class:\n$childMainClass") - printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") + logInfo(s"Main class:\n$childMainClass") + logInfo(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") - printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") - printStream.println("\n") + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") + logInfo("\n") } - // scalastyle:on println val loader = if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { @@ -848,23 +799,19 @@ object SparkSubmit extends CommandLineUtils with Logging { mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass.", e) if (childMainClass.contains("thriftserver")) { - // scalastyle:off println - printStream.println(s"Failed to load main class $childMainClass.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load main class $childMainClass.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) case e: NoClassDefFoundError => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass: ${e.getMessage()}") if (e.getMessage.contains("org/apache/hadoop/hive")) { - // scalastyle:off println - printStream.println(s"Failed to load hive class.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load hive class.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { @@ -872,7 +819,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + logWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") } new JavaMainApplication(mainClass) } @@ -891,29 +838,90 @@ object SparkSubmit extends CommandLineUtils with Logging { app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => - findCause(t) match { - case SparkUserAppException(exitCode) => - System.exit(exitCode) - - case t: Throwable => - throw t - } + throw findCause(t) } } - private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { - val uri = Utils.resolveURI(localJar) - uri.getScheme match { - case "file" | "local" => - val file = new File(uri.getPath) - if (file.exists()) { - loader.addURL(file.toURI.toURL) - } else { - printWarning(s"Local jar $file does not exist, skipping.") + /** Throw a SparkException with the given error message. */ + private def error(msg: String): Unit = throw new SparkException(msg) + +} + + +/** + * This entry point is used by the launcher library to start in-process Spark applications. + */ +private[spark] object InProcessSparkSubmit { + + def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() + submit.doSubmit(args) + } + +} + +object SparkSubmit extends CommandLineUtils with Logging { + + // Cluster managers + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES + + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + + // Special primary resource names that represent shells rather than application jars. + private val SPARK_SHELL = "spark-shell" + private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" + + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" + + override def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() { + self => + + override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) { + override protected def logInfo(msg: => String): Unit = self.logInfo(msg) + + override protected def logWarning(msg: => String): Unit = self.logWarning(msg) } - case _ => - printWarning(s"Skip remote jar $uri.") + } + + override protected def logInfo(msg: => String): Unit = printMessage(msg) + + override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + + override def doSubmit(args: Array[String]): Unit = { + try { + super.doSubmit(args) + } catch { + case e: SparkUserAppException => + exitFn(e.exitCode) + case e: SparkException => + printErrorAndExit(e.getMessage()) + } + } + } + + submit.doSubmit(args) } /** @@ -962,17 +970,6 @@ object SparkSubmit extends CommandLineUtils with Logging { res == SparkLauncher.NO_RESOURCE } - /** - * Merge a sequence of comma-separated file lists, some of which may be null to indicate - * no files, into a single comma-separated string. - */ - private[deploy] def mergeFileLists(lists: String*): String = { - val merged = lists.filterNot(StringUtils.isBlank) - .flatMap(_.split(",")) - .mkString(",") - if (merged == "") null else merged - } - } /** Provides utility functions to be used inside SparkSubmit. */ @@ -1000,12 +997,12 @@ private[spark] object SparkSubmitUtils { override def toString: String = s"$groupId:$artifactId:$version" } -/** - * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided - * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. - * @param coordinates Comma-delimited string of maven coordinates - * @return Sequence of Maven coordinates - */ + /** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { coordinates.split(",").map { p => val splits = p.replace("/", ":").split(":") @@ -1304,6 +1301,13 @@ private[spark] object SparkSubmitUtils { rule } + def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => throw new SparkException(s"Spark config without '=': $pair") + } + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 8e7070593687b..0733fdb72cafb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -29,7 +29,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source import scala.util.Try +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -40,7 +42,7 @@ import org.apache.spark.util.Utils * The env argument is used for testing. */ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) - extends SparkSubmitArgumentsParser { + extends SparkSubmitArgumentsParser with Logging { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -85,8 +87,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() - // scalastyle:off println - if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") + if (verbose) { + logInfo(s"Using properties file: $propertiesFile") + } Option(propertiesFile).foreach { filename => val properties = Utils.getPropertiesFromFile(filename) properties.foreach { case (k, v) => @@ -95,21 +98,16 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Property files may contain sensitive information, so redact before printing if (verbose) { Utils.redact(properties).foreach { case (k, v) => - SparkSubmit.printStream.println(s"Adding default property: $k=$v") + logInfo(s"Adding default property: $k=$v") } } } - // scalastyle:on println defaultProperties } // Set parameters from command line arguments - try { - parse(args.asJava) - } catch { - case e: IllegalArgumentException => - SparkSubmit.printErrorAndExit(e.getMessage()) - } + parse(args.asJava) + // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() // Remove keys that don't start with "spark." from `sparkProperties`. @@ -141,7 +139,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S sparkProperties.foreach { case (k, v) => if (!k.startsWith("spark.")) { sparkProperties -= k - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + logWarning(s"Ignoring non-spark config property: $k=$v") } } } @@ -215,10 +213,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } catch { case _: Exception => - SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") + error(s"Cannot load main class from JAR $primaryResource") } case _ => - SparkSubmit.printErrorAndExit( + error( s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " + "Please specify a class through --class.") } @@ -248,6 +246,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case SUBMIT => validateSubmitArguments() case KILL => validateKillArguments() case REQUEST_STATUS => validateStatusRequestArguments() + case PRINT_VERSION => } } @@ -256,62 +255,61 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") + error("Must specify a primary resource (JAR or Python or R file)") } if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { - SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") + error("No main class set in JAR; please specify one with --class") } if (driverMemory != null && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + error("Driver memory must be a positive number") } if (executorMemory != null && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + error("Executor memory must be a positive number") } if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + error("Executor cores must be a positive number") } if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + error("Total executor cores must be a positive number") } if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { - SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") + error("--py-files given but primary resource is not a Python script") } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { - throw new Exception(s"When running with master '$master' " + + error(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } if (proxyUser != null && principal != null) { - SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + error("Only one of --proxy-user or --principal can be provided.") } } private def validateKillArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( - "Killing submissions is only supported in standalone or Mesos mode!") + error("Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + error("Please specify a submission to kill.") } } private def validateStatusRequestArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( + error( "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + error("Please specify a submission to request status for.") } } @@ -368,7 +366,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case DEPLOY_MODE => if (value != "client" && value != "cluster") { - SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") + error("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value @@ -405,14 +403,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case KILL_SUBMISSION => submissionToKill = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + error(s"Action cannot be both $action and $KILL.") } action = KILL case STATUS => submissionToRequestStatusFor = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + error(s"Action cannot be both $action and $REQUEST_STATUS.") } action = REQUEST_STATUS @@ -444,7 +442,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + val (confName, confValue) = SparkSubmitUtils.parseSparkConfProperty(value) sparkProperties(confName) = confValue case PROXY_USER => @@ -463,15 +461,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S verbose = true case VERSION => - SparkSubmit.printVersionAndExit() + action = SparkSubmitAction.PRINT_VERSION case USAGE_ERROR => printUsageAndExit(1) case _ => - throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + error(s"Unexpected argument '$opt'.") } - true + action != SparkSubmitAction.PRINT_VERSION } /** @@ -482,7 +480,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S */ override protected def handleUnknown(opt: String): Boolean = { if (opt.startsWith("-")) { - SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") + error(s"Unrecognized option '$opt'.") } primaryResource = @@ -501,20 +499,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { - // scalastyle:off println - val outStream = SparkSubmit.printStream if (unknownParam != null) { - outStream.println("Unknown/unsupported param " + unknownParam) + logInfo("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) - outStream.println(command) + logInfo(command) val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB - outStream.println( + logInfo( s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, @@ -596,12 +592,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S ) if (SparkSubmit.isSqlShell(mainClass)) { - outStream.println("CLI options:") - outStream.println(getSqlShellOptions()) + logInfo("CLI options:") + logInfo(getSqlShellOptions()) } - // scalastyle:on println - SparkSubmit.exitFn(exitCode) + throw new SparkUserAppException(exitCode) } /** @@ -655,4 +650,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } + + private def error(msg: String): Unit = throw new SparkException(msg) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 3f71237164a15..8d6a2b80ef5f2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -25,7 +25,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -93,7 +93,7 @@ object DriverWrapper extends Logging { val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { - SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + DependencyUtils.mergeFileLists(jarsProp, resolvedMavenCoordinates) } else { jarsProp } diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala index d73901686b705..4b6602b50aa1c 100644 --- a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -33,24 +33,14 @@ private[spark] trait CommandLineUtils { private[spark] var printStream: PrintStream = System.err // scalastyle:off println - - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printMessage(str: String): Unit = printStream.println(str) + // scalastyle:on println private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") + printMessage("Error: " + str) + printMessage("Run with --help for usage help or --verbose for debug output") exitFn(1) } - // scalastyle:on println - - private[spark] def parseSparkConfProperty(pair: String): (String, String) = { - pair.split("=", 2).toSeq match { - case Seq(k, v) => (k, v) - case _ => printErrorAndExit(s"Spark config without '=': $pair") - throw new SparkException(s"Spark config without '=': $pair") - } - } - def main(args: Array[String]): Unit } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 2225591a4ff75..6a1a38c1a54f4 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -109,7 +109,7 @@ public void testChildProcLauncher() throws Exception { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.appender=childproc") + "-Dfoo=bar -Dtest.appender=console") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) @@ -192,6 +192,41 @@ private void inProcessLauncherTestImpl() throws Exception { } } + @Test + public void testInProcessLauncherDoesNotKillJvm() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + List wrongArgs = Arrays.asList( + new String[] { "--unknown" }, + new String[] { opts.DEPLOY_MODE, "invalid" }); + + for (String[] args : wrongArgs) { + InProcessLauncher launcher = new InProcessLauncher() + .setAppResource(SparkLauncher.NO_RESOURCE); + switch (args.length) { + case 2: + launcher.addSparkArg(args[0], args[1]); + break; + + case 1: + launcher.addSparkArg(args[0]); + break; + + default: + fail("FIXME: invalid test."); + } + + SparkAppHandle handle = launcher.startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.FAILED, handle.getState()); + } + + // Run --version, which is useless as a use case, but should succeed and not exit the JVM. + // The expected state is "LOST" since "--version" doesn't report state back to the handle. + SparkAppHandle handle = new InProcessLauncher().addSparkArg(opts.VERSION).startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.LOST, handle.getState()); + } + public static class SparkLauncherTestApp { public static void main(String[] args) throws Exception { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 0d7c342a5eacd..7451e07b25a1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -109,6 +110,8 @@ class SparkSubmitSuite private val emptyIvySettings = File.createTempFile("ivy", ".xml") FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + private val submit = new SparkSubmit() + override def beforeEach() { super.beforeEach() } @@ -128,13 +131,16 @@ class SparkSubmitSuite } test("handle binary specified but not class") { - testPrematureExit(Array("foo.jar"), "No main class") + val jar = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + testPrematureExit(Array(jar.toString()), "No main class") } test("handles arguments with --key=val") { val clArgs = Seq( "--jars=one.jar,two.jar,three.jar", - "--name=myApp") + "--name=myApp", + "--class=org.FooBar", + SparkLauncher.NO_RESOURCE) val appArgs = new SparkSubmitArguments(clArgs) appArgs.jars should include regex (".*one.jar,.*two.jar,.*three.jar") appArgs.name should be ("myApp") @@ -182,7 +188,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") conf.get("spark.submit.deployMode") should be ("client") @@ -192,11 +198,11 @@ class SparkSubmitSuite "--master", "yarn", "--deploy-mode", "cluster", "--conf", "spark.submit.deployMode=client", - "-class", "org.SomeClass", + "--class", "org.SomeClass", "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") conf1.get("spark.submit.deployMode") should be ("cluster") @@ -210,7 +216,7 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") conf2.get("spark.submit.deployMode") should be ("client") } @@ -233,7 +239,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -276,7 +282,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -322,7 +328,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -359,7 +365,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -381,7 +387,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -403,7 +409,7 @@ class SparkSubmitSuite "/home/thejar.jar", "arg1") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar")) @@ -428,7 +434,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.executor.memory") should be ("5g") conf.get("spark.master") should be ("yarn") conf.get("spark.submit.deployMode") should be ("cluster") @@ -441,12 +447,12 @@ class SparkSubmitSuite val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } @@ -625,7 +631,7 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -640,7 +646,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -656,7 +662,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -708,7 +714,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) conf.get("spark.files") should be(Utils.resolveURIs(files)) @@ -725,7 +731,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -740,7 +746,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -757,7 +763,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -778,17 +784,17 @@ class SparkSubmitSuite } test("SPARK_CONF_DIR overrides spark-defaults.conf") { - forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + forConfDir(Map("spark.executor.memory" -> "3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", unusedJar.toString) - val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + val appArgs = new SparkSubmitArguments(args, env = Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("3g") } } @@ -809,6 +815,9 @@ class SparkSubmitSuite val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + val tempPyFile = File.createTempFile("tmpApp", ".py") + tempPyFile.deleteOnExit() + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", @@ -818,10 +827,10 @@ class SparkSubmitSuite "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", - jar2.toString) + tempPyFile.toURI().toString()) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) conf.get("spark.yarn.dist.files").split(",").toSet should be @@ -947,7 +956,7 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. conf.get("spark.yarn.dist.jars") should be (tmpJarPath) @@ -1007,7 +1016,7 @@ class SparkSubmitSuite ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) val jars = conf.get("spark.yarn.dist.jars").split(",").toSet @@ -1058,7 +1067,7 @@ class SparkSubmitSuite "hello") val exception = intercept[SparkException] { - SparkSubmit.main(args) + submit.doSubmit(args) } assert(exception.getMessage() === "hello") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index e505bc018857d..54c168a8218f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,7 +445,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 44e69fc45dffa..4e02843480e8f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -139,7 +139,7 @@ public T setMainClass(String mainClass) { public T addSparkArg(String arg) { SparkSubmitOptionParser validator = new ArgumentValidator(false); validator.parse(Arrays.asList(arg)); - builder.sparkArgs.add(arg); + builder.userArgs.add(arg); return self(); } @@ -187,8 +187,8 @@ public T addSparkArg(String name, String value) { } } else { validator.parse(Arrays.asList(name, value)); - builder.sparkArgs.add(name); - builder.sparkArgs.add(value); + builder.userArgs.add(name); + builder.userArgs.add(value); } return self(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java index 6d726b4a69a86..688e1f763c205 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java @@ -89,10 +89,18 @@ Method findSparkSubmit() throws IOException { } Class sparkSubmit; + // SPARK-22941: first try the new SparkSubmit interface that has better error handling, + // but fall back to the old interface in case someone is mixing & matching launcher and + // Spark versions. try { - sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); - } catch (Exception e) { - throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", e); + sparkSubmit = cl.loadClass("org.apache.spark.deploy.InProcessSparkSubmit"); + } catch (Exception e1) { + try { + sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); + } catch (Exception e2) { + throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", + e2); + } } Method main; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index e0ef22d7d5058..5cb6457bf5c21 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -88,8 +88,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkLauncher.NO_RESOURCE); } - final List sparkArgs; - private final boolean isAppResourceReq; + final List userArgs; + private final List parsedArgs; + private final boolean requiresAppResource; private final boolean isExample; /** @@ -99,17 +100,27 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ private boolean allowsMixedArguments; + /** + * This constructor is used when creating a user-configurable launcher. It allows the + * spark-submit argument list to be modified after creation. + */ SparkSubmitCommandBuilder() { - this.sparkArgs = new ArrayList<>(); - this.isAppResourceReq = true; + this.requiresAppResource = true; this.isExample = false; + this.parsedArgs = new ArrayList<>(); + this.userArgs = new ArrayList<>(); } + /** + * This constructor is used when invoking spark-submit; it parses and validates arguments + * provided by the user on the command line. + */ SparkSubmitCommandBuilder(List args) { this.allowsMixedArguments = false; - this.sparkArgs = new ArrayList<>(); + this.parsedArgs = new ArrayList<>(); boolean isExample = false; List submitArgs = args; + this.userArgs = Collections.emptyList(); if (args.size() > 0) { switch (args.get(0)) { @@ -131,21 +142,21 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } this.isExample = isExample; - OptionParser parser = new OptionParser(); + OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.isAppResourceReq = parser.isAppResourceReq; - } else { + this.requiresAppResource = parser.requiresAppResource; + } else { this.isExample = isExample; - this.isAppResourceReq = false; + this.requiresAppResource = false; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) { + if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) { + } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -154,9 +165,19 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); - SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + OptionParser parser = new OptionParser(false); + final boolean requiresAppResource; + + // If the user args array is not empty, we need to parse it to detect exactly what + // the user is trying to run, so that checks below are correct. + if (!userArgs.isEmpty()) { + parser.parse(userArgs); + requiresAppResource = parser.requiresAppResource; + } else { + requiresAppResource = this.requiresAppResource; + } - if (!allowsMixedArguments && isAppResourceReq) { + if (!allowsMixedArguments && requiresAppResource) { checkArgument(appResource != null, "Missing application resource."); } @@ -208,15 +229,16 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isAppResourceReq) { - checkArgument(!isExample || mainClass != null, "Missing example class name."); + if (isExample) { + checkArgument(mainClass != null, "Missing example class name."); } + if (mainClass != null) { args.add(parser.CLASS); args.add(mainClass); } - args.addAll(sparkArgs); + args.addAll(parsedArgs); if (appResource != null) { args.add(appResource); } @@ -399,7 +421,12 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean isAppResourceReq = true; + boolean requiresAppResource = true; + private final boolean errorOnUnknownArgs; + + OptionParser(boolean errorOnUnknownArgs) { + this.errorOnUnknownArgs = errorOnUnknownArgs; + } @Override protected boolean handle(String opt, String value) { @@ -443,23 +470,23 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - isAppResourceReq = false; - sparkArgs.add(opt); - sparkArgs.add(value); + requiresAppResource = false; + parsedArgs.add(opt); + parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; case VERSION: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; default: - sparkArgs.add(opt); + parsedArgs.add(opt); if (value != null) { - sparkArgs.add(value); + parsedArgs.add(value); } break; } @@ -483,12 +510,13 @@ protected boolean handleUnknown(String opt) { mainClass = className; appResource = SparkLauncher.NO_RESOURCE; return false; - } else { + } else if (errorOnUnknownArgs) { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); checkState(appResource == null, "Found unrecognized argument but resource is already set."); appResource = opt; return false; } + return true; } @Override diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b37b4d51775e8..a87fa68422c34 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,12 +36,17 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), - + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index aa378c9d340f1..ccf33e8d4283c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer @@ -100,7 +100,13 @@ private[mesos] object MesosClusterDispatcher Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf - val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + val dispatcherArgs = try { + new MesosClusterDispatcherArguments(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage()) + null + } conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 096bb4e1af688..267a4283db9e6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkSubmitUtils import org.apache.spark.util.{IntParam, Utils} private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { @@ -95,9 +96,8 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: parse(tail) case ("--conf") :: value :: tail => - val pair = MesosClusterDispatcher. - parseSparkConfProperty(value) - confProperties(pair._1) = pair._2 + val (k, v) = SparkSubmitUtils.parseSparkConfProperty(value) + confProperties(k) = v parse(tail) case ("--help") :: tail => From 75a183071c4ed2e407c930edfdf721779662b3ee Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 11 Apr 2018 09:59:38 -0700 Subject: [PATCH 0604/2461] [SPARK-22883] ML test for StructuredStreaming: spark.ml.feature, I-M ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * IDF * Imputer * Interaction * MaxAbsScaler * MinHashLSH * MinMaxScaler * NGram ## How was this patch tested? It is a bunch of tests! Author: Joseph K. Bradley Closes #20964 from jkbradley/SPARK-22883-part2. --- .../apache/spark/ml/feature/IDFSuite.scala | 14 +++--- .../spark/ml/feature/ImputerSuite.scala | 31 ++++++++++--- .../spark/ml/feature/InteractionSuite.scala | 46 ++++++++++--------- .../spark/ml/feature/MaxAbsScalerSuite.scala | 14 +++--- .../spark/ml/feature/MinHashLSHSuite.scala | 25 ++++++++-- .../spark/ml/feature/MinMaxScalerSuite.scala | 14 +++--- .../apache/spark/ml/feature/NGramSuite.scala | 2 +- 7 files changed, 89 insertions(+), 57 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 005edf73d29be..cdd62be43b54c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class IDFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,7 +55,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((numOfData + 1.0) / (x + 1.0)) }) @@ -72,7 +70,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead MLTestingUtils.checkCopyAndUids(idfEst, idfModel) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } @@ -85,7 +83,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 }) @@ -99,7 +97,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setMinDocFreq(1) .fit(df) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index c08b35b419266..75f63a623e6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -16,13 +16,12 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.SparkException +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { val df = spark.createDataFrame( Seq( @@ -76,6 +75,28 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default ImputerSuite.iterateStrategyTest(imputer, df) } + test("Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCols(Array("value")) + .setOutputCols(Array("out")) + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -164,8 +185,6 @@ object ImputerSuite { * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { - val inputCols = imputer.getInputCols - Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 54f059e5f143e..eea31fc7ae3f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class InteractionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -63,9 +63,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("numeric interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -73,14 +73,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -92,9 +93,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("nominal interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -103,14 +104,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def val df = data.select( col("a").as( "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index 918da4f9388d4..8dd0f0cb91e37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -14,15 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -45,9 +44,10 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setOutputCol("scaled") val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(expectedVec: Vector, actualVec: Vector) => + assert(expectedVec === actualVec, + s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec") } MLTestingUtils.checkCopyAndUids(scaler, model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3da0fb7da01ae..1c2956cb82908 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{Dataset, Row} -class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class MinHashLSHSuite extends MLTest with DefaultReadWriteTest { @transient var dataset: Dataset[_] = _ @@ -175,4 +174,20 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(precision == 1.0) assert(recall >= 0.7) } + + test("MinHashLSHModel.transform should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + model.set(model.inputCol, "keys") + testTransformer[Tuple1[Vector]](dataset.toDF(), model, "keys", model.getOutputCol) { + case Row(_: Vector, output: Seq[_]) => + assert(output.length === model.randCoefficients.length) + // no AND-amplification yet: SPARK-18450, so each hash output is of length 1 + output.foreach { + case hashOutput: Vector => assert(hashOutput.size === 1) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 51db74eb739ca..2d965f2ca2c54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -48,9 +46,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setMax(5) val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 === vector2, "Transformed vector is different with expected.") } MLTestingUtils.checkCopyAndUids(scaler, model) @@ -114,7 +112,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De val model = scaler.fit(df) model.transform(df).select("expected", "scaled").collect() .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + assert(vector1 === vector2, "Transformed vector is different with expected.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index e5956ee9942aa..201a335e0d7be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -84,7 +84,7 @@ class NGramSuite extends MLTest with DefaultReadWriteTest { def testNGram(t: NGram, dataFrame: DataFrame): Unit = { testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { - case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => + case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) => assert(actualNGrams === wantedNGrams) } } From 9d960de0814a1128318676cc2e91f447cdf0137f Mon Sep 17 00:00:00 2001 From: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Date: Wed, 11 Apr 2018 15:52:13 -0700 Subject: [PATCH 0605/2461] typo rawPredicition changed to rawPrediction MultilayerPerceptronClassifier had 4 occurrences ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Closes #21030 from JBauerKogentix/patch-1. --- python/pyspark/ml/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fbbe3d0307c81..ec17653a1adf9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1543,12 +1543,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction") + rawPredictionCol="rawPrediction") """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -1562,12 +1562,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): Sets params for MultilayerPerceptronClassifier. """ kwargs = self._input_kwargs From e904dfaf0d16f9fa0cc4d2f46a3dec1b1d77de75 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 11 Apr 2018 17:04:34 -0700 Subject: [PATCH 0606/2461] Revert "[SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient" This reverts commit 271c891b91917d660d1f6b995de397c47c7a6058. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 965950ed94fe8..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. - @transient private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used for aggregation without keys. + private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,8 +238,6 @@ case class HashAggregateExec( | } """.stripMargin) - bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner - val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 6a2289ecf020a99cd9b3bcea7da5e78fb4e0303a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 12 Apr 2018 15:58:04 +0800 Subject: [PATCH 0607/2461] [SPARK-23962][SQL][TEST] Fix race in currentExecutionIds(). SQLMetricsTestUtils.currentExecutionIds() was racing with the listener bus, which lead to some flaky tests. We should wait till the listener bus is empty. I tested by adding some Thread.sleep()s in SQLAppStatusListener, which reproduced the exceptions I saw on Jenkins. With this change, they went away. Author: Imran Rashid Closes #21041 from squito/SPARK-23962. --- .../apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 534d8bb629b8c..dcc540fc4f109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -34,6 +34,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ protected def currentExecutionIds(): Set[Long] = { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) statusStore.executionsList.map(_.executionId).toSet } From 0b19122d434e39eb117ccc3174a0688c9c874d48 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 22:21:30 +0800 Subject: [PATCH 0608/2461] [SPARK-23762][SQL] UTF8StringBuffer uses MemoryBlock ## What changes were proposed in this pull request? This PR tries to use `MemoryBlock` in `UTF8StringBuffer`. In general, there are two advantages to use `MemoryBlock`. 1. Has clean API calls rather than using a Java array or `PlatformMemory` 2. Improve runtime performance of memory access instead of using `Object`. ## How was this patch tested? Added `UTF8StringBufferSuite` Author: Kazuaki Ishizaki Closes #20871 from kiszk/SPARK-23762. --- .../codegen/UTF8StringBuilder.java | 35 +++++++--------- .../codegen/UTF8StringBuilderSuite.scala | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index f0f66bae245fd..f8000d78cd1b6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,6 +19,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -29,43 +31,34 @@ public class UTF8StringBuilder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - private byte[] buffer; - private int cursor = Platform.BYTE_ARRAY_OFFSET; + private ByteArrayMemoryBlock buffer; + private int length = 0; public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new byte[16]; + this.buffer = new ByteArrayMemoryBlock(16); } // Grows the buffer by at least `neededSize` private void grow(int neededSize) { - if (neededSize > ARRAY_MAX - totalSize()) { + if (neededSize > ARRAY_MAX - length) { throw new UnsupportedOperationException( "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } - final int length = totalSize() + neededSize; - if (buffer.length < length) { - int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); + final int requestedSize = length + neededSize; + if (buffer.size() < requestedSize) { + int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; + final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); + MemoryBlock.copyMemory(buffer, tmp, length); buffer = tmp; } } - private int totalSize() { - return cursor - Platform.BYTE_ARRAY_OFFSET; - } - public void append(UTF8String value) { grow(value.numBytes()); - value.writeToMemory(buffer, cursor); - cursor += value.numBytes(); + value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); + length += value.numBytes(); } public void append(String value) { @@ -73,6 +66,6 @@ public void append(String value) { } public UTF8String build() { - return UTF8String.fromBytes(buffer, 0, totalSize()); + return UTF8String.fromBytes(buffer.getByteArray(), 0, length); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala new file mode 100644 index 0000000000000..1b25a4b191f86 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class UTF8StringBuilderSuite extends SparkFunSuite { + + test("basic test") { + val sb = new UTF8StringBuilder() + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("") + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("abcd") + assert(sb.build() === UTF8String.fromString("abcd")) + + sb.append(UTF8String.fromString("1234")) + assert(sb.build() === UTF8String.fromString("abcd1234")) + + // expect to grow an internal buffer + sb.append(UTF8String.fromString("efgijk567890")) + assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) + } +} From 0f93b91a71444a1a938acfd8ea2191c54fb0187c Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Apr 2018 15:47:42 -0600 Subject: [PATCH 0609/2461] [SPARK-23751][FOLLOW-UP] fix build for scala-2.12 ## What changes were proposed in this pull request? fix build for scala-2.12 ## How was this patch tested? Manual. Author: WeichenXu Closes #21051 from WeichenXu123/fix_build212. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index af8ff64d33ffe..adf8145726711 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -85,7 +85,7 @@ object KolmogorovSmirnovTest { dataset: Dataset[_], sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble) } /** From 682002b6da844ed11324ee5ff4d00fc0294c0b31 Mon Sep 17 00:00:00 2001 From: Patrick Pisciuneri Date: Fri, 13 Apr 2018 09:45:27 +0800 Subject: [PATCH 0610/2461] [SPARK-23867][SCHEDULER] use droppedCount in logWarning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Get the count of dropped events for output in log message. ## How was this patch tested? The fix is pretty trivial, but `./dev/run-tests` were run and were successful. Please review http://spark.apache.org/contributing.html before opening a pull request. vanzin cloud-fan The contribution is my original work and I license the work to the project under the project’s open source license. Author: Patrick Pisciuneri Closes #20977 from phpisciuneri/fix-log-warning. --- .../main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 7e14938acd8e0..c1fedd63f6a90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -166,7 +166,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp = System.currentTimeMillis() val previous = new java.util.Date(prevLastReportTimestamp) - logWarning(s"Dropped $droppedEvents events from $name since $previous.") + logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } From 14291b061b9b40eadbf4ed442f9a5021b8e09597 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 12 Apr 2018 20:00:25 -0700 Subject: [PATCH 0611/2461] [SPARK-23748][SS] Fix SS continuous process doesn't support SubqueryAlias issue ## What changes were proposed in this pull request? Current SS continuous doesn't support processing on temp table or `df.as("xxx")`, SS will throw an exception as LogicalPlan not supported, details described in [here](https://issues.apache.org/jira/browse/SPARK-23748). So here propose to add this support. ## How was this patch tested? new UT. Author: jerryshao Closes #21017 from jerryshao/SPARK-23748. --- .../UnsupportedOperationChecker.scala | 2 +- .../continuous/ContinuousSuite.scala | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index b55043c270644..ff9d6d7a7dded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,7 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index f5884b9c8de12..ef74efef156d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -171,6 +171,25 @@ class ContinuousSuite extends ContinuousSuiteBase { "Continuous processing does not support current time operations.")) } + test("subquery alias") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .createOrReplaceTempView("rate") + val test = spark.sql("select value from rate where value > 5") + + testStream(test, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + test("repeatedly restart") { val df = spark.readStream .format("rate") From ab7b961a4fe96ca02b8352d16b0fa80c972b67fc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 13 Apr 2018 11:28:13 +0800 Subject: [PATCH 0612/2461] [SPARK-23942][PYTHON][SQL] Makes collect in PySpark as action for a query executor listener ## What changes were proposed in this pull request? This PR proposes to add `collect` to a query executor as an action. Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below: ```scala package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener class TestQueryExecutionListener extends QueryExecutionListener with Logging { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { logError("Look at me! I'm 'onSuccess'") } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } } ``` and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener` Other operations in PySpark or Scala side seems fine: ```python >>> sql("SELECT * FROM range(1)").show() ``` ``` 18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' +---+ | id| +---+ | 0| +---+ ``` ```scala scala> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' res1: Array[org.apache.spark.sql.Row] = Array([0]) ``` but .. **Before** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` id 0 0 ``` **After** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` 18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' id 0 0 ``` ## How was this patch tested? I have manually tested as described above and unit test was added. Author: hyukjinkwon Closes #21007 from HyukjinKwon/SPARK-23942. --- python/pyspark/sql/tests.py | 87 ++++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 20 +++-- .../sql/TestQueryExecutionListener.scala | 44 ++++++++++ 3 files changed, 134 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 96c2a776a5049..4e99c8e3c6b10 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,16 +186,12 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ @contextmanager def sql_conf(self, pairs): @@ -204,6 +200,7 @@ def sql_conf(self, pairs): `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() @@ -219,6 +216,18 @@ def sql_conf(self, pairs): else: self.spark.conf.set(key, old_value) + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3066,6 +3075,64 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0aee1d7be5788..917168162b236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collectAsArrowToPython", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3311,14 +3312,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala new file mode 100644 index 0000000000000..d2a6358ee822b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} + +/** + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. + */ +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) +} From 1018be44d6c52cf18e14d84160850063f0e60a1d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 12 Apr 2018 22:30:59 -0700 Subject: [PATCH 0613/2461] [SPARK-23971] Should not leak Spark sessions across test suites ## What changes were proposed in this pull request? Many suites currently leak Spark sessions (sometimes with stopped SparkContexts) via the thread-local active Spark session and default Spark session. We should attempt to clean these up and detect when this happens to improve the reproducibility of tests. ## How was this patch tested? Existing tests Author: Eric Liang Closes #21058 from ericl/clear-session. --- .../org/apache/spark/SharedSparkSession.java | 9 ++++++-- .../org/apache/spark/sql/SparkSession.scala | 23 +++++++++++++++++-- .../apache/spark/sql/SessionStateSuite.scala | 2 ++ .../spark/sql/test/SharedSparkSession.scala | 22 ++++++++++++++---- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java index 43779878890db..35a250955b282 100644 --- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -42,7 +42,12 @@ public void setUp() throws IOException { @After public void tearDown() { - spark.stop(); - spark = null; + try { + spark.stop(); + spark = null; + } finally { + SparkSession.clearDefaultSession(); + SparkSession.clearActiveSession(); + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b107492fbb330..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} /** @@ -81,6 +81,9 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => + // The call site where this SparkSession was constructed. + private val creationSite: CallSite = Utils.getCallSite() + private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) } @@ -763,7 +766,7 @@ class SparkSession private( @InterfaceStability.Stable -object SparkSession { +object SparkSession extends Logging { /** * Builder for [[SparkSession]]. @@ -1090,4 +1093,20 @@ object SparkSession { } } + private[spark] def cleanupAnyExistingSession(): Unit = { + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + logWarning( + s"""An existing Spark session exists as the active or default session. + |This probably means another suite leaked it. Attempting to stop it before continuing. + |This existing Spark session was created at: + | + |${session.get.creationSite.longForm} + | + """.stripMargin) + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 4efae4c46c2e1..7d1366092d1e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -44,6 +44,8 @@ class SessionStateSuite extends SparkFunSuite { if (activeSession != null) { activeSession.stop() activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e758c865b908f..8968dbf36d507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -60,6 +60,7 @@ trait SharedSparkSession protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() new TestSparkSession(sparkConf) } @@ -92,11 +93,22 @@ trait SharedSparkSession * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } } } From 4b07036799b01894826b47c73142fe282c607a57 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Fri, 13 Apr 2018 13:46:34 +0800 Subject: [PATCH 0614/2461] [SPARK-23815][CORE] Spark writer dynamic partition overwrite mode may fail to write output on multi level partition ## What changes were proposed in this pull request? Spark introduced new writer mode to overwrite only related partitions in SPARK-20236. While we are using this feature in our production cluster, we found a bug when writing multi-level partitions on HDFS. A simple test case to reproduce this issue: val df = Seq(("1","2","3")).toDF("col1", "col2","col3") df.write.partitionBy("col1","col2").mode("overwrite").save("/my/hdfs/location") If HDFS location "/my/hdfs/location" does not exist, there will be no output. This seems to be caused by the job commit change in SPARK-20236 in HadoopMapReduceCommitProtocol. In the commit job process, the output has been written into staging dir /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2, and then the code calls fs.rename to rename /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2 to /my/hdfs/location/col1=1/col2=2. However, in our case the operation will fail on HDFS because /my/hdfs/location/col1=1 does not exists. HDFS rename can not create directory for more than one level. This does not happen in the new unit test added with SPARK-20236 which uses local file system. We are proposing a fix. When cleaning current partition dir /my/hdfs/location/col1=1/col2=2 before the rename op, if the delete op fails (because /my/hdfs/location/col1=1/col2=2 may not exist), we call mkdirs op to create the parent dir /my/hdfs/location/col1=1 (if the parent dir does not exist) so the following rename op can succeed. Reference: in official HDFS document(https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html), the rename command has precondition "dest must be root, or have a parent that exists" ## How was this patch tested? We have tested this patch on our production cluster and it fixed the problem Author: Fangshi Li Closes #20931 from fangshil/master. --- .../internal/io/HadoopMapReduceCommitProtocol.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6d20ef1f98a3c..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -186,7 +186,17 @@ class HadoopMapReduceCommitProtocol( logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") for (part <- partitionPaths) { val finalPartPath = new Path(path, part) - fs.delete(finalPartPath, true) + if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) { + // According to the official hadoop FileSystem API spec, delete op should assume + // the destination is no longer present regardless of return value, thus we do not + // need to double check if finalPartPath exists before rename. + // Also in our case, based on the spec, delete returns false only when finalPartPath + // does not exist. When this happens, we need to take action if parent of finalPartPath + // also does not exist(e.g. the scenario described on SPARK-23815), because + // FileSystem API spec on rename op says the rename dest(finalPartPath) must have + // a parent that exists, otherwise we may get unexpected result on the rename. + fs.mkdirs(finalPartPath.getParent) + } fs.rename(new Path(stagingDir, part), finalPartPath) } } From 0323e61465ee747c9a57a70e9d6108876499546e Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 13 Apr 2018 00:00:04 -0700 Subject: [PATCH 0615/2461] [SPARK-23905][SQL] Add UDF weekday ## What changes were proposed in this pull request? Add UDF weekday ## How was this patch tested? A new test Author: yucai Closes #21009 from yucai/SPARK-23905. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 55 +++++++++++++++---- .../expressions/DateExpressionsSuite.scala | 11 ++++ .../resources/sql-tests/inputs/datetime.sql | 2 + .../sql-tests/results/datetime.sql.out | 9 ++- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..131b958239e41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -395,6 +395,7 @@ object FunctionRegistry { expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), + expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 32fdb13afbbfa..b9b2cd5bdb9f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -426,36 +426,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa """, since = "2.3.0") // scalastyle:on line.size.limit -case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfWeek(child: Expression) extends DayWeek { - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType + override protected def nullSafeEval(date: Any): Any = { + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + cal.get(Calendar.DAY_OF_WEEK) + } - @transient private lazy val c = { - Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).", + examples = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 3 + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class WeekDay(child: Expression) extends DayWeek { override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.DAY_OF_WEEK) + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7 } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" + val c = "calWeekDay" ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ }) } } +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient protected lazy val cal: Calendar = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 786266a2c13c0..080ec487cfa6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -211,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) } + test("WeekDay") { + checkEvaluation(WeekDay(Literal.create(null, DateType)), null) + checkEvaluation(WeekDay(Literal(d)), 2) + checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd3..547c2bef02b24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -25,3 +25,5 @@ create temporary view ttf2 as select * from values select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index bbb6851e69c7e..4e1cfa6e48c1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -81,3 +81,10 @@ struct -- !query 8 output 1 2 2 3 + +-- !query 9 +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +5 3 5 NULL 4 From a83ae0d9bc1b8f4909b9338370efe4020079bea7 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 13 Apr 2018 08:43:58 -0700 Subject: [PATCH 0616/2461] [SPARK-22839][K8S] Refactor to unify driver and executor pod builder APIs ## What changes were proposed in this pull request? Breaks down the construction of driver pods and executor pods in a way that uses a common abstraction for both spark-submit creating the driver and KubernetesClusterSchedulerBackend creating the executor. Encourages more code reuse and is more legible than the older approach. The high-level design is discussed in more detail on the JIRA ticket. This pull request is the implementation of that design with some minor changes in the implementation details. No user-facing behavior should break as a result of this change. ## How was this patch tested? Migrated all unit tests from the old submission steps architecture to the new architecture. Integration tests should not have to change and pass given that this shouldn't change any outward behavior. Author: mcheah Closes #20910 from mccheah/spark-22839-incremental. --- .../org/apache/spark/deploy/k8s/Config.scala | 2 +- .../spark/deploy/k8s/KubernetesConf.scala | 184 +++++++++++++ .../deploy/k8s/KubernetesDriverSpec.scala} | 25 +- .../spark/deploy/k8s/KubernetesUtils.scala | 11 - .../deploy/k8s/MountSecretsBootstrap.scala | 72 ----- ...ConfigurationStep.scala => SparkPod.scala} | 24 +- .../k8s/features/BasicDriverFeatureStep.scala | 136 ++++++++++ .../features/BasicExecutorFeatureStep.scala | 179 +++++++++++++ ...iverKubernetesCredentialsFeatureStep.scala | 216 +++++++++++++++ .../features/DriverServiceFeatureStep.scala | 97 +++++++ .../KubernetesFeatureConfigStep.scala | 71 +++++ .../features/MountSecretsFeatureStep.scala | 62 +++++ .../k8s/submit/DriverConfigOrchestrator.scala | 145 ----------- .../submit/KubernetesClientApplication.scala | 80 +++--- .../k8s/submit/KubernetesDriverBuilder.scala | 56 ++++ .../k8s/submit/KubernetesDriverSpec.scala | 47 ---- .../steps/BasicDriverConfigurationStep.scala | 163 ------------ .../steps/DependencyResolutionStep.scala | 61 ----- .../DriverKubernetesCredentialsStep.scala | 245 ------------------ .../submit/steps/DriverMountSecretsStep.scala | 38 --- .../steps/DriverServiceBootstrapStep.scala | 104 -------- .../cluster/k8s/ExecutorPodFactory.scala | 227 ---------------- .../k8s/KubernetesClusterManager.scala | 12 +- .../KubernetesClusterSchedulerBackend.scala | 20 +- .../k8s/KubernetesExecutorBuilder.scala | 41 +++ .../deploy/k8s/KubernetesConfSuite.scala | 175 +++++++++++++ .../BasicDriverFeatureStepSuite.scala | 153 +++++++++++ .../BasicExecutorFeatureStepSuite.scala | 179 +++++++++++++ ...bernetesCredentialsFeatureStepSuite.scala} | 101 +++++--- .../DriverServiceFeatureStepSuite.scala | 227 ++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 61 +++++ .../MountSecretsFeatureStepSuite.scala} | 29 ++- .../spark/deploy/k8s/submit/ClientSuite.scala | 216 +++++++-------- .../DriverConfigOrchestratorSuite.scala | 131 ---------- .../submit/KubernetesDriverBuilderSuite.scala | 102 ++++++++ .../BasicDriverConfigurationStepSuite.scala | 122 --------- .../steps/DependencyResolutionStepSuite.scala | 69 ----- .../DriverServiceBootstrapStepSuite.scala | 180 ------------- .../cluster/k8s/ExecutorPodFactorySuite.scala | 195 -------------- ...bernetesClusterSchedulerBackendSuite.scala | 37 ++- .../k8s/KubernetesExecutorBuilderSuite.scala | 75 ++++++ 41 files changed, 2289 insertions(+), 2081 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala rename resource-managers/kubernetes/core/src/{test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala => main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala} (57%) delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverConfigurationStep.scala => SparkPod.scala} (64%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverKubernetesCredentialsStepSuite.scala => features/DriverKubernetesCredentialsFeatureStepSuite.scala} (67%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverMountSecretsStepSuite.scala => features/MountSecretsFeatureStepSuite.scala} (64%) delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 82f6c714f3555..4086970ffb256 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -167,5 +167,5 @@ private[spark] object Config extends Logging { val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." - val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala new file mode 100644 index 0000000000000..77b634ddfabcc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.internal.config.ConfigEntry + +private[spark] sealed trait KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark driver. + */ +private[spark] case class KubernetesDriverSpecificConf( + mainAppResource: Option[MainAppResource], + mainClass: String, + appName: String, + appArgs: Seq[String]) extends KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark executor. + */ +private[spark] case class KubernetesExecutorSpecificConf( + executorId: String, + driverPod: Pod) + extends KubernetesRoleSpecificConf + +/** + * Structure containing metadata for Kubernetes logic to build Spark pods. + */ +private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( + sparkConf: SparkConf, + roleSpecificConf: T, + appResourceNamePrefix: String, + appId: String, + roleLabels: Map[String, String], + roleAnnotations: Map[String, String], + roleSecretNamesToMountPaths: Map[String, String], + roleEnvs: Map[String, String]) { + + def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + + def sparkJars(): Seq[String] = sparkConf + .getOption("spark.jars") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def sparkFiles(): Seq[String] = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets(): Seq[LocalObjectReference] = { + sparkConf + .get(IMAGE_PULL_SECRETS) + .map(_.split(",")) + .getOrElse(Array.empty[String]) + .map(_.trim) + .map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + } + + def nodeSelector(): Map[String, String] = + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) + + def get(conf: String): String = sparkConf.get(conf) + + def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue) + + def getOption(key: String): Option[String] = sparkConf.getOption(key) +} + +private[spark] object KubernetesConf { + def createDriverConf( + sparkConf: SparkConf, + appName: String, + appResourceNamePrefix: String, + appId: String, + mainAppResource: Option[MainAppResource], + mainClass: String, + appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + val sparkConfWithMainAppJar = sparkConf.clone() + mainAppResource.foreach { + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + } + + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + val driverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + + KubernetesConf( + sparkConfWithMainAppJar, + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + appResourceNamePrefix, + appId, + driverLabels, + driverAnnotations, + driverSecretNamesToMountPaths, + driverEnvs) + } + + def createExecutorConf( + sparkConf: SparkConf, + executorId: String, + appId: String, + driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorCustomLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorCustomLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + val executorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorCustomLabels + val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnv = sparkConf.getExecutorEnv.toMap + + KubernetesConf( + sparkConf.clone(), + KubernetesExecutorSpecificConf(executorId, driverPod), + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appId, + executorLabels, + executorAnnotations, + executorSecrets, + executorEnv) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala similarity index 57% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala index cf41b22e241af..0c5ae022f4070 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -16,21 +16,16 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference +import io.fabric8.kubernetes.api.model.HasMetadata -import org.apache.spark.SparkFunSuite - -class KubernetesUtilsTest extends SparkFunSuite { - - test("testParseImagePullSecrets") { - val noSecrets = KubernetesUtils.parseImagePullSecrets(None) - assert(noSecrets === Nil) - - val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) - assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) - - val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) - assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) - } +private[spark] case class KubernetesDriverSpec( + pod: SparkPod, + driverKubernetesResources: Seq[HasMetadata], + systemProperties: Map[String, String]) +private[spark] object KubernetesDriverSpec { + def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( + SparkPod.initialPod(), + Seq.empty, + initialProps) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5b2bb819cdb14..ee629068ad90d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -37,17 +37,6 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } - /** - * Parses comma-separated list of imagePullSecrets into K8s-understandable format - */ - def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { - imagePullSecrets match { - case Some(secretsCommaSeparated) => - secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList - case None => Nil - } - } - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala deleted file mode 100644 index c35e7db51d407..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} - -/** - * Bootstraps a driver or executor container or an init-container with needed secrets mounted. - */ -private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { - - /** - * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. - * - * @param pod the pod into which the secret volumes are being added. - * @return the updated pod with the secret volumes added. - */ - def addSecretVolumes(pod: Pod): Pod = { - var podBuilder = new PodBuilder(pod) - secretNamesToMountPaths.keys.foreach { name => - podBuilder = podBuilder - .editOrNewSpec() - .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() - .endSpec() - } - - podBuilder.build() - } - - /** - * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the - * given container. - * - * @param container the container into which the secret volumes are being mounted. - * @return the updated container with the secrets mounted. - */ - def mountSecrets(container: Container): Container = { - var containerBuilder = new ContainerBuilder(container) - secretNamesToMountPaths.foreach { case (name, path) => - containerBuilder = containerBuilder - .addNewVolumeMount() - .withName(secretVolumeName(name)) - .withMountPath(path) - .endVolumeMount() - } - - containerBuilder.build() - } - - private def secretVolumeName(secretName: String): String = { - secretName + "-volume" - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala similarity index 64% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 17614e040e587..345dd117fd35f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -14,17 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -/** - * Represents a step in configuring the Spark driver pod. - */ -private[spark] trait DriverConfigurationStep { +private[spark] case class SparkPod(pod: Pod, container: Container) - /** - * Apply some transformation to the previous state of the driver to add a new feature to it. - */ - def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +private[spark] object SparkPod { + def initialPod(): SparkPod = { + SparkPod( + new PodBuilder() + .withNewMetadata() + .endMetadata() + .withNewSpec() + .endSpec() + .build(), + new ContainerBuilder().build()) + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala new file mode 100644 index 0000000000000..07bdccbe0479d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class BasicDriverFeatureStep( + conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"${conf.appResourceNamePrefix}-driver") + + private val driverContainerImage = conf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) + + // CPU settings + private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadMiB = conf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configurePod(pod: SparkPod): SparkPod = { + val driverCustomEnvs = conf.roleEnvs + .toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(pod.container) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverContainerImage) + .withImagePullPolicy(conf.imagePullPolicy()) + .addAllToEnv(driverCustomEnvs.asJava) + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryQuantity) + .endResources() + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", conf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(conf.roleSpecificConf.appArgs: _*) + .build() + + val driverPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(driverPodName) + .addToLabels(conf.roleLabels.asJava) + .addToAnnotations(conf.roleAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(conf.nodeSelector().asJava) + .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(driverPod, driverContainer) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val additionalProps = mutable.Map( + KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, + "spark.app.id" -> conf.appId, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") + + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkJars()) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkFiles()) + if (resolvedSparkJars.nonEmpty) { + additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + } + additionalProps.toMap + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala new file mode 100644 index 0000000000000..d22097587aafe --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) + extends KubernetesFeatureConfigStep { + + // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf + private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) + private val executorContainerImage = kubernetesConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val blockManagerPort = kubernetesConf + .sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + + private val driverUrl = RpcEndpointAddress( + kubernetesConf.get("spark.driver.host"), + kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) + private val executorMemoryString = kubernetesConf.get( + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = kubernetesConf + .get(EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = + if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } + private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def configurePod(pod: SparkPod): SparkPod = { + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCoresRequest) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = kubernetesConf + .get(EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ + kubernetesConf.roleEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder(pod.container) + .withName("executor") + .withImage(executorContainerImage) + .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .addToArgs("executor") + .build() + val containerWithLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + val driverPod = kubernetesConf.roleSpecificConf.driverPod + val executorPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(name) + .withLabels(kubernetesConf.roleLabels.asJava) + .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .editOrNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(executorPod, containerWithLimitCores) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala new file mode 100644 index 0000000000000..ff5ad6673b309 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) + extends KubernetesFeatureConfigStep { + // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. + // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. + + private val maybeMountedOAuthTokenFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + private val oauthTokenBase64 = kubernetesConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + + private val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + private val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + private val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder. + private val shouldMountSecret = oauthTokenBase64.isDefined || + caCertDataBase64.isDefined || + clientKeyDataBase64.isDefined || + clientCertDataBase64.isDefined + + private val driverCredentialsSecretName = + s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + + override def configurePod(pod: SparkPod): SparkPod = { + if (!shouldMountSecret) { + pod.copy( + pod = driverServiceAccount.map { account => + new PodBuilder(pod.pod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(pod.pod)) + } else { + val driverPodWithMountedKubernetesCredentials = + new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret() + .endVolume() + .endSpec() + .build() + + val driverContainerWithMountedSecretVolume = + new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val redactedTokens = kubernetesConf.sparkConf.getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) + .toMap + .mapValues( _ => "") + redactedTokens ++ + resolvedMountedCaCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientKeyFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedOAuthTokenFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (shouldMountSecret) { + Seq(createCredentialsSecret()) + } else { + Seq.empty + } + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + kubernetesConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + private def createCredentialsSecret(): Secret = { + val allSecretData = + resolveSecretData( + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + new SecretBuilder() + .withNewMetadata() + .withName(driverCredentialsSecretName) + .endMetadata() + .withData(allSecretData.asJava) + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala new file mode 100644 index 0000000000000..f2d7bbd08f305 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, SystemClock} + +private[spark] class DriverServiceFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + clock: Clock = new SystemClock) + extends KubernetesFeatureConfigStep with Logging { + import DriverServiceFeatureStep._ + + require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + private val driverPort = kubernetesConf.sparkConf.getInt( + "spark.driver.port", DEFAULT_DRIVER_PORT) + private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + + override def configurePod(pod: SparkPod): SparkPod = pod + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + Map(DRIVER_HOST_KEY -> driverHostname, + "spark.driver.port" -> driverPort.toString, + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> + driverBlockManagerPort.toString) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(kubernetesConf.roleLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + Seq(driverService) + } +} + +private[spark] object DriverServiceFeatureStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala new file mode 100644 index 0000000000000..4c1be3bb13293 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.HasMetadata + +import org.apache.spark.deploy.k8s.SparkPod + +/** + * A collection of functions that together represent a "feature" in pods that are launched for + * Spark drivers and executors. + */ +private[spark] trait KubernetesFeatureConfigStep { + + /** + * Apply modifications on the given pod in accordance to this feature. This can include attaching + * volumes, adding environment variables, and adding labels/annotations. + *

    + * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod + * object. So this is correct: + *

    +   * {@code val configuredPod = new PodBuilder(pod.pod)
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder(pod.container)
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + * This is incorrect: + *
    +   * {@code val configuredPod = new PodBuilder() // Loses the original state
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder() // Loses the original state
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + */ + def configurePod(pod: SparkPod): SparkPod + + /** + * Return any system properties that should be set on the JVM in accordance to this feature. + */ + def getAdditionalPodSystemProperties(): Map[String, String] + + /** + * Return any additional Kubernetes resources that should be added to support this feature. Only + * applicable when creating the driver in cluster mode. + */ + def getAdditionalKubernetesResources(): Seq[HasMetadata] +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala new file mode 100644 index 0000000000000..97fa9499b2edb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class MountSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedVolumes = kubernetesConf + .roleSecretNamesToMountPaths + .keys + .map(secretName => + new VolumeBuilder() + .withName(secretVolumeName(secretName)) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .build()) + val podWithVolumes = new PodBuilder(pod.pod) + .editOrNewSpec() + .addToVolumes(addedVolumes.toSeq: _*) + .endSpec() + .build() + val addedVolumeMounts = kubernetesConf + .roleSecretNamesToMountPaths + .map { + case (secretName, mountPath) => + new VolumeMountBuilder() + .withName(secretVolumeName(secretName)) + .withMountPath(mountPath) + .build() + } + val containerWithMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(addedVolumeMounts.toSeq: _*) + .build() + SparkPod(podWithVolumes, containerWithMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def secretVolumeName(secretName: String): String = s"$secretName-volume" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala deleted file mode 100644 index b4d3f04a1bc32..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils - -/** - * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to - * configure the Spark driver pod. The returned steps will be applied one by one in the given - * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. - */ -private[spark] class DriverConfigOrchestrator( - kubernetesAppId: String, - kubernetesResourceNamePrefix: String, - mainAppResource: Option[MainAppResource], - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) { - - // The resource name prefix is derived from the Spark application name, making it easy to connect - // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the - // application the user submitted. - - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - - def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - - val allDriverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> kubernetesAppId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - - val initialSubmissionStep = new BasicDriverConfigurationStep( - kubernetesAppId, - kubernetesResourceNamePrefix, - allDriverLabels, - imagePullPolicy, - appName, - mainClass, - appArgs, - sparkConf) - - val serviceBootstrapStep = new DriverServiceBootstrapStep( - kubernetesResourceNamePrefix, - allDriverLabels, - sparkConf, - new SystemClock) - - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - sparkConf, kubernetesResourceNamePrefix) - - val additionalMainAppJar = if (mainAppResource.nonEmpty) { - val mayBeResource = mainAppResource.get match { - case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => - Some(resource) - case _ => None - } - mayBeResource - } else { - None - } - - val sparkJars = sparkConf.getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty[String]) ++ - additionalMainAppJar.toSeq - val sparkFiles = sparkConf.getOption("spark.files") - .map(_.split(",")) - .getOrElse(Array.empty[String]) - - // TODO(SPARK-23153): remove once submission client local dependencies are supported. - if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { - throw new SparkException("The Kubernetes mode does not yet support referencing application " + - "dependencies in the local file system.") - } - - val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Seq(new DependencyResolutionStep( - sparkJars, - sparkFiles)) - } else { - Nil - } - - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq( - initialSubmissionStep, - serviceBootstrapStep, - kubernetesCredentialsStep) ++ - dependencyResolutionStep ++ - mountSecretsStep - } - - private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme == "file" - } - } - - private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme != "local" - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index e16d1add600b2..a97f5650fb869 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,12 +27,10 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -43,9 +41,9 @@ import org.apache.spark.util.Utils * @param driverArgs arguments to the driver */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], - mainClass: String, - driverArgs: Array[String]) + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) private[spark] object ClientArguments { @@ -80,8 +78,9 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * - * @param submissionSteps steps that collectively configure the driver - * @param sparkConf the submission client Spark configuration + * @param builder Responsible for building the base driver pod based on a composition of + * implemented features. + * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete @@ -89,31 +88,21 @@ private[spark] object ClientArguments { * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( - submissionSteps: Seq[DriverConfigurationStep], - sparkConf: SparkConf, + builder: KubernetesDriverBuilder, + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, watcher: LoggingPodStatusWatcher, kubernetesResourceNamePrefix: String) extends Logging { - /** - * Run command that initializes a DriverSpec that will be updated after each - * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec - * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources - */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) - // submissionSteps contain steps necessary to take, to resolve varying - // client arguments that are passed in, created by orchestrator - for (nextStep <- submissionSteps) { - currentDriverSpec = nextStep.configureDriver(currentDriverSpec) - } + val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" - val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap - val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) + val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container) .addNewEnv() .withName(ENV_SPARK_CONF_DIR) .withValue(SPARK_CONF_DIR_INTERNAL) @@ -123,7 +112,7 @@ private[spark] class Client( .withMountPath(SPARK_CONF_DIR_INTERNAL) .endVolumeMount() .build() - val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod) .editSpec() .addToContainers(resolvedDriverContainer) .addNewVolume() @@ -141,12 +130,10 @@ private[spark] class Client( .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { - if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = - currentDriverSpec.otherKubernetesResources ++ Seq(configMap) - addDriverOwnerReference(createdDriverPod, otherKubernetesResources) - kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() - } + val otherKubernetesResources = + resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap) + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } catch { case NonFatal(e) => kubernetesClient.pods().delete(createdDriverPod) @@ -180,20 +167,17 @@ private[spark] class Client( } // Build a Config Map that will house spark conf properties in a single file for spark-submit - private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = { val properties = new Properties() - conf.getAll.foreach { case (k, v) => + conf.foreach { case (k, v) => properties.setProperty(k, v) } val propertiesWriter = new StringWriter() properties.store(propertiesWriter, s"Java properties built from Kubernetes config map with name: $configMapName") - - val namespace = conf.get(KUBERNETES_NAMESPACE) new ConfigMapBuilder() .withNewMetadata() .withName(configMapName) - .withNamespace(namespace) .endMetadata() .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) .build() @@ -211,7 +195,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // For constructing the app ID, we can't use the Spark application name, as the app ID is going // to be added as a label to group resources belonging to the same application. Label values are // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate @@ -219,10 +203,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + val kubernetesConf = KubernetesConf.createDriverConf( + sparkConf, + appName, + kubernetesResourceNamePrefix, + kubernetesAppId, + clientArguments.mainAppResource, + clientArguments.mainClass, + clientArguments.driverArgs) + val builder = new KubernetesDriverBuilder + val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -230,15 +223,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val orchestrator = new DriverConfigOrchestrator( - kubernetesAppId, - kubernetesResourceNamePrefix, - clientArguments.mainAppResource, - appName, - clientArguments.mainClass, - clientArguments.driverArgs, - sparkConf) - Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, Some(namespace), @@ -247,8 +231,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - orchestrator.getAllConfigurationSteps, - sparkConf, + builder, + kubernetesConf, kubernetesClient, waitForAppCompletion, appName, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala new file mode 100644 index 0000000000000..c7579ed8cb689 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesDriverBuilder( + provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + new BasicDriverFeatureStep(_), + provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) + => DriverKubernetesCredentialsFeatureStep = + new DriverKubernetesCredentialsFeatureStep(_), + provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + new DriverServiceFeatureStep(_), + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountSecretsFeatureStep) = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideCredentialsStep(kubernetesConf), + provideServiceStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + for (feature <- allFeatures) { + val configuredPod = feature.configurePod(spec.pod) + val addedSystemProperties = feature.getAdditionalPodSystemProperties() + val addedResources = feature.getAdditionalKubernetesResources() + spec = KubernetesDriverSpec( + configuredPod, + spec.driverKubernetesResources ++ addedResources, + spec.systemProperties ++ addedSystemProperties) + } + spec + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala deleted file mode 100644 index db13f09387ef9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} - -import org.apache.spark.SparkConf - -/** - * Represents the components and characteristics of a Spark driver. The driver can be considered - * as being comprised of the driver pod itself, any other Kubernetes resources that the driver - * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver - * container should be operated on via the specific field of this case class as opposed to trying - * to edit the container directly on the pod. The driver container should be attached at the - * end of executing all submission steps. - */ -private[spark] case class KubernetesDriverSpec( - driverPod: Pod, - driverContainer: Container, - otherKubernetesResources: Seq[HasMetadata], - driverSparkConf: SparkConf) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { - KubernetesDriverSpec( - // Set new metadata and a new spec so that submission steps can use - // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - new ContainerBuilder().build(), - Seq.empty[HasMetadata], - initialSparkConf.clone()) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala deleted file mode 100644 index fcb1db8008053..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} -import org.apache.spark.launcher.SparkLauncher - -/** - * Performs basic configuration for the driver pod. - */ -private[spark] class BasicDriverConfigurationStep( - kubernetesAppId: String, - resourceNamePrefix: String, - driverLabels: Map[String, String], - imagePullPolicy: String, - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) extends DriverConfigurationStep { - - private val driverPodName = sparkConf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$resourceNamePrefix-driver") - - private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - - private val driverContainerImage = sparkConf - .get(DRIVER_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver container image")) - - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - - // CPU settings - private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) - - // Memory settings - private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val memoryOverheadMiB = sparkConf - .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) - private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(classPath) - .build() - } - - val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), - s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + - " Spark bookkeeping operations.") - - val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq - .map { env => - new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - } - - val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - - val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - - val driverCpuQuantity = new QuantityBuilder(false) - .withAmount(driverCpuCores) - .build() - val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryWithOverheadMiB}Mi") - .build() - val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => - ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) - } - - val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) - .withName(DRIVER_CONTAINER_NAME) - .withImage(driverContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(driverCustomEnvs.asJava) - .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_BIND_ADDRESS) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .endEnv() - .withNewResources() - .addToRequests("cpu", driverCpuQuantity) - .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryQuantity) - .addToLimits(maybeCpuLimitQuantity.toMap.asJava) - .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - - val driverContainer = appArgs.toList match { - case "" :: Nil | Nil => driverContainerWithoutArgs.build() - case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() - } - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - val baseDriverPod = new PodBuilder(driverSpec.driverPod) - .editOrNewMetadata() - .withName(driverPodName) - .addToLabels(driverLabels.asJava) - .addToAnnotations(driverAnnotations.asJava) - .endMetadata() - .withNewSpec() - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) - // to set the config variables to allow client-mode spark-submit from driver - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - - driverSpec.copy( - driverPod = baseDriverPod, - driverSparkConf = resolvedSparkConf, - driverContainer = driverContainer) - } - -} - diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala deleted file mode 100644 index 43de329f239ad..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import io.fabric8.kubernetes.api.model.ContainerBuilder - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Step that configures the classpath, spark.jars, and spark.files for the driver given that the - * user may provide remote files or files with local:// schemes. - */ -private[spark] class DependencyResolutionStep( - sparkJars: Seq[String], - sparkFiles: Seq[String]) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) - val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) - - val sparkConf = driverSpec.driverSparkConf.clone() - if (resolvedSparkJars.nonEmpty) { - sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) - } - val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { - new ContainerBuilder(driverSpec.driverContainer) - .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedSparkJars.mkString(File.pathSeparator)) - .endEnv() - .build() - } else { - driverSpec.driverContainer - } - - driverSpec.copy( - driverContainer = resolvedDriverContainer, - driverSparkConf = sparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala deleted file mode 100644 index 2424e63999a82..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials - * to request executors. - */ -private[spark] class DriverKubernetesCredentialsStep( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { - - private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") - private val maybeMountedClientKeyFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") - private val maybeMountedClientCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") - private val maybeMountedCaCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") - private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverSparkConf = driverSpec.driverSparkConf.clone() - - val oauthTokenBase64 = submissionSparkConf - .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") - .map { token => - BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) - } - val caCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - "Driver CA cert file") - val clientKeyDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - "Driver client key file") - val clientCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - "Driver client cert file") - - val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( - driverSparkConf, - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val kubernetesCredentialsSecret = createCredentialsSecret( - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .addNewVolume() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() - .endVolume() - .endSpec() - .build() - }.getOrElse( - driverServiceAccount.map { account => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .withServiceAccount(account) - .withServiceAccountName(account) - .endSpec() - .build() - }.getOrElse(driverSpec.driverPod) - ) - - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => - new ContainerBuilder(driverSpec.driverContainer) - .addNewVolumeMount() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) - .endVolumeMount() - .build() - }.getOrElse(driverSpec.driverContainer) - - driverSpec.copy( - driverPod = driverPodWithMountedKubernetesCredentials, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, - driverSparkConf = driverSparkConfWithCredentialsLocations, - driverContainer = driverContainerWithMountedSecretVolume) - } - - private def createCredentialsSecret( - driverOAuthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): Option[Secret] = { - val allSecretData = - resolveSecretData( - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ - resolveSecretData( - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ - resolveSecretData( - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ - resolveSecretData( - driverOAuthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) - - if (allSecretData.isEmpty) { - None - } else { - Some(new SecretBuilder() - .withNewMetadata() - .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") - .endMetadata() - .withData(allSecretData.asJava) - .build()) - } - } - - private def setDriverPodKubernetesCredentialLocations( - driverSparkConf: SparkConf, - driverOauthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): SparkConf = { - val resolvedMountedOAuthTokenFile = resolveSecretLocation( - maybeMountedOAuthTokenFile, - driverOauthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) - val resolvedMountedClientKeyFile = resolveSecretLocation( - maybeMountedClientKeyFile, - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_PATH) - val resolvedMountedClientCertFile = resolveSecretLocation( - maybeMountedClientCertFile, - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_PATH) - val resolvedMountedCaCertFile = resolveSecretLocation( - maybeMountedCaCertFile, - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_PATH) - - val sparkConfWithCredentialLocations = driverSparkConf - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - resolvedMountedCaCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - resolvedMountedClientKeyFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - resolvedMountedClientCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", - resolvedMountedOAuthTokenFile) - - // Redact all OAuth token values - sparkConfWithCredentialLocations - .getAll - .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) - .foreach { - sparkConfWithCredentialLocations.set(_, "") - } - sparkConfWithCredentialLocations - } - - private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { - submissionSparkConf.getOption(conf) - .map(new File(_)) - .map { file => - require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", - fileType, file.getAbsolutePath)) - BaseEncoding.base64().encode(Files.toByteArray(file)) - } - } - - private def resolveSecretLocation( - mountedUserSpecified: Option[String], - valueMountedFromSubmitter: Option[String], - mountedCanonicalLocation: String): Option[String] = { - mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => - mountedCanonicalLocation - }) - } - - /** - * Resolve a Kubernetes secret data entry from an optional client credential used by the - * driver to talk to the Kubernetes API server. - * - * @param userSpecifiedCredential the optional user-specified client credential. - * @param secretName name of the Kubernetes secret storing the client credential. - * @return a secret data entry in the form of a map from the secret name to the secret data, - * which may be empty if the user-specified credential is empty. - */ - private def resolveSecretData( - userSpecifiedCredential: Option[String], - secretName: String): Map[String, String] = { - userSpecifiedCredential.map { valueBase64 => - Map(secretName -> valueBase64) - }.getOrElse(Map.empty[String, String]) - } - - private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { - new OptionSettableSparkConf(sparkConf) - } -} - -private class OptionSettableSparkConf(sparkConf: SparkConf) { - def setOption(configEntry: String, option: Option[String]): SparkConf = { - option.foreach { opt => - sparkConf.set(configEntry, opt) - } - sparkConf - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala deleted file mode 100644 index 91e9a9f211335..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * A driver configuration step for mounting user-specified secrets onto user-specified paths. - * - * @param bootstrap a utility actually handling mounting of the secrets. - */ -private[spark] class DriverMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) - val container = bootstrap.mountSecrets(driverSpec.driverContainer) - driverSpec.copy( - driverPod = pod, - driverContainer = container - ) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala deleted file mode 100644 index 34af7cde6c1a9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.ServiceBuilder - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.Logging -import org.apache.spark.util.Clock - -/** - * Allows the driver to be reachable by executor pods through a headless service. The service's - * ports should correspond to the ports that the executor will reach the pod at for RPC. - */ -private[spark] class DriverServiceBootstrapStep( - resourceNamePrefix: String, - driverLabels: Map[String, String], - sparkConf: SparkConf, - clock: Clock) extends DriverConfigurationStep with Logging { - - import DriverServiceBootstrapStep._ - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, - s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + - "address is managed and set to the driver pod's IP address.") - require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, - s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + - "managed via a Kubernetes service.") - - val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" - val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { - preferredServiceName - } else { - val randomServiceId = clock.getTimeMillis() - val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" - logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + - s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + - s"$shorterServiceName as the driver service's name.") - shorterServiceName - } - - val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) - val driverService = new ServiceBuilder() - .withNewMetadata() - .withName(resolvedServiceName) - .endMetadata() - .withNewSpec() - .withClusterIP("None") - .withSelector(driverLabels.asJava) - .addNewPort() - .withName(DRIVER_PORT_NAME) - .withPort(driverPort) - .withNewTargetPort(driverPort) - .endPort() - .addNewPort() - .withName(BLOCK_MANAGER_PORT_NAME) - .withPort(driverBlockManagerPort) - .withNewTargetPort(driverBlockManagerPort) - .endPort() - .endSpec() - .build() - - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .set(DRIVER_HOST_KEY, driverHostname) - .set("spark.driver.port", driverPort.toString) - .set( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) - - driverSpec.copy( - driverSparkConf = resolvedSparkConf, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) - } -} - -private[spark] object DriverServiceBootstrapStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key - val DRIVER_SVC_POSTFIX = "-driver-svc" - val MAX_SERVICE_NAME_LENGTH = 63 -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala deleted file mode 100644 index 8607d6fba3234..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.util.Utils - -/** - * A factory class for bootstrapping and creating executor pods with the given bootstrapping - * components. - * - * @param sparkConf Spark configuration - * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto - * user-specified paths into the executor container - */ -private[spark] class ExecutorPodFactory( - sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap]) { - - private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - - private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - - private val executorAnnotations = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - private val nodeSelector = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_NODE_SELECTOR_PREFIX) - - private val executorContainerImage = sparkConf - .get(EXECUTOR_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor container image")) - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - private val blockManagerPort = sparkConf - .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - - private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - - private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) - private val executorMemoryString = sparkConf.get( - EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) - - private val memoryOverheadMiB = sparkConf - .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - - private val executorCores = sparkConf.getInt("spark.executor.cores", 1) - private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { - sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get - } else { - executorCores.toString - } - private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod = { - val name = s"$executorPodNamePrefix-exec-$executorId" - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - // hostname must be no longer than 63 characters, so take the last 63 characters of the pod - // name as the hostname. This preserves uniqueness since the end of name contains - // executorId - val hostname = name.substring(Math.max(0, name.length - 63)) - val resolvedExecutorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> applicationId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorLabels - val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") - .build() - val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCoresRequest) - .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = sparkConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() - } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - (ENV_EXECUTOR_CORES, executorCores.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, applicationId), - // This is to set the SPARK_CONF_DIR to be /opt/spark/conf - (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq - val requiredPorts = Seq( - (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) - .map { case (name, port) => - new ContainerPortBuilder() - .withName(name) - .withContainerPort(port) - .build() - } - - val executorContainer = new ContainerBuilder() - .withName("executor") - .withImage(executorContainerImage) - .withImagePullPolicy(imagePullPolicy) - .withNewResources() - .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryQuantity) - .addToRequests("cpu", executorCpuQuantity) - .endResources() - .addAllToEnv(executorEnv.asJava) - .withPorts(requiredPorts.asJava) - .addToArgs("executor") - .build() - - val executorPod = new PodBuilder() - .withNewMetadata() - .withName(name) - .withLabels(resolvedExecutorLabels.asJava) - .withAnnotations(executorAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withHostname(hostname) - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val containerWithLimitCores = executorLimitCores.map { limitCores => - val executorCpuLimitQuantity = new QuantityBuilder(false) - .withAmount(limitCores) - .build() - new ContainerBuilder(executorContainer) - .editResources() - .addToLimits("cpu", executorCpuLimitQuantity) - .endResources() - .build() - }.getOrElse(executorContainer) - - val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = - mountSecretsBootstrap.map { bootstrap => - (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) - }.getOrElse((executorPod, containerWithLimitCores)) - - - new PodBuilder(maybeSecretsMountedPod) - .editSpec() - .addToContainers(maybeSecretsMountedContainer) - .endSpec() - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ff5f6801da2a3..0ea80dfbc0d97 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -48,12 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit scheduler: TaskScheduler): SchedulerBackend = { val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -62,8 +56,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) - val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( @@ -71,7 +63,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - executorPodFactory, + new KubernetesExecutorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 9de4b16c30d3c..d86664c81071b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorPodFactory: ExecutorPodFactory, + executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, allocatorExecutor: ScheduledExecutorService, requestExecutorsService: ExecutorService) @@ -115,14 +116,19 @@ private[spark] class KubernetesClusterSchedulerBackend( for (_ <- 0 until math.min( currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorPod = executorPodFactory.createExecutorPod( + val executorConf = KubernetesConf.createExecutorConf( + conf, executorId, applicationId(), - driverUrl, - conf.getExecutorEnv, - driverPod, - currentNodeToLocalTaskCount) - executorsToAllocate(executorId) = executorPod + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + + executorsToAllocate(executorId) = podWithAttachedContainer logInfo( s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala new file mode 100644 index 0000000000000..22568fe7ea3be --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesExecutorBuilder( + provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_), + provideSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + var executorPod = SparkPod.initialPod() + for (feature <- allFeatures) { + executorPod = feature.configurePod(executorPod) + } + executorPod + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala new file mode 100644 index 0000000000000..f10202f7a3546 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource + +class KubernetesConfSuite extends SparkFunSuite { + + private val APP_NAME = "test-app" + private val RESOURCE_NAME_PREFIX = "prefix" + private val APP_ID = "test-id" + private val MAIN_CLASS = "test-class" + private val APP_ARGS = Array("arg1", "arg2") + private val CUSTOM_LABELS = Map( + "customLabel1Key" -> "customLabel1Value", + "customLabel2Key" -> "customLabel2Value") + private val CUSTOM_ANNOTATIONS = Map( + "customAnnotation1Key" -> "customAnnotation1Value", + "customAnnotation2Key" -> "customAnnotation2Value") + private val SECRET_NAMES_TO_MOUNT_PATHS = Map( + "secret1" -> "/mnt/secrets/secret1", + "secret2" -> "/mnt/secrets/secret2") + private val CUSTOM_ENVS = Map( + "customEnvKey1" -> "customEnvValue1", + "customEnvKey2" -> "customEnvValue2") + private val DRIVER_POD = new PodBuilder().build() + private val EXECUTOR_ID = "executor-id" + + test("Basic driver translated fields.") { + val sparkConf = new SparkConf(false) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.appId === APP_ID) + assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) + assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) + assert(conf.roleSpecificConf.appName === APP_NAME) + assert(conf.roleSpecificConf.mainAppResource.isEmpty) + assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) + assert(conf.roleSpecificConf.appArgs === APP_ARGS) + } + + test("Creating driver conf with and without the main app jar influences spark.jars") { + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) + val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppJar, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") + .split(",") + === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) + val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + } + + test("Resolve driver labels, annotations, secret mount paths, and envs.") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) + } + CUSTOM_ENVS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) + } + + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.roleLabels === Map( + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ + CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleEnvs === CUSTOM_ENVS) + } + + test("Basic executor translated fields.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) + assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + } + + test("Image pull secrets.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false) + .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.imagePullSecrets() === + Seq( + new LocalObjectReferenceBuilder().withName("my-secret-1").build(), + new LocalObjectReferenceBuilder().withName("my-secret-2").build())) + } + + test("Set executor labels, annotations, and secrets") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) + } + + val conf = KubernetesConf.createExecutorConf( + sparkConf, + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleLabels === Map( + SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..eee85b8baa730 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class BasicDriverFeatureStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) + private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ENVS = Map( + DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, + DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + + test("Check the pod respects all configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + DRIVER_ENVS) + + val featureStep = new BasicDriverFeatureStep(kubernetesConf) + val basePod = SparkPod.initialPod() + val configuredPod = featureStep.configurePod(basePod) + + assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getImage === "spark-driver:latest") + assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + + assert(configuredPod.container.getEnv.size === 3) + val envs = configuredPod.container + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) + assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + + assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === + TEST_IMAGE_PULL_SECRET_OBJECTS) + + assert(configuredPod.container.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = configuredPod.container.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "456Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = configuredPod.pod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") + + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") + assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) + } + + test("Additional system properties resolve jars and set cluster-mode confs.") { + val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") + val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .setJars(allJars) + .set("spark.files", allFiles.mkString(",")) + .set(CONTAINER_IMAGE, "spark-driver:latest") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty) + val step = new BasicDriverFeatureStep(kubernetesConf) + val additionalProperties = step.getAdditionalPodSystemProperties() + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true", + "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + assert(additionalProperties === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala new file mode 100644 index 0000000000000..a764f7630b5c8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +class BasicExecutorFeatureStepSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + + private val APP_ID = "app-id" + private val DRIVER_HOSTNAME = "localhost" + private val DRIVER_PORT = 7098 + private val DRIVER_ADDRESS = RpcEndpointAddress( + DRIVER_HOSTNAME, + DRIVER_PORT.toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + private val DRIVER_POD_NAME = "driver-pod" + + private val DRIVER_POD_UID = "driver-uid" + private val RESOURCE_NAME_PREFIX = "base" + private val EXECUTOR_IMAGE = "executor-image" + private val LABELS = Map("label1key" -> "label1value") + private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") + private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + private val DRIVER_POD = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .withUid(DRIVER_POD_UID) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) + .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set("spark.driver.host", DRIVER_HOSTNAME) + .set("spark.driver.port", DRIVER_PORT.toString) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + } + + test("basic executor pod has reasonable defaults") { + val step = new BasicExecutorFeatureStep( + KubernetesConf( + baseConf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + val executor = step.configurePod(SparkPod.initialPod()) + + // The executor pod name and default labels. + assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") + assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.container.getImage === EXECUTOR_IMAGE) + assert(executor.container.getVolumeMounts.isEmpty) + assert(executor.container.getResources.getLimits.size() === 1) + assert(executor.container.getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.pod.getSpec.getNodeSelector.isEmpty) + assert(executor.pod.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + longPodNamePrefix, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map("qux" -> "quux"))) + val executor = step.configurePod(SparkPod.initialPod()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + ENV_CLASSPATH -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> APP_ID, + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val mapEnvs = executorPod.container.getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala similarity index 67% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 64553d25883bb..9f817d3bfc79a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -14,34 +14,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features import java.io.File -import scala.collection.JavaConverters._ - import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.util.Utils -class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private val APP_ID = "k8s-app" private var credentialsTempDirectory: File = _ - private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) + private val BASE_DRIVER_POD = SparkPod.initialPod() + + @Mock + private var driverSpecificConf: KubernetesDriverSpecificConf = _ before { + MockitoAnnotations.initMocks(this) credentialsTempDirectory = Utils.createTempDir() } @@ -50,13 +51,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA } test("Don't set any credentials") { - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + val kubernetesConf = KubernetesConf( + new SparkConf(false), + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) } test("Only set credentials that are manually mounted.") { @@ -73,14 +80,23 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() + resolvedProperties.foreach { case (propKey, propValue) => + assert(submissionSparkConf.get(propKey) === propValue) + } } test("Mount credentials from the submission client as a secret.") { @@ -100,10 +116,17 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( - BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> @@ -113,16 +136,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> DRIVER_CREDENTIALS_CLIENT_CERT_PATH, s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - DRIVER_CREDENTIALS_CA_CERT_PATH, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> - clientKeyFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> - clientCertFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - caCertFile.getAbsolutePath) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - assert(preparedDriverSpec.otherKubernetesResources.size === 1) - val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + DRIVER_CREDENTIALS_CA_CERT_PATH) + assert(resolvedProperties === expectedSparkConf) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1) + val credentialsSecret = kubernetesCredentialsStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => @@ -134,12 +154,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") assert(decodedSecretData === expectedSecretData) - val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) + val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala assert(driverPodVolumes.size === 1) assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverPodVolumes.head.getSecret != null) assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) - val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala assert(driverContainerVolumeMount.size === 1) assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala new file mode 100644 index 0000000000000..c299d56865ec0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Clock + +class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + sparkConf = sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) + assert(configurationStep.getAdditionalKubernetesResources().size === 1) + assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val resolvedService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + resolvedService) + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) + === DEFAULT_BLOCKMANAGER_PORT.toString) + } + + test("Long prefixes should switch to using a generated name.") { + when(clock.getTimeMillis()).thenReturn(10000) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: Map[String, String], expectedHostName: String): Unit = { + assert(driverSparkConf( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala new file mode 100644 index 0000000000000..27bff74ce38af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.deploy.k8s.SparkPod + +object KubernetesFeaturesTestUtils { + + def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep]( + stepType: String, stepClass: Class[T]): T = { + val mockStep = mock(stepClass) + when(mockStep.getAdditionalKubernetesResources()).thenReturn( + getSecretsForStepType(stepType)) + + when(mockStep.getAdditionalPodSystemProperties()) + .thenReturn(Map(stepType -> stepType)) + when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + .thenAnswer(new Answer[SparkPod]() { + override def answer(invocation: InvocationOnMock): SparkPod = { + val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val configuredPod = new PodBuilder(originalPod.pod) + .editOrNewMetadata() + .addToLabels(stepType, stepType) + .endMetadata() + .build() + SparkPod(configuredPod, originalPod.container) + } + }) + mockStep + } + + def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String) + : Seq[HasMetadata] = { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(stepType) + .endMetadata() + .build()) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala similarity index 64% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 960d0bda1d011..9d02f56cc206d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -14,29 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} -class DriverMountSecretsStepSuite extends SparkFunSuite { +class MountSecretsFeatureStepSuite extends SparkFunSuite { private val SECRET_FOO = "foo" private val SECRET_BAR = "bar" private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("mounts all given secrets") { - val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val baseDriverPod = SparkPod.initialPod() val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + secretNamesToMountPaths, + Map.empty) - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) - val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) - val driverPodWithSecretsMounted = configuredDriverSpec.driverPod - val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + val step = new MountSecretsFeatureStep(kubernetesConf) + val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod + val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 6a501592f42a3..c1b203e03a357 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -16,22 +16,17 @@ */ package org.apache.spark.deploy.k8s.submit -import scala.collection.JavaConverters._ - -import com.google.common.collect.Iterables import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -39,6 +34,74 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" private val KUBERNETES_RESOURCE_PREFIX = "resource-example" + private val POD_NAME = "driver" + private val CONTAINER_NAME = "container" + private val APP_ID = "app-id" + private val APP_NAME = "app" + private val MAIN_CLASS = "main" + private val APP_ARGS = Seq("arg1", "arg2") + private val RESOLVED_JAVA_OPTIONS = Map( + "conf1key" -> "conf1value", + "conf2key" -> "conf2value") + private val BUILT_DRIVER_POD = + new PodBuilder() + .withNewMetadata() + .withName(POD_NAME) + .endMetadata() + .withNewSpec() + .withHostname("localhost") + .endSpec() + .build() + private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build() + private val ADDITIONAL_RESOURCES = Seq( + new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build()) + + private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec( + SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER), + ADDITIONAL_RESOURCES, + RESOLVED_JAVA_OPTIONS) + + private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() + .build() + private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD) + .editSpec() + .addToContainers(FULL_EXPECTED_CONTAINER) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap() + .endVolume() + .endSpec() + .build() + + private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + + private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret => + new SecretBuilder(secret) + .editMetadata() + .addNewOwnerReference() + .withName(POD_NAME) + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .withController(true) + .withUid(DRIVER_POD_UID) + .endOwnerReference() + .endMetadata() + .build() + } private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -56,113 +119,86 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + @Mock + private var driverBuilder: KubernetesDriverBuilder = _ + @Mock private var resourceList: ResourceList = _ - private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ + + private var sparkConf: SparkConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ - private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( + sparkConf, + KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KUBERNETES_RESOURCE_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.withName(POD_NAME)).thenReturn(namedPods) createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) - when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { - override def answer(invocation: InvocationOnMock): Pod = { - new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) - .editMetadata() - .withUid(DRIVER_POD_UID) - .endMetadata() - .withApiVersion(DRIVER_POD_API_VERSION) - .withKind(DRIVER_POD_KIND) - .build() - } - }) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE) when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) doReturn(resourceList) .when(kubernetesClient) .resourceList(createdResourcesArgumentCaptor.capture()) } - test("The client should configure the pod with the submission steps.") { + test("The client should configure the pod using the builder.") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) - assert(createdPod.getMetadata.getLabels.asScala === - Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) - assert(createdPod.getMetadata.getAnnotations.asScala === - Map(SecondTestConfigurationStep.annotationKey -> - SecondTestConfigurationStep.annotationValue)) - assert(createdPod.getSpec.getContainers.size() === 1) - assert(createdPod.getSpec.getContainers.get(0).getName === - SecondTestConfigurationStep.containerName) + verify(podOperations).create(FULL_EXPECTED_POD) } test("The client should create Kubernetes resources") { - val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" - val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( - submissionSteps, - new SparkConf(false) - .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues assert(otherCreatedResources.size === 2) - val secrets = otherCreatedResources.toArray - .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq + assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES) val configMaps = otherCreatedResources.toArray .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) assert(secrets.nonEmpty) - val secret = secrets.head - assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(secret.getData.asScala === - Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) - assert(ownerReference.getName === createdPod.getMetadata.getName) - assert(ownerReference.getKind === DRIVER_POD_KIND) - assert(ownerReference.getUid === DRIVER_POD_UID) - assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) assert(configMaps.nonEmpty) val configMap = configMaps.head assert(configMap.getMetadata.getName === s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( - "spark.custom-conf=custom-conf-value")) - val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) - assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverEnv = driverContainer.getEnv.asScala.head - assert(driverEnv.getName === ENV_SPARK_CONF_DIR) - assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) - val driverMount = driverContainer.getVolumeMounts.asScala.head - assert(driverMount.getName === SPARK_CONF_VOLUME) - assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value")) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value")) } test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, true, "spark", @@ -171,56 +207,4 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } - -} - -private object FirstTestConfigurationStep extends DriverConfigurationStep { - - val podName = "test-pod" - val secretName = "test-secret" - val labelKey = "first-submit" - val labelValue = "true" - val secretKey = "secretKey" - val secretData = "secretData" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .withName(podName) - .addToLabels(labelKey, labelValue) - .endMetadata() - .build() - val additionalResource = new SecretBuilder() - .withNewMetadata() - .withName(secretName) - .endMetadata() - .addToData(secretKey, secretData) - .build() - driverSpec.copy( - driverPod = modifiedPod, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) - } -} - -private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" - val annotationValue = "submitted" - val sparkConfKey = "spark.custom-conf" - val sparkConfValue = "custom-conf-value" - val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .addToAnnotations(annotationKey, annotationValue) - .endMetadata() - .build() - val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) - val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(containerName) - .build() - driverSpec.copy( - driverPod = modifiedPod, - driverSparkConf = resolvedSparkConf, - driverContainer = modifiedContainer) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala deleted file mode 100644 index df34d2dbcb5be..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.steps._ - -class DriverConfigOrchestratorSuite extends SparkFunSuite { - - private val DRIVER_IMAGE = "driver-image" - private val IC_IMAGE = "init-container-image" - private val APP_ID = "spark-app-id" - private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" - private val APP_NAME = "spark" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2") - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/driver" - - test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep]) - } - - test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Option.empty, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep]) - } - - test("Submission steps with driver secrets to mount") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverMountSecretsStep]) - } - - test("Submission using client local dependencies") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - var orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - - sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") - orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - } - - private def validateStepTypes( - orchestrator: DriverConfigOrchestrator, - types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps - assert(steps.size === types.size) - assert(steps.map(_.getClass) === types) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala new file mode 100644 index 0000000000000..161f9afe7bba9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesDriverBuilderSuite extends SparkFunSuite { + + private val BASIC_STEP_TYPE = "basic" + private val CREDENTIALS_STEP_TYPE = "credentials" + private val SERVICE_STEP_TYPE = "service" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) + + private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) + + private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest: KubernetesDriverBuilder = + new KubernetesDriverBuilder( + _ => basicFeatureStep, + _ => credentialsStep, + _ => serviceStep, + _ => secretsStep) + + test("Apply fundamental steps all the time.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) + : Unit = { + assert(resolvedSpec.systemProperties.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) + assert(resolvedSpec.driverKubernetesResources.containsSlice( + KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) + assert(resolvedSpec.systemProperties(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala deleted file mode 100644 index ee450fff8d376..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class BasicDriverConfigurationStepSuite extends SparkFunSuite { - - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" - - test("Set all possible configurations from the user.") { - val sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") - .set("spark.driver.cores", "2") - .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") - - val submissionStep = new BasicDriverConfigurationStep( - APP_ID, - RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - CONTAINER_IMAGE_PULL_POLICY, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = basePod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - - assert(preparedDriverSpec.driverContainer.getEnv.size === 4) - val envs = preparedDriverSpec.driverContainer - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") - assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") - - assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => - envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && - envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && - envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) - - val resourceRequirements = preparedDriverSpec.driverContainer.getResources - val requests = resourceRequirements.getRequests.asScala - assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "456Mi") - val limits = resourceRequirements.getLimits.asScala - assert(limits("memory").getAmount === "456Mi") - assert(limits("cpu").getAmount === "4") - - val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata - assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) - val expectedAnnotations = Map( - CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, - SPARK_APP_NAME_ANNOTATION -> APP_NAME) - assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - - val driverPodSpec = preparedDriverSpec.driverPod.getSpec - assert(driverPodSpec.getRestartPolicy === "Never") - assert(driverPodSpec.getImagePullSecrets.size() === 2) - assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap - val expectedSparkConf = Map( - KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, - "spark.kubernetes.submitInDriver" -> "true") - assert(resolvedSparkConf === expectedSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala deleted file mode 100644 index ca43fc97dc991..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class DependencyResolutionStepSuite extends SparkFunSuite { - - private val SPARK_JARS = Seq( - "apps/jars/jar1.jar", - "local:///var/apps/jars/jar2.jar") - - private val SPARK_FILES = Seq( - "apps/files/file1.txt", - "local:///var/apps/files/file2.txt") - - test("Added dependencies should be resolved in Spark configuration and environment") { - val dependencyResolutionStep = new DependencyResolutionStep( - SPARK_JARS, - SPARK_FILES) - val driverPod = new PodBuilder().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = driverPod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) - assert(preparedDriverSpec.driverPod === driverPod) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet - val expectedResolvedSparkJars = Set( - "apps/jars/jar1.jar", - "/var/apps/jars/jar2.jar") - assert(resolvedSparkJars === expectedResolvedSparkJars) - val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet - val expectedResolvedSparkFiles = Set( - "apps/files/file1.txt", - "/var/apps/files/file2.txt") - assert(resolvedSparkFiles === expectedResolvedSparkFiles) - val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(driverEnv.size === 1) - assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) - val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = expectedResolvedSparkJars - assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala deleted file mode 100644 index 78c8c3ba1afbd..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.util.Clock - -class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) - - private val LONG_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) - private val DRIVER_LABELS = Map( - "label1key" -> "label1value", - "label2key" -> "label2value") - - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - - test("Headless service has a port for the driver RPC and the block manager.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - assert(resolvedDriverSpec.otherKubernetesResources.size === 1) - assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - verifyService( - 9000, - 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - driverService) - } - - test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf, - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - verifyService( - DEFAULT_DRIVER_PORT, - DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) - assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === - DEFAULT_DRIVER_PORT.toString) - assert(resolvedDriverSpec.driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) - } - - test("Long prefixes should switch to using a generated name.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - when(clock.getTimeMillis()).thenReturn(10000) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" - assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Disallow bind address and driver host to be set explicitly.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - clock) - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") - } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") - } - } - - private def verifyService( - driverPort: Int, - blockManagerPort: Int, - expectedServiceName: String, - service: Service): Unit = { - assert(service.getMetadata.getName === expectedServiceName) - assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) - assert(service.getSpec.getPorts.size() === 2) - val driverServicePorts = service.getSpec.getPorts.asScala - assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) - assert(driverServicePorts.head.getPort.intValue() === driverPort) - assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) - assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) - assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) - assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) - } - - private def verifySparkConfHostNames( - driverSparkConf: SparkConf, expectedHostName: String): Unit = { - assert(driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala deleted file mode 100644 index d73df20f0f956..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.MockitoAnnotations -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { - - private val driverPodName: String = "driver-pod" - private val driverPodUid: String = "driver-uid" - private val executorPrefix: String = "base" - private val executorImage: String = "executor-image" - private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(driverPodName) - .withUid(driverPodUid) - .endMetadata() - .withNewSpec() - .withNodeName("some-node") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private var baseConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - baseConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(CONTAINER_IMAGE, executorImage) - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set(IMAGE_PULL_SECRETS, imagePullSecrets) - } - - test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - // The executor pod name and default labels. - assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") - assert(executor.getMetadata.getLabels.size() === 3) - assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - - // There is exactly 1 container with no volume mounts and default memory limits and requests. - // Default memory limit/request is 1024M + 384M (minimum overhead constant). - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getImage === executorImage) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) - assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources - .getRequests.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getContainers.get(0).getResources - .getLimits.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getImagePullSecrets.size() === 2) - assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - // The pod has no node selector, volumes. - assert(executor.getSpec.getNodeSelector.isEmpty) - assert(executor.getSpec.getVolumes.isEmpty) - - checkEnv(executor, Map()) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor core request specification") { - var factory = new ExecutorPodFactory(baseConf, None) - var executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "1") - - val conf = baseConf.clone() - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") - factory = new ExecutorPodFactory(conf, None) - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "0.1") - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - factory = new ExecutorPodFactory(conf, None) - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "100m") - } - - test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() - conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, - "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getHostname.length === 63) - } - - test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) - - checkEnv(executor, - Map("SPARK_JAVA_OPT_0" -> "foo=bar", - ENV_CLASSPATH -> "bar=baz", - "qux" -> "quux")) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor secrets get mounted") { - val conf = baseConf.clone() - - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - // check volume mounted. - assert(executor.getSpec.getVolumes.size() === 1) - assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - // There is always exactly one controller reference, and it points to the driver pod. - private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { - assert(executor.getMetadata.getOwnerReferences.size() === 1) - assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) - assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) - } - - // Check that the expected environment variables are present. - private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { - val defaultEnvs = Map( - ENV_EXECUTOR_ID -> "1", - ENV_DRIVER_URL -> "dummy", - ENV_EXECUTOR_CORES -> "1", - ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> "dummy", - ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) - val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { - x => (x.getName, x.getValue) - }.toMap - assert(defaultEnvs === mapEnvs) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index b2f26f205a329..96065e83f069c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} -import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.hamcrest.{BaseMatcher, Description, Matcher} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito.{doNothing, never, times, verify, when} import org.scalatest.BeforeAndAfter @@ -31,6 +32,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc._ @@ -47,8 +49,6 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 private val POD_ALLOCATION_INTERVAL = "1m" - private val DRIVER_URL = RpcEndpointAddress( - SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() .withNewMetadata() .withName("pod1") @@ -94,7 +94,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var requestExecutorsService: ExecutorService = _ @Mock - private var executorPodFactory: ExecutorPodFactory = _ + private var executorBuilder: KubernetesExecutorBuilder = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -399,7 +399,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, - executorPodFactory, + executorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) { @@ -428,13 +428,22 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) .endMetadata() .build() - when(executorPodFactory.createExecutorPod( - executorId.toString, - APP_ID, - DRIVER_URL, - sparkConf.getExecutorEnv, - driverPod, - Map.empty)).thenReturn(resolvedPod) - resolvedPod + val resolvedContainer = new ContainerBuilder().build() + when(executorBuilder.buildFromFeatures(Matchers.argThat( + new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any) + : Boolean = { + argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && + argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + .roleSpecificConf.executorId == executorId.toString + } + + override def describeTo(description: Description): Unit = {} + }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) + new PodBuilder(resolvedPod) + .editSpec() + .addToContainers(resolvedContainer) + .endSpec() + .build() } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala new file mode 100644 index 0000000000000..f5270623f8acc --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesExecutorBuilderSuite extends SparkFunSuite { + private val BASIC_STEP_TYPE = "basic" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) + private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest = new KubernetesExecutorBuilder( + _ => basicFeatureStep, + _ => mountSecretsStep) + + test("Basic steps are consistently applied.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { + assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) + } + } +} From 4dfd746de3f4346ed0c2191f8523a7e6cc9f064d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 14 Apr 2018 00:22:38 +0800 Subject: [PATCH 0617/2461] [SPARK-23896][SQL] Improve PartitioningAwareFileIndex ## What changes were proposed in this pull request? Currently `PartitioningAwareFileIndex` accepts an optional parameter `userPartitionSchema`. If provided, it will combine the inferred partition schema with the parameter. However, 1. to get `userPartitionSchema`, we need to combine inferred partition schema with `userSpecifiedSchema` 2. to get the inferred partition schema, we have to create a temporary file index. Only after that, a final version of `PartitioningAwareFileIndex` can be created. This can be improved by passing `userSpecifiedSchema` to `PartitioningAwareFileIndex`. With the improvement, we can reduce redundant code and avoid parsing the file partition twice. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21004 from gengliangwang/PartitioningAwareFileIndex. --- .../datasources/CatalogFileIndex.scala | 2 +- .../execution/datasources/DataSource.scala | 133 ++++++++---------- .../datasources/InMemoryFileIndex.scala | 8 +- .../PartitioningAwareFileIndex.scala | 54 ++++--- .../streaming/MetadataLogFileIndex.scala | 10 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- 7 files changed, 103 insertions(+), 108 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4046396d0e614..a66a07673e25f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -85,7 +85,7 @@ class CatalogFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( - sparkSession, rootPaths, table.storage.properties, partitionSchema = None) + sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -103,24 +102,6 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param fileIndex A FileIndex that will perform partition inference - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { - val resolved = fileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } - /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -140,31 +121,26 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileStatusCache the shared cache for file statuses to speed up listing + * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to, e.g., + fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { + // The operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below - lazy val tempFileIndex = { - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) + lazy val tempFileIndex = fileIndex.getOrElse { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) } + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) + tempFileIndex.partitionSchema } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -356,13 +332,7 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) - val fileCatalog = if (userSpecifiedSchema.nonEmpty) { - val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) - new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) - } else { - tempFileCatalog - } + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -384,24 +354,23 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap( - DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray - - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) - - val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && - catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes - new CatalogFileIndex( + val index = new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, Some(index)) + (index, resultDataSchema, resultPartitionSchema) } HadoopFsRelation( @@ -552,6 +521,40 @@ case class DataSource( sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() + allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toSeq + } } object DataSource extends Logging { @@ -699,30 +702,6 @@ object DataSource extends Logging { locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } - /** - * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. - * If `checkFilesExist` is `true`, also check the file existence. - */ - private def checkAndGlobPathIfNecessary( - hadoopConf: Configuration, - path: String, - checkFilesExist: Boolean): Seq[Path] = { - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - } - /** * Called before writing into a FileFormat based data source to make sure the * supplied schema is not empty. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 318ada0ceefc5..739d1f456e3ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration * @param rootPathsSpecified the list of root table paths to scan (some of which might be * filtered out later) * @param parameters as set of options to control discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, rootPathsSpecified: Seq[Path], parameters: Map[String, String], - partitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( - sparkSession, parameters, partitionSchema, fileStatusCache) { + sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e8..cc8af7b92c454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType} * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], - userPartitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { import PartitioningAwareFileIndex.BASE_PATH_PARAM @@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - userPartitionSchema match { + val inferredPartitionSpec = PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + timeZoneId = timeZoneId) + userSpecifiedSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - typeInference = false, - basePaths = basePaths, - timeZoneId = timeZoneId) + val userPartitionSchema = + combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType, + Literal.create(row.get(i, dt), dt), + userPartitionSchema.fields(i).dataType, Option(timeZoneId)).eval() }: _*) } - PartitionSpec(userProvidedSchema, spec.partitions.map { part => + PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) case _ => - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - timeZoneId = timeZoneId) + inferredPartitionSpec } } @@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } + + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param spec A partition inference result + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { + val equality = sparkSession.sessionState.conf.resolver + val resolved = spec.partitionColumns.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 1da703cefd8ea..5cacdd070b735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. * - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class MetadataLogFileIndex( sparkSession: SparkSession, path: Path, - userPartitionSchema: Option[StructType]) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { + userSpecifiedSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -51,7 +51,7 @@ class MetadataLogFileIndex( } override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - allFilesFromLog.toArray.groupBy(_.getPath.getParent) + allFilesFromLog.groupBy(_.getPath.getParent) } override def rootPaths: Seq[Path] = path :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..8764f0c42cf9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -401,7 +401,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi sparkSession = spark, rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], - partitionSchema = None) + userSpecifiedSchema = None) // This should not fail. fileCatalog.listLeafFiles(Seq(new Path(tempDir))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 1a86c604d5da3..3af163af0968c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -419,7 +419,7 @@ class PartitionedTablePerfStatsSuite HiveCatalogMetrics.reset() spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) - assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) } } } From 25892f3cc9dcb938220be8020a5b9a17c92dbdbe Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 14 Apr 2018 01:01:00 +0800 Subject: [PATCH 0618/2461] [SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer ## What changes were proposed in this pull request? Added a new rule to remove Sort operation when its child is already sorted. For instance, this simple code: ``` spark.sparkContext.parallelize(Seq(("a", "b"))).toDF("a", "b").registerTempTable("table1") val df = sql(s"""SELECT b | FROM ( | SELECT a, b | FROM table1 | ORDER BY a | ) t | ORDER BY a""".stripMargin) df.explain(true) ``` before the PR produces this plan: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(3) Project [b#7] +- *(3) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(2) Project [b#7, a#6] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` while after the PR produces: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(2) Project [b#7] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 5) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` this means that an unnecessary sort operation is not performed after the PR. ## How was this patch tested? added UT Author: Marco Gaido Closes #20560 from mgaido91/SPARK-23375. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../catalyst/plans/logical/LogicalPlan.scala | 9 ++ .../plans/logical/basicLogicalOperators.scala | 23 ++-- .../optimizer/RemoveRedundantSortsSuite.scala | 101 ++++++++++++++++++ .../spark/sql/execution/CacheManager.scala | 4 +- .../spark/sql/execution/ExistingRDD.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 17 +-- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +-- 10 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9a1bbc675e397..5fb59ef350b8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -733,6 +735,16 @@ object EliminateSorts extends Rule[LogicalPlan] { } } +/** + * Removes Sort operation if the child is already sorted + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + child + } +} + /** * Removes filters that can be evaluated trivially. This can be done through the following ways: * 1) by eliding the filter for cases where it will always evaluate to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c8ccd9bd03994..42034403d6d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -219,6 +219,11 @@ abstract class LogicalPlan * Refreshes (or invalidates) any metadata/data cached in the plan recursively. */ def refresh(): Unit = children.foreach(_.refresh()) + + /** + * Returns the output ordering that this plan generates. + */ + def outputOrdering: Seq[SortOrder] = Nil } /** @@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Seq(left, right) } + +abstract class OrderPreservingUnaryNode extends UnaryNode { + override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..10df504795430 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { * This node is inserted at the top of a subquery when it is optimized. This makes sure we can * recognize a subquery as such, and it allows us to write subquery aware transformations. */ -case class Subquery(child: LogicalPlan) extends UnaryNode { +case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } -case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -125,7 +126,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) - extends UnaryNode with PredicateHelper { + extends OrderPreservingUnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -469,6 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def outputOrdering: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -522,6 +524,15 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( @@ -728,7 +739,7 @@ object Limit { * * See [[Limit]] for more information. */ -case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { @@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, child: LogicalPlan) - extends UnaryNode { + extends OrderPreservingUnaryNode { override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..2319ab8046e56 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} + +class RemoveRedundantSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.select('a).analyze + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d68aeb275afda..a8794be7280c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -99,7 +99,7 @@ class CacheManager extends Logging { sparkSession.sessionState.conf.columnBatchSize, storageLevel, sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, - planToCache.stats) + planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +148,7 @@ class CacheManager extends Logging { storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, tableName = cd.cachedRepresentation.tableName, - statsOfPlanToCache = cd.plan.stats) + logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f3555508185fe..be50a1571a2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,7 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2579046e30708..a7ba9b86a176f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -39,9 +39,9 @@ object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String], - statsOfPlanToCache: Statistics): InMemoryRelation = + logicalPlan: LogicalPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = statsOfPlanToCache) + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) } @@ -64,7 +64,8 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -76,7 +77,8 @@ case class InMemoryRelation( tableName = None)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache) + statsOfPlanToCache, + outputOrdering) override def producedAttributes: AttributeSet = outputSet @@ -159,7 +161,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { @@ -172,7 +174,8 @@ case class InMemoryRelation( tableName)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache).asInstanceOf[this.type] + statsOfPlanToCache, + outputOrdering).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index cee85ec8af04d..949505e449fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id) + val data = spark.range(0, n, 1, 1).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc7..40915a102bab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext { assert(planned.child.isInstanceOf[CollectLimitExec]) } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).reverse) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) + } + test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490f..9b7b316211d30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, - data.logicalPlan.stats) + data.logicalPlan) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) } @@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val logicalPlan = testData.select('value, 'key).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - logicalPlan.stats) + logicalPlan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -329,7 +330,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan) // Materialize the data. val expectedAnswer = data.collect() @@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, + LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) From 558f31b31c73b7e9f26f56498b54cf53997b59b8 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 13 Apr 2018 14:05:04 -0700 Subject: [PATCH 0619/2461] [SPARK-23963][SQL] Properly handle large number of columns in query on text-based Hive table ## What changes were proposed in this pull request? TableReader would get disproportionately slower as the number of columns in the query increased. I fixed the way TableReader was looking up metadata for each column in the row. Previously, it had been looking up this data in linked lists, accessing each linked list by an index (column number). Now it looks up this data in arrays, where indexing by column number works better. ## How was this patch tested? Manual testing All sbt unit tests python sql tests Author: Bruce Robbins Closes #21043 from bersprockets/tabreadfix. --- .../src/main/scala/org/apache/spark/sql/hive/TableReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc93..b5444a4217924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -381,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal - }.unzip + }.toArray.unzip /** * Builds specific unwrappers ahead of time according to object inspector From cbb41a0c5b01579c85f06ef42cc0585fbef216c5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 13 Apr 2018 16:31:39 -0700 Subject: [PATCH 0620/2461] [SPARK-23966][SS] Refactoring all checkpoint file writing logic in a common CheckpointFileManager interface ## What changes were proposed in this pull request? Checkpoint files (offset log files, state store files) in Structured Streaming must be written atomically such that no partial files are generated (would break fault-tolerance guarantees). Currently, there are 3 locations which try to do this individually, and in some cases, incorrectly. 1. HDFSOffsetMetadataLog - This uses a FileManager interface to use any implementation of `FileSystem` or `FileContext` APIs. It preferably loads `FileContext` implementation as FileContext of HDFS has atomic renames. 1. HDFSBackedStateStore (aka in-memory state store) - Writing a version.delta file - This uses FileSystem APIs only to perform a rename. This is incorrect as rename is not atomic in HDFS FileSystem implementation. - Writing a snapshot file - Same as above. #### Current problems: 1. State Store behavior is incorrect - HDFS FileSystem implementation does not have atomic rename. 1. Inflexible - Some file systems provide mechanisms other than write-to-temp-file-and-rename for writing atomically and more efficiently. For example, with S3 you can write directly to the final file and it will be made visible only when the entire file is written and closed correctly. Any failure can be made to terminate the writing without making any partial files visible in S3. The current code does not abstract out this mechanism enough that it can be customized. #### Solution: 1. Introduce a common interface that all 3 cases above can use to write checkpoint files atomically. 2. This interface must provide the necessary interfaces that allow customization of the write-and-rename mechanism. This PR does that by introducing the interface `CheckpointFileManager` and modifying `HDFSMetadataLog` and `HDFSBackedStateStore` to use the interface. Similar to earlier `FileManager`, there are implementations based on `FileSystem` and `FileContext` APIs, and the latter implementation is preferred to make it work correctly with HDFS. The key method this interface has is `createAtomic(path, overwrite)` which returns a `CancellableFSDataOutputStream` that has the method `cancel()`. All users of this method need to either call `close()` to successfully write the file, or `cancel()` in case of an error. ## How was this patch tested? New tests in `CheckpointFileManagerSuite` and slightly modified existing tests. Author: Tathagata Das Closes #21048 from tdas/SPARK-23966. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../streaming/CheckpointFileManager.scala | 349 ++++++++++++++++++ .../execution/streaming/HDFSMetadataLog.scala | 229 +----------- .../state/HDFSBackedStateStoreProvider.scala | 120 +++--- .../streaming/state/StateStore.scala | 4 +- .../CheckpointFileManagerSuite.scala | 192 ++++++++++ .../CompactibleFileStreamLogSuite.scala | 5 - .../streaming/HDFSMetadataLogSuite.scala | 116 +----- .../streaming/state/StateStoreSuite.scala | 58 ++- 9 files changed, 678 insertions(+), 402 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1c8ab9c62623e..0dc47bfe075d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -930,6 +930,13 @@ object SQLConf { .intConf .createWithDefault(100) + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .internal() + .stringConf + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala new file mode 100644 index 0000000000000..606ba250ad9d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException, OutputStream} +import java.util.{EnumSet, UUID} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs} +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * An interface to abstract out all operation related to streaming checkpoints. Most importantly, + * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a + * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and + * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations + * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or + * without overwrite. + * + * This higher-level interface above the Hadoop FileSystem is necessary because + * different implementation of FileSystem/FileContext may have different combination of operations + * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename, + * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while + * keeping the usage simple (`createAtomic` -> `close` or `cancel`). + */ +trait CheckpointFileManager { + + import org.apache.spark.sql.execution.streaming.CheckpointFileManager._ + + /** + * Create a file and make its contents available atomically after the output stream is closed. + * + * @param path Path to create + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** List the files in a path that match a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** List all the files in a path. */ + def list(path: Path): Array[FileStatus] = { + list(path, new PathFilter { override def accept(path: Path): Boolean = true }) + } + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + + /** Is the default file system this implementation is operating on the local file system. */ + def isLocal: Boolean +} + +object CheckpointFileManager extends Logging { + + /** + * Additional methods in CheckpointFileManager implementations that allows + * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename + */ + sealed trait RenameHelperMethods { self => CheckpointFileManager + /** Create a file with overwrite. */ + def createTempFile(path: Path): FSDataOutputStream + + /** + * Rename a file. + * + * @param srcPath Source path to rename + * @param dstPath Destination path to rename to + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit + } + + /** + * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used + * mainly by `CheckpointFileManager.createAtomic` to write a file atomically. + * + * @see [[CheckpointFileManager]]. + */ + abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream) + extends FSDataOutputStream(underlyingStream, null) { + /** Cancel the `underlyingStream` and ensure that the output file is not generated. */ + def cancel(): Unit + } + + /** + * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing + * to a temporary file and then renames. + */ + sealed class RenameBasedFSDataOutputStream( + fm: CheckpointFileManager with RenameHelperMethods, + finalPath: Path, + tempPath: Path, + overwriteIfPossible: Boolean) + extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) { + + def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = { + this(fm, path, generateTempPath(path), overwrite) + } + + logInfo(s"Writing atomically to $finalPath using temp file $tempPath") + @volatile private var terminated = false + + override def close(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + try { + fm.renameTempFile(tempPath, finalPath, overwriteIfPossible) + } catch { + case fe: FileAlreadyExistsException => + logWarning( + s"Failed to rename temp file $tempPath to $finalPath because file exists", fe) + if (!overwriteIfPossible) throw fe + } + logInfo(s"Renamed temp file $tempPath to $finalPath") + } finally { + terminated = true + } + } + + override def cancel(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + fm.delete(tempPath) + } catch { + case NonFatal(e) => + logWarning(s"Error cancelling write to $finalPath", e) + } finally { + terminated = true + } + } + } + + + /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */ + def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = { + val fileManagerClass = hadoopConf.get( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key) + if (fileManagerClass != null) { + return Utils.classForName(fileManagerClass) + .getConstructor(classOf[Path], classOf[Configuration]) + .newInstance(path, hadoopConf) + .asInstanceOf[CheckpointFileManager] + } + try { + // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename() + // gives atomic renames, which is what we rely on for the default implementation + // `CheckpointFileManager.createAtomic`. + new FileContextBasedCheckpointFileManager(path, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning( + "Could not use FileContext API for managing Structured Streaming checkpoint files at " + + s"$path. Using FileSystem API instead for managing log files. If the implementation " + + s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" + + s"your Structured Streaming is not guaranteed.") + new FileSystemBasedCheckpointFileManager(path, hadoopConf) + } + } + + private def generateTempPath(path: Path): Path = { + val tc = org.apache.spark.TaskContext.get + val tid = if (tc != null) ".TID" + tc.taskAttemptId else "" + new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp") + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */ +class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + protected val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + fs.create(path, true) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def exists(path: Path): Boolean = { + try + return fs.getFileStatus(path) != null + catch { + case e: FileNotFoundException => + return false + } + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + if (!overwriteIfPossible && fs.exists(dstPath)) { + throw new FileAlreadyExistsException( + s"Failed to rename $srcPath to $dstPath as destination already exists") + } + + if (!fs.rename(srcPath, dstPath)) { + // FileSystem.rename() returning false is very ambiguous as it can be for many reasons. + // This tries to make a best effort attempt to return the most appropriate exception. + if (fs.exists(dstPath)) { + if (!overwriteIfPossible) { + throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists") + } + } else if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Failed to rename as $srcPath was not found") + } else { + val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false" + logWarning(msg) + throw new IOException(msg) + } + } + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + logInfo(s"Failed to delete $path as it does not exist") + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => true + case _ => false + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */ +class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + import CreateFlag._ + import Options._ + fc.create( + path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled())) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def exists(path: Path): Boolean = { + fc.util.exists(path) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + import Options.Rename._ + fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE) + } + + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fc.getDefaultFileSystem match { + case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs + case _ => false + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 00bc215a5dc8c..bd0a46115ceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") - import HDFSMetadataLog._ - val metadataPath = new Path(path) - protected val fileManager = createFileManager() + + protected val fileManager = + CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - writeBatch(batchId, metadata) + writeBatchToFile(metadata, batchIdToPath(batchId)) true } } - private def writeTempBatch(metadata: T): Option[Path] = { - while (true) { - val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") - try { - val output = fileManager.create(tempPath) - try { - serialize(metadata, output) - return Some(tempPath) - } finally { - output.close() - } - } catch { - case e: FileAlreadyExistsException => - // Failed to create "tempPath". There are two cases: - // 1. Someone is creating "tempPath" too. - // 2. This is a restart. "tempPath" has already been created but not moved to the final - // batch file (not committed). - // - // For both cases, the batch has not yet been committed. So we can retry it. - // - // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use - // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a - // big problem because it requires the attacker must have the permission to write the - // metadata path. In addition, the old Streaming also have this issue, people can create - // malicious checkpoint files to crash a Streaming application too. - } - } - None - } - - /** - * Write a batch to a temp file then rename it to the batch file. + /** Write a batch to a temp file then rename it to the batch file. * * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, metadata: T): Unit = { - val tempPath = writeTempBatch(metadata).getOrElse( - throw new IllegalStateException(s"Unable to create temp batch file $batchId")) + private def writeBatchToFile(metadata: T, path: Path): Unit = { + val output = fileManager.createAtomic(path, overwriteIfPossible = false) try { - // Try to commit the batch - // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") - fileManager.rename(tempPath, batchIdToPath(batchId)) - - // SPARK-17475: HDFSMetadataLog should not leak CRC files - // If the underlying filesystem didn't rename the CRC file, delete it. - val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") - if (fileManager.exists(crcPath)) fileManager.delete(crcPath) + serialize(metadata, output) + output.close() } catch { case e: FileAlreadyExistsException => - // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. - // So throw an exception to tell the user this is not a valid behavior. + output.cancel() + // If next batch file already exists, then another concurrently running query has + // written it. throw new ConcurrentModificationException( - s"Multiple HDFSMetadataLog are using $path", e) - } finally { - fileManager.delete(tempPath) - } - } - - /** - * @return the deserialized metadata in a batch file, or None if file not exist. - * @throws IllegalArgumentException when path does not point to a batch file. - */ - def get(batchFile: Path): Option[T] = { - if (fileManager.exists(batchFile)) { - if (isBatchFile(batchFile)) { - get(pathToBatchId(batchFile)) - } else { - throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!") - } - } else { - None + s"Multiple streaming queries are concurrently using $path", e) + case e: Throwable => + output.cancel() + throw e } } @@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => @@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - private def createFileManager(): FileManager = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - try { - new FileContextManager(metadataPath, hadoopConf) - } catch { - case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log files at path " + - s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + - s"inconsistent under failures.") - new FileSystemManager(metadataPath, hadoopConf) - } - } - /** * Parse the log version from the given `text` -- will throw exception when the parsed version * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", @@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: object HDFSMetadataLog { - /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ - trait FileManager { - - /** List the files in a path that match a filter. */ - def list(path: Path, filter: PathFilter): Array[FileStatus] - - /** Make directory at the give path and all its parent directories as needed. */ - def mkdirs(path: Path): Unit - - /** Whether path exists */ - def exists(path: Path): Boolean - - /** Open a file for reading, or throw exception if it does not exist. */ - def open(path: Path): FSDataInputStream - - /** Create path, or throw exception if it already exists */ - def create(path: Path): FSDataOutputStream - - /** - * Atomically rename path, or throw exception if it cannot be done. - * Should throw FileNotFoundException if srcPath does not exist. - * Should throw FileAlreadyExistsException if destPath already exists. - */ - def rename(srcPath: Path, destPath: Path): Unit - - /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ - def delete(path: Path): Unit - } - - /** - * Default implementation of FileManager using newer FileContext API. - */ - class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fc = if (path.toUri.getScheme == null) { - FileContext.getFileContext(hadoopConf) - } else { - FileContext.getFileContext(path.toUri, hadoopConf) - } - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fc.util.listStatus(path, filter) - } - - override def rename(srcPath: Path, destPath: Path): Unit = { - fc.rename(srcPath, destPath) - } - - override def mkdirs(path: Path): Unit = { - fc.mkdir(path, FsPermission.getDirDefault, true) - } - - override def open(path: Path): FSDataInputStream = { - fc.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fc.create(path, EnumSet.of(CreateFlag.CREATE)) - } - - override def exists(path: Path): Boolean = { - fc.util().exists(path) - } - - override def delete(path: Path): Unit = { - try { - fc.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - - /** - * Implementation of FileManager using older FileSystem API. Note that this implementation - * cannot provide atomic renaming of paths, hence can lead to consistency issues. This - * should be used only as a backup option, when FileContextManager cannot be used. - */ - class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fs = path.getFileSystem(hadoopConf) - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fs.listStatus(path, filter) - } - - /** - * Rename a path. Note that this implementation is not atomic. - * @throws FileNotFoundException if source path does not exist. - * @throws FileAlreadyExistsException if destination path already exists. - * @throws IOException if renaming fails for some unknown reason. - */ - override def rename(srcPath: Path, destPath: Path): Unit = { - if (!fs.exists(srcPath)) { - throw new FileNotFoundException(s"Source path does not exist: $srcPath") - } - if (fs.exists(destPath)) { - throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") - } - if (!fs.rename(srcPath, destPath)) { - throw new IOException(s"Failed to rename $srcPath to $destPath") - } - } - - override def mkdirs(path: Path): Unit = { - fs.mkdirs(path, FsPermission.getDirDefault) - } - - override def open(path: Path): FSDataInputStream = { - fs.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fs.create(path, false) - } - - override def exists(path: Path): Boolean = { - fs.exists(path) - } - - override def delete(path: Path): Unit = { - try { - fs.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - /** * Verify if batchIds are continuous and between `startId` and `endId`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3f5002a4e6937..df722b953228b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.io._ import java.nio.channels.ClosedChannelException import java.util.Locale @@ -27,13 +27,16 @@ import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} @@ -87,10 +90,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case object ABORTED extends STATE private val newVersion = version + 1 - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) @volatile private var state: STATE = UPDATING - @volatile private var finalDeltaFile: Path = null + private val finalDeltaFile: Path = deltaFile(newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream) override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId @@ -103,14 +106,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val keyCopy = key.copy() val valueCopy = value.copy() mapToUpdate.put(keyCopy, valueCopy) - writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") val prevValue = mapToUpdate.remove(key) if (prevValue != null) { - writeRemoveToDeltaFile(tempDeltaFileStream, key) + writeRemoveToDeltaFile(compressedStream, key) } } @@ -126,8 +129,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit verify(state == UPDATING, "Cannot commit after already committed or aborted") try { - finalizeDeltaFile(tempDeltaFileStream) - finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + commitUpdates(newVersion, mapToUpdate, compressedStream) state = COMMITTED logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion @@ -140,23 +142,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - try { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + if (state == UPDATING) { + state = ABORTED + cancelDeltaFile(compressedStream, deltaFileStream) + } else { state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) - } - } catch { - case c: ClosedChannelException => - // This can happen when underlying file output stream has been closed before the - // compression stream. - logDebug(s"Error aborting version $newVersion into $this", c) - - case e: Exception => - logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -212,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf - fs.mkdirs(baseDir) + fm.mkdirs(baseDir) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -251,31 +244,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val loadedMaps = new mutable.HashMap[Long, MapType] private lazy val baseDir = stateStoreId.storeCheckpointLocation() - private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - /** Commit a set of updates to the store with the given new version */ - private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { - val finalDeltaFile = deltaFile(newVersion) - - // scalastyle:off - // Renaming a file atop an existing one fails on HDFS - // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). - // Hence we should either skip the rename step or delete the target file. Because deleting the - // target file will break speculation, skipping the rename step is the only choice. It's still - // semantically correct because Structured Streaming requires rerunning a batch should - // generate the same output. (SPARK-19677) - // scalastyle:on - if (fs.exists(finalDeltaFile)) { - fs.delete(tempDeltaFile, true) - } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { - throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") - } + finalizeDeltaFile(output) loadedMaps.put(newVersion, map) - finalDeltaFile } } @@ -365,7 +342,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val fileToRead = deltaFile(version) var input: DataInputStream = null val sourceStream = try { - fs.open(fileToRead) + fm.open(fileToRead) } catch { case f: FileNotFoundException => throw new IllegalStateException( @@ -412,12 +389,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val tempFile = - new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") + val targetFile = snapshotFile(version) + var rawOutput: CancellableFSDataOutputStream = null var output: DataOutputStream = null - Utils.tryWithSafeFinally { - output = compressStream(fs.create(tempFile, false)) + try { + rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true) + output = compressStream(rawOutput) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -429,16 +406,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit output.write(valueBytes) } output.writeInt(-1) - } { - if (output != null) output.close() + output.close() + } catch { + case e: Throwable => + cancelDeltaFile(compressedStream = output, rawStream = rawOutput) + throw e } - if (fs.exists(fileToWrite)) { - // Skip rename if the file is alreayd created. - fs.delete(tempFile, true) - } else if (!fs.rename(tempFile, fileToWrite)) { - throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + logInfo(s"Written snapshot file for version $version of $this at $targetFile") + } + + /** + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + private def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. } - logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -447,7 +442,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit var input: DataInputStream = null try { - input = decompressStream(fs.open(fileToRead)) + input = decompressStream(fm.open(fileToRead)) var eof = false while (!eof) { @@ -508,7 +503,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case None => // The last map is not loaded, probably some other instance is in charge } - } } catch { case NonFatal(e) => @@ -534,7 +528,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) filesToDelete.foreach { f => - fs.delete(f.path, true) + fm.delete(f.path) } logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) @@ -576,7 +570,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) + fm.list(baseDir) } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d1d9f95cb0977..7eb68c21569ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -459,7 +459,6 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - logInfo("Env is not null") val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER || env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -467,13 +466,12 @@ object StateStore extends Logging { // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. if (isDriver || _coordRef == null) { - logInfo("Getting StateStoreCoordinatorRef") + logDebug("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { - logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala new file mode 100644 index 0000000000000..fe59cb25d5005 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +abstract class CheckpointFileManagerTests extends SparkFunSuite { + + def createManager(path: Path): CheckpointFileManager + + test("mkdirs, list, createAtomic, open, delete, exists") { + withTempPath { p => + val basePath = new Path(p.getAbsolutePath) + val fm = createManager(basePath) + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create atomic without overwrite + var path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).close() + assert(fm.exists(path)) + quietly { + intercept[IOException] { + // should throw exception since file exists and overwrite is false + fm.createAtomic(path, overwriteIfPossible = false).close() + } + } + + // Create atomic with overwrite if possible + path = new Path(s"$dir/file2") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() + assert(fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() // should not throw exception + + // Open and delete + fm.open(path).close() + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + } + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} + +class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("CheckpointFileManager.create() should pick up user-specified class from conf") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key -> + classOf[CreateAtomicTestManager].getName) { + val fileManager = + CheckpointFileManager.create(new Path("/"), spark.sessionState.newHadoopConf) + assert(fileManager.isInstanceOf[CreateAtomicTestManager]) + } + } + + test("CheckpointFileManager.create() should fallback from FileContext to FileSystem") { + import CheckpointFileManagerSuiteFileSystem.scheme + spark.conf.set(s"fs.$scheme.impl", classOf[CheckpointFileManagerSuiteFileSystem].getName) + quietly { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) + + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) + } + } + } +} + +class FileContextBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileContextBasedCheckpointFileManager(path, new Configuration()) + } +} + +class FileSystemBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileSystemBasedCheckpointFileManager(path, new Configuration()) + } +} + + +/** A fake implementation to test different characteristics of CheckpointFileManager interface */ +class CreateAtomicTestManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + import CheckpointFileManager._ + + override def createAtomic(path: Path, overwrite: Boolean): CancellableFSDataOutputStream = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + } + val originalOut = super.createAtomic(path, overwrite) + + new CancellableFSDataOutputStream(originalOut) { + override def close(): Unit = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + throw new IOException("Copy failed intentionally") + } + super.close() + } + + override def cancel(): Unit = { + CreateAtomicTestManager.cancelCalledInCreateAtomic = true + originalOut.cancel() + } + } + } +} + +object CreateAtomicTestManager { + @volatile var shouldFailInCreateAtomic = false + @volatile var cancelCalledInCreateAtomic = false +} + + +/** + * CheckpointFileManagerSuiteFileSystem to test fallback of the CheckpointFileManager + * from FileContext to FileSystem API. + */ +private class CheckpointFileManagerSuiteFileSystem extends RawLocalFileSystem { + import CheckpointFileManagerSuiteFileSystem.scheme + + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +private object CheckpointFileManagerSuiteFileSystem { + val scheme = s"CheckpointFileManagerSuiteFileSystem${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 12eaf63415081..ec961a9ecb592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -22,15 +22,10 @@ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - import CompactibleFileStreamLog._ /** -- testing of `object CompactibleFileStreamLog` begins -- */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4677769c12a35..9268306ce4275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -17,46 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.io.{File, FileNotFoundException, IOException} -import java.net.URI +import java.io.File import java.util.ConcurrentModificationException import scala.language.implicitConversions -import scala.util.Random -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ -import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - private implicit def toOption[A](a: A): Option[A] = Option(a) - test("FileManager: FileContextManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileContextManager(path, new Configuration)) - } - } - - test("FileManager: FileSystemManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileSystemManager(path, new Configuration)) - } - } - test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir @@ -82,26 +58,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - spark.conf.set( - s"fs.$scheme.impl", - classOf[FakeFileSystem].getName) - withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog.add(0, "batch0")) - assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - - - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog2.get(0) === Some("batch0")) - assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) - - } - } - test("HDFSMetadataLog: purge") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -121,7 +77,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { // There should be exactly one file, called "2", in the metadata directory. // This check also tests for regressions of SPARK-17475 - val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + val allFiles = new File(metadataLog.metadataPath.toString).listFiles() + .filter(!_.getName.startsWith(".")).toSeq assert(allFiles.size == 1) assert(allFiles(0).getName() == "2") } @@ -172,7 +129,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("HDFSMetadataLog: metadata directory collision") { + testQuietly("HDFSMetadataLog: metadata directory collision") { withTempDir { temp => val waiter = new Waiter val maxBatchId = 100 @@ -206,60 +163,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - /** Basic test case for [[FileManager]] implementation. */ - private def testFileManager(basePath: Path, fm: FileManager): Unit = { - // Mkdirs - val dir = new Path(s"$basePath/dir/subdir/subsubdir") - assert(!fm.exists(dir)) - fm.mkdirs(dir) - assert(fm.exists(dir)) - fm.mkdirs(dir) - - // List - val acceptAllFilter = new PathFilter { - override def accept(path: Path): Boolean = true - } - val rejectAllFilter = new PathFilter { - override def accept(path: Path): Boolean = false - } - assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) - assert(fm.list(basePath, rejectAllFilter).length === 0) - - // Create - val path = new Path(s"$dir/file") - assert(!fm.exists(path)) - fm.create(path).close() - assert(fm.exists(path)) - intercept[IOException] { - fm.create(path) - } - - // Open and delete - fm.open(path).close() - fm.delete(path) - assert(!fm.exists(path)) - intercept[IOException] { - fm.open(path) - } - fm.delete(path) // should not throw exception - - // Rename - val path1 = new Path(s"$dir/file1") - val path2 = new Path(s"$dir/file2") - fm.create(path1).close() - assert(fm.exists(path1)) - fm.rename(path1, path2) - intercept[FileNotFoundException] { - fm.rename(path1, path2) - } - val path3 = new Path(s"$dir/file3") - fm.create(path3).close() - assert(fm.exists(path3)) - intercept[FileAlreadyExistsException] { - fm.rename(path2, path3) - } - } - test("verifyBatchIds") { import HDFSMetadataLog.verifyBatchIds verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) @@ -277,14 +180,3 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) } } - -/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ -class FakeFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } -} - -object FakeFileSystem { - val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c843b65020d8c..73f8705060402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +27,17 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +137,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getData(provider, 19) === Set("a" -> 19)) } - test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") @@ -344,7 +343,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } - test("SPARK-18342: commit fails when rename fails") { + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() val conf = new Configuration() @@ -366,7 +365,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def numTempFiles: Int = { if (deltaFileDir.exists) { - deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + deltaFileDir.listFiles.map(_.getName).count(n => n.endsWith(".tmp")) } else 0 } @@ -471,6 +470,43 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("error writing [version].delta cancels the output stream") { + + val hadoopConf = new Configuration() + hadoopConf.set( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + classOf[CreateAtomicTestManager].getName) + val remoteDir = Utils.createTempDir().getAbsolutePath + + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -720,6 +756,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] * this provider */ def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] + + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } } object StateStoreTestsHelper { From 73f28530d6f6dd8aba758ea818c456cf911a5f41 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Apr 2018 08:59:04 +0800 Subject: [PATCH 0621/2461] [SPARK-23979][SQL] MultiAlias should not be a CodegenFallback ## What changes were proposed in this pull request? Just found `MultiAlias` is a `CodegenFallback`. It should not be as looks like `MultiAlias` won't be evaluated. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21065 from viirya/multialias-without-codegenfallback. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a65f58fa61ff4..71e23175168e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode @@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression with CodegenFallback { + extends UnaryExpression with NamedExpression with Unevaluable { override def name: String = throw new UnresolvedException(this, "name") From c0964935d614bf345535439bce01cbd0e60c86aa Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Mon, 16 Apr 2018 12:01:42 +0800 Subject: [PATCH 0622/2461] [SPARK-23956][YARN] Use effective RPC port in AM registration ## What changes were proposed in this pull request? We propose not to hard-code the RPC port in the AM registration. ## How was this patch tested? Tested application reports from a pseudo-distributed cluster ``` 18/04/10 14:56:21 INFO Client: client token: N/A diagnostics: N/A ApplicationMaster host: localhost ApplicationMaster RPC port: 58338 queue: default start time: 1523397373659 final status: UNDEFINED tracking URL: http://localhost:8088/proxy/application_1523370127531_0016/ ``` Author: Gera Shegalov Closes #21047 from gerashegalov/gera/am-to-rm-nmhost. --- .../scala/org/apache/spark/deploy/yarn/YarnRMClient.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index c1ae12aabb8cc..17234b120ae13 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -71,7 +70,8 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, + trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, From 69310220319163bac18c9ee69d7da6d92227253b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 15 Apr 2018 21:45:55 -0700 Subject: [PATCH 0623/2461] [SPARK-23917][SQL] Add array_max function ## What changes were proposed in this pull request? The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21024 from mgaido91/SPARK-23917. --- python/pyspark/sql/functions.py | 15 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 68 ++++++++++++++++++- .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 133 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..f3492ae42639c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 131b958239e41..05bfa2dd45340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9212c3de1f814..942dfd4292610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0abfc9fa4c465..c86c5beded9d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..e2614a179aad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..a2384019533b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..daf407926dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..5d5d92c84df6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 083cf223569b7896e35ff1d53a73498a4971b28d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 16 Apr 2018 23:50:50 +0800 Subject: [PATCH 0624/2461] [SPARK-21033][CORE][FOLLOW-UP] Update Spillable ## What changes were proposed in this pull request? Update ```scala SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) ``` to ```scala SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) ``` because of `SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD`'s default value is `Integer.MAX_VALUE`: https://github.com/apache/spark/blob/c99fc9ad9b600095baba003053dbf84304ca392b/core/src/main/scala/org/apache/spark/internal/config/package.scala#L503-L511 ## How was this patch tested? N/A Author: Yuming Wang Closes #21077 from wangyum/SPARK-21033. --- .../org/apache/spark/util/collection/Spillable.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 8183f825592c0..81457b53cd814 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** @@ -41,7 +42,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - protected def elementsRead: Long = _elementsRead + protected def elementsRead: Int = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -54,15 +55,15 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // Force this collection to spill when there are this many elements in memory // For testing only - private[this] val numElementsForceSpillThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + private[this] val numElementsForceSpillThreshold: Int = + SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 @volatile private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill - private[this] var _elementsRead = 0L + private[this] var _elementsRead = 0 // Number of bytes spilled in total @volatile private[this] var _memoryBytesSpilled = 0L From 5003736ad60c3231bb18264c9561646c08379170 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 16 Apr 2018 11:27:30 -0500 Subject: [PATCH 0625/2461] [SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel add RawPrediction as output column add numClasses and numFeatures to OneVsRestModel ## What changes were proposed in this pull request? - Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future - Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21044 from ludatabricks/SPARK-9312. --- .../spark/ml/classification/OneVsRest.scala | 56 +++++++++++++++---- .../ml/classification/OneVsRestSuite.scala | 7 ++- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f04fde2cbbca1..5348d882cfd67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams +private[ml] trait OneVsRestParams extends ClassifierParams with ClassifierTypeTrait with HasWeightCol { /** @@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + require(models.nonEmpty, "OneVsRestModel requires at least one model for one class") + + @Since("2.4.0") + val numClasses: Int = models.length + + @Since("2.4.0") + val numFeatures: Int = models.head.numFeatures + /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset @@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble - } + if (getRawPredictionCol != "") { + val numClass = models.length - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + // output the RawPrediction as vector + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val predArray = Array.fill[Double](numClass)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } + + // output confidence as raw prediction, label and label metadata as prediction + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output label and label metadata as prediction + aggregatedDataset + .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } } @Since("1.4.1") @@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + /** * The implementation of parallel one vs. rest runs the classification for * each class in a separate threads. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 11e88367108b4..2c3417c7e4028 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") + assert(ova.getRawPredictionCol === "rawPrediction") val ovaModel = ova.fit(dataset) MLTestingUtils.checkCopyAndUids(ova, ovaModel) - assert(ovaModel.models.length === numClasses) + assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") + ovaModel.setRawPredictionCol("") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet assert(outputFields === Set("y", "fea", "pred")) @@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet + === Set("label", "features", "prediction", "rawPrediction")) } test("SPARK-21306: OneVsRest should support setWeightCol") { From 04614820e103feeae91299dc90dba1dd628fd485 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 16 Apr 2018 11:31:24 -0500 Subject: [PATCH 0626/2461] [SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API ## What changes were proposed in this pull request? Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting. ## How was this patch tested? UT added. Author: WeichenXu Closes #19627 from WeichenXu123/expose-model-list-py. --- .../spark/ml/tuning/CrossValidator.scala | 11 ++ .../ml/tuning/TrainValidationSplit.scala | 11 ++ .../ml/param/_shared_params_code_gen.py | 5 + python/pyspark/ml/param/shared.py | 24 ++++ python/pyspark/ml/tests.py | 78 +++++++++++++ python/pyspark/ml/tuning.py | 107 +++++++++++++----- python/pyspark/ml/util.py | 4 + 7 files changed, 211 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a0b507d2e718c..c2826dcc08634 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[JList[Model[_]]]) + : CrossValidatorModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray.map(_.asScala.toArray)) + } else { + None + } + this + } + /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 88ff0dfd75e96..8d1b9a8ddab59 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[Model[_]]) + : TrainValidationSplitModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray) + } else { + None + } + this + } + /** * @return submodels represented in array. The index of array corresponds to the ordering of * estimatorParamMaps diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index db951d81de1e7..6e9e0a34cdfde 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -157,6 +157,11 @@ def get$Name(self): "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", "1", "TypeConverters.toInt"), + ("collectSubModels", "Param for whether to collect a list of sub-models trained during " + + "tuning. If set to false, then only the single best sub-model will be available after " + + "fitting. If set to true, then all sub-models will be available. Warning: For large " + + "models, collecting all sub-models can cause OOMs on the Spark driver.", + "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 474c38764e5a1..08408ee8fbfcc 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -655,6 +655,30 @@ def getParallelism(self): return self.getOrDefault(self.parallelism) +class HasCollectSubModels(Params): + """ + Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. + """ + + collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean) + + def __init__(self): + super(HasCollectSubModels, self).__init__() + self._setDefault(collectSubModels=False) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + + def getCollectSubModels(self): + """ + Gets the value of collectSubModels or its default value. + """ + return self.getOrDefault(self.collectSubModels) + + class HasLoss(Params): """ Mixin for param loss: the loss function to be optimized. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4ce54547eab09..2ec0be60e9fa9 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1018,6 +1018,50 @@ def test_parallel_evaluation(self): cvParallelModel = cv.fit(dataset) self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -1186,6 +1230,40 @@ def test_parallel_evaluation(self): tvsParallelModel = tvs.fit(dataset) self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 545e24ca05aa5..0c8029f293cfe 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasParallelism, HasSeed +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -33,7 +33,7 @@ 'TrainValidationSplitModel'] -def _parallelFitTasks(est, train, eva, validation, epm): +def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm): :param eva: Evaluator, used to compute `metric` :param validation: DataFrame, validation data set, used for evaluation. :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. - :return: (int, float), an index into `epm` and the associated metric value. + :param collectSubModel: Whether to collect sub model. + :return: (int, float, subModel), an index into `epm` and the associated metric value. """ modelIter = est.fitMultiple(train, epm) def singleTask(): index, model = next(modelIter) metric = eva.evaluate(model.transform(validation, epm[index])) - return index, metric + return index, metric, model if collectSubModel else None return [singleTask] * len(epm) @@ -194,7 +195,8 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1) + seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) @@ -246,10 +248,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -282,6 +284,10 @@ def _fit(self, dataset): metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h @@ -290,9 +296,12 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) + if collectSubModelsParam: + subModels[i][j] = subModel + validation.unpersist() train.unpersist() @@ -301,7 +310,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics)) + return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels)) @since("1.4.0") def copy(self, extra=None): @@ -345,9 +354,11 @@ def _from_java(cls, java_stage): numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed, parallelism=parallelism) + numFolds=numFolds, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -367,6 +378,7 @@ def _to_java(self): _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[]): + def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics + #: sub model list from cross validation + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -399,6 +413,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -407,7 +422,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics - return CrossValidatorModel(bestModel, avgMetrics) + subModels = self.subModels + return CrossValidatorModel(bestModel, avgMetrics, subModels) @since("2.3.0") def write(self): @@ -426,13 +442,17 @@ def _from_java(cls, java_stage): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ - bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [[JavaParams._from_java(sub_model) + for sub_model in fold_sub_models] + for fold_sub_models in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -454,10 +474,16 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels] + _java_obj.setSubModels(java_sub_models) return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ .. note:: Experimental @@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None) + parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) @@ -505,10 +531,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -541,11 +567,19 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [None for i in range(numModels)] + + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric + if collectSubModelsParam: + subModels[j] = subModel + train.unpersist() validation.unpersist() @@ -554,7 +588,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) @since("2.0.0") def copy(self, extra=None): @@ -598,9 +632,11 @@ def _from_java(cls, java_stage): trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed, parallelism=parallelism) + trainRatio=trainRatio, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -620,7 +656,7 @@ def _to_java(self): _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) - + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[]): + def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() - #: best model from cross validation + #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics + #: sub models from train validation split + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -651,6 +689,7 @@ def copy(self, extra=None): creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -659,7 +698,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) - return TrainValidationSplitModel(bestModel, validationMetrics) + subModels = self.subModels + return TrainValidationSplitModel(bestModel, validationMetrics, subModels) @since("2.3.0") def write(self): @@ -687,6 +727,10 @@ def _from_java(cls, java_stage): py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [JavaParams._from_java(sub_model) + for sub_model in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -708,6 +752,11 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + _java_obj.setSubModels(java_sub_models) + return _java_obj diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3c47bd79459a..a486c6a3fdeb5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -169,6 +169,10 @@ def overwrite(self): self._jwrite.overwrite() return self + def option(self, key, value): + self._jwrite.option(key, value) + return self + def context(self, sqlContext): """ Sets the SQL context to use for saving. From fd990a908b94d1c90c4ca604604f35a13b453d44 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Apr 2018 22:45:57 +0200 Subject: [PATCH 0627/2461] [SPARK-23873][SQL] Use accessors in interpreted LambdaVariable ## What changes were proposed in this pull request? Currently, interpreted execution of `LambdaVariable` just uses `InternalRow.get` to access element. We should use specified accessors if possible. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20981 from viirya/SPARK-23873. --- .../spark/sql/catalyst/InternalRow.scala | 26 ++++++++++++- .../catalyst/expressions/BoundAttribute.scala | 22 ++--------- .../expressions/objects/objects.scala | 8 +++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 38 ++++++++++++++++++- 5 files changed, 75 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 29110640d64f2..274d75e680f03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -119,4 +119,28 @@ object InternalRow { case v: MapData => v.copy() case _ => value } + + /** + * Returns an accessor for an `InternalRow` with given data type. The returned accessor + * actually takes a `SpecializedGetters` input because it can be generalized to other classes + * that implements `SpecializedGetters` (e.g., `ArrayData`) too. + */ + def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType) + case _ => (input, ordinal) => input.get(ordinal, dataType) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5021a567592e0..4cc84b27d9eb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (input.isNullAt(ordinal)) { + if (nullable && input.isNullAt(ordinal)) { null } else { - dataType match { - case BooleanType => input.getBoolean(ordinal) - case ByteType => input.getByte(ordinal) - case ShortType => input.getShort(ordinal) - case IntegerType | DateType => input.getInt(ordinal) - case LongType | TimestampType => input.getLong(ordinal) - case FloatType => input.getFloat(ordinal) - case DoubleType => input.getDouble(ordinal) - case StringType => input.getUTF8String(ordinal) - case BinaryType => input.getBinary(ordinal) - case CalendarIntervalType => input.getInterval(ordinal) - case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => input.getStruct(ordinal, t.size) - case _: ArrayType => input.getArray(ordinal) - case _: MapType => input.getMap(ordinal) - case _ => input.get(ordinal, dataType) - } + accessor(input, ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e90ca550807..77802e89e942b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -560,11 +560,17 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - input.get(0, dataType) + if (nullable && input.isNullAt(0)) { + null + } else { + accessor(input, 0) + } } override def genCode(ctx: CodegenContext): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..b4bf6d7107d7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], MapData and Row. */ - protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b1bc67dfac1b5..b0188b0098def 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util._ @@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("LambdaVariable should support interpreted execution") { + def genSchema(dt: DataType): Seq[StructType] = { + Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), + StructType(StructField("col_1", dt, nullable = true) :: Nil)) + } + + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val mapTypes = elementTypes.flatMap { elementType => + Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true)) + } + val structTypes = elementTypes.flatMap { elementType => + Seq(StructType(StructField("col1", elementType, false) :: Nil), + StructType(StructField("col1", elementType, true) :: Nil)) + } + + val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes + val random = new Random(100) + testTypes.foreach { dt => + genSchema(dt).map { schema => + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) + } + } + } } class TestBean extends Serializable { From 14844a62c025e7299029d7452b8c4003bc221ac8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 17:55:35 +0900 Subject: [PATCH 0628/2461] [SPARK-23918][SQL] Add array_min function ## What changes were proposed in this pull request? The PR adds the SQL function `array_min`. It takes an array as argument and returns the minimum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21025 from mgaido91/SPARK-23918. --- python/pyspark/sql/functions.py | 17 ++++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 64 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 131 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f3492ae42639c..6ca22b610843d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_min(col): + """ + Collection function: returns the minimum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_min(_to_java_column(col))) + + @since(2.4) def array_max(col): """ @@ -2108,7 +2123,7 @@ def sort_array(col, asc=True): [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] - """ + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 05bfa2dd45340..4dd1ca509bf2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 942dfd4292610..d4e322d23b95b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -595,11 +595,7 @@ case class Least(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfSmaller(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c86c5beded9d0..d97611c98ac91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is smaller than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, partialResult.value, item.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code for updating `partialResult` if `item` is greater than it. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e2614a179aad8..7c87777eed47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -288,6 +288,70 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} /** * Returns the maximum value in the array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2384019533b7..5a31e3a30edd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Array Min") { + checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) + checkEvaluation( + ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "") + checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234) + } + test("Array max") { checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index daf407926dca4..642ac056bb809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the minimum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } + /** * Returns the maximum value in the array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d5d92c84df6d..636e86baedf6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_min function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(1), Row(null), Row(null), Row(-100)) + + checkAnswer(df.select(array_min(df("a"))), answer) + checkAnswer(df.selectExpr("array_min(a)"), answer) + } + test("array_max function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From 1cc66a072b7fd3bf140fa41596f6b18f8d1bd7b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Apr 2018 01:59:38 -0700 Subject: [PATCH 0629/2461] [SPARK-23687][SS] Add a memory source for continuous processing. ## What changes were proposed in this pull request? Add a memory source for continuous processing. Note that only one of the ContinuousSuite tests is migrated to minimize the diff here. I'll submit a second PR for SPARK-23688 to change the rest and get rid of waitForRateSourceTriggers. ## How was this patch tested? unit test Author: Jose Torres Closes #20828 from jose-torres/continuousMemory. --- .../continuous/ContinuousExecution.scala | 5 +- .../sql/execution/streaming/memory.scala | 59 +++-- .../sources/ContinuousMemoryStream.scala | 211 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 31 +-- 5 files changed, 266 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1758b3844bd62..951d694355ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -317,8 +318,10 @@ class ContinuousExecution( synchronized { if (queryExecutionThread.isAlive) { commitLog.add(epoch) - val offset = offsetLog.get(epoch).get.offsets(0).get + val offset = + continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) + continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) } else { return } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 352d4ce9fbcaa..628923d367ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -47,16 +49,43 @@ object MemoryStream { new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) } +/** + * A base class for memory stream implementations. Supports adding data and resetting. + */ +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { + protected val encoder = encoderFor[A] + protected val attributes = encoder.schema.toAttributes + + def toDS(): Dataset[A] = { + Dataset[A](sqlContext.sparkSession, logicalPlan) + } + + def toDF(): DataFrame = { + Dataset.ofRows(sqlContext.sparkSession, logicalPlan) + } + + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + + def readSchema(): StructType = encoder.schema + + protected def logicalPlan: LogicalPlan + + def addData(data: TraversableOnce[A]): Offset +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - protected val encoder = encoderFor[A] - private val attributes = encoder.schema.toAttributes - protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) + extends MemoryStreamBase[A](sqlContext) + with MicroBatchReader with SupportsScanUnsafeRow with Logging { + + protected val logicalPlan: LogicalPlan = + StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected var currentOffset: LongOffset = new LongOffset(-1) @GuardedBy("this") - private var startOffset = new LongOffset(-1) + protected var startOffset = new LongOffset(-1) @GuardedBy("this") private var endOffset = new LongOffset(-1) @@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def toDS(): Dataset[A] = { - Dataset(sqlContext.sparkSession, logicalPlan) - } - - def toDF(): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, logicalPlan) - } - - def addData(data: A*): Offset = { - addData(data.toTraversable) - } - def addData(data: TraversableOnce[A]): Offset = { val objects = data.toSeq val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray @@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def readSchema(): StructType = encoder.schema - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) override def getStartOffset: OffsetV2 = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala new file mode 100644 index 0000000000000..c28919b8b729b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.{Encoder, Row, SQLContext} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + +/** + * The overall strategy here is: + * * ContinuousMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * RecordEndpoint is set up as an endpoint for executor-side + * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified + * offset within the list, or null if that offset doesn't yet have a record. + */ +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) + private val NUM_PARTITIONS = 2 + + protected val logicalPlan = + StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + + // ContinuousReader implementation + + @GuardedBy("this") + private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + + private val recordEndpoint = new RecordEndpoint() + @volatile private var endpointRef: RpcEndpointRef = _ + + def addData(data: TraversableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.toSeq.zipWithIndex.map { + case (item, index) => records(index % NUM_PARTITIONS) += item + } + + // The new target offset is the offset where all records in all partitions have been processed. + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + } + + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset + } + + override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset( + offsets.map { + case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + }.toMap + ) + } + + override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + synchronized { + val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = + recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + new ContinuousMemoryStreamDataReaderFactory( + endpointName, part, index): DataReaderFactory[Row] + }.toList.asJava + } + } + + override def stop(): Unit = { + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + override def commit(end: Offset): Unit = {} + + // ContinuousReadSupport implementation + // This is necessary because of how StreamTest finds the source for AddDataMemory steps. + def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + this + } + + /** + * Endpoint for executors to poll for records. + */ + private class RecordEndpoint extends ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => + ContinuousMemoryStream.this.synchronized { + val buf = records(part) + val record = if (buf.size <= index) None else Some(buf(index)) + + context.reply(record.map(Row(_))) + } + } + } +} + +object ContinuousMemoryStream { + case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * Data reader factory for continuous memory stream. + */ +class ContinuousMemoryStreamDataReaderFactory( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends DataReaderFactory[Row] { + override def createDataReader: ContinuousMemoryStreamDataReader = + new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) +} + +/** + * Data reader for continuous memory stream. + * + * Polls the driver endpoint for new records. + */ +class ContinuousMemoryStreamDataReader( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends ContinuousDataReader[Row] { + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[Row] = None + + override def next(): Boolean = { + current = None + while (current.isEmpty) { + Thread.sleep(10) + current = endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + } + currentOffset += 1 + true + } + + override def get(): Row = current.get + + override def close(): Unit = {} + + override def getOffset: ContinuousMemoryStreamPartitionOffset = + ContinuousMemoryStreamPartitionOffset(partition, currentOffset) +} + +case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) + extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(partitionNums) +} + +case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) + extends PartitionOffset diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 00741d660dd2d..af0268fa47871 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * been processed. */ object AddData { - def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] = AddDataMemory(source, data) } @@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def runAction(): Unit } - case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index ef74efef156d5..c318b951ff992 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -53,32 +54,24 @@ class ContinuousSuiteBase extends StreamTest { // A continuous trigger that will only fire the initial time for the duration of a test. // This allows clean testing with manual epoch advancement. protected val longContinuousTrigger = Trigger.Continuous("1 hour") + + override protected val defaultTrigger = Trigger.Continuous(100) + override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { import testImplicits._ - test("basic rate source") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + test("basic") { + val input = ContinuousMemoryStream[Int] - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(input.toDF())( + AddData(input, 0, 1, 2), + CheckAnswer(0, 1, 2), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), - StopStream) + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(0, 1, 2, 3, 4, 5)) } test("map") { From 05ae74778a10fbdd7f2cbf7742de7855966b7d35 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Tue, 17 Apr 2018 04:13:17 -0700 Subject: [PATCH 0630/2461] [SPARK-23747][STRUCTURED STREAMING] Add EpochCoordinator unit tests ## What changes were proposed in this pull request? Unit tests for EpochCoordinator that test correct sequencing of committed epochs. Several tests are ignored since they test functionality implemented in SPARK-23503 which is not yet merged, otherwise they fail. Author: Efim Poberezkin Closes #20983 from efimpoberezkin/pr/EpochCoordinator-tests. --- .../continuous/EpochCoordinatorSuite.scala | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala new file mode 100644 index 0000000000000..99e30561f81d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.mockito.InOrder +import org.mockito.Matchers.{any, eq => eqTo} +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.test.TestSparkSession + +class EpochCoordinatorSuite + extends SparkFunSuite + with LocalSparkSession + with MockitoSugar + with BeforeAndAfterEach { + + private var epochCoordinator: RpcEndpointRef = _ + + private var writer: StreamWriter = _ + private var query: ContinuousExecution = _ + private var orderVerifier: InOrder = _ + + override def beforeEach(): Unit = { + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] + query = mock[ContinuousExecution] + orderVerifier = inOrder(writer, query) + + spark = new TestSparkSession() + + epochCoordinator + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + } + + test("single epoch") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + // Here and in subsequent tests this is called to make a synchronous call to EpochCoordinator + // so that mocks would have been acted upon by the time verification happens + makeSynchronousCall() + + verifyCommit(1) + } + + test("single epoch, all but one writer partition has committed") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("single epoch, all but one reader partition has reported an offset") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("consequent epochs, messages for epoch (k + 1) arrive after messages for epoch k") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + // Message that arrives late + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 5) + reportPartitionOffset(0, 5) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) + } + + private def setWriterPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) + } + + private def setReaderPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetReaderPartitions(numPartitions)) + } + + private def commitPartitionEpoch(partitionId: Int, epoch: Long): Unit = { + val dummyMessage: WriterCommitMessage = mock[WriterCommitMessage] + epochCoordinator.send(CommitPartitionEpoch(partitionId, epoch, dummyMessage)) + } + + private def reportPartitionOffset(partitionId: Int, epoch: Long): Unit = { + val dummyOffset: PartitionOffset = mock[PartitionOffset] + epochCoordinator.send(ReportPartitionOffset(partitionId, epoch, dummyOffset)) + } + + private def makeSynchronousCall(): Unit = { + epochCoordinator.askSync[Long](GetCurrentEpoch) + } + + private def verifyCommit(epoch: Long): Unit = { + orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(query).commit(epoch) + } + + private def verifyNoCommitFor(epoch: Long): Unit = { + verify(writer, never()).commit(eqTo(epoch), any()) + verify(query, never()).commit(epoch) + } + + private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { + epochs.foreach(verifyCommit) + } +} From 30ffb53cad84283b4f7694bfd60bdd7e1101b04e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 17 Apr 2018 15:09:36 +0200 Subject: [PATCH 0631/2461] [SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20984 from viirya/SPARK-23875. --- .../expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/util/ArrayData.scala | 30 +++++- .../util/ArrayDataIndexedSeqSuite.scala | 100 ++++++++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 77802e89e942b..72b202b3a5020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -708,7 +708,7 @@ case class MapObjects private( } } case ArrayType(et, _) => - _.asInstanceOf[ArrayData].array + _.asInstanceOf[ArrayData].toSeq[Any](et) } private lazy val mapElements: Seq[_] => Any = customCollectionCls match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41d639f3..2cf59d567c08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def toSeq[T](dataType: DataType): IndexedSeq[T] = + new ArrayDataIndexedSeq[T](this, dataType) + def setNullAt(i: Int): Unit def update(i: Int, value: Any): Unit @@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable { } } } + +/** + * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData` + * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`, + * instead of `IndexedSeq[Int]`, in order to keep the null elements. + */ +class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] { + + private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType) + + override def apply(idx: Int): T = + if (0 <= idx && idx < arrayData.numElements()) { + if (arrayData.isNullAt(idx)) { + null.asInstanceOf[T] + } else { + accessor(arrayData, idx).asInstanceOf[T] + } + } else { + throw new IndexOutOfBoundsException( + s"Index $idx must be between 0 and the length of the ArrayData.") + } + + override def length: Int = arrayData.numElements() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala new file mode 100644 index 0000000000000..6400898343ae7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.types._ + +class ArrayDataIndexedSeqSuite extends SparkFunSuite { + private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = { + assert(arrayData.numElements == array.length) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For NaN, etc. + case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e)) + case _ => assert(arrayData.get(i, elementDt) === e) + } + } else { + assert(arrayData.isNullAt(i)) + } + } + + val seq = arrayData.toSeq[Any](elementDt) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For Nan, etc. + case FloatType | DoubleType => assert(seq(i).equals(e)) + case _ => assert(seq(i) === e) + } + } else { + assert(seq(i) == null) + } + } + + intercept[IndexOutOfBoundsException] { + seq(-1) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + + intercept[IndexOutOfBoundsException] { + seq(seq.length) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + } + + private def testArrayData(): Unit = { + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val random = new Random(100) + arrayTypes.foreach { dt => + val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + + val unsafeRowConverter = UnsafeProjection.create(schema) + val safeRowConverter = FromUnsafeProjection(schema) + + val unsafeRow = unsafeRowConverter(internalRow) + val safeRow = safeRowConverter(unsafeRow) + + val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData] + val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + + val elementType = dt.elementType + test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) { + compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType)) + } + + test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) { + compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType)) + } + } + } + + testArrayData() +} From 0a9172a05e604a4a94adbb9208c8c02362afca00 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 21:45:20 +0800 Subject: [PATCH 0632/2461] [SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization ## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido Closes #20976 from mgaido91/SPARK-23835. --- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 14 +++++++------- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index fc890a0cfdac3..ddfc0c1a4be2d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 42f8b4c7657e2..7079ac6453ffc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1aae3aea3a31a..e4274aaa9727e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { + val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = deserializerFor( + deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 8c3db48a01f12..353b8344658f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("_2", NullType, nullable = true))), nullable = true)) } + + test("SPARK-23835: add null check to non-nullable types in Tuples") { + def numberOfCheckedArguments(deserializer: Expression): Int = { + assert(deserializer.isInstanceOf[NewInstance]) + deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + } + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9b745befcb611..e0f4d2ba685e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val group2 = cached.groupBy("x").agg(min(col("z")) as "value") checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) } + + test("SPARK-23835: null primitive data type should throw NullPointerException") { + val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() + intercept[NullPointerException](ds.as[(Int, Int)].collect()) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From ed4101d29f50d54fd7846421e4c00e9ecd3599d0 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 21:52:33 +0800 Subject: [PATCH 0633/2461] [SPARK-22676] Avoid iterating all partition paths when spark.sql.hive.verifyPartitionPath=true MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code, it will scanning all partition paths when spark.sql.hive.verifyPartitionPath=true. e.g. table like below: ``` CREATE TABLE `test`( `id` int, `age` int, `name` string) PARTITIONED BY ( `A` string, `B` string) load data local inpath '/tmp/data0' into table test partition(A='00', B='00') load data local inpath '/tmp/data1' into table test partition(A='01', B='01') load data local inpath '/tmp/data2' into table test partition(A='10', B='10') load data local inpath '/tmp/data3' into table test partition(A='11', B='11') ``` If I query with SQL – "select * from test where A='00' and B='01' ", current code will scan all partition paths including '/data/A=00/B=00', '/data/A=00/B=00', '/data/A=01/B=01', '/data/A=10/B=10', '/data/A=11/B=11'. It costs much time and memory cost. This pr proposes to avoid iterating all partition paths. Add a config `spark.files.ignoreMissingFiles` and ignore the `file not found` when `getPartitions/compute`(for hive table scan). This is much like the logic brought by `spark.sql.files.ignoreMissingFiles`(which is for datasource scan). ## How was this patch tested? UT Author: jinxing Closes #19868 from jinxing64/SPARK-22676. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 43 +++++++++--- .../org/apache/spark/rdd/NewHadoopRDD.scala | 45 ++++++++---- .../scala/org/apache/spark/FileSuite.scala | 69 ++++++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/hive/QueryPartitionSuite.scala | 40 +++++++++++ 6 files changed, 181 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 407545aa4a47a..99d779fb600e8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -301,6 +301,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val IGNORE_MISSING_FILES = ConfigBuilder("spark.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") .stringConf .createOptional diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2480559a41b7a..44895abc7bd4d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred._ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -134,6 +135,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. @@ -197,17 +200,24 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) - val inputSplits = if (ignoreEmptySplits) { - allInputSplits.filter(_.getLength > 0) - } else { - allInputSplits - } - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) + try { + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } + val array = new Array[Partition](inputSplits.size) + for (i <- 0 until inputSplits.size) { + array(i) = new HadoopPartition(id, i, inputSplits(i)) + } + array + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - array } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -256,6 +266,12 @@ class HadoopRDD[K, V]( try { inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true @@ -276,6 +292,11 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e4dd1b6a82498..ff66a04859d10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,7 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileInputFormat, FileSplit, InvalidInputException} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark._ @@ -90,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { @@ -124,17 +126,25 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala - val rawSplits = if (ignoreEmptySplits) { - allRowSplits.filter(_.getLength > 0) - } else { - allRowSplits - } - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + try { + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${_conf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - result } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -189,6 +199,12 @@ class NewHadoopRDD[K, V]( _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) _reader } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", @@ -213,6 +229,11 @@ class NewHadoopRDD[K, V]( try { finished = !reader.nextKeyValue } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 55a9122cf9026..a441b9c8ab97a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec @@ -32,7 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.internal.config._ -import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -596,4 +597,70 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { actualPartitionNum = 5, expectedPartitionNum = 2) } + + test("spark.files.ignoreMissingFiles should work both HadoopRDD and NewHadoopRDD") { + // "file not found" can happen both when getPartitions or compute in HadoopRDD/NewHadoopRDD, + // We test both cases here. + + val deletedPath = new Path(tempDir.getAbsolutePath, "test-data-1") + val fs = deletedPath.getFileSystem(new Configuration()) + fs.delete(deletedPath, true) + intercept[FileNotFoundException](fs.open(deletedPath)) + + def collectRDDAndDeleteFileBeforeCompute(newApi: Boolean): Array[_] = { + val dataPath = new Path(tempDir.getAbsolutePath, "test-data-2") + val writer = new OutputStreamWriter(new FileOutputStream(new File(dataPath.toString))) + writer.write("hello\n") + writer.write("world\n") + writer.close() + val rdd = if (newApi) { + sc.newAPIHadoopFile(dataPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]) + } else { + sc.textFile(dataPath.toString) + } + rdd.partitions + fs.delete(dataPath, true) + // Exception happens when initialize record reader in HadoopRDD/NewHadoopRDD.compute + // because partitions' info already cached. + rdd.collect() + } + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=false by default. + sc = new SparkContext("local", "test") + intercept[org.apache.hadoop.mapred.InvalidInputException] { + // Exception happens when HadoopRDD.getPartitions + sc.textFile(deletedPath.toString).collect() + } + + var e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(false) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + intercept[org.apache.hadoop.mapreduce.lib.input.InvalidInputException] { + // Exception happens when NewHadoopRDD.getPartitions + sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect + } + + e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(true) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + sc.stop() + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=true. + val conf = new SparkConf().set(IGNORE_MISSING_FILES, true) + sc = new SparkContext("local", "test", conf) + assert(sc.textFile(deletedPath.toString).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(false).isEmpty) + + assert(sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0dc47bfe075d0..3729bd5293eca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -437,7 +437,8 @@ object SQLConf { val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + - "when reading data stored in HDFS.") + "when reading data stored in HDFS. This configuration will be deprecated in the future " + + "releases and replaced by spark.files.ignoreMissingFiles.") .booleanConf .createWithDefault(false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index b2dc401ce1efc..78156b17fb43b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem +import org.apache.spark.internal.config._ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -70,6 +71,45 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl } } + test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + + sql("DROP TABLE IF EXISTS table_with_partition") + sql("DROP TABLE IF EXISTS createAndInsertTest") + } + } + test("SPARK-21739: Cast expression should initialize timezoneId") { withTable("table_with_timestamp_partition") { sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") From 3990daaf3b6ca2c5a9f7790030096262efb12cb2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 08:55:01 -0500 Subject: [PATCH 0634/2461] [SPARK-23948] Trigger mapstage's job listener in submitMissingTasks ## What changes were proposed in this pull request? SparkContext submitted a map stage from `submitMapStage` to `DAGScheduler`, `markMapStageJobAsFinished` is called only in (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L933 and https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1314); But think about below scenario: 1. stage0 and stage1 are all `ShuffleMapStage` and stage1 depends on stage0; 2. We submit stage1 by `submitMapStage`; 3. When stage 1 running, `FetchFailed` happened, stage0 and stage1 got resubmitted as stage0_1 and stage1_1; 4. When stage0_1 running, speculated tasks in old stage1 come as succeeded, but stage1 is not inside `runningStages`. So even though all splits(including the speculated tasks) in stage1 succeeded, job listener in stage1 will not be called; 5. stage0_1 finished, stage1_1 starts running. When `submitMissingTasks`, there is no missing tasks. But in current code, job listener is not triggered. We should call the job listener for map stage in `5`. ## How was this patch tested? Not added yet. Author: jinxing Closes #21019 from jinxing64/SPARK-23948. --- .../apache/spark/scheduler/DAGScheduler.scala | 33 ++++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 52 +++++++++++++++++++ 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c46a84323392..78b6b34b5d2bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1092,17 +1092,16 @@ class DAGScheduler( // the stage as completed here in case there are no tasks to run markStageAsFinished(stage, None) - val debugString = stage match { + stage match { case stage: ShuffleMapStage => - s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) case stage : ResultStage => - s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") } - logDebug(debugString) - submitWaitingChildStages(stage) } } @@ -1307,13 +1306,7 @@ class DAGScheduler( shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - // Mark any map-stage jobs waiting on this stage as finished - if (shuffleStage.mapStageJobs.nonEmpty) { - val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) - for (job <- shuffleStage.mapStageJobs) { - markMapStageJobAsFinished(job, stats) - } - } + markMapStageJobsAsFinished(shuffleStage) submitWaitingChildStages(shuffleStage) } } @@ -1433,6 +1426,16 @@ class DAGScheduler( } } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + /** * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d812b5bd92c1b..8b6ec37625eec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2146,6 +2146,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Trigger mapstage's job listener in submitMissingTasks") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + + // Complete the stage0. + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting stage1, trigger a fetch failure. + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + // Stage1 listener should not have a result yet + assert(listener2.results.size === 0) + + // Speculative task succeeded in stage1. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostD", rdd2.partitions.length))) + // stage1 listener still should not have a result, though there's no missing partitions + // in it. Because stage1 has been failed and is not inside `runningStages` at this moment. + assert(listener2.results.size === 0) + + // Stage0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // After stage0 is finished, stage1 will be submitted and found there is no missing + // partitions in it. Then listener got triggered. + assert(listener2.results.size === 1) + assertDataStructuresEmpty() + } + /** * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported From f39e82ce150b6a7ea038e6858ba7adbaba3cad88 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Apr 2018 00:35:44 +0800 Subject: [PATCH 0635/2461] [SPARK-23986][SQL] freshName can generate non-unique names ## What changes were proposed in this pull request? We are using `CodegenContext.freshName` to get a unique name for any new variable we are adding. Unfortunately, this method currently fails to create a unique name when we request more than one instance of variables with starting name `name1` and an instance with starting name `name11`. The PR changes the way a new name is generated by `CodegenContext.freshName` so that we generate unique names in this scenario too. ## How was this patch tested? added UT Author: Marco Gaido Closes #21080 from mgaido91/SPARK-23986. --- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 +++-------- .../catalyst/expressions/CodeGenerationSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d97611c98ac91..f6b6775923ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -572,14 +572,9 @@ class CodegenContext { } else { s"${freshNamePrefix}_$name" } - if (freshNameIds.contains(fullName)) { - val id = freshNameIds(fullName) - freshNameIds(fullName) = id + 1 - s"$fullName$id" - } else { - freshNameIds += fullName -> 1 - fullName - } + val id = freshNameIds.getOrElse(fullName, 0) + freshNameIds(fullName) = id + 1 + s"${fullName}_$id" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f7c023111ff59..5b71becee2de0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -489,4 +489,14 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(!ctx.subExprEliminationExprs.contains(ref)) } } + + test("SPARK-23986: freshName can generate duplicated names") { + val ctx = new CodegenContext + val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: + ctx.freshName("myName11") :: Nil + assert(names1.distinct.length == 3) + val names2 = ctx.freshName("a") :: ctx.freshName("a") :: + ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil + assert(names2.distinct.length == 4) + } } From 1ca3c50fefb34532c78427fa74872db3ecbf7ba2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 17 Apr 2018 10:11:08 -0700 Subject: [PATCH 0636/2461] [SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariate summarizer ## What changes were proposed in this pull request? Python API for DataFrame-based multivariate summarizer. ## How was this patch tested? doctest added. Author: WeichenXu Closes #20695 from WeichenXu123/py_summarizer. --- python/pyspark/ml/stat.py | 193 +++++++++++++++++++++++++++++++++++++- 1 file changed, 192 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 93d0f4fd9148f..a06ab31a7a56a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -19,7 +19,9 @@ from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.wrapper import _jvm +from pyspark.ml.wrapper import JavaWrapper, _jvm +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.functions import lit class ChiSquareTest(object): @@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params): _jvm().PythonUtils.toSeq(params))) +class Summarizer(object): + """ + .. note:: Experimental + + Tools for vectorized statistics on MLlib Vectors. + The methods in this package provide various statistics for Vectors contained inside DataFrames. + This class lets users pick the statistics they would like to extract for a given column. + + >>> from pyspark.ml.stat import Summarizer + >>> from pyspark.sql import Row + >>> from pyspark.ml.linalg import Vectors + >>> summarizer = Summarizer.metrics("mean", "count") + >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + +-----------------------------------+ + |aggregate_metrics(features, weight)| + +-----------------------------------+ + |[[1.0,1.0,1.0], 1] | + +-----------------------------------+ + + >>> df.select(summarizer.summary(df.features)).show(truncate=False) + +--------------------------------+ + |aggregate_metrics(features, 1.0)| + +--------------------------------+ + |[[1.0,1.5,2.0], 2] | + +--------------------------------+ + + >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.0,1.0] | + +--------------+ + + >>> df.select(Summarizer.mean(df.features)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.5,2.0] | + +--------------+ + + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def mean(col, weightCol=None): + """ + return a column of mean summary + """ + return Summarizer._get_single_metric(col, weightCol, "mean") + + @staticmethod + @since("2.4.0") + def variance(col, weightCol=None): + """ + return a column of variance summary + """ + return Summarizer._get_single_metric(col, weightCol, "variance") + + @staticmethod + @since("2.4.0") + def count(col, weightCol=None): + """ + return a column of count summary + """ + return Summarizer._get_single_metric(col, weightCol, "count") + + @staticmethod + @since("2.4.0") + def numNonZeros(col, weightCol=None): + """ + return a column of numNonZero summary + """ + return Summarizer._get_single_metric(col, weightCol, "numNonZeros") + + @staticmethod + @since("2.4.0") + def max(col, weightCol=None): + """ + return a column of max summary + """ + return Summarizer._get_single_metric(col, weightCol, "max") + + @staticmethod + @since("2.4.0") + def min(col, weightCol=None): + """ + return a column of min summary + """ + return Summarizer._get_single_metric(col, weightCol, "min") + + @staticmethod + @since("2.4.0") + def normL1(col, weightCol=None): + """ + return a column of normL1 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL1") + + @staticmethod + @since("2.4.0") + def normL2(col, weightCol=None): + """ + return a column of normL2 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL2") + + @staticmethod + def _check_param(featuresCol, weightCol): + if weightCol is None: + weightCol = lit(1.0) + if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column): + raise TypeError("featureCol and weightCol should be a Column") + return featuresCol, weightCol + + @staticmethod + def _get_single_metric(col, weightCol, metric): + col, weightCol = Summarizer._check_param(col, weightCol) + return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric, + col._jc, weightCol._jc)) + + @staticmethod + @since("2.4.0") + def metrics(*metrics): + """ + Given a list of metrics, provides a builder that it turns computes metrics from a column. + + See the documentation of [[Summarizer]] for an example. + + The following metrics are accepted (case sensitive): + - mean: a vector that contains the coefficient-wise mean. + - variance: a vector tha contains the coefficient-wise variance. + - count: the count of all vectors seen. + - numNonzeros: a vector with the number of non-zeros for each coefficients + - max: the maximum for each coefficient. + - min: the minimum for each coefficient. + - normL2: the Euclidian norm for each coefficient. + - normL1: the L1 norm of each coefficient (sum of the absolute values). + + :param metrics: + metrics that can be provided. + :return: + an object of :py:class:`pyspark.ml.stat.SummaryBuilder` + + Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + interface. + """ + sc = SparkContext._active_spark_context + js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics", + _to_seq(sc, metrics)) + return SummaryBuilder(js) + + +class SummaryBuilder(JavaWrapper): + """ + .. note:: Experimental + + A builder object that provides summary statistics about a given column. + + Users should not directly create such builders, but instead use one of the methods in + :py:class:`pyspark.ml.stat.Summarizer` + + .. versionadded:: 2.4.0 + + """ + def __init__(self, jSummaryBuilder): + super(SummaryBuilder, self).__init__(jSummaryBuilder) + + @since("2.4.0") + def summary(self, featuresCol, weightCol=None): + """ + Returns an aggregate object that contains the summary of the column with the requested + metrics. + + :param featuresCol: + a column that contains features Vector object. + :param weightCol: + a column that contains weight value. Default weight is 1.0. + :return: + an aggregate column that contains the statistics. The exact content of this + structure is determined during the creation of the builder. + """ + featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol) + return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 5fccdae18911793967b315c02c058eb737e46174 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Apr 2018 21:08:42 -0500 Subject: [PATCH 0637/2461] [SPARK-22968][DSTREAM] Throw an exception on partition revoking issue ## What changes were proposed in this pull request? Kafka partitions can be revoked when new consumers joined in the consumer group to rebalance the partitions. But current Spark Kafka connector code makes sure there's no partition revoking scenarios, so trying to get latest offset from revoked partitions will throw exceptions as JIRA mentioned. Partition revoking happens when new consumer joined the consumer group, which means different streaming apps are trying to use same group id. This is fundamentally not correct, different apps should use different consumer group. So instead of throwing an confused exception from Kafka, improve the exception message by identifying revoked partition and directly throw an meaningful exception when partition is revoked. Besides, this PR also fixes bugs in `DirectKafkaWordCount`, this example simply cannot be worked without the fix. ``` 8/01/05 09:48:27 INFO internals.ConsumerCoordinator: Revoking previously assigned partitions [kssh-7, kssh-4, kssh-3, kssh-6, kssh-5, kssh-0, kssh-2, kssh-1] for group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: (Re-)joining group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: Successfully joined group use_a_separate_group_id_for_each_stream with generation 4 18/01/05 09:48:27 INFO internals.ConsumerCoordinator: Setting newly assigned partitions [kssh-7, kssh-4, kssh-6, kssh-5] for group use_a_separate_group_id_for_each_stream ``` ## How was this patch tested? This is manually verified in local cluster, unfortunately I'm not sure how to simulate it in UT, so propose the PR without UT added. Author: jerryshao Closes #21038 from jerryshao/SPARK-22968. --- .../streaming/DirectKafkaWordCount.scala | 17 +++++++++++++---- .../kafka010/DirectKafkaInputDStream.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index def06026bde96..2082fb71afdf1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -18,6 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka010._ @@ -26,18 +29,20 @@ import org.apache.spark.streaming.kafka010._ * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: DirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ object DirectKafkaWordCount { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { System.err.println(s""" |Usage: DirectKafkaWordCount | is a list of one or more Kafka brokers + | is a consumer group name to consume from topics | is a list of one or more kafka topics to consume from | """.stripMargin) @@ -46,7 +51,7 @@ object DirectKafkaWordCount { StreamingExamples.setStreamingLogLevels() - val Array(brokers, topics) = args + val Array(brokers, groupId, topics) = args // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") @@ -54,7 +59,11 @@ object DirectKafkaWordCount { // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet - val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val kafkaParams = Map[String, Object]( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers, + ConsumerConfig.GROUP_ID_CONFIG -> groupId, + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer], + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer]) val messages = KafkaUtils.createDirectStream[String, String]( ssc, LocationStrategies.PreferConsistent, diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 215b7cab703fb..c3221481556f5 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -190,8 +190,20 @@ private[spark] class DirectKafkaInputDStream[K, V]( // make sure new partitions are reflected in currentOffsets val newPartitions = parts.diff(currentOffsets.keySet) + + // Check if there's any partition been revoked because of consumer rebalance. + val revokedPartitions = currentOffsets.keySet.diff(parts) + if (revokedPartitions.nonEmpty) { + throw new IllegalStateException(s"Previously tracked partitions " + + s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " + + s"rebalance. This is mostly due to another stream with same group id joined, " + + s"please check if there're different streaming application misconfigure to use same " + + s"group id. Fundamentally different stream should use different group id") + } + // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap + // don't want to consume messages, so pause c.pause(newPartitions.asJava) // find latest available offsets From 1e3b8762a854a07c317f69fba7fa1a7bcdc58ff3 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Apr 2018 10:36:41 +0800 Subject: [PATCH 0638/2461] [SPARK-21479][SQL] Outer join filter pushdown in null supplying table when condition is on one of the joined columns ## What changes were proposed in this pull request? Added `TransitPredicateInOuterJoin` optimization rule that transits constraints from the preserved side of an outer join to the null-supplying side. The constraints of the join operator will remain unchanged. ## How was this patch tested? Added 3 tests in `InferFiltersFromConstraintsSuite`. Author: maryannxue Closes #20816 from maryannxue/spark-21479. --- .../sql/catalyst/optimizer/Optimizer.scala | 42 +++++++++++++++++-- .../plans/logical/QueryPlanConstraints.scala | 25 +++++++++-- .../InferFiltersFromConstraintsSuite.scala | 36 ++++++++++++++++ 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5fb59ef350b8b..913354e4df0e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,8 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and - * LeftSemi joins. + * In addition, for left/right outer joins, infer predicate from the preserved side of the Join + * operator and push the inferred filter over to the null-supplying side. For example, if the + * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in + * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will + * be applied to the null-supplying side. */ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { @@ -671,11 +674,42 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe val newConditionOpt = conditionOpt match { case Some(condition) => val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None + if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt case None => additionalConstraints.reduceOption(And) } - if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join + // Infer filter for left/right outer joins + val newLeftOpt = joinType match { + case RightOuter if newConditionOpt.isDefined => + val inferredConstraints = left.getRelevantConstraints( + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(left.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, left)) + case _ => None + } + val newRightOpt = joinType match { + case LeftOuter if newConditionOpt.isDefined => + val inferredConstraints = right.getRelevantConstraints( + right.constraints + .union(left.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(right.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, right)) + case _ => None + } + + if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) + || newLeftOpt.isDefined || newRightOpt.isDefined) { + Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) + } else { + join + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..a29f3d29236c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -41,9 +41,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - }) + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -55,6 +53,23 @@ trait QueryPlanConstraints { self: LogicalPlan => */ protected def validConstraints: Set[Expression] = Set.empty + /** + * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as + * equality constraints and `isNotNull` constraints, etc., and that only contains references + * to this [[LogicalPlan]] node. + */ + def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { + val allRelevantConstraints = + if (conf.constraintPropagationEnabled) { + constraints + .union(inferAdditionalConstraints(constraints)) + .union(constructIsNotNullConstraints(constraints)) + } else { + constraints + } + ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this @@ -120,4 +135,8 @@ trait QueryPlanConstraints { self: LogicalPlan => destination: Attribute): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) + + private def selfReferenceOnly(e: Expression): Boolean = { + e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..e068f51044589 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze + val left = x.where(IsNotNull('a) && 'a === 2) + val right = y.where(IsNotNull('a) && 'a === 2) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze + val left = x.where(IsNotNull('a) && 'a > 5) + val right = y.where(IsNotNull('a) && 'a > 5) + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join no filter push down to preserved side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a) && 'a === 1) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 310a8cd06299e434d94a1e391a6eb62944112446 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Apr 2018 11:51:10 +0800 Subject: [PATCH 0639/2461] [SPARK-23341][SQL] define some standard options for data source v2 ## What changes were proposed in this pull request? Each data source implementation can define its own options and teach its users how to set them. Spark doesn't have any restrictions about what options a data source should or should not have. It's possible that some options are very common and many data sources use them. However different data sources may define the common options(key and meaning) differently, which is quite confusing to end users. This PR defines some standard options that data sources can optionally adopt: path, table and database. ## How was this patch tested? a new test case. Author: Wenchen Fan Closes #20535 from cloud-fan/options. --- .../sql/sources/v2/DataSourceOptions.java | 100 ++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 14 ++- .../sources/v2/DataSourceOptionsSuite.scala | 25 +++++ 3 files changed, 135 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index c32053580f016..83df3be747085 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -17,16 +17,61 @@ package org.apache.spark.sql.sources.v2; +import java.io.IOException; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; + +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.spark.annotation.InterfaceStability; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent * data source options. + * + * Each data source implementation can define its own options and teach its users how to set them. + * Spark doesn't have any restrictions about what options a data source should or should not have. + * Instead Spark defines some standard options that data sources can optionally adopt. It's possible + * that some options are very common and many data sources use them. However different data + * sources may define the common options(key and meaning) differently, which is quite confusing to + * end users. + * + * The standard options defined by Spark: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
    Option keyOption value
    pathA path string of the data files/directories, like + * path1, /absolute/file2, path3/*. The path can + * either be relative or absolute, points to either file or directory, and can contain + * wildcards. This option is commonly used by file-based data sources.
    pathsA JSON array style paths string of the data files/directories, like + * ["path1", "/absolute/file2"]. The format of each path is same as the + * path option, plus it should follow JSON string literal format, e.g. quotes + * should be escaped, pa\"th means pa"th. + *
    tableA table name string representing the table name directly without any interpretation. + * For example, db.tbl means a table called db.tbl, not a table called tbl + * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
    databaseA database name string representing the database name directly without any + * interpretation, which is very similar to the table name option.
    */ @InterfaceStability.Evolving public class DataSourceOptions { @@ -97,4 +142,59 @@ public double getDouble(String key, double defaultValue) { return keyLowerCasedMap.containsKey(lcaseKey) ? Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; } + + /** + * The option key for singular path. + */ + public static final String PATH_KEY = "path"; + + /** + * The option key for multiple paths. + */ + public static final String PATHS_KEY = "paths"; + + /** + * The option key for table name. + */ + public static final String TABLE_KEY = "table"; + + /** + * The option key for database name. + */ + public static final String DATABASE_KEY = "database"; + + /** + * Returns all the paths specified by both the singular path option and the multiple + * paths option. + */ + public String[] paths() { + String[] singularPath = + get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); + Optional pathsStr = get(PATHS_KEY); + if (pathsStr.isPresent()) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); + return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); + } catch (IOException e) { + return singularPath; + } + } else { + return singularPath; + } + } + + /** + * Returns the value of the table name option. + */ + public Optional tableName() { + return get(TABLE_KEY); + } + + /** + * Returns the value of the database name option. + */ + public Optional databaseName() { + return get(DATABASE_KEY); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ae3ba1690f696..d640fdc530ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,8 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD @@ -34,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -171,7 +173,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` + // force invocation of `load(...varargs...)` + option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) } /** @@ -193,10 +196,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + } Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, extraOptions.toMap ++ sessionOptions, + ds, extraOptions.toMap ++ sessionOptions + pathsOption, userSpecifiedSchema = userSpecifiedSchema)) - } else { loadV1Source(paths: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 31dfc55b23361..cfa69a86de1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -79,4 +79,29 @@ class DataSourceOptionsSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } + + test("standard options") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.TABLE_KEY -> "tbl").asJava) + + assert(options.paths().toSeq == Seq("abc")) + assert(options.tableName().get() == "tbl") + assert(!options.databaseName().isPresent) + } + + test("standard options with both singular path and multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) + + assert(options.paths().toSeq == Seq("abc", "c", "d")) + } + + test("standard options with only multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) + + assert(options.paths().toSeq == Seq("c", "d\"e")) + } } From cce469435d61bda5893d9aa6cfdf7ea46fa717df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 17 Apr 2018 21:03:57 -0700 Subject: [PATCH 0640/2461] [SPARK-24002][SQL] Task not serializable caused by org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes ## What changes were proposed in this pull request? ``` Py4JJavaError: An error occurred while calling o153.sql. : org.apache.spark.SparkException: Job aborted. at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:223) at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:189) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:70) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:68) at org.apache.spark.sql.execution.command.ExecutedCommandExec.executeCollect(commands.scala:79) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$59.apply(Dataset.scala:3021) at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:89) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:127) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3020) at org.apache.spark.sql.Dataset.(Dataset.scala:190) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:74) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:646) at sun.reflect.GeneratedMethodAccessor153.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:293) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:226) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Exception thrown in Future.get: at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:190) at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:267) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec.doConsume(BroadcastNestedLoopJoinExec.scala:530) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.ProjectExec.consume(basicPhysicalOperators.scala:37) at org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:69) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.FilterExec.consume(basicPhysicalOperators.scala:144) ... at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:190) ... 23 more Caused by: java.util.concurrent.ExecutionException: org.apache.spark.SparkException: Task not serializable at java.util.concurrent.FutureTask.report(FutureTask.java:122) at java.util.concurrent.FutureTask.get(FutureTask.java:206) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:179) ... 276 more Caused by: org.apache.spark.SparkException: Task not serializable at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:340) at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:330) at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:156) at org.apache.spark.SparkContext.clean(SparkContext.scala:2380) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:850) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:849) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:371) at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:849) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:417) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:89) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:125) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:271) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.inputRDDs(HashAggregateExec.scala:181) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:414) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:61) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:70) at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:264) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:93) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:81) at org.apache.spark.sql.execution.SQLExecution$.withExecutionId(SQLExecution.scala:150) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:80) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:76) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.nio.BufferUnderflowException at java.nio.HeapByteBuffer.get(HeapByteBuffer.java:151) at java.nio.ByteBuffer.get(ByteBuffer.java:715) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes(Binary.java:405) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytesUnsafe(Binary.java:414) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.writeObject(Binary.java:484) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1128) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496) ``` The Parquet filters are serializable but not thread safe. SparkPlan.prepare() could be called in different threads (BroadcastExchange will call it in a thread pool). Thus, we could serialize the same Parquet filter at the same time. This is not easily reproduced. The fix is to avoid serializing these Parquet filters in the driver. This PR is to avoid serializing these Parquet filters by moving the parquet filter generation from the driver to executors. ## How was this patch tested? Having two queries one is a 1000-line SQL query and a 3000-line SQL query. Need to run at least one hour with a heavy write workload to reproduce once. Author: gatorsmile Closes #21086 from gatorsmile/taskNotSerializable. --- .../parquet/ParquetFileFormat.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 476bd02374364..d8f47eec952de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -321,19 +321,6 @@ class ParquetFileFormat SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - // Try to push down filters when filter push-down is enabled. - val pushed = - if (sparkSession.sessionState.conf.parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -351,12 +338,26 @@ class ParquetFileFormat val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = + sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) From f81fa478ff990146e2a8e463ac252271448d96f5 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 18 Apr 2018 18:41:55 +0900 Subject: [PATCH 0641/2461] [SPARK-23926][SQL] Extending reverse function to support ArrayType arguments ## What changes were proposed in this pull request? This PR extends `reverse` functions to be able to operate over array columns and covers: - Introduction of `Reverse` expression that represents logic for reversing arrays and also strings - Removal of `StringReverse` expression - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(1, 3, 4, 2), null ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = inputadapter_value.copy(); /* 051 */ for(int k = 0; k < project_length / 2; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ boolean isNullAtK = project_value.isNullAt(k); /* 054 */ boolean isNullAtL = project_value.isNullAt(l); /* 055 */ if(!isNullAtK) { /* 056 */ int el = project_value.getInt(k); /* 057 */ if(!isNullAtL) { /* 058 */ project_value.setInt(k, project_value.getInt(l)); /* 059 */ } else { /* 060 */ project_value.setNullAt(k); /* 061 */ } /* 062 */ project_value.setInt(l, el); /* 063 */ } else if (!isNullAtL) { /* 064 */ project_value.setInt(k, project_value.getInt(l)); /* 065 */ project_value.setNullAt(l); /* 066 */ } /* 067 */ } /* 068 */ /* 069 */ } ``` ### Non-primitive type ``` val df = Seq( Seq("a", "c", "d", "b"), null ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]); /* 051 */ for(int k = 0; k < project_length; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l)); /* 054 */ } /* 055 */ /* 056 */ } ``` Author: mn-mikke Closes #21034 from mn-mikke/feature/array-api-reverse-to-master. --- python/pyspark/sql/functions.py | 20 +++- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 88 +++++++++++++++++ .../expressions/stringExpressions.scala | 20 ---- .../CollectionExpressionsSuite.scala | 44 +++++++++ .../expressions/StringExpressionsSuite.scala | 6 +- .../org/apache/spark/sql/functions.scala | 15 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 94 +++++++++++++++++++ 8 files changed, 256 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6ca22b610843d..d3bb0a5d6b36a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1414,7 +1414,6 @@ def hash(*cols): 'uppercase. Words are delimited by whitespace.', 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', - 'reverse': 'Reverses the string column and returns it as a new string column.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', @@ -2128,6 +2127,25 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4dd1ca509bf2c..38c874ad948e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -336,7 +336,6 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), - expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), @@ -411,6 +410,7 @@ object FunctionRegistry { expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), + expression[Reverse]("reverse"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c87777eed47a..76b71f5b86074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + /** * Checks if the array (left) has the element (right) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 22fbb8998ed89..5a02ca0d6862c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression) } } -/** - * Returns the reversed given string. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns the reversed given string.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - """) -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - /** * Returns a string consisting of n spaces. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5a31e3a30edd6..517639dbc7232 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + + test("Reverse") { + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai7 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2)) + checkEvaluation(Reverse(ai1), Seq(3, 1, 2)) + checkEvaluation(Reverse(ai2), Seq(3, null, 1, null)) + checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2)) + checkEvaluation(Reverse(ai4), Seq(null, null, null)) + checkEvaluation(Reverse(ai5), Seq(1)) + checkEvaluation(Reverse(ai6), Seq.empty) + checkEvaluation(Reverse(ai7), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType)) + val as7 = Literal.create(null, ArrayType(StringType)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b")) + checkEvaluation(Reverse(as1), Seq("c", "a", "b")) + checkEvaluation(Reverse(as2), Seq("c", null, "a", null)) + checkEvaluation(Reverse(as3), Seq(null, "d", null, "b")) + checkEvaluation(Reverse(as4), Seq(null, null, null)) + checkEvaluation(Reverse(as5), Seq("a")) + checkEvaluation(Reverse(as6), Seq.empty) + checkEvaluation(Reverse(as7), null) + checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9a1a4da074ce3..f1a6f9b8889fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("REVERSE") { val s = 'a.string.at(0) val row1 = create_row("abccc") - checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) - checkEvaluation(StringReverse(s), "cccba", row1) - checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) + checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) + checkEvaluation(Reverse(s), "cccba", row1) + checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 642ac056bb809..a55a800f48245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2464,14 +2464,6 @@ object functions { StringRepeat(str.expr, lit(n).expr) } - /** - * Reverses the string column and returns it as a new string column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } - /** * Trim the spaces from right end for the specified string value. * @@ -3316,6 +3308,13 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a reversed string or an array with reverse order of elements. + * @group collection_funcs + * @since 1.5.0 + */ + def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 636e86baedf6f..74c42f2599dca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("reverse function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + + // String test cases + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + + checkAnswer( + oneRowDF.select(reverse('s)), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(s)"), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.select(reverse('i)), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(i)"), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(null)"), + Seq(Row(null)) + ) + + // Array test cases (primitive-type elements) + val idf = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + + // Array test cases (non-primitive-type elements) + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(struct(1, 'a'))") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(map(1, 'a'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From f09a9e9418c1697d198de18f340b1288f5eb025c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 18 Apr 2018 08:22:05 -0700 Subject: [PATCH 0642/2461] [SPARK-24007][SQL] EqualNullSafe for FloatType and DoubleType might generate a wrong result by codegen. ## What changes were proposed in this pull request? `EqualNullSafe` for `FloatType` and `DoubleType` might generate a wrong result by codegen. ```scala scala> val df = Seq((Some(-1.0d), None), (None, Some(-1.0d))).toDF() df: org.apache.spark.sql.DataFrame = [_1: double, _2: double] scala> df.show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ scala> df.filter("_1 <=> _2").show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ ``` The result should be empty but the result remains two rows. ## How was this patch tested? Added a test. Author: Takuya UESHIN Closes #21094 from ueshin/issues/SPARK-24007/equalnullsafe. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 ++++-- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6b6775923ac6..cf0a91ff00626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,8 +582,10 @@ class CodegenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" - case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" - case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" + case FloatType => + s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" + case DoubleType => + s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 8a8f8e10225fa..1bfd180ae4393 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -442,4 +442,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") { + checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false) + checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) + } } From a9066478f6d98c3ae634c3bb9b09ee20bd60e111 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 00:05:47 +0200 Subject: [PATCH 0643/2461] [SPARK-23875][SQL][FOLLOWUP] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? Use specified accessor in `ArrayData.foreach` and `toArray`. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21099 from viirya/SPARK-23875-followup. --- .../org/apache/spark/sql/catalyst/util/ArrayData.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 2cf59d567c08c..104b428614849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -141,28 +141,29 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType).asInstanceOf[T] + values(i) = accessor(this, i).asInstanceOf[T] } i += 1 } values } - // todo: specialize this. def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { if (isNullAt(i)) { f(i, null) } else { - f(i, get(i, elementType)) + f(i, accessor(this, i)) } i += 1 } From 0c94e48bc50717e1627c0d2acd5382d9adc73c97 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 18 Apr 2018 16:37:41 -0700 Subject: [PATCH 0644/2461] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `wait` and `CountDownLatch` for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #20888 from gaborgsomogyi/SPARK-23775. --- .../spark/sql/DataFrameRangeSuite.scala | 78 +++++++++++-------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..a0fd74088ce8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import java.util.concurrent.{CountDownLatch, TimeUnit} + import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -152,39 +154,53 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) + // Save and restore the value because SparkContext is shared + val savedInterruptOnCancel = sparkContext + .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + + try { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") + + for (codegen <- Seq(true, false)) { + // This countdown latch used to make sure with all the stages cancelStage called in listener + val latch = new CountDownLatch(2) + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + latch.countDown() + } } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) - } - } - sparkContext.addSparkListener(listener) - for (codegen <- Seq(true, false)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 - val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } + x + }.toDF("id").agg(sum("id")).collect() + } + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + } } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + latch.await(20, TimeUnit.SECONDS) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sparkContext.removeSparkListener(listener) } - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } + } finally { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, + savedInterruptOnCancel) } - sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -204,7 +220,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From 8bb0df2c65355dfdcd28e362ff661c6c7ebc99c0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 10:00:57 +0800 Subject: [PATCH 0645/2461] [SPARK-24014][PYSPARK] Add onStreamingStarted method to StreamingListener ## What changes were proposed in this pull request? The `StreamingListener` in PySpark side seems to be lack of `onStreamingStarted` method. This patch adds it and a test for it. This patch also includes a trivial doc improvement for `createDirectStream`. Original PR is #21057. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21098 from viirya/SPARK-24014. --- python/pyspark/streaming/kafka.py | 3 ++- python/pyspark/streaming/listener.py | 6 ++++++ python/pyspark/streaming/tests.py | 7 +++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index fdb9308604489..ed2e0e7d10fa2 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :param topics: list of topic_name to consume. :param kafkaParams: Additional params for Kafka. :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting - point of the stream. + point of the stream (a dictionary mapping `TopicAndPartition` to + integers). :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py index b830797f5c0a0..d4ecc215aea99 100644 --- a/python/pyspark/streaming/listener.py +++ b/python/pyspark/streaming/listener.py @@ -23,6 +23,12 @@ class StreamingListener(object): def __init__(self): pass + def onStreamingStarted(self, streamingStarted): + """ + Called when the streaming has been started. + """ + pass + def onReceiverStarted(self, receiverStarted): """ Called when a receiver has been started diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7dde7c0928c08..103940923dd4d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -507,6 +507,10 @@ def __init__(self): self.batchInfosCompleted = [] self.batchInfosStarted = [] self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) def onBatchSubmitted(self, batchSubmitted): self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) @@ -530,9 +534,12 @@ def func(dstream): batchInfosSubmitted = batch_collector.batchInfosSubmitted batchInfosStarted = batch_collector.batchInfosStarted batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime self.wait_for(batchInfosCompleted, 4) + self.assertEqual(len(streamingStartedTime), 1) + self.assertGreaterEqual(len(batchInfosSubmitted), 4) for info in batchInfosSubmitted: self.assertGreaterEqual(info.batchTime().milliseconds(), 0) From d5bec48b9cb225c19b43935c07b24090c51cacce Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 11:59:17 +0900 Subject: [PATCH 0646/2461] [SPARK-23919][SQL] Add array_position function ## What changes were proposed in this pull request? The PR adds the SQL function `array_position`. The behavior of the function is based on Presto's one. The function returns the position of the first occurrence of the element in array x (or 0 if not found) using 1-based index as BigInt. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21037 from kiszk/SPARK-23919. --- python/pyspark/sql/functions.py | 17 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 56 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 22 ++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 34 +++++++++++ 6 files changed, 144 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d3bb0a5d6b36a..36dcabc6766d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,6 +1845,23 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def array_position(col, value): + """ + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. + + .. note:: The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38c874ad948e1..74095fe697b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 76b71f5b86074..e6a05f535cb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 517639dbc7232..916cd3bb4cca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Reverse(as7), null) checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) } + + test("Array Position") { + val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) + checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) + checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) + checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) + checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayPosition(a3, Literal("")), null) + checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a55a800f48245..3a09ec4f1982e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3038,6 +3038,20 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Locates the position of the first occurrence of the value in the given array as long. + * Returns null if either of the arguments are null. + * + * @note The position is not zero based, but 1 based index. Returns 0 if value + * could not be found in array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_position(column: Column, value: Any): Column = withExpr { + ArrayPosition(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 74c42f2599dca..13161e7e24cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array position function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + checkAnswer( + df.select(array_position(df("a"), 1)), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.selectExpr("array_position(a, 1)"), + Seq(Row(1L), Row(0L)) + ) + + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("array_position(a, null)"), + Seq(Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L), Row(1L)) + ) + checkAnswer( + df.selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L), Row(1L)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 46bb2b5129833cc5829089bf1174a76cb7b81741 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 21:00:10 +0900 Subject: [PATCH 0647/2461] [SPARK-23924][SQL] Add element_at function ## What changes were proposed in this pull request? The PR adds the SQL function `element_at`. The behavior of the function is based on Presto's one. This function returns element of array at given index in value if column is array, or returns value for the given key in value if column is map. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21053 from kiszk/SPARK-23924. --- python/pyspark/sql/functions.py | 24 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 104 ++++++++++++++++++ .../expressions/complexTypeExtractors.scala | 64 +++++++---- .../CollectionExpressionsSuite.scala | 48 ++++++++ .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 48 ++++++++ 7 files changed, 276 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 36dcabc6766d8..1be68f2a4a448 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,30 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def element_at(col, extraction): + """ + Collection function: Returns element of array at given index in extraction if col is array. + Returns value for the given key in extraction if col is map. + + :param col: name of column containing array or map + :param extraction: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 74095fe697b6a..a44f2d5272b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -405,6 +405,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6a05f535cb1c..dba426e999dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression) }) } } + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6cdad19168dce..3fba52d745453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression) }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 916cd3bb4cca5..7d8fe211858b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a09ec4f1982e..9c8580378303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3052,6 +3052,17 @@ object functions { ArrayPosition(column.expr, Literal(value)) } + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 13161e7e24cfe..7c976c1b7f915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("element_at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 1b08c4393cf48e21fea9914d130d8d3bf544061d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:38:26 +0200 Subject: [PATCH 0648/2461] [SPARK-23584][SQL] NewInstance should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `NewInstance`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20778 from maropu/SPARK-23584. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../expressions/objects/objects.scala | 28 +++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 36 +++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e4274aaa9727e..818cc2fb1e8a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Constructor + +import org.apache.commons.lang3.reflect.ConstructorUtils + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned. Otherwise, it returns `None`. + */ + def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + } + /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 72b202b3a5020..1645bd7d57b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -449,8 +449,32 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + sys.error(s"Couldn't find a valid constructor on $cls") + } + } + outerPointer.map { p => + val outerObj = p() + val d = outerObj.getClass +: paramTypes + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(outerObj +: args: _*) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(args: _*) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0188b0098def..bf805f4f29ac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass { override def binOp(e1: Int, e2: Double): Double = e1 - e2 } +// Tests for NewInstance +class Outer extends Serializable { + class Inner(val value: Int) { + override def hashCode(): Int = super.hashCode() + override def equals(other: Any): Boolean = { + if (other.isInstanceOf[Inner]) { + value == other.asInstanceOf[Inner].value + } else { + false + } + } + } +} + class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-16622: The returned value of the called method in Invoke can be null") { @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23584 NewInstance should support interpreted execution") { + // Normal case test + val newInst1 = NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None) + checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3))) + + // Inner class case test + val outerObj = new Outer() + val newInst2 = NewInstance( + cls = classOf[outerObj.Inner], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[outerObj.Inner]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + } + test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), @@ -421,6 +456,7 @@ class TestBean extends Serializable { private var x: Int = 0 def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = assert(i != null, "this setter should not be called with null.") } From e13416502f814b04d59bb650953a0114332d163a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:42:50 +0200 Subject: [PATCH 0649/2461] [SPARK-23588][SQL] CatalystToExternalMap should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `CatalystToExternalMap`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20979 from maropu/SPARK-23588. --- .../expressions/objects/objects.scala | 39 +++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 34 +++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1645bd7d57b1d..bc17d1229420a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,12 +28,12 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1033,8 +1033,39 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val keyConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) + private lazy val valueConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + val module = clazz.getField("MODULE$").get(null) + val method = clazz.getMethod("newBuilder") + method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + var i = 0 + while (i < result.numElements()) { + val key = keyConverter(keyArray.get(i, inputMapType.keyType)) + val value = valueConverter(valueArray.get(i, inputMapType.valueType)) + builder += Tuple2(key, value) + i += 1 + } + builder.result() + } else { + null + } + } override def dataType: DataType = ObjectType(collClass) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bf805f4f29ac5..bcd035c1eba0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -27,12 +27,14 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -162,9 +164,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), (DateTimeUtils.getClass, ObjectType(classOf[Date]), - "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777, + DateTimeUtils.toJavaDate(77777)), (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), - "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]), 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) ).foreach { case (cls, dataType, methodName, argType, arg, expected) => checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, @@ -450,6 +453,25 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]() + + test("SPARK-23588 CatalystToExternalMap should support interpreted execution") { + // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan + // with dummy input, resolve the plan by the analyzer, and replace the dummy input + // with a literal for tests. + val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) + val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) + + val analyzedPlan = SimpleAnalyzer.execute(plan) + val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head + + // Replaces the dummy input with a literal for tests here + val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val deserializer = toMapExpr.copy(inputData = Literal.create(data)) + checkObjectExprEvaluation(deserializer, expected = data) + } } class TestBean extends Serializable { From 9e10f69df52abde2de5d93435bab54e97dd59d9c Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 19 Apr 2018 21:07:21 +0800 Subject: [PATCH 0650/2461] [SPARK-22676][FOLLOW-UP] fix code style for test. ## What changes were proposed in this pull request? This pr address comments in https://github.com/apache/spark/pull/19868 ; Fix the code style for `org.apache.spark.sql.hive.QueryPartitionSuite` by using: `withTempView`, `withTempDir`, `withTable`... Author: jinxing Closes #21091 from jinxing64/SPARK-22676-FOLLOW-UP. --- .../spark/sql/hive/QueryPartitionSuite.scala | 109 +++++++----------- 1 file changed, 41 insertions(+), 68 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 78156b17fb43b..1e396553c9c52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -33,80 +33,53 @@ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ - test("SPARK-5068: query data when path doesn't exist") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + private def queryWhenPathNotExist(): Unit = { + withTempView("testData") { + withTable("table_with_partition", "createAndInsertTest") { + withTempDir { tmpDir => + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData).union(testData)) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData)) + } + } + } + } - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") { + queryWhenPathNotExist() } } test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") { sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + queryWhenPathNotExist() } } From d96c3e33cc2a95de8e15e1a2ddf50a8d0cc66dd2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 19 Apr 2018 21:21:22 +0800 Subject: [PATCH 0651/2461] [SPARK-21811][SQL] Fix the inconsistency behavior when finding the widest common type ## What changes were proposed in this pull request? Currently we find the wider common type by comparing the two types from left to right, this can be a problem when you have two data types which don't have a common type but each can be promoted to StringType. For instance, if you have a table with the schema: [c1: date, c2: string, c3: int] The following succeeds: SELECT coalesce(c1, c2, c3) FROM table While the following produces an exception: SELECT coalesce(c1, c3, c2) FROM table This is only a issue when the seq of dataTypes contains `StringType` and all the types can do string promotion. close #19033 ## How was this patch tested? Add test in `TypeCoercionSuite` Author: Xingbo Jiang Closes #21074 from jiangxb1987/typeCoercion. --- docs/sql-programming-guide.md | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 24 +++++++++++++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 13 ++++++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 55d35b9dd31db..e8ff1470970f7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1810,7 +1810,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ec7e7761dc4c2..281f206e8d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -175,11 +175,27 @@ object TypeCoercion { }) } + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt match { + case StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) - case None => None - }) + // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal + // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. + // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, + // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 8ac49dc05e3cf..fd6a3121663ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -539,6 +539,9 @@ class TypeCoercionSuite extends AnalysisTest { val floatLit = Literal.create(1.0f, FloatType) val timestampLit = Literal.create("2017-04-12", TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) + val strArrayLit = Literal(Array("c")) + val intArrayLit = Literal(Array(1)) ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), @@ -572,6 +575,16 @@ class TypeCoercionSuite extends AnalysisTest { Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, intLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), + Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), + Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), + Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) } test("CreateArray casts") { From 0deaa5251326a32a3d2d2b8851193ca926303972 Mon Sep 17 00:00:00 2001 From: wuyi Date: Thu, 19 Apr 2018 09:00:33 -0500 Subject: [PATCH 0652/2461] [SPARK-24021][CORE] fix bug in BlacklistTracker's updateBlacklistForFetchFailure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? There‘s a miswrite in BlacklistTracker's updateBlacklistForFetchFailure: ``` val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) blacklistedExecsOnNode += exec ``` where first **exec** should be **host**. ## How was this patch tested? adjust existed test. Author: wuyi Closes #21104 from Ngone51/SPARK-24021. --- .../scala/org/apache/spark/scheduler/BlacklistTracker.scala | 2 +- .../org/apache/spark/scheduler/BlacklistTrackerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 952598f6de19d..30cf75d43ee09 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -210,7 +210,7 @@ private[scheduler] class BlacklistTracker ( updateNextExpiryTime() killBlacklistedExecutor(exec) - val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]()) blacklistedExecsOnNode += exec } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 06d7afaaff55c..96c8404327e24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -574,6 +574,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. conf.set(config.BLACKLIST_KILL_ENABLED, true) blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) @@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) // Enable external shuffle service to see if all the executors on this node will be killed. conf.set(config.SHUFFLE_SERVICE_ENABLED, true) From 6e19f7683fc73fabe7cdaac4eb1982d2e3e607b7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Apr 2018 17:54:53 +0200 Subject: [PATCH 0653/2461] [SPARK-23989][SQL] exchange should copy data before non-serialized shuffle ## What changes were proposed in this pull request? In Spark SQL, we usually reuse the `UnsafeRow` instance and need to copy the data when a place buffers non-serialized objects. Shuffle may buffer objects if we don't make it to the bypass merge shuffle or unsafe shuffle. `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle` misses the case that, if `spark.sql.shuffle.partitions` is large enough, we could fail to run unsafe shuffle and go with the non-serialized shuffle. This bug is very hard to hit since users wouldn't set such a large number of partitions(16 million) for Spark SQL exchange. TODO: test ## How was this patch tested? todo. Author: Wenchen Fan Closes #21101 from cloud-fan/shuffle. --- .../exchange/ShuffleExchangeExec.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4d95ee34f30de..b89203719541b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -153,12 +153,9 @@ object ShuffleExchangeExec { * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. * * @param partitioner the partitioner for the shuffle - * @param serializer the serializer that will be used to write rows * @return true if rows should be copied before being shuffled, false otherwise */ - private def needToCopyObjectsBeforeShuffle( - partitioner: Partitioner, - serializer: Serializer): Boolean = { + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { // Note: even though we only use the partitioner's `numPartitions` field, we require it to be // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output @@ -167,22 +164,24 @@ object ShuffleExchangeExec { val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + val numParts = partitioner.numPartitions if (sortBasedShuffleOn) { - val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + if (numParts <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializer.supportsRelocationOfSerializedObjects) { + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records // prior to sorting them. This optimization is only applied in cases where shuffle // dependency does not specify an aggregator or ordering and the record serializer has - // certain properties. If this optimization is enabled, we can safely avoid the copy. + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. // - // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only - // need to check whether the optimization is enabled and supported by our serializer. + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. false } else { // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must @@ -298,7 +297,7 @@ object ShuffleExchangeExec { rdd } - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + if (needToCopyObjectsBeforeShuffle(part)) { newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } From a471880afbeafd4ef54c15a97e72ea7ff784a88d Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Thu, 19 Apr 2018 09:40:20 -0700 Subject: [PATCH 0654/2461] [SPARK-24026][ML] Add Power Iteration Clustering to spark.ml ## What changes were proposed in this pull request? This PR adds PowerIterationClustering as a Transformer to spark.ml. In the transform method, it calls spark.mllib's PowerIterationClustering.run() method and transforms the return value assignments (the Kmeans output of the pseudo-eigenvector) as a DataFrame (id: LongType, cluster: IntegerType). This PR is copied and modified from https://github.com/apache/spark/pull/15770 The primary author is wangmiao1981 ## How was this patch tested? This PR has 2 types of tests: * Copies of tests from spark.mllib's PIC tests * New tests specific to the spark.ml APIs Author: wm624@hotmail.com Author: wangmiao1981 Author: Joseph K. Bradley Closes #21090 from jkbradley/wangmiao1981-pic. --- .../clustering/PowerIterationClustering.scala | 256 ++++++++++++++++++ .../PowerIterationClusteringSuite.scala | 238 ++++++++++++++++ 2 files changed, 494 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..2c30a1d9aa947 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +/** + * Common params for PowerIterationClustering + */ +private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter + with HasPredictionCol { + + /** + * The number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.4.0") + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) + + /** @group getParam */ + @Since("2.4.0") + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use a normalized sum of similarities with other vertices. + * Default: random. + * @group expertParam + */ + @Since("2.4.0") + final val initMode = { + val allowedParams = ParamValidators.inArray(Array("random", "degree")) + new Param[String](this, "initMode", "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " + + "of similarities with other vertices. Supported options: 'random' and 'degree'.", + allowedParams) + } + + /** @group expertGetParam */ + @Since("2.4.0") + def getInitMode: String = $(initMode) + + /** + * Param for the name of the input column for vertex IDs. + * Default: "id" + * @group param + */ + @Since("2.4.0") + val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + (value: String) => value.nonEmpty) + + setDefault(idCol, "id") + + /** @group getParam */ + @Since("2.4.0") + def getIdCol: String = getOrDefault(idCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "neighbors" + * @group param + */ + @Since("2.4.0") + val neighborsCol = new Param[String](this, "neighborsCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(neighborsCol, "neighbors") + + /** @group getParam */ + @Since("2.4.0") + def getNeighborsCol: String = $(neighborsCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "similarities" + * @group param + */ + @Since("2.4.0") + val similaritiesCol = new Param[String](this, "similaritiesCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(similaritiesCol, "similarities") + + /** @group getParam */ + @Since("2.4.0") + def getSimilaritiesCol: String = $(similaritiesCol) + + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(schema, $(neighborsCol), + Seq(ArrayType(IntegerType, containsNull = false), + ArrayType(LongType, containsNull = false))) + SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), + Seq(ArrayType(FloatType, containsNull = false), + ArrayType(DoubleType, containsNull = false))) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * Lin and Cohen. From the abstract: + * PIC finds a very low-dimensional embedding of a dataset using truncated power + * iteration on a normalized pair-wise similarity matrix of the data. + * + * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix + * is a symmetric matrix whose entries are non-negative similarities between items. + * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: + * - `idCol`: vertex ID + * - `neighborsCol`: neighbors of vertex in `idCol` + * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex + * in `idCol` and each neighbor in `neighborsCol` + * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` + * containing the cluster assignment in `[0,k)` for each row (vertex). + * + * Notes: + * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. + * Transform runs the iterative PIC algorithm to cluster the whole input dataset. + * - Input validation: This validates that similarities are non-negative but does NOT validate + * that the input matrix is symmetric. + * + * @see + * Spectral clustering (Wikipedia) + */ +@Since("2.4.0") +@Experimental +class PowerIterationClustering private[clustering] ( + @Since("2.4.0") override val uid: String) + extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> "random") + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("PowerIterationClustering")) + + /** @group setParam */ + @Since("2.4.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + @Since("2.4.0") + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group setParam */ + @Since("2.4.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.4.0") + def setIdCol(value: String): this.type = set(idCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + + @Since("2.4.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + val sparkSession = dataset.sparkSession + val idColValue = $(idCol) + val rdd: RDD[(Long, Long, Double)] = + dataset.select( + col($(idCol)).cast(LongType), + col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), + col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) + ).rdd.flatMap { + case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => + require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + + s"equal to the the length of the neighbor similarity list. Row for ID " + + s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + + s"of length ${sims.length}.") + nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { + case (nbr, similarity) => (id, nbr, similarity) + } + } + val algorithm = new MLlibPowerIterationClustering() + .setK($(k)) + .setInitializationMode($(initMode)) + .setMaxIterations($(maxIter)) + val model = algorithm.run(rdd) + + val predictionsRDD: RDD[Row] = model.assignments.map { assignment => + Row(assignment.id, assignment.cluster) + } + + val predictionsSchema = StructType(Seq( + StructField($(idCol), LongType, nullable = false), + StructField($(predictionCol), IntegerType, nullable = false))) + val predictions = { + val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) + dataset.schema($(idCol)).dataType match { + case _: LongType => + uncastPredictions + case otherType => + uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + } + } + + dataset.join(predictions, $(idCol)) + } + + @Since("2.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra) +} + +@Since("2.4.0") +object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] { + + @Since("2.4.0") + override def load(path: String): PowerIterationClustering = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..65328df17baff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import scala.collection.mutable + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + + +class PowerIterationClusteringSuite extends SparkFunSuite + with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var data: Dataset[_] = _ + final val r1 = 1.0 + final val n1 = 10 + final val r2 = 4.0 + final val n2 = 40 + + override def beforeAll(): Unit = { + super.beforeAll() + + data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2) + } + + test("default parameters") { + val pic = new PowerIterationClustering() + + assert(pic.getK === 2) + assert(pic.getMaxIter === 20) + assert(pic.getInitMode === "random") + assert(pic.getPredictionCol === "prediction") + assert(pic.getIdCol === "id") + assert(pic.getNeighborsCol === "neighbors") + assert(pic.getSimilaritiesCol === "similarities") + } + + test("parameter validation") { + intercept[IllegalArgumentException] { + new PowerIterationClustering().setK(1) + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setIdCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setNeighborsCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setSimilaritiesCol("") + } + } + + test("power iteration clustering") { + val n = n1 + n2 + + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + val result = model.transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + + val result2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .transform(data) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + result2.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + } + assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + } + + test("supported input types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + + def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + val typedData = data.select( + col("id").cast(idType).alias("id"), + col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("similarities").cast(ArrayType(similarityType, containsNull = false)) + .alias("similarities") + ) + model.transform(typedData).collect() + } + + for (idType <- Seq(IntegerType, LongType)) { + runTest(idType, LongType, DoubleType) + } + for (neighborType <- Seq(IntegerType, LongType)) { + runTest(LongType, neighborType, DoubleType) + } + for (similarityType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, similarityType) + } + } + + test("invalid input: wrong types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id").cast(DoubleType).alias("id"), + col("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors"), + col("neighbors").alias("similarities") + ) + model.transform(typedData) + } + } + + test("invalid input: negative similarity") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(-1.0)), + (1, Array(0), Array(-1.0)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("Similarity must be nonnegative")) + } + + test("invalid input: mismatched lengths for neighbor and similarity arrays") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(0.5)), + (1, Array(0, 2), Array(0.5)), + (2, Array(1), Array(0.5)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + + "the neighbor similarity list.")) + assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + } + + test("read/write") { + val t = new PowerIterationClustering() + .setK(4) + .setMaxIter(100) + .setInitMode("degree") + .setIdCol("test_id") + .setNeighborsCol("myNeighborsCol") + .setSimilaritiesCol("mySimilaritiesCol") + .setPredictionCol("test_prediction") + testDefaultReadWrite(t) + } +} + +object PowerIterationClusteringSuite { + + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + + def generatePICData( + spark: SparkSession, + r1: Double, + r2: Double, + n1: Int, + n2: Int): DataFrame = { + // Generate two circles following the example in the PIC paper. + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + + val rows = for (i <- 1 until n) yield { + val neighbors = for (j <- 0 until i) yield { + j.toLong + } + val similarities = for (j <- 0 until i) yield { + sim(points(i), points(j)) + } + (i.toLong, neighbors.toArray, similarities.toArray) + } + + spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + } + +} From 9ea8d3d31b75246bf61118ac7934bc92c18b5f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 19 Apr 2018 18:55:59 +0200 Subject: [PATCH 0655/2461] [SPARK-22362][SQL] Add unit test for Window Aggregate Functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Improving the test coverage of window functions focusing on missing test for window aggregate functions. No new UDAF test is added as it has been tested already. ## How was this patch tested? Only new tests were added, automated tests were executed. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20046 from attilapiros/SPARK-22362. --- .../resources/sql-tests/inputs/window.sql | 10 +- .../sql-tests/results/window.sql.out | 30 +- .../sql/DataFrameWindowFunctionsSuite.scala | 266 ++++++++++++++++++ 3 files changed, 294 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c4bea34ec4cf3..cda4db4b449fe 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 133458ae9303b..4afbcd62853dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val -- !query 17 schema -struct +struct,collect_set:array,skewness:double,kurtosis:double> -- !query 17 output -NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 -2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 -3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 -- !query 18 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 281147835abde..3ea398aad7375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import scala.collection.mutable + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ @@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } + test("corr, covar_pop, stddev_pop functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } + + test("collect_list in ascending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_list in descending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_set in window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", "3"), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20")).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "3")), + Row("b", Array("1", "2", "3")), + Row("c", Array("1", "2", "3")), + Row("d", Array("1", "2", "3")), + Row("e", Array("1", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")))) + } + + test("skewness and kurtosis functions in window") { + val df = Seq( + ("a", "p1", 1.0), + ("b", "p1", 1.0), + ("c", "p1", 2.0), + ("d", "p1", 2.0), + ("e", "p1", 3.0), + ("f", "p1", 3.0), + ("g", "p1", 3.0), + ("h", "p2", 1.0), + ("i", "p2", 2.0), + ("j", "p2", 5.0)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + skewness("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() + Seq( + Row("a", -0.27238010581457267, -1.506920415224914), + Row("b", -0.27238010581457267, -1.506920415224914), + Row("c", -0.27238010581457267, -1.506920415224914), + Row("d", -0.27238010581457267, -1.506920415224914), + Row("e", -0.27238010581457267, -1.506920415224914), + Row("f", -0.27238010581457267, -1.506920415224914), + Row("g", -0.27238010581457267, -1.506920415224914), + Row("h", 0.5280049792181881, -1.5000000000000013), + Row("i", 0.5280049792181881, -1.5000000000000013), + Row("j", 0.5280049792181881, -1.5000000000000013))) + } + + test("aggregation function on invalid column") { + val df = Seq((1, "1")).toDF("key", "value") + val e = intercept[AnalysisException]( + df.select($"key", count("invalid").over())) + assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]")) + } + + test("numerical aggregate functions on string column") { + val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") + checkAnswer( + df.select($"key", + var_pop("value1").over(), + variance("value1").over(), + stddev_pop("value1").over(), + stddev("value1").over(), + sum("value1").over(), + mean("value1").over(), + avg("value1").over(), + corr("value1", "value2").over(), + covar_pop("value1", "value2").over(), + covar_samp("value1", "value2").over(), + skewness("value1").over(), + kurtosis("value1").over()), + Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) + } + test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") @@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("b", 2, null, null, null, null, null, null))) } + test("last/first on descending ordered window") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, "v"), + ("b", 1, "k"), + ("b", 2, "l"), + ("b", 3, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order".desc) + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, "v", "v", "v", null, null, "x"), + Row("a", 1, "v", "v", "v", "x", "x", "x"), + Row("a", 2, "v", "v", "v", "y", "y", "y"), + Row("a", 3, "v", "v", "v", "z", "z", "z"), + Row("a", 4, "v", "v", "v", "v", "v", "v"), + Row("b", 1, null, null, "l", "k", "k", "k"), + Row("b", 2, null, null, "l", "l", "l", "l"), + Row("b", 3, null, null, null, null, null, null))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) From e55953b0bf2a80b34127ba123417ee54955a6064 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 19 Apr 2018 15:06:27 -0700 Subject: [PATCH 0656/2461] [SPARK-24022][TEST] Make SparkContextSuite not flaky ## What changes were proposed in this pull request? SparkContextSuite.test("Cancelling stages/jobs with custom reasons.") could stay in an infinite loop because of the problem found and fixed in [SPARK-23775](https://issues.apache.org/jira/browse/SPARK-23775). This PR solves this mentioned flakyness by removing shared variable usages when cancel happens in a loop and using wait and CountDownLatch for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #21105 from gaborgsomogyi/SPARK-24022. --- .../org/apache/spark/SparkContextSuite.scala | 61 ++++++++----------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b30bd74812b36..ce9f2be1c02dd 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") val REASON = "You shall not pass" - val slices = 10 - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - if (SparkContextSuite.cancelStage) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + for (cancelWhat <- Seq("stage", "job")) { + // This countdown latch used to make sure stage or job canceled in listener + val latch = new CountDownLatch(1) + + val listener = cancelWhat match { + case "stage" => + new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sc.cancelStage(taskStart.stageId, REASON) + latch.countDown() + } } - sc.cancelStage(taskStart.stageId, REASON) - SparkContextSuite.cancelStage = false - SparkContextSuite.semaphore.release(slices) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - if (SparkContextSuite.cancelJob) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + case "job" => + new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sc.cancelJob(jobStart.jobId, REASON) + latch.countDown() + } } - sc.cancelJob(jobStart.jobId, REASON) - SparkContextSuite.cancelJob = false - SparkContextSuite.semaphore.release(slices) - } } - } - sc.addSparkListener(listener) - - for (cancelWhat <- Seq("stage", "job")) { - SparkContextSuite.semaphore.drainPermits() - SparkContextSuite.isTaskStarted = false - SparkContextSuite.cancelStage = (cancelWhat == "stage") - SparkContextSuite.cancelJob = (cancelWhat == "job") + sc.addSparkListener(listener) val ex = intercept[SparkException] { - sc.range(0, 10000L, numSlices = slices).mapPartitions { x => - SparkContextSuite.isTaskStarted = true - // Block waiting for the listener to cancel the stage or job. - SparkContextSuite.semaphore.acquire() + sc.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } x }.count() } @@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } + latch.await(20, TimeUnit.SECONDS) eventually(timeout(20.seconds)) { assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sc.removeSparkListener(listener) } } @@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } object SparkContextSuite { - @volatile var cancelJob = false - @volatile var cancelStage = false @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false From b3fde5a41ee625141b9d21ce32ea68c082449430 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 20 Apr 2018 12:06:41 +0800 Subject: [PATCH 0657/2461] [SPARK-23877][SQL] Use filter predicates to prune partitions in metadata-only queries ## What changes were proposed in this pull request? This updates the OptimizeMetadataOnlyQuery rule to use filter expressions when listing partitions, if there are filter nodes in the logical plan. This avoids listing all partitions for large tables on the driver. This also fixes a minor bug where the partitions returned from fsRelation cannot be serialized without hitting a stack level too deep error. This is caused by serializing a stream to executors, where the stream is a recursive structure. If the stream is too long, the serialization stack reaches the maximum level of depth. The fix is to create a LocalRelation using an Array instead of the incoming Seq. ## How was this patch tested? Existing tests for metadata-only queries. Author: Ryan Blue Closes #20988 from rdblue/SPARK-23877-metadata-only-push-filters. --- .../execution/OptimizeMetadataOnlyQuery.scala | 94 +++++++++++++------ .../OptimizeHiveMetadataOnlyQuerySuite.scala | 68 ++++++++++++++ 2 files changed, 132 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index dc4aff9f12580..acbd4becb8549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -49,9 +49,9 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(partAttrs)) { + if (a.references.subsetOf(attrs)) { val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -67,7 +67,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic }) } if (isAllDistinctAgg) { - a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation))) + a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters))) } else { a } @@ -98,14 +98,27 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic */ private def replaceTableScanWithPartitionMetadata( child: LogicalPlan, - relation: LogicalPlan): LogicalPlan = { + relation: LogicalPlan, + partFilters: Seq[Expression]): LogicalPlan = { + // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the + // relation's schema. PartitionedRelation ensures that the filters only reference partition cols + val relFilters = partFilters.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + child transform { case plan if plan eq relation => relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(Nil, Nil) - LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) + val partitionData = fsRelation.location.listFiles(relFilters, Nil) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -113,12 +126,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) - val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => + val partitions = if (partFilters.nonEmpty) { + catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + } else { + catalog.listPartitions(relation.tableMeta.identifier) + } + + val partitionData = partitions.map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - LocalRelation(partAttrs, partitionData) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.toArray) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -129,35 +151,47 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes and the table relation node. + * pair of the partition attributes, partition filters, and the table relation node. * * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with * deterministic expressions, and returns result after reaching the partitioned table relation * node. */ - object PartitionedRelation { - - def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => - val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - Some((AttributeSet(partAttrs), l)) - - case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) - Some((AttributeSet(partAttrs), relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, relation) => - if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None - } + object PartitionedRelation extends PredicateHelper { + + def unapply( + plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + plan match { + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) + if fsRelation.partitionSchema.nonEmpty => + val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) + Some((partAttrs, partAttrs, Nil, l)) + + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = AttributeSet( + getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) + Some((partAttrs, partAttrs, Nil, relation)) + + case p @ Project(projectList, child) if projectList.forall(_.deterministic) => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (p.references.subsetOf(attrs)) { + Some((partAttrs, p.outputSet, filters, relation)) + } else { + None + } + } - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, relation) => - if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None - } + case f @ Filter(condition, child) if condition.deterministic => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (f.references.subsetOf(partAttrs)) { + Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) + } else { + None + } + } - case _ => None + case _ => None + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala new file mode 100644 index 0000000000000..95f192f0e40e2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + import spark.implicits._ + + before { + sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") + (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) + } + + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the number of matching partitions + assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5) + + // verify that the partition predicate was pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5) + } + } + + test("SPARK-23877: filter on projected expression") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the matching partitions + val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) + .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) + + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + + // verify that the partition predicate was not pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) + } + } +} From e6b466084c26fbb9b9e50dd5cc8b25da7533ac72 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Fri, 20 Apr 2018 14:58:11 +0900 Subject: [PATCH 0658/2461] [SPARK-23736][SQL] Extending the concat function to support array columns ## What changes were proposed in this pull request? The PR adds a logic for easy concatenation of multiple array columns and covers: - Concat expression has been extended to support array columns - A Python wrapper ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite - typeCoercion/native/concat.sql ## Codegen examples ### Primitive-type elements ``` val df = Seq( (Seq(1 ,2), Seq(3, 4)), (Seq(1, 2, 3), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 070 */ project_numElements, /* 071 */ 4); /* 072 */ if (project_size > 2147483632) { /* 073 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size + /* 074 */ " bytes of data due to exceeding the limit 2147483632 bytes" + /* 075 */ " for UnsafeArrayData."); /* 076 */ } /* 077 */ /* 078 */ byte[] project_array = new byte[(int)project_size]; /* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData(); /* 080 */ Platform.putLong(project_array, 16, project_numElements); /* 081 */ project_arrayData.pointTo(project_array, 16, (int)project_size); /* 082 */ int project_counter = 0; /* 083 */ for (int y = 0; y < 2; y++) { /* 084 */ for (int z = 0; z < args[y].numElements(); z++) { /* 085 */ if (args[y].isNullAt(z)) { /* 086 */ project_arrayData.setNullAt(project_counter); /* 087 */ } else { /* 088 */ project_arrayData.setInt( /* 089 */ project_counter, /* 090 */ args[y].getInt(z) /* 091 */ ); /* 092 */ } /* 093 */ project_counter++; /* 094 */ } /* 095 */ } /* 096 */ return project_arrayData; /* 097 */ } /* 098 */ }.concat(project_args); /* 099 */ boolean project_isNull = project_value == null; ``` ### Non-primitive-type elements ``` val df = Seq( (Seq("aa" ,"bb"), Seq("ccc", "ddd")), (Seq("x", "y"), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ Object[] project_arrayObjects = new Object[(int)project_numElements]; /* 070 */ int project_counter = 0; /* 071 */ for (int y = 0; y < 2; y++) { /* 072 */ for (int z = 0; z < args[y].numElements(); z++) { /* 073 */ project_arrayObjects[project_counter] = args[y].getUTF8String(z); /* 074 */ project_counter++; /* 075 */ } /* 076 */ } /* 077 */ return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects); /* 078 */ } /* 079 */ }.concat(project_args); /* 080 */ boolean project_isNull = project_value == null; ``` Author: mn-mikke Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master. --- .../spark/unsafe/array/ByteArrayMethods.java | 6 +- python/pyspark/sql/functions.py | 34 +-- .../catalyst/expressions/UnsafeArrayData.java | 10 + .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../expressions/collectionOperations.scala | 220 +++++++++++++++++- .../expressions/stringExpressions.scala | 81 ------- .../CollectionExpressionsSuite.scala | 41 ++++ .../org/apache/spark/sql/functions.scala | 20 +- .../inputs/typeCoercion/native/concat.sql | 62 +++++ .../typeCoercion/native/concat.sql.out | 78 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 74 ++++++ .../sql/execution/command/DDLSuite.scala | 4 +- 13 files changed, 529 insertions(+), 111 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 4bc9955090fd7..ef0f78d95d1ee 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1be68f2a4a448..da32ab25cad0c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1425,21 +1425,6 @@ def hash(*cols): del _name, _doc -@since(1.5) -@ignore_unicode_prefix -def concat(*cols): - """ - Concatenates multiple input columns together into a single column. - If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) - - @since(1.5) @ignore_unicode_prefix def concat_ws(sep, *cols): @@ -1845,6 +1830,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input columns together into a single column. + The function works with strings, binary and compatible array columns. + + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + @since(2.4) def array_position(col, value): """ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8546c28335536..d5d934bc91cab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -56,9 +56,19 @@ public final class UnsafeArrayData extends ArrayData { public static int calculateHeaderPortionInBytes(int numFields) { + return (int)calculateHeaderPortionInBytes((long)numFields); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a44f2d5272b8e..c41f16c61d7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -308,7 +308,6 @@ object FunctionRegistry { expression[BitLength]("bit_length"), expression[Length]("char_length"), expression[Length]("character_length"), - expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), @@ -413,6 +412,7 @@ object FunctionRegistry { expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), + expression[Concat]("concat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 281f206e8d59e..cfcbd8db559a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -520,6 +520,14 @@ object TypeCoercion { case None => a } + case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case None => c + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dba426e999dda..c16793bda028e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** * Given an array or map, returns its size. Returns -1 if null. @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def prettyName: String = "element_at" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + + s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: Nil) + ev.copy(s""" + $initCode + $codes + $javaType ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (code, numElements) + } + + private def nullArgumentProtection() : String = { + if (nullable) { + s""" + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} + """.stripMargin + } else { + "" + } + } + + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + + | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + + | " for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | $unsafeArraySizeInBytes + | byte[] $arrayName = new byte[(int)$arraySizeName]; + | UnsafeArrayData $arrayData = new UnsafeArrayData(); + | Platform.putLong($arrayName, $baseOffset, $numElemName); + | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5a02ca0d6862c..ea005a26a4c8b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * An expression that concatenates multiple inputs into a single output. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * If any input is null, concat returns null. - */ -@ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private lazy val isBinaryMode: Boolean = dataType == BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } - } - - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - if (isBinaryMode) { - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - } else { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } - - val (concatenator, initCode) = if (isBinaryMode) { - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - } else { - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - override def toString: String = s"concat(${children.mkString(", ")})" - - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" -} - - /** * An expression that concatenates multiple input strings or array of strings into a single string, * using a given separator (the first child). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7d8fe211858b2..43c5dda2e4a48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m2, Literal("a")), null) } + + test("Concat") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val ai4 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5)) + checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5)) + checkEvaluation(Concat(Seq(ai4)), null) + checkEvaluation(Concat(Seq(ai0, ai4)), null) + checkEvaluation(Concat(Seq(ai4, ai0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) + val as4 = Literal.create(null, ArrayType(StringType)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + + checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e")) + checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e")) + checkEvaluation(Concat(Seq(as4)), null) + checkEvaluation(Concat(Seq(as0, as4)), null) + checkEvaluation(Concat(Seq(as4, as0)), null) + + checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c8580378303e..bea8c0e445002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2228,16 +2228,6 @@ object functions { */ def base64(e: Column): Column = withExpr { Base64(e.expr) } - /** - * Concatenates multiple input columns together into a single column. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } - /** * Concatenates multiple input string columns together into a single string column, * using the given separator. @@ -3038,6 +3028,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } + /** * Locates the position of the first occurrence of the value in the given array as long. * Returns null if either of the arguments are null. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 0beebec5702fd..db00a18f2e7e9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -91,3 +91,65 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +); + +-- Concatenate arrays of the same type +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays; + +-- Concatenate arrays of different types +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 09729fdc2ec32..62befc5ca0f15 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -237,3 +237,81 @@ struct 78910 891011 9101112 + + +-- !query 11 +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays +-- !query 12 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,data_array:array,timestamp_array:array,string_array:array,array_array:array>,struct_array:array>,map_array:array>> +-- !query 12 output +[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}] + + +-- !query 13 +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays +-- !query 13 schema +struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +-- !query 13 output +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7c976c1b7f915..25e5cd60dd236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("concat function - arrays") { + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + + // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index cbd7f9d6f67be..3998ceca38b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + Row("Usage: concat(col1, col2, ..., colN) - " + + "Returns the concatenation of col1, col2, ..., colN.") :: Nil ) // extended mode checkAnswer( From 074a7f90536493b607e8e74bcebf3a27ea49a49d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 14:43:47 +0200 Subject: [PATCH 0659/2461] [SPARK-23588][SQL][FOLLOW-UP] Resolve a map builder method per execution in CatalystToExternalMap ## What changes were proposed in this pull request? This pr is a follow-up pr of #20979 and fixes code to resolve a map builder method per execution instead of per row in `CatalystToExternalMap`. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21112 from maropu/SPARK-23588-FOLLOWUP. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420a..32c1f34ef97a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1040,11 +1040,13 @@ case class CatalystToExternalMap private( private lazy val valueConverter = CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") - val module = clazz.getField("MODULE$").get(null) - val method = clazz.getMethod("newBuilder") - method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { From 0dd97f6ea4affde1531dec1bec004b7ab18c6965 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 15:02:27 +0200 Subject: [PATCH 0660/2461] [SPARK-23595][SQL] ValidateExternalType should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ValidateExternalType`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20757 from maropu/SPARK-23595. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../expressions/objects/objects.scala | 34 ++++++++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 33 ++++++++++++++++-- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a8..f9acc208b715e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 32c1f34ef97a5..f1ffcaec8a484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -1672,13 +1673,36 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. @@ -1691,7 +1715,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0b..7136af8934486 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } } class TestBean extends Serializable { From 1d758dc73b54e802fdc92be204185fe7414e6553 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Apr 2018 10:23:01 -0700 Subject: [PATCH 0661/2461] Revert "[SPARK-23775][TEST] Make DataFrameRangeSuite not flaky" This reverts commit 0c94e48bc50717e1627c0d2acd5382d9adc73c97. --- .../spark/sql/DataFrameRangeSuite.scala | 78 ++++++++----------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index a0fd74088ce8b..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql -import java.util.concurrent.{CountDownLatch, TimeUnit} - import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -154,53 +152,39 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - // Save and restore the value because SparkContext is shared - val savedInterruptOnCancel = sparkContext - .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - - try { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") - - for (codegen <- Seq(true, false)) { - // This countdown latch used to make sure with all the stages cancelStage called in listener - val latch = new CountDownLatch(2) - - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - sparkContext.cancelStage(taskStart.stageId) - latch.countDown() - } + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + eventually(timeout(10.seconds), interval(1.millis)) { + assert(DataFrameRangeSuite.stageToKill > 0) } + sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + } + } - sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => - x.synchronized { - x.wait() - } - x - }.toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } + sparkContext.addSparkListener(listener) + for (codegen <- Seq(true, false)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + DataFrameRangeSuite.stageToKill = -1 + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1).map { x => + DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() + x + }.toDF("id").agg(sum("id")).collect() } - latch.await(20, TimeUnit.SECONDS) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } - sparkContext.removeSparkListener(listener) } - } finally { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, - savedInterruptOnCancel) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -220,3 +204,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + +object DataFrameRangeSuite { + @volatile var stageToKill = -1 +} From 32b4bcd6d31b92b179a15f9886779fc5f96404b5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 21 Apr 2018 23:14:58 +0800 Subject: [PATCH 0662/2461] [SPARK-24029][CORE] Set SO_REUSEADDR on listen sockets. This allows sockets to be bound even if there are sockets from a previous application that are still pending closure. It avoids bind issues when, for example, re-starting the SHS. Don't enable the option on Windows though. The following page explains some odd behavior that this option can have there: https://msdn.microsoft.com/en-us/library/windows/desktop/ms740621%28v=vs.85%29.aspx I intentionally ignored server sockets that always bind to ephemeral ports, since those don't benefit from this option. Author: Marcelo Vanzin Closes #21110 from vanzin/SPARK-24029. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 +++- .../org/apache/spark/deploy/rest/RestSubmissionServer.scala | 1 + core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bcc..612750972c4bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,7 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator); + .childOption(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f270..3d99d085408c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0e8a6307de6a8..d6a025a6f12da 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging { connectionFactories: _*) connector.setPort(port) connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads From 7bc853d08973a6bd839ad2222911eb0a0f413677 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Apr 2018 10:45:12 -0700 Subject: [PATCH 0663/2461] [SPARK-24033][SQL] Fix Mismatched of Window Frame specifiedwindowframe(RowFrame, -1, -1) ## What changes were proposed in this pull request? When the OffsetWindowFunction's frame is `UnaryMinus(Literal(1))` but the specified window frame has been simplified to `Literal(-1)` by some optimizer rules e.g., `ConstantFolding`. Thus, they do not match and cause the following error: ``` org.apache.spark.sql.AnalysisException: Window Frame specifiedwindowframe(RowFrame, -1, -1) must match the required frame specifiedwindowframe(RowFrame, -1, -1); at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:91) at ``` ## How was this patch tested? Added a test Author: gatorsmile Closes #21115 from gatorsmile/fixLag. --- .../catalyst/expressions/windowExpressions.scala | 5 ++++- .../spark/sql/DataFrameWindowFramesSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 78895f1c2f6f5..9fe2fb2b95e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -342,7 +342,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 0ee9b0edc02b2..2a0b2b85e10a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -402,4 +402,18 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } } From c48085aa91c60615a4de3b391f019f46f3fcdbe3 Mon Sep 17 00:00:00 2001 From: Mykhailo Shtelma Date: Sat, 21 Apr 2018 23:33:57 -0700 Subject: [PATCH 0664/2461] [SPARK-23799][SQL] FilterEstimation.evaluateInSet produces devision by zero in a case of empty table with analyzed statistics >What changes were proposed in this pull request? During evaluation of IN conditions, if the source data frame, is represented by a plan, that uses hive table with columns, which were previously analysed, and the plan has conditions for these fields, that cannot be satisfied (which leads us to an empty data frame), FilterEstimation.evaluateInSet method produces NumberFormatException and ClassCastException. In order to fix this bug, method FilterEstimation.evaluateInSet at first checks, if distinct count is not zero, and also checks if colStat.min and colStat.max are defined, and only in this case proceeds with the calculation. If at least one of the conditions is not satisfied, zero is returned. >How was this patch tested? In order to test the PR two tests were implemented: one in FilterEstimationSuite, that tests the plan with the statistics that violates the conditions mentioned above, and another one in StatisticsCollectionSuite, that test the whole process of analysis/optimisation of the query, that leads to the problems, mentioned in the first section. Author: Mykhailo Shtelma Author: smikesh Closes #21052 from mshtelma/filter_estimation_evaluateInSet_Bugs. --- .../statsEstimation/FilterEstimation.scala | 4 +++ .../FilterEstimationSuite.scala | 11 ++++++++ .../spark/sql/StatisticsCollectionSuite.scala | 28 +++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0538c9d88584b..263c9ba60d145 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,6 +392,10 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 43440d51dede6..16cb5d032cf57 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -357,6 +357,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..b91712f4cc25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -382,4 +382,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } From c3a86faa53c9e49efd595802adc38a6d412ce681 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 Apr 2018 10:45:25 +0800 Subject: [PATCH 0665/2461] [SPARK-10399][SPARK-23879][FOLLOWUP][CORE] Free unused off-heap memory in MemoryBlockSuite ## What changes were proposed in this pull request? As viirya pointed out [here](https://github.com/apache/spark/pull/19222#discussion_r179910484), this PR explicitly frees unused off-heap memory in `MemoryBlockSuite` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21117 from kiszk/SPARK-10399-free-offheap. --- .../java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 5d5fdc1c55a75..ef5ff8ee70ec0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } catch (Exception expected) { Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); } @Test @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() { int length = 56; check(memory, obj, offset, length); + memoryAllocator.free(memory); long address = Platform.allocateMemory(112); memory = new OffHeapMemoryBlock(address, length); obj = memory.getBaseObject(); offset = memory.getBaseOffset(); check(memory, obj, offset, length); + Platform.freeMemory(address); } } From f70f46d1e5bc503e9071707d837df618b7696d32 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:18:50 +0800 Subject: [PATCH 0666/2461] [SPARK-23877][SQL][FOLLOWUP] use PhysicalOperation to simplify the handling of Project and Filter over partitioned relation ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/20988 `PhysicalOperation` can collect Project and Filters over a certain plan and substitute the alias with the original attributes in the bottom plan. We can use it in `OptimizeMetadataOnlyQuery` rule to handle the Project and Filter over partitioned relation. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21111 from cloud-fan/refactor. --- .../plans/logical/LocalRelation.scala | 6 ++ .../sql/execution/LocalTableScanExec.scala | 3 + .../execution/OptimizeMetadataOnlyQuery.scala | 58 ++++++------------- .../OptimizeHiveMetadataOnlyQuerySuite.scala | 16 ++++- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b05508db786ad..720d42ab409a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,6 +43,12 @@ object LocalRelation { } } +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c7..448eb703eacde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index acbd4becb8549..3ca03ab2939aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -49,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(attrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -102,7 +107,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic partFilters: Seq[Expression]): LogicalPlan = { // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the // relation's schema. PartitionedRelation ensures that the filters only reference partition cols - val relFilters = partFilters.map { e => + val normalizedFilters = partFilters.map { e => e transform { case a: AttributeReference => a.withName(relation.output.find(_.semanticEquals(a)).get.name) @@ -114,11 +119,8 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(relFilters, Nil) - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -127,7 +129,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitions = if (partFilters.nonEmpty) { - catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) } else { catalog.listPartitions(relation.tableMeta.identifier) } @@ -137,10 +139,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.toArray) + LocalRelation(partAttrs, partitionData) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -151,44 +150,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes, partition filters, and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. + * pair of the partition attributes and the table relation node. */ object PartitionedRelation extends PredicateHelper { - def unapply( - plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => + if fsRelation.partitionSchema.nonEmpty => val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) - Some((partAttrs, partAttrs, Nil, l)) + Some((partAttrs, l)) case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = AttributeSet( getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) - Some((partAttrs, partAttrs, Nil, relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (p.references.subsetOf(attrs)) { - Some((partAttrs, p.outputSet, filters, relation)) - } else { - None - } - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (f.references.subsetOf(partAttrs)) { - Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) - } else { - None - } - } + Some((partAttrs, relation)) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 95f192f0e40e2..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -32,13 +33,22 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto import spark.implicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) } + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the number of matching partitions @@ -50,7 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions From d87d30e4fe9c9e91c462351e9f744a830db8d6fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:21:01 +0800 Subject: [PATCH 0667/2461] [SPARK-23564][SQL] infer additional filters from constraints for join's children ## What changes were proposed in this pull request? The existing query constraints framework has 2 steps: 1. propagate constraints bottom up. 2. use constraints to infer additional filters for better data pruning. For step 2, it mostly helps with Join, because we can connect the constraints from children to the join condition and infer powerful filters to prune the data of the join sides. e.g., the left side has constraints `a = 1`, the join condition is `left.a = right.a`, then we can infer `right.a = 1` to the right side and prune the right side a lot. However, the current logic of inferring filters from constraints for Join is pretty weak. It infers the filters from Join's constraints. Some joins like left semi/anti exclude output from right side and the right side constraints will be lost here. This PR propose to check the left and right constraints individually, expand the constraints with join condition and add filters to children of join directly, instead of adding to the join condition. This reverts https://github.com/apache/spark/pull/20670 , covers https://github.com/apache/spark/pull/20717 and https://github.com/apache/spark/pull/20816 This is inspired by the original PRs and the tests are all from these PRs. Thanks to the authors mgaido91 maryannxue KaiXinXiaoLei ! ## How was this patch tested? new tests Author: Wenchen Fan Closes #21083 from cloud-fan/join. --- .../sql/catalyst/optimizer/Optimizer.scala | 97 +++++++++---------- .../plans/logical/QueryPlanConstraints.scala | 95 ++++++++---------- .../InferFiltersFromConstraintsSuite.scala | 53 +++++++--- 3 files changed, 124 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 913354e4df0e6..f00d40d11f23f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,13 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * In addition, for left/right outer joins, infer predicate from the preserved side of the Join - * operator and push the inferred filter over to the null-supplying side. For example, if the - * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in - * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will - * be applied to the null-supplying side. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.allConstraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt - case None => - additionalConstraints.reduceOption(And) - } - // Infer filter for left/right outer joins - val newLeftOpt = joinType match { - case RightOuter if newConditionOpt.isDefined => - val inferredConstraints = left.getRelevantConstraints( - left.constraints - .union(right.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(left.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, left)) - case _ => None - } - val newRightOpt = joinType match { - case LeftOuter if newConditionOpt.isDefined => - val inferredConstraints = right.getRelevantConstraints( - right.constraints - .union(left.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(right.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, right)) - case _ => None - } + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) - if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) - || newLeftOpt.isDefined || newRightOpt.isDefined) { - Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) - } else { - join + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a29f3d29236c7..cc352c59dff80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains an additional set of constraints, such as equality - * constraints and `isNotNull` constraints, etc. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val allConstraints: ExpressionSet = { + lazy val constraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } } - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) - /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { /** - * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as - * equality constraints and `isNotNull` constraints, etc., and that only contains references - * to this [[LogicalPlan]] node. + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. */ - def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { - val allRelevantConstraints = - if (conf.constraintPropagationEnabled) { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - } else { - constraints - } - ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints } + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) - - private def selfReferenceOnly(e: Expression): Boolean = { - e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e068f51044589..e4671f0d1cce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-21479: Outer join no filter push down to preserved side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a) && 'a === 1) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } From afbdf427302aba858f95205ecef7667f412b2a6a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 23 Apr 2018 14:28:28 +0200 Subject: [PATCH 0668/2461] [SPARK-23589][SQL] ExternalMapToCatalyst should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ExternalMapToCatalyst`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20980 from maropu/SPARK-23589. --- .../expressions/objects/objects.scala | 60 +++++++++- .../expressions/ObjectExpressionsSuite.scala | 108 +++++++++++++++++- 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f1ffcaec8a484..9c7e76467d153 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,8 +1255,64 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 7136af8934486..730b36c32333c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -501,6 +502,111 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(1))), "java.lang.Integer is not a valid external type for schema of double") } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } } class TestBean extends Serializable { From 293a0f29e314dc532cec2048a7c6bc00e31de472 Mon Sep 17 00:00:00 2001 From: Teng Peng Date: Mon, 23 Apr 2018 10:29:47 -0700 Subject: [PATCH 0669/2461] [Spark-24024][ML] Fix poisson deviance calculations in GLM to handle y = 0 ## What changes were proposed in this pull request? It is reported by Spark users that the deviance calculation for poisson regression does not handle y = 0. Thus, the correct model summary cannot be obtained. The user has confirmed the the issue is in ``` override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (y * math.log(y / mu) - (y - mu)) } when y = 0. ``` The user also mentioned there are many other places he believe we should check the same thing. However, no other changes are needed, including Gamma distribution. ## How was this patch tested? Add a comparison with R deviance calculation to the existing unit test. Author: Teng Peng Closes #21125 from tengpeng/Spark24024GLM. --- .../ml/regression/GeneralizedLinearRegression.scala | 10 +++++----- .../regression/GeneralizedLinearRegressionSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 9f1f2405c428e..4c3f1431d5077 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -471,6 +471,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -725,10 +729,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +783,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d5bcbb221783e..997c50157dcda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } From 448d248f897fa39cfc82d71a3d6b67e6470f8a02 Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Mon, 23 Apr 2018 13:56:11 -0500 Subject: [PATCH 0670/2461] [SPARK-21168] KafkaRDD should always set kafka clientId. [https://issues.apache.org/jira/browse/SPARK-21168](https://issues.apache.org/jira/browse/SPARK-21168) There are no a number of other places that a client ID should be set,and I think we should use consumer.clientId in the clientId method,because the fetch request will be used by the same consumer behind. Author: liuzhaokun Closes #19887 from liu-zhaokun/master1205. --- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..791cf0efaf888 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) From 770add81c3474e754867d7105031a5eaf27159bd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Apr 2018 13:20:32 -0700 Subject: [PATCH 0671/2461] [SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A structured streaming query with a streaming aggregation can throw the following error in rare cases.  ``` java.lang.IllegalStateException: Cannot commit after already committed or aborted at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643) at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359) at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336) ``` This can happen when the following conditions are accidentally hit.  - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.).  - Query running in `update}` mode - After the shuffle, a partition has exactly 128 records.  This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`. ## How was this patch tested? Added unit test that I have confirm will fail without the fix. Author: Tathagata Das Closes #21124 from tdas/SPARK-23004. --- .../streaming/statefulOperators.scala | 40 +++++++++---------- .../streaming/StreamingAggregationSuite.scala | 25 ++++++++++++ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f9..c9354ac0ec78a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -340,37 +340,35 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1cae8cb8d47f1..382da13430781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From e82cb68349b785c1b35bcfb85bff3a8ec2c93fee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 23 Apr 2018 13:23:02 -0700 Subject: [PATCH 0672/2461] [SPARK-11237][ML] Add pmml export for k-means in Spark ML ## What changes were proposed in this pull request? Adding PMML export to Spark ML's KMeans Model. ## How was this patch tested? New unit test for Spark ML PMML export based on the old Spark MLlib unit test. Author: Holden Karau Closes #20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans. --- .../org.apache.spark.ml.util.MLFormatRegister | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 75 ++++++++++++------- .../ml/regression/LinearRegression.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 32 +++++++- 4 files changed, 83 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 5e5484fd8784d..f14431d50feec 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1,2 +1,4 @@ org.apache.spark.ml.regression.InternalLinearRegressionModelWriter -org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 987a4285ebad4..1ad157a695a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -152,14 +154,14 @@ class KMeansModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +187,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f67d9d831f327..9cdd3a051e719 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter /** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter - extends MLWriterFormat with MLFormatRegister { + extends MLWriterFormat with MLFormatRegister { override def format(): String = "pmml" diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32830b39407ad..77c9d482d95b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering import scala.util.Random +import org.dmg.pmml.{ClusteringModel, PMML} + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header descripiton is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { From c8f3ac69d176bd10b8de1c147b6903a247943d51 Mon Sep 17 00:00:00 2001 From: wuyi Date: Mon, 23 Apr 2018 15:35:45 -0500 Subject: [PATCH 0673/2461] [SPARK-23888][CORE] correct the comment of hasAttemptOnHost() TaskSetManager.hasAttemptOnHost had a misleading comment. The comment said that it only checked for running tasks, but really it checked for any tasks that might have run in the past as well. This updates to line up with the implementation. Author: wuyi Closes #20998 from Ngone51/SPARK-23888. --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d958658527f6d..8a96a7692f614 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -287,7 +287,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } From 428b903859c3d8873045fdcfffdebe24fc6e027f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Apr 2018 09:10:29 +0800 Subject: [PATCH 0674/2461] [SPARK-24029][CORE] Follow up: set SO_REUSEADDR on the server socket. "childOption" is for the remote connections, not for the server socket that actually listens for incoming connections. Author: Marcelo Vanzin Closes #21132 from vanzin/SPARK-24029.2. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 612750972c4bb..60f51125c07fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -99,8 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) + .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); From 281c1ca0dc96b0441a60c32df3d16fbb1c61e99f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Apr 2018 10:11:09 +0800 Subject: [PATCH 0675/2461] [SPARK-23973][SQL] Remove consecutive Sorts ## What changes were proposed in this pull request? In SPARK-23375 we introduced the ability of removing `Sort` operation during query optimization if the data is already sorted. In this follow-up we remove also a `Sort` which is followed by another `Sort`: in this case the first sort is not needed and can be safely removed. The PR starts from henryr's comment: https://github.com/apache/spark/pull/20560#discussion_r180601594. So credit should be given to him. ## How was this patch tested? added UT Author: Marco Gaido Closes #21072 from mgaido91/SPARK-23973. --- .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++- .../optimizer/RemoveRedundantSortsSuite.scala | 51 ++++++++++++++++--- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f00d40d11f23f..45f13956a0a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -767,12 +767,29 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operation if the child is already sorted + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators */ object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 2319ab8046e56..dae5e6f3ee3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.select('a).analyze + val correctAnswer = orderedPlan.limit(2).select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest { val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } } From c303b1b6766a3dc5961713f98f62cd7d7ac7972a Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 24 Apr 2018 16:16:07 +0800 Subject: [PATCH 0676/2461] [MINOR][DOCS] Fix comments of SQLExecution#withExecutionId ## What changes were proposed in this pull request? Fix comment. Change `BroadcastHashJoin.broadcastFuture` to `BroadcastExchangeExec.relationFuture`: https://github.com/apache/spark/blob/d28d5732ae205771f1f443b15b10e64dcffb5ff0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala#L66 ## How was this patch tested? N/A Author: seancxmao Closes #21113 from seancxmao/SPARK-13136. --- .../scala/org/apache/spark/sql/execution/SQLExecution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bde..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -88,7 +88,7 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) From 87e8a572be14381da9081365d9aa2cbf3253a32c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Apr 2018 16:18:20 +0800 Subject: [PATCH 0677/2461] [SPARK-24054][R] Add array_position function / element_at functions ## What changes were proposed in this pull request? This PR proposes to add array_position and element_at in R side too. array_position: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_position(mutated$v1, 1))) ``` ``` array_position(v1, 1.0) 1 2 2 2 3 2 4 3 5 0 6 3 ``` element_at: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, element_at(mutated$v1, 1))) ``` ``` element_at(v1, 1.0) 1 21.0 2 21.0 3 22.8 4 21.4 5 18.7 6 18.1 ``` ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_map(df$model, df$cyl)) head(select(mutated, element_at(mutated$v1, "Valiant"))) ``` ``` element_at(v3, Valiant) 1 NA 2 NA 3 NA 4 NA 5 NA 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21130 from HyukjinKwon/sparkr_array_position_element_at. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 42 +++++++++++++++++++++++++-- R/pkg/R/generics.R | 8 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++-- 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50ea10482..55dec177ea853 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_position", "asc", "ascii", "asin", @@ -245,6 +246,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426b19674..7b3aa05074563 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,11 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. @@ -201,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -208,7 +214,8 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} NULL #' Window functions for Column operations @@ -2975,7 +2982,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2986,6 +2992,22 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3012,6 +3034,22 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff1a3d76..f30ac9e4295e4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469ffc242..a384997830276 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,17 +1479,23 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_position(), element_at() and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - # Test map_keys() and map_values() + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) @@ -1497,6 +1503,9 @@ test_that("column functions", { result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) From 4926a7c2f0a47b562f99dbb4f1ca17adb3192061 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 24 Apr 2018 17:52:05 +0200 Subject: [PATCH 0678/2461] [SPARK-23589][SQL][FOLLOW-UP] Reuse InternalRow in ExternalMapToCatalyst eval ## What changes were proposed in this pull request? This pr is a follow-up of #20980 and fixes code to reuse `InternalRow` for converting input keys/values in `ExternalMapToCatalyst` eval. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21137 from maropu/SPARK-23589-FOLLOWUP. --- .../expressions/objects/objects.scala | 92 ++++++++++--------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9c7e76467d153..f974fd81fc788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,53 +1255,61 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { - case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[java.util.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - val iter = data.entrySet().iterator() - var i = 0 - while (iter.hasNext) { - val entry = iter.next() - val (key, value) = (entry.getKey, entry.getValue) - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } - case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[scala.collection.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - var i = 0 - for ((key, value) <- data) { - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } + } } override def eval(input: InternalRow): Any = { From 55c4ca88a3b093ee197a8689631be8d1fac1f10f Mon Sep 17 00:00:00 2001 From: Julien Cuquemelle Date: Tue, 24 Apr 2018 10:56:55 -0500 Subject: [PATCH 0679/2461] [SPARK-22683][CORE] Add a executorAllocationRatio parameter to throttle the parallelism of the dynamic allocation ## What changes were proposed in this pull request? By default, the dynamic allocation will request enough executors to maximize the parallelism according to the number of tasks to process. While this minimizes the latency of the job, with small tasks this setting can waste a lot of resources due to executor allocation overhead, as some executor might not even do any work. This setting allows to set a ratio that will be used to reduce the number of target executors w.r.t. full parallelism. The number of executors computed with this setting is still fenced by `spark.dynamicAllocation.maxExecutors` and `spark.dynamicAllocation.minExecutors` ## How was this patch tested? Units tests and runs on various actual workloads on a Yarn Cluster Author: Julien Cuquemelle Closes #19881 from jcuquemelle/AddTaskPerExecutorSlot. --- .../spark/ExecutorAllocationManager.scala | 24 +++++++++++--- .../spark/internal/config/package.scala | 4 +++ .../ExecutorAllocationManagerSuite.scala | 33 +++++++++++++++++++ docs/configuration.md | 18 ++++++++++ 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c045..aa363eeffffb8 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -69,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -116,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -209,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -273,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 99d779fb600e8..6bb98c37b4479 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -126,6 +126,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d4..3cfb0a9feb32b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -145,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get diff --git a/docs/configuration.md b/docs/configuration.md index 4d4d0c58dd07d..fb02d7ea1d4ea 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1753,6 +1753,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio
    spark.dynamicAllocation.executorAllocationRatio1 + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overriden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings +
    spark.dynamicAllocation.schedulerBacklogTimeout 1s
    Property NameDefaultMeaning
    spark.sql.orc.implhivenative The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1.
    spark.yarn.am.waitTime 100s - In cluster mode, time for the YARN Application Master to wait for the - SparkContext to be initialized. In client mode, time for the YARN Application Master to wait - for the driver to connect to it. + Only used in cluster mode. Time for the YARN Application Master to wait for the + SparkContext to be initialized.
    1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.2. + options are 0.12.0 through 2.3.3.
    queryTimeout + The number of seconds the driver will wait for a Statement object to execute to the given + number of seconds. Zero means there is no limit. In the write path, this option depends on + how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver + checks the timeout of each query instead of an entire JDBC batch. + It defaults to 0. +
    fetchsize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaff..917f0cb221412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the - * number of rows per fetch. + * number of rows per fetch and "queryTimeout" can be used to wait + * for a Statement object to execute to the given number of seconds. * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d9..a73a97c06fe5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -89,6 +89,10 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) + // the number of seconds the driver will wait for a Statement object to execute to the given + // number of seconds. Zero means there is no limit. + val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt + // ------------------------------------------------------------ // Optional parameters only for reading // ------------------------------------------------------------ @@ -160,6 +164,7 @@ object JDBCOptions { val JDBC_LOWER_BOUND = newOption("lowerBound") val JDBC_UPPER_BOUND = newOption("upperBound") val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05326210f3242..0bab3689e5d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -57,6 +57,7 @@ object JDBCRDD extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { + statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) @@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD( val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql") try { + statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() @@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD( stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..f8c5677ea0f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider saveTable(df, tableSchema, isCaseSensitive, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it - dropTable(conn, options.table) + dropTable(conn, options.table, options) createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1b..433443007cfd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -76,6 +76,7 @@ object JdbcUtils extends Logging { Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() } finally { statement.close() @@ -86,9 +87,10 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(s"DROP TABLE $table") } finally { statement.close() @@ -102,6 +104,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() @@ -254,6 +257,7 @@ object JdbcUtils extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) } catch { case _: SQLException => None @@ -596,7 +600,8 @@ object JdbcUtils extends Logging { insertStmt: String, batchSize: Int, dialect: JdbcDialect, - isolationLevel: Int): Iterator[Byte] = { + isolationLevel: Int, + options: JDBCOptions): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -637,6 +642,9 @@ object JdbcUtils extends Logging { try { var rowCount = 0 + + stmt.setQueryTimeout(options.queryTimeout) + while (iterator.hasNext) { val row = iterator.next() var i = 0 @@ -819,7 +827,8 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, + options) ) } @@ -841,6 +850,7 @@ object JdbcUtils extends Logging { val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(sql) } finally { statement.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a699..bc2aca65e803f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1190,4 +1190,20 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").schema === schema) } } + + test("SPARK-23856 Spark jdbc setQueryTimeout option") { + val numJoins = 100 + val longRunningQuery = + s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " + + s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}" + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbtable", s"($longRunningQuery)") + .option("queryTimeout", 1) + .load() + val errMsg = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc82879..1c2c92d1f0737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -515,4 +515,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(e.contains("NULL not allowed for column \"NAME\"")) } + + ignore("SPARK-23856 Spark jdbc setQueryTimeout option") { + // The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API + // `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple + // INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout + // of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver + // checks the timeout of an entire batch in a driver side. So, the test below fails because + // this test suite depends on the h2 JDBC driver and the JDBC write path internally + // uses `executeBatch`. + val errMsg = intercept[SparkException] { + spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write + .mode(SaveMode.Overwrite) + .option("queryTimeout", 1) + .option("batchsize", Int.MaxValue) + .jdbc(url1, "TEST.TIMEOUTTEST", properties) + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } From 710e4e81a8efc1aacc14283fb57bc8786146f885 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 18 May 2018 14:37:01 -0700 Subject: [PATCH 0831/2461] [SPARK-24308][SQL] Handle DataReaderFactory to InputPartition rename in left over classes ## What changes were proposed in this pull request? SPARK-24073 renames DataReaderFactory -> InputPartition and DataReader -> InputPartitionReader. Some classes still reflects the old name and causes confusion. This patch renames the left over classes to reflect the new interface and fixes a few comments. ## How was this patch tested? Existing unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21355 from arunmahadevan/SPARK-24308. --- .../spark/sql/kafka010/KafkaContinuousReader.scala | 6 +++--- .../spark/sql/kafka010/KafkaMicroBatchReader.scala | 4 ++-- .../sql/kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sources/v2/reader/ContinuousInputPartition.java | 4 ++-- .../spark/sql/sources/v2/reader/InputPartition.java | 6 +++--- .../sql/sources/v2/reader/InputPartitionReader.java | 6 +++--- .../sql/execution/datasources/v2/DataSourceRDD.scala | 6 +++--- .../continuous/ContinuousRateStreamSource.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 4 ++-- .../streaming/sources/ContinuousMemoryStream.scala | 12 ++++++------ .../sources/RateStreamMicroBatchReader.scala | 4 ++-- .../streaming/sources/RateStreamProviderSuite.scala | 2 +- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 88abf8a8dd027..badaa69cc303c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -106,7 +106,7 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousDataReaderFactory( + KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) .asInstanceOf[InputPartition[UnsafeRow]] }.asJava @@ -146,7 +146,7 @@ class KafkaContinuousReader( } /** - * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. @@ -156,7 +156,7 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousDataReaderFactory( +case class KafkaContinuousInputPartition( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a377738ea782..64ba98762788c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -143,7 +143,7 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges val factories = offsetRanges.map { range => - new KafkaMicroBatchDataReaderFactory( + new KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava @@ -300,7 +300,7 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReaderFactory( +private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 871f9700cd1db..c6412eac97dba 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -679,7 +679,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) val factories = reader.planUnsafeInputPartitions().asScala - .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index c24f3b21eade1..dcb87715d0b6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -27,9 +27,9 @@ @InterfaceStability.Evolving public interface ContinuousInputPartition extends InputPartition { /** - * Create a DataReader with particular offset as its startOffset. + * Create an input partition reader with particular offset as its startOffset. * - * @param offset offset want to set as the DataReader's startOffset. + * @param offset offset want to set as the input partition reader's startOffset. */ InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 3524481784fea..f53687e113ae0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -36,8 +36,8 @@ public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this partition can run faster, - * but Spark does not guarantee to run the data reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run faster, + * but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * @@ -53,7 +53,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work. + * Returns an input partition reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 1b7051f1ad0af..f0d808536207a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,11 +23,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for * outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 1a6b32429313a..8d6fb3820d420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,12 +29,12 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: I class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[InputPartition[T]]) + @transient private val inputPartitions: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 8d25d9ccc43d3..516a563bdcc7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -85,7 +85,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousDataReaderFactory( + RateStreamContinuousInputPartition( start.value, start.runTimeMs, i, @@ -113,7 +113,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) } -case class RateStreamContinuousDataReaderFactory( +case class RateStreamContinuousInputPartition( startValue: Long, startTimeMs: Long, partitionIndex: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index daa2963220aef..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -201,7 +201,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) +class MemoryStreamInputPartition(records: Array[UnsafeRow]) extends InputPartition[UnsafeRow] { override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { new InputPartitionReader[UnsafeRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 4daafa65850de..d1c3498450096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.RpcUtils * * ContinuousMemoryStream maintains a list of records for each partition. addData() will * distribute records evenly-ish across partitions. * * RecordEndpoint is set up as an endpoint for executor-side - * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified - * offset within the list, or null if that offset doesn't yet have a record. + * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { @@ -106,7 +106,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa startOffset.partitionNums.map { case (part, index) => - new ContinuousMemoryStreamDataReaderFactory( + new ContinuousMemoryStreamInputPartition( endpointName, part, index): InputPartition[Row] }.toList.asJava } @@ -157,9 +157,9 @@ object ContinuousMemoryStream { } /** - * Data reader factory for continuous memory stream. + * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamDataReaderFactory( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, startOffset: Int) extends InputPartition[Row] { @@ -168,7 +168,7 @@ class ContinuousMemoryStreamDataReaderFactory( } /** - * Data reader for continuous memory stream. + * An input partition reader for continuous memory stream. * * Polls the driver endpoint for new records. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 723cc3ad5bb89..fbff8db987110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -167,7 +167,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: } (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( + new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) : InputPartition[Row] }.toList.asJava @@ -182,7 +182,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchDataReaderFactory( +class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 39a010f970ce5..bf72e5c99689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -309,7 +309,7 @@ class RateSourceSuite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => + case t: RateStreamContinuousInputPartition => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) From 434d74e337465d77fa49ab65e2b5461e5ff7b5c7 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Fri, 18 May 2018 16:54:39 -0700 Subject: [PATCH 0832/2461] [SPARK-23503][SS] Enforce sequencing of committed epochs for Continuous Execution ## What changes were proposed in this pull request? Made changes to EpochCoordinator so that it enforces a commit order. In case a message for epoch n is lost and epoch (n + 1) is ready for commit before epoch n is, epoch (n + 1) will wait for epoch n to be committed first. ## How was this patch tested? Existing tests in ContinuousSuite and EpochCoordinatorSuite. Author: Efim Poberezkin Closes #20936 from efimpoberezkin/pr/sequence-commited-epochs. --- .../continuous/EpochCoordinator.scala | 69 +++++++++++++++---- .../continuous/EpochCoordinatorSuite.scala | 6 +- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..8877ebeb26735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator( private val partitionOffsets = mutable.Map[(Long, Int), PartitionOffset]() + private var lastCommittedEpoch = startEpoch - 1 + // Remembers epochs that have to wait for previous epochs to be committed first. + private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long] + private def resolveCommitsAtEpoch(epoch: Long) = { - val thisEpochCommits = - partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val thisEpochCommits = findPartitionCommitsForEpoch(epoch) val nextEpochOffsets = partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochCommits.size == numWriterPartitions && nextEpochOffsets.size == numReaderPartitions) { - logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") - // Sequencing is important here. We must commit to the writer before recording the commit - // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, thisEpochCommits.toArray) - query.commit(epoch) - - // Cleanup state from before this epoch, now that we know all partitions are forever past it. - for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) - } - for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + + // Check that last committed epoch is the previous one for sequencing of committed epochs. + // If not, add the epoch being currently processed to epochs waiting to be committed, + // otherwise commit it. + if (lastCommittedEpoch != epoch - 1) { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is waiting for epoch ${epoch - 1} to be committed first.") + epochsWaitingToBeCommitted.add(epoch) + } else { + commitEpoch(epoch, thisEpochCommits) + lastCommittedEpoch = epoch + + // Commit subsequent epochs that are waiting to be committed. + var nextEpoch = lastCommittedEpoch + 1 + while (epochsWaitingToBeCommitted.contains(nextEpoch)) { + val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch) + commitEpoch(nextEpoch, nextEpochCommits) + + epochsWaitingToBeCommitted.remove(nextEpoch) + lastCommittedEpoch = nextEpoch + nextEpoch += 1 + } + + // Cleanup state from before last committed epoch, + // now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionOffsets.remove(k) + } } } } + /** + * Collect per-partition commits for an epoch. + */ + private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = { + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + } + + /** + * Commit epoch to the offset log. + */ + private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is ready to be committed. Committing epoch $epoch.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, messages.toArray) + query.commit(epoch) + } + override def receive: PartialFunction[Any, Unit] = { // If we just drop these messages, we won't do any writes to the query. The lame duck tasks // won't shed errors or anything. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 99e30561f81d5..82836dced9df7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -120,7 +120,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + test("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { setWriterPartitions(2) setReaderPartitions(2) @@ -141,7 +141,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -162,7 +162,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) From dd37529a8dada6ed8a49b8ce50875268f6a20cba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 19 May 2018 18:51:02 +0800 Subject: [PATCH 0833/2461] [SPARK-24250][SQL] support accessing SQLConf inside tasks ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan Closes #21299 from cloud-fan/config. --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 50 ++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 9 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..643e4c686f58d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,13 +27,12 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -107,7 +106,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1297,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1764,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..032525a08ccdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,37 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..d54bfbfc14f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..404d6313ab92c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadonlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From 000e25ae7950ff005d4bbe4fffed410e5947075c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 20 May 2018 16:13:42 +0800 Subject: [PATCH 0834/2461] Revert "[SPARK-24250][SQL] support accessing SQLConf inside tasks" This reverts commit dd37529a8dada6ed8a49b8ce50875268f6a20cba. --- .../org/apache/spark/TaskContextImpl.scala | 2 - .../spark/sql/internal/ReadOnlySQLConf.scala | 66 ------------------- .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +----- .../spark/sql/execution/SQLExecution.scala | 50 ++++---------- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 ------------------- 9 files changed, 36 insertions(+), 210 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 0791fe856ef15..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,6 +178,4 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException - // TODO: shall we publish it and define it in `TaskContext`? - private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala deleted file mode 100644 index 19f67236c8979..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.internal - -import java.util.{Map => JMap} - -import org.apache.spark.{TaskContext, TaskContextImpl} -import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} - -/** - * A readonly SQLConf that will be created by tasks running at the executor side. It reads the - * configs from the local properties which are propagated from driver to executors. - */ -class ReadOnlySQLConf(context: TaskContext) extends SQLConf { - - @transient override val settings: JMap[String, String] = { - context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] - } - - @transient override protected val reader: ConfigReader = { - new ConfigReader(new TaskContextConfigProvider(context)) - } - - override protected def setConfWithCheck(key: String, value: String): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def unsetConf(key: String): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def unsetConf(entry: ConfigEntry[_]): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def clear(): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def clone(): SQLConf = { - throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") - } - - override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { - throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") - } -} - -class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { - override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 643e4c686f58d..53a50305348fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,12 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -106,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (TaskContext.get != null) { - new ReadOnlySQLConf(TaskContext.get()) - } else { - confGetter.get()() - } - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1297,11 +1292,17 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient protected val reader = new ConfigReader(settings) + @transient private val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1764,7 +1765,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - protected def setConfWithCheck(key: String, value: String): Unit = { + private def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 032525a08ccdb..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,18 +68,16 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sc.getCallSite() + val callSite = sparkSession.sparkContext.getCallSite() - withSQLConfPropagated(sparkSession) { - sc.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sc.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { executionIdToQueryExecution.remove(executionId) @@ -92,37 +90,13 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { - val sc = sparkSession.sparkContext + def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - withSQLConfPropagated(sparkSession) { - try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) - body - } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) - } - } - } - - def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - // Set all the specified SQL configs to local properties, so that they can be available at - // the executor side. - val allConfigs = sparkSession.sessionState.conf.getAllConfs - val originalLocalProps = allConfigs.collect { - case (key, value) if key.startsWith("spark") => - val originalValue = sc.getLocalProperty(key) - sc.setLocalProperty(key, value) - (key, originalValue) - } - try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - for ((key, value) <- originalLocalProps) { - sc.setLocalProperty(key, value) - } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d54bfbfc14f5f..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withExecutionId(sparkContext, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..ba83df0efebd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,7 +34,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -105,19 +104,22 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - SQLExecution.withSQLConfPropagated(json.sparkSession) { - JsonInferSchema.infer(rdd, parsedOptions, rowParser) - } + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = inputPaths.map(_.getPath.toString), + paths = paths, className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -163,9 +165,7 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - SQLExecution.withSQLConfPropagated(sparkSession) { - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) - } + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9e0ec9481b0de..daea6c39624d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala deleted file mode 100644 index 404d6313ab92c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.internal - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.test.SQLTestUtils - -class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { - import testImplicits._ - - protected var spark: SparkSession = null - - // Create a new [[SparkSession]] running in local-cluster mode. - override def beforeAll(): Unit = { - super.beforeAll() - spark = SparkSession.builder() - .master("local-cluster[2,1,1024]") - .appName("testing") - .getOrCreate() - } - - override def afterAll(): Unit = { - spark.stop() - spark = null - } - - test("ReadonlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => - val conf = SQLConf.get - Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") - }.collect() - assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") - } - } - - test("case-sensitive config should work for json schema inference") { - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - withTempPath { path => - val pathString = path.getCanonicalPath - spark.range(10).select('id.as("ID")).write.json(pathString) - spark.range(10).write.mode("append").json(pathString) - assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) - } - } - } -} From 8eac621229b50e15bea550a751593bba0bf8b20c Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 20 May 2018 18:15:04 -0500 Subject: [PATCH 0835/2461] [SPARK-23857][MESOS] remove keytab check in mesos cluster mode at first submit time ## What changes were proposed in this pull request? - Removes the check for the keytab when we are running in mesos cluster mode. - Keeps the check for client mode since in cluster mode we eventually launch the driver within the cluster in client mode. In the latter case we want to have the check done when the container starts, the keytab should be checked if it exists within the container's local filesystem. ## How was this patch tested? This was manually tested by running spark submit in mesos cluster mode. Author: Stavros Closes #20967 from skonto/fix_mesos_keytab_susbmit. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 087e9c31a9c9a..4baf032f0e9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -310,6 +310,7 @@ private[spark] class SparkSubmit extends Logging { val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isMesosClient = clusterManager == MESOS && deployMode == CLIENT if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -337,7 +338,7 @@ private[spark] class SparkSubmit extends Logging { val targetDir = Utils.createTempDir() // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) { if (args.principal != null) { if (args.keytab != null) { require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") From f32b7faf7c4b5d2ac45a2db96935f67d1b629ca2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 21 May 2018 09:47:52 +0800 Subject: [PATCH 0836/2461] [MINOR][PROJECT-INFRA] Check if 'original_head' variable is defined in clean_up at merge script ## What changes were proposed in this pull request? This PR proposes to check if global variable exists or not in clean_up. This can happen when it fails at: https://github.com/apache/spark/blob/7013eea11cb32b1e0038dc751c485da5c94a484b/dev/merge_spark_pr.py#L423 I found this (It was my environment problem) but the error message took me a while to debug. ## How was this patch tested? Manually tested: **Before** ``` git rev-parse --abbrev-ref HEAD fatal: Not a git repository (or any of the parent directories): .git Traceback (most recent call last): File "./dev/merge_spark_pr_jira.py", line 517, in clean_up() File "./dev/merge_spark_pr_jira.py", line 104, in clean_up print("Restoring head pointer to %s" % original_head) NameError: global name 'original_head' is not defined ``` **After** ``` git rev-parse --abbrev-ref HEAD fatal: Not a git repository (or any of the parent directories): .git Traceback (most recent call last): File "./dev/merge_spark_pr.py", line 516, in main() File "./dev/merge_spark_pr.py", line 424, in main original_head = get_current_ref() File "./dev/merge_spark_pr.py", line 412, in get_current_ref ref = run_cmd("git rev-parse --abbrev-ref HEAD").strip() File "./dev/merge_spark_pr.py", line 94, in run_cmd return subprocess.check_output(cmd.split(" ")) File "/usr/local/Cellar/python2/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/subprocess.py", line 219, in check_output raise CalledProcessError(retcode, cmd, output=output) subprocess.CalledProcessError: Command '['git', 'rev-parse', '--abbrev-ref', 'HEAD']' returned non-zero exit status 128 ``` Author: hyukjinkwon Closes #21349 from HyukjinKwon/minor-merge-script. --- dev/merge_spark_pr.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 5ea205fbed4aa..7f46a1c8f6a7c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -101,14 +101,15 @@ def continue_maybe(prompt): def clean_up(): - print("Restoring head pointer to %s" % original_head) - run_cmd("git checkout %s" % original_head) + if 'original_head' in globals(): + print("Restoring head pointer to %s" % original_head) + run_cmd("git checkout %s" % original_head) - branches = run_cmd("git branch").replace(" ", "").split("\n") + branches = run_cmd("git branch").replace(" ", "").split("\n") - for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): - print("Deleting local branch %s" % branch) - run_cmd("git branch -D %s" % branch) + for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): + print("Deleting local branch %s" % branch) + run_cmd("git branch -D %s" % branch) # merge the requested PR and return the merge hash From 6d7d45a1af078edd9e4ed027e735d6096482179c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 May 2018 15:39:35 +0800 Subject: [PATCH 0837/2461] [SPARK-24242][SQL] RangeExec should have correct outputOrdering and outputPartitioning ## What changes were proposed in this pull request? Logical `Range` node has been added with `outputOrdering` recently. It's used to eliminate redundant `Sort` during optimization. However, this `outputOrdering` doesn't not propagate to physical `RangeExec` node. We also add correct `outputPartitioning` to `RangeExec` node. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21291 from viirya/SPARK-24242. --- python/pyspark/sql/tests.py | 4 +-- .../execution/basicPhysicalOperators.scala | 14 ++++++++++ .../spark/sql/ConfigBehaviorSuite.scala | 4 ++- .../spark/sql/execution/PlannerSuite.scala | 27 ++++++++++++++++++- .../execution/WholeStageCodegenSuite.scala | 4 +-- .../sql/execution/debug/DebuggingSuite.scala | 7 +++-- 6 files changed, 52 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a1b6db71782bb..c7bd8f01b907f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5239,8 +5239,8 @@ def test_complex_groupby(self): expected2 = df.groupby().agg(sum(df.v)) # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..2df81d09c58e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output + override def outputOrdering: Seq[SortOrder] = range.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numElements > 0) { + if (numSlices == 1) { + SinglePartition + } else { + RangePartitioning(outputOrdering, numSlices) + } + } else { + UnknownPartitioning(0) + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 949505e449fd7..276496be3d62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id.desc) + // Range has range partitioning in its output now. To have a range shuffle, we + // need to run a repartition first. + val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a375f881c7d63..b2aba8e72c5db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -633,6 +633,31 @@ class PlannerSuite extends SharedSQLContext { requiredOrdering = Seq(orderingA, orderingB), shouldHaveSort = true) } + + test("SPARK-24242: RangeExec should have correct output ordering and partitioning") { + val df = spark.range(10) + val rangeExec = df.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + val range = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(rangeExec.head.outputOrdering == range.head.outputOrdering) + assert(rangeExec.head.outputPartitioning == + RangePartitioning(rangeExec.head.outputOrdering, df.rdd.getNumPartitions)) + + val rangeInOnePartition = spark.range(1, 10, 1, 1) + val rangeExecInOnePartition = rangeInOnePartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInOnePartition.head.outputPartitioning == SinglePartition) + + val rangeInZeroPartition = spark.range(-10, -9, -20, 1) + val rangeExecInZeroPartition = rangeInZeroPartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..b714dcd5269fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,12 +51,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = spark.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index adcaf2d76519f..8251ff159e05f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -33,14 +34,16 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } test("debugCodegenStringSeq") { - val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) From e480eccd9754b4900c3e2c2036d69130a262cffe Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 21 May 2018 15:42:04 +0800 Subject: [PATCH 0838/2461] [SPARK-24323][SQL] Fix lint-java errors ## What changes were proposed in this pull request? This PR fixes the following errors reported by `lint-java` ``` % dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java:[39] (sizes) LineLength: Line is longer than 100 characters (found 104). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:[26] (sizes) LineLength: Line is longer than 100 characters (found 110). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:[30] (sizes) LineLength: Line is longer than 100 characters (found 104). ``` ## How was this patch tested? Run `lint-java` manually. Author: Kazuaki Ishizaki Closes #21374 from kiszk/SPARK-24323. --- .../spark/sql/sources/v2/reader/InputPartition.java | 4 ++-- .../spark/sql/sources/v2/reader/InputPartitionReader.java | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index f53687e113ae0..f2038d0de3ffe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -36,8 +36,8 @@ public interface InputPartition extends Serializable { /** - * The preferred locations where the input partition reader returned by this partition can run faster, - * but Spark does not guarantee to run the input partition reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run + * faster, but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index f0d808536207a..33fa7be4c1b20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,12 +23,12 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for - * outputting data for a RDD partition. + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is + * responsible for outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition - * readers that mix in {@link SupportsScanUnsafeRow}. + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input + * partition readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { From a6e883feb3b78232ad5cf636f7f7d5e825183041 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 21 May 2018 23:14:03 +0900 Subject: [PATCH 0839/2461] [SPARK-23935][SQL] Adding map_entries function ## What changes were proposed in this pull request? This PR adds `map_entries` function that returns an unordered array of all entries in the given map. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionsSuite` ## CodeGen examples ### Primitive types ``` val df = Seq(Map(1 -> 5, 2 -> 6)).toDF("m") df.filter('m.isNotNull).select(map_entries('m)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ /* 044 */ ArrayData project_value_0 = null; /* 045 */ /* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements(); /* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray(); /* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray(); /* 049 */ /* 050 */ final long project_size_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 051 */ project_numElements_0, /* 052 */ 32); /* 053 */ if (project_size_0 > 2147483632) { /* 054 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0]; /* 055 */ for (int z = 0; z < project_numElements_0; z++) { /* 056 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getInt(z), project_values_0.getInt(z)}); /* 057 */ } /* 058 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0); /* 059 */ /* 060 */ } else { /* 061 */ final byte[] project_arrayBytes_0 = new byte[(int)project_size_0]; /* 062 */ UnsafeArrayData project_unsafeArrayData_0 = new UnsafeArrayData(); /* 063 */ Platform.putLong(project_arrayBytes_0, 16, project_numElements_0); /* 064 */ project_unsafeArrayData_0.pointTo(project_arrayBytes_0, 16, (int)project_size_0); /* 065 */ /* 066 */ final int project_structsOffset_0 = UnsafeArrayData.calculateHeaderPortionInBytes(project_numElements_0) + project_numElements_0 * 8; /* 067 */ UnsafeRow project_unsafeRow_0 = new UnsafeRow(2); /* 068 */ for (int z = 0; z < project_numElements_0; z++) { /* 069 */ long offset = project_structsOffset_0 + z * 24L; /* 070 */ project_unsafeArrayData_0.setLong(z, (offset << 32) + 24L); /* 071 */ project_unsafeRow_0.pointTo(project_arrayBytes_0, 16 + offset, 24); /* 072 */ project_unsafeRow_0.setInt(0, project_keys_0.getInt(z)); /* 073 */ project_unsafeRow_0.setInt(1, project_values_0.getInt(z)); /* 074 */ } /* 075 */ project_value_0 = project_unsafeArrayData_0; /* 076 */ /* 077 */ } ``` ### Non-primitive types ``` val df = Seq(Map("a" -> "foo", "b" -> null)).toDF("m") df.filter('m.isNotNull).select(map_entries('m)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ /* 044 */ ArrayData project_value_0 = null; /* 045 */ /* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements(); /* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray(); /* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray(); /* 049 */ /* 050 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0]; /* 051 */ for (int z = 0; z < project_numElements_0; z++) { /* 052 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getUTF8String(z), project_values_0.getUTF8String(z)}); /* 053 */ } /* 054 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0); ``` Author: Marek Novotny Closes #21236 from mn-mikke/feature/array-api-map_entries-to-master. --- python/pyspark/sql/functions.py | 20 +++ .../sql/catalyst/expressions/UnsafeRow.java | 2 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 153 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 23 +++ .../expressions/ExpressionEvalHelper.scala | 3 + .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 44 +++++ 9 files changed, 287 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8490081facc5a..fbc8a2d038f8f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2344,6 +2344,26 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 867c2d5eab53d..1134a8866dc13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -419,6 +419,7 @@ object FunctionRegistry { expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dda525294259..d382d9aace109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -764,6 +764,40 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c82db839438ed..8d763dca5243e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -154,6 +155,158 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns an unordered array of all entries in the given map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) + } + + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 + } + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) + } + + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val structSizeAsLong = structSize + "L" + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment + } + + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "map_entries" +} + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6ae1ac18c4dc4..71ff96bb722e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a22e9d4655e8c..c2a44e0d33b18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a8fe583b83bc..5ab9cb3fb86a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3492,6 +3492,13 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d08982a138bc5..df23e07e441a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + val dummyFilter = (c: Column) => c.isNotNull || c.isNull + + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) + checkAnswer( + spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + checkAnswer( + spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), From 03e90f65bfdad376400a4ae4df31a82c05ed4d4b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 May 2018 00:19:18 +0800 Subject: [PATCH 0840/2461] [SPARK-24250][SQL] support accessing SQLConf inside tasks re-submit https://github.com/apache/spark/pull/21299 which broke build. A few new commits are added to fix the SQLConf problem in `JsonSchemaInference.infer`, and prevent us to access `SQLConf` in DAGScheduler event loop thread. ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan Closes #21376 from cloud-fan/config. --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../org/apache/spark/util/EventLoop.scala | 3 +- .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 33 ++++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 54 +++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../datasources/json/JsonInferSchema.scala | 15 +++-- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 12 files changed, 239 insertions(+), 43 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..5f2d16d03165f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -206,7 +206,7 @@ class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..a2fb3c64844b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -95,7 +95,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -107,7 +109,22 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + if (Utils.isTesting && SparkContext.getActive.isDefined) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we prevent it from happening. + val schedulerEventLoopThread = + SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread + if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } + } + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1309,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1776,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,41 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2df81d09c58e7..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -643,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e7eed95a560a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -45,8 +45,9 @@ private[sql] object JsonInferSchema { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -66,9 +67,13 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } - } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) canonicalizeType(rootType) match { case Some(st: StructType) => st diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..3dd0712e02448 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadOnlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From a33dcf4a0bbe20dce6f1e1e6c2e1c3828291fb3d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 21 May 2018 12:58:05 -0700 Subject: [PATCH 0841/2461] [SPARK-24234][SS] Reader for continuous processing shuffle ## What changes were proposed in this pull request? Read RDD for continuous processing shuffle, as well as the initial RPC-based row receiver. https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii ## How was this patch tested? new unit tests Author: Jose Torres Closes #21337 from jose-torres/readerRddMaster. --- .../shuffle/ContinuousShuffleReadRDD.scala | 61 ++++++ .../shuffle/ContinuousShuffleReader.scala | 32 +++ .../shuffle/UnsafeRowReceiver.scala | 75 +++++++ .../shuffle/ContinuousShuffleReadSuite.scala | 184 ++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..270b1a5c28dee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. + */ +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..42631c90ebc55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..b8adbb743c6c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. + */ +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..b25e75b3b37a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverEpochMarker() + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(unsafeRow(part.index)), + ReceiverEpochMarker() + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRowThread = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } +} From ffaefe755e20cb94e27f07b233615a4bbb476679 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 21 May 2018 13:05:17 -0700 Subject: [PATCH 0842/2461] [SPARK-7132][ML] Add fit with validation set to spark.ml GBT ## What changes were proposed in this pull request? Add fit with validation set to spark.ml GBT ## How was this patch tested? Will add later. Author: WeichenXu Closes #21129 from WeichenXu123/gbt_fit_validation. --- .../ml/classification/GBTClassifier.scala | 38 ++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 5 +- .../spark/ml/param/shared/sharedParams.scala | 17 +++++++ .../spark/ml/regression/GBTRegressor.scala | 31 ++++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 41 +++++++++++----- .../classification/GBTClassifierSuite.scala | 46 ++++++++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 48 ++++++++++++++++++- project/MimaExcludes.scala | 13 ++++- 8 files changed, 213 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3fb6d1e4e4f3e..337133a2e2326 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. - val oldDataset: RDD[LabeledPoint] = + val convert2LabeledPoint = (dataset: Dataset[_]) => { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + @@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") ( s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } - val numFeatures = oldDataset.first().features.size + } + + val (trainDataset, validationDataset) = if (withValidation) { + ( + convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), + convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (convert2LabeledPoint(dataset), null) + } + + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) + } + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b9c3170cc3c28..7e08675f834da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), isValid = "(value: String) => " + - "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"), + ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " + + "each row is for training or for validation. False indicates training; true indicates " + + "validation.") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 282ea6ebcbf7f..5928a0749f738 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params { /** @group getParam */ final def getDistanceMeasure: String = $(distanceMeasure) } + +/** + * Trait for shared param validationIndicatorCol. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasValidationIndicatorCol extends Params { + + /** + * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.. + * @group param + */ + final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.") + + /** @group getParam */ + final def getValidationIndicatorCol: String = $(validationIndicatorCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index d7e054bf55ef6..eb8b3c001436a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val numFeatures = oldDataset.first().features.size + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + + val (trainDataset, validationDataset) = if (withValidation) { + ( + extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), + extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (extractLabeledPoints(dataset), null) + } + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index ec8868bb42cbb..00157fe63af41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.util.Try +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize + with HasValidationIndicatorCol { + + /** + * Threshold for stopping early when fit with validation is used. + * (This parameter is ignored when fit without validation is used.) + * The decision to stop early is decided based on this logic: + * If the current loss on the validation set is greater than 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is less than or equal to 0.01, + * the diff of validation error is compared to absolute tolerance which is + * validationTol * 0.01. * @group param + * @see validationIndicatorCol */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 + @Since("2.4.0") + final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", + "Threshold for stopping early when fit with validation is used." + + "If the error rate on the validation input changes by less than the validationTol," + + "then learning will stop early (before `maxIter`)." + + "This parameter is ignored when fit without validation is used.", + ParamValidators.gtEq(0.0) + ) + + /** @group getParam */ + @Since("2.4.0") + final def getValidationTol: Double = $(validationTol) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") @@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol) } /** Get old Gradient Boosting Loss type */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e20de196d65ca..e6d2a8e2b900e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -392,6 +393,51 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(evalArr(2) ~== lossErr3 relTol 1E-3) } + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTClassifier.supportedLossTypes) { + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val (errorWithoutValidation, errorWithValidation) = { + val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), + GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) + } + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Classification) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 773f6d2c542fe..b145c7a3dc952 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -231,7 +232,52 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } - ///////////////////////////////////////////////////////////////////////////// + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + modelWithoutValidation.trees, modelWithoutValidation.treeWeights, + modelWithoutValidation.getOldLossType) + val errorWithValidation = GradientBoostedTrees.computeError(validationData, + modelWithValidation.trees, modelWithValidation.treeWeights, + modelWithValidation.getOldLossType) + + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Regression) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7d0e88ee20c3f..6bae4d147d4ac 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -73,7 +73,18 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), + + // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") ) // Exclude rules for 2.3.x From b550b2a1a159941c7327973182f16004a6bf179d Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 21 May 2018 14:21:05 -0700 Subject: [PATCH 0843/2461] [SPARK-24325] Tests for Hadoop's LinesReader ## What changes were proposed in this pull request? The tests cover basic functionality of [Hadoop LinesReader](https://github.com/apache/spark/blob/8d79113b812a91073d2c24a3a9ad94cc3b90b24a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala#L42). In particular, the added tests check: - A split slices a line or delimiter - A split slices two consecutive lines and cover a delimiter between the lines - Two splits slice a line and there are no duplicates - Internal buffer size (`io.file.buffer.size`) is less than line length - Constrain of maximum line length - `mapreduce.input.linerecordreader.line.maxlength` Author: Maxim Gekk Closes #21377 from MaxGekk/line-reader-tests. --- .../HadoopFileLinesReaderSuite.scala | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala new file mode 100644 index 0000000000000..a39a25be262a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFileLinesReaderSuite extends SharedSQLContext { + def getLines( + path: File, + text: String, + ranges: Seq[(Long, Long)], + delimiter: Option[String] = None, + conf: Option[Configuration] = None): Seq[String] = { + val delimOpt = delimiter.map(_.getBytes(StandardCharsets.UTF_8)) + Files.write(path.toPath, text.getBytes(StandardCharsets.UTF_8)) + + val lines = ranges.map { case (start, length) => + val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) + val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) + + reader.map(_.toString) + }.flatten + + lines + } + + test("A split ends at the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 1), (1, 3))) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 2), (2, 2))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the end of the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 3), (3, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split covers two lines") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 4), (4, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 1), (1, 4)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split slices the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 2), (2, 3)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("The first split covers the first line and the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 4), (4, 1)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the first line") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((0, 1)), Some(",")) + assert(lines == Seq("abc")) + } + } + + test("The split cuts both lines") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((2, 2)), Some(",")) + assert(lines == Seq("def")) + } + } + + test("io.file.buffer.size is less than line length") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("io.file.buffer.size", "2") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } + } + + test("line cannot be longer than line.maxlength") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } + } + + test("default delimiter is 0xd or 0xa or 0xd0xa") { + withTempPath { path => + val lines = getLines(path, text = "1\r2\n3\r\n4", ranges = Seq((0, 3), (3, 5))) + assert(lines == Seq("1", "2", "3", "4")) + } + } +} From 32447079e9d0fa9f7e180b94ecac19091b6af1ab Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 May 2018 16:26:39 -0700 Subject: [PATCH 0844/2461] [SPARK-24309][CORE] AsyncEventQueue should stop on interrupt. EventListeners can interrupt the event queue thread. In particular, when the EventLoggingListener writes to hdfs, hdfs can interrupt the thread. When there is an interrupt, the queue should be removed and stop accepting any more events. Before this change, the queue would continue to take more events (till it was full), and then would not stop when the application was complete because the PoisonPill couldn't be added. Added a unit test which failed before this change. Author: Imran Rashid Closes #21356 from squito/SPARK-24309. --- .../spark/scheduler/AsyncEventQueue.scala | 41 ++++++++------ .../spark/scheduler/LiveListenerBus.scala | 2 +- .../org/apache/spark/util/ListenerBus.scala | 18 +++++++ .../spark/scheduler/SparkListenerSuite.scala | 54 +++++++++++++++++++ 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a90..e2b6df4600590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08ad..d135190d1e919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f23..d4474a90b26f1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index fa47a52bbbc47..6ffd1e84f7adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -489,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -547,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want From 84d31aa5d453620d462f1fdd90206c676a8395cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 21 May 2018 18:11:05 -0700 Subject: [PATCH 0845/2461] [SPARK-24209][SHS] Automatic retrieve proxyBase from Knox headers ## What changes were proposed in this pull request? The PR retrieves the proxyBase automatically from the header `X-Forwarded-Context` (if available). This is the header used by Knox to inform the proxied service about the base path. This provides 0-configuration support for Knox gateway (instead of having to properly set `spark.ui.proxyBase`) and it allows to access directly SHS when it is proxied by Knox. In the previous scenario, indeed, after setting `spark.ui.proxyBase`, direct access to SHS was not working fine (due to bad link generated). ## How was this patch tested? added UT + manual tests Author: Marco Gaido Closes #21268 from mgaido91/SPARK-24209. --- .../spark/deploy/history/HistoryPage.scala | 17 +-- .../spark/deploy/history/HistoryServer.scala | 2 +- .../deploy/master/ui/ApplicationPage.scala | 4 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../spark/deploy/worker/ui/LogPage.scala | 2 +- .../spark/deploy/worker/ui/WorkerPage.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 109 ++++++++++-------- .../apache/spark/ui/env/EnvironmentPage.scala | 2 +- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../apache/spark/ui/exec/ExecutorsTab.scala | 6 +- .../apache/spark/ui/jobs/AllJobsPage.scala | 4 +- .../apache/spark/ui/jobs/AllStagesPage.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 5 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 4 +- .../org/apache/spark/ui/jobs/PoolTable.scala | 9 +- .../org/apache/spark/ui/jobs/StagePage.scala | 8 +- .../org/apache/spark/ui/jobs/StageTable.scala | 12 +- .../org/apache/spark/ui/storage/RDDPage.scala | 7 +- .../apache/spark/ui/storage/StoragePage.scala | 20 +++- .../deploy/history/HistoryServerSuite.scala | 24 ++++ .../spark/ui/storage/StoragePageSuite.scala | 7 +- .../spark/deploy/mesos/ui/DriverPage.scala | 6 +- .../deploy/mesos/ui/MesosClusterPage.scala | 2 +- .../sql/execution/ui/AllExecutionsPage.scala | 38 +++--- .../sql/execution/ui/ExecutionPage.scala | 30 ++--- .../thriftserver/ui/ThriftServerPage.scala | 17 +-- .../ui/ThriftServerSessionPage.scala | 9 +- .../apache/spark/streaming/ui/BatchPage.scala | 21 +++- .../spark/streaming/ui/StreamingPage.scala | 12 +- 29 files changed, 232 insertions(+), 155 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6fc12d721e6f1..32667ddf5c7ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - ++ - + ++ +
    - UIUtils.basicSparkPage(content, "History Server", true) + UIUtils.basicSparkPage(request, content, "History Server", true) } - private def makePageLink(showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) + private def makePageLink(request: HttpServletRequest, showIncomplete: Boolean): String = { + UIUtils.prependBaseUri(request, "/?" + "showIncomplete=" + showIncomplete) } private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 611fa563a7cd9..a9a4d5a4ec6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -87,7 +87,7 @@ class HistoryServer( if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { val msg =
    Application {appId} not found.
    res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n => res.getWriter().write(n.toString) } return diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f699c75085fe1..fad4e46dc035d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
    No running application with ID {appId}
    - return UIUtils.basicSparkPage(msg, "Not Found") + return UIUtils.basicSparkPage(request, msg, "Not Found") } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") @@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } ; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name) } private def executorRow(executor: ExecutorDesc): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c629937606b51..b8afe203fbfa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } private def workerRow(worker: WorkerInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 2f5a5642d3cab..4fca9342c0378 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with - UIUtils.basicSparkPage(content, logType + " log page for " + pageName) + UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8b98ae56fc108..aa4e28d213e2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 02cf19e00ecde..5d015b0531ef6 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} +import javax.servlet.http.HttpServletRequest import scala.util.control.NonFatal import scala.xml._ @@ -148,60 +149,71 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - def uiRoot: String = { + def uiRoot(request: HttpServletRequest): String = { + // Knox uses X-Forwarded-Context to notify the application the base path + val knoxBasePath = Option(request.getHeader("X-Forwarded-Context")) // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. sys.props.get("spark.ui.proxyBase") .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .orElse(knoxBasePath) .getOrElse("") } - def prependBaseUri(basePath: String = "", resource: String = ""): String = { - uiRoot + basePath + resource + def prependBaseUri( + request: HttpServletRequest, + basePath: String = "", + resource: String = ""): String = { + uiRoot(request) + basePath + resource } - def commonHeaderNodes: Seq[Node] = { + def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = { - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + } - def vizHeaderNodes: Seq[Node] = { - - - - - + def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + + + + } - def dataTablesHeaderNodes: Seq[Node] = { + def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> - - - - - - - + href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/> + + + + + + } /** Returns a spark page with correctly formatted headers */ def headerSparkPage( + request: HttpServletRequest, title: String, content: => Seq[Node], activeTab: SparkUITab, @@ -214,25 +226,26 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) - {commonHeaderNodes} - {if (showVisualization) vizHeaderNodes else Seq.empty} - {if (useDataTables) dataTablesHeaderNodes else Seq.empty} - + {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} + {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} + {appName} - {title}
    }.getOrElse(Text("Error fetching thread dump")) - UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + UIUtils.headerSparkPage(request, s"Thread dump for executor $executorId", content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 843486f4a70d2..d5a60f52cbb0f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -49,12 +49,12 @@ private[ui] class ExecutorsPage(
    {
    ++ - ++ - ++ + ++ + ++ }
    - UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) + UIUtils.headerSparkPage(request, "Executors", content, parent, useDataTables = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2b0f4acbac72a..f651fe97c2cd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -248,7 +248,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs, tableHeaderId, jobTag, - UIUtils.prependBaseUri(parent.basePath), + UIUtils.prependBaseUri(request, parent.basePath), "jobs", // subPath parameterOtherTable, killEnabled, @@ -407,7 +407,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + " Click on a job to see information about the stages of tasks inside it." - UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + UIUtils.headerSparkPage(request, "Spark Jobs", content, parent, helpText = Some(helpText)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 4658aa1cea3f1..f672ce0ec6a68 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -66,7 +66,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { ++
    - {poolTable.toNodeSeq} + {poolTable.toNodeSeq(request)}
    } else { Seq.empty[Node] @@ -74,7 +74,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val content = summary ++ poolsDescription ++ tables.flatten.flatten - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + UIUtils.headerSparkPage(request, "Stages for All Jobs", content, parent) } private def summaryAndTableForStatus( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 46f2a76cc651b..55444a2c0c9ab 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -195,7 +195,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP

    No information to display for job {jobId}

    return UIUtils.headerSparkPage( - s"Details for Job $jobId", content, parent) + request, s"Details for Job $jobId", content, parent) } val isComplete = jobData.status != JobExecutionStatus.RUNNING val stages = jobData.stageIds.map { stageId => @@ -413,6 +413,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP {failedStagesTable.toNodeSeq} } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + UIUtils.headerSparkPage( + request, s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index a3e1f13782e30..22a40101e33df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -49,7 +49,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { "stages/pool", parent.isFairScheduler, parent.killEnabled, false) val poolTable = new PoolTable(Map(pool -> uiPool), parent) - var content =

    Summary

    ++ poolTable.toNodeSeq + var content =

    Summary

    ++ poolTable.toNodeSeq(request) if (activeStages.nonEmpty) { content ++= } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + UIUtils.headerSparkPage(request, "Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 5dfce858dec07..96b5f72393070 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder +import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -28,7 +29,7 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = { @@ -39,15 +40,15 @@ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab - {pools.map { case (s, p) => poolRow(s, p) }} + {pools.map { case (s, p) => poolRow(request, s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + private def poolRow(request: HttpServletRequest, s: Schedulable, p: PoolData): Seq[Node] = { val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) + .format(UIUtils.prependBaseUri(request, parent.basePath), URLEncoder.encode(p.name, "UTF-8"))
    {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ac83de10f9237..2575914121c39 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,7 +112,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) @@ -125,7 +125,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( stageData, - UIUtils.prependBaseUri(parent.basePath) + + UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, @@ -498,7 +498,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {taskTableHTML ++ jsForScrollingDownToTaskTable}
    - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) } def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 18a4926f2f6c0..b8b20db1fa407 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -92,7 +92,8 @@ private[ui] class StageTableBase( stageSortColumn, stageSortDesc, isFailedStage, - parameterOtherTable + parameterOtherTable, + request ).table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -147,7 +148,8 @@ private[ui] class StagePagedTable( sortColumn: String, desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] { + parameterOtherTable: Iterable[String], + request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" @@ -161,7 +163,7 @@ private[ui] class StagePagedTable( override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + + val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( @@ -288,7 +290,7 @@ private[ui] class StagePagedTable( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(request, basePath), data.schedulingPool)}> {data.schedulingPool} @@ -346,7 +348,7 @@ private[ui] class StagePagedTable( } private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { - val basePathUri = UIUtils.prependBaseUri(basePath) + val basePathUri = UIUtils.prependBaseUri(request, basePath) val killLink = if (killEnabled) { val confirm = diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2674b9291203a..238cd31433660 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -53,7 +53,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } catch { case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) + return UIUtils.headerSparkPage(request, "RDD Not Found", Seq.empty[Node], parent) } // Worker table @@ -72,7 +72,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } val blockTableHTML = try { val _blockTable = new BlockPagedTable( - UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, @@ -145,7 +145,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web {blockTableHTML ++ jsForScrollingDownToBlockTable} ; - UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) + UIUtils.headerSparkPage( + request, "RDD Storage Info for " + rddStorageInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 68d946574a37b..3eb546e336e99 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,11 +31,14 @@ import org.apache.spark.util.Utils private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) - UIUtils.headerSparkPage("Storage", content, parent) + val content = rddTable(request, store.rddList()) ++ + receiverBlockTables(store.streamBlocksList()) + UIUtils.headerSparkPage(request, "Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { + private[storage] def rddTable( + request: HttpServletRequest, + rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -49,7 +52,11 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends
    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + {UIUtils.listingTable( + rddHeader, + rddRow(request, _: v1.RDDStorageInfo), + rdds, + id = Some("storage-by-rdd-table"))}
    } @@ -66,12 +73,13 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { + private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} - + {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a871b1c717837..11b29121739a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito._ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} @@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } + test("automatically retrieve uiRoot from request through Knox") { + assert(sys.props.get("spark.ui.proxyBase").isEmpty, + "spark.ui.proxyBase is defined but it should not for this UT") + assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty, + "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT") + val page = new HistoryPage(server) + val requestThroughKnox = mock[HttpServletRequest] + val knoxBaseUrl = "/gateway/default/sparkhistoryui" + when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl) + val responseThroughKnox = page.render(requestThroughKnox) + + val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString) + val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/")) + all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl) + + val directRequest = mock[HttpServletRequest] + val directResponse = page.render(directRequest) + + val directUrls = directResponse \\ "@href" map (_.toString) + val directSiteRelativeLinks = directUrls filter (_.startsWith("/")) + all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) + } + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index a71521c91d2f2..cdc7f541b9552 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.storage +import javax.servlet.http.HttpServletRequest + import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") val storagePage = new StoragePage(storageTab, null) + val request = mock(classOf[HttpServletRequest]) test("rddTable") { val rdd1 = new RDDStorageInfo(1, @@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite { None, None) - val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3)) val headers = Seq( "ID", @@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite { } test("empty rddTable") { - assert(storagePage.rddTable(Seq.empty).isEmpty) + assert(storagePage.rddTable(request, Seq.empty).isEmpty) } test("streamBlockStorageLevelDescriptionAndSize") { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 022191d0070fd..91f64141e5318 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")

    Cannot find driver {driverId}

    - return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } val driverState = state.get val driverHeaders = Seq("Driver property", "Value") @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    @@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
    ; - UIUtils.basicSparkPage(content, s"Details for Job $driverId") + UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 88a6614d51384..c53285331ea68 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {retryTable} ; - UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster") } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 582528777f90e..bf46bc4cf904d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L _content ++= new RunningExecutionTable( parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq + running.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq + completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (failed.nonEmpty) { _content ++= new FailedExecutionTable( parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq + failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } _content } @@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } - UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } } @@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable( protected def header: Seq[String] - protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + protected def row( + request: HttpServletRequest, + currentTime: Long, + executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - submissionTime @@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable( def jobLinks(status: JobExecutionStatus): Seq[Node] = { executionUIData.jobs.flatMap { case (jobId, jobStatus) => if (jobStatus == status) { - [{jobId.toString}] + [{jobId.toString}] } else { None } @@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(executionUIData)} + {descriptionCell(request, executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable( } - private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell( + request: HttpServletRequest, + execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { +details @@ -192,27 +197,28 @@ private[ui] abstract class ExecutionTable( } val desc = if (execution.description != null && execution.description.nonEmpty) { - {execution.description} + {execution.description} } else { - {execution.executionId} + {execution.executionId} }
    {desc} {details}
    } - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = {

    {tableName}

    {UIUtils.listingTable[SQLExecutionUIData]( - header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))}
    } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) - private def executionURL(executionID: Long): String = - s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" + private def executionURL(request: HttpServletRequest, executionID: Long): String = + s"${UIUtils.prependBaseUri( + request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e0554f0c4d337..282f7b4bb5a58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -49,7 +49,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
  • {label} {jobs.toSeq.sorted.map { jobId => - {jobId.toString}  + {jobId.toString}  }}
  • } else { @@ -77,27 +77,31 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, graph) ++ + planVisualization(request, metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for query {executionId}
    } - UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + UIUtils.headerSparkPage( + request, s"Details for Query $executionId", content, parent, Some(5000)) } - private def planVisualizationResources: Seq[Node] = { + private def planVisualizationResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - - - + + + + + // scalastyle:on } - private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization( + request: HttpServletRequest, + metrics: Map[Long, String], + graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
    {node.desc}
    @@ -112,13 +116,13 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    {graph.allNodes.size.toString}
    {metadata} - {planVisualizationResources} + {planVisualizationResources(request)} } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index f517bffccdf31..0950b30126773 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,10 +47,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {listener.getOnlineSessionNum} session(s) are online, running {listener.getTotalRunning} SQL statement(s) ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + generateSessionStatsTable(request) ++ + generateSQLStatsTable(request) } - UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -67,7 +67,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", @@ -76,7 +76,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } @@ -138,7 +139,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { + private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { @@ -146,8 +147,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + val sessionLink = "%s/%s/session?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 5cd2fdf6437c2..c884aa0ecbdf8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -56,9 +56,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Session created at {formatDate(sessionStat.startTimestamp)}, Total run {sessionStat.totalExecution} SQL ++ - generateSQLStatsTable(sessionStat.sessionId) + generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -75,7 +75,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { val executionList = listener.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size @@ -86,7 +86,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 6748dd4ec48e3..ca9da6139649a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,6 +47,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -54,7 +55,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { isFirstRow: Boolean, jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { if (jobIdWithData.jobData.isDefined) { - generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(request, outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, @@ -89,6 +90,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -106,7 +108,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") - val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + val detailUrl = s"${SparkUIUtils.prependBaseUri( + request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". @@ -196,6 +199,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = @@ -212,6 +216,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { val firstRow = generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -221,6 +226,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -278,7 +284,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate the job table for the batch. */ - private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + private def generateJobTable( + request: HttpServletRequest, + batchUIData: BatchUIData): Seq[Node] = { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId @@ -301,7 +309,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { outputOpWithJobs.map { case (outputOpData, sparkJobs) => - generateOutputOpIdRow(outputOpData, sparkJobs) + generateOutputOpIdRow(request, outputOpData, sparkJobs) } } @@ -364,9 +372,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
    - val content = summary ++ generateJobTable(batchUIData) + val content = summary ++ generateJobTable(request, batchUIData) - SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + SparkUIUtils.headerSparkPage( + request, s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 3a176f64cdd60..4ce661bc1144e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -148,7 +148,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val resources = generateLoadResources() + val resources = generateLoadResources(request) val basicInfo = generateBasicInfo() val content = resources ++ basicInfo ++ @@ -156,17 +156,17 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) } /** * Generate html that will load css/js files for StreamingPage */ - private def generateLoadResources(): Seq[Node] = { + private def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - + + + // scalastyle:on } From 952e4d1c830c4eb3dfd522be3d292dd02d8c9065 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 22 May 2018 19:12:30 +0800 Subject: [PATCH 0846/2461] [SPARK-24321][SQL] Extract common code from Divide/Remainder to a base trait ## What changes were proposed in this pull request? Extract common code from `Divide`/`Remainder` to a new base trait, `DivModLike`. Further refactoring to make `Pmod` work with `DivModLike` is to be done as a separate task. ## How was this patch tested? Existing tests in `ArithmeticExpressionSuite` covers the functionality. Author: Kris Mok Closes #21367 from rednaxelafx/catalyst-divmod. --- .../sql/catalyst/expressions/arithmetic.scala | 145 ++++++------------ 1 file changed, 51 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..efd4e992c8eec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -220,30 +220,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +234,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,7 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" @@ -283,7 +267,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { ev.copy(code = s""" @@ -297,13 +281,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +322,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( From 82fb5bfa770b0325d4f377dd38d89869007c6111 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 22 May 2018 21:02:17 +0800 Subject: [PATCH 0847/2461] [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ## What changes were proposed in this pull request? The ultimate goal is for listeners to onTaskEnd to receive metrics when a task is killed intentionally, since the data is currently just thrown away. This is already done for ExceptionFailure, so this just copies the same approach. ## How was this patch tested? Updated existing tests. This is a rework of https://github.com/apache/spark/pull/17422, all credits should go to noodle-fb Author: Xianjin YE Author: Charles Lewis Closes #21165 from advancedxy/SPARK-20087. --- .../org/apache/spark/TaskEndReason.scala | 8 ++- .../org/apache/spark/executor/Executor.scala | 55 ++++++++++++------- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/TaskSetManager.scala | 8 ++- .../org/apache/spark/util/JsonProtocol.scala | 9 ++- .../spark/scheduler/DAGSchedulerSuite.scala | 18 ++++-- project/MimaExcludes.scala | 5 ++ 7 files changed, 78 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a76283e33fa65..33901bc8380e9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case class TaskKilled(reason: String) extends TaskFailedReason { +case class TaskKilled( + reason: String, + accumUpdates: Seq[AccumulableInfo] = Seq.empty, + private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil) + extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false + } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c325222b764b8..b1856ff0f3247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -287,6 +287,28 @@ private[spark] class Executor( notifyAll() } + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) @@ -300,7 +322,7 @@ private[spark] class Executor( val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 + var taskStartTime: Long = 0 var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() @@ -336,7 +358,7 @@ private[spark] class Executor( } // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() + taskStartTime = System.currentTimeMillis() taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L @@ -396,11 +418,11 @@ private[spark] class Executor( // Deserialization happens in two parts: first, we deserialize a Task object, which // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) + (taskStartTime - deserializeStartTime) + task.executorDeserializeTime) task.metrics.setExecutorDeserializeCpuTime( (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting - task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime) task.metrics.setExecutorCpuTime( (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -482,16 +504,19 @@ private[spark] class Executor( } catch { case t: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -524,17 +549,7 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of premption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } - - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5f2d16d03165f..ea7bfd7d7a68d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1210,7 +1210,7 @@ class DAGScheduler( case _ => updateAccumulators(event) } - case _: ExceptionFailure => updateAccumulators(event) + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) case _ => } postTaskEnd(event) @@ -1414,13 +1414,13 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case exceptionFailure: ExceptionFailure => + case _: ExceptionFailure | _: TaskKilled => // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => + case _: ExecutorLostFailure | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 195fc8025e4b5..a18c66596852a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -851,13 +851,19 @@ private[spark] class TaskSetManager( } ef.exception + case tk: TaskKilled => + // TaskKilled might have accumulator updates + accumUpdates = tk.accums + logWarning(failureReason) + None + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None - case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others + case e: TaskFailedReason => // TaskResultLost and others logWarning(failureReason) None } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40383fe05026b..50c6461373dee 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -407,7 +407,9 @@ private[spark] object JsonProtocol { ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => - ("Kill Reason" -> taskKilled.reason) + val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) + ("Kill Reason" -> taskKilled.reason) ~ + ("Accumulator Updates" -> accumUpdates) case _ => emptyJson } ("Reason" -> reason) ~ json @@ -917,7 +919,10 @@ private[spark] object JsonProtocol { case `taskKilled` => val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") - TaskKilled(killReason) + val accumUpdates = jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(Seq[AccumulableInfo]()) + TaskKilled(killReason, accumUpdates) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b6ec37625eec..2987170bf5026 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("accumulators are updated on exception failures") { + test("accumulators are updated on exception failures and task killed") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") val acc3 = AccumulatorSuite.createLongAccum("agriculteur") @@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accUpdate3 = new LongAccumulator accUpdate3.metadata = acc3.metadata accUpdate3.setValue(18) - val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3) - val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo) + + val accumUpdates1 = Seq(accUpdate1, accUpdate2) + val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo) val exceptionFailure = new ExceptionFailure( new SparkException("fondue?"), - accumInfo).copy(accums = accumUpdates) + accumInfo1).copy(accums = accumUpdates1) submit(new MyRDD(sc, 1, Nil), Array(0)) runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(AccumulatorContext.get(acc1.id).get.value === 15L) assert(AccumulatorContext.get(acc2.id).get.value === 13L) + + val accumUpdates2 = Seq(accUpdate3) + val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo) + + val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result")) + assert(AccumulatorContext.get(acc3.id).get.value === 18L) } @@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accumUpdates = reason match { case Success => task.metrics.accumulators() case ef: ExceptionFailure => ef.accums + case tk: TaskKilled => tk.accums case _ => Seq.empty } CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6bae4d147d4ac..4f6d5ff898681 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), From a4470bc78ca5f5a090b6831a7cdca88274eb9afc Mon Sep 17 00:00:00 2001 From: Jake Charland Date: Tue, 22 May 2018 08:06:15 -0500 Subject: [PATCH 0848/2461] [SPARK-21673] Use the correct sandbox environment variable set by Mesos ## What changes were proposed in this pull request? This change changes spark behavior to use the correct environment variable set by Mesos in the container on startup. Author: Jake Charland Closes #18894 from jakecharland/MesosSandbox. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 8 ++++---- docs/configuration.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 13adaa921dc23..f9191a59c1655 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -810,15 +810,15 @@ private[spark] object Utils extends Logging { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { conf.getenv("SPARK_LOCAL_DIRS").split(",") - } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) { // Mesos already creates a directory per Mesos task. Spark should use that directory // instead so all temporary files are automatically cleaned up when the Mesos task ends. // Note that we don't want this if the shuffle service is enabled because we want to // continue to serve shuffle files after the executors that wrote them have already exited. - Array(conf.getenv("MESOS_DIRECTORY")) + Array(conf.getenv("MESOS_SANDBOX")) } else { - if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { - logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) { + logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " + "spark.shuffle.service.enabled is enabled.") } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user diff --git a/docs/configuration.md b/docs/configuration.md index 8a1aacef85760..fd2670cba2125 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -208,7 +208,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. From d3d18073152cab4408464d1417ec644d939cfdf7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 22 May 2018 21:08:49 +0800 Subject: [PATCH 0849/2461] [SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types ## What changes were proposed in this pull request? The interpreted evaluation of several collection operations works only for simple datatypes. For complex data types, for instance, `array_contains` it returns always `false`. The list of the affected functions is `array_contains`, `array_position`, `element_at` and `GetMapValue`. The PR fixes the behavior for all the datatypes. ## How was this patch tested? added UT Author: Marco Gaido Closes #21361 from mgaido91/SPARK-24313. --- .../expressions/collectionOperations.scala | 41 ++++++++++++---- .../expressions/complexTypeExtractors.scala | 19 +++++-- .../CollectionExpressionsSuite.scala | 49 ++++++++++++++++++- .../optimizer/complexTypesSuite.scala | 13 +++++ .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++ 5 files changed, 113 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8d763dca5243e..7da4c3cc6b9fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -657,6 +657,9 @@ case class ArrayContains(left: Expression, right: Expression) override def dataType: DataType = BooleanType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq.empty case _ => left.dataType match { @@ -673,7 +676,7 @@ case class ArrayContains(left: Expression, right: Expression) TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -686,7 +689,7 @@ case class ArrayContains(left: Expression, right: Expression) arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == null) { hasNull = true - } else if (v == value) { + } else if (ordering.equiv(v, value)) { return true } ) @@ -735,11 +738,7 @@ case class ArraysOverlap(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - if (RowOrdering.isOrderable(elementType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") case failure => failure } @@ -1391,13 +1390,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { + if (v != null && ordering.equiv(v, value)) { return (i + 1).toLong } ) @@ -1446,6 +1456,9 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + override def dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType @@ -1460,6 +1473,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti ) } + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr( + left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + override def nullable: Boolean = true override def nullSafeEval(value: Any, ordinal: Any): Any = { @@ -1484,7 +1507,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d745453..99671d5b863c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 71ff96bb722e2..3fc0b08c56e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -157,6 +157,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + // binary + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + // complex data types + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) } test("ArraysOverlap") { @@ -372,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { @@ -409,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -420,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test binary type as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d495581..5452e72b38647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support binary type as map keys in GetMapValue") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee7504..1cc8cb3874c9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } } From fc743f7b30902bad1da36131087bb922c17a048e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 22 May 2018 08:20:59 -0500 Subject: [PATCH 0850/2461] [SPARK-20120][SQL][FOLLOW-UP] Better way to support spark-sql silent mode. ## What changes were proposed in this pull request? `spark-sql` silent mode will broken if`SPARK_HOME/jars` missing `kubernetes-model-2.0.0.jar`. This pr use `sc.setLogLevel ()` to implement silent mode. ## How was this patch tested? manual tests ``` build/sbt -Phive -Phive-thriftserver package export SPARK_PREPEND_CLASSES=true ./bin/spark-sql -S ``` Author: Yuming Wang Closes #20274 from wangyum/SPARK-20120-FOLLOW-UP. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 084f8200102ba..d9fd3ebd3c65d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.log4j.{Level, Logger} +import org.apache.log4j.Level import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf @@ -300,10 +300,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) - if (sessionState.getIsSilent) { - Logger.getRootLogger.setLevel(Level.WARN) - } - private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -315,6 +311,9 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } else { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") From 8086acc2f676a04ce6255a621ffae871bd09ceea Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 22 May 2018 22:07:32 +0800 Subject: [PATCH 0851/2461] [SPARK-24244][SQL] Passing only required columns to the CSV parser ## What changes were proposed in this pull request? uniVocity parser allows to specify only required column names or indexes for [parsing](https://www.univocity.com/pages/parsers-tutorial) like: ``` // Here we select only the columns by their indexes. // The parser just skips the values in other columns parserSettings.selectIndexes(4, 0, 1); CsvParser parser = new CsvParser(parserSettings); ``` In this PR, I propose to extract indexes from required schema and pass them into the CSV parser. Benchmarks show the following improvements in parsing of 1000 columns: ``` Select 100 columns out of 1000: x1.76 Select 1 column out of 1000: x2 ``` **Note**: Comparing to current implementation, the changes can return different result for malformed rows in the `DROPMALFORMED` and `FAILFAST` modes if only subset of all columns is requested. To have previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## How was this patch tested? It was tested by new test which selects 3 columns out of 15, by existing tests and by new benchmarks. Author: Maxim Gekk Closes #21296 from MaxGekk/csv-column-pruning. --- docs/sql-programming-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 7 +++ .../datasources/csv/CSVOptions.scala | 3 ++ .../datasources/csv/UnivocityParser.scala | 26 ++++++----- .../datasources/csv/CSVBenchmarks.scala | 42 ++++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 43 ++++++++++++++++--- 6 files changed, 104 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f1ed316341b95..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a2fb3c64844b5..d0478d6ad250b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1295,6 +1295,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1066d156acd74..dd41aee0f2ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -80,6 +81,8 @@ class CSVOptions( } } + private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index d442ba7e59c61..ec788df00aa92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,7 +74,49 @@ object CSVBenchmarks { } } + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X + Select 100 columns 28625 / 32884 0.0 28625.1 2.7X + Select one column 22498 / 22669 0.0 22497.8 3.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07e6c74b14d0d..5f9f799a6c466 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1368,4 +1370,31 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + checkAnswer( + idf, + List(Row(15, 10, 5), Row(-15, -10, -5)) + ) + } + } } From f9f055afa47412eec8228c843b34a90decb9be43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 01:50:22 +0800 Subject: [PATCH 0852/2461] [SPARK-24121][SQL] Add API for handling expression code generation ## What changes were proposed in this pull request? This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions. In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code. For example, in following java code: ```java int ${variable} = 1; boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; ``` `variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code. For codegen, we provide a specified string interpolator `code`, so you can define a code like this: ```scala val codeBlock = code""" |int ${variable} = 1; |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; """.stripMargin // Generates actual java code. codeBlock.toString ``` Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc.. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21193 from viirya/SPARK-24121. --- .../catalyst/expressions/BoundAttribute.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 +- .../sql/catalyst/expressions/Expression.scala | 26 ++-- .../MonotonicallyIncreasingID.scala | 3 +- .../sql/catalyst/expressions/ScalaUDF.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 3 +- .../expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/TimeWindow.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 13 +- .../expressions/codegen/CodeGenerator.scala | 25 +-- .../expressions/codegen/CodegenFallback.scala | 5 +- .../codegen/GenerateSafeProjection.scala | 7 +- .../codegen/GenerateUnsafeProjection.scala | 5 +- .../expressions/codegen/javaCode.scala | 145 +++++++++++++++++- .../expressions/collectionOperations.scala | 19 +-- .../expressions/complexTypeCreator.scala | 7 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 23 +-- .../expressions/decimalExpressions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../spark/sql/catalyst/expressions/hash.scala | 5 +- .../catalyst/expressions/inputFileBlock.scala | 14 +- .../expressions/mathExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 9 +- .../expressions/objects/objects.scala | 48 +++--- .../sql/catalyst/expressions/predicates.scala | 15 +- .../expressions/randomExpressions.scala | 5 +- .../expressions/regexpExpressions.scala | 9 +- .../expressions/stringExpressions.scala | 25 +-- .../ExpressionEvalHelperSuite.scala | 3 +- .../expressions/codegen/CodeBlockSuite.scala | 136 ++++++++++++++++ .../sql/execution/ColumnarBatchScan.scala | 9 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 15 +- .../aggregate/HashAggregateExec.scala | 7 +- .../aggregate/HashMapGenerator.scala | 3 +- .../joins/BroadcastHashJoinExec.scala | 3 +- .../execution/joins/SortMergeJoinExec.scala | 5 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 41 files changed, 479 insertions(+), 172 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..df3ab05e02c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..699ea53b5df0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + + ev.copy(code = + code""" + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..9b9fa41a47d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 9f0779642271d..f1da592a76845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..3e7ca88249737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..2ce9d072c71c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..84e38a8b2711e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index efd4e992c8eec..fe91e520169b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic { ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d382d9aace109..66315e5906253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -330,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -1056,7 +1057,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1084,7 +1085,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1141,7 +1142,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1160,9 +1161,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..8f2a5a0dce943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..250ce48d059e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -114,6 +115,147 @@ object JavaCode { } } +/** + * A trait representing a block of java code. + */ +trait Block extends JavaCode { + + // The expressions to be evaluated inside this block. + def exprValues: Set[ExprValue] + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def nonEmpty: Boolean = toString.nonEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + // Concatenates this block with other block. + def + (other: Block): Block +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: Block => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case _ => + buf.append(input) + } + buf.append(strings.next) + } + if (buf.nonEmpty) { + codeParts += buf.toString + } + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override lazy val exprValues: Set[ExprValue] = { + blockInputs.flatMap { + case b: Block => b.exprValues + case e: ExprValue => Set(e) + }.toSet + } + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + +case class Blocks(blocks: Seq[Block]) extends Block { + override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet + override lazy val code: String = blocks.map(_.toString).mkString("\n") + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(blocks :+ c) + case b: Blocks => Blocks(blocks ++ b.blocks) + case EmptyBlock => this + } +} + +object EmptyBlock extends Block with Serializable { + override val code: String = "" + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other +} + /** * A typed java fragment that must be a valid java expression. */ @@ -123,10 +265,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7da4c3cc6b9fa..c28eab71b84fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : @@ -1177,14 +1178,14 @@ case class ArrayJoin( } if (nullable) { ev.copy( - s""" + code""" |boolean ${ev.isNull} = true; |UTF8String ${ev.value} = null; |$code """.stripMargin) } else { ev.copy( - s""" + code""" |UTF8String ${ev.value} = null; |$code """.stripMargin, FalseLiteral) @@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression { expressions = inputs, funcName = "valueConcat", extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" $initCode $codes $javaType ${ev.value} = $concatenator.concat($args); @@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = false; |${leftGen.code} |${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..a9867aaeb0cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..77ac6c088022e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03422fecb3209..e8d85f72f7a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") val eval = child.genCode(ctx) - val code = s""" + val code = code""" |${eval.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; @@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..b7c52f1d7b40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..cec00b66f873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b7834696cafc3..5d98dac46cf17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..2eeed3bbb2d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f974fd81fc788..2bf4203d0fec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -269,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -385,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -492,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -532,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -935,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1709,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f8c6dc4e6adc9..f54103c4fbfba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2653b28f6c3bd..926c2f00d430d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..7b68bb771faf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..9823b2fc5ad97 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..d2c6420eadb20 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block interpolates string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.exprValues + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin.trim + + assert(code.toString == expected) + + val exprValues = code.exprValues + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) + } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = Tuple2(1, 1) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("replace expr values in code block") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { + case _: SimpleExprValue => aliasedParam + case other => other + } + val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..372dc3db36ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim + |${ev.code} + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + evaluateVars.append(ev.code.toString + "\n") + ev.code = EmptyBlock } } evaluateVars.toString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6a8ec4f722aea..8c7b2c187cccd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin @@ -773,8 +774,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..0da0e8610c392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } From bc6ea614ad4c6a323c78f209120287b256a458d3 Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Tue, 22 May 2018 13:01:07 -0700 Subject: [PATCH 0853/2461] [SPARK-24348][SQL] "element_at" error fix ## What changes were proposed in this pull request? ### Fixes a `scala.MatchError` in the `element_at` operation - [SPARK-24348](https://issues.apache.org/jira/browse/SPARK-24348) When calling `element_at` with a wrong first operand type an `AnalysisException` should be thrown instead of `scala.MatchError` *Example:* ```sql select element_at('foo', 1) ``` results in: ``` scala.MatchError: StringType (of class org.apache.spark.sql.types.StringType$) at org.apache.spark.sql.catalyst.expressions.ElementAt.inputTypes(collectionOperations.scala:1469) at org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes$class.checkInputDataTypes(ExpectsInputTypes.scala:44) at org.apache.spark.sql.catalyst.expressions.ElementAt.checkInputDataTypes(collectionOperations.scala:1478) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:168) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:168) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:256) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:252) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:288) ``` ## How was this patch tested? unit tests Author: Vayda, Oleksandr: IT (PRG) Closes #21395 from wajda/SPARK-24348-element_at-error-fix. --- .../sql/catalyst/expressions/collectionOperations.scala | 1 + .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c28eab71b84fd..03b3b21a16617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1470,6 +1470,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti left.dataType match { case _: ArrayType => IntegerType case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _ => AnyDataType // no match for a wrong 'left' expression type } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index df23e07e441a0..ec2a569f900d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -756,6 +756,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") + } + assert(e.message.contains( + "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } test("concat function - arrays") { From 79e06faa4ef6596c9e2d4be09c74b935064021bb Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 22 May 2018 13:43:45 -0700 Subject: [PATCH 0854/2461] [SPARK-19185][DSTREAMS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? `CachedKafkaConsumer` in the project streaming-kafka-0-10 is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one thread trying to read the same Kafka TopicPartition at the same time. This assumption is not true all the time and this can inadvertently lead to ConcurrentModificationException. Here is a better way to design this. The consumer pool should be smart enough to avoid concurrent use of a cached consumer. If there is another request for the same TopicPartition as a currently in-use consumer, the pool should automatically return a fresh consumer. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called `KafkaDataConsumer` is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply call `val consumer = KafkaDataConsumer.acquire` and then `consumer.release`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is request for a consumer which is a task reattempt, then already existing cached consumer will be invalidated and a new consumer is generated. This could fix potential issues if the source of the reattempt is a malfunctioning consumer. - In addition, I renamed the `CachedKafkaConsumer` class to `KafkaDataConsumer` because is a misnomer given that what it returns may or may not be cached. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same TopicPartition from the consumer pool. Author: Gabor Somogyi Closes #20997 from gaborgsomogyi/SPARK-19185. --- .../sql/kafka010/KafkaDataConsumer.scala | 2 +- .../kafka010/CachedKafkaConsumer.scala | 226 ----------- .../kafka010/KafkaDataConsumer.scala | 359 ++++++++++++++++++ .../spark/streaming/kafka010/KafkaRDD.scala | 20 +- .../kafka010/KafkaDataConsumerSuite.scala | 131 +++++++ 5 files changed, 496 insertions(+), 242 deletions(-) delete mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala create mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 48508d057a540..941f0ab177e48 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // likely running on a beefy machine that can handle a large number of simultaneously // active consumers. - if (entry.getValue.inUse == false && this.size > capacity) { + if (!entry.getValue.inUse && this.size > capacity) { logWarning( s"KafkaConsumer cache hitting max capacity of $capacity, " + s"removing consumer for ${entry.getKey}") diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala deleted file mode 100644 index aeb8c1dc342b3..0000000000000 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kafka010 - -import java.{ util => ju } - -import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } -import org.apache.kafka.common.{ KafkaException, TopicPartition } - -import org.apache.spark.internal.Logging - -/** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. - */ -private[kafka010] -class CachedKafkaConsumer[K, V] private( - val groupId: String, - val topic: String, - val partition: Int, - val kafkaParams: ju.Map[String, Object]) extends Logging { - - require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), - "groupId used for cache key must match the groupId in kafkaParams") - - val topicPartition = new TopicPartition(topic, partition) - - protected val consumer = { - val c = new KafkaConsumer[K, V](kafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - // TODO if the buffer was kept around as a random-access structure, - // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() - protected var nextOffset = -2L - - def close(): Unit = consumer.close() - - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. - */ - def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { - logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") - if (offset != nextOffset) { - logInfo(s"Initial fetch for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - - if (!buffer.hasNext()) { poll(timeout) } - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - var record = buffer.next() - - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - record = buffer.next() - require(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + - s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + - "spark.streaming.kafka.allowNonConsecutiveOffsets" - ) - } - - nextOffset = offset + 1 - record - } - - /** - * Start a batch on a compacted topic - */ - def compactedStart(offset: Long, timeout: Long): Unit = { - logDebug(s"compacted start $groupId $topic $partition starting $offset") - // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics - if (offset != nextOffset) { - logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - } - - /** - * Get the next record in the batch from a compacted topic. - * Assumes compactedStart has been called first, and ignores gaps. - */ - def compactedNext(timeout: Long): ConsumerRecord[K, V] = { - if (!buffer.hasNext()) { - poll(timeout) - } - require(buffer.hasNext(), - s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") - val record = buffer.next() - nextOffset = record.offset + 1 - record - } - - /** - * Rewind to previous record in the batch from a compacted topic. - * @throws NoSuchElementException if no previous element - */ - def compactedPrevious(): ConsumerRecord[K, V] = { - buffer.previous() - } - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - private def poll(timeout: Long): Unit = { - val p = consumer.poll(timeout) - val r = p.records(topicPartition) - logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.listIterator - } - -} - -private[kafka010] -object CachedKafkaConsumer extends Logging { - - private case class CacheKey(groupId: String, topic: String, partition: Int) - - // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap - private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null - - /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ - def init( - initialCapacity: Int, - maxCapacity: Int, - loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { - if (null == cache) { - logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") - cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( - initialCapacity, loadFactor, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { - if (this.size > maxCapacity) { - try { - entry.getValue.consumer.close() - } catch { - case x: KafkaException => - logError("Error closing oldest Kafka consumer", x) - } - true - } else { - false - } - } - } - } - } - - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def get[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) - val v = cache.get(k) - if (null == v) { - logInfo(s"Cache miss for $k") - logDebug(cache.keySet.toString) - val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - cache.put(k, c) - c - } else { - // any given topicpartition should have a consistent key and value type - v.asInstanceOf[CachedKafkaConsumer[K, V]] - } - } - - /** - * Get a fresh new instance, unassociated with the global cache. - * Caller is responsible for closing - */ - def getUncached[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - - /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) - logInfo(s"Removing $k from cache") - val v = CachedKafkaConsumer.synchronized { - cache.remove(k) - } - if (null != v) { - v.close() - logInfo(s"Removed $k from cache") - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala new file mode 100644 index 0000000000000..68c5fe9ab066a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging + +private[kafka010] sealed trait KafkaDataConsumer[K, V] { + /** + * Get the record for the given offset if available. + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.get(offset, pollTimeoutMs) + } + + /** + * Start a batch on a compacted topic + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + internalConsumer.compactedStart(offset, pollTimeoutMs) + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + * + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.compactedNext(pollTimeoutMs) + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + internalConsumer.compactedPrevious() + } + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + def internalConsumer: InternalKafkaConsumer[K, V] +} + + +/** + * A wrapper around Kafka's KafkaConsumer. + * This is not for direct use outside this file. + */ +private[kafka010] class InternalKafkaConsumer[K, V]( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + .asInstanceOf[String] + + private val consumer = createConsumer + + /** indicates whether this consumer is in use or not */ + var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + var markedForClose = false + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() + @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET + + override def toString: String = { + "InternalKafkaConsumer(" + + s"hash=${Integer.toHexString(hashCode)}, " + + s"groupId=$groupId, " + + s"topicPartition=$topicPartition)" + } + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[K, V] = { + val c = new KafkaConsumer[K, V](kafkaParams) + val topics = ju.Arrays.asList(topicPartition) + c.assign(topics) + c + } + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + record = buffer.next() + require(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) + } + + nextOffset = offset + 1 + record + } + + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + logDebug(s"compacted start $groupId $topicPartition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(pollTimeoutMs) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topicPartition " + + s"after polling for $pollTimeoutMs") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.listIterator + } + +} + +private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) + +private[kafka010] object KafkaDataConsumer extends Logging { + + private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + assert(internalConsumer.inUse) + override def release(): Unit = KafkaDataConsumer.release(internalConsumer) + } + + private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + override def release(): Unit = internalConsumer.close() + } + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null + + /** + * Must be called before acquire, once per JVM, to configure the cache. + * Further calls are ignored. + */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > maxCapacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $maxCapacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by anyone + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. + */ + def acquire[K, V]( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + context: TaskContext, + useCache: Boolean): KafkaDataConsumer[K, V] = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = cache.get(key) + + lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams) + + if (context != null && context.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumers if any and + // start with a new one. If prior attempt failures were cache related then this way old + // problematic consumers can be removed. + logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer") + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + // Remove the consumer from cache only if it's closed. + // Marked for close consumers will be removed in release function. + cache.remove(key) + } + } + + logDebug("Reattempt detected, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (!useCache) { + // If consumer reuse turned off, then do not use it, return a new consumer + logDebug("Cache usage turned off, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + logDebug("No cached consumer, new cached consumer will be allocated " + + s"$newInternalConsumer") + cache.put(key, newInternalConsumer) + CachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + logDebug("Used cached consumer found, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else { + // If consumer is already cached and is currently not in use, then return that consumer + logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer") + existingInternalConsumer.inUse = true + // Any given TopicPartition should have a consistent key and value type + CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]]) + } + } + + private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized { + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition) + val cachedInternalConsumer = cache.get(key) + if (internalConsumer.eq(cachedInternalConsumer)) { + // The released consumer is the same object as the cached one. + if (internalConsumer.markedForClose) { + internalConsumer.close() + cache.remove(key) + } else { + internalConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + internalConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache " + + s"$internalConsumer") + } + } +} + +private[kafka010] object InternalKafkaConsumer { + private val UNKNOWN_OFFSET = -2L +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 07239eda64d2e..81abc9860bfc3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } -import scala.collection.mutable.ArrayBuffer - import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } import org.apache.kafka.common.TopicPartition @@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - context.addTaskCompletionListener(_ => closeIfNeeded()) - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + val consumer = { + KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache) } var requestOffset = part.fromOffset def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close() + if (consumer != null) { + consumer.release() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..d934c64962adb --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll { + private var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + KafkaDataConsumer.init(16, 64, 0.75f) + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> groupId, + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer1.release() + + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer2.release() + + assert(KafkaDataConsumer.cache.size() == 1) + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = KafkaDataConsumer.cache.get(key) + assert(existingInternalConsumer.eq(consumer1.internalConsumer)) + assert(existingInternalConsumer.eq(consumer2.internalConsumer)) + } + + test("concurrent use of KafkaDataConsumer") { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic) + testUtils.sendMessages(topic, data.toArray) + + val kafkaParams = getKafkaParams() + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, taskContext, useCache) + try { + val rcvd = (0 until data.length).map { offset => + val bytes = consumer.get(offset, 10000).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadPool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadPool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadPool.shutdown() + } + } +} From 00c13cfad78607fde0787c9d494f0df8ab7051ba Mon Sep 17 00:00:00 2001 From: Seth Fitzsimmons Date: Wed, 23 May 2018 09:14:03 +0800 Subject: [PATCH 0855/2461] Correct reference to Offset class This is a documentation-only correction; `org.apache.spark.sql.sources.v2.reader.Offset` is actually `org.apache.spark.sql.sources.v2.reader.streaming.Offset`. Author: Seth Fitzsimmons Closes #21387 from mojodna/patch-1. --- .../org/apache/spark/sql/execution/streaming/Offset.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java index 80aa5505db991..43ad4b3384ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -19,8 +19,8 @@ /** * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported - * in the long term. + * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be + * supported in the long term. * * This class will be removed in a future release. */ From a40ffc656d62372da85e0fa932b67207839e7fde Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 22:40:52 +0800 Subject: [PATCH 0856/2461] [SPARK-23711][SQL] Add fallback generator for UnsafeProjection ## What changes were proposed in this pull request? Add fallback logic for `UnsafeProjection`. In production we can try to create unsafe projection using codegen implementation. Once any compile error happens, it fallbacks to interpreted implementation. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21106 from viirya/SPARK-23711. --- ...CodeGeneratorWithInterpretedFallback.scala | 73 +++++++++++++++++++ .../InterpretedUnsafeProjection.scala | 5 +- .../sql/catalyst/expressions/Projection.scala | 60 +++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 12 +++ ...eneratorWithInterpretedFallbackSuite.scala | 43 +++++++++++ .../expressions/ExpressionEvalHelper.scala | 52 ++++++------- .../expressions/ObjectExpressionsSuite.scala | 18 +---- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++----- .../UnsafeKVExternalSorterSuite.scala | 3 +- 9 files changed, 230 insertions(+), 88 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala new file mode 100644 index 0000000000000..fb25e781e72e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.InternalCompilerException + +import org.apache.spark.TaskContext +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Catches compile error during code generation. + */ +object CodegenError { + def unapply(throwable: Throwable): Option[Exception] = throwable match { + case e: InternalCompilerException => Some(e) + case e: CompileException => Some(e) + case _ => None + } +} + +/** + * Defines values for `SQLConf` config of fallback mode. Use for test only. + */ +object CodegenObjectFactoryMode extends Enumeration { + val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value +} + +/** + * A codegen object generator which creates objects with codegen path first. Once any compile + * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. + */ +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { + + def createObject(in: IN): OUT = { + // We are allowed to choose codegen-only or no-codegen modes if under tests. + val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) + val fallbackMode = CodegenObjectFactoryMode.withName(config) + + fallbackMode match { + case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting => + createCodeGeneratedObject(in) + case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting => + createInterpretedObject(in) + case _ => + try { + createCodeGeneratedObject(in) + } catch { + case CodegenError(_) => createInterpretedObject(in) + } + } + } + + protected def createCodeGeneratedObject(in: IN): OUT + protected def createInterpretedObject(in: IN): OUT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6d69d69b1c802..55a5bd380859e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** * Helper functions for creating an [[InterpretedUnsafeProjection]]. */ -object InterpretedUnsafeProjection extends UnsafeProjectionCreator { - +object InterpretedUnsafeProjection { /** * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + def createProjection(exprs: Seq[Expression]): UnsafeProjection = { // We need to make sure that we do not reuse stateful expressions. val cleanedExpressions = exprs.map(_.transform { case s: Stateful => s.freshCopy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 3cd73682188bc..6493f09100577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -trait UnsafeProjectionCreator { +/** + * The factory object for `UnsafeProjection`. + */ +object UnsafeProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + + protected def toBoundExprs( + exprs: Seq[Expression], + inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) + } + + protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_ transform { + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + } + /** * Returns an UnsafeProjection for given StructType. * @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator { * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - createProjection(unsafeExprs) + createObject(toUnsafeExprs(exprs)) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator { * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(exprs.map(BindReferences.bindReference(_, inputSchema))) - } - - /** - * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. - */ - protected def createProjection(exprs: Seq[Expression]): UnsafeProjection -} - -object UnsafeProjection extends UnsafeProjectionCreator { - - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + create(toBoundExprs(exprs, inputSchema)) } /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. + * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, + * when fallbacking to interpreted execution, it is not supported. */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) + try { + GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) + } catch { + case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d0478d6ad250b..fb98df587129b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.Utils @@ -703,6 +704,17 @@ object SQLConf { .intConf .createWithDefault(100) + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then fallbacking to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala new file mode 100644 index 0000000000000..531ca9a87370a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} + +class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + + test("UnsafeProjection with codegen factory mode") { + val input = Seq(LongType, IntegerType) + .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = UnsafeProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) + } + + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = UnsafeProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c2a44e0d33b18..14bfa212b5496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -205,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) - } - - protected def checkEvaluationWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow, - factory: UnsafeProjectionCreator): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") - } - } else { - val lit = InternalRow(expected, expected) - val expectedRow = - factory.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected, expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } } } } protected def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow, - factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + inputRow: InternalRow = EmptyRow): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 77ca640f2e0bd..20d568c44258f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -81,10 +81,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) checkEvaluationWithUnsafeProjection( - structEncoder.serializer.head, - structExpected, - structInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + structEncoder.serializer.head, structExpected, structInputRow) // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] @@ -92,10 +89,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) checkEvaluationWithUnsafeProjection( - arrayEncoder.serializer.head, - arrayExpected, - arrayInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) checkEvaluationWithUnsafeProjection( - mapEncoder.serializer.head, - mapExpected, - mapInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + mapEncoder.serializer.head, mapExpected, mapInputRow) } test("SPARK-23582: StaticInvoke should support interpreted execution") { @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithUnsafeProjection( expr, expected, - inputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + inputRow) } checkEvaluationWithOptimization(expr, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c07da122cd7b8..5a646d9a850ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,25 +24,30 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testWithFactory( - name: String)( - f: UnsafeProjectionCreator => Unit): Unit = { - test(name) { - f(UnsafeProjection) - f(InterpretedUnsafeProjection) + private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } } } - testWithFactory("basic conversion with only primitive types") { factory => + testBothCodegenAndInterpreted("basic conversion with only primitive types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) @@ -79,7 +84,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - testWithFactory("basic conversion with primitive, string and binary types") { factory => + testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) @@ -98,7 +104,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => + testBothCodegenAndInterpreted( + "basic conversion with primitive, string, date and timestamp types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) @@ -127,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - testWithFactory("null handling") { factory => + testBothCodegenAndInterpreted("null handling") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -248,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testWithFactory("NaN canonicalization") { factory => + testBothCodegenAndInterpreted("NaN canonicalization") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -263,7 +273,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - testWithFactory("basic conversion with struct type") { factory => + testBothCodegenAndInterpreted("basic conversion with struct type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) @@ -325,7 +336,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - testWithFactory("basic conversion with array type") { factory => + testBothCodegenAndInterpreted("basic conversion with array type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) @@ -355,7 +367,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - testWithFactory("basic conversion with map type") { factory => + testBothCodegenAndInterpreted("basic conversion with map type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) @@ -401,7 +414,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - testWithFactory("basic conversion with struct and array") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and array") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) @@ -440,7 +454,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with struct and map") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) @@ -486,7 +501,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with array and map") { factory => + testBothCodegenAndInterpreted("basic conversion with array and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bf588d3bb7841..c882a9dd2148c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -231,7 +231,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, schema, From df125062c8dac9fee3328d67dd438a456b7a3b74 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 23 May 2018 11:00:23 -0700 Subject: [PATCH 0857/2461] [SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? Change `PrefixSpan` into a class with param setter/getters. This address issues mentioned here: https://github.com/apache/spark/pull/20973#discussion_r186931806 ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21393 from WeichenXu123/fix_prefix_span. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 127 ++++++++++++++---- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 28 ++-- 2 files changed, 119 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 02168fee16dbf..41716c621ca98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -29,13 +31,97 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. * * @see Sequential Pattern Mining * (Wikipedia) */ @Since("2.4.0") @Experimental -object PrefixSpan { +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset)." + + "times will be output.", ParamValidators.gtEq(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** @group setParam */ + @Since("2.4.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * Param for the maximal pattern length (default: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "The maximal length of the sequential pattern.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Int = $(maxPatternLength) + + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. + * @group param + */ + @Since("2.4.0") + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") + + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") /** * :: Experimental :: @@ -43,54 +129,39 @@ object PrefixSpan { * * @param dataset A dataset or a dataframe containing a sequence column which is * {{{Seq[Seq[_]]}}} type - * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column - * are ignored - * @param minSupport the minimal support level of the sequential pattern, any pattern that - * appears more than (minSupport * size-of-the-dataset) times will be output - * (recommended value: `0.1`). - * @param maxPatternLength the maximal length of the sequential pattern - * (recommended value: `10`). - * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the - * internal storage format) allowed in a projected database before - * local processing. If a projected database exceeds this size, another - * iteration of distributed prefix growth is run - * (recommended value: `32000000`). * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: * - `sequence: Seq[Seq[T]]` (T is the item type) * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns( - dataset: Dataset[_], - sequenceCol: String, - minSupport: Double, - maxPatternLength: Int, - maxLocalProjDBSize: Long): DataFrame = { - - val inputType = dataset.schema(sequenceCol).dataType + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], s"The input column must be ArrayType and the array element type must also be ArrayType, " + s"but got $inputType.") - - val data = dataset.select(sequenceCol) - val sequences = data.where(col(sequenceCol).isNotNull).rdd + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) val mllibPrefixSpan = new mllibPrefixSpan() - .setMinSupport(minSupport) - .setMaxPatternLength(maxPatternLength) - .setMaxLocalProjDBSize(maxLocalProjDBSize) + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) val schema = StructType(Seq( - StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) freqSequences } + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala index 9e538696cbcf7..2252151af306b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan projections with multiple partial starts") { val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", - minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset) .as[(Seq[Seq[Int]], Long)].collect() val expected = Array( (Seq(Seq(1)), 1L), @@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan Integer type, variable-size itemsets") { val df = smallTestData.toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan input row with nulls") { val df = (smallTestData :+ null).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest { val df = smallTestData .map(seq => seq.map(itemSet => itemSet.map(intToString))) .toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[String]], Long)].collect() val expected = smallTestDataExpectedResult.map { case (seq, freq) => From 5a5a868dc410ad8c97851d7f3f0ea1c9fc1db90c Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 23 May 2018 11:51:13 -0700 Subject: [PATCH 0858/2461] Revert "[SPARK-24244][SQL] Passing only required columns to the CSV parser" This reverts commit 8086acc2f676a04ce6255a621ffae871bd09ceea. --- docs/sql-programming-guide.md | 1 - .../apache/spark/sql/internal/SQLConf.scala | 7 --- .../datasources/csv/CSVOptions.scala | 3 -- .../datasources/csv/UnivocityParser.scala | 26 +++++------ .../datasources/csv/CSVBenchmarks.scala | 42 ------------------ .../execution/datasources/csv/CSVSuite.scala | 43 +++---------------- 6 files changed, 18 insertions(+), 104 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fc26562ff33da..f1ed316341b95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,7 +1825,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fb98df587129b..15ba10f604510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1307,13 +1307,6 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } - - val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") - .internal() - .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + - "Other column values can be ignored during parsing even if they are malformed.") - .booleanConf - .createWithDefault(true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index dd41aee0f2ebc..1066d156acd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,7 +25,6 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -81,8 +80,6 @@ class CSVOptions( } } - private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) - val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f39..99557a1ceb0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - dataSchema: StructType, + schema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(dataSchema.toSet), + require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,17 +45,9 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { - val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) - parserSetting.selectIndexes(tokenIndexArr: _*) - } - new CsvParser(parserSetting) - } - private val schema = if (options.columnPruning) requiredSchema else dataSchema + private val tokenizer = new CsvParser(options.asParserSettings) - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -81,8 +73,11 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = { + private val valueConverters: Array[ValueConverter] = schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + + private val tokenIndexArr: Array[Int] = { + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -215,8 +210,9 @@ class UnivocityParser( } else { try { var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index ec788df00aa92..d442ba7e59c61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,49 +74,7 @@ object CSVBenchmarks { } } - def multiColumnsBenchmark(rowsNum: Int): Unit = { - val colsNum = 1000 - val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) - - withTempPath { path => - val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) - val schema = StructType(fields) - val values = (0 until colsNum).map(i => i.toString).mkString(",") - val columnNames = schema.fieldNames - - spark.range(rowsNum) - .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) - .write.option("header", true) - .csv(path.getAbsolutePath) - - val ds = spark.read.schema(schema).csv(path.getAbsolutePath) - - benchmark.addCase(s"Select $colsNum columns", 3) { _ => - ds.select("*").filter((row: Row) => true).count() - } - val cols100 = columnNames.take(100).map(Column(_)) - benchmark.addCase(s"Select 100 columns", 3) { _ => - ds.select(cols100: _*).filter((row: Row) => true).count() - } - benchmark.addCase(s"Select one column", 3) { _ => - ds.select($"col1").filter((row: Row) => true).count() - } - - /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz - - Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X - Select 100 columns 28625 / 32884 0.0 28625.1 2.7X - Select one column 22498 / 22669 0.0 22497.8 3.4X - */ - benchmark.run() - } - } - def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) - multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5f9f799a6c466..07e6c74b14d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,16 +260,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) - } + assert(cars.select("year").collect().size === 2) } } @@ -1370,31 +1368,4 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } - - test("SPARK-24244: Select a subset of all columns") { - withTempPath { path => - import collection.JavaConverters._ - val schema = new StructType() - .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) - .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) - .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) - .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) - .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) - - val odf = spark.createDataFrame(List( - Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) - ).asJava, schema) - odf.write.csv(path.getCanonicalPath) - val idf = spark.read - .schema(schema) - .csv(path.getCanonicalPath) - .select('f15, 'f10, 'f5) - - checkAnswer( - idf, - List(Row(15, 10, 5), Row(-15, -10, -5)) - ) - } - } } From 84557bc9f87885577f22d7e7f2220c0b988014cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 23 May 2018 13:02:32 -0700 Subject: [PATCH 0859/2461] [SPARK-24206][SQL] Improve DataSource read benchmark code ## What changes were proposed in this pull request? This pr added benchmark code `DataSourceReadBenchmark` for `orc`, `paruqet`, `csv`, and `json` based on the existing `ParquetReadBenchmark` and `OrcReadBenchmark`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21266 from maropu/DataSourceReadBenchmark. --- .../benchmark/DataSourceReadBenchmark.scala | 828 ++++++++++++++++++ .../parquet/ParquetReadBenchmark.scala | 339 ------- 2 files changed, 828 insertions(+), 339 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala new file mode 100644 index 0000000000000..fc6d8abc03c09 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -0,0 +1,828 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnVector +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure data source read performance. + * To run this: + * spark-submit --class + */ +object DataSourceReadBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceReadBenchmark") + .setIfMissing("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + spark.conf.set(SQLConf.ORC_COPY_BATCH_TO_SPARK.key, "false") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") + saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") + saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") + } + + private def saveAsCsvTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").option("header", true).csv(dir) + spark.read.option("header", true).csv(dir).createOrReplaceTempView("csvTable") + } + + private def saveAsJsonTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").json(dir) + spark.read.json(dir).createOrReplaceTempView("jsonTable") + } + + private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark( + s"Parquet Reader Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + sqlBenchmark.addCase("SQL Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + sqlBenchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 15231 / 15267 1.0 968.3 1.0X + SQL Json 8476 / 8498 1.9 538.9 1.8X + SQL Parquet Vectorized 121 / 127 130.0 7.7 125.9X + SQL Parquet MR 1515 / 1543 10.4 96.3 10.1X + SQL ORC Vectorized 164 / 171 95.9 10.4 92.9X + SQL ORC Vectorized with copy 228 / 234 69.0 14.5 66.8X + SQL ORC MR 1297 / 1309 12.1 82.5 11.7X + + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 16344 / 16374 1.0 1039.1 1.0X + SQL Json 8634 / 8648 1.8 548.9 1.9X + SQL Parquet Vectorized 172 / 177 91.5 10.9 95.1X + SQL Parquet MR 1744 / 1746 9.0 110.9 9.4X + SQL ORC Vectorized 189 / 194 83.1 12.0 86.4X + SQL ORC Vectorized with copy 244 / 250 64.5 15.5 67.0X + SQL ORC MR 1341 / 1386 11.7 85.3 12.2X + + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17874 / 17875 0.9 1136.4 1.0X + SQL Json 9190 / 9204 1.7 584.3 1.9X + SQL Parquet Vectorized 141 / 160 111.2 9.0 126.4X + SQL Parquet MR 1930 / 2049 8.2 122.7 9.3X + SQL ORC Vectorized 259 / 264 60.7 16.5 69.0X + SQL ORC Vectorized with copy 265 / 272 59.4 16.8 67.5X + SQL ORC MR 1528 / 1569 10.3 97.2 11.7X + + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 22812 / 22839 0.7 1450.4 1.0X + SQL Json 12026 / 12054 1.3 764.6 1.9X + SQL Parquet Vectorized 222 / 227 70.8 14.1 102.6X + SQL Parquet MR 2199 / 2204 7.2 139.8 10.4X + SQL ORC Vectorized 331 / 335 47.6 21.0 69.0X + SQL ORC Vectorized with copy 338 / 343 46.6 21.5 67.6X + SQL ORC MR 1618 / 1622 9.7 102.9 14.1X + + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 18703 / 18740 0.8 1189.1 1.0X + SQL Json 11779 / 11869 1.3 748.9 1.6X + SQL Parquet Vectorized 143 / 145 110.1 9.1 130.9X + SQL Parquet MR 1954 / 1963 8.0 124.2 9.6X + SQL ORC Vectorized 347 / 355 45.3 22.1 53.8X + SQL ORC Vectorized with copy 356 / 359 44.1 22.7 52.5X + SQL ORC MR 1570 / 1598 10.0 99.8 11.9X + + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 23832 / 23838 0.7 1515.2 1.0X + SQL Json 16204 / 16226 1.0 1030.2 1.5X + SQL Parquet Vectorized 242 / 306 65.1 15.4 98.6X + SQL Parquet MR 2462 / 2482 6.4 156.5 9.7X + SQL ORC Vectorized 419 / 451 37.6 26.6 56.9X + SQL ORC Vectorized with copy 426 / 447 36.9 27.1 55.9X + SQL ORC MR 1885 / 1931 8.3 119.8 12.6X + */ + sqlBenchmark.run() + + // Driving the parquet reader in batch mode directly. + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) + case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) + case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) + case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) + case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) + case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) aggregateValue(col, i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (InternalRow) => Unit = dataType match { + case ByteType => (col: InternalRow) => longSum += col.getByte(0) + case ShortType => (col: InternalRow) => longSum += col.getShort(0) + case IntegerType => (col: InternalRow) => longSum += col.getInt(0) + case LongType => (col: InternalRow) => longSum += col.getLong(0) + case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) + case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) aggregateValue(record) + } + } + } finally { + reader.close() + } + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 187 / 201 84.2 11.9 1.0X + ParquetReader Vectorized -> Row 101 / 103 156.4 6.4 1.9X + + + Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 272 / 288 57.8 17.3 1.0X + ParquetReader Vectorized -> Row 213 / 219 73.7 13.6 1.3X + + + Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 252 / 288 62.5 16.0 1.0X + ParquetReader Vectorized -> Row 232 / 246 67.7 14.8 1.1X + + + Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 415 / 454 37.9 26.4 1.0X + ParquetReader Vectorized -> Row 407 / 432 38.6 25.9 1.0X + + + Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 251 / 302 62.7 16.0 1.0X + ParquetReader Vectorized -> Row 220 / 234 71.5 14.0 1.1X + + + Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 432 / 436 36.4 27.5 1.0X + ParquetReader Vectorized -> Row 414 / 422 38.0 26.4 1.0X + */ + parquetReaderBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(c1), sum(length(c2)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(c1), sum(length(c2)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 19172 / 19173 0.5 1828.4 1.0X + SQL Json 12799 / 12873 0.8 1220.6 1.5X + SQL Parquet Vectorized 2558 / 2564 4.1 244.0 7.5X + SQL Parquet MR 4514 / 4583 2.3 430.4 4.2X + SQL ORC Vectorized 2561 / 2697 4.1 244.3 7.5X + SQL ORC Vectorized with copy 3076 / 3110 3.4 293.4 6.2X + SQL ORC MR 4197 / 4283 2.5 400.2 4.6X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("select cast((value % 200) + 10000 as STRING) as c1 from t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c1)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c1)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("select sum(length(c1)) from orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 10889 / 10924 1.0 1038.5 1.0X + SQL Json 7903 / 7931 1.3 753.7 1.4X + SQL Parquet Vectorized 777 / 799 13.5 74.1 14.0X + SQL Parquet MR 1682 / 1708 6.2 160.4 6.5X + SQL ORC Vectorized 532 / 534 19.7 50.7 20.5X + SQL ORC Vectorized with copy 742 / 743 14.1 70.7 14.7X + SQL ORC MR 1996 / 2002 5.3 190.4 5.5X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column - CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + benchmark.addCase("Data column - Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + benchmark.addCase("Data column - Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + benchmark.addCase("Data column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + benchmark.addCase("Data column - ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Data column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Data column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - CSV") { _ => + spark.sql("select sum(p) from csvTable").collect() + } + + benchmark.addCase("Partition column - Json") { _ => + spark.sql("select sum(p) from jsonTable").collect() + } + + benchmark.addCase("Partition column - Parquet Vectorized") { _ => + spark.sql("select sum(p) from parquetTable").collect() + } + + benchmark.addCase("Partition column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p) from parquetTable").collect() + } + } + + benchmark.addCase("Partition column - ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + + benchmark.addCase("Partition column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - CSV") { _ => + spark.sql("select sum(p), sum(id) from csvTable").collect() + } + + benchmark.addCase("Both columns - Json") { _ => + spark.sql("select sum(p), sum(id) from jsonTable").collect() + } + + benchmark.addCase("Both columns - Parquet Vectorized") { _ => + spark.sql("select sum(p), sum(id) from parquetTable").collect() + } + + benchmark.addCase("Both columns - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p), sum(id) from parquetTable").collect + } + } + + benchmark.addCase("Both columns - ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Both column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Data column - CSV 25428 / 25454 0.6 1616.7 1.0X + Data column - Json 12689 / 12774 1.2 806.7 2.0X + Data column - Parquet Vectorized 222 / 231 70.7 14.1 114.3X + Data column - Parquet MR 3355 / 3397 4.7 213.3 7.6X + Data column - ORC Vectorized 332 / 338 47.4 21.1 76.6X + Data column - ORC Vectorized with copy 338 / 341 46.5 21.5 75.2X + Data column - ORC MR 2329 / 2356 6.8 148.0 10.9X + Partition column - CSV 17465 / 17502 0.9 1110.4 1.5X + Partition column - Json 10865 / 10876 1.4 690.8 2.3X + Partition column - Parquet Vectorized 48 / 52 325.4 3.1 526.1X + Partition column - Parquet MR 1695 / 1696 9.3 107.8 15.0X + Partition column - ORC Vectorized 49 / 54 319.9 3.1 517.2X + Partition column - ORC Vectorized with copy 49 / 52 324.1 3.1 524.0X + Partition column - ORC MR 1548 / 1549 10.2 98.4 16.4X + Both columns - CSV 25568 / 25595 0.6 1625.6 1.0X + Both columns - Json 13658 / 13673 1.2 868.4 1.9X + Both columns - Parquet Vectorized 270 / 296 58.3 17.1 94.3X + Both columns - Parquet MR 3501 / 3521 4.5 222.6 7.3X + Both columns - ORC Vectorized 377 / 380 41.7 24.0 67.4X + Both column - ORC Vectorized with copy 447 / 448 35.2 28.4 56.9X + Both columns - ORC MR 2440 / 2446 6.4 155.2 10.4X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + val benchmark = new Benchmark("String with Nulls Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c2)) from csvTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c2)) from jsonTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + benchmark.addCase("ParquetReader Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 13518 / 13529 0.8 1289.2 1.0X + SQL Json 10895 / 10926 1.0 1039.0 1.2X + SQL Parquet Vectorized 1539 / 1581 6.8 146.8 8.8X + SQL Parquet MR 3746 / 3811 2.8 357.3 3.6X + ParquetReader Vectorized 1070 / 1112 9.8 102.0 12.6X + SQL ORC Vectorized 1389 / 1408 7.6 132.4 9.7X + SQL ORC Vectorized with copy 1736 / 1750 6.0 165.6 7.8X + SQL ORC MR 3799 / 3892 2.8 362.3 3.6X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 10854 / 10892 1.0 1035.2 1.0X + SQL Json 8129 / 8138 1.3 775.3 1.3X + SQL Parquet Vectorized 1053 / 1104 10.0 100.4 10.3X + SQL Parquet MR 2840 / 2854 3.7 270.8 3.8X + ParquetReader Vectorized 978 / 1008 10.7 93.2 11.1X + SQL ORC Vectorized 1312 / 1387 8.0 125.1 8.3X + SQL ORC Vectorized with copy 1764 / 1772 5.9 168.2 6.2X + SQL ORC MR 3435 / 3445 3.1 327.6 3.2X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 8043 / 8048 1.3 767.1 1.0X + SQL Json 4911 / 4923 2.1 468.4 1.6X + SQL Parquet Vectorized 206 / 209 51.0 19.6 39.1X + SQL Parquet MR 1528 / 1537 6.9 145.8 5.3X + ParquetReader Vectorized 216 / 219 48.6 20.6 37.2X + SQL ORC Vectorized 462 / 466 22.7 44.1 17.4X + SQL ORC Vectorized with copy 568 / 572 18.5 54.2 14.2X + SQL ORC MR 1647 / 1649 6.4 157.1 4.9X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql(s"SELECT sum(c$middle) FROM csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql(s"SELECT sum(c$middle) FROM jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 3663 / 3665 0.3 3493.2 1.0X + SQL Json 3122 / 3160 0.3 2977.5 1.2X + SQL Parquet Vectorized 40 / 42 26.2 38.2 91.5X + SQL Parquet MR 189 / 192 5.5 180.2 19.4X + SQL ORC Vectorized 48 / 51 21.6 46.2 75.6X + SQL ORC Vectorized with copy 49 / 52 21.4 46.7 74.9X + SQL ORC MR 280 / 289 3.7 267.1 13.1X + + + Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 11420 / 11505 0.1 10891.1 1.0X + SQL Json 11905 / 12120 0.1 11353.6 1.0X + SQL Parquet Vectorized 50 / 54 20.9 47.8 227.7X + SQL Parquet MR 195 / 199 5.4 185.8 58.6X + SQL ORC Vectorized 61 / 65 17.3 57.8 188.3X + SQL ORC Vectorized with copy 62 / 65 17.0 58.8 185.2X + SQL ORC MR 847 / 865 1.2 807.4 13.5X + + + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 21278 / 21404 0.0 20292.4 1.0X + SQL Json 22455 / 22625 0.0 21414.7 0.9X + SQL Parquet Vectorized 73 / 75 14.4 69.3 292.8X + SQL Parquet MR 220 / 226 4.8 209.7 96.8X + SQL ORC Vectorized 82 / 86 12.8 78.2 259.4X + SQL ORC Vectorized with copy 82 / 90 12.7 78.7 258.0X + SQL ORC MR 1568 / 1582 0.7 1495.4 13.6X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + repeatedStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + for (columnWidth <- List(10, 50, 100)) { + columnsBenchmark(1024 * 1024 * 1, columnWidth) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala deleted file mode 100644 index e43336d947364..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.parquet - -import java.io.File - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - -/** - * Benchmark to measure parquet read performance. - * To run this: - * spark-submit --class --jars - */ -object ParquetReadBenchmark { - val conf = new SparkConf() - conf.set("spark.sql.parquet.compression.codec", "snappy") - - val spark = SparkSession.builder - .master("local[1]") - .appName("test-sql-context") - .config(conf) - .getOrCreate() - - // Set default configs. Individual cases will change them if necessary. - spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - def intScanBenchmark(values: Int): Unit = { - // Benchmarks running through spark sql. - val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) - // Benchmarks driving reader component directly. - val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) - - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as id from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(id) from tempTable").collect() - } - - sqlBenchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from tempTable").collect() - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - // Driving the parquet reader in batch mode directly. - parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - val col = batch.column(0) - while (reader.nextBatch()) { - val numRows = batch.numRows() - var i = 0 - while (i < numRows) { - if (!col.isNullAt(i)) sum += col.getInt(i) - i += 1 - } - } - } finally { - reader.close() - } - } - } - - // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val it = batch.rowIterator() - while (it.hasNext) { - val record = it.next() - if (!record.isNullAt(0)) sum += record.getInt(0) - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X - SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X - */ - sqlBenchmark.run() - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X - ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X - */ - parquetReaderBenchmark.run() - } - } - } - - def intStringScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Int and String Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X - SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X - */ - benchmark.run() - } - } - } - - def stringDictionaryScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String Dictionary", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c1)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c1)) from tempTable").collect - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X - SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X - */ - benchmark.run() - } - } - } - - def partitionTableScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select id % 2 as p, cast(id as INT) as id from t1") - .write.partitionBy("p").parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Partitioned Table", values) - - benchmark.addCase("Read data column") { iter => - spark.sql("select sum(id) from tempTable").collect - } - - benchmark.addCase("Read partition column") { iter => - spark.sql("select sum(p) from tempTable").collect - } - - benchmark.addCase("Read both columns") { iter => - spark.sql("select sum(p), sum(id) from tempTable").collect - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Read data column 191 / 250 82.1 12.2 1.0X - Read partition column 82 / 86 192.4 5.2 2.3X - Read both columns 220 / 248 71.5 14.0 0.9X - */ - benchmark.run() - } - } - } - - def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + - s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String with Nulls Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c2)) from tempTable where c1 is " + - "not NULL and c2 is not NULL").collect() - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("PR Vectorized") { num => - var sum = 0 - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - val row = rowIterator.next() - val value = row.getUTF8String(0) - if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X - PR Vectorized 833 / 846 12.6 79.4 1.5X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X - PR Vectorized 732 / 772 14.3 69.8 1.4X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X - PR Vectorized 190 / 200 55.1 18.2 1.7X - */ - - benchmark.run() - } - } - } - - def main(args: Array[String]): Unit = { - intScanBenchmark(1024 * 1024 * 15) - intStringScanBenchmark(1024 * 1024 * 10) - stringDictionaryScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) - } - } -} From b7a036b75b8a1d287ac014b85e90d555753064c9 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 23 May 2018 13:12:05 -0700 Subject: [PATCH 0860/2461] [SPARK-24294] Throw SparkException when OOM in BroadcastExchangeExec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When OutOfMemoryError thrown from BroadcastExchangeExec, scala.concurrent.Future will hit scala bug – https://github.com/scala/bug/issues/9554, and hang until future timeout: We could wrap the OOM inside SparkException to resolve this issue. ## How was this patch tested? Manually tested. Author: jinxing Closes #21342 from jinxing64/SPARK-24294. --- .../spark/util/SparkFatalException.scala | 27 +++++++++++++++++++ .../util/SparkUncaughtExceptionHandler.scala | 13 ++++++--- .../org/apache/spark/util/ThreadUtils.scala | 2 ++ .../exchange/BroadcastExchangeExec.scala | 13 ++++++--- 4 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/SparkFatalException.scala diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala new file mode 100644 index 0000000000000..1aa2009fa9b5b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we catch + * fatal throwable in {@link scala.concurrent.Future}'s body, and re-throw + * SparkFatalException, which wraps the fatal throwable inside. + * Note that SparkFatalException should only be thrown from a {@link scala.concurrent.Future}, + * which is run by using ThreadUtils.awaitResult. ThreadUtils.awaitResult will catch + * it and re-throw the original exception/error. + */ +private[spark] final class SparkFatalException(val throwable: Throwable) extends Exception diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index e0f5af5250e7f..1b34fbde38cd6 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -39,10 +39,15 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) if (!ShutdownHookManager.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(SparkExitCode.OOM) - } else if (exitOnUncaughtException) { - System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + exception match { + case _: OutOfMemoryError => + System.exit(SparkExitCode.OOM) + case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => + // SPARK-24294: This is defensive code, in case that SparkFatalException is + // misused and uncaught. + System.exit(SparkExitCode.OOM) + case _ if exitOnUncaughtException => + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 81aaf79db0c13..165a15c73e7ca 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -200,6 +200,8 @@ private[spark] object ThreadUtils { val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] awaitable.result(atMost)(awaitPermission) } catch { + case e: SparkFatalException => + throw e.throwable // TimeoutException is thrown in the current thread, so not need to warp the exception. case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9e0ec9481b0de..c55f9b8f1a7fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} import org.apache.spark.launcher.SparkLauncher @@ -30,7 +31,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of @@ -111,12 +112,18 @@ case class BroadcastExchangeExec( SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + + throw new SparkFatalException( + new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + s"all worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " + s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value") - .initCause(oe.getCause) + .initCause(oe.getCause)) + case e if !NonFatal(e) => + throw new SparkFatalException(e) } } }(BroadcastExchangeExec.executionContext) From f4579332931c9bf424d0b6147fad89bd63da26f6 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 23 May 2018 17:21:29 -0700 Subject: [PATCH 0861/2461] [SPARK-23416][SS] Add a specific stop method for ContinuousExecution. ## What changes were proposed in this pull request? Add a specific stop method for ContinuousExecution. The previous StreamExecution.stop() method had a race condition as applied to continuous processing: if the cancellation was round-tripped to the driver too quickly, the generic SparkException it caused would be reported as the query death cause. We earlier decided that SparkException should not be added to the StreamExecution.isInterruptionException() whitelist, so we need to ensure this never happens instead. ## How was this patch tested? Existing tests. I could consistently reproduce the previous flakiness by putting Thread.sleep(1000) between the first job cancellation and thread interruption in StreamExecution.stop(). Author: Jose Torres Closes #21384 from jose-torres/fixKafka. --- .../streaming/MicroBatchExecution.scala | 18 ++++++++++++++++++ .../execution/streaming/StreamExecution.scala | 18 ------------------ .../continuous/ContinuousExecution.scala | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6709e7052f005..7817360810bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -126,6 +126,24 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + queryExecutionThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + logInfo(s"Query $prettyIdString was stopped") + } + /** * Repeatedly attempts to run batches as data arrives. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3fc8c7887896a..290de873c5cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -378,24 +378,6 @@ abstract class StreamExecution( } } - /** - * Signals to the thread executing micro-batches that it should stop running after the next - * batch. This method blocks until the thread stops running. - */ - override def stop(): Unit = { - // Set the state to TERMINATED so that the batching thread knows that it was interrupted - // intentionally - state.set(TERMINATED) - if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) - queryExecutionThread.interrupt() - queryExecutionThread.join() - // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak - sparkSession.sparkContext.cancelJobGroup(runId.toString) - } - logInfo(s"Query $prettyIdString was stopped") - } - /** * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 0e7d1019b9c8f..d16b24c89ebef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -356,6 +356,22 @@ class ContinuousExecution( } } } + + /** + * Stops the query execution thread to terminate the query. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + // The query execution thread will clean itself up in the finally clause of runContinuous. + // We just need to interrupt the long running job. + queryExecutionThread.interrupt() + queryExecutionThread.join() + } + logInfo(s"Query $prettyIdString was stopped") + } } object ContinuousExecution { From 230f1441978e14062d3e0b6ba1524acf36e3eafe Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Wed, 23 May 2018 17:22:52 -0700 Subject: [PATCH 0862/2461] [SPARK-24350][SQL] Fixes ClassCastException in the "array_position" function ## What changes were proposed in this pull request? ### Fixes `ClassCastException` in the `array_position` function - [SPARK-24350](https://issues.apache.org/jira/browse/SPARK-24350) When calling `array_position` function with a wrong type of the 1st argument an `AnalysisException` should be thrown instead of `ClassCastException` Example: ```sql select array_position('foo', 'bar') ``` ``` java.lang.ClassCastException: org.apache.spark.sql.types.StringType$ cannot be cast to org.apache.spark.sql.types.ArrayType at org.apache.spark.sql.catalyst.expressions.ArrayPosition.inputTypes(collectionOperations.scala:1398) at org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes$class.checkInputDataTypes(ExpectsInputTypes.scala:44) at org.apache.spark.sql.catalyst.expressions.ArrayPosition.checkInputDataTypes(collectionOperations.scala:1401) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:168) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:168) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:256) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:252) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:288) ``` ## How was this patch tested? unit test Author: Vayda, Oleksandr: IT (PRG) Closes #21401 from wajda/SPARK-24350-array_position-error-fix. --- .../catalyst/expressions/collectionOperations.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 03b3b21a16617..8a877b02c8191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1395,8 +1395,14 @@ case class ArrayPosition(left: Expression, right: Expression) TypeUtils.getInterpretedOrdering(right.dataType) override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ec2a569f900d1..79e743d961af8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -708,6 +708,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") { From 888340151f737bb68d0e419b1e949f11469881f9 Mon Sep 17 00:00:00 2001 From: sychen Date: Thu, 24 May 2018 11:02:09 +0800 Subject: [PATCH 0863/2461] [SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted. Author: sychen Closes #21311 from cxzl25/fix_LongToUnsafeRowMap_page_size. --- .../sql/execution/joins/HashedRelation.scala | 38 +++++++++++-------- .../execution/joins/HashedRelationSuite.scala | 26 ++++++++++++- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1465346eb802d..20ce01f4ce8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 51f8c3325fdff..037cc2e3ccad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() From 486ecc680e9a0e7b6b3c3a45fb883a61072096fc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 24 May 2018 11:34:13 +0800 Subject: [PATCH 0864/2461] [SPARK-24322][BUILD] Upgrade Apache ORC to 1.4.4 ## What changes were proposed in this pull request? ORC 1.4.4 includes [nine fixes](https://issues.apache.org/jira/issues/?filter=12342568&jql=project%20%3D%20ORC%20AND%20resolution%20%3D%20Fixed%20AND%20fixVersion%20%3D%201.4.4). One of the issues is about `Timestamp` bug (ORC-306) which occurs when `native` ORC vectorized reader reads ORC column vector's sub-vector `times` and `nanos`. ORC-306 fixes this according to the [original definition](https://github.com/apache/hive/blob/master/storage-api/src/java/org/apache/hadoop/hive/ql/exec/vector/TimestampColumnVector.java#L45-L46) and this PR includes the updated interpretation on ORC column vectors. Note that `hive` ORC reader and ORC MR reader is not affected. ```scala scala> spark.version res0: String = 2.3.0 scala> spark.sql("set spark.sql.orc.impl=native") scala> Seq(java.sql.Timestamp.valueOf("1900-05-05 12:34:56.000789")).toDF().write.orc("/tmp/orc") scala> spark.read.orc("/tmp/orc").show(false) +--------------------------+ |value | +--------------------------+ |1900-05-05 12:34:55.000789| +--------------------------+ ``` This PR aims to update Apache Spark to use it. **FULL LIST** ID | TITLE -- | -- ORC-281 | Fix compiler warnings from clang 5.0 ORC-301 | `extractFileTail` should open a file in `try` statement ORC-304 | Fix TestRecordReaderImpl to not fail with new storage-api ORC-306 | Fix incorrect workaround for bug in java.sql.Timestamp ORC-324 | Add support for ARM and PPC arch ORC-330 | Remove unnecessary Hive artifacts from root pom ORC-332 | Add syntax version to orc_proto.proto ORC-336 | Remove avro and parquet dependency management entries ORC-360 | Implement error checking on subtype fields in Java ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21372 from dongjoon-hyun/SPARK_ORC144. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- .../sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../datasources/orc/OrcColumnarBatchReader.java | 2 +- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ 7 files changed, 18 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e710e26348117..723180a14febb 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 97ad17a9ff7b1..ea08a001a1c9b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e21bfef8c4291..da874026d7d10 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -176,8 +176,8 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 6e37e518d86e4..883c096ae1ae9 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.10.0 - 1.4.3 + 1.4.4 nohive 1.6.0 9.3.20.v20170531 diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..9bfad1e83ee7b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -136,7 +136,7 @@ public int getInt(int rowId) { public long getLong(int rowId) { int index = getRowIndex(rowId); if (isTimestamp) { - return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; } else { return longData.vector[index]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index dcebdc39f0aa2..a0d9578a377b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -497,7 +497,7 @@ private void putValues( * Returns the number of micros since epoch from an element of TimestampColumnVector. */ private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { - return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000); } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8a3bbd03a26dc..02bfb7197ffc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.sql.Timestamp import java.util.Locale import org.apache.orc.OrcConf.COMPRESS @@ -169,6 +170,14 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-24322 Fix incorrect workaround for bug in java.sql.Timestamp") { + withTempPath { path => + val ts = Timestamp.valueOf("1900-05-05 12:34:56.000789") + Seq(ts).toDF.write.orc(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { From e108f84f5cd562f070872651bdcf6c02e80dd585 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 24 May 2018 11:42:25 +0800 Subject: [PATCH 0865/2461] [MINOR][CORE] Cleanup unused vals in `DAGScheduler.handleTaskCompletion` ## What changes were proposed in this pull request? Cleanup unused vals in `DAGScheduler.handleTaskCompletion` to reduce the code complexity slightly. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #21406 from jiangxb1987/handleTaskCompletion. --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ea7bfd7d7a68d..041eade82d3ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1167,9 +1167,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task - val taskId = event.taskInfo.id val stageId = task.stageId - val taskType = Utils.getFormattedClassName(task) outputCommitCoordinator.taskCompleted( stageId, @@ -1323,7 +1321,7 @@ class DAGScheduler( "tasks in ShuffleMapStages.") } - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => + case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1411,7 +1409,7 @@ class DAGScheduler( } } - case commitDenied: TaskCommitDenied => + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case _: ExceptionFailure | _: TaskKilled => From 8a545822d0cc3a866ef91a94e58ea5c8b1014007 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 May 2018 13:21:02 +0800 Subject: [PATCH 0866/2461] [SPARK-24364][SS] Prevent InMemoryFileIndex from failing if file path doesn't exist ## What changes were proposed in this pull request? This PR proposes to follow up https://github.com/apache/spark/pull/15153 and complete SPARK-17599. `FileSystem` operation (`fs.getFileBlockLocations`) can still fail if the file path does not exist. For example see the exception message below: ``` Error occurred while processing: File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... java.io.FileNotFoundException: File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... org.apache.hadoop.hdfs.DistributedFileSystem.getFileBlockLocations(DistributedFileSystem.java:249) at org.apache.hadoop.hdfs.DistributedFileSystem.getFileBlockLocations(DistributedFileSystem.java:229) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles$3.apply(InMemoryFileIndex.scala:314) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles$3.apply(InMemoryFileIndex.scala:297) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles(InMemoryFileIndex.scala:297) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles$1.apply(InMemoryFileIndex.scala:174) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles$1.apply(InMemoryFileIndex.scala:173) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles(InMemoryFileIndex.scala:173) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.listLeafFiles(InMemoryFileIndex.scala:126) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.refresh0(InMemoryFileIndex.scala:91) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.(InMemoryFileIndex.scala:67) at org.apache.spark.sql.execution.datasources.DataSource.tempFileIndex$lzycompute$1(DataSource.scala:161) at org.apache.spark.sql.execution.datasources.DataSource.org$apache$spark$sql$execution$datasources$DataSource$$tempFileIndex$1(DataSource.scala:152) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:166) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:261) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:94) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:94) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:33) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:196) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:206) at com.hwx.StreamTest$.main(StreamTest.scala:97) at com.hwx.StreamTest.main(StreamTest.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52) at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:906) at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197) at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: org.apache.hadoop.ipc.RemoteException(java.io.FileNotFoundException): File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... ``` So, it fixes it to make a warning instead. ## How was this patch tested? It's hard to write a test. Manually tested multiple times. Author: hyukjinkwon Closes #21408 from HyukjinKwon/missing-files. --- .../datasources/InMemoryFileIndex.scala | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 739d1f456e3ec..9d9f8bd5bb58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -294,9 +294,12 @@ object InMemoryFileIndex extends Logging { if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles } - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + val missingFiles = mutable.ArrayBuffer.empty[String] + val filteredLeafStatuses = allLeafStatuses.filterNot( + status => shouldFilterOut(status.getPath.getName)) + val resolvedLeafStatuses = filteredLeafStatuses.flatMap { case f: LocatedFileStatus => - f + Some(f) // NOTE: // @@ -311,14 +314,27 @@ object InMemoryFileIndex extends Logging { // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), // which is very slow on some file system (RawLocalFileSystem, which is launch a // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) + try { + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + Some(lfs) + } catch { + case _: FileNotFoundException => + missingFiles += f.getPath.toString + None } - lfs } + + if (missingFiles.nonEmpty) { + logWarning( + s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") + } + + resolvedLeafStatuses } /** Checks if we should filter out this path name. */ From 4a14dc0aff9cac85390cab94bc183271fa95beef Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 May 2018 14:19:32 +0800 Subject: [PATCH 0867/2461] [SPARK-22269][BUILD] Run Java linter via SBT for Jenkins ## What changes were proposed in this pull request? This PR proposes to check Java lint via SBT for Jenkins. It uses the SBT wrapper for checkstyle. I manually tested. If we build the codes once, running this script takes 2 mins at maximum in my local: Test codes: ``` Checkstyle failed at following occurrences: [error] Checkstyle error found in /.../spark/core/src/test/java/test/org/apache/spark/JavaAPISuite.java:82: Line is longer than 100 characters (found 103). [error] 1 issue(s) found in Checkstyle report: /.../spark/core/target/checkstyle-test-report.xml [error] Checkstyle error found in /.../spark/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java:84: Line is longer than 100 characters (found 115). [error] 1 issue(s) found in Checkstyle report: /.../spark/sql/hive/target/checkstyle-test-report.xml ... ``` Main codes: ``` Checkstyle failed at following occurrences: [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java:39: Line is longer than 100 characters (found 104). [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:26: Line is longer than 100 characters (found 110). [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:30: Line is longer than 100 characters (found 104). ... ``` ## How was this patch tested? Manually tested. Jenkins build should test this. Author: hyukjinkwon Closes #21399 from HyukjinKwon/SPARK-22269. --- dev/run-tests.py | 5 ++--- dev/sbt-checkstyle | 42 ++++++++++++++++++++++++++++++++++++++++ project/SparkBuild.scala | 14 +++++++++++++- project/plugins.sbt | 8 ++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100755 dev/sbt-checkstyle diff --git a/dev/run-tests.py b/dev/run-tests.py index 164c1e2200aa9..5e8c8590b5c34 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -204,7 +204,7 @@ def run_scala_style_checks(): def run_java_style_checks(): set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") - run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "sbt-checkstyle")]) def run_python_style_checks(): @@ -574,8 +574,7 @@ def main(): or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - # run_java_style_checks() - pass + run_java_style_checks() if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle new file mode 100755 index 0000000000000..8821a7c0e4ccf --- /dev/null +++ b/dev/sbt-checkstyle @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pmesos \ + -Pkafka-0-8 \ + -Pkubernetes \ + -Pyarn \ + -Pflume \ + -Phive \ + -Phive-thriftserver \ + checkstyle test:checkstyle \ + | awk '{if($1~/error/)print}' \ +) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7469f11df0294..4cb6495a33b61 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -27,6 +27,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._ import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -317,7 +318,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -740,6 +741,17 @@ object Unidoc { ) } +object Checkstyle { + lazy val settings = Seq( + checkstyleSeverityLevel := Some(CheckstyleSeverityLevel.Error), + javaSource in (Compile, checkstyle) := baseDirectory.value / "src/main/java", + javaSource in (Test, checkstyle) := baseDirectory.value / "src/test/java", + checkstyleConfigLocation := CheckstyleConfigLocation.File("dev/checkstyle.xml"), + checkstyleOutputFile := baseDirectory.value / "target/checkstyle-output.xml", + checkstyleOutputFile in Test := baseDirectory.value / "target/checkstyle-output.xml" + ) +} + object CopyDependencies { val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") diff --git a/project/plugins.sbt b/project/plugins.sbt index 96bdb9067ae59..ffbd417b0f145 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,11 @@ +addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") + +// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" + +// checkstyle uses guava 23.0. +libraryDependencies += "com.google.guava" % "guava" % "23.0" + // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") From 3469f5c989e686866051382a3a28b2265619cab9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 24 May 2018 20:55:26 +0800 Subject: [PATCH 0868/2461] [SPARK-24230][SQL] Fix SpecificParquetRecordReaderBase with dictionary filters. ## What changes were proposed in this pull request? I missed this commit when preparing #21070. When Parquet is able to filter blocks with dictionary filtering, the expected total value count to be too high in Spark, leading to an error when there were fewer than expected row groups to process. Spark should get the row groups from Parquet to pick up new filter schemes in Parquet like dictionary filtering. ## How was this patch tested? Using in production at Netflix. Added test case for dictionary-filtered blocks. Author: Ryan Blue Closes #21295 from rdblue/SPARK-24230-fix-parquet-block-tracking. --- .../parquet/SpecificParquetRecordReaderBase.java | 6 ++++-- .../datasources/parquet/ParquetQuerySuite.scala | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index daedfd7e78f5f..c975e52734e01 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -146,7 +146,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader( configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } @@ -224,7 +225,8 @@ protected void initialize(String path, List columns) throws IOException this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e1f094d0a7af3..2b1227faf48a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -879,6 +879,18 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-24230: filter row group using dictionary") { + withSQLConf(("parquet.filter.dictionary.enabled", "true")) { + // create a table with values from 0, 2, ..., 18 that will be dictionary-encoded + withParquetTable((0 until 100).map(i => ((i * 2) % 20, s"data-$i")), "t") { + // search for a key that is not present so the dictionary filter eliminates all row groups + // Fails without SPARK-24230: + // java.io.IOException: expecting more rows but reached last block. Read 0 out of 50 + checkAnswer(sql("SELECT _2 FROM t WHERE t._1 = 5"), Seq.empty) + } + } + } } object TestingUDT { From 13bedc05c28fcc6e739fb472bd2ee3035fa11648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 24 May 2018 22:18:58 +0800 Subject: [PATCH 0869/2461] [SPARK-24329][SQL] Test for skipping multi-space lines ## What changes were proposed in this pull request? The PR is a continue of https://github.com/apache/spark/pull/21380 . It checks cases that are handled by the code: https://github.com/apache/spark/blob/e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala#L303-L304 Basically the code skips lines with one or many whitespaces, and lines with comments (see [filterCommentAndEmpty](https://github.com/apache/spark/blob/e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala#L47)) ```scala iter.filter { line => line.trim.nonEmpty && !line.startsWith(options.comment.toString) } ``` Closes #21380 ## How was this patch tested? Added a test for the case described above. Author: Maxim Gekk Author: Maxim Gekk Closes #21394 from MaxGekk/test-for-multi-space-lines. --- .../resources/test-data/comments-whitespaces.csv | 8 ++++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 15 +++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/comments-whitespaces.csv diff --git a/sql/core/src/test/resources/test-data/comments-whitespaces.csv b/sql/core/src/test/resources/test-data/comments-whitespaces.csv new file mode 100644 index 0000000000000..2737978f83a5e --- /dev/null +++ b/sql/core/src/test/resources/test-data/comments-whitespaces.csv @@ -0,0 +1,8 @@ +# The file contains comments, whitespaces and empty lines +colA +# empty line + +# the line with a few whitespaces + +# int value with leading and trailing whitespaces + "a" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07e6c74b14d0d..2bac1a3e2d4c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1368,4 +1368,19 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } + + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { + val schema = new StructType().add("colA", StringType) + val ds = spark + .read + .schema(schema) + .option("multiLine", false) + .option("header", true) + .option("comment", "#") + .option("ignoreLeadingWhiteSpace", false) + .option("ignoreTrailingWhiteSpace", false) + .csv(testFile("test-data/comments-whitespaces.csv")) + + checkAnswer(ds, Seq(Row(""" "a" """))) + } } From 0d89943449764e8a578edf4ceb6245158421eb96 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 24 May 2018 23:38:50 +0800 Subject: [PATCH 0870/2461] [SPARK-24378][SQL] Fix date_trunc function incorrect examples ## What changes were proposed in this pull request? Fix `date_trunc` function incorrect examples. ## How was this patch tested? N/A Author: Yuming Wang Closes #21423 from wangyum/SPARK-24378. --- .../expressions/datetimeExpressions.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index e8d85f72f7a7a..08838d2b2c612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1533,14 +1533,14 @@ case class TruncDate(date: Expression, format: Expression) """, examples = """ Examples: - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); - 2015-01-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); - 2015-03-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); - 2015-03-05T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); - 2015-03-05T09:00:00 + > SELECT _FUNC_('YEAR', '2015-03-05T09:32:05.359'); + 2015-01-01 00:00:00 + > SELECT _FUNC_('MM', '2015-03-05T09:32:05.359'); + 2015-03-01 00:00:00 + > SELECT _FUNC_('DD', '2015-03-05T09:32:05.359'); + 2015-03-05 00:00:00 + > SELECT _FUNC_('HOUR', '2015-03-05T09:32:05.359'); + 2015-03-05 09:00:00 """, since = "2.3.0") // scalastyle:on line.size.limit From 53c06ddabbdf689f8823807445849ad63173676f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 24 May 2018 13:00:24 -0700 Subject: [PATCH 0871/2461] [SPARK-24332][SS][MESOS] Fix places reading 'spark.network.timeout' as milliseconds ## What changes were proposed in this pull request? This PR replaces `getTimeAsMs` with `getTimeAsSeconds` to fix the issue that reading "spark.network.timeout" using a wrong time unit when the user doesn't specify a time out. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #21382 from zsxwing/fix-network-timeout-conf. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaRelation.scala | 4 +++- .../scala/org/apache/spark/sql/kafka010/KafkaSource.scala | 2 +- .../scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala | 2 +- .../cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala | 2 +- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 64ba98762788c..737da2e51b125 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -68,7 +68,7 @@ private[kafka010] class KafkaMicroBatchReader( private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", - SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 7103709969c18..c31e6ed3e0903 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -48,7 +48,9 @@ private[kafka010] class KafkaRelation( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sqlContext.sparkContext.conf.getTimeAsSeconds( + "spark.network.timeout", + "120s") * 1000L).toString ).toLong override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1c7b3a29a861f..101e649727fcf 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaSource( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L).toString ).toLong private val maxOffsetsPerTrigger = diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 81abc9860bfc3..3efc90fe466b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -65,7 +65,7 @@ private[spark] class KafkaRDD[K, V]( // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", - conf.getTimeAsMs("spark.network.timeout", "120s")) + conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 9b75e4c98344a..d35bea4aca311 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -634,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } From 0fd68cb7278e5fdf106e73b580ee7dd829006386 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:08:52 -0700 Subject: [PATCH 0872/2461] [SPARK-24234][SS] Support multiple row writers in continuous processing shuffle reader. ## What changes were proposed in this pull request? https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii Support multiple different row writers in continuous processing shuffle reader. Note that having multiple read-side buffers ended up being the natural way to do this. Otherwise it's hard to express the constraint of sending an epoch marker only when all writers have sent one. ## How was this patch tested? new unit tests Author: Jose Torres Closes #21385 from jose-torres/multipleWrite. --- .../shuffle/ContinuousShuffleReadRDD.scala | 21 ++- .../shuffle/UnsafeRowReceiver.scala | 87 ++++++++-- .../shuffle/ContinuousShuffleReadSuite.scala | 163 +++++++++++++++--- 3 files changed, 227 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 270b1a5c28dee..801b28b751bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -25,11 +25,16 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { +case class ContinuousShuffleReadPartition( + index: Int, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -42,16 +47,24 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. + * + * @param sc the RDD context + * @param numPartitions the number of read partitions for this RDD + * @param queueSize the size of the row buffers to use + * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD + * @param epochIntervalMs the checkpoint interval of the streaming query */ class ContinuousShuffleReadRDD( sc: SparkContext, numPartitions: Int, - queueSize: Int = 1024) + queueSize: Int = 1024, + numShuffleWriters: Int = 1, + epochIntervalMs: Long = 1000) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize) + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index b8adbb743c6c2..d81f552d56626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -27,10 +29,17 @@ import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * + * Each message comes tagged with writerId, identifying which writer the message is coming + * from. The receiver will only begin the next epoch once all writers have sent an epoch + * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable -private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -41,11 +50,15 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa */ private[shuffle] class UnsafeRowReceiver( queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + } // Exposed for testing to determine if the endpoint gets stopped on task end. private[shuffle] val stopped = new AtomicBoolean(false) @@ -56,20 +69,70 @@ private[shuffle] class UnsafeRowReceiver( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: UnsafeRowReceiverMessage => - queue.put(r) + queues(r.writerId).put(r) context.reply(()) } override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = queue.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null + // An array of flags for whether each writer ID has gotten an epoch marker. + private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + + private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { + override def call(): UnsafeRowReceiverMessage = queues(writerId).take() } - override def close(): Unit = {} + // Initialize by submitting tasks to read the first row from each writer. + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + /** + * In each call to getNext(), we pull the next row available in the completion queue, and then + * submit another task to read the next row from the writer which returned it. + * + * When a writer sends an epoch marker, we note that it's finished and don't submit another + * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. + */ + override def getNext(): UnsafeRow = { + var nextRow: UnsafeRow = null + while (!finished && nextRow == null) { + completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { + case null => + // Try again if the poll didn't wait long enough to get a real result. + // But we should be getting at least an epoch marker every checkpoint interval. + val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { + case (flag, idx) if !flag => idx + } + logWarning( + s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + + s"for writers $writerIdsUncommitted to send epoch markers.") + + // The completion service guarantees this future will be available immediately. + case future => future.get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + nextRow = r + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need to loop again to poll from the remaining + // writers. + writerEpochMarkersReceived(writerId) = true + if (writerEpochMarkersReceived.forall(_ == true)) { + finished = true + } + } + } + } + + nextRow + } + + override def close(): Unit = { + executor.shutdownNow() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index b25e75b3b37a6..2e4d607a403ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -21,7 +21,8 @@ import org.apache.spark.{TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String class ContinuousShuffleReadSuite extends StreamTest { @@ -30,6 +31,11 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { messages.foreach(endpoint.askSync[Unit](_)) } @@ -57,8 +63,8 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(111)) + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) ) ctx.markTaskCompleted(None) @@ -71,8 +77,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with marker last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -86,10 +95,10 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverRow(unsafeRow(111)), - ReceiverRow(unsafeRow(222)), - ReceiverRow(unsafeRow(333)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(111)), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) ) val iter = rdd.compute(rdd.partitions(0), ctx) @@ -101,11 +110,11 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverRow(unsafeRow(111)), - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(222)), - ReceiverRow(unsafeRow(333)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) ) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) @@ -118,14 +127,15 @@ class ContinuousShuffleReadSuite extends StreamTest { test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( endpoint, - ReceiverEpochMarker(), - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(111)), - ReceiverEpochMarker(), - ReceiverEpochMarker(), - ReceiverEpochMarker() + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0) ) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -146,8 +156,8 @@ class ContinuousShuffleReadSuite extends StreamTest { // Send index for identification. send( part.endpoint, - ReceiverRow(unsafeRow(part.index)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(part.index)), + ReceiverEpochMarker(0) ) } @@ -160,25 +170,122 @@ class ContinuousShuffleReadSuite extends StreamTest { } test("blocks waiting for new rows") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) + val epoch = rdd.compute(rdd.partitions(0), ctx) val readRowThread = new Thread { override def run(): Unit = { - // set the non-inheritable thread local - TaskContext.setTaskContext(ctx) - val epoch = rdd.compute(rdd.partitions(0), ctx) - epoch.next().getInt(0) + try { + epoch.next().getInt(0) + } catch { + case _: InterruptedException => // do nothing - expected at test ending + } } } try { readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.TIMED_WAITING) } } finally { readRowThread.interrupt() readRowThread.join() } } + + test("multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + + test("epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(2) + ) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + send(endpoint, ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarkerThread.isAlive + } + + // Join to pick up assertion failures. + readEpochMarkerThread.join() + } + + test("writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } } From 3b20b34ab72c92d9d20188ed430955e1a94eac9c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 25 May 2018 11:16:35 +0800 Subject: [PATCH 0873/2461] [SPARK-24367][SQL] Parquet: use JOB_SUMMARY_LEVEL instead of deprecated flag ENABLE_JOB_SUMMARY ## What changes were proposed in this pull request? In current parquet version,the conf ENABLE_JOB_SUMMARY is deprecated. When writing to Parquet files, the warning message ```WARN org.apache.parquet.hadoop.ParquetOutputFormat: Setting parquet.enable.summary-metadata is deprecated, please use parquet.summary.metadata.level``` keeps showing up. From https://github.com/apache/parquet-mr/blame/master/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetOutputFormat.java#L164 we can see that we should use JOB_SUMMARY_LEVEL. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21411 from gengliangwang/summaryLevel. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../datasources/parquet/ParquetFileFormat.scala | 10 ++++++---- .../datasources/parquet/ParquetCommitterSuite.scala | 7 ++++++- .../execution/datasources/parquet/ParquetIOSuite.scala | 2 +- .../parquet/ParquetPartitionDiscoverySuite.scala | 2 +- .../datasources/parquet/ParquetQuerySuite.scala | 2 +- .../sql/sources/ParquetHadoopFsRelationSuite.scala | 2 +- 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 15ba10f604510..93d356fef07af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -395,7 +395,7 @@ object SQLConf { .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + - "will never be created, irrespective of the value of parquet.enable.summary-metadata") + "will never be created, irrespective of the value of parquet.summary.metadata.level") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d1f9e11ed4225..60fc9ec7e1f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -34,6 +34,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType @@ -125,16 +126,17 @@ class ParquetFileFormat conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. - if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { - conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) } - if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { // output summary is requested, but the class is not a Parquet Committer logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + s" create job summaries. " + - s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") } new OutputWriterFactory { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index f3ecc5ced689f..4b2437803d645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -91,9 +91,14 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils summary: Boolean, check: Boolean): Option[FileStatus] = { var result: Option[FileStatus] = None + val summaryLevel = if (summary) { + "ALL" + } else { + "NONE" + } withSQLConf( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> summaryLevel) { withTempPath { dest => val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) val destPath = new Path(dest.toURI) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0b3e8ca060d87..002c42f23bd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -543,7 +543,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) - withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withSQLConf(ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" spark.range(1 << 16).selectExpr("(id % 4) AS i") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index e887c9734a8b8..9966ed94a8392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1014,7 +1014,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val path = dir.getCanonicalPath withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", "spark.sql.sources.commitProtocolClass" -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { spark.range(3).write.parquet(s"$path/p0=0/p1=0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2b1227faf48a0..dbf637783e6d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName, SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true", - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true" + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL" ) { testSchemaMerging(2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index dce5bb7ddba66..6858bbc441721 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -124,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { withTempPath { dir => From 64fad0b519cf35b8c0a0dec18dd3df9488a5ed25 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 24 May 2018 21:38:04 -0700 Subject: [PATCH 0874/2461] [SPARK-24244][SPARK-24368][SQL] Passing only required columns to the CSV parser ## What changes were proposed in this pull request? uniVocity parser allows to specify only required column names or indexes for [parsing](https://www.univocity.com/pages/parsers-tutorial) like: ``` // Here we select only the columns by their indexes. // The parser just skips the values in other columns parserSettings.selectIndexes(4, 0, 1); CsvParser parser = new CsvParser(parserSettings); ``` In this PR, I propose to extract indexes from required schema and pass them into the CSV parser. Benchmarks show the following improvements in parsing of 1000 columns: ``` Select 100 columns out of 1000: x1.76 Select 1 column out of 1000: x2 ``` **Note**: Comparing to current implementation, the changes can return different result for malformed rows in the `DROPMALFORMED` and `FAILFAST` modes if only subset of all columns is requested. To have previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## How was this patch tested? It was tested by new test which selects 3 columns out of 15, by existing tests and by new benchmarks. Author: Maxim Gekk Author: Maxim Gekk Closes #21415 from MaxGekk/csv-column-pruning2. --- docs/sql-programming-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVFileFormat.scala | 18 ++++++-- .../datasources/csv/CSVOptions.scala | 3 ++ .../datasources/csv/UnivocityParser.scala | 26 ++++++----- .../datasources/csv/CSVBenchmarks.scala | 46 +++++++++++++++++++ .../datasources/csv/CSVInferSchemaSuite.scala | 22 ++++----- .../execution/datasources/csv/CSVSuite.scala | 41 ++++++++++++++--- .../csv/UnivocityParserSuite.scala | 37 +++++++-------- 10 files changed, 152 insertions(+), 52 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f1ed316341b95..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 93d356fef07af..8d2320d8a6ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1307,6 +1307,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** @@ -1664,6 +1671,8 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 917f0cb221412..ac4580a0919ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -480,6 +480,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def csv(csvDataset: Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone) val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index e20977a4ec79f..21279d6daf7ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -41,8 +41,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) val csvDataSource = CSVDataSource(parsedOptions) csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -51,8 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -64,7 +68,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { dataSchema: StructType): OutputWriterFactory = { CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) csvOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -97,6 +104,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new CSVOptions( options, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1066d156acd74..7119189a4e131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -28,16 +28,19 @@ import org.apache.spark.sql.catalyst.util._ class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], + val columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { def this( parameters: Map[String, String], + columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String = "") = { this( CaseInsensitiveMap(parameters), + columnPruning, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index d442ba7e59c61..1a3dacb8398e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,7 +74,53 @@ object CSVBenchmarks { } } + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 81091 / 81692 0.0 81090.7 1.0X + Select 100 columns 30003 / 34448 0.0 30003.0 2.7X + Select one column 24792 / 24855 0.0 24792.0 3.3X + count() 24344 / 24642 0.0 24343.8 3.3X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..842251be92c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(NullType, "", options) == NullType) assert(CSVInferSchema.inferField(NullType, null, options) == NullType) assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) @@ -41,7 +41,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) @@ -60,21 +60,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } test("Timestamp field types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) } test("Boolean fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) } @@ -92,12 +92,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT") + options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) @@ -111,12 +111,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), "GMT") + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == @@ -134,7 +134,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { test("DoubleType should be infered when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", - "positiveInf" -> "inf"), "GMT") + "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2bac1a3e2d4c6..afe10bdc4de26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1383,4 +1385,29 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(ds, Seq(Row(""" "a" """))) } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + assert(idf.count() == 2) + checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index efbf73534bd19..458edb253fb33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParserSuite extends SparkFunSuite { - private val parser = - new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String], "GMT")) + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) private def assertNull(v: Any) = assert(v == null) @@ -38,7 +39,7 @@ class UnivocityParserSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } @@ -51,21 +52,21 @@ class UnivocityParserSuite extends SparkFunSuite { // Nullable field with nullValue option. types.foreach { t => // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = false, options = options) var message = intercept[RuntimeException] { @@ -81,7 +82,7 @@ class UnivocityParserSuite extends SparkFunSuite { // If nullValue is different with empty string, then, empty string should not be casted into // null. Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") val converter = parser.makeConverter("_1", StringType, nullable = b, options = options) assert(converter.apply("") == UTF8String.fromString("")) @@ -89,7 +90,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val exception = intercept[RuntimeException]{ parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") } @@ -97,7 +98,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) @@ -107,7 +108,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = @@ -116,7 +117,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), "GMT") + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) @@ -131,7 +132,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val types = Seq(DoubleType, FloatType) val input = Seq("10u000", "abc", "1 2/3") types.foreach { dt => @@ -145,7 +146,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, options = options ).apply("nn").asInstanceOf[Float] @@ -156,7 +157,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, options = options ).apply("-").asInstanceOf[Double] @@ -165,14 +166,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Float] @@ -181,14 +182,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Double] From fd315f5884c03c6dd21abca178897584dee83f1a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 25 May 2018 12:49:06 -0700 Subject: [PATCH 0875/2461] [MINOR] Add port SSL config in toString and scaladoc ## What changes were proposed in this pull request? SPARK-17874 introduced a new configuration to set the port where SSL services bind to. We missed to update the scaladoc and the `toString` method, though. The PR adds it in the missing places ## How was this patch tested? checked the `toString` output in the logs Author: Marco Gaido Closes #21429 from mgaido91/minor_ssl. --- core/src/main/scala/org/apache/spark/SSLOptions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 477b01968c6ef..04c38f12acc78 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -128,7 +128,7 @@ private[spark] case class SSLOptions( } /** Returns a string representation of this SSLOptions with all the passwords masked. */ - override def toString: String = s"SSLOptions{enabled=$enabled, " + + override def toString: String = s"SSLOptions{enabled=$enabled, port=$port, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" @@ -142,6 +142,7 @@ private[spark] object SSLOptions extends Logging { * * The following settings are allowed: * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].port` - the port where to bind the SSL server * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory * $ - `[ns].keyStorePassword` - a password to the key-store file * $ - `[ns].keyPassword` - a password to the private key From 1b1528a504febfadf6fe41fd72e657689da50525 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 25 May 2018 15:42:46 -0700 Subject: [PATCH 0876/2461] [SPARK-24366][SQL] Improving of error messages for type converting ## What changes were proposed in this pull request? Currently, users are getting the following error messages on type conversions: ``` scala.MatchError: test (of class java.lang.String) ``` The message doesn't give any clues to the users where in the schema the error happened. In this PR, I would like to improve the error message like: ``` The value (test) of the type (java.lang.String) cannot be converted to struct ``` ## How was this patch tested? Added tests for converting of wrong values to `struct`, `map`, `array`, `string` and `decimal`. Author: Maxim Gekk Closes #21410 from MaxGekk/type-conv-error. --- .../sql/catalyst/CatalystTypeConverters.scala | 16 +++++++ .../CatalystTypeConvertersSuite.scala | 45 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 474ec592201d9..9e9105a157abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -170,6 +170,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -206,6 +209,10 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + "cannot be converted to a map type with " + + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})") } } @@ -252,6 +259,9 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${structType.catalogString}") } override def toScala(row: InternalRow): Row = { @@ -276,6 +286,9 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the string type") } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString @@ -309,6 +322,9 @@ object CatalystTypeConverters { case d: JavaBigDecimal => Decimal(d) case d: JavaBigInteger => Decimal(d) case d: Decimal => d + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${dataType.catalogString}") } decimal.toPrecision(dataType.precision, dataType.scale) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f3702ec92b425..f99af9b84d959 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -94,4 +94,49 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) == doubleGenericArray) } + + test("converting a wrong value to the struct type") { + val structType = new StructType().add("f1", IntegerType) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(structType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to struct")) + } + + test("converting a wrong value to the map type") { + val mapType = MapType(StringType, IntegerType, false) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(mapType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to a map type with key " + + "type (string) and value type (int)")) + } + + test("converting a wrong value to the array type") { + val arrayType = ArrayType(IntegerType, true) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(arrayType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to an array of int")) + } + + test("converting a wrong value to the decimal type") { + val decimalType = DecimalType(10, 0) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(decimalType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to decimal(10,0)")) + } + + test("converting a wrong value to the string type") { + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(StringType)(0.1) + } + assert(exception.getMessage.contains("The value (0.1) of the type " + + "(java.lang.Double) cannot be converted to the string type")) + } } From ed1a65448f228776afe2e5c6b1ac4228d2ed2854 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 26 May 2018 20:26:00 +0800 Subject: [PATCH 0877/2461] [SPARK-19112][CORE][FOLLOW-UP] Add missing shortCompressionCodecNames to configuration. ## What changes were proposed in this pull request? Spark provides four codecs: `lz4`, `lzf`, `snappy`, and `zstd`. This pr add missing shortCompressionCodecNames to configuration. ## How was this patch tested? manually tested Author: Yuming Wang Closes #21431 from wangyum/SPARK-19112. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index fd2670cba2125..64af0e98a82f5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -910,8 +910,8 @@ Apart from these, the following properties are also available, and may be useful lz4 The codec used to compress internal data such as RDD partitions, event log, broadcast variables - and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, - and snappy. You can also use fully qualified class names to specify the codec, + and shuffle outputs. By default, Spark provides four codecs: lz4, lzf, + snappy, and zstd. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, From d440699192f21b14dfb8ec0dc5673537e1003b55 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Sat, 26 May 2018 20:42:23 -0700 Subject: [PATCH 0878/2461] [SPARK-24381][TESTING] Add unit tests for NOT IN subquery around null values ## What changes were proposed in this pull request? This PR adds several unit tests along the `cols NOT IN (subquery)` pathway. There are a scattering of tests here and there which cover this codepath, but there doesn't seem to be a unified unit test of the correctness of null-aware anti joins anywhere. I have also added a brief explanation of how this expression behaves in SubquerySuite. Lastly, I made some clarifying changes in the NOT IN pathway in RewritePredicateSubquery. ## How was this patch tested? Added unit tests! There should be no behavioral change in this PR. Author: Miles Yucht Closes #21425 from mgyucht/spark-24381. --- .../sql/catalyst/optimizer/subquery.scala | 9 +- ...not-in-unit-tests-multi-column-literal.sql | 39 +++++ .../not-in-unit-tests-multi-column.sql | 98 ++++++++++++ ...ot-in-unit-tests-single-column-literal.sql | 42 +++++ .../not-in-unit-tests-single-column.sql | 123 +++++++++++++++ ...in-unit-tests-multi-column-literal.sql.out | 54 +++++++ .../not-in-unit-tests-multi-column.sql.out | 134 ++++++++++++++++ ...n-unit-tests-single-column-literal.sql.out | 69 ++++++++ .../not-in-unit-tests-single-column.sql.out | 149 ++++++++++++++++++ .../typeCoercion/native/concat.sql.out | 2 +- 10 files changed, 714 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7d..de89e17e51f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -116,15 +116,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // (a1,a2,...) = (b1,b2,...) // to // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... - val joinConds = splitConjunctivePredicates(joinCond.get) + val baseJoinConds = splitConjunctivePredicates(joinCond.get) + val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c))) // After that, add back the correlated join predicate(s) in the subquery // Example: // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) // will have the final conditions in the LEFT ANTI as - // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) - val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql new file mode 100644 index 0000000000000..8eea84f4f5272 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -0,0 +1,39 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- This file has the same test cases as not-in-unit-tests-multi-column.sql with literals instead of +-- subqueries. Small changes have been made to the literals to make them typecheck. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +-- Case 1 (not possible to write a literal with no rows, so we ignore it.) +-- (subquery is empty -> row is returned) + +-- Cases 2, 3 and 4 are currently broken, so I have commented them out here. +-- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql new file mode 100644 index 0000000000000..9f8dc7fca3b94 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql @@ -0,0 +1,98 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- +-- Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? | +-- | 1 | empty | * | * | * | yes | +-- | 2 | 1+ row has null for all columns | * | * | * | no | +-- | 3 | no row has null for all columns | (yes, yes) | * | * | no | +-- | 4 | no row has null for all columns | (no, yes) | yes | * | no | +-- | 5 | no row has null for all columns | (no, yes) | no | * | yes | +-- | 6 | no | (no, no) | yes | yes | no | +-- | 7 | no | (no, no) | _ | _ | yes | +-- +-- This can be generalized to include more tests for more columns, but it covers the main cases +-- when there is more than one column. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d); + + -- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +; + + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +; + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql new file mode 100644 index 0000000000000..b261363d1dde7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql @@ -0,0 +1,42 @@ +-- Unit tests for simple NOT IN with a literal expression of a single column +-- +-- More information can be found in not-in-unit-tests-single-column.sql. +-- This file has the same test cases as not-in-unit-tests-single-column.sql with literals instead of +-- subqueries. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null); + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2); + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2); + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql new file mode 100644 index 0000000000000..2cc08e10acf67 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql @@ -0,0 +1,123 @@ +-- Unit tests for simple NOT IN predicate subquery across a single column. +-- +-- ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the +-- rules are confusing to the uninitiated, and precedence and treatment of null values is plain +-- unintuitive. To make this simpler to understand, I've come up with a plain English way of +-- describing the expected behavior of this query. +-- +-- - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of +-- whether the filtered columns include nulls. +-- - If the subquery contains a result with all columns null, then the row should not be returned. +-- - If for all non-null filter columns there exists a row in the subquery in which each column +-- either +-- 1. is equal to the corresponding filter column or +-- 2. is null +-- then the row should not be returned. (This includes the case where all filter columns are +-- null.) +-- - Otherwise, the row should be returned. +-- +-- Using these rules, we can come up with a set of test cases for single-column and multi-column +-- NOT IN test cases. +-- +-- Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | is a null? | a = c? | row with a included in result? | +-- | 1 | empty | | | yes | +-- | 2 | yes | | | no | +-- | 3 | no | yes | | no | +-- | 4 | no | no | yes | no | +-- | 5 | no | no | no | yes | +-- +-- There are also some considerations around correlated subqueries. Correlated subqueries can +-- cause cases 2, 3, or 4 to be reduced to case 1 by limiting the number of rows returned by the +-- subquery, so the row from the parent table should always be included in the output. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +; + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +; + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +; + + -- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out new file mode 100644 index 0000000000000..a16e98af9a417 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 1 schema +struct +-- !query 1 output +NULL 1 + + +-- !query 2 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 3 schema +struct +-- !query 3 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out new file mode 100644 index 0000000000000..aa5f64b8ebf55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out @@ -0,0 +1,134 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 +NULL NULL + + +-- !query 3 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 6 schema +struct +-- !query 6 output +NULL 1 + + +-- !query 7 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 8 schema +struct +-- !query 8 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out new file mode 100644 index 0000000000000..446447e890449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null) +-- !query 1 schema +struct +-- !query 1 output + + + +-- !query 2 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6) +-- !query 4 schema +struct +-- !query 4 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out new file mode 100644 index 0000000000000..f58ebeacc2872 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out @@ -0,0 +1,149 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 + + +-- !query 3 +-- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +-- !query 6 schema +struct +-- !query 6 output +2 3 + + +-- !query 7 +-- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 7 schema +struct +-- !query 7 output +2 3 +4 5 +NULL 1 + + +-- !query 8 +-- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 8 schema +struct +-- !query 8 output +NULL 1 + + +-- !query 9 +-- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 9 schema +struct +-- !query 9 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 62befc5ca0f15..be637b66abc86 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 14 -- !query 0 From 672209f2909a95e891f3c779bfb2f0e534239851 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 28 May 2018 10:50:17 +0800 Subject: [PATCH 0879/2461] [SPARK-24334] Fix race condition in ArrowPythonRunner causes unclean shutdown of Arrow memory allocator ## What changes were proposed in this pull request? There is a race condition of closing Arrow VectorSchemaRoot and Allocator in the writer thread of ArrowPythonRunner. The race results in memory leak exception when closing the allocator. This patch removes the closing routine from the TaskCompletionListener and make the writer thread responsible for cleaning up the Arrow memory. This issue be reproduced by this test: ``` def test_memory_leak(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType, array, lit, explode # Have all data in a single executor thread so it can trigger the race condition easier with self.sql_conf({'spark.sql.shuffle.partitions': 1}): df = self.spark.range(0, 1000) df = df.withColumn('id', array([lit(i) for i in range(0, 300)])) \ .withColumn('id', explode(col('id'))) \ .withColumn('v', array([lit(i) for i in range(0, 1000)])) pandas_udf(df.schema, PandasUDFType.GROUPED_MAP) def foo(pdf): xxx return pdf result = df.groupby('id').apply(foo) with QuietTest(self.sc): with self.assertRaises(py4j.protocol.Py4JJavaError) as context: result.count() self.assertTrue('Memory leaked' not in str(context.exception)) ``` Note: Because of the race condition, the test case cannot reproduce the issue reliably so it's not added to test cases. ## How was this patch tested? Because of the race condition, the bug cannot be unit test easily. So far it has only happens on large amount of data. This is currently tested manually. Author: Li Jin Closes #21397 from icexelloss/SPARK-24334-arrow-memory-leak. --- .../execution/python/ArrowPythonRunner.scala | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5fcdcddca7d51..01e19bddbfb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -70,19 +70,13 @@ class ArrowPythonRunner( val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - context.addTaskCompletionListener { _ => - root.close() - allocator.close() - } - - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + while (inputIterator.hasNext) { val nextBatch = inputIterator.next() @@ -94,8 +88,21 @@ class ArrowPythonRunner( writer.writeBatch() arrowWriter.reset() } - } { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. root.close() allocator.close() } From de01a8d50c9c3e196591db057d544f5d7b24d95f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 28 May 2018 12:09:44 +0800 Subject: [PATCH 0880/2461] [SPARK-24373][SQL] Add AnalysisBarrier to RelationalGroupedDataset's and KeyValueGroupedDataset's child ## What changes were proposed in this pull request? When we create a `RelationalGroupedDataset` or a `KeyValueGroupedDataset` we set its child to the `logicalPlan` of the `DataFrame` we need to aggregate. Since the `logicalPlan` is already analyzed, we should not analyze it again. But this happens when the new plan of the aggregate is analyzed. The current behavior in most of the cases is likely to produce no harm, but in other cases re-analyzing an analyzed plan can change it, since the analysis is not idempotent. This can cause issues like the one described in the JIRA (missing to find a cached plan). The PR adds an `AnalysisBarrier` to the `logicalPlan` which is used as child of `RelationalGroupedDataset` or a `KeyValueGroupedDataset`. ## How was this patch tested? added UT Author: Marco Gaido Closes #21432 from mgaido91/SPARK-24373. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 12 +-- .../spark/sql/GroupedDatasetSuite.scala | 96 +++++++++++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 32267eb0300f5..abb5ae53f4d73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,7 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cbd..36f6038aa9485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = queryExecution.analyzed + private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c2be3610ae30..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -63,17 +63,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) } } @@ -433,7 +433,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.logicalPlan)) + df.planWithBarrier)) } /** @@ -459,7 +459,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.logicalPlan + val child = df.planWithBarrier val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala new file mode 100644 index 0000000000000..147c0b61f5017 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +class GroupedDatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val scalaUDF = udf((x: Long) => { x + 1 }) + private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) + + private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { + assert(atLevel >= 0) + var children = Seq(ds.queryExecution.logical) + (1 to atLevel).foreach { _ => + children = children.flatMap(_.children) + } + val barriers = children.collect { + case ab: AnalysisBarrier => ab + } + assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + + ds.queryExecution.logical) + } + + test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { + val groupByDataset = datasetWithUDF.groupBy() + val rollupDataset = datasetWithUDF.rollup("s") + val cubeDataset = datasetWithUDF.cube("s") + val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) + datasetWithUDF.cache() + Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => + val df = rgDS.count() + assertContainsAnalysisBarrier(df) + assertCached(df) + } + + val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( + Array.emptyByteArray, + Array.emptyByteArray, + Array.empty, + StructType(Seq(StructField("s", LongType)))) + val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => + assertContainsAnalysisBarrier(df, 2) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } + + test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { + val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) + datasetWithUDF.cache() + val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) + val keysKVDataset = kvDasaset.keys + val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) + val aggKVDataset = kvDasaset.count() + val otherKVDataset = spark.range(1).groupByKey(_ + 1) + val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) + Seq((mapValuesKVDataset, 1), + (keysKVDataset, 2), + (flatMapGroupsKVDataset, 2), + (aggKVDataset, 1), + (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => + assertContainsAnalysisBarrier(df, analysisBarrierDepth) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } +} From fa2ae9d2019f839647d17932d8fea769e7622777 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 28 May 2018 12:56:05 +0800 Subject: [PATCH 0881/2461] [SPARK-24392][PYTHON] Label pandas_udf as Experimental ## What changes were proposed in this pull request? The pandas_udf functionality was introduced in 2.3.0, but is not completely stable and still evolving. This adds a label to indicate it is still an experimental API. ## How was this patch tested? NA Author: Bryan Cutler Closes #21435 from BryanCutler/arrow-pandas_udf-experimental-SPARK-24392. --- docs/sql-programming-guide.md | 4 ++++ python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/functions.py | 2 ++ python/pyspark/sql/group.py | 2 ++ python/pyspark/sql/session.py | 2 ++ 5 files changed, 12 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fc26562ff33da..50600861912b1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1827,6 +1827,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. +## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above + + - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 213dc158f9328..808235ab25440 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1975,6 +1975,8 @@ def toPandas(self): .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fbc8a2d038f8f..efcce25a08e04 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2456,6 +2456,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. + .. note:: Experimental + The function type of the UDF can be one of the following: 1. SCALAR diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 3505065b648f2..0906c9c6b329a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -236,6 +236,8 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. + .. note:: Experimental + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 13d6e2e53dbd0..d675a240172a7 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -584,6 +584,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr .. versionchanged:: 2.1 Added verifySchema. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] From b31b587cd091010337378cf448fd598c37757053 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 29 May 2018 10:35:30 +0800 Subject: [PATCH 0882/2461] [SPARK-19613][SS][TEST] Random.nextString is not safe for directory namePrefix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `Random.nextString` is good for generating random string data, but it's not proper for directory name prefix in `Utils.createDirectory(tempDir, Random.nextString(10))`. This PR uses more safe directory namePrefix. ```scala scala> scala.util.Random.nextString(10) res0: String = 馨쭔ᎰႻ穚䃈兩㻞藑並 ``` ```scala StateStoreRDDSuite: - versioning and immutability - recovering from files - usage with iterators - only gets and only puts - preferred locations using StateStoreCoordinator *** FAILED *** java.io.IOException: Failed to create a temp directory (under /.../spark/sql/core/target/tmp/StateStoreRDDSuite8712796397908632676) after 10 attempts! at org.apache.spark.util.Utils$.createDirectory(Utils.scala:295) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13$$anonfun$apply$6.apply(StateStoreRDDSuite.scala:152) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13$$anonfun$apply$6.apply(StateStoreRDDSuite.scala:149) at org.apache.spark.sql.catalyst.util.package$.quietly(package.scala:42) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13.apply(StateStoreRDDSuite.scala:149) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13.apply(StateStoreRDDSuite.scala:149) ... - distributed test *** FAILED *** java.io.IOException: Failed to create a temp directory (under /.../spark/sql/core/target/tmp/StateStoreRDDSuite8712796397908632676) after 10 attempts! at org.apache.spark.util.Utils$.createDirectory(Utils.scala:295) ``` ## How was this patch tested? Pass the existing tests.StateStoreRDDSuite: Author: Dongjoon Hyun Closes #21446 from dongjoon-hyun/SPARK-19613. --- .../execution/streaming/state/StateStoreRDDSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 65b39f0fbd73d..579a364ebc3e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -55,7 +55,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) @@ -73,7 +73,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString def makeStoreRDD( spark: SparkSession, @@ -101,7 +101,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("usage with iterators - only gets and only puts") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 // Returns an iterator of the incremented value made into the store @@ -149,7 +149,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn quietly { val queryRunId = UUID.randomUUID val opId = 0 - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext @@ -189,7 +189,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) .getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) From 2ced6193b39dc63e5f74138859f2a9d69d3cfd11 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 29 May 2018 10:48:48 +0800 Subject: [PATCH 0883/2461] [SPARK-24377][SPARK SUBMIT] make --py-files work in non pyspark application ## What changes were proposed in this pull request? For some Spark applications, though they're a java program, they require not only jar dependencies, but also python dependencies. One example is Livy remote SparkContext application, this application is actually an embedded REPL for Scala/Python/R, it will not only load in jar dependencies, but also python and R deps, so we should specify not only "--jars", but also "--py-files". Currently for a Spark application, --py-files can only be worked for a pyspark application, so it will not be worked in the above case. So here propose to remove such restriction. Also we tested that "spark.submit.pyFiles" only supports quite limited scenario (client mode with local deps), so here also expand the usage of "spark.submit.pyFiles" to be alternative of --py-files. ## How was this patch tested? UT added. Author: jerryshao Closes #21420 from jerryshao/SPARK-24377. --- .../org/apache/spark/deploy/SparkSubmit.scala | 11 ++--- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../spark/deploy/SparkSubmitSuite.scala | 46 ++++++++++++++++++- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4baf032f0e9c6..a46af26feb061 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -430,18 +430,15 @@ private[spark] class SparkSubmit extends Logging { // Usage: PythonAppRunner
    [app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs - if (clusterManager != YARN) { - // The YARN backend distributes the primary file differently, so don't merge it. - args.files = mergeFileLists(args.files, args.primaryResource) - } } if (clusterManager != YARN) { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (localPyFiles != null) { - sparkConf.set("spark.submit.pyFiles", localPyFiles) - } + } + + if (localPyFiles != null) { + sparkConf.set("spark.submit.pyFiles", localPyFiles) } // In YARN mode for an R app, add the SparkR package archive and the R package diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fed4e0a5069c3..fb232101114b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -182,6 +182,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull + pyFiles = Option(pyFiles).orElse(sparkProperties.get("spark.submit.pyFiles")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull @@ -280,9 +281,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } - if (pyFiles != null && !isPython) { - error("--py-files given but primary resource is not a Python script") - } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 43286953e4383..545c8d0423dc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -771,9 +771,13 @@ class SparkSubmitSuite PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val pyFile1 = File.createTempFile("file1", ".py", tmpDir) + val pyFile2 = File.createTempFile("file2", ".py", tmpDir) val writer4 = new PrintWriter(f4) - val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" writer4.println("spark.submit.pyFiles " + remotePyFiles) writer4.close() val clArgs4 = Seq( @@ -783,7 +787,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -1093,6 +1097,44 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + test("support --py-files/spark.submit.pyFiles in non pyspark application") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + + conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.submit.pyFiles") should (startWith("/")) + + // Verify "spark.submit.pyFiles" + val args1 = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs1 = new SparkSubmitArguments(args1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) + + conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf1.get("spark.submit.pyFiles") should (startWith("/")) + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { From 23db600c956cff4d0b20c38ddd2d746de2b535a0 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 28 May 2018 23:23:22 -0700 Subject: [PATCH 0884/2461] [SPARK-24250][SQL][FOLLOW-UP] support accessing SQLConf inside tasks ## What changes were proposed in this pull request? We should not stop users from calling `getActiveSession` and `getDefaultSession` in executors. To not break the existing behaviors, we should simply return None. ## How was this patch tested? N/A Author: Xiao Li Closes #21436 from gatorsmile/followUpSPARK-24250. --- .../org/apache/spark/sql/SparkSession.scala | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..565042fcf762e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1021,21 +1021,33 @@ object SparkSession extends Logging { /** * Returns the active SparkSession for the current thread, returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(activeThreadSession.get) + } } /** * Returns the default SparkSession that is returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(defaultSession.get) + } } /** From aca65c63cb12073eb193fe08998994c60acb8b58 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 29 May 2018 20:10:59 +0800 Subject: [PATCH 0885/2461] [SPARK-23991][DSTREAMS] Fix data loss when WAL write fails in allocateBlocksToBatch When blocks tried to get allocated to a batch and WAL write fails then the blocks will be removed from the received block queue. This fact simply produces data loss because the next allocation will not find the mentioned blocks in the queue. In this PR blocks will be removed from the received queue only if WAL write succeded. Additional unit test. Author: Gabor Somogyi Closes #21430 from gaborgsomogyi/SPARK-23991. Change-Id: I5ead84f0233f0c95e6d9f2854ac2ff6906f6b341 --- .../scheduler/ReceivedBlockTracker.scala | 3 +- .../streaming/ReceivedBlockTrackerSuite.scala | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index dacff69d55dd2..cf4324578ea87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -112,10 +112,11 @@ private[streaming] class ReceivedBlockTracker( def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { val streamIdToBlocks = streamIds.map { streamId => - (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + (streamId, getReceivedBlockQueue(streamId).clone()) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + streamIds.foreach(getReceivedBlockQueue(_).clear()) timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 4fa236bd39663..fd7e00b1de25f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -26,10 +26,12 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration +import org.mockito.Matchers.any +import org.mockito.Mockito.{doThrow, reset, spy} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -115,6 +117,47 @@ class ReceivedBlockTrackerSuite tracker2.stop() } + test("block allocation to batch should not loose blocks from received queue") { + val tracker1 = spy(createTracker()) + tracker1.isWriteAheadLogEnabled should be (true) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + // Add blocks + val blockInfos = generateBlockInfos() + blockInfos.map(tracker1.addBlock) + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Try to allocate the blocks to a batch and verify that it's failing + // The blocks should stay in the received queue when WAL write failing + doThrow(new RuntimeException("Not able to write BatchAllocationEvent")) + .when(tracker1).writeToLog(any(classOf[BatchAllocationEvent])) + val errMsg = intercept[RuntimeException] { + tracker1.allocateBlocksToBatch(1) + } + assert(errMsg.getMessage === "Not able to write BatchAllocationEvent") + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + tracker1.getBlocksOfBatch(1) shouldEqual Map.empty + tracker1.getBlocksOfBatchAndStream(1, streamId) shouldEqual Seq.empty + + // Allocate the blocks to a batch and verify that all of them have been allocated + reset(tracker1) + tracker1.allocateBlocksToBatch(2) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker1.hasUnallocatedReceivedBlocks should be (false) + tracker1.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker1.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + + tracker1.stop() + + // Recover from WAL to see the correctness + val tracker2 = createTracker(recoverFromWriteAheadLog = true) + tracker2.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker2.hasUnallocatedReceivedBlocks should be (false) + tracker2.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker2.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates @@ -312,7 +355,7 @@ class ReceivedBlockTrackerSuite recoverFromWriteAheadLog: Boolean = false, clock: Clock = new SystemClock): ReceivedBlockTracker = { val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None - val tracker = new ReceivedBlockTracker( + var tracker = new ReceivedBlockTracker( conf, hadoopConf, Seq(streamId), clock, recoverFromWriteAheadLog, cpDirOption) allReceivedBlockTrackers += tracker tracker From 900bc1f7dc5a7c013b473fceab1c4052ade74a2f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 29 May 2018 10:22:18 -0700 Subject: [PATCH 0886/2461] [SPARK-24371][SQL] Added isInCollection in DataFrame API for Scala and Java. ## What changes were proposed in this pull request? Implemented **`isInCollection `** in DataFrame API for both Scala and Java, so users can do ```scala val profileDF = Seq( Some(1), Some(2), Some(3), Some(4), Some(5), Some(6), Some(7), None ).toDF("profileID") val validUsers: Seq[Any] = Seq(6, 7.toShort, 8L, "3") val result = profileDF.withColumn("isValid", $"profileID". isInCollection(validUsers)) result.show(10) """ +---------+-------+ |profileID|isValid| +---------+-------+ | 1| false| | 2| false| | 3| true| | 4| false| | 5| false| | 6| true| | 7| true| | null| null| +---------+-------+ """.stripMargin ``` ## How was this patch tested? Several unit tests are added. Author: DB Tsai Closes #21416 from dbtsai/optimize-set. --- .../sql/catalyst/optimizer/expressions.scala | 1 - .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 19 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 62 ++++++++++++++++++- 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e487693927ab6..c486ad700f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..b3e59f53ee3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging { @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group expr_ops + * @since 2.4.0 + */ + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Scala Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Java Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { From f48938800e6dc3880441f160dd93856b9f86874e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 30 May 2018 09:32:33 +0800 Subject: [PATCH 0887/2461] [SPARK-24365][SQL] Add Data Source write benchmark ## What changes were proposed in this pull request? Add Data Source write benchmark. So that it would be easier to measure the writer performance. Author: Gengliang Wang Closes #21409 from gengliangwang/parquetWriteBenchmark. --- .../benchmark/DataSourceWriteBenchmark.scala | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2d2cdebd067c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure data source write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: + * spark-submit --class + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + */ +object DataSourceWriteBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceWriteBenchmark") + .setIfMissing("spark.master", "local[1]") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.orc.compression.codec", "snappy") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + val tempTable = "temp" + val numRows = 1024 * 1024 * 15 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = { + spark.sql(s"create table $table(id $dataType) using $format") + benchmark.addCase(s"Output Single $dataType Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable") + } + } + + def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format") + benchmark.addCase("Output Int and String Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS STRING) AS c2 FROM $tempTable") + } + } + + def writePartition(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)") + benchmark.addCase("Output Partitions") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," + + s" CAST(id % 2 AS INT) AS p FROM $tempTable") + } + } + + def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS") + benchmark.addCase("Output Buckets") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS INT) AS c2 FROM $tempTable") + } + } + + def main(args: Array[String]): Unit = { + val tableInt = "tableInt" + val tableDouble = "tableDouble" + val tableIntString = "tableIntString" + val tablePartition = "tablePartition" + val tableBucket = "tableBucket" + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + withTempTable(tempTable) { + spark.range(numRows).createOrReplaceTempView(tempTable) + formats.foreach { format => + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() + } + } + } + } +} From a4be981c0476bb613c660b70a370f671c8b3ffee Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 29 May 2018 23:26:39 -0700 Subject: [PATCH 0888/2461] [SPARK-24331][SPARKR][SQL] Adding arrays_overlap, array_repeat, map_entries to SparkR ## What changes were proposed in this pull request? The PR adds functions `arrays_overlap`, `array_repeat`, `map_entries` to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Examples ### arrays_overlap ``` df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), list(list(1L, 2L), list(3L, 4L)), list(list(1L, NA), list(3L, 4L)))) collect(select(df, arrays_overlap(df[[1]], df[[2]]))) ``` ``` arrays_overlap(_1, _2) 1 TRUE 2 FALSE 3 NA ``` ### array_repeat ``` df <- createDataFrame(list(list("a", 3L), list("b", 2L))) collect(select(df, array_repeat(df[[1]], df[[2]]))) ``` ``` array_repeat(_1, _2) 1 a, a, a 2 b, b ``` ``` collect(select(df, array_repeat(df[[1]], 2L))) ``` ``` array_repeat(_1, 2) 1 a, a 2 b, b ``` ### map_entries ``` df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) collect(select(df, map_entries(df$map))) ``` ``` map_entries(map) 1 x, 1, y, 2 ``` Author: Marek Novotny Closes #21434 from mn-mikke/SPARK-24331. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/DataFrame.R | 2 + R/pkg/R/functions.R | 58 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 12 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 22 +++++++++- 5 files changed, 91 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c575fe255f57a..73a33af4dd48b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,7 +204,9 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_repeat", "array_sort", + "arrays_overlap", "asc", "ascii", "asin", @@ -302,6 +304,7 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", "map_keys", "map_values", "max", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index fcb3521f901ea..abc91aeeb4825 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,7 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. #' @param value A value to compute on. #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. @@ -207,7 +208,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -216,11 +217,10 @@ NULL #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) -#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) -#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL @@ -3048,6 +3048,26 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + #' @details #' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array #' must be orderable. NA elements will be placed at the end of the returned array. @@ -3062,6 +3082,21 @@ setMethod("array_sort", column(jc) }) +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3076,6 +3111,19 @@ setMethod("flatten", column(jc) }) +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3ea181157b644..8894cb1c5b92f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,10 +769,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1034,6 +1042,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 13b55ac6e6e3c..16c1fd5a065eb 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,21 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1531,8 +1546,13 @@ test_that("column functions", { result <- collect(select(df, flatten(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) - # Test map_keys(), map_values() and element_at() + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) From 0ebb0c0d4dd3e192464dc5e0e6f01efa55b945ed Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Wed, 30 May 2018 18:11:33 +0800 Subject: [PATCH 0889/2461] [SPARK-23754][PYTHON] Re-raising StopIteration in client code ## What changes were proposed in this pull request? Make sure that `StopIteration`s raised in users' code do not silently interrupt processing by spark, but are raised as exceptions to the users. The users' functions are wrapped in `safe_iter` (in `shuffle.py`), which re-raises `StopIteration`s as `RuntimeError`s ## How was this patch tested? Unit tests, making sure that the exceptions are indeed raised. I am not sure how to check whether a `Py4JJavaError` contains my exception, so I simply looked for the exception message in the java exception's `toString`. Can you propose a better way? ## License This is my original work, licensed in the same way as spark Author: e-dorigatti Author: edorigatti Closes #21383 from e-dorigatti/fix_spark_23754. --- python/pyspark/rdd.py | 18 ++++++++++--- python/pyspark/shuffle.py | 7 ++--- python/pyspark/sql/tests.py | 16 +++++++++++ python/pyspark/sql/udf.py | 14 ++++++++-- python/pyspark/tests.py | 53 +++++++++++++++++++++++++++++++++++++ python/pyspark/util.py | 28 +++++++++++++++++--- 6 files changed, 125 insertions(+), 11 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d5a237a5b2855..14d9128502ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -417,7 +418,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -798,6 +799,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -847,6 +850,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -918,6 +923,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -950,6 +957,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1643,6 +1653,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c7bd8f01b907f..a2450932e303d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,22 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + func = fail_on_stopiteration(self.func) + + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + func._argspec = _get_argspec(self.func) + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..3b37cc028c1b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1246,6 +1277,28 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,11 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - if sys.version_info[0] < 3: + + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here + argspec = f._argspec + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9e7bad0edd9f6c59c0af21c95e5df98cf82150d3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 30 May 2018 05:18:18 -0700 Subject: [PATCH 0890/2461] [SPARK-24419][BUILD] Upgrade SBT to 0.13.17 with Scala 2.10.7 for JDK9+ ## What changes were proposed in this pull request? Upgrade SBT to 0.13.17 with Scala 2.10.7 for JDK9+ ## How was this patch tested? Existing tests Author: DB Tsai Closes #21458 from dbtsai/sbt. --- project/build.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/build.properties b/project/build.properties index b19518fd7aa1c..d03985d980ec8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.16 +sbt.version=0.13.17 From 1e46f92f956a00d04d47340489b6125d44dbd47b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 31 May 2018 00:23:25 +0800 Subject: [PATCH 0891/2461] [SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set ## What changes were proposed in this pull request? This pr fixed an issue when having multiple distinct aggregations having the same argument set, e.g., ``` scala>: paste val df = sql( s"""SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) | FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) """.stripMargin) java.lang.RuntimeException You hit a query analyzer bug. Please report your query to Spark user mailing list. ``` The root cause is that `RewriteDistinctAggregates` can't detect multiple distinct aggregations if they have the same argument set. This pr modified code so that `RewriteDistinctAggregates` could count the number of aggregate expressions with `isDistinct=true`. ## How was this patch tested? Added tests in `DataFrameAggregateSuite`. Author: Takeshi Yamamuro Closes #21443 from maropu/SPARK-24369. --- .../optimizer/RewriteDistinctAggregates.scala | 7 ++++--- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/resources/sql-tests/inputs/group-by.sql | 6 +++++- .../test/resources/sql-tests/results/group-by.sql.out | 11 ++++++++++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 4448ace7105a4..bc898ab0dc723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,7 +115,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => + val distincgAggExpressions = aggExpressions.filter(_.isDistinct) + val distinctAggGroups = distincgAggExpressions.groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -132,7 +133,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size > 1) { + if (distincgAggExpressions.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -151,7 +152,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..b9452b58657a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.partition(_.isDistinct) if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 From b142157dcc7f595eea93d66dda8b1d169a38d95c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 30 May 2018 10:33:34 -0700 Subject: [PATCH 0892/2461] [SPARK-24384][PYTHON][SPARK SUBMIT] Add .py files correctly into PythonRunner in submit with client mode in spark-submit ## What changes were proposed in this pull request? In client side before context initialization specifically, .py file doesn't work in client side before context initialization when the application is a Python file. See below: ``` $ cat /home/spark/tmp.py def testtest(): return 1 ``` This works: ``` $ cat app.py import pyspark pyspark.sql.SparkSession.builder.getOrCreate() import tmp print("************************%s" % tmp.testtest()) $ ./bin/spark-submit --master yarn --deploy-mode client --py-files /home/spark/tmp.py app.py ... ************************1 ``` but this doesn't: ``` $ cat app.py import pyspark import tmp pyspark.sql.SparkSession.builder.getOrCreate() print("************************%s" % tmp.testtest()) $ ./bin/spark-submit --master yarn --deploy-mode client --py-files /home/spark/tmp.py app.py Traceback (most recent call last): File "/home/spark/spark/app.py", line 2, in import tmp ImportError: No module named tmp ``` ### How did it happen? In client mode specifically, the paths are being added into PythonRunner as are: https://github.com/apache/spark/blob/628c7b517969c4a7ccb26ea67ab3dd61266073ca/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L430 https://github.com/apache/spark/blob/628c7b517969c4a7ccb26ea67ab3dd61266073ca/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala#L49-L88 The problem here is, .py file shouldn't be added as are since `PYTHONPATH` expects a directory or an archive like zip or egg. ### How does this PR fix? We shouldn't simply just add its parent directory because other files in the parent directory could also be added into the `PYTHONPATH` in client mode before context initialization. Therefore, we copy .py files into a temp directory for .py files and add it to `PYTHONPATH`. ## How was this patch tested? Unit tests are added and manually tested in both standalond and yarn client modes with submit. Author: hyukjinkwon Closes #21426 from HyukjinKwon/SPARK-24384. --- .../apache/spark/deploy/PythonRunner.scala | 29 ++++++++++++++++++- .../spark/deploy/yarn/YarnClusterSuite.scala | 15 ++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 1b7e031ee0678..ccb30e205ca40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io.File import java.net.{InetAddress, URI} +import java.nio.file.Files import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -48,7 +49,7 @@ object PythonRunner { // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) - val formattedPyFiles = formatPaths(pyFiles) + val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such @@ -153,4 +154,30 @@ object PythonRunner { .map { p => formatPath(p, testWindows) } } + /** + * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does + * not expect a file. This method creates a temporary directory and puts the ".py" files + * if exist in the given paths. + */ + private def resolvePyFiles(pyFiles: Array[String]): Array[String] = { + lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles") + pyFiles.flatMap { pyFile => + // In case of client with submit, the python paths should be set before context + // initialization because the context initialization can be done later. + // We will copy the local ".py" files because ".py" file shouldn't be added + // alone but its parent directory in PYTHONPATH. See SPARK-24384. + if (pyFile.endsWith(".py")) { + val source = new File(pyFile) + if (source.exists() && source.isFile && source.canRead) { + Files.copy(source.toPath, new File(dest, source.getName).toPath) + Some(dest.getAbsolutePath) + } else { + // Don't have to add it if it doesn't exist or isn't readable. + None + } + } else { + Some(pyFile) + } + }.distinct + } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59b0f29e37d84..3b78b88de778d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -271,16 +271,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv - val moduleDir = - if (clientMode) { - // In client-mode, .py files added with --py-files are not visible in the driver. - // This is something that the launcher library would have to handle. - tempDir - } else { - val subdir = new File(tempDir, "pyModules") - subdir.mkdir() - subdir - } + val moduleDir = { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } val pyModule = new File(moduleDir, "mod1.py") Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) From ec6f971dc57bcdc0ad65ac1987b6f0c1801157f4 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 30 May 2018 11:04:09 -0700 Subject: [PATCH 0893/2461] [SPARK-23161][PYSPARK][ML] Add missing APIs to Python GBTClassifier ## What changes were proposed in this pull request? Add featureSubsetStrategy in GBTClassifier and GBTRegressor. Also make GBTClassificationModel inherit from JavaClassificationModel instead of prediction model so it will have numClasses. ## How was this patch tested? Add tests in doctest Author: Huaxin Gao Closes #21413 from huaxingao/spark-23161. --- python/pyspark/ml/classification.py | 35 ++++++++++++--- python/pyspark/ml/regression.py | 70 ++++++++++++++++++----------- 2 files changed, 74 insertions(+), 31 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 424ecfd89b060..1754c48937a62 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1131,6 +1131,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): @@ -1193,6 +1200,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) + >>> gbt.getFeatureSubsetStrategy() + 'all' >>> model = gbt.fit(td) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1226,6 +1235,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol ... ["indexed", "features"]) >>> model.evaluateEachIteration(validation) [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] + >>> model.numClasses + 2 .. versionadded:: 1.4.0 """ @@ -1244,19 +1255,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, + featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1265,12 +1279,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1293,8 +1309,15 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + -class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, +class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dd0b62f184d26..dba0e57b01a0b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -602,6 +602,14 @@ class TreeEnsembleParams(DecisionTreeParams): "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat) + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", + typeConverter=TypeConverters.toString) + def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -619,6 +627,22 @@ def getSubsamplingRate(self): """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + + .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0. + """ + return self._set(featureSubsetStrategy=value) + + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + class TreeRegressorParams(Params): """ @@ -654,14 +678,8 @@ class RandomForestParams(TreeEnsembleParams): Private class to track supported random forest parameters. """ - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt) - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -680,20 +698,6 @@ def getNumTrees(self): """ return self.getOrDefault(self.numTrees) - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - return self._set(featureSubsetStrategy=value) - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class GBTParams(TreeEnsembleParams): """ @@ -981,6 +985,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): @@ -1029,6 +1040,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> print(gbt.getImpurity()) variance + >>> print(gbt.getFeatureSubsetStrategy()) + all >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1079,20 +1092,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance"): + impurity="variance", featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance") + impurity="variance", featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1102,13 +1115,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance"): + impuriy="variance", featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1131,6 +1144,13 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ From 1b36f148891ac41ef36a40366f87dd5405cb3751 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 11:18:04 -0700 Subject: [PATCH 0894/2461] [SPARK-23901][SQL] Add masking functions ## What changes were proposed in this pull request? The PR adds the masking function as they are described in Hive's documentation: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-DataMaskingFunctions. This means that only `string`s are accepted as parameter for the masking functions. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21246 from mgaido91/SPARK-23901. --- .../expressions/MaskExpressionsUtils.java | 80 +++ .../catalyst/analysis/FunctionRegistry.scala | 8 + .../expressions/maskExpressions.scala | 569 ++++++++++++++++++ .../expressions/MaskExpressionsSuite.scala | 236 ++++++++ .../org/apache/spark/sql/functions.scala | 119 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 107 ++++ 6 files changed, 1119 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java new file mode 100644 index 0000000000000..05879902a4ed9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Contains all the Utils methods used in the masking expressions. + */ +public class MaskExpressionsUtils { + static final int UNMASKED_VAL = -1; + + /** + * Returns the masking character for {@param c} or {@param c} is it should not be masked. + * @param c the character to transform + * @param maskedUpperChar the character to use instead of a uppercase letter + * @param maskedLowerChar the character to use instead of a lowercase letter + * @param maskedDigitChar the character to use instead of a digit + * @param maskedOtherChar the character to use instead of a any other character + * @return masking character for {@param c} + */ + public static int transformChar( + final int c, + int maskedUpperChar, + int maskedLowerChar, + int maskedDigitChar, + int maskedOtherChar) { + switch(Character.getType(c)) { + case Character.UPPERCASE_LETTER: + if(maskedUpperChar != UNMASKED_VAL) { + return maskedUpperChar; + } + break; + + case Character.LOWERCASE_LETTER: + if(maskedLowerChar != UNMASKED_VAL) { + return maskedLowerChar; + } + break; + + case Character.DECIMAL_DIGIT_NUMBER: + if(maskedDigitChar != UNMASKED_VAL) { + return maskedDigitChar; + } + break; + + default: + if(maskedOtherChar != UNMASKED_VAL) { + return maskedOtherChar; + } + break; + } + + return c; + } + + /** + * Returns the replacement char to use according to the {@param rep} specified by the user and + * the {@param def} default. + */ + public static int getReplacementChar(String rep, int def) { + if (rep != null && rep.length() > 0) { + return rep.codePointAt(0); + } + return def; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1134a8866dc13..23a4a440fac23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -432,6 +432,14 @@ object FunctionRegistry { expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, + // mask functions + expression[Mask]("mask"), + expression[MaskFirstN]("mask_first_n"), + expression[MaskLastN]("mask_last_n"), + expression[MaskShowFirstN]("mask_show_first_n"), + expression[MaskShowLastN]("mask_show_last_n"), + expression[MaskHash]("mask_hash"), + // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala new file mode 100644 index 0000000000000..276a57266a6e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ +import org.apache.spark.sql.catalyst.expressions.MaskLike._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait MaskLike { + def upper: String + def lower: String + def digit: String + + protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) + protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) + protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) + + protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName + + def inputStringLengthCode(inputString: String, length: String): String = { + s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + } + + def appendMaskedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, + | $upperReplacement, $lowerReplacement, + | $digitReplacement, $defaultMaskedOther)); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendUnchangedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($codePoint); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendMaskedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(transformChar( + codePoint, + upperReplacement, + lowerReplacement, + digitReplacement, + defaultMaskedOther)) + offset += Character.charCount(codePoint) + } + offset + } + + def appendUnchangedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(codePoint) + offset += Character.charCount(codePoint) + } + offset + } +} + +trait MaskLikeWithN extends MaskLike { + def n: Int + protected lazy val charCount: Int = if (n < 0) 0 else n +} + +/** + * Utils for mask operations. + */ +object MaskLike { + val defaultCharCount = 4 + val defaultMaskedUppercase: Int = 'X' + val defaultMaskedLowercase: Int = 'x' + val defaultMaskedDigit: Int = 'n' + val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL + + def extractCharCount(e: Expression): Int = e match { + case Literal(i, IntegerType | NullType) => + if (i == null) defaultCharCount else i.asInstanceOf[Int] + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } + + def extractReplacement(e: Expression): String = e match { + case Literal(s, StringType | NullType) => if (s == null) null else s.toString + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${StringType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } +} + +/** + * Masks the input string. Additional parameters can be set to change the masking chars for + * uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); + llll-UUUU-####-#### + """) +// scalastyle:on line.size.limit +case class Mask(child: Expression, upper: String, lower: String, digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLike { + + def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) + + def this(child: Expression, upper: Expression) = + this(child, extractReplacement(upper), null, null) + + def this(child: Expression, upper: Expression, lower: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), null) + + def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val sb = new java.lang.StringBuilder(length) + appendMaskedToStringBuilder(sb, str, 0, length) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |StringBuilder $sb = new StringBuilder($length); + |${CodeGenerator.JAVA_INT} $offset = 0; + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} + |${ev.value} = UTF8String.fromString($sb.toString()); + """.stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) +} + +/** + * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-5678-8765-4321 + """) +// scalastyle:on line.size.limit +case class MaskFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_first_n" +} + +/** + * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-5678-8765-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? + | 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_last_n" +} + +/** + * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-nnnn-nnnn-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_first_n" +} + +/** + * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-nnnn-nnnn-4321 + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_last_n" +} + +/** + * Returns a hashed value based on str. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321"); + 60c713f5ec6912229d2060df1c322776 + """) +// scalastyle:on line.size.limit +case class MaskHash(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def nullSafeEval(input: Any): Any = { + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") + s""" + |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_hash" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala new file mode 100644 index 0000000000000..4d69dc32ace82 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.{IntegerType, StringType} + +class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("mask") { + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") + checkEvaluation( + new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-####-####") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), + "llll-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal(null, StringType)), null) + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") + checkEvaluation(new Mask( + Literal("abcd-EFGH-8765-4321"), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), + "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("")), "") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") + checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), + "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") + // scalastyle:on nonascii + intercept[AnalysisException] { + checkEvaluation(new Mask(Literal(""), Literal(1)), "") + } + } + + test("mask_first_n") { + checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UFGH-8765") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UFGH-8765-4321") + checkEvaluation( + new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XFGH-8765-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-EFGH-8765-4321") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("")), "") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-EFGH-8765-4321") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xxxxo World") + // scalastyle:on nonascii + } + + test("mask_last_n") { + checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), + "abcd-EFGU-lU#l") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EFGU-####") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), + "abcd-EFGX-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") + } + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal(null, StringType)), null) + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), + "abcd-EFUU-nnnn-nnnn") + checkEvaluation(new MaskLastN(Literal("")), "") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), + "abcx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), + "あ, 𠀋, Hello Xxxxx あ 𠀋") + // scalastyle:on nonascii + } + + test("mask_show_first_n") { + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), + "abcd-EUUU-####-lU#l") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EUUU-####-####") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "abcd-EXXX-nnnn-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-UUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("")), "") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "abcd-XXXX-nnnn-nnnn") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hellx Xxxxx") + // scalastyle:on nonascii + } + + test("mask_show_last_n") { + checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UUUH-8765") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-###5-4321") + checkEvaluation( + new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XXXX-nnn5-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-UUUU-nnnn-4321") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("")), "") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-XXXX-nnnn-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xello World") + // scalastyle:on nonascii + } + + test("mask_hash") { + checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") + checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") + checkEvaluation(MaskHash(Literal(null, StringType)), null) + // scalastyle:off nonascii + checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") + // scalastyle:on nonascii + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5ab9cb3fb86a5..443ba2aa3757d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3499,6 +3499,125 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Mask functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns a string which is the masked representation of the input. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column): Column = withExpr { new Mask(e.expr) } + + /** + * Returns a string which is the masked representation of the input, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { + Mask(e.expr, upper, lower, digit) + } + + /** + * Returns a string with the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } + + /** + * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } + + /** + * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n(e: Column, n: Int): Column = withExpr { + new MaskShowFirstN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n(e: Column, n: Int): Column = withExpr { + new MaskShowLastN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a hashed value based on the input column. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 79e743d961af8..cc8bad4ded53e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,6 +276,113 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("mask functions") { + val df = Seq("TestString-123", "", null).toDF("a") + checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4)), + Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4)), + Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_hash($"a")), + Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), + Row("d41d8cd98f00b204e9800998ecf8427e"), + Row(null))) + + checkAnswer(df.select(mask($"a", "U", "l", "#")), + Seq(Row("UlllUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), + Seq(Row("TestString-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), + Seq(Row("TestUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllUlllll-123"), Row(""), Row(null))) + + checkAnswer( + df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), + Row("", "", "", ""), + Row(null, null, null, null))) + checkAnswer(sql("select mask(null)"), Row(null)) + checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) + intercept[AnalysisException] { + checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + } + + checkAnswer( + df.selectExpr( + "mask_first_n(a)", + "mask_first_n(a, 6)", + "mask_first_n(a, 6, 'U')", + "mask_first_n(a, 6, 'U', 'l')", + "mask_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", + "UlllUlring-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) + } + + checkAnswer( + df.selectExpr( + "mask_last_n(a)", + "mask_last_n(a, 6)", + "mask_last_n(a, 6, 'U')", + "mask_last_n(a, 6, 'U', 'l')", + "mask_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", + "TestStrill-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_first_n(a)", + "mask_show_first_n(a, 6)", + "mask_show_first_n(a, 6, 'U')", + "mask_show_first_n(a, 6, 'U', 'l')", + "mask_show_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", + "TestStllll-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_last_n(a)", + "mask_show_last_n(a, 6)", + "mask_show_last_n(a, 6, 'U')", + "mask_show_last_n(a, 6, 'U', 'l')", + "mask_show_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", + "UlllUlllng-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) + } + + checkAnswer(sql("select mask_hash(null)"), Row(null)) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 24ef7fbfa94fc85663eaf49cc574f81387c66c62 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 15:31:40 -0700 Subject: [PATCH 0895/2461] [SPARK-24276][SQL] Order of literals in IN should not affect semantic equality ## What changes were proposed in this pull request? When two `In` operators are created with the same list of values, but different order, we are considering them as semantically different. This is wrong, since they have the same semantic meaning. The PR adds a canonicalization rule which orders the literals in the `In` operator so the semantic equality works properly. ## How was this patch tested? added UT Author: Marco Gaido Closes #21331 from mgaido91/SPARK-24276. --- .../catalyst/expressions/Canonicalize.scala | 6 +++ .../expressions/CanonicalizeSuite.scala | 53 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index d848ba18356d3..7541f527a52a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + * - Elements in [[In]] are reordered by `hashCode`. */ object Canonicalize { def execute(e: Expression): Expression = { @@ -85,6 +86,11 @@ object Canonicalize { case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + // order the list in the In operator + // In subqueries contain only one element of type ListQuery. So checking that the length > 1 + // we are not reordering In subqueries. + case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case _ => e } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala new file mode 100644 index 0000000000000..28e6940f3cca3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Range + +class CanonicalizeSuite extends SparkFunSuite { + + test("SPARK-24276: IN expression with different order are semantically equal") { + val range = Range(1, 1, 1, 1) + val idAttr = range.output.head + + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + + assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) + assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) + + assert(range.where(in1).sameResult(range.where(in2))) + assert(!range.where(in1).sameResult(range.where(in3))) + + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(2), Literal(1))))) + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + CreateArray(Seq(Literal(1), Literal(2))))) + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(3), Literal(1))))) + + assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) + assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash()) + + assert(range.where(arrays1).sameResult(range.where(arrays2))) + assert(!range.where(arrays1).sameResult(range.where(arrays3))) + } +} From 0053e153faaa76ea38c845adab137d5be970e5af Mon Sep 17 00:00:00 2001 From: William Sheu Date: Wed, 30 May 2018 22:37:27 -0700 Subject: [PATCH 0896/2461] [SPARK-24337][CORE] Improve error messages for Spark conf values ## What changes were proposed in this pull request? Improve the exception messages when retrieving Spark conf values to include the key name when the value is invalid. ## How was this patch tested? Unit tests for all get* operations in SparkConf that require a specific value format Author: William Sheu Closes #21454 from PenguinToast/SPARK-24337-spark-config-errors. --- .../scala/org/apache/spark/SparkConf.scala | 85 ++++++++++++++----- .../org/apache/spark/SparkConfSuite.scala | 32 +++++++ 2 files changed, 96 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index dab409572646f..6c4c5c94cfa28 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -265,16 +265,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String): Long = { + def getTimeAsSeconds(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key)) } /** * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String, defaultValue: String): Long = { + def getTimeAsSeconds(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key, defaultValue)) } @@ -282,16 +284,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String): Long = { + def getTimeAsMs(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key)) } /** * Get a time parameter as milliseconds, falling back to a default if not set. If no * suffix is provided then milliseconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String, defaultValue: String): Long = { + def getTimeAsMs(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key, defaultValue)) } @@ -299,23 +303,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String): Long = { + def getSizeAsBytes(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key)) } /** * Get a size parameter as bytes, falling back to a default if not set. If no * suffix is provided then bytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: String): Long = { + def getSizeAsBytes(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue)) } /** * Get a size parameter as bytes, falling back to a default if not set. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: Long): Long = { + def getSizeAsBytes(key: String, defaultValue: Long): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue + "B")) } @@ -323,16 +330,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String): Long = { + def getSizeAsKb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key)) } /** * Get a size parameter as Kibibytes, falling back to a default if not set. If no * suffix is provided then Kibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String, defaultValue: String): Long = { + def getSizeAsKb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key, defaultValue)) } @@ -340,16 +349,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String): Long = { + def getSizeAsMb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key)) } /** * Get a size parameter as Mebibytes, falling back to a default if not set. If no * suffix is provided then Mebibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String, defaultValue: String): Long = { + def getSizeAsMb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key, defaultValue)) } @@ -357,16 +368,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String): Long = { + def getSizeAsGb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key)) } /** * Get a size parameter as Gibibytes, falling back to a default if not set. If no * suffix is provided then Gibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String, defaultValue: String): Long = { + def getSizeAsGb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key, defaultValue)) } @@ -394,23 +407,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } - /** Get a parameter as an integer, falling back to a default if not set */ - def getInt(key: String, defaultValue: Int): Int = { + /** + * Get a parameter as an integer, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as an integer + */ + def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) { getOption(key).map(_.toInt).getOrElse(defaultValue) } - /** Get a parameter as a long, falling back to a default if not set */ - def getLong(key: String, defaultValue: Long): Long = { + /** + * Get a parameter as a long, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as a long + */ + def getLong(key: String, defaultValue: Long): Long = catchIllegalValue(key) { getOption(key).map(_.toLong).getOrElse(defaultValue) } - /** Get a parameter as a double, falling back to a default if not set */ - def getDouble(key: String, defaultValue: Double): Double = { + /** + * Get a parameter as a double, falling back to a default if not ste + * @throws NumberFormatException If the value cannot be interpreted as a double + */ + def getDouble(key: String, defaultValue: Double): Double = catchIllegalValue(key) { getOption(key).map(_.toDouble).getOrElse(defaultValue) } - /** Get a parameter as a boolean, falling back to a default if not set */ - def getBoolean(key: String, defaultValue: Boolean): Boolean = { + /** + * Get a parameter as a boolean, falling back to a default if not set + * @throws IllegalArgumentException If the value cannot be interpreted as a boolean + */ + def getBoolean(key: String, defaultValue: Boolean): Boolean = catchIllegalValue(key) { getOption(key).map(_.toBoolean).getOrElse(defaultValue) } @@ -448,6 +473,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def getenv(name: String): String = System.getenv(name) + /** + * Wrapper method for get() methods which require some specific value format. This catches + * any [[NumberFormatException]] or [[IllegalArgumentException]] and re-raises it with the + * incorrectly configured key in the exception message. + */ + private def catchIllegalValue[T](key: String)(getValue: => T): T = { + try { + getValue + } catch { + case e: NumberFormatException => + // NumberFormatException doesn't have a constructor that takes a cause for some reason. + throw new NumberFormatException(s"Illegal value for config key $key: ${e.getMessage}") + .initCause(e) + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Illegal value for config key $key: ${e.getMessage}", e) + } + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index bff808eb540ac..0d06b02e74e34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + val defaultIllegalValue = "SomeIllegalValue" + val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( + "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), + "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)), + "getTimeAsMs" -> (_.getTimeAsMs(_)), + "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)), + "getSizeAsBytes" -> (_.getSizeAsBytes(_)), + "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)), + "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)), + "getSizeAsKb" -> (_.getSizeAsKb(_)), + "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)), + "getSizeAsMb" -> (_.getSizeAsMb(_)), + "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)), + "getSizeAsGb" -> (_.getSizeAsGb(_)), + "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)), + "getInt" -> (_.getInt(_, 0)), + "getLong" -> (_.getLong(_, 0L)), + "getDouble" -> (_.getDouble(_, 0.0)), + "getBoolean" -> (_.getBoolean(_, false)) + ) + + illegalValueTests.foreach { case (name, getValue) => + test(s"SPARK-24337: $name throws an useful error message with key name") { + val key = "SomeKey" + val conf = new SparkConf() + conf.set(key, "SomeInvalidValue") + val thrown = intercept[IllegalArgumentException] { + getValue(conf, key) + } + assert(thrown.getMessage.contains(key)) + } + } } class Class1 {} From 90ae98d1accb3e4b7d381de072257bdece8dd7e0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 31 May 2018 06:53:10 -0700 Subject: [PATCH 0897/2461] [SPARK-24146][PYSPARK][ML] spark.ml parity for sequential pattern mining - PrefixSpan: Python API ## What changes were proposed in this pull request? spark.ml parity for sequential pattern mining - PrefixSpan: Python API ## How was this patch tested? doctests Author: WeichenXu Closes #21265 from WeichenXu123/prefix_span_py. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 6 +- python/pyspark/ml/fpm.py | 104 +++++++++++++++++- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 41716c621ca98..bd1c1a8885201 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + "sequential pattern. Sequential pattern that appears more than " + - "(minSupport * size-of-the-dataset)." + + "(minSupport * size-of-the-dataset) " + "times will be output.", ParamValidators.gtEq(0.0)) /** @group getParam */ @@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. * * @param dataset A dataset or a dataframe containing a sequence column which is - * {{{Seq[Seq[_]]}}} type + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: - * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) * - `freq: Long` */ @Since("2.4.0") diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354d..fd19fd96c4df6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,8 +16,9 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * __all__ = ["FPGrowth", "FPGrowthModel"] @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) From 698b9a0981f0ec322e15d6ac89cc38c8f49ed33d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 31 May 2018 09:34:39 -0700 Subject: [PATCH 0898/2461] [WEBUI] Avoid possibility of script in query param keys As discussed separately, this avoids the possibility of XSS on certain request param keys. CC vanzin Author: Sean Owen Closes #21464 from srowen/XSS2. --- .../src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 4 +++- core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index f651fe97c2cd5..178d2c8d1a10a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -206,7 +206,9 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b8b20db1fa407..56e4d6838a99a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -43,7 +43,9 @@ private[ui] class StageTableBase( killEnabled: Boolean, isFailedStage: Boolean) { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) From 7a82e93b349b4f414f2075dd5add8e4ed72fe357 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 31 May 2018 10:05:20 -0700 Subject: [PATCH 0899/2461] [SPARK-24414][UI] Calculate the correct number of tasks for a stage. This change takes into account all non-pending tasks when calculating the number of tasks to be shown. This also means that when the stage is pending, the task table (or, in fact, most of the data in the stage page) will not be rendered. I also fixed the label when the known number of tasks is larger than the recorded number of tasks (it was inverted). Author: Marcelo Vanzin Closes #21457 from vanzin/SPARK-24414. --- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2575914121c39..d4e6a7bc3effa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -117,8 +117,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks + val totalTasks = taskCount(stageData) if (totalTasks == 0) { val content =
    @@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${storedTasks}" + s"$storedTasks, showing ${totalTasks}" } val summary = @@ -686,7 +685,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numTasks + override def dataSize: Int = taskCount(stage) override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1052,4 +1051,9 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } + def taskCount(stageData: StageData): Int = { + stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + + stageData.numKilledTasks + } + } From 223df5d9d4fbf48db017edb41f9b7e4033679f35 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 31 May 2018 11:23:57 -0700 Subject: [PATCH 0900/2461] [SPARK-24397][PYSPARK] Added TaskContext.getLocalProperty(key) in Python ## What changes were proposed in this pull request? This adds a new API `TaskContext.getLocalProperty(key)` to the Python TaskContext. It mirrors the Java TaskContext API of returning a string value if the key exists, or None if the key does not exist. ## How was this patch tested? New test added. Author: Tathagata Das Closes #21437 from tdas/SPARK-24397. --- .../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++ python/pyspark/taskcontext.py | 7 +++++++ python/pyspark/tests.py | 14 ++++++++++++++ python/pyspark/worker.py | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e0eb0b4..41eac10d9b267 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9e75e78..63ae1f30e17ca 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -34,6 +34,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +89,9 @@ def taskAttemptId(self): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3b37cc028c1b7..30723b8e15b36 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -574,6 +574,20 @@ def test_tc_on_driver(self): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + class RDDTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5d2e58bef6466..fbcb8af8bfb24 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -222,6 +222,12 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() From cc976f6cb858adb5f52987b56dda54769915ce50 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 31 May 2018 11:38:23 -0700 Subject: [PATCH 0901/2461] [SPARK-23900][SQL] format_number support user specifed format as argument ## What changes were proposed in this pull request? `format_number` support user specifed format as argument. For example: ```sql spark-sql> SELECT format_number(12332.123456, '##################.###'); 12332.123 ``` ## How was this patch tested? unit test Author: Yuming Wang Closes #21010 from wangyum/SPARK-23900. --- .../expressions/stringExpressions.scala | 142 ++++++++++++------ .../expressions/StringExpressionsSuite.scala | 24 +++ 2 files changed, 116 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9823b2fc5ad97..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1916,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression) usage = """ _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If `expr2` is 0, the result has no decimal point or fractional part. + `expr2` also accept a user specified format. This is supposed to function like MySQL's FORMAT. """, examples = """ Examples: > SELECT _FUNC_(12332.123456, 4); 12,332.1235 + > SELECT _FUNC_(12332.123456, '##################.###'); + 12332.123 """) case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -1930,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(IntegerType, StringType)) + + private val defaultFormat = "#,###,###,###,###,###,##0" // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Option[Int] = None + private var lastDIntValue: Option[Int] = None + + @transient + private var lastDStringValue: Option[String] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @@ -1950,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression) private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - lastDValue match { - case Some(last) if last == dValue => - // use the current pattern - case _ => - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } + right.dataType match { + case IntegerType => + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null } - lastDValue = Some(dValue) + lastDIntValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append(defaultFormat) + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + + lastDIntValue = Some(dValue) - numberFormat.applyLocalizedPattern(pattern.toString) + numberFormat.applyLocalizedPattern(pattern.toString) + } + case StringType => + val dValue = dObject.asInstanceOf[UTF8String].toString + lastDStringValue match { + case Some(last) if last == dValue => + case _ => + pattern.delete(0, pattern.length) + lastDStringValue = Some(dValue) + if (dValue.isEmpty) { + numberFormat.applyLocalizedPattern(defaultFormat) + } else { + numberFormat.applyLocalizedPattern(dValue) + } + } } x.dataType match { @@ -2008,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + right.dataType match { + case IntegerType => + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val i = ctx.freshName("i") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("$defaultFormat"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $lastDValue = $d; + $numberFormat.applyLocalizedPattern($pattern.toString()); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } else { + ${ev.value} = null; + ${ev.isNull} = true; } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); - } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.value} = null; - ${ev.isNull} = true; - } - """ + """ + case StringType => + val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") + val dValue = ctx.freshName("dValue") + s""" + String $dValue = $d.toString(); + if (!$dValue.equals($lastDValue)) { + $lastDValue = $dValue; + if ($dValue.isEmpty()) { + $numberFormat.applyLocalizedPattern("$defaultFormat"); + } else { + $numberFormat.applyLocalizedPattern($dValue); + } + } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + """ + } }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f1a6f9b8889fa..aa334e040d5fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(12831273.23481d), + Literal("###,###,###,###,###.###")), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")), + "123,123,324,123") + checkEvaluation( + FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)), + Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null) + assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")), + "12,332") + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, StringType)), null) + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("find in set") { From 21e1fc7d4aed688d7b685be6ce93f76752159c98 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 31 May 2018 14:28:33 -0700 Subject: [PATCH 0902/2461] [SPARK-24232][K8S] Add support for secret env vars ## What changes were proposed in this pull request? * Allows to refer a secret as an env var. * Introduces new config properties in the form: spark.kubernetes{driver,executor}.secretKeyRef.ENV_NAME=name:key ENV_NAME is case sensitive. * Updates docs. * Adds required unit tests. ## How was this patch tested? Manually tested and confirmed that the secrets exist in driver's and executor's container env. Also job finished successfully. First created a secret with the following yaml: ``` apiVersion: v1 kind: Secret metadata: name: test-secret data: username: c3RhdnJvcwo= password: Mzk1MjgkdmRnN0pi ------- $ echo -n 'stavros' | base64 c3RhdnJvcw== $ echo -n '39528$vdg7Jb' | base64 MWYyZDFlMmU2N2Rm ``` Run a job as follows: ```./bin/spark-submit \ --master k8s://http://localhost:9000 \ --deploy-mode cluster \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=1 \ --conf spark.kubernetes.container.image=skonto/spark:k8envs3 \ --conf spark.kubernetes.driver.secretKeyRef.MY_USERNAME=test-secret:username \ --conf spark.kubernetes.driver.secretKeyRef.My_password=test-secret:password \ --conf spark.kubernetes.executor.secretKeyRef.MY_USERNAME=test-secret:username \ --conf spark.kubernetes.executor.secretKeyRef.My_password=test-secret:password \ local:///opt/spark/examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar 10000 ``` Secret loaded correctly at the driver container: ![image](https://user-images.githubusercontent.com/7945591/40174346-7fee70c8-59dd-11e8-8705-995a5472716f.png) Also if I log into the exec container: kubectl exec -it spark-pi-1526555613156-exec-1 bash bash-4.4# env > SPARK_EXECUTOR_MEMORY=1g > SPARK_EXECUTOR_CORES=1 > LANG=C.UTF-8 > HOSTNAME=spark-pi-1526555613156-exec-1 > SPARK_APPLICATION_ID=spark-application-1526555618626 > **MY_USERNAME=stavros** > > JAVA_HOME=/usr/lib/jvm/java-1.8-openjdk > KUBERNETES_PORT_443_TCP_PROTO=tcp > KUBERNETES_PORT_443_TCP_ADDR=10.100.0.1 > JAVA_VERSION=8u151 > KUBERNETES_PORT=tcp://10.100.0.1:443 > PWD=/opt/spark/work-dir > HOME=/root > SPARK_LOCAL_DIRS=/var/data/spark-b569b0ae-b7ef-4f91-bcd5-0f55535d3564 > KUBERNETES_SERVICE_PORT_HTTPS=443 > KUBERNETES_PORT_443_TCP_PORT=443 > SPARK_HOME=/opt/spark > SPARK_DRIVER_URL=spark://CoarseGrainedSchedulerspark-pi-1526555613156-driver-svc.default.svc:7078 > KUBERNETES_PORT_443_TCP=tcp://10.100.0.1:443 > SPARK_EXECUTOR_POD_IP=9.0.9.77 > TERM=xterm > SPARK_EXECUTOR_ID=1 > SHLVL=1 > KUBERNETES_SERVICE_PORT=443 > SPARK_CONF_DIR=/opt/spark/conf > PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/lib/jvm/java-1.8-openjdk/jre/bin:/usr/lib/jvm/java-1.8-openjdk/bin > JAVA_ALPINE_VERSION=8.151.12-r0 > KUBERNETES_SERVICE_HOST=10.100.0.1 > **My_password=39528$vdg7Jb** > _=/usr/bin/env > Author: Stavros Kontopoulos Closes #21317 from skonto/k8s-fix-env-secrets. --- docs/running-on-kubernetes.md | 22 +++++++ .../org/apache/spark/deploy/k8s/Config.scala | 2 + .../spark/deploy/k8s/KubernetesConf.scala | 11 +++- .../k8s/features/EnvSecretsFeatureStep.scala | 57 ++++++++++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 11 +++- .../k8s/KubernetesExecutorBuilder.scala | 12 +++- .../deploy/k8s/KubernetesConfSuite.scala | 12 +++- .../BasicDriverFeatureStepSuite.scala | 2 + .../BasicExecutorFeatureStepSuite.scala | 3 + ...ubernetesCredentialsFeatureStepSuite.scala | 3 + .../DriverServiceFeatureStepSuite.scala | 6 ++ .../features/EnvSecretsFeatureStepSuite.scala | 59 +++++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 7 ++- .../features/LocalDirsFeatureStepSuite.scala | 1 + .../MountSecretsFeatureStepSuite.scala | 1 + .../spark/deploy/k8s/submit/ClientSuite.scala | 1 + .../submit/KubernetesDriverBuilderSuite.scala | 13 +++- .../k8s/KubernetesExecutorBuilderSuite.scala | 11 +++- 18 files changed, 222 insertions(+), 12 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e9e1f3e280609..a4b2b98b0b649 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -140,6 +140,12 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` +To use a secret through an environment variable use the following options to the `spark-submit` command: +``` +--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key +--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key +``` + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -602,4 +608,20 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. + + spark.kubernetes.driver.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.executor.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4086970ffb256..560dedf431b08 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -162,10 +162,12 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 77b634ddfabcc..5a944187a7096 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -54,6 +54,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleLabels: Map[String, String], roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], + roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -129,6 +130,8 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) @@ -140,6 +143,7 @@ private[spark] object KubernetesConf { driverLabels, driverAnnotations, driverSecretNamesToMountPaths, + driverSecretEnvNamesToKeyRefs, driverEnvs) } @@ -167,8 +171,10 @@ private[spark] object KubernetesConf { executorCustomLabels val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap KubernetesConf( @@ -178,7 +184,8 @@ private[spark] object KubernetesConf { appId, executorLabels, executorAnnotations, - executorSecrets, + executorMountSecrets, + executorEnvSecrets, executorEnv) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala new file mode 100644 index 0000000000000..03ff7d48420ff --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class EnvSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedEnvSecrets = kubernetesConf + .roleSecretEnvNamesToKeyRefs + .map{ case (envName, keyRef) => + // Keyref parts + val keyRefParts = keyRef.split(":") + require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.") + val name = keyRefParts(0) + val key = keyRefParts(1) + new EnvVarBuilder() + .withName(envName) + .withNewValueFrom() + .withNewSecretKeyRef() + .withKey(key) + .withName(name) + .endSecretKeyRef() + .endValueFrom() + .build() + } + + val containerWithEnvVars = new ContainerBuilder(pod.container) + .addAllToEnv(addedEnvSecrets.toSeq.asJava) + .build() + SparkPod(pod.pod, containerWithEnvVars) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 10b0154466a3a..fdc5eb0d75832 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -30,6 +30,9 @@ private[spark] class KubernetesDriverBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -41,10 +44,14 @@ private[spark] class KubernetesDriverBuilder( provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { val configuredPod = feature.configurePod(spec.pod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d8f63d57574fb..d5e1de36a58df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = @@ -25,6 +25,9 @@ private[spark] class KubernetesExecutorBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -32,9 +35,14 @@ private[spark] class KubernetesExecutorBuilder( def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index f10202f7a3546..3d23e1cb90fd2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -40,6 +40,9 @@ class KubernetesConfSuite extends SparkFunSuite { private val SECRET_NAMES_TO_MOUNT_PATHS = Map( "secret1" -> "/mnt/secrets/secret1", "secret2" -> "/mnt/secrets/secret2") + private val SECRET_ENV_VARS = Map( + "envName1" -> "name1:key1", + "envName2" -> "name2:key2") private val CUSTOM_ENVS = Map( "customEnvKey1" -> "customEnvValue1", "customEnvKey2" -> "customEnvValue2") @@ -103,6 +106,9 @@ class KubernetesConfSuite extends SparkFunSuite { SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value) + } CUSTOM_ENVS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) } @@ -121,6 +127,7 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) } @@ -155,6 +162,9 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_ANNOTATIONS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value) + } SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) } @@ -170,6 +180,6 @@ class KubernetesConfSuite extends SparkFunSuite { SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) } - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eee85b8baa730..b2813d8b3265d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -69,6 +69,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, DRIVER_ENVS) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -138,6 +139,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, Map.empty) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a764f7630b5c8..9182134b3337c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -87,6 +87,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) val executor = step.configurePod(SparkPod.initialPod()) @@ -124,6 +125,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -142,6 +144,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map("qux" -> "quux"))) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 9f817d3bfc79a..f81894f8055f1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -59,6 +59,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -88,6 +89,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -124,6 +126,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index c299d56865ec0..f265522a8823a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -65,6 +65,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -94,6 +95,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -113,6 +115,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -141,6 +144,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) val driverService = configurationStep @@ -166,6 +170,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver bind address should not be allowed.") @@ -189,6 +194,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala new file mode 100644 index 0000000000000..8b0b2d0739c76 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class EnvSecretsFeatureStepSuite extends SparkFunSuite{ + private val KEY_REF_NAME_FOO = "foo" + private val KEY_REF_NAME_BAR = "bar" + private val KEY_REF_KEY_FOO = "key_foo" + private val KEY_REF_KEY_BAR = "key_bar" + private val ENV_NAME_FOO = "MY_FOO" + private val ENV_NAME_BAR = "MY_bar" + + test("sets up all keyRefs") { + val baseDriverPod = SparkPod.initialPod() + val envVarsToKeys = Map( + ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", + ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + Map.empty, + envVarsToKeys, + Map.empty) + + val step = new EnvSecretsFeatureStep(kubernetesConf) + val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container + + val expectedVars = + Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") + + expectedVars.foreach { envName => + assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 27bff74ce38af..f90380e30e52a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -58,4 +60,7 @@ object KubernetesFeaturesTestUtils { .build()) } + def containerHasEnvVar(container: Container, envVarName: String): Boolean = { + container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 91e184b84b86e..2542a02d37766 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9d02f56cc206d..9155793774123 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -41,6 +41,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, secretNamesToMountPaths, + Map.empty, Map.empty) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index c1b203e03a357..0775338098a13 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -142,6 +142,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index a511d254d2175..cb724068ea4f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -27,6 +27,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -43,12 +44,16 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, _ => secretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Apply fundamental steps all the time.") { @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("EnvName" -> "SecretName:secretKey"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -93,7 +100,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE + ) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 9ee86b5a423a9..753cd30a237f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,23 +20,27 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Basic steps are consistently applied.") { @@ -49,6 +53,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -64,12 +69,14 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("secret-name" -> "secret-key"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE) } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { From 2c9c8629b7aa87c057ec9b84b6082f8b9b9319b1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 1 Jun 2018 08:44:57 +0800 Subject: [PATCH 0903/2461] [MINOR][YARN] Add YARN-specific credential providers in debug logging message This PR adds a debugging log for YARN-specific credential providers which is loaded by service loader mechanism. It took me a while to debug if it's actually loaded or not. I had to explicitly set the deprecated configuration and check if that's actually being loaded. The change scope is manually tested. Logs are like: ``` Using the following builtin delegation token providers: hadoopfs, hive, hbase. Using the following YARN-specific credential providers: yarn-test. ``` Author: hyukjinkwon Closes #21466 from HyukjinKwon/minor-log. Change-Id: I18e2fb8eeb3289b148f24c47bb3130a560a881cf --- .../spark/deploy/security/HadoopDelegationTokenManager.scala | 4 ++-- .../yarn/security/YARNHadoopDelegationTokenManager.scala | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 5151df00476f9..ab8d8d96a9b08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging * * Also, each HadoopDelegationTokenProvider is controlled by * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to - * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be * enabled/disabled by the configuration spark.security.credentials.hive.enabled. * * @param sparkConf Spark configuration @@ -52,7 +52,7 @@ private[spark] class HadoopDelegationTokenManager( // Maintain all the registered delegation token providers private val delegationTokenProviders = getDelegationTokenProviders - logDebug(s"Using the following delegation token providers: " + + logDebug("Using the following builtin delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index d4eeb6bbcf886..26a2e5d730218 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -44,6 +44,10 @@ private[yarn] class YARNHadoopDelegationTokenManager( // public for testing val credentialProviders = getCredentialProviders + if (credentialProviders.nonEmpty) { + logDebug("Using the following YARN-specific credential providers: " + + s"${credentialProviders.keys.mkString(", ")}.") + } /** * Writes delegation tokens to creds. Delegation tokens are fetched from all registered From cbaa729132e9aee0e5d8aed332848e21dd8670a8 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 1 Jun 2018 10:01:15 +0800 Subject: [PATCH 0904/2461] [SPARK-24330][SQL] Refactor ExecuteWriteTask and Use `while` in writing files ## What changes were proposed in this pull request? 1. Refactor ExecuteWriteTask in FileFormatWriter to reduce common logic and improve readability. After the change, callers only need to call `commit()` or `abort` at the end of task. Also there is less code in `SingleDirectoryWriteTask` and `DynamicPartitionWriteTask`. Definitions of related classes are moved to a new file, and `ExecuteWriteTask` is renamed to `FileFormatDataWriter`. 2. As per code style guide: https://github.com/databricks/scala-style-guide#traversal-and-zipwithindex , we avoid using `for` for looping in [FileFormatWriter](https://github.com/apache/spark/pull/21381/files#diff-3b69eb0963b68c65cfe8075f8a42e850L536) , or `foreach` in [WriteToDataSourceV2Exec](https://github.com/apache/spark/pull/21381/files#diff-6fbe10db766049a395bae2e785e9d56eL119). In such critical code path, using `while` is good for performance. ## How was this patch tested? Existing unit test. I tried the microbenchmark in https://github.com/apache/spark/pull/21409 | Workload | Before changes(Best/Avg Time(ms)) | After changes(Best/Avg Time(ms)) | | --- | --- | -- | |Output Single Int Column| 2018 / 2043 | 2096 / 2236 | |Output Single Double Column| 1978 / 2043 | 2013 / 2018 | |Output Int and String Column| 6332 / 6706 | 6162 / 6298 | |Output Partitions| 4458 / 5094 | 3792 / 4008 | |Output Buckets| 5695 / 6102 | 5120 / 5154 | Also a microbenchmark on my laptop for general comparison among while/foreach/for : ``` class Writer { var sum = 0L def write(l: Long): Unit = sum += l } def testWhile(iterator: Iterator[Long]): Long = { val w = new Writer while (iterator.hasNext) { w.write(iterator.next()) } w.sum } def testForeach(iterator: Iterator[Long]): Long = { val w = new Writer iterator.foreach(w.write) w.sum } def testFor(iterator: Iterator[Long]): Long = { val w = new Writer for (x <- iterator) { w.write(x) } w.sum } val data = 0L to 100000000L val start = System.nanoTime (0 to 10).foreach(_ => testWhile(data.iterator)) println("benchmark while: " + (System.nanoTime - start)/1000000) val start2 = System.nanoTime (0 to 10).foreach(_ => testForeach(data.iterator)) println("benchmark foreach: " + (System.nanoTime - start2)/1000000) val start3 = System.nanoTime (0 to 10).foreach(_ => testForeach(data.iterator)) println("benchmark for: " + (System.nanoTime - start3)/1000000) ``` Benchmark result: `while`: 15401 ms `foreach`: 43034 ms `for`: 41279 ms Author: Gengliang Wang Closes #21381 from gengliangwang/refactorExecuteWriteTask. --- .../datasources/BasicWriteStatsTracker.scala | 2 +- .../datasources/FileFormatDataWriter.scala | 313 ++++++++++++++++ .../datasources/FileFormatWriter.scala | 353 +----------------- .../datasources/v2/WriteToDataSourceV2.scala | 4 +- 4 files changed, 334 insertions(+), 338 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 69c03d862391e..ba7d2b7cbdb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration /** - * Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]]. + * Simple metrics collected during an instance of [[FileFormatDataWriter]]. * These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703). */ case class BasicWriteTaskStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala new file mode 100644 index 0000000000000..6499328e89ce7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.SerializableConfiguration + +/** + * Abstract class for writing out data in a single Spark task. + * Exceptions thrown by the implementation of this trait will automatically trigger task aborts. + */ +abstract class FileFormatDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + protected val MAX_FILE_COUNTER: Int = 1000 * 1000 + protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() + protected var currentWriter: OutputWriter = _ + + /** Trackers for computing various statistics on the data as it's being written out. */ + protected val statsTrackers: Seq[WriteTaskStatsTracker] = + description.statsTrackers.map(_.newTaskInstance()) + + protected def releaseResources(): Unit = { + if (currentWriter != null) { + try { + currentWriter.close() + } finally { + currentWriter = null + } + } + } + + /** Writes a record */ + def write(record: InternalRow): Unit + + /** + * Returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to e.g. update the metrics in UI. + */ + def commit(): WriteTaskResult = { + releaseResources() + val summary = ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + stats = statsTrackers.map(_.getFinalStats())) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) + } + + def abort(): Unit = { + try { + releaseResources() + } finally { + committer.abortTask(taskAttemptContext) + } + } +} + +/** FileFormatWriteTask for empty partitions */ +class EmptyDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol +) extends FileFormatDataWriter(description, taskAttemptContext, committer) { + override def write(record: InternalRow): Unit = {} +} + +/** Writes data to a single directory (used for non-dynamic-partition writes). */ +class SingleDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + // Initialize currentWriter and statsTrackers + newOutputWriter() + + private def newOutputWriter(): Unit = { + recordsInFile = 0 + releaseResources() + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val currentPath = committer.newTaskTempFile( + taskAttemptContext, + None, + f"-c$fileCounter%03d" + ext) + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter() + } + + currentWriter.write(record) + statsTrackers.foreach(_.newRow(record)) + recordsInFile += 1 + } +} + +/** + * Writes data to using dynamic partition writes, meaning this single function can write to + * multiple directories (partitions) or files (bucketing). + */ +class DynamicPartitionDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + + /** Flag saying whether or not the data to be written out is partitioned. */ + private val isPartitioned = description.partitionColumns.nonEmpty + + /** Flag saying whether or not the data to be written out is bucketed. */ + private val isBucketed = description.bucketIdExpression.isDefined + + assert(isPartitioned || isBucketed, + s"""DynamicPartitionWriteTask should be used for writing out data that's either + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description + """.stripMargin) + + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + private var currentPartionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + /** Extracts the partition values out of an input row. */ + private lazy val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) + row => proj(row) + } + + /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ + private lazy val partitionPathExpression: Expression = Concat( + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. */ + private lazy val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) + row => proj(row).getString(0) + } + + /** Given an input row, returns the corresponding `bucketId` */ + private lazy val getBucketId: InternalRow => Int = { + val proj = + UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) + row => proj(row).getInt(0) + } + + /** Returns the data columns to be written given an input row */ + private val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + + /** + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + recordsInFile = 0 + releaseResources() + + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartionValues != nextPartitionValues) { + currentPartionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + statsTrackers.foreach(_.newBucket(currentBucketId.get)) + } + + fileCounter = 0 + newOutputWriter(currentPartionValues, currentBucketId) + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter(currentPartionValues, currentBucketId) + } + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** A shared job description for all the write tasks. */ +class WriteJobDescription( + val uuid: String, // prevent collision between different (appending) write jobs + val serializableHadoopConf: SerializableConfiguration, + val outputWriterFactory: OutputWriterFactory, + val allColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long, + val timeZoneId: String, + val statsTrackers: Seq[WriteJobStatsTracker]) + extends Serializable { + + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), + s""" + |All columns: ${allColumns.mkString(", ")} + |Partition columns: ${partitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} + """.stripMargin) +} + +/** The result of a successful write task. */ +case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..52da8356ab835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { - - /** - * Max number of files a single task writes out due to file size. In most cases the number of - * files written should be very small. This is just a safe guard to protect some really bad - * settings, e.g. maxRecordsPerFile = 1. - */ - private val MAX_FILE_COUNTER = 1000 * 1000 - /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( - outputPath: String, - customPartitionLocations: Map[TablePartitionSpec, String], - outputColumns: Seq[Attribute]) - - /** A shared job description for all the write tasks. */ - private class WriteJobDescription( - val uuid: String, // prevent collision between different (appending) write jobs - val serializableHadoopConf: SerializableConfiguration, - val outputWriterFactory: OutputWriterFactory, - val allColumns: Seq[Attribute], - val dataColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], - val bucketIdExpression: Option[Expression], - val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long, - val timeZoneId: String, - val statsTrackers: Seq[WriteJobStatsTracker]) - extends Serializable { - - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), - s""" - |All columns: ${allColumns.mkString(", ")} - |Partition columns: ${partitionColumns.mkString(", ")} - |Data columns: ${dataColumns.mkString(", ")} - """.stripMargin) - } - - /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) /** * Basic work flow of this command is: @@ -262,30 +223,27 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) - val writeTask = + val dataWriter = if (sparkPartitionId != 0 && !iterator.hasNext) { // In case of empty job, leave first partition to save meta for file format like parquet. - new EmptyDirectoryWriteTask(description) + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val summary = writeTask.execute(iterator) - writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), summary) - })(catchBlock = { - // If there is an error, release resource and then abort the task - try { - writeTask.releaseResources() - } finally { - committer.abortTask(taskAttemptContext) - logError(s"Job $jobId aborted.") + while (iterator.hasNext) { + dataWriter.write(iterator.next()) } + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") }) } catch { case e: FetchFailedException => @@ -302,7 +260,7 @@ object FileFormatWriter extends Logging { private def processStats( statsTrackers: Seq[WriteJobStatsTracker], statsPerTask: Seq[Seq[WriteTaskStats]]) - : Unit = { + : Unit = { val numStatsTrackers = statsTrackers.length assert(statsPerTask.forall(_.length == numStatsTrackers), @@ -321,281 +279,4 @@ object FileFormatWriter extends Logging { case (statsTracker, stats) => statsTracker.processStats(stats) } } - - /** - * A simple trait for writing out data in a single Spark task, without any concerns about how - * to commit or abort tasks. Exceptions thrown by the implementation of this trait will - * automatically trigger task aborts. - */ - private trait ExecuteWriteTask { - - /** - * Writes data out to files, and then returns the summary of relative information which - * includes the list of partition strings written out. The list of partitions is sent back - * to the driver and used to update the catalog. Other information will be sent back to the - * driver too and used to e.g. update the metrics in UI. - */ - def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary - def releaseResources(): Unit - } - - /** ExecuteWriteTask for empty partitions */ - private class EmptyDirectoryWriteTask(description: WriteJobDescription) - extends ExecuteWriteTask { - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = {} - } - - /** Writes data to a single directory (used for non-dynamic-partition writes). */ - private class SingleDirectoryWriteTask( - description: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - private[this] var currentWriter: OutputWriter = _ - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - private def newOutputWriter(fileCounter: Int): Unit = { - val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val currentPath = committer.newTaskTempFile( - taskAttemptContext, - None, - f"-c$fileCounter%03d" + ext) - - currentWriter = description.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = description.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.map(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - var fileCounter = 0 - var recordsInFile: Long = 0L - newOutputWriter(fileCounter) - - while (iter.hasNext) { - if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - recordsInFile = 0 - releaseResources() - newOutputWriter(fileCounter) - } - - val internalRow = iter.next() - currentWriter.write(internalRow) - statsTrackers.foreach(_.newRow(internalRow)) - recordsInFile += 1 - } - releaseResources() - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } - - /** - * Writes data to using dynamic partition writes, meaning this single function can write to - * multiple directories (partitions) or files (bucketing). - */ - private class DynamicPartitionWriteTask( - desc: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty - - /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined - - assert(isPartitioned || isBucketed, - s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: ${desc} - """.stripMargin) - - // currentWriter is initialized whenever we see a new key (partitionValues + BucketId) - private var currentWriter: OutputWriter = _ - - /** Trackers for computing various statistics on the data as it's being written out. */ - private val statsTrackers: Seq[WriteTaskStatsTracker] = - desc.statsTrackers.map(_.newTaskInstance()) - - /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { - val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns) - row => proj(row) - } - - /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ - private lazy val partitionPathExpression: Expression = Concat( - desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val partitionName = ScalaUDF( - ExternalCatalogUtils.getPartitionPathString _, - StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) - if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) - }) - - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ - private lazy val getPartitionPath: InternalRow => String = { - val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns) - row => proj(row).getString(0) - } - - /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { - val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns) - row => proj(row).getInt(0) - } - - /** Returns the data columns to be written given an input row */ - private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) - - /** - * Opens a new OutputWriter given a partition key and/or a bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` - * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to - * @param fileCounter the number of files that have been written in the past for this specific - * partition. This is used to limit the max number of records written for a - * single file. The value should start from 0. - * @param updatedPartitions the set of updated partition paths, we should add the new partition - * path of this writer to it. - */ - private def newOutputWriter( - partitionValues: Option[InternalRow], - bucketId: Option[Int], - fileCounter: Int, - updatedPartitions: mutable.Set[String]): Unit = { - - val partDir = partitionValues.map(getPartitionPath(_)) - partDir.foreach(updatedPartitions.add) - - val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + - desc.outputWriterFactory.getFileExtension(taskAttemptContext) - - val customPath = partDir.flatMap { dir => - desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) - } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) - } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) - } - - currentWriter = desc.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = desc.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.foreach(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - // If anything below fails, we should abort the task. - var recordsInFile: Long = 0L - var fileCounter = 0 - val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None - var currentBucketId: Option[Int] = None - - for (row <- iter) { - val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None - val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) - } - - recordsInFile = 0 - fileCounter = 0 - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } else if (desc.maxRecordsPerFile > 0 && - recordsInFile >= desc.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - recordsInFile = 0 - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } - val outputRow = getOutputRow(row) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 - } - releaseResources() - - ExecutedWriteSummary( - updatedPartitions = updatedPartitions.toSet, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } } - -/** - * Wrapper class for the metrics of writing data out. - * - * @param updatedPartitions the partitions updated during writing data out. Only valid - * for dynamic partition. - * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. - */ -case class ExecutedWriteSummary( - updatedPartitions: Set[String], - stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea283ed77efda..ea4bda327f36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -116,7 +116,9 @@ object DataWritingSparkTask extends Logging { // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - iter.foreach(dataWriter.write) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator From b2d022656298c7a39ff3e84b04f813d5f315cb95 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 1 Jun 2018 11:58:59 +0800 Subject: [PATCH 0905/2461] [SPARK-24444][DOCS][PYTHON] Improve Pandas UDF docs to explain column assignment ## What changes were proposed in this pull request? Added sections to pandas_udf docs, in the grouped map section, to indicate columns are assigned by position. ## How was this patch tested? NA Author: Bryan Cutler Closes #21471 from BryanCutler/arrow-doc-pandas_udf-column_by_pos-SPARK-21427. --- docs/sql-programming-guide.md | 9 +++++++++ python/pyspark/sql/functions.py | 9 ++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 50600861912b1..4d8a738507bd1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1752,6 +1752,15 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. +The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, +not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their +position matches the corresponding field in the schema. + +Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column +can differ from the order that it was placed in the dictionary. It is recommended in this case to +explicitly define the column order using the `columns` keyword, e.g. +`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. + Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index efcce25a08e04..fd656c5c35844 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2500,7 +2500,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary. + The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be + indexed so that their position matches the corresponding field in the schema. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. @@ -2548,6 +2549,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2|6.0| +---+---+ + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG From 22df953f6bb191858053eafbabaa5b3ebca29f56 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 31 May 2018 21:25:45 -0700 Subject: [PATCH 0906/2461] [SPARK-24326][MESOS] add support for local:// scheme for the app jar ## What changes were proposed in this pull request? * Adds support for local:// scheme like in k8s case for image based deployments where the jar is already in the image. Affects cluster mode and the mesos dispatcher. Covers also file:// scheme. Keeps the default case where jar resolution happens on the host. ## How was this patch tested? Dispatcher image with the patch, use it to start DC/OS Spark service: skonto/spark-local-disp:test Test image with my application jar located at the root folder: skonto/spark-local:test Dockerfile for that image. From mesosphere/spark:2.3.0-2.2.1-2-hadoop-2.6 COPY spark-examples_2.11-2.2.1.jar / WORKDIR /opt/spark/dist Tests: The following work as expected: * local normal example ``` dcos spark run --submit-args="--conf spark.mesos.appJar.local.resolution.mode=container --conf spark.executor.memory=1g --conf spark.mesos.executor.docker.image=skonto/spark-local:test --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi local:///spark-examples_2.11-2.2.1.jar" ``` * make sure the flag does not affect other uris ``` dcos spark run --submit-args="--conf spark.mesos.appJar.local.resolution.mode=container --conf spark.executor.memory=1g --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi https://s3-eu-west-1.amazonaws.com/fdp-stavros-test/spark-examples_2.11-2.1.1.jar" ``` * normal example no local ``` dcos spark run --submit-args="--conf spark.executor.memory=1g --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi https://s3-eu-west-1.amazonaws.com/fdp-stavros-test/spark-examples_2.11-2.1.1.jar" ``` The following fails * uses local with no setting, default is host. ``` dcos spark run --submit-args="--conf spark.executor.memory=1g --conf spark.mesos.executor.docker.image=skonto/spark-local:test --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi local:///spark-examples_2.11-2.2.1.jar" ``` ![image](https://user-images.githubusercontent.com/7945591/40283021-8d349762-5c80-11e8-9d62-2a61a4318fd5.png) Author: Stavros Kontopoulos Closes #21378 from skonto/local-upstream. --- docs/running-on-mesos.md | 12 +++++++ .../cluster/mesos/MesosClusterScheduler.scala | 36 +++++++++++++++---- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3c2a1501ca692..66ffb17949845 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -753,6 +753,18 @@ See the [configuration page](configuration.html) for information on Spark config spark.cores.max is reached + + spark.mesos.appJar.local.resolution.mode + host + + Provides support for the `local:///` scheme to reference the app jar resource in cluster mode. + If user uses a local resource (`local:///path/to/jar`) and the config option is not used it defaults to `host` eg. + the mesos fetcher tries to get the resource from the host's file system. + If the value is unknown it prints a warning msg in the dispatcher logs and defaults to `host`. + If the value is `container` then spark submit in the container will use the jar in the container's path: + `/path/to/jar`. + + # Troubleshooting and Debugging diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index b36f46456f9a5..7d80eedcc43ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,8 +30,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.deploy.mesos.config +import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils @@ -418,6 +417,18 @@ private[spark] class MesosClusterScheduler( envBuilder.build() } + private def isContainerLocalAppJar(desc: MesosDriverDescription): Boolean = { + val isLocalJar = desc.jarUrl.startsWith("local://") + val isContainerLocal = desc.conf.getOption("spark.mesos.appJar.local.resolution.mode").exists { + case "container" => true + case "host" => false + case other => + logWarning(s"Unknown spark.mesos.appJar.local.resolution.mode $other, using host.") + false + } + isLocalJar && isContainerLocal + } + private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -425,10 +436,14 @@ private[spark] class MesosClusterScheduler( _.map(_.split(",").map(_.trim)) ).flatten - val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") - - ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + if (isContainerLocalAppJar(desc)) { + (confUris ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } else { + val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") + ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } } private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { @@ -480,7 +495,14 @@ private[spark] class MesosClusterScheduler( (cmdExecutable, ".") } val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") - val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val primaryResource = { + if (isContainerLocalAppJar(desc)) { + new File(desc.jarUrl.stripPrefix("local://")).toString() + } else { + new File(sandboxPath, desc.jarUrl.split("/").last).toString() + } + } + val appArguments = desc.command.arguments.mkString(" ") s"$executable $cmdOptions $primaryResource $appArguments" From 98909c398dbcbffcae8015d36f44f185d3280af6 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 31 May 2018 22:04:26 -0700 Subject: [PATCH 0907/2461] [SPARK-23920][SQL] add array_remove to remove all elements that equal element from array ## What changes were proposed in this pull request? add array_remove to remove all elements that equal element from array ## How was this patch tested? add unit tests Author: Huaxin Gao Closes #21069 from huaxingao/spark-23920. --- python/pyspark/sql/functions.py | 16 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 123 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 58 +++++++++ .../org/apache/spark/sql/functions.scala | 9 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 29 +++++ 6 files changed, 236 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fd656c5c35844..1759195c6fcc0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1964,6 +1964,22 @@ def element_at(col, extraction): return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) +@since(2.4) +def array_remove(col, element): + """ + Collection function: Remove all elements that equal to element from the given array. + + :param col: name of column containing array + :param element: element to be removed from the array + + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 23a4a440fac23..49fb35b083580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -430,6 +430,7 @@ object FunctionRegistry { expression[Concat]("concat"), expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), + expression[ArrayRemove]("array_remove"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8a877b02c8191..176995affe701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2066,3 +2066,126 @@ case class ArrayRepeat(left: Expression, right: Expression) } } + +/** + * Remove all elements that equal to element from the given array + */ +@ExpressionDescription( + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] + """, since = "2.4.0") +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } + + lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null || !ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 + } + ) + new GenericArrayData(newArray.slice(0, pos)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + s""" + |int $numsToRemove = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && $isEqual) { + | $numsToRemove = $numsToRemove + 1; + | } + |} + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values[$pos] = null; + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } + + override def prettyName: String = "array_remove" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3fc0b08c56e02..f8ad624ce0e3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -622,4 +622,62 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + + test("Array remove") { + val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) + val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a4 = Literal.create(null, ArrayType(StringType)) + val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) + val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + + checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) + checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) + + checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) + checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) + + checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) + checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) + + checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) + + checkEvaluation(ArrayRemove(a4, Literal("a")), null) + + checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) + checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + + val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemove1), + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b0, nullBinary), null) + checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) + val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 443ba2aa3757d..a2aae9a708ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3169,6 +3169,15 @@ object functions { */ def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** + * Remove all elements that equal to element from the given array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_remove(column: Column, element: Any): Column = withExpr { + ArrayRemove(column.expr, Literal(element)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cc8bad4ded53e..59119bbbd8a2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1110,6 +1110,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } + test("array remove") { + val df = Seq( + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), + (Array.empty[Int], Array.empty[String], Array.empty[String]), + (null, null, null) + ).toDF("a", "b", "c") + checkAnswer( + df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", + "array_remove(c, \"\")"), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 6039b132304cc77ed39e4ca7813850507ae0b440 Mon Sep 17 00:00:00 2001 From: Huang Tengfei Date: Fri, 1 Jun 2018 10:47:53 -0700 Subject: [PATCH 0908/2461] [SPARK-24351][SS] offsetLog/commitLog purge thresholdBatchId should be computed with current committed epoch but not currentBatchId in CP mode ## What changes were proposed in this pull request? Compute the thresholdBatchId to purge metadata based on current committed epoch instead of currentBatchId in CP mode to avoid cleaning all the committed metadata in some case as described in the jira [SPARK-24351](https://issues.apache.org/jira/browse/SPARK-24351). ## How was this patch tested? Add new unit test. Author: Huang Tengfei Closes #21400 from ivoson/branch-cp-meta. --- .../continuous/ContinuousExecution.scala | 11 +++-- .../continuous/ContinuousSuite.scala | 46 +++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index d16b24c89ebef..e3d0cea608b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -318,9 +318,14 @@ class ContinuousExecution( } } - if (minLogEntriesToMaintain < currentBatchId) { - offsetLog.purge(currentBatchId - minLogEntriesToMaintain) - commitLog.purge(currentBatchId - minLogEntriesToMaintain) + // Since currentBatchId increases independently in cp mode, the current committed epoch may + // be far behind currentBatchId. It is not safe to discard the metadata with thresholdBatchId + // computed based on currentBatchId. As minLogEntriesToMaintain is used to keep the minimum + // number of batches that must be retained and made recoverable, so we should keep the + // specified number of metadata that have been committed. + if (minLogEntriesToMaintain <= epoch) { + offsetLog.purge(epoch + 1 - minLogEntriesToMaintain) + commitLog.purge(epoch + 1 - minLogEntriesToMaintain) } awaitProgressLock.lock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index cd1704ac2fdad..4980b0cd41f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -297,3 +297,49 @@ class ContinuousStressSuite extends ContinuousSuiteBase { CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } + +class ContinuousMetaSuite extends ContinuousSuiteBase { + import testImplicits._ + + // We need to specify spark.sql.streaming.minBatchesToRetain to do the following test. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true") + .set("spark.sql.streaming.minBatchesToRetain", "2"))) + + test("SPARK-24351: check offsetLog/commitLog retained in the checkpoint directory") { + withTempDir { checkpointDir => + val input = ContinuousMemoryStream[Int] + val df = input.toDF().mapPartitions(iter => { + // Sleep the task thread for 300 ms to make sure epoch processing time 3 times + // longer than epoch creating interval. So the gap between last committed + // epoch and currentBatchId grows over time. + Thread.sleep(300) + iter.map(row => row.getInt(0) * 2) + }) + + testStream(df)( + StartStream(trigger = Trigger.Continuous(100), + checkpointLocation = checkpointDir.getAbsolutePath), + AddData(input, 1), + CheckAnswer(2), + // Make sure epoch 2 has been committed before the following validation. + AwaitEpoch(2), + StopStream, + AssertOnQuery(q => { + q.commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val commitLogValidateResult = q.commitLog.get(latestEpochId - 1).isDefined && + q.commitLog.get(latestEpochId - 2).isEmpty + val offsetLogValidateResult = q.offsetLog.get(latestEpochId - 1).isDefined && + q.offsetLog.get(latestEpochId - 2).isEmpty + commitLogValidateResult && offsetLogValidateResult + case None => false + } + }) + ) + } + } +} From d2c3de7efcfacadff20b023924d4566a5bf9ad7a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 1 Jun 2018 11:51:10 -0700 Subject: [PATCH 0909/2461] Revert "[SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set" This reverts commit 1e46f92f956a00d04d47340489b6125d44dbd47b. --- .../optimizer/RewriteDistinctAggregates.scala | 7 +++---- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/resources/sql-tests/inputs/group-by.sql | 6 +----- .../test/resources/sql-tests/results/group-by.sql.out | 11 +---------- 4 files changed, 6 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index bc898ab0dc723..4448ace7105a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,8 +115,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distincgAggExpressions = aggExpressions.filter(_.isDistinct) - val distinctAggGroups = distincgAggExpressions.groupBy { e => + val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -133,7 +132,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distincgAggExpressions.size > 1) { + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -152,7 +151,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b9452b58657a4..b97a87a122406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.partition(_.isDistinct) if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our `RewriteDistinctAggregates` should take care this case. + // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 2c18d6aaabdba..c5070b734d521 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,8 +68,4 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z; - --- SPARK-24369 multiple distinct aggregations having the same argument set -SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) - FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 581aa1754ce14..c1abc6dff754b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 26 -- !query 0 @@ -241,12 +241,3 @@ where b.z != b.z struct<1:int> -- !query 25 output - - --- !query 26 -SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) - FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) --- !query 26 schema -struct --- !query 26 output -1.0 1.0 3 From 09e78c1eaa742b9cab4564928e5a5401fe0198a9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 1 Jun 2018 11:55:34 -0700 Subject: [PATCH 0910/2461] [INFRA] Close stale PRs. Closes #21444 From 8ef167a5f9ba8a79bb7ca98a9844fe9cfcfea060 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 1 Jun 2018 13:46:05 -0700 Subject: [PATCH 0911/2461] [SPARK-24340][CORE] Clean up non-shuffle disk block manager files following executor exits on a Standalone cluster ## What changes were proposed in this pull request? Currently we only clean up the local directories on application removed. However, when executors die and restart repeatedly, many temp files are left untouched in the local directories, which is undesired behavior and could cause disk space used up gradually. We can detect executor death in the Worker, and clean up the non-shuffle files (files not ended with ".index" or ".data") in the local directories, we should not touch the shuffle files since they are expected to be used by the external shuffle service. Scope of this PR is limited to only implement the cleanup logic on a Standalone cluster, we defer to experts familiar with other cluster managers(YARN/Mesos/K8s) to determine whether it's worth to add similar support. ## How was this patch tested? Add new test suite to cover. Author: Xingbo Jiang Closes #21390 from jiangxb1987/cleanupNonshuffleFiles. --- .../apache/spark/network/util/JavaUtils.java | 45 ++-- .../shuffle/ExternalShuffleBlockHandler.java | 7 + .../shuffle/ExternalShuffleBlockResolver.java | 43 ++++ .../shuffle/NonShuffleFilesCleanupSuite.java | 221 ++++++++++++++++++ .../shuffle/TestShuffleDataContext.java | 15 ++ .../spark/deploy/ExternalShuffleService.scala | 5 + .../apache/spark/deploy/worker/Worker.scala | 17 +- .../spark/deploy/worker/WorkerSuite.scala | 55 ++++- docs/spark-standalone.md | 12 + 9 files changed, 400 insertions(+), 20 deletions(-) create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index afc59efaef810..b5497087634ce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,10 +17,7 @@ package org.apache.spark.network.util; -import java.io.Closeable; -import java.io.EOFException; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; @@ -91,11 +88,24 @@ public static String bytesToString(ByteBuffer b) { * @throws IOException if deletion is unsuccessful */ public static void deleteRecursively(File file) throws IOException { + deleteRecursively(file, null); + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * + * @param file Input file / dir to be deleted + * @param filter A filename filter that make sure only files / dirs with the satisfied filenames + * are deleted. + * @throws IOException if deletion is unsuccessful + */ + public static void deleteRecursively(File file, FilenameFilter filter) throws IOException { if (file == null) { return; } // On Unix systems, use operating system command to run faster // If that does not work out, fallback to the Java IO way - if (SystemUtils.IS_OS_UNIX) { + if (SystemUtils.IS_OS_UNIX && filter == null) { try { deleteRecursivelyUsingUnixNative(file); return; @@ -105,15 +115,17 @@ public static void deleteRecursively(File file) throws IOException { } } - deleteRecursivelyUsingJavaIO(file); + deleteRecursivelyUsingJavaIO(file, filter); } - private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { + private static void deleteRecursivelyUsingJavaIO( + File file, + FilenameFilter filter) throws IOException { if (file.isDirectory() && !isSymlink(file)) { IOException savedIOException = null; - for (File child : listFilesSafely(file)) { + for (File child : listFilesSafely(file, filter)) { try { - deleteRecursively(child); + deleteRecursively(child, filter); } catch (IOException e) { // In case of multiple exceptions, only last one will be thrown savedIOException = e; @@ -124,10 +136,13 @@ private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { } } - boolean deleted = file.delete(); - // Delete can also fail if the file simply did not exist. - if (!deleted && file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath()); + // Delete file only when it's a normal file or an empty directory. + if (file.isFile() || (file.isDirectory() && listFilesSafely(file, null).length == 0)) { + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } } } @@ -157,9 +172,9 @@ private static void deleteRecursivelyUsingUnixNative(File file) throws IOExcepti } } - private static File[] listFilesSafely(File file) throws IOException { + private static File[] listFilesSafely(File file, FilenameFilter filter) throws IOException { if (file.exists()) { - File[] files = file.listFiles(); + File[] files = file.listFiles(filter); if (files == null) { throw new IOException("Failed to list files for dir: " + file); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba41185f0..098fa7974b87b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -138,6 +138,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + /** + * Clean up any non-shuffle files in any local directories associated with an finished executor. + */ + public void executorRemoved(String executorId, String appId) { + blockManager.executorRemoved(executorId, appId); + } + /** * Register an (application, executor) with the given shuffle info. * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e6399897be9c2..58fb17f60a79d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -211,6 +211,26 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } } + /** + * Removes all the non-shuffle files in any local directories associated with the finished + * executor. + */ + public void executorRemoved(String executorId, String appId) { + logger.info("Clean up non-shuffle files associated with the finished executor {}", executorId); + AppExecId fullId = new AppExecId(appId, executorId); + final ExecutorShuffleInfo executor = executors.get(fullId); + if (executor == null) { + // Executor not registered, skip clean up of the local directories. + logger.info("Executor is not registered (appId={}, execId={})", appId, executorId); + } else { + logger.info("Cleaning up non-shuffle files in executor {}'s {} local dirs", fullId, + executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(() -> deleteNonShuffleFiles(executor.localDirs)); + } + } + /** * Synchronously deletes each directory one at a time. * Should be executed in its own thread, as this may take a long time. @@ -226,6 +246,29 @@ private void deleteExecutorDirs(String[] dirs) { } } + /** + * Synchronously deletes non-shuffle files in each directory recursively. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteNonShuffleFiles(String[] dirs) { + FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir), filter); + logger.debug("Successfully cleaned up non-shuffle files in directory: {}", localDir); + } catch (Exception e) { + logger.error("Failed to delete non-shuffle files in directory: " + localDir, e); + } + } + } + /** * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java new file mode 100644 index 0000000000000..d22f3ace4103b --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class NonShuffleFilesCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + + @Test + public void cleanupOnRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(true); + } + + @Test + public void cleanupOnRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(false); + } + + private void cleanupOnRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + resolver.executorRemoved("exec0", "app"); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutorWithShuffleFiles() throws IOException { + cleanupUsesExecutor(true); + } + + @Test + public void cleanupUsesExecutorWithoutShuffleFiles() throws IOException { + cleanupUsesExecutor(false); + } + + private void cleanupUsesExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + manager.executorRemoved("exec0", "app"); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + } + + @Test + public void cleanupOnlyRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(true); + } + + @Test + public void cleanupOnlyRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(false); + } + + private void cleanupOnlyRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext0 = initDataContext(withShuffleFiles); + TestShuffleDataContext dataContext1 = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); + + + resolver.executorRemoved("exec-nonexistent", "app"); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(true); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(false); + } + + private void cleanupOnlyRegisteredExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + + resolver.executorRemoved("exec1", "app"); + assertStillThere(dataContext); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext); + } + + private static void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private static FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + private static boolean assertOnlyShuffleDataInDir(File[] dirs) { + for (File dir : dirs) { + assertTrue(dir.getName() + " wasn't cleaned up", !dir.exists() || + dir.listFiles(filter).length == 0 || assertOnlyShuffleDataInDir(dir.listFiles())); + } + return true; + } + + private static void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + File[] dirs = new File[] {new File(localDir)}; + assertOnlyShuffleDataInDir(dirs); + } + } + + private static TestShuffleDataContext initDataContext(boolean withShuffleFiles) + throws IOException { + if (withShuffleFiles) { + return initDataContextWithShuffleFiles(); + } else { + return initDataContextWithoutShuffleFiles(); + } + } + + private static TestShuffleDataContext initDataContextWithShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createShuffleFiles(dataContext); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext initDataContextWithoutShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext createDataContext() { + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + dataContext.create(); + return dataContext; + } + + private static void createShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + Random rand = new Random(123); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + } + + private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + // Create spill file(s) + dataContext.insertSpillData(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 81e01949e50fa..6989c3baf2e28 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -22,6 +22,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.util.UUID; import com.google.common.io.Closeables; import com.google.common.io.Files; @@ -94,6 +95,20 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr } } + /** Creates spill file(s) within the local dirs. */ + public void insertSpillData() throws IOException { + String filename = "temp_local_" + UUID.randomUUID(); + OutputStream dataStream = null; + + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, filename)); + dataStream.write(42); + } finally { + Closeables.close(dataStream, false); + } + } + /** * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this * context's directories. diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f975fa5cb4e23..b59a4fe66587c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -94,6 +94,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) } + /** Clean up all the non-shuffle files associated with an executor that has exited. */ + def executorRemoved(executorId: String, appId: String): Unit = { + blockHandler.executorRemoved(executorId, appId) + } + def stop() { if (server != null) { server.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 563b84934f264..ee1ca0bba5749 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.function.Supplier import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext @@ -49,7 +50,8 @@ private[deploy] class Worker( endpointName: String, workDirPath: String = null, val conf: SparkConf, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + externalShuffleServiceSupplier: Supplier[ExternalShuffleService] = null) extends ThreadSafeRpcEndpoint with Logging { private val host = rpcEnv.address.host @@ -97,6 +99,10 @@ private[deploy] class Worker( private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + // Whether or not cleanup the non-shuffle files on executor exits. + private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = + conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true) + private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None @@ -142,7 +148,11 @@ private[deploy] class Worker( WorkerWebUI.DEFAULT_RETAINED_DRIVERS) // The shuffle service is not actually started unless configured. - private val shuffleService = new ExternalShuffleService(conf, securityMgr) + private val shuffleService = if (externalShuffleServiceSupplier != null) { + externalShuffleServiceSupplier.get() + } else { + new ExternalShuffleService(conf, securityMgr) + } private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -732,6 +742,9 @@ private[deploy] class Worker( trimFinishedExecutorsIfNecessary() coresUsed -= executor.cores memoryUsed -= executor.memory + if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) { + shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) + } case None => logInfo("Unknown Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index ce212a7513310..e3fe2b696aa1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.deploy.worker +import java.util.concurrent.atomic.AtomicBoolean +import java.util.function.Supplier + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -29,6 +38,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleService: ExternalShuffleService = _ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -36,15 +47,21 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private var _worker: Worker = _ - private def makeWorker(conf: SparkConf): Worker = { + private def makeWorker( + conf: SparkConf, + shuffleServiceSupplier: Supplier[ExternalShuffleService] = null): Worker = { assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, securityMgr) + "Worker", "/tmp", conf, securityMgr, shuffleServiceSupplier) _worker } + before { + MockitoAnnotations.initMocks(this) + } + after { if (_worker != null) { _worker.rpcEnv.shutdown() @@ -194,4 +211,36 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { assert(worker.finishedDrivers.size === expectedValue) } } + + test("cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=true") { + testCleanupFilesWithConfig(true) + } + + test("don't cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=false") { + testCleanupFilesWithConfig(false) + } + + private def testCleanupFilesWithConfig(value: Boolean) = { + val conf = new SparkConf().set("spark.storage.cleanupFilesAfterExecutorExit", value.toString) + + val cleanupCalled = new AtomicBoolean(false) + when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + cleanupCalled.set(true) + } + }) + val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] { + override def get: ExternalShuffleService = shuffleService + } + val worker = makeWorker(conf, externalShuffleServiceSupplier) + // initialize workers + for (i <- 0 until 10) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(cleanupCalled.get() == value) + } } diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f06e72a387df1..14d742de5655c 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -254,6 +254,18 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + spark.storage.cleanupFilesAfterExecutorExit + true + + Enable cleanup non-shuffle files(such as temp. shuffle blocks, cached RDD/broadcast blocks, + spill files, etc) of worker directories following executor exits. Note that this doesn't + overlap with `spark.worker.cleanup.enabled`, as this enables cleanup of non-shuffle files in + local directories of a dead executor, while `spark.worker.cleanup.enabled` enables cleanup of + all files/subdirectories of a stopped and timeout application. + This only affects Standalone mode, support of other cluster manangers can be added in the future. + + spark.worker.ui.compressedLogFileLengthCacheSize 100 From a36c1a6bbd1deb119d96316ccbb6dc96ad174796 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 1 Jun 2018 23:43:10 -0700 Subject: [PATCH 0912/2461] [SPARK-23668][K8S] Added missing config property in running-on-kubernetes.md ## What changes were proposed in this pull request? PR https://github.com/apache/spark/pull/20811 introduced a new Spark configuration property `spark.kubernetes.container.image.pullSecrets` for specifying image pull secrets. However, the documentation wasn't updated accordingly. This PR adds the property introduced into running-on-kubernetes.md. ## How was this patch tested? N/A. foxish mccheah please help merge this. Thanks! Author: Yinan Li Closes #21480 from liyinan926/master. --- docs/running-on-kubernetes.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index a4b2b98b0b649..4eac9bd9032e4 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -327,6 +327,13 @@ specific to Spark on Kubernetes. Container image pull policy used when pulling images within Kubernetes. + + spark.kubernetes.container.image.pullSecrets + + + Comma separated list of Kubernetes secrets used to pull images from private image registries. + + spark.kubernetes.allocation.batch.size 5 From de4feae3cd2fef9e83cac749b04ea9395bdd805e Mon Sep 17 00:00:00 2001 From: Misha Dmitriev Date: Sat, 2 Jun 2018 23:07:39 -0500 Subject: [PATCH 0913/2461] [SPARK-24356][CORE] Duplicate strings in File.path managed by FileSegmentManagedBuffer This patch eliminates duplicate strings that come from the 'path' field of java.io.File objects created by FileSegmentManagedBuffer. That is, we want to avoid the situation when multiple File instances for the same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String instance. In some scenarios such duplicate strings may waste a lot of memory (~ 10% of the heap). To avoid that, we intern the pathname with String.intern(), and before that we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, the code in java.io.File would normalize it later, creating a new "foo/bar" String copy. Unfortunately, the normalization code that java.io.File uses internally is in the package-private class java.io.FileSystem, so we cannot call it here directly. ## What changes were proposed in this pull request? Added code to ExternalShuffleBlockResolver.getFile(), that normalizes and then interns the pathname string before passing it to the File() constructor. ## How was this patch tested? Added unit test Author: Misha Dmitriev Closes #21456 from countmdm/misha/spark-24356. --- .../shuffle/ExternalShuffleBlockResolver.java | 30 ++++++++++++++++++- .../ExternalShuffleBlockResolverSuite.java | 20 +++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 58fb17f60a79d..0b7a27402369d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -24,6 +24,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +61,7 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); private static final ObjectMapper mapper = new ObjectMapper(); + /** * This a common prefix to the key for each app registration we stick in leveldb, so they * are easy to find, since leveldb lets you search based on prefix. @@ -66,6 +69,8 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); + // Map containing all registered executors' metadata. @VisibleForTesting final ConcurrentMap executors; @@ -302,7 +307,8 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); + return new File(createNormalizedInternedPathname( + localDir, String.format("%02x", subDirId), filename)); } void close() { @@ -315,6 +321,28 @@ void close() { } } + /** + * This method is needed to avoid the situation when multiple File instances for the + * same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String. + * According to measurements, in some scenarios such duplicate strings may waste a lot + * of memory (~ 10% of the heap). To avoid that, we intern the pathname, and before that + * we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, + * the internal code in java.io.File would normalize it later, creating a new "foo/bar" + * String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File + * uses, since it is in the package-private class java.io.FileSystem. + */ + @VisibleForTesting + static String createNormalizedInternedPathname(String dir1, String dir2, String fname) { + String pathname = dir1 + File.separator + dir2 + File.separator + fname; + Matcher m = MULTIPLE_SEPARATORS.matcher(pathname); + pathname = m.replaceAll("/"); + // A single trailing slash needs to be taken care of separately + if (pathname.length() > 1 && pathname.endsWith("/")) { + pathname = pathname.substring(0, pathname.length() - 1); + } + return pathname.intern(); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6d201b8fe8d7d..d2072a54fa415 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -135,4 +136,23 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { "\"subDirsPerLocalDir\": 7, \"shuffleManager\": " + "\"" + SORT_MANAGER + "\"}"; assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } + + @Test + public void testNormalizeAndInternPathname() { + assertPathsMatch("/foo", "bar", "baz", "/foo/bar/baz"); + assertPathsMatch("//foo/", "bar/", "//baz", "/foo/bar/baz"); + assertPathsMatch("foo", "bar", "baz///", "foo/bar/baz"); + assertPathsMatch("/foo/", "/bar//", "/baz", "/foo/bar/baz"); + assertPathsMatch("/", "", "", "/"); + assertPathsMatch("/", "/", "/", "/"); + } + + private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { + String normPathname = + ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); + assertEquals(expectedPathname, normPathname); + File file = new File(normPathname); + String returnedPath = file.getPath(); + assertTrue(normPathname == returnedPath); + } } From a2166ecddaec030f78acaa66ce660d979a35079c Mon Sep 17 00:00:00 2001 From: xueyu Date: Mon, 4 Jun 2018 08:10:49 +0700 Subject: [PATCH 0914/2461] [SPARK-24455][CORE] fix typo in TaskSchedulerImpl comment change runTasks to submitTasks in the TaskSchedulerImpl.scala 's comment Author: xueyu Author: Xue Yu <278006819@qq.com> Closes #21485 from xueyumusic/fixtypo1. --- .../org/apache/spark/scheduler/TaskSchedulerImpl.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8e97b3da33820..598b62f85a1fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * up to launch speculative tasks, etc. * * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. + * submitTasks method. * * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some @@ -62,7 +62,7 @@ private[spark] class TaskSchedulerImpl( this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient, // because ExecutorAllocationClient is created after this TaskSchedulerImpl. private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) @@ -228,7 +228,7 @@ private[spark] class TaskSchedulerImpl( // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task and then abort // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // 2. The task set manager has been created but no tasks have been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => taskIdToExecutorId.get(tid).foreach(execId => @@ -694,7 +694,7 @@ private[spark] class TaskSchedulerImpl( * * After stage failure and retry, there may be multiple TaskSetManagers for the stage. * If an earlier attempt of a stage completes a task, we should ensure that the later attempts - * do not also submit those same tasks. That also means that a task completion from an earlier + * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { From 416cd1fd96c0db9194e32ba877b1396b6dc13c8e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 3 Jun 2018 21:57:42 -0700 Subject: [PATCH 0915/2461] [SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set ## What changes were proposed in this pull request? bring back https://github.com/apache/spark/pull/21443 This is a different approach: just change the check to count distinct columns with `toSet` ## How was this patch tested? a new test to verify the planner behavior. Author: Wenchen Fan Author: Takeshi Yamamuro Closes #21487 from cloud-fan/back. --- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../resources/sql-tests/inputs/group-by.sql | 6 +++++- .../sql-tests/results/group-by.sql.out | 11 +++++++++- .../spark/sql/execution/PlannerSuite.scala | 21 +++++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..be34387f6a874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -384,9 +384,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b2aba8e72c5db..98a50fbd52b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext { testPartialAggregationPlan(query) } + test("mixed aggregates with same distinct columns") { + def assertNoExpand(plan: SparkPlan): Unit = { + assert(plan.collect { case e: ExpandExec => e }.isEmpty) + } + + withTempView("v") { + Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v") + // one distinct column + val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i") + assertNoExpand(query1.queryExecution.executedPlan) + + // 2 distinct columns + val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i") + assertNoExpand(query2.queryExecution.executedPlan) + + // 2 distinct columns with different order + val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") + assertNoExpand(query3.queryExecution.executedPlan) + } + } + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { withTempView("testLimit") { From 1d9338bb10b953daddb23b8879ff99aa5c57dbea Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 3 Jun 2018 22:02:21 -0700 Subject: [PATCH 0916/2461] [SPARK-23786][SQL] Checking column names of csv headers ## What changes were proposed in this pull request? Currently column names of headers in CSV files are not checked against provided schema of CSV data. It could cause errors like showed in the [SPARK-23786](https://issues.apache.org/jira/browse/SPARK-23786) and https://github.com/apache/spark/pull/20894#issuecomment-375957777. I introduced new CSV option - `enforceSchema`. If it is enabled (by default `true`), Spark forcibly applies provided or inferred schema to CSV files. In that case, CSV headers are ignored and not checked against the schema. If `enforceSchema` is set to `false`, additional checks can be performed. For example, if column in CSV header and in the schema have different ordering, the following exception is thrown: ``` java.lang.IllegalArgumentException: CSV file header does not contain the expected fields Header: depth, temperature Schema: temperature, depth CSV file: marina.csv ``` ## How was this patch tested? The changes were tested by existing tests of CSVSuite and by 2 new tests. Author: Maxim Gekk Author: Maxim Gekk Closes #20894 from MaxGekk/check-column-names. --- python/pyspark/sql/readwriter.py | 15 +- python/pyspark/sql/streaming.py | 15 +- python/pyspark/sql/tests.py | 18 ++ .../apache/spark/sql/DataFrameReader.scala | 19 ++ .../datasources/csv/CSVDataSource.scala | 126 +++++++++++- .../datasources/csv/CSVFileFormat.scala | 9 +- .../datasources/csv/CSVOptions.scala | 6 + .../execution/datasources/csv/CSVUtils.scala | 22 +- .../datasources/csv/UnivocityParser.scala | 26 +-- .../execution/datasources/csv/CSVSuite.scala | 192 ++++++++++++++++++ 10 files changed, 411 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 448a4732001b5..a0e20d39c20da 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None): + samplingRatio=None, enforceSchema=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, + enforceSchema=enforceSchema) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f9407389864..fae50b3d5d532 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + enforceSchema=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a2450932e303d..ea2dd7605dc57 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3056,6 +3056,24 @@ def test_csv_sampling_ratio(self): .csv(rdd, samplingRatio=0.5).schema self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ac4580a0919ad..de6be5f76e15a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * it determines the columns as string types and it reads only the first line to determine the * names and the number of fields. * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked + * to conform specified or inferred schema. + * * @param csvDataset input Dataset with one CSV row per record * @since 2.2.0 */ @@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + CSVDataSource.checkHeader( + firstLine, + new CsvParser(parsedOptions.asParserSettings), + actualSchema, + csvDataset.getClass.getCanonicalName, + parsedOptions.enforceSchema, + sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) @@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
  • *
  • `header` (default `false`): uses the first line as names of columns.
  • + *
  • `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema + * will be forcibly applied to datasource files, and headers in CSV files will be ignored. + * If the option is set to `false`, the schema will be validated against all headers in CSV files + * in the case when the `header` option is set to `true`. Field names in the schema + * and column names in CSV headers are checked by their positions taking into account + * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable + * the `enforceSchema` option to avoid incorrect results.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • @@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * + * * @since 2.0.0 */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index dc54d182651b1..82322df407521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] + requiredSchema: StructType, + // Actual schema of data in the csv file + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable { } } -object CSVDataSource { +object CSVDataSource extends Logging { def apply(options: CSVOptions): CSVDataSource = { if (options.multiLine) { MultiLineCSVDataSource @@ -118,6 +122,84 @@ object CSVDataSource { TextInputCSVDataSource } } + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema - provided (or inferred) schema to which CSV must conform. + * @param columnNames - names of CSV columns that must be checked against to the schema. + * @param fileName - name of CSV file that are currently checked. It is used in error messages. + * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column + * names are checked for conformance to the schema. In the case if + * the column name don't conform to the schema, an exception is thrown. + * @param caseSensitive - if it is set to `false`, comparison of column names and schema field + * names is not case sensitive. + */ + def checkHeaderColumnNames( + schema: StructType, + columnNames: Array[String], + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |CSV file: $fileName""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |CSV file: $fileName""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + /** + * Checks that CSV header contains the same column names as fields names in the given schema + * by taking into account case sensitivity. + */ + def checkHeader( + header: String, + parser: CsvParser, + schema: StructType, + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + checkHeaderColumnNames( + schema, + parser.parseLine(header), + fileName, + enforceSchema, + caseSensitive) + } } object TextInputCSVDataSource extends CSVDataSource { @@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource { } } - val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) + val hasHeader = parser.options.headerFlag && file.start == 0 + if (hasHeader) { + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + CSVUtils.extractHeader(lines, parser.options).foreach { header => + CSVDataSource.checkHeader( + header, + parser.tokenizer, + dataSchema, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + } + + UnivocityParser.parseIterator(lines, parser, requiredSchema) } override def infer( @@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { + def checkHeader(header: Array[String]): Unit = { + CSVDataSource.checkHeaderColumnNames( + dataSchema, + header, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, - schema) + requiredSchema, + checkHeader) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 21279d6daf7ad..b90275de9f40a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -130,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -137,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) + CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + requiredSchema, + dataSchema, + caseSensitive) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 7119189a4e131..fab8d62da0c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -156,6 +156,12 @@ class CSVOptions( val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 9dae41b63e810..1012e774118e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -68,12 +68,8 @@ object CSVUtils { } } - /** - * Drop header line so that only data can remain. - * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. - */ - def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - val nonEmptyLines = if (options.isCommentSet) { + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { val commentPrefix = options.comment.toString iter.dropWhile { line => line.trim.isEmpty || line.trim.startsWith(commentPrefix) @@ -81,11 +77,19 @@ object CSVUtils { } else { iter.dropWhile(_.trim.isEmpty) } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - iter } + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f39..5f7d5696b71a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -45,7 +45,7 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { + val tokenizer = { val parserSetting = options.asParserSettings if (options.columnPruning && requiredSchema.length < dataSchema.length) { val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) @@ -250,14 +250,15 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + schema: StructType, + checkHeader: Array[String] => Unit): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) - convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten } @@ -265,11 +266,14 @@ private[csv] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer: CsvParser, + checkHeader: Array[String] => Unit = _ => ())( + convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) private var nextRecord = { if (shouldDropHeader) { - tokenizer.parseNext() + val firstRecord = tokenizer.parseNext() + checkHeader(firstRecord) } tokenizer.parseNext() } @@ -291,21 +295,11 @@ private[csv] object UnivocityParser { */ def parseIterator( lines: Iterator[String], - shouldDropHeader: Boolean, parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val options = parser.options - val linesWithoutHeader = if (shouldDropHeader) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, options) - } else { - lines - } - - val filteredLines: Iterator[String] = - CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index afe10bdc4de26..d2f166c7d1877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -23,9 +23,13 @@ import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import scala.collection.JavaConverters._ + import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.log4j.{AppenderSkeleton, LogManager} +import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} @@ -1410,4 +1414,192 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) } } + + def checkHeader(multiLine: Boolean): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val exception = intercept[SparkException] { + spark.read + .schema(ischema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + + val shortSchema = new StructType().add("f1", DoubleType) + val exceptionForShortSchema = intercept[SparkException] { + spark.read + .schema(shortSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForShortSchema.getMessage.contains( + "Number of column in CSV header is not equal to number of fields in the schema")) + + val longSchema = new StructType() + .add("f1", DoubleType) + .add("f2", DoubleType) + .add("f3", DoubleType) + + val exceptionForLongSchema = intercept[SparkException] { + spark.read + .schema(longSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3")) + + val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType) + val caseSensitiveException = intercept[SparkException] { + spark.read + .schema(caseSensitiveSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(caseSensitiveException.getMessage.contains( + "CSV header does not conform to the schema")) + } + } + } + + test(s"SPARK-23786: Checking column names against schema in the multiline mode") { + checkHeader(multiLine = true) + } + + test(s"SPARK-23786: Checking column names against schema in the per-line mode") { + checkHeader(multiLine = false) + } + + test("SPARK-23786: CSV header must not be checked if it doesn't exist") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", false).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val idf = spark.read + .schema(ischema) + .option("header", false) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + + checkAnswer(idf, odf) + } + } + + test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val oschema = new StructType().add("A", StringType) + val odf = spark.createDataFrame(List(Row("0")).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("a", StringType) + val idf = spark.read.schema(ischema) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + checkAnswer(idf, odf) + } + } + } + + test("SPARK-23786: check header on parsing of dataset of strings") { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + val exception = intercept[IllegalArgumentException] { + spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) + } + + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: enforce inferred schema") { + val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType) + val withHeader = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .csv(Seq("_c0,_c1", "1.0,a").toDS()) + assert(withHeader.schema == expectedSchema) + checkAnswer(withHeader, Seq(Row(1.0, "a"))) + + // Ignore the inferSchema flag if an user sets a schema + val schema = new StructType().add("colA", DoubleType).add("colB", StringType) + val ds = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("colA,colB", "1.0,a").toDS()) + assert(ds.schema == schema) + checkAnswer(ds, Seq(Row(1.0, "a"))) + + val exception = intercept[IllegalArgumentException] { + spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("col1,col2", "1.0,a").toDS()) + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") { + class TestAppender extends AppenderSkeleton { + var events = new java.util.ArrayList[LoggingEvent] + override def close(): Unit = {} + override def requiresLayout: Boolean = false + protected def append(event: LoggingEvent): Unit = events.add(event) + } + + val testAppender1 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender1) + try { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + + spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds) + } finally { + LogManager.getRootLogger.removeAppender(testAppender1) + } + assert(testAppender1.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + + val testAppender2 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender2) + try { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + spark.read + .schema(ischema) + .option("header", true) + .option("enforceSchema", true) + .csv(path.getCanonicalPath) + .collect() + } + } finally { + LogManager.getRootLogger.removeAppender(testAppender2) + } + assert(testAppender2.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + } } From 0be5aa27460f87b5627f9de16ec25b09368d205a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 4 Jun 2018 10:16:13 -0700 Subject: [PATCH 0917/2461] [SPARK-23903][SQL] Add support for date extract ## What changes were proposed in this pull request? Add support for date `extract` function: ```sql spark-sql> SELECT EXTRACT(YEAR FROM TIMESTAMP '2000-12-16 12:21:13'); 2000 ``` Supported field same as [Hive](https://github.com/apache/hive/blob/rel/release-2.3.3/ql/src/java/org/apache/hadoop/hive/ql/parse/IdentifiersParser.g#L308-L316): `YEAR`, `QUARTER`, `MONTH`, `WEEK`, `DAY`, `DAYOFWEEK`, `HOUR`, `MINUTE`, `SECOND`. ## How was this patch tested? unit tests Author: Yuming Wang Closes #21479 from wangyum/SPARK-23903. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 + .../sql/catalyst/parser/AstBuilder.scala | 28 ++++++ .../parser/TableIdentifierParserSuite.scala | 2 +- .../resources/sql-tests/inputs/extract.sql | 21 ++++ .../sql-tests/results/extract.sql.out | 96 +++++++++++++++++++ 5 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/extract.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/extract.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7c54851097af3..3fe00eefde7d8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -592,6 +592,7 @@ primaryExpression | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract ; constant @@ -739,6 +740,7 @@ nonReserved | VIEW | REPLACE | IF | POSITION + | EXTRACT | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION @@ -878,6 +880,7 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; +EXTRACT: 'EXTRACT'; EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b9ece295c2510..383ebde3229d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1206,6 +1206,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging new StringLocate(expression(ctx.substr), expression(ctx.str)) } + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + ctx.field.getText.toUpperCase(Locale.ROOT) match { + case "YEAR" => + Year(expression(ctx.source)) + case "QUARTER" => + Quarter(expression(ctx.source)) + case "MONTH" => + Month(expression(ctx.source)) + case "WEEK" => + WeekOfYear(expression(ctx.source)) + case "DAY" => + DayOfMonth(expression(ctx.source)) + case "DAYOFWEEK" => + DayOfWeek(expression(ctx.source)) + case "HOUR" => + Hour(expression(ctx.source)) + case "MINUTE" => + Minute(expression(ctx.source)) + case "SECOND" => + Second(expression(ctx.source)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 89903c2825125..ff0de0fb7c1f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql new file mode 100644 index 0000000000000..9adf5d70056e2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -0,0 +1,21 @@ +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c; + +select extract(year from c) from t; + +select extract(quarter from c) from t; + +select extract(month from c) from t; + +select extract(week from c) from t; + +select extract(day from c) from t; + +select extract(dayofweek from c) from t; + +select extract(hour from c) from t; + +select extract(minute from c) from t; + +select extract(second from c) from t; + +select extract(not_supported from c) from t; diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out new file mode 100644 index 0000000000000..160e4c7d78455 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -0,0 +1,96 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select extract(year from c) from t +-- !query 1 schema +struct +-- !query 1 output +2011 + + +-- !query 2 +select extract(quarter from c) from t +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +select extract(month from c) from t +-- !query 3 schema +struct +-- !query 3 output +5 + + +-- !query 4 +select extract(week from c) from t +-- !query 4 schema +struct +-- !query 4 output +18 + + +-- !query 5 +select extract(day from c) from t +-- !query 5 schema +struct +-- !query 5 output +6 + + +-- !query 6 +select extract(dayofweek from c) from t +-- !query 6 schema +struct +-- !query 6 output +6 + + +-- !query 7 +select extract(hour from c) from t +-- !query 7 schema +struct +-- !query 7 output +7 + + +-- !query 8 +select extract(minute from c) from t +-- !query 8 schema +struct +-- !query 8 output +8 + + +-- !query 9 +select extract(second from c) from t +-- !query 9 schema +struct +-- !query 9 output +9 + + +-- !query 10 +select extract(not_supported from c) from t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'NOT_SUPPORTED' are currently not supported.(line 1, pos 7) + +== SQL == +select extract(not_supported from c) from t +-------^^^ From 7297ae04d87b6e3d48b747a7c1d53687fcc3971c Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 4 Jun 2018 13:28:16 -0700 Subject: [PATCH 0918/2461] [SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions ## What changes were proposed in this pull request? This PR explicitly prohibits window functions inside aggregates. Currently, this will cause StackOverflow during analysis. See PR #19193 for previous discussion. ## How was this patch tested? This PR comes with a dedicated unit test. Author: aokolnychyi Closes #21473 from aokolnychyi/fix-stackoverflow-window-funcs. --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++++-- .../spark/sql/DataFrameAggregateSuite.scala | 34 +++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3eaa9ecf5d075..f9947d1fa6c78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1744,10 +1744,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -1830,6 +1830,10 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate " + + "function. Please use the inner window function in a sub-query.") + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 96c28961e5aaf..f495a949ebc5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.Random -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.scalatest.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21896: Window functions inside aggregate functions") { + def checkWindowError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("not allowed to use a window function")) + } + + checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) + checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) + checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a))))) + checkWindowError( + testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3)) + checkAnswer( + testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) + + checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + checkAnswer( + sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) + } + } From b24d3dba6571fd3c9e2649aceeaadc3f9c6cc90f Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 4 Jun 2018 14:54:31 -0700 Subject: [PATCH 0919/2461] [SPARK-24290][ML] add support for Array input for instrumentation.logNamedValue ## What changes were proposed in this pull request? Extend instrumentation.logNamedValue to support Array input change the logging for "clusterSizes" to new method ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21347 from ludatabricks/SPARK-24290. --- .../spark/ml/clustering/BisectingKMeans.scala | 3 +-- .../spark/ml/clustering/GaussianMixture.scala | 3 +-- .../org/apache/spark/ml/clustering/KMeans.scala | 3 +-- .../org/apache/spark/ml/util/Instrumentation.scala | 13 +++++++++++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1ad4e097246a3..9c9614509c64f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -276,8 +276,7 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 3091bb5a2e54c..64ecc1ebda589 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -426,8 +426,7 @@ class GaussianMixture @Since("2.0.0") ( $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e72d7f9485e6a..1704412741d49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -359,8 +359,7 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 467130b37c16e..3a1c166d46257 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -132,6 +132,19 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Array[String]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Long]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Double]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + /** * Logs the successful completion of the training session. */ From ff0501b0c27dc8149bd5fb38a19d9b0056698766 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 4 Jun 2018 16:08:27 -0700 Subject: [PATCH 0920/2461] [SPARK-24300][ML] change the way to set seed in ml.cluster.LDASuite.generateLDAData ## What changes were proposed in this pull request? Using different RNG in all different partitions. ## How was this patch tested? manually Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21492 from ludatabricks/SPARK-24300. --- .../test/scala/org/apache/spark/ml/clustering/LDASuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 096b5416899e1..db92132d18b7b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -34,9 +34,8 @@ object LDASuite { vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc val sc = spark.sparkContext - val rng = new java.util.Random() - rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => + val rng = new java.util.Random(i) Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) spark.createDataFrame(rdd) From dbb4d83829ec4b51d6e6d3a96f7a4e611d8827bc Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 5 Jun 2018 08:23:08 +0700 Subject: [PATCH 0921/2461] [SPARK-24215][PYSPARK] Implement _repr_html_ for dataframes in PySpark ## What changes were proposed in this pull request? Implement `_repr_html_` for PySpark while in notebook and add config named "spark.sql.repl.eagerEval.enabled" to control this. The dev list thread for context: http://apache-spark-developers-list.1001551.n3.nabble.com/eager-execution-and-debuggability-td23928.html ## How was this patch tested? New ut in DataFrameSuite and manual test in jupyter. Some screenshot below. **After:** ![image](https://user-images.githubusercontent.com/4833765/40268422-8db5bef0-5b9f-11e8-80f1-04bc654a4f2c.png) **Before:** ![image](https://user-images.githubusercontent.com/4833765/40268431-9f92c1b8-5b9f-11e8-9db9-0611f0940b26.png) Author: Yuanjian Li Closes #21370 from xuanyuanking/SPARK-24215. --- docs/configuration.md | 27 ++++++ python/pyspark/sql/dataframe.py | 65 +++++++++++++- python/pyspark/sql/tests.py | 30 +++++++ .../scala/org/apache/spark/sql/Dataset.scala | 84 ++++++++++++------- 4 files changed, 176 insertions(+), 30 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 64af0e98a82f5..5588c372d3e42 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -456,6 +456,33 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. + + spark.sql.repl.eagerEval.enabled + false + + Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, + Dataset will be ran automatically. The HTML table which generated by _repl_html_ + called by notebooks like Jupyter will feedback the queries user have defined. For plain Python + REPL, the output will be shown like dataframe.show() + (see SPARK-24215 for more details). + + + + spark.sql.repl.eagerEval.maxNumRows + 20 + + Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text, + this only take effect when spark.sql.repl.eagerEval.enabled is set to true. + + + + spark.sql.repl.eagerEval.truncate + 20 + + Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or + plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true. + + spark.files diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 808235ab25440..1e6a1acebb5ca 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx): self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opened. + self._support_repr_html = False @property @since(1.3) @@ -351,8 +354,68 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) + @property + def _eager_eval(self): + """Returns true if the eager evaluation enabled. + """ + return self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" + + @property + def _max_num_rows(self): + """Returns the max row number for eager evaluation. + """ + return int(self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.maxNumRows", "20")) + + @property + def _truncate(self): + """Returns the truncate length for eager evaluation. + """ + return int(self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.truncate", "20")) + def __repr__(self): - return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + if not self._support_repr_html and self._eager_eval: + vertical = False + return self._jdf.showString( + self._max_num_rows, self._truncate, vertical) + else: + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def _repr_html_(self): + """Returns a dataframe with html code when you enabled eager evaluation + by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are + using support eager evaluation with HTML. + """ + import cgi + if not self._support_repr_html: + self._support_repr_html = True + if self._eager_eval: + max_num_rows = max(self._max_num_rows, 0) + vertical = False + sock_info = self._jdf.getRowsToPython( + max_num_rows, self._truncate, vertical) + rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + head = rows[0] + row_data = rows[1:] + has_more_data = len(row_data) > max_num_rows + row_data = row_data[:max_num_rows] + + html = "\n" + # generate table head + html += "\n" % "\n" % "
    %s
    ".join(map(lambda x: cgi.escape(x), head)) + # generate table rows + for row in row_data: + html += "
    %s
    ".join( + map(lambda x: cgi.escape(x), row)) + html += "
    \n" + if has_more_data: + html += "only showing top %d %s\n" % ( + max_num_rows, "row" if max_num_rows == 1 else "rows") + return html + else: + return None @since(2.1) def checkpoint(self, eager=True): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ea2dd7605dc57..487eb19c3b98a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3074,6 +3074,36 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) + def test_repr_html(self): + import re + pattern = re.compile(r'^ *\|', re.MULTILINE) + df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) + self.assertEquals(None, df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """ + | + | + | + |
    keyvalue
    11
    2222222222
    + |""" + self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """ + | + | + | + |
    keyvalue
    11
    222222
    + |""" + self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """ + | + | + |
    keyvalue
    11
    + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index abb5ae53f4d73..f5526104690d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -231,16 +231,17 @@ class Dataset[T] private[sql]( } /** - * Compose the string representing rows for output + * Get rows represented in Sequence by specific truncate and vertical requirement. * - * @param _numRows Number of rows to show + * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, prints output rows vertically (one line per column value). + * @param vertical If set to true, the rows to return do not need truncate. */ - private[sql] def showString( - _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + private[sql] def getRows( + numRows: Int, + truncate: Int, + vertical: Boolean): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -251,14 +252,12 @@ class Dataset[T] private[sql]( Column(col).cast(StringType) } } - val takeResult = newDf.select(castCols: _*).take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -274,6 +273,26 @@ class Dataset[T] private[sql]( } }: Seq[String] } + } + + /** + * Compose the string representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + */ + private[sql] def showString( + _numRows: Int, + truncate: Int = 20, + vertical: Boolean = false): String = { + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. + val tmpRows = getRows(numRows, truncate, vertical) + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) val sb = new StringBuilder val numCols = schema.fieldNames.length @@ -291,31 +310,25 @@ class Dataset[T] private[sql]( } } + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + } + } + // Create SeparateLine val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - + paddedRows.head.addString(sb, "|", "|", "|\n") sb.append(sep) // data - rows.tail.foreach { - _.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) sb.append(sep) } else { // Extended display mode enabled @@ -346,7 +359,7 @@ class Dataset[T] private[sql]( } // Print a footer - if (vertical && data.isEmpty) { + if (vertical && rows.tail.isEmpty) { // In a vertical mode, print an empty row set explicitly sb.append("(0 rows)\n") } else if (hasMoreData) { @@ -3209,6 +3222,19 @@ class Dataset[T] private[sql]( } } + private[sql] def getRowsToPython( + _numRows: Int, + truncate: Int, + vertical: Boolean): Array[Any] = { + EvaluatePython.registerPicklers() + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + rows.iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-GetRows") + } + /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ From b3417b731d4e323398a0d7ec6e86405f4464f4f9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 5 Jun 2018 08:29:29 +0700 Subject: [PATCH 0922/2461] [SPARK-16451][REPL] Fail shell if SparkSession fails to start. Currently, in spark-shell, if the session fails to start, the user sees a bunch of unrelated errors which are caused by code in the shell initialization that references the "spark" variable, which does not exist in that case. Things like: ``` :14: error: not found: value spark import spark.sql ``` The user is also left with a non-working shell (unless they want to just write non-Spark Scala or Python code, that is). This change fails the whole shell session at the point where the failure occurs, so that the last error message is the one with the actual information about the failure. For the python error handling, I moved the session initialization code to session.py, so that traceback.print_exc() only shows the last error. Otherwise, the printed exception would contain all previous exceptions with a message "During handling of the above exception, another exception occurred", making the actual error kinda hard to parse. Tested with spark-shell, pyspark (with 2.7 and 3.5), by forcing an error during SparkContext initialization. Author: Marcelo Vanzin Closes #21368 from vanzin/SPARK-16451. --- python/pyspark/shell.py | 26 ++----- python/pyspark/sql/session.py | 34 +++++++++ .../scala/org/apache/spark/repl/Main.scala | 72 ++++++++++--------- 3 files changed, 81 insertions(+), 51 deletions(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index b5fcf7092d93a..472c3cd4452f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -38,25 +38,13 @@ SparkContext._ensure_initialized() try: - # Try to access HiveConf, it will raise exception if Hive is not added - conf = SparkConf() - if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() - else: - spark = SparkSession.builder.getOrCreate() -except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() -except TypeError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() + spark = SparkSession._create_shell_session() +except Exception: + import sys + import traceback + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) sc = spark.sparkContext sql = spark.sql diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d675a240172a7..e880dd1ca6d1a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -547,6 +547,40 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): df._schema = schema return df + @staticmethod + def _create_shell_session(): + """ + Initialize a SparkSession for a pyspark shell session. This is called from shell.py + to make error handling simpler without needing to declare local variables in that + script, which would expose those to users. + """ + import py4j + from pyspark.conf import SparkConf + from pyspark.context import SparkContext + try: + # Try to access HiveConf, it will raise exception if Hive is not added + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + return SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + return SparkSession.builder.getOrCreate() + except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + try: + return SparkSession.builder.getOrCreate() + except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + return SparkSession.builder.getOrCreate() + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index cc76a703bdf8f..e4ddcef9772e4 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -44,6 +44,7 @@ object Main extends Logging { var interp: SparkILoop = _ private var hasErrors = false + private var isShellSession = false private def scalaOptionError(msg: String): Unit = { hasErrors = true @@ -53,6 +54,7 @@ object Main extends Logging { } def main(args: Array[String]) { + isShellSession = true doMain(args, new SparkILoop) } @@ -79,44 +81,50 @@ object Main extends Logging { } def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - conf.setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } + try { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } - val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { - if (SparkSession.hiveClassesArePresent) { - // In the case that the property is not set at all, builder's config - // does not have this value set to 'hive' yet. The original default - // behavior is that when there are hive classes, we use hive catalog. - sparkSession = builder.enableHiveSupport().getOrCreate() - logInfo("Created Spark session with Hive support") + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { - // Need to change it back to 'in-memory' if no hive classes are found - // in the case that the property is set to hive in spark-defaults.conf - builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. sparkSession = builder.getOrCreate() logInfo("Created Spark session") } - } else { - // In the case that the property is set but not to 'hive', the internal - // default is 'in-memory'. So the sparkSession will use in-memory catalog. - sparkSession = builder.getOrCreate() - logInfo("Created Spark session") + sparkContext = sparkSession.sparkContext + sparkSession + } catch { + case e: Exception if isShellSession => + logError("Failed to initialize Spark session.", e) + sys.exit(1) } - sparkContext = sparkSession.sparkContext - sparkSession } } From e8c1a0c2fdb09a628d9cc925676af870d5a7a946 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 4 Jun 2018 21:24:35 -0700 Subject: [PATCH 0923/2461] [SPARK-15784] Add Power Iteration Clustering to spark.ml ## What changes were proposed in this pull request? According to the discussion on JIRA. I rewrite the Power Iteration Clustering API in `spark.ml`. ## How was this patch tested? Unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21493 from WeichenXu123/pic_api. --- .../clustering/PowerIterationClustering.scala | 157 +++++---------- .../PowerIterationClusteringSuite.scala | 179 ++++++++---------- 2 files changed, 125 insertions(+), 211 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..1b9a3499947d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -18,21 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ /** * Common params for PowerIterationClustering */ private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter - with HasPredictionCol { + with HasWeightCol { /** * The number of clusters to create (k). Must be > 1. Default: 2. @@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has def getInitMode: String = $(initMode) /** - * Param for the name of the input column for vertex IDs. - * Default: "id" + * Param for the name of the input column for source vertex IDs. + * Default: "src" * @group param */ @Since("2.4.0") - val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.", (value: String) => value.nonEmpty) - setDefault(idCol, "id") - - /** @group getParam */ - @Since("2.4.0") - def getIdCol: String = getOrDefault(idCol) - - /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "neighbors" - * @group param - */ - @Since("2.4.0") - val neighborsCol = new Param[String](this, "neighborsCol", - "Name of the input column for neighbors in the adjacency list representation.", - (value: String) => value.nonEmpty) - - setDefault(neighborsCol, "neighbors") - /** @group getParam */ @Since("2.4.0") - def getNeighborsCol: String = $(neighborsCol) + def getSrcCol: String = getOrDefault(srcCol) /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "similarities" + * Name of the input column for destination vertex IDs. + * Default: "dst" * @group param */ @Since("2.4.0") - val similaritiesCol = new Param[String](this, "similaritiesCol", - "Name of the input column for neighbors in the adjacency list representation.", + val dstCol = new Param[String](this, "dstCol", + "Name of the input column for destination vertex IDs.", (value: String) => value.nonEmpty) - setDefault(similaritiesCol, "similarities") - /** @group getParam */ @Since("2.4.0") - def getSimilaritiesCol: String = $(similaritiesCol) + def getDstCol: String = $(dstCol) - protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) - SchemaUtils.checkColumnTypes(schema, $(neighborsCol), - Seq(ArrayType(IntegerType, containsNull = false), - ArrayType(LongType, containsNull = false))) - SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), - Seq(ArrayType(FloatType, containsNull = false), - ArrayType(DoubleType, containsNull = false))) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - } + setDefault(srcCol -> "src", dstCol -> "dst") } /** @@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has * PIC finds a very low-dimensional embedding of a dataset using truncated power * iteration on a normalized pair-wise similarity matrix of the data. * - * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix - * is a symmetric matrix whose entries are non-negative similarities between items. - * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: - * - `idCol`: vertex ID - * - `neighborsCol`: neighbors of vertex in `idCol` - * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex - * in `idCol` and each neighbor in `neighborsCol` - * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` - * containing the cluster assignment in `[0,k)` for each row (vertex). - * - * Notes: - * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. - * Transform runs the iterative PIC algorithm to cluster the whole input dataset. - * - Input validation: This validates that similarities are non-negative but does NOT validate - * that the input matrix is symmetric. + * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the + * PowerIterationClustering algorithm. * * @see * Spectral clustering (Wikipedia) @@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has @Experimental class PowerIterationClustering private[clustering] ( @Since("2.4.0") override val uid: String) - extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + extends PowerIterationClusteringParams with DefaultParamsWritable { setDefault( k -> 2, @@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] ( @Since("2.4.0") def this() = this(Identifiable.randomUID("PowerIterationClustering")) - /** @group setParam */ - @Since("2.4.0") - def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group setParam */ @Since("2.4.0") def setK(value: Int): this.type = set(k, value) @@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] ( /** @group setParam */ @Since("2.4.0") - def setIdCol(value: String): this.type = set(idCol, value) + def setSrcCol(value: String): this.type = set(srcCol, value) /** @group setParam */ @Since("2.4.0") - def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + def setDstCol(value: String): this.type = set(dstCol, value) /** @group setParam */ @Since("2.4.0") - def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Run the PIC algorithm and returns a cluster assignment for each input vertex. + * + * @param dataset A dataset with columns src, dst, weight representing the affinity matrix, + * which is the matrix A in the PIC paper. Suppose the src column value is i, + * the dst column value is j, the weight column value is similarity s,,ij,, + * which must be nonnegative. This is a symmetric matrix and hence + * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + * ignored, because we assume s,,ij,, = 0.0. + * + * @return A dataset that contains columns of vertex id and the corresponding cluster for the id. + * The schema of it will be: + * - id: Long + * - cluster: Int + */ @Since("2.4.0") - override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + def assignClusters(dataset: Dataset[_]): DataFrame = { + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { + lit(1.0) + } else { + col($(weightCol)).cast(DoubleType) + } - val sparkSession = dataset.sparkSession - val idColValue = $(idCol) - val rdd: RDD[(Long, Long, Double)] = - dataset.select( - col($(idCol)).cast(LongType), - col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), - col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) - ).rdd.flatMap { - case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => - require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + - s"equal to the the length of the neighbor similarity list. Row for ID " + - s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + - s"of length ${sims.length}.") - nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { - case (nbr, similarity) => (id, nbr, similarity) - } - } + SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType)) + val rdd: RDD[(Long, Long, Double)] = dataset.select( + col($(srcCol)).cast(LongType), + col($(dstCol)).cast(LongType), + w).rdd.map { + case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight) + } val algorithm = new MLlibPowerIterationClustering() .setK($(k)) .setInitializationMode($(initMode)) .setMaxIterations($(maxIter)) val model = algorithm.run(rdd) - val predictionsRDD: RDD[Row] = model.assignments.map { assignment => - Row(assignment.id, assignment.cluster) - } - - val predictionsSchema = StructType(Seq( - StructField($(idCol), LongType, nullable = false), - StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } - - dataset.join(predictions, $(idCol)) - } - - @Since("2.4.0") - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + import dataset.sparkSession.implicits._ + model.assignments.toDF } @Since("2.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..b7072728d48f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,19 +17,19 @@ package org.apache.spark.ml.clustering -import scala.collection.mutable - import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Dataset[_] = _ final val r1 = 1.0 final val n1 = 10 @@ -48,10 +48,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(pic.getK === 2) assert(pic.getMaxIter === 20) assert(pic.getInitMode === "random") - assert(pic.getPredictionCol === "prediction") - assert(pic.getIdCol === "id") - assert(pic.getNeighborsCol === "neighbors") - assert(pic.getSimilaritiesCol === "similarities") + assert(pic.getSrcCol === "src") + assert(pic.getDstCol === "dst") + assert(!pic.isDefined(pic.weightCol)) } test("parameter validation") { @@ -62,125 +61,102 @@ class PowerIterationClusteringSuite extends SparkFunSuite new PowerIterationClustering().setInitMode("no_such_a_mode") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setIdCol("") + new PowerIterationClustering().setSrcCol("") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setNeighborsCol("") - } - intercept[IllegalArgumentException] { - new PowerIterationClustering().setSimilaritiesCol("") + new PowerIterationClustering().setDstCol("") } } test("power iteration clustering") { val n = n1 + n2 - val model = new PowerIterationClustering() + val assignments = new PowerIterationClustering() .setK(2) .setMaxIter(40) - val result = model.transform(data) - - val predictions = Array.fill(2)(mutable.Set.empty[Long]) - result.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions(cluster) += id - } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) - - val result2 = new PowerIterationClustering() + .setWeightCol("weight") + .assignClusters(data) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ + (n1 until n).map(x => (x, 0)).toSet + assert(localAssignments === expectedResult) + + val assignments2 = new PowerIterationClustering() .setK(2) .setMaxIter(10) .setInitMode("degree") - .transform(data) - val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - result2.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions2(cluster) += id - } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + .setWeightCol("weight") + .assignClusters(data) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + assert(localAssignments2 === expectedResult) } test("supported input types") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setK(2) .setMaxIter(1) + .setWeightCol("weight") - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = { val typedData = data.select( - col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), - col("similarities").cast(ArrayType(similarityType, containsNull = false)) - .alias("similarities") + col("src").cast(srcType).alias("src"), + col("dst").cast(dstType).alias("dst"), + col("weight").cast(weightType).alias("weight") ) - model.transform(typedData).collect() - } - - for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) - } - for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + pic.assignClusters(typedData).collect() } - } - test("invalid input: wrong types") { - val model = new PowerIterationClustering() - .setK(2) - .setMaxIter(1) - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id").cast(DoubleType).alias("id"), - col("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (srcType <- Seq(IntegerType, LongType)) { + runTest(srcType, LongType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (dstType <- Seq(IntegerType, LongType)) { + runTest(LongType, dstType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors"), - col("neighbors").alias("similarities") - ) - model.transform(typedData) + for (weightType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, weightType) } } test("invalid input: negative similarity") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setMaxIter(1) + .setWeightCol("weight") val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(-1.0)), - (1, Array(0), Array(-1.0)) - )).toDF("id", "neighbors", "similarities") + (0, 1, -1.0), + (1, 0, -1.0) + )).toDF("src", "dst", "weight") val msg = intercept[SparkException] { - model.transform(badData) + pic.assignClusters(badData) }.getCause.getMessage assert(msg.contains("Similarity must be nonnegative")) } - test("invalid input: mismatched lengths for neighbor and similarity arrays") { - val model = new PowerIterationClustering() - .setMaxIter(1) - val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(0.5)), - (1, Array(0, 2), Array(0.5)), - (2, Array(1), Array(0.5)) - )).toDF("id", "neighbors", "similarities") - val msg = intercept[SparkException] { - model.transform(badData) - }.getCause.getMessage - assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + - "the neighbor similarity list.")) - assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + test("test default weight") { + val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) + + val assignments = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithoutWeight) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0)) + + val assignments2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithWeightOne) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + assert(localAssignments === localAssignments2) } test("read/write") { @@ -188,10 +164,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(4) .setMaxIter(100) .setInitMode("degree") - .setIdCol("test_id") - .setNeighborsCol("myNeighborsCol") - .setSimilaritiesCol("mySimilaritiesCol") - .setPredictionCol("test_prediction") + .setSrcCol("src1") + .setDstCol("dst1") + .setWeightCol("weight") testDefaultReadWrite(t) } } @@ -222,17 +197,13 @@ object PowerIterationClusteringSuite { val n = n1 + n2 val points = genCircle(r1, n1) ++ genCircle(r2, n2) - val rows = for (i <- 1 until n) yield { - val neighbors = for (j <- 0 until i) yield { - j.toLong + val rows = (for (i <- 1 until n) yield { + for (j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) } - val similarities = for (j <- 0 until i) yield { - sim(points(i), points(j)) - } - (i.toLong, neighbors.toArray, similarities.toArray) - } + }).flatMap(_.iterator) - spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + spark.createDataFrame(rows).toDF("src", "dst", "weight") } } From 2c2a86b5d5be6f77ee72d16f990b39ae59f479b9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 5 Jun 2018 01:08:55 -0700 Subject: [PATCH 0924/2461] [SPARK-24453][SS] Fix error recovering from the failure in a no-data batch ## What changes were proposed in this pull request? The error occurs when we are recovering from a failure in a no-data batch (say X) that has been planned (i.e. written to offset log) but not executed (i.e. not written to commit log). Upon recovery the following sequence of events happen. 1. `MicroBatchExecution.populateStartOffsets` sets `currentBatchId` to X. Since there was no data in the batch, the `availableOffsets` is same as `committedOffsets`, so `isNewDataAvailable` is `false`. 2. When `MicroBatchExecution.constructNextBatch` is called, ideally it should immediately return true because the next batch has already been constructed. However, the check of whether the batch has been constructed was `if (isNewDataAvailable) return true`. Since the planned batch is a no-data batch, it escaped this check and proceeded to plan the same batch X *once again*. The solution is to have an explicit flag that signifies whether a batch has already been constructed or not. `populateStartOffsets` is going to set the flag appropriately. ## How was this patch tested? new unit test Author: Tathagata Das Closes #21491 from tdas/SPARK-24453. --- .../streaming/MicroBatchExecution.scala | 38 ++++++---- .../streaming/MicroBatchExecutionSuite.scala | 71 +++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 2 +- 3 files changed, 98 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 7817360810bde..17ffa2a517312 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -126,6 +126,12 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed + * (i.e. written to the offsetLog) and is ready for execution. + */ + private var isCurrentBatchConstructed = false + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -154,7 +160,6 @@ class MicroBatchExecution( triggerExecutor.execute(() => { if (isActive) { - var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run var currentBatchHasNewData = false // Whether the current batch had new data startTrigger() @@ -175,7 +180,9 @@ class MicroBatchExecution( // new data to process as `constructNextBatch` may decide to run a batch for // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data // is available or not. - currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + if (!isCurrentBatchConstructed) { + isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) + } // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed @@ -183,7 +190,7 @@ class MicroBatchExecution( currentBatchHasNewData = isNewDataAvailable currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) - if (currentBatchIsRunnable) { + if (isCurrentBatchConstructed) { if (currentBatchHasNewData) updateStatusMessage("Processing new data") else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) @@ -194,9 +201,12 @@ class MicroBatchExecution( finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded - // If the current batch has been executed, then increment the batch id, else there was - // no data to execute the batch - if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) + // If the current batch has been executed, then increment the batch id and reset flag. + // Otherwise, there was no data to execute the batch and sleep for some time + if (isCurrentBatchConstructed) { + currentBatchId += 1 + isCurrentBatchConstructed = false + } else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -231,6 +241,7 @@ class MicroBatchExecution( /* First assume that we are re-executing the latest known batch * in the offset log */ currentBatchId = latestBatchId + isCurrentBatchConstructed = true availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ @@ -269,6 +280,7 @@ class MicroBatchExecution( // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 + isCurrentBatchConstructed = false committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets } else if (latestCommittedBatchId < latestBatchId - 1) { @@ -313,11 +325,8 @@ class MicroBatchExecution( * - If either of the above is true, then construct the next batch by committing to the offset * log that range of offsets that the next batch will process. */ - private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { - // If new data is already available that means this method has already been called before - // and it must have already committed the offset range of next batch to the offset log. - // Hence do nothing, just return true. - if (isNewDataAvailable) return true + private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked { + if (isCurrentBatchConstructed) return true // Generate a map from each unique source to the next available offset. val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { @@ -348,9 +357,14 @@ class MicroBatchExecution( batchTimestampMs = triggerClock.getTimeMillis()) // Check whether next batch should be constructed - val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + logTrace( + s"noDataBatchesEnabled = $noDataBatchesEnabled, " + + s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " + + s"isNewDataAvailable = $isNewDataAvailable, " + + s"shouldConstructNextBatch = $shouldConstructNextBatch") if (shouldConstructNextBatch) { // Commit the next batch offset range to the offset log diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala new file mode 100644 index 0000000000000..c228740df07c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.StreamTest + +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(df)( + AddData(inputData, 10, 11, 12, 13, 14, 15), // Set watermark to 5 + CheckAnswer(), + AddData(inputData, 25), // Set watermark to 15 to make MicroBatchExecution run no-data batch + CheckAnswer((10, 5)), // Last batch should be a no-data batch + StopStream, + Execute { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) never completed + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + // Add data before start so that MicroBatchExecution can plan a batch. It should not, + // it should first re-run the incomplete no-data batch and then run a new batch to process + // new data. + AddData(inputData, 30), + StartStream(), + CheckNewAnswer((15, 1)), // This should not throw the error reported in SPARK-24156 + StopStream, + Execute { q => + // Delete the entire commit log + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit + 1) + }, + AddData(inputData, 50), + StartStream(), + CheckNewAnswer((25, 1), (30, 1)) // This should not throw the error reported in SPARK-24156 + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index f348dac1319cb..4c3fd58cb2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -292,7 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }) + AssertOnQuery(query => { func(query); true }, "Execute") } object AwaitEpoch { From 93df3cd03503fca7745141fbd2676b8bf70fe92f Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 5 Jun 2018 11:32:42 -0700 Subject: [PATCH 0925/2461] [SPARK-22384][SQL] Refine partition pruning when attribute is wrapped in Cast ## What changes were proposed in this pull request? Sql below will get all partitions from metastore, which put much burden on metastore; ``` CREATE TABLE `partition_test`(`col` int) PARTITIONED BY (`pt` byte) SELECT * FROM partition_test WHERE CAST(pt AS INT)=1 ``` The reason is that the the analyzed attribute `dt` is wrapped in `Cast` and `HiveShim` fails to generate a proper partition filter. This pr proposes to take `Cast` into consideration when generate partition filter. ## How was this patch tested? Test added. This pr proposes to use analyzed expressions in `HiveClientSuite` Author: jinxing Closes #19602 from jinxing64/SPARK-22384. --- .../spark/sql/hive/client/HiveShim.scala | 23 +++- .../sql/hive/client/HiveClientSuite.scala | 102 ++++++++++++------ 2 files changed, 86 insertions(+), 39 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 948ba542b5733..130e258e78ca2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + object ExtractAttribute { + def unapply(expr: Expression): Option[Attribute] = { + expr match { + case attr: Attribute => Some(attr) + case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case _ => None + } + } + } + def convert(expr: Expression): Option[String] = expr match { - case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + case op @ SpecialBinaryComparison( + ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) => Some(s"$name ${op.symbol} $value") - case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + case op @ SpecialBinaryComparison( + ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) => Some(s"$value ${op.symbol} $name") case And(expr1, expr2) if useAdvanced => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index f991352b207d4..55275f6b37945 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.LongType // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { - import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname @@ -46,8 +46,7 @@ class HiveClientSuite(version: String) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) val client = buildClient(hadoopConf) - client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") val partitions = for { @@ -66,6 +65,15 @@ class HiveClientSuite(version: String) client } + private def attr(name: String): Attribute = { + client.getTable("default", "test").partitionSchema.fields + .find(field => field.name.equals(name)) match { + case Some(field) => AttributeReference(field.name, field.dataType)() + case None => + fail(s"Illegal name of partition attribute: $name") + } + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -74,7 +82,7 @@ class HiveClientSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(parseExpression("ds=20170101"))) + Seq(attr("ds") === 20170101)) assert(filteredPartitions.size == testPartitionCount) } @@ -82,7 +90,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported testMetastorePartitionFiltering( - "ds<=>20170101", + attr("ds") <=> 20170101, 20170101 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -90,7 +98,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( - "ds=20170101", + attr("ds") === 20170101, 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -100,7 +108,7 @@ class HiveClientSuite(version: String) // Should return all partitions where h=0 because getPartitionsByFilter does not support // comparisons to non-literal values testMetastorePartitionFiltering( - "ds=(20170101 + 1) and h=0", + attr("ds") === (Literal(20170101) + 1) && attr("h") === 0, 20170101 to 20170103, 0 to 0, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -108,7 +116,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( - "chunk='aa'", + attr("chunk") === "aa", 20170101 to 20170103, 0 to 23, "aa" :: Nil) @@ -116,7 +124,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( - "20170101=ds", + Literal(20170101) === attr("ds"), 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -124,7 +132,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( - "ds=20170101 and h=10", + attr("ds") === 20170101 && attr("h") === 10, + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, 10 to 10, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -132,7 +148,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( - "ds=20170101 or ds=20170102", + attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -140,7 +156,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -148,7 +172,19 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") + { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -159,7 +195,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -167,7 +203,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -179,26 +215,24 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering( - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } @@ -207,41 +241,41 @@ class HiveClientSuite(version: String) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String]): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], transform: Expression => Expression): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + transform) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], transform: Expression => Expression): Unit = { val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( - transform(parseExpression(filterString)) + transform(filterExpr) )) val expectedPartitionCount = expectedPartitionCubes.map { From e9efb62e0795c8d5233b7e5bfc276d74953942b8 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 6 Jun 2018 08:31:35 +0700 Subject: [PATCH 0926/2461] [SPARK-24187][R][SQL] Add array_join function to SparkR ## What changes were proposed in this pull request? This PR adds array_join function to SparkR ## How was this patch tested? Add unit test in test_sparkSQL.R Author: Huaxin Gao Closes #21313 from huaxingao/spark-24187. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 29 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 15 ++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 73a33af4dd48b..9696f6987ad78 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_join", "array_max", "array_min", "array_position", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index abc91aeeb4825..3bff633fbc1ff 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -221,7 +221,9 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) -#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) +#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} NULL #' Window functions for Column operations @@ -3006,6 +3008,27 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_join}: Concatenates the elements of column using the delimiter. +#' Null values are replaced with nullReplacement if set, otherwise they are ignored. +#' +#' @param delimiter a character string that is used to concatenate the elements of column. +#' @param nullReplacement an optional character string that is used to replace the Null values. +#' @rdname column_collection_functions +#' @aliases array_join array_join,Column-method +#' @note array_join since 2.4.0 +setMethod("array_join", + signature(x = "Column", delimiter = "character"), + function(x, delimiter, nullReplacement = NULL) { + jc <- if (is.null(nullReplacement)) { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter) + } else { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter, + as.character(nullReplacement)) + } + column(jc) + }) + #' @details #' \code{array_max}: Returns the maximum value of the array. #' @@ -3197,8 +3220,8 @@ setMethod("size", #' (or starting from the end if start is negative) with the specified length. #' #' @rdname column_collection_functions -#' @param start an index indicating the first element occuring in the result. -#' @param length a number of consecutive elements choosen to the result. +#' @param start an index indicating the first element occurring in the result. +#' @param length a number of consecutive elements chosen to the result. #' @aliases slice slice,Column-method #' @note slice since 2.4.0 setMethod("slice", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8894cb1c5b92f..9321bbaf96ff8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_max", function(x) { standardGeneric("array_max") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 16c1fd5a065eb..36e0f78bb0599 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1518,6 +1518,21 @@ test_that("column functions", { result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_join() + df <- createDataFrame(list(list(list("Hello", "World!")))) + result <- collect(select(df, array_join(df[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df2 <- createDataFrame(list(list(list("Hello", NA, "World!")))) + result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df3 <- createDataFrame(list(list(list("Hello", NULL, "World!")))) + result <- collect(select(df3, array_join(df3[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df3, array_join(df3[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) From e76b0124fbe463def00b1dffcfd8fd47e04772fe Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Wed, 6 Jun 2018 07:14:08 -0700 Subject: [PATCH 0927/2461] [SPARK-23803][SQL] Support bucket pruning ## What changes were proposed in this pull request? support bucket pruning when filtering on a single bucketed column on the following predicates - EqualTo, EqualNullSafe, In, And/Or predicates ## How was this patch tested? refactored unit tests to test the above. based on gatorsmile work in https://github.com/apache/spark/commit/e3c75c6398b1241500343ff237e9bcf78b5396f9 Author: Asher Saban Author: asaban Closes #20915 from sabanas/filter-prune-buckets. --- .../sql/execution/DataSourceScanExec.scala | 32 ++++- .../datasources/BucketingUtils.scala | 14 +++ .../datasources/DataSourceStrategy.scala | 12 -- .../datasources/FileSourceStrategy.scala | 98 +++++++++++++++- .../spark/sql/sources/BucketedReadSuite.scala | 110 +++++++++++++++--- 5 files changed, 231 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 61c14fee09337..d7f2654be0451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation @@ -151,6 +152,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan, including data attributes and partition attributes. * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. + * @param optionalBucketSet Bucket ids for bucket pruning * @param dataFilters Filters on non-partition columns. * @param tableIdentifier identifier for the table in the metastore. */ @@ -159,6 +161,7 @@ case class FileSourceScanExec( output: Seq[Attribute], requiredSchema: StructType, partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -286,7 +289,20 @@ case class FileSourceScanExec( } getOrElse { metadata } - withOptPartitionCount + + val withSelectedBucketsCount = relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + withOptPartitionCount + ("SelectedBucketsCount" -> + s"$numSelectedBuckets out of ${spec.numBuckets}") + } getOrElse { + withOptPartitionCount + } + + withSelectedBucketsCount } private lazy val inputRDD: RDD[InternalRow] = { @@ -365,7 +381,7 @@ case class FileSourceScanExec( selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val bucketed = + val filesGroupedToBuckets = selectedPartitions.flatMap { p => p.files.map { f => val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) @@ -377,8 +393,17 @@ case class FileSourceScanExec( .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil)) } new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) @@ -503,6 +528,7 @@ case class FileSourceScanExec( output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, QueryPlan.normalizePredicates(partitionFilters, output), + optionalBucketSet, QueryPlan.normalizePredicates(dataFilters, output), None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index ea4fe9c8ade5f..a776fc3e7021d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning + object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name @@ -35,5 +38,16 @@ object BucketingUtils { case other => None } + // Given bucketColumn, numBuckets and value, returns the corresponding bucketId + def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) + mutableInternalRow.update(0, value) + + val bucketIdGenerator = UnsafeProjection.create( + HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + bucketIdGenerator(mutableInternalRow).getInt(0) + } + def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3f41612c08065..7b129435c45db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case _ => Nil } - // Get the bucket ID based on the bucketing values. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) - val bucketIdGeneration = UnsafeProjection.create( - HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, - bucketColumn :: Nil) - - bucketIdGeneration(mutableRow).getInt(0) - } - // Based on Public API. private def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 0a568d6b8adce..fe27b78bf3360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan * and add it. Proceed to the next file. */ object FileSourceStrategy extends Strategy with Logging { + + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column + private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { + bucketSpec match { + case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 + case None => false + } + } + + private def getExpressionBuckets( + expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { + + def getBucketNumber(attr: Attribute, v: Any): Int = { + BucketingUtils.getBucketIdFromValue(attr, numBuckets, v) + } + + def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + iter + .map(v => getBucketNumber(attr, v)) + .foreach(bucketNum => matchedBuckets.set(bucketNum)) + matchedBuckets + } + + def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(getBucketNumber(attr, v)) + matchedBuckets + } + + expr match { + case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getBucketSetFromValue(a, v) + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) + case expressions.InSet(a: Attribute, hset) + if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + getBucketSetFromValue(a, null) + case expressions.And(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) & + getExpressionBuckets(right, bucketColumnName, numBuckets) + case expressions.Or(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case _ => + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.setUntil(numBuckets) + matchedBuckets + } + } + + private def genBucketSet( + normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { + if (normalizedFilters.isEmpty) { + return None + } + + val bucketColumnName = bucketSpec.bucketColumnNames.head + val numBuckets = bucketSpec.numBuckets + + val normalizedFiltersAndExpr = normalizedFilters + .reduce(expressions.And) + val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName, + numBuckets) + + val numBucketsSelected = matchedBuckets.cardinality() + + logInfo { + s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." + } + + // None means all the buckets need to be scanned + if (numBucketsSelected == numBuckets) { + None + } else { + Some(matchedBuckets) + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => @@ -82,6 +168,13 @@ object FileSourceStrategy extends Strategy with Logging { logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec + val bucketSet = if (shouldPruneBuckets(bucketSpec)) { + genBucketSet(normalizedFilters, bucketSpec.get) + } else { + None + } + val dataColumns = l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) @@ -111,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, + bucketSet, dataFilters, table.map(_.identifier)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fb61fa716b946..a9414200e70f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,10 +22,11 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ @@ -52,6 +53,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + // number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF + // empty buckets before filtering might hide bugs in pruning logic + private val NumBucketsForPruningDF = 7 + private val NumBucketsForPruningNullDf = 5 + test("read bucketed data") { withTable("bucketed_table") { df.write @@ -90,32 +96,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) - val matchedBuckets = new BitSet(numBuckets) - bucketValues.foreach { value => - matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) - } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) assert(rdd.isDefined, plan) - val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => - if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() + // if nothing should be pruned, skip the pruning test + if (bucketValues.nonEmpty) { + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value)) + } + val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of partitions that should have been pruned and are not empty + if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (invalidBuckets.nonEmpty) { + fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan") + } } - // TODO: These tests are not testing the right columns. -// // checking if all the pruned buckets are empty -// val invalidBuckets = checkedResult.collect().toList -// if (invalidBuckets.nonEmpty) { -// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") -// } checkAnswer( bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), @@ -125,7 +136,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -155,13 +166,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = Seq(j, j + 1, j + 2, j + 3), filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), df) + + // Case 4: InSet + val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr)) + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = Column(inSetExpr), + df) } } } test("read non-partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -181,7 +200,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having null in bucketing key") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningNullDf val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here nullDF.write @@ -208,7 +227,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having composite filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -229,7 +248,62 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = j :: Nil, filterCondition = $"j" === j && $"i" > j % 5, df) + + // check multiple bucket values OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1), + filterCondition = $"j" === j || $"j" === (j + 1), + df) + + // check bucket value and none bucket value OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Nil, + filterCondition = $"j" === j || $"i" === 0, + df) + + // check AND condition in complex expression + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j), + filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, + df) + } + } + } + + test("read bucketed table without filters") { + withTable("bucketed_table") { + val numBuckets = NumBucketsForPruningDF + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") + val plan = bucketedDataFrame.queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) + assert(rdd.isDefined, plan) + + val emptyBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of empty partitions + if (iter.isEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (emptyBuckets.nonEmpty) { + fail(s"Buckets ${emptyBuckets.mkString(",")} should not have been pruned from:\n$plan") } + + checkAnswer( + bucketedDataFrame.orderBy("i", "j", "k"), + df.orderBy("i", "j", "k")) } } From 1462bba4fd99a264ebc8679db91dfc62d0b9a35f Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 8 Jun 2018 13:27:52 +0200 Subject: [PATCH 0928/2461] [SPARK-24119][SQL] Add interpreted execution to SortPrefix expression ## What changes were proposed in this pull request? Implemented eval in SortPrefix expression. ## How was this patch tested? - ran existing sbt SQL tests - added unit test - ran existing Python SQL tests - manual tests: disabling codegen -- patching code to disable beyond what spark.sql.codegen.wholeStage=false can do -- and running sbt SQL tests Author: Bruce Robbins Closes #21231 from bersprockets/sortprefixeval. --- .../sql/catalyst/expressions/SortOrder.scala | 37 +++++++- .../SortOrderExpressionsSuite.scala | 95 +++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 2ce9d072c71c9..76a881146a146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { @@ -148,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (!child.isAscending && child.nullOrdering == NullsLast) } - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + private lazy val calcPrefix: Any => Long = child.child.dataType match { + case BooleanType => (raw) => + if (raw.asInstanceOf[Boolean]) 1 else 0 + case DateType | TimestampType | _: IntegralType => (raw) => + raw.asInstanceOf[java.lang.Number].longValue() + case FloatType | DoubleType => (raw) => { + val dVal = raw.asInstanceOf[java.lang.Number].doubleValue() + DoublePrefixComparator.computePrefix(dVal) + } + case StringType => (raw) => + StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String]) + case BinaryType => (raw) => + BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]]) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + _.asInstanceOf[Decimal].toUnscaledLong + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + val p = Decimal.MAX_LONG_DIGITS + val s = p - (dt.precision - dt.scale) + (raw) => { + val value = raw.asInstanceOf[Decimal] + if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue + } + case dt: DecimalType => (raw) => + DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble) + case _ => (Any) => 0L + } + + override def eval(input: InternalRow): Any = { + val value = child.child.eval(input) + if (value == null) { + null + } else { + calcPrefix(value) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala new file mode 100644 index 0000000000000..cc2e2a993d629 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ + +class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SortPrefix") { + val b1 = Literal.create(false, BooleanType) + val b2 = Literal.create(true, BooleanType) + val i1 = Literal.create(20132983, IntegerType) + val i2 = Literal.create(-20132983, IntegerType) + val l1 = Literal.create(20132983, LongType) + val l2 = Literal.create(-20132983, LongType) + val millis = 1524954911000L; + // Explicitly choose a time zone, since Date objects can create different values depending on + // local time zone of the machine on which the test is running + val oldDefaultTZ = TimeZone.getDefault + val d1 = try { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Literal.create(new java.sql.Date(millis), DateType) + } finally { + TimeZone.setDefault(oldDefaultTZ) + } + val t1 = Literal.create(new Timestamp(millis), TimestampType) + val f1 = Literal.create(0.7788229f, FloatType) + val f2 = Literal.create(-0.7788229f, FloatType) + val db1 = Literal.create(0.7788229d, DoubleType) + val db2 = Literal.create(-0.7788229d, DoubleType) + val s1 = Literal.create("T", StringType) + val s2 = Literal.create("This is longer than 8 characters", StringType) + val bin1 = Literal.create(Array[Byte](12), BinaryType) + val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]), + BinaryType) + val dec1 = Literal(Decimal(20132983L, 10, 2)) + val dec2 = Literal(Decimal(20132983L, 19, 2)) + val dec3 = Literal(Decimal(20132983L, 21, 2)) + val list1 = Literal(List(1, 2), ArrayType(IntegerType)) + val nullVal = Literal.create(null, IntegerType) + + checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) + checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L) + checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L) + // For some reason, the Literal.create code gives us the number of days since the epoch + checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L) + checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000) + checkEvaluation(SortPrefix(SortOrder(f1, Ascending)), + DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(f2, Ascending)), + DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(db1, Ascending)), + DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(db2, Ascending)), + DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(s1, Ascending)), + StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(s2, Ascending)), + StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)), + BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)), + BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L) + checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)), + DoublePrefixComparator.computePrefix(201329.83d)) + checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) + } +} From 2c100209f0b73e882ab953993b307867d1df7c2f Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 8 Jun 2018 08:44:59 -0500 Subject: [PATCH 0929/2461] [SPARK-24224][ML-EXAMPLES] Java example code for Power Iteration Clustering in spark.ml ## What changes were proposed in this pull request? Java example code for Power Iteration Clustering in spark.ml ## How was this patch tested? Locally tested Author: Shahid Closes #21283 from shahidki31/JavaPicExample. --- .../JavaPowerIterationClusteringExample.java | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java new file mode 100644 index 0000000000000..51865637df6f6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.clustering.PowerIterationClustering; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + // Create a SparkSession. + SparkSession spark = SparkSession + .builder() + .appName("JavaPowerIterationClustering") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(0L, 1L, 1.0), + RowFactory.create(0L, 2L, 1.0), + RowFactory.create(1L, 2L, 1.0), + RowFactory.create(3L, 4L, 1.0), + RowFactory.create(4L, 0L, 0.1) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("src", DataTypes.LongType, false, Metadata.empty()), + new StructField("dst", DataTypes.LongType, false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + PowerIterationClustering model = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .setWeightCol("weight"); + + Dataset result = model.assignClusters(df); + result.show(false); + // $example off$ + spark.stop(); + } +} From a5d775a1f3aad7bef0ac0f93869eaf96b677411b Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 8 Jun 2018 08:45:56 -0500 Subject: [PATCH 0930/2461] [SPARK-24191][ML] Scala Example code for Power Iteration Clustering ## What changes were proposed in this pull request? Added example code for Power Iteration Clustering in Spark ML examples Author: Shahid Closes #21248 from shahidki31/sparkCommit. --- .../ml/PowerIterationClusteringExample.scala | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..ca8f7affb14e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.clustering.PowerIterationClustering +// $example off$ +import org.apache.spark.sql.SparkSession + +object PowerIterationClusteringExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame(Seq( + (0L, 1L, 1.0), + (0L, 2L, 1.0), + (1L, 2L, 1.0), + (3L, 4L, 1.0), + (4L, 0L, 0.1) + )).toDF("src", "dst", "weight") + + val model = new PowerIterationClustering(). + setK(2). + setMaxIter(20). + setInitMode("degree"). + setWeightCol("weight") + + val prediction = model.assignClusters(dataset).select("id", "cluster") + + // Shows the cluster assignment + prediction.show(false) + // $example off$ + + spark.stop() + } + } From 173fe450df203b262b58f7e71c6b52a79db95ee0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 8 Jun 2018 09:32:11 -0700 Subject: [PATCH 0931/2461] [SPARK-24477][SPARK-24454][ML][PYTHON] Imports submodule in ml/__init__.py and add ImageSchema into __all__ ## What changes were proposed in this pull request? This PR attaches submodules to ml's `__init__.py` module. Also, adds `ImageSchema` into `image.py` explicitly. ## How was this patch tested? Before: ```python >>> from pyspark import ml >>> ml.image Traceback (most recent call last): File "", line 1, in AttributeError: 'module' object has no attribute 'image' >>> ml.image.ImageSchema Traceback (most recent call last): File "", line 1, in AttributeError: 'module' object has no attribute 'image' ``` ```python >>> "image" in globals() False >>> from pyspark.ml import * >>> "image" in globals() False >>> image Traceback (most recent call last): File "", line 1, in NameError: name 'image' is not defined ``` After: ```python >>> from pyspark import ml >>> ml.image >>> ml.image.ImageSchema ``` ```python >>> "image" in globals() False >>> from pyspark.ml import * >>> "image" in globals() True >>> image ``` Author: hyukjinkwon Closes #21483 from HyukjinKwon/SPARK-24454. --- python/pyspark/ml/__init__.py | 8 +++++++- python/pyspark/ml/image.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 129d7d68f7cbb..d99a25390db15 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -21,5 +21,11 @@ """ from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml import classification, clustering, evaluation, feature, fpm, \ + image, pipeline, recommendation, regression, stat, tuning, util, linalg, param -__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"] +__all__ = [ + "Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel", + "classification", "clustering", "evaluation", "feature", "fpm", "image", + "recommendation", "regression", "stat", "tuning", "util", "linalg", "param", +] diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 96d702f844839..5f0c57ee3cc67 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -31,6 +31,8 @@ from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession +__all__ = ["ImageSchema"] + class _ImageSchema(object): """ From 1a644afbac35c204f9ad55f86999319a9ab458c6 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 8 Jun 2018 11:18:34 -0700 Subject: [PATCH 0932/2461] [SPARK-23984][K8S] Initial Python Bindings for PySpark on K8s ## What changes were proposed in this pull request? Introducing Python Bindings for PySpark. - [x] Running PySpark Jobs - [x] Increased Default Memory Overhead value - [ ] Dependency Management for virtualenv/conda ## How was this patch tested? This patch was tested with - [x] Unit Tests - [x] Integration tests with [this addition](https://github.com/apache-spark-on-k8s/spark-integration/pull/46) ``` KubernetesSuite: - Run SparkPi with no resources - Run SparkPi with a very long application name. - Run SparkPi with a master URL without a scheme. - Run SparkPi with an argument. - Run SparkPi with custom labels, annotations, and environment variables. - Run SparkPi with a test secret mounted into the driver and executor pods - Run extraJVMOptions check on driver - Run SparkRemoteFileTest using a remote data file - Run PySpark on simple pi.py example - Run PySpark with Python2 to test a pyfiles example - Run PySpark with Python3 to test a pyfiles example Run completed in 4 minutes, 28 seconds. Total number of tests run: 11 Suites: completed 2, aborted 0 Tests: succeeded 11, failed 0, canceled 0, ignored 0, pending 0 All tests passed. ``` Author: Ilan Filonenko Author: Ilan Filonenko Closes #21092 from ifilonenko/master. --- bin/docker-image-tool.sh | 23 +++- .../org/apache/spark/deploy/SparkSubmit.scala | 14 ++- docs/running-on-kubernetes.md | 16 ++- .../src/main/python/py_container_checks.py | 32 ++++++ examples/src/main/python/pyfiles.py | 38 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 40 +++++++ .../apache/spark/deploy/k8s/Constants.scala | 7 +- .../spark/deploy/k8s/KubernetesConf.scala | 62 +++++++--- .../spark/deploy/k8s/KubernetesUtils.scala | 2 +- .../k8s/features/BasicDriverFeatureStep.scala | 14 +-- .../features/BasicExecutorFeatureStep.scala | 3 +- .../bindings/JavaDriverFeatureStep.scala | 44 +++++++ .../bindings/PythonDriverFeatureStep.scala | 73 ++++++++++++ .../submit/KubernetesClientApplication.scala | 16 ++- .../k8s/submit/KubernetesDriverBuilder.scala | 39 +++++-- .../deploy/k8s/submit/MainAppResource.scala | 5 + .../k8s/KubernetesExecutorBuilder.scala | 22 ++-- .../deploy/k8s/KubernetesConfSuite.scala | 66 +++++++++-- .../BasicDriverFeatureStepSuite.scala | 58 +++++++++- .../BasicExecutorFeatureStepSuite.scala | 9 +- ...ubernetesCredentialsFeatureStepSuite.scala | 9 +- .../DriverServiceFeatureStepSuite.scala | 18 ++- .../features/EnvSecretsFeatureStepSuite.scala | 3 +- .../features/LocalDirsFeatureStepSuite.scala | 3 +- .../MountSecretsFeatureStepSuite.scala | 3 +- .../bindings/JavaDriverFeatureStepSuite.scala | 60 ++++++++++ .../PythonDriverFeatureStepSuite.scala | 108 ++++++++++++++++++ .../spark/deploy/k8s/submit/ClientSuite.scala | 3 +- .../submit/KubernetesDriverBuilderSuite.scala | 78 +++++++++++-- .../k8s/KubernetesExecutorBuilderSuite.scala | 6 +- .../spark/bindings/python/Dockerfile | 39 +++++++ .../src/main/dockerfiles/spark/entrypoint.sh | 30 +++++ 32 files changed, 842 insertions(+), 101 deletions(-) create mode 100644 examples/src/main/python/py_container_checks.py create mode 100644 examples/src/main/python/pyfiles.py create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index f090240065bf1..a871ab5d448c3 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -63,12 +63,20 @@ function build { if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi - - local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local BINDING_BUILD_ARGS=( + --build-arg + base_img=$(image_ref spark) + ) + local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} docker build "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$DOCKERFILE" . + -f "$BASEDOCKERFILE" . + + docker build "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" . } function push { @@ -86,7 +94,8 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: - -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. + -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. + -p file Dockerfile with Python baked in. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -116,12 +125,14 @@ fi REPO= TAG= -DOCKERFILE= +BASEDOCKERFILE= +PYDOCKERFILE= while getopts f:mr:t: option do case "${option}" in - f) DOCKERFILE=${OPTARG};; + f) BASEDOCKERFILE=${OPTARG};; + p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; m) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a46af26feb061..e83d82f847c61 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -285,8 +285,6 @@ private[spark] class SparkSubmit extends Logging { case (STANDALONE, CLUSTER) if args.isR => error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") - case (KUBERNETES, _) if args.isPython => - error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => @@ -694,9 +692,17 @@ private[spark] class SparkSubmit extends Logging { if (isKubernetesCluster) { childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS if (args.primaryResource != SparkLauncher.NO_RESOURCE) { - childArgs ++= Array("--primary-java-resource", args.primaryResource) + if (args.isPython) { + childArgs ++= Array("--primary-py-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner") + if (args.pyFiles != null) { + childArgs ++= Array("--other-py-files", args.pyFiles) + } + } else { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + childArgs ++= Array("--main-class", args.mainClass) + } } - childArgs ++= Array("--main-class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 4eac9bd9032e4..408e446ea4822 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -270,7 +270,6 @@ future versions of the spark-kubernetes integration. Some of these include: -* PySpark * R * Dynamic Executor Scaling * Local File Dependency Management @@ -631,4 +630,19 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + spark.kubernetes.memoryOverheadFactor + 0.1 + + This sets the Memory Overhead Factor that will allocate memory to non-JVM memory, which includes off-heap memory allocations, non-JVM tasks, and various systems processes. For JVM-based jobs this value will default to 0.10 and 0.40 for non-JVM jobs. + This is done as non-JVM tasks need more non-JVM heap space and such tasks commonly fail with "Memory Overhead Exceeded" errors. This prempts this error with a higher default. + + + + spark.kubernetes.pyspark.pythonversion + "2" + + This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. + + diff --git a/examples/src/main/python/py_container_checks.py b/examples/src/main/python/py_container_checks.py new file mode 100644 index 0000000000000..f6b3be2806c82 --- /dev/null +++ b/examples/src/main/python/py_container_checks.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys + + +def version_check(python_env, major_python_version): + """ + These are various tests to test the Python container image. + This file will be distributed via --py-files in the e2e tests. + """ + env_version = os.environ.get('PYSPARK_PYTHON') + print("Python runtime version check is: " + + str(sys.version_info[0] == major_python_version)) + + print("Python environment version check is: " + + str(env_version == python_env)) diff --git a/examples/src/main/python/pyfiles.py b/examples/src/main/python/pyfiles.py new file mode 100644 index 0000000000000..4193654b49a12 --- /dev/null +++ b/examples/src/main/python/pyfiles.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: pyfiles [major_python_version] + """ + spark = SparkSession \ + .builder \ + .appName("PyFilesTest") \ + .getOrCreate() + + from py_container_checks import version_check + # Begin of Python container checks + version_check(sys.argv[1], 2 if sys.argv[1] == "python" else 3) + + spark.stop() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 560dedf431b08..590deaa72e7ee 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -117,6 +117,28 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("spark") + val KUBERNETES_PYSPARK_PY_FILES = + ConfigBuilder("spark.kubernetes.python.pyFiles") + .doc("The PyFiles that are distributed via client arguments") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.python.mainAppResource") + .doc("The main app resource for pyspark jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_APP_ARGS = + ConfigBuilder("spark.kubernetes.python.appArgs") + .doc("The app arguments for PySpark Jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") .doc("Number of pods to launch at once in each round of executor allocation.") @@ -154,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = + ConfigBuilder("spark.kubernetes.memoryOverheadFactor") + .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + + "which in the case of JVM tasks will default to 0.10 and 0.40 for non-JVM jobs") + .doubleConf + .checkValue(mem_overhead => mem_overhead >= 0 && mem_overhead < 1, + "Ensure that memory overhead is a double between 0 --> 1.0") + .createWithDefault(0.1) + + val PYSPARK_MAJOR_PYTHON_VERSION = + ConfigBuilder("spark.kubernetes.pyspark.pythonversion") + .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)") + .stringConf + .checkValue(pv => List("2", "3").contains(pv), + "Ensure that major Python version is either Python2 or Python3") + .createWithDefault("2") + + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 8da5f24044aad..69bd03d1eda6f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -71,9 +71,14 @@ private[spark] object Constants { val SPARK_CONF_FILE_NAME = "spark.properties" val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" + // BINDINGS + val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" + val ENV_PYSPARK_FILES = "PYSPARK_FILES" + val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" + val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" + // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" - val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN_MIB = 384L } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 5a944187a7096..b0ccaa36b01ed 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -16,14 +16,17 @@ */ package org.apache.spark.deploy.k8s +import scala.collection.mutable + import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config.ConfigEntry + private[spark] sealed trait KubernetesRoleSpecificConf /* @@ -55,7 +58,8 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], roleSecretEnvNamesToKeyRefs: Map[String, String], - roleEnvs: Map[String, String]) { + roleEnvs: Map[String, String], + sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -64,10 +68,14 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( .map(str => str.split(",").toSeq) .getOrElse(Seq.empty[String]) - def sparkFiles(): Seq[String] = sparkConf - .getOption("spark.files") - .map(str => str.split(",").toSeq) - .getOrElse(Seq.empty[String]) + def pyFiles(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_PY_FILES) + + def pySparkMainResource(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE) + + def pySparkPythonVersion(): String = sparkConf + .get(PYSPARK_MAJOR_PYTHON_VERSION) def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) @@ -102,17 +110,30 @@ private[spark] object KubernetesConf { appId: String, mainAppResource: Option[MainAppResource], mainClass: String, - appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + appArgs: Array[String], + maybePyFiles: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = { val sparkConfWithMainAppJar = sparkConf.clone() + val additionalFiles = mutable.ArrayBuffer.empty[String] mainAppResource.foreach { - case JavaMainAppResource(res) => - val previousJars = sparkConf - .getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty) - if (!previousJars.contains(res)) { - sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) - } + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + // The function of this outer match is to account for multiple nonJVM + // bindings that will all have increased MEMORY_OVERHEAD_FACTOR to 0.4 + case nonJVM: NonJVMResource => + nonJVM match { + case PythonMainAppResource(res) => + additionalFiles += res + maybePyFiles.foreach{maybePyFiles => + additionalFiles.appendAll(maybePyFiles.split(","))} + sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) + } + sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) } val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -135,6 +156,11 @@ private[spark] object KubernetesConf { val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val sparkFiles = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) ++ additionalFiles + KubernetesConf( sparkConfWithMainAppJar, KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), @@ -144,7 +170,8 @@ private[spark] object KubernetesConf { driverAnnotations, driverSecretNamesToMountPaths, driverSecretEnvNamesToKeyRefs, - driverEnvs) + driverEnvs, + sparkFiles) } def createExecutorConf( @@ -186,6 +213,7 @@ private[spark] object KubernetesConf { executorAnnotations, executorMountSecrets, executorEnvSecrets, - executorEnv) + executorEnv, + Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index ee629068ad90d..593fb531a004d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -52,7 +52,7 @@ private[spark] object KubernetesUtils { } } - private def resolveFileUri(uri: String): String = { + def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 07bdccbe0479d..143dc8a12304e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -25,8 +25,8 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ -import org.apache.spark.launcher.SparkLauncher private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -48,7 +48,8 @@ private[spark] class BasicDriverFeatureStep( private val driverMemoryMiB = conf.get(DRIVER_MEMORY) private val memoryOverheadMiB = conf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + .getOrElse(math.max((conf.get(MEMORY_OVERHEAD_FACTOR) * driverMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configurePod(pod: SparkPod): SparkPod = { @@ -88,13 +89,6 @@ private[spark] class BasicDriverFeatureStep( .addToRequests("memory", driverMemoryQuantity) .addToLimits("memory", driverMemoryQuantity) .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", conf.roleSpecificConf.mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - .addToArgs(conf.roleSpecificConf.appArgs: _*) .build() val driverPod = new PodBuilder(pod.pod) @@ -122,7 +116,7 @@ private[spark] class BasicDriverFeatureStep( val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( conf.sparkJars()) val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( - conf.sparkFiles()) + conf.sparkFiles) if (resolvedSparkJars.nonEmpty) { additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 529069d3b8a0c..91c54a9776982 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -54,7 +54,8 @@ private[spark] class BasicExecutorFeatureStep( private val memoryOverheadMiB = kubernetesConf .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + .getOrElse(math.max( + (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala new file mode 100644 index 0000000000000..f52ec9fdc677e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep +import org.apache.spark.launcher.SparkLauncher + +private[spark] class JavaDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val withDriverArgs = new ContainerBuilder(pod.container) + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", kubernetesConf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(kubernetesConf.roleSpecificConf.appArgs: _*) + .build() + SparkPod(pod.pod, withDriverArgs) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala new file mode 100644 index 0000000000000..c20bcac1f8987 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class PythonDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") + val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + pyArgs => + new EnvVarBuilder() + .withName(ENV_PYSPARK_ARGS) + .withValue(pyArgs.mkString(",")) + .build()) + val maybePythonFiles = kubernetesConf.pyFiles().map( + // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH + // of the respective PySpark pod + pyFiles => + new EnvVarBuilder() + .withName(ENV_PYSPARK_FILES) + .withValue(KubernetesUtils.resolveFileUrisAndPath(pyFiles.split(",")) + .mkString(":")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_PYSPARK_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.pySparkMainResource().get)) + .build(), + new EnvVarBuilder() + .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) + .withValue(kubernetesConf.pySparkPythonVersion()) + .build()) + val pythonEnvs = envSeq ++ + maybePythonArgs.toSeq ++ + maybePythonFiles.toSeq + + val withPythonPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(pythonEnvs.asJava) + .addToArgs("driver-py") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withPythonPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index a97f5650fb869..eaff47205dbbc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -39,11 +39,13 @@ import org.apache.spark.util.Utils * @param mainAppResource the main application resource if any * @param mainClass the main class of the application to run * @param driverArgs arguments to the driver + * @param maybePyFiles additional Python files via --py-files */ private[spark] case class ClientArguments( mainAppResource: Option[MainAppResource], mainClass: String, - driverArgs: Array[String]) + driverArgs: Array[String], + maybePyFiles: Option[String]) private[spark] object ClientArguments { @@ -51,10 +53,15 @@ private[spark] object ClientArguments { var mainAppResource: Option[MainAppResource] = None var mainClass: Option[String] = None val driverArgs = mutable.ArrayBuffer.empty[String] + var maybePyFiles : Option[String] = None args.sliding(2, 2).toList.foreach { case Array("--primary-java-resource", primaryJavaResource: String) => mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) + case Array("--primary-py-file", primaryPythonResource: String) => + mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + case Array("--other-py-files", pyFiles: String) => + maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => mainClass = Some(clazz) case Array("--arg", arg: String) => @@ -69,7 +76,8 @@ private[spark] object ClientArguments { ClientArguments( mainAppResource, mainClass.get, - driverArgs.toArray) + driverArgs.toArray, + maybePyFiles) } } @@ -206,6 +214,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, appName, @@ -213,7 +222,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesAppId, clientArguments.mainAppResource, clientArguments.mainClass, - clientArguments.driverArgs) + clientArguments.driverArgs, + clientArguments.maybePyFiles) val builder = new KubernetesDriverBuilder val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index fdc5eb0d75832..5762d8245f778 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -33,9 +34,17 @@ private[spark] class KubernetesDriverBuilder( provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) - => LocalDirsFeatureStep = - new LocalDirsFeatureStep(_)) { + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => LocalDirsFeatureStep) = + new LocalDirsFeatureStep(_), + provideJavaStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => JavaDriverFeatureStep) = + new JavaDriverFeatureStep(_), + providePythonStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => PythonDriverFeatureStep) = + new PythonDriverFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { @@ -44,13 +53,23 @@ private[spark] class KubernetesDriverBuilder( provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures - allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) - } else allFeatures + val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Some(provideSecretsStep(kubernetesConf)) } else None + + val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Some(provideEnvSecretsStep(kubernetesConf)) } else None + + val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { + case JavaMainAppResource(_) => + provideJavaStep(kubernetesConf) + case PythonMainAppResource(_) => + providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf)) + + val allFeatures: Seq[KubernetesFeatureConfigStep] = + (baseFeatures :+ bindingsStep) ++ + maybeRoleSecretNamesStep.toSeq ++ + maybeProvideSecretsStep.toSeq var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index cca9f4627a1f6..cbe081ae35683 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -18,4 +18,9 @@ package org.apache.spark.deploy.k8s.submit private[spark] sealed trait MainAppResource +private[spark] sealed trait NonJVMResource + private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource + +private[spark] case class PythonMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d5e1de36a58df..769a0a5a63047 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = @@ -34,14 +34,20 @@ private[spark] class KubernetesExecutorBuilder( def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) - allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) - } else allFeatures + val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Some(provideSecretsStep(kubernetesConf)) } else None + + val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Some(provideEnvSecretsStep(kubernetesConf)) } else None + + val allFeatures: Seq[KubernetesFeatureConfigStep] = + baseFeatures ++ + maybeRoleSecretNamesStep.toSeq ++ + maybeProvideSecretsStep.toSeq var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index 3d23e1cb90fd2..661f942435921 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit._ class KubernetesConfSuite extends SparkFunSuite { @@ -56,9 +56,10 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.appId === APP_ID) assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) @@ -79,7 +80,8 @@ class KubernetesConfSuite extends SparkFunSuite { APP_ID, mainAppJar, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") .split(",") === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) @@ -88,15 +90,59 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithoutMainJar.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.1) } - test("Resolve driver labels, annotations, secret mount paths, and envs.") { + test("Creating driver conf with a python primary file") { + val mainResourceFile = "local:///opt/spark/main.py" + val inputPyFiles = Array("local:///opt/spark/example2.py", "local:///example3.py") val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example4.py") + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + Some(inputPyFiles.mkString(","))) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) + } + + test("Testing explicit setting of memory overhead on non-JVM tasks") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) + + val mainResourceFile = "local:///opt/spark/main.py" + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + None) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) + } + + test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) CUSTOM_LABELS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) } @@ -118,9 +164,10 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.roleLabels === Map( SPARK_APP_ID_LABEL -> APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ @@ -129,6 +176,7 @@ class KubernetesConfSuite extends SparkFunSuite { assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) } test("Basic executor translated fields.") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index b2813d8b3265d..04b909db9d9f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -24,6 +24,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -33,6 +35,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val PY_MAIN_CLASS = "org.apache.spark.deploy.PythonRunner" private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" @@ -60,7 +63,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val kubernetesConf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("")), APP_NAME, MAIN_CLASS, APP_ARGS), @@ -70,7 +73,8 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_ANNOTATIONS, Map.empty, Map.empty, - DRIVER_ENVS) + DRIVER_ENVS, + Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) val basePod = SparkPod.initialPod() @@ -110,7 +114,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") - val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, @@ -119,6 +122,50 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) } + test("Check appropriate entrypoint rerouting for various bindings") { + val javaSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val pythonSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val javaKubernetesConf = KubernetesConf( + javaSparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Seq.empty[String]) + val pythonKubernetesConf = KubernetesConf( + pythonSparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Seq.empty[String]) + val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) + val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) + val basePod = SparkPod.initialPod() + val configuredJavaPod = javaFeatureStep.configurePod(basePod) + val configuredPythonPod = pythonFeatureStep.configurePod(basePod) + } + test("Additional system properties resolve jars and set cluster-mode confs.") { val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") @@ -130,7 +177,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val kubernetesConf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("")), APP_NAME, MAIN_CLASS, APP_ARGS), @@ -140,7 +187,8 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_ANNOTATIONS, Map.empty, Map.empty, - Map.empty) + DRIVER_ENVS, + allFiles) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 9182134b3337c..f06030aa55c0c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -88,7 +88,8 @@ class BasicExecutorFeatureStepSuite ANNOTATIONS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) // The executor pod name and default labels. @@ -126,7 +127,8 @@ class BasicExecutorFeatureStepSuite ANNOTATIONS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -145,7 +147,8 @@ class BasicExecutorFeatureStepSuite ANNOTATIONS, Map.empty, Map.empty, - Map("qux" -> "quux"))) + Map("qux" -> "quux"), + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) checkEnv(executor, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index f81894f8055f1..7cea83591f3e8 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -60,7 +60,8 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) @@ -90,7 +91,8 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -127,7 +129,8 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index f265522a8823a..77d38bf19cd10 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -66,7 +66,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) @@ -96,7 +97,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX val expectedHostName = s"$expectedServiceName.my-namespace.svc" @@ -116,7 +118,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() .head @@ -145,7 +148,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty), + Map.empty, + Seq.empty[String]), clock) val driverService = configurationStep .getAdditionalKubernetesResources() @@ -171,7 +175,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty), + Map.empty, + Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") } catch { @@ -195,7 +200,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty), + Map.empty, + Seq.empty[String]), clock) fail("The driver host address should not be allowed.") } catch { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index 8b0b2d0739c76..af6b35eae484a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -44,7 +44,8 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ Map.empty, Map.empty, envVarsToKeys, - Map.empty) + Map.empty, + Seq.empty[String]) val step = new EnvSecretsFeatureStep(kubernetesConf) val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 2542a02d37766..bd6ce4b42fc8e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -44,7 +44,8 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) } test("Resolve to default local dir if neither env nor configuration are set") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9155793774123..eff75b8a15daa 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -42,7 +42,8 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, secretNamesToMountPaths, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..0f2bf2fa1d9b5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class JavaDriverFeatureStepSuite extends SparkFunSuite { + + test("Java Step modifies container correctly") { + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.jar")), + "test-class", + "java-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + sparkFiles = Seq.empty[String]) + + val step = new JavaDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithJavaStep = step.configurePod(baseDriverPod).container + assert(driverContainerwithJavaStep.getArgs.size === 7) + val args = driverContainerwithJavaStep + .getArgs.asScala + assert(args === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class", + "spark-internal", "5 7")) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..a1f9a5d9e264e --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class PythonDriverFeatureStepSuite extends SparkFunSuite { + + test("Python Step modifies container correctly") { + val expectedMainResource = "/main.py" + val mainResource = "local:///main.py" + val pyFiles = Seq("local:///example2.py", "local:///example3.py") + val expectedPySparkFiles = + "/example2.py:/example3.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(KUBERNETES_PYSPARK_PY_FILES, pyFiles.mkString(",")) + .set("spark.files", "local:///example.py") + .set(PYSPARK_MAJOR_PYTHON_VERSION, "2") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-app", + "python-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + sparkFiles = Seq.empty[String]) + + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + assert(driverContainerwithPySpark.getEnv.size === 4) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) + assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) + assert(envs(ENV_PYSPARK_ARGS) === "5 7") + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") + } + test("Python Step testing empty pyfiles") { + val mainResource = "local:///main.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(PYSPARK_MAJOR_PYTHON_VERSION, "3") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-class-py", + "python-runner", + Seq.empty[String]), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + sparkFiles = Seq.empty[String]) + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + val args = driverContainerwithPySpark + .getArgs.asScala + assert(driverContainerwithPySpark.getArgs.size === 5) + assert(args === List( + "driver-py", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class-py")) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 0775338098a13..a8a8218c621ea 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -143,7 +143,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index cb724068ea4f3..4e8c300543430 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -27,6 +28,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val JAVA_STEP_TYPE = "java-bindings" + private val PYSPARK_STEP_TYPE = "pyspark-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( @@ -44,6 +47,12 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val javaStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + JAVA_STEP_TYPE, classOf[JavaDriverFeatureStep]) + + private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) @@ -54,13 +63,15 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => serviceStep, _ => secretsStep, _ => envSecretsStep, - _ => localDirsStep) + _ => localDirsStep, + _ => javaStep, + _ => pythonStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("example.jar")), "test-app", "main", Seq.empty), @@ -70,13 +81,15 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE) + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -93,7 +106,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map("secret" -> "secretMountPath"), Map("EnvName" -> "SecretName:secretKey"), - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -101,8 +115,58 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE, - ENV_SECRETS_STEP_TYPE - ) + ENV_SECRETS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Java step if main resource is none.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Python step if main resource is python.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("example.py")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + PYSPARK_STEP_TYPE) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 753cd30a237f3..a6bc8bce32926 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -54,7 +54,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } @@ -70,7 +71,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map("secret" -> "secretMountPath"), Map("secret-name" -> "secret-key"), - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile new file mode 100644 index 0000000000000..72bb9620b45de --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/python +COPY python/lib ${SPARK_HOME}/python/lib +# TODO: Investigate running both pip and pip3 via virtualenvs +RUN apk add --no-cache python && \ + apk add --no-cache python3 && \ + python -m ensurepip && \ + python3 -m ensurepip && \ + # We remove ensurepip since it adds no functionality since pip is + # installed on the image and it just takes up 1.6MB on the image + rm -r /usr/lib/python*/ensurepip && \ + pip install --upgrade pip setuptools && \ + # You may install with python3 packages by using pip3.6 + # Removed the .cache to save space + rm -r /root/.cache + +ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3e166116aa3fd..acdb4b1f09e0a 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -53,6 +53,28 @@ if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." . fi +if [ -n "$PYSPARK_FILES" ]; then + PYTHONPATH="$PYTHONPATH:$PYSPARK_FILES" +fi + +PYSPARK_ARGS="" +if [ -n "$PYSPARK_APP_ARGS" ]; then + PYSPARK_ARGS="$PYSPARK_APP_ARGS" +fi + + +if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then + pyv="$(python -V 2>&1)" + export PYTHON_VERSION="${pyv:7}" + export PYSPARK_PYTHON="python" + export PYSPARK_DRIVER_PYTHON="python" +elif [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "3" ]; then + pyv3="$(python3 -V 2>&1)" + export PYTHON_VERSION="${pyv3:7}" + export PYSPARK_PYTHON="python3" + export PYSPARK_DRIVER_PYTHON="python3" +fi + case "$SPARK_K8S_CMD" in driver) CMD=( @@ -62,6 +84,14 @@ case "$SPARK_K8S_CMD" in "$@" ) ;; + driver-py) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS + ) + ;; executor) CMD=( From b070ded2843e88131c90cb9ef1b4f8d533f8361d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 9 Jun 2018 01:27:51 +0700 Subject: [PATCH 0933/2461] [SPARK-17756][PYTHON][STREAMING] Workaround to avoid return type mismatch in PythonTransformFunction ## What changes were proposed in this pull request? This PR proposes to wrap the transformed rdd within `TransformFunction`. `PythonTransformFunction` looks requiring to return `JavaRDD` in `_jrdd`. https://github.com/apache/spark/blob/39e2bad6a866d27c3ca594d15e574a1da3ee84cc/python/pyspark/streaming/util.py#L67 https://github.com/apache/spark/blob/6ee28423ad1b2e6089b82af64a31d77d3552bb38/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala#L43 However, this could be `JavaPairRDD` by some APIs, for example, `zip` in PySpark's RDD API. `_jrdd` could be checked as below: ```python >>> rdd.zip(rdd)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaPairRDD' ``` So, here, I wrapped it with `map` so that it ensures returning `JavaRDD`. ```python >>> rdd.zip(rdd).map(lambda x: x)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaRDD' ``` I tried to elaborate some failure cases as below: ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]) \ .transform(lambda rdd: rdd.cartesian(rdd)) \ .pprint() ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.cartesian(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).union(rdd.zip(rdd))) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).coalesce(1)) ssc.start() ``` ## How was this patch tested? Unit tests were added in `python/pyspark/streaming/tests.py` and manually tested. Author: hyukjinkwon Closes #19498 from HyukjinKwon/SPARK-17756. --- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/tests.py | 6 ++++++ python/pyspark/streaming/util.py | 11 ++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 17c34f8a1c54c..dd924ef89868e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -338,7 +338,7 @@ def transform(self, dstreams, transformFunc): jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + lambda t, *rdds: transformFunc(rdds), *[d._jrdd_deserializer for d in dstreams]) jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index e4a428a0b27e7..373784f826677 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -779,6 +779,12 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + def test_get_active(self): self.assertEqual(StreamingContext.getActive(), None) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index df184471993ff..b4b9f97feb7ca 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -20,6 +20,8 @@ import traceback import sys +from py4j.java_gateway import is_instance_of + from pyspark import SparkContext, RDD @@ -65,7 +67,14 @@ def call(self, milliseconds, jrdds): t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: - return r._jrdd + # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`. + # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return + # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`. + # See SPARK-17756. + if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"): + return r._jrdd + else: + return r.map(lambda x: x)._jrdd except: self.failure = traceback.format_exc() From f433ef786770e48e3594ad158ce9908f98ef0d9a Mon Sep 17 00:00:00 2001 From: Sean Suchter Date: Fri, 8 Jun 2018 15:15:24 -0700 Subject: [PATCH 0934/2461] [SPARK-23010][K8S] Initial checkin of k8s integration tests. These tests were developed in the https://github.com/apache-spark-on-k8s/spark-integration repo by several contributors. This is a copy of the current state into the main apache spark repo. The only changes from the current spark-integration repo state are: * Move the files from the repo root into resource-managers/kubernetes/integration-tests * Add a reference to these tests in the root README.md * Fix a path reference in dev/dev-run-integration-tests.sh * Add a TODO in include/util.sh ## What changes were proposed in this pull request? Incorporation of Kubernetes integration tests. ## How was this patch tested? This code has its own unit tests, but the main purpose is to provide the integration tests. I tested this on my laptop by running dev/dev-run-integration-tests.sh --spark-tgz ~/spark-2.4.0-SNAPSHOT-bin--.tgz The spark-integration tests have already been running for months in AMPLab, here is an example: https://amplab.cs.berkeley.edu/jenkins/job/testing-k8s-scheduled-spark-integration-master/ Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Sean Suchter Author: Sean Suchter Closes #20697 from ssuchter/ssuchter-k8s-integration-tests. --- README.md | 2 + dev/tox.ini | 2 +- pom.xml | 1 + .../kubernetes/integration-tests/README.md | 52 ++++ .../dev/dev-run-integration-tests.sh | 93 ++++++ .../integration-tests/dev/spark-rbac.yaml | 52 ++++ .../kubernetes/integration-tests/pom.xml | 155 +++++++++ .../scripts/setup-integration-test-env.sh | 91 ++++++ .../src/test/resources/log4j.properties | 31 ++ .../k8s/integrationtest/KubernetesSuite.scala | 294 ++++++++++++++++++ .../KubernetesTestComponents.scala | 120 +++++++ .../k8s/integrationtest/ProcessUtils.scala | 46 +++ .../SparkReadinessWatcher.scala | 41 +++ .../deploy/k8s/integrationtest/Utils.scala | 30 ++ .../backend/IntegrationTestBackend.scala | 43 +++ .../backend/minikube/Minikube.scala | 84 +++++ .../minikube/MinikubeTestBackend.scala | 42 +++ .../deploy/k8s/integrationtest/config.scala | 38 +++ .../k8s/integrationtest/constants.scala | 22 ++ 19 files changed, 1238 insertions(+), 1 deletion(-) create mode 100644 resource-managers/kubernetes/integration-tests/README.md create mode 100755 resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh create mode 100644 resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml create mode 100644 resource-managers/kubernetes/integration-tests/pom.xml create mode 100755 resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh create mode 100644 resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala diff --git a/README.md b/README.md index 1e521a7e7b178..531d330234062 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,8 @@ can be run using: Please see the guidance on how to [run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). +There is also a Kubernetes integration test, see resource-managers/kubernetes/integration-tests/README.md + ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported diff --git a/dev/tox.ini b/dev/tox.ini index 583c1eaaa966b..28dad8f3b5c7c 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -16,4 +16,4 @@ [pycodestyle] ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* diff --git a/pom.xml b/pom.xml index 883c096ae1ae9..23bbd3b09734e 100644 --- a/pom.xml +++ b/pom.xml @@ -2705,6 +2705,7 @@ kubernetes resource-managers/kubernetes/core + resource-managers/kubernetes/integration-tests diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md new file mode 100644 index 0000000000000..b3863e6b7d1af --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -0,0 +1,52 @@ +--- +layout: global +title: Spark on Kubernetes Integration Tests +--- + +# Running the Kubernetes Integration Tests + +Note that the integration test framework is currently being heavily revised and +is subject to change. Note that currently the integration tests only run with Java 8. + +The simplest way to run the integration tests is to install and run Minikube, then run the following: + + dev/dev-run-integration-tests.sh + +The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should +run with a minimum of 3 CPUs and 4G of memory: + + minikube start --cpus 3 --memory 4096 + +You can download Minikube [here](https://github.com/kubernetes/minikube/releases). + +# Integration test customization + +Configuration of the integration test runtime is done through passing different arguments to the test script. The main useful options are outlined below. + +## Re-using Docker Images + +By default, the test framework will build new Docker images on every test execution. A unique image tag is generated, +and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker image tag +that you have built by other means already, pass the tag to the test script: + + dev/dev-run-integration-tests.sh --image-tag + +where if you still want to use images that were built before by the test framework: + + dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt) + +## Spark Distribution Under Test + +The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to specify the tarball: + +* `--spark-tgz ` - set `` to point to a tarball containing the Spark distribution to test. + +TODO: Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current tree. + +## Customizing the Namespace and Service Account + +* `--namespace ` - set `` to the namespace in which the tests should be run. +* `--service-account ` - set `` to the name of the Kubernetes service account to +use in the namespace specified by the `--namespace`. The service account is expected to have permissions to get, list, watch, +and create pods. For clusters with RBAC turned on, it's important that the right permissions are granted to the service account +in the namespace through an appropriate role and role binding. A reference RBAC configuration is provided in `dev/spark-rbac.yaml`. diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh new file mode 100755 index 0000000000000..ea893fa39eede --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests + +cd "${TEST_ROOT_DIR}" + +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +SPARK_TGZ="N/A" +IMAGE_TAG="N/A" +SPARK_MASTER= +NAMESPACE= +SERVICE_ACCOUNT= + +# Parse arguments +while (( "$#" )); do + case $1 in + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + --spark-master) + SPARK_MASTER="$2" + shift + ;; + --namespace) + NAMESPACE="$2" + shift + ;; + --service-account) + SERVICE_ACCOUNT="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +cd $TEST_ROOT_DIR + +properties=( + -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ + -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ + -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE +) + +if [ -n $NAMESPACE ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) +fi + +if [ -n $SERVICE_ACCOUNT ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) +fi + +if [ -n $SPARK_MASTER ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) +fi + +../../../build/mvn integration-test ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml new file mode 100644 index 0000000000000..a4c242f2f2645 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +apiVersion: v1 +kind: Namespace +metadata: + name: spark +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: spark-sa + namespace: spark +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRole +metadata: + name: spark-role +rules: +- apiGroups: + - "" + resources: + - "pods" + verbs: + - "*" +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRoleBinding +metadata: + name: spark-role-binding +subjects: +- kind: ServiceAccount + name: spark-sa + namespace: spark +roleRef: + kind: ClusterRole + name: spark-role + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml new file mode 100644 index 0000000000000..520bda89e034d --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -0,0 +1,155 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes-integration-tests_2.11 + spark-kubernetes-integration-tests + + 1.3.0 + 1.4.0 + + 3.0.0 + 3.2.2 + 1.0 + kubernetes-integration-tests + ${project.build.directory}/spark-dist-unpacked + N/A + ${project.build.directory}/imageTag.txt + minikube + docker.io/kubespark + + + jar + Spark Project Kubernetes Integration Tests + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + io.fabric8 + kubernetes-client + ${kubernetes-client.version} + + + + + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + + setup-integration-test-env + pre-integration-test + + exec + + + scripts/setup-integration-test-env.sh + + --unpacked-spark-tgz + ${spark.kubernetes.test.unpackSparkDir} + + --image-repo + ${spark.kubernetes.test.imageRepo} + + --image-tag + ${spark.kubernetes.test.imageTag} + + --image-tag-output-file + ${spark.kubernetes.test.imageTagFile} + + --deploy-mode + ${spark.kubernetes.test.deployMode} + + --spark-tgz + ${spark.kubernetes.test.sparkTgz} + + + + + + + + org.scalatest + scalatest-maven-plugin + ${scalatest-maven-plugin.version} + + ${project.build.directory}/surefire-reports + . + SparkTestSuite.txt + -ea -Xmx3g -XX:ReservedCodeCacheSize=512m ${extraScalaTestArgs} + + + file:src/test/resources/log4j.properties + true + ${spark.kubernetes.test.imageTagFile} + ${spark.kubernetes.test.unpackSparkDir} + ${spark.kubernetes.test.imageRepo} + ${spark.kubernetes.test.deployMode} + ${spark.kubernetes.test.master} + ${spark.kubernetes.test.namespace} + ${spark.kubernetes.test.serviceAccountName} + + ${test.exclude.tags} + + + + test + + test + + + + (?<!Suite) + + + + integration-test + integration-test + + test + + + + + + + + + diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh new file mode 100755 index 0000000000000..ccfb8e767c529 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) +UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" +IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +IMAGE_TAG="N/A" +SPARK_TGZ="N/A" + +# Parse arguments +while (( "$#" )); do + case $1 in + --unpacked-spark-tgz) + UNPACKED_SPARK_TGZ="$2" + shift + ;; + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --image-tag-output-file) + IMAGE_TAG_OUTPUT_FILE="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +if [[ $SPARK_TGZ == "N/A" ]]; +then + echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; +fi + +rm -rf $UNPACKED_SPARK_TGZ +mkdir -p $UNPACKED_SPARK_TGZ +tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + +if [[ $IMAGE_TAG == "N/A" ]]; +then + IMAGE_TAG=$(uuidgen); + cd $UNPACKED_SPARK_TGZ + if [[ $DEPLOY_MODE == cloud ]] ; + then + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + if [[ $IMAGE_REPO == gcr.io* ]] ; + then + gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG + else + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push + fi + else + # -m option for minikube. + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + fi + cd - +fi + +rm -f $IMAGE_TAG_OUTPUT_FILE +echo -n $IMAGE_TAG > $IMAGE_TAG_OUTPUT_FILE diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..866126bc3c1c2 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/integration-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/integration-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala new file mode 100644 index 0000000000000..65c513cf241a4 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File +import java.nio.file.{Path, Paths} +import java.util.UUID +import java.util.regex.Pattern + +import scala.collection.JavaConverters._ + +import com.google.common.io.PatternFilenameFilter +import io.fabric8.kubernetes.api.model.{Container, Pod} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.{Eventually, PatienceConfiguration} +import org.scalatest.time.{Minutes, Seconds, Span} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} +import org.apache.spark.deploy.k8s.integrationtest.config._ + +private[spark] class KubernetesSuite extends SparkFunSuite + with BeforeAndAfterAll with BeforeAndAfter { + + import KubernetesSuite._ + + private var testBackend: IntegrationTestBackend = _ + private var sparkHomeDir: Path = _ + private var kubernetesTestComponents: KubernetesTestComponents = _ + private var sparkAppConf: SparkAppConf = _ + private var image: String = _ + private var containerLocalSparkDistroExamplesJar: String = _ + private var appLocator: String = _ + private var driverPodName: String = _ + + override def beforeAll(): Unit = { + // The scalatest-maven-plugin gives system properties that are referenced but not set null + // values. We need to remove the null-value properties before initializing the test backend. + val nullValueProperties = System.getProperties.asScala + .filter(entry => entry._2.equals("null")) + .map(entry => entry._1.toString) + nullValueProperties.foreach { key => + System.clearProperty(key) + } + + val sparkDirProp = System.getProperty("spark.kubernetes.test.unpackSparkDir") + require(sparkDirProp != null, "Spark home directory must be provided in system properties.") + sparkHomeDir = Paths.get(sparkDirProp) + require(sparkHomeDir.toFile.isDirectory, + s"No directory found for spark home specified at $sparkHomeDir.") + val imageTag = getTestImageTag + val imageRepo = getTestImageRepo + image = s"$imageRepo/spark:$imageTag" + + val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) + .toFile + .listFiles(new PatternFilenameFilter(Pattern.compile("^spark-examples_.*\\.jar$")))(0) + containerLocalSparkDistroExamplesJar = s"local:///opt/spark/examples/jars/" + + s"${sparkDistroExamplesJarFile.getName}" + testBackend = IntegrationTestBackendFactory.getTestBackend + testBackend.initialize() + kubernetesTestComponents = new KubernetesTestComponents(testBackend.getKubernetesClient) + } + + override def afterAll(): Unit = { + testBackend.cleanUp() + } + + before { + appLocator = UUID.randomUUID().toString.replaceAll("-", "") + driverPodName = "spark-test-app-" + UUID.randomUUID().toString.replaceAll("-", "") + sparkAppConf = kubernetesTestComponents.newSparkAppConf() + .set("spark.kubernetes.container.image", image) + .set("spark.kubernetes.driver.pod.name", driverPodName) + .set("spark.kubernetes.driver.label.spark-app-locator", appLocator) + .set("spark.kubernetes.executor.label.spark-app-locator", appLocator) + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.createNamespace() + } + } + + after { + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.deleteNamespace() + } + deleteDriverPod() + } + + test("Run SparkPi with no resources") { + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a very long application name.") { + sparkAppConf.set("spark.app.name", "long" * 40) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a master URL without a scheme.") { + val url = kubernetesTestComponents.kubernetesClient.getMasterUrl + val k8sMasterUrl = if (url.getPort < 0) { + s"k8s://${url.getHost}" + } else { + s"k8s://${url.getHost}:${url.getPort}" + } + sparkAppConf.set("spark.master", k8sMasterUrl) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with an argument.") { + runSparkPiAndVerifyCompletion(appArgs = Array("5")) + } + + test("Run SparkPi with custom labels, annotations, and environment variables.") { + sparkAppConf + .set("spark.kubernetes.driver.label.label1", "label1-value") + .set("spark.kubernetes.driver.label.label2", "label2-value") + .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") + .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") + .set("spark.kubernetes.executor.label.label1", "label1-value") + .set("spark.kubernetes.executor.label.label2", "label2-value") + .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.executorEnv.ENV1", "VALUE1") + .set("spark.executorEnv.ENV2", "VALUE2") + + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkCustomSettings(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkCustomSettings(executorPod) + }) + } + + // TODO(ssuchter): Enable the below after debugging + // test("Run PageRank using remote data file") { + // sparkAppConf + // .set("spark.kubernetes.mountDependencies.filesDownloadDir", + // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH) + // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + // runSparkPageRankAndVerifyCompletion( + // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE)) + // } + + private def runSparkPiAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String] = Array.empty[String], + appLocator: String = appLocator): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_PI_MAIN_CLASS, + Seq("Pi is roughly 3"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator) + } + + private def runSparkPageRankAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String], + appLocator: String = appLocator): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_PAGE_RANK_MAIN_CLASS, + Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator) + } + + private def runSparkApplicationAndVerifyCompletion( + appResource: String, + mainClass: String, + expectedLogOnCompletion: Seq[String], + appArgs: Array[String], + driverPodChecker: Pod => Unit, + executorPodChecker: Pod => Unit, + appLocator: String): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + driverPodChecker(driverPod) + + val executorPods = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "executor") + .list() + .getItems + executorPods.asScala.foreach { pod => + executorPodChecker(pod) + } + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedLogOnCompletion.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + private def doBasicDriverPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === image) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === image) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + private def checkCustomSettings(pod: Pod): Unit = { + assert(pod.getMetadata.getLabels.get("label1") === "label1-value") + assert(pod.getMetadata.getLabels.get("label2") === "label2-value") + assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") + assert(pod.getMetadata.getAnnotations.get("annotation2") === "annotation2-value") + + val container = pod.getSpec.getContainers.get(0) + val envVars = container + .getEnv + .asScala + .map { env => + (env.getName, env.getValue) + } + .toMap + assert(envVars("ENV1") === "VALUE1") + assert(envVars("ENV2") === "VALUE2") + } + + private def deleteDriverPod(): Unit = { + kubernetesTestComponents.kubernetesClient.pods().withName(driverPodName).delete() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .get() == null) + } + } +} + +private[spark] object KubernetesSuite { + + val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) + val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) + val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + + // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + + // val REMOTE_PAGE_RANK_DATA_FILE = + // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = + // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + + // case object ShuffleNotReadyException extends Exception +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala new file mode 100644 index 0000000000000..48727142dd052 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.nio.file.{Path, Paths} +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.client.DefaultKubernetesClient +import org.scalatest.concurrent.Eventually + +import org.apache.spark.internal.Logging + +private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) { + + val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) + val hasUserSpecifiedNamespace = namespaceOption.isDefined + val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) + private val serviceAccountName = + Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) + .getOrElse("default") + val kubernetesClient = defaultClient.inNamespace(namespace) + val clientConfig = kubernetesClient.getConfiguration + + def createNamespace(): Unit = { + defaultClient.namespaces.createNew() + .withNewMetadata() + .withName(namespace) + .endMetadata() + .done() + } + + def deleteNamespace(): Unit = { + defaultClient.namespaces.withName(namespace).delete() + Eventually.eventually(KubernetesSuite.TIMEOUT, KubernetesSuite.INTERVAL) { + val namespaceList = defaultClient + .namespaces() + .list() + .getItems + .asScala + require(!namespaceList.exists(_.getMetadata.getName == namespace)) + } + } + + def newSparkAppConf(): SparkAppConf = { + new SparkAppConf() + .set("spark.master", s"k8s://${kubernetesClient.getMasterUrl}") + .set("spark.kubernetes.namespace", namespace) + .set("spark.executor.memory", "500m") + .set("spark.executor.cores", "1") + .set("spark.executors.instances", "1") + .set("spark.app.name", "spark-test-app") + .set("spark.ui.enabled", "true") + .set("spark.testing", "false") + .set("spark.kubernetes.submission.waitAppCompletion", "false") + .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName) + } +} + +private[spark] class SparkAppConf { + + private val map = mutable.Map[String, String]() + + def set(key: String, value: String): SparkAppConf = { + map.put(key, value) + this + } + + def get(key: String): String = map.getOrElse(key, "") + + def setJars(jars: Seq[String]): Unit = set("spark.jars", jars.mkString(",")) + + override def toString: String = map.toString + + def toStringArray: Iterable[String] = map.toList.flatMap(t => List("--conf", s"${t._1}=${t._2}")) +} + +private[spark] case class SparkAppArguments( + mainAppResource: String, + mainClass: String, + appArgs: Array[String]) + +private[spark] object SparkAppLauncher extends Logging { + + def launch( + appArguments: SparkAppArguments, + appConf: SparkAppConf, + timeoutSecs: Int, + sparkHomeDir: Path): Unit = { + val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) + logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") + val appArgsArray = + if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" ")) + else Array[String]() + val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--class", appArguments.mainClass, + "--master", appConf.get("spark.master") + ) ++ appConf.toStringArray :+ + appArguments.mainAppResource) ++ + appArgsArray + ProcessUtils.executeProcess(commandLine, timeoutSecs) + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala new file mode 100644 index 0000000000000..d8f3a6cec05c3 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.ArrayBuffer +import scala.io.Source + +import org.apache.spark.internal.Logging + +object ProcessUtils extends Logging { + /** + * executeProcess is used to run a command and return the output if it + * completes within timeout seconds. + */ + def executeProcess(fullCommand: Array[String], timeout: Long): Seq[String] = { + val pb = new ProcessBuilder().command(fullCommand: _*) + pb.redirectErrorStream(true) + val proc = pb.start() + val outputLines = new ArrayBuffer[String] + Utils.tryWithResource(proc.getInputStream)( + Source.fromInputStream(_, "UTF-8").getLines().foreach { line => + logInfo(line) + outputLines += line + }) + assert(proc.waitFor(timeout, TimeUnit.SECONDS), + s"Timed out while executing ${fullCommand.mkString(" ")}") + assert(proc.exitValue == 0, s"Failed to execute ${fullCommand.mkString(" ")}") + outputLines + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala new file mode 100644 index 0000000000000..f1fd6dc19ce54 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import com.google.common.util.concurrent.SettableFuture +import io.fabric8.kubernetes.api.model.HasMetadata +import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.internal.readiness.Readiness + +private[spark] class SparkReadinessWatcher[T <: HasMetadata] extends Watcher[T] { + + private val signal = SettableFuture.create[Boolean] + + override def eventReceived(action: Action, resource: T): Unit = { + if ((action == Action.MODIFIED || action == Action.ADDED) && + Readiness.isReady(resource)) { + signal.set(true) + } + } + + override def onClose(cause: KubernetesClientException): Unit = {} + + def waitUntilReady(): Boolean = signal.get(60, TimeUnit.SECONDS) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala new file mode 100644 index 0000000000000..663f8b6523ac8 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.Closeable +import java.net.URI + +import org.apache.spark.internal.Logging + +object Utils extends Logging { + + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala new file mode 100644 index 0000000000000..284712c6d250e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s.integrationtest.backend + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend + +private[spark] trait IntegrationTestBackend { + def initialize(): Unit + def getKubernetesClient: DefaultKubernetesClient + def cleanUp(): Unit = {} +} + +private[spark] object IntegrationTestBackendFactory { + val deployModeConfigKey = "spark.kubernetes.test.deployMode" + + def getTestBackend: IntegrationTestBackend = { + val deployMode = Option(System.getProperty(deployModeConfigKey)) + .getOrElse("minikube") + if (deployMode == "minikube") { + MinikubeTestBackend + } else { + throw new IllegalArgumentException( + "Invalid " + deployModeConfigKey + ": " + deployMode) + } + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala new file mode 100644 index 0000000000000..6494cbc18f33e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import java.io.File +import java.nio.file.Paths + +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient} + +import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils +import org.apache.spark.internal.Logging + +// TODO support windows +private[spark] object Minikube extends Logging { + + private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60 + + def getMinikubeIp: String = { + val outputs = executeMinikube("ip") + .filter(_.matches("^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$")) + assert(outputs.size == 1, "Unexpected amount of output from minikube ip") + outputs.head + } + + def getMinikubeStatus: MinikubeStatus.Value = { + val statusString = executeMinikube("status") + .filter(line => line.contains("minikubeVM: ") || line.contains("minikube:")) + .head + .replaceFirst("minikubeVM: ", "") + .replaceFirst("minikube: ", "") + MinikubeStatus.unapply(statusString) + .getOrElse(throw new IllegalStateException(s"Unknown status $statusString")) + } + + def getKubernetesClient: DefaultKubernetesClient = { + val kubernetesMaster = s"https://${getMinikubeIp}:8443" + val userHome = System.getProperty("user.home") + val kubernetesConf = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(kubernetesMaster) + .withCaCertFile(Paths.get(userHome, ".minikube", "ca.crt").toFile.getAbsolutePath) + .withClientCertFile(Paths.get(userHome, ".minikube", "apiserver.crt").toFile.getAbsolutePath) + .withClientKeyFile(Paths.get(userHome, ".minikube", "apiserver.key").toFile.getAbsolutePath) + .build() + new DefaultKubernetesClient(kubernetesConf) + } + + private def executeMinikube(action: String, args: String*): Seq[String] = { + ProcessUtils.executeProcess( + Array("bash", "-c", s"minikube $action") ++ args, MINIKUBE_STARTUP_TIMEOUT_SECONDS) + } +} + +private[spark] object MinikubeStatus extends Enumeration { + + // The following states are listed according to + // https://github.com/docker/machine/blob/master/libmachine/state/state.go. + val STARTING = status("Starting") + val RUNNING = status("Running") + val PAUSED = status("Paused") + val STOPPING = status("Stopping") + val STOPPED = status("Stopped") + val ERROR = status("Error") + val TIMEOUT = status("Timeout") + val SAVED = status("Saved") + val NONE = status("") + + def status(value: String): Value = new Val(nextId, value) + def unapply(s: String): Option[Value] = values.find(s == _.toString) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala new file mode 100644 index 0000000000000..cb9324179d70e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend + +private[spark] object MinikubeTestBackend extends IntegrationTestBackend { + + private var defaultClient: DefaultKubernetesClient = _ + + override def initialize(): Unit = { + val minikubeStatus = Minikube.getMinikubeStatus + require(minikubeStatus == MinikubeStatus.RUNNING, + s"Minikube must be running to use the Minikube backend for integration tests." + + s" Current status is: $minikubeStatus.") + defaultClient = Minikube.getKubernetesClient + } + + override def cleanUp(): Unit = { + super.cleanUp() + } + + override def getKubernetesClient: DefaultKubernetesClient = { + defaultClient + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala new file mode 100644 index 0000000000000..a81ef455c6766 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +package object config { + def getTestImageTag: String = { + val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") + require(imageTagFileProp != null, "Image tag file must be provided in system properties.") + val imageTagFile = new File(imageTagFileProp) + require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.") + Files.toString(imageTagFile, Charsets.UTF_8).trim + } + + def getTestImageRepo: String = { + val imageRepo = System.getProperty("spark.kubernetes.test.imageRepo") + require(imageRepo != null, "Image repo must be provided in system properties.") + imageRepo + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala new file mode 100644 index 0000000000000..0807a68cd823c --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +package object constants { + val MINIKUBE_TEST_BACKEND = "minikube" + val GCE_TEST_BACKEND = "gce" +} From 36a3409134687d6a2894cd6a77554b8439cacec1 Mon Sep 17 00:00:00 2001 From: Thiruvasakan Paramasivan Date: Fri, 8 Jun 2018 17:17:43 -0700 Subject: [PATCH 0935/2461] [SPARK-24412][SQL] Adding docs about automagical type casting in `isin` and `isInCollection` APIs ## What changes were proposed in this pull request? Update documentation for `isInCollection` API to clealy explain the "auto-casting" of elements if their types are different. ## How was this patch tested? No-Op Author: Thiruvasakan Paramasivan Closes #21519 from trvskn/sql-doc-update. --- .../scala/org/apache/spark/sql/Column.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b3e59f53ee3de..2dbb53e7c906b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -781,6 +781,14 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. * + * Note: Since the type of the elements in the list are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group expr_ops * @since 1.5.0 */ @@ -791,6 +799,14 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the provided collection. * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group expr_ops * @since 2.4.0 */ @@ -800,6 +816,14 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the provided collection. * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group java_expr_ops * @since 2.4.0 */ From f07c5064a3967cdddf57c2469635ee50a26d864c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 8 Jun 2018 18:51:56 -0700 Subject: [PATCH 0936/2461] [SPARK-24468][SQL] Handle negative scale when adjusting precision for decimal operations ## What changes were proposed in this pull request? In SPARK-22036 we introduced the possibility to allow precision loss in arithmetic operations (according to the SQL standard). The implementation was drawn from Hive's one, where Decimals with a negative scale are not allowed in the operations. The PR handles the case when the scale is negative, removing the assertion that it is not. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21499 from mgaido91/SPARK-24468. --- .../apache/spark/sql/types/DecimalType.scala | 8 +- .../analysis/DecimalPrecisionSuite.scala | 9 + .../native/decimalArithmeticOperations.sql | 4 + .../decimalArithmeticOperations.sql.out | 164 +++++++++++------- 4 files changed, 117 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ef3b67c0d48d0..dbf51c398fa47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -161,13 +161,17 @@ object DecimalType extends AbstractDataType { * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. */ private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { - // Assumptions: + // Assumption: assert(precision >= scale) - assert(scale >= 0) if (precision <= MAX_PRECISION) { // Adjustment only needed when we exceed max precision DecimalType(precision, scale) + } else if (scale < 0) { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + DecimalType(MAX_PRECISION, scale) } else { // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. val intDigits = precision - scale diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index c86dc18dfa680..bd87ca6017e99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -272,6 +272,15 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { } } + test("SPARK-24468: operations on decimals with negative scale") { + val a = AttributeReference("a", DecimalType(3, -10))() + val b = AttributeReference("b", DecimalType(1, -1))() + val c = AttributeReference("c", DecimalType(35, 1))() + checkType(Multiply(a, b), DecimalType(5, -11)) + checkType(Multiply(a, c), DecimalType(38, -9)) + checkType(Multiply(b, c), DecimalType(37, 0)) + } + /** strength reduction for integer/decimal comparisons */ def ruleTest(initial: Expression, transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 9be7fcdadfea8..28a0e20c0f495 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -40,12 +40,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss are truncated select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; @@ -67,12 +69,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss return NULL select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 6bfdb84548d4d..cbf44548b3cce 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 36 +-- Number of queries: 40 -- !query 0 @@ -114,190 +114,222 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000 -- !query 13 -select (5e36 + 0.1) + 5e36 +select 2.35E10 * 1.0 -- !query 13 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 13 output -NULL +23500000000 -- !query 14 -select (-4e36 - 0.1) - 7e36 +select (5e36 + 0.1) + 5e36 -- !query 14 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 14 output NULL -- !query 15 -select 12345678901234567890.0 * 12345678901234567890.0 +select (-4e36 - 0.1) - 7e36 -- !query 15 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 15 output NULL -- !query 16 -select 1e35 / 0.1 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 16 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 16 output NULL -- !query 17 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select 1e35 / 0.1 -- !query 17 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> -- !query 17 output -10012345678912345678912345678911.246907 +NULL -- !query 18 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 18 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 18 output -138698367904130467.654320988515622621 +NULL -- !query 19 -select 12345678912345.123456789123 / 0.000000012345678 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 19 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 19 output -1000000073899961059796.725866332 +10012345678912345678912345678911.246907 -- !query 20 -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 20 schema -struct +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 20 output -spark.sql.decimalOperations.allowPrecisionLoss false +138698367904130467.654320988515622621 -- !query 21 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 21 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 21 output -1 1099 -899 NULL 0.1001001001001001 -2 24690.246 0 NULL 1 -3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 -4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 +1000000073899961059796.725866332 -- !query 22 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 22 schema -struct +struct -- !query 22 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.1123456789123456789 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 23 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 23 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 23 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 24 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 24 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 24 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 25 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 25 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 25 output -309 +30.9 -- !query 26 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 26 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 26 output 30.9 -- !query 27 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 27 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 27 output -NULL +309 -- !query 28 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 28 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 28 output -NULL +30.9 -- !query 29 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 29 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 29 output NULL -- !query 30 -select 12345678901234567890.0 * 12345678901234567890.0 +select 2.35E10 * 1.0 -- !query 30 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 30 output -NULL +23500000000 -- !query 31 -select 1e35 / 0.1 +select (5e36 + 0.1) + 5e36 -- !query 31 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 31 output NULL -- !query 32 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select (-4e36 - 0.1) - 7e36 -- !query 32 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 32 output NULL -- !query 33 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 33 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 33 output NULL -- !query 34 -select 12345678912345.123456789123 / 0.000000012345678 +select 1e35 / 0.1 -- !query 34 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 34 output NULL -- !query 35 -drop table decimals_test +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 35 schema -struct<> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 35 output +NULL + + +-- !query 36 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 36 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 36 output +NULL + + +-- !query 37 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 37 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 37 output +NULL + + +-- !query 38 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 38 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 38 output +NULL + + +-- !query 39 +drop table decimals_test +-- !query 39 schema +struct<> +-- !query 39 output From 3e5b4ae63a468858ff8b9f7f3231cc877846a0af Mon Sep 17 00:00:00 2001 From: edorigatti Date: Mon, 11 Jun 2018 10:15:42 +0800 Subject: [PATCH 0937/2461] [SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from driver to executor ## What changes were proposed in this pull request? SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker ## How does this work? The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used: - In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself. - In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack. ## How was this patch tested? Same tests, plus tests for pandas UDFs Author: edorigatti Closes #21467 from e-dorigatti/fix_udf_hack. --- python/pyspark/sql/tests.py | 71 ++++++++++++++++++++++++++++--------- python/pyspark/sql/udf.py | 14 ++------ python/pyspark/tests.py | 37 +++++++++++-------- python/pyspark/util.py | 9 ++--- python/pyspark/worker.py | 18 ++++++---- 5 files changed, 92 insertions(+), 57 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 487eb19c3b98a..4a3941de8a6a6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,22 +900,6 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_stopiteration_in_udf(self): - # test for SPARK-23754 - from pyspark.sql.functions import udf - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - with self.assertRaises(Py4JJavaError) as cm: - self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() - - self.assertIn( - "Caught StopIteration thrown from user's code; failing the task", - cm.exception.java_exception.toString() - ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -4144,6 +4128,61 @@ def foo(df): def foo(k, v, w): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + # pandas grouped agg + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c8fb49d7c2b65..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec, fail_on_stopiteration +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -157,17 +157,7 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - func = fail_on_stopiteration(self.func) - - # for pandas UDFs the worker needs to know if the function takes - # one or two arguments, but the signature is lost when wrapping with - # fail_on_stopiteration, so we store it here - if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - func._argspec = _get_argspec(self.func) - - wrapped_func = _wrap_function(sc, func, self.returnType) + wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 30723b8e15b36..18b2f251dc9fd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1291,27 +1291,34 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) - def test_stopiteration_in_client_code(self): + def test_stopiteration_in_user_code(self): def stopit(*x): raise StopIteration() seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - - self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) - - # the exception raised is non-deterministic - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, lambda *x: 1, stopit) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e95a9b523393f..f015542c8799d 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,12 +53,7 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - - if hasattr(f, '_argspec'): - # only used for pandas UDF: they wrap the user function, losing its signature - # workers need this signature, so UDF saves it here - argspec = f._argspec - elif sys.version_info[0] < 3: + if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: # `getargspec` is deprecated since python3.0 (incompatible with function annotations). @@ -97,7 +92,7 @@ def majorMinorVersion(sparkVersion): def fail_on_stopiteration(f): """ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + prevents silent loss of data when 'f' is used in a for loop in Spark code """ def wrapper(*args, **kwargs): try: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbcb8af8bfb24..a30d6bf523a50 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -35,7 +35,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle pickleSer = PickleSerializer() @@ -92,10 +92,9 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd - argspec = _get_argspec(f) if len(argspec.args) == 1: result = f(pd.concat(value_series, axis=1)) @@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + argspec = _get_argspec(row_func) # signature was lost when wrapping it + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) From a99d284c16cc4e00ce7c83ecdc3db6facd467552 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 11 Jun 2018 12:15:14 -0700 Subject: [PATCH 0938/2461] [SPARK-19826][ML][PYTHON] add spark.ml Python API for PIC ## What changes were proposed in this pull request? add spark.ml Python API for PIC ## How was this patch tested? add doctest Author: Huaxin Gao Closes #21513 from huaxingao/spark--19826. --- python/pyspark/ml/clustering.py | 184 +++++++++++++++++++++++++++++++- 1 file changed, 179 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b3d5fb17f6b81..4aa1cf84b5824 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,14 +19,15 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', - 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] + 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering'] class ClusteringSummary(JavaWrapper): @@ -836,7 +837,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter Terminology: - - "term" = "word": an el + - "term" = "word": an element of the vocabulary - "token": instance of a term appearing in a document - "topic": multinomial distribution over terms representing some concept - "document": one piece of text, corresponding to one row in the input data @@ -938,7 +939,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) """ super(LDA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid) @@ -967,7 +968,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) Sets params for LDA. """ @@ -1156,6 +1157,179 @@ def getKeepLastCheckpoint(self): return self.getOrDefault(self.keepLastCheckpoint) +@inherit_doc +class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + Lin and Cohen. From the abstract: + PIC finds a very low-dimensional embedding of a dataset using truncated power + iteration on a normalized pair-wise similarity matrix of the data. + + This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method + to run the PowerIterationClustering algorithm. + + .. seealso:: `Wikipedia on Spectral clustering \ + `_ + + >>> data = [(1, 0, 0.5), \ + (2, 0, 0.5), (2, 1, 0.7), \ + (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \ + (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \ + (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] + >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight") + >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight") + >>> assignments = pic.assignClusters(df) + >>> assignments.sort(assignments.id).show(truncate=False) + +---+-------+ + |id |cluster| + +---+-------+ + |0 |1 | + |1 |1 | + |2 |1 | + |3 |1 | + |4 |1 | + |5 |0 | + +---+-------+ + ... + >>> pic_path = temp_path + "/pic" + >>> pic.save(pic_path) + >>> pic2 = PowerIterationClustering.load(pic_path) + >>> pic2.getK() + 2 + >>> pic2.getMaxIter() + 40 + + .. versionadded:: 2.4.0 + """ + + k = Param(Params._dummy(), "k", + "The number of clusters to create. Must be > 1.", + typeConverter=TypeConverters.toInt) + initMode = Param(Params._dummy(), "initMode", + "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use " + + "a normalized sum of similarities with other vertices. Supported options: " + + "'random' and 'degree'.", + typeConverter=TypeConverters.toString) + srcCol = Param(Params._dummy(), "srcCol", + "Name of the input column for source vertex IDs.", + typeConverter=TypeConverters.toString) + dstCol = Param(Params._dummy(), "dstCol", + "Name of the input column for destination vertex IDs.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + """ + super(PowerIterationClustering, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid) + self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + Sets params for PowerIterationClustering. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + return self._set(k=value) + + @since("2.4.0") + def getK(self): + """ + Gets the value of :py:attr:`k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.4.0") + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + """ + return self._set(initMode=value) + + @since("2.4.0") + def getInitMode(self): + """ + Gets the value of :py:attr:`initMode` or its default value. + """ + return self.getOrDefault(self.initMode) + + @since("2.4.0") + def setSrcCol(self, value): + """ + Sets the value of :py:attr:`srcCol`. + """ + return self._set(srcCol=value) + + @since("2.4.0") + def getSrcCol(self): + """ + Gets the value of :py:attr:`srcCol` or its default value. + """ + return self.getOrDefault(self.srcCol) + + @since("2.4.0") + def setDstCol(self, value): + """ + Sets the value of :py:attr:`dstCol`. + """ + return self._set(dstCol=value) + + @since("2.4.0") + def getDstCol(self): + """ + Gets the value of :py:attr:`dstCol` or its default value. + """ + return self.getOrDefault(self.dstCol) + + @since("2.4.0") + def assignClusters(self, dataset): + """ + Run the PIC algorithm and returns a cluster assignment for each input vertex. + + :param dataset: + A dataset with columns src, dst, weight representing the affinity matrix, + which is the matrix A in the PIC paper. Suppose the src column value is i, + the dst column value is j, the weight column value is similarity s,,ij,, + which must be nonnegative. This is a symmetric matrix and hence + s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + ignored, because we assume s,,ij,, = 0.0. + + :return: + A dataset that contains columns of vertex id and the corresponding cluster for + the id. The schema of it will be: + - id: Long + - cluster: Int + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.assignClusters(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) + + if __name__ == "__main__": import doctest import pyspark.ml.clustering From 9b6f24202f6f8d9d76bbe53f379743318acb19f9 Mon Sep 17 00:00:00 2001 From: Jonathan Kelly Date: Mon, 11 Jun 2018 16:41:15 -0500 Subject: [PATCH 0939/2461] [MINOR][CORE] Log committer class used by HadoopMapRedCommitProtocol ## What changes were proposed in this pull request? When HadoopMapRedCommitProtocol is used (e.g., when using saveAsTextFile() or saveAsHadoopFile() with RDDs), it's not easy to determine which output committer class was used, so this PR simply logs the class that was used, similarly to what is done in SQLHadoopMapReduceCommitProtocol. ## How was this patch tested? Built Spark then manually inspected logging when calling saveAsTextFile(): ```scala scala> sc.setLogLevel("INFO") scala> sc.textFile("README.md").saveAsTextFile("/tmp/out") ... 18/05/29 10:06:20 INFO HadoopMapRedCommitProtocol: Using output committer class org.apache.hadoop.mapred.FileOutputCommitter ``` Author: Jonathan Kelly Closes #21452 from ejono/master. --- .../apache/spark/internal/io/HadoopMapRedCommitProtocol.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala index ddbd624b380d4..af0aa41518766 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -31,6 +31,8 @@ class HadoopMapRedCommitProtocol(jobId: String, path: String) override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { val config = context.getConfiguration.asInstanceOf[JobConf] - config.getOutputCommitter + val committer = config.getOutputCommitter + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + committer } } From 2dc047a3189290411def92f6d7e9a4e01bdb2c30 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 11 Jun 2018 17:12:33 -0500 Subject: [PATCH 0940/2461] [SPARK-24520] Double braces in documentations There are double braces in the markdown, which break the link. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Fokko Driesprong Closes #21528 from Fokko/patch-1. --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 602a4c70848e7..0842e8dd88672 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -926,7 +926,7 @@ event time. For a specific window starting at time `T`, the engine will maintain data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will start getting dropped -(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +(see [later](#semantic-guarantees-of-aggregation-with-watermarking) in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. From f5af86ea753c446df59a0a8c16c685224690d633 Mon Sep 17 00:00:00 2001 From: Xiaodong <11539188+XD-DENG@users.noreply.github.com> Date: Mon, 11 Jun 2018 17:13:11 -0500 Subject: [PATCH 0941/2461] [SPARK-24134][DOCS] A missing full-stop in doc "Tuning Spark". MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In the document [Tuning Spark -> Determining Memory Consumption](https://spark.apache.org/docs/latest/tuning.html#determining-memory-consumption), a full stop was missing in the second paragraph. It's `...use SizeEstimator’s estimate method This is useful for experimenting...`, while there is supposed to be a full stop before `This`. Screenshot showing before change is attached below. screen shot 2018-05-01 at 5 22 32 pm ## How was this patch tested? This is a simple change in doc. Only one full stop was added in plain text. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xiaodong <11539188+XD-DENG@users.noreply.github.com> Closes #21205 from XD-DENG/patch-1. --- docs/tuning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tuning.md b/docs/tuning.md index 912c39879be8f..1c3bd0e8758ff 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -132,7 +132,7 @@ The best way to size the amount of memory consumption a dataset will require is into cache, and look at the "Storage" page in the web UI. The page will tell you how much memory the RDD is occupying. -To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method +To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method. This is useful for experimenting with different data layouts to trim memory usage, as well as determining the amount of space a broadcast variable will occupy on each executor heap. From 048197749ef990e4def1fcbf488f3ded38d95cae Mon Sep 17 00:00:00 2001 From: liutang123 Date: Mon, 11 Jun 2018 17:48:07 -0700 Subject: [PATCH 0942/2461] [SPARK-22144][SQL] ExchangeCoordinator combine the partitions of an 0 sized pre-shuffle to 0 ## What changes were proposed in this pull request? when the length of pre-shuffle's partitions is 0, the length of post-shuffle's partitions should be 0 instead of spark.sql.shuffle.partitions. ## How was this patch tested? ExchangeCoordinator converted a pre-shuffle that partitions is 0 to a post-shuffle that partitions is 0 instead of one that partitions is spark.sql.shuffle.partitions. Author: liutang123 Closes #19364 from liutang123/SPARK-22144. --- .../spark/sql/execution/exchange/ExchangeCoordinator.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 78f11ca8d8c78..051e610eb2705 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -232,16 +232,16 @@ class ExchangeCoordinator( // number of post-shuffle partitions. val partitionStartIndices = if (mapOutputStatistics.length == 0) { - None + Array.empty[Int] } else { - Some(estimatePartitionStartIndices(mapOutputStatistics)) + estimatePartitionStartIndices(mapOutputStatistics) } var k = 0 while (k < numExchanges) { val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), Some(partitionStartIndices)) newPostShuffleRDDs.put(exchange, rdd) k += 1 From dc22465f3e1ef5ad59306b1f591d6fd16d674eb7 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 12 Jun 2018 09:32:14 +0800 Subject: [PATCH 0943/2461] [SPARK-23732][DOCS] Fix source links in generated scaladoc. Apply the suggestion on the bug to fix source links. Tested with the 2.3.1 release docs. Author: Marcelo Vanzin Closes #21521 from vanzin/SPARK-23732. --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4cb6495a33b61..adc2b6b5d273d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -729,7 +729,8 @@ object Unidoc { scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( "-groups", // Group similar methods together based on the @group annotation. - "-skip-packages", "org.apache.hadoop" + "-skip-packages", "org.apache.hadoop", + "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath ) ++ ( // Add links to sources when generating Scaladoc for a non-snapshot release if (!isSnapshot.value) { From 01452ea9c75ff027ceeb8314368c6bbedefdb2bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Jun 2018 22:08:44 -0700 Subject: [PATCH 0944/2461] [SPARK-24502][SQL] flaky test: UnsafeRowSerializerSuite ## What changes were proposed in this pull request? `UnsafeRowSerializerSuite` calls `UnsafeProjection.create` which accesses `SQLConf.get`, while the current active SparkSession may already be stopped, and we may hit exception like this ``` sbt.ForkMain$ForkError: java.lang.IllegalStateException: LiveListenerBus is stopped. at org.apache.spark.scheduler.LiveListenerBus.addToQueue(LiveListenerBus.scala:97) at org.apache.spark.scheduler.LiveListenerBus.addToStatusQueue(LiveListenerBus.scala:80) at org.apache.spark.sql.internal.SharedState.(SharedState.scala:93) at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120) at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.SparkSession.sharedState$lzycompute(SparkSession.scala:120) at org.apache.spark.sql.SparkSession.sharedState(SparkSession.scala:119) at org.apache.spark.sql.internal.BaseSessionStateBuilder.build(BaseSessionStateBuilder.scala:286) at org.apache.spark.sql.test.TestSparkSession.sessionState$lzycompute(TestSQLContext.scala:42) at org.apache.spark.sql.test.TestSparkSession.sessionState(TestSQLContext.scala:41) at org.apache.spark.sql.SparkSession$$anonfun$1$$anonfun$apply$1.apply(SparkSession.scala:95) at org.apache.spark.sql.SparkSession$$anonfun$1$$anonfun$apply$1.apply(SparkSession.scala:95) at scala.Option.map(Option.scala:146) at org.apache.spark.sql.SparkSession$$anonfun$1.apply(SparkSession.scala:95) at org.apache.spark.sql.SparkSession$$anonfun$1.apply(SparkSession.scala:94) at org.apache.spark.sql.internal.SQLConf$.get(SQLConf.scala:126) at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:54) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:157) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:150) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite.org$apache$spark$sql$execution$UnsafeRowSerializerSuite$$unsafeRowConverter(UnsafeRowSerializerSuite.scala:54) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite.org$apache$spark$sql$execution$UnsafeRowSerializerSuite$$toUnsafeRow(UnsafeRowSerializerSuite.scala:49) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite$$anonfun$2.apply(UnsafeRowSerializerSuite.scala:63) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite$$anonfun$2.apply(UnsafeRowSerializerSuite.scala:60) ... ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #21518 from cloud-fan/test. --- .../apache/spark/sql/LocalSparkSession.scala | 4 + .../execution/UnsafeRowSerializerSuite.scala | 80 +++++++------------ 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index d66a6902b0510..cbef1c7828319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override def afterEach() { try { resetSparkContext() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } finally { super.afterEach() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a3ae93810aa3c..d305ce3e698ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter /** @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("toUnsafeRow() test helper method") { - // This currently doesnt work because the generic getter throws an exception. + // This currently doesn't work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-10466: external sorter spilling with unsafe row serializer") { - var sc: SparkContext = null - var outputFile: File = null - val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten - Utils.tryWithSafeFinally { - val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.testing.memory", "80000") - - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - // prepare data - val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 10000).iterator.map { i => - (i, converter(Row(i))) - } - val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - taskContext, - partitioner = Some(new HashPartitioner(10)), - serializer = new UnsafeRowSerializer(numFields = 1)) - - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() + val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + outputFile.deleteOnExit() + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - // Merging spilled files should not throw assertion error - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) - } { - // Clean up - if (sc != null) { - sc.stop() - } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, + partitioner = Some(new HashPartitioner(10)), + serializer = new UnsafeRowSerializer(numFields = 1)) - // restore the spark env - SparkEnv.set(oldEnv) + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) - if (outputFile != null) { - outputFile.delete() - } - } + // Merging spilled files should not throw assertion error + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } test("SPARK-10403: unsafe row serializer with SortShuffleManager") { val conf = new SparkConf().set("spark.shuffle.manager", "sort") - sc = new SparkContext("local", "test", conf) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) - .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val rowsRDD = spark.sparkContext.parallelize( + Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)) + ).asInstanceOf[RDD[Product2[Int, InternalRow]]] val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rowsRDD, From 1d7db65e968de1c601e7f8b1ec9bc783ef2dbd01 Mon Sep 17 00:00:00 2001 From: Tom Saleeba Date: Tue, 12 Jun 2018 09:22:52 -0500 Subject: [PATCH 0945/2461] docs: fix typo no => no[t] ## What changes were proposed in this pull request? Fixing a typo. ## How was this patch tested? Visual check of the docs. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tom Saleeba Closes #21496 from tomsaleeba/patch-1. --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2dbb53e7c906b..4eee3de5f7d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -104,7 +104,7 @@ class TypedColumn[-T, U]( * * {{{ * df("columnName") // On a specific `df` DataFrame. - * col("columnName") // A generic column no yet associated with a DataFrame. + * col("columnName") // A generic column not yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. From 5d6a53d9831cc1e2115560db5cebe0eea2565dcd Mon Sep 17 00:00:00 2001 From: Lee Dongjin Date: Tue, 12 Jun 2018 08:16:37 -0700 Subject: [PATCH 0946/2461] [SPARK-15064][ML] Locale support in StopWordsRemover ## What changes were proposed in this pull request? Add locale support for `StopWordsRemover`. ## How was this patch tested? [Scala|Python] unit tests. Author: Lee Dongjin Closes #21501 from dongjinleekr/feature/SPARK-15064. --- .../spark/ml/feature/StopWordsRemover.scala | 30 +++++++++-- .../ml/feature/StopWordsRemoverSuite.scala | 51 +++++++++++++++++++ python/pyspark/ml/feature.py | 30 +++++++++-- python/pyspark/ml/tests.py | 7 +++ 4 files changed, 109 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3fcd84c029e61..0f946dd2e015b 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.feature +import java.util.Locale + import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) + /** + * Locale of the input for case insensitive matching. Ignored when [[caseSensitive]] + * is true. + * Default: Locale.getDefault.toString + * @group param + */ + @Since("2.4.0") + val locale: Param[String] = new Param[String](this, "locale", + "Locale of the input for case insensitive matching. Ignored when caseSensitive is true.", + ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString))) + + /** @group setParam */ + @Since("2.4.0") + def setLocale(value: String): this.type = set(locale, value) + + /** @group getParam */ + @Since("2.4.0") + def getLocale: String = $(locale) + + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive -> false, locale -> Locale.getDefault.toString) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String terms.filter(s => !stopWordsSet.contains(s)) } } else { - // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lc = new Locale($(locale)) + val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => terms.filter(s => !lowerStopWords.contains(toLower(s))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 21259a50916d2..20972d1f403b9 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { testStopWordsRemover(remover, dataSet) } + test("StopWordsRemover with localed input (case insensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(false) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with localed input (case sensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(true) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with invalid locale") { + intercept[IllegalArgumentException] { + val stopWords = Array("test", "a", "an", "the") + new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setLocale("rt") // invalid locale + } + } + test("StopWordsRemover case sensitive") { val remover = new StopWordsRemover() .setInputCol("raw") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cdda30cfab482..14800d4d9327a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words", typeConverter=TypeConverters.toBoolean) + locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " + + "is true", typeConverter=TypeConverters.toString) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False) + caseSensitive=False, locale=self._java_obj.getLocale()) kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) Sets params for this StopWordRemover. """ kwargs = self._input_kwargs @@ -2634,6 +2640,20 @@ def getCaseSensitive(self): """ return self.getOrDefault(self.caseSensitive) + @since("2.4.0") + def setLocale(self, value): + """ + Sets the value of :py:attr:`locale`. + """ + return self._set(locale=value) + + @since("2.4.0") + def getLocale(self): + """ + Gets the value of :py:attr:`locale`. + """ + return self.getOrDefault(self.locale) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0dde0db9e3339..ebd36cbb5f7a7 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -681,6 +681,13 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): dataset = self.spark.createDataFrame([ From 2824f1436bb0371b7216730455f02456ef8479ce Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 12 Jun 2018 09:56:35 -0700 Subject: [PATCH 0947/2461] [SPARK-24531][TESTS] Remove version 2.2.0 from testing versions in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Removing version 2.2.0 from testing versions in HiveExternalCatalogVersionsSuite as it is not present anymore in the mirrors and this is blocking all the open PRs. ## How was this patch tested? running UTs Author: Marco Gaido Closes #21540 from mgaido91/SPARK-24531. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index ea86ab9772bc7..6f904c937348d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") protected var spark: SparkSession = _ From 3af1d3e6d95719e15a997877d5ecd3bb40c08b9c Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Tue, 12 Jun 2018 13:55:08 -0500 Subject: [PATCH 0948/2461] [SPARK-24416] Fix configuration specification for killBlacklisted executors ## What changes were proposed in this pull request? spark.blacklist.killBlacklistedExecutors is defined as (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, executors when they are blacklisted. Note that, when an entire node is added to the blacklist, all of the executors on that node will be killed. I presume the killing of blacklisted executors only happens after the stage completes successfully and all tasks have completed or on fetch failures (updateBlacklistForFetchFailure/updateBlacklistForSuccessfulTaskSet). It is confusing because the definition states that the executor will be attempted to be recreated as soon as it is blacklisted. This is not true while the stage is in progress and an executor is blacklisted, it will not attempt to cleanup until the stage finishes. Author: Sanket Chintapalli Closes #21475 from redsanket/SPARK-24416. --- docs/configuration.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5588c372d3e42..6aa7878fe614d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1656,9 +1656,10 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.killBlacklistedExecutors false - (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, - executors when they are blacklisted. Note that, when an entire node is added to the blacklist, - all of the executors on that node will be killed. + (Experimental) If set to "true", allow Spark to automatically kill the executors + when they are blacklisted on fetch failure or blacklisted for the entire application, + as controlled by spark.blacklist.application.*. Note that, when an entire node is added + to the blacklist, all of the executors on that node will be killed. From f0ef1b311dd5399290ad6abe4ca491bdb13478f0 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Tue, 12 Jun 2018 11:57:25 -0700 Subject: [PATCH 0949/2461] [SPARK-23931][SQL] Adds arrays_zip function to sparksql Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Addition of arrays_zip function to spark sql functions. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Unit tests that checks if the results are correct. Author: DylanGuedes Closes #21045 from DylanGuedes/SPARK-23931. --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 166 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 86 +++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 47 +++++ 6 files changed, 325 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1759195c6fcc0..0715297042520 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2394,6 +2394,23 @@ def array_repeat(col, count): return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) +@since(2.4) +def arrays_zip(*cols): + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + + :param cols: columns of arrays to be merged. + + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) + >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect() + [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 49fb35b083580..3c0b72873af54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -423,6 +423,7 @@ object FunctionRegistry { expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), + expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 176995affe701..d76f3013f0c41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -128,6 +128,172 @@ case class MapKeys(child: Expression) override def prettyName: String = "map_keys" } +@ExpressionDescription( + usage = """ + _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); + [[1, 2], [2, 3], [3, 4]] + > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); + [[1, 2, 3], [2, 3, 4]] + """, + since = "2.4.0") +case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) + + override def dataType: DataType = ArrayType(mountSchema) + + override def nullable: Boolean = children.exists(_.nullable) + + private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + + private lazy val arrayElementTypes = arrayTypes.map(_.elementType) + + @transient private lazy val mountSchema: StructType = { + val fields = children.zip(arrayElementTypes).zipWithIndex.map { + case ((expr: NamedExpression, elementType), _) => + StructField(expr.name, elementType, nullable = true) + case ((_, elementType), idx) => + StructField(idx.toString, elementType, nullable = true) + } + StructType(fields) + } + + @transient lazy val numberOfArrays: Int = children.length + + @transient lazy val genericArrayData = classOf[GenericArrayData].getName + + def emptyInputGenCode(ev: ExprCode): ExprCode = { + ev.copy(code""" + |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); + |boolean ${ev.isNull} = false; + """.stripMargin) + } + + def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val genericInternalRow = classOf[GenericInternalRow].getName + val arrVals = ctx.freshName("arrVals") + val biggestCardinality = ctx.freshName("biggestCardinality") + + val currentRow = ctx.freshName("currentRow") + val j = ctx.freshName("j") + val i = ctx.freshName("i") + val args = ctx.freshName("args") + + val evals = children.map(_.genCode(ctx)) + val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => + s""" + |if ($biggestCardinality != -1) { + | ${eval.code} + | if (!${eval.isNull}) { + | $arrVals[$index] = ${eval.value}; + | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); + | } else { + | $biggestCardinality = -1; + | } + |} + """.stripMargin + } + + val splittedGetValuesAndCardinalities = ctx.splitExpressions( + expressions = getValuesAndCardinalities, + funcName = "getValuesAndCardinalities", + returnType = "int", + makeSplitFunction = body => + s""" + |$body + |return $biggestCardinality; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), + arguments = + ("ArrayData[]", arrVals) :: + ("int", biggestCardinality) :: Nil) + + val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => + val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) + s""" + |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { + | $currentRow[$idx] = $g; + |} else { + | $currentRow[$idx] = null; + |} + """.stripMargin + } + + val getValueForTypeSplitted = ctx.splitExpressions( + expressions = getValueForType, + funcName = "extractValue", + arguments = + ("int", i) :: + ("Object[]", currentRow) :: + ("ArrayData[]", arrVals) :: Nil) + + val initVariables = s""" + |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |int $biggestCardinality = 0; + |${CodeGenerator.javaType(dataType)} ${ev.value} = null; + """.stripMargin + + ev.copy(code""" + |$initVariables + |$splittedGetValuesAndCardinalities + |boolean ${ev.isNull} = $biggestCardinality == -1; + |if (!${ev.isNull}) { + | Object[] $args = new Object[$biggestCardinality]; + | for (int $i = 0; $i < $biggestCardinality; $i ++) { + | Object[] $currentRow = new Object[$numberOfArrays]; + | $getValueForTypeSplitted + | $args[$i] = new $genericInternalRow($currentRow); + | } + | ${ev.value} = new $genericArrayData($args); + |} + """.stripMargin) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (numberOfArrays == 0) { + emptyInputGenCode(ev) + } else { + nonEmptyInputGenCode(ctx, ev) + } + } + + override def eval(input: InternalRow): Any = { + val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) + if (inputArrays.contains(null)) { + null + } else { + val biggestCardinality = if (inputArrays.isEmpty) { + 0 + } else { + inputArrays.map(_.numElements()).max + } + + val result = new Array[InternalRow](biggestCardinality) + val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex + + for (i <- 0 until biggestCardinality) { + val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + arr.get(i, arrayElementTypes(index)) + } else { + null + } + } + + result(i) = InternalRow.apply(currentLayer: _*) + } + new GenericArrayData(result) + } + } + + override def prettyName: String = "arrays_zip" +} + /** * Returns an unordered array containing the values of the map. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f8ad624ce0e3d..85e692bdc4ef1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Some(Literal.create(null, StringType))), null) } + test("ArraysZip") { + val literals = Seq( + Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), + Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), + Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), + Literal.create(Seq("a", null, "c"), ArrayType(StringType)), + Literal.create(Seq(null, false, true), ArrayType(BooleanType)), + Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), + Literal.create(Seq(), ArrayType(NullType)), + Literal.create(Seq(null), ArrayType(NullType)), + Literal.create(Seq(192.toByte), ArrayType(ByteType)), + Literal.create( + Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), + Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) + ) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1))), + List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(2))), + List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(3))), + List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(4))), + List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(5))), + List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(6))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(7))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))), + List( + Row(9001, null, -1, "a"), + Row(9002, 1L, -3, null), + Row(9003, null, 900, "c"), + Row(null, 4L, null, null), + Row(null, 11L, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), + List( + Row(null, 1.1, null, null, 192.toByte), + Row(false, null, null, null, null), + Row(true, 1.3, null, null, null), + Row(null, null, null, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(9), literals(0))), + List( + Row(List(1, 2, 3), 9001), + Row(null, 9002), + Row(List(4, 5), 9003), + Row(List(1, null, 3), null))) + + checkEvaluation(ArraysZip(Seq(literals(7), literals(10))), + List(Row(null, Array[Byte](1.toByte, 5.toByte)))) + + val longLiteral = + Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) + + checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)), + List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ + (3 to 1000).map { Row(null, _) }.toList) + + val manyLiterals = (0 to 1000).map { _ => + Literal.create(Seq(1), ArrayType(IntegerType)) + }.toSeq + + val numbers = List( + Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), + Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) + checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals), + List(numbers(0), numbers(1), numbers(2), numbers(3))) + + checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) + checkEvaluation(ArraysZip(Seq()), List()) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2aae9a708ff3..266a136fc2410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3508,6 +3508,14 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + /** + * Returns a merged array of structs in which the N-th struct contains all N-th values of input + * arrays. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 59119bbbd8a2c..959a77a9ea345 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -479,6 +479,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("dataframe arrays_zip function") { + val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") + val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3") + val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") + val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2") + val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4") + val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null))) + .toDF("v1", "v2", "v3", "v4") + val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2") + val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2") + + val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) + checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1) + checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1) + + val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11))) + checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2) + checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2) + + val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) + checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3) + checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3) + + val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null))) + checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4) + checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4) + + val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null))) + checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5) + checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5) + + val expectedValue6 = Row(Seq( + Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null))) + checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6) + checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6) + + val expectedValue7 = Row(Seq( + Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2))) + checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7) + checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7) + + val expectedValue8 = Row(Seq( + Row(Array[Byte](1.toByte, 5.toByte), null))) + checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8) + checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) + } + test("map size function") { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), From cc88d7fad16e8b5cbf7b6b9bfe412908782b4a45 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Tue, 12 Jun 2018 12:10:08 -0700 Subject: [PATCH 0950/2461] [SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is not safe in scala ## What changes were proposed in this pull request? When user create a aggregator object in scala and pass the aggregator to Spark Dataset's agg() method, Spark's will initialize TypedAggregateExpression with the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName is not safe in scala environment, depending on how user creates the aggregator object. For example, if the aggregator class full qualified name is "com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw java.lang.InternalError "Malformed class name". This has been reported in scalatest https://github.com/scalatest/scalatest/pull/1044 and discussed in many scala upstream jiras such as SI-8110, SI-5425. To fix this issue, we follow the solution in https://github.com/scalatest/scalatest/pull/1044 to add safer version of getSimpleName as a util method, and TypedAggregateExpression will invoke this util method rather than getClass.getSimpleName. ## How was this patch tested? added unit test Author: Fangshi Li Closes #21276 from fangshil/SPARK-24216. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 59 ++++++++++++++++++- .../org/apache/spark/util/UtilsSuite.scala | 16 +++++ .../spark/ml/util/Instrumentation.scala | 5 +- .../aggregate/TypedAggregateExpression.scala | 5 +- .../v2/DataSourceV2StringFormat.scala | 4 +- 6 files changed, 89 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 3b469a69437b9..bf618b4afbce0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } override def toString: String = { + // getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead if (metadata == null) { - "Un-registered Accumulator: " + getClass.getSimpleName + "Un-registered Accumulator: " + Utils.getSimpleName(getClass) } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" + Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)" } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f9191a59c1655..7428db2158538 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} +import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -1820,7 +1821,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef): String = { - obj.getClass.getSimpleName.replace("$", "") + getSimpleName(obj.getClass).replace("$", "") } /** @@ -2715,6 +2716,62 @@ private[spark] object Utils extends Logging { HashCodes.fromBytes(secretBytes).toString() } + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + return cls.getSimpleName + } catch { + case err: InternalError => return stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + private def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an intepreter + // generated class, so we should return the full string + s + } else { + // The class name is intepreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.reverse.find(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3b4273184f1e9..418d2f9b88500 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") } } + + object MalformedClassObject { + class MalformedClass + } + + test("Safe getSimpleName") { + // getSimpleName on class of MalformedClass will result in error: Malformed class name + // Utils.getSimpleName works + val err = intercept[java.lang.InternalError] { + classOf[MalformedClassObject.MalformedClass].getSimpleName + } + assert(err.getMessage === "Malformed class name") + + assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === + "UtilsSuite$MalformedClassObject$MalformedClass") + } } private class SimpleExtension diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3a1c166d46257..11f46eb9e4359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log @@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( private val id = UUID.randomUUID() private val prefix = { - val className = estimator.getClass.getSimpleName + // estimator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + val className = Utils.getSimpleName(estimator.getClass) s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index aab8cc50b9526..6d44890704f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction { s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + // aggregator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$"); } // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 693e67dcd108e..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -53,7 +53,9 @@ trait DataSourceV2StringFormat { private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() - case _ => source.getClass.getSimpleName.stripSuffix("$") + // source.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + case _ => Utils.getSimpleName(source.getClass) } def metadataString: String = { From ada28f25955a9e8ddd182ad41b2a4ef278f3d809 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jun 2018 12:31:22 -0700 Subject: [PATCH 0951/2461] [SPARK-23933][SQL] Add map_from_arrays function ## What changes were proposed in this pull request? The PR adds the SQL function `map_from_arrays`. The behavior of the function is based on Presto's `map`. Since SparkSQL already had a `map` function, we prepared the different name for this behavior. This function returns returns a map from a pair of arrays for keys and values. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21258 from kiszk/SPARK-23933. --- python/pyspark/sql/functions.py | 19 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/complexTypeCreator.scala | 72 ++++++++++++++++++- .../expressions/ComplexTypeSuite.scala | 44 ++++++++++++ .../org/apache/spark/sql/functions.scala | 11 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 30 ++++++++ 6 files changed, 176 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0715297042520..1cdbb8a4c3e8b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1819,6 +1819,25 @@ def create_map(*cols): return Column(jc) +@since(2.4) +def map_from_arrays(col1, col2): + """Creates a new map from two arrays. + + :param col1: name of column containing a set of keys. All elements should not be null + :param col2: name of column containing a set of values + + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() + +----------------+ + | map| + +----------------+ + |[2 -> a, 5 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def array(*cols): """Creates a new array column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3c0b72873af54..3700c63d817ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -417,6 +417,7 @@ object FunctionRegistry { expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), + expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a9867aaeb0cfe..0a5f8a907b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -236,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +/** + * Returns a catalyst Map containing the two arrays in children expressions as keys and values. + */ +@ExpressionDescription( + usage = """ + _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements + in keys should not be null""", + examples = """ + Examples: + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} + """, since = "2.4.0") +case class MapFromArrays(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def dataType: DataType = { + MapType( + keyType = left.dataType.asInstanceOf[ArrayType].elementType, + valueType = right.dataType.asInstanceOf[ArrayType].elementType, + valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) + } + + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { + val keyArrayData = keyArray.asInstanceOf[ArrayData] + val valueArrayData = valueArray.asInstanceOf[ArrayData] + if (keyArrayData.numElements != valueArrayData.numElements) { + throw new RuntimeException("The given two arrays should have the same length") + } + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + if (leftArrayType.containsNull) { + var i = 0 + while (i < keyArrayData.numElements) { + if (keyArrayData.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + i += 1 + } + } + new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { + val i = ctx.freshName("i") + s""" + |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { + | if ($keyArrayData.isNullAt($i)) { + | throw new RuntimeException("Cannot use null as map key!"); + | } + |} + """.stripMargin + } + s""" + |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { + | throw new RuntimeException("The given two arrays should have the same length"); + |} + |$keyArrayElemNullCheck + |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); + """.stripMargin + }) + } + + override def prettyName: String = "map_from_arrays" +} + /** * An expression representing a not yet available attribute name. This expression is unevaluable * and as its name suggests it is a temporary place holder until we're able to determine the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b4138ce366b3a..726193b411737 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("MapFromArrays") { + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) + val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) + val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) + + val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) + val longArray = Literal.create(longSeq, ArrayType(LongType, false)) + val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + + val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) + val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + + val nullArray = Literal.create(null, ArrayType(StringType, false)) + + checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + + checkEvaluation( + MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(MapFromArrays(nullArray, nullArray), null) + + intercept[RuntimeException] { + checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) + } + intercept[RuntimeException] { + checkEvaluation( + MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 266a136fc2410..87bd7b3b0f9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1070,6 +1070,17 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Creates a new map column. The array in the first column is used for keys. The array in the + * second column is used for values. All elements in the array for key should not be null. + * + * @group normal_funcs + * @since 2.4 + */ + def map_from_arrays(keys: Column, values: Column): Column = withExpr { + MapFromArrays(keys.expr, values.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 959a77a9ea345..4e5c1c56e2673 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -62,6 +62,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } + test("map with arrays") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + val row = df1.select(map_from_arrays($"k", $"v")).first() + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) + checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + + val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") + checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + + val df3 = Seq((null, null)).toDF("k", "v") + checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null))) + + val df4 = Seq((1, "a")).toDF("k", "v") + intercept[AnalysisException] { + df4.select(map_from_arrays($"k", $"v")) + } + + val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[RuntimeException] { + df5.select(map_from_arrays($"k", $"v")).collect + } + + val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") + intercept[RuntimeException] { + df6.select(map_from_arrays($"k", $"v")).collect + } + } + test("struct with column name") { val df = Seq((1, "str")).toDF("a", "b") val row = df.select(struct("a", "b")).first() From 0d3714d221460a2a1141134c3d451f18c4e0d46f Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 12 Jun 2018 15:57:43 -0700 Subject: [PATCH 0952/2461] [SPARK-23010][BUILD][FOLLOWUP] Fix java checkstyle failure of kubernetes-integration-tests ## What changes were proposed in this pull request? Fix java checkstyle failure of kubernetes-integration-tests ## How was this patch tested? Checked manually on my local environment. Author: Xingbo Jiang Closes #21545 from jiangxb1987/k8s-checkstyle. --- project/SparkBuild.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index adc2b6b5d273d..b606f9355e03b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,11 +57,11 @@ object BuildCommons { val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = + dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) = Seq("kubernetes", "mesos", "yarn", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") From f53818d35bdef5d20a2718b14a2fed4c468545c6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 12 Jun 2018 16:42:44 -0700 Subject: [PATCH 0953/2461] [SPARK-24506][UI] Add UI filters to tabs added after binding ## What changes were proposed in this pull request? Currently, `spark.ui.filters` are not applied to the handlers added after binding the server. This means that every page which is added after starting the UI will not have the filters configured on it. This can allow unauthorized access to the pages. The PR adds the filters also to the handlers added after the UI starts. ## How was this patch tested? manual tests (without the patch, starting the thriftserver with `--conf spark.ui.filters=org.apache.hadoop.security.authentication.server.AuthenticationFilter --conf spark.org.apache.hadoop.security.authentication.server.AuthenticationFilter.params="type=simple"` you can access `http://localhost:4040/sqlserver`; with the patch, 401 is the response as for the other pages). Author: Marco Gaido Closes #21523 from mgaido91/SPARK-24506. --- .../org/apache/spark/deploy/history/HistoryServer.scala | 1 - core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a9a4d5a4ec6a2..066275e8f8425 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -152,7 +152,6 @@ class HistoryServer( assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") handlers.synchronized { ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index d6a025a6f12da..52a955111231a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -263,7 +263,7 @@ private[spark] object JettyUtils extends Logging { filters.foreach { case filter : String => if (!filter.isEmpty) { - logInfo("Adding filter: " + filter) + logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.") val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter @@ -407,7 +407,7 @@ private[spark] object JettyUtils extends Logging { } pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - ServerInfo(server, httpPort, securePort, collection) + ServerInfo(server, httpPort, securePort, conf, collection) } catch { case e: Exception => server.stop() @@ -507,10 +507,12 @@ private[spark] case class ServerInfo( server: Server, boundPort: Int, securePort: Option[Int], + conf: SparkConf, private val rootHandler: ContextHandlerCollection) { - def addHandler(handler: ContextHandler): Unit = { + def addHandler(handler: ServletContextHandler): Unit = { handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) + JettyUtils.addFilters(Seq(handler), conf) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() From 9786ce66c52d41b1d58ddedb3a984f561fd09ff3 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 13 Jun 2018 09:10:52 +0800 Subject: [PATCH 0954/2461] [SPARK-22239][SQL][PYTHON] Enable grouped aggregate pandas UDFs as window functions with unbounded window frames ## What changes were proposed in this pull request? This PR enables using a grouped aggregate pandas UDFs as window functions. The semantics is the same as using SQL aggregation function as window functions. ``` >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql import Window >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> pandas_udf("double", PandasUDFType.GROUPED_AGG) ... def mean_udf(v): ... return v.mean() >>> w = Window.partitionBy('id') >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() +---+----+------+ | id| v|mean_v| +---+----+------+ | 1| 1.0| 1.5| | 1| 2.0| 1.5| | 2| 3.0| 6.0| | 2| 5.0| 6.0| | 2|10.0| 6.0| +---+----+------+ ``` The scope of this PR is somewhat limited in terms of: (1) Only supports unbounded window, which acts essentially as group by. (2) Only supports aggregation functions, not "transform" like window functions (n -> n mapping) Both of these are left as future work. Especially, (1) needs careful thinking w.r.t. how to pass rolling window data to python efficiently. (2) is a bit easier but does require more changes therefore I think it's better to leave it as a separate PR. ## How was this patch tested? WindowPandasUDFTests Author: Li Jin Closes #21082 from icexelloss/SPARK-22239-window-udf. --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 34 ++- python/pyspark/sql/tests.py | 238 ++++++++++++++++++ python/pyspark/worker.py | 20 +- .../sql/catalyst/analysis/Analyzer.scala | 11 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 +- .../sql/catalyst/expressions/PythonUDF.scala | 6 +- .../expressions/windowExpressions.scala | 33 ++- .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../sql/catalyst/planning/patterns.scala | 42 +++- .../spark/sql/execution/SparkPlanner.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 20 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../execution/python/WindowInPandasExec.scala | 173 +++++++++++++ 15 files changed, 580 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 41eac10d9b267..ebabedf950e39 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -40,6 +40,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -47,6 +48,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 14d9128502ab0..7e7e5822a6b20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -74,6 +74,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_UDF = 200 SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 + SQL_WINDOW_AGG_PANDAS_UDF = 203 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1cdbb8a4c3e8b..a5e3384e802b8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2616,10 +2616,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as - output types. + :class:`MapType` and :class:`StructType` are currently not supported as output types. - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` + + This example shows using grouped aggregated UDFs with groupby: >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( @@ -2636,7 +2638,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs as window functions. Note that only + unbounded window frame is supported at the moment: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = Window.partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.5| + | 1| 2.0| 1.5| + | 2| 3.0| 6.0| + | 2| 5.0| 6.0| + | 2|10.0| 6.0| + +---+----+------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4a3941de8a6a6..2d7a4f62d4ee8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5454,6 +5454,15 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + def test_invalid_args(self): from pyspark.sql.functions import mean @@ -5479,6 +5488,235 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean, pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a30d6bf523a50..38fe2ef06eac5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -128,6 +128,21 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) +def wrap_window_agg_pandas_udf(f, return_type): + # This is similar to grouped_agg_pandas_udf, the only difference + # is that window_agg_pandas_udf needs to repeat the return value + # to match window length, where grouped_agg_pandas_udf just returns + # the scalar value. + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]).repeat(len(series[0])) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: @@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f9947d1fa6c78..6e3107f1c6f75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1739,9 +1739,10 @@ class Analyzer( * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. - * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. - * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts - * it into the plan tree. + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s + * and [[WindowFunctionType]]s. + * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a + * [[Window]] operator and inserts it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { private def hasWindowFunction(exprs: Seq[Expression]): Boolean = @@ -1901,7 +1902,7 @@ class Analyzer( s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head - (spec.partitionSpec, spec.orderSpec) + (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) } }.toSeq @@ -1909,7 +1910,7 @@ class Analyzer( // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { - case (last, ((partitionSpec, orderSpec), windowExpressions)) => + case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..af256b98b34f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") + case _ @ WindowExpression(_: PythonUDF, + WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) + if !frame.isUnbounded => + failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") + case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window - // function. + // function or a Pandas window UDF. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => w + case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } @@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper { case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index efd664dde725a..6530b176968f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,10 +34,14 @@ object PythonUDF { e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } - def isGroupAggPandasUDF(e: Expression): Boolean = { + def isGroupedAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 9fe2fb2b95e4d..f957aaa96e98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** @@ -297,6 +297,37 @@ trait WindowFunction extends Expression { def frame: WindowFrame = UnspecifiedFrame } +/** + * Case objects that describe whether a window function is a SQL window function or a Python + * user-defined window function. + */ +sealed trait WindowFunctionType + +object WindowFunctionType { + case object SQL extends WindowFunctionType + case object Python extends WindowFunctionType + + def functionType(windowExpression: NamedExpression): WindowFunctionType = { + val t = windowExpression.collectFirst { + case _: WindowFunction | _: AggregateFunction => SQL + case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python + } + + // Normally a window expression would either have a SQL window function, a SQL + // aggregate function or a python window UDF. However, sometimes the optimizer will replace + // the window function if the value of the window function can be predetermined. + // For example, for query: + // + // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a) + // + // The window function will be replaced by expression literal(0) + // To handle this case, if a window expression doesn't have a regular window function, we + // consider its type to be SQL as literal(0) is also a SQL expression. + t.getOrElse(SQL) + } +} + + /** * An offset window function is a window function that returns the value of the input column offset * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bfa61116a6658..aa992def1ce6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. * - If the partition specs and order specs are the same and the window expression are - * independent, collapse into the parent. + * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) - if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && + // This assumes Window contains the same type of window expressions. This is ensured + // by ExtractWindowFunctions. + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f905707191..84be677e438a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -215,7 +216,7 @@ object PhysicalAggregation { case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg case udf: PythonUDF - if PythonUDF.isGroupAggPandasUDF(udf) && + if PythonUDF.isGroupedAggPandasUDF(udf) && !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => @@ -268,3 +269,40 @@ object PhysicalAggregation { case _ => None } } + +/** + * An extractor used when planning physical execution of a window. This extractor outputs + * the window function type of the logical window. + * + * The input logical window must contain same type of window functions, which is ensured by + * the rule ExtractWindowExpressions in the analyzer. + */ +object PhysicalWindow { + // windowFunctionType, windowExpression, partitionSpec, orderSpec, child + private type ReturnType = + (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + + // The window expression should not be empty here, otherwise it's a bug. + if (windowExpressions.isEmpty) { + throw new AnalysisException(s"Window expression is empty in $expr") + } + + val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType) + .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) => + if (t1 != t2) { + // We shouldn't have different window function type here, otherwise it's a bug. + throw new AnalysisException( + s"Found different window function type in $windowExpressions") + } else { + t1 + } + } + + Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child)) + + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 74048871f8d42..75f5ec0e253df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -41,6 +41,7 @@ class SparkPlanner( DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: + Window :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index be34387f6a874..d6951ad01fb0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } @@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Window extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalWindow( + WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => + execution.window.WindowExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case PhysicalWindow( + WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) => + execution.python.WindowInPandasExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { @@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 9d56f48249982..1e096100f7f43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(e) || + PythonUDF.isGroupedAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala new file mode 100644 index 0000000000000..c76832a1a3829 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +case class WindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + // Extract window expressions and window functions + val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) + + val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val context = TaskContext.get() + + val grouped = if (partitionSpec.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, partitionSpec, child.output) + } + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + val inputProj = UnsafeProjection.create(allInputs, child.output) + val pythonInput = grouped.map { case (_, rows) => + rows.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + inputProj(row) + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, windowInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = createResultProjection(expressions) + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} From 3352d6fe9a1efb6dee18e40bdf584930b10d1d3e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 13 Jun 2018 12:34:46 +0800 Subject: [PATCH 0955/2461] [SPARK-24466][SS] Fix TextSocketMicroBatchReader to be compatible with netcat again ## What changes were proposed in this pull request? TextSocketMicroBatchReader was no longer be compatible with netcat due to launching temporary reader for reading schema, and closing reader, and re-opening reader. While reliable socket server should be able to handle this without any issue, nc command normally can't handle multiple connections and simply exits when closing temporary reader. This patch fixes TextSocketMicroBatchReader to be compatible with netcat again, via deferring opening socket to the first call of planInputPartitions() instead of constructor. ## How was this patch tested? Added unit test which fails on current and succeeds with the patch. And also manually tested. Author: Jungtaek Lim Closes #21497 from HeartSaVioR/SPARK-24466. --- .../execution/streaming/sources/socket.scala | 7 +- .../sources/TextSocketStreamSuite.scala | 119 ++++++++++++------ 2 files changed, 85 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 8240e06d4ab72..91e3b7179c34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -22,6 +22,7 @@ import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -76,7 +77,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR @GuardedBy("this") private var lastOffsetCommitted: LongOffset = LongOffset(-1L) - initialize() + private val initialized: AtomicBoolean = new AtomicBoolean(false) /** This method is only used for unit test */ private[sources] def getCurrentOffset(): LongOffset = synchronized { @@ -149,6 +150,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { + if (initialized.compareAndSet(false, true)) { + initialize() + } + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 batches.slice(sliceStart, sliceEnd) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index a15a980bb92fd..52e8386f6b1fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming.sources -import java.io.IOException -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp @@ -33,9 +32,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -101,7 +101,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -130,7 +130,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val socket = spark .readStream .format("socket") @@ -216,20 +216,11 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "socket source does not support a user-specified schema")) } - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - batchReader = provider.createMicroBatchReader( - Optional.empty(), "", new DataSourceOptions(parameters.asJava)) - } - } - test("input row metrics") { serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -256,6 +247,66 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("verify ServerThread only accepts the first connection") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + // we are trying to connect to the server once again which should fail + try { + val socket2 = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + testStream(socket2)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + fail("StreamingQueryException is expected!") + } catch { + case e: StreamingQueryException if e.cause.isInstanceOf[SocketException] => // pass + } + } + } + + /** + * This class tries to mimic the behavior of netcat, so that we can ensure + * TextSocketStream supports netcat, which only accepts the first connection + * and exits the process when the first connection is closed. + * + * Please refer SPARK-24466 for more details. + */ private class ServerThread extends Thread with Logging { private val serverSocketChannel = ServerSocketChannel.open() serverSocketChannel.bind(new InetSocketAddress(0)) @@ -265,36 +316,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before override def run(): Unit = { try { + val clientSocketChannel = serverSocketChannel.accept() + + // Close server socket channel immediately to mimic the behavior that + // only first connection will be made and deny any further connections + // Note that the first client socket channel will be available + serverSocketChannel.close() + + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + while (true) { - val clientSocketChannel = serverSocketChannel.accept() - clientSocketChannel.configureBlocking(false) - clientSocketChannel.socket().setTcpNoDelay(true) - - // Check whether remote client is closed but still send data to this closed socket. - // This happens in DataStreamReader where a source will be created to get the schema. - var remoteIsClosed = false - var cnt = 0 - while (cnt < 3 && !remoteIsClosed) { - if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { - cnt += 1 - Thread.sleep(100) - } else { - remoteIsClosed = true - } - } - - if (remoteIsClosed) { - logInfo(s"remote client ${clientSocketChannel.socket()} is closed") - } else { - while (true) { - val line = messageQueue.take() + "\n" - clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) - } - } + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) } } catch { case e: InterruptedException => } finally { + // no harm to call close() again... serverSocketChannel.close() } } From 4c388bccf1bcac8f833fd9214096dd164c3ea065 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 13 Jun 2018 12:36:20 +0800 Subject: [PATCH 0956/2461] [SPARK-24485][SS] Measure and log elapsed time for filesystem operations in HDFSBackedStateStoreProvider ## What changes were proposed in this pull request? This patch measures and logs elapsed time for each operation which communicate with file system (mostly remote HDFS in production) in HDFSBackedStateStoreProvider to help investigating any latency issue. ## How was this patch tested? Manually tested. Author: Jungtaek Lim Closes #21506 from HeartSaVioR/SPARK-24485. --- .../scala/org/apache/spark/util/Utils.scala | 11 ++- .../state/HDFSBackedStateStoreProvider.scala | 83 +++++++++++-------- .../streaming/statefulOperators.scala | 9 +- 3 files changed, 62 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7428db2158538..c139db46b63a3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -31,6 +31,7 @@ import java.nio.file.Files import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ +import java.util.concurrent.TimeUnit.NANOSECONDS import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream @@ -434,7 +435,7 @@ private[spark] object Utils extends Logging { new URI("file:///" + rawFileName).getPath.substring(1) } - /** + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -507,6 +508,14 @@ private[spark] object Utils extends Logging { targetFile } + /** Records the duration of running `body`. */ + def timeTakenMs[T](body: => T): (T, Long) = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + (result, math.max(NANOSECONDS.toMillis(endTime - startTime), 0)) + } + /** * Download `in` to `tempFile`, then move it to `destFile`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index df722b953228b..118c82aa75e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.nio.channels.ClosedChannelException import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams @@ -280,38 +278,49 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit if (loadedCurrentVersionMap.isDefined) { return loadedCurrentVersionMap.get } - val snapshotCurrentVersionMap = readSnapshotFile(version) - if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } - return snapshotCurrentVersionMap.get - } - // Find the most recent map before this version that we can. - // [SPARK-22305] This must be done iteratively to avoid stack overflow. - var lastAvailableVersion = version - var lastAvailableMap: Option[MapType] = None - while (lastAvailableMap.isEmpty) { - lastAvailableVersion -= 1 + logWarning(s"The state for version $version doesn't exist in loadedMaps. " + + "Reading snapshot file and delta files if needed..." + + "Note that this is normal for the first batch of starting query.") - if (lastAvailableVersion <= 0) { - // Use an empty map for versions 0 or less. - lastAvailableMap = Some(new MapType) - } else { - lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } - .orElse(readSnapshotFile(lastAvailableVersion)) + val (result, elapsedMs) = Utils.timeTakenMs { + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get + } + + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { loadedMaps.get(lastAvailableVersion) } + .orElse(readSnapshotFile(lastAvailableVersion)) + } + } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) } - } - // Load all the deltas from the version after the last available one up to the target version. - // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = new MapType(lastAvailableMap.get) - for (deltaVersion <- lastAvailableVersion + 1 to version) { - updateFromDeltaFile(deltaVersion, resultMap) + synchronized { loadedMaps.put(version, resultMap) } + resultMap } - synchronized { loadedMaps.put(version, resultMap) } - resultMap + logDebug(s"Loading state for $version takes $elapsedMs ms.") + + result } private def writeUpdateToDeltaFile( @@ -490,7 +499,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Perform a snapshot of the store to allow delta files to be consolidated */ private def doSnapshot(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val lastVersion = files.last.version val deltaFilesForLastVersion = @@ -498,7 +509,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { - writeSnapshotFile(lastVersion, map) + val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) + logDebug(s"writeSnapshotFile() took $e2 ms.") } case None => // The last map is not loaded, probably some other instance is in charge @@ -517,7 +529,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def cleanup(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { @@ -527,9 +541,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit mapsToRemove.foreach(loadedMaps.remove) } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) - filesToDelete.foreach { f => - fm.delete(f.path) + val (_, e2) = Utils.timeTakenMs { + filesToDelete.foreach { f => + fm.delete(f.path) + } } + logDebug(s"deleting files took $e2 ms.") logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 1691a6320a526..6759fb42b4052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ -import org.apache.spark.util.{CompletionIterator, NextIterator} +import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} /** Used to identify the state store for a given operator. */ @@ -97,12 +97,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } /** Records the duration of running `body` for the next query progress update. */ - protected def timeTakenMs(body: => Unit): Long = { - val startTime = System.nanoTime() - val result = body - val endTime = System.nanoTime() - math.max(NANOSECONDS.toMillis(endTime - startTime), 0) - } + protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** * Set the SQL metrics related to the state store. From 7703b46d2843db99e28110c4c7ccf60934412504 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Wed, 13 Jun 2018 20:43:16 +0800 Subject: [PATCH 0957/2461] [SPARK-24479][SS] Added config for registering streamingQueryListeners ## What changes were proposed in this pull request? Currently a "StreamingQueryListener" can only be registered programatically. We could have a new config "spark.sql.streamingQueryListeners" similar to "spark.sql.queryExecutionListeners" and "spark.extraListeners" for users to register custom streaming listeners. ## How was this patch tested? New unit test and running example programs. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21504 from arunmahadevan/SPARK-24480. --- .../spark/sql/internal/StaticSQLConf.scala | 8 +++ .../sql/streaming/StreamingQueryManager.scala | 15 +++++ .../StreamingQueryListenersConfSuite.scala | 66 +++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index fe0ad39c29025..382ef28f49a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -96,6 +96,14 @@ object StaticSQLConf { .toSequence .createOptional + val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners") + .doc("List of class names implementing StreamingQueryListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional + val UI_RETAINED_EXECUTIONS = buildStaticConf("spark.sql.ui.retainedExecutions") .doc("Number of executions to retain in the Spark UI.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 97da2b1325f58..25bb05212d66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null + try { + sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[StreamingQueryListener], classNames, + sparkSession.sparkContext.conf).foreach(listener => { + addListener(listener) + logInfo(s"Registered listener ${listener.getClass.getName}") + }) + } + } catch { + case e: Exception => + throw new SparkException("Exception when registering StreamingQueryListener", e) + } + /** * Returns a list of active queries associated with this SQLContext * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala new file mode 100644 index 0000000000000..1aaf8a9aa2d55 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ + + +class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + + override protected def sparkConf: SparkConf = + super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + "org.apache.spark.sql.streaming.TestListener") + + test("test if the configured query lister is loaded") { + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + + assert(TestListener.queryStartedEvent != null) + assert(TestListener.queryTerminatedEvent != null) + } + +} + +object TestListener { + @volatile var queryStartedEvent: QueryStartedEvent = null + @volatile var queryTerminatedEvent: QueryTerminatedEvent = null +} + +class TestListener(sparkConf: SparkConf) extends StreamingQueryListener { + + override def onQueryStarted(event: QueryStartedEvent): Unit = { + TestListener.queryStartedEvent = event + } + + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = { + TestListener.queryTerminatedEvent = event + } +} From 299d297e250ca3d46616a97e4256aa9ad6a135e5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 13 Jun 2018 07:09:48 -0700 Subject: [PATCH 0958/2461] [SPARK-24500][SQL] Make sure streams are materialized during Tree transforms. ## What changes were proposed in this pull request? If you construct catalyst trees using `scala.collection.immutable.Stream` you can run into situations where valid transformations do not seem to have any effect. There are two causes for this behavior: - `Stream` is evaluated lazily. Note that default implementation will generally only evaluate a function for the first element (this makes testing a bit tricky). - `TreeNode` and `QueryPlan` use side effects to detect if a tree has changed. Mapping over a stream is lazy and does not need to trigger this side effect. If this happens the node will invalidly assume that it did not change and return itself instead if the newly created node (this is for GC reasons). This PR fixes this issue by forcing materialization on streams in `TreeNode` and `QueryPlan`. ## How was this patch tested? Unit tests were added to `TreeNodeSuite` and `LogicalPlanSuite`. An integration test was added to the `PlannerSuite` Author: Herman van Hovell Closes #21539 from hvanhovell/SPARK-24500. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 1 + .../spark/sql/catalyst/trees/TreeNode.scala | 122 ++++++++---------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 20 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 25 +++- .../spark/sql/execution/PlannerSuite.scala | 11 +- 5 files changed, 109 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c726772f..e431c9523a9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee10..becfa8d982213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 14041747fd20e..bf569cb869428 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef642..b7092f4c42d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 98a50fbd52b4d..ed0ff1be476c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext { } assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } } // Used for unit-testing EnsureRequirements From 1b46f41c55f5cd29956e17d7da95a95580cf273f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 13 Jun 2018 13:13:01 -0700 Subject: [PATCH 0959/2461] [SPARK-24235][SS] Implement continuous shuffle writer for single reader partition. ## What changes were proposed in this pull request? https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit Implement continuous shuffle write RDD for a single reader partition. (I don't believe any implementation changes are actually required for multiple reader partitions, but this PR is already very large, so I want to exclude those for now to keep the size down.) ## How was this patch tested? new unit tests Author: Jose Torres Closes #21428 from jose-torres/writerTask. --- .../shuffle/ContinuousShuffleReadRDD.scala | 6 +- .../shuffle/ContinuousShuffleWriter.scala | 27 ++ ...scala => RPCContinuousShuffleReader.scala} | 24 +- .../shuffle/RPCContinuousShuffleWriter.scala | 60 +++++ ...ite.scala => ContinuousShuffleSuite.scala} | 231 ++++++++++++++---- 5 files changed, 281 insertions(+), 67 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/{UnsafeRowReceiver.scala => RPCContinuousShuffleReader.scala} (86%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/{ContinuousShuffleReadSuite.scala => ContinuousShuffleSuite.scala} (65%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 801b28b751bee..cf6572d3de1f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -34,8 +34,10 @@ case class ContinuousShuffleReadPartition( // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..47b1f78b24505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index d81f552d56626..834e84675c7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -20,26 +20,24 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.NextIterator /** - * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. * * Each message comes tagged with writerId, identifying which writer the message is coming * from. The receiver will only begin the next epoch once all writers have sent an epoch * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { def writerId: Int } private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) - extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -48,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver( +private[shuffle] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -57,7 +55,7 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queues = Array.fill(numShuffleWriters) { - new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) } // Exposed for testing to determine if the endpoint gets stopped on task end. @@ -68,7 +66,9 @@ private[shuffle] class UnsafeRowReceiver( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: UnsafeRowReceiverMessage => + case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. queues(r.writerId).put(r) context.reply(()) } @@ -79,10 +79,10 @@ private[shuffle] class UnsafeRowReceiver( private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) - private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) - private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { - override def call(): UnsafeRowReceiverMessage = queues(writerId).take() + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() } // Initialize by submitting tasks to read the first row from each writer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..1c6f3ddb395e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils + +/** + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. + */ +class RPCContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) + } + + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala similarity index 65% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 2e4d607a403ca..a8e3611b585cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,29 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String -class ContinuousShuffleReadSuite extends StreamTest { - - private def unsafeRow(value: Int) = { - UnsafeProjection.create(Array(IntegerType : DataType))( - new GenericInternalRow(Array(value: Any))) - } - - private def unsafeRow(value: String) = { - UnsafeProjection.create(Array(StringType : DataType))( - new GenericInternalRow(Array(UTF8String.fromString(value): Any))) - } - - private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { - messages.foreach(endpoint.askSync[Unit](_)) - } - +class ContinuousShuffleSuite extends StreamTest { // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -58,39 +43,29 @@ class ContinuousShuffleReadSuite extends StreamTest { super.afterEach() } - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) + private implicit def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) } - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint } - test("one epoch") { + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("reader - one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -105,7 +80,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } - test("multiple epochs") { + test("reader - multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -124,7 +99,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } - test("empty epochs") { + test("reader - empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -148,7 +123,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } - test("multiple partitions") { + test("reader - multiple partitions") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { @@ -169,7 +144,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("blocks waiting for new rows") { + test("reader - blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) val epoch = rdd.compute(rdd.partitions(0), ctx) @@ -195,7 +170,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("multiple writers") { + test("reader - multiple writers") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -213,7 +188,7 @@ class ContinuousShuffleReadSuite extends StreamTest { Set("writer0-row0", "writer1-row0", "writer2-row0")) } - test("epoch only ends when all writers send markers") { + test("reader - epoch only ends when all writers send markers") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -233,6 +208,7 @@ class ContinuousShuffleReadSuite extends StreamTest { // After checking the right rows, block until we get an epoch marker indicating there's no next. // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { override def run(): Unit = { assert(!epoch.hasNext) @@ -251,10 +227,10 @@ class ContinuousShuffleReadSuite extends StreamTest { } // Join to pick up assertion failures. - readEpochMarkerThread.join() + readEpochMarkerThread.join(streamingTimeout.toMillis) } - test("writer epochs non aligned") { + test("reader - writer epochs non aligned") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should @@ -288,4 +264,153 @@ class ContinuousShuffleReadSuite extends StreamTest { val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join(streamingTimeout.toMillis) + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } } From 3bf76918fb67fb3ee9aed254d4fb3b87a7e66117 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Jun 2018 15:18:19 -0700 Subject: [PATCH 0960/2461] [SPARK-24531][TESTS] Replace 2.3.0 version with 2.3.1 ## What changes were proposed in this pull request? The PR updates the 2.3 version tested to the new release 2.3.1. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #21543 from mgaido91/patch-1. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6f904c937348d..514921875f1f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") protected var spark: SparkSession = _ From 534065efeb51ff0d308fa6cc9dea0715f8ce25ad Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 14 Jun 2018 14:20:48 +0800 Subject: [PATCH 0961/2461] [MINOR][CORE][TEST] Remove unnecessary sort in UnsafeInMemorySorterSuite ## What changes were proposed in this pull request? We don't require specific ordering of the input data, the sort action is not necessary and misleading. ## How was this patch tested? Existing test suite. Author: Xingbo Jiang Closes #21536 from jiangxb1987/sorterSuite. --- .../util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c145532328514..85ffdca436e14 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = From fdadc4be08dcf1a06383bbb05e53540da2092c63 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 14 Jun 2018 09:20:41 -0700 Subject: [PATCH 0962/2461] [SPARK-24495][SQL] EnsureRequirement returns wrong plan when reordering equal keys ## What changes were proposed in this pull request? `EnsureRequirement` in its `reorder` method currently assumes that the same key appears only once in the join condition. This of course might not be the case, and when it is not satisfied, it returns a wrong plan which produces a wrong result of the query. ## How was this patch tested? added UT Author: Marco Gaido Closes #21529 from mgaido91/SPARK-24495. --- .../execution/exchange/EnsureRequirements.scala | 14 ++++++++++++-- .../scala/org/apache/spark/sql/JoinSuite.scala | 11 +++++++++++ .../spark/sql/execution/PlannerSuite.scala | 17 +++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..ad95879d86f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { + plan match { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = @@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8fa747465cb1a..44767dfc92497 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ed0ff1be476c7..37d468739c613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -680,6 +680,23 @@ class PlannerSuite extends SharedSQLContext { assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + test("SPARK-24500: create union with stream of children") { val df = Union(Stream( Range(1, 1, 1, 1), From d3eed8fd6d65d95306abfb513a9e0fde05b703ac Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 14 Jun 2018 13:16:20 -0700 Subject: [PATCH 0963/2461] =?UTF-8?q?[SPARK-24563][PYTHON]=20Catch=20TypeE?= =?UTF-8?q?rror=20when=20testing=20existence=20of=20HiveConf=20when=20crea?= =?UTF-8?q?ting=20pysp=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ark shell ## What changes were proposed in this pull request? This PR catches TypeError when testing existence of HiveConf when creating pyspark shell ## How was this patch tested? Manually tested. Here are the manual test cases: Build with hive: ``` (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ bin/pyspark Python 3.6.5 | packaged by conda-forge | (default, Apr 6 2018, 13:44:09) [GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin Type "help", "copyright", "credits" or "license" for more information. 18/06/14 14:55:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Python version 3.6.5 (default, Apr 6 2018 13:44:09) SparkSession available as 'spark'. >>> spark.conf.get('spark.sql.catalogImplementation') 'hive' ``` Build without hive: ``` (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ bin/pyspark Python 3.6.5 | packaged by conda-forge | (default, Apr 6 2018, 13:44:09) [GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin Type "help", "copyright", "credits" or "license" for more information. 18/06/14 15:04:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Python version 3.6.5 (default, Apr 6 2018 13:44:09) SparkSession available as 'spark'. >>> spark.conf.get('spark.sql.catalogImplementation') 'in-memory' ``` Failed to start shell: ``` (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ bin/pyspark Python 3.6.5 | packaged by conda-forge | (default, Apr 6 2018, 13:44:09) [GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin Type "help", "copyright", "credits" or "license" for more information. 18/06/14 15:07:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). /Users/icexelloss/workspace/spark/python/pyspark/shell.py:45: UserWarning: Failed to initialize Spark session. warnings.warn("Failed to initialize Spark session.") Traceback (most recent call last): File "/Users/icexelloss/workspace/spark/python/pyspark/shell.py", line 41, in spark = SparkSession._create_shell_session() File "/Users/icexelloss/workspace/spark/python/pyspark/sql/session.py", line 581, in _create_shell_session return SparkSession.builder.getOrCreate() File "/Users/icexelloss/workspace/spark/python/pyspark/sql/session.py", line 168, in getOrCreate raise py4j.protocol.Py4JError("Fake Py4JError") py4j.protocol.Py4JError: Fake Py4JError (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ ``` Author: Li Jin Closes #21569 from icexelloss/SPARK-24563-fix-pyspark-shell-without-hive. --- python/pyspark/sql/session.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e880dd1ca6d1a..f1ad6b1212ed9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -567,14 +567,7 @@ def _create_shell_session(): .getOrCreate() else: return SparkSession.builder.getOrCreate() - except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - - try: - return SparkSession.builder.getOrCreate() - except TypeError: + except (py4j.protocol.Py4JError, TypeError): if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': warnings.warn("Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive") From b8f27ae3b34134a01998b77db4b7935e7f82a4fe Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 14 Jun 2018 13:27:27 -0700 Subject: [PATCH 0964/2461] [SPARK-24543][SQL] Support any type as DDL string for from_json's schema ## What changes were proposed in this pull request? In the PR, I propose to support any DataType represented as DDL string for the from_json function. After the changes, it will be possible to specify `MapType` in SQL like: ```sql select from_json('{"a":1, "b":2}', 'map') ``` and in Scala (similar in other languages) ```scala val in = Seq("""{"a": {"b": 1}}""").toDS() val schema = "map>" val out = in.select(from_json($"value", schema, Map.empty[String, String])) ``` ## How was this patch tested? Added a couple sql tests and modified existing tests for Python and Scala. The former tests were modified because it is not imported for them in which format schema for `from_json` is provided. Author: Maxim Gekk Closes #21550 from MaxGekk/from_json-ddl-schema. --- python/pyspark/sql/functions.py | 3 +-- .../catalyst/expressions/jsonExpressions.scala | 5 ++--- .../org/apache/spark/sql/types/DataType.scala | 11 +++++++++++ .../scala/org/apache/spark/sql/functions.scala | 2 +- .../sql-tests/inputs/json-functions.sql | 4 ++++ .../sql-tests/results/json-functions.sql.out | 18 +++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 4 ++-- 7 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a5e3384e802b8..e6346691fb1d4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2168,8 +2168,7 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] - >>> schema = MapType(StringType(), IntegerType()) - >>> df.select(from_json(df.value, schema).alias("json")).collect() + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 04a4eb0ffc032..f6d74f5b74c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -747,8 +746,8 @@ case class StructsToJson( object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + def validateSchemaLiteral(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) case e => throw new AnalysisException(s"Expected a string literal instead of $e") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..fd40741cfb5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 87bd7b3b0f9c6..8551058ec58ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3369,7 +3369,7 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d48..dc15d13cd1dd3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,7 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 14a69128ffb41..2b3288dc5a137 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 28 -- !query 0 @@ -258,3 +258,19 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 055e1fc5640f3..7bf17cbcd9c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -354,8 +354,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24027: from_json - map>") { val in = Seq("""{"a": {"b": 1}}""").toDS() - val schema = MapType(StringType, MapType(StringType, IntegerType)) - val out = in.select(from_json($"value", schema)) + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) } From 18cb0c07988578156c869682d8a2c4151e8d35e5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 14 Jun 2018 14:54:46 -0700 Subject: [PATCH 0965/2461] [SPARK-24319][SPARK SUBMIT] Fix spark-submit execution where no main class is required. ## What changes were proposed in this pull request? With [PR 20925](https://github.com/apache/spark/pull/20925) now it's not possible to execute the following commands: * run-example * run-example --help * run-example --version * run-example --usage-error * run-example --status ... * run-example --kill ... In this PR the execution will be allowed for the mentioned commands. ## How was this patch tested? Existing unit tests extended + additional written. Author: Gabor Somogyi Closes #21450 from gaborgsomogyi/SPARK-24319. --- .../java/org/apache/spark/launcher/Main.java | 36 +++++++++---- .../launcher/SparkSubmitCommandBuilder.java | 33 ++++++------ .../SparkSubmitCommandBuilderSuite.java | 54 +++++++++++++++++-- 3 files changed, 90 insertions(+), 33 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c73279..d967aa39a4827 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c21..cc65f78b45c30 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f8413074..b343094b2e7b8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } From 270a9a3cac25f3e799460320d0fc94ccd7ecfaea Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 14 Jun 2018 15:56:21 -0700 Subject: [PATCH 0966/2461] [SPARK-24248][K8S] Use level triggering and state reconciliation in scheduling and lifecycle ## What changes were proposed in this pull request? Previously, the scheduler backend was maintaining state in many places, not only for reading state but also writing to it. For example, state had to be managed in both the watch and in the executor allocator runnable. Furthermore, one had to keep track of multiple hash tables. We can do better here by: 1. Consolidating the places where we manage state. Here, we take inspiration from traditional Kubernetes controllers. These controllers tend to follow a level-triggered mechanism. This means that the controller will continuously monitor the API server via watches and polling, and on periodic passes, the controller will reconcile the current state of the cluster with the desired state. We implement this by introducing the concept of a pod snapshot, which is a given state of the executors in the Kubernetes cluster. We operate periodically on snapshots. To prevent overloading the API server with polling requests to get the state of the cluster (particularly for executor allocation where we want to be checking frequently to get executors to launch without unbearably bad latency), we use watches to populate snapshots by applying observed events to a previous snapshot to get a new snapshot. Whenever we do poll the cluster, the polled state replaces any existing snapshot - this ensures eventual consistency and mirroring of the cluster, as is desired in a level triggered architecture. 2. Storing less specialized in-memory state in general. Previously we were creating hash tables to represent the state of executors. Instead, it's easier to represent state solely by the snapshots. ## How was this patch tested? Integration tests should test there's no regressions end to end. Unit tests to be updated, in particular focusing on different orderings of events, particularly accounting for when events come in unexpected ordering. Author: mcheah Closes #21366 from mccheah/event-queue-driven-scheduling. --- LICENSE | 1 + .../org/apache/spark/util/ThreadUtils.scala | 31 +- licenses/LICENSE-jmock.txt | 28 ++ pom.xml | 6 + resource-managers/kubernetes/core/pom.xml | 12 +- .../org/apache/spark/deploy/k8s/Config.scala | 19 +- .../cluster/k8s/ExecutorPodStates.scala | 37 ++ .../cluster/k8s/ExecutorPodsAllocator.scala | 149 ++++++ .../k8s/ExecutorPodsLifecycleManager.scala | 176 +++++++ .../ExecutorPodsPollingSnapshotSource.scala | 68 +++ .../cluster/k8s/ExecutorPodsSnapshot.scala | 74 +++ .../k8s/ExecutorPodsSnapshotsStore.scala | 32 ++ .../k8s/ExecutorPodsSnapshotsStoreImpl.scala | 113 +++++ .../k8s/ExecutorPodsWatchSnapshotSource.scala | 67 +++ .../k8s/KubernetesClusterManager.scala | 42 +- .../KubernetesClusterSchedulerBackend.scala | 417 +++------------- .../spark/deploy/k8s/Fabric8Aliases.scala | 30 ++ .../spark/deploy/k8s/submit/ClientSuite.scala | 9 +- ...erministicExecutorPodsSnapshotsStore.scala | 51 ++ .../k8s/ExecutorLifecycleTestUtils.scala | 123 +++++ .../k8s/ExecutorPodsAllocatorSuite.scala | 179 +++++++ .../ExecutorPodsLifecycleManagerSuite.scala | 126 +++++ ...ecutorPodsPollingSnapshotSourceSuite.scala | 85 ++++ .../k8s/ExecutorPodsSnapshotSuite.scala | 60 +++ .../k8s/ExecutorPodsSnapshotsStoreSuite.scala | 137 ++++++ ...ExecutorPodsWatchSnapshotSourceSuite.scala | 75 +++ ...bernetesClusterSchedulerBackendSuite.scala | 451 +++--------------- 27 files changed, 1842 insertions(+), 756 deletions(-) create mode 100644 licenses/LICENSE-jmock.txt create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala diff --git a/LICENSE b/LICENSE index 820f14dbdeed0..cc1f580207a75 100644 --- a/LICENSE +++ b/LICENSE @@ -237,6 +237,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 165a15c73e7ca..0f08a2b0ad895 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,13 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} - import org.apache.spark.SparkException private[spark] object ThreadUtils { @@ -103,6 +102,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -229,4 +244,14 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } } diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-jmock.txt new file mode 100644 index 0000000000000..ed7964fe3d9ef --- /dev/null +++ b/licenses/LICENSE-jmock.txt @@ -0,0 +1,28 @@ +Copyright (c) 2000-2017, jMock.org +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. Redistributions +in binary form must reproduce the above copyright notice, this list of +conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +Neither the name of jMock nor the names of its contributors may be +used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pom.xml b/pom.xml index 23bbd3b09734e..4b4e6c13ea8fd 100644 --- a/pom.xml +++ b/pom.xml @@ -760,6 +760,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f271273465..a6dd47a6b7d95 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -77,6 +77,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +90,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 590deaa72e7ee..bf33179ae3dab 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -176,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = ConfigBuilder("spark.kubernetes.memoryOverheadFactor") .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + @@ -193,7 +211,6 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 0000000000000..83daddf714489 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 0000000000000..5a143ad3600fd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val driverPod = kubernetesClient.pods() + .withName(kubernetesDriverPodName) + .get() + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 0000000000000..b28d93990313e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 0000000000000..e77e604d00e0f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 0000000000000..26be918043412 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..dd264332cf9e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 0000000000000..5583b4617eeb2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
    + * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
    + * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
    + * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 0000000000000..a6749a644e00c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d97..c6e931a38405f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,7 +17,9 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} @@ -26,7 +28,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -56,17 +58,45 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071b..fa6dc2c479bbf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 0000000000000..527fc6b0d8f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index a8a8218c621ea..d045d9ae89c07 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..f7721e6fd6388 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 0000000000000..c6b667ed85e8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 0000000000000..0c19f5946b75f --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + driverPod) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 0000000000000..562ace9f49d4d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..1b26d6af296a5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 0000000000000..70e19c904eddb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 0000000000000..cf54b3c4eb329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..ac1968b4ff810 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069c..52e7a12dbaf06 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } From 22daeba59b3ffaccafc9ff4b521abc265d0e58dd Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 14 Jun 2018 20:59:42 -0700 Subject: [PATCH 0967/2461] [SPARK-24478][SQL] Move projection and filter push down to physical conversion ## What changes were proposed in this pull request? This removes the v2 optimizer rule for push-down and instead pushes filters and required columns when converting to a physical plan, as suggested by marmbrus. This makes the v2 relation cleaner because the output and filters do not change in the logical plan. A side-effect of this change is that the stats from the logical (optimized) plan no longer reflect pushed filters and projection. This is a temporary state, until the planner gathers stats from the physical plan instead. An alternative to this approach is https://github.com/rdblue/spark/commit/9d3a11e68bca6c5a56a2be47fb09395350362ac5. The first commit was proposed in #21262. This PR replaces #21262. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #21503 from rdblue/SPARK-24478-move-push-down-to-physical-conversion. --- .../v2/reader/SupportsReportStatistics.java | 5 ++ .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../datasources/v2/DataSourceV2Relation.scala | 85 +++++-------------- .../datasources/v2/DataSourceV2Strategy.scala | 47 +++++++++- .../v2/PushDownOperatorsToDataSource.scala | 66 -------------- .../sql/sources/v2/DataSourceV2Suite.scala | 11 +-- 6 files changed, 79 insertions(+), 139 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b211..a79080a249ec8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -22,6 +22,11 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. + * + * Statistics are reported to the optimizer before a projection or any filters are pushed to the + * DataSourceReader. Implementations that return more accurate statistics based on projection and + * filters will not improve query performance until the planner can push operators before getting + * stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -32,8 +31,7 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 90fb5a14c9fc9..e08af218513fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources.{DataSourceRegister, Filter} @@ -32,69 +31,27 @@ import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( source: DataSourceV2, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString - - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } - - private lazy val v2Options: DataSourceOptions = makeV2Options(options) + override def pushedFilters: Seq[Expression] = Seq.empty - // postScanFilters: filters that need to be evaluated after the scan. - // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. - // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. - lazy val ( - reader: DataSourceReader, - postScanFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (postScanFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - - (newReader, postScanFilters, pushedFilters) - } - - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + override def simpleString: String = "RelationV2 " + metadataString - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + lazy val v2Options: DataSourceOptions = makeV2Options(options) - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } + def newReader: DataSourceReader = userSpecifiedSchema match { + case Some(userSchema) => + source.asReadSupportWithSchema.createReader(userSchema, v2Options) + case None => + source.asReadSupport.createReader(v2Options) } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -102,9 +59,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -206,21 +161,27 @@ object DataSourceV2Relation { def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) + val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, output, options, userSpecifiedSchema) } - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + def pushRequiredColumns( + relation: DataSourceV2Relation, + reader: DataSourceReader, + struct: StructType): Seq[AttributeReference] = { reader match { case projectionSupport: SupportsPushDownRequiredColumns => projectionSupport.pruneColumns(struct) + // return the output columns from the relation that were projected + val attrMap = relation.output.map(a => a.name -> a).toMap + projectionSupport.readSchema().map(f => attrMap(f.name)) case _ => + relation.output } } - private def pushFilters( + def pushFilters( reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { @@ -248,7 +209,7 @@ object DataSourceV2Relation { // the data source cannot guarantee the rows returned can pass these filters. // As a result we must return it so Spark can plan an extra filter operator. val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) // The filters which are marked as pushed to this data source val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1b7c639f10f98..8bf858c38d76c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,15 +17,56 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy +import org.apache.spark.sql.{execution, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(relation.output) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + relation.output + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } + + val reader = relation.newReader + + val output = DataSourceV2Relation.pushRequiredColumns(relation, reader, + projection.asInstanceOf[Seq[AttributeReference]].toStructType) + + val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters) + + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + + val scan = DataSourceV2ScanExec( + output, relation.source, relation.options, pushedFilters, reader) + + val filter = postScanFilters.reduceLeftOption(And) + val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan) + + val withProjection = if (withFilter.output != project) { + execution.ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil case r: StreamingDataSourceV2Relation => DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index e894f8afd6762..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - assert(relation.filters.isEmpty, "data source v2 should do push down only once.") - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that need to be evaluated after scan. - val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) - val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - - case other => other.mapChildren(apply) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 505a3f3465c02..e96cd4500458d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } From 6567fc43aca75b41900cde976594e21c8b0ca98a Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro Date: Fri, 15 Jun 2018 16:59:00 +0800 Subject: [PATCH 0968/2461] [PYTHON] Fix typo in serializer exception ## What changes were proposed in this pull request? Fix typo in exception raised in Python serializer ## How was this patch tested? No code changes Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ruben Berenguel Montoro Closes #21566 from rberenguel/fix_typo_pyspark_serializers. --- python/pyspark/serializers.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..4c16b5fc26f3d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -461,7 +462,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +526,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +628,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): From 495d8cf09ae7134aa6d2feb058612980e02955fa Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 15 Jun 2018 09:59:02 -0700 Subject: [PATCH 0969/2461] [SPARK-24490][WEBUI] Use WebUI.addStaticHandler in web UIs `WebUI` defines `addStaticHandler` that web UIs don't use (and simply introduce duplication). Let's clean them up and remove duplications. Local build and waiting for Jenkins Author: Jacek Laskowski Closes #21510 from jaceklaskowski/SPARK-24490-Use-WebUI.addStaticHandler. --- .../spark/deploy/history/HistoryServer.scala | 2 +- .../spark/deploy/master/ui/MasterWebUI.scala | 2 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 2 +- .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/WebUI.scala | 52 ++++++++++--------- .../deploy/mesos/ui/MesosClusterUI.scala | 2 +- .../spark/streaming/ui/StreamingTab.scala | 2 +- 7 files changed, 33 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 066275e8f8425..56f3f59504a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4db..e87b2240564bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384bd..ea67b7434a769 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1febc..d315ef66e0dc0 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 8b75f5d8fe1a8..2e43f17e6a8e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -60,23 +60,25 @@ private[spark] abstract class WebUI( def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager - /** Attach a tab to this UI, along with all of its attached pages. */ - def attachTab(tab: WebUITab) { + /** Attaches a tab to this UI, along with all of its attached pages. */ + def attachTab(tab: WebUITab): Unit = { tab.pages.foreach(attachPage) tabs += tab } - def detachTab(tab: WebUITab) { + /** Detaches a tab from this UI, along with all of its attached pages. */ + def detachTab(tab: WebUITab): Unit = { tab.pages.foreach(detachPage) tabs -= tab } - def detachPage(page: WebUIPage) { + /** Detaches a page from this UI, along with all of its attached handlers. */ + def detachPage(page: WebUIPage): Unit = { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } - /** Attach a page to this UI. */ - def attachPage(page: WebUIPage) { + /** Attaches a page to this UI. */ + def attachPage(page: WebUIPage): Unit = { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) @@ -88,41 +90,41 @@ private[spark] abstract class WebUI( handlers += renderHandler } - /** Attach a handler to this UI. */ - def attachHandler(handler: ServletContextHandler) { + /** Attaches a handler to this UI. */ + def attachHandler(handler: ServletContextHandler): Unit = { handlers += handler serverInfo.foreach(_.addHandler(handler)) } - /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + /** Detaches a handler from this UI. */ + def detachHandler(handler: ServletContextHandler): Unit = { handlers -= handler serverInfo.foreach(_.removeHandler(handler)) } /** - * Add a handler for static content. + * Detaches the content handler at `path` URI. * - * @param resourceBase Root of where to find resources to serve. - * @param path Path in UI where to mount the resources. + * @param path Path in UI to unmount. */ - def addStaticHandler(resourceBase: String, path: String): Unit = { - attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + def detachHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) } /** - * Remove a static content handler. + * Adds a handler for static content. * - * @param path Path in UI to unmount. + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. */ - def removeStaticHandler(path: String): Unit = { - handlers.find(_.getContextPath() == path).foreach(detachHandler) + def addStaticHandler(resourceBase: String, path: String = "/static"): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) } - /** Initialize all components of the server. */ + /** A hook to initialize components of the UI */ def initialize(): Unit - /** Bind to the HTTP server behind this web interface. */ + /** Binds to the HTTP server behind this web interface. */ def bind(): Unit = { assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { @@ -136,17 +138,17 @@ private[spark] abstract class WebUI( } } - /** Return the url of web interface. Only valid after bind(). */ + /** @return The url of web interface. Only valid after [[bind]]. */ def webUrl: String = s"http://$publicHostName:$boundPort" - /** Return the actual port to which this server is bound. Only valid after bind(). */ + /** @return The actual port to which this server is bound. Only valid after [[bind]]. */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - /** Stop the server behind this web interface. Only valid after bind(). */ + /** Stops the server behind this web interface. Only valid after [[bind]]. */ def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") - serverInfo.get.stop() + serverInfo.foreach(_.stop()) } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6db..15bbe60d6c8fb 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b1..25e71258b9369 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } From b5ccf0d3957a444db93893c0ce4417bfbbb11822 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 15 Jun 2018 12:56:39 -0700 Subject: [PATCH 0970/2461] [SPARK-24396][SS][PYSPARK] Add Structured Streaming ForeachWriter for python ## What changes were proposed in this pull request? This PR adds `foreach` for streaming queries in Python. Users will be able to specify their processing logic in two different ways. - As a function that takes a row as input. - As an object that has methods `open`, `process`, and `close` methods. See the python docs in this PR for more details. ## How was this patch tested? Added java and python unit tests Author: Tathagata Das Closes #21477 from tdas/SPARK-24396. --- python/pyspark/sql/streaming.py | 162 +++++++++++ python/pyspark/sql/tests.py | 257 ++++++++++++++++++ python/pyspark/tests.py | 4 +- .../org/apache/spark/sql/ForeachWriter.scala | 62 ++++- .../python/PythonForeachWriter.scala | 161 +++++++++++ .../sources/ForeachWriterProvider.scala | 52 +++- .../sql/streaming/DataStreamWriter.scala | 48 +--- .../python/PythonForeachWriterSuite.scala | 137 ++++++++++ 8 files changed, 811 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fae50b3d5d532..4984593bab491 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -854,6 +854,168 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d7a4f62d4ee8..4e5fafa77e109 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1869,6 +1869,263 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert(msg in str(e), "%s not in %s" % (msg, str(e))) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18b2f251dc9fd..a4c5fb1db8b37 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -581,9 +581,9 @@ def test_get_local_property(self): self.sc.setLocalProperty(key, value) try: rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] self.assertEqual(prop1, value) - prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] self.assertTrue(prop2 is None) finally: self.sc.setLocalProperty(key, None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f3..b21c50af18433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
      + *
    • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
    • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
    • The lifecycle of the methods are as follows. + * + *
      + *   For each partition with `partitionId`:
      + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
      + *           Method `open(partitionId, epochId)` is called.
      + *           If `open` returns true:
      + *                For each row in the partition and batch/epoch, method `process(row)` is called.
      + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
      + *   
      + * + *
    + * + * Important points to note: + *
      + *
    • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
    • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
    * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 0000000000000..a58773122922f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36f..f677f25f116a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, attemptNumber: Int, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..e035c9cdc379e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -307,49 +307,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 0000000000000..07e6034770127 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} From 90da7dc241f8eec2348c0434312c97c116330bc4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 15 Jun 2018 13:47:48 -0700 Subject: [PATCH 0971/2461] [SPARK-24452][SQL][CORE] Avoid possible overflow in int add or multiple ## What changes were proposed in this pull request? This PR fixes possible overflow in int add or multiply. In particular, their overflows in multiply are detected by [Spotbugs](https://spotbugs.github.io/) The following assignments may cause overflow in right hand side. As a result, the result may be negative. ``` long = int * int long = int + int ``` To avoid this problem, this PR performs cast from int to long in right hand side. ## How was this patch tested? Existing UTs. Author: Kazuaki Ishizaki Closes #21481 from kiszk/SPARK-24452. --- .../spark/unsafe/map/BytesToBytesMap.java | 2 +- .../spark/deploy/worker/DriverRunner.scala | 2 +- .../apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 14 +-- .../VariableLengthRowBasedKeyValueBatch.java | 2 +- .../vectorized/OffHeapColumnVector.java | 106 +++++++++--------- .../vectorized/OnHeapColumnVector.java | 10 +- .../sources/RateStreamMicroBatchReader.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 2 +- 11 files changed, 73 insertions(+), 73 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9a767dd739b91 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..df1a4bef616b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -291,7 +291,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..4dd2b7365652a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } private static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index fbff8db987110..b393c48baee8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -202,7 +202,7 @@ class RateStreamMicroBatchInputPartitionReader( rangeEnd: Long, localStartTimeMs: Long, relativeMsPerValue: Double) extends InputPartitionReader[Row] { - private var count = 0 + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca2..8620f3f6d99fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -342,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..2e8599026ea1d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) From e4fee395ecd93ad4579d9afbf0861f82a303e563 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Fri, 15 Jun 2018 13:56:48 -0700 Subject: [PATCH 0972/2461] [SPARK-24525][SS] Provide an option to limit number of rows in a MemorySink ## What changes were proposed in this pull request? Provide an option to limit number of rows in a MemorySink. Currently, MemorySink and MemorySinkV2 have unbounded size, meaning that if they're used on big data, they can OOM the stream. This change adds a maxMemorySinkRows option to limit how many rows MemorySink and MemorySinkV2 can hold. By default, they are still unbounded. ## How was this patch tested? Added new unit tests. Author: Mukul Murthy Closes #21559 from mukulmurthy/SPARK-24525. --- .../sql/execution/streaming/memory.scala | 70 ++++++++++++++-- .../streaming/sources/memoryV2.scala | 44 +++++++--- .../sql/streaming/DataStreamWriter.scala | 4 +- .../execution/streaming/MemorySinkSuite.scala | 62 +++++++++++++- .../streaming/MemorySinkV2Suite.scala | 80 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 4 +- 6 files changed, 239 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..7fa13c4aa2c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -221,19 +222,60 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { +trait MemorySinkBase extends BaseStreamingSink with Logging { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] + + /** + * Truncates the given rows to return at most maxRows rows. + * @param rows The data that may need to be truncated. + * @param batchLimit Number of rows to keep in this batch; the rest will be truncated + * @param sinkLimit Total number of rows kept in this sink, for logging purposes. + * @param batchId The ID of the batch that sent these rows, for logging purposes. + * @return Truncated rows. + */ + protected def truncateRowsIfNeeded( + rows: Array[Row], + batchLimit: Int, + sinkLimit: Int, + batchId: Long): Array[Row] = { + if (rows.length > batchLimit && batchLimit >= 0) { + logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") + rows.take(batchLimit) + } else { + rows + } + } +} + +/** + * Companion object to MemorySinkBase. + */ +object MemorySinkBase { + val MAX_MEMORY_SINK_ROWS = "maxRows" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 + + /** + * Gets the max number of rows a MemorySink should store. This number is based on the memory + * sink row limit option if it is set. If not, we use a large value so that data truncates + * rather than causing out of memory errors. + * @param options Options for writing from which we get the max rows option + * @return The maximum number of rows a memorySink should store. + */ + def getMemorySinkCapacity(options: DataSourceOptions): Int = { + val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + if (maxRows >= 0) maxRows else Int.MaxValue - 10 + } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) + extends Sink with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -241,6 +283,12 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + + /** The capacity in rows of this sink. */ + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } + var rowsToAdd = data.collect() + synchronized { + rowsToAdd = + truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, data.collect()) + var rowsToAdd = data.collect() synchronized { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3c..47b482007822d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -81,7 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write( + batchId: Long, + outputMode: OutputMode, + newRows: Array[Row], + sinkCapacity: Int): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -89,14 +96,21 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } + synchronized { + val rowsToAdd = + truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, newRows) synchronized { + val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -110,6 +124,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -117,16 +132,22 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter( + sink: MemorySinkV2, + batchId: Long, + outputMode: OutputMode, + options: DataSourceOptions) extends DataSourceWriter with Logging { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows) + sink.write(batchId, outputMode, newRows, sinkCapacity) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -134,16 +155,21 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter( + val sink: MemorySinkV2, + outputMode: OutputMode, + options: DataSourceOptions) extends StreamWriter { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, outputMode, newRows, sinkCapacity) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index e035c9cdc379e..43e80e4e54239 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode) + val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..b2fd6ba27ebb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } + test("directly add data in Append output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Append, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 5) + checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit + } + test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } + test("directly add data in Complete output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Complete, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 10) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 8) + checkAnswer(sink.allData, 4 to 8) // new data should replace old data + } + test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..e539510e15755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -40,7 +45,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +67,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +75,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -80,4 +85,73 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("continuous writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) + appendWriter.commit(0, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + appendWriter.commit(19, Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))))) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + + val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) + completeWriter.commit(20, Array( + MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(5, Seq(Row(33))))) + assert(sink.latestBatchId.contains(20)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + completeWriter.commit(21, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), + MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), + MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) + assert(sink.latestBatchId.contains(21)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + } + + test("microbatch writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e45..e41b4534ed51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -337,7 +338,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 + else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath From c7c0b086a0b18424725433ade840d5121ac2b86e Mon Sep 17 00:00:00 2001 From: James Yu Date: Fri, 15 Jun 2018 21:04:04 -0700 Subject: [PATCH 0973/2461] add one supported type missing from the javadoc ## What changes were proposed in this pull request? The supported java.math.BigInteger type is not mentioned in the javadoc of Encoders.bean() ## How was this patch tested? only Javadoc fix Please review http://spark.apache.org/contributing.html before opening a pull request. Author: James Yu Closes #21544 from yuj/master. --- sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. From b0a935255951280b49c39968f6234163e2f0e379 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 18 Jun 2018 15:32:34 +0800 Subject: [PATCH 0974/2461] [SPARK-24573][INFRA] Runs SBT checkstyle after the build to work around a side-effect ## What changes were proposed in this pull request? Seems checkstyle affects the build in the PR builder in Jenkins. I can't reproduce in my local and seems it can only be reproduced in the PR builder. I was checking the places it goes through and this is just a speculation that checkstyle's compilation in SBT has a side effect to the assembly build. This PR proposes to run the SBT checkstyle after the build. ## How was this patch tested? Jenkins tests. Author: hyukjinkwon Closes #21579 from HyukjinKwon/investigate-javastyle. --- dev/run-tests.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5e8c8590b5c34..cd4590864b7d7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,11 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - run_java_style_checks() + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -603,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) From e219e692ef70c161f37a48bfdec2a94b29260004 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 19 Jun 2018 00:24:54 +0800 Subject: [PATCH 0975/2461] [SPARK-23772][SQL] Provide an option to ignore column of all null values or empty array during JSON schema inference ## What changes were proposed in this pull request? This pr added a new JSON option `dropFieldIfAllNull ` to ignore column of all null values or empty array/struct during JSON schema inference. ## How was this patch tested? Added tests in `JsonSuite`. Author: Takeshi Yamamuro Author: Xiangrui Meng Closes #20929 from maropu/SPARK-23772. --- python/pyspark/sql/readwriter.py | 5 +- .../spark/sql/catalyst/json/JSONOptions.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonInferSchema.scala | 40 +++++++-------- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 49 +++++++++++++++++++ 6 files changed, 80 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a0e20d39c20da..3efe2adb6e2a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - encoding=None): + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2ff12acb2946f..c081772116f84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index de6be5f76e15a..ec9352a7fa055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -381,6 +381,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * that should be used for parsing. *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used * for schema inferring.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e7eed95a560a3..f6edc7bfb3750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -75,7 +75,7 @@ private[sql] object JsonInferSchema { // active SparkSession and `SQLConf.get` may point to the wrong configs. val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -181,33 +181,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..a8a4a524a97f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2427,4 +2427,53 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), Row(badJson)) } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } } From bce177552564a4862bc979d39790cf553a477d74 Mon Sep 17 00:00:00 2001 From: trystanleftwich Date: Tue, 19 Jun 2018 00:34:24 +0800 Subject: [PATCH 0976/2461] [SPARK-24526][BUILD][TEST-MAVEN] Spaces in the build dir causes failures in the build/mvn script ## What changes were proposed in this pull request? Fix the call to ${MVN_BIN} to be wrapped in quotes so it will handle having spaces in the path. ## How was this patch tested? Ran the following to confirm using the build/mvn tool with a space in the build dir now works without error ``` mkdir /tmp/test\ spaces cd /tmp/test\ spaces git clone https://github.com/apache/spark.git cd spark # Remove all mvn references in PATH so the script will download mvn to the local dir ./build/mvn -DskipTests clean package ``` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: trystanleftwich Closes #21534 from trystanleftwich/SPARK-24526. --- build/mvn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index efa4f9364ea52..1405983982d4c 100755 --- a/build/mvn +++ b/build/mvn @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" From 8f225e055c2031ca85d61721ab712170ab4e50c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Jun 2018 11:01:17 -0700 Subject: [PATCH 0977/2461] [SPARK-24548][SQL] Fix incorrect schema of Dataset with tuple encoders ## What changes were proposed in this pull request? When creating tuple expression encoders, we should give the serializer expressions of tuple items correct names, so we can have correct output schema when we use such tuple encoders. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21576 from viirya/SPARK-24548. --- .../catalyst/encoders/ExpressionEncoder.scala | 3 ++- .../org/apache/spark/sql/JavaDatasetSuite.java | 18 ++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78dc14e3..093cee91d2f49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 1737d45e08a5f1fb78515b14321721d7197b443a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jun 2018 20:15:01 -0700 Subject: [PATCH 0978/2461] [SPARK-24478][SQL][FOLLOWUP] Move projection and filter push down to physical conversion ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/21503, to completely move operator pushdown to the planner rule. The code are mostly from https://github.com/apache/spark/pull/21319 ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21574 from cloud-fan/followup. --- .../v2/reader/SupportsReportStatistics.java | 7 +- .../datasources/v2/DataSourceV2Relation.scala | 109 ++++----------- .../datasources/v2/DataSourceV2Strategy.scala | 124 +++++++++++++----- 3 files changed, 123 insertions(+), 117 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index a79080a249ec8..926396414816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -23,10 +23,9 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. * - * Statistics are reported to the optimizer before a projection or any filters are pushed to the - * DataSourceReader. Implementations that return more accurate statistics based on projection and - * filters will not improve query performance until the planner can push operators before getting - * stats. + * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. + * Implementations that return more accurate statistics based on pushed operators will not improve + * query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index e08af218513fd..7613eb210c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,17 +23,24 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. + */ case class DataSourceV2Relation( source: DataSourceV2, output: Seq[AttributeReference], options: Map[String, String], - userSpecifiedSchema: Option[StructType] = None) + userSpecifiedSchema: Option[StructType]) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ @@ -42,14 +49,7 @@ case class DataSourceV2Relation( override def simpleString: String = "RelationV2 " + metadataString - lazy val v2Options: DataSourceOptions = makeV2Options(options) - - def newReader: DataSourceReader = userSpecifiedSchema match { - case Some(userSchema) => - source.asReadSupportWithSchema.createReader(userSchema, v2Options) - case None => - source.asReadSupport.createReader(v2Options) - } + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => @@ -139,83 +139,26 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, output, options, userSpecifiedSchema) - } - - def pushRequiredColumns( - relation: DataSourceV2Relation, - reader: DataSourceReader, - struct: StructType): Seq[AttributeReference] = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - // return the output columns from the relation that were projected - val attrMap = relation.output.map(a => a.name -> a).toMap - projectionSupport.readSchema().map(f => attrMap(f.name)) - case _ => - relation.output - } - } - - def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (postScanFilters, pushedFilters) - - case r: SupportsPushDownFilters => - // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] - // Catalyst filter expression that can't be translated to data source filters. - val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] - - for (filterExpr <- filters) { - val translated = DataSourceStrategy.translateFilter(filterExpr) - if (translated.isDefined) { - translatedFilterToExpr(translated.get) = filterExpr - } else { - untranslatableExprs += filterExpr - } - } - - // Data source filters that need to be evaluated again after scanning. which means - // the data source cannot guarantee the rows returned can pass these filters. - // As a result we must return it so Spark can plan an extra filter operator. - val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) - - (untranslatableExprs ++ postScanFilters, pushedFilters) - - case _ => (filters, Nil) - } + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + val reader = source.createReader(options, userSpecifiedSchema) + DataSourceV2Relation( + source, reader.readSchema().toAttributes, options, userSpecifiedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 8bf858c38d76c..182aa2906cf1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,51 +17,115 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.{execution, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} object DataSourceV2Strategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(relation.output) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - relation.output - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - val reader = relation.newReader + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (pushedFilters, postScanFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } - val output = DataSourceV2Relation.pushRequiredColumns(relation, reader, - projection.asInstanceOf[Seq[AttributeReference]].toStructType) + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + reader: DataSourceReader, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + r.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + relation.output + } + + case _ => relation.output + } + } - val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters) - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val reader = relation.newReader() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) val scan = DataSourceV2ScanExec( output, relation.source, relation.options, pushedFilters, reader) - val filter = postScanFilters.reduceLeftOption(And) - val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan) + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) val withProjection = if (withFilter.output != project) { - execution.ProjectExec(project, withFilter) + ProjectExec(project, withFilter) } else { withFilter } From 9a75c18290fff7d116cf88a44f9120bf67d8bd27 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 18 Jun 2018 20:17:04 -0700 Subject: [PATCH 0979/2461] [SPARK-24542][SQL] UDF series UDFXPathXXXX allow users to pass carefully crafted XML to access arbitrary files ## What changes were proposed in this pull request? UDF series UDFXPathXXXX allow users to pass carefully crafted XML to access arbitrary files. Spark does not have built-in access control. When users use the external access control library, users might bypass them and access the file contents. This PR basically patches the Hive fix to Apache Spark. https://issues.apache.org/jira/browse/HIVE-18879 ## How was this patch tested? A unit test case Author: Xiao Li Closes #21549 from gatorsmile/xpathSecurity. --- .../expressions/xml/UDFXPathUtil.java | 28 ++++++++++++++++++- .../expressions/xml/UDFXPathUtilSuite.scala | 21 ++++++++++++++ .../xml/XPathExpressionSuite.scala | 5 ++-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } From a78a9046413255756653f70165520efd486fb493 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Jun 2018 10:42:08 -0700 Subject: [PATCH 0980/2461] [SPARK-24521][SQL][TEST] Fix ineffective test in CachedTableSuite ## What changes were proposed in this pull request? test("withColumn doesn't invalidate cached dataframe") in CachedTableSuite doesn't not work because: The UDF is executed and test count incremented when "df.cache()" is called and the subsequent "df.collect()" has no effect on the test result. This PR fixed this test and add another test for caching UDF. ## How was this patch tested? Add new tests. Author: Li Jin Closes #21531 from icexelloss/fix-cache-test. --- .../apache/spark/sql/CachedTableSuite.scala | 19 ---------- .../apache/spark/sql/DatasetCacheSuite.scala | 38 ++++++++++++++++++- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81b7e18773f81..6982c22f4771d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -83,25 +83,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..82a93f74dd76c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ test("get storage level") { @@ -96,4 +99,37 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(5 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } } From 9dbe53eb6bb5916d28000f2c0d646cf23094ac11 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 19 Jun 2018 10:52:51 -0700 Subject: [PATCH 0981/2461] [SPARK-24556][SQL] Always rewrite output partitioning in ReusedExchangeExec and InMemoryTableScanExec ## What changes were proposed in this pull request? Currently, ReusedExchange and InMemoryTableScanExec only rewrite output partitioning if child's partitioning is HashPartitioning and do nothing for other partitioning, e.g., RangePartitioning. We should always rewrite it, otherwise, unnecessary shuffle could be introduced like https://issues.apache.org/jira/browse/SPARK-24556. ## How was this patch tested? Add new tests. Author: yucai Closes #21564 from yucai/SPARK-24556. --- .../columnar/InMemoryTableScanExec.scala | 6 +- .../sql/execution/exchange/Exchange.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 64 ++++++++++++++++++- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -169,8 +169,8 @@ case class InMemoryTableScanExec( // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { relation.cachedPlan.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.cachedPlan.outputPartitioning + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 37d468739c613..d254345e8fa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -703,6 +703,66 @@ class PlannerSuite extends SharedSQLContext { Range(1, 2, 1, 1))) df.queryExecution.executedPlan.execute() } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements From 13092d733791b19cd7994084178306e0c449f2ed Mon Sep 17 00:00:00 2001 From: rimolive Date: Tue, 19 Jun 2018 13:25:00 -0700 Subject: [PATCH 0982/2461] [SPARK-24534][K8S] Bypass non spark-on-k8s commands ## What changes were proposed in this pull request? This PR changes the entrypoint.sh to provide an option to run non spark-on-k8s commands (init, driver, executor) in order to let the user keep with the normal workflow without hacking the image to bypass the entrypoint ## How was this patch tested? This patch was built manually in my local machine and I ran some tests with a combination of ```docker run``` commands. Author: rimolive Closes #21572 from rimolive/rimolive-spark-24534. --- .../src/main/dockerfiles/spark/entrypoint.sh | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index acdb4b1f09e0a..2f4e115e84ecd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -37,11 +37,17 @@ if [ -z "$uidentry" ] ; then fi SPARK_K8S_CMD="$1" -if [ -z "$SPARK_K8S_CMD" ]; then - echo "No command to execute has been provided." 1>&2 - exit 1 -fi -shift 1 +case "$SPARK_K8S_CMD" in + driver | driver-py | executor) + shift 1 + ;; + "") + ;; + *) + echo "Non-spark-on-k8s command provided, proceeding in pass-through mode..." + exec /sbin/tini -s -- "$@" + ;; +esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt @@ -92,7 +98,6 @@ case "$SPARK_K8S_CMD" in "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS ) ;; - executor) CMD=( ${JAVA_HOME}/bin/java From 2cb976355c615eee4ebd0a86f3911fa9284fccf6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 19 Jun 2018 13:56:51 -0700 Subject: [PATCH 0983/2461] [SPARK-24565][SS] Add API for in Structured Streaming for exposing output rows of each microbatch as a DataFrame ## What changes were proposed in this pull request? Currently, the micro-batches in the MicroBatchExecution is not exposed to the user through any public API. This was because we did not want to expose the micro-batches, so that all the APIs we expose, we can eventually support them in the Continuous engine. But now that we have better sense of buiding a ContinuousExecution, I am considering adding APIs which will run only the MicroBatchExecution. I have quite a few use cases where exposing the microbatch output as a dataframe is useful. - Pass the output rows of each batch to a library that is designed only the batch jobs (example, uses many ML libraries need to collect() while learning). - Reuse batch data sources for output whose streaming version does not exists (e.g. redshift data source). - Writer the output rows to multiple places by writing twice for each batch. This is not the most elegant thing to do for multiple-output streaming queries but is likely to be better than running two streaming queries processing the same data twice. The proposal is to add a method `foreachBatch(f: Dataset[T] => Unit)` to Scala/Java/Python `DataStreamWriter`. ## How was this patch tested? New unit tests. Author: Tathagata Das Closes #21571 from tdas/foreachBatch. --- python/pyspark/java_gateway.py | 25 ++- python/pyspark/sql/streaming.py | 33 +++- python/pyspark/sql/tests.py | 36 +++++ python/pyspark/sql/utils.py | 23 +++ python/pyspark/streaming/context.py | 18 +-- .../streaming/sources/ForeachBatchSink.scala | 58 +++++++ .../sql/streaming/DataStreamWriter.scala | 63 +++++++- .../sources/ForeachBatchSinkSuite.scala | 148 ++++++++++++++++++ 8 files changed, 383 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0afbe9dc6aa3e..fa2d5e8db716a 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -31,7 +31,7 @@ if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayParameters +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer @@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret): if reply != "ok": conn.close() raise Exception("Unexpected reply from iterator server.") + + +def ensure_callback_server_started(gw): + """ + Start callback server if not already started. The callback server is needed if the Java + driver process needs to callback into the Python driver process to execute Python code. + """ + + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 4984593bab491..8c1fd4af674d7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,12 +24,14 @@ else: intlike = (int, long) +from py4j.java_gateway import java_import + from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * -from pyspark.sql.utils import StreamingQueryException +from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -1016,6 +1018,35 @@ def func_with_open_process_close(partition_id, iterator): self._jwrite.foreach(jForeachWriter) return self + @since(2.4) + def foreachBatch(self, func): + """ + Sets the output of the streaming query to be processed using the provided + function. This is supported only the in the micro-batch execution modes (that is, when the + trigger is not continuous). In every micro-batch, the provided function will be called in + every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. + The batchId can be used deduplicate and transactionally write the output + (that is, the provided Dataset) to external systems. The output DataFrame is guaranteed + to exactly same for the same batchId (assuming all operations are deterministic in the + query). + + .. note:: Evolving. + + >>> def func(batch_df, batch_id): + ... batch_df.collect() + ... + >>> writer = sdf.writeStream.foreach(func) + """ + + from pyspark.java_gateway import ensure_callback_server_started + gw = self._spark._sc._gateway + java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") + + wrapped_func = ForeachBatchFunction(self._spark, func) + gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) + ensure_callback_server_started(gw) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e5fafa77e109..94ab867f0bd9b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2126,6 +2126,42 @@ class WriterWithNonCallableClose(WithProcess): tester.assert_invalid_writer(WriterWithNonCallableClose(), "'close' in provided object is not callable") + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 45363f089a73d..bb9ce02c4b60f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -150,3 +150,26 @@ def require_minimum_pyarrow_version(): if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): raise ImportError("PyArrow >= %s must be installed; however, " "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + + +class ForeachBatchFunction(object): + """ + This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps + the user-defined 'foreachBatch' function such that it can be called from the JVM when + the query is active. + """ + + def __init__(self, sql_ctx, func): + self.sql_ctx = sql_ctx + self.func = func + + def call(self, jdf, batch_id): + from pyspark.sql.dataframe import DataFrame + try: + self.func(DataFrame(jdf, self.sql_ctx), batch_id) + except Exception as e: + self.error = e + raise e + + class Java: + implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index dd924ef89868e..a4515828d180c 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,22 +79,8 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # start callback server - # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__ or gw._callback_server is None: - gw.callback_server_parameters.eager_load = True - gw.callback_server_parameters.daemonize = True - gw.callback_server_parameters.daemonize_connections = True - gw.callback_server_parameters.port = 0 - gw.start_callback_server(gw.callback_server_parameters) - cbport = gw._callback_server.server_socket.getsockname()[1] - gw._callback_server.port = cbport - # gateway with real port - gw._python_proxy_port = gw._callback_server.port - # get the GatewayServer object in JVM by ID - jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) - # update the port of CallbackClient with real port - jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + from pyspark.java_gateway import ensure_callback_server_started + ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala new file mode 100644 index 0000000000000..03c567c58d46a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.api.python.PythonException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.DataStreamWriter + +class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) + extends Sink { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val resolvedEncoder = encoder.resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val ds = data.sparkSession.createDataset(rdd)(encoder) + batchWriter(ds, batchId) + } + + override def toString(): String = "ForeachBatchSink" +} + + +/** + * Interface that is meant to be extended by Python classes via Py4J. + * Py4J allows Python classes to implement Java interfaces so that the JVM can call back + * Python objects. In this case, this allows the user-defined Python `foreachBatch` function + * to be called from JVM when the query is active. + * */ +trait PythonForeachBatchFunction { + /** Call the Python implementation of this function */ + def call(batchDF: DataFrame, batchId: Long): Unit +} + +object PythonForeachBatchHelper { + def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { + dsw.foreachBatch(pythonFunc.call _) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 43e80e4e54239..926c0b69a03fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,14 +21,15 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** @@ -279,6 +280,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { outputMode, useTempCheckpointLocation = true, trigger = trigger) + } else if (source == "foreachBatch") { + assertNotPartitioned("foreachBatch") + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + extraOptions.toMap, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") @@ -322,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + this.source = "foreachBatch" + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function + this + } + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) + } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -358,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var foreachWriter: ForeachWriter[T] = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var partitioningColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala new file mode 100644 index 0000000000000..a4233e15e4ffd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ + +case class KV(key: Int, value: Long) + +class ForeachBatchSinkSuite extends StreamTest { + import testImplicits._ + + test("foreachBatch with non-stateful query") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 5, 6, 7)(out = 7, 8, 9)) + } + + test("foreachBatch with stateful query in update mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Update)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (1, 1L)), + check(in = 2, 3)(out = (0, 2L), (1, 2L))) + } + + test("foreachBatch with stateful query in complete mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Complete)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (0, 1L), (1, 1L)), + check(in = 2)(out = (0, 2L), (1, 1L))) + } + + test("foreachBatchSink does not affect metric generation") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), + checkMetrics) + } + + test("throws errors in invalid situations") { + val ds = MemoryStream[Int].toDS + val ex1 = intercept[IllegalArgumentException] { + ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => Unit]).start() + } + assert(ex1.getMessage.contains("foreachBatch function cannot be null")) + val ex2 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + } + assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) + val ex3 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + } + assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) + } + + // ============== Helper classes and methods ================= + + private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { + trait Test + private case class Check(in: Seq[Int], out: Seq[T]) extends Test + private case object CheckMetrics extends Test + + private val recordedOutput = new mutable.HashMap[Long, Seq[T]] + + def testWriter( + ds: Dataset[T], + outputBatchWriter: (Dataset[T], Long) => Unit, + outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = { + try { + var expectedBatchId = -1 + val query = ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start() + + tests.foreach { + case Check(in, out) => + expectedBatchId += 1 + memoryStream.addData(in) + query.processAllAvailable() + assert(recordedOutput.contains(expectedBatchId)) + val ds: Dataset[T] = spark.createDataset[T](recordedOutput(expectedBatchId)) + checkDataset[T](ds, out: _*) + case CheckMetrics => + assert(query.recentProgress.exists(_.numInputRows > 0)) + } + } finally { + sqlContext.streams.active.foreach(_.stop()) + } + } + + def check(in: Int*)(out: T*): Test = Check(in, out) + def checkMetrics: Test = CheckMetrics + def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) + } +} From bc0498d5820ded2b428277e396502e74ef0ce36d Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Tue, 19 Jun 2018 15:27:20 -0700 Subject: [PATCH 0984/2461] [SPARK-24583][SQL] Wrong schema type in InsertIntoDataSourceCommand ## What changes were proposed in this pull request? Change insert input schema type: "insertRelationType" -> "insertRelationType.asNullable", in order to avoid nullable being overridden. ## How was this patch tested? Added one test in InsertSuite. Author: Maryann Xue Closes #21585 from maryannxue/spark-24583. --- .../InsertIntoDataSourceCommand.scala | 5 +- .../spark/sql/sources/InsertSuite.scala | 51 ++++++++++++++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..80d7608a22891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -38,9 +38,8 @@ case class InsertIntoDataSourceCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) - // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + // Data has been casted to the target relation's schema by the PreprocessTableInsertion rule. + relation.insert(data, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this // data source relation. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..438d5d8176b8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,36 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.collect + } +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +544,29 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", LongType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some(classOf[SimpleInsertSource].getName)) + + spark.sessionState.catalog.createTable(newTable, false) + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + sql("INSERT INTO TABLE test_table SELECT 2, null") + } + } } From bc111463a766a5619966a282fbe0fec991088ceb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Jun 2018 22:29:00 -0700 Subject: [PATCH 0985/2461] [SPARK-23778][CORE] Avoid unneeded shuffle when union gets an empty RDD ## What changes were proposed in this pull request? When a `union` is invoked on several RDDs of which one is an empty RDD, the result of the operation is a `UnionRDD`. This causes an unneeded extra-shuffle when all the other RDDs have the same partitioning. The PR ignores incoming empty RDDs in the union method. ## How was this patch tested? added UT Author: Marco Gaido Closes #21333 from mgaido91/SPARK-23778. --- .../main/scala/org/apache/spark/SparkContext.scala | 9 +++++---- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 14 +++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5e8595603cc90..74bfb5d6d2ea3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1306,11 +1306,12 @@ class SparkContext(config: SparkConf) extends Logging { /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { - val partitioners = rdds.flatMap(_.partitioner).toSet - if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { - new PartitionerAwareUnionRDD(this, rdds) + val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty) + val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet + if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, nonEmptyRdds) } else { - new UnionRDD(this, rdds) + new UnionRDD(this, nonEmptyRdds) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 191c61250ce21..5148ce05bd918 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -154,6 +154,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-23778: empty RDD in union should not produce a UnionRDD") { + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val emptyRDD = sc.emptyRDD[(Int, Boolean)] + val unionRDD = sc.union(emptyRDD, rddWithPartitioner) + assert(unionRDD.isInstanceOf[PartitionerAwareUnionRDD[_]]) + val unionAllEmptyRDD = sc.union(emptyRDD, emptyRDD) + assert(unionAllEmptyRDD.isInstanceOf[UnionRDD[_]]) + assert(unionAllEmptyRDD.collect().isEmpty) + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) @@ -1047,7 +1057,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty - override def getPartitions: Array[Partition] = Array.empty + override def getPartitions: Array[Partition] = Array(new Partition { + override def index: Int = 0 + }) override def getDependencies: Seq[Dependency[_]] = mutableDependencies def addDependency(dep: Dependency[_]) { mutableDependencies += dep From c8ef9232cf8b8ef262404b105cea83c1f393d8c3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 20 Jun 2018 18:38:42 +0200 Subject: [PATCH 0986/2461] [MINOR][SQL] Remove invalid comment from SparkStrategies ## What changes were proposed in this pull request? This patch is removing invalid comment from SparkStrategies, given that TODO-like comment is no longer preferred one as the comment: https://github.com/apache/spark/pull/21388#issuecomment-396856235 Removing invalid comment will prevent contributors to spend their times which is not going to be merged. ## How was this patch tested? N/A Author: Jungtaek Lim Closes #21595 from HeartSaVioR/MINOR-remove-invalid-comment-on-spark-strategies. --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d6951ad01fb0c..07a6fcae83b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -494,7 +494,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil From c5a0d1132a5608f2110781763f4c2229c6cd7175 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Wed, 20 Jun 2018 18:57:13 +0200 Subject: [PATCH 0987/2461] [SPARK-24575][SQL] Prohibit window expressions inside WHERE and HAVING clauses ## What changes were proposed in this pull request? As discussed [before](https://github.com/apache/spark/pull/19193#issuecomment-393726964), this PR prohibits window expressions inside WHERE and HAVING clauses. ## How was this patch tested? This PR comes with a dedicated unit test. Author: aokolnychyi Closes #21580 from aokolnychyi/spark-24575. --- .../sql/catalyst/analysis/Analyzer.scala | 3 ++ .../sql/DataFrameWindowFunctionsSuite.scala | 42 +++++++++++++++++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e3107f1c6f75..e187133d03b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1923,6 +1923,9 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case Filter(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3ea398aad7375..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} - -import scala.collection.mutable +import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} @@ -27,7 +25,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -624,4 +621,41 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { + def checkAnalysisError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + } + + checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(avg('b).as("avgb")) + .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where(rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1)) + + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError( + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql( + s"""SELECT a, MAX(b) + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + } } From 3f4bda7289f1bfbbe8b9bc4b516007f569c44d2e Mon Sep 17 00:00:00 2001 From: Wenbo Zhao Date: Wed, 20 Jun 2018 14:26:04 -0700 Subject: [PATCH 0988/2461] [SPARK-24578][CORE] Cap sub-region's size of returned nio buffer ## What changes were proposed in this pull request? This PR tries to fix the performance regression introduced by SPARK-21517. In our production job, we performed many parallel computations, with high possibility, some task could be scheduled to a host-2 where it needs to read the cache block data from host-1. Often, this big transfer makes the cluster suffer time out issue (it will retry 3 times, each with 120s timeout, and then do recompute to put the cache block into the local MemoryStore). The root cause is that we don't do `consolidateIfNeeded` anymore as we are using ``` Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) ``` in ChunkedByteBuffer. If we have many small chunks, it could cause the `buf.notBuffer(...)` have very bad performance in the case that we have to call `copyByteBuf(...)` many times. ## How was this patch tested? Existing unit tests and also test in production Author: Wenbo Zhao Closes #21593 from WenboZhao/spark-24578. --- .../network/protocol/MessageWithHeader.java | 25 ++++--------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index a5337656cbd84..e7b66a6f33a82 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -137,30 +137,15 @@ protected void deallocate() { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - ByteBuffer buffer = buf.nioBuffer(); - int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - target.write(buffer) : writeNioBuffer(target, buffer); + // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance + // for the case that the passed-in buffer has too many components. + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + int written = target.write(buffer); buf.skipBytes(written); return written; } - private int writeNioBuffer( - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - int originalLimit = buf.limit(); - int ret = 0; - - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - ret = writeCh.write(buf); - } finally { - buf.limit(originalLimit); - } - - return ret; - } - @Override public MessageWithHeader touch(Object o) { super.touch(o); From 15747cfd3246385ffb23e19e28d2e4effa710bf6 Mon Sep 17 00:00:00 2001 From: Ray Burgemeestre Date: Wed, 20 Jun 2018 17:09:37 -0700 Subject: [PATCH 0989/2461] [SPARK-24547][K8S] Allow for building spark on k8s docker images without cache and don't forget to push spark-py container. ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-24547 TL;DR from JIRA issue: - First time I generated images for 2.4.0 Docker was using it's cache, so actually when running jobs, old jars where still in the Docker image. This produces errors like this in the executors: `java.io.InvalidClassException: org.apache.spark.storage.BlockManagerId; local class incompatible: stream classdesc serialVersionUID = 6155820641931972169, local class serialVersionUID = -3720498261147521051` - The second problem was that the spark container is pushed, but the spark-py container wasn't yet. This was just forgotten in the initial PR. - A third problem I also ran into because I had an older docker was https://github.com/apache/spark/pull/21551 so I have not included a fix for that in this ticket. ## How was this patch tested? I've tested it on my own Spark on k8s deployment. Author: Ray Burgemeestre Closes #21555 from rayburgemeestre/SPARK-24547. --- bin/docker-image-tool.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a871ab5d448c3..a3f1bcffaea57 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -70,17 +70,18 @@ function build { local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - docker build "${BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ -f "$BASEDOCKERFILE" . - docker build "${BINDING_BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ -f "$PYDOCKERFILE" . } function push { docker push "$(image_ref spark)" + docker push "$(image_ref spark-py)" } function usage { @@ -99,6 +100,7 @@ Options: -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. + -n Build docker image with --no-cache Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -127,7 +129,8 @@ REPO= TAG= BASEDOCKERFILE= PYDOCKERFILE= -while getopts f:mr:t: option +NOCACHEARG= +while getopts f:mr:t:n option do case "${option}" in @@ -135,6 +138,7 @@ do p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; + n) NOCACHEARG="--no-cache";; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." From 9de11d3f901bc206a33b9da3e7499bcd43e0142a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 21 Jun 2018 12:24:53 +0900 Subject: [PATCH 0990/2461] [SPARK-23912][SQL] add array_distinct ## What changes were proposed in this pull request? Add array_distinct to remove duplicate value from the array. ## How was this patch tested? Add unit tests Author: Huaxin Gao Closes #21050 from huaxingao/spark-23912. --- python/pyspark/sql/functions.py | 14 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 279 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 45 +++ .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 22 ++ 6 files changed, 368 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e6346691fb1d4..11b179fe26bfc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1999,6 +1999,20 @@ def array_remove(col, element): return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3700c63d817ea..4b09b9a7e75df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -433,6 +433,7 @@ object FunctionRegistry { expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d76f3013f0c41..7c064a130ff35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if (elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { + | break; + | } + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } + + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s"$distinctArray[$pos] = null"; + } else { + s"$distinctArray.setNullAt($pos)"; + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { + s"$distinctArray[$pos] = $getValue1"; + } else { + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForBruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) { + | break; + | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } else { + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForBruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } + } + + override def prettyName: String = "array_distinct" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85e692bdc4ef1..f377f9c8cd533 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8551058ec58ce..965dbb69c8efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3189,6 +3189,13 @@ object functions { ArrayRemove(column.expr, Literal(element)) } + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4e5c1c56e2673..3dc696bd01eeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1216,6 +1216,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 54fcaafb094e299f21c18370fddb4a727c88d875 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 20 Jun 2018 23:38:37 -0700 Subject: [PATCH 0991/2461] [SPARK-24571][SQL] Support Char literals ## What changes were proposed in this pull request? In the PR, I propose to automatically convert a `Literal` with `Char` type to a `Literal` of `String` type. Currently, the following code: ```scala val df = Seq("Amsterdam", "San Francisco", "London").toDF("city") df.where($"city".contains('o')).show(false) ``` fails with the exception: ``` Unsupported literal type class java.lang.Character o java.lang.RuntimeException: Unsupported literal type class java.lang.Character o at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78) ``` The PR fixes this issue by converting `char` to `string` of length `1`. I believe it makes sense to does not differentiate `char` and `string(1)` in _a unified, multi-language data platform_ like Spark which supports languages like Python/R. Author: Maxim Gekk Author: Maxim Gekk Closes #21578 from MaxGekk/support-char-literals. --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 1 + .../apache/spark/sql/catalyst/expressions/literals.scala | 1 + .../spark/sql/catalyst/CatalystTypeConvertersSuite.scala | 8 ++++++++ .../sql/catalyst/expressions/LiteralExpressionSuite.scala | 7 +++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ 5 files changed, 25 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9e9105a157abe..93df73ab1eaf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -286,6 +286,7 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to the string type") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59e..0cc2a332f2c30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f99af9b84d959..89452ee05cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -139,4 +140,11 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(exception.getMessage.contains("The value (0.1) of the type " + "(java.lang.Double) cannot be converted to the string type")) } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a6..86f80fe66d28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 093cee91d2f49..2d20c50584c03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1479,6 +1479,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds1.schema == ds2.schema) checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 7236e759c9856970116bf4dd20813dbf14440462 Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Thu, 21 Jun 2018 14:58:57 +0800 Subject: [PATCH 0992/2461] [SPARK-24574][SQL] array_contains, array_position, array_remove and element_at functions deal with Column type ## What changes were proposed in this pull request? For the function ```def array_contains(column: Column, value: Any): Column ``` , if we pass the `value` parameter as a Column type, it will yield a runtime exception. This PR proposes a pattern matching to detect if `value` is of type Column. If yes, it will use the .expr of the column, otherwise it will work as it used to. Same thing for ```array_position, array_remove and element_at``` functions ## How was this patch tested? Unit test modified to cover this code change. Ping ueshin Author: Chongguang LIU Closes #21581 from chongguang/SPARK-24574. --- .../org/apache/spark/sql/functions.scala | 8 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 69 +++++++++++++++---- 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 965dbb69c8efb..c296a1bf9d69d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3093,7 +3093,7 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + ArrayContains(column.expr, lit(value).expr) } /** @@ -3157,7 +3157,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3168,7 +3168,7 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + ElementAt(column.expr, lit(value).expr) } /** @@ -3186,7 +3186,7 @@ object functions { * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, Literal(element)) + ArrayRemove(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3dc696bd01eeb..fcdd33f544311 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -635,9 +635,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -648,6 +648,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -862,9 +870,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -874,7 +882,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - + checkAnswer( + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), null)), Seq(Row(null), Row(null)) @@ -901,10 +916,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -922,6 +937,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -1189,10 +1212,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array remove") { val df = Seq( - (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), - (Array.empty[Int], Array.empty[String], Array.empty[String]), - (null, null, null) - ).toDF("a", "b", "c") + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") checkAnswer( df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), Seq( @@ -1201,6 +1224,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) + checkAnswer( + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), From c0cad596b84feeb7840c1b7a51c1ef275940a5ed Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Thu, 21 Jun 2018 16:41:43 +0800 Subject: [PATCH 0993/2461] [SPARK-24614][PYSPARK] Fix for SyntaxWarning on tests.py ## What changes were proposed in this pull request? Fix for SyntaxWarning on tests.py ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Closes #21604 from rekhajoshm/SPARK-24614. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 94ab867f0bd9b..3c8a8fcf6e946 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1918,7 +1918,7 @@ def assert_invalid_writer(self, writer, msg=None): self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected except Exception as e: if msg: - assert(msg in str(e), "%s not in %s" % (msg, str(e))) + assert msg in str(e), "%s not in %s" % (msg, str(e)) finally: self.stop_all() From b56e9c613fb345472da3db1a567ee129621f6bf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 21 Jun 2018 09:17:18 -0500 Subject: [PATCH 0994/2461] [SPARK-16630][YARN] Blacklist a node if executors won't launch on it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This change extends YARN resource allocation handling with blacklisting functionality. This handles cases when node is messed up or misconfigured such that a container won't launch on it. Before this change backlisting only focused on task execution but this change introduces YarnAllocatorBlacklistTracker which tracks allocation failures per host (when enabled via "spark.yarn.blacklist.executor.launch.blacklisting.enabled"). ## How was this patch tested? ### With unit tests Including a new suite: YarnAllocatorBlacklistTrackerSuite. #### Manually It was tested on a cluster by deleting the Spark jars on one of the node. #### Behaviour before these changes Starting Spark as: ``` spark2-shell --master yarn --deploy-mode client --num-executors 4 --conf spark.executor.memory=4g --conf "spark.yarn.max.executor.failures=6" ``` Log is: ``` 18/04/12 06:49:36 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 11, (reason: Max number of executor failures (6) reached) 18/04/12 06:49:39 INFO yarn.ApplicationMaster: Unregistering ApplicationMaster with FAILED (diag message: Max number of executor failures (6) reached) 18/04/12 06:49:39 INFO impl.AMRMClientImpl: Waiting for application to be successfully unregistered. 18/04/12 06:49:39 INFO yarn.ApplicationMaster: Deleting staging directory hdfs://apiros-1.gce.test.com:8020/user/systest/.sparkStaging/application_1523459048274_0016 18/04/12 06:49:39 INFO util.ShutdownHookManager: Shutdown hook called ``` #### Behaviour after these changes Starting Spark as: ``` spark2-shell --master yarn --deploy-mode client --num-executors 4 --conf spark.executor.memory=4g --conf "spark.yarn.max.executor.failures=6" --conf "spark.yarn.blacklist.executor.launch.blacklisting.enabled=true" ``` And the log is: ``` 18/04/13 05:37:43 INFO yarn.YarnAllocator: Will request 1 executor container(s), each with 1 core(s) and 4505 MB memory (including 409 MB of overhead) 18/04/13 05:37:43 INFO yarn.YarnAllocator: Submitted 1 unlocalized container requests. 18/04/13 05:37:43 INFO yarn.YarnAllocator: Launching container container_1523459048274_0025_01_000008 on host apiros-4.gce.test.com for executor with ID 6 18/04/13 05:37:43 INFO yarn.YarnAllocator: Received 1 containers from YARN, launching executors on 1 of them. 18/04/13 05:37:43 INFO yarn.YarnAllocator: Completed container container_1523459048274_0025_01_000007 on host: apiros-4.gce.test.com (state: COMPLETE, exit status: 1) 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: blacklisting host as YARN allocation failed: apiros-4.gce.test.com 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: adding nodes to YARN application master's blacklist: List(apiros-4.gce.test.com) 18/04/13 05:37:43 WARN yarn.YarnAllocator: Container marked as failed: container_1523459048274_0025_01_000007 on host: apiros-4.gce.test.com. Exit status: 1. Diagnostics: Exception from container-launch. Container id: container_1523459048274_0025_01_000007 Exit code: 1 Stack trace: ExitCodeException exitCode=1: at org.apache.hadoop.util.Shell.runCommand(Shell.java:604) at org.apache.hadoop.util.Shell.run(Shell.java:507) at org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:789) at org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213) at org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302) at org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` Where the most important part is: ``` 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: blacklisting host as YARN allocation failed: apiros-4.gce.test.com 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: adding nodes to YARN application master's blacklist: List(apiros-4.gce.test.com) ``` And execution was continued (no shutdown called). ### Testing the backlisting of the whole cluster Starting Spark with YARN blacklisting enabled then removing a the Spark core jar one by one from all the cluster nodes. Then executing a simple spark job which fails checking the yarn log the expected exit status is contained: ``` 18/06/15 01:07:10 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 11, (reason: Due to executor failures all available nodes are blacklisted) 18/06/15 01:07:13 INFO util.ShutdownHookManager: Shutdown hook called ``` Author: “attilapiros” Closes #21068 from attilapiros/SPARK-16630. --- .../spark/scheduler/BlacklistTracker.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 3 +- .../apache/spark/HeartbeatReceiverSuite.scala | 6 +- docs/running-on-yarn.md | 10 + ...osCoarseGrainedSchedulerBackendSuite.scala | 1 + .../spark/deploy/yarn/ApplicationMaster.scala | 4 + .../spark/deploy/yarn/YarnAllocator.scala | 60 ++---- .../yarn/YarnAllocatorBlacklistTracker.scala | 187 ++++++++++++++++++ .../org/apache/spark/deploy/yarn/config.scala | 6 + .../deploy/yarn/FailureTrackerSuite.scala | 100 ++++++++++ .../YarnAllocatorBlacklistTrackerSuite.scala | 140 +++++++++++++ .../deploy/yarn/YarnAllocatorSuite.scala | 16 +- 12 files changed, 479 insertions(+), 56 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 30cf75d43ee09..980fbbe516b91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -371,7 +371,7 @@ private[scheduler] class BlacklistTracker ( } -private[scheduler] object BlacklistTracker extends Logging { +private[spark] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d8794e8e551aa..9b90e309d2e04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -170,8 +170,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) - } else if (scheduler.nodeBlacklist != null && - scheduler.nodeBlacklist.contains(hostname)) { + } else if (scheduler.nodeBlacklist.contains(hostname)) { // If the cluster manager gives us an executor on a blacklisted node (because it // already started allocating those resources before we informed it of our blacklist, // or if it ignored our blacklist), then we reject that executor immediately. diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 88916488c0def..b705556e54b14 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} -import scala.collection.Map import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ @@ -73,6 +72,7 @@ class HeartbeatReceiverSuite sc = spy(new SparkContext(conf)) scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.nodeBlacklist).thenReturn(Predef.Set[String]()) when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) @@ -241,7 +241,7 @@ class HeartbeatReceiverSuite } === Some(true)) } - private def getTrackedExecutors: Map[String, Long] = { + private def getTrackedExecutors: collection.Map[String, Long] = { // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4dbcbeafbbd9d..575da7205b529 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -411,6 +411,16 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.blacklist.executor.launch.blacklisting.enabled + false + + Flag to enable blacklisting of nodes having YARN resource allocation problems. + The error limit for blacklisting can be configured by + spark.blacklist.application.maxFailedExecutorsPerNode. + + + # Important notes diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f4bd1ee9da6f7..b790c7cd27794 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -789,6 +789,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.nodeBlacklist).thenReturn(Set[String]()) when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d6ee50b070a3..ecc576910db9e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -515,6 +515,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, s"Max number of executor failures ($maxNumExecutorFailures) reached") + } else if (allocator.isAllNodeBlacklisted) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Due to executor failures all available nodes are blacklisted") } else { logDebug("Sending progress") allocator.allocateResources() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebee3d431744d..fae054e0eea00 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records._ @@ -66,7 +66,8 @@ private[yarn] class YarnAllocator( appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, localResources: Map[String, LocalResource], - resolver: SparkRackResolver) + resolver: SparkRackResolver, + clock: Clock = new SystemClock) extends Logging { import YarnAllocator._ @@ -102,18 +103,14 @@ private[yarn] class YarnAllocator( private var executorIdCounter: Int = driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) - // Queue to store the timestamp of failed executors - private val failedExecutorsTimeStamps = new Queue[Long]() + private[spark] val failureTracker = new FailureTracker(sparkConf, clock) - private var clock: Clock = new SystemClock - - private val executorFailuresValidityInterval = - sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + private val allocatorBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClient, failureTracker) @volatile private var targetNumExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) - private var currentNodeBlacklist = Set.empty[String] // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor @@ -149,7 +146,6 @@ private[yarn] class YarnAllocator( private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) - // A map to store preferred hostname and possible task numbers running on it. private var hostToLocalTaskCounts: Map[String, Int] = Map.empty @@ -160,26 +156,11 @@ private[yarn] class YarnAllocator( private[yarn] val containerPlacementStrategy = new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) - /** - * Use a different clock for YarnAllocator. This is mainly used for testing. - */ - def setClock(newClock: Clock): Unit = { - clock = newClock - } - def getNumExecutorsRunning: Int = runningExecutors.size() - def getNumExecutorsFailed: Int = synchronized { - val endTime = clock.getTimeMillis() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - while (executorFailuresValidityInterval > 0 - && failedExecutorsTimeStamps.nonEmpty - && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) { - failedExecutorsTimeStamps.dequeue() - } - - failedExecutorsTimeStamps.size - } + def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted /** * A sequence of pending container requests that have not yet been fulfilled. @@ -204,9 +185,8 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. - * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new - * containers on them. It will be used to update the application master's - * blacklist. + * @param nodeBlacklist blacklisted nodes, which is passed in to avoid allocating new containers + * on them. It will be used to update the application master's blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( @@ -220,19 +200,7 @@ private[yarn] class YarnAllocator( if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - - // Update blacklist infomation to YARN ResouceManager for this application, - // in order to avoid allocating new Containers on the problematic nodes. - val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist - val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist - if (blacklistAdditions.nonEmpty) { - logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") - } - if (blacklistRemovals.nonEmpty) { - logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") - } - amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) - currentNodeBlacklist = nodeBlacklist + allocatorBlacklistTracker.setSchedulerBlacklistedNodes(nodeBlacklist) true } else { false @@ -268,6 +236,7 @@ private[yarn] class YarnAllocator( val allocateResponse = amClient.allocate(progressIndicator) val allocatedContainers = allocateResponse.getAllocatedContainers() + allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes) if (allocatedContainers.size > 0) { logDebug(("Allocated containers: %d. Current executor count: %d. " + @@ -602,8 +571,9 @@ private[yarn] class YarnAllocator( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case _ => - // Enqueue the timestamp of failed executor - failedExecutorsTimeStamps.enqueue(clock.getTimeMillis()) + // all the failures which not covered above, like: + // disk failure, kill by app master or resource manager, ... + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala new file mode 100644 index 0000000000000..1b48a0ee7ad32 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.BlacklistTracker +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes + * and synchronizing the node list to YARN. + * + * Blacklisted nodes are coming from two different sources: + * + *
      + *
    • from the scheduler as task level blacklisted nodes + *
    • from this class (tracked here) as YARN resource allocation problems + *
    + * + * The reason to realize this logic here (and not in the driver) is to avoid possible delays + * between synchronizing the blacklisted nodes with YARN and resource allocations. + */ +private[spark] class YarnAllocatorBlacklistTracker( + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + failureTracker: FailureTracker) + extends Logging { + + private val blacklistTimeoutMillis = BlacklistTracker.getBlacklistTimeout(sparkConf) + + private val launchBlacklistEnabled = sparkConf.get(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED) + + private val maxFailuresPerHost = sparkConf.get(MAX_FAILED_EXEC_PER_NODE) + + private val allocatorBlacklist = new HashMap[String, Long]() + + private var currentBlacklistedYarnNodes = Set.empty[String] + + private var schedulerBlacklist = Set.empty[String] + + private var numClusterNodes = Int.MaxValue + + def setNumClusterNodes(numClusterNodes: Int): Unit = { + this.numClusterNodes = numClusterNodes + } + + def handleResourceAllocationFailure(hostOpt: Option[String]): Unit = { + hostOpt match { + case Some(hostname) if launchBlacklistEnabled => + // failures on an already blacklisted nodes are not even tracked. + // otherwise, such failures could shutdown the application + // as resource requests are asynchronous + // and a late failure response could exceed MAX_EXECUTOR_FAILURES + if (!schedulerBlacklist.contains(hostname) && + !allocatorBlacklist.contains(hostname)) { + failureTracker.registerFailureOnHost(hostname) + updateAllocationBlacklistedNodes(hostname) + } + case _ => + failureTracker.registerExecutorFailure() + } + } + + private def updateAllocationBlacklistedNodes(hostname: String): Unit = { + val failuresOnHost = failureTracker.numFailuresOnHost(hostname) + if (failuresOnHost > maxFailuresPerHost) { + logInfo(s"blacklisting $hostname as YARN allocation failed $failuresOnHost times") + allocatorBlacklist.put( + hostname, + failureTracker.clock.getTimeMillis() + blacklistTimeoutMillis) + refreshBlacklistedNodes() + } + } + + def setSchedulerBlacklistedNodes(schedulerBlacklistedNodesWithExpiry: Set[String]): Unit = { + this.schedulerBlacklist = schedulerBlacklistedNodesWithExpiry + refreshBlacklistedNodes() + } + + def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + + private def refreshBlacklistedNodes(): Unit = { + removeExpiredYarnBlacklistedNodes() + val allBlacklistedNodes = schedulerBlacklist ++ allocatorBlacklist.keySet + synchronizeBlacklistedNodeWithYarn(allBlacklistedNodes) + } + + private def synchronizeBlacklistedNodeWithYarn(nodesToBlacklist: Set[String]): Unit = { + // Update blacklist information to YARN ResourceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val additions = (nodesToBlacklist -- currentBlacklistedYarnNodes).toList.sorted + val removals = (currentBlacklistedYarnNodes -- nodesToBlacklist).toList.sorted + if (additions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $additions") + } + if (removals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $removals") + } + amClient.updateBlacklist(additions.asJava, removals.asJava) + currentBlacklistedYarnNodes = nodesToBlacklist + } + + private def removeExpiredYarnBlacklistedNodes(): Unit = { + val now = failureTracker.clock.getTimeMillis() + allocatorBlacklist.retain { (_, expiryTime) => expiryTime > now } + } +} + +/** + * FailureTracker is responsible for tracking executor failures both for each host separately + * and for all hosts altogether. + */ +private[spark] class FailureTracker( + sparkConf: SparkConf, + val clock: Clock = new SystemClock) extends Logging { + + private val executorFailuresValidityInterval = + sparkConf.get(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + + // Queue to store the timestamp of failed executors for each host + private val failedExecutorsTimeStampsPerHost = mutable.Map[String, mutable.Queue[Long]]() + + private val failedExecutorsTimeStamps = new mutable.Queue[Long]() + + private def updateAndCountFailures(failedExecutorsWithTimeStamps: mutable.Queue[Long]): Int = { + val endTime = clock.getTimeMillis() + while (executorFailuresValidityInterval > 0 && + failedExecutorsWithTimeStamps.nonEmpty && + failedExecutorsWithTimeStamps.head < endTime - executorFailuresValidityInterval) { + failedExecutorsWithTimeStamps.dequeue() + } + failedExecutorsWithTimeStamps.size + } + + def numFailedExecutors: Int = synchronized { + updateAndCountFailures(failedExecutorsTimeStamps) + } + + def registerFailureOnHost(hostname: String): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + val failedExecutorsOnHost = + failedExecutorsTimeStampsPerHost.getOrElse(hostname, { + val failureOnHost = mutable.Queue[Long]() + failedExecutorsTimeStampsPerHost.put(hostname, failureOnHost) + failureOnHost + }) + failedExecutorsOnHost.enqueue(timeMillis) + } + + def registerExecutorFailure(): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + } + + def numFailuresOnHost(hostname: String): Int = { + failedExecutorsTimeStampsPerHost.get(hostname).map { failedExecutorsOnHost => + updateAndCountFailures(failedExecutorsOnHost) + }.getOrElse(0) + } + +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1a99b3bd57672..129084a86597a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -328,4 +328,10 @@ package object config { CACHED_FILES_TYPES, CACHED_CONF_ARCHIVE) + /* YARN allocator-level blacklisting related config entries. */ + private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = + ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala new file mode 100644 index 0000000000000..4f77b9c99dd25 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.ManualClock + +class FailureTrackerSuite extends SparkFunSuite with Matchers { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("failures expire if validity interval is set") { + val sparkConf = new SparkConf() + sparkConf.set(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS, 100L) + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(101) + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(231) + failureTracker.numFailuresOnHost("host1") should be (0) + failureTracker.numFailuresOnHost("host2") should be (0) + failureTracker.numFailedExecutors should be (0) + } + + + test("failures never expire if validity interval is not set (-1)") { + val sparkConf = new SparkConf() + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(1000) + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..aeac68e6ed330 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.Arrays +import java.util.Collections + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config.YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED +import org.apache.spark.internal.config.{BLACKLIST_TIMEOUT_CONF, MAX_FAILED_EXEC_PER_NODE} +import org.apache.spark.util.ManualClock + +class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach { + + val BLACKLIST_TIMEOUT = 100L + val MAX_FAILED_EXEC_PER_NODE_VALUE = 2 + + var amClientMock: AMRMClient[ContainerRequest] = _ + var yarnBlacklistTracker: YarnAllocatorBlacklistTracker = _ + var failureTracker: FailureTracker = _ + var clock: ManualClock = _ + + override def beforeEach(): Unit = { + val sparkConf = new SparkConf() + sparkConf.set(BLACKLIST_TIMEOUT_CONF, BLACKLIST_TIMEOUT) + sparkConf.set(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED, true) + sparkConf.set(MAX_FAILED_EXEC_PER_NODE, MAX_FAILED_EXEC_PER_NODE_VALUE) + clock = new ManualClock() + + amClientMock = mock(classOf[AMRMClient[ContainerRequest]]) + failureTracker = new FailureTracker(sparkConf, clock) + yarnBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClientMock, failureTracker) + yarnBlacklistTracker.setNumClusterNodes(4) + super.beforeEach() + } + + test("expiring its own blacklisted nodes") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // host should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + } + } + + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // the third failure on the host triggers the blacklisting + verify(amClientMock).updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + + clock.advance(BLACKLIST_TIMEOUT) + + // trigger synchronisation of blacklisted nodes with YARN + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set()) + verify(amClientMock).updateBlacklist(Collections.emptyList(), Arrays.asList("host")) + } + + test("not handling the expiry of scheduler blacklisted nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2"), Collections.emptyList()) + + // advance timer more then host1, host2 expiry time + clock.advance(200L) + + // expired blacklisted nodes (simulating a resource request) + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + // no change is communicated to YARN regarding the blacklisting + verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + } + + test("combining scheduler and allocation blacklist") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + // host1 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + } + } + + // as this is the third failure on host1 the node will be blacklisted + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host2", "host3"), Collections.emptyList()) + + clock.advance(10L) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host3", "host4")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host4"), Arrays.asList("host2")) + } + + test("blacklist all available nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2", "host3"), Collections.emptyList()) + + clock.advance(60L) + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + // host4 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + } + } + + // the third failure on the host triggers the blacklisting + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + + verify(amClientMock).updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + assert(yarnBlacklistTracker.isAllNodeBlacklisted === true) + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 525abb6f2b350..3f783baed110d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -59,6 +59,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var rmClient: AMRMClient[ContainerRequest] = _ + var clock: ManualClock = _ + var containerNum = 0 override def beforeEach() { @@ -66,6 +68,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() + clock = new ManualClock() } override def afterEach() { @@ -101,7 +104,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter appAttemptId, new SecurityManager(sparkConf), Map(), - new MockResolver()) + new MockResolver(), + clock) } def createContainer(host: String): Container = { @@ -332,10 +336,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + val blacklistedNodes = Set( + "hostA", + "hostB" + ) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), blacklistedNodes) verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set.empty) verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } @@ -353,8 +361,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) - val clock = new ManualClock(0L) - handler.setClock(clock) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) From c8e909cd498b67b121fa920ceee7631c652dac38 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 21 Jun 2018 13:25:15 -0500 Subject: [PATCH 0995/2461] [SPARK-24589][CORE] Correctly identify tasks in output commit coordinator. When an output stage is retried, it's possible that tasks from the previous attempt are still running. In that case, there would be a new task for the same partition in the new attempt, and the coordinator would allow both tasks to commit their output since it did not keep track of stage attempts. The change adds more information to the stage state tracked by the coordinator, so that only one task is allowed to commit the output in the above case. The stage state in the coordinator is also maintained across stage retries, so that a stray speculative task from a previous stage attempt is not allowed to commit. This also removes some code added in SPARK-18113 that allowed for duplicate commit requests; with the RPC code used in Spark 2, that situation cannot happen, so there is no need to handle it. Author: Marcelo Vanzin Closes #21577 from vanzin/SPARK-24552. --- .../spark/mapred/SparkHadoopMapRedUtil.scala | 8 +- .../apache/spark/scheduler/DAGScheduler.scala | 23 ++-- .../scheduler/OutputCommitCoordinator.scala | 128 ++++++++++-------- .../OutputCommitCoordinatorSuite.scala | 116 +++++++++++----- .../datasources/v2/WriteToDataSourceV2.scala | 9 +- 5 files changed, 179 insertions(+), 105 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 764735dc4eae7..db8aff94ea1e1 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val taskAttemptNumber = TaskContext.get().attemptNumber() - val stageId = TaskContext.get().stageId() - val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) + val ctx = TaskContext.get() + val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), + splitId, ctx.attemptNumber()) if (canCommit) { performCommit() @@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber()) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 041eade82d3ca..f74425d73b392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1171,6 +1171,7 @@ class DAGScheduler( outputCommitCoordinator.taskCompleted( stageId, + task.stageAttemptId, task.partitionId, event.taskInfo.attemptNumber, // this is a task attempt number event.reason) @@ -1330,23 +1331,24 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) } else { logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + s"longer running") } - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest - if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1545,7 +1547,10 @@ class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1564,7 +1569,9 @@ class DAGScheduler( logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - outputCommitCoordinator.stageEnd(stage.id) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 83d87b548a430..b382d623806e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) +private case class AskPermissionToCommitOutput( + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None - private type StageId = Int - private type PartitionId = Int - private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + // Class used to identify a committer. The task ID for a committer is implicitly defined by + // the partition being processed, but the coordinator needs to keep track of both the stage + // attempt and the task attempt, because in some situations the same task may be running + // concurrently in two different attempts of the same stage. + private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int) + private case class StageState(numPartitions: Int) { - val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) - val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null) + val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** @@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val stageStates = mutable.Map[StageId, StageState]() + private val stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @return true if this task is authorized to commit, false otherwise */ def canCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = { + val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), @@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } /** - * Called by the DAGScheduler when a stage starts. + * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't + * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { - stageStates(stage) = new StageState(maxPartitionId + 1) + private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized { + stageStates.get(stage) match { + case Some(state) => + require(state.authorizedCommitters.length == maxPartitionId + 1) + logInfo(s"Reusing state from previous attempt of stage $stage.") + + case _ => + stageStates(stage) = new StageState(maxPartitionId + 1) + } } // Called by DAGScheduler - private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + private[scheduler] def stageEnd(stage: Int): Unit = synchronized { stageStates.remove(stage) } // Called by DAGScheduler private[scheduler] def taskCompleted( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber, + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int, reason: TaskEndReason): Unit = synchronized { val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) reason match { case Success => // The task output has been committed successfully - case denied: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + - s"attempt: $attemptNumber") - case otherReason => + case _: TaskCommitDenied => + logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + + s"partition: $partition, attempt: $attemptNumber") + case _ => // Mark the attempt as failed to blacklist from future commit protocol - stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber - if (stageState.authorizedCommitters(partition) == attemptNumber) { + val taskId = TaskIdentifier(stageAttempt, attemptNumber) + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId + if (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = null } } } @@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Marked private[scheduler] instead of private so this can be mocked in tests private[scheduler] def handleAskPermissionToCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = synchronized { + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { - case Some(state) if attemptFailed(state, partition, attemptNumber) => - logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition as task attempt $attemptNumber has already failed.") + case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => + logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"task attempt $attemptNumber already marked as failed.") false case Some(state) => - state.authorizedCommitters(partition) match { - case NO_AUTHORIZED_COMMITTER => - logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition") - state.authorizedCommitters(partition) = attemptNumber - true - case existingCommitter => - // Coordinator should be idempotent when receiving AskPermissionToCommit. - if (existingCommitter == attemptNumber) { - logWarning(s"Authorizing duplicate request to commit for " + - s"attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition; existingCommitter = $existingCommitter." + - s" This can indicate dropped network traffic.") - true - } else { - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - } + val existing = state.authorizedCommitters(partition) + if (existing == null) { + logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " + + s"task attempt $attemptNumber") + state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber) + true + } else { + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"already committed by $existing") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing" + - s" attempt number $attemptNumber of partition $partition to commit") + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + "stage already marked as completed.") false } } private def attemptFailed( stageState: StageState, - partition: PartitionId, - attempt: TaskAttemptNumber): Boolean = synchronized { - stageState.failures.get(partition).exists(_.contains(attempt)) + stageAttempt: Int, + partition: Int, + attempt: Int): Boolean = synchronized { + val failInfo = TaskIdentifier(stageAttempt, attempt) + stageState.failures.get(partition).exists(_.contains(failInfo)) } } @@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, attemptNumber) => + case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition, + attemptNumber)) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 03b1903902491..158c9eb75f2b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any()) + Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) - assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter)) // The non-authorized committer fails - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer - assert( - outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 2)) // There can only be one authorized committer - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) - } - - test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { - val rdd = sc.parallelize(Seq(1), 1) - sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, - 0 until rdd.partitions.size) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 1 val failedAttempt: Int = 0 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) - outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = failedAttempt, reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) - assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) - assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1)) + } + + test("SPARK-24589: Differentiate tasks from different stage attempts") { + var stage = 1 + val taskAttempt = 1 + val partition = 1 + + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Fail the task in the first attempt, the task in the second attempt should succeed. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, + // then fail the 1st attempt and make sure the 4th one can commit again. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + } + + test("SPARK-24589: Make sure stage state is cleaned up") { + // Normal application without stage failures. + sc.parallelize(1 to 100, 100) + .map { i => (i % 10, i) } + .reduceByKey(_ + _) + .collect() + + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + + // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing + // stage so that we can check the state of the output committer. + val retriedStage = sc.parallelize(1 to 100, 10) + .map { i => (i % 10, i) } + .reduceByKey { case (_, _) => + val ctx = TaskContext.get() + if (ctx.stageAttemptNumber() == 0) { + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + new Exception("Failure for test.")) + } else { + ctx.stageId() + } + } + .collect() + .map { case (k, v) => v } + .toSet + + assert(retriedStage.size === 1) + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + verify(sc.env.outputCommitCoordinator, times(2)) + .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) + verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) } } @@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } - // Receiver should be idempotent for AskPermissionToCommitOutput - def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { - val ctx = TaskContext.get() - val canCommit1 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - val canCommit2 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - assert(canCommit1 && canCommit2) - } - private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea4bda327f36f..11ed7131e7e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -109,6 +109,7 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") @@ -122,12 +123,14 @@ object DataWritingSparkTask extends Logging { val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator - val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + logInfo(s"Writer for stage $stageId / $stageAttempt, " + + s"task $partId.$attemptId is authorized to commit.") dataWriter.commit() } else { - val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + val message = s"Stage $stageId / $stageAttempt, " + + s"task $partId.$attemptId: driver did not authorize commit" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) From b9a6f7499a7f4d95efaf6a753c480f3642e22187 Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Thu, 21 Jun 2018 11:45:30 -0700 Subject: [PATCH 0996/2461] [SPARK-24613][SQL] Cache with UDF could not be matched with subsequent dependent caches ## What changes were proposed in this pull request? Wrap the logical plan with a `AnalysisBarrier` for execution plan compilation in CacheManager, in order to avoid the plan being analyzed again. ## How was this patch tested? Add one test in `DatasetCacheSuite` Author: Maryann Xue Closes #21602 from maryannxue/cache-mismatch. --- .../spark/sql/execution/CacheManager.scala | 6 +++--- .../org/apache/spark/sql/DatasetCacheSuite.scala | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 93bf91e56f1bd..2db7c02e86014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -142,7 +142,7 @@ class CacheManager extends Logging { // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(cd.plan).executedPlan + val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( cacheBuilder = cd.cachedRepresentation .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 82a93f74dd76c..c4f056334cd1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalatest.concurrent.TimeLimits import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel @@ -132,4 +133,19 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits df.unpersist() assert(df.storageLevel == StorageLevel.NONE) } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + val plan = df2.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined) + } } From dc8a6befa5dad861a731b4d7865f3ccf37482ae0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 Jun 2018 15:38:46 -0700 Subject: [PATCH 0997/2461] [SPARK-24588][SS] streaming join should require HashClusteredPartitioning from children ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/19080 we simplified the distribution/partitioning framework, and make all the join-like operators require `HashClusteredDistribution` from children. Unfortunately streaming join operator was missed. This can cause wrong result. Think about ``` val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) val joined = df1.join(df2, Seq("a", "b")).select('a) ``` The physical plan is ``` *(3) Project [a#5] +- StreamingSymmetricHashJoin [a#5, b#6], [a#10, b#11], Inner, condition = [ leftOnly = null, rightOnly = null, both = null, full = null ], state info [ checkpoint = , runId = 54e31fce-f055-4686-b75d-fcd2b076f8d8, opId = 0, ver = 0, numPartitions = 5], 0, state cleanup [ left = null, right = null ] :- Exchange hashpartitioning(a#5, b#6, 5) : +- *(1) Project [value#1 AS a#5, (value#1 * 2) AS b#6] : +- StreamingRelation MemoryStream[value#1], [value#1] +- Exchange hashpartitioning(b#11, 5) +- *(2) Project [value#3 AS a#10, (value#3 * 2) AS b#11] +- StreamingRelation MemoryStream[value#3], [value#3] ``` The left table is hash partitioned by `a, b`, while the right table is hash partitioned by `b`. This means, we may have a matching record that is in different partitions, which should be in the output but not. ## How was this patch tested? N/A Author: Wenchen Fan Closes #21587 from cloud-fan/join. --- .../plans/physical/partitioning.scala | 53 +++-- .../sql/catalyst/DistributionSuite.scala | 201 +++++++++++++++--- .../v2/DataSourcePartitioning.scala | 4 +- .../StreamingSymmetricHashJoinExec.scala | 4 +- .../sql/streaming/StreamingJoinSuite.scala | 14 ++ 5 files changed, 217 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -99,16 +99,19 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None - override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } } @@ -163,11 +166,22 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 case _ => false @@ -186,9 +200,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } @@ -205,16 +218,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case h: HashClusteredDistribution => expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -246,15 +258,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -295,7 +306,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -310,7 +321,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,8 +30,8 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: physical.Distribution): Boolean = { + super.satisfies0(required) || { required match { case d: physical.ClusteredDistribution if isCandidate(d.clustering) => val attrs = d.clustering.map(_.asInstanceOf[Attribute]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index afa664eb76525..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1f62357e6d09e..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input3, 5, 10), CheckNewAnswer((5, 10, 5, 15, 5, 25))) } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) + } } From 92c2f00bd275a90b6912fb8c8cf542002923629c Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 22 Jun 2018 16:18:22 +0900 Subject: [PATCH 0998/2461] [SPARK-23934][SQL] Adding map_from_entries function ## What changes were proposed in this pull request? The PR adds the `map_from_entries` function that returns a map created from the given array of entries. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionSuite` ## CodeGen Examples ### Primitive-type Keys and Values ``` val idf = Seq( Seq((1, 10), (2, 20), (3, 10)), Seq((1, 10), null, (2, 20)) ).toDF("a") idf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_2 = 0; !project_isNull_0 && project_idx_2 < inputadapter_value_0.numElements(); project_idx_2++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_2); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final long project_keySectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 052 */ final long project_valueSectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 053 */ final long project_byteArraySize_0 = 8 + project_keySectionSize_0 + project_valueSectionSize_0; /* 054 */ if (project_byteArraySize_0 > 2147483632) { /* 055 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 056 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 057 */ /* 058 */ for (int project_idx_1 = 0; project_idx_1 < project_numEntries_0; project_idx_1++) { /* 059 */ InternalRow project_entry_1 = inputadapter_value_0.getStruct(project_idx_1, 2); /* 060 */ /* 061 */ project_keys_0[project_idx_1] = project_entry_1.getInt(0); /* 062 */ project_values_0[project_idx_1] = project_entry_1.getInt(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } else { /* 068 */ final byte[] project_byteArray_0 = new byte[(int)project_byteArraySize_0]; /* 069 */ UnsafeMapData project_unsafeMapData_0 = new UnsafeMapData(); /* 070 */ Platform.putLong(project_byteArray_0, 16, project_keySectionSize_0); /* 071 */ Platform.putLong(project_byteArray_0, 24, project_numEntries_0); /* 072 */ Platform.putLong(project_byteArray_0, 24 + project_keySectionSize_0, project_numEntries_0); /* 073 */ project_unsafeMapData_0.pointTo(project_byteArray_0, 16, (int)project_byteArraySize_0); /* 074 */ ArrayData project_keyArrayData_0 = project_unsafeMapData_0.keyArray(); /* 075 */ ArrayData project_valueArrayData_0 = project_unsafeMapData_0.valueArray(); /* 076 */ /* 077 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 078 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 079 */ /* 080 */ project_keyArrayData_0.setInt(project_idx_0, project_entry_0.getInt(0)); /* 081 */ project_valueArrayData_0.setInt(project_idx_0, project_entry_0.getInt(1)); /* 082 */ } /* 083 */ /* 084 */ project_value_0 = project_unsafeMapData_0; /* 085 */ } /* 086 */ /* 087 */ } ``` ### Non-primitive-type Keys and Values ``` val sdf = Seq( Seq(("a", null), ("b", "bb"), ("c", "aa")), Seq(("a", "aa"), null, (null, "bb")) ).toDF("a") sdf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_1 = 0; !project_isNull_0 && project_idx_1 < inputadapter_value_0.numElements(); project_idx_1++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_1); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 052 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 053 */ /* 054 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 055 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 056 */ /* 057 */ if (project_entry_0.isNullAt(0)) { /* 058 */ throw new RuntimeException("The first field from a struct (key) can't be null."); /* 059 */ } /* 060 */ /* 061 */ project_keys_0[project_idx_0] = project_entry_0.getUTF8String(0); /* 062 */ project_values_0[project_idx_0] = project_entry_0.getUTF8String(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } ``` Author: Marek Novotny Closes #21282 from mn-mikke/feature/array-api-map_from_entries-to-master. --- python/pyspark/sql/functions.py | 20 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 30 +++ .../expressions/collectionOperations.scala | 235 ++++++++++++++++-- .../CollectionExpressionsSuite.scala | 51 ++++ .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 50 ++++ 7 files changed, 378 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 11b179fe26bfc..5f5d73307e9c8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2412,6 +2412,26 @@ def map_entries(col): return Column(sc._jvm.functions.map_entries(_to_java_column(col))) +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4b09b9a7e75df..8abc616c1a3f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -421,6 +421,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e5906253..4cc0968911cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,6 +819,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c064a130ff35..3afabe14606e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -475,6 +475,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } + + private def nullEntries: Boolean = dataTypeDetails.get._3 + + override def nullable: Boolean = child.nullable || nullEntries + + override def dataType: MapType = dataTypeDetails.get._1 + + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } + + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - private def nullElementsProtection( - ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" - |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if (!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin - } - private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f377f9c8cd533..5b8cf5128fe21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -80,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c296a1bf9d69d..f792a6fba1d8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3526,6 +3526,13 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + /** * Returns a merged array of structs in which the N-th struct contains all N-th values of input * arrays. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fcdd33f544311..25fdbab745128 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -633,6 +633,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_from_entries function") { + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") + + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x", 1), From 39dfaf2fd167cafc84ec9cc637c114ed54a331e3 Mon Sep 17 00:00:00 2001 From: Hieu Huynh <“Hieu.huynh@oath.com”> Date: Fri, 22 Jun 2018 09:16:14 -0500 Subject: [PATCH 0999/2461] [SPARK-24519] Make the threshold for highly compressed map status configurable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Problem** MapStatus uses hardcoded value of 2000 partitions to determine if it should use highly compressed map status. We should make it configurable to allow users to more easily tune their jobs with respect to this without having for them to modify their code to change the number of partitions. Note we can leave this as an internal/undocumented config for now until we have more advise for the users on how to set this config. Some of my reasoning: The config gives you a way to easily change something without the user having to change code, redeploy jar, and then run again. You can simply change the config and rerun. It also allows for easier experimentation. Changing the # of partitions has other side affects, whether good or bad is situation dependent. It can be worse are you could be increasing # of output files when you don't want to be, affects the # of tasks needs and thus executors to run in parallel, etc. There have been various talks about this number at spark summits where people have told customers to increase it to be 2001 partitions. Note if you just do a search for spark 2000 partitions you will fine various things all talking about this number. This shows that people are modifying their code to take this into account so it seems to me having this configurable would be better. Once we have more advice for users we could expose this and document information on it. **What changes were proposed in this pull request?** I make the hardcoded value mentioned above to be configurable under the name _SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS_, which has default value to be 2000. Users can set it to the value they want by setting the property name _spark.shuffle.minNumPartitionsToHighlyCompress_ **How was this patch tested?** I wrote a unit test to make sure that the default value is 2000, and _IllegalArgumentException_ will be thrown if user set it to a non-positive value. The unit test also checks that highly compressed map status is correctly used when the number of partition is greater than _SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS_. Author: Hieu Huynh <“Hieu.huynh@oath.com”> Closes #21527 from hthuynh2/spark_branch_1. --- .../spark/internal/config/package.scala | 7 +++++ .../apache/spark/scheduler/MapStatus.scala | 4 ++- .../spark/scheduler/MapStatusSuite.scala | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a54b091a64d50..38a043c85ae33 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -552,4 +552,11 @@ package object config { .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1h") + private[spark] val SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS = + ConfigBuilder("spark.shuffle.minNumPartitionsToHighlyCompress") + .internal() + .doc("Number of partitions to determine if MapStatus should use HighlyCompressedMapStatus") + .intConf + .checkValue(v => v > 0, "The value should be a positive integer.") + .createWithDefault(2000) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 2ec2f2031aa45..659694dd189ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -50,7 +50,9 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { + if (uncompressedSizes.length > Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..354e6386fa60e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,4 +188,32 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = Array.fill[Long](500)(150L) + // Test default value + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[CompressedMapStatus]) + // Test Non-positive values + for (s <- -1 to 0) { + assertThrows[IllegalArgumentException] { + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + } + } + // Test positive values + Seq(1, 100, 499, 500, 501).foreach { s => + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + if(sizes.length > s) { + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + } else { + assert(status.isInstanceOf[CompressedMapStatus]) + } + } + } } From 33e77fa89b5805ecb1066fc534723527f70d37c7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 22 Jun 2018 10:14:12 -0700 Subject: [PATCH 1000/2461] [SPARK-24518][CORE] Using Hadoop credential provider API to store password ## What changes were proposed in this pull request? In our distribution, because we don't do such fine-grained access control of config file, also configuration file is world readable shared between different components, so password may leak to different users. Hadoop credential provider API support storing password in a secure way, in which Spark could read it in a secure way, so here propose to add support of using credential provider API to get password. ## How was this patch tested? Adding tests and verified locally. Author: jerryshao Closes #21548 from jerryshao/SPARK-24518. --- .../scala/org/apache/spark/SSLOptions.scala | 11 ++- .../org/apache/spark/SecurityManager.scala | 9 ++- .../org/apache/spark/SSLOptionsSuite.scala | 75 +++++++++++++++++-- docs/security.md | 23 +++++- 4 files changed, 107 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 04c38f12acc78..1632e0c69eef5 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,6 +21,7 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration import org.eclipse.jetty.util.ssl.SslContextFactory import org.apache.spark.internal.Logging @@ -163,11 +164,16 @@ private[spark] object SSLOptions extends Logging { * missing in SparkConf, the corresponding setting is used from the default configuration. * * @param conf Spark configuration object where the settings are collected from + * @param hadoopConf Hadoop configuration to get settings * @param ns the namespace name * @param defaults the default configuration * @return [[org.apache.spark.SSLOptions]] object */ - def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + def parse( + conf: SparkConf, + hadoopConf: Configuration, + ns: String, + defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) @@ -179,9 +185,11 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.keyStore)) val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyStorePassword)) val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyPassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyPassword)) val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") @@ -194,6 +202,7 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.trustStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.trustStorePassword)) val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index b87476322573d..3cfafeb951105 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -19,11 +19,11 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import javax.net.ssl._ import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher @@ -111,11 +111,14 @@ private[spark] class SecurityManager( ) } + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten - private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + private val defaultSSLOptions = + SSLOptions.parse(sparkConf, hadoopConf, "spark.ssl", defaults = None) def getSSLOptions(module: String): SSLOptions = { - val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + val opts = + SSLOptions.parse(sparkConf, hadoopConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") opts } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 8eabc2b3cb958..5dbfc5c10a6f8 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark import java.io.File +import java.util.UUID import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.SparkConfWithEnv @@ -40,6 +43,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { .toSet val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -49,7 +53,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) conf.set("spark.ssl.protocol", "TLSv1.2") - val opts = SSLOptions.parse(conf, "spark.ssl") + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl") assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -70,6 +74,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -80,8 +85,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -103,6 +108,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.ui.port", "4242") @@ -117,8 +123,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.port === Some(4242)) @@ -139,14 +145,71 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConfWithEnv(Map( "ENV1" -> "val1", "ENV2" -> "val2")) + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", "${env:ENV1}") conf.set("spark.ssl.trustStore", "${env:ENV2}") - val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) assert(opts.keyStore === Some(new File("val1"))) assert(opts.trustStore === Some(new File("val2"))) } + test("get password from Hadoop credential provider") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + val hadoopConf = new Configuration() + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + + s"${UUID.randomUUID().toString}.jceks" + val provider = createCredentialProvider(tmpPath, hadoopConf) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + storePassword(provider, "spark.ssl.keyStorePassword", "password") + storePassword(provider, "spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + storePassword(provider, "spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) + + val provider = CredentialProviderFactory.getProviders(conf).get(0) + if (provider == null) { + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") + } + + provider + } + + private def storePassword( + provider: CredentialProvider, + passwordKey: String, + password: String): Unit = { + provider.createCredentialEntry(passwordKey, password.toCharArray) + provider.flush() + } } diff --git a/docs/security.md b/docs/security.md index 8c0c66fb5a285..6ef3a808e0471 100644 --- a/docs/security.md +++ b/docs/security.md @@ -177,7 +177,7 @@ ACLs can be configured for either users or groups. Configuration entries accept lists as input, meaning multiple users or groups can be given the desired privileges. This can be used if you run on a shared cluster and have a set of administrators or developers who need to monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL -means that all users will have the respective pivilege. By default, only the user submitting the +means that all users will have the respective privilege. By default, only the user submitting the application is added to the ACLs. Group membership is established by using a configurable group mapping provider. The mapper is @@ -446,6 +446,27 @@ replaced with one of the above namespaces. +Spark also supports retrieving `${ns}.keyPassword`, `${ns}.keyStorePassword` and `${ns}.trustStorePassword` from +[Hadoop Credential Providers](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/CredentialProviderAPI.html). +User could store password into credential file and make it accessible by different components, like: + +``` +hadoop credential create spark.ssl.keyPassword -value password \ + -provider jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks +``` + +To configure the location of the credential provider, set the `hadoop.security.credential.provider.path` +config option in the Hadoop configuration used by Spark, like: + +``` + + hadoop.security.credential.provider.path + jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks + +``` + +Or via SparkConf "spark.hadoop.hadoop.security.credential.provider.path=jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks". + ## Preparing the key stores Key stores can be generated by `keytool` program. The reference documentation for this tool for From 4e7d8678a3d9b12797d07f5497e0ed9e471428dd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 22 Jun 2018 12:38:34 -0500 Subject: [PATCH 1001/2461] [SPARK-24372][BUILD] Add scripts to help with preparing releases. The "do-release.sh" script asks questions about the RC being prepared, trying to find out as much as possible automatically, and then executes the existing scripts with proper arguments to prepare the release. This script was used to prepare the 2.3.1 release candidates, so was tested in that context. The docker version runs that same script inside a docker image especially crafted for building Spark releases. That image is based on the work by Felix C. linked in the bug. At this point is has been only midly tested. I also added a template for the vote e-mail, with placeholders for things that need to be replaced, although there is no automation around that for the moment. It shouldn't be hard to hook up certain things like version and tags to this, or to figure out certain things like the repo URL from the output of the release scripts. Author: Marcelo Vanzin Closes #21515 from vanzin/SPARK-24372. --- dev/.rat-excludes | 1 + dev/create-release/do-release-docker.sh | 143 +++++++++++++++ dev/create-release/do-release.sh | 81 +++++++++ dev/create-release/release-build.sh | 190 ++++++++++++-------- dev/create-release/release-tag.sh | 26 ++- dev/create-release/release-util.sh | 228 ++++++++++++++++++++++++ dev/create-release/spark-rm/Dockerfile | 87 +++++++++ dev/create-release/vote.tmpl | 65 +++++++ 8 files changed, 736 insertions(+), 85 deletions(-) create mode 100755 dev/create-release/do-release-docker.sh create mode 100755 dev/create-release/do-release.sh create mode 100644 dev/create-release/release-util.sh create mode 100644 dev/create-release/spark-rm/Dockerfile create mode 100644 dev/create-release/vote.tmpl diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 9552d001a079c..23b24212b4d29 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -106,3 +106,4 @@ spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin kafka-source-initial-offset-future-version.bin +vote.tmpl diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh new file mode 100755 index 0000000000000..fa7b73cdb40ec --- /dev/null +++ b/dev/create-release/do-release-docker.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Creates a Spark release candidate. The script will update versions, tag the branch, +# build Spark binary packages and documentation, and upload maven artifacts to a staging +# repository. There is also a dry run mode where only local builds are performed, and +# nothing is uploaded to the ASF repos. +# +# Run with "-h" for options. +# + +set -e +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +function usage { + local NAME=$(basename $0) + cat < "$GPG_KEY_FILE" + +run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \ + docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm" + +# Write the release information to a file with environment variables to be used when running the +# image. +ENVFILE="$WORKDIR/env.list" +fcreate_secure "$ENVFILE" + +function cleanup { + rm -f "$ENVFILE" + rm -f "$GPG_KEY_FILE" +} + +trap cleanup EXIT + +cat > $ENVFILE <> $ENVFILE + JAVA_VOL="--volume $JAVA:/opt/spark-java" +fi + +echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" +docker run -ti \ + --env-file "$ENVFILE" \ + --volume "$WORKDIR:/opt/spark-rm" \ + $JAVA_VOL \ + "spark-rm:$IMGTAG" diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh new file mode 100755 index 0000000000000..f1d4f3ab5ddec --- /dev/null +++ b/dev/create-release/do-release.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +while getopts "bn" opt; do + case $opt in + b) GIT_BRANCH=$OPTARG ;; + n) DRY_RUN=1 ;; + ?) error "Invalid option: $OPTARG" ;; + esac +done + +if [ "$RUNNING_IN_DOCKER" = "1" ]; then + # Inside docker, need to import the GPG key stored in the current directory. + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --import "$SELF/gpg.key" + + # We may need to adjust the path since JAVA_HOME may be overridden by the driver script. + if [ -n "$JAVA_HOME" ]; then + export PATH="$JAVA_HOME/bin:$PATH" + else + # JAVA_HOME for the openjdk package. + export JAVA_HOME=/usr + fi +else + # Outside docker, need to ask for information about the release. + get_release_info +fi + +function should_build { + local WHAT=$1 + [ -z "$RELEASE_STEP" ] || [ "$WHAT" = "$RELEASE_STEP" ] +} + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ + "$SELF/release-tag.sh" + echo "It may take some time for the tag to be synchronized to github." + echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." + read +else + echo "Skipping tag creation for $RELEASE_TAG." +fi + +if should_build "build"; then + run_silent "Building Spark..." "build.log" \ + "$SELF/release-build.sh" package +else + echo "Skipping build step." +fi + +if should_build "docs"; then + run_silent "Building documentation..." "docs.log" \ + "$SELF/release-build.sh" docs +else + echo "Skipping docs step." +fi + +if should_build "publish"; then + run_silent "Publishing release" "publish.log" \ + "$SELF/release-build.sh" publish-release +else + echo "Skipping publish step." +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5faa3d3260a56..24a62a8f4c7d3 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: release-build.sh @@ -87,49 +90,56 @@ NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) -MVN="build/mvn --force" - -# Hive-specific profiles for some builds -HIVE_PROFILES="-Phive -Phive-thriftserver" -# Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" -# Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" -# Scala 2.11 only profiles for some builds -SCALA_2_11_PROFILES="-Pkafka-0-8" -# Scala 2.12 only profiles for some builds -SCALA_2_12_PROFILES="-Pscala-2.12" +init_java +init_maven_sbt rm -rf spark -git clone https://git-wip-us.apache.org/repos/asf/spark.git +git clone "$ASF_REPO" cd spark git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" if [ -z "$SPARK_VERSION" ]; then - SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ - | grep -v INFO | grep -v WARNING | grep -v Download) + # Run $MVN in a separate command so that 'set -e' does the right thing. + TMP=$(mktemp) + $MVN help:evaluate -Dexpression=project.version > $TMP + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + rm $TMP fi -# Verify we have the right java version set -if [ -z "$JAVA_HOME" ]; then - echo "Please set JAVA_HOME." - exit 1 +# Depending on the version being built, certain extra profiles need to be activated, and +# different versions of Scala are supported. +BASE_PROFILES="-Pmesos -Pyarn" +PUBLISH_SCALA_2_10=0 +SCALA_2_10_PROFILES="-Pscala-2.10" +SCALA_2_11_PROFILES= +SCALA_2_12_PROFILES="-Pscala-2.12" + +if [[ $SPARK_VERSION > "2.3" ]]; then + BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + SCALA_2_11_PROFILES="-Pkafka-0-8" +else + PUBLISH_SCALA_2_10=1 fi -java_version=$("${JAVA_HOME}"/bin/javac -version 2>&1 | cut -d " " -f 2) +# Hive-specific profiles for some builds +HIVE_PROFILES="-Phive -Phive-thriftserver" +# Profiles for publishing snapshots and release to Maven Central +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +# Profiles for building binary releases +BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" if [[ ! $SPARK_VERSION < "2.2." ]]; then - if [[ $java_version < "1.8." ]]; then - echo "Java version $java_version is less than required 1.8 for 2.2+" + if [[ $JAVA_VERSION < "1.8." ]]; then + echo "Java version $JAVA_VERSION is less than required 1.8 for 2.2+" echo "Please set JAVA_HOME correctly." exit 1 fi else - if [[ $java_version > "1.7." ]]; then + if ! [[ $JAVA_VERSION =~ 1\.7\..* ]]; then if [ -z "$JAVA_7_HOME" ]; then - echo "Java version $java_version is higher than required 1.7 for pre-2.2" + echo "Java version $JAVA_VERSION is higher than required 1.7 for pre-2.2" echo "Please set JAVA_HOME correctly." exit 1 else @@ -174,8 +184,9 @@ if [[ "$1" == "package" ]]; then FLAGS=$2 ZINC_PORT=$3 BUILD_PACKAGE=$4 - cp -r spark spark-$SPARK_VERSION-bin-$NAME + echo "Building binary dist $NAME" + cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.12 support @@ -244,31 +255,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # TODO: Check exit codes of children here: - # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" & - make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" & - make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" & - wait - rm -rf spark-$SPARK_VERSION-bin-*/ + if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then + error "Failed to build hadoop2.6 package. Check logs for details." + fi + if ! is_dry_run; then + if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then + error "Failed to build hadoop2.7 package. Check logs for details." + fi + if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then + error "Failed to build without-hadoop package. Check logs for details." + fi + fi - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-bin" - mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + rm -rf spark-$SPARK_VERSION-bin-*/ - echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" - svn add "svn-spark/${DEST_DIR_NAME}-bin" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark + fi - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" - cd .. - rm -rf svn-spark exit 0 fi @@ -282,18 +301,22 @@ if [[ "$1" == "docs" ]]; then cd .. cd .. - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-docs" - mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" - echo "Copying release documentation" - cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" - svn add "svn-spark/${DEST_DIR_NAME}-docs" + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" - cd .. - rm -rf svn-spark + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark + fi + + mv "spark/docs/_site" docs/ exit 0 fi @@ -341,13 +364,15 @@ if [[ "$1" == "publish-release" ]]; then # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" + if ! is_dry_run; then + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + fi tmp_repo=$(mktemp -d spark-repo-XXXXX) @@ -356,6 +381,12 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_10 = 1 ]]; then + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$((ZINC_PORT + 1)) -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install + fi + #./dev/change-scala-version.sh 2.12 #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install @@ -386,23 +417,26 @@ if [[ "$1" == "publish-release" ]]; then sha1sum $file | cut -f1 -d' ' > $file.sha1 done - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done + if ! is_dry_run; then + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + fi - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" popd rm -rf $tmp_repo cd .. diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index a05716a5f66bb..628bc0504c9c8 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: tag-release.sh @@ -36,6 +39,7 @@ EOF } set -e +set -o pipefail if [[ $@ == *"help"* ]]; then exit_with_usage @@ -54,8 +58,10 @@ for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GI fi done +init_java +init_maven_sbt + ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" -MVN="build/mvn --force" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH @@ -94,9 +100,15 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -# Push changes -git push origin $RELEASE_TAG -git push origin HEAD:$GIT_BRANCH - -cd .. -rm -rf spark +if ! is_dry_run; then + # Push changes + git push origin $RELEASE_TAG + git push origin HEAD:$GIT_BRANCH + + cd .. + rm -rf spark +else + cd .. + mv spark spark.tag + echo "Clone with version changes and tag available as spark.tag in the output directory." +fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh new file mode 100644 index 0000000000000..7426b0d6ca08d --- /dev/null +++ b/dev/create-release/release-util.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DRY_RUN=${DRY_RUN:-0} +GPG="gpg --no-tty --batch" +ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" + +function error { + echo "$*" + exit 1 +} + +function read_config { + local PROMPT="$1" + local DEFAULT="$2" + local REPLY= + + read -p "$PROMPT [$DEFAULT]: " REPLY + local RETVAL="${REPLY:-$DEFAULT}" + if [ -z "$RETVAL" ]; then + error "$PROMPT is must be provided." + fi + echo "$RETVAL" +} + +function parse_version { + grep -e '.*' | \ + head -n 2 | tail -n 1 | cut -d'>' -f2 | cut -d '<' -f1 +} + +function run_silent { + local BANNER="$1" + local LOG_FILE="$2" + shift 2 + + echo "========================" + echo "= $BANNER" + echo "Command: $@" + echo "Log file: $LOG_FILE" + + "$@" 1>"$LOG_FILE" 2>&1 + + local EC=$? + if [ $EC != 0 ]; then + echo "Command FAILED. Check full logs for details." + tail "$LOG_FILE" + exit $EC + fi +} + +function fcreate_secure { + local FPATH="$1" + rm -f "$FPATH" + touch "$FPATH" + chmod 600 "$FPATH" +} + +function check_for_tag { + curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null +} + +function get_release_info { + if [ -z "$GIT_BRANCH" ]; then + # If no branch is specified, found out the latest branch from the repo. + GIT_BRANCH=$(git ls-remote --heads "$ASF_REPO" | + grep -v refs/heads/master | + awk '{print $2}' | + sort -r | + head -n 1 | + cut -d/ -f3) + fi + + export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") + + # Find the current version for the branch. + local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + parse_version) + echo "Current branch version is $VERSION." + + if [[ ! $VERSION =~ .*-SNAPSHOT ]]; then + error "Not a SNAPSHOT version: $VERSION" + fi + + NEXT_VERSION="$VERSION" + RELEASE_VERSION="${VERSION/-SNAPSHOT/}" + SHORT_VERSION=$(echo "$VERSION" | cut -d . -f 1-2) + local REV=$(echo "$VERSION" | cut -d . -f 3) + + # Find out what rc is being prepared. + # - If the current version is "x.y.0", then this is rc1 of the "x.y.0" release. + # - If not, need to check whether the previous version has been already released or not. + # - If it has, then we're building rc1 of the current version. + # - If it has not, we're building the next RC of the previous version. + local RC_COUNT + if [ $REV != 0 ]; then + local PREV_REL_REV=$((REV - 1)) + local PREV_REL_TAG="v${SHORT_VERSION}.${PREV_REL_REV}" + if check_for_tag "$PREV_REL_TAG"; then + RC_COUNT=1 + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + else + RELEASE_VERSION="${SHORT_VERSION}.${PREV_REL_REV}" + RC_COUNT=$(git ls-remote --tags "$ASF_REPO" "v${RELEASE_VERSION}-rc*" | wc -l) + RC_COUNT=$((RC_COUNT + 1)) + fi + else + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + RC_COUNT=1 + fi + + export NEXT_VERSION + export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") + + RC_COUNT=$(read_config "RC #" "$RC_COUNT") + + # Check if the RC already exists, and if re-creating the RC, skip tag creation. + RELEASE_TAG="v${RELEASE_VERSION}-rc${RC_COUNT}" + SKIP_TAG=0 + if check_for_tag "$RELEASE_TAG"; then + read -p "$RELEASE_TAG already exists. Continue anyway [y/n]? " ANSWER + if [ "$ANSWER" != "y" ]; then + error "Exiting." + fi + SKIP_TAG=1 + fi + + + export RELEASE_TAG + + GIT_REF="$RELEASE_TAG" + if is_dry_run; then + echo "This is a dry run. Please confirm the ref that will be built for testing." + GIT_REF=$(read_config "Ref" "$GIT_REF") + fi + export GIT_REF + export SPARK_PACKAGE_VERSION="$RELEASE_TAG" + + # Gather some user information. + export ASF_USERNAME=$(read_config "ASF user" "$LOGNAME") + + GIT_NAME=$(git config user.name || echo "") + export GIT_NAME=$(read_config "Full name" "$GIT_NAME") + + export GIT_EMAIL="$ASF_USERNAME@apache.org" + export GPG_KEY=$(read_config "GPG key" "$GIT_EMAIL") + + cat <&1 | cut -d " " -f 2) + export JAVA_VERSION +} + +# Initializes MVN_EXTRA_OPTS and SBT_OPTS depending on the JAVA_VERSION in use. Requires init_java. +function init_maven_sbt { + MVN="build/mvn -B" + MVN_EXTRA_OPTS= + SBT_OPTS= + if [[ $JAVA_VERSION < "1.8." ]]; then + # Needed for maven central when using Java 7. + SBT_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN_EXTRA_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN="$MVN $MVN_EXTRA_OPTS" + fi + export MVN MVN_EXTRA_OPTS SBT_OPTS +} diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile new file mode 100644 index 0000000000000..07ce320177f5a --- /dev/null +++ b/dev/create-release/spark-rm/Dockerfile @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building Spark releases. Based on Ubuntu 16.04. +# +# Includes: +# * Java 8 +# * Ivy +# * Python/PyPandoc (2.7.12/3.5.2) +# * R-base/R-base-dev (3.3.2+) +# * Ruby 2.3 build utilities + +FROM ubuntu:16.04 + +# These arguments are just for reuse and not really meant to be customized. +ARG APT_INSTALL="apt-get install --no-install-recommends -y" + +ARG BASE_PIP_PKGS="setuptools wheel virtualenv" +ARG PIP_PKGS="pyopenssl pypandoc numpy pygments sphinx" + +# Install extra needed repos and refresh. +# - CRAN repo +# - Ruby repo (for doc generation) +# +# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch +# the most current package versions (instead of potentially using old versions cached by docker). +RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt/sources.list && \ + gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9 && \ + gpg -a --export E084DAB9 | apt-key add - && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + apt-get update && \ + $APT_INSTALL software-properties-common && \ + apt-add-repository -y ppa:brightbox/ruby-ng && \ + apt-get update && \ + # Install openjdk 8. + $APT_INSTALL openjdk-8-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java && \ + # Install build / source control tools + $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ + pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ + ln -s -T /usr/share/java/ivy.jar /usr/share/ant/lib/ivy.jar && \ + curl -sL https://deb.nodesource.com/setup_4.x | bash && \ + $APT_INSTALL nodejs && \ + # Install needed python packages. Use pip for installing packages (for consistency). + $APT_INSTALL libpython2.7-dev libpython3-dev python-pip python3-pip && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + cd && \ + virtualenv -p python3 p35 && \ + . p35/bin/activate && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + # Install R packages and dependencies used when building. + # R depends on pandoc*, libssl (which are installed above). + $APT_INSTALL r-base r-base-dev && \ + $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf && \ + Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ + Rscript -e "devtools::install_github('jimhester/lintr')" && \ + # Install tools needed to build the documentation. + $APT_INSTALL ruby2.3 ruby2.3-dev && \ + gem install jekyll --no-rdoc --no-ri && \ + gem install jekyll-redirect-from && \ + gem install pygments.rb + +WORKDIR /opt/spark-rm/output + +ARG UID +RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm +USER spark-rm:spark-rm + +ENTRYPOINT [ "/opt/spark-rm/do-release.sh" ] diff --git a/dev/create-release/vote.tmpl b/dev/create-release/vote.tmpl new file mode 100644 index 0000000000000..2ce953c2f7ec4 --- /dev/null +++ b/dev/create-release/vote.tmpl @@ -0,0 +1,65 @@ +Please vote on releasing the following candidate as Apache Spark version {version}. + +The vote is open until {deadline} and passes if a majority +1 PMC votes are cast, with +a minimum of 3 +1 votes. + +[ ] +1 Release this package as Apache Spark {version} +[ ] -1 Do not release this package because ... + +To learn more about Apache Spark, please see http://spark.apache.org/ + +The tag to be voted on is {tag} (commit {tag_commit}): +https://github.com/apache/spark/tree/{tag} + +The release files, including signatures, digests, etc. can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-bin/ + +Signatures used for Spark RCs can be found in this file: +https://dist.apache.org/repos/dist/dev/spark/KEYS + +The staging repository for this release can be found at: +https://repository.apache.org/content/repositories/orgapachespark-{repo_id}/ + +The documentation corresponding to this release can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-docs/ + +The list of bug fixes going into {version} can be found at the following URL: +https://issues.apache.org/jira/projects/SPARK/versions/{jira_version_id} + +FAQ + +========================= +How can I help test this release? +========================= + +If you are a Spark user, you can help us test this release by taking +an existing Spark workload and running on this release candidate, then +reporting any regressions. + +If you're working in PySpark you can set up a virtual env and install +the current RC and see if anything important breaks, in the Java/Scala +you can add the staging repository to your projects resolvers and test +with the RC (make sure to clean up the artifact cache before/after so +you don't end up building with a out of date RC going forward). + +=========================================== +What should happen to JIRA tickets still targeting {version}? +=========================================== + +The current list of open tickets targeted at {version} can be found at: +{open_issues_link} + +Committers should look at those and triage. Extremely important bug +fixes, documentation, and API tweaks that impact compatibility should +be worked on immediately. Everything else please retarget to an +appropriate release. + +================== +But my bug isn't fixed? +================== + +In order to make timely releases, we will typically not hold the +release unless the bug in question is a regression from the previous +release. That being said, if there is something which is a regression +that has not been correctly targeted please ping me or a committer to +help target the issue. \ No newline at end of file From c7e2742f9bce2fcb7c717df80761939272beff54 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 23 Jun 2018 17:40:20 -0700 Subject: [PATCH 1002/2461] [SPARK-24190][SQL] Allow saving of JSON files in UTF-16 and UTF-32 ## What changes were proposed in this pull request? Currently, restrictions in JSONOptions for `encoding` and `lineSep` are the same for read and for write. For example, a requirement for `lineSep` in the code: ``` df.write.option("encoding", "UTF-32BE").json(file) ``` doesn't allow to skip `lineSep` and use its default value `\n` because it throws the exception: ``` equirement failed: The lineSep option must be specified for the UTF-32BE encoding java.lang.IllegalArgumentException: requirement failed: The lineSep option must be specified for the UTF-32BE encoding ``` In the PR, I propose to separate JSONOptions in read and write, and make JSONOptions in write less restrictive. ## How was this patch tested? Added new test for blacklisted encodings in read. And the `lineSep` option was removed in write for some tests. Author: Maxim Gekk Author: Maxim Gekk Closes #21247 from MaxGekk/json-options-in-write. --- .../spark/sql/catalyst/json/JSONOptions.scala | 71 +++++++++++++------ .../datasources/json/JsonFileFormat.scala | 13 ++-- .../datasources/json/JsonSuite.scala | 56 ++++++++++----- 3 files changed, 95 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index c081772116f84..47eeb70e00427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -97,32 +97,16 @@ private[sql] class JSONOptions( sep } + protected def checkedEncoding(enc: String): String = enc + /** * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. - * If the encoding is not specified (None), it will be detected automatically - * when the multiLine option is set to `true`. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. */ val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map { enc => - // The following encodings are not supported in per-line mode (multiline is false) - // because they cause some problems in reading files with BOM which is supposed to - // present in the files with such encodings. After splitting input files by lines, - // only the first lines will have the BOM which leads to impossibility for reading - // the rest lines. Besides of that, the lineSep option must have the BOM in such - // encodings which can never present between lines. - val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) - val isBlacklisted = blacklist.contains(Charset.forName(enc)) - require(multiLine || !isBlacklisted, - s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. - |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) - - val isLineSepRequired = - multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - - require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") - - enc - } + .orElse(parameters.get("charset")).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse("UTF-8")) @@ -141,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d29695..e9a0b383b5f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -40,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -52,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -99,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -158,6 +158,11 @@ private[json] class JsonOutputWriter( case None => StandardCharsets.UTF_8 } + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + private val writer = CodecStreams.createOutputStreamWriter( context, new Path(path), encoding) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a8a4a524a97f9..897424daca0cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, FileOutputStream, StringWriter} -import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -2262,7 +2262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { path => val df = spark.createDataset(Seq(("Dog", 42))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) checkEncoding( @@ -2286,16 +2286,22 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: wrong output encoding") { val encoding = "UTF-128" - val exception = intercept[UnsupportedCharsetException] { + val exception = intercept[SparkException] { withTempPath { path => val df = spark.createDataset(Seq((0))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) } } - assert(exception.getMessage == encoding) + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) } test("SPARK-23723: read back json in UTF-16LE") { @@ -2316,18 +2322,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: write json in UTF-16/32 with multiline off") { Seq("UTF-16", "UTF-32").foreach { encoding => withTempPath { path => - val ds = spark.createDataset(Seq( - ("a", 1), ("b", 2), ("c", 3)) - ).repartition(2) - val e = intercept[IllegalArgumentException] { - ds.write - .option("encoding", encoding) - .option("multiline", "false") - .format("json").mode("overwrite") - .save(path.getCanonicalPath) - }.getMessage - assert(e.contains( - s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } } } } @@ -2476,4 +2481,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) } } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() + } + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } + } } From 98f363b77488792009f5b97bf831ede280f232e2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 23 Jun 2018 17:51:18 -0700 Subject: [PATCH 1003/2461] [SPARK-24206][SQL] Improve FilterPushdownBenchmark benchmark code ## What changes were proposed in this pull request? This pr added benchmark code `FilterPushdownBenchmark` for string pushdown and updated performance results on the AWS `r3.xlarge`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21288 from maropu/UpdateParquetBenchmark. --- .../spark/sql/FilterPushdownBenchmark.scala | 243 ---------- .../benchmark/FilterPushdownBenchmark.scala | 442 ++++++++++++++++++ 2 files changed, 442 insertions(+), 243 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala deleted file mode 100644 index c6dd7dadc9d93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.File - -import scala.util.{Random, Try} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.functions.monotonically_increasing_id -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - - -/** - * Benchmark to measure read performance with Filter pushdown. - */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - conf.set("orc.compression", "snappy") - conf.set("spark.sql.parquet.compression.codec", "snappy") - - private val spark = SparkSession.builder() - .master("local[1]") - .appName("FilterPushdownBenchmark") - .config(conf) - .getOrCreate() - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { - import spark.implicits._ - val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") - val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) - .withColumn("id", monotonically_increasing_id()) - - val dirORC = dir.getCanonicalPath + "/orc" - val dirParquet = dir.getCanonicalPath + "/parquet" - - df.write.mode("overwrite").orc(dirORC) - df.write.mode("overwrite").parquet(dirParquet) - - spark.read.orc(dirORC).createOrReplaceTempView("orcTable") - spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") - } - - def filterPushDownBenchmark( - values: Int, - title: String, - whereExpr: String, - selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() - } - } - } - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() - } - } - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - - Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X - Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X - Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X - Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X - - Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X - Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X - Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X - Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X - - Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X - Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X - Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X - Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X - - Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X - Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X - Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X - - Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X - Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X - Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X - Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X - - Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X - Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X - Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X - - Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X - Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X - Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X - Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X - - Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X - Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X - Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X - Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X - - Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X - Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X - Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X - Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X - - Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X - Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X - Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X - Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X - - Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X - Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X - Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X - Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X - - Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X - Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X - Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X - Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X - */ - benchmark.run() - } - - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - val mid = numRows / 2 - - withTempPath { dir => - withTempTable("orcTable", "patquetTable") { - prepareTable(dir, numRows, width) - - Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => - val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"id = $mid", - s"id <=> $mid", - s"$mid <= id AND id <= $mid", - s"${mid - 1} < id AND id < ${mid + 1}" - ).foreach { whereExpr => - val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") - - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% rows (id < ${numRows * percent / 100})", - s"id < ${numRows * percent / 100}", - selectExpr - ) - } - - Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all rows ($whereExpr)", - whereExpr, - selectExpr) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..6d7c7de9a856e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + * To run this: + * spark-submit --class + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + .setAppName("FilterPushdownBenchmark") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder().config(conf).getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable( + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val valueCol = if (useStringForValue) { + monotonically_increasing_id().cast("string") + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("value", valueCol) + .sort("value") + + saveAsOrcTable(df, dir.getCanonicalPath + "/orc") + saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + } + + private def prepareStringDictTable( + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + val selectExpr = (0 to width).map { + case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" + case i => s"CAST(rand() AS STRING) c$i" + } + val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") + + saveAsOrcTable(df, dir.getCanonicalPath + "/orc") + saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + } + + private def saveAsOrcTable(df: DataFrame, dir: String): Unit = { + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite").option("orc.dictionary.key.threshold", 1.0).orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + private def saveAsParquetTable(df: DataFrame, dir: String): Unit = { + df.write.mode("overwrite").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9201 / 9300 1.7 585.0 1.0X + Parquet Vectorized (Pushdown) 89 / 105 176.3 5.7 103.1X + Native ORC Vectorized 8886 / 8898 1.8 564.9 1.0X + Native ORC Vectorized (Pushdown) 110 / 128 143.4 7.0 83.9X + + + Select 0 string row + ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9336 / 9357 1.7 593.6 1.0X + Parquet Vectorized (Pushdown) 927 / 937 17.0 58.9 10.1X + Native ORC Vectorized 9026 / 9041 1.7 573.9 1.0X + Native ORC Vectorized (Pushdown) 257 / 272 61.1 16.4 36.3X + + + Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9209 / 9223 1.7 585.5 1.0X + Parquet Vectorized (Pushdown) 908 / 925 17.3 57.7 10.1X + Native ORC Vectorized 8878 / 8904 1.8 564.4 1.0X + Native ORC Vectorized (Pushdown) 248 / 261 63.4 15.8 37.1X + + + Select 1 string row + (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9194 / 9216 1.7 584.5 1.0X + Parquet Vectorized (Pushdown) 899 / 908 17.5 57.2 10.2X + Native ORC Vectorized 8934 / 8962 1.8 568.0 1.0X + Native ORC Vectorized (Pushdown) 249 / 254 63.3 15.8 37.0X + + + Select 1 string row + ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9332 / 9351 1.7 593.3 1.0X + Parquet Vectorized (Pushdown) 915 / 934 17.2 58.2 10.2X + Native ORC Vectorized 9049 / 9057 1.7 575.3 1.0X + Native ORC Vectorized (Pushdown) 248 / 258 63.5 15.8 37.7X + + + Select all string rows + (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 20478 / 20497 0.8 1301.9 1.0X + Parquet Vectorized (Pushdown) 20461 / 20550 0.8 1300.9 1.0X + Native ORC Vectorized 27464 / 27482 0.6 1746.1 0.7X + Native ORC Vectorized (Pushdown) 27454 / 27488 0.6 1745.5 0.7X + + + Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8489 / 8519 1.9 539.7 1.0X + Parquet Vectorized (Pushdown) 64 / 69 246.1 4.1 132.8X + Native ORC Vectorized 8064 / 8099 2.0 512.7 1.1X + Native ORC Vectorized (Pushdown) 88 / 94 178.6 5.6 96.4X + + + Select 0 int row + (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8494 / 8514 1.9 540.0 1.0X + Parquet Vectorized (Pushdown) 835 / 840 18.8 53.1 10.2X + Native ORC Vectorized 8090 / 8106 1.9 514.4 1.0X + Native ORC Vectorized (Pushdown) 249 / 257 63.2 15.8 34.1X + + + Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8552 / 8560 1.8 543.7 1.0X + Parquet Vectorized (Pushdown) 837 / 841 18.8 53.2 10.2X + Native ORC Vectorized 8178 / 8188 1.9 519.9 1.0X + Native ORC Vectorized (Pushdown) 249 / 258 63.2 15.8 34.4X + + + Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8562 / 8580 1.8 544.3 1.0X + Parquet Vectorized (Pushdown) 833 / 836 18.9 53.0 10.3X + Native ORC Vectorized 8164 / 8185 1.9 519.0 1.0X + Native ORC Vectorized (Pushdown) 245 / 254 64.3 15.6 35.0X + + + Select 1 int row + (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8540 / 8555 1.8 542.9 1.0X + Parquet Vectorized (Pushdown) 837 / 839 18.8 53.2 10.2X + Native ORC Vectorized 8182 / 8231 1.9 520.2 1.0X + Native ORC Vectorized (Pushdown) 250 / 259 62.9 15.9 34.1X + + + Select 1 int row + (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8535 / 8555 1.8 542.6 1.0X + Parquet Vectorized (Pushdown) 835 / 841 18.8 53.1 10.2X + Native ORC Vectorized 8159 / 8179 1.9 518.8 1.0X + Native ORC Vectorized (Pushdown) 244 / 250 64.5 15.5 35.0X + + + Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9609 / 9634 1.6 610.9 1.0X + Parquet Vectorized (Pushdown) 2663 / 2672 5.9 169.3 3.6X + Native ORC Vectorized 9824 / 9850 1.6 624.6 1.0X + Native ORC Vectorized (Pushdown) 2717 / 2722 5.8 172.7 3.5X + + + Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 13592 / 13613 1.2 864.2 1.0X + Parquet Vectorized (Pushdown) 9720 / 9738 1.6 618.0 1.4X + Native ORC Vectorized 16366 / 16397 1.0 1040.5 0.8X + Native ORC Vectorized (Pushdown) 12437 / 12459 1.3 790.7 1.1X + + + Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 17580 / 17617 0.9 1117.7 1.0X + Parquet Vectorized (Pushdown) 16803 / 16827 0.9 1068.3 1.0X + Native ORC Vectorized 24169 / 24187 0.7 1536.6 0.7X + Native ORC Vectorized (Pushdown) 22147 / 22341 0.7 1408.1 0.8X + + + Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18461 / 18491 0.9 1173.7 1.0X + Parquet Vectorized (Pushdown) 18466 / 18530 0.9 1174.1 1.0X + Native ORC Vectorized 24231 / 24270 0.6 1540.6 0.8X + Native ORC Vectorized (Pushdown) 24207 / 24304 0.6 1539.0 0.8X + + + Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18414 / 18453 0.9 1170.7 1.0X + Parquet Vectorized (Pushdown) 18435 / 18464 0.9 1172.1 1.0X + Native ORC Vectorized 24430 / 24454 0.6 1553.2 0.8X + Native ORC Vectorized (Pushdown) 24410 / 24465 0.6 1552.0 0.8X + + + Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18446 / 18457 0.9 1172.8 1.0X + Parquet Vectorized (Pushdown) 18428 / 18440 0.9 1171.6 1.0X + Native ORC Vectorized 24414 / 24450 0.6 1552.2 0.8X + Native ORC Vectorized (Pushdown) 24385 / 24472 0.6 1550.4 0.8X + + + Select 0 distinct string row + (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8322 / 8352 1.9 529.1 1.0X + Parquet Vectorized (Pushdown) 53 / 57 296.3 3.4 156.7X + Native ORC Vectorized 7903 / 7953 2.0 502.4 1.1X + Native ORC Vectorized (Pushdown) 80 / 82 197.2 5.1 104.3X + + + Select 0 distinct string row + ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8712 / 8743 1.8 553.9 1.0X + Parquet Vectorized (Pushdown) 995 / 1030 15.8 63.3 8.8X + Native ORC Vectorized 8345 / 8362 1.9 530.6 1.0X + Native ORC Vectorized (Pushdown) 84 / 87 187.6 5.3 103.9X + + + Select 1 distinct string row + (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8574 / 8610 1.8 545.1 1.0X + Parquet Vectorized (Pushdown) 1127 / 1135 14.0 71.6 7.6X + Native ORC Vectorized 8163 / 8181 1.9 519.0 1.1X + Native ORC Vectorized (Pushdown) 426 / 433 36.9 27.1 20.1X + + + Select 1 distinct string row + (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8549 / 8568 1.8 543.5 1.0X + Parquet Vectorized (Pushdown) 1124 / 1131 14.0 71.4 7.6X + Native ORC Vectorized 8163 / 8210 1.9 519.0 1.0X + Native ORC Vectorized (Pushdown) 426 / 436 36.9 27.1 20.1X + + + Select 1 distinct string row + ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8889 / 8896 1.8 565.2 1.0X + Parquet Vectorized (Pushdown) 1161 / 1168 13.6 73.8 7.7X + Native ORC Vectorized 8519 / 8554 1.8 541.6 1.0X + Native ORC Vectorized (Pushdown) 430 / 437 36.6 27.3 20.7X + + + Select all distinct string rows + (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 20433 / 20533 0.8 1299.1 1.0X + Parquet Vectorized (Pushdown) 20433 / 20456 0.8 1299.1 1.0X + Native ORC Vectorized 25435 / 25513 0.6 1617.1 0.8X + Native ORC Vectorized (Pushdown) 25435 / 25507 0.6 1617.1 0.8X + */ + + benchmark.run() + } + + private def runIntBenchmark(numRows: Int, width: Int, mid: Int): Unit = { + Seq("value IS NULL", s"$mid < value AND value < $mid").foreach { whereExpr => + val title = s"Select 0 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = $mid", + s"value <=> $mid", + s"$mid <= value AND value <= $mid", + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% int rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all int rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = '$searchValue'", + s"value <=> '$searchValue'", + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => + val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq("value IS NOT NULL").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all $colType rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + + // Pushdown for many distinct value case + withTempPath { dir => + val mid = numRows / 2 + + withTempTable("orcTable", "patquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } + } + } + } + + // Pushdown for few distinct value case (use dictionary encoding) + withTempPath { dir => + val numDistinctValues = 200 + val mid = numDistinctValues / 2 + + withTempTable("orcTable", "patquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, mid, "distinct string") + } + } + } +} From a5849ad9a3e5d41b5938faa7c592bcc6aec36044 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 24 Jun 2018 09:28:46 +0800 Subject: [PATCH 1004/2461] [SPARK-24324][PYTHON] Pandas Grouped Map UDF should assign result columns by name ## What changes were proposed in this pull request? Currently, a `pandas_udf` of type `PandasUDFType.GROUPED_MAP` will assign the resulting columns based on index of the return pandas.DataFrame. If a new DataFrame is returned and constructed using a dict, then the order of the columns could be arbitrary and be different than the defined schema for the UDF. If the schema types still match, then no error will be raised and the user will see column names and column data mixed up. This change will first try to assign columns using the return type field names. If a KeyError occurs, then the column index is checked if it is string based. If so, then the error is raised as it is most likely a naming mistake, else it will fallback to assign columns by position and raise a TypeError if the field types do not match. ## How was this patch tested? Added a test that returns a new DataFrame with column order different than the schema. Author: Bryan Cutler Closes #21427 from BryanCutler/arrow-grouped-map-mixesup-cols-SPARK-24324. --- docs/sql-programming-guide.md | 12 +- python/pyspark/sql/functions.py | 7 +- python/pyspark/sql/tests.py | 104 +++++++++++++++++- python/pyspark/worker.py | 55 ++++++--- .../apache/spark/sql/internal/SQLConf.scala | 13 +++ .../sql/execution/arrow/ArrowUtils.scala | 16 +++ .../python/AggregateInPandasExec.scala | 15 ++- .../python/ArrowEvalPythonExec.scala | 15 ++- .../execution/python/ArrowPythonRunner.scala | 15 ++- .../python/FlatMapGroupsInPandasExec.scala | 15 ++- .../execution/python/WindowInPandasExec.scala | 14 ++- 11 files changed, 226 insertions(+), 55 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d8a738507bd1..d2db067989434 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1752,14 +1752,10 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. -The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, -not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their -position matches the corresponding field in the schema. - -Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column -can differ from the order that it was placed in the dictionary. It is recommended in this case to -explicitly define the column order using the `columns` keyword, e.g. -`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5f5d73307e9c8..9652d3e79b875 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2584,9 +2584,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be - indexed so that their position matches the corresponding field in the schema. + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3c8a8fcf6e946..35a0636e5cfc0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4742,7 +4742,6 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -5327,6 +5326,109 @@ def foo3(key, pdf): expected4 = udf3.func((), pdf) self.assertPandasEqual(expected4, result4) + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 38fe2ef06eac5..eaaae2b14e107 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,6 +38,9 @@ from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle +if sys.version >= '3': + basestring = str + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -92,7 +95,10 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + assign_cols_by_pos = runner_conf.get( + "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + def wrapped(key_series, value_series): import pandas as pd @@ -110,9 +116,13 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] + + # Assign result columns by schema name if user labeled with strings, else use position + if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] + else: + return [(result[result.columns[i]], to_arrow_type(field.dataType)) + for i, field in enumerate(return_type)] return wrapped @@ -143,7 +153,7 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type): +def read_single_udf(pickleSer, infile, eval_type, runner_conf): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -163,7 +173,7 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -175,6 +185,26 @@ def read_single_udf(pickleSer, infile, eval_type): def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + ser = ArrowStreamPandasSerializer(timezone) + else: + ser = BatchedSerializer(PickleSerializer(), 100) + num_udfs = read_int(infile) udfs = {} call_udf = [] @@ -189,7 +219,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -201,7 +231,7 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -210,15 +240,6 @@ def read_udfs(pickleSer, infile, eval_type): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): - timezone = utf8_deserializer.loads(infile) - ser = ArrowStreamPandasSerializer(timezone) - else: - ser = BatchedSerializer(PickleSerializer(), 100) - # profiling is not supported for UDF return func, None, ser, ser diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8d2320d8a6ed7..d5fb524a1396f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1161,6 +1161,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = + buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + .internal() + .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + + "Pandas DataFrame based on position, regardless of column label type. When false, " + + "columns will be looked up by name if labeled with a string and fallback to use " + + "position if not.") + .booleanConf + .createWithDefault(false) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1647,6 +1657,9 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def pandasGroupedMapAssignColumnssByPosition: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 6ad11bda84bf6..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object ArrowUtils { @@ -120,4 +121,19 @@ object ArrowUtils { StructField(field.getName, dt, field.isNullable) }) } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { + Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") + } else { + Nil + } + Map(timeZoneConf ++ pandasColsByPosition: _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8e01e8e56a5bd..d00f6f042d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -81,7 +82,7 @@ case class AggregateInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -135,10 +136,14 @@ case class AggregateInPandasExec( } val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(projectedRowIter, context.partitionId(), context) + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4de214679ae4..0bc21c0986e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -63,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -80,10 +81,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(batchIter, context.partitionId(), context) + funcs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 01e19bddbfb66..ca665652f204d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, - respectTimeZone: Boolean) + conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,12 +58,15 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - if (respectTimeZone) { - PythonRDD.writeUTF(timeZoneId, dataOut) - } else { - dataOut.writeInt(SpecialLengths.NULL) + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 513e174c7733e..f5a563baf52df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -77,7 +78,7 @@ case class FlatMapGroupsInPandasExec( val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Deduplicate the grouping attributes. // If a grouping attribute also appears in data attributes, then we don't need to send the @@ -139,10 +140,14 @@ case class FlatMapGroupsInPandasExec( val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(grouped, context.partitionId(), context) + chainedFunc, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index c76832a1a3829..628029b13a6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -97,7 +98,7 @@ case class WindowInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Extract window expressions and window functions val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) @@ -154,11 +155,14 @@ case class WindowInPandasExec( } val windowFunctionResult = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, + pyFuncs, + bufferSize, + reuseWorker, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, windowInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(pythonInput, context.partitionId(), context) + argOffsets, + windowInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow val resultProj = createResultProjection(expressions) From f596ebe4d3170590b6fce34c179e51ee80c965d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 24 Jun 2018 23:14:42 -0700 Subject: [PATCH 1005/2461] [SPARK-24327][SQL] Verify and normalize a partition column name based on the JDBC resolved schema ## What changes were proposed in this pull request? This pr modified JDBC datasource code to verify and normalize a partition column based on the JDBC resolved schema before building `JDBCRelation`. Closes #20370 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro Closes #21379 from maropu/SPARK-24327. --- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 76 +++++++++++++++---- .../jdbc/JdbcRelationProvider.scala | 6 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 51 ++++++++++++- 4 files changed, 118 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c139db46b63a3..a6fd3637663e8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -100,7 +100,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..b84543ccd7869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * + * @param schema resolved schema of a JDBC table * @param partitioning partition information to generate the where clause for each partition + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + partitioning: JDBCPartitioningInfo, + resolver: Resolver, + jdbcOptions: JDBCOptions): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging { // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + + val column = verifyAndGetNormalizedColumnName( + schema, partitioning.column, resolver, jdbcOptions) + var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + // Verify column name based on the JDBC resolved schema + private def verifyAndGetNormalizedColumnName( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): String = { + val dialect = JdbcDialects.get(jdbcOptions.url) + schema.map(_.name).find { fieldName => + resolver(fieldName, columnName) || + resolver(dialect.quoteIdentifier(fieldName), columnName) + }.map(dialect.quoteIdentifier).getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { + val tableSchema = JDBCRDD.resolveTable(jdbcOptions) + jdbcOptions.customSchema match { + case Some(customSchema) => JdbcUtils.getCustomSchema( + tableSchema, customSchema, resolver) + case None => tableSchema + } + } + + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } } private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = { - val tableSchema = JDBCRDD.resolveTable(jdbcOptions) - jdbcOptions.customSchema match { - case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) - case None => tableSchema - } - } - // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index f8c5677ea0f2a..2b488bb7121dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index bc2aca65e803f..6ea61f02a8206 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } } From 6e0596e2639117d7a0a58b644b0600086f45c7f9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 24 Jun 2018 23:56:47 -0700 Subject: [PATCH 1006/2461] [SPARK-23931][SQL][FOLLOW-UP] Make `arrays_zip` in function.scala `@scala.annotation.varargs`. ## What changes were proposed in this pull request? This is a follow-up pr of #21045 which added `arrays_zip`. The `arrays_zip` in functions.scala should've been `scala.annotation.varargs`. This pr makes it `scala.annotation.varargs`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #21630 from ueshin/issues/SPARK-23931/fup1. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f792a6fba1d8f..40c40e7083d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3539,6 +3539,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ + @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// From 8ab8ef7733b42e94f687a5520332814ac9caeda8 Mon Sep 17 00:00:00 2001 From: Jim Kleckner Date: Mon, 25 Jun 2018 16:23:23 +0800 Subject: [PATCH 1007/2461] Fix minor typo in docs/cloud-integration.md ## What changes were proposed in this pull request? Minor typo in docs/cloud-integration.md ## How was this patch tested? This is trivial enough that it should not affect tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Jim Kleckner Closes #21629 from jkleckner/fix-doc-typo. --- docs/cloud-integration.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index ac1c336988930..18e8fe77bbdbe 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -70,7 +70,7 @@ be safely used as the direct destination of work with the normal rename-based co ### Installation With the relevant libraries on the classpath and Spark configured with valid credentials, -objects can be can be read or written by using their URLs as the path to data. +objects can be read or written by using their URLs as the path to data. For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. @@ -184,7 +184,8 @@ is no need for a workflow of write-then-rename to ensure that files aren't picke while they are still being written. Applications can write straight to the monitored directory. 1. Streams should only be checkpointed to a store implementing a fast and -atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. +atomic `rename()` operation. +Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading From bac50aa37168a7612702a4503750a78ed5d59c78 Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Mon, 25 Jun 2018 07:17:30 -0700 Subject: [PATCH 1008/2461] [SPARK-24596][SQL] Non-cascading Cache Invalidation ## What changes were proposed in this pull request? 1. Add parameter 'cascade' in CacheManager.uncacheQuery(). Under 'cascade=false' mode, only invalidate the current cache, and for other dependent caches, rebuild execution plan and reuse cached buffer. 2. Pass true/false from callers in different uncache scenarios: - Drop tables and regular (persistent) views: regular mode - Drop temporary views: non-cascading mode - Modify table contents (INSERT/UPDATE/MERGE/DELETE): regular mode - Call `DataSet.unpersist()`: non-cascading mode - Call `Catalog.uncacheTable()`: follow the same convention as drop tables/view, which is, use non-cascading mode for temporary views and regular mode for the rest Note that a regular (persistent) view is a database object just like a table, so after dropping a regular view (whether cached or not cached), any query referring to that view should no long be valid. Hence if a cached persistent view is dropped, we need to invalidate the all dependent caches so that exceptions will be thrown for any later reference. On the other hand, a temporary view is in fact equivalent to an unnamed DataSet, and dropping a temporary view should have no impact on queries referencing that view. Thus we should do non-cascading uncaching for temporary views, which also guarantees a consistent uncaching behavior between temporary views and unnamed DataSets. ## How was this patch tested? New tests in CachedTableSuite and DatasetCacheSuite. Author: Maryann Xue Closes #21594 from maryannxue/noncascading-cache. --- docs/sql-programming-guide.md | 1 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/CacheManager.scala | 50 ++++++++++--- .../execution/columnar/InMemoryRelation.scala | 10 +++ .../spark/sql/execution/command/ddl.scala | 8 ++- .../spark/sql/execution/command/tables.scala | 2 +- .../spark/sql/internal/CatalogImpl.scala | 12 ++-- .../apache/spark/sql/CachedTableSuite.scala | 66 ++++++++++++++++- .../apache/spark/sql/DatasetCacheSuite.scala | 70 +++++++++++++++++-- 9 files changed, 197 insertions(+), 26 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2db067989434..196b814420be1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1827,6 +1827,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5526104690d2..57f1e173211af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2964,6 +2964,7 @@ class Dataset[T] private[sql]( /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @param blocking Whether to block until all blocks are deleted. * @@ -2971,12 +2972,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 2db7c02e86014..39d9a95ca4710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -105,24 +105,50 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking) } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = writeLock { + val shouldRemove: LogicalPlan => Boolean = + if (cascade) { + _.find(_.sameResult(plan)).isDefined + } else { + _.sameResult(plan) + } val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { + if (shouldRemove(cd.plan)) { cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } + // Re-compile dependent cached queries after removing the cached query. + if (!cascade) { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, clearCache = false) + } } /** @@ -132,20 +158,24 @@ class CacheManager extends Logging { recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) } - private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + private def recacheByCondition( + spark: SparkSession, + condition: LogicalPlan => Boolean, + clearCache: Boolean = true): Unit = { val it = cachedData.iterator() val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cacheBuilder.clearCache() + if (clearCache) { + cd.cachedRepresentation.cacheBuilder.clearCache() + } // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( - cacheBuilder = cd.cachedRepresentation - .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), + cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index da35a4734e65a..7c8faec53a828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -74,6 +74,16 @@ case class CachedRDDBuilder( } } + def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = { + new CachedRDDBuilder( + useCompression, + batchSize, + storageLevel, + cachedPlan = cachedPlan, + tableName + )(_cachedColumnBuffers) + } + private def buildBuffers(): RDD[CachedBatch] = { val output = cachedPlan.output val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bf4d96fa18d0d..04bf8c6dd917f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -189,8 +189,9 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val isTempView = catalog.isTemporaryTable(tableName) - if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + if (!isTempView && catalog.tableExists(tableName)) { // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadata(tableName).tableType match { @@ -204,9 +205,10 @@ case class DropTableCommand( } } - if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + if (isTempView || catalog.tableExists(tableName)) { try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName), cascade = !isTempView) } catch { case NonFatal(e) => log.warn(e.toString, e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79eb..ec3961f84bd8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -493,7 +493,7 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), cascade = true) } catch { case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6ae307bce10c8..4698e8ab13ce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + val cascade = !sessionCatalog.isTemporaryTable(tableIdent) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), cascade) } /** @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // cached version and make the new version cached lazily. if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = true, blocking = true) // Cache it again. sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 6982c22f4771d..60c73df88896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -801,4 +800,69 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } assert(cachedData.collect === Seq(1001)) } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withView("t1") { + withTempView("t2") { + sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1") + + sql("CACHE TABLE t1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(!spark.catalog.isCached("t2")) + } + } + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withTempView("t1", "t2") { + sql("CACHE TABLE t") + sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t") + assert(!spark.catalog.isCached("t")) + assert(!spark.catalog.isCached("t1")) + assert(!spark.catalog.isCached("t2")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index c4f056334cd1a..5c6a021d5b767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -29,6 +29,16 @@ import org.apache.spark.storage.StorageLevel class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ + /** + * Asserts that a cached [[Dataset]] will be built using the given number of other cached results. + */ + private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { + val plan = df.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -117,7 +127,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits } test("cache UDF result correctly") { - val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) + val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) val df2 = df.agg(sum(df("b"))) @@ -126,7 +136,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits assertCached(df2) // udf has been evaluated during caching, and thus should not be re-evaluated here - failAfter(5 seconds) { + failAfter(3 seconds) { df2.collect() } @@ -143,9 +153,57 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits df.count() df2.cache() - val plan = df2.queryExecution.withCachedData - assert(plan.isInstanceOf[InMemoryRelation]) - val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan - assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined) + assertCacheDependency(df2) + } + + test("SPARK-24596 Non-cascading Cache Invalidation") { + val df = Seq(("a", 1), ("b", 2)).toDF("s", "i") + val df2 = df.filter('i > 1) + val df3 = df.filter('i < 2) + + df2.cache() + df.cache() + df.count() + df3.cache() + + df.unpersist() + + // df un-cached; df2 and df3's cache plan re-compiled + assert(df.storageLevel == StorageLevel.NONE) + assertCacheDependency(df2, 0) + assertCacheDependency(df3, 0) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { + val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val df = spark.range(0, 5).toDF("a") + val df1 = df.withColumn("b", expensiveUDF($"a")) + val df2 = df1.groupBy('a).agg(sum('b)) + val df3 = df.agg(sum('a)) + + df1.cache() + df2.cache() + df2.collect() + df3.cache() + + assertCacheDependency(df2) + + df1.unpersist(blocking = true) + + // df1 un-cached; df2's cache plan re-compiled + assert(df1.storageLevel == StorageLevel.NONE) + assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0) + + val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)")) + assertCached(df4) + // reuse loaded cache + failAfter(3 seconds) { + checkDataset(df4, Row(10)) + } + + val df5 = df.agg(sum('a)).filter($"sum(a)" > 1) + assertCached(df5) + // first time use, load cache + checkDataset(df5, Row(10)) } } From 594ac4f7b816488091202918c409487058e6d8ac Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 25 Jun 2018 23:44:20 +0800 Subject: [PATCH 1009/2461] [SPARK-24633][SQL] Fix codegen when split is required for arrays_zip ## What changes were proposed in this pull request? In function array_zip, when split is required by the high number of arguments, a codegen error can happen. The PR fixes codegen for cases when splitting the code is required. ## How was this patch tested? added UT Author: Marco Gaido Closes #21621 from mgaido91/SPARK-24633. --- .../catalyst/expressions/collectionOperations.scala | 4 ++-- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3afabe14606e4..b6137b07555f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -200,7 +200,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """.stripMargin } - val splittedGetValuesAndCardinalities = ctx.splitExpressions( + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( expressions = getValuesAndCardinalities, funcName = "getValuesAndCardinalities", returnType = "int", @@ -210,7 +210,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |return $biggestCardinality; """.stripMargin, foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), - arguments = + extraArguments = ("ArrayData[]", arrVals) :: ("int", biggestCardinality) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25fdbab745128..47fe67d8daea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -556,6 +556,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) } + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + test("map size function") { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), From 5264164a67df498b73facae207eda12ee133be7d Mon Sep 17 00:00:00 2001 From: Stacy Kerkela Date: Mon, 25 Jun 2018 23:41:39 +0200 Subject: [PATCH 1010/2461] [SPARK-24648][SQL] SqlMetrics should be threadsafe Use LongAdder to make SQLMetrics thread safe. ## What changes were proposed in this pull request? Replace += with LongAdder.add() for concurrent counting ## How was this patch tested? Unit tests with local threads Author: Stacy Kerkela Closes #21634 from dbkerkela/sqlmetrics-concurrency-stacy. --- .../sql/execution/metric/SQLMetrics.scala | 33 ++++++++++------- .../execution/metric/SQLMetricsSuite.scala | 36 ++++++++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 77b907870d678..b4f0ae1eb1a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { + // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue - private var _zeroValue = initValue + private[this] val _value = new LongAdder + private val _zeroValue = initValue + _value.add(initValue) override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) - newAcc._zeroValue = initValue + val newAcc = new SQLMetric(metricType, initValue) + newAcc.add(_value.sum()) newAcc } - override def reset(): Unit = _value = _zeroValue + override def reset(): Unit = this.set(_zeroValue) override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value += o.value + case o: SQLMetric => _value.add(o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == _zeroValue + override def isZero(): Boolean = _value.sum() == _zeroValue - override def add(v: Long): Unit = _value += v + override def add(v: Long): Unit = _value.add(v) // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = _value = v + def set(v: Long): Unit = { + _value.reset() + _value.add(v) + } - def +=(v: Long): Unit = _value += v + def +=(v: Long): Unit = _value.add(v) - override def value: Long = _value + override def value: Long = _value.sum() // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -153,7 +159,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -173,7 +179,8 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), + sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a3a3f3851e21c..8263c9c81c49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("writing metrics from single thread") { + val nAdds = 10 + val acc = new SQLMetric("test", -10) + assert(acc.isZero()) + acc.set(0) + for (i <- 1 to nAdds) acc.add(1) + assert(!acc.isZero()) + assert(nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + + test("writing metrics from multiple threads") { + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + val nFutures = 1000 + val nAdds = 100 + val acc = new SQLMetric("test", -10) + assert(acc.isZero() === true) + acc.set(0) + val l = for ( i <- 1 to nFutures ) yield { + Future { + for (j <- 1 to nAdds) acc.add(1) + i + } + } + for (futures <- Future.sequence(l)) { + assert(nFutures === futures.length) + assert(!acc.isZero()) + assert(nFutures * nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + } } From baa01c8ca9e8ea456f986fbb223c61ad541b52b0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 25 Jun 2018 15:12:33 -0700 Subject: [PATCH 1011/2461] [INFRA] Close stale PR. Closes #21614 From 6d16b9885d6ad01e1cc56d5241b7ebad99487a0c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 25 Jun 2018 16:54:57 -0700 Subject: [PATCH 1012/2461] [SPARK-24552][CORE][SQL] Use task ID instead of attempt number for writes. This passes the unique task attempt id instead of attempt number to v2 data sources because attempt number is reused when stages are retried. When attempt numbers are reused, sources that track data by partition id and attempt number may incorrectly clean up data because the same attempt number can be both committed and aborted. For v1 / Hadoop writes, generate a unique ID based on available attempt numbers to avoid a similar problem. Closes #21558 Author: Marcelo Vanzin Author: Ryan Blue Closes #21606 from vanzin/SPARK-24552.2. --- .../spark/internal/io/SparkHadoopWriter.scala | 6 +++- .../sql/kafka010/KafkaStreamWriter.scala | 2 +- .../sources/v2/writer/DataSourceWriter.java | 8 +++--- .../sql/sources/v2/writer/DataWriter.java | 8 +++--- .../sources/v2/writer/DataWriterFactory.java | 11 +++----- .../datasources/v2/WriteToDataSourceV2.scala | 28 ++++++++++--------- .../continuous/ContinuousWriteRDD.scala | 2 +- .../sources/ForeachWriterProvider.scala | 2 +- .../sources/PackedRowWriterFactory.scala | 2 +- .../streaming/sources/memoryV2.scala | 2 +- .../sources/v2/SimpleWritableDataSource.scala | 8 +++--- 11 files changed, 41 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index abf39213fa0d2..9ebd0aa301592 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging { // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers. + // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently. + val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber + executeTask( context = context, config = config, jobTrackerId = jobTrackerId, commitJobId = commitJobId, sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, + sparkAttemptNumber = attemptId, committer = committer, iterator = iter) }) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index ae5b5c52d514e..32923dc9f5a6b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0030a9f05dba7..7eedc85a5d6f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -64,8 +64,8 @@ public interface DataSourceWriter { DataWriterFactory createWriterFactory(); /** - * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for - * each task commits. + * Returns whether Spark should use the commit coordinator to ensure that at most one task for + * each partition commits. * * @return true if commit coordinator should be used, false otherwise. */ @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that speculative execution may cause multiple tasks to run for a partition. By default, - * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * Spark uses the commit coordinator to allow at most one task to commit. Implementations can * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple - * attempts may have committed successfully and one successful commit message per task will be + * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf458298862..1626c0013e4e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -39,14 +39,14 @@ * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a - * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a + * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the * previous one fails, speculative tasks are running simultaneously. It's possible that one input - * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * RDD partition has multiple data writers with different `taskId` running at the same time, * and data sources should guarantee that these data writers don't conflict and can work together. * Implementations can coordinate with driver during {@link #commit()} to make sure only one of * these data writers can commit successfully. Or implementations can allow all of them to commit diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 7527bcc0c4027..0932ff8f8f8a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -42,15 +42,12 @@ public interface DataWriterFactory extends Serializable { * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task - * failed, Spark launches a new task wth the same task id but different - * attempt number. Or a task is too slow, Spark launches new tasks wth the - * same task id but different attempt number, which means there are multiple - * tasks with the same task id running at the same time. Implementations can - * use this attempt number to distinguish writers of different task attempts. + * @param taskId A unique identifier for a task that is performing the write of the partition + * data. Spark may run multiple tasks for the same partition (due to speculation + * or task failures, for example). * @param epochId A monotonically increasing id for streaming queries that are split in to * discrete periods of execution. For non-streaming queries, * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createDataWriter(int partitionId, long taskId, long epochId); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 11ed7131e7e3d..b1148c0f62f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -111,9 +109,10 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() + val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) + val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -125,12 +124,12 @@ object DataWritingSparkTask extends Logging { val coordinator = SparkEnv.get.outputCommitCoordinator val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId / $stageAttempt, " + - s"task $partId.$attemptId is authorized to commit.") + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.commit() } else { - val message = s"Stage $stageId / $stageAttempt, " + - s"task $partId.$attemptId: driver did not authorize commit" + val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) @@ -141,15 +140,18 @@ object DataWritingSparkTask extends Logging { dataWriter.commit() } - logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") msg })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.abort() - logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") }) } } @@ -160,10 +162,10 @@ class InternalRowDataWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), + rowWriterFactory.createDataWriter(partitionId, taskId, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index ef5f0da1e7cc2..76f3f5baa8d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -56,7 +56,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( context.partitionId(), - context.attemptNumber(), + context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index f677f25f116a2..bc9b6d93ce7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -88,7 +88,7 @@ case class ForeachWriterFactory[T]( extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): ForeachDataWriter[T] = { new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37dba..b501d90c81f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat case object PackedRowWriterFactory extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 47b482007822d..29f8cca476722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -180,7 +180,7 @@ class MemoryStreamWriter( case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 694bb3b95b0f0..1334cf71ae988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new SimpleCSVDataWriter(fs, filePath) } @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new InternalRowCSVDataWriter(fs, filePath) } From d48803bf64dc0fccd6f560738b4682f0c05e767a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 25 Jun 2018 17:08:23 -0700 Subject: [PATCH 1013/2461] [SPARK-24324][PYTHON][FOLLOWUP] Grouped Map positional conf should have deprecation note ## What changes were proposed in this pull request? Followup to the discussion of the added conf in SPARK-24324 which allows assignment by column position only. This conf is to preserve old behavior and will be removed in future releases, so it should have a note to indicate that. ## How was this patch tested? NA Author: Bryan Cutler Closes #21637 from BryanCutler/arrow-groupedMap-conf-deprecate-followup-SPARK-24324. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d5fb524a1396f..e768416f257c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1167,7 +1167,7 @@ object SQLConf { .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + "Pandas DataFrame based on position, regardless of column label type. When false, " + "columns will be looked up by name if labeled with a string and fallback to use " + - "position if not.") + "position if not. This configuration will be deprecated in future releases.") .booleanConf .createWithDefault(false) From 4c059ebc6008b4e78cbebc87a421cb87d1b800ed Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 26 Jun 2018 09:48:15 +0800 Subject: [PATCH 1014/2461] [SPARK-23776][DOC] Update instructions for running PySpark after building with SBT ## What changes were proposed in this pull request? This update tells the reader how to build Spark with SBT such that pyspark-sql tests will succeed. If you follow the current instructions for building Spark with SBT, pyspark/sql/udf.py fails with:
    AnalysisException: u'Can not load class test.org.apache.spark.sql.JavaStringLength, please make sure it is on the classpath;'
    
    ## How was this patch tested? I ran the doc build command (SKIP_API=1 jekyll build) and eyeballed the result. Author: Bruce Robbins Closes #21628 from bersprockets/SPARK-23776_doc. --- docs/building-spark.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 0236bb05849ad..c3bcd90ccc78f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,19 +215,23 @@ If you are building Spark for use in a Python environment and you wish to pip in Alternatively, you can also run make-distribution with the --pip option. -## PySpark Tests with Maven +## PySpark Tests with Maven or SBT If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. ./build/mvn -DskipTests clean package -Phive ./python/run-tests +If you are building PySpark with SBT and wish to run the PySpark tests, you will need to build Spark with Hive support and also build the test components: + + ./build/sbt -Phive clean package + ./build/sbt test:compile + ./python/run-tests + The run-tests script also can be limited to a specific Python version or a specific module ./python/run-tests --python-executables=python --modules=pyspark-sql -**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. - ## Running R Tests To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: From c7967c6049327a03b63ea7a3b0001a97d31e309d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 26 Jun 2018 09:48:52 +0800 Subject: [PATCH 1015/2461] [SPARK-24418][BUILD] Upgrade Scala to 2.11.12 and 2.12.6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Scala is upgraded to `2.11.12` and `2.12.6`. We used `loadFIles()` in `ILoop` as a hook to initialize the Spark before REPL sees any files in Scala `2.11.8`. However, it was a hack, and it was not intended to be a public API, so it was removed in Scala `2.11.12`. From the discussion in Scala community, https://github.com/scala/bug/issues/10913 , we can use `initializeSynchronous` to initialize Spark instead. This PR implements the Spark initialization there. However, in Scala `2.11.12`'s `ILoop.scala`, in function `def startup()`, the first thing it calls is `printWelcome()`. As a result, Scala will call `printWelcome()` and `splash` before calling `initializeSynchronous`. Thus, the Spark shell will allow users to type commends first, and then show the Spark UI URL. It's working, but it will change the Spark Shell interface as the following. ```scala ➜ apache-spark git:(scala-2.11.12) ✗ ./bin/spark-shell Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Scala version 2.11.12 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_161) Type in expressions to have them evaluated. Type :help for more information. scala> Spark context Web UI available at http://192.168.1.169:4040 Spark context available as 'sc' (master = local[*], app id = local-1528180279528). Spark session available as 'spark'. scala> ``` It seems there is no easy way to inject the Spark initialization code in the proper place as Scala doesn't provide a hook. Maybe som-snytt can comment on this. The following command is used to update the dep files. ```scala ./dev/test-dependencies.sh --replace-manifest ``` ## How was this patch tested? Existing tests Author: DB Tsai Closes #21495 from dbtsai/scala-2.11.12. --- LICENSE | 12 +++++----- dev/deps/spark-deps-hadoop-2.6 | 10 ++++---- dev/deps/spark-deps-hadoop-2.7 | 10 ++++---- dev/deps/spark-deps-hadoop-3.1 | 10 ++++---- pom.xml | 8 +++---- .../org/apache/spark/repl/SparkILoop.scala | 24 +++++++------------ .../spark/repl/SparkILoopInterpreter.scala | 18 ++++++++++++-- 7 files changed, 50 insertions(+), 42 deletions(-) diff --git a/LICENSE b/LICENSE index cc1f580207a75..6f5d9452e800d 100644 --- a/LICENSE +++ b/LICENSE @@ -243,18 +243,18 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) + (BSD) JLine (jline:jline:2.14.3 - https://github.com/jline/jline2) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.12 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 723180a14febb..96e9c27210d05 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -172,10 +172,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ea08a001a1c9b..4a6ee027ec355 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -173,10 +173,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index da874026d7d10..e0b560c8ec71f 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-webapp-9.3.20.v20170531.jar jetty-xml-9.3.20.v20170531.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -192,10 +192,10 @@ protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/pom.xml b/pom.xml index 4b4e6c13ea8fd..90e64ff71d229 100644 --- a/pom.xml +++ b/pom.xml @@ -155,7 +155,7 @@ 3.4.1 3.2.2 - 2.11.8 + 2.11.12 2.11 1.9.13 2.6.7 @@ -740,13 +740,13 @@ org.scala-lang.modules scala-parser-combinators_${scala.binary.version} - 1.0.4 + 1.1.0 jline jline - 2.12.1 + 2.14.3 org.scalatest @@ -2755,7 +2755,7 @@ scala-2.12 - 2.12.4 + 2.12.6 2.12 diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e69441a475e9a..a44051b351e19 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -36,7 +36,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this() = this(None, new JPrintWriter(Console.out, true)) override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) + intp = new SparkILoopInterpreter(settings, out, initializeSpark) } val initializationCommands: Seq[String] = Seq( @@ -73,11 +73,15 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) "import org.apache.spark.sql.functions._" ) - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(processLine) + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") } } @@ -101,16 +105,6 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = standardCommands - /** - * We override `loadFiles` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def loadFiles(settings: Settings): Unit = { - initializeSpark() - super.loadFiles(settings) - } - override def resetCommand(line: String): Unit = { super.resetCommand(line) initializeSpark() diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index e736607a9a6b9..4e63816402a10 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,8 +21,22 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { - self => +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) + extends IMain(settings, out) { self => + + /** + * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized + * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that + * the Spark context is visible in those files. + * + * This is a bit of a hack, but there isn't another hook available to us at this point. + * + * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. + */ + override def initializeSynchronous(): Unit = { + super.initializeSynchronous() + initializeSpark() + } override lazy val memberHandlers = new { val intp: self.type = self From e07aee2165af4d301ae12005a6d9ffb030bc2650 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 26 Jun 2018 09:51:55 +0800 Subject: [PATCH 1016/2461] [SPARK-24636][SQL] Type coercion of arrays for array_join function ## What changes were proposed in this pull request? Presto's implementation accepts arbitrary arrays of primitive types as an input: ``` presto> SELECT array_join(ARRAY [1, 2, 3], ', '); _col0 --------- 1, 2, 3 (1 row) ``` This PR proposes to implement a type coercion rule for ```array_join``` function that converts arrays of primitive as well as non-primitive types to arrays of string. ## How was this patch tested? New test cases add into: - sql-tests/inputs/typeCoercion/native/arrayJoin.sql - DataFrameFunctionsSuite.scala Author: Marek Novotny Closes #21620 from mn-mikke/SPARK-24636. --- .../sql/catalyst/analysis/TypeCoercion.scala | 8 ++ .../expressions/collectionOperations.scala | 1 + .../inputs/typeCoercion/native/arrayJoin.sql | 11 +++ .../typeCoercion/native/arrayJoin.sql.out | 90 +++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 ++++ 5 files changed, 127 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..637923928a7da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -536,6 +536,14 @@ object TypeCoercion { case None => c } + case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b6137b07555f4..58612f65c1a53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1621,6 +1621,7 @@ case class ArrayJoin( override def dataType: DataType = StringType + override def prettyName: String = "array_join" } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql new file mode 100644 index 0000000000000..99729c007b104 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql @@ -0,0 +1,11 @@ +SELECT array_join(array(true, false), ', '); +SELECT array_join(array(2Y, 1Y), ', '); +SELECT array_join(array(2S, 1S), ', '); +SELECT array_join(array(2, 1), ', '); +SELECT array_join(array(2L, 1L), ', '); +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', '); +SELECT array_join(array(2.0D, 1.0D), ', '); +SELECT array_join(array(float(2.0), float(1.0)), ', '); +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', '); +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', '); +SELECT array_join(array('a', 'b'), ', '); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out new file mode 100644 index 0000000000000..b23a62dacef7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT array_join(array(true, false), ', ') +-- !query 0 schema +struct +-- !query 0 output +true, false + + +-- !query 1 +SELECT array_join(array(2Y, 1Y), ', ') +-- !query 1 schema +struct +-- !query 1 output +2, 1 + + +-- !query 2 +SELECT array_join(array(2S, 1S), ', ') +-- !query 2 schema +struct +-- !query 2 output +2, 1 + + +-- !query 3 +SELECT array_join(array(2, 1), ', ') +-- !query 3 schema +struct +-- !query 3 output +2, 1 + + +-- !query 4 +SELECT array_join(array(2L, 1L), ', ') +-- !query 4 schema +struct +-- !query 4 output +2, 1 + + +-- !query 5 +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ') +-- !query 5 schema +struct +-- !query 5 output +9223372036854775809, 9223372036854775808 + + +-- !query 6 +SELECT array_join(array(2.0D, 1.0D), ', ') +-- !query 6 schema +struct +-- !query 6 output +2.0, 1.0 + + +-- !query 7 +SELECT array_join(array(float(2.0), float(1.0)), ', ') +-- !query 7 schema +struct +-- !query 7 output +2.0, 1.0 + + +-- !query 8 +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ') +-- !query 8 schema +struct +-- !query 8 output +2016-03-14, 2016-03-13 + + +-- !query 9 +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ') +-- !query 9 schema +struct +-- !query 9 output +2016-11-15 20:54:00, 2016-11-12 20:54:00 + + +-- !query 10 +SELECT array_join(array('a', 'b'), ', ') +-- !query 10 schema +struct +-- !query 10 output +a, b diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 47fe67d8daea3..5d6a6c0832c96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -805,6 +805,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_join(x, delimiter, 'NULL')"), Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + + val idf = Seq(Seq(1, 2, 3)).toDF("x") + + checkAnswer( + idf.select(array_join(idf("x"), ", ")), + Seq(Row("1, 2, 3")) + ) + checkAnswer( + idf.selectExpr("array_join(x, ', ')"), + Seq(Row("1, 2, 3")) + ) + intercept[AnalysisException] { + idf.selectExpr("array_join(x, 1)") + } + intercept[AnalysisException] { + idf.selectExpr("array_join(x, ', ', 1)") + } } test("array_min function") { From dcaa49ff1edd7fcf0f000c6f93ae0e30bd5b6464 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 Jun 2018 14:33:04 -0700 Subject: [PATCH 1017/2461] [SPARK-24658][SQL] Remove workaround for ANTLR bug ## What changes were proposed in this pull request? Issue antlr/antlr4#781 has already been fixed, so the workaround of extracting the pattern into a separate rule is no longer needed. The presto already removed it: https://github.com/prestodb/presto/pull/10744. ## How was this patch tested? Existing tests Author: Yuming Wang Closes #21641 from wangyum/ANTLR-780. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 3fe00eefde7d8..dc95751bf905c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -539,18 +539,11 @@ expression booleanExpression : NOT booleanExpression #logicalNot | EXISTS '(' query ')' #exists - | predicated #booleanDefault + | valueExpression predicate? #predicated | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary ; -// workaround for: -// https://github.com/antlr/antlr4/issues/780 -// https://github.com/antlr/antlr4/issues/781 -predicated - : valueExpression predicate? - ; - predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression | NOT? kind=IN '(' expression (',' expression)* ')' From 02f8781fa2649cf1d3a5cb932e1c8408790974ff Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 26 Jun 2018 15:17:00 -0700 Subject: [PATCH 1018/2461] [SPARK-24423][SQL] Add a new option for JDBC sources ## What changes were proposed in this pull request? Here is the description in the JIRA - Currently, our JDBC connector provides the option `dbtable` for users to specify the to-be-loaded JDBC source table. ```SQL val jdbcDf = spark.read .format("jdbc") .option("dbtable", "dbName.tableName") .options(jdbcCredentials: Map) .load() ``` Normally, users do not fetch the whole JDBC table due to the poor performance/throughput of JDBC. Thus, they normally just fetch a small set of tables. For advanced users, they can pass a subquery as the option. ```SQL val query = """ (select * from tableName limit 10) as tmp """ val jdbcDf = spark.read .format("jdbc") .option("dbtable", query) .options(jdbcCredentials: Map) .load() ``` However, this is straightforward to end users. We should simply allow users to specify the query by a new option `query`. We will handle the complexity for them. ```SQL val query = """select * from tableName limit 10""" val jdbcDf = spark.read .format("jdbc") .option("query", query) .options(jdbcCredentials: Map) .load() ``` ## How was this patch tested? Added tests in JDBCSuite and JDBCWriterSuite. Also tested against MySQL, Postgress, Oracle, DB2 (using docker infrastructure) to make sure there are no syntax issues. Author: Dilip Biswal Closes #21590 from dilipbiswal/SPARK-24423. --- docs/sql-programming-guide.md | 30 +++++- .../datasources/jdbc/JDBCOptions.scala | 66 ++++++++++++- .../execution/datasources/jdbc/JDBCRDD.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 4 +- .../jdbc/JdbcRelationProvider.scala | 5 +- .../datasources/jdbc/JdbcUtils.scala | 10 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 94 ++++++++++++++++++- .../spark/sql/jdbc/JDBCWriteSuite.scala | 14 ++- 8 files changed, 204 insertions(+), 23 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 196b814420be1..7c4ef41cc8907 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1302,9 +1302,33 @@ the following case-insensitive options: dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a - subquery in parentheses. + The JDBC table that should be read from or written into. Note that when using it in the read + path anything that is valid in a FROM clause of a SQL query can be used. + For example, instead of a full table you could also use a subquery in parentheses. It is not + allowed to specify `dbtable` and `query` options at the same time. + + + + query + + A query that will be used to read data into Spark. The specified query will be parenthesized and used + as a subquery in the FROM clause. Spark will also assign an alias to the subquery clause. + As an example, spark will issue a query of the following form to the JDBC Source.

    + SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

    + Below are couple of restrictions while using this option.
    +
      +
    1. It is not allowed to specify `dbtable` and `query` options at the same time.
    2. +
    3. It is not allowed to spcify `query` and `partitionColumn` options at the same time. When specifying + `partitionColumn` option is required, the subquery can be specified using `dbtable` option instead and + partition columns can be qualified using the subquery alias provided as part of `dbtable`.
      + Example:
      + + spark.read.format("jdbc")
      +    .option("dbtable", "(select c1, c2 from t1) as subq")
      +    .option("partitionColumn", "subq.c1"
      +    .load() +
    4. +
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index a73a97c06fe5a..eea966d30948b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: CaseInsensitiveMap[String]) + @transient val parameters: CaseInsensitiveMap[String]) extends Serializable { import JDBCOptions._ @@ -65,11 +65,31 @@ class JDBCOptions( // Required parameters // ------------------------------------------------------------ require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL val url = parameters(JDBC_URL) - // name of table - val table = parameters(JDBC_TABLE_NAME) + // table name or a table subquery. + val tableOrQuery = (parameters.get(JDBC_TABLE_NAME), parameters.get(JDBC_QUERY_STRING)) match { + case (Some(name), Some(subquery)) => + throw new IllegalArgumentException( + s"Both '$JDBC_TABLE_NAME' and '$JDBC_QUERY_STRING' can not be specified at the same time." + ) + case (None, None) => + throw new IllegalArgumentException( + s"Option '$JDBC_TABLE_NAME' or '$JDBC_QUERY_STRING' is required." + ) + case (Some(name), None) => + if (name.isEmpty) { + throw new IllegalArgumentException(s"Option '$JDBC_TABLE_NAME' can not be empty.") + } else { + name.trim + } + case (None, Some(subquery)) => + if (subquery.isEmpty) { + throw new IllegalArgumentException(s"Option `$JDBC_QUERY_STRING` can not be empty.") + } else { + s"(${subquery}) __SPARK_GEN_JDBC_SUBQUERY_NAME_${curId.getAndIncrement()}" + } + } // ------------------------------------------------------------ // Optional parameters @@ -109,6 +129,20 @@ class JDBCOptions( s"When reading JDBC data sources, users need to specify all or none for the following " + s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + s"and '$JDBC_NUM_PARTITIONS'") + + require(!(parameters.get(JDBC_QUERY_STRING).isDefined && partitionColumn.isDefined), + s""" + |Options '$JDBC_QUERY_STRING' and '$JDBC_PARTITION_COLUMN' can not be specified together. + |Please define the query using `$JDBC_TABLE_NAME` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + ) + val fetchSize = { val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, @@ -149,7 +183,30 @@ class JDBCOptions( val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) } +class JdbcOptionsInWrite( + @transient override val parameters: CaseInsensitiveMap[String]) + extends JDBCOptions(parameters) { + + import JDBCOptions._ + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(CaseInsensitiveMap(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table))) + } + + require( + parameters.get(JDBC_TABLE_NAME).isDefined, + s"Option '$JDBC_TABLE_NAME' is required. " + + s"Option '$JDBC_QUERY_STRING' is not applicable while writing.") + + val table = parameters(JDBC_TABLE_NAME) +} + object JDBCOptions { + private val curId = new java.util.concurrent.atomic.AtomicLong(0L) private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { @@ -159,6 +216,7 @@ object JDBCOptions { val JDBC_URL = newOption("url") val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_QUERY_STRING = newOption("query") val JDBC_DRIVER_CLASS = newOption("driver") val JDBC_PARTITION_COLUMN = newOption("partitionColumn") val JDBC_LOWER_BOUND = newOption("lowerBound") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0bab3689e5d0e..1b3b17c75e756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -51,7 +51,7 @@ object JDBCRDD extends Logging { */ def resolveTable(options: JDBCOptions): StructType = { val url = options.url - val table = options.table + val table = options.tableOrQuery val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { @@ -296,7 +296,7 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b84543ccd7869..97e2d255cb7be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -189,12 +189,12 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) } override def toString: String = { val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo + s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 2b488bb7121dc..782d626c1573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -59,7 +59,7 @@ class JdbcRelationProvider extends CreatableRelationProvider mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { - val options = new JDBCOptions(parameters) + val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis val conn = JdbcUtils.createConnectionFactory(options)() @@ -86,7 +86,8 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.ErrorIfExists => throw new AnalysisException( - s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.") + s"Table or view '${options.table}' already exists. " + + s"SaveMode: ErrorIfExists.") case SaveMode.Ignore => // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 433443007cfd8..b81737eda475b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -67,7 +67,7 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, options: JDBCOptions): Boolean = { + def tableExists(conn: Connection, options: JdbcOptionsInWrite): Boolean = { val dialect = JdbcDialects.get(options.url) // Somewhat hacky, but there isn't a good way to identify whether a table exists for all @@ -100,7 +100,7 @@ object JdbcUtils extends Logging { /** * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + def truncateTable(conn: Connection, options: JdbcOptionsInWrite): Unit = { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { @@ -255,7 +255,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) try { - val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) + val statement = conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery)) try { statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) @@ -809,7 +809,7 @@ object JdbcUtils extends Logging { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table val dialect = JdbcDialects.get(url) @@ -838,7 +838,7 @@ object JdbcUtils extends Logging { def createTable( conn: Connection, df: DataFrame, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val strSchema = schemaString( df, options.url, options.createTableColumnTypes) val table = options.table diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 6ea61f02a8206..0389273d6cdfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,7 +25,7 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec @@ -39,7 +39,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite +class JDBCSuite extends QueryTest with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ @@ -1099,7 +1099,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive - assert(options.table == "t1") + assert(options.tableOrQuery == "t1") // When we convert it to properties, it should be case-sensitive. assert(options.asProperties.size == 3) assert(options.asProperties.get("customkey") == null) @@ -1255,4 +1255,92 @@ class JDBCSuite extends SparkFunSuite testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) } } + + test("query JDBC option - negative tests") { + val query = "SELECT * FROM test.people WHERE theid = 1" + // load path + val e1 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .option("dbtable", "test.people") + .load() + }.getMessage + assert(e1.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + // jdbc api path + val properties = new Properties() + properties.setProperty(JDBCOptions.JDBC_QUERY_STRING, query) + val e2 = intercept[RuntimeException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() + }.getMessage + assert(e2.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e3 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', dbtable 'TEST.PEOPLE', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e3.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e4 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", "") + .load() + }.getMessage + assert(e4.contains("Option `query` can not be empty.")) + + // Option query and partitioncolumn are not allowed together. + val expectedErrorMsg = + s""" + |Options 'query' and 'partitionColumn' can not be specified together. + |Please define the query using `dbtable` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + val e5 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e5.contains(expectedErrorMsg)) + } + + test("query JDBC option") { + val query = "SELECT name, theid FROM test.people WHERE theid = 1" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .load() + checkAnswer( + df, + Row("fred", 1) :: Nil) + + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + checkAnswer( + sql("select name, theid from queryOption"), + Row("fred", 1) :: Nil) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1c2c92d1f0737..b751ec2de4825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -293,13 +293,23 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("save errors if dbtable is not specified") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[RuntimeException] { + val e1 = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e1.contains("Option 'dbtable' or 'query' is required")) + + val e2 = intercept[RuntimeException] { df.write.format("jdbc") .option("url", url1) .options(properties.asScala) + .option("query", "select * from TEST.SAVETEST") .save() }.getMessage - assert(e.contains("Option 'dbtable' is required")) + val msg = "Option 'dbtable' is required. Option 'query' is not applicable while writing." + assert(e2.contains(msg)) } test("save errors if wrong user/password combination") { From 16f2c3ea46a330bff7fae33f2521eb36a6280f04 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 26 Jun 2018 15:56:58 -0700 Subject: [PATCH 1019/2461] [SPARK-6237][NETWORK] Network-layer changes to allow stream upload. These changes allow an RPCHandler to receive an upload as a stream of data, without having to buffer the entire message in the FrameDecoder. The primary use case is for replicating large blocks. By itself, this change is adding dead-code that is not being used -- it is a step towards SPARK-24296. Added unit tests for handling streaming data, including successfully sending data, and failures in reading the stream with concurrent requests. Summary of changes: * Introduce a new UploadStream RPC which is sent to push a large payload as a stream (in contrast, the pre-existing StreamRequest and StreamResponse RPCs are used for pull-based streaming). * Generalize RpcHandler.receive() to support requests which contain streams. * Generalize StreamInterceptor to handle both request and response messages (previously it only handled responses). * Introduce StdChannelListener to abstract away common logging logic in ChannelFuture listeners. Author: Imran Rashid Closes #21346 from squito/upload_stream. --- .../network/client/StreamCallbackWithID.java | 22 ++ .../network/client/StreamInterceptor.java | 26 +- .../spark/network/client/TransportClient.java | 175 ++++++----- .../spark/network/crypto/AuthRpcHandler.java | 9 + .../spark/network/protocol/Message.java | 3 +- .../network/protocol/MessageDecoder.java | 3 + .../network/protocol/StreamResponse.java | 2 +- .../spark/network/protocol/UploadStream.java | 107 +++++++ .../spark/network/sasl/SaslRpcHandler.java | 9 + .../spark/network/server/RpcHandler.java | 34 ++- .../server/TransportRequestHandler.java | 95 +++++- .../spark/network/RpcIntegrationSuite.java | 289 ++++++++++++++++-- .../org/apache/spark/network/StreamSuite.java | 94 ++---- .../spark/network/StreamTestHelper.java | 104 +++++++ project/MimaExcludes.scala | 3 + 15 files changed, 799 insertions(+), 176 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java new file mode 100644 index 0000000000000..bd173b653e33e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface StreamCallbackWithID extends StreamCallback { + String getID(); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..f3eb744ff7345 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -22,22 +22,24 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - private final TransportResponseHandler handler; + private final MessageHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private long bytesRead; - StreamInterceptor( - TransportResponseHandler handler, + public StreamInterceptor( + MessageHandler handler, String streamId, long byteCount, StreamCallback callback) { @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } + private void deactivateStream() { + if (handler instanceof TransportResponseHandler) { + // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches + // (there is no extra cleanup that needs to happen) + ((TransportResponseHandler) handler).deactivateStream(); + } + } + @Override public boolean handle(ByteBuf buf) throws Exception { int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); - handler.deactivateStream(); + deactivateStream(); throw re; } else if (bytesRead == byteCount) { - handler.deactivateStream(); + deactivateStream(); callback.onComplete(streamId); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..325225dc0ea2c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,15 +32,15 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.*; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -133,34 +133,21 @@ public void fetchChunk( long streamId, int chunkIndex, ChunkReceivedCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); - - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); + StdChannelListener listener = new StdChannelListener(streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } - }); + }; + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); } /** @@ -170,7 +157,12 @@ public void fetchChunk( * @param callback Object to call with the stream data. */ public void stream(String streamId, StreamCallback callback) { - long startTime = System.currentTimeMillis(); + StdChannelListener listener = new StdChannelListener(streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); + } + }; if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -180,25 +172,7 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); } } @@ -211,35 +185,44 @@ public void stream(String streamId, StreamCallback callback) { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(listener); + + return requestId; + } + + /** + * Send data to the remote end as a stream. This differs from stream() in that this is a request + * to *send* data to the remote end, not to receive it from the remote. + * + * @param meta meta data associated with the stream, which will be read completely on the + * receiving end before the stream itself. + * @param data this will be streamed to the remote end to allow for transferring large amounts + * of data without reading into memory. + * @param callback handles the reply -- onSuccess will only be called when both message and data + * are received successfully. + */ + public long uploadStream( + ManagedBuffer meta, + ManagedBuffer data, + RpcResponseCallback callback) { + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = requestId(); + handler.addRpcRequest(requestId, callback); + + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); return requestId; } @@ -319,4 +302,60 @@ public String toString() { .add("isActive", isActive()) .toString(); } + + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(Object requestId) { + this.startTime = System.currentTimeMillis(); + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + private class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + super("RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..fb44dbbb0953b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..0ccd70c03aba8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); + OneWayMessage(9), UploadStream(10), User(-1); private final byte id; @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case 10: return UploadStream; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..bf80aed0afe10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case StreamFailure: return StreamFailure.decode(in); + case UploadStream: + return UploadStream.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 87e212f3e157b..50b811604b84b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); + return Objects.hashCode(byteCount, streamId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java new file mode 100644 index 0000000000000..fa1d26e76b852 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * An RPC with data that is sent outside of the frame, so it can be read as a stream. + */ +public final class UploadStream extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + public final ManagedBuffer meta; + public final long bodyByteCount; + + public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { + super(body, false); // body is *not* included in the frame + this.requestId = requestId; + this.meta = meta; + bodyByteCount = body.size(); + } + + // this version is called when decoding the bytes on the receiving end. The body is handled + // separately. + private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { + super(null, false); + this.requestId = requestId; + this.meta = meta; + this.bodyByteCount = bodyByteCount; + } + + @Override + public Type type() { return Type.UploadStream; } + + @Override + public int encodedLength() { + // the requestId, meta size, meta and bodyByteCount (body is not included) + return 8 + 4 + ((int) meta.size()) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + try { + ByteBuffer metaBuf = meta.nioByteBuffer(); + buf.writeInt(metaBuf.remaining()); + buf.writeBytes(metaBuf); + } catch (IOException io) { + throw new RuntimeException(io); + } + buf.writeLong(bodyByteCount); + } + + public static UploadStream decode(ByteBuf buf) { + long requestId = buf.readLong(); + int metaSize = buf.readInt(); + ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); + long bodyByteCount = buf.readLong(); + // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor + // to read the data. + return new UploadStream(requestId, meta, bodyByteCount); + } + + @Override + public int hashCode() { + return Long.hashCode(requestId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UploadStream) { + UploadStream o = (UploadStream) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..355a3def8cc22 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -132,6 +133,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 8f7554e2e07d5..38569baf82bce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; /** @@ -36,7 +37,8 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * This method will not be called in parallel for a single TransportClient (i.e., channel). + * Neither this method nor #receiveStream will be called in parallel for a single + * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. @@ -49,6 +51,36 @@ public abstract void receive( ByteBuffer message, RpcResponseCallback callback); + /** + * Receive a single RPC message which includes data that is to be received as a stream. Any + * exception thrown while in this method will be sent back to the client in string form as a + * standard RPC failure. + * + * Neither this method nor #receive will be called in parallel for a single TransportClient + * (i.e., channel). + * + * An error while reading data from the stream + * ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant + * to be relatively small, and will be buffered entirely in memory, to + * facilitate how the streaming portion should be received. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + * @return a StreamCallback for handling the accompanying streaming data + */ + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b0..e1d7b2dbff60f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -28,20 +29,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.TransportFrameDecoder; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -52,6 +43,7 @@ * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ public class TransportRequestHandler extends MessageHandler { + private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ @@ -113,6 +105,8 @@ public void handle(RequestMessage request) { processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); + } else if (request instanceof UploadStream) { + processStreamUpload((UploadStream) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -203,6 +197,79 @@ public void onFailure(Throwable e) { } } + /** + * Handle a request from the client to upload a stream of data. + */ + private void processStreamUpload(final UploadStream req) { + assert (req.body() == null); + try { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }; + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + if (streamHandler == null) { + throw new NullPointerException("rpcHandler returned a null streamHandler"); + } + StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + streamHandler.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + try { + streamHandler.onComplete(streamId); + callback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream;" + + " failing this rpc and leaving channel active"); + callback.onFailure(ioExc); + streamHandler.onFailure(streamId, ioExc); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + callback.onFailure(new IOException("Destination failed while reading stream", cause)); + streamHandler.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamHandler.getID(); + } + }; + if (req.bodyByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), + req.bodyByteCount, wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(wrappedCallback.getID()); + } + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // We choose to totally fail the channel, rather than trying to recover as we do in other + // cases. We don't know how many bytes of the stream the client has already sent for the + // stream, it's not worth trying to recover. + channel.pipeline().fireExceptionCaught(e); + } finally { + req.meta.release(); + } + } + private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8ff737b129641..1f4d75c7e2ec5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,43 +17,46 @@ package org.apache.spark.network; +import java.io.*; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { + static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; static List oneWayMsgs; + static StreamTestHelper testData; + + static ConcurrentHashMap streamCallbacks = + new ConcurrentHashMap<>(); @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + testData = new StreamTestHelper(); rpcHandler = new RpcHandler() { @Override public void receive( @@ -71,6 +74,14 @@ public void receive( } } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); + } + @Override public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs.add(JavaUtils.bytesToString(message)); @@ -85,10 +96,71 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } + private static StreamCallbackWithID receiveStreamHelper(String msg) { + try { + if (msg.startsWith("fail/")) { + String[] parts = msg.split("/"); + switch (parts[1]) { + case "exception-ondata": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + throw new IOException("failed to read stream data!"); + } + + @Override + public void onComplete(String streamId) throws IOException { + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "exception-oncomplete": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + } + + @Override + public void onComplete(String streamId) throws IOException { + throw new IOException("exception in onComplete"); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "null": + return null; + default: + throw new IllegalArgumentException("unexpected msg: " + msg); + } + } else { + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamCallbacks.put(msg, streamCallback); + return streamCallback; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @AfterClass public static void tearDown() { server.close(); clientFactory.close(); + testData.cleanup(); } static class RpcResult { @@ -130,6 +202,59 @@ public void onFailure(Throwable e) { return res; } + private RpcResult sendRpcWithStream(String... streams) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + for (String stream : streams) { + int idx = stream.lastIndexOf('/'); + ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); + String streamName = (idx == -1) ? stream : stream.substring(idx + 1); + ManagedBuffer data = testData.openStream(conf, streamName); + client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); + } + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + streamCallbacks.values().forEach(streamCallback -> { + try { + streamCallback.verify(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + client.close(); + return res; + } + + private static class RpcStreamCallback implements RpcResponseCallback { + final String streamId; + final RpcResult res; + final Semaphore sem; + + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { + this.streamId = streamId; + this.res = res; + this.sem = sem; + } + + @Override + public void onSuccess(ByteBuffer message) { + res.successMessages.add(streamId); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); @@ -193,10 +318,83 @@ public void sendOneWayMessage() throws Exception { } } + @Test + public void sendRpcWithStreamOneAtATime() throws Exception { + for (String stream : StreamTestHelper.STREAMS) { + RpcResult res = sendRpcWithStream(stream); + assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); + assertEquals(Sets.newHashSet(stream), res.successMessages); + } + } + + @Test + public void sendRpcWithStreamConcurrently() throws Exception { + String[] streams = new String[10]; + for (int i = 0; i < 10; i++) { + streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; + } + RpcResult res = sendRpcWithStream(streams); + assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void sendRpcWithStreamFailures() throws Exception { + // when there is a failure reading stream data, we don't try to keep the channel usable, + // just send back a decent error msg. + RpcResult exceptionInCallbackResult = + sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + RpcResult nullStreamHandler = + sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + // OTOH, if there is a failure during onComplete, the channel should still be fine + RpcResult exceptionInOnComplete = + sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); + assertErrorsContain(exceptionInOnComplete.errorMessages, + Sets.newHashSet("Failure post-processing")); + assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); + } + private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); + assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + + errors, contains.size(), errors.size()); + + Pair, Set> r = checkErrorsContain(errors, contains); + assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, + r.getRight().isEmpty()); + + assertTrue(r.getLeft().isEmpty()); + } + + private void assertErrorAndClosed(RpcResult result, String expectedError) { + assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); + // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + Set errors = result.errorMessages; + assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + + errors, 2, errors.size()); + + Set containsAndClosed = Sets.newHashSet(expectedError); + containsAndClosed.add("closed"); + containsAndClosed.add("Connection reset"); + + Pair, Set> r = checkErrorsContain(errors, containsAndClosed); + Set errorsNotFound = r.getRight(); + assertEquals(1, errorsNotFound.size()); + String err = errorsNotFound.iterator().next(); + assertTrue(err.equals("closed") || err.equals("Connection reset")); + + assertTrue(r.getLeft().isEmpty()); + } + + private Pair, Set> checkErrorsContain( + Set errors, + Set contains) { Set remainingErrors = Sets.newHashSet(errors); + Set notFound = Sets.newHashSet(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; @@ -207,9 +405,66 @@ private void assertErrorsContain(Set errors, Set contains) { break; } } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + if (!foundMatch) { + notFound.add(contain); + } + } + return new ImmutablePair<>(remainingErrors, notFound); + } + + private static class VerifyingStreamCallback implements StreamCallbackWithID { + final String streamId; + final StreamSuite.TestCallback helper; + final OutputStream out; + final File outFile; + + VerifyingStreamCallback(String streamId) throws IOException { + if (streamId.equals("file")) { + outFile = File.createTempFile("data", ".tmp", testData.tempDir); + out = new FileOutputStream(outFile); + } else { + out = new ByteArrayOutputStream(); + outFile = null; + } + this.streamId = streamId; + helper = new StreamSuite.TestCallback(out); + } + + void verify() throws IOException { + if (streamId.equals("file")) { + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); + } else { + byte[] result = ((ByteArrayOutputStream)out).toByteArray(); + ByteBuffer srcBuffer = testData.srcBuffer(streamId); + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + helper.onData(streamId, buf); } - assertTrue(remainingErrors.isEmpty()); + @Override + public void onComplete(String streamId) throws IOException { + helper.onComplete(streamId); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + helper.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamId; + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Random; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -37,9 +36,7 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; @@ -51,16 +48,11 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + private static final String[] STREAMS = StreamTestHelper.STREAMS; + private static StreamTestHelper testData; private static TransportServer server; private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -73,23 +65,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } + testData = new StreamTestHelper(); final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @@ -100,18 +76,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } + return testData.openStream(conf, streamId); } }; RpcHandler handler = new RpcHandler() { @@ -137,12 +102,7 @@ public StreamManager getStreamManager() { public static void tearDown() { server.close(); clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } + testData.cleanup(); } @Test @@ -234,21 +194,21 @@ public void run() { case "largeBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = largeBuffer; + srcBuffer = testData.largeBuffer; break; case "smallBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = smallBuffer; + srcBuffer = testData.smallBuffer; break; case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); + outFile = File.createTempFile("data", ".tmp", testData.tempDir); out = new FileOutputStream(outFile); break; case "emptyBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = emptyBuffer; + srcBuffer = testData.emptyBuffer; break; default: throw new IllegalArgumentException(streamId); @@ -256,10 +216,10 @@ public void run() { TestCallback callback = new TestCallback(out); client.stream(streamId, callback); - waitForCompletion(callback); + callback.waitForCompletion(timeoutMs); if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { ByteBuffer base; synchronized (srcBuffer) { @@ -292,23 +252,9 @@ public void check() throws Throwable { throw error; } } - - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } - } - private static class TestCallback implements StreamCallback { + static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; @@ -344,6 +290,22 @@ public void onFailure(String streamId, Throwable cause) { } } + void waitForCompletion(long timeoutMs) { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (this) { + while (!completed && now < deadline) { + try { + wait(deadline - now); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", completed); + assertNull(error); + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java new file mode 100644 index 0000000000000..0f5c82c9e9b1f --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Random; + +import com.google.common.io.Files; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +class StreamTestHelper { + static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + final File testFile; + final File tempDir; + + final ByteBuffer emptyBuffer; + final ByteBuffer smallBuffer; + final ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + StreamTestHelper() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + } + + public ByteBuffer srcBuffer(String name) { + switch (name) { + case "largeBuffer": + return largeBuffer; + case "smallBuffer": + return smallBuffer; + case "emptyBuffer": + return emptyBuffer; + default: + throw new IllegalArgumentException("Invalid stream: " + name); + } + } + + public ManagedBuffer openStream(TransportConf conf, String streamId) { + switch (streamId) { + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + return new NioManagedBuffer(srcBuffer(streamId)); + } + } + + void cleanup() { + if (tempDir != null) { + try { + JavaUtils.deleteRecursively(tempDir); + } catch (IOException io) { + throw new RuntimeException(io); + } + } + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f6d5ff898681..eeb097ef153ad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), From 1b9368f7d4c1d5c0df49204f48515d3b4ffe3e13 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 27 Jun 2018 10:27:40 +0800 Subject: [PATCH 1020/2461] [SPARK-24659][SQL] GenericArrayData.equals should respect element type differences ## What changes were proposed in this pull request? Fix `GenericArrayData.equals`, so that it respects the actual types of the elements. e.g. an instance that represents an `array` and another instance that represents an `array` should be considered incompatible, and thus should return false for `equals`. `GenericArrayData` doesn't keep any schema information by itself, and rather relies on the Java objects referenced by its `array` field's elements to keep track of their own object types. So, the most straightforward way to respect their types is to call `equals` on the elements, instead of using Scala's `==` operator, which can have semantics that are not always desirable: ``` new java.lang.Integer(123) == new java.lang.Long(123L) // true in Scala new java.lang.Integer(123).equals(new java.lang.Long(123L)) // false in Scala ``` ## How was this patch tested? Added unit test in `ComplexDataSuite` Author: Kris Mok Closes #21643 from rednaxelafx/fix-genericarraydata-equals. --- .../sql/catalyst/util/GenericArrayData.scala | 2 +- .../sql/catalyst/util/ComplexDataSuite.scala | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 9e39ed9c3a778..83ad08d8e1758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (o1 != o2) { + case _ => if (!o1.equals(o2)) { return false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index 9d285916bcf42..229e32479082c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -104,4 +104,40 @@ class ComplexDataSuite extends SparkFunSuite { // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } + + test("SPARK-24659: GenericArrayData.equals should respect element type differences") { + import scala.reflect.ClassTag + + // Expected positive cases + def arraysShouldEqual[T: ClassTag](element: T*): Unit = { + val array1 = new GenericArrayData(Array[T](element: _*)) + val array2 = new GenericArrayData(Array[T](element: _*)) + assert(array1.equals(array2)) + } + arraysShouldEqual(true, false) // Boolean + arraysShouldEqual(0.toByte, 123.toByte, Byte.MinValue, Byte.MaxValue) // Byte + arraysShouldEqual(0.toShort, 123.toShort, Short.MinValue, Short.MaxValue) // Short + arraysShouldEqual(0, 123, -65536, Int.MinValue, Int.MaxValue) // Int + arraysShouldEqual(0L, 123L, -65536L, Long.MinValue, Long.MaxValue) // Long + arraysShouldEqual(0.0F, 123.0F, Float.MinValue, Float.MaxValue, Float.MinPositiveValue, + Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN) // Float + arraysShouldEqual(0.0, 123.0, Double.MinValue, Double.MaxValue, Double.MinPositiveValue, + Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN) // Double + arraysShouldEqual(Array[Byte](123.toByte), Array[Byte](), null) // SQL Binary + arraysShouldEqual(UTF8String.fromString("foo"), null) // SQL String + + // Expected negative cases + // Spark SQL considers cases like array vs array to be incompatible, + // so an underlying implementation of array type should return false in such cases. + def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = { + val array1 = new GenericArrayData(Array[T](element1)) + val array2 = new GenericArrayData(Array[U](element2)) + assert(!array1.equals(array2)) + } + arraysShouldNotEqual(true, 1) // Boolean <-> Int + arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int + arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long + arraysShouldNotEqual(123.toShort, 123) // Short <-> Int + arraysShouldNotEqual(123, 123L) // Int <-> Long + } } From d08f53dc61f662f5291f71bcbe1a7b9f531a34d2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 27 Jun 2018 10:36:51 +0800 Subject: [PATCH 1021/2461] [SPARK-24605][SQL] size(null) returns null instead of -1 ## What changes were proposed in this pull request? In PR, I propose new behavior of `size(null)` under the config flag `spark.sql.legacy.sizeOfNull`. If the former one is disabled, the `size()` function returns `null` for `null` input. By default the `spark.sql.legacy.sizeOfNull` is enabled to keep backward compatibility with previous versions. In that case, `size(null)` returns `-1`. ## How was this patch tested? Modified existing tests for the `size()` function to check new behavior (`null`) and old one (`-1`). Author: Maxim Gekk Closes #21598 from MaxGekk/legacy-size-of-null. --- .../expressions/collectionOperations.scala | 38 ++++++++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../CollectionExpressionsSuite.scala | 30 +++++++---- .../org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 54 +++++++++++-------- 5 files changed, 93 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 58612f65c1a53..abd6c88d3d985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -67,37 +67,61 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression /** - * Given an array or map, returns its size. Returns -1 if null. + * Given an array or map, returns total number of elements in it. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map. Returns -1 if null.", + usage = """ + _FUNC_(expr) - Returns the size of an array or a map. + The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true. + If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input. + By default, the spark.sql.legacy.sizeOfNull parameter is set to true. + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a')); 4 + > SELECT _FUNC_(map('a', 1, 'b', 2)); + 2 + > SELECT _FUNC_(NULL); + -1 """) -case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Size( + child: Expression, + legacySizeOfNull: Boolean) + extends UnaryExpression with ExpectsInputTypes { + + def this(child: Expression) = + this( + child, + legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL)) + override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - override def nullable: Boolean = false + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - -1 + if (legacySizeOfNull) -1 else null } else child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() case _: MapType => value.asInstanceOf[MapData].numElements() + case other => throw new UnsupportedOperationException( + s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}") } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - ev.copy(code = code""" + if (legacySizeOfNull) { + val childGen = child.genCode(ctx) + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) + } else { + defineCodeGen(ctx, ev, c => s"($c).numElements()") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e768416f257c9..239c8266351ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1324,6 +1324,12 @@ object SQLConf { "Other column values can be ignored during parsing even if they are malformed.") .booleanConf .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) } /** @@ -1686,6 +1692,8 @@ class SQLConf extends Serializable with Logging { def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5b8cf5128fe21..caea4fb25ff7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -24,25 +24,37 @@ import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("Array and Map Size") { + def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - checkEvaluation(Size(a0), 3) - checkEvaluation(Size(a1), 0) - checkEvaluation(Size(a2), 2) + checkEvaluation(Size(a0, legacySizeOfNull), 3) + checkEvaluation(Size(a1, legacySizeOfNull), 0) + checkEvaluation(Size(a2, legacySizeOfNull), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - checkEvaluation(Size(m0), 2) - checkEvaluation(Size(m1), 0) - checkEvaluation(Size(m2), 1) + checkEvaluation(Size(m0, legacySizeOfNull), 2) + checkEvaluation(Size(m1, legacySizeOfNull), 0) + checkEvaluation(Size(m2, legacySizeOfNull), 1) + + checkEvaluation( + Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull), + expected = sizeOfNull) + checkEvaluation( + Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull), + expected = sizeOfNull) + } + + test("Array and Map Size - legacy") { + testSize(legacySizeOfNull = true, sizeOfNull = -1) + } - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + test("Array and Map Size") { + testSize(legacySizeOfNull = false, sizeOfNull = null) } test("MapKeys/MapValues") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 40c40e7083d1c..ef99ce3ad69d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3431,7 +3431,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { Size(e.expr) } + def size(e: Column): Column = withExpr { new Size(e.expr) } /** * Sorts the input array for the given column in ascending order, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d6a6c0832c96..b109898b5bfb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -487,26 +487,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { }.getMessage().contains("only supports array input")) } - test("array size function") { + def testSizeOfArray(sizeOfNull: Any): Unit = { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), (Seq[Int](1, 2, 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("cardinality(a)"), - Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) - ) + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("cardinality(a)"), Seq(Row(2L), Row(0L), Row(3L), Row(sizeOfNull))) + } + + test("array size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } + } + + test("array size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfArray(sizeOfNull = null) + } } test("dataframe arrays_zip function") { @@ -567,21 +570,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("map size function") { + def testSizeOfMap(sizeOfNull: Any): Unit = { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + } + + test("map size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } + } + + test("map size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfMap(sizeOfNull = null) + } } test("map_keys/map_values function") { From 2669b4de3b336dde84b698c20dbc73b30abf79d4 Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Wed, 27 Jun 2018 11:52:31 +0900 Subject: [PATCH 1022/2461] [SPARK-23927][SQL] Add "sequence" expression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The PR adds the SQL function ```sequence```. https://issues.apache.org/jira/browse/SPARK-23927 The behavior of the function is based on Presto's one. Ref: https://prestodb.io/docs/current/functions/array.html - ```sequence(start, stop) → array``` Generate a sequence of integers from ```start``` to ```stop```, incrementing by ```1``` if ```start``` is less than or equal to ```stop```, otherwise ```-1```. - ```sequence(start, stop, step) → array``` Generate a sequence of integers from ```start``` to ```stop```, incrementing by ```step```. - ```sequence(start_date, stop_date) → array``` Generate a sequence of dates from ```start_date``` to ```stop_date```, incrementing by ```interval 1 day``` if ```start_date``` is less than or equal to ```stop_date```, otherwise ```- interval 1 day```. - ```sequence(start_date, stop_date, step_interval) → array``` Generate a sequence of dates from ```start_date``` to ```stop_date```, incrementing by ```step_interval```. The type of ```step_interval``` is ```CalendarInterval```. - ```sequence(start_timestemp, stop_timestemp) → array``` Generate a sequence of timestamps from ```start_timestamps``` to ```stop_timestamps```, incrementing by ```interval 1 day``` if ```start_date``` is less than or equal to ```stop_date```, otherwise ```- interval 1 day```. - ```sequence(start_timestamp, stop_timestamp, step_interval) → array``` Generate a sequence of timestamps from ```start_timestamps``` to ```stop_timestamps```, incrementing by ```step_interval```. The type of ```step_interval``` is ```CalendarInterval```. ## How was this patch tested? Added unit tests. Author: Vayda, Oleksandr: IT (PRG) Closes #21155 from wajda/feature/array-api-sequence. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 7 + .../expressions/collectionOperations.scala | 402 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 292 +++++++++++++ .../org/apache/spark/sql/functions.scala | 21 + .../spark/sql/DataFrameFunctionsSuite.scala | 56 +++ 6 files changed, 777 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8abc616c1a3f7..a574d8a84d4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -432,6 +432,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 637923928a7da..3ebab430ffbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -544,6 +544,13 @@ object TypeCoercion { case None => aj } + case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) => + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { + case Some(widerDataType) => s.castChildrenTo(widerDataType) + case None => s + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index abd6c88d3d985..0395e1ef9a7ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,9 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.util.Comparator +import java.util.{Comparator, TimeZone} import scala.collection.mutable +import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} @@ -26,11 +27,13 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.collection.OpenHashSet /** @@ -2313,6 +2316,401 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } +@ExpressionDescription( + usage = """ + _FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive), + incrementing by step. The type of the returned elements is the same as the type of argument + expressions. + + Supported types are: byte, short, integer, long, date, timestamp. + + The start and stop expressions must resolve to the same type. + If start and stop expressions resolve to the 'date' or 'timestamp' type + then the step expression must resolve to the 'interval' type, otherwise to the same type + as the start and stop expressions. + """, + arguments = """ + Arguments: + * start - an expression. The start of the range. + * stop - an expression. The end the range (inclusive). + * step - an optional expression. The step of the range. + By default step is 1 if start is less than or equal to stop, otherwise -1. + For the temporal sequences it's 1 day and -1 day respectively. + If start is greater than stop then the step must be negative, and vice versa. + """, + examples = """ + Examples: + > SELECT _FUNC_(1, 5); + [1, 2, 3, 4, 5] + > SELECT _FUNC_(5, 1); + [5, 4, 3, 2, 1] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); + [2018-01-01, 2018-02-01, 2018-03-01] + """, + since = "2.4.0" +) +case class Sequence( + start: Expression, + stop: Expression, + stepOpt: Option[Expression], + timeZoneId: Option[String] = None) + extends Expression + with TimeZoneAwareExpression { + + import Sequence._ + + def this(start: Expression, stop: Expression) = + this(start, stop, None, None) + + def this(start: Expression, stop: Expression, step: Expression) = + this(start, stop, Some(step), None) + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) + + override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + + override def checkInputDataTypes(): TypeCheckResult = { + val startType = start.dataType + def stepType = stepOpt.get.dataType + val typesCorrect = + startType.sameType(stop.dataType) && + (startType match { + case TimestampType | DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case _: IntegralType => + stepOpt.isEmpty || stepType.sameType(startType) + case _ => false + }) + + if (typesCorrect) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName only supports integral, timestamp or date types") + } + } + + def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) + + def castChildrenTo(widerType: DataType): Expression = Sequence( + Cast(start, widerType), + Cast(stop, widerType), + stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + timeZoneId) + + private lazy val impl: SequenceImpl = dataType.elementType match { + case iType: IntegralType => + type T = iType.InternalType + val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) + new IntegralSequenceImpl(iType)(ct, iType.integral) + + case TimestampType => + new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone) + + case DateType => + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone) + } + + override def eval(input: InternalRow): Any = { + val startVal = start.eval(input) + if (startVal == null) return null + val stopVal = stop.eval(input) + if (stopVal == null) return null + val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal)) + if (stepVal == null) return null + + ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val startGen = start.genCode(ctx) + val stopGen = stop.genCode(ctx) + val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse( + impl.defaultStep.genCode(ctx, startGen, stopGen)) + + val resultType = CodeGenerator.javaType(dataType) + val resultCode = { + val arr = ctx.freshName("arr") + val arrElemType = CodeGenerator.javaType(dataType.elementType) + s""" + |final $arrElemType[] $arr = null; + |${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)} + |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr); + """.stripMargin + } + + if (nullable) { + val nullSafeEval = + startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) { + stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) { + stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode + """.stripMargin + } + } + } + ev.copy(code = + code""" + |boolean ${ev.isNull} = true; + |$resultType ${ev.value} = null; + |$nullSafeEval + """.stripMargin) + + } else { + ev.copy(code = + code""" + |${startGen.code} + |${stopGen.code} + |${stepGen.code} + |$resultType ${ev.value} = null; + |$resultCode + """.stripMargin, + isNull = FalseLiteral) + } + } +} + +object Sequence { + + private type LessThanOrEqualFn = (Any, Any) => Boolean + + private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) { + private val negativeOne = UnaryMinus(Literal(one)).eval() + + def apply(start: Any, stop: Any): Any = { + if (lteq(start, stop)) one else negativeOne + } + + def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = { + val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value) + ExprCode.forNonNullValue(JavaCode.expression( + s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal", + stepType)) + } + } + + private trait SequenceImpl { + def eval(start: Any, stop: Any, step: Any): Any + + def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String + + val defaultStep: DefaultStep + } + + private class IntegralSequenceImpl[T: ClassTag] + (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + elemType, + num.one) + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + import num._ + + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[T] + + var i: Int = getSequenceLength(start, stop, step) + val arr = new Array[T](i) + while (i > 0) { + i -= 1 + arr(i) = start + step * num.fromInt(i) + } + arr + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val i = ctx.freshName("i") + s""" + |${genSequenceLengthCode(ctx, start, stop, step, i)} + |$arr = new $elemType[$i]; + |while ($i > 0) { + | $i--; + | $arr[$i] = ($elemType) ($start + $step * $i); + |} + """.stripMargin + } + } + + private class TemporalSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone) + (implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + CalendarIntervalType, + new CalendarInterval(0, MICROS_PER_DAY)) + + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) + private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[CalendarInterval] + val stepMonths = step.months + val stepMicros = step.microseconds + + if (stepMonths == 0) { + backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale)) + + } else { + // To estimate the resulted array length we need to make assumptions + // about a month length in microseconds + val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth + val startMicros: Long = num.toLong(start) * scale + val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = + getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + + val stepSign = if (stopMicros > startMicros) +1 else -1 + val exclusiveItem = stopMicros + stepSign + val arr = new Array[T](maxEstimatedArrayLength) + var t = startMicros + var i = 0 + + while (t < exclusiveItem ^ stepSign < 0) { + arr(i) = fromLong(t / scale) + t = timestampAddInterval(t, stepMonths, stepMicros, timeZone) + i += 1 + } + + // truncate array to the correct length + if (arr.length == i) arr else arr.slice(0, i) + } + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val stepMonths = ctx.freshName("stepMonths") + val stepMicros = ctx.freshName("stepMicros") + val stepScaled = ctx.freshName("stepScaled") + val intervalInMicros = ctx.freshName("intervalInMicros") + val startMicros = ctx.freshName("startMicros") + val stopMicros = ctx.freshName("stopMicros") + val arrLength = ctx.freshName("arrLength") + val stepSign = ctx.freshName("stepSign") + val exclusiveItem = ctx.freshName("exclusiveItem") + val t = ctx.freshName("t") + val i = ctx.freshName("i") + val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName) + + val sequenceLengthCode = + s""" + |final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L; + |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} + """.stripMargin + + val timestampAddIntervalCode = + s""" + |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $t, $stepMonths, $stepMicros, $genTimeZone); + """.stripMargin + + s""" + |final int $stepMonths = $step.months; + |final long $stepMicros = $step.microseconds; + | + |if ($stepMonths == 0) { + | final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L); + | ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)}; + | + |} else { + | final long $startMicros = $start * ${scale}L; + | final long $stopMicros = $stop * ${scale}L; + | + | $sequenceLengthCode + | + | final int $stepSign = $stopMicros > $startMicros ? +1 : -1; + | final long $exclusiveItem = $stopMicros + $stepSign; + | + | $arr = new $elemType[$arrLength]; + | long $t = $startMicros; + | int $i = 0; + | + | while ($t < $exclusiveItem ^ $stepSign < 0) { + | $arr[$i] = ($elemType) ($t / ${scale}L); + | $timestampAddIntervalCode + | $i += 1; + | } + | + | if ($arr.length > $i) { + | $arr = java.util.Arrays.copyOf($arr, $i); + | } + |} + """.stripMargin + } + } + + private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + import num._ + require( + (step > num.zero && start <= stop) + || (step < num.zero && start >= stop) + || (step == num.zero && start == stop), + s"Illegal sequence boundaries: $start to $stop by $step") + + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + + require( + len <= MAX_ROUNDED_ARRAY_LENGTH, + s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + len.toInt + } + + private def genSequenceLengthCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + len: String): String = { + val longLen = ctx.freshName("longLen") + s""" + |if (!(($step > 0 && $start <= $stop) || + | ($step < 0 && $start >= $stop) || + | ($step == 0 && $start == $stop))) { + | throw new IllegalArgumentException( + | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); + |} + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { + | throw new IllegalArgumentException( + | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); + |} + |int $len = (int) $longLen; + """.stripMargin + } +} + /** * Returns the array containing the given input value (left) count (right) times. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index caea4fb25ff7e..d7744eb4c7dc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -17,10 +17,16 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} +import java.util.TimeZone + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -484,6 +490,292 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + test("Sequence of numbers") { + // test null handling + + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType)), null) + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(1L), Literal(null, LongType)), null) + + // test sequence boundaries checking + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(0)), EmptyRow, "boundaries: 2 to 1 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(1)), EmptyRow, "boundaries: 2 to 1 by 1") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") + + // test sequence with one element (zero step or equal start and stop) + + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(0)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(2), Literal(2)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(0), Literal(-2)), Seq(1)) + + // test sequence of different integral types (ascending and descending) + + checkEvaluation(new Sequence(Literal(1L), Literal(3L), Literal(1L)), Seq(1L, 2L, 3L)) + checkEvaluation(new Sequence(Literal(-3), Literal(3), Literal(3)), Seq(-3, 0, 3)) + checkEvaluation( + new Sequence(Literal(3.toShort), Literal(-3.toShort), Literal(-3.toShort)), + Seq(3.toShort, 0.toShort, -3.toShort)) + checkEvaluation( + new Sequence(Literal(-1.toByte), Literal(-3.toByte), Literal(-1.toByte)), + Seq(-1.toByte, -2.toByte, -3.toByte)) + } + + test("Sequence of timestamps") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())), + Seq( + Timestamp.valueOf("2018-03-03 00:00:00"), + Timestamp.valueOf("2018-02-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 second"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:01"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:04:06")), + Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:02:03"), + Timestamp.valueOf("2018-03-01 00:04:06"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5").negate())), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + } + + test("Sequence on DST boundaries") { + val timeZone = TimeZone.getTimeZone("Europe/Prague") + val dstOffset = timeZone.getDSTSavings + + def noDST(t: Timestamp): Timestamp = new Timestamp(t.getTime - dstOffset) + + DateTimeTestUtils.withDefaultTimeZone(timeZone) { + // Spring time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-25 01:30:00")), + Literal(Timestamp.valueOf("2018-03-25 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-03-25 01:30:00"), + Timestamp.valueOf("2018-03-25 03:00:00"), + Timestamp.valueOf("2018-03-25 03:30:00"))) + + // Autumn time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-10-28 01:30:00")), + Literal(Timestamp.valueOf("2018-10-28 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-10-28 01:30:00"), + noDST(Timestamp.valueOf("2018-10-28 02:00:00")), + noDST(Timestamp.valueOf("2018-10-28 02:30:00")), + Timestamp.valueOf("2018-10-28 02:00:00"), + Timestamp.valueOf("2018-10-28 02:30:00"), + Timestamp.valueOf("2018-10-28 03:00:00"), + Timestamp.valueOf("2018-10-28 03:30:00"))) + } + } + + test("Sequence of dates") { + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(CalendarInterval.fromString("interval 2 days"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-05"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-02")), + Literal(Date.valueOf("1970-01-01")), + Literal(CalendarInterval.fromString("interval 1 day"))), + EmptyRow, "sequence boundaries: 1 to 0 by 1") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}") + } + } + + test("Sequence with default step") { + // +/- 1 for integral type + checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3)) + checkEvaluation(new Sequence(Literal(3), Literal(1)), Seq(3, 2, 1)) + + // +/- 1 day for timestamps + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-03 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-03 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-03 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + // +/- 1 day for dates + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-03"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-03"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-03")), + Literal(Date.valueOf("2018-01-01"))), + Seq( + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-01"))) + } + test("Reverse") { // Primitive-type elements val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ef99ce3ad69d9..0b4f526799578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3485,6 +3485,27 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Generate a sequence of integers from start to stop, incrementing by step. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column, step: Column): Column = withExpr { + new Sequence(start.expr, stop.expr, step.expr) + } + + /** + * Generate a sequence of integers from start to stop, + * incrementing by 1 if start is less than or equal to stop, otherwise -1. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column): Column = withExpr { + new Sequence(start.expr, stop.expr) + } + /** * Creates an array containing the left argument repeated the number of times given by the * right argument. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b109898b5bfb3..4c28e2f1cd909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -862,6 +865,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("sequence") { + checkAnswer(Seq((-2, 2)).toDF().select(sequence('_1, '_2)), Seq(Row(Array(-2, -1, 0, 1, 2)))) + checkAnswer(Seq((7, 2, -2)).toDF().select(sequence('_1, '_2, '_3)), Seq(Row(Array(7, 5, 3)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01 00:00:00' as timestamp)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-03-01' as date)" + + ", interval 1 month)"), + Seq(Row(Array( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))))) + } + + // test type coercion + checkAnswer( + Seq((1.toByte, 3L, 1)).toDF().select(sequence('_1, '_2, '_3)), + Seq(Row(Array(1L, 2L, 3L)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + // test invalid data types + intercept[AnalysisException] { + Seq((true, false)).toDF().selectExpr("sequence(_1, _2)") + } + intercept[AnalysisException] { + Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)") + } + intercept[AnalysisException] { + Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)") + } + } + test("reverse function") { val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on From 9a76f23c6a1756053c30f58baea2966d1b023981 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Jun 2018 11:52:48 +0800 Subject: [PATCH 1023/2461] [SPARK-23927][SQL][FOLLOW-UP] Fix a build failure. ## What changes were proposed in this pull request? This pr is a follow-up pr of #21155. The #21155 removed unnecessary import at that time, but the import became necessary in another pr. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #21646 from ueshin/issues/SPARK-23927/fup1. --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0395e1ef9a7ad..8b278f067749e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods From a1a64e3583cfa451b4d0d2361c1da2972a5e4444 Mon Sep 17 00:00:00 2001 From: Yuexin Zhang Date: Wed, 27 Jun 2018 16:05:36 +0800 Subject: [PATCH 1024/2461] [SPARK-21335][DOC] doc changes for disallowed un-aliased subquery use case ## What changes were proposed in this pull request? Document a change for un-aliased subquery use case, to address the last question in PR #18559: https://github.com/apache/spark/pull/18559#issuecomment-316884858 (Please fill in changes proposed in this fix) ## How was this patch tested? it does not affect tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Yuexin Zhang Closes #21647 from cnZach/doc_change_for_SPARK-20690_SPARK-21335. --- docs/sql-programming-guide.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7c4ef41cc8907..cd7329b621122 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2017,6 +2017,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. + - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details. ## Upgrading From Spark SQL 2.1 to 2.2 From 6a0b77a55d53e74ac0a0892556c3a7a933474948 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 27 Jun 2018 10:43:06 -0700 Subject: [PATCH 1025/2461] [SPARK-24215][PYSPARK][FOLLOW UP] Implement eager evaluation for DataFrame APIs in PySpark ## What changes were proposed in this pull request? Address comments in #21370 and add more test. ## How was this patch tested? Enhance test in pyspark/sql/test.py and DataFrameSuite Author: Yuanjian Li Closes #21553 from xuanyuanking/SPARK-24215-follow. --- docs/configuration.md | 27 --------- python/pyspark/sql/dataframe.py | 3 +- python/pyspark/sql/tests.py | 46 ++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 23 ++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 11 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 59 +++++++++++++++++++ 6 files changed, 131 insertions(+), 38 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 6aa7878fe614d..0c7c4472be643 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -456,33 +456,6 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. - - spark.sql.repl.eagerEval.enabled - false - - Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, - Dataset will be ran automatically. The HTML table which generated by _repl_html_ - called by notebooks like Jupyter will feedback the queries user have defined. For plain Python - REPL, the output will be shown like dataframe.show() - (see SPARK-24215 for more details). - - - - spark.sql.repl.eagerEval.maxNumRows - 20 - - Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text, - this only take effect when spark.sql.repl.eagerEval.enabled is set to true. - - - - spark.sql.repl.eagerEval.truncate - 20 - - Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or - plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true. - - spark.files diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e6a1acebb5ca..cb3fe448b6fc7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -393,9 +393,8 @@ def _repr_html_(self): self._support_repr_html = True if self._eager_eval: max_num_rows = max(self._max_num_rows, 0) - vertical = False sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate, vertical) + max_num_rows, self._truncate) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 35a0636e5cfc0..8d738069adb3d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3351,11 +3351,41 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) - def test_repr_html(self): + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) - self.assertEquals(None, df._repr_html_()) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): expected1 = """ | @@ -3381,6 +3411,18 @@ def test_repr_html(self): |""" self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 239c8266351ae..e1752ff997b69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1330,6 +1330,29 @@ object SQLConf { "The size function returns null for null input if the flag is disabled.") .booleanConf .createWithDefault(true) + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + + "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + + "the returned outputs are formatted like dataframe.show().") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + .intConf + .createWithDefault(20) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 57f1e173211af..2ec236fc75efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -236,12 +236,10 @@ class Dataset[T] private[sql]( * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, the rows to return do not need truncate. */ private[sql] def getRows( numRows: Int, - truncate: Int, - vertical: Boolean): Seq[Seq[String]] = { + truncate: Int): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -289,7 +287,7 @@ class Dataset[T] private[sql]( vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. - val tmpRows = getRows(numRows, truncate, vertical) + val tmpRows = getRows(numRows, truncate) val hasMoreData = tmpRows.length - 1 > numRows val rows = tmpRows.take(numRows + 1) @@ -3226,11 +3224,10 @@ class Dataset[T] private[sql]( private[sql] def getRowsToPython( _numRows: Int, - truncate: Int, - vertical: Boolean): Array[Any] = { + truncate: Int): Array[Any] = { EvaluatePython.registerPicklers() val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val rows = getRows(numRows, truncate).map(_.toArray).toArray val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( rows.iterator.map(toJava)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cc8cb3874c9b..ea00d22bff001 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select($"*").show(1000) } + test("getRows: truncate = [0, 20]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111111111111111111111")) + assert(df.getRows(10, 0) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111111...")) + assert(df.getRows(10, 20) === expectedAnswerForTrue) + } + + test("getRows: truncate = [3, 17]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111")) + assert(df.getRows(10, 3) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111...")) + assert(df.getRows(10, 17) === expectedAnswerForTrue) + } + + test("getRows: numRows = 0") { + val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1")) + assert(testData.select($"*").getRows(0, 20) === expectedAnswer) + } + + test("getRows: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[1, 2, 3]", "[1, 2, 3]"), + Seq("[2, 3, 4]", "[2, 3, 4]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + + test("getRows: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[31 32]", "[41 42 43 2E]"), + Seq("[33 34]", "[31 32 33 34 36]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + test("showString: truncate = [0, 20]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() From 78ecb6d457970b136a2e0e0e27d170c84ea28eac Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 27 Jun 2018 10:57:29 -0700 Subject: [PATCH 1026/2461] [SPARK-24446][YARN] Properly quote library path for YARN. Because the way YARN executes commands via bash -c, everything needs to be quoted so that the whole command is fully contained inside a bash string and is interpreted correctly when the string is read by bash. This is a bit different than the quoting done when executing things as if typing in a bash shell. Tweaked unit tests to exercise the bad behavior, which would cause existing tests to time out without the fix. Also tested on a real cluster, verifying the shell script created by YARN to run the container. Author: Marcelo Vanzin Closes #21476 from vanzin/SPARK-24446. --- .../org/apache/spark/deploy/yarn/Client.scala | 22 +++++++++++++++++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 11 +++++----- .../deploy/yarn/BaseYarnClusterSuite.scala | 9 ++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7225ff03dc34e..793d012218490 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -899,7 +899,8 @@ private[spark] class Client( val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) + prefixEnv = Some(createLibraryPathPrefix(libraryPaths.mkString(File.pathSeparator), + sparkConf)) } if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) { logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode") @@ -921,7 +922,7 @@ private[spark] class Client( .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) + prefixEnv = Some(createLibraryPathPrefix(paths, sparkConf)) } } @@ -1485,6 +1486,23 @@ private object Client extends Logging { YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) } + /** + * Create a properly quoted and escaped library path string to be added as a prefix to the command + * executed by YARN. This is different from normal quoting / escaping due to YARN executing the + * command through "bash -c". + */ + def createLibraryPathPrefix(libpath: String, conf: SparkConf): String = { + val cmdPrefix = if (Utils.isWindows) { + Utils.libraryPathEnvPrefix(Seq(libpath)) + } else { + val envName = Utils.libraryPathEnvName + // For quotes, escape both the quote and the escape character when encoding in the command + // string. + val quoted = libpath.replace("\"", "\\\\\\\"") + envName + "=\\\"" + quoted + File.pathSeparator + "$" + envName + "\\\"" + } + getClusterPath(conf, cmdPrefix) + } } private[spark] class YarnClusterApplication extends SparkApplication { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index a2a18cdff65af..49a0b93aa5c40 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -131,10 +131,6 @@ private[yarn] class ExecutorRunnable( // Extra options for the JVM val javaOpts = ListBuffer[String]() - // Set the environment variable through a command prefix - // to append to the existing value of the variable - var prefixEnv: Option[String] = None - // Set the JVM memory val executorMemoryString = executorMemory + "m" javaOpts += "-Xmx" + executorMemoryString @@ -144,8 +140,11 @@ private[yarn] class ExecutorRunnable( val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } - sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => - prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) + + // Set the library path through a command prefix to append to the existing value of the + // env variable. + val prefixEnv = sparkConf.get(EXECUTOR_LIBRARY_PATH).map { libPath => + Client.createLibraryPathPrefix(libPath, sparkConf) } javaOpts += "-Djava.io.tmpdir=" + diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index ac67f2196e0a0..b0abcc9149d08 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher._ import org.apache.spark.util.Utils @@ -216,6 +217,14 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + // SPARK-24446: make sure special characters in the library path do not break containers. + if (!Utils.isWindows) { + val libPath = """/tmp/does not exist:$PWD/tmp:/tmp/quote":/tmp/ampersand&""" + props.setProperty(AM_LIBRARY_PATH.key, libPath) + props.setProperty(DRIVER_LIBRARY_PATH.key, libPath) + props.setProperty(EXECUTOR_LIBRARY_PATH.key, libPath) + } + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } From c04cb2d1b72b1edaddf684755f5a9d6aaf00e03b Mon Sep 17 00:00:00 2001 From: debugger87 Date: Wed, 27 Jun 2018 11:34:28 -0700 Subject: [PATCH 1027/2461] [SPARK-21687][SQL] Spark SQL should set createTime for Hive partition ## What changes were proposed in this pull request? Set createTime for every hive partition created in Spark SQL, which could be used to manage data lifecycle in Hive warehouse. We found that almost every partition modified by spark sql has not been set createTime. ``` mysql> select * from partitions where create_time=0 limit 1\G; *************************** 1. row *************************** PART_ID: 1028584 CREATE_TIME: 0 LAST_ACCESS_TIME: 1502203611 PART_NAME: date=20170130 SD_ID: 1543605 TBL_ID: 211605 LINK_TARGET_ID: NULL 1 row in set (0.27 sec) ``` ## How was this patch tested? N/A Author: debugger87 Author: Chaozhong Yang Closes #18900 from debugger87/fix/set-create-time-for-hive-partition. --- .../spark/sql/catalyst/catalog/interface.scala | 6 ++++++ .../sql/catalyst/catalog/SessionCatalogSuite.scala | 6 ++++-- .../results/describe-part-after-analyze.sql.out | 14 ++++++++++++++ .../resources/sql-tests/results/describe.sql.out | 4 ++++ .../sql-tests/results/show-tables.sql.out | 2 ++ .../spark/sql/hive/client/HiveClientImpl.scala | 4 ++++ 6 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f3e67dc4e975c..c6105c5526049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -93,12 +93,16 @@ object CatalogStorageFormat { * @param spec partition spec values indexed by column name * @param storage storage format of the partition * @param parameters some parameters for the partition + * @param createTime creation time of the partition, in milliseconds + * @param lastAccessTime last access time, in milliseconds * @param stats optional statistics (number of rows, total size, etc.) */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, stats: Option[CatalogStatistics] = None) { def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { @@ -109,6 +113,8 @@ case class CatalogTablePartition( if (parameters.nonEmpty) { map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } + map.put("Created Time", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6abab0073cca3..6a7375ee186fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1114,11 +1114,13 @@ abstract class SessionCatalogSuite extends AnalysisTest { // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) // in table's parameters and storage's properties, here we also ignore them. val actualPartsNormalize = actualParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet val expectedPartsNormalize = expectedParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet actualPartsNormalize == expectedPartsNormalize diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 58ed201e2a60f..8ba69c698b551 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -57,6 +57,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -89,6 +91,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -122,6 +126,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -147,6 +153,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1098 bytes, 4 rows # Storage Information @@ -180,6 +188,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -205,6 +215,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1098 bytes, 4 rows # Storage Information @@ -230,6 +242,8 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1144 bytes, 2 rows # Storage Information diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 8c908b7625056..79390cb424444 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -282,6 +282,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 @@ -311,6 +313,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 975bb06124744..abeb7e18f031e 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -178,6 +178,8 @@ struct -- !query 14 output showdb show_t1 false Partition Values: [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 +Created Time [not included in comparison] +Last Access [not included in comparison] -- !query 15 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index da9fe2d3088b4..1df46d7431a21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -995,6 +995,8 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setCreateTime((p.createTime / 1000).toInt) + tpart.setLastAccessTime((p.lastAccessTime / 1000).toInt) tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } @@ -1019,6 +1021,8 @@ private[hive] object HiveClientImpl { compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), + createTime = apiPartition.getCreateTime.toLong * 1000, + lastAccessTime = apiPartition.getLastAccessTime.toLong * 1000, parameters = properties, stats = readHiveStats(properties)) } From 776befbfd5b3c317a713d4fa3882cda6264db9ba Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 27 Jun 2018 14:26:08 -0700 Subject: [PATCH 1028/2461] [SPARK-24660][SHS] Show correct error pages when downloading logs ## What changes were proposed in this pull request? SHS is showing bad errors when trying to download logs is not successful. This may happen because the requested application doesn't exist or the user doesn't have permissions for it, for instance. The PR fixes the response when errors occur, so that they are displayed properly. ## How was this patch tested? manual tests **Before the patch:** 1. Unauthorized user ![screen shot 2018-06-26 at 3 53 33 pm](https://user-images.githubusercontent.com/8821783/41918118-f8b37e70-795b-11e8-91e8-d0250239f09d.png) 2. Non-existing application ![screen shot 2018-06-26 at 3 25 19 pm](https://user-images.githubusercontent.com/8821783/41918082-e3034c72-795b-11e8-970e-cee4a1eae77f.png) **After the patch** 1. Unauthorized user ![screen shot 2018-06-26 at 3 41 29 pm](https://user-images.githubusercontent.com/8821783/41918155-0d950476-795c-11e8-8d26-7b7ce73e6fe1.png) 2. Non-existing application ![screen shot 2018-06-26 at 3 40 37 pm](https://user-images.githubusercontent.com/8821783/41918175-1a14bb88-795c-11e8-91ab-eadf29190a02.png) Author: Marco Gaido Closes #21644 from mgaido91/SPARK-24660. --- .../spark/status/api/v1/ApiRootResource.scala | 30 ++++--------------- .../status/api/v1/JacksonMessageWriter.scala | 5 +--- .../api/v1/OneApplicationResource.scala | 7 ++--- .../scala/org/apache/spark/ui/UIUtils.scala | 5 ++++ 4 files changed, 13 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index d121068718b8a..84c2ad48f1f27 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -28,7 +28,7 @@ import org.glassfish.jersey.server.ServerProperties import org.glassfish.jersey.servlet.ServletContainer import org.apache.spark.SecurityManager -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * Main entry point for serving spark application metrics as json, using JAX-RS. @@ -148,38 +148,18 @@ private[v1] trait BaseAppResource extends ApiRequestContext { } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( - Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + UIUtils.buildErrorResponse(Response.Status.FORBIDDEN, msg)) private[v1] class NotFoundException(msg: String) extends WebApplicationException( - new NoSuchElementException(msg), - Response - .status(Response.Status.NOT_FOUND) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.NOT_FOUND, msg)) private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( - new ServiceUnavailableException(msg), - Response - .status(Response.Status.SERVICE_UNAVAILABLE) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.SERVICE_UNAVAILABLE, msg)) private[v1] class BadParameterException(msg: String) extends WebApplicationException( - new IllegalArgumentException(msg), - Response - .status(Response.Status.BAD_REQUEST) - .entity(ErrorWrapper(msg)) - .build() -) { + UIUtils.buildErrorResponse(Response.Status.BAD_REQUEST, msg)) { def this(param: String, exp: String, actual: String) = { this(raw"""Bad value for parameter "$param". Expected a $exp, got "$actual"""") } } -/** - * Signal to JacksonMessageWriter to not convert the message into json (which would result in an - * extra set of quotes). - */ -private[v1] case class ErrorWrapper(s: String) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 76af33c1a18db..4560d300cb0c8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -68,10 +68,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ mediaType: MediaType, multivaluedMap: MultivaluedMap[String, AnyRef], outputStream: OutputStream): Unit = { - t match { - case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8)) - case _ => mapper.writeValue(outputStream, t) - } + mapper.writeValue(outputStream, t) } override def getSize( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 974697890dd03..32100c5704538 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -140,11 +140,8 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) .build() } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() + case NonFatal(_) => + throw new ServiceUnavailable(s"Event logs are not available for app: $appId.") } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 5d015b0531ef6..732b7528f499e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} import javax.servlet.http.HttpServletRequest +import javax.ws.rs.core.{MediaType, Response} import scala.util.control.NonFatal import scala.xml._ @@ -566,4 +567,8 @@ private[spark] object UIUtils extends Logging { NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, "")) } } + + def buildErrorResponse(status: Response.Status, msg: String): Response = { + Response.status(status).entity(msg).`type`(MediaType.TEXT_PLAIN).build() + } } From 221d03acca19bdf7a2624a29c180c99f098205d8 Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Wed, 27 Jun 2018 14:37:19 -0700 Subject: [PATCH 1029/2461] [SPARK-24533] Typesafe rebranded to lightbend. Changing the build downloads path Typesafe has rebranded to lightbend. Just changing the downloads path to avoid redirection Tested by running build/mvn -DskipTests package Author: Sanket Chintapalli Closes #21636 from redsanket/SPARK-24533. --- build/mvn | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build/mvn b/build/mvn index 1405983982d4c..ae4276dbc7e32 100755 --- a/build/mvn +++ b/build/mvn @@ -93,7 +93,7 @@ install_mvn() { install_zinc() { local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/zinc/0.3.15" \ @@ -109,7 +109,7 @@ install_scala() { # determine the Scala version used in Spark local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/scala/${scala_version}" \ From 893ea224cc738766be207c87f4b913fe8fea4c94 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 27 Jun 2018 15:25:51 -0700 Subject: [PATCH 1030/2461] [SPARK-24204][SQL] Verify a schema in Json/Orc/ParquetFileFormat ## What changes were proposed in this pull request? This pr added code to verify a schema in Json/Orc/ParquetFileFormat along with CSVFileFormat. ## How was this patch tested? Added verification tests in `FileBasedDataSourceSuite` and `HiveOrcSourceSuite`. Author: Takeshi Yamamuro Closes #21389 from maropu/SPARK-24204. --- .../datasources/DataSourceUtils.scala | 106 +++++++++ .../datasources/csv/CSVFileFormat.scala | 4 +- .../execution/datasources/csv/CSVUtils.scala | 19 -- .../datasources/json/JsonFileFormat.scala | 4 + .../datasources/orc/OrcFileFormat.scala | 4 + .../parquet/ParquetFileFormat.scala | 3 + .../spark/sql/FileBasedDataSourceSuite.scala | 213 +++++++++++++++++- .../execution/datasources/csv/CSVSuite.scala | 33 --- .../spark/sql/hive/orc/OrcFileFormat.scala | 4 + .../sql/hive/orc/HiveOrcSourceSuite.scala | 49 +++- 10 files changed, 383 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala new file mode 100644 index 0000000000000..c5347218c4b40 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.types._ + + +object DataSourceUtils { + + /** + * Verify if the schema is supported in datasource in write path. + */ + def verifyWriteSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = false) + } + + /** + * Verify if the schema is supported in datasource in read path. + */ + def verifyReadSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = true) + } + + /** + * Verify if the schema is supported in datasource. This verification should be done + * in a driver side, e.g., `prepareWrite`, `buildReader`, and `buildReaderWithPartitionValues` + * in `FileFormat`. + * + * Unsupported data types of csv, json, orc, and parquet are as follows; + * csv -> R/W: Interval, Null, Array, Map, Struct + * json -> W: Interval + * orc -> W: Interval, Null + * parquet -> R/W: Interval, Null + */ + private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { + def throwUnsupportedException(dataType: DataType): Unit = { + throw new UnsupportedOperationException( + s"$format data source does not support ${dataType.simpleString} data type.") + } + + def verifyType(dataType: DataType): Unit = dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType | DateType | TimestampType | _: DecimalType => + + // All the unsupported types for CSV + case _: NullType | _: CalendarIntervalType | _: StructType | _: ArrayType | _: MapType + if format.isInstanceOf[CSVFileFormat] => + throwUnsupportedException(dataType) + + case st: StructType => st.foreach { f => verifyType(f.dataType) } + + case ArrayType(elementType, _) => verifyType(elementType) + + case MapType(keyType, valueType, _) => + verifyType(keyType) + verifyType(valueType) + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + // Interval type not supported in all the write path + case _: CalendarIntervalType if !isReadPath => + throwUnsupportedException(dataType) + + // JSON and ORC don't support an Interval type, but we pass it in read pass + // for back-compatibility. + case _: CalendarIntervalType if format.isInstanceOf[JsonFileFormat] || + format.isInstanceOf[OrcFileFormat] => + + // Interval type not supported in the other read path + case _: CalendarIntervalType => + throwUnsupportedException(dataType) + + // For JSON & ORC backward-compatibility + case _: NullType if format.isInstanceOf[JsonFileFormat] || + (isReadPath && format.isInstanceOf[OrcFileFormat]) => + + // Null type not supported in the other path + case _: NullType => + throwUnsupportedException(dataType) + + // We keep this default case for safeguards + case _ => throwUnsupportedException(dataType) + } + + schema.foreach(field => verifyType(field.dataType)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index b90275de9f40a..fa366ccce6b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -66,7 +66,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - CSVUtils.verifySchema(dataSchema) + DataSourceUtils.verifyWriteSchema(this, dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +98,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - CSVUtils.verifySchema(dataSchema) + DataSourceUtils.verifyReadSchema(this, dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 1012e774118e2..7ce65fa89b02d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -117,25 +117,6 @@ object CSVUtils { } } - /** - * Verify if the schema is supported in CSV datasource. - */ - def verifySchema(schema: StructType): Unit = { - def verifyType(dataType: DataType): Unit = dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | - DoubleType | BooleanType | _: DecimalType | TimestampType | - DateType | StringType => - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - case _ => - throw new UnsupportedOperationException( - s"CSV data source does not support ${dataType.simpleString} data type.") - } - - schema.foreach(field => verifyType(field.dataType)) - } - /** * Sample CSV dataset as configured by `samplingRatio`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index e9a0b383b5f49..383bff1375a93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -65,6 +65,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) + val conf = job.getConfiguration val parsedOptions = new JSONOptions( options, @@ -96,6 +98,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 1de2ca2914c44..df488a748e3e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -89,6 +89,8 @@ class OrcFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val conf = job.getConfiguration @@ -141,6 +143,8 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(dataSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 60fc9ec7e1f82..9602a08911dea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -78,6 +78,7 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) @@ -302,6 +303,8 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 06303099f5310..86f9647b4ac4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql -import java.io.FileNotFoundException +import java.io.{File, FileNotFoundException} +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -202,4 +204,213 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + // Unsupported data types of csv, json, orc, and parquet are as follows; + // csv -> R/W: Interval, Null, Array, Map, Struct + // json -> W: Interval + // orc -> W: Interval, Null + // parquet -> R/W: Interval, Null + test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + var msg = intercept[UnsupportedOperationException] { + Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType.fromDDL("a struct") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType.fromDDL("a map") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType.fromDDL("a array") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type.")) + } + } + + test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } + + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } + + // We expect the types below should be passed for backward-compatibility + Seq("orc", "json").foreach { format => + // Interval type + var schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having interval data + schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + } + } + + test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("orc").foreach { format => + // write path + var msg = intercept[UnsupportedOperationException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + // We expect the types below should be passed for backward-compatibility + + // Null type + var schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having null data + schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + + Seq("parquet", "csv").foreach { format => + // write path + var msg = intercept[UnsupportedOperationException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + } + } + } +} + +object TestingUDT { + + @SQLUserDefinedType(udt = classOf[IntervalUDT]) + class IntervalData extends Serializable + + class IntervalUDT extends UserDefinedType[IntervalData] { + + override def sqlType: DataType = CalendarIntervalType + override def serialize(obj: IntervalData): Any = + throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): IntervalData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[IntervalData] = classOf[IntervalData] + } + + @SQLUserDefinedType(udt = classOf[NullUDT]) + private[sql] class NullData extends Serializable + + private[sql] class NullUDT extends UserDefinedType[NullData] { + + override def sqlType: DataType = NullType + override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): NullData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[NullData] = classOf[NullData] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d2f166c7d1877..365239d040ef2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -740,39 +740,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(numbers.count() == 8) } - test("error handling for unsupported data types.") { - withTempDir { dir => - val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { - Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support struct data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support map data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") - .write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) - spark.range(1).write.csv(csvDir) - spark.read.schema(schema).csv(csvDir).collect() - }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) - } - } - test("SPARK-15585 turn off quotations") { val cars = spark.read .format("csv") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 237ed9bc05988..dd2144c5fcea8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -72,6 +72,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val configuration = job.getConfiguration @@ -121,6 +123,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index d556a030e2186..69009e1b520c2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -133,4 +135,49 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { Utils.deleteRecursively(location) } } + + test("SPARK-24204 error handling for unsupported data types") { + withTempDir { dir => + val orcDir = new File(dir, "orc").getCanonicalPath + + // write path + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[UnsupportedOperationException] { + sql("select null").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + // read path + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + } + } } From c5aa54d54b301555bad1ff0653df11293f0033ed Mon Sep 17 00:00:00 2001 From: "Kallman, Steven" Date: Wed, 27 Jun 2018 15:36:59 -0700 Subject: [PATCH 1031/2461] [SPARK-24553][WEB-UI] http 302 fixes for href redirect ## What changes were proposed in this pull request? Updated URL/href links to include a '/' before '?id' to make links consistent and avoid http 302 redirect errors within UI port 4040 tabs. ## How was this patch tested? Built a runnable distribution and executed jobs. Validated that http 302 redirects are no longer encountered when clicking on links within UI port 4040 tabs. Author: Steven Kallman Author: Kallman, Steven Closes #21600 from SJKallman/{Spark-24553}{WEB-UI}-redirect-href-fixes. --- .../src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/sql/execution/ui/AllExecutionsPage.scala | 4 ++-- .../org/apache/spark/sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/hive/thriftserver/ui/ThriftServerPage.scala | 4 ++-- .../sql/hive/thriftserver/ui/ThriftServerSessionPage.scala | 2 +- .../main/scala/org/apache/spark/streaming/ui/BatchPage.scala | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 178d2c8d1a10a..90e9a7a3630cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -464,7 +464,7 @@ private[ui] class JobDataSource( val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) - val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) + val detailUrl = "%s/jobs/job/?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d4e6a7bc3effa..55eb989962668 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val _taskTable = new TaskPagedTable( stageData, UIUtils.prependBaseUri(request, parent.basePath) + - s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 56e4d6838a99a..d01acdae59c9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -370,7 +370,7 @@ private[ui] class StagePagedTable( Seq.empty } - val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" + val nameLinkUri = s"$basePathUri/stages/stage/?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index bf46bc4cf904d..a7a24ac3641b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -214,11 +214,11 @@ private[ui] abstract class ExecutionTable( } private def jobURL(request: HttpServletRequest, jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def executionURL(request: HttpServletRequest, executionID: Long): String = s"${UIUtils.prependBaseUri( - request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" + request, parent.basePath)}/${parent.prefix}/execution/?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 282f7b4bb5a58..877176b030f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -122,7 +122,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging } private def jobURL(request: HttpServletRequest, jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 0950b30126773..771104ceb8842 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -76,7 +76,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - [{id}] @@ -147,7 +147,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s".format( + val sessionLink = "%s/%s/session/?id=%s".format( UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId)
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index c884aa0ecbdf8..163eb43aabc72 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -86,7 +86,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - [{id}] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index ca9da6139649a..884d21d0afdd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -109,7 +109,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") val detailUrl = s"${SparkUIUtils.prependBaseUri( - request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + request, parent.basePath)}/jobs/job/?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". From bd32b509a1728366494cba13f8f6612b7bd46ec0 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Jun 2018 09:19:25 +0800 Subject: [PATCH 1032/2461] [SPARK-24645][SQL] Skip parsing when csvColumnPruning enabled and partitions scanned only ## What changes were proposed in this pull request? In the master, when `csvColumnPruning`(implemented in [this commit](https://github.com/apache/spark/commit/64fad0b519cf35b8c0a0dec18dd3df9488a5ed25#diff-d19881aceddcaa5c60620fdcda99b4c4)) enabled and partitions scanned only, it throws an exception below; ``` scala> val dir = "/tmp/spark-csv/csv" scala> spark.range(10).selectExpr("id % 2 AS p", "id").write.mode("overwrite").partitionBy("p").csv(dir) scala> spark.read.csv(dir).selectExpr("sum(p)").collect() 18/06/25 13:12:51 ERROR Executor: Exception in task 0.0 in stage 2.0 (TID 5) java.lang.NullPointerException at org.apache.spark.sql.execution.datasources.csv.UnivocityParser.org$apache$spark$sql$execution$datasources$csv$UnivocityParser$$convert(UnivocityParser.scala:197) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser.parse(UnivocityParser.scala:190) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser$$anonfun$5.apply(UnivocityParser.scala:309) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser$$anonfun$5.apply(UnivocityParser.scala:309) at org.apache.spark.sql.execution.datasources.FailureSafeParser.parse(FailureSafeParser.scala:61) ... ``` This pr modified code to skip CSV parsing in the case. ## How was this patch tested? Added tests in `CSVSuite`. Author: Takeshi Yamamuro Closes #21631 from maropu/SPARK-24645. --- .../execution/datasources/csv/UnivocityParser.scala | 10 +++++++++- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 5f7d5696b71a6..aa545e1a0c00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -183,11 +183,19 @@ class UnivocityParser( } } + private val doParse = if (schema.nonEmpty) { + (input: String) => convert(tokenizer.parseLine(input)) + } else { + // If `columnPruning` enabled and partition attributes scanned only, + // `schema` gets empty. + (_: String) => InternalRow.empty + } + /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + def parse(input: String): InternalRow = doParse(input) private def convert(tokens: Array[String]): InternalRow = { if (tokens.length != schema.length) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 365239d040ef2..84b91f6309fe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1569,4 +1569,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(testAppender2.events.asScala .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) } + + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id").write.partitionBy("p").csv(dir) + checkAnswer(spark.read.csv(dir).selectExpr("sum(p)"), Row(5)) + } + } + } } From 1c9acc2438f9a97134ae5213a12112b2361fbb78 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Jun 2018 09:21:10 +0800 Subject: [PATCH 1033/2461] [SPARK-24206][SQL][FOLLOW-UP] Update DataSourceReadBenchmark benchmark results ## What changes were proposed in this pull request? This pr corrected the default configuration (`spark.master=local[1]`) for benchmarks. Also, this updated performance results on the AWS `r3.xlarge`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21625 from maropu/FixDataSourceReadBenchmark. --- .../benchmark/DataSourceReadBenchmark.scala | 296 +++++++++--------- 1 file changed, 152 insertions(+), 144 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index fc6d8abc03c09..8711f5a8fa1ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -39,9 +39,11 @@ import org.apache.spark.util.{Benchmark, Utils} object DataSourceReadBenchmark { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") - .setIfMissing("spark.master", "local[1]") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") val spark = SparkSession.builder.config(conf).getOrCreate() @@ -154,73 +156,73 @@ object DataSourceReadBenchmark { } } - /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 15231 / 15267 1.0 968.3 1.0X - SQL Json 8476 / 8498 1.9 538.9 1.8X - SQL Parquet Vectorized 121 / 127 130.0 7.7 125.9X - SQL Parquet MR 1515 / 1543 10.4 96.3 10.1X - SQL ORC Vectorized 164 / 171 95.9 10.4 92.9X - SQL ORC Vectorized with copy 228 / 234 69.0 14.5 66.8X - SQL ORC MR 1297 / 1309 12.1 82.5 11.7X + SQL CSV 22964 / 23096 0.7 1460.0 1.0X + SQL Json 8469 / 8593 1.9 538.4 2.7X + SQL Parquet Vectorized 164 / 177 95.8 10.4 139.9X + SQL Parquet MR 1687 / 1706 9.3 107.2 13.6X + SQL ORC Vectorized 191 / 197 82.3 12.2 120.2X + SQL ORC Vectorized with copy 215 / 219 73.2 13.7 106.9X + SQL ORC MR 1392 / 1412 11.3 88.5 16.5X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 16344 / 16374 1.0 1039.1 1.0X - SQL Json 8634 / 8648 1.8 548.9 1.9X - SQL Parquet Vectorized 172 / 177 91.5 10.9 95.1X - SQL Parquet MR 1744 / 1746 9.0 110.9 9.4X - SQL ORC Vectorized 189 / 194 83.1 12.0 86.4X - SQL ORC Vectorized with copy 244 / 250 64.5 15.5 67.0X - SQL ORC MR 1341 / 1386 11.7 85.3 12.2X + SQL CSV 24090 / 24097 0.7 1531.6 1.0X + SQL Json 8791 / 8813 1.8 558.9 2.7X + SQL Parquet Vectorized 204 / 212 77.0 13.0 117.9X + SQL Parquet MR 1813 / 1850 8.7 115.3 13.3X + SQL ORC Vectorized 226 / 230 69.7 14.4 106.7X + SQL ORC Vectorized with copy 295 / 298 53.3 18.8 81.6X + SQL ORC MR 1526 / 1549 10.3 97.1 15.8X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 17874 / 17875 0.9 1136.4 1.0X - SQL Json 9190 / 9204 1.7 584.3 1.9X - SQL Parquet Vectorized 141 / 160 111.2 9.0 126.4X - SQL Parquet MR 1930 / 2049 8.2 122.7 9.3X - SQL ORC Vectorized 259 / 264 60.7 16.5 69.0X - SQL ORC Vectorized with copy 265 / 272 59.4 16.8 67.5X - SQL ORC MR 1528 / 1569 10.3 97.2 11.7X + SQL CSV 25637 / 25791 0.6 1629.9 1.0X + SQL Json 9532 / 9570 1.7 606.0 2.7X + SQL Parquet Vectorized 181 / 191 86.8 11.5 141.5X + SQL Parquet MR 2210 / 2227 7.1 140.5 11.6X + SQL ORC Vectorized 309 / 317 50.9 19.6 83.0X + SQL ORC Vectorized with copy 316 / 322 49.8 20.1 81.2X + SQL ORC MR 1650 / 1680 9.5 104.9 15.5X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 22812 / 22839 0.7 1450.4 1.0X - SQL Json 12026 / 12054 1.3 764.6 1.9X - SQL Parquet Vectorized 222 / 227 70.8 14.1 102.6X - SQL Parquet MR 2199 / 2204 7.2 139.8 10.4X - SQL ORC Vectorized 331 / 335 47.6 21.0 69.0X - SQL ORC Vectorized with copy 338 / 343 46.6 21.5 67.6X - SQL ORC MR 1618 / 1622 9.7 102.9 14.1X + SQL CSV 31617 / 31764 0.5 2010.1 1.0X + SQL Json 12440 / 12451 1.3 790.9 2.5X + SQL Parquet Vectorized 284 / 315 55.4 18.0 111.4X + SQL Parquet MR 2382 / 2390 6.6 151.5 13.3X + SQL ORC Vectorized 398 / 403 39.5 25.3 79.5X + SQL ORC Vectorized with copy 410 / 413 38.3 26.1 77.1X + SQL ORC MR 1783 / 1813 8.8 113.4 17.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 18703 / 18740 0.8 1189.1 1.0X - SQL Json 11779 / 11869 1.3 748.9 1.6X - SQL Parquet Vectorized 143 / 145 110.1 9.1 130.9X - SQL Parquet MR 1954 / 1963 8.0 124.2 9.6X - SQL ORC Vectorized 347 / 355 45.3 22.1 53.8X - SQL ORC Vectorized with copy 356 / 359 44.1 22.7 52.5X - SQL ORC MR 1570 / 1598 10.0 99.8 11.9X + SQL CSV 26679 / 26742 0.6 1696.2 1.0X + SQL Json 12490 / 12541 1.3 794.1 2.1X + SQL Parquet Vectorized 174 / 183 90.4 11.1 153.3X + SQL Parquet MR 2201 / 2223 7.1 140.0 12.1X + SQL ORC Vectorized 415 / 429 37.9 26.4 64.3X + SQL ORC Vectorized with copy 422 / 428 37.2 26.9 63.2X + SQL ORC MR 1767 / 1773 8.9 112.3 15.1X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 23832 / 23838 0.7 1515.2 1.0X - SQL Json 16204 / 16226 1.0 1030.2 1.5X - SQL Parquet Vectorized 242 / 306 65.1 15.4 98.6X - SQL Parquet MR 2462 / 2482 6.4 156.5 9.7X - SQL ORC Vectorized 419 / 451 37.6 26.6 56.9X - SQL ORC Vectorized with copy 426 / 447 36.9 27.1 55.9X - SQL ORC MR 1885 / 1931 8.3 119.8 12.6X + SQL CSV 34223 / 34324 0.5 2175.8 1.0X + SQL Json 17784 / 17785 0.9 1130.7 1.9X + SQL Parquet Vectorized 277 / 283 56.7 17.6 123.4X + SQL Parquet MR 2356 / 2386 6.7 149.8 14.5X + SQL ORC Vectorized 533 / 536 29.5 33.9 64.2X + SQL ORC Vectorized with copy 541 / 546 29.1 34.4 63.3X + SQL ORC MR 2166 / 2177 7.3 137.7 15.8X */ sqlBenchmark.run() @@ -294,41 +296,42 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 187 / 201 84.2 11.9 1.0X - ParquetReader Vectorized -> Row 101 / 103 156.4 6.4 1.9X + ParquetReader Vectorized 198 / 202 79.4 12.6 1.0X + ParquetReader Vectorized -> Row 119 / 121 132.3 7.6 1.7X Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 272 / 288 57.8 17.3 1.0X - ParquetReader Vectorized -> Row 213 / 219 73.7 13.6 1.3X + ParquetReader Vectorized 282 / 287 55.8 17.9 1.0X + ParquetReader Vectorized -> Row 246 / 247 64.0 15.6 1.1X Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 252 / 288 62.5 16.0 1.0X - ParquetReader Vectorized -> Row 232 / 246 67.7 14.8 1.1X + ParquetReader Vectorized 258 / 262 60.9 16.4 1.0X + ParquetReader Vectorized -> Row 259 / 260 60.8 16.5 1.0X Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 415 / 454 37.9 26.4 1.0X - ParquetReader Vectorized -> Row 407 / 432 38.6 25.9 1.0X + ParquetReader Vectorized 361 / 369 43.6 23.0 1.0X + ParquetReader Vectorized -> Row 361 / 371 43.6 22.9 1.0X Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 251 / 302 62.7 16.0 1.0X - ParquetReader Vectorized -> Row 220 / 234 71.5 14.0 1.1X + ParquetReader Vectorized 253 / 261 62.2 16.1 1.0X + ParquetReader Vectorized -> Row 254 / 256 61.9 16.2 1.0X Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 432 / 436 36.4 27.5 1.0X - ParquetReader Vectorized -> Row 414 / 422 38.0 26.4 1.0X + ParquetReader Vectorized 357 / 364 44.0 22.7 1.0X + ParquetReader Vectorized -> Row 358 / 366 44.0 22.7 1.0X */ parquetReaderBenchmark.run() } @@ -382,16 +385,17 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 19172 / 19173 0.5 1828.4 1.0X - SQL Json 12799 / 12873 0.8 1220.6 1.5X - SQL Parquet Vectorized 2558 / 2564 4.1 244.0 7.5X - SQL Parquet MR 4514 / 4583 2.3 430.4 4.2X - SQL ORC Vectorized 2561 / 2697 4.1 244.3 7.5X - SQL ORC Vectorized with copy 3076 / 3110 3.4 293.4 6.2X - SQL ORC MR 4197 / 4283 2.5 400.2 4.6X + SQL CSV 27145 / 27158 0.4 2588.7 1.0X + SQL Json 12969 / 13337 0.8 1236.8 2.1X + SQL Parquet Vectorized 2419 / 2448 4.3 230.7 11.2X + SQL Parquet MR 4631 / 4633 2.3 441.7 5.9X + SQL ORC Vectorized 2412 / 2465 4.3 230.0 11.3X + SQL ORC Vectorized with copy 2633 / 2675 4.0 251.1 10.3X + SQL ORC MR 4280 / 4350 2.4 408.2 6.3X */ benchmark.run() } @@ -445,16 +449,17 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 10889 / 10924 1.0 1038.5 1.0X - SQL Json 7903 / 7931 1.3 753.7 1.4X - SQL Parquet Vectorized 777 / 799 13.5 74.1 14.0X - SQL Parquet MR 1682 / 1708 6.2 160.4 6.5X - SQL ORC Vectorized 532 / 534 19.7 50.7 20.5X - SQL ORC Vectorized with copy 742 / 743 14.1 70.7 14.7X - SQL ORC MR 1996 / 2002 5.3 190.4 5.5X + SQL CSV 17345 / 17424 0.6 1654.1 1.0X + SQL Json 8639 / 8664 1.2 823.9 2.0X + SQL Parquet Vectorized 839 / 854 12.5 80.0 20.7X + SQL Parquet MR 1771 / 1775 5.9 168.9 9.8X + SQL ORC Vectorized 550 / 569 19.1 52.4 31.6X + SQL ORC Vectorized with copy 785 / 849 13.4 74.9 22.1X + SQL ORC MR 2168 / 2202 4.8 206.7 8.0X */ benchmark.run() } @@ -574,30 +579,31 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - Data column - CSV 25428 / 25454 0.6 1616.7 1.0X - Data column - Json 12689 / 12774 1.2 806.7 2.0X - Data column - Parquet Vectorized 222 / 231 70.7 14.1 114.3X - Data column - Parquet MR 3355 / 3397 4.7 213.3 7.6X - Data column - ORC Vectorized 332 / 338 47.4 21.1 76.6X - Data column - ORC Vectorized with copy 338 / 341 46.5 21.5 75.2X - Data column - ORC MR 2329 / 2356 6.8 148.0 10.9X - Partition column - CSV 17465 / 17502 0.9 1110.4 1.5X - Partition column - Json 10865 / 10876 1.4 690.8 2.3X - Partition column - Parquet Vectorized 48 / 52 325.4 3.1 526.1X - Partition column - Parquet MR 1695 / 1696 9.3 107.8 15.0X - Partition column - ORC Vectorized 49 / 54 319.9 3.1 517.2X - Partition column - ORC Vectorized with copy 49 / 52 324.1 3.1 524.0X - Partition column - ORC MR 1548 / 1549 10.2 98.4 16.4X - Both columns - CSV 25568 / 25595 0.6 1625.6 1.0X - Both columns - Json 13658 / 13673 1.2 868.4 1.9X - Both columns - Parquet Vectorized 270 / 296 58.3 17.1 94.3X - Both columns - Parquet MR 3501 / 3521 4.5 222.6 7.3X - Both columns - ORC Vectorized 377 / 380 41.7 24.0 67.4X - Both column - ORC Vectorized with copy 447 / 448 35.2 28.4 56.9X - Both columns - ORC MR 2440 / 2446 6.4 155.2 10.4X + Data column - CSV 32613 / 32841 0.5 2073.4 1.0X + Data column - Json 13343 / 13469 1.2 848.3 2.4X + Data column - Parquet Vectorized 302 / 318 52.1 19.2 108.0X + Data column - Parquet MR 2908 / 2924 5.4 184.9 11.2X + Data column - ORC Vectorized 412 / 425 38.1 26.2 79.1X + Data column - ORC Vectorized with copy 442 / 446 35.6 28.1 73.8X + Data column - ORC MR 2390 / 2396 6.6 152.0 13.6X + Partition column - CSV 9626 / 9683 1.6 612.0 3.4X + Partition column - Json 10909 / 10923 1.4 693.6 3.0X + Partition column - Parquet Vectorized 69 / 76 228.4 4.4 473.6X + Partition column - Parquet MR 1898 / 1933 8.3 120.7 17.2X + Partition column - ORC Vectorized 67 / 74 236.0 4.2 489.4X + Partition column - ORC Vectorized with copy 65 / 72 241.9 4.1 501.6X + Partition column - ORC MR 1743 / 1749 9.0 110.8 18.7X + Both columns - CSV 35523 / 35552 0.4 2258.5 0.9X + Both columns - Json 13676 / 13681 1.2 869.5 2.4X + Both columns - Parquet Vectorized 317 / 326 49.5 20.2 102.7X + Both columns - Parquet MR 3333 / 3336 4.7 211.9 9.8X + Both columns - ORC Vectorized 441 / 446 35.6 28.1 73.9X + Both column - ORC Vectorized with copy 517 / 524 30.4 32.9 63.1X + Both columns - ORC MR 2574 / 2577 6.1 163.6 12.7X */ benchmark.run() } @@ -684,41 +690,42 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 13518 / 13529 0.8 1289.2 1.0X - SQL Json 10895 / 10926 1.0 1039.0 1.2X - SQL Parquet Vectorized 1539 / 1581 6.8 146.8 8.8X - SQL Parquet MR 3746 / 3811 2.8 357.3 3.6X - ParquetReader Vectorized 1070 / 1112 9.8 102.0 12.6X - SQL ORC Vectorized 1389 / 1408 7.6 132.4 9.7X - SQL ORC Vectorized with copy 1736 / 1750 6.0 165.6 7.8X - SQL ORC MR 3799 / 3892 2.8 362.3 3.6X + SQL CSV 14875 / 14920 0.7 1418.6 1.0X + SQL Json 10974 / 10992 1.0 1046.5 1.4X + SQL Parquet Vectorized 1711 / 1750 6.1 163.2 8.7X + SQL Parquet MR 3838 / 3884 2.7 366.0 3.9X + ParquetReader Vectorized 1155 / 1168 9.1 110.2 12.9X + SQL ORC Vectorized 1341 / 1380 7.8 127.9 11.1X + SQL ORC Vectorized with copy 1659 / 1716 6.3 158.2 9.0X + SQL ORC MR 3594 / 3634 2.9 342.7 4.1X String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 10854 / 10892 1.0 1035.2 1.0X - SQL Json 8129 / 8138 1.3 775.3 1.3X - SQL Parquet Vectorized 1053 / 1104 10.0 100.4 10.3X - SQL Parquet MR 2840 / 2854 3.7 270.8 3.8X - ParquetReader Vectorized 978 / 1008 10.7 93.2 11.1X - SQL ORC Vectorized 1312 / 1387 8.0 125.1 8.3X - SQL ORC Vectorized with copy 1764 / 1772 5.9 168.2 6.2X - SQL ORC MR 3435 / 3445 3.1 327.6 3.2X + SQL CSV 17219 / 17264 0.6 1642.1 1.0X + SQL Json 8843 / 8864 1.2 843.3 1.9X + SQL Parquet Vectorized 1169 / 1178 9.0 111.4 14.7X + SQL Parquet MR 2676 / 2697 3.9 255.2 6.4X + ParquetReader Vectorized 1068 / 1071 9.8 101.8 16.1X + SQL ORC Vectorized 1319 / 1319 7.9 125.8 13.1X + SQL ORC Vectorized with copy 1638 / 1639 6.4 156.2 10.5X + SQL ORC MR 3230 / 3257 3.2 308.1 5.3X String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 8043 / 8048 1.3 767.1 1.0X - SQL Json 4911 / 4923 2.1 468.4 1.6X - SQL Parquet Vectorized 206 / 209 51.0 19.6 39.1X - SQL Parquet MR 1528 / 1537 6.9 145.8 5.3X - ParquetReader Vectorized 216 / 219 48.6 20.6 37.2X - SQL ORC Vectorized 462 / 466 22.7 44.1 17.4X - SQL ORC Vectorized with copy 568 / 572 18.5 54.2 14.2X - SQL ORC MR 1647 / 1649 6.4 157.1 4.9X + SQL CSV 13976 / 14053 0.8 1332.8 1.0X + SQL Json 5166 / 5176 2.0 492.6 2.7X + SQL Parquet Vectorized 274 / 282 38.2 26.2 50.9X + SQL Parquet MR 1553 / 1555 6.8 148.1 9.0X + ParquetReader Vectorized 241 / 246 43.5 23.0 57.9X + SQL ORC Vectorized 476 / 479 22.0 45.4 29.3X + SQL ORC Vectorized with copy 584 / 588 17.9 55.7 23.9X + SQL ORC MR 1720 / 1734 6.1 164.1 8.1X */ benchmark.run() } @@ -773,38 +780,39 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 3663 / 3665 0.3 3493.2 1.0X - SQL Json 3122 / 3160 0.3 2977.5 1.2X - SQL Parquet Vectorized 40 / 42 26.2 38.2 91.5X - SQL Parquet MR 189 / 192 5.5 180.2 19.4X - SQL ORC Vectorized 48 / 51 21.6 46.2 75.6X - SQL ORC Vectorized with copy 49 / 52 21.4 46.7 74.9X - SQL ORC MR 280 / 289 3.7 267.1 13.1X + SQL CSV 3478 / 3481 0.3 3316.4 1.0X + SQL Json 2646 / 2654 0.4 2523.6 1.3X + SQL Parquet Vectorized 67 / 72 15.8 63.5 52.2X + SQL Parquet MR 207 / 214 5.1 197.6 16.8X + SQL ORC Vectorized 69 / 76 15.2 66.0 50.3X + SQL ORC Vectorized with copy 70 / 76 15.0 66.5 49.9X + SQL ORC MR 299 / 303 3.5 285.1 11.6X Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 11420 / 11505 0.1 10891.1 1.0X - SQL Json 11905 / 12120 0.1 11353.6 1.0X - SQL Parquet Vectorized 50 / 54 20.9 47.8 227.7X - SQL Parquet MR 195 / 199 5.4 185.8 58.6X - SQL ORC Vectorized 61 / 65 17.3 57.8 188.3X - SQL ORC Vectorized with copy 62 / 65 17.0 58.8 185.2X - SQL ORC MR 847 / 865 1.2 807.4 13.5X + SQL CSV 9214 / 9236 0.1 8786.7 1.0X + SQL Json 9943 / 9978 0.1 9482.7 0.9X + SQL Parquet Vectorized 77 / 86 13.6 73.3 119.8X + SQL Parquet MR 229 / 235 4.6 218.6 40.2X + SQL ORC Vectorized 84 / 96 12.5 80.0 109.9X + SQL ORC Vectorized with copy 83 / 91 12.6 79.4 110.7X + SQL ORC MR 843 / 854 1.2 804.0 10.9X - Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 21278 / 21404 0.0 20292.4 1.0X - SQL Json 22455 / 22625 0.0 21414.7 0.9X - SQL Parquet Vectorized 73 / 75 14.4 69.3 292.8X - SQL Parquet MR 220 / 226 4.8 209.7 96.8X - SQL ORC Vectorized 82 / 86 12.8 78.2 259.4X - SQL ORC Vectorized with copy 82 / 90 12.7 78.7 258.0X - SQL ORC MR 1568 / 1582 0.7 1495.4 13.6X + SQL CSV 16503 / 16622 0.1 15738.9 1.0X + SQL Json 19109 / 19184 0.1 18224.2 0.9X + SQL Parquet Vectorized 99 / 108 10.6 94.3 166.8X + SQL Parquet MR 253 / 264 4.1 241.6 65.1X + SQL ORC Vectorized 107 / 114 9.8 101.6 154.8X + SQL ORC Vectorized with copy 107 / 118 9.8 102.1 154.1X + SQL ORC MR 1526 / 1529 0.7 1455.3 10.8X */ benchmark.run() } From 6a97e8eb31da76fe5af912a6304c07b63735062f Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 28 Jun 2018 09:59:00 +0800 Subject: [PATCH 1034/2461] [SPARK-24603][SQL] Fix findTightestCommonType reference in comments findTightestCommonTypeOfTwo has been renamed to findTightestCommonType ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Fokko Driesprong Closes #21597 from Fokko/fd-typo. --- .../sql/execution/datasources/json/JsonInferSchema.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index f6edc7bfb3750..8e1b430f4eb33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -334,8 +334,8 @@ private[sql] object JsonInferSchema { ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in - // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when - // the given `DecimalType` is not capable of the given `IntegralType`. + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => From 5b0596648854c0c733b7c607661b78af7df18b89 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 28 Jun 2018 14:19:50 +0800 Subject: [PATCH 1035/2461] [SPARK-24564][TEST] Add test suite for RecordBinaryComparator ## What changes were proposed in this pull request? Add a new test suite to test RecordBinaryComparator. ## How was this patch tested? New test suite. Author: Xingbo Jiang Closes #21570 from jiangxb1987/rbc-test. --- .../spark/memory/TestMemoryConsumer.java | 10 + .../sort/RecordBinaryComparatorSuite.java | 256 ++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb6..0bbaea6b834b8 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -17,6 +17,10 @@ package org.apache.spark.memory; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.unsafe.memory.MemoryBlock; + import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { @@ -43,6 +47,12 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + @VisibleForTesting + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 0000000000000..a19ddbdbadba2 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } +} From 524827f0626281847582ec3056982db7eb83f8b1 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 28 Jun 2018 12:40:39 -0700 Subject: [PATCH 1036/2461] [SPARK-14712][ML] LogisticRegressionModel.toString should summarize model ## What changes were proposed in this pull request? [SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712) spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should do the same in spark.ml and override repr in pyspark. ## How was this patch tested? LogisticRegressionSuite.scala Python doctest in pyspark.ml.classification.py Author: bravo-zhang Closes #18826 from bravo-zhang/spark-14712. --- .../apache/spark/ml/classification/LogisticRegression.scala | 5 +++++ .../spark/ml/classification/LogisticRegressionSuite.scala | 6 ++++++ python/pyspark/ml/classification.py | 5 +++++ python/pyspark/mllib/classification.py | 3 +++ 4 files changed, 19 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 06ca37bc75146..92e342ed4a464 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 36b7e51f93d01..75c2aeb146786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1754c48937a62..d5963f4f7042c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -562,6 +564,9 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bb281981fd56b..e00ed95ef0701 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -258,6 +258,9 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ From a95a4af76459016b0d52df90adab68a49904da99 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 28 Jun 2018 13:20:08 -0700 Subject: [PATCH 1037/2461] [SPARK-23120][PYSPARK][ML] Add basic PMML export support to PySpark ## What changes were proposed in this pull request? Adds basic PMML export support for Spark ML stages to PySpark as was previously done in Scala. Includes LinearRegressionModel as the first stage to implement. ## How was this patch tested? Doctest, the main testing work for this is on the Scala side. (TODO holden add the unittest once I finish locally). Author: Holden Karau Closes #21172 from holdenk/SPARK-23120-add-pmml-export-support-to-pyspark. --- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/tests.py | 17 ++++++++++++ python/pyspark/ml/util.py | 46 +++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dba0e57b01a0b..83f0edb397271 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.numFeatures 1 + >>> model.write().format("pmml").save(model_path + "_2") .. versionadded:: 1.4.0 """ @@ -161,7 +162,7 @@ def getEpsilon(self): return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ebd36cbb5f7a7..bc782138292bf 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1362,6 +1362,23 @@ def test_linear_regression(self): except OSError: pass + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def test_logistic_regression(self): lr = LogisticRegression(maxIter=1) path = tempfile.mkdtemp() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9fa85664939b8..080cd299f4fde 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -148,6 +148,23 @@ def overwrite(self): return self +@inherit_doc +class GeneralMLWriter(MLWriter): + """ + Utility class that can save ML instances in different formats. + + .. versionadded:: 2.4.0 + """ + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self.source = source + return self + + @inherit_doc class JavaMLWriter(MLWriter): """ @@ -192,6 +209,24 @@ def session(self, sparkSession): return self +@inherit_doc +class GeneralJavaMLWriter(JavaMLWriter): + """ + (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(GeneralJavaMLWriter, self).__init__(instance) + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self._jwrite.format(source) + return self + + @inherit_doc class MLWritable(object): """ @@ -220,6 +255,17 @@ def write(self): return JavaMLWriter(self) +@inherit_doc +class GeneralJavaMLWritable(JavaMLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. + """ + + def write(self): + """Returns an GeneralMLWriter instance for this ML instance.""" + return GeneralJavaMLWriter(self) + + @inherit_doc class MLReader(BaseReadWrite): """ From e1d3f80103f6df2eb8a962607dd5427df4b355dd Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 28 Jun 2018 13:22:52 -0700 Subject: [PATCH 1038/2461] [SPARK-24408][SQL][DOC] Move abs function to math_funcs group ## What changes were proposed in this pull request? A few math functions (`abs` , `bitwiseNOT`, `isnan`, `nanvl`) are not in **math_funcs** group. They should really be. ## How was this patch tested? Awaiting Jenkins Author: Jacek Laskowski Closes #21448 from jaceklaskowski/SPARK-24408-math-funcs-doc. --- .../scala/org/apache/spark/sql/functions.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0b4f526799578..acca9572cb14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1031,14 +1031,6 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Computes the absolute value. - * - * @group normal_funcs - * @since 1.3.0 - */ - def abs(e: Column): Column = withExpr { Abs(e.expr) } - /** * Creates a new array column. The input columns must all have the same data type. * @@ -1336,7 +1328,7 @@ object functions { } /** - * Computes bitwise NOT. + * Computes bitwise NOT (~) of a number. * * @group normal_funcs * @since 1.4.0 @@ -1364,6 +1356,14 @@ object functions { // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value of a numeric value. + * + * @group math_funcs + * @since 1.3.0 + */ + def abs(e: Column): Column = withExpr { Abs(e.expr) } + /** * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * From 2224861f2f93830d736b625c9a4cb72c918512b2 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 28 Jun 2018 14:07:28 -0700 Subject: [PATCH 1039/2461] [SPARK-24439][ML][PYTHON] Add distanceMeasure to BisectingKMeans in PySpark ## What changes were proposed in this pull request? add distanceMeasure to BisectingKMeans in Python. ## How was this patch tested? added doctest and also manually tested it. Author: Huaxin Gao Closes #21557 from huaxingao/spark-24439. --- python/pyspark/ml/clustering.py | 35 +++++++++++++------ .../ml/param/_shared_params_code_gen.py | 4 ++- python/pyspark/ml/param/shared.py | 24 +++++++++++++ 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 4aa1cf84b5824..6d77baf7349e4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -349,8 +349,8 @@ def summary(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - JavaMLWritable, JavaMLReadable): +class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, HasMaxIter, + HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). @@ -406,9 +406,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) - distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -544,8 +541,8 @@ def summary(self): @inherit_doc -class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, - JavaMLWritable, JavaMLReadable): +class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, + HasMaxIter, HasSeed, JavaMLWritable, JavaMLReadable): """ A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -585,6 +582,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> bkm2 = BisectingKMeans.load(bkm_path) >>> bkm2.getK() 2 + >>> bkm2.getDistanceMeasure() + 'euclidean' >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) @@ -607,10 +606,10 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") """ super(BisectingKMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", @@ -622,10 +621,10 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 @keyword_only @since("2.0.0") def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") Sets params for BisectingKMeans. """ kwargs = self._input_kwargs @@ -659,6 +658,20 @@ def getMinDivisibleClusterSize(self): """ return self.getOrDefault(self.minDivisibleClusterSize) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + def _create_model(self, java_model): return BisectingKMeansModel(java_model) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6e9e0a34cdfde..e45ba840b412b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -162,7 +162,9 @@ def get$Name(self): "fitting. If set to true, then all sub-models will be available. Warning: For large " + "models, collecting all sub-models can cause OOMs on the Spark driver.", "False", "TypeConverters.toBoolean"), - ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] + ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"), + ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", + "'euclidean'", "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 08408ee8fbfcc..618f5bf0a8103 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -790,3 +790,27 @@ def getCacheNodeIds(self): """ return self.getOrDefault(self.cacheNodeIds) + +class HasDistanceMeasure(Params): + """ + Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. + """ + + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasDistanceMeasure, self).__init__() + self._setDefault(distanceMeasure='euclidean') + + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + def getDistanceMeasure(self): + """ + Gets the value of distanceMeasure or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + From f6e6899a8b8af99cd06e84cae7c69e0fc35bc60a Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 28 Jun 2018 16:25:40 -0700 Subject: [PATCH 1040/2461] [SPARK-24386][SS] coalesce(1) aggregates in continuous processing ## What changes were proposed in this pull request? Provide a continuous processing implementation of coalesce(1), as well as allowing aggregates on top of it. The changes in ContinuousQueuedDataReader and such are to use split.index (the ID of the partition within the RDD currently being compute()d) rather than context.partitionId() (the partition ID of the scheduled task within the Spark job - that is, the post coalesce writer). In the absence of a narrow dependency, these values were previously always the same, so there was no need to distinguish. ## How was this patch tested? new unit test Author: Jose Torres Closes #21560 from jose-torres/coalesce. --- .../UnsupportedOperationChecker.scala | 11 ++ .../datasources/v2/DataSourceV2Strategy.scala | 16 ++- .../continuous/ContinuousCoalesceExec.scala | 51 +++++++ .../continuous/ContinuousCoalesceRDD.scala | 136 ++++++++++++++++++ .../continuous/ContinuousDataSourceRDD.scala | 7 +- .../continuous/ContinuousExecution.scala | 4 + .../ContinuousQueuedDataReader.scala | 6 +- .../shuffle/ContinuousShuffleReadRDD.scala | 10 +- .../shuffle/RPCContinuousShuffleReader.scala | 4 +- .../sources/ContinuousMemoryStream.scala | 11 +- .../ContinuousAggregationSuite.scala | 63 +++++++- .../ContinuousQueuedDataReaderSuite.scala | 2 +- .../shuffle/ContinuousShuffleSuite.scala | 7 +- 13 files changed, 310 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 2bed41672fe33..5ced1ca200daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -349,6 +349,17 @@ object UnsupportedOperationChecker { _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 182aa2906cf1e..2a7f1de2c7c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,11 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { @@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy { case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case Repartition(1, false, child) => + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..ba85b355f974f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index a7ccce10b0cee..73868d5967e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -51,11 +51,11 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) + private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -74,8 +74,7 @@ class ContinuousDataSourceRDD( val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { partition.queueReader = - new ContinuousQueuedDataReader( - partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e3d0cea608b2a..a0bb8292d7766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -216,6 +216,9 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = @@ -382,4 +385,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index f38577b6a9f16..8c74b8244d096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: InputPartition[UnsafeRow], + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.createPartitionReader() + private val reader = partition.inputPartition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index cf6572d3de1f7..518223f3cd008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -21,12 +21,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition( index: Int, + endpointName: String, queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long) @@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition( val env = SparkEnv.get.rpcEnv val receiver = new RPCContinuousShuffleReader( queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + val endpoint = env.setupEndpoint(endpointName, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD( numPartitions: Int, queueSize: Int = 1024, numShuffleWriters: Int = 1, - epochIntervalMs: Long = 1000) + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 834e84675c7d5..502ae0d4822e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class RPCContinuousShuffleReader( +private[continuous] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader( } logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + - s"for writers $writerIdsUncommitted to send epoch markers.") + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") // The completion service guarantees this future will be available immediately. case future => future.get() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index d1c3498450096..0bf90b8063326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ @@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader( private var currentOffset = startOffset private var current: Option[Row] = None + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def next(): Boolean = { current = getRecord while (current.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index b7ef637f5270e..0223812600961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -31,7 +31,8 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { testStream(input.toDF().agg(max('value)), OutputMode.Complete)() } - assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) } test("basic") { @@ -50,6 +51,66 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } } + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .planWithBarrier + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .planWithBarrier + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + test("repeated restart") { withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { val input = ContinuousMemoryStream.singlePartition[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e663fa8312da4..0e7e6febb53df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -92,7 +92,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { } } val reader = new ContinuousQueuedDataReader( - factory, + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index a8e3611b585cf..f84f3d49707bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import java.util.UUID + import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} @@ -124,7 +126,10 @@ class ContinuousShuffleSuite extends StreamTest { } test("reader - multiple partitions") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] From f71e8da5efde96aacc89e59c6e27b71fffcbc25f Mon Sep 17 00:00:00 2001 From: xueyu <278006819@qq.com> Date: Fri, 29 Jun 2018 10:44:17 -0700 Subject: [PATCH 1041/2461] [SPARK-24566][CORE] Fix spark.storage.blockManagerSlaveTimeoutMs default config This PR use spark.network.timeout in place of spark.storage.blockManagerSlaveTimeoutMs when it is not configured, as configuration doc said manual test Author: xueyu <278006819@qq.com> Closes #21575 from xueyumusic/slaveTimeOutConfig. --- core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala | 5 ++--- .../cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ff960b396dbf1..bcbc8df0d5865 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -74,10 +74,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = - sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") private val executorTimeoutMs = - sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s") // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index d35bea4aca311..1ce2f816dffb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -634,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } From 03545ce6de08bd0ad685c5f59b73bc22dfc40887 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 30 Jun 2018 13:58:50 +0800 Subject: [PATCH 1042/2461] [SPARK-24638][SQL] StringStartsWith support push down ## What changes were proposed in this pull request? `StringStartsWith` support push down. About 50% savings in compute time. ## How was this patch tested? unit tests, manual tests and performance test: ```scala cat < SPARK-24638.scala def benchmark(func: () => Unit): Long = { val start = System.currentTimeMillis() for(i <- 0 until 100) { func() } val end = System.currentTimeMillis() end - start } val path = "/tmp/spark/parquet/string/" spark.range(10000000).selectExpr("concat(id, 'str', id) as id").coalesce(1).write.mode("overwrite").option("parquet.block.size", 1048576).parquet(path) val df = spark.read.parquet(path) spark.sql("set spark.sql.parquet.filterPushdown.string.startsWith=true") val pushdownEnable = benchmark(() => df.where("id like '999998%'").count()) spark.sql("set spark.sql.parquet.filterPushdown.string.startsWith=false") val pushdownDisable = benchmark(() => df.where("id like '999998%'").count()) val improvements = pushdownDisable - pushdownEnable println(s"improvements: $improvements") EOF bin/spark-shell -i SPARK-24638.scala ``` result: ```scala Loading SPARK-24638.scala... benchmark: (func: () => Unit)Long path: String = /tmp/spark/parquet/string/ df: org.apache.spark.sql.DataFrame = [id: string] res1: org.apache.spark.sql.DataFrame = [key: string, value: string] pushdownEnable: Long = 11608 res2: org.apache.spark.sql.DataFrame = [key: string, value: string] pushdownDisable: Long = 31981 improvements: Long = 20373 ``` Author: Yuming Wang Closes #21623 from wangyum/SPARK-24638. --- .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../parquet/ParquetFileFormat.scala | 4 +- .../datasources/parquet/ParquetFilters.scala | 35 +++++++- .../parquet/ParquetFilterSuite.scala | 84 ++++++++++++++++++- 4 files changed, 130 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e1752ff997b69..da1c34cdc78f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -378,6 +378,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.string.startsWith") + .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1459,6 +1467,9 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = + getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 9602a08911dea..93de1faef527a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -348,6 +348,7 @@ class ParquetFileFormat // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -358,7 +359,8 @@ class ParquetFileFormat // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) + .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) + .createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 310626197a763..21c9e2e4f82b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -22,16 +22,18 @@ import java.sql.Date import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.PrimitiveComparator import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] class ParquetFilters(pushDownDate: Boolean) { +private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -270,6 +272,37 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + Option(prefix).map { v => + FilterApi.userDefined(binaryColumn(name), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + UTF8String.fromBytes(value.getBytes).startsWith( + UTF8String.fromBytes(strToBinary.getBytes)) + } + } + ) + } + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90da7eb8c4fb5..d9ae5858e5ed0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { - private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + private lazy val parquetFilters = + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith) override def beforeEach(): Unit = { super.beforeEach() @@ -82,6 +83,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) @@ -140,6 +142,31 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. + private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + dataFrame.write.option("parquet.block.size", 512).parquet(path) + Seq(true, false).foreach { pushDown => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> pushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) + + val df = spark.read.parquet(path).filter(filter) + df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) + if (pushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } + + AccumulatorContext.remove(accu.id) + } + } + } + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -574,7 +601,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) - df.collect if (enablePushDown) { assert(accu.value == 0) @@ -660,6 +686,60 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(df.where("col > 0").count() === 2) } } + + test("filter pushdown - StringStartsWith") { + withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + checkFilterPredicate( + '_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "2str2") + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq.empty[Row]) + } + + checkFilterPredicate( + !'_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq().map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "3str3", "4str4").map(Row(_))) + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + } + + assertResult(None) { + parquetFilters.createFilter( + df.schema, + sources.StringStartsWith("_1", null)) + } + } + + import testImplicits._ + // Test canDrop() has taken effect + testStringStartsWith(spark.range(1024).map(_.toString).toDF(), "value like 'a%'") + // Test inverseCanDrop() has taken effect + testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 797971ed42cab41cbc3d039c0af4b26199bff783 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Fri, 29 Jun 2018 23:46:12 -0700 Subject: [PATCH 1043/2461] [SPARK-24696][SQL] ColumnPruning rule fails to remove extra Project ## What changes were proposed in this pull request? The ColumnPruning rule tries adding an extra Project if an input node produces fields more than needed, but as a post-processing step, it needs to remove the lower Project in the form of "Project - Filter - Project" otherwise it would conflict with PushPredicatesThroughProject and would thus cause a infinite optimization loop. The current post-processing method is defined as: ``` private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { case p1 Project(_, f Filter(_, p2 Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) } ``` This method works well when there is only one Filter but would not if there's two or more Filters. In this case, there is a deterministic filter and a non-deterministic filter so they stay as separate filter nodes and cannot be combined together. An simplified illustration of the optimization process that forms the infinite loop is shown below (F1 stands for the 1st filter, F2 for the 2nd filter, P for project, S for scan of relation, PredicatePushDown as abbrev. of PushPredicatesThroughProject): ``` F1 - F2 - P - S PredicatePushDown => F1 - P - F2 - S ColumnPruning => F1 - P - F2 - P - S => F1 - P - F2 - S (Project removed) PredicatePushDown => P - F1 - F2 - S ColumnPruning => P - F1 - P - F2 - S => P - F1 - P - F2 - P - S => P - F1 - F2 - P - S (only one Project removed) RemoveRedundantProject => F1 - F2 - P - S (goes back to the loop start) ``` So the problem is the ColumnPruning rule adds a Project under a Filter (and fails to remove it in the end), and that new Project triggers PushPredicateThroughProject. Once the filters have been push through the Project, a new Project will be added by the ColumnPruning rule and this goes on and on. The fix should be when adding Projects, the rule applies top-down, but later when removing extra Projects, the process should go bottom-up to ensure all extra Projects can be matched. ## How was this patch tested? Added a optimization rule test in ColumnPruningSuite; and a end-to-end test in SQLQuerySuite. Author: maryannxue Closes #21674 from maryannxue/spark-24696. --- .../spark/sql/catalyst/dsl/package.scala | 1 + .../sql/catalyst/optimizer/Optimizer.scala | 5 +++-- .../optimizer/ColumnPruningSuite.scala | 9 +++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 21 +++++++++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index efb2eba655e15..8cf69c6f3c922 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -149,6 +149,7 @@ package object dsl { } } + def rand(e: Long): Expression = Rand(Literal.create(e, LongType)) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aa992def1ce6c..2cc27d82f7d20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * so remove it. Since the Projects have been added top-down, we need to remove in bottom-up + * order, otherwise lower Projects can be missed. */ - private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 3f41f4b144096..8b05ba32e6eef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized2, expected2.analyze) } + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + val input = LocalRelation('key.int, 'value.string) + val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze + val optimized = Optimize.execute(query) + val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze + comparePlans(optimized, expected) + } + // todo: add more tests for column pruning } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 640affc10ee58..dfb9c137b74f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + withTable("fact_stats", "dim_stats") { + val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4)) + val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US")) + spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic) + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats") + storeData.toDF("store_id", "state_province", "country") + .write.mode("overwrite").format("parquet").saveAsTable("dim_stats") + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id FROM + |(SELECT date_id, product_id, store_id + | FROM fact_stats WHERE filterND(date_id)) AS f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) + checkAnswer(df, Seq(Row(3, 99, 1))) + } + } } From d54d8b86301581142293341af25fd78b3278a2e8 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 29 Jun 2018 23:51:13 -0700 Subject: [PATCH 1044/2461] simplify rand in dsl/package.scala --- .../main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8cf69c6f3c922..89e8c998f740d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -149,7 +149,7 @@ package object dsl { } } - def rand(e: Long): Expression = Rand(Literal.create(e, LongType)) + def rand(e: Long): Expression = Rand(e) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() From f825847c82042a9eee7bd5cfab106310d279fc32 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 30 Jun 2018 19:27:16 -0500 Subject: [PATCH 1045/2461] [SPARK-24654][BUILD] Update, fix LICENSE and NOTICE, and specialize for source vs binary Whew, lots of work to track down again all the license requirements, but this ought to be a pretty good pass. Below, find a writeup on how I approached it for future reference. - LICENSE and NOTICE and licenses/ now reflect the *source* release - LICENSE-binary and NOTICE-binary and licenses-binary now reflect the binary release - Recreated all the license info from scratch - Added notes about how this was constructed for next time - License-oriented info was moved from NOTICE to LICENSE, esp. for Cat B deps - Some seemingly superfluous or stale license info was removed, especially for test-scope deps - Updated release script to put binary-oriented versions in binary releases ---- # Principles ASF projects distribute source and binary code under the Apache License 2.0. However these project distributions frequently include copies of source or binary code from third parties, under possibly other license terms. This triggers conditions of those licenses, which essentially amount to including license information in a LICENSE and/or NOTICE file, and including copies of license texts (here, in a directory called `license/`). See http://www.apache.org/dev/licensing-howto.html and https://www.apache.org/legal/resolved.html#required-third-party-notices # In Spark Spark produces source releases, and also binary releases of that code. Spark source code may contain source from third parties, possibly modified. This is true in Scala, Java, Python and R, and in the UI's JavaScript and CSS files. These must be handled appropriately per above in a LICENSE and NOTICE file created for the source release. Separately, the binary releases may contain binary code from third parties. This is very much true for Scala and Java, as Spark produces an 'assembly' binary release which includes all transitive binary dependencies of this part of Spark. With perhaps the exception of py4j, this doesn't occur in the same way for Python or R because of the way these ecosystems work. (Note that the JS and CSS for the UI will be in both 'source' and 'binary' releases.) These must also be handled in a separate LICENSE and NOTICE file for the binary release. # Binary Release License ## Transitive Maven Dependencies We'll first tackle the binary release, and that almost entirely means assessing the transitive dependencies of the Scala/Java backbone of Spark. Run `project-info-reports:dependencies` with essentially all profiles: a set that would bring in all different possible transitive dependencies. However, don't activate any of the '-lgpl' profiles as these would bring in LGPL-licensed dependencies that are explicitly excluded from Spark binary releases. ``` mvn -Phadoop-2.7 -Pyarn -Phive -Pmesos -Pkubernetes -Pflume -Pkinesis-asl -Pdocker-integration-tests -Phive-thriftserver -Pkafka-0-8 -Ddependency.locations.enabled=false project-info-reports:dependencies ``` Open `assembly/target/site/dependencies.html`. Find "Project Transitive Dependencies", and find "compile" and "runtime" (if exists). This is a list of all the dependencies that Spark is going to ship in its binary "assembly" distro and therefore whose licenses need to be appropriately considered in LICENSE and NOTICE. Copy this table into a spreadsheet for easy management. Next job is to fill in some blanks, as a few projects will not have clearly declared their licenses in a POM. Sort by license. This is a good time to verify all the dependencies are at least Cat A/B licenses, and not Cat X! http://www.apache.org/legal/resolved.html ### Apache License 2 The Apache License 2 variants are typically easiest to deal with as they will not require you to modify LICENSE, nor add to license/. It's still good form to list the ALv2 dependencies in LICENSE for completeness, but optional. They may require you to propagate bits from NOTICE. It's tedious to track down all the NOTICE files and evaluate what if anything needs to be copied to NOTICE. Fortunately, this can be made easier as the assembly module can be temporarily modified to produce a NOTICE file that concatenates all NOTICE files bundled with transitive dependencies. First change the packaging of `assembly/spark-assembly_2.11/pom.xml` to `jar`. Next add this stanza somewhere in the body of the same POM file: ``` org.apache.maven.plugins maven-shade-plugin false *:* package shade ``` Finally execute `mvn ... package` with all of the same `-P` profile flags as above. In the JAR file at `assembly/target/spark-assembly_2.11....jar` you'll find a file `META-INF/NOTICE` that concatenates all NOTICE files bundled with transitive dependencies. This should be the starting point for the binary release's NOTICE file. Some elements in the file are from Spark itself, like: ``` Spark Project Assembly Copyright 2018 The Apache Software Foundation Spark Project Core Copyright 2018 The Apache Software Foundation ``` These can be removed. Remove elements of the combined NOTICE file that aren't relevant to Spark. It's actually rare that we are sure that some element is completely irrelevant to Spark, because each transitive dependency includes all its transitive dependencies. So there may be nothing that can be done here. Of course, some projects may not publish NOTICE in their Maven artifacts. Ideally, search for the NOTICE file of projects that don't seem to have produced any text in NOTICE, but, there is some argument that projects that don't produce a NOTICE in their Maven artifacts don't entail an obligation on projects that depend solely on their Maven artifacts. ### Other Licenses Next are "Cat A" permissively licensed (BSD 2-Clause, BSD 3-Clause, MIT) components. List the components grouped by their license type in LICENSE. Then add the text of the license to licenses/. For example if you list "foo bar" as a BSD-licensed dependency, add its license text as licenses/LICENSE-foo-bar.txt. Public domain and similar works are treated like permissively licensed dependencies. And the same goes for all Cat B licenses too, like CDDL. However these additional require at least a URL pointer to the project's page. Use the artifact hyperlink in your spreadsheet if possible; if non-existent or doesn't resolve, do your best to determine a URL for the project's source. ### Shaded third-party dependencies Some third party dependencies actually copy in other dependencies rather than depend on them as Maven artifacts. This means they don't show up in the process above. These can be quite hard to track down, but are rare. A key example is reflectasm, embedded in kryo. ### Examples module The above _almost_ considers everything bundled in a Spark binary release. The main assembly won't include examples. The same must be done for dependencies marked as 'compile' for the examples module. See `examples/target/site/dependencies.html`. At the time of this writing however this just adds one dependency: `scopt`. ### provided scope Above we considered just compile and runtime scope dependencies, which makes sense as they are the ones that are packaged. However, for complicated reasons (shading), a few components that Spark does bundle are not marked as compile dependencies in the assembly. Therefore it's also necessary to consider 'provided' dependencies from `assembly/target/site/dependencies.html` actually! Right now that's just Jetty and JPMML artifacts. ## Python, R Don't forget that Py4J is also distributed in the binary release, actually. There should be no other R, Python code in the binary release. That's it. ## Sense checking Compare the contents of `jars/`, `examples/jars/` and `python/lib` from a recent binary release to see if anything appears there that doesn't seem to have been covered above. These additional components will have to be handled manually, but should be few or none of this type. # Source Release License While there are relatively fewer third-party source artifacts included as source code, there is no automated way to detect it, really. It requires some degree of manual auditing. Most third party source comes from included JS and CSS files. At the time of this writing, some places to look or consider: `build/sbt-launch-lib.bash`, `python/lib`, third party source in `python/pyspark` like `heapq3.py`, `docs/js/vendor`, and `core/src/main/resources/org/apache/spark/ui/static`. The principles are the same as above. Remember some JS files copy in other JS files! Look out for Modernizr. # One More Thing: JS and CSS in Binary Release Now that you've got a handle on source licenses, recall that all the JS and CSS source code will *also* be part of the binary release. Copy that info from source to binary license files accordingly. Author: Sean Owen Closes #21640 from srowen/SPARK-24654. --- LICENSE | 158 +-- LICENSE-binary | 520 ++++++++ NOTICE | 661 ---------- NOTICE-binary | 1170 +++++++++++++++++ dev/.rat-excludes | 4 + dev/make-distribution.sh | 7 +- .../LICENSE-AnchorJS.txt | 0 licenses-binary/LICENSE-CC0.txt | 121 ++ .../LICENSE-antlr.txt | 0 licenses-binary/LICENSE-arpack.txt | 8 + licenses-binary/LICENSE-automaton.txt | 24 + licenses-binary/LICENSE-bootstrap.txt | 13 + .../LICENSE-bouncycastle-bcprov.txt | 7 + licenses-binary/LICENSE-cloudpickle.txt | 28 + licenses-binary/LICENSE-d3.min.js.txt | 26 + .../LICENSE-dagre-d3.txt | 4 +- licenses-binary/LICENSE-datatables.txt | 7 + {licenses => licenses-binary}/LICENSE-f2j.txt | 0 licenses-binary/LICENSE-graphlib-dot.txt | 19 + licenses-binary/LICENSE-heapq.txt | 280 ++++ licenses-binary/LICENSE-janino.txt | 31 + licenses-binary/LICENSE-javassist.html | 373 ++++++ .../LICENSE-javolution.txt | 0 .../LICENSE-jline.txt | 0 .../LICENSE-jodd.txt | 12 +- .../LICENSE-join.txt | 0 licenses-binary/LICENSE-jquery.txt | 20 + licenses-binary/LICENSE-json-formatter.txt | 6 + licenses-binary/LICENSE-jtransforms.html | 388 ++++++ .../LICENSE-kryo.txt | 0 licenses-binary/LICENSE-leveldbjni.txt | 27 + licenses-binary/LICENSE-machinist.txt | 19 + .../LICENSE-matchMedia-polyfill.txt | 1 + .../LICENSE-minlog.txt | 0 licenses-binary/LICENSE-modernizr.txt | 21 + .../LICENSE-netlib.txt | 0 .../LICENSE-paranamer.txt | 0 .../LICENSE-pmml-model.txt | 0 .../LICENSE-protobuf.txt | 0 licenses-binary/LICENSE-py4j.txt | 27 + .../LICENSE-pyrolite.txt | 0 .../LICENSE-reflectasm.txt | 0 licenses-binary/LICENSE-respond.txt | 22 + licenses-binary/LICENSE-sbt-launch-lib.txt | 26 + .../LICENSE-scala.txt | 0 licenses-binary/LICENSE-scopt.txt | 9 + .../LICENSE-slf4j.txt | 0 licenses-binary/LICENSE-sorttable.js.txt | 16 + .../LICENSE-spire.txt | 0 licenses-binary/LICENSE-vis.txt | 22 + .../LICENSE-xmlenc.txt | 0 .../LICENSE-zstd-jni.txt | 0 .../LICENSE-zstd.txt | 0 licenses/LICENSE-CC0.txt | 121 ++ licenses/LICENSE-SnapTree.txt | 35 - licenses/LICENSE-bootstrap.txt | 13 + licenses/LICENSE-boto.txt | 20 - licenses/LICENSE-datatables.txt | 7 + licenses/LICENSE-graphlib-dot.txt | 2 +- licenses/LICENSE-jbcrypt.txt | 17 - .../{LICENSE-jmock.txt => LICENSE-join.txt} | 22 +- licenses/LICENSE-jquery.txt | 23 +- licenses/LICENSE-json-formatter.txt | 6 + licenses/LICENSE-matchMedia-polyfill.txt | 1 + licenses/LICENSE-postgresql.txt | 24 - licenses/LICENSE-respond.txt | 22 + licenses/LICENSE-scalacheck.txt | 32 - licenses/LICENSE-vis.txt | 22 + 68 files changed, 3526 insertions(+), 918 deletions(-) create mode 100644 LICENSE-binary create mode 100644 NOTICE-binary rename licenses/LICENSE-scopt.txt => licenses-binary/LICENSE-AnchorJS.txt (100%) create mode 100644 licenses-binary/LICENSE-CC0.txt rename {licenses => licenses-binary}/LICENSE-antlr.txt (100%) create mode 100644 licenses-binary/LICENSE-arpack.txt create mode 100644 licenses-binary/LICENSE-automaton.txt create mode 100644 licenses-binary/LICENSE-bootstrap.txt create mode 100644 licenses-binary/LICENSE-bouncycastle-bcprov.txt create mode 100644 licenses-binary/LICENSE-cloudpickle.txt create mode 100644 licenses-binary/LICENSE-d3.min.js.txt rename licenses/LICENSE-Mockito.txt => licenses-binary/LICENSE-dagre-d3.txt (94%) create mode 100644 licenses-binary/LICENSE-datatables.txt rename {licenses => licenses-binary}/LICENSE-f2j.txt (100%) create mode 100644 licenses-binary/LICENSE-graphlib-dot.txt create mode 100644 licenses-binary/LICENSE-heapq.txt create mode 100644 licenses-binary/LICENSE-janino.txt create mode 100644 licenses-binary/LICENSE-javassist.html rename {licenses => licenses-binary}/LICENSE-javolution.txt (100%) rename {licenses => licenses-binary}/LICENSE-jline.txt (100%) rename licenses/LICENSE-junit-interface.txt => licenses-binary/LICENSE-jodd.txt (69%) rename licenses/LICENSE-DPark.txt => licenses-binary/LICENSE-join.txt (100%) create mode 100644 licenses-binary/LICENSE-jquery.txt create mode 100644 licenses-binary/LICENSE-json-formatter.txt create mode 100644 licenses-binary/LICENSE-jtransforms.html rename {licenses => licenses-binary}/LICENSE-kryo.txt (100%) create mode 100644 licenses-binary/LICENSE-leveldbjni.txt create mode 100644 licenses-binary/LICENSE-machinist.txt create mode 100644 licenses-binary/LICENSE-matchMedia-polyfill.txt rename {licenses => licenses-binary}/LICENSE-minlog.txt (100%) create mode 100644 licenses-binary/LICENSE-modernizr.txt rename {licenses => licenses-binary}/LICENSE-netlib.txt (100%) rename {licenses => licenses-binary}/LICENSE-paranamer.txt (100%) rename licenses/LICENSE-jpmml-model.txt => licenses-binary/LICENSE-pmml-model.txt (100%) rename {licenses => licenses-binary}/LICENSE-protobuf.txt (100%) create mode 100644 licenses-binary/LICENSE-py4j.txt rename {licenses => licenses-binary}/LICENSE-pyrolite.txt (100%) rename {licenses => licenses-binary}/LICENSE-reflectasm.txt (100%) create mode 100644 licenses-binary/LICENSE-respond.txt create mode 100644 licenses-binary/LICENSE-sbt-launch-lib.txt rename {licenses => licenses-binary}/LICENSE-scala.txt (100%) create mode 100644 licenses-binary/LICENSE-scopt.txt rename {licenses => licenses-binary}/LICENSE-slf4j.txt (100%) create mode 100644 licenses-binary/LICENSE-sorttable.js.txt rename {licenses => licenses-binary}/LICENSE-spire.txt (100%) create mode 100644 licenses-binary/LICENSE-vis.txt rename {licenses => licenses-binary}/LICENSE-xmlenc.txt (100%) rename {licenses => licenses-binary}/LICENSE-zstd-jni.txt (100%) rename {licenses => licenses-binary}/LICENSE-zstd.txt (100%) create mode 100644 licenses/LICENSE-CC0.txt delete mode 100644 licenses/LICENSE-SnapTree.txt create mode 100644 licenses/LICENSE-bootstrap.txt delete mode 100644 licenses/LICENSE-boto.txt create mode 100644 licenses/LICENSE-datatables.txt delete mode 100644 licenses/LICENSE-jbcrypt.txt rename licenses/{LICENSE-jmock.txt => LICENSE-join.txt} (60%) create mode 100644 licenses/LICENSE-json-formatter.txt create mode 100644 licenses/LICENSE-matchMedia-polyfill.txt delete mode 100644 licenses/LICENSE-postgresql.txt create mode 100644 licenses/LICENSE-respond.txt delete mode 100644 licenses/LICENSE-scalacheck.txt create mode 100644 licenses/LICENSE-vis.txt diff --git a/LICENSE b/LICENSE index 6f5d9452e800d..b771bd552b762 100644 --- a/LICENSE +++ b/LICENSE @@ -201,103 +201,61 @@ limitations under the License. -======================================================================= -Apache Spark Subcomponents: - -The Apache Spark project contains subcomponents with separate copyright -notices and license terms. Your use of the source code for the these -subcomponents is subject to the terms and conditions of the following -licenses. - - -======================================================================== -For heapq (pyspark/heapq3.py): -======================================================================== - -See license/LICENSE-heapq.txt - -======================================================================== -For SnapTree: -======================================================================== - -See license/LICENSE-SnapTree.txt - -======================================================================== -For jbcrypt: -======================================================================== - -See license/LICENSE-jbcrypt.txt - -======================================================================== -BSD-style licenses -======================================================================== - -The following components are provided under a BSD-style license. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) - (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) - (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) - (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) - (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:2.14.3 - https://github.com/jline/jline2) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) - (Interpreter classes (all .scala files in repl/src/main/scala - except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), - and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.12 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) - (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) - (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) - (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) - (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) - (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) - (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (BSD licence) sbt and sbt-launch-lib.bash - (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) - (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) - (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) - (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) - (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) - -======================================================================== -MIT licenses -======================================================================== - -The following components are provided under the MIT License. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) - (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) - (MIT License) jquery (https://jquery.org/license/) - (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) - (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) - (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) - (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) - (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) - (MIT License) datatables (http://datatables.net/license) - (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) - (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) - (MIT License) blockUI (http://jquery.malsup.com/block/) - (MIT License) RowsGroup (http://datatables.net/license/mit) - (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) - (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) - (MIT License) machinist (https://github.com/typelevel/machinist) +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000000000..c033dd8ad2e6a --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,520 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.11 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.11 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.11 +org.scalanlp:breeze_2.11 +org.typelevel:macro-compat_2.11 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +net.java.dev.jets3t:jets3t +org.apache.spark:spark-catalyst_2.11 +org.apache.spark:spark-kvstore_2.11 +org.apache.spark:spark-launcher_2.11 +org.apache.spark:spark-mllib-local_2.11 +org.apache.spark:spark-network-common_2.11 +org.apache.spark:spark-network-shuffle_2.11 +org.apache.spark:spark-sketch_2.11 +org.apache.spark:spark-tags_2.11 +org.apache.spark:spark-unsafe_2.11 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.calcite:calcite-avatica +org.apache.calcite:calcite-core +org.apache.calcite:calcite-linq4j +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.11 +org.json4s:json4s-core_2.11 +org.json4s:json4s-jackson_2.11 +org.json4s:json4s-scalap_2.11 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.11 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-beanutils:commons-beanutils-core +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.11 +org.scala-lang.modules:scala-xml_2.11 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.11 +org.spire-math:spire_2.11 +org.typelevel:machinist_2.11 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.11 +org.bouncycastle:bcprov-jdk15on + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE index 6ec240efbf12e..9246cc54caa3a 100644 --- a/NOTICE +++ b/NOTICE @@ -4,664 +4,3 @@ Copyright 2014 and onwards The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). - -======================================================================== -Common Development and Distribution License 1.0 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.0. See project link for details. - - (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) - (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) - (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) - (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) - (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) - (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) - -======================================================================== -Common Development and Distribution License 1.1 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.1. See project link for details. - - (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) - (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) - -======================================================================== -Common Public License 1.0 -======================================================================== - -The following components are provided under the Common Public 1.0 License. See project link for details. - - (Common Public License Version 1.0) JUnit (junit:junit-dep:4.10 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:3.8.1 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:4.8.2 - http://junit.org) - -======================================================================== -Eclipse Public License 1.0 -======================================================================== - -The following components are provided under the Eclipse Public License 1.0. See project link for details. - - (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) - -======================================================================== -Mozilla Public License 1.0 -======================================================================== - -The following components are provided under the Mozilla Public License 1.0. See project link for details. - - (GPL) (LGPL) (MPL) JTransforms (com.github.rwl:jtransforms:2.4.0 - http://sourceforge.net/projects/jtransforms/) - (Mozilla Public License Version 1.1) jamon-runtime (org.jamon:jamon-runtime:2.3.1 - http://www.jamon.org/jamon-runtime/) - - - -======================================================================== -NOTICE files -======================================================================== - -The following NOTICEs are pertain to software distributed with this project. - - -// ------------------------------------------------------------------ -// NOTICE file corresponding to the section 4d of The Apache License, -// Version 2.0, in this case for -// ------------------------------------------------------------------ - -Apache Avro -Copyright 2009-2013 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Apache Commons Codec -Copyright 2002-2009 The Apache Software Foundation - -This product includes software developed by -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- -src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java contains -test data from http://aspell.sourceforge.net/test/batch0.tab. - -Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org). Verbatim copying -and distribution of this entire article is permitted in any medium, -provided this notice is preserved. --------------------------------------------------------------------------------- - -Apache HttpComponents HttpClient -Copyright 1999-2011 The Apache Software Foundation - -This project contains annotations derived from JCIP-ANNOTATIONS -Copyright (c) 2005 Brian Goetz and Tim Peierls. See http://www.jcip.net - -Apache HttpComponents HttpCore -Copyright 2005-2011 The Apache Software Foundation - -Curator Recipes -Copyright 2011-2014 The Apache Software Foundation - -Curator Framework -Copyright 2011-2014 The Apache Software Foundation - -Curator Client -Copyright 2011-2014 The Apache Software Foundation - -Apache Geronimo -Copyright 2003-2008 The Apache Software Foundation - -Activation 1.1 -Copyright 2003-2007 The Apache Software Foundation - -Apache Commons Lang -Copyright 2001-2014 The Apache Software Foundation - -This product includes software from the Spring Framework, -under the Apache License 2.0 (see: StringUtils.containsWhitespace()) - -Apache log4j -Copyright 2007 The Apache Software Foundation - -# Compress LZF - -This library contains efficient implementation of LZF compression format, -as well as additional helper classes that build on JDK-provided gzip (deflat) -codec. - -## Licensing - -Library is licensed under Apache License 2.0, as per accompanying LICENSE file. - -## Credit - -Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). -It was started at Ning, inc., as an official Open Source process used by -platform backend, but after initial versions has been developed outside of -Ning by supporting community. - -Other contributors include: - -* Jon Hartlaub (first versions of streaming reader/writer; unit tests) -* Cedrik Lime: parallel LZF implementation - -Various community members have contributed bug reports, and suggested minor -fixes; these can be found from file "VERSION.txt" in SCM. - -Objenesis -Copyright 2006-2009 Joe Walnes, Henri Tremblay, Leonardo Mesquita - -Apache Commons Net -Copyright 2001-2010 The Apache Software Foundation - - The Netty Project - ================= - -Please visit the Netty web site for more information: - - * http://netty.io/ - -Copyright 2011 The Netty Project - -The Netty Project licenses this file to you under the Apache License, -version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at: - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -License for the specific language governing permissions and limitations -under the License. - -Also, please refer to each LICENSE..txt file, which is located in -the 'license' directory of the distribution file, for the license terms of the -components that this product depends on. - -------------------------------------------------------------------------------- -This product contains the extensions to Java Collections Framework which has -been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: - - * LICENSE: - * license/LICENSE.jsr166y.txt (Public Domain) - * HOMEPAGE: - * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ - * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ - -This product contains a modified version of Robert Harder's Public Domain -Base64 Encoder and Decoder, which can be obtained at: - - * LICENSE: - * license/LICENSE.base64.txt (Public Domain) - * HOMEPAGE: - * http://iharder.sourceforge.net/current/java/base64/ - -This product contains a modified version of 'JZlib', a re-implementation of -zlib in pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD Style License) - * HOMEPAGE: - * http://www.jcraft.com/jzlib/ - -This product optionally depends on 'Protocol Buffers', Google's data -interchange format, which can be obtained at: - - * LICENSE: - * license/LICENSE.protobuf.txt (New BSD License) - * HOMEPAGE: - * http://code.google.com/p/protobuf/ - -This product optionally depends on 'SLF4J', a simple logging facade for Java, -which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * http://www.slf4j.org/ - -This product optionally depends on 'Apache Commons Logging', a logging -framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-logging.txt (Apache License 2.0) - * HOMEPAGE: - * http://commons.apache.org/logging/ - -This product optionally depends on 'Apache Log4J', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.log4j.txt (Apache License 2.0) - * HOMEPAGE: - * http://logging.apache.org/log4j/ - -This product optionally depends on 'JBoss Logging', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) - * HOMEPAGE: - * http://anonsvn.jboss.org/repos/common/common-logging-spi/ - -This product optionally depends on 'Apache Felix', an open source OSGi -framework implementation, which can be obtained at: - - * LICENSE: - * license/LICENSE.felix.txt (Apache License 2.0) - * HOMEPAGE: - * http://felix.apache.org/ - -This product optionally depends on 'Webbit', a Java event based -WebSocket and HTTP server: - - * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -# Jackson JSON processor - -Jackson is a high-performance, Free/Open Source JSON processing library. -It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has -been in development since 2007. -It is currently developed by a community of developers, as well as supported -commercially by FasterXML.com. - -Jackson core and extension components may be licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -## Credits - -A list of contributors may be found from CREDITS file, which is included -in some artifacts (usually source distributions); but is always available -from the source code management (SCM) system project uses. - -Jackson core and extension components may licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -mesos -Copyright 2014 The Apache Software Foundation - -Apache Thrift -Copyright 2006-2010 The Apache Software Foundation. - - Apache Ant - Copyright 1999-2013 The Apache Software Foundation - - The task is based on code Copyright (c) 2002, Landmark - Graphics Corp that has been kindly donated to the Apache Software - Foundation. - -Apache Commons IO -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons Math -Copyright 2001-2013 The Apache Software Foundation - -=============================================================================== - -The inverse error function implementation in the Erf class is based on CUDA -code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, -and published in GPU Computing Gems, volume 2, 2010. -=============================================================================== - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. -=============================================================================== - -The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, -RelationShip, SimplexSolver and SimplexTableau classes in package -org.apache.commons.math3.optimization.linear include software developed by -Benjamin McCann (http://www.benmccann.com) and distributed with -the following copyright: Copyright 2009 Google Inc. -=============================================================================== - -This product includes software developed by the -University of Chicago, as Operator of Argonne National -Laboratory. -The LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general includes software -translated from the lmder, lmpar and qrsolv Fortran routines -from the Minpack package -Minpack Copyright Notice (1999) University of Chicago. All rights reserved -=============================================================================== - -The GraggBulirschStoerIntegrator class in package -org.apache.commons.math3.ode.nonstiff includes software translated -from the odex Fortran routine developed by E. Hairer and G. Wanner. -Original source copyright: -Copyright (c) 2004, Ernst Hairer -=============================================================================== - -The EigenDecompositionImpl class in package -org.apache.commons.math3.linear includes software translated -from some LAPACK Fortran routines. Original source copyright: -Copyright (c) 1992-2008 The University of Tennessee. All rights reserved. -=============================================================================== - -The MersenneTwister class in package org.apache.commons.math3.random -includes software translated from the 2002-01-26 version of -the Mersenne-Twister generator written in C by Makoto Matsumoto and Takuji -Nishimura. Original source copyright: -Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, -All rights reserved -=============================================================================== - -The LocalizedFormatsTest class in the unit tests is an adapted version of -the OrekitMessagesTest class from the orekit library distributed under the -terms of the Apache 2 licence. Original source copyright: -Copyright 2010 CS Systèmes d'Information -=============================================================================== - -The HermiteInterpolator class and its corresponding test have been imported from -the orekit library distributed under the terms of the Apache 2 licence. Original -source copyright: -Copyright 2010-2012 CS Systèmes d'Information -=============================================================================== - -The creation of the package "o.a.c.m.analysis.integration.gauss" was inspired -by an original code donated by Sébastien Brisard. -=============================================================================== - -The complete text of licenses and disclaimers associated with the the original -sources enumerated above at the time of code translation are in the LICENSE.txt -file. - -This product currently only contains code developed by authors -of specific components, as identified by the source code files; -if such notes are missing files have been created by -Tatu Saloranta. - -For additional credits (generally to people who reported problems) -see CREDITS file. - -Apache Commons Lang -Copyright 2001-2011 The Apache Software Foundation - -Apache Commons Compress -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons CLI -Copyright 2001-2009 The Apache Software Foundation - -Google Guice - Extensions - Servlet -Copyright 2006-2011 Google, Inc. - -Google Guice - Core Library -Copyright 2006-2011 Google, Inc. - -Apache Jakarta HttpClient -Copyright 1999-2007 The Apache Software Foundation - -Apache Hive -Copyright 2008-2013 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by The JDBM Project -(http://jdbm.sourceforge.net/). - -This product includes/uses ANTLR (http://www.antlr.org/), -Copyright (c) 2003-2011, Terrence Parr. - -This product includes/uses StringTemplate (http://www.stringtemplate.org/), -Copyright (c) 2011, Terrence Parr. - -This product includes/uses ASM (http://asm.ow2.org/), -Copyright (c) 2000-2007 INRIA, France Telecom. - -This product includes/uses JLine (http://jline.sourceforge.net/), -Copyright (c) 2002-2006, Marc Prud'hommeaux . - -This product includes/uses SQLLine (http://sqlline.sourceforge.net), -Copyright (c) 2002, 2003, 2004, 2005 Marc Prud'hommeaux . - -This product includes/uses SLF4J (http://www.slf4j.org/), -Copyright (c) 2004-2010 QOS.ch - -This product includes/uses Bootstrap (http://twitter.github.com/bootstrap/), -Copyright (c) 2012 Twitter, Inc. - -This product includes/uses Glyphicons (http://glyphicons.com/), -Copyright (c) 2010 - 2012 Jan Kovarík - -This product includes DataNucleus (http://www.datanucleus.org/) -Copyright 2008-2008 DataNucleus - -This product includes Guava (http://code.google.com/p/guava-libraries/) -Copyright (C) 2006 Google Inc. - -This product includes JavaEWAH (http://code.google.com/p/javaewah/) -Copyright (C) 2011 Google Inc. - -Apache Commons Pool -Copyright 1999-2009 The Apache Software Foundation - -This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) -Copyright (C) 2015 Red Hat, Inc. - -This product includes/uses OkHttp (https://github.com/square/okhttp) -Copyright (C) 2012 The Android Open Source Project - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the DataNucleus distribution. == -========================================================================= - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Erik Bengtson -Andy Jefferson - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Joerg von Frantzius -Thomas Marti -Barry Haddow -Marco Schulze -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Marcus Mennemeier -Xuan Baldauf -Eric Sultan - -=================================================================== -This product also includes software developed by the TJDO project -(http://tjdo.sourceforge.net/). -=================================================================== - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Andy Jefferson -Erik Bengtson -Joerg von Frantzius -Marco Schulze - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Barry Haddow -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Anton Troshin (Timesten) - -=================================================================== -This product also includes software developed by the Apache Commons project -(http://commons.apache.org/). -=================================================================== - -Apache Java Data Objects (JDO) -Copyright 2005-2006 The Apache Software Foundation - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the Apache Derby distribution. == -========================================================================= - -Apache Derby -Copyright 2004-2008 The Apache Software Foundation - -Portions of Derby were originally developed by -International Business Machines Corporation and are -licensed to the Apache Software Foundation under the -"Software Grant and Corporate Contribution License Agreement", -informally known as the "Derby CLA". -The following copyright notice(s) were affixed to portions of the code -with which this file is now or was at one time distributed -and are placed here unaltered. - -(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. - -(C) Copyright IBM Corp. 2003. - -The portion of the functionTests under 'nist' was originally -developed by the National Institute of Standards and Technology (NIST), -an agency of the United States Department of Commerce, and adapted by -International Business Machines Corporation in accordance with the NIST -Software Acknowledgment and Redistribution document at -http://www.itl.nist.gov/div897/ctg/sql_form.htm - -Apache Commons Collections -Copyright 2001-2008 The Apache Software Foundation - -Apache Commons Configuration -Copyright 2001-2008 The Apache Software Foundation - -Apache Jakarta Commons Digester -Copyright 2001-2006 The Apache Software Foundation - -Apache Commons BeanUtils -Copyright 2000-2008 The Apache Software Foundation - -Apache Avro Mapred API -Copyright 2009-2013 The Apache Software Foundation - -Apache Avro IPC -Copyright 2009-2013 The Apache Software Foundation - - -Vis.js -Copyright 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - - and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. - - -Vis.js uses and redistributes the following third-party libraries: - -- component-emitter - https://github.com/component/emitter - The MIT License - -- hammer.js - http://hammerjs.github.io/ - The MIT License - -- moment.js - http://momentjs.com/ - The MIT License - -- keycharm - https://github.com/AlexDM0/keycharm - The MIT License - -=============================================================================== - -The CSS style for the navigation sidebar of the documentation was originally -submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project -is distributed under the 3-Clause BSD license. -=============================================================================== - -For CSV functionality: - -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2015 Ayasdi Inc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -=============================================================================== -For dev/sparktestsupport/toposort.py: - -Copyright 2014 True Blade Systems, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000000000..d56f99bdb55a6 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1170 @@ +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library containd statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Calcite Avatica +Copyright 2012-2015 The Apache Software Foundation + +Calcite Core +Copyright 2012-2015 The Apache Software Foundation + +Calcite Linq4j +Copyright 2012-2015 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the distribution of jets3t. == + ========================================================================= + + This product includes software developed by: + + The Apache Software Foundation (http://www.apache.org/). + + The ExoLab Project (http://www.exolab.org/) + + Sun Microsystems (http://www.sun.com/) + + Codehaus (http://castor.codehaus.org) + + Tatu Saloranta (http://wiki.fasterxml.com/TatuSaloranta) + + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation \ No newline at end of file diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 23b24212b4d29..466135e72233a 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -11,6 +11,10 @@ cache .rat-excludes .*md derby.log +licenses/* +licenses-binary/* +LICENSE +NOTICE TAGS RELEASE control diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84233c64caa9c..ad99ce55806af 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -211,9 +211,10 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" +cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" +mkdir -p "$DISTDIR/licenses" +cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" +cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" diff --git a/licenses/LICENSE-scopt.txt b/licenses-binary/LICENSE-AnchorJS.txt similarity index 100% rename from licenses/LICENSE-scopt.txt rename to licenses-binary/LICENSE-AnchorJS.txt diff --git a/licenses-binary/LICENSE-CC0.txt b/licenses-binary/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses-binary/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-antlr.txt b/licenses-binary/LICENSE-antlr.txt similarity index 100% rename from licenses/LICENSE-antlr.txt rename to licenses-binary/LICENSE-antlr.txt diff --git a/licenses-binary/LICENSE-arpack.txt b/licenses-binary/LICENSE-arpack.txt new file mode 100644 index 0000000000000..a3ad80087bb63 --- /dev/null +++ b/licenses-binary/LICENSE-arpack.txt @@ -0,0 +1,8 @@ +Copyright © 2018 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses-binary/LICENSE-automaton.txt b/licenses-binary/LICENSE-automaton.txt new file mode 100644 index 0000000000000..2fc6e8c3432f0 --- /dev/null +++ b/licenses-binary/LICENSE-automaton.txt @@ -0,0 +1,24 @@ +Copyright (c) 2001-2017 Anders Moeller +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-bootstrap.txt b/licenses-binary/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses-binary/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses-binary/LICENSE-bouncycastle-bcprov.txt b/licenses-binary/LICENSE-bouncycastle-bcprov.txt new file mode 100644 index 0000000000000..c445a93a06dd4 --- /dev/null +++ b/licenses-binary/LICENSE-bouncycastle-bcprov.txt @@ -0,0 +1,7 @@ +Copyright (c) 2000 - 2018 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-cloudpickle.txt b/licenses-binary/LICENSE-cloudpickle.txt new file mode 100644 index 0000000000000..b1e20fa1eda88 --- /dev/null +++ b/licenses-binary/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-d3.min.js.txt b/licenses-binary/LICENSE-d3.min.js.txt new file mode 100644 index 0000000000000..c71e3f254c068 --- /dev/null +++ b/licenses-binary/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses-binary/LICENSE-dagre-d3.txt similarity index 94% rename from licenses/LICENSE-Mockito.txt rename to licenses-binary/LICENSE-dagre-d3.txt index e0840a446caf5..4864fe05e9803 100644 --- a/licenses/LICENSE-Mockito.txt +++ b/licenses-binary/LICENSE-dagre-d3.txt @@ -1,6 +1,4 @@ -The MIT License - -Copyright (c) 2007 Mockito contributors +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses-binary/LICENSE-datatables.txt b/licenses-binary/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses-binary/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses-binary/LICENSE-f2j.txt similarity index 100% rename from licenses/LICENSE-f2j.txt rename to licenses-binary/LICENSE-f2j.txt diff --git a/licenses-binary/LICENSE-graphlib-dot.txt b/licenses-binary/LICENSE-graphlib-dot.txt new file mode 100644 index 0000000000000..4864fe05e9803 --- /dev/null +++ b/licenses-binary/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-heapq.txt b/licenses-binary/LICENSE-heapq.txt new file mode 100644 index 0000000000000..0c4c4b954bea4 --- /dev/null +++ b/licenses-binary/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-janino.txt b/licenses-binary/LICENSE-janino.txt new file mode 100644 index 0000000000000..d1e1f237c4641 --- /dev/null +++ b/licenses-binary/LICENSE-janino.txt @@ -0,0 +1,31 @@ +Janino - An embedded Java[TM] compiler + +Copyright (c) 2001-2016, Arno Unkrig +Copyright (c) 2015-2016 TIBCO Software Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + 3. Neither the name of JANINO nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-javassist.html b/licenses-binary/LICENSE-javassist.html new file mode 100644 index 0000000000000..5abd563a0c4d9 --- /dev/null +++ b/licenses-binary/LICENSE-javassist.html @@ -0,0 +1,373 @@ + + + Javassist License + + + + +
    MOZILLA PUBLIC LICENSE
    Version + 1.1 +

    +


    +
    +

    1. Definitions. +

      1.0.1. "Commercial Use" means distribution or otherwise making the + Covered Code available to a third party. +

      1.1. ''Contributor'' means each entity that creates or contributes + to the creation of Modifications. +

      1.2. ''Contributor Version'' means the combination of the Original + Code, prior Modifications used by a Contributor, and the Modifications made by + that particular Contributor. +

      1.3. ''Covered Code'' means the Original Code or Modifications or + the combination of the Original Code and Modifications, in each case including + portions thereof. +

      1.4. ''Electronic Distribution Mechanism'' means a mechanism + generally accepted in the software development community for the electronic + transfer of data. +

      1.5. ''Executable'' means Covered Code in any form other than Source + Code. +

      1.6. ''Initial Developer'' means the individual or entity identified + as the Initial Developer in the Source Code notice required by Exhibit + A. +

      1.7. ''Larger Work'' means a work which combines Covered Code or + portions thereof with code not governed by the terms of this License. +

      1.8. ''License'' means this document. +

      1.8.1. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or subsequently + acquired, any and all of the rights conveyed herein. +

      1.9. ''Modifications'' means any addition to or deletion from the + substance or structure of either the Original Code or any previous + Modifications. When Covered Code is released as a series of files, a + Modification is: +

        A. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +

        B. Any new file that contains any part of the Original Code or + previous Modifications.
         

      1.10. ''Original Code'' +means Source Code of computer software code which is described in the Source +Code notice required by Exhibit A as Original Code, and which, at the +time of its release under this License is not already Covered Code governed by +this License. +

      1.10.1. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation,  method, process, and + apparatus claims, in any patent Licensable by grantor. +

      1.11. ''Source Code'' means the preferred form of the Covered Code + for making modifications to it, including all modules it contains, plus any + associated interface definition files, scripts used to control compilation and + installation of an Executable, or source code differential comparisons against + either the Original Code or another well known, available Covered Code of the + Contributor's choice. The Source Code can be in a compressed or archival form, + provided the appropriate decompression or de-archiving software is widely + available for no charge. +

      1.12. "You'' (or "Your")  means an individual or a legal entity + exercising rights under, and complying with all of the terms of, this License + or a future version of this License issued under Section 6.1. For legal + entities, "You'' includes any entity which controls, is controlled by, or is + under common control with You. For purposes of this definition, "control'' + means (a) the power, direct or indirect, to cause the direction or management + of such entity, whether by contract or otherwise, or (b) ownership of more + than fifty percent (50%) of the outstanding shares or beneficial ownership of + such entity.

    2. Source Code License. +
      2.1. The Initial Developer Grant.
      The Initial Developer hereby + grants You a world-wide, royalty-free, non-exclusive license, subject to third + party intellectual property claims: +
        (a)  under intellectual property rights (other than + patent or trademark) Licensable by Initial Developer to use, reproduce, + modify, display, perform, sublicense and distribute the Original Code (or + portions thereof) with or without Modifications, and/or as part of a Larger + Work; and +

        (b) under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for + sale, and/or otherwise dispose of the Original Code (or portions thereof). +

          +
          (c) the licenses granted in this Section 2.1(a) and (b) + are effective on the date Initial Developer first distributes Original Code + under the terms of this License. +

          (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: 1) for code that You delete from the Original Code; 2) separate + from the Original Code;  or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original + Code with other software or devices.
           

        2.2. Contributor + Grant.
        Subject to third party intellectual property claims, each + Contributor hereby grants You a world-wide, royalty-free, non-exclusive + license +

          (a)  under intellectual property rights (other + than patent or trademark) Licensable by Contributor, to use, reproduce, + modify, display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof) either on an unmodified + basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +

          (b) under Patent Claims infringed by the making, using, or selling + of  Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: 1) Modifications made by that Contributor (or portions + thereof); and 2) the combination of  Modifications made by that + Contributor with its Contributor Version (or portions of such + combination). +

          (c) the licenses granted in Sections 2.2(a) and 2.2(b) are + effective on the date Contributor first makes Commercial Use of the Covered + Code. +

          (d)    Notwithstanding Section 2.2(b) above, no + patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2)  separate from the Contributor + Version;  3)  for infringements caused by: i) third party + modifications of Contributor Version or ii)  the combination of + Modifications made by that Contributor with other software  (except as + part of the Contributor Version) or other devices; or 4) under Patent Claims + infringed by Covered Code in the absence of Modifications made by that + Contributor.

      +


      3. Distribution Obligations. +

        3.1. Application of License.
        The Modifications which You create + or to which You contribute are governed by the terms of this License, + including without limitation Section 2.2. The Source Code version of + Covered Code may be distributed only under the terms of this License or a + future version of this License released under Section 6.1, and You must + include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version + that alters or restricts the applicable version of this License or the + recipients' rights hereunder. However, You may include an additional document + offering the additional rights described in Section 3.5. +

        3.2. Availability of Source Code.
        Any Modification which You + create or to which You contribute must be made available in Source Code form + under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom + you made an Executable version available; and if made available via Electronic + Distribution Mechanism, must remain available for at least twelve (12) months + after the date it initially became available, or at least six (6) months after + a subsequent version of that particular Modification has been made available + to such recipients. You are responsible for ensuring that the Source Code + version remains available even if the Electronic Distribution Mechanism is + maintained by a third party. +

        3.3. Description of Modifications.
        You must cause all Covered + Code to which You contribute to contain a file documenting the changes You + made to create that Covered Code and the date of any change. You must include + a prominent statement that the Modification is derived, directly or + indirectly, from Original Code provided by the Initial Developer and including + the name of the Initial Developer in (a) the Source Code, and (b) in any + notice in an Executable version or related documentation in which You describe + the origin or ownership of the Covered Code. +

        3.4. Intellectual Property Matters +

          (a) Third Party Claims.
          If Contributor has knowledge that a + license under a third party's intellectual property rights is required to + exercise the rights granted by such Contributor under Sections 2.1 or 2.2, + Contributor must include a text file with the Source Code distribution + titled "LEGAL'' which describes the claim and the party making the claim in + sufficient detail that a recipient will know whom to contact. If Contributor + obtains such knowledge after the Modification is made available as described + in Section 3.2, Contributor shall promptly modify the LEGAL file in all + copies Contributor makes available thereafter and shall take other steps + (such as notifying appropriate mailing lists or newsgroups) reasonably + calculated to inform those who received the Covered Code that new knowledge + has been obtained. +

          (b) Contributor APIs.
          If Contributor's Modifications include + an application programming interface and Contributor has knowledge of patent + licenses which are reasonably necessary to implement that API, Contributor + must also include this information in the LEGAL file. +
           

                  +(c)    Representations. +
          Contributor represents that, except as disclosed pursuant to Section + 3.4(a) above, Contributor believes that Contributor's Modifications are + Contributor's original creation(s) and/or Contributor has sufficient rights + to grant the rights conveyed by this License.
        +


        3.5. Required Notices.
        You must duplicate the notice in + Exhibit A in each file of the Source Code.  If it is not possible + to put such notice in a particular Source Code file due to its structure, then + You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice.  If You created + one or more Modification(s) You may add your name as a Contributor to the + notice described in Exhibit A.  You must also duplicate this + License in any documentation for the Source Code where You describe + recipients' rights or ownership rights relating to Covered Code.  You may + choose to offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Code. However, You + may do so only on Your own behalf, and not on behalf of the Initial Developer + or any Contributor. You must make it absolutely clear than any such warranty, + support, indemnity or liability obligation is offered by You alone, and You + hereby agree to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a result of + warranty, support, indemnity or liability terms You offer. +

        3.6. Distribution of Executable Versions.
        You may distribute + Covered Code in Executable form only if the requirements of Section + 3.1-3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available + under the terms of this License, including a description of how and where You + have fulfilled the obligations of Section 3.2. The notice must be + conspicuously included in any notice in an Executable version, related + documentation or collateral in which You describe recipients' rights relating + to the Covered Code. You may distribute the Executable version of Covered Code + or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the + terms of this License and that the license for the Executable version does not + attempt to limit or alter the recipient's rights in the Source Code version + from the rights set forth in this License. If You distribute the Executable + version under a different license You must make it absolutely clear that any + terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the + Initial Developer and every Contributor for any liability incurred by the + Initial Developer or such Contributor as a result of any such terms You offer. + +

        3.7. Larger Works.
        You may create a Larger Work by combining + Covered Code with other code not governed by the terms of this License and + distribute the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Code.

      4. Inability to Comply Due to Statute or Regulation. +
        If it is impossible for You to comply with any of the terms of this + License with respect to some or all of the Covered Code due to statute, + judicial order, or regulation then You must: (a) comply with the terms of this + License to the maximum extent possible; and (b) describe the limitations and + the code they affect. Such description must be included in the LEGAL file + described in Section 3.4 and must be included with all distributions of + the Source Code. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it.
      5. Application of this License. +
        This License applies to code to which the Initial Developer has attached + the notice in Exhibit A and to related Covered Code.
      6. Versions + of the License. +
        6.1. New Versions.
        Netscape Communications Corporation + (''Netscape'') may publish revised and/or new versions of the License from + time to time. Each version will be given a distinguishing version number. +

        6.2. Effect of New Versions.
        Once Covered Code has been + published under a particular version of the License, You may always continue + to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License + published by Netscape. No one other than Netscape has the right to modify the + terms applicable to Covered Code created under this License. +

        6.3. Derivative Works.
        If You create or use a modified version + of this License (which you may only do in order to apply it to code which is + not already Covered Code governed by this License), You must (a) rename Your + license so that the phrases ''Mozilla'', ''MOZILLAPL'', ''MOZPL'', + ''Netscape'', "MPL", ''NPL'' or any confusingly similar phrase do not appear + in your license (except to note that your license differs from this License) + and (b) otherwise make it clear that Your version of the license contains + terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or + Contributor in the notice described in Exhibit A shall not of + themselves be deemed to be modifications of this License.)

      7. + DISCLAIMER OF WARRANTY. +
        COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS'' BASIS, WITHOUT + WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT + LIMITATION, WARRANTIES THAT THE COVERED CODE IS FREE OF DEFECTS, MERCHANTABLE, + FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE + QUALITY AND PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED + CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR + CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS + LICENSE. NO USE OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS + DISCLAIMER.
      8. TERMINATION. +
        8.1.  This License and the rights granted hereunder will + terminate automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. All + sublicenses to the Covered Code which are properly granted shall survive any + termination of this License. Provisions which, by their nature, must remain in + effect beyond the termination of this License shall survive. +

        8.2.  If You initiate litigation by asserting a patent + infringement claim (excluding declatory judgment actions) against Initial + Developer or a Contributor (the Initial Developer or Contributor against whom + You file such action is referred to as "Participant")  alleging that: +

        (a)  such Participant's Contributor Version directly or + indirectly infringes any patent, then any and all rights granted by such + Participant to You under Sections 2.1 and/or 2.2 of this License shall, upon + 60 days notice from Participant terminate prospectively, unless if within 60 + days after receipt of notice You either: (i)  agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future + use of Modifications made by such Participant, or (ii) withdraw Your + litigation claim with respect to the Contributor Version against such + Participant.  If within 60 days of notice, a reasonable royalty and + payment arrangement are not mutually agreed upon in writing by the parties or + the litigation claim is not withdrawn, the rights granted by Participant to + You under Sections 2.1 and/or 2.2 automatically terminate at the expiration of + the 60 day notice period specified above. +

        (b)  any software, hardware, or device, other than such + Participant's Contributor Version, directly or indirectly infringes any + patent, then any rights granted to You by such Participant under Sections + 2.1(b) and 2.2(b) are revoked effective as of the date You first made, used, + sold, distributed, or had made, Modifications made by that Participant. +

        8.3.  If You assert a patent infringement claim against + Participant alleging that such Participant's Contributor Version directly or + indirectly infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent infringement + litigation, then the reasonable value of the licenses granted by such + Participant under Sections 2.1 or 2.2 shall be taken into account in + determining the amount or value of any payment or license. +

        8.4.  In the event of termination under Sections 8.1 or 8.2 + above,  all end user license agreements (excluding distributors and + resellers) which have been validly granted by You or any distributor hereunder + prior to termination shall survive termination.

      9. LIMITATION OF + LIABILITY. +
        UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING + NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY + OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED CODE, OR ANY SUPPLIER OF ANY + OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, + INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR + MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH + PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS + LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL + INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW + PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND + LIMITATION MAY NOT APPLY TO YOU.
      10. U.S. GOVERNMENT END USERS. +
        The Covered Code is a ''commercial item,'' as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of ''commercial computer software'' and + ''commercial computer software documentation,'' as such terms are used in 48 + C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein.
      11. + MISCELLANEOUS. +
        This License represents the complete agreement concerning subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. This License shall be governed by California law provisions + (except to the extent applicable law, if any, provides otherwise), excluding + its conflict-of-law provisions. With respect to disputes in which at least one + party is a citizen of, or an entity chartered or registered to do business in + the United States of America, any litigation relating to this License shall be + subject to the jurisdiction of the Federal Courts of the Northern District of + California, with venue lying in Santa Clara County, California, with the + losing party responsible for costs, including without limitation, court costs + and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is + expressly excluded. Any law or regulation which provides that the language of + a contract shall be construed against the drafter shall not apply to this + License.
      12. RESPONSIBILITY FOR CLAIMS. +
        As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, out of its + utilization of rights under this License and You agree to work with Initial + Developer and Contributors to distribute such responsibility on an equitable + basis. Nothing herein is intended or shall be deemed to constitute any + admission of liability.
      13. MULTIPLE-LICENSED CODE. +
        Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed".  "Multiple-Licensed" means that the Initial + Developer permits you to utilize portions of the Covered Code under Your + choice of the MPL or the alternative licenses, if any, specified by the + Initial Developer in the file described in Exhibit A.
      +


      EXHIBIT A -Mozilla Public License. +

        The contents of this file are subject to the Mozilla Public License + Version 1.1 (the "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at +
        http://www.mozilla.org/MPL/ +

        Software distributed under the License is distributed on an "AS IS" basis, + WITHOUT WARRANTY OF
        ANY KIND, either express or implied. See the License + for the specific language governing rights and
        limitations under the + License. +

        The Original Code is Javassist. +

        The Initial Developer of the Original Code is Shigeru Chiba. + Portions created by the Initial Developer are
          + Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. +

        Contributor(s): __Bill Burke, Jason T. Greene______________. + +

        Alternatively, the contents of this software may be used under the + terms of the GNU Lesser General Public License Version 2.1 or later + (the "LGPL"), or the Apache License Version 2.0 (the "AL"), + in which case the provisions of the LGPL or the AL are applicable + instead of those above. If you wish to allow use of your version of + this software only under the terms of either the LGPL or the AL, and not to allow others to + use your version of this software under the terms of the MPL, indicate + your decision by deleting the provisions above and replace them with + the notice and other provisions required by the LGPL or the AL. If you do not + delete the provisions above, a recipient may use your version of this + software under the terms of any one of the MPL, the LGPL or the AL. + +

      + + \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses-binary/LICENSE-javolution.txt similarity index 100% rename from licenses/LICENSE-javolution.txt rename to licenses-binary/LICENSE-javolution.txt diff --git a/licenses/LICENSE-jline.txt b/licenses-binary/LICENSE-jline.txt similarity index 100% rename from licenses/LICENSE-jline.txt rename to licenses-binary/LICENSE-jline.txt diff --git a/licenses/LICENSE-junit-interface.txt b/licenses-binary/LICENSE-jodd.txt similarity index 69% rename from licenses/LICENSE-junit-interface.txt rename to licenses-binary/LICENSE-jodd.txt index e835350c4e2a4..cc6b458adb386 100644 --- a/licenses/LICENSE-junit-interface.txt +++ b/licenses-binary/LICENSE-jodd.txt @@ -1,15 +1,15 @@ -Copyright (c) 2009-2012, Stefan Zeiger +Copyright (c) 2003-present, Jodd Team (https://jodd.org) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/licenses/LICENSE-DPark.txt b/licenses-binary/LICENSE-join.txt similarity index 100% rename from licenses/LICENSE-DPark.txt rename to licenses-binary/LICENSE-join.txt diff --git a/licenses-binary/LICENSE-jquery.txt b/licenses-binary/LICENSE-jquery.txt new file mode 100644 index 0000000000000..45930542204fb --- /dev/null +++ b/licenses-binary/LICENSE-jquery.txt @@ -0,0 +1,20 @@ +Copyright JS Foundation and other contributors, https://js.foundation/ + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-json-formatter.txt b/licenses-binary/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses-binary/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html new file mode 100644 index 0000000000000..351c17412357b --- /dev/null +++ b/licenses-binary/LICENSE-jtransforms.html @@ -0,0 +1,388 @@ + + +Mozilla Public License version 1.1 + + + + +

      Mozilla Public License Version 1.1

      +

      1. Definitions.

      +
      +
      1.0.1. "Commercial Use" +
      means distribution or otherwise making the Covered Code available to a third party. +
      1.1. "Contributor" +
      means each entity that creates or contributes to the creation of Modifications. +
      1.2. "Contributor Version" +
      means the combination of the Original Code, prior Modifications used by a Contributor, + and the Modifications made by that particular Contributor. +
      1.3. "Covered Code" +
      means the Original Code or Modifications or the combination of the Original Code and + Modifications, in each case including portions thereof. +
      1.4. "Electronic Distribution Mechanism" +
      means a mechanism generally accepted in the software development community for the + electronic transfer of data. +
      1.5. "Executable" +
      means Covered Code in any form other than Source Code. +
      1.6. "Initial Developer" +
      means the individual or entity identified as the Initial Developer in the Source Code + notice required by Exhibit A. +
      1.7. "Larger Work" +
      means a work which combines Covered Code or portions thereof with code not governed + by the terms of this License. +
      1.8. "License" +
      means this document. +
      1.8.1. "Licensable" +
      means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently acquired, any and all of the rights + conveyed herein. +
      1.9. "Modifications" +
      +

      means any addition to or deletion from the substance or structure of either the + Original Code or any previous Modifications. When Covered Code is released as a + series of files, a Modification is: +

        +
      1. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +
      2. Any new file that contains any part of the Original Code or + previous Modifications. +
      +
      1.10. "Original Code" +
      means Source Code of computer software code which is described in the Source Code + notice required by Exhibit A as Original Code, and which, + at the time of its release under this License is not already Covered Code governed + by this License. +
      1.10.1. "Patent Claims" +
      means any patent claim(s), now owned or hereafter acquired, including without + limitation, method, process, and apparatus claims, in any patent Licensable by + grantor. +
      1.11. "Source Code" +
      means the preferred form of the Covered Code for making modifications to it, + including all modules it contains, plus any associated interface definition files, + scripts used to control compilation and installation of an Executable, or source + code differential comparisons against either the Original Code or another well known, + available Covered Code of the Contributor's choice. The Source Code can be in a + compressed or archival form, provided the appropriate decompression or de-archiving + software is widely available for no charge. +
      1.12. "You" (or "Your") +
      means an individual or a legal entity exercising rights under, and complying with + all of the terms of, this License or a future version of this License issued under + Section 6.1. For legal entities, "You" includes any entity + which controls, is controlled by, or is under common control with You. For purposes of + this definition, "control" means (a) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or otherwise, or (b) + ownership of more than fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. +
      +

      2. Source Code License.

      +

      2.1. The Initial Developer Grant.

      +

      The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive + license, subject to third party intellectual property claims: +

        +
      1. under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, + sublicense and distribute the Original Code (or portions thereof) with or without + Modifications, and/or as part of a Larger Work; and +
      2. under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or + otherwise dispose of the Original Code (or portions thereof). +
      3. the licenses granted in this Section 2.1 + (a) and (b) are effective on + the date Initial Developer first distributes Original Code under the terms of this + License. +
      4. Notwithstanding Section 2.1 (b) + above, no patent license is granted: 1) for code that You delete from the Original Code; + 2) separate from the Original Code; or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original Code with other + software or devices. +
      +

      2.2. Contributor Grant.

      +

      Subject to third party intellectual property claims, each Contributor hereby grants You + a world-wide, royalty-free, non-exclusive license +

        +
      1. under intellectual property rights (other than patent or trademark) + Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and + distribute the Modifications created by such Contributor (or portions thereof) either on + an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +
      2. under Patent Claims infringed by the making, using, or selling of + Modifications made by that Contributor either alone and/or in combination with its + Contributor Version (or portions of such combination), to make, use, sell, offer for + sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor + (or portions thereof); and 2) the combination of Modifications made by that Contributor + with its Contributor Version (or portions of such combination). +
      3. the licenses granted in Sections 2.2 + (a) and 2.2 (b) are effective + on the date Contributor first makes Commercial Use of the Covered Code. +
      4. Notwithstanding Section 2.2 (b) + above, no patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2) separate from the Contributor Version; 3) for infringements + caused by: i) third party modifications of Contributor Version or ii) the combination of + Modifications made by that Contributor with other software (except as part of the + Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code + in the absence of Modifications made by that Contributor. +
      +

      3. Distribution Obligations.

      +

      3.1. Application of License.

      +

      The Modifications which You create or to which You contribute are governed by the terms + of this License, including without limitation Section 2.2. The + Source Code version of Covered Code may be distributed only under the terms of this License + or a future version of this License released under Section 6.1, + and You must include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version that alters or + restricts the applicable version of this License or the recipients' rights hereunder. + However, You may include an additional document offering the additional rights described in + Section 3.5. +

      3.2. Availability of Source Code.

      +

      Any Modification which You create or to which You contribute must be made available in + Source Code form under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an + Executable version available; and if made available via Electronic Distribution Mechanism, + must remain available for at least twelve (12) months after the date it initially became + available, or at least six (6) months after a subsequent version of that particular + Modification has been made available to such recipients. You are responsible for ensuring + that the Source Code version remains available even if the Electronic Distribution + Mechanism is maintained by a third party. +

      3.3. Description of Modifications.

      +

      You must cause all Covered Code to which You contribute to contain a file documenting the + changes You made to create that Covered Code and the date of any change. You must include a + prominent statement that the Modification is derived, directly or indirectly, from Original + Code provided by the Initial Developer and including the name of the Initial Developer in + (a) the Source Code, and (b) in any notice in an Executable version or related documentation + in which You describe the origin or ownership of the Covered Code. +

      3.4. Intellectual Property Matters

      +

      (a) Third Party Claims

      +

      If Contributor has knowledge that a license under a third party's intellectual property + rights is required to exercise the rights granted by such Contributor under Sections + 2.1 or 2.2, Contributor must include a + text file with the Source Code distribution titled "LEGAL" which describes the claim and the + party making the claim in sufficient detail that a recipient will know whom to contact. If + Contributor obtains such knowledge after the Modification is made available as described in + Section 3.2, Contributor shall promptly modify the LEGAL file in + all copies Contributor makes available thereafter and shall take other steps (such as + notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who + received the Covered Code that new knowledge has been obtained. +

      (b) Contributor APIs

      +

      If Contributor's Modifications include an application programming interface and Contributor + has knowledge of patent licenses which are reasonably necessary to implement that + API, Contributor must also include this information in the + legal file. +

      (c) Representations.

      +

      Contributor represents that, except as disclosed pursuant to Section 3.4 + (a) above, Contributor believes that Contributor's Modifications + are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the + rights conveyed by this License. +

      3.5. Required Notices.

      +

      You must duplicate the notice in Exhibit A in each file of the + Source Code. If it is not possible to put such notice in a particular Source Code file due to + its structure, then You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice. If You created one or more + Modification(s) You may add your name as a Contributor to the notice described in + Exhibit A. You must also duplicate this License in any documentation + for the Source Code where You describe recipients' rights or ownership rights relating to + Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity + or liability obligations to one or more recipients of Covered Code. However, You may do so + only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You + must make it absolutely clear than any such warranty, support, indemnity or liability + obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer + and every Contributor for any liability incurred by the Initial Developer or such Contributor + as a result of warranty, support, indemnity or liability terms You offer. +

      3.6. Distribution of Executable Versions.

      +

      You may distribute Covered Code in Executable form only if the requirements of Sections + 3.1, 3.2, + 3.3, 3.4 and + 3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available under the terms + of this License, including a description of how and where You have fulfilled the obligations + of Section 3.2. The notice must be conspicuously included in any + notice in an Executable version, related documentation or collateral in which You describe + recipients' rights relating to the Covered Code. You may distribute the Executable version of + Covered Code or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the terms of this + License and that the license for the Executable version does not attempt to limit or alter the + recipient's rights in the Source Code version from the rights set forth in this License. If + You distribute the Executable version under a different license You must make it absolutely + clear that any terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial Developer or such Contributor as + a result of any such terms You offer. +

      3.7. Larger Works.

      +

      You may create a Larger Work by combining Covered Code with other code not governed by the + terms of this License and distribute the Larger Work as a single product. In such a case, + You must make sure the requirements of this License are fulfilled for the Covered Code. +

      4. Inability to Comply Due to Statute or Regulation.

      +

      If it is impossible for You to comply with any of the terms of this License with respect to + some or all of the Covered Code due to statute, judicial order, or regulation then You must: + (a) comply with the terms of this License to the maximum extent possible; and (b) describe + the limitations and the code they affect. Such description must be included in the + legal file described in Section + 3.4 and must be included with all distributions of the Source Code. + Except to the extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to understand it. +

      5. Application of this License.

      +

      This License applies to code to which the Initial Developer has attached the notice in + Exhibit A and to related Covered Code. +

      6. Versions of the License.

      +

      6.1. New Versions

      +

      Netscape Communications Corporation ("Netscape") may publish revised and/or new versions + of the License from time to time. Each version will be given a distinguishing version number. +

      6.2. Effect of New Versions

      +

      Once Covered Code has been published under a particular version of the License, You may + always continue to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License published by Netscape. + No one other than Netscape has the right to modify the terms applicable to Covered Code + created under this License. +

      6.3. Derivative Works

      +

      If You create or use a modified version of this License (which you may only do in order to + apply it to code which is not already Covered Code governed by this License), You must (a) + rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", + "NPL" or any confusingly similar phrase do not appear in your license (except to note that + your license differs from this License) and (b) otherwise make it clear that Your version of + the license contains terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or Contributor in the + notice described in Exhibit A shall not of themselves be deemed to + be modifications of this License.) +

      7. Disclaimer of warranty

      +

      Covered code is provided under this license on an "as is" + basis, without warranty of any kind, either expressed or implied, including, without + limitation, warranties that the covered code is free of defects, merchantable, fit for a + particular purpose or non-infringing. The entire risk as to the quality and performance of + the covered code is with you. Should any covered code prove defective in any respect, you + (not the initial developer or any other contributor) assume the cost of any necessary + servicing, repair or correction. This disclaimer of warranty constitutes an essential part + of this license. No use of any covered code is authorized hereunder except under this + disclaimer. +

      8. Termination

      +

      8.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to cure such breach + within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which + are properly granted shall survive any termination of this License. Provisions which, by + their nature, must remain in effect beyond the termination of this License shall survive. +

      8.2. If You initiate litigation by asserting a patent infringement + claim (excluding declatory judgment actions) against Initial Developer or a Contributor + (the Initial Developer or Contributor against whom You file such action is referred to + as "Participant") alleging that: +

        +
      1. such Participant's Contributor Version directly or indirectly + infringes any patent, then any and all rights granted by such Participant to You under + Sections 2.1 and/or 2.2 of this + License shall, upon 60 days notice from Participant terminate prospectively, unless if + within 60 days after receipt of notice You either: (i) agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future use of + Modifications made by such Participant, or (ii) withdraw Your litigation claim with + respect to the Contributor Version against such Participant. If within 60 days of + notice, a reasonable royalty and payment arrangement are not mutually agreed upon in + writing by the parties or the litigation claim is not withdrawn, the rights granted by + Participant to You under Sections 2.1 and/or + 2.2 automatically terminate at the expiration of the 60 day + notice period specified above. +
      2. any software, hardware, or device, other than such Participant's + Contributor Version, directly or indirectly infringes any patent, then any rights + granted to You by such Participant under Sections 2.1(b) + and 2.2(b) are revoked effective as of the date You first + made, used, sold, distributed, or had made, Modifications made by that Participant. +
      +

      8.3. If You assert a patent infringement claim against Participant + alleging that such Participant's Contributor Version directly or indirectly infringes + any patent where such claim is resolved (such as by license or settlement) prior to the + initiation of patent infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or + 2.2 shall be taken into account in determining the amount or + value of any payment or license. +

      8.4. In the event of termination under Sections + 8.1 or 8.2 above, all end user + license agreements (excluding distributors and resellers) which have been validly + granted by You or any distributor hereunder prior to termination shall survive + termination. +

      9. Limitation of liability

      +

      Under no circumstances and under no legal theory, whether + tort (including negligence), contract, or otherwise, shall you, the initial developer, + any other contributor, or any distributor of covered code, or any supplier of any of + such parties, be liable to any person for any indirect, special, incidental, or + consequential damages of any character including, without limitation, damages for loss + of goodwill, work stoppage, computer failure or malfunction, or any and all other + commercial damages or losses, even if such party shall have been informed of the + possibility of such damages. This limitation of liability shall not apply to liability + for death or personal injury resulting from such party's negligence to the extent + applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion + or limitation of incidental or consequential damages, so this exclusion and limitation + may not apply to you. +

      10. U.S. government end users

      +

      The Covered Code is a "commercial item," as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of + "commercial computer software" and "commercial computer software documentation," as such + terms are used in 48 C.F.R. 12.212 (Sept. + 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein. +

      11. Miscellaneous

      +

      This License represents the complete agreement concerning subject matter hereof. If + any provision of this License is held to be unenforceable, such provision shall be + reformed only to the extent necessary to make it enforceable. This License shall be + governed by California law provisions (except to the extent applicable law, if any, + provides otherwise), excluding its conflict-of-law provisions. With respect to + disputes in which at least one party is a citizen of, or an entity chartered or + registered to do business in the United States of America, any litigation relating to + this License shall be subject to the jurisdiction of the Federal Courts of the + Northern District of California, with venue lying in Santa Clara County, California, + with the losing party responsible for costs, including without limitation, court + costs and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is expressly + excluded. Any law or regulation which provides that the language of a contract + shall be construed against the drafter shall not apply to this License. +

      12. Responsibility for claims

      +

      As between Initial Developer and the Contributors, each party is responsible for + claims and damages arising, directly or indirectly, out of its utilization of rights + under this License and You agree to work with Initial Developer and Contributors to + distribute such responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. +

      13. Multiple-licensed code

      +

      Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits + you to utilize portions of the Covered Code under Your choice of the MPL + or the alternative licenses, if any, specified by the Initial Developer in the file + described in Exhibit A. +

      Exhibit A - Mozilla Public License.

      +
      "The contents of this file are subject to the Mozilla Public License
      +Version 1.1 (the "License"); you may not use this file except in
      +compliance with the License. You may obtain a copy of the License at
      +http://www.mozilla.org/MPL/
      +
      +Software distributed under the License is distributed on an "AS IS"
      +basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
      +License for the specific language governing rights and limitations
      +under the License.
      +
      +The Original Code is JTransforms.
      +
      +The Initial Developer of the Original Code is
      +Piotr Wendykier, Emory University.
      +Portions created by the Initial Developer are Copyright (C) 2007-2009
      +the Initial Developer. All Rights Reserved.
      +
      +Alternatively, the contents of this file may be used under the terms of
      +either the GNU General Public License Version 2 or later (the "GPL"), or
      +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
      +in which case the provisions of the GPL or the LGPL are applicable instead
      +of those above. If you wish to allow use of your version of this file only
      +under the terms of either the GPL or the LGPL, and not to allow others to
      +use your version of this file under the terms of the MPL, indicate your
      +decision by deleting the provisions above and replace them with the notice
      +and other provisions required by the GPL or the LGPL. If you do not delete
      +the provisions above, a recipient may use your version of this file under
      +the terms of any one of the MPL, the GPL or the LGPL.
      +

      NOTE: The text of this Exhibit A may differ slightly from the text of + the notices in the Source Code files of the Original Code. You should + use the text of this Exhibit A rather than the text found in the + Original Code Source Code for Your Modifications. + +

      \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses-binary/LICENSE-kryo.txt similarity index 100% rename from licenses/LICENSE-kryo.txt rename to licenses-binary/LICENSE-kryo.txt diff --git a/licenses-binary/LICENSE-leveldbjni.txt b/licenses-binary/LICENSE-leveldbjni.txt new file mode 100644 index 0000000000000..b4dabb9174c6d --- /dev/null +++ b/licenses-binary/LICENSE-leveldbjni.txt @@ -0,0 +1,27 @@ +Copyright (c) 2011 FuseSource Corp. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of FuseSource Corp. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-machinist.txt b/licenses-binary/LICENSE-machinist.txt new file mode 100644 index 0000000000000..68cc3a3e3a9c4 --- /dev/null +++ b/licenses-binary/LICENSE-machinist.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-matchMedia-polyfill.txt b/licenses-binary/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses-binary/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses-binary/LICENSE-minlog.txt similarity index 100% rename from licenses/LICENSE-minlog.txt rename to licenses-binary/LICENSE-minlog.txt diff --git a/licenses-binary/LICENSE-modernizr.txt b/licenses-binary/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses-binary/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses-binary/LICENSE-netlib.txt similarity index 100% rename from licenses/LICENSE-netlib.txt rename to licenses-binary/LICENSE-netlib.txt diff --git a/licenses/LICENSE-paranamer.txt b/licenses-binary/LICENSE-paranamer.txt similarity index 100% rename from licenses/LICENSE-paranamer.txt rename to licenses-binary/LICENSE-paranamer.txt diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses-binary/LICENSE-pmml-model.txt similarity index 100% rename from licenses/LICENSE-jpmml-model.txt rename to licenses-binary/LICENSE-pmml-model.txt diff --git a/licenses/LICENSE-protobuf.txt b/licenses-binary/LICENSE-protobuf.txt similarity index 100% rename from licenses/LICENSE-protobuf.txt rename to licenses-binary/LICENSE-protobuf.txt diff --git a/licenses-binary/LICENSE-py4j.txt b/licenses-binary/LICENSE-py4j.txt new file mode 100644 index 0000000000000..70af3e69ed67a --- /dev/null +++ b/licenses-binary/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses-binary/LICENSE-pyrolite.txt similarity index 100% rename from licenses/LICENSE-pyrolite.txt rename to licenses-binary/LICENSE-pyrolite.txt diff --git a/licenses/LICENSE-reflectasm.txt b/licenses-binary/LICENSE-reflectasm.txt similarity index 100% rename from licenses/LICENSE-reflectasm.txt rename to licenses-binary/LICENSE-reflectasm.txt diff --git a/licenses-binary/LICENSE-respond.txt b/licenses-binary/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses-binary/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-sbt-launch-lib.txt b/licenses-binary/LICENSE-sbt-launch-lib.txt new file mode 100644 index 0000000000000..3b9156baaab78 --- /dev/null +++ b/licenses-binary/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses-binary/LICENSE-scala.txt similarity index 100% rename from licenses/LICENSE-scala.txt rename to licenses-binary/LICENSE-scala.txt diff --git a/licenses-binary/LICENSE-scopt.txt b/licenses-binary/LICENSE-scopt.txt new file mode 100644 index 0000000000000..e92e9b592fba0 --- /dev/null +++ b/licenses-binary/LICENSE-scopt.txt @@ -0,0 +1,9 @@ +This project is licensed under the MIT license. + +Copyright (c) scopt contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses-binary/LICENSE-slf4j.txt similarity index 100% rename from licenses/LICENSE-slf4j.txt rename to licenses-binary/LICENSE-slf4j.txt diff --git a/licenses-binary/LICENSE-sorttable.js.txt b/licenses-binary/LICENSE-sorttable.js.txt new file mode 100644 index 0000000000000..b31a5b206bf40 --- /dev/null +++ b/licenses-binary/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses-binary/LICENSE-spire.txt similarity index 100% rename from licenses/LICENSE-spire.txt rename to licenses-binary/LICENSE-spire.txt diff --git a/licenses-binary/LICENSE-vis.txt b/licenses-binary/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses-binary/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses-binary/LICENSE-xmlenc.txt similarity index 100% rename from licenses/LICENSE-xmlenc.txt rename to licenses-binary/LICENSE-xmlenc.txt diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses-binary/LICENSE-zstd-jni.txt similarity index 100% rename from licenses/LICENSE-zstd-jni.txt rename to licenses-binary/LICENSE-zstd-jni.txt diff --git a/licenses/LICENSE-zstd.txt b/licenses-binary/LICENSE-zstd.txt similarity index 100% rename from licenses/LICENSE-zstd.txt rename to licenses-binary/LICENSE-zstd.txt diff --git a/licenses/LICENSE-CC0.txt b/licenses/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt deleted file mode 100644 index a538825d89ec5..0000000000000 --- a/licenses/LICENSE-SnapTree.txt +++ /dev/null @@ -1,35 +0,0 @@ -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/licenses/LICENSE-bootstrap.txt b/licenses/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt deleted file mode 100644 index 7bba0cd9e10a4..0000000000000 --- a/licenses/LICENSE-boto.txt +++ /dev/null @@ -1,20 +0,0 @@ -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-datatables.txt b/licenses/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt index c9e18cd562423..4864fe05e9803 100644 --- a/licenses/LICENSE-graphlib-dot.txt +++ b/licenses/LICENSE-graphlib-dot.txt @@ -1,4 +1,4 @@ -Copyright (c) 2012-2013 Chris Pettitt +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt deleted file mode 100644 index d332534c06356..0000000000000 --- a/licenses/LICENSE-jbcrypt.txt +++ /dev/null @@ -1,17 +0,0 @@ -jBCrypt is subject to the following license: - -/* - * Copyright (c) 2006 Damien Miller - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-join.txt similarity index 60% rename from licenses/LICENSE-jmock.txt rename to licenses/LICENSE-join.txt index ed7964fe3d9ef..1d916090e4ea0 100644 --- a/licenses/LICENSE-jmock.txt +++ b/licenses/LICENSE-join.txt @@ -1,19 +1,21 @@ -Copyright (c) 2000-2017, jMock.org +Copyright (c) 2011, Douban Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -Redistributions of source code must retain the above copyright notice, -this list of conditions and the following disclaimer. Redistributions -in binary form must reproduce the above copyright notice, this list of -conditions and the following disclaimer in the documentation and/or -other materials provided with the distribution. + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. -Neither the name of jMock nor the names of its contributors may be -used to endorse or promote products derived from this software without -specific prior written permission. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -25,4 +27,4 @@ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt index e1dd696d3b6cc..45930542204fb 100644 --- a/licenses/LICENSE-jquery.txt +++ b/licenses/LICENSE-jquery.txt @@ -1,9 +1,20 @@ -The MIT License (MIT) +Copyright JS Foundation and other contributors, https://js.foundation/ -Copyright (c) +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-json-formatter.txt b/licenses/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-matchMedia-polyfill.txt b/licenses/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt deleted file mode 100644 index 515bf9af4d432..0000000000000 --- a/licenses/LICENSE-postgresql.txt +++ /dev/null @@ -1,24 +0,0 @@ -PostgreSQL Database Management System -(formerly known as Postgres, then as Postgres95) - -Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group - -Portions Copyright (c) 1994, The Regents of the University of California - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose, without fee, and without a written agreement -is hereby granted, provided that the above copyright notice and this -paragraph and the following two paragraphs appear in all copies. - -IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR -DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING -LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS -DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS -ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO -PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - diff --git a/licenses/LICENSE-respond.txt b/licenses/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt deleted file mode 100644 index cb8f97842f4c4..0000000000000 --- a/licenses/LICENSE-scalacheck.txt +++ /dev/null @@ -1,32 +0,0 @@ -ScalaCheck LICENSE - -Copyright (c) 2007-2015, Rickard Nilsson -All rights reserved. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the author nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-vis.txt b/licenses/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file From 8f91c697e251423b826cd6ac4ddd9e2dac15b96e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 2 Jul 2018 14:35:37 +0800 Subject: [PATCH 1046/2461] [SPARK-24665][PYSPARK] Use SQLConf in PySpark to manage all sql configs ## What changes were proposed in this pull request? Use SQLConf for PySpark to manage all sql configs, drop all the hard code in config usage. ## How was this patch tested? Existing UT. Author: Yuanjian Li Closes #21648 from xuanyuanking/SPARK-24665. --- python/pyspark/sql/context.py | 5 +++ python/pyspark/sql/dataframe.py | 42 +++++-------------- .../apache/spark/sql/internal/SQLConf.scala | 6 +++ 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e9ec7ba866761..9c094dd9a9033 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -93,6 +93,11 @@ def _ssql_ctx(self): """ return self._jsqlContext + @property + def _conf(self): + """Accessor for the JVM SQL-specific configurations""" + return self.sparkSession._jsparkSession.sessionState().conf() + @classmethod @since(1.6) def getOrCreate(cls, sc): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cb3fe448b6fc7..c40aea9bcef0a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -354,32 +354,12 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) - @property - def _eager_eval(self): - """Returns true if the eager evaluation enabled. - """ - return self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" - - @property - def _max_num_rows(self): - """Returns the max row number for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.maxNumRows", "20")) - - @property - def _truncate(self): - """Returns the truncate length for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.truncate", "20")) - def __repr__(self): - if not self._support_repr_html and self._eager_eval: + if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): vertical = False return self._jdf.showString( - self._max_num_rows, self._truncate, vertical) + self.sql_ctx._conf.replEagerEvalMaxNumRows(), + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -391,10 +371,10 @@ def _repr_html_(self): import cgi if not self._support_repr_html: self._support_repr_html = True - if self._eager_eval: - max_num_rows = max(self._max_num_rows, 0) + if self.sql_ctx._conf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate) + max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] @@ -2049,13 +2029,12 @@ def toPandas(self): import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() else: timezone = None - if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + if self.sql_ctx._conf.arrowEnabled(): use_arrow = True try: from pyspark.sql.types import to_arrow_schema @@ -2065,8 +2044,7 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self.sql_ctx._conf.arrowFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index da1c34cdc78f2..e2c48e2d8a14c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1728,6 +1728,12 @@ class SQLConf extends Serializable with Logging { def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From 8008f9cb82e7c228b94eade2e7cb484d6d17e6a4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 2 Jul 2018 22:09:47 +0800 Subject: [PATCH 1047/2461] [SPARK-24715][BUILD] Override jline version as 2.14.3 in SBT ## What changes were proposed in this pull request? During SPARK-24418 (Upgrade Scala to 2.11.12 and 2.12.6), we upgrade `jline` version together. So, `mvn` works correctly. However, `sbt` brings old jline library and is hitting `NoSuchMethodError` in `master` branch, see https://github.com/apache/spark/pull/21495#issuecomment-401560826. This overrides jline version in SBT to make sbt build work. ## How was this patch tested? Manually test. Author: Liang-Chi Hsieh Closes #21692 from viirya/SPARK-24715. --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b606f9355e03b..f887e4570c85d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -464,7 +464,8 @@ object DockerIntegrationTests { */ object DependencyOverrides { lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", + dependencyOverrides += "jline" % "jline" % "2.14.3") } /** From f599cde69506a5aedeeec449cba9a8b5ab128282 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 2 Jul 2018 22:39:00 +0800 Subject: [PATCH 1048/2461] [SPARK-24507][DOCUMENTATION] Update streaming guide ## What changes were proposed in this pull request? Updated streaming guide for direct stream and link to integration guide. ## How was this patch tested? jekyll build Author: Rekha Joshi Closes #21683 from rekhajoshm/SPARK-24507. --- docs/streaming-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c30959263cdfa..118b05355c74d 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2176,6 +2176,8 @@ the input data stream (using `inputStream.repartition()`). This distributes the received batches of data across the specified number of machines in the cluster before further processing. +For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html) + ### Level of Parallelism in Data Processing {:.no_toc} Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the From 42815548c7ef498439b9ba47134a6f3e1b519c83 Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 2 Jul 2018 10:24:04 -0700 Subject: [PATCH 1049/2461] [SPARK-24683][K8S] Fix k8s no resource ## What changes were proposed in this pull request? Make SparkSubmit pass in the main class even if `SparkLauncher.NO_RESOURCE` is the primary resource. ## How was this patch tested? New integration test written to capture this case. Author: mcheah Closes #21660 from mccheah/fix-k8s-no-resource. --- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 2 ++ .../deploy/k8s/submit/KubernetesDriverBuilder.scala | 3 ++- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 10 ++++++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e83d82f847c61..2da778a29779d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -702,6 +702,8 @@ private[spark] class SparkSubmit extends Logging { childArgs ++= Array("--primary-java-resource", args.primaryResource) childArgs ++= Array("--main-class", args.mainClass) } + } else { + childArgs ++= Array("--main-class", args.mainClass) } if (args.childArgs != null) { args.childArgs.foreach { arg => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 5762d8245f778..0dd1c37661707 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -64,7 +64,8 @@ private[spark] class KubernetesDriverBuilder( case JavaMainAppResource(_) => provideJavaStep(kubernetesConf) case PythonMainAppResource(_) => - providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf)) + providePythonStep(kubernetesConf)} + .getOrElse(provideJavaStep(kubernetesConf)) val allFeatures: Seq[KubernetesFeatureConfigStep] = (baseFeatures :+ bindingsStep) ++ diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 65c513cf241a4..6e334c83fbde8 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -21,17 +21,17 @@ import java.nio.file.{Path, Paths} import java.util.UUID import java.util.regex.Pattern -import scala.collection.JavaConverters._ - import com.google.common.io.PatternFilenameFilter import io.fabric8.kubernetes.api.model.{Container, Pod} import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} +import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} import org.apache.spark.deploy.k8s.integrationtest.config._ +import org.apache.spark.launcher.SparkLauncher private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { @@ -109,6 +109,12 @@ private[spark] class KubernetesSuite extends SparkFunSuite runSparkPiAndVerifyCompletion() } + test("Use SparkLauncher.NO_RESOURCE") { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + test("Run SparkPi with a master URL without a scheme.") { val url = kubernetesTestComponents.kubernetesClient.getMasterUrl val k8sMasterUrl = if (url.getPort < 0) { From 85fe1297e35bcff9cf86bd53fee615e140ee5bfb Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Mon, 2 Jul 2018 13:08:16 -0700 Subject: [PATCH 1050/2461] [SPARK-24428][K8S] Fix unused code ## What changes were proposed in this pull request? Remove code that is misleading and is a leftover from a previous implementation. ## How was this patch tested? Manually. Author: Stavros Kontopoulos Closes #21462 from skonto/fix-k8s-docs. --- .../org/apache/spark/deploy/k8s/Constants.scala | 6 ------ .../cluster/k8s/KubernetesClusterManager.scala | 2 -- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 12 +++++------- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 69bd03d1eda6f..5ecdd3a04d77b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,6 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" - // Annotations - val SPARK_APP_NAME_ANNOTATION = "spark-app-name" - // Credentials secrets val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = "/mnt/secrets/spark-kubernetes-credentials" @@ -50,17 +47,14 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" - val EXECUTOR_PORT_NAME = "executor" // Environment Variables - val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" val ENV_DRIVER_URL = "SPARK_DRIVER_URL" val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index c6e931a38405f..de2a52bc7a0b8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -48,8 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 2f4e115e84ecd..8bdb0f7a10795 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -51,12 +51,10 @@ esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt -if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then - SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" -fi -if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then - cp -R "$SPARK_MOUNTED_FILES_DIR/." . +readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt + +if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" fi if [ -n "$PYSPARK_FILES" ]; then @@ -101,7 +99,7 @@ case "$SPARK_K8S_CMD" in executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" From a7c8f0c8cb144a026ea21e8780107e363ceacb8d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 3 Jul 2018 12:20:03 +0800 Subject: [PATCH 1051/2461] [SPARK-24385][SQL] Resolve self-join condition ambiguity for EqualNullSafe ## What changes were proposed in this pull request? In Dataset.join we have a small hack for resolving ambiguity in the column name for self-joins. The current code supports only `EqualTo`. The PR extends the fix to `EqualNullSafe`. Credit for this PR should be given to daniel-shields. ## How was this patch tested? added UT Author: Marco Gaido Closes #21605 from mgaido91/SPARK-24385_2. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 5 +++++ .../scala/org/apache/spark/sql/DataFrameJoinSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2ec236fc75efc..c97246f30220d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1016,6 +1016,11 @@ class Dataset[T] private[sql]( catalyst.expressions.EqualTo( withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} withPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..10d9a11d2ee79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -287,4 +287,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with EqualNullSafe") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(2) + // this throws an exception before the fix + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + } + } } From 5585c5765f13519a447587ca778d52ce6a36a484 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 3 Jul 2018 10:13:48 -0700 Subject: [PATCH 1052/2461] [SPARK-24420][BUILD] Upgrade ASM to 6.1 to support JDK9+ ## What changes were proposed in this pull request? Upgrade ASM to 6.1 to support JDK9+ ## How was this patch tested? Existing tests. Author: DB Tsai Closes #21459 from dbtsai/asm. --- core/pom.xml | 2 +- .../main/scala/org/apache/spark/util/ClosureCleaner.scala | 4 ++-- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- graphx/pom.xml | 2 +- .../org/apache/spark/graphx/util/BytecodeUtils.scala | 4 ++-- pom.xml | 8 ++++---- repl/pom.xml | 4 ++-- .../scala/org/apache/spark/repl/ExecutorClassLoader.scala | 4 ++-- sql/core/pom.xml | 2 +- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 220522d3a8296..d0b869e6ef92c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ad0c0639521f6..073d71c63b0c7 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 96e9c27210d05..f50a0aac0aefc 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -192,7 +192,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 4a6ee027ec355..774f9dc39ce4d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -193,7 +193,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e0b560c8ec71f..19c05ad1e991f 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -214,7 +214,7 @@ token-provider-1.0.1.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xz-1.0.jar zjsonpatch-0.3.0.jar zookeeper-3.4.9.jar diff --git a/graphx/pom.xml b/graphx/pom.xml index fbe77fcb958d5..0f5dc548600b2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d76e84ed8c9ed..a559685b1633c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.util.Utils diff --git a/pom.xml b/pom.xml index 90e64ff71d229..ca30f9f12b098 100644 --- a/pom.xml +++ b/pom.xml @@ -313,13 +313,13 @@ chill-java ${chill.version} - org.apache.xbean - xbean-asm5-shaded - 4.4 + xbean-asm6-shaded + 4.8 @@ -166,7 +166,7 @@ - + scala-2.12 diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4dc399827ffed..42298b06a2c86 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil diff --git a/sql/core/pom.xml b/sql/core/pom.xml index f270c70fbfcf0..18ae314309d7b 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.scalacheck From 776f299fc8146b400e97185b1577b0fc8f06e14b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 4 Jul 2018 09:38:18 +0800 Subject: [PATCH 1053/2461] [SPARK-24709][SQL] schema_of_json() - schema inference from an example ## What changes were proposed in this pull request? In the PR, I propose to add new function - *schema_of_json()* which infers schema of JSON string literal. The result of the function is a string containing a schema in DDL format. One of the use cases is using of *schema_of_json()* in the combination with *from_json()*. Currently, _from_json()_ requires a schema as a mandatory argument. The *schema_of_json()* function will allow to point out an JSON string as an example which has the same schema as the first argument of _from_json()_. For instance: ```sql select from_json(json_column, schema_of_json('{"c1": [0], "c2": [{"c3":0}]}')) from json_table; ``` ## How was this patch tested? Added new test to `JsonFunctionsSuite`, `JsonExpressionsSuite` and SQL tests to `json-functions.sql` Author: Maxim Gekk Closes #21686 from MaxGekk/infer_schema_json. --- python/pyspark/sql/functions.py | 27 ++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/jsonExpressions.scala | 52 ++++++++++++++++--- .../sql/catalyst}/json/JsonInferSchema.scala | 5 +- .../expressions/JsonExpressionsSuite.scala | 7 +++ .../datasources/json/JsonDataSource.scala | 2 +- .../org/apache/spark/sql/functions.scala | 42 +++++++++++++++ .../sql-tests/inputs/json-functions.sql | 4 ++ .../sql-tests/results/json-functions.sql.out | 20 ++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 17 +++++- .../datasources/json/JsonSuite.scala | 4 +- 11 files changed, 163 insertions(+), 18 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JsonInferSchema.scala (98%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9652d3e79b875..4d371976364d3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2189,11 +2189,16 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2235,6 +2240,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a574d8a84d4fb..80a0af672bf74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -505,6 +505,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f6d74f5b74c8e..8cd86053a01c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -525,17 +526,19 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, timeZoneId = None, forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None, @@ -744,11 +747,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): DataType = exp match { + def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 8e1b430f4eb33..491ca005877f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.util.Comparator @@ -25,7 +25,6 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -103,7 +102,7 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 00e97637eee7e..52203b9e337ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -706,4 +706,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with assert(schemaToCompare == schema) } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct,col1:struct>") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..2fee2128ba1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acca9572cb14c..614f65f0faaba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3381,6 +3381,48 @@ object functions { from_json(e, dataType, options) } + /** + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index dc15d13cd1dd3..79fdd5895e691 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -35,3 +35,7 @@ DROP VIEW IF EXISTS jsonTable; -- from_json - complex types select from_json('{"a":1, "b":2}', 'map'); select from_json('{"a":1, "b":"2"}', 'struct'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 2b3288dc5a137..3d49323751a10 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 30 -- !query 0 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -274,3 +274,19 @@ select from_json('{"a":1, "b":"2"}', 'struct') struct> -- !query 27 output {"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct +-- !query 28 output +struct> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct>> +-- !query 29 output +{"c1":[1,2,3]} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7bf17cbcd9c97..d3b2701f2558e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -311,7 +311,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -392,4 +392,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), Row(null)) } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 897424daca0cb..eab15b35c97d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ From b42fda8ab3b5f82b33b96fce3f584c50f2ed5a3a Mon Sep 17 00:00:00 2001 From: cclauss Date: Wed, 4 Jul 2018 09:40:58 +0800 Subject: [PATCH 1054/2461] [SPARK-23698] Remove raw_input() from Python 2 Signed-off-by: cclauss ## What changes were proposed in this pull request? Humans will be able to enter text in Python 3 prompts which they can not do today. The Python builtin __raw_input()__ was removed in Python 3 in favor of __input()__. This PR does the same thing in Python 2. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) flake8 testing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: cclauss Closes #21702 from cclauss/python-fix-raw_input. --- dev/create-release/releaseutils.py | 5 ++++- dev/merge_spark_pr.py | 21 ++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 32f6cbb29f0be..ab812e1bb7c04 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,13 +49,16 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) +if sys.version < '3': + input = raw_input + # Contributors list file name contributors_file_name = "contributors.txt" # Prompt the user to answer yes or no until they do so def yesOrNoPrompt(msg): - response = raw_input("%s [y/n]: " % msg) + response = input("%s [y/n]: " % msg) while response != "y" and response != "n": return yesOrNoPrompt(msg) return response == "y" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 7f46a1c8f6a7c..79c7c021fe74a 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -39,6 +39,9 @@ except ImportError: JIRA_IMPORTED = False +if sys.version < '3': + input = raw_input + # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site @@ -95,7 +98,7 @@ def run_cmd(cmd): def continue_maybe(prompt): - result = raw_input("\n%s (y/n): " % prompt) + result = input("\n%s (y/n): " % prompt) if result.lower() != "y": fail("Okay, exiting") @@ -134,7 +137,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = raw_input( + primary_author = input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) if primary_author == "": @@ -184,7 +187,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): def cherry_pick(pr_num, merge_hash, default_branch): - pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + pick_ref = input("Enter a branch name [%s]: " % default_branch) if pick_ref == "": pick_ref = default_branch @@ -231,7 +234,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -276,7 +279,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): default_fix_versions = filter(lambda x: x != v, default_fix_versions) default_fix_versions = ",".join(default_fix_versions) - fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) if fix_versions == "": fix_versions = default_fix_versions fix_versions = fix_versions.replace(" ", "").split(",") @@ -315,7 +318,7 @@ def choose_jira_assignee(issue, asf_jira): if author in commentors: annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - raw_assignee = raw_input( + raw_assignee = input( "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None @@ -428,7 +431,7 @@ def main(): # Assumes branch names can be sorted lexicographically latest_branch = sorted(branch_names, reverse=True)[0] - pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr_num = input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) @@ -440,7 +443,7 @@ def main(): print("I've re-written the title as follows to match the standard format:") print("Original: %s" % pr["title"]) print("Modified: %s" % modified_title) - result = raw_input("Would you like to use the modified title? (y/n): ") + result = input("Would you like to use the modified title? (y/n): ") if result.lower() == "y": title = modified_title print("Using modified title:") @@ -491,7 +494,7 @@ def main(): merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) pick_prompt = "Would you like to pick %s into another branch?" % merge_hash - while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + while input("\n%s (y/n): " % pick_prompt).lower() == "y": merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: From 5bf95f2a37e624eb6fb0ef6fbd2a40a129d5a470 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 4 Jul 2018 09:53:04 +0800 Subject: [PATCH 1055/2461] [BUILD] Close stale PRs Closes #20932 Closes #17843 Closes #13477 Closes #14291 Closes #20919 Closes #17907 Closes #18766 Closes #20809 Closes #8849 Closes #21076 Closes #21507 Closes #21336 Closes #21681 Closes #21691 Author: Sean Owen Closes #21708 from srowen/CloseStalePRs. From 7c08eb6d61d55ce45229f3302e6d463e7669183d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Jul 2018 12:21:26 +0800 Subject: [PATCH 1056/2461] [SPARK-24732][SQL] Type coercion between MapTypes. ## What changes were proposed in this pull request? Currently we don't allow type coercion between maps. We can support type coercion between MapTypes where both the key types and the value types are compatible. ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21703 from ueshin/issues/SPARK-24732/maptypecoercion. --- .../sql/catalyst/analysis/TypeCoercion.scala | 12 +++++ .../catalyst/analysis/TypeCoercionSuite.scala | 45 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3ebab430ffbcd..cf90e6e555fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -179,6 +179,12 @@ object TypeCoercion { .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findWiderTypeForTwo(kt1, kt2).flatMap { kt => + findWiderTypeForTwo(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } case _ => None }) } @@ -220,6 +226,12 @@ object TypeCoercion { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt => + findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } case _ => None }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..4e5ca1b8cdd36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable only when the internal child types also match; otherwise, not castable. // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -487,12 +488,38 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(ArrayType(IntegerType), containsNull = false), ArrayType(ArrayType(LongType), containsNull = false), Some(ArrayType(ArrayType(LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(MapType(IntegerType, FloatType), containsNull = false), + ArrayType(MapType(LongType, DoubleType), containsNull = false), + Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + + // MapType + widenTestWithStringPromotion( + MapType(ShortType, TimestampType, valueContainsNull = true), + MapType(DoubleType, StringType, valueContainsNull = false), + Some(MapType(DoubleType, StringType, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false), + MapType(LongType, ArrayType(StringType), valueContainsNull = true), + Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), + MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), + Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) widenTestWithoutStringPromotion(StringType, TimestampType, None) widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -501,6 +528,22 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) widenTestWithStringPromotion( ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + MapType(LongType, IntegerType), + MapType(StringType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, LongType), + MapType(IntegerType, StringType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType), + MapType(TimestampType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType), + MapType(IntegerType, TimestampType), + Some(MapType(IntegerType, StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { From 772060d0940a97d89807befd682a70ae82e83ef4 Mon Sep 17 00:00:00 2001 From: Stan Zhai Date: Wed, 4 Jul 2018 10:12:36 +0200 Subject: [PATCH 1057/2461] [SPARK-24704][WEBUI] Fix the order of stages in the DAG graph ## What changes were proposed in this pull request? Before: ![wx20180630-155537](https://user-images.githubusercontent.com/1438757/42123357-2c2e2d84-7c83-11e8-8abd-1c2860f38783.png) After: ![wx20180630-155604](https://user-images.githubusercontent.com/1438757/42123359-32fae990-7c83-11e8-8a7b-cdcee94f9123.png) ## How was this patch tested? Manual tests. Author: Stan Zhai Closes #21680 from stanzhai/fix-dag-graph. --- .../src/main/scala/org/apache/spark/status/AppStatusStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 688f25a9fdea1..e237281c552b1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -471,7 +471,7 @@ private[spark] class AppStatusStore( def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { val job = store.read(classOf[JobDataWrapper], jobId) - val stages = job.info.stageIds + val stages = job.info.stageIds.sorted stages.map { id => val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() From b2deef64f604ddd9502a31105ed47cb63470ec85 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Jul 2018 20:04:18 +0800 Subject: [PATCH 1058/2461] [SPARK-24727][SQL] Add a static config to control cache size for generated classes ## What changes were proposed in this pull request? Since SPARK-24250 has been resolved, executors correctly references user-defined configurations. So, this pr added a static config to control cache size for generated classes in `CodeGenerator`. ## How was this patch tested? Added tests in `ExecutorSideSQLConfSuite`. Author: Takeshi Yamamuro Closes #21705 from maropu/SPARK-24727. --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 ++ .../spark/sql/internal/StaticSQLConf.scala | 8 +++++ .../internal/ExecutorSideSQLConfSuite.scala | 31 +++++++++++++++---- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cc0968911cb5..838c045d5bcce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1415,7 +1415,7 @@ object CodeGenerator extends Logging { * weak keys/values and thus does not respond to memory pressure. */ private val cache = CacheBuilder.newBuilder() - .maximumSize(100) + .maximumSize(SQLConf.get.codegenCacheMaxEntries) .build( new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { override def load(code: CodeAndComment): (GeneratedClass, Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e2c48e2d8a14c..50965c1abc68c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1508,6 +1508,8 @@ class SQLConf extends Serializable with Logging { def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 382ef28f49a7a..384b1917a1f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -66,6 +66,14 @@ object StaticSQLConf { .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") .createWithDefault(1000) + val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries") + .internal() + .doc("When nonzero, enable caching of generated classes for operators and expressions. " + + "All jobs share the cache that can use up to the specified number for generated classes.") + .intConf + .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") + .createWithDefault(100) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 3dd0712e02448..855fe4f4523f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.test.SQLTestUtils class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { @@ -40,16 +40,24 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { spark = null } + override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + pairs.foreach { case (k, v) => + SQLConf.get.setConfString(k, v) + } + try f finally { + pairs.foreach { case (k, _) => + SQLConf.get.unsetConf(k) + } + } + } + test("ReadOnlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => + withSQLConf("spark.sql.x" -> "a") { + val checks = spark.range(10).mapPartitions { _ => val conf = SQLConf.get Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") }.collect() assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") } } @@ -63,4 +71,15 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("SPARK-24727 CODEGEN_CACHE_MAX_ENTRIES is correctly referenced at the executor side") { + withSQLConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key -> "300") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && + conf.getConfString(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key) == "300") + }.collect() + assert(checks.forall(_ == true)) + } + } } From 021145f36432b386cce30450c888a85393d5169f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 4 Jul 2018 20:15:40 +0800 Subject: [PATCH 1059/2461] [SPARK-24716][SQL] Refactor ParquetFilters ## What changes were proposed in this pull request? Replace DataFrame schema to Parquet file schema when create `ParquetFilters`. Thus we can easily implement `Decimal` and `Timestamp` push down. some thing like this: ```scala // DecimalType: 32BitDecimalType case ParquetSchemaType(DECIMAL, INT32, decimal) if pushDownDecimal => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(_.asInstanceOf[JBigDecimal].unscaledValue().intValue() .asInstanceOf[Integer]).orNull) // DecimalType: 64BitDecimalType case ParquetSchemaType(DECIMAL, INT64, decimal) if pushDownDecimal => (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[JBigDecimal].unscaledValue().longValue() .asInstanceOf[java.lang.Long]).orNull) // DecimalType: LegacyParquetFormat 32BitDecimalType & 64BitDecimalType case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, decimal) if pushDownDecimal && decimal.getPrecision <= Decimal.MAX_LONG_DIGITS => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(d => decimalToBinaryUsingUnscaledLong(decimal.getPrecision, d.asInstanceOf[JBigDecimal])).orNull) // DecimalType: ByteArrayDecimalType case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, decimal) if pushDownDecimal && decimal.getPrecision > Decimal.MAX_LONG_DIGITS => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(d => decimalToBinaryUsingUnscaledBytes(decimal.getPrecision, d.asInstanceOf[JBigDecimal])).orNull) ``` ```scala // INT96 doesn't support pushdown case ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) => (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) .asInstanceOf[java.lang.Long]).orNull) case ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) => (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[java.lang.Long]).orNull) ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #21696 from wangyum/SPARK-24716. --- .../parquet/ParquetFileFormat.scala | 34 ++-- .../datasources/parquet/ParquetFilters.scala | 173 ++++++++++-------- .../apache/spark/sql/sources/filters.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 13 +- 4 files changed, 121 insertions(+), 101 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 93de1faef527a..52a18abb55241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -353,25 +353,13 @@ class ParquetFileFormat (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) - .createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val filePath = fileSplit.getPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( - fileSplit.getPath, + filePath, fileSplit.getStart, fileSplit.getStart + fileSplit.getLength, fileSplit.getLength, @@ -379,12 +367,28 @@ class ParquetFileFormat null) val sharedConf = broadcastedHadoopConf.value.value + + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) + .getFileMetaData.getSchema + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) + .createFilter(parquetSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) + val footer = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") } val convertTz = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 21c9e2e4f82b4..4827f706e6016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -19,15 +19,19 @@ package org.apache.spark.sql.execution.datasources.parquet import java.sql.Date +import scala.collection.JavaConverters.asScalaBufferConverter + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.PrimitiveComparator +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator, PrimitiveType} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources -import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -35,171 +39,180 @@ import org.apache.spark.unsafe.types.UTF8String */ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) + private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } - private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => + case ParquetIntegerType => (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetLongType => (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } - private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => + private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => + case ParquetIntegerType => (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetLongType => (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } - private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => + private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetLongType => (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.lt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } - private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => + (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.ltEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } - private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => + (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.gt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } - private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => + (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.gtEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } /** * Returns a map from name of the column to the data type, if predicate push down applies. */ - private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => + private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { + case m: MessageType => // Here we don't flatten the fields in the nested schema but just look up through // root fields. Currently, accessing to nested fields does not push down filters // and it does not support to create filters for them. - fields.map(f => f.name -> f.dataType).toMap - case _ => Map.empty[String, DataType] + m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetSchemaType( + f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata) + }.toMap + case _ => Map.empty[String, ParquetSchemaType] } /** * Converts data sources filters to Parquet filter predicates. */ - def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) // Parquet does not allow dots in the column name because dots are used as a column path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2499e9b604f3e..bdd8c4da6bd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to - * a string that starts with `value`. + * a string that ends with `value`. * * @since 1.3.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index d9ae5858e5ed0..8b96c841c8c6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -103,7 +103,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = parquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -542,12 +543,14 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex StructField("c", DoubleType, nullable = true) )) + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + assertResult(Some(and( lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.GreaterThan("c", 1.5D))) @@ -555,7 +558,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.StringContains("b", "prefix"))) @@ -563,7 +566,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.Not( sources.And( sources.GreaterThan("a", 1), @@ -729,7 +732,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - df.schema, + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.StringStartsWith("_1", null)) } } From 1a2655a9e75627b584787f9e4c6cdaa92e61fa3f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Jul 2018 20:42:08 +0800 Subject: [PATCH 1060/2461] [SPARK-24635][SQL] Remove Blocks class from JavaCode class hierarchy ## What changes were proposed in this pull request? The `Blocks` class in `JavaCode` class hierarchy is not necessary. Its function can be taken by `CodeBlock`. We should remove it to make simpler class hierarchy. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21619 from viirya/SPARK-24635. --- .../expressions/codegen/javaCode.scala | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e0..44f63e21e93bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -119,6 +119,7 @@ object JavaCode { * A trait representing a block of java code. */ trait Block extends JavaCode { + import Block._ // The expressions to be evaluated inside this block. def exprValues: Set[ExprValue] @@ -148,14 +149,17 @@ trait Block extends JavaCode { } // Concatenates this block with other block. - def + (other: Block): Block + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } } object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 - implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { def code(args: Any*): Block = { @@ -190,18 +194,17 @@ object Block { while (strings.hasNext) { val input = inputs.next input match { - case _: ExprValue | _: Block => + case _: ExprValue | _: CodeBlock => codeParts += buf.toString buf.clear blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => case _ => buf.append(input) } buf.append(strings.next) } - if (buf.nonEmpty) { - codeParts += buf.toString - } + codeParts += buf.toString (codeParts.toSeq, blockInputs.toSeq) } @@ -209,7 +212,11 @@ object Block { /** * A block of java code. Including a sequence of code parts and some inputs to this block. - * The actual java code is generated by embedding the inputs into the code parts. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { override lazy val exprValues: Set[ExprValue] = { @@ -230,30 +237,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(Seq(this, c)) - case b: Blocks => Blocks(Seq(this) ++ b.blocks) - case EmptyBlock => this - } -} - -case class Blocks(blocks: Seq[Block]) extends Block { - override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override lazy val code: String = blocks.map(_.toString).mkString("\n") - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(blocks :+ c) - case b: Blocks => Blocks(blocks ++ b.blocks) - case EmptyBlock => this - } } object EmptyBlock extends Block with Serializable { override val code: String = "" override val exprValues: Set[ExprValue] = Set.empty - - override def + (other: Block): Block = other } /** From ca8243f30fc6939ee099a9534e3b811d5c64d2cf Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 4 Jul 2018 09:56:24 -0500 Subject: [PATCH 1061/2461] [MINOR][ML] Minor correction in the powerIterationSuite ## What changes were proposed in this pull request? Currently the power iteration clustering test in spark ml, maps the results to the labels 0 and 1 for assertion. Since the clustering outputs need not be the same as the mapped labels, it may cause failure in the test case. Even if it correctly maps, theoretically we cannot guarantee which set belongs to which cluster label. KMeans can assign label 0 to either of the set. PowerIterationClusteringSuite in the MLLib checks the clustering results without mapping to the particular cluster label, as shown below. `` val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) `` ## How was this patch tested? Existing tests Author: Shahid Closes #21689 from shahidki31/picTestSuiteMinorCorrection. --- .../PowerIterationClusteringSuite.scala | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index b7072728d48f0..55b460f1a4524 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setMaxIter(40) .setWeightCol("weight") .assignClusters(data) - val localAssignments = assignments - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ - (n1 until n).map(x => (x, 0)).toSet - assert(localAssignments === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + assignments.foreach { + case (id, cluster) => predictions(cluster) += id + } + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) val assignments2 = new PowerIterationClustering() .setK(2) @@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setInitMode("degree") .setWeightCol("weight") .assignClusters(data) - val localAssignments2 = assignments2 - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - assert(localAssignments2 === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id + } + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") { From bf764a33bef617aa9bae535a5ea73d6a3e278d42 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Jul 2018 18:36:09 -0700 Subject: [PATCH 1062/2461] [SPARK-22384][SQL][FOLLOWUP] Refine partition pruning when attribute is wrapped in Cast ## What changes were proposed in this pull request? As mentioned in https://github.com/apache/spark/pull/21586 , `Cast.mayTruncate` is not 100% safe, string to boolean is allowed. Since changing `Cast.mayTruncate` also changes the behavior of Dataset, here I propose to add a new `Cast.canSafeCast` for partition pruning. ## How was this patch tested? new test cases Author: Wenchen Fan Closes #21712 from cloud-fan/safeCast. --- .../spark/sql/catalyst/expressions/Cast.scala | 20 +++++++++++++++++++ .../spark/sql/hive/client/HiveShim.scala | 5 +++-- .../sql/hive/client/HiveClientSuite.scala | 20 +++++++++++++++++-- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 699ea53b5df0f..7971ae602bd37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -134,6 +134,26 @@ object Cast { toPrecedence > 0 && fromPrecedence > toPrecedence } + /** + * Returns true iff we can safely cast the `from` type to `to` type without any truncating or + * precision lose, e.g. int -> long, date -> timestamp. + */ + def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from, to) if legalNumericPrecedence(from, to) => true + case (DateType, TimestampType) => true + case (_, StringType) => true + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8620f3f6d99fb..933384ed43e98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -660,7 +660,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) - case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case Cast(child @ AtomicType(), dt: AtomicType, _) + if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 55275f6b37945..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) @@ -122,6 +122,22 @@ class HiveClientSuite(version: String) "aa" :: Nil) } + test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(IntegerType) === 1, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(BooleanType) === true, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( Literal(20170101) === attr("ds"), @@ -138,7 +154,7 @@ class HiveClientSuite(version: String) "aa" :: "ab" :: "ba" :: "bb" :: Nil) } - test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { testMetastorePartitionFiltering( attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, From 489a5294d106130beda1509e3cbbaf707a3d703d Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 5 Jul 2018 09:56:48 +0800 Subject: [PATCH 1063/2461] [SPARK-17213][SPARK-17213][FOLLOW-UP] Improve the test of ## What changes were proposed in this pull request? This is a minor improvement for the test of SPARK-17213 ## How was this patch tested? N/A Author: Xiao Li Closes #21716 from gatorsmile/testMaster23. --- .../parquet/ParquetFilterSuite.scala | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 8b96c841c8c6e..f2c0bda256239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -618,21 +618,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - withTempPath { dir => - import testImplicits._ + Seq(true, false).foreach { vectorizedEnabled => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii + } + } } } From f997be0c3136f85762b841469e7dfcde7e699ced Mon Sep 17 00:00:00 2001 From: mcteo Date: Thu, 5 Jul 2018 10:05:41 +0800 Subject: [PATCH 1064/2461] [SPARK-24698][PYTHON] Fixed typo in pyspark.ml's Identifiable class. ## What changes were proposed in this pull request? Fixed a small typo in the code that caused 20 random characters to be added to the UID, rather than 12. Author: mcteo Closes #21675 from mcteo/SPARK-24698-fix. --- python/pyspark/ml/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 080cd299f4fde..e846834761e49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -63,7 +63,7 @@ def _randomUID(cls): Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc From 4be9f0c028cebb0d2975e93a6ebc56337cd2c585 Mon Sep 17 00:00:00 2001 From: Antonio Murgia Date: Thu, 5 Jul 2018 16:10:34 +0800 Subject: [PATCH 1065/2461] [SPARK-24673][SQL] scala sql function from_utc_timestamp second argument could be Column instead of String ## What changes were proposed in this pull request? Add an overloaded version to `from_utc_timestamp` and `to_utc_timestamp` having second argument as a `Column` instead of `String`. ## How was this patch tested? Unit testing, especially adding two tests to org.apache.spark.sql.DateFunctionsSuite.scala Author: Antonio Murgia Author: Antonio Murgia Closes #21693 from tmnd1991/feature/SPARK-24673. --- .../org/apache/spark/sql/functions.scala | 22 +++++++++++ .../apache/spark/sql/DateFunctionsSuite.scala | 38 ++++++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 614f65f0faaba..f2627e69939cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2934,6 +2934,17 @@ object functions { FromUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + FromUTCTimestamp(ts.expr, tz.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield @@ -2945,6 +2956,17 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + ToUTCTimestamp(ts.expr, tz.expr) + } + /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 237412aa692e5..3af80b36ec42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -663,7 +663,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp") { + test("from_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -680,7 +680,24 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 17:00:00")))) } - test("to_utc_timestamp") { + test("from_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -697,6 +714,23 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("to_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { checkAnswer( From 32cfd3e75a5ca65696fedfa4d49681e6fc3e698d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Jul 2018 20:48:55 +0800 Subject: [PATCH 1066/2461] [SPARK-24361][SQL] Polish code block manipulation API ## What changes were proposed in this pull request? Current code block manipulation API is immature and hacky. We need a formal API to manipulate code blocks. The basic idea is making `JavaCode` as `TreeNode`. So we can use familiar `transform` API to manipulate code blocks and expressions in code blocks. For example, we can replace `SimpleExprValue` in a code block like this: ```scala code.transformExprValues { case SimpleExprValue("1 + 1", _) => aliasedParam } ``` The example use case is splitting code to methods. For example, we have an `ExprCode` containing generated code. But it is too long and we need to split it as method. Because statement-based expressions can't be directly passed into. We need to transform them as variables first: ```scala def getExprValues(block: Block): Set[ExprValue] = block match { case c: CodeBlock => c.blockInputs.collect { case e: ExprValue => e }.toSet case _ => Set.empty } def currentCodegenInputs(ctx: CodegenContext): Set[ExprValue] = { // Collects current variables in ctx.currentVars and ctx.INPUT_ROW. // It looks roughly like... ctx.currentVars.flatMap { v => getExprValues(v.code) ++ Set(v.value, v.isNull) }.toSet + ctx.INPUT_ROW } // A code block of an expression contains too long code, making it as method if (eval.code.length > 1024) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { ... } else { "" } // Pick up variables and statements necessary to pass in. val currentVars = currentCodegenInputs(ctx) val varsPassIn = getExprValues(eval.code).intersect(currentVars) val aliasedExprs = HashMap.empty[SimpleExprValue, VariableValue] // Replace statement-based expressions which can't be directly passed in the method. val newCode = eval.code.transform { case block => block.transformExprValues { case s: SimpleExprValue(_, javaType) if varsPassIn.contains(s) => if (aliasedExprs.contains(s)) { aliasedExprs(s) } else { val aliasedVariable = JavaCode.variable(ctx.freshName("aliasedVar"), javaType) aliasedExprs += s -> aliasedVariable varsPassIn += aliasedVariable aliasedVariable } } } val params = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable => s"${variable.javaType.getName} ${variable.variableName}" }.mkString(", ") val funcName = ctx.freshName("nodeName") val javaType = CodeGenerator.javaType(dataType) val newValue = JavaCode.variable(ctx.freshName("value"), dataType) val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName($params) { | $newCode | $setIsNull | return ${eval.value}; |} """.stripMargin)) eval.value = newValue val args = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable => s"${variable.variableName}" } // Create a code block to assign statements to aliased variables. val createVariables = aliasedExprs.foldLeft(EmptyBlock) { (block, (statement, variable)) => block + code"${statement.javaType.getName} $variable = $statement;" } eval.code = createVariables + code"$javaType $newValue = $funcFullName($args);" } ``` ## How was this patch tested? Added unite tests. Author: Liang-Chi Hsieh Closes #21405 from viirya/codeblock-api. --- .../expressions/codegen/javaCode.scala | 48 +++++++++--- .../expressions/codegen/CodeBlockSuite.scala | 75 +++++++++++++++++-- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 44f63e21e93bb..2f8c853e836ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -118,12 +119,9 @@ object JavaCode { /** * A trait representing a block of java code. */ -trait Block extends JavaCode { +trait Block extends TreeNode[Block] with JavaCode { import Block._ - // The expressions to be evaluated inside this block. - def exprValues: Set[ExprValue] - // Returns java code string for this code block. override def toString: String = _marginChar match { case Some(c) => code.stripMargin(c).trim @@ -148,11 +146,41 @@ trait Block extends JavaCode { this } + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.equals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + // Concatenates this block with other block. def + (other: Block): Block = other match { case EmptyBlock => this case _ => code"$this\n$other" } + + override def verboseString: String = toString } object Block { @@ -219,12 +247,8 @@ object Block { * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { - override lazy val exprValues: Set[ExprValue] = { - blockInputs.flatMap { - case b: Block => b.exprValues - case e: ExprValue => Set(e) - }.toSet - } + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] override lazy val code: String = { val strings = codeParts.iterator @@ -239,9 +263,9 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } } -object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable { override val code: String = "" - override val exprValues: Set[ExprValue] = Set.empty + override def children: Seq[Block] = Seq.empty } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d2c6420eadb20..55569b6f2933e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite { |boolean $isNull = false; |int $value = -1; """.stripMargin - val exprValues = code.exprValues + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet assert(exprValues.size == 2) assert(exprValues === Set(value, isNull)) } @@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) - val exprValues = code.exprValues + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet assert(exprValues.size == 5) assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) } @@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite { assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) } - test("replace expr values in code block") { + test("transform expr in code block") { val expr = JavaCode.expression("1 + 1", IntegerType) val isNull = JavaCode.isNullVariable("expr1_isNull") val exprInFunc = JavaCode.variable("expr1", IntegerType) @@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin val aliasedParam = JavaCode.variable("aliased", expr.javaType) - val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { - case _: SimpleExprValue => aliasedParam - case other => other + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam } - val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin val expected = code""" |callFunc(int $aliasedParam) { @@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin assert(aliasedCode.toString == expected.toString) } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}" + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + }.asInstanceOf[CodeBlock] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + + assert(transformedBlock.children(0).toString == expected1.toString) + assert(transformedBlock.children(1).toString == expected2.toString) + assert(transformedBlock.children(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) + } } From e58dadb77ed6cac3e1b2a037a6449e5a6e7f2cec Mon Sep 17 00:00:00 2001 From: Michael Mior Date: Thu, 5 Jul 2018 08:32:20 -0500 Subject: [PATCH 1067/2461] [SPARK-23820][CORE] Enable use of long form of callsite in logs This adds an option to event logging to include the long form of the callsite instead of the short form. Author: Michael Mior Closes #21433 from michaelmior/long-callsite. --- .../org/apache/spark/internal/config/package.scala | 3 +++ .../main/scala/org/apache/spark/storage/RDDInfo.scala | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 38a043c85ae33..bda9795a0b925 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,9 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_FORM = + ConfigBuilder("spark.eventLog.callsite").stringConf.createWithDefault("short") + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..9ccc8f9cc585b 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,16 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = callsiteForm match { + case "short" => rdd.creationSite.shortForm + case "long" => rdd.creationSite.longForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } From 7bd6d5412072643f2320fd389f323cfc51368c81 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 5 Jul 2018 08:38:26 -0500 Subject: [PATCH 1068/2461] [SPARK-24711][K8S] Fix tags for integration tests ## What changes were proposed in this pull request? - disables maven surfire plugin to allow tags function properly, doc here: http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin ## How was this patch tested? Manually by adding tags. Author: Stavros Kontopoulos Closes #21697 from skonto/fix-tags. --- pom.xml | 2 +- .../dev/dev-run-integration-tests.sh | 20 +++++++++++++++++++ .../kubernetes/integration-tests/pom.xml | 11 ++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ca30f9f12b098..cd567e227f331 100644 --- a/pom.xml +++ b/pom.xml @@ -2122,7 +2122,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.20.1 + 2.22.0 diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index ea893fa39eede..3acd0f5cd3349 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -27,6 +27,8 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= +INCLUDE_TAGS= +EXCLUDE_TAGS= # Parse arguments while (( "$#" )); do @@ -59,6 +61,14 @@ while (( "$#" )); do SERVICE_ACCOUNT="$2" shift ;; + --include-tags) + INCLUDE_TAGS="$2" + shift + ;; + --exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; *) break ;; @@ -90,4 +100,14 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) fi +if [ -n $EXCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) +fi + +if [ -n $INCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS ) +fi + ../../../build/mvn integration-test ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 520bda89e034d..6a2fff891098b 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -40,6 +40,7 @@ minikube docker.io/kubespark + jar Spark Project Kubernetes Integration Tests @@ -102,6 +103,15 @@ + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + @@ -126,6 +136,7 @@ ${spark.kubernetes.test.serviceAccountName} ${test.exclude.tags} + ${test.include.tags} From ac78bcce00ff8ec8e5b7335c2807aa0cd0f5406a Mon Sep 17 00:00:00 2001 From: cluo <0512lc@163.com> Date: Thu, 5 Jul 2018 09:06:25 -0500 Subject: [PATCH 1069/2461] [SPARK-24743][EXAMPLES] Update the JavaDirectKafkaWordCount example to support the new API of kafka ## What changes were proposed in this pull request? Add some required configs for Kafka consumer in JavaDirectKafkaWordCount class. ## How was this patch tested? Manual tests on Local mode. Author: cluo <0512lc@163.com> Closes #21717 from cluo512/SPARK-24743-update-JavaDirectKafkaWordCount. --- .../streaming/JavaDirectKafkaWordCount.java | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index b6b163fa8b2cd..748bf58f30350 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -26,7 +26,9 @@ import scala.Tuple2; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.*; @@ -37,30 +39,33 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaDirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ public final class JavaDirectKafkaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaDirectKafkaWordCount \n" + - " is a list of one or more Kafka brokers\n" + - " is a list of one or more kafka topics to consume from\n\n"); + if (args.length < 3) { + System.err.println("Usage: JavaDirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a consumer group name to consume from topics\n" + + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); } StreamingExamples.setStreamingLogLevels(); String brokers = args[0]; - String topics = args[1]; + String groupId = args[1]; + String topics = args[2]; // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); @@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception { Set topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); Map kafkaParams = new HashMap<>(); - kafkaParams.put("metadata.broker.list", brokers); + kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers); + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); // Create direct kafka stream with brokers and topics JavaInputDStream> messages = KafkaUtils.createDirectStream( From 33952cfa8182c1e925083e18c63c6152dcc3c8b4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 5 Jul 2018 09:25:19 -0700 Subject: [PATCH 1070/2461] [SPARK-24675][SQL] Rename table: validate existence of new location MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? If table is renamed to a existing new location, data won't show up. ``` scala> Seq("hello").toDF("a").write.format("parquet").saveAsTable("t") scala> sql("select * from t").show() +-----+ | a| +-----+ |hello| +-----+ scala> sql("alter table t rename to test") res2: org.apache.spark.sql.DataFrame = [] scala> sql("select * from test").show() +---+ | a| +---+ +---+ ``` The file layout is like ``` $ tree test test ├── gabage └── t ├── _SUCCESS └── part-00000-856b0f10-08f1-42d6-9eb3-7719261f3d5e-c000.snappy.parquet ``` In Hive, if the new location exists, the renaming will fail even the location is empty. We should have the same validation in Catalog, in case of unexpected bugs. ## How was this patch tested? New unit test. Author: Gengliang Wang Closes #21655 from gengliangwang/validate_rename_table. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 20 +++++++++++++++++++ .../sql/execution/command/DDLSuite.scala | 18 +++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd7329b621122..ad23dae7c6b7c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1850,6 +1850,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c390337c03ff5..c26a34528c162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -619,6 +619,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) + validateNewLocationOfRename(oldName, newName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -1366,4 +1367,23 @@ class SessionCatalog( // copy over temporary views tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } + + /** + * Validate the new locatoin before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + val databaseLocation = + externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw new AnalysisException(s"Can not rename the managed table('$oldName')" + + s". The associated location('$newTableLocation') already exists.") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3998ceca38b30..270ed7f80197c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -441,6 +441,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename a managed table with existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) + try { + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") + tableLoc.mkdir() + val ex = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + }.getMessage + val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" + assert(ex.contains(expectedMsg)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], From e71e93aaaa0d26301e10d3dc65f4db298424e99a Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 5 Jul 2018 16:35:16 -0500 Subject: [PATCH 1071/2461] [SPARK-24694][K8S] Pass all app args to integration tests ## What changes were proposed in this pull request? - Allows to pass more than one app args to tests. ## How was this patch tested? Manually tested it with a spark test that requires more than on app args. Author: Stavros Kontopoulos Closes #21672 from skonto/fix_itsets-args. --- .../k8s/integrationtest/KubernetesTestComponents.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 48727142dd052..b2471e51116cb 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -105,16 +105,13 @@ private[spark] object SparkAppLauncher extends Logging { sparkHomeDir: Path): Unit = { val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") - val appArgsArray = - if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" ")) - else Array[String]() val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, "--deploy-mode", "cluster", "--class", appArguments.mainClass, "--master", appConf.get("spark.master") ) ++ appConf.toStringArray :+ appArguments.mainAppResource) ++ - appArgsArray + appArguments.appArgs ProcessUtils.executeProcess(commandLine, timeoutSecs) } } From 01fcba2c685be0603a404392685e9d52fb4cb82a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 6 Jul 2018 11:10:50 +0800 Subject: [PATCH 1072/2461] [SPARK-24737][SQL] Type coercion between StructTypes. ## What changes were proposed in this pull request? We can support type coercion between `StructType`s where all the internal types are compatible. ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21713 from ueshin/issues/SPARK-24737/structtypecoercion. --- .../sql/catalyst/analysis/TypeCoercion.scala | 69 ++++++-------- .../catalyst/analysis/TypeCoercionSuite.scala | 93 +++++++++++++++++-- 2 files changed, 114 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cf90e6e555fc8..b6ca30c7398f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -102,25 +102,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) - - case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) - Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -166,6 +148,30 @@ object TypeCoercion { case (l, r) => None } + private def findTypeForComplex( + t1: DataType, + t2: DataType, + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findTypeFunc(kt1, kt2).flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findTypeFunc(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } + case _ => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -176,17 +182,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findWiderTypeForTwo(kt1, kt2).flatMap { kt => - findWiderTypeForTwo(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } - } - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -222,18 +218,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt => - findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } - } - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 4e5ca1b8cdd36..8cc5a23779a2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,7 +54,7 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: StructType* is castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit @@ -397,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), @@ -454,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -492,6 +495,10 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(MapType(IntegerType, FloatType), containsNull = false), ArrayType(MapType(LongType, DoubleType), containsNull = false), Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) // MapType widenTestWithStringPromotion( @@ -506,6 +513,64 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) @@ -520,6 +585,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) widenTestWithoutStringPromotion( MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -544,6 +617,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { From bf67f70c48881ee99751f7d51fbcbda1e593d90a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 6 Jul 2018 11:13:57 +0800 Subject: [PATCH 1073/2461] [SPARK-24692][TESTS] Improvement FilterPushdownBenchmark ## What changes were proposed in this pull request? Refer to the [`WideSchemaBenchmark`](https://github.com/apache/spark/blob/v2.3.1/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala) update `FilterPushdownBenchmark`: 1. Write the result to `benchmarks/FilterPushdownBenchmark-results.txt` for easy maintenance. 2. Add more benchmark case: `StringStartsWith`, `Decimal`, `InSet -> InFilters` and `tinyint`. ## How was this patch tested? manual tests Author: Yuming Wang Closes #21677 from wangyum/SPARK-24692. --- .../FilterPushdownBenchmark-results.txt | 580 ++++++++++++++++++ .../benchmark/FilterPushdownBenchmark.scala | 405 +++++------- 2 files changed, 748 insertions(+), 237 deletions(-) create mode 100644 sql/core/benchmarks/FilterPushdownBenchmark-results.txt diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt new file mode 100644 index 0000000000000..29fe4345d69da --- /dev/null +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -0,0 +1,580 @@ +================================================================================================ +Pushdown for many distinct value case +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X +Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X +Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X +Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X +Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X +Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X +Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X +Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X +Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X +Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X +Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X +Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X +Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X +Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X +Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X +Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X +Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X +Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X +Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X +Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X +Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X +Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X +Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X +Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X +Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X +Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X +Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X +Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X +Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X +Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X +Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X +Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X +Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X +Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X +Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X +Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X +Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X +Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X +Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X +Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X +Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X +Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X +Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X +Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X +Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X +Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X +Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X +Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X +Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X +Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X +Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X +Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X +Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X + + +================================================================================================ +Pushdown for few distinct value case (use dictionary encoding) +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X +Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X +Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X +Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X +Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X +Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X +Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X +Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X +Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X +Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X +Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X +Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X +Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X +Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X +Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X +Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X +Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X +Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X +Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X + + +================================================================================================ +Pushdown benchmark for StringStartsWith +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X +Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X +Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X +Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X +Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X +Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X +Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X +Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X +Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X + + +================================================================================================ +Pushdown benchmark for decimal +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X +Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X +Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X +Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X +Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X +Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X +Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X +Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X +Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X +Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X +Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X +Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X +Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X +Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X +Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X +Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X +Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X +Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X +Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X +Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X +Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X +Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X +Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X +Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X +Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X +Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X +Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X +Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X +Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X +Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X +Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X +Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X +Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X +Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X +Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X +Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X +Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X + + +================================================================================================ +Pushdown benchmark for InSet -> InFilters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7477 / 7587 2.1 475.4 1.0X +Parquet Vectorized (Pushdown) 7862 / 8346 2.0 499.9 1.0X +Native ORC Vectorized 6447 / 7021 2.4 409.9 1.2X +Native ORC Vectorized (Pushdown) 983 / 1003 16.0 62.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7107 / 7290 2.2 451.9 1.0X +Parquet Vectorized (Pushdown) 7196 / 7258 2.2 457.5 1.0X +Native ORC Vectorized 6102 / 6222 2.6 388.0 1.2X +Native ORC Vectorized (Pushdown) 926 / 958 17.0 58.9 7.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7374 / 7692 2.1 468.8 1.0X +Parquet Vectorized (Pushdown) 7771 / 7848 2.0 494.1 0.9X +Native ORC Vectorized 6184 / 6356 2.5 393.2 1.2X +Native ORC Vectorized (Pushdown) 920 / 963 17.1 58.5 8.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7073 / 7326 2.2 449.7 1.0X +Parquet Vectorized (Pushdown) 7304 / 7647 2.2 464.4 1.0X +Native ORC Vectorized 6222 / 6579 2.5 395.6 1.1X +Native ORC Vectorized (Pushdown) 958 / 994 16.4 60.9 7.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7121 / 7501 2.2 452.7 1.0X +Parquet Vectorized (Pushdown) 7751 / 8334 2.0 492.8 0.9X +Native ORC Vectorized 6225 / 6680 2.5 395.8 1.1X +Native ORC Vectorized (Pushdown) 998 / 1020 15.8 63.5 7.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7399 2.2 455.1 1.0X +Parquet Vectorized (Pushdown) 7806 / 7911 2.0 496.3 0.9X +Native ORC Vectorized 6548 / 6720 2.4 416.3 1.1X +Native ORC Vectorized (Pushdown) 1016 / 1050 15.5 64.6 7.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7662 / 7805 2.1 487.1 1.0X +Parquet Vectorized (Pushdown) 7590 / 7861 2.1 482.5 1.0X +Native ORC Vectorized 6840 / 8073 2.3 434.9 1.1X +Native ORC Vectorized (Pushdown) 1041 / 1075 15.1 66.2 7.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8230 / 9266 1.9 523.2 1.0X +Parquet Vectorized (Pushdown) 7735 / 7960 2.0 491.8 1.1X +Native ORC Vectorized 6945 / 7109 2.3 441.6 1.2X +Native ORC Vectorized (Pushdown) 1123 / 1144 14.0 71.4 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7656 / 8058 2.1 486.7 1.0X +Parquet Vectorized (Pushdown) 7860 / 8247 2.0 499.7 1.0X +Native ORC Vectorized 6684 / 7003 2.4 424.9 1.1X +Native ORC Vectorized (Pushdown) 1085 / 1172 14.5 69.0 7.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7594 / 8128 2.1 482.8 1.0X +Parquet Vectorized (Pushdown) 7845 / 7923 2.0 498.8 1.0X +Native ORC Vectorized 5859 / 6421 2.7 372.5 1.3X +Native ORC Vectorized (Pushdown) 1037 / 1054 15.2 66.0 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6762 / 6775 2.3 429.9 1.0X +Parquet Vectorized (Pushdown) 6911 / 6970 2.3 439.4 1.0X +Native ORC Vectorized 5884 / 5960 2.7 374.1 1.1X +Native ORC Vectorized (Pushdown) 1028 / 1052 15.3 65.4 6.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6718 / 6767 2.3 427.1 1.0X +Parquet Vectorized (Pushdown) 6812 / 6909 2.3 433.1 1.0X +Native ORC Vectorized 5842 / 5883 2.7 371.4 1.1X +Native ORC Vectorized (Pushdown) 1040 / 1058 15.1 66.1 6.5X + + +================================================================================================ +Pushdown benchmark for tinyint +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3726 / 3775 4.2 236.9 1.0X +Parquet Vectorized (Pushdown) 3741 / 3789 4.2 237.9 1.0X +Native ORC Vectorized 2793 / 2909 5.6 177.6 1.3X +Native ORC Vectorized (Pushdown) 530 / 561 29.7 33.7 7.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4385 / 4406 3.6 278.8 1.0X +Parquet Vectorized (Pushdown) 4398 / 4454 3.6 279.6 1.0X +Native ORC Vectorized 3420 / 3501 4.6 217.4 1.3X +Native ORC Vectorized (Pushdown) 1395 / 1432 11.3 88.7 3.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7307 / 7394 2.2 464.6 1.0X +Parquet Vectorized (Pushdown) 7411 / 7461 2.1 471.2 1.0X +Native ORC Vectorized 6501 / 7814 2.4 413.4 1.1X +Native ORC Vectorized (Pushdown) 7341 / 8637 2.1 466.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11886 / 13122 1.3 755.7 1.0X +Parquet Vectorized (Pushdown) 12557 / 14173 1.3 798.4 0.9X +Native ORC Vectorized 10758 / 11971 1.5 684.0 1.1X +Native ORC Vectorized (Pushdown) 10564 / 10713 1.5 671.6 1.1X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 6d7c7de9a856e..fc716dec9f337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -17,25 +17,30 @@ package org.apache.spark.sql.execution.benchmark -import java.io.File +import java.io.{File, FileOutputStream, OutputStream} import scala.util.{Random, Try} +import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} + import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType} import org.apache.spark.util.{Benchmark, Utils} - /** * Benchmark to measure read performance with Filter pushdown. * To run this: - * spark-submit --class + * build/sbt "sql/test-only *FilterPushdownBenchmark" + * + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - .setAppName("FilterPushdownBenchmark") +class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { + private val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) // Since `spark.master` always exists, overrides this value .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") @@ -44,8 +49,40 @@ object FilterPushdownBenchmark { .setIfMissing("orc.compression", "snappy") .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + private val blockSize = 1048576 + private val spark = SparkSession.builder().config(conf).getOrCreate() + private var out: OutputStream = _ + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) + } + + override def beforeEach(td: TestData) { + super.beforeEach(td) + val separator = "=" * 96 + val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes + out.write(testHeader) + } + + override def afterEach(td: TestData) { + out.write('\n') + super.afterEach(td) + } + + override def afterAll() { + try { + out.close() + } finally { + super.afterAll() + } + } + def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() path.delete() @@ -81,8 +118,7 @@ object FilterPushdownBenchmark { .withColumn("value", valueCol) .sort("value") - saveAsOrcTable(df, dir.getCanonicalPath + "/orc") - saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + saveAsTable(df, dir) } private def prepareStringDictTable( @@ -93,19 +129,22 @@ object FilterPushdownBenchmark { } val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") - saveAsOrcTable(df, dir.getCanonicalPath + "/orc") - saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + saveAsTable(df, dir) } - private def saveAsOrcTable(df: DataFrame, dir: String): Unit = { - // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) - df.write.mode("overwrite").option("orc.dictionary.key.threshold", 1.0).orc(dir) - spark.read.orc(dir).createOrReplaceTempView("orcTable") - } + private def saveAsTable(df: DataFrame, dir: File): Unit = { + val orcPath = dir.getCanonicalPath + "/orc" + val parquetPath = dir.getCanonicalPath + "/parquet" - private def saveAsParquetTable(df: DataFrame, dir: String): Unit = { - df.write.mode("overwrite").parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite") + .option("orc.dictionary.key.threshold", 1.0) + .option("orc.stripe.size", blockSize).orc(orcPath) + spark.read.orc(orcPath).createOrReplaceTempView("orcTable") + + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") } def filterPushDownBenchmark( @@ -113,7 +152,7 @@ object FilterPushdownBenchmark { title: String, whereExpr: String, selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) + val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) Seq(false, true).foreach { pushDownEnabled => val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" @@ -133,214 +172,6 @@ object FilterPushdownBenchmark { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9201 / 9300 1.7 585.0 1.0X - Parquet Vectorized (Pushdown) 89 / 105 176.3 5.7 103.1X - Native ORC Vectorized 8886 / 8898 1.8 564.9 1.0X - Native ORC Vectorized (Pushdown) 110 / 128 143.4 7.0 83.9X - - - Select 0 string row - ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9336 / 9357 1.7 593.6 1.0X - Parquet Vectorized (Pushdown) 927 / 937 17.0 58.9 10.1X - Native ORC Vectorized 9026 / 9041 1.7 573.9 1.0X - Native ORC Vectorized (Pushdown) 257 / 272 61.1 16.4 36.3X - - - Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9209 / 9223 1.7 585.5 1.0X - Parquet Vectorized (Pushdown) 908 / 925 17.3 57.7 10.1X - Native ORC Vectorized 8878 / 8904 1.8 564.4 1.0X - Native ORC Vectorized (Pushdown) 248 / 261 63.4 15.8 37.1X - - - Select 1 string row - (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9194 / 9216 1.7 584.5 1.0X - Parquet Vectorized (Pushdown) 899 / 908 17.5 57.2 10.2X - Native ORC Vectorized 8934 / 8962 1.8 568.0 1.0X - Native ORC Vectorized (Pushdown) 249 / 254 63.3 15.8 37.0X - - - Select 1 string row - ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9332 / 9351 1.7 593.3 1.0X - Parquet Vectorized (Pushdown) 915 / 934 17.2 58.2 10.2X - Native ORC Vectorized 9049 / 9057 1.7 575.3 1.0X - Native ORC Vectorized (Pushdown) 248 / 258 63.5 15.8 37.7X - - - Select all string rows - (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 20478 / 20497 0.8 1301.9 1.0X - Parquet Vectorized (Pushdown) 20461 / 20550 0.8 1300.9 1.0X - Native ORC Vectorized 27464 / 27482 0.6 1746.1 0.7X - Native ORC Vectorized (Pushdown) 27454 / 27488 0.6 1745.5 0.7X - - - Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8489 / 8519 1.9 539.7 1.0X - Parquet Vectorized (Pushdown) 64 / 69 246.1 4.1 132.8X - Native ORC Vectorized 8064 / 8099 2.0 512.7 1.1X - Native ORC Vectorized (Pushdown) 88 / 94 178.6 5.6 96.4X - - - Select 0 int row - (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8494 / 8514 1.9 540.0 1.0X - Parquet Vectorized (Pushdown) 835 / 840 18.8 53.1 10.2X - Native ORC Vectorized 8090 / 8106 1.9 514.4 1.0X - Native ORC Vectorized (Pushdown) 249 / 257 63.2 15.8 34.1X - - - Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8552 / 8560 1.8 543.7 1.0X - Parquet Vectorized (Pushdown) 837 / 841 18.8 53.2 10.2X - Native ORC Vectorized 8178 / 8188 1.9 519.9 1.0X - Native ORC Vectorized (Pushdown) 249 / 258 63.2 15.8 34.4X - - - Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8562 / 8580 1.8 544.3 1.0X - Parquet Vectorized (Pushdown) 833 / 836 18.9 53.0 10.3X - Native ORC Vectorized 8164 / 8185 1.9 519.0 1.0X - Native ORC Vectorized (Pushdown) 245 / 254 64.3 15.6 35.0X - - - Select 1 int row - (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8540 / 8555 1.8 542.9 1.0X - Parquet Vectorized (Pushdown) 837 / 839 18.8 53.2 10.2X - Native ORC Vectorized 8182 / 8231 1.9 520.2 1.0X - Native ORC Vectorized (Pushdown) 250 / 259 62.9 15.9 34.1X - - - Select 1 int row - (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8535 / 8555 1.8 542.6 1.0X - Parquet Vectorized (Pushdown) 835 / 841 18.8 53.1 10.2X - Native ORC Vectorized 8159 / 8179 1.9 518.8 1.0X - Native ORC Vectorized (Pushdown) 244 / 250 64.5 15.5 35.0X - - - Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9609 / 9634 1.6 610.9 1.0X - Parquet Vectorized (Pushdown) 2663 / 2672 5.9 169.3 3.6X - Native ORC Vectorized 9824 / 9850 1.6 624.6 1.0X - Native ORC Vectorized (Pushdown) 2717 / 2722 5.8 172.7 3.5X - - - Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 13592 / 13613 1.2 864.2 1.0X - Parquet Vectorized (Pushdown) 9720 / 9738 1.6 618.0 1.4X - Native ORC Vectorized 16366 / 16397 1.0 1040.5 0.8X - Native ORC Vectorized (Pushdown) 12437 / 12459 1.3 790.7 1.1X - - - Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 17580 / 17617 0.9 1117.7 1.0X - Parquet Vectorized (Pushdown) 16803 / 16827 0.9 1068.3 1.0X - Native ORC Vectorized 24169 / 24187 0.7 1536.6 0.7X - Native ORC Vectorized (Pushdown) 22147 / 22341 0.7 1408.1 0.8X - - - Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18461 / 18491 0.9 1173.7 1.0X - Parquet Vectorized (Pushdown) 18466 / 18530 0.9 1174.1 1.0X - Native ORC Vectorized 24231 / 24270 0.6 1540.6 0.8X - Native ORC Vectorized (Pushdown) 24207 / 24304 0.6 1539.0 0.8X - - - Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18414 / 18453 0.9 1170.7 1.0X - Parquet Vectorized (Pushdown) 18435 / 18464 0.9 1172.1 1.0X - Native ORC Vectorized 24430 / 24454 0.6 1553.2 0.8X - Native ORC Vectorized (Pushdown) 24410 / 24465 0.6 1552.0 0.8X - - - Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18446 / 18457 0.9 1172.8 1.0X - Parquet Vectorized (Pushdown) 18428 / 18440 0.9 1171.6 1.0X - Native ORC Vectorized 24414 / 24450 0.6 1552.2 0.8X - Native ORC Vectorized (Pushdown) 24385 / 24472 0.6 1550.4 0.8X - - - Select 0 distinct string row - (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8322 / 8352 1.9 529.1 1.0X - Parquet Vectorized (Pushdown) 53 / 57 296.3 3.4 156.7X - Native ORC Vectorized 7903 / 7953 2.0 502.4 1.1X - Native ORC Vectorized (Pushdown) 80 / 82 197.2 5.1 104.3X - - - Select 0 distinct string row - ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8712 / 8743 1.8 553.9 1.0X - Parquet Vectorized (Pushdown) 995 / 1030 15.8 63.3 8.8X - Native ORC Vectorized 8345 / 8362 1.9 530.6 1.0X - Native ORC Vectorized (Pushdown) 84 / 87 187.6 5.3 103.9X - - - Select 1 distinct string row - (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8574 / 8610 1.8 545.1 1.0X - Parquet Vectorized (Pushdown) 1127 / 1135 14.0 71.6 7.6X - Native ORC Vectorized 8163 / 8181 1.9 519.0 1.1X - Native ORC Vectorized (Pushdown) 426 / 433 36.9 27.1 20.1X - - - Select 1 distinct string row - (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8549 / 8568 1.8 543.5 1.0X - Parquet Vectorized (Pushdown) 1124 / 1131 14.0 71.4 7.6X - Native ORC Vectorized 8163 / 8210 1.9 519.0 1.0X - Native ORC Vectorized (Pushdown) 426 / 436 36.9 27.1 20.1X - - - Select 1 distinct string row - ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8889 / 8896 1.8 565.2 1.0X - Parquet Vectorized (Pushdown) 1161 / 1168 13.6 73.8 7.7X - Native ORC Vectorized 8519 / 8554 1.8 541.6 1.0X - Native ORC Vectorized (Pushdown) 430 / 437 36.6 27.3 20.7X - - - Select all distinct string rows - (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 20433 / 20533 0.8 1299.1 1.0X - Parquet Vectorized (Pushdown) 20433 / 20456 0.8 1299.1 1.0X - Native ORC Vectorized 25435 / 25513 0.6 1617.1 0.8X - Native ORC Vectorized (Pushdown) 25435 / 25507 0.6 1617.1 0.8X - */ - benchmark.run() } @@ -408,14 +239,8 @@ object FilterPushdownBenchmark { } } - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - - // Pushdown for many distinct value case + ignore("Pushdown for many distinct value case") { withTempPath { dir => - val mid = numRows / 2 - withTempTable("orcTable", "patquetTable") { Seq(true, false).foreach { useStringForValue => prepareTable(dir, numRows, width, useStringForValue) @@ -427,16 +252,122 @@ object FilterPushdownBenchmark { } } } + } - // Pushdown for few distinct value case (use dictionary encoding) + ignore("Pushdown for few distinct value case (use dictionary encoding)") { withTempPath { dir => val numDistinctValues = 200 - val mid = numDistinctValues / 2 withTempTable("orcTable", "patquetTable") { prepareStringDictTable(dir, numRows, numDistinctValues, width) - runStringBenchmark(numRows, width, mid, "distinct string") + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") } } } + + ignore("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + + ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(dt)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + } + } + } + } + + ignore("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + } + + ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } + } + } + } +} + +trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => + + override def beforeEach(td: TestData) { + super.beforeEach(td) + } + + override def afterEach(td: TestData) { + super.afterEach(td) + } } From 141953f4c44dbad1c2a7059e92bec5fe770af932 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 6 Jul 2018 00:08:03 -0700 Subject: [PATCH 1074/2461] [SPARK-24535][SPARKR] fix tests on java check error ## What changes were proposed in this pull request? change to skip tests if - couldn't determine java version fix problem on windows ## How was this patch tested? unit test, manual, win-builder Author: Felix Cheung Closes #21666 from felixcheung/rjavaskip. --- R/pkg/R/client.R | 24 +++++++++++++++--------- R/pkg/R/sparkR.R | 2 +- R/pkg/inst/tests/testthat/test_basic.R | 8 ++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 4c87f64e7f0e1..660f0864403e0 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -71,15 +71,20 @@ checkJavaVersion <- function() { # If java is missing from PATH, we get an error in Unix and a warning in Windows javaVersionOut <- tryCatch( - launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), - error = function(e) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", e) - }, - warning = function(w) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", w) - }) + if (is_windows()) { + # See SPARK-24535 + system2(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + } else { + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + }, + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) javaVersionFilter <- Filter( function(x) { grepl(" version", x) @@ -93,6 +98,7 @@ checkJavaVersion <- function() { stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) } + return(javaVersionNum) } launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index f7c1663d32c96..d3a9cbae7d808 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,7 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) - checkJavaVersion() + invisible(checkJavaVersion()) launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 823d26f12feee..243f5f0298284 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,6 +18,10 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -50,6 +54,10 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) From a381bce7285ec30f58f28f523dfcfe0c13221bbf Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 6 Jul 2018 18:28:54 +0800 Subject: [PATCH 1075/2461] [SPARK-24673][SQL][PYTHON][FOLLOWUP] Support Column arguments in timezone of from_utc_timestamp/to_utc_timestamp ## What changes were proposed in this pull request? This pr supported column arguments in timezone of `from_utc_timestamp/to_utc_timestamp` (follow-up of #21693). ## How was this patch tested? Added tests. Author: Takeshi Yamamuro Closes #21723 from maropu/SPARK-24673-FOLLOWUP. --- python/pyspark/sql/functions.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4d371976364d3..55e7d575b4681 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1285,11 +1285,21 @@ def from_utc_timestamp(timestamp, tz): that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1300,11 +1310,21 @@ def to_utc_timestamp(timestamp, tz): zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) From 4de0425df8d2545718a0583bc26592108aebc5ac Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 7 Jul 2018 10:54:14 +0800 Subject: [PATCH 1076/2461] [SPARK-24569][SQL] Aggregator with output type Option should produce consistent schema ## What changes were proposed in this pull request? SQL `Aggregator` with output type `Option[Boolean]` creates column of type `StructType`. It's not in consistency with a Dataset of similar java class. This changes the way `definedByConstructorParams` checks given type. For `Option[_]`, it goes to check its type argument. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21611 from viirya/SPARK-24569. --- .../spark/sql/catalyst/ScalaReflection.scala | 7 ++- .../spark/sql/DatasetAggregatorSuite.scala | 60 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 11 ++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715e..4543bba8f6ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -798,7 +798,12 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) + case _ => tpe.dealias <:< localTypeOf[Product] || + tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + } } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d57..538ea3c66c40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -148,6 +148,41 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { } +case class OptionBooleanData(name: String, isGood: Option[Boolean]) + +case class OptionBooleanAggregator(colName: String) + extends Aggregator[Row, Option[Boolean], Option[Boolean]] { + + override def zero: Option[Boolean] = None + + override def reduce(buffer: Option[Boolean], row: Row): Option[Boolean] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[Boolean] + } else { + Some(row.getBoolean(index)) + } + merge(buffer, value) + } + + override def merge(b1: Option[Boolean], b2: Option[Boolean]): Option[Boolean] = { + if ((b1.isDefined && b1.get) || (b2.isDefined && b2.get)) { + Some(true) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[Boolean]): Option[Boolean] = reduction + + override def bufferEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + override def outputEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + + def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -333,4 +368,29 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", true) :: Nil) + checkDataset(group.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } + + test("SPARK-24569: groupByKey with Aggregator of output type Option[Boolean]") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val grouped = df.groupByKey((r: Row) => r.getString(0)) + .agg(OptionBooleanAggregator("isGood").toColumn).toDF("name", "isGood") + + assert(grouped.schema == df.schema) + checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2d20c50584c03..ce8db99d4e2f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val a = Seq(Some(1)).toDS + val b = Seq(Some(1.2)).toDS + val expected = Seq((Some(1), Some(1.2))).toDS + val joined = a.joinWith(b, lit(true)) + assert(joined.schema == expected.schema) + checkDataset(joined, expected.collect: _*) + } + } + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { val encoder = Encoders.tuple(newStringEncoder, Encoders.tuple(newStringEncoder, newStringEncoder)) From fc43690d36e7a17e45826a69ab86935fb0ee2be4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 7 Jul 2018 11:34:30 +0800 Subject: [PATCH 1077/2461] [SPARK-24749][SQL] Use sameType to compare Array's element type in ArrayContains ## What changes were proposed in this pull request? We should use `DataType.sameType` to compare element type in `ArrayContains`, otherwise nullability affects comparison result. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21724 from viirya/SPARK-24749. --- .../sql/catalyst/expressions/collectionOperations.scala | 2 +- .../catalyst/expressions/CollectionExpressionsSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8b278f067749e..fcac3a58e6a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1085,7 +1085,7 @@ case class ArrayContains(left: Expression, right: Expression) if (right.dataType == NullType) { TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") } else if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7744eb4c7dc7..496ee1d496a36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -213,6 +213,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( + StructField("a", IntegerType, true))))) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) @@ -228,6 +230,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq( + StructField("a", IntegerType, false))))), true) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq( + StructField("a", IntegerType, false))))), false) + // binary val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) From 74f6a92fcea9196d62c2d531c11ec7efd580b760 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 7 Jul 2018 11:37:41 +0800 Subject: [PATCH 1078/2461] [SPARK-24739][PYTHON] Make PySpark compatible with Python 3.7 ## What changes were proposed in this pull request? This PR proposes to make PySpark compatible with Python 3.7. There are rather radical change in semantic of `StopIteration` within a generator. It now throws it as a `RuntimeError`. To make it compatible, we should fix it: ```python try: next(...) except StopIteration return ``` See [release note](https://docs.python.org/3/whatsnew/3.7.html#porting-to-python-3-7) and [PEP 479](https://www.python.org/dev/peps/pep-0479/). ## How was this patch tested? Manually tested: ``` $ ./run-tests --python-executables=python3.7 Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3.7'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Starting test(python3.7): pyspark.mllib.tests Starting test(python3.7): pyspark.sql.tests Starting test(python3.7): pyspark.streaming.tests Starting test(python3.7): pyspark.tests Finished test(python3.7): pyspark.streaming.tests (130s) Starting test(python3.7): pyspark.accumulators Finished test(python3.7): pyspark.accumulators (8s) Starting test(python3.7): pyspark.broadcast Finished test(python3.7): pyspark.broadcast (9s) Starting test(python3.7): pyspark.conf Finished test(python3.7): pyspark.conf (6s) Starting test(python3.7): pyspark.context Finished test(python3.7): pyspark.context (27s) Starting test(python3.7): pyspark.ml.classification Finished test(python3.7): pyspark.tests (200s) ... 3 tests were skipped Starting test(python3.7): pyspark.ml.clustering Finished test(python3.7): pyspark.mllib.tests (244s) Starting test(python3.7): pyspark.ml.evaluation Finished test(python3.7): pyspark.ml.classification (63s) Starting test(python3.7): pyspark.ml.feature Finished test(python3.7): pyspark.ml.clustering (48s) Starting test(python3.7): pyspark.ml.fpm Finished test(python3.7): pyspark.ml.fpm (0s) Starting test(python3.7): pyspark.ml.image Finished test(python3.7): pyspark.ml.evaluation (23s) Starting test(python3.7): pyspark.ml.linalg.__init__ Finished test(python3.7): pyspark.ml.linalg.__init__ (0s) Starting test(python3.7): pyspark.ml.recommendation Finished test(python3.7): pyspark.ml.image (20s) Starting test(python3.7): pyspark.ml.regression Finished test(python3.7): pyspark.ml.regression (58s) Starting test(python3.7): pyspark.ml.stat Finished test(python3.7): pyspark.ml.feature (90s) Starting test(python3.7): pyspark.ml.tests Finished test(python3.7): pyspark.ml.recommendation (82s) Starting test(python3.7): pyspark.ml.tuning Finished test(python3.7): pyspark.ml.stat (27s) Starting test(python3.7): pyspark.mllib.classification Finished test(python3.7): pyspark.sql.tests (362s) ... 102 tests were skipped Starting test(python3.7): pyspark.mllib.clustering Finished test(python3.7): pyspark.ml.tuning (29s) Starting test(python3.7): pyspark.mllib.evaluation Finished test(python3.7): pyspark.mllib.classification (39s) Starting test(python3.7): pyspark.mllib.feature Finished test(python3.7): pyspark.mllib.evaluation (30s) Starting test(python3.7): pyspark.mllib.fpm Finished test(python3.7): pyspark.mllib.feature (44s) Starting test(python3.7): pyspark.mllib.linalg.__init__ Finished test(python3.7): pyspark.mllib.linalg.__init__ (0s) Starting test(python3.7): pyspark.mllib.linalg.distributed Finished test(python3.7): pyspark.mllib.clustering (78s) Starting test(python3.7): pyspark.mllib.random Finished test(python3.7): pyspark.mllib.fpm (33s) Starting test(python3.7): pyspark.mllib.recommendation Finished test(python3.7): pyspark.mllib.random (12s) Starting test(python3.7): pyspark.mllib.regression Finished test(python3.7): pyspark.mllib.linalg.distributed (45s) Starting test(python3.7): pyspark.mllib.stat.KernelDensity Finished test(python3.7): pyspark.mllib.stat.KernelDensity (0s) Starting test(python3.7): pyspark.mllib.stat._statistics Finished test(python3.7): pyspark.mllib.recommendation (41s) Starting test(python3.7): pyspark.mllib.tree Finished test(python3.7): pyspark.mllib.regression (44s) Starting test(python3.7): pyspark.mllib.util Finished test(python3.7): pyspark.mllib.stat._statistics (20s) Starting test(python3.7): pyspark.profiler Finished test(python3.7): pyspark.mllib.tree (26s) Starting test(python3.7): pyspark.rdd Finished test(python3.7): pyspark.profiler (11s) Starting test(python3.7): pyspark.serializers Finished test(python3.7): pyspark.mllib.util (24s) Starting test(python3.7): pyspark.shuffle Finished test(python3.7): pyspark.shuffle (0s) Starting test(python3.7): pyspark.sql.catalog Finished test(python3.7): pyspark.serializers (15s) Starting test(python3.7): pyspark.sql.column Finished test(python3.7): pyspark.rdd (27s) Starting test(python3.7): pyspark.sql.conf Finished test(python3.7): pyspark.sql.catalog (24s) Starting test(python3.7): pyspark.sql.context Finished test(python3.7): pyspark.sql.conf (8s) Starting test(python3.7): pyspark.sql.dataframe Finished test(python3.7): pyspark.sql.column (29s) Starting test(python3.7): pyspark.sql.functions Finished test(python3.7): pyspark.sql.context (26s) Starting test(python3.7): pyspark.sql.group Finished test(python3.7): pyspark.sql.dataframe (51s) Starting test(python3.7): pyspark.sql.readwriter Finished test(python3.7): pyspark.ml.tests (266s) Starting test(python3.7): pyspark.sql.session Finished test(python3.7): pyspark.sql.group (36s) Starting test(python3.7): pyspark.sql.streaming Finished test(python3.7): pyspark.sql.functions (57s) Starting test(python3.7): pyspark.sql.types Finished test(python3.7): pyspark.sql.session (25s) Starting test(python3.7): pyspark.sql.udf Finished test(python3.7): pyspark.sql.types (10s) Starting test(python3.7): pyspark.sql.window Finished test(python3.7): pyspark.sql.readwriter (31s) Starting test(python3.7): pyspark.streaming.util Finished test(python3.7): pyspark.sql.streaming (22s) Starting test(python3.7): pyspark.util Finished test(python3.7): pyspark.util (0s) Finished test(python3.7): pyspark.streaming.util (0s) Finished test(python3.7): pyspark.sql.udf (16s) Finished test(python3.7): pyspark.sql.window (12s) ``` In my local (I have two Macs but both have the same issues), I currently faced some issues for now to install both extra dependencies PyArrow and Pandas same as Jenkins's, against Python 3.7. Author: hyukjinkwon Closes #21714 from HyukjinKwon/SPARK-24739. --- python/pyspark/rdd.py | 5 ++++- python/setup.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7e7e5822a6b20..951851804b1d8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1370,7 +1370,10 @@ def takeUpToNumLeft(iterator): iterator = iter(iterator) taken = 0 while taken < left: - yield next(iterator) + try: + yield next(iterator) + except StopIteration: + return taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) diff --git a/python/setup.py b/python/setup.py index d309e0564530a..45eb74eb87ce7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -219,6 +219,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) From 044b33b2ed2d423d798f2a632fab110c46f41567 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 7 Jul 2018 11:39:29 +0800 Subject: [PATCH 1079/2461] [SPARK-24740][PYTHON][ML] Make PySpark's tests compatible with NumPy 1.14+ ## What changes were proposed in this pull request? This PR proposes to make PySpark's tests compatible with NumPy 0.14+ NumPy 0.14.x introduced rather radical changes about its string representation. For example, the tests below are failed: ``` ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 895, in __main__.DenseMatrix.__str__ Failed example: print(dm) Expected: DenseMatrix([[ 0., 2.], [ 1., 3.]]) Got: DenseMatrix([[0., 2.], [1., 3.]]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 899, in __main__.DenseMatrix.__str__ Failed example: print(dm) Expected: DenseMatrix([[ 0., 1.], [ 2., 3.]]) Got: DenseMatrix([[0., 1.], [2., 3.]]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 939, in __main__.DenseMatrix.toArray Failed example: m.toArray() Expected: array([[ 0., 2.], [ 1., 3.]]) Got: array([[0., 2.], [1., 3.]]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 324, in __main__.DenseVector.dot Failed example: dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F')) Expected: array([ 5., 11.]) Got: array([ 5., 11.]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 567, in __main__.SparseVector.dot Failed example: a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) Expected: array([ 22., 22.]) Got: array([22., 22.]) ``` See [release note](https://docs.scipy.org/doc/numpy-1.14.0/release.html#compatibility-notes). ## How was this patch tested? Manually tested: ``` $ ./run-tests --python-executables=python3.6,python2.7 --modules=pyspark-ml,pyspark-mllib Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3.6', 'python2.7'] Will test the following Python modules: ['pyspark-ml', 'pyspark-mllib'] Starting test(python2.7): pyspark.mllib.tests Starting test(python2.7): pyspark.ml.classification Starting test(python3.6): pyspark.mllib.tests Starting test(python2.7): pyspark.ml.clustering Finished test(python2.7): pyspark.ml.clustering (54s) Starting test(python2.7): pyspark.ml.evaluation Finished test(python2.7): pyspark.ml.classification (74s) Starting test(python2.7): pyspark.ml.feature Finished test(python2.7): pyspark.ml.evaluation (27s) Starting test(python2.7): pyspark.ml.fpm Finished test(python2.7): pyspark.ml.fpm (0s) Starting test(python2.7): pyspark.ml.image Finished test(python2.7): pyspark.ml.image (17s) Starting test(python2.7): pyspark.ml.linalg.__init__ Finished test(python2.7): pyspark.ml.linalg.__init__ (1s) Starting test(python2.7): pyspark.ml.recommendation Finished test(python2.7): pyspark.ml.feature (76s) Starting test(python2.7): pyspark.ml.regression Finished test(python2.7): pyspark.ml.recommendation (69s) Starting test(python2.7): pyspark.ml.stat Finished test(python2.7): pyspark.ml.regression (45s) Starting test(python2.7): pyspark.ml.tests Finished test(python2.7): pyspark.ml.stat (28s) Starting test(python2.7): pyspark.ml.tuning Finished test(python2.7): pyspark.ml.tuning (20s) Starting test(python2.7): pyspark.mllib.classification Finished test(python2.7): pyspark.mllib.classification (31s) Starting test(python2.7): pyspark.mllib.clustering Finished test(python2.7): pyspark.mllib.tests (260s) Starting test(python2.7): pyspark.mllib.evaluation Finished test(python3.6): pyspark.mllib.tests (266s) Starting test(python2.7): pyspark.mllib.feature Finished test(python2.7): pyspark.mllib.evaluation (21s) Starting test(python2.7): pyspark.mllib.fpm Finished test(python2.7): pyspark.mllib.feature (38s) Starting test(python2.7): pyspark.mllib.linalg.__init__ Finished test(python2.7): pyspark.mllib.linalg.__init__ (1s) Starting test(python2.7): pyspark.mllib.linalg.distributed Finished test(python2.7): pyspark.mllib.fpm (34s) Starting test(python2.7): pyspark.mllib.random Finished test(python2.7): pyspark.mllib.clustering (64s) Starting test(python2.7): pyspark.mllib.recommendation Finished test(python2.7): pyspark.mllib.random (15s) Starting test(python2.7): pyspark.mllib.regression Finished test(python2.7): pyspark.mllib.linalg.distributed (47s) Starting test(python2.7): pyspark.mllib.stat.KernelDensity Finished test(python2.7): pyspark.mllib.stat.KernelDensity (0s) Starting test(python2.7): pyspark.mllib.stat._statistics Finished test(python2.7): pyspark.mllib.recommendation (40s) Starting test(python2.7): pyspark.mllib.tree Finished test(python2.7): pyspark.mllib.regression (38s) Starting test(python2.7): pyspark.mllib.util Finished test(python2.7): pyspark.mllib.stat._statistics (19s) Starting test(python3.6): pyspark.ml.classification Finished test(python2.7): pyspark.mllib.tree (26s) Starting test(python3.6): pyspark.ml.clustering Finished test(python2.7): pyspark.mllib.util (27s) Starting test(python3.6): pyspark.ml.evaluation Finished test(python3.6): pyspark.ml.evaluation (30s) Starting test(python3.6): pyspark.ml.feature Finished test(python2.7): pyspark.ml.tests (234s) Starting test(python3.6): pyspark.ml.fpm Finished test(python3.6): pyspark.ml.fpm (1s) Starting test(python3.6): pyspark.ml.image Finished test(python3.6): pyspark.ml.clustering (55s) Starting test(python3.6): pyspark.ml.linalg.__init__ Finished test(python3.6): pyspark.ml.linalg.__init__ (0s) Starting test(python3.6): pyspark.ml.recommendation Finished test(python3.6): pyspark.ml.classification (71s) Starting test(python3.6): pyspark.ml.regression Finished test(python3.6): pyspark.ml.image (18s) Starting test(python3.6): pyspark.ml.stat Finished test(python3.6): pyspark.ml.stat (37s) Starting test(python3.6): pyspark.ml.tests Finished test(python3.6): pyspark.ml.regression (59s) Starting test(python3.6): pyspark.ml.tuning Finished test(python3.6): pyspark.ml.feature (93s) Starting test(python3.6): pyspark.mllib.classification Finished test(python3.6): pyspark.ml.recommendation (83s) Starting test(python3.6): pyspark.mllib.clustering Finished test(python3.6): pyspark.ml.tuning (29s) Starting test(python3.6): pyspark.mllib.evaluation Finished test(python3.6): pyspark.mllib.evaluation (26s) Starting test(python3.6): pyspark.mllib.feature Finished test(python3.6): pyspark.mllib.classification (43s) Starting test(python3.6): pyspark.mllib.fpm Finished test(python3.6): pyspark.mllib.clustering (81s) Starting test(python3.6): pyspark.mllib.linalg.__init__ Finished test(python3.6): pyspark.mllib.linalg.__init__ (2s) Starting test(python3.6): pyspark.mllib.linalg.distributed Finished test(python3.6): pyspark.mllib.fpm (48s) Starting test(python3.6): pyspark.mllib.random Finished test(python3.6): pyspark.mllib.feature (54s) Starting test(python3.6): pyspark.mllib.recommendation Finished test(python3.6): pyspark.mllib.random (18s) Starting test(python3.6): pyspark.mllib.regression Finished test(python3.6): pyspark.mllib.linalg.distributed (55s) Starting test(python3.6): pyspark.mllib.stat.KernelDensity Finished test(python3.6): pyspark.mllib.stat.KernelDensity (1s) Starting test(python3.6): pyspark.mllib.stat._statistics Finished test(python3.6): pyspark.mllib.recommendation (51s) Starting test(python3.6): pyspark.mllib.tree Finished test(python3.6): pyspark.mllib.regression (45s) Starting test(python3.6): pyspark.mllib.util Finished test(python3.6): pyspark.mllib.stat._statistics (21s) Finished test(python3.6): pyspark.mllib.tree (27s) Finished test(python3.6): pyspark.mllib.util (27s) Finished test(python3.6): pyspark.ml.tests (264s) ``` Author: hyukjinkwon Closes #21715 from HyukjinKwon/SPARK-24740. --- python/pyspark/ml/clustering.py | 6 ++++++ python/pyspark/ml/linalg/__init__.py | 5 +++++ python/pyspark/ml/stat.py | 6 ++++++ python/pyspark/mllib/clustering.py | 6 ++++++ python/pyspark/mllib/evaluation.py | 6 ++++++ python/pyspark/mllib/linalg/__init__.py | 6 ++++++ python/pyspark/mllib/linalg/distributed.py | 6 ++++++ python/pyspark/mllib/stat/_statistics.py | 6 ++++++ 8 files changed, 47 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6d77baf7349e4..2f0660040dc7c 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -1345,8 +1345,14 @@ def assignClusters(self, dataset): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.clustering from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 6a611a2b5b59d..2548fd0f50b33 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): import doctest + try: + # Numpy 1.14+ changed it's string format. + np.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index a06ab31a7a56a..370154fc6d62a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.stat from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.stat.__dict__.copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 0cbabab13a896..b09469b9f5c2d 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + import numpy import pyspark.mllib.clustering + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36cb03369b8c0..6c65da58e4e2b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -532,8 +532,14 @@ def accuracy(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession import pyspark.mllib.evaluation + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.evaluation.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 60d96d8d5ceb8..4afd6666400b0 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1368,6 +1368,12 @@ def R(self): def _test(): import doctest + import numpy + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index bba88542167ad..7e8b15056cabe 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.linalg.distributed.__dict__.copy() spark = SparkSession.builder\ .master("local[2]")\ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 3c75b132ecad2..937bb154c2356 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + import numpy from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = globals().copy() spark = SparkSession.builder\ .master("local[4]")\ From 79c66894296840cc4a5bf6c8718ecfd2b08bcca8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 7 Jul 2018 22:16:48 +0200 Subject: [PATCH 1080/2461] [SPARK-24757][SQL] Improving the error message for broadcast timeouts ## What changes were proposed in this pull request? In the PR, I propose to provide a tip to user how to resolve the issue of timeout expiration for broadcast joins. In particular, they can increase the timeout via **spark.sql.broadcastTimeout** or disable the broadcast at all by setting **spark.sql.autoBroadcastJoinThreshold** to `-1`. ## How was this patch tested? It tested manually from `spark-shell`: ``` scala> spark.conf.set("spark.sql.broadcastTimeout", 1) scala> val df = spark.range(100).join(spark.range(15).as[Long].map { x => Thread.sleep(5000) x }).where("id = value") scala> df.count() ``` ``` org.apache.spark.SparkException: Could not execute broadcast in 1 secs. You can increase the timeout for broadcasts via spark.sql.broadcastTimeout or disable broadcast join by setting spark.sql.autoBroadcastJoinThreshold to -1 at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:150) ``` Author: Maxim Gekk Closes #21727 from MaxGekk/broadcast-timeout-error. --- .../execution/exchange/BroadcastExchangeExec.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index c55f9b8f1a7fc..a80673c705f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.exchange +import java.util.concurrent.TimeoutException + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -140,7 +142,16 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + try { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) + throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } } } From e2c7e09f742a7e522efd74fe8e14c2620afdb522 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 9 Jul 2018 10:21:40 +0800 Subject: [PATCH 1081/2461] [SPARK-24646][CORE] Minor change to spark.yarn.dist.forceDownloadSchemes to support wildcard '*' ## What changes were proposed in this pull request? In the case of getting tokens via customized `ServiceCredentialProvider`, it is required that `ServiceCredentialProvider` be available in local spark-submit process classpath. In this case, all the configured remote sources should be forced to download to local. For the ease of using this configuration, here propose to add wildcard '*' support to `spark.yarn.dist.forceDownloadSchemes`, also clarify the usage of this configuration. ## How was this patch tested? New UT added. Author: jerryshao Closes #21633 from jerryshao/SPARK-21917-followup. --- .../org/apache/spark/deploy/SparkSubmit.scala | 5 ++-- .../spark/internal/config/package.scala | 5 ++-- .../spark/deploy/SparkSubmitSuite.scala | 29 ++++++++++++------- docs/running-on-yarn.md | 5 ++-- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 2da778a29779d..e7310ee886103 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -385,7 +385,7 @@ private[spark] class SparkSubmit extends Logging { val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { - forceDownloadSchemes.contains(scheme) || + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } @@ -578,7 +578,8 @@ private[spark] class SparkSubmit extends Logging { } // Add the main application jar and any added jars to classpath in case YARN client // requires these jars. - // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be // added to the classpath of YARN client. if (isYarnCluster) { if (isUserJar(args.primaryResource)) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bda9795a0b925..ba892bf7f60d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -486,10 +486,11 @@ package object config { private[spark] val FORCE_DOWNLOAD_SCHEMES = ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") - .doc("Comma-separated list of schemes for which files will be downloaded to the " + + .doc("Comma-separated list of schemes for which resources will be downloaded to the " + "local disk prior to being added to YARN's distributed cache. For use in cases " + "where the YARN service does not support schemes that are supported by Spark, like http, " + - "https and ftp.") + "https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " + + "'*' is denoted to download resources for all the schemes.") .stringConf .toSequence .createWithDefault(Nil) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 545c8d0423dc3..f829fecc30840 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -995,20 +995,24 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = true) } test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + } + + test("force download for all the schemes") { + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistHttpFs: Boolean): Unit = { + blacklistSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1025,8 +1029,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistHttpFs) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") } else { Nil } @@ -1044,14 +1048,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - // The URI of remote S3 resource should still be remote. - assert(jars.contains(tmpS3JarPath)) + def isSchemeBlacklisted(scheme: String) = { + blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + } + + if (!isSchemeBlacklisted("s3")) { + assert(jars.contains(tmpS3JarPath)) + } - if (enableHttpFs && !blacklistHttpFs) { + if (enableHttpFs && blacklistSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else { + } else if (!enableHttpFs || isSchemeBlacklisted("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 575da7205b529..0b265b0cb1b31 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -218,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd
      From 034913b62b579ae003431231c0272513de8f496c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 9 Jul 2018 21:21:38 +0900 Subject: [PATCH 1082/2461] [SPARK-23936][SQL] Implement map_concat ## What changes were proposed in this pull request? Implement map_concat high order function. This implementation does not pick a winner when the specified maps have overlapping keys. Therefore, this implementation preserves existing duplicate keys in the maps and potentially introduces new duplicates (After discussion with ueshin, we settled on option 1 from [here](https://issues.apache.org/jira/browse/SPARK-23936?focusedCommentId=16464245&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16464245)). ## How was this patch tested? New tests Manual tests Run all sbt SQL tests Run all pyspark sql tests Author: Bruce Robbins Closes #21073 from bersprockets/SPARK-23936. --- python/pyspark/sql/functions.py | 22 ++ .../sql/catalyst/CatalystTypeConverters.scala | 6 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../expressions/collectionOperations.scala | 231 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 126 ++++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../inputs/typeCoercion/native/mapconcat.sql | 94 +++++++ .../typeCoercion/native/mapconcat.sql.out | 143 +++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 78 ++++++ 10 files changed, 717 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 55e7d575b4681..9f61e29f9cd42 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2510,6 +2510,28 @@ def arrays_zip(*cols): return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 93df73ab1eaf6..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 80a0af672bf74..e7517e8c676e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -422,6 +422,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b6ca30c7398f2..72908c1f433ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -548,6 +548,14 @@ object TypeCoercion { case None => s } + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fcac3a58e6a95..879603b66b314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -503,6 +503,237 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends Expression { + + override def checkInputDataTypes(): TypeCheckResult = { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + } + } + + override def dataType: MapType = { + val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption + .getOrElse(MapType(StringType, StringType)) + val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) + if (dt.valueContainsNull != valueContainsNull) { + dt.copy(valueContainsNull = valueContainsNull) + } else { + dt + } + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + | if (${m.isNull}) { + | $hasNullName = true; + | } + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType, false) + } else { + genCodeForNonPrimitiveArrays(ctx, keyType) + } + + val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } + + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin + + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } + + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin + + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin + } else { + setterCode1 + } + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def prettyName: String = "map_concat" +} + /** * Returns a map created from the given array of entries. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 496ee1d496a36..173c98af323b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -98,6 +98,132 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + } + test("MapFromEntries") { def arrayType(keyType: DataType, valueType: DataType) : DataType = { ArrayType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f2627e69939cd..89dbba10a6bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3627,6 +3627,14 @@ object functions { @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + /** + * Returns the union of all the given maps. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..fc26397b881b5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,94 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..d352b7284ae87 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4c28e2f1cd909..d60ed7a5ef0d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -657,6 +657,84 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + test("map_from_entries function") { def dummyFilter(c: Column): Column = c.isNull || c.isNotNull val oneRowDF = Seq(3215).toDF("i") From 1bd3d61f4191767a94b71b42f4d00706b703e84f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 9 Jul 2018 22:59:05 +0800 Subject: [PATCH 1083/2461] [SPARK-24268][SQL] Use datatype.simpleString in error messages ## What changes were proposed in this pull request? SPARK-22893 tried to unify error messages about dataTypes. Unfortunately, still many places were missing the `simpleString` method in other to have the same representation everywhere. The PR unified the messages using alway the simpleString representation of the dataTypes in the messages. ## How was this patch tested? existing/modified UTs Author: Marco Gaido Closes #21321 from mgaido91/SPARK-24268. --- .../spark/sql/kafka010/KafkaWriteTask.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaWriter.scala | 6 +++--- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/DCT.scala | 3 ++- .../apache/spark/ml/feature/FeatureHasher.scala | 5 +++-- .../org/apache/spark/ml/feature/HashingTF.scala | 2 +- .../apache/spark/ml/feature/Interaction.scala | 3 ++- .../org/apache/spark/ml/feature/NGram.scala | 2 +- .../apache/spark/ml/feature/OneHotEncoder.scala | 3 ++- .../org/apache/spark/ml/feature/RFormula.scala | 2 +- .../spark/ml/feature/StopWordsRemover.scala | 4 ++-- .../org/apache/spark/ml/feature/Tokenizer.scala | 3 ++- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/util/SchemaUtils.scala | 11 +++++++---- .../BinaryClassificationEvaluatorSuite.scala | 4 ++-- .../apache/spark/ml/feature/RFormulaSuite.scala | 2 +- .../spark/ml/feature/VectorAssemblerSuite.scala | 6 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- .../regression/AFTSurvivalRegressionSuite.scala | 2 +- .../apache/spark/ml/util/MLTestingUtils.scala | 6 +++--- .../expressions/complexTypeCreator.scala | 4 ++-- .../catalyst/expressions/jsonExpressions.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 5 +++-- .../sql/catalyst/json/JacksonGenerator.scala | 4 ++-- .../spark/sql/catalyst/json/JacksonParser.scala | 6 ++++-- .../sql/catalyst/json/JsonInferSchema.scala | 6 ++++-- .../spark/sql/catalyst/util/TypeUtils.scala | 5 +++-- .../spark/sql/types/AbstractDataType.scala | 9 +++++---- .../org/apache/spark/sql/types/ArrayType.scala | 5 +++-- .../org/apache/spark/sql/types/DecimalType.scala | 3 ++- .../org/apache/spark/sql/types/ObjectType.scala | 3 ++- .../org/apache/spark/sql/types/StructType.scala | 5 +++-- .../catalyst/analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/ExpressionTypeCheckingSuite.scala | 16 ++++++++-------- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../apache/spark/sql/types/DataTypeSuite.scala | 2 +- .../parquet/VectorizedColumnReader.java | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/arrow/ArrowUtils.scala | 3 ++- .../execution/datasources/orc/OrcFilters.scala | 2 +- .../parquet/ParquetSchemaConverter.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 ++-- .../resources/sql-tests/results/literals.sql.out | 6 +++--- .../datasources/parquet/ParquetSchemaSuite.scala | 4 ++-- .../sql/hive/execution/HiveTableScanExec.scala | 6 +++--- 48 files changed, 108 insertions(+), 88 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index d90630a8adc93..59a84706d4f55 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - "must be a StringType") + s"must be a ${StringType.simpleString}") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.simpleString}") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.simpleString}") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..3ec26e9edd353 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a String") + throw new AnalysisException(s"Topic type must be a ${StringType.simpleString}") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index ddfc0c1a4be2d..0e1492ac27449 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -314,7 +314,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { /* key field wrong type */ @@ -330,7 +330,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 7079ac6453ffc..70ffd7dee89d7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -303,7 +303,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +318,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 682787a830113..1eac1d1613d2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,7 +69,8 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + require(inputType.isInstanceOf[VectorUDT], + s"Input type must be ${(new VectorUDT).simpleString} but got ${inputType.simpleString}.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index d67e4819b161a..405ea467cb02a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -208,8 +208,9 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + - s"Column $fieldName was $dataType") + s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + + s"${BooleanType.simpleString} or ${StringType.simpleString}. " + + s"Column $fieldName was ${dataType.simpleString}") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index db432b6fefaff..403b0a813aedd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 4ff1d0ef356f3..5e01ec30bb2eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,7 +261,8 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + assert(numFeatures.length == 1, + s"${DoubleType.simpleString} columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index c8760f9dc178f..6445360f7fd90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + s"Input type must be ${ArrayType(StringType).simpleString} but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 5ab6c2dde667a..24045f0448c81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,7 +85,8 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + s"Input column must be of type ${NumericType.simpleString} but got " + + schema(inputColName).dataType.simpleString) require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 55e595eee6ffb..346e1823f00b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - "Label column already exists and is not of type NumericType.") + s"Label column already exists and is not of type ${NumericType.simpleString}.") } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0f946dd2e015b..ead75d5b8def3 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -131,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + s"${ArrayType(StringType).simpleString} but got ${inputType.simpleString}.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index cfaf6c0e610b3..5132f63af1796 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,7 +40,8 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, s"Input type must be string type but got $inputType.") + require(inputType == StringType, + s"Input type must be ${StringType.simpleString} type but got ${inputType.simpleString}.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4061154b39c14..ed3b36ee5ab2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type $other of column $name is not supported.") + case other => Some(s"Data type ${other.simpleString} of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d7fbe28ae7a64..51b88b3117f4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -106,7 +106,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index d9a3f85ef9a24..b500582074398 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -41,7 +41,8 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.$message") + s"Column $colName must be of type ${dataType.simpleString} but was actually " + + s"${actualDataType.simpleString}.$message") } /** @@ -58,7 +59,8 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + s"${dataTypes.map(_.simpleString).mkString("[", ", ", "]")} but was actually of type " + + s"${actualDataType.simpleString}.$message") } /** @@ -71,8 +73,9 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + - s"NumericType but was actually of type $actualDataType.$message") + require(actualDataType.isInstanceOf[NumericType], + s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + + s"${actualDataType.simpleString}.$message") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ede284712b1c0..2b0909acf69c3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [DoubleType, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + "equal to one of the following types: [double, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index a250331efeb1d..0de6528c4cf22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type NumericType.", + "Label column already exists and is not of type numeric.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 91fb24a268b8c..ed15a1d88a269 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type StringType of column a is not supported.\n" + - "Data type StringType of column b is not supported.\n" + - "Data type StringType of column c is not supported.") + "Data type string of column a is not supported.\n" + + "Data type string of column b is not supported.\n" + + "Data type string of column c is not supported.") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e3dfe2faf5698..65bee4edc4965 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -612,7 +612,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) + s"$column must be of type numeric but was actually of type string")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 4e4ff71c9de90..6cc73e040e82c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type NumericType but was actually of type StringType")) + "Column censor must be of type numeric but was actually of type string")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 5e72b4d864c1d..91a8b14625a86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type NumericType but was actually of type StringType")) + "Column weight must be of type numeric but was actually of type string")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) } def genClassifDFWithNumericLabelCol( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0a5f8a907b50a..cf0e3765de80f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -385,8 +385,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable ${StringType.simpleString} expressions are allowed to appear at odd" + + s" position, got: ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd86053a01c7..1bcf11d7ee737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -796,7 +796,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType}") + s"A type of keys and values in map() must be string, but got ${m.dataType.simpleString}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bedad7da334ae..70dd4df9df511 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -222,11 +222,12 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have IntegerType, but it's $indexType") + s"have ${IntegerType.simpleString}, but it's ${indexType.simpleString}") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + + s"input to function $prettyName should have ${StringType.simpleString} or " + + s"${BinaryType.simpleString}, but it's " + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c413de752a8c..00086abbefd08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -45,8 +45,8 @@ private[sql] class JacksonGenerator( // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - "JacksonGenerator only supports to be initialized with a StructType " + - s"or MapType but got ${dataType.simpleString}") + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + + s"or ${MapType.simpleString} but got ${dataType.simpleString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index c3a4ca8f64bf6..aa1691bb40d93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -143,7 +143,8 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") + case other => throw new RuntimeException( + s"Cannot parse $other as ${FloatType.simpleString}.") } } @@ -158,7 +159,8 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") + case other => + throw new RuntimeException(s"Cannot parse $other as ${DoubleType.simpleString}.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 491ca005877f8..5f70e062d46c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -294,8 +294,10 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), + s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), + s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..a9aaf617f7837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -29,7 +29,7 @@ object TypeUtils { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.simpleString}") } } @@ -37,7 +37,8 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") + TypeCheckResult.TypeCheckFailure( + s"$caller does not support ordering on type ${dt.simpleString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 3041f44b116ea..c43cc748655e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[spark] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,11 +155,12 @@ private[sql] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[sql] def defaultConcreteType: DataType = DoubleType + override private[spark] def defaultConcreteType: DataType = DoubleType - override private[sql] def simpleString: String = "numeric" + override private[spark] def simpleString: String = "numeric" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] + override private[spark] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 38c40482fa4d9..8f118624f6d2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[sql] def simpleString: String = "array" + override private[spark] def simpleString: String = "array" } /** @@ -103,7 +103,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") + throw new IllegalArgumentException( + s"Type ${other.simpleString} does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index dbf51c398fa47..f780ffd46a876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException(s"DecimalType can only support precision up to 38") + throw new AnalysisException( + s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") } // default constructor for Java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 2d49fe076786a..203e85e1c99bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + throw new UnsupportedOperationException( + s"null literals can't be casted to ${ObjectType.simpleString}") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 362676b252126..0e69ef8ba73e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -426,7 +426,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") } } @@ -528,7 +528,8 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") + throw new SparkException(s"Failed to merge incompatible data types ${left.simpleString} " + + s"and ${right.simpleString}") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..5e503be416a1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -514,7 +514,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 36714bd631b0e..8eec14842c7e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type MapType") + "EqualNullSafe does not support ordering on type map") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type MapType") + "LessThan does not support ordering on type map") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type MapType") + "LessThanOrEqual does not support ordering on type map") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type MapType") + "GreaterThan does not support ordering on type map") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type MapType") + "GreaterThanOrEqual does not support ordering on type map") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index cb8a1fecb80a7..b4d422d8506fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -469,7 +469,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "DecimalType can only support precision up to 38") + intercept("1.20E-38BD", "decimal can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 5a86f4055dce7..fccd057e577d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType")) + "Failed to merge incompatible data types float and bigint")) } test("existsRecursively") { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index d5969b55eef96..060e2ec068053 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -244,7 +244,7 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), descriptor.getType().toString(), - column.dataType().toString()); + column.dataType().simpleString()); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c6449cd5a16b0..b068493f2dd17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -452,7 +452,7 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the udf must be a StructType") + s"The returnType of the udf must be a ${StructType.simpleString}") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 93c8127681b3e..1274abffaa116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -47,7 +47,8 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + throw new UnsupportedOperationException( + s"${TimestampType.simpleString} must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f44ae4fa1d71..c90328f7ad43f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -98,7 +98,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.simpleString}") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c61be077d309f..18decad3f62f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -555,7 +555,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type $field.dataType") + throw new AnalysisException(s"Unsupported data type ${field.dataType.simpleString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 685d5841ab551..f772a3336d6af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType} not supported.") + s"for columns with dataType ${data.get.dataType.simpleString} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 3d49323751a10..827931d74138d 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 12 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 21 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index b8c91dc8b59a4..7f301614523b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38(line 1, pos 7) +decimal can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9d3dfae348beb..368e52cfbda9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -430,9 +430,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 7dcaf170f9693..40be4e8c1f5be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -78,9 +78,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require( - pred.dataType == BooleanType, - s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + require(pred.dataType == BooleanType, + s"Data type of predicate $pred must be ${BooleanType.simpleString} rather than " + + s"${pred.dataType.simpleString}.") BindReferences.bindReference(pred, relation.partitionCols) } From aec966b05e8df9d459dae88d091de1923e50e2dc Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 9 Jul 2018 14:24:23 -0700 Subject: [PATCH 1084/2461] Revert "[SPARK-24268][SQL] Use datatype.simpleString in error messages" This reverts commit 1bd3d61f4191767a94b71b42f4d00706b703e84f. --- .../spark/sql/kafka010/KafkaWriteTask.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaWriter.scala | 6 +++--- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/DCT.scala | 3 +-- .../apache/spark/ml/feature/FeatureHasher.scala | 5 ++--- .../org/apache/spark/ml/feature/HashingTF.scala | 2 +- .../apache/spark/ml/feature/Interaction.scala | 3 +-- .../org/apache/spark/ml/feature/NGram.scala | 2 +- .../apache/spark/ml/feature/OneHotEncoder.scala | 3 +-- .../org/apache/spark/ml/feature/RFormula.scala | 2 +- .../spark/ml/feature/StopWordsRemover.scala | 4 ++-- .../org/apache/spark/ml/feature/Tokenizer.scala | 3 +-- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/util/SchemaUtils.scala | 11 ++++------- .../BinaryClassificationEvaluatorSuite.scala | 4 ++-- .../apache/spark/ml/feature/RFormulaSuite.scala | 2 +- .../spark/ml/feature/VectorAssemblerSuite.scala | 6 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- .../regression/AFTSurvivalRegressionSuite.scala | 2 +- .../apache/spark/ml/util/MLTestingUtils.scala | 6 +++--- .../expressions/complexTypeCreator.scala | 4 ++-- .../catalyst/expressions/jsonExpressions.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 5 ++--- .../sql/catalyst/json/JacksonGenerator.scala | 4 ++-- .../spark/sql/catalyst/json/JacksonParser.scala | 6 ++---- .../sql/catalyst/json/JsonInferSchema.scala | 6 ++---- .../spark/sql/catalyst/util/TypeUtils.scala | 5 ++--- .../spark/sql/types/AbstractDataType.scala | 9 ++++----- .../org/apache/spark/sql/types/ArrayType.scala | 5 ++--- .../org/apache/spark/sql/types/DecimalType.scala | 3 +-- .../org/apache/spark/sql/types/ObjectType.scala | 3 +-- .../org/apache/spark/sql/types/StructType.scala | 5 ++--- .../catalyst/analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/ExpressionTypeCheckingSuite.scala | 16 ++++++++-------- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../apache/spark/sql/types/DataTypeSuite.scala | 2 +- .../parquet/VectorizedColumnReader.java | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/arrow/ArrowUtils.scala | 3 +-- .../execution/datasources/orc/OrcFilters.scala | 2 +- .../parquet/ParquetSchemaConverter.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 ++-- .../resources/sql-tests/results/literals.sql.out | 6 +++--- .../datasources/parquet/ParquetSchemaSuite.scala | 4 ++-- .../sql/hive/execution/HiveTableScanExec.scala | 6 +++--- 48 files changed, 88 insertions(+), 108 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 59a84706d4f55..d90630a8adc93 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - s"must be a ${StringType.simpleString}") + "must be a StringType") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.simpleString}") + s"attribute unsupported type $t") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.simpleString}") + s"attribute unsupported type $t") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 3ec26e9edd353..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a ${StringType.simpleString}") + throw new AnalysisException(s"Topic type must be a String") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") + s"must be a String or BinaryType") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") + s"must be a String or BinaryType") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 0e1492ac27449..ddfc0c1a4be2d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -314,7 +314,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binary")) + "value attribute type must be a string or binarytype")) try { /* key field wrong type */ @@ -330,7 +330,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binary")) + "key attribute type must be a string or binarytype")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 70ffd7dee89d7..7079ac6453ffc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -303,7 +303,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binary")) + "value attribute type must be a string or binarytype")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +318,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binary")) + "key attribute type must be a string or binarytype")) } test("streaming - write to non-existing topic") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 1eac1d1613d2a..682787a830113 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,8 +69,7 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], - s"Input type must be ${(new VectorUDT).simpleString} but got ${inputType.simpleString}.") + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index 405ea467cb02a..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -208,9 +208,8 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + - s"${BooleanType.simpleString} or ${StringType.simpleString}. " + - s"Column $fieldName was ${dataType.simpleString}") + s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + + s"Column $fieldName was $dataType") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 403b0a813aedd..db432b6fefaff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") + s"The input column must be ArrayType, but got $inputType.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 5e01ec30bb2eb..4ff1d0ef356f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,8 +261,7 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, - s"${DoubleType.simpleString} columns should only contain one feature.") + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 6445360f7fd90..c8760f9dc178f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ${ArrayType(StringType).simpleString} but got $inputType.") + s"Input type must be ArrayType(StringType) but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 24045f0448c81..5ab6c2dde667a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,8 +85,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type ${NumericType.simpleString} but got " + - schema(inputColName).dataType.simpleString) + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 346e1823f00b8..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - s"Label column already exists and is not of type ${NumericType.simpleString}.") + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index ead75d5b8def3..0f946dd2e015b 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -131,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), "Input type must be " + - s"${ArrayType(StringType).simpleString} but got ${inputType.simpleString}.") + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 5132f63af1796..cfaf6c0e610b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,8 +40,7 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, - s"Input type must be ${StringType.simpleString} type but got ${inputType.simpleString}.") + require(inputType == StringType, s"Input type must be string type but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index ed3b36ee5ab2f..4061154b39c14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type ${other.simpleString} of column $name is not supported.") + case other => Some(s"Data type $other of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 51b88b3117f4e..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -106,7 +106,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") + s"The input column must be ArrayType, but got $inputType.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index b500582074398..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -41,8 +41,7 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type ${dataType.simpleString} but was actually " + - s"${actualDataType.simpleString}.$message") + s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } /** @@ -59,8 +58,7 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.map(_.simpleString).mkString("[", ", ", "]")} but was actually of type " + - s"${actualDataType.simpleString}.$message") + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") } /** @@ -73,9 +71,8 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], - s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + - s"${actualDataType.simpleString}.$message") + require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + + s"NumericType but was actually of type $actualDataType.$message") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 2b0909acf69c3..ede284712b1c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [double, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") + "equal to one of the following types: [DoubleType, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 0de6528c4cf22..a250331efeb1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type numeric.", + "Label column already exists and is not of type NumericType.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index ed15a1d88a269..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type string of column a is not supported.\n" + - "Data type string of column b is not supported.\n" + - "Data type string of column c is not supported.") + "Data type StringType of column a is not supported.\n" + + "Data type StringType of column b is not supported.\n" + + "Data type StringType of column c is not supported.") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 65bee4edc4965..e3dfe2faf5698 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -612,7 +612,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type numeric but was actually of type string")) + s"$column must be of type NumericType but was actually of type StringType")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 6cc73e040e82c..4e4ff71c9de90 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type numeric but was actually of type string")) + "Column censor must be of type NumericType but was actually of type StringType")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 91a8b14625a86..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type numeric but was actually of type string")) + "Column label must be of type NumericType but was actually of type StringType")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type numeric but was actually of type string")) + "Column weight must be of type NumericType but was actually of type StringType")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type numeric but was actually of type string")) + "Column label must be of type NumericType but was actually of type StringType")) } def genClassifDFWithNumericLabelCol( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cf0e3765de80f..0a5f8a907b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -385,8 +385,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Only foldable ${StringType.simpleString} expressions are allowed to appear at odd" + - s" position, got: ${invalidNames.mkString(",")}") + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 1bcf11d7ee737..8cd86053a01c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -796,7 +796,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType.simpleString}") + s"A type of keys and values in map() must be string, but got ${m.dataType}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 70dd4df9df511..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -222,12 +222,11 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have ${IntegerType.simpleString}, but it's ${indexType.simpleString}") + s"have IntegerType, but it's $indexType") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have ${StringType.simpleString} or " + - s"${BinaryType.simpleString}, but it's " + + s"input to function $prettyName should have StringType or BinaryType, but it's " + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 00086abbefd08..9c413de752a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -45,8 +45,8 @@ private[sql] class JacksonGenerator( // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + - s"or ${MapType.simpleString} but got ${dataType.simpleString}") + "JacksonGenerator only supports to be initialized with a StructType " + + s"or MapType but got ${dataType.simpleString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index aa1691bb40d93..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -143,8 +143,7 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException( - s"Cannot parse $other as ${FloatType.simpleString}.") + case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") } } @@ -159,8 +158,7 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => - throw new RuntimeException(s"Cannot parse $other as ${DoubleType.simpleString}.") + case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 5f70e062d46c8..491ca005877f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -294,10 +294,8 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), - s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), - s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index a9aaf617f7837..1dcda49a3af6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -29,7 +29,7 @@ object TypeUtils { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.simpleString}") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") } } @@ -37,8 +37,7 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - s"$caller does not support ordering on type ${dt.simpleString}") + TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index c43cc748655e8..3041f44b116ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[spark] object NumericType extends AbstractDataType { +private[sql] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,12 +155,11 @@ private[spark] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[spark] def defaultConcreteType: DataType = DoubleType + override private[sql] def defaultConcreteType: DataType = DoubleType - override private[spark] def simpleString: String = "numeric" + override private[sql] def simpleString: String = "numeric" - override private[spark] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[NumericType] + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 8f118624f6d2f..38c40482fa4d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[spark] def simpleString: String = "array" + override private[sql] def simpleString: String = "array" } /** @@ -103,8 +103,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException( - s"Type ${other.simpleString} does not support ordered operations") + throw new IllegalArgumentException(s"Type $other does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index f780ffd46a876..dbf51c398fa47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,8 +48,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException( - s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") + throw new AnalysisException(s"DecimalType can only support precision up to 38") } // default constructor for Java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 203e85e1c99bd..2d49fe076786a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,8 +24,7 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException( - s"null literals can't be casted to ${ObjectType.simpleString}") + throw new UnsupportedOperationException("null literals can't be casted to ObjectType") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 0e69ef8ba73e8..362676b252126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -426,7 +426,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") } } @@ -528,8 +528,7 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types ${left.simpleString} " + - s"and ${right.simpleString}") + throw new SparkException(s"Failed to merge incompatible data types $left and $right") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5e503be416a1f..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -514,7 +514,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..36714bd631b0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type map") + "EqualNullSafe does not support ordering on type MapType") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type map") + "LessThan does not support ordering on type MapType") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type map") + "LessThanOrEqual does not support ordering on type MapType") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type map") + "GreaterThan does not support ordering on type MapType") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type map") + "GreaterThanOrEqual does not support ordering on type MapType") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4d422d8506fc..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -469,7 +469,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "decimal can only support precision up to 38") + intercept("1.20E-38BD", "DecimalType can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index fccd057e577d4..5a86f4055dce7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types float and bigint")) + "Failed to merge incompatible data types FloatType and LongType")) } test("existsRecursively") { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 060e2ec068053..d5969b55eef96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -244,7 +244,7 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), descriptor.getType().toString(), - column.dataType().simpleString()); + column.dataType().toString()); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index b068493f2dd17..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -452,7 +452,7 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - s"The returnType of the udf must be a ${StructType.simpleString}") + "The returnType of the udf must be a StructType") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 1274abffaa116..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -47,8 +47,7 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException( - s"${TimestampType.simpleString} must supply timeZoneId parameter") + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index c90328f7ad43f..4f44ae4fa1d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -98,7 +98,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.simpleString}") + case _ => throw new UnsupportedOperationException(s"DataType: $dataType") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 18decad3f62f0..c61be077d309f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -555,7 +555,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type ${field.dataType.simpleString}") + throw new AnalysisException(s"Unsupported data type $field.dataType") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index f772a3336d6af..685d5841ab551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType.simpleString} not supported.") + s"for columns with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 827931d74138d..3d49323751a10 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got map;; line 1 pos 7 +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 -- !query 12 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got map;; line 1 pos 7 +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 -- !query 21 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 7f301614523b2..b8c91dc8b59a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -decimal can only support precision up to 38 +DecimalType can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -decimal can only support precision up to 38 +DecimalType can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -decimal can only support precision up to 38(line 1, pos 7) +DecimalType can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 368e52cfbda9c..9d3dfae348beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -430,9 +430,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 40be4e8c1f5be..7dcaf170f9693 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -78,9 +78,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require(pred.dataType == BooleanType, - s"Data type of predicate $pred must be ${BooleanType.simpleString} rather than " + - s"${pred.dataType.simpleString}.") + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") BindReferences.bindReference(pred, relation.partitionCols) } From eb6e9880397dbac8b0b9ebc0796150b6924fc566 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 9 Jul 2018 14:53:14 -0700 Subject: [PATCH 1085/2461] [SPARK-24759][SQL] No reordering keys for broadcast hash join ## What changes were proposed in this pull request? As the implementation of the broadcast hash join is independent of the input hash partitioning, reordering keys is not necessary. Thus, we solve this issue by simply removing the broadcast hash join from the reordering rule in EnsureRequirements. ## How was this patch tested? N/A Author: Xiao Li Closes #21728 from gatorsmile/cleanER. --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ad95879d86f42..d96ecbaa48029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -279,13 +279,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan match { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) From 4984f1af7e48dab1ae08021a3b17c5ad6d47a87e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 10 Jul 2018 13:54:04 +0800 Subject: [PATCH 1086/2461] [MINOR] Add Sphinx into dev/requirements.txt ## What changes were proposed in this pull request? Not a big deal but this PR adds `sphinx` into `dev/requirements.txt` since we found it needed - https://github.com/apache/spark-website/pull/122#discussion_r200896018 ## How was this patch tested? manually: ``` pip install -r requirements.txt ``` Author: hyukjinkwon Closes #21735 from HyukjinKwon/minor-dev. --- dev/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/requirements.txt b/dev/requirements.txt index 79782279f8fbd..fa833ab96b8e7 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -2,3 +2,4 @@ jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 pypandoc==1.3.3 +sphinx From a289009567c1566a1df4bcdfdf0111e82ae3d81d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 10 Jul 2018 15:58:14 +0800 Subject: [PATCH 1087/2461] [SPARK-24706][SQL] ByteType and ShortType support pushdown to parquet ## What changes were proposed in this pull request? `ByteType` and `ShortType` support pushdown to parquet data source. [Benchmark result](https://issues.apache.org/jira/browse/SPARK-24706?focusedCommentId=16528878&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16528878). ## How was this patch tested? unit tests Author: Yuming Wang Closes #21682 from wangyum/SPARK-24706. --- .../FilterPushdownBenchmark-results.txt | 32 +++++------ .../datasources/parquet/ParquetFilters.scala | 34 +++++++---- .../parquet/ParquetFilterSuite.scala | 56 +++++++++++++++++++ 3 files changed, 94 insertions(+), 28 deletions(-) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 29fe4345d69da..110669b69a00d 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -542,39 +542,39 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3726 / 3775 4.2 236.9 1.0X -Parquet Vectorized (Pushdown) 3741 / 3789 4.2 237.9 1.0X -Native ORC Vectorized 2793 / 2909 5.6 177.6 1.3X -Native ORC Vectorized (Pushdown) 530 / 561 29.7 33.7 7.0X +Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X +Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X +Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X +Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4385 / 4406 3.6 278.8 1.0X -Parquet Vectorized (Pushdown) 4398 / 4454 3.6 279.6 1.0X -Native ORC Vectorized 3420 / 3501 4.6 217.4 1.3X -Native ORC Vectorized (Pushdown) 1395 / 1432 11.3 88.7 3.1X +Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X +Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X +Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X +Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7307 / 7394 2.2 464.6 1.0X -Parquet Vectorized (Pushdown) 7411 / 7461 2.1 471.2 1.0X -Native ORC Vectorized 6501 / 7814 2.4 413.4 1.1X -Native ORC Vectorized (Pushdown) 7341 / 8637 2.1 466.7 1.0X +Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X +Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X +Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X +Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11886 / 13122 1.3 755.7 1.0X -Parquet Vectorized (Pushdown) 12557 / 14173 1.3 798.4 0.9X -Native ORC Vectorized 10758 / 11971 1.5 684.0 1.1X -Native ORC Vectorized (Pushdown) 10564 / 10713 1.5 671.6 1.1X +Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X +Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X +Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X +Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 4827f706e6016..4c9b940db2b30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -45,6 +45,8 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: decimalMetadata: DecimalMetadata) private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null) private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) private val ParquetLongType = ParquetSchemaType(null, INT64, null) private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) @@ -60,8 +62,10 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -87,8 +91,10 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -111,8 +117,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -132,8 +139,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -153,8 +161,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -174,8 +183,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index f2c0bda256239..067d2fea14fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -179,6 +179,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - tinyint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + assert(df.schema.head.dataType === ByteType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - smallint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + assert(df.schema.head.dataType === ShortType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) From 6fe32869ccb17933e77a4dbe883e36d382fbeeec Mon Sep 17 00:00:00 2001 From: sharkdtu Date: Tue, 10 Jul 2018 20:18:34 +0800 Subject: [PATCH 1088/2461] [SPARK-24678][SPARK-STREAMING] Give priority in use of 'PROCESS_LOCAL' for spark-streaming ## What changes were proposed in this pull request? Currently, `BlockRDD.getPreferredLocations` only get hosts info of blocks, which results in subsequent schedule level is not better than 'NODE_LOCAL'. We can just make a small changes, the schedule level can be improved to 'PROCESS_LOCAL' ## How was this patch tested? manual test Author: sharkdtu Closes #21658 from sharkdtu/master. --- .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 7 +++++-- .../apache/spark/storage/BlockManagerSuite.scala | 13 +++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4e036c2ed49b5..23cf19d55b4ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -30,7 +30,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index df1a4bef616b2..0e1c7d5fd3fa2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -45,6 +45,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -1554,7 +1555,7 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - def blockIdsToHosts( + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { @@ -1569,7 +1570,9 @@ private[spark] object BlockManager { val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i).map(_.host) + blockManagers(blockIds(i)) = blockLocations(i).map { loc => + ExecutorCacheTaskLocation(loc.host, loc.executorId).toString + } } blockManagers.toMap } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..08172f0b07b75 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) } + test("query locations of blockIds") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) + when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) + .thenReturn(Array(blockLocations)) + val env = mock(classOf[SparkEnv]) + + val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) + val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster) + val expectedLocs = Seq("executor_host1_1", "executor_host2_2") + assert(locs(blockIds(0)) == expectedLocs) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: TempFileManager = null From e0559f238009e02c40f65678fec691c07904e8c0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Jul 2018 23:07:10 +0800 Subject: [PATCH 1089/2461] [SPARK-21743][SQL][FOLLOWUP] free aggregate map when task ends ## What changes were proposed in this pull request? This is the first follow-up of https://github.com/apache/spark/pull/21573 , which was only merged to 2.3. This PR fixes the memory leak in another way: free the `UnsafeExternalMap` when the task ends. All the data buffers in Spark SQL are using `UnsafeExternalMap` and `UnsafeExternalSorter` under the hood, e.g. sort, aggregate, window, SMJ, etc. `UnsafeExternalSorter` registers a task completion listener to free the resource, we should apply the same thing to `UnsafeExternalMap`. TODO in the next PR: do not consume all the inputs when having limit in whole stage codegen. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21738 from cloud-fan/limit. --- .../UnsafeFixedWidthAggregationMap.java | 17 ++++++++++----- .../spark/sql/execution/SparkStrategies.scala | 7 +------ .../aggregate/HashAggregateExec.scala | 2 +- .../TungstenAggregationIterator.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 21 ++++++++++++------- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 07a6fcae83b70..cfbcb9aad65c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -73,12 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case Limit(IntegerLiteral(limit), Sort(order, true, child)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c7b2c187cccd..2cac0cfce28de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -328,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..c1911235f8df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) From 32cb50835e7258625afff562939872be002232f2 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 10 Jul 2018 11:08:04 -0700 Subject: [PATCH 1090/2461] [SPARK-24662][SQL][SS] Support limit in structured streaming ## What changes were proposed in this pull request? Support the LIMIT operator in structured streaming. For streams in append or complete output mode, a stream with a LIMIT operator will return no more than the specified number of rows. LIMIT is still unsupported for the update output mode. This change reverts https://github.com/apache/spark/commit/e4fee395ecd93ad4579d9afbf0861f82a303e563 as part of it because it is a better and more complete implementation. ## How was this patch tested? New and existing unit tests. Author: Mukul Murthy Closes #21662 from mukulmurthy/SPARK-24662. --- .../UnsupportedOperationChecker.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 26 ++++- .../streaming/IncrementalExecution.scala | 11 +- .../streaming/StreamingGlobalLimitExec.scala | 102 +++++++++++++++++ .../sql/execution/streaming/memory.scala | 70 +----------- .../streaming/sources/memoryV2.scala | 44 ++----- .../sql/streaming/DataStreamWriter.scala | 4 +- .../execution/streaming/MemorySinkSuite.scala | 62 +--------- .../streaming/MemorySinkV2Suite.scala | 80 +------------ .../spark/sql/streaming/StreamSuite.scala | 108 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 4 +- 11 files changed, 272 insertions(+), 245 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 5ced1ca200daa..f68df5d29b545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -315,8 +315,10 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => - throwError("Limits are not supported on streaming DataFrames/Datasets") + case GlobalLimit(_, _) | LocalLimit(_, _) + if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update => + throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + + "output mode") case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cfbcb9aad65c4..02e095b42a506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -34,7 +35,7 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType /** @@ -349,6 +350,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming global limit operator for streams in append mode. + * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator, + * following the example of the SpecialLimits Strategy above. + * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec. + * Streams with limit in Complete mode use the stateless CollectLimitExec operator. + * Limit is unsupported for streams in Update mode. + */ + case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + } + object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index c480b96626f84..6ae7f2869b0f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -59,7 +59,8 @@ class IncrementalExecution( StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: - StreamingDeduplicationStrategy :: Nil + StreamingDeduplicationStrategy :: + StreamingGlobalLimitStrategy(outputMode) :: Nil } private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -134,8 +135,12 @@ class IncrementalExecution( stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs)) - ) + Some(offsetSeqMetadata.batchWatermarkMs))) + + case l: StreamingGlobalLimitExec => + l.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + outputMode = Some(outputMode)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala new file mode 100644 index 0000000000000..bf4af60c8cf03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} +import org.apache.spark.util.CompletionIterator + +/** + * A physical operator for executing a streaming limit, which makes sure no more than streamLimit + * rows are returned. This operator is meant for streams in Append mode only. + */ +case class StreamingGlobalLimitExec( + streamLimit: Long, + child: SparkPlan, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter { + + private val keySchema = StructType(Array(StructField("key", NullType))) + private val valueSchema = StructType(Array(StructField("value", LongType))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append, + "StreamingGlobalLimitExec is only valid for streams in Append output mode") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + keySchema, + valueSchema, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val updatesStartTimeNs = System.nanoTime + + val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L) + var cumulativeRowCount = preBatchRowCount + + val result = iter.filter { r => + val x = cumulativeRowCount < streamLimit + if (x) { + cumulativeRowCount += 1 + } + x + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + if (cumulativeRowCount > preBatchRowCount) { + numUpdatedStateRows += 1 + numOutputRows += cumulativeRowCount - preBatchRowCount + store.put(key, getValueRow(cumulativeRowCount)) + } + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil + + private def getValueRow(value: Long): UnsafeRow = { + UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7fa13c4aa2c01..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -222,60 +221,19 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink with Logging { +trait MemorySinkBase extends BaseStreamingSink { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] - - /** - * Truncates the given rows to return at most maxRows rows. - * @param rows The data that may need to be truncated. - * @param batchLimit Number of rows to keep in this batch; the rest will be truncated - * @param sinkLimit Total number of rows kept in this sink, for logging purposes. - * @param batchId The ID of the batch that sent these rows, for logging purposes. - * @return Truncated rows. - */ - protected def truncateRowsIfNeeded( - rows: Array[Row], - batchLimit: Int, - sinkLimit: Int, - batchId: Long): Array[Row] = { - if (rows.length > batchLimit && batchLimit >= 0) { - logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") - rows.take(batchLimit) - } else { - rows - } - } -} - -/** - * Companion object to MemorySinkBase. - */ -object MemorySinkBase { - val MAX_MEMORY_SINK_ROWS = "maxRows" - val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 - - /** - * Gets the max number of rows a MemorySink should store. This number is based on the memory - * sink row limit option if it is set. If not, we use a large value so that data truncates - * rather than causing out of memory errors. - * @param options Options for writing from which we get the max rows option - * @return The maximum number of rows a memorySink should store. - */ - def getMemorySinkCapacity(options: DataSourceOptions): Int = { - val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) - if (maxRows >= 0) maxRows else Int.MaxValue - 10 - } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) - extends Sink with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -283,12 +241,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() - /** The number of rows in this MemorySink. */ - private var numRows = 0 - - /** The capacity in rows of this sink. */ - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -321,23 +273,14 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - var rowsToAdd = data.collect() - synchronized { - rowsToAdd = - truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) - batches += rows - numRows += rowsToAdd.length - } + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } case Complete => - var rowsToAdd = data.collect() + val rows = AddedData(batchId, data.collect()) synchronized { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows - numRows = rowsToAdd.length } case _ => @@ -351,7 +294,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo def clear(): Unit = synchronized { batches.clear() - numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 29f8cca476722..f2a35a90af24a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode, options) + new MemoryStreamWriter(this, mode) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,9 +55,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() - /** The number of rows in this MemorySink. */ - private var numRows = 0 - /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -84,11 +81,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write( - batchId: Long, - outputMode: OutputMode, - newRows: Array[Row], - sinkCapacity: Int): Unit = { + def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -96,21 +89,14 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - synchronized { - val rowsToAdd = - truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) - batches += rows - numRows += rowsToAdd.length - } + val rows = AddedData(batchId, newRows) + synchronized { batches += rows } case Complete => + val rows = AddedData(batchId, newRows) synchronized { - val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows - numRows = rowsToAdd.length } case _ => @@ -124,7 +110,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() - numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -132,22 +117,16 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter( - sink: MemorySinkV2, - batchId: Long, - outputMode: OutputMode, - options: DataSourceOptions) +class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) extends DataSourceWriter with Logging { - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows, sinkCapacity) + sink.write(batchId, outputMode, newRows) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -155,21 +134,16 @@ class MemoryWriter( } } -class MemoryStreamWriter( - val sink: MemorySinkV2, - outputMode: OutputMode, - options: DataSourceOptions) +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) extends StreamWriter { - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows, sinkCapacity) + sink.write(epochId, outputMode, newRows) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 926c0b69a03fd..3b9a56ffdde4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -250,7 +250,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) + val s = new MemorySink(df.schema, outputMode) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index b2fd6ba27ebb8..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -38,7 +36,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Append) // Before adding data, check output assert(sink.latestBatchId === None) @@ -70,35 +68,9 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } - test("directly add data in Append output mode with row limit") { - implicit val schema = new StructType().add(new StructField("value", IntegerType)) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - val sink = new MemorySink(schema, OutputMode.Append, options) - - // Before adding data, check output - assert(sink.latestBatchId === None) - checkAnswer(sink.latestBatchData, Seq.empty) - checkAnswer(sink.allData, Seq.empty) - - // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) - assert(sink.latestBatchId === Some(0)) - checkAnswer(sink.latestBatchData, 1 to 3) - checkAnswer(sink.allData, 1 to 3) - - // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) - assert(sink.latestBatchId === Some(1)) - checkAnswer(sink.latestBatchData, 4 to 5) - checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit - } - test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Update) // Before adding data, check output assert(sink.latestBatchId === None) @@ -132,7 +104,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Complete) // Before adding data, check output assert(sink.latestBatchId === None) @@ -164,32 +136,6 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } - test("directly add data in Complete output mode with row limit") { - implicit val schema = new StructType().add(new StructField("value", IntegerType)) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - val sink = new MemorySink(schema, OutputMode.Complete, options) - - // Before adding data, check output - assert(sink.latestBatchId === None) - checkAnswer(sink.latestBatchData, Seq.empty) - checkAnswer(sink.allData, Seq.empty) - - // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) - assert(sink.latestBatchId === Some(0)) - checkAnswer(sink.latestBatchData, 1 to 3) - checkAnswer(sink.allData, 1 to 3) - - // Add batch 1 and check outputs - sink.addBatch(1, 4 to 10) - assert(sink.latestBatchId === Some(1)) - checkAnswer(sink.latestBatchData, 4 to 8) - checkAnswer(sink.allData, 4 to 8) // new data should replace old data - } - test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -265,7 +211,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Append) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index e539510e15755..9be22d94b5654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,16 +17,11 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -45,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -67,7 +62,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( + new MemoryWriter(sink, 0, OutputMode.Append()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -75,7 +70,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( + new MemoryWriter(sink, 19, OutputMode.Append()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -85,73 +80,4 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } - - test("continuous writer with row limit") { - val sink = new MemorySinkV2 - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) - appendWriter.commit(0, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - appendWriter.commit(19, Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))))) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) - - val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) - completeWriter.commit(20, Array( - MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(5, Seq(Row(33))))) - assert(sink.latestBatchId.contains(20)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - completeWriter.commit(21, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), - MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), - MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) - assert(sink.latestBatchId.contains(21)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) - } - - test("microbatch writer with row limit") { - val sink = new MemorySinkV2 - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - - new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) - assert(sink.latestBatchId.contains(25)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), - MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) - assert(sink.latestBatchId.contains(26)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) - - new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), - MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) - assert(sink.latestBatchId.contains(27)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), - MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) - assert(sink.latestBatchId.contains(28)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c1ec1eba69fb2..ca38f04136c7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -805,6 +805,114 @@ class StreamSuite extends StreamTest { } } + test("streaming limit without state") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(0))( + AddData(inputData1, 1 to 8: _*), + CheckAnswer()) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4))( + AddData(inputData2, 1 to 8: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with state") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().limit(4))( + AddData(inputData, 1 to 2: _*), + CheckAnswer(1 to 2: _*), + AddData(inputData, 3 to 6: _*), + CheckAnswer(1 to 4: _*), + AddData(inputData, 7 to 9: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with other operators") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().where("value % 2 = 1").limit(4))( + AddData(inputData, 1 to 5: _*), + CheckAnswer(1, 3, 5), + AddData(inputData, 6 to 9: _*), + CheckAnswer(1, 3, 5, 7), + AddData(inputData, 10 to 12: _*), + CheckAnswer(1, 3, 5, 7)) + } + + test("streaming limit with multiple limits") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(4).limit(2))( + AddData(inputData1, 1), + CheckAnswer(1), + AddData(inputData1, 2 to 8: _*), + CheckAnswer(1, 2)) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4).limit(100).limit(3))( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 3 to 8: _*), + CheckAnswer(1 to 3: _*)) + } + + test("streaming limit in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(5).groupBy("value").count() + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 3: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1))) + } + + test("streaming limits in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3) + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 2 to 6: _*), + CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2))) + } + + test("streaming limit in update mode") { + val inputData = MemoryStream[Int] + val e = intercept[AnalysisException] { + testStream(inputData.toDF().limit(5), OutputMode.Update())( + AddData(inputData, 1 to 3: _*) + ) + } + assert(e.getMessage.contains( + "Limits are not supported on streaming DataFrames/Datasets in Update output mode")) + } + + test("streaming limit in multiple partitions") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().repartition(2).limit(7))( + AddData(inputData, 1 to 10: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false), + AddData(inputData, 11 to 20: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false)) + } + + test("streaming limit in multiple partitions by column") { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF().repartition(2, $"_2").limit(7) + testStream(df)( + AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <= 4)), + false), + AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 8)), + false)) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e41b4534ed51d..4c3fd58cb2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -338,8 +337,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 - else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) + val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath From 6078b891da8fe7fc36579699473168ae7443284c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Jul 2018 18:03:40 -0700 Subject: [PATCH 1091/2461] [SPARK-24730][SS] Add policy to choose max as global watermark when streaming query has multiple watermarks ## What changes were proposed in this pull request? Currently, when a streaming query has multiple watermark, the policy is to choose the min of them as the global watermark. This is safe to do as the global watermark moves with the slowest stream, and is therefore is safe as it does not unexpectedly drop some data as late, etc. While this is indeed the safe thing to do, in some cases, you may want the watermark to advance with the fastest stream, that is, take the max of multiple watermarks. This PR is to add that configuration. It makes the following changes. - Adds a configuration to specify max as the policy. - Saves the configuration in OffsetSeqMetadata because changing it in the middle can lead to unpredictable results. - For old checkpoints without the configuration, it assumes the default policy as min (irrespective of the policy set at the session where the query is being restarted). This is to ensure that existing queries are affected in any way. TODO - [ ] Add a test for recovery from existing checkpoints. ## How was this patch tested? New unit test Author: Tathagata Das Closes #21701 from tdas/SPARK-24730. --- .../apache/spark/sql/internal/SQLConf.scala | 15 ++ .../streaming/MicroBatchExecution.scala | 4 +- .../sql/execution/streaming/OffsetSeq.scala | 37 ++++- .../streaming/WatermarkTracker.scala | 90 ++++++++++-- .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 4 + .../offsets/1 | 4 + .../streaming/EventTimeWatermarkSuite.scala | 136 +++++++++++++++++- 10 files changed, 276 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 50965c1abc68c..ae56cc97581a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -875,6 +875,21 @@ object SQLConf { .stringConf .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is" + + "'max' which chooses the maximum across multiple operators." + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .stringConf + .checkValue( + str => Set("min", "max").contains(str.toLowerCase), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 17ffa2a517312..16651dd060d73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,7 +61,7 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } - private val watermarkTracker = new WatermarkTracker() + private var watermarkTracker: WatermarkTracker = _ override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, @@ -257,6 +257,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) watermarkTracker.setWatermark(metadata.batchWatermarkMs) } @@ -295,6 +296,7 @@ class MicroBatchExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 787174481ff08..1ae3f36c152cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.internal.SQLConf._ /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -86,7 +86,22 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) - private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + private val relevantSQLConfs = Seq( + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + + /** + * Default values of relevant configurations that are used for backward compatibility. + * As new configurations are added to the metadata, existing checkpoints may not have those + * confs. The values in this list ensures that the confs without recovered values are + * set to a default value that ensure the same behavior of the streaming query as it was before + * the restart. + * + * Note, that this is optional; set values here if you *have* to override existing session conf + * with a specific default value for ensuring same behavior of the query as before. + */ + private val relevantSQLConfDefaultValues = Map[String, String]( + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -115,8 +130,22 @@ object OffsetSeqMetadata extends Logging { case None => // For backward compatibility, if a config was not recorded in the offset log, - // then log it, and let the existing conf value in SparkSession prevail. - logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or + // let the existing conf value in SparkSession prevail. + relevantSQLConfDefaultValues.get(confKey) match { + + case Some(defaultValue) => + sessionConf.set(confKey, defaultValue) + logWarning(s"Conf '$confKey' was not found in the offset log, " + + s"using default value '$defaultValue'") + + case None => + val valueStr = sessionConf.getOption(confKey).map { v => + s" Using existing session conf value '$v'." + }.getOrElse { " No value set in session conf." } + logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr") + + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 80865669558dd..7b30db44a2090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -20,15 +20,68 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf -class WatermarkTracker extends Logging { +/** + * Policy to define how to choose a new global watermark value if there are + * multiple watermark operators in a streaming query. + */ +sealed trait MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long +} + +object MultipleWatermarkPolicy { + val DEFAULT_POLICY_NAME = "min" + + def apply(policyName: String): MultipleWatermarkPolicy = { + policyName.toLowerCase match { + case DEFAULT_POLICY_NAME => MinWatermark + case "max" => MaxWatermark + case _ => + throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'") + } + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. + * Note that this is the safe (hence default) policy as the global watermark will advance + * only if all the individual operator watermarks have advanced. In other words, in a + * streaming query with multiple input streams and watermarks defined on all of them, + * the global watermark will advance as slowly as the slowest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most conservative one. + */ +case object MinWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.min + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. So the + * global watermark will advance if any of the individual operator watermarks has advanced. + * In other words, in a streaming query with multiple input streams and watermarks defined on all + * of them, the global watermark will advance as fast as the fastest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most aggressive one and + * may lead to unexpected behavior if the data of the slow stream is delayed. + */ +case object MaxWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.max + } +} + +/** Tracks the watermark value of a streaming query based on a given `policy` */ +case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() - private var watermarkMs: Long = 0 - private var updated = false + private var globalWatermarkMs: Long = 0 def setWatermark(newWatermarkMs: Long): Unit = synchronized { - watermarkMs = newWatermarkMs + globalWatermarkMs = newWatermarkMs } def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { @@ -37,7 +90,6 @@ class WatermarkTracker extends Logging { } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { case (e, index) if e.eventTimeStats.value.count > 0 => logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") @@ -58,16 +110,28 @@ class WatermarkTracker extends Logging { // This is the safest option, because only the global watermark is fault-tolerant. Making // it the minimum of all individual watermarks guarantees it will never advance past where // any individual watermark operator would be if it were in a plan by itself. - val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 - if (newWatermarkMs > watermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - watermarkMs = newWatermarkMs - updated = true + val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + if (chosenGlobalWatermark > globalWatermarkMs) { + logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms") + globalWatermarkMs = chosenGlobalWatermark } else { - logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") - updated = false + logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs") } } - def currentWatermark: Long = synchronized { watermarkMs } + def currentWatermark: Long = synchronized { globalWatermarkMs } +} + +object WatermarkTracker { + def apply(conf: RuntimeConfig): WatermarkTracker = { + // If the session has been explicitly configured to use non-default policy then use it, + // otherwise use the default `min` policy as thats the safe thing to do. + // When recovering from a checkpoint location, it is expected that the `conf` will already + // be configured with the value present in the checkpoint. If there is no policy explicitly + // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced + // through defaults specified in OffsetSeqMetadata.setSessionConf(). + val policyName = conf.get( + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + } } diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata new file mode 100644 index 0000000000000..d6be7fbffa9b7 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata @@ -0,0 +1 @@ +{"id":"549eeb1a-d762-420c-bb44-3fd6d73a5268"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 new file mode 100644 index 0000000000000..43db49d052894 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531172902041,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 new file mode 100644 index 0000000000000..8cc898e81017f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":10000,"batchTimestampMs":1531172902217,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 +0 \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7e8fde1ff8e56..58ed9790ea123 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.io.File import java.text.SimpleDateFormat import java.util.{Calendar, Date} +import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.util.Utils class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -484,6 +488,136 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testWithFlag(false) } + test("MultipleWatermarkPolicy: max") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15), // max(20 - 10, 30 - 15) = 15 + StopStream, + StartStream(), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: min") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 10), // min(20 - 10, 30 - 15) = 10 + StopStream, + StartStream(), + checkWatermark(input1, 10), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // min(120 - 10, 130 - 15) = 110, policy recovered correctly + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) // does not advance when only one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from checkpoints ignores session conf") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val checkpointDir = Utils.createTempDir().getCanonicalFile + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15) // max(20 - 10, 30 - 15) = 15 + ) + } + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from Spark ver 2.3.1 checkpoints ensures min policy") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + input1.addData(20) + input2.addData(30) + input1.addData(10) + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + Execute { _.processAllAvailable() }, + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // should calculate 'min' even if session conf has 'max' policy + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) + ) + } + } + + test("MultipleWatermarkPolicy: fail on incorrect conf values") { + val invalidValues = Seq("", "random") + invalidValues.foreach { value => + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) + } + assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + } + } + + private def dfWithMultipleWatermarks( + input1: MemoryStream[Int], + input2: MemoryStream[Int]): Dataset[_] = { + val df1 = input1.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + val df2 = input2.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "15 seconds") + df1.union(df2).select($"eventTime".cast("int")) + } + + private def checkWatermark(input: MemoryStream[Int], watermark: Long) = Execute { q => + input.addData(1) + q.processAllAvailable() + assert(q.lastProgress.eventTime.get("watermark") == formatTimestamp(watermark)) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get From 1f94bf492c3bce3b61f7fec6132b50e06dea94a8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Jul 2018 10:10:07 +0800 Subject: [PATCH 1092/2461] [SPARK-24530][PYTHON] Add a control to force Python version in Sphinx via environment variable, SPHINXPYTHON ## What changes were proposed in this pull request? This PR proposes to add `SPHINXPYTHON` environment variable to control the Python version used by Sphinx. The motivation of this environment variable is, it seems not properly rendering some signatures in the Python documentation when Python 2 is used by Sphinx. See the JIRA's case. It should be encouraged to use Python 3, but looks we will probably live with this problem for a long while in any event. For the default case of `make html`, it keeps previous behaviour and use `SPHINXBUILD` as it was. If `SPHINXPYTHON` is set, then it forces Sphinx to use the specific Python version. ``` $ SPHINXPYTHON=python3 make html python3 -msphinx -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` 1. if `SPHINXPYTHON` is set, use Python. If `SPHINXBUILD` is set, use sphinx-build. 2. If both are set, `SPHINXBUILD` has a higher priority over `SPHINXPYTHON` 3. By default, `SPHINXBUILD` is used as 'sphinx-build'. Probably, we can somehow work around this via explicitly setting `SPHINXBUILD` but `sphinx-build` can't be easily distinguished since it (at least in my environment and up to my knowledge) doesn't replace `sphinx-build` when newer Sphinx is installed in different Python version. It confuses and doesn't warn for its Python version. ## How was this patch tested? Manually tested: **`python` (Python 2.7) in the path with Sphinx:** ``` $ make html sphinx-build -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` **`python` (Python 2.7) in the path without Sphinx:** ``` $ make html Makefile:8: *** The 'sphinx-build' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the 'sphinx-build' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/. Stop. ``` **`SPHINXPYTHON` set `python` (Python 2.7) with Sphinx:** ``` $ SPHINXPYTHON=python make html Makefile:35: *** Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.. Stop. ``` **`SPHINXPYTHON` set `python` (Python 2.7) without Sphinx:** ``` $ SPHINXPYTHON=python make html Makefile:35: *** Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.. Stop. ``` **`SPHINXPYTHON` set `python3` with Sphinx:** ``` $ SPHINXPYTHON=python3 make html python3 -msphinx -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` **`SPHINXPYTHON` set `python3` without Sphinx:** ``` $ SPHINXPYTHON=python3 make html Makefile:39: *** Python executable 'python3' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/. Stop. ``` **`SPHINXBUILD` set:** ``` $ SPHINXBUILD=sphinx-build make html sphinx-build -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` **Both `SPHINXPYTHON` and `SPHINXBUILD` are set:** ``` $ SPHINXBUILD=sphinx-build SPHINXPYTHON=python make html sphinx-build -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` Author: hyukjinkwon Closes #21659 from HyukjinKwon/SPARK-24530. --- python/docs/Makefile | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/python/docs/Makefile b/python/docs/Makefile index b8e079483c90c..1ed1f33af2326 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -1,19 +1,44 @@ # Makefile for Sphinx documentation # +ifndef SPHINXBUILD +ifndef SPHINXPYTHON +SPHINXBUILD = sphinx-build +endif +endif + +ifdef SPHINXBUILD +# User-friendly check for sphinx-build. +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +else +# Note that there is an issue with Python version and Sphinx in PySpark documentation generation. +# Please remove this check below when this issue is fixed. See SPARK-24530 for more details. +PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))') +ifeq ($(PYTHON_VERSION_CHECK), True) +$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.) +endif +# Check if Sphinx is installed. +ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1) +$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details. +SPHINXBUILD = $(SPHINXPYTHON) -msphinx +endif + # You can set these variables from the command line. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build +# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx. +# They follow: +# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build. +# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON +# 3. By default, SPHINXBUILD is used as 'sphinx-build'. export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter From 74a8d6308bfa6e7ed4c64e1175c77eb3114baed5 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Wed, 11 Jul 2018 12:21:03 +0800 Subject: [PATCH 1093/2461] [SPARK-24165][SQL] Fixing conditional expressions to handle nullability of nested types ## What changes were proposed in this pull request? This PR is proposing a fix for the output data type of ```If``` and ```CaseWhen``` expression. Upon till now, the implementation of exprassions has ignored nullability of nested types from different execution branches and returned the type of the first branch. This could lead to an unwanted ```NullPointerException``` from other expressions depending on a ```If```/```CaseWhen``` expression. Example: ``` val rows = new util.ArrayList[Row]() rows.add(Row(true, ("a", 1))) rows.add(Row(false, (null, 2))) val schema = StructType(Seq( StructField("cond", BooleanType, false), StructField("s", StructType(Seq( StructField("val1", StringType, true), StructField("val2", IntegerType, false) )), false) )) val df = spark.createDataFrame(rows, schema) df .select(when('cond, struct(lit("x").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") .select('res.getField("val1")) .show() ``` Exception: ``` Exception in thread "main" java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.execution.LocalTableScanExec$$anonfun$unsafeRows$1.apply(LocalTableScanExec.scala:44) at org.apache.spark.sql.execution.LocalTableScanExec$$anonfun$unsafeRows$1.apply(LocalTableScanExec.scala:44) ... ``` Output schema: ``` root |-- res.val1: string (nullable = false) ``` ## How was this patch tested? New test cases added into - DataFrameSuite.scala - conditionalExpressions.scala Author: Marek Novotny Closes #21687 from mn-mikke/SPARK-24165. --- .../sql/catalyst/analysis/TypeCoercion.scala | 22 ++++-- .../sql/catalyst/expressions/Expression.scala | 37 +++++++++- .../expressions/conditionalExpressions.scala | 28 ++++---- .../ConditionalExpressionSuite.scala | 70 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 58 +++++++++++++++ 5 files changed, 195 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72908c1f433ee..e8331c90ea0f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -172,6 +172,18 @@ object TypeCoercion { case _ => None } + /** + * The method finds a common type for data types that differ only in nullable, containsNull + * and valueContainsNull flags. If the input types are too different, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -660,8 +672,8 @@ object TypeCoercion { object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual => + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -693,10 +705,10 @@ object TypeCoercion { plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if left.dataType != right.dataType => + case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType) + val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b9fa41a47d0f..44c5556ff9ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -695,6 +695,41 @@ abstract class TernaryExpression extends Expression { } } +/** + * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. + * This logic is usually utilized by expressions combining data from multiple child expressions + * of non-primitive types (e.g. [[CaseWhen]]). + */ +trait ComplexTypeMergingExpression extends Expression { + + /** + * A collection of data types used for resolution the output type of the expression. By default, + * data types of all child expressions. The collection must not be empty. + */ + @transient + lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + + /** + * A method determining whether the input types are equal ignoring nullable, containsNull and + * valueContainsNull flags and thus convenient for resolution of the final data type. + */ + def areInputTypesForMergingEqual: Boolean = { + inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall { + case Seq(dt1, dt2) => dt1.sameType(dt2) + } + } + + override def dataType: DataType = { + require( + inputTypesForMerging.nonEmpty, + "The collection of input data types must not be empty.") + require( + areInputTypesForMergingEqual, + "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + } +} + /** * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 77ac6c088022e..e6377b7d87b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -33,7 +33,12 @@ import org.apache.spark.sql.types._ """) // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends ComplexTypeMergingExpression { + + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + Seq(trueValue.dataType, falseValue.dataType) + } override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable @@ -43,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + s"not ${predicate.dataType.simpleString}") - } else if (!trueValue.dataType.sameType(falseValue.dataType)) { + } else if (!areInputTypesForMergingEqual) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { @@ -51,8 +56,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def dataType: DataType = trueValue.dataType - override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -118,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with Serializable { + extends ComplexTypeMergingExpression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - - def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } - override def dataType: DataType = branches.head._2.dataType - override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { - // Make sure all branch conditions are boolean types. - if (valueTypesEqual) { + if (areInputTypesForMergingEqual) { + // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a099119732e25..e068c32500cfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } + test("if/case when - null flags of non-primitive types") { + val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true)) + val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false)) + val structWithNulls = Literal.create( + create_row(null, null), + StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true)))) + val structWithoutNulls = Literal.create( + create_row(1, "a"), + StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false)))) + val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true)) + val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false)) + + val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls) + val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls) + val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls) + val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls) + val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls) + val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls) + val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls) + val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls) + val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls) + val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls) + + val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls) + val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls) + val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls) + val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls) + val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls) + val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls) + + def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = { + assert(expectedType == result.dataType) + checkEvaluation(result, expectedValue) + } + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf2) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4) + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4) + } + test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") val c1 = 'a.int.at(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ea00d22bff001..9d7645d232d08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2320,6 +2320,64 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24165: CaseWhen/If - nullability of nested types") { + val rows = new java.util.ArrayList[Row]() + rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) + rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null))) + val schema = StructType(Seq( + StructField("cond", BooleanType, true), + StructField("s", StructType(Seq( + StructField("val1", StringType, true), + StructField("val2", IntegerType, false) + )), false), + StructField("a", ArrayType(StringType, true)), + StructField("m", MapType(IntegerType, StringType, true)) + )) + + val sourceDF = spark.createDataFrame(rows, schema) + + val structWhenDF = sourceDF + .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") + .select('res.getField("val1")) + val arrayWhenDF = sourceDF + .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") + .select('res.getItem(0)) + val mapWhenDF = sourceDF + .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") + .select('res.getItem(0)) + + val structIfDF = sourceDF + .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") + .select('res.getField("val1")) + val arrayIfDF = sourceDF + .select(expr("if(cond, array('a', 'b'), a)") as "res") + .select('res.getItem(0)) + val mapIfDF = sourceDF + .select(expr("if(cond, map(0, 'a'), m)") as "res") + .select('res.getItem(0)) + + def checkResult(df: DataFrame, codegenExpected: Boolean): Unit = { + assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == codegenExpected) + checkAnswer(df, Seq(Row("a"), Row(null))) + } + + // without codegen + checkResult(structWhenDF, false) + checkResult(arrayWhenDF, false) + checkResult(mapWhenDF, false) + checkResult(structIfDF, false) + checkResult(arrayIfDF, false) + checkResult(mapIfDF, false) + + // with codegen + checkResult(structWhenDF.filter('cond.isNotNull), true) + checkResult(arrayWhenDF.filter('cond.isNotNull), true) + checkResult(mapWhenDF.filter('cond.isNotNull), true) + checkResult(structIfDF.filter('cond.isNotNull), true) + checkResult(arrayIfDF.filter('cond.isNotNull), true) + checkResult(mapIfDF.filter('cond.isNotNull), true) + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) From 5ff1b9ba1983d5601add62aef64a3e87d07050eb Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Tue, 10 Jul 2018 22:53:44 -0700 Subject: [PATCH 1094/2461] [SPARK-23529][K8S] Support mounting volumes This PR continues #21095 and intersects with #21238. I've added volume mounts as a separate step and added PersistantVolumeClaim support. There is a fundamental problem with how we pass the options through spark conf to fabric8. For each volume type and all possible volume options we would have to implement some custom code to map config values to fabric8 calls. This will result in big body of code we would have to support and means that Spark will always be somehow out of sync with k8s. I think there needs to be a discussion on how to proceed correctly (eg use PodPreset instead) ---- Due to the complications of provisioning and managing actual resources this PR addresses only volume mounting of already present resources. ---- - [x] emptyDir support - [x] Testing - [x] Documentation - [x] KubernetesVolumeUtils tests Author: Andrew Korzhuev Author: madanadit Closes #21260 from andrusha/k8s-vol. --- docs/running-on-kubernetes.md | 48 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 12 ++ .../spark/deploy/k8s/KubernetesConf.scala | 11 ++ .../spark/deploy/k8s/KubernetesUtils.scala | 2 - .../deploy/k8s/KubernetesVolumeSpec.scala | 38 +++++ .../deploy/k8s/KubernetesVolumeUtils.scala | 110 +++++++++++++ .../k8s/features/BasicDriverFeatureStep.scala | 5 +- .../features/BasicExecutorFeatureStep.scala | 5 +- .../features/MountVolumesFeatureStep.scala | 79 ++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 31 ++-- .../k8s/KubernetesExecutorBuilder.scala | 38 ++--- .../k8s/KubernetesVolumeUtilsSuite.scala | 106 +++++++++++++ .../BasicDriverFeatureStepSuite.scala | 23 +-- .../BasicExecutorFeatureStepSuite.scala | 3 + ...ubernetesCredentialsFeatureStepSuite.scala | 3 + .../DriverServiceFeatureStepSuite.scala | 6 + .../features/EnvSecretsFeatureStepSuite.scala | 1 + .../features/LocalDirsFeatureStepSuite.scala | 3 +- .../MountSecretsFeatureStepSuite.scala | 1 + .../MountVolumesFeatureStepSuite.scala | 144 ++++++++++++++++++ .../bindings/JavaDriverFeatureStepSuite.scala | 1 + .../PythonDriverFeatureStepSuite.scala | 2 + .../spark/deploy/k8s/submit/ClientSuite.scala | 1 + .../submit/KubernetesDriverBuilderSuite.scala | 45 +++++- .../k8s/KubernetesExecutorBuilderSuite.scala | 38 ++++- 25 files changed, 705 insertions(+), 51 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 408e446ea4822..7149616e534aa 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -629,6 +629,54 @@ specific to Spark on Kubernetes. Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index bf33179ae3dab..f9a77e71ad618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -220,11 +220,23 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." + val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." + val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes." + + val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath" + val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" + val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" + val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" + val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" + val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" + val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index b0ccaa36b01ed..51d205fdb68d1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -59,6 +59,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleSecretNamesToMountPaths: Map[String, String], roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String], + roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -155,6 +156,12 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) + // Also parse executor volumes in order to verify configuration + // before the driver pod is created + KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) val sparkFiles = sparkConf .getOption("spark.files") @@ -171,6 +178,7 @@ private[spark] object KubernetesConf { driverSecretNamesToMountPaths, driverSecretEnvNamesToKeyRefs, driverEnvs, + driverVolumes, sparkFiles) } @@ -203,6 +211,8 @@ private[spark] object KubernetesConf { val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap + val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) KubernetesConf( sparkConf.clone(), @@ -214,6 +224,7 @@ private[spark] object KubernetesConf { executorMountSecrets, executorEnvSecrets, executorEnv, + executorVolumes, Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 593fb531a004d..66fff267545dc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference - import org.apache.spark.SparkConf import org.apache.spark.util.Utils diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala new file mode 100644 index 0000000000000..b1762d1efe2ea --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] sealed trait KubernetesVolumeSpecificConf + +private[spark] case class KubernetesHostPathVolumeConf( + hostPath: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesPVCVolumeConf( + claimName: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesEmptyDirVolumeConf( + medium: Option[String], + sizeLimit: Option[String]) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( + volumeName: String, + mountPath: String, + mountReadOnly: Boolean, + volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala new file mode 100644 index 0000000000000..713df5fffc3a2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.util.NoSuchElementException + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ + +private[spark] object KubernetesVolumeUtils { + /** + * Extract Spark volume configuration properties with a given name prefix. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing with volume name as key and spec as value + */ + def parseVolumesWithPrefix( + sparkConf: SparkConf, + prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + val properties = sparkConf.getAllWithPrefix(prefix).toMap + + getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" + val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + + for { + path <- properties.getTry(pathKey) + volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) + } yield KubernetesVolumeSpec( + volumeName = volumeName, + mountPath = path, + mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), + volumeConf = volumeConf + ) + } + } + + /** + * Get unique pairs of volumeType and volumeName, + * assuming options are formatted in this way: + * `volumeType`.`volumeName`.`property` = `value` + * @param properties flat mapping of property names to values + * @return Set[(volumeType, volumeName)] + */ + private def getVolumeTypesAndNames( + properties: Map[String, String] + ): Set[(String, String)] = { + properties.keys.flatMap { k => + k.split('.').toList match { + case tpe :: name :: _ => Some((tpe, name)) + case _ => None + } + }.toSet + } + + private def parseVolumeSpecificConf( + options: Map[String, String], + volumeType: String, + volumeName: String): Try[KubernetesVolumeSpecificConf] = { + volumeType match { + case KUBERNETES_VOLUMES_HOSTPATH_TYPE => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + for { + path <- options.getTry(pathKey) + } yield KubernetesHostPathVolumeConf(path) + + case KUBERNETES_VOLUMES_PVC_TYPE => + val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" + for { + claimName <- options.getTry(claimNameKey) + } yield KubernetesPVCVolumeConf(claimName) + + case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => + val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" + val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" + Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + + case _ => + Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) + } + } + + /** + * Convenience wrapper to accumulate key lookup errors + */ + implicit private class MapOps[A, B](m: Map[A, B]) { + def getTry(key: A): Try[B] = { + m + .get(key) + .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 143dc8a12304e..7e67b51de6e04 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import scala.collection.mutable -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ @@ -103,6 +103,7 @@ private[spark] class BasicDriverFeatureStep( .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(driverPod, driverContainer) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 91c54a9776982..abaeff0313a79 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -173,6 +173,7 @@ private[spark] class BasicExecutorFeatureStep( .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(executorPod, containerWithLimitCores) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala new file mode 100644 index 0000000000000..bb0e2b3128efd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s._ + +private[spark] class MountVolumesFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + + override def configurePod(pod: SparkPod): SparkPod = { + val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + + val podWithVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(volumes.toSeq: _*) + .endSpec() + .build() + + val containerWithVolumeMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(volumeMounts.toSeq: _*) + .build() + + SparkPod(podWithVolumes, containerWithVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def constructVolumes( + volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + ): Iterable[(VolumeMount, Volume)] = { + volumeSpecs.map { spec => + val volumeMount = new VolumeMountBuilder() + .withMountPath(spec.mountPath) + .withReadOnly(spec.mountReadOnly) + .withName(spec.volumeName) + .build() + + val volumeBuilder = spec.volumeConf match { + case KubernetesHostPathVolumeConf(hostPath) => + new VolumeBuilder() + .withHostPath(new HostPathVolumeSource(hostPath)) + + case KubernetesPVCVolumeConf(claimName) => + new VolumeBuilder() + .withPersistentVolumeClaim( + new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + new VolumeBuilder() + .withEmptyDir( + new EmptyDirVolumeSource(medium.getOrElse(""), + new Quantity(sizeLimit.orNull))) + } + + val volume = volumeBuilder.withName(spec.volumeName).build() + + (volumeMount, volume) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 0dd1c37661707..7208e3d377593 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} private[spark] class KubernetesDriverBuilder( @@ -33,10 +33,13 @@ private[spark] class KubernetesDriverBuilder( new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => LocalDirsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_), provideJavaStep: ( KubernetesConf[KubernetesDriverSpecificConf] => JavaDriverFeatureStep) = @@ -54,11 +57,15 @@ private[spark] class KubernetesDriverBuilder( provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None - - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { case JavaMainAppResource(_) => @@ -67,10 +74,8 @@ private[spark] class KubernetesDriverBuilder( providePythonStep(kubernetesConf)} .getOrElse(provideJavaStep(kubernetesConf)) - val allFeatures: Seq[KubernetesFeatureConfigStep] = - (baseFeatures :+ bindingsStep) ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = (baseFeatures :+ bindingsStep) ++ + secretFeature ++ envSecretFeature ++ volumesFeature var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 769a0a5a63047..364b6fb367722 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,37 +17,41 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) + => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), - provideSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = - new LocalDirsFeatureStep(_)) { + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq( - provideBasicStep(kubernetesConf), - provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None - - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala new file mode 100644 index 0000000000000..d795d159773a8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class KubernetesVolumeUtilsSuite extends SparkFunSuite { + test("Parses hostPath volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath")) + } + + test("Parses persistentVolumeClaim volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf("claimeName")) + } + + test("Parses emptyDir volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") + sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G"))) + } + + test("Parses emptyDir volume options can be optional") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(None, None)) + } + + test("Defaults optional readOnly to false") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.mountReadOnly === false) + } + + test("Gracefully fails on missing mount key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + } + + test("Gracefully fails on missing option key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 04b909db9d9f3..165f46a07df2f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -50,6 +50,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } + private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS) + test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -62,11 +68,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -74,6 +76,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -143,6 +146,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val pythonKubernetesConf = KubernetesConf( pythonSparkConf, @@ -158,6 +162,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) @@ -176,11 +181,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -188,7 +189,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, allFiles) + val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index f06030aa55c0c..a44fa1f2ffc63 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -89,6 +89,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) @@ -128,6 +129,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -148,6 +150,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map("qux" -> "quux"), + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 7cea83591f3e8..7e916b3854404 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -61,6 +61,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -92,6 +93,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -130,6 +132,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 77d38bf19cd10..8b91e93eecd8c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -67,6 +67,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -98,6 +99,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -119,6 +121,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -149,6 +152,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) val driverService = configurationStep @@ -176,6 +180,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") @@ -201,6 +206,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index af6b35eae484a..1c8d84b76c56b 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -45,6 +45,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ Map.empty, envVarsToKeys, Map.empty, + Nil, Seq.empty[String]) val step = new EnvSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index bd6ce4b42fc8e..a339827b819a9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val defaultLocalDir = "/var/data/default-local-dir" @@ -45,6 +45,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index eff75b8a15daa..2b49b72dfa569 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { secretNamesToMountPaths, Map.empty, Map.empty, + Nil, Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala new file mode 100644 index 0000000000000..d309aa94ec115 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class MountVolumesFeatureStepSuite extends SparkFunSuite { + private val sparkConf = new SparkConf(false) + private val emptyKubernetesConf = KubernetesConf( + sparkConf = sparkConf, + roleSpecificConf = KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + appResourceNamePrefix = "resource", + appId = "app-id", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Nil) + + test("Mounts hostPath volumes") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts pesistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true) + + } + + test("Mounts emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "Memory") + assert(emptyDir.getSizeLimit.getAmount === "6G") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts emptyDir with no options") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "") + assert(emptyDir.getSizeLimit.getAmount === null) + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts multiple volumes") { + val hpVolumeConf = KubernetesVolumeSpec( + "hpVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val pvcVolumeConf = KubernetesVolumeSpec( + "checkpointVolume", + "/checkpoints", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala index 0f2bf2fa1d9b5..18874afe6e53a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -42,6 +42,7 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new JavaDriverFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala index a1f9a5d9e264e..a5dac6869327d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -52,6 +52,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) @@ -88,6 +89,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) val driverContainerwithPySpark = step.configurePod(baseDriverPod).container diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index d045d9ae89c07..4d8e79189ff32 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -141,6 +141,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 4e8c300543430..046e578b94629 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} @@ -31,6 +32,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val JAVA_STEP_TYPE = "java-bindings" private val PYSPARK_STEP_TYPE = "pyspark-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -56,6 +58,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => secretsStep, _ => envSecretsStep, _ => localDirsStep, + _ => mountVolumesStep, _ => javaStep, _ => pythonStep) @@ -82,6 +88,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -107,6 +114,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("EnvName" -> "SecretName:secretKey"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -134,6 +142,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -159,6 +168,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -169,6 +179,39 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { PYSPARK_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + JAVA_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { assert(resolvedSpec.systemProperties.size === stepTypes.size) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index a6bc8bce32926..d0b4127065eb7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) @@ -36,12 +37,15 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, _ => envSecretsStep, - _ => localDirsStep) + _ => localDirsStep, + _ => mountVolumesStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -55,6 +59,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -72,6 +77,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("secret-name" -> "secret-key"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -81,6 +87,32 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/checkpoint")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE) + } + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) stepTypes.foreach { stepType => From 006e798e477b6871ad3ba4417d354d23f45e4013 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 10 Jul 2018 23:18:07 -0700 Subject: [PATCH 1095/2461] [SPARK-23461][R] vignettes should include model predictions for some ML models ## What changes were proposed in this pull request? Add model predictions for Linear Support Vector Machine (SVM) Classifier, Logistic Regression, GBT, RF and DecisionTree in vignettes. ## How was this patch tested? Manually ran the test and checked the result. Author: Huaxin Gao Closes #21678 from huaxingao/spark-23461. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d4713de7806a1..68a18ab57b28d 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -590,6 +590,7 @@ summary(model) Predict values on training data ```{r} prediction <- predict(model, training) +head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Logistic Regression @@ -613,6 +614,7 @@ summary(model) Predict values on training data ```{r} fitted <- predict(model, training) +head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` Multinomial logistic regression against three classes @@ -807,6 +809,7 @@ df <- createDataFrame(t) dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) summary(dtModel) predictions <- predict(dtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Gradient-Boosted Trees @@ -822,6 +825,7 @@ df <- createDataFrame(t) gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Random Forest @@ -837,6 +841,7 @@ df <- createDataFrame(t) rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Bisecting k-Means From 592cc84583d74c78e4cdf34a3b82692c8de8f4a9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 11 Jul 2018 23:43:06 +0800 Subject: [PATCH 1096/2461] [SPARK-24562][TESTS] Support different configs for same test in SQLQueryTestSuite ## What changes were proposed in this pull request? The PR proposes to add support for running the same SQL test input files against different configs leading to the same result. ## How was this patch tested? Involved UTs Author: Marco Gaido Closes #21568 from mgaido91/SPARK-24562. --- .../sql-tests/inputs/join-empty-relation.sql | 5 ++ .../sql-tests/inputs/natural-join.sql | 5 ++ .../resources/sql-tests/inputs/outer-join.sql | 5 ++ .../exists-joins-and-set-ops.sql | 4 ++ .../inputs/subquery/in-subquery/in-joins.sql | 4 ++ .../subquery/in-subquery/not-in-joins.sql | 4 ++ .../apache/spark/sql/SQLQueryTestSuite.scala | 53 ++++++++++++++++--- 7 files changed, 74 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql index 8afa3270f4de4..2e6a5f362a8fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql index 71a50157b766c..e0abeda3eb44f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + create temporary view nt1 as select * from values ("one", 1), ("two", 2), diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index cdc6c81e10047..ce09c21568f13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + -- SPARK-17099: Incorrect result when HAVING clause is added to group by query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (-234), (145), (367), (975), (298) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql index cc4ed64affec7..cefc3fe6272ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql @@ -1,5 +1,9 @@ -- Tests EXISTS subquery support. Tests Exists subquery -- used in Joins (Both when joins occurs in outer and suquery blocks) +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES (100, "emp 1", date "2005-01-01", 100.00D, 10), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index 880175fd7add0..22f3eafd6a02d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for IN JOINS in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index e09b91f18de0a..4f8ca8bfb27c1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for not-in-joins in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index beac9699585d5..826408c7161e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructType * The format for input files is simple: * 1. A list of SQL queries separated by semicolon. * 2. Lines starting with -- are treated as comments and ignored. + * 3. Lines starting with --SET are used to run the file with the following set of configs. * * For example: * {{{ @@ -138,18 +139,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) + val (comments, code) = input.split("\n").partition(_.startsWith("--")) + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + // When we are regenerating the golden files we don't need to run all the configs as they + // all need to return the same result + if (regenerateGoldenFiles && configs.nonEmpty) { + configs.take(1) + } else { + configs + } + } // List of SQL queries to run - val queries: Seq[String] = { - val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") - // note: this is not a robust way to split queries using semicolon, but works for now. - cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + if (configSets.isEmpty) { + runQueries(queries, testCase.resultFile, None) + } else { + configSets.foreach { configSet => + try { + runQueries(queries, testCase.resultFile, Some(configSet)) + } catch { + case e: Throwable => + val configs = configSet.map { + case (k, v) => s"$k=$v" + } + logError(s"Error using configs: ${configs.mkString(",")}") + throw e + } + } } + } + private def runQueries( + queries: Seq[String], + resultFileName: String, + configSet: Option[Seq[(String, String)]]): Unit = { // Create a local SparkSession to have stronger isolation between different test cases. // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + if (configSet.isDefined) { + // Execute the list of set operation in order to add the desired configs + val setOperations = configSet.get.map { case (key, value) => s"set $key=$value" } + logInfo(s"Setting configs: ${setOperations.mkString(", ")}") + setOperations.foreach(localSparkSession.sql) + } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => val (schema, output) = getNormalizedResult(localSparkSession, sql) @@ -167,7 +208,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile) + val resultFile = new File(resultFileName) val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) @@ -177,7 +218,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Read back the golden file. val expectedOutputs: Seq[QueryOutput] = { - val goldenOutput = fileToString(new File(testCase.resultFile)) + val goldenOutput = fileToString(new File(resultFileName)) val segments = goldenOutput.split("-- !query.+\n") // each query has 3 segments, plus the header From ebf4bfb966389342bfd9bdb8e3b612828c18730c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 11 Jul 2018 09:29:19 -0700 Subject: [PATCH 1097/2461] [SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas ## What changes were proposed in this pull request? A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because of duplicate attributes. This happens because we are not dealing with this specific case in our `dedupAttr` rules. The PR fix the issue by adding the management of the specific case ## How was this patch tested? added UT + manual tests Author: Marco Gaido Author: Marco Gaido Closes #21737 from mgaido91/SPARK-24208. --- python/pyspark/sql/tests.py | 16 ++++++++++++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++++ .../apache/spark/sql/GroupedDatasetSuite.scala | 12 ++++++++++++ 3 files changed, 32 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8d738069adb3d..4404dbe40590a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5925,6 +5925,22 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e187133d03b17..c078efdfc0000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -738,6 +738,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala index 147c0b61f5017..bd54ea415ca88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext { } datasetWithUDF.unpersist(true) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s") + df1.queryExecution.assertAnalyzed() + } } From 290c30a53fc2f46001846ab8abafcc69b853ba98 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 11 Jul 2018 13:48:28 -0500 Subject: [PATCH 1098/2461] [SPARK-24470][CORE] RestSubmissionClient to be robust against 404 & non json responses ## What changes were proposed in this pull request? Added check for 404, to avoid json parsing on not found response and to avoid returning malformed or bad request when it was a not found http response. Not sure if I need to add an additional check on non json response [if(connection.getHeaderField("Content-Type").contains("text/html")) then exception] as non-json is a subset of malformed json and covered in flow. ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Closes #21684 from rekhajoshm/SPARK-24470. --- .../deploy/rest/RestSubmissionClient.scala | 60 ++++++++++++------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 742a95841a138..31a8e3e60c067 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { import scala.concurrent.ExecutionContext.Implicits.global val responseFuture = Future { - val dataStream = - if (connection.getResponseCode == HttpServletResponse.SC_OK) { - connection.getInputStream - } else { - connection.getErrorStream + val responseCode = connection.getResponseCode + + if (responseCode != HttpServletResponse.SC_OK) { + val errString = Some(Source.fromInputStream(connection.getErrorStream()) + .getLines().mkString("\n")) + if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR && + !connection.getContentType().contains("application/json")) { + throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}") + } + logError(s"Server responded with error:\n${errString}") + val error = new ErrorResponse + if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) { + error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION + } + error.message = errString.get + error + } else { + val dataStream = connection.getInputStream + + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") } - // If the server threw an exception while writing a response, it will not have a body - if (dataStream == null) { - throw new SubmitRestProtocolException("Server returned empty body") - } - val responseJson = Source.fromInputStream(dataStream).mkString - logDebug(s"Response from the server:\n$responseJson") - val response = SubmitRestProtocolMessage.fromJson(responseJson) - response.validate() - response match { - // If the response is an error, log the message - case error: ErrorResponse => - logError(s"Server responded with error:\n${error.message}") - error - // Otherwise, simply return the response - case response: SubmitRestProtocolResponse => response - case unexpected => - throw new SubmitRestProtocolException( - s"Message received from server was not a response:\n${unexpected.toJson}") } } From 59c3c233f4366809b6b4db39b3d32c194c98d5ab Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 11 Jul 2018 13:56:09 -0500 Subject: [PATCH 1099/2461] [SPARK-23254][ML] Add user guide entry and example for DataFrame multivariate summary ## What changes were proposed in this pull request? Add user guide and scala/java/python examples for `ml.stat.Summarizer` ## How was this patch tested? Doc generated snapshot: ![image](https://user-images.githubusercontent.com/19235986/38987108-45646044-4401-11e8-9ba8-ae94ba96cbf9.png) ![image](https://user-images.githubusercontent.com/19235986/38987096-36dcc73c-4401-11e8-87f9-5b91e7f9e27b.png) ![image](https://user-images.githubusercontent.com/19235986/38987088-2d1c1eaa-4401-11e8-80b5-8c40d529a120.png) ![image](https://user-images.githubusercontent.com/19235986/38987077-22ce8be0-4401-11e8-8199-c3a4d8d23201.png) Author: WeichenXu Closes #20446 from WeichenXu123/summ_guide. --- docs/ml-statistics.md | 28 ++++++++ .../examples/ml/JavaSummarizerExample.java | 71 +++++++++++++++++++ .../src/main/python/ml/summarizer_example.py | 59 +++++++++++++++ .../spark/examples/ml/SummarizerExample.scala | 61 ++++++++++++++++ 4 files changed, 219 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java create mode 100644 examples/src/main/python/ml/summarizer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index abfb3cab1e566..6c82b3bb94b24 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -89,4 +89,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat {% include_example python/ml/chi_square_test_example.py %} + + +## Summarizer + +We provide vector column summary statistics for `Dataframe` through `Summarizer`. +Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. + +
      +
      +The following example demonstrates using [`Summarizer`](api/scala/index.html#org.apache.spark.ml.stat.Summarizer$) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example scala/org/apache/spark/examples/ml/SummarizerExample.scala %} +
      + +
      +The following example demonstrates using [`Summarizer`](api/java/org/apache/spark/ml/stat/Summarizer.html) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example java/org/apache/spark/examples/ml/JavaSummarizerExample.java %} +
      + +
      +Refer to the [`Summarizer` Python docs](api/python/index.html#pyspark.ml.stat.Summarizer$) for details on the API. + +{% include_example python/ml/summarizer_example.py %} +
      +
      \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java new file mode 100644 index 0000000000000..e9b84365d86ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.*; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.stat.Summarizer; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaSummarizerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaSummarizerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Vectors.dense(2.0, 3.0, 5.0), 1.0), + RowFactory.create(Vectors.dense(4.0, 6.0, 7.0), 2.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + Row result1 = df.select(Summarizer.metrics("mean", "variance") + .summary(new Column("features"), new Column("weight")).as("summary")) + .select("summary.mean", "summary.variance").first(); + System.out.println("with weight: mean = " + result1.getAs(0).toString() + + ", variance = " + result1.getAs(1).toString()); + + Row result2 = df.select( + Summarizer.mean(new Column("features")), + Summarizer.variance(new Column("features")) + ).first(); + System.out.println("without weight: mean = " + result2.getAs(0).toString() + + ", variance = " + result2.getAs(1).toString()); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py new file mode 100644 index 0000000000000..8835f189a1ad4 --- /dev/null +++ b/examples/src/main/python/ml/summarizer_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for summarizer. +Run with: + bin/spark-submit examples/src/main/python/ml/summarizer_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.stat import Summarizer +from pyspark.sql import Row +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("SummarizerExample") \ + .getOrCreate() + sc = spark.sparkContext + + # $example on$ + df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + + # create summarizer for multiple metrics "mean" and "count" + summarizer = Summarizer.metrics("mean", "count") + + # compute statistics for multiple metrics with weight + df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + + # compute statistics for multiple metrics without weight + df.select(summarizer.summary(df.features)).show(truncate=False) + + # compute statistics for single metric "mean" with weight + df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + + # compute statistics for single metric "mean" without weight + df.select(Summarizer.mean(df.features)).show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala new file mode 100644 index 0000000000000..2f54d1d81bc48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.stat.Summarizer +// $example off$ +import org.apache.spark.sql.SparkSession + +object SummarizerExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("SummarizerExample") + .getOrCreate() + + import spark.implicits._ + import Summarizer._ + + // $example on$ + val data = Seq( + (Vectors.dense(2.0, 3.0, 5.0), 1.0), + (Vectors.dense(4.0, 6.0, 7.0), 2.0) + ) + + val df = data.toDF("features", "weight") + + val (meanVal, varianceVal) = df.select(metrics("mean", "variance") + .summary($"features", $"weight").as("summary")) + .select("summary.mean", "summary.variance") + .as[(Vector, Vector)].first() + + println(s"with weight: mean = ${meanVal}, variance = ${varianceVal}") + + val (meanVal2, varianceVal2) = df.select(mean($"features"), variance($"features")) + .as[(Vector, Vector)].first() + + println(s"without weight: mean = ${meanVal2}, sum = ${varianceVal2}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From ff7f6ef75c80633480802d537e66432e3bea4785 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 11 Jul 2018 12:44:42 -0700 Subject: [PATCH 1100/2461] [SPARK-24697][SS] Fix the reported start offsets in streaming query progress ## What changes were proposed in this pull request? In ProgressReporter for streams, we use the `committedOffsets` as the startOffset and `availableOffsets` as the end offset when reporting the status of a trigger in `finishTrigger`. This is a bad pattern that has existed since the beginning of ProgressReporter and it is bad because its super hard to reason about when `availableOffsets` and `committedOffsets` are updated, and when they are recorded. Case in point, this bug silently existed in ContinuousExecution, since before MicroBatchExecution was refactored. The correct fix it to record the offsets explicitly. This PR adds a simple method which is explicitly called from MicroBatch/ContinuousExecition before updating the `committedOffsets`. ## How was this patch tested? Added new tests Author: Tathagata Das Closes #21744 from tdas/SPARK-24697. --- .../streaming/MicroBatchExecution.scala | 3 +++ .../streaming/ProgressReporter.scala | 21 +++++++++++++++---- .../continuous/ContinuousExecution.scala | 3 +++ .../sql/streaming/StreamingQuerySuite.scala | 6 ++++-- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 16651dd060d73..45c43f549d24f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -184,6 +184,9 @@ class MicroBatchExecution( isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) } + // Record the trigger offset range for progress reporting *before* processing the batch + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) + // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed // to false as the batch would have already processed the available data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 16ad3ef9a3d4a..47f4b52e6e34c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -56,8 +56,6 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] - protected def availableOffsets: StreamProgress - protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -68,8 +66,11 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L + private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L + private val currentDurationsMs = new mutable.HashMap[String, Long]() /** Flag that signals whether any error with input metrics have already been logged */ @@ -114,9 +115,20 @@ trait ProgressReporter extends Logging { lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() currentStatus = currentStatus.copy(isTriggerActive = true) + currentTriggerStartOffsets = null + currentTriggerEndOffsets = null currentDurationsMs.clear() } + /** + * Record the offsets range this trigger will process. Call this before updating + * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. + */ + protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { + currentTriggerStartOffsets = from.mapValues(_.json) + currentTriggerEndOffsets = to.mapValues(_.json) + } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress @@ -130,6 +142,7 @@ trait ProgressReporter extends Logging { /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { + assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() val executionStats = extractExecutionStats(hasNewData) @@ -147,8 +160,8 @@ trait ProgressReporter extends Logging { val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, - startOffset = committedOffsets.get(source).map(_.json).orNull, - endOffset = availableOffsets.get(source).map(_.json).orNull, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, processedRowsPerSecond = numRecords / processingTimeSec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a0bb8292d7766..e991dbc81696d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -309,7 +309,10 @@ class ContinuousExecution( def commit(epoch: Long): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + // Record offsets before updating `committedOffsets` + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { commitLog.add(epoch) val offset = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dcf6cb5d609ee..936a076d647b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -335,8 +335,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === "0") - assert(progress.sources(0).endOffset !== null) + assert(progress.sources(0).startOffset === null) // no prior offset + assert(progress.sources(0).endOffset === "0") assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) @@ -362,6 +362,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(query.lastProgress.batchId === 1) assert(query.lastProgress.inputRowsPerSecond === 2.0) assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).startOffset === "0") + assert(query.lastProgress.sources(0).endOffset === "1") true }, From e008ad175256a3192fdcbd2c4793044d52f46d57 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 11 Jul 2018 17:30:43 -0700 Subject: [PATCH 1101/2461] [SPARK-24782][SQL] Simplify conf retrieval in SQL expressions ## What changes were proposed in this pull request? The PR simplifies the retrieval of config in `size`, as we can access them from tasks too thanks to SPARK-24250. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #21736 from mgaido91/SPARK-24605_followup. --- .../expressions/collectionOperations.scala | 10 +-- .../expressions/jsonExpressions.scala | 16 ++--- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 - .../CollectionExpressionsSuite.scala | 27 ++++---- .../expressions/JsonExpressionsSuite.scala | 65 +++++++++---------- .../org/apache/spark/sql/functions.scala | 4 +- 6 files changed, 57 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 879603b66b314..e217d37184511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -89,15 +89,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression > SELECT _FUNC_(NULL); -1 """) -case class Size( - child: Expression, - legacySizeOfNull: Boolean) - extends UnaryExpression with ExpectsInputTypes { +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { - def this(child: Expression) = - this( - child, - legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL)) + val legacySizeOfNull = SQLConf.get.legacySizeOfNull override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd86053a01c7..63943b1a4351a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -514,10 +514,11 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String], - forceNullableSchema: Boolean) + timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -531,8 +532,7 @@ case class JsonToStructs( schema = JsonExprUtils.evalSchemaExpr(schema), options = options, child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) @@ -541,13 +541,7 @@ case class JsonToStructs( schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) - - // Used in `org.apache.spark.sql.functions` - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) | _: MapType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e431c9523a9da..4b4722b0b2117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * The active config object within the current scope. - * Note that if you want to refer config values during execution, you have to capture them - * in Driver and use the captured values in Executors. * See [[SQLConf.get]] for more information. */ def conf: SQLConf = SQLConf.get diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 173c98af323b1..a838a2eedc03c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -24,43 +24,48 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = { + def testSize(sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - checkEvaluation(Size(a0, legacySizeOfNull), 3) - checkEvaluation(Size(a1, legacySizeOfNull), 0) - checkEvaluation(Size(a2, legacySizeOfNull), 2) + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - checkEvaluation(Size(m0, legacySizeOfNull), 2) - checkEvaluation(Size(m1, legacySizeOfNull), 0) - checkEvaluation(Size(m2, legacySizeOfNull), 1) + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) checkEvaluation( - Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull), + Size(Literal.create(null, MapType(StringType, StringType))), expected = sizeOfNull) checkEvaluation( - Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull), + Size(Literal.create(null, ArrayType(StringType))), expected = sizeOfNull) } test("Array and Map Size - legacy") { - testSize(legacySizeOfNull = true, sizeOfNull = -1) + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSize(sizeOfNull = -1) + } } test("Array and Map Size") { - testSize(legacySizeOfNull = false, sizeOfNull = null) + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSize(sizeOfNull = null) + } } test("MapKeys/MapValues") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 52203b9e337ba..04f1c8ce0b83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,8 +512,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID), - true), + Option(tz.getID)), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -522,8 +521,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId, - true), + gmtId), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -532,7 +530,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } @@ -687,23 +685,24 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - val input = - """{ + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ | "a": 1, | "c": "foo" |} |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - val expr = JsonToStructs( - jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) - checkEvaluation(expr, output) - val schema = expr.dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 89dbba10a6bf1..6b956ddb48561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3304,7 +3304,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - new JsonToStructs(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3495,7 +3495,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { new Size(e.expr) } + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, From 3ab48f985c7f96bc9143caad99bf3df7cc984583 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 11 Jul 2018 17:38:43 -0700 Subject: [PATCH 1102/2461] [SPARK-24761][SQL] Adding of isModifiable() to RuntimeConfig ## What changes were proposed in this pull request? In the PR, I propose to extend `RuntimeConfig` by new method `isModifiable()` which returns `true` if a config parameter can be modified at runtime (for current session state). For static SQL and core parameters, the method returns `false`. ## How was this patch tested? Added new test to `RuntimeConfigSuite` for checking Spark core and SQL parameters. Author: Maxim Gekk Closes #21730 from MaxGekk/is-modifiable. --- python/pyspark/sql/conf.py | 8 ++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 4 ++++ .../scala/org/apache/spark/sql/RuntimeConfig.scala | 11 +++++++++++ .../org/apache/spark/sql/RuntimeConfigSuite.scala | 14 ++++++++++++++ 4 files changed, 37 insertions(+) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index db49040e17b63..f80bf598c2211 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -63,6 +63,14 @@ def _checkType(self, obj, identifier): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) + @ignore_unicode_prefix + @since(2.4) + def isModifiable(self, key): + """Indicates whether the configuration property with the given key + is modifiable in the current session. + """ + return self._jconf.isModifiable(key) + def _test(): import os diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ae56cc97581a5..14dd5281fbcb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1907,4 +1907,8 @@ class SQLConf extends Serializable with Logging { } cloned } + + def isModifiable(key: String): Boolean = { + sqlConfEntries.containsKey(key) && !staticConfKeys.contains(key) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index b352e332bc7e0..3c39579149fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -132,6 +132,17 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.unsetConf(key) } + /** + * Indicates whether the configuration property with the given key + * is modifiable in the current session. + * + * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, + * invalid (not existing) and other non-modifiable configuration properties, + * the returned value is `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) + /** * Returns whether a particular key is set. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cfe2e9f2dbc44..cdcea09ad9758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -54,4 +54,18 @@ class RuntimeConfigSuite extends SparkFunSuite { conf.get("k1") } } + + test("SPARK-24761: is a config parameter modifiable") { + val conf = newConf() + + // SQL configs + assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) + assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + // Core configs + assert(!conf.isModifiable("spark.task.cpus")) + assert(!conf.isModifiable("spark.executor.cores")) + // Invalid config parameters + assert(!conf.isModifiable("")) + assert(!conf.isModifiable("invalid config parameter")) + } } From 5ad4735bdad558fe564a0391e207c62743647ab1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 09:52:23 +0800 Subject: [PATCH 1103/2461] [SPARK-24529][BUILD][TEST-MAVEN] Add spotbugs into maven build process ## What changes were proposed in this pull request? This PR enables a Java bytecode check tool [spotbugs](https://spotbugs.github.io/) to avoid possible integer overflow at multiplication. When an violation is detected, the build process is stopped. Due to the tool limitation, some other checks will be enabled. In this PR, [these patterns](http://spotbugs-in-kengo-toda.readthedocs.io/en/lqc-list-detectors/detectors.html#findpuzzlers) in `FindPuzzlers` can be detected. This check is enabled at `compile` phase. Thus, `mvn compile` or `mvn package` launches this check. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21542 from kiszk/SPARK-24529. --- .../util/collection/ExternalSorter.scala | 4 ++-- .../apache/spark/ml/image/HadoopUtils.scala | 8 ++++--- pom.xml | 22 +++++++++++++++++++ resource-managers/kubernetes/core/pom.xml | 6 +++++ .../kubernetes/integration-tests/pom.xml | 5 +++++ .../expressions/collectionOperations.scala | 2 +- .../expressions/conditionalExpressions.scala | 4 ++-- .../sql/catalyst/parser/AstBuilder.scala | 2 +- 8 files changed, 44 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..b159200d79222 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index 8c975a2fba8ca..f1579ec5844a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -42,9 +42,11 @@ private object RecursiveFlag { val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { - old match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) + // avoid false positive of DLS_DEAD_LOCAL_STORE_IN_RETURN by SpotBugs + if (old.isDefined) { + hadoopConf.set(flagName, old.get) + } else { + hadoopConf.unset(flagName) } } } diff --git a/pom.xml b/pom.xml index cd567e227f331..6dee6fce3ffc4 100644 --- a/pom.xml +++ b/pom.xml @@ -2606,6 +2606,28 @@ + + com.github.spotbugs + spotbugs-maven-plugin + 3.1.3 + + ${basedir}/target/scala-${scala.binary.version}/classes + ${basedir}/target/scala-${scala.binary.version}/test-classes + Max + Low + true + FindPuzzlers + false + + + + + check + + compile + + + diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a6dd47a6b7d95..920f0f6ebf2c8 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -47,6 +47,12 @@ test + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + io.fabric8 kubernetes-client diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 6a2fff891098b..29334cc6d891d 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -63,6 +63,11 @@ kubernetes-client ${kubernetes-client.version} + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e217d37184511..b8f2aa3e624ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -993,7 +993,7 @@ trait ArraySortLike extends ExpectsInputTypes { } else if (o2 == null) { nullOrder } else { - -ordering.compare(o1, o2) + ordering.compare(o2, o1) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e6377b7d87b53..30ce9e4743da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -294,7 +294,7 @@ object CaseWhen { case cond :: value :: Nil => Some((cond, value)) case value :: Nil => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } @@ -309,7 +309,7 @@ object CaseKeyWhen { case Seq(cond, value) => Some((EqualTo(key, cond), value)) case Seq(value) => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 383ebde3229d6..f398b479dc273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1507,7 +1507,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case "TIMESTAMP" => Literal(Timestamp.valueOf(value)) case "X" => - val padding = if (value.length % 2 == 1) "0" else "" + val padding = if (value.length % 2 != 0) "0" else "" Literal(DatatypeConverter.parseHexBinary(padding + value)) case other => throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) From 301bff70637983426d76b106b7c659c1f28ed7bf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 17:42:29 +0900 Subject: [PATCH 1104/2461] [SPARK-23914][SQL] Add array_union function ## What changes were proposed in this pull request? The PR adds the SQL function `array_union`. The behavior of the function is based on Presto's one. This function returns returns an array of the elements in the union of array1 and array2. Note: The order of elements in the result is not defined. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21061 from kiszk/SPARK-23914. --- python/pyspark/sql/functions.py | 19 ++ .../catalyst/expressions/UnsafeArrayData.java | 19 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 319 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 81 +++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++ 7 files changed, 499 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9f61e29f9cd42..5ef73987a66a6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2033,6 +2033,25 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 4dd2b7365652a..cf2a5ed2e27f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -450,7 +450,7 @@ public double[] toDoubleArray() { return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; @@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { + return fromPrimitiveArray(null, offset, length, elementSize); + } + + public static boolean shouldUseGenericArrayData(int elementSize, int length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e7517e8c676e3..1d9e470c9ba83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b8f2aa3e624ce..0f4f4f1601b4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3486,3 +3486,322 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } + +/** + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + */ +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = { + val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + ArrayType(elementType, dataTypes.exists(_.containsNull)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +object ArraySetLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + if (resultArray != null) { + resultArray.setInt(pos, elem) + } + hsInt.add(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + if (resultArray != null) { + resultArray.setLong(pos, elem) + } + hsLong.add(elem) + true + } else { + false + } + } + + def evalIntLongPrimitiveType( + array1: ArrayData, + array2: ArrayData, + resultArray: ArrayData, + isLongType: Boolean): Int = { + // store elements into resultArray + var nullElementSize = 0 + var pos = 0 + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + val size = if (!isLongType) hsInt.size else hsLong.size + if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(size) + } + if (array.isNullAt(i)) { + if (nullElementSize == 0) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } + pos += 1 + nullElementSize = 1 + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + } + pos + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray + case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + } + new GenericArrayData(arrayBuffer) + } + } else { + ArrayUnion.unionOrdering(array1, array2, elementType, ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + if (elementType == LongType) "Long" else "Int", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $foundNullElement = false; + |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array.$getter); + | } + | } + |} + |int $size = $hs.size() + ($foundNullElement ? 1 : 0); + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$foundNullElement = false; + |int $pos = 0; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = true; + | } + | } else { + | $javaTypeName $value = $array.$getter; + | if (!$hs.contains($castOp $value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } + | } + |} + """.stripMargin + } else { + val arrayUnion = classOf[ArrayUnion].getName + val et = ctx.addReferenceObj("elementTypeUnion", elementType) + val order = ctx.addReferenceObj("orderingUnion", ordering) + val method = "unionOrdering" + s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" + } + }) + } + + override def prettyName: String = "array_union" +} + +object ArrayUnion { + def unionOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a838a2eedc03c..85d6a1befed6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1304,4 +1304,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) + val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) + val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) + val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) + + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6b956ddb48561..b98ab11e56feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3204,6 +3204,7 @@ object functions { /** * Remove all elements that equal to element from the given array. + * * @group collection_funcs * @since 2.4.0 */ @@ -3218,6 +3219,16 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d60ed7a5ef0d9..d4615714cff03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1198,6 +1198,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } + test("array_union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1, 2, 3, 4)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + } + + val df7 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_union(a, b)") + } + + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_union(a, b)") + } + } + test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null From e6c6f90a55241905c420afbc803dd3bd6961d66b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 13 Jul 2018 00:26:49 +0800 Subject: [PATCH 1105/2461] [SPARK-24691][SQL] Dispatch the type support check in FileFormat implementation ## What changes were proposed in this pull request? With https://github.com/apache/spark/pull/21389, data source schema is validated on driver side before launching read/write tasks. However, 1. Putting all the validations together in `DataSourceUtils` is tricky and hard to maintain. On second thought after review, I find that the `OrcFileFormat` in hive package is not matched, so that its validation wrong. 2. `DataSourceUtils.verifyWriteSchema` and `DataSourceUtils.verifyReadSchema` is not supposed to be called in every file format. We can move them to some upper entry. So, I propose we can add a new method `validateDataType` in FileFormat. File format implementation can override the method to specify its supported/non-supported data types. Although we should focus on data source V2 API, `FileFormat` should remain workable for some time. Adding this new method should be helpful. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21667 from gengliangwang/refactorSchemaValidate. --- .../execution/datasources/DataSource.scala | 1 + .../datasources/DataSourceUtils.scala | 68 ++-------- .../execution/datasources/FileFormat.scala | 9 +- .../datasources/FileFormatWriter.scala | 4 +- .../datasources/csv/CSVFileFormat.scala | 11 +- .../datasources/json/JsonFileFormat.scala | 23 +++- .../datasources/orc/OrcFileFormat.scala | 21 +++- .../parquet/ParquetFileFormat.scala | 19 ++- .../datasources/text/TextFileFormat.scala | 10 +- .../spark/sql/FileBasedDataSourceSuite.scala | 119 ++++++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 21 +++- .../sql/hive/orc/HiveOrcSourceSuite.scala | 19 +-- 14 files changed, 189 insertions(+), 140 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..0c3d9a4895fe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -396,6 +396,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index c5347218c4b40..82e99190ecf14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types._ @@ -42,65 +39,14 @@ object DataSourceUtils { /** * Verify if the schema is supported in datasource. This verification should be done - * in a driver side, e.g., `prepareWrite`, `buildReader`, and `buildReaderWithPartitionValues` - * in `FileFormat`. - * - * Unsupported data types of csv, json, orc, and parquet are as follows; - * csv -> R/W: Interval, Null, Array, Map, Struct - * json -> W: Interval - * orc -> W: Interval, Null - * parquet -> R/W: Interval, Null + * in a driver side. */ private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { - def throwUnsupportedException(dataType: DataType): Unit = { - throw new UnsupportedOperationException( - s"$format data source does not support ${dataType.simpleString} data type.") + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.simpleString} data type.") + } } - - def verifyType(dataType: DataType): Unit = dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - StringType | BinaryType | DateType | TimestampType | _: DecimalType => - - // All the unsupported types for CSV - case _: NullType | _: CalendarIntervalType | _: StructType | _: ArrayType | _: MapType - if format.isInstanceOf[CSVFileFormat] => - throwUnsupportedException(dataType) - - case st: StructType => st.foreach { f => verifyType(f.dataType) } - - case ArrayType(elementType, _) => verifyType(elementType) - - case MapType(keyType, valueType, _) => - verifyType(keyType) - verifyType(valueType) - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - // Interval type not supported in all the write path - case _: CalendarIntervalType if !isReadPath => - throwUnsupportedException(dataType) - - // JSON and ORC don't support an Interval type, but we pass it in read pass - // for back-compatibility. - case _: CalendarIntervalType if format.isInstanceOf[JsonFileFormat] || - format.isInstanceOf[OrcFileFormat] => - - // Interval type not supported in the other read path - case _: CalendarIntervalType => - throwUnsupportedException(dataType) - - // For JSON & ORC backward-compatibility - case _: NullType if format.isInstanceOf[JsonFileFormat] || - (isReadPath && format.isInstanceOf[OrcFileFormat]) => - - // Null type not supported in the other path - case _: NullType => - throwUnsupportedException(dataType) - - // We keep this default case for safeguards - case _ => throwUnsupportedException(dataType) - } - - schema.foreach(field => verifyType(field.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e127888290..2c162e23644ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 52da8356ab835..7c6ab4bc922fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -96,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index fa366ccce6b61..aeb40e5a4131d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -66,7 +66,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +97,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -153,6 +151,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 383bff1375a93..a9241afba537b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSON import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -65,8 +65,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val conf = job.getConfiguration val parsedOptions = new JSONOptions( options, @@ -98,8 +96,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -148,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index df488a748e3e5..3a8c0add8c2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -89,8 +89,6 @@ class OrcFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val conf = job.getConfiguration @@ -143,8 +141,6 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(dataSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) @@ -228,4 +224,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 52a18abb55241..b86b97ec7b103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -78,8 +78,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -303,8 +301,6 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -454,6 +450,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da43535..8661a5395ac44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 86f9647b4ac4c..a7ce952b70ac1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -205,63 +205,121 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + // Unsupported data types of csv, json, orc, and parquet are as follows; - // csv -> R/W: Interval, Null, Array, Map, Struct - // json -> W: Interval - // orc -> W: Interval, Null + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null // parquet -> R/W: Interval, Null test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a struct") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a map") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a array") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage - assert(msg.contains("CSV data source does not support array data type")) + assert(msg.contains("CSV data source does not support mydensevector data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) + assert(msg.contains("CSV data source does not support mydensevector data type.")) } } @@ -276,17 +334,17 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + .contains(s"$format data source does not support interval data type.")) } // read path Seq("parquet", "csv").foreach { format => - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -294,26 +352,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - } - - // We expect the types below should be passed for backward-compatibility - Seq("orc", "json").foreach { format => - // Interval type - var schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - - // UDT having interval data - schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() + .contains(s"$format data source does not support interval data type.")) } } } @@ -324,13 +369,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("orc").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -353,13 +398,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("parquet", "csv").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -367,7 +412,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(s"$format data source does not support null data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", NullType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -375,7 +420,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 270ed7f80197c..ca95aad3976e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2513,7 +2513,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9f..bceaf1a9ec061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index dd2144c5fcea8..20090696ec3fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,7 +72,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) @@ -123,7 +122,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates @@ -178,6 +176,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 69009e1b520c2..fb4957ed943a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -146,38 +146,31 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { sql("select null").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage - assert(msg.contains("ORC data source does not support calendarinterval data type.")) + assert(msg.contains("ORC data source does not support interval data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.mode("overwrite").orc(orcDir) - spark.read.schema(schema).orc(orcDir).collect() - }.getMessage - assert(msg.contains("ORC data source does not support null data type.")) - - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage - assert(msg.contains("ORC data source does not support calendarinterval data type.")) + assert(msg.contains("ORC data source does not support interval data type.")) } } } From 9fa4a1ed38713e2d18a3320d3fc56f9f6db07b06 Mon Sep 17 00:00:00 2001 From: Yash Sharma Date: Thu, 12 Jul 2018 10:04:47 -0700 Subject: [PATCH 1106/2461] =?UTF-8?q?[SPARK-20168][STREAMING=20KINESIS]=20?= =?UTF-8?q?Setting=20the=20timestamp=20directly=20would=20cause=20exceptio?= =?UTF-8?q?n=20on=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Setting the timestamp directly would cause exception on reading stream, it can be set directly only if the mode is not AT_TIMESTAMP ## What changes were proposed in this pull request? The last patch in the kinesis streaming receiver sets the timestamp for the mode AT_TIMESTAMP, but this mode can only be set via the `baseClientLibConfiguration.withTimestampAtInitialPositionInStream() ` and can't be set directly using `.withInitialPositionInStream()` This patch fixes the issue. ## How was this patch tested? Kinesis Receiver doesn't expose the internal state outside, so couldn't find the right way to test this change. Seeking for tips from other contributors here. Author: Yash Sharma Closes #21541 from yashs360/ysharma/fix_kinesis_bug. --- .../org/apache/spark/streaming/kinesis/KinesisReceiver.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fa0de6298a5f1..69c52365b1bf8 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -160,7 +160,6 @@ private[kinesis] class KinesisReceiver[T]( cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) @@ -169,7 +168,8 @@ private[kinesis] class KinesisReceiver[T]( initialPosition match { case ts: AtTimestamp => baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => baseClientLibConfiguration + case _ => + baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) } } From 1055c94cdf072bfce5e36bb6552fe9b148bb9d17 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Thu, 12 Jul 2018 15:36:02 -0500 Subject: [PATCH 1107/2461] [SPARK-24610] fix reading small files via wholeTextFiles ## What changes were proposed in this pull request? The `WholeTextFileInputFormat` determines the `maxSplitSize` for the file/s being read using the `wholeTextFiles` method. While this works well for large files, for smaller files where the maxSplitSize is smaller than the defaults being used with configs like hive-site.xml or explicitly passed in the form of `mapreduce.input.fileinputformat.split.minsize.per.node` or `mapreduce.input.fileinputformat.split.minsize.per.rack` , it just throws up an exception. ```java java.io.IOException: Minimum split size pernode 123456 cannot be larger than maximum split size 9962 at org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat.getSplits(CombineFileInputFormat.java:200) at org.apache.spark.rdd.WholeTextFileRDD.getPartitions(WholeTextFileRDD.scala:50) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:252) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:250) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:250) at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:35) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:252) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:250) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:250) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2096) at org.apache.spark.rdd.RDD.count(RDD.scala:1158) ... 48 elided ` This change checks the maxSplitSize against the minSplitSizePerNode and minSplitSizePerRack and set them if `maxSplitSize < minSplitSizePerNode/Rack` ## How was this patch tested? Test manually setting the conf while launching the job and added unit test. Author: Dhruve Ashar Closes #21601 from dhruve/bug/SPARK-24610. --- .../input/WholeTextFileInputFormat.scala | 13 +++ .../input/WholeTextFileInputFormatSuite.scala | 96 +++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index f47cd38d712c3..04c5c4b90e8a1 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,6 +53,19 @@ private[spark] class WholeTextFileInputFormat val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val config = context.getConfiguration + val minSplitSizePerNode = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala new file mode 100644 index 0000000000000..817dc082b7d38 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileInputFormat WholeTextFileInputFormat]]. A temporary + * directory containing files is created as fake input which is deleted in the end. + */ +class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { + private var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + } + + override def afterAll() { + try { + sc.stop() + } finally { + super.afterAll() + } + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val path = s"${inputDir.toString}/$fileName" + val out = new DataOutputStream(new FileOutputStream(path)) + out.write(contents, 0, contents.length) + out.close() + } + + test("for small files minimum split size per node and per rack should be less than or equal to " + + "maximum split size.") { + var dir : File = null; + try { + dir = Utils.createTempDir() + logInfo(s"Local disk address is ${dir.toString}.") + + // Set the minsize per node and rack to be larger than the size of the input file. + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 123456) + + WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } + // ensure spark job runs successfully without exceptions from the CombineFileInputFormat + assert(sc.wholeTextFiles(dir.toString).count == 3) + } finally { + Utils.deleteRecursively(dir) + } + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileInputFormatSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} From 395860a986987886df6d60fd9b26afd818b2cb39 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 12 Jul 2018 13:55:25 -0700 Subject: [PATCH 1108/2461] [SPARK-24768][SQL] Have a built-in AVRO data source implementation ## What changes were proposed in this pull request? Apache Avro (https://avro.apache.org) is a popular data serialization format. It is widely used in the Spark and Hadoop ecosystem, especially for Kafka-based data pipelines. Using the external package https://github.com/databricks/spark-avro, Spark SQL can read and write the avro data. Making spark-Avro built-in can provide a better experience for first-time users of Spark SQL and structured streaming. We expect the built-in Avro data source can further improve the adoption of structured streaming. The proposal is to inline code from spark-avro package (https://github.com/databricks/spark-avro). The target release is Spark 2.4. [Built-in AVRO Data Source In Spark 2.4.pdf](https://github.com/apache/spark/files/2181511/Built-in.AVRO.Data.Source.In.Spark.2.4.pdf) ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21742 from gengliangwang/export_avro. --- dev/run-tests.py | 2 +- dev/sparktestsupport/modules.py | 10 + external/avro/pom.xml | 73 ++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/sql/avro/AvroFileFormat.scala | 289 +++++++ .../spark/sql/avro/AvroOutputWriter.scala | 164 ++++ .../sql/avro/AvroOutputWriterFactory.scala | 38 + .../spark/sql/avro/SchemaConverters.scala | 406 +++++++++ .../org/apache/spark/sql/avro/package.scala | 39 + .../avro/src/test/resources/episodes.avro | Bin 0 -> 597 bytes .../avro/src/test/resources/log4j.properties | 49 ++ .../test-random-partitioned/part-r-00000.avro | Bin 0 -> 1768 bytes .../test-random-partitioned/part-r-00001.avro | Bin 0 -> 2313 bytes .../test-random-partitioned/part-r-00002.avro | Bin 0 -> 1621 bytes .../test-random-partitioned/part-r-00003.avro | Bin 0 -> 2117 bytes .../test-random-partitioned/part-r-00004.avro | Bin 0 -> 3282 bytes .../test-random-partitioned/part-r-00005.avro | Bin 0 -> 1550 bytes .../test-random-partitioned/part-r-00006.avro | Bin 0 -> 1729 bytes .../test-random-partitioned/part-r-00007.avro | Bin 0 -> 1897 bytes .../test-random-partitioned/part-r-00008.avro | Bin 0 -> 3420 bytes .../test-random-partitioned/part-r-00009.avro | Bin 0 -> 1796 bytes .../test-random-partitioned/part-r-00010.avro | Bin 0 -> 3872 bytes external/avro/src/test/resources/test.avro | Bin 0 -> 1365 bytes external/avro/src/test/resources/test.avsc | 53 ++ external/avro/src/test/resources/test.json | 42 + .../org/apache/spark/sql/avro/AvroSuite.scala | 812 ++++++++++++++++++ .../avro/SerializableConfigurationSuite.scala | 50 ++ .../org/apache/spark/sql/avro/TestUtils.scala | 156 ++++ pom.xml | 1 + project/SparkBuild.scala | 12 +- 30 files changed, 2191 insertions(+), 6 deletions(-) create mode 100644 external/avro/pom.xml create mode 100644 external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100755 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala create mode 100755 external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala create mode 100644 external/avro/src/test/resources/episodes.avro create mode 100644 external/avro/src/test/resources/log4j.properties create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro create mode 100644 external/avro/src/test/resources/test.avro create mode 100644 external/avro/src/test/resources/test.avsc create mode 100644 external/avro/src/test/resources/test.json create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala create mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala create mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/dev/run-tests.py b/dev/run-tests.py index cd4590864b7d7..d9d3789ac1255 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', + ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c6..2aa355504bf29 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,16 @@ def __hash__(self): ] ) +avro = Module( + name="avro", + dependencies=[sql], + source_file_regexes=[ + "external/avro", + ], + sbt_test_goals=[ + "avro/test", + ] +) sql_kafka = Module( name="sql-kafka-0-10", diff --git a/external/avro/pom.xml b/external/avro/pom.xml new file mode 100644 index 0000000000000..42e865bc38824 --- /dev/null +++ b/external/avro/pom.xml @@ -0,0 +1,73 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../pom.xml + + + spark-sql-avro_2.11 + + avro + + jar + Spark Avro + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..95835f0d4ca49 --- /dev/null +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.avro.AvroFileFormat diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala new file mode 100755 index 0000000000000..46e5a189c5eb3 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI +import java.util.zip.Deflater + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.file.{DataFileConstants, DataFileReader} +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job +import org.slf4j.LoggerFactory + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { + private val log = LoggerFactory.getLogger(getClass) + + override def equals(other: Any): Boolean = other match { + case _: AvroFileFormat => true + case _ => false + } + + // Dummy hashCode() to appease ScalaStyle. + override def hashCode(): Int = super.hashCode() + + override def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sparkContext.hadoopConfiguration + + // Schema evolution is not supported yet. Here we only pick a single random sample file to + // figure out the schema of the whole dataset. + val sampleFile = + if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + + " is set to true. Do all input files have \".avro\" extension?" + ) + } + } else { + files.headOption.getOrElse { + throw new FileNotFoundException("No Avro files found.") + } + } + + // User can specify an optional avro json schema. + val avroSchema = options.get(AvroFileFormat.AvroSchema) + .map(new Schema.Parser().parse) + .getOrElse { + val in = new FsInput(sampleFile.getPath, conf) + try { + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + try { + reader.getSchema + } finally { + reader.close() + } + } finally { + in.close() + } + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + override def shortName(): String = "avro" + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = true + + override def prepareWrite( + spark: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val recordName = options.getOrElse("recordName", "topLevelRecord") + val recordNamespace = options.getOrElse("recordNamespace", "") + val build = SchemaBuilder.record(recordName).namespace(recordNamespace) + val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val COMPRESS_KEY = "mapred.output.compress" + + spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { + case "uncompressed" => + log.info("writing uncompressed Avro records") + job.getConfiguration.setBoolean(COMPRESS_KEY, false) + + case "snappy" => + log.info("compressing Avro output using Snappy") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) + + case "deflate" => + val deflateLevel = spark.conf.get( + AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt + log.info(s"compressing Avro output using deflate (level=$deflateLevel)") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + + case unknown: String => + log.error(s"unsupported compression codec $unknown") + } + + new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + } + + override def buildReader( + spark: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + val broadcastedConf = + spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) + val conf = broadcastedConf.value.value + val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) + + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. + // Doing input file filtering is improper because we may generate empty tasks that process no + // input files but stress the scheduler. We should probably add a more general input file + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. + if ( + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && + !file.filePath.endsWith(".avro") + ) { + Iterator.empty + } else { + val reader = { + val in = new FsInput(new Path(new URI(file.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + log.error("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener { _ => + reader.close() + } + } + + reader.sync(file.start) + val stop = file.start + file.length + + val rowConverter = SchemaConverters.createConverterToSQL( + userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + + new Iterator[InternalRow] { + // Used to convert `Row`s containing data columns into `InternalRow`s. + private val encoderForDataColumns = RowEncoder(requiredSchema) + + private[this] var completed = false + + override def hasNext: Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def next(): InternalRow = { + if (reader.pastSync(stop)) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + val safeDataRow = rowConverter(record).asInstanceOf[GenericRow] + + // The safeDataRow is reused, we must do a copy + encoderForDataColumns.toRow(safeDataRow) + } + } + } + } + } +} + +private[avro] object AvroFileFormat { + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" + + val AvroSchema = "avroSchema" + + class SerializableConfiguration(@transient var value: Configuration) + extends Serializable with KryoSerializable { + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + value.write(dos) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + value = new Configuration(false) + value.readFields(new DataInputStream(in)) + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala new file mode 100644 index 0000000000000..830bf3c0570bf --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{IOException, OutputStream} +import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} +import java.util.HashMap + +import scala.collection.immutable.Map + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.generic.GenericRecord +import org.apache.avro.mapred.AvroKey +import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[avro] class AvroOutputWriter( + path: String, + context: TaskAttemptContext, + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriter { + + private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) + // copy of the old conversion logic after api change in SPARK-19085 + private lazy val internalRowConverter = + CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row] + + /** + * Overrides the couple of methods responsible for generating the output streams / files so + * that the data can be correctly partitioned + */ + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = + new AvroKeyOutputFormat[GenericRecord]() { + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + + @throws(classOf[IOException]) + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { + val path = getDefaultWorkFile(context, ".avro") + path.getFileSystem(context.getConfiguration).create(path) + } + + }.getRecordWriter(context) + + override def write(internalRow: InternalRow): Unit = { + val row = internalRowConverter(internalRow) + val key = new AvroKey(converter(row).asInstanceOf[GenericRecord]) + recordWriter.write(key, NullWritable.get()) + } + + override def close(): Unit = recordWriter.close(context) + + /** + * This function constructs converter function for a given sparkSQL datatype. This is used in + * writing Avro records out to disk + */ + private def createConverterToAvro( + dataType: DataType, + structName: String, + recordNamespace: String): (Any) => Any = { + dataType match { + case BinaryType => (item: Any) => item match { + case null => null + case bytes: Array[Byte] => ByteBuffer.wrap(bytes) + } + case ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | StringType | BooleanType => identity + case _: DecimalType => (item: Any) => if (item == null) null else item.toString + case TimestampType => (item: Any) => + if (item == null) null else item.asInstanceOf[Timestamp].getTime + case DateType => (item: Any) => + if (item == null) null else item.asInstanceOf[Date].getTime + case ArrayType(elementType, _) => + val elementConverter = createConverterToAvro( + elementType, + structName, + SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName)) + (item: Any) => { + if (item == null) { + null + } else { + val sourceArray = item.asInstanceOf[Seq[Any]] + val sourceArraySize = sourceArray.size + val targetArray = new Array[Any](sourceArraySize) + var idx = 0 + while (idx < sourceArraySize) { + targetArray(idx) = elementConverter(sourceArray(idx)) + idx += 1 + } + targetArray + } + } + case MapType(StringType, valueType, _) => + val valueConverter = createConverterToAvro( + valueType, + structName, + SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName)) + (item: Any) => { + if (item == null) { + null + } else { + val javaMap = new HashMap[String, Any]() + item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => + javaMap.put(key, valueConverter(value)) + } + javaMap + } + } + case structType: StructType => + val builder = SchemaBuilder.record(structName).namespace(recordNamespace) + val schema: Schema = SchemaConverters.convertStructToAvro( + structType, builder, recordNamespace) + val fieldConverters = structType.fields.map(field => + createConverterToAvro( + field.dataType, + field.name, + SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name))) + (item: Any) => { + if (item == null) { + null + } else { + val record = new Record(schema) + val convertersIterator = fieldConverters.iterator + val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator + val rowIterator = item.asInstanceOf[Row].toSeq.iterator + + while (convertersIterator.hasNext) { + val converter = convertersIterator.next() + record.put(fieldNamesIterator.next(), converter(rowIterator.next())) + } + record + } + } + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala new file mode 100644 index 0000000000000..5b2ce7d7d8e0f --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroOutputWriterFactory( + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriterFactory { + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala new file mode 100644 index 0000000000000..01f8c74982535 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.SchemaBuilder._ +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.generic.GenericFixed + +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + + class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) + + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => SchemaType(IntegerType, nullable = false) + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES => SchemaType(BinaryType, nullable = false) + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => SchemaType(LongType, nullable = false) + case FIXED => SchemaType(BinaryType, nullable = false) + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlType(avroSchema.getTypes.get(0)) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlType(s) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + /** + * This function converts sparkSQL StructType into avro schema. This method uses two other + * converter methods in order to do the conversion. + */ + def convertStructToAvro[T]( + structType: StructType, + schemaBuilder: RecordBuilder[T], + recordNamespace: String): T = { + val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields() + structType.fields.foreach { field => + val newField = fieldsAssembler.name(field.name).`type`() + + if (field.nullable) { + convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace) + .noDefault + } else { + convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace) + .noDefault + } + } + fieldsAssembler.endRecord() + } + + /** + * Returns a converter function to convert row in avro format to GenericRow of catalyst. + * + * @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in + * by user. + * @param targetSqlType Target catalyst sql type after the conversion. + * @return returns a converter function to convert row in avro format to GenericRow of catalyst. + */ + private[avro] def createConverterToSQL( + sourceAvroSchema: Schema, + targetSqlType: DataType): AnyRef => AnyRef = { + + def createConverter(avroSchema: Schema, + sqlType: DataType, path: List[String]): AnyRef => AnyRef = { + val avroType = avroSchema.getType + (sqlType, avroType) match { + // Avro strings are in Utf8, so we have to call toString on them + case (StringType, STRING) | (StringType, ENUM) => + (item: AnyRef) => item.toString + // Byte arrays are reused by avro, so we have to make a copy of them. + case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | + (FloatType, FLOAT) | (LongType, LONG) => + identity + case (TimestampType, LONG) => + (item: AnyRef) => new Timestamp(item.asInstanceOf[Long]) + case (DateType, LONG) => + (item: AnyRef) => new Date(item.asInstanceOf[Long]) + case (BinaryType, FIXED) => + (item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone() + case (BinaryType, BYTES) => + (item: AnyRef) => + val byteBuffer = item.asInstanceOf[ByteBuffer] + val bytes = new Array[Byte](byteBuffer.remaining) + byteBuffer.get(bytes) + bytes + case (struct: StructType, RECORD) => + val length = struct.fields.length + val converters = new Array[AnyRef => AnyRef](length) + val avroFieldIndexes = new Array[Int](length) + var i = 0 + while (i < length) { + val sqlField = struct.fields(i) + val avroField = avroSchema.getField(sqlField.name) + if (avroField != null) { + val converter = (item: AnyRef) => { + if (item == null) { + item + } else { + createConverter(avroField.schema, sqlField.dataType, path :+ sqlField.name)(item) + } + } + converters(i) = converter + avroFieldIndexes(i) = avroField.pos() + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " + + "in Avro schema\n" + + s"Source Avro schema: $sourceAvroSchema.\n" + + s"Target Catalyst type: $targetSqlType") + } + i += 1 + } + + (item: AnyRef) => + val record = item.asInstanceOf[GenericRecord] + val result = new Array[Any](length) + var i = 0 + while (i < converters.length) { + if (converters(i) != null) { + val converter = converters(i) + result(i) = converter(record.get(avroFieldIndexes(i))) + } + i += 1 + } + new GenericRow(result) + case (arrayType: ArrayType, ARRAY) => + val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType, + path) + val allowsNull = arrayType.containsNull + (item: AnyRef) => + item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element => + if (element == null && !allowsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementConverter(element) + } + } + case (mapType: MapType, MAP) if mapType.keyType == StringType => + val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path) + val allowsNull = mapType.valueContainsNull + (item: AnyRef) => + item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { case (k, v) => + if (v == null && !allowsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + (k.toString, valueConverter(v)) + } + }.toMap + case (sqlType, UNION) => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + createConverter(remainingUnionTypes.head, sqlType, path) + } else { + createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path) + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType => + (item: AnyRef) => + item match { + case l: java.lang.Long => l + case i: java.lang.Integer => new java.lang.Long(i.longValue()) + } + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType => + (item: AnyRef) => + item match { + case d: java.lang.Double => d + case f: java.lang.Float => new java.lang.Double(f.doubleValue()) + } + case other => + sqlType match { + case t: StructType if t.fields.length == avroSchema.getTypes.size => + val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map { + case (field, schema) => + createConverter(schema, field.dataType, path :+ field.name) + } + (item: AnyRef) => + val i = GenericData.get().resolveUnion(avroSchema, item) + val converted = new Array[Any](fieldConverters.length) + converted(i) = fieldConverters(i)(item) + new GenericRow(converted) + case _ => throw new IncompatibleSchemaException( + s"Cannot convert Avro schema to catalyst type because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $other, sqlType = $sqlType). \n" + + s"Source Avro schema: $sourceAvroSchema.\n" + + s"Target Catalyst type: $targetSqlType") + } + } + case (left, right) => + throw new IncompatibleSchemaException( + s"Cannot convert Avro schema to catalyst type because schema at path " + + s"${path.mkString(".")} is not compatible (avroType = $right, sqlType = $left). \n" + + s"Source Avro schema: $sourceAvroSchema.\n" + + s"Target Catalyst type: $targetSqlType") + } + } + createConverter(sourceAvroSchema, targetSqlType, List.empty[String]) + } + + /** + * This function is used to convert some sparkSQL type to avro type. Note that this function won't + * be used to construct fields of avro record (convertFieldTypeToAvro is used for that). + */ + private def convertTypeToAvro[T]( + dataType: DataType, + schemaBuilder: BaseTypeBuilder[T], + structName: String, + recordNamespace: String): T = { + dataType match { + case ByteType => schemaBuilder.intType() + case ShortType => schemaBuilder.intType() + case IntegerType => schemaBuilder.intType() + case LongType => schemaBuilder.longType() + case FloatType => schemaBuilder.floatType() + case DoubleType => schemaBuilder.doubleType() + case _: DecimalType => schemaBuilder.stringType() + case StringType => schemaBuilder.stringType() + case BinaryType => schemaBuilder.bytesType() + case BooleanType => schemaBuilder.booleanType() + case TimestampType => schemaBuilder.longType() + case DateType => schemaBuilder.longType() + + case ArrayType(elementType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) + val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) + schemaBuilder.array().items(elementSchema) + + case MapType(StringType, valueType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) + val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) + schemaBuilder.map().values(valueSchema) + + case structType: StructType => + convertStructToAvro( + structType, + schemaBuilder.record(structName).namespace(recordNamespace), + recordNamespace) + + case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") + } + } + + /** + * This function is used to construct fields of the avro record, where schema of the field is + * specified by avro representation of dataType. Since builders for record fields are different + * from those for everything else, we have to use a separate method. + */ + private def convertFieldTypeToAvro[T]( + dataType: DataType, + newFieldBuilder: BaseFieldTypeBuilder[T], + structName: String, + recordNamespace: String): FieldDefault[T, _] = { + dataType match { + case ByteType => newFieldBuilder.intType() + case ShortType => newFieldBuilder.intType() + case IntegerType => newFieldBuilder.intType() + case LongType => newFieldBuilder.longType() + case FloatType => newFieldBuilder.floatType() + case DoubleType => newFieldBuilder.doubleType() + case _: DecimalType => newFieldBuilder.stringType() + case StringType => newFieldBuilder.stringType() + case BinaryType => newFieldBuilder.bytesType() + case BooleanType => newFieldBuilder.booleanType() + case TimestampType => newFieldBuilder.longType() + case DateType => newFieldBuilder.longType() + + case ArrayType(elementType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) + val elementSchema = convertTypeToAvro( + elementType, + builder, + structName, + getNewRecordNamespace(elementType, recordNamespace, structName)) + newFieldBuilder.array().items(elementSchema) + + case MapType(StringType, valueType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) + val valueSchema = convertTypeToAvro( + valueType, + builder, + structName, + getNewRecordNamespace(valueType, recordNamespace, structName)) + newFieldBuilder.map().values(valueSchema) + + case structType: StructType => + convertStructToAvro( + structType, + newFieldBuilder.record(structName).namespace(s"$recordNamespace.$structName"), + s"$recordNamespace.$structName") + + case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") + } + } + + /** + * Returns a new namespace depending on the data type of the element. + * If the data type is a StructType it returns the current namespace concatenated + * with the element name, otherwise it returns the current namespace as it is. + */ + private[avro] def getNewRecordNamespace( + elementDataType: DataType, + currentRecordNamespace: String, + elementName: String): String = { + + elementDataType match { + case StructType(_) => s"$currentRecordNamespace.$elementName" + case _ => currentRecordNamespace + } + } + + private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = { + if (isNullable) { + SchemaBuilder.builder().nullable() + } else { + SchemaBuilder.builder() + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala new file mode 100755 index 0000000000000..b3c8a669cf820 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object avro { + /** + * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using + * the DataFileWriter + */ + implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) { + def avro: String => Unit = writer.format("avro").save + } + + /** + * Adds a method, `avro`, to DataFrameReader that allows you to read avro files using + * the DataFileReader + */ + implicit class AvroDataFrameReader(reader: DataFrameReader) { + def avro: String => DataFrame = reader.format("avro").load + + @scala.annotation.varargs + def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) + } +} diff --git a/external/avro/src/test/resources/episodes.avro b/external/avro/src/test/resources/episodes.avro new file mode 100644 index 0000000000000000000000000000000000000000..58a028ce19e6a1e964465dba8f9fbf5c84b3c3e5 GIT binary patch literal 597 zcmZ`$$w~u35Ov8x#Dj=L5s_jL1qmVRM7@a%A}%2+f)b^isW@$Vx`*!0$Pn`jp8Nt& zeudv=PZl@uSoPMfKD&P$pU7gYWL|p#h4`N7Iwpz8*>)6pQu$8K5g4X3MNCVd^l+mi z^wPBcH1bcl6SZw%Sr`PO_y|f#X$L9-g z;dAr)w);_-ea$!*mc7p@CSd|NlpVELhMh<;4y8h|knQ6Gw{;CytVP*k1x_$Y;bL~} zP%34!WeX0_W;dkQhBBN}WGK8R1;wpeZEAH#z@;EmCg2I|28{bqD#NLaM@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`q4)_ z4$liO^cG152#C1yMoOBpsIolT{kr(v`xw3Q^aHBW4<*kTe|R?Ed|m!p0};Vgja0tv z^0NQSvV@mTJX+?d_R`?juLpDYZ?l{K!|lO@#j3M9R!GPRCFkb&$j^M&lzz*=zqXM# zsNrZ`5s?Tm34X+aF~ge-=lK+ z2L&D(4&@8abtV@!zo-l}lk%Cm$yRMq@$>t&Ket}CzUHQT%Vx^AYX|R6cVK-rRd8j` zOTR5yPp$nu6^}@m>DI|rUwOUhz@FC}U7Lg=b6d_{+`1-vk5BKc(#2Doxw0d#H~rkS z>B_}r%bU$JUshk+u`iG#hrO-m)Zq|*ZMK;OR@s{;vvnSzoJ1dQtC(q4u@wlk?es!+mtYw`gb3HR! zlFhy~`RA@k4N1A}alNE7Cs<_Xv%EKZvz-q_I6ZH2t+4v}K}xeex=j9SE@xNu_tbCuxjBt5tMljYE-QGu oc=J`GqRT&@xU6vJWtjc`@x7({=kF?=`0vk0?&b4-%A@BG085M0ApigX literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro new file mode 100755 index 0000000000000000000000000000000000000000..1ca623a07dcf36e222f502144212cf90a03bec5a GIT binary patch literal 2313 zcma)6c~}$o78OK9KqLrC6%`^;sR3jW5o8G>$QDr67$6ME1QN1fl3_6fWhtway&}sC zqGey|M%V-`OGT*CqEP|G9l;{{fM{5x3D7)RzxRFpYv!JN&$;I}b7ouv$xGDGSj^bQ*4{gQGzd zio_TezFaQT{8pG*Vu>u`D0EuTKY3#7NVvEox-5!(%_UOk01HQ;LxB`<#mYrkR4+GH z@`$7ekYFU4l^k*r7cW|Ro02gm>6Ga08m&C+V$bm3Nr=amBn($dfHa8uwZJmEY{4BO zi5~znk{U>-h@%8|cSG48aTj4nkD!iwh;M8iP%f@$Tk-8-XHOeI+;3LN-L#S&Qr(v8 zR1Vs!C4_U?3{(GFDznj7FR>IAswtO}q?ux;ocxU`<5|P@s;3j-Fq_a4s?zNs2m&^cy=$fyhTFzEv6sTZLfgWDR|cL3 zG48Ci+FwrC{Ujyx8e!}X0qtQ{9fVHofs%mruJk~Cy=($!pO$L~M^6y-3peruZ^V-I zw)e(IuGNj1%!0zWRJDzEaxEYD^Rq#l&4iY>^;;+FxEgx+-<}*%nKE~^`pgH+{qJ7a z`JDgqOP2H2H36CF{kvH-h4L^*imYQQpGa=kgX>RzY~PTvTV}g?)Nqukp{k&H zUSZ68`1Bv+_HDxF>0jmsXUC(Ouh)Ef;~BkB^dYTBKVu3@4|!WxMo92X$J& zd#0*t56R;WHs&Fmd(herUdpB(8xyD&u5BM*j%kxpT7@Df45&r-lEf zUmcIIag=_re-@+o2X0{fKslPOj;FlFoi^{B#?X0T*?VXlXc{Xn2b9{GXJTJ@koN6q1l*b<2dWF z&b(bO@}jc5qht1&+atDh$F~$YLf_?CD|_rnA4F+Gm958047*C9sP2sZ>sj&Plf)R~ zf~uQ0vd;Z}5_#rI*uBS1cfVaN%!8d9bO2DJSDG-MzrF9$pE=)=pY1FN;JKo=ur)JW zl6I!s#sW##8r#CEYcjp|wolmy`Ka4_Q~|g88%8bc159;moWCVxs05qXURB{>A6(6C zwqf%t-}W7GkP50S+g&nn*H+<{0NTZ)^r@K~SYsT&$*3GNh}h=lqkI;R@VY;$*x_8d zn#;#4zS{}W5-%htpjt|@RO@AZ?v-OrJPAKwe|FKiWU|XTR15DFKe@m4eghWvhWr}K z3-*Z;I5#4Y@!*CLwBS_((o#BRzjsd9(`q#1WRZ@dLRL~CK@M{df7zo(*LRB8ogH^A z|M}DW^=MLcOjlq28#O7Kx_+_UC!e!IOX@pWIp5CL&W?^pfAD4**!%O<`k{G(A)>U3 zv7yT~AF1PHq}>>zPPiFc-lUlC>2Yy7_}K|RcVWH%L-{dMTW>usaw3&(F>|>t`b7no l=Dn%ucfNM=JvwJ!aUC~mGJJtO{<_X0i7T92koxBj^iS{0*{lEn literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro new file mode 100755 index 0000000000000000000000000000000000000000..a12e9459e7461cc6a7d6bcdda9f6a5b76a6a840a GIT binary patch literal 1621 zcmeZI%3@>@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`h%K!X7&ekGnQ^-`@$(Y$#!+Dt3>Pb zy#Dsb4-0n0@I4gL6aT{hgPliQ``{nW^tAJ5=Jnggzpsg6bDbix@=|7Ms$$m!B_&S< zPO)1V4zm|1MpSO(cG*(!`x^JXt@F0}?a?nje5>BNcy1!2z{?pm8AmZ?*^C7Oak59|jSH8WVHPuDp(v{xm z^ZV_0J<48xdUa$};qTe}b+;CuU0D?S-QroGw)GtUzDq0%y5=m|^lU;k=i;`DNhOOI zCS6>2<)P2nmUga&%}QnWF3GVjXnq*coIY#DS3~u6W^O-4J&Mhyb>@87xl*vF>4A>6 z;gVRh$=N^GJX*vmZ2Ux_diSfEXFAUp6>j94`^|9Ef?D5`hc=|zT-x$1Jxkr9Yr$7KlU}z^&QOkJ+VyjO zfBwI%Zm(+(&P=-#{-00zX;e~XcUo8Z)5tkeh5?INo<5lp<29$Rhh@?NPN~W^ug=Ri z&Ig2YaNZ6$mLOg@`H;bKai&|#7fd=Wc*5scmQ6{b_t{uUZAIxvF^he*AJ0`y$uzP$ z=<%v0if^LzQOObmDe0&=T8zw}FNlbA+$(7Qtnur`VT<=NeO4EY3zF1S{PZL;p1QCL zNqRe6Y%I{OUhZ?^>c;fqAUnV3-KLRVmm@slvSPT+@yvSI%xANEBpHJ)26C(hAypM@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`` ztqlL{$Gh|rBb_3`Hg4j)#i_LHx}Pqu(f8Z--*<;{bZ}@bGgx+Ufr1e0Q4zMzM@~H} z6mCS^(2USvJorFW!e>re+rw15eU;CD-t%3Z8MVJSe2!>)fV+qDM9ZWVJ|}&{*t0KL zCAyVd2??9w+J8nhb!)Rs#{J@T77Fq%3+7cH3q2|5bf5dwy5}lst&U2uO7FXL!Dsbc9OlY=GA|Eh8uoJ?^bGE~;LD?+gmPv z^Afklj;Y@Id;ULr<~YHlz?MffGATuW&##9+m91?xq*|EfN2 ze{a6*YvJ38eZ2>kO`IQ|9mag^_?eguPlG0|I%;(+=Eu4v#fAWky|n zWA^>n*&M^Iy4{N_!!+C`{nYC_I4eN;>~5R%vyTci^@Hz)NNtHd(6srK!%VZ20n6>h zuYK4#^OcLa(}7p4+b-@pvi5_=1c?O+Q-63geYn=)rlKMfw6JsfDM9TIIy0oU&g3=X z2=&RG`(w)7SF!=ux;YPvyuQsKP;f@)<2Loll3^!;&UCY1(Yn4N@Zz_O&WTQGVn&X7 zA+~%|PrC>PFdS&Hbd3~AE)YAtv_HFN=f&Sl-l3%#E*Uys%brXup17=a-n1)zmy>p% zO}b{MyU*_X%l`bD-#3qHOLt$6|5|L$cj#oiUrg<LZo|eU=EAeREYq`*S9BVc`tu@)c;q=kZ`FDQ4YZt!mU-#|3 zaes~B+8E8#y}d6V?A!eHslC}5@qF8Y&sVz7*BY&z9wz_q&+p0XXA>nl7AU4N%(D7? z>TC4rTyqC`+dDshe0sZf?WO9pZ?jjYufMnXd6vk^k2#4ypFeu}{=FPS6tB-aU;fo1 zk#85@*&1w^D#FlynbE83;(`t!D z$8FCF_A1Wx=FYTwC-Qt}vryy1vaQ+zT&fQKUl_$@zBhKO?+88eU5Qz2lF>fh^+8wN zpGu^y5mt>BG)up{rPXHs*FCqK44%l9cln92dO)UCiOTjZJ)!# zdt!>(tQo0WMYgR9ZkQlt@VrUm^NnnSM=Km+7Yc7YbE*7dW!^&9DcXq(l2%uDp z++Dd_D#<1NvYnRlG>N1*rgu+^Zf#T$*G;^8`iP*;W`__NW@Y!BCRw+OY>B5|e)!Ye sU-|#a()9MT&QI2Uj%avsbZwiRZ2jBLE&b=_+WdL>P4nM}|LBDe08rC;u>b%7 literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro new file mode 100755 index 0000000000000000000000000000000000000000..af56dfc8083dc409b2afb5a1970897b148dd8ef8 GIT binary patch literal 3282 zcma)8c{o)2A8tX}B}>+H2s5agJ#raa$QF{dF~(q+1+zqqEh9@&mTXhmqAa%<`;wH9 ztR>7KSrXYol9aCP`nmT$zvsDso%4A=pZEQ|-*evcJjcx0V=n^_jOc?s0mr%^;2bUp zR}>77;M;_7aCZb6mdW7{;QhQ1fEwU~fMb1J09gP7LvPf01P%|~npy{4kqDFv4p6iA z|ErI~`yerH>#c46PVIt)uhUHsFwA-%g}~v&wpwCXS24Id%m)U?BYbcGN%ntINVFFU z0Y$^ScI0+!ZGl7>ihx*O4^93W{b6M5sJcM-2Tbiur3R1bk;AMBLC*PqOdSL)CEg$Mj>{s+=SQb?Z$4N zE7BL?qOdc&bI476W6{>z+!@0mudKgah24kU?8*N(iH)%>3HjX;2n+%JZ-H%e+kzzl zx$yvSerRVbY9mU3r8z`b&vG3z(1Tc5ZQSS@LQKp|bvJHxj4x_l`cayw_; z!M8W<)c;QHW{{wp`1+2&N0;O%3&2oF7;YnJ_JGCa2Lagt{%!;2;{)@9VqKxmSS)Hi z<~u?=hBhbGVI&@b-bn3VgS!vo?6>}f?NEWOlH^F*lTeZF%RO0TDkohJExBeo$22Q0 zl4KH*eUUf34%w1}4$lz?haGH_o&eM#q3=D)W_bp{!)w1NtG|L6k1>d|KpY}YgmSY% z^mM%r#fyPYD2ApVbV1N zM=(9`jMMem)3t{*L@o~M6}#3{G-dAXFDFRZ%{E#j@h3%_yu%0jokNEh>0vjs|p4kG1LJS-Z#&G6`F>(?lxsb z3hzsbiv$yjm++3_Zz$uWws32Cd?2p}&@fir_)WjZn=*K?iG1o-K8JA^RV8Slk#)## z+!L!o_)6!8ey@){y7F-}8XGYFBG*hMAy~0ms6UwbviE_bP0!70f5r|C`7bYN22H<& zu6>p~*w4x~kPa1TH){`=OuTrv)^|eitr2U0>*|1iNA2g3YV@h;*^ZzuZ|f*WlVZz; z^)I356Y39Ry;nX+6%fPYN|5hM1M2= zJ*Cn?Iq{3uVG%la)_we$VIf@3KHj|%s+y*v?FZudvlRm~;vZd%Nh_TbO8g zFiYNo`VeVS)2)}zMvWI4TAM7tr0%~Kqa>J``}*YTwA4-q&Vh+Rmj#f3bQXh=treiw zKag&hHee9mzU08b1UjDst^3tl&ENE)?so?GjH|FffWZ3--%i;RSF0_khrx{=F*DgD zpyq*=9o=;?2Nh{JXOjFIl{vk$Ybd^_WrdCtxNDQC*epl*K%C9NDZk*RFG7y*K`j535LX#%e4KEyk9 zKT6Q8dLMa7*Uk-LJaZ}%j2SuF1Y3>sdl=@TZ1=MC976F?Loq<1(|EL_<{3#bWrgn< z-`%=tPDY2yS+nwLc$u?@ z?Y=SzMt)N@D^9{W^vsITYmh74B=x8RS6!Fg$zgU!i!5igB$B9e%0V&aFC4!RHSn|pwHbP#9j2Hc76&kHGeL++bQ)1rU2nY^tno!u{8!g9NS z(ieF?F^eOy=Xj`ziEYUQTg;IlbX?gP%06q$s+)SrnD@olZj*G0NP$+058z9>JS`NILDeS0{+s z5OHrUQd+3Rfl3Q`>vrWlR!}rHlLnv+R!u18j2^Aloyg2?^CY6yTNy?U;j~c6G&&gaK7Ufe^U|}^4>z%L-A*$Qq z&tn$RW;|R?aTl8Ne;(aa#*tAJ(sFFSdd{Jd=(92AX_1}2{x`IjD#$7iDFkFUxIF%8 zb|qNPNm8ktn|SOrg1KMeh6CJGUfr!Jm_kMukItUvvvR5;#JUCuH8awjK{jHQ0xax; zRGT)RrxME7ucR1f)UpYKl+yJ!}mD*ZV;hGATFX9RGOBA}=Kf0%>;+zwY5^ z07=<1*~?XV@wb8!FBB@985DHyJ>mft?8$fx`*1?H;u1|u+CVtrONnrCH^<4v#+YL- z9(XURX_vQhD(f`0S2Iz$_$~{*%d4czmq>cHg=Ja^`x_YZ(diFFp!zA!dxtY6`SbIP zc}G;(OnKooQVycgqRaxQRTd$(Sz2|y8t=0ahX z(@d9=qd390A50A{l+m4w3Y-faJtLT}UujXp5}5~hW-`0@$F#%xI-*7g^zowr6WCN2 zxTkTep%qQ^I#qQAnb8*DAZETytBf1v`85Ci&h;)u-=QNs5}5Jn_EvwZy=DOmHmU_T z)>R^C~s-Zv7@T0(NO9YVvvz%4PQg$;fPVlv}byI z;qo~C>W3n_dMA1R)c}b|Db-&6koZOCmAH{G$Rk@1irJ@FRz8HL$nr}YrPAUIO@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`9vH?HLE}M8j+mM~2qt zX0sn#oOgD%;j>w{al?d+L}f=Fj$izL7&jYRTU!J1W+3WTSk$V~G;6_Bj(K5@!k%27 zi!-l?m^C%+JYgzm+IlEa`PiZAug%~0et%V{G27QqcKU4Z!;fpFGNpD2Wu?c=Q`;^7 zZ(eSSf&1(4aWV6=1eIet*74cf=X_P|VVbQTb8w2=kv%sfWlbEJ8RVDNO@4GJ#CTo9 zwLo1LcXe0Q3K@siU&|i#6?JPj6))6Tv3w?zZScu0Itu(LZc{(q)@n~ojWPE(dL!X- zCG}4D5qD08xVe!J*S@vex%Fq+>(|?*ENWGwB5G{s*)d7+g`QdMYvU^1Gq3*EysTrc zPrD8t`{#e3|K`8m-`QXHm55}v`L-FDOyPBxUi?x{TyYtC+S2pm2J^T&^O`*+zpuveu95jy_fFv5UCHL39j$%>vqVyL zCIwu+acT$G5=S1@6^vQJNA9gZx>MfW{?uM>asL?a{aYrUygcJrT=(>8Gao#=>;8JO ztL>ZZbCeQYZU;W~kel46crI!6qb_5HH18bUl0+$v2Z8x&;yoK$rtb;S(byUGZ;#Cb-g$A~!u#Lqyg4T?Gymf*vjEL0R{SL?whX)ee>nMct@rY} N^K;AJ|7S)|MF7-8R#E@} literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro new file mode 100755 index 0000000000000000000000000000000000000000..c326fc434bf181b897372b8ab750a03320145617 GIT binary patch literal 1729 zcmeZI%3@>@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`=-%k3&)dh9=mr>?d2A)g-gYZ?w7wiGk50y8vm7z8cPISTrNj;3TzQOqqIcu znv!?Nq*MAL%vq|Qf)}5!tUvm&+^_$-_{@CI8SU@y{{4N^nd` zmSerb>Q9Bc70y@mHFRFMd(3T$l8bd-ny}d006C$x?Lw;gGB3?0iE0U4xtHg7I>vTt zfQd}@UYGNVXMBG}?OtbJuqbfPyY1iYYabttE?*Z@{Wb7UnZ+rco1SOvdRFIcebrVz zZ~umi)fZMI{YsXc!@&4^SL=db=f7U|*DHB^!=Uu_+_Hlvheg+&X`SnyI5WC!P1+&7 zG!~BOkKJqo^~_r=9XDBLB+kn%U-4F^XwI*9=elM;^SXFh<=}0WHLF~gL{00OI&s<( z%Z)44TvkrK_DHGtLACMJC#t?K!Ru!6B`WHr^oKGuO9o7>Rb6J7KDY5&?@3{!Kbwpu zGfd$L6WrPMI463`wBOSTkIuRkvs^;pT6chI_Ob~NTa{kvOj~(L!+0i(kmz*J6Gz{0 zPF~YJSumt>k>lh}f6qxdrjN|p1bFOL+K7fSKQ(N%xcYq7qg&0<3^#UU#PfB`sjxDM z;CA!ee!l;C`uz33j@$p$tgHU}TKhil^ta-NJ(HQX9yN*8<`w!Cx_$4*!+$rk+uygY z`*J+{x_(-CW?SKdguJ@{2j8Z-WSjSo3%aTtlTa_o^(SKZXLO`ivT-3Y|Y58ndZ=RUUWq4-t5c=uUbB|wL}zkT9Jmp(#Q#&!pE{9^2`x#2|7Y*{ z`{1wo>vvD9&Gq(|{`)y2hMU%h<)pAkLb0|2== B!hHY$ literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro new file mode 100755 index 0000000000000000000000000000000000000000..279f36c317eb8763be6cd3202a88d852550d381e GIT binary patch literal 1897 zcmeZI%3@>@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`p)McTZ*Uy6Uy$j-bjwwT zC4VZ5B+o4$)x3ziS=T z%0fzRdy1_m*D1!@iy?>OCM{d_Ln5h(Gcr~)?-Y|&l*cxgrb7Zo+$k+VY=-C9^hGyM zFg;T7B;cjN+GVp+kVq}^)JF+3&0 zEF~(+Uvg4GO4l}d<(9JtO?=JsYjxwRzPUF)xS5xEz9({W64$1safVFi#X1I5-@oD>yCf!!48l1 zIj1K07SG+js-Z^0X7PQS=GiaR(qxV90YLNAAaN%zu`ce5G07_M-b5on|{uzbrH6y(lMS(DC^zZ-fGaFl%SX zhP3t5r#=bMd=^mXymHy8o^wL$d?u^x7MXSW>aDMn)Zcvl{_f@J&;0*ue_cEJWe+pY Rj1d0to!27ue*Q<#pa29S0zLo$ literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro new file mode 100755 index 0000000000000000000000000000000000000000..8d70f5d1274d4f68f9e6cef0fdb20c8857b81e52 GIT binary patch literal 3420 zcma)8c{r4N8!lTSvXc;#o$SjbyCHiW%!IL=F*6vA8O@BHD6&OE23eAwX)MW>wFMPf zQixDSwjx57sC>5LbI!TG>-zqB-{*dw`+n}<``+($y%$_O4%2bLd~jGfI2M6~vm%hL zXcz&>vjgGb?nqBq4r?HQ;O~V5XaI0XI2MNh$O14h&+VFk#1ou%r?x;>6cUZV12mw4 zfA#SM917#M)!O6l*9a_pi*A90VYd2cBpxrk+Y-jQ3d7xDI2fFO#Nh$Kvj3Grd3vFd z&Ym!@eYt&GyPyvY?Ty@84?tlEfZzjEICh)r_y2JJm*k%D;DfT>7!(%cyxVccc%#w# z#-M<`a*yVCz|1ad%c3(Hi*frWZ`-;oV7s$**%A3S*AXG7G`XyXJxZ}12WOGxL^s|zBMoh>stdH4{q6y=x*fp z&$^HAZQ8H@o!XrsIq%?GJN_QsU`JU142^=}x0411SnYfedfUL?ZNP9in7=dD)!7A$ zMQ_D?UufUZ&cqgsA|O4tQ~TH8!Na)tZ+&6=RKRW->==9~enaSj7=xHT?IHYe$!MA* z4A;d4Q_PC)x^YLWo<$R1sD!iWM`hE|u?qsZ+{gf+)s$uU{8i}3U)sf4iI!;9B>k$Q zC@@!pn5`ZCBOaXCEt5WV0lV z_m>(*S>b(ACz9~|4EdA?fiO`+M-lEiO=$WbU!M=>iLwtKLy30!WfVl8X`p)M4F@PB z5X5c{8qAOy?JSH&lIA)Y&j)36*zzydnVe9EBoClELNt5so0}23g9gF$pEXR%R7;+N zMR;ub8<3Qnv+l8#mi5;?(w558a!l?dgUivku6;_luTXcIw(FA|&b0`{!HCOyUp%om z`M$3*b!__bNOf&^6%wW$H84pFr6-GwR~o!ru$eWmcU6C+g?EEy~}kUm7g?Ieqfetnut%TU#6TCv$hm*xg*pl-pW- z&F|&z4GHy;SOa`xUTRjt7|tHkmv+ImBJ zjc7^NeqS=w#__gb0nt{o-2P4UNue1gaQZuyvi@?AG1MV4>%`DGSzqAS$R#>8fm4Q3 z^oj^}RIemvkz~ym)An%9+p5?swxi${SU?U{PcPsVF8_GBe#3WO$8U++Mx8meF$yOJ zkLI4jzMus0gUKQ(tiYyb>N>Oh@0G}VwLd;j&Hd7hAL)#-5;u^Y5shX?AC>N&^DSxc z$iBR)$!P$&#&r{-8>|z;Wg1^nqYbUrXj&WZcG9Uelwu>q_(0{Ceo}~!35!2Cp*G_= z(vcqBh?5WE(mJesuZFjqk_20;rFi?Z@ZlC+i6`oYn^iExltIdpcWS9h3q7Vi^S)%o zln)>HhxW5fEmlEuF)b!SamvF_Lbm)5$KFZKt2~{Wn08Gsi#gru7_Yl5i-Gd@RExb8hAW zI6U3my&m?#;P)L{9cGQ}-Z*>%8R7)iOr@iXW0#uTnEd+fuTW6O+wdZW@g_`$eXHU{ zRr!Zf9TY{AXa{|cr@m207J@I6_1khiiRwpa9gd3odl{*t8knnq5~#0Sbkf#Kh|M2uR`e@QLj^u*D2FmY0g^KRV`VPT<9wjXkTXkqjkN<>*?CrM%l zoNf)}wX{Qw+u}G9flXPL$!Uu05}tRnj7ZbA_H96nhe!VMooTK*`fM1-8$ zr4(ux_Y@S1M#ZGskWv{gQs@m@Z97aY zj|F(;Jr9;RyC9U3^Nek}VM%pqARM3S@+V2Aqj2C7`*f2aF1MYzSR?GLGEX9V2A*XV z2OC#94`J-I8cAlCm<%d6N>IyYBW|p={i&*tuhy#J(X%=qenCj&D~IzdR5IArHt)XE zfW7{wVZELk%KYuDrEzEZgd(J9t-_2PCYYq?o?= zJ~ha2kF?m^Tia4ip)2O4nSa#$fiHf@g15pWr*4xu^d=Be)YT{V7Ib~Zg?2r^8BxKe zW}6ULUktoD%H74s<^gg`mo{>5hER_xc|q!eKh_$BiMiJ00>721xx!{4IgY$v3>MN7lsf|&_OA>L^zU`pYW4K&QvNhlGwwUJ1b+8Se7JGooWR` zze;p|NIt?TACv5*ToB&b5yMr%dw($VkBLz5gUMD(8O!Y?_@ka{!ZVYEQkxdg`4$e* zX`3)M1fgjSf~-R^URMYc6ix`8YvcK$sKTT(cBoQ_Z|Rmj4Png{+!C9e6+ud<=p-Pj z<)Y`WmEPh{U`3jlIEM?5hVJ}1W_qWA#CH0^KvO-;`ygGSowP%G&>A{UhYUHDu~ z#)h~ZIeTqXA%;S)nb|zu?7yg?dL=YZjk%&1%+%&w`DeR|s-{L^V=KDl4%aOtDxwa2&!3I80O~1fV4P{#T7`uW9tukJ<)!mMVKOEPqz% zx1}WX?$wQ|!rvQjPyPI{a6zbPEk$BJXu9!Uz-A=zF>r)WlPudm%s=k(`F2`oJ;^F; zCL17RB4r!5VNB6D8Ykz_uKNPh!ZRnwhFBb{r6PNXAYi1h=l-zgV)8mpjYf{OTSzMWC@EdPR{12J8cZteE9v6^JDE&pZw@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`22aQTmZy8)zx%yUgv*mfR9Hz#OKZgr5jX2zt|(rq?gT}Z zscnrc+y`eQEL2+Yy=1xN{;zjtW~*8(zc9;Sm+}JH%!YOW?ea*4tr15=4oo~^^-axX zIuB!i;L2&2UKDocZ0t!t*%9nmIJ@iP#T5cJ+x0rn6$Nu`2(*6ed{Ss(SV zF5}$oRomS9v{vkL*TJ7(m=cZ7XkFJxW!HF}mtfLx)8p$GhGTbb9LR|Dmli7#++3+q zJdZuD;08zbS2iW5S9jN{>YH9zGpS%x#tj#~O&{_eDrKL%b@bepgM8m+74Zrr9O!B2 z<-4xTuN1jZJfQfJt9ZuD1dD{nEMd01o7+uNQ+*=&-WQsS>aEJ~uDKH$P?~;g&azCq z45!x~i?^H$V1F*Zul9Ghw)pvfJ7%4;{c+^s-sR_V{nqg*#})&n1-&XgeZ6<0OG5Tm zmi+$y)8Wy!l*cQ4&KT5v-kbk!-ksKytWyv9Z@2rAQXTbnwn!+0TgCf}CZ{WJDO@j= z{r65&Cwe-|R*U;Te(ZegEnZjm^UR!ie_m`oF2A$v$Mg2{cAr1&X?_#X`YhAx>$~gE z&)d}2{0$U5nz<@5V{iX4X~pZ+%8ego=M@zG{kpjR+GT5Z{lB}mo}BRE+smh)k4IZy zY-~Gu)xK`$cDWr+iCr--rtyn?x0cU*B;lj~r}pVu;m=#e_xnG&AKgB`{@bdZg-Opo SU)xgoulMfk?f=-(^9}$|N9DQz literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro new file mode 100755 index 0000000000000000000000000000000000000000..aedc7f7e0e61ce84828b3046992b525ec029f473 GIT binary patch literal 3872 zcma)8c{r5)8n*8oMByWhEJKLwq-2Ybi0oq;GqzbQGxnwWvLr^Ih^$4FT^ReGrNP*> z$QnX+V;hoV={U~$u5(@IulIfK=eh6a{=M)0Ue|l;o(nk@2iVI4B?Un_Kp`v+P#7GH zhMqlwJRwd{1UQ|=AAt6Cg920l9#9C%!vSy+fCM8BYc$jo4LY7W0AbEhxPvD^#oGU0 zeNVK9Gt%*(^_zcEJD?y3bVDc@dC-SLJv{}ETLLJU0K^IG0fwNV9-e@}i~p8#M!3PD zAOzU$MDE1aG3W(`W1t7?0nSJ?An+6wf;y!7{6F0PCHdQU;AvS5(iw#W9d|%T3>+n|7!0fsy!1nZ(lZYNi z?&Pc|`0q_8_1{x_6eQ3Qez4;|qZ?>@5delegFO$EW(_bo`XDq8f&a7t_V57vf>1Ef zJroLl5c3nE6GKN62e305ia1Q|zXneq=AQ4t7j{Ag9G8K%fn-TnuFE)AmF!%o7m8_>3Kau5L3y-I=(ED>mm!M$yCdDzskoqW+!#uGPI)j_u_-%sfAL&B z$#FSA>^&A!lC-c@_iJ#U)Hyk5-PlyDn*#YXiK|qv3mT(}e@$4<1PZ0J>2_2#h>>0n z3#^*Ix0K_D+#E*mgW7nt@;o3yECAjB>tZ_HXFLWFGPGc-~9?A&DRI` z)&!}ICJCRb2L;<~6$CECFm=b5Y_ZP_lDfICZgHvg0y0Jkcg7(4)4;k-Y`)s;>_-uQ z(=K5qvviavX(`OQReZ92T|&+-Au2#cy{lb(0dIkl>8=bENNNA|_R_h&C`5u8`d9LJ z^FW0lVrXquTH0#m>(Z)GthVw;zXWminHwTg5ecy(uiz$Cn~9pkj&MPCVtS!TwB?Aa zYtv|Fk)RP)N}sJJQ>xgjm^+xEa+bN}E%qhYC^IHEs<=H_AL3rhhM0esu9d_kbH-it z%}D7mRL~or6i!v^hns4XBQ3z1#j)!nHL;(1FZ|kW*zFTnU8)O$z+D$yL0mF}J=SY# z8yO^ZLLr+%c6$5k9Aa?_9=!efWo&O*UO%Ep{nIGq#oY}4w&6;{v_4hYP@u?@0n7zn zNm{oU?#DgvJ6%P+7%vsoHqp8Cwa53gs$1RGaeduD#VGN9c%F>eQVFLZ&OEnK)RzOl z6=%Ci_z|v|=-LELvB?qX9u^nDMNU20VSU85`UCgG@saDI!h_szXfFA#lrFfrzUWf^EW6rTQqoh? zoM}|GSyAANRvr7G?3fu(K_@%hQ1pjb&j?>&+3YKm5WPZ}m2RqzQnIvfEpO8hFxpzr z#Q?YBRj?d}1JaVLo=N+a#;iRyw>0;(z2>s4?;ddtjI5rpZA9r6#Is5>M@+1kVWdas z^1s~9V_g}jCEOD^Q;DrmB6g;bvCUQHX5?iDwF}tueBb%)n{m~QbqVpPrtJ(nr4is zG&EX%DA=lqUj&YQn~vm`Ode#)_Tk34jdZaF4h+a7N_FcAl{GwxDyPsJip@-O^k}%m zwhH$?XT&ZLoD~iy#@(fK1NsQk+^mo;nJnrF@n$O;E1t(KZ{~R$ZMa}|>4K6S-g>oo z<(e=Fl~boPb}$!-w=HE~h46CWw3XqOuG zF(;?gVH^?aEC{PWS4dK{7}!2ntmAMt@n*f0K~NHD&(FK;eJA7L!jpgWFNLyxpAC#- z*I!tuV(_>x;sN_a!_-;1;Q(@`(N9Ve6)}wB`_@@i>5;?OGK<&Fid!P{T_3lXQWcY# z;d%a`bN8C%xGdIE0PDmw4Be~rFyLxE%Xdi2XTCBj%%8A|4gOZ^i_9s==8yS^n#cwyofuuR4o0My#MX6xa_@nuKc?#R{b}*Q$`?aXO&jh#yoge4L{(TunPb9DL!Ujg7ZofW?}f=js1`+B z+ohgJRC4v?|1=tuLiTHi($g|I=rM6l+Rgw!(Ngp}1I?8A{C%%dvOu_&fK-JKOr>>L zeB-+Z@_0Apcqgo?pMFtyN@*%ZX;{-HE2i$5&^dO%Bw;1}UA%>b_%a4!;3mhj6?47N zYF_HcrfS8wh-mwWSQW{K8r0;D*%2|=r@I_qW-kk%5;B+2Uaa<(?Mabt2y3)+M#!I?v~QX3q^)h${tKE`!kRQY{*prL6bQ^jL`OyJD|du+3pv zpB+8&n_tdt#I|eYPy2$OTiKqh7)fqYq+<5qsEuVyPSDA$|JP|e0rNa zhF#IchM+ECYaGc!jet@O;;!jjr;2Fw9#>~bYP$DDs(R6aj$-q?#Puw(kv>z=WvI?w z|5)Bpu<(YCM}_f3pC~07{?$gGN#z-4ZWak|o7#iE0&!wVh=oz*%$-k1pi>&3Tsk%PPJfP^Q zls5EEbpp}5odC|@%M9yx$)V%Yl(IlY8Y(q?Etk;c@V=bR_GymBpDhbl_5tHvFrK2b z5=sZCp3>~)pwhyZu6~*tXEptKDUjM}FOqJ1C9#+R|HjAuvn@b&UcB32qe*!!?p#@Z zrlxUEqBL#Vj6Pv5!Bqk&RDJ)K1&igwqGaU(e1OGOXE~J}hDoZoBy@U~lFB5;^m>S} zw}PHu3~r4)lPdvc3{%1~r86|$;H+4el~FlAQ95NqcjpJDaH;UQ7W1he*X}jSUZA+CxZ;v6Onj`D2T)KPZylIPP3Mczr+zYMi@o!q zSh5Un+YD~}bl!_J?=HV&78257oWSt8iIsyf_FAcwC_A=TZU&Q^IfzEpeI5zbs-rx^ zuqH$?6f&i)caigHc?f0i{kaz!A3=<5-@ltXaXC_0bTq$#Gs~O*0A7>%}^qFfSY}w-7m)2%x0sPz~40*`PCb*-hn~*N}r}xF%#^p0) z=1D7hDrd3~;jfF9U!&v``FE^zg%t)&-f{f3C1Zq@J6p7oz*7AnG;_F$+&Jb!3)fR@ zO-SNjoTO?fF50rGAdcqepG$K)C3@Yxw?0ODOG+$9N8NGYE?bMgx--4MS3kD>K6t-U zT1}m33fh^PsozKo-kQeTu0MRS3H)>U-ZwQ*Cu?cR@RI@om-1YR0Yml>hx-{ZCxu7Nh_G literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test.avro b/external/avro/src/test/resources/test.avro new file mode 100644 index 0000000000000000000000000000000000000000..6425e2107e3043aa9cfa201e274994def048f4c5 GIT binary patch literal 1365 zcma)5F>ljA6n0E-VF?tWA~B%r2?0YCiG-;z)NLaOGz7 jeHj?Q?U~=SzKdX{aKW zikXR#fsq;UW8gQiGB6R*bpC|b&G z=}#yppBslbolPlT!3p(665u9|30HPXW$G4D0EUc4fy67@hsS=ICM@0oSIO6QAbg&lu;;;S)Af|h3X4MJva~d ze<@4h^J>~GW+HYAkE;$%3){w}S<=Q8F$D`Gx{-)?PV)ib~Oc!Gk!KfiIx(a zjHv^VGwz8d< zwfP{qISw^Wj_!Qi#3W)ws!7|%!+arZ1)P*Yl7!4$5xSlb5sbM`qy^;>0JD^GHMPfq z)n>dIY?!9v!kmxi#-FD*-hE9L8z1j9zCZf;{+IRp;=MutF@n`n(?hRdLvLeft{3yBowhYWHAU_ z0gxgX-F?_fibx!wNybSGTboT;z|z^n9PHiYC>AM_8IXx5vQ(2=R?RRB%U)Z*HKK4h zF=7(+`fK)b-P+sRJE~cnb0sPl*)boO_szDl@3%YVQO$hyMLjoH7Zw(3ruC_|$wG<( z7UcC{Z766;1>$5Egi17}Nl5*)g|;Swf@)Q*#E?hTf=TEO5yUe|Gu|?c+aqYvS9Ay^ lXtUQ{F4TbRx(ToRS;UvGeyA^vy3WXTMnfeIL^`MM<1a>lvsC~9 literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test.avsc b/external/avro/src/test/resources/test.avsc new file mode 100644 index 0000000000000..d7119a01f6aa0 --- /dev/null +++ b/external/avro/src/test/resources/test.avsc @@ -0,0 +1,53 @@ +{ + "type" : "record", + "name" : "test_schema", + "fields" : [{ + "name" : "string", + "type" : "string", + "doc" : "Meaningless string of characters" + }, { + "name" : "simple_map", + "type" : {"type": "map", "values": "int"} + }, { + "name" : "complex_map", + "type" : {"type": "map", "values": {"type": "map", "values": "string"}} + }, { + "name" : "union_string_null", + "type" : ["null", "string"] + }, { + "name" : "union_int_long_null", + "type" : ["int", "long", "null"] + }, { + "name" : "union_float_double", + "type" : ["float", "double"] + }, { + "name": "fixed3", + "type": {"type": "fixed", "size": 3, "name": "fixed3"} + }, { + "name": "fixed2", + "type": {"type": "fixed", "size": 2, "name": "fixed2"} + }, { + "name": "enum", + "type": { "type": "enum", + "name": "Suit", + "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + } + }, { + "name": "record", + "type": { + "type": "record", + "name": "record", + "aliases": ["RecordAlias"], + "fields" : [{ + "name": "value_field", + "type": "string" + }] + } + }, { + "name": "array_of_boolean", + "type": {"type": "array", "items": "boolean"} + }, { + "name": "bytes", + "type": "bytes" + }] +} diff --git a/external/avro/src/test/resources/test.json b/external/avro/src/test/resources/test.json new file mode 100644 index 0000000000000..780189a92b378 --- /dev/null +++ b/external/avro/src/test/resources/test.json @@ -0,0 +1,42 @@ +{ + "string": "OMG SPARK IS AWESOME", + "simple_map": {"abc": 1, "bcd": 7}, + "complex_map": {"key": {"a": "b", "c": "d"}}, + "union_string_null": {"string": "abc"}, + "union_int_long_null": {"int": 1}, + "union_float_double": {"float": 3.1415926535}, + "fixed3":"\u0002\u0003\u0004", + "fixed2":"\u0011\u0012", + "enum": "SPADES", + "record": {"value_field": "Two things are infinite: the universe and human stupidity; and I'm not sure about universe."}, + "array_of_boolean": [true, false, false], + "bytes": "\u0041\u0042\u0043" +} +{ + "string": "Terran is IMBA!", + "simple_map": {"mmm": 0, "qqq": 66}, + "complex_map": {"key": {"1": "2", "3": "4"}}, + "union_string_null": {"string": "123"}, + "union_int_long_null": {"long": 66}, + "union_float_double": {"double": 6.6666666666666}, + "fixed3":"\u0007\u0007\u0007", + "fixed2":"\u0001\u0002", + "enum": "CLUBS", + "record": {"value_field": "Life did not intend to make us perfect. Whoever is perfect belongs in a museum."}, + "array_of_boolean": [], + "bytes": "" +} +{ + "string": "The cake is a LIE!", + "simple_map": {}, + "complex_map": {"key": {}}, + "union_string_null": {"null": null}, + "union_int_long_null": {"null": null}, + "union_float_double": {"double": 0}, + "fixed3":"\u0011\u0022\u0009", + "fixed2":"\u0010\u0090", + "enum": "DIAMONDS", + "record": {"value_field": "TEST_STR123"}, + "array_of_boolean": [false], + "bytes": "\u0053" +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala new file mode 100644 index 0000000000000..c6c1e4051a4b3 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -0,0 +1,812 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.nio.file.Files +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.commons.io.FileUtils + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ + +class AvroSuite extends SparkFunSuite { + val episodesFile = "src/test/resources/episodes.avro" + val testFile = "src/test/resources/test.avro" + + private var spark: SparkSession = _ + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local[2]") + .appName("AvroSuite") + .config("spark.sql.files.maxPartitionBytes", 1024) + .getOrCreate() + } + + override protected def afterAll(): Unit = { + try { + spark.sparkContext.stop() + } finally { + super.afterAll() + } + } + + test("reading from multiple paths") { + val df = spark.read.avro(episodesFile, episodesFile) + assert(df.count == 16) + } + + test("reading and writing partitioned data") { + val df = spark.read.avro(episodesFile) + val fields = List("title", "air_date", "doctor") + for (field <- fields) { + TestUtils.withTempDir { dir => + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.partitionBy(field).avro(outputDir) + val input = spark.read.avro(outputDir) + // makes sure that no fields got dropped. + // We convert Rows to Seqs in order to work around SPARK-10325 + assert(input.select(field).collect().map(_.toSeq).toSet === + df.select(field).collect().map(_.toSeq).toSet) + } + } + } + + test("request no fields") { + val df = spark.read.avro(episodesFile) + df.registerTempTable("avro_table") + assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) + } + + test("convert formats") { + TestUtils.withTempDir { dir => + val df = spark.read.avro(episodesFile) + df.write.parquet(dir.getCanonicalPath) + assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) + } + } + + test("rearrange internal schema") { + TestUtils.withTempDir { dir => + val df = spark.read.avro(episodesFile) + df.select("doctor", "title").write.avro(dir.getCanonicalPath) + } + } + + test("test NULL avro type") { + TestUtils.withTempDir { dir => + val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("null", null) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + intercept[IncompatibleSchemaException] { + spark.read.avro(s"$dir.avro") + } + } + } + + test("union(int, long) is read as long") { + TestUtils.withTempDir { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toLong) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) + assert(df.collect().toSet == Set(Row(1L), Row(2L))) + } + } + + test("union(float, double) is read as double") { + TestUtils.withTempDir { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2.toDouble) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) + } + } + + test("union(float, double, null) is read as nullable double") { + TestUtils.withTempDir { dir => + val avroSchema: Schema = { + val union = Schema.createUnion( + List(Schema.create(Type.FLOAT), + Schema.create(Type.DOUBLE), + Schema.create(Type.NULL) + ).asJava + ) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", null) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) + } + } + + test("Union of a single type") { + TestUtils.withTempDir { dir => + val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + + avroRec.put("field1", 8) + + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.avro(s"$dir.avro") + assert(df.first() == Row(8)) + } + } + + test("Complex Union Type") { + TestUtils.withTempDir { dir => + val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) + val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) + val complexUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) + val fields = Seq( + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) + ).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val field1 = 1234 + val field2 = "Hope that was not load bearing" + val field3 = Array[Byte](1, 2, 3, 4) + val field4 = "e2" + avroRec.put("field1", field1) + avroRec.put("field2", field2) + avroRec.put("field3", new Fixed(fixedSchema, field3)) + avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.sqlContext.read.avro(s"$dir.avro") + assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) + assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) + assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) + assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + } + } + + test("Lots of nulls") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("binary", BinaryType, true), + StructField("timestamp", TimestampType, true), + StructField("array", ArrayType(ShortType), true), + StructField("map", MapType(StringType, StringType), true), + StructField("struct", StructType(Seq(StructField("int", IntegerType, true)))))) + val rdd = spark.sparkContext.parallelize(Seq[Row]( + Row(null, new Timestamp(1), Array[Short](1, 2, 3), null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null))) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Struct field type") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("short", ShortType, true), + StructField("byte", ByteType, true), + StructField("boolean", BooleanType, true) + )) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, 1.toShort, 1.toByte, true), + Row(2f, 2.toShort, 2.toByte, true), + Row(3f, 3.toShort, 3.toByte, true) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Date field type") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + Array(null, 1451865600000L, 1459987200000L).toSet) + } + } + + test("Array data types") { + TestUtils.withTempDir { dir => + val testSchema = StructType(Seq( + StructField("byte_array", ArrayType(ByteType), true), + StructField("short_array", ArrayType(ShortType), true), + StructField("float_array", ArrayType(FloatType), true), + StructField("bool_array", ArrayType(BooleanType), true), + StructField("long_array", ArrayType(LongType), true), + StructField("double_array", ArrayType(DoubleType), true), + StructField("decimal_array", ArrayType(DecimalType(10, 0)), true), + StructField("bin_array", ArrayType(BinaryType), true), + StructField("timestamp_array", ArrayType(TimestampType), true), + StructField("array_array", ArrayType(ArrayType(StringType), true), true), + StructField("struct_array", ArrayType( + StructType(Seq(StructField("name", StringType, true))))))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + + val rdd = spark.sparkContext.parallelize(Seq( + Row(arrayOfByte, Array[Short](1, 2, 3, 4), Array[Float](1f, 2f, 3f, 4f), + Array[Boolean](true, false, true, false), Array[Long](1L, 2L), Array[Double](1.0, 2.0), + Array[BigDecimal](BigDecimal.valueOf(3)), Array[Array[Byte]](arrayOfByte, arrayOfByte), + Array[Timestamp](new Timestamp(0)), + Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), + Array[Row](Row("Bobby G. can't swim"))))) + val df = spark.createDataFrame(rdd, testSchema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("write with compression") { + TestUtils.withTempDir { dir => + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val uncompressDir = s"$dir/uncompress" + val deflateDir = s"$dir/deflate" + val snappyDir = s"$dir/snappy" + val fakeDir = s"$dir/fake" + + val df = spark.read.avro(testFile) + spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") + df.write.avro(uncompressDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") + spark.conf.set(AVRO_DEFLATE_LEVEL, "9") + df.write.avro(deflateDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") + df.write.avro(snappyDir) + + val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) + val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + + assert(uncompressSize > deflateSize) + assert(snappySize > deflateSize) + } + } + + test("dsl test") { + val results = spark.read.avro(episodesFile).select("title").collect() + assert(results.length === 8) + } + + test("support of various data types") { + // This test uses data from test.avro. You can see the data and the schema of this file in + // test.json and test.avsc + val all = spark.read.avro(testFile).collect() + assert(all.length == 3) + + val str = spark.read.avro(testFile).select("string").collect() + assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) + + val simple_map = spark.read.avro(testFile).select("simple_map").collect() + assert(simple_map(0)(0).getClass.toString.contains("Map")) + assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) + + val union0 = spark.read.avro(testFile).select("union_string_null").collect() + assert(union0.map(_(0)).toSet == Set("abc", "123", null)) + + val union1 = spark.read.avro(testFile).select("union_int_long_null").collect() + assert(union1.map(_(0)).toSet == Set(66, 1, null)) + + val union2 = spark.read.avro(testFile).select("union_float_double").collect() + assert( + union2 + .map(x => new java.lang.Double(x(0).toString)) + .exists(p => Math.abs(p - Math.PI) < 0.001)) + + val fixed = spark.read.avro(testFile).select("fixed3").collect() + assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) + + val enum = spark.read.avro(testFile).select("enum").collect() + assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) + + val record = spark.read.avro(testFile).select("record").collect() + assert(record(0)(0).getClass.toString.contains("Row")) + assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) + + val array_of_boolean = spark.read.avro(testFile).select("array_of_boolean").collect() + assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) + + val bytes = spark.read.avro(testFile).select("bytes").collect() + assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) + } + + test("sql test") { + spark.sql( + s""" + |CREATE TEMPORARY TABLE avroTable + |USING avro + |OPTIONS (path "$episodesFile") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) + } + + test("conversion to avro and back") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + TestUtils.withTempDir { dir => + val avroDir = s"$dir/avro" + spark.read.avro(testFile).write.avro(avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + } + } + + test("conversion to avro and back with namespace") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + TestUtils.withTempDir { tempDir => + val name = "AvroTest" + val namespace = "com.databricks.spark.avro" + val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) + + val avroDir = tempDir + "/namedAvro" + spark.read.avro(testFile).write.options(parameters).avro(avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + + // Look at raw file and make sure has namespace info + val rawSaved = spark.sparkContext.textFile(avroDir) + val schema = rawSaved.collect().mkString("") + assert(schema.contains(name)) + assert(schema.contains(namespace)) + } + } + + test("converting some specific sparkSQL types to avro") { + TestUtils.withTempDir { tempDir => + val testSchema = StructType(Seq( + StructField("Name", StringType, false), + StructField("Length", IntegerType, true), + StructField("Time", TimestampType, false), + StructField("Decimal", DecimalType(10, 2), true), + StructField("Binary", BinaryType, false))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + val cityRDD = spark.sparkContext.parallelize(Seq( + Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), + Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), + Row("Munich", 8, new Timestamp(42), Decimal(3.14), arrayOfByte))) + val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) + + val avroDir = tempDir + "/avro" + cityDataFrame.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 3) + + // TimesStamps are converted to longs + val times = spark.read.avro(avroDir).select("Time").collect() + assert(times.map(_(0)).toSet == Set(666, 777, 42)) + + // DecimalType should be converted to string + val decimals = spark.read.avro(avroDir).select("Decimal").collect() + assert(decimals.map(_(0)).contains("3.14")) + + // There should be a null entry + val length = spark.read.avro(avroDir).select("Length").collect() + assert(length.map(_(0)).contains(null)) + + val binary = spark.read.avro(avroDir).select("Binary").collect() + for (i <- arrayOfByte.indices) { + assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) + } + } + } + + test("correctly read long as date/timestamp type") { + TestUtils.withTempDir { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val currentTime = new Timestamp(System.currentTimeMillis()) + val currentDate = new Date(System.currentTimeMillis()) + val schema = StructType(Seq( + StructField("_1", DateType, false), StructField("_2", TimestampType, false))) + val writeDs = Seq((currentDate, currentTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 1) + + val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)] + + assert(readDs.collect().sameElements(writeDs.collect())) + } + } + + test("support of globbed paths") { + val e1 = spark.read.avro("*/test/resources/episodes.avro").collect() + assert(e1.length == 8) + + val e2 = spark.read.avro("src/*/*/episodes.avro").collect() + assert(e2.length == 8) + } + + test("does not coerce null date/timestamp value to 0 epoch.") { + TestUtils.withTempDir { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val nullTime: Timestamp = null + val nullDate: Date = null + val schema = StructType(Seq( + StructField("_1", DateType, nullable = true), + StructField("_2", TimestampType, nullable = true)) + ) + val writeDs = Seq((nullDate, nullTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect + + assert(readValues.size == 1) + assert(readValues.head == ((nullDate, nullTime))) + } + } + + test("support user provided avro schema") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "string", + | "type" : "string", + | "doc" : "Meaningless string of characters" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema).avro(testFile).collect() + val expected = spark.read.avro(testFile).select("string").collect() + assert(result.sameElements(expected)) + } + + test("support user provided avro schema with defaults for missing fields") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "missingField", + | "type" : "string", + | "default" : "foo" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testFile).select("missingField").first + assert(result === Row("foo")) + } + + test("reading from invalid path throws exception") { + + // Directory given has no avro files + intercept[AnalysisException] { + TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) + } + + intercept[AnalysisException] { + spark.read.avro("very/invalid/path/123.avro") + } + + // In case of globbed path that can't be matched to anything, another exception is thrown (and + // exception message is helpful) + intercept[AnalysisException] { + spark.read.avro("*/*/*/*/*/*/*/something.avro") + } + + intercept[FileNotFoundException] { + TestUtils.withTempDir { dir => + FileUtils.touch(new File(dir, "test")) + spark.read.avro(dir.toString) + } + } + + } + + test("SQL test insert overwrite") { + TestUtils.withTempDir { tempDir => + val tempEmptyDir = s"$tempDir/sqlOverwrite" + // Create a temp directory for table that will be overwritten + new File(tempEmptyDir).mkdirs() + spark.sql( + s""" + |CREATE TEMPORARY TABLE episodes + |USING avro + |OPTIONS (path "$episodesFile") + """.stripMargin.replaceAll("\n", " ")) + spark.sql( + s""" + |CREATE TEMPORARY TABLE episodesEmpty + |(name string, air_date string, doctor int) + |USING avro + |OPTIONS (path "$tempEmptyDir") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM episodes").collect().length === 8) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().isEmpty) + + spark.sql( + s""" + |INSERT OVERWRITE TABLE episodesEmpty + |SELECT * FROM episodes + """.stripMargin.replaceAll("\n", " ")) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().length == 8) + } + } + + test("test save and load") { + // Test if load works as expected + TestUtils.withTempDir { tempDir => + val df = spark.read.avro(episodesFile) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + + df.write.avro(tempSaveDir) + val newDf = spark.read.avro(tempSaveDir) + assert(newDf.count == 8) + } + } + + test("test load with non-Avro file") { + // Test if load works as expected + TestUtils.withTempDir { tempDir => + val df = spark.read.avro(episodesFile) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.avro(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val newDf = spark + .read + .option(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + .avro(tempSaveDir) + + assert(newDf.count == 8) + } + } + + test("read avro with user defined schema: read partial columns") { + val partialColumns = StructType(Seq( + StructField("string", StringType, false), + StructField("simple_map", MapType(StringType, IntegerType), false), + StructField("complex_map", MapType(StringType, MapType(StringType, StringType)), false), + StructField("union_string_null", StringType, true), + StructField("union_int_long_null", LongType, true), + StructField("fixed3", BinaryType, true), + StructField("fixed2", BinaryType, true), + StructField("enum", StringType, false), + StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), + StructField("array_of_boolean", ArrayType(BooleanType), false), + StructField("bytes", BinaryType, true))) + val withSchema = spark.read.schema(partialColumns).avro(testFile).collect() + val withOutSchema = spark + .read + .avro(testFile) + .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", + "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") + .collect() + assert(withSchema.sameElements(withOutSchema)) + } + + test("read avro with user defined schema: read non-exist columns") { + val schema = + StructType( + Seq( + StructField("non_exist_string", StringType, true), + StructField( + "record", + StructType(Seq( + StructField("non_exist_field", StringType, false), + StructField("non_exist_field2", StringType, false))), + false))) + val withEmptyColumn = spark.read.schema(schema).avro(testFile).collect() + + assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) + } + + test("read avro file partitioned") { + TestUtils.withTempDir { dir => + val sparkSession = spark + import sparkSession.implicits._ + val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.avro(outputDir) + val input = spark.read.avro(outputDir) + assert(input.collect.toSet.size === 1024 * 3 + 1) + assert(input.rdd.partitions.size > 2) + } + } + + case class NestedBottom(id: Int, data: String) + + case class NestedMiddle(id: Int, data: NestedBottom) + + case class NestedTop(id: Int, data: NestedMiddle) + + test("saving avro that has nested records with the same name") { + TestUtils.withTempDir { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + val outputFolder = s"$tempDir/duplicate_names/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) + + case class NestedTopArray(id: Int, data: NestedMiddleArray) + + test("saving avro that has nested records with the same name inside an array") { + TestUtils.withTempDir { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopArray(1, NestedMiddleArray(2, Array( + NestedBottom(3, "1"), NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_array/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleMap(id: Int, data: Map[String, NestedBottom]) + + case class NestedTopMap(id: Int, data: NestedMiddleMap) + + test("saving avro that has nested records with the same name inside a map") { + TestUtils.withTempDir { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopMap(1, NestedMiddleMap(2, Map( + "1" -> NestedBottom(3, "1"), "2" -> NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_map/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala new file mode 100755 index 0000000000000..a0f88515ed9d4 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableConfigurationSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + import AvroFileFormat.SerializableConfiguration + val conf = new SerializableConfiguration(new Configuration()) + + val serialized = serializer.serialize(conf) + + serializer.deserialize[Any](serialized) match { + case c: SerializableConfiguration => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + case other => fail( + s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } + +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala new file mode 100755 index 0000000000000..4ae9b14d9ad0d --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{File, IOException} +import java.nio.ByteBuffer + +import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.google.common.io.Files +import java.util + +import org.apache.spark.sql.SparkSession + +private[avro] object TestUtils { + + /** + * This function checks that all records in a file match the original + * record. + */ + def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { + + def convertToString(elem: Any): String = { + elem match { + case null => "NULL" // HashSets can't have null in them, so we use a string instead + case arrayBuf: ArrayBuffer[_] => + arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") + case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") + case other => other.toString + } + } + + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(avroDir).collect() + + assert(originalEntries.length == newEntries.length) + + val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) + for (origEntry <- originalEntries) { + var idx = 0 + for (origElement <- origEntry.toSeq) { + origEntrySet(idx) += convertToString(origElement) + idx += 1 + } + } + + for (newEntry <- newEntries) { + var idx = 0 + for (newElement <- newEntry.toSeq) { + assert(origEntrySet(idx).contains(convertToString(newElement))) + idx += 1 + } + } + } + + def withTempDir(f: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.delete() + try f(dir) finally deleteRecursively(dir) + } + + /** + * This function deletes a file or a directory with everything that's in it. This function is + * copied from Spark with minor modifications made to it. See original source at: + * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala + */ + + def deleteRecursively(file: File) { + def listFilesSafely(file: File): Seq[File] = { + if (file.exists()) { + val files = file.listFiles() + if (files == null) { + throw new IOException("Failed to list files for dir: " + file) + } + files + } else { + List() + } + } + + if (file != null) { + try { + if (file.isDirectory) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + } + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } + } + } + } + } + + /** + * This function generates a random map(string, int) of a given size. + */ + private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { + val jMap = new util.HashMap[String, Int]() + for (i <- 0 until size) { + jMap.put(rand.nextString(5), i) + } + jMap + } + + /** + * This function generates a random array of booleans of a given size. + */ + private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { + val vec = new util.ArrayList[Boolean]() + for (i <- 0 until size) { + vec.add(rand.nextBoolean()) + } + vec + } + + /** + * This function generates a random ByteBuffer of a given size. + */ + private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { + val bb = ByteBuffer.allocate(size) + val arrayOfBytes = new Array[Byte](size) + rand.nextBytes(arrayOfBytes) + bb.put(arrayOfBytes) + } +} diff --git a/pom.xml b/pom.xml index 6dee6fce3ffc4..039292337eaa0 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + external/avro diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f887e4570c85d..247b6fee394bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -40,8 +40,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, avro) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "avro" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -326,7 +326,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, kvstore + unsafe, tags, sqlKafka010, kvstore, avro ).contains(x) } @@ -688,9 +688,11 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) From 07704c971cbc92bff15e15f8c42fab9afaab3ef7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 12 Jul 2018 14:08:49 -0700 Subject: [PATCH 1109/2461] [SPARK-23007][SQL][TEST] Add read schema suite for file-based data sources ## What changes were proposed in this pull request? The reader schema is said to be evolved (or projected) when it changed after the data is written. The followings are already supported in file-based data sources. Note that partition columns are not maintained in files. In this PR, `column` means `non-partition column`. 1. Add a column 2. Hide a column 3. Change a column position 4. Change a column type (upcast) This issue aims to guarantee users a backward-compatible read-schema test coverage on file-based data sources and to prevent future regressions by *adding read schema tests explicitly*. Here, we consider safe changes without data loss. For example, data type change should be from small types to larger types like `int`-to-`long`, not vice versa. As of today, in the master branch, file-based data sources have the following coverage. File Format | Coverage | Note ----------- | ---------- | ------------------------------------------------ TEXT | N/A | Schema consists of a single string column. CSV | 1, 2, 4 | JSON | 1, 2, 3, 4 | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage among ORC formats. PARQUET | 1, 2, 3 | ## How was this patch tested? Pass the Jenkins with newly added test suites. Author: Dongjoon Hyun Closes #20208 from dongjoon-hyun/SPARK-SCHEMA-EVOLUTION. --- .../datasources/ReadSchemaSuite.scala | 181 +++++++ .../datasources/ReadSchemaTest.scala | 493 ++++++++++++++++++ 2 files changed, 674 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala new file mode 100644 index 0000000000000..23c58e175fe5e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.internal.SQLConf + +/** + * Read schema suites have the following hierarchy and aims to guarantee users + * a backward-compatible read-schema change coverage on file-based data sources, and + * to prevent future regressions. + * + * ReadSchemaSuite + * -> CSVReadSchemaSuite + * -> HeaderCSVReadSchemaSuite + * + * -> JsonReadSchemaSuite + * + * -> OrcReadSchemaSuite + * -> VectorizedOrcReadSchemaSuite + * + * -> ParquetReadSchemaSuite + * -> VectorizedParquetReadSchemaSuite + * -> MergedParquetReadSchemaSuite + */ + +/** + * All file-based data sources supports column addition and removal at the end. + */ +abstract class ReadSchemaSuite + extends AddColumnTest + with HideColumnAtTheEndTest { + + var originalConf: Boolean = _ +} + +class CSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" +} + +class HeaderCSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" + + override val options = Map("header" -> "true") +} + +class JsonReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "json" +} + +class OrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedOrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class ParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class MergedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, originalConf) + super.afterAll() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala new file mode 100644 index 0000000000000..2a5457e00b4ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +/** + * The reader schema is said to be evolved (or projected) when it changed after the data is + * written by writers. The followings are supported in file-based data sources. + * Note that partition columns are not maintained in files. Here, `column` means non-partition + * column. + * + * 1. Add a column + * 2. Hide a column + * 3. Change a column position + * 4. Change a column type (Upcast) + * + * Here, we consider safe changes without data loss. For example, data type changes should be + * from small types to larger types like `int`-to-`long`, not vice versa. + * + * So far, file-based data sources have the following coverages. + * + * | File Format | Coverage | Note | + * | ------------ | ------------ | ------------------------------------------------------ | + * | TEXT | N/A | Schema consists of a single string column. | + * | CSV | 1, 2, 4 | | + * | JSON | 1, 2, 3, 4 | | + * | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage. | + * | PARQUET | 1, 2, 3 | | + * + * This aims to provide an explicit test coverage for reader schema change on file-based data + * sources. Since a file format has its own coverage, we need a test suite for each file-based + * data source with corresponding supported test case traits. + * + * The following is a hierarchy of test traits. + * + * ReadSchemaTest + * -> AddColumnTest + * -> HideColumnTest + * -> ChangePositionTest + * -> BooleanTypeTest + * -> IntegralTypeTest + * -> ToDoubleTypeTest + * -> ToDecimalTypeTest + */ + +trait ReadSchemaTest extends QueryTest with SQLTestUtils with SharedSQLContext { + val format: String + val options: Map[String, String] = Map.empty[String, String] +} + +/** + * Add column (Case 1). + * This test suite assumes that the missing column should be `null`. + */ +trait AddColumnTest extends ReadSchemaTest { + import testImplicits._ + + test("append column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq("a", "b").toDF("col1") + val df2 = df1.withColumn("col2", lit("x")) + val df3 = df2.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + val dir3 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + df3.write.format(format).options(options).save(dir3) + + val df = spark.read + .schema(df3.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", null, null, "one"), + Row("b", null, null, "one"), + Row("a", "x", null, "two"), + Row("b", "x", null, "two"), + Row("a", "x", "y", "three"), + Row("b", "x", "y", "three"))) + } + } +} + +/** + * Hide column (Case 2-1). + */ +trait HideColumnAtTheEndTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(df1.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("1", "a", "two"), + Row("2", "b", "two"), + Row("1", "a", "three"), + Row("2", "b", "three"))) + + val df3 = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df3, Seq( + Row("1", "two"), + Row("2", "two"), + Row("1", "three"), + Row("2", "three"))) + } + } +} + +/** + * Hide column in the middle (Case 2-2). + */ +trait HideColumnInTheMiddleTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column in the middle") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema("col2 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", "two"), + Row("b", "two"), + Row("a", "three"), + Row("b", "three"))) + } + } +} + +/** + * Change column positions (Case 3). + * This suite assumes that all data set have the same number of columns. + */ +trait ChangePositionTest extends ReadSchemaTest { + import testImplicits._ + + test("change column position") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b"), ("3", "c")).toDF("col1", "col2") + val df2 = Seq(("d", "4"), ("e", "5"), ("f", "6")).toDF("col2", "col1") + val unionDF = df1.unionByName(df2) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1", "col2") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait BooleanTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("change column type from boolean to byte/short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val values = (1 to 10).map(_ % 2) + val booleanDF = (1 to 10).map(_ % 2 == 1).toDF("col1") + val byteDF = values.map(_.toByte).toDF("col1") + val shortDF = values.map(_.toShort).toDF("col1") + val intDF = values.toDF("col1") + val longDF = values.map(_.toLong).toDF("col1") + + booleanDF.write.mode("overwrite").format(format).options(options).save(path) + + Seq( + ("col1 byte", byteDF), + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToStringTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("read as string") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + .selectExpr("cast(col1 AS STRING) col1") + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait IntegralTypeTest extends ReadSchemaTest { + + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val byteDF = values.map(_.toByte).toDF("col1") + private lazy val shortDF = values.map(_.toShort).toDF("col1") + private lazy val intDF = values.toDF("col1") + private lazy val longDF = values.map(_.toLong).toDF("col1") + + test("change column type from byte to short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + byteDF.write.format(format).options(options).save(path) + + Seq( + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from short to int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + shortDF.write.format(format).options(options).save(path) + + Seq(("col1 int", intDF), ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from int to long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + intDF.write.format(format).options(options).save(path) + + Seq(("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("read byte, int, short, long together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDoubleTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF) + + test("change column type from float to double") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read.schema("col1 double").format(format).options(options).load(path) + + checkAnswer(df, doubleDF) + } + } + + test("read float and double together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDecimalTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val decimalDF = values.map(BigDecimal(_)).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF).union(decimalDF) + + test("change column type from float to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("change column type from double to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + doubleDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("read float, double, decimal together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + val decimalDir = s"$path${File.separator}part=decimal" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + decimalDF.write.format(format).options(options).save(decimalDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} From 11384893b6ad09c0c8bc6a350bb9540d0d704bb4 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 12 Jul 2018 15:13:26 -0700 Subject: [PATCH 1110/2461] [SPARK-24208][SQL][FOLLOWUP] Move test cases to proper locations ## What changes were proposed in this pull request? The PR is a followup to move the test cases introduced by the original PR in their proper location. ## How was this patch tested? moved UTs Author: Marco Gaido Closes #21751 from mgaido91/SPARK-24208_followup. --- python/pyspark/sql/tests.py | 32 +++++++++---------- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 +++++++++++ .../spark/sql/GroupedDatasetSuite.scala | 12 ------- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4404dbe40590a..565654e7f03bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5471,6 +5471,22 @@ def foo(_): self.assertEqual(r.a, 'hi') self.assertEqual(r.b, 1) + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5925,22 +5941,6 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() - def test_self_join_with_pandas(self): - import pyspark.sql.functions as F - - @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) - def dummy_pandas_udf(df): - return df[['key', 'col']] - - df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), - Row(key=2, col='C')]) - dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf) - - # this was throwing an AnalysisException before SPARK-24208 - res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'), - F.col('temp0.key') == F.col('temp1.key')) - self.assertEquals(res.count(), 5) - @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..bbcdf6c1b8481 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.Matchers +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { SubqueryAlias("tbl", testRelation))) assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val flatMapGroupsInPandas = FlatMapGroupsInPandas( + Seq(UnresolvedAttribute("a")), pythonUdf, output, project) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala index bd54ea415ca88..147c0b61f5017 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -93,16 +93,4 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext { } datasetWithUDF.unpersist(true) } - - test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { - val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF( - "pyUDF", - null, - StructType(Seq(StructField("s", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true)) - val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s") - df1.queryExecution.assertAnalyzed() - } } From 75725057b3ffdb0891844b10bd707bb0830f92ca Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 12 Jul 2018 16:54:03 -0700 Subject: [PATCH 1111/2461] [SPARK-24790][SQL] Allow complex aggregate expressions in Pivot ## What changes were proposed in this pull request? Relax the check to allow complex aggregate expressions, like `ceil(sum(col1))` or `sum(col1) + 1`, which roughly means any aggregate expression that could appear in an Aggregate plan except pandas UDF (due to the fact that it is not supported in pivot yet). ## How was this patch tested? Added 2 tests in pivot.sql Author: maryannxue Closes #21753 from maryannxue/pivot-relax-syntax. --- .../sql/catalyst/analysis/Analyzer.scala | 24 ++++++------- .../test/resources/sql-tests/inputs/pivot.sql | 18 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 34 +++++++++++++++++-- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c078efdfc0000..9749893dbdb0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -509,12 +509,7 @@ class Analyzer( || !p.pivotColumn.resolved => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. - aggregates.foreach { e => - if (!isAggregateExpression(e)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$e'") - } - } + aggregates.foreach(checkValidAggregateExpression) // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) @@ -586,12 +581,17 @@ class Analyzer( } } - private def isAggregateExpression(expr: Expression): Boolean = { - expr match { - case Alias(e, _) => isAggregateExpression(e) - case AggregateExpression(_, _, _, _) => true - case _ => false - } + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 01dea6c81c11b..b3d53adfbebe7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -111,3 +111,21 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 85e3488990e20..922d8b9f9152c 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -176,7 +176,7 @@ PIVOT ( struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Aggregate expression required for pivot, found 'abs(earnings#x)'; +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; -- !query 12 @@ -192,3 +192,33 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 13 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 13 schema +struct +-- !query 13 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; From e0f4f206b737c62da307c1b1b8e6d2eae832696e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 13 Jul 2018 10:40:58 +0800 Subject: [PATCH 1112/2461] [SPARK-24537][R] Add array_remove / array_zip / map_from_arrays / array_distinct ## What changes were proposed in this pull request? Add array_remove / array_zip / map_from_arrays / array_distinct functions in SparkR. ## How was this patch tested? Add tests in test_sparkSQL.R Author: Huaxin Gao Closes #21645 from huaxingao/spark-24537. --- R/pkg/NAMESPACE | 4 ++ R/pkg/R/functions.R | 70 +++++++++++++++++++++++++-- R/pkg/R/generics.R | 16 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 21 ++++++++ 4 files changed, 107 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9696f6987ad78..adfd3871f3426 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,13 +201,16 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_distinct", "array_join", "array_max", "array_min", "array_position", + "array_remove", "array_repeat", "array_sort", "arrays_overlap", + "arrays_zip", "asc", "ascii", "asin", @@ -306,6 +309,7 @@ exportMethods("%<=>%", "lpad", "ltrim", "map_entries", + "map_from_arrays", "map_keys", "map_values", "max", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3bff633fbc1ff..2929a00330c62 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -194,10 +194,12 @@ NULL #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. #' \item \code{array_position}: a value to locate in the given array. +#' \item \code{array_remove}: a value to remove in the given array. #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. +#' options as the JSON data source. In \code{arrays_zip}, this contains additional +#' Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -207,9 +209,9 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) -#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -221,6 +223,7 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) #' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} @@ -1978,7 +1981,7 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} #' are on the same day of month, or both are the last day of month, time of day will be ignored. #' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. @@ -3008,6 +3011,19 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_distinct}: Removes duplicate values from the array. +#' +#' @rdname column_collection_functions +#' @aliases array_distinct array_distinct,Column-method +#' @note array_distinct since 2.4.0 +setMethod("array_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc) + column(jc) + }) + #' @details #' \code{array_join}: Concatenates the elements of column using the delimiter. #' Null values are replaced with nullReplacement if set, otherwise they are ignored. @@ -3071,6 +3087,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_remove}: Removes all elements that equal to element from the given array. +#' +#' @rdname column_collection_functions +#' @aliases array_remove array_remove,Column-method +#' @note array_remove since 2.4.0 +setMethod("array_remove", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value) + column(jc) + }) + #' @details #' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times #' given by \code{count}. @@ -3120,6 +3149,24 @@ setMethod("arrays_overlap", column(jc) }) +#' @details +#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th +#' values of input arrays. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip arrays_zip,Column-method +#' @note arrays_zip since 2.4.0 +setMethod("arrays_zip", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(arg) { + stopifnot(class(arg) == "Column") + arg@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3147,6 +3194,21 @@ setMethod("map_entries", column(jc) }) +#' @details +#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for +#' keys. The array in the second column is used for values. All elements in the array for key +#' should not be null. +#' +#' @rdname column_collection_functions +#' @aliases map_from_arrays map_from_arrays,Column-method +#' @note map_from_arrays since 2.4.0 +setMethod("map_from_arrays", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9321bbaf96ff8..4a7210bf1b902 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) @@ -773,6 +777,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) @@ -785,6 +793,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) #' @name NULL setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1050,6 +1062,10 @@ setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @name NULL setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 36e0f78bb0599..adcbbff823a2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,27 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_distinct() and array_remove() + df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L)))) + result <- collect(select(df, array_distinct(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L))) + + result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]] + expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L))) + + # Test arrays_zip() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2")) + result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]] + expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)), + listToStruct(list(c1 = 2L, c2 = 4L))) + expect_equal(result, list(expected_entries)) + + # Test map_from_arrays() + df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v")) + result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]] + expected_entries <- list(as.environment(list(x = 1, y = 2))) + expect_equal(result, expected_entries) + # Test array_repeat() df <- createDataFrame(list(list("a", 3L), list("b", 2L))) result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] From 0ce11d0e3a7c8c48d9f7305d2dd39c7b281b2a53 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 12 Jul 2018 22:20:06 -0700 Subject: [PATCH 1113/2461] [SPARK-23486] cache the function name from the external catalog for lookupFunctions ## What changes were proposed in this pull request? This PR will cache the function name from external catalog, it is used by lookupFunctions in the analyzer, and it is cached for each query plan. The original problem is reported in the [ spark-19737](https://issues.apache.org/jira/browse/SPARK-19737) ## How was this patch tested? create new test file LookupFunctionsSuite and add test case in SessionCatalogSuite Author: Kevin Yu Closes #20795 from kevinyu98/spark-23486. --- .../sql/catalyst/analysis/Analyzer.scala | 45 +++++++- .../sql/catalyst/catalog/SessionCatalog.scala | 16 +++ .../analysis/LookupFunctionsSuite.scala | 104 ++++++++++++++++++ .../catalog/SessionCatalogSuite.scala | 36 ++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 4 + 5 files changed, 199 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9749893dbdb0c..960ee27aec7be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -1208,16 +1211,46 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * + * In order to avoid duplicate external functions lookup, the external function identifier will + * store in the local hash set externalFunctionNameSet. * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case f: UnresolvedFunction if !catalog.functionExists(f.name) => - withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() + plan.transformAllExpressions { + case f: UnresolvedFunction + if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f + case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + externalFunctionNameSet.add(normalizeFuncName(f.name)) + f + case f: UnresolvedFunction => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + f.name.funcName) + } + } + } + + def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + val funcName = if (conf.caseSensitiveAnalysis) { + name.funcName + } else { + name.funcName.toLowerCase(Locale.ROOT) + } + + val databaseName = name.database match { + case Some(a) => formatDatabaseName(a) + case None => catalog.getCurrentDatabase + } + + FunctionIdentifier(funcName, Some(databaseName)) + } + + protected def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c26a34528c162..b09b81eabf60d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1193,6 +1193,22 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { throw new NoSuchFunctionException( db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala new file mode 100644 index 0000000000000..cea0f2a9cbc97 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.net.URI + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +class LookupFunctionsSuite extends PlanTest { + + test("SPARK-23486: the functionExists for the Persistent function check") { + val externalCatalog = new CustomInMemoryCatalog + val conf = new SQLConf() + val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), + Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), + Alias(unresolvedRegisteredFunc, "call5")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(externalCatalog.getFunctionExistsCalledTimes == 1) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedPersistentFunc.name).database == Some("default")) + } + + test("SPARK-23486: the functionExists for the Registered function check") { + val externalCatalog = new InMemoryCatalog + val conf = new SQLConf() + val customerFunctionReg = new CustomerFunctionRegistry + val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedRegisteredFunc.name).database == Some("default")) + } +} + +class CustomerFunctionRegistry extends SimpleFunctionRegistry { + + private var isRegisteredFunctionCalledTimes: Int = 0; + + override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { + isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 + true + } + + def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes +} + +class CustomInMemoryCatalog extends InMemoryCatalog { + + private var functionExistsCalledTimes: Int = 0 + + override def functionExists(db: String, funcName: String): Boolean = synchronized { + functionExistsCalledTimes = functionExistsCalledTimes + 1 + true + } + + def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6a7375ee186fa..50496a0410528 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1217,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("isRegisteredFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1"))) + + // Returns true when the function does register + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc1) ) + assert(catalog.isRegisteredFunction(FunctionIdentifier("iff"))) + + // Returns false when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum"))) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2")))) + } + } + + test("isPersistentFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2"))) + + // Returns false when the function does register + val tempFunc2 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc2)) + assert(!catalog.isPersistentFunction(FunctionIdentifier("iff"))) + + // Return true when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2")))) + assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum"))) + } + } + test("drop function") { withBasicCatalog { catalog => assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 94ddeae1bf547..de41bb418181d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog( super.functionExists(name) || hiveFunctions.contains(name.funcName) } + override def isPersistentFunction(name: FunctionIdentifier): Boolean = { + super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. From 0f24c6f8abc11c4525c21ac5cd25991bfed36dc4 Mon Sep 17 00:00:00 2001 From: Yuanbo Liu Date: Fri, 13 Jul 2018 07:37:24 -0600 Subject: [PATCH 1114/2461] =?UTF-8?q?[SPARK-24713]=20AppMatser=20of=20spar?= =?UTF-8?q?k=20streaming=20kafka=20OOM=20if=20there=20are=20hund=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We have hundreds of kafka topics need to be consumed in one application. The application master will throw OOM exception after hanging for nearly half of an hour. OOM happens in the env with a lot of topics, and it's not convenient to set up such kind of env in the unit test. So I didn't change/add test case. Author: Yuanbo Liu Author: yuanbo Closes #21690 from yuanboliu/master. --- .../spark/streaming/kafka010/DirectKafkaInputDStream.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index c3221481556f5..0246006acf0bd 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -166,6 +166,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( * which would throw off consumer position. Fix position if this happens. */ private def paranoidPoll(c: Consumer[K, V]): Unit = { + // don't actually want to consume any messages, so pause all partitions + c.pause(c.assignment()) val msgs = c.poll(0) if (!msgs.isEmpty) { // position should be minimum offset per topicpartition @@ -204,8 +206,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap - // don't want to consume messages, so pause - c.pause(newPartitions.asJava) // find latest available offsets c.seekToEnd(currentOffsets.keySet.asJava) parts.map(tp => tp -> c.position(tp)).toMap @@ -262,9 +262,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( tp -> c.position(tp) }.toMap } - - // don't actually want to consume any messages, so pause all partitions - c.pause(currentOffsets.keySet.asJava) } override def stop(): Unit = this.synchronized { From dfd7ac9887f89b9b51b7b143ab54d01f11cfcdb5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 13 Jul 2018 08:25:00 -0700 Subject: [PATCH 1115/2461] [SPARK-24781][SQL] Using a reference from Dataset in Filter/Sort might not work ## What changes were proposed in this pull request? When we use a reference from Dataset in filter or sort, which was not used in the prior select, an AnalysisException occurs, e.g., ```scala val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") df.select(df("name")).filter(df("id") === 0).show() ``` ```scala org.apache.spark.sql.AnalysisException: Resolved attribute(s) id#6 missing from name#5 in operator !Filter (id#6 = 0).;; !Filter (id#6 = 0) +- AnalysisBarrier +- Project [name#5] +- Project [_1#2 AS name#5, _2#3 AS id#6] +- LocalRelation [_1#2, _2#3] ``` This change updates the rule `ResolveMissingReferences` so `Filter` and `Sort` with non-empty `missingInputs` will also be transformed. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21745 from viirya/SPARK-24781. --- .../sql/catalyst/analysis/Analyzer.scala | 30 ++++++++++++++----- .../org/apache/spark/sql/DataFrameSuite.scala | 25 ++++++++++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 960ee27aec7be..36f14ccdc6989 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1132,7 +1132,8 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1143,7 +1144,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1154,10 +1155,17 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { @@ -1168,15 +1176,19 @@ class Analyzer( (newExprs, AnalysisBarrier(newChild)) case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1526,7 +1538,11 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9d7645d232d08..5babdf6f33b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2387,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } } From c1b62e420a43aa7da36733ccdbec057d87ac1b43 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 13 Jul 2018 08:55:46 -0700 Subject: [PATCH 1116/2461] [SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods ## What changes were proposed in this pull request? Improve Avro unit test: 1. use QueryTest/SharedSQLContext/SQLTestUtils, instead of the duplicated test utils. 2. replace deprecated methods ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21760 from gengliangwang/improve_avro_test. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 114 ++++++------- .../org/apache/spark/sql/avro/TestUtils.scala | 156 ------------------ 2 files changed, 53 insertions(+), 217 deletions(-) delete mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c6c1e4051a4b3..108b347ca0f56 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends SparkFunSuite { +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" - private var spark: SparkSession = _ - override protected def beforeAll(): Unit = { super.beforeAll() - spark = SparkSession.builder() - .master("local[2]") - .appName("AvroSuite") - .config("spark.sql.files.maxPartitionBytes", 1024) - .getOrCreate() - } - - override protected def afterAll(): Unit = { - try { - spark.sparkContext.stop() - } finally { - super.afterAll() - } + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(newFile) + checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { @@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - TestUtils.withTempDir { dir => + withTempPath { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -82,12 +74,12 @@ class AvroSuite extends SparkFunSuite { test("request no fields") { val df = spark.read.avro(episodesFile) - df.registerTempTable("avro_table") + df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -95,15 +87,16 @@ class AvroSuite extends SparkFunSuite { } test("rearrange internal schema") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - TestUtils.withTempDir { dir => - val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -122,11 +115,11 @@ class AvroSuite extends SparkFunSuite { } test("union(int, long) is read as long") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -150,11 +143,11 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double) is read as double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -178,7 +171,7 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double, null) is read as nullable double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -186,7 +179,7 @@ class AvroSuite extends SparkFunSuite { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -210,9 +203,9 @@ class AvroSuite extends SparkFunSuite { } test("Union of a single type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -233,16 +226,16 @@ class AvroSuite extends SparkFunSuite { } test("Complex Union Type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null), - new Field("field2", complexUnionType, "doc", null), - new Field("field3", complexUnionType, "doc", null), - new Field("field4", complexUnionType, "doc", null) + new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]), + new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]), + new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]), + new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any]) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite { } test("Lots of nulls") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite { } test("Struct field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite { } test("Date field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite { } test("Array data types") { - TestUtils.withTempDir { dir => + withTempPath { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite { } test("write with compression") { - TestUtils.withTempDir { dir => + withTempPath { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY TABLE avroTable + |CREATE TEMPORARY VIEW avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { dir => + withTempPath { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite { } test("converting some specific sparkSQL types to avro") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite { } test("correctly read long as date/timestamp type") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite { } test("does not coerce null date/timestamp value to 0 epoch.") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite { // Directory given has no avro files intercept[AnalysisException] { - TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) + withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite { } intercept[FileNotFoundException] { - TestUtils.withTempDir { dir => + withTempPath { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite { } test("SQL test insert overwrite") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY TABLE episodes + |CREATE TEMPORARY VIEW episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY TABLE episodesEmpty + |CREATE TEMPORARY VIEW episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite { test("test save and load") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite { test("test load with non-Avro file") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite { } test("read avro file partitioned") { - TestUtils.withTempDir { dir => + withTempPath { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala deleted file mode 100755 index 4ae9b14d9ad0d..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import java.io.{File, IOException} -import java.nio.ByteBuffer - -import scala.collection.immutable.HashSet -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import com.google.common.io.Files -import java.util - -import org.apache.spark.sql.SparkSession - -private[avro] object TestUtils { - - /** - * This function checks that all records in a file match the original - * record. - */ - def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { - - def convertToString(elem: Any): String = { - elem match { - case null => "NULL" // HashSets can't have null in them, so we use a string instead - case arrayBuf: ArrayBuffer[_] => - arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") - case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") - case other => other.toString - } - } - - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(avroDir).collect() - - assert(originalEntries.length == newEntries.length) - - val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) - for (origEntry <- originalEntries) { - var idx = 0 - for (origElement <- origEntry.toSeq) { - origEntrySet(idx) += convertToString(origElement) - idx += 1 - } - } - - for (newEntry <- newEntries) { - var idx = 0 - for (newElement <- newEntry.toSeq) { - assert(origEntrySet(idx).contains(convertToString(newElement))) - idx += 1 - } - } - } - - def withTempDir(f: File => Unit): Unit = { - val dir = Files.createTempDir() - dir.delete() - try f(dir) finally deleteRecursively(dir) - } - - /** - * This function deletes a file or a directory with everything that's in it. This function is - * copied from Spark with minor modifications made to it. See original source at: - * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala - */ - - def deleteRecursively(file: File) { - def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - if (file != null) { - try { - if (file.isDirectory) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - } - } finally { - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } - } - } - - /** - * This function generates a random map(string, int) of a given size. - */ - private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { - val jMap = new util.HashMap[String, Int]() - for (i <- 0 until size) { - jMap.put(rand.nextString(5), i) - } - jMap - } - - /** - * This function generates a random array of booleans of a given size. - */ - private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { - val vec = new util.ArrayList[Boolean]() - for (i <- 0 until size) { - vec.add(rand.nextBoolean()) - } - vec - } - - /** - * This function generates a random ByteBuffer of a given size. - */ - private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { - val bb = ByteBuffer.allocate(size) - val arrayOfBytes = new Array[Byte](size) - rand.nextBytes(arrayOfBytes) - bb.put(arrayOfBytes) - } -} From 3bcb1b481423aedf1ac531ad582c7cb8685f1e3c Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 13 Jul 2018 10:06:26 -0700 Subject: [PATCH 1117/2461] Revert "[SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods" This reverts commit c1b62e420a43aa7da36733ccdbec057d87ac1b43. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 114 +++++++------ .../org/apache/spark/sql/avro/TestUtils.scala | 156 ++++++++++++++++++ 2 files changed, 217 insertions(+), 53 deletions(-) create mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 108b347ca0f56..c6c1e4051a4b3 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,24 +31,32 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class AvroSuite extends SparkFunSuite { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" + private var spark: SparkSession = _ + override protected def beforeAll(): Unit = { super.beforeAll() - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) - } - - def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(newFile) - checkAnswer(newEntries, originalEntries) + spark = SparkSession.builder() + .master("local[2]") + .appName("AvroSuite") + .config("spark.sql.files.maxPartitionBytes", 1024) + .getOrCreate() + } + + override protected def afterAll(): Unit = { + try { + spark.sparkContext.stop() + } finally { + super.afterAll() + } } test("reading from multiple paths") { @@ -60,7 +68,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - withTempPath { dir => + TestUtils.withTempDir { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -74,12 +82,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("request no fields") { val df = spark.read.avro(episodesFile) - df.createOrReplaceTempView("avro_table") + df.registerTempTable("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - withTempPath { dir => + TestUtils.withTempDir { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -87,16 +95,15 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("rearrange internal schema") { - withTempPath { dir => + TestUtils.withTempDir { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - withTempPath { dir => - val fields = - Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava + TestUtils.withTempDir { dir => + val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -115,11 +122,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(int, long) is read as long") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -143,11 +150,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(float, double) is read as double") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -171,7 +178,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(float, double, null) is read as nullable double") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -179,7 +186,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -203,9 +210,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Union of a single type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -226,16 +233,16 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Complex Union Type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any]) + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -264,7 +271,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Lots of nulls") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -283,7 +290,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Struct field type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -302,7 +309,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Date field type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -322,7 +329,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Array data types") { - withTempPath { dir => + TestUtils.withTempDir { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -356,12 +363,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("write with compression") { - withTempPath { dir => + TestUtils.withTempDir { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" + val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -431,7 +439,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY VIEW avroTable + |CREATE TEMPORARY TABLE avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -442,24 +450,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - withTempPath { dir => + TestUtils.withTempDir { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -470,7 +478,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("converting some specific sparkSQL types to avro") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -512,7 +520,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("correctly read long as date/timestamp type") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -541,7 +549,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("does not coerce null date/timestamp value to 0 epoch.") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -602,7 +610,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Directory given has no avro files intercept[AnalysisException] { - withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) + TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -616,7 +624,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } intercept[FileNotFoundException] { - withTempPath { dir => + TestUtils.withTempDir { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -625,19 +633,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SQL test insert overwrite") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY VIEW episodes + |CREATE TEMPORARY TABLE episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY VIEW episodesEmpty + |CREATE TEMPORARY TABLE episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -657,7 +665,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test save and load") { // Test if load works as expected - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -671,7 +679,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test load with non-Avro file") { // Test if load works as expected - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -729,7 +737,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("read avro file partitioned") { - withTempPath { dir => + TestUtils.withTempDir { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -748,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -765,7 +773,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -786,7 +794,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala new file mode 100755 index 0000000000000..4ae9b14d9ad0d --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{File, IOException} +import java.nio.ByteBuffer + +import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.google.common.io.Files +import java.util + +import org.apache.spark.sql.SparkSession + +private[avro] object TestUtils { + + /** + * This function checks that all records in a file match the original + * record. + */ + def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { + + def convertToString(elem: Any): String = { + elem match { + case null => "NULL" // HashSets can't have null in them, so we use a string instead + case arrayBuf: ArrayBuffer[_] => + arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") + case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") + case other => other.toString + } + } + + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(avroDir).collect() + + assert(originalEntries.length == newEntries.length) + + val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) + for (origEntry <- originalEntries) { + var idx = 0 + for (origElement <- origEntry.toSeq) { + origEntrySet(idx) += convertToString(origElement) + idx += 1 + } + } + + for (newEntry <- newEntries) { + var idx = 0 + for (newElement <- newEntry.toSeq) { + assert(origEntrySet(idx).contains(convertToString(newElement))) + idx += 1 + } + } + } + + def withTempDir(f: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.delete() + try f(dir) finally deleteRecursively(dir) + } + + /** + * This function deletes a file or a directory with everything that's in it. This function is + * copied from Spark with minor modifications made to it. See original source at: + * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala + */ + + def deleteRecursively(file: File) { + def listFilesSafely(file: File): Seq[File] = { + if (file.exists()) { + val files = file.listFiles() + if (files == null) { + throw new IOException("Failed to list files for dir: " + file) + } + files + } else { + List() + } + } + + if (file != null) { + try { + if (file.isDirectory) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + } + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } + } + } + } + } + + /** + * This function generates a random map(string, int) of a given size. + */ + private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { + val jMap = new util.HashMap[String, Int]() + for (i <- 0 until size) { + jMap.put(rand.nextString(5), i) + } + jMap + } + + /** + * This function generates a random array of booleans of a given size. + */ + private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { + val vec = new util.ArrayList[Boolean]() + for (i <- 0 until size) { + vec.add(rand.nextBoolean()) + } + vec + } + + /** + * This function generates a random ByteBuffer of a given size. + */ + private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { + val bb = ByteBuffer.allocate(size) + val arrayOfBytes = new Array[Byte](size) + rand.nextBytes(arrayOfBytes) + bb.put(arrayOfBytes) + } +} From 3b6005b8a276e646c0785d924f139a48238a7c87 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 13 Jul 2018 11:23:42 -0700 Subject: [PATCH 1118/2461] [SPARK-23528][ML] Add numIter to ClusteringSummary ## What changes were proposed in this pull request? Added the number of iterations in `ClusteringSummary`. This is an helpful information in evaluating how to eventually modify the parameters in order to get a better model. ## How was this patch tested? modified existing UTs Author: Marco Gaido Closes #20701 from mgaido91/SPARK-23528. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 6 ++++-- .../apache/spark/ml/clustering/ClusteringSummary.scala | 6 ++++-- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 8 +++++--- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++++-- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 2 +- .../org/apache/spark/mllib/clustering/KMeansModel.scala | 9 +++++++-- .../spark/ml/clustering/BisectingKMeansSuite.scala | 1 + .../spark/ml/clustering/GaussianMixtureSuite.scala | 1 + .../org/apache/spark/ml/clustering/KMeansSuite.scala | 1 + project/MimaExcludes.scala | 5 +++++ 10 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 9c9614509c64f..de564471c2b4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -274,7 +274,7 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) @@ -304,6 +304,7 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.1.0") @Experimental @@ -311,4 +312,5 @@ class BisectingKMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala index 44e832b058b62..7da4c43a1abf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.sql.{DataFrame, Row} /** @@ -28,13 +28,15 @@ import org.apache.spark.sql.{DataFrame, Row} * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Experimental class ClusteringSummary private[clustering] ( @transient val predictions: DataFrame, val predictionCol: String, val featuresCol: String, - val k: Int) extends Serializable { + val k: Int, + @Since("2.4.0") val numIter: Int) extends Serializable { /** * Cluster centers of the transformed data. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 64ecc1ebda589..dae64ba9a515d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -423,7 +423,7 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -687,6 +687,7 @@ private class ExpectationAggregator( * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param logLikelihood Total log-likelihood for this model on the given data. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -696,8 +697,9 @@ class GaussianMixtureSummary private[clustering] ( @Since("2.0.0") val probabilityCol: String, featuresCol: String, k: Int, - @Since("2.2.0") val logLikelihood: Double) - extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { + @Since("2.2.0") val logLikelihood: Double, + numIter: Int) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) { /** * Probability of each cluster. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1704412741d49..f40037a8d9aa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -356,7 +356,7 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), parentModel.numIter) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -388,6 +388,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -395,4 +396,5 @@ class KMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be3490497..37ae8b1a6171a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -348,7 +348,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure) + new KMeansModel(centers.map(_.vector), distanceMeasure, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a78c21e838e44..e3a88b42fbf73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,8 +36,9 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], - @Since("2.4.0") val distanceMeasure: String) +class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String, + private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { private val distanceMeasureInstance: DistanceMeasure = @@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("2.4.0") + private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = + this(clusterCenters: Array[Vector], distanceMeasure, -1) + @Since("1.1.0") def this(clusterCenters: Array[Vector]) = this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 81842afbddbbb..1b7780e171e77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -133,6 +133,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 20) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 0b91f502f615b..13bed9dbe3e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -145,6 +145,7 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 2) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2569e7a432ca4..829c90fe34e94 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -135,6 +135,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 1) model.setSummary(None) assert(!model.hasSummary) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eeb097ef153ad..8f96bb0f33849 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23528] Add numIter to ClusteringSummary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), From a75571b46f813005a6d4b076ec39081ffab11844 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 13 Jul 2018 14:07:52 -0700 Subject: [PATCH 1119/2461] [SPARK-23831][SQL] Add org.apache.derby to IsolatedClientLoader ## What changes were proposed in this pull request? Add `org.apache.derby` to `IsolatedClientLoader`, otherwise it may throw an exception: ```scala ... [info] Cause: java.sql.SQLException: Failed to start database 'metastore_db' with class loader org.apache.spark.sql.hive.client.IsolatedClientLoader$$anon$12439ab23, see the next exception for details. [info] at org.apache.derby.impl.jdbc.SQLExceptionFactory.getSQLException(Unknown Source) [info] at org.apache.derby.impl.jdbc.SQLExceptionFactory.getSQLException(Unknown Source) [info] at org.apache.derby.impl.jdbc.Util.seeNextException(Unknown Source) [info] at org.apache.derby.impl.jdbc.EmbedConnection.bootDatabase(Unknown Source) [info] at org.apache.derby.impl.jdbc.EmbedConnection.(Unknown Source) [info] at org.apache.derby.jdbc.InternalDriver$1.run(Unknown Source) ... ``` ## How was this patch tested? unit tests and manual tests Author: Yuming Wang Closes #20944 from wangyum/SPARK-23831. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 1 + .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 2f34f69b5cf48..6a90c44a2633d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -182,6 +182,7 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 + name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 0a522b6a11c80..1de258f060943 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,4 +113,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } + + test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { + val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + assert(!client1.equals(client2)) + } } From f1a99ad5825daf1b4cc275146ba8460cbcdf9701 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 13 Jul 2018 17:19:28 -0700 Subject: [PATCH 1120/2461] [SPARK-23984][K8S][TEST] Added Integration Tests for PySpark on Kubernetes ## What changes were proposed in this pull request? I added integration tests for PySpark ( + checking JVM options + RemoteFileTest) which wasn't properly merged in the initial integration test PR. ## How was this patch tested? I tested this with integration tests using: `dev/dev-run-integration-tests.sh --spark-tgz spark-2.4.0-SNAPSHOT-bin-2.7.3.tgz` Author: Ilan Filonenko Closes #21583 from ifilonenko/master. --- .../k8s/integrationtest/KubernetesSuite.scala | 171 +++++++++++++++--- .../KubernetesTestComponents.scala | 28 ++- 2 files changed, 167 insertions(+), 32 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 6e334c83fbde8..774c3936b877c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter -import io.fabric8.kubernetes.api.model.{Container, Pod} +import io.fabric8.kubernetes.api.model.Pod import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} @@ -43,6 +43,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite private var kubernetesTestComponents: KubernetesTestComponents = _ private var sparkAppConf: SparkAppConf = _ private var image: String = _ + private var pyImage: String = _ private var containerLocalSparkDistroExamplesJar: String = _ private var appLocator: String = _ private var driverPodName: String = _ @@ -65,6 +66,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite val imageTag = getTestImageTag val imageRepo = getTestImageRepo image = s"$imageRepo/spark:$imageTag" + pyImage = s"$imageRepo/spark-py:$imageTag" val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) .toFile @@ -156,22 +158,77 @@ private[spark] class KubernetesSuite extends SparkFunSuite }) } - // TODO(ssuchter): Enable the below after debugging - // test("Run PageRank using remote data file") { - // sparkAppConf - // .set("spark.kubernetes.mountDependencies.filesDownloadDir", - // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH) - // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) - // runSparkPageRankAndVerifyCompletion( - // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE)) - // } + test("Run extraJVMOptions check on driver") { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file") { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion( + appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } + + test("Run PySpark on simple pi.py example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } private def runSparkPiAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, appArgs: Array[String] = Array.empty[String], - appLocator: String = appLocator): Unit = { + appLocator: String = appLocator, + isJVM: Boolean = true ): Unit = { runSparkApplicationAndVerifyCompletion( appResource, SPARK_PI_MAIN_CLASS, @@ -179,10 +236,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + isJVM) } - private def runSparkPageRankAndVerifyCompletion( + private def runSparkRemoteCheckAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, @@ -190,12 +248,50 @@ private[spark] class KubernetesSuite extends SparkFunSuite appLocator: String = appLocator): Unit = { runSparkApplicationAndVerifyCompletion( appResource, - SPARK_PAGE_RANK_MAIN_CLASS, - Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"), + SPARK_REMOTE_MAIN_CLASS, + Seq(s"Mounting of ${appArgs.head} was true"), appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + true) + } + + private def runSparkJVMCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("5"), + expectedJVMValue: Seq[String]): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + doBasicDriverPodCheck(driverPod) + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedJVMValue.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } } private def runSparkApplicationAndVerifyCompletion( @@ -205,12 +301,20 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs: Array[String], driverPodChecker: Pod => Unit, executorPodChecker: Pod => Unit, - appLocator: String): Unit = { + appLocator: String, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, appArgs = appArgs) - SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) val driverPod = kubernetesTestComponents.kubernetesClient .pods() @@ -248,11 +352,22 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") } + private def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") } + private def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + private def checkCustomSettings(pod: Pod): Unit = { assert(pod.getMetadata.getLabels.get("label1") === "label1-value") assert(pod.getMetadata.getLabels.get("label2") === "label2-value") @@ -287,14 +402,22 @@ private[spark] object KubernetesSuite { val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" + val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" - // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + val TEST_SECRET_NAME_PREFIX = "test-secret-" + val TEST_SECRET_KEY = "test-key" + val TEST_SECRET_VALUE = "test-data" + val TEST_SECRET_MOUNT_PATH = "/etc/secrets" - // val REMOTE_PAGE_RANK_DATA_FILE = - // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" - // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = - // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" - // case object ShuffleNotReadyException extends Exception + case object ShuffleNotReadyException extends Exception } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index b2471e51116cb..a9b49a8e5a610 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -97,21 +97,33 @@ private[spark] case class SparkAppArguments( appArgs: Array[String]) private[spark] object SparkAppLauncher extends Logging { - def launch( appArguments: SparkAppArguments, appConf: SparkAppConf, timeoutSecs: Int, - sparkHomeDir: Path): Unit = { + sparkHomeDir: Path, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") - val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, + val preCommandLine = if (isJVM) { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, "--deploy-mode", "cluster", "--class", appArguments.mainClass, - "--master", appConf.get("spark.master") - ) ++ appConf.toStringArray :+ - appArguments.mainAppResource) ++ - appArguments.appArgs - ProcessUtils.executeProcess(commandLine, timeoutSecs) + "--master", appConf.get("spark.master")) + } else { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--master", appConf.get("spark.master")) + } + val commandLine = + pyFiles.map(s => preCommandLine ++ Array("--py-files", s)).getOrElse(preCommandLine) ++ + appConf.toStringArray :+ appArguments.mainAppResource + + if (appArguments.appArgs.nonEmpty) { + commandLine += appArguments.appArgs.mkString(" ") + } + logInfo(s"Launching a spark app with command line: ${commandLine.mkString(" ")}") + ProcessUtils.executeProcess(commandLine.toArray, timeoutSecs) } } From e1de34113e057707dfc5ff54a8109b3ec7c16dfb Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 14 Jul 2018 17:50:54 +0800 Subject: [PATCH 1121/2461] [SPARK-17091][SQL] Add rule to convert IN predicate to equivalent Parquet filter ## What changes were proposed in this pull request? The original pr is: https://github.com/apache/spark/pull/18424 Add a new optimizer rule to convert an IN predicate to an equivalent Parquet filter and add `spark.sql.parquet.pushdown.inFilterThreshold` to control limit thresholds. Different data types have different limit thresholds, this is a copy of data for reference: Type | limit threshold -- | -- string | 370 int | 210 long | 285 double | 270 float | 220 decimal | Won't provide better performance before [SPARK-24549](https://issues.apache.org/jira/browse/SPARK-24549) ## How was this patch tested? unit tests and manual tests Author: Yuming Wang Closes #21603 from wangyum/SPARK-17091. --- .../apache/spark/sql/internal/SQLConf.scala | 15 +++ .../FilterPushdownBenchmark-results.txt | 96 +++++++++---------- .../parquet/ParquetFileFormat.scala | 15 ++- .../datasources/parquet/ParquetFilters.scala | 20 +++- .../parquet/ParquetFilterSuite.scala | 66 ++++++++++++- 5 files changed, 153 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 14dd5281fbcb1..699e9394f5be5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -386,6 +386,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("The maximum number of values to filter push-down optimization for IN predicate. " + + "Large threshold won't necessarily provide much better performance. " + + "The experiment argued that 300 is the limit threshold. " + + "By setting this value to 0 this feature can be disabled. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1485,6 +1497,9 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 110669b69a00d..c44908b3b5406 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -417,120 +417,120 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7477 / 7587 2.1 475.4 1.0X -Parquet Vectorized (Pushdown) 7862 / 8346 2.0 499.9 1.0X -Native ORC Vectorized 6447 / 7021 2.4 409.9 1.2X -Native ORC Vectorized (Pushdown) 983 / 1003 16.0 62.5 7.6X +Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X +Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X +Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7107 / 7290 2.2 451.9 1.0X -Parquet Vectorized (Pushdown) 7196 / 7258 2.2 457.5 1.0X -Native ORC Vectorized 6102 / 6222 2.6 388.0 1.2X -Native ORC Vectorized (Pushdown) 926 / 958 17.0 58.9 7.7X +Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X +Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X +Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X +Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7374 / 7692 2.1 468.8 1.0X -Parquet Vectorized (Pushdown) 7771 / 7848 2.0 494.1 0.9X -Native ORC Vectorized 6184 / 6356 2.5 393.2 1.2X -Native ORC Vectorized (Pushdown) 920 / 963 17.1 58.5 8.0X +Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X +Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X +Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X +Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7073 / 7326 2.2 449.7 1.0X -Parquet Vectorized (Pushdown) 7304 / 7647 2.2 464.4 1.0X -Native ORC Vectorized 6222 / 6579 2.5 395.6 1.1X -Native ORC Vectorized (Pushdown) 958 / 994 16.4 60.9 7.4X +Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X +Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X +Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X +Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7121 / 7501 2.2 452.7 1.0X -Parquet Vectorized (Pushdown) 7751 / 8334 2.0 492.8 0.9X -Native ORC Vectorized 6225 / 6680 2.5 395.8 1.1X -Native ORC Vectorized (Pushdown) 998 / 1020 15.8 63.5 7.1X +Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X +Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X +Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X +Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7157 / 7399 2.2 455.1 1.0X -Parquet Vectorized (Pushdown) 7806 / 7911 2.0 496.3 0.9X -Native ORC Vectorized 6548 / 6720 2.4 416.3 1.1X -Native ORC Vectorized (Pushdown) 1016 / 1050 15.5 64.6 7.0X +Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X +Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X +Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X +Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7662 / 7805 2.1 487.1 1.0X -Parquet Vectorized (Pushdown) 7590 / 7861 2.1 482.5 1.0X -Native ORC Vectorized 6840 / 8073 2.3 434.9 1.1X -Native ORC Vectorized (Pushdown) 1041 / 1075 15.1 66.2 7.4X +Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X +Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X +Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X +Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8230 / 9266 1.9 523.2 1.0X -Parquet Vectorized (Pushdown) 7735 / 7960 2.0 491.8 1.1X -Native ORC Vectorized 6945 / 7109 2.3 441.6 1.2X -Native ORC Vectorized (Pushdown) 1123 / 1144 14.0 71.4 7.3X +Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X +Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X +Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X +Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7656 / 8058 2.1 486.7 1.0X -Parquet Vectorized (Pushdown) 7860 / 8247 2.0 499.7 1.0X -Native ORC Vectorized 6684 / 7003 2.4 424.9 1.1X -Native ORC Vectorized (Pushdown) 1085 / 1172 14.5 69.0 7.1X +Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X +Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X +Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X +Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7594 / 8128 2.1 482.8 1.0X -Parquet Vectorized (Pushdown) 7845 / 7923 2.0 498.8 1.0X -Native ORC Vectorized 5859 / 6421 2.7 372.5 1.3X -Native ORC Vectorized (Pushdown) 1037 / 1054 15.2 66.0 7.3X +Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X +Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X +Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X +Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6762 / 6775 2.3 429.9 1.0X -Parquet Vectorized (Pushdown) 6911 / 6970 2.3 439.4 1.0X -Native ORC Vectorized 5884 / 5960 2.7 374.1 1.1X -Native ORC Vectorized (Pushdown) 1028 / 1052 15.3 65.4 6.6X +Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X +Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X +Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X +Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6718 / 6767 2.3 427.1 1.0X -Parquet Vectorized (Pushdown) 6812 / 6909 2.3 433.1 1.0X -Native ORC Vectorized 5842 / 5883 2.7 371.4 1.1X -Native ORC Vectorized (Pushdown) 1040 / 1058 15.1 66.1 6.5X +Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X +Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X +Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X +Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X ================================================================================================ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b86b97ec7b103..efddf8d68eb8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -334,17 +334,15 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val enableRecordFilter: Boolean = - sparkSession.sessionState.conf.parquetRecordFilterEnabled - val timestampConversion: Boolean = - sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize - val enableParquetFilterPushDown: Boolean = - sparkSession.sessionState.conf.parquetFilterPushDown + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -368,12 +366,13 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema + val parquetFilters = new ParquetFilters(pushDownDate, + pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) - .createFilter(parquetSchema, _)) + .flatMap(parquetFilters.createFilter(parquetSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 4c9b940db2b30..e590c153c4151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -37,7 +37,10 @@ import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { +private[parquet] class ParquetFilters( + pushDownDate: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int) { private case class ParquetSchemaType( originalType: OriginalType, @@ -232,6 +235,15 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: // See SPARK-20364. def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") + // All DataTypes that support `makeEq` can provide better performance. + def shouldConvertInPredicate(name: String): Boolean = nameToType(name) match { + case ParquetBooleanType | ParquetByteType | ParquetShortType | ParquetIntegerType + | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType + | ParquetBinaryType => true + case ParquetDateType if pushDownDate => true + case _ => false + } + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -295,6 +307,12 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if canMakeFilterOn(name) && shouldConvertInPredicate(name) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct.flatMap { v => + makeEq.lift(nameToType(name)).map(_(name, v)) + }.reduceLeftOption(FilterApi.or) + case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => Option(prefix).map { v => FilterApi.userDefined(binaryColumn(name), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 067d2fea14fd7..00c191f755520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets import java.sql.Date -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -56,7 +56,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private lazy val parquetFilters = - new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith) + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -803,6 +804,67 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // Test inverseCanDrop() has taken effect testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + val data = 0 to 1024 + data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushEnabled.toString) { + Seq(1, 5, 10, 11).foreach { count => + val filter = s"a in(${Range(0, count).mkString(",")})" + assert(df.where(filter).count() === count) + val actual = stripSparkFilter(df.where(filter)).collect().length + if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { + assert(actual > 1 && actual < data.length) + } else { + assert(actual === data.length) + } + } + assert(df.where("a in(null)").count() === 0) + assert(df.where("a = null").count() === 0) + assert(df.where("a is null").count() === 1) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 8aceb961c3b8e462c6002dbe03be61b4fe194f47 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 14 Jul 2018 15:59:17 -0500 Subject: [PATCH 1122/2461] [SPARK-24754][ML] Minhash integer overflow ## What changes were proposed in this pull request? Use longs in calculating min hash to avoid bias due to int overflow. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #21750 from srowen/SPARK-24754. --- .../main/scala/org/apache/spark/ml/feature/MinHashLSH.scala | 2 +- python/pyspark/ml/feature.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index a67a3b0abbc1f..a043033e96724 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -66,7 +66,7 @@ class MinHashLSHModel private[ml]( val elemsList = elems.toSparse.indices.toList val hashValues = randCoefficients.map { case (a, b) => elemsList.map { elem: Int => - ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME }.min.toDouble } // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 14800d4d9327a..ddba7389145e3 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) >>> model = mh.fit(df) >>> model.transform(df).head() - Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925... + Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),), ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),), ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)] >>> df2 = spark.createDataFrame(data2, ["id", "features"]) >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0]) >>> model.approxNearestNeighbors(df2, key, 1).collect() - [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892... + [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668... >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select( ... col("datasetA.id").alias("idA"), ... col("datasetB.id").alias("idB"), @@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, +---+---+---------------+ |idA|idB|JaccardDistance| +---+---+---------------+ - | 1| 4| 0.5| | 0| 5| 0.5| + | 1| 4| 0.5| +---+---+---------------+ ... >>> mhPath = temp_path + "/mh" From 43e4e851b642bbee535d22e1b9e72ec6b99f6ed4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 15 Jul 2018 11:13:49 +0800 Subject: [PATCH 1123/2461] [SPARK-24718][SQL] Timestamp support pushdown to parquet data source ## What changes were proposed in this pull request? `Timestamp` support pushdown to parquet data source. Only `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS` support push down. ## How was this patch tested? unit tests and benchmark tests Author: Yuming Wang Closes #21741 from wangyum/SPARK-24718. --- .../apache/spark/sql/internal/SQLConf.scala | 11 ++ .../FilterPushdownBenchmark-results.txt | 124 ++++++++++++++++++ .../parquet/ParquetFileFormat.scala | 3 +- .../datasources/parquet/ParquetFilters.scala | 59 ++++++++- .../benchmark/FilterPushdownBenchmark.scala | 37 +++++- .../parquet/ParquetFilterSuite.scala | 74 ++++++++++- 6 files changed, 301 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 699e9394f5be5..07d33fa7d52ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -378,6 +378,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + @@ -1494,6 +1503,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index c44908b3b5406..4f38cc4cee96d 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -578,3 +578,127 @@ Native ORC Vectorized 11622 / 12196 1.4 7 Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X +================================================================================================ +Pushdown benchmark for Timestamp +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X +Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X +Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X +Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X +Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X +Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X +Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X +Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X +Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X +Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X +Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X +Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X +Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X +Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X +Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X +Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X +Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X +Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X +Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X +Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X +Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X +Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X +Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X +Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X +Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X +Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X +Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X +Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X +Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X +Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X +Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X +Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X +Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X +Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X +Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X +Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X +Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index efddf8d68eb8f..3ec33b2f4b540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -341,6 +341,7 @@ class ParquetFileFormat // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold @@ -366,7 +367,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema - val parquetFilters = new ParquetFilters(pushDownDate, + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index e590c153c4151..0c146f2f6f915 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.sql.Date +import java.lang.{Long => JLong} +import java.sql.{Date, Timestamp} import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator, PrimitiveType} +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -39,6 +40,7 @@ import org.apache.spark.unsafe.types.UTF8String */ private[parquet] class ParquetFilters( pushDownDate: Boolean, + pushDownTimestamp: Boolean, pushDownStartWith: Boolean, pushDownInFilterThreshold: Int) { @@ -57,6 +59,8 @@ private[parquet] class ParquetFilters( private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -89,6 +93,15 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) } private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -117,6 +130,15 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -139,6 +161,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -161,6 +191,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -183,6 +221,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -205,6 +251,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } /** @@ -241,6 +295,7 @@ private[parquet] class ParquetFilters( | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType | ParquetBinaryType => true case ParquetDateType if pushDownDate => true + case ParquetTimestampMicrosType | ParquetTimestampMillisType if pushDownTimestamp => true case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index fc716dec9f337..567a8ebf9d102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -28,7 +28,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType} +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} import org.apache.spark.util.{Benchmark, Utils} /** @@ -359,6 +360,40 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter } } } + + ignore(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } + } + } + } + } + } + } } trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 00c191f755520..924f136503656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets -import java.sql.Date +import java.sql.{Date, Timestamp} import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -56,8 +57,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private lazy val parquetFilters = - new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith, - conf.parquetFilterPushDownInFilterThreshold) + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -84,6 +85,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df @@ -144,6 +146,39 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + + checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) + checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) + checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) + + checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) + checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) + checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) + checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) + + checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) + checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) + } + } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { withTempPath { dir => @@ -444,6 +479,39 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - timestamp") { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), + Timestamp.valueOf("2018-06-15 08:28:53.123"), + Timestamp.valueOf("2018-06-16 08:28:53.123"), + Timestamp.valueOf("2018-06-17 08:28:53.123")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), + Timestamp.valueOf("2018-06-15 08:28:53.123456"), + Timestamp.valueOf("2018-06-16 08:28:53.123456"), + Timestamp.valueOf("2018-06-17 08:28:53.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.IsNull("_1")) + } + } + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From 3e7dc82960fd3339eee16d83df66761ae6e3fe3d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 14 Jul 2018 21:36:56 -0700 Subject: [PATCH 1124/2461] [SPARK-24776][SQL] Avro unit test: deduplicate code and replace deprecated methods ## What changes were proposed in this pull request? Improve Avro unit test: 1. use QueryTest/SharedSQLContext/SQLTestUtils, instead of the duplicated test utils. 2. replace deprecated methods This is a follow up PR for #21760, the PR passes pull request tests but failed in: https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-maven-hadoop-2.6/7842/ This PR is to fix it. ## How was this patch tested? Unit test. Compile with different commands: ``` ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-2.6 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-2.7 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-3.1 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ``` Author: Gengliang Wang Closes #21768 from gengliangwang/improve_avro_test. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 98 +++++------ .../org/apache/spark/sql/avro/TestUtils.scala | 156 ------------------ 2 files changed, 45 insertions(+), 209 deletions(-) delete mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c6c1e4051a4b3..4f94d827e3127 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends SparkFunSuite { +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" - private var spark: SparkSession = _ - override protected def beforeAll(): Unit = { super.beforeAll() - spark = SparkSession.builder() - .master("local[2]") - .appName("AvroSuite") - .config("spark.sql.files.maxPartitionBytes", 1024) - .getOrCreate() - } - - override protected def afterAll(): Unit = { - try { - spark.sparkContext.stop() - } finally { - super.afterAll() - } + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(newFile) + checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { @@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - TestUtils.withTempDir { dir => + withTempPath { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -82,12 +74,12 @@ class AvroSuite extends SparkFunSuite { test("request no fields") { val df = spark.read.avro(episodesFile) - df.registerTempTable("avro_table") + df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -95,15 +87,16 @@ class AvroSuite extends SparkFunSuite { } test("rearrange internal schema") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - TestUtils.withTempDir { dir => - val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -122,7 +115,7 @@ class AvroSuite extends SparkFunSuite { } test("union(int, long) is read as long") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) @@ -150,7 +143,7 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double) is read as double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) @@ -178,7 +171,7 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double, null) is read as nullable double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -210,7 +203,7 @@ class AvroSuite extends SparkFunSuite { } test("Union of a single type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) @@ -233,7 +226,7 @@ class AvroSuite extends SparkFunSuite { } test("Complex Union Type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( @@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite { } test("Lots of nulls") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite { } test("Struct field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite { } test("Date field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite { } test("Array data types") { - TestUtils.withTempDir { dir => + withTempPath { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite { } test("write with compression") { - TestUtils.withTempDir { dir => + withTempPath { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY TABLE avroTable + |CREATE TEMPORARY VIEW avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { dir => + withTempPath { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite { } test("converting some specific sparkSQL types to avro") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite { } test("correctly read long as date/timestamp type") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite { } test("does not coerce null date/timestamp value to 0 epoch.") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite { // Directory given has no avro files intercept[AnalysisException] { - TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) + withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite { } intercept[FileNotFoundException] { - TestUtils.withTempDir { dir => + withTempPath { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite { } test("SQL test insert overwrite") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY TABLE episodes + |CREATE TEMPORARY VIEW episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY TABLE episodesEmpty + |CREATE TEMPORARY VIEW episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite { test("test save and load") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite { test("test load with non-Avro file") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite { } test("read avro file partitioned") { - TestUtils.withTempDir { dir => + withTempPath { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala deleted file mode 100755 index 4ae9b14d9ad0d..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import java.io.{File, IOException} -import java.nio.ByteBuffer - -import scala.collection.immutable.HashSet -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import com.google.common.io.Files -import java.util - -import org.apache.spark.sql.SparkSession - -private[avro] object TestUtils { - - /** - * This function checks that all records in a file match the original - * record. - */ - def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { - - def convertToString(elem: Any): String = { - elem match { - case null => "NULL" // HashSets can't have null in them, so we use a string instead - case arrayBuf: ArrayBuffer[_] => - arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") - case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") - case other => other.toString - } - } - - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(avroDir).collect() - - assert(originalEntries.length == newEntries.length) - - val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) - for (origEntry <- originalEntries) { - var idx = 0 - for (origElement <- origEntry.toSeq) { - origEntrySet(idx) += convertToString(origElement) - idx += 1 - } - } - - for (newEntry <- newEntries) { - var idx = 0 - for (newElement <- newEntry.toSeq) { - assert(origEntrySet(idx).contains(convertToString(newElement))) - idx += 1 - } - } - } - - def withTempDir(f: File => Unit): Unit = { - val dir = Files.createTempDir() - dir.delete() - try f(dir) finally deleteRecursively(dir) - } - - /** - * This function deletes a file or a directory with everything that's in it. This function is - * copied from Spark with minor modifications made to it. See original source at: - * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala - */ - - def deleteRecursively(file: File) { - def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - if (file != null) { - try { - if (file.isDirectory) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - } - } finally { - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } - } - } - - /** - * This function generates a random map(string, int) of a given size. - */ - private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { - val jMap = new util.HashMap[String, Int]() - for (i <- 0 until size) { - jMap.put(rand.nextString(5), i) - } - jMap - } - - /** - * This function generates a random array of booleans of a given size. - */ - private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { - val vec = new util.ArrayList[Boolean]() - for (i <- 0 until size) { - vec.add(rand.nextBoolean()) - } - vec - } - - /** - * This function generates a random ByteBuffer of a given size. - */ - private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { - val bb = ByteBuffer.allocate(size) - val arrayOfBytes = new Array[Byte](size) - rand.nextBytes(arrayOfBytes) - bb.put(arrayOfBytes) - } -} From 69993217fc4f5e5e41a297702389e86fe534dc2f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 14 Jul 2018 22:07:49 -0700 Subject: [PATCH 1125/2461] [SPARK-24807][CORE] Adding files/jars twice: output a warning and add a note ## What changes were proposed in this pull request? In the PR, I propose to output an warning if the `addFile()` or `addJar()` methods are callled more than once for the same path. Currently, overwriting of already added files is not supported. New comments and warning are reflected the existing behaviour. Author: Maxim Gekk Closes #21771 from MaxGekk/warning-on-adding-file. --- R/pkg/R/context.R | 2 ++ .../main/scala/org/apache/spark/SparkContext.scala | 12 ++++++++++++ .../org/apache/spark/api/java/JavaSparkContext.scala | 6 ++++++ python/pyspark/context.py | 4 ++++ 4 files changed, 24 insertions(+) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8ec727dd042bc..3e996a5ba26fc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -305,6 +305,8 @@ setCheckpointDirSC <- function(sc, dirName) { #' Currently directories are only supported for Hadoop-supported filesystems. #' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. #' +#' Note: A path can be added only once. Subsequent additions of the same path are ignored. +#' #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 74bfb5d6d2ea3..531384ab57305 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1496,6 +1496,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String): Unit = { addFile(path, false) @@ -1516,6 +1518,8 @@ class SparkContext(config: SparkConf) extends Logging { * use `SparkFiles.get(fileName)` to find its download location. * @param recursive if true, a directory can be given in `path`. Currently directories are * only supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri @@ -1555,6 +1559,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() + } else { + logWarning(s"The path $path has been added already. Overwriting of added paths " + + "is not supported in the current version.") } } @@ -1803,6 +1810,8 @@ class SparkContext(config: SparkConf) extends Logging { * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { def addJarFile(file: File): String = { @@ -1849,6 +1858,9 @@ class SparkContext(config: SparkConf) extends Logging { if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() + } else { + logWarning(s"The jar $path has been added already. Overwriting of added jars " + + "is not supported in the current version.") } } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1936bf587282..09c83849e26b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -668,6 +668,8 @@ class JavaSparkContext(val sc: SparkContext) * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String) { sc.addFile(path) @@ -681,6 +683,8 @@ class JavaSparkContext(val sc: SparkContext) * * A directory can be given if the recursive option is set to true. Currently directories are only * supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { sc.addFile(path, recursive) @@ -690,6 +694,8 @@ class JavaSparkContext(val sc: SparkContext) * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { sc.addJar(path) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ede3b6af0a8cf..2cb3117184334 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -847,6 +847,8 @@ def addFile(self, path, recursive=False): A directory can be given if the recursive option is set to True. Currently directories are only supported for Hadoop-supported filesystems. + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. + >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: @@ -867,6 +869,8 @@ def addPyFile(self, path): SparkContext in the future. The C{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, HTTPS or FTP URI. + + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix From 96030876383822645a5b35698ee407a8d4eb76af Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 15 Jul 2018 22:06:33 +0800 Subject: [PATCH 1126/2461] [SPARK-24800][SQL] Refactor Avro Serializer and Deserializer ## What changes were proposed in this pull request? Currently the Avro Deserializer converts input Avro format data to `Row`, and then convert the `Row` to `InternalRow`. While the Avro Serializer converts `InternalRow` to `Row`, and then output Avro format data. This PR allows direct conversion between `InternalRow` and Avro format data. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21762 from gengliangwang/avro_io. --- .../spark/sql/avro/AvroDeserializer.scala | 348 ++++++++++++++++++ .../spark/sql/avro/AvroFileFormat.scala | 24 +- .../spark/sql/avro/AvroOutputWriter.scala | 109 +----- .../sql/avro/AvroOutputWriterFactory.scala | 5 +- .../spark/sql/avro/AvroSerializer.scala | 180 +++++++++ .../spark/sql/avro/SchemaConverters.scala | 333 ++--------------- .../spark/sql/avro/SerializableSchema.scala | 69 ++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 1 - .../sql/avro/SerializableSchemaSuite.scala | 56 +++ 9 files changed, 704 insertions(+), 421 deletions(-) create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..b31149a2c74c2 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + + } + updater.set(ordinal, bytes) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(avroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 46e5a189c5eb3..fb93033bb15d4 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema import org.apache.avro.file.{DataFileConstants, DataFileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} @@ -38,8 +38,6 @@ import org.slf4j.LoggerFactory import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType @@ -118,8 +116,8 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { dataSchema: StructType): OutputWriterFactory = { val recordName = options.getOrElse("recordName", "topLevelRecord") val recordNamespace = options.getOrElse("recordNamespace", "") - val build = SchemaBuilder.record(recordName).namespace(recordNamespace) - val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) + val outputAvroSchema = SchemaConverters.toAvroType( + dataSchema, nullable = false, recordName, recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" @@ -148,7 +146,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { log.error(s"unsupported compression codec $unknown") } - new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + new AvroOutputWriterFactory(dataSchema, new SerializableSchema(outputAvroSchema)) } override def buildReader( @@ -205,13 +203,10 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { reader.sync(file.start) val stop = file.start + file.length - val rowConverter = SchemaConverters.createConverterToSQL( - userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) new Iterator[InternalRow] { - // Used to convert `Row`s containing data columns into `InternalRow`s. - private val encoderForDataColumns = RowEncoder(requiredSchema) - private[this] var completed = false override def hasNext: Boolean = { @@ -228,14 +223,11 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { } override def next(): InternalRow = { - if (reader.pastSync(stop)) { + if (!hasNext) { throw new NoSuchElementException("next on empty iterator") } val record = reader.next() - val safeDataRow = rowConverter(record).asInstanceOf[GenericRow] - - // The safeDataRow is reused, we must do a copy - encoderForDataColumns.toRow(safeDataRow) + deserializer.deserialize(record).asInstanceOf[InternalRow] } } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 830bf3c0570bf..06507115f5ed8 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -18,14 +18,8 @@ package org.apache.spark.sql.avro import java.io.{IOException, OutputStream} -import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} -import java.util.HashMap -import scala.collection.immutable.Map - -import org.apache.avro.{Schema, SchemaBuilder} -import org.apache.avro.generic.GenericData.Record +import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord import org.apache.avro.mapred.AvroKey import org.apache.avro.mapreduce.AvroKeyOutputFormat @@ -33,8 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ @@ -43,13 +36,10 @@ private[avro] class AvroOutputWriter( path: String, context: TaskAttemptContext, schema: StructType, - recordName: String, - recordNamespace: String) extends OutputWriter { + avroSchema: Schema) extends OutputWriter { - private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) - // copy of the old conversion logic after api change in SPARK-19085 - private lazy val internalRowConverter = - CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row] + // The input rows will never be null. + private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false) /** * Overrides the couple of methods responsible for generating the output streams / files so @@ -70,95 +60,10 @@ private[avro] class AvroOutputWriter( }.getRecordWriter(context) - override def write(internalRow: InternalRow): Unit = { - val row = internalRowConverter(internalRow) - val key = new AvroKey(converter(row).asInstanceOf[GenericRecord]) + override def write(row: InternalRow): Unit = { + val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord]) recordWriter.write(key, NullWritable.get()) } override def close(): Unit = recordWriter.close(context) - - /** - * This function constructs converter function for a given sparkSQL datatype. This is used in - * writing Avro records out to disk - */ - private def createConverterToAvro( - dataType: DataType, - structName: String, - recordNamespace: String): (Any) => Any = { - dataType match { - case BinaryType => (item: Any) => item match { - case null => null - case bytes: Array[Byte] => ByteBuffer.wrap(bytes) - } - case ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | StringType | BooleanType => identity - case _: DecimalType => (item: Any) => if (item == null) null else item.toString - case TimestampType => (item: Any) => - if (item == null) null else item.asInstanceOf[Timestamp].getTime - case DateType => (item: Any) => - if (item == null) null else item.asInstanceOf[Date].getTime - case ArrayType(elementType, _) => - val elementConverter = createConverterToAvro( - elementType, - structName, - SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName)) - (item: Any) => { - if (item == null) { - null - } else { - val sourceArray = item.asInstanceOf[Seq[Any]] - val sourceArraySize = sourceArray.size - val targetArray = new Array[Any](sourceArraySize) - var idx = 0 - while (idx < sourceArraySize) { - targetArray(idx) = elementConverter(sourceArray(idx)) - idx += 1 - } - targetArray - } - } - case MapType(StringType, valueType, _) => - val valueConverter = createConverterToAvro( - valueType, - structName, - SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName)) - (item: Any) => { - if (item == null) { - null - } else { - val javaMap = new HashMap[String, Any]() - item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => - javaMap.put(key, valueConverter(value)) - } - javaMap - } - } - case structType: StructType => - val builder = SchemaBuilder.record(structName).namespace(recordNamespace) - val schema: Schema = SchemaConverters.convertStructToAvro( - structType, builder, recordNamespace) - val fieldConverters = structType.fields.map(field => - createConverterToAvro( - field.dataType, - field.name, - SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name))) - (item: Any) => { - if (item == null) { - null - } else { - val record = new Record(schema) - val convertersIterator = fieldConverters.iterator - val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator - val rowIterator = item.asInstanceOf[Row].toSeq.iterator - - while (convertersIterator.hasNext) { - val converter = convertersIterator.next() - record.put(fieldNamesIterator.next(), converter(rowIterator.next())) - } - record - } - } - } - } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala index 5b2ce7d7d8e0f..18a6d93951408 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.types.StructType private[avro] class AvroOutputWriterFactory( schema: StructType, - recordName: String, - recordNamespace: String) extends OutputWriterFactory { + avroSchema: SerializableSchema) extends OutputWriterFactory { override def getFileExtension(context: TaskAttemptContext): String = ".avro" @@ -33,6 +32,6 @@ private[avro] class AvroOutputWriterFactory( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + new AvroOutputWriter(path, context, schema, avroSchema.value) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..2b4c5813a535b --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.Type.NULL +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + */ +class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { + catalystType match { + case NullType => + (getter, ordinal) => null + case BooleanType => + (getter, ordinal) => getter.getBoolean(ordinal) + case ByteType => + (getter, ordinal) => getter.getByte(ordinal).toInt + case ShortType => + (getter, ordinal) => getter.getShort(ordinal).toInt + case IntegerType => + (getter, ordinal) => getter.getInt(ordinal) + case LongType => + (getter, ordinal) => getter.getLong(ordinal) + case FloatType => + (getter, ordinal) => getter.getFloat(ordinal) + case DoubleType => + (getter, ordinal) => getter.getDouble(ordinal) + case d: DecimalType => + (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString + case StringType => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + case BinaryType => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + case DateType => + (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY + case TimestampType => + (getter, ordinal) => getter.getLong(ordinal) / 1000 + + case ArrayType(et, containsNull) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull)) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val result = new java.util.ArrayList[Any] + var i = 0 + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(arrayData, i)) + } + i += 1 + } + result + } + + case st: StructType => + val structConverter = newStructConverter(st, avroType) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case MapType(kt, vt, valueContainsNull) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull)) + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val result = new java.util.HashMap[String, Any](mapData.numElements()) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < mapData.numElements()) { + val key = keyArray.getUTF8String(i).toString + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case other => + throw new IncompatibleSchemaException(s"Unexpected type: $other") + } + } + + private def newStructConverter( + catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + val avroFields = avroStruct.getFields + assert(avroFields.size() == catalystStruct.length) + val fieldConverters = catalystStruct.zip(avroFields.asScala).map { + case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) + } + val numFields = catalystStruct.length + (row: InternalRow) => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(i, null) + } else { + result.put(i, fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + if (nullable) { + // avro uses union to represent nullable type. + val fields = avroType.getTypes.asScala + assert(fields.length == 2) + val actualType = fields.filter(_.getType != NULL) + assert(actualType.length == 1) + actualType.head + } else { + avroType + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 01f8c74982535..87fae63aeff2b 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -17,18 +17,11 @@ package org.apache.spark.sql.avro -import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} - import scala.collection.JavaConverters._ import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.Schema.Type._ -import org.apache.avro.SchemaBuilder._ -import org.apache.avro.generic.{GenericData, GenericRecord} -import org.apache.avro.generic.GenericFixed -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ /** @@ -36,9 +29,6 @@ import org.apache.spark.sql.types._ * versa. */ object SchemaConverters { - - class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) - case class SchemaType(dataType: DataType, nullable: Boolean) /** @@ -109,298 +99,43 @@ object SchemaConverters { } } - /** - * This function converts sparkSQL StructType into avro schema. This method uses two other - * converter methods in order to do the conversion. - */ - def convertStructToAvro[T]( - structType: StructType, - schemaBuilder: RecordBuilder[T], - recordNamespace: String): T = { - val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields() - structType.fields.foreach { field => - val newField = fieldsAssembler.name(field.name).`type`() - - if (field.nullable) { - convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace) - .noDefault - } else { - convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace) - .noDefault - } - } - fieldsAssembler.endRecord() - } - - /** - * Returns a converter function to convert row in avro format to GenericRow of catalyst. - * - * @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in - * by user. - * @param targetSqlType Target catalyst sql type after the conversion. - * @return returns a converter function to convert row in avro format to GenericRow of catalyst. - */ - private[avro] def createConverterToSQL( - sourceAvroSchema: Schema, - targetSqlType: DataType): AnyRef => AnyRef = { - - def createConverter(avroSchema: Schema, - sqlType: DataType, path: List[String]): AnyRef => AnyRef = { - val avroType = avroSchema.getType - (sqlType, avroType) match { - // Avro strings are in Utf8, so we have to call toString on them - case (StringType, STRING) | (StringType, ENUM) => - (item: AnyRef) => item.toString - // Byte arrays are reused by avro, so we have to make a copy of them. - case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | - (FloatType, FLOAT) | (LongType, LONG) => - identity - case (TimestampType, LONG) => - (item: AnyRef) => new Timestamp(item.asInstanceOf[Long]) - case (DateType, LONG) => - (item: AnyRef) => new Date(item.asInstanceOf[Long]) - case (BinaryType, FIXED) => - (item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone() - case (BinaryType, BYTES) => - (item: AnyRef) => - val byteBuffer = item.asInstanceOf[ByteBuffer] - val bytes = new Array[Byte](byteBuffer.remaining) - byteBuffer.get(bytes) - bytes - case (struct: StructType, RECORD) => - val length = struct.fields.length - val converters = new Array[AnyRef => AnyRef](length) - val avroFieldIndexes = new Array[Int](length) - var i = 0 - while (i < length) { - val sqlField = struct.fields(i) - val avroField = avroSchema.getField(sqlField.name) - if (avroField != null) { - val converter = (item: AnyRef) => { - if (item == null) { - item - } else { - createConverter(avroField.schema, sqlField.dataType, path :+ sqlField.name)(item) - } - } - converters(i) = converter - avroFieldIndexes(i) = avroField.pos() - } else if (!sqlField.nullable) { - throw new IncompatibleSchemaException( - s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " + - "in Avro schema\n" + - s"Source Avro schema: $sourceAvroSchema.\n" + - s"Target Catalyst type: $targetSqlType") - } - i += 1 - } - - (item: AnyRef) => - val record = item.asInstanceOf[GenericRecord] - val result = new Array[Any](length) - var i = 0 - while (i < converters.length) { - if (converters(i) != null) { - val converter = converters(i) - result(i) = converter(record.get(avroFieldIndexes(i))) - } - i += 1 - } - new GenericRow(result) - case (arrayType: ArrayType, ARRAY) => - val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType, - path) - val allowsNull = arrayType.containsNull - (item: AnyRef) => - item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element => - if (element == null && !allowsNull) { - throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + - "allowed to be null") - } else { - elementConverter(element) - } - } - case (mapType: MapType, MAP) if mapType.keyType == StringType => - val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path) - val allowsNull = mapType.valueContainsNull - (item: AnyRef) => - item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { case (k, v) => - if (v == null && !allowsNull) { - throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + - "allowed to be null") - } else { - (k.toString, valueConverter(v)) - } - }.toMap - case (sqlType, UNION) => - if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { - val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) - if (remainingUnionTypes.size == 1) { - createConverter(remainingUnionTypes.head, sqlType, path) - } else { - createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path) - } - } else avroSchema.getTypes.asScala.map(_.getType) match { - case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path) - case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType => - (item: AnyRef) => - item match { - case l: java.lang.Long => l - case i: java.lang.Integer => new java.lang.Long(i.longValue()) - } - case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType => - (item: AnyRef) => - item match { - case d: java.lang.Double => d - case f: java.lang.Float => new java.lang.Double(f.doubleValue()) - } - case other => - sqlType match { - case t: StructType if t.fields.length == avroSchema.getTypes.size => - val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map { - case (field, schema) => - createConverter(schema, field.dataType, path :+ field.name) - } - (item: AnyRef) => - val i = GenericData.get().resolveUnion(avroSchema, item) - val converted = new Array[Any](fieldConverters.length) - converted(i) = fieldConverters(i)(item) - new GenericRow(converted) - case _ => throw new IncompatibleSchemaException( - s"Cannot convert Avro schema to catalyst type because schema at path " + - s"${path.mkString(".")} is not compatible " + - s"(avroType = $other, sqlType = $sqlType). \n" + - s"Source Avro schema: $sourceAvroSchema.\n" + - s"Target Catalyst type: $targetSqlType") - } - } - case (left, right) => - throw new IncompatibleSchemaException( - s"Cannot convert Avro schema to catalyst type because schema at path " + - s"${path.mkString(".")} is not compatible (avroType = $right, sqlType = $left). \n" + - s"Source Avro schema: $sourceAvroSchema.\n" + - s"Target Catalyst type: $targetSqlType") - } - } - createConverter(sourceAvroSchema, targetSqlType, List.empty[String]) - } - - /** - * This function is used to convert some sparkSQL type to avro type. Note that this function won't - * be used to construct fields of avro record (convertFieldTypeToAvro is used for that). - */ - private def convertTypeToAvro[T]( - dataType: DataType, - schemaBuilder: BaseTypeBuilder[T], - structName: String, - recordNamespace: String): T = { - dataType match { - case ByteType => schemaBuilder.intType() - case ShortType => schemaBuilder.intType() - case IntegerType => schemaBuilder.intType() - case LongType => schemaBuilder.longType() - case FloatType => schemaBuilder.floatType() - case DoubleType => schemaBuilder.doubleType() - case _: DecimalType => schemaBuilder.stringType() - case StringType => schemaBuilder.stringType() - case BinaryType => schemaBuilder.bytesType() - case BooleanType => schemaBuilder.booleanType() - case TimestampType => schemaBuilder.longType() - case DateType => schemaBuilder.longType() - - case ArrayType(elementType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) - val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) - schemaBuilder.array().items(elementSchema) - - case MapType(StringType, valueType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) - val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) - schemaBuilder.map().values(valueSchema) - - case structType: StructType => - convertStructToAvro( - structType, - schemaBuilder.record(structName).namespace(recordNamespace), - recordNamespace) - - case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") - } - } - - /** - * This function is used to construct fields of the avro record, where schema of the field is - * specified by avro representation of dataType. Since builders for record fields are different - * from those for everything else, we have to use a separate method. - */ - private def convertFieldTypeToAvro[T]( - dataType: DataType, - newFieldBuilder: BaseFieldTypeBuilder[T], - structName: String, - recordNamespace: String): FieldDefault[T, _] = { - dataType match { - case ByteType => newFieldBuilder.intType() - case ShortType => newFieldBuilder.intType() - case IntegerType => newFieldBuilder.intType() - case LongType => newFieldBuilder.longType() - case FloatType => newFieldBuilder.floatType() - case DoubleType => newFieldBuilder.doubleType() - case _: DecimalType => newFieldBuilder.stringType() - case StringType => newFieldBuilder.stringType() - case BinaryType => newFieldBuilder.bytesType() - case BooleanType => newFieldBuilder.booleanType() - case TimestampType => newFieldBuilder.longType() - case DateType => newFieldBuilder.longType() - - case ArrayType(elementType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) - val elementSchema = convertTypeToAvro( - elementType, - builder, - structName, - getNewRecordNamespace(elementType, recordNamespace, structName)) - newFieldBuilder.array().items(elementSchema) - - case MapType(StringType, valueType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) - val valueSchema = convertTypeToAvro( - valueType, - builder, - structName, - getNewRecordNamespace(valueType, recordNamespace, structName)) - newFieldBuilder.map().values(valueSchema) - - case structType: StructType => - convertStructToAvro( - structType, - newFieldBuilder.record(structName).namespace(s"$recordNamespace.$structName"), - s"$recordNamespace.$structName") - - case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") - } - } - - /** - * Returns a new namespace depending on the data type of the element. - * If the data type is a StructType it returns the current namespace concatenated - * with the element name, otherwise it returns the current namespace as it is. - */ - private[avro] def getNewRecordNamespace( - elementDataType: DataType, - currentRecordNamespace: String, - elementName: String): String = { - - elementDataType match { - case StructType(_) => s"$currentRecordNamespace.$elementName" - case _ => currentRecordNamespace - } - } - - private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = { - if (isNullable) { + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + prevNameSpace: String = ""): Schema = { + val builder = if (nullable) { SchemaBuilder.builder().nullable() } else { SchemaBuilder.builder() } + catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => builder.longType() + case TimestampType => builder.longType() + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case _: DecimalType | StringType => builder.stringType() + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) + case st: StructType => + val nameSpace = s"$prevNameSpace.$recordName" + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } } } + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala new file mode 100644 index 0000000000000..ec0ddc778c8f6 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.Schema +import org.slf4j.LoggerFactory + +class SerializableSchema(@transient var value: Schema) + extends Serializable with KryoSerializable { + + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + out.writeUTF(value.toString()) + out.flush() + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + val json = in.readUTF() + value = new Schema.Parser().parse(json) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + dos.writeUTF(value.toString()) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + val dis = new DataInputStream(in) + val json = dis.readUTF() + value = new Schema.Parser().parse(json) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 4f94d827e3127..6ed66563b987a 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -32,7 +32,6 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ -import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala new file mode 100644 index 0000000000000..510bcbdd31929 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableSchemaSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + val avroSchema = new Schema.Parser().parse(avroTypeJson) + val serializableSchema = new SerializableSchema(avroSchema) + val serialized = serializer.serialize(serializableSchema) + + serializer.deserialize[Any](serialized) match { + case c: SerializableSchema => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + assert(c.value == avroSchema) + case other => fail( + s"Expecting ${classOf[SerializableSchema]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } +} From 5d62a985dca9280f884e13e29fc7166ef13c459f Mon Sep 17 00:00:00 2001 From: "Zoltan C. Toth" Date: Sun, 15 Jul 2018 17:08:26 -0500 Subject: [PATCH 1127/2461] Doc fix: The Imputer is an Estimator Fixing the doc as the imputer is not a `Transformer` but an `Estimator`. https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala#L96-L97 ## What changes were proposed in this pull request? Simple documentation fix ## How was this patch tested? manual testing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Zoltan C. Toth Closes #21755 from zoltanctoth/doc-imputer-is-estimator. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 7aed2341584fc..ad6e718b37f1b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1429,7 +1429,7 @@ for more details on the API. ## Imputer -The `Imputer` transformer completes missing values in a dataset, either using the mean or the +The `Imputer` estimator completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly creates incorrect values for columns containing categorical features. Imputer can impute custom values From bbc2ffc8ab27192384def9847c36b873efd87234 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 16 Jul 2018 09:29:51 +0800 Subject: [PATCH 1128/2461] [SPARK-24813][TESTS][HIVE][HOTFIX] HiveExternalCatalogVersionsSuite still flaky; fall back to Apache archive ## What changes were proposed in this pull request? Try only unique ASF mirrors to download Spark release; fall back to Apache archive if no mirrors available or release is not mirrored ## How was this patch tested? Existing HiveExternalCatalogVersionsSuite Author: Sean Owen Closes #21776 from srowen/SPARK-24813. --- .../HiveExternalCatalogVersionsSuite.scala | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 514921875f1f9..f8212684d5335 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -56,14 +56,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { } private def tryDownloadSpark(version: String, path: String): Unit = { - // Try mirrors a few times until one succeeds - for (i <- 0 until 3) { - // we don't retry on a failure to get mirror url. If we can't get a mirror url, - // the test fails (getStringFromUrl will throw an exception) - val preferredMirror = - getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + // Try a few mirrors first; fall back to Apache archive + val mirrors = + (0 until 2).flatMap { _ => + try { + Some(getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true")) + } catch { + // If we can't get a mirror URL, skip it. No retry. + case _: Exception => None + } + } + val sites = mirrors.distinct :+ "https://archive.apache.org/dist" + logInfo(s"Trying to download Spark $version from $sites") + for (site <- sites) { val filename = s"spark-$version-bin-hadoop2.7.tgz" - val url = s"$preferredMirror/spark/spark-$version/$filename" + val url = s"$site/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) @@ -83,7 +90,8 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { Seq("rm", "-rf", targetDir).! } } catch { - case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) + case ex: Exception => + logWarning(s"Failed to download Spark $version from $url: ${ex.getMessage}") } } fail(s"Unable to download Spark $version") From bcf7121ed2283d88424863ac1d35393870eaae6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E7=91=9E=E5=B3=B0?= Date: Sun, 15 Jul 2018 20:14:17 -0700 Subject: [PATCH 1129/2461] [TRIVIAL][ML] GMM unpersist RDD after training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? unpersist `instances` after training ## How was this patch tested? existing tests Author: 郑瑞峰 Closes #21562 from zhengruifeng/gmm_unpersist. --- .../scala/org/apache/spark/ml/clustering/GaussianMixture.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index dae64ba9a515d..f0707b380c673 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -341,7 +341,7 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset + val instances = dataset .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -416,6 +416,7 @@ class GaussianMixture @Since("2.0.0") ( iter += 1 } + instances.unpersist(false) val gaussianDists = gaussians.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov) From d463533ded89a05e9f77e590fd3de2ffa212d68b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 15 Jul 2018 20:22:09 -0700 Subject: [PATCH 1130/2461] [SPARK-24676][SQL] Project required data from CSV parsed data when column pruning disabled ## What changes were proposed in this pull request? This pr modified code to project required data from CSV parsed data when column pruning disabled. In the current master, an exception below happens if `spark.sql.csv.parser.columnPruning.enabled` is false. This is because required formats and CSV parsed formats are different from each other; ``` ./bin/spark-shell --conf spark.sql.csv.parser.columnPruning.enabled=false scala> val dir = "/tmp/spark-csv/csv" scala> spark.range(10).selectExpr("id % 2 AS p", "id").write.mode("overwrite").partitionBy("p").csv(dir) scala> spark.read.csv(dir).selectExpr("sum(p)").collect() 18/06/25 13:48:46 ERROR Executor: Exception in task 2.0 in stage 2.0 (TID 7) java.lang.ClassCastException: org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:101) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getInt(rows.scala:41) ... ``` ## How was this patch tested? Added tests in `CSVSuite`. Author: Takeshi Yamamuro Closes #21657 from maropu/SPARK-24676. --- .../datasources/csv/UnivocityParser.scala | 54 ++++++++++++++----- .../execution/datasources/csv/CSVSuite.scala | 29 ++++++++++ 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index aa545e1a0c00a..79143cce4a380 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -33,29 +33,49 @@ import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ class UnivocityParser( dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(dataSchema.toSet), - "requiredSchema should be the subset of schema.") + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + val tokenizer = { val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { parserSetting.selectIndexes(tokenIndexArr: _*) } new CsvParser(parserSetting) } - private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -82,7 +102,7 @@ class UnivocityParser( // // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } /** @@ -183,7 +203,7 @@ class UnivocityParser( } } - private val doParse = if (schema.nonEmpty) { + private val doParse = if (requiredSchema.nonEmpty) { (input: String) => convert(tokenizer.parseLine(input)) } else { // If `columnPruning` enabled and partition attributes scanned only, @@ -197,15 +217,21 @@ class UnivocityParser( */ def parse(input: String): InternalRow = doParse(input) + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } + private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != schema.length) { + if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { - tokens.take(schema.length) + tokens.take(parsedSchema.length) } def getPartialResult(): Option[InternalRow] = { try { @@ -222,9 +248,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 84b91f6309fe8..ae8110fdf1709 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1579,4 +1579,33 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("SPARK-24676 project required data from parsed data when columnPruning disabled") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id AS c0", "id AS c1").write.partitionBy("p") + .option("header", "true").csv(dir) + val df1 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)", "count(c0)") + checkAnswer(df1, Row(5, 10)) + + // empty required column case + val df2 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)") + checkAnswer(df2, Row(5)) + } + + // the case where tokens length != parsedSchema length + withTempPath { path => + val dir = path.getAbsolutePath + Seq("1,2").toDF().write.text(dir) + // more tokens + val df1 = spark.read.schema("c0 int").format("csv").option("mode", "permissive").load(dir) + checkAnswer(df1, Row(1)) + // less tokens + val df2 = spark.read.schema("c0 int, c1 int, c2 int").format("csv") + .option("mode", "permissive").load(dir) + checkAnswer(df2, Row(1, 2, null)) + } + } + } } From 9f929458fb0a8a106f3b5a6ed3ee2cd3faa85770 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 15 Jul 2018 23:01:36 -0700 Subject: [PATCH 1131/2461] [SPARK-24810][SQL] Fix paths to test files in AvroSuite ## What changes were proposed in this pull request? In the PR, I propose to move `testFile()` to the common trait `SQLTestUtilsBase` and wrap test files in `AvroSuite` by the method `testFile()` which returns full paths to test files in the resource folder. Author: Maxim Gekk Closes #21773 from MaxGekk/test-file. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 79 ++++++++++--------- .../execution/datasources/csv/CSVSuite.scala | 4 - .../datasources/json/JsonSuite.scala | 4 - .../datasources/text/WholeTextFileSuite.scala | 11 +-- .../apache/spark/sql/test/SQLTestUtils.scala | 7 ++ 5 files changed, 53 insertions(+), 52 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 6ed66563b987a..9c6526b29dca3 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - val episodesFile = "src/test/resources/episodes.avro" - val testFile = "src/test/resources/test.avro" + val episodesAvro = testFile("episodes.avro") + val testAvro = testFile("test.avro") override protected def beforeAll(): Unit = { super.beforeAll() @@ -45,18 +45,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { - val originalEntries = spark.read.avro(testFile).collect() + val originalEntries = spark.read.avro(testAvro).collect() val newEntries = spark.read.avro(newFile) checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { - val df = spark.read.avro(episodesFile, episodesFile) + val df = spark.read.avro(episodesAvro, episodesAvro) assert(df.count == 16) } test("reading and writing partitioned data") { - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) val fields = List("title", "air_date", "doctor") for (field <- fields) { withTempPath { dir => @@ -72,14 +72,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("request no fields") { - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { withTempPath { dir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) } @@ -87,7 +87,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("rearrange internal schema") { withTempPath { dir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } @@ -362,7 +362,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val df = spark.read.avro(testFile) + val df = spark.read.avro(testAvro) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") df.write.avro(uncompressDir) spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") @@ -381,49 +381,49 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("dsl test") { - val results = spark.read.avro(episodesFile).select("title").collect() + val results = spark.read.avro(episodesAvro).select("title").collect() assert(results.length === 8) } test("support of various data types") { // This test uses data from test.avro. You can see the data and the schema of this file in // test.json and test.avsc - val all = spark.read.avro(testFile).collect() + val all = spark.read.avro(testAvro).collect() assert(all.length == 3) - val str = spark.read.avro(testFile).select("string").collect() + val str = spark.read.avro(testAvro).select("string").collect() assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) - val simple_map = spark.read.avro(testFile).select("simple_map").collect() + val simple_map = spark.read.avro(testAvro).select("simple_map").collect() assert(simple_map(0)(0).getClass.toString.contains("Map")) assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) - val union0 = spark.read.avro(testFile).select("union_string_null").collect() + val union0 = spark.read.avro(testAvro).select("union_string_null").collect() assert(union0.map(_(0)).toSet == Set("abc", "123", null)) - val union1 = spark.read.avro(testFile).select("union_int_long_null").collect() + val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect() assert(union1.map(_(0)).toSet == Set(66, 1, null)) - val union2 = spark.read.avro(testFile).select("union_float_double").collect() + val union2 = spark.read.avro(testAvro).select("union_float_double").collect() assert( union2 .map(x => new java.lang.Double(x(0).toString)) .exists(p => Math.abs(p - Math.PI) < 0.001)) - val fixed = spark.read.avro(testFile).select("fixed3").collect() + val fixed = spark.read.avro(testAvro).select("fixed3").collect() assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) - val enum = spark.read.avro(testFile).select("enum").collect() + val enum = spark.read.avro(testAvro).select("enum").collect() assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) - val record = spark.read.avro(testFile).select("record").collect() + val record = spark.read.avro(testAvro).select("record").collect() assert(record(0)(0).getClass.toString.contains("Row")) assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) - val array_of_boolean = spark.read.avro(testFile).select("array_of_boolean").collect() + val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect() assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) - val bytes = spark.read.avro(testFile).select("bytes").collect() + val bytes = spark.read.avro(testAvro).select("bytes").collect() assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) } @@ -432,7 +432,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { s""" |CREATE TEMPORARY VIEW avroTable |USING avro - |OPTIONS (path "$episodesFile") + |OPTIONS (path "${episodesAvro}") """.stripMargin.replaceAll("\n", " ")) assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) @@ -443,8 +443,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // get the same values back. withTempPath { dir => val avroDir = s"$dir/avro" - spark.read.avro(testFile).write.avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + spark.read.avro(testAvro).write.avro(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) } } @@ -457,8 +457,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" - spark.read.avro(testFile).write.options(parameters).avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + spark.read.avro(testAvro).write.options(parameters).avro(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -532,10 +532,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("support of globbed paths") { - val e1 = spark.read.avro("*/test/resources/episodes.avro").collect() + val resourceDir = testFile(".") + val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect() assert(e1.length == 8) - val e2 = spark.read.avro("src/*/*/episodes.avro").collect() + val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect() assert(e2.length == 8) } @@ -574,8 +575,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | }] |} """.stripMargin - val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema).avro(testFile).collect() - val expected = spark.read.avro(testFile).select("string").collect() + val result = spark + .read + .option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testAvro) + .collect() + val expected = spark.read.avro(testAvro).select("string").collect() assert(result.sameElements(expected)) } @@ -593,7 +598,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { |} """.stripMargin val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) - .avro(testFile).select("missingField").first + .avro(testAvro).select("missingField").first assert(result === Row("foo")) } @@ -632,7 +637,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { s""" |CREATE TEMPORARY VIEW episodes |USING avro - |OPTIONS (path "$episodesFile") + |OPTIONS (path "${episodesAvro}") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" @@ -657,7 +662,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test save and load") { // Test if load works as expected withTempPath { tempDir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" @@ -671,7 +676,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test load with non-Avro file") { // Test if load works as expected withTempPath { tempDir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" @@ -701,10 +706,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), StructField("array_of_boolean", ArrayType(BooleanType), false), StructField("bytes", BinaryType, true))) - val withSchema = spark.read.schema(partialColumns).avro(testFile).collect() + val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect() val withOutSchema = spark .read - .avro(testFile) + .avro(testAvro) .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") .collect() @@ -722,7 +727,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("non_exist_field", StringType, false), StructField("non_exist_field2", StringType, false))), false))) - val withEmptyColumn = spark.read.schema(schema).avro(testFile).collect() + val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect() assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ae8110fdf1709..63cc5985040c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -60,10 +60,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" - private def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - /** Verifies data and schema. */ private def verifyCars( df: DataFrame, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index eab15b35c97d3..655f40ad549e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -48,10 +48,6 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ - def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index fff0f82f9bc2b..a302d67b5cbf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types.{StringType, StructType} -class WholeTextFileSuite extends QueryTest with SharedSQLContext { +class WholeTextFileSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which // can cause Filesystem.get(Configuration) to return a cached instance created with a different @@ -35,13 +35,10 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { protected override def sparkConf = super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true") - private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString - } - test("reading text file with option wholetext=true") { val df = spark.read.option("wholetext", "true") - .format("text").load(testFile) + .format("text") + .load(testFile("test-data/text-suite.txt")) // schema assert(df.schema == new StructType().add("value", StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index bc4a120f7042f..e562be83822e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -391,6 +391,13 @@ private[sql] trait SQLTestUtilsBase val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } + + /** + * Returns full path to the given file in the resouce folder + */ + protected def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } } private[sql] object SQLTestUtils { From 2603ae30be78c6cb24a67c26fb781fae8763f229 Mon Sep 17 00:00:00 2001 From: sandeep-katta Date: Mon, 16 Jul 2018 14:52:49 +0800 Subject: [PATCH 1132/2461] [SPARK-24558][CORE] wrong Idle Timeout value is used in case of the cacheBlock. It is corrected as per the configuration. ## What changes were proposed in this pull request? IdleTimeout info used to print in the logs is taken based on the cacheBlock. If it is cacheBlock then cachedExecutorIdleTimeoutS is considered else executorIdleTimeoutS ## How was this patch tested? Manual Test spark-sql> cache table sample; 2018-05-15 14:44:02 INFO DAGScheduler:54 - Submitting 3 missing tasks from ShuffleMapStage 0 (MapPartitionsRDD[8] at processCmd at CliDriver.java:376) (first 15 tasks are for partitions Vector(0, 1, 2)) 2018-05-15 14:44:02 INFO YarnScheduler:54 - Adding task set 0.0 with 3 tasks 2018-05-15 14:44:03 INFO ExecutorAllocationManager:54 - Requesting 1 new executor because tasks are backlogged (new desired total will be 1) ... ... 2018-05-15 14:46:10 INFO YarnClientSchedulerBackend:54 - Actual list of executor(s) to be killed is 1 2018-05-15 14:46:10 INFO **ExecutorAllocationManager:54 - Removing executor 1 because it has been idle for 120 seconds (new desired total will be 0)** 2018-05-15 14:46:11 INFO YarnSchedulerBackend$YarnDriverEndpoint:54 - Disabling executor 1. 2018-05-15 14:46:11 INFO DAGScheduler:54 - Executor lost: 1 (epoch 1) Author: sandeep-katta Closes #21565 from sandeep-katta/loginfoBug. --- .../org/apache/spark/ExecutorAllocationManager.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index aa363eeffffb8..17b88631bcb4c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -488,9 +488,15 @@ private[spark] class ExecutorAllocationManager( newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { executorsRemoved.foreach { removedExecutorId => + // If it is a cached block, it uses cachedExecutorIdleTimeoutS for timeout + val idleTimeout = if (blockManagerMaster.hasCachedBlocks(removedExecutorId)) { + cachedExecutorIdleTimeoutS + } else { + executorIdleTimeoutS + } newExecutorTotal -= 1 logInfo(s"Removing executor $removedExecutorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") + s"$idleTimeout seconds (new desired total will be $newExecutorTotal)") executorsPendingToRemove.add(removedExecutorId) } executorsRemoved From 9549a2814951f9ba969955d78ac4bd2240f85989 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 16 Jul 2018 15:44:51 +0800 Subject: [PATCH 1133/2461] [SPARK-24549][SQL] Support Decimal type push down to the parquet data sources ## What changes were proposed in this pull request? Support Decimal type push down to the parquet data sources. The Decimal comparator used is: [`BINARY_AS_SIGNED_INTEGER_COMPARATOR`](https://github.com/apache/parquet-mr/blob/c6764c4a0848abf1d581e22df8b33e28ee9f2ced/parquet-column/src/main/java/org/apache/parquet/schema/PrimitiveComparator.java#L224-L292). ## How was this patch tested? unit tests and manual tests. **manual tests**: ```scala spark.range(10000000).selectExpr("id", "cast(id as decimal(9)) as d1", "cast(id as decimal(9, 2)) as d2", "cast(id as decimal(18)) as d3", "cast(id as decimal(18, 4)) as d4", "cast(id as decimal(38)) as d5", "cast(id as decimal(38, 18)) as d6").coalesce(1).write.option("parquet.block.size", 1048576).parquet("/tmp/spark/parquet/decimal") val df = spark.read.parquet("/tmp/spark/parquet/decimal/") spark.sql("set spark.sql.parquet.filterPushdown.decimal=true") // Only read about 1 MB data df.filter("d2 = 10000").show // Only read about 1 MB data df.filter("d4 = 10000").show spark.sql("set spark.sql.parquet.filterPushdown.decimal=false") // Read 174.3 MB data df.filter("d2 = 10000").show // Read 174.3 MB data df.filter("d4 = 10000").show ``` Author: Yuming Wang Closes #21556 from wangyum/SPARK-24549. --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../FilterPushdownBenchmark-results.txt | 96 ++++---- .../parquet/ParquetFileFormat.scala | 3 +- .../datasources/parquet/ParquetFilters.scala | 225 +++++++++++++----- .../benchmark/FilterPushdownBenchmark.scala | 8 +- .../parquet/ParquetFilterSuite.scala | 90 ++++++- 6 files changed, 324 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 07d33fa7d52ae..41fe0c3b60d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -387,6 +387,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + @@ -1505,6 +1513,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 4f38cc4cee96d..2215ed91e2018 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -292,120 +292,120 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X -Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X -Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X -Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X +Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X +Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X +Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X +Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X -Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X -Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X -Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X +Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X +Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X +Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X +Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X -Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X -Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X -Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X +Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X +Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X +Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X +Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X -Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X -Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X -Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X +Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X +Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X +Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X +Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X -Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X -Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X -Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X +Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X +Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X +Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X +Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X -Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X -Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X -Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X +Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X +Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X +Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X +Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X -Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X -Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X -Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X +Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X +Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X +Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X +Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X -Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X -Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X -Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X +Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X +Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X +Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X +Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X -Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X -Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X -Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X +Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X +Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X +Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X +Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X -Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X -Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X -Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X +Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X +Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X +Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X +Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X -Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X -Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X -Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X +Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X +Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X +Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X +Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X -Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X -Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X -Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X +Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X +Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X +Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X +Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X ================================================================================================ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 3ec33b2f4b540..295960b1c2d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold @@ -367,7 +368,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema - val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 0c146f2f6f915..58b4a769fcb62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.lang.{Long => JLong} +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} import scala.collection.JavaConverters.asScalaBufferConverter @@ -41,44 +42,65 @@ import org.apache.spark.unsafe.types.UTF8String private[parquet] class ParquetFilters( pushDownDate: Boolean, pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, pushDownStartWith: Boolean, pushDownInFilterThreshold: Int) { private case class ParquetSchemaType( originalType: OriginalType, primitiveTypeName: PrimitiveTypeName, + length: Int, decimalMetadata: DecimalMetadata) - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) - private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null) - private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) - private val ParquetLongType = ParquetSchemaType(null, INT64, null) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null) - private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) - private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) - private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) - private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) // Binary.fromString and Binary.fromByteArray don't accept null values case ParquetStringType => @@ -102,21 +124,34 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( @@ -139,6 +174,19 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.notEq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -146,11 +194,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -169,6 +217,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.lt( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -176,11 +234,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -199,6 +257,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.ltEq( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -206,11 +274,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -229,6 +297,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gt( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -236,11 +314,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -259,6 +337,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gtEq( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } /** @@ -271,7 +359,7 @@ private[parquet] class ParquetFilters( // and it does not support to create filters for them. m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => f.getName -> ParquetSchemaType( - f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata) + f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) }.toMap case _ => Map.empty[String, ParquetSchemaType] } @@ -282,21 +370,45 @@ private[parquet] class ParquetFilters( def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToType(name) match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => value.isInstanceOf[Date] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + // Parquet does not allow dots in the column name because dots are used as a column path // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates // with missing columns. The incorrect results could be got from Parquet when we push down // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. - def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") - - // All DataTypes that support `makeEq` can provide better performance. - def shouldConvertInPredicate(name: String): Boolean = nameToType(name) match { - case ParquetBooleanType | ParquetByteType | ParquetShortType | ParquetIntegerType - | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType - | ParquetBinaryType => true - case ParquetDateType if pushDownDate => true - case ParquetTimestampMicrosType | ParquetTimestampMillisType if pushDownTimestamp => true - case _ => false + def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) } // NOTE: @@ -315,29 +427,29 @@ private[parquet] class ParquetFilters( // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if canMakeFilterOn(name) => + case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToType(name)).map(_(name, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name) => + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToType(name)).map(_(name, null)) - case sources.EqualTo(name, value) if canMakeFilterOn(name) => + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.LessThan(name, value) if canMakeFilterOn(name) => + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToType(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThan(name, value) if canMakeFilterOn(name) => + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToType(name)).map(_(name, value)) case sources.And(lhs, rhs) => @@ -362,13 +474,14 @@ private[parquet] class ParquetFilters( case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) - case sources.In(name, values) if canMakeFilterOn(name) && shouldConvertInPredicate(name) + case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => makeEq.lift(nameToType(name)).map(_(name, v)) }.reduceLeftOption(FilterApi.or) - case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => FilterApi.userDefined(binaryColumn(name), new UserDefinedPredicate[Binary] with Serializable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 567a8ebf9d102..bdb60b44750c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -290,8 +290,12 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter s"decimal(${DecimalType.MAX_PRECISION}, 2)" ).foreach { dt => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) - .withColumn("value", monotonically_increasing_id().cast(dt)) + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) withTempTable("orcTable", "patquetTable") { saveAsTable(df, dir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 924f136503656..be4f498c921ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -58,7 +59,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, - conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold) + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -86,6 +88,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df @@ -179,6 +182,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { + withTempPath { file => + data.write.parquet(file.getCanonicalPath) + readParquetFile(file.toString)(f) + } + } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { withTempPath { dir => @@ -512,6 +522,84 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - decimal") { + Seq(true, false).foreach { legacyFormat => + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { + Seq( + s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType + s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType + "a decimal(38, 18)" // ByteArrayDecimalType + ).foreach { schemaDDL => + val schema = StructType.fromDDL(schemaDDL) + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + } + } + } + + test("Ensure that filter value matched the parquet file schema") { + val scale = 2 + val schema = StructType(Seq( + StructField("cint", IntegerType), + StructField("cdecimal1", DecimalType(Decimal.MAX_INT_DIGITS, scale)), + StructField("cdecimal2", DecimalType(Decimal.MAX_LONG_DIGITS, scale)), + StructField("cdecimal3", DecimalType(DecimalType.MAX_PRECISION, scale)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + val decimal = new JBigDecimal(10).setScale(scale) + val decimal1 = new JBigDecimal(10).setScale(scale + 1) + assert(decimal.scale() === scale) + assert(decimal1.scale() === scale + 1) + + assertResult(Some(lt(intColumn("cdecimal1"), 1000: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal1)) + } + + assertResult(Some(lt(longColumn("cdecimal2"), 1000L: java.lang.Long))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal1)) + } + + assert(parquetFilters.createFilter( + parquetSchema, sources.LessThan("cdecimal3", decimal)).isDefined) + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal3", decimal1)) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From cf9704534903b5bbd9bd4834728c92953e45293e Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 16 Jul 2018 09:50:43 -0500 Subject: [PATCH 1134/2461] [SPARK-18230][MLLIB] Throw a better exception, if the user or product doesn't exist When invoking MatrixFactorizationModel.recommendProducts(Int, Int) with a non-existing user, a java.util.NoSuchElementException is thrown: > java.util.NoSuchElementException: next on empty iterator at scala.collection.Iterator$$anon$2.next(Iterator.scala:39) at scala.collection.Iterator$$anon$2.next(Iterator.scala:37) at scala.collection.IndexedSeqLike$Elements.next(IndexedSeqLike.scala:63) at scala.collection.IterableLike$class.head(IterableLike.scala:107) at scala.collection.mutable.WrappedArray.scala$collection$IndexedSeqOptimized$$super$head(WrappedArray.scala:35) at scala.collection.IndexedSeqOptimized$class.head(IndexedSeqOptimized.scala:126) at scala.collection.mutable.WrappedArray.head(WrappedArray.scala:35) at org.apache.spark.mllib.recommendation.MatrixFactorizationModel.recommendProducts(MatrixFactorizationModel.scala:169) ## What changes were proposed in this pull request? Throw a better exception, like "user-id/product-id doesn't found in the model", for a non-existent user/product ## How was this patch tested? Added UT Author: Shahid Closes #21740 from shahidki31/checkInvalidUserProduct. --- .../MatrixFactorizationModel.scala | 23 ++++++++++++++----- .../MatrixFactorizationModelSuite.scala | 21 +++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ac709ad72f0c0..7b49d4d0812f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -78,8 +78,13 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** Predict the rating of one user for one product. */ @Since("0.8.0") def predict(user: Int, product: Int): Double = { - val userVector = userFeatures.lookup(user).head - val productVector = productFeatures.lookup(product).head + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + + val userVector = userFeatureSeq.head + val productVector = productFeatureSeq.head blas.ddot(rank, userVector, 1, productVector, 1) } @@ -164,9 +169,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the product is. */ @Since("1.1.0") - def recommendProducts(user: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) + def recommendProducts(user: Int, num: Int): Array[Rating] = { + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + MatrixFactorizationModel.recommend(userFeatureSeq.head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) + } /** * Recommends users to a product. That is, this returns users who are most likely to be @@ -181,9 +189,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the user is. */ @Since("1.1.0") - def recommendUsers(product: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) + def recommendUsers(product: Int, num: Int): Array[Rating] = { + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + MatrixFactorizationModel.recommend(productFeatureSeq.head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + } protected override val formatVersion: String = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c8ed057a516a..5ed9d077afe78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -72,6 +72,27 @@ class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkCon } } + test("invalid user and product") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + + intercept[IllegalArgumentException] { + // invalid user + model.predict(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.predict(0, 5) + } + intercept[IllegalArgumentException] { + // invalid user + model.recommendProducts(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.recommendUsers(5, 2) + } + } + test("batch predict API recommendProductsForUsers") { val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 From b045315e5d87b7ea3588436053aaa4d5a7bd103f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 16 Jul 2018 23:16:25 +0800 Subject: [PATCH 1135/2461] [SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions. ## What changes were proposed in this pull request? We have some functions which need to aware the nullabilities of all children, such as `CreateArray`, `CreateMap`, `Concat`, and so on. Currently we add casts to fix the nullabilities, but the casts might be removed during the optimization phase. After the discussion, we decided to not add extra casts for just fixing the nullabilities of the nested types, but handle them by functions themselves. ## How was this patch tested? Modified and added some tests. Author: Takuya UESHIN Closes #21704 from ueshin/issues/SPARK-24734/concat_containsnull. --- .../sql/catalyst/analysis/TypeCoercion.scala | 113 ++++++++++-------- .../sql/catalyst/expressions/Expression.scala | 12 +- .../sql/catalyst/expressions/arithmetic.scala | 14 +-- .../expressions/collectionOperations.scala | 22 ++-- .../expressions/complexTypeCreator.scala | 15 ++- .../expressions/conditionalExpressions.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 2 +- .../expressions/nullExpressions.scala | 6 +- .../spark/sql/catalyst/util/TypeUtils.scala | 16 +-- .../catalyst/analysis/TypeCoercionSuite.scala | 43 ++++--- .../ArithmeticExpressionSuite.scala | 12 ++ .../CollectionExpressionsSuite.scala | 60 ++++++++-- .../expressions/ComplexTypeSuite.scala | 19 +++ .../expressions/NullExpressionsSuite.scala | 7 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 8 ++ 15 files changed, 211 insertions(+), 142 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e8331c90ea0f6..316aebdeaffa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -184,6 +184,17 @@ object TypeCoercion { } } + def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { + if (types.isEmpty) { + None + } else { + types.tail.foldLeft[Option[DataType]](Some(types.head)) { + case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) + case _ => None + } + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -259,8 +270,25 @@ object TypeCoercion { } } - private def haveSameType(exprs: Seq[Expression]): Boolean = - exprs.map(_.dataType).distinct.length == 1 + /** + * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull. + */ + def haveSameType(types: Seq[DataType]): Boolean = { + if (types.size <= 1) { + true + } else { + val head = types.head + types.tail.forall(_.sameType(head)) + } + } + + private def castIfNotSameType(expr: Expression, dt: DataType): Expression = { + if (!expr.dataType.sameType(dt)) { + Cast(expr, dt) + } else { + expr + } + } /** * Widens numeric types and converts strings to numbers when appropriate. @@ -525,23 +553,24 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !haveSameType(children) => + case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(c.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } @@ -553,7 +582,8 @@ object TypeCoercion { case None => aj } - case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) => + case s @ Sequence(_, _, _, timeZoneId) + if !haveSameType(s.coercibleChildren.map(_.dataType)) => val types = s.coercibleChildren.map(_.dataType) findWiderCommonType(types) match { case Some(widerDataType) => s.castChildrenTo(widerDataType) @@ -561,33 +591,25 @@ object TypeCoercion { } case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(m.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) case None => m } case m @ CreateMap(children) if m.keys.length == m.values.length && - (!haveSameType(m.keys) || !haveSameType(m.values)) => - val newKeys = if (haveSameType(m.keys)) { - m.keys - } else { - val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) - case None => m.keys - } + (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { + case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) + case None => m.keys } - val newValues = if (haveSameType(m.values)) { - m.values - } else { - val types = m.values.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) - case None => m.values - } + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { + case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) + case None => m.values } CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) @@ -610,27 +632,27 @@ object TypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if !haveSameType(es) => + case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => val types = es.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) + case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => c } // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if // we need to truncate, but we should not promote one side to string if the other side is // string.g - case g @ Greatest(children) if !haveSameType(children) => + case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) case None => g } - case l @ Least(children) if !haveSameType(children) => + case l @ Least(children) if !haveSameType(l.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) case None => l } @@ -672,27 +694,14 @@ object TypeCoercion { object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual => + case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => - var changed = false val newBranches = c.branches.map { case (condition, value) => - if (value.dataType.sameType(commonType)) { - (condition, value) - } else { - changed = true - (condition, Cast(value, commonType)) - } - } - val newElseValue = c.elseValue.map { value => - if (value.dataType.sameType(commonType)) { - value - } else { - changed = true - Cast(value, commonType) - } + (condition, castIfNotSameType(value, commonType)) } - if (changed) CaseWhen(newBranches, newElseValue) else c + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) }.getOrElse(c) } } @@ -705,10 +714,10 @@ object TypeCoercion { plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual => + case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType) - val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType) + val newLeft = castIfNotSameType(left, widestType) + val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 44c5556ff9ccf..f7d1b105964d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,22 +709,12 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) - /** - * A method determining whether the input types are equal ignoring nullable, containsNull and - * valueContainsNull flags and thus convenient for resolution of the final data type. - */ - def areInputTypesForMergingEqual: Boolean = { - inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) - } - } - override def dataType: DataType = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") require( - areInputTypesForMergingEqual, + TypeCoercion.haveSameType(inputTypesForMerging), "All input types must be the same except nullable, containsNull, valueContainsNull flags.") inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe91e520169b4..55940410cc4d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { > SELECT _FUNC_(10, 9, 2, 4, 3); 2 """) -case class Least(children: Seq[Expression]) extends Expression { +case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") @@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression { > SELECT _FUNC_(10, 9, 2, 4, 3); 10 """) -case class Greatest(children: Seq[Expression]) extends Expression { +case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") @@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0f4f4f1601b4a..972bc6e57892c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -507,7 +507,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] """, since = "2.4.0") -case class MapConcat(children: Seq[Expression]) extends Expression { +case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def checkInputDataTypes(): TypeCheckResult = { var funcName = s"function $prettyName" @@ -521,14 +521,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression { } override def dataType: MapType = { - val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption - .getOrElse(MapType(StringType, StringType)) - val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) - .exists(_.valueContainsNull) - if (dt.valueContainsNull != valueContainsNull) { - dt.copy(valueContainsNull = valueContainsNull) + if (children.isEmpty) { + MapType(StringType, StringType) } else { - dt + super.dataType.asInstanceOf[MapType] } } @@ -2211,7 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); | [1,2,3,4,5,6] """) -case class Concat(children: Seq[Expression]) extends Expression { +case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH @@ -2232,7 +2228,13 @@ case class Concat(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + override def dataType: DataType = { + if (children.isEmpty) { + StringType + } else { + super.dataType + } + } lazy val javaType: String = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0a5f8a907b50a..a43de028360b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(StringType), + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) + .getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -179,11 +180,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure( s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given keys of function map should all be the same type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given values of function map should all be the same type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) @@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(StringType), - valueType = values.headOption.map(_.dataType).getOrElse(StringType), + keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) + .getOrElse(StringType), + valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) + .getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 30ce9e4743da9..3b597e8b5263b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -48,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + s"not ${predicate.dataType.simpleString}") - } else if (!areInputTypesForMergingEqual) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { @@ -137,7 +137,7 @@ case class CaseWhen( } override def checkInputDataTypes(): TypeCheckResult = { - if (areInputTypesForMergingEqual) { + if (TypeCoercion.haveSameType(inputTypesForMerging)) { // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0cc2a332f2c30..0efd1224f1bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -186,7 +186,7 @@ object Literal { case map: MapType => create(Map(), map) case struct: StructType => create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) - case udt: UserDefinedType[_] => default(udt.sqlType) + case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) case other => throw new RuntimeException(s"no default for type $dataType") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 2eeed3bbb2d91..b683d2a7e9ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ 1 """) // scalastyle:on line.size.limit -case class Coalesce(children: Seq[Expression]) extends Expression { +case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = children.forall(_.nullable) @@ -61,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..b795abea95a74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ @@ -42,18 +42,12 @@ object TypeUtils { } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.size <= 1) { + if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) - } - } - TypeCheckResult.TypeCheckSuccess + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 8cc5a23779a2a..4161f09c63190 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -687,46 +687,43 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), - Coalesce(Seq(Cast(doubleLit, DoubleType), - Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) ruleTest(rule, Coalesce(Seq(longLit, intLit, decimalLit)), Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), - Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + Cast(intLit, DecimalType(22, 0)), decimalLit))) ruleTest(rule, Coalesce(Seq(nullLit, intLit)), - Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + Coalesce(Seq(Cast(nullLit, IntegerType), intLit))) ruleTest(rule, Coalesce(Seq(timestampLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, intLit)), - Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), - Cast(intLit, FloatType)))) + Coalesce(Seq(Cast(nullLit, FloatType), floatNullLit, Cast(intLit, FloatType)))) ruleTest(rule, Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), - Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + Cast(decimalLit, DoubleType), doubleLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), - Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + Cast(doubleLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(timestampLit, intLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), - Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), - Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) + Cast(intArrayLit, ArrayType(StringType)), strArrayLit))) } test("CreateArray casts") { @@ -735,7 +732,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - CreateArray(Cast(Literal(1.0), DoubleType) + CreateArray(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -747,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Cast(Literal(1.0), StringType) :: Cast(Literal(1), StringType) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -765,7 +762,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) } @@ -779,7 +776,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateMap(Cast(Literal(1), FloatType) :: Literal("a") - :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal.create(2.0, FloatType) :: Literal("b") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -801,7 +798,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Literal(1) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) @@ -814,7 +811,7 @@ class TypeCoercionSuite extends AnalysisTest { CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) :: Literal(2) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -824,8 +821,8 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Cast(Literal(1), DoubleType) - :: Cast(Literal("a"), StringType) - :: Cast(Literal(2.0), DoubleType) + :: Literal("a") + :: Literal(2.0) :: Cast(Literal(3.0), StringType) :: Nil)) } @@ -837,7 +834,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - operator(Cast(Literal(1.0), DoubleType) + operator(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -848,14 +845,14 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), operator(Cast(Literal(1L), DecimalType(22, 0)) :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) :: Nil), - operator(Literal(1.0).cast(DoubleType) + operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6edb4348f8309..021217606dc03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -282,6 +282,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } + + val least = Least(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(least.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(least, Seq(1, 2)) } test("function greatest") { @@ -334,6 +340,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } + + val greatest = Greatest(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(greatest.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(greatest, Seq(1, 3, null)) } test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85d6a1befed6b..f1e3bd091565d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -227,6 +227,27 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) assert(!MapConcat(Seq(m1, m2)).nullable) assert(MapConcat(Seq(m1, mNull)).nullable) + + val mapConcat = MapConcat(Seq( + Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + MapType( + ArrayType(IntegerType, containsNull = false), + ArrayType(StringType, containsNull = false), + valueContainsNull = false)), + Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)))) + assert(mapConcat.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)) + checkEvaluation(mapConcat, Map( + Seq(1, 2) -> Seq("a", "b"), + Seq(3, 4, null) -> Seq("c", "d", null), + Seq(6) -> null)) } test("MapFromEntries") { @@ -1050,11 +1071,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper test("Concat") { // Primitive-type elements - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) - val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) - val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) - val ai4 = Literal.create(null, ArrayType(IntegerType)) + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(null, ArrayType(IntegerType, containsNull = false)) checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) @@ -1067,14 +1088,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai4, ai0)), null) // Non-primitive-type elements - val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) - val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) - val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) - val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) - val as4 = Literal.create(null, ArrayType(StringType)) - - val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) - val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa2 = Literal.create(Seq(Seq("g", null), null), + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) @@ -1087,6 +1112,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + + assert(Concat(Seq(ai0, ai1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(ai0, ai2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(as0, as1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(as0, as2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(aa0, aa1)).dataType === + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + assert(Concat(Seq(aa0, aa2)).dataType === + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) } test("Flatten") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 726193b411737..77aaf55480ec2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -144,6 +144,13 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) + + val array = CreateArray(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)))) + assert(array.dataType === + ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + checkEvaluation(array, Seq(intSeq, intSeq :+ null)) } test("CreateMap") { @@ -184,6 +191,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } + + val map = CreateMap(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(strSeq, ArrayType(StringType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)), + Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true)))) + assert(map.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = false)) + checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) } test("MapFromArrays") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 424c3a4696077..6e07f7a59b730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -86,6 +86,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } + + val coalesce = Coalesce(Seq( + Literal.create(null, ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(coalesce.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(coalesce, Seq(1, 2, 3)) } test("SPARK-16602 Nvl should support numeric-string cases") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d4615714cff03..3f6f4556e2d6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1622,6 +1622,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg.contains(s"input to function $name requires at least two arguments")) } } + + test("SPARK-24734: Fix containsNull of Concat for array type") { + val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v") + val ex = intercept[RuntimeException] { + df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() + } + assert(ex.getMessage.contains("Cannot use null as map key")) + } } object DataFrameFunctionsSuite { From b0c95a1d698888df752bd62e49838a98268f6847 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 16 Jul 2018 14:28:35 -0700 Subject: [PATCH 1136/2461] [SPARK-23901][SQL] Removing masking functions The PR reverts #21246. Author: Marek Novotny Closes #21786 from mn-mikke/SPARK-23901. --- .../expressions/MaskExpressionsUtils.java | 80 --- .../catalyst/analysis/FunctionRegistry.scala | 8 - .../expressions/maskExpressions.scala | 569 ------------------ .../expressions/MaskExpressionsSuite.scala | 236 -------- .../org/apache/spark/sql/functions.scala | 119 ---- .../spark/sql/DataFrameFunctionsSuite.scala | 107 ---- 6 files changed, 1119 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java deleted file mode 100644 index 05879902a4ed9..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -/** - * Contains all the Utils methods used in the masking expressions. - */ -public class MaskExpressionsUtils { - static final int UNMASKED_VAL = -1; - - /** - * Returns the masking character for {@param c} or {@param c} is it should not be masked. - * @param c the character to transform - * @param maskedUpperChar the character to use instead of a uppercase letter - * @param maskedLowerChar the character to use instead of a lowercase letter - * @param maskedDigitChar the character to use instead of a digit - * @param maskedOtherChar the character to use instead of a any other character - * @return masking character for {@param c} - */ - public static int transformChar( - final int c, - int maskedUpperChar, - int maskedLowerChar, - int maskedDigitChar, - int maskedOtherChar) { - switch(Character.getType(c)) { - case Character.UPPERCASE_LETTER: - if(maskedUpperChar != UNMASKED_VAL) { - return maskedUpperChar; - } - break; - - case Character.LOWERCASE_LETTER: - if(maskedLowerChar != UNMASKED_VAL) { - return maskedLowerChar; - } - break; - - case Character.DECIMAL_DIGIT_NUMBER: - if(maskedDigitChar != UNMASKED_VAL) { - return maskedDigitChar; - } - break; - - default: - if(maskedOtherChar != UNMASKED_VAL) { - return maskedOtherChar; - } - break; - } - - return c; - } - - /** - * Returns the replacement char to use according to the {@param rep} specified by the user and - * the {@param def} default. - */ - public static int getReplacementChar(String rep, int def) { - if (rep != null && rep.length() > 0) { - return rep.codePointAt(0); - } - return def; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1d9e470c9ba83..d696ce9a766d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -440,14 +440,6 @@ object FunctionRegistry { expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, - // mask functions - expression[Mask]("mask"), - expression[MaskFirstN]("mask_first_n"), - expression[MaskLastN]("mask_last_n"), - expression[MaskShowFirstN]("mask_show_first_n"), - expression[MaskShowLastN]("mask_show_last_n"), - expression[MaskHash]("mask_hash"), - // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala deleted file mode 100644 index 276a57266a6e0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ /dev/null @@ -1,569 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.commons.codec.digest.DigestUtils - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ -import org.apache.spark.sql.catalyst.expressions.MaskLike._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -trait MaskLike { - def upper: String - def lower: String - def digit: String - - protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) - protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) - protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) - - protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName - - def inputStringLengthCode(inputString: String, length: String): String = { - s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" - } - - def appendMaskedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, - | $upperReplacement, $lowerReplacement, - | $digitReplacement, $defaultMaskedOther)); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendUnchangedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($codePoint); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendMaskedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(transformChar( - codePoint, - upperReplacement, - lowerReplacement, - digitReplacement, - defaultMaskedOther)) - offset += Character.charCount(codePoint) - } - offset - } - - def appendUnchangedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(codePoint) - offset += Character.charCount(codePoint) - } - offset - } -} - -trait MaskLikeWithN extends MaskLike { - def n: Int - protected lazy val charCount: Int = if (n < 0) 0 else n -} - -/** - * Utils for mask operations. - */ -object MaskLike { - val defaultCharCount = 4 - val defaultMaskedUppercase: Int = 'X' - val defaultMaskedLowercase: Int = 'x' - val defaultMaskedDigit: Int = 'n' - val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL - - def extractCharCount(e: Expression): Int = e match { - case Literal(i, IntegerType | NullType) => - if (i == null) defaultCharCount else i.asInstanceOf[Int] - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } - - def extractReplacement(e: Expression): String = e match { - case Literal(s, StringType | NullType) => if (s == null) null else s.toString - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${StringType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } -} - -/** - * Masks the input string. Additional parameters can be set to change the masking chars for - * uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); - llll-UUUU-####-#### - """) -// scalastyle:on line.size.limit -case class Mask(child: Expression, upper: String, lower: String, digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLike { - - def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) - - def this(child: Expression, upper: Expression) = - this(child, extractReplacement(upper), null, null) - - def this(child: Expression, upper: Expression, lower: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), null) - - def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val sb = new java.lang.StringBuilder(length) - appendMaskedToStringBuilder(sb, str, 0, length) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |StringBuilder $sb = new StringBuilder($length); - |${CodeGenerator.JAVA_INT} $offset = 0; - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} - |${ev.value} = UTF8String.fromString($sb.toString()); - """.stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) -} - -/** - * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-5678-8765-4321 - """) -// scalastyle:on line.size.limit -case class MaskFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_first_n" -} - -/** - * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-5678-8765-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? - | 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_last_n" -} - -/** - * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-nnnn-nnnn-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_first_n" -} - -/** - * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-nnnn-nnnn-4321 - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_last_n" -} - -/** - * Returns a hashed value based on str. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321"); - 60c713f5ec6912229d2060df1c322776 - """) -// scalastyle:on line.size.limit -case class MaskHash(child: Expression) - extends UnaryExpression with ExpectsInputTypes { - - override def nullSafeEval(input: Any): Any = { - UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") - s""" - |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_hash" -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala deleted file mode 100644 index 4d69dc32ace82..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.{IntegerType, StringType} - -class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - test("mask") { - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") - checkEvaluation( - new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-####-####") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), - "llll-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal(null, StringType)), null) - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") - checkEvaluation(new Mask( - Literal("abcd-EFGH-8765-4321"), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), - "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("")), "") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") - checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), - "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") - // scalastyle:on nonascii - intercept[AnalysisException] { - checkEvaluation(new Mask(Literal(""), Literal(1)), "") - } - } - - test("mask_first_n") { - checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UFGH-8765") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UFGH-8765-4321") - checkEvaluation( - new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XFGH-8765-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-EFGH-8765-4321") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("")), "") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-EFGH-8765-4321") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xxxxo World") - // scalastyle:on nonascii - } - - test("mask_last_n") { - checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), - "abcd-EFGU-lU#l") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EFGU-####") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), - "abcd-EFGX-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") - } - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal(null, StringType)), null) - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), - "abcd-EFUU-nnnn-nnnn") - checkEvaluation(new MaskLastN(Literal("")), "") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), - "abcx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), - "あ, 𠀋, Hello Xxxxx あ 𠀋") - // scalastyle:on nonascii - } - - test("mask_show_first_n") { - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), - "abcd-EUUU-####-lU#l") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EUUU-####-####") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "abcd-EXXX-nnnn-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-UUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("")), "") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "abcd-XXXX-nnnn-nnnn") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Hellx Xxxxx") - // scalastyle:on nonascii - } - - test("mask_show_last_n") { - checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UUUH-8765") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-###5-4321") - checkEvaluation( - new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XXXX-nnn5-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-UUUU-nnnn-4321") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("")), "") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-XXXX-nnnn-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xello World") - // scalastyle:on nonascii - } - - test("mask_hash") { - checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") - checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") - checkEvaluation(MaskHash(Literal(null, StringType)), null) - // scalastyle:off nonascii - checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") - // scalastyle:on nonascii - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b98ab11e56feb..de1d422856ba9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3646,125 +3646,6 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } - ////////////////////////////////////////////////////////////////////////////////////////////// - // Mask functions - ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Returns a string which is the masked representation of the input. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column): Column = withExpr { new Mask(e.expr) } - - /** - * Returns a string which is the masked representation of the input, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { - Mask(e.expr, upper, lower, digit) - } - - /** - * Returns a string with the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } - - /** - * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } - - /** - * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n(e: Column, n: Int): Column = withExpr { - new MaskShowFirstN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n(e: Column, n: Int): Column = withExpr { - new MaskShowLastN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a hashed value based on the input column. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } - // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3f6f4556e2d6b..a39aef18d27e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -309,113 +309,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("mask functions") { - val df = Seq("TestString-123", "", null).toDF("a") - checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4)), - Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4)), - Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_hash($"a")), - Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), - Row("d41d8cd98f00b204e9800998ecf8427e"), - Row(null))) - - checkAnswer(df.select(mask($"a", "U", "l", "#")), - Seq(Row("UlllUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), - Seq(Row("TestString-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), - Seq(Row("TestUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllUlllll-123"), Row(""), Row(null))) - - checkAnswer( - df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), - Row("", "", "", ""), - Row(null, null, null, null))) - checkAnswer(sql("select mask(null)"), Row(null)) - checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) - intercept[AnalysisException] { - checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - } - - checkAnswer( - df.selectExpr( - "mask_first_n(a)", - "mask_first_n(a, 6)", - "mask_first_n(a, 6, 'U')", - "mask_first_n(a, 6, 'U', 'l')", - "mask_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", - "UlllUlring-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) - } - - checkAnswer( - df.selectExpr( - "mask_last_n(a)", - "mask_last_n(a, 6)", - "mask_last_n(a, 6, 'U')", - "mask_last_n(a, 6, 'U', 'l')", - "mask_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", - "TestStrill-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_first_n(a)", - "mask_show_first_n(a, 6)", - "mask_show_first_n(a, 6, 'U')", - "mask_show_first_n(a, 6, 'U', 'l')", - "mask_show_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", - "TestStllll-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_last_n(a)", - "mask_show_last_n(a, 6)", - "mask_show_last_n(a, 6, 'U')", - "mask_show_last_n(a, 6, 'U', 'l')", - "mask_show_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", - "UlllUlllng-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) - } - - checkAnswer(sql("select mask_hash(null)"), Row(null)) - } - test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From ba437fc5c73b95ee4c59327abf3161c58f64cb12 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 16 Jul 2018 14:35:44 -0700 Subject: [PATCH 1137/2461] [SPARK-24805][SQL] Do not ignore avro files without extensions by default ## What changes were proposed in this pull request? In the PR, I propose to change default behaviour of AVRO datasource which currently ignores files without `.avro` extension in read by default. This PR sets the default value for `avro.mapred.ignore.inputs.without.extension` to `false` in the case if the parameter is not set by an user. ## How was this patch tested? Added a test file without extension in AVRO format, and new test for reading the file with and wihout specified schema. Author: Maxim Gekk Author: Maxim Gekk Closes #21769 from MaxGekk/avro-without-extension. --- .../spark/sql/avro/AvroFileFormat.scala | 14 +++--- .../org/apache/spark/sql/avro/AvroSuite.scala | 45 ++++++++++++++++--- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index fb93033bb15d4..9eb206457809c 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -62,7 +62,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Schema evolution is not supported yet. Here we only pick a single random sample file to // figure out the schema of the whole dataset. val sampleFile = - if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf)) { files.find(_.getPath.getName.endsWith(".avro")).getOrElse { throw new FileNotFoundException( "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + @@ -170,10 +170,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Doing input file filtering is improper because we may generate empty tasks that process no // input files but stress the scheduler. We should probably add a more general input file // filtering mechanism for `FileFormat` data sources. See SPARK-16317. - if ( - conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && - !file.filePath.endsWith(".avro") - ) { + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf) && !file.filePath.endsWith(".avro")) { Iterator.empty } else { val reader = { @@ -278,4 +275,11 @@ private[avro] object AvroFileFormat { value.readFields(new DataInputStream(in)) } } + + def ignoreFilesWithoutExtensions(conf: Configuration): Boolean = { + // Files without .avro extensions are not ignored by default + val defaultValue = false + + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, defaultValue) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 9c6526b29dca3..446b42124ceca 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.avro import java.io._ -import java.nio.file.Files +import java.net.URL +import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.util.{TimeZone, UUID} @@ -622,7 +623,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { intercept[FileNotFoundException] { withTempPath { dir => FileUtils.touch(new File(dir, "test")) - spark.read.avro(dir.toString) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + spark.read.avro(dir.toString) + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } } } @@ -684,12 +690,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Files.createFile(new File(tempSaveDir, "non-avro").toPath) - val newDf = spark - .read - .option(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") - .avro(tempSaveDir) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + val count = try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + val newDf = spark + .read + .avro(tempSaveDir) + newDf.count() + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + } - assert(newDf.count == 8) + assert(count == 8) } } @@ -805,4 +817,23 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(readDf.collect().sameElements(writeDf.collect())) } } + + test("SPARK-24805: do not ignore files without .avro extension by default") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" + val df1 = spark.read.avro(fileWithoutExtension) + assert(df1.count == 8) + + val schema = new StructType() + .add("title", StringType) + .add("air_date", StringType) + .add("doctor", IntegerType) + val df2 = spark.read.schema(schema).avro(fileWithoutExtension) + assert(df2.count == 8) + } + } } From 0f0d1865f581a9158d73505471953656b173beba Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 16 Jul 2018 15:33:39 -0700 Subject: [PATCH 1138/2461] [SPARK-24402][SQL] Optimize `In` expression when only one element in the collection or collection is empty ## What changes were proposed in this pull request? Two new rules in the logical plan optimizers are added. 1. When there is only one element in the **`Collection`**, the physical plan will be optimized to **`EqualTo`**, so predicate pushdown can be used. ```scala profileDF.filter( $"profileID".isInCollection(Set(6))).explain(true) """ |== Physical Plan == |*(1) Project [profileID#0] |+- *(1) Filter (isnotnull(profileID#0) && (profileID#0 = 6)) | +- *(1) FileScan parquet [profileID#0] Batched: true, Format: Parquet, | PartitionFilters: [], | PushedFilters: [IsNotNull(profileID), EqualTo(profileID,6)], | ReadSchema: struct """.stripMargin ``` 2. When the **`Collection`** is empty, and the input is nullable, the logical plan will be simplified to ```scala profileDF.filter( $"profileID".isInCollection(Set())).explain(true) """ |== Optimized Logical Plan == |Filter if (isnull(profileID#0)) null else false |+- Relation[profileID#0] parquet """.stripMargin ``` TODO: 1. For multiple conditions with numbers less than certain thresholds, we should still allow predicate pushdown. 2. Optimize the **`In`** using **`tableswitch`** or **`lookupswitch`** when the numbers of the categories are low, and they are **`Int`**, **`Long`**. 3. The default immutable hash trees set is slow for query, and we should do benchmark for using different set implementation for faster query. 4. **`filter(if (condition) null else false)`** can be optimized to false. ## How was this patch tested? Couple new tests are added. Author: DB Tsai Closes #21442 from dbtsai/optimize-in. --- .../sql/catalyst/optimizer/expressions.scala | 13 +++++--- .../catalyst/optimizer/OptimizeInSuite.scala | 32 +++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..f78a0ff95f382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,15 +218,20 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 && !newList.isInstanceOf[ListQuery]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..86522a6a54ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } From d57a267b79f4015508c3686c34a0f438bad41ea1 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 17 Jul 2018 09:13:35 +0800 Subject: [PATCH 1139/2461] [SPARK-23259][SQL] Clean up legacy code around hive external catalog and HiveClientImpl ## What changes were proposed in this pull request? Three legacy statements are removed by this patch: - in HiveExternalCatalog: The withClient wrapper is not necessary for the private method getRawTable. - in HiveClientImpl: There are some redundant code in both the tableExists and getTableOption method. This PR takes over https://github.com/apache/spark/pull/20425 ## How was this patch tested? Existing tests Closes #20425 Author: hyukjinkwon Closes #21780 from HyukjinKwon/SPARK-23259. --- .../org/apache/spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 011a3ba553cb2..44480cecf0039 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = { client.getTable(db, table) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 1df46d7431a21..db8fd5a43d842 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -353,15 +353,19 @@ private[hive] class HiveClientImpl( client.getDatabasesByPattern(pattern).asScala } + private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { + Option(client.getTable(dbName, tableName, false /* do not throw exception */)) + } + override def tableExists(dbName: String, tableName: String): Boolean = withHiveState { - Option(client.getTable(dbName, tableName, false /* do not throw exception */)).nonEmpty + getRawTableOption(dbName, tableName).nonEmpty } override def getTableOption( dbName: String, tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") - Option(client.getTable(dbName, tableName, false)).map { h => + getRawTableOption(dbName, tableName).map { h => // Note: Hive separates partition columns and the schema, but for us the // partition columns are part of the schema val cols = h.getCols.asScala.map(fromHiveColumn) From f876d3fa800ae04ec33f27295354669bb1db911e Mon Sep 17 00:00:00 2001 From: Miklos C Date: Tue, 17 Jul 2018 09:22:16 +0800 Subject: [PATCH 1140/2461] [SPARK-20220][DOCS] Documentation Add thrift scheduling pool config to scheduling docs ## What changes were proposed in this pull request? The thrift scheduling pool configuration was removed from a previous release. Adding this back to the job scheduling configuration docs. This PR takes over #17536 and handle some comments here. ## How was this patch tested? Manually. Closes #17536 Author: hyukjinkwon Closes #21778 from HyukjinKwon/SPARK-20220. --- docs/job-scheduling.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index da90342406c84..2316f175676ee 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -264,3 +264,11 @@ within it for the various settings. For example: A full example is also available in `conf/fairscheduler.xml.template`. Note that any pools not configured in the XML file will simply get default values for all settings (scheduling mode FIFO, weight 1, and minShare 0). + +## Scheduling using JDBC Connections +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + +{% highlight SQL %} +SET spark.sql.thriftserver.scheduler.pool=accounting; +{% endhighlight %} From 0ca16f6e143768f0c96b5310c1f81b3b51dcbbc8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 17 Jul 2018 11:30:53 +0800 Subject: [PATCH 1141/2461] Revert "[SPARK-24402][SQL] Optimize `In` expression when only one element in the collection or collection is empty" This reverts commit 0f0d1865f581a9158d73505471953656b173beba. --- .../sql/catalyst/optimizer/expressions.scala | 13 +++----- .../catalyst/optimizer/OptimizeInSuite.scala | 32 ------------------- 2 files changed, 4 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index f78a0ff95f382..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,20 +218,15 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty => - // When v is not nullable, the following expression will be optimized - // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.length == 1 && !newList.isInstanceOf[ListQuery]) { - EqualTo(v, newList.head) - } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.length < list.length) { + } else if (newList.size < list.size) { expr.copy(list = newList) - } else { // newList.length == list.length && newList.length > 1 + } else { // newList.length == list.length expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 86522a6a54ed5..478118ed709f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,21 +176,6 @@ class OptimizeInSuite extends PlanTest { } } - test("OptimizedIn test: one element in list gets transformed to EqualTo.") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) - .analyze - - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) - .analyze - - comparePlans(optimized, correctAnswer) - } - test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -206,21 +191,4 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - - test("OptimizedIn test: In empty list gets transformed to `If` expression " + - "when value is nullable") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Nil)) - .analyze - - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(If(IsNotNull(UnresolvedAttribute("a")), - Literal(false), Literal.create(null, BooleanType))) - .analyze - - comparePlans(optimized, correctAnswer) - } } From 4cf1bec4dc574c541d03ea2f49db4de8b76ef6d2 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 17 Jul 2018 23:07:18 +0800 Subject: [PATCH 1142/2461] [SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions. ## What changes were proposed in this pull request? The PR tries to avoid serialization of private fields of already added collection functions and follows up on comments in [SPARK-23922](https://github.com/apache/spark/pull/21028) and [SPARK-23935](https://github.com/apache/spark/pull/21236) ## How was this patch tested? Run tests from: - CollectionExpressionSuite.scala - DataFrameFunctionsSuite.scala Author: Marek Novotny Closes #21352 from mn-mikke/SPARK-24305. --- .../expressions/collectionOperations.scala | 132 +++++++++--------- 1 file changed, 64 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 972bc6e57892c..d60f4c36fa214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -168,27 +168,22 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) - override def dataType: DataType = ArrayType(mountSchema) - - override def nullable: Boolean = children.exists(_.nullable) - - private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - - private lazy val arrayElementTypes = arrayTypes.map(_.elementType) - - @transient private lazy val mountSchema: StructType = { + @transient override lazy val dataType: DataType = { val fields = children.zip(arrayElementTypes).zipWithIndex.map { case ((expr: NamedExpression, elementType), _) => StructField(expr.name, elementType, nullable = true) case ((_, elementType), idx) => StructField(idx.toString, elementType, nullable = true) } - StructType(fields) + ArrayType(StructType(fields), containsNull = false) } - @transient lazy val numberOfArrays: Int = children.length + override def nullable: Boolean = children.exists(_.nullable) + + @transient private lazy val arrayElementTypes = + children.map(_.dataType.asInstanceOf[ArrayType].elementType) - @transient lazy val genericArrayData = classOf[GenericArrayData].getName + private def genericArrayData = classOf[GenericArrayData].getName def emptyInputGenCode(ev: ExprCode): ExprCode = { ev.copy(code""" @@ -256,7 +251,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI ("ArrayData[]", arrVals) :: Nil) val initVariables = s""" - |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |ArrayData[] $arrVals = new ArrayData[${children.length}]; |int $biggestCardinality = 0; |${CodeGenerator.javaType(dataType)} ${ev.value} = null; """.stripMargin @@ -268,7 +263,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |if (!${ev.isNull}) { | Object[] $args = new Object[$biggestCardinality]; | for (int $i = 0; $i < $biggestCardinality; $i ++) { - | Object[] $currentRow = new Object[$numberOfArrays]; + | Object[] $currentRow = new Object[${children.length}]; | $getValueForTypeSplitted | $args[$i] = new $genericInternalRow($currentRow); | } @@ -278,7 +273,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (numberOfArrays == 0) { + if (children.length == 0) { emptyInputGenCode(ev) } else { nonEmptyInputGenCode(ctx, ev) @@ -360,7 +355,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] override def dataType: DataType = { ArrayType( @@ -520,7 +515,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres } } - override def dataType: MapType = { + @transient override lazy val dataType: MapType = { if (children.isEmpty) { MapType(StringType, StringType) } else { @@ -747,11 +742,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { case _ => None } - private def nullEntries: Boolean = dataTypeDetails.get._3 + @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 override def nullable: Boolean = child.nullable || nullEntries - override def dataType: MapType = dataTypeDetails.get._1 + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess @@ -949,8 +944,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient - private lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -972,8 +966,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient - private lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -995,7 +988,9 @@ trait ArraySortLike extends ExpectsInputTypes { } } - def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + @transient lazy val elementType: DataType = + arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull def sortEval(array: Any, ascending: Boolean): Any = { @@ -1211,7 +1206,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -1601,9 +1596,9 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - override def children: Seq[Expression] = Seq(x, start, length) + @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval - lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -1889,7 +1884,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1930,7 +1925,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast min } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -1954,7 +1949,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1995,7 +1990,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast max } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -2097,10 +2092,13 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) - override def dataType: DataType = left.dataType match { + @transient override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType } @@ -2109,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti Seq(TypeCollection(ArrayType, MapType), left.dataType match { case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _: MapType => mapKeyType case _ => AnyDataType // no match for a wrong 'left' expression type } ) @@ -2119,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr( - left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess } } @@ -2142,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } else { array.numElements() + index } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + if (arrayContainsNull && array.isNullAt(idx)) { null } else { array.get(idx, dataType) } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) + getValueEval(value, ordinal, mapKeyType, ordering) } } @@ -2158,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val nullCheck = if (arrayContainsNull) { s""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; @@ -2209,9 +2206,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti """) case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2228,7 +2223,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - override def dataType: DataType = { + @transient override lazy val dataType: DataType = { if (children.isEmpty) { StringType } else { @@ -2236,7 +2231,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - lazy val javaType: String = CodeGenerator.javaType(dataType) + private def javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -2256,9 +2251,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } else { val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -2316,9 +2312,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio |for (int z = 0; z < ${children.length}; z++) { | $numElements += args[z].numElements(); |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -2413,15 +2410,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] override def nullable: Boolean = child.nullable || childDataType.containsNull - override def dataType: DataType = childDataType.elementType + @transient override lazy val dataType: DataType = childDataType.elementType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -2441,9 +2436,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s"$numberOfElements elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -2476,9 +2472,10 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > $MAX_ARRAY_LENGTH) { + |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | $variableName + " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin (code, variableName) @@ -2602,7 +2599,7 @@ case class Sequence( override def nullable: Boolean = children.exists(_.nullable) - override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def checkInputDataTypes(): TypeCheckResult = { val startType = start.dataType @@ -2633,7 +2630,7 @@ case class Sequence( stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), timeZoneId) - private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: SequenceImpl = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) @@ -2953,8 +2950,6 @@ object Sequence { case class ArrayRepeat(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) @@ -2966,9 +2961,9 @@ case class ArrayRepeat(left: Expression, right: Expression) if (count == null) { null } else { - if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -3027,9 +3022,10 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -3111,7 +3107,7 @@ case class ArrayRemove(left: Expression, right: Expression) Seq(ArrayType, elementType) } - lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -3228,7 +3224,7 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType - @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) From 5215344deaa5533e593c62aba3fcdfa1a2901801 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 17 Jul 2018 11:23:34 -0500 Subject: [PATCH 1143/2461] [SPARK-24813][BUILD][FOLLOW-UP][HOTFIX] HiveExternalCatalogVersionsSuite still flaky; fall back to Apache archive ## What changes were proposed in this pull request? Test HiveExternalCatalogVersionsSuite vs only current Spark releases ## How was this patch tested? `HiveExternalCatalogVersionsSuite` Author: Sean Owen Closes #21793 from srowen/SPARK-24813.3. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index f8212684d5335..5103aa8a207db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -203,7 +203,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") + val testingVersions = Seq("2.1.3", "2.2.2", "2.3.1") protected var spark: SparkSession = _ From 7688ce88b2ea514054200845ae860fbccc25a927 Mon Sep 17 00:00:00 2001 From: HanShuliang Date: Tue, 17 Jul 2018 11:25:23 -0500 Subject: [PATCH 1144/2461] [SPARK-21590][SS] Window start time should support negative values ## What changes were proposed in this pull request? Remove the non-negative checks of window start time to make window support negative start time, and add a check to guarantee the absolute value of start time is less than slide duration. ## How was this patch tested? New unit tests. Author: HanShuliang Closes #18903 from KevinZwx/dev. --- .../sql/catalyst/expressions/TimeWindow.scala | 9 ++-- .../analysis/AnalysisErrorSuite.scala | 25 +++++++---- .../expressions/TimeWindowSuite.scala | 13 ++++++ .../sql/DataFrameTimeWindowingSuite.scala | 45 +++++++++++++++++++ 4 files changed, 77 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 84e38a8b2711e..8e48856d4607c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -80,16 +80,13 @@ case class TimeWindow( if (slideDuration <= 0) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") } - if (startTime < 0) { - return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") - } if (slideDuration > windowDuration) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + s" to the windowDuration ($windowDuration).") } - if (startTime >= slideDuration) { - return TypeCheckFailure(s"The start time ($startTime) must be less than the " + - s"slideDuration ($slideDuration).") + if (startTime.abs >= slideDuration) { + return TypeCheckFailure(s"The absolute value of start time ($startTime) must be less " + + s"than the slideDuration ($slideDuration).") } } dataTypeCheck diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..0ce94d39e994a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -334,14 +334,28 @@ class AnalysisErrorSuite extends AnalysisTest { "start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( "start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 minute").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 second").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( @@ -372,13 +386,6 @@ class AnalysisErrorSuite extends AnalysisTest { "The slide duration" :: " must be greater than 0." :: Nil ) - errorTest( - "negative start time in time window", - testRelation.select( - TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), - "The start time" :: "must be greater than or equal to 0." :: Nil - ) - errorTest( "generator nested in expressions", listRelation.select(Explode('list) + 1), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 351d4d0c2eac9..d46135c02bc01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -77,6 +77,19 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva } } + test("SPARK-21590: Start time works with negative values and return microseconds") { + val validDuration = "10 minutes" + for ((text, seconds) <- Seq( + ("-10 seconds", -10000000), // -1e7 + ("-1 minute", -60000000), + ("-1 hour", -3600000000L))) { // -6e7 + assert(TimeWindow(Literal(10L), validDuration, validDuration, "interval " + text).startTime + === seconds) + assert(TimeWindow(Literal(10L), validDuration, validDuration, text).startTime + === seconds) + } + } + private val parseExpression = PrivateMethod[Long]('parseExpression) test("parse sql expression for duration in microseconds - string") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6fe356877c268..2953425b1db49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -43,6 +43,22 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } + test("SPARK-21590: tumbling window using negative start time") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a"), + ("2016-03-27 19:39:25", 2, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 2) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -72,6 +88,20 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Seq(Row(1), Row(1), Row(1))) } + test("SPARK-21590: tumbling window groupBy statement with negative startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + test("tumbling window with multi-column projection") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -309,4 +339,19 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } } + + test("SPARK-21590: time window in SQL with three expressions including negative start time") { + withTempTable { table => + checkAnswer( + spark.sql( + s"""select window(time, "10 seconds", 10000000, "-5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } } From 912634b004c2302533a8a8501b4ecb803d17e335 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 17 Jul 2018 13:11:52 -0700 Subject: [PATCH 1145/2461] [SPARK-24747][ML] Make Instrumentation class more flexible ## What changes were proposed in this pull request? This PR updates the Instrumentation class to make it more flexible and a little bit easier to use. When these APIs are merged, I'll followup with a PR to update the training code to use these new APIs so we can remove the old APIs. These changes are all to private APIs so this PR doesn't make any user facing changes. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21719 from MrBago/new-instrumentation-apis. --- .../classification/LogisticRegression.scala | 8 +- .../spark/ml/tree/impl/RandomForest.scala | 2 +- .../spark/ml/tuning/ValidatorParams.scala | 2 +- .../spark/ml/util/Instrumentation.scala | 128 ++++++++++++------ .../spark/mllib/clustering/KMeans.scala | 4 +- 5 files changed, 93 insertions(+), 51 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 92e342ed4a464..25fb9c8aab0bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer @@ -490,7 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { + handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) + instr.logPipelineStage(this) + instr.logDataset(dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory) } model.setSummary(Some(logRegSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 905870178e549..bb3f3a015c715 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging { numTrees: Int, featureSubsetStrategy: String, seed: Long, - instr: Option[Instrumentation[_]], + instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 363304ef10147..135828815504a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** * Instrumentation logging for tuning params including the inner estimator and evaluator info. */ - protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + protected def logTuningParams(instrumentation: Instrumentation): Unit = { instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 11f46eb9e4359..2e43a9ef49ee1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,15 +19,16 @@ package org.apache.spark.ml.util import java.util.UUID -import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.util.Utils @@ -35,29 +36,44 @@ import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log * useful information during this session. - * - * A new instance is expected to be created within fit(). - * - * @param estimator the estimator that is being fit - * @param dataset the training dataset - * @tparam E the type of the estimator */ -private[spark] class Instrumentation[E <: Estimator[_]] private ( - val estimator: E, - val dataset: RDD[_]) extends Logging { +private[spark] class Instrumentation extends Logging { private val id = UUID.randomUUID() - private val prefix = { + private val shortId = id.toString.take(8) + private val prefix = s"[$shortId] " + + // TODO: remove stage + var stage: Params = _ + // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor + private def this(estimator: Estimator[_], dataset: RDD[_]) = { + this() + logPipelineStage(estimator) + logDataset(dataset) + } + + /** + * Log some info about the pipeline stage being fit. + */ + def logPipelineStage(stage: PipelineStage): Unit = { + this.stage = stage // estimator.getClass.getSimpleName can cause Malformed class name error, // call safer `Utils.getSimpleName` instead - val className = Utils.getSimpleName(estimator.getClass) - s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + val className = Utils.getSimpleName(stage.getClass) + logInfo(s"Stage class: $className") + logInfo(s"Stage uid: ${stage.uid}") } - init() + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) - private def init(): Unit = { - log(s"training: numPartitions=${dataset.partitions.length}" + + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: RDD[_]): Unit = { + logInfo(s"training: numPartitions=${dataset.partitions.length}" + s" storageLevel=${dataset.getStorageLevel}") } @@ -89,23 +105,25 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } - /** - * Alias for logInfo, see above. - */ - def log(msg: String): Unit = logInfo(msg) - /** * Logs the value of the given parameters for the estimator being used in this session. */ - def logParams(params: Param[_]*): Unit = { + def logParams(hasParams: Params, params: Param[_]*): Unit = { val pairs: Seq[(String, JValue)] = for { p <- params - value <- estimator.get(p) + value <- hasParams.get(p) } yield { val cast = p.asInstanceOf[Param[Any]] p.name -> parse(cast.jsonEncode(value)) } - log(compact(render(map2jvalue(pairs.toMap)))) + logInfo(compact(render(map2jvalue(pairs.toMap)))) + } + + // TODO: remove this + def logParams(params: Param[_]*): Unit = { + require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" + + " Params must be provided explicitly).") + logParams(stage, params: _*) } def logNumFeatures(num: Long): Unit = { @@ -124,35 +142,48 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Logs the value with customized name field. */ def logNamedValue(name: String, value: String): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Long): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Double): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Array[String]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Long]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Double]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } + // TODO: Remove this (possibly replace with logModel?) /** * Logs the successful completion of the training session. */ def logSuccess(model: Model[_]): Unit = { - log(s"training finished") + logInfo(s"training finished") + } + + def logSuccess(): Unit = { + logInfo("training finished") + } + + /** + * Logs an exception raised during a training session. + */ + def logFailure(e: Throwable): Unit = { + val msg = e.getStackTrace.mkString("\n") + super.logError(msg) } } @@ -169,22 +200,33 @@ private[spark] object Instrumentation { val varianceOfLabels = "varianceOfLabels" } + // TODO: Remove these /** * Creates an instrumentation object for a training session. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: Dataset[_]): Instrumentation[E] = { - create[E](estimator, dataset.rdd) + def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = { + create(estimator, dataset.rdd) } /** * Creates an instrumentation object for a training session. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: RDD[_]): Instrumentation[E] = { - new Instrumentation[E](estimator, dataset) + def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = { + new Instrumentation(estimator, dataset) + } + // end remove + + def instrumented[T](body: (Instrumentation => T)): T = { + val instr = new Instrumentation() + Try(body(instr)) match { + case Failure(NonFatal(e)) => + instr.logFailure(e) + throw e + case Success(result) => + instr.logSuccess() + result + } } - } /** @@ -193,7 +235,7 @@ private[spark] object Instrumentation { * will log via it, otherwise will log via common logger. */ private[spark] class OptionalInstrumentation private( - val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val instrumentation: Option[Instrumentation], val className: String) extends Logging { protected override def logName: String = className @@ -225,9 +267,9 @@ private[spark] object OptionalInstrumentation { /** * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. */ - def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + def create(instr: Instrumentation): OptionalInstrumentation = { new OptionalInstrumentation(Some(instr), - instr.estimator.getClass.getName.stripSuffix("$")) + instr.stage.getClass.getName.stripSuffix("$")) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 37ae8b1a6171a..4f554f420b903 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -235,7 +235,7 @@ class KMeans private ( private[spark] def run( data: RDD[Vector], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -264,7 +264,7 @@ class KMeans private ( */ private def runAlgorithm( data: RDD[VectorWithNorm], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { val sc = data.sparkContext From 2a4dd6f06cfd2f58fda9786c88809e6de695444e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 17 Jul 2018 14:15:30 -0700 Subject: [PATCH 1146/2461] [SPARK-24681][SQL] Verify nested column names in Hive metastore ## What changes were proposed in this pull request? This pr added code to check if nested column names do not include ',', ':', and ';' because Hive metastore can't handle these characters in nested column names; ref: https://github.com/apache/hive/blob/release-1.2.1/serde/src/java/org/apache/hadoop/hive/serde2/typeinfo/TypeInfoUtils.java#L239 ## How was this patch tested? Added tests in `HiveDDLSuite`. Author: Takeshi Yamamuro Closes #21711 from maropu/SPARK-24681. --- .../spark/sql/hive/HiveExternalCatalog.scala | 34 +++++++++++++++---- .../sql/hive/execution/HiveDDLSuite.scala | 19 +++++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 44480cecf0039..7f28fc40b4469 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -138,17 +138,37 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Checks the validity of data column names. Hive metastore disallows the table to use comma in - * data column names. Partition columns do not have such a restriction. Views do not have such - * a restriction. + * Checks the validity of data column names. Hive metastore disallows the table to use some + * special characters (',', ':', and ';') in data column names, including nested column names. + * Partition columns do not have such a restriction. Views do not have such a restriction. */ private def verifyDataSchema( tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = { if (tableType != VIEW) { - dataSchema.map(_.name).foreach { colName => - if (colName.contains(",")) { - throw new AnalysisException("Cannot create a table having a column whose name contains " + - s"commas in Hive metastore. Table: $tableName; Column: $colName") + val invalidChars = Seq(",", ":", ";") + def verifyNestedColumnNames(schema: StructType): Unit = schema.foreach { f => + f.dataType match { + case st: StructType => verifyNestedColumnNames(st) + case _ if invalidChars.exists(f.name.contains) => + val invalidCharsString = invalidChars.map(c => s"'$c'").mkString(", ") + val errMsg = "Cannot create a table having a nested column whose name contains " + + s"invalid characters ($invalidCharsString) in Hive metastore. Table: $tableName; " + + s"Column: ${f.name}" + throw new AnalysisException(errMsg) + case _ => + } + } + + dataSchema.foreach { f => + f.dataType match { + // Checks top-level column names + case _ if f.name.contains(",") => + throw new AnalysisException("Cannot create a table having a column whose name " + + s"contains commas in Hive metastore. Table: $tableName; Column: ${f.name}") + // Checks nested column names + case st: StructType => + verifyNestedColumnNames(st) + case _ => } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0341c3b378918..31fd4c5a1f996 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator @@ -2248,4 +2249,22 @@ class HiveDDLSuite checkAnswer(spark.table("t4"), Row(0, 0)) } } + + test("SPARK-24681 checks if nested column names do not include ',', ':', and ';'") { + val expectedMsg = "Cannot create a table having a nested column whose name contains invalid " + + "characters (',', ':', ';') in Hive metastore." + + Seq("nested,column", "nested:column", "nested;column").foreach { nestedColumnName => + withTable("t") { + val e = intercept[AnalysisException] { + spark.range(1) + .select(struct(lit(0).as(nestedColumnName)).as("toplevel")) + .write + .format("hive") + .saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } } From 681845fd62bbcbbf1d9309b7d8a252198d96c738 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 17 Jul 2018 17:33:52 -0700 Subject: [PATCH 1147/2461] [SPARK-24402][SQL] Optimize `In` expression when only one element in the collection or collection is empty ## What changes were proposed in this pull request? Two new rules in the logical plan optimizers are added. 1. When there is only one element in the **`Collection`**, the physical plan will be optimized to **`EqualTo`**, so predicate pushdown can be used. ```scala profileDF.filter( $"profileID".isInCollection(Set(6))).explain(true) """ |== Physical Plan == |*(1) Project [profileID#0] |+- *(1) Filter (isnotnull(profileID#0) && (profileID#0 = 6)) | +- *(1) FileScan parquet [profileID#0] Batched: true, Format: Parquet, | PartitionFilters: [], | PushedFilters: [IsNotNull(profileID), EqualTo(profileID,6)], | ReadSchema: struct """.stripMargin ``` 2. When the **`Collection`** is empty, and the input is nullable, the logical plan will be simplified to ```scala profileDF.filter( $"profileID".isInCollection(Set())).explain(true) """ |== Optimized Logical Plan == |Filter if (isnull(profileID#0)) null else false |+- Relation[profileID#0] parquet """.stripMargin ``` TODO: 1. For multiple conditions with numbers less than certain thresholds, we should still allow predicate pushdown. 2. Optimize the **`In`** using **`tableswitch`** or **`lookupswitch`** when the numbers of the categories are low, and they are **`Int`**, **`Long`**. 3. The default immutable hash trees set is slow for query, and we should do benchmark for using different set implementation for faster query. 4. **`filter(if (condition) null else false)`** can be optimized to false. ## How was this patch tested? Couple new tests are added. Author: DB Tsai Closes #21797 from dbtsai/optimize-in. --- .../sql/catalyst/optimizer/expressions.scala | 17 +++++++--- .../catalyst/optimizer/OptimizeInSuite.scala | 32 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..cf17f59599968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,15 +218,24 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 + // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, + // TODO: we exclude them in this rule. + && !v.isInstanceOf[CreateNamedStructLike] + && !newList.head.isInstanceOf[CreateNamedStructLike]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..86522a6a54ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } From fc2e18963efdf4b50258f85c8779122742876910 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 18 Jul 2018 10:00:13 +0800 Subject: [PATCH 1148/2461] [SPARK-24529][BUILD][TEST-MAVEN][FOLLOW-UP] Set spotbugs-maven-plugin's fork to true ## What changes were proposed in this pull request? Set `spotbugs-maven-plugin`'s fork to `true`, otherwise will throw exception when make distribution: ``` ./dev/make-distribution.sh --name SPARK-24529 --tgz -Phadoop-2.7 -Phive -Phive-thriftserver -Pyarn -Phadoop-provided ``` exception: ```java ... [INFO] Reactor Summary: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 8.753 s] [INFO] Spark Project Tags ................................. SUCCESS [ 9.334 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 12.029 s] [INFO] Spark Project Local DB ............................. SUCCESS [ 13.641 s] [INFO] Spark Project Networking ........................... FAILURE [10:10 min] [INFO] Spark Project Shuffle Streaming Service ............ SKIPPED [INFO] Spark Project Unsafe ............................... SUCCESS [ 16.415 s] [INFO] Spark Project Launcher ............................. SKIPPED [INFO] Spark Project Core ................................. SKIPPED [INFO] Spark Project ML Local Library ..................... SKIPPED [INFO] Spark Project GraphX ............................... SKIPPED [INFO] Spark Project Streaming ............................ SKIPPED [INFO] Spark Project Catalyst ............................. SKIPPED [INFO] Spark Project SQL .................................. SKIPPED [INFO] Spark Project ML Library ........................... SKIPPED [INFO] Spark Project Tools ................................ SUCCESS [ 8.750 s] [INFO] Spark Project Hive ................................. SKIPPED [INFO] Spark Project REPL ................................. SKIPPED [INFO] Spark Project YARN Shuffle Service ................. SKIPPED [INFO] Spark Project YARN ................................. SKIPPED [INFO] Spark Project Hive Thrift Server ................... SKIPPED [INFO] Spark Project Assembly ............................. SKIPPED [INFO] Spark Integration for Kafka 0.10 ................... SKIPPED [INFO] Kafka 0.10 Source for Structured Streaming ......... SKIPPED [INFO] Spark Project Examples ............................. SKIPPED [INFO] Spark Integration for Kafka 0.10 Assembly .......... SKIPPED [INFO] Spark Avro ......................................... SKIPPED [INFO] ------------------------------------------------------------------------ [INFO] BUILD FAILURE [INFO] ------------------------------------------------------------------------ [INFO] Total time: 10:29 min (Wall Clock) [INFO] Finished at: 2018-07-16T21:39:46+08:00 [INFO] Final Memory: 61M/885M [INFO] ------------------------------------------------------------------------ Timeout: sub-process interrupted [ERROR] Failed to execute goal com.github.spotbugs:spotbugs-maven-plugin:3.1.3:spotbugs (spotbugs) on project spark-network-common_2.11: Execution spotbugs of goal com.github.spotbugs:spotbugs-maven-plugin:3.1.3:spotbugs failed: Timeout: killed the sub-process -> [Help 1] [ERROR] [ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch. [ERROR] Re-run Maven using the -X switch to enable full debug logging. [ERROR] [ERROR] For more information about the errors and possible solutions, please read the following articles: [ERROR] [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/PluginExecutionException [ERROR] [ERROR] After correcting the problems, you can resume the build with the command [ERROR] mvn -rf :spark-network-common_2.11 org.apache.tools.ant.ExitException: Permission ("java.lang.RuntimePermission" "exitVM") was not granted. at org.apache.tools.ant.types.Permissions$MySM.checkExit(Permissions.java:194) at java.lang.Runtime.exit(Runtime.java:107) at java.lang.System.exit(System.java:971) at org.codehaus.plexus.classworlds.launcher.Launcher.main(Launcher.java:358) Exception in thread "main" org.apache.tools.ant.ExitException: Permission ("java.lang.RuntimePermission" "exitVM") was not granted. at org.apache.tools.ant.types.Permissions$MySM.checkExit(Permissions.java:194) at java.lang.Runtime.exit(Runtime.java:107) at java.lang.System.exit(System.java:971) at org.codehaus.plexus.classworlds.launcher.Launcher.main(Launcher.java:364) Timeout: sub-process interrupted ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #21785 from wangyum/SPARK-24529. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 039292337eaa0..1892bbee4e97b 100644 --- a/pom.xml +++ b/pom.xml @@ -2618,7 +2618,7 @@ Low true FindPuzzlers - false + true From 3b59d326c77bec96e5fb856d827139e0389394ba Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 17 Jul 2018 23:52:17 -0700 Subject: [PATCH 1149/2461] [SPARK-24576][BUILD] Upgrade Apache ORC to 1.5.2 ## What changes were proposed in this pull request? This issue aims to upgrade Apache ORC library from 1.4.4 to 1.5.2 in order to bring the following benefits into Apache Spark. - [ORC-91](https://issues.apache.org/jira/browse/ORC-91) Support for variable length blocks in HDFS (The current space wasted in ORC to padding is known to be 5%.) - [ORC-344](https://issues.apache.org/jira/browse/ORC-344) Support for using Decimal64ColumnVector In addition to that, Apache Hive 3.1 and 3.2 will use ORC 1.5.1 ([HIVE-19669](https://issues.apache.org/jira/browse/HIVE-19465)) and 1.5.2 ([HIVE-19792](https://issues.apache.org/jira/browse/HIVE-19792)) respectively. This will improve the compatibility between Apache Spark and Apache Hive by sharing the common library. ## How was this patch tested? Pass the Jenkins with all existing tests. Author: Dongjoon Hyun Closes #21582 from dongjoon-hyun/SPARK-24576. --- dev/deps/spark-deps-hadoop-2.6 | 7 +++-- dev/deps/spark-deps-hadoop-2.7 | 7 +++-- dev/deps/spark-deps-hadoop-3.1 | 7 +++-- pom.xml | 2 +- sql/core/pom.xml | 28 +++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 15 +++++++++- .../datasources/orc/OrcSerializer.scala | 2 +- 7 files changed, 56 insertions(+), 12 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f50a0aac0aefc..ff6d5c30c1eb4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -157,8 +157,9 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 774f9dc39ce4d..72a94f8953c6c 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -158,8 +158,9 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 19c05ad1e991f..3409dc4613324 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -4,7 +4,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar accessors-smart-1.2.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -176,8 +176,9 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 1892bbee4e97b..649221d526086 100644 --- a/pom.xml +++ b/pom.xml @@ -131,7 +131,7 @@ 1.2.1 10.12.1.1 1.10.0 - 1.4.4 + 1.5.2 nohive 1.6.0 9.3.20.v20170531 diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 18ae314309d7b..8873b00e7117a 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -90,11 +90,39 @@ org.apache.orc orc-core ${orc.classifier} + + + org.apache.hadoop + hadoop-hdfs + + + + org.apache.hive + hive-storage-api + + org.apache.orc orc-mapreduce ${orc.classifier} + + + org.apache.hadoop + hadoop-hdfs + + + + org.apache.hive + hive-storage-api + + org.apache.parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 3a8c0add8c2f9..df1cebed5bd0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -59,6 +59,19 @@ private[sql] object OrcFileFormat { def checkFieldNames(names: Seq[String]): Unit = { names.foreach(checkFieldName) } + + def getQuotedSchemaString(dataType: DataType): String = dataType match { + case _: AtomicType => dataType.catalogString + case StructType(fields) => + fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") + .mkString("struct<", ",", ">") + case ArrayType(elementType, _) => + s"array<${getQuotedSchemaString(elementType)}>" + case MapType(keyType, valueType, _) => + s"map<${getQuotedSchemaString(keyType)},${getQuotedSchemaString(valueType)}>" + case _ => // UDT and others + dataType.catalogString + } } /** @@ -93,7 +106,7 @@ class OrcFileFormat val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 899af0750cadf..90d1268028096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -223,6 +223,6 @@ class OrcSerializer(dataSchema: StructType) { * Return a Orc value object for the given Spark schema. */ private def createOrcValue(dataType: DataType) = { - OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) } } From 34cb3b54e9b7d4c5739cd446a9a66ec3fe59516b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 18 Jul 2018 19:17:18 +0800 Subject: [PATCH 1150/2461] [SPARK-24386][SPARK-24768][BUILD][FOLLOWUP] Fix lint-java and Scala 2.12 build. ## What changes were proposed in this pull request? This pr fixes lint-java and Scala 2.12 build. lint-java: ``` [ERROR] src/test/resources/log4j.properties:[0] (misc) NewlineAtEndOfFile: File does not end with a newline. ``` Scala 2.12 build: ``` [error] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala:121: overloaded method value addTaskCompletionListener with alternatives: [error] (f: org.apache.spark.TaskContext => Unit)org.apache.spark.TaskContext [error] (listener: org.apache.spark.util.TaskCompletionListener)org.apache.spark.TaskContext [error] cannot be applied to (org.apache.spark.TaskContext => java.util.List[Runnable]) [error] context.addTaskCompletionListener { ctx => [error] ^ ``` ## How was this patch tested? Manually executed lint-java and Scala 2.12 build in my local environment. Author: Takuya UESHIN Closes #21801 from ueshin/issues/SPARK-24386_24768/fix_build. --- external/avro/src/test/resources/log4j.properties | 2 +- .../execution/streaming/continuous/ContinuousCoalesceRDD.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties index c18a724007c27..f80a5291bc078 100644 --- a/external/avro/src/test/resources/log4j.properties +++ b/external/avro/src/test/resources/log4j.properties @@ -46,4 +46,4 @@ log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF log4j.additivity.hive.ql.metadata.Hive=false -log4j.logger.hive.ql.metadata.Hive=OFF \ No newline at end of file +log4j.logger.hive.ql.metadata.Hive=OFF diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala index ba85b355f974f..10d8fc553fede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -120,6 +120,7 @@ class ContinuousCoalesceRDD( context.addTaskCompletionListener { ctx => threadPool.shutdownNow() + () } part.writersInitialized = true From 2694dd2bf084410ff346d21aaf74025b587d46a8 Mon Sep 17 00:00:00 2001 From: Nihar Sheth Date: Wed, 18 Jul 2018 09:14:36 -0500 Subject: [PATCH 1151/2461] [MINOR][CORE] Add test cases for RDD.cartesian ## What changes were proposed in this pull request? While looking through the codebase, it appeared that the scala code for RDD.cartesian does not have any tests for correctness. This adds a couple basic tests to verify cartesian yields correct values. While the implementation for RDD.cartesian is pretty simple, it always helps to have a few tests! ## How was this patch tested? The new test cases pass, and the scala style tests from running dev/run-tests all pass. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nihar Sheth Closes #21765 from NiharS/cartesianTests. --- .../scala/org/apache/spark/rdd/RDDSuite.scala | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5148ce05bd918..b143a468a1baf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -443,7 +443,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") } - test("coalesced RDDs with partial locality") { + test("coalesced RDDs with partial locality") { // Make an RDD that has some locality preferences and some without. This can happen // with UnionRDD val data = sc.makeRDD((1 to 9).map(i => { @@ -846,6 +846,28 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) } + test("cartesian on empty RDD") { + val a = sc.emptyRDD[Int] + val b = sc.parallelize(1 to 3) + val cartesian_result = Array.empty[(Int, Int)] + assert(a.cartesian(a).collect().toList === cartesian_result) + assert(a.cartesian(b).collect().toList === cartesian_result) + assert(b.cartesian(a).collect().toList === cartesian_result) + } + + test("cartesian on non-empty RDDs") { + val a = sc.parallelize(1 to 3) + val b = sc.parallelize(2 to 4) + val c = sc.parallelize(1 to 1) + val a_cartesian_b = + Array((1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 2), (3, 3), (3, 4)) + val a_cartesian_c = Array((1, 1), (2, 1), (3, 1)) + val c_cartesian_a = Array((1, 1), (1, 2), (1, 3)) + assert(a.cartesian[Int](b).collect().toList.sorted === a_cartesian_b) + assert(a.cartesian[Int](c).collect().toList.sorted === a_cartesian_c) + assert(c.cartesian[Int](a).collect().toList.sorted === c_cartesian_a) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) From 002300dd41ccc2d1c38351cb09c430f0ded6ab85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E7=94=B0=E7=94=B000222924?= Date: Wed, 18 Jul 2018 09:40:36 -0500 Subject: [PATCH 1152/2461] [SPARK-24804] There are duplicate words in the test title in the DatasetSuite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In DatasetSuite.scala, in the 1299 line, test("SPARK-19896: cannot have circular references in in case class") , there are duplicate words "in in". We can get rid of one. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 韩田田00222924 Closes #21767 from httfighter/inin. --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ce8db99d4e2f1..cf24eba128012 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1296,7 +1296,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { new java.sql.Timestamp(100000)) } - test("SPARK-19896: cannot have circular references in in case class") { + test("SPARK-19896: cannot have circular references in case class") { val errMsg1 = intercept[UnsupportedOperationException] { Seq(CircularReferenceClassA(null)).toDS } From ebe9e28488a30b40f85f79c69be69185a1c4e4f5 Mon Sep 17 00:00:00 2001 From: Huangweizhe Date: Wed, 18 Jul 2018 09:45:56 -0500 Subject: [PATCH 1153/2461] [SPARK-24628][DOC] Typos of the example code in docs/mllib-data-types.md ## What changes were proposed in this pull request? The example wants to create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)), but the list is given as [1, 2, 3, 4, 5, 6]. Now it is changed as [1, 3, 5, 2, 4, 6]. And the example wants to create an RDD of coordinate entries like: entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]). However, it is done with the MatrixEntry class like: entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]), where the third MatrixEntry has a different row index. Now it is changed as MatrixEntry(2, 1, 3.7). ## How was this patch tested? This is trivial enough that it should not affect tests. Author: Weizhe Huang Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Huangweizhe Closes #21612 from huangweizhe123/my_change. --- docs/mllib-data-types.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 5066bb29387dc..eca101132d2e5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -317,7 +317,7 @@ Refer to the [`Matrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib. from pyspark.mllib.linalg import Matrix, Matrices # Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) -dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) +dm2 = Matrices.dense(3, 2, [1, 3, 5, 2, 4, 6]) # Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) @@ -624,7 +624,7 @@ from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry # Create an RDD of coordinate entries. # - This can be done explicitly with the MatrixEntry class: -entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(2, 1, 3.7)]) # - or using (long, long, float) tuples: entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) From fc0c8c97173e641b50eb7cea80c63262d5ba4180 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 18 Jul 2018 10:01:39 -0700 Subject: [PATCH 1154/2461] [SPARK-24825][K8S][TEST] Kubernetes integration tests build the whole reactor ## What changes were proposed in this pull request? Make the integration test script build all modules. In order to not run all the non-Kubernetes integration tests in the build, support specifying tags and tag all integration tests specifically with "k8s". Supply the k8s tag in the dev/dev-run-integration-tests.sh script. ## How was this patch tested? The build system will test this. Author: mcheah Closes #21800 from mccheah/k8s-integration-tests-maven-fix. --- pom.xml | 3 +++ .../dev/dev-run-integration-tests.sh | 21 ++++++---------- .../k8s/integrationtest/KubernetesSuite.scala | 25 ++++++++++--------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/pom.xml b/pom.xml index 649221d526086..81a53eee14f29 100644 --- a/pom.xml +++ b/pom.xml @@ -194,6 +194,7 @@ ${java.home} + org.spark_project @@ -2162,6 +2163,7 @@ false ${test.exclude.tags} + ${test.include.tags} @@ -2209,6 +2211,7 @@ __not_used__ ${test.exclude.tags} + ${test.include.tags} diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 3acd0f5cd3349..b28b8b82ca016 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -16,9 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests - -cd "${TEST_ROOT_DIR}" +set -xo errexit +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" @@ -27,7 +26,7 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= -INCLUDE_TAGS= +INCLUDE_TAGS="k8s" EXCLUDE_TAGS= # Parse arguments @@ -62,7 +61,7 @@ while (( "$#" )); do shift ;; --include-tags) - INCLUDE_TAGS="$2" + INCLUDE_TAGS="k8s,$2" shift ;; --exclude-tags) @@ -76,13 +75,12 @@ while (( "$#" )); do shift done -cd $TEST_ROOT_DIR - properties=( -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ - -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE \ + -Dtest.include.tags=$INCLUDE_TAGS ) if [ -n $NAMESPACE ]; @@ -105,9 +103,4 @@ then properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) fi -if [ -n $INCLUDE_TAGS ]; -then - properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS ) -fi - -../../../build/mvn integration-test ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes -Phadoop-2.7 ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 774c3936b877c..daabfaaac8c7e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -23,7 +23,7 @@ import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter import io.fabric8.kubernetes.api.model.Pod -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} import scala.collection.JavaConverters._ @@ -47,6 +47,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite private var containerLocalSparkDistroExamplesJar: String = _ private var appLocator: String = _ private var driverPodName: String = _ + private val k8sTestTag = Tag("k8s") override def beforeAll(): Unit = { // The scalatest-maven-plugin gives system properties that are referenced but not set null @@ -102,22 +103,22 @@ private[spark] class KubernetesSuite extends SparkFunSuite deleteDriverPod() } - test("Run SparkPi with no resources") { + test("Run SparkPi with no resources", k8sTestTag) { runSparkPiAndVerifyCompletion() } - test("Run SparkPi with a very long application name.") { + test("Run SparkPi with a very long application name.", k8sTestTag) { sparkAppConf.set("spark.app.name", "long" * 40) runSparkPiAndVerifyCompletion() } - test("Use SparkLauncher.NO_RESOURCE") { + test("Use SparkLauncher.NO_RESOURCE", k8sTestTag) { sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) runSparkPiAndVerifyCompletion( appResource = SparkLauncher.NO_RESOURCE) } - test("Run SparkPi with a master URL without a scheme.") { + test("Run SparkPi with a master URL without a scheme.", k8sTestTag) { val url = kubernetesTestComponents.kubernetesClient.getMasterUrl val k8sMasterUrl = if (url.getPort < 0) { s"k8s://${url.getHost}" @@ -128,11 +129,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite runSparkPiAndVerifyCompletion() } - test("Run SparkPi with an argument.") { + test("Run SparkPi with an argument.", k8sTestTag) { runSparkPiAndVerifyCompletion(appArgs = Array("5")) } - test("Run SparkPi with custom labels, annotations, and environment variables.") { + test("Run SparkPi with custom labels, annotations, and environment variables.", k8sTestTag) { sparkAppConf .set("spark.kubernetes.driver.label.label1", "label1-value") .set("spark.kubernetes.driver.label.label2", "label2-value") @@ -158,21 +159,21 @@ private[spark] class KubernetesSuite extends SparkFunSuite }) } - test("Run extraJVMOptions check on driver") { + test("Run extraJVMOptions check on driver", k8sTestTag) { sparkAppConf .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") runSparkJVMCheckAndVerifyCompletion( expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) } - test("Run SparkRemoteFileTest using a remote data file") { + test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { sparkAppConf .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) runSparkRemoteCheckAndVerifyCompletion( appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) } - test("Run PySpark on simple pi.py example") { + test("Run PySpark on simple pi.py example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") runSparkApplicationAndVerifyCompletion( @@ -186,7 +187,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite isJVM = false) } - test("Run PySpark with Python2 to test a pyfiles example") { + test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") .set("spark.kubernetes.pyspark.pythonversion", "2") @@ -204,7 +205,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite pyFiles = Some(PYSPARK_CONTAINER_TESTS)) } - test("Run PySpark with Python3 to test a pyfiles example") { + test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") .set("spark.kubernetes.pyspark.pythonversion", "3") From c8bee932cb644627c4049b5a07dd8028968572d9 Mon Sep 17 00:00:00 2001 From: sychen Date: Wed, 18 Jul 2018 13:24:41 -0500 Subject: [PATCH 1155/2461] [SPARK-24677][CORE] Avoid NoSuchElementException from MedianHeap ## What changes were proposed in this pull request? When speculation is enabled, TaskSetManager#markPartitionCompleted should write successful task duration to MedianHeap, not just increase tasksSuccessful. Otherwise when TaskSetManager#checkSpeculatableTasks,tasksSuccessful non-zero, but MedianHeap is empty. Then throw an exception successfulTaskDurations.median java.util.NoSuchElementException: MedianHeap is empty. Finally led to stopping SparkContext. ## How was this patch tested? TaskSetManagerSuite.scala unit test:[SPARK-24677] MedianHeap should not be empty when speculation is enabled Author: sychen Closes #21656 from cxzl25/fix_MedianHeap_empty. --- .../spark/scheduler/TaskSchedulerImpl.scala | 7 ++- .../spark/scheduler/TaskSetManager.scala | 7 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 49 +++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 598b62f85a1fa..56c0bf6c09351 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -697,9 +697,12 @@ private[spark] class TaskSchedulerImpl( * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ - private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + private[scheduler] def markPartitionCompletedInAllTaskSets( + stageId: Int, + partitionId: Int, + taskInfo: TaskInfo) = { taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => - tsm.markPartitionCompleted(partitionId) + tsm.markPartitionCompleted(partitionId, taskInfo) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a18c66596852a..6071605ad7f9d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -758,7 +758,7 @@ private[spark] class TaskSetManager( } // There may be multiple tasksets for this stage -- we let all of them know that the partition // was completed. This may result in some of the tasksets getting completed. - sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -769,9 +769,12 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { + if (speculationEnabled && !isZombie) { + successfulTaskDurations.insert(taskInfo.duration) + } tasksSuccessful += 1 successful(index) = true if (tasksSuccessful == numTasks) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ca6a7e5db3b17..ae571e5a3583a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1365,6 +1365,55 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } + test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.1") + sc.conf.set("spark.speculation", "true") + + sched = new FakeTaskScheduler(sc) + sched.initialize(new FakeSchedulerBackend()) + + val dagScheduler = new FakeDAGScheduler(sc, sched) + sched.setDAGScheduler(dagScheduler) + + val taskSet1 = FakeTask.createTaskSet(10) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task => + task.metrics.internalAccums + } + + sched.submitTasks(taskSet1) + sched.resourceOffers( + (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get + + // fail fetch + taskSetManager1.handleFailedTask( + taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + + assert(taskSetManager1.isZombie) + assert(taskSetManager1.runningTasks === 9) + + val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1) + sched.submitTasks(taskSet2) + sched.resourceOffers( + (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + // Complete the 2 tasks and leave 8 task in running + for (id <- Set(0, 1)) { + taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get + assert(!taskSetManager2.successfulTaskDurations.isEmpty()) + taskSetManager2.checkSpeculatableTasks(0) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From 1272b2034d4eed4bfe60a49e1065871b3a3f96e0 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Wed, 18 Jul 2018 14:07:03 -0500 Subject: [PATCH 1156/2461] =?UTF-8?q?[SPARK-22151]=20PYTHONPATH=20not=20pi?= =?UTF-8?q?cked=20up=20from=20the=20spark.yarn.appMaste=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rEnv properly Running in yarn cluster mode and trying to set pythonpath via spark.yarn.appMasterEnv.PYTHONPATH doesn't work. the yarn Client code looks at the env variables: val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) But when you set spark.yarn.appMasterEnv it puts it into the local env. So the python path set in spark.yarn.appMasterEnv isn't properly set. You can work around if you are running in cluster mode by setting it on the client like: PYTHONPATH=./addon/python/ spark-submit ## What changes were proposed in this pull request? In Client.scala, PYTHONPATH was being overridden, so changed code to append values to PYTHONPATH instead of overriding them. ## How was this patch tested? Added log statements to ApplicationMaster.scala to check for environment variable PYTHONPATH, ran a spark job in cluster mode before the change and verified the issue. Performed the same test after the change and verified the fix. Author: pgandhi Closes #21468 from pgandhi999/SPARK-22151. --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 793d012218490..ed9879c06968d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -811,10 +811,12 @@ private[spark] class Client( // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. if (pythonPath.nonEmpty) { - val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + val pythonPathList = (sys.env.get("PYTHONPATH") ++ pythonPath) + env("PYTHONPATH") = (env.get("PYTHONPATH") ++ pythonPathList) .mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) - env("PYTHONPATH") = pythonPathStr - sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) + val pythonPathExecutorEnv = (sparkConf.getExecutorEnv.toMap.get("PYTHONPATH") ++ + pythonPathList).mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathExecutorEnv) } if (isClusterMode) { From cd203e0dfc0758a2a90297e8c74c22a1212db846 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Jul 2018 13:33:26 -0700 Subject: [PATCH 1157/2461] [SPARK-24163][SPARK-24164][SQL] Support column list as the pivot column in Pivot ## What changes were proposed in this pull request? 1. Extend the Parser to enable parsing a column list as the pivot column. 2. Extend the Parser and the Pivot node to enable parsing complex expressions with aliases as the pivot value. 3. Add type check and constant check in Analyzer for Pivot node. ## How was this patch tested? Add tests in pivot.sql Author: maryannxue Closes #21720 from maryannxue/spark-24164. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 11 +- .../sql/catalyst/analysis/Analyzer.scala | 47 +++- .../sql/catalyst/parser/AstBuilder.scala | 22 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 92 +++++++ .../resources/sql-tests/results/pivot.sql.out | 230 +++++++++++++++--- 6 files changed, 348 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index dc95751bf905c..1b43874af6feb 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -414,7 +414,16 @@ groupingSet ; pivotClause - : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? ; lateralView diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 36f14ccdc6989..59c371eb1557b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -509,17 +509,39 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) - || !p.pivotColumn.resolved => p + || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) + // Check all pivot values are literal and match pivot column data type. + val evalPivotValues = pivotValues.map { value => + val foldable = value match { + case Alias(v, _) => v.foldable + case _ => value.foldable + } + if (!foldable) { + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } + if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { + throw new AnalysisException(s"Invalid pivot value '$value': " + + s"value data type ${value.dataType.simpleString} does not match " + + s"pivot column data type ${pivotColumn.dataType.catalogString}") + } + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 - def outputName(value: Literal, aggregate: Expression): String = { - val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) - val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") + def outputName(value: Expression, aggregate: Expression): String = { + val stringValue = value match { + case n: NamedExpression => n.name + case _ => + val utf8Value = + Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + Option(utf8Value).map(_.toString).getOrElse("null") + } if (singleAgg) { stringValue } else { @@ -534,15 +556,10 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val namedPivotCol = pivotColumn match { - case n: NamedExpression => n - case _ => Alias(pivotColumn, "__pivot_col")() - } - val bigGroup = groupByExprs :+ namedPivotCol + val bigGroup = groupByExprs ++ pivotColumn.references val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) - val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } @@ -557,8 +574,12 @@ class Analyzer( Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) + def ifExpr(e: Expression) = { + If( + EqualNullSafe( + pivotColumn, + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), + e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f398b479dc273..49f578a24aaeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -630,11 +630,29 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val aggregates = Option(ctx.aggregates).toSeq .flatMap(_.namedExpression.asScala) .map(typedVisit[Expression]) - val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) - val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) Pivot(None, pivotColumn, pivotValues, aggregates, query) } + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bf32ef7884e5..ea5a9b8ed5542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -700,7 +700,7 @@ case class GroupingSets( case class Pivot( groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, - pivotValues: Seq[Literal], + pivotValues: Seq[Expression], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { override lazy val resolved = false // Pivot will be replaced after being resolved. diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index b3d53adfbebe7..a6c8d4854ff38 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,6 +11,11 @@ create temporary view years as select * from values (2013, 2) as years(y, s); +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a); + -- pivot courses SELECT * FROM ( SELECT year, course, earnings FROM courseSales @@ -96,6 +101,15 @@ PIVOT ( FOR y IN (2012, 2013) ); +-- pivot with projection and value aliases +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +); + -- pivot years with non-aggregate function SELECT * FROM courseSales PIVOT ( @@ -103,6 +117,15 @@ PIVOT ( FOR year IN (2012, 2013) ); +-- pivot with one of the expressions as non-aggregate function +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +); + -- pivot with unresolvable columns SELECT * FROM ( SELECT course, earnings FROM courseSales @@ -129,3 +152,72 @@ PIVOT ( sum(avg(earnings)) FOR course IN ('dotNET', 'Java') ); + +-- pivot on multiple pivot columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +); + +-- pivot on multiple pivot columns with aliased values +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +); + +-- pivot on multiple pivot columns with values of wrong data types +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +); + +-- pivot with unresolvable values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +); + +-- pivot with non-literal values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +); + +-- pivot on join query with columns of complex data types +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns with agg columns of complex data types +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 922d8b9f9152c..6bb51b946f960 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 25 -- !query 0 @@ -28,6 +28,17 @@ struct<> -- !query 2 +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -35,27 +46,27 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 2 schema +-- !query 3 schema struct --- !query 2 output +-- !query 3 output 2012 15000 20000 2013 48000 30000 --- !query 3 +-- !query 4 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output Java 20000 30000 dotNET 15000 48000 --- !query 4 +-- !query 5 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -63,14 +74,14 @@ PIVOT ( sum(earnings), avg(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 2012 15000 7500.0 20000 20000.0 2013 48000 48000.0 30000 30000.0 --- !query 5 +-- !query 6 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -78,13 +89,13 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 63000 50000 --- !query 6 +-- !query 7 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -92,13 +103,13 @@ PIVOT ( sum(earnings), min(year) FOR course IN ('dotNET', 'Java') ) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 63000 2012 50000 2012 --- !query 7 +-- !query 8 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -108,16 +119,16 @@ PIVOT ( sum(earnings) FOR s IN (1, 2) ) --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output Java 2012 20000 NULL Java 2013 NULL 30000 dotNET 2012 15000 NULL dotNET 2013 NULL 48000 --- !query 8 +-- !query 9 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -127,14 +138,14 @@ PIVOT ( sum(earnings), min(s) FOR course IN ('dotNET', 'Java') ) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 2012 15000 1 20000 1 2013 48000 2 30000 2 --- !query 9 +-- !query 10 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -144,14 +155,14 @@ PIVOT ( sum(earnings * s) FOR course IN ('dotNET', 'Java') ) --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output 2012 15000 20000 2013 96000 60000 --- !query 10 +-- !query 11 SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) @@ -159,27 +170,57 @@ PIVOT ( sum(e) s, avg(e) a FOR y IN (2012, 2013) ) --- !query 10 schema +-- !query 11 schema struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> --- !query 10 output +-- !query 11 output 15000 48000 7500.0 48000.0 dotNET 20000 30000 20000.0 30000.0 Java --- !query 11 +-- !query 12 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +) +-- !query 12 schema +struct +-- !query 12 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 13 SELECT * FROM courseSales PIVOT ( abs(earnings) FOR year IN (2012, 2013) ) --- !query 11 schema +-- !query 13 schema struct<> --- !query 11 output +-- !query 13 output org.apache.spark.sql.AnalysisException Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; --- !query 12 +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but '__auto_generated_subquery_name.`year`' did not appear in any aggregate function.; + + +-- !query 15 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -187,14 +228,14 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 12 schema +-- !query 15 schema struct<> --- !query 12 output +-- !query 15 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 --- !query 13 +-- !query 16 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -202,14 +243,14 @@ PIVOT ( ceil(sum(earnings)), avg(earnings) + 1 as a1 FOR course IN ('dotNET', 'Java') ) --- !query 13 schema +-- !query 16 schema struct --- !query 13 output +-- !query 16 output 2012 15000 7501.0 20000 20001.0 2013 48000 48001.0 30000 30001.0 --- !query 14 +-- !query 17 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -217,8 +258,119 @@ PIVOT ( sum(avg(earnings)) FOR course IN ('dotNET', 'Java') ) --- !query 14 schema +-- !query 17 schema struct<> --- !query 14 output +-- !query 17 output org.apache.spark.sql.AnalysisException It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; + + +-- !query 18 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +) +-- !query 18 schema +struct +-- !query 18 output +1 15000 NULL +2 NULL 30000 + + +-- !query 19 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +) +-- !query 19 schema +struct +-- !query 19 output +2012 NULL 20000 +2013 48000 NULL + + +-- !query 20 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; + + +-- !query 21 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 + + +-- !query 22 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +) +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Literal expressions required for pivot values, found 'course#x'; + + +-- !query 23 +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +) +-- !query 23 schema +struct,Java:array> +-- !query 23 output +2012 [1,1] [1,1] +2013 [2,2] [2,2] + + +-- !query 24 +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +) +-- !query 24 schema +struct,[2013, Java]:array> +-- !query 24 output +2012 [1,1] NULL +2013 NULL [2,2] From d404e54e644ec7c6873234b9fed72e7f1f41a8b1 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 18 Jul 2018 16:18:29 -0500 Subject: [PATCH 1158/2461] [SPARK-24129][K8S] Add option to pass --build-arg's to docker-image-tool.sh ## What changes were proposed in this pull request? Adding `-b arg` option to take `--build-arg` parameters to pass into the docker command ## How was this patch tested? I verified by passing proxy details which fails without this change and succeeds with the changes. Author: Devaraj K Closes #21202 from devaraj-kavali/SPARK-24129. --- bin/docker-image-tool.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a3f1bcffaea57..f36fb43692cf4 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -49,6 +49,7 @@ function build { # Set image build arguments accordingly if this is a source repo and not a distribution archive. IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg img_path=$IMG_PATH --build-arg @@ -57,13 +58,14 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" - BUILD_ARGS=() + BUILD_ARGS=(${BUILD_PARAMS}) fi if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi local BINDING_BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg base_img=$(image_ref spark) ) @@ -101,6 +103,8 @@ Options: -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. -n Build docker image with --no-cache + -b arg Build arg to build or push the image. For multiple build args, this option needs to + be used separately for each build arg. Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -130,7 +134,8 @@ TAG= BASEDOCKERFILE= PYDOCKERFILE= NOCACHEARG= -while getopts f:mr:t:n option +BUILD_PARAMS= +while getopts f:mr:t:n:b: option do case "${option}" in @@ -139,6 +144,7 @@ do r) REPO=${OPTARG};; t) TAG=${OPTARG};; n) NOCACHEARG="--no-cache";; + b) BUILD_PARAMS=${BUILD_PARAMS}" --build-arg "${OPTARG};; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." From 753f11516226d0ddf542f95bd67d145a4e899451 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Jul 2018 18:39:23 -0500 Subject: [PATCH 1159/2461] [SPARK-21261][DOCS][SQL] SQL Regex document fix ## What changes were proposed in this pull request? Fix regexes in spark-sql command examples. This takes over https://github.com/apache/spark/pull/18477 ## How was this patch tested? Existing tests. I verified the existing example doesn't work in spark-sql, but new ones does. Author: Sean Owen Closes #21808 from srowen/SPARK-21261. --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 7b68bb771faf3..bf0c35fe61018 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -272,7 +272,7 @@ case class StringSplit(str: Expression, pattern: Expression) usage = "_FUNC_(str, regexp, rep) - Replaces all substrings of `str` that match `regexp` with `rep`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)', 'num'); + > SELECT _FUNC_('100-200', '(\\d+)', 'num'); num-num """) // scalastyle:on line.size.limit @@ -371,7 +371,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)-(\d+)', 1); + > SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1); 100 """) case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) From cd5d93c0e4ec4573126c6cdda3362814976d11eb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 19 Jul 2018 09:16:16 +0800 Subject: [PATCH 1160/2461] [SPARK-24854][SQL] Gathering all Avro options into the AvroOptions class ## What changes were proposed in this pull request? In the PR, I propose to put all `Avro` options in new class `AvroOptions` in the same way as for other datasources `JSON` and `CSV`. ## How was this patch tested? It was tested by `AvroSuite` Author: Maxim Gekk Closes #21810 from MaxGekk/avro-options. --- .../spark/sql/avro/AvroFileFormat.scala | 13 +++-- .../apache/spark/sql/avro/AvroOptions.scala | 48 +++++++++++++++++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 6 ++- 3 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 9eb206457809c..1d0f40e1ce92a 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -58,6 +58,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { val conf = spark.sparkContext.hadoopConfiguration + val parsedOptions = new AvroOptions(options) // Schema evolution is not supported yet. Here we only pick a single random sample file to // figure out the schema of the whole dataset. @@ -76,7 +77,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { } // User can specify an optional avro json schema. - val avroSchema = options.get(AvroFileFormat.AvroSchema) + val avroSchema = parsedOptions.schema .map(new Schema.Parser().parse) .getOrElse { val in = new FsInput(sampleFile.getPath, conf) @@ -114,10 +115,9 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val recordName = options.getOrElse("recordName", "topLevelRecord") - val recordNamespace = options.getOrElse("recordNamespace", "") + val parsedOptions = new AvroOptions(options) val outputAvroSchema = SchemaConverters.toAvroType( - dataSchema, nullable = false, recordName, recordNamespace) + dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" @@ -160,11 +160,12 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { val broadcastedConf = spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + val parsedOptions = new AvroOptions(options) (file: PartitionedFile) => { val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) val conf = broadcastedConf.value.value - val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) + val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse) // TODO Removes this check once `FileFormat` gets a general file filtering interface method. // Doing input file filtering is improper because we may generate empty tasks that process no @@ -235,8 +236,6 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { private[avro] object AvroFileFormat { val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" - val AvroSchema = "avroSchema" - class SerializableConfiguration(@transient var value: Configuration) extends Serializable with KryoSerializable { @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala new file mode 100644 index 0000000000000..8721eae3481da --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +/** + * Options for Avro Reader and Writer stored in case insensitive manner. + */ +class AvroOptions(@transient val parameters: CaseInsensitiveMap[String]) + extends Logging with Serializable { + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + /** + * Optional schema provided by an user in JSON format. + */ + val schema: Option[String] = parameters.get("avroSchema") + + /** + * Top level record name in write result, which is required in Avro spec. + * See https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + * Default value is "topLevelRecord" + */ + val recordName: String = parameters.getOrElse("recordName", "topLevelRecord") + + /** + * Record namespace in write result. Default value is "". + * See Avro spec for details: https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + */ + val recordNamespace: String = parameters.getOrElse("recordNamespace", "") +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 446b42124ceca..f7e9877b7744b 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -578,7 +578,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin val result = spark .read - .option(AvroFileFormat.AvroSchema, avroSchema) + .option("avroSchema", avroSchema) .avro(testAvro) .collect() val expected = spark.read.avro(testAvro).select("string").collect() @@ -598,7 +598,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | }] |} """.stripMargin - val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) + val result = spark + .read + .option("avroSchema", avroSchema) .avro(testAvro).select("missingField").first assert(result === Row("foo")) } From 1a4fda88685bf62d7d8c639ec2d9762107cd447b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 19 Jul 2018 10:48:10 +0800 Subject: [PATCH 1161/2461] [INFRA] Close stale PR Closes #17422 Closes #17619 Closes #18034 Closes #18229 Closes #18268 Closes #17973 Closes #18125 Closes #18918 Closes #19274 Closes #19456 Closes #19510 Closes #19420 Closes #20090 Closes #20177 Closes #20304 Closes #20319 Closes #20543 Closes #20437 Closes #21261 Closes #21726 Closes #14653 Closes #13143 Closes #17894 Closes #19758 Closes #12951 Closes #17092 Closes #21240 Closes #16910 Closes #12904 Closes #21731 Closes #21095 Added: Closes #19233 Closes #20100 Closes #21453 Closes #21455 Closes #18477 Added: Closes #21812 Closes #21787 Author: hyukjinkwon Closes #21781 from HyukjinKwon/closing-prs. From d05a926e78cf78253e895ca1d5d6f61a0538ee47 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Jul 2018 11:54:41 +0800 Subject: [PATCH 1162/2461] [SPARK-24840][SQL] do not use dummy filter to switch codegen on/of ## What changes were proposed in this pull request? It's a little tricky and fragile to use a dummy filter to switch codegen on/off. For now we should use local/cached relation to switch. In the future when we are able to use a config to turn off codegen, we shall use that. ## How was this patch tested? test only PR. Author: Wenchen Fan Closes #21795 from cloud-fan/follow. --- .../spark/sql/DataFrameFunctionsSuite.scala | 400 ++++++++++-------- .../org/apache/spark/sql/DataFrameSuite.scala | 42 +- 2 files changed, 239 insertions(+), 203 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a39aef18d27e1..bf04251e655ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -507,8 +507,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("map_entries") { - val dummyFilter = (c: Column) => c.isNotNull || c.isNull - // Primitive-type elements val idf = Seq( Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), @@ -521,15 +519,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) ) - checkAnswer(idf.select(map_entries('m)), iExpected) - checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) - checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) - checkAnswer( - spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), - Seq(Row(Seq(Row(1, null), Row(2, null))))) - checkAnswer( - spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), - Seq(Row(Seq(Row(1, null), Row(2, null))))) + def testPrimitiveType(): Unit = { + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.selectExpr("map_entries(map(1, null, 2, null))"), + Seq.fill(iExpected.length)(Row(Seq(Row(1, null), Row(2, null))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testPrimitiveType() // Non-primitive-type elements val sdf = Seq( @@ -545,9 +546,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) ) - checkAnswer(sdf.select(map_entries('m)), sExpected) - checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) - checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + def testNonPrimitiveType(): Unit = { + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() } test("map_concat function") { @@ -629,9 +637,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("map_from_entries function") { - def dummyFilter(c: Column): Column = c.isNull || c.isNotNull - val oneRowDF = Seq(3215).toDF("i") - // Test cases with primitive-type keys and values val idf = Seq( Seq((1, 10), (2, 20), (3, 10)), @@ -645,18 +650,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map.empty), Row(null)) - checkAnswer(idf.select(map_from_entries('a)), iExpected) - checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) - checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) - checkAnswer( - oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), - Seq(Row(Map(1 -> null, 2 -> null))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('i)) - .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), - Seq(Row(Map(1 -> null, 2 -> null))) - ) + def testPrimitiveType(): Unit = { + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq.fill(iExpected.length)(Row(Map(1 -> null, 2 -> null)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testPrimitiveType() // Test cases with non-primitive-type keys and values val sdf = Seq( @@ -673,9 +678,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Map.empty), Row(null)) - checkAnswer(sdf.select(map_from_entries('a)), sExpected) - checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) - checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + def testNonPrimitiveType(): Unit = { + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() } test("array contains function") { @@ -890,31 +902,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("reverse function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on - // String test cases val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + def testString(): Unit = { + checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS"))) + checkAnswer(oneRowDF.selectExpr("reverse(s)"), Seq(Row("krapS"))) + checkAnswer(oneRowDF.select(reverse('i)), Seq(Row("5123"))) + checkAnswer(oneRowDF.selectExpr("reverse(i)"), Seq(Row("5123"))) + checkAnswer(oneRowDF.selectExpr("reverse(null)"), Seq(Row(null))) + } - checkAnswer( - oneRowDF.select(reverse('s)), - Seq(Row("krapS")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(s)"), - Seq(Row("krapS")) - ) - checkAnswer( - oneRowDF.select(reverse('i)), - Seq(Row("5123")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(i)"), - Seq(Row("5123")) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(null)"), - Seq(Row(null)) - ) + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + oneRowDF.cache() + testString() // Array test cases (primitive-type elements) val idf = Seq( @@ -924,26 +926,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("i") - checkAnswer( - idf.select(reverse('i)), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - idf.filter(dummyFilter('i)).select(reverse('i)), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - idf.selectExpr("reverse(i)"), - Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), - Seq(Row(Seq(null, 2, null, 1))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), - Seq(Row(Seq(null, 2, null, 1))) - ) + def testArray(): Unit = { + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(array(1, null, 2, null))"), + Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1))) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + idf.cache() + testArray() // Array test cases (non-primitive-type elements) val sdf = Seq( @@ -953,26 +955,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { null ).toDF("s") - checkAnswer( - sdf.select(reverse('s)), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - sdf.filter(dummyFilter('s)).select(reverse('s)), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - sdf.selectExpr("reverse(s)"), - Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) - ) - checkAnswer( - oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), - Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) - ) - checkAnswer( - oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), - Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) - ) + def testArrayOfNonPrimitiveType(): Unit = { + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq.fill(sdf.count().toInt)(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testArrayOfNonPrimitiveType() // Error test cases intercept[AnalysisException] { @@ -1147,65 +1149,66 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null val df = Seq( - (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on - // Simple test cases - checkAnswer( - df.selectExpr("array(1, 2, 3L)"), - Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) - ) + def simpleTest(): Unit = { + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + } - checkAnswer ( - df.select(concat($"i1", $"s1")), - Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) - ) - checkAnswer( - df.select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.selectExpr("concat(array(1, null), i2, i3)"), - Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) - ) - checkAnswer( - df.select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.selectExpr("concat(s1, s2, s3)"), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) + // Test with local relation, the Project will be evaluated without codegen + simpleTest() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + simpleTest() // Null test cases - checkAnswer( - df.select(concat($"i1", $"in")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"in", $"i1")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"s1", $"sn")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"sn", $"s1")), - Seq(Row(null), Row(null)) - ) + def nullTest(): Unit = { + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + } + + // Test with local relation, the Project will be evaluated without codegen + df.unpersist() + nullTest() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + nullTest() // Type error test cases intercept[AnalysisException] { @@ -1223,9 +1226,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("flatten function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on - val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") - // Test cases with a primitive type val intDF = Seq( (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), @@ -1248,12 +1248,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null), Row(null)) - checkAnswer(intDF.select(flatten($"i")), intDFResult) - checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) - checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) - checkAnswer( - oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), - Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() // Test cases with non-primitive types val strDF = Seq( @@ -1279,14 +1283,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null), Row(null)) - checkAnswer(strDF.select(flatten($"s")), strDFResult) - checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) - checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) - checkAnswer( - oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), - Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") intercept[AnalysisException] { oneRowDF.select(flatten($"arr")) } @@ -1302,7 +1328,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("array_repeat function") { - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on val strDF = Seq( ("hi", 2), (null, 2) @@ -1313,12 +1338,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(null, null)) ) - checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) - checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) - checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) - checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) - checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) - checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + def testString(): Unit = { + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() val intDF = { val schema = StructType(Seq( @@ -1336,12 +1367,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(null, null)) ) - checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) - checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) - checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) - checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) - checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) - checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + def testInt(): Unit = { + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() val nullCountDF = { val schema = StructType(Seq( @@ -1354,13 +1391,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { spark.createDataFrame(spark.sparkContext.parallelize(data), schema) } - checkAnswer( - nullCountDF.select(array_repeat($"a", $"b")), - Seq( - Row(null), - Row(null) + def testNull(): Unit = { + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq(Row(null), Row(null)) ) - ) + } + + // Test with local relation, the Project will be evaluated without codegen + testNull() + // Test with cached relation, the Project will be evaluated with codegen + nullCountDF.cache() + testNull() // Error test cases val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5babdf6f33b99..9cf8c47fa6cf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2336,46 +2336,40 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val sourceDF = spark.createDataFrame(rows, schema) - val structWhenDF = sourceDF + def structWhenDF: DataFrame = sourceDF .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") .select('res.getField("val1")) - val arrayWhenDF = sourceDF + def arrayWhenDF: DataFrame = sourceDF .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") .select('res.getItem(0)) - val mapWhenDF = sourceDF + def mapWhenDF: DataFrame = sourceDF .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") .select('res.getItem(0)) - val structIfDF = sourceDF + def structIfDF: DataFrame = sourceDF .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") .select('res.getField("val1")) - val arrayIfDF = sourceDF + def arrayIfDF: DataFrame = sourceDF .select(expr("if(cond, array('a', 'b'), a)") as "res") .select('res.getItem(0)) - val mapIfDF = sourceDF + def mapIfDF: DataFrame = sourceDF .select(expr("if(cond, map(0, 'a'), m)") as "res") .select('res.getItem(0)) - def checkResult(df: DataFrame, codegenExpected: Boolean): Unit = { - assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == codegenExpected) - checkAnswer(df, Seq(Row("a"), Row(null))) + def checkResult(): Unit = { + checkAnswer(structWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(arrayWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(mapWhenDF, Seq(Row("a"), Row(null))) + checkAnswer(structIfDF, Seq(Row("a"), Row(null))) + checkAnswer(arrayIfDF, Seq(Row("a"), Row(null))) + checkAnswer(mapIfDF, Seq(Row("a"), Row(null))) } - // without codegen - checkResult(structWhenDF, false) - checkResult(arrayWhenDF, false) - checkResult(mapWhenDF, false) - checkResult(structIfDF, false) - checkResult(arrayIfDF, false) - checkResult(mapIfDF, false) - - // with codegen - checkResult(structWhenDF.filter('cond.isNotNull), true) - checkResult(arrayWhenDF.filter('cond.isNotNull), true) - checkResult(mapWhenDF.filter('cond.isNotNull), true) - checkResult(structIfDF.filter('cond.isNotNull), true) - checkResult(arrayIfDF.filter('cond.isNotNull), true) - checkResult(mapIfDF.filter('cond.isNotNull), true) + // Test with local relation, the Project will be evaluated without codegen + checkResult() + // Test with cached relation, the Project will be evaluated with codegen + sourceDF.cache() + checkResult() } test("Uuid expressions should produce same results at retries in the same DataFrame") { From 8b7d4f842fdc90b8d1c37080bdd9b5e1d070f5c0 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 19 Jul 2018 00:07:35 -0700 Subject: [PATCH 1163/2461] [SPARK-24717][SS] Split out max retain version of state for memory in HDFSBackedStateStoreProvider ## What changes were proposed in this pull request? This patch proposes breaking down configuration of retaining batch size on state into two pieces: files and in memory (cache). While this patch reuses existing configuration for files, it introduces new configuration, "spark.sql.streaming.maxBatchesToRetainInMemory" to configure max count of batch to retain in memory. ## How was this patch tested? Apply this patch on top of SPARK-24441 (https://github.com/apache/spark/pull/21469), and manually tested in various workloads to ensure overall size of states in memory is around 2x or less of the size of latest version of state, while it was 10x ~ 80x before applying the patch. Author: Jungtaek Lim Closes #21700 from HeartSaVioR/SPARK-24717. --- .../apache/spark/sql/internal/SQLConf.scala | 11 ++ .../state/HDFSBackedStateStoreProvider.scala | 57 +++++-- .../streaming/state/StateStoreConf.scala | 3 + .../streaming/state/StateStoreSuite.scala | 150 ++++++++++++++++-- 4 files changed, 196 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 41fe0c3b60d9e..9239d4ef45d36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -854,6 +854,15 @@ object SQLConf { .intConf .createWithDefault(100) + val MAX_BATCHES_TO_RETAIN_IN_MEMORY = buildConf("spark.sql.streaming.maxBatchesToRetainInMemory") + .internal() + .doc("The maximum number of batches which will be retained in memory to avoid " + + "loading from files. The value adjusts a trade-off between memory usage vs cache miss: " + + "'2' covers both success and direct failure cases, '1' covers only success case, " + + "and '0' covers extreme case - disable cache to maximize memory size of executors.") + .intConf + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -1507,6 +1516,8 @@ class SQLConf extends Serializable with Logging { def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) + def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 118c82aa75e68..523acef34ca61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ +import java.util import java.util.Locale import scala.collection.JavaConverters._ @@ -203,6 +204,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf + this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory fm.mkdirs(baseDir) } @@ -220,7 +222,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def close(): Unit = { - loadedMaps.values.foreach(_.clear()) + loadedMaps.values.asScala.foreach(_.clear()) } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { @@ -239,8 +241,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ + @volatile private var numberOfVersionsToRetainInMemory: Int = _ - private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse) private lazy val baseDir = stateStoreId.storeCheckpointLocation() private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) @@ -250,7 +253,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { finalizeDeltaFile(output) - loadedMaps.put(newVersion, map) + putStateIntoStateCacheMap(newVersion, map) } } @@ -260,7 +263,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet - val versionsLoaded = loadedMaps.keySet + val versionsLoaded = loadedMaps.keySet.asScala val allKnownVersions = versionsInFiles ++ versionsLoaded val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { @@ -270,11 +273,43 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } else Iterator.empty } + /** This method is intended to be only used for unit test(s). DO NOT TOUCH ELEMENTS IN MAP! */ + private[state] def getLoadedMaps(): util.SortedMap[Long, MapType] = synchronized { + // shallow copy as a minimal guard + loadedMaps.clone().asInstanceOf[util.SortedMap[Long, MapType]] + } + + private def putStateIntoStateCacheMap(newVersion: Long, map: MapType): Unit = synchronized { + if (numberOfVersionsToRetainInMemory <= 0) { + if (loadedMaps.size() > 0) loadedMaps.clear() + return + } + + while (loadedMaps.size() > numberOfVersionsToRetainInMemory) { + loadedMaps.remove(loadedMaps.lastKey()) + } + + val size = loadedMaps.size() + if (size == numberOfVersionsToRetainInMemory) { + val versionIdForLastKey = loadedMaps.lastKey() + if (versionIdForLastKey > newVersion) { + // this is the only case which we can avoid putting, because new version will be placed to + // the last key and it should be evicted right away + return + } else if (versionIdForLastKey < newVersion) { + // this case needs removal of the last key before putting new one + loadedMaps.remove(versionIdForLastKey) + } + } + + loadedMaps.put(newVersion, map) + } + /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { // Shortcut if the map for this version is already there to avoid a redundant put. - val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) } + val loadedCurrentVersionMap = synchronized { Option(loadedMaps.get(version)) } if (loadedCurrentVersionMap.isDefined) { return loadedCurrentVersionMap.get } @@ -286,7 +321,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val (result, elapsedMs) = Utils.timeTakenMs { val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } return snapshotCurrentVersionMap.get } @@ -302,7 +337,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit lastAvailableMap = Some(new MapType) } else { lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } + synchronized { Option(loadedMaps.get(lastAvailableVersion)) } .orElse(readSnapshotFile(lastAvailableVersion)) } } @@ -314,7 +349,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit updateFromDeltaFile(deltaVersion, resultMap) } - synchronized { loadedMaps.put(version, resultMap) } + synchronized { putStateIntoStateCacheMap(version, resultMap) } resultMap } @@ -506,7 +541,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val lastVersion = files.last.version val deltaFilesForLastVersion = filesForVersion(files, lastVersion).filter(_.isSnapshot == false) - synchronized { loadedMaps.get(lastVersion) } match { + synchronized { Option(loadedMaps.get(lastVersion)) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) @@ -536,10 +571,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head - synchronized { - val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq - mapsToRemove.foreach(loadedMaps.remove) - } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) val (_, e2) = Utils.timeTakenMs { filesToDelete.foreach { f => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 765ff076cb467..d145082a39b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -34,6 +34,9 @@ class StateStoreConf(@transient private val sqlConf: SQLConf) /** Minimum versions a State Store implementation should retain to allow rollbacks */ val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + /** Maximum count of versions a State Store implementation should retain in memory */ + val maxVersionsToRetainInMemory: Int = sqlConf.maxBatchesToRetainInMemory + /** * Optional fully qualified name of the subclass of [[StateStoreProvider]] * managing state data. That is, the implementation of the State Store to use. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 73f8705060402..bfeb2b16ff7be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import java.util import java.util.UUID import scala.collection.JavaConverters._ @@ -47,6 +48,7 @@ import org.apache.spark.util.Utils class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ import StateStoreTestsHelper._ @@ -64,21 +66,143 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] require(!StateStore.isMaintenanceRunning) } + def updateVersionTo( + provider: StateStoreProvider, + currentVersion: Int, + targetVersion: Int): Int = { + var newCurrentVersion = currentVersion + for (i <- newCurrentVersion until targetVersion) { + newCurrentVersion = incrementVersion(provider, i) + } + require(newCurrentVersion === targetVersion) + newCurrentVersion + } + + def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = { + val store = provider.getStore(currentVersion) + put(store, "a", currentVersion + 1) + store.commit() + currentVersion + 1 + } + + def checkLoadedVersions( + loadedMaps: util.SortedMap[Long, ProviderMapType], + count: Int, + earliestKey: Long, + latestKey: Long): Unit = { + assert(loadedMaps.size() === count) + assert(loadedMaps.firstKey() === earliestKey) + assert(loadedMaps.lastKey() === latestKey) + } + + def checkVersion( + loadedMaps: util.SortedMap[Long, ProviderMapType], + version: Long, + expectedData: Map[String, Int]): Unit = { + + val originValueMap = loadedMaps.get(version).asScala.map { entry => + rowToString(entry._1) -> rowToInt(entry._2) + }.toMap + + assert(originValueMap === expectedData) + } + + test("retaining only two latest versions when MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 2) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache will have two elements + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache, + // and ver 3 will be added but ver 1 will be evicted + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 3)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2) + checkVersion(loadedMaps, 3, Map("a" -> 3)) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + } + + test("failure after committing with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 1") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 1) + + var currentVersion = 0 + + // commit the ver 1 : cache will have one element + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1) + checkVersion(loadedMaps, 1, Map("a" -> 1)) + + // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache, + // and ver 2 will be added but ver 1 will be evicted + // this fact ensures cache miss will occur when this partition succeeds commit + // but there's a failure afterwards so have to reprocess previous batch + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> 2)) + + // suppose there has been failure after committing, and it decided to reprocess previous batch + currentVersion = 1 + + // committing to existing version which is committed partially but abandoned globally + val store = provider.getStore(currentVersion) + // negative value to represent reprocessing + put(store, "a", -2) + store.commit() + currentVersion += 1 + + // make sure newly committed version is reflected to the cache (overwritten) + assert(getData(provider) === Set("a" -> -2)) + loadedMaps = provider.getLoadedMaps() + checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2) + checkVersion(loadedMaps, 2, Map("a" -> -2)) + } + + test("no cache data with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 0") { + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, + numOfVersToRetainInMemory = 0) + + var currentVersion = 0 + + // commit the ver 1 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 1)) + var loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + + // commit the ver 2 : never cached + currentVersion = incrementVersion(provider, currentVersion) + assert(getData(provider) === Set("a" -> 2)) + loadedMaps = provider.getLoadedMaps() + assert(loadedMaps.size() === 0) + } + test("snapshotting") { val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 - def updateVersionTo(targetVersion: Int): Unit = { - for (i <- currentVersion + 1 to targetVersion) { - val store = provider.getStore(currentVersion) - put(store, "a", i) - store.commit() - currentVersion += 1 - } - require(currentVersion === targetVersion) - } - updateVersionTo(2) + currentVersion = updateVersionTo(provider, currentVersion, 2) require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files assert(getData(provider) === Set("a" -> 2)) @@ -89,7 +213,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } // After version 6, snapshotting should generate one snapshot file - updateVersionTo(6) + currentVersion = updateVersionTo(provider, currentVersion, 6) require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files @@ -104,7 +228,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files - updateVersionTo(20) + currentVersion = updateVersionTo(provider, currentVersion, 20) require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot @@ -535,9 +659,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] partition: Int, dir: String = newDir(), minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get, hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory) sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val provider = new HDFSBackedStateStoreProvider() provider.init( From 6a9a058e09abb1b629680a546c3d6358b49f723a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 19 Jul 2018 22:24:53 +0800 Subject: [PATCH 1164/2461] [SPARK-24858][SQL] Avoid unnecessary parquet footer reads ## What changes were proposed in this pull request? Currently the same Parquet footer is read twice in the function `buildReaderWithPartitionValues` of ParquetFileFormat if filter push down is enabled. Fix it with simple changes. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21814 from gengliangwang/parquetFooter. --- .../datasources/parquet/ParquetFileFormat.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 295960b1c2d30..2d4ac7686d4c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -364,10 +364,11 @@ class ParquetFileFormat val sharedConf = broadcastedHadoopConf.value.value + lazy val footerFileMetaData = + ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { - val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) - .getFileMetaData.getSchema + val parquetSchema = footerFileMetaData.getSchema val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold) filters @@ -384,12 +385,12 @@ class ParquetFileFormat // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. - def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) - footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") - } + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + val convertTz = - if (timestampConversion && !isCreatedByParquetMr()) { + if (timestampConversion && !isCreatedByParquetMr) { Some(DateTimeUtils.getTimeZone(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) } else { None From 8d707b06003bc97d06630b22e6ae7c35f99b3cdd Mon Sep 17 00:00:00 2001 From: Hieu Huynh <“Hieu.huynh@oath.com”> Date: Thu, 19 Jul 2018 09:52:07 -0500 Subject: [PATCH 1165/2461] [SPARK-24755][CORE] Executor loss can cause task to not be resubmitted MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description** As described in [SPARK-24755](https://issues.apache.org/jira/browse/SPARK-24755), when speculation is enabled, there is scenario that executor loss can cause task to not be resubmitted. This patch changes the variable killedByOtherAttempt to keeps track of the taskId of tasks that are killed by other attempt. By doing this, we can still prevent resubmitting task killed by other attempt while resubmit successful attempt when executor lost. **How was this patch tested?** A UT is added based on the UT written by xuanyuanking with modification to simulate the scenario described in SPARK-24755. Author: Hieu Huynh <“Hieu.huynh@oath.com”> Closes #21729 from hthuynh2/SPARK_24755. --- .../spark/scheduler/TaskSetManager.scala | 10 +- .../spark/scheduler/TaskSetManagerSuite.scala | 112 ++++++++++++++++++ 2 files changed, 117 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 6071605ad7f9d..defed1e0f9c6c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -84,10 +84,10 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) - // Set the coresponding index of Boolean var when the task killed by other attempt tasks, - // this happened while we set the `spark.speculation` to true. The task killed by others + // Add the tid of task into this HashSet when the task is killed by other attempt tasks. + // This happened while we set the `spark.speculation` to true. The task killed by others // should not resubmit while executor lost. - private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + private val killedByOtherAttempt = new HashSet[Long] val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -735,7 +735,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - killedByOtherAttempt(index) = true + killedByOtherAttempt += attemptInfo.taskId sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -947,7 +947,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index) && !killedByOtherAttempt(index)) { + if (successful(index) && !killedByOtherAttempt.contains(tid)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ae571e5a3583a..206b9f47eed4f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1414,6 +1414,118 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg taskSetManager2.checkSpeculatableTasks(0) } + + test("SPARK-24755 Executor loss can cause task to not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the index of tasks that are resubmitted, + // so that the test can check that task is resubmitted correctly + var resubmittedTasks = new mutable.HashSet[Int] + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += taskInfo.index + case _ => + } + } + } + sched.dagScheduler.stop() + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + + assert(resubmittedTasks.isEmpty) + // Host 2 Losts, meaning we lost the map output task4 + manager.executorLost("exec2", "host2", SlaveLost()) + // Make sure that task with index 2 is re-submitted + assert(resubmittedTasks.contains(2)) + + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From b3d88ac02940eff4c867d3acb79fe5ff9d724e83 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 19 Jul 2018 13:17:28 -0700 Subject: [PATCH 1166/2461] [SPARK-22187][SS] Update unsaferow format for saved state in flatMapGroupsWithState to allow timeouts with deleted state ## What changes were proposed in this pull request? Currently, the group state of user-defined-type is encoded as top-level columns in the UnsafeRows stores in the state store. The timeout timestamp is also saved as (when needed) as the last top-level column. Since the group state is serialized to top-level columns, you cannot save "null" as a value of state (setting null in all the top-level columns is not equivalent). So we don't let the user set the timeout without initializing the state for a key. Based on user experience, this leads to confusion. This PR is to change the row format such that the state is saved as nested columns. This would allow the state to be set to null, and avoid these confusing corner cases. However, queries recovering from existing checkpoint will use the previous format to maintain compatibility with existing production queries. ## How was this patch tested? Refactored existing end-to-end tests and added new tests for explicitly testing obj-to-row conversion for both state formats. Author: Tathagata Das Closes #21739 from tdas/SPARK-22187-1. --- .../sql/catalyst/expressions/Expression.scala | 3 +- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../spark/sql/execution/SparkStrategies.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 136 ++-------- .../sql/execution/streaming/OffsetSeq.scala | 10 +- .../FlatMapGroupsWithStateExecHelper.scala | 247 +++++++++++++++++ .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 3 + .../offsets/1 | 3 + .../state/0/0/1.delta | Bin 0 -> 84 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/1.delta | Bin 0 -> 46 bytes .../state/0/1/2.delta | Bin 0 -> 46 bytes .../state/0/2/1.delta | Bin 0 -> 46 bytes .../state/0/2/2.delta | Bin 0 -> 46 bytes .../state/0/3/1.delta | Bin 0 -> 46 bytes .../state/0/3/2.delta | Bin 0 -> 46 bytes .../state/0/4/1.delta | Bin 0 -> 46 bytes .../state/0/4/2.delta | Bin 0 -> 46 bytes ...latMapGroupsWithStateExecHelperSuite.scala | 218 +++++++++++++++ .../FlatMapGroupsWithStateSuite.scala | 250 +++++++++++++----- 23 files changed, 708 insertions(+), 180 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f7d1b105964d5..a69b80428472a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -715,7 +715,8 @@ trait ComplexTypeMergingExpression extends Expression { "The collection of input data types must not be empty.") require( TypeCoercion.haveSameType(inputTypesForMerging), - "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + "All input types must be the same except nullable, containsNull, valueContainsNull flags." + + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9239d4ef45d36..fbb9a8cfae2e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -843,6 +843,14 @@ object SQLConf { .intConf .createWithDefault(10) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") + .internal() + .doc("State format version used by flatMapGroupsWithState operation in a streaming query") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .stringConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 02e095b42a506..0c4ea857fd1d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -504,9 +504,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = FlatMapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion, + outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 8e82cccbc8fa3..bfe7d00f56048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -52,6 +50,7 @@ case class FlatMapGroupsWithStateExec( outputObjAttr: Attribute, stateInfo: Option[StatefulOperatorStateInfo], stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], @@ -60,32 +59,15 @@ case class FlatMapGroupsWithStateExec( ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs - } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - private val watermarkPresent = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } + private[sql] val stateManager = + createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -125,11 +107,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -143,7 +125,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -158,7 +140,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -167,14 +149,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -183,20 +157,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -205,12 +178,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutPairs = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutPairs.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutPairs.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -220,22 +192,19 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -243,50 +212,24 @@ case class FlatMapGroupsWithStateExec( watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -295,28 +238,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 1ae3f36c152cf..9847756f22d4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -87,7 +88,8 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( - SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -100,7 +102,9 @@ object OffsetSeqMetadata extends Logging { * with a specific default value for ensuring same behavior of the query as before. */ private val relevantSQLConfDefaultValues = Map[String, String]( - STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> + FlatMapGroupsWithStateExecHelper.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala new file mode 100644 index 0000000000000..0a16a3819b778 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types._ + + +object FlatMapGroupsWithStateExecHelper { + + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + /** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ + case class StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + + private[FlatMapGroupsWithStateExecHelper] def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } + } + + /** Interface for interacting with state data of FlatMapGroupsWithState */ + sealed trait StateManager extends Serializable { + def stateSchema: StructType + def getState(store: StateStore, keyRow: UnsafeRow): StateData + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit + def removeState(store: StateStore, keyRow: UnsafeRow): Unit + def getAllState(store: StateStore): Iterator[StateData] + } + + def createStateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean, + stateFormatVersion: Int): StateManager = { + stateFormatVersion match { + case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp) + case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } + + // =============================================================================================== + // =========================== Private implementations of StateManager =========================== + // =============================================================================================== + + /** Commmon methods for StateManager implementations */ + private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) + extends StateManager { + + protected def stateSerializerExprs: Seq[Expression] + protected def stateDeserializerExpr: Expression + protected def timeoutTimestampOrdinalInRow: Int + + /** Get deserialized state and corresponding timeout timestamp for a key */ + override def getState(store: StateStore, keyRow: UnsafeRow): StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + override def putState(store: StateStore, key: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(key, stateRow) + } + + override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + override def getAllState(store: StateStore): Iterator[StateData] = { + val stateData = StateData() + store.getRange(None, None).map { p => + stateData.withNew(p.key, p.value, getStateObject(p.value), getTimestamp(p.value)) + } + } + + private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) + private lazy val stateDeserializerFunc = { + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + } + private lazy val stateDataForGets = StateData() + + protected def getStateObject(row: UnsafeRow): Any = { + if (row != null) stateDeserializerFunc(row) else null + } + + protected def getStateRow(obj: Any): UnsafeRow = { + stateSerializerFunc(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinalInRow) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps) + } + } + + /** + * Version 1 of the StateManager which stores the user-defined state as flattened columns in + * the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * + * UnsafeRow[ col1 | col2 | col3 | timestamp ] + * + * The limitation of this format is that timestamp cannot be set when the user-defined + * state has been removed. This is because the columns cannot be collectively marked to be + * empty/null. + */ + private class StateManagerImplV1( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + override val stateSchema: StructType = stateAttributes.toStructType + + override val timeoutTimestampOrdinalInRow: Int = { + stateAttributes.indexOf(timestampTimeoutAttribute) + } + + override val stateSerializerExprs: Seq[Expression] = { + val encoderSerializer = stateEncoder.namedExpressions + if (shouldStoreTimestamp) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + stateEncoder.resolveAndBind().deserializer + } + + override protected def getStateRow(obj: Any): UnsafeRow = { + require(obj != null, "State object cannot be null") + super.getStateRow(obj) + } + } + + /** + * Version 2 of the StateManager which stores the user-defined state as a nested struct + * in the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, col3. The + * unsafe rows will look like this. + * ___________________________ + * | | + * | V + * UnsafeRow[ nested-struct | timestamp | UnsafeRow[ col1 | col2 | col3 ] ] + * + * This allows the entire user-defined state to be collectively marked as empty/null, + * thus allowing timestamp to be set without requiring the state to be present. + */ + private class StateManagerImplV2( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) { + + /** Schema of the state rows saved in the state store */ + override val stateSchema: StructType = { + var schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema = schema.add("timeoutTimestamp", LongType, nullable = false) + schema + } + + // Ordinals of the information stored in the state row + private val nestedStateOrdinal = 0 + override val timeoutTimestampOrdinalInRow = 1 + + override val stateSerializerExprs: Seq[Expression] = { + val boundRefToSpecificInternalRow = BoundReference( + 0, stateEncoder.serializer.head.collect { case b: BoundReference => b.dataType }.head, true) + + val nestedStateSerExpr = + CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + + val nullSafeNestedStateSerExpr = { + val nullLiteral = Literal(null, nestedStateSerExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToSpecificInternalRow) -> nullLiteral), nestedStateSerExpr) + } + + if (shouldStoreTimestamp) { + Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nullSafeNestedStateSerExpr) + } + } + + override val stateDeserializerExpr: Expression = { + // Note that this must be done in the driver, as resolving and binding of deserializer + // expressions to the encoded type can be safely done only in the driver. + val boundRefToNestedState = + BoundReference(nestedStateOrdinal, stateEncoder.schema, nullable = true) + val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + val nullLiteral = Literal(null, deserExpr.dataType) + CaseWhen(Seq(IsNull(boundRefToNestedState) -> nullLiteral), elseValue = deserExpr) + } + } +} diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata new file mode 100644 index 0000000000000..372180b2096ee --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"04d960cd-d38f-4ce6-b8d0-ebcf84c9dccc"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 new file mode 100644 index 0000000000000..807d7b0063b96 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 new file mode 100644 index 0000000000000..cce541073fb4b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..193524ffe15b51c941eb08906f274e7708616f37 GIT binary patch literal 84 zcmeZ?GI7euPtI1=Vqjpf0pdRCZw$deT7rR*VKO6-AppdQ0t_5741)ap3=bF>6#Rf9 QK=2<3e4yGzAwm!m0E+|;=l}o! literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala new file mode 100644 index 0000000000000..dec30fd01f7e2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types._ + + +class FlatMapGroupsWithStateExecHelperSuite extends StreamTest { + + import testImplicits._ + import FlatMapGroupsWithStateExecHelper._ + + // ============================ StateManagerImplV1 ============================ + + test(s"StateManager v1 - primitive type - without timestamp") { + val schema = new StructType().add("value", IntegerType, nullable = false) + testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - primitive type - with timestamp") { + val schema = new StructType() + .add("value", IntegerType, nullable = false) + .add("timeoutTimestamp", IntegerType, nullable = false) + testStateManagerWithTimestamp[Int](version = 1, schema, Seq(0, 10)) + } + + test(s"StateManager v1 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + test(s"StateManager v1 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType)) + )), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null)) + + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues) + + // Verify the limitation of v1 with null state + intercept[Exception] { + testStateManagerWithTimestamp[NestedStruct](version = 1, schema, testValues = Seq(null)) + } + } + + // ============================ StateManagerImplV2 ============================ + + test(s"StateManager v2 - primitive type - without timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + testStateManagerWithoutTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - primitive type - with timestamp") { + val schema = new StructType() + .add("groupState", new StructType().add("value", IntegerType, nullable = false)) + .add("timeoutTimestamp", LongType, nullable = false) + testStateManagerWithTimestamp[Int](version = 2, schema, Seq(0, 10)) + } + + test(s"StateManager v2 - nested type - without timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithoutTimestamp[NestedStruct](version = 2, schema, testValues) + } + + test(s"StateManager v2 - nested type - with timestamp") { + val schema = StructType(Seq( + StructField("groupState", StructType(Seq( + StructField("i", IntegerType, nullable = false), + StructField("nested", StructType(Seq( + StructField("d", DoubleType, nullable = false), + StructField("str", StringType) + ))) + ))), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val testValues = Seq( + NestedStruct(1, Struct(1.0, "someString")), + NestedStruct(0, Struct(0.0, "")), + NestedStruct(0, null), + null) + + testStateManagerWithTimestamp[NestedStruct](version = 2, schema, testValues) + } + + + def testStateManagerWithoutTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = false) + assert(stateManager.stateSchema === expectedStateSchema) + testStateManager(stateManager, testValues, NO_TIMESTAMP) + } + + def testStateManagerWithTimestamp[T: Encoder]( + version: Int, + expectedStateSchema: StructType, + testValues: Seq[T]): Unit = { + val stateManager = newStateManager[T](version, withTimestamp = true) + assert(stateManager.stateSchema === expectedStateSchema) + for (timestamp <- Seq(NO_TIMESTAMP, 1000)) { + testStateManager(stateManager, testValues, timestamp) + } + } + + private def testStateManager[T: Encoder]( + stateManager: StateManager, + values: Seq[T], + timestamp: Long): Unit = { + val keys = (1 to values.size).map(_ => newKey()) + val store = new MemoryStateStore() + + // Test stateManager.getState(), putState(), removeState() + keys.zip(values).foreach { case (key, value) => + try { + stateManager.putState(store, key, value, timestamp) + val data = stateManager.getState(store, key) + assert(data.stateObj == value) + assert(data.timeoutTimestamp === timestamp) + stateManager.removeState(store, key) + assert(stateManager.getState(store, key).stateObj == null) + } catch { + case e: Throwable => + fail(s"put/get/remove test with '$value' failed", e) + } + } + + // Test stateManager.getAllState() + for (i <- keys.indices) { + stateManager.putState(store, keys(i), values(i), timestamp) + } + val allData = stateManager.getAllState(store).map(_.copy()).toArray + assert(allData.map(_.timeoutTimestamp).toSet == Set(timestamp)) + assert(allData.map(_.stateObj).toSet == values.toSet) + } + + private def newStateManager[T: Encoder](version: Int, withTimestamp: Boolean): StateManager = { + FlatMapGroupsWithStateExecHelper.createStateManager( + implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]], + withTimestamp, + version) + } + + private val proj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val keyCounter = new AtomicInteger(0) + private def newKey(): UnsafeRow = { + proj.apply(new GenericInternalRow(Array[Any](keyCounter.getAndDecrement()))).copy() + } +} + +case class Struct(d: Double, str: String) +case class NestedStruct(i: Int, nested: Struct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 988c8e6753e25..82d7755aef5f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.sql.Date import java.util.concurrent.ConcurrentHashMap +import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll import org.scalatest.exceptions.TestFailedException @@ -31,10 +33,12 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils /** Class to check custom state types */ case class RunningCount(count: Long) @@ -359,13 +363,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " @@ -396,7 +400,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -443,6 +447,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -453,6 +469,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = @@ -477,48 +517,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest priorTimeoutTimestamp = priorTimeoutTimestamp, expectedState = Some(5), // state should change expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -590,7 +603,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("flatMapGroupsWithState - streaming") { + testWithAllStateVersions("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -669,7 +682,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming + aggregation") { + testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -728,7 +741,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("flatMapGroupsWithState - streaming with processing time timeout") { + testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { @@ -792,7 +805,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } - test("flatMapGroupsWithState - streaming with event time timeout + watermark") { + testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time timeout + watermark") { // Function to maintain the max event time as state and set the timeout timestamp based on the // current max event time seen. It returns the max event time in the state, or -1 if the state // was removed by timeout. @@ -843,6 +856,105 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest ) } + test("flatMapGroupsWithState - uses state format version 2 by default") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { query => + // Verify state format = 2 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 2) + } + ) + } + + test("flatMapGroupsWithState - recovery from checkpoint uses state format version 1") { + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(("a", 11), ("a", 13), ("a", 15)) + inputData.addData(("a", 4)) + + testStream(result, Update)( + StartStream( + checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + */ + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + Execute { query => + // Verify state format = 1 + val f = query.lastExecution.executedPlan.collect { case f: FlatMapGroupsWithStateExec => f } + assert(f.size == 1) + assert(f.head.stateFormatVersion == 1) + }, + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -1032,7 +1144,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -1054,7 +1166,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -1081,21 +1193,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -1106,15 +1217,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } @@ -1122,6 +1229,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { + val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) MemoryStream[Int] .toDS .groupByKey(x => x) @@ -1129,7 +1237,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, + f, k, v, g, d, o, None, s, stateFormatVersion, m, t, Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } @@ -1162,6 +1270,16 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } def rowToInt(row: UnsafeRow): Int = row.getInt(0) + + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { + for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { + test(s"$name - state format version $version") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) { + func + } + } + } + } } object FlatMapGroupsWithStateSuite { From 67e108daa6f324e7f4f7db2bda980a9945b59396 Mon Sep 17 00:00:00 2001 From: Ger van Rossum Date: Thu, 19 Jul 2018 23:28:16 +0200 Subject: [PATCH 1167/2461] [SPARK-24846][SQL] Made hashCode ExprId independent of jvmId ## What changes were proposed in this pull request? Made ExprId hashCode independent of jvmId to make canonicalization independent of JVM, by overriding hashCode (and necessarily also equality) to depend on id only ## How was this patch tested? Created a unit test ExprIdSuite Ran all unit tests of sql/catalyst Author: Ger van Rossum Closes #21806 from gvr/spark24846-canonicalization. --- .../expressions/namedExpressions.scala | 11 +++- .../catalyst/expressions/ExprIdSuite.scala | 50 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8df870468c2ad..ce5c2804d08ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -40,7 +40,16 @@ object NamedExpression { * * The `id` field is unique within a given JVM, while the `uuid` is used to uniquely identify JVMs. */ -case class ExprId(id: Long, jvmId: UUID) +case class ExprId(id: Long, jvmId: UUID) { + + override def equals(other: Any): Boolean = other match { + case ExprId(id, jvmId) => this.id == id && this.jvmId == jvmId + case _ => false + } + + override def hashCode(): Int = id.hashCode() + +} object ExprId { def apply(id: Long): ExprId = ExprId(id, NamedExpression.jvmId) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala new file mode 100644 index 0000000000000..2352db405b1a8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExprIdSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.UUID + +import org.apache.spark.SparkFunSuite + +class ExprIdSuite extends SparkFunSuite { + + private val jvmId = UUID.randomUUID() + private val otherJvmId = UUID.randomUUID() + + test("hashcode independent of jvmId") { + val exprId1 = ExprId(12, jvmId) + val exprId2 = ExprId(12, otherJvmId) + assert(exprId1 != exprId2) + assert(exprId1.hashCode() == exprId2.hashCode()) + } + + test("equality should depend on both id and jvmId") { + val exprId1 = ExprId(1, jvmId) + val exprId2 = ExprId(1, jvmId) + assert(exprId1 == exprId2) + + val exprId3 = ExprId(1, jvmId) + val exprId4 = ExprId(2, jvmId) + assert(exprId3 != exprId4) + + val exprId5 = ExprId(1, jvmId) + val exprId6 = ExprId(1, otherJvmId) + assert(exprId5 != exprId6) + } + +} From 7e847646d1f377f46dc3154dea37148d4e557a03 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 20 Jul 2018 11:16:53 +0800 Subject: [PATCH 1168/2461] [SPARK-24307][CORE] Support reading remote cached partitions > 2gb (1) Netty's ByteBuf cannot support data > 2gb. So to transfer data from a ChunkedByteBuffer over the network, we use a custom version of FileRegion which is backed by the ChunkedByteBuffer. (2) On the receiving end, we need to expose all the data in a FileSegmentManagedBuffer as a ChunkedByteBuffer. We do that by memory mapping the entire file in chunks. Added unit tests. Ran the randomized test a couple of hundred times on my laptop. Tests cover the equivalent of SPARK-24107 for the ChunkedByteBufferFileRegion. Also tested on a cluster with remote cache reads >2gb (in memory and on disk). Author: Imran Rashid Closes #21440 from squito/chunked_bb_file_region. --- .../apache/spark/storage/BlockManager.scala | 11 +- .../spark/util/io/ChunkedByteBuffer.scala | 44 ++++- .../util/io/ChunkedByteBufferFileRegion.scala | 86 ++++++++++ .../io/ChunkedByteBufferFileRegionSuite.scala | 152 ++++++++++++++++++ .../spark/io/ChunkedByteBufferSuite.scala | 2 +- 5 files changed, 286 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala create mode 100644 core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0e1c7d5fd3fa2..1db032711ce42 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -130,6 +130,8 @@ private[spark] class BlockManager( private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val chunkSize = + conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. @@ -660,6 +662,11 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + // TODO if we change this method to return the ManagedBuffer, then getRemoteValues + // could just use the inputStream on the temp file, rather than memory-mapping the file. + // Until then, replication can cause the process to use too much memory and get killed + // by the OS / cluster manager (not a java OOM, since its a memory-mapped file) even though + // we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 @@ -690,7 +697,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -724,7 +731,7 @@ private[spark] class BlockManager( } if (data != null) { - return Some(new ChunkedByteBuffer(data)) + return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize)) } logDebug(s"The value of block $blockId is null") } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 700ce56466c35..efed90cb7678e 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -17,17 +17,21 @@ package org.apache.spark.util.io -import java.io.InputStream +import java.io.{File, FileInputStream, InputStream} import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel +import java.nio.channels.{FileChannel, WritableByteChannel} +import java.nio.file.StandardOpenOption + +import scala.collection.mutable.ListBuffer import com.google.common.primitives.UnsignedBytes -import io.netty.buffer.{ByteBuf, Unpooled} import org.apache.spark.SparkEnv import org.apache.spark.internal.config +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.storage.StorageUtils +import org.apache.spark.util.Utils /** * Read-only byte buffer which is physically stored as multiple chunks rather than a single @@ -81,10 +85,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Wrap this buffer to view it as a Netty ByteBuf. + * Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB. */ - def toNetty: ByteBuf = { - Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) + def toNetty: ChunkedByteBufferFileRegion = { + new ChunkedByteBufferFileRegion(this, bufferWriteChunkSize) } /** @@ -166,6 +170,34 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } +object ChunkedByteBuffer { + // TODO eliminate this method if we switch BlockManager to getting InputStreams + def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = { + data match { + case f: FileSegmentManagedBuffer => + map(f.getFile, maxChunkSize, f.getOffset, f.getLength) + case other => + new ChunkedByteBuffer(other.nioByteBuffer()) + } + } + + def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { + Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => + var remaining = length + var pos = offset + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, maxChunkSize) + val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize) + pos += chunkSize + remaining -= chunkSize + chunks += chunk + } + new ChunkedByteBuffer(chunks.toArray) + } + } +} + /** * Reads data from a ChunkedByteBuffer. * diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala new file mode 100644 index 0000000000000..9622d0ac05368 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferFileRegion.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.io + +import java.nio.channels.WritableByteChannel + +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.AbstractFileRegion + + +/** + * This exposes a ChunkedByteBuffer as a netty FileRegion, just to allow sending > 2gb in one netty + * message. This is because netty cannot send a ByteBuf > 2g, but it can send a large FileRegion, + * even though the data is not backed by a file. + */ +private[io] class ChunkedByteBufferFileRegion( + private val chunkedByteBuffer: ChunkedByteBuffer, + private val ioChunkSize: Int) extends AbstractFileRegion { + + private var _transferred: Long = 0 + // this duplicates the original chunks, so we're free to modify the position, limit, etc. + private val chunks = chunkedByteBuffer.getChunks() + private val size = chunks.foldLeft(0L) { _ + _.remaining() } + + protected def deallocate: Unit = {} + + override def count(): Long = size + + // this is the "start position" of the overall Data in the backing file, not our current position + override def position(): Long = 0 + + override def transferred(): Long = _transferred + + private var currentChunkIdx = 0 + + def transferTo(target: WritableByteChannel, position: Long): Long = { + assert(position == _transferred) + if (position == size) return 0L + var keepGoing = true + var written = 0L + var currentChunk = chunks(currentChunkIdx) + while (keepGoing) { + while (currentChunk.hasRemaining && keepGoing) { + val ioSize = Math.min(currentChunk.remaining(), ioChunkSize) + val originalLimit = currentChunk.limit() + currentChunk.limit(currentChunk.position() + ioSize) + val thisWriteSize = target.write(currentChunk) + currentChunk.limit(originalLimit) + written += thisWriteSize + if (thisWriteSize < ioSize) { + // the channel did not accept our entire write. We do *not* keep trying -- netty wants + // us to just stop, and report how much we've written. + keepGoing = false + } + } + if (keepGoing) { + // advance to the next chunk (if there are any more) + currentChunkIdx += 1 + if (currentChunkIdx == chunks.size) { + keepGoing = false + } else { + currentChunk = chunks(currentChunkIdx) + } + } + } + _transferred += written + written + } +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala new file mode 100644 index 0000000000000..a6b0654204f34 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferFileRegionSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io + +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel + +import scala.util.Random + +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.util.io.ChunkedByteBuffer + +class ChunkedByteBufferFileRegionSuite extends SparkFunSuite with MockitoSugar + with BeforeAndAfterEach { + + override protected def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf() + val env = mock[SparkEnv] + SparkEnv.set(env) + when(env.conf).thenReturn(conf) + } + + override protected def afterEach(): Unit = { + SparkEnv.set(null) + } + + private def generateChunkedByteBuffer(nChunks: Int, perChunk: Int): ChunkedByteBuffer = { + val bytes = (0 until nChunks).map { chunkIdx => + val bb = ByteBuffer.allocate(perChunk) + (0 until perChunk).foreach { idx => + bb.put((chunkIdx * perChunk + idx).toByte) + } + bb.position(0) + bb + }.toArray + new ChunkedByteBuffer(bytes) + } + + test("transferTo can stop and resume correctly") { + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 9L) + val cbb = generateChunkedByteBuffer(4, 10) + val fileRegion = cbb.toNetty + + val targetChannel = new LimitedWritableByteChannel(40) + + var pos = 0L + // write the fileregion to the channel, but with the transfer limited at various spots along + // the way. + + // limit to within the first chunk + targetChannel.acceptNBytes = 5 + pos = fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 5) + + // a little bit further within the first chunk + targetChannel.acceptNBytes = 2 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 7) + + // past the first chunk, into the 2nd + targetChannel.acceptNBytes = 6 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 13) + + // right to the end of the 2nd chunk + targetChannel.acceptNBytes = 7 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 20) + + // rest of 2nd chunk, all of 3rd, some of 4th + targetChannel.acceptNBytes = 15 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 35) + + // now till the end + targetChannel.acceptNBytes = 5 + pos += fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + + // calling again at the end should be OK + targetChannel.acceptNBytes = 20 + fileRegion.transferTo(targetChannel, pos) + assert(targetChannel.pos === 40) + } + + test(s"transfer to with random limits") { + val rng = new Random() + val seed = System.currentTimeMillis() + logInfo(s"seed = $seed") + rng.setSeed(seed) + val chunkSize = 1e4.toInt + SparkEnv.get.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, rng.nextInt(chunkSize).toLong) + + val cbb = generateChunkedByteBuffer(50, chunkSize) + val fileRegion = cbb.toNetty + val transferLimit = 1e5.toInt + val targetChannel = new LimitedWritableByteChannel(transferLimit) + while (targetChannel.pos < cbb.size) { + val nextTransferSize = rng.nextInt(transferLimit) + targetChannel.acceptNBytes = nextTransferSize + fileRegion.transferTo(targetChannel, targetChannel.pos) + } + assert(0 === fileRegion.transferTo(targetChannel, targetChannel.pos)) + } + + /** + * This mocks a channel which only accepts a limited number of bytes at a time. It also verifies + * the written data matches our expectations as the data is received. + */ + private class LimitedWritableByteChannel(maxWriteSize: Int) extends WritableByteChannel { + val bytes = new Array[Byte](maxWriteSize) + var acceptNBytes = 0 + var pos = 0 + + override def write(src: ByteBuffer): Int = { + val length = math.min(acceptNBytes, src.remaining()) + src.get(bytes, 0, length) + acceptNBytes -= length + // verify we got the right data + (0 until length).foreach { idx => + assert(bytes(idx) === (pos + idx).toByte, s"; wrong data at ${pos + idx}") + } + pos += length + length + } + + override def isOpen: Boolean = true + + override def close(): Unit = {} + } + +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 2107559572d78..ff117b1c21cb1 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -34,7 +34,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { assert(emptyChunkedByteBuffer.getChunks().isEmpty) assert(emptyChunkedByteBuffer.toArray === Array.empty) assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0) - assert(emptyChunkedByteBuffer.toNetty.capacity() === 0) + assert(emptyChunkedByteBuffer.toNetty.count() === 0) emptyChunkedByteBuffer.toInputStream(dispose = false).close() emptyChunkedByteBuffer.toInputStream(dispose = true).close() } From 7db81ac8a2d6c3c19db387d3d25053750b1404dd Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 20 Jul 2018 11:25:51 +0800 Subject: [PATCH 1169/2461] [SPARK-24195][CORE] Ignore the files with "local" scheme in SparkContext.addFile ## What changes were proposed in this pull request? In Spark "local" scheme means resources are already on the driver/executor nodes, this pr ignore the files with "local" scheme in `SparkContext.addFile` for fixing potential bug. ## How was this patch tested? Existing tests. Author: Yuanjian Li Closes #21533 from xuanyuanking/SPARK-24195. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 531384ab57305..78ba0b31fc6bb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1524,7 +1524,11 @@ class SparkContext(config: SparkConf) extends Logging { def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri val schemeCorrectedPath = uri.getScheme match { - case null | "local" => new File(path).getCanonicalFile.toURI.toString + case null => new File(path).getCanonicalFile.toURI.toString + case "local" => + logWarning("File with 'local' scheme is not supported to add to file server, since " + + "it is already available on every node.") + return case _ => path } From 1462b17666729cd6c9e8dfa2a1fe9c2020d3f25b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Jul 2018 13:40:26 +0800 Subject: [PATCH 1170/2461] [SPARK-24861][SS][TEST] create corrected temp directories in RateSourceSuite ## What changes were proposed in this pull request? `RateSourceSuite` may leave garbage files under `sql/core/dummy`, we should use a corrected temp directory ## How was this patch tested? test only Author: Wenchen Fan Closes #21817 from cloud-fan/minor. --- .../sources/RateStreamProviderSuite.scala | 127 +++++++++--------- 1 file changed, 67 insertions(+), 60 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index bf72e5c99689f..9115a384d0790 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.execution.streaming.sources -import java.nio.file.Files import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ @@ -54,12 +52,15 @@ class RateSourceSuite extends StreamTest { } test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") + withTempDir { temp => + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader( + Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } } } @@ -108,69 +109,75 @@ class RateSourceSuite extends StreamTest { } test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) + withTempDir { temp => + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } } test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") + withTempDir { temp => + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + temp.getCanonicalPath) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } } } test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createPartitionReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) + withTempDir { temp => + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), + temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.planInputPartitions() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createPartitionReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) } - assert(data.size === 20) } test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createPartitionReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } + withTempDir { temp => + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), + temp.getCanonicalPath) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.planInputPartitions() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createPartitionReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } } test("valueAtSecond") { From a5925c1631e25c2dcc3c2948cea31e993ce66a97 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 19 Jul 2018 23:29:29 -0700 Subject: [PATCH 1171/2461] [SPARK-24268][SQL] Use datatype.catalogString in error messages ## What changes were proposed in this pull request? As stated in https://github.com/apache/spark/pull/21321, in the error messages we should use `catalogString`. This is not the case, as SPARK-22893 used `simpleString` in order to have the same representation everywhere and it missed some places. The PR unifies the messages using alway the `catalogString` representation of the dataTypes in the messages. ## How was this patch tested? existing/modified UTs Author: Marco Gaido Closes #21804 from mgaido91/SPARK-24268_catalog. --- .../spark/sql/kafka010/KafkaWriteTask.scala | 6 +++--- .../spark/sql/kafka010/KafkaWriter.scala | 6 +++--- .../kafka010/KafkaContinuousSinkSuite.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 4 ++-- .../org/apache/spark/ml/feature/DCT.scala | 3 ++- .../spark/ml/feature/FeatureHasher.scala | 5 +++-- .../apache/spark/ml/feature/HashingTF.scala | 2 +- .../apache/spark/ml/feature/Interaction.scala | 3 ++- .../org/apache/spark/ml/feature/NGram.scala | 3 ++- .../spark/ml/feature/OneHotEncoder.scala | 3 ++- .../org/apache/spark/ml/feature/RFormula.scala | 2 +- .../spark/ml/feature/StopWordsRemover.scala | 4 ++-- .../apache/spark/ml/feature/Tokenizer.scala | 3 ++- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/util/SchemaUtils.scala | 11 +++++++---- .../BinaryClassificationEvaluatorSuite.scala | 4 ++-- .../spark/ml/feature/RFormulaSuite.scala | 2 +- .../ml/feature/VectorAssemblerSuite.scala | 6 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- .../AFTSurvivalRegressionSuite.scala | 2 +- .../apache/spark/ml/util/MLTestingUtils.scala | 6 +++--- .../spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 18 +++++++++--------- .../spark/sql/catalyst/analysis/view.scala | 3 ++- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/ExpectsInputTypes.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 4 ++-- .../sql/catalyst/expressions/ScalaUDF.scala | 5 +++-- .../sql/catalyst/expressions/SortOrder.scala | 2 +- .../aggregate/ApproximatePercentile.scala | 4 ++-- .../sql/catalyst/expressions/arithmetic.scala | 4 ++-- .../expressions/codegen/CodeGenerator.scala | 4 ++-- .../expressions/collectionOperations.scala | 14 +++++++------- .../expressions/complexTypeCreator.scala | 8 ++++---- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/conditionalExpressions.scala | 4 ++-- .../sql/catalyst/expressions/generators.scala | 8 ++++---- .../catalyst/expressions/jsonExpressions.scala | 6 +++--- .../catalyst/expressions/objects/objects.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../expressions/stringExpressions.scala | 7 ++++--- .../expressions/windowExpressions.scala | 6 +++--- .../sql/catalyst/json/JacksonGenerator.scala | 8 ++++---- .../sql/catalyst/json/JacksonParser.scala | 6 ++++-- .../spark/sql/catalyst/json/JacksonUtils.scala | 2 +- .../sql/catalyst/json/JsonInferSchema.scala | 6 ++++-- .../spark/sql/catalyst/util/TypeUtils.scala | 7 ++++--- .../spark/sql/types/AbstractDataType.scala | 9 +++++---- .../org/apache/spark/sql/types/ArrayType.scala | 5 +++-- .../apache/spark/sql/types/DecimalType.scala | 3 ++- .../apache/spark/sql/types/ObjectType.scala | 3 ++- .../apache/spark/sql/types/StructType.scala | 5 +++-- .../catalyst/analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/ExpressionTypeCheckingSuite.scala | 16 ++++++++-------- .../parser/ExpressionParserSuite.scala | 2 +- .../apache/spark/sql/types/DataTypeSuite.scala | 2 +- .../parquet/VectorizedColumnReader.java | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/arrow/ArrowUtils.scala | 6 ++++-- .../sql/execution/arrow/ArrowWriter.scala | 2 +- .../sql/execution/columnar/ColumnType.scala | 2 +- .../datasources/DataSourceUtils.scala | 2 +- .../execution/datasources/jdbc/JdbcUtils.scala | 6 +++--- .../execution/datasources/orc/OrcFilters.scala | 2 +- .../execution/datasources/orc/OrcUtils.scala | 2 +- .../parquet/ParquetSchemaConverter.scala | 2 +- .../sql/execution/datasources/rules.scala | 4 ++-- .../sql/execution/stat/StatFunctions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 ++-- .../sql-tests/results/literals.sql.out | 6 +++--- .../spark/sql/FileBasedDataSourceSuite.scala | 8 ++++---- .../parquet/ParquetSchemaSuite.scala | 4 ++-- .../spark/sql/hive/HiveExternalCatalog.scala | 4 ++-- .../sql/hive/execution/HiveTableScanExec.scala | 6 +++--- .../sql/hive/orc/HiveOrcSourceSuite.scala | 4 ++-- 76 files changed, 185 insertions(+), 161 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index d90630a8adc93..041fac7717635 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - "must be a StringType") + s"must be a ${StringType.catalogString}") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.catalogString}") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.catalogString}") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..fc09938a43a8c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a String") + throw new AnalysisException(s"Topic type must be a ${StringType.catalogString}") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index ddfc0c1a4be2d..0e1492ac27449 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -314,7 +314,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { /* key field wrong type */ @@ -330,7 +330,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 7079ac6453ffc..70ffd7dee89d7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -303,7 +303,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +318,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 682787a830113..32d98151bdcff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,7 +69,8 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + require(inputType.isInstanceOf[VectorUDT], + s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index d67e4819b161a..dc38ee326e5e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -208,8 +208,9 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + - s"Column $fieldName was $dataType") + s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + + s"${BooleanType.catalogString} or ${StringType.catalogString}. " + + s"Column $fieldName was ${dataType.catalogString}") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index db432b6fefaff..dbda5b8d8fd4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 4ff1d0ef356f3..611f1b691b782 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,7 +261,8 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + assert(numFeatures.length == 1, + s"${DoubleType.catalogString} columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index c8760f9dc178f..e0772d5af20a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,8 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + s"Input type must be ${ArrayType(StringType).catalogString} but got " + + inputType.catalogString) } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 5ab6c2dde667a..27e4869a020b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,7 +85,8 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + s"Input column must be of type ${NumericType.simpleString} but got " + + schema(inputColName).dataType.catalogString) require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 55e595eee6ffb..346e1823f00b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - "Label column already exists and is not of type NumericType.") + s"Label column already exists and is not of type ${NumericType.simpleString}.") } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0f946dd2e015b..94640a5cbe310 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -131,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index cfaf6c0e610b3..aede1f812a552 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,7 +40,8 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, s"Input type must be string type but got $inputType.") + require(inputType == StringType, + s"Input type must be ${StringType.catalogString} type but got ${inputType.catalogString}.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4061154b39c14..57e23d5072b88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type $other of column $name is not supported.") + case other => Some(s"Data type ${other.catalogString} of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d7fbe28ae7a64..9d664b6ca6d2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -106,7 +106,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index d9a3f85ef9a24..c3894ebdd1785 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -41,7 +41,8 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.$message") + s"Column $colName must be of type ${dataType.catalogString} but was actually " + + s"${actualDataType.catalogString}.$message") } /** @@ -58,7 +59,8 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + s"${dataTypes.map(_.catalogString).mkString("[", ", ", "]")} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** @@ -71,8 +73,9 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + - s"NumericType but was actually of type $actualDataType.$message") + require(actualDataType.isInstanceOf[NumericType], + s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + + s"${actualDataType.catalogString}.$message") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ede284712b1c0..2b0909acf69c3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [DoubleType, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + "equal to one of the following types: [double, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index a250331efeb1d..0de6528c4cf22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type NumericType.", + "Label column already exists and is not of type numeric.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 91fb24a268b8c..ed15a1d88a269 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type StringType of column a is not supported.\n" + - "Data type StringType of column b is not supported.\n" + - "Data type StringType of column c is not supported.") + "Data type string of column a is not supported.\n" + + "Data type string of column b is not supported.\n" + + "Data type string of column c is not supported.") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e3dfe2faf5698..65bee4edc4965 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -612,7 +612,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) + s"$column must be of type numeric but was actually of type string")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 4e4ff71c9de90..6cc73e040e82c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type NumericType but was actually of type StringType")) + "Column censor must be of type numeric but was actually of type string")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 5e72b4d864c1d..91a8b14625a86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type NumericType but was actually of type StringType")) + "Column weight must be of type numeric but was actually of type string")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) } def genClassifDFWithNumericLabelCol( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 59c371eb1557b..7c5504d90433f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2273,7 +2273,7 @@ class Analyzer( } expr case other => - throw new AnalysisException("need an array field but got " + other.simpleString) + throw new AnalysisException("need an array field but got " + other.catalogString) } } validateNestedTupleFields(result) @@ -2282,8 +2282,8 @@ class Analyzer( } private def fail(schema: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.") + throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" + + ", but failed as the number of fields does not line up.") } /** @@ -2362,7 +2362,7 @@ class Analyzer( case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index af256b98b34f3..49fe625b8fc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -67,7 +67,7 @@ trait CheckAnalysis extends PredicateHelper { limitExpr.sql) case e if e.dataType != IntegerType => failAnalysis( s"The limit expression must be integer type, but got " + - e.dataType.simpleString) + e.dataType.catalogString) case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis( "The limit expression must be equal to or greater than 0, but got " + e.eval().asInstanceOf[Int]) @@ -96,8 +96,8 @@ trait CheckAnalysis extends PredicateHelper { } case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + failAnalysis(s"invalid cast from ${c.child.dataType.catalogString} to " + + c.dataType.catalogString) case g: Grouping => failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") @@ -144,12 +144,12 @@ trait CheckAnalysis extends PredicateHelper { case _ => failAnalysis( s"Event time must be defined on a window or a timestamp, but " + - s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.simpleString}") + s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.catalogString}") } case f: Filter if f.condition.dataType != BooleanType => failAnalysis( s"filter expression '${f.condition.sql}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") + s"of type ${f.condition.dataType.catalogString} is not a boolean.") case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + @@ -158,7 +158,7 @@ trait CheckAnalysis extends PredicateHelper { case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + - s"of type ${condition.dataType.simpleString} is not a boolean.") + s"of type ${condition.dataType.catalogString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { @@ -219,7 +219,7 @@ trait CheckAnalysis extends PredicateHelper { if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + - s"because its data type ${expr.dataType.simpleString} is not an orderable " + + s"because its data type ${expr.dataType.catalogString} is not an orderable " + s"data type.") } @@ -239,7 +239,7 @@ trait CheckAnalysis extends PredicateHelper { orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { failAnalysis( - s"sorting is not supported for columns of type ${order.dataType.simpleString}") + s"sorting is not supported for columns of type ${order.dataType.catalogString}") } } @@ -342,7 +342,7 @@ trait CheckAnalysis extends PredicateHelper { val mapCol = mapColumnInSetOperation(o).get failAnalysis("Cannot have map type columns in DataFrame which calls " + s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + - "is " + mapCol.dataType.simpleString) + "is " + mapCol.dataType.catalogString) case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 20216087b0158..23eb78f914656 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -76,7 +76,8 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupp // Will throw an AnalysisException if the cast can't perform or might truncate. if (Cast.mayTruncate(originAttr.dataType, attr.dataType)) { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + - s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") + s"${originAttr.dataType.catalogString} to ${attr.dataType.catalogString} as it " + + s"may truncate\n") } else { Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7971ae602bd37..ba4d1314bab2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -202,7 +202,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType.simpleString} to ${dataType.simpleString}") + s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 98f25a9ad7597..464566b0cb7d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -44,7 +44,7 @@ trait ExpectsInputTypes extends Expression { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.simpleString} type." + s"however, '${child.sql}' is of ${child.dataType.catalogString} type." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a69b80428472a..dcb9c96ca3b2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -580,10 +580,10 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { // First check whether left and right have the same type, then check if the type is acceptable. if (!left.dataType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + - s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + s"(${left.dataType.catalogString} and ${right.dataType.catalogString}).") } else if (!inputType.acceptsType(left.dataType)) { TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," + - s" not ${left.dataType.simpleString}") + s" not ${left.dataType.catalogString}") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3e7ca88249737..4b09978e75081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1048,8 +1048,9 @@ case class ScalaUDF( lazy val udfErrorMessage = { val funcCls = function.getClass.getSimpleName - val inputTypes = children.map(_.dataType.simpleString).mkString(", ") - s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})" + val inputTypes = children.map(_.dataType.catalogString).mkString(", ") + val outputType = dataType.catalogString + s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)" } override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 76a881146a146..536276b5cb29f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -73,7 +73,7 @@ case class SortOrder( if (RowOrdering.isOrderable(dataType)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}") + TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.catalogString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index f1bbbdabb41f3..c790d87492c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -132,7 +132,7 @@ case class ApproximatePercentile( case TimestampType => value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") + throw new UnsupportedOperationException(s"Unexpected data type ${other.catalogString}") } buffer.add(doubleValue) } @@ -157,7 +157,7 @@ case class ApproximatePercentile( case DoubleType => doubleResult case _: DecimalType => doubleResult.map(Decimal(_)) case other: DataType => - throw new UnsupportedOperationException(s"Unexpected data type ${other.simpleString}") + throw new UnsupportedOperationException(s"Unexpected data type ${other.catalogString}") } if (result.length == 0) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 55940410cc4d4..c827226d58420 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -528,7 +528,7 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") + s" got LEAST(${children.map(_.dataType.catalogString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } @@ -601,7 +601,7 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") + s" got GREATEST(${children.map(_.dataType.catalogString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 838c045d5bcce..05500f5923e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -596,7 +596,7 @@ class CodegenContext { case NullType => "false" case _ => throw new IllegalArgumentException( - "cannot generate equality code for un-comparable type: " + dataType.simpleString) + "cannot generate equality code for un-comparable type: " + dataType.catalogString) } /** @@ -683,7 +683,7 @@ class CodegenContext { case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException( - "cannot generate compare code for un-comparable type: " + dataType.simpleString) + "cannot generate compare code for un-comparable type: " + dataType.catalogString) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d60f4c36fa214..92635417e9666 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -64,7 +64,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + s"been two ${ArrayType.simpleString}s with same element type, but it's " + - s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]") } } } @@ -509,7 +509,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"input to $funcName should all be of type map, but it's " + - children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + children.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) } @@ -751,7 +751,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + - s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } override protected def nullSafeEval(input: Any): Any = { @@ -1118,7 +1118,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) "Sort order in second argument requires a boolean literal.") } case ArrayType(dt, _) => - val dtSimple = dt.simpleString + val dtSimple = dt.catalogString TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type $dtSimple which is not orderable") case _ => @@ -1166,7 +1166,7 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess case ArrayType(dt, _) => - val dtSimple = dt.simpleString + val dtSimple = dt.catalogString TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type $dtSimple which is not orderable") case _ => @@ -2217,7 +2217,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have been ${StringType.simpleString}," + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) + childTypes.map(_.catalogString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } @@ -2424,7 +2424,7 @@ case class Flatten(child: Expression) extends UnaryExpression { case _ => TypeCheckResult.TypeCheckFailure( s"The argument should be an array of arrays, " + - s"but '${child.sql}' is of ${child.dataType.simpleString} type." + s"but '${child.sql}' is of ${child.dataType.catalogString} type." ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a43de028360b1..077a6dc93bd17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -183,11 +183,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given keys of function map should all be the same type, but they are " + - keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + keys.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given values of function map should all be the same type, but they are " + - values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + values.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { TypeCheckResult.TypeCheckSuccess } @@ -388,8 +388,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" + + s" position, got: ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 99671d5b863c4..8994eeff92c7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child: need struct type but got ${other.simpleString}" + s"Can't extract value from $child: need struct type but got ${other.catalogString}" } throw new AnalysisException(errorMsg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 3b597e8b5263b..bed581a61b2dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -47,10 +47,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + - s"not ${predicate.dataType.simpleString}") + s"not ${predicate.dataType.catalogString}") } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + - s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") + s"(${trueValue.dataType.catalogString} and ${falseValue.dataType.catalogString}).") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b7c52f1d7b40a..b6e0d364d3a96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -156,8 +156,8 @@ case class Stack(children: Seq[Expression]) extends Generator { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType.simpleString}) != " + - s"Argument $i (${children(i).dataType.simpleString})") + s"Argument ${j + 1} (${elementSchema.fields(j).dataType.catalogString}) != " + + s"Argument $i (${children(i).dataType.catalogString})") } } TypeCheckResult.TypeCheckSuccess @@ -251,7 +251,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with case _ => TypeCheckResult.TypeCheckFailure( "input to function explode should be array or map type, " + - s"not ${child.dataType.simpleString}") + s"not ${child.dataType.catalogString}") } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -381,7 +381,7 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene case _ => TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should be array of struct type, " + - s"not ${child.dataType.simpleString}") + s"not ${child.dataType.catalogString}") } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 63943b1a4351a..abe88754f3a1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -547,7 +547,7 @@ case class JsonToStructs( case _: StructType | ArrayType(_: StructType, _) | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") + s"Input schema ${nullableSchema.catalogString} must be a struct or an array of structs.") } @transient @@ -729,7 +729,7 @@ case class StructsToJson( TypeCheckResult.TypeCheckFailure(e.getMessage) } case _ => TypeCheckResult.TypeCheckFailure( - s"Input type ${child.dataType.simpleString} must be a struct, array of structs or " + + s"Input type ${child.dataType.catalogString} must be a struct, array of structs or " + "a map or array of map.") } @@ -790,7 +790,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType}") + s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2bf4203d0fec3..3189e6841a525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1727,7 +1727,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private val errMsg = s" is not a valid external type for schema of ${expected.catalogString}" private lazy val checkType: (Any) => Boolean = expected match { case _: DecimalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f54103c4fbfba..699601e64dd14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -205,7 +205,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } case _ => TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType.simpleString} != ${mismatchOpt.get.dataType.simpleString}") + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bedad7da334ae..1838b9fca02db 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -222,12 +222,13 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have IntegerType, but it's $indexType") + s"have ${IntegerType.catalogString}, but it's ${indexType.catalogString}") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have ${StringType.catalogString} or " + + s"${BinaryType.catalogString}, but it's " + + inputTypes.map(_.catalogString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index f957aaa96e98c..53c6f01c2459e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -70,9 +70,9 @@ case class WindowSpecDefinition( case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && !isValidFrameType(f.valueBoundary.head.dataType) => TypeCheckFailure( - s"The data type '${orderSpec.head.dataType.simpleString}' used in the order " + + s"The data type '${orderSpec.head.dataType.catalogString}' used in the order " + "specification does not match the data type " + - s"'${f.valueBoundary.head.dataType.simpleString}' which is used in the range frame.") + s"'${f.valueBoundary.head.dataType.catalogString}' which is used in the range frame.") case _ => TypeCheckSuccess } } @@ -251,7 +251,7 @@ case class SpecifiedWindowFrame( TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") case e: Expression if !frameType.inputType.acceptsType(e.dataType) => TypeCheckFailure( - s"The data type of the $location bound '${e.dataType.simpleString}' does not match " + + s"The data type of the $location bound '${e.dataType.catalogString}' does not match " + s"the expected data type '${frameType.inputType.simpleString}'.") case _ => TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c413de752a8c..738947766adda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -45,14 +45,14 @@ private[sql] class JacksonGenerator( // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - "JacksonGenerator only supports to be initialized with a StructType " + - s"or MapType but got ${dataType.simpleString}") + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + + s"or ${MapType.simpleString} but got ${dataType.catalogString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { case st: StructType => st.map(_.dataType).map(makeWriter).toArray case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.simpleString} must be a struct") + s"Initial type ${dataType.catalogString} must be a struct") } // `ValueWriter` for array data storing rows of the schema. @@ -70,7 +70,7 @@ private[sql] class JacksonGenerator( private lazy val mapElementWriter: ValueWriter = dataType match { case mt: MapType => makeWriter(mt.valueType) case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.simpleString} must be a map") + s"Initial type ${dataType.catalogString} must be a map") } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index c3a4ca8f64bf6..4d409caddd33d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -143,7 +143,8 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") + case other => throw new RuntimeException( + s"Cannot parse $other as ${FloatType.catalogString}.") } } @@ -158,7 +159,8 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") + case other => + throw new RuntimeException(s"Cannot parse $other as ${DoubleType.catalogString}.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index 134d16e981a15..f26b194e7a7ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -52,7 +52,7 @@ object JacksonUtils { case _ => throw new UnsupportedOperationException( - s"Unable to convert column $name of type ${dataType.simpleString} to JSON.") + s"Unable to convert column $name of type ${dataType.catalogString} to JSON.") } schema.foreach(field => verifyType(field.name, field.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 491ca005877f8..5f70e062d46c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -294,8 +294,10 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), + s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), + s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index b795abea95a74..5214cdce861d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -29,7 +29,7 @@ object TypeUtils { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.catalogString}") } } @@ -37,7 +37,8 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") + TypeCheckResult.TypeCheckFailure( + s"$caller does not support ordering on type ${dt.catalogString}") } } @@ -47,7 +48,7 @@ object TypeUtils { } else { return TypeCheckResult.TypeCheckFailure( s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) + types.map(_.catalogString).mkString("[", ", ", "]")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 3041f44b116ea..c43cc748655e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[spark] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,11 +155,12 @@ private[sql] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[sql] def defaultConcreteType: DataType = DoubleType + override private[spark] def defaultConcreteType: DataType = DoubleType - override private[sql] def simpleString: String = "numeric" + override private[spark] def simpleString: String = "numeric" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] + override private[spark] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 38c40482fa4d9..58c75b5dc7a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[sql] def simpleString: String = "array" + override private[spark] def simpleString: String = "array" } /** @@ -103,7 +103,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") + throw new IllegalArgumentException( + s"Type ${other.catalogString} does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index dbf51c398fa47..f780ffd46a876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException(s"DecimalType can only support precision up to 38") + throw new AnalysisException( + s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") } // default constructor for Java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 2d49fe076786a..203e85e1c99bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + throw new UnsupportedOperationException( + s"null literals can't be casted to ${ObjectType.simpleString}") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 362676b252126..b13e95f83bc58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -426,7 +426,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") } } @@ -528,7 +528,8 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") + throw new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" + + s" and ${right.catalogString}") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 0ce94d39e994a..f4cfed4a91594 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -521,7 +521,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 36714bd631b0e..8eec14842c7e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type MapType") + "EqualNullSafe does not support ordering on type map") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type MapType") + "LessThan does not support ordering on type map") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type MapType") + "LessThanOrEqual does not support ordering on type map") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type MapType") + "GreaterThan does not support ordering on type map") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type MapType") + "GreaterThanOrEqual does not support ordering on type map") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index cb8a1fecb80a7..b4d422d8506fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -469,7 +469,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "DecimalType can only support precision up to 38") + intercept("1.20E-38BD", "decimal can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 5a86f4055dce7..fccd057e577d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType")) + "Failed to merge incompatible data types float and bigint")) } test("existsRecursively") { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index d5969b55eef96..31ef090ac4b45 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -244,7 +244,7 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), descriptor.getType().toString(), - column.dataType().toString()); + column.dataType().catalogString()); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c6449cd5a16b0..b068493f2dd17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -452,7 +452,7 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the udf must be a StructType") + s"The returnType of the udf must be a ${StructType.simpleString}") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 93c8127681b3e..533097ac399e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -47,11 +47,13 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + throw new UnsupportedOperationException( + s"${TimestampType.catalogString} must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + case _ => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } def fromArrowType(dt: ArrowType): DataType = dt match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 66888fce7f9f5..3de6ea8bb2577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -68,7 +68,7 @@ object ArrowWriter { } new StructWriter(vector, children.toArray) case (dt, _) => - throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index e9b150fd86095..542a10fc175c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -717,7 +717,7 @@ private[columnar] object ColumnType { case struct: StructType => STRUCT(struct) case udt: UserDefinedType[_] => apply(udt.sqlType) case other => - throw new Exception(s"Unsupported type: ${other.simpleString}") + throw new Exception(s"Unsupported type: ${other.catalogString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 82e99190ecf14..cccd6c08ae460 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -45,7 +45,7 @@ object DataSourceUtils { schema.foreach { field => if (!format.supportDataType(field.dataType, isReadPath)) { throw new AnalysisException( - s"$format data source does not support ${field.dataType.simpleString} data type.") + s"$format data source does not support ${field.dataType.catalogString} data type.") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b81737eda475b..6cc7922396d45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -175,7 +175,7 @@ object JdbcUtils extends Logging { private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( - throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}")) } /** @@ -480,7 +480,7 @@ object JdbcUtils extends Logging { case LongType if metadata.contains("binarylong") => throw new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.simpleString} based on binary") + s"type ${dt.catalogString} based on binary") case ArrayType(_, _) => throw new IllegalArgumentException("Nested arrays unsupported") @@ -494,7 +494,7 @@ object JdbcUtils extends Logging { array => new GenericArrayData(elementConversion.apply(array.getArray))) row.update(pos, array) - case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.catalogString}") } private def nullSafeConvert[T](input: T, f: T => Any): Any = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f44ae4fa1d71..c4c3b3053a3b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -98,7 +98,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 460194ba61c8b..b404cfa61f41e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -104,7 +104,7 @@ object OrcUtils extends Logging { // This is a ORC file written by Hive, no field names in the physical schema, assume the // physical schema maps to the data scheme by index. assert(orcFieldNames.length <= dataSchema.length, "The given data schema " + - s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " + + s"${dataSchema.catalogString} has less fields than the actual ORC physical schema, " + "no idea which columns were dropped, fail to read.") Some(requiredSchema.fieldNames.map { name => val index = dataSchema.fieldIndex(name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c61be077d309f..70f42f2c4ad79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -555,7 +555,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type $field.dataType") + throw new AnalysisException(s"Unsupported data type ${field.dataType.catalogString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index cab00251622b8..dfcf6c14fbef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -281,7 +281,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi schema.filter(f => normalizedPartitionCols.contains(f.name)).map(_.dataType).foreach { case _: AtomicType => // OK - case other => failAnalysis(s"Cannot use ${other.simpleString} for partition column") + case other => failAnalysis(s"Cannot use ${other.catalogString} for partition column") } normalizedPartitionCols @@ -307,7 +307,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK - case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") + case other => failAnalysis(s"Cannot use ${other.catalogString} for sorting column") } Some(normalizedBucketSpec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 685d5841ab551..bea652cc33076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType} not supported.") + s"for columns with dataType ${data.get.dataType.catalogString} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 3d49323751a10..827931d74138d 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 12 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 21 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index b8c91dc8b59a4..7f301614523b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38(line 1, pos 7) +decimal can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index a7ce952b70ac1..9f9af89570789 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -312,14 +312,14 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage - assert(msg.contains("CSV data source does not support mydensevector data type")) + assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage - assert(msg.contains("CSV data source does not support mydensevector data type.")) + assert(msg.contains("CSV data source does not support array data type.")) } } @@ -339,7 +339,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support interval data type.")) + .contains(s"$format data source does not support calendarinterval data type.")) } // read path @@ -358,7 +358,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support interval data type.")) + .contains(s"$format data source does not support calendarinterval data type.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9d3dfae348beb..368e52cfbda9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -430,9 +430,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 7f28fc40b4469..5cc1047fc067b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -785,9 +785,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // schema we read back is different(ignore case and nullability) from the one in table // properties which was written when creating table, we should respect the table schema // from hive. - logWarning(s"The table schema given by Hive metastore(${table.schema.simpleString}) is " + + logWarning(s"The table schema given by Hive metastore(${table.schema.catalogString}) is " + "different from the schema when this table was created by Spark SQL" + - s"(${schemaFromTableProps.simpleString}). We have to fall back to the table schema " + + s"(${schemaFromTableProps.catalogString}). We have to fall back to the table schema " + "from Hive metastore which is not case preserving.") hiveTable.copy(schemaPreservesCase = false) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 7dcaf170f9693..6052486c47da2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -78,9 +78,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require( - pred.dataType == BooleanType, - s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + require(pred.dataType == BooleanType, + s"Data type of predicate $pred must be ${BooleanType.catalogString} rather than " + + s"${pred.dataType.catalogString}.") BindReferences.bindReference(pred, relation.partitionCols) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index fb4957ed943a7..d84f9a3828207 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -155,7 +155,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage - assert(msg.contains("ORC data source does not support interval data type.")) + assert(msg.contains("ORC data source does not support calendarinterval data type.")) // read path msg = intercept[AnalysisException] { @@ -170,7 +170,7 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage - assert(msg.contains("ORC data source does not support interval data type.")) + assert(msg.contains("ORC data source does not support calendarinterval data type.")) } } } From 2b91d9918c8eaec6c32a502e2f08b63c475d3335 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Jul 2018 23:52:53 -0700 Subject: [PATCH 1172/2461] [SPARK-24424][SQL] Support ANSI-SQL compliant syntax for GROUPING SET ## What changes were proposed in this pull request? Enhances the parser and analyzer to support ANSI compliant syntax for GROUPING SET. As part of this change we derive the grouping expressions from user supplied groupings in the grouping sets clause. ```SQL SELECT c1, c2, max(c3) FROM t1 GROUP BY GROUPING SETS ((c1), (c1, c2)) ``` ## How was this patch tested? Added tests in SQLQueryTestSuite and ResolveGroupingAnalyticsSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #21813 from dilipbiswal/spark-24424. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../sql/catalyst/analysis/Analyzer.scala | 22 ++- .../ResolveGroupingAnalyticsSuite.scala | 28 ++++ .../sql-tests/inputs/grouping_set.sql | 36 +++++ .../sql-tests/results/grouping_set.sql.out | 126 +++++++++++++++++- 5 files changed, 210 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1b43874af6feb..2aca10f1bfbc7 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -406,6 +406,7 @@ aggregation WITH kind=ROLLUP | WITH kind=CUBE | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + | GROUP BY kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')' ; groupingSet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c5504d90433f..957c468d5f8ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -442,17 +442,35 @@ class Analyzer( child: LogicalPlan): LogicalPlan = { val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and + // can be null. In such case, we derive the groupByExprs from the user supplied values for + // grouping sets. + val finalGroupByExpressions = if (groupByExprs == Nil) { + selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => + // Only unique expressions are included in the group by expressions and is determined + // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results + // in grouping expression (a * b) + if (result.find(_.semanticEquals(currentExpr)).isDefined) { + result + } else { + result :+ currentExpr + } + } + } else { + groupByExprs + } + // Expand works by setting grouping expressions to null as determined by the // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate // instead of the original value we need to create new aliases for all group by expressions // that will only be used for the intended purpose. - val groupByAliases = constructGroupByAlias(groupByExprs) + val groupByAliases = constructGroupByAlias(finalGroupByExpressions) val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) val groupingAttrs = expand.output.drop(child.output.length) val aggregations = constructAggregateExprs( - groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) + finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid) Aggregate(groupingAttrs, aggregations, expand) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 553b1598e7750..8da4d7e3aa372 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -91,6 +91,34 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list")) } + test("grouping sets with no explicit group by expressions") { + val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), + Nil, r1, + Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)))) + val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")), + Expand( + Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)), + Seq(a, b, c, a, b, gid), + Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) + checkAnalysis(originalPlan, expected) + + // Computation of grouping expression should remove duplicate expression based on their + // semantics (semanticEqual). + val originalPlan2 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))), + Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil, r1, + Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), + unresolved_b, UnresolvedAlias(count(unresolved_c)))) + + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) + val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions + assert(gExpressions.size == 3) + val firstGroupingExprAttrName = + gExpressions(0).asInstanceOf[AttributeReference].name.replaceAll("#[0-9]*", "#0") + assert(firstGroupingExprAttrName == "(a#0 * 2)") + assert(gExpressions(1).asInstanceOf[AttributeReference].name == "b") + assert(gExpressions(2).asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName) + } + test("cube") { val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))), Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 3594283505280..6bbde9f38d657 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -13,5 +13,41 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a)); -- SPARK-17849: grouping set throws NPE #3 SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c)); +-- Group sets without explicit group by +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); +-- Group sets without group by and with grouping +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1); + +-- Mutiple grouping within a grouping set +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1; + +-- Group sets without explicit group by +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2); + +-- Mutiple grouping within a grouping set +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)); + +-- complex expression in grouping sets +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)); + +-- more query constructs with grouping sets +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1; + +-- negative tests - must have at least one grouping expression +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP; + +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE; + +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()); diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index edb38a52b7514..34ab09c5e3bba 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 15 -- !query 0 @@ -40,3 +40,127 @@ struct NULL NULL 3 1 NULL NULL 6 1 NULL NULL 9 1 + + +-- !query 4 +SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 4 schema +struct +-- !query 4 output +x 10 +y 20 + + +-- !query 5 +SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1) +-- !query 5 schema +struct +-- !query 5 output +x 10 0 +y 20 0 + + +-- !query 6 +SELECT c1, c2, Sum(c3), grouping__id +FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) +GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) +HAVING GROUPING__ID > 1 +-- !query 6 schema +struct +-- !query 6 output +NULL a 10 2 +NULL b 20 2 + + +-- !query 7 +SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2) +-- !query 7 schema +struct +-- !query 7 output +0 +0 +1 +1 + + +-- !query 8 +SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2)) +-- !query 8 schema +struct +-- !query 8 output +-1 +-1 +-3 +-3 + + +-- !query 9 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b)) +-- !query 9 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 9 output +2 NULL 1 +4 NULL 2 +NULL 1 1 +NULL 2 2 + + +-- !query 10 +SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b)) +-- !query 10 schema +struct<(a + b):int,b:int,sum(c):bigint> +-- !query 10 output +2 NULL 2 +4 NULL 4 +NULL 1 1 +NULL 2 2 + + +-- !query 11 +SELECT c1 AS col1, c2 AS col2 +FROM (VALUES (1, 2), (3, 2)) t(c1, c2) +GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) ) +HAVING col2 IS NOT NULL +ORDER BY -col1 +-- !query 11 schema +struct +-- !query 11 output +3 2 +1 2 + + +-- !query 12 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'ROLLUP' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP +-----------------------------------------------------^^^ + + +-- !query 13 +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +extraneous input 'CUBE' expecting (line 1, pos 53) + +== SQL == +SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE +-----------------------------------------------------^^^ + + +-- !query 14 +SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; From 0ab07b357b5ddae29f815734237013c21d2d2b4e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 20 Jul 2018 17:53:14 +0800 Subject: [PATCH 1173/2461] [SPARK-24868][PYTHON] add sequence function in Python ## What changes were proposed in this pull request? Add ```sequence``` in functions.py ## How was this patch tested? Add doctest. Author: Huaxin Gao Closes #21820 from huaxingao/spark-24868. --- python/pyspark/sql/functions.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5ef73987a66a6..f2e66337257be 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2551,6 +2551,28 @@ def map_concat(*cols): return Column(jc) +@since(2.4) +def sequence(start, stop, step=None): + """ + Generate a sequence of integers from `start` to `stop`, incrementing by `step`. + If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, + otherwise -1. + + >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) + >>> df1.select(sequence('C1', 'C2').alias('r')).collect() + [Row(r=[-2, -1, 0, 1, 2])] + >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) + >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() + [Row(r=[4, 2, 0, -2, -4])] + """ + sc = SparkContext._active_spark_context + if step is None: + return Column(sc._jvm.functions.sequence(_to_java_column(start), _to_java_column(stop))) + else: + return Column(sc._jvm.functions.sequence( + _to_java_column(start), _to_java_column(stop), _to_java_column(step))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): From 7b6d36bc9ef7a0af5e7461bba31c0e2518e3ce8d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 20 Jul 2018 20:08:42 +0800 Subject: [PATCH 1174/2461] [SPARK-24871][SQL] Refactor Concat and MapConcat to avoid creating concatenator object for each row. ## What changes were proposed in this pull request? Refactor `Concat` and `MapConcat` to: - avoid creating concatenator object for each row. - make `Concat` handle `containsNull` properly. - make `Concat` shortcut if `null` child is found. ## How was this patch tested? Added some tests and existing tests. Author: Takuya UESHIN Closes #21824 from ueshin/issues/SPARK-24871/refactor_concat_mapconcat. --- .../expressions/collectionOperations.scala | 299 +++++++++++------- .../CollectionExpressionsSuite.scala | 15 + 2 files changed, 192 insertions(+), 122 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 92635417e9666..f438748d9a4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres |$mapDataClass ${ev.value} = null; """.stripMargin - val assignments = mapCodes.zipWithIndex.map { case (m, i) => - s""" - |if (!$hasNullName) { - | ${m.code} - | $argsName[$i] = ${m.value}; - | if (${m.isNull}) { - | $hasNullName = true; - | } - |} - """.stripMargin + val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map { + case ((m, true), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | if (!${m.isNull}) { + | $argsName[$i] = ${m.value}; + | } else { + | $hasNullName = true; + | } + |} + """.stripMargin + case ((m, false), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + |} + """.stripMargin } val codes = ctx.splitExpressionsWithCurrentInputs( @@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val finKeysName = ctx.freshName("finalKeys") val finValsName = ctx.freshName("finalValues") - val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) { genCodeForPrimitiveArrays(ctx, keyType, false) } else { genCodeForNonPrimitiveArrays(ctx, keyType) } - val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { - genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) - } else { - genCodeForNonPrimitiveArrays(ctx, valueType) - } + val valueConcat = + if (valueType.sameType(keyType) && + !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { + keyConcat + } else if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } val keyArgsName = ctx.freshName("keyArgs") val valArgsName = ctx.freshName("valArgs") @@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres | $numElementsName + " elements due to exceeding the map size limit " + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | } - | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | $arrayDataClass $finKeysName = $keyConcat($keyArgsName, | (int) $numElementsName); - | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | $arrayDataClass $finValsName = $valueConcat($valArgsName, | (int) $numElementsName); | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); |} @@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres setterCode1 } - s""" - |new Object() { - | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $setterCode - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { @@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val argsName = ctx.freshName("args") val numElemName = ctx.freshName("numElements") - s""" - |new Object() { - | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; - | Object[] $arrayData = new Object[$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } override def prettyName: String = "map_concat" @@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) val args = ctx.freshName("args") + val hasNull = ctx.freshName("hasNull") - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ + val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map { + case ((eval, true), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | if (!${eval.isNull}) { + | $args[$index] = ${eval.value}; + | } else { + | $hasNull = true; + | } + |} + """.stripMargin + case ((eval, false), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | $args[$index] = ${eval.value}; + |} + """.stripMargin } - val (concatenator, initCode) = dataType match { + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNull; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n") + ) + + val (concat, initCode) = dataType match { case BinaryType => - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") case StringType => - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - case ArrayType(elementType, _) => - val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType) + ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, containsNull) => + val concat = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType, containsNull) } else { genCodeForNonPrimitiveArrays(ctx, elementType) } - (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];") } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(code""" - $initCode - $codes - $javaType ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) + + ev.copy(code = + code""" + |boolean $hasNull = false; + |$initCode + |$codes + |$javaType ${ev.value} = null; + |if (!$hasNull) { + | ${ev.value} = $concat($args); + |} + |boolean ${ev.isNull} = ${ev.value} == null; + """.stripMargin) } private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { @@ -2322,19 +2369,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio (code, numElements) } - private def nullArgumentProtection() : String = { - if (nullable) { - s""" - |for (int z = 0; z < ${children.length}; z++) { - | if (args[z] == null) return null; - |} - """.stripMargin - } else { - "" - } - } - - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") @@ -2342,29 +2380,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - | ); - | } - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") + val setterCode = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + |); + """.stripMargin + + val nullSafeSetterCode = if (checkForNull) { + s""" + |if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode + |} + """.stripMargin + } else { + setterCode + } + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $nullSafeSetterCode + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { @@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } override def toString: String = s"concat(${children.mkString(", ")})" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f1e3bd091565d..c7f0da71e1440 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -125,6 +125,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper valueContainsNull = false)) val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, valueContainsNull = false)) + val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m14 = Literal.create(Map(5 -> 6), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m15 = Literal.create(Map(7 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps @@ -188,6 +194,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) ) + // both keys and value are primitive and valueContainsNull = false + checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + + // both keys and value are primitive and valueContainsNull = true + checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) checkEvaluation(MapConcat(Seq(mNull, m0)), null) @@ -1121,6 +1133,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) assert(Concat(Seq(aa0, aa2)).dataType === ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) + + // force split expressions for input in generated code + checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten) } test("Flatten") { From 20ce1a8f8b2d8d8b41afdd4a8b498b6502aa5e24 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Fri, 20 Jul 2018 07:55:58 -0500 Subject: [PATCH 1175/2461] [SPARK-24551][K8S] Add integration tests for secrets ## What changes were proposed in this pull request? - Adds integration tests for env and mount secrets. ## How was this patch tested? Manually by checking that secrets were added to the containers and by tuning the tests. ![image](https://user-images.githubusercontent.com/7945591/42968472-fee3740a-8bab-11e8-9eac-573f67d861fc.png) Author: Stavros Kontopoulos Closes #21652 from skonto/add-secret-its. --- bin/docker-image-tool.sh | 2 +- .../k8s/integrationtest/BasicTestsSuite.scala | 106 +++++++++++ .../k8s/integrationtest/KubernetesSuite.scala | 177 +++--------------- .../integrationtest/PythonTestsSuite.scala | 83 ++++++++ .../integrationtest/SecretsTestsSuite.scala | 122 ++++++++++++ .../{config.scala => TestConfig.scala} | 2 +- .../{constants.scala => TestConstants.scala} | 2 +- 7 files changed, 335 insertions(+), 159 deletions(-) create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala rename resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/{config.scala => TestConfig.scala} (98%) rename resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/{constants.scala => TestConstants.scala} (97%) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index f36fb43692cf4..cd22e75402f56 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -135,7 +135,7 @@ BASEDOCKERFILE= PYDOCKERFILE= NOCACHEARG= BUILD_PARAMS= -while getopts f:mr:t:n:b: option +while getopts f:p:mr:t:n:b: option do case "${option}" in diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala new file mode 100644 index 0000000000000..4e749c40563dc --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.launcher.SparkLauncher + +private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => + + import BasicTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run SparkPi with no resources", k8sTestTag) { + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a very long application name.", k8sTestTag) { + sparkAppConf.set("spark.app.name", "long" * 40) + runSparkPiAndVerifyCompletion() + } + + test("Use SparkLauncher.NO_RESOURCE", k8sTestTag) { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + + test("Run SparkPi with a master URL without a scheme.", k8sTestTag) { + val url = kubernetesTestComponents.kubernetesClient.getMasterUrl + val k8sMasterUrl = if (url.getPort < 0) { + s"k8s://${url.getHost}" + } else { + s"k8s://${url.getHost}:${url.getPort}" + } + sparkAppConf.set("spark.master", k8sMasterUrl) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with an argument.", k8sTestTag) { + runSparkPiAndVerifyCompletion(appArgs = Array("5")) + } + + test("Run SparkPi with custom labels, annotations, and environment variables.", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.label.label1", "label1-value") + .set("spark.kubernetes.driver.label.label2", "label2-value") + .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") + .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") + .set("spark.kubernetes.executor.label.label1", "label1-value") + .set("spark.kubernetes.executor.label.label2", "label2-value") + .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.executorEnv.ENV1", "VALUE1") + .set("spark.executorEnv.ENV2", "VALUE2") + + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkCustomSettings(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkCustomSettings(executorPod) + }) + } + + test("Run extraJVMOptions check on driver", k8sTestTag) { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion(appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } +} + +private[spark] object BasicTestsSuite { + val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = + s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index daabfaaac8c7e..95694aa93d5b5 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -29,25 +29,25 @@ import org.scalatest.time.{Minutes, Seconds, Span} import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.integrationtest.TestConfig._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} -import org.apache.spark.deploy.k8s.integrationtest.config._ -import org.apache.spark.launcher.SparkLauncher private[spark] class KubernetesSuite extends SparkFunSuite - with BeforeAndAfterAll with BeforeAndAfter { + with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite + with PythonTestsSuite { import KubernetesSuite._ private var testBackend: IntegrationTestBackend = _ private var sparkHomeDir: Path = _ - private var kubernetesTestComponents: KubernetesTestComponents = _ - private var sparkAppConf: SparkAppConf = _ private var image: String = _ private var pyImage: String = _ - private var containerLocalSparkDistroExamplesJar: String = _ - private var appLocator: String = _ private var driverPodName: String = _ - private val k8sTestTag = Tag("k8s") + + protected var kubernetesTestComponents: KubernetesTestComponents = _ + protected var sparkAppConf: SparkAppConf = _ + protected var containerLocalSparkDistroExamplesJar: String = _ + protected var appLocator: String = _ override def beforeAll(): Unit = { // The scalatest-maven-plugin gives system properties that are referenced but not set null @@ -103,127 +103,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite deleteDriverPod() } - test("Run SparkPi with no resources", k8sTestTag) { - runSparkPiAndVerifyCompletion() - } - - test("Run SparkPi with a very long application name.", k8sTestTag) { - sparkAppConf.set("spark.app.name", "long" * 40) - runSparkPiAndVerifyCompletion() - } - - test("Use SparkLauncher.NO_RESOURCE", k8sTestTag) { - sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) - runSparkPiAndVerifyCompletion( - appResource = SparkLauncher.NO_RESOURCE) - } - - test("Run SparkPi with a master URL without a scheme.", k8sTestTag) { - val url = kubernetesTestComponents.kubernetesClient.getMasterUrl - val k8sMasterUrl = if (url.getPort < 0) { - s"k8s://${url.getHost}" - } else { - s"k8s://${url.getHost}:${url.getPort}" - } - sparkAppConf.set("spark.master", k8sMasterUrl) - runSparkPiAndVerifyCompletion() - } - - test("Run SparkPi with an argument.", k8sTestTag) { - runSparkPiAndVerifyCompletion(appArgs = Array("5")) - } - - test("Run SparkPi with custom labels, annotations, and environment variables.", k8sTestTag) { - sparkAppConf - .set("spark.kubernetes.driver.label.label1", "label1-value") - .set("spark.kubernetes.driver.label.label2", "label2-value") - .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") - .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") - .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") - .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") - .set("spark.kubernetes.executor.label.label1", "label1-value") - .set("spark.kubernetes.executor.label.label2", "label2-value") - .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") - .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") - .set("spark.executorEnv.ENV1", "VALUE1") - .set("spark.executorEnv.ENV2", "VALUE2") - - runSparkPiAndVerifyCompletion( - driverPodChecker = (driverPod: Pod) => { - doBasicDriverPodCheck(driverPod) - checkCustomSettings(driverPod) - }, - executorPodChecker = (executorPod: Pod) => { - doBasicExecutorPodCheck(executorPod) - checkCustomSettings(executorPod) - }) - } - - test("Run extraJVMOptions check on driver", k8sTestTag) { - sparkAppConf - .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") - runSparkJVMCheckAndVerifyCompletion( - expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) - } - - test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { - sparkAppConf - .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) - runSparkRemoteCheckAndVerifyCompletion( - appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) - } - - test("Run PySpark on simple pi.py example", k8sTestTag) { - sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") - runSparkApplicationAndVerifyCompletion( - appResource = PYSPARK_PI, - mainClass = "", - expectedLogOnCompletion = Seq("Pi is roughly 3"), - appArgs = Array("5"), - driverPodChecker = doBasicDriverPyPodCheck, - executorPodChecker = doBasicExecutorPyPodCheck, - appLocator = appLocator, - isJVM = false) - } - - test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { - sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") - .set("spark.kubernetes.pyspark.pythonversion", "2") - runSparkApplicationAndVerifyCompletion( - appResource = PYSPARK_FILES, - mainClass = "", - expectedLogOnCompletion = Seq( - "Python runtime version check is: True", - "Python environment version check is: True"), - appArgs = Array("python"), - driverPodChecker = doBasicDriverPyPodCheck, - executorPodChecker = doBasicExecutorPyPodCheck, - appLocator = appLocator, - isJVM = false, - pyFiles = Some(PYSPARK_CONTAINER_TESTS)) - } - - test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { - sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") - .set("spark.kubernetes.pyspark.pythonversion", "3") - runSparkApplicationAndVerifyCompletion( - appResource = PYSPARK_FILES, - mainClass = "", - expectedLogOnCompletion = Seq( - "Python runtime version check is: True", - "Python environment version check is: True"), - appArgs = Array("python3"), - driverPodChecker = doBasicDriverPyPodCheck, - executorPodChecker = doBasicExecutorPyPodCheck, - appLocator = appLocator, - isJVM = false, - pyFiles = Some(PYSPARK_CONTAINER_TESTS)) - } - - private def runSparkPiAndVerifyCompletion( + protected def runSparkPiAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, @@ -241,7 +121,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite isJVM) } - private def runSparkRemoteCheckAndVerifyCompletion( + protected def runSparkRemoteCheckAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, @@ -258,7 +138,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite true) } - private def runSparkJVMCheckAndVerifyCompletion( + protected def runSparkJVMCheckAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, mainClass: String = SPARK_DRIVER_MAIN_CLASS, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, @@ -295,7 +175,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite } } - private def runSparkApplicationAndVerifyCompletion( + protected def runSparkApplicationAndVerifyCompletion( appResource: String, mainClass: String, expectedLogOnCompletion: Seq[String], @@ -347,29 +227,30 @@ private[spark] class KubernetesSuite extends SparkFunSuite } } - private def doBasicDriverPodCheck(driverPod: Pod): Unit = { + protected def doBasicDriverPodCheck(driverPod: Pod): Unit = { assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === image) assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") } - private def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + + protected def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") } - private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { + protected def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") } - private def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + protected def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") } - private def checkCustomSettings(pod: Pod): Unit = { + protected def checkCustomSettings(pod: Pod): Unit = { assert(pod.getMetadata.getLabels.get("label1") === "label1-value") assert(pod.getMetadata.getLabels.get("label2") === "label2-value") assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") @@ -399,26 +280,10 @@ private[spark] class KubernetesSuite extends SparkFunSuite } private[spark] object KubernetesSuite { - - val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) - val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) + val k8sTestTag = Tag("k8s") val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" - val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" - val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" - val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" - val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" - val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" - - val TEST_SECRET_NAME_PREFIX = "test-secret-" - val TEST_SECRET_KEY = "test-key" - val TEST_SECRET_VALUE = "test-data" - val TEST_SECRET_MOUNT_PATH = "/etc/secrets" - - val REMOTE_PAGE_RANK_DATA_FILE = - "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" - val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" - - case object ShuffleNotReadyException extends Exception + val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) + val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala new file mode 100644 index 0000000000000..0254cc99de268 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} + +private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => + + import PythonTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run PySpark on simple pi.py example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } +} + +private[spark] object PythonTestsSuite { + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" +} + diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala new file mode 100644 index 0000000000000..7b05c1355ca24 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Pod, Secret, SecretBuilder} +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.io.output.ByteArrayOutputStream +import org.scalatest.concurrent.Eventually + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite._ + +private[spark] trait SecretsTestsSuite { k8sSuite: KubernetesSuite => + + import SecretsTestsSuite._ + + private def createTestSecret(): Unit = { + val sb = new SecretBuilder() + sb.withNewMetadata() + .withName(ENV_SECRET_NAME) + .endMetadata() + val secUsername = Base64.encodeBase64String(ENV_SECRET_VALUE_1.getBytes()) + val secPassword = Base64.encodeBase64String(ENV_SECRET_VALUE_2.getBytes()) + val envSecretData = Map(ENV_SECRET_KEY_1 -> secUsername, ENV_SECRET_KEY_2 -> secPassword) + sb.addToData(envSecretData.asJava) + val envSecret = sb.build() + val sec = kubernetesTestComponents + .kubernetesClient + .secrets() + .createOrReplace(envSecret) + } + + private def deleteTestSecret(): Unit = { + kubernetesTestComponents + .kubernetesClient + .secrets() + .withName(ENV_SECRET_NAME) + .delete() + } + + test("Run SparkPi with env and mount secrets.", k8sTestTag) { + createTestSecret() + sparkAppConf + .set(s"spark.kubernetes.driver.secrets.$ENV_SECRET_NAME", SECRET_MOUNT_PATH) + .set(s"spark.kubernetes.driver.secretKeyRef.USERNAME", s"$ENV_SECRET_NAME:username") + .set(s"spark.kubernetes.driver.secretKeyRef.PASSWORD", s"$ENV_SECRET_NAME:password") + .set(s"spark.kubernetes.executor.secrets.$ENV_SECRET_NAME", SECRET_MOUNT_PATH) + .set(s"spark.kubernetes.executor.secretKeyRef.USERNAME", s"$ENV_SECRET_NAME:username") + .set(s"spark.kubernetes.executor.secretKeyRef.PASSWORD", s"$ENV_SECRET_NAME:password") + try { + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkSecrets(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkSecrets(executorPod) + }, + appArgs = Array("1000") // give it enough time for all execs to be visible + ) + } finally { + // make sure this always run + deleteTestSecret() + } + } + + private def checkSecrets(pod: Pod): Unit = { + Eventually.eventually(TIMEOUT, INTERVAL) { + implicit val podName: String = pod.getMetadata.getName + val env = executeCommand("env") + assert(env.toString.contains(ENV_SECRET_VALUE_1)) + assert(env.toString.contains(ENV_SECRET_VALUE_2)) + val fileUsernameContents = executeCommand("cat", s"$SECRET_MOUNT_PATH/$ENV_SECRET_KEY_1") + val filePasswordContents = executeCommand("cat", s"$SECRET_MOUNT_PATH/$ENV_SECRET_KEY_2") + assert(fileUsernameContents.toString.trim.equals(ENV_SECRET_VALUE_1)) + assert(filePasswordContents.toString.trim.equals(ENV_SECRET_VALUE_2)) + } + } + + private def executeCommand(cmd: String*)(implicit podName: String): String = { + val out = new ByteArrayOutputStream() + val watch = kubernetesTestComponents + .kubernetesClient + .pods() + .withName(podName) + .readingInput(System.in) + .writingOutput(out) + .writingError(System.err) + .withTTY() + .exec(cmd.toArray: _*) + // wait to get some result back + Thread.sleep(1000) + watch.close() + out.flush() + out.toString() + } +} + +private[spark] object SecretsTestsSuite { + val ENV_SECRET_NAME = "mysecret" + val SECRET_MOUNT_PATH = "/etc/secret" + val ENV_SECRET_KEY_1 = "username" + val ENV_SECRET_KEY_2 = "password" + val ENV_SECRET_VALUE_1 = "secretusername" + val ENV_SECRET_VALUE_2 = "secretpassword" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala similarity index 98% rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala rename to resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala index a81ef455c6766..5a49e0779160c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala @@ -21,7 +21,7 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files -package object config { +object TestConfig { def getTestImageTag: String = { val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") require(imageTagFileProp != null, "Image tag file must be provided in system properties.") diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala similarity index 97% rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala rename to resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala index 0807a68cd823c..8595d0eab1126 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.integrationtest -package object constants { +object TestConstants { val MINIKUBE_TEST_BACKEND = "minikube" val GCE_TEST_BACKEND = "gce" } From e0b63832181464453f753649623a24cb567a73d4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 20 Jul 2018 20:59:48 +0800 Subject: [PATCH 1176/2461] [SPARK-23731][SQL] Make FileSourceScanExec canonicalizable after being (de)serialized MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? ### What's problem? In some cases, sub scalar query could throw a NPE, which is caused in execution side. ``` java.lang.NullPointerException at org.apache.spark.sql.execution.FileSourceScanExec.(DataSourceScanExec.scala:169) at org.apache.spark.sql.execution.FileSourceScanExec.doCanonicalize(DataSourceScanExec.scala:526) at org.apache.spark.sql.execution.FileSourceScanExec.doCanonicalize(DataSourceScanExec.scala:159) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:211) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:210) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:225) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:225) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.immutable.List.map(List.scala:296) at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:225) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:211) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:210) at org.apache.spark.sql.catalyst.plans.QueryPlan.sameResult(QueryPlan.scala:258) at org.apache.spark.sql.execution.ScalarSubquery.semanticEquals(subquery.scala:58) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$Expr.equals(EquivalentExpressions.scala:36) at scala.collection.mutable.HashTable$class.elemEquals(HashTable.scala:364) at scala.collection.mutable.HashMap.elemEquals(HashMap.scala:40) at scala.collection.mutable.HashTable$class.scala$collection$mutable$HashTable$$findEntry0(HashTable.scala:139) at scala.collection.mutable.HashTable$class.findEntry(HashTable.scala:135) at scala.collection.mutable.HashMap.findEntry(HashMap.scala:40) at scala.collection.mutable.HashMap.get(HashMap.scala:70) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExpr(EquivalentExpressions.scala:56) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:97) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$$anonfun$addExprTree$1.apply(EquivalentExpressions.scala:98) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$$anonfun$addExprTree$1.apply(EquivalentExpressions.scala:98) at scala.collection.immutable.List.foreach(List.scala:392) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:98) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext$$anonfun$subexpressionElimination$1.apply(CodeGenerator.scala:1102) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext$$anonfun$subexpressionElimination$1.apply(CodeGenerator.scala:1102) at scala.collection.immutable.List.foreach(List.scala:392) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.subexpressionElimination(CodeGenerator.scala:1102) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1154) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.createCode(GenerateUnsafeProjection.scala:270) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.create(GenerateUnsafeProjection.scala:319) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.generate(GenerateUnsafeProjection.scala:308) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:181) at org.apache.spark.sql.execution.ProjectExec$$anonfun$9.apply(basicPhysicalOperators.scala:71) at org.apache.spark.sql.execution.ProjectExec$$anonfun$9.apply(basicPhysicalOperators.scala:70) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:818) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:818) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:367) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` ### How does this happen? Here looks what happen now: 1. Sub scalar query was made (for instance `SELECT (SELECT id FROM foo)`). 2. Try to extract some common expressions (via `CodeGenerator.subexpressionElimination`) so that it can generates some common codes and can be reused. 3. During this, seems it extracts some expressions that can be reused (via `EquivalentExpressions.addExprTree`) https://github.com/apache/spark/blob/b2deef64f604ddd9502a31105ed47cb63470ec85/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala#L1102 4. During this, if the hash (`EquivalentExpressions.Expr.hashCode`) happened to be the same at `EquivalentExpressions.addExpr` anyhow, `EquivalentExpressions.Expr.equals` is called to identify object in the same hash, which eventually calls `semanticEquals` in `ScalarSubquery` https://github.com/apache/spark/blob/087879a77acb37b790c36f8da67355b90719c2dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala#L54 https://github.com/apache/spark/blob/087879a77acb37b790c36f8da67355b90719c2dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala#L36 5. `ScalarSubquery`'s `semanticEquals` needs `SubqueryExec`'s `sameResult` https://github.com/apache/spark/blob/77a2fc5b521788b406bb32bcc3c637c1d7406e58/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala#L58 6. `SubqueryExec`'s `sameResult` requires a canonicalized plan which calls `FileSourceScanExec`'s `doCanonicalize` https://github.com/apache/spark/blob/e008ad175256a3192fdcbd2c4793044d52f46d57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala#L258 7. In `FileSourceScanExec`'s `doCanonicalize`, `FileSourceScanExec`'s `relation` is required but seems `transient` so it becomes `null`. https://github.com/apache/spark/blob/e76b0124fbe463def00b1dffcfd8fd47e04772fe/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L527 https://github.com/apache/spark/blob/e76b0124fbe463def00b1dffcfd8fd47e04772fe/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L160 8. NPE is thrown. \*1. driver side \*2., 3., 4., 5., 6., 7., 8. executor side Note that most of cases, it looks fine because we will usually call: https://github.com/apache/spark/blob/087879a77acb37b790c36f8da67355b90719c2dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala#L40 which make a canonicalized plan via: https://github.com/apache/spark/blob/b045315e5d87b7ea3588436053aaa4d5a7bd103f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala#L192 https://github.com/apache/spark/blob/77a2fc5b521788b406bb32bcc3c637c1d7406e58/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala#L52 ### How to reproduce? This looks what happened now. I can reproduce this by a bit of messy way: ```diff diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 8d06804ce1e..d25fc9a7ba9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala -37,7 +37,9 class EquivalentExpressions { case _ => false } - override def hashCode: Int = e.semanticHash() + override def hashCode: Int = { + 1 + } } ``` ```scala spark.range(1).write.mode("overwrite").parquet("/tmp/foo") spark.read.parquet("/tmp/foo").createOrReplaceTempView("foo") spark.conf.set("spark.sql.codegen.wholeStage", false) sql("SELECT (SELECT id FROM foo) == (SELECT id FROM foo)").collect() ``` ### How does this PR fix? - Make all variables that access to `FileSourceScanExec`'s `relation` as `lazy val` so that we avoid NPE. This is a temporary fix. - Allow `makeCopy` in `SparkPlan` without Spark session too. This looks still able to be accessed within executor side. For instance: ``` at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:70) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:47) at org.apache.spark.sql.catalyst.trees.TreeNode.withNewChildren(TreeNode.scala:233) at org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:243) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:211) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:210) at org.apache.spark.sql.catalyst.plans.QueryPlan.sameResult(QueryPlan.scala:258) at org.apache.spark.sql.execution.ScalarSubquery.semanticEquals(subquery.scala:58) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$Expr.equals(EquivalentExpressions.scala:36) at scala.collection.mutable.HashTable$class.elemEquals(HashTable.scala:364) at scala.collection.mutable.HashMap.elemEquals(HashMap.scala:40) at scala.collection.mutable.HashTable$class.scala$collection$mutable$HashTable$$findEntry0(HashTable.scala:139) at scala.collection.mutable.HashTable$class.findEntry(HashTable.scala:135) at scala.collection.mutable.HashMap.findEntry(HashMap.scala:40) at scala.collection.mutable.HashMap.get(HashMap.scala:70) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExpr(EquivalentExpressions.scala:54) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:95) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$$anonfun$addExprTree$1.apply(EquivalentExpressions.scala:96) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions$$anonfun$addExprTree$1.apply(EquivalentExpressions.scala:96) at scala.collection.immutable.List.foreach(List.scala:392) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.addExprTree(EquivalentExpressions.scala:96) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext$$anonfun$subexpressionElimination$1.apply(CodeGenerator.scala:1102) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext$$anonfun$subexpressionElimination$1.apply(CodeGenerator.scala:1102) at scala.collection.immutable.List.foreach(List.scala:392) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.subexpressionElimination(CodeGenerator.scala:1102) at org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext.generateExpressions(CodeGenerator.scala:1154) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.createCode(GenerateUnsafeProjection.scala:270) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.create(GenerateUnsafeProjection.scala:319) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.generate(GenerateUnsafeProjection.scala:308) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:181) at org.apache.spark.sql.execution.ProjectExec$$anonfun$9.apply(basicPhysicalOperators.scala:71) at org.apache.spark.sql.execution.ProjectExec$$anonfun$9.apply(basicPhysicalOperators.scala:70) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:818) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:818) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:367) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` This PR takes over https://github.com/apache/spark/pull/20856. ## How was this patch tested? Manually tested and unit test was added. Closes #20856 Author: hyukjinkwon Closes #21815 from HyukjinKwon/SPARK-23731. --- .../sql/execution/DataSourceScanExec.scala | 10 ++++++---- .../apache/spark/sql/execution/SparkPlan.scala | 10 +++++----- .../spark/sql/execution/SparkPlanSuite.scala | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d7f2654be0451..36ed016773b67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -166,10 +166,12 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - override val supportsBatch: Boolean = relation.fileFormat.supportBatch( + // Note that some vals referring the file-based relation are lazy intentionally + // so that this plan can be canonicalized on executor side too. See SPARK-23731. + override lazy val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - override val needsUnsafeRowConversion: Boolean = { + override lazy val needsUnsafeRowConversion: Boolean = { if (relation.fileFormat.isInstanceOf[ParquetSource]) { SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled } else { @@ -199,7 +201,7 @@ case class FileSourceScanExec( ret } - override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { relation.bucketSpec } else { @@ -270,7 +272,7 @@ case class FileSourceScanExec( private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - override val metadata: Map[String, String] = { + override lazy val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") val location = relation.location val locationDesc = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 398758a3331b4..1f97993e20458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -47,17 +47,15 @@ import org.apache.spark.util.ThreadUtils abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { /** - * A handle to the SQL Context that was used to create this plan. Since many operators need + * A handle to the SQL Context that was used to create this plan. Since many operators need * access to the sqlContext for RDD operations or configuration this field is automatically * populated by the query planning infrastructure. */ - @transient - final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull + @transient final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when SparkPlan nodes are created without the active sessions. - // So far, this only happens in the test cases. val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.subexpressionEliminationEnabled } else { @@ -69,7 +67,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkSession.setActiveSession(sqlContext.sparkSession) + if (sqlContext != null) { + SparkSession.setActiveSession(sqlContext.sparkSession) + } super.makeCopy(newArgs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index 750d9e4adf8b4..34dc6f37c0e4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.SparkEnv import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext @@ -33,4 +34,20 @@ class SparkPlanSuite extends QueryTest with SharedSQLContext { intercept[IllegalStateException] { plan.executeTake(1) } } + test("SPARK-23731 plans should be canonicalizable after being (de)serialized") { + withTempPath { path => + spark.range(1).write.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + val fileSourceScanExec = + df.queryExecution.sparkPlan.collectFirst { case p: FileSourceScanExec => p }.get + val serializer = SparkEnv.get.serializer.newInstance() + val readback = + serializer.deserialize[FileSourceScanExec](serializer.serialize(fileSourceScanExec)) + try { + readback.canonicalized + } catch { + case e: Throwable => fail("FileSourceScanExec was not canonicalizable", e) + } + } + } } From cc4d64bb16987eb5a41d4198bf4a5882e549a94f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 20 Jul 2018 09:18:57 -0700 Subject: [PATCH 1177/2461] [SPARK-23451][ML] Deprecate KMeans.computeCost ## What changes were proposed in this pull request? Deprecate `KMeans.computeCost` which was introduced as a temp fix and now it is not needed anymore, since we introduced `ClusteringEvaluator`. ## How was this patch tested? manual test (deprecation warning displayed) Scala ``` ... scala> model.computeCost(dataset) warning: there was one deprecation warning; re-run with -deprecation for details res1: Double = 0.0 ``` Python ``` >>> import warnings >>> warnings.simplefilter('always', DeprecationWarning) ... >>> model.computeCost(df) /Users/mgaido/apache/spark/python/pyspark/ml/clustering.py:330: DeprecationWarning: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. " instead.", DeprecationWarning) ``` Author: Marco Gaido Closes #20629 from mgaido91/SPARK-23451. --- .../apache/spark/ml/clustering/KMeans.scala | 19 ++++++++++++++++--- .../spark/mllib/clustering/KMeans.scala | 2 +- .../spark/mllib/clustering/KMeansModel.scala | 11 +++++++---- .../spark/ml/clustering/KMeansSuite.scala | 2 ++ python/pyspark/ml/clustering.py | 19 ++++++++++++++++++- 5 files changed, 44 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f40037a8d9aa9..6f4a30d4595a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -145,8 +145,12 @@ class KMeansModel private[ml] ( /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. + * + * @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator + * instead. You can also get the cost on the training dataset in the summary. */ - // TODO: Replace the temp fix when we have proper evaluators defined for clustering. + @deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " + + "instead. You can also get the cost on the training dataset in the summary.", "2.4.0") @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) @@ -356,7 +360,12 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k), parentModel.numIter) + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + parentModel.numIter, + parentModel.trainingCost) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -389,6 +398,8 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param numIter Number of iterations. + * @param trainingCost K-means cost (sum of squared distances to the nearest centroid for all + * points in the training dataset). This is equivalent to sklearn's inertia. */ @Since("2.0.0") @Experimental @@ -397,4 +408,6 @@ class KMeansSummary private[clustering] ( predictionCol: String, featuresCol: String, k: Int, - numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) + numIter: Int, + @Since("2.4.0") val trainingCost: Double) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 4f554f420b903..55df8a34fbfc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -348,7 +348,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure, iteration) + new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index e3a88b42fbf73..d5c8188144ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.{Row, SparkSession} @Since("0.8.0") class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], @Since("2.4.0") val distanceMeasure: String, + @Since("2.4.0") val trainingCost: Double, private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { @@ -49,11 +50,11 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], @Since("2.4.0") private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = - this(clusterCenters: Array[Vector], distanceMeasure, -1) + this(clusterCenters: Array[Vector], distanceMeasure, 0.0, -1) @Since("1.1.0") def this(clusterCenters: Array[Vector]) = - this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN, 0.0, -1) /** * A Java-friendly constructor that takes an Iterable of Vectors. @@ -187,7 +188,8 @@ object KMeansModel extends Loader[KMeansModel] { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) - ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure) + ~ ("trainingCost" -> model.trainingCost))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => Cluster(id, p.vector) @@ -207,7 +209,8 @@ object KMeansModel extends Loader[KMeansModel] { val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.length) val distanceMeasure = (metadata \ "distanceMeasure").extract[String] - new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + val trainingCost = (metadata \ "trainingCost").extract[Double] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure, trainingCost, -1) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 829c90fe34e94..9b0b52617755c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -131,6 +131,8 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(summary.predictions.columns.contains(c)) } assert(summary.cluster.columns === Array(predictionColName)) + assert(summary.trainingCost < 0.1) + assert(model.computeCost(dataset) == summary.trainingCost) val clusterSizes = summary.clusterSizes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 2f0660040dc7c..8a58d838819e2 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,6 +16,7 @@ # import sys +import warnings from pyspark import since, keyword_only from pyspark.ml.util import * @@ -303,7 +304,15 @@ class KMeansSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - pass + + @property + @since("2.4.0") + def trainingCost(self): + """ + K-means cost (sum of squared distances to the nearest centroid for all points in the + training dataset). This is equivalent to sklearn's inertia. + """ + return self._call_java("trainingCost") class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): @@ -323,7 +332,13 @@ def computeCost(self, dataset): """ Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data. + + ..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. + You can also get the cost on the training dataset in the summary. """ + warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator " + "instead. You can also get the cost on the training dataset in the summary.", + DeprecationWarning) return self._call_java("computeCost", dataset) @property @@ -379,6 +394,8 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol 2 >>> summary.clusterSizes [2, 2] + >>> summary.trainingCost + 2.000... >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) From 244bcff19463d82ec72baf15bc0a5209f21f2ef3 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 20 Jul 2018 09:19:29 -0700 Subject: [PATCH 1178/2461] [SPARK-24811][SQL] Avro: add new function from_avro and to_avro ## What changes were proposed in this pull request? Add a new function from_avro for parsing a binary column of avro format and converting it into its corresponding catalyst value. Add a new function to_avro for converting a column into binary of avro format with the specified schema. This PR is in progress. Will add test cases. ## How was this patch tested? Author: Gengliang Wang Closes #21774 from gengliangwang/from_and_to_avro. --- .../spark/sql/avro/AvroDataToCatalyst.scala | 68 +++++++ .../spark/sql/avro/CatalystDataToAvro.scala | 69 +++++++ .../org/apache/spark/sql/avro/package.scala | 31 ++++ .../AvroCatalystDataConversionSuite.scala | 175 ++++++++++++++++++ .../spark/sql/avro/AvroFunctionsSuite.scala | 83 +++++++++ .../expressions/ExpressionEvalHelper.scala | 6 + 6 files changed, 432 insertions(+) create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala new file mode 100644 index 0000000000000..6671b3fb8705c --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericDatumReader +import org.apache.avro.io.{BinaryDecoder, DecoderFactory} + +import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} + +case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType + + override def nullable: Boolean = true + + @transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema) + + @transient private lazy val reader = new GenericDatumReader[Any](avroSchema) + + @transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType) + + @transient private var decoder: BinaryDecoder = _ + + @transient private var result: Any = _ + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) + result = reader.read(result, decoder) + deserializer.deserialize(result) + } + + override def simpleString: String = { + s"from_avro(${child.sql}, ${dataType.simpleString})" + } + + override def sql: String = { + s"from_avro(${child.sql}, ${dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala new file mode 100644 index 0000000000000..a669388e88258 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.avro.generic.GenericDatumWriter +import org.apache.avro.io.{BinaryEncoder, EncoderFactory} + +import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types.{BinaryType, DataType} + +case class CatalystDataToAvro(child: Expression) extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val avroType = + SchemaConverters.toAvroType(child.dataType, child.nullable) + + @transient private lazy val serializer = + new AvroSerializer(child.dataType, avroType, child.nullable) + + @transient private lazy val writer = + new GenericDatumWriter[Any](avroType) + + @transient private var encoder: BinaryEncoder = _ + + @transient private lazy val out = new ByteArrayOutputStream + + override def nullSafeEval(input: Any): Any = { + out.reset() + encoder = EncoderFactory.get().directBinaryEncoder(out, encoder) + val avroData = serializer.serialize(input) + writer.write(avroData, encoder) + encoder.flush() + out.toByteArray + } + + override def simpleString: String = { + s"to_avro(${child.sql}, ${child.dataType.simpleString})" + } + + override def sql: String = { + s"to_avro(${child.sql}, ${child.dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(byte[]) $expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala index b3c8a669cf820..e82651d96a03d 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.avro.Schema + +import org.apache.spark.annotation.Experimental + package object avro { /** * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using @@ -36,4 +40,31 @@ package object avro { @scala.annotation.varargs def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) } + + /** + * Converts a binary column of avro format into its corresponding catalyst value. The specified + * schema must match the read data, otherwise the behavior is undefined: it may fail or return + * arbitrary result. + * + * @param data the binary column. + * @param jsonFormatSchema the avro schema in JSON string format. + * + * @since 2.4.0 + */ + @Experimental + def from_avro(data: Column, jsonFormatSchema: String): Column = { + new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema)) + } + + /** + * Converts a column into binary of avro format. + * + * @param data the data column. + * + * @since 2.4.0 + */ + @Experimental + def to_avro(data: Column): Column = { + new Column(CatalystDataToAvro(data.expr)) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..06d5477b2ea45 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AvroDataToCatalyst, CatalystDataToAvro, RandomDataGenerator} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def roundTripTest(data: Literal): Unit = { + val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable) + checkResult(data, avroType.toString, data.eval()) + } + + private def checkResult(data: Literal, schema: String, expected: Any): Unit = { + checkEvaluation( + AvroDataToCatalyst(CatalystDataToAvro(data), schema), + prepareExpectedResult(expected)) + } + + private def assertFail(data: Literal, schema: String): Unit = { + intercept[java.io.EOFException] { + AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval() + } + } + + private val testingTypes = Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(8, 0), // 32 bits decimal without fraction + DecimalType(8, 4), // 32 bits decimal + DecimalType(16, 0), // 64 bits decimal without fraction + DecimalType(16, 11), // 64 bits decimal + DecimalType(38, 0), + DecimalType(38, 38), + StringType, + BinaryType) + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark decimal is converted to avro string= + case d: Decimal => UTF8String.fromString(d.toString) + // Spark byte and short both map to avro int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + testingTypes.foreach { dt => + val seed = scala.util.Random.nextLong() + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes) + test(s"flat schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes) + test(s"nested schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + test("read int as string") { + val data = Literal(1) + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + + // When read int as string, avro reader is not able to parse the binary and fail. + assertFail(data, avroTypeJson) + } + + test("read string as int") { + val data = Literal("abc") + val avroTypeJson = + s""" + |{ + | "type": "int", + | "name": "my_int" + |} + """.stripMargin + + // When read string data as int, avro reader is not able to find the type mismatch and read + // the string length as int value. + checkResult(data, avroTypeJson, 3) + } + + test("read float as double") { + val data = Literal(1.23f) + val avroTypeJson = + s""" + |{ + | "type": "double", + | "name": "my_double" + |} + """.stripMargin + + // When read float data as double, avro reader fails(trying to read 8 bytes while the data have + // only 4 bytes). + assertFail(data, avroTypeJson) + } + + test("read double as float") { + val data = Literal(1.23) + val avroTypeJson = + s""" + |{ + | "type": "float", + | "name": "my_float" + |} + """.stripMargin + + // avro reader reads the first 4 bytes of a double as a float, the result is totally undefined. + checkResult(data, avroTypeJson, 5.848603E35f) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala new file mode 100644 index 0000000000000..90a4cd6ccf9dd --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.test.SharedSQLContext + +class AvroFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("roundtrip in to_avro and from_avro - int and string") { + val df = spark.range(10).select('id, 'id.cast("string").as("str")) + + val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroTypeLong = s""" + |{ + | "type": "int", + | "name": "id" + |} + """.stripMargin + val avroTypeStr = s""" + |{ + | "type": "string", + | "name": "str" + |} + """.stripMargin + checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + } + + test("roundtrip in to_avro and from_avro - struct") { + val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroTypeStruct = s""" + |{ + | "type": "record", + | "name": "struct", + | "fields": [ + | {"name": "col1", "type": "long"}, + | {"name": "col2", "type": "string"} + | ] + |} + """.stripMargin + checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + } + + test("roundtrip in to_avro and from_avro - array with null") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val avroTypeArrStruct = s""" + |[ { + | "type" : "array", + | "items" : [ { + | "type" : "record", + | "name" : "x", + | "fields" : [ { + | "name" : "y", + | "type" : "int" + | } ] + | }, "null" ] + |}, "null" ] + """.stripMargin + val readBackOne = dfOne.select(to_avro($"array").as("avro")) + .select(from_avro($"avro", avroTypeArrStruct).as("array")) + checkAnswer(dfOne, readBackOne) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 14bfa212b5496..d045267ef5d9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,6 +79,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: InternalRow, expected: InternalRow) => + val st = dataType.asInstanceOf[StructType] + assert(result.numFields == st.length && expected.numFields == st.length) + st.zipWithIndex.forall { case (f, i) => + checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) + } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { val et = dataType.asInstanceOf[ArrayType].elementType From 3cb1b57809d0b4a93223669f5c10cea8fc53eff6 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Fri, 20 Jul 2018 12:13:15 -0700 Subject: [PATCH 1179/2461] [SPARK-24852][ML] Update spark.ml to use Instrumentation.instrumented. ## What changes were proposed in this pull request? Followup for #21719. Update spark.ml training code to fully wrap instrumented methods and remove old instrumentation APIs. ## How was this patch tested? existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21799 from MrBago/new-instrumentation-apis2. --- .../DecisionTreeClassifier.scala | 24 +++++----- .../ml/classification/GBTClassifier.scala | 14 +++--- .../spark/ml/classification/LinearSVC.scala | 12 ++--- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 14 +++--- .../spark/ml/classification/NaiveBayes.scala | 12 ++--- .../spark/ml/classification/OneVsRest.scala | 9 ++-- .../RandomForestClassifier.scala | 13 ++--- .../spark/ml/clustering/BisectingKMeans.scala | 12 ++--- .../spark/ml/clustering/GaussianMixture.scala | 12 ++--- .../apache/spark/ml/clustering/KMeans.scala | 9 ++-- .../org/apache/spark/ml/clustering/LDA.scala | 12 ++--- .../org/apache/spark/ml/fpm/FPGrowth.scala | 12 ++--- .../apache/spark/ml/recommendation/ALS.scala | 9 ++-- .../ml/regression/AFTSurvivalRegression.scala | 13 +++-- .../ml/regression/DecisionTreeRegressor.scala | 24 +++++----- .../spark/ml/regression/GBTRegressor.scala | 12 ++--- .../GeneralizedLinearRegression.scala | 12 +++-- .../ml/regression/IsotonicRegression.scala | 12 ++--- .../ml/regression/LinearRegression.scala | 21 ++++----- .../ml/regression/RandomForestRegressor.scala | 14 +++--- .../spark/ml/tuning/CrossValidator.scala | 9 ++-- .../ml/tuning/TrainValidationSplit.scala | 9 ++-- .../spark/ml/util/Instrumentation.scala | 47 ++----------------- 24 files changed, 153 insertions(+), 186 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index c9786f1f7ceb1..8a57bfc029d14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -96,8 +97,10 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = set(seed, value) - override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { - val instr = Instrumentation.create(this, dataset) + override protected def train( + dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -112,30 +115,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeClassificationModel = { - val instr = Instrumentation.create(this, data) - instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeClassificationModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 337133a2e2326..33acd9914073f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -31,9 +31,9 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -152,7 +152,8 @@ class GBTClassifier @Since("1.4.0") ( set(validationIndicatorCol, value) } - override protected def train(dataset: Dataset[_]): GBTClassificationModel = { + override protected def train( + dataset: Dataset[_]): GBTClassificationModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -189,8 +190,9 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, validationIndicatorCol) @@ -204,9 +206,7 @@ class GBTClassifier @Since("1.4.0") ( GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } - val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) - instr.logSuccess(m) - m + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 38eb04556b775..20f9366862bb8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -162,7 +163,7 @@ class LinearSVC @Since("2.2.0") ( @Since("2.2.0") override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) - override protected def train(dataset: Dataset[_]): LinearSVCModel = { + override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -170,8 +171,9 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) val (summarizer, labelSummarizer) = { @@ -276,9 +278,7 @@ class LinearSVC @Since("2.2.0") ( (Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result()) } - val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) - instr.logSuccess(model) - model + copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 25fb9c8aab0bc..af651b056f2f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -503,7 +503,7 @@ class LogisticRegression @Since("1.2.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(regParam, elasticNetParam, standardization, threshold, + instr.logParams(this, regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) val (summarizer, labelSummarizer) = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 57ba47e596a97..65e3b2d3beb14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.Dataset /** Params for Multilayer Perceptron. */ @@ -230,9 +231,11 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, layers, maxIter, tol, + override protected def train( + dataset: Dataset[_]): MultilayerPerceptronClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, layers, maxIter, tol, blockSize, solver, stepSize, seed) val myLayers = $(layers) @@ -264,10 +267,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( } trainer.setStackSize($(blockSize)) val mlpModel = trainer.train(data) - val model = new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) - - instr.logSuccess(model) - model + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1dde18d2d1a31..f65d3979791a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} @@ -125,8 +126,9 @@ class NaiveBayes @Since("1.5.0") ( */ private[spark] def trainWithLabelCheck( dataset: Dataset[_], - positiveLabel: Boolean): NaiveBayesModel = { - val instr = Instrumentation.create(this, dataset) + positiveLabel: Boolean): NaiveBayesModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) instr.logNumClasses(numClasses) @@ -148,7 +150,7 @@ class NaiveBayes @Since("1.5.0") ( } } - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size @@ -204,9 +206,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - val model = new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) - instr.logSuccess(model) - model + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 3474b61e40136..1835a91775e0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -36,6 +36,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -362,11 +363,12 @@ final class OneVsRest @Since("1.4.0") ( } @Since("2.0.0") - override def fit(dataset: Dataset[_]): OneVsRestModel = { + override def fit(dataset: Dataset[_]): OneVsRestModel = instrumented { instr => transformSchema(dataset.schema) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -440,7 +442,6 @@ final class OneVsRest @Since("1.4.0") ( case attr: Attribute => attr } val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) - instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 040db3b94b041..94887ac346fec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -115,8 +116,10 @@ class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { - val instr = Instrumentation.create(this, dataset) + override protected def train( + dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -131,7 +134,7 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, + instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -140,11 +143,9 @@ class RandomForestClassifier @Since("1.4.0") ( .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) - instr.logSuccess(m) - m + new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index de564471c2b4d..48b8c52dbffd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -257,12 +258,13 @@ class BisectingKMeans @Since("2.0.0") ( def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() @@ -275,10 +277,8 @@ class BisectingKMeans @Since("2.0.0") ( val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) - model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logSuccess(model) - model + model.setSummary(Some(summary)) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f0707b380c673..310b03b15822c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD @@ -335,7 +336,7 @@ class GaussianMixture @Since("2.0.0") ( private val numSamples = 5 @Since("2.0.0") - override def fit(dataset: Dataset[_]): GaussianMixtureModel = { + override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val sc = dataset.sparkSession.sparkContext @@ -352,8 +353,9 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians( @@ -425,11 +427,9 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) - model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logSuccess(model) - model + model.setSummary(Some(summary)) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6f4a30d4595a1..498310d6644e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): KMeansModel = { + override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -346,8 +347,9 @@ class KMeans @Since("1.5.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) @@ -369,7 +371,6 @@ class KMeans @Since("1.5.0") ( model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logSuccess(model) if (handlePersistence) { instances.unpersist() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index fed42c959b5ef..50867f776c522 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -896,11 +897,12 @@ class LDA @Since("1.6.0") ( override def copy(extra: ParamMap): LDA = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): LDAModel = { + override def fit(dataset: Dataset[_]): LDAModel = instrumented { instr => transformSchema(dataset.schema, logging = true) - val instr = Instrumentation.create(this, dataset) - instr.logParams(featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, featuresCol, topicDistributionCol, k, maxIter, subsamplingRate, checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration, learningDecay, optimizer, learningOffset, seed) @@ -923,9 +925,7 @@ class LDA @Since("1.6.0") ( } instr.logNumFeatures(newModel.vocabSize) - val model = copyValues(newModel).setParent(this) - instr.logSuccess(model) - model + copyValues(newModel).setParent(this) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 9d664b6ca6d2a..85c483c387ad8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset @@ -158,11 +159,12 @@ class FPGrowth @Since("2.2.0") ( genericFit(dataset) } - private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = instrumented { instr => val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instr = Instrumentation.create(this, dataset) - instr.logParams(params: _*) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -185,9 +187,7 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) - instr.logSuccess(model) - model + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) } @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a23f9552b9e5f..ffe592789b3cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -39,6 +39,7 @@ import org.apache.spark.ml.linalg.BLAS import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -654,7 +655,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] } @Since("2.0.0") - override def fit(dataset: Dataset[_]): ALSModel = { + override def fit(dataset: Dataset[_]): ALSModel = instrumented { instr => transformSchema(dataset.schema) import dataset.sparkSession.implicits._ @@ -666,8 +667,9 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val instr = Instrumentation.create(this, ratings) - instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval, seed, intermediateStorageLevel, finalStorageLevel) @@ -681,7 +683,6 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) - instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e27a96e1f5dfc..3cd07063ee32d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils @@ -210,7 +211,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } @Since("2.0.0") - override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -229,8 +230,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) val numFeatures = featuresStd.size - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, censorCol, predictionCol, quantilesCol, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, fitIntercept, maxIter, tol, aggregationDepth) instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) instr.logNumFeatures(numFeatures) @@ -284,10 +286,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) - val model = copyValues(new AFTSurvivalRegressionModel(uid, coefficients, - intercept, scale).setParent(this)) - instr.logSuccess(model) - model + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale).setParent(this)) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 8bcf0793a64c1..018290f81842f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -99,37 +100,36 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("2.0.0") def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { + override protected def train( + dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logPipelineStage(this) + instr.logDataset(oldDataset) + instr.logParams(this, params: _*) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train( data: RDD[LabeledPoint], oldStrategy: OldStrategy, - featureSubsetStrategy: String): DecisionTreeRegressionModel = { - val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, params: _*) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) - val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] - instr.logSuccess(m) - m + trees.head.asInstanceOf[DecisionTreeRegressionModel] } /** (private[ml]) Create a Strategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index eb8b3c001436a..3305881b0ccc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -31,6 +31,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD @@ -151,7 +152,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) set(validationIndicatorCol, value) } - override protected def train(dataset: Dataset[_]): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -168,8 +169,9 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) @@ -181,9 +183,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } - val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) - instr.logSuccess(m) - m + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 143c8a3548b1f..20878b6448920 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -373,13 +374,15 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val @Since("2.0.0") def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value) - override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { + override protected def train( + dataset: Dataset[_]): GeneralizedLinearRegressionModel = instrumented { instr => val familyAndLink = FamilyAndLink(this) val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol, - family, solver, fitIntercept, link, maxIter, regParam, tol) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, offsetCol, predictionCol, + linkPredictionCol, family, solver, fitIntercept, link, maxIter, regParam, tol) instr.logNumFeatures(numFeatures) if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { @@ -431,7 +434,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val model.setSummary(Some(trainingSummary)) } - instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index b046897ab2b7e..8b9233dcdc4d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD @@ -161,15 +162,16 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) instr.logNumFeatures(1) val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) @@ -177,9 +179,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri if (handlePersistence) instances.unpersist() - val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) - instr.logSuccess(model) - model + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c45ade94a4e33..ce6c12cc368dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -37,6 +37,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} @@ -315,7 +316,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setEpsilon(value: Double): this.type = set(epsilon, value) setDefault(epsilon -> 1.35) - override protected def train(dataset: Dataset[_]): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr => // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) @@ -326,9 +327,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Instance(label, weight, features) } - val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, tol, elasticNetParam, - fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, epsilon) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, solver, tol, + elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss, + epsilon) instr.logNumFeatures(numFeatures) if ($(loss) == SquaredError && (($(solver) == Auto && @@ -353,9 +356,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - lrModel.setSummary(Some(trainingSummary)) - instr.logSuccess(lrModel) - return lrModel + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -415,9 +416,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Array(0D), Array(0D)) - model.setSummary(Some(trainingSummary)) - instr.logSuccess(model) - return model + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -596,8 +595,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String objectiveHistory) model.setSummary(Some(trainingSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 4509f85aafd12..35875724b3cfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -114,15 +115,17 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { + override protected def train( + dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees, + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) @@ -131,9 +134,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestRegressionModel(uid, trees, numFeatures) - instr.logSuccess(m) - m + instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) + new RandomForestRegressionModel(uid, trees, numFeatures) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f327f37bad204..e60a14f976a5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -33,6 +33,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -118,7 +119,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): CrossValidatorModel = { + override def fit(dataset: Dataset[_]): CrossValidatorModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val sparkSession = dataset.sparkSession @@ -129,8 +130,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // Create execution context based on $(parallelism) val executionContext = getExecutionContext - val instr = Instrumentation.create(this, dataset) - instr.logParams(numFolds, seed, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, numFolds, seed, parallelism) logTuningParams(instr) val collectSubModelsParam = $(collectSubModels) @@ -176,7 +178,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) .setSubModels(subModels).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 14d6a69c36747..8b251197afbef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils @@ -117,7 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val est = $(estimator) @@ -127,8 +128,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Create execution context based on $(parallelism) val executionContext = getExecutionContext - val instr = Instrumentation.create(this, dataset) - instr.logParams(trainRatio, seed, parallelism) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, trainRatio, seed, parallelism) logTuningParams(instr) val Array(trainingDataset, validationDataset) = @@ -172,7 +174,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) .setSubModels(subModels).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 2e43a9ef49ee1..49654918bd8f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -27,7 +27,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.PipelineStage import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset @@ -37,26 +37,16 @@ import org.apache.spark.util.Utils * A small wrapper that defines a training session for an estimator, and some methods to log * useful information during this session. */ -private[spark] class Instrumentation extends Logging { +private[spark] class Instrumentation private () extends Logging { private val id = UUID.randomUUID() private val shortId = id.toString.take(8) - private val prefix = s"[$shortId] " - - // TODO: remove stage - var stage: Params = _ - // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor - private def this(estimator: Estimator[_], dataset: RDD[_]) = { - this() - logPipelineStage(estimator) - logDataset(dataset) - } + private[util] val prefix = s"[$shortId] " /** * Log some info about the pipeline stage being fit. */ def logPipelineStage(stage: PipelineStage): Unit = { - this.stage = stage // estimator.getClass.getSimpleName can cause Malformed class name error, // call safer `Utils.getSimpleName` instead val className = Utils.getSimpleName(stage.getClass) @@ -119,13 +109,6 @@ private[spark] class Instrumentation extends Logging { logInfo(compact(render(map2jvalue(pairs.toMap)))) } - // TODO: remove this - def logParams(params: Param[_]*): Unit = { - require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" + - " Params must be provided explicitly).") - logParams(stage, params: _*) - } - def logNumFeatures(num: Long): Unit = { logNamedValue(Instrumentation.loggerTags.numFeatures, num) } @@ -166,14 +149,9 @@ private[spark] class Instrumentation extends Logging { } - // TODO: Remove this (possibly replace with logModel?) /** * Logs the successful completion of the training session. */ - def logSuccess(model: Model[_]): Unit = { - logInfo(s"training finished") - } - def logSuccess(): Unit = { logInfo("training finished") } @@ -200,22 +178,6 @@ private[spark] object Instrumentation { val varianceOfLabels = "varianceOfLabels" } - // TODO: Remove these - /** - * Creates an instrumentation object for a training session. - */ - def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = { - create(estimator, dataset.rdd) - } - - /** - * Creates an instrumentation object for a training session. - */ - def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = { - new Instrumentation(estimator, dataset) - } - // end remove - def instrumented[T](body: (Instrumentation => T)): T = { val instr = new Instrumentation() Try(body(instr)) match { @@ -268,8 +230,7 @@ private[spark] object OptionalInstrumentation { * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. */ def create(instr: Instrumentation): OptionalInstrumentation = { - new OptionalInstrumentation(Some(instr), - instr.stage.getClass.getName.stripSuffix("$")) + new OptionalInstrumentation(Some(instr), instr.prefix) } /** From 9ad77b3037b476b726b773c38d1cd264d89d51e2 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 20 Jul 2018 12:55:38 -0700 Subject: [PATCH 1180/2461] Revert "[SPARK-24811][SQL] Avro: add new function from_avro and to_avro" This reverts commit 244bcff19463d82ec72baf15bc0a5209f21f2ef3. --- .../spark/sql/avro/AvroDataToCatalyst.scala | 68 ------- .../spark/sql/avro/CatalystDataToAvro.scala | 69 ------- .../org/apache/spark/sql/avro/package.scala | 31 ---- .../AvroCatalystDataConversionSuite.scala | 175 ------------------ .../spark/sql/avro/AvroFunctionsSuite.scala | 83 --------- .../expressions/ExpressionEvalHelper.scala | 6 - 6 files changed, 432 deletions(-) delete mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala delete mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala delete mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala delete mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala deleted file mode 100644 index 6671b3fb8705c..0000000000000 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.avro.Schema -import org.apache.avro.generic.GenericDatumReader -import org.apache.avro.io.{BinaryDecoder, DecoderFactory} - -import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters} -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} - -case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) - extends UnaryExpression with ExpectsInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) - - override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType - - override def nullable: Boolean = true - - @transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema) - - @transient private lazy val reader = new GenericDatumReader[Any](avroSchema) - - @transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType) - - @transient private var decoder: BinaryDecoder = _ - - @transient private var result: Any = _ - - override def nullSafeEval(input: Any): Any = { - val binary = input.asInstanceOf[Array[Byte]] - decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) - result = reader.read(result, decoder) - deserializer.deserialize(result) - } - - override def simpleString: String = { - s"from_avro(${child.sql}, ${dataType.simpleString})" - } - - override def sql: String = { - s"from_avro(${child.sql}, ${dataType.catalogString})" - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val expr = ctx.addReferenceObj("this", this) - defineCodeGen(ctx, ev, input => - s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") - } -} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala deleted file mode 100644 index a669388e88258..0000000000000 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.ByteArrayOutputStream - -import org.apache.avro.generic.GenericDatumWriter -import org.apache.avro.io.{BinaryEncoder, EncoderFactory} - -import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters} -import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types.{BinaryType, DataType} - -case class CatalystDataToAvro(child: Expression) extends UnaryExpression { - - override def dataType: DataType = BinaryType - - @transient private lazy val avroType = - SchemaConverters.toAvroType(child.dataType, child.nullable) - - @transient private lazy val serializer = - new AvroSerializer(child.dataType, avroType, child.nullable) - - @transient private lazy val writer = - new GenericDatumWriter[Any](avroType) - - @transient private var encoder: BinaryEncoder = _ - - @transient private lazy val out = new ByteArrayOutputStream - - override def nullSafeEval(input: Any): Any = { - out.reset() - encoder = EncoderFactory.get().directBinaryEncoder(out, encoder) - val avroData = serializer.serialize(input) - writer.write(avroData, encoder) - encoder.flush() - out.toByteArray - } - - override def simpleString: String = { - s"to_avro(${child.sql}, ${child.dataType.simpleString})" - } - - override def sql: String = { - s"to_avro(${child.sql}, ${child.dataType.catalogString})" - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val expr = ctx.addReferenceObj("this", this) - defineCodeGen(ctx, ev, input => - s"(byte[]) $expr.nullSafeEval($input)") - } -} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala index e82651d96a03d..b3c8a669cf820 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import org.apache.avro.Schema - -import org.apache.spark.annotation.Experimental - package object avro { /** * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using @@ -40,31 +36,4 @@ package object avro { @scala.annotation.varargs def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) } - - /** - * Converts a binary column of avro format into its corresponding catalyst value. The specified - * schema must match the read data, otherwise the behavior is undefined: it may fail or return - * arbitrary result. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * - * @since 2.4.0 - */ - @Experimental - def from_avro(data: Column, jsonFormatSchema: String): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema)) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * - * @since 2.4.0 - */ - @Experimental - def to_avro(data: Column): Column = { - new Column(CatalystDataToAvro(data.expr)) - } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala deleted file mode 100644 index 06d5477b2ea45..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import org.apache.avro.Schema - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{AvroDataToCatalyst, CatalystDataToAvro, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { - - private def roundTripTest(data: Literal): Unit = { - val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable) - checkResult(data, avroType.toString, data.eval()) - } - - private def checkResult(data: Literal, schema: String, expected: Any): Unit = { - checkEvaluation( - AvroDataToCatalyst(CatalystDataToAvro(data), schema), - prepareExpectedResult(expected)) - } - - private def assertFail(data: Literal, schema: String): Unit = { - intercept[java.io.EOFException] { - AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval() - } - } - - private val testingTypes = Seq( - BooleanType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType, - DecimalType(8, 0), // 32 bits decimal without fraction - DecimalType(8, 4), // 32 bits decimal - DecimalType(16, 0), // 64 bits decimal without fraction - DecimalType(16, 11), // 64 bits decimal - DecimalType(38, 0), - DecimalType(38, 38), - StringType, - BinaryType) - - protected def prepareExpectedResult(expected: Any): Any = expected match { - // Spark decimal is converted to avro string= - case d: Decimal => UTF8String.fromString(d.toString) - // Spark byte and short both map to avro int - case b: Byte => b.toInt - case s: Short => s.toInt - case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) - case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) - case map: MapData => - val keys = new GenericArrayData( - map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) - val values = new GenericArrayData( - map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) - new ArrayBasedMapData(keys, values) - case other => other - } - - testingTypes.foreach { dt => - val seed = scala.util.Random.nextLong() - test(s"single $dt with seed $seed") { - val rand = new scala.util.Random(seed) - val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() - val converter = CatalystTypeConverters.createToCatalystConverter(dt) - val input = Literal.create(converter(data), dt) - roundTripTest(input) - } - } - - for (_ <- 1 to 5) { - val seed = scala.util.Random.nextLong() - val rand = new scala.util.Random(seed) - val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes) - test(s"flat schema ${schema.catalogString} with seed $seed") { - val data = RandomDataGenerator.randomRow(rand, schema) - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - val input = Literal.create(converter(data), schema) - roundTripTest(input) - } - } - - for (_ <- 1 to 5) { - val seed = scala.util.Random.nextLong() - val rand = new scala.util.Random(seed) - val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes) - test(s"nested schema ${schema.catalogString} with seed $seed") { - val data = RandomDataGenerator.randomRow(rand, schema) - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - val input = Literal.create(converter(data), schema) - roundTripTest(input) - } - } - - test("read int as string") { - val data = Literal(1) - val avroTypeJson = - s""" - |{ - | "type": "string", - | "name": "my_string" - |} - """.stripMargin - - // When read int as string, avro reader is not able to parse the binary and fail. - assertFail(data, avroTypeJson) - } - - test("read string as int") { - val data = Literal("abc") - val avroTypeJson = - s""" - |{ - | "type": "int", - | "name": "my_int" - |} - """.stripMargin - - // When read string data as int, avro reader is not able to find the type mismatch and read - // the string length as int value. - checkResult(data, avroTypeJson, 3) - } - - test("read float as double") { - val data = Literal(1.23f) - val avroTypeJson = - s""" - |{ - | "type": "double", - | "name": "my_double" - |} - """.stripMargin - - // When read float data as double, avro reader fails(trying to read 8 bytes while the data have - // only 4 bytes). - assertFail(data, avroTypeJson) - } - - test("read double as float") { - val data = Literal(1.23) - val avroTypeJson = - s""" - |{ - | "type": "float", - | "name": "my_float" - |} - """.stripMargin - - // avro reader reads the first 4 bytes of a double as a float, the result is totally undefined. - checkResult(data, avroTypeJson, 5.848603E35f) - } -} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala deleted file mode 100644 index 90a4cd6ccf9dd..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import org.apache.avro.Schema - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.struct -import org.apache.spark.sql.test.SharedSQLContext - -class AvroFunctionsSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - test("roundtrip in to_avro and from_avro - int and string") { - val df = spark.range(10).select('id, 'id.cast("string").as("str")) - - val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) - val avroTypeLong = s""" - |{ - | "type": "int", - | "name": "id" - |} - """.stripMargin - val avroTypeStr = s""" - |{ - | "type": "string", - | "name": "str" - |} - """.stripMargin - checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) - } - - test("roundtrip in to_avro and from_avro - struct") { - val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) - val avroStructDF = df.select(to_avro('struct).as("avro")) - val avroTypeStruct = s""" - |{ - | "type": "record", - | "name": "struct", - | "fields": [ - | {"name": "col1", "type": "long"}, - | {"name": "col2", "type": "string"} - | ] - |} - """.stripMargin - checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) - } - - test("roundtrip in to_avro and from_avro - array with null") { - val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") - val avroTypeArrStruct = s""" - |[ { - | "type" : "array", - | "items" : [ { - | "type" : "record", - | "name" : "x", - | "fields" : [ { - | "name" : "y", - | "type" : "int" - | } ] - | }, "null" ] - |}, "null" ] - """.stripMargin - val readBackOne = dfOne.select(to_avro($"array").as("avro")) - .select(from_avro($"avro", avroTypeArrStruct).as("array")) - checkAnswer(dfOne, readBackOne) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index d045267ef5d9e..14bfa212b5496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,12 +79,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) - case (result: InternalRow, expected: InternalRow) => - val st = dataType.asInstanceOf[StructType] - assert(result.numFields == st.length && expected.numFields == st.length) - st.zipWithIndex.forall { case (f, i) => - checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) - } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { val et = dataType.asInstanceOf[ArrayType].elementType From 2333a34d390f2fa19b939b8007be0deb31f31d3c Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Fri, 20 Jul 2018 13:03:57 -0700 Subject: [PATCH 1181/2461] [SPARK-22880][SQL] Add cascadeTruncate option to JDBC datasource This commit adds the `cascadeTruncate` option to the JDBC datasource API, for databases that support this functionality (PostgreSQL and Oracle at the moment). This allows for applying a cascading truncate that affects tables that have foreign key constraints on the table being truncated. ## What changes were proposed in this pull request? Add `cascadeTruncate` option to JDBC datasource API. Allow this to affect the `TRUNCATE` query for databases that support this option. ## How was this patch tested? Existing tests for `truncateQuery` were updated. Also, an additional test was added to ensure that the correct syntax was applied, and that enabling the config for databases that do not support this option does not result in invalid queries. Author: Daniel van der Ende Closes #20057 from danielvdende/SPARK-22880. --- docs/sql-programming-guide.md | 7 +++ .../datasources/jdbc/JDBCOptions.scala | 3 ++ .../datasources/jdbc/JdbcUtils.scala | 7 ++- .../spark/sql/jdbc/AggregatedDialect.scala | 13 +++++- .../apache/spark/sql/jdbc/DerbyDialect.scala | 2 + .../apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++++++- .../apache/spark/sql/jdbc/OracleDialect.scala | 16 +++++++ .../spark/sql/jdbc/PostgresDialect.scala | 29 ++++++++---- .../spark/sql/jdbc/TeradataDialect.scala | 18 ++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 46 ++++++++++++++++--- 10 files changed, 140 insertions(+), 21 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ad23dae7c6b7c..4bab58aff0067 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1407,6 +1407,13 @@ the following case-insensitive options: This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing.
      + + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index eea966d30948b..574aed4958fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -157,6 +157,8 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -225,6 +227,7 @@ object JDBCOptions { val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 6cc7922396d45..edea549748b47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -105,7 +105,12 @@ object JdbcUtils extends Logging { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + val truncateQuery = if (options.isCascadeTruncate.isDefined) { + dialect.getTruncateQuery(options.table, options.isCascadeTruncate) + } else { + dialect.getTruncateQuery(options.table) + } + statement.executeUpdate(truncateQuery) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 8b92c8b4f56b5..3a3246a1b1d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String): String = { - dialects.head.getTruncateQuery(table) + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38c..d13c29ed46bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 83d87a11810c1..f76c1fae562c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -120,12 +121,27 @@ abstract class JdbcDialect extends Serializable { * The SQL query that should be used to truncate a table. Dialects can override this method to * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. - * @param table The name of the table. + * @param table The table to truncate * @return The SQL query to use for truncating a table */ @Since("2.3.0") def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE $table" + getTruncateQuery(table, isCascadingTruncateTable) + } + + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation + * @return The SQL query to use for truncating a table + */ + @Since("2.4.0") + def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6ef77f24460be..f4a6d0a4d2e44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -95,4 +95,20 @@ private case object OracleDialect extends JdbcDialect { } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE $table CASCADE" + case _ => s"TRUNCATE TABLE $table" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 13a2035f4d0c4..f8d2bc8e0f13f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,15 +85,27 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. For Postgres, the default behaviour is to - * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, - * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The name of the table. - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE ONLY $table" + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the value of + * isCascadingTruncateTable(). Cascading a truncation will truncate tables + * with a foreign key relationship to the target table. However, it will not + * truncate tables with an inheritance relationship to the target table, as + * the truncate query always includes "ONLY" to prevent this behaviour. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" + case _ => s"TRUNCATE TABLE ONLY $table" + } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -110,5 +122,4 @@ private object PostgresDialect extends JdbcDialect { } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca25..6c17bd7ed9ec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,22 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + // Teradata does not support cascading a truncation + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. Teradata does not support the 'TRUNCATE' syntax that + * other dialects use. Instead, we need to use a 'DELETE FROM' statement. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable(). Teradata does not support cascading a + * 'DELETE FROM' statement (and as mentioned, does not support 'TRUNCATE' syntax) + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"DELETE FROM $table ALL" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0389273d6cdfa..09facb9bef8dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -861,19 +861,51 @@ class JDBCSuite extends QueryTest } test("truncate table query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" - assert(MySQL.getTruncateQuery(table) == defaultQuery) - assert(Postgres.getTruncateQuery(table) == postgresQuery) - assert(db2.getTruncateQuery(table) == defaultQuery) - assert(h2.getTruncateQuery(table) == defaultQuery) - assert(derby.getTruncateQuery(table) == defaultQuery) + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + + assert(postgres.getTruncateQuery(table) == postgresQuery) + assert(oracle.getTruncateQuery(table) == defaultQuery) + assert(teradata.getTruncateQuery(table) == teradataQuery) + } + + test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { + // cascade in a truncate should only be applied for databases that support this, + // even if the parameter is passed. + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" + val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) + assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) } test("Test DataFrame.where for Date and Timestamp") { From 00b864aa7054a34f3d7a118d92eae0b3c28b86e5 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 20 Jul 2018 14:57:59 -0700 Subject: [PATCH 1182/2461] [SPARK-24876][SQL] Avro: simplify schema serialization ## What changes were proposed in this pull request? Previously in the refactoring of Avro Serializer and Deserializer, a new class SerializableSchema is created for serializing the Avro schema: https://github.com/apache/spark/pull/21762/files#diff-01fea32e6ec6bcf6f34d06282e08705aR37 On second thought, we can use `toString` method for serialization. After that, parse the JSON format schema on executor. This makes the code much simpler. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21829 from gengliangwang/removeSerializableSchema. --- .../spark/sql/avro/AvroFileFormat.scala | 2 +- .../sql/avro/AvroOutputWriterFactory.scala | 14 +++- .../spark/sql/avro/SerializableSchema.scala | 69 ------------------- .../sql/avro/SerializableSchemaSuite.scala | 56 --------------- 4 files changed, 12 insertions(+), 129 deletions(-) delete mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala delete mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 1d0f40e1ce92a..780e4570f697e 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -146,7 +146,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { log.error(s"unsupported compression codec $unknown") } - new AvroOutputWriterFactory(dataSchema, new SerializableSchema(outputAvroSchema)) + new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) } override def buildReader( diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala index 18a6d93951408..116020ed5c433 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -17,14 +17,22 @@ package org.apache.spark.sql.avro +import org.apache.avro.Schema import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType +/** + * A factory that produces [[AvroOutputWriter]]. + * @param catalystSchema Catalyst schema of input data. + * @param avroSchemaAsJsonString Avro schema of output result, in JSON string format. + */ private[avro] class AvroOutputWriterFactory( - schema: StructType, - avroSchema: SerializableSchema) extends OutputWriterFactory { + catalystSchema: StructType, + avroSchemaAsJsonString: String) extends OutputWriterFactory { + + private lazy val avroSchema = new Schema.Parser().parse(avroSchemaAsJsonString) override def getFileExtension(context: TaskAttemptContext): String = ".avro" @@ -32,6 +40,6 @@ private[avro] class AvroOutputWriterFactory( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, avroSchema.value) + new AvroOutputWriter(path, context, catalystSchema, avroSchema) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala deleted file mode 100644 index ec0ddc778c8f6..0000000000000 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import java.io._ - -import scala.util.control.NonFatal - -import com.esotericsoftware.kryo.{Kryo, KryoSerializable} -import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.avro.Schema -import org.slf4j.LoggerFactory - -class SerializableSchema(@transient var value: Schema) - extends Serializable with KryoSerializable { - - @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) - - private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { - out.defaultWriteObject() - out.writeUTF(value.toString()) - out.flush() - } - - private def readObject(in: ObjectInputStream): Unit = tryOrIOException { - val json = in.readUTF() - value = new Schema.Parser().parse(json) - } - - private def tryOrIOException[T](block: => T): T = { - try { - block - } catch { - case e: IOException => - log.error("Exception encountered", e) - throw e - case NonFatal(e) => - log.error("Exception encountered", e) - throw new IOException(e) - } - } - - def write(kryo: Kryo, out: Output): Unit = { - val dos = new DataOutputStream(out) - dos.writeUTF(value.toString()) - dos.flush() - } - - def read(kryo: Kryo, in: Input): Unit = { - val dis = new DataInputStream(in) - val json = dis.readUTF() - value = new Schema.Parser().parse(json) - } -} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala deleted file mode 100644 index 510bcbdd31929..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import org.apache.avro.Schema - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} - -class SerializableSchemaSuite extends SparkFunSuite { - - private def testSerialization(serializer: SerializerInstance): Unit = { - val avroTypeJson = - s""" - |{ - | "type": "string", - | "name": "my_string" - |} - """.stripMargin - val avroSchema = new Schema.Parser().parse(avroTypeJson) - val serializableSchema = new SerializableSchema(avroSchema) - val serialized = serializer.serialize(serializableSchema) - - serializer.deserialize[Any](serialized) match { - case c: SerializableSchema => - assert(c.log != null, "log was null") - assert(c.value != null, "value was null") - assert(c.value == avroSchema) - case other => fail( - s"Expecting ${classOf[SerializableSchema]}, but got ${other.getClass}.") - } - } - - test("serialization with JavaSerializer") { - testSerialization(new JavaSerializer(new SparkConf()).newInstance()) - } - - test("serialization with KryoSerializer") { - testSerialization(new KryoSerializer(new SparkConf()).newInstance()) - } -} From f765bb7823a60440cb42819edd98b14f65e13b18 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 20 Jul 2018 15:23:04 -0700 Subject: [PATCH 1183/2461] [SPARK-24880][BUILD] Fix the group id for spark-kubernetes-integration-tests ## What changes were proposed in this pull request? The correct group id should be `org.apache.spark`. This is causing the nightly build failure: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-maven-snapshots/2295/console ` [ERROR] Failed to execute goal org.apache.maven.plugins:maven-deploy-plugin:2.8.2:deploy (default-deploy) on project spark-kubernetes-integration-tests_2.11: Failed to deploy artifacts: Could not transfer artifact spark-kubernetes-integration-tests:spark-kubernetes-integration-tests_2.11:jar:2.4.0-20180720.101629-1 from/to apache.snapshots.https (https://repository.apache.org/content/repositories/snapshots): Access denied to: https://repository.apache.org/content/repositories/snapshots/spark-kubernetes-integration-tests/spark-kubernetes-integration-tests_2.11/2.4.0-SNAPSHOT/spark-kubernetes-integration-tests_2.11-2.4.0-20180720.101629-1.jar, ReasonPhrase: Forbidden. -> [Help 1] [ERROR] ` ## How was this patch tested? Jenkins. Author: zsxwing Closes #21831 from zsxwing/fix-k8s-test. --- resource-managers/kubernetes/integration-tests/pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 29334cc6d891d..614705c1ed668 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -25,7 +25,6 @@ spark-kubernetes-integration-tests_2.11 - spark-kubernetes-integration-tests 1.3.0 1.4.0 From 597bdeff2d07d690287526ab0e722f80749014d2 Mon Sep 17 00:00:00 2001 From: Brandon Krieger Date: Sat, 21 Jul 2018 00:44:00 +0200 Subject: [PATCH 1184/2461] [SPARK-24488][SQL] Fix issue when generator is aliased multiple times ## What changes were proposed in this pull request? Currently, the Analyzer throws an exception if your try to nest a generator. However, it special cases generators "nested" in an alias, and allows that. If you try to alias a generator twice, it is not caught by the special case, so an exception is thrown. This PR trims the unnecessary, non-top-level aliases, so that the generator is allowed. ## How was this patch tested? new tests in AnalysisSuite. Author: Brandon Krieger Closes #21508 from bkrieger/bk/SPARK-24488. --- .../sql/catalyst/analysis/Analyzer.scala | 53 +++++++++++-------- .../sql/catalyst/analysis/AnalysisSuite.scala | 8 +++ 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 957c468d5f8ee..866396c42f9d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1660,11 +1660,13 @@ class Analyzer( expr.find(_.isInstanceOf[Generator]).isDefined } - private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { - case UnresolvedAlias(_: Generator, _) => false - case Alias(_: Generator, _) => false - case MultiAlias(_: Generator, _) => false - case other => hasGenerator(other) + private def hasNestedGenerator(expr: NamedExpression): Boolean = { + CleanupAliases.trimNonTopLevelAliases(expr) match { + case UnresolvedAlias(_: Generator, _) => false + case Alias(_: Generator, _) => false + case MultiAlias(_: Generator, _) => false + case other => hasGenerator(other) + } } private def trimAlias(expr: NamedExpression): Expression = expr match { @@ -1705,24 +1707,26 @@ class Analyzer( // Holds the resolved generator, if one exists in the project list. var resolvedGenerator: Generate = null - val newProjectList = projectList.flatMap { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => - // It's a sanity check, this should not happen as the previous case will throw - // exception earlier. - assert(resolvedGenerator == null, "More than one generator found in SELECT.") - - resolvedGenerator = - Generate( - generator, - unrequiredChildIndex = Nil, - outer = outer, - qualifier = None, - generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), - child) - - resolvedGenerator.generatorOutput - case other => other :: Nil - } + val newProjectList = projectList + .map(CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + .flatMap { + case AliasedGenerator(generator, names, outer) if generator.childrenResolved => + // It's a sanity check, this should not happen as the previous case will throw + // exception earlier. + assert(resolvedGenerator == null, "More than one generator found in SELECT.") + + resolvedGenerator = + Generate( + generator, + unrequiredChildIndex = Nil, + outer = outer, + qualifier = None, + generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } if (resolvedGenerator != null) { Project(newProjectList, resolvedGenerator) @@ -2433,6 +2437,7 @@ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child + case MultiAlias(child, _) => child } } @@ -2442,6 +2447,8 @@ object CleanupAliases extends Rule[LogicalPlan] { exprId = a.exprId, qualifier = a.qualifier, explicitMetadata = Some(a.metadata)) + case a: MultiAlias => + a.copy(child = trimAliases(a.child)) case other => trimAliases(other) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index bbcdf6c1b8481..9e0db8dbf8f3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -575,4 +575,12 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertAnalysisSuccess( Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) } + + test("SPARK-24488 Generator with multiple aliases") { + assertAnalysisSuccess( + listRelation.select(Explode('list).as("first_alias").as("second_alias"))) + assertAnalysisSuccess( + listRelation.select(MultiAlias(MultiAlias( + PosExplode('list), Seq("first_pos", "first_val")), Seq("second_pos", "second_val")))) + } } From 96f3120760ba0a83ef6347327ecfb130487e02dd Mon Sep 17 00:00:00 2001 From: William Sheu Date: Fri, 20 Jul 2018 19:48:32 -0700 Subject: [PATCH 1185/2461] [PYSPARK][TEST][MINOR] Fix UDFInitializationTests ## What changes were proposed in this pull request? Fix a typo in pyspark sql tests Author: William Sheu Closes #21833 from PenguinToast/fix-test-typo. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 565654e7f03bb..2d6b9f01e6525 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3595,7 +3595,7 @@ def tearDown(self): SparkSession._instantiatedSession.stop() if SparkContext._active_spark_context is not None: - SparkContext._active_spark_contex.stop() + SparkContext._active_spark_context.stop() def test_udf_init_shouldnt_initalize_context(self): from pyspark.sql.functions import UserDefinedFunction From bbd6f0c25fe19dc6c946e63cac7b98d0f78b3463 Mon Sep 17 00:00:00 2001 From: William Sheu Date: Fri, 20 Jul 2018 19:59:28 -0700 Subject: [PATCH 1186/2461] [SPARK-24879][SQL] Fix NPE in Hive partition pruning filter pushdown ## What changes were proposed in this pull request? We get a NPE when we have a filter on a partition column of the form `col in (x, null)`. This is due to the filter converter in HiveShim not handling `null`s correctly. This patch fixes this bug while still pushing down as much of the partition pruning predicates as possible, by filtering out `null`s from any `in` predicate. Since Hive only supports very simple partition pruning filters, this change should preserve correctness. ## How was this patch tested? Unit tests, manual tests Author: William Sheu Closes #21832 from PenguinToast/partition-pruning-npe. --- .../spark/sql/hive/client/HiveShim.scala | 19 ++++++++++++++++++- .../spark/sql/hive/client/FiltersSuite.scala | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 933384ed43e98..bc9d4cd7f4181 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -598,6 +598,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiteral { def unapply(expr: Expression): Option[String] = expr match { + case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs. case Literal(value, _: IntegralType) => Some(value.toString) case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) case _ => None @@ -606,7 +607,23 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { object ExtractableLiterals { def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { - val extractables = exprs.map(ExtractableLiteral.unapply) + // SPARK-24879: The Hive metastore filter parser does not support "null", but we still want + // to push down as many predicates as we can while still maintaining correctness. + // In SQL, the `IN` expression evaluates as follows: + // > `1 in (2, NULL)` -> NULL + // > `1 in (1, NULL)` -> true + // > `1 in (2)` -> false + // Since Hive metastore filters are NULL-intolerant binary operations joined only by + // `AND` and `OR`, we can treat `NULL` as `false` and thus rewrite `1 in (2, NULL)` as + // `1 in (2)`. + // If the Hive metastore begins supporting NULL-tolerant predicates and Spark starts + // pushing down these predicates, then this optimization will become incorrect and need + // to be changed. + val extractables = exprs + .filter { + case Literal(null, _) => false + case _ => true + }.map(ExtractableLiteral.unapply) if (extractables.nonEmpty && extractables.forall(_.isDefined)) { Some(extractables.map(_.get)) } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 19765695fbcb4..2a4efd0cce6e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -72,6 +72,20 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + filterTest("SPARK-24879 null literals should be ignored for IN constructs", + (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil, + "(intcol = 1)") + + // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization + // will be applied by Catalyst, this filter converter does not need to account for this. + filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE", + (a("intcol", IntegerType) in Literal(null)) :: Nil, + "") + + filterTest("typecast null literals should not be pushed down in simple predicates", + (a("intcol", IntegerType) === Literal(null, IntegerType)) :: Nil, + "") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { From 106880edcd67bc20e8610a16f8ce6aa250268eeb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 20 Jul 2018 20:04:40 -0700 Subject: [PATCH 1187/2461] [SPARK-24836][SQL] New option for Avro datasource - ignoreExtension ## What changes were proposed in this pull request? I propose to add new option for AVRO datasource which should control ignoring of files without `.avro` extension in read. The option has name `ignoreExtension` with default value `true`. If both options `ignoreExtension` and `avro.mapred.ignore.inputs.without.extension` are set, `ignoreExtension` overrides the former one. Here is an example of usage: ``` spark .read .option("ignoreExtension", false) .avro("path to avro files") ``` ## How was this patch tested? I added a test which checks the option directly and a test for checking that new option overrides hadoop's config. Author: Maxim Gekk Closes #21798 from MaxGekk/avro-ignore-extension. --- .../spark/sql/avro/AvroFileFormat.scala | 33 ++++------- .../apache/spark/sql/avro/AvroOptions.scala | 29 +++++++++- .../org/apache/spark/sql/avro/AvroSuite.scala | 55 ++++++++++++++++++- 3 files changed, 91 insertions(+), 26 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 780e4570f697e..078efabbeeb4e 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -58,21 +58,19 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { val conf = spark.sparkContext.hadoopConfiguration - val parsedOptions = new AvroOptions(options) + val parsedOptions = new AvroOptions(options, conf) // Schema evolution is not supported yet. Here we only pick a single random sample file to // figure out the schema of the whole dataset. val sampleFile = - if (AvroFileFormat.ignoreFilesWithoutExtensions(conf)) { - files.find(_.getPath.getName.endsWith(".avro")).getOrElse { - throw new FileNotFoundException( - "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + - " is set to true. Do all input files have \".avro\" extension?" - ) + if (parsedOptions.ignoreExtension) { + files.headOption.getOrElse { + throw new FileNotFoundException("Files for schema inferring have been not found.") } } else { - files.headOption.getOrElse { - throw new FileNotFoundException("No Avro files found.") + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") } } @@ -115,7 +113,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val parsedOptions = new AvroOptions(options) + val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) val outputAvroSchema = SchemaConverters.toAvroType( dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) @@ -160,7 +158,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { val broadcastedConf = spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) - val parsedOptions = new AvroOptions(options) + val parsedOptions = new AvroOptions(options, hadoopConf) (file: PartitionedFile) => { val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) @@ -171,9 +169,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Doing input file filtering is improper because we may generate empty tasks that process no // input files but stress the scheduler. We should probably add a more general input file // filtering mechanism for `FileFormat` data sources. See SPARK-16317. - if (AvroFileFormat.ignoreFilesWithoutExtensions(conf) && !file.filePath.endsWith(".avro")) { - Iterator.empty - } else { + if (parsedOptions.ignoreExtension || file.filePath.endsWith(".avro")) { val reader = { val in = new FsInput(new Path(new URI(file.filePath)), conf) try { @@ -228,6 +224,8 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { deserializer.deserialize(record).asInstanceOf[InternalRow] } } + } else { + Iterator.empty } } } @@ -274,11 +272,4 @@ private[avro] object AvroFileFormat { value.readFields(new DataInputStream(in)) } } - - def ignoreFilesWithoutExtensions(conf: Configuration): Boolean = { - // Files without .avro extensions are not ignored by default - val defaultValue = false - - conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, defaultValue) - } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 8721eae3481da..cd9a911a14bfa 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -17,16 +17,21 @@ package org.apache.spark.sql.avro +import org.apache.hadoop.conf.Configuration + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** * Options for Avro Reader and Writer stored in case insensitive manner. */ -class AvroOptions(@transient val parameters: CaseInsensitiveMap[String]) - extends Logging with Serializable { +class AvroOptions( + @transient val parameters: CaseInsensitiveMap[String], + @transient val conf: Configuration) extends Logging with Serializable { - def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + def this(parameters: Map[String, String], conf: Configuration) = { + this(CaseInsensitiveMap(parameters), conf) + } /** * Optional schema provided by an user in JSON format. @@ -45,4 +50,22 @@ class AvroOptions(@transient val parameters: CaseInsensitiveMap[String]) * See Avro spec for details: https://avro.apache.org/docs/1.8.2/spec.html#schema_record . */ val recordNamespace: String = parameters.getOrElse("recordNamespace", "") + + /** + * The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read. + * If the option is enabled, all files (with and without `.avro` extension) are loaded. + * If the option is not set, the Hadoop's config `avro.mapred.ignore.inputs.without.extension` + * is taken into account. If the former one is not set too, file extensions are ignored. + */ + val ignoreExtension: Boolean = { + val ignoreFilesWithoutExtensionByDefault = false + val ignoreFilesWithoutExtension = conf.getBoolean( + AvroFileFormat.IgnoreFilesWithoutExtensionProperty, + ignoreFilesWithoutExtensionByDefault) + + parameters + .get("ignoreExtension") + .map(_.toBoolean) + .getOrElse(!ignoreFilesWithoutExtension) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index f7e9877b7744b..dad56aacf9326 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -630,10 +630,21 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") spark.read.avro(dir.toString) } finally { - hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + } } } + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + + spark + .read + .option("ignoreExtension", false) + .avro(dir.toString) + } + } } test("SQL test insert overwrite") { @@ -702,7 +713,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } finally { hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } - assert(count == 8) } } @@ -838,4 +848,45 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(df2.count == 8) } } + + test("SPARK-24836: checking the ignoreExtension option") { + withTempPath { tempDir => + val df = spark.read.avro(episodesAvro) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.avro(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val newDf = spark + .read + .option("ignoreExtension", false) + .avro(tempSaveDir) + + assert(newDf.count == 8) + } + } + + test("SPARK-24836: ignoreExtension must override hadoop's config") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + val count = try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + val newDf = spark + .read + .option("ignoreExtension", "true") + .avro(s"${dir.getCanonicalPath}/episodes") + newDf.count() + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + } + + assert(count == 8) + } + } } From d7ae4247ea8754dcb5fa03c2e0bf2d9aab7828e5 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 21 Jul 2018 16:43:10 +0800 Subject: [PATCH 1188/2461] [SPARK-24873][YARN] Turn off spark-shell noisy log output ## What changes were proposed in this pull request? [SPARK-24182](https://github.com/apache/spark/pull/21243) changed the `logApplicationReport` from `false` to `true`. This pr revert it to `false`. otherwise `spark-shell` will show noisy log output: ```java ... 18/07/16 04:46:25 INFO Client: Application report for application_1530676576026_54551 (state: RUNNING) 18/07/16 04:46:26 INFO Client: Application report for application_1530676576026_54551 (state: RUNNING) ... ``` Closes https://github.com/apache/spark/pull/21827 ## How was this patch tested? manual tests Author: Yuming Wang Closes #21784 from wangyum/SPARK-24182. --- .../spark/scheduler/cluster/YarnClientSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f1a8df00f9c5b..9397a1e3de9ac 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -111,7 +111,7 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { val YarnAppReport(_, state, diags) = - client.monitorApplication(appId.get, logApplicationReport = true) + client.monitorApplication(appId.get, logApplicationReport = false) logError(s"YARN application has exited unexpectedly with state $state! " + "Check the YARN application logs for more details.") diags.foreach { err => From 81af88687f97f70b30828ac63239129637852526 Mon Sep 17 00:00:00 2001 From: zhengruifeng3 Date: Sat, 21 Jul 2018 08:26:45 -0500 Subject: [PATCH 1189/2461] [SPARK-23231][ML][DOC] Add doc for string indexer ordering to user guide (also to RFormula guide) ## What changes were proposed in this pull request? add doc for string indexer ordering ## How was this patch tested? existing tests Author: zhengruifeng3 Author: zhengruifeng Closes #21792 from zhengruifeng/doc_string_indexer_ordering. --- docs/ml-features.md | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index ad6e718b37f1b..882b895a9d154 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -585,7 +585,11 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The indices are in `[0, numLabels)`, and four ordering options are supported: +"frequencyDesc": descending order by label frequency (most frequent label assigned 0), +"frequencyAsc": ascending order by label frequency (least frequent label assigned 0), +"alphabetDesc": descending alphabetical order, and "alphabetAsc": ascending alphabetical order +(default = "frequencyDesc"). The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or @@ -1593,10 +1597,25 @@ Suppose `a` and `b` are double columns, we use the following simple examples to * `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` are coefficients. `RFormula` produces a vector column of features and a double or string column of label. -Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. -If the label column is of type string, it will be first transformed to double with `StringIndexer`. +Like when formulas are used in R for linear regression, numeric columns will be cast to doubles. +As to string input columns, they will first be transformed with [StringIndexer](ml-features.html#stringindexer) using ordering determined by `stringOrderType`, +and the last category after ordering is dropped, then the doubles will be one-hot encoded. + +Suppose a string feature column containing values `{'b', 'a', 'b', 'a', 'c', 'b'}`, we set `stringOrderType` to control the encoding: +~~~ +stringOrderType | Category mapped to 0 by StringIndexer | Category dropped by RFormula +----------------|---------------------------------------|--------------------------------- +'frequencyDesc' | most frequent category ('b') | least frequent category ('c') +'frequencyAsc' | least frequent category ('c') | most frequent category ('b') +'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a') +'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') +~~~ + +If the label column is of type string, it will be first transformed to double with [StringIndexer](ml-features.html#stringindexer) using `frequencyDesc` ordering. If the label column does not exist in the DataFrame, the output label column will be created from the specified response variable in the formula. +**Note:** The ordering option `stringOrderType` is NOT used for the label column. When the label column is indexed, it uses the default descending frequency ordering in `StringIndexer`. + **Examples** Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: From 8817c68f5099901753585716e00281736938bca0 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 22 Jul 2018 17:36:57 -0700 Subject: [PATCH 1190/2461] [SPARK-24811][SQL] Avro: add new function from_avro and to_avro ## What changes were proposed in this pull request? 1. Add a new function from_avro for parsing a binary column of avro format and converting it into its corresponding catalyst value. 2. Add a new function to_avro for converting a column into binary of avro format with the specified schema. I created #21774 for this, but it failed the build https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-maven-hadoop-2.6/7902/ Additional changes In this PR: 1. Add `scalacheck` dependency in pom.xml to resolve the failure. 2. Update the `log4j.properties` to make it consistent with other modules. ## How was this patch tested? Unit test Compile with different commands: ``` ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-2.6 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-2.7 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-3.1 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ``` Author: Gengliang Wang Closes #21838 from gengliangwang/from_and_to_avro. --- external/avro/pom.xml | 5 + .../spark/sql/avro/AvroDataToCatalyst.scala | 68 +++++++ .../spark/sql/avro/CatalystDataToAvro.scala | 69 +++++++ .../org/apache/spark/sql/avro/package.scala | 31 ++++ .../avro/src/test/resources/log4j.properties | 39 +--- .../AvroCatalystDataConversionSuite.scala | 175 ++++++++++++++++++ .../spark/sql/avro/AvroFunctionsSuite.scala | 83 +++++++++ .../expressions/ExpressionEvalHelper.scala | 6 + 8 files changed, 446 insertions(+), 30 deletions(-) create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 42e865bc38824..ad7df1f49ac45 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -61,6 +61,11 @@ test-jar test + + org.scalacheck + scalacheck_${scala.binary.version} + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala new file mode 100644 index 0000000000000..6671b3fb8705c --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericDatumReader +import org.apache.avro.io.{BinaryDecoder, DecoderFactory} + +import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} + +case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType + + override def nullable: Boolean = true + + @transient private lazy val avroSchema = new Schema.Parser().parse(jsonFormatSchema) + + @transient private lazy val reader = new GenericDatumReader[Any](avroSchema) + + @transient private lazy val deserializer = new AvroDeserializer(avroSchema, dataType) + + @transient private var decoder: BinaryDecoder = _ + + @transient private var result: Any = _ + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) + result = reader.read(result, decoder) + deserializer.deserialize(result) + } + + override def simpleString: String = { + s"from_avro(${child.sql}, ${dataType.simpleString})" + } + + override def sql: String = { + s"from_avro(${child.sql}, ${dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala new file mode 100644 index 0000000000000..a669388e88258 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.avro.generic.GenericDatumWriter +import org.apache.avro.io.{BinaryEncoder, EncoderFactory} + +import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters} +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types.{BinaryType, DataType} + +case class CatalystDataToAvro(child: Expression) extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val avroType = + SchemaConverters.toAvroType(child.dataType, child.nullable) + + @transient private lazy val serializer = + new AvroSerializer(child.dataType, avroType, child.nullable) + + @transient private lazy val writer = + new GenericDatumWriter[Any](avroType) + + @transient private var encoder: BinaryEncoder = _ + + @transient private lazy val out = new ByteArrayOutputStream + + override def nullSafeEval(input: Any): Any = { + out.reset() + encoder = EncoderFactory.get().directBinaryEncoder(out, encoder) + val avroData = serializer.serialize(input) + writer.write(avroData, encoder) + encoder.flush() + out.toByteArray + } + + override def simpleString: String = { + s"to_avro(${child.sql}, ${child.dataType.simpleString})" + } + + override def sql: String = { + s"to_avro(${child.sql}, ${child.dataType.catalogString})" + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(byte[]) $expr.nullSafeEval($input)") + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala index b3c8a669cf820..e82651d96a03d 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.avro.Schema + +import org.apache.spark.annotation.Experimental + package object avro { /** * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using @@ -36,4 +40,31 @@ package object avro { @scala.annotation.varargs def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) } + + /** + * Converts a binary column of avro format into its corresponding catalyst value. The specified + * schema must match the read data, otherwise the behavior is undefined: it may fail or return + * arbitrary result. + * + * @param data the binary column. + * @param jsonFormatSchema the avro schema in JSON string format. + * + * @since 2.4.0 + */ + @Experimental + def from_avro(data: Column, jsonFormatSchema: String): Column = { + new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema)) + } + + /** + * Converts a column into binary of avro format. + * + * @param data the data column. + * + * @since 2.4.0 + */ + @Experimental + def to_avro(data: Column): Column = { + new Column(CatalystDataToAvro(data.expr)) + } } diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties index f80a5291bc078..75e3b53a093f6 100644 --- a/external/avro/src/test/resources/log4j.properties +++ b/external/avro/src/test/resources/log4j.properties @@ -15,35 +15,14 @@ # limitations under the License. # -# Set everything to be logged to the file core/target/unit-tests.log -log4j.rootLogger=DEBUG, CA, FA +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n -#Console Appender -log4j.appender.CA=org.apache.log4j.ConsoleAppender -log4j.appender.CA.layout=org.apache.log4j.PatternLayout -log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n -log4j.appender.CA.Threshold = WARN +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN - -#File Appender -log4j.appender.FA=org.apache.log4j.FileAppender -log4j.appender.FA.append=false -log4j.appender.FA.file=target/unit-tests.log -log4j.appender.FA.layout=org.apache.log4j.PatternLayout -log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Set the logger level of File Appender to WARN -log4j.appender.FA.Threshold = INFO - -# Some packages are noisy for no good reason. -log4j.additivity.parquet.hadoop.ParquetRecordReader=false -log4j.logger.parquet.hadoop.ParquetRecordReader=OFF - -log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false -log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF - -log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false -log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF - -log4j.additivity.hive.ql.metadata.Hive=false -log4j.logger.hive.ql.metadata.Hive=OFF diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..06d5477b2ea45 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AvroDataToCatalyst, CatalystDataToAvro, RandomDataGenerator} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def roundTripTest(data: Literal): Unit = { + val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable) + checkResult(data, avroType.toString, data.eval()) + } + + private def checkResult(data: Literal, schema: String, expected: Any): Unit = { + checkEvaluation( + AvroDataToCatalyst(CatalystDataToAvro(data), schema), + prepareExpectedResult(expected)) + } + + private def assertFail(data: Literal, schema: String): Unit = { + intercept[java.io.EOFException] { + AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval() + } + } + + private val testingTypes = Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(8, 0), // 32 bits decimal without fraction + DecimalType(8, 4), // 32 bits decimal + DecimalType(16, 0), // 64 bits decimal without fraction + DecimalType(16, 11), // 64 bits decimal + DecimalType(38, 0), + DecimalType(38, 38), + StringType, + BinaryType) + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark decimal is converted to avro string= + case d: Decimal => UTF8String.fromString(d.toString) + // Spark byte and short both map to avro int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + testingTypes.foreach { dt => + val seed = scala.util.Random.nextLong() + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomSchema(rand, 5, testingTypes) + test(s"flat schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + for (_ <- 1 to 5) { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes) + test(s"nested schema ${schema.catalogString} with seed $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + + test("read int as string") { + val data = Literal(1) + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + + // When read int as string, avro reader is not able to parse the binary and fail. + assertFail(data, avroTypeJson) + } + + test("read string as int") { + val data = Literal("abc") + val avroTypeJson = + s""" + |{ + | "type": "int", + | "name": "my_int" + |} + """.stripMargin + + // When read string data as int, avro reader is not able to find the type mismatch and read + // the string length as int value. + checkResult(data, avroTypeJson, 3) + } + + test("read float as double") { + val data = Literal(1.23f) + val avroTypeJson = + s""" + |{ + | "type": "double", + | "name": "my_double" + |} + """.stripMargin + + // When read float data as double, avro reader fails(trying to read 8 bytes while the data have + // only 4 bytes). + assertFail(data, avroTypeJson) + } + + test("read double as float") { + val data = Literal(1.23) + val avroTypeJson = + s""" + |{ + | "type": "float", + | "name": "my_float" + |} + """.stripMargin + + // avro reader reads the first 4 bytes of a double as a float, the result is totally undefined. + checkResult(data, avroTypeJson, 5.848603E35f) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala new file mode 100644 index 0000000000000..90a4cd6ccf9dd --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.test.SharedSQLContext + +class AvroFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("roundtrip in to_avro and from_avro - int and string") { + val df = spark.range(10).select('id, 'id.cast("string").as("str")) + + val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroTypeLong = s""" + |{ + | "type": "int", + | "name": "id" + |} + """.stripMargin + val avroTypeStr = s""" + |{ + | "type": "string", + | "name": "str" + |} + """.stripMargin + checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + } + + test("roundtrip in to_avro and from_avro - struct") { + val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroTypeStruct = s""" + |{ + | "type": "record", + | "name": "struct", + | "fields": [ + | {"name": "col1", "type": "long"}, + | {"name": "col2", "type": "string"} + | ] + |} + """.stripMargin + checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + } + + test("roundtrip in to_avro and from_avro - array with null") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val avroTypeArrStruct = s""" + |[ { + | "type" : "array", + | "items" : [ { + | "type" : "record", + | "name" : "x", + | "fields" : [ { + | "name" : "y", + | "type" : "int" + | } ] + | }, "null" ] + |}, "null" ] + """.stripMargin + val readBackOne = dfOne.select(to_avro($"array").as("avro")) + .select(from_avro($"avro", avroTypeArrStruct).as("array")) + checkAnswer(dfOne, readBackOne) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 14bfa212b5496..d045267ef5d9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,6 +79,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: InternalRow, expected: InternalRow) => + val st = dataType.asInstanceOf[StructType] + assert(result.numFields == st.length && expected.numFields == st.length) + st.zipWithIndex.forall { case (f, i) => + checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) + } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { val et = dataType.asInstanceOf[ArrayType].elementType From f59de52a2a2fc8b8c596230b76f5fd2aa9fedd58 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 23 Jul 2018 15:27:33 +0800 Subject: [PATCH 1191/2461] [SPARK-24883][SQL] Avro: remove implicit class AvroDataFrameWriter/AvroDataFrameReader ## What changes were proposed in this pull request? As per Reynold's comment: https://github.com/apache/spark/pull/21742#discussion_r203496489 It makes sense to remove the implicit class AvroDataFrameWriter/AvroDataFrameReader, since the Avro package is external module. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21841 from gengliangwang/removeImplicit. --- .../org/apache/spark/sql/avro/package.scala | 21 -- .../org/apache/spark/sql/avro/AvroSuite.scala | 185 +++++++++--------- 2 files changed, 96 insertions(+), 110 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala index e82651d96a03d..97f9427f96c55 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -17,30 +17,9 @@ package org.apache.spark.sql -import org.apache.avro.Schema - import org.apache.spark.annotation.Experimental package object avro { - /** - * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using - * the DataFileWriter - */ - implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) { - def avro: String => Unit = writer.format("avro").save - } - - /** - * Adds a method, `avro`, to DataFrameReader that allows you to read avro files using - * the DataFileReader - */ - implicit class AvroDataFrameReader(reader: DataFrameReader) { - def avro: String => DataFrame = reader.format("avro").load - - @scala.annotation.varargs - def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) - } - /** * Converts a binary column of avro format into its corresponding catalyst value. The specified * schema must match the read data, otherwise the behavior is undefined: it may fail or return diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index dad56aacf9326..ec1627a3898bf 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -46,24 +46,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { - val originalEntries = spark.read.avro(testAvro).collect() - val newEntries = spark.read.avro(newFile) + val originalEntries = spark.read.format("avro").load(testAvro).collect() + val newEntries = spark.read.format("avro").load(newFile) checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { - val df = spark.read.avro(episodesAvro, episodesAvro) + val df = spark.read.format("avro").load(episodesAvro, episodesAvro) assert(df.count == 16) } test("reading and writing partitioned data") { - val df = spark.read.avro(episodesAvro) + val df = spark.read.format("avro").load(episodesAvro) val fields = List("title", "air_date", "doctor") for (field <- fields) { withTempPath { dir => val outputDir = s"$dir/${UUID.randomUUID}" - df.write.partitionBy(field).avro(outputDir) - val input = spark.read.avro(outputDir) + df.write.partitionBy(field).format("avro").save(outputDir) + val input = spark.read.format("avro").load(outputDir) // makes sure that no fields got dropped. // We convert Rows to Seqs in order to work around SPARK-10325 assert(input.select(field).collect().map(_.toSeq).toSet === @@ -73,14 +73,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("request no fields") { - val df = spark.read.avro(episodesAvro) + val df = spark.read.format("avro").load(episodesAvro) df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { withTempPath { dir => - val df = spark.read.avro(episodesAvro) + val df = spark.read.format("avro").load(episodesAvro) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) } @@ -88,8 +88,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("rearrange internal schema") { withTempPath { dir => - val df = spark.read.avro(episodesAvro) - df.select("doctor", "title").write.avro(dir.getCanonicalPath) + val df = spark.read.format("avro").load(episodesAvro) + df.select("doctor", "title").write.format("avro").save(dir.getCanonicalPath) } } @@ -109,7 +109,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { dataFileWriter.close() intercept[IncompatibleSchemaException] { - spark.read.avro(s"$dir.avro") + spark.read.format("avro").load(s"$dir.avro") } } } @@ -136,7 +136,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { dataFileWriter.append(rec2) dataFileWriter.flush() dataFileWriter.close() - val df = spark.read.avro(s"$dir.avro") + val df = spark.read.format("avro").load(s"$dir.avro") assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) assert(df.collect().toSet == Set(Row(1L), Row(2L))) } @@ -164,7 +164,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { dataFileWriter.append(rec2) dataFileWriter.flush() dataFileWriter.close() - val df = spark.read.avro(s"$dir.avro") + val df = spark.read.format("avro").load(s"$dir.avro") assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) } @@ -196,7 +196,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { dataFileWriter.append(rec2) dataFileWriter.flush() dataFileWriter.close() - val df = spark.read.avro(s"$dir.avro") + val df = spark.read.format("avro").load(s"$dir.avro") assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) } @@ -220,7 +220,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { dataFileWriter.flush() dataFileWriter.close() - val df = spark.read.avro(s"$dir.avro") + val df = spark.read.format("avro").load(s"$dir.avro") assert(df.first() == Row(8)) } } @@ -255,7 +255,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { dataFileWriter.flush() dataFileWriter.close() - val df = spark.sqlContext.read.avro(s"$dir.avro") + val df = spark.sqlContext.read.format("avro").load(s"$dir.avro") assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) @@ -277,8 +277,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row(null, null, null, null, null), Row(null, null, null, null, null))) val df = spark.createDataFrame(rdd, schema) - df.write.avro(dir.toString) - assert(spark.read.avro(dir.toString).count == rdd.count) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) } } @@ -296,8 +296,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row(3f, 3.toShort, 3.toByte, true) )) val df = spark.createDataFrame(rdd, schema) - df.write.avro(dir.toString) - assert(spark.read.avro(dir.toString).count == rdd.count) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) } } @@ -314,9 +314,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row(3f, new Date(1460066400500L)) )) val df = spark.createDataFrame(rdd, schema) - df.write.avro(dir.toString) - assert(spark.read.avro(dir.toString).count == rdd.count) - assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) + assert( + spark.read.format("avro").load(dir.toString).select("date").collect().map(_(0)).toSet == Array(null, 1451865600000L, 1459987200000L).toSet) } } @@ -350,8 +351,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), Array[Row](Row("Bobby G. can't swim"))))) val df = spark.createDataFrame(rdd, testSchema) - df.write.avro(dir.toString) - assert(spark.read.avro(dir.toString).count == rdd.count) + df.write.format("avro").save(dir.toString) + assert(spark.read.format("avro").load(dir.toString).count == rdd.count) } } @@ -363,14 +364,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val df = spark.read.avro(testAvro) + val df = spark.read.format("avro").load(testAvro) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") - df.write.avro(uncompressDir) + df.write.format("avro").save(uncompressDir) spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") spark.conf.set(AVRO_DEFLATE_LEVEL, "9") - df.write.avro(deflateDir) + df.write.format("avro").save(deflateDir) spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") - df.write.avro(snappyDir) + df.write.format("avro").save(snappyDir) val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) @@ -382,49 +383,50 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("dsl test") { - val results = spark.read.avro(episodesAvro).select("title").collect() + val results = spark.read.format("avro").load(episodesAvro).select("title").collect() assert(results.length === 8) } test("support of various data types") { // This test uses data from test.avro. You can see the data and the schema of this file in // test.json and test.avsc - val all = spark.read.avro(testAvro).collect() + val all = spark.read.format("avro").load(testAvro).collect() assert(all.length == 3) - val str = spark.read.avro(testAvro).select("string").collect() + val str = spark.read.format("avro").load(testAvro).select("string").collect() assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) - val simple_map = spark.read.avro(testAvro).select("simple_map").collect() + val simple_map = spark.read.format("avro").load(testAvro).select("simple_map").collect() assert(simple_map(0)(0).getClass.toString.contains("Map")) assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) - val union0 = spark.read.avro(testAvro).select("union_string_null").collect() + val union0 = spark.read.format("avro").load(testAvro).select("union_string_null").collect() assert(union0.map(_(0)).toSet == Set("abc", "123", null)) - val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect() + val union1 = spark.read.format("avro").load(testAvro).select("union_int_long_null").collect() assert(union1.map(_(0)).toSet == Set(66, 1, null)) - val union2 = spark.read.avro(testAvro).select("union_float_double").collect() + val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect() assert( union2 .map(x => new java.lang.Double(x(0).toString)) .exists(p => Math.abs(p - Math.PI) < 0.001)) - val fixed = spark.read.avro(testAvro).select("fixed3").collect() + val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect() assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) - val enum = spark.read.avro(testAvro).select("enum").collect() + val enum = spark.read.format("avro").load(testAvro).select("enum").collect() assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) - val record = spark.read.avro(testAvro).select("record").collect() + val record = spark.read.format("avro").load(testAvro).select("record").collect() assert(record(0)(0).getClass.toString.contains("Row")) assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) - val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect() + val array_of_boolean = + spark.read.format("avro").load(testAvro).select("array_of_boolean").collect() assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) - val bytes = spark.read.avro(testAvro).select("bytes").collect() + val bytes = spark.read.format("avro").load(testAvro).select("bytes").collect() assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) } @@ -444,7 +446,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // get the same values back. withTempPath { dir => val avroDir = s"$dir/avro" - spark.read.avro(testAvro).write.avro(avroDir) + spark.read.format("avro").load(testAvro).write.format("avro").save(avroDir) checkReloadMatchesSaved(testAvro, avroDir) } } @@ -458,7 +460,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" - spark.read.avro(testAvro).write.options(parameters).avro(avroDir) + spark.read.format("avro").load(testAvro) + .write.options(parameters).format("avro").save(avroDir) checkReloadMatchesSaved(testAvro, avroDir) // Look at raw file and make sure has namespace info @@ -489,22 +492,22 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) val avroDir = tempDir + "/avro" - cityDataFrame.write.avro(avroDir) - assert(spark.read.avro(avroDir).collect().length == 3) + cityDataFrame.write.format("avro").save(avroDir) + assert(spark.read.format("avro").load(avroDir).collect().length == 3) // TimesStamps are converted to longs - val times = spark.read.avro(avroDir).select("Time").collect() + val times = spark.read.format("avro").load(avroDir).select("Time").collect() assert(times.map(_(0)).toSet == Set(666, 777, 42)) // DecimalType should be converted to string - val decimals = spark.read.avro(avroDir).select("Decimal").collect() + val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect() assert(decimals.map(_(0)).contains("3.14")) // There should be a null entry - val length = spark.read.avro(avroDir).select("Length").collect() + val length = spark.read.format("avro").load(avroDir).select("Length").collect() assert(length.map(_(0)).contains(null)) - val binary = spark.read.avro(avroDir).select("Binary").collect() + val binary = spark.read.format("avro").load(avroDir).select("Binary").collect() for (i <- arrayOfByte.indices) { assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) } @@ -523,10 +526,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val writeDs = Seq((currentDate, currentTime)).toDS val avroDir = tempDir + "/avro" - writeDs.write.avro(avroDir) - assert(spark.read.avro(avroDir).collect().length == 1) + writeDs.write.format("avro").save(avroDir) + assert(spark.read.format("avro").load(avroDir).collect().length == 1) - val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)] + val readDs = spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)] assert(readDs.collect().sameElements(writeDs.collect())) } @@ -534,10 +537,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("support of globbed paths") { val resourceDir = testFile(".") - val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect() + val e1 = spark.read.format("avro").load(resourceDir + "../*/episodes.avro").collect() assert(e1.length == 8) - val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect() + val e2 = spark.read.format("avro").load(resourceDir + "../../*/*/episodes.avro").collect() assert(e2.length == 8) } @@ -555,8 +558,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val writeDs = Seq((nullDate, nullTime)).toDS val avroDir = tempDir + "/avro" - writeDs.write.avro(avroDir) - val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect + writeDs.write.format("avro").save(avroDir) + val readValues = + spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)].collect assert(readValues.size == 1) assert(readValues.head == ((nullDate, nullTime))) @@ -579,9 +583,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val result = spark .read .option("avroSchema", avroSchema) - .avro(testAvro) + .format("avro") + .load(testAvro) .collect() - val expected = spark.read.avro(testAvro).select("string").collect() + val expected = spark.read.format("avro").load(testAvro).select("string").collect() assert(result.sameElements(expected)) } @@ -601,7 +606,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val result = spark .read .option("avroSchema", avroSchema) - .avro(testAvro).select("missingField").first + .format("avro").load(testAvro).select("missingField").first assert(result === Row("foo")) } @@ -609,17 +614,17 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Directory given has no avro files intercept[AnalysisException] { - withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) + withTempPath(dir => spark.read.format("avro").load(dir.getCanonicalPath)) } intercept[AnalysisException] { - spark.read.avro("very/invalid/path/123.avro") + spark.read.format("avro").load("very/invalid/path/123.avro") } // In case of globbed path that can't be matched to anything, another exception is thrown (and // exception message is helpful) intercept[AnalysisException] { - spark.read.avro("*/*/*/*/*/*/*/something.avro") + spark.read.format("avro").load("*/*/*/*/*/*/*/something.avro") } intercept[FileNotFoundException] { @@ -628,7 +633,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration try { hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") - spark.read.avro(dir.toString) + spark.read.format("avro").load(dir.toString) } finally { hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } @@ -642,7 +647,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { spark .read .option("ignoreExtension", false) - .avro(dir.toString) + .format("avro") + .load(dir.toString) } } } @@ -681,13 +687,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test save and load") { // Test if load works as expected withTempPath { tempDir => - val df = spark.read.avro(episodesAvro) + val df = spark.read.format("avro").load(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" - df.write.avro(tempSaveDir) - val newDf = spark.read.avro(tempSaveDir) + df.write.format("avro").save(tempSaveDir) + val newDf = spark.read.format("avro").load(tempSaveDir) assert(newDf.count == 8) } } @@ -695,20 +701,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test load with non-Avro file") { // Test if load works as expected withTempPath { tempDir => - val df = spark.read.avro(episodesAvro) + val df = spark.read.format("avro").load(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" - df.write.avro(tempSaveDir) + df.write.format("avro").save(tempSaveDir) Files.createFile(new File(tempSaveDir, "non-avro").toPath) val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration val count = try { hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") - val newDf = spark - .read - .avro(tempSaveDir) + val newDf = spark.read.format("avro").load(tempSaveDir) newDf.count() } finally { hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) @@ -730,10 +734,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), StructField("array_of_boolean", ArrayType(BooleanType), false), StructField("bytes", BinaryType, true))) - val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect() + val withSchema = spark.read.schema(partialColumns).format("avro").load(testAvro).collect() val withOutSchema = spark .read - .avro(testAvro) + .format("avro") + .load(testAvro) .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") .collect() @@ -751,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("non_exist_field", StringType, false), StructField("non_exist_field2", StringType, false))), false))) - val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect() + val withEmptyColumn = spark.read.schema(schema).format("avro").load(testAvro).collect() assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) } @@ -762,8 +767,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") val outputDir = s"$dir/${UUID.randomUUID}" - df.write.avro(outputDir) - val input = spark.read.avro(outputDir) + df.write.format("avro").save(outputDir) + val input = spark.read.format("avro").load(outputDir) assert(input.collect.toSet.size === 1024 * 3 + 1) assert(input.rdd.partitions.size > 2) } @@ -780,9 +785,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" - writeDf.write.avro(outputFolder) + writeDf.write.format("avro").save(outputFolder) // Read avro file saved on the last step - val readDf = spark.read.avro(outputFolder) + val readDf = spark.read.format("avro").load(outputFolder) // Check if the written DataFrame is equals than read DataFrame assert(readDf.collect().sameElements(writeDf.collect())) } @@ -801,9 +806,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { )))) ) val outputFolder = s"$tempDir/duplicate_names_array/" - writeDf.write.avro(outputFolder) + writeDf.write.format("avro").save(outputFolder) // Read avro file saved on the last step - val readDf = spark.read.avro(outputFolder) + val readDf = spark.read.format("avro").load(outputFolder) // Check if the written DataFrame is equals than read DataFrame assert(readDf.collect().sameElements(writeDf.collect())) } @@ -822,9 +827,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { )))) ) val outputFolder = s"$tempDir/duplicate_names_map/" - writeDf.write.avro(outputFolder) + writeDf.write.format("avro").save(outputFolder) // Read avro file saved on the last step - val readDf = spark.read.avro(outputFolder) + val readDf = spark.read.format("avro").load(outputFolder) // Check if the written DataFrame is equals than read DataFrame assert(readDf.collect().sameElements(writeDf.collect())) } @@ -837,32 +842,33 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Paths.get(dir.getCanonicalPath, "episodes")) val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" - val df1 = spark.read.avro(fileWithoutExtension) + val df1 = spark.read.format("avro").load(fileWithoutExtension) assert(df1.count == 8) val schema = new StructType() .add("title", StringType) .add("air_date", StringType) .add("doctor", IntegerType) - val df2 = spark.read.schema(schema).avro(fileWithoutExtension) + val df2 = spark.read.schema(schema).format("avro").load(fileWithoutExtension) assert(df2.count == 8) } } test("SPARK-24836: checking the ignoreExtension option") { withTempPath { tempDir => - val df = spark.read.avro(episodesAvro) + val df = spark.read.format("avro").load(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" - df.write.avro(tempSaveDir) + df.write.format("avro").save(tempSaveDir) Files.createFile(new File(tempSaveDir, "non-avro").toPath) val newDf = spark .read .option("ignoreExtension", false) - .avro(tempSaveDir) + .format("avro") + .load(tempSaveDir) assert(newDf.count == 8) } @@ -880,7 +886,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val newDf = spark .read .option("ignoreExtension", "true") - .avro(s"${dir.getCanonicalPath}/episodes") + .format("avro") + .load(s"${dir.getCanonicalPath}/episodes") newDf.count() } finally { hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) From ab18b02e66fd04bc8f1a4fb7b6a7f2773902a494 Mon Sep 17 00:00:00 2001 From: SongYadong Date: Mon, 23 Jul 2018 19:10:53 +0800 Subject: [PATCH 1192/2461] [SQL][HIVE] Correct an assert message in function makeRDDForTable ## What changes were proposed in this pull request? according to the context, "makeRDDForTablePartitions" in assert message should be "makeRDDForPartitionedTable", because "makeRDDForTablePartitions" does't exist in spark code. ## How was this patch tested? unit tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: SongYadong Closes #21836 from SongYadong/assert_info_modify. --- .../main/scala/org/apache/spark/sql/hive/TableReader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b5444a4217924..7d57389947576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -110,8 +110,9 @@ class HadoopTableReader( deserializerClass: Class[_ <: Deserializer], filterOpt: Option[PathFilter]): RDD[InternalRow] = { - assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, - since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") + assert(!hiveTable.isPartitioned, + "makeRDDForTable() cannot be called on a partitioned table, since input formats may " + + "differ across partitions. Use makeRDDForPartitionedTable() instead.") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. From 434319e73f8cb6e080671bdde42a72228bd814ef Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 23 Jul 2018 08:25:24 -0700 Subject: [PATCH 1193/2461] [SPARK-24802][SQL] Add a new config for Optimization Rule Exclusion ## What changes were proposed in this pull request? Since Spark has provided fairly clear interfaces for adding user-defined optimization rules, it would be nice to have an easy-to-use interface for excluding an optimization rule from the Spark query optimizer as well. This would make customizing Spark optimizer easier and sometimes could debugging issues too. - Add a new config spark.sql.optimizer.excludedRules, with the value being a list of rule names separated by comma. - Modify the current batches method to remove the excluded rules from the default batches. Log the rules that have been excluded. - Split the existing default batches into "post-analysis batches" and "optimization batches" so that only rules in the "optimization batches" can be excluded. ## How was this patch tested? Add a new test suite: OptimizerRuleExclusionSuite Author: maryannxue Closes #21764 from maryannxue/rule-exclusion. --- .../sql/catalyst/optimizer/Optimizer.scala | 53 ++++++++- .../apache/spark/sql/internal/SQLConf.scala | 10 ++ .../OptimizerRuleExclusionSuite.scala | 101 ++++++++++++++++++ 3 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2cc27d82f7d20..6faecd3efc40d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,7 +46,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) - def batches: Seq[Batch] = { + def defaultBatches: Seq[Batch] = { val operatorOptimizationRuleSet = Seq( // Operator push down @@ -160,6 +160,22 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) UpdateNullabilityInAttributeReferences) } + def nonExcludableRules: Seq[String] = + EliminateDistinct.ruleName :: + EliminateSubqueryAliases.ruleName :: + EliminateView.ruleName :: + ReplaceExpressions.ruleName :: + ComputeCurrentTime.ruleName :: + GetCurrentDatabase(sessionCatalog).ruleName :: + RewriteDistinctAggregates.ruleName :: + ReplaceDeduplicateWithAggregate.ruleName :: + ReplaceIntersectWithSemiJoin.ruleName :: + ReplaceExceptWithFilter.ruleName :: + ReplaceExceptWithAntiJoin.ruleName :: + ReplaceDistinctWithAggregate.ruleName :: + PullupCorrelatedPredicates.ruleName :: + RewritePredicateSubquery.ruleName :: Nil + /** * Optimize all the subqueries inside expression. */ @@ -175,6 +191,41 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) * Override to provide additional rules for the operator optimization batch. */ def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + override def batches: Seq[Batch] = { + val excludedRulesConf = + SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq) + val excludedRules = excludedRulesConf.filter { ruleName => + val nonExcludable = nonExcludableRules.contains(ruleName) + if (nonExcludable) { + logWarning(s"Optimization rule '${ruleName}' was not excluded from the optimizer " + + s"because this rule is a non-excludable rule.") + } + !nonExcludable + } + if (excludedRules.isEmpty) { + defaultBatches + } else { + defaultBatches.flatMap { batch => + val filteredRules = batch.rules.filter { rule => + val exclude = excludedRules.contains(rule.ruleName) + if (exclude) { + logInfo(s"Optimization rule '${rule.ruleName}' is excluded from the optimizer.") + } + !exclude + } + if (batch.rules == filteredRules) { + Some(batch) + } else if (filteredRules.nonEmpty) { + Some(Batch(batch.name, batch.strategy, filteredRules: _*)) + } else { + logInfo(s"Optimization batch '${batch.name}' is excluded from the optimizer " + + s"as all enclosed rules have been excluded.") + None + } + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fbb9a8cfae2e1..d7c830dfa0454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -127,6 +127,14 @@ object SQLConf { } } + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") + .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + + "specified by their rule names and separated by comma. It is not guaranteed that all the " + + "rules in this configuration will eventually be excluded, as some rules are necessary " + + "for correctness. The optimizer will log the rules that have indeed been excluded.") + .stringConf + .createOptional + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() .doc("The max number of iterations the optimizer and analyzer runs.") @@ -1444,6 +1452,8 @@ class SQLConf extends Serializable with Logging { /** ************************ Spark SQL Params/Hints ******************* */ + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) + def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala new file mode 100644 index 0000000000000..5a5396e6f58b0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_EXCLUDED_RULES + + +class OptimizerRuleExclusionSuite extends PlanTest { + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + private def verifyExcludedRules(excludedRuleNames: Seq[String]) { + val optimizer = new SimpleTestOptimizer() + // Batches whose rules are all to be excluded should be removed as a whole. + val excludedBatchNames = optimizer.batches + .filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName))) + .map(_.name) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) { + val batches = optimizer.batches + assert(batches.forall(batch => !excludedBatchNames.contains(batch.name))) + assert( + batches + .forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName)))) + } + } + + test("Exclude a single rule from multiple batches") { + verifyExcludedRules( + Seq( + PushPredicateThroughJoin.ruleName)) + } + + test("Exclude multiple rules from single or multiple batches") { + verifyExcludedRules( + Seq( + CombineUnions.ruleName, + RemoveLiteralFromGroupExpressions.ruleName, + RemoveRepetitionFromGroupExpressions.ruleName)) + } + + test("Exclude non-existent rule with other valid rules") { + verifyExcludedRules( + Seq( + LimitPushDown.ruleName, + InferFiltersFromConstraints.ruleName, + "DummyRuleName")) + } + + test("Try to exclude a non-excludable rule") { + val excludedRules = Seq( + ReplaceIntersectWithSemiJoin.ruleName, + PullupCorrelatedPredicates.ruleName) + + val optimizer = new SimpleTestOptimizer() + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { + excludedRules.foreach { excludedRule => + assert( + optimizer.batches + .exists(batch => batch.rules.exists(rule => rule.ruleName == excludedRule))) + } + } + } + + test("Verify optimized plan after excluding CombineUnions rule") { + val excludedRules = Seq( + ConvertToLocalRelation.ruleName, + PropagateEmptyRelation.ruleName, + CombineUnions.ruleName) + + withSQLConf( + OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { + val optimizer = new SimpleTestOptimizer() + val originalQuery = testRelation.union(testRelation.union(testRelation)).analyze + val optimized = optimizer.execute(originalQuery) + comparePlans(originalQuery, optimized) + } + } +} From 08e315f6330984b757f241079dfc9e1028e5cd0a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 23 Jul 2018 08:31:48 -0700 Subject: [PATCH 1194/2461] [SPARK-24887][SQL] Avro: use SerializableConfiguration in Spark utils to deduplicate code ## What changes were proposed in this pull request? To implement the method `buildReader` in `FileFormat`, it is required to serialize the hadoop configuration for executors. Previous spark-avro uses its own class `SerializableConfiguration` for the serialization. As now it is part of Spark, we can use SerializableConfiguration in Spark util to deduplicate the code. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21846 from gengliangwang/removeSerializableConfiguration. --- .../spark/sql/avro/AvroFileFormat.scala | 44 +--------------- .../avro/SerializableConfigurationSuite.scala | 50 ------------------- 2 files changed, 2 insertions(+), 92 deletions(-) delete mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 078efabbeeb4e..b043252f49afa 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -23,8 +23,6 @@ import java.util.zip.Deflater import scala.util.control.NonFatal -import com.esotericsoftware.kryo.{Kryo, KryoSerializable} -import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.avro.Schema import org.apache.avro.file.{DataFileConstants, DataFileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} @@ -41,6 +39,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { private val log = LoggerFactory.getLogger(getClass) @@ -157,7 +156,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val broadcastedConf = - spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + spark.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val parsedOptions = new AvroOptions(options, hadoopConf) (file: PartitionedFile) => { @@ -233,43 +232,4 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { private[avro] object AvroFileFormat { val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" - - class SerializableConfiguration(@transient var value: Configuration) - extends Serializable with KryoSerializable { - @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) - - private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { - out.defaultWriteObject() - value.write(out) - } - - private def readObject(in: ObjectInputStream): Unit = tryOrIOException { - value = new Configuration(false) - value.readFields(in) - } - - private def tryOrIOException[T](block: => T): T = { - try { - block - } catch { - case e: IOException => - log.error("Exception encountered", e) - throw e - case NonFatal(e) => - log.error("Exception encountered", e) - throw new IOException(e) - } - } - - def write(kryo: Kryo, out: Output): Unit = { - val dos = new DataOutputStream(out) - value.write(dos) - dos.flush() - } - - def read(kryo: Kryo, in: Input): Unit = { - value = new Configuration(false) - value.readFields(new DataInputStream(in)) - } - } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala deleted file mode 100755 index a0f88515ed9d4..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} - -class SerializableConfigurationSuite extends SparkFunSuite { - - private def testSerialization(serializer: SerializerInstance): Unit = { - import AvroFileFormat.SerializableConfiguration - val conf = new SerializableConfiguration(new Configuration()) - - val serialized = serializer.serialize(conf) - - serializer.deserialize[Any](serialized) match { - case c: SerializableConfiguration => - assert(c.log != null, "log was null") - assert(c.value != null, "value was null") - case other => fail( - s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") - } - } - - test("serialization with JavaSerializer") { - testSerialization(new JavaSerializer(new SparkConf()).newInstance()) - } - - test("serialization with KryoSerializer") { - testSerialization(new KryoSerializer(new SparkConf()).newInstance()) - } - -} From 2edf17effd8b0aba61c95dddd5823ad7277d6c7d Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Mon, 23 Jul 2018 09:52:28 -0700 Subject: [PATCH 1195/2461] [SPARK-24850][SQL] fix str representation of CachedRDDBuilder ## What changes were proposed in this pull request? As of https://github.com/apache/spark/pull/21018, InMemoryRelation includes its cacheBuilder when logging query plans. This PR changes the string representation of the CachedRDDBuilder to not include the cached spark plan. ## How was this patch tested? spark-shell, query: ``` var df_cached = spark.read.format("csv").option("header", "true").load("test.csv").cache() 0 to 1 foreach { _ => df_cached = df_cached.join(spark.read.format("csv").option("header", "true").load("test.csv"), "A").cache() } df_cached.explain ``` as of master results in: ``` == Physical Plan == InMemoryTableScan [A#10, B#11, B#35, B#87] +- InMemoryRelation [A#10, B#11, B#35, B#87], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(2) Project [A#10, B#11, B#35, B#87] +- *(2) BroadcastHashJoin [A#10], [A#86], Inner, BuildRight :- *(2) Filter isnotnull(A#10) : +- InMemoryTableScan [A#10, B#11, B#35], [isnotnull(A#10)] : +- InMemoryRelation [A#10, B#11, B#35], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(2) Project [A#10, B#11, B#35] +- *(2) BroadcastHashJoin [A#10], [A#34], Inner, BuildRight :- *(2) Filter isnotnull(A#10) : +- InMemoryTableScan [A#10, B#11], [isnotnull(A#10)] : +- InMemoryRelation [A#10, B#11], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) +- *(1) Filter isnotnull(A#34) +- InMemoryTableScan [A#34, B#35], [isnotnull(A#34)] +- InMemoryRelation [A#34, B#35], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : +- *(2) Project [A#10, B#11, B#35] : +- *(2) BroadcastHashJoin [A#10], [A#34], Inner, BuildRight : :- *(2) Filter isnotnull(A#10) : : +- InMemoryTableScan [A#10, B#11], [isnotnull(A#10)] : : +- InMemoryRelation [A#10, B#11], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) : +- *(1) Filter isnotnull(A#34) : +- InMemoryTableScan [A#34, B#35], [isnotnull(A#34)] : +- InMemoryRelation [A#34, B#35], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) +- *(1) Filter isnotnull(A#86) +- InMemoryTableScan [A#86, B#87], [isnotnull(A#86)] +- InMemoryRelation [A#86, B#87], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) +- *(2) Project [A#10, B#11, B#35, B#87] +- *(2) BroadcastHashJoin [A#10], [A#86], Inner, BuildRight :- *(2) Filter isnotnull(A#10) : +- InMemoryTableScan [A#10, B#11, B#35], [isnotnull(A#10)] : +- InMemoryRelation [A#10, B#11, B#35], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(2) Project [A#10, B#11, B#35] +- *(2) BroadcastHashJoin [A#10], [A#34], Inner, BuildRight :- *(2) Filter isnotnull(A#10) : +- InMemoryTableScan [A#10, B#11], [isnotnull(A#10)] : +- InMemoryRelation [A#10, B#11], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) +- *(1) Filter isnotnull(A#34) +- InMemoryTableScan [A#34, B#35], [isnotnull(A#34)] +- InMemoryRelation [A#34, B#35], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : +- *(2) Project [A#10, B#11, B#35] : +- *(2) BroadcastHashJoin [A#10], [A#34], Inner, BuildRight : :- *(2) Filter isnotnull(A#10) : : +- InMemoryTableScan [A#10, B#11], [isnotnull(A#10)] : : +- InMemoryRelation [A#10, B#11], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) : +- *(1) Filter isnotnull(A#34) : +- InMemoryTableScan [A#34, B#35], [isnotnull(A#34)] : +- InMemoryRelation [A#34, B#35], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) +- *(1) Filter isnotnull(A#86) +- InMemoryTableScan [A#86, B#87], [isnotnull(A#86)] +- InMemoryRelation [A#86, B#87], CachedRDDBuilder(true,10000,StorageLevel(disk, memory, deserialized, 1 replicas),*(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ,None) +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` with this patch results in: ``` == Physical Plan == InMemoryTableScan [A#10, B#11, B#35, B#87] +- InMemoryRelation [A#10, B#11, B#35, B#87], CachedRDDBuilder(true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)) +- *(2) Project [A#10, B#11, B#35, B#87] +- *(2) BroadcastHashJoin [A#10], [A#86], Inner, BuildRight :- *(2) Filter isnotnull(A#10) : +- InMemoryTableScan [A#10, B#11, B#35], [isnotnull(A#10)] : +- InMemoryRelation [A#10, B#11, B#35], CachedRDDBuilder(true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)) : +- *(2) Project [A#10, B#11, B#35] : +- *(2) BroadcastHashJoin [A#10], [A#34], Inner, BuildRight : :- *(2) Filter isnotnull(A#10) : : +- InMemoryTableScan [A#10, B#11], [isnotnull(A#10)] : : +- InMemoryRelation [A#10, B#11], CachedRDDBuilder(true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)) : : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct : +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) : +- *(1) Filter isnotnull(A#34) : +- InMemoryTableScan [A#34, B#35], [isnotnull(A#34)] : +- InMemoryRelation [A#34, B#35], CachedRDDBuilder(true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)) : +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, false])) +- *(1) Filter isnotnull(A#86) +- InMemoryTableScan [A#86, B#87], [isnotnull(A#86)] +- InMemoryRelation [A#86, B#87], CachedRDDBuilder(true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)) +- *(1) FileScan csv [A#10,B#11] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:test.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` Author: Onur Satici Closes #21805 from onursatici/os/inmemoryrelation-str. --- .../sql/execution/columnar/InMemoryRelation.scala | 5 ++++- .../org/apache/spark/sql/DatasetCacheSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 7c8faec53a828..1a8fbaca53f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.LongAccumulator +import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -207,4 +207,7 @@ case class InMemoryRelation( } override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) + + override def simpleString: String = + s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 5c6a021d5b767..44177e36caa01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -206,4 +206,15 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits // first time use, load cache checkDataset(df5, Row(10)) } + + test("SPARK-24850 InMemoryRelation string representation does not include cached plan") { + val df = Seq(1).toDF("a").cache() + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(false) + } + assert(outputStream.toString.replaceAll("#\\d+", "#x").contains( + "InMemoryRelation [a#x], StorageLevel(disk, memory, deserialized, 1 replicas)" + )) + } } From 61f0ca4f1c4f1498c0b6ad02370839619871d6c5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Jul 2018 13:03:32 -0700 Subject: [PATCH 1196/2461] [SPARK-24699][SS] Make watermarks work with Trigger.Once by saving updated watermark to commit log ## What changes were proposed in this pull request? Streaming queries with watermarks do not work with Trigger.Once because of the following. - Watermark is updated in the driver memory after a batch completes, but it is persisted to checkpoint (in the offset log) only when the next batch is planned - In trigger.once, the query terminated as soon as one batch has completed. Hence, the updated watermark is never persisted anywhere. The simple solution is to persist the updated watermark value in the commit log when a batch is marked as completed. Then the next batch, in the next trigger.once run can pick it up from the commit log. ## How was this patch tested? new unit tests Co-authored-by: Tathagata Das Co-authored-by: c-horn Author: Tathagata Das Closes #21746 from tdas/SPARK-24699. --- .../sql/execution/streaming/CommitLog.scala | 33 ++-- .../streaming/MicroBatchExecution.scala | 9 +- .../continuous/ContinuousExecution.scala | 2 +- .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 3 + .../offsets/1 | 3 + .../state/0/0/1.delta | Bin 0 -> 46 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/1.delta | Bin 0 -> 46 bytes .../state/0/1/2.delta | Bin 0 -> 46 bytes .../state/0/2/1.delta | Bin 0 -> 103 bytes .../state/0/2/2.delta | Bin 0 -> 46 bytes .../state/0/3/1.delta | Bin 0 -> 46 bytes .../state/0/3/2.delta | Bin 0 -> 46 bytes .../state/0/4/1.delta | Bin 0 -> 46 bytes .../state/0/4/2.delta | Bin 0 -> 103 bytes .../streaming/EventTimeWatermarkSuite.scala | 156 +++++++++++++++--- .../spark/sql/streaming/StreamTest.scala | 8 +- 20 files changed, 177 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/4/2.delta diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala index 5b114242558dc..0063318db332d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CommitLog.scala @@ -22,6 +22,9 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.sql.SparkSession /** @@ -43,36 +46,28 @@ import org.apache.spark.sql.SparkSession * line 2: metadata (optional json string) */ class CommitLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[String](sparkSession, path) { + extends HDFSMetadataLog[CommitMetadata](sparkSession, path) { import CommitLog._ - def add(batchId: Long): Unit = { - super.add(batchId, EMPTY_JSON) - } - - override def add(batchId: Long, metadata: String): Boolean = { - throw new UnsupportedOperationException( - "CommitLog does not take any metadata, use 'add(batchId)' instead") - } - - override protected def deserialize(in: InputStream): String = { + override protected def deserialize(in: InputStream): CommitMetadata = { // called inside a try-finally where the underlying stream is closed in the caller val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file in the offset commit log") } parseVersion(lines.next.trim, VERSION) - EMPTY_JSON + val metadataJson = if (lines.hasNext) lines.next else EMPTY_JSON + CommitMetadata(metadataJson) } - override protected def serialize(metadata: String, out: OutputStream): Unit = { + override protected def serialize(metadata: CommitMetadata, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller out.write(s"v${VERSION}".getBytes(UTF_8)) out.write('\n') // write metadata - out.write(EMPTY_JSON.getBytes(UTF_8)) + out.write(metadata.json.getBytes(UTF_8)) } } @@ -81,3 +76,13 @@ object CommitLog { private val EMPTY_JSON = "{}" } + +case class CommitMetadata(nextBatchWatermarkMs: Long = 0) { + def json: String = Serialization.write(this)(CommitMetadata.format) +} + +object CommitMetadata { + implicit val format = Serialization.formats(NoTypeHints) + + def apply(json: String): CommitMetadata = Serialization.read[CommitMetadata](json) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 45c43f549d24f..abb807def6239 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -268,7 +268,7 @@ class MicroBatchExecution( * latest batch id in the offset log, then we can safely move to the next batch * i.e., committedBatchId + 1 */ commitLog.getLatest() match { - case Some((latestCommittedBatchId, _)) => + case Some((latestCommittedBatchId, commitMetadata)) => if (latestBatchId == latestCommittedBatchId) { /* The last batch was successfully committed, so we can safely process a * new next batch but first: @@ -286,7 +286,8 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 isCurrentBatchConstructed = false committedOffsets ++= availableOffsets - // Construct a new batch be recomputing availableOffsets + watermarkTracker.setWatermark( + math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs)) } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -536,11 +537,11 @@ class MicroBatchExecution( } withProgressLocked { - commitLog.add(currentBatchId) + watermarkTracker.updateWatermark(lastExecution.executedPlan) + commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() } - watermarkTracker.updateWatermark(lastExecution.executedPlan) logDebug(s"Completed batch ${currentBatchId}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e991dbc81696d..140cec64fffb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -314,7 +314,7 @@ class ContinuousExecution( // Record offsets before updating `committedOffsets` recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { - commitLog.add(epoch) + commitLog.add(epoch, CommitMetadata()) val offset = continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata new file mode 100644 index 0000000000000..f205857e6876f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/metadata @@ -0,0 +1 @@ +{"id":"73f7f943-0a08-4ffb-a504-9fa88ff7612a"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 new file mode 100644 index 0000000000000..8fa80bedc2285 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531991874513,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 new file mode 100644 index 0000000000000..2248a58fea006 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":5000,"batchTimestampMs":1531991878604,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/1/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..171aa58a06e215f25d2acf0e621f85e9f4c26aa1 GIT binary patch literal 103 zcmeZ?GI7euPtI1gWnf@P0b-jdCFWott--*^5G(;?2=Fj4FfkY$c=&{!A%KBF*N~B& k!4W8Kz{bF=1Ed&OBpFzk7`RPiUJkE0RUDs6gB_= literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 58ed9790ea123..026af17c7b23f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -127,31 +127,133 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(aggWithWatermark)( AddData(inputData2, 15), CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(15)) - assert(e.get("min") === formatTimestamp(15)) - assert(e.get("avg") === formatTimestamp(15)) - assert(e.get("watermark") === formatTimestamp(0)) - }, + assertEventStats(min = 15, max = 15, avg = 15, wtrmark = 0), AddData(inputData2, 10, 12, 14), CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(14)) - assert(e.get("min") === formatTimestamp(10)) - assert(e.get("avg") === formatTimestamp(12)) - assert(e.get("watermark") === formatTimestamp(5)) - }, + assertEventStats(min = 10, max = 14, avg = 12, wtrmark = 5), AddData(inputData2, 25), CheckAnswer((10, 3)), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - } + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5) ) } + test("event time and watermark metrics with Trigger.Once (SPARK-24699)") { + // All event time metrics where watermarking is set + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + // Unlike the ProcessingTime trigger, Trigger.Once only runs one trigger every time + // the query is started and it does not run no-data batches. Hence the answer generated + // by the updated watermark is only generated the next time the query is started. + // Also, the data to process in the next trigger is added *before* starting the stream in + // Trigger.Once to ensure that first and only trigger picks up the new data. + + testStream(aggWithWatermark)( + StartStream(Trigger.Once), // to make sure the query is not running when adding data 1st time + awaitTermination(), + + AddData(inputData, 15), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 15, max = 15, avg = 15, wtrmark = 0), + // watermark should be updated to 15 - 10 = 5 + + AddData(inputData, 10, 12, 14), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 10, max = 14, avg = 12, wtrmark = 5), + // watermark should stay at 5 + + AddData(inputData, 25), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5), + // watermark should be updated to 25 - 10 = 15 + + AddData(inputData, 50), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer((10, 3)), // watermark = 15 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 15), + // watermark should be updated to 50 - 10 = 40 + + AddData(inputData, 50), + StartStream(Trigger.Once), + awaitTermination(), + CheckNewAnswer((15, 1), (25, 1)), // watermark = 40 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 40)) + } + + test("recovery from Spark ver 2.3.1 commit log without commit metadata (SPARK-24699)") { + // All event time metrics where watermarking is set + val inputData = MemoryStream[Int] + val aggWithWatermark = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-without-commit-log-metadata/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(15) + inputData.addData(10, 12, 14) + + testStream(aggWithWatermark)( + /* + + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + + StartStream(checkpointLocation = "./sql/core/src/test/resources/structured-streaming/" + + "checkpoint-version-2.3.1-without-commit-log-metadata/")), + AddData(inputData, 15), // watermark should be updated to 15 - 10 = 5 + CheckAnswer(), + AddData(inputData, 10, 12, 14), // watermark should stay at 5 + CheckAnswer(), + StopStream, + + // Offset log should have watermark recorded as 5. + */ + + StartStream(Trigger.Once), + awaitTermination(), + + AddData(inputData, 25), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer(), + assertEventStats(min = 25, max = 25, avg = 25, wtrmark = 5), + // watermark should be updated to 25 - 10 = 15 + + AddData(inputData, 50), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer((10, 3)), // watermark = 15 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 15), + // watermark should be updated to 50 - 10 = 40 + + AddData(inputData, 50), + StartStream(Trigger.Once, checkpointLocation = checkpointDir.getAbsolutePath), + awaitTermination(), + CheckNewAnswer((15, 1), (25, 1)), // watermark = 40 is used to generate this + assertEventStats(min = 50, max = 50, avg = 50, wtrmark = 40)) + } + test("append mode") { val inputData = MemoryStream[Int] @@ -625,10 +727,20 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche true } + /** Assert event stats generated on that last batch with data in it */ private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = { - AssertOnQuery { q => + Execute("AssertEventStats") { q => body(q.recentProgress.filter(_.numInputRows > 0).lastOption.get.eventTime) - true + } + } + + /** Assert event stats generated on that last batch with data in it */ + private def assertEventStats(min: Long, max: Long, avg: Double, wtrmark: Long): AssertOnQuery = { + assertEventStats { e => + assert(e.get("min") === formatTimestamp(min), s"min value mismatch") + assert(e.get("max") === formatTimestamp(max), s"max value mismatch") + assert(e.get("avg") === formatTimestamp(avg.toLong), s"avg value mismatch") + assert(e.get("watermark") === formatTimestamp(wtrmark), s"watermark value mismatch") } } @@ -638,4 +750,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche private def formatTimestamp(sec: Long): String = { timestampFormat.format(new ju.Date(sec * 1000)) } + + private def awaitTermination(): AssertOnQuery = Execute("AwaitTermination") { q => + q.awaitTermination() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e45..df22bc1315b7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -291,8 +291,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { - def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }, "Execute") + def apply(name: String)(func: StreamExecution => Any): AssertOnQuery = + AssertOnQuery(query => { func(query); true }, "name") + + def apply(func: StreamExecution => Any): AssertOnQuery = apply("Execute")(func) } object AwaitEpoch { @@ -512,7 +514,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be logInfo(s"Processing test stream action: $action") action match { case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") + verify(currentStream == null || !currentStream.isActive, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], "Use either SystemClock or StreamManualClock to start the stream") From cfc3e1aaa44b58da660be3378effdc48e088b9d3 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 23 Jul 2018 13:04:39 -0700 Subject: [PATCH 1197/2461] [SPARK-24339][SQL] Prunes the unused columns from child of ScriptTransformation ## What changes were proposed in this pull request? Modify the strategy in ColumnPruning to add a Project between ScriptTransformation and its child, this strategy can reduce the scan time especially in the scenario of the table has many columns. ## How was this patch tested? Add UT in ColumnPruningSuite and ScriptTransformationSuite. Author: Yuanjian Li Closes #21839 from xuanyuanking/SPARK-24339. --- .../sql/catalyst/optimizer/Optimizer.scala | 5 +++- .../optimizer/ColumnPruningSuite.scala | 24 +++++++++++++++++++ .../execution/ScriptTransformationSuite.scala | 19 +++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6faecd3efc40d..5ed7412c106fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -501,13 +501,16 @@ object ColumnPruning extends Rule[LogicalPlan] { case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => d.copy(child = prunedChild(child, d.references)) - // Prunes the unused columns from child of Aggregate/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) + case s @ ScriptTransformation(_, _, _, child, _) + if (child.outputSet -- s.references).nonEmpty => + s.copy(child = prunedChild(child, s.references)) // prune unrequired references case p @ Project(_, g: Generate) if p.references != g.outputSet => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 8b05ba32e6eef..f6db3c90ad96c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -140,6 +140,30 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, expected) } + test("Column pruning for ScriptTransformation") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + input, + null).analyze + val optimized = Optimize.execute(query) + + val expected = + ScriptTransformation( + Seq('a, 'b), + "func", + Seq.empty, + Project( + Seq('a, 'b), + input), + null).analyze + + comparePlans(optimized, expected) + } + test("Column pruning on Filter") { val input = LocalRelation('a.int, 'b.string, 'c.double) val plan1 = Filter('a > 1, input).analyze diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 5318b4650b01f..5f73b7170c612 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -136,6 +136,25 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } assert(e.getMessage.contains("Subprocess exited with status")) } + + test("SPARK-24339 verify the result after pruning the unused columns") { + val rowsDf = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformationExec( + input = Seq(rowsDf.col("name").expr), + script = "cat", + output = Seq(AttributeReference("name", StringType)()), + child = child, + ioschema = serdeIOSchema + ), + rowsDf.select("name").collect()) + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { From d2436a85294a178398525c37833dae79d45c1452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 24 Jul 2018 09:33:10 +0800 Subject: [PATCH 1198/2461] [SPARK-24594][YARN] Introducing metrics for YARN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR metrics are introduced for YARN. As up to now there was no metrics in the YARN module a new metric system is created with the name "applicationMaster". To support both client and cluster mode the metric system lifecycle is bound to the AM. ## How was this patch tested? Both client and cluster mode was tested manually. Before the test on one of the YARN node spark-core was removed to cause the allocation failure. Spark was started as (in case of client mode): ``` spark2-submit \ --class org.apache.spark.examples.SparkPi \ --conf "spark.yarn.blacklist.executor.launch.blacklisting.enabled=true" --conf "spark.blacklist.application.maxFailedExecutorsPerNode=2" --conf "spark.dynamicAllocation.enabled=true" --conf "spark.metrics.conf.*.sink.console.class=org.apache.spark.metrics.sink.ConsoleSink" \ --master yarn \ --deploy-mode client \ original-spark-examples_2.11-2.4.0-SNAPSHOT.jar \ 1000 ``` In both cases the YARN logs contained the new metrics as: ``` $ yarn logs --applicationId application_1529926424933_0015 ... -- Gauges ---------------------------------------------------------------------- application_1531751594108_0046.applicationMaster.numContainersPendingAllocate value = 0 application_1531751594108_0046.applicationMaster.numExecutorsFailed value = 3 application_1531751594108_0046.applicationMaster.numExecutorsRunning value = 9 application_1531751594108_0046.applicationMaster.numLocalityAwareTasks value = 0 application_1531751594108_0046.applicationMaster.numReleasedContainers value = 0 ... ``` Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #21635 from attilapiros/SPARK-24594. --- docs/monitoring.md | 1 + docs/running-on-yarn.md | 9 +++- .../spark/deploy/yarn/ApplicationMaster.scala | 18 +++++++ .../deploy/yarn/ApplicationMasterSource.scala | 50 +++++++++++++++++++ .../spark/deploy/yarn/YarnAllocator.scala | 8 ++- .../yarn/YarnAllocatorBlacklistTracker.scala | 2 +- .../org/apache/spark/deploy/yarn/config.scala | 5 ++ 7 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala diff --git a/docs/monitoring.md b/docs/monitoring.md index 6eaf33135744d..2717dd091c751 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -435,6 +435,7 @@ set of sinks to which metrics are reported. The following instances are currentl * `executor`: A Spark executor. * `driver`: The Spark driver process (the process in which your SparkContext is created). * `shuffleService`: The Spark shuffle service. +* `applicationMaster`: The Spark ApplicationMaster when running on YARN. Each instance can report to zero or more _sinks_. Sinks are contained in the `org.apache.spark.metrics.sink` package: diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 0b265b0cb1b31..1c1f40c028a97 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -421,7 +421,14 @@ To use a custom metrics.properties for the application master and executors, upd spark.blacklist.application.maxFailedExecutorsPerNode. - + + + + +
      keyvalue
      {session.userName} spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to + Comma-separated list of schemes for which resources will be downloaded to the local disk prior to being added to YARN's distributed cache. For use in cases where the YARN service does not - support schemes that are supported by Spark, like http, https and ftp. + support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the + local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes.
      spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path(none) + Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. +
      spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly(none) + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. +
      spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName](none) + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. +
      spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path(none) + Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. +
      spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnlyfalse + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. +
      spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName](none) + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. +
      spark.kubernetes.memoryOverheadFactor
      cascadeTruncate + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE (in the case of PostgreSQL a TRUNCATE TABLE ONLY t CASCADE is executed to prevent inadvertently truncating descendant tables). This will affect other tables, and thus should be used with care. This option applies only to writing. It defaults to the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect. +
      createTableOptions
      spark.yarn.metrics.namespace(none) + The root namespace for AM metrics reporting. + If it is not set then the YARN application ID is used. +
      # Important notes diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ecc576910db9e..55ed114f8500f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -43,6 +43,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -67,6 +68,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val securityMgr = new SecurityManager(sparkConf) + private var metricsSystem: Option[MetricsSystem] = None + // Set system properties for each config entry. This covers two use cases: // - The default configuration stored by the SparkHadoopUtil class // - The user application creating a new SparkConf in cluster mode @@ -309,6 +312,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, "Uncaught exception: " + StringUtils.stringifyException(e)) + } finally { + try { + metricsSystem.foreach { ms => + ms.report() + ms.stop() + } + } catch { + case e: Exception => + logWarning("Exception during stopping of the metric system: ", e) + } } } @@ -434,6 +447,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverRef)) allocator.allocateResources() + val ms = MetricsSystem.createMetricsSystem("applicationMaster", sparkConf, securityMgr) + val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId) + ms.registerSource(new ApplicationMasterSource(prefix, allocator)) + ms.start() + metricsSystem = Some(ms) reporterThread = launchReporterThread() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala new file mode 100644 index 0000000000000..0fec916582602 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterSource.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[spark] class ApplicationMasterSource(prefix: String, yarnAllocator: YarnAllocator) + extends Source { + + override val sourceName: String = prefix + ".applicationMaster" + override val metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("numExecutorsFailed"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumExecutorsFailed + }) + + metricRegistry.register(MetricRegistry.name("numExecutorsRunning"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumExecutorsRunning + }) + + metricRegistry.register(MetricRegistry.name("numReleasedContainers"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.getNumReleasedContainers + }) + + metricRegistry.register(MetricRegistry.name("numLocalityAwareTasks"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.numLocalityAwareTasks + }) + + metricRegistry.register(MetricRegistry.name("numContainersPendingAllocate"), new Gauge[Int] { + override def getValue: Int = yarnAllocator.numContainersPendingAllocate + }) + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index fae054e0eea00..40f1222fcd83f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -150,7 +150,7 @@ private[yarn] class YarnAllocator( private var hostToLocalTaskCounts: Map[String, Int] = Map.empty // Number of tasks that have locality preferences in active stages - private var numLocalityAwareTasks: Int = 0 + private[yarn] var numLocalityAwareTasks: Int = 0 // A container placement strategy based on pending tasks' locality preference private[yarn] val containerPlacementStrategy = @@ -158,6 +158,8 @@ private[yarn] class YarnAllocator( def getNumExecutorsRunning: Int = runningExecutors.size() + def getNumReleasedContainers: Int = releasedContainers.size() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted @@ -167,6 +169,10 @@ private[yarn] class YarnAllocator( */ def getPendingAllocate: Seq[ContainerRequest] = getPendingAtLocation(ANY_HOST) + def numContainersPendingAllocate: Int = synchronized { + getPendingAllocate.size + } + /** * A sequence of pending container requests at the given location that have not yet been * fulfilled. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala index 1b48a0ee7ad32..ceac7cda5f8be 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -28,7 +28,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.scheduler.BlacklistTracker -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SystemClock} /** * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 129084a86597a..1013fd2cc4a82 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -152,6 +152,11 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("100s") + private[spark] val YARN_METRICS_NAMESPACE = ConfigBuilder("spark.yarn.metrics.namespace") + .doc("The root namespace for AM metrics reporting.") + .stringConf + .createOptional + private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression") .doc("Node label expression for the AM.") .stringConf From 13a67b070d335bb257d13dacadea3450885c3d81 Mon Sep 17 00:00:00 2001 From: 10129659 Date: Mon, 23 Jul 2018 23:05:08 -0700 Subject: [PATCH 1199/2461] [SPARK-24870][SQL] Cache can't work normally if there are case letters in SQL ## What changes were proposed in this pull request? Modified the canonicalized to not case-insensitive. Before the PR, cache can't work normally if there are case letters in SQL, for example: sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive") sql("select key, sum(case when Key > 0 then 1 else 0 end) as positiveNum " + "from src group by key").cache().createOrReplaceTempView("src_cache") sql( s"""select a.key from (select key from src_cache where positiveNum = 1)a left join (select key from src_cache )b on a.key=b.key """).explain The physical plan of the sql is: ![image](https://user-images.githubusercontent.com/26834091/42979518-3decf0fa-8c05-11e8-9837-d5e4c334cb1f.png) The subquery "select key from src_cache where positiveNum = 1" on the left of join can use the cache data, but the subquery "select key from src_cache" on the right of join cannot use the cache data. ## How was this patch tested? new added test Author: 10129659 Closes #21823 from eatoncys/canonicalized. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../spark/sql/execution/SameResultSuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 4b4722b0b2117..b1ffdca091461 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -284,7 +284,7 @@ object QueryPlan extends PredicateHelper { if (ordinal == -1) { ar } else { - ar.withExprId(ExprId(ordinal)) + ar.withExprId(ExprId(ordinal)).canonicalized } }.canonicalized.asInstanceOf[T] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala index aaf51b5b90111..d088e24e53bfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType /** * Tests for the sameResult function for [[SparkPlan]]s. @@ -58,4 +61,16 @@ class SameResultSuite extends QueryTest with SharedSQLContext { val df4 = spark.range(10).agg(sumDistinct($"id")) assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) } + + test("Canonicalized result is case-insensitive") { + val a = AttributeReference("A", IntegerType)() + val b = AttributeReference("B", IntegerType)() + val planUppercase = Project(Seq(a), LocalRelation(a, b)) + + val c = AttributeReference("a", IntegerType)() + val d = AttributeReference("b", IntegerType)() + val planLowercase = Project(Seq(c), LocalRelation(c, d)) + + assert(planUppercase.sameResult(planLowercase)) + } } From 3d5c61e5fd24f07302e39b5d61294da79aa0c2f9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Jul 2018 19:51:09 +0800 Subject: [PATCH 1200/2461] [SPARK-22499][FOLLOWUP][SQL] Reduce input string expressions for Least and Greatest to reduce time in its test ## What changes were proposed in this pull request? It's minor and trivial but looks 2000 input is good enough to reproduce and test in SPARK-22499. ## How was this patch tested? Manually brought the change and tested. Locally tested: Before: 3m 21s 288ms After: 1m 29s 134ms Given the latest successful build took: ``` ArithmeticExpressionSuite: - SPARK-22499: Least and greatest should not generate codes beyond 64KB (7 minutes, 49 seconds) ``` I expect it's going to save 4ish mins. Author: hyukjinkwon Closes #21855 from HyukjinKwon/minor-fix-suite. --- .../sql/catalyst/expressions/ArithmeticExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 021217606dc03..9a752af523ffc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -349,7 +349,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { - val N = 3000 + val N = 2000 val strings = (1 to N).map(x => "s" * x) val inputsExpr = strings.map(Literal.create(_, StringType)) From 9d27541a856d95635386cbc98f2bb1f1f2f30c13 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 24 Jul 2018 10:46:36 -0700 Subject: [PATCH 1201/2461] [SPARK-23325] Use InternalRow when reading with DataSourceV2. ## What changes were proposed in this pull request? This updates the DataSourceV2 API to use InternalRow instead of Row for the default case with no scan mix-ins. Support for readers that produce Row is added through SupportsDeprecatedScanRow, which matches the previous API. Readers that used Row now implement this class and should be migrated to InternalRow. Readers that previously implemented SupportsScanUnsafeRow have been migrated to use no SupportsScan mix-ins and produce InternalRow. ## How was this patch tested? This uses existing tests. Author: Ryan Blue Closes #21118 from rdblue/SPARK-23325-datasource-v2-internal-row. --- .../sql/kafka010/KafkaContinuousReader.scala | 16 +++++----- .../sql/kafka010/KafkaMicroBatchReader.scala | 21 ++++++------- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sources/v2/reader/DataSourceReader.java | 6 ++-- .../v2/reader/InputPartitionReader.java | 7 +++-- ...ow.java => SupportsDeprecatedScanRow.java} | 25 ++++++---------- .../v2/reader/SupportsScanColumnarBatch.java | 4 +-- .../datasources/v2/DataSourceRDD.scala | 1 - .../datasources/v2/DataSourceV2ScanExec.scala | 17 ++++++----- .../datasources/v2/DataSourceV2Strategy.scala | 13 ++++---- .../continuous/ContinuousDataSourceRDD.scala | 26 ++++++++-------- .../ContinuousQueuedDataReader.scala | 8 ++--- .../ContinuousRateStreamSource.scala | 4 +-- .../sql/execution/streaming/memory.scala | 16 +++++----- .../sources/ContinuousMemoryStream.scala | 7 +++-- .../sources/RateStreamMicroBatchReader.scala | 4 +-- .../execution/streaming/sources/socket.scala | 7 +++-- .../sources/v2/JavaAdvancedDataSourceV2.java | 4 +-- .../v2/JavaPartitionAwareDataSource.java | 4 +-- .../v2/JavaSchemaRequiredDataSource.java | 5 ++-- .../sources/v2/JavaSimpleDataSourceV2.java | 5 ++-- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 9 +++--- .../sources/RateStreamProviderSuite.scala | 6 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 30 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 7 +++-- .../sql/streaming/StreamingQuerySuite.scala | 6 ++-- .../ContinuousQueuedDataReaderSuite.scala | 6 ++-- .../sources/StreamingDataSourceV2Suite.scala | 7 +++-- 28 files changed, 138 insertions(+), 135 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{SupportsScanUnsafeRow.java => SupportsDeprecatedScanRow.java} (62%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index badaa69cc303c..48b91dfe764e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ @@ -53,7 +54,7 @@ class KafkaContinuousReader( metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { + extends ContinuousReader with Logging { private lazy val session = SparkSession.getActiveSession.get private lazy val sc = session.sparkContext @@ -86,7 +87,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -107,8 +108,8 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[InputPartition[UnsafeRow]] + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss + ): InputPartition[InternalRow] }.asJava } @@ -161,9 +162,10 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] { - override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") @@ -192,7 +194,7 @@ class KafkaContinuousInputPartitionReader( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 737da2e51b125..6c95b2b2560c4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -29,11 +29,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -61,7 +62,7 @@ private[kafka010] class KafkaMicroBatchReader( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MicroBatchReader with Logging { private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -101,7 +102,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -142,11 +143,11 @@ private[kafka010] class KafkaMicroBatchReader( val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size // Generate factories based on the offset ranges - val factories = offsetRanges.map { range => + offsetRanges.map { range => new KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) - } - factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer + ): InputPartition[InternalRow] + }.asJava } override def getStartOffset: Offset = { @@ -305,11 +306,11 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + override def createPartitionReader(): InputPartitionReader[InternalRow] = new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } @@ -320,7 +321,7 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c6412eac97dba..5d5e57323cff5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.planUnsafeInputPartitions().asScala + val factories = reader.planInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 36a3e542b5a11..ad9c838992fa8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; @@ -43,7 +43,7 @@ * Names of these interfaces start with `SupportsScan`. Note that a reader should only * implement at most one of the special scans, if more than one special scans are implemented, * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. + * {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}. * * If an exception was throw when applying any of these query optimizations, the action will fail * and no Spark job will be submitted. @@ -76,5 +76,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - List> planInputPartitions(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 33fa7be4c1b20..7cf382e52f67e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -26,9 +26,10 @@ * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is * responsible for outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input - * partition readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} + * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data + * source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row} + * for data source readers that mix in {@link SupportsDeprecatedScanRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java similarity index 62% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java index f2220f6d31093..595943cf4d8ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java @@ -17,30 +17,23 @@ package org.apache.spark.sql.sources.v2.reader; -import java.util.List; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.InternalRow; + +import java.util.List; /** * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. - * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get - * changed in the future Spark versions. + * interface to output {@link Row} instead of {@link InternalRow}. + * This is an experimental and unstable interface. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceReader { - - @Override - default List> planInputPartitions() { +public interface SupportsDeprecatedScanRow extends DataSourceReader { + default List> planInputPartitions() { throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsDeprecatedScanRow"); } - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, - * but returns data in unsafe row format. - */ - List> planUnsafeInputPartitions(); + List> planRowInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 0faf81db24605..f4da686740d11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -20,7 +20,7 @@ import java.util.List; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; /** @@ -30,7 +30,7 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> planInputPartitions() { + default List> planInputPartitions() { throw new IllegalStateException( "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 8d6fb3820d420..7ea53424ae100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c6a7684bf6ab0..b030b9a929b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -75,12 +75,13 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala - case _ => - reader.planInputPartitions().asScala.map { - new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] + private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match { + case r: SupportsDeprecatedScanRow => + r.planRowInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow] } + case _ => + reader.planInputPartitions().asScala } private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { @@ -132,11 +133,11 @@ case class DataSourceV2ScanExec( } class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) - extends InputPartition[UnsafeRow] { + extends InputPartition[InternalRow] { override def preferredLocations: Array[String] = partition.preferredLocations - override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + override def createPartitionReader: InputPartitionReader[InternalRow] = { new RowToUnsafeInputPartitionReader( partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } @@ -146,7 +147,7 @@ class RowToUnsafeInputPartitionReader( val rowReader: InputPartitionReader[Row], encoder: ExpressionEncoder[Row]) - extends InputPartitionReader[UnsafeRow] { + extends InputPartitionReader[InternalRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 2a7f1de2c7c19..9414e68155b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -125,16 +125,13 @@ object DataSourceV2Strategy extends Strategy { val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - val withProjection = if (withFilter.output != project) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + // always add the projection, which will produce unsafe rows required by some operators + ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + // ensure there is a projection, which will produce unsafe rows required by some operators + ProjectExec(r.output, + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 73868d5967e90..1ffa1d02f1432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} -import org.apache.spark.util.{NextIterator, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[UnsafeRow]) + val inputPartition: InputPartition[InternalRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,8 +51,8 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + private val readerInputPartitions: Seq[InputPartition[InternalRow]]) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerInputPartitions.zipWithIndex.map { @@ -64,7 +64,7 @@ class ContinuousDataSourceRDD( * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. */ - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { // If attempt number isn't 0, this is a task retry, which we don't support. if (context.attemptNumber() != 0) { throw new ContinuousTaskRetryException() @@ -80,8 +80,8 @@ class ContinuousDataSourceRDD( partition.queueReader } - new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = { + new NextIterator[InternalRow] { + override def getNext(): InternalRow = { readerForPartition.next() match { case null => finished = true @@ -101,9 +101,9 @@ class ContinuousDataSourceRDD( object ContinuousDataSourceRDD { private[continuous] def getContinuousReader( - reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { + reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { reader match { - case r: ContinuousInputPartitionReader[UnsafeRow] => r + case r: ContinuousInputPartitionReader[InternalRow] => r case wrapped: RowToUnsafeInputPartitionReader => wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 8c74b8244d096..bfb87053db475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils @@ -52,7 +52,7 @@ class ContinuousQueuedDataReader( */ sealed trait ContinuousRecord case object EpochMarker extends ContinuousRecord - case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + case class ContinuousRow(row: InternalRow, offset: PartitionOffset) extends ContinuousRecord private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) @@ -79,12 +79,12 @@ class ContinuousQueuedDataReader( } /** - * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * Return the next row to be read in the current epoch, or null if the epoch is done. * * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch * will call next() again to start getting rows. */ - def next(): UnsafeRow = { + def next(): InternalRow = { val POLL_TIMEOUT_MS = 1000 var currentEntry: ContinuousRecord = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 516a563bdcc7a..55ce3ae38ee3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -35,7 +35,7 @@ case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset class RateStreamContinuousReader(options: DataSourceOptions) - extends ContinuousReader { + extends ContinuousReader with SupportsDeprecatedScanRow { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..f81abdcc3711a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -28,12 +28,13 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -79,8 +80,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) - with MicroBatchReader with SupportsScanUnsafeRow with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -139,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block): InputPartition[InternalRow] }.asJava } } @@ -202,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[UnsafeRow] { - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { - new InputPartitionReader[UnsafeRow] { + extends InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new InputPartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 0bf90b8063326..e776ebc08e30d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -49,7 +49,8 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport + with SupportsDeprecatedScanRow { private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -99,7 +100,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planInputPartitions(): ju.List[InputPartition[Row]] = { + override def planRowInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index b393c48baee8d..7a3452aa315cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { + extends MicroBatchReader with SupportsDeprecatedScanRow with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def planInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 91e3b7179c34a..e3a2c007a9ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -50,7 +50,8 @@ object TextSocketMicroBatchReader { * debugging. This MicroBatchReader will *not* work in production applications due to multiple * reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader + with SupportsDeprecatedScanRow with Logging { private var startOffset: Offset = _ private var endOffset: Offset = _ @@ -141,7 +142,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 445cb29f5ee3a..c130b5f1e2513 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + SupportsPushDownFilters, SupportsDeprecatedScanRow { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -79,7 +79,7 @@ public Filter[] pushedFilters() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { List> res = new ArrayList<>(); Integer lowerBound = null; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index e49c8cf8b9e16..35aafb532d80d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -34,7 +34,7 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning, SupportsDeprecatedScanRow { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -43,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Arrays.asList( new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 80eeffd95f83b..6dee94c34e21c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -25,11 +25,12 @@ import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceReader { + class Reader implements DataSourceReader, SupportsDeprecatedScanRow { private final StructType schema; Reader(StructType schema) { @@ -42,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 8522a63898a3b..5c2f351975c74 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -28,11 +28,12 @@ import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader { + class Reader implements DataSourceReader, SupportsDeprecatedScanRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -41,7 +42,7 @@ public StructType readSchema() { } @Override - public List> planInputPartitions() { + public List> planRowInputPartitions() { return java.util.Arrays.asList( new JavaSimpleInputPartition(0, 5), new JavaSimpleInputPartition(5, 10)); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index 3ad8e7a0104ce..25b89c7fd36a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; @@ -29,7 +30,7 @@ public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -38,7 +39,7 @@ public StructType readSchema() { } @Override - public List> planUnsafeInputPartitions() { + public List> planInputPartitions() { return java.util.Arrays.asList( new JavaUnsafeRowInputPartition(0, 5), new JavaUnsafeRowInputPartition(5, 10)); @@ -46,7 +47,7 @@ public List> planUnsafeInputPartitions() { } static class JavaUnsafeRowInputPartition - implements InputPartition, InputPartitionReader { + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; @@ -59,7 +60,7 @@ static class JavaUnsafeRowInputPartition } @Override - public InputPartitionReader createPartitionReader() { + public InputPartitionReader createPartitionReader() { return new JavaUnsafeRowInputPartition(start - 1, end); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 9115a384d0790..260a0376daeb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -146,7 +146,7 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 1) val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() @@ -165,7 +165,7 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala @@ -311,7 +311,7 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val tasks = reader.planRowInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e96cd4500458d..d73eebbc84b71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -23,6 +23,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -344,10 +345,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { + class Reader extends DataSourceReader with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -357,10 +358,10 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { + class Reader extends DataSourceReader with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } @@ -390,7 +391,7 @@ class SimpleInputPartition(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader + class Reader extends DataSourceReader with SupportsDeprecatedScanRow with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -415,7 +416,7 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption @@ -467,10 +468,10 @@ class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), new UnsafeRowInputPartitionReader(5, 10)) } @@ -480,14 +481,14 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } class UnsafeRowInputPartitionReader(start: Int, end: Int) - extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[InternalRow] = this override def next(): Boolean = { current += 1 @@ -504,8 +505,8 @@ class UnsafeRowInputPartitionReader(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[Row]] = + class Reader(val readSchema: StructType) extends DataSourceReader with SupportsDeprecatedScanRow { + override def planRowInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -568,10 +569,11 @@ class BatchInputPartitionReader(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { + class Reader extends DataSourceReader with SupportsReportPartitioning + with SupportsDeprecatedScanRow { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 1334cf71ae988..98d7eedbcb9c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,10 +42,11 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { + class Reader(path: String, conf: Configuration) extends DataSourceReader + with SupportsDeprecatedScanRow { override def readSchema(): StructType = schema - override def planInputPartitions(): JList[InputPartition[Row]] = { + override def planRowInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 936a076d647b6..78199b0a1c19a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -30,7 +30,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { clock.waitTillTime(1350) - super.planUnsafeInputPartitions() + super.planInputPartitions() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 0e7e6febb53df..4f198819b58d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} -import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.InputPartition @@ -73,8 +73,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[UnsafeRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { + val factory = new InputPartition[InternalRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { var index = -1 var curr: UnsafeRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index c1a28b9bc75ef..7c012158bd751 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,14 +26,15 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { +case class FakeReader() extends MicroBatchReader with ContinuousReader + with SupportsDeprecatedScanRow { def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} def getStartOffset: Offset = RateStreamOffset(Map()) def getEndOffset: Offset = RateStreamOffset(Map()) @@ -44,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { + def planRowInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From d4a277f0ce2d6e1832d87cae8faec38c5bc730f4 Mon Sep 17 00:00:00 2001 From: s71955 Date: Tue, 24 Jul 2018 11:31:27 -0700 Subject: [PATCH 1202/2461] [SPARK-24812][SQL] Last Access Time in the table description is not valid ## What changes were proposed in this pull request? Last Access Time will always displayed wrong date Thu Jan 01 05:30:00 IST 1970 when user run DESC FORMATTED table command In hive its displayed as "UNKNOWN" which makes more sense than displaying wrong date. seems to be a limitation as of now even from hive, better we can follow the hive behavior unless the limitation has been resolved from hive. spark client output ![spark_desc table](https://user-images.githubusercontent.com/12999161/42753448-ddeea66a-88a5-11e8-94aa-ef8d017f94c5.png) Hive client output ![hive_behaviour](https://user-images.githubusercontent.com/12999161/42753489-f4fd366e-88a5-11e8-83b0-0f3a53ce83dd.png) ## How was this patch tested? UT has been added which makes sure that the wrong date "Thu Jan 01 05:30:00 IST 1970 " shall not be added as value for the Last Access property Author: s71955 Closes #21775 from sujith71955/master_hive. --- docs/sql-programming-guide.md | 1 + .../spark/sql/catalyst/catalog/interface.scala | 5 ++++- .../spark/sql/hive/execution/HiveDDLSuite.scala | 13 +++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4bab58aff0067..e815e5bd516e2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1850,6 +1850,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c6105c5526049..a4ead538bb51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -114,7 +114,10 @@ case class CatalogTablePartition( map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } map.put("Created Time", new Date(createTime).toString) - map.put("Last Access", new Date(lastAccessTime).toString) + val lastAccess = { + if (-1 == lastAccessTime) "UNKNOWN" else new Date(lastAccessTime).toString + } + map.put("Last Access", lastAccess) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 31fd4c5a1f996..0b3de3d4cd599 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI +import java.util.Date import scala.language.existentials @@ -2250,6 +2251,18 @@ class HiveDDLSuite } } + test("SPARK-24812: desc formatted table for last access verification") { + withTable("t1") { + sql( + "CREATE TABLE IF NOT EXISTS t1 (c1_int INT, c2_string STRING, c3_float FLOAT)") + val desc = sql("DESC FORMATTED t1").filter($"col_name".startsWith("Last Access")) + .select("data_type") + // check if the last access time doesnt have the default date of year + // 1970 as its a wrong access time + assert(!(desc.first.toString.contains("1970"))) + } + } + test("SPARK-24681 checks if nested column names do not include ',', ':', and ';'") { val expectedMsg = "Cannot create a table having a nested column whose name contains invalid " + "characters (',', ':', ';') in Hive metastore." From fc21f192a302e48e5c321852e2a25639c5a182b5 Mon Sep 17 00:00:00 2001 From: Eric Chang Date: Tue, 24 Jul 2018 15:53:50 -0700 Subject: [PATCH 1203/2461] [SPARK-24895] Remove spotbugs plugin ## What changes were proposed in this pull request? Spotbugs maven plugin was a recently added plugin before 2.4.0 snapshot artifacts were broken. To ensure it does not affect the maven deploy plugin, this change removes it. ## How was this patch tested? Local build was ran, but this patch will be actually tested by monitoring the apache repo artifacts and making sure metadata is correctly uploaded after this job is ran: https://amplab.cs.berkeley.edu/jenkins/view/Spark%20Packaging/job/spark-master-maven-snapshots/ Author: Eric Chang Closes #21865 from ericfchang/SPARK-24895. --- pom.xml | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/pom.xml b/pom.xml index 81a53eee14f29..d75db0f080e64 100644 --- a/pom.xml +++ b/pom.xml @@ -2610,28 +2610,6 @@ - - com.github.spotbugs - spotbugs-maven-plugin - 3.1.3 - - ${basedir}/target/scala-${scala.binary.version}/classes - ${basedir}/target/scala-${scala.binary.version}/test-classes - Max - Low - true - FindPuzzlers - true - - - - - check - - compile - - -
      From 3efdf35327be38115b04b08e9c8d0aa282a904ab Mon Sep 17 00:00:00 2001 From: shane knapp Date: Tue, 24 Jul 2018 16:13:57 -0700 Subject: [PATCH 1204/2461] [SPARK-24908][R][STYLE] removing spaces to make lintr happy ## What changes were proposed in this pull request? during my travails in porting spark builds to run on our centos worker, i managed to recreate (as best i could) the centos environment on our new ubuntu-testing machine. while running my initial builds, lintr was crashing on some extraneous spaces in test_basic.R (see: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.6-ubuntu-test/862/console) after removing those spaces, the ubuntu build happily passed the lintr tests. ## How was this patch tested? i then tested this against a modified spark-master-test-sbt-hadoop-2.6 build (see https://amplab.cs.berkeley.edu/jenkins/view/RISELab%20Infra/job/testing-spark-master-test-with-updated-R-crap/4/), which scp'ed a copy of test_basic.R in to the repo after the git clone. everything seems to be working happily. Author: shane knapp Closes #21864 from shaneknapp/fixing-R-lint-spacing. --- R/pkg/inst/tests/testthat/test_basic.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 243f5f0298284..80df3d8ce6e59 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,9 +18,9 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { - tryCatch( checkJavaVersion(), + tryCatch(checkJavaVersion(), error = function(e) { skip("error on Java check") }, - warning = function(e) { skip("warning on Java check") } ) + warning = function(e) { skip("warning on Java check") }) sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -54,9 +54,9 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { - tryCatch( checkJavaVersion(), + tryCatch(checkJavaVersion(), error = function(e) { skip("error on Java check") }, - warning = function(e) { skip("warning on Java check") } ) + warning = function(e) { skip("warning on Java check") }) sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) From 15fff79032f6d708d8570b5e83144f1f84519552 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 25 Jul 2018 09:08:42 +0800 Subject: [PATCH 1205/2461] [SPARK-24297][CORE] Fetch-to-disk by default for > 2gb Fetch-to-mem is guaranteed to fail if the message is bigger than 2 GB, so we might as well use fetch-to-disk in that case. The message includes some metadata in addition to the block data itself (in particular UploadBlock has a lot of metadata), so we leave a little room. Author: Imran Rashid Closes #21474 from squito/SPARK-24297. --- .../org/apache/spark/internal/config/package.scala | 6 +++++- docs/configuration.md | 10 ++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ba892bf7f60d6..8fef2aa6863c5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -432,7 +432,11 @@ package object config { "external shuffle service, this feature can only be worked when external shuffle" + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) - .createWithDefault(Long.MaxValue) + // fetch-to-mem is guaranteed to fail if the message is bigger than 2 GB, so we might + // as well use fetch-to-disk in that case. The message includes some metadata in addition + // to the block data itself (in particular UploadBlock has a lot of metadata), so we leave + // extra room. + .createWithDefault(Int.MaxValue - 512) private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") diff --git a/docs/configuration.md b/docs/configuration.md index 0c7c4472be643..60c0358c0e938 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -580,13 +580,15 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem - Long.MaxValue + Int.MaxValue - 512 The remote block will be fetched to disk when size of the block is above this threshold in bytes. - This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + This is to avoid a giant request that takes too much memory. By default, this is only enabled + for blocks > 2GB, as those cannot be fetched directly into memory, no matter what resources are + available. But it can be turned down to a much lower value (eg. 200m) to avoid using too much + memory on smaller blocks as well. Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, - this feature can only be worked when external shuffle service is newer than Spark 2.2. + this feature can only be used when external shuffle service is newer than Spark 2.2. From c26b0921693814f0726507f16b836d82e2e8cfe0 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Tue, 24 Jul 2018 19:35:34 -0700 Subject: [PATCH 1206/2461] [SPARK-24891][SQL] Fix HandleNullInputsForUDF rule ## What changes were proposed in this pull request? The HandleNullInputsForUDF would always add a new `If` node every time it is applied. That would cause a difference between the same plan being analyzed once and being analyzed twice (or more), thus raising issues like plan not matched in the cache manager. The solution is to mark the arguments as null-checked, which is to add a "KnownNotNull" node above those arguments, when adding the UDF under an `If` node, because clearly the UDF will not be called when any of those arguments is null. ## How was this patch tested? Add new tests under sql/UDFSuite and AnalysisSuite. Author: maryannxue Closes #21851 from maryannxue/spark-24891. --- .../sql/catalyst/analysis/Analyzer.scala | 22 ++++++++---- .../expressions/constraintExpressions.scala | 35 +++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 16 +++++++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 31 +++++++++++++++- 4 files changed, 94 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 866396c42f9d8..4f474f4987dcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -2145,14 +2145,24 @@ class Analyzer( val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // (cls, expr) => cls.isPrimitive && expr.nullable + val needsNullCheck = (cls: Class[_], expr: Expression) => + cls.isPrimitive && !expr.isInstanceOf[KnowNotNull] val inputsNullCheck = parameterTypes.zip(inputs) - // TODO: skip null handling for not-nullable primitive inputs after we can completely - // trust the `nullable` information. - // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } - .filter { case (cls, _) => cls.isPrimitive } + .filter { case (cls, expr) => needsNullCheck(cls, expr) } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) => + if (needsNullCheck(cls, expr)) KnowNotNull(expr) else expr } + inputsNullCheck + .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) + .getOrElse(udf) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala new file mode 100644 index 0000000000000..53936aa914c8f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.types.DataType + +case class KnowNotNull(child: Expression) extends UnaryExpression { + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx).copy(isNull = FalseLiteral) + } + + override def eval(input: InternalRow): Any = { + child.eval(input) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 9e0db8dbf8f3a..31f703d018aed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -316,7 +316,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) - val expected2 = If(IsNull(double), nullResult, udf2) + val expected2 = + If(IsNull(double), nullResult, udf2.copy(children = string :: KnowNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters @@ -324,7 +325,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3) + udf3.copy(children = KnowNotNull(short) :: KnowNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -336,10 +337,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected4 = If( IsNull(short), nullResult, - udf4) + udf4.copy(children = KnowNotNull(short) :: double.withNullability(false) :: Nil)) // checkUDF(udf4, expected4) } + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val a = testRelation.output(0) + val func = (x: Int, y: Int) => x + y + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil) + val plan = Project(Alias(udf2, "")() :: Nil, testRelation) + comparePlans(plan.analyze, plan.analyze.analyze) + } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { val a = testRelation2.output(0) val c = testRelation2.output(2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 21afdc7e2a33f..d8074571ffc65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types.{DataTypes, DoubleType} @@ -324,4 +324,33 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) } } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val udf1 = udf({(x: Int, y: Int) => x + y}) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", udf1($"a", lit(10)))) + .withColumn("c", udf1($"a", lit(null))) + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + + comparePlans(df.logicalPlan, plan) + checkAnswer( + df, + Seq( + Row(0, 10, null), + Row(1, 12, null), + Row(2, 14, null))) + } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule - with table") { + withTable("x") { + Seq((1, "2"), (2, "4")).toDF("a", "b").write.format("json").saveAsTable("x") + sql("insert into table x values(3, null)") + sql("insert into table x values(null, '4')") + spark.udf.register("f", (a: Int, b: String) => a + b) + val df = spark.sql("SELECT f(a, b) FROM x") + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + comparePlans(df.logicalPlan, plan) + checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null))) + } + } } From d4c341589499099654ed4febf235f19897a21601 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 24 Jul 2018 20:21:11 -0700 Subject: [PATCH 1207/2461] [SPARK-24890][SQL] Short circuiting the `if` condition when `trueValue` and `falseValue` are the same ## What changes were proposed in this pull request? When `trueValue` and `falseValue` are semantic equivalence, the condition expression in `if` can be removed to avoid extra computation in runtime. ## How was this patch tested? Test added. Author: DB Tsai Closes #21848 from dbtsai/short-circuit-if. --- .../sql/catalyst/optimizer/expressions.scala | 7 ++++-- .../optimizer/SimplifyConditionalSuite.scala | 24 ++++++++++++++++++- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index cf17f59599968..4696699337c9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -390,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, trueValue, falseValue) + if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. @@ -403,14 +405,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = newBranches) } - case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) => // If the first branch is a true literal, remove the entire CaseWhen and use the value // from that. Note that CaseWhen.branches should never be empty, and as a result the // headOption (rather than head) added above is just an extra (and unnecessary) safeguard. branches.head._2 case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) => - // a branc with a TRue condition eliminates all following branches, + // a branch with a true condition eliminates all following branches, // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) @@ -651,6 +653,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } + /** * Combine nested [[Concat]] expressions. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index b597c8e162c83..e210874a55d87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + val batches = Batch("SimplifyConditionals", FixedPoint(50), + BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { @@ -43,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) + private val testRelation = LocalRelation('a.int) + test("simplify if") { assertEquivalent( If(TrueLiteral, Literal(10), Literal(20)), @@ -57,6 +62,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) } + test("remove unnecessary if when the outputs are semantic equivalence") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + Literal(9)) + + // For non-deterministic condition, we don't remove the `If` statement. + assertEquivalent( + If(GreaterThan(Rand(0), Literal(0.5)), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + If(GreaterThan(Rand(0), Literal(0.5)), + Literal(9), + Literal(9))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index e562be83822e9..ac70488febc5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase } /** - * Returns full path to the given file in the resouce folder + * Returns full path to the given file in the resource folder */ protected def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString From afb0627536494c654ce5dd72db648f1ee7da641c Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 24 Jul 2018 20:46:27 -0700 Subject: [PATCH 1208/2461] [SPARK-23957][SQL] Sorts in subqueries are redundant and can be removed ## What changes were proposed in this pull request? Thanks to henryr for the original idea at https://github.com/apache/spark/pull/21049 Description from the original PR : Subqueries (at least in SQL) have 'bag of tuples' semantics. Ordering them is therefore redundant (unless combined with a limit). This patch removes the top sort operators from the subquery plans. This closes https://github.com/apache/spark/pull/21049. ## How was this patch tested? Added test cases in SubquerySuite to cover in, exists and scalar subqueries. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #21853 from dilipbiswal/SPARK-23957. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +- .../org/apache/spark/sql/SubquerySuite.scala | 300 +++++++++++++++++- 2 files changed, 310 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5ed7412c106fd..adb1350adc261 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -180,10 +180,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) * Optimize all the subqueries inside expression. */ object OptimizeSubqueries extends Rule[LogicalPlan] { + private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = { + plan match { + case Sort(_, _, child) => child + case Project(fields, child) => Project(fields, removeTopLevelSort(child)) + case other => other + } + } def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) - s.withNewPlan(newPlan) + // At this point we have an optimized subquery plan that we are going to attach + // to this subquery expression. Here we can safely remove any top level sort + // in the plan as tuples produced by a subquery are un-ordered. + s.withNewPlan(removeTopLevelSort(newPlan)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index acef62d81ee12..cbffed994bb4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.Join +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -970,4 +973,299 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row("3", "b") :: Row("4", "b") :: Nil) } } + + private def getNumSortsInQuery(query: String): Int = { + val plan = sql(query).queryExecution.optimizedPlan + getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum + } + + private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { + val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] + plan transformAllExpressions { + case s: SubqueryExpression => + subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) + s + } + subqueryExpressions + } + + private def getNumSorts(plan: LogicalPlan): Int = { + plan.collect { case s: Sort => s }.size + } + + test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order bys + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT * + | FROM t2 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + + // nested IN + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM t2 + | WHERE c1 IN (SELECT c1 + | FROM t3 + | WHERE c1 = 1 + | ORDER BY c3) + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Complex subplan and multiple sorts + val query4 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Join in subplan + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT t2.c1 FROM t2, t3 + | WHERE t2.c1 = t3.c1 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 0) + + val query6 = + """ + |SELECT c1 + |FROM t1 + |WHERE (c1, c2) IN (SELECT c1, max(c2) + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | GROUP BY c1 + | HAVING max(c2) > 0 + | ORDER BY c1) + """.stripMargin + // The rule to remove redundant sorts is not able to remove the inner sort under + // an Aggregate operator. We only remove the top level sort. + assert(getNumSortsInQuery(query6) == 1) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query7 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query7) == 1) + + // Sort below a set operations (intersect, union) + val query8 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (( + | SELECT c1 FROM t2 + | ORDER BY c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | ORDER BY c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query8) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by exists correlated + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order by and correlated. + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM (SELECT * + | FROM t2 + | WHERE t2.c1 = t1.c1 + | ORDER BY t2.c2) t2 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested EXISTS + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT c1 + | FROM t3 + | WHERE t3.c1 = t2.c1 + | ORDER BY c3) + | AND t2.c1 = t1.c1 + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query4 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 1) + + // Sort below a set operations (intersect, union) + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 1 + | ORDER BY t2.c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 2 + | ORDER BY t2.c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query5) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Two scalar subqueries in OR + val query1 = + """ + |SELECT * FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | ORDER BY max(t2.c1)) + |OR c2 = (SELECT min(t3.c2) + | FROM t3 + | WHERE t3.c1 = 1 + | ORDER BY min(t3.c2)) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // scalar subquery - groupby and having + val query2 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested scalar subquery + val query3 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Scalar subquery in projection + val query4 = + """ + |SELECT (SELECT min(c1) from t1 group by c1 order by c1) + |FROM t1 + |WHERE t1.c1 = 1 + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Limit on top of sort prevents it from being pruned. + val query5 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1) + | LIMIT 1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 1) + } + } } From 78e0a725e06665cf92d4b8f987ee01947a1d620c Mon Sep 17 00:00:00 2001 From: crafty-coder Date: Wed, 25 Jul 2018 14:17:20 +0800 Subject: [PATCH 1209/2461] [SPARK-19018][SQL] Add support for custom encoding on csv writer ## What changes were proposed in this pull request? Add support for custom encoding on csv writer, see https://issues.apache.org/jira/browse/SPARK-19018 ## How was this patch tested? Added two unit tests in CSVSuite Author: crafty-coder Author: Carlos Closes #20949 from crafty-coder/master. --- python/pyspark/sql/readwriter.py | 7 +++- .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/csv/CSVFileFormat.scala | 6 ++- .../execution/datasources/csv/CSVSuite.scala | 39 ++++++++++++++++++- 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3efe2adb6e2a4..98b2cd9968407 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -859,7 +859,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None): + charToEscapeQuoteEscaping=None, encoding=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -909,6 +909,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise.. + :param encoding: sets the encoding (charset) of saved csv files. If None is set, + the default UTF-8 charset will be used. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -918,7 +920,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No dateFormat=dateFormat, timestampFormat=timestampFormat, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, + encoding=encoding) self._jwrite.csv(path) @since(1.5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 90bea2d676e22..b9fa43f1f9fbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -629,6 +629,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * enclosed in quotes. Default is to only escape values containing a quote character. *
    • `header` (default `false`): writes the names of columns as the first line.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved csv + * files. If it is not set, the UTF-8 charset will be used.
    • *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index aeb40e5a4131d..d59b9820bdeef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv +import java.nio.charset.Charset + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ @@ -168,7 +170,9 @@ private[csv] class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val charset = Charset.forName(params.charset) + + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) private val gen = new UnivocityGenerator(dataSchema, writer, params) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 63cc5985040c9..456b4535a0dcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.UnsupportedCharsetException +import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale import scala.collection.JavaConverters._ +import scala.util.Properties import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType @@ -514,6 +516,41 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } + test("SPARK-19018: Save csv with custom charset") { + + // scalastyle:off nonascii + val content = "µß áâä ÁÂÄ" + // scalastyle:on nonascii + + Seq("iso-8859-1", "utf-8", "utf-16", "utf-32", "windows-1250").foreach { encoding => + withTempPath { path => + val csvDir = new File(path, "csv") + Seq(content).toDF().write + .option("encoding", encoding) + .csv(csvDir.getCanonicalPath) + + csvDir.listFiles().filter(_.getName.endsWith("csv")).foreach({ csvFile => + val readback = Files.readAllBytes(csvFile.toPath) + val expected = (content + Properties.lineSeparator).getBytes(Charset.forName(encoding)) + assert(readback === expected) + }) + } + } + } + + test("SPARK-19018: error handling for unsupported charsets") { + val exception = intercept[SparkException] { + withTempPath { path => + val csvDir = new File(path, "csv").getCanonicalPath + Seq("a,A,c,A,b,B").toDF().write + .option("encoding", "1-9588-osi") + .csv(csvDir) + } + } + + assert(exception.getCause.getMessage.contains("1-9588-osi")) + } + test("commented lines in CSV data") { Seq("false", "true").foreach { multiLine => From 7a5fd4a91e19ee32b365eaf5678c627ad6c6d4c2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 24 Jul 2018 23:59:13 -0700 Subject: [PATCH 1210/2461] [SPARK-18874][SQL][FOLLOW-UP] Improvement type mismatched message ## What changes were proposed in this pull request? Improvement `IN` predicate type mismatched message: ```sql Mismatched columns: [(, t, 4, ., `, t, 4, a, `, :, d, o, u, b, l, e, ,, , t, 5, ., `, t, 5, a, `, :, d, e, c, i, m, a, l, (, 1, 8, ,, 0, ), ), (, t, 4, ., `, t, 4, c, `, :, s, t, r, i, n, g, ,, , t, 5, ., `, t, 5, c, `, :, b, i, g, i, n, t, )] ``` After this patch: ```sql Mismatched columns: [(t4.`t4a`:double, t5.`t5a`:decimal(18,0)), (t4.`t4c`:string, t5.`t5c`:bigint)] ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #21863 from wangyum/SPARK-18874. --- .../sql/catalyst/expressions/predicates.scala | 2 +- .../negative-cases/subq-input-typecheck.sql | 16 ++++- .../subq-input-typecheck.sql.out | 66 +++++++++++++++---- 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 699601e64dd14..f4077f78006b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -189,7 +189,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } else { val mismatchedColumns = valExprs.zip(childOutputs).flatMap { case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") case _ => None } TypeCheckResult.TypeCheckFailure( diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql index b15f4da81dd93..95b115a8dd094 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -13,6 +13,14 @@ CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) AS t3(t3a, t3b, t3c); +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES + (CAST(1 AS DOUBLE), CAST(2 AS STRING), CAST(3 AS STRING)) +AS t1(t4a, t4b, t4c); + +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES + (CAST(1 AS DECIMAL(18, 0)), CAST(2 AS STRING), CAST(3 AS BIGINT)) +AS t1(t5a, t5b, t5c); + -- TC 01.01 SELECT ( SELECT max(t2b), min(t2b) @@ -44,4 +52,10 @@ WHERE (t1a, t1b) IN (SELECT t2a FROM t2 WHERE t1a = t2a); - +-- TC 01.05 +SELECT * FROM t4 +WHERE +(t4a, t4b, t4c) IN (SELECT t5a, + t5b, + t5c + FROM t5); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index 70aeb9373f3c7..dcd30055bca19 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 10 -- !query 0 @@ -33,6 +33,26 @@ struct<> -- !query 3 +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES + (CAST(1 AS DOUBLE), CAST(2 AS STRING), CAST(3 AS STRING)) +AS t1(t4a, t4b, t4c) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES + (CAST(1 AS DECIMAL(18, 0)), CAST(2 AS STRING), CAST(3 AS BIGINT)) +AS t1(t5a, t5b, t5c) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 SELECT ( SELECT max(t2b), min(t2b) FROM t2 @@ -40,14 +60,14 @@ SELECT GROUP BY t2.t2b ) FROM t1 --- !query 3 schema +-- !query 5 schema struct<> --- !query 3 output +-- !query 5 output org.apache.spark.sql.AnalysisException Scalar subquery must return only one column, but got 2; --- !query 4 +-- !query 6 SELECT ( SELECT max(t2b), min(t2b) FROM t2 @@ -55,22 +75,22 @@ SELECT GROUP BY t2.t2b ) FROM t1 --- !query 4 schema +-- !query 6 schema struct<> --- !query 4 output +-- !query 6 output org.apache.spark.sql.AnalysisException Scalar subquery must return only one column, but got 2; --- !query 5 +-- !query 7 SELECT * FROM t1 WHERE t1a IN (SELECT t2a, t2b FROM t2 WHERE t1a = t2a) --- !query 5 schema +-- !query 7 schema struct<> --- !query 5 output +-- !query 7 output org.apache.spark.sql.AnalysisException cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the @@ -83,15 +103,15 @@ Right side columns: [t2.`t2a`, t2.`t2b`].; --- !query 6 +-- !query 8 SELECT * FROM T1 WHERE (t1a, t1b) IN (SELECT t2a FROM t2 WHERE t1a = t2a) --- !query 6 schema +-- !query 8 schema struct<> --- !query 6 output +-- !query 8 output org.apache.spark.sql.AnalysisException cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the @@ -102,3 +122,25 @@ Left side columns: [t1.`t1a`, t1.`t1b`]. Right side columns: [t2.`t2a`].; + + +-- !query 9 +SELECT * FROM t4 +WHERE +(t4a, t4b, t4c) IN (SELECT t5a, + t5b, + t5c + FROM t5) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t4a', t4.`t4a`, 't4b', t4.`t4b`, 't4c', t4.`t4c`) IN (listquery()))' due to data type mismatch: +The data type of one or more elements in the left hand side of an IN subquery +is not compatible with the data type of the output of the subquery +Mismatched columns: +[(t4.`t4a`:double, t5.`t5a`:decimal(18,0)), (t4.`t4c`:string, t5.`t5c`:bigint)] +Left side: +[double, string, string]. +Right side: +[decimal(18,0), string, bigint].; From c44eb561ec371af0405710d2e9358f9797655145 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 25 Jul 2018 08:42:45 -0700 Subject: [PATCH 1211/2461] [SPARK-24768][FOLLOWUP][SQL] Avro migration followup: change artifactId to spark-avro ## What changes were proposed in this pull request? After rethinking on the artifactId, I think it should be `spark-avro` instead of `spark-sql-avro`, which is simpler, and consistent with the previous artifactId. I think we need to change it before Spark 2.4 release. Also a tiny change: use `spark.sessionState.newHadoopConf()` to get the hadoop configuration, thus the related hadoop configurations in SQLConf will come into effect. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21866 from gengliangwang/avro_followup. --- external/avro/pom.xml | 2 +- .../main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/avro/pom.xml b/external/avro/pom.xml index ad7df1f49ac45..8f118ba48201b 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -25,7 +25,7 @@ ../../pom.xml - spark-sql-avro_2.11 + spark-avro_2.11 avro diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index b043252f49afa..c6b3c13be5140 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -56,7 +56,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { spark: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val conf = spark.sparkContext.hadoopConfiguration + val conf = spark.sessionState.newHadoopConf() val parsedOptions = new AvroOptions(options, conf) // Schema evolution is not supported yet. Here we only pick a single random sample file to From 571a6f0574e50e53cea403624ec3795cd03aa204 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 25 Jul 2018 11:08:41 -0700 Subject: [PATCH 1212/2461] [SPARK-23146][K8S] Support client mode. ## What changes were proposed in this pull request? Support client mode for the Kubernetes scheduler. Client mode works more or less identically to cluster mode. However, in client mode, the Spark Context needs to be manually bootstrapped with certain properties which would have otherwise been set up by spark-submit in cluster mode. Specifically: - If the user doesn't provide a driver pod name, we don't add an owner reference. This is for usage when the driver is not running in a pod in the cluster. In such a case, the driver can only provide a best effort to clean up the executors when the driver exits, but cleaning up the resources is not guaranteed. The executor JVMs should exit if the driver JVM exits, but the pods will still remain in the cluster in a COMPLETED or FAILED state. - The user must provide a host (spark.driver.host) and port (spark.driver.port) that the executors can connect to. When using spark-submit in cluster mode, spark-submit generates the headless service automatically; in client mode, the user is responsible for setting up their own connectivity. We also change the authentication configuration prefixes for client mode. ## How was this patch tested? Adding an integration test to exercise client mode support. Author: mcheah Closes #21748 from mccheah/k8s-client-mode. --- docs/running-on-kubernetes.md | 138 ++++++++++++++---- .../org/apache/spark/deploy/k8s/Config.scala | 3 +- .../spark/deploy/k8s/KubernetesConf.scala | 4 +- .../spark/deploy/k8s/KubernetesUtils.scala | 2 + .../features/BasicExecutorFeatureStep.scala | 17 ++- .../submit/KubernetesClientApplication.scala | 4 +- .../cluster/k8s/ExecutorPodsAllocator.scala | 13 +- .../k8s/KubernetesClusterManager.scala | 35 +++-- .../deploy/k8s/KubernetesConfSuite.scala | 8 +- .../BasicExecutorFeatureStepSuite.scala | 6 +- .../features/EnvSecretsFeatureStepSuite.scala | 2 +- .../MountSecretsFeatureStepSuite.scala | 2 +- .../k8s/ExecutorPodsAllocatorSuite.scala | 2 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 6 +- .../src/main/dockerfiles/spark/Dockerfile | 2 +- .../ClientModeTestsSuite.scala | 111 ++++++++++++++ .../k8s/integrationtest/KubernetesSuite.scala | 11 +- 17 files changed, 290 insertions(+), 76 deletions(-) create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 7149616e534aa..97c650d0f80aa 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,6 +117,45 @@ If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0. spark-submit. Finally, notice that in the above example we specify a jar with a specific URI with a scheme of `local://`. This URI is the location of the example jar that is already in the Docker image. +## Client Mode + +Starting with Spark 2.4.0, it is possible to run Spark applications on Kubernetes in client mode. When your application +runs in client mode, the driver can run inside a pod or on a physical host. When running an application in client mode, +it is recommended to account for the following factors: + +### Client Mode Networking + +Spark executors must be able to connect to the Spark driver over a hostname and a port that is routable from the Spark +executors. The specific network configuration that will be required for Spark to work in client mode will vary per +setup. If you run your driver inside a Kubernetes pod, you can use a +[headless service](https://kubernetes.io/docs/concepts/services-networking/service/#headless-services) to allow your +driver pod to be routable from the executors by a stable hostname. When deploying your headless service, ensure that +the service's label selector will only match the driver pod and no other pods; it is recommended to assign your driver +pod a sufficiently unique label and to use that label in the label selector of the headless service. Specify the driver's +hostname via `spark.driver.host` and your spark driver's port to `spark.driver.port`. + +### Client Mode Executor Pod Garbage Collection + +If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod. +When this property is set, the Spark scheduler will deploy the executor pods with an +[OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will +ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted. +The driver will look for a pod with the given name in the namespace specified by `spark.kubernetes.namespace`, and +an OwnerReference pointing to that pod will be added to each executor pod's OwnerReferences list. Be careful to avoid +setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated +prematurely when the wrong pod is deleted. + +If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is +actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the +application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails +for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the +driver, so the executor pods should not consume compute resources (cpu and memory) in the cluster after your application +exits. + +### Authentication Parameters + +Use the exact prefix `spark.kubernetes.authenticate` for Kubernetes authentication parameters in client mode. + ## Dependency Management If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to @@ -258,10 +297,6 @@ RBAC authorization and how to configure Kubernetes service accounts for pods, pl [Using RBAC Authorization](https://kubernetes.io/docs/admin/authorization/rbac/) and [Configure Service Accounts for Pods](https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/). -## Client Mode - -Client mode is not currently supported. - ## Future Work There are several Spark on Kubernetes features that are currently being incubated in a fork - @@ -354,7 +389,7 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide - a scheme). + a scheme). In client mode, use spark.kubernetes.authenticate.caCertFile instead. @@ -363,7 +398,7 @@ specific to Spark on Kubernetes. Path to the client key file for authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not provide - a scheme). + a scheme). In client mode, use spark.kubernetes.authenticate.clientKeyFile instead. @@ -372,7 +407,7 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not - provide a scheme). + provide a scheme). In client mode, use spark.kubernetes.authenticate.clientCertFile instead. @@ -381,7 +416,7 @@ specific to Spark on Kubernetes. OAuth token to use when authenticating against the Kubernetes API server when starting the driver. Note that unlike the other authentication options, this is expected to be the exact string value of the token to use for - the authentication. + the authentication. In client mode, use spark.kubernetes.authenticate.oauthToken instead. @@ -390,7 +425,7 @@ specific to Spark on Kubernetes. Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server when starting the driver. This file must be located on the submitting machine's disk. Specify this as a path as opposed to a URI (i.e. do not - provide a scheme). + provide a scheme). In client mode, use spark.kubernetes.authenticate.oauthTokenFile instead. @@ -399,7 +434,8 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.caCertFile instead. @@ -407,10 +443,9 @@ specific to Spark on Kubernetes. (none) Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting - executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). If this is specified, it is highly - recommended to set up TLS for the driver submission server, as this value is sensitive information that would be - passed to the driver pod in plaintext otherwise. + executors. This file must be located on the submitting machine's disk, and will be uploaded to the driver pod as + a Kubernetes secret. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + In client mode, use spark.kubernetes.authenticate.clientKeyFile instead. @@ -419,7 +454,8 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This file must be located on the submitting machine's disk, and will be uploaded to the - driver pod. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + driver pod as a Kubernetes secret. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + In client mode, use spark.kubernetes.authenticate.clientCertFile instead. @@ -428,9 +464,8 @@ specific to Spark on Kubernetes. OAuth token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. Note that unlike the other authentication options, this must be the exact string value of - the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is - highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would - be passed to the driver pod in plaintext otherwise. + the token to use for the authentication. This token value is uploaded to the driver pod as a Kubernetes secret. + In client mode, use spark.kubernetes.authenticate.oauthToken instead. @@ -439,9 +474,8 @@ specific to Spark on Kubernetes. Path to the OAuth token file containing the token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. Note that unlike the other authentication options, this file must contain the exact string value of - the token to use for the authentication. This token value is uploaded to the driver pod. If this is specified, it is - highly recommended to set up TLS for the driver submission server, as this value is sensitive information that would - be passed to the driver pod in plaintext otherwise. + the token to use for the authentication. This token value is uploaded to the driver pod as a secret. In client mode, use + spark.kubernetes.authenticate.oauthTokenFile instead. @@ -450,7 +484,8 @@ specific to Spark on Kubernetes. Path to the CA cert file for connecting to the Kubernetes API server over TLS from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.caCertFile instead. @@ -459,7 +494,8 @@ specific to Spark on Kubernetes. Path to the client key file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.clientKeyFile instead. @@ -468,7 +504,8 @@ specific to Spark on Kubernetes. Path to the client cert file for authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + Specify this as a path as opposed to a URI (i.e. do not provide a scheme). In client mode, use + spark.kubernetes.authenticate.clientCertFile instead. @@ -477,7 +514,8 @@ specific to Spark on Kubernetes. Path to the file containing the OAuth token to use when authenticating against the Kubernetes API server from the driver pod when requesting executors. This path must be accessible from the driver pod. - Note that unlike the other authentication options, this file must contain the exact string value of the token to use for the authentication. + Note that unlike the other authentication options, this file must contain the exact string value of the token to use + for the authentication. In client mode, use spark.kubernetes.authenticate.oauthTokenFile instead. @@ -486,7 +524,48 @@ specific to Spark on Kubernetes. Service account that is used when running the driver pod. The driver pod uses this service account when requesting executor pods from the API server. Note that this cannot be specified alongside a CA cert file, client key file, - client cert file, and/or OAuth token. + client cert file, and/or OAuth token. In client mode, use spark.kubernetes.authenticate.serviceAccountName instead. + + + + spark.kubernetes.authenticate.caCertFile + (none) + + In client mode, path to the CA cert file for connecting to the Kubernetes API server over TLS when + requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.clientKeyFile + (none) + + In client mode, path to the client key file for authenticating against the Kubernetes API server + when requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.clientCertFile + (none) + + In client mode, path to the client cert file for authenticating against the Kubernetes API server + when requesting executors. Specify this as a path as opposed to a URI (i.e. do not provide a scheme). + + + + spark.kubernetes.authenticate.oauthToken + (none) + + In client mode, the OAuth token to use when authenticating against the Kubernetes API server when + requesting executors. Note that unlike the other authentication options, this must be the exact string value of + the token to use for the authentication. + + + + spark.kubernetes.authenticate.oauthTokenFile + (none) + + In client mode, path to the file containing the OAuth token to use when authenticating against the Kubernetes API + server when requesting executors. @@ -529,8 +608,11 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.pod.name (none) - Name of the driver pod. If not set, the driver pod name is set to "spark.app.name" suffixed by the current timestamp - to avoid name conflicts. + Name of the driver pod. In cluster mode, if this is not set, the driver pod name is set to "spark.app.name" + suffixed by the current timestamp to avoid name conflicts. In client mode, if your application is running + inside a pod, it is highly recommended to set this to the name of the pod your driver is running in. Setting this + value in client mode allows the driver to become the owner of its executor pods, which in turn allows the executor + pods to be garbage collected by the cluster. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index f9a77e71ad618..968679df60367 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -65,6 +65,7 @@ private[spark] object Config extends Logging { "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = "spark.kubernetes.authenticate.driver.mounted" + val KUBERNETES_AUTH_CLIENT_MODE_PREFIX = "spark.kubernetes.authenticate" val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" @@ -90,7 +91,7 @@ private[spark] object Config extends Logging { ConfigBuilder("spark.kubernetes.submitInDriver") .internal() .booleanConf - .createOptional + .createWithDefault(false) val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 51d205fdb68d1..866ba3cbaa9c3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -43,7 +43,7 @@ private[spark] case class KubernetesDriverSpecificConf( */ private[spark] case class KubernetesExecutorSpecificConf( executorId: String, - driverPod: Pod) + driverPod: Option[Pod]) extends KubernetesRoleSpecificConf /** @@ -186,7 +186,7 @@ private[spark] object KubernetesConf { sparkConf: SparkConf, executorId: String, appId: String, - driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + driverPod: Option[Pod]): KubernetesConf[KubernetesExecutorSpecificConf] = { val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) require( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 66fff267545dc..588cd9d40f9a0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -58,4 +58,6 @@ private[spark] object KubernetesUtils { case _ => uri } } + + def parseMasterUrl(url: String): String = url.substring("k8s://".length) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index abaeff0313a79..c37f713c56de1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -152,19 +152,20 @@ private[spark] class BasicExecutorFeatureStep( .build() }.getOrElse(executorContainer) val driverPod = kubernetesConf.roleSpecificConf.driverPod + val ownerReference = driverPod.map(pod => + new OwnerReferenceBuilder() + .withController(true) + .withApiVersion(pod.getApiVersion) + .withKind(pod.getKind) + .withName(pod.getMetadata.getName) + .withUid(pod.getMetadata.getUid) + .build()) val executorPod = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(name) .withLabels(kubernetesConf.roleLabels.asJava) .withAnnotations(kubernetesConf.roleAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() + .addToOwnerReferences(ownerReference.toSeq: _*) .endMetadata() .editOrNewSpec() .withHostname(hostname) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index eaff47205dbbc..9398faee2ea5c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -228,7 +228,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. - val master = sparkConf.get("spark.master").substring("k8s://".length) + val master = KubernetesUtils.parseMasterUrl(sparkConf.get("spark.master")) val loggingInterval = if (waitForAppCompletion) Some(sparkConf.get(REPORT_INTERVAL)) else None val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 5a143ad3600fd..77bb9c3fcc9f4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -46,13 +46,18 @@ private[spark] class ExecutorPodsAllocator( private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + private val namespace = conf.get(KUBERNETES_NAMESPACE) + private val kubernetesDriverPodName = conf .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) - private val driverPod = kubernetesClient.pods() - .withName(kubernetesDriverPodName) - .get() + private val driverPod = kubernetesDriverPodName + .map(name => Option(kubernetesClient.pods() + .withName(name) + .get()) + .getOrElse(throw new SparkException( + s"No pod was found named $kubernetesDriverPodName in the cluster in the " + + s"namespace $namespace (this was supposed to be the driver pod.)."))) // Executor IDs that have been requested from Kubernetes but have not been detected in any // snapshot yet. Mapped to the timestamp when they were created. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index de2a52bc7a0b8..9999c62c878df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -35,12 +35,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && - sc.deployMode == "client" && - !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { - throw new SparkException("Client mode is currently not supported for Kubernetes.") - } - new TaskSchedulerImpl(sc) } @@ -48,13 +42,32 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { + val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - KUBERNETES_MASTER_INTERNAL_URL, + apiServerUri, Some(sc.conf.get(KUBERNETES_NAMESPACE)), - KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + authConfPrefix, sc.conf, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + defaultServiceAccountToken, + defaultServiceAccountCaCrt) val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index 661f942435921..ecdb71359c5bb 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -184,9 +184,9 @@ class KubernetesConfSuite extends SparkFunSuite { new SparkConf(false), EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) - assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + assert(conf.roleSpecificConf.driverPod.get === DRIVER_POD) } test("Image pull secrets.") { @@ -195,7 +195,7 @@ class KubernetesConfSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.imagePullSecrets() === Seq( new LocalObjectReferenceBuilder().withName("my-secret-1").build(), @@ -221,7 +221,7 @@ class KubernetesConfSuite extends SparkFunSuite { sparkConf, EXECUTOR_ID, APP_ID, - DRIVER_POD) + Some(DRIVER_POD)) assert(conf.roleLabels === Map( SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, SPARK_APP_ID_LABEL -> APP_ID, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a44fa1f2ffc63..95d373f791649 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -81,7 +81,7 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( baseConf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), RESOURCE_NAME_PREFIX, APP_ID, LABELS, @@ -121,7 +121,7 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( conf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), longPodNamePrefix, APP_ID, LABELS, @@ -142,7 +142,7 @@ class BasicExecutorFeatureStepSuite val step = new BasicExecutorFeatureStep( KubernetesConf( conf, - KubernetesExecutorSpecificConf("1", DRIVER_POD), + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), RESOURCE_NAME_PREFIX, APP_ID, LABELS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index 1c8d84b76c56b..85c6cb282d2b0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -37,7 +37,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ val sparkConf = new SparkConf(false) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), "resource-name-prefix", "app-id", Map.empty, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 2b49b72dfa569..dad610c443acc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -35,7 +35,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { val sparkConf = new SparkConf(false) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), "resource-name-prefix", "app-id", Map.empty, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 0c19f5946b75f..e847f8590d353 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -166,7 +166,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { conf, executorSpecificConf.executorId, TEST_SPARK_APP_ID, - driverPod) + Some(driverPod)) k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && // Since KubernetesConf.createExecutorConf clones the SparkConf object, force // deep equality comparison for the SparkConf object and use object equality diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index d0b4127065eb7..44fe4a24e1102 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -51,7 +51,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, @@ -69,7 +69,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, @@ -96,7 +96,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesExecutorSpecificConf( - "executor-id", new PodBuilder().build()), + "executor-id", Some(new PodBuilder().build())), "prefix", "appId", Map.empty, diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 9badf8556afc3..42a670174eae1 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -31,7 +31,7 @@ RUN set -ex && \ apk upgrade --no-cache && \ apk add --no-cache bash tini libc6-compat && \ mkdir -p /opt/spark && \ - mkdir -p /opt/spark/work-dir \ + mkdir -p /opt/spark/work-dir && \ touch /opt/spark/RELEASE && \ rm /bin/sh && \ ln -sv /bin/bash /bin/sh && \ diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala new file mode 100644 index 0000000000000..0690db7f67cf0 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.scalatest.concurrent.Eventually +import scala.collection.JavaConverters._ + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.{k8sTestTag, INTERVAL, TIMEOUT} + +trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => + + test("Run in client mode.", k8sTestTag) { + val labels = Map("spark-app-selector" -> driverPodName) + val driverPort = 7077 + val blockManagerPort = 10000 + val driverService = testBackend + .getKubernetesClient + .services() + .inNamespace(kubernetesTestComponents.namespace) + .createNew() + .withNewMetadata() + .withName(s"$driverPodName-svc") + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(labels.asJava) + .addNewPort() + .withName("driver-port") + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName("block-manager") + .withPort(blockManagerPort) + .withNewTargetPort(blockManagerPort) + .endPort() + .endSpec() + .done() + try { + val driverPod = testBackend + .getKubernetesClient + .pods() + .inNamespace(kubernetesTestComponents.namespace) + .createNew() + .withNewMetadata() + .withName(driverPodName) + .withLabels(labels.asJava) + .endMetadata() + .withNewSpec() + .withServiceAccountName("default") + .addNewContainer() + .withName("spark-example") + .withImage(image) + .withImagePullPolicy("IfNotPresent") + .withCommand("/opt/spark/bin/run-example") + .addToArgs("--master", s"k8s://https://kubernetes.default.svc") + .addToArgs("--deploy-mode", "client") + .addToArgs("--conf", s"spark.kubernetes.container.image=$image") + .addToArgs( + "--conf", + s"spark.kubernetes.namespace=${kubernetesTestComponents.namespace}") + .addToArgs("--conf", "spark.kubernetes.authenticate.oauthTokenFile=" + + "/var/run/secrets/kubernetes.io/serviceaccount/token") + .addToArgs("--conf", "spark.kubernetes.authenticate.caCertFile=" + + "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt") + .addToArgs("--conf", s"spark.kubernetes.driver.pod.name=$driverPodName") + .addToArgs("--conf", "spark.executor.memory=500m") + .addToArgs("--conf", "spark.executor.cores=1") + .addToArgs("--conf", "spark.executor.instances=1") + .addToArgs("--conf", + s"spark.driver.host=" + + s"${driverService.getMetadata.getName}.${kubernetesTestComponents.namespace}.svc") + .addToArgs("--conf", s"spark.driver.port=$driverPort") + .addToArgs("--conf", s"spark.driver.blockManager.port=$blockManagerPort") + .addToArgs("SparkPi") + .addToArgs("10") + .endContainer() + .endSpec() + .done() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .getLog + .contains("Pi is roughly 3"), "The application did not complete.") + } + } finally { + // Have to delete the service manually since it doesn't have an owner reference + kubernetesTestComponents + .kubernetesClient + .services() + .inNamespace(kubernetesTestComponents.namespace) + .delete(driverService) + } + } + +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 95694aa93d5b5..0d829218cf774 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -38,12 +38,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite import KubernetesSuite._ - private var testBackend: IntegrationTestBackend = _ - private var sparkHomeDir: Path = _ - private var image: String = _ - private var pyImage: String = _ - private var driverPodName: String = _ - + protected var testBackend: IntegrationTestBackend = _ + protected var sparkHomeDir: Path = _ + protected var image: String = _ + protected var pyImage: String = _ + protected var driverPodName: String = _ protected var kubernetesTestComponents: KubernetesTestComponents = _ protected var sparkAppConf: SparkAppConf = _ protected var containerLocalSparkDistroExamplesJar: String = _ From 2f77616e1dc593fb9c376a8bd72416198cf3d6f5 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 25 Jul 2018 11:09:12 -0700 Subject: [PATCH 1213/2461] [SPARK-24849][SPARK-24911][SQL] Converting a value of StructType to a DDL string ## What changes were proposed in this pull request? In the PR, I propose to extend the `StructType`/`StructField` classes by new method `toDDL` which converts a value of the `StructType`/`StructField` type to a string formatted in DDL style. The resulted string can be used in a table creation. The `toDDL` method of `StructField` is reused in `SHOW CREATE TABLE`. In this way the PR fixes the bug of unquoted names of nested fields. ## How was this patch tested? I add a test for checking the new method and 2 round trip tests: `fromDDL` -> `toDDL` and `toDDL` -> `fromDDL` Author: Maxim Gekk Closes #21803 from MaxGekk/to-ddl. --- .../spark/sql/catalyst/util/package.scala | 12 +++++++ .../apache/spark/sql/types/StructField.scala | 13 ++++++++ .../apache/spark/sql/types/StructType.scala | 10 +++++- .../spark/sql/types/StructTypeSuite.scala | 33 +++++++++++++++++++ .../spark/sql/execution/command/tables.scala | 23 +++---------- .../spark/sql/hive/ShowCreateTableSuite.scala | 15 +++++++++ 6 files changed, 86 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 4005087dad05a..0978e92dd4f72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,6 +155,18 @@ package object util { def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql + + def escapeSingleQuotedString(str: String): String = { + val builder = StringBuilder.newBuilder + + str.foreach { + case '\'' => builder ++= s"\\\'" + case ch => builder += ch + } + + builder.toString() + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 2c18fdcc497fe..902cae9150ede 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** * A field inside a StructType. @@ -74,4 +75,16 @@ case class StructField( def getComment(): Option[String] = { if (metadata.contains("comment")) Option(metadata.getString("comment")) else None } + + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructField("eventId", IntegerType)` will be converted to `eventId` INT. + */ + def toDDL: String = { + val comment = getComment() + .map(escapeSingleQuotedString) + .map(" COMMENT '" + _ + "'") + + s"${quoteIdentifier(name)} ${dataType.sql}${comment.getOrElse("")}" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b13e95f83bc58..c5ca169c955dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.util.Utils /** @@ -360,6 +360,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"STRUCT<${fieldTypes.mkString(", ")}>" } + /** + * Returns a string containing a schema in DDL format. For example, the following value: + * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` + * will be converted to `eventId` INT, `s` STRING. + * The returned DDL schema can be used in a table creation. + */ + def toDDL: String = fields.map(_.toDDL).mkString(",") + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index c6ca8bb005429..53a78c94aa6fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { @@ -37,4 +38,36 @@ class StructTypeSuite extends SparkFunSuite { val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage assert(e.contains("Available fields: a, b")) } + + test("SPARK-24849: toDDL - simple struct") { + val struct = StructType(Seq(StructField("a", IntegerType))) + + assert(struct.toDDL == "`a` INT") + } + + test("SPARK-24849: round trip toDDL - fromDDL") { + val struct = new StructType().add("a", IntegerType).add("b", StringType) + + assert(fromDDL(struct.toDDL) === struct) + } + + test("SPARK-24849: round trip fromDDL - toDDL") { + val struct = "`a` MAP,`b` INT" + + assert(fromDDL(struct).toDDL === struct) + } + + test("SPARK-24849: toDDL must take into account case of fields.") { + val struct = new StructType() + .add("metaData", new StructType().add("eventId", StringType)) + + assert(struct.toDDL == "`metaData` STRUCT<`eventId`: STRING>") + } + + test("SPARK-24849: toDDL should output field's comment") { + val struct = StructType(Seq( + StructField("b", BooleanType).withComment("Field's comment"))) + + assert(struct.toDDL == """`b` BOOLEAN COMMENT 'Field\'s comment'""") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ec3961f84bd8d..56f48b7dc00ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.Histogram -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -982,7 +982,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showHiveTableHeader(metadata: CatalogTable, builder: StringBuilder): Unit = { val columns = metadata.schema.filterNot { column => metadata.partitionColumnNames.contains(column.name) - }.map(columnToDDLFragment) + }.map(_.toDDL) if (columns.nonEmpty) { builder ++= columns.mkString("(", ", ", ")\n") @@ -994,14 +994,10 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman .foreach(builder.append) } - private def columnToDDLFragment(column: StructField): String = { - val comment = column.getComment().map(escapeSingleQuotedString).map(" COMMENT '" + _ + "'") - s"${quoteIdentifier(column.name)} ${column.dataType.catalogString}${comment.getOrElse("")}" - } private def showHiveTableNonDataColumns(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.partitionColumnNames.nonEmpty) { - val partCols = metadata.partitionSchema.map(columnToDDLFragment) + val partCols = metadata.partitionSchema.map(_.toDDL) builder ++= partCols.mkString("PARTITIONED BY (", ", ", ")\n") } @@ -1072,7 +1068,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman private def showDataSourceTableDataColumns( metadata: CatalogTable, builder: StringBuilder): Unit = { - val columns = metadata.schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}") + val columns = metadata.schema.fields.map(_.toDDL) builder ++= columns.mkString("(", ", ", ")\n") } @@ -1117,15 +1113,4 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman } } } - - private def escapeSingleQuotedString(str: String): String = { - val builder = StringBuilder.newBuilder - - str.foreach { - case '\'' => builder ++= s"\\\'" - case ch => builder += ch - } - - builder.toString() - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index 473bbced41b31..34ca790299859 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -288,6 +288,21 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } + test("SPARK-24911: keep quotes for nested fields") { + withTable("t1") { + val createTable = "CREATE TABLE `t1`(`a` STRUCT<`b`: STRING>)" + sql(createTable) + val shownDDL = sql(s"SHOW CREATE TABLE t1") + .head() + .getString(0) + .split("\n") + .head + assert(shownDDL == createTable) + + checkCreateTable("t1") + } + } + private def createRawHiveTable(ddl: String): Unit = { hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] .client.runSqlHive(ddl) From 0c83f718ee8544708e421374d46ef1b95d93973e Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 25 Jul 2018 12:10:23 -0700 Subject: [PATCH 1214/2461] [SPARK-23146][K8S][TESTS] Enable client mode integration test. ## What changes were proposed in this pull request? Enable client mode integration test after merging from master. ## How was this patch tested? Check the integration test runs in the build. Author: mcheah Closes #21874 from mccheah/enable-client-mode-test. --- .../spark/deploy/k8s/integrationtest/KubernetesSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 0d829218cf774..13ce2efecbef7 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBacke private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite - with PythonTestsSuite { + with PythonTestsSuite with ClientModeTestsSuite { import KubernetesSuite._ From 17f469bc808e076b45fffcedb0147991fa4c41f3 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Wed, 25 Jul 2018 13:06:03 -0700 Subject: [PATCH 1215/2461] [SPARK-24860][SQL] Support setting of partitionOverWriteMode in output options for writing DataFrame ## What changes were proposed in this pull request? Besides spark setting spark.sql.sources.partitionOverwriteMode also allow setting partitionOverWriteMode per write ## How was this patch tested? Added unit test in InsertSuite Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Koert Kuipers Closes #21818 from koertkuipers/feat-partition-overwrite-mode-per-write. --- .../apache/spark/sql/internal/SQLConf.scala | 6 +++++- .../InsertIntoHadoopFsRelationCommand.scala | 9 +++++++-- .../spark/sql/sources/InsertSuite.scala | 20 +++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d7c830dfa0454..53423e03b6b2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1360,7 +1360,11 @@ object SQLConf { "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + "those partitions that have data written into it at runtime. By default we use static " + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + - "affect Hive serde tables, as they are always overwritten with dynamic mode.") + "affect Hive serde tables, as they are always overwritten with dynamic mode. This can " + + "also be set as an output option for a data source using key partitionOverwriteMode " + + "(which takes precendence over this setting), e.g. " + + "dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)." + ) .stringConf .transform(_.toUpperCase(Locale.ROOT)) .checkValues(PartitionOverwriteMode.values.map(_.toString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index dd7ef0d15c140..8a2e00d9780e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -91,8 +92,12 @@ case class InsertIntoHadoopFsRelationCommand( val pathExists = fs.exists(qualifiedOutputPath) - val enableDynamicOverwrite = - sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + val parameters = CaseInsensitiveMap(options) + + val partitionOverwriteMode = parameters.get("partitionOverwriteMode") + .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) + .getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode) + val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC // This config only makes sense when we are overwriting a partitioned dataset with dynamic // partition columns. val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 438d5d8176b8b..0b6d93975daef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -545,6 +545,26 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } + test("SPARK-24860: dynamic partition overwrite specified per source without catalog table") { + withTempPath { path => + Seq((1, 1), (2, 2)).toDF("i", "part") + .write.partitionBy("part") + .parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1) :: Row(2, 2) :: Nil) + + Seq((1, 2), (1, 3)).toDF("i", "part") + .write.partitionBy("part").mode("overwrite") + .option("partitionOverwriteMode", "dynamic").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), + Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + + Seq((1, 2), (1, 3)).toDF("i", "part") + .write.partitionBy("part").mode("overwrite") + .option("partitionOverwriteMode", "static").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 2) :: Row(1, 3) :: Nil) + } + } + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { withTable("test_table") { val schema = new StructType() From d2e7deb59f641e93778b763d5396f73d38f9a785 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 25 Jul 2018 17:22:37 -0700 Subject: [PATCH 1216/2461] [SPARK-24867][SQL] Add AnalysisBarrier to DataFrameWriter ## What changes were proposed in this pull request? ```Scala val udf1 = udf({(x: Int, y: Int) => x + y}) val df = spark.range(0, 3).toDF("a") .withColumn("b", udf1($"a", udf1($"a", lit(10)))) df.cache() df.write.saveAsTable("t") ``` Cache is not being used because the plans do not match with the cached plan. This is a regression caused by the changes we made in AnalysisBarrier, since not all the Analyzer rules are idempotent. ## How was this patch tested? Added a test. Also found a bug in the DSV1 write path. This is not a regression. Thus, opened a separate JIRA https://issues.apache.org/jira/browse/SPARK-24869 Author: Xiao Li Closes #21821 from gatorsmile/testMaster22. --- .../apache/spark/sql/DataFrameWriter.scala | 10 +++-- .../spark/sql/execution/command/ddl.scala | 7 ++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 42 ++++++++++++++++++- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b9fa43f1f9fbd..39c0e102b69b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) if (writer.isPresent) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) + WriteToDataSourceV2(writer.get(), df.planWithBarrier) } } @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier) } } @@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], - query = df.logicalPlan, + query = df.planWithBarrier, overwrite = mode == SaveMode.Overwrite, ifPartitionNotExists = false) } @@ -459,7 +459,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) + runCommand(df.sparkSession, "saveAsTable") { + CreateTable(tableDesc, mode, Some(df.planWithBarrier)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 04bf8c6dd917f..c7f7e4d755cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} +import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -891,8 +891,9 @@ object DDLUtils { * Throws exception if outputPath tries to overwrite inputpath. */ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = { - val inputPaths = query.collect { - case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths + val inputPaths = EliminateBarriers(query).collect { + case LogicalRelation(r: HadoopFsRelation, _, _, _) => + r.location.rootPaths }.flatten if (inputPaths.contains(outputPath)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d8074571ffc65..30dca9497ddde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types.{DataTypes, DoubleType} +import org.apache.spark.sql.util.QueryExecutionListener + private case class FunctionResult(f1: String, f2: String) @@ -325,6 +330,41 @@ class UDFSuite extends QueryTest with SharedSQLContext { } } + test("cached Data should be used in the write path") { + withTable("t") { + withTempPath { path => + var numTotalCachedHit = 0 + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.withCachedData match { + case c: CreateDataSourceTableAsSelectCommand + if c.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case i: InsertIntoHadoopFsRelationCommand + if i.query.isInstanceOf[InMemoryRelation] => + numTotalCachedHit += 1 + case _ => + } + } + } + spark.listenerManager.register(listener) + + val udf1 = udf({ (x: Int, y: Int) => x + y }) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", lit(10))) + df.cache() + df.write.saveAsTable("t") + assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable") + df.write.insertInto("t") + assert(numTotalCachedHit == 2, "expected to be cached in insertInto") + df.write.save(path.getCanonicalPath) + assert(numTotalCachedHit == 3, "expected to be cached in save for native") + } + } + } + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { val udf1 = udf({(x: Int, y: Int) => x + y}) val df = spark.range(0, 3).toDF("a") From c9b233d4144790c3e57e1a1d1602ad5dc354e8a8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 26 Jul 2018 15:06:13 +0800 Subject: [PATCH 1217/2461] [SPARK-24878][SQL] Fix reverse function for array type of primitive type containing null. ## What changes were proposed in this pull request? If we use `reverse` function for array type of primitive type containing `null` and the child array is `UnsafeArrayData`, the function returns a wrong result because `UnsafeArrayData` doesn't define the behavior of re-assignment, especially we can't set a valid value after we set `null`. ## How was this patch tested? Added some tests. Author: Takuya UESHIN Closes #21830 from ueshin/issues/SPARK-24878/fix_reverse. --- .../expressions/collectionOperations.scala | 66 ++++++++++--------- .../spark/sql/DataFrameFunctionsSuite.scala | 63 +++++++++++++----- 2 files changed, 80 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f438748d9a4ff..b3d04bfa86455 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1244,46 +1244,50 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI } private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val length = ctx.freshName("length") - val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + val initialization = if (isPrimitiveType) { - s"$childName.copy()" + ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") } else { - s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" - } - - val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length - - val swapAssigments = if (isPrimitiveType) { - val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) - val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) - s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); - |boolean isNullAtL = ${ev.value}.isNullAt(l); - |if(!isNullAtK) { - | $javaElementType el = ${getCall("k")}; - | if(!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | } else { - | ${ev.value}.setNullAt(k); - | } - | ${ev.value}.$setFunc(l, el); - |} else if (!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | ${ev.value}.setNullAt(l); - |}""".stripMargin + val arrayDataClass = classOf[GenericArrayData].getName + s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" + } + + val i = ctx.freshName("i") + val j = ctx.freshName("j") + + val getValue = CodeGenerator.getValue(childName, elementType, i) + + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + + val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($childName.isNullAt($i)) { + | $arrayData.setNullAt($j); + |} else { + | $arrayData.$setFunc($j, $getValue); + |} + """.stripMargin } else { - s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + s"$arrayData.$setFunc($j, $getValue);" } s""" - |final int $length = $childName.numElements(); - |${ev.value} = $initialization; - |for(int k = 0; k < $numberOfIterations; k++) { - | int l = $length - k - 1; - | $swapAssigments + |final int $numElements = $childName.numElements(); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | int $j = $numElements - $i - 1; + | $assignment |} + |${ev.value} = $arrayData; """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index bf04251e655ed..5a7bd45a4b5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -901,8 +901,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("reverse function") { - // String test cases + test("reverse function - string") { val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") def testString(): Unit = { checkAnswer(oneRowDF.select(reverse('s)), Seq(Row("krapS"))) @@ -917,37 +916,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { // Test with cached relation, the Project will be evaluated with codegen oneRowDF.cache() testString() + } - // Array test cases (primitive-type elements) - val idf = Seq( + test("reverse function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( Seq(1, 9, 8, 7), Seq(5, 8, 9, 7, 2), Seq.empty, null ).toDF("i") - def testArray(): Unit = { + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { checkAnswer( - idf.select(reverse('i)), + idfNotContainsNull.select(reverse('i)), Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) ) checkAnswer( - idf.selectExpr("reverse(i)"), + idfNotContainsNull.selectExpr("reverse(i)"), Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) ) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("reverse function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer( + idfContainsNull.select(reverse('i)), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) + ) checkAnswer( - idf.selectExpr("reverse(array(1, null, 2, null))"), - Seq.fill(idf.count().toInt)(Row(Seq(null, 2, null, 1))) + idfContainsNull.selectExpr("reverse(i)"), + Seq(Row(Seq(7, null, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5, null)), Row(Seq.empty), Row(null)) ) } // Test with local relation, the Project will be evaluated without codegen - testArray() + testArrayOfPrimitiveTypeContainsNull() // Test with cached relation, the Project will be evaluated with codegen - idf.cache() - testArray() + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() + } - // Array test cases (non-primitive-type elements) + test("reverse function - array for non-primitive type") { val sdf = Seq( Seq("c", "a", "b"), Seq("b", null, "c", null), @@ -975,14 +998,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { // Test with cached relation, the Project will be evaluated with codegen sdf.cache() testArrayOfNonPrimitiveType() + } - // Error test cases - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(struct(1, 'a'))") + test("reverse function - data type mismatch") { + val ex1 = intercept[AnalysisException] { + sql("select reverse(struct(1, 'a'))") } - intercept[AnalysisException] { - oneRowDF.selectExpr("reverse(map(1, 'a'))") + assert(ex1.getMessage.contains("data type mismatch")) + + val ex2 = intercept[AnalysisException] { + sql("select reverse(map(1, 'a'))") } + assert(ex2.getMessage.contains("data type mismatch")) } test("array position function") { From 58353d7f4baa8102c3d2f4777a5c407f14993306 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 26 Jul 2018 16:11:03 +0800 Subject: [PATCH 1218/2461] [SPARK-24924][SQL] Add mapping for built-in Avro data source ## What changes were proposed in this pull request? This PR aims to the followings. 1. Like `com.databricks.spark.csv` mapping, we had better map `com.databricks.spark.avro` to built-in Avro data source. 2. Remove incorrect error message, `Please find an Avro package at ...`. ## How was this patch tested? Pass the newly added tests. Author: Dongjoon Hyun Closes #21878 from dongjoon-hyun/SPARK-24924. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 10 +++++++++- .../sql/execution/datasources/DataSource.scala | 8 ++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ---------------- .../sql/sources/ResolvedDataSourceSuite.scala | 14 ++------------ 4 files changed, 13 insertions(+), 35 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index ec1627a3898bf..865a145094853 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -33,6 +33,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -51,6 +52,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(newEntries, originalEntries) } + test("resolve avro data source") { + Seq("avro", "com.databricks.spark.avro").foreach { provider => + assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === + classOf[org.apache.spark.sql.avro.AvroFileFormat]) + } + } + test("reading from multiple paths") { val df = spark.read.format("avro").load(episodesAvro, episodesAvro) assert(df.count == 16) @@ -456,7 +464,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // get the same values back. withTempPath { tempDir => val name = "AvroTest" - val namespace = "com.databricks.spark.avro" + val namespace = "org.apache.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 0c3d9a4895fe2..b1a10fdb60207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -571,6 +571,7 @@ object DataSource extends Logging { val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName val rate = classOf[RateStreamProvider].getCanonicalName + val avro = "org.apache.spark.sql.avro.AvroFileFormat" Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -592,6 +593,7 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, + "com.databricks.spark.avro" -> avro, "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) @@ -635,12 +637,6 @@ object DataSource extends Logging { "Hive built-in ORC data source must be used with Hive support enabled. " + "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + "'native'") - } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || - provider1 == "com.databricks.spark.avro") { - throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + - "Please find an Avro package at " + - "http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dfb9c137b74f0..86083d1701c2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1689,22 +1689,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("Hive built-in ORC data source must be used with Hive support")) - e = intercept[AnalysisException] { - sql(s"select id from `com.databricks.spark.avro`.`file_path`") - } - assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) - - // data source type is case insensitive - e = intercept[AnalysisException] { - sql(s"select id from Avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - - e = intercept[AnalysisException] { - sql(s"select id from avro.`file_path`") - } - assert(e.message.contains("Failed to find data source: avro.")) - e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 4adbff5c663bc..95460fa70d8f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -77,19 +77,9 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { } test("error message for unknown data sources") { - val error1 = intercept[AnalysisException] { - getProvidingClass("avro") - } - assert(error1.getMessage.contains("Failed to find data source: avro.")) - - val error2 = intercept[AnalysisException] { - getProvidingClass("com.databricks.spark.avro") - } - assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) - - val error3 = intercept[ClassNotFoundException] { + val error = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) + assert(error.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } } From 5ed7660d14022eb65396e28496c06e47c1dbab1d Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 26 Jul 2018 11:06:23 -0700 Subject: [PATCH 1219/2461] [SPARK-24802][SQL][FOLLOW-UP] Add a new config for Optimization Rule Exclusion ## What changes were proposed in this pull request? This is an extension to the original PR, in which rule exclusion did not work for classes derived from Optimizer, e.g., SparkOptimizer. To solve this issue, Optimizer and its derived classes will define/override `defaultBatches` and `nonExcludableRules` in order to define its default rule set as well as rules that cannot be excluded by the SQL config. In the meantime, Optimizer's `batches` method is dedicated to the rule exclusion logic and is defined "final". ## How was this patch tested? Added UT. Author: maryannxue Closes #21876 from maryannxue/rule-exclusion. --- .../sql/catalyst/optimizer/Optimizer.scala | 24 ++++++++- .../optimizer/OptimizerExtendableSuite.scala | 2 +- .../OptimizerRuleExclusionSuite.scala | 53 ++++++++++++++----- ...mizerStructuralIntegrityCheckerSuite.scala | 2 +- .../spark/sql/execution/SparkOptimizer.scala | 5 +- 5 files changed, 69 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index adb1350adc261..3c264eb8586b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,6 +46,13 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) + /** + * Defines the default rule batches in the Optimizer. + * + * Implementations of this class should override this method, and [[nonExcludableRules]] if + * necessary, instead of [[batches]]. The rule batches that eventually run in the Optimizer, + * i.e., returned by [[batches]], will be (defaultBatches - (excludedRules - nonExcludableRules)). + */ def defaultBatches: Seq[Batch] = { val operatorOptimizationRuleSet = Seq( @@ -160,6 +167,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) UpdateNullabilityInAttributeReferences) } + /** + * Defines rules that cannot be excluded from the Optimizer even if they are specified in + * SQL config "excludedRules". + * + * Implementations of this class can override this method if necessary. The rule batches + * that eventually run in the Optimizer, i.e., returned by [[batches]], will be + * (defaultBatches - (excludedRules - nonExcludableRules)). + */ def nonExcludableRules: Seq[String] = EliminateDistinct.ruleName :: EliminateSubqueryAliases.ruleName :: @@ -202,7 +217,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) */ def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil - override def batches: Seq[Batch] = { + /** + * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that + * eventually run in the Optimizer. + * + * Implementations of this class should override [[defaultBatches]], and [[nonExcludableRules]] + * if necessary, instead of this method. + */ + final override def batches: Seq[Batch] = { val excludedRulesConf = SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq) val excludedRules = excludedRulesConf.filter { ruleName => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 7112c033eabce..36b083a540c3c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -47,7 +47,7 @@ class OptimizerExtendableSuite extends SparkFunSuite { DummyRule) :: Nil } - override def batches: Seq[Batch] = super.batches ++ myBatches + override def defaultBatches: Seq[Batch] = super.defaultBatches ++ myBatches } test("Extending batches possible") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala index 5a5396e6f58b0..30c80d26b67a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -28,8 +28,10 @@ class OptimizerRuleExclusionSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - private def verifyExcludedRules(excludedRuleNames: Seq[String]) { - val optimizer = new SimpleTestOptimizer() + private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]) { + val nonExcludableRules = optimizer.nonExcludableRules + + val excludedRuleNames = rulesToExclude.filter(!nonExcludableRules.contains(_)) // Batches whose rules are all to be excluded should be removed as a whole. val excludedBatchNames = optimizer.batches .filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName))) @@ -38,21 +40,31 @@ class OptimizerRuleExclusionSuite extends PlanTest { withSQLConf( OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) { val batches = optimizer.batches + // Verify removed batches. assert(batches.forall(batch => !excludedBatchNames.contains(batch.name))) + // Verify removed rules. assert( batches .forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName)))) + // Verify non-excludable rules retained. + nonExcludableRules.foreach { nonExcludableRule => + assert( + optimizer.batches + .exists(batch => batch.rules.exists(rule => rule.ruleName == nonExcludableRule))) + } } } test("Exclude a single rule from multiple batches") { verifyExcludedRules( + new SimpleTestOptimizer(), Seq( PushPredicateThroughJoin.ruleName)) } test("Exclude multiple rules from single or multiple batches") { verifyExcludedRules( + new SimpleTestOptimizer(), Seq( CombineUnions.ruleName, RemoveLiteralFromGroupExpressions.ruleName, @@ -61,6 +73,7 @@ class OptimizerRuleExclusionSuite extends PlanTest { test("Exclude non-existent rule with other valid rules") { verifyExcludedRules( + new SimpleTestOptimizer(), Seq( LimitPushDown.ruleName, InferFiltersFromConstraints.ruleName, @@ -68,20 +81,34 @@ class OptimizerRuleExclusionSuite extends PlanTest { } test("Try to exclude a non-excludable rule") { - val excludedRules = Seq( - ReplaceIntersectWithSemiJoin.ruleName, - PullupCorrelatedPredicates.ruleName) + verifyExcludedRules( + new SimpleTestOptimizer(), + Seq( + ReplaceIntersectWithSemiJoin.ruleName, + PullupCorrelatedPredicates.ruleName)) + } - val optimizer = new SimpleTestOptimizer() + test("Custom optimizer") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("push", Once, + PushDownPredicate, + PushPredicateThroughJoin, + PushProjectionThroughUnion) :: + Batch("pull", Once, + PullupCorrelatedPredicates) :: Nil - withSQLConf( - OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { - excludedRules.foreach { excludedRule => - assert( - optimizer.batches - .exists(batch => batch.rules.exists(rule => rule.ruleName == excludedRule))) - } + override def nonExcludableRules: Seq[String] = + PushDownPredicate.ruleName :: + PullupCorrelatedPredicates.ruleName :: Nil } + + verifyExcludedRules( + optimizer, + Seq( + PushDownPredicate.ruleName, + PushProjectionThroughUnion.ruleName, + PullupCorrelatedPredicates.ruleName)) } test("Verify optimized plan after excluding CombineUnions rule") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala index 6e183d81b7265..a22a81e9844d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala @@ -44,7 +44,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { EmptyFunctionRegistry, new SQLConf())) { val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI) - override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches + override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches } test("check for invalid plan after execution of rule") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..64d3f2cdbfa82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -28,13 +28,16 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog) { - override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def nonExcludableRules: Seq[String] = + super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName + /** * Optimization batches that are executed before the regular optimization batches (also before * the finish analysis batch). From e3486e1b9556e00bc9c392a5b8440ab366780f9b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 26 Jul 2018 12:09:01 -0700 Subject: [PATCH 1220/2461] [SPARK-24795][CORE] Implement barrier execution mode ## What changes were proposed in this pull request? Propose new APIs and modify job/task scheduling to support barrier execution mode, which requires all tasks in a same barrier stage start at the same time, and retry all tasks in case some tasks fail in the middle. The barrier execution mode is useful for some ML/DL workloads. The proposed API changes include: - `RDDBarrier` that marks an RDD as barrier (Spark must launch all the tasks together for the current stage). - `BarrierTaskContext` that support global sync of all tasks in a barrier stage, and provide extra `BarrierTaskInfo`s. In DAGScheduler, we retry all tasks of a barrier stage in case some tasks fail in the middle, this is achieved by unregistering map outputs for a shuffleId (for ShuffleMapStage) or clear the finished partitions in an active job (for ResultStage). ## How was this patch tested? Add `RDDBarrierSuite` to ensure we convert RDDs correctly; Add new test cases in `DAGSchedulerSuite` to ensure we do task scheduling correctly; Add new test cases in `SparkContextSuite` to ensure the barrier execution mode actually works (both under local mode and local cluster mode). Add new test cases in `TaskSchedulerImplSuite` to ensure we schedule tasks for barrier taskSet together. Author: Xingbo Jiang Closes #21758 from jiangxb1987/barrier-execution-mode. --- .../org/apache/spark/BarrierTaskContext.scala | 42 ++++++ .../apache/spark/BarrierTaskContextImpl.scala | 49 +++++++ .../org/apache/spark/BarrierTaskInfo.scala | 31 +++++ .../org/apache/spark/MapOutputTracker.scala | 12 ++ .../apache/spark/rdd/MapPartitionsRDD.scala | 15 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 27 +++- .../org/apache/spark/rdd/RDDBarrier.scala | 52 +++++++ .../org/apache/spark/rdd/ShuffledRDD.scala | 2 + .../apache/spark/scheduler/ActiveJob.scala | 6 + .../apache/spark/scheduler/DAGScheduler.scala | 131 +++++++++++++++--- .../apache/spark/scheduler/ResultTask.scala | 9 +- .../spark/scheduler/ShuffleMapTask.scala | 7 +- .../org/apache/spark/scheduler/Stage.scala | 8 +- .../org/apache/spark/scheduler/Task.scala | 41 ++++-- .../spark/scheduler/TaskDescription.scala | 7 +- .../spark/scheduler/TaskSchedulerImpl.scala | 66 +++++++-- .../spark/scheduler/TaskSetManager.scala | 9 +- .../apache/spark/scheduler/WorkerOffer.scala | 8 +- .../CoarseGrainedSchedulerBackend.scala | 6 +- .../local/LocalSchedulerBackend.scala | 3 +- .../org/apache/spark/SparkContextSuite.scala | 42 ++++++ .../apache/spark/executor/ExecutorSuite.scala | 1 + .../apache/spark/rdd/RDDBarrierSuite.scala | 43 ++++++ .../spark/scheduler/DAGSchedulerSuite.scala | 58 ++++++++ .../org/apache/spark/scheduler/FakeTask.scala | 24 +++- .../scheduler/TaskDescriptionSuite.scala | 2 + .../scheduler/TaskSchedulerImplSuite.scala | 34 +++++ ...esosFineGrainedSchedulerBackendSuite.scala | 2 + 28 files changed, 673 insertions(+), 64 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/BarrierTaskContext.scala create mode 100644 core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala create mode 100644 core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala new file mode 100644 index 0000000000000..4c358629dee96 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.{Experimental, Since} + +/** A [[TaskContext]] with extra info and tooling for a barrier stage. */ +trait BarrierTaskContext extends TaskContext { + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit + + /** + * :: Experimental :: + * Returns the all task infos in this barrier stage, the task infos are ordered by partitionId. + */ + @Experimental + @Since("2.4.0") + def getTaskInfos(): Array[BarrierTaskInfo] +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala new file mode 100644 index 0000000000000..8ac705757a382 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.Properties + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem + +/** A [[BarrierTaskContext]] implementation. */ +private[spark] class BarrierTaskContextImpl( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) + with BarrierTaskContext { + + // TODO SPARK-24817 implement global barrier. + override def barrier(): Unit = {} + + override def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = localProperties.getProperty("addresses", "") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala new file mode 100644 index 0000000000000..ce2653df2e845 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.{Experimental, Since} + + +/** + * :: Experimental :: + * Carries all task infos of a barrier task. + * + * @param address the IPv4 address(host:port) of the executor that a barrier task is running on + */ +@Experimental +@Since("2.4.0") +class BarrierTaskInfo(val address: String) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 73646051f264c..1c4fa4bc6541f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -434,6 +434,18 @@ private[spark] class MapOutputTrackerMaster( } } + /** Unregister all map output information of the given shuffle. */ + def unregisterAllMapOutput(shuffleId: Int) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeOutputsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1c..904d9c025629f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -23,11 +23,21 @@ import org.apache.spark.{Partition, TaskContext} /** * An RDD that applies the provided function to every partition of the parent RDD. + * + * @param prev the parent RDD. + * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to + * an output iterator. + * @param preservesPartitioning Whether the input function preserves the partitioner, which should + * be `false` unless `prev` is a pair RDD and the input function + * doesn't modify the keys. + * @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage + * containing at least one RDDBarrier shall be turned into a barrier stage. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) - preservesPartitioning: Boolean = false) + preservesPartitioning: Boolean = false, + isFromBarrier: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -41,4 +51,7 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( super.clearDependencies() prev = null } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32ac..cbc1143126d8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble @@ -1647,6 +1647,14 @@ abstract class RDD[T: ClassTag]( } } + /** + * :: Experimental :: + * Indicates that Spark must launch the tasks together for the current stage. + */ + @Experimental + @Since("2.4.0") + def barrier(): RDDBarrier[T] = withScope(new RDDBarrier[T](this)) + // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1839,6 +1847,23 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Whether the RDD is in a barrier stage. Spark must launch all the tasks at the same time for a + * barrier stage. + * + * An RDD is in a barrier stage, if at least one of its parent RDD(s), or itself, are mapped from + * an [[RDDBarrier]]. This function always returns false for a [[ShuffledRDD]], since a + * [[ShuffledRDD]] indicates start of a new stage. + * + * A [[MapPartitionsRDD]] can be transformed from an [[RDDBarrier]], under that case the + * [[MapPartitionsRDD]] shall be marked as barrier. + */ + private[spark] def isBarrier(): Boolean = isBarrier_ + + // From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long + // RDD chain. + @transient protected lazy val isBarrier_ : Boolean = dependencies.exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala new file mode 100644 index 0000000000000..85565d16e2717 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.BarrierTaskContext +import org.apache.spark.TaskContext +import org.apache.spark.annotation.{Experimental, Since} + +/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */ +class RDDBarrier[T: ClassTag](rdd: RDD[T]) { + + /** + * :: Experimental :: + * Maps partitions together with a provided BarrierTaskContext. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. + */ + @Experimental + @Since("2.4.0") + def mapPartitions[S: ClassTag]( + f: (Iterator[T], BarrierTaskContext) => Iterator[S], + preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { + val cleanedF = rdd.sparkContext.clean(f) + new MapPartitionsRDD( + rdd, + (context: TaskContext, index: Int, iter: Iterator[T]) => + cleanedF(iter, context.asInstanceOf[BarrierTaskContext]), + preservesPartitioning, + isFromBarrier = true + ) + } + + /** TODO extra conf(e.g. timeout) */ +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 26eaa9aa3d03f..e8f9b27b7eb55 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( super.clearDependencies() prev = null } + + private[spark] override def isBarrier(): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 949e88f606275..6e4d062749d5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -60,4 +60,10 @@ private[spark] class ActiveJob( val finished = Array.fill[Boolean](numPartitions)(false) var numFinished = 0 + + /** Resets the status of all partitions in this stage so they are marked as not finished. */ + def resetAllPartitions(): Unit = { + (0 until numPartitions).foreach(finished.update(_, false)) + numFinished = 0 + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f74425d73b392..003d64f78e853 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1062,7 +1062,7 @@ class DAGScheduler( stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), - Option(sc.applicationId), sc.applicationAttemptId) + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } case stage: ResultStage => @@ -1072,7 +1072,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, - Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) } } } catch { @@ -1311,17 +1312,6 @@ class DAGScheduler( } } - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - stage match { - case sms: ShuffleMapStage => - sms.pendingPartitions += task.partitionId - - case _ => - assert(false, "TaskSetManagers should only send Resubmitted task statuses for " + - "tasks in ShuffleMapStages.") - } - case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1331,9 +1321,9 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest // It is likely that we receive multiple FetchFailed for a single stage (because we have @@ -1349,6 +1339,29 @@ class DAGScheduler( s"longer running") } + if (mapStage.rdd.isBarrier()) { + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(shuffleId) + } else if (mapId != -1) { + // Mark the map whose fetch failed as broken in the map stage + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + if (failedStage.rdd.isBarrier()) { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Mark all the partitions of the result stage to be not finished, to ensure retry + // all the tasks on resubmitted stage attempt. + failedResultStage.activeJob.map(_.resetAllPartitions()) + } + } + if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1375,7 +1388,7 @@ class DAGScheduler( // simpler while not producing an overwhelming number of scheduler events. logInfo( s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure" + s"$failedStage (${failedStage.name}) due to fetch failure" ) messageScheduler.schedule( new Runnable { @@ -1386,10 +1399,6 @@ class DAGScheduler( ) } } - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { @@ -1411,6 +1420,76 @@ class DAGScheduler( } } + case failure: TaskFailedReason if task.isBarrier => + // Also handle the task failed reasons here. + failure match { + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _ => // Do nothing. + } + + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully. " + + failure.toErrorString + try { + // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, interruptThread = false) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not cancel tasks for stage $stageId", e) + abortStage(failedStage, "Could not cancel zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Mark all the partitions of the result stage to be not finished, to ensure retry + // all the tasks on resubmitted stage attempt. + failedResultStage.activeJob.map(_.resetAllPartitions()) + } + + // update failedStages and make sure a ResubmitFailedStages event is enqueued + failedStages += failedStage + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + + case Resubmitted => + handleResubmittedFailure(task, stage) + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits @@ -1426,6 +1505,18 @@ class DAGScheduler( } } + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { + logInfo(s"Resubmitted $task, so marking it as still running.") + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + throw new SparkException("TaskSetManagers should only send Resubmitted task " + + "statuses for tasks in ShuffleMapStages.") + } + } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { // Mark any map-stage jobs waiting on this stage as finished if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e36c759a42556..aafeae05b566c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -48,7 +48,9 @@ import org.apache.spark.rdd.RDD * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to - */ + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -60,9 +62,10 @@ private[spark] class ResultTask[T, U]( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, - jobId, appId, appAttemptId) + jobId, appId, appAttemptId, isBarrier) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 7a25c47e2cab3..f2cd65fd523ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -49,6 +49,8 @@ import org.apache.spark.shuffle.ShuffleWriter * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -60,9 +62,10 @@ private[spark] class ShuffleMapTask( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, - serializedTaskMetrics, jobId, appId, appAttemptId) + serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 290fd073caf27..26cca334d3bd5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -82,15 +82,15 @@ private[scheduler] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** - * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these - * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid + * endless retries if a stage keeps failing. * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - val fetchFailedAttemptIds = new HashSet[Int] + val failedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { - fetchFailedAttemptIds.clear() + failedAttemptIds.clear() } /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index f536fc2a5f0a1..89ff2038e5f8a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -49,6 +49,8 @@ import org.apache.spark.util._ * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] abstract class Task[T]( val stageId: Int, @@ -60,7 +62,8 @@ private[spark] abstract class Task[T]( SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, - val appAttemptId: Option[String] = None) extends Serializable { + val appAttemptId: Option[String] = None, + val isBarrier: Boolean = false) extends Serializable { @transient lazy val metrics: TaskMetrics = SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) @@ -77,16 +80,32 @@ private[spark] abstract class Task[T]( attemptNumber: Int, metricsSystem: MetricsSystem): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) - context = new TaskContextImpl( - stageId, - stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal - partitionId, - taskAttemptId, - attemptNumber, - taskMemoryManager, - localProperties, - metricsSystem, - metrics) + // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether + // the stage is barrier. + context = if (isBarrier) { + new BarrierTaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } else { + new TaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } + TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e404..bb4a4442b9433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -50,6 +50,7 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet + val partitionId: Int, val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, @@ -76,6 +77,7 @@ private[spark] object TaskDescription { dataOut.writeUTF(taskDescription.executorId) dataOut.writeUTF(taskDescription.name) dataOut.writeInt(taskDescription.index) + dataOut.writeInt(taskDescription.partitionId) // Write files. serializeStringLongMap(taskDescription.addedFiles, dataOut) @@ -117,6 +119,7 @@ private[spark] object TaskDescription { val executorId = dataIn.readUTF() val name = dataIn.readUTF() val index = dataIn.readInt() + val partitionId = dataIn.readInt() // Read files. val taskFiles = deserializeStringLongMap(dataIn) @@ -138,7 +141,7 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles, + taskJars, properties, serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 56c0bf6c09351..587ed4b5243b7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -274,7 +274,8 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]], + addressesWithDescs: ArrayBuffer[(String, TaskDescription)]) : Boolean = { var launchedTask = false // nodes and executors that are blacklisted for the entire application have already been // filtered out by this point @@ -291,6 +292,11 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) + // Only update hosts for a barrier task. + if (taskSet.isBarrier) { + // The executor address is expected to be non empty. + addressesWithDescs += (shuffledOffers(i).address.get -> task) + } launchedTask = true } } catch { @@ -346,6 +352,7 @@ private[spark] class TaskSchedulerImpl( // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray + val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -359,20 +366,55 @@ private[spark] class TaskSchedulerImpl( // of locality levels so that it gets a chance to launch local tasks on all of them. // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY for (taskSet <- sortedTaskSets) { - var launchedAnyTask = false - var launchedTaskAtCurrentMaxLocality = false - for (currentMaxLocality <- taskSet.myLocalityLevels) { - do { - launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) - launchedAnyTask |= launchedTaskAtCurrentMaxLocality - } while (launchedTaskAtCurrentMaxLocality) - } - if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + // Skip the barrier taskSet if the available slots are less than the number of pending tasks. + if (taskSet.isBarrier && availableSlots < taskSet.numTasks) { + // Skip the launch process. + // TODO SPARK-24819 If the job requires more slots than available (both busy and free + // slots), fail the job on submit. + logInfo(s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because the barrier taskSet requires ${taskSet.numTasks} slots, while the total " + + s"number of available slots is $availableSlots.") + } else { + var launchedAnyTask = false + // Record all the executor IDs assigned barrier tasks on. + val addressesWithDescs = ArrayBuffer[(String, TaskDescription)]() + for (currentMaxLocality <- taskSet.myLocalityLevels) { + var launchedTaskAtCurrentMaxLocality = false + do { + launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet, + currentMaxLocality, shuffledOffers, availableCpus, tasks, addressesWithDescs) + launchedAnyTask |= launchedTaskAtCurrentMaxLocality + } while (launchedTaskAtCurrentMaxLocality) + } + if (!launchedAnyTask) { + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + } + if (launchedAnyTask && taskSet.isBarrier) { + // Check whether the barrier tasks are partially launched. + // TODO SPARK-24818 handle the assert failure case (that can happen when some locality + // requirements are not fulfilled, and we should revert the launched tasks). + require(addressesWithDescs.size == taskSet.numTasks, + s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because only ${addressesWithDescs.size} out of a total number of " + + s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + + "been blacklisted or cannot fulfill task locality requirements.") + + // Update the taskInfos into all the barrier task properties. + val addressesStr = addressesWithDescs + // Addresses ordered by partitionId + .sortBy(_._2.partitionId) + .map(_._1) + .mkString(",") + addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) + + logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + + s"stage ${taskSet.stageId}.") + } } } + // TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get + // launched within a configured time. if (tasks.size > 0) { hasLaunchedTask = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index defed1e0f9c6c..0b21256ab6cce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -123,6 +123,10 @@ private[spark] class TaskSetManager( // TODO: We should kill any running task attempts when the task set manager becomes a zombie. private[scheduler] var isZombie = false + // Whether the taskSet run tasks from a barrier stage. Spark must launch all the tasks at the + // same time for a barrier stage. + private[scheduler] def isBarrier = taskSet.tasks.nonEmpty && taskSet.tasks(0).isBarrier + // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect @@ -512,6 +516,7 @@ private[spark] class TaskSetManager( execId, taskName, index, + task.partitionId, addedFiles, addedJars, task.localProperties, @@ -979,8 +984,8 @@ private[spark] class TaskSetManager( */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a - // zombie. - if (isZombie || numTasks == 1) { + // zombie or is from a barrier stage. + if (isZombie || isBarrier || numTasks == 1) { return false } var foundTasks = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 810b36cddf835..6ec74913e42f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -21,4 +21,10 @@ package org.apache.spark.scheduler * Represents free resources available on an executor. */ private[spark] -case class WorkerOffer(executorId: String, host: String, cores: Int) +case class WorkerOffer( + executorId: String, + host: String, + cores: Int, + // `address` is an optional hostPort string, it provide more useful information than `host` + // when multiple executors are launched on the same host. + address: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9b90e309d2e04..375aeb0c34661 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -242,7 +242,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + new WorkerOffer(id, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort)) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -267,7 +268,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort))) scheduler.resourceOffers(workOffers) } else { Seq.empty diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 4c614c5c0f602..cf8b0ff4f7019 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,8 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores, + Some(rpcEnv.address.hostPort))) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ce9f2be1c02dd..e5f31a04c6e7e 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -627,6 +627,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } + + test("support barrier execution mode under local mode") { + val conf = new SparkConf().setAppName("test").setMaster("local[2]") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + test("support barrier execution mode under local-cluster mode") { + val conf = new SparkConf() + .setMaster("local-cluster[3, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1a7bebe2c53cd..77a7668d3a1d1 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -275,6 +275,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executorId = "", name = "", index = 0, + partitionId = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), properties = new Properties, diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala new file mode 100644 index 0000000000000..39d4618e4c6c5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { + + test("create an RDDBarrier") { + val rdd = sc.parallelize(1 to 10, 4) + assert(rdd.isBarrier() === false) + + val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter) + assert(rdd2.isBarrier() === true) + } + + test("create an RDDBarrier in the middle of a chain of RDDs") { + val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) + val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).map(x => (x, x + 1)) + assert(rdd2.isBarrier() === true) + } + + test("RDDBarrier with shuffle") { + val rdd = sc.parallelize(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).repartition(2) + assert(rdd2.isBarrier() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2987170bf5026..b3db5e29fb82e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1055,6 +1055,64 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(sparkListener.failedStages.size == 1) } + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 1) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(1))) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 0) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + /** * This tests the case where another FetchFailed comes in while the map stage is getting * re-run. diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 109d4a0a870b8..b29d32f7b35c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -27,8 +27,10 @@ class FakeTask( partitionId: Int, prefLocs: Seq[TaskLocation] = Nil, serializedTaskMetrics: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) - extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) { + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), + isBarrier: Boolean = false) + extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics, + isBarrier = isBarrier) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -74,4 +76,22 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*) + } + + def createBarrierTaskSet( + numTasks: Int, + stageId: Int, + stageAttempId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) + } + new TaskSet(tasks, stageId, stageAttempId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 97487ce1d2ca8..ba62eec0522db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -62,6 +62,7 @@ class TaskDescriptionSuite extends SparkFunSuite { executorId = "testExecutor", name = "task for test", index = 19, + partitionId = 1, originalFiles, originalJars, originalProperties, @@ -77,6 +78,7 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) assert(decodedTaskDescription.name === originalTaskDescription.name) assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.partitionId === originalTaskDescription.partitionId) assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 33f2ea1c94e75..624384abcd71d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1021,4 +1021,38 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) } } + + test("don't schedule for a barrier taskSet if available slots are less than pending tasks") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, since the available slots are less than pending + // tasks, don't schedule barrier tasks on the resource offer. + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions.length) + } + + test("schedule tasks for a barrier taskSet if all tasks can be launched together") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")), + new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, all tasks get launched together + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(3 === taskDescriptions.length) + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 2d2f90c63a309..31f84310485a0 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -253,6 +253,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), @@ -361,6 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), From 2c82745686f4456c4d5c84040a431dcb5b6cb60b Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 26 Jul 2018 12:13:27 -0700 Subject: [PATCH 1221/2461] [SPARK-24307][CORE] Add conf to revert to old code. In case there are any issues in converting FileSegmentManagedBuffer to ChunkedByteBuffer, add a conf to go back to old code path. Followup to 7e847646d1f377f46dc3154dea37148d4e557a03 Author: Imran Rashid Closes #21867 from squito/SPARK-24307-p2. --- .../scala/org/apache/spark/storage/BlockManager.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1db032711ce42..5cd21e31c9554 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -132,6 +132,8 @@ private[spark] class BlockManager( conf.getBoolean("spark.shuffle.service.enabled", false) private val chunkSize = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt + private val remoteReadNioBufferConversion = + conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. @@ -731,7 +733,14 @@ private[spark] class BlockManager( } if (data != null) { - return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize)) + // SPARK-24307 undocumented "escape-hatch" in case there are any issues in converting to + // ChunkedByteBuffer, to go back to old code-path. Can be removed post Spark 2.4 if + // new path is stable. + if (remoteReadNioBufferConversion) { + return Some(new ChunkedByteBuffer(data.nioByteBuffer())) + } else { + return Some(ChunkedByteBuffer.fromManagedBuffer(data, chunkSize)) + } } logDebug(s"The value of block $blockId is null") } From fa09d91925c07a58dea285d6cf85a751664f89ff Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 26 Jul 2018 16:50:59 -0700 Subject: [PATCH 1222/2461] [SPARK-24919][BUILD] New linter rule for sparkContext.hadoopConfiguration ## What changes were proposed in this pull request? In most cases, we should use `spark.sessionState.newHadoopConf()` instead of `sparkContext.hadoopConfiguration`, so that the hadoop configurations specified in Spark session configuration will come into effect. Add a rule matching `spark.sparkContext.hadoopConfiguration` or `spark.sqlContext.sparkContext.hadoopConfiguration` to prevent the usage. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21873 from gengliangwang/linterRule. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 26 +++++-------------- .../apache/spark/ml/image/HadoopUtils.scala | 4 +++ .../apache/spark/ml/clustering/LDASuite.scala | 2 +- scalastyle-config.xml | 13 ++++++++++ .../HadoopFileLinesReaderSuite.scala | 22 ++++++++-------- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 11 +++++--- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 8 files changed, 45 insertions(+), 37 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 865a145094853..a93309e8ed9b5 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -638,12 +638,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { intercept[FileNotFoundException] { withTempPath { dir => FileUtils.touch(new File(dir, "test")) - val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration - try { - hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { spark.read.format("avro").load(dir.toString) - } finally { - hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } } } @@ -717,15 +713,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Files.createFile(new File(tempSaveDir, "non-avro").toPath) - val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration - val count = try { - hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { val newDf = spark.read.format("avro").load(tempSaveDir) - newDf.count() - } finally { - hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + assert(newDf.count() == 8) } - assert(count == 8) } } @@ -888,20 +879,15 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Paths.get(new URL(episodesAvro).toURI), Paths.get(dir.getCanonicalPath, "episodes")) - val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration - val count = try { - hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + val hadoopConf = spark.sessionState.newHadoopConf() + withSQLConf(AvroFileFormat.IgnoreFilesWithoutExtensionProperty -> "true") { val newDf = spark .read .option("ignoreExtension", "true") .format("avro") .load(s"${dir.getCanonicalPath}/episodes") - newDf.count() - } finally { - hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + assert(newDf.count() == 8) } - - assert(count == 8) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index f1579ec5844a4..1fae1dc04ad7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -38,7 +38,9 @@ private object RecursiveFlag { */ def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = { val flagName = FileInputFormat.INPUT_DIR_RECURSIVE + // scalastyle:off hadoopconfiguration val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { @@ -98,7 +100,9 @@ private object SamplePathFilter { val sampleImages = sampleRatio < 1 if (sampleImages) { val flagName = FileInputFormat.PATHFILTER_CLASS + // scalastyle:off hadoopconfiguration val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration val old = Option(hadoopConf.getClass(flagName, null)) hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) hadoopConf.setLong(SamplePathFilter.seedParam, seed) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index db92132d18b7b..bbd5408c9fce3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -285,7 +285,7 @@ class LDASuite extends MLTest with DefaultReadWriteTest { // There should be 1 checkpoint remaining. assert(model.getCheckpointFiles.length === 1) val checkpointFile = new Path(model.getCheckpointFiles.head) - val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = checkpointFile.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointFile)) model.deleteCheckpointFiles() assert(model.getCheckpointFiles.isEmpty) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e65e3aafe5b5b..da5c3f29c32dc 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,19 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + spark(.sqlContext)?.sparkContext.hadoopConfiguration + + + @VisibleForTesting val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) - val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val hadoopConf = conf.getOrElse(spark.sessionState.newHadoopConf()) val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) reader.map(_.toString) @@ -111,20 +111,20 @@ class HadoopFileLinesReaderSuite extends SharedSQLContext { } test("io.file.buffer.size is less than line length") { - val conf = spark.sparkContext.hadoopConfiguration - conf.set("io.file.buffer.size", "2") - withTempPath { path => - val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) - assert(lines == Seq("123456")) + withSQLConf("io.file.buffer.size" -> "2") { + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } } } test("line cannot be longer than line.maxlength") { - val conf = spark.sparkContext.hadoopConfiguration - conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") - withTempPath { path => - val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) - assert(lines == Seq("1234")) + withSQLConf("mapreduce.input.linerecordreader.line.maxlength" -> "5") { + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0b3de3d4cd599..728817729dcf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -783,7 +783,7 @@ class HiveDDLSuite val part1 = Map("a" -> "1", "b" -> "5") val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2ea51791d0f79..741b0124c83b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1177,13 +1177,18 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(spark.table("with_parts").filter($"p" === 2).collect().head == Row(1, 2)) } - val originalValue = spark.sparkContext.hadoopConfiguration.get(modeConfKey, "nonstrict") + // Turn off style check since the following test is to modify hadoop configuration on purpose. + // scalastyle:off hadoopconfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration + // scalastyle:on hadoopconfiguration + + val originalValue = hadoopConf.get(modeConfKey, "nonstrict") try { - spark.sparkContext.hadoopConfiguration.set(modeConfKey, "nonstrict") + hadoopConf.set(modeConfKey, "nonstrict") sql("INSERT OVERWRITE TABLE with_parts partition(p) select 3, 4") assert(spark.table("with_parts").filter($"p" === 4).collect().head == Row(3, 4)) } finally { - spark.sparkContext.hadoopConfiguration.set(modeConfKey, originalValue) + hadoopConf.set(modeConfKey, originalValue) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 828c18a770c80..1a916824c5d9e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2053,7 +2053,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit") deleteOnExitField.setAccessible(true) - val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]] val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() From 094aa597155dfcbf41a2490c9e462415e3824901 Mon Sep 17 00:00:00 2001 From: Misha Dmitriev Date: Thu, 26 Jul 2018 22:15:12 -0500 Subject: [PATCH 1223/2461] [SPARK-24801][CORE] Avoid memory waste by empty byte[] arrays in SaslEncryption$EncryptedMessage ## What changes were proposed in this pull request? Initialize SaslEncryption$EncryptedMessage.byteChannel lazily, so that empty, not yet used instances of ByteArrayWritableChannel referenced by this field don't use up memory. I analyzed a heap dump from Yarn Node Manager where this code is used, and found that there are over 40,000 of the above objects in memory, each with a big empty byte[] array. The reason they are all there is because of Netty queued up a large number of messages in memory before transferTo() is called. There is a small number of netty ChannelOutboundBuffer objects, and then collectively , via linked lists starting from their flushedEntry data fields, they end up referencing over 40K ChannelOutboundBuffer$Entry objects, which ultimately reference EncryptedMessage objects. ## How was this patch tested? Ran all the tests locally. Author: Misha Dmitriev Closes #21811 from countmdm/misha/spark-24801. --- .../org/apache/spark/network/sasl/SaslEncryption.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3ac9081d78a75..d3b2a334baadd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -135,13 +135,14 @@ static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; + private final int maxOutboundBlockSize; /** * A channel used to buffer input data for encryption. The channel has an upper size bound * so that if the input is larger than the allowed buffer, it will be broken into multiple - * chunks. + * chunks. Made non-final to enable lazy initialization, which saves memory. */ - private final ByteArrayWritableChannel byteChannel; + private ByteArrayWritableChannel byteChannel; private ByteBuf currentHeader; private ByteBuffer currentChunk; @@ -157,7 +158,7 @@ static class EncryptedMessage extends AbstractFileRegion { this.isByteBuf = msg instanceof ByteBuf; this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; - this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + this.maxOutboundBlockSize = maxOutboundBlockSize; } /** @@ -292,6 +293,9 @@ public long transferTo(final WritableByteChannel target, final long position) } private void nextChunk() throws IOException { + if (byteChannel == null) { + byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + } byteChannel.reset(); if (isByteBuf) { int copied = byteChannel.write(buf.nioBuffer()); From dc3713cca22e2ac66ad0a4206b0a09e289167137 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Fri, 27 Jul 2018 13:27:17 +0800 Subject: [PATCH 1224/2461] [SPARK-24829][STS] In Spark Thrift Server, CAST AS FLOAT inconsistent with spark-shell or spark-sql ## What changes were proposed in this pull request? SELECT CAST('4.56' AS FLOAT) the result is 4.559999942779541 ![2018-07-18_110944](https://user-images.githubusercontent.com/24823338/42857199-7c6783da-8a7b-11e8-8c69-1e9302102525.png) it should be 4.56 as same as in spark-shell or spark-sql. ![2018-07-18_111111](https://user-images.githubusercontent.com/24823338/42857210-80c89e96-8a7b-11e8-9f8c-de1a79a73752.png) ## How was this patch tested? add unit tests Author: zuotingbing Closes #21789 from zuotingbing/SPARK-24829. --- .../java/org/apache/hive/service/cli/Column.java | 2 +- .../thriftserver/HiveThriftServer2Suites.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java index 2e21f18d61268..adb269aa235ea 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java @@ -349,7 +349,7 @@ public void addValue(Type type, Object field) { break; case FLOAT_TYPE: nulls.set(size, field == null); - doubleVars()[size] = field == null ? 0 : ((Float)field).doubleValue(); + doubleVars()[size] = field == null ? 0 : new Double(field.toString()); break; case DOUBLE_TYPE: nulls.set(size, field == null); diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 192f33a45e273..70eb28cdd0c64 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -636,6 +636,14 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(pipeoutFileList(sessionID).length == 0) } } + + test("SPARK-24829 Checks cast as float") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT CAST('4.56' AS FLOAT)") + resultSet.next() + assert(resultSet.getString(1) === "4.56") + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -766,6 +774,14 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { assert(resultSet.getString(2) === HiveUtils.builtinHiveVersion) } } + + test("SPARK-24829 Checks cast as float") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT CAST('4.56' AS FLOAT)") + resultSet.next() + assert(resultSet.getString(1) === "4.56") + } + } } object ServerMode extends Enumeration { From f9c9d80e46001852ef3568bad0b141a840a34ae2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 27 Jul 2018 13:29:54 +0800 Subject: [PATCH 1225/2461] [SPARK-24929][INFRA] Make merge script don't swallow KeyboardInterrupt ## What changes were proposed in this pull request? If you want to get out of the loop to assign JIRA's user by command+c (KeyboardInterrupt), I am unable to get out. I faced this problem when the user doesn't have a contributor role and I just wanted to cancel and manually take an action to the JIRA. **Before:** ``` JIRA is unassigned, choose assignee [0] todd.chen (Reporter) Enter number of user, or userid, to assign to (blank to leave unassigned):Traceback (most recent call last): File "./dev/merge_spark_pr.py", line 322, in choose_jira_assignee "Enter number of user, or userid, to assign to (blank to leave unassigned):") KeyboardInterrupt Error assigning JIRA, try again (or leave blank and fix manually) JIRA is unassigned, choose assignee [0] todd.chen (Reporter) Enter number of user, or userid, to assign to (blank to leave unassigned):Traceback (most recent call last): File "./dev/merge_spark_pr.py", line 322, in choose_jira_assignee "Enter number of user, or userid, to assign to (blank to leave unassigned):") KeyboardInterrupt ``` **After:** ``` JIRA is unassigned, choose assignee [0] Dongjoon Hyun (Reporter) Enter number of user, or userid to assign to (blank to leave unassigned):Traceback (most recent call last): File "./dev/merge_spark_pr.py", line 322, in choose_jira_assignee "Enter number of user, or userid to assign to (blank to leave unassigned):") KeyboardInterrupt Restoring head pointer to master git checkout master Already on 'master' git branch ``` ## How was this patch tested? I tested this manually (I use my own merging script with few fixes). Author: hyukjinkwon Closes #21880 from HyukjinKwon/key-error. --- dev/merge_spark_pr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 79c7c021fe74a..fd3eeb007a845 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -319,7 +319,7 @@ def choose_jira_assignee(issue, asf_jira): annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) raw_assignee = input( - "Enter number of user, or userid, to assign to (blank to leave unassigned):") + "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None else: @@ -331,6 +331,8 @@ def choose_jira_assignee(issue, asf_jira): assignee = asf_jira.user(raw_assignee) asf_jira.assign_issue(issue.key, assignee.key) return assignee + except KeyboardInterrupt: + raise except: traceback.print_exc() print("Error assigning JIRA, try again (or leave blank and fix manually)") From e6e9031d7b7458c9d88205d06cdcdd95a98b9537 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 Jul 2018 14:29:05 +0800 Subject: [PATCH 1226/2461] [SPARK-24865] Remove AnalysisBarrier ## What changes were proposed in this pull request? AnalysisBarrier was introduced in SPARK-20392 to improve analysis speed (don't re-analyze nodes that have already been analyzed). Before AnalysisBarrier, we already had some infrastructure in place, with analysis specific functions (resolveOperators and resolveExpressions). These functions do not recursively traverse down subplans that are already analyzed (with a mutable boolean flag _analyzed). The issue with the old system was that developers started using transformDown, which does a top-down traversal of the plan tree, because there was not top-down resolution function, and as a result analyzer performance became pretty bad. In order to fix the issue in SPARK-20392, AnalysisBarrier was introduced as a special node and for this special node, transform/transformUp/transformDown don't traverse down. However, the introduction of this special node caused a lot more troubles than it solves. This implicit node breaks assumptions and code in a few places, and it's hard to know when analysis barrier would exist, and when it wouldn't. Just a simple search of AnalysisBarrier in PR discussions demonstrates it is a source of bugs and additional complexity. Instead, this pull request removes AnalysisBarrier and reverts back to the old approach. We added infrastructure in tests that fail explicitly if transform methods are used in the analyzer. ## How was this patch tested? Added a test suite AnalysisHelperSuite for testing the resolve* methods and transform* methods. Author: Reynold Xin Author: Xiao Li Closes #21822 from rxin/SPARK-24865. --- .../sql/catalyst/analysis/Analyzer.scala | 118 +++++------ .../sql/catalyst/analysis/CheckAnalysis.scala | 13 +- .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../analysis/ResolveInlineTables.scala | 2 +- .../ResolveTableValuedFunctions.scala | 2 +- .../SubstituteUnresolvedOrdinals.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 32 +-- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../plans/logical/AnalysisHelper.scala | 192 ++++++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 -- .../sql/catalyst/plans/LogicalPlanSuite.scala | 30 +-- .../plans/logical/AnalysisHelperSuite.scala | 159 +++++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 12 +- .../scala/org/apache/spark/sql/Dataset.scala | 7 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 4 +- .../spark/sql/execution/command/ddl.scala | 4 +- .../datasources/DataSourceStrategy.scala | 4 +- .../sql/execution/datasources/rules.scala | 6 +- .../spark/sql/GroupedDatasetSuite.scala | 96 --------- .../spark/sql/hive/HiveStrategies.scala | 8 +- .../sql/hive/execution/HiveExplainSuite.scala | 17 -- 25 files changed, 460 insertions(+), 276 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4f474f4987dcf..8e8f8e3e7eda5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,11 +102,11 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { val analyzed = execute(plan) try { checkAnalysis(analyzed) - EliminateBarriers(analyzed) + analyzed } catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) @@ -203,7 +203,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -213,7 +213,7 @@ class Analyzer( } def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case u : UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) @@ -231,10 +231,10 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => - child.transform { + child.resolveOperators { case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = @@ -271,7 +271,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -491,7 +491,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -524,7 +524,7 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p @@ -694,7 +694,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -755,12 +755,6 @@ class Analyzer( s"between $left and $right") right.collect { - // For `AnalysisBarrier`, recursively de-duplicate its child. - case oldVersion: AnalysisBarrier - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = dedupRight(left, oldVersion.child) - (oldVersion, AnalysisBarrier(newVersion)) - // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -805,6 +799,7 @@ class Analyzer( right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + // TODO(rxin): Why do we need transformUp here? right transformUp { case r if r == oldRelation => newRelation } transformUp { @@ -865,7 +860,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => + plan resolveOperatorsDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) @@ -891,7 +886,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -1086,7 +1081,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1142,7 +1137,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1166,9 +1161,8 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, child) @@ -1208,12 +1202,6 @@ class Analyzer( (exprs, plan) } else { plan match { - // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via - // its child. - case barrier: AnalysisBarrier => - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) - (newExprs, AnalysisBarrier(newChild)) - case p: Project => // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) @@ -1270,7 +1258,7 @@ class Analyzer( object LookupFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() - plan.transformAllExpressions { + plan.resolveExpressions { case f: UnresolvedFunction if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f @@ -1309,7 +1297,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1364,7 +1352,7 @@ class Analyzer( * resolved outer references are wrapped in an [[OuterReference]] */ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { - plan transformDown { + plan resolveOperatorsDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { case u @ UnresolvedAttribute(nameParts) => @@ -1446,7 +1434,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1462,7 +1450,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1485,7 +1473,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1511,9 +1499,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case Filter(cond, AnalysisBarrier(agg: Aggregate)) => - apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1571,8 +1557,6 @@ class Analyzer( case ae: AnalysisException => f } - case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1692,7 +1676,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1752,7 +1736,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1793,7 +1777,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -2017,7 +2001,7 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Filter(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") @@ -2077,7 +2061,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2121,7 +2105,7 @@ class Analyzer( object ResolvedUuidExpressions extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) @@ -2136,7 +2120,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2171,7 +2155,7 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) @@ -2197,7 +2181,7 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + @@ -2215,7 +2199,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2280,7 +2264,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2366,7 +2350,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2400,7 +2384,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2422,8 +2406,13 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child) => child + // This is actually called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. This is also often called in the + // + def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformUp { + case SubqueryAlias(_, child) => child + } } } @@ -2431,7 +2420,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case Union(children) if children.size == 1 => children.head } } @@ -2462,7 +2451,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2491,19 +2480,12 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** Remove the barrier nodes of analysis */ -object EliminateBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case AnalysisBarrier(child) => child - } -} - /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2548,7 +2530,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2636,7 +2618,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: CreateNamedStruct if !e.resolved => val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => @@ -2688,7 +2670,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { private def updateOuterReferenceInSubquery( plan: LogicalPlan, refExprs: Seq[Expression]): LogicalPlan = { - plan transformAllExpressions { case e => + plan resolveExpressions { case e => val outerAlias = refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { @@ -2699,7 +2681,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan transform { + plan resolveOperatorsDown { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 49fe625b8fc6c..f9478a1c3cf4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -79,6 +79,9 @@ trait CheckAnalysis extends PredicateHelper { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + + case p if p.analyzed => // Skip already analyzed sub-plans + case u: UnresolvedRelation => u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") @@ -364,10 +367,11 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } + + plan.setAnalyzed() } /** @@ -531,9 +535,8 @@ trait CheckAnalysis extends PredicateHelper { var foundNonEqualCorrelatedPred: Boolean = false - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { + // Simplify the predicates before validating any unsupported correlation patterns in the plan. + AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, @@ -635,6 +638,6 @@ trait CheckAnalysis extends PredicateHelper { // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - } + }} } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ab63131b07573..65a5888222f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -82,7 +82,7 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f068bce3e9b69..bfe5169c25900 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -85,7 +85,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -107,7 +107,7 @@ object ResolveHints { * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 71ed75454cd4d..4edfe507a7580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) validateInputEvaluable(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index a214e59302cd9..7358f9ee36921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index f9fd0df9e4010..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 316aebdeaffa1..6bdb639011a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -318,7 +318,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) @@ -391,7 +391,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -453,7 +453,7 @@ object TypeCoercion { } override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -512,7 +512,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -555,7 +555,7 @@ object TypeCoercion { object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -670,7 +670,7 @@ object TypeCoercion { */ object Division extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -693,7 +693,7 @@ object TypeCoercion { */ object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => @@ -711,7 +711,7 @@ object TypeCoercion { */ object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => @@ -731,7 +731,7 @@ object TypeCoercion { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -751,7 +751,8 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => p transformExpressionsUp { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c @@ -773,7 +774,8 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => p transformExpressionsUp { // Skip nodes if unresolved or not enough children case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c @@ -801,7 +803,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -822,7 +824,7 @@ object TypeCoercion { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -961,7 +963,7 @@ object TypeCoercion { */ object WindowFrameCoercion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -999,7 +1001,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index af1f9165b0044..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformAllExpressions(transformTimeZoneExprs) + plan.resolveExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 23eb78f914656..feeb6553d1066 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala new file mode 100644 index 0000000000000..039acc1ea4fa8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.analysis.CheckAnalysis +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.util.Utils + + +/** + * [[AnalysisHelper]] defines some infrastructure for the query analyzer. In particular, in query + * analysis we don't want to repeatedly re-analyze sub-plans that have previously been analyzed. + * + * This trait defines a flag `analyzed` that can be set to true once analysis is done on the tree. + * This also provides a set of resolve methods that do not recurse down to sub-plans that have the + * analyzed flag set to true. + * + * The analyzer rules should use the various resolve methods, in lieu of the various transform + * methods defined in [[TreeNode]] and [[QueryPlan]]. + * + * To prevent accidental use of the transform methods, this trait also overrides the transform + * methods to throw exceptions in test mode, if they are used in the analyzer. + */ +trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => + + private var _analyzed: Boolean = false + + /** + * Recursively marks all nodes in this plan tree as analyzed. + * This should only be called by [[CheckAnalysis]]. + */ + private[catalyst] def setAnalyzed(): Unit = { + if (!_analyzed) { + _analyzed = true + children.foreach(_.setAnalyzed()) + } + } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, + * it is left unchanged. This function is similar to `transformUp`, but skips sub-trees that + * have already been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + if (self fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } + } else { + self + } + } + + /** Similar to [[resolveOperators]], but does it top-down. */ + def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(self, identity[LogicalPlan]) + } + + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (self fastEquals afterRule) { + mapChildren(_.resolveOperatorsDown(rule)) + } else { + afterRule.mapChildren(_.resolveOperatorsDown(rule)) + } + } + } else { + self + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + resolveOperators { + case p => p.transformExpressions(r) + } + } + + protected def assertNotAnalysisRule(): Unit = { + if (Utils.isTesting && + AnalysisHelper.inAnalyzer.get > 0 && + AnalysisHelper.resolveOperatorDepth.get == 0) { + throw new RuntimeException("This method should not be called in the analyzer") + } + } + + /** + * In analyzer, use [[resolveOperatorsDown()]] instead. If this is used in the analyzer, + * an exception will be thrown in test mode. It is however OK to call this function within + * the scope of a [[resolveOperatorsDown()]] call. + * @see [[TreeNode.transformDown()]]. + */ + override def transformDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformDown(rule) + } + + /** + * Use [[resolveOperators()]] in the analyzer. + * @see [[TreeNode.transformUp()]] + */ + override def transformUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + assertNotAnalysisRule() + super.transformUp(rule) + } + + /** + * Use [[resolveExpressions()]] in the analyzer. + * @see [[QueryPlan.transformAllExpressions()]] + */ + override def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + assertNotAnalysisRule() + super.transformAllExpressions(rule) + } + +} + + +object AnalysisHelper { + + /** + * A thread local to track whether we are in a resolveOperator call (for the purpose of analysis). + * This is an int because resolve* calls might be be nested (e.g. a rule might trigger another + * query compilation within the rule itself), so we are tracking the depth here. + */ + private val resolveOperatorDepth: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + /** + * A thread local to track whether we are in the analysis phase of query compilation. This is an + * int rather than a boolean in case our analyzer recursively calls itself. + */ + private val inAnalyzer: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + + def allowInvokingTransformsInAnalyzer[T](f: => T): T = { + resolveOperatorDepth.set(resolveOperatorDepth.get + 1) + try f finally { + resolveOperatorDepth.set(resolveOperatorDepth.get - 1) + } + } + + def markInAnalyzer[T](f: => T): T = { + inAnalyzer.set(inAnalyzer.get + 1) + try f finally { + inAnalyzer.set(inAnalyzer.get - 1) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c486ad700f362..0e4456ac0e6a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] + with AnalysisHelper with LogicalPlanStats with QueryPlanConstraints with Logging { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 31f703d018aed..9fb50a5e565e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -555,20 +555,6 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("SPARK-20392: analysis barrier") { - // [[AnalysisBarrier]] will be removed after analysis - checkAnalysis( - Project(Seq(UnresolvedAttribute("tbl.a")), - AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - - // Verify we won't go through a plan wrapped in a barrier. - // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. - val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), - SubqueryAlias("tbl", testRelation))) - assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) - } - test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { val pythonUdf = PythonUDF("pyUDF", null, StructType(Seq(StructField("a", LongType))), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index bf569cb869428..aaab3ff1bf128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown` plus analysis barrier - * and make sure it can correctly skip sub-trees that have already been analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp/transformDown`. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -60,31 +59,6 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 2) } - test("transformUp skips all ready resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) - plan transformUp function - - assert(invocationCount === 0) - - invocationCount = 0 - plan transformDown function - assert(invocationCount === 0) - } - - test("transformUp skips partially resolved plans wrapped in analysis barrier") { - invocationCount = 0 - val plan1 = AnalysisBarrier(Project(Nil, testRelation)) - val plan2 = Project(Nil, plan1) - plan2 transformUp function - - assert(invocationCount === 1) - - invocationCount = 0 - plan2 transformDown function - assert(invocationCount === 1) - } - test("isStreaming") { val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val incrementalRelation = LocalRelation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala new file mode 100644 index 0000000000000..9100e10ca0c09 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelperSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, Literal, NamedExpression} + + +class AnalysisHelperSuite extends SparkFunSuite { + + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val exprFunction: PartialFunction[Expression, Expression] = { + case e: Literal => + invocationCount += 1 + Literal.TrueLiteral + } + + private def projectExprs: Seq[NamedExpression] = Alias(Literal.TrueLiteral, "A")() :: Nil + + test("setAnalyze is recursive") { + val plan = Project(Nil, LocalRelation()) + plan.setAnalyzed() + assert(plan.find(!_.analyzed).isEmpty) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperators(function) + assert(invocationCount === 2) + } + + test("resolveOperatorsDown runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.resolveOperatorsDown(function) + assert(invocationCount === 2) + } + + test("resolveExpressions runs on operators recursively") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.resolveExpressions(exprFunction) + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperators(function) + assert(invocationCount === 0) + } + + test("resolveOperatorsDown skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, LocalRelation())) + plan.setAnalyzed() + plan.resolveOperatorsDown(function) + assert(invocationCount === 0) + } + + test("resolveExpressions skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(projectExprs, Project(projectExprs, LocalRelation())) + plan.setAnalyzed() + plan.resolveExpressions(exprFunction) + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperators(function) + assert(invocationCount === 1) + } + + test("resolveOperatorsDown skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, LocalRelation()) + val plan2 = Project(Nil, plan1) + plan1.setAnalyzed() + plan2.resolveOperatorsDown(function) + assert(invocationCount === 1) + } + + test("resolveExpressions skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(projectExprs, LocalRelation()) + val plan2 = Project(projectExprs, plan1) + plan1.setAnalyzed() + plan2.resolveExpressions(exprFunction) + assert(invocationCount === 1) + } + + test("do not allow transform in analyzer") { + val plan = Project(Nil, LocalRelation()) + // These should be OK since we are not in the analzyer + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + + // The following should fail in the analyzer scope + AnalysisHelper.markInAnalyzer { + intercept[RuntimeException] { plan.transform { case p: Project => p } } + intercept[RuntimeException] { plan.transformUp { case p: Project => p } } + intercept[RuntimeException] { plan.transformDown { case p: Project => p } } + intercept[RuntimeException] { plan.transformAllExpressions { case lit: Literal => lit } } + } + } + + test("allow transform in resolveOperators in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + plan.resolveOperators { case p: Project => p.transform { case p: Project => p } } + plan.resolveOperatorsDown { case p: Project => p.transform { case p: Project => p } } + plan.resolveExpressions { case lit: Literal => + Project(Nil, LocalRelation()).transform { case p: Project => p } + lit + } + } + } + + test("allow transform with allowInvokingTransformsInAnalyzer in the analyzer") { + val plan = Project(Nil, LocalRelation()) + AnalysisHelper.markInAnalyzer { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan.transform { case p: Project => p } + plan.transformUp { case p: Project => p } + plan.transformDown { case p: Project => p } + plan.transformAllExpressions { case lit: Literal => lit } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 39c0e102b69b2..3c9e743106260 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) if (writer.isPresent) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.planWithBarrier) + WriteToDataSourceV2(writer.get(), df.logicalPlan) } } @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], - query = df.planWithBarrier, + query = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, ifPartitionNotExists = false) } @@ -351,7 +351,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def assertNotPartitioned(operation: String): Unit = { if (partitioningColumns.isDefined) { - throw new AnalysisException( s"'$operation' does not support partitioning") + throw new AnalysisException(s"'$operation' does not support partitioning") } } @@ -459,9 +459,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable") { - CreateTable(tableDesc, mode, Some(df.planWithBarrier)) - } + runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c97246f30220d..b63235ec827c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,8 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) + // TODO(rxin): remove this later. + @transient private[sql] val planWithBarrier = logicalPlan /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the @@ -1857,7 +1858,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -1916,7 +1917,7 @@ class Dataset[T] private[sql]( // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, rightChild)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 36f6038aa9485..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) + private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 39d9a95ca4710..ed130dc57ee5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, + sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -173,7 +173,7 @@ class CacheManager extends Logging { // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index c7f7e4d755cfd..e1faecedd20ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -891,7 +891,7 @@ object DDLUtils { * Throws exception if outputPath tries to overwrite inputpath. */ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = { - val inputPaths = EliminateBarriers(query).collect { + val inputPaths = query.collect { case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths }.flatten diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7b129435c45db..e1b049b6ceaba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index dfcf6c14fbef1..3170180b32b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -39,7 +39,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala deleted file mode 100644 index 147c0b61f5017..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.catalyst.expressions.PythonUDF -import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{LongType, StructField, StructType} - -class GroupedDatasetSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private val scalaUDF = udf((x: Long) => { x + 1 }) - private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) - - private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { - assert(atLevel >= 0) - var children = Seq(ds.queryExecution.logical) - (1 to atLevel).foreach { _ => - children = children.flatMap(_.children) - } - val barriers = children.collect { - case ab: AnalysisBarrier => ab - } - assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + - ds.queryExecution.logical) - } - - test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { - val groupByDataset = datasetWithUDF.groupBy() - val rollupDataset = datasetWithUDF.rollup("s") - val cubeDataset = datasetWithUDF.cube("s") - val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) - datasetWithUDF.cache() - Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => - val df = rgDS.count() - assertContainsAnalysisBarrier(df) - assertCached(df) - } - - val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( - Array.emptyByteArray, - Array.emptyByteArray, - Array.empty, - StructType(Seq(StructField("s", LongType)))) - val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( - "pyUDF", - null, - StructType(Seq(StructField("s", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true)) - Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => - assertContainsAnalysisBarrier(df, 2) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } - - test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { - val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) - datasetWithUDF.cache() - val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) - val keysKVDataset = kvDasaset.keys - val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) - val aggKVDataset = kvDasaset.count() - val otherKVDataset = spark.range(1).groupByKey(_ + 1) - val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) - Seq((mapValuesKVDataset, 1), - (keysKVDataset, 2), - (flatMapGroupsKVDataset, 2), - (aggKVDataset, 1), - (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => - assertContainsAnalysisBarrier(df, analysisBarrierDepth) - assertCached(df) - } - datasetWithUDF.unpersist(true) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a0c197b06ddab..9fe83bb332a9a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -114,7 +114,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -145,7 +145,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, @@ -225,7 +225,7 @@ case class RelationConversions( } override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { + plan resolveOperators { // Write path case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 5d56f89c2271c..a1ce1ea936bbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -170,21 +170,4 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } - - test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { - val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - df.explain(true) - } - assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( - s"""== Parsed Logical Plan == - |GlobalLimit 1 - |+- LocalLimit 1 - | +- AnalysisBarrier - | +- Aggregate [a#0], [a#0, count(1) AS count#0L] - | +- Project [_1#0 AS a#0, _2#0 AS b#0] - | +- LocalRelation [_1#0, _2#0] - |""".stripMargin)) - } } From 21fcac1645bf01c453ddd4cb64c566895e66ea4f Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 26 Jul 2018 23:47:32 -0700 Subject: [PATCH 1227/2461] [SPARK-24288][SQL] Add a JDBC Option to enable preventing predicate pushdown ## What changes were proposed in this pull request? Add a JDBC Option "pushDownPredicate" (default `true`) to allow/disallow predicate push-down in JDBC data source. ## How was this patch tested? Add a test in `JDBCSuite` Author: maryannxue Closes #21875 from maryannxue/spark-24288. --- docs/sql-programming-guide.md | 7 ++ .../datasources/jdbc/JDBCOptions.scala | 4 ++ .../datasources/jdbc/JDBCRelation.scala | 6 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 66 +++++++++++++------ 4 files changed, 63 insertions(+), 20 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e815e5bd516e2..4b013c633e27c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1435,6 +1435,13 @@ the following case-insensitive options: The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING". You can also specify partial fields, and the others use the default type mapping. For example, "id DECIMAL(38, 0)". The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. + + + pushDownPredicate + + The option to enable or disable predicate push-down into the JDBC data source. The default value is true, in which case Spark will push down filters to the JDBC data source as much as possible. Otherwise, if set to false, no filter will be pushed down to the JDBC data source and thus all filters will be handled by Spark. Predicate push-down is usually turned off when the predicate filtering is performed faster by Spark than by the JDBC data source. + +
      diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 574aed4958fd7..d80efcedf8c2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -183,6 +183,9 @@ class JDBCOptions( } // An option to execute custom SQL before fetching data from the remote DB val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) + + // An option to allow/disallow pushing down predicate into JDBC data source + val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean } class JdbcOptionsInWrite( @@ -234,4 +237,5 @@ object JDBCOptions { val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") + val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 97e2d255cb7be..4f78f593fa4af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -172,7 +172,11 @@ private[sql] case class JDBCRelation( // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { - filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + if (jdbcOptions.pushDownPredicate) { + filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + } else { + filters + } } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 09facb9bef8dc..0edbd3a55e17e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -261,21 +261,32 @@ class JDBCSuite extends QueryTest s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:`$jdbcRelations`") } + private def checkPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is removed in a physical plan and + // the plan only has PhysicalRDD to scan JDBCRelation. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScanExec]) + assert(node.child.asInstanceOf[DataSourceScanExec].nodeName.contains("JDBCRelation")) + df + } + + private def checkNotPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD + // cannot compile given predicates. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.FilterExec]) + df + } + test("SELECT *") { assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - def checkPushdown(df: DataFrame): DataFrame = { - val parentPlan = df.queryExecution.executedPlan - // Check if SparkPlan Filter is removed in a physical plan and - // the plan only has PhysicalRDD to scan JDBCRelation. - assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) - val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] - assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScanExec]) - assert(node.child.asInstanceOf[DataSourceScanExec].nodeName.contains("JDBCRelation")) - df - } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) @@ -308,15 +319,6 @@ class JDBCSuite extends QueryTest "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) - def checkNotPushdown(df: DataFrame): DataFrame = { - val parentPlan = df.queryExecution.executedPlan - // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD - // cannot compile given predicates. - assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec]) - val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegenExec] - assert(node.child.isInstanceOf[org.apache.spark.sql.execution.FilterExec]) - df - } assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) } @@ -1375,4 +1377,30 @@ class JDBCSuite extends QueryTest Row("fred", 1) :: Nil) } + + test("SPARK-24288: Enable preventing predicate pushdown") { + val table = "test.people" + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbTable", table) + .option("pushDownPredicate", false) + .load() + .filter("theid = 1") + .select("name", "theid") + checkAnswer( + checkNotPushdown(df), + Row("fred", 1) :: Nil) + + // pushDownPredicate option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW predicateOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$urlWithUserAndPass', dbTable '$table', pushDownPredicate 'false') + """.stripMargin.replaceAll("\n", " ")) + checkAnswer( + checkNotPushdown(sql("SELECT name, theid FROM predicateOption WHERE theid = 1")), + Row("fred", 1) :: Nil) + } } From ef6c8395c483954857547cbd75f84a6814317d13 Mon Sep 17 00:00:00 2001 From: pkuwm Date: Fri, 27 Jul 2018 23:02:48 +0900 Subject: [PATCH 1228/2461] [SPARK-23928][SQL] Add shuffle collection function. ## What changes were proposed in this pull request? This PR adds a new collection function: shuffle. It generates a random permutation of the given array. This implementation uses the "inside-out" version of Fisher-Yates algorithm. ## How was this patch tested? New tests are added to CollectionExpressionsSuite.scala and DataFrameFunctionsSuite.scala. Author: Takuya UESHIN Author: pkuwm Closes #21802 from ueshin/issues/SPARK-23928/shuffle. --- python/pyspark/sql/functions.py | 17 +++ .../sql/catalyst/analysis/Analyzer.scala | 7 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 106 ++++++++++++++++++ .../util/RandomIndicesGenerator.scala | 45 ++++++++ .../CollectionExpressionsSuite.scala | 69 ++++++++++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 65 +++++++++++ 8 files changed, 317 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f2e66337257be..0a88e482787ff 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2382,6 +2382,23 @@ def array_sort(col): return Column(sc._jvm.functions.array_sort(_to_java_column(col))) +@since(2.4) +def shuffle(col): + """ + Collection function: Generates a random permutation of the given array. + + .. note:: The function is non-deterministic. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data']) + >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP + [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.shuffle(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8e8f8e3e7eda5..d18509fe8d91b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -181,7 +181,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: - ResolvedUuidExpressions :: + ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -2100,15 +2100,16 @@ class Analyzer( } /** - * Set the seed for random number generation in Uuid expressions. + * Set the seed for random number generation. */ - object ResolvedUuidExpressions extends Rule[LogicalPlan] { + object ResolveRandomSeed extends Rule[LogicalPlan] { private lazy val random = new Random() override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) + case Shuffle(child, None) => Shuffle(child, Some(random.nextLong())) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d696ce9a766d5..adc4837276793 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -429,6 +429,7 @@ object FunctionRegistry { expression[Size]("cardinality"), expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), + expression[Shuffle]("shuffle"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b3d04bfa86455..b1d91ffbe86e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1203,6 +1203,112 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi override def prettyName: String = "array_sort" } +/** + * Returns a random permutation of the given array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a random permutation of the given array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, 3, 5)); + [3, 1, 5, 20] + > SELECT _FUNC_(array(1, 20, null, 3)); + [20, null, 3, 1] + """, + note = "The function is non-deterministic.", + since = "2.4.0") +case class Shuffle(child: Expression, randomSeed: Option[Long] = None) + extends UnaryExpression with ExpectsInputTypes with Stateful { + + def this(child: Expression) = this(child, None) + + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private[this] var random: RandomIndicesGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = { + random = RandomIndicesGenerator(randomSeed.get + partitionIndex) + } + + override protected def evalInternal(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val source = value.asInstanceOf[ArrayData] + val numElements = source.numElements() + val indices = random.getNextIndices(numElements) + new GenericArrayData(indices.map(source.get(_, elementType))) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c)) + } + + private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val randomClass = classOf[RandomIndicesGenerator].getName + + val rand = ctx.addMutableState(randomClass, "rand", forceInline = true) + ctx.addPartitionInitializationStatement( + s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);") + + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val numElements = ctx.freshName("numElements") + val arrayData = ctx.freshName("arrayData") + + val initialization = if (isPrimitiveType) { + ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") + } else { + val arrayDataClass = classOf[GenericArrayData].getName() + s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" + } + + val indices = ctx.freshName("indices") + val i = ctx.freshName("i") + + val getValue = CodeGenerator.getValue(childName, elementType, s"$indices[$i]") + + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + + val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($childName.isNullAt($indices[$i])) { + | $arrayData.setNullAt($i); + |} else { + | $arrayData.$setFunc($i, $getValue); + |} + """.stripMargin + } else { + s"$arrayData.$setFunc($i, $getValue);" + } + + s""" + |int $numElements = $childName.numElements(); + |int[] $indices = $rand.getNextIndices($numElements); + |$initialization + |for (int $i = 0; $i < $numElements; $i++) { + | $assignment + |} + |${ev.value} = $arrayData; + """.stripMargin + } + + override def freshCopy(): Shuffle = Shuffle(child, randomSeed) +} + /** * Returns a reversed string or an array with reverse order of elements. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala new file mode 100644 index 0000000000000..ae05128f94777 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomIndicesGenerator.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.commons.math3.random.MersenneTwister + +/** + * This class is used to generate a random indices of given length. + * + * This implementation uses the "inside-out" version of Fisher-Yates algorithm. + * Reference: + * https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_%22inside-out%22_algorithm + */ +case class RandomIndicesGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextIndices(length: Int): Array[Int] = { + val indices = new Array[Int](length) + var i = 0 + while (i < length) { + val j = random.nextInt(i + 1) + if (j != i) { + indices(i) = indices(j) + } + indices(j) = i + i += 1 + } + indices + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c7f0da71e1440..5c5728548e646 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.TimeZone +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -1434,4 +1436,71 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Shuffle") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3, 4, 5), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType, containsNull = true)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType, containsNull = false)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType, containsNull = false)) + val ai7 = Literal.create(null, ArrayType(IntegerType, containsNull = true)) + + checkEvaluation(Shuffle(ai0, Some(0)), Seq(4, 1, 2, 3, 5)) + checkEvaluation(Shuffle(ai1, Some(0)), Seq(3, 1, 2)) + checkEvaluation(Shuffle(ai2, Some(0)), Seq(3, null, 1, null)) + checkEvaluation(Shuffle(ai3, Some(0)), Seq(null, 2, null, 4)) + checkEvaluation(Shuffle(ai4, Some(0)), Seq(null, null, null)) + checkEvaluation(Shuffle(ai5, Some(0)), Seq(1)) + checkEvaluation(Shuffle(ai6, Some(0)), Seq.empty) + checkEvaluation(Shuffle(ai7, Some(0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c", "d"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType, containsNull = true)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = false)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType, containsNull = false)) + val as7 = Literal.create(null, ArrayType(StringType, containsNull = true)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Shuffle(as0, Some(0)), Seq("d", "a", "b", "c")) + checkEvaluation(Shuffle(as1, Some(0)), Seq("c", "a", "b")) + checkEvaluation(Shuffle(as2, Some(0)), Seq("c", null, "a", null)) + checkEvaluation(Shuffle(as3, Some(0)), Seq(null, "b", null, "d")) + checkEvaluation(Shuffle(as4, Some(0)), Seq(null, null, null)) + checkEvaluation(Shuffle(as5, Some(0)), Seq("a")) + checkEvaluation(Shuffle(as6, Some(0)), Seq.empty) + checkEvaluation(Shuffle(as7, Some(0)), null) + checkEvaluation(Shuffle(aa, Some(0)), Seq(Seq("e"), Seq("a", "b"), Seq("c", "d"))) + + val r = new Random() + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) === + evaluateWithoutCodegen(Shuffle(ai0, seed1))) + assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) === + evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1))) + assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) === + evaluateWithUnsafeProjection(Shuffle(ai0, seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) !== + evaluateWithoutCodegen(Shuffle(ai0, seed2))) + assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) !== + evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed2))) + assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== + evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) + + val shuffle = Shuffle(ai0, seed1) + assert(shuffle.fastEquals(shuffle)) + assert(!shuffle.fastEquals(Shuffle(ai0, seed1))) + assert(!shuffle.fastEquals(shuffle.freshCopy())) + assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index de1d422856ba9..bcd0c946ab996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3545,6 +3545,16 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a random permutation of the given array. + * + * @note The function is non-deterministic. + * + * @group collection_funcs + * @since 2.4.0 + */ + def shuffle(e: Column): Column = withExpr { Shuffle(e.expr) } + /** * Returns a reversed string or an array with reverse order of elements. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5a7bd45a4b5f3..299c96f74af22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1513,6 +1513,71 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + // Shuffle expressions should produce same results at retries in the same DataFrame. + private def checkShuffleResult(df: DataFrame): Unit = { + checkAnswer(df, df.collect()) + } + + test("shuffle function - array for primitive type not containing null") { + val idfNotContainsNull = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkShuffleResult(idfNotContainsNull.select(shuffle('i))) + checkShuffleResult(idfNotContainsNull.selectExpr("shuffle(i)")) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfNotContainsNull.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("shuffle function - array for primitive type containing null") { + val idfContainsNull = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(null, 5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkShuffleResult(idfContainsNull.select(shuffle('i))) + checkShuffleResult(idfContainsNull.selectExpr("shuffle(i)")) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + idfContainsNull.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("shuffle function - array for non-primitive type") { + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkShuffleResult(sdf.select(shuffle('s))) + checkShuffleResult(sdf.selectExpr("shuffle(s)")) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + sdf.cache() + testNonPrimitiveType() + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From c9bec1d3717e5aeb6d3ec95d4c78111bfc33e0ca Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 27 Jul 2018 08:57:48 -0700 Subject: [PATCH 1229/2461] [SPARK-24927][BUILD][BRANCH-2.3] The scope of snappy-java cannot be "provided" ## What changes were proposed in this pull request? Please see [SPARK-24927][1] for more details. [1]: https://issues.apache.org/jira/browse/SPARK-24927 ## How was this patch tested? Manually tested. Author: Cheng Lian Closes #21879 from liancheng/spark-24927. (cherry picked from commit d5f340f27706bd9767f23ac9726f904028916814) Signed-off-by: Xiao Li --- pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/pom.xml b/pom.xml index d75db0f080e64..f320844faf222 100644 --- a/pom.xml +++ b/pom.xml @@ -538,7 +538,6 @@ org.xerial.snappy snappy-java ${snappy.version} - ${hadoop.deps.scope} org.lz4 From 0a0f68bae6c0a1bf30184b1e9ac6bf3805bd7511 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 28 Jul 2018 00:11:32 +0800 Subject: [PATCH 1230/2461] [SPARK-24881][SQL] New Avro option - compression ## What changes were proposed in this pull request? In the PR, I added new option for Avro datasource - `compression`. The option allows to specify compression codec for saved Avro files. This option is similar to `compression` option in another datasources like `JSON` and `CSV`. Also I added the SQL configs `spark.sql.avro.compression.codec` and `spark.sql.avro.deflate.level`. I put the configs into `SQLConf`. If the `compression` option is not specified by an user, the first SQL config is taken into account. ## How was this patch tested? I added new test which read meta info from written avro files and checks `avro.codec` property. Author: Maxim Gekk Closes #21837 from MaxGekk/avro-compression. --- .../spark/sql/avro/AvroFileFormat.scala | 7 +-- .../apache/spark/sql/avro/AvroOptions.scala | 11 +++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 44 +++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 19 ++++++++ 4 files changed, 67 insertions(+), 14 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index c6b3c13be5140..1df1c8b4af2e9 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -117,11 +117,9 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) - val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" - val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val COMPRESS_KEY = "mapred.output.compress" - spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { + parsedOptions.compression match { case "uncompressed" => log.info("writing uncompressed Avro records") job.getConfiguration.setBoolean(COMPRESS_KEY, false) @@ -132,8 +130,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) case "deflate" => - val deflateLevel = spark.conf.get( - AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt + val deflateLevel = spark.sessionState.conf.avroDeflateLevel log.info(s"compressing Avro output using deflate (level=$deflateLevel)") job.getConfiguration.setBoolean(COMPRESS_KEY, true) job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index cd9a911a14bfa..0f59007e7f72c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.internal.SQLConf /** * Options for Avro Reader and Writer stored in case insensitive manner. @@ -68,4 +69,14 @@ class AvroOptions( .map(_.toBoolean) .getOrElse(!ignoreFilesWithoutExtension) } + + /** + * The `compression` option allows to specify a compression codec used in write. + * Currently supported codecs are `uncompressed`, `snappy` and `deflate`. + * If the option is not set, the `spark.sql.avro.compression.codec` config is taken into + * account. If the former one is not set too, the `snappy` codec is used by default. + */ + val compression: String = { + parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index a93309e8ed9b5..2f478c76113eb 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -27,13 +27,14 @@ import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} -import org.apache.avro.file.DataFileWriter -import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.file.{DataFileReader, DataFileWriter} +import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ @@ -364,21 +365,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("write with compression") { + test("write with compression - sql configs") { withTempPath { dir => - val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" - val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" val df = spark.read.format("avro").load(testAvro) - spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") df.write.format("avro").save(uncompressDir) - spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") - spark.conf.set(AVRO_DEFLATE_LEVEL, "9") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "deflate") + spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9") df.write.format("avro").save(deflateDir) - spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "snappy") df.write.format("avro").save(snappyDir) val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) @@ -890,4 +889,31 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } } + + test("SPARK-24881: write with compression - avro options") { + def getCodec(dir: String): Option[String] = { + val files = new File(dir) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + files.map { file => + val reader = new DataFileReader(file, new GenericDatumReader[Any]()) + val r = reader.getMetaString("avro.codec") + r + }.map(v => if (v == "null") "uncompressed" else v).headOption + } + def checkCodec(df: DataFrame, dir: String, codec: String): Unit = { + val subdir = s"$dir/$codec" + df.write.option("compression", codec).format("avro").save(subdir) + assert(getCodec(subdir) == Some(codec)) + } + withTempPath { dir => + val path = dir.toString + val df = spark.read.format("avro").load(testAvro) + + checkCodec(df, path, "uncompressed") + checkCodec(df, path, "deflate") + checkCodec(df, path, "snappy") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53423e03b6b2b..a269e218c4efd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.internal import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference +import java.util.zip.Deflater import scala.collection.JavaConverters._ import scala.collection.immutable @@ -1434,6 +1435,20 @@ object SQLConf { "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") .intConf .createWithDefault(20) + + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") + .doc("Compression codec used in writing of AVRO files. Default codec is snappy.") + .stringConf + .checkValues(Set("uncompressed", "deflate", "snappy")) + .createWithDefault("snappy") + + val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") + .doc("Compression level for the deflate codec used in writing of AVRO files. " + + "Valid value must be in the range of from 1 to 9 inclusive or -1. " + + "The default value is -1 which corresponds to 6 level in the current implementation.") + .intConf + .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) + .createWithDefault(Deflater.DEFAULT_COMPRESSION) } /** @@ -1820,6 +1835,10 @@ class SQLConf extends Serializable with Logging { def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) + + def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From ee5a5a092517c2ec06b005b11001dd9a4ae60db6 Mon Sep 17 00:00:00 2001 From: Karthik Palaniappan Date: Fri, 27 Jul 2018 12:18:56 -0500 Subject: [PATCH 1231/2461] [SPARK-21960][STREAMING] Spark Streaming Dynamic Allocation should respect spark.executor.instances ## What changes were proposed in this pull request? Removes check that `spark.executor.instances` is set to 0 when using Streaming DRA. ## How was this patch tested? Manual tests My only concern with this PR is that `spark.executor.instances` (or the actual initial number of executors that the cluster manager gives Spark) can be outside of `spark.streaming.dynamicAllocation.minExecutors` to `spark.streaming.dynamicAllocation.maxExecutors`. I don't see a good way around that, because this code only runs after the SparkContext has been created. Author: Karthik Palaniappan Closes #19183 from karth295/master. --- .../scheduler/ExecutorAllocationManager.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index 7b29b40668def..8717555dea491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, Utils} /** - * Class that manages executor allocated to a StreamingContext, and dynamically request or kill + * Class that manages executors allocated to a StreamingContext, and dynamically requests or kills * executors based on the statistics of the streaming computation. This is different from the core * dynamic allocation policy; the core policy relies on executors being idle for a while, but the * micro-batch model of streaming prevents any particular executors from being idle for a long @@ -43,6 +43,10 @@ import org.apache.spark.util.{Clock, Utils} * * This features should ideally be used in conjunction with backpressure, as backpressure ensures * system stability, while executors are being readjusted. + * + * Note that an initial set of executors (spark.executor.instances) was allocated when the + * SparkContext was created. This class scales executors up/down after the StreamingContext + * has started. */ private[streaming] class ExecutorAllocationManager( client: ExecutorAllocationClient, @@ -202,12 +206,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { val MAX_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.maxExecutors" def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - val numExecutor = conf.getInt("spark.executor.instances", 0) val streamingDynamicAllocationEnabled = conf.getBoolean(ENABLED_KEY, false) - if (numExecutor != 0 && streamingDynamicAllocationEnabled) { - throw new IllegalArgumentException( - "Dynamic Allocation for streaming cannot be enabled while spark.executor.instances is set.") - } if (Utils.isDynamicAllocationEnabled(conf) && streamingDynamicAllocationEnabled) { throw new IllegalArgumentException( """ @@ -217,7 +216,7 @@ private[streaming] object ExecutorAllocationManager extends Logging { """.stripMargin) } val testing = conf.getBoolean("spark.streaming.dynamicAllocation.testing", false) - numExecutor == 0 && streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) + streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) } def createIfEnabled( From 5828f41a52c446b774a909e96eff8d8c5831b394 Mon Sep 17 00:00:00 2001 From: Hieu Huynh <“Hieu.huynh@oath.com”> Date: Fri, 27 Jul 2018 12:34:14 -0500 Subject: [PATCH 1232/2461] [SPARK-13343] speculative tasks that didn't commit shouldn't be marked as success MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description** Currently Speculative tasks that didn't commit can show up as success (depending on timing of commit). This is a bit confusing because that task didn't really succeed in the sense it didn't write anything. I think these tasks should be marked as KILLED or something that is more obvious to the user exactly what happened. it is happened to hit the timing where it got a commit denied exception then it shows up as failed and counts against your task failures. It shouldn't count against task failures since that failure really doesn't matter. MapReduce handles these situation so perhaps we can look there for a model. unknown **How can this issue happen?** When both attempts of a task finish before the driver sends command to kill one of them, both of them send the status update FINISHED to the driver. The driver calls TaskSchedulerImpl to handle one successful task at a time. When it handles the first successful task, it sends the command to kill the other copy of the task, however, because that task is already finished, the executor will ignore the command. After finishing handling the first attempt, it processes the second one, although all actions on the result of this task are skipped, this copy of the task is still marked as SUCCESS. As a result, even though this issue does not affect the result of the job, it might cause confusing to user because both of them appear to be successful. **How does this PR fix the issue?** The simple way to fix this issue is that when taskSetManager handles successful task, it checks if any other attempt succeeded. If this is the case, it will call handleFailedTask with state==KILLED and reason==TaskKilled(“another attempt succeeded”) to handle this task as begin killed. **How was this patch tested?** I tested this manually by running applications, that caused the issue before, a few times, and observed that the issue does not happen again. Also, I added a unit test in TaskSetManagerSuite to test that if we call handleSuccessfulTask to handle status update for 2 copies of a task, only the one that is handled first will be mark as SUCCESS Author: Hieu Huynh <“Hieu.huynh@oath.com”> Author: hthuynh2 Closes #21653 from hthuynh2/SPARK_13343. --- .../spark/scheduler/TaskSetManager.scala | 19 ++++- .../spark/scheduler/TaskSetManagerSuite.scala | 70 +++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 0b21256ab6cce..8b77641e85b76 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ -import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.{AccumulatorV2, Clock, LongAccumulator, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap /** @@ -728,6 +728,23 @@ private[spark] class TaskSetManager( def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index + // Check if any other attempt succeeded before this and this attempt has not been handled + if (successful(index) && killedByOtherAttempt.contains(tid)) { + // Undo the effect on calculatedTasks and totalResultSize made earlier when + // checking if can fetch more results + calculatedTasks -= 1 + val resultSizeAcc = result.accumUpdates.find(a => + a.name == Some(InternalAccumulator.RESULT_SIZE)) + if (resultSizeAcc.isDefined) { + totalResultSize -= resultSizeAcc.get.asInstanceOf[LongAccumulator].value + } + + // Handle this task as a killed task + handleFailedTask(tid, TaskState.KILLED, + TaskKilled("Finish but did not commit due to another attempt succeeded")) + return + } + info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) if (speculationEnabled) { successfulTaskDurations.insert(info.duration) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 206b9f47eed4f..cf05434aee301 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1532,4 +1532,74 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } + + test("SPARK-13343 speculative tasks that didn't commit shouldn't be marked as success") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec2" -> "host2", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 3 tasks and leave 1 task in running + for (id <- Set(0, 1, 2)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(3)) + + // Offer resource to start the speculative attempt for the running task + val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption5.isDefined) + val task5 = taskOption5.get + assert(task5.index === 3) + assert(task5.taskId === 4) + assert(task5.executorId === "exec1") + assert(task5.attemptNumber === 1) + sched.backend = mock(classOf[SchedulerBackend]) + sched.dagScheduler.stop() + sched.dagScheduler = mock(classOf[DAGScheduler]) + // Complete one attempt for the running task + val result = createTaskResult(3, accumUpdatesByTask(3)) + manager.handleSuccessfulTask(3, result) + // There is a race between the scheduler asking to kill the other task, and that task + // actually finishing. We simulate what happens if the other task finishes before we kill it. + verify(sched.backend).killTask(4, "exec1", true, "another attempt succeeded") + manager.handleSuccessfulTask(4, result) + + val info3 = manager.taskInfos(3) + val info4 = manager.taskInfos(4) + assert(info3.successful) + assert(info4.killed) + verify(sched.dagScheduler).taskEnded( + manager.tasks(3), + TaskKilled("Finish but did not commit due to another attempt succeeded"), + null, + Seq.empty, + info4) + verify(sched.dagScheduler).taskEnded(manager.tasks(3), Success, result.value(), + result.accumUpdates, info3) + } } From 10f1f196595df66cb82d1fb9e27cc7ef0a176766 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 Jul 2018 13:47:33 -0700 Subject: [PATCH 1233/2461] [SPARK-21274][SQL] Implement EXCEPT ALL clause. ## What changes were proposed in this pull request? Implements EXCEPT ALL clause through query rewrites using existing operators in Spark. In this PR, an internal UDTF (replicate_rows) is added to aid in preserving duplicate rows. Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design. **Note** This proposed UDTF is kept as a internal function that is purely used to aid with this particular rewrite to give us flexibility to change to a more generalized UDTF in future. Input Query ``` SQL SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2 ``` Rewritten Query ```SQL SELECT c1 FROM ( SELECT replicate_rows(sum_val, c1) FROM ( SELECT c1, sum_val FROM ( SELECT c1, sum(vcol) AS sum_val FROM ( SELECT 1L as vcol, c1 FROM ut1 UNION ALL SELECT -1L as vcol, c1 FROM ut2 ) AS union_all GROUP BY union_all.c1 ) WHERE sum_val > 0 ) ) ``` ## How was this patch tested? Added test cases in SQLQueryTestSuite, DataFrameSuite and SetOperationSuite Author: Dilip Biswal Closes #21857 from dilipbiswal/dkb_except_all_final. --- python/pyspark/sql/dataframe.py | 25 ++ .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../sql/catalyst/analysis/TypeCoercion.scala | 12 +- .../UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 26 ++ .../sql/catalyst/optimizer/Optimizer.scala | 61 +++- .../optimizer/ReplaceExceptWithFilter.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 7 +- .../optimizer/SetOperationSuite.scala | 24 +- .../catalyst/parser/ErrorParserSuite.scala | 3 - .../sql/catalyst/parser/PlanParserSuite.scala | 1 - .../scala/org/apache/spark/sql/Dataset.scala | 16 + .../spark/sql/execution/SparkStrategies.scala | 6 +- .../resources/sql-tests/inputs/except-all.sql | 146 ++++++++ .../sql-tests/results/except-all.sql.out | 319 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 70 +++- 17 files changed, 708 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/except-all.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/except-all.sql.out diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c40aea9bcef0a..b2e0a5b2390c2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -293,6 +293,31 @@ def explain(self, extended=False): else: print(self._jdf.queryExecution().simpleString()) + @since(2.4) + def exceptAll(self, other): + """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but + not in another :class:`DataFrame` while preserving duplicates. + + This is equivalent to `EXCEPT ALL` in SQL. + + >>> df1 = spark.createDataFrame( + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.exceptAll(df2).show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | a| 2| + | c| 4| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx) + @since(1.3) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d18509fe8d91b..8abb1c70d4919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -916,9 +916,8 @@ class Analyzer( j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) - case i @ Except(left, right) if !i.duplicateResolved => - i.copy(right = dedupRight(left, right)) - + case e @ Except(left, right, _) if !e.duplicateResolved => + e.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6bdb639011a17..f9edca53d571e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -319,11 +319,17 @@ object TypeCoercion { object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ SetOperation(left, right) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - s.makeCopy(Array(newChildren.head, newChildren.last)) + Except(newChildren.head, newChildren.last, isAll) + + case s @ Intersect(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + Intersect(newChildren.head, newChildren.last) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index f68df5d29b545..c9a3ee47a02be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -306,7 +306,7 @@ object UnsupportedOperationChecker { case u: Union if u.children.map(_.isStreaming).distinct.size == 2 => throwError("Union between streaming and batch DataFrames/Datasets is not supported") - case Except(left, right) if right.isStreaming => + case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") case Intersect(left, right) if left.isStreaming && right.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b6e0d364d3a96..d6e67b9ac3d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -223,6 +223,32 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Replicate the row N times. N is specified as the first argument to the function. + * This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND + * INTERSECT ALL queries. + */ +case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + private lazy val numColumns = children.length - 1 // remove the multiplier value from output. + + override def elementSchema: StructType = + StructType(children.tail.zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val numRows = children.head.eval(input).asInstanceOf[Long] + val values = children.tail.map(_.eval(input)).toArray + Range.Long(0, numRows, 1).map { _ => + val fields = new Array[Any](numColumns) + for (col <- 0 until numColumns) { + fields.update(col, values(col)) + } + InternalRow(fields: _*) + } + } +} + /** * Wrapper around another generator to specify outer behavior. This is used to implement functions * such as explode_outer. This expression gets replaced during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3c264eb8586b5..193f6591c9a8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -135,6 +135,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, + RewriteExcepAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -1422,13 +1423,71 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { */ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Except(left, right) => + case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And))) } } +/** + * Replaces logical [[Except]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_rows(sum_val, c1) + * FROM ( + * SELECT c1, sum_val + * FROM ( + * SELECT c1, sum(vcol) AS sum_val + * FROM ( + * SELECT 1L as vcol, c1 FROM ut1 + * UNION ALL + * SELECT -1L as vcol, c1 FROM ut2 + * ) AS union_all + * GROUP BY union_all.c1 + * ) + * WHERE sum_val > 0 + * ) + * ) + * }}} + */ + +object RewriteExcepAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Except(left, right, true) => + assert(left.output.size == right.output.size) + + val newColumnLeft = Alias(Literal(1L), "vcol")() + val newColumnRight = Alias(Literal(-1L), "vcol")() + val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left) + val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right) + val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan) + val aggSumCol = + Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")() + val aggOutputColumns = left.output ++ Seq(aggSumCol) + val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan) + val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan) + val genRowPlan = Generate( + ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + filteredAggPlan + ) + Project(left.output, genRowPlan) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 45edf266bbce4..efd3944eba7f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,7 +46,7 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case e @ Except(left, right) if isEligible(left, right) => + case e @ Except(left, right, false) if isEligible(left, right) => val newCondition = transformCondition(left, skipProject(right)) newCondition.map { c => Distinct(Filter(Not(c), left)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 49f578a24aaeb..8b3c0686181fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -537,7 +537,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => - throw new ParseException("EXCEPT ALL is not supported.", ctx) + Except(left, right, isAll = true) case SqlBaseParser.EXCEPT => Except(left, right) case SqlBaseParser.SETMINUS if all => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ea5a9b8ed5542..498a13a62bd22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -183,8 +183,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - +case class Except( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean = false) extends SetOperation(left, right) { + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index aa8841109329c..f002aa3aacaba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -144,4 +144,26 @@ class SetOperationSuite extends PlanTest { Distinct(Union(query3 :: query4 :: Nil))).analyze comparePlans(distinctUnionCorrectAnswer2, optimized2) } + + test("EXCEPT ALL rewrite") { + val input = Except(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteExcepAll(input) + + val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) + .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f)) + .groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum")) + .where(GreaterThan('sum, Literal(0L))).analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index f67697eb86c26..baaf01800b33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -58,8 +58,5 @@ class ErrorParserSuite extends SparkFunSuite { intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") - intercept("select * from r except all select * from t", 1, 0, - "EXCEPT ALL is not supported", - "^^^") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fb51376c6163f..629e3c4f3fcfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -65,7 +65,6 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) assertEqual("select * from a union all select * from b", a.union(b)) assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") assertEqual("select * from a except distinct select * from b", a.except(b)) assertEqual("select * from a minus select * from b", a.except(b)) intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b63235ec827c9..e6a3b0adcdaa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1948,6 +1948,22 @@ class Dataset[T] private[sql]( Except(planWithBarrier, other.planWithBarrier) } + /** + * Returns a new Dataset containing rows in this Dataset but not in another Dataset while + * preserving the duplicates. + * This is equivalent to `EXCEPT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard in + * SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Except(planWithBarrier, other.planWithBarrier, isAll = true) + } + /** * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), * using a user-supplied seed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0c4ea857fd1d7..3f5fd3dbb9e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -532,9 +532,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Intersect(left, right) => throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.Except(left, right) => + case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") + case logical.Except(left, right, true) => + throw new IllegalStateException( + "logical except (all) operator should have been replaced by union, aggregate" + + "and generate operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql new file mode 100644 index 0000000000000..08b9a437b3d14 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -0,0 +1,146 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1); +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v); +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v); + +-- Basic ExceptAll +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2; + +-- ExceptAll same table in both branches +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL; + +-- Empty left relation +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6; + +-- Type Coerced ExceptAll +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1); + +-- Basic +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4; + +-- Basic +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3; + +-- ExceptAll + Intersect +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4; + +-- ExceptAll + Except +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Chain of set operations +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4; + +-- Join under except all. Should produce empty resultset since both left and right sets +-- are same. +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Join under except all (2) +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k); + +-- Group by under ExceptAll +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; +DROP VIEW IF EXISTS tab3; +DROP VIEW IF EXISTS tab4; diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out new file mode 100644 index 0000000000000..2a21c1505350c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -0,0 +1,319 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 25 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (0), (1), (2), (2), (2), (2), (3), (null), (null) AS tab1(c1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1), (2), (2), (3), (5), (5), (null) AS tab2(c1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW tab3 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (2, 3), + (2, 2) + AS tab3(k, v) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES + (1, 2), + (2, 3), + (2, 2), + (2, 2), + (2, 20) + AS tab4(k, v) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output +0 +2 +2 +NULL + + +-- !query 5 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 IS NOT NULL +-- !query 5 schema +struct +-- !query 5 output +0 +2 +2 +NULL +NULL + + +-- !query 6 +SELECT * FROM tab1 WHERE c1 > 5 +EXCEPT ALL +SELECT * FROM tab2 +-- !query 6 schema +struct +-- !query 6 output + + + +-- !query 7 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6 +-- !query 7 schema +struct +-- !query 7 output +0 +1 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 8 +SELECT * FROM tab1 +EXCEPT ALL +SELECT CAST(1 AS BIGINT) +-- !query 8 schema +struct +-- !query 8 output +0 +2 +2 +2 +2 +3 +NULL +NULL + + +-- !query 9 +SELECT * FROM tab1 +EXCEPT ALL +SELECT array(1) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 10 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +-- !query 10 schema +struct +-- !query 10 output +1 2 +1 3 + + +-- !query 11 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +-- !query 11 schema +struct +-- !query 11 output +2 2 +2 20 + + +-- !query 12 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +INTERSECT DISTINCT +SELECT * FROM tab4 +-- !query 12 schema +struct +-- !query 12 output +2 2 +2 20 + + +-- !query 13 +SELECT * FROM tab4 +EXCEPT ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 13 schema +struct +-- !query 13 output + + + +-- !query 14 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION ALL +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 14 schema +struct +-- !query 14 output +1 3 + + +-- !query 15 +SELECT k FROM tab3 +EXCEPT ALL +SELECT k, v FROM tab4 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +ExceptAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 16 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 16 schema +struct +-- !query 16 output +1 3 + + +-- !query 17 +SELECT * FROM tab3 +EXCEPT ALL +SELECT * FROM tab4 +EXCEPT DISTINCT +SELECT * FROM tab3 +EXCEPT DISTINCT +SELECT * FROM tab4 +-- !query 17 schema +struct +-- !query 17 output + + + +-- !query 18 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 18 schema +struct +-- !query 18 output + + + +-- !query 19 +SELECT * +FROM (SELECT tab3.k, + tab4.v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +EXCEPT ALL +SELECT * +FROM (SELECT tab4.v AS k, + tab3.k AS v + FROM tab3 + JOIN tab4 + ON tab3.k = tab4.k) +-- !query 19 schema +struct +-- !query 19 output +1 2 +1 2 +1 2 +2 20 +2 20 +2 3 +2 3 + + +-- !query 20 +SELECT v FROM tab3 GROUP BY v +EXCEPT ALL +SELECT k FROM tab4 GROUP BY k +-- !query 20 schema +struct +-- !query 20 output +3 + + +-- !query 21 +DROP VIEW IF EXISTS tab1 +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP VIEW IF EXISTS tab2 +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +DROP VIEW IF EXISTS tab3 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +DROP VIEW IF EXISTS tab4 +-- !query 24 schema +struct<> +-- !query 24 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9cf8c47fa6cf1..af0735920cc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -629,6 +629,74 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), From 34ebcc6b5246c1a47e6d3b2dbb23e368de25219e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 Jul 2018 15:34:06 -0700 Subject: [PATCH 1234/2461] [MINOR] Improve documentation for HiveStringType's The diff should be self-explanatory. Author: Reynold Xin Closes #21897 from rxin/hivestringtypedoc. --- .../scala/org/apache/spark/sql/types/HiveStringType.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala index e0bca937d1d84..4eb3226c5786e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -56,14 +56,18 @@ object HiveStringType { } /** - * Hive char type. + * Hive char type. Similar to other HiveStringType's, these datatypes should only used for + * parsing, and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. */ case class CharType(length: Int) extends HiveStringType { override def simpleString: String = s"char($length)" } /** - * Hive varchar type. + * Hive varchar type. Similar to other HiveStringType's, these datatypes should only used for + * parsing, and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. */ case class VarcharType(length: Int) extends HiveStringType { override def simpleString: String = s"varchar($length)" From 6424b146c91fdca734a3ec972067e8e1f88e8b9e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 27 Jul 2018 17:24:55 -0700 Subject: [PATCH 1235/2461] [MINOR] Update docs for functions.scala to make it clear not all the built-in functions are defined there The title summarizes the change. Author: Reynold Xin Closes #21318 from rxin/functions. --- .../scala/org/apache/spark/sql/functions.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bcd0c946ab996..277295816c978 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -39,7 +39,21 @@ import org.apache.spark.util.Utils /** - * Functions available for DataFrame operations. + * Commonly used functions available for DataFrame operations. Using functions defined here provides + * a little bit more compile-time safety to make sure the function exists. + * + * Spark also includes more built-in functions that are less common and are not defined here. + * You can still access them (and all the functions defined here) using the `functions.expr()` API + * and calling them through a SQL expression string. You can find the entire list of functions for + * the latest version of Spark at https://spark.apache.org/docs/latest/api/sql/index.html. + * + * As an example, `isnan` is a function that is defined here. You can use `isnan(col("myCol"))` + * to invoke the `isnan` function. This way the programming language's compiler ensures `isnan` + * exists and is of the proper form. You can also use `expr("isnan(myCol)")` function to invoke the + * same function. In this case, Spark itself will ensure `isnan` exists when it analyzes the query. + * + * `regr_count` is an example of a function that is built-in but not defined here, because it is + * less commonly used. To invoke it, use `expr("regr_count(yCol, xCol)")`. * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions From e8752095a00aba453a92bc822131c001602f0829 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 28 Jul 2018 13:41:07 +0800 Subject: [PATCH 1236/2461] [SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF ## What changes were proposed in this pull request? This PR add supports for using mixed Python UDF and Scalar Pandas UDF, in the following two cases: (1) ``` from pyspark.sql.functions import udf, pandas_udf udf('int') def f1(x): return x + 1 pandas_udf('int') def f2(x): return x + 1 df = spark.range(0, 1).toDF('v') \ .withColumn('foo', f1(col('v'))) \ .withColumn('bar', f2(col('v'))) ``` QueryPlan: ``` >>> df.explain(True) == Parsed Logical Plan == 'Project [v#2L, foo#5, f2('v) AS bar#9] +- AnalysisBarrier +- Project [v#2L, f1(v#2L) AS foo#5] +- Project [id#0L AS v#2L] +- Range (0, 1, step=1, splits=Some(4)) == Analyzed Logical Plan == v: bigint, foo: int, bar: int Project [v#2L, foo#5, f2(v#2L) AS bar#9] +- Project [v#2L, f1(v#2L) AS foo#5] +- Project [id#0L AS v#2L] +- Range (0, 1, step=1, splits=Some(4)) == Optimized Logical Plan == Project [id#0L AS v#2L, f1(id#0L) AS foo#5, f2(id#0L) AS bar#9] +- Range (0, 1, step=1, splits=Some(4)) == Physical Plan == *(2) Project [id#0L AS v#2L, pythonUDF0#13 AS foo#5, pythonUDF0#14 AS bar#9] +- ArrowEvalPython [f2(id#0L)], [id#0L, pythonUDF0#13, pythonUDF0#14] +- BatchEvalPython [f1(id#0L)], [id#0L, pythonUDF0#13] +- *(1) Range (0, 1, step=1, splits=4) ``` (2) ``` from pyspark.sql.functions import udf, pandas_udf udf('int') def f1(x): return x + 1 pandas_udf('int') def f2(x): return x + 1 df = spark.range(0, 1).toDF('v') df = df.withColumn('foo', f2(f1(df['v']))) ``` QueryPlan: ``` >>> df.explain(True) == Parsed Logical Plan == Project [v#21L, f2(f1(v#21L)) AS foo#46] +- AnalysisBarrier +- Project [v#21L, f1(f2(v#21L)) AS foo#39] +- Project [v#21L, ((v#21L)) AS foo#32] +- Project [v#21L, ((v#21L)) AS foo#25] +- Project [id#19L AS v#21L] +- Range (0, 1, step=1, splits=Some(4)) == Analyzed Logical Plan == v: bigint, foo: int Project [v#21L, f2(f1(v#21L)) AS foo#46] +- Project [v#21L, f1(f2(v#21L)) AS foo#39] +- Project [v#21L, ((v#21L)) AS foo#32] +- Project [v#21L, ((v#21L)) AS foo#25] +- Project [id#19L AS v#21L] +- Range (0, 1, step=1, splits=Some(4)) == Optimized Logical Plan == Project [id#19L AS v#21L, f2(f1(id#19L)) AS foo#46] +- Range (0, 1, step=1, splits=Some(4)) == Physical Plan == *(2) Project [id#19L AS v#21L, pythonUDF0#50 AS foo#46] +- ArrowEvalPython [f2(pythonUDF0#49)], [id#19L, pythonUDF0#49, pythonUDF0#50] +- BatchEvalPython [f1(id#19L)], [id#19L, pythonUDF0#49] +- *(1) Range (0, 1, step=1, splits=4) ``` ## How was this patch tested? New tests are added to BatchEvalPythonExecSuite and ScalarPandasUDFTests Author: Li Jin Closes #21650 from icexelloss/SPARK-24624-mix-udf. --- python/pyspark/sql/tests.py | 186 ++++++++++++++++-- .../execution/python/ExtractPythonUDFs.scala | 42 ++-- .../python/BatchEvalPythonExecSuite.scala | 7 + .../python/ExtractPythonUDFsSuite.scala | 92 +++++++++ 4 files changed, 304 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d6b9f01e6525..a294d70119d0b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self): 'Result vector from pandas_udf was not the required length'): df.select(raise_exception(col('id'))).collect() - def test_vectorized_udf_mix_udf(self): - from pyspark.sql.functions import pandas_udf, udf, col - df = self.spark.range(10) - row_by_row_udf = udf(lambda x: x, LongType()) - pd_udf = pandas_udf(lambda x: x, LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Can not mix vectorized and non-vectorized UDFs'): - df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() - def test_vectorized_udf_chained(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) @@ -5060,6 +5049,166 @@ def test_type_annotation(self): df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) self.assertEqual(df.first()[0], 0) + def test_mixed_udf(self): + import pandas as pd + from pyspark.sql.functions import col, udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of multiple UDFs and Pandas UDFs. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + @pandas_udf('int') + def f2(x): + assert type(x) == pd.Series + return x + 10 + + @udf('int') + def f3(x): + assert type(x) == int + return x + 100 + + @pandas_udf('int') + def f4(x): + assert type(x) == pd.Series + return x + 1000 + + # Test single expression with chained UDFs + df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) + df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) + df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) + df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) + + expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) + expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) + expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) + expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) + expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) + + self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) + self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) + self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) + self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) + self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) + + # Test multiple mixed UDF expressions in a single projection + df_multi_1 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(col('f1'))) \ + .withColumn('f3_f1', f3(col('f1'))) \ + .withColumn('f4_f1', f4(col('f1'))) \ + .withColumn('f3_f2', f3(col('f2'))) \ + .withColumn('f4_f2', f4(col('f2'))) \ + .withColumn('f4_f3', f4(col('f3'))) \ + .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ + .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ + .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ + .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ + .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) + + # Test mixed udfs in a single expression + df_multi_2 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(f1(col('v')))) \ + .withColumn('f3_f1', f3(f1(col('v')))) \ + .withColumn('f4_f1', f4(f1(col('v')))) \ + .withColumn('f3_f2', f3(f2(col('v')))) \ + .withColumn('f4_f2', f4(f2(col('v')))) \ + .withColumn('f4_f3', f4(f3(col('v')))) \ + .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ + .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ + .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ + .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ + .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) + + expected = df \ + .withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f4', df['v'] + 1000) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f4_f1', df['v'] + 1001) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f4_f2', df['v'] + 1010) \ + .withColumn('f4_f3', df['v'] + 1100) \ + .withColumn('f3_f2_f1', df['v'] + 111) \ + .withColumn('f4_f2_f1', df['v'] + 1011) \ + .withColumn('f4_f3_f1', df['v'] + 1101) \ + .withColumn('f4_f3_f2', df['v'] + 1110) \ + .withColumn('f4_f3_f2_f1', df['v'] + 1111) + + self.assertEquals(expected.collect(), df_multi_1.collect()) + self.assertEquals(expected.collect(), df_multi_2.collect()) + + def test_mixed_udf_and_sql(self): + import pandas as pd + from pyspark.sql import Column + from pyspark.sql.functions import udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of UDFs, Pandas UDFs and SQL expression. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + def f2(x): + assert type(x) == Column + return x + 10 + + @pandas_udf('int') + def f3(x): + assert type(x) == pd.Series + return x + 100 + + df1 = df.withColumn('f1', f1(df['v'])) \ + .withColumn('f2', f2(df['v'])) \ + .withColumn('f3', f3(df['v'])) \ + .withColumn('f1_f2', f1(f2(df['v']))) \ + .withColumn('f1_f3', f1(f3(df['v']))) \ + .withColumn('f2_f1', f2(f1(df['v']))) \ + .withColumn('f2_f3', f2(f3(df['v']))) \ + .withColumn('f3_f1', f3(f1(df['v']))) \ + .withColumn('f3_f2', f3(f2(df['v']))) \ + .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ + .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ + .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ + .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ + .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ + .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + + expected = df.withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f1_f2', df['v'] + 11) \ + .withColumn('f1_f3', df['v'] + 101) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f2_f3', df['v'] + 110) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f1_f2_f3', df['v'] + 111) \ + .withColumn('f1_f3_f2', df['v'] + 111) \ + .withColumn('f2_f1_f3', df['v'] + 111) \ + .withColumn('f2_f3_f1', df['v'] + 111) \ + .withColumn('f3_f1_f2', df['v'] + 111) \ + .withColumn('f3_f2_f1', df['v'] + 111) + + self.assertEquals(expected.collect(), df1.collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df): F.col('temp0.key') == F.col('temp1.key')) self.assertEquals(res.count(), 5) + def test_mixed_scalar_udfs_followed_by_grouby_apply(self): + import pandas as pd + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + + df = self.spark.range(0, 10).toDF('v1') + df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ + .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) + + result = df.groupby() \ + .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), + 'sum int', + PandasUDFType.GROUPED_MAP)) + + self.assertEquals(result.collect()[0]['sum'], 165) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1e096100f7f43..cb75874be32ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -94,28 +95,44 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private type EvalType = Int + private type EvalTypeChecker = EvalType => Boolean + + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { e.children match { // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + case children => !children.exists(hasScalarPythonUDF) } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = { + // Eval type checker is set once when we find the first evaluable UDF and its value + // shouldn't change later. + // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only + // extract UDFs of the same eval type) + var evalTypeChecker: Option[EvalTypeChecker] = None + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.isEmpty => + evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType) + Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.get(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker - // Therefore we don't need to extract the UDFs - case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -123,7 +140,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ private def extract(plan: SparkPlan): SparkPlan = { - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { @@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) case _ => - throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") + throw new AnalysisException( + "Expected either Scalar Pandas UDFs or Batched UDFs but got both") } attributeMap ++= validUdfs.zip(resultAttrs) @@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { case filter: FilterExec => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) - val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index d456c931f5275..2cc55ff88b983 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( dataType = BooleanType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) + +class MyDummyScalarPandasUDF extends UserDefinedPythonFunction( + name = "dummyScalarPandasUDF", + func = new DummyUDF, + dataType = BooleanType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala new file mode 100644 index 0000000000000..76b609d111acd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext + +class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + val batchedPythonUDF = new MyDummyPythonUDF + val scalarPandasUDF = new MyDummyScalarPandasUDF + + private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect { + case b: BatchEvalPythonExec => b + } + + private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect { + case b: ArrowEvalPythonExec => b + } + + test("Chained Batched Python UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", batchedPythonUDF(col("c"))) + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + } + + test("Chained Scalar Pandas UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", scalarPandasUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("c"))) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(arrowEvalNodes.size == 1) + } + + test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("b"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("c2", batchedPythonUDF(col("c1"))) + .withColumn("d1", scalarPandasUDF(col("a"))) + .withColumn("d2", scalarPandasUDF(col("d1"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("d1", scalarPandasUDF(col("c1"))) + .withColumn("c2", batchedPythonUDF(col("d1"))) + .withColumn("d2", scalarPandasUDF(col("c2"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 2) + assert(arrowEvalNodes.size == 2) + } +} + From c6a3db2fb6d9df1a377a1d3385343f70f9e237e4 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 28 Jul 2018 13:43:32 +0800 Subject: [PATCH 1237/2461] [SPARK-24924][SQL][FOLLOW-UP] Add mapping for built-in Avro data source ## What changes were proposed in this pull request? Add one more test case for `com.databricks.spark.avro`. ## How was this patch tested? N/A Author: Xiao Li Closes #21906 from gatorsmile/avro. --- .../test/scala/org/apache/spark/sql/avro/AvroSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 2f478c76113eb..f59c2cc6ffaaf 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -394,6 +394,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results.length === 8) } + test("old avro data source name works") { + val results = + spark.read.format("com.databricks.spark.avro") + .load(episodesAvro).select("title").collect() + assert(results.length === 8) + } + test("support of various data types") { // This test uses data from test.avro. You can see the data and the schema of this file in // test.json and test.avsc From c5b8d54c61780af6e9e157e6c855718df972efad Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Sat, 28 Jul 2018 10:40:10 -0500 Subject: [PATCH 1238/2461] [SPARK-24950][SQL] DateTimeUtilsSuite daysToMillis and millisToDays fails w/java 8 181-b13 ## What changes were proposed in this pull request? - Update DateTimeUtilsSuite so that when testing roundtripping in daysToMillis and millisToDays multiple skipdates can be specified. - Updated test so that both new years eve 2014 and new years day 2015 are skipped for kiribati time zones. This is necessary as java versions pre 181-b13 considered new years day 2015 to be skipped while susequent versions corrected this to new years eve. ## How was this patch tested? Unit tests Author: Chris Martin Closes #21901 from d80tb7/SPARK-24950_datetimeUtilsSuite_failures. --- .../catalyst/util/DateTimeUtilsSuite.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index cbf6106697f30..2423668392231 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -662,18 +662,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(daysToMillis(16800, TimeZoneGMT) === c.getTimeInMillis) // There are some days are skipped entirely in some timezone, skip them here. - val skipped_days = Map[String, Int]( - "Kwajalein" -> 8632, - "Pacific/Apia" -> 15338, - "Pacific/Enderbury" -> 9131, - "Pacific/Fakaofo" -> 15338, - "Pacific/Kiritimati" -> 9131, - "Pacific/Kwajalein" -> 8632, - "MIT" -> 15338) + val skipped_days = Map[String, Set[Int]]( + "Kwajalein" -> Set(8632), + "Pacific/Apia" -> Set(15338), + "Pacific/Enderbury" -> Set(9130, 9131), + "Pacific/Fakaofo" -> Set(15338), + "Pacific/Kiritimati" -> Set(9130, 9131), + "Pacific/Kwajalein" -> Set(8632), + "MIT" -> Set(15338)) for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { - val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) + val skipped = skipped_days.getOrElse(tz.getID, Set.empty) (-20000 to 20000).foreach { d => - if (d != skipped) { + if (!skipped.contains(d)) { assert(millisToDays(daysToMillis(d, tz), tz) === d, s"Round trip of ${d} did not work in tz ${tz}") } From 8fe5d2c393f035b9e82ba42202421c9ba66d6c78 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 29 Jul 2018 08:31:16 -0500 Subject: [PATCH 1239/2461] [SPARK-24956][Build][test-maven] Upgrade maven version to 3.5.4 ## What changes were proposed in this pull request? This PR updates maven version from 3.3.9 to 3.5.4. The current build process uses mvn 3.3.9 that was release on 2015, which looks pretty old. We met [an issue](https://issues.apache.org/jira/browse/SPARK-24895) to need the maven 3.5.2 or later. The release note of the 3.5.4 is [here](https://maven.apache.org/docs/3.5.4/release-notes.html). Note version 3.4 was skipped. From [the release note of the 3.5.0](https://maven.apache.org/docs/3.5.0/release-notes.html), the followings are new features: 1. ANSI color logging for improved output visibility 1. add support for module name != artifactId in every calculated URLs (project, SCM, site): special project.directory property 1. create a slf4j-simple provider extension that supports level color rendering 1. ModelResolver interface enhancement: addition of resolveModel(Dependency) supporting version ranges ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #21905 from kiszk/SPARK-24956. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index f320844faf222..9f60edcdc023e 100644 --- a/pom.xml +++ b/pom.xml @@ -114,7 +114,7 @@ 1.8 ${java.version} ${java.version} - 3.3.9 + 3.5.4 spark 1.7.16 1.2.17 From 2c54aae1bc2fa3da26917c89e6201fb2108d9fab Mon Sep 17 00:00:00 2001 From: liulijia Date: Sun, 29 Jul 2018 13:13:00 -0700 Subject: [PATCH 1240/2461] [SPARK-24809][SQL] Serializing LongToUnsafeRowMap in executor may result in data error When join key is long or int in broadcast join, Spark will use `LongToUnsafeRowMap` to store key-values of the table witch will be broadcasted. But, when `LongToUnsafeRowMap` is broadcasted to executors, and it is too big to hold in memory, it will be stored in disk. At that time, because `write` uses a variable `cursor` to determine how many bytes in `page` of `LongToUnsafeRowMap` will be write out and the `cursor` was not restore when deserializing, executor will write out nothing from page into disk. ## What changes were proposed in this pull request? Restore cursor value when deserializing. Author: liulijia Closes #21772 from liutang123/SPARK-24809. --- .../sql/execution/joins/HashedRelation.scala | 2 ++ .../execution/joins/HashedRelationSuite.scala | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 20ce01f4ce8cc..86eb47a70f1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -772,6 +772,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap array = readLongArray(readBuffer, length) val pageLength = readLong().toInt page = readLongArray(readBuffer, pageLength) + // Restore cursor variable to make this map able to be serialized again on executors. + cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET } override def readExternal(in: ObjectInput): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 037cc2e3ccad7..d9b34dcd16476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -278,6 +278,35 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24809: Serializing LongToUnsafeRowMap in executor may result in data error") { + val unsafeProj = UnsafeProjection.create(Array[DataType](LongType)) + val originalMap = new LongToUnsafeRowMap(mm, 1) + + val key1 = 1L + val value1 = 4852306286022334418L + + val key2 = 2L + val value2 = 8813607448788216010L + + originalMap.append(key1, unsafeProj(InternalRow(value1))) + originalMap.append(key2, unsafeProj(InternalRow(value2))) + originalMap.optimize() + + val ser = sparkContext.env.serializer.newInstance() + // Simulate serialize/deserialize twice on driver and executor + val firstTimeSerialized = ser.deserialize[LongToUnsafeRowMap](ser.serialize(originalMap)) + val secondTimeSerialized = + ser.deserialize[LongToUnsafeRowMap](ser.serialize(firstTimeSerialized)) + + val resultRow = new UnsafeRow(1) + assert(secondTimeSerialized.getValue(key1, resultRow).getLong(0) === value1) + assert(secondTimeSerialized.getValue(key2, resultRow).getLong(0) === value2) + + originalMap.free() + firstTimeSerialized.free() + secondTimeSerialized.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() From 3695ba57731a669ed20e7f676edee602c292fbed Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 30 Jul 2018 09:58:28 +0800 Subject: [PATCH 1241/2461] [MINOR][CORE][TEST] Fix afterEach() in TastSetManagerSuite and TaskSchedulerImplSuite ## What changes were proposed in this pull request? In the `afterEach()` method of both `TastSetManagerSuite` and `TaskSchedulerImplSuite`, `super.afterEach()` shall be called at the end, because it shall stop the SparkContext. https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/93706/testReport/org.apache.spark.scheduler/TaskSchedulerImplSuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/ The test failure is caused by the above reason, the newly added `barrierCoordinator` required `rpcEnv` which has been stopped before `TaskSchedulerImpl` doing cleanup. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #21908 from jiangxb1987/afterEach. --- .../org/apache/spark/scheduler/TaskSchedulerImplSuite.scala | 2 +- .../scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 624384abcd71d..16c273b7bc8a4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -62,7 +62,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } override def afterEach(): Unit = { - super.afterEach() if (taskScheduler != null) { taskScheduler.stop() taskScheduler = null @@ -71,6 +70,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B dagScheduler.stop() dagScheduler = null } + super.afterEach() } def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index cf05434aee301..d264adaef90a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -178,12 +178,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } override def afterEach(): Unit = { - super.afterEach() if (sched != null) { sched.dagScheduler.stop() sched.stop() sched = null } + super.afterEach() } From 3210121fed0ba256667f18f990c1a11d32c306ea Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Jul 2018 10:01:18 +0800 Subject: [PATCH 1242/2461] [MINOR][BUILD] Remove -Phive-thriftserver profile within appveyor.yml ## What changes were proposed in this pull request? This PR propose to remove `-Phive-thriftserver` profile which seems not affecting the SparkR tests in AppVeyor. Originally wanted to check if there's a meaningful build time decrease but seems not. It will have but seems not meaningfully decreased. ## How was this patch tested? AppVeyor tests: ``` [00:40:49] Attaching package: 'SparkR' [00:40:49] [00:40:49] The following objects are masked from 'package:testthat': [00:40:49] [00:40:49] describe, not [00:40:49] [00:40:49] The following objects are masked from 'package:stats': [00:40:49] [00:40:49] cov, filter, lag, na.omit, predict, sd, var, window [00:40:49] [00:40:49] The following objects are masked from 'package:base': [00:40:49] [00:40:49] as.data.frame, colnames, colnames<-, drop, endsWith, intersect, [00:40:49] rank, rbind, sample, startsWith, subset, summary, transform, union [00:40:49] [00:40:49] Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:41:43] basic tests for CRAN: ............. [00:41:43] [00:41:43] DONE =========================================================================== [00:41:43] binary functions: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:42:05] ........... [00:42:05] functions on binary files: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:42:10] .... [00:42:10] broadcast variables: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:42:12] .. [00:42:12] functions in client.R: ..... [00:42:30] test functions in sparkR.R: .............................................. [00:42:30] include R packages: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:42:31] [00:42:31] JVM API: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:42:31] .. [00:42:31] MLlib classification algorithms, except for tree-based algorithms: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:48:48] ...................................................................... [00:48:48] MLlib clustering algorithms: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:50:12] ..................................................................... [00:50:12] MLlib frequent pattern mining: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:50:18] ..... [00:50:18] MLlib recommendation algorithms: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:50:27] ........ [00:50:27] MLlib regression algorithms, except for tree-based algorithms: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:56:00] ................................................................................................................................ [00:56:00] MLlib statistics algorithms: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:56:04] ........ [00:56:04] MLlib tree-based algorithms: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:58:20] .............................................................................................. [00:58:20] parallelize() and collect(): Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [00:58:20] ............................. [00:58:20] basic RDD functions: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:03:35] ............................................................................................................................................................................................................................................................................................................................................................................................................................................ [01:03:35] SerDe functionality: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:03:39] ............................... [01:03:39] partitionBy, groupByKey, reduceByKey etc.: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:04:20] .................... [01:04:20] functions in sparkR.R: .... [01:04:20] SparkSQL functions: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:04:50] ........................................................................................................................................-chgrp: 'APPVYR-WIN\None' does not match expected pattern for group [01:04:50] Usage: hadoop fs [generic options] -chgrp [-R] GROUP PATH... [01:04:50] -chgrp: 'APPVYR-WIN\None' does not match expected pattern for group [01:04:50] Usage: hadoop fs [generic options] -chgrp [-R] GROUP PATH... [01:04:51] -chgrp: 'APPVYR-WIN\None' does not match expected pattern for group [01:04:51] Usage: hadoop fs [generic options] -chgrp [-R] GROUP PATH... [01:06:13] ............................................................................................................................................................................................................................................................................................................................................................-chgrp: 'APPVYR-WIN\None' does not match expected pattern for group [01:06:13] Usage: hadoop fs [generic options] -chgrp [-R] GROUP PATH... [01:06:14] .-chgrp: 'APPVYR-WIN\None' does not match expected pattern for group [01:06:14] Usage: hadoop fs [generic options] -chgrp [-R] GROUP PATH... [01:06:14] ....-chgrp: 'APPVYR-WIN\None' does not match expected pattern for group [01:06:14] Usage: hadoop fs [generic options] -chgrp [-R] GROUP PATH... [01:12:30] ................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................... [01:12:30] Structured Streaming: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:14:27] .......................................... [01:14:27] tests RDD function take(): Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:14:28] ................ [01:14:28] the textFile() function: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:14:44] ............. [01:14:44] functions in utils.R: Spark package found in SPARK_HOME: C:\projects\spark\bin\.. [01:14:46] ............................................ [01:14:46] Windows-specific tests: . [01:14:46] [01:14:46] DONE =========================================================================== [01:15:29] Build success ``` Author: hyukjinkwon Closes #21894 from HyukjinKwon/wip-build. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index aee94c59612d2..7fb45745a036f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,7 +48,7 @@ install: - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package + - cmd: mvn -DskipTests -Psparkr -Phive package environment: NOT_CRAN: true From 6690924c49a443cd629fcc1a4460cf443fb0a918 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Jul 2018 10:02:29 +0800 Subject: [PATCH 1243/2461] [MINOR] Avoid the 'latest' link that might vary per release in functions.scala's comment ## What changes were proposed in this pull request? This PR propose to address https://github.com/apache/spark/pull/21318#discussion_r187843125 comment. This is rather a nit but looks we better avoid to update the link for each release since it always points the latest (it doesn't look like worth enough updating release guide on the other hand as well). ## How was this patch tested? N/A Author: hyukjinkwon Closes #21907 from HyukjinKwon/minor-fix. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 277295816c978..a2d37928bff59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.Utils * * Spark also includes more built-in functions that are less common and are not defined here. * You can still access them (and all the functions defined here) using the `functions.expr()` API - * and calling them through a SQL expression string. You can find the entire list of functions for - * the latest version of Spark at https://spark.apache.org/docs/latest/api/sql/index.html. + * and calling them through a SQL expression string. You can find the entire list of functions + * at SQL API documentation. * * As an example, `isnan` is a function that is defined here. You can use `isnan(col("myCol"))` * to invoke the `isnan` function. This way the programming language's compiler ensures `isnan` From 65a4bc143ab5dc2ced589dc107bbafa8a7290931 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 29 Jul 2018 22:11:01 -0700 Subject: [PATCH 1244/2461] [SPARK-21274][SQL] Implement INTERSECT ALL clause ## What changes were proposed in this pull request? Implements INTERSECT ALL clause through query rewrites using existing operators in Spark. Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design. Input Query ``` SQL SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 ``` Rewritten Query ```SQL SELECT c1 FROM ( SELECT replicate_row(min_count, c1) FROM ( SELECT c1, IF (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count FROM ( SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt FROM ( SELECT c1, true as vcol1, null as vcol2 FROM ut1 UNION ALL SELECT c1, null as vcol1, true as vcol2 FROM ut2 ) AS union_all GROUP BY c1 HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 ) ) ) ``` ## How was this patch tested? Added test cases in SQLQueryTestSuite, DataFrameSuite, SetOperationSuite Author: Dilip Biswal Closes #21886 from dilipbiswal/dkb_intersect_all_final. --- python/pyspark/sql/dataframe.py | 22 ++ .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 81 +++++- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 7 +- .../optimizer/SetOperationSuite.scala | 32 ++- .../sql/catalyst/parser/PlanParserSuite.scala | 1 - .../scala/org/apache/spark/sql/Dataset.scala | 19 +- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../sql-tests/inputs/intersect-all.sql | 123 +++++++++ .../sql-tests/results/intersect-all.sql.out | 241 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 54 ++++ .../apache/spark/sql/test/SQLTestData.scala | 13 + 15 files changed, 599 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b2e0a5b2390c2..07fb260a77ea0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1500,6 +1500,28 @@ def intersect(self, other): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(2.4) + def intersectAll(self, other): + """ Return a new :class:`DataFrame` containing rows in both this dataframe and other + dataframe while preserving duplicates. + + This is equivalent to `INTERSECT ALL` in SQL. + >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) + >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) + + >>> df1.intersectAll(df2).sort("C1", "C2").show() + +---+---+ + | C1| C2| + +---+---+ + | a| 1| + | a| 1| + | b| 3| + +---+---+ + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8abb1c70d4919..9965cd654bcb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -914,7 +914,7 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) - case i @ Intersect(left, right) if !i.duplicateResolved => + case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => e.copy(right = dedupRight(left, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f9edca53d571e..7dd26b62b1fc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -325,11 +325,11 @@ object TypeCoercion { assert(newChildren.length == 2) Except(newChildren.head, newChildren.last, isAll) - case s @ Intersect(left, right) if s.childrenResolved && + case s @ Intersect(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last) + Intersect(newChildren.head, newChildren.last, isAll) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c9a3ee47a02be..cff4cee09427f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -309,7 +309,7 @@ object UnsupportedOperationChecker { case Except(left, right, _) if right.isStreaming => throwError("Except on a streaming DataFrame/Dataset on the right is not supported") - case Intersect(left, right) if left.isStreaming && right.isStreaming => + case Intersect(left, right, _) if left.isStreaming && right.isStreaming => throwError("Intersect between two streaming DataFrames/Datasets is not supported") case GroupingSets(_, _, child, _) if child.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 193f6591c9a8b..105623c767d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, RewriteExcepAll, + RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, @@ -1402,7 +1403,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { */ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Intersect(left, right) => + case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) @@ -1488,6 +1489,84 @@ object RewriteExcepAll extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator using a combination of Union, Aggregate + * and Generate operator. + * + * Input Query : + * {{{ + * SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 + * }}} + * + * Rewritten Query: + * {{{ + * SELECT c1 + * FROM ( + * SELECT replicate_row(min_count, c1) + * FROM ( + * SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count + * FROM ( + * SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt + * FROM ( + * SELECT true as vcol1, null as , c1 FROM ut1 + * UNION ALL + * SELECT null as vcol1, true as vcol2, c1 FROM ut2 + * ) AS union_all + * GROUP BY c1 + * HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 + * ) + * ) + * ) + * }}} + */ +object RewriteIntersectAll extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right, true) => + assert(left.output.size == right.output.size) + + val trueVcol1 = Alias(Literal(true), "vcol1")() + val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")() + + val trueVcol2 = Alias(Literal(true), "vcol2")() + val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")() + + // Add a projection on the top of left and right plans to project out + // the additional virtual columns. + val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left) + val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right) + + val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols) + + // Expressions to compute count and minimum of both the counts. + val vCol1AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")() + val vCol2AggrExpr = + Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")() + val ifExpression = Alias(If( + GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute), + vCol2AggrExpr.toAttribute, + vCol1AggrExpr.toAttribute + ), "min_count")() + + val aggregatePlan = Aggregate(left.output, + Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan) + val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)), + GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan) + val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan) + + // Apply the replicator to replicate rows based on min_count + val genRowPlan = Generate( + ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output), + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + left.output, + projectMinPlan + ) + Project(left.output, genRowPlan) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8b3c0686181fd..8a8db6df37094 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -533,7 +533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.UNION => Distinct(Union(left, right)) case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) + Intersect(left, right, isAll = true) case SqlBaseParser.INTERSECT => Intersect(left, right) case SqlBaseParser.EXCEPT if all => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 498a13a62bd22..13b51304d7f89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -164,7 +164,12 @@ object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Intersect( + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean = false) extends SetOperation(left, right) { + + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index f002aa3aacaba..cb744be400603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.BooleanType class SetOperationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -166,4 +167,33 @@ class SetOperationSuite extends PlanTest { )) comparePlans(expectedPlan, rewrittenPlan) } + + test("INTERSECT ALL rewrite") { + val input = Intersect(testRelation, testRelation2, isAll = true) + val rewrittenPlan = RewriteIntersectAll(input) + val leftRelation = testRelation + .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c) + val rightRelation = testRelation2 + .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f) + val planFragment = leftRelation.union(rightRelation) + .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"), + count('vcol2).as("vcol2_count"), 'a, 'b, 'c) + .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)), + GreaterThanOrEqual('vcol2_count, Literal(1L)))) + .select('a, 'b, 'c, + If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count")) + .analyze + val multiplerAttr = planFragment.output.last + val output = planFragment.output.dropRight(1) + val expectedPlan = Project(output, + Generate( + ReplicateRows(Seq(multiplerAttr) ++ output), + Nil, + false, + None, + output, + planFragment + )) + comparePlans(expectedPlan, rewrittenPlan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 629e3c4f3fcfb..9be0ec5af78ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -70,7 +70,6 @@ class PlanParserSuite extends AnalysisTest { intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") assertEqual("select * from a minus distinct select * from b", a.except(b)) assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e6a3b0adcdaa6..d36c8d13acca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1934,6 +1934,23 @@ class Dataset[T] private[sql]( Intersect(planWithBarrier, other.planWithBarrier) } + /** + * Returns a new Dataset containing rows only in both this Dataset and another Dataset while + * preserving the duplicates. + * This is equivalent to `INTERSECT ALL` in SQL. + * + * @note Equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. Also as standard + * in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.4.0 + */ + def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator { + Intersect(logicalPlan, other.logicalPlan, isAll = true) + } + + /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT DISTINCT` in SQL. @@ -1961,7 +1978,7 @@ class Dataset[T] private[sql]( * @since 2.4.0 */ def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier, isAll = true) + Except(logicalPlan, other.logicalPlan, isAll = true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3f5fd3dbb9e2f..75eff8a88312b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -529,9 +529,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.Intersect(left, right) => + case logical.Intersect(left, right, false) => throw new IllegalStateException( - "logical intersect operator should have been replaced by semi-join in the optimizer") + "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.Intersect(left, right, true) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by union, aggregate" + + "and generate operators in the optimizer") case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql new file mode 100644 index 0000000000000..ff4395c3e7447 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -0,0 +1,123 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v); +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v); + +-- Basic INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- INTERSECT ALL same table in both branches +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1; + +-- Empty left relation +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2; + +-- Empty right relation +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3; + +-- Type Coerced INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT); + +-- Error as types of two side are not compatible +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2; + +-- Mismatch on number of columns across both branches +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2; + +-- Basic +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +-- Chain of different `set operations +-- We need to parenthesize the following two queries to enforce +-- certain order of evaluation of operators. After fix to +-- SPARK-24966 this can be removed. +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +); + +-- Chain of different `set operations +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +); + +-- Join under intersect all +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Join under intersect all (2) +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k); + +-- Group by under intersect all +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k; + +-- Clean-up +DROP VIEW IF EXISTS tab1; +DROP VIEW IF EXISTS tab2; diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out new file mode 100644 index 0000000000000..792791bc51628 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -0,0 +1,241 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (1, 3), + (1, 3), + (2, 3), + (null, null), + (null, null) + AS tab1(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW tab2 AS SELECT * FROM VALUES + (1, 2), + (1, 2), + (2, 3), + (3, 4), + (null, null), + (null, null) + AS tab2(k, v) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 2 schema +struct +-- !query 2 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 3 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab1 WHERE k = 1 +-- !query 3 schema +struct +-- !query 3 output +1 2 +1 2 +1 3 +1 3 + + +-- !query 4 +SELECT * FROM tab1 WHERE k > 2 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 WHERE k > 3 +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +SELECT * FROM tab1 +INTERSECT ALL +SELECT CAST(1 AS BIGINT), CAST(2 AS BIGINT) +-- !query 6 schema +struct +-- !query 6 output +1 2 + + +-- !query 7 +SELECT * FROM tab1 +INTERSECT ALL +SELECT array(1), 2 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; + + +-- !query 8 +SELECT k FROM tab1 +INTERSECT ALL +SELECT k, v FROM tab2 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +IntersectAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; + + +-- !query 9 +SELECT * FROM tab2 +INTERSECT ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 9 schema +struct +-- !query 9 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 10 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +) +-- !query 10 schema +struct +-- !query 10 output +1 2 +1 2 +1 3 +2 3 +NULL NULL +NULL NULL + + +-- !query 11 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +EXCEPT +( +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +) +-- !query 11 schema +struct +-- !query 11 output +1 3 + + +-- !query 12 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 12 schema +struct +-- !query 12 output +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +1 2 +2 3 + + +-- !query 13 +SELECT * +FROM (SELECT tab1.k, + tab2.v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +INTERSECT ALL +SELECT * +FROM (SELECT tab2.v AS k, + tab1.k AS v + FROM tab1 + JOIN tab2 + ON tab1.k = tab2.k) +-- !query 13 schema +struct +-- !query 13 output + + + +-- !query 14 +SELECT v FROM tab1 GROUP BY v +INTERSECT ALL +SELECT k FROM tab2 GROUP BY k +-- !query 14 schema +struct +-- !query 14 output +2 +3 +NULL + + +-- !query 15 +DROP VIEW IF EXISTS tab1 +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +DROP VIEW IF EXISTS tab2 +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index af0735920cc29..b0e22a51e7611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -749,6 +749,60 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df4.schema.forall(!_.nullable)) } + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52152..deea9dbb30aae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -136,6 +136,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val lowerCaseDataWithDuplicates: DataFrame = { + val df = spark.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.createOrReplaceTempView("lowerCaseData") + df + } + protected lazy val arrayData: RDD[ArrayData] = { val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: From bfe60fcdb49aa48534060c38e36e06119900140d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Jul 2018 13:20:03 +0800 Subject: [PATCH 1245/2461] [SPARK-24934][SQL] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning ## What changes were proposed in this pull request? Looks we intentionally set `null` for upper/lower bounds for complex types and don't use it. However, these look used in in-memory partition pruning, which ends up with incorrect results. This PR proposes to explicitly whitelist the supported types. ```scala val df = Seq(Array("a", "b"), Array("c", "d")).toDF("arrayCol") df.cache().filter("arrayCol > array('a', 'b')").show() ``` ```scala val df = sql("select cast('a' as binary) as a") df.cache().filter("a == cast('a' as binary)").show() ``` **Before:** ``` +--------+ |arrayCol| +--------+ +--------+ ``` ``` +---+ | a| +---+ +---+ ``` **After:** ``` +--------+ |arrayCol| +--------+ | [c, d]| +--------+ ``` ``` +----+ | a| +----+ |[61]| +----+ ``` ## How was this patch tested? Unit tests were added and manually tested. Author: hyukjinkwon Closes #21882 from HyukjinKwon/stats-filter. --- .../columnar/InMemoryTableScanExec.scala | 42 +++++++++++++------ .../columnar/PartitionBatchPruningSuite.scala | 30 ++++++++++++- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 997cf92449c68..6012aba1acbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -183,6 +183,18 @@ case class InMemoryTableScanExec( private val stats = relation.partitionStatistics private def statsFor(a: Attribute) = stats.forAttribute(a) + // Currently, only use statistics from atomic types except binary type only. + private object ExtractableLiteral { + def unapply(expr: Expression): Option[Literal] = expr match { + case lit: Literal => lit.dataType match { + case BinaryType => None + case _: AtomicType => Some(lit) + case _ => None + } + case _ => None + } + } + // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { @@ -194,33 +206,37 @@ case class InMemoryTableScanExec( if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => buildFilter(lhs) || buildFilter(rhs) - case EqualTo(a: AttributeReference, l: Literal) => + case EqualTo(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualTo(l: Literal, a: AttributeReference) => + case EqualTo(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(a: AttributeReference, l: Literal) => + case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(l: Literal, a: AttributeReference) => + case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l - case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound + case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l + case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound - case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l - case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound + case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l + case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + l <= statsFor(a).upperBound - case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound - case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l + case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound + case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l - case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound - case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l + case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + l <= statsFor(a).upperBound + case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 case In(a: AttributeReference, list: Seq[Expression]) - if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => + if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9d862cfdecb21..af493e93b5192 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -35,6 +36,12 @@ class PartitionBatchPruningSuite private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) private lazy val originalInMemoryPartitionPruning = spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) + private val testArrayData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key)) + } + private val testBinaryData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key.toByte)) + } override protected def beforeAll(): Unit = { super.beforeAll() @@ -71,12 +78,22 @@ class PartitionBatchPruningSuite }, 5).toDF() pruningStringData.createOrReplaceTempView("pruningStringData") spark.catalog.cacheTable("pruningStringData") + + val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF() + pruningArrayData.createOrReplaceTempView("pruningArrayData") + spark.catalog.cacheTable("pruningArrayData") + + val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF() + pruningBinaryData.createOrReplaceTempView("pruningBinaryData") + spark.catalog.cacheTable("pruningBinaryData") } override protected def afterEach(): Unit = { try { spark.catalog.uncacheTable("pruningData") spark.catalog.uncacheTable("pruningStringData") + spark.catalog.uncacheTable("pruningArrayData") + spark.catalog.uncacheTable("pruningBinaryData") } finally { super.afterEach() } @@ -95,6 +112,14 @@ class PartitionBatchPruningSuite checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11) checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100) checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)( + testArrayData.map(_._1)) + // Do not filter on binary type + checkBatchPruning( + "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte))) // IS NULL checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { @@ -131,6 +156,9 @@ class PartitionBatchPruningSuite checkBatchPruning( "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)( Seq(150)) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)( + Seq(Array(1), Array(2, 2))) // With unsupported `InSet` predicate { @@ -161,7 +189,7 @@ class PartitionBatchPruningSuite query: String, expectedReadPartitions: Int, expectedReadBatches: Int)( - expectedQueryResult: => Seq[Int]): Unit = { + expectedQueryResult: => Seq[Any]): Unit = { test(query) { val df = sql(query) From 85505fc8a58ca229bbaf240c6bc23ea876d594db Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 30 Jul 2018 20:53:45 +0800 Subject: [PATCH 1246/2461] [SPARK-24957][SQL] Average with decimal followed by aggregation returns wrong result ## What changes were proposed in this pull request? When we do an average, the result is computed dividing the sum of the values by their count. In the case the result is a DecimalType, the way we are casting/managing the precision and scale is not really optimized and it is not coherent with what we do normally. In particular, a problem can happen when the `Divide` operand returns a result which contains a precision and scale different by the ones which are expected as output of the `Divide` operand. In the case reported in the JIRA, for instance, the result of the `Divide` operand is a `Decimal(38, 36)`, while the output data type for `Divide` is 38, 22. This is not an issue when the `Divide` is followed by a `CheckOverflow` or a `Cast` to the right data type, as these operations return a decimal with the defined precision and scale. Despite in the `Average` operator we do have a `Cast`, this may be bypassed if the result of `Divide` is the same type which it is casted to, hence the issue reported in the JIRA may arise. The PR proposes to use the normal rules/handling of the arithmetic operators with Decimal data type, so we both reuse the existing code (having a single logic for operations between decimals) and we fix this problem as the result is always guarded by `CheckOverflow`. ## How was this patch tested? added UT Author: Marco Gaido Closes #21910 from mgaido91/SPARK-24957. --- .../sql/catalyst/analysis/DecimalPrecision.scala | 2 +- .../catalyst/expressions/aggregate/Average.scala | 9 ++++----- .../sql/hive/execution/AggregationQuerySuite.scala | 13 +++++++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 65a5888222f2e..23d146e71ed19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -89,7 +89,7 @@ object DecimalPrecision extends TypeCoercionRule { } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index a133bc2361eb5..9ccf5aa092d11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -57,10 +57,9 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + case _: DecimalType => + Cast( + DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ae675149df5e2..c65bf7c14c7a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1005,6 +1005,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) ) } + + test("SPARK-24957: average with decimal followed by aggregation returning wrong result") { + val df = Seq(("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("11.9999999988"))).toDF("text", "number") + val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res")) + val agg2 = agg1.groupBy($"text").agg(sum($"avg_res")) + checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000"))) + } } From fca0b8528e704cfe62863a34f8bb5dcee850b046 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Jul 2018 21:13:08 +0800 Subject: [PATCH 1247/2461] [SPARK-24967][SQL] Avro: Use internal.Logging instead for logging ## What changes were proposed in this pull request? Looks Avro uses direct `getLogger` to create a SLF4J logger. Should better use `internal.Logging` instead. ## How was this patch tested? Exiting tests. Author: hyukjinkwon Closes #21914 from HyukjinKwon/avro-log. --- .../apache/spark/sql/avro/AvroFileFormat.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 1df1c8b4af2e9..e0159b9320276 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.avro import java.io._ import java.net.URI -import java.util.zip.Deflater import scala.util.control.NonFatal @@ -31,9 +30,9 @@ import org.apache.avro.mapreduce.AvroJob import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job -import org.slf4j.LoggerFactory import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} @@ -41,8 +40,8 @@ import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { - private val log = LoggerFactory.getLogger(getClass) +private[avro] class AvroFileFormat extends FileFormat + with DataSourceRegister with Logging with Serializable { override def equals(other: Any): Boolean = other match { case _: AvroFileFormat => true @@ -121,23 +120,23 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { parsedOptions.compression match { case "uncompressed" => - log.info("writing uncompressed Avro records") + logInfo("writing uncompressed Avro records") job.getConfiguration.setBoolean(COMPRESS_KEY, false) case "snappy" => - log.info("compressing Avro output using Snappy") + logInfo("compressing Avro output using Snappy") job.getConfiguration.setBoolean(COMPRESS_KEY, true) job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) case "deflate" => val deflateLevel = spark.sessionState.conf.avroDeflateLevel - log.info(s"compressing Avro output using deflate (level=$deflateLevel)") + logInfo(s"compressing Avro output using deflate (level=$deflateLevel)") job.getConfiguration.setBoolean(COMPRESS_KEY, true) job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) case unknown: String => - log.error(s"unsupported compression codec $unknown") + logError(s"unsupported compression codec $unknown") } new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) @@ -157,7 +156,6 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { val parsedOptions = new AvroOptions(options, hadoopConf) (file: PartitionedFile) => { - val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) val conf = broadcastedConf.value.value val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse) @@ -176,7 +174,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { DataFileReader.openReader(in, datumReader) } catch { case NonFatal(e) => - log.error("Exception while opening DataFileReader", e) + logError("Exception while opening DataFileReader", e) in.close() throw e } From b90bfe3c42eb9b51e6131a8f8923bcddfccd75bb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 30 Jul 2018 07:30:47 -0700 Subject: [PATCH 1248/2461] [SPARK-24771][BUILD] Upgrade Apache AVRO to 1.8.2 ## What changes were proposed in this pull request? Upgrade Apache Avro from 1.7.7 to 1.8.2. The major new features: 1. More logical types. From the spec of 1.8.2 https://avro.apache.org/docs/1.8.2/spec.html#Logical+Types we can see comparing to [1.7.7](https://avro.apache.org/docs/1.7.7/spec.html#Logical+Types), the new version support: - Date - Time (millisecond precision) - Time (microsecond precision) - Timestamp (millisecond precision) - Timestamp (microsecond precision) - Duration 2. Single-object encoding: https://avro.apache.org/docs/1.8.2/spec.html#single_object_encoding This PR aims to update Apache Spark to support these new features. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21761 from gengliangwang/upgrade_avro_1.8. --- dev/deps/spark-deps-hadoop-2.6 | 10 +++++----- dev/deps/spark-deps-hadoop-2.7 | 10 +++++----- dev/deps/spark-deps-hadoop-3.1 | 10 +++++----- pom.xml | 2 +- sql/core/pom.xml | 13 ------------- 5 files changed, 16 insertions(+), 29 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index ff6d5c30c1eb4..4ef61b2ab8cb7 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -18,9 +18,9 @@ arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar @@ -37,7 +37,7 @@ commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-dbcp-1.4.jar @@ -196,7 +196,7 @@ validation-api-1.1.0.Final.jar xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 72a94f8953c6c..a74ce1f26b146 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -18,9 +18,9 @@ arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar @@ -37,7 +37,7 @@ commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-dbcp-1.4.jar @@ -197,7 +197,7 @@ validation-api-1.1.0.Final.jar xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 3409dc4613324..e0fcca0eeb31e 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -16,9 +16,9 @@ arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar automaton-1.11-8.jar -avro-1.7.7.jar -avro-ipc-1.7.7.jar -avro-mapred-1.7.7-hadoop2.jar +avro-1.8.2.jar +avro-ipc-1.8.2.jar +avro-mapred-1.8.2-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar @@ -34,7 +34,7 @@ commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar commons-compiler-3.0.8.jar -commons-compress-1.4.1.jar +commons-compress-1.8.1.jar commons-configuration2-2.1.1.jar commons-crypto-1.0.0.jar commons-daemon-1.0.13.jar @@ -216,7 +216,7 @@ univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar xbean-asm6-shaded-4.8.jar -xz-1.0.jar +xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.9.jar zstd-jni-1.3.2-2.jar diff --git a/pom.xml b/pom.xml index 9f60edcdc023e..be84661a50dcc 100644 --- a/pom.xml +++ b/pom.xml @@ -140,7 +140,7 @@ 2.4.0 2.0.8 3.1.5 - 1.7.7 + 1.8.2 hadoop2 0.9.4 1.7.3 diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 8873b00e7117a..9cd6776a18bcb 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -174,19 +174,6 @@ parquet-avro test - - - org.apache.avro - avro - 1.8.1 - test - org.mockito mockito-core From 47d84e4d0e56e14f9402770dceaf0b4302c00e98 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 30 Jul 2018 07:42:00 -0700 Subject: [PATCH 1249/2461] [SPARK-22814][SQL] Support Date/Timestamp in a JDBC partition column ## What changes were proposed in this pull request? This pr supported Date/Timestamp in a JDBC partition column (a numeric column is only supported in the master). This pr also modified code to verify a partition column type; ``` val jdbcTable = spark.read .option("partitionColumn", "text") .option("lowerBound", "aaa") .option("upperBound", "zzz") .option("numPartitions", 2) .jdbc("jdbc:postgresql:postgres", "t", options) // with this pr org.apache.spark.sql.AnalysisException: Partition column type should be numeric, date, or timestamp, but string found.; at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation$.verifyAndGetNormalizedPartitionColumn(JDBCRelation.scala:165) at org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation$.columnPartition(JDBCRelation.scala:85) at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:36) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:317) // without this pr java.lang.NumberFormatException: For input string: "aaa" at java.lang.NumberFormatException.forInputString(NumberFormatException.java:65) at java.lang.Long.parseLong(Long.java:589) at java.lang.Long.parseLong(Long.java:631) at scala.collection.immutable.StringLike$class.toLong(StringLike.scala:277) ``` Closes #19999 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro Closes #21834 from maropu/SPARK-22814. --- docs/sql-programming-guide.md | 4 +- .../sql/jdbc/OracleIntegrationSuite.scala | 86 +++++++++++++- .../sql/catalyst/util/DateTimeUtils.scala | 10 +- .../datasources/PartitioningUtils.scala | 2 +- .../datasources/jdbc/JDBCOptions.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 107 ++++++++++++++---- .../jdbc/JdbcRelationProvider.scala | 21 +--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 77 ++++++++++++- 8 files changed, 258 insertions(+), 53 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4b013c633e27c..cff521c06a242 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1345,8 +1345,8 @@ the following case-insensitive options: These options must all be specified if any of them is specified. In addition, numPartitions must be specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. Notice - that lowerBound and upperBound are just used to decide the + partitionColumn must be a numeric, date, or timestamp column from the table in question. + Notice that lowerBound and upperBound are just used to decide the partition stride, not for filtering the rows in table. So all rows in the table will be partitioned and returned. This option applies only to reading. diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 8512496e5fe52..09a2cd83aed6b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.jdbc +import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.{Properties, TimeZone} -import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.execution.{RowDataSourceScanExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -86,7 +88,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo conn.prepareStatement( "CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate() conn.prepareStatement( - "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate() + "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)") + .executeUpdate() conn.commit() sql( @@ -108,15 +111,36 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))") + .executeUpdate() conn.prepareStatement( "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() conn.commit() - conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)") + .executeUpdate() conn.commit() - } + conn.prepareStatement("CREATE TABLE datetimePartitionTest (id NUMBER(10), d DATE, t TIMESTAMP)") + .executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(1, {d '2018-07-06'}, {ts '2018-07-06 05:50:00'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(2, {d '2018-07-06'}, {ts '2018-07-06 08:10:08'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(3, {d '2018-07-08'}, {ts '2018-07-08 13:32:01'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.prepareStatement( + """INSERT INTO datetimePartitionTest VALUES + |(4, {d '2018-07-12'}, {ts '2018-07-12 09:51:15'}) + """.stripMargin.replaceAll("\n", " ")).executeUpdate() + conn.commit() + } test("SPARK-16625 : Importing Oracle numeric types") { val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) @@ -399,4 +423,54 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getDouble(0) === 1.1) assert(values.getFloat(1) === 2.2f) } + + test("SPARK-22814 support date/timestamp types in partitionColumn") { + val expectedResult = Set( + (1, "2018-07-06", "2018-07-06 05:50:00"), + (2, "2018-07-06", "2018-07-06 08:10:08"), + (3, "2018-07-08", "2018-07-08 13:32:01"), + (4, "2018-07-12", "2018-07-12 09:51:15") + ).map { case (id, date, timestamp) => + Row(BigDecimal.valueOf(id), Date.valueOf(date), Timestamp.valueOf(timestamp)) + } + + // DateType partition column + val df1 = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "datetimePartitionTest") + .option("partitionColumn", "d") + .option("lowerBound", "2018-07-06") + .option("upperBound", "2018-07-20") + .option("numPartitions", 3) + .load() + + df1.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"D" < '2018-07-10' or "D" is null""", + """"D" >= '2018-07-10' AND "D" < '2018-07-14'""", + """"D" >= '2018-07-14'""")) + } + assert(df1.collect.toSet === expectedResult) + + // TimestampType partition column + val df2 = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "datetimePartitionTest") + .option("partitionColumn", "t") + .option("lowerBound", "2018-07-04 03:30:00.0") + .option("upperBound", "2018-07-27 14:11:05.0") + .option("numPartitions", 2) + .load() + + df2.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"T" < '2018-07-15 20:50:32.5' or "T" is null""", + """"T" >= '2018-07-15 20:50:32.5'""")) + } + assert(df2.collect.toSet === expectedResult) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 80f15053005ff..02813d3939796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -96,9 +96,9 @@ object DateTimeUtils { } } - def getThreadLocalDateFormat(): DateFormat = { + def getThreadLocalDateFormat(timeZone: TimeZone): DateFormat = { val sdf = threadLocalDateFormat.get() - sdf.setTimeZone(defaultTimeZone()) + sdf.setTimeZone(timeZone) sdf } @@ -144,7 +144,11 @@ object DateTimeUtils { } def dateToString(days: SQLDate): String = - getThreadLocalDateFormat.format(toJavaDate(days)) + getThreadLocalDateFormat(defaultTimeZone()).format(toJavaDate(days)) + + def dateToString(days: SQLDate, timeZone: TimeZone): String = { + getThreadLocalDateFormat(timeZone).format(toJavaDate(days)) + } // Converts Timestamp to string according to Hive TimestampWritable convention. def timestampToString(us: SQLTimestamp): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f9a24806953e6..c8a5f9864a602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -410,7 +410,7 @@ object PartitioningUtils { val dateTry = Try { // try and parse the date, if no exception occurs this is a candidate to be resolved as // DateType - DateTimeUtils.getThreadLocalDateFormat.parse(raw) + DateTimeUtils.getThreadLocalDateFormat(DateTimeUtils.defaultTimeZone()).parse(raw) // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. // This can happen since DateFormat.parse may not use the entire text of the given string: // so if there are extra-characters after the date, it returns correctly. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d80efcedf8c2d..7dfbb9d8b5c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -119,9 +119,9 @@ class JDBCOptions( // the column used to partition val partitionColumn = parameters.get(JDBC_PARTITION_COLUMN) // the lower bound of partition column - val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong) + val lowerBound = parameters.get(JDBC_LOWER_BOUND) // the upper bound of the partition column - val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong) + val upperBound = parameters.get(JDBC_UPPER_BOUND) // numPartitions is also used for data source writing require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) || (partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 4f78f593fa4af..f15014442e3fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.jdbc +import java.sql.{Date, Timestamp} + import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition @@ -24,9 +26,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} import org.apache.spark.util.Utils /** @@ -34,6 +37,7 @@ import org.apache.spark.util.Utils */ private[sql] case class JDBCPartitioningInfo( column: String, + columnType: DataType, lowerBound: Long, upperBound: Long, numPartitions: Int) @@ -51,16 +55,43 @@ private[sql] object JDBCRelation extends Logging { * the rows with null value for the partitions column. * * @param schema resolved schema of a JDBC table - * @param partitioning partition information to generate the where clause for each partition * @param resolver function used to determine if two identifiers are equal + * @param timeZoneId timezone ID to be used if a partition column type is date or timestamp * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ def columnPartition( schema: StructType, - partitioning: JDBCPartitioningInfo, resolver: Resolver, + timeZoneId: String, jdbcOptions: JDBCOptions): Array[Partition] = { + val partitioning = { + import JDBCOptions._ + + val partitionColumn = jdbcOptions.partitionColumn + val lowerBound = jdbcOptions.lowerBound + val upperBound = jdbcOptions.upperBound + val numPartitions = jdbcOptions.numPartitions + + if (partitionColumn.isEmpty) { + assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not " + + s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") + null + } else { + assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, + s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + + s"'$JDBC_NUM_PARTITIONS' are also required") + + val (column, columnType) = verifyAndGetNormalizedPartitionColumn( + schema, partitionColumn.get, resolver, jdbcOptions) + + val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType) + val upperBoundValue = toInternalBoundValue(upperBound.get, columnType) + JDBCPartitioningInfo( + column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get) + } + } + if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -72,6 +103,8 @@ private[sql] object JDBCRelation extends Logging { "Operation not allowed: the lower bound of partitioning column is larger than the upper " + s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") + val boundValueToString: Long => String = + toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId) val numPartitions = if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */ (upperBound - lowerBound) < 0) { @@ -80,24 +113,25 @@ private[sql] object JDBCRelation extends Logging { logWarning("The number of partitions is reduced because the specified number of " + "partitions is less than the difference between upper bound and lower bound. " + s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + - s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " + - s"Upper bound: $upperBound.") + s"partitions: ${partitioning.numPartitions}; " + + s"Lower bound: ${boundValueToString(lowerBound)}; " + + s"Upper bound: ${boundValueToString(upperBound)}.") upperBound - lowerBound } // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = verifyAndGetNormalizedColumnName( - schema, partitioning.column, resolver, jdbcOptions) - var i: Int = 0 - var currentValue: Long = lowerBound + val column = partitioning.column + var currentValue = lowerBound val ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lBound = if (i != 0) s"$column >= $currentValue" else null + val lBoundValue = boundValueToString(currentValue) + val lBound = if (i != 0) s"$column >= $lBoundValue" else null currentValue += stride - val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val uBoundValue = boundValueToString(currentValue) + val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null val whereClause = if (uBound == null) { lBound @@ -109,23 +143,58 @@ private[sql] object JDBCRelation extends Logging { ans += JDBCPartition(whereClause, i) i = i + 1 } - ans.toArray + val partitions = ans.toArray + logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " + + partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", ")) + partitions } - // Verify column name based on the JDBC resolved schema - private def verifyAndGetNormalizedColumnName( + // Verify column name and type based on the JDBC resolved schema + private def verifyAndGetNormalizedPartitionColumn( schema: StructType, columnName: String, resolver: Resolver, - jdbcOptions: JDBCOptions): String = { + jdbcOptions: JDBCOptions): (String, DataType) = { val dialect = JdbcDialects.get(jdbcOptions.url) - schema.map(_.name).find { fieldName => - resolver(fieldName, columnName) || - resolver(dialect.quoteIdentifier(fieldName), columnName) - }.map(dialect.quoteIdentifier).getOrElse { + val column = schema.find { f => + resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) + }.getOrElse { throw new AnalysisException(s"User-defined partition column $columnName not " + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") } + column.dataType match { + case _: NumericType | DateType | TimestampType => + case _ => + throw new AnalysisException( + s"Partition column type should be ${NumericType.simpleString}, " + + s"${DateType.catalogString}, or ${TimestampType.catalogString}, but " + + s"${column.dataType.catalogString} found.") + } + (dialect.quoteIdentifier(column.name), column.dataType) + } + + private def toInternalBoundValue(value: String, columnType: DataType): Long = columnType match { + case _: NumericType => value.toLong + case DateType => DateTimeUtils.fromJavaDate(Date.valueOf(value)).toLong + case TimestampType => DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(value)) + } + + private def toBoundValueInWhereClause( + value: Long, + columnType: DataType, + timeZoneId: String): String = { + def dateTimeToString(): String = { + val timeZone = DateTimeUtils.getTimeZone(timeZoneId) + val dateTimeStr = columnType match { + case DateType => DateTimeUtils.dateToString(value.toInt, timeZone) + case TimestampType => DateTimeUtils.timestampToString(value, timeZone) + } + s"'$dateTimeStr'" + } + columnType match { + case _: NumericType => value.toString + case DateType | TimestampType => dateTimeToString() + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 782d626c1573c..e7456f9c8ed0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -29,28 +29,11 @@ class JdbcRelationProvider extends CreatableRelationProvider override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - import JDBCOptions._ - val jdbcOptions = new JDBCOptions(parameters) - val partitionColumn = jdbcOptions.partitionColumn - val lowerBound = jdbcOptions.lowerBound - val upperBound = jdbcOptions.upperBound - val numPartitions = jdbcOptions.numPartitions - - val partitionInfo = if (partitionColumn.isEmpty) { - assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " + - s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") - null - } else { - assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, - s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + - s"'$JDBC_NUM_PARTITIONS' are also required") - JDBCPartitioningInfo( - partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) - } val resolver = sqlContext.conf.resolver + val timeZoneId = sqlContext.conf.sessionLocalTimeZone val schema = JDBCRelation.getSchema(resolver, jdbcOptions) - val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0edbd3a55e17e..7fa0e7fc162ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,7 +24,7 @@ import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -244,6 +244,17 @@ class JDBCSuite extends QueryTest .executeUpdate() conn.commit() + conn.prepareStatement("CREATE TABLE test.datetime (d DATE, t TIMESTAMP)").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-06', '2018-07-06 05:50:00.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-06', '2018-07-06 08:10:08.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-08', '2018-07-08 13:32:01.0')").executeUpdate() + conn.prepareStatement( + "INSERT INTO test.datetime VALUES ('2018-07-12', '2018-07-12 09:51:15.0')").executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1375,7 +1386,71 @@ class JDBCSuite extends QueryTest checkAnswer( sql("select name, theid from queryOption"), Row("fred", 1) :: Nil) + } + + test("SPARK-22814 support date/timestamp types in partitionColumn") { + val expectedResult = Seq( + ("2018-07-06", "2018-07-06 05:50:00.0"), + ("2018-07-06", "2018-07-06 08:10:08.0"), + ("2018-07-08", "2018-07-08 13:32:01.0"), + ("2018-07-12", "2018-07-12 09:51:15.0") + ).map { case (date, timestamp) => + Row(Date.valueOf(date), Timestamp.valueOf(timestamp)) + } + + // DateType partition column + val df1 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.DATETIME") + .option("partitionColumn", "d") + .option("lowerBound", "2018-07-06") + .option("upperBound", "2018-07-20") + .option("numPartitions", 3) + .load() + df1.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"D" < '2018-07-10' or "D" is null""", + """"D" >= '2018-07-10' AND "D" < '2018-07-14'""", + """"D" >= '2018-07-14'""")) + } + checkAnswer(df1, expectedResult) + + // TimestampType partition column + val df2 = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.DATETIME") + .option("partitionColumn", "t") + .option("lowerBound", "2018-07-04 03:30:00.0") + .option("upperBound", "2018-07-27 14:11:05.0") + .option("numPartitions", 2) + .load() + + df2.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + """"T" < '2018-07-15 20:50:32.5' or "T" is null""", + """"T" >= '2018-07-15 20:50:32.5'""")) + } + checkAnswer(df2, expectedResult) + } + + test("throws an exception for unsupported partition column types") { + val errMsg = intercept[AnalysisException] { + spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("partitionColumn", "name") + .option("lowerBound", "aaa") + .option("upperBound", "zzz") + .option("numPartitions", 2) + .load() + }.getMessage + assert(errMsg.contains( + "Partition column type should be numeric, date, or timestamp, but string found.")) } test("SPARK-24288: Enable preventing predicate pushdown") { From d6b7545b5f495a496d40a982e0ab0f8053e1a4f5 Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 30 Jul 2018 11:41:02 -0700 Subject: [PATCH 1250/2461] [SPARK-24963][K8S][TESTS] Don't set service account name for client mode test ## What changes were proposed in this pull request? Don't set service account name for the pod created in client mode ## How was this patch tested? Test should continue running smoothly in Jenkins. Author: mcheah Closes #21900 from mccheah/fix-integration-test-service-account. --- .../spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala index 0690db7f67cf0..6affee2e5141a 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -61,7 +61,6 @@ trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => .withLabels(labels.asJava) .endMetadata() .withNewSpec() - .withServiceAccountName("default") .addNewContainer() .withName("spark-example") .withImage(image) From abbb4ab4d8b12ba2d94b16407c0d62ae207ee4fa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 30 Jul 2018 14:05:45 -0700 Subject: [PATCH 1251/2461] [SPARK-24865][SQL] Remove AnalysisBarrier addendum ## What changes were proposed in this pull request? I didn't want to pollute the diff in the previous PR and left some TODOs. This is a follow-up to address those TODOs. ## How was this patch tested? Should be covered by existing tests. Author: Reynold Xin Closes #21896 from rxin/SPARK-24865-addendum. --- .../sql/catalyst/analysis/Analyzer.scala | 1 - .../scala/org/apache/spark/sql/Dataset.scala | 88 +++++++++---------- .../spark/sql/RelationalGroupedDataset.scala | 13 ++- .../ContinuousAggregationSuite.scala | 4 +- 4 files changed, 50 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9965cd654bcb5..1488ededa38d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -799,7 +799,6 @@ class Analyzer( right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - // TODO(rxin): Why do we need transformUp here? right transformUp { case r if r == oldRelation => newRelation } transformUp { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d36c8d13acca9..3b0a6d8840f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -195,10 +195,6 @@ class Dataset[T] private[sql]( } } - // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - // TODO(rxin): remove this later. - @transient private[sql] val planWithBarrier = logicalPlan - /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -427,7 +423,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -681,7 +677,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) } /** @@ -854,7 +850,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } /** @@ -932,7 +928,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -993,7 +989,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -1002,8 +998,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed - val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -1040,7 +1036,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) } /** @@ -1072,8 +1068,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.planWithBarrier, - other.planWithBarrier, + this.logicalPlan, + other.logicalPlan, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1294,7 +1290,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, planWithBarrier) + SubqueryAlias(alias, logicalPlan) } /** @@ -1332,7 +1328,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), planWithBarrier) + Project(cols.map(_.named), logicalPlan) } /** @@ -1387,8 +1383,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, - planWithBarrier) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, + logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1406,8 +1402,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) + columns.map(_.withInputType(exprEnc, logicalPlan.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1483,7 +1479,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, planWithBarrier) + Filter(condition.expr, logicalPlan) } /** @@ -1662,7 +1658,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = planWithBarrier + val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1808,7 +1804,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), planWithBarrier) + Limit(Literal(n), logicalPlan) } /** @@ -1931,7 +1927,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(planWithBarrier, other.planWithBarrier) + Intersect(logicalPlan, other.logicalPlan) } /** @@ -1962,7 +1958,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier) + Except(logicalPlan, other.logicalPlan) } /** @@ -2029,7 +2025,7 @@ class Dataset[T] private[sql]( */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, planWithBarrier) + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } } @@ -2071,15 +2067,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = planWithBarrier.output + val sortOrder = logicalPlan.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, planWithBarrier) + Sort(sortOrder, global = false, logicalPlan) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - planWithBarrier + logicalPlan } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -2163,7 +2159,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2204,7 +2200,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2355,7 +2351,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.planWithBarrier.output + val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2403,7 +2399,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, planWithBarrier) + Deduplicate(groupCols, logicalPlan) } /** @@ -2585,7 +2581,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2599,7 +2595,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2613,7 +2609,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, planWithBarrier) + MapElements[T, U](func, logicalPlan) } /** @@ -2628,7 +2624,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, planWithBarrier)) + withTypedPlan(MapElements[T, U](func, logicalPlan)) } /** @@ -2644,7 +2640,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, planWithBarrier), + MapPartitions[T, U](func, logicalPlan), implicitly[Encoder[U]]) } @@ -2675,7 +2671,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) } /** @@ -2839,7 +2835,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, planWithBarrier) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -2862,7 +2858,7 @@ class Dataset[T] private[sql]( |For range partitioning use repartitionByRange(...) instead. """.stripMargin) withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } } @@ -2900,7 +2896,7 @@ class Dataset[T] private[sql]( case expr: Expression => SortOrder(expr, Ascending) }) withTypedPlan { - RepartitionByExpression(sortOrder, planWithBarrier, numPartitions) + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) } } @@ -2939,7 +2935,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, planWithBarrier) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** @@ -3024,7 +3020,7 @@ class Dataset[T] private[sql]( // Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`. @transient private lazy val rddQueryExecution: QueryExecution = { - val deserialized = CatalystSerde.deserialize[T](planWithBarrier) + val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized) } @@ -3150,7 +3146,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = planWithBarrier, + child = logicalPlan, allowExisting = false, replace = replace, viewType = viewType) @@ -3363,7 +3359,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, planWithBarrier) + Sort(sortOrder, global = global, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index b068493f2dd17..8412219b1250b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -62,18 +62,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => - Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) + Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } @@ -433,7 +432,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.planWithBarrier)) + df.logicalPlan)) } /** @@ -459,7 +458,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.planWithBarrier + val child = df.logicalPlan val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index 0223812600961..c5b95fa9b64a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -74,7 +74,7 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { val df = input.toDF() .select('value as 'copy, 'value) .where('copy =!= 1) - .planWithBarrier + .logicalPlan .coalesce(1) .where('copy =!= 2) .agg(max('value)) @@ -95,7 +95,7 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { val df = input.toDF() .coalesce(1) - .planWithBarrier + .logicalPlan .coalesce(1) .select('value as 'copy, 'value) .agg(max('value)) From 2fbe294cf01f78b34498553d9228b57e2f992bce Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 30 Jul 2018 15:57:54 -0700 Subject: [PATCH 1252/2461] [SPARK-24963][K8S][TESTS] Add user-specified service account name for client mode test driver pod ## What changes were proposed in this pull request? Adds the user-set service account name for the driver pod in the client mode integration test ## How was this patch tested? Manual test against a custom Kubernetes cluster Author: mcheah Closes #21924 from mccheah/fix-service-account. --- .../spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala | 1 + .../deploy/k8s/integrationtest/KubernetesTestComponents.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala index 6affee2e5141a..159cfd97ff403 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -61,6 +61,7 @@ trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => .withLabels(labels.asJava) .endMetadata() .withNewSpec() + .withServiceAccountName(kubernetesTestComponents.serviceAccountName) .addNewContainer() .withName("spark-example") .withImage(image) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index a9b49a8e5a610..b602fdf39731f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -32,7 +32,7 @@ private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesCl val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) val hasUserSpecifiedNamespace = namespaceOption.isDefined val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) - private val serviceAccountName = + val serviceAccountName = Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) .getOrElse("default") val kubernetesClient = defaultClient.inNamespace(namespace) From d20c10fdf382acf43a7e6a541923bd078e19ca75 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 31 Jul 2018 09:12:57 +0800 Subject: [PATCH 1253/2461] [SPARK-24952][SQL] Support LZMA2 compression by Avro datasource ## What changes were proposed in this pull request? In the PR, I propose to support `LZMA2` (`XZ`) and `BZIP2` compressions by `AVRO` datasource in write since the codecs may have better characteristics like compression ratio and speed comparing to already supported `snappy` and `deflate` codecs. ## How was this patch tested? It was tested manually and by an existing test which was extended to check the `xz` and `bzip2` compressions. Author: Maxim Gekk Closes #21902 from MaxGekk/avro-xz-bzip2. --- .../spark/sql/avro/AvroFileFormat.scala | 40 +++++++++---------- .../apache/spark/sql/avro/AvroOptions.scala | 2 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 14 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 6 ++- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index e0159b9320276..7db452bb6b09a 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -23,7 +23,8 @@ import java.net.URI import scala.util.control.NonFatal import org.apache.avro.Schema -import org.apache.avro.file.{DataFileConstants, DataFileReader} +import org.apache.avro.file.DataFileConstants._ +import org.apache.avro.file.DataFileReader import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob @@ -116,27 +117,22 @@ private[avro] class AvroFileFormat extends FileFormat dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) - val COMPRESS_KEY = "mapred.output.compress" - - parsedOptions.compression match { - case "uncompressed" => - logInfo("writing uncompressed Avro records") - job.getConfiguration.setBoolean(COMPRESS_KEY, false) - - case "snappy" => - logInfo("compressing Avro output using Snappy") - job.getConfiguration.setBoolean(COMPRESS_KEY, true) - job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) - - case "deflate" => - val deflateLevel = spark.sessionState.conf.avroDeflateLevel - logInfo(s"compressing Avro output using deflate (level=$deflateLevel)") - job.getConfiguration.setBoolean(COMPRESS_KEY, true) - job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) - job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) - - case unknown: String => - logError(s"unsupported compression codec $unknown") + + if (parsedOptions.compression == "uncompressed") { + job.getConfiguration.setBoolean("mapred.output.compress", false) + } else { + job.getConfiguration.setBoolean("mapred.output.compress", true) + logInfo(s"Compressing Avro output using the ${parsedOptions.compression} codec") + val codec = parsedOptions.compression match { + case DEFLATE_CODEC => + val deflateLevel = spark.sessionState.conf.avroDeflateLevel + logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + DEFLATE_CODEC + case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC) => codec + case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") + } + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) } new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 0f59007e7f72c..67f56343b4524 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -72,7 +72,7 @@ class AvroOptions( /** * The `compression` option allows to specify a compression codec used in write. - * Currently supported codecs are `uncompressed`, `snappy` and `deflate`. + * Currently supported codecs are `uncompressed`, `snappy`, `deflate`, `bzip2` and `xz`. * If the option is not set, the `spark.sql.avro.compression.codec` config is taken into * account. If the former one is not set too, the `snappy` codec is used by default. */ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index f59c2cc6ffaaf..c221c4fd07de7 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.avro import java.io._ import java.net.URL -import java.nio.file.{Files, Path, Paths} +import java.nio.file.{Files, Paths} import java.sql.{Date, Timestamp} import java.util.{TimeZone, UUID} @@ -368,12 +368,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("write with compression - sql configs") { withTempPath { dir => val uncompressDir = s"$dir/uncompress" + val bzip2Dir = s"$dir/bzip2" + val xzDir = s"$dir/xz" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" val df = spark.read.format("avro").load(testAvro) spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") df.write.format("avro").save(uncompressDir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "bzip2") + df.write.format("avro").save(bzip2Dir) + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "xz") + df.write.format("avro").save(xzDir) spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "deflate") spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9") df.write.format("avro").save(deflateDir) @@ -381,11 +387,15 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { df.write.format("avro").save(snappyDir) val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val bzip2Size = FileUtils.sizeOfDirectory(new File(bzip2Dir)) + val xzSize = FileUtils.sizeOfDirectory(new File(xzDir)) val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) assert(uncompressSize > deflateSize) assert(snappySize > deflateSize) + assert(snappySize > bzip2Size) + assert(bzip2Size > xzSize) } } @@ -921,6 +931,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkCodec(df, path, "uncompressed") checkCodec(df, path, "deflate") checkCodec(df, path, "snappy") + checkCodec(df, path, "bzip2") + checkCodec(df, path, "xz") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a269e218c4efd..edc1a488150c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,6 +27,7 @@ import scala.collection.immutable import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.tukaani.xz.LZMA2Options import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging @@ -1437,9 +1438,10 @@ object SQLConf { .createWithDefault(20) val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") - .doc("Compression codec used in writing of AVRO files. Default codec is snappy.") + .doc("Compression codec used in writing of AVRO files. Supported codecs: " + + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") .stringConf - .checkValues(Set("uncompressed", "deflate", "snappy")) + .checkValues(Set("uncompressed", "deflate", "snappy", "bzip2", "xz")) .createWithDefault("snappy") val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") From f1550aaf1506c0115c8d95cd8bc784ed6c734ea5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 31 Jul 2018 09:14:29 +0800 Subject: [PATCH 1254/2461] [SPARK-24956][BUILD][FOLLOWUP] Upgrade Maven version to 3.5.4 for AppVeyor as well ## What changes were proposed in this pull request? Maven version was upgraded and AppVeyor should also use upgraded maven version. Currently, it looks broken by this: https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark/build/2458-master ``` [WARNING] Rule 0: org.apache.maven.plugins.enforcer.RequireMavenVersion failed with message: Detected Maven Version: 3.3.9 is not in the allowed range 3.5.4. [INFO] ------------------------------------------------------------------------ [INFO] Reactor Summary: ``` ## How was this patch tested? AppVeyor tests Author: hyukjinkwon Closes #21920 from HyukjinKwon/SPARK-24956. --- dev/appveyor-install-dependencies.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index e6afb18558852..8a04b621f8ce4 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -81,7 +81,7 @@ if (!(Test-Path $tools)) { # ========================== Maven Push-Location $tools -$mavenVer = "3.3.9" +$mavenVer = "3.5.4" Start-FileDownload "https://archive.apache.org/dist/maven/maven-3/$mavenVer/binaries/apache-maven-$mavenVer-bin.zip" "maven.zip" # extract From 8141d55926e95c06cd66bf82098895e1ed419449 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 31 Jul 2018 10:10:38 +0800 Subject: [PATCH 1255/2461] [SPARK-23633][SQL] Update Pandas UDFs section in sql-programming-guide ## What changes were proposed in this pull request? Update Pandas UDFs section in sql-programming-guide. Add section for grouped aggregate pandas UDF. ## How was this patch tested? Author: Li Jin Closes #21887 from icexelloss/SPARK-23633-sql-programming-guide. --- docs/sql-programming-guide.md | 19 ++++++++++++++ examples/src/main/python/sql/arrow.py | 37 +++++++++++++++++++++++++++ python/pyspark/sql/functions.py | 5 ++-- 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cff521c06a242..5f1eee85b5154 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1811,6 +1811,25 @@ The following example shows how to use `groupby().apply()` to subtract the mean For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and [`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). +### Grouped Aggregate + +Grouped aggregate Pandas UDFs are similar to Spark aggregate functions. Grouped aggregate Pandas UDFs are used with `groupBy().agg()` and +[`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). It defines an aggregation from one or more `pandas.Series` +to a scalar value, where each `pandas.Series` represents a column within the group or window. + +Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded into memory. Also, +only unbounded window is supported with Grouped aggregate Pandas UDFs currently. + +The following example shows how to use this type of UDF to compute mean with groupBy and window operations: + +
      +
      +{% include_example grouped_agg_pandas_udf python/sql/arrow.py %} +
      +
      + +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + ## Usage Notes ### Supported SQL Types diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 4c5aefb6ff4a6..6c4510d9e3c01 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -113,6 +113,43 @@ def substract_mean(pdf): # $example off:grouped_map_pandas_udf$ +def grouped_agg_pandas_udf_example(spark): + # $example on:grouped_agg_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql import Window + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def mean_udf(v): + return v.mean() + + df.groupby("id").agg(mean_udf(df['v'])).show() + # +---+-----------+ + # | id|mean_udf(v)| + # +---+-----------+ + # | 1| 1.5| + # | 2| 6.0| + # +---+-----------+ + + w = Window \ + .partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() + # +---+----+------+ + # | id| v|mean_v| + # +---+----+------+ + # | 1| 1.0| 1.5| + # | 1| 2.0| 1.5| + # | 2| 3.0| 6.0| + # | 2| 5.0| 6.0| + # | 2|10.0| 6.0| + # +---+----+------+ + # $example off:grouped_agg_pandas_udf$ + + if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0a88e482787ff..dd7daf946dd41 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2810,8 +2810,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() - >>> w = Window.partitionBy('id') \\ - ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> w = Window \\ + ... .partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP +---+----+------+ | id| v|mean_v| From b4fd75fb9b615cfe592ad269cf20d02b483a0d33 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 30 Jul 2018 23:43:53 -0700 Subject: [PATCH 1256/2461] [SPARK-24972][SQL] PivotFirst could not handle pivot columns of complex types ## What changes were proposed in this pull request? When the pivot column is of a complex type, the eval() result will be an UnsafeRow, while the keys of the HashMap for column value matching is a GenericInternalRow. As a result, there will be no match and the result will always be empty. So for a pivot column of complex-types, we should: 1) If the complex-type is not comparable (orderable), throw an Exception. It cannot be a pivot column. 2) Otherwise, if it goes through the `PivotFirst` code path, `PivotFirst` should use a TreeMap instead of HashMap for such columns. This PR has also reverted the walk-around in Analyzer that had been introduced to avoid this `PivotFirst` issue. ## How was this patch tested? Added UT. Author: maryannxue Closes #21926 from maryannxue/pivot_followup. --- .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../expressions/aggregate/PivotFirst.scala | 11 +- .../test/resources/sql-tests/inputs/pivot.sql | 78 +++++++++++- .../resources/sql-tests/results/pivot.sql.out | 116 ++++++++++++++++-- 4 files changed, 199 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1488ededa38d1..76dc86710909e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -529,6 +529,10 @@ class Analyzer( || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + if (!RowOrdering.isOrderable(pivotColumn.dataType)) { + throw new AnalysisException( + s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + } // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) // Check all pivot values are literal and match pivot column data type. @@ -574,10 +578,14 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val bigGroup = groupByExprs ++ pivotColumn.references + val namedPivotCol = pivotColumn match { + case n: NamedExpression => n + case _ => Alias(pivotColumn, "__pivot_col")() + } + val bigGroup = groupByExprs :+ namedPivotCol val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 523714869242d..33bc5b5821b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import scala.collection.immutable.HashMap +import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ object PivotFirst { @@ -83,7 +83,12 @@ case class PivotFirst( override val dataType: DataType = ArrayType(valueDataType) - val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) { + HashMap(pivotColumnValues.zipWithIndex: _*) + } else { + TreeMap(pivotColumnValues.zipWithIndex: _*)( + TypeUtils.getInterpretedOrdering(pivotColumn.dataType)) + } val indexSize = pivotIndex.size diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index a6c8d4854ff38..1f607b334dc18 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,10 +11,10 @@ create temporary view years as select * from values (2013, 2) as years(y, s); -create temporary view yearsWithArray as select * from values - (2012, array(1, 1)), - (2013, array(2, 2)) - as yearsWithArray(y, a); +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); -- pivot courses SELECT * FROM ( @@ -204,7 +204,7 @@ PIVOT ( SELECT * FROM ( SELECT course, year, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( min(a) @@ -215,9 +215,75 @@ PIVOT ( SELECT * FROM ( SELECT course, year, y, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( max(a) FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) ); + +-- pivot on pivot column of array type +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +); + +-- pivot on multiple pivot columns containing array type +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +); + +-- pivot on pivot column of struct type +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +); + +-- pivot on multiple pivot columns containing struct type +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +); + +-- pivot on pivot column of map type +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +); + +-- pivot on multiple pivot columns containing map type +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 6bb51b946f960..2dd92930f92aa 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 31 -- !query 0 @@ -28,10 +28,10 @@ struct<> -- !query 2 -create temporary view yearsWithArray as select * from values - (2012, array(1, 1)), - (2013, array(2, 2)) - as yearsWithArray(y, a) +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) -- !query 2 schema struct<> -- !query 2 output @@ -346,7 +346,7 @@ Literal expressions required for pivot values, found 'course#x'; SELECT * FROM ( SELECT course, year, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( min(a) @@ -363,7 +363,7 @@ struct,Java:array> SELECT * FROM ( SELECT course, year, y, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( max(a) @@ -374,3 +374,105 @@ struct,[2013, Java]:array> -- !query 24 output 2012 [1,1] NULL 2013 NULL [2,2] + + +-- !query 25 +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +) +-- !query 25 schema +struct +-- !query 25 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 26 +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +) +-- !query 26 schema +struct +-- !query 26 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 27 +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +) +-- !query 27 schema +struct +-- !query 27 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 28 +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +) +-- !query 28 schema +struct +-- !query 28 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 29 +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'm#x'. Pivot columns must be comparable.; + + +-- !query 30 +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.; From 4ac2126bc64bad1b4cbe1c697b4bcafacd67c96c Mon Sep 17 00:00:00 2001 From: Mauro Palsgraaf Date: Tue, 31 Jul 2018 08:18:08 -0700 Subject: [PATCH 1257/2461] [SPARK-24536] Validate that an evaluated limit clause cannot be null ## What changes were proposed in this pull request? It proposes a version in which nullable expressions are not valid in the limit clause ## How was this patch tested? It was tested with unit and e2e tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Mauro Palsgraaf Closes #21807 from mauropalsgraaf/SPARK-24536. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 +++-- .../analysis/AnalysisErrorSuite.scala | 6 +++ .../test/resources/sql-tests/inputs/limit.sql | 5 +++ .../resources/sql-tests/results/limit.sql.out | 45 +++++++++++++------ 4 files changed, 50 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f9478a1c3cf4b..4addc83add3e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -68,10 +68,14 @@ trait CheckAnalysis extends PredicateHelper { case e if e.dataType != IntegerType => failAnalysis( s"The limit expression must be integer type, but got " + e.dataType.catalogString) - case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis( - "The limit expression must be equal to or greater than 0, but got " + - e.eval().asInstanceOf[Int]) - case e => // OK + case e => + e.eval() match { + case null => failAnalysis( + s"The evaluated limit expression must not be null, but got ${limitExpr.sql}") + case v: Int if v < 0 => failAnalysis( + s"The limit expression must be equal to or greater than 0, but got $v") + case _ => // OK + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index f4cfed4a91594..ae8d77bbbf9a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -399,6 +399,12 @@ class AnalysisErrorSuite extends AnalysisTest { "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil ) + errorTest( + "an evaluated limit class must not be null", + testRelation.limit(Literal(null, IntegerType)), + "The evaluated limit expression must not be null, but got " :: Nil + ) + errorTest( "num_rows in limit clause must be equal to or greater than 0", listRelation.limit(-1), diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index f21912a042716..b4c73cf33e53a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -13,6 +13,11 @@ SELECT * FROM testdata LIMIT CAST(1 AS int); SELECT * FROM testdata LIMIT -1; SELECT * FROM testData TABLESAMPLE (-1 ROWS); + +SELECT * FROM testdata LIMIT CAST(1 AS INT); +-- evaluated limit must not be null +SELECT * FROM testdata LIMIT CAST(NULL AS INT); + -- limit must be foldable SELECT * FROM testdata LIMIT key > 3; diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 146abe6cbd058..02fe1de84f753 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 14 -- !query 0 @@ -66,44 +66,61 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -SELECT * FROM testdata LIMIT key > 3 +SELECT * FROM testdata LIMIT CAST(1 AS INT) -- !query 7 schema -struct<> +struct -- !query 7 output -org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +1 1 -- !query 8 -SELECT * FROM testdata LIMIT true +SELECT * FROM testdata LIMIT CAST(NULL AS INT) -- !query 8 schema struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The evaluated limit expression must not be null, but got CAST(NULL AS INT); -- !query 9 -SELECT * FROM testdata LIMIT 'a' +SELECT * FROM testdata LIMIT key > 3 -- !query 9 schema struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 10 -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +SELECT * FROM testdata LIMIT true -- !query 10 schema -struct +struct<> -- !query 10 output -4 +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got boolean; -- !query 11 -SELECT * FROM testdata WHERE key < 3 LIMIT ALL +SELECT * FROM testdata LIMIT 'a' -- !query 11 schema -struct +struct<> -- !query 11 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; + + +-- !query 12 +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +-- !query 12 schema +struct +-- !query 12 output +4 + + +-- !query 13 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 13 schema +struct +-- !query 13 output 1 1 2 2 From 1223a201fcb2c2f211ad96997ebb00c3554aa822 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Tue, 31 Jul 2018 13:37:13 -0500 Subject: [PATCH 1258/2461] [SPARK-24609][ML][DOC] PySpark/SparkR doc doesn't explain RandomForestClassifier.featureSubsetStrategy well ## What changes were proposed in this pull request? update doc of RandomForestClassifier.featureSubsetStrategy ## How was this patch tested? local built doc rdoc: ![default](https://user-images.githubusercontent.com/7322292/42807787-4dda6362-89e4-11e8-839f-a8519b7c1f1c.png) pydoc: ![default](https://user-images.githubusercontent.com/7322292/43112817-5f1d4d88-8f2a-11e8-93ff-de90db8afdca.png) Author: zhengruifeng Closes #21788 from zhengruifeng/rf_doc_py_r. --- R/pkg/R/mllib_tree.R | 13 ++++++++++++- python/pyspark/ml/regression.py | 9 +++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 6769be038efa9..0e60842dd44c8 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -362,7 +362,18 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' For regression, must be "variance". For classification, must be one of #' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. -#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' Supported options: "auto" (choose automatically for task: If +#' numTrees == 1, set to "all." If numTrees > 1 +#' (forest), set to "sqrt" for classification and +#' to "onethird" for regression), +#' "all" (use all features), +#' "onethird" (use 1/3 of the features), +#' "sqrt" (use sqrt(number of features)), +#' "log2" (use log2(number of features)), +#' "n": (when n is in the range (0, 1.0], use +#' n * number of features. When n is in the range +#' (1, number of features), use n features). +#' Default is "auto". #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in #' range (0, 1]. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 83f0edb397271..564c9f1b8f729 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -608,8 +608,13 @@ class TreeEnsembleParams(DecisionTreeParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) + "options: 'auto' (choose automatically for task: If numTrees == 1, set to " + + "'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " + + "'onethird' for regression), 'all' (use all features), 'onethird' (use " + + "1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " + + "log2(number of features)), 'n' (when n is in the range (0, 1.0], use " + + "n * number of features. When n is in the range (1, number of features), use" + + " n features). default = 'auto'", typeConverter=TypeConverters.toString) def __init__(self): super(TreeEnsembleParams, self).__init__() From e82784d13fac7d45164dfadb00d3fa43e64e0bde Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 31 Jul 2018 13:14:14 -0700 Subject: [PATCH 1259/2461] [SPARK-18057][SS] Update Kafka client version from 0.10.0.1 to 2.0.0 ## What changes were proposed in this pull request? This PR upgrades to the Kafka 2.0.0 release where KIP-266 is integrated. ## How was this patch tested? This PR uses existing Kafka related unit tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: tedyu Closes #21488 from tedyu/master. --- external/kafka-0-10-sql/pom.xml | 24 +++++++++++-- .../kafka010/KafkaContinuousSourceSuite.scala | 1 + .../kafka010/KafkaMicroBatchSourceSuite.scala | 7 +++- .../spark/sql/kafka010/KafkaTestUtils.scala | 36 ++++++++++++------- 4 files changed, 53 insertions(+), 15 deletions(-) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 16bbc6db641ca..95500037c1473 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -29,10 +29,10 @@ spark-sql-kafka-0-10_2.11 sql-kafka-0-10 - 0.10.0.1 + 2.0.0 jar - Kafka 0.10 Source for Structured Streaming + Kafka 0.10+ Source for Structured Streaming http://spark.apache.org/ @@ -73,6 +73,20 @@ kafka_${scala.binary.version} ${kafka.version} test + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + +
      net.sf.jopt-simple @@ -80,6 +94,12 @@ 3.2 test + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index aab8ec42189fb..ea2a2a84d22c6 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -42,6 +42,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", s"$topicPrefix-.*") .option("failOnDataLoss", "false") diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 5d5e57323cff5..aa898686c77ca 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -290,6 +290,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", s"$topicPrefix-.*") .option("failOnDataLoss", "false") @@ -467,6 +468,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribe", topic) // If a topic is deleted and we try to poll data starting from offset 0, // the Kafka consumer will just block until timeout and return an empty result. @@ -1103,6 +1105,7 @@ class KafkaSourceStressSuite extends KafkaSourceTest { .option("kafka.metadata.max.age.ms", "1") .option("subscribePattern", "stress.*") .option("failOnDataLoss", "false") + .option("kafka.default.api.timeout.ms", "3000") .load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] @@ -1173,7 +1176,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at // least 30 seconds. props.put("log.cleaner.backoff.ms", "100") - props.put("log.segment.bytes", "40") + // The size of RecordBatch V2 increases to support transactional write. + props.put("log.segment.bytes", "70") props.put("log.retention.bytes", "40") props.put("log.retention.check.interval.ms", "100") props.put("delete.retention.ms", "10") @@ -1215,6 +1219,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") .option("subscribePattern", "failOnDataLoss.*") .option("startingOffsets", "earliest") .option("failOnDataLoss", "false") diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 75245943c4936..82294905c24b9 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -29,12 +29,15 @@ import scala.util.Random import kafka.admin.AdminUtils import kafka.api.Request -import kafka.common.TopicAndPartition -import kafka.server.{KafkaConfig, KafkaServer, OffsetCheckpoint} +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.server.checkpoints.OffsetCheckpointFile import kafka.utils.ZkUtils +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{AdminClient, CreatePartitionsOptions, NewPartitions} import org.apache.kafka.clients.consumer.KafkaConsumer import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.scalatest.concurrent.Eventually._ @@ -61,6 +64,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L private var zookeeper: EmbeddedZookeeper = _ private var zkUtils: ZkUtils = _ + private var adminClient: AdminClient = null // Kafka broker related configurations private val brokerHost = "localhost" @@ -113,17 +117,23 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort() + brokerPort = server.boundPort(new ListenerName("PLAINTEXT")) (server, brokerPort) }, new SparkConf(), "KafkaBroker") brokerReady = true + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, s"$brokerHost:$brokerPort") + adminClient = AdminClient.create(props) } /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { setupEmbeddedZookeeper() setupEmbeddedKafkaServer() + eventually(timeout(60.seconds)) { + assert(zkUtils.getAllBrokersInCluster().nonEmpty, "Broker was not up in 60 seconds") + } } /** Teardown the whole servers, including Kafka broker and Zookeeper */ @@ -203,7 +213,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L /** Add new partitions to a Kafka topic */ def addPartitions(topic: String, partitions: Int): Unit = { - AdminUtils.addPartitions(zkUtils, topic, partitions) + adminClient.createPartitions( + Map(topic -> NewPartitions.increaseTo(partitions)).asJava, + new CreatePartitionsOptions) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) @@ -296,6 +308,8 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") props.put("offsets.topic.num.partitions", "1") + props.put("offsets.topic.replication.factor", "1") + props.put("group.initial.rebalance.delay.ms", "10") // Can not use properties.putAll(propsMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 withBrokerProps.foreach { case (k, v) => props.put(k, v) } @@ -327,7 +341,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L topic: String, numPartitions: Int, servers: Seq[KafkaServer]): Unit = { - val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + val topicAndPartitions = (0 until numPartitions).map(new TopicPartition(topic, _)) import ZkUtils._ // wait until admin path for delete topic is deleted, signaling completion of topic deletion @@ -337,7 +351,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L assert(!zkUtils.pathExists(getTopicPath(topic)), s"${getTopicPath(topic)} still exists") // ensure that the topic-partition has been deleted from all brokers' replica managers assert(servers.forall(server => topicAndPartitions.forall(tp => - server.replicaManager.getPartition(tp.topic, tp.partition) == None)), + server.replicaManager.getPartition(tp) == None)), s"topic $topic still exists in the replica manager") // ensure that logs from all replicas are deleted if delete topic is marked successful assert(servers.forall(server => topicAndPartitions.forall(tp => @@ -345,8 +359,8 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L s"topic $topic still exists in log mananger") // ensure that topic is removed from all cleaner offsets assert(servers.forall(server => topicAndPartitions.forall { tp => - val checkpoints = server.getLogManager().logDirs.map { logDir => - new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + val checkpoints = server.getLogManager().liveLogDirs.map { logDir => + new OffsetCheckpointFile(new File(logDir, "cleaner-offset-checkpoint")).read() } checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) }), s"checkpoint for topic $topic still exists") @@ -379,11 +393,9 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => - val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr - zkUtils.getLeaderForPartition(topic, partition).isDefined && - Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && - leaderAndInSyncReplicas.isr.nonEmpty + Request.isValidBrokerId(partitionState.basePartitionState.leader) && + !partitionState.basePartitionState.replicas.isEmpty case _ => false From 42dfe4f1593767eae355e27bf969339f4ab03f56 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 31 Jul 2018 15:23:11 -0500 Subject: [PATCH 1260/2461] [SPARK-24973][PYTHON] Add numIter to Python ClusteringSummary ## What changes were proposed in this pull request? Add numIter to Python version of ClusteringSummary ## How was this patch tested? Modified existing UT test_multiclass_logistic_regression_summary Author: Huaxin Gao Closes #21925 from huaxingao/spark-24973. --- python/pyspark/ml/clustering.py | 8 ++++++++ python/pyspark/ml/tests.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 8a58d838819e2..ef9822d0ca5a5 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -88,6 +88,14 @@ def clusterSizes(self): """ return self._call_java("clusterSizes") + @property + @since("2.4.0") + def numIter(self): + """ + Number of iterations. + """ + return self._call_java("numIter") + class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index bc782138292bf..3d8883b486e4c 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1888,6 +1888,7 @@ def test_gaussian_mixture_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 3) def test_bisecting_kmeans_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), @@ -1903,6 +1904,7 @@ def test_bisecting_kmeans_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 20) def test_kmeans_summary(self): data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), @@ -1918,6 +1920,7 @@ def test_kmeans_summary(self): self.assertTrue(isinstance(s.cluster, DataFrame)) self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 1) class KMeansTests(SparkSessionTestCase): From f4772fd26f32b11ae54e7721924b5cf6eb27298a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 31 Jul 2018 17:24:24 -0700 Subject: [PATCH 1261/2461] [SPARK-24976][PYTHON] Allow None for Decimal type conversion (specific to PyArrow 0.9.0) ## What changes were proposed in this pull request? See [ARROW-2432](https://jira.apache.org/jira/browse/ARROW-2432). Seems using `from_pandas` to convert decimals fails if encounters a value of `None`: ```python import pyarrow as pa import pandas as pd from decimal import Decimal pa.Array.from_pandas(pd.Series([Decimal('3.14'), None]), type=pa.decimal128(3, 2)) ``` **Arrow 0.8.0** ``` [ Decimal('3.14'), NA ] ``` **Arrow 0.9.0** ``` Traceback (most recent call last): File "", line 1, in File "array.pxi", line 383, in pyarrow.lib.Array.from_pandas File "array.pxi", line 177, in pyarrow.lib.array File "error.pxi", line 77, in pyarrow.lib.check_status File "error.pxi", line 77, in pyarrow.lib.check_status pyarrow.lib.ArrowInvalid: Error converting from Python objects to Decimal: Got Python object of type NoneType but can only handle these types: decimal.Decimal ``` This PR propose to work around this via Decimal NaN: ```python pa.Array.from_pandas(pd.Series([Decimal('3.14'), Decimal('NaN')]), type=pa.decimal128(3, 2)) ``` ``` [ Decimal('3.14'), NA ] ``` ## How was this patch tested? Manually tested: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.sql.tests ScalarPandasUDFTests ``` **Before** ``` Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4672, in test_vectorized_udf_null_decimal self.assertEquals(df.collect(), res.collect()) File "/.../spark/python/pyspark/sql/dataframe.py", line 533, in collect sock_info = self._jdf.collectToPython() File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__ answer, self.gateway_client, self.target_id, self.name) File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value format(target_id, ".", name), value) Py4JJavaError: An error occurred while calling o51.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 1.0 failed 1 times, most recent failure: Lost task 3.0 in stage 1.0 (TID 7, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/.../spark/python/pyspark/worker.py", line 320, in main process() File "/.../spark/python/pyspark/worker.py", line 315, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/.../spark/python/pyspark/serializers.py", line 274, in dump_stream batch = _create_batch(series, self._timezone) File "/.../spark/python/pyspark/serializers.py", line 243, in _create_batch arrs = [create_array(s, t) for s, t in series] File "/.../spark/python/pyspark/serializers.py", line 241, in create_array return pa.Array.from_pandas(s, mask=mask, type=t) File "array.pxi", line 383, in pyarrow.lib.Array.from_pandas File "array.pxi", line 177, in pyarrow.lib.array File "error.pxi", line 77, in pyarrow.lib.check_status File "error.pxi", line 77, in pyarrow.lib.check_status ArrowInvalid: Error converting from Python objects to Decimal: Got Python object of type NoneType but can only handle these types: decimal.Decimal ``` **After** ``` Running tests... ---------------------------------------------------------------------- Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). .......S............................. ---------------------------------------------------------------------- Ran 37 tests in 21.980s ``` Author: hyukjinkwon Closes #21928 from HyukjinKwon/SPARK-24976. --- python/pyspark/serializers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4c16b5fc26f3d..82abf1947c818 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -216,9 +216,10 @@ def _create_batch(series, timezone): :param timezone: A timezone to respect when handling timestamp values :return: Arrow RecordBatch """ - - from pyspark.sql.types import _check_series_convert_timestamps_internal + import decimal + from distutils.version import LooseVersion import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ (len(series) == 2 and isinstance(series[1], pa.DataType)): @@ -236,6 +237,11 @@ def create_array(s, t): # TODO: need decode before converting to Arrow in Python 2 return pa.Array.from_pandas(s.apply( lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] From 5f3441e542bfacd81d70bd8b34c22044c8928bff Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 1 Aug 2018 10:31:02 +0800 Subject: [PATCH 1262/2461] [SPARK-24893][SQL] Remove the entire CaseWhen if all the outputs are semantic equivalence ## What changes were proposed in this pull request? Similar to SPARK-24890, if all the outputs of `CaseWhen` are semantic equivalence, `CaseWhen` can be removed. ## How was this patch tested? Tests added. Author: DB Tsai Closes #21852 from dbtsai/short-circuit-when. --- .../sql/catalyst/optimizer/expressions.scala | 18 +++++++ .../optimizer/SimplifyConditionalSuite.scala | 48 ++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4696699337c9d..e7b4730e11115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -416,6 +416,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) + + case e @ CaseWhen(branches, Some(elseValue)) + if branches.forall(_._2.semanticEquals(elseValue)) => + // For non-deterministic conditions with side effect, we can not remove it, or change + // the ordering. As a result, we try to remove the deterministic conditions from the tail. + var hitNonDeterministicCond = false + var i = branches.length + while (i > 0 && !hitNonDeterministicCond) { + hitNonDeterministicCond = !branches(i - 1)._1.deterministic + if (!hitNonDeterministicCond) { + i -= 1 + } + } + if (i == 0) { + elseValue + } else { + e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index e210874a55d87..8ad7c12020b82 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -46,7 +45,9 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) - private val testRelation = LocalRelation('a.int) + val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a"))) + val isNullCond = IsNull(UnresolvedAttribute("b")) + val notCond = Not(UnresolvedAttribute("c")) test("simplify if") { assertEquivalent( @@ -122,4 +123,47 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { None), CaseWhen(normalBranch :: trueBranch :: Nil, None)) } + + test("simplify CaseWhen if all the outputs are semantic equivalence") { + // When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can be removed. + assertEquivalent( + CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) :: + (isNullCond, Literal(1)) :: + (notCond, Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + Literal(1) + ) + + // For non-deterministic conditions, we don't remove the `CaseWhen` statement. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Add(Literal(6), Literal(-5))) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (EqualTo(Rand(2), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + + // When we have mixture of deterministic and non-deterministic conditions, we remove + // the deterministic conditions from the tail until a non-deterministic one is seen. + assertEquivalent( + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Subtract(Literal(3), Literal(2))) :: + (NonFoldableLiteral(true), Add(Literal(2), Literal(-1))) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Add(Literal(6), Literal(-5))) :: + (NonFoldableLiteral(false), Literal(1)) :: + Nil, + Add(Literal(2), Literal(-1))), + CaseWhen((GreaterThan(Rand(0), Literal(0.5)), Literal(1)) :: + (NonFoldableLiteral(true), Literal(1)) :: + (LessThan(Rand(1), Literal(0.5)), Literal(1)) :: + Nil, + Literal(1)) + ) + } } From 1f7e22c72c89fc2c0e729dde0948bc6bdf8f7628 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 31 Jul 2018 22:25:40 -0700 Subject: [PATCH 1263/2461] [SPARK-24951][SQL] Table valued functions should throw AnalysisException ## What changes were proposed in this pull request? Previously TVF resolution could throw IllegalArgumentException if the data type is null type. This patch replaces that exception with AnalysisException, enriched with positional information, to improve error message reporting and to be more consistent with rest of Spark SQL. ## How was this patch tested? Updated the test case in table-valued-functions.sql.out, which is how I identified this problem in the first place. Author: Reynold Xin Closes #21934 from rxin/SPARK-24951. --- .../ResolveTableValuedFunctions.scala | 34 ++++++++++++++----- .../results/table-valued-functions.sql.out | 9 +++-- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 7358f9ee36921..983e4b0e901cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ @@ -68,9 +69,11 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { : (ArgumentList, Seq[Any] => LogicalPlan) = { (ArgumentList(args: _*), pf orElse { - case args => - throw new IllegalArgumentException( - "Invalid arguments for resolved function: " + args.mkString(", ")) + case arguments => + // This is caught again by the apply function and rethrow with richer information about + // position, etc, for a better error message. + throw new AnalysisException( + "Invalid arguments for resolved function: " + arguments.mkString(", ")) }) } @@ -105,22 +108,35 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + // The whole resolution is somewhat difficult to understand here due to too much abstractions. + // We should probably rewrite the following at some point. Reynold was just here to improve + // error messages and didn't have time to do a proper rewrite. val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => + + def failAnalysis(): Nothing = { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: ($argTypes)""".stripMargin) + } + val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { case Some(casted) => - Some(resolver(casted.map(_.eval()))) + try { + Some(resolver(casted.map(_.eval()))) + } catch { + case e: AnalysisException => + failAnalysis() + } case _ => None } } resolved.headOption.getOrElse { - val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") - u.failAnalysis( - s"""error: table-valued function ${u.functionName} with alternatives: - |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} - |cannot be applied to: (${argTypes})""".stripMargin) + failAnalysis() } case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index a8bc6faf11262..94af9181225d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -83,8 +83,13 @@ select * from range(1, null) -- !query 6 schema struct<> -- !query 6 output -java.lang.IllegalArgumentException -Invalid arguments for resolved function: 1, null +org.apache.spark.sql.AnalysisException +error: table-valued function range with alternatives: + (end: long) + (start: long, end: long) + (start: long, end: long, step: long) + (start: long, end: long, step: long, numPartitions: integer) +cannot be applied to: (integer, null); line 1 pos 14 -- !query 7 From 1efffb7993ecebe5dc1f9ebd924e7503bfd9668c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Aug 2018 00:15:31 -0700 Subject: [PATCH 1264/2461] [SPARK-24982][SQL] UDAF resolution should not throw AssertionError ## What changes were proposed in this pull request? When user calls anUDAF with the wrong number of arguments, Spark previously throws an AssertionError, which is not supposed to be a user-facing exception. This patch updates it to throw AnalysisException instead, so it is consistent with a regular UDF. ## How was this patch tested? Updated test case udaf.sql. Author: Reynold Xin Closes #21938 from rxin/SPARK-24982. --- .../sql/catalyst/catalog/SessionCatalog.scala | 15 ++++++++++++--- .../test/resources/sql-tests/results/udaf.sql.out | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b09b81eabf60d..2f60eb30f7240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils @@ -1124,13 +1124,22 @@ class SessionCatalog( name: String, clazz: Class[_], input: Seq[Expression]): Expression = { + // Unfortunately we need to use reflection here because UserDefinedAggregateFunction + // and ScalaUDAF are defined in sql/core module. val clsForUDAF = Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) - .asInstanceOf[Expression] + .asInstanceOf[ImplicitCastInputTypes] + + // Check input argument size + if (e.inputTypes.size != input.size) { + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: ${e.inputTypes.size}; Found: ${input.size}") + } + e } else { throw new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}'. " + s"Use sparkSession.udf.register(...) instead.") diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 4815a578b1029..87824ab81cdf7 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -33,8 +33,8 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 -- !query 3 schema struct<> -- !query 3 output -java.lang.AssertionError -assertion failed: Incorrect number of children +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function default.myDoubleAvg. Expected: 1; Found: 2; line 1 pos 7 -- !query 4 From 1122754bd9c5aa1b434c2b0ad856bc8511cd2ee2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 1 Aug 2018 15:47:46 +0800 Subject: [PATCH 1265/2461] [SPARK-24653][TESTS] Avoid cross-job pollution in TestUtils / SpillListener. There is a narrow race in this code that is caused when the code being run in assertSpilled / assertNotSpilled runs more than a single job. SpillListener assumed that only a single job was run, and so would only block waiting for that single job to finish when `numSpilledStages` was called. But some tests (like SQL tests that call `checkAnswer`) run more than one job, and so that wait was basically a no-op. This could cause the next test to install a listener to receive events from the previous job. Which could cause test failures in certain cases. The change fixes that race, and also uninstalls listeners after the test runs, so they don't accumulate when the SparkContext is shared among multiple tests. Author: Marcelo Vanzin Closes #21639 from vanzin/SPARK-24653. --- .../scala/org/apache/spark/TestUtils.scala | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index b5c4c705dcbc7..6cc8fe1173d2e 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.{Arrays, Properties} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -172,22 +172,22 @@ private[spark] object TestUtils { /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ - def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + withListener(sc, new SpillListener) { listener => + body + assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + } } /** * Run some code involving jobs submitted to the given context and assert that the jobs * did not spill. */ - def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { + withListener(sc, new SpillListener) { listener => + body + assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + } } /** @@ -233,6 +233,21 @@ private[spark] object TestUtils { } } + /** + * Runs some code with the given listener installed in the SparkContext. After the code runs, + * this method will wait until all events posted to the listener bus are processed, and then + * remove the listener from the bus. + */ + def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = { + sc.addSparkListener(listener) + try { + body(listener) + } finally { + sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10)) + sc.listenerBus.removeListener(listener) + } + } + /** * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting * time elapsed before `numExecutors` executors up. Exposed for testing. @@ -289,21 +304,17 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] - private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = { - // Long timeout, just in case somehow the job end isn't notified. - // Fails if a timeout occurs - assert(stagesDone.await(10, TimeUnit.SECONDS)) + def numSpilledStages: Int = synchronized { spilledStageIds.size } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { stageIdToTaskMetrics.getOrElseUpdate( taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics } - override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized { val stageId = stageComplete.stageInfo.stageId val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 @@ -311,8 +322,4 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - stagesDone.countDown() - } } From defc54c69aadc510c6f77e13e57f003646c461bc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Aug 2018 21:39:35 +0800 Subject: [PATCH 1266/2461] [SPARK-24971][SQL] remove SupportsDeprecatedScanRow ## What changes were proposed in this pull request? This is a follow up of https://github.com/apache/spark/pull/21118 . In https://github.com/apache/spark/pull/21118 we added `SupportsDeprecatedScanRow`. Ideally data source should produce `InternalRow` instead of `Row` for better performance. We should remove `SupportsDeprecatedScanRow` and encourage data sources to produce `InternalRow`, which is also very easy to build. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #21921 from cloud-fan/row. --- .../sources/v2/reader/DataSourceReader.java | 6 +- .../v2/reader/InputPartitionReader.java | 3 +- .../datasources/v2/DataSourceV2ScanExec.scala | 36 +------ .../continuous/ContinuousDataSourceRDD.scala | 4 - .../ContinuousRateStreamSource.scala | 26 ++--- .../sources/ContinuousMemoryStream.scala | 27 +++-- .../sources/RateStreamMicroBatchReader.scala | 21 ++-- .../execution/streaming/sources/socket.scala | 31 +++--- .../sources/v2/JavaAdvancedDataSourceV2.java | 20 ++-- .../v2/JavaPartitionAwareDataSource.java | 18 ++-- .../v2/JavaSchemaRequiredDataSource.java | 7 +- .../sources/v2/JavaSimpleDataSourceV2.java | 19 ++-- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 90 ----------------- .../sources/RateStreamProviderSuite.scala | 15 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 99 +++++-------------- .../sources/v2/SimpleWritableDataSource.scala | 15 ++- .../sources/StreamingDataSourceV2Suite.scala | 10 +- 17 files changed, 133 insertions(+), 314 deletions(-) delete mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index ad9c838992fa8..4a7462096db16 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -39,11 +39,7 @@ * pruning), etc. Names of these interfaces start with `SupportsPushDown`. * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. - * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. Note that a reader should only - * implement at most one of the special scans, if more than one special scans are implemented, - * only one of them would be respected, according to the priority list from high to low: - * {@link SupportsScanColumnarBatch}, {@link SupportsDeprecatedScanRow}. + * 3. Columnar scan if implements {@link SupportsScanColumnarBatch}. * * If an exception was throw when applying any of these query optimizations, the action will fail * and no Spark job will be submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 7cf382e52f67e..f3ff7f5cc0f20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -28,8 +28,7 @@ * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data - * source readers that mix in {@link SupportsScanColumnarBatch}, or {@link org.apache.spark.sql.Row} - * for data source readers that mix in {@link SupportsDeprecatedScanRow}. + * source readers that mix in {@link SupportsScanColumnarBatch}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index b030b9a929b08..c8494f97f1761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition @@ -31,7 +29,6 @@ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -75,13 +72,8 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[InternalRow]] = reader match { - case r: SupportsDeprecatedScanRow => - r.planRowInputPartitions().asScala.map { - new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[InternalRow] - } - case _ => - reader.planInputPartitions().asScala + private lazy val partitions: Seq[InputPartition[InternalRow]] = { + reader.planInputPartitions().asScala } private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { @@ -131,27 +123,3 @@ case class DataSourceV2ScanExec( } } } - -class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) - extends InputPartition[InternalRow] { - - override def preferredLocations: Array[String] = partition.preferredLocations - - override def createPartitionReader: InputPartitionReader[InternalRow] = { - new RowToUnsafeInputPartitionReader( - partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) - } -} - -class RowToUnsafeInputPartitionReader( - val rowReader: InputPartitionReader[Row], - encoder: ExpressionEncoder[Row]) - - extends InputPartitionReader[InternalRow] { - - override def next: Boolean = rowReader.next - - override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] - - override def close(): Unit = rowReader.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 1ffa1d02f1432..554a0b0573f4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.v2.RowToUnsafeInputPartitionReader import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.util.NextIterator @@ -104,8 +102,6 @@ object ContinuousDataSourceRDD { reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { reader match { case r: ContinuousInputPartitionReader[InternalRow] => r - case wrapped: RowToUnsafeInputPartitionReader => - wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 55ce3ae38ee3b..551e07c3db868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.json4s.DefaultFormats import org.json4s.jackson.Serialization -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceOptions) - extends ContinuousReader with SupportsDeprecatedScanRow { +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -67,7 +66,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -91,7 +90,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) i, numPartitions, perPartitionRate) - .asInstanceOf[InputPartition[Row]] + .asInstanceOf[InputPartition[InternalRow]] }.asJava } @@ -119,9 +118,10 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartition[Row] { + extends ContinuousInputPartition[InternalRow] { - override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = { + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] require(rateStreamOffset.partition == partitionIndex, s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") @@ -133,7 +133,7 @@ case class RateStreamContinuousInputPartition( rowsPerSecond) } - override def createPartitionReader(): InputPartitionReader[Row] = + override def createPartitionReader(): InputPartitionReader[InternalRow] = new RateStreamContinuousInputPartitionReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } @@ -144,12 +144,12 @@ class RateStreamContinuousInputPartitionReader( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartitionReader[Row] { + extends ContinuousInputPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong private var currentValue = startValue - private var currentRow: Row = null + private var currentRow: InternalRow = null override def next(): Boolean = { currentValue += increment @@ -165,14 +165,14 @@ class RateStreamContinuousInputPartitionReader( return false } - currentRow = Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(nextReadTime)), + currentRow = InternalRow( + DateTimeUtils.fromMillis(nextReadTime), currentValue) true } - override def get: Row = currentRow + override def get: InternalRow = currentRow override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index e776ebc08e30d..711f0941fe731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints @@ -31,11 +30,12 @@ import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.sql.{Encoder, Row, SQLContext} +import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -49,8 +49,7 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport - with SupportsDeprecatedScanRow { + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -100,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planRowInputPartitions(): ju.List[InputPartition[Row]] = { + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = @@ -109,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa startOffset.partitionNums.map { case (part, index) => new ContinuousMemoryStreamInputPartition( - endpointName, part, index): InputPartition[Row] + endpointName, part, index): InputPartition[InternalRow] }.toList.asJava } } @@ -141,7 +140,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa val buf = records(part) val record = if (buf.size <= index) None else Some(buf(index)) - context.reply(record.map(Row(_))) + context.reply(record.map(r => encoder.toRow(r).copy())) } } } @@ -164,7 +163,7 @@ object ContinuousMemoryStream { class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition[Row] { + startOffset: Int) extends InputPartition[InternalRow] { override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } @@ -177,14 +176,14 @@ class ContinuousMemoryStreamInputPartition( class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousInputPartitionReader[Row] { + startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, SparkEnv.get.rpcEnv) private var currentOffset = startOffset - private var current: Option[Row] = None + private var current: Option[InternalRow] = None // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, // we have to do a bit of error prone work to get it into every thread used by continuous @@ -204,15 +203,15 @@ class ContinuousMemoryStreamInputPartitionReader( true } - override def get(): Row = current.get + override def get(): InternalRow = current.get override def close(): Unit = {} override def getOffset: ContinuousMemoryStreamPartitionOffset = ContinuousMemoryStreamPartitionOffset(partition, currentOffset) - private def getRecord: Option[Row] = - endpoint.askSync[Option[Row]]( + private def getRecord: Option[InternalRow] = + endpoint.askSync[Option[InternalRow]]( GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 7a3452aa315cf..9e0d954932163 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -29,6 +29,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -38,7 +39,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with SupportsDeprecatedScanRow with Logging { + extends MicroBatchReader with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -134,7 +135,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def planRowInputPartitions(): java.util.List[InputPartition[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") @@ -169,7 +170,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : InputPartition[Row] + : InputPartition[InternalRow] }.toList.asJava } @@ -188,9 +189,9 @@ class RateStreamMicroBatchInputPartition( rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition[Row] { + relativeMsPerValue: Double) extends InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[Row] = + override def createPartitionReader(): InputPartitionReader[InternalRow] = new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) } @@ -201,22 +202,18 @@ class RateStreamMicroBatchInputPartitionReader( rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartitionReader[Row] { + relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] { private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd } - override def get(): Row = { + override def get(): InternalRow = { val currValue = rangeStart + partitionId + numPartitions * count count += 1 val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index e3a2c007a9ce4..9f53a1849b33d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket -import java.sql.Timestamp import java.text.SimpleDateFormat import java.util.{Calendar, List => JList, Locale, Optional} import java.util.concurrent.atomic.AtomicBoolean @@ -31,12 +30,15 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String object TextSocketMicroBatchReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) @@ -50,8 +52,7 @@ object TextSocketMicroBatchReader { * debugging. This MicroBatchReader will *not* work in production applications due to multiple * reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader - with SupportsDeprecatedScanRow with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { private var startOffset: Offset = _ private var endOffset: Offset = _ @@ -70,7 +71,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - private val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(UTF8String, Long)] @GuardedBy("this") private var currentOffset: LongOffset = LongOffset(-1L) @@ -101,9 +102,9 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR return } TextSocketMicroBatchReader.this.synchronized { - val newData = (line, - Timestamp.valueOf( - TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + val newData = ( + UTF8String.fromString(line), + DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) ) currentOffset += 1 batches.append(newData) @@ -142,7 +143,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planRowInputPartitions(): JList[InputPartition[Row]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") @@ -164,16 +165,16 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR val spark = SparkSession.getActiveSession.get val numPartitions = spark.sparkContext.defaultParallelism - val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + val slices = Array.fill(numPartitions)(new ListBuffer[(UTF8String, Long)]) rawList.zipWithIndex.foreach { case (r, idx) => slices(idx % numPartitions).append(r) } (0 until numPartitions).map { i => val slice = slices(i) - new InputPartition[Row] { - override def createPartitionReader(): InputPartitionReader[Row] = - new InputPartitionReader[Row] { + new InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new InputPartitionReader[InternalRow] { private var currentIdx = -1 override def next(): Boolean = { @@ -181,8 +182,8 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR currentIdx < slice.size } - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) } override def close(): Unit = {} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index c130b5f1e2513..e4cead9df429c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.util.*; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; import org.apache.spark.sql.sources.v2.DataSourceOptions; @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters, SupportsDeprecatedScanRow { + SupportsPushDownFilters { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -79,8 +79,8 @@ public Filter[] pushedFilters() { } @Override - public List> planRowInputPartitions() { - List> res = new ArrayList<>(); + public List> planInputPartitions() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -107,8 +107,8 @@ public List> planRowInputPartitions() { } } - static class JavaAdvancedInputPartition implements InputPartition, - InputPartitionReader { + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { private int start; private int end; private StructType requiredSchema; @@ -120,7 +120,7 @@ static class JavaAdvancedInputPartition implements InputPartition, } @Override - public InputPartitionReader createPartitionReader() { + public InputPartitionReader createPartitionReader() { return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); } @@ -131,7 +131,7 @@ public boolean next() { } @Override - public Row get() { + public InternalRow get() { Object[] values = new Object[requiredSchema.size()]; for (int i = 0; i < values.length; i++) { if ("i".equals(requiredSchema.apply(i).name())) { @@ -140,7 +140,7 @@ public Row get() { values[i] = -start; } } - return new GenericRow(values); + return new GenericInternalRow(values); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 35aafb532d80d..2d21324f5ece3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -21,8 +21,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -34,7 +34,7 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsReportPartitioning, SupportsDeprecatedScanRow { + class Reader implements DataSourceReader, SupportsReportPartitioning { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -43,7 +43,7 @@ public StructType readSchema() { } @Override - public List> planRowInputPartitions() { + public List> planInputPartitions() { return java.util.Arrays.asList( new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); @@ -73,7 +73,9 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificInputPartition implements InputPartition, InputPartitionReader { + static class SpecificInputPartition implements InputPartition, + InputPartitionReader { + private int[] i; private int[] j; private int current = -1; @@ -91,8 +93,8 @@ public boolean next() throws IOException { } @Override - public Row get() { - return new GenericRow(new Object[] {i[current], j[current]}); + public InternalRow get() { + return new GenericInternalRow(new Object[] {i[current], j[current]}); } @Override @@ -101,7 +103,7 @@ public void close() throws IOException { } @Override - public InputPartitionReader createPartitionReader() { + public InputPartitionReader createPartitionReader() { return this; } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 6dee94c34e21c..ca5abd24abe8f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -19,18 +19,17 @@ import java.util.List; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceReader, SupportsDeprecatedScanRow { + class Reader implements DataSourceReader { private final StructType schema; Reader(StructType schema) { @@ -43,7 +42,7 @@ public StructType readSchema() { } @Override - public List> planRowInputPartitions() { + public List> planInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 5c2f351975c74..274dc3745bcf9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -20,20 +20,19 @@ import java.io.IOException; import java.util.List; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.SupportsDeprecatedScanRow; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsDeprecatedScanRow { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -42,14 +41,16 @@ public StructType readSchema() { } @Override - public List> planRowInputPartitions() { + public List> planInputPartitions() { return java.util.Arrays.asList( new JavaSimpleInputPartition(0, 5), new JavaSimpleInputPartition(5, 10)); } } - static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader { + static class JavaSimpleInputPartition implements InputPartition, + InputPartitionReader { + private int start; private int end; @@ -59,7 +60,7 @@ static class JavaSimpleInputPartition implements InputPartition, InputParti } @Override - public InputPartitionReader createPartitionReader() { + public InputPartitionReader createPartitionReader() { return new JavaSimpleInputPartition(start - 1, end); } @@ -70,8 +71,8 @@ public boolean next() { } @Override - public Row get() { - return new GenericRow(new Object[] {start, -start}); + public InternalRow get() { + return new GenericInternalRow(new Object[] {start, -start}); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java deleted file mode 100644 index 25b89c7fd36a9..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.StructType; - -public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new JavaUnsafeRowInputPartition(0, 5), - new JavaUnsafeRowInputPartition(5, 10)); - } - } - - static class JavaUnsafeRowInputPartition - implements InputPartition, InputPartitionReader { - private int start; - private int end; - private UnsafeRow row; - - JavaUnsafeRowInputPartition(int start, int end) { - this.start = start; - this.end = end; - this.row = new UnsafeRow(2); - row.pointTo(new byte[8 * 3], 8 * 3); - } - - @Override - public InputPartitionReader createPartitionReader() { - return new JavaUnsafeRowInputPartition(start - 1, end); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public UnsafeRow get() { - row.setInt(0, start); - row.setInt(1, -start); - return row; - } - - @Override - public void close() throws IOException { - - } - } - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 260a0376daeb3..7e53da1f312cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -23,7 +23,8 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ @@ -146,10 +147,10 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planRowInputPartitions() + val tasks = reader.planInputPartitions() assert(tasks.size == 1) val dataReader = tasks.get(0).createPartitionReader() - val data = ArrayBuffer[Row]() + val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) } @@ -165,13 +166,13 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planRowInputPartitions() + val tasks = reader.planInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala .map(_.createPartitionReader()) .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() + val buf = scala.collection.mutable.ListBuffer[InternalRow]() while (reader.next()) buf.append(reader.get()) buf } @@ -311,10 +312,10 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.planRowInputPartitions() + val tasks = reader.planInputPartitions() assert(tasks.size == 2) - val data = scala.collection.mutable.ListBuffer[Row]() + val data = scala.collection.mutable.ListBuffer[InternalRow]() tasks.asScala.foreach { case t: RateStreamContinuousInputPartition => val startTimeMs = reader.getStartOffset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index d73eebbc84b71..c7da137219894 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -121,17 +121,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("unsafe row scan implementation") { - Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => - withClue(cls.getName) { - val df = spark.read.format(cls.getName).load() - checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) - } - } - } - test("columnar batch scan implementation") { Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -345,10 +334,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsDeprecatedScanRow { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planRowInputPartitions(): JList[InputPartition[Row]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -358,10 +347,10 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsDeprecatedScanRow { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def planRowInputPartitions(): JList[InputPartition[Row]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } @@ -370,11 +359,11 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { } class SimpleInputPartition(start: Int, end: Int) - extends InputPartition[Row] - with InputPartitionReader[Row] { + extends InputPartition[InternalRow] + with InputPartitionReader[InternalRow] { private var current = start - 1 - override def createPartitionReader(): InputPartitionReader[Row] = + override def createPartitionReader(): InputPartitionReader[InternalRow] = new SimpleInputPartition(start, end) override def next(): Boolean = { @@ -382,7 +371,7 @@ class SimpleInputPartition(start: Int, end: Int) current < end } - override def get(): Row = Row(current, -current) + override def get(): InternalRow = InternalRow(current, -current) override def close(): Unit = {} } @@ -391,7 +380,7 @@ class SimpleInputPartition(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsDeprecatedScanRow + class Reader extends DataSourceReader with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -416,12 +405,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def planRowInputPartitions(): JList[InputPartition[Row]] = { - val lowerBound = filters.collect { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v - }.headOption + } - val res = new ArrayList[InputPartition[Row]] + val res = new ArrayList[InputPartition[InternalRow]] if (lowerBound.isEmpty) { res.add(new AdvancedInputPartition(0, 5, requiredSchema)) @@ -441,11 +430,11 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) - extends InputPartition[Row] with InputPartitionReader[Row] { + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { private var current = start - 1 - override def createPartitionReader(): InputPartitionReader[Row] = { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { new AdvancedInputPartition(start, end, requiredSchema) } @@ -456,57 +445,20 @@ class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) current < end } - override def get(): Row = { + override def get(): InternalRow = { val values = requiredSchema.map(_.name).map { case "i" => current case "j" => -current } - Row.fromSeq(values) + InternalRow.fromSeq(values) } } -class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), - new UnsafeRowInputPartitionReader(5, 10)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader -} - -class UnsafeRowInputPartitionReader(start: Int, end: Int) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { - - private val row = new UnsafeRow(2) - row.pointTo(new Array[Byte](8 * 3), 8 * 3) - - private var current = start - 1 - - override def createPartitionReader(): InputPartitionReader[InternalRow] = this - - override def next(): Boolean = { - current += 1 - current < end - } - override def get(): UnsafeRow = { - row.setInt(0, current) - row.setInt(1, -current) - row - } - - override def close(): Unit = {} -} - class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceReader with SupportsDeprecatedScanRow { - override def planRowInputPartitions(): JList[InputPartition[Row]] = + class Reader(val readSchema: StructType) extends DataSourceReader { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = java.util.Collections.emptyList() } @@ -569,11 +521,10 @@ class BatchInputPartitionReader(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning - with SupportsDeprecatedScanRow { + class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def planRowInputPartitions(): JList[InputPartition[Row]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), @@ -596,20 +547,20 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { } class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) - extends InputPartition[Row] - with InputPartitionReader[Row] { + extends InputPartition[InternalRow] + with InputPartitionReader[InternalRow] { assert(i.length == j.length) private var current = -1 - override def createPartitionReader(): InputPartitionReader[Row] = this + override def createPartitionReader(): InputPartitionReader[InternalRow] = this override def next(): Boolean = { current += 1 current < i.length } - override def get(): Row = Row(i(current), j(current)) + override def get(): InternalRow = InternalRow(i(current), j(current)) override def close(): Unit = {} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 98d7eedbcb9c6..183d0399d3bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,11 +42,10 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader - with SupportsDeprecatedScanRow { + class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def planRowInputPartitions(): JList[InputPartition[Row]] = { + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -57,7 +56,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val serializableConf = new SerializableConfiguration(conf) new SimpleCSVInputPartitionReader( f.getPath.toUri.toString, - serializableConf): InputPartition[Row] + serializableConf): InputPartition[InternalRow] }.toList.asJava } else { Collections.emptyList() @@ -158,13 +157,13 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) - extends InputPartition[Row] with InputPartitionReader[Row] { + extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createPartitionReader(): InputPartitionReader[Row] = { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) @@ -182,7 +181,7 @@ class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguratio } } - override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) override def close(): Unit = { inputStream.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 7c012158bd751..52b833a19c236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -19,22 +19,22 @@ package org.apache.spark.sql.streaming.sources import java.util.Optional -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, SupportsDeprecatedScanRow} +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader - with SupportsDeprecatedScanRow { +case class FakeReader() extends MicroBatchReader with ContinuousReader { def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} def getStartOffset: Offset = RateStreamOffset(Map()) def getEndOffset: Offset = RateStreamOffset(Map()) @@ -45,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def planRowInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { + def planInputPartitions(): java.util.ArrayList[InputPartition[InternalRow]] = { throw new IllegalStateException("fake source - cannot actually read") } } From 95a9d5e3a5ad22c2126fee0ffc7fc789edd18a59 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 2 Aug 2018 02:52:30 +0800 Subject: [PATCH 1267/2461] [SPARK-23915][SQL] Add array_except function ## What changes were proposed in this pull request? The PR adds the SQL function `array_except`. The behavior of the function is based on Presto's one. This function returns returns an array of the elements in array1 but not in array2. Note: The order of elements in the result is not defined. ## How was this patch tested? Added UTs. Author: Kazuaki Ishizaki Closes #21103 from kiszk/SPARK-23915. --- .../spark/util/collection/OpenHashSet.scala | 25 +- .../util/collection/OpenHashSetSuite.scala | 74 +++++ python/pyspark/sql/functions.py | 19 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/Expression.scala | 6 +- .../expressions/collectionOperations.scala | 302 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 115 +++++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 69 ++++ 9 files changed, 609 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60f6f537c1d54..8883e17bf3164 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -28,9 +28,9 @@ import org.apache.spark.annotation.Private * removed. * * The underlying implementation uses Scala compiler's specialization to generate optimized - * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet - * while incurring much less memory overhead. This can serve as building blocks for higher level - * data structures such as an optimized HashMap. + * storage for four primitive types (Long, Int, Double, and Float). It is much faster than Java's + * standard HashSet while incurring much less memory overhead. This can serve as building blocks + * for higher level data structures such as an optimized HashMap. * * This OpenHashSet is designed to serve as building blocks for higher level data structures * such as an optimized hash map. Compared with standard hash set implementations, this class @@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ @Private -class OpenHashSet[@specialized(Long, Int) T: ClassTag]( +class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { @@ -77,6 +77,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( (new LongHasher).asInstanceOf[Hasher[T]] } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Double) { + (new DoubleHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Float) { + (new FloatHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] } @@ -293,7 +297,7 @@ object OpenHashSet { * A set of specialized hash function implementation to avoid boxing hash code computation * in the specialized implementation of OpenHashSet. */ - sealed class Hasher[@specialized(Long, Int) T] extends Serializable { + sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { def hash(o: T): Int = o.hashCode() } @@ -305,6 +309,17 @@ object OpenHashSet { override def hash(o: Int): Int = o } + class DoubleHasher extends Hasher[Double] { + override def hash(o: Double): Int = { + val bits = java.lang.Double.doubleToLongBits(o) + (bits ^ (bits >>> 32)).toInt + } + } + + class FloatHasher extends Hasher[Float] { + override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) + } + private def grow1(newSize: Int) {} private def move1(oldPos: Int, newPos: Int) { } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 210bc5c099742..b887f937a9da9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -112,6 +112,80 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(!set.contains(10000L)) } + test("primitive float") { + val set = new OpenHashSet[Float] + assert(set.size === 0) + assert(!set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(10.1F) + assert(set.size === 1) + assert(set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 2) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(999.9F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + } + + test("primitive double") { + val set = new OpenHashSet[Double] + assert(set.size === 0) + assert(!set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(10.1D) + assert(set.size === 1) + assert(set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 2) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(999.9D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + } + test("non-primitive") { val set = new OpenHashSet[String] assert(set.size === 0) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dd7daf946dd41..ec014a5b39c31 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2052,6 +2052,25 @@ def array_union(col1, col2): return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) +@ignore_unicode_prefix +@since(2.4) +def array_except(col1, col2): + """ + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=[u'b'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index adc4837276793..b8b311219ca8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index dcb9c96ca3b2d..773aefc0ac1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,7 +709,7 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) - override def dataType: DataType = { + def dataTypeCheck: Unit = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") @@ -717,6 +717,10 @@ trait ComplexTypeMergingExpression extends Expression { TypeCoercion.haveSameType(inputTypesForMerging), "All input types must be the same except nullable, containsNull, valueContainsNull flags." + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") + } + + override def dataType: DataType = { + dataTypeCheck inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b1d91ffbe86e0..b03bd7d942d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3651,14 +3651,9 @@ case class ArrayDistinct(child: Expression) } /** - * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { - override def dataType: DataType = { - val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - ArrayType(elementType, dataTypes.exists(_.containsNull)) - } - override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -3702,7 +3697,8 @@ object ArraySetLike { array(1, 2, 3, 5) """, since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ @@ -3968,3 +3964,295 @@ object ArrayUnion { new GenericArrayData(arrayBuffer) } } + +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(2) + """, + since = "2.4.0") +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + left.dataType + } + + @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var scannedNullElements = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (elem1 == null) { + if (!scannedNullElements) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + scannedNullElements = true + } else { + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + val elem2 = array2.get(j, elementType) + if (elem2 != null) { + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + } + } + if (!found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalExcept(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val hsValue = ctx.freshName("hsValue") + val size = ctx.freshName("size") + if (elementTypeSupportEquals) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = + elementType match { + case ByteType | ShortType | IntegerType => + ("$mcI$sp", "Int", "int", s"(int) $value", + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val notFoundNullElement = ctx.freshName("notFoundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val genericArrayData = classOf[GenericArrayData].getName + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" + val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { + s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" + } else { + s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" + } + + def withArray2NullCheck(body: String) = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array2Body = + s""" + |$javaTypeName $value = $array2.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |$hs.add$postFix($hsValue); + """.stripMargin + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $nullElementIndex = $size; + | $notFoundNullElement = false; + | $size++; + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array1Body = + s""" + |$javaTypeName $value = $array1.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |if (!$hs.contains($hsValue)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hs.add$postFix($hsValue); + | $builder.$$plus$$eq($value); + |} + """.stripMargin + + val nonNullArrayDataBuild = { + val build = if (postFix != "") { + val defaultSize = elementType.defaultSize + s""" + |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin + } else { + s"${ev.value} = new $genericArrayData($builder.result());" + } + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + + | " $prettyName failed."); + |} + |$build + """.stripMargin + } + + def buildResultArrayData(nonNullArrayDataBuild: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($nullElementIndex < 0) { + | // result has no null element + | $nonNullArrayDataBuild + |} else { + | // result has null element + | $arrayDataBuilder + | $javaTypeName[] $array = $builder.result(); + | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { + | if ($pos == $nullElementIndex) { + | ${ev.value}.setNullAt($pos); + | } else { + | $javaTypeName $value = $array[$i++]; + | ${ev.value}.$setter; + | } + | } + |} + """.stripMargin + } else { + nonNullArrayDataBuild + } + + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $notFoundNullElement = true; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | ${withArray2NullCheck(array2Body)} + |} + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |int $nullElementIndex = -1; + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | ${withArray1NullAssignment(array1Body)} + |} + |${buildResultArrayData(nonNullArrayDataBuild)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayExceptExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + }) + } + } + + override def prettyName: String = "array_except" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5c5728548e646..2f6f9064f9e62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1503,4 +1503,119 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(!shuffle.fastEquals(shuffle.freshCopy())) assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) } + + test("Array Except") { + val a00 = Literal.create(Seq(1, 2, 4, 3), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 4, 2), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 4L, 2L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, 1L), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c", "d"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", null, "a", "f", "c"), ArrayType(StringType, true)) + val a25 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, true)) + val a26 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayExcept(a00, a01), Seq(1, 3)) + checkEvaluation(ArrayExcept(a02, a01), Seq(1)) + checkEvaluation(ArrayExcept(a02, a02), Seq.empty) + checkEvaluation(ArrayExcept(a02, a03), Seq(1)) + checkEvaluation(ArrayExcept(a04, a02), Seq(null, 5)) + checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) + checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayExcept(a06, a04), Seq.empty) + checkEvaluation(ArrayExcept(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) + checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) + checkEvaluation(ArrayExcept(af0, af1), Seq[Float](1.1F, 3.3F)) + checkEvaluation(ArrayExcept(ad0, ad1), Seq[Double](1.1, 3.3)) + + checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) + checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) + checkEvaluation(ArrayExcept(a12, a12), Seq.empty) + checkEvaluation(ArrayExcept(a12, a13), Seq(1L)) + checkEvaluation(ArrayExcept(a14, a12), Seq(null, 5L)) + checkEvaluation(ArrayExcept(a14, a15), Seq(1L, 5L)) + checkEvaluation(ArrayExcept(a14, a16), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayExcept(a16, a14), Seq.empty) + + checkEvaluation(ArrayExcept(a20, a21), Seq("b", "d")) + checkEvaluation(ArrayExcept(a22, a21), Seq("b")) + checkEvaluation(ArrayExcept(a22, a22), Seq.empty) + checkEvaluation(ArrayExcept(a22, a23), Seq("b")) + checkEvaluation(ArrayExcept(a24, a22), Seq(null, "f")) + checkEvaluation(ArrayExcept(a24, a25), Seq("c", "f")) + checkEvaluation(ArrayExcept(a24, a26), Seq("c", null, "a", "f")) + checkEvaluation(ArrayExcept(a26, a24), Seq.empty) + + checkEvaluation(ArrayExcept(a30, a30), Seq.empty) + checkEvaluation(ArrayExcept(a20, a31), null) + checkEvaluation(ArrayExcept(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](7, 8)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayExcept(b0, b1), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b1, b0), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b0, b2), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b2, b0), Seq.empty) + checkEvaluation(ArrayExcept(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2))) + checkEvaluation(ArrayExcept(b3, b2), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayExcept(b3, b4), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b4, b3), Seq.empty) + checkEvaluation(ArrayExcept(b4, b5), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayExcept(b5, b4), Seq.empty) + checkEvaluation(ArrayExcept(b4, arrayWithBinaryNull), Seq[Array[Byte]](Array[Byte](3, 4))) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayExcept(aa0, aa1), Seq[Seq[Int]](Seq[Int](1, 2))) + checkEvaluation(ArrayExcept(aa1, aa0), Seq[Seq[Int]](Seq[Int](2, 1))) + + assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a04, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2d37928bff59..cc739b85f555c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3243,6 +3243,17 @@ object functions { ArrayUnion(col1.expr, col2.expr) } + /** + * Returns an array of the elements in the first array but not in the second array, + * without duplicates. The order of elements in the result is not determined + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_except(col1: Column, col2: Column): Column = withExpr { + ArrayExcept(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 299c96f74af22..e550b142c738d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1578,6 +1578,75 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testNonPrimitiveType() } + test("array_except functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1)) + checkAnswer(df1.select(array_except($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_except(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(1, 5)) + checkAnswer(df2.select(array_except($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_except(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L)) + checkAnswer(df3.select(array_except($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_except(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 5L)) + checkAnswer(df4.select(array_except($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_except(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("c", "f")) + checkAnswer(df5.select(array_except($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_except(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_except(a, b)") + } + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_except(a, b)") + } + val df8 = Seq((Array("a"), null)).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_except(a, b)") + } + val df9 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df9.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df9.selectExpr("array_except(a, b)") + } + + val df10 = Seq( + (Array[Integer](1, 2), Array[Integer](2)), + (Array[Integer](1, 2), Array[Integer](1, null)), + (Array[Integer](1, null, 3), Array[Integer](1, 2)), + (Array[Integer](1, null), Array[Integer](2, null)) + ).toDF("a", "b") + val result10 = df10.select(array_except($"a", $"b")) + val expectedType10 = ArrayType(IntegerType, containsNull = true) + assert(result10.first.schema(0).dataType === expectedType10) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From f5113ea8d79de724ec1579bc81a7abb61e44eeef Mon Sep 17 00:00:00 2001 From: Adelbert Chang Date: Wed, 1 Aug 2018 13:57:33 -0700 Subject: [PATCH 1268/2461] [SPARK-24960][K8S] explicitly expose ports on driver container https://issues.apache.org/jira/browse/SPARK-24960 ## What changes were proposed in this pull request? Expose ports explicitly in the driver container. The driver Service created expects to reach the driver Pod at specific ports which before this change, were not explicitly exposed and would likely cause connection issues (see https://github.com/apache-spark-on-k8s/spark/issues/617). This is a port of the original PR created in the now-deprecated Kubernetes fork: https://github.com/apache-spark-on-k8s/spark/pull/618 ## How was this patch tested? Failure in https://github.com/apache-spark-on-k8s/spark/issues/617 reproduced on Kubernetes 1.6.x and 1.8.x. Built the driver image with this patch and observed fixed https://github.com/apache-spark-on-k8s/spark/issues/617 on Kubernetes 1.6.x. Author: Adelbert Chang Closes #21884 from adelbertc/k8s-expose-driver-ports. --- .../apache/spark/deploy/k8s/Constants.scala | 1 + .../k8s/features/BasicDriverFeatureStep.scala | 22 +++++++++++++++++++ .../BasicDriverFeatureStepSuite.scala | 18 ++++++++++++++- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 5ecdd3a04d77b..f82cd7fd02e12 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -47,6 +47,7 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" + val UI_PORT_NAME = "spark-ui" // Environment Variables val ENV_DRIVER_URL = "SPARK_DRIVER_URL" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 7e67b51de6e04..575bc54ffe2bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -27,6 +27,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ +import org.apache.spark.ui.SparkUI private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -72,10 +73,31 @@ private[spark] class BasicDriverFeatureStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } + val driverPort = conf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverBlockManagerPort = conf.sparkConf.getInt( + DRIVER_BLOCK_MANAGER_PORT.key, + DEFAULT_BLOCKMANAGER_PORT + ) + val driverUIPort = SparkUI.getUIPort(conf.sparkConf) val driverContainer = new ContainerBuilder(pod.container) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(conf.imagePullPolicy()) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withContainerPort(driverPort) + .withProtocol("TCP") + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withContainerPort(driverBlockManagerPort) + .withProtocol("TCP") + .endPort() + .addNewPort() + .withName(UI_PORT_NAME) + .withContainerPort(driverUIPort) + .withProtocol("TCP") + .endPort() .addAllToEnv(driverCustomEnvs.asJava) .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 165f46a07df2f..d98e113554648 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder +import io.fabric8.kubernetes.api.model.{ContainerPort, ContainerPortBuilder, LocalObjectReferenceBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.JavaMainAppResource import org.apache.spark.deploy.k8s.submit.PythonMainAppResource +import org.apache.spark.ui.SparkUI class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -87,6 +88,14 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(configuredPod.container.getImage === "spark-driver:latest") assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + val expectedPortNames = Set( + containerPort(DRIVER_PORT_NAME, DEFAULT_DRIVER_PORT), + containerPort(BLOCK_MANAGER_PORT_NAME, DEFAULT_BLOCKMANAGER_PORT), + containerPort(UI_PORT_NAME, SparkUI.DEFAULT_PORT) + ) + val foundPortNames = configuredPod.container.getPorts.asScala.toSet + assert(expectedPortNames === foundPortNames) + assert(configuredPod.container.getEnv.size === 3) val envs = configuredPod.container .getEnv @@ -203,4 +212,11 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") assert(additionalProperties === expectedSparkConf) } + + def containerPort(name: String, portNumber: Int): ContainerPort = + new ContainerPortBuilder() + .withName(name) + .withContainerPort(portNumber) + .withProtocol("TCP") + .build() } From 9f558601e822b7596e4bcc141d5c91a5a8859628 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 1 Aug 2018 13:58:29 -0700 Subject: [PATCH 1269/2461] [SPARK-24937][SQL] Datasource partition table should load empty static partitions ## What changes were proposed in this pull request? How to reproduce: ```sql spark-sql> CREATE TABLE tbl AS SELECT 1; spark-sql> CREATE TABLE tbl1 (c1 BIGINT, day STRING, hour STRING) > USING parquet > PARTITIONED BY (day, hour); spark-sql> INSERT INTO TABLE tbl1 PARTITION (day = '2018-07-25', hour='01') SELECT * FROM tbl where 1=0; spark-sql> SHOW PARTITIONS tbl1; spark-sql> CREATE TABLE tbl2 (c1 BIGINT) > PARTITIONED BY (day STRING, hour STRING); spark-sql> INSERT INTO TABLE tbl2 PARTITION (day = '2018-07-25', hour='01') SELECT * FROM tbl where 1=0; spark-sql> SHOW PARTITIONS tbl2; day=2018-07-25/hour=01 spark-sql> ``` 1. Users will be confused about whether the partition data of `tbl1` is generated. 2. Inconsistent with Hive table behavior. This pr fix this issues. ## How was this patch tested? unit tests Author: Yuming Wang Closes #21883 from wangyum/SPARK-24937. --- .../InsertIntoHadoopFsRelationCommand.scala | 10 ++- .../datasources/PartitioningUtils.scala | 6 +- .../sql/execution/command/DDLSuite.scala | 62 +++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 8a2e00d9780e8..2ae21b7df9823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -171,7 +171,15 @@ case class InsertIntoHadoopFsRelationCommand( // update metastore partition metadata - refreshUpdatedPartitions(updatedPartitionPaths) + if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty + && partitionColumns.length == staticPartitions.size) { + // Avoid empty static partition can't loaded to datasource table. + val staticPathFragment = + PartitioningUtils.getPathFragment(staticPartitions, partitionColumns) + refreshUpdatedPartitions(Set(staticPathFragment)) + } else { + refreshUpdatedPartitions(updatedPartitionPaths) + } // refresh cached files in FileIndex fileIndex.foreach(_.refresh()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c8a5f9864a602..3183fd30e5e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -284,6 +284,10 @@ object PartitioningUtils { }.mkString("/") } + def getPathFragment(spec: TablePartitionSpec, partitionColumns: Seq[Attribute]): String = { + getPathFragment(spec, StructType.fromAttributes(partitionColumns)) + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index ca95aad3976e6..78df1db93692b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2249,6 +2249,68 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("Partition table should load empty static partitions") { + // All static partitions + withTable("t", "t1", "t2") { + withTempPath { dir => + spark.sql("CREATE TABLE t(a int) USING parquet") + spark.sql("CREATE TABLE t1(a int, c string, b string) " + + s"USING parquet PARTITIONED BY(c, b) LOCATION '${dir.toURI}'") + + // datasource table + validateStaticPartitionTable("t1") + + // hive table + if (isUsingHiveMetastore) { + spark.sql("CREATE TABLE t2(a int) " + + s"PARTITIONED BY(c string, b string) LOCATION '${dir.toURI}'") + validateStaticPartitionTable("t2") + } + + def validateStaticPartitionTable(tableName: String): Unit = { + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + spark.sql( + s"INSERT INTO TABLE $tableName PARTITION(b='b', c='c') SELECT * FROM t WHERE 1 = 0") + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 1) + assert(new File(dir, "c=c/b=b").exists()) + checkAnswer(spark.table(tableName), Nil) + } + } + } + + // Partial dynamic partitions + withTable("t", "t1", "t2") { + withTempPath { dir => + spark.sql("CREATE TABLE t(a int) USING parquet") + spark.sql("CREATE TABLE t1(a int, b string, c string) " + + s"USING parquet PARTITIONED BY(c, b) LOCATION '${dir.toURI}'") + + // datasource table + validatePartialStaticPartitionTable("t1") + + // hive table + if (isUsingHiveMetastore) { + spark.sql("CREATE TABLE t2(a int) " + + s"PARTITIONED BY(c string, b string) LOCATION '${dir.toURI}'") + validatePartialStaticPartitionTable("t2") + } + + def validatePartialStaticPartitionTable(tableName: String): Unit = { + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + spark.sql( + s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, 'b' FROM t WHERE 1 = 0") + assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) + assert(!new File(dir, "c=c/b=b").exists()) + checkAnswer(spark.table(tableName), Nil) + } + } + } + } + Seq(true, false).foreach { shouldDelete => val tcName = if (shouldDelete) "non-existing" else "existed" test(s"CTAS for external data source table with a $tcName location") { From ce084d3e06b14897174426665dada0464260da89 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Aug 2018 15:57:54 -0700 Subject: [PATCH 1270/2461] [SPARK-24990][SQL] merge ReadSupport and ReadSupportWithSchema ## What changes were proposed in this pull request? Regarding user-specified schema, data sources may have 3 different behaviors: 1. must have a user-specified schema 2. can't have a user-specified schema 3. can accept the user-specified if it's given, or infer the schema. I added `ReadSupportWithSchema` to support these behaviors, following data source v1. But it turns out we don't need this extra interface. We can just add a `createReader(schema, options)` to `ReadSupport` and make it call `createReader(options)` by default. TODO: also fix the streaming API in followup PRs. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #21946 from cloud-fan/ds-schema. --- .../spark/sql/sources/v2/ReadSupport.java | 25 ++++++++++ .../sql/sources/v2/ReadSupportWithSchema.java | 49 ------------------- .../sources/v2/reader/DataSourceReader.java | 3 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../datasources/v2/DataSourceV2Relation.scala | 20 +------- .../v2/JavaSchemaRequiredDataSource.java | 9 +++- .../sql/sources/v2/DataSourceV2Suite.scala | 16 +++--- 7 files changed, 47 insertions(+), 79 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index b2526ded53d92..80ac08ee5ff52 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to @@ -27,6 +29,29 @@ @InterfaceStability.Evolving public interface ReadSupport extends DataSourceV2 { + /** + * Creates a {@link DataSourceReader} to scan the data from this data source. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + * + * @param schema the user specified schema. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + */ + default DataSourceReader createReader(StructType schema, DataSourceOptions options) { + String name; + if (this instanceof DataSourceRegister) { + name = ((DataSourceRegister) this).shortName(); + } else { + name = this.getClass().getName(); + } + throw new UnsupportedOperationException(name + " does not support user specified schema"); + } + /** * Creates a {@link DataSourceReader} to scan the data from this data source. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java deleted file mode 100644 index f31659904cc53..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. - * - * This is a variant of {@link ReadSupport} that accepts user-specified schema when reading data. - * A data source can implement both {@link ReadSupport} and {@link ReadSupportWithSchema} if it - * supports both schema inference and user-specified schema. - */ -@InterfaceStability.Evolving -public interface ReadSupportWithSchema extends DataSourceV2 { - - /** - * Create a {@link DataSourceReader} to scan the data from this data source. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - * - * @param schema the full schema of this data source reader. Full schema usually maps to the - * physical schema of the underlying storage of this data source reader, e.g. - * CSV files, JSON files, etc, while this reader may not read data with full - * schema, as column pruning or other optimizations may happen. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - DataSourceReader createReader(StructType schema, DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 4a7462096db16..da98fab1284ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -23,13 +23,12 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.types.StructType; /** * A data source reader that is returned by * {@link ReadSupport#createReader(DataSourceOptions)} or - * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. + * {@link ReadSupport#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan * logic is delegated to {@link InputPartition}s, which are returned by * {@link #planInputPartitions()}. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ec9352a7fa055..9bd113419ae4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + if (ds.isInstanceOf[ReadSupport]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7613eb210c659..46166928f449d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType @@ -110,22 +110,6 @@ object DataSourceV2Relation { source match { case support: ReadSupport => support - case _: ReadSupportWithSchema => - // this method is only called if there is no user-supplied schema. if there is no - // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. - throw new AnalysisException(s"Data source requires a user-supplied schema: $name") - case _ => - throw new AnalysisException(s"Data source is not readable: $name") - } - } - - def asReadSupportWithSchema: ReadSupportWithSchema = { - source match { - case support: ReadSupportWithSchema => - support - case _: ReadSupport => - throw new AnalysisException( - s"Data source does not support user-supplied schema: $name") case _ => throw new AnalysisException(s"Data source is not readable: $name") } @@ -146,7 +130,7 @@ object DataSourceV2Relation { val v2Options = new DataSourceOptions(options.asJava) userSpecifiedSchema match { case Some(s) => - asReadSupportWithSchema.createReader(s, v2Options) + asReadSupport.createReader(s, v2Options) case _ => asReadSupport.createReader(v2Options) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index ca5abd24abe8f..6fd6a44d2c4d5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -22,12 +22,12 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; +import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { +public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupport { class Reader implements DataSourceReader { private final StructType schema; @@ -47,6 +47,11 @@ public List> planInputPartitions() { } } + @Override + public DataSourceReader createReader(DataSourceOptions options) { + throw new IllegalArgumentException("requires a user-supplied schema"); + } + @Override public DataSourceReader createReader(StructType schema, DataSourceOptions options) { return new Reader(schema); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index c7da137219894..b6e594dc29cef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -22,9 +22,8 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -135,8 +134,8 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { - val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("requires a user-supplied schema")) + val e = intercept[IllegalArgumentException](spark.read.format(cls.getName).load()) + assert(e.getMessage.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() @@ -455,15 +454,20 @@ class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) } -class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { +class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport { class Reader(val readSchema: StructType) extends DataSourceReader { override def planInputPartitions(): JList[InputPartition[InternalRow]] = java.util.Collections.emptyList() } - override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = + override def createReader(options: DataSourceOptions): DataSourceReader = { + throw new IllegalArgumentException("requires a user-supplied schema") + } + + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { new Reader(schema) + } } class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { From c5fe412928b76b468713742497d3ccc010596516 Mon Sep 17 00:00:00 2001 From: liuxian Date: Wed, 1 Aug 2018 21:19:24 -0500 Subject: [PATCH 1271/2461] [SPARK-18188][DOC][FOLLOW-UP] Add `spark.broadcast.checksum` to configuration ## What changes were proposed in this pull request? This pr add `spark.broadcast.checksum` to configuration. ## How was this patch tested? manually tested Author: liuxian Closes #21825 from 10110346/checksum_config. --- docs/configuration.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 60c0358c0e938..4911abb0f5cfc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1217,6 +1217,15 @@ Apart from these, the following properties are also available, and may be useful if it is too small, BlockManager might take a performance hit. + + spark.broadcast.checksum + true + + Whether to enable checksum for broadcast. If enabled, broadcasts will include a checksum, which can + help detect corrupted blocks, at the cost of computing and sending a little more data. It's possible + to disable it if the network has other mechanisms to guarantee data won't be corrupted during broadcast. + + spark.executor.cores From c9914cf0490d13820fb4081eb05188b4903eb980 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 2 Aug 2018 10:22:52 +0800 Subject: [PATCH 1272/2461] [MINOR][DOCS] Add note about Spark network security ## What changes were proposed in this pull request? In response to a recent question, this reiterates that network access to a Spark cluster should be disabled by default, and that access to its hosts and services from outside a private network should be added back explicitly. Also, some minor touch-ups while I was at it. ## How was this patch tested? N/A Author: Sean Owen Closes #21947 from srowen/SecurityNote. --- docs/security.md | 23 ++++++++++++++++++----- docs/spark-standalone.md | 15 +++++++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/docs/security.md b/docs/security.md index 6ef3a808e0471..1de1d6318939a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -278,7 +278,7 @@ To enable authorization in the SHS, a few extra options are used: - + - + - + + + + + + + + + diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 14d742de5655c..7975b0c8b11ca 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -362,8 +362,15 @@ You can run Spark alongside your existing Hadoop cluster by just launching it as # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using -tight firewall settings. For a complete list of ports to configure, see the +Generally speaking, a Spark cluster and its services are not deployed on the public internet. +They are generally private services, and should only be accessible within the network of the +organization that deploys Spark. Access to the hosts and ports used by Spark services should +be limited to origin hosts that need to access the services. + +This is particularly important for clusters using the standalone resource manager, as they do +not support fine-grained access control in a way that other resource managers do. + +For a complete list of ports to configure, see the [security page](security.html#configuring-ports-for-network-security). # High Availability @@ -376,7 +383,7 @@ By default, standalone scheduling clusters are resilient to Worker failures (ins Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. -Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/current/zookeeperStarted.html). +Learn more about getting started with ZooKeeper [here](https://zookeeper.apache.org/doc/current/zookeeperStarted.html). **Configuration** @@ -419,6 +426,6 @@ In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spa **Details** -* This solution can be used in tandem with a process monitor/manager like [monit](http://mmonit.com/monit/), or just to enable manual recovery via restart. +* This solution can be used in tandem with a process monitor/manager like [monit](https://mmonit.com/monit/), or just to enable manual recovery via restart. * While filesystem recovery seems straightforwardly better than not doing any recovery at all, this mode may be suboptimal for certain development or experimental purposes. In particular, killing a master via stop-master.sh does not clean up its recovery state, so whenever you start a new Master, it will enter recovery mode. This could increase the startup time by up to 1 minute if it needs to wait for all previously-registered Workers/clients to timeout. * While it's not officially supported, you could mount an NFS directory as the recovery directory. If the original Master node dies completely, you could then start a Master on a different node, which would correctly recover all previously registered Workers/applications (equivalent to ZooKeeper recovery). Future applications will have to be able to find the new Master, however, in order to register. From 166f346185cc0b27a7e2b2a3b42df277e5901f2f Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 1 Aug 2018 23:00:17 -0700 Subject: [PATCH 1273/2461] [SPARK-24957][SQL][FOLLOW-UP] Clean the code for AVERAGE ## What changes were proposed in this pull request? This PR is to refactor the code in AVERAGE by dsl. ## How was this patch tested? N/A Author: Xiao Li Closes #21951 from gatorsmile/refactor1. --- .../org/apache/spark/sql/catalyst/dsl/package.scala | 1 + .../sql/catalyst/expressions/aggregate/Average.scala | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 89e8c998f740d..98708545c4bfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -166,6 +166,7 @@ package object dsl { def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct = true) def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) + def coalesce(args: Expression*): Expression = Coalesce(args) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) def star(names: String*): Expression = names match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9ccf5aa092d11..f1fad770b637f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -46,7 +46,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override lazy val aggBufferAttributes = sum :: count :: Nil override lazy val initialValues = Seq( - /* sum = */ Cast(Literal(0), sumDataType), + /* sum = */ Literal(0).cast(sumDataType), /* count = */ Literal(0L) ) @@ -58,18 +58,16 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { case _: DecimalType => - Cast( - DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)), - resultType) + DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) case _ => - Cast(sum, resultType) / Cast(count, resultType) + sum.cast(resultType) / count.cast(resultType) } protected def updateExpressionsDef: Seq[Expression] = Seq( /* sum = */ Add( sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), /* count = */ If(IsNull(child), count, count + 1L) ) From 57d994994d27154f57f2724924c42beb2ab2e0e7 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 1 Aug 2018 23:46:01 -0700 Subject: [PATCH 1274/2461] [SPARK-24557][ML] ClusteringEvaluator support array input ## What changes were proposed in this pull request? ClusteringEvaluator support array input ## How was this patch tested? added tests Author: zhengruifeng Closes #21563 from zhengruifeng/clu_eval_support_array. --- .../spark/ml/evaluation/ClusteringEvaluator.scala | 15 +++++++++------ .../ml/evaluation/ClusteringEvaluatorSuite.scala | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 4353c46781e9d..a6d6b4ea8b965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -21,11 +21,10 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, - SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -107,15 +106,19 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol)) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + val df = dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata)) + ($(metricName), $(distanceMeasure)) match { case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol)) + df, $(predictionCol), $(featuresCol)) case ("silhouette", "cosine") => - CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) + CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 2c175ff68e0b8..e2d77560293fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite import testImplicits._ @transient var irisDataset: Dataset[_] = _ + @transient var newIrisDataset: Dataset[_] = _ + @transient var newIrisDatasetD: Dataset[_] = _ + @transient var newIrisDatasetF: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset) + newIrisDataset = datasets._1 + newIrisDatasetD = datasets._2 + newIrisDatasetF = datasets._3 } test("params") { @@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite .setPredictionCol("label") assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5) } /* @@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite .setDistanceMeasure("cosine") assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5) } test("number of clusters must be greater than one") { From 275415777b84b82aa5409e6577e1efaff1d989e7 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 2 Aug 2018 20:54:36 +0800 Subject: [PATCH 1275/2461] [SPARK-24795][CORE][FOLLOWUP] Kill all running tasks when a task in a barrier stage fail ## What changes were proposed in this pull request? Kill all running tasks when a task in a barrier stage fail in the middle. `TaskScheduler`.`cancelTasks()` will also fail the job, so we implemented a new method `killAllTaskAttempts()` to just kill all running tasks of a stage without cancel the stage/job. ## How was this patch tested? Add new test cases in `TaskSchedulerImplSuite`. Author: Xingbo Jiang Closes #21943 from jiangxb1987/killAllTasks. --- .../apache/spark/scheduler/DAGScheduler.scala | 14 +++-- .../spark/scheduler/TaskScheduler.scala | 8 ++- .../spark/scheduler/TaskSchedulerImpl.scala | 34 +++++++--- .../spark/scheduler/DAGSchedulerSuite.scala | 6 ++ .../ExternalClusterManagerSuite.scala | 2 + .../scheduler/TaskSchedulerImplSuite.scala | 62 +++++++++++++++++++ 6 files changed, 109 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 003d64f78e853..4858af71c1a9c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1433,17 +1433,18 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + "failed.") - val message = s"Stage failed because barrier task $task finished unsuccessfully. " + + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + failure.toErrorString try { - // cancelTasks will fail if a SchedulerBackend does not implement killTask - taskScheduler.cancelTasks(stageId, interruptThread = false) + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) failed." + taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) } catch { case e: UnsupportedOperationException => // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. - logWarning(s"Could not cancel tasks for stage $stageId", e) - abortStage(failedStage, "Could not cancel zombie barrier tasks for stage " + + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + s"$failedStage (${failedStage.name})", Some(e)) } markStageAsFinished(failedStage, Some(message)) @@ -1457,7 +1458,8 @@ class DAGScheduler( if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { - "Barrier stage will not retry stage due to testing config" + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" } else { s"""$failedStage (${failedStage.name}) |has failed the maximum allowable number of diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 90644fea23ab1..95f7ae4fd39a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -51,16 +51,22 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Cancel a stage. + // Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit /** * Kills a task attempt. + * Throw UnsupportedOperationException if the backend doesn't support kill a task. * * @return Whether the task was successfully killed. */ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Kill all the running task attempts in a stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. + def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 587ed4b5243b7..72691389d271c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -222,18 +222,11 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) + // Kill all running tasks for the stage. + killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled") + // Cancel all attempts for the stage. taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks have been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - taskIdToExecutorId.get(tid).foreach(execId => - backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) - } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) } @@ -252,6 +245,27 @@ private[spark] class TaskSchedulerImpl( } } + override def killAllTaskAttempts( + stageId: Int, + interruptThread: Boolean, + reason: String): Unit = synchronized { + logInfo(s"Killing all running tasks in stage $stageId: $reason") + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks have been scheduled. In this case, + // simply continue. + tsm.runningTasksSet.foreach { tid => + taskIdToExecutorId.get(tid).foreach { execId => + backend.killTask(tid, execId, interruptThread, reason) + } + } + } + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index b3db5e29fb82e..dad339e2cdb91 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -131,6 +131,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -629,6 +631,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskId: Long, interruptThread: Boolean, reason: String): Boolean = { throw new UnsupportedOperationException } + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index a4e4ea7cd2894..02b19e01ce7a0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -81,6 +81,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 16c273b7bc8a4..38e26a82e750f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1055,4 +1055,66 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten assert(3 === taskDescriptions.length) } + + test("cancelTasks shall kill all the running tasks and fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.cancelTasks(0, false) + assert(0 === tsm.runningTasks) + assert(tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty) + } + + test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.killAllTaskAttempts(0, false, "test") + assert(0 === tsm.runningTasks) + assert(!tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined) + } } From a65736996b2b506f61cc8d599ec9f4c52a1b5312 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 2 Aug 2018 09:17:09 -0500 Subject: [PATCH 1276/2461] [SPARK-14540][CORE] Fix remaining major issues for Scala 2.12 Support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR addresses issues 2,3 in this [document](https://docs.google.com/document/d/1fbkjEL878witxVQpOCbjlvOvadHtVjYXeB-2mgzDTvk). * We modified the closure cleaner to identify closures that are implemented via the LambdaMetaFactory mechanism (serializedLambdas) (issue2). * We also fix the issue due to scala/bug#11016. There are two options for solving the Unit issue, either add () at the end of the closure or use the trick described in the doc. Otherwise overloading resolution does not work (we are not going to eliminate either of the methods) here. Compiler tries to adapt to Unit and makes these two methods candidates for overloading, when there is polymorphic overloading there is no ambiguity (that is the workaround implemented). This does not look that good but it serves its purpose as we need to support two different uses for method: `addTaskCompletionListener`. One that passes a TaskCompletionListener and one that passes a closure that is wrapped with a TaskCompletionListener later on (issue3). Note: regarding issue 1 in the doc the plan is: > Do Nothing. Don’t try to fix this as this is only a problem for Java users who would want to use 2.11 binaries. In that case they can cast to MapFunction to be able to utilize lambdas. In Spark 3.0.0 the API should be simplified so that this issue is removed. ## How was this patch tested? This was manually tested: ```./dev/change-scala-version.sh 2.12 ./build/mvn -DskipTests -Pscala-2.12 clean package ./build/mvn -Pscala-2.12 clean package -DwildcardSuites=org.apache.spark.serializer.ProactiveClosureSerializationSuite -Dtest=None ./build/mvn -Pscala-2.12 clean package -DwildcardSuites=org.apache.spark.util.ClosureCleanerSuite -Dtest=None ./build/mvn -Pscala-2.12 clean package -DwildcardSuites=org.apache.spark.streaming.DStreamClosureSuite -Dtest=None``` Author: Stavros Kontopoulos Closes #21930 from skonto/scala2.12-sup. --- .../scala/org/apache/spark/TaskContext.scala | 5 +- .../spark/api/python/PythonRunner.scala | 2 +- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../scala/org/apache/spark/rdd/JdbcRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 2 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../spark/storage/memory/MemoryStore.scala | 2 +- .../apache/spark/util/ClosureCleaner.scala | 278 +++++++++++------- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../spark/util/ClosureCleanerSuite.scala | 3 + .../spark/util/ClosureCleanerSuite2.scala | 53 +++- .../spark/sql/avro/AvroFileFormat.scala | 2 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 2 +- .../spark/streaming/kafka010/KafkaRDD.scala | 2 +- .../ml/source/libsvm/LibSVMRelation.scala | 2 +- .../TungstenAggregationIterator.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 4 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../execution/datasources/CodecStreams.scala | 2 +- .../execution/datasources/FileScanRDD.scala | 2 +- .../datasources/csv/CSVDataSource.scala | 2 +- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/json/JsonDataSource.scala | 2 +- .../datasources/orc/OrcFileFormat.scala | 4 +- .../parquet/ParquetFileFormat.scala | 4 +- .../datasources/text/TextFileFormat.scala | 2 +- .../datasources/v2/DataSourceRDD.scala | 2 +- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../joins/ShuffledHashJoinExec.scala | 2 +- .../python/AggregateInPandasExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/execution/python/EvalPythonExec.scala | 2 +- .../python/PythonForeachWriter.scala | 2 +- .../execution/python/WindowInPandasExec.scala | 2 +- .../continuous/ContinuousCoalesceRDD.scala | 5 +- .../ContinuousQueuedDataReader.scala | 2 +- .../shuffle/ContinuousShuffleReadRDD.scala | 2 +- .../state/SymmetricHashJoinStateManager.scala | 2 +- .../execution/streaming/state/package.scala | 2 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 3 +- 43 files changed, 266 insertions(+), 161 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 69739745aa6cf..ceadf108c86cd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -123,7 +123,10 @@ abstract class TaskContext extends Serializable { * * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + def addTaskCompletionListener[U](f: (TaskContext) => U): TaskContext = { + // Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make + // this function polymorphic for every scala version >= 2.12, otherwise an overloaded method + // resolution error occurs at compile time. addTaskCompletionListener(new TaskCompletionListener { override def onTaskCompletion(context: TaskContext): Unit = f(context) }) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ebabedf950e39..7b31857588252 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -94,7 +94,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() if (!reuseWorker || !released.get) { try { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e125095cf4777..cbd49e070f2eb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -262,7 +262,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) val blockManager = SparkEnv.get.blockManager Option(TaskContext.get()) match { case Some(taskContext) => - taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + taskContext.addTaskCompletionListener[Unit](_ => blockManager.releaseLock(blockId)) case None => // This should only happen on the driver, where broadcast variables may be accessed // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 44895abc7bd4d..3974580cfaa11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -278,7 +278,7 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytes read before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index aab46b8954bf7..56ef3e107a980 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -77,7 +77,7 @@ class JdbcRDD[T: ClassTag]( override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] { - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ff66a04859d10..2d66d25ba39fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -214,7 +214,7 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener { context => + context.addTaskCompletionListener[Unit] { context => // Update the bytesRead before closing is to make sure lingering bytesRead statistics in // this thread get correctly added. updateBytesRead() diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 979152b55f957..8273d8a9eb476 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -300,7 +300,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => deserializeStream.close()) + context.addTaskCompletionListener[Unit](context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..74b0e0b3a741a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -104,7 +104,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) // Use completion callback to stop sorter if task was finished/cancelled. - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { sorter.stop() }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b31862323a895..00d01dd28afb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -346,7 +346,7 @@ final class ShuffleBlockFetcherIterator( private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. - context.addTaskCompletionListener(_ => cleanup()) + context.addTaskCompletionListener[Unit](_ => cleanup()) // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 4cc5bcb7f9baf..06fd56e54d9c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -827,7 +827,7 @@ private[storage] class PartiallySerializedBlock[T]( // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. // The dispose() method is idempotent, so it's safe to call it unconditionally. Option(TaskContext.get()).foreach { taskContext => - taskContext.addTaskCompletionListener { _ => + taskContext.addTaskCompletionListener[Unit] { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. unrolledBuffer.dispose() diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 073d71c63b0c7..d8c840c356527 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.invoke.SerializedLambda import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials @@ -33,6 +34,8 @@ import org.apache.spark.internal.Logging */ private[spark] object ClosureCleaner extends Logging { + private val isScala2_11 = scala.util.Properties.versionString.contains("2.11") + // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. @@ -159,6 +162,42 @@ private[spark] object ClosureCleaner extends Logging { clean(closure, checkSerializable, cleanTransitively, Map.empty) } + /** + * Try to get a serialized Lambda from the closure. + * + * @param closure the closure to check. + */ + private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { + if (isScala2_11) { + return None + } + val isClosureCandidate = + closure.getClass.isSynthetic && + closure + .getClass + .getInterfaces.exists(_.getName.equals("scala.Serializable")) + + if (isClosureCandidate) { + try { + Option(inspect(closure)) + } catch { + case e: Exception => + // no need to check if debug is enabled here the Spark + // logging api covers this. + logDebug("Closure is not a serialized lambda.", e) + None + } + } else { + None + } + } + + private def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] + } + /** * Helper method to clean the given closure in place. * @@ -206,7 +245,12 @@ private[spark] object ClosureCleaner extends Logging { cleanTransitively: Boolean, accessedFields: Map[Class[_], Set[String]]): Unit = { - if (!isClosure(func.getClass)) { + // most likely to be the case with 2.12, 2.13 + // so we check first + // non LMF-closures should be less frequent from now on + val lambdaFunc = getSerializedLambda(func) + + if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") return } @@ -218,118 +262,132 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") - - // A list of classes that represents closures enclosed in the given one - val innerClasses = getInnerClosureClasses(func) - - // A list of enclosing objects and their respective classes, from innermost to outermost - // An outer object at a given index is of type outer class at the same index - val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) - - // For logging purposes only - val declaredFields = func.getClass.getDeclaredFields - val declaredMethods = func.getClass.getDeclaredMethods - - if (log.isDebugEnabled) { - logDebug(" + declared fields: " + declaredFields.size) - declaredFields.foreach { f => logDebug(" " + f) } - logDebug(" + declared methods: " + declaredMethods.size) - declaredMethods.foreach { m => logDebug(" " + m) } - logDebug(" + inner classes: " + innerClasses.size) - innerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer classes: " + outerClasses.size) - outerClasses.foreach { c => logDebug(" " + c.getName) } - logDebug(" + outer objects: " + outerObjects.size) - outerObjects.foreach { o => logDebug(" " + o) } - } + if (lambdaFunc.isEmpty) { + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) + + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods + + if (log.isDebugEnabled) { + logDebug(s" + declared fields: ${declaredFields.size}") + declaredFields.foreach { f => logDebug(s" $f") } + logDebug(s" + declared methods: ${declaredMethods.size}") + declaredMethods.foreach { m => logDebug(s" $m") } + logDebug(s" + inner classes: ${innerClasses.size}") + innerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer classes: ${outerClasses.size}" ) + outerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer objects: ${outerObjects.size}") + outerObjects.foreach { o => logDebug(s" $o") } + } - // Fail fast if we detect return statements in closures - getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - // If accessed fields is not populated yet, we assume that - // the closure we are trying to clean is the starting one - if (accessedFields.isEmpty) { - logDebug(s" + populating accessed fields because this is the starting closure") - // Initialize accessed fields with the outer classes first - // This step is needed to associate the fields to the correct classes later - initAccessedFields(accessedFields, outerClasses) - - // Populate accessed fields by visiting all fields and methods accessed by this and - // all of its inner closures. If transitive cleaning is enabled, this may recursively - // visits methods that belong to other classes in search of transitively referenced fields. - for (cls <- func.getClass :: innerClasses) { - getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + // Fail fast if we detect return statements in closures + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + initAccessedFields(accessedFields, outerClasses) + + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } } - } - logDebug(s" + fields accessed by starting closure: " + accessedFields.size) - accessedFields.foreach { f => logDebug(" " + f) } - - // List of outer (class, object) pairs, ordered from outermost to innermost - // Note that all outer objects but the outermost one (first one in this list) must be closures - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var parent: AnyRef = null - if (outerPairs.size > 0) { - val (outermostClass, outermostObject) = outerPairs.head - if (isClosure(outermostClass)) { - logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") - } else if (outermostClass.getName.startsWith("$line")) { - // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it - // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. - logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => logDebug(" " + f) } + + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures + var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse + var parent: AnyRef = null + if (outerPairs.nonEmpty) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone + // and clean it as it may carray a lot of unnecessary information, + // e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object," + + "so do not clone it: " + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + - outerPairs.head) - parent = outermostObject // e.g. SparkContext - outerPairs = outerPairs.tail + logDebug(" + there are no enclosing objects!") } - } else { - logDebug(" + there are no enclosing objects!") - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - logDebug(s" + cloning the object $obj of class ${cls.getName}") - // We null out these unused references by cloning each object and then filling in all - // required fields from the original object. We need the parent here because the Java - // language specification requires the first constructor parameter of any closure to be - // its enclosing object. - val clone = cloneAndSetFields(parent, obj, cls, accessedFields) - - // If transitive cleaning is enabled, we recursively clean any enclosing closure using - // the already populated accessed fields map of the starting closure - if (cleanTransitively && isClosure(clone.getClass)) { - logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") - // No need to check serializable here for the outer closures because we're - // only interested in the serializability of the starting closure - clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + logDebug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + // No need to check serializable here for the outer closures because we're + // only interested in the serializability of the starting closure + clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + } + parent = clone } - parent = clone - } - // Update the parent pointer ($outer) of this closure - if (parent != null) { - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - // If the starting closure doesn't actually need our enclosing object, then just null it out - if (accessedFields.contains(func.getClass) && - !accessedFields(func.getClass).contains("$outer")) { - logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") - field.set(func, null) - } else { - // Update this closure's parent pointer to point to our enclosing object, - // which could either be a cloned closure or the original user object - field.set(func, parent) + // Update the parent pointer ($outer) of this closure + if (parent != null) { + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } } - } - logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + } else { + logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + + // scalastyle:off classforname + val captClass = Class.forName(lambdaFunc.get.getCapturingClass.replace('/', '.'), + false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + // Fail fast if we detect return statements in closures + getClassReader(captClass) + .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) + logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") + } if (checkSerializable) { ensureSerializable(func) @@ -366,14 +424,24 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM5) { +private class ReturnStatementFinder(targetMethodName: Option[String] = None) + extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { + // $anonfun$ covers Java 8 lambdas if (name.contains("apply") || name.contains("$anonfun$")) { + // A method with suffix "$adapted" will be generated in cases like + // { _:Int => return; Seq()} but not { _:Int => return; true} + // closure passed is $anonfun$t$1$adapted while actual code resides in $anonfun$s$1 + // visitor will see only $anonfun$s$1$adapted, so we remove the suffix, see + // https://github.com/scala/scala-dev/issues/109 + val isTargetMethod = targetMethodName.isEmpty || + name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") + new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { - if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { throw new ReturnStatementInClosureException } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5c6dd45ec58e3..d83da0d126d89 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -565,7 +565,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } - context.addTaskCompletionListener(context => cleanup()) + context.addTaskCompletionListener[Unit](context => cleanup()) } private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 9a19baee9569e..a0010f18c18a1 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -121,6 +121,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass { val n2 = 222 val s2 = "bbb" @@ -141,6 +142,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val n2 = 222 val s2 = "bbb" @@ -154,6 +156,7 @@ class ClosureCleanerSuite extends SparkFunSuite { } test("SPARK-22328: multiple outer classes have the same parent class") { + assume(!ClosureCleanerSuite2.supportsLMFs) val concreteObject = new TestAbstractClass2 { val innerObject = new TestAbstractClass2 { diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 278fada83d78c..96da8ec3b2a1c 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -145,6 +145,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get inner closure classes") { + assume(!ClosureCleanerSuite2.supportsLMFs) val closure1 = () => 1 val closure2 = () => { () => 1 } val closure3 = (i: Int) => { @@ -171,6 +172,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -207,6 +209,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get outer classes and objects with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -258,6 +261,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -296,6 +300,7 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields with nesting") { + assume(!ClosureCleanerSuite2.supportsLMFs) val localValue = someSerializableValue val test1 = () => { @@ -538,17 +543,22 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // As before, this closure is neither serializable nor cleanable verifyCleaning(inner1, serializableBefore = false, serializableAfter = false) - // This closure is no longer serializable because it now has a pointer to the outer closure, - // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. - // If we do not clean transitively, we will not null out this indirect reference. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = false, transitive = false) - - // If we clean transitively, we will find that method `a` does not actually reference the - // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out - // the outer closure's parent pointer. This will make `inner2` serializable. - verifyCleaning( - inner2, serializableBefore = false, serializableAfter = true, transitive = true) + if (ClosureCleanerSuite2.supportsLMFs) { + verifyCleaning( + inner2, serializableBefore = true, serializableAfter = true) + } else { + // This closure is no longer serializable because it now has a pointer to the outer closure, + // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. + // If we do not clean transitively, we will not null out this indirect reference. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = false, transitive = false) + + // If we clean transitively, we will find that method `a` does not actually reference the + // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out + // the outer closure's parent pointer. This will make `inner2` serializable. + verifyCleaning( + inner2, serializableBefore = false, serializableAfter = true, transitive = true) + } } // Same as above, but with more levels of nesting @@ -565,4 +575,25 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri test6()()() } + test("verify nested non-LMF closures") { + assume(ClosureCleanerSuite2.supportsLMFs) + class A1(val f: Int => Int) + class A2(val f: Int => Int => Int) + class B extends A1(x => x*x) + class C extends A2(x => new B().f ) + val closure1 = new B().f + val closure2 = new C().f + // serializable already + verifyCleaning(closure1, serializableBefore = true, serializableAfter = true) + // brings in deps that can't be cleaned + verifyCleaning(closure2, serializableBefore = false, serializableAfter = false) + } +} + +object ClosureCleanerSuite2 { + // Scala 2.12 allows better interop with Java 8 via lambda syntax. This is supported + // by implementing FunctionN classes in Scala’s standard library as Single Abstract + // Method (SAM) types. Lambdas are implemented via the invokedynamic instruction and + // the use of the LambdaMwtaFactory (LMF) machanism. + val supportsLMFs = scala.util.Properties.versionString.contains("2.12") } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 7db452bb6b09a..67765162d634b 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -179,7 +179,7 @@ private[avro] class AvroFileFormat extends FileFormat // Ensure that the reader is closed even if the task fails or doesn't consume the entire // iterator of records. Option(TaskContext.get()).foreach { taskContext => - taskContext.addTaskCompletionListener { _ => + taskContext.addTaskCompletionListener[Unit] { _ => reader.close() } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 498e344ea39f4..53bd9a96d1d68 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -166,7 +166,7 @@ private[kafka010] class KafkaSourceRDD( } } // Release consumer, either by removing it or indicating we're no longer using it - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => underlying.closeIfNeeded() } underlying diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 3efc90fe466b2..4513dca44c7c6 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -237,7 +237,7 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - context.addTaskCompletionListener(_ => closeIfNeeded()) + context.addTaskCompletionListener[Unit](_ => closeIfNeeded()) val consumer = { KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 4e84ff044f55e..39dcd911a0814 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -154,7 +154,7 @@ private[libsvm] class LibSVMFileFormat (file: PartitionedFile) => { val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val points = linesReader .map(_.toString.trim) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index c1911235f8df3..72505f7fac0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -372,7 +372,7 @@ class TungstenAggregationIterator( } } - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { // At the end of the task, update the task's peak memory usage. Since we destroy // the map to create the sorter, their memory usages should not overlap, so it is safe // to just use the max of the two. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7487564ed64da..501520c0e085e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -86,7 +86,7 @@ private[sql] object ArrowConverters { val root = VectorSchemaRoot.create(arrowSchema, allocator) val arrowWriter = ArrowWriter.create(root) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => root.close() allocator.close() } @@ -137,7 +137,7 @@ private[sql] object ArrowConverters { private var schemaRead = StructType(Seq.empty) private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => closeReader() allocator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 6012aba1acbca..196d057c2de1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -97,7 +97,7 @@ case class InMemoryTableScanExec( columnarBatch.column(i).asInstanceOf[WritableColumnVector], columnarBatchSchema.fields(i).dataType, rowCount) } - taskContext.foreach(_.addTaskCompletionListener(_ => columnarBatch.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => columnarBatch.close())) columnarBatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index c0df6c779d7bd..9fddfad249e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -50,7 +50,7 @@ object CodecStreams { */ def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { val inputStream = createInputStream(config, path) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 28c36b6020d33..99fc78ff3e49b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -214,7 +214,7 @@ class FileScanRDD( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(_ => iterator.close()) + context.addTaskCompletionListener[Unit](_ => iterator.close()) iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 82322df407521..b7b46c7c86a29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -214,7 +214,7 @@ object TextInputCSVDataSource extends CSVDataSource { caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1b3b17c75e756..16b493892e3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -265,7 +265,7 @@ private[jdbc] class JDBCRDD( closed = true } - context.addTaskCompletionListener{ context => close() } + context.addTaskCompletionListener[Unit]{ context => close() } val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 2fee2128ba1f9..d6c588894d7f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -130,7 +130,7 @@ object TextInputJsonDataSource extends JsonDataSource { parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) val textParser = parser.options.encoding .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index df1cebed5bd0a..4574f8247af54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -205,7 +205,7 @@ class OrcFileFormat // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( @@ -220,7 +220,7 @@ class OrcFileFormat val orcRecordReader = new OrcInputFormat[OrcStruct] .createRecordReader(fileSplit, taskAttemptContext) val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2d4ac7686d4c4..283d7761d22d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -411,7 +411,7 @@ class ParquetFileFormat convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) val iter = new RecordReaderIterator(vectorizedReader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) @@ -432,7 +432,7 @@ class ParquetFileFormat } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) reader.initialize(split, hadoopAttemptContext) val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 8661a5395ac44..268297148b522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -120,7 +120,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } else { new HadoopFileWholeTextReader(file, confValue) } - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 7ea53424ae100..782829887c446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -40,7 +40,7 @@ class DataSourceRDD[T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[T] = { val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition .createPartitionReader() - context.addTaskCompletionListener(_ => reader.close()) + context.addTaskCompletionListener[Unit](_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0396168d3f311..dab873bf9b9a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -214,7 +214,7 @@ trait HashJoin { } // At the end of the task, we update the avg hash probe. - TaskContext.get().addTaskCompletionListener(_ => + TaskContext.get().addTaskCompletionListener[Unit](_ => avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f32..2b59ed6e4d16b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -57,7 +57,7 @@ case class ShuffledHashJoinExec( buildTime += (System.nanoTime() - start) / 1000000 buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. - context.addTaskCompletionListener(_ => relation.close()) + context.addTaskCompletionListener[Unit](_ => relation.close()) relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index d00f6f042d6e0..88c9c026928e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -125,7 +125,7 @@ case class AggregateInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index ca665652f204d..85b187159a3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -131,7 +131,7 @@ class ArrowPythonRunner( private var schema: StructType = _ private var vectors: Array[ColumnVector] = _ - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => if (reader != null) { reader.close(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 860dc78c1dd1b..04c7dfdd4e204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -97,7 +97,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => queue.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index a58773122922f..f08f816cbcca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -56,7 +56,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override def open(partitionId: Long, version: Long): Boolean = { outputIterator // initialize everything - TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() } true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 628029b13a6c3..47bfbde56bb3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -142,7 +142,7 @@ case class WindowInPandasExec( // combine input with output from Python. val queue = HybridRowQueue(context.taskMemoryManager(), new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) - context.addTaskCompletionListener { _ => + context.addTaskCompletionListener[Unit] { _ => queue.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala index 10d8fc553fede..aec756c0eb2a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -40,7 +40,7 @@ case class ContinuousCoalesceRDDPartition( queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(endpointName, receiver) - TaskContext.get().addTaskCompletionListener { ctx => + TaskContext.get().addTaskCompletionListener[Unit] { ctx => env.stop(endpoint) } (receiver, endpoint) @@ -118,9 +118,8 @@ class ContinuousCoalesceRDD( } } - context.addTaskCompletionListener { ctx => + context.addTaskCompletionListener[Unit] { ctx => threadPool.shutdownNow() - () } part.writersInitialized = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index bfb87053db475..ec1dabd7da3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -70,7 +70,7 @@ class ContinuousQueuedDataReader( dataReaderThread.setDaemon(true) dataReaderThread.start() - context.addTaskCompletionListener(_ => { + context.addTaskCompletionListener[Unit](_ => { this.close() }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 518223f3cd008..9b13f6398d837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -40,7 +40,7 @@ case class ContinuousShuffleReadPartition( queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(endpointName, receiver) - TaskContext.get().addTaskCompletionListener { ctx => + TaskContext.get().addTaskCompletionListener[Unit] { ctx => env.stop(endpoint) } (receiver, endpoint) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 6b386308c79fb..55d783e023246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -290,7 +290,7 @@ class SymmetricHashJoinStateManager( private val keyWithIndexToValue = new KeyWithIndexToValueStore() // Clean up any state store resources if necessary at the end of the task - Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } } + Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => abortIfNeeded() } } /** Helper trait for invoking common functionalities of a state store. */ private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 0b32327e51dbf..b6021438e902b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -61,7 +61,7 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { + TaskContext.get().addTaskCompletionListener[Unit](_ => { if (!store.hasCommitted) store.abort() }) cleanedF(store, iter) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 20090696ec3fc..de8085f07db19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -164,7 +164,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + Option(TaskContext.get()) + .foreach(_.addTaskCompletionListener[Unit](_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s OrcFileFormat.unwrapOrcStructs( From 46110a589f4e91cd7605c5a2c34c3db6b2635830 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 2 Aug 2018 22:20:41 +0800 Subject: [PATCH 1277/2461] [SPARK-24865][FOLLOW-UP] Remove AnalysisBarrier LogicalPlan Node ## What changes were proposed in this pull request? Remove the AnalysisBarrier LogicalPlan node, which is useless now. ## How was this patch tested? N/A Author: Xiao Li Closes #21962 from gatorsmile/refactor2. --- .../plans/logical/basicLogicalOperators.scala | 20 ------------------- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 2 +- 3 files changed, 2 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 13b51304d7f89..68413d7fd10f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -924,23 +924,3 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } - -/** - * A logical plan for setting a barrier of analysis. - * - * The SQL Analyzer goes through a whole query plan even most part of it is analyzed. This - * increases the time spent on query analysis for long pipelines in ML, especially. - * - * This logical plan wraps an analyzed logical plan to prevent it from analysis again. The barrier - * is applied to the analyzed logical plan in Dataset. It won't change the output of wrapped - * logical plan and just acts as a wrapper to hide it from analyzer. New operations on the dataset - * will be put on the barrier, so only the new nodes created will be analyzed. - * - * This analysis barrier will be removed at the end of analysis stage. - */ -case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { - override protected def innerChildren: Seq[LogicalPlan] = Seq(child) - override def output: Seq[Attribute] = child.output - override def isStreaming: Boolean = child.isStreaming - override def doCanonicalize(): LogicalPlan = child.canonicalized -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3c9e743106260..cd7dc2a2727e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index ed130dc57ee5b..c9929935fb8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel From 7be6fc3c77b00f0eefd276676524ec4e36bab868 Mon Sep 17 00:00:00 2001 From: Kaya Kupferschmidt Date: Thu, 2 Aug 2018 09:22:21 -0500 Subject: [PATCH 1278/2461] [SPARK-24742] Fix NullPointerexception in Field Metadata ## What changes were proposed in this pull request? This pull request provides a fix for SPARK-24742: SQL Field MetaData was throwing an Exception in the hashCode method when a "null" Metadata was added via "putNull" ## How was this patch tested? A new unittest is provided in org/apache/spark/sql/types/MetadataSuite.scala Author: Kaya Kupferschmidt Closes #21722 from kupferk/SPARK-24742. --- .../org/apache/spark/sql/types/Metadata.scala | 2 + .../spark/sql/types/MetadataSuite.scala | 74 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 352fb545f4b6b..7c15dc0de4b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -215,6 +215,8 @@ object Metadata { x.## case x: Metadata => hash(x.map) + case null => + 0 case other => throw new RuntimeException(s"Do not support type ${other.getClass}.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala new file mode 100644 index 0000000000000..210e65708170f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class MetadataSuite extends SparkFunSuite { + test("String Metadata") { + val meta = new MetadataBuilder().putString("key", "value").build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getString("key") === "value") + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getString("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Long Metadata") { + val meta = new MetadataBuilder().putLong("key", 12).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getLong("key") === 12) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getLong("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Double Metadata") { + val meta = new MetadataBuilder().putDouble("key", 12).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getDouble("key") === 12) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getDouble("no_such_key")) + intercept[ClassCastException](meta.getBoolean("key")) + } + + test("Boolean Metadata") { + val meta = new MetadataBuilder().putBoolean("key", true).build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getBoolean("key") === true) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getBoolean("no_such_key")) + intercept[ClassCastException](meta.getString("key")) + } + + test("Null Metadata") { + val meta = new MetadataBuilder().putNull("key").build() + assert(meta === meta) + assert(meta.## !== 0) + assert(meta.getString("key") === null) + assert(meta.getDouble("key") === 0) + assert(meta.getLong("key") === 0) + assert(meta.getBoolean("key") === false) + assert(meta.contains("key")) + intercept[NoSuchElementException](meta.getLong("no_such_key")) + } +} From d182b3d34d6afade401b8a455b774059bae9d90f Mon Sep 17 00:00:00 2001 From: Kaya Kupferschmidt Date: Thu, 2 Aug 2018 22:23:24 +0800 Subject: [PATCH 1279/2461] [SPARK-24742] Fix NullPointerexception in Field Metadata ## What changes were proposed in this pull request? This pull request provides a fix for SPARK-24742: SQL Field MetaData was throwing an Exception in the hashCode method when a "null" Metadata was added via "putNull" ## How was this patch tested? A new unittest is provided in org/apache/spark/sql/types/MetadataSuite.scala Author: Kaya Kupferschmidt Closes #21722 from kupferk/SPARK-24742. From f04cd670943d0eb6eb688a0f50d56293cda554ef Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 2 Aug 2018 09:26:27 -0500 Subject: [PATCH 1280/2461] [MINOR] remove dead code in ExpressionEvalHelper ## What changes were proposed in this pull request? This addresses https://github.com/apache/spark/pull/21236/files#r207078480 both https://github.com/apache/spark/pull/21236 and https://github.com/apache/spark/pull/21838 add a InternalRow result check to ExpressionEvalHelper and becomes duplicated. ## How was this patch tested? N/A Author: Wenchen Fan Closes #21958 from cloud-fan/minor. --- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index d045267ef5d9e..6684e5ce18d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -105,9 +105,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result - case (result: UnsafeRow, expected: GenericInternalRow) => - val structType = exprDataType.asInstanceOf[StructType] - result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected From 15fc2372269159ea2556b028d4eb8860c4108650 Mon Sep 17 00:00:00 2001 From: LucaCanali Date: Wed, 18 Jul 2018 23:19:02 +0200 Subject: [PATCH 1281/2461] Updates to Accumulators --- .../apache/spark/api/python/PythonRDD.scala | 12 +++-- python/pyspark/accumulators.py | 53 ++++++++++++++----- python/pyspark/context.py | 5 +- 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a1ee2f7d1b119..8bc0ff7936daf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -586,8 +586,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By */ private[spark] class PythonAccumulatorV2( @transient private val serverHost: String, - private val serverPort: Int) - extends CollectionAccumulator[Array[Byte]] { + private val serverPort: Int, + private val secretToken: String) + extends CollectionAccumulator[Array[Byte]] with Logging{ Utils.checkHost(serverHost) @@ -602,12 +603,17 @@ private[spark] class PythonAccumulatorV2( private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) + logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) } socket } // Need to override so the types match with PythonFunction - override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) + override def copyAndReset(): PythonAccumulatorV2 = { + new PythonAccumulatorV2(serverHost, serverPort, secretToken) + } override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index f730d290273fe..1276c31b33737 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -227,20 +227,46 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry - while not self.server.server_shutdown: - # Poll every 1 second for new data -- don't block in case of shutdown. - r, _, _ = select.select([self.rfile], [], [], 1) - if self.rfile in r: - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) - + auth_token = self.server.auth_token + def poll(func): + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + if func(): + break + + def accum_updates(): + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + return False + + def authenticate_and_accum_updates(): + received_token = self.rfile.read(len(auth_token)) + if isinstance(received_token, bytes): + received_token = received_token.decode("utf-8") + if (received_token == auth_token): + accum_updates() + # we've authenticated, we can break out of the first loop now + return True + else: + raise Exception("The value of the provided token to the AccumulatorServer is not correct.") + + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated, don't need to check for the token anymore + poll(accum_updates) class AccumulatorServer(SocketServer.TCPServer): + def __init__(self, server_address, RequestHandlerClass, auth_token): + SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + self.auth_token = auth_token + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -252,10 +278,9 @@ def shutdown(self): SocketServer.TCPServer.shutdown(self) self.server_close() - -def _start_update_server(): +def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2cb3117184334..0ff4f5be0a228 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -183,9 +183,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server - self._accumulatorServer = accumulators._start_update_server() + auth_token = self._gateway.gateway_parameters.auth_token + self._accumulatorServer = accumulators._start_update_server(auth_token) (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') From ad2e63662885b67b1e94030b13fdae4f7366dc4a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 2 Aug 2018 09:28:13 -0700 Subject: [PATCH 1282/2461] [SPARK-24598][DOCS] State in the documentation the behavior when arithmetic operations cause overflow ## What changes were proposed in this pull request? According to the discussion in https://github.com/apache/spark/pull/21599, changing the behavior of arithmetic operations so that they can check for overflow is not nice in a minor release. What we can do for 2.4 is warn users about the current behavior in the documentation, so that they are aware of the issue and can take proper actions. ## How was this patch tested? NA Author: Marco Gaido Closes #21967 from mgaido91/SPARK-24598_doc. --- docs/sql-programming-guide.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5f1eee85b5154..0900f8317d635 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -3072,3 +3072,10 @@ Specifically: - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. + + ## Arithmetic operations + +Operations performed on numeric types (with the exception of `decimal`) are not checked for overflow. +This means that in case an operation causes an overflow, the result is the same that the same operation +returns in a Java/Scala program (eg. if the sum of 2 integers is higher than the maximum value representable, +the result is a negative number). From 38e4699c978e56a0f24b8efb94fd3206cdd8b3fe Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 2 Aug 2018 09:36:26 -0700 Subject: [PATCH 1283/2461] [SPARK-24820][SPARK-24821][CORE] Fail fast when submitted job contains a barrier stage with unsupported RDD chain pattern ## What changes were proposed in this pull request? Check on job submit to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The following patterns are not supported: - Ancestor RDDs that have different number of partitions from the resulting RDD (eg. union()/coalesce()/first()/PartitionPruningRDD); - An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)). ## How was this patch tested? Add test cases in `BarrierStageOnSubmittedSuite`. Author: Xingbo Jiang Closes #21927 from jiangxb1987/SPARK-24820. --- .../apache/spark/scheduler/DAGScheduler.scala | 55 ++++++- .../spark/BarrierStageOnSubmittedSuite.scala | 153 ++++++++++++++++++ 2 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4858af71c1a9c..3dd0718ac673d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.rdd.{PartitionPruningRDD, RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -340,6 +340,22 @@ class DAGScheduler( } } + /** + * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The + * following patterns are not supported: + * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (eg. + * union()/coalesce()/first()/take()/PartitionPruningRDD); + * 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)). + */ + private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = { + val predicate: RDD[_] => Boolean = (r => + r.getNumPartitions == numTasksInStage && r.dependencies.filter(_.rdd.isBarrier()).size <= 1) + if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) { + throw new SparkException( + DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + } + /** * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a * previously run stage generated the same shuffle data, this function will copy the output @@ -348,6 +364,7 @@ class DAGScheduler( */ def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd + checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() @@ -376,6 +393,7 @@ class DAGScheduler( partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { + checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite) @@ -451,6 +469,32 @@ class DAGScheduler( parents } + /** + * Traverses the given RDD and its ancestors within the same stage and checks whether all of the + * RDDs satisfy a given predicate. + */ + private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ArrayStack[RDD[_]] + waitingForVisit.push(rdd) + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + if (!predicate(toVisit)) { + return false + } + visited += toVisit + toVisit.dependencies.foreach { + case _: ShuffleDependency[_, _, _] => + // Not within the same stage with current rdd, do nothing. + case dependency => + waitingForVisit.push(dependency.rdd) + } + } + } + true + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] @@ -1948,4 +1992,13 @@ private[spark] object DAGScheduler { // Number of consecutive stage attempts allowed before a stage is aborted val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 + + // Error message when running a barrier stage that have unsupported RDD chain pattern. + val ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN = + "[SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of " + + "RDD chain within a barrier stage:\n1. Ancestor RDDs that have different number of " + + "partitions from the resulting RDD (eg. union()/coalesce()/first()/take()/" + + "PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head " + + "(scala) or barrierRdd.collect()[0] (python).\n" + + "2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))." } diff --git a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala new file mode 100644 index 0000000000000..f2b3884e25ffa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} +import org.apache.spark.scheduler.DAGScheduler +import org.apache.spark.util.ThreadUtils + +/** + * This test suite covers all the cases that shall fail fast on job submitted that contains one + * of more barrier stages. + */ +class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach + with LocalSparkContext { + + override def beforeEach(): Unit = { + super.beforeEach() + + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + sc = new SparkContext(conf) + } + + private def testSubmitJob( + sc: SparkContext, + rdd: RDD[Int], + partitions: Option[Seq[Int]] = None, + message: String): Unit = { + val futureAction = sc.submitJob( + rdd, + (iter: Iterator[Int]) => iter.toArray, + partitions.getOrElse(0 until rdd.partitions.length), + { case (_, _) => return }: (Int, Array[Int]) => Unit, + { return } + ) + + val error = intercept[SparkException] { + ThreadUtils.awaitResult(futureAction, 5 seconds) + }.getCause.getMessage + assert(error.contains(message)) + } + + test("submit a barrier ResultStage that contains PartitionPruningRDD") { + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .barrier() + .mapPartitions((iter, context) => iter) + testSubmitJob(sc, rdd, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier ShuffleMapStage that contains PartitionPruningRDD") { + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .barrier() + .mapPartitions((iter, context) => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage that doesn't contain PartitionPruningRDD") { + val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) + val rdd = prunedRdd + .repartition(2) + .barrier() + .mapPartitions((iter, context) => iter) + // Should be able to submit job and run successfully. + val result = rdd.collect().sorted + assert(result === Seq(6, 7, 8, 9, 10)) + } + + test("submit a barrier stage with partial partitions") { + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions((iter, context) => iter) + testSubmitJob(sc, rdd, Some(Seq(1, 3)), + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with union()") { + val rdd1 = sc.parallelize(1 to 10, 2) + .barrier() + .mapPartitions((iter, context) => iter) + val rdd2 = sc.parallelize(1 to 20, 2) + val rdd3 = rdd1 + .union(rdd2) + .map(x => x * 2) + // Fail the job on submit because the barrier RDD (rdd1) may be not assigned Task 0. + testSubmitJob(sc, rdd3, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with coalesce()") { + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions((iter, context) => iter) + .coalesce(1) + // Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage + // only launches 1 task. + testSubmitJob(sc, rdd, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage that contains an RDD that depends on multiple barrier RDDs") { + val rdd1 = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions((iter, context) => iter) + val rdd2 = sc.parallelize(11 to 20, 4) + .barrier() + .mapPartitions((iter, context) => iter) + val rdd3 = rdd1 + .zip(rdd2) + .map(x => x._1 + x._2) + testSubmitJob(sc, rdd3, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + } + + test("submit a barrier stage with zip()") { + val rdd1 = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions((iter, context) => iter) + val rdd2 = sc.parallelize(11 to 20, 4) + val rdd3 = rdd1 + .zip(rdd2) + .map(x => x._1 + x._2) + // Should be able to submit job and run successfully. + val result = rdd3.collect().sorted + assert(result === Seq(12, 14, 16, 18, 20, 22, 24, 26, 28, 30)) + } +} From 0df6bf882907d7d76572f513168a144067d0e0ec Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Aug 2018 03:18:46 +0900 Subject: [PATCH 1284/2461] [BUILD] Fix lint-python. ## What changes were proposed in this pull request? This pr fixes lint-python. ``` ./python/pyspark/accumulators.py:231:9: E306 expected 1 blank line before a nested definition, found 0 ./python/pyspark/accumulators.py:257:101: E501 line too long (107 > 100 characters) ./python/pyspark/accumulators.py:264:1: E302 expected 2 blank lines, found 1 ./python/pyspark/accumulators.py:281:1: E302 expected 2 blank lines, found 1 ``` ## How was this patch tested? Executed lint-python manually. Author: Takuya UESHIN Closes #21973 from ueshin/issues/build/1/fix_lint-python. --- python/pyspark/accumulators.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 1276c31b33737..30ad04297c682 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -228,6 +228,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry auth_token = self.server.auth_token + def poll(func): while not self.server.server_shutdown: # Poll every 1 second for new data -- don't block in case of shutdown. @@ -254,13 +255,15 @@ def authenticate_and_accum_updates(): # we've authenticated, we can break out of the first loop now return True else: - raise Exception("The value of the provided token to the AccumulatorServer is not correct.") + raise Exception( + "The value of the provided token to the AccumulatorServer is not correct.") # first we keep polling till we've received the authentication token poll(authenticate_and_accum_updates) # now we've authenticated, don't need to check for the token anymore poll(accum_updates) + class AccumulatorServer(SocketServer.TCPServer): def __init__(self, server_address, RequestHandlerClass, auth_token): @@ -278,6 +281,7 @@ def shutdown(self): SocketServer.TCPServer.shutdown(self) self.server_close() + def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) From 02f967795b7e8ccf2738d567928e47c38c1134e1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 2 Aug 2018 13:00:33 -0700 Subject: [PATCH 1285/2461] [SPARK-23908][SQL] Add transform function. ## What changes were proposed in this pull request? This pr adds `transform` function which transforms elements in an array using the function. Optionally we can take the index of each element as the second argument. ```sql > SELECT transform(array(1, 2, 3), x -> x + 1); array(2, 3, 4) > SELECT transform(array(1, 2, 3), (x, i) -> x + i); array(1, 3, 5) ``` ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21954 from ueshin/issues/SPARK-23908/transform. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../sql/catalyst/analysis/Analyzer.scala | 3 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../analysis/higherOrderFunctions.scala | 166 ++++++++++++++ .../expressions/higherOrderFunctions.scala | 212 ++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 10 + .../ResolveLambdaVariablesSuite.scala | 89 ++++++++ .../HigherOrderFunctionsSuite.scala | 97 ++++++++ .../parser/ExpressionParserSuite.scala | 5 + .../spark/sql/catalyst/plans/PlanTest.scala | 2 + .../inputs/higher-order-functions.sql | 26 +++ .../results/higher-order-functions.sql.out | 81 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 153 +++++++++++++ 13 files changed, 847 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 2aca10f1bfbc7..9ad6f30c40a88 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -591,6 +591,8 @@ primaryExpression (OVER windowSpec)? #functionCall | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) argument+=expression FROM argument+=expression ')' #functionCall + | IDENTIFIER '->' expression #lambda + | '(' IDENTIFIER (',' IDENTIFIER)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 76dc86710909e..7f235ac560299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -180,6 +180,8 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveHigherOrderFunctions(catalog) :: + ResolveLambdaVariables(conf) :: ResolveTimeZone(conf) :: ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ @@ -878,6 +880,7 @@ class Analyzer( } private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case f: LambdaFunction if !f.bound => f case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b8b311219ca8d..f7517486e5411 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -440,6 +440,7 @@ object FunctionRegistry { expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), + expression[ArrayTransform]("transform"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala new file mode 100644 index 0000000000000..063ca0fc3252d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Resolve a higher order functions from the catalog. This is different from regular function + * resolution because lambda functions can only be resolved after the function has been resolved; + * so we need to resolve higher order function when all children are either resolved or a lambda + * function. + */ +case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case q: LogicalPlan => + q.transformExpressions { + case u @ UnresolvedFunction(fn, children, false) + if hasLambdaAndResolvedArguments(children) => + withPosition(u) { + catalog.lookupFunction(fn, children) match { + case func: HigherOrderFunction => func + case other => other.failAnalysis( + "A lambda function should only be used in a higher order function. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a " + + s"higher order function.") + } + } + } + } + + /** + * Check if the arguments of a function are either resolved or a lambda function. + */ + private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + lambdas.nonEmpty && others.forall(_.resolved) + } +} + +/** + * Resolve the lambda variables exposed by a higher order functions. + * + * This rule works in two steps: + * [1]. Bind the anonymous variables exposed by the higher order function to the lambda function's + * arguments; this creates named and typed lambda variables. The argument names are checked + * for duplicates and the number of arguments are checked during this step. + * [2]. Resolve the used lambda variables used in the lambda function's function expression tree. + * Note that we allow the use of variables from outside the current lambda, this can either + * be a lambda function defined in an outer scope, or a attribute in produced by the plan's + * child. If names are duplicate, the name defined in the most inner scope is used. + */ +case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { + + type LambdaVariableMap = Map[String, NamedExpression] + + private val canonicalizer = { + if (!conf.caseSensitiveAnalysis) { + s: String => s.toLowerCase + } else { + s: String => s + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperators { + case q: LogicalPlan => + q.mapExpressions(resolve(_, Map.empty)) + } + } + + /** + * Create a bound lambda function by binding the arguments of a lambda function to the given + * partial arguments (dataType and nullability only). If the expression happens to be an already + * bound lambda function then we assume it has been bound to the correct arguments and do + * nothing. This function will produce a lambda function with hidden arguments when it is passed + * an arbitrary expression. + */ + private def createLambda( + e: Expression, + partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction if f.bound => f + + case LambdaFunction(function, names, _) => + if (names.size != partialArguments.size) { + e.failAnalysis( + s"The number of lambda function arguments '${names.size}' does not " + + "match the number of arguments expected by the higher order function " + + s"'${partialArguments.size}'.") + } + + if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { + e.failAnalysis( + "Lambda function arguments should not have names that are semantically the same.") + } + + val arguments = partialArguments.zip(names).map { + case ((dataType, nullable), ne) => + NamedLambdaVariable(ne.name, dataType, nullable) + } + LambdaFunction(function, arguments) + + case _ => + // This expression does not consume any of the lambda's arguments (it is independent). We do + // create a lambda function with default parameters because this is expected by the higher + // order function. Note that we hide the lambda variables produced by this function in order + // to prevent accidental naming collisions. + val arguments = partialArguments.zipWithIndex.map { + case ((dataType, nullable), i) => + NamedLambdaVariable(s"col$i", dataType, nullable) + } + LambdaFunction(e, arguments, hidden = true) + } + + /** + * Resolve lambda variables in the expression subtree, using the passed lambda variable registry. + */ + private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { + case _ if e.resolved => e + + case h: HigherOrderFunction if h.inputResolved => + h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) + + case l: LambdaFunction if !l.bound => + // Do not resolve an unbound lambda function. If we see such a lambda function this means + // that either the higher order function has yet to be resolved, or that we are seeing + // dangling lambda function. + l + + case l: LambdaFunction if !l.hidden => + val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap + l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap)) + + case u @ UnresolvedAttribute(name +: nestedFields) => + parentLambdaMap.get(canonicalizer(name)) match { + case Some(lambda) => + nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) => + ExtractValue(expr, Literal(fieldName), conf.resolver) + } + case None => u + } + + case _ => + e.mapChildren(resolve(_, parentLambdaMap)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala new file mode 100644 index 0000000000000..c5c3482afa134 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.concurrent.atomic.AtomicReference + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.types._ + +/** + * A named lambda variable. + */ +case class NamedLambdaVariable( + name: String, + dataType: DataType, + nullable: Boolean, + value: AtomicReference[Any] = new AtomicReference(), + exprId: ExprId = NamedExpression.newExprId) + extends LeafExpression + with NamedExpression + with CodegenFallback { + + override def qualifier: Option[String] = None + + override def newInstance(): NamedExpression = + copy(value = new AtomicReference(), exprId = NamedExpression.newExprId) + + override def toAttribute: Attribute = { + AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, None) + } + + override def eval(input: InternalRow): Any = value.get + + override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" + + override def simpleString: String = s"lambda $name#${exprId.id}: ${dataType.simpleString}" +} + +/** + * A lambda function and its arguments. A lambda function can be hidden when a user wants to + * process an completely independent expression in a [[HigherOrderFunction]], the lambda function + * and its variables are then only used for internal bookkeeping within the higher order function. + */ +case class LambdaFunction( + function: Expression, + arguments: Seq[NamedExpression], + hidden: Boolean = false) + extends Expression with CodegenFallback { + + override def children: Seq[Expression] = function +: arguments + override def dataType: DataType = function.dataType + override def nullable: Boolean = function.nullable + + lazy val bound: Boolean = arguments.forall(_.resolved) + + override def eval(input: InternalRow): Any = function.eval(input) +} + +/** + * A higher order function takes one or more (lambda) functions and applies these to some objects. + * The function produces a number of variables which can be consumed by some lambda function. + */ +trait HigherOrderFunction extends Expression { + + override def children: Seq[Expression] = inputs ++ functions + + /** + * Inputs to the higher ordered function. + */ + def inputs: Seq[Expression] + + /** + * All inputs have been resolved. This means that the types and nullabilty of (most of) the + * lambda function arguments is known, and that we can start binding the lambda functions. + */ + lazy val inputResolved: Boolean = inputs.forall(_.resolved) + + /** + * Functions applied by the higher order function. + */ + def functions: Seq[Expression] + + /** + * All inputs must be resolved and all functions must be resolved lambda functions. + */ + override lazy val resolved: Boolean = inputResolved && functions.forall { + case l: LambdaFunction => l.resolved + case _ => false + } + + /** + * Bind the lambda functions to the [[HigherOrderFunction]] using the given bind function. The + * bind function takes the potential lambda and it's (partial) arguments and converts this into + * a bound lambda function. + */ + def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + + @transient lazy val functionsForEval: Seq[Expression] = functions.map { + case LambdaFunction(function, arguments, hidden) => + val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap + function.transformUp { + case variable: NamedLambdaVariable if argumentMap.contains(variable.exprId) => + argumentMap(variable.exprId) + } + } +} + +trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { + + def input: Expression + + override def inputs: Seq[Expression] = input :: Nil + + def function: Expression + + override def functions: Seq[Expression] = function :: Nil + + def expectingFunctionType: AbstractDataType = AnyDataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) + + @transient lazy val functionForEval: Expression = functionsForEval.head +} + +/** + * Transform elements in an array using the transform function. This is similar to + * a `map` in functional programming. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in an array using the function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); + array(2, 3, 4) + > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); + array(1, 3, 5) + """, + since = "2.4.0") +case class ArrayTransform( + input: Expression, + function: Expression) + extends ArrayBasedHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { + val (elementType, containsNull) = input.dataType match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + } + + @transient lazy val (elementVar, indexVar) = { + val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function + val indexVar = if (tail.nonEmpty) { + Some(tail.head.asInstanceOf[NamedLambdaVariable]) + } else { + None + } + (elementVar, indexVar) + } + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) + } + result.update(i, f.eval(input)) + i += 1 + } + result + } + } + + override def prettyName: String = "transform" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8a8db6df37094..0ceeb53e1d7a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1312,6 +1312,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.IDENTIFIER().asScala.map { name => + UnresolvedAttribute.quoted(name.getText) + } + LambdaFunction(expression(ctx.expression), arguments) + } + /** * Create a reference to a window frame, i.e. [[WindowSpecReference]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala new file mode 100644 index 0000000000000..c4171c75ecd03 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{ArrayType, IntegerType} + +/** + * Test suite for [[ResolveLambdaVariables]]. + */ +class ResolveLambdaVariablesSuite extends PlanTest { + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + object Analyzer extends RuleExecutor[LogicalPlan] { + val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil + } + + private val key = 'key.int + private val values1 = 'values1.array(IntegerType) + private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType))) + private val data = LocalRelation(Seq(key, values1, values2)) + private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true) + private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true) + private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true) + + private def plan(e: Expression): LogicalPlan = data.select(e.as("res")) + + private def checkExpression(e1: Expression, e2: Expression): Unit = { + comparePlans(Analyzer.execute(plan(e1)), plan(e2)) + } + + test("resolution - no op") { + checkExpression(key, key) + } + + test("resolution - simple") { + val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil)) + val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) + checkExpression(in, out) + } + + test("resolution - nested") { + val in = ArrayTransform(values2, LambdaFunction( + ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil)) + val out = ArrayTransform(values2, LambdaFunction( + ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) + checkExpression(in, out) + } + + test("resolution - hidden") { + val in = ArrayTransform(values1, key) + val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true)) + checkExpression(in, out) + } + + test("fail - name collisions") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("arguments should not have names that are semantically the same")) + } + + test("fail - lambda arguments") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("does not match the number of arguments expected")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala new file mode 100644 index 0000000000000..e987ea5b8a4d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) + } + + test("ArrayTransform") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val plusOne: Expression => Expression = x => x + 1 + val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + + checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) + checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) + checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6)) + checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4)) + checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5)) + checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) + checkEvaluation(transform(ain, plusOne), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val repeatTwice: Expression => Expression = x => Concat(Seq(x, x)) + val repeatIndexTimes: (Expression, Expression) => Expression = (x, i) => StringRepeat(x, i) + + checkEvaluation(transform(as0, repeatTwice), Seq("aa", "bb", "cc")) + checkEvaluation(transform(as0, repeatIndexTimes), Seq("", "b", "cc")) + checkEvaluation(transform(transform(as0, repeatIndexTimes), repeatTwice), + Seq("", "bb", "cccc")) + checkEvaluation(transform(as1, repeatTwice), Seq("aa", null, "cc")) + checkEvaluation(transform(as1, repeatIndexTimes), Seq("", null, "cc")) + checkEvaluation(transform(transform(as1, repeatIndexTimes), repeatTwice), + Seq("", null, "cccc")) + checkEvaluation(transform(asn, repeatTwice), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, array => Cast(transform(array, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), + Seq("[1, 3, 5]", null, "[4, 6]")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4d422d8506fc..c37b9f148cf48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -234,6 +234,11 @@ class ExpressionParserSuite extends PlanTest { intercept("foo(a x)", "extraneous input 'x'") } + test("lambda functions") { + assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr))) + assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr))) + } + test("window function expressions") { val func = 'foo.function(star()) def windowed( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6241d5cbb1d25..139785719fec7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -60,6 +60,8 @@ trait PlanTestBase extends PredicateHelper { self: Suite => Alias(a.child, a.name)(exprId = ExprId(0)) case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) + case lv: NamedLambdaVariable => + lv.copy(value = null, exprId = ExprId(0)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql new file mode 100644 index 0000000000000..8e928a41f08e0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -0,0 +1,26 @@ +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs); + +-- Only allow lambda's in higher order functions. +select upper(x -> x) as v; + +-- Identity transform an array +select transform(zs, z -> z) as v from nested; + +-- Transform an array +select transform(ys, y -> y * y) as v from nested; + +-- Transform an array with index +select transform(ys, (y, i) -> y + i) as v from nested; + +-- Transform an array with reference +select transform(zs, z -> concat(ys, z)) as v from nested; + +-- Transform an array to an array of 0's +select transform(ys, 0) as v from nested; + +-- Transform a null array +select transform(cast(null as array), x -> x + 1) as v; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out new file mode 100644 index 0000000000000..ca2c3c35333cc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -0,0 +1,81 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select upper(x -> x) as v +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +A lambda function should only be used in a higher order function. However, its class is org.apache.spark.sql.catalyst.expressions.Upper, which is not a higher order function.; line 1 pos 7 + + +-- !query 2 +select transform(zs, z -> z) as v from nested +-- !query 2 schema +struct>> +-- !query 2 output +[[12,99],[123,42],[1]] +[[17]] +[[6,96,65],[-1,-2]] + + +-- !query 3 +select transform(ys, y -> y * y) as v from nested +-- !query 3 schema +struct> +-- !query 3 output +[1024,9409] +[144] +[5929,5776] + + +-- !query 4 +select transform(ys, (y, i) -> y + i) as v from nested +-- !query 4 schema +struct> +-- !query 4 output +[12] +[32,98] +[77,-75] + + +-- !query 5 +select transform(zs, z -> concat(ys, z)) as v from nested +-- !query 5 schema +struct>> +-- !query 5 output +[[12,17]] +[[32,97,12,99],[32,97,123,42],[32,97,1]] +[[77,-76,6,96,65],[77,-76,-1,-2]] + + +-- !query 6 +select transform(ys, 0) as v from nested +-- !query 6 schema +struct> +-- !query 6 output +[0,0] +[0,0] +[0] + + +-- !query 7 +select transform(cast(null as array), x -> x + 1) as v +-- !query 7 schema +struct> +-- !query 7 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e550b142c738d..923482024b033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1647,6 +1647,159 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(result10.first.schema(0).dataType === expectedType10) } + test("transform function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("transform function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("transform function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("transform function - special cases") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("arg") + + def testSpecialCases(): Unit = { + checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, arg)"), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testSpecialCases() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testSpecialCases() + } + + test("transform function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("transform(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("transform(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From efef55388fedef3f7954a385776e666ad4597a58 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 2 Aug 2018 13:05:36 -0700 Subject: [PATCH 1286/2461] [SPARK-24705][SQL] ExchangeCoordinator broken when duplicate exchanges reused ## What changes were proposed in this pull request? In the current master, `EnsureRequirements` sets the number of exchanges in `ExchangeCoordinator` before `ReuseExchange`. Then, `ReuseExchange` removes some duplicate exchange and the actual number of registered exchanges changes. Finally, the assertion in `ExchangeCoordinator` fails because the logical number of exchanges and the actual number of registered exchanges become different; https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala#L201 This pr fixed the issue and the code to reproduce this is as follows; ``` scala> sql("SET spark.sql.adaptive.enabled=true") scala> sql("SET spark.sql.autoBroadcastJoinThreshold=-1") scala> val df = spark.range(1).selectExpr("id AS key", "id AS value") scala> val resultDf = df.join(df, "key").join(df, "key") scala> resultDf.show ... at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) ... 101 more Caused by: java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:156) at org.apache.spark.sql.execution.exchange.ExchangeCoordinator.doEstimationIfNecessary(ExchangeCoordinator.scala:201) at org.apache.spark.sql.execution.exchange.ExchangeCoordinator.postShuffleRDD(ExchangeCoordinator.scala:259) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:124) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:119) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) ... ``` ## How was this patch tested? Added tests in `ExchangeCoordinatorSuite`. Author: Takeshi Yamamuro Closes #21754 from maropu/SPARK-24705-2. --- .../exchange/EnsureRequirements.scala | 1 - .../exchange/ExchangeCoordinator.scala | 17 +++++++++------ .../execution/ExchangeCoordinatorSuite.scala | 21 +++++++++++++++---- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d96ecbaa48029..d2d5011bbcb97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -82,7 +82,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (adaptiveExecutionEnabled && supportsCoordinator) { val coordinator = new ExchangeCoordinator( - children.length, targetPostShuffleInputSize, minNumPostShufflePartitions) children.zip(requiredChildDistributions).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 051e610eb2705..f5d93ee5fa914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -83,7 +83,6 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB) */ class ExchangeCoordinator( - numExchanges: Int, advisoryTargetPostShuffleInputSize: Long, minNumPostShufflePartitions: Option[Int] = None) extends Logging { @@ -91,8 +90,14 @@ class ExchangeCoordinator( // The registered Exchange operators. private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() + // `lazy val` is used here so that we could notice the wrong use of this class, e.g., all the + // exchanges should be registered before `postShuffleRDD` called first time. If a new exchange is + // registered after the `postShuffleRDD` call, `assert(exchanges.length == numExchanges)` fails + // in `doEstimationIfNecessary`. + private[this] lazy val numExchanges = exchanges.size + // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = + private[this] lazy val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. @@ -117,10 +122,6 @@ class ExchangeCoordinator( */ def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length < numExchange, it is because we do not submit - // a stage when the number of partitions of this dependency is 0. - assert(mapOutputStatistics.length <= numExchanges) - // If minNumPostShufflePartitions is defined, it is possible that we need to use a // value less than advisoryTargetPostShuffleInputSize as the target input size of // a post shuffle task. @@ -228,6 +229,10 @@ class ExchangeCoordinator( j += 1 } + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit + // a stage when the number of partitions of this dependency is 0. + assert(mapOutputStatistics.length <= numExchanges) + // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the // number of post-shuffle partitions. val partitionStartIndices = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 737eeb0af586e..b736d43bfc6ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -58,7 +58,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 1 Exchange") { - val coordinator = new ExchangeCoordinator(1, 100L) + val coordinator = new ExchangeCoordinator(100L) { // All bytes per partition are 0. @@ -105,7 +105,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 2 Exchanges") { - val coordinator = new ExchangeCoordinator(2, 100L) + val coordinator = new ExchangeCoordinator(100L) { // If there are multiple values of the number of pre-shuffle partitions, @@ -199,7 +199,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val coordinator = new ExchangeCoordinator(2, 100L, Some(2)) + val coordinator = new ExchangeCoordinator(100L, Some(2)) { // The minimal number of post-shuffle partitions is not enforced because @@ -480,4 +480,17 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { withSparkSession(test, 6144, minNumPostShufflePartitions) } } + + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { + val test = { spark: SparkSession => + spark.sql("SET spark.sql.exchange.reuse=true") + val df = spark.range(1).selectExpr("id AS key", "id AS value") + val resultDf = df.join(df, "key").join(df, "key") + val sparkPlan = resultDf.queryExecution.executedPlan + assert(sparkPlan.collect { case p: ReusedExchangeExec => p }.length == 1) + assert(sparkPlan.collect { case p @ ShuffleExchangeExec(_, _, Some(c)) => p }.length == 3) + checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + } + withSparkSession(test, 4, None) + } } From d0bc3ed6797e0c06f688b7b2ef6c26282a25b175 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Aug 2018 15:35:46 -0700 Subject: [PATCH 1287/2461] [SPARK-24896][SQL] Uuid should produce different values for each execution in streaming query ## What changes were proposed in this pull request? `Uuid`'s results depend on random seed given during analysis. Thus under streaming query, we will have the same uuids in each execution. This seems to be incorrect for streaming query execution. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21854 from viirya/uuid_in_streaming. --- .../streaming/IncrementalExecution.scala | 8 ++++++- .../sql/streaming/StreamingQuerySuite.scala | 22 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ae7f2869b0f3..e9ffe129ca310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.atomic.AtomicInteger +import scala.util.Random + import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} -import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp +import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Uuid} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule @@ -73,10 +75,14 @@ class IncrementalExecution( * with the desired literal */ override lazy val optimizedPlan: LogicalPlan = { + val random = new Random() + sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions { case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral + // SPARK-24896: Set the seed for random number generation in Uuid expressions. + case _: Uuid => Uuid(Some(random.nextLong())) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 78199b0a1c19a..f37f3682b03b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -21,6 +21,8 @@ import java.{util => ju} import java.util.Optional import java.util.concurrent.CountDownLatch +import scala.collection.mutable + import org.apache.commons.lang3.RandomStringUtils import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter @@ -29,8 +31,9 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -834,6 +837,23 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("Uuid in streaming query should not produce same uuids in each execution") { + val uuids = mutable.ArrayBuffer[String]() + def collectUuid: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach(r => uuids += r.getString(0)) + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(Uuid())) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectUuid), + AddData(stream, 2), + CheckAnswer(collectUuid) + ) + assert(uuids.distinct.size == 2) + } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + "should not fail") { val df = spark.readStream.format("rate").load() From bbdcc3bf61da39704650d4570c6307b5a46f7100 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 2 Aug 2018 18:19:04 -0500 Subject: [PATCH 1288/2461] [SPARK-22219][SQL] Refactor code to get a value for "spark.sql.codegen.comments" ## What changes were proposed in this pull request? This PR refactors code to get a value for "spark.sql.codegen.comments" by avoiding `SparkEnv.get.conf`. This PR uses `SQLConf.get.codegenComments` since `SQLConf.get` always returns an instance of `SQLConf`. ## How was this patch tested? Added test case to `DebuggingSuite` Author: Kazuaki Ishizaki Closes #19449 from kiszk/SPARK-22219. --- .../expressions/codegen/CodeGenerator.scala | 7 +------ .../org/apache/spark/sql/internal/SQLConf.scala | 2 ++ .../spark/sql/internal/StaticSQLConf.scala | 8 ++++++++ .../sql/internal/ExecutorSideSQLConfSuite.scala | 16 ++++++++++++++++ 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 05500f5923e94..498dd2639f423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1173,12 +1173,7 @@ class CodegenContext { text: => String, placeholderId: String = "", force: Boolean = false): Block = { - // By default, disable comments in generated code because computing the comments themselves can - // be extremely expensive in certain cases, such as deeply-nested expressions which operate over - // inputs with wide schemas. For more details on the performance issues that motivated this - // flat, see SPARK-15680. - if (force || - SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + if (force || SQLConf.get.codegenComments) { val name = if (placeholderId != "") { assert(!placeHolderToComments.contains(placeholderId)) placeholderId diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index edc1a488150c2..8f303f7316b7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1599,6 +1599,8 @@ class SQLConf extends Serializable with Logging { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) + def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS) + def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 384b1917a1f79..d9c354b165e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -74,6 +74,14 @@ object StaticSQLConf { .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") .createWithDefault(100) + val CODEGEN_COMMENTS = buildStaticConf("spark.sql.codegen.comments") + .internal() + .doc("When true, put comment in the generated code. Since computing huge comments " + + "can be extremely expensive in certain cases, such as deeply-nested expressions which " + + "operate over inputs with wide schemas, default is false.") + .booleanConf + .createWithDefault(false) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 855fe4f4523f2..5b4736ef4f7f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.execution.debug.codegenStringSeq +import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SQLTestUtils class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { @@ -82,4 +84,18 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { assert(checks.forall(_ == true)) } } + + test("SPARK-22219: refactor to control to generate comment") { + Seq(true, false).foreach { flag => + withSQLConf(StaticSQLConf.CODEGEN_COMMENTS.key -> flag.toString) { + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) + assert(res.length == 2) + assert(res.forall { case (_, code) => + (code.contains("* Codegend pipeline") == flag) && + (code.contains("// input[") == flag) + }) + } + } + } } From 29077a1d15e49dfafe7f2eab963830ba9cc6b29a Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 2 Aug 2018 17:19:42 -0700 Subject: [PATCH 1289/2461] [SPARK-24795][CORE][FOLLOWUP] Combine BarrierTaskContext with BarrierTaskContextImpl ## What changes were proposed in this pull request? According to https://github.com/apache/spark/pull/21758#discussion_r206746905 , current declaration of `BarrierTaskContext` didn't extend methods from `TaskContext`. Since `TaskContext` is an abstract class and we don't want to change it to a trait, we have to define class `BarrierTaskContext` directly. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #21972 from jiangxb1987/BarrierTaskContext. --- .../org/apache/spark/BarrierTaskContext.scala | 60 ++++++++++++++++++- .../apache/spark/BarrierTaskContextImpl.scala | 49 --------------- .../org/apache/spark/rdd/RDDBarrier.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 2 +- 4 files changed, 59 insertions(+), 54 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 4c358629dee96..ba303680d1a0f 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,20 +17,71 @@ package org.apache.spark +import java.util.Properties + import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem /** A [[TaskContext]] with extra info and tooling for a barrier stage. */ -trait BarrierTaskContext extends TaskContext { +class BarrierTaskContext( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) { /** * :: Experimental :: * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of misuses listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { (iter, context) => + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { (iter, context) => + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} */ @Experimental @Since("2.4.0") - def barrier(): Unit + def barrier(): Unit = { + // TODO SPARK-24817 implement global barrier. + } /** * :: Experimental :: @@ -38,5 +89,8 @@ trait BarrierTaskContext extends TaskContext { */ @Experimental @Since("2.4.0") - def getTaskInfos(): Array[BarrierTaskInfo] + def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = localProperties.getProperty("addresses", "") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } } diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala deleted file mode 100644 index 8ac705757a382..0000000000000 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.util.Properties - -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.metrics.MetricsSystem - -/** A [[BarrierTaskContext]] implementation. */ -private[spark] class BarrierTaskContextImpl( - override val stageId: Int, - override val stageAttemptNumber: Int, - override val partitionId: Int, - override val taskAttemptId: Long, - override val attemptNumber: Int, - override val taskMemoryManager: TaskMemoryManager, - localProperties: Properties, - @transient private val metricsSystem: MetricsSystem, - // The default value is only used in tests. - override val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, - taskMemoryManager, localProperties, metricsSystem, taskMetrics) - with BarrierTaskContext { - - // TODO SPARK-24817 implement global barrier. - override def barrier(): Unit = {} - - override def getTaskInfos(): Array[BarrierTaskInfo] = { - val addressesStr = localProperties.getProperty("addresses", "") - addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 85565d16e2717..71f38bf6967bc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -28,7 +28,7 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) { /** * :: Experimental :: - * Maps partitions together with a provided BarrierTaskContext. + * Maps partitions together with a provided [[org.apache.spark.BarrierTaskContext]]. * * `preservesPartitioning` indicates whether the input function preserves the partitioner, which * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 89ff2038e5f8a..11f85fd91ba08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -83,7 +83,7 @@ private[spark] abstract class Task[T]( // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether // the stage is barrier. context = if (isBarrier) { - new BarrierTaskContextImpl( + new BarrierTaskContext( stageId, stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, From 7cf16a7fa4eb4145c0c5d1dd2555f78a2fdd8d8b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 3 Aug 2018 08:32:08 +0800 Subject: [PATCH 1290/2461] [SPARK-24773] Avro: support logical timestamp type with different precisions ## What changes were proposed in this pull request? Support reading/writing Avro logical timestamp type with different precisions https://avro.apache.org/docs/1.8.2/spec.html#Timestamp+%28millisecond+precision%29 To specify the output timestamp type, use Dataframe option `outputTimestampType` or SQL config `spark.sql.avro.outputTimestampType`. The supported values are * `TIMESTAMP_MICROS` * `TIMESTAMP_MILLIS` The default output type is `TIMESTAMP_MICROS` ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21935 from gengliangwang/avro_timestamp. --- .../spark/sql/avro/AvroDeserializer.scala | 15 ++- .../spark/sql/avro/AvroFileFormat.scala | 4 +- .../apache/spark/sql/avro/AvroOptions.scala | 11 ++ .../spark/sql/avro/AvroSerializer.scala | 12 +- .../spark/sql/avro/SchemaConverters.scala | 33 ++++-- .../avro/src/test/resources/timestamp.avro | Bin 0 -> 375 bytes .../org/apache/spark/sql/avro/AvroSuite.scala | 107 ++++++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 18 +++ 8 files changed, 178 insertions(+), 22 deletions(-) create mode 100644 external/avro/src/test/resources/timestamp.avro diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index b31149a2c74c2..394a62bf82795 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ import org.apache.avro.generic._ import org.apache.avro.util.Utf8 @@ -86,8 +87,18 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) - case (LONG, TimestampType) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case (LONG, TimestampType) => avroType.getLogicalType match { + case _: TimestampMillis => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case _: TimestampMicros => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + case null => (updater, ordinal, value) => + // For backward compatibility, if the Avro type is Long and it is not logical type, + // the value is processed as timestamp type with millisecond precision. + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case other => throw new IncompatibleSchemaException( + s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") + } case (LONG, DateType) => (updater, ordinal, value) => updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 67765162d634b..6ffcf375af678 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -113,8 +113,8 @@ private[avro] class AvroFileFormat extends FileFormat options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) - val outputAvroSchema = SchemaConverters.toAvroType( - dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) + val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = false, + parsedOptions.recordName, parsedOptions.recordNamespace, parsedOptions.outputTimestampType) AvroJob.setOutputKeySchema(job, outputAvroSchema) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 67f56343b4524..8c62d5db7ae24 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType /** * Options for Avro Reader and Writer stored in case insensitive manner. @@ -79,4 +80,14 @@ class AvroOptions( val compression: String = { parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) } + + /** + * Avro timestamp type used when Spark writes data to Avro files. + * Currently supported types are `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS`. + * TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of microseconds + * from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with millisecond precision, + * which means Spark has to truncate the microsecond portion of its timestamp value. + * The related configuration is set via SQLConf, and it is not exposed as an option. + */ + val outputTimestampType: AvroOutputTimestampType.Value = SQLConf.get.avroOutputTimestampType } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 2b4c5813a535b..a744d54e43e53 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema import org.apache.avro.Schema.Type.NULL import org.apache.avro.generic.GenericData.Record @@ -92,8 +93,15 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) case DateType => (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY - case TimestampType => - (getter, ordinal) => getter.getLong(ordinal) / 1000 + case TimestampType => avroType.getLogicalType match { + case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000 + case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) + // For backward compatibility, if the Avro type is Long and it is not logical type, + // output the timestamp value as with millisecond precision. + case null => (getter, ordinal) => getter.getLong(ordinal) / 1000 + case other => throw new IncompatibleSchemaException( + s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}") + } case ArrayType(et, containsNull) => val elementConverter = newConverter( diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 87fae63aeff2b..1e912073f12b4 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.avro import scala.collection.JavaConverters._ -import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ +import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType import org.apache.spark.sql.types._ /** @@ -42,7 +44,10 @@ object SchemaConverters { case BYTES => SchemaType(BinaryType, nullable = false) case DOUBLE => SchemaType(DoubleType, nullable = false) case FLOAT => SchemaType(FloatType, nullable = false) - case LONG => SchemaType(LongType, nullable = false) + case LONG => avroSchema.getLogicalType match { + case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false) + case _ => SchemaType(LongType, nullable = false) + } case FIXED => SchemaType(BinaryType, nullable = false) case ENUM => SchemaType(StringType, nullable = false) @@ -103,31 +108,45 @@ object SchemaConverters { catalystType: DataType, nullable: Boolean = false, recordName: String = "topLevelRecord", - prevNameSpace: String = ""): Schema = { + prevNameSpace: String = "", + outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS) + : Schema = { val builder = if (nullable) { SchemaBuilder.builder().nullable() } else { SchemaBuilder.builder() } + catalystType match { case BooleanType => builder.booleanType() case ByteType | ShortType | IntegerType => builder.intType() case LongType => builder.longType() case DateType => builder.longType() - case TimestampType => builder.longType() + case TimestampType => + val timestampType = outputTimestampType match { + case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis() + case AvroOutputTimestampType.TIMESTAMP_MICROS => LogicalTypes.timestampMicros() + case other => + throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.") + } + builder.longBuilder().prop(LogicalType.LOGICAL_TYPE_PROP, timestampType.getName).endLong() + case FloatType => builder.floatType() case DoubleType => builder.doubleType() case _: DecimalType | StringType => builder.stringType() case BinaryType => builder.bytesType() case ArrayType(et, containsNull) => - builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace)) + builder.array() + .items(toAvroType(et, containsNull, recordName, prevNameSpace, outputTimestampType)) case MapType(StringType, vt, valueContainsNull) => - builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) + builder.map() + .values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace, outputTimestampType)) case st: StructType => val nameSpace = s"$prevNameSpace.$recordName" val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() st.foreach { f => - val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace) + val fieldAvroType = + toAvroType(f.dataType, f.nullable, f.name, nameSpace, outputTimestampType) fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() } fieldsAssembler.endRecord() diff --git a/external/avro/src/test/resources/timestamp.avro b/external/avro/src/test/resources/timestamp.avro new file mode 100644 index 0000000000000000000000000000000000000000..daef50b78b87494b75c0309bfdbb917f228edc01 GIT binary patch literal 375 zcmeZI%3@>@ODrqO*DFrWNX<>$!&0qOQdy9yWTl`~l$xAhl%k}gpp=)Gn_66um<$%q z$xqKrPRxOcgH)EJ7MFndX_=`xDaAmMXt*iWN>KG7P*Y1Xfo7E?<`(GYX6EE%7K8M` zY|P2eOINCeS_n26rZ^s|7$`}c(aA;m#2XD(jBGT}(Lk3VIRxUe*jf>ASS9DDq$YFZ ymFDCycpDx~so;CZqcP7bglXq>7Z$Y({0)=7Fn-Wmuq?1){+%7{7v997D*^zXy@3t@ literal 0 HcmV?d00001 diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c221c4fd07de7..085c8c8e5fc1c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -39,9 +39,34 @@ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + val episodesAvro = testFile("episodes.avro") val testAvro = testFile("test.avro") + // The test file timestamp.avro is generated via following Python code: + // import json + // import avro.schema + // from avro.datafile import DataFileWriter + // from avro.io import DatumWriter + // + // write_schema = avro.schema.parse(json.dumps({ + // "namespace": "logical", + // "type": "record", + // "name": "test", + // "fields": [ + // {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, + // {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, + // {"name": "long", "type": "long"} + // ] + // })) + // + // writer = DataFileWriter(open("timestamp.avro", "wb"), DatumWriter(), write_schema) + // writer.append({"timestamp_millis": 1000, "timestamp_micros": 2000000, "long": 3000}) + // writer.append({"timestamp_millis": 666000, "timestamp_micros": 999000000, "long": 777000}) + // writer.close() + val timestampAvro = testFile("timestamp.avro") + override protected def beforeAll(): Unit = { super.beforeAll() spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) @@ -331,6 +356,77 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("Logical type: timestamp_millis") { + val expected = Seq(1000L, 666000L).map(t => Row(new Timestamp(t))) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) + + checkAnswer(df, expected) + + withTempPath { dir => + df.write.format("avro").save(dir.toString) + checkAnswer(spark.read.format("avro").load(dir.toString), expected) + } + } + + test("Logical type: timestamp_micros") { + val expected = Seq(2000L, 999000L).map(t => Row(new Timestamp(t))) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) + + checkAnswer(df, expected) + + withTempPath { dir => + df.write.format("avro").save(dir.toString) + checkAnswer(spark.read.format("avro").load(dir.toString), expected) + } + } + + test("Logical type: specify different output timestamp types") { + val df = + spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) + + val expected = Seq((1000L, 2000L), (666000L, 999000L)) + .map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) + + Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType => + withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) { + withTempPath { dir => + df.write.format("avro").save(dir.toString) + checkAnswer(spark.read.format("avro").load(dir.toString), expected) + } + } + } + } + + test("Read Long type as Timestamp") { + val schema = StructType(StructField("long", TimestampType, true) :: Nil) + val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) + + val expected = Seq(3000L, 777000L).map(t => Row(new Timestamp(t))) + + checkAnswer(df, expected) + } + + test("Logical type: user specified schema") { + val expected = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) + .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) + + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, + {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, + {"name": "long", "type": "long"} + ] + } + """ + val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro) + + checkAnswer(df, expected) + } + test("Array data types") { withTempPath { dir => val testSchema = StructType(Seq( @@ -521,7 +617,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // TimesStamps are converted to longs val times = spark.read.format("avro").load(avroDir).select("Time").collect() - assert(times.map(_(0)).toSet == Set(666, 777, 42)) + assert(times.map(_(0)).toSet == + Set(new Timestamp(666), new Timestamp(777), new Timestamp(42))) // DecimalType should be converted to string val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect() @@ -540,9 +637,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("correctly read long as date/timestamp type") { withTempPath { tempDir => - val sparkSession = spark - import sparkSession.implicits._ - val currentTime = new Timestamp(System.currentTimeMillis()) val currentDate = new Date(System.currentTimeMillis()) val schema = StructType(Seq( @@ -570,9 +664,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("does not coerce null date/timestamp value to 0 epoch.") { withTempPath { tempDir => - val sparkSession = spark - import sparkSession.implicits._ - val nullTime: Timestamp = null val nullDate: Date = null val schema = StructType(Seq( @@ -778,8 +869,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("read avro file partitioned") { withTempPath { dir => - val sparkSession = spark - import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") val outputDir = s"$dir/${UUID.randomUUID}" df.write.format("avro").save(outputDir) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8f303f7316b7d..2aba2f7de54a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1437,6 +1437,21 @@ object SQLConf { .intConf .createWithDefault(20) + object AvroOutputTimestampType extends Enumeration { + val TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value + } + + val AVRO_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.avro.outputTimestampType") + .doc("Sets which Avro timestamp type to use when Spark writes data to Avro files. " + + "TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of " + + "microseconds from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with " + + "millisecond precision, which means Spark has to truncate the microsecond portion of its " + + "timestamp value.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(AvroOutputTimestampType.values.map(_.toString)) + .createWithDefault(AvroOutputTimestampType.TIMESTAMP_MICROS.toString) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") @@ -1839,6 +1854,9 @@ class SQLConf extends Serializable with Logging { def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + def avroOutputTimestampType: AvroOutputTimestampType.Value = + AvroOutputTimestampType.withName(getConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE)) + def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) From b3f2911eebeb418631ce296f68a7cc68083659cd Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 3 Aug 2018 08:33:28 +0800 Subject: [PATCH 1291/2461] [SPARK-24945][SQL] Switching to uniVocity 2.7.3 ## What changes were proposed in this pull request? In the PR, I propose to upgrade uniVocity parser from **2.6.3** to **2.7.3**. The recent version includes a fix for the SPARK-24645 issue and has better performance. Before changes: ``` Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ One quoted string 33336 / 34122 0.0 666727.0 1.0X Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Select 1000 columns 90287 / 91713 0.0 90286.9 1.0X Select 100 columns 31826 / 36589 0.0 31826.4 2.8X Select one column 25738 / 25872 0.0 25737.9 3.5X count() 6931 / 7269 0.1 6931.5 13.0X ``` after: ``` Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ One quoted string 33411 / 33510 0.0 668211.4 1.0X Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Select 1000 columns 88028 / 89311 0.0 88028.1 1.0X Select 100 columns 29010 / 32755 0.0 29010.1 3.0X Select one column 22936 / 22953 0.0 22936.5 3.8X count() 6657 / 6740 0.2 6656.6 13.5X ``` Closes #21892 ## How was this patch tested? It was tested by `CSVSuite` and `CSVBenchmarks` Author: Maxim Gekk Closes #21969 from MaxGekk/univocity-2_7_3. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- sql/core/pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 4ef61b2ab8cb7..54cdcfcaf8aa1 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -191,7 +191,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.6.3.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a74ce1f26b146..fda13db52ba3d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -192,7 +192,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.6.3.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e0fcca0eeb31e..90602fce59a7d 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -212,7 +212,7 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar token-provider-1.0.1.jar -univocity-parsers-2.6.3.jar +univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar xbean-asm6-shaded-4.8.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 9cd6776a18bcb..68b42a4c51ec4 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.6.3 + 2.7.3 jar From 73dd6cf9b558f9d752e1f3c13584344257ad7863 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 2 Aug 2018 22:04:17 -0700 Subject: [PATCH 1292/2461] [SPARK-24966][SQL] Implement precedence rules for set operations. ## What changes were proposed in this pull request? Currently the set operations INTERSECT, UNION and EXCEPT are assigned the same precedence. This PR fixes the problem by giving INTERSECT higher precedence than UNION and EXCEPT. UNION and EXCEPT operators are evaluated in the order in which they appear in the query from left to right. This results in change in behavior because of the change in order of evaluations of set operators in a query. The old behavior is still preserved under a newly added config parameter. Query `:` ``` SELECT * FROM t1 UNION SELECT * FROM t2 EXCEPT SELECT * FROM t3 INTERSECT SELECT * FROM t4 ``` Parsed plan before the change `:` ``` == Parsed Logical Plan == 'Intersect false :- 'Except false : :- 'Distinct : : +- 'Union : : :- 'Project [*] : : : +- 'UnresolvedRelation `t1` : : +- 'Project [*] : : +- 'UnresolvedRelation `t2` : +- 'Project [*] : +- 'UnresolvedRelation `t3` +- 'Project [*] +- 'UnresolvedRelation `t4` ``` Parsed plan after the change `:` ``` == Parsed Logical Plan == 'Except false :- 'Distinct : +- 'Union : :- 'Project [*] : : +- 'UnresolvedRelation `t1` : +- 'Project [*] : +- 'UnresolvedRelation `t2` +- 'Intersect false :- 'Project [*] : +- 'UnresolvedRelation `t3` +- 'Project [*] +- 'UnresolvedRelation `t4` ``` ## How was this patch tested? Added tests in PlanParserSuite, SQLQueryTestSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #21941 from dilipbiswal/SPARK-24966. --- docs/sql-programming-guide.md | 1 + .../spark/sql/catalyst/parser/SqlBase.g4 | 15 ++- .../spark/sql/catalyst/dsl/package.scala | 6 +- .../sql/catalyst/parser/ParseDriver.scala | 2 + .../plans/logical/basicLogicalOperators.scala | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 12 ++ .../sql/catalyst/parser/PlanParserSuite.scala | 45 ++++++++ .../spark/sql/execution/SparkStrategies.scala | 4 +- .../sql-tests/inputs/intersect-all.sql | 51 +++++++-- .../sql-tests/results/intersect-all.sql.out | 104 ++++++++++++++---- 10 files changed, 211 insertions(+), 35 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0900f8317d635..a1e019cbec4d2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1876,6 +1876,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuaration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9ad6f30c40a88..94283f59011a8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -17,6 +17,12 @@ grammar SqlBase; @members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enbled = false; + /** * Verify whether current token is a valid decimal token (which contains dot). * Returns true if the character that follows the token is not a digit or letter or underscore. @@ -352,8 +358,13 @@ multiInsertQueryBody ; queryTerm - : queryPrimary #queryTermDefault - | left=queryTerm operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enbled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enbled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enbled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation ; queryPrimary diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 98708545c4bfc..7997e79003b12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -356,9 +356,11 @@ package object dsl { def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) - def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) + def except(otherPlan: LogicalPlan, isAll: Boolean = false): LogicalPlan = + Except(logicalPlan, otherPlan, isAll) - def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan, isAll: Boolean = false): LogicalPlan = + Intersect(logicalPlan, otherPlan, isAll) def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 4c20f2368bded..7d8cb1f18b4b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -84,12 +84,14 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) + lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) parser.addParseListener(PostProcessor) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) + parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced try { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 68413d7fd10f1..d7dbdb39a9afb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -165,9 +165,9 @@ object SetOperation { } case class Intersect( - left: LogicalPlan, - right: LogicalPlan, - isAll: Boolean = false) extends SetOperation(left, right) { + left: LogicalPlan, + right: LogicalPlan, + isAll: Boolean = false) extends SetOperation(left, right) { override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2aba2f7de54a7..67c3abb80c2c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1466,6 +1466,16 @@ object SQLConf { .intConf .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) .createWithDefault(Deflater.DEFAULT_COMPRESSION) + + val LEGACY_SETOPS_PRECEDENCE_ENABLED = + buildConf("spark.sql.legacy.setopsPrecedence.enabled") + .internal() + .doc("When set to true and the order of evaluation is not specified by parentheses, the " + + "set operations are performed from left to right as they appear in the query. When set " + + "to false and order of evaluation is not specified by parentheses, INTERSECT operations " + + "are performed before any UNION, EXCEPT and MINUS operations.") + .booleanConf + .createWithDefault(false) } /** @@ -1861,6 +1871,8 @@ class SQLConf extends Serializable with Logging { def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 9be0ec5af78ef..38efd89156de6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType /** @@ -676,4 +677,48 @@ class PlanParserSuite extends AnalysisTest { OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc")) ) } + + test("precedence of set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + val c = table("c").select(star()) + val d = table("d").select(star()) + + val query1 = + """ + |SELECT * FROM a + |UNION + |SELECT * FROM b + |EXCEPT + |SELECT * FROM c + |INTERSECT + |SELECT * FROM d + """.stripMargin + + val query2 = + """ + |SELECT * FROM a + |UNION + |SELECT * FROM b + |EXCEPT ALL + |SELECT * FROM c + |INTERSECT ALL + |SELECT * FROM d + """.stripMargin + + assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d))) + assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) + + // Now disable precedence enforcement to verify the old behaviour. + withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "true") { + assertEqual(query1, Distinct(a.union(b)).except(c).intersect(d)) + assertEqual(query2, Distinct(a.union(b)).except(c, isAll = true).intersect(d, isAll = true)) + } + + // Explicitly enable the precedence enforcement + withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "false") { + assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d))) + assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 75eff8a88312b..b4179f4d12d35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -535,14 +535,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Intersect(left, right, true) => throw new IllegalStateException( "logical intersect operator should have been replaced by union, aggregate" + - "and generate operators in the optimizer") + " and generate operators in the optimizer") case logical.Except(left, right, false) => throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") case logical.Except(left, right, true) => throw new IllegalStateException( "logical except (all) operator should have been replaced by union, aggregate" + - "and generate operators in the optimizer") + " and generate operators in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql index ff4395c3e7447..b0b2244048caa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -59,29 +59,40 @@ INTERSECT ALL SELECT * FROM tab2; -- Chain of different `set operations --- We need to parenthesize the following two queries to enforce --- certain order of evaluation of operators. After fix to --- SPARK-24966 this can be removed. SELECT * FROM tab1 EXCEPT SELECT * FROM tab2 UNION ALL -( SELECT * FROM tab1 INTERSECT ALL SELECT * FROM tab2 -); +; -- Chain of different `set operations SELECT * FROM tab1 EXCEPT SELECT * FROM tab2 EXCEPT -( SELECT * FROM tab1 INTERSECT ALL SELECT * FROM tab2 -); +; + +-- test use parenthesis to control order of evaluation +( + ( + ( + SELECT * FROM tab1 + EXCEPT + SELECT * FROM tab2 + ) + EXCEPT + SELECT * FROM tab1 + ) + INTERSECT ALL + SELECT * FROM tab2 +) +; -- Join under intersect all SELECT * @@ -118,6 +129,32 @@ SELECT v FROM tab1 GROUP BY v INTERSECT ALL SELECT k FROM tab2 GROUP BY k; +-- Test pre spark2.4 behaviour of set operation precedence +-- All the set operators are given equal precedence and are evaluated +-- from left to right as they appear in the query. + +-- Set the property +SET spark.sql.legacy.setopsPrecedence.enabled= true; + +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2; + +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT +SELECT * FROM tab2; + +-- Restore the property +SET spark.sql.legacy.setopsPrecedence.enabled = false; + -- Clean-up DROP VIEW IF EXISTS tab1; DROP VIEW IF EXISTS tab2; diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out index 792791bc51628..63dd56ce468bc 100644 --- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 22 -- !query 0 @@ -133,11 +133,9 @@ SELECT * FROM tab1 EXCEPT SELECT * FROM tab2 UNION ALL -( SELECT * FROM tab1 INTERSECT ALL SELECT * FROM tab2 -) -- !query 10 schema struct -- !query 10 output @@ -154,11 +152,9 @@ SELECT * FROM tab1 EXCEPT SELECT * FROM tab2 EXCEPT -( SELECT * FROM tab1 INTERSECT ALL SELECT * FROM tab2 -) -- !query 11 schema struct -- !query 11 output @@ -166,6 +162,26 @@ struct -- !query 12 +( + ( + ( + SELECT * FROM tab1 + EXCEPT + SELECT * FROM tab2 + ) + EXCEPT + SELECT * FROM tab1 + ) + INTERSECT ALL + SELECT * FROM tab2 +) +-- !query 12 schema +struct +-- !query 12 output + + + +-- !query 13 SELECT * FROM (SELECT tab1.k, tab2.v @@ -179,9 +195,9 @@ FROM (SELECT tab1.k, FROM tab1 JOIN tab2 ON tab1.k = tab2.k) --- !query 12 schema +-- !query 13 schema struct --- !query 12 output +-- !query 13 output 1 2 1 2 1 2 @@ -193,7 +209,7 @@ struct 2 3 --- !query 13 +-- !query 14 SELECT * FROM (SELECT tab1.k, tab2.v @@ -207,35 +223,85 @@ FROM (SELECT tab2.v AS k, FROM tab1 JOIN tab2 ON tab1.k = tab2.k) --- !query 13 schema +-- !query 14 schema struct --- !query 13 output +-- !query 14 output --- !query 14 +-- !query 15 SELECT v FROM tab1 GROUP BY v INTERSECT ALL SELECT k FROM tab2 GROUP BY k --- !query 14 schema +-- !query 15 schema struct --- !query 14 output +-- !query 15 output 2 3 NULL --- !query 15 +-- !query 16 +SET spark.sql.legacy.setopsPrecedence.enabled= true +-- !query 16 schema +struct +-- !query 16 output +spark.sql.legacy.setopsPrecedence.enabled true + + +-- !query 17 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT ALL +SELECT * FROM tab2 +-- !query 17 schema +struct +-- !query 17 output +1 2 +1 2 +2 3 +NULL NULL +NULL NULL + + +-- !query 18 +SELECT * FROM tab1 +EXCEPT +SELECT * FROM tab2 +UNION ALL +SELECT * FROM tab1 +INTERSECT +SELECT * FROM tab2 +-- !query 18 schema +struct +-- !query 18 output +1 2 +2 3 +NULL NULL + + +-- !query 19 +SET spark.sql.legacy.setopsPrecedence.enabled = false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.legacy.setopsPrecedence.enabled false + + +-- !query 20 DROP VIEW IF EXISTS tab1 --- !query 15 schema +-- !query 20 schema struct<> --- !query 15 output +-- !query 20 output --- !query 16 +-- !query 21 DROP VIEW IF EXISTS tab2 --- !query 16 schema +-- !query 21 schema struct<> --- !query 16 output +-- !query 21 output From f45d60a5a1f1e97ecde36eda8202034d78f93d53 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 3 Aug 2018 13:28:44 +0800 Subject: [PATCH 1293/2461] [SPARK-25002][SQL] Avro: revise the output record namespace ## What changes were proposed in this pull request? Currently the output namespace is starting with ".", e.g. `.topLevelRecord` Although it is valid according to Avro spec, we should remove the starting dot in case of failures when the output Avro file is read by other lib: https://github.com/linkedin/goavro/pull/96 ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21974 from gengliangwang/avro_namespace. --- .../spark/sql/avro/SchemaConverters.scala | 6 +++++- .../org/apache/spark/sql/avro/AvroSuite.scala | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 1e912073f12b4..69295398775e6 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -142,7 +142,11 @@ object SchemaConverters { builder.map() .values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace, outputTimestampType)) case st: StructType => - val nameSpace = s"$prevNameSpace.$recordName" + val nameSpace = prevNameSpace match { + case "" => recordName + case _ => s"$prevNameSpace.$recordName" + } + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() st.foreach { f => val fieldAvroType = diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 085c8c8e5fc1c..b4dcf6c6c9330 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -884,6 +884,23 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTop(id: Int, data: NestedMiddle) + test("Validate namespace in avro file that has nested records with the same name") { + withTempPath { dir => + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + writeDf.write.format("avro").save(dir.toString) + val file = new File(dir.toString) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + .head + val reader = new DataFileReader(file, new GenericDatumReader[Any]()) + val schema = reader.getSchema.toString() + assert(schema.contains("\"namespace\":\"topLevelRecord\"")) + assert(schema.contains("\"namespace\":\"topLevelRecord.data\"")) + assert(schema.contains("\"namespace\":\"topLevelRecord.data.data\"")) + } + } + test("saving avro that has nested records with the same name") { withTempPath { tempDir => // Save avro file on output folder path From b0d6967d45f3260ed4ee9b2a49f801d799e81283 Mon Sep 17 00:00:00 2001 From: Chris Horn Date: Thu, 2 Aug 2018 22:40:58 -0700 Subject: [PATCH 1294/2461] [SPARK-24788][SQL] RelationalGroupedDataset.toString with unresolved exprs should not fail ## What changes were proposed in this pull request? In the current master, `toString` throws an exception when `RelationalGroupedDataset` has unresolved expressions; ``` scala> spark.range(0, 10).groupBy("id") res4: org.apache.spark.sql.RelationalGroupedDataset = RelationalGroupedDataset: [grouping expressions: [id: bigint], value: [id: bigint], type: GroupBy] scala> spark.range(0, 10).groupBy('id) org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to dataType on unresolved object, tree: 'id at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:105) at org.apache.spark.sql.RelationalGroupedDataset$$anonfun$12.apply(RelationalGroupedDataset.scala:474) at org.apache.spark.sql.RelationalGroupedDataset$$anonfun$12.apply(RelationalGroupedDataset.scala:473) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.sql.RelationalGroupedDataset.toString(RelationalGroupedDataset.scala:473) at scala.runtime.ScalaRunTime$.scala$runtime$ScalaRunTime$$inner$1(ScalaRunTime.scala:332) at scala.runtime.ScalaRunTime$.stringOf(ScalaRunTime.scala:337) at scala.runtime.ScalaRunTime$.replStringOf(ScalaRunTime.scala:345) ``` This pr fixed code to handle the unresolved case in `RelationalGroupedDataset.toString`. Closes #21752 ## How was this patch tested? Added tests in `DataFrameAggregateSuite`. Author: Chris Horn Author: Takeshi Yamamuro Closes #21964 from maropu/SPARK-24788. --- .../apache/spark/sql/RelationalGroupedDataset.scala | 7 +++++-- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 8412219b1250b..4e73b3657b4f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -469,8 +469,11 @@ class RelationalGroupedDataset protected[sql]( override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") - val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { - case f => s"${f.name}: ${f.dataType.simpleString(2)}" + val kFields = groupingExprs.collect { + case expr: NamedExpression if expr.resolved => + s"${expr.name}: ${expr.dataType.simpleString(2)}" + case expr: NamedExpression => expr.name + case o => o.toString } builder.append(kFields.take(2).mkString(", ")) if (kFields.length > 2) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f495a949ebc5a..d0106c44b7db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -717,4 +717,14 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) } + test("SPARK-24788: RelationalGroupedDataset.toString with unresolved exprs should not fail") { + // Checks if these raise no exception + assert(testData.groupBy('key).toString.contains( + "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) + assert(testData.groupBy(col("key")).toString.contains( + "[grouping expressions: [key], value: [key: int, value: string], type: GroupBy]")) + assert(testData.groupBy(current_date()).toString.contains( + "grouping expressions: [current_date(None)], value: [key: int, value: string], " + + "type: GroupBy]")) + } } From 19a45319130d618a173f5f3b4dde59356b39089b Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 2 Aug 2018 22:45:10 -0700 Subject: [PATCH 1295/2461] [SPARK-24997][SQL] Enable support of MINUS ALL ## What changes were proposed in this pull request? Enable support for MINUS ALL which was gated at AstBuilder. ## How was this patch tested? Added tests in SQLQueryTestSuite and modify PlanParserSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #21963 from dilipbiswal/minus-all. --- .../sql/catalyst/parser/AstBuilder.scala | 11 +- .../sql/catalyst/parser/PlanParserSuite.scala | 4 +- .../resources/sql-tests/inputs/except-all.sql | 22 ++- .../sql-tests/results/except-all.sql.out | 147 +++++++++++------- 4 files changed, 113 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0ceeb53e1d7a6..9906a30b488b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -517,11 +517,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Connect two queries by a Set operator. * * Supported Set operators are: - * - UNION [DISTINCT] - * - UNION ALL - * - EXCEPT [DISTINCT] - * - MINUS [DISTINCT] - * - INTERSECT [DISTINCT] + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] */ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { val left = plan(ctx.left) @@ -541,7 +540,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.EXCEPT => Except(left, right) case SqlBaseParser.SETMINUS if all => - throw new ParseException("MINUS ALL is not supported.", ctx) + Except(left, right, isAll = true) case SqlBaseParser.SETMINUS => Except(left, right) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 38efd89156de6..924700483dbe4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -67,11 +67,13 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a union all select * from b", a.union(b)) assertEqual("select * from a except select * from b", a.except(b)) assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a except all select * from b", a.except(b, isAll = true)) assertEqual("select * from a minus select * from b", a.except(b)) - intercept("select * from a minus all select * from b", "MINUS ALL is not supported.") + assertEqual("select * from a minus all select * from b", a.except(b, isAll = true)) assertEqual("select * from a minus distinct select * from b", a.except(b)) assertEqual("select * from a intersect select * from b", a.intersect(b)) assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + assertEqual("select * from a intersect all select * from b", a.intersect(b, isAll = true)) } test("common table expressions") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql index 08b9a437b3d14..e28f0721a6449 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/except-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/except-all.sql @@ -17,12 +17,17 @@ CREATE TEMPORARY VIEW tab4 AS SELECT * FROM VALUES (2, 20) AS tab4(k, v); --- Basic ExceptAll +-- Basic EXCEPT ALL SELECT * FROM tab1 EXCEPT ALL SELECT * FROM tab2; --- ExceptAll same table in both branches +-- MINUS ALL (synonym for EXCEPT) +SELECT * FROM tab1 +MINUS ALL +SELECT * FROM tab2; + +-- EXCEPT ALL same table in both branches SELECT * FROM tab1 EXCEPT ALL SELECT * FROM tab2 WHERE c1 IS NOT NULL; @@ -57,14 +62,14 @@ SELECT * FROM tab4 EXCEPT ALL SELECT * FROM tab3; --- ExceptAll + Intersect +-- EXCEPT ALL + INTERSECT SELECT * FROM tab4 EXCEPT ALL SELECT * FROM tab3 INTERSECT DISTINCT SELECT * FROM tab4; --- ExceptAll + Except +-- EXCEPT ALL + EXCEPT SELECT * FROM tab4 EXCEPT ALL SELECT * FROM tab3 @@ -94,6 +99,15 @@ SELECT * FROM tab3 EXCEPT DISTINCT SELECT * FROM tab4; +-- Using MINUS ALL +SELECT * FROM tab3 +MINUS ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +MINUS DISTINCT +SELECT * FROM tab4; + -- Chain of set operations SELECT * FROM tab3 EXCEPT ALL diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out index 2a21c1505350c..01091a2f751ce 100644 --- a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 27 -- !query 0 @@ -63,8 +63,8 @@ NULL -- !query 5 SELECT * FROM tab1 -EXCEPT ALL -SELECT * FROM tab2 WHERE c1 IS NOT NULL +MINUS ALL +SELECT * FROM tab2 -- !query 5 schema struct -- !query 5 output @@ -72,26 +72,39 @@ struct 2 2 NULL -NULL -- !query 6 -SELECT * FROM tab1 WHERE c1 > 5 +SELECT * FROM tab1 EXCEPT ALL -SELECT * FROM tab2 +SELECT * FROM tab2 WHERE c1 IS NOT NULL -- !query 6 schema struct -- !query 6 output - +0 +2 +2 +NULL +NULL -- !query 7 -SELECT * FROM tab1 +SELECT * FROM tab1 WHERE c1 > 5 EXCEPT ALL -SELECT * FROM tab2 WHERE c1 > 6 +SELECT * FROM tab2 -- !query 7 schema struct -- !query 7 output + + + +-- !query 8 +SELECT * FROM tab1 +EXCEPT ALL +SELECT * FROM tab2 WHERE c1 > 6 +-- !query 8 schema +struct +-- !query 8 output 0 1 2 @@ -103,13 +116,13 @@ NULL NULL --- !query 8 +-- !query 9 SELECT * FROM tab1 EXCEPT ALL SELECT CAST(1 AS BIGINT) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 0 2 2 @@ -120,65 +133,65 @@ NULL NULL --- !query 9 +-- !query 10 SELECT * FROM tab1 EXCEPT ALL SELECT array(1) --- !query 9 schema +-- !query 10 schema struct<> --- !query 9 output +-- !query 10 output org.apache.spark.sql.AnalysisException ExceptAll can only be performed on tables with the compatible column types. array <> int at the first column of the second table; --- !query 10 +-- !query 11 SELECT * FROM tab3 EXCEPT ALL SELECT * FROM tab4 --- !query 10 schema +-- !query 11 schema struct --- !query 10 output +-- !query 11 output 1 2 1 3 --- !query 11 +-- !query 12 SELECT * FROM tab4 EXCEPT ALL SELECT * FROM tab3 --- !query 11 schema +-- !query 12 schema struct --- !query 11 output +-- !query 12 output 2 2 2 20 --- !query 12 +-- !query 13 SELECT * FROM tab4 EXCEPT ALL SELECT * FROM tab3 INTERSECT DISTINCT SELECT * FROM tab4 --- !query 12 schema +-- !query 13 schema struct --- !query 12 output +-- !query 13 output 2 2 2 20 --- !query 13 +-- !query 14 SELECT * FROM tab4 EXCEPT ALL SELECT * FROM tab3 EXCEPT DISTINCT SELECT * FROM tab4 --- !query 13 schema +-- !query 14 schema struct --- !query 13 output +-- !query 14 output --- !query 14 +-- !query 15 SELECT * FROM tab3 EXCEPT ALL SELECT * FROM tab4 @@ -186,24 +199,24 @@ UNION ALL SELECT * FROM tab3 EXCEPT DISTINCT SELECT * FROM tab4 --- !query 14 schema +-- !query 15 schema struct --- !query 14 output +-- !query 15 output 1 3 --- !query 15 +-- !query 16 SELECT k FROM tab3 EXCEPT ALL SELECT k, v FROM tab4 --- !query 15 schema +-- !query 16 schema struct<> --- !query 15 output +-- !query 16 output org.apache.spark.sql.AnalysisException ExceptAll can only be performed on tables with the same number of columns, but the first table has 1 columns and the second table has 2 columns; --- !query 16 +-- !query 17 SELECT * FROM tab3 EXCEPT ALL SELECT * FROM tab4 @@ -211,13 +224,27 @@ UNION SELECT * FROM tab3 EXCEPT DISTINCT SELECT * FROM tab4 --- !query 16 schema +-- !query 17 schema struct --- !query 16 output +-- !query 17 output 1 3 --- !query 17 +-- !query 18 +SELECT * FROM tab3 +MINUS ALL +SELECT * FROM tab4 +UNION +SELECT * FROM tab3 +MINUS DISTINCT +SELECT * FROM tab4 +-- !query 18 schema +struct +-- !query 18 output +1 3 + + +-- !query 19 SELECT * FROM tab3 EXCEPT ALL SELECT * FROM tab4 @@ -225,13 +252,13 @@ EXCEPT DISTINCT SELECT * FROM tab3 EXCEPT DISTINCT SELECT * FROM tab4 --- !query 17 schema +-- !query 19 schema struct --- !query 17 output +-- !query 19 output --- !query 18 +-- !query 20 SELECT * FROM (SELECT tab3.k, tab4.v @@ -245,13 +272,13 @@ FROM (SELECT tab3.k, FROM tab3 JOIN tab4 ON tab3.k = tab4.k) --- !query 18 schema +-- !query 20 schema struct --- !query 18 output +-- !query 20 output --- !query 19 +-- !query 21 SELECT * FROM (SELECT tab3.k, tab4.v @@ -265,9 +292,9 @@ FROM (SELECT tab4.v AS k, FROM tab3 JOIN tab4 ON tab3.k = tab4.k) --- !query 19 schema +-- !query 21 schema struct --- !query 19 output +-- !query 21 output 1 2 1 2 1 2 @@ -277,43 +304,43 @@ struct 2 3 --- !query 20 +-- !query 22 SELECT v FROM tab3 GROUP BY v EXCEPT ALL SELECT k FROM tab4 GROUP BY k --- !query 20 schema +-- !query 22 schema struct --- !query 20 output +-- !query 22 output 3 --- !query 21 +-- !query 23 DROP VIEW IF EXISTS tab1 --- !query 21 schema +-- !query 23 schema struct<> --- !query 21 output +-- !query 23 output --- !query 22 +-- !query 24 DROP VIEW IF EXISTS tab2 --- !query 22 schema +-- !query 24 schema struct<> --- !query 22 output +-- !query 24 output --- !query 23 +-- !query 25 DROP VIEW IF EXISTS tab3 --- !query 23 schema +-- !query 25 schema struct<> --- !query 23 output +-- !query 25 output --- !query 24 +-- !query 26 DROP VIEW IF EXISTS tab4 --- !query 24 schema +-- !query 26 schema struct<> --- !query 24 output +-- !query 26 output From ebf33a333e9f7ad46f37233eee843e31028a1d62 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 3 Aug 2018 15:02:41 +0800 Subject: [PATCH 1296/2461] [SAPRK-25011][ML] add prefix to __all__ in fpm.py ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-25011 add prefix to __all__ in fpm.py ## How was this patch tested? existing unit test. Author: Yuhao Yang Closes #21981 from hhbyyh/prefixall. --- python/pyspark/ml/fpm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index fd19fd96c4df6..f9394421e0cc4 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * -__all__ = ["FPGrowth", "FPGrowthModel"] +__all__ = ["FPGrowth", "FPGrowthModel", "PrefixSpan"] class HasMinSupport(Params): @@ -313,14 +313,15 @@ def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=3200 def findFrequentSequentialPatterns(self, dataset): """ .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. :param dataset: A dataframe containing a sequence column which is `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. :return: A `DataFrame` that contains columns of sequence and corresponding frequency. The schema of it will be: - - `sequence: ArrayType(ArrayType(T))` (T is the item type) - - `freq: Long` + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` >>> from pyspark.ml.fpm import PrefixSpan >>> from pyspark.sql import Row From 53ca9755dbb3b952b16b198d31b7964d56bb5ef9 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Fri, 3 Aug 2018 07:23:56 +0000 Subject: [PATCH 1297/2461] [SPARK-25009][CORE] Standalone Cluster mode application submit is not working ## What changes were proposed in this pull request? It seems 'doRunMain()' has been removed accidentally by other PR and due to that the application submission is not happening, this PR adds back the 'doRunMain()' for standalone cluster submission. ## How was this patch tested? I verified it manually by submitting application in standalone cluster mode, all the applications are submitting to the Master with the change. Author: Devaraj K Closes #21979 from devaraj-kavali/SPARK-25009. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e7310ee886103..6e70bcd7fc088 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -181,6 +181,7 @@ private[spark] class SparkSubmit extends Logging { if (args.isStandaloneCluster && args.useRest) { try { logInfo("Running Spark using the REST application submission protocol.") + doRunMain() } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => From 273b28404ca8bcdb07878be8fb0053e6625046bf Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 3 Aug 2018 07:43:54 +0000 Subject: [PATCH 1298/2461] [SPARK-24993][SQL] Make Avro Fast Again ## What changes were proposed in this pull request? When lindblombr at apple developed [SPARK-24855](https://github.com/apache/spark/pull/21847) to support specified schema on write, we found a performance regression in Avro writer for our dataset. With this PR, the performance is improved, but not as good as Spark 2.3 + the old avro writer. There must be something we miss which we need to investigate further. Spark 2.4 ``` spark git:(master) ./build/mvn -DskipTests clean package spark git:(master) bin/spark-shell --jars external/avro/target/spark-avro_2.11-2.4.0-SNAPSHOT.jar ``` Spark 2.3 + databricks avro ``` spark git:(branch-2.3) ./build/mvn -DskipTests clean package spark git:(branch-2.3) bin/spark-shell --packages com.databricks:spark-avro_2.11:4.0.0 ``` Current master: ``` +-------+--------------------+ |summary| writeTimes| +-------+--------------------+ | count| 100| | mean| 2.95621| | stddev|0.030895815479469294| | min| 2.915| | max| 3.049| +-------+--------------------+ +-------+--------------------+ |summary| readTimes| +-------+--------------------+ | count| 100| | mean| 0.31072999999999995| | stddev|0.054139709842390006| | min| 0.259| | max| 0.692| +-------+--------------------+ ``` Current master with this PR: ``` +-------+--------------------+ |summary| writeTimes| +-------+--------------------+ | count| 100| | mean| 2.5804300000000002| | stddev|0.011175600225672079| | min| 2.558| | max| 2.62| +-------+--------------------+ +-------+--------------------+ |summary| readTimes| +-------+--------------------+ | count| 100| | mean| 0.29922000000000004| | stddev|0.058261961532514166| | min| 0.251| | max| 0.732| +-------+--------------------+ ``` Spark 2.3 + databricks avro: ``` +-------+--------------------+ |summary| writeTimes| +-------+--------------------+ | count| 100| | mean| 1.7730500000000005| | stddev|0.025199156230863575| | min| 1.729| | max| 1.833| +-------+--------------------+ +-------+-------------------+ |summary| readTimes| +-------+-------------------+ | count| 100| | mean| 0.29715| | stddev|0.05685643358850465| | min| 0.258| | max| 0.718| +-------+-------------------+ ``` The following is the test code to reproduce the result. ```scala spark.sqlContext.setConf("spark.sql.avro.compression.codec", "uncompressed") val sparkSession = spark import sparkSession.implicits._ val df = spark.sparkContext.range(1, 3000).repartition(1).map { uid => val features = Array.fill(16000)(scala.math.random) (uid, scala.math.random, java.util.UUID.randomUUID().toString, java.util.UUID.randomUUID().toString, features) }.toDF("uid", "random", "uuid1", "uuid2", "features").cache() val size = df.count() // Write into ramdisk to rule out the disk IO impact val tempSaveDir = s"/Volumes/ramdisk/${java.util.UUID.randomUUID()}/" val n = 150 val writeTimes = new Array[Double](n) var i = 0 while (i < n) { val t1 = System.currentTimeMillis() df.write .format("com.databricks.spark.avro") .mode("overwrite") .save(tempSaveDir) val t2 = System.currentTimeMillis() writeTimes(i) = (t2 - t1) / 1000.0 i += 1 } df.unpersist() // The first 50 runs are for warm-up val readTimes = new Array[Double](n) i = 0 while (i < n) { val t1 = System.currentTimeMillis() val readDF = spark.read.format("com.databricks.spark.avro").load(tempSaveDir) assert(readDF.count() == size) val t2 = System.currentTimeMillis() readTimes(i) = (t2 - t1) / 1000.0 i += 1 } spark.sparkContext.parallelize(writeTimes.slice(50, 150)).toDF("writeTimes").describe("writeTimes").show() spark.sparkContext.parallelize(readTimes.slice(50, 150)).toDF("readTimes").describe("readTimes").show() ``` ## How was this patch tested? Existing tests. Author: DB Tsai Author: Brian Lindblom Closes #21952 from dbtsai/avro-performance-fix. --- .../spark/sql/avro/AvroSerializer.scala | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index a744d54e43e53..382f9a750c16c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -108,17 +108,20 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: et, resolveNullableType(avroType.getElementType, containsNull)) (getter, ordinal) => { val arrayData = getter.getArray(ordinal) - val result = new java.util.ArrayList[Any] + val len = arrayData.numElements() + val result = new Array[Any](len) var i = 0 - while (i < arrayData.numElements()) { - if (arrayData.isNullAt(i)) { - result.add(null) + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null } else { - result.add(elementConverter(arrayData, i)) + result(i) = elementConverter(arrayData, i) } i += 1 } - result + // avro writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) } case st: StructType => @@ -131,13 +134,14 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: vt, resolveNullableType(avroType.getValueType, valueContainsNull)) (getter, ordinal) => val mapData = getter.getMap(ordinal) - val result = new java.util.HashMap[String, Any](mapData.numElements()) + val len = mapData.numElements() + val result = new java.util.HashMap[String, Any](len) val keyArray = mapData.keyArray() val valueArray = mapData.valueArray() var i = 0 - while (i < mapData.numElements()) { + while (i < len) { val key = keyArray.getUTF8String(i).toString - if (valueArray.isNullAt(i)) { + if (valueContainsNull && valueArray.isNullAt(i)) { result.put(key, null) } else { result.put(key, valueConverter(valueArray, i)) From c32dbd6bd55cdff4d73408ba5fd6fe18056048fe Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 3 Aug 2018 08:17:18 -0500 Subject: [PATCH 1299/2461] [SPARK-18057][FOLLOW-UP][SS] Update Kafka client version from 0.10.0.1 to 2.0.0 ## What changes were proposed in this pull request? Update to kafka 2.0.0 in streaming-kafka module, and remove override for Scala 2.12. It won't compile for 2.12 otherwise. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #21955 from srowen/SPARK-18057.2. --- external/kafka-0-10-sql/pom.xml | 10 +----- external/kafka-0-10/pom.xml | 26 +++++++++------ .../streaming/kafka010/KafkaRDDSuite.scala | 32 +++++++++++-------- .../streaming/kafka010/KafkaTestUtils.scala | 12 +++---- .../kafka010/mocks/MockScheduler.scala | 3 +- .../streaming/kafka010/mocks/MockTime.scala | 10 +++--- 6 files changed, 50 insertions(+), 43 deletions(-) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 95500037c1473..8588e8be052eb 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -29,6 +29,7 @@ spark-sql-kafka-0-10_2.11 sql-kafka-0-10 + 2.0.0 jar @@ -128,13 +129,4 @@ target/scala-${scala.binary.version}/test-classes - - - scala-2.12 - - 0.10.1.1 - - - - diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 3b124b2a69d50..a97fd35bfbb73 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -28,7 +28,8 @@ spark-streaming-kafka-0-10_2.11 streaming-kafka-0-10 - 0.10.0.1 + + 2.0.0 jar Spark Integration for Kafka 0.10 @@ -58,6 +59,20 @@ kafka_${scala.binary.version} ${kafka.version} test + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + net.sf.jopt-simple @@ -93,13 +108,4 @@ target/scala-${scala.binary.version}/test-classes - - - scala-2.12 - - 0.10.1.1 - - - - diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index 271adea1df731..3ac6509b04707 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -23,11 +23,11 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Random -import kafka.common.TopicAndPartition -import kafka.log._ -import kafka.message._ +import kafka.log.{CleanerConfig, Log, LogCleaner, LogConfig, ProducerStateManager} +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} import kafka.utils.Pool import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll @@ -72,33 +72,39 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { val mockTime = new MockTime() - // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api - val logs = new Pool[TopicAndPartition, Log]() + val logs = new Pool[TopicPartition, Log]() val logDir = kafkaTestUtils.brokerLogDir val dir = new File(logDir, topic + "-" + partition) dir.mkdirs() val logProps = new ju.Properties() logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val logDirFailureChannel = new LogDirFailureChannel(1) + val topicPartition = new TopicPartition(topic, partition) val log = new Log( dir, LogConfig(logProps), 0L, + 0L, mockTime.scheduler, - mockTime + new BrokerTopicStats(), + mockTime, + Int.MaxValue, + Int.MaxValue, + topicPartition, + new ProducerStateManager(topicPartition, dir), + logDirFailureChannel ) messages.foreach { case (k, v) => - val msg = new ByteBufferMessageSet( - NoCompressionCodec, - new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) - log.append(msg) + val record = new SimpleRecord(k.getBytes, v.getBytes) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, record), 0); } log.roll() - logs.put(TopicAndPartition(topic, partition), log) + logs.put(topicPartition, log) - val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + val cleaner = new LogCleaner(CleanerConfig(), Array(dir), logs, logDirFailureChannel) cleaner.startup() - cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + cleaner.awaitCleaned(new TopicPartition(topic, partition), log.activeSegment.baseOffset, 1000) cleaner.shutdown() mockTime.scheduler.shutdown() diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 70b579d96d692..2315baf7bc944 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -109,7 +109,7 @@ private[kafka010] class KafkaTestUtils extends Logging { brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort() + brokerPort = server.boundPort(brokerConf.interBrokerListenerName) (server, brokerPort) }, new SparkConf(), "KafkaBroker") @@ -222,6 +222,8 @@ private[kafka010] class KafkaTestUtils extends Logging { props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") + props.put("offsets.topic.replication.factor", "1") + props.put("group.initial.rebalance.delay.ms", "10") props } @@ -270,12 +272,10 @@ private[kafka010] class KafkaTestUtils extends Logging { private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => - val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr - + val leader = partitionState.basePartitionState.leader + val isr = partitionState.basePartitionState.isr zkUtils.getLeaderForPartition(topic, partition).isDefined && - Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && - leaderAndInSyncReplicas.isr.nonEmpty - + Request.isValidBrokerId(leader) && !isr.isEmpty case _ => false } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala index 928e1a6ef54b9..4811d041e7e9e 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -21,7 +21,8 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable.PriorityQueue -import kafka.utils.{Scheduler, Time} +import kafka.utils.Scheduler +import org.apache.kafka.common.utils.Time /** * A mock scheduler that executes tasks synchronously using a mock time instance. diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala index a68f94db1f689..8a8646ee4eb94 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010.mocks import java.util.concurrent._ -import kafka.utils.Time +import org.apache.kafka.common.utils.Time /** * A class used for unit testing things which depend on the Time interface. @@ -36,12 +36,14 @@ private[kafka010] class MockTime(@volatile private var currentMs: Long) extends def this() = this(System.currentTimeMillis) - def milliseconds: Long = currentMs + override def milliseconds: Long = currentMs - def nanoseconds: Long = + override def hiResClockMs(): Long = milliseconds + + override def nanoseconds: Long = TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) - def sleep(ms: Long) { + override def sleep(ms: Long) { this.currentMs += ms scheduler.tick() } From 92b48842b944a3e430472294cdc3c481bad6b804 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 09:36:56 -0700 Subject: [PATCH 1300/2461] [SPARK-24954][CORE] Fail fast on job submit if run a barrier stage with dynamic resource allocation enabled ## What changes were proposed in this pull request? We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead to some confusing behaviors (eg. with dynamic resource allocation enabled, it may happen that we acquire some executors (but not enough to launch all the tasks in a barrier stage) and later release them due to executor idle time expire, and then acquire again). We perform the check on job submit and fail fast if running a barrier stage with dynamic resource allocation enabled. ## How was this patch tested? Added new test suite `BarrierStageOnSubmittedSuite` to cover all the fail fast cases that submitted a job containing one or more barrier stages. Author: Xingbo Jiang Closes #21915 from jiangxb1987/SPARK-24954. --- .../apache/spark/scheduler/DAGScheduler.scala | 25 ++++++++ .../spark/BarrierStageOnSubmittedSuite.scala | 57 +++++++++++++++---- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3dd0718ac673d..cf1fcbce78b30 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -364,6 +364,7 @@ class DAGScheduler( */ def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd + checkBarrierStageWithDynamicAllocation(rdd) checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) @@ -384,6 +385,23 @@ class DAGScheduler( stage } + /** + * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead + * to some confusing behaviors (eg. with dynamic resource allocation enabled, it may happen that + * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and + * later release them due to executor idle time expire, and then acquire again). + * + * We perform the check on job submit and fail fast if running a barrier stage with dynamic + * resource allocation enabled. + * + * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage + */ + private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + throw new SparkException(DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + } + /** * Create a ResultStage associated with the provided jobId. */ @@ -393,6 +411,7 @@ class DAGScheduler( partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { + checkBarrierStageWithDynamicAllocation(rdd) checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() @@ -2001,4 +2020,10 @@ private[spark] object DAGScheduler { "PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head " + "(scala) or barrierRdd.collect()[0] (python).\n" + "2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))." + + // Error message when running a barrier stage with dynamic resource allocation enabled. + val ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION = + "[SPARK-24942]: Barrier execution mode does not support dynamic resource allocation for " + + "now. You can disable dynamic resource allocation by setting Spark conf " + + "\"spark.dynamicAllocation.enabled\" to \"false\"." } diff --git a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala index f2b3884e25ffa..75e13a9bec105 100644 --- a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala +++ b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.BeforeAndAfterEach - import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.scheduler.DAGScheduler import org.apache.spark.util.ThreadUtils @@ -30,16 +28,13 @@ import org.apache.spark.util.ThreadUtils * This test suite covers all the cases that shall fail fast on job submitted that contains one * of more barrier stages. */ -class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach - with LocalSparkContext { - - override def beforeEach(): Unit = { - super.beforeEach() +class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext { - val conf = new SparkConf() - .setMaster("local[4]") - .setAppName("test") - sc = new SparkContext(conf) + private def createSparkContext(conf: Option[SparkConf] = None): SparkContext = { + new SparkContext(conf.getOrElse( + new SparkConf() + .setMaster("local[4]") + .setAppName("test"))) } private def testSubmitJob( @@ -62,6 +57,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier ResultStage that contains PartitionPruningRDD") { + sc = createSparkContext() val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) val rdd = prunedRdd .barrier() @@ -71,6 +67,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier ShuffleMapStage that contains PartitionPruningRDD") { + sc = createSparkContext() val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) val rdd = prunedRdd .barrier() @@ -82,6 +79,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier stage that doesn't contain PartitionPruningRDD") { + sc = createSparkContext() val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) val rdd = prunedRdd .repartition(2) @@ -93,6 +91,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier stage with partial partitions") { + sc = createSparkContext() val rdd = sc.parallelize(1 to 10, 4) .barrier() .mapPartitions((iter, context) => iter) @@ -101,6 +100,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier stage with union()") { + sc = createSparkContext() val rdd1 = sc.parallelize(1 to 10, 2) .barrier() .mapPartitions((iter, context) => iter) @@ -114,6 +114,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier stage with coalesce()") { + sc = createSparkContext() val rdd = sc.parallelize(1 to 10, 4) .barrier() .mapPartitions((iter, context) => iter) @@ -125,6 +126,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier stage that contains an RDD that depends on multiple barrier RDDs") { + sc = createSparkContext() val rdd1 = sc.parallelize(1 to 10, 4) .barrier() .mapPartitions((iter, context) => iter) @@ -139,6 +141,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach } test("submit a barrier stage with zip()") { + sc = createSparkContext() val rdd1 = sc.parallelize(1 to 10, 4) .barrier() .mapPartitions((iter, context) => iter) @@ -150,4 +153,36 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach val result = rdd3.collect().sorted assert(result === Seq(12, 14, 16, 18, 20, 22, 24, 26, 28, 30)) } + + test("submit a barrier ResultStage with dynamic resource allocation enabled") { + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions((iter, context) => iter) + testSubmitJob(sc, rdd, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + + test("submit a barrier ShuffleMapStage with dynamic resource allocation enabled") { + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + + val rdd = sc.parallelize(1 to 10, 4) + .barrier() + .mapPartitions((iter, context) => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } } From 8c14276c3362798b030db7a9fcdc31a10d04b643 Mon Sep 17 00:00:00 2001 From: Onwuka Gideon Date: Fri, 3 Aug 2018 17:39:40 -0500 Subject: [PATCH 1301/2461] Little typo ## What changes were proposed in this pull request? Fixed little typo for a comment ## How was this patch tested? Manual test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Onwuka Gideon Closes #21992 from dongido001/patch-1. --- python/pyspark/streaming/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index a4515828d180c..3fa57ca85b37b 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -222,7 +222,7 @@ def remember(self, duration): Set each DStreams in this context to remember RDDs it generated in the last given duration. DStreams remember RDDs only for a limited duration of time and releases them for garbage collection. - This method allows the developer to specify how to long to remember + This method allows the developer to specify how long to remember the RDDs (if the developer wishes to query old data outside the DStream computation). @@ -287,7 +287,7 @@ def _check_serializers(self, rdds): def queueStream(self, rdds, oneAtATime=True, default=None): """ - Create an input stream from an queue of RDDs or list. In each batch, + Create an input stream from a queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. .. note:: Changes to the queue after the stream is created will not be recognized. From 4c27663cb20f3cde7317ffcb2c9d42257a40057f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 3 Aug 2018 16:22:54 -0700 Subject: [PATCH 1302/2461] [SPARK-18057][FOLLOW-UP][SS] Update Kafka client version from 0.10.0.1 to 2.0.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Increase ZK timeout and harmonize configs across Kafka tests to resol…ve potentially flaky test failure ## How was this patch tested? Existing tests Author: Sean Owen Closes #21995 from srowen/SPARK-18057.3. --- .../org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 1 + .../apache/spark/streaming/kafka010/KafkaTestUtils.scala | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 82294905c24b9..d89cccd3c5215 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -304,6 +304,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) props.put("zookeeper.connect", zkAddress) + props.put("zookeeper.connection.timeout.ms", "60000") props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 2315baf7bc944..eef4c55d27d51 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -32,6 +32,7 @@ import kafka.api.Request import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.ZkUtils import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.StringSerializer import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} @@ -109,7 +110,7 @@ private[kafka010] class KafkaTestUtils extends Logging { brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) server = new KafkaServer(brokerConf) server.startup() - brokerPort = server.boundPort(brokerConf.interBrokerListenerName) + brokerPort = server.boundPort(new ListenerName("PLAINTEXT")) (server, brokerPort) }, new SparkConf(), "KafkaBroker") @@ -220,8 +221,11 @@ private[kafka010] class KafkaTestUtils extends Logging { props.put("port", brokerPort.toString) props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) + props.put("zookeeper.connection.timeout.ms", "60000") props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") + props.put("delete.topic.enable", "true") + props.put("offsets.topic.num.partitions", "1") props.put("offsets.topic.replication.factor", "1") props.put("group.initial.rebalance.delay.ms", "10") props From 41c2227a2318029709553a588e44dee28f106350 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 4 Aug 2018 14:17:32 +0800 Subject: [PATCH 1303/2461] [SPARK-24722][SQL] pivot() with Column type argument ## What changes were proposed in this pull request? In the PR, I propose column-based API for the `pivot()` function. It allows using of any column expressions as the pivot column. Also this makes it consistent with how groupBy() works. ## How was this patch tested? I added new tests to `DataFramePivotSuite` and updated PySpark examples for the `pivot()` function. Author: Maxim Gekk Closes #21699 from MaxGekk/pivot-column. --- python/pyspark/sql/group.py | 8 ++ .../spark/sql/RelationalGroupedDataset.scala | 100 +++++++++++++----- .../spark/sql/DataFramePivotSuite.scala | 88 ++++++++++++--- .../apache/spark/sql/test/SQLTestData.scala | 12 +++ 4 files changed, 167 insertions(+), 41 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 0906c9c6b329a..cc1da8e7c1f72 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -211,6 +211,8 @@ def pivot(self, pivot_col, values=None): >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: jgd = self._jgd.pivot(pivot_col) @@ -296,6 +298,12 @@ def _test(): Row(course="dotNET", year=2012, earnings=5000), Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() + globs['df5'] = sc.parallelize([ + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)), + Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)), + Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)), + Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)), + Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4e73b3657b4f5..d700fb83b9b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -314,7 +314,67 @@ class RelationalGroupedDataset protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RelationalGroupedDataset = { + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. + * + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(Column(pivotColumn), values) + } + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. + * + * {{{ + * // Or without specifying column values (less efficient) + * df.groupBy($"year").pivot($"course").sum($"earnings"); + * }}} + * + * @param pivotColumn he column to pivot. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues // Get the distinct values of the column and sort them so its consistent @@ -339,29 +399,24 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. * * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -371,25 +426,14 @@ class RelationalGroupedDataset protected[sql]( /** * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of + * the `String` type. * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @param pivotColumn Name of the column to pivot. + * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 + * @since 2.4.0 */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 6ca9ee57e8f49..b972b9ef93e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), - Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + expected) } test("pivot year") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)), + expected) } test("pivot courses with multiple aggregations") { + val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), - Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: - Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year") + .pivot($"course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + expected) } test("pivot year with string values (cast)") { @@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { test("pivot courses with no values") { // Note Java comes before dotNet in sorted order + val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil checkAnswer( courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), - Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + expected) } test("pivot year with no values") { + val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil checkAnswer( courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) + expected) + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + expected) } test("pivot max values enforced") { @@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } test("pivot with datatype not supported by PivotFirst") { + val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil checkAnswer( complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), - Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil - ) + expected) + checkAnswer( + complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)), + expected) } test("pivot with datatype not supported by PivotFirst 2") { @@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) } } + + test("SPARK-24722: pivoting nested columns") { + val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: references to multiple columns in the pivot column") { + val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET")) + .agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("SPARK-24722: pivoting by a constant") { + val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil + val df1 = trainingSales + .groupBy($"sales.year") + .pivot(lit(123), Seq(123)) + .agg(sum($"sales.earnings")) + + checkAnswer(df1, expected) + } + + test("SPARK-24722: aggregate as the pivot column") { + val exception = intercept[AnalysisException] { + trainingSales + .groupBy($"sales.year") + .pivot(min($"training"), Seq("Experts")) + .agg(sum($"sales.earnings")) + } + + assert(exception.getMessage.contains("aggregate functions are not allowed")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index deea9dbb30aae..615923fe02d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -268,6 +268,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val trainingSales: DataFrame = { + val df = spark.sparkContext.parallelize( + TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) :: + TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) :: + TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) :: + TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) :: + TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF() + df.createOrReplaceTempView("trainingSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -323,4 +334,5 @@ private[sql] object SQLTestData { case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) case class CourseSales(course: String, year: Int, earnings: Double) + case class TrainingSales(training: String, sales: CourseSales) } From 36ea55e97e609d25de5d8cd47ce8d2a7ae990d62 Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Sat, 4 Aug 2018 02:27:15 -0400 Subject: [PATCH 1304/2461] [SPARK-24940][SQL] Coalesce and Repartition Hint for SQL Queries ## What changes were proposed in this pull request? Many Spark SQL users in my company have asked for a way to control the number of output files in Spark SQL. The users prefer not to use function repartition(n) or coalesce(n, shuffle) that require them to write and deploy Scala/Java/Python code. We propose adding the following Hive-style Coalesce and Repartition Hint to Spark SQL: ``` ... SELECT /*+ COALESCE(numPartitions) */ ... ... SELECT /*+ REPARTITION(numPartitions) */ ... ``` Multiple such hints are allowed. Multiple nodes are inserted into the logical plan, and the optimizer will pick the leftmost hint. ``` INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t == Logical Plan == 'InsertIntoTable 'UnresolvedRelation `s`, false, false +- 'UnresolvedHint REPARTITION, [100] +- 'UnresolvedHint COALESCE, [500] +- 'UnresolvedHint COALESCE, [10] +- 'Project [*] +- 'UnresolvedRelation `t` == Optimized Logical Plan == InsertIntoHadoopFsRelationCommand ... +- Repartition 100, true +- HiveTableRelation ... ``` ## How was this patch tested? All unit tests. Manual tests using explain. Author: John Zhuge Closes #21911 from jzhuge/SPARK-24940. --- .../sql/catalyst/analysis/Analyzer.scala | 1 + .../sql/catalyst/analysis/ResolveHints.scala | 28 +++++++++++++++ .../catalyst/analysis/ResolveHintsSuite.scala | 35 +++++++++++++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 27 ++++++++++++++ .../apache/spark/sql/DataFrameHintSuite.scala | 10 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 33 +++++++++++++++++ 6 files changed, 134 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7f235ac560299..b5016fdb29d92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -145,6 +145,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), + ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index bfe5169c25900..1ef482b0e9f5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.IntegerType /** @@ -102,6 +104,32 @@ object ResolveHints { } } + /** + * COALESCE Hint accepts name "COALESCE" and "REPARTITION". + * Its parameter includes a partition number. + */ + object ResolveCoalesceHints extends Rule[LogicalPlan] { + private val COALESCE_HINT_NAMES = Set("COALESCE", "REPARTITION") + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case h: UnresolvedHint if COALESCE_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + val hintName = h.name.toUpperCase(Locale.ROOT) + val shuffle = hintName match { + case "REPARTITION" => true + case "COALESCE" => false + } + val numPartitions = h.parameters match { + case Seq(Literal(numPartitions: Int, IntegerType)) => + numPartitions + case Seq(numPartitions: Int) => + numPartitions + case _ => + throw new AnalysisException(s"$hintName Hint expects a partition number as parameter") + } + Repartition(numPartitions, shuffle, h.child) + } + } + /** * Removes all the hints, used to remove invalid hints provided by the user. * This must be executed after all the other hint rules are executed. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 9782b5fb0d266..bd66ee5355f45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ @@ -120,4 +121,38 @@ class ResolveHintsSuite extends AnalysisTest { testRelation.where('a > 1).select('a).select('a).analyze, caseSensitive = false) } + + test("coalesce and repartition hint") { + checkAnalysis( + UnresolvedHint("COALESCE", Seq(Literal(10)), table("TaBlE")), + Repartition(numPartitions = 10, shuffle = false, child = testRelation)) + checkAnalysis( + UnresolvedHint("coalesce", Seq(Literal(20)), table("TaBlE")), + Repartition(numPartitions = 20, shuffle = false, child = testRelation)) + checkAnalysis( + UnresolvedHint("REPARTITION", Seq(Literal(100)), table("TaBlE")), + Repartition(numPartitions = 100, shuffle = true, child = testRelation)) + checkAnalysis( + UnresolvedHint("RePARTITion", Seq(Literal(200)), table("TaBlE")), + Repartition(numPartitions = 200, shuffle = true, child = testRelation)) + + val errMsgCoal = "COALESCE Hint expects a partition number as parameter" + assertAnalysisError( + UnresolvedHint("COALESCE", Seq.empty, table("TaBlE")), + Seq(errMsgCoal)) + assertAnalysisError( + UnresolvedHint("COALESCE", Seq(Literal(10), Literal(false)), table("TaBlE")), + Seq(errMsgCoal)) + assertAnalysisError( + UnresolvedHint("COALESCE", Seq(Literal(1.0)), table("TaBlE")), + Seq(errMsgCoal)) + + val errMsgRepa = "REPARTITION Hint expects a partition number as parameter" + assertAnalysisError( + UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")), + Seq(errMsgRepa)) + assertAnalysisError( + UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), + Seq(errMsgRepa)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 924700483dbe4..d7200d0bff5d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -593,6 +593,33 @@ class PlanParserSuite extends AnalysisTest { parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + + comparePlans( + parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"), + UnresolvedHint("REPARTITION", Seq(Literal(100)), + table("t").select(star()))) + + comparePlans( + parsePlan( + "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), + InsertIntoTable(table("s"), Map.empty, + UnresolvedHint("REPARTITION", Seq(Literal(100)), + UnresolvedHint("COALESCE", Seq(Literal(500)), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) + + comparePlans( + parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), + UnresolvedHint("BROADCASTJOIN", Seq($"u"), + UnresolvedHint("REPARTITION", Seq(Literal(100)), + table("t").select(star())))) + + intercept("SELECT /*+ COALESCE(30 + 50) */ * FROM t", "mismatched input") } test("SPARK-20854: select hint syntax with expressions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 0dd5bdcba2e4c..7ef8b542c79a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -59,4 +59,14 @@ class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { ) ) } + + test("coalesce and repartition hint") { + check( + df.hint("COALESCE", 10), + UnresolvedHint("COALESCE", Seq(10), df.logicalPlan)) + + check( + df.hint("REPARTITION", 100), + UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 86083d1701c2c..2cb7a04714a52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2797,4 +2798,36 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(3, 99, 1))) } } + + + test("SPARK-24940: coalesce and repartition hint") { + withTempView("nums1") { + val numPartitionsSrc = 10 + spark.range(0, 100, 1, numPartitionsSrc).createOrReplaceTempView("nums1") + assert(spark.table("nums1").rdd.getNumPartitions == numPartitionsSrc) + + withTable("nums") { + sql("CREATE TABLE nums (id INT) USING parquet") + + Seq(5, 20, 2).foreach { numPartitions => + sql( + s""" + |INSERT OVERWRITE TABLE nums + |SELECT /*+ REPARTITION($numPartitions) */ * + |FROM nums1 + """.stripMargin) + assert(spark.table("nums").inputFiles.length == numPartitions) + + sql( + s""" + |INSERT OVERWRITE TABLE nums + |SELECT /*+ COALESCE($numPartitions) */ * + |FROM nums1 + """.stripMargin) + // Coalesce can not increase the number of partitions + assert(spark.table("nums").inputFiles.length == Seq(numPartitions, numPartitionsSrc).min) + } + } + } + } } From 0ecc132d6b68f78d61a1ef68d3ae386670df9322 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 4 Aug 2018 16:08:53 +0900 Subject: [PATCH 1305/2461] [SPARK-23909][SQL] Add filter function. ## What changes were proposed in this pull request? This pr adds `filter` function which filters the input array using the given predicate. ```sql > SELECT filter(array(1, 2, 3), x -> x % 2 == 1); array(1, 3) ``` ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21965 from ueshin/issues/SPARK-23909/filter. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 76 +++++++++++++-- .../HigherOrderFunctionsSuite.scala | 37 +++++++ .../inputs/higher-order-functions.sql | 9 ++ .../results/higher-order-functions.sql.out | 30 +++++- .../spark/sql/DataFrameFunctionsSuite.scala | 96 +++++++++++++++++++ 6 files changed, 240 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f7517486e5411..d0efe975f81ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,6 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[ArrayFilter]("filter"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index c5c3482afa134..e15225ffbd2d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.concurrent.atomic.AtomicReference +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -140,6 +142,18 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu @transient lazy val functionForEval: Expression = functionsForEval.head } +object ArrayBasedHigherOrderFunction { + + def elementArgumentType(dt: DataType): (DataType, Boolean) = { + dt match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + } +} + /** * Transform elements in an array using the transform function. This is similar to * a `map` in functional programming. @@ -164,17 +178,12 @@ case class ArrayTransform( override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val (elementType, containsNull) = input.dataType match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + copy(function = f(function, elem :: (IntegerType, false) :: Nil)) case _ => - copy(function = f(function, (elementType, containsNull) :: Nil)) + copy(function = f(function, elem :: Nil)) } } @@ -210,3 +219,54 @@ case class ArrayTransform( override def prettyName: String = "transform" } + +/** + * Filters the input array using the given lambda function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); + array(1, 3) + """, + since = "2.4.0") +case class ArrayFilter( + input: Expression, + function: Expression) + extends ArrayBasedHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = input.dataType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + copy(function = f(function, elem :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val f = functionForEval + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(input).asInstanceOf[Boolean]) { + buffer += elementVar.value.get + } + i += 1 + } + new GenericArrayData(buffer) + } + } + + override def prettyName: String = "filter" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e987ea5b8a4d1..d1330c7aad219 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -54,6 +54,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) } + def filter(expr: Expression, f: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -94,4 +99,36 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), Seq("[1, 3, 5]", null, "[4, 6]")) } + + test("ArrayFilter") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(filter(ai0, isEven), Seq(2)) + checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3)) + checkEvaluation(filter(ai1, isEven), Seq.empty) + checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3)) + checkEvaluation(filter(ain, isEven), null) + checkEvaluation(filter(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2")) + checkEvaluation(filter(as1, startsWithA), Seq("a")) + checkEvaluation(filter(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), + Seq(Seq(1, 3), null, Seq(5))) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 8e928a41f08e0..f833aa5818bc1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -24,3 +24,12 @@ select transform(ys, 0) as v from nested; -- Transform a null array select transform(cast(null as array), x -> x + 1) as v; + +-- Filter. +select filter(ys, y -> y > 30) as v from nested; + +-- Filter a null array +select filter(cast(null as array), y -> true) as v; + +-- Filter nested arrays +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index ca2c3c35333cc..4c5d972378b31 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 11 -- !query 0 @@ -79,3 +79,31 @@ select transform(cast(null as array), x -> x + 1) as v struct> -- !query 7 output NULL + + +-- !query 8 +select filter(ys, y -> y > 30) as v from nested +-- !query 8 schema +struct> +-- !query 8 output +[32,97] +[77] +[] + + +-- !query 9 +select filter(cast(null as array), y -> true) as v +-- !query 9 schema +struct> +-- !query 9 output +NULL + + +-- !query 10 +select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested +-- !query 10 schema +struct>> +-- !query 10 output +[[96,65],[]] +[[99],[123],[]] +[[]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 923482024b033..1d5707a2c7047 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,6 +1800,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } + test("filter function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("filter function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("filter function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("filter(s, x -> x is not null)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("filter function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("filter(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("filter(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("filter(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 684c719cc01d48d9f67983ae208f3cdc519bf654 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 4 Aug 2018 16:35:14 +0900 Subject: [PATCH 1306/2461] [SPARK-23915][SQL][FOLLOWUP] Add array_except function ## What changes were proposed in this pull request? simplify the codegen: 1. only do real codegen if the type can be specialized by the hash set 2. change the null handling. Before: track the nullElementIndex, and create a new ArrayData to insert the null in the middle. After: track the nullElementIndex, put a null placeholder in the ArrayBuilder, at the end create ArrayData from ArrayBuilder directly. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #21966 from cloud-fan/minor2. --- .../expressions/collectionOperations.scala | 205 +++++++++--------- 1 file changed, 98 insertions(+), 107 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b03bd7d942d72..3f94f25796634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3980,7 +3980,8 @@ object ArrayUnion { """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { + with ComplexTypeMergingExpression { + override def dataType: DataType = { dataTypeCheck left.dataType @@ -4077,81 +4078,80 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") - val pos = ctx.freshName("pos") val value = ctx.freshName("value") - val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - if (elementTypeSupportEquals) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = - elementType match { - case ByteType | ShortType | IntegerType => - ("$mcI$sp", "Int", "int", s"(int) $value", - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case LongType | FloatType | DoubleType => - val signature = elementType match { - case LongType => "$mcJ$sp" - case FloatType => "$mcF$sp" - case DoubleType => "$mcD$sp" - } - (signature, CodeGenerator.boxedType(elementType), - CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } + val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + def genGetValue(array: String): String = + CodeGenerator.getValue(array, elementType, i) + + val (hsPostFix, hsTypeName) = elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + + // we cast byte/short to int when writing to the hash set. + val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } nullSafeCodeGen(ctx, ev, (array1, array2) => { val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") val builder = ctx.freshName("builder") - val array = ctx.freshName("array") val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") val genericArrayData = classOf[GenericArrayData].getName val arrayBuilder = "scala.collection.mutable.ArrayBuilder" - val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" - val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { - s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" - } else { - s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" - } + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" - def withArray2NullCheck(body: String) = + def withArray2NullCheck(body: String): String = if (right.dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($array2.isNullAt($i)) { - | $notFoundNullElement = false; - |} else { - | $body - |} + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } } else { body } - val array2Body = + + val writeArray2ToHashSet = withArray2NullCheck( s""" - |$javaTypeName $value = $array2.$getter; - |$hsJavaTypeName $hsValue = $genHsValue; - |$hs.add$postFix($hsValue); - """.stripMargin + |$jt $value = ${genGetValue(array2)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + val nullValueHolder = elementType match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { @@ -4161,6 +4161,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike | $nullElementIndex = $size; | $notFoundNullElement = false; | $size++; + | $builder.$$plus$$eq($nullValueHolder); | } |} else { | $body @@ -4169,81 +4170,71 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } else { body } - val array1Body = + + val processArray1 = withArray1NullAssignment( s""" - |$javaTypeName $value = $array1.$getter; - |$hsJavaTypeName $hsValue = $genHsValue; - |if (!$hs.contains($hsValue)) { + |$jt $value = ${genGetValue(array1)}; + |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hs.add$postFix($hsValue); + | $hashSet.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} - """.stripMargin + """.stripMargin) - val nonNullArrayDataBuild = { - val build = if (postFix != "") { - val defaultSize = elementType.defaultSize + def withResultArrayNullCheck(body: String): String = { + if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | ${ev.value}.setNullAt($nullElementIndex); |} """.stripMargin } else { - s"${ev.value} = new $genericArrayData($builder.result());" + body } + } + + val buildResultArray = withResultArrayNullCheck( s""" |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $size + + | throw new RuntimeException("Cannot create array with " + $size + | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + - | " $prettyName failed."); + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); |} - |$build + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin) + + // Only need to track null element index when array1's element is nullable. + val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $notFoundNullElement = true; + |int $nullElementIndex = -1; """.stripMargin + } else { + "" } - def buildResultArrayData(nonNullArrayDataBuild: String) = - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($nullElementIndex < 0) { - | // result has no null element - | $nonNullArrayDataBuild - |} else { - | // result has null element - | $arrayDataBuilder - | $javaTypeName[] $array = $builder.result(); - | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { - | if ($pos == $nullElementIndex) { - | ${ev.value}.setNullAt($pos); - | } else { - | $javaTypeName $value = $array[$i++]; - | ${ev.value}.$setter; - | } - | } - |} - """.stripMargin - } else { - nonNullArrayDataBuild - } - s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |boolean $notFoundNullElement = true; + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables |for (int $i = 0; $i < $array2.numElements(); $i++) { - | ${withArray2NullCheck(array2Body)} + | $writeArray2ToHashSet |} |$arrayBuilderClass $builder = | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); - |int $nullElementIndex = -1; |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { - | ${withArray1NullAssignment(array1Body)} + | $processArray1 |} - |${buildResultArrayData(nonNullArrayDataBuild)} + |$buildResultArray """.stripMargin }) } else { From 70462f291bf046c648f36063b0161861e6d11898 Mon Sep 17 00:00:00 2001 From: Nihar Sheth Date: Sat, 4 Aug 2018 10:27:34 -0500 Subject: [PATCH 1307/2461] [SPARK-24926][CORE] Ensure numCores is used consistently in all netty configurations ## What changes were proposed in this pull request? Netty could just ignore user-provided configurations. In particular, spark.driver.cores would be ignored when considering the number of cores available to netty (which would usually just default to Runtime.availableProcessors() ). In transport configurations, the number of threads are based directly on how many cores the system believes it has available, and in yarn cluster mode this would generally overshoot the user-preferred value. ## How was this patch tested? As this is mostly a configuration change, tests were done manually by adding spark-submit confs and verifying the number of threads started by netty was what was expected. Passes scalastyle checks from dev/run-tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nihar Sheth Closes #21885 from NiharS/usableCores. --- .../scala/org/apache/spark/SparkContext.scala | 19 ++++++++++++++++--- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 78ba0b31fc6bb..03e91cdd310ed 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -254,7 +254,7 @@ class SparkContext(config: SparkConf) extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master)) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf)) } private[spark] def env: SparkEnv = _env @@ -2668,9 +2668,16 @@ object SparkContext extends Logging { } /** - * The number of driver cores to use for execution in local mode, 0 otherwise. + * The number of cores available to the driver to use for tasks such as I/O with Netty */ private[spark] def numDriverCores(master: String): Int = { + numDriverCores(master, null) + } + + /** + * The number of cores available to the driver to use for tasks such as I/O with Netty + */ + private[spark] def numDriverCores(master: String, conf: SparkConf): Int = { def convertToInt(threads: String): Int = { if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt } @@ -2678,7 +2685,13 @@ object SparkContext extends Logging { case "local" => 1 case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads) case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) - case _ => 0 // driver is not used for execution + case "yarn" => + if (conf != null && conf.getOption("spark.submit.deployMode").contains("cluster")) { + conf.getInt("spark.driver.cores", 0) + } else { + 0 + } + case _ => 0 // Either driver is not being used, or its core count will be interpolated later } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..47576959322d1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -50,7 +50,7 @@ private[netty] class NettyRpcEnv( private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", - conf.getInt("spark.rpc.io.threads", 0)) + conf.getInt("spark.rpc.io.threads", numUsableCores)) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) From 55e3ae6930af4730b01956ac5f60d0b0d2931134 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 4 Aug 2018 11:52:49 -0500 Subject: [PATCH 1308/2461] [SPARK-25001][BUILD] Fix miscellaneous build warnings ## What changes were proposed in this pull request? There are many warnings in the current build (for instance see https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/4734/console). **common**: ``` [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java:237: warning: [rawtypes] found raw type: LevelDBIterator [warn] void closeIterator(LevelDBIterator it) throws IOException { [warn] ^ [warn] missing type arguments for generic class LevelDBIterator [warn] where T is a type-variable: [warn] T extends Object declared in class LevelDBIterator [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java:151: warning: [deprecation] group() in AbstractBootstrap has been deprecated [warn] if (bootstrap != null && bootstrap.group() != null) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java:152: warning: [deprecation] group() in AbstractBootstrap has been deprecated [warn] bootstrap.group().shutdownGracefully(); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java:154: warning: [deprecation] childGroup() in ServerBootstrap has been deprecated [warn] if (bootstrap != null && bootstrap.childGroup() != null) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java:155: warning: [deprecation] childGroup() in ServerBootstrap has been deprecated [warn] bootstrap.childGroup().shutdownGracefully(); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java:112: warning: [deprecation] PooledByteBufAllocator(boolean,int,int,int,int,int,int,int) in PooledByteBufAllocator has been deprecated [warn] return new PooledByteBufAllocator( [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java:321: warning: [rawtypes] found raw type: Future [warn] public void operationComplete(Future future) throws Exception { [warn] ^ [warn] missing type arguments for generic class Future [warn] where V is a type-variable: [warn] V extends Object declared in interface Future [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java:215: warning: [rawtypes] found raw type: StreamInterceptor [warn] StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, [warn] ^ [warn] missing type arguments for generic class StreamInterceptor [warn] where T is a type-variable: [warn] T extends Message declared in class StreamInterceptor [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java:215: warning: [rawtypes] found raw type: StreamInterceptor [warn] StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, [warn] ^ [warn] missing type arguments for generic class StreamInterceptor [warn] where T is a type-variable: [warn] T extends Message declared in class StreamInterceptor [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java:215: warning: [unchecked] unchecked call to StreamInterceptor(MessageHandler,String,long,StreamCallback) as a member of the raw type StreamInterceptor [warn] StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, [warn] ^ [warn] where T is a type-variable: [warn] T extends Message declared in class StreamInterceptor [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java:255: warning: [rawtypes] found raw type: StreamInterceptor [warn] StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), [warn] ^ [warn] missing type arguments for generic class StreamInterceptor [warn] where T is a type-variable: [warn] T extends Message declared in class StreamInterceptor [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java:255: warning: [rawtypes] found raw type: StreamInterceptor [warn] StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), [warn] ^ [warn] missing type arguments for generic class StreamInterceptor [warn] where T is a type-variable: [warn] T extends Message declared in class StreamInterceptor [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java:255: warning: [unchecked] unchecked call to StreamInterceptor(MessageHandler,String,long,StreamCallback) as a member of the raw type StreamInterceptor [warn] StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), [warn] ^ [warn] where T is a type-variable: [warn] T extends Message declared in class StreamInterceptor [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java:270: warning: [deprecation] transfered() in FileRegion has been deprecated [warn] region.transferTo(byteRawChannel, region.transfered()); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:304: warning: [deprecation] transfered() in FileRegion has been deprecated [warn] region.transferTo(byteChannel, region.transfered()); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java:119: warning: [deprecation] transfered() in FileRegion has been deprecated [warn] while (in.transfered() < in.count()) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java:120: warning: [deprecation] transfered() in FileRegion has been deprecated [warn] in.transferTo(channel, in.transfered()); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java:80: warning: [static] static method should be qualified by type name, Murmur3_x86_32, instead of by an expression [warn] Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java:84: warning: [static] static method should be qualified by type name, Murmur3_x86_32, instead of by an expression [warn] Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java:88: warning: [static] static method should be qualified by type name, Murmur3_x86_32, instead of by an expression [warn] Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); [warn] ^ ``` **launcher**: ``` [warn] Pruning sources from previous analysis, due to incompatible CompileSetup. [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java:31: warning: [rawtypes] found raw type: AbstractLauncher [warn] public abstract class AbstractLauncher { [warn] ^ [warn] missing type arguments for generic class AbstractLauncher [warn] where T is a type-variable: [warn] T extends AbstractLauncher declared in class AbstractLauncher ``` **core**: ``` [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/main/scala/org/apache/spark/api/r/RBackend.scala:99: method group in class AbstractBootstrap is deprecated: see corresponding Javadoc for more information. [warn] if (bootstrap != null && bootstrap.group() != null) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/main/scala/org/apache/spark/api/r/RBackend.scala:100: method group in class AbstractBootstrap is deprecated: see corresponding Javadoc for more information. [warn] bootstrap.group().shutdownGracefully() [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/main/scala/org/apache/spark/api/r/RBackend.scala:102: method childGroup in class ServerBootstrap is deprecated: see corresponding Javadoc for more information. [warn] if (bootstrap != null && bootstrap.childGroup() != null) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/main/scala/org/apache/spark/api/r/RBackend.scala:103: method childGroup in class ServerBootstrap is deprecated: see corresponding Javadoc for more information. [warn] bootstrap.childGroup().shutdownGracefully() [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala:151: reflective access of structural type member method getData should be enabled [warn] by making the implicit value scala.language.reflectiveCalls visible. [warn] This can be achieved by adding the import clause 'import scala.language.reflectiveCalls' [warn] or by setting the compiler option -language:reflectiveCalls. [warn] See the Scaladoc for value scala.language.reflectiveCalls for a discussion [warn] why the feature should be explicitly enabled. [warn] val rdd = sc.parallelize(1 to 1).map(concreteObject.getData) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala:175: reflective access of structural type member value innerObject2 should be enabled [warn] by making the implicit value scala.language.reflectiveCalls visible. [warn] val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala:175: reflective access of structural type member method getData should be enabled [warn] by making the implicit value scala.language.reflectiveCalls visible. [warn] val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/LocalSparkContext.scala:32: constructor Slf4JLoggerFactory in class Slf4JLoggerFactory is deprecated: see corresponding Javadoc for more information. [warn] InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:218: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] assert(wrapper.stageAttemptId === stages.head.attemptId) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:261: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] stageAttemptId = stages.head.attemptId)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:287: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] stageAttemptId = stages.head.attemptId)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:471: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] stageAttemptId = stages.last.attemptId)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:966: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:972: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:976: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:1146: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala:1150: value attemptId in class StageInfo is deprecated: Use attemptNumber instead [warn] SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala:197: method transfered in trait FileRegion is deprecated: see corresponding Javadoc for more information. [warn] while (region.transfered() < region.count()) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala:198: method transfered in trait FileRegion is deprecated: see corresponding Javadoc for more information. [warn] region.transferTo(byteChannel, region.transfered()) [warn] ^ ``` **sql**: ``` [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala:534: abstract type T is unchecked since it is eliminated by erasure [warn] assert(partitioning.isInstanceOf[T]) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala:534: abstract type T is unchecked since it is eliminated by erasure [warn] assert(partitioning.isInstanceOf[T]) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala:323: inferred existential type Option[Class[_$1]]( forSome { type _$1 }), which cannot be expressed by wildcards, should be enabled [warn] by making the implicit value scala.language.existentials visible. [warn] This can be achieved by adding the import clause 'import scala.language.existentials' [warn] or by setting the compiler option -language:existentials. [warn] See the Scaladoc for value scala.language.existentials for a discussion [warn] why the feature should be explicitly enabled. [warn] val optClass = Option(collectionCls) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:226: warning: [deprecation] ParquetFileReader(Configuration,FileMetaData,Path,List,List) in ParquetFileReader has been deprecated [warn] this.reader = new ParquetFileReader( [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:178: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:179: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:181: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:182: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:183: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:198: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] switch (descriptor.getType()) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:221: warning: [deprecation] getTypeLength() in ColumnDescriptor has been deprecated [warn] readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength()); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:224: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] throw new IOException("Unsupported type: " + descriptor.getType()); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:246: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] descriptor.getType().toString(), [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:258: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] switch (descriptor.getType()) { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java:384: warning: [deprecation] getType() in ColumnDescriptor has been deprecated [warn] throw new UnsupportedOperationException("Unsupported type: " + descriptor.getType()); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java:458: warning: [static] static variable should be qualified by type name, BaseRepeatedValueVector, instead of by an expression [warn] int index = rowId * accessor.OFFSET_WIDTH; [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java:460: warning: [static] static variable should be qualified by type name, BaseRepeatedValueVector, instead of by an expression [warn] int end = offsets.getInt(index + accessor.OFFSET_WIDTH); [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala:57: a pure expression does nothing in statement position; you may be omitting necessary parentheses [warn] case s => s [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala:182: inferred existential type org.apache.parquet.column.statistics.Statistics[?0]( forSome { type ?0 <: Comparable[?0] }), which cannot be expressed by wildcards, should be enabled [warn] by making the implicit value scala.language.existentials visible. [warn] This can be achieved by adding the import clause 'import scala.language.existentials' [warn] or by setting the compiler option -language:existentials. [warn] See the Scaladoc for value scala.language.existentials for a discussion [warn] why the feature should be explicitly enabled. [warn] val columnStats = oneBlockColumnMeta.getStatistics [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala:146: implicit conversion method conv should be enabled [warn] by making the implicit value scala.language.implicitConversions visible. [warn] This can be achieved by adding the import clause 'import scala.language.implicitConversions' [warn] or by setting the compiler option -language:implicitConversions. [warn] See the Scaladoc for value scala.language.implicitConversions for a discussion [warn] why the feature should be explicitly enabled. [warn] implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala:48: implicit conversion method unsafeRow should be enabled [warn] by making the implicit value scala.language.implicitConversions visible. [warn] private implicit def unsafeRow(value: Int) = { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala:178: method getType in class ColumnDescriptor is deprecated: see corresponding Javadoc for more information. [warn] assert(oneFooter.getFileMetaData.getSchema.getColumns.get(0).getType() === [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala:154: method readAllFootersInParallel in object ParquetFileReader is deprecated: see corresponding Javadoc for more information. [warn] ParquetFileReader.readAllFootersInParallel(configuration, fs.getFileStatus(path)).asScala.toSeq [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java:679: warning: [cast] redundant cast to Complex [warn] Complex typedOther = (Complex)other; [warn] ^ ``` **mllib**: ``` [warn] Pruning sources from previous analysis, due to incompatible CompileSetup. [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala:597: match may not be exhaustive. [warn] It would fail on the following inputs: None, Some((x: Tuple2[?, ?] forSome x not in (?, ?))) [warn] val df = dfs.find { [warn] ^ ``` This PR does not target fix all of them since some look pretty tricky to fix and there look too many warnings including false positive (like deprecated API but it's used in its test, etc.) ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon Closes #21975 from HyukjinKwon/remove-build-warnings. --- .../apache/spark/util/kvstore/LevelDB.java | 2 +- .../spark/network/client/TransportClient.java | 2 +- .../client/TransportResponseHandler.java | 4 +-- .../spark/network/crypto/TransportCipher.java | 2 +- .../spark/network/sasl/SaslEncryption.java | 2 +- .../server/TransportRequestHandler.java | 4 +-- .../spark/network/server/TransportServer.java | 8 +++--- .../apache/spark/network/util/NettyUtils.java | 27 ++++++------------- .../apache/spark/network/ProtocolSuite.java | 4 +-- .../unsafe/hash/Murmur3_x86_32Suite.java | 6 ++--- .../org/apache/spark/api/r/RBackend.scala | 6 ++--- .../org/apache/spark/LocalSparkContext.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 14 +++++----- .../apache/spark/storage/DiskStoreSuite.scala | 4 +-- .../spark/util/ClosureCleanerSuite.scala | 2 ++ .../spark/launcher/AbstractLauncher.java | 2 +- .../spark/ml/recommendation/ALSSuite.scala | 9 ++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 8 ++++-- .../expressions/ObjectExpressionsSuite.scala | 1 + .../parquet/VectorizedColumnReader.java | 26 ++++++++++-------- .../sql/vectorized/ArrowColumnVector.java | 4 +-- .../apache/spark/sql/BenchmarkQueryTest.scala | 2 +- .../ParquetInteroperabilitySuite.scala | 7 +++-- .../sources/ForeachBatchSinkSuite.scala | 1 + .../shuffle/ContinuousShuffleSuite.scala | 2 ++ .../apache/spark/sql/hive/test/Complex.java | 2 +- 26 files changed, 80 insertions(+), 73 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 0e491efac9181..58e2a8f25f34f 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -234,7 +234,7 @@ public void close() throws IOException { * Closes the given iterator if the DB is still open. Trying to close a JNI LevelDB handle * with a closed DB can cause JVM crashes, so this ensures that situation does not happen. */ - void closeIterator(LevelDBIterator it) throws IOException { + void closeIterator(LevelDBIterator it) throws IOException { synchronized (this._db) { DB _db = this._db.get(); if (_db != null) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 325225dc0ea2c..20d840baeaf6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -318,7 +318,7 @@ private class StdChannelListener } @Override - public void operationComplete(Future future) throws Exception { + public void operationComplete(Future future) throws Exception { if (future.isSuccess()) { if (logger.isTraceEnabled()) { long timeTaken = System.currentTimeMillis() - startTime; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0c..596b0ea5dba9b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -212,8 +212,8 @@ public void handle(ResponseMessage message) throws Exception { if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); + StreamInterceptor interceptor = new StreamInterceptor<>( + this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index e04524dde0a75..452408df19061 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -267,7 +267,7 @@ private void encryptMore() throws IOException { int copied = byteRawChannel.write(buf.nioBuffer()); buf.skipBytes(copied); } else { - region.transferTo(byteRawChannel, region.transfered()); + region.transferTo(byteRawChannel, region.transferred()); } cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); cos.flush(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index d3b2a334baadd..1dcf1324839eb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -301,7 +301,7 @@ private void nextChunk() throws IOException { int copied = byteChannel.write(buf.nioBuffer()); buf.skipBytes(copied); } else { - region.transferTo(byteChannel, region.transfered()); + region.transferTo(byteChannel, region.transferred()); } byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e1d7b2dbff60f..c6fd56b9291e5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -252,8 +252,8 @@ public String getID() { } }; if (req.bodyByteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), - req.bodyByteCount, wrappedCallback); + StreamInterceptor interceptor = new StreamInterceptor<>( + this, wrappedCallback.getID(), req.bodyByteCount, wrappedCallback); frameDecoder.setInterceptor(interceptor); } else { wrappedCallback.onComplete(wrappedCallback.getID()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 60f51125c07fd..d95ed22912507 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -148,11 +148,11 @@ public void close() { channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); channelFuture = null; } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully(); + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully(); } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully(); + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully(); } bootstrap = null; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5e85180bd6f9f..33d6eb4a83a0c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -17,7 +17,6 @@ package org.apache.spark.network.util; -import java.lang.reflect.Field; import java.util.concurrent.ThreadFactory; import io.netty.buffer.PooledByteBufAllocator; @@ -111,24 +110,14 @@ public static PooledByteBufAllocator createPooledByteBufAllocator( } return new PooledByteBufAllocator( allowDirectBufs && PlatformDependent.directBufferPreferred(), - Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), - Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, - allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 + Math.min(PooledByteBufAllocator.defaultNumHeapArena(), numCores), + Math.min(PooledByteBufAllocator.defaultNumDirectArena(), allowDirectBufs ? numCores : 0), + PooledByteBufAllocator.defaultPageSize(), + PooledByteBufAllocator.defaultMaxOrder(), + allowCache ? PooledByteBufAllocator.defaultTinyCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultSmallCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultNormalCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultUseCacheForAllThreads() : false ); } - - /** Used to get defaults from Netty's private static fields. */ - private static int getPrivateStaticField(String name) { - try { - Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); - f.setAccessible(true); - return f.getInt(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bc94f7ca63a96..6fb44fea8c5a4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -116,8 +116,8 @@ public void encode(ChannelHandlerContext ctx, FileRegion in, List out) throws Exception { ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); - while (in.transfered() < in.count()) { - in.transferTo(channel, in.transfered()); + while (in.transferred() < in.count()) { + in.transferTo(channel, in.transferred()); } out.add(Unpooled.wrappedBuffer(channel.getData())); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index d7ed005db1891..d9898771720ae 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -77,15 +77,15 @@ public void testKnownWordsInputs() { for (int i = 0; i < 16; i++) { bytes[i] = 0; } - Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + Assert.assertEquals(-300363099, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); for (int i = 0; i < 16; i++) { bytes[i] = -1; } - Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + Assert.assertEquals(-1210324667, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); for (int i = 0; i < 16; i++) { bytes[i] = (byte)i; } - Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + Assert.assertEquals(-634919701, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); } @Test diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 3b2e809408e0f..7ce2581555014 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -96,11 +96,11 @@ private[spark] class RBackend { channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) channelFuture = null } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully() } if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() + bootstrap.config().childGroup().shutdownGracefully() } bootstrap = null jvmObjectTracker.clear() diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 1dd89bcbe36bc..05aaaa11451b4 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -29,7 +29,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 1cd71955ad4d9..1b3639ad64a73 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -215,7 +215,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) assert(wrapper.index === task.index) assert(wrapper.attempt === task.attemptNumber) assert(wrapper.launchTime === task.launchTime) @@ -258,7 +258,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { executorId = execIds.head, taskFailures = 2, stageId = stages.head.stageId, - stageAttemptId = stages.head.attemptId)) + stageAttemptId = stages.head.attemptNumber)) val executorStageSummaryWrappers = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") @@ -284,7 +284,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { hostId = "2.example.com", // this is where the second executor is hosted executorFailures = 1, stageId = stages.head.stageId, - stageAttemptId = stages.head.attemptId)) + stageAttemptId = stages.head.attemptNumber)) val executorStageSummaryWrappersForNode = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") @@ -468,7 +468,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { hostId = "1.example.com", executorFailures = 1, stageId = stages.last.stageId, - stageAttemptId = stages.last.attemptId)) + stageAttemptId = stages.last.attemptNumber)) check[ExecutorSummaryWrapper](execIds.head) { exec => assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) @@ -963,17 +963,17 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // task end event. time += 1 val task = createTasks(1, Array("1")).head - listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptNumber, task)) time += 1 task.markFinished(TaskState.FINISHED, time) val metrics = TaskMetrics.empty metrics.setExecutorRunTime(42L) - listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptNumber, "taskType", Success, task, metrics)) new AppStatusStore(store) - .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + .taskSummary(dropped.stageId, dropped.attemptNumber, Array(0.25d, 0.50d, 0.75d)) assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) stages.drop(1).foreach { s => diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index efdd02fff7871..2f880a3be33d3 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -194,8 +194,8 @@ class DiskStoreSuite extends SparkFunSuite { val region = data.toNetty().asInstanceOf[FileRegion] val byteChannel = new ByteArrayWritableChannel(data.size.toInt) - while (region.transfered() < region.count()) { - region.transferTo(byteChannel, region.transfered()) + while (region.transferred() < region.count()) { + region.transferTo(byteChannel, region.transferred()) } byteChannel.close() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index a0010f18c18a1..3c6660800f170 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.io.NotSerializableException +import scala.language.reflectiveCalls + import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 4e02843480e8f..8a1256f73416e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -28,7 +28,7 @@ * * @since Spark 2.3.0 */ -public abstract class AbstractLauncher { +public abstract class AbstractLauncher> { final SparkSubmitCommandBuilder builder; diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 65bee4edc4965..9a59c41740daf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -594,11 +594,12 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { (check: (ALSModel, ALSModel) => Unit) (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { val dfs = genRatingsDFWithNumericCols(spark, column) - val df = dfs.find { - case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType - } match { - case Some((_, df)) => df + val maybeDf = dfs.find { case (numericTypeWithEncoder, _) => + numericTypeWithEncoder.numericType == baseType } + assert(maybeDf.isDefined) + val df = maybeDf.get._2 + val expected = estimator.fit(df) val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) actuals.foreach { case (_, actual) => check(expected, actual) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 9fb50a5e565e0..ac33248241a25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import scala.reflect.ClassTag + import org.scalatest.Matchers import org.apache.spark.api.python.PythonEvalType @@ -529,9 +531,11 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-22614 RepartitionByExpression partitioning") { - def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: Expression*): Unit = { + def checkPartitioning[T <: Partitioning: ClassTag]( + numPartitions: Int, exprs: Expression*): Unit = { val partitioning = RepartitionByExpression(exprs, testRelation2, numPartitions).partitioning - assert(partitioning.isInstanceOf[T]) + val clazz = implicitly[ClassTag[T]].runtimeClass + assert(clazz.isInstance(partitioning)) } checkPartitioning[HashPartitioning](numPartitions = 10, exprs = Literal(20)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 20d568c44258f..b0af9e07d1d1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ +import scala.language.existentials import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Random diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 31ef090ac4b45..ba26b57567e64 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -167,6 +167,8 @@ void readBatch(int total, WritableColumnVector column) throws IOException { leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); + PrimitiveType.PrimitiveTypeName typeName = + descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. defColumn.readIntegers( @@ -175,12 +177,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && - (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || - (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && + (typeName == PrimitiveType.PrimitiveTypeName.INT32 || + (typeName == PrimitiveType.PrimitiveTypeName.INT64 && originalType != OriginalType.TIMESTAMP_MILLIS) || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { + typeName == PrimitiveType.PrimitiveTypeName.FLOAT || + typeName == PrimitiveType.PrimitiveTypeName.DOUBLE || + typeName == PrimitiveType.PrimitiveTypeName.BINARY))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). @@ -195,7 +197,7 @@ void readBatch(int total, WritableColumnVector column) throws IOException { decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); - switch (descriptor.getType()) { + switch (typeName) { case BOOLEAN: readBooleanBatch(rowId, num, column); break; @@ -218,10 +220,11 @@ void readBatch(int total, WritableColumnVector column) throws IOException { readBinaryBatch(rowId, num, column); break; case FIXED_LEN_BYTE_ARRAY: - readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength()); + readFixedLenByteArrayBatch( + rowId, num, column, descriptor.getPrimitiveType().getTypeLength()); break; default: - throw new IOException("Unsupported type: " + descriptor.getType()); + throw new IOException("Unsupported type: " + typeName); } } @@ -243,7 +246,7 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc WritableColumnVector column) { return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), - descriptor.getType().toString(), + descriptor.getPrimitiveType().getPrimitiveTypeName().toString(), column.dataType().catalogString()); } @@ -255,7 +258,7 @@ private void decodeDictionaryIds( int num, WritableColumnVector column, WritableColumnVector dictionaryIds) { - switch (descriptor.getType()) { + switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { case INT32: if (column.dataType() == DataTypes.IntegerType || DecimalType.is32BitDecimalType(column.dataType())) { @@ -381,7 +384,8 @@ private void decodeDictionaryIds( break; default: - throw new UnsupportedOperationException("Unsupported type: " + descriptor.getType()); + throw new UnsupportedOperationException( + "Unsupported type: " + descriptor.getPrimitiveType().getPrimitiveTypeName()); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 227a16f7e69e9..5aed87f88a298 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -455,9 +455,9 @@ final boolean isNullAt(int rowId) { @Override final ColumnarArray getArray(int rowId) { ArrowBuf offsets = accessor.getOffsetBuffer(); - int index = rowId * accessor.OFFSET_WIDTH; + int index = rowId * ListVector.OFFSET_WIDTH; int start = offsets.getInt(index); - int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); return new ColumnarArray(arrayData, start, end - start); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index e51aad021fcbf..d95794d624033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -54,7 +54,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B plan foreach { case s: WholeStageCodegenExec => codegenSubtrees += s - case s => s + case _ => } codegenSubtrees.toSeq.foreach { subtree => val code = subtree.doCodeGen()._2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 9c75965639d8a..f06e1867151e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import scala.language.existentials + import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER @@ -175,8 +177,9 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS val oneFooter = ParquetFileReader.readFooter(hadoopConf, part.getPath, NO_FILTER) assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 1) - assert(oneFooter.getFileMetaData.getSchema.getColumns.get(0).getType() === - PrimitiveTypeName.INT96) + val typeName = oneFooter + .getFileMetaData.getSchema.getColumns.get(0).getPrimitiveType.getPrimitiveTypeName + assert(typeName === PrimitiveTypeName.INT96) val oneBlockMeta = oneFooter.getBlocks().get(0) val oneBlockColumnMeta = oneBlockMeta.getColumns().get(0) val columnStats = oneBlockColumnMeta.getStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index a4233e15e4ffd..71dff443e8836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable +import scala.language.implicitConversions import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.MemoryStream diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index f84f3d49707bf..b42f8267916b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.UUID +import scala.language.implicitConversions + import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index a8cbd4fab15bb..48891fdcb1d80 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -676,7 +676,7 @@ public int compareTo(Complex other) { } int lastComparison = 0; - Complex typedOther = (Complex)other; + Complex typedOther = other; lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); if (lastComparison != 0) { From b7fdf8eb2011ae76f0161caa9da91e29f52f05e4 Mon Sep 17 00:00:00 2001 From: Yuval Itzchakov Date: Sat, 4 Aug 2018 14:44:10 -0500 Subject: [PATCH 1309/2461] [SPARK-24987][SS] - Fix Kafka consumer leak when no new offsets for TopicPartition ## What changes were proposed in this pull request? This small fix adds a `consumer.release()` call to `KafkaSourceRDD` in the case where we've retrieved offsets from Kafka, but the `fromOffset` is equal to the `lastOffset`, meaning there is no new data to read for a particular topic partition. Up until now, we'd just return an empty iterator without closing the consumer which would cause a FD leak. If accepted, this pull request should be merged into master as well. ## How was this patch tested? Haven't ran any specific tests, would love help on how to test methods running inside `RDD.compute`. Author: Yuval Itzchakov Closes #21997 from YuvalItzchakov/master. --- .../scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 53bd9a96d1d68..8b4494d2e9a25 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -124,8 +124,6 @@ private[kafka010] class KafkaSourceRDD( thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] - val topic = sourcePartition.offsetRange.topic - val kafkaPartition = sourcePartition.offsetRange.partition val consumer = KafkaDataConsumer.acquire( sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) @@ -138,6 +136,7 @@ private[kafka010] class KafkaSourceRDD( if (range.fromOffset == range.untilOffset) { logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + s"skipping ${range.topic} ${range.partition}") + consumer.release() Iterator.empty } else { val underlying = new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { From 5f9633dc97ad5f78dd17cad39945ea32f3441f06 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 4 Aug 2018 14:59:13 -0500 Subject: [PATCH 1310/2461] [SPARK-25015][BUILD] Update Hadoop 2.7 to 2.7.7 ## What changes were proposed in this pull request? Update Hadoop 2.7 to 2.7.7 to pull in bug and security fixes. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #21987 from srowen/SPARK-25015. --- assembly/README | 2 +- dev/deps/spark-deps-hadoop-2.7 | 31 ++++++++++++++++--------------- docs/building-spark.md | 2 +- pom.xml | 2 +- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/assembly/README b/assembly/README index d5dafab477410..affd281a1385c 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.3 + -Dhadoop.version=2.7.7 diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index fda13db52ba3d..113639946f7d6 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -66,21 +66,21 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.3.jar -hadoop-auth-2.7.3.jar -hadoop-client-2.7.3.jar -hadoop-common-2.7.3.jar -hadoop-hdfs-2.7.3.jar -hadoop-mapreduce-client-app-2.7.3.jar -hadoop-mapreduce-client-common-2.7.3.jar -hadoop-mapreduce-client-core-2.7.3.jar -hadoop-mapreduce-client-jobclient-2.7.3.jar -hadoop-mapreduce-client-shuffle-2.7.3.jar -hadoop-yarn-api-2.7.3.jar -hadoop-yarn-client-2.7.3.jar -hadoop-yarn-common-2.7.3.jar -hadoop-yarn-server-common-2.7.3.jar -hadoop-yarn-server-web-proxy-2.7.3.jar +hadoop-annotations-2.7.7.jar +hadoop-auth-2.7.7.jar +hadoop-client-2.7.7.jar +hadoop-common-2.7.7.jar +hadoop-hdfs-2.7.7.jar +hadoop-mapreduce-client-app-2.7.7.jar +hadoop-mapreduce-client-common-2.7.7.jar +hadoop-mapreduce-client-core-2.7.7.jar +hadoop-mapreduce-client-jobclient-2.7.7.jar +hadoop-mapreduce-client-shuffle-2.7.7.jar +hadoop-yarn-api-2.7.7.jar +hadoop-yarn-client-2.7.7.jar +hadoop-yarn-common-2.7.7.jar +hadoop-yarn-server-common-2.7.7.jar +hadoop-yarn-server-web-proxy-2.7.7.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar @@ -121,6 +121,7 @@ jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar +jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar jline-2.14.3.jar joda-time-2.9.3.jar diff --git a/docs/building-spark.md b/docs/building-spark.md index c3bcd90ccc78f..affd7df17b001 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -67,7 +67,7 @@ Examples: ./build/mvn -Pyarn -DskipTests clean package # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.3 -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.7 -DskipTests clean package ## Building With Hive and JDBC Support diff --git a/pom.xml b/pom.xml index be84661a50dcc..c46eb31715eab 100644 --- a/pom.xml +++ b/pom.xml @@ -2681,7 +2681,7 @@ hadoop-2.7 - 2.7.3 + 2.7.7 2.7.1 From 327bb30075834c873cdb78061c9b647e5e13b8a6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 5 Aug 2018 08:58:35 +0900 Subject: [PATCH 1311/2461] [SPARK-23911][SQL] Add aggregate function. ## What changes were proposed in this pull request? This pr adds `aggregate` function which applies a binary operator to an initial state and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish function. ```sql > SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x); 6 > SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); 60 ``` ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21982 from ueshin/issues/SPARK-23911/aggregate. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 95 ++++++++++++++ .../HigherOrderFunctionsSuite.scala | 50 ++++++++ .../inputs/higher-order-functions.sql | 12 ++ .../results/higher-order-functions.sql.out | 40 +++++- .../spark/sql/DataFrameFunctionsSuite.scala | 121 ++++++++++++++++++ 6 files changed, 318 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d0efe975f81ce..35f8de1328b50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -442,6 +442,7 @@ object FunctionRegistry { expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[ArrayFilter]("filter"), + expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index e15225ffbd2d2..20c7f7d43b9dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} @@ -76,6 +77,13 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) } +object LambdaFunction { + val identity: LambdaFunction = { + val id = UnresolvedAttribute.quoted("id") + LambdaFunction(id, Seq(id)) + } +} + /** * A higher order function takes one or more (lambda) functions and applies these to some objects. * The function produces a number of variables which can be consumed by some lambda function. @@ -270,3 +278,90 @@ case class ArrayFilter( override def prettyName: String = "filter" } + +/** + * Applies a binary operator to a start value and all elements in the array. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all + elements in the array, and reduces this to a single state. The final state is converted + into the final result by applying a finish function. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x); + 6 + > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); + 60 + """, + since = "2.4.0") +case class ArrayAggregate( + input: Expression, + zero: Expression, + merge: Expression, + finish: Expression) + extends HigherOrderFunction with CodegenFallback { + + def this(input: Expression, zero: Expression, merge: Expression) = { + this(input, zero, merge, LambdaFunction.identity) + } + + override def inputs: Seq[Expression] = input :: zero :: Nil + + override def functions: Seq[Expression] = merge :: finish :: Nil + + override def nullable: Boolean = input.nullable || finish.nullable + + override def dataType: DataType = finish.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (!ArrayType.acceptsType(input.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"argument 1 requires ${ArrayType.simpleString} type, " + + s"however, '${input.sql}' is of ${input.dataType.catalogString} type.") + } else if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { + // Be very conservative with nullable. We cannot be sure that the accumulator does not + // evaluate to null. So we always set nullable to true here. + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val acc = zero.dataType -> true + val newMerge = f(merge, acc :: elem :: Nil) + val newFinish = f(finish, acc :: Nil) + copy(merge = newMerge, finish = newFinish) + } + + @transient lazy val LambdaFunction(_, + Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge + @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val Seq(mergeForEval, finishForEval) = functionsForEval + accForMergeVar.value.set(zero.eval(input)) + var i = 0 + while (i < arr.numElements()) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + accForMergeVar.value.set(mergeForEval.eval(input)) + i += 1 + } + accForFinishVar.value.set(accForMergeVar.value.get) + finishForEval.eval(input) + } + } + + override def prettyName: String = "aggregate" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index d1330c7aad219..40cfc0ccc7c07 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -59,6 +59,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) } + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, at.elementType, at.containsNull, merge), + createLambda(zeroType, true, finish)) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -131,4 +152,33 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), Seq(Seq(1, 3), null, Seq(5))) } + + test("ArrayAggregate") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10), 60) + checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) + checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) + checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), "abc") + checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, coalesce(elem, "x")))), "axc") + checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), "") + checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), null) + + val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + aggregate(aai, 0, + (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), + 15) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index f833aa5818bc1..136396d9553db 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -33,3 +33,15 @@ select filter(cast(null as array), y -> true) as v; -- Filter nested arrays select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; + +-- Aggregate. +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested; + +-- Aggregate average. +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested; + +-- Aggregate nested arrays +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested; + +-- Aggregate a null array +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 4c5d972378b31..e6f62f2e1bb67 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 15 -- !query 0 @@ -107,3 +107,41 @@ struct>> [[96,65],[]] [[99],[123],[]] [[]] + + +-- !query 11 +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested +-- !query 11 schema +struct +-- !query 11 output +131 +15 +5 + + +-- !query 12 +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested +-- !query 12 schema +struct +-- !query 12 output +0.5 +12.0 +64.5 + + +-- !query 13 +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested +-- !query 13 schema +struct> +-- !query 13 output +[1010880,8] +[17] +[4752,20664,1] + + +-- !query 14 +select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v +-- !query 14 schema +struct +-- !query 14 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1d5707a2c7047..af3301b1599a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1896,6 +1896,127 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) } + test("aggregate function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("aggregate function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("aggregate function - array for non-primitive type") { + val df = Seq( + (Seq("c", "a", "b"), "a"), + (Seq("b", null, "c", null), "b"), + (Seq.empty, "c"), + (null, "d") + ).toDF("ss", "s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("aggregate function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("aggregate(i, 0, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From ac527b5205ec2826677e2b7ad0d424aa976bce81 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Aug 2018 15:52:01 +0800 Subject: [PATCH 1312/2461] [SPARK-24991][SQL] use InternalRow in DataSourceWriter ## What changes were proposed in this pull request? A follow up of #21118 Since we use `InternalRow` in the read API of data source v2, we should do the same thing for the write API. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #21948 from cloud-fan/row-write. --- .../sql/kafka010/KafkaStreamWriter.scala | 4 +- .../sources/v2/writer/DataSourceWriter.java | 4 +- .../sql/sources/v2/writer/DataWriter.java | 4 +- .../sources/v2/writer/DataWriterFactory.java | 5 +- .../v2/writer/SupportsWriteInternalRow.java | 41 ----------- .../datasources/v2/WriteToDataSourceV2.scala | 30 +------- .../streaming/MicroBatchExecution.scala | 10 +-- .../continuous/ContinuousWriteRDD.scala | 6 +- .../WriteToContinuousDataSourceExec.scala | 12 +--- .../streaming/sources/ConsoleWriter.scala | 11 ++- .../sources/ForeachWriterProvider.scala | 10 +-- .../streaming/sources/MicroBatchWriter.scala | 21 +----- .../sources/PackedRowWriterFactory.scala | 15 ++-- .../streaming/sources/memoryV2.scala | 33 +++++---- .../streaming/MemorySinkV2Suite.scala | 18 +++-- .../sql/sources/v2/DataSourceV2Suite.scala | 7 -- .../sources/v2/SimpleWritableDataSource.scala | 72 ++----------------- 17 files changed, 73 insertions(+), 230 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 32923dc9f5a6b..5f0802b466039 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage */ class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter with SupportsWriteInternalRow { + extends StreamWriter { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + override def createWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7eedc85a5d6f3..385fc294fea82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -18,8 +18,8 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; @@ -61,7 +61,7 @@ public interface DataSourceWriter { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 1626c0013e4e7..27dc5ea224fe2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -53,9 +53,7 @@ * successfully, and have a way to revert committed data writers without the commit message, because * Spark only accepts the commit message that arrives first and ignore others. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers - * that mix in {@link SupportsWriteInternalRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ @InterfaceStability.Evolving public interface DataWriter { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 0932ff8f8f8a7..3d337b6e0bdfd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -33,7 +33,10 @@ public interface DataWriterFactory extends Serializable { /** - * Returns a data writer to do the actual writing work. + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java deleted file mode 100644 index d2cf7e01c08c8..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.InternalRow; - -/** - * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. - * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get - * changed in the future Spark versions. - */ - -@InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceWriter { - - @Override - default DataWriterFactory createWriterFactory() { - throw new IllegalStateException( - "createWriterFactory should not be called with SupportsWriteInternalRow."); - } - - DataWriterFactory createInternalRowWriterFactory(); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index b1148c0f62f7c..0399970495bec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writeTask = writer.createWriterFactory() val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging { }) } } - -class InternalRowDataWriterFactory( - rowWriterFactory: DataWriterFactory[Row], - schema: StructType) extends DataWriterFactory[InternalRow] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { - new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, taskId, epochId), - RowEncoder.apply(schema).resolveAndBind()) - } -} - -class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) - extends DataWriter[InternalRow] { - - override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) - - override def commit(): WriterCommitMessage = rowWriter.commit() - - override def abort(): Unit = rowWriter.abort() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index abb807def6239..c759f5be8ba35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -498,12 +497,7 @@ class MicroBatchExecution( newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - if (writer.isInstanceOf[SupportsWriteInternalRow]) { - WriteToDataSourceV2( - new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) - } else { - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) - } + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 76f3f5baa8d56..967dbe24a3705 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} import org.apache.spark.util.Utils /** @@ -47,7 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor SparkEnv.get) EpochTracker.initializeCurrentEpoch( context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) - while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index e0af3a2f1b85d..927d3a84e296b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} -import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.util.Utils /** * The physical plan for writing data into a continuous processing [[StreamWriter]]. @@ -41,11 +37,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) - } - + val writerFactory = writer.createWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d276403190b3c..fd45ba509091e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.streaming.sources -import scala.collection.JavaConverters._ - import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 @@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark - .createDataFrame(rows.toList.asJava, schema) + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index bc9b6d93ce7d9..e8ce21cc12044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,11 +46,11 @@ case class ForeachWriterProvider[T]( schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new StreamWriter with SupportsWriteInternalRow { + new StreamWriter { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index 56f7ff25cbed0..d023a35ea20b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** @@ -34,21 +33,5 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() -} - -class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceWriter with SupportsWriteInternalRow { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = - writer match { - case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() - case _ => throw new IllegalStateException( - "InternalRowMicroBatchWriter should only be created with base writer support") - } + override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index b501d90c81f06..f26e11d842b29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** @@ -30,11 +30,11 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[Row] { +case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { + epochId: Long): DataWriter[InternalRow] = { new PackedRowDataWriter() } } @@ -43,15 +43,16 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] { * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most * recent interval. */ -case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage +case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage /** * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. */ -class PackedRowDataWriter() extends DataWriter[Row] with Logging { - private val data = mutable.Buffer[Row]() +class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging { + private val data = mutable.Buffer[InternalRow]() - override def write(row: Row): Unit = data.append(row) + // Spark reuses the same `InternalRow` instance, here we copy it before buffer it. + override def write(row: InternalRow): Unit = data.append(row.copy()) override def commit(): PackedRowCommitMessage = { val msg = PackedRowCommitMessage(data.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f2a35a90af24a..afacb2f72c926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils @@ -46,7 +48,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -115,12 +117,13 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB override def toString(): String = "MemorySinkV2" } -case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} +case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) + extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) extends DataSourceWriter with Logging { - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -134,10 +137,10 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) extends StreamWriter { - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -151,22 +154,26 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } } -case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { +case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, taskId: Long, - epochId: Long): DataWriter[Row] = { - new MemoryDataWriter(partitionId, outputMode) + epochId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode) - extends DataWriter[Row] with Logging { +class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) + extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[Row]() - override def write(row: Row): Unit = { - data.append(row) + private val encoder = RowEncoder(schema).resolveAndBind() + + override def write(row: InternalRow): Unit = { + data.append(encoder.fromRow(row)) } override def commit(): MemoryWriterCommitMessage = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..b4d9b68c78152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.execution.streaming import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 - val writer = new MemoryDataWriter(partition, OutputMode.Append()) - writer.write(Row(1)) - writer.write(Row(2)) - writer.write(Row(44)) + val writer = new MemoryDataWriter( + partition, OutputMode.Append(), new StructType().add("i", "int")) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) val msg = writer.commit() assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) assert(msg.partition == partition) @@ -40,7 +43,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +65,8 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + val schema = new StructType().add("i", "int") + new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +74,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index b6e594dc29cef..fef53e6f7b6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -242,13 +242,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - - // test internal row writer - spark.range(5).select('id, -'id).write.format(cls.getName) - .option("path", path).option("internal", "true").mode("overwrite").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 183d0399d3bcd..e1b8e9c44d725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ @@ -65,9 +65,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[Row] = { + override def createWriterFactory(): DataWriterFactory[InternalRow] = { SimpleCounter.resetCounter - new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -97,18 +97,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class InternalRowWriter(jobId: String, path: String, conf: Configuration) - extends Writer(jobId, path, conf) with SupportsWriteInternalRow { - - override def createWriterFactory(): DataWriterFactory[Row] = { - throw new IllegalArgumentException("not expected!") - } - - override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) - } - } - override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration @@ -124,7 +112,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) val path = new Path(options.get("path").get()) - val internal = options.get("internal").isPresent val conf = SparkContext.getActive.get.hadoopConfiguration val fs = path.getFileSystem(conf) @@ -142,17 +129,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS fs.delete(path, true) } - Optional.of(createWriter(jobId, path, conf, internal)) - } - - private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString - if (internal) { - new InternalRowWriter(jobId, pathStr, conf) - } else { - new Writer(jobId, pathStr, conf) - } + Optional.of(new Writer(jobId, pathStr, conf)) } } @@ -204,43 +182,7 @@ private[v2] object SimpleCounter { } } -class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[Row] { - - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[Row] = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") - val fs = filePath.getFileSystem(conf.value) - new SimpleCSVDataWriter(fs, filePath) - } -} - -class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { - - private val out = fs.create(file) - - override def write(record: Row): Unit = { - out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") - } - - override def commit(): WriterCommitMessage = { - out.close() - null - } - - override def abort(): Unit = { - try { - out.close() - } finally { - fs.delete(file, false) - } - } -} - -class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) +class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { override def createDataWriter( @@ -250,11 +192,11 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) - new InternalRowCSVDataWriter(fs, filePath) + new CSVDataWriter(fs, filePath) } } -class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { +class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { private val out = fs.create(file) From 64ad7b841d1efa979041358ee2a19aea7382d737 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 6 Aug 2018 16:46:55 +0800 Subject: [PATCH 1313/2461] [SPARK-23772][FOLLOW-UP][SQL] Provide an option to ignore column of all null values or empty array during JSON schema inference ## What changes were proposed in this pull request? The `dropFieldIfAllNull` parameter of the `json` method wasn't set as an option. This PR fixes that. ## How was this patch tested? I added a test to `sql/test.py` Author: Maxim Gekk Closes #22002 from MaxGekk/drop-field-if-all-null. --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/tests.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 98b2cd9968407..abf878ae709a5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -267,7 +267,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio, encoding=encoding) + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a294d70119d0b..ed97a6394a98c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3351,6 +3351,22 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) + def test_ignore_column_of_all_nulls(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""], + ["""{"a":null, "b":null, "c":"string"}"""], + ["""{"a":null, "b":null, "c":null}"""]]) + df.write.text(path) + schema = StructType([ + StructField('b', LongType(), nullable=True), + StructField('c', StringType(), nullable=True)]) + readback = self.spark.read.json(path, dropFieldIfAllNull=True) + self.assertEquals(readback.schema, schema) + finally: + shutil.rmtree(path) + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) From d063e3a478221c836a0aa74a69828a526a6207bb Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Mon, 6 Aug 2018 06:41:55 -0400 Subject: [PATCH 1314/2461] [SPARK-24940][SQL] Use IntegerLiteral in ResolveCoalesceHints ## What changes were proposed in this pull request? Follow up to fix an unmerged review comment. ## How was this patch tested? Unit test ResolveHintsSuite. Author: John Zhuge Closes #21998 from jzhuge/SPARK-24940. --- .../apache/spark/sql/catalyst/analysis/ResolveHints.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 1ef482b0e9f5b..80d5105c2de8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType /** @@ -119,7 +118,7 @@ object ResolveHints { case "COALESCE" => false } val numPartitions = h.parameters match { - case Seq(Literal(numPartitions: Int, IntegerType)) => + case Seq(IntegerLiteral(numPartitions)) => numPartitions case Seq(numPartitions: Int) => numPartitions From c1760da5dd5576c52be4f9dd9ecd06589a6153e4 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 6 Aug 2018 06:56:36 -0400 Subject: [PATCH 1315/2461] [SPARK-25025][SQL] Remove the default value of isAll in INTERSECT/EXCEPT ## What changes were proposed in this pull request? Having the default value of isAll in the logical plan nodes INTERSECT/EXCEPT could introduce bugs when the callers are not aware of it. This PR removes the default value and makes caller explicitly specify them. ## How was this patch tested? This is a refactoring change. Existing tests test the functionality already. Author: Dilip Biswal Closes #22000 from dilipbiswal/SPARK-25025. --- .../spark/sql/catalyst/dsl/package.scala | 4 ++-- .../sql/catalyst/parser/AstBuilder.scala | 6 ++--- .../plans/logical/basicLogicalOperators.scala | 4 ++-- .../analysis/AnalysisErrorSuite.scala | 12 +++++----- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 ++--- .../catalyst/analysis/TypeCoercionSuite.scala | 24 ++++++++++++------- .../analysis/UnsupportedOperationsSuite.scala | 4 ++-- .../optimizer/ColumnPruningSuite.scala | 4 ++-- .../optimizer/ReplaceOperatorSuite.scala | 16 ++++++------- .../sql/catalyst/parser/PlanParserSuite.scala | 21 +++++++++------- .../plans/ConstraintPropagationSuite.scala | 4 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 4 ++-- 12 files changed, 60 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 7997e79003b12..75387fac64ed8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -356,10 +356,10 @@ package object dsl { def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) - def except(otherPlan: LogicalPlan, isAll: Boolean = false): LogicalPlan = + def except(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan = Except(logicalPlan, otherPlan, isAll) - def intersect(otherPlan: LogicalPlan, isAll: Boolean = false): LogicalPlan = + def intersect(otherPlan: LogicalPlan, isAll: Boolean): LogicalPlan = Intersect(logicalPlan, otherPlan, isAll) def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9906a30b488b8..732d762335f1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -534,15 +534,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.INTERSECT if all => Intersect(left, right, isAll = true) case SqlBaseParser.INTERSECT => - Intersect(left, right) + Intersect(left, right, isAll = false) case SqlBaseParser.EXCEPT if all => Except(left, right, isAll = true) case SqlBaseParser.EXCEPT => - Except(left, right) + Except(left, right, isAll = false) case SqlBaseParser.SETMINUS if all => Except(left, right, isAll = true) case SqlBaseParser.SETMINUS => - Except(left, right) + Except(left, right, isAll = false) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d7dbdb39a9afb..9d18ce5c7b80f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -167,7 +167,7 @@ object SetOperation { case class Intersect( left: LogicalPlan, right: LogicalPlan, - isAll: Boolean = false) extends SetOperation(left, right) { + isAll: Boolean) extends SetOperation(left, right) { override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) @@ -191,7 +191,7 @@ case class Intersect( case class Except( left: LogicalPlan, right: LogicalPlan, - isAll: Boolean = false) extends SetOperation(left, right) { + isAll: Boolean) extends SetOperation(left, right) { override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ae8d77bbbf9a8..0a5194a287ecc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -277,13 +277,13 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "intersect with unequal number of columns", - testRelation.intersect(testRelation2), + testRelation.intersect(testRelation2, isAll = false), "intersect" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) errorTest( "except with unequal number of columns", - testRelation.except(testRelation2), + testRelation.except(testRelation2, isAll = false), "except" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) @@ -299,22 +299,22 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "intersect with incompatible column types", - testRelation.intersect(nestedRelation), + testRelation.intersect(nestedRelation, isAll = false), "intersect" :: "the compatible column types" :: Nil) errorTest( "intersect with a incompatible column type and compatible column types", - testRelation3.intersect(testRelation4), + testRelation3.intersect(testRelation4, isAll = false), "intersect" :: "the compatible column types" :: "map" :: "decimal" :: Nil) errorTest( "except with incompatible column types", - testRelation.except(nestedRelation), + testRelation.except(nestedRelation, isAll = false), "except" :: "the compatible column types" :: Nil) errorTest( "except with a incompatible column type and compatible column types", - testRelation3.except(testRelation4), + testRelation3.except(testRelation4, isAll = false), "except" :: "the compatible column types" :: "map" :: "decimal" :: Nil) errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ac33248241a25..ba44484b946ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -273,7 +273,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("self intersect should resolve duplicate expression IDs") { - val plan = testRelation.intersect(testRelation) + val plan = testRelation.intersect(testRelation, isAll = false) assertAnalysisSuccess(plan) } @@ -439,8 +439,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { val unionPlan = Union(firstTable, secondTable) assertAnalysisSuccess(unionPlan) - val r1 = Except(firstTable, secondTable) - val r2 = Intersect(firstTable, secondTable) + val r1 = Except(firstTable, secondTable, isAll = false) + val r2 = Intersect(firstTable, secondTable, isAll = false) assertAnalysisSuccess(r1) assertAnalysisSuccess(r2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 4161f09c63190..d71bbb3227134 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1223,8 +1223,10 @@ class TypeCoercionSuite extends AnalysisTest { val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes( + Except(firstTable, secondTable, isAll = false)).asInstanceOf[Except] + val r2 = widenSetOperationTypes( + Intersect(firstTable, secondTable, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -1289,8 +1291,10 @@ class TypeCoercionSuite extends AnalysisTest { val expectedType1 = Seq(DecimalType(10, 8)) val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] - val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] - val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] + val r2 = widenSetOperationTypes( + Except(left1, right1, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(left1, right1, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -1310,16 +1314,20 @@ class TypeCoercionSuite extends AnalysisTest { AttributeReference("r", rType)()) val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r2 = widenSetOperationTypes( + Except(plan1, plan2, isAll = false)).asInstanceOf[Except] + val r3 = widenSetOperationTypes( + Intersect(plan1, plan2, isAll = false)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r5 = widenSetOperationTypes( + Except(plan2, plan1, isAll = false)).asInstanceOf[Except] + val r6 = widenSetOperationTypes( + Intersect(plan2, plan1, isAll = false)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index cb487c8893541..197d7c7668ef1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -575,14 +575,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Except: *-stream not supported testBinaryOperationInStreamingPlan( "except", - _.except(_), + _.except(_, isAll = false), streamStreamSupported = false, batchStreamSupported = false) // Intersect: stream-stream not supported testBinaryOperationInStreamingPlan( "intersect", - _.intersect(_), + _.intersect(_, isAll = false), streamStreamSupported = false) // Sort: supported only on batch subplans and after aggregation on streaming plan + complete mode diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index f6db3c90ad96c..8d7c9bf220bc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -180,10 +180,10 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on except/intersect/distinct") { val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Except(input, input)).analyze + val query = Project('a :: Nil, Except(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query), query) - val query2 = Project('a :: Nil, Intersect(input, input)).analyze + val query2 = Project('a :: Nil, Intersect(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query2), query2) val query3 = Project('a :: Nil, Distinct(input)).analyze comparePlans(Optimize.execute(query3), query3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 52dc2e9fb076c..3b1b2d588ef67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -42,7 +42,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) - val query = Intersect(table1, table2) + val query = Intersect(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -60,7 +60,7 @@ class ReplaceOperatorSuite extends PlanTest { val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) val table3 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -79,7 +79,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) - val query = Except(table1, table2) + val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -99,7 +99,7 @@ class ReplaceOperatorSuite extends PlanTest { val table3 = Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -120,7 +120,7 @@ class ReplaceOperatorSuite extends PlanTest { val table3 = Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -141,7 +141,7 @@ class ReplaceOperatorSuite extends PlanTest { Filter(attributeB < 1, Filter(attributeA >= 2, table1))) val table3 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) - val query = Except(table2, table3) + val query = Except(table2, table3, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -158,7 +158,7 @@ class ReplaceOperatorSuite extends PlanTest { val table1 = LocalRelation('a.int, 'b.int) val table2 = LocalRelation('c.int, 'd.int) - val query = Except(table1, table2) + val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = @@ -173,7 +173,7 @@ class ReplaceOperatorSuite extends PlanTest { val left = table.where('b < 1).select('a).as("left") val right = table.where('b < 3).select('a).as("right") - val query = Except(left, right) + val query = Except(left, right, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index d7200d0bff5d6..422bf97e30e7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -65,14 +65,15 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a union select * from b", Distinct(a.union(b))) assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) assertEqual("select * from a union all select * from b", a.union(b)) - assertEqual("select * from a except select * from b", a.except(b)) - assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a except select * from b", a.except(b, isAll = false)) + assertEqual("select * from a except distinct select * from b", a.except(b, isAll = false)) assertEqual("select * from a except all select * from b", a.except(b, isAll = true)) - assertEqual("select * from a minus select * from b", a.except(b)) + assertEqual("select * from a minus select * from b", a.except(b, isAll = false)) assertEqual("select * from a minus all select * from b", a.except(b, isAll = true)) - assertEqual("select * from a minus distinct select * from b", a.except(b)) - assertEqual("select * from a intersect select * from b", a.intersect(b)) - assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + assertEqual("select * from a minus distinct select * from b", a.except(b, isAll = false)) + assertEqual("select * from a " + + "intersect select * from b", a.intersect(b, isAll = false)) + assertEqual("select * from a intersect distinct select * from b", a.intersect(b, isAll = false)) assertEqual("select * from a intersect all select * from b", a.intersect(b, isAll = true)) } @@ -735,18 +736,20 @@ class PlanParserSuite extends AnalysisTest { |SELECT * FROM d """.stripMargin - assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d))) + assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d, isAll = false), isAll = false)) assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) // Now disable precedence enforcement to verify the old behaviour. withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "true") { - assertEqual(query1, Distinct(a.union(b)).except(c).intersect(d)) + assertEqual(query1, + Distinct(a.union(b)).except(c, isAll = false).intersect(d, isAll = false)) assertEqual(query2, Distinct(a.union(b)).except(c, isAll = true).intersect(d, isAll = true)) } // Explicitly enable the precedence enforcement withSQLConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED.key -> "false") { - assertEqual(query1, Distinct(a.union(b)).except(c.intersect(d))) + assertEqual(query1, + Distinct(a.union(b)).except(c.intersect(d, isAll = false), isAll = false)) assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a37e06d922642..5ad748b6113d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -187,7 +187,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(tr1 .where('a.attr > 10) - .intersect(tr2.where('b.attr < 100)) + .intersect(tr2.where('b.attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100, @@ -200,7 +200,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { val tr2 = LocalRelation('a.int, 'b.int, 'c.int) verifyConstraints(tr1 .where('a.attr > 10) - .except(tr2.where('b.attr < 100)) + .except(tr2.where('b.attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3b0a6d8840f1e..a4bf990ea9d6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1927,7 +1927,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(logicalPlan, other.logicalPlan) + Intersect(logicalPlan, other.logicalPlan, isAll = false) } /** @@ -1958,7 +1958,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(logicalPlan, other.logicalPlan) + Except(logicalPlan, other.logicalPlan, isAll = false) } /** From 35700bb7f2e3008ff781a1b3a1da8147d26371be Mon Sep 17 00:00:00 2001 From: Hieu Huynh <“Hieu.huynh@oath.com”> Date: Mon, 6 Aug 2018 09:01:51 -0500 Subject: [PATCH 1316/2461] [SPARK-24981][CORE] ShutdownHook timeout causes job to fail when succeeded when SparkContext stop() not called by user program MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description** The issue is described in [SPARK-24981](https://issues.apache.org/jira/browse/SPARK-24981). **How does this PR fix the issue?** This PR catch the Exception that is thrown while the Sparkcontext.stop() is running (when it is called by the ShutdownHookManager). **How was this patch tested?** I manually tested it by adding delay (60s) inside the stop(). This make the shutdownHookManger interrupt the thread that is running stop(). The Interrupted Exception was catched and the job succeed. Author: Hieu Huynh <“Hieu.huynh@oath.com”> Author: Hieu Tri Huynh Closes #21936 from hthuynh2/SPARK_24981. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 03e91cdd310ed..e8bacee3b0215 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -571,7 +571,12 @@ class SparkContext(config: SparkConf) extends Logging { _shutdownHookRef = ShutdownHookManager.addShutdownHook( ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") - stop() + try { + stop() + } catch { + case e: Throwable => + logWarning("Ignoring Exception while stopping SparkContext from shutdown hook", e) + } } } catch { case NonFatal(e) => From 1a5e460762593c61b7ff2c5f3641d406706616ff Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 6 Aug 2018 23:27:57 +0900 Subject: [PATCH 1317/2461] [SPARK-23913][SQL] Add array_intersect function ## What changes were proposed in this pull request? The PR adds the SQL function `array_intersect`. The behavior of the function is based on Presto's one. This function returns returns an array of the elements in the intersection of array1 and array2. Note: The order of elements in the result is not defined. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21102 from kiszk/SPARK-23913. --- python/pyspark/sql/functions.py | 19 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 386 +++++++++++++++--- .../CollectionExpressionsSuite.scala | 112 +++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 54 +++ 6 files changed, 515 insertions(+), 68 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ec014a5b39c31..eaecf284b51f1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2033,6 +2033,25 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_intersect(col1, col2): + """ + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=[u'a', u'c'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) + + @ignore_unicode_prefix @since(2.4) def array_union(col1, col2): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 35f8de1328b50..ed2f67da6f2bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -411,6 +411,7 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3f94f25796634..e385c2d9782e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3651,7 +3651,7 @@ case class ArrayDistinct(child: Expression) } /** - * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. + * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def checkInputDataTypes(): TypeCheckResult = { @@ -3672,6 +3672,75 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { case _: AtomicType => true case _ => false } + + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, elementType, i) + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName (elementType) + elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = elementType match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | $value.setNullAt($nullElementIndex); + |} + """.stripMargin + } else { + body + } + } + + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) } object ArraySetLike { @@ -3965,6 +4034,248 @@ object ArrayUnion { } } +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 3) + """, + since = "2.4.0") +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + } + + @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } + } else { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalIntersect(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val hashSetResult = ctx.freshName("hashSetResult") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + s""" + |if (!$array1.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if ($hashSet.contains($hsValueCast$value) && + | !$hashSetResult.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSetResult.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$openHashSet $hashSetResult = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet + |} + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $processArray1 + |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayIntersectExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + }) + } + } + + override def prettyName: String = "array_intersect" +} + /** * Returns an array of the elements in the intersect of x and y, without duplicates */ @@ -4065,7 +4376,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike i += 1 } new GenericArrayData(arrayBuffer) - } + } } override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -4080,31 +4391,10 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") - val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } if (canUseSpecializedHashSet) { val jt = CodeGenerator.javaType(elementType) val ptName = CodeGenerator.primitiveTypeName(jt) - def genGetValue(array: String): String = - CodeGenerator.getValue(array, elementType, i) - - val (hsPostFix, hsTypeName) = elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - - // we cast byte/short to int when writing to the hash set. - val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - nullSafeCodeGen(ctx, ev, (array1, array2) => { val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") @@ -4112,10 +4402,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") - val genericArrayData = classOf[GenericArrayData].getName - val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" def withArray2NullCheck(body: String): String = if (right.dataType.asInstanceOf[ArrayType].containsNull) { @@ -4141,18 +4429,10 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val writeArray2ToHashSet = withArray2NullCheck( s""" - |$jt $value = ${genGetValue(array2)}; + |$jt $value = ${genGetValue(array2, i)}; |$hashSet.add$hsPostFix($hsValueCast$value); """.stripMargin) - // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will - // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. - val nullValueHolder = elementType match { - case ByteType => "(byte) 0" - case ShortType => "(short) 0" - case _ => "0" - } - def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" @@ -4173,7 +4453,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val processArray1 = withArray1NullAssignment( s""" - |$jt $value = ${genGetValue(array1)}; + |$jt $value = ${genGetValue(array1, i)}; |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; @@ -4183,35 +4463,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |} """.stripMargin) - def withResultArrayNullCheck(body: String): String = { - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |$body - |if ($nullElementIndex >= 0) { - | // result has null element - | ${ev.value}.setNullAt($nullElementIndex); - |} - """.stripMargin - } else { - body - } - } - - val buildResultArray = withResultArrayNullCheck( - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Cannot create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); - |} - | - |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); - |} - """.stripMargin) - // Only need to track null element index when array1's element is nullable. val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" @@ -4228,13 +4479,12 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |for (int $i = 0; $i < $array2.numElements(); $i++) { | $writeArray2ToHashSet |} - |$arrayBuilderClass $builder = - | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |$arrayBuilderClass $builder = new $arrayBuilderClass(); |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 |} - |$buildResultArray + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2f6f9064f9e62..4daa113869b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1618,4 +1618,116 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Array Intersect") { + val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 1, 4), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1, null), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, false, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, containsNull = false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, containsNull = false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, containsNull = false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, containsNull = false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 1L, 4L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L, 4L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, null), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L, null), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", null, "f"), ArrayType(StringType, true)) + val a24 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) + val a25 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayIntersect(a00, a01), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a01, a00), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a02, a03), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a03, a02), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a00, a04), Seq(1, 2, 4)) + checkEvaluation(ArrayIntersect(a04, a05), Seq(2, null, 4)) + checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) + checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) + checkEvaluation(ArrayIntersect(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayIntersect(ab0, ab1), Seq[Byte](2)) + checkEvaluation(ArrayIntersect(as0, as1), Seq[Short](2)) + checkEvaluation(ArrayIntersect(af0, af1), Seq[Float](2.2F)) + checkEvaluation(ArrayIntersect(ad0, ad1), Seq[Double](2.2D)) + + checkEvaluation(ArrayIntersect(a10, a11), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a11, a10), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a12, a13), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a13, a12), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a14, a15), Seq(2L, null, 4L)) + checkEvaluation(ArrayIntersect(a12, a16), Seq.empty) + checkEvaluation(ArrayIntersect(a16, a14), Seq.empty) + + checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a20), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a22, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a22), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a23, a24), Seq("a", null)) + checkEvaluation(ArrayIntersect(a24, a23), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a24, a25), Seq.empty) + checkEvaluation(ArrayIntersect(a25, a24), Seq.empty) + + checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + checkEvaluation(ArrayIntersect(a20, a31), null) + checkEvaluation(ArrayIntersect(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) + checkEvaluation(ArrayIntersect(b0, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b3, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) + checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) + checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq[Array[Byte]](null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayIntersect(aa0, aa1), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) + + assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cc739b85f555c..310e428b69819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3233,6 +3233,17 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the intersection of the given two arrays, + * without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_intersect(col1: Column, col2: Column): Column = withExpr { + ArrayIntersect(col1.expr, col2.expr) + } + /** * Returns an array of the elements in the union of the given two arrays, without duplicates. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index af3301b1599a9..3c5831f33b23c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1647,6 +1647,60 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(result10.first.schema(0).dataType === expectedType10) } + test("array_intersect functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(2, 4)) + checkAnswer(df1.select(array_intersect($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_intersect(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(2, null, 4)) + checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(2L, 4L)) + checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(2L, null, 4L)) + checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", "a", null, "g"))).toDF("a", "b") + val ans5 = Row(Seq(null, "a")) + checkAnswer(df5.select(array_intersect($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + assert(intercept[AnalysisException] { + df6.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df6.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df7.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df7.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + + val df8 = Seq((null, Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df8.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df8.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + } + test("transform function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), From 51e2b38d93df8cb0cc151d5e68a2190eab52644c Mon Sep 17 00:00:00 2001 From: Hieu Huynh <“Hieu.huynh@oath.com”> Date: Mon, 6 Aug 2018 13:58:28 -0500 Subject: [PATCH 1318/2461] [SPARK-24992][CORE] spark should randomize yarn local dir selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description: [SPARK-24992](https://issues.apache.org/jira/browse/SPARK-24992)** Utils.getLocalDir is used to get path of a temporary directory. However, it always returns the the same directory, which is the first element in the array localRootDirs. When running on YARN, this might causes the case that we always write to one disk, which makes it busy while other disks are free. We should randomize the selection to spread out the loads. **What changes were proposed in this pull request?** This PR randomized the selection of local directory inside the method Utils.getLocalDir. This change affects the Utils.fetchFile method since it based on the fact that Utils.getLocalDir always return the same directory to cache file. Therefore, a new variable cachedLocalDir is used to cache the first localDirectory that it gets from Utils.getLocalDir. Also, when getting the configured local directories (inside Utils. getConfiguredLocalDirs), in case we are in yarn mode, the array of directories are also randomized before return. Author: Hieu Huynh <“Hieu.huynh@oath.com”> Closes #21953 from hthuynh2/SPARK_24992. --- .../scala/org/apache/spark/util/Utils.scala | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a6fd3637663e8..7ec707d94ed87 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -83,6 +83,7 @@ private[spark] object Utils extends Logging { val random = new Random() private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler + @volatile private var cachedLocalDir: String = "" /** * Define a default value for driver memory here since this value is referenced across the code @@ -462,7 +463,15 @@ private[spark] object Utils extends Logging { if (useCache && fetchCacheEnabled) { val cachedFileName = s"${url.hashCode}${timestamp}_cache" val lockFileName = s"${url.hashCode}${timestamp}_lock" - val localDir = new File(getLocalDir(conf)) + // Set the cachedLocalDir for the first time and re-use it later + if (cachedLocalDir.isEmpty) { + this.synchronized { + if (cachedLocalDir.isEmpty) { + cachedLocalDir = getLocalDir(conf) + } + } + } + val localDir = new File(cachedLocalDir) val lockFile = new File(localDir, lockFileName) val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. @@ -767,13 +776,17 @@ private[spark] object Utils extends Logging { * - Otherwise, this will return java.io.tmpdir. * * Some of these configuration options might be lists of multiple paths, but this method will - * always return a single directory. + * always return a single directory. The return directory is chosen randomly from the array + * of directories it gets from getOrCreateLocalRootDirs. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val localRootDirs = getOrCreateLocalRootDirs(conf) + if (localRootDirs.isEmpty) { val configuredLocalDirs = getConfiguredLocalDirs(conf) throw new IOException( s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } else { + localRootDirs(scala.util.Random.nextInt(localRootDirs.length)) } } @@ -815,7 +828,7 @@ private[spark] object Utils extends Logging { // to what Yarn on this system said was available. Note this assumes that Yarn has // created the directories already, and that they are secured so that only the // user has access to them. - getYarnLocalDirs(conf).split(",") + randomizeInPlace(getYarnLocalDirs(conf).split(",")) } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { From 278984d5a5e56136c9f940f2d0e3d2040fad180b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 6 Aug 2018 12:00:39 -0700 Subject: [PATCH 1319/2461] [SPARK-25019][BUILD] Fix orc dependency to use the same exclusion rules ## What changes were proposed in this pull request? During upgrading Apache ORC to 1.5.2 ([SPARK-24576](https://issues.apache.org/jira/browse/SPARK-24576)), `sql/core` module overrides the exclusion rules of parent pom file and it causes published `spark-sql_2.1X` artifacts have incomplete exclusion rules ([SPARK-25019](https://issues.apache.org/jira/browse/SPARK-25019)). This PR fixes it by moving the newly added exclusion rule to the parent pom. This also fixes the sbt build hack introduced at that time. ## How was this patch tested? Pass the existing dependency check and the tests. Author: Dongjoon Hyun Closes #22003 from dongjoon-hyun/SPARK-25019. --- pom.xml | 4 ++++ sql/core/pom.xml | 28 ---------------------------- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/pom.xml b/pom.xml index c46eb31715eab..8abdb700dc4c0 100644 --- a/pom.xml +++ b/pom.xml @@ -1743,6 +1743,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-hdfs + org.apache.hive hive-storage-api diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 68b42a4c51ec4..ba17f5f33f2b6 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -90,39 +90,11 @@ org.apache.orc orc-core ${orc.classifier} - - - org.apache.hadoop - hadoop-hdfs - - - - org.apache.hive - hive-storage-api - - org.apache.orc orc-mapreduce ${orc.classifier} - - - org.apache.hadoop - hadoop-hdfs - - - - org.apache.hive - hive-storage-api - - org.apache.parquet From 3c96937c7b1d7a010b630f4b98fd22dafc37808b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 6 Aug 2018 14:29:05 -0700 Subject: [PATCH 1320/2461] [SPARK-24948][SHS] Delegate check access permissions to the file system ## What changes were proposed in this pull request? In `SparkHadoopUtil. checkAccessPermission`, we consider only basic permissions in order to check wether a user can access a file or not. This is not a complete check, as it ignores ACLs and other policies a file system may apply in its internal. So this can result in returning wrongly that a user cannot access a file (despite he actually can). The PR proposes to delegate to the filesystem the check whether a file is accessible or not, in order to return the right result. A caching layer is added for performance reasons. ## How was this patch tested? modified UTs Author: Marco Gaido Closes #21895 from mgaido91/SPARK-24948. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 23 ----- .../deploy/history/FsHistoryProvider.scala | 67 +++++++++---- .../spark/deploy/SparkHadoopUtilSuite.scala | 97 ------------------- .../history/FsHistoryProviderSuite.scala | 42 +++++++- 4 files changed, 89 insertions(+), 140 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 8353e64a619cf..70a8c659bbdd3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -31,7 +31,6 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} -import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -367,28 +366,6 @@ class SparkHadoopUtil extends Logging { buffer.toString } - private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { - val perm = status.getPermission - val ugi = UserGroupInformation.getCurrentUser - - if (ugi.getShortUserName == status.getOwner) { - if (perm.getUserAction.implies(mode)) { - return true - } - } else if (ugi.getGroupNames.contains(status.getGroup)) { - if (perm.getGroupAction.implies(mode)) { - return true - } - } else if (perm.getOtherAction.implies(mode)) { - return true - } - - logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + - s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + - s"${if (status.isDirectory) "d" else "-"}$perm") - false - } - def serialize(creds: Credentials): Array[Byte] = { val byteStream = new ByteArrayOutputStream val dataStream = new DataOutputStream(byteStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index bf1eeb0c1bf59..44d23908146c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -21,11 +21,12 @@ import java.io.{File, FileNotFoundException, IOException} import java.nio.file.Files import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} -import java.util.concurrent.{ExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.ExecutionException import scala.io.Source import scala.util.Try import scala.xml.Node @@ -33,8 +34,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -114,7 +114,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private val fs = new Path(logDir).getFileSystem(hadoopConf) + // Visible for testing + private[history] val fs: FileSystem = new Path(logDir).getFileSystem(hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -161,6 +162,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) new HistoryServerDiskManager(conf, path, listing, clock) } + private val blacklist = new ConcurrentHashMap[String, Long] + + // Visible for testing + private[history] def isBlacklisted(path: Path): Boolean = { + blacklist.containsKey(path.getName) + } + + private def blacklist(path: Path): Unit = { + blacklist.put(path.getName, clock.getTimeMillis()) + } + + /** + * Removes expired entries in the blacklist, according to the provided `expireTimeInSeconds`. + */ + private def clearBlacklist(expireTimeInSeconds: Long): Unit = { + val expiredThreshold = clock.getTimeMillis() - expireTimeInSeconds * 1000 + blacklist.asScala.retain((_, creationTime) => creationTime >= expiredThreshold) + } + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() /** @@ -418,7 +438,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + !isBlacklisted(entry.getPath) } .filter { entry => try { @@ -461,32 +481,37 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - val tasks = updated.map { entry => + val tasks = updated.flatMap { entry => try { - replayExecutor.submit(new Runnable { + val task: Future[Unit] = replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) - }) + }, Unit) + Some(task -> entry.getPath) } catch { // let the iteration over the updated entries break, since an exception on // replayExecutor.submit (..) indicates the ExecutorService is unable // to take any more submissions at this time case e: Exception => logError(s"Exception while submitting event log for replay", e) - null + None } - }.filter(_ != null) + } pendingReplayTasksCount.addAndGet(tasks.size) // Wait for all tasks to finish. This makes sure that checkForLogs // is not scheduled again while some tasks are already running in // the replayExecutor. - tasks.foreach { task => + tasks.foreach { case (task, path) => try { task.get() } catch { case e: InterruptedException => throw e + case e: ExecutionException if e.getCause.isInstanceOf[AccessControlException] => + // We don't have read permissions on the log file + logWarning(s"Unable to read log $path", e.getCause) + blacklist(path) case e: Exception => logError("Exception while merging application listings", e) } finally { @@ -779,6 +804,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.delete(classOf[LogInfo], log.logPath) } } + // Clean the blacklist from the expired entries. + clearBlacklist(CLEAN_INTERVAL_S) } /** @@ -938,13 +965,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } private def deleteLog(log: Path): Unit = { - try { - fs.delete(log, true) - } catch { - case _: AccessControlException => - logInfo(s"No permission to delete $log, ignoring.") - case ioe: IOException => - logError(s"IOException in cleaning $log", ioe) + if (isBlacklisted(log)) { + logDebug(s"Skipping deleting $log as we don't have permissions on it.") + } else { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala deleted file mode 100644 index ab24a76e20a30..0000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.security.PrivilegedExceptionAction - -import scala.util.Random - -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.permission.{FsAction, FsPermission} -import org.apache.hadoop.security.UserGroupInformation -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite - -class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { - test("check file permission") { - import FsAction._ - val testUser = s"user-${Random.nextInt(100)}" - val testGroups = Array(s"group-${Random.nextInt(100)}") - val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) - - testUgi.doAs(new PrivilegedExceptionAction[Void] { - override def run(): Void = { - val sparkHadoopUtil = new SparkHadoopUtil - - // If file is owned by user and user has access permission - var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user but user has no access permission - status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - val otherUser = s"test-${Random.nextInt(100)}" - val otherGroup = s"test-${Random.nextInt(100)}" - - // If file is owned by user's group and user's group has access permission - status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user's group but user's group has no access permission - status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - // If file is owned by other user and this user has access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by other user but this user has no access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - null - } - }) - } - - private def fileStatus( - owner: String, - group: String, - userAction: FsAction, - groupAction: FsAction, - otherAction: FsAction): FileStatus = { - new FileStatus(0L, - false, - 0, - 0L, - 0L, - 0L, - new FsPermission(userAction, groupAction, otherAction), - owner, - group, - null) - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 77b239489d489..b4eba755eccbf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -29,9 +29,11 @@ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.security.AccessControlException import org.json4s.jackson.JsonMethods._ -import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.ArgumentMatcher +import org.mockito.Matchers.{any, argThat} +import org.mockito.Mockito.{doThrow, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -818,6 +820,42 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-24948: blacklist files we don't have read permission on") { + val clock = new ManualClock(1533132471) + val provider = new FsHistoryProvider(createTestConf(), clock) + val accessDenied = newLogFile("accessDenied", None, inProgress = false) + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None)) + val accessGranted = newLogFile("accessGranted", None, inProgress = false) + writeFile(accessGranted, true, None, + SparkListenerApplicationStart("accessGranted", Some("accessGranted"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + val mockedFs = spy(provider.fs) + doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open( + argThat(new ArgumentMatcher[Path]() { + override def matches(path: Any): Boolean = { + path.asInstanceOf[Path].getName.toLowerCase == "accessdenied" + } + })) + val mockedProvider = spy(provider) + when(mockedProvider.fs).thenReturn(mockedFs) + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + // Doing 2 times in order to check the blacklist filter too + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + val accessDeniedPath = new Path(accessDenied.getPath) + assert(mockedProvider.isBlacklisted(accessDeniedPath)) + clock.advance(24 * 60 * 60 * 1000 + 1) // add a bit more than 1d + mockedProvider.cleanLogs() + assert(!mockedProvider.isBlacklisted(accessDeniedPath)) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: From 87ca7396c7b21a87874d8ceb32e53119c609002c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 6 Aug 2018 15:23:47 -0700 Subject: [PATCH 1321/2461] [SPARK-24161][SS] Enable debug package feature on structured streaming ## What changes were proposed in this pull request? Currently, debug package has a implicit class "DebugQuery" which matches Dataset to provide debug features on Dataset class. It doesn't work with structured streaming: it requires query is already started, and the information can be retrieved from StreamingQuery, not Dataset. I guess that's why "explain" had to be placed to StreamingQuery whereas it already exists on Dataset. This patch adds a new implicit class "DebugStreamQuery" which matches StreamingQuery to provide similar debug features on StreamingQuery class. ## How was this patch tested? Added relevant unit tests. Author: Jungtaek Lim Closes #21222 from HeartSaVioR/SPARK-24161. --- .../spark/sql/execution/debug/package.scala | 59 ++++++++- .../spark/sql/streaming/StreamSuite.scala | 116 ++++++++++++++++++ 2 files changed, 173 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index a717cbd4a7df9..366e1fe6a4aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming.continuous.WriteToContinuousDataSourceExec +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{AccumulatorV2, LongAccumulator} /** @@ -40,6 +43,16 @@ import org.apache.spark.util.{AccumulatorV2, LongAccumulator} * sql("SELECT 1").debug() * sql("SELECT 1").debugCodegen() * }}} + * + * or for streaming case (structured streaming): + * {{{ + * import org.apache.spark.sql.execution.debug._ + * val query = df.writeStream.<...>.start() + * query.debugCodegen() + * }}} + * + * Note that debug in structured streaming is not supported, because it doesn't make sense for + * streaming to execute batch once while main query is running concurrently. */ package object debug { @@ -88,14 +101,50 @@ package object debug { } } + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan into one String + * + * @param query the streaming query for codegen + * @return single String containing all WholeStageCodegen subtrees and corresponding codegen + */ + def codegenString(query: StreamingQuery): String = { + val w = asStreamExecution(query) + if (w.lastExecution != null) { + codegenString(w.lastExecution.executedPlan) + } else { + "No physical plan. Waiting for data." + } + } + + /** + * Get WholeStageCodegenExec subtrees and the codegen in a query plan + * + * @param query the streaming query for codegen + * @return Sequence of WholeStageCodegen subtrees and corresponding codegen + */ + def codegenStringSeq(query: StreamingQuery): Seq[(String, String)] = { + val w = asStreamExecution(query) + if (w.lastExecution != null) { + codegenStringSeq(w.lastExecution.executedPlan) + } else { + Seq.empty + } + } + + private def asStreamExecution(query: StreamingQuery): StreamExecution = query match { + case wrapper: StreamingQueryWrapper => wrapper.streamingQuery + case q: StreamExecution => q + case _ => throw new IllegalArgumentException("Parameter should be an instance of " + + "StreamExecution!") + } + /** * Augments [[Dataset]]s with debug methods. */ implicit class DebugQuery(query: Dataset[_]) extends Logging { def debug(): Unit = { - val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() - val debugPlan = plan transform { + val debugPlan = query.queryExecution.executedPlan transform { case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => visited += new TreeNodeRef(s) DebugExec(s) @@ -116,6 +165,12 @@ package object debug { } } + implicit class DebugStreamQuery(query: StreamingQuery) extends Logging { + def debugCodegen(): Unit = { + debugPrint(codegenString(query)) + } + } + case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index ca38f04136c7d..bf509b1976ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -27,6 +27,7 @@ import scala.util.control.ControlThrowable import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration +import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -35,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -513,6 +515,120 @@ class StreamSuite extends StreamTest { } } + test("explain-continuous") { + val inputData = ContinuousMemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test `df.explain` + val explain = ExplainCommand(df.queryExecution.logical, extended = false) + val explainString = + spark.sessionState + .executePlan(explain) + .executedPlan + .executeCollect() + .map(_.getString(0)) + .mkString("\n") + assert(explainString.contains("Filter")) + assert(explainString.contains("MapElements")) + assert(!explainString.contains("LocalTableScan")) + + // Test StreamingQuery.display + val q = df.writeStream.queryName("memory_continuous_explain") + .outputMode(OutputMode.Update()).format("memory") + .trigger(Trigger.Continuous("1 seconds")) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + try { + // in continuous mode, the query will be run even there's no data + // sleep a bit to ensure initialization + eventually(timeout(2.seconds), interval(100.milliseconds)) { + assert(q.lastExecution != null) + } + + val explainWithoutExtended = q.explainInternal(false) + + // `extended = false` only displays the physical plan. + assert("Streaming RelationV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithoutExtended).size === 1) + + val explainWithExtended = q.explainInternal(true) + // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical + // plan. + assert("Streaming RelationV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 ContinuousMemoryStream".r + .findAllMatchIn(explainWithExtended).size === 1) + } finally { + q.stop() + } + } + + test("codegen-microbatch") { + val inputData = MemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test StreamingQuery.codegen + val q = df.writeStream.queryName("memory_microbatch_codegen") + .outputMode(OutputMode.Update) + .format("memory") + .trigger(Trigger.ProcessingTime("1 seconds")) + .start() + + try { + import org.apache.spark.sql.execution.debug._ + assert("No physical plan. Waiting for data." === codegenString(q)) + assert(codegenStringSeq(q).isEmpty) + + inputData.addData(1, 2, 3, 4, 5) + q.processAllAvailable() + + assertDebugCodegenResult(q) + } finally { + q.stop() + } + } + + test("codegen-continuous") { + val inputData = ContinuousMemoryStream[Int] + val df = inputData.toDS().map(_ * 2).filter(_ > 5) + + // Test StreamingQuery.codegen + val q = df.writeStream.queryName("memory_continuous_codegen") + .outputMode(OutputMode.Update) + .format("memory") + .trigger(Trigger.Continuous("1 seconds")) + .start() + + try { + // in continuous mode, the query will be run even there's no data + // sleep a bit to ensure initialization + eventually(timeout(2.seconds), interval(100.milliseconds)) { + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution != null) + } + + assertDebugCodegenResult(q) + } finally { + q.stop() + } + } + + private def assertDebugCodegenResult(query: StreamingQuery): Unit = { + import org.apache.spark.sql.execution.debug._ + + val codegenStr = codegenString(query) + assert(codegenStr.contains("Found 1 WholeStageCodegen subtrees.")) + // assuming that code is generated for the test query + assert(codegenStr.contains("Generated code:")) + + val codegenStrSeq = codegenStringSeq(query) + assert(codegenStrSeq.nonEmpty) + assert(codegenStrSeq.head._1.contains("*(1)")) + assert(codegenStrSeq.head._2.contains("codegenStageId=1")) + } + test("SPARK-19065: dropDuplicates should not create expressions using the same id") { withTempPath { testPath => val data = Seq((1, 2), (2, 3), (3, 4)) From 408a3ff2c484fba5734c03dbc570b654dcbc1f23 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 6 Aug 2018 19:43:21 -0400 Subject: [PATCH 1322/2461] [SPARK-25036][SQL] Should compare ExprValue.isNull with LiteralTrue/LiteralFalse ## What changes were proposed in this pull request? This PR fixes a comparison of `ExprValue.isNull` with `String`. `ExprValue.isNull` should be compared with `LiteralTrue` or `LiteralFalse`. This causes the following compilation error using scala-2.12 with sbt. In addition, this code may also generate incorrect code in Spark 2.3. ``` /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala:94: org.apache.spark.sql.catalyst.expressions.codegen.ExprValue and String are unrelated: they will most likely always compare unequal [error] [warn] if (eval.isNull != "true") { [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala:126: org.apache.spark.sql.catalyst.expressions.codegen.ExprValue and String are unrelated: they will most likely never compare equal [error] [warn] if (eval.isNull == "true") { [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala:133: org.apache.spark.sql.catalyst.expressions.codegen.ExprValue and String are unrelated: they will most likely never compare equal [error] [warn] if (eval.isNull == "true") { [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala:90: org.apache.spark.sql.catalyst.expressions.codegen.ExprValue and String are unrelated: they will most likely never compare equal [error] [warn] if (inputs.map(_.isNull).forall(_ == "false")) { [error] [warn] ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #22012 from kiszk/SPARK-25036a. --- .../expressions/codegen/GenerateUnsafeProjection.scala | 2 +- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 8f2a5a0dce943..998a675eecc62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -87,7 +87,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call // `reset` to set up its fixed-size region every time. - if (inputs.map(_.isNull).forall(_ == "false")) { + if (inputs.map(_.isNull).forall(_ == FalseLiteral)) { // If all fields are not nullable, which means the null bits never changes, then we don't // need to clear it out every time. "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1838b9fca02db..e1549d3dee539 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -91,7 +91,7 @@ case class ConcatWs(children: Seq[Expression]) val args = ctx.freshName("args") val inputs = strings.zipWithIndex.map { case (eval, index) => - if (eval.isNull != "true") { + if (eval.isNull != TrueLiteral) { s""" ${eval.code} if (!${eval.isNull}) { @@ -123,14 +123,14 @@ case class ConcatWs(children: Seq[Expression]) child.dataType match { case StringType => ("", // we count all the StringType arguments num at once below. - if (eval.isNull == "true") { + if (eval.isNull == TrueLiteral) { "" } else { s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") - if (eval.isNull == "true") { + if (eval.isNull == TrueLiteral) { ("", "") } else { (s""" From 0f3fa2f289f53a8ceea3b0a52fa6dc319001b10b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 6 Aug 2018 19:46:51 -0400 Subject: [PATCH 1323/2461] [SPARK-24996][SQL] Use DSL in DeclarativeAggregate ## What changes were proposed in this pull request? The PR refactors the aggregate expressions which were not using DSL in order to simplify them. ## How was this patch tested? NA Author: Marco Gaido Closes #21970 from mgaido91/SPARK-24996. --- .../spark/sql/catalyst/dsl/package.scala | 2 + .../expressions/aggregate/Average.scala | 2 +- .../aggregate/CentralMomentAgg.scala | 40 ++++++++----------- .../catalyst/expressions/aggregate/Corr.scala | 13 +++--- .../expressions/aggregate/Covariance.scala | 16 ++++---- .../expressions/aggregate/First.scala | 7 ++-- .../catalyst/expressions/aggregate/Last.scala | 7 ++-- .../catalyst/expressions/aggregate/Max.scala | 5 ++- .../catalyst/expressions/aggregate/Min.scala | 5 ++- .../catalyst/expressions/aggregate/Sum.scala | 7 ++-- .../expressions/windowExpressions.scala | 30 +++++++------- 11 files changed, 65 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 75387fac64ed8..2b582b5be61a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -167,6 +167,8 @@ package object dsl { def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def coalesce(args: Expression*): Expression = Coalesce(args) + def greatest(args: Expression*): Expression = Greatest(args) + def least(args: Expression*): Expression = Least(args) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) def star(names: String*): Expression = names match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index f1fad770b637f..5ecb77be5965e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -68,7 +68,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { Add( sum, coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), - /* count = */ If(IsNull(child), count, count + 1L) + /* count = */ If(child.isNull, count, count + 1L) ) override lazy val updateExpressions = updateExpressionsDef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 6bbb083f1e18e..e2ff0efba07ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -75,7 +75,7 @@ abstract class CentralMomentAgg(child: Expression) val n2 = n.right val newN = n1 + n2 val delta = avg.right - avg.left - val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val deltaN = If(newN === 0.0, 0.0, delta / newN) val newAvg = avg.left + deltaN * n2 // higher order moments computed according to: @@ -102,7 +102,7 @@ abstract class CentralMomentAgg(child: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val delta = child - avg val deltaN = delta / newN val newAvg = avg + deltaN @@ -123,11 +123,11 @@ abstract class CentralMomentAgg(child: Expression) } trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) + If(child.isNull, n, newN), + If(child.isNull, avg, newAvg), + If(child.isNull, m2, newM2), + If(child.isNull, m3, newM3), + If(child.isNull, m4, newM4) )) } } @@ -142,8 +142,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - Sqrt(m2 / n)) + If(n === 0.0, Literal.create(null, DoubleType), sqrt(m2 / n)) } override def prettyName: String = "stddev_pop" @@ -159,9 +158,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - Sqrt(m2 / (n - Literal(1.0))))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) } override def prettyName: String = "stddev_samp" @@ -175,8 +173,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - m2 / n) + If(n === 0.0, Literal.create(null, DoubleType), m2 / n) } override def prettyName: String = "var_pop" @@ -190,9 +187,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - m2 / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, m2 / (n - 1.0))) } override def prettyName: String = "var_samp" @@ -207,9 +203,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 3 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } } @@ -220,9 +215,8 @@ case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(m2 === Literal(0.0), Literal(Double.NaN), - n * m4 / (m2 * m2) - Literal(3.0))) + If(n === 0.0, Literal.create(null, DoubleType), + If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0)) } override def prettyName: String = "kurtosis" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 3cdef72c1f2c4..e14cc716ea223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -54,9 +54,9 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 @@ -67,7 +67,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val dx = x - xAvg val dxN = dx / newN val dy = y - yAvg @@ -78,7 +78,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) val newXMk = xMk + dx * (x - newXAvg) val newYMk = yMk + dy * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -99,9 +99,8 @@ case class Corr(x: Expression, y: Expression) extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / Sqrt(xMk * yMk))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk))) } override def prettyName: String = "corr" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 72a7c62b328ee..ee28eb591882f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -50,9 +50,9 @@ abstract class Covariance(x: Expression, y: Expression) val n2 = n.right val newN = n1 + n2 val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dxN = If(newN === 0.0, 0.0, dx / newN) val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val dyN = If(newN === 0.0, 0.0, dy / newN) val newXAvg = xAvg.left + dxN * n2 val newYAvg = yAvg.left + dyN * n2 val newCk = ck.left + ck.right + dx * dyN * n1 * n2 @@ -61,7 +61,7 @@ abstract class Covariance(x: Expression, y: Expression) } protected def updateExpressionsDef: Seq[Expression] = { - val newN = n + Literal(1.0) + val newN = n + 1.0 val dx = x - xAvg val dy = y - yAvg val dyN = dy / newN @@ -69,7 +69,7 @@ abstract class Covariance(x: Expression, y: Expression) val newYAvg = yAvg + dyN val newCk = ck + dx * (y - newYAvg) - val isNull = IsNull(x) || IsNull(y) + val isNull = x.isNull || y.isNull Seq( If(isNull, n, newN), If(isNull, xAvg, newXAvg), @@ -83,8 +83,7 @@ abstract class Covariance(x: Expression, y: Expression) usage = "_FUNC_(expr1, expr2) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - ck / n) + If(n === 0.0, Literal.create(null, DoubleType), ck / n) } override def prettyName: String = "covar_pop" } @@ -94,9 +93,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance usage = "_FUNC_(expr1, expr2) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), - If(n === Literal(1.0), Literal(Double.NaN), - ck / (n - Literal(1.0)))) + If(n === 0.0, Literal.create(null, DoubleType), + If(n === 1.0, Double.NaN, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 4e671e1f3e6eb..f51bfd591204a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* first = */ If(Or(valueSet, IsNull(child)), first, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* first = */ If(valueSet || child.isNull, first, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -97,7 +98,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) // false, we are safe to do so because first.right will be null in this case). Seq( /* first = */ If(valueSet.left, first.left, first.right), - /* valueSet = */ Or(valueSet.left, valueSet.right) + /* valueSet = */ valueSet.left || valueSet.right ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 0ccabb9d98914..2650d7b5908fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) + /* last = */ If(child.isNull, last, child), + /* valueSet = */ valueSet || child.isNotNull ) } else { Seq( @@ -95,7 +96,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) // Prefer the right hand expression if it has been set. Seq( /* last = */ If(valueSet.right, last.right, last.left), - /* valueSet = */ Or(valueSet.right, valueSet.left) + /* valueSet = */ valueSet.right || valueSet.left ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 58fd1d8620e16..71099eba0fc75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ Greatest(Seq(max, child)) + /* max = */ greatest(max, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* max = */ Greatest(Seq(max.left, max.right)) + /* max = */ greatest(max.left, max.right) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index b2724ee76827c..8c4ba93231cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -45,12 +46,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ Least(Seq(min, child)) + /* min = */ least(min, child) ) override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* min = */ Least(Seq(min.left, min.right)) + /* min = */ least(min.left, min.right) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 86e40a9713b36..761dba111c074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -61,12 +62,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) ) } else { Seq( /* sum = */ - Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + coalesce(sum, zero) + child.cast(sumDataType) ) } } @@ -74,7 +75,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 53c6f01c2459e..707f312499734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ @@ -476,7 +477,7 @@ abstract class RowNumberLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil override val initialValues: Seq[Expression] = zero :: Nil - override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil + override val updateExpressions: Seq[Expression] = rowNumber + one :: Nil } /** @@ -527,7 +528,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must // return the same value for equal values in the partition. override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override val evaluateExpression = rowNumber.cast(DoubleType) / n.cast(DoubleType) override def prettyName: String = "cume_dist" } @@ -587,8 +588,7 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() private val bucketsWithPadding = AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() - private def bucketOverflow(e: Expression) = - If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + private def bucketOverflow(e: Expression) = If(rowNumber >= bucketThreshold, e, zero) override val aggBufferAttributes = Seq( rowNumber, @@ -602,15 +602,14 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow zero, zero, zero, - Cast(Divide(n, buckets), IntegerType), - Cast(Remainder(n, buckets), IntegerType) + (n / buckets).cast(IntegerType), + (n % buckets).cast(IntegerType) ) override val updateExpressions = Seq( - Add(rowNumber, one), - Add(bucket, bucketOverflow(one)), - Add(bucketThreshold, bucketOverflow( - Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + rowNumber + one, + bucket + bucketOverflow(one), + bucketThreshold + bucketOverflow(bucketSize + If(bucket < bucketsWithPadding, one, zero)), NoOp, NoOp ) @@ -644,7 +643,7 @@ abstract class RankLike extends AggregateWindowFunction { protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() protected val zero = Literal(0) protected val one = Literal(1) - protected val increaseRowNumber = Add(rowNumber, one) + protected val increaseRowNumber = rowNumber + one /** * Different RankLike implementations use different source expressions to update their rank value. @@ -653,7 +652,7 @@ abstract class RankLike extends AggregateWindowFunction { protected def rankSource: Expression = rowNumber /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ - protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + protected val increaseRank = If(orderEquals && rank =!= zero, rank, rankSource) override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs override val initialValues = zero +: one +: orderInit @@ -707,7 +706,7 @@ case class Rank(children: Seq[Expression]) extends RankLike { case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) - override protected def rankSource = Add(rank, one) + override protected def rankSource = rank + one override val updateExpressions = increaseRank +: children override val aggBufferAttributes = rank +: orderAttrs override val initialValues = zero +: orderInit @@ -736,8 +735,7 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) override def dataType: DataType = DoubleType - override val evaluateExpression = If(GreaterThan(n, one), - Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), - Literal(0.0d)) + override val evaluateExpression = + If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d) override def prettyName: String = "percent_rank" } From 1076e4f0026914804b5948ff0da0c84def1315cc Mon Sep 17 00:00:00 2001 From: deshanxiao <42019462+deshanxiao@users.noreply.github.com> Date: Tue, 7 Aug 2018 09:36:37 +0800 Subject: [PATCH 1324/2461] [MINOR][DOCS] Fix grammatical error in SortShuffleManager ## What changes were proposed in this pull request? Fix a grammatical error in the comment of SortShuffleManager. ## How was this patch tested? N/A Closes #21956 from deshanxiao/master. Authored-by: deshanxiao <42019462+deshanxiao@users.noreply.github.com> Signed-off-by: hyukjinkwon --- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d9fad64f34c7c..0caf84c6050a8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.shuffle._ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then * written to a single map output file. Reducers fetch contiguous regions of this file in order to * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * memory, sorted subsets of the output can be spilled to disk and those on-disk files are merged * to produce the final output file. * * Sort-based shuffle has two different write paths for producing its map output files: From 6afe6f32ca2880b13bb5fb4397b2058eef12952b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 7 Aug 2018 10:12:22 +0800 Subject: [PATCH 1325/2461] [SPARK-24637][SS] Add metrics regarding state and watermark to dropwizard metrics ## What changes were proposed in this pull request? The patch adds metrics regarding state and watermark to dropwizard metrics, so that watermark and state rows/size can be tracked via time-series manner. ## How was this patch tested? Manually tested with CSV metric sink. Closes #21622 from HeartSaVioR/SPARK-24637. Authored-by: Jungtaek Lim Signed-off-by: hyukjinkwon --- .../execution/streaming/MetricsReporter.scala | 20 +++++++++++++++++++ .../sql/streaming/StreamingQuerySuite.scala | 3 +++ 2 files changed, 23 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index 66b11ecddf233..8709822acff12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.streaming +import java.text.SimpleDateFormat + import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.streaming.StreamingQueryProgress /** @@ -39,6 +42,23 @@ class MetricsReporter( registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 + timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC")) + + registerGauge("eventTime-watermark", + progress => convertStringDateToMillis(progress.eventTime.get("watermark")), 0L) + + registerGauge("states-rowsTotal", _.stateOperators.map(_.numRowsTotal).sum, 0L) + registerGauge("states-usedBytes", _.stateOperators.map(_.memoryUsedBytes).sum, 0L) + + private def convertStringDateToMillis(isoUtcDateStr: String) = { + if (isoUtcDateStr != null) { + timestampFormat.parse(isoUtcDateStr).getTime + } else { + 0L + } + } + private def registerGauge[T]( name: String, f: StreamingQueryProgress => T, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index f37f3682b03b9..9cceec90e4d51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -467,6 +467,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("eventTime-watermark").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("states-rowsTotal").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("states-usedBytes").getValue.asInstanceOf[Long] == 0) sq.stop() } } From 18b6ec14716bfafc25ae281b190547ea58b59af1 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Tue, 7 Aug 2018 10:28:26 +0800 Subject: [PATCH 1326/2461] [SPARK-24748][SS] Support for reporting custom metrics via StreamingQuery Progress ## What changes were proposed in this pull request? Currently the Structured Streaming sources and sinks does not have a way to report custom metrics. Providing an option to report custom metrics and making it available via Streaming Query progress can enable sources and sinks to report custom progress information (E.g. the lag metrics for Kafka source). Similar metrics can be reported for Sinks as well, but would like to get initial feedback before proceeding further. ## How was this patch tested? New and existing unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21721 from arunmahadevan/SPARK-24748. Authored-by: Arun Mahadevan Signed-off-by: hyukjinkwon --- .../spark/sql/sources/v2/CustomMetrics.java | 33 ++++++++++ .../SupportsCustomReaderMetrics.java | 47 ++++++++++++++ .../SupportsCustomWriterMetrics.java | 47 ++++++++++++++ .../streaming/ProgressReporter.scala | 63 +++++++++++++++++-- .../streaming/sources/MicroBatchWriter.scala | 2 +- .../streaming/sources/memoryV2.scala | 32 ++++++++-- .../apache/spark/sql/streaming/progress.scala | 46 ++++++++++++-- .../streaming/MemorySinkV2Suite.scala | 22 +++++++ .../sql/streaming/StreamingQuerySuite.scala | 28 +++++++++ 9 files changed, 306 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java new file mode 100644 index 0000000000000..7011a70e515e2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface for reporting custom metrics from streaming sources and sinks + */ +@InterfaceStability.Evolving +public interface CustomMetrics { + /** + * Returns a JSON serialized representation of custom metrics + * + * @return JSON serialized representation of custom metrics + */ + String json(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java new file mode 100644 index 0000000000000..3b293d925c91d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.CustomMetrics; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; + +/** + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to report custom metrics that gets reported under the + * {@link org.apache.spark.sql.streaming.SourceProgress} + * + */ +@InterfaceStability.Evolving +public interface SupportsCustomReaderMetrics extends DataSourceReader { + /** + * Returns custom metrics specific to this data source. + */ + CustomMetrics getCustomMetrics(); + + /** + * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid + * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that + * your custom metrics work right and correct values are reported always. The default action + * on invalid metrics is to ignore it. + * + * @param ex the exception + */ + default void onInvalidMetrics(Exception ex) { + // default is to ignore invalid custom metrics + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java new file mode 100644 index 0000000000000..0cd36501320fd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources.v2.writer.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.CustomMetrics; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; + +/** + * A mix in interface for {@link DataSourceWriter}. Data source writers can implement this + * interface to report custom metrics that gets reported under the + * {@link org.apache.spark.sql.streaming.SinkProgress} + * + */ +@InterfaceStability.Evolving +public interface SupportsCustomWriterMetrics extends DataSourceWriter { + /** + * Returns custom metrics specific to this data source. + */ + CustomMetrics getCustomMetrics(); + + /** + * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid + * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that + * your custom metrics work right and correct values are reported always. The default action + * on invalid metrics is to ignore it. + * + * @param ex the exception + */ + default void onInvalidMetrics(Exception ex) { + // default is to ignore invalid custom metrics + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 47f4b52e6e34c..1e158323d2020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -22,14 +22,22 @@ import java.util.{Date, UUID} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.control.NonFatal + +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.sources.v2.CustomMetrics +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsCustomWriterMetrics import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -156,7 +164,31 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") + // extracts and validates custom metrics from readers and writers + def extractMetrics( + getMetrics: () => Option[CustomMetrics], + onInvalidMetrics: (Exception) => Unit): Option[String] = { + try { + getMetrics().map(m => { + val json = m.json() + parse(json) + json + }) + } catch { + case ex: Exception if NonFatal(ex) => + onInvalidMetrics(ex) + None + } + } + val sourceProgress = sources.distinct.map { source => + val customReaderMetrics = source match { + case s: SupportsCustomReaderMetrics => + extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) + + case _ => None + } + val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -164,10 +196,19 @@ trait ProgressReporter extends Logging { endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, - processedRowsPerSecond = numRecords / processingTimeSec + processedRowsPerSecond = numRecords / processingTimeSec, + customReaderMetrics.orNull ) } - val sinkProgress = new SinkProgress(sink.toString) + + val customWriterMetrics = dataSourceWriter match { + case Some(s: SupportsCustomWriterMetrics) => + extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) + + case _ => None + } + + val sinkProgress = new SinkProgress(sink.toString, customWriterMetrics.orNull) val newProgress = new StreamingQueryProgress( id = id, @@ -196,6 +237,18 @@ trait ProgressReporter extends Logging { currentStatus = currentStatus.copy(isTriggerActive = false) } + /** Extract writer from the executed query plan. */ + private def dataSourceWriter: Option[DataSourceWriter] = { + if (lastExecution == null) return None + lastExecution.executedPlan.collect { + case p if p.isInstanceOf[WriteToDataSourceV2Exec] => + p.asInstanceOf[WriteToDataSourceV2Exec].writer + }.headOption match { + case Some(w: MicroBatchWriter) => Some(w.writer) + case _ => None + } + } + /** Extract statistics about stateful operators from the executed query plan. */ private def extractStateOperatorMetrics(hasNewData: Boolean): Seq[StateOperatorProgress] = { if (lastExecution == null) return Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d023a35ea20b6..2d43a7bb77872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped * streaming writer. */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { +class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index afacb2f72c926..2a5d21f330541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -23,6 +23,9 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -32,9 +35,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -114,14 +117,25 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB batches.clear() } + def numRows: Int = synchronized { + batches.foldLeft(0)(_ + _.data.length) + } + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} +class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) +} + class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) - extends DataSourceWriter with Logging { + extends DataSourceWriter with SupportsCustomWriterMetrics with Logging { + + private val memoryV2CustomMetrics = new MemoryV2CustomMetrics(sink) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) @@ -135,10 +149,16 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, sc override def abort(messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } + + override def getCustomMetrics: CustomMetrics = { + memoryV2CustomMetrics + } } class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamWriter { + extends StreamWriter with SupportsCustomWriterMetrics { + + private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) @@ -152,6 +172,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } + + override def getCustomMetrics: CustomMetrics = { + customMemoryV2Metrics + } } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 0dcb666e2c3e4..2fb87960ccb04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -163,7 +163,27 @@ class SourceProgress protected[sql]( val endOffset: String, val numInputRows: Long, val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double) extends Serializable { + val processedRowsPerSecond: Double, + val customMetrics: String) extends Serializable { + + /** SourceProgress without custom metrics. */ + protected[sql] def this( + description: String, + startOffset: String, + endOffset: String, + numInputRows: Long, + inputRowsPerSecond: Double, + processedRowsPerSecond: Double) { + + this( + description, + startOffset, + endOffset, + numInputRows, + inputRowsPerSecond, + processedRowsPerSecond, + null) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -178,12 +198,18 @@ class SourceProgress protected[sql]( if (value.isNaN || value.isInfinity) JNothing else JDouble(value) } - ("description" -> JString(description)) ~ + val jsonVal = ("description" -> JString(description)) ~ ("startOffset" -> tryParse(startOffset)) ~ ("endOffset" -> tryParse(endOffset)) ~ ("numInputRows" -> JInt(numInputRows)) ~ ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) + + if (customMetrics != null) { + jsonVal ~ ("customMetrics" -> parse(customMetrics)) + } else { + jsonVal + } } private def tryParse(json: String) = try { @@ -202,7 +228,13 @@ class SourceProgress protected[sql]( */ @InterfaceStability.Evolving class SinkProgress protected[sql]( - val description: String) extends Serializable { + val description: String, + val customMetrics: String) extends Serializable { + + /** SinkProgress without custom metrics. */ + protected[sql] def this(description: String) { + this(description, null) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -213,6 +245,12 @@ class SinkProgress protected[sql]( override def toString: String = prettyJson private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) + val jsonVal = ("description" -> JString(description)) + + if (customMetrics != null) { + jsonVal ~ ("customMetrics" -> parse(customMetrics)) + } else { + jsonVal + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index b4d9b68c78152..1efaead0845db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -84,4 +84,26 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("writer metrics") { + val sink = new MemorySinkV2 + val schema = new StructType().add("i", "int") + // batch 0 + var writer = new MemoryWriter(sink, 0, OutputMode.Append(), schema) + writer.commit( + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) + )) + assert(writer.getCustomMetrics.json() == "{\"numRows\":6}") + // batch 1 + writer = new MemoryWriter(sink, 1, OutputMode.Append(), schema + ) + writer.commit( + Array( + MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) + )) + assert(writer.getCustomMetrics.json() == "{\"numRows\":8}") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 9cceec90e4d51..a379569a96d41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -24,6 +24,9 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable import org.apache.commons.lang3.RandomStringUtils +import org.json4s.NoTypeHints +import org.json4s.jackson.JsonMethods._ +import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout @@ -475,6 +478,31 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("Check if custom metrics are reported") { + val streamInput = MemoryStream[Int] + implicit val formats = Serialization.formats(NoTypeHints) + testStream(streamInput.toDF(), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sink.customMetrics == "{\"numRows\":3}") + true + }, + AddData(streamInput, 4, 5, 6, 7), + CheckAnswer(1, 2, 3, 4, 5, 6, 7), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 4) + assert(lastProgress.get.sink.customMetrics == "{\"numRows\":7}") + true + } + ) + } + test("input row calculation with same V1 source used twice in self-join") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") From 51bee7aca13451167fa3e701fcd60f023eae5e61 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 7 Aug 2018 10:31:11 +0800 Subject: [PATCH 1327/2461] [SPARK-25018][INFRA] Use `Co-authored-by` and `Signed-off-by` git trailer in `merge_spark_pr.py` ## What changes were proposed in this pull request? In [Linux community](https://git.wiki.kernel.org/index.php/CommitMessageConventions), `Co-authored-by` and `Signed-off-by` git trailer have been used for awhile. Until recently, Github adopted `Co-authored-by` to include the work of co-authors in the profile contributions graph and the repository's statistics. It's a convention for recognizing multiple authors, and can encourage people to collaborate in OSS communities. Git provides a command line tools to read the metadata to know who commits the code to upstream, but it's not as easy as having `Signed-off-by` as part of the message so developers can find who is the relevant committers who can help with certain part of the codebase easier. For a single author PR, I purpose to use `Authored-by` and `Signed-off-by`, so the message will look like ``` Authored-by: Author's name Signed-off-by: Committer's name ``` For a multi-author PR, I purpose to use `Lead-authored-by:` and `Co-authored-by:` for the lead author and co-authors. The message will look like ``` Lead-authored-by: Lead Author's name Co-authored-by: CoAuthor's name Signed-off-by: Committer's name ``` It's also useful to include `Reviewed-by:` to give credits to the people who participate on the code reviewing. We can add this in the next iteration. Closes #21991 from dbtsai/script. Lead-authored-by: DB Tsai Co-authored-by: Liang-Chi Hsieh Co-authored-by: Brian Lindblom Co-authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/merge_spark_pr.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index fd3eeb007a845..7a6f7d2b891d3 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -142,6 +142,11 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): distinct_authors[0]) if primary_author == "": primary_author = distinct_authors[0] + else: + # When primary author is specified manually, de-dup it from author list and + # put it at the head of author list. + distinct_authors = list(filter(lambda x: x != primary_author, distinct_authors)) + distinct_authors.insert(0, primary_author) commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -154,13 +159,10 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): # to people every time someone creates a public fork of Spark. merge_message_flags += ["-m", body.replace("@", "")] - authors = "\n".join(["Author: %s" % a for a in distinct_authors]) - - merge_message_flags += ["-m", authors] + committer_name = run_cmd("git config --get user.name").strip() + committer_email = run_cmd("git config --get user.email").strip() if had_conflicts: - committer_name = run_cmd("git config --get user.name").strip() - committer_email = run_cmd("git config --get user.email").strip() message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % ( committer_name, committer_email) merge_message_flags += ["-m", message] @@ -168,6 +170,14 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): # The string "Closes #%s" string is required for GitHub to correctly close the PR merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] + authors = "Authored-by:" if len(distinct_authors) == 1 else "Lead-authored-by:" + authors += " %s" % (distinct_authors.pop(0)) + if len(distinct_authors) > 0: + authors += "\n" + "\n".join(["Co-authored-by: %s" % a for a in distinct_authors]) + authors += "\n" + "Signed-off-by: %s <%s>" % (committer_name, committer_email) + + merge_message_flags += ["-m", authors] + run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) continue_maybe("Merge complete (local ref %s). Push to %s?" % ( From 4446a0b0d9bd830f0e903d6780dedac4db572b5a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 7 Aug 2018 12:07:56 +0900 Subject: [PATCH 1328/2461] [SPARK-23914][SQL][FOLLOW-UP] refactor ArrayUnion ## What changes were proposed in this pull request? This PR refactors `ArrayUnion` based on [this suggestion](https://github.com/apache/spark/pull/21103#discussion_r205668821). 1. Generate optimized code for all of the primitive types except `boolean` 1. Generate code using `ArrayBuilder` or `ArrayBuffer` 1. Leave only a generic path in the interpreted path ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #21937 from kiszk/SPARK-23914-follow. --- .../expressions/collectionOperations.scala | 325 +++++++----------- .../CollectionExpressionsSuite.scala | 21 +- .../spark/sql/DataFrameFunctionsSuite.scala | 24 +- 3 files changed, 153 insertions(+), 217 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e385c2d9782e8..fbb182631eefa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3767,230 +3767,159 @@ object ArraySetLike { """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { - var hsInt: OpenHashSet[Int] = _ - var hsLong: OpenHashSet[Long] = _ - - def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getInt(idx) - if (!hsInt.contains(elem)) { - if (resultArray != null) { - resultArray.setInt(pos, elem) - } - hsInt.add(elem) - true - } else { - false - } - } - - def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getLong(idx) - if (!hsLong.contains(elem)) { - if (resultArray != null) { - resultArray.setLong(pos, elem) - } - hsLong.add(elem) - true - } else { - false - } - } + with ComplexTypeMergingExpression { - def evalIntLongPrimitiveType( - array1: ArrayData, - array2: ArrayData, - resultArray: ArrayData, - isLongType: Boolean): Int = { - // store elements into resultArray - var nullElementSize = 0 - var pos = 0 - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - val size = if (!isLongType) hsInt.size else hsLong.size - if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(size) - } - if (array.isNullAt(i)) { - if (nullElementSize == 0) { - if (resultArray != null) { - resultArray.setNullAt(pos) + @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } } - pos += 1 - nullElementSize = 1 + i += 1 } - } else { - val assigned = if (!isLongType) { - assignInt(array, i, resultArray, pos) + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } } else { - assignLong(array, i, resultArray, pos) + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } } - if (assigned) { - pos += 1 + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem } - } - i += 1 - } + })) + new GenericArrayData(arrayBuffer) } - pos } override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - if (elementTypeSupportEquals) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - // calculate result array size - hsInt = new OpenHashSet[Int] - val elements = evalIntLongPrimitiveType(array1, array2, null, false) - hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - IntegerType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) - } - evalIntLongPrimitiveType(array1, array2, resultArray, false) - resultArray - case LongType => - // avoid boxing of primitive long array elements - // calculate result array size - hsLong = new OpenHashSet[Long] - val elements = evalIntLongPrimitiveType(array1, array2, null, true) - hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - LongType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) - } - evalIntLongPrimitiveType(array1, array2, resultArray, true) - resultArray - case _ => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new OpenHashSet[Any] - var foundNullElement = false - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - if (array.isNullAt(i)) { - if (!foundNullElement) { - arrayBuffer += null - foundNullElement = true - } - } else { - val elem = array.get(i, elementType) - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) - } - arrayBuffer += elem - hs.add(elem) - } - } - i += 1 - } - } - new GenericArrayData(arrayBuffer) - } - } else { - ArrayUnion.unionOrdering(array1, array2, elementType, ordering) - } + evalUnion(array1, array2) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") - val pos = ctx.freshName("pos") val value = ctx.freshName("value") val size = ctx.freshName("size") - val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = - if (elementTypeSupportEquals) { - elementType match { - case ByteType | ShortType | IntegerType | LongType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - if (elementType == LongType) "Long" else "Int", - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), - if (elementType == LongType) "(long)" else "(int)", - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", - s"get($i, $et)", s"update($pos, $value)", "Object", "", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } - } else { - ("", "", "", "", "", "", "") - } + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) - nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (openHashElementType != "") { - // Here, we ensure elementTypeSupportEquals is true + nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") - val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") - val arrayData = classOf[ArrayData].getName - val arrays = ctx.freshName("arrays") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") val array = ctx.freshName("array") + val arrays = ctx.freshName("arrays") val arrayDataIdx = ctx.freshName("arrayDataIdx") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray = withArrayNullAssignment( + s""" + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |boolean $foundNullElement = false; - |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; - |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { - | $arrayData $array = $arrays[$arrayDataIdx]; - | for (int $i = 0; $i < $array.numElements(); $i++) { - | if ($array.isNullAt($i)) { - | $foundNullElement = true; - | } else { - | $hs.add$postFix($array.$getter); - | } - | } - |} - |int $size = $hs.size() + ($foundNullElement ? 1 : 0); - |$arrayBuilder - |$hs = new $openHashSet$postFix($classTag); - |$foundNullElement = false; - |int $pos = 0; + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |int $size = 0; + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |ArrayData[] $arrays = new ArrayData[]{$array1, $array2}; |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { - | $arrayData $array = $arrays[$arrayDataIdx]; + | ArrayData $array = $arrays[$arrayDataIdx]; | for (int $i = 0; $i < $array.numElements(); $i++) { - | if ($array.isNullAt($i)) { - | if (!$foundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $foundNullElement = true; - | } - | } else { - | $javaTypeName $value = $array.$getter; - | if (!$hs.contains($castOp $value)) { - | $hs.add$postFix($value); - | ${ev.value}.$setter; - | $pos++; - | } - | } + | $processArray | } |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin - } else { - val arrayUnion = classOf[ArrayUnion].getName - val et = ctx.addReferenceObj("elementTypeUnion", elementType) - val order = ctx.addReferenceObj("orderingUnion", ordering) - val method = "unionOrdering" - s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" - } - }) + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayUnionExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" + }) + } } override def prettyName: String = "array_union" @@ -4154,7 +4083,6 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") @@ -4268,7 +4196,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } else { nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayIntersectExpr", this) - s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" }) } } @@ -4387,7 +4315,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") @@ -4490,7 +4417,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } else { nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayExceptExpr", this) - s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" }) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4daa113869b5d..c6b3f9502f2bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1362,10 +1362,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) - val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) - val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) - val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) - val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) @@ -1384,8 +1390,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) - checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) - checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(abl0, abl1), Seq[Boolean](true, false)) + checkEvaluation(ArrayUnion(ab0, ab1), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(as0, as1), Seq[Short](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(af0, af1), Seq[Float](1.1F, 2.2F, 3.3F, 4.4F)) + checkEvaluation(ArrayUnion(ad0, ad1), Seq[Double](1.1, 2.2, 3.3, 4.4)) checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3c5831f33b23c..c04780db4e525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1148,28 +1148,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) val df6 = Seq((null, Array("a"))).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df6.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df6.selectExpr("array_union(a, b)") - } + }.getMessage.contains("data type mismatch")) val df7 = Seq((null, null)).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df7.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df7.selectExpr("array_union(a, b)") - } + }.getMessage.contains("data type mismatch")) val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df8.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df8.selectExpr("array_union(a, b)") - } + }.getMessage.contains("data type mismatch")) } test("concat function - arrays") { From 43763629f1d1a220cd91e2aed89152d065dfba24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 7 Aug 2018 14:28:14 +0800 Subject: [PATCH 1329/2461] [SPARK-25010][SQL] Rand/Randn should produce different values for each execution in streaming query ## What changes were proposed in this pull request? Like Uuid in SPARK-24896, Rand and Randn expressions now produce the same results for each execution in streaming query. It doesn't make too much sense for streaming queries. We should make them produce different results as Uuid. In this change, similar to Uuid, we assign new random seeds to Rand/Randn when returning optimized plan from `IncrementalExecution`. Note: Different to Uuid, Rand/Randn can be created with initial seed. Because we replace this initial seed at `IncrementalExecution`, it doesn't use the initial seed anymore. For now it seems to me not a big issue for streaming query. But need to confirm with others. cc zsxwing cloud-fan ## How was this patch tested? Added test. Closes #21980 from viirya/SPARK-25010. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/misc.scala | 5 ++++- .../expressions/randomExpressions.scala | 16 ++++++++++++-- .../streaming/IncrementalExecution.scala | 10 +++------ .../sql/streaming/StreamingQuerySuite.scala | 22 ++++++++++++++++++- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 5d98dac46cf17..0cdeda9b10516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -126,10 +126,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { """, note = "The function is non-deterministic.") // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful + with ExpressionWithRandomSeed { def this() = this(None) + override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 926c2f00d430d..b70c34141b97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -57,6 +57,14 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } +/** + * Represents the behavior of expressions which have a random seed and can renew the seed. + * Usually the random seed needs to be renewed at each execution under streaming queries. + */ +trait ExpressionWithRandomSeed { + def withNewSeed(seed: Long): Expression +} + /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -72,10 +80,12 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful """, note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit -case class Rand(child: Expression) extends RDG { +case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) + override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType)) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -110,10 +120,12 @@ object Rand { """, note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit -case class Randn(child: Expression) extends RDG { +case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed { def this() = this(Literal(Utils.random.nextLong(), LongType)) + override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType)) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index e9ffe129ca310..725abb318baa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,11 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.atomic.AtomicInteger -import scala.util.Random - import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} -import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Uuid} +import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, ExpressionWithRandomSeed} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule @@ -32,6 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.util.Utils /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] @@ -75,14 +74,11 @@ class IncrementalExecution( * with the desired literal */ override lazy val optimizedPlan: LogicalPlan = { - val random = new Random() - sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions { case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral - // SPARK-24896: Set the seed for random number generation in Uuid expressions. - case _: Uuid => Uuid(Some(random.nextLong())) + case e: ExpressionWithRandomSeed => e.withNewSeed(Utils.random.nextLong()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index a379569a96d41..848924dde296e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Uuid +import org.apache.spark.sql.catalyst.expressions.{Rand, Randn, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -885,6 +885,26 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(uuids.distinct.size == 2) } + test("Rand/Randn in streaming query should not produce same results in each execution") { + val rands = mutable.ArrayBuffer[Double]() + def collectRand: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach { r => + rands += r.getDouble(0) + rands += r.getDouble(1) + } + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(new Rand()), new Column(new Randn())) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectRand), + AddData(stream, 2), + CheckAnswer(collectRand) + ) + assert(rands.distinct.size == 4) + } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + "should not fail") { val df = spark.readStream.format("rate").load() From 388f5a0635a2812cd71b08352e3ddc20293ec189 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 7 Aug 2018 15:06:32 +0800 Subject: [PATCH 1330/2461] [SPARK-24817][CORE] Implement BarrierTaskContext.barrier() ## What changes were proposed in this pull request? Implement BarrierTaskContext.barrier(), to support global sync between all the tasks in a barrier stage. The function set a global barrier and waits until all tasks in this stage hit this barrier. Similar to MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same stage have reached this routine. The global sync shall finish immediately once all tasks in the same barrier stage reaches the same barrier. This PR implements BarrierTaskContext.barrier() based on netty-based RPC client, introduces new `BarrierCoordinator` and new `BarrierCoordinatorMessage`, and new config to handle timeout issue. ## How was this patch tested? Add `BarrierTaskContextSuite` to test `BarrierTaskContext.barrier()` Closes #21898 from jiangxb1987/taskcontext.barrier. Authored-by: Xingbo Jiang Signed-off-by: Wenchen Fan --- .../org/apache/spark/BarrierCoordinator.scala | 235 ++++++++++++++++++ .../org/apache/spark/BarrierTaskContext.scala | 62 ++++- .../scala/org/apache/spark/SparkContext.scala | 12 +- .../spark/internal/config/package.scala | 10 + .../spark/scheduler/TaskSchedulerImpl.scala | 20 ++ .../scheduler/BarrierTaskContextSuite.scala | 150 +++++++++++ 6 files changed, 481 insertions(+), 8 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/BarrierCoordinator.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala new file mode 100644 index 0000000000000..5e546c694e8d9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Timer, TimerTask} +import java.util.concurrent.ConcurrentHashMap +import java.util.function.{Consumer, Function} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} + +/** + * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus + * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is + * from. + */ +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { + override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" +} + +/** + * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync + * request is generated by `BarrierTaskContext.barrier()`, and identified by + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon + * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to + * collect enough global sync requests within a configured time, fail all the requests and return + * an Exception with timeout message. + */ +private[spark] class BarrierCoordinator( + timeoutInSecs: Long, + listenerBus: LiveListenerBus, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to + // fetch result, we shall fix the issue. + private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + + // Listen to StageCompleted event, clear corresponding ContextBarrierState. + private val listener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageInfo = stageCompleted.stageInfo + val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) + // Clear ContextBarrierState from a finished stage attempt. + cleanupBarrierStage(barrierId) + } + } + + // Record all active stage attempts that make barrier() call(s), and the corresponding internal + // state. + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] + + override def onStart(): Unit = { + super.onStart() + listenerBus.addToStatusQueue(listener) + } + + override def onStop(): Unit = { + try { + states.forEachValue(1, clearStateConsumer) + states.clear() + listenerBus.removeListener(listener) + } finally { + super.onStop() + } + } + + /** + * Provide the current state of a barrier() call. A state is created when a new stage attempt + * sends out a barrier() call, and recycled on stage completed. + * + * @param barrierId Identifier of the barrier stage that make a barrier() call. + * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall + * collect `numTasks` requests to succeed. + */ + private class ContextBarrierState( + val barrierId: ContextBarrierId, + val numTasks: Int) { + + // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used + // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or + // reset when a barrier() call fails due to timeout. + private var barrierEpoch: Int = 0 + + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. + private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + + // A timer task that ensures we may timeout for a barrier() call. + private var timerTask: TimerTask = null + + // Init a TimerTask for a barrier() call. + private def initTimerTask(): Unit = { + timerTask = new TimerTask { + override def run(): Unit = synchronized { + // Timeout current barrier() call, fail all the sync requests. + requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + + s"$timeoutInSecs second(s)."))) + cleanupBarrierStage(barrierId) + } + } + } + + // Cancel the current active TimerTask and release resources. + private def cancelTimerTask(): Unit = { + if (timerTask != null) { + timerTask.cancel() + timerTask = null + } + } + + // Process the global sync request. The barrier() call succeed if collected enough requests + // within a configured time, otherwise fail all the pending requests. + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + val taskId = request.taskAttemptId + val epoch = request.barrierEpoch + + // Require the number of tasks is correctly set from the BarrierTaskContext. + require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + + s"${request.numTasks} from Task $taskId, previously it was $numTasks.") + + // Check whether the epoch from the barrier tasks matches current barrierEpoch. + logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") + if (epoch != barrierEpoch) { + requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + + s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + "properly killed.")) + } else { + // If this is the first sync message received for a barrier() call, start timer to ensure + // we may timeout for the sync. + if (requesters.isEmpty) { + initTimerTask() + timer.schedule(timerTask, timeoutInSecs * 1000) + } + // Add the requester to array of RPCCallContexts pending for reply. + requesters += requester + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + + s"$taskId, current progress: ${requesters.size}/$numTasks.") + if (maybeFinishAllRequesters(requesters, numTasks)) { + // Finished current barrier() call successfully, clean up ContextBarrierState and + // increase the barrier epoch. + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + + s"tasks, finished successfully.") + barrierEpoch += 1 + requesters.clear() + cancelTimerTask() + } + } + } + + // Finish all the blocking barrier sync requests from a stage attempt successfully if we + // have received all the sync requests. + private def maybeFinishAllRequesters( + requesters: ArrayBuffer[RpcCallContext], + numTasks: Int): Boolean = { + if (requesters.size == numTasks) { + requesters.foreach(_.reply(())) + true + } else { + false + } + } + + // Cleanup the internal state of a barrier stage attempt. + def clear(): Unit = synchronized { + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + barrierEpoch = -1 + requesters.clear() + cancelTimerTask() + } + } + + // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. + private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { + val barrierState = states.remove(barrierId) + if (barrierState != null) { + barrierState.clear() + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + // Get or init the ContextBarrierState correspond to the stage attempt. + val barrierId = ContextBarrierId(stageId, stageAttemptId) + states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { + override def apply(key: ContextBarrierId): ContextBarrierState = + new ContextBarrierState(key, numTasks) + }) + val barrierState = states.get(barrierId) + + barrierState.handleRequest(context, request) + } + + private val clearStateConsumer = new Consumer[ContextBarrierState] { + override def accept(state: ContextBarrierState) = state.clear() + } +} + +private[spark] sealed trait BarrierCoordinatorMessage extends Serializable + +/** + * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + */ +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index ba303680d1a0f..8e2b15599b674 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,12 +17,17 @@ package org.apache.spark -import java.util.Properties +import java.util.{Properties, Timer, TimerTask} + +import scala.concurrent.duration._ +import scala.language.postfixOps import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.util.{RpcUtils, Utils} /** A [[TaskContext]] with extra info and tooling for a barrier stage. */ class BarrierTaskContext( @@ -39,6 +44,22 @@ class BarrierTaskContext( extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, taskMetrics) { + // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. + private val barrierCoordinator: RpcEndpointRef = { + val env = SparkEnv.get + RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) + } + + private val timer = new Timer("Barrier task timer for barrier() calls.") + + // Local barrierEpoch that identify a barrier() call from current task, it shall be identical + // with the driver side epoch. + private var barrierEpoch = 0 + + // Number of tasks of the current barrier stage, a barrier() call must collect enough requests + // from different tasks within the same barrier stage attempt to succeed. + private lazy val numTasks = getTaskInfos().size + /** * :: Experimental :: * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to @@ -80,7 +101,44 @@ class BarrierTaskContext( @Experimental @Since("2.4.0") def barrier(): Unit = { - // TODO SPARK-24817 implement global barrier. + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace("Current callSite: " + Utils.getCallSite()) + + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + barrierCoordinator.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. + timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + + s"$barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + throw e + } finally { + timerTask.cancel() + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e8bacee3b0215..a7ffb354c09ca 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1935,6 +1935,12 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _executorAllocationManager.foreach(_.stop()) } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } if (_listenerBusStarted) { Utils.tryLogNonFatalError { listenerBus.stop() @@ -1944,12 +1950,6 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } - if (_dagScheduler != null) { - Utils.tryLogNonFatalError { - _dagScheduler.stop() - } - _dagScheduler = null - } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8fef2aa6863c5..eb08628ce1112 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -567,4 +567,14 @@ package object config { .intConf .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + + private[spark] val BARRIER_SYNC_TIMEOUT = + ConfigBuilder("spark.barrier.sync.timeout") + .doc("The timeout in seconds for each barrier() call from a barrier task. If the " + + "coordinator didn't receive all the sync messages from barrier tasks within the " + + "configed time, throw a SparkException to fail all the tasks. The default value is set " + + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") + .timeConf(TimeUnit.SECONDS) + .checkValue(v => v > 0, "The value should be a positive time value.") + .createWithDefaultString("365d") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 72691389d271c..8992d7e2284a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.internal.config +import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -138,6 +139,19 @@ private[spark] class TaskSchedulerImpl( // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) + + private[scheduler] var barrierCoordinator: RpcEndpoint = null + + private def maybeInitBarrierCoordinator(): Unit = { + if (barrierCoordinator == null) { + barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, + sc.env.rpcEnv) + sc.env.rpcEnv.setupEndpoint("barrierSync", barrierCoordinator) + logInfo("Registered BarrierCoordinator endpoint") + } + } + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -413,6 +427,9 @@ private[spark] class TaskSchedulerImpl( s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + "been blacklisted or cannot fulfill task locality requirements.") + // materialize the barrier coordinator. + maybeInitBarrierCoordinator() + // Update the taskInfos into all the barrier task properties. val addressesStr = addressesWithDescs // Addresses ordered by partitionId @@ -566,6 +583,9 @@ private[spark] class TaskSchedulerImpl( if (taskResultGetter != null) { taskResultGetter.stop() } + if (barrierCoordinator != null) { + barrierCoordinator.stop() + } starvationTimer.cancel() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala new file mode 100644 index 0000000000000..5f96d6fb0cdb6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Random + +import org.apache.spark._ + +class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { + + test("global sync by barrier() call") { + val conf = new SparkConf() + // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` + // call is actually useful. + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + Seq(System.currentTimeMillis()).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish global sync within a short time slot. + assert(times.max - times.min <= 1000) + } + + test("support multiple barrier() call within a single task") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time between two global syncs. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + + test("throw exception on barrier() call timeout") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Task 3 shall sleep 2000ms to ensure barrier() call timeout + if (context.taskAttemptId == 3) { + Thread.sleep(2000) + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if barrier() call doesn't happen on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + if (context.taskAttemptId != 0) { + context.barrier() + } + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if the number of barrier() calls are not the same on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + try { + if (context.taskAttemptId == 0) { + // Due to some non-obvious reason, the code can trigger an Exception and skip the + // following statements within the try ... catch block, including the first barrier() + // call. + throw new SparkException("test") + } + context.barrier() + } catch { + case e: Exception => // Do nothing + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } +} From 88e0c7bbd566240d182332299cf6695890a953ad Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 7 Aug 2018 15:43:41 +0800 Subject: [PATCH 1331/2461] [SPARK-24341][SQL] Support only IN subqueries with the same number of items per row ## What changes were proposed in this pull request? Using struct types in subqueries with the `IN` clause can generate invalid plans in `RewritePredicateSubquery`. Indeed, we are not handling clearly the cases when the outer value is a struct or the output of the inner subquery is a struct. The PR aims to make Spark's behavior the same as the one of the other RDBMS - namely Oracle and Postgres behavior were checked. So we consider valid only queries having the same number of fields in the outer value and in the subquery. This means that: - `(a, b) IN (select c, d from ...)` is a valid query; - `(a, b) IN (select (c, d) from ...)` throws an AnalysisException, as in the subquery we have only one field of type struct while in the outer value we have 2 fields; - `a IN (select (c, d) from ...)` - where `a` is a struct - is a valid query. ## How was this patch tested? Added UT Closes #21403 from mgaido91/SPARK-24313. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 19 +++- .../sql/catalyst/analysis/TypeCoercion.scala | 28 +---- .../spark/sql/catalyst/dsl/package.scala | 8 +- .../catalyst/expressions/Canonicalize.scala | 2 - .../sql/catalyst/expressions/predicates.scala | 105 +++++++++++------- .../sql/catalyst/expressions/subquery.scala | 4 +- .../sql/catalyst/optimizer/expressions.scala | 1 + .../sql/catalyst/optimizer/subquery.scala | 20 ++-- .../sql/catalyst/parser/AstBuilder.scala | 7 +- .../analysis/AnalysisErrorSuite.scala | 7 +- .../analysis/ResolveSubquerySuite.scala | 5 +- .../catalyst/optimizer/OptimizeInSuite.scala | 15 +++ .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../parser/ExpressionParserSuite.scala | 14 ++- .../inputs/subquery/in-subquery/in-basic.sql | 14 +++ .../subquery/in-subquery/in-basic.sql.out | 70 ++++++++++++ .../subq-input-typecheck.sql.out | 20 ++-- 17 files changed, 240 insertions(+), 103 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b5016fdb29d92..d391a93144c45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1433,11 +1433,26 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + case InSubquery(values, l @ ListQuery(_, _, exprId, _)) + if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - In(value, Seq(expr)) + val subqueryOutput = expr.plan.output + val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) + if (values.length != subqueryOutput.length) { + throw new AnalysisException( + s"""Cannot analyze ${resolvedIn.sql}. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length} + |#columns in right hand side: ${subqueryOutput.length} + |Left side columns: + |[${values.map(_.sql).mkString(", ")}] + |Right side columns: + |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) + } + resolvedIn } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 7dd26b62b1fc4..648aa9ee8fa0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -449,15 +449,6 @@ object TypeCoercion { * Analysis Exception will be raised at the type checking phase. */ case class InConversion(conf: SQLConf) extends TypeCoercionRule { - private def flattenExpr(expr: Expression): Seq[Expression] = { - expr match { - // Multi columns in IN clause is represented as a CreateNamedStruct. - // flatten the named struct to get the list of expressions. - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. @@ -465,11 +456,9 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && flattenExpr(a).length == sub.output.length => - // LHS is the value expression of IN subquery. - val lhs = flattenExpr(a) - + case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) + if !i.resolved && lhs.length == sub.output.length => + // LHS is the value expressions of IN subquery. // RHS is the subquery output. val rhs = sub.output @@ -485,20 +474,13 @@ object TypeCoercion { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val castedLhs = lhs.zip(commonTypes).map { + val newLhs = lhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } - // Before constructing the In expression, wrap the multi values in LHS - // in a CreatedNamedStruct. - val newLhs = castedLhs match { - case Seq(lhs) => lhs - case _ => CreateStruct(castedLhs) - } - val newSub = Project(castedRhs, sub) - In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) + InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) } else { i } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2b582b5be61a7..d3ccd18d0245e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,7 +88,13 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) + def in(list: Expression*): Expression = list match { + case Seq(l: ListQuery) => expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, l) + case other => InSubquery(Seq(other), l) + } + case _ => In(expr, list) + } def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 7541f527a52a8..fe6db8b344d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -87,8 +87,6 @@ object Canonicalize { case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) // order the list in the In operator - // In subqueries contain only one element of type ListQuery. So checking that the length > 1 - // we are not reordering In subqueries. case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) case _ => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f4077f78006b1..149bd79278a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,6 +138,66 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override def children: Seq[Expression] = values :+ query + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ($query)" + override def sql: String = s"(${value.sql} IN (${query.sql}))" +} + /** * Evaluates to `true` if `list` contains `value`. @@ -169,44 +229,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { - list match { - case ListQuery(_, _, _, childOutputs) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. - |#columns in right hand side: ${childOutputs.length}. - |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = valExprs.zip(childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - case _ => - TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") - } + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } @@ -307,9 +331,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def sql: String = { - val childrenSQL = children.map(_.sql) - val valueSQL = childrenSQL.head - val listSQL = childrenSQL.tail.mkString(", ") + val valueSQL = value.sql + val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 6acc87a3e7367..fc1caed84e272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper { def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { case _: Exists | Not(_: Exists) => false - case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false + case _: InSubquery | Not(_: InSubquery) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { - case In(_, Seq(_: ListQuery)) => true + case _: InSubquery => true case _ => false }.isDefined }.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e7b4730e11115..5629b72894225 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -523,6 +523,7 @@ object NullPropagation extends Rule[LogicalPlan] { // If the value expression is NULL then transform the In expression to null literal. case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index de89e17e51f1b..e9b7a8b76e683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def getValueExpression(e: Expression): Seq[Expression] = { - e match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should @@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 732d762335f1e..7bc1f63e30540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1103,6 +1103,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1111,7 +1116,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 0a5194a287ecc..94778840d706b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -534,7 +534,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -543,12 +543,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter( + Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 1bf8d76da04d8..74a8590b5eefe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter( + InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 86522a6a54ed5..a36083b847043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") { + val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 169b8737d808b..8a5a55146726e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) .select('a).analyze assert(outerQuery.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index c37b9f148cf48..781fc1e957ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,19 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In('a, Seq(ListQuery(table("c").select('b))))) + InSubquery(Seq('a), ListQuery(table("c").select('b)))) + + assertEqual( + "(a, b, c) in (select d, e, f from g)", + InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f)))) + + assertEqual( + "(a, b) in (select c from d)", + InSubquery(Seq('a, 'b), ListQuery(table("d").select('c)))) + + assertEqual( + "(a) in (select b from c)", + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql new file mode 100644 index 0000000000000..f4ffc20086386 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql @@ -0,0 +1,14 @@ +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1); +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2); +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2); + +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b); +-- Invalid query, see SPARK-24341 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b); + +-- Aliasing is needed as a workaround for SPARK-24443 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b); +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out new file mode 100644 index 0000000000000..088db55d66406 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -0,0 +1,70 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b) +-- !query 3 schema +struct<1:int> +-- !query 3 output + + + +-- !query 4 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2 +#columns in right hand side: 1 +Left side columns: +[tab_a.`a1`, tab_a.`b1`] +Right side columns: +[`named_struct(a2, a2, b2, b2)`]; + + +-- !query 5 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b) +-- !query 5 schema +struct +-- !query 5 output +1 + + +-- !query 6 +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b) +-- !query 6 schema +struct +-- !query 6 output +3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index dcd30055bca19..c52e5706deeee 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -92,15 +92,15 @@ t1a IN (SELECT t2a, t2b struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 1. -#columns in right hand side: 2. +#columns in left hand side: 1 +#columns in right hand side: 2 Left side columns: -[t1.`t1a`]. +[t1.`t1a`] Right side columns: -[t2.`t2a`, t2.`t2b`].; +[t2.`t2a`, t2.`t2b`]; -- !query 8 @@ -113,15 +113,15 @@ WHERE struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. +#columns in left hand side: 2 +#columns in right hand side: 1 Left side columns: -[t1.`t1a`, t1.`t1b`]. +[t1.`t1a`, t1.`t1b`] Right side columns: -[t2.`t2a`].; +[t2.`t2a`]; -- !query 9 From 131ca146ed390cd0109cd6e8c95b61e418507080 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 7 Aug 2018 17:14:30 +0800 Subject: [PATCH 1332/2461] =?UTF-8?q?[SPARK-24005][CORE]=20Remove=20usage?= =?UTF-8?q?=20of=20Scala=E2=80=99s=20parallel=20collection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In the PR, I propose to replace Scala parallel collections by new methods `parmap()`. The methods use futures to transform a sequential collection by applying a lambda function to each element in parallel. The result of `parmap` is another regular (sequential) collection. The proposed `parmap` method aims to solve the problem of impossibility to interrupt parallel Scala collection. This possibility is needed for reliable task preemption. ## How was this patch tested? A test was added to `ThreadUtilsSuite` Closes #21913 from MaxGekk/par-map. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/rdd/UnionRDD.scala | 17 +++-- .../org/apache/spark/util/ThreadUtils.scala | 64 ++++++++++++++++++- .../apache/spark/util/ThreadUtilsSuite.scala | 33 ++++++++++ .../spark/sql/execution/command/ddl.scala | 34 +++++----- .../parquet/ParquetFileFormat.scala | 40 +++++------- .../parquet/ParquetFileFormatSuite.scala | 4 +- .../util/FileBasedWriteAheadLog.scala | 6 +- 7 files changed, 142 insertions(+), 56 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 60e383afadf1c..4b6f73235a57a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,12 +20,13 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.ExecutionContext import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.ThreadUtils.parmap import org.apache.spark.util.Utils /** @@ -59,8 +60,7 @@ private[spark] class UnionPartition[T: ClassTag]( } object UnionRDD { - private[spark] lazy val partitionEvalTaskSupport = - new ForkJoinTaskSupport(new ForkJoinPool(8)) + private[spark] lazy val threadPool = new ForkJoinPool(8) } @DeveloperApi @@ -74,14 +74,13 @@ class UnionRDD[T: ClassTag]( rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) override def getPartitions: Array[Partition] = { - val parRDDs = if (isPartitionListingParallel) { - val parArray = rdds.par - parArray.tasksupport = UnionRDD.partitionEvalTaskSupport - parArray + val partitionLengths = if (isPartitionListingParallel) { + implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool) + parmap(rdds)(_.partitions.length) } else { - rdds + rdds.map(_.partitions.length) } - val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum) + val array = new Array[Partition](partitionLengths.sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 0f08a2b0ad895..f0e5addbe5b56 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,8 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import scala.collection.TraversableLike +import scala.collection.generic.CanBuildFrom +import scala.language.higherKinds + import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal @@ -254,4 +258,62 @@ private[spark] object ThreadUtils { executor.shutdownNow() } } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param prefix - the prefix assigned to the underlying thread pool. + * @param maxThreads - maximum number of thread can be created during execution. + * @param f - the lambda function will be applied to each element of `in`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I], prefix: String, maxThreads: Int) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]] // for Future.sequence + ): Col[O] = { + val pool = newForkJoinPool(prefix, maxThreads) + try { + implicit val ec = ExecutionContext.fromExecutor(pool) + + parmap(in)(f) + } finally { + pool.shutdownNow() + } + } + + /** + * Transforms input collection by applying the given function to each element in parallel fashion. + * Comparing to the map() method of Scala parallel collections, this method can be interrupted + * at any time. This is useful on canceling of task execution, for example. + * + * @param in - the input collection which should be transformed in parallel. + * @param f - the lambda function will be applied to each element of `in`. + * @param ec - an execution context for parallel applying of the given function `f`. + * @tparam I - the type of elements in the input collection. + * @tparam O - the type of elements in resulted collection. + * @return new collection in which each element was given from the input collection `in` by + * applying the lambda function `f`. + */ + def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] + (in: Col[I]) + (f: I => O) + (implicit + cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map + cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence + ec: ExecutionContext + ): Col[O] = { + val futures = in.map(x => Future(f(x))) + val futureSeq = Future.sequence(futures) + + awaitResult(futureSeq, Duration.Inf) + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index ae3b3d829f1bb..604f1e1ca3101 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -133,4 +133,37 @@ class ThreadUtilsSuite extends SparkFunSuite { "stack trace contains unexpected references to ThreadUtils" ) } + + test("parmap should be interruptible") { + val t = new Thread() { + setDaemon(true) + + override def run() { + try { + // "par" is uninterruptible. The following will keep running even if the thread is + // interrupted. We should prefer to use "ThreadUtils.parmap". + // + // (1 to 10).par.flatMap { i => + // Thread.sleep(100000) + // 1 to i + // } + // + ThreadUtils.parmap(1 to 10, "test", 2) { i => + Thread.sleep(100000) + 1 to i + }.flatten + } catch { + case _: InterruptedException => // excepted + } + } + } + t.start() + eventually(timeout(10.seconds)) { + assert(t.isAlive) + } + t.interrupt() + eventually(timeout(10.seconds)) { + assert(!t.isAlive) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index e1faecedd20ed..7a6f5741862ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import java.util.Locale import scala.collection.{GenMap, GenSeq} -import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.ExecutionContext import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types._ import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} +import org.apache.spark.util.ThreadUtils.parmap // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -621,8 +622,9 @@ case class AlterTableRecoverPartitionsCommand( val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = try { + implicit val ec = ExecutionContext.fromExecutor(evalPool) scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, - spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq + spark.sessionState.conf.resolver) } finally { evalPool.shutdown() } @@ -654,23 +656,13 @@ case class AlterTableRecoverPartitionsCommand( spec: TablePartitionSpec, partitionNames: Seq[String], threshold: Int, - resolver: Resolver, - evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { + resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } - val statuses = fs.listStatus(path, filter) - val statusPar: GenSeq[FileStatus] = - if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { - // parallelize the list of partitions here, then we can have better parallelism later. - val parArray = statuses.par - parArray.tasksupport = evalTaskSupport - parArray - } else { - statuses - } - statusPar.flatMap { st => + val statuses = fs.listStatus(path, filter).toSeq + def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = { val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) @@ -679,7 +671,7 @@ case class AlterTableRecoverPartitionsCommand( val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), - partitionNames.drop(1), threshold, resolver, evalTaskSupport) + partitionNames.drop(1), threshold, resolver) } else { logWarning( s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") @@ -690,6 +682,14 @@ case class AlterTableRecoverPartitionsCommand( Seq.empty } } + val result = if (partitionNames.length > 1 && + statuses.length > threshold || partitionNames.length > 2) { + parmap(statuses)(handleStatus _) + } else { + statuses.map(handleStatus) + } + + result.flatten } private def gatherPartitionStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 283d7761d22d7..b2409f3470e73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -22,7 +22,6 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.parallel.ForkJoinTaskSupport import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration @@ -532,30 +531,23 @@ object ParquetFileFormat extends Logging { conf: Configuration, partFiles: Seq[FileStatus], ignoreCorruptFiles: Boolean): Seq[Footer] = { - val parFiles = partFiles.par - val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8) - parFiles.tasksupport = new ForkJoinTaskSupport(pool) - try { - parFiles.flatMap { currentFile => - try { - // Skips row group information since we only need the schema. - // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, - // when it can't read the footer. - Some(new Footer(currentFile.getPath(), - ParquetFileReader.readFooter( - conf, currentFile, SKIP_ROW_GROUPS))) - } catch { case e: RuntimeException => - if (ignoreCorruptFiles) { - logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) - None - } else { - throw new IOException(s"Could not read footer for file: $currentFile", e) - } + ThreadUtils.parmap(partFiles, "readingParquetFooters", 8) { currentFile => + try { + // Skips row group information since we only need the schema. + // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, + // when it can't read the footer. + Some(new Footer(currentFile.getPath(), + ParquetFileReader.readFooter( + conf, currentFile, SKIP_ROW_GROUPS))) + } catch { case e: RuntimeException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) + None + } else { + throw new IOException(s"Could not read footer for file: $currentFile", e) } - }.seq - } finally { - pool.shutdown() - } + } + }.flatten } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index 3a0867fd2b78b..94abf115cef35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -51,9 +51,9 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo } testReadFooters(true) - val exception = intercept[java.io.IOException] { + val exception = intercept[SparkException] { testReadFooters(false) - } + }.getCause assert(exception.getMessage().contains("Could not read footer for file")) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 2e8599026ea1d..bba071e80c0e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -312,10 +312,10 @@ private[streaming] object FileBasedWriteAheadLog { handler: I => Iterator[O]): Iterator[O] = { val taskSupport = new ExecutionContextTaskSupport(executionContext) val groupSize = taskSupport.parallelismLevel.max(8) + implicit val ec = executionContext + source.grouped(groupSize).flatMap { group => - val parallelCollection = group.par - parallelCollection.tasksupport = taskSupport - parallelCollection.map(handler) + ThreadUtils.parmap(group)(handler) }.flatten } } From 819c4de45af2fe39bac8363241d0001b2e83f858 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 7 Aug 2018 17:24:25 +0800 Subject: [PATCH 1333/2461] [SPARK-24772][SQL] Avro: support logical date type ## What changes were proposed in this pull request? Support Avro logical date type: https://avro.apache.org/docs/1.8.2/spec.html#Date ## How was this patch tested? Unit test Closes #21984 from gengliangwang/avro_date. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../spark/sql/avro/AvroDeserializer.scala | 5 ++ .../spark/sql/avro/AvroSerializer.scala | 2 +- .../spark/sql/avro/SchemaConverters.scala | 12 +++- external/avro/src/test/resources/date.avro | Bin 0 -> 209 bytes .../org/apache/spark/sql/avro/AvroSuite.scala | 54 +++++++++++++++++- 5 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 external/avro/src/test/resources/date.avro diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 394a62bf82795..74677a29afcb4 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -84,6 +84,9 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { case (INT, IntegerType) => (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) @@ -100,6 +103,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") } + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. case (LONG, DateType) => (updater, ordinal, value) => updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 382f9a750c16c..988582698d826 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -92,7 +92,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: case BinaryType => (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) case DateType => - (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY + (getter, ordinal) => getter.getInt(ordinal) case TimestampType => avroType.getLogicalType match { case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000 case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 69295398775e6..245e68d242c50 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.avro import scala.collection.JavaConverters._ import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder} -import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} +import org.apache.avro.LogicalTypes.{Date, TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType @@ -38,7 +38,10 @@ object SchemaConverters { */ def toSqlType(avroSchema: Schema): SchemaType = { avroSchema.getType match { - case INT => SchemaType(IntegerType, nullable = false) + case INT => avroSchema.getLogicalType match { + case _: Date => SchemaType(DateType, nullable = false) + case _ => SchemaType(IntegerType, nullable = false) + } case STRING => SchemaType(StringType, nullable = false) case BOOLEAN => SchemaType(BooleanType, nullable = false) case BYTES => SchemaType(BinaryType, nullable = false) @@ -121,7 +124,10 @@ object SchemaConverters { case BooleanType => builder.booleanType() case ByteType | ShortType | IntegerType => builder.intType() case LongType => builder.longType() - case DateType => builder.longType() + case DateType => builder + .intBuilder() + .prop(LogicalType.LOGICAL_TYPE_PROP, LogicalTypes.date().getName) + .endInt() case TimestampType => val timestampType = outputTimestampType match { case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis() diff --git a/external/avro/src/test/resources/date.avro b/external/avro/src/test/resources/date.avro new file mode 100644 index 0000000000000000000000000000000000000000..3a6761704cb169b69e0d8dcf61447e51d90a04c2 GIT binary patch literal 209 zcmeZI%3@>@ODrqO*DFrWNX<>0z*MbNQdy9yWTl`~l$xAhl%k}gpp=)Gn_66um<$%q z$xqKrPRxOcgH)EJ7MFndX_=`xDaAmMXt*iWN>KG7P*YP9OHx5Cx;;@x1E2_6G_LL&BYT}`Ey=nVdJ~SiY^EMS{_DB literal 0 HcmV?d00001 diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index b4dcf6c6c9330..47995bb39a05e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -33,6 +33,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -67,6 +68,27 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // writer.close() val timestampAvro = testFile("timestamp.avro") + // The test file date.avro is generated via following Python code: + // import json + // import avro.schema + // from avro.datafile import DataFileWriter + // from avro.io import DatumWriter + // + // write_schema = avro.schema.parse(json.dumps({ + // "namespace": "logical", + // "type": "record", + // "name": "test", + // "fields": [ + // {"name": "date", "type": {"type": "int", "logicalType": "date"}} + // ] + // })) + // + // writer = DataFileWriter(open("date.avro", "wb"), DatumWriter(), write_schema) + // writer.append({"date": 7}) + // writer.append({"date": 365}) + // writer.close() + val dateAvro = testFile("date.avro") + override protected def beforeAll(): Unit = { super.beforeAll() spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) @@ -350,9 +372,35 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.createDataFrame(rdd, schema) df.write.format("avro").save(dir.toString) assert(spark.read.format("avro").load(dir.toString).count == rdd.count) - assert( - spark.read.format("avro").load(dir.toString).select("date").collect().map(_(0)).toSet == - Array(null, 1451865600000L, 1459987200000L).toSet) + checkAnswer( + spark.read.format("avro").load(dir.toString).select("date"), + Seq(Row(null), Row(new Date(1451865600000L)), Row(new Date(1459987200000L)))) + } + } + + test("Logical type: date") { + val expected = Seq(7, 365).map(t => Row(DateTimeUtils.toJavaDate(t))) + val df = spark.read.format("avro").load(dateAvro) + + checkAnswer(df, expected) + + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "date", "type": {"type": "int", "logicalType": "date"}} + ] + } + """ + + checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(dateAvro), + expected) + + withTempPath { dir => + df.write.format("avro").save(dir.toString) + checkAnswer(spark.read.format("avro").load(dir.toString), expected) } } From b4bf8be549afa51e931e48dd79ddd9480f567b13 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 7 Aug 2018 21:11:08 +0800 Subject: [PATCH 1334/2461] [SPARK-19602][SQL] Support column resolution of fully qualified column name ( 3 part name) ## What changes were proposed in this pull request? The design details is attached to the JIRA issue [here](https://drive.google.com/file/d/1zKm3aNZ3DpsqIuoMvRsf0kkDkXsAasxH/view) High level overview of the changes are: - Enhance the qualifier to be more than one string - Add support to store the qualifier. Enhance the lookupRelation to keep the qualifier appropriately. - Enhance the table matching column resolution algorithm to account for qualifier being more than a string. - Enhance the table matching algorithm in UnresolvedStar.expand - Ensure that we continue to support select t1.i1 from db1.t1 ## How was this patch tested? - New tests are added. - Several test scenarios were added in a separate [test pr 17067](https://github.com/apache/spark/pull/17067). The tests that were not supported earlier are marked with TODO markers and those are now supported with the code changes here. - Existing unit tests ( hive, catalyst and sql) were run successfully. Closes #17185 from skambha/colResolution. Authored-by: Sunitha Kambhampati Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/unresolved.scala | 59 ++++++++++---- .../sql/catalyst/catalog/SessionCatalog.scala | 7 +- .../expressions/higherOrderFunctions.scala | 4 +- .../expressions/namedExpressions.scala | 35 ++++---- .../sql/catalyst/expressions/package.scala | 67 ++++++++++++---- .../spark/sql/catalyst/identifiers.scala | 16 ++++ .../plans/logical/basicLogicalOperators.scala | 29 +++++-- .../catalog/SessionCatalogSuite.scala | 6 +- .../expressions/ExpressionSetSuite.scala | 4 +- .../SubexpressionEliminationSuite.scala | 2 +- .../inputs/columnresolution-views.sql | 2 - .../sql-tests/inputs/columnresolution.sql | 12 +-- .../results/columnresolution-negative.sql.out | 26 +++--- .../results/columnresolution-views.sql.out | 10 +-- .../results/columnresolution.sql.out | 80 +++++++++---------- .../results/string-functions.sql.out | 4 +- .../invalid-correlation.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../benchmark/TPCDSQueryBenchmark.scala | 2 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 16 ++++ 22 files changed, 249 insertions(+), 145 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 71e23175168e2..c1ec736c32ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -104,12 +104,12 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override lazy val resolved = false override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this - override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this + override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this @@ -240,7 +240,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false @@ -262,17 +262,46 @@ abstract class Star extends LeafExpression with NamedExpression { */ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { - override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { + /** + * Returns true if the nameParts match the qualifier of the attribute + * + * There are two checks: i) Check if the nameParts match the qualifier fully. + * E.g. SELECT db.t1.* FROM db1.t1 In this case, the nameParts is Seq("db1", "t1") and + * qualifier of the attribute is Seq("db1","t1") + * ii) If (i) is not true, then check if nameParts is only a single element and it + * matches the table portion of the qualifier + * + * E.g. SELECT t1.* FROM db1.t1 In this case nameParts is Seq("t1") and + * qualifier is Seq("db1","t1") + * SELECT a.* FROM db1.t1 AS a + * In this case nameParts is Seq("a") and qualifier for + * attribute is Seq("a") + */ + private def matchedQualifier( + attribute: Attribute, + nameParts: Seq[String], + resolver: Resolver): Boolean = { + val qualifierList = attribute.qualifier + + val matched = nameParts.corresponds(qualifierList)(resolver) || { + // check if it matches the table portion of the qualifier + if (nameParts.length == 1 && qualifierList.nonEmpty) { + resolver(nameParts.head, qualifierList.last) + } else { + false + } + } + matched + } + + override def expand( + input: LogicalPlan, + resolver: Resolver): Seq[NamedExpression] = { // If there is no table specified, use all input attributes. if (target.isEmpty) return input.output - val expandedAttributes = - if (target.get.size == 1) { - // If there is a table, pick out attributes that are part of this table. - input.output.filter(_.qualifier.exists(resolver(_, target.get.head))) - } else { - List() - } + val expandedAttributes = input.output.filter(matchedQualifier(_, target.get, resolver)) + if (expandedAttributes.nonEmpty) return expandedAttributes // Try to resolve it as a struct expansion. If there is a conflict and both are possible, @@ -316,8 +345,8 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens // If there is no table specified, use all input attributes that match expr case None => input.output.filter(_.name.matches(pattern)) // If there is a table, pick out attributes that are part of this table that match expr - case Some(t) => input.output.filter(_.qualifier.exists(resolver(_, t))) - .filter(_.name.matches(pattern)) + case Some(t) => input.output.filter(a => a.qualifier.nonEmpty && + resolver(a.qualifier.last, t)).filter(_.name.matches(pattern)) } } @@ -345,7 +374,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") @@ -403,7 +432,7 @@ case class UnresolvedAlias( extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") - override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 2f60eb30f7240..cd243b87652f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -684,6 +684,7 @@ class SessionCatalog( * * If the relation is a view, we generate a [[View]] operator from the view description, and * wrap the logical plan in a [[SubqueryAlias]] which will track the name of the view. + * [[SubqueryAlias]] will also keep track of the name and database(optional) of the table/view * * @param name The name of the table/view that we look up. */ @@ -693,7 +694,7 @@ class SessionCatalog( val table = formatTableName(name.table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(table, viewDef) + SubqueryAlias(table, db, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempViews.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -706,9 +707,9 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(table, child) + SubqueryAlias(table, db, child) } else { - SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) + SubqueryAlias(table, db, UnresolvedCatalogRelation(metadata)) } } else { SubqueryAlias(table, tempViews(table)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 20c7f7d43b9dc..5d4665917009b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -41,13 +41,13 @@ case class NamedLambdaVariable( with NamedExpression with CodegenFallback { - override def qualifier: Option[String] = None + override def qualifier: Seq[String] = Seq.empty override def newInstance(): NamedExpression = copy(value = new AtomicReference(), exprId = NamedExpression.newExprId) override def toAttribute: Attribute = { - AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, None) + AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty) } override def eval(input: InternalRow): Any = value.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index ce5c2804d08ee..584a2946bd564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -71,19 +71,22 @@ trait NamedExpression extends Expression { * multiple qualifiers, it is possible that there are other possible way to refer to this * attribute. */ - def qualifiedName: String = (qualifier.toSeq :+ name).mkString(".") + def qualifiedName: String = (qualifier :+ name).mkString(".") /** * Optional qualifier for the expression. + * Qualifier can also contain the fully qualified information, for e.g, Sequence of string + * containing the database and the table name * * For now, since we do not allow using original table name to qualify a column name once the * table is aliased, this can only be: * * 1. Empty Seq: when an attribute doesn't have a qualifier, * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. - * 2. Single element: either the table name or the alias name of the table. + * 2. Seq with a Single element: either the table name or the alias name of the table. + * 3. Seq with 2 elements: database name and table name */ - def qualifier: Option[String] + def qualifier: Seq[String] def toAttribute: Attribute @@ -109,7 +112,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute - def withQualifier(newQualifier: Option[String]): Attribute + def withQualifier(newQualifier: Seq[String]): Attribute def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute @@ -130,14 +133,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. - * @param qualifier An optional string that can be used to referred to this attribute in a fully - * qualified way. Consider the examples tableName.name, subQueryAlias.name. - * tableName and subQueryAlias are possible qualifiers. + * @param qualifier An optional Seq of string that can be used to refer to this attribute in a + * fully qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None, + val qualifier: Seq[String] = Seq.empty, val explicitMetadata: Option[Metadata] = None) extends UnaryExpression with NamedExpression { @@ -201,7 +204,7 @@ case class Alias(child: Expression, name: String)( } override def sql: String = { - val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" s"${child.sql} AS $qualifierPrefix${quoteIdentifier(name)}" } } @@ -225,9 +228,11 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifier: Option[String] = None) + val qualifier: Seq[String] = Seq.empty[String]) extends Attribute with Unevaluable { + // currently can only handle qualifier of length 2 + require(qualifier.length <= 2) /** * Returns true iff the expression id is the same for both attributes. */ @@ -286,7 +291,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifier. */ - override def withQualifier(newQualifier: Option[String]): AttributeReference = { + override def withQualifier(newQualifier: Seq[String]): AttributeReference = { if (newQualifier == qualifier) { this } else { @@ -324,7 +329,7 @@ case class AttributeReference( override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" override def sql: String = { - val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" s"$qualifierPrefix${quoteIdentifier(name)}" } } @@ -350,12 +355,12 @@ case class PrettyAttribute( override def withNullability(newNullability: Boolean): Attribute = throw new UnsupportedOperationException override def newInstance(): Attribute = throw new UnsupportedOperationException - override def withQualifier(newQualifier: Option[String]): Attribute = + override def withQualifier(newQualifier: Seq[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def withMetadata(newMetadata: Metadata): Attribute = throw new UnsupportedOperationException - override def qualifier: Option[String] = throw new UnsupportedOperationException + override def qualifier: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = true } @@ -371,7 +376,7 @@ case class OuterReference(e: NamedExpression) override def prettyName: String = "outer" override def name: String = e.name - override def qualifier: Option[String] = e.qualifier + override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute override def newInstance(): NamedExpression = OuterReference(e.newInstance()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 8a06daa37132d..11dcc3ebf798c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -152,10 +152,22 @@ package object expressions { unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) } - /** Map to use for qualified case insensitive attribute lookups. */ - @transient private val qualified: Map[(String, String), Seq[Attribute]] = { - val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => - (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + /** Map to use for qualified case insensitive attribute lookups with 2 part key */ + @transient private lazy val qualified: Map[(String, String), Seq[Attribute]] = { + // key is 2 part: table/alias and name + val grouped = attrs.filter(_.qualifier.nonEmpty).groupBy { + a => (a.qualifier.last.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Map to use for qualified case insensitive attribute lookups with 3 part key */ + @transient private val qualified3Part: Map[(String, String, String), Seq[Attribute]] = { + // key is 3 part: database name, table name and name + val grouped = attrs.filter(_.qualifier.length == 2).groupBy { a => + (a.qualifier.head.toLowerCase(Locale.ROOT), + a.qualifier.last.toLowerCase(Locale.ROOT), + a.name.toLowerCase(Locale.ROOT)) } unique(grouped) } @@ -169,25 +181,48 @@ package object expressions { }) } - // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, - // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // Find matches for the given name assuming that the 1st two parts are qualifier + // (i.e. database name and table name) and the 3rd part is the actual column name. + // + // For example, consider an example where "db1" is the database name, "a" is the table name + // and "b" is the column name and "c" is the struct field name. + // If the name parts is db1.a.b.c, then Attribute will match + // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element + var matches: (Seq[Attribute], Seq[String]) = nameParts match { + case dbPart +: tblPart +: name +: nestedFields => + val key = (dbPart.toLowerCase(Locale.ROOT), + tblPart.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified3Part.get(key)).filter { + a => (resolver(dbPart, a.qualifier.head) && resolver(tblPart, a.qualifier.last)) + } + (attributes, nestedFields) + case all => + (Seq.empty, Seq.empty) + } + + // If there are no matches, then find matches for the given name assuming that + // the 1st part is a qualifier (i.e. table name, alias, or subquery alias) and the + // 2nd part is the actual name. This returns a tuple of // matched attributes and a list of parts that are to be resolved. // // For example, consider an example where "a" is the table name, "b" is the column name, // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", // and the second element will be List("c"). - val matches = nameParts match { - case qualifier +: name +: nestedFields => - val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) - val attributes = collectMatches(name, qualified.get(key)).filter { a => - resolver(qualifier, a.qualifier.get) - } - (attributes, nestedFields) - case all => - (Nil, all) + if (matches._1.isEmpty) { + matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.last) + } + (attributes, nestedFields) + case all => + (Seq.empty[Attribute], Seq.empty[String]) + } } - // If none of attributes match `table.column` pattern, we try to resolve it as a column. + // If none of attributes match database.table.column pattern or + // `table.column` pattern, we try to resolve it as a column. val (candidates, nestedFields) = matches match { case (Seq(), _) => val name = nameParts.head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index a3cc4529b5456..deceec73dda30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -47,6 +47,22 @@ sealed trait IdentifierWithDatabase { override def toString: String = quotedString } +/** + * Encapsulates an identifier that is either a alias name or an identifier that has table + * name and optionally a database name. + * The SubqueryAlias node keeps track of the qualifier using the information in this structure + * @param identifier - Is an alias name or a table name + * @param database - Is a database name and is optional + */ +case class AliasIdentifier(identifier: String, database: Option[String]) + extends IdentifierWithDatabase { + + def this(identifier: String) = this(identifier, None) +} + +object AliasIdentifier { + def apply(identifier: String): AliasIdentifier = new AliasIdentifier(identifier) +} /** * Identifies a table in a database. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9d18ce5c7b80f..19779b73d6dba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.{AliasIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ @@ -113,7 +114,7 @@ case class Generate( def qualifiedGeneratorOutput: Seq[Attribute] = { val qualifiedOutput = qualifier.map { q => // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) + generatorOutput.map(a => a.withQualifier(Seq(q))) }.getOrElse(generatorOutput) val nullableOutput = qualifiedOutput.map { // if outer, make all attributes nullable, otherwise keep existing nullability @@ -794,19 +795,37 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr /** * Aliased subquery. * - * @param alias the alias name for this subquery. + * @param name the alias identifier for this subquery. * @param child the logical plan of this subquery. */ case class SubqueryAlias( - alias: String, + name: AliasIdentifier, child: LogicalPlan) extends OrderPreservingUnaryNode { - override def doCanonicalize(): LogicalPlan = child.canonicalized + def alias: String = name.identifier - override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) + override def output: Seq[Attribute] = { + val qualifierList = name.database.map(Seq(_, alias)).getOrElse(Seq(alias)) + child.output.map(_.withQualifier(qualifierList)) + } + override def doCanonicalize(): LogicalPlan = child.canonicalized } +object SubqueryAlias { + def apply( + identifier: String, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier), child) + } + + def apply( + identifier: String, + database: String, + child: LogicalPlan): SubqueryAlias = { + SubqueryAlias(AliasIdentifier(identifier, Some(database)), child) + } +} /** * Sample the dataset. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 50496a0410528..89fabd4774065 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -537,11 +537,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { val view = View(desc = metadata, output = metadata.schema.toAttributes, child = CatalystSqlParser.parsePlan(metadata.viewText.get)) comparePlans(catalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view)) + SubqueryAlias("view1", "db3", view)) // Look up a view using current database of the session catalog. catalog.setCurrentDatabase("db3") comparePlans(catalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view)) + SubqueryAlias("view1", "db3", view)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 12eddf557109f..3ccaa5976cc28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -41,7 +41,7 @@ class ExpressionSetSuite extends SparkFunSuite { // maxHash's hashcode is calculated based on this exprId's hashcode, so we set this // exprId's hashCode to this specific value to make sure maxHash's hashcode is // `Int.MaxValue` - override def hashCode: Int = -1030353449 + override def hashCode: Int = 1394598635 // We are implementing this equals() only because the style-checking rule "you should // implement equals and hashCode together" requires us to override def equals(obj: Any): Boolean = super.equals(obj) @@ -57,7 +57,7 @@ class ExpressionSetSuite extends SparkFunSuite { // minHash's hashcode is calculated based on this exprId's hashcode, so we set this // exprId's hashCode to this specific value to make sure minHash's hashcode is // `Int.MinValue` - override def hashCode: Int = 1407330692 + override def hashCode: Int = -462684520 // We are implementing this equals() only because the style-checking rule "you should // implement equals and hashCode together" requires us to override def equals(obj: Any): Boolean = super.equals(obj) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index c48730bd9d1cc..1fa185cc77ebb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -30,7 +30,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite { } val b1 = a.withName("name2").withExprId(id) val b2 = a.withExprId(id) - val b3 = a.withQualifier(Some("qualifierName")) + val b3 = a.withQualifier(Seq("qualifierName")) assert(b1 != b2) assert(a != b1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql index d3f928751757c..83c32a5bf2435 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql @@ -13,10 +13,8 @@ DROP VIEW view1; -- Test scenario with Global Temp view CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1; SELECT * FROM global_temp.view1; --- TODO: Support this scenario SELECT global_temp.view1.* FROM global_temp.view1; SELECT i1 FROM global_temp.view1; --- TODO: Support this scenario SELECT global_temp.view1.i1 FROM global_temp.view1; SELECT view1.i1 FROM global_temp.view1; SELECT a.i1 FROM global_temp.view1 AS a; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql index 79e90ad3de91d..d001185a73931 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql @@ -14,9 +14,7 @@ SELECT i1 FROM mydb1.t1; SELECT t1.i1 FROM t1; SELECT t1.i1 FROM mydb1.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1; USE mydb2; @@ -24,7 +22,6 @@ SELECT i1 FROM t1; SELECT i1 FROM mydb1.t1; SELECT t1.i1 FROM t1; SELECT t1.i1 FROM mydb1.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1; -- Scenario: resolve fully qualified table name in star expansion @@ -34,7 +31,6 @@ SELECT mydb1.t1.* FROM mydb1.t1; SELECT t1.* FROM mydb1.t1; USE mydb2; SELECT t1.* FROM t1; --- TODO: Support this scenario SELECT mydb1.t1.* FROM mydb1.t1; SELECT t1.* FROM mydb1.t1; SELECT a.* FROM mydb1.t1 AS a; @@ -47,21 +43,17 @@ CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3) SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2); --- TODO: Support this scenario SELECT * FROM mydb1.t3 WHERE c1 IN (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2); -- Scenario: column resolution scenarios in join queries SET spark.sql.crossJoin.enabled = true; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1, mydb2.t1; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1; USE mydb2; --- TODO: Support this scenario SELECT mydb1.t1.i1 FROM t1, mydb1.t1; SET spark.sql.crossJoin.enabled = false; @@ -75,12 +67,10 @@ SELECT t5.t5.i1 FROM mydb1.t5; SELECT t5.i1 FROM mydb1.t5; SELECT t5.* FROM mydb1.t5; SELECT t5.t5.* FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.t5.i1 FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.t5.i2 FROM mydb1.t5; --- TODO: Support this scenario SELECT mydb1.t5.* FROM mydb1.t5; +SELECT mydb1.t5.* FROM t5; -- Cleanup and Reset USE default; diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index 539f673c9d679..9fc97f0c39149 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -72,7 +72,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 9 @@ -81,7 +81,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 10 @@ -90,7 +90,7 @@ SELECT mydb1.t1.i1 FROM t1, mydb1.t1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +Reference 'mydb1.t1.i1' is ambiguous, could be: mydb1.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 11 @@ -99,7 +99,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb1.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 12 @@ -108,7 +108,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb1.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 13 @@ -125,7 +125,7 @@ SELECT i1 FROM t1, mydb1.t1 struct<> -- !query 14 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb2.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 15 @@ -134,7 +134,7 @@ SELECT t1.i1 FROM t1, mydb1.t1 struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb2.t1.i1, mydb1.t1.i1.; line 1 pos 7 -- !query 16 @@ -143,7 +143,7 @@ SELECT i1 FROM t1, mydb2.t1 struct<> -- !query 16 output org.apache.spark.sql.AnalysisException -Reference 'i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 'i1' is ambiguous, could be: mydb2.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 17 @@ -152,7 +152,7 @@ SELECT t1.i1 FROM t1, mydb2.t1 struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Reference 't1.i1' is ambiguous, could be: t1.i1, t1.i1.; line 1 pos 7 +Reference 't1.i1' is ambiguous, could be: mydb2.t1.i1, mydb2.t1.i1.; line 1 pos 7 -- !query 18 @@ -161,7 +161,7 @@ SELECT db1.t1.i1 FROM t1, mydb2.t1 struct<> -- !query 18 output org.apache.spark.sql.AnalysisException -cannot resolve '`db1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +cannot resolve '`db1.t1.i1`' given input columns: [mydb2.t1.i1, mydb2.t1.i1]; line 1 pos 7 -- !query 19 @@ -186,7 +186,7 @@ SELECT mydb1.t1 FROM t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`mydb1.t1`' given input columns: [mydb1.t1.i1]; line 1 pos 7 -- !query 22 @@ -204,7 +204,7 @@ SELECT t1 FROM mydb1.t1 struct<> -- !query 23 output org.apache.spark.sql.AnalysisException -cannot resolve '`t1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`t1`' given input columns: [mydb1.t1.i1]; line 1 pos 7 -- !query 24 @@ -221,7 +221,7 @@ SELECT mydb1.t1.i1 FROM t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +cannot resolve '`mydb1.t1.i1`' given input columns: [mydb2.t1.i1]; line 1 pos 7 -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 2092119600954..3d8fb661afe55 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -85,10 +85,9 @@ struct -- !query 10 SELECT global_temp.view1.* FROM global_temp.view1 -- !query 10 schema -struct<> +struct -- !query 10 output -org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' given input columns 'i1'; +1 -- !query 11 @@ -102,10 +101,9 @@ struct -- !query 12 SELECT global_temp.view1.i1 FROM global_temp.view1 -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.AnalysisException -cannot resolve '`global_temp.view1.i1`' given input columns: [view1.i1]; line 1 pos 7 +1 -- !query 13 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index e10f516ad6e5b..73e3fdc08232c 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query 0 @@ -93,19 +93,17 @@ struct -- !query 11 SELECT mydb1.t1.i1 FROM t1 -- !query 11 schema -struct<> +struct -- !query 11 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 12 SELECT mydb1.t1.i1 FROM mydb1.t1 -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 13 @@ -151,10 +149,9 @@ struct -- !query 18 SELECT mydb1.t1.i1 FROM mydb1.t1 -- !query 18 schema -struct<> +struct -- !query 18 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1]; line 1 pos 7 +1 -- !query 19 @@ -176,10 +173,9 @@ struct -- !query 21 SELECT mydb1.t1.* FROM mydb1.t1 -- !query 21 schema -struct<> +struct -- !query 21 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' given input columns 'i1'; +1 -- !query 22 @@ -209,10 +205,9 @@ struct -- !query 25 SELECT mydb1.t1.* FROM mydb1.t1 -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' given input columns 'i1'; +1 -- !query 26 @@ -267,10 +262,9 @@ struct SELECT * FROM mydb1.t3 WHERE c1 IN (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2) -- !query 32 schema -struct<> +struct -- !query 32 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t4.c3`' given input columns: [t4.c2, t4.c3]; line 2 pos 42 +4 1 -- !query 33 @@ -284,19 +278,17 @@ spark.sql.crossJoin.enabled true -- !query 34 SELECT mydb1.t1.i1 FROM t1, mydb2.t1 -- !query 34 schema -struct<> +struct -- !query 34 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 35 SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 -- !query 35 schema -struct<> +struct -- !query 35 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 36 @@ -310,10 +302,9 @@ struct<> -- !query 37 SELECT mydb1.t1.i1 FROM t1, mydb1.t1 -- !query 37 schema -struct<> +struct -- !query 37 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t1.i1`' given input columns: [t1.i1, t1.i1]; line 1 pos 7 +1 -- !query 38 @@ -399,40 +390,37 @@ struct -- !query 48 SELECT mydb1.t5.t5.i1 FROM mydb1.t5 -- !query 48 schema -struct<> +struct -- !query 48 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i1`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 +2 -- !query 49 SELECT mydb1.t5.t5.i2 FROM mydb1.t5 -- !query 49 schema -struct<> +struct -- !query 49 output -org.apache.spark.sql.AnalysisException -cannot resolve '`mydb1.t5.t5.i2`' given input columns: [t5.i1, t5.t5]; line 1 pos 7 +3 -- !query 50 SELECT mydb1.t5.* FROM mydb1.t5 -- !query 50 schema -struct<> +struct> -- !query 50 output -org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; +1 {"i1":2,"i2":3} -- !query 51 -USE default +SELECT mydb1.t5.* FROM t5 -- !query 51 schema -struct<> +struct> -- !query 51 output - +1 {"i1":2,"i2":3} -- !query 52 -DROP DATABASE mydb1 CASCADE +USE default -- !query 52 schema struct<> -- !query 52 output @@ -440,8 +428,16 @@ struct<> -- !query 53 -DROP DATABASE mydb2 CASCADE +DROP DATABASE mydb1 CASCADE -- !query 53 schema struct<> -- !query 53 output + + +-- !query 54 +DROP DATABASE mydb2 CASCADE +-- !query 54 schema +struct<> +-- !query 54 output + diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index d5f8705a35ed6..7b3dc84388889 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -36,14 +36,14 @@ struct -- !query 3 output == Parsed Logical Plan == 'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias __auto_generated_subquery_name ++- 'SubqueryAlias `__auto_generated_subquery_name` +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] +- 'UnresolvedTableValuedFunction range, [10] == Analyzed Logical Plan == col: string Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias __auto_generated_subquery_name ++- SubqueryAlias `__auto_generated_subquery_name` +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 2586f26f71c35..e49978ddb1ce2 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -109,8 +109,8 @@ struct<> org.apache.spark.sql.AnalysisException Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: Aggregate [min(outer(t2a#x)) AS min(outer())#x] -+- SubqueryAlias t3 ++- SubqueryAlias `t3` +- Project [t3a#x, t3b#x, t3c#x] - +- SubqueryAlias t3 + +- SubqueryAlias `t3` +- LocalRelation [t3a#x, t3b#x, t3c#x] ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2cb7a04714a52..3a393d766b0bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2689,7 +2689,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val m = intercept[AnalysisException] { sql("SELECT * FROM t, S WHERE c = C") }.message - assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + assert( + m.contains("cannot resolve '(default.t.`c` = default.S.`C`)' due to data type mismatch")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d254345e8fa54..bdc106325aa5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -624,7 +624,7 @@ class PlannerSuite extends SharedSQLContext { dataType = LongType, nullable = false ) (exprId = exprId, - qualifier = Some("col1_qualifier") + qualifier = Seq("col1_qualifier") ) val attribute2 = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index abe61a2c2b9c4..fccee97820e75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -72,7 +72,7 @@ object TPCDSQueryBenchmark extends Logging { val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.analyzed.foreach { case SubqueryAlias(alias, _: LogicalRelation) => - queryRelations.add(alias) + queryRelations.add(alias.identifier) case LogicalRelation(_, _, Some(catalogTable), _) => queryRelations.add(catalogTable.identifier.table) case HiveTableRelation(tableMeta, _, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..688b619cd1bb5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias @@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias("vw1", _) => x + case x @ SubqueryAlias(AliasIdentifier("vw1", Some("default")), _) => x } assert(aliases.size == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1a916824c5d9e..13aa2b843667c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1967,6 +1967,22 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("column resolution scenarios with hive table") { + val currentDb = spark.catalog.currentDatabase + withTempDatabase { db1 => + try { + spark.catalog.setCurrentDatabase(db1) + spark.sql("CREATE TABLE t1(i1 int) STORED AS parquet") + spark.sql("INSERT INTO t1 VALUES(1)") + checkAnswer(spark.sql(s"SELECT $db1.t1.i1 FROM t1"), Row(1)) + checkAnswer(spark.sql(s"SELECT $db1.t1.i1 FROM $db1.t1"), Row(1)) + checkAnswer(spark.sql(s"SELECT $db1.t1.* FROM $db1.t1"), Row(1)) + } finally { + spark.catalog.setCurrentDatabase(currentDb) + } + } + } + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { withTable("bar") { withTempView("foo") { From 6a143e3ebf4bcac9464561e67174a7c610d9be1f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 7 Aug 2018 22:23:59 +0800 Subject: [PATCH 1335/2461] [SPARK-23928][TESTS][FOLLOWUP] Set seed to avoid flakiness ## What changes were proposed in this pull request? The tests for shuffle can be flaky (eg. https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/94355/testReport/). This happens because we have not set the seed for `Random`. ## How was this patch tested? running 10000 times the UT (validated that with a different seed eg. 12345 the test fails). Closes #22023 from mgaido91/SPARK-23928_followup. Authored-by: Marco Gaido Signed-off-by: hyukjinkwon --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c6b3f9502f2bf..40487b3fd001c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1489,7 +1489,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Shuffle(as7, Some(0)), null) checkEvaluation(Shuffle(aa, Some(0)), Seq(Seq("e"), Seq("a", "b"), Seq("c", "d"))) - val r = new Random() + val r = new Random(1234) val seed1 = Some(r.nextLong()) assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) === evaluateWithoutCodegen(Shuffle(ai0, seed1))) From 1a29fec8e278a98e69f2e2b6faa11332e8550f30 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Aug 2018 08:45:20 -0700 Subject: [PATCH 1336/2461] [SPARK-24979][SQL] add AnalysisHelper#resolveOperatorsUp ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/21822 Similar to `TreeNode`, `AnalysisHelper` should also provide 3 versions of transformations: `resolveOperatorsUp`, `resolveOperatorsDown` and `resolveOperators`. This PR adds the missing `resolveOperatorsUp`, and also fixes some code style which is missed in #21822 ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #21932 from cloud-fan/follow. --- .../sql/catalyst/analysis/Analyzer.scala | 134 ++++++++---------- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 56 ++++---- .../analysis/higherOrderFunctions.scala | 25 ++-- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../plans/logical/AnalysisHelper.scala | 19 ++- .../scala/org/apache/spark/sql/Dataset.scala | 8 +- .../datasources/DataSourceStrategy.scala | 4 +- .../sql/execution/datasources/rules.scala | 4 +- 9 files changed, 128 insertions(+), 128 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d391a93144c45..90c7cf6f082c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -206,7 +206,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -217,7 +217,7 @@ class Analyzer( def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { plan resolveOperatorsDown { - case u : UnresolvedRelation => + case u: UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) case other => @@ -234,19 +234,16 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. - case WithWindowDefinition(windowDefinitions, child) => - child.resolveOperators { - case p => p.transformExpressions { - case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => - val errorMessage = - s"Window specification $windowName is not defined in the WINDOW clause." - val windowSpecDefinition = - windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) - WindowExpression(c, windowSpecDefinition) - } - } + case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions { + case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => + val errorMessage = + s"Window specification $windowName is not defined in the WINDOW clause." + val windowSpecDefinition = + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) + WindowExpression(c, windowSpecDefinition) + } } } @@ -274,7 +271,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -494,7 +491,7 @@ class Analyzer( } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -527,7 +524,7 @@ class Analyzer( } object ResolvePivot extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p @@ -705,7 +702,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -897,7 +894,7 @@ class Analyzer( case _ => e.mapChildren(resolve(_, q)) } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -1091,7 +1088,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1147,7 +1144,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1171,7 +1168,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1307,7 +1304,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1459,7 +1456,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1475,7 +1472,7 @@ class Analyzer( */ object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => // Resolves output attributes if a query has alias names in its subquery: // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) @@ -1524,7 +1521,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause @@ -1701,7 +1698,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1761,7 +1758,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -1802,7 +1799,7 @@ class Analyzer( */ object FixNullability extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { @@ -2086,7 +2083,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2130,7 +2127,7 @@ class Analyzer( object ResolveRandomSeed extends Rule[LogicalPlan] { private lazy val random = new Random() - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if p.resolved => p case p => p transformExpressionsUp { case Uuid(None) => Uuid(Some(random.nextLong())) @@ -2146,7 +2143,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2181,25 +2178,21 @@ class Analyzer( * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, - WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != UnspecifiedFrame && wf.frame != f => - failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") - case WindowExpression(wf: WindowFunction, - s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, _, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => - WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = if (o.nonEmpty) { - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - we.copy(windowSpec = s.copy(frameSpecification = frame)) - } + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } + we.copy(windowSpec = s.copy(frameSpecification = frame)) } } @@ -2207,16 +2200,14 @@ class Analyzer( * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { - case logical: LogicalPlan => logical transformExpressions { - case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => - failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + - s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + - s"ORDER BY window_ordering) from table") - case WindowExpression(rank: RankLike, spec) if spec.resolved => - val order = spec.orderSpec.map(_.child) - WindowExpression(rank.withOrder(order), spec) - } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + + s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + + s"ORDER BY window_ordering) from table") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) } } @@ -2225,8 +2216,8 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case j @ Join(left, right, UsingJoin(joinType, usingCols), _) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => @@ -2290,7 +2281,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2376,7 +2367,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2410,7 +2401,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2432,9 +2423,8 @@ class Analyzer( * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { - // This is actually called in the beginning of the optimization phase, and as a result - // is using transformUp rather than resolveOperators. This is also often called in the - // + // This is also called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { plan transformUp { case SubqueryAlias(_, child) => child @@ -2446,7 +2436,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Union(children) if children.size == 1 => children.head } } @@ -2477,7 +2467,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2487,7 +2477,7 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(windowExprs, partitionSpec, orderSpec, child) => + case Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) Window(cleanedWindowExprs, partitionSpec.map(trimAliases), @@ -2511,7 +2501,7 @@ object CleanupAliases extends Rule[LogicalPlan] { * TODO: add this rule into analyzer rule list. */ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case EventTimeWatermark(_, _, child) if !child.isStreaming => child } } @@ -2556,7 +2546,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2707,7 +2697,7 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan resolveOperatorsDown { + plan resolveOperators { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 80d5105c2de8f..dbd4ed845e329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -86,7 +86,7 @@ object ResolveHints { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. @@ -134,7 +134,7 @@ object ResolveHints { * This must be executed after all the other hint rules are executed. */ object RemoveAllHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 648aa9ee8fa0e..27839d72c6306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -318,7 +318,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case s @ Except(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) @@ -739,17 +739,18 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } } } } @@ -762,23 +763,24 @@ object TypeCoercion { */ case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children) => - val index = children.head - val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail } - } else { - children.tail - } - c.copy(children = newIndex +: newInputs) + c.copy(children = newIndex +: newInputs) + } } } } @@ -989,7 +991,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { protected def coerceTypes(plan: LogicalPlan): LogicalPlan - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 063ca0fc3252d..5e2029c251ee4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -32,20 +32,17 @@ import org.apache.spark.sql.types.DataType */ case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case q: LogicalPlan => - q.transformExpressions { - case u @ UnresolvedFunction(fn, children, false) - if hasLambdaAndResolvedArguments(children) => - withPosition(u) { - catalog.lookupFunction(fn, children) match { - case func: HigherOrderFunction => func - case other => other.failAnalysis( - "A lambda function should only be used in a higher order function. However, " + - s"its class is ${other.getClass.getCanonicalName}, which is not a " + - s"higher order function.") - } - } + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case u @ UnresolvedFunction(fn, children, false) + if hasLambdaAndResolvedArguments(children) => + withPosition(u) { + catalog.lookupFunction(fn, children) match { + case func: HigherOrderFunction => func + case other => other.failAnalysis( + "A lambda function should only be used in a higher order function. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a " + + s"higher order function.") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index feeb6553d1066..af74693000c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 039acc1ea4fa8..9404a809b453c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -60,6 +60,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => */ def analyzed: Boolean = _analyzed + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. When + * `rule` does not apply to a given node, it is left unchanged. This function is similar to + * `transform`, but skips sub-trees that have already been marked as analyzed. + * Users should not expect a specific directionality. If a specific directionality is needed, + * [[resolveOperatorsUp]] or [[resolveOperatorsDown]] should be used. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + resolveOperatorsDown(rule) + } + /** * Returns a copy of this node where `rule` has been recursively applied first to all of its * children and then itself (post-order, bottom-up). When `rule` does not apply to a given node, @@ -68,10 +81,10 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => * * @param rule the function use to transform this nodes children */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + def resolveOperatorsUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { AnalysisHelper.allowInvokingTransformsInAnalyzer { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + val afterRuleOnChildren = mapChildren(_.resolveOperatorsUp(rule)) if (self fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(self, identity[LogicalPlan]) @@ -87,7 +100,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => } } - /** Similar to [[resolveOperators]], but does it top-down. */ + /** Similar to [[resolveOperatorsUp]], but does it top-down. */ def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { if (!analyzed) { AnalysisHelper.allowInvokingTransformsInAnalyzer { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a4bf990ea9d6c..f65948d39a1cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1383,8 +1383,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1658,15 +1657,14 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan - val withGroupingKey = AppendColumns(func, inputPlan) + val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, - inputPlan.output, + logicalPlan.output, withGroupingKey.newColumns) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e1b049b6ceaba..6b61e749e3063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast projectList } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] table.partitionSchema.asNullable.toAttributes) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3170180b32b83..949aa665527ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // catalog is a def and not a val/lazy val as the latter would introduce a circular reference private def catalog = sparkSession.sessionState.catalog - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // When we CREATE TABLE without specifying the table schema, we should fail the query if // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { case relation: HiveTableRelation => From 298e80f5c7a735206e070e03b03aa52712038128 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 7 Aug 2018 11:58:44 -0500 Subject: [PATCH 1337/2461] [SPARK-25041][BUILD] upgrade genJavaDoc-plugin from 0.10 to 0.11 ## What changes were proposed in this pull request? This PR fixes a build error with sbt using Scala-2.12. Since [`genJavaDoc-plugin`] (https://mvnrepository.com/artifact/com.typesafe.genjavadoc/genjavadoc-plugin) 0.10 is not prepared for Scala-2.12.6, the recent version of `genJavaDoc-plugin` is necessary. The version 0.11 of `genJavaDoc-plugin` is also prepared for Scala-2.11.12. [genJavaDoc-0.10](https://index.scala-lang.org/lightbend/genjavadoc/genjavadoc-plugin/0.10) [genJavaDoc-0.11](https://index.scala-lang.org/lightbend/genjavadoc/genjavadoc-plugin/0.11) ## How was this patch tested? Manually tested for Scala-2.12. Author: Kazuaki Ishizaki Closes #22020 from kiszk/SPARK-25041. --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 247b6fee394bc..4b4cce6788a96 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -212,7 +212,7 @@ object SparkBuild extends PomBuild { .map(file), incOptions := incOptions.value.withNameHashing(true), publishMavenStyle := true, - unidocGenjavadocVersion := "0.10", + unidocGenjavadocVersion := "0.11", // Override SBT's default resolvers: resolvers := Seq( From cb6cb313637836737f8ec7de34e592e425efd57f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 8 Aug 2018 02:12:19 +0900 Subject: [PATCH 1338/2461] [SPARK-23937][SQL] Add map_filter SQL function ## What changes were proposed in this pull request? The PR adds the high order function `map_filter`, which filters the entries of a map and returns a new map which contains only the entries which satisfied the filter function. ## How was this patch tested? added UTs Closes #21986 from mgaido91/SPARK-23937. Authored-by: Marco Gaido Signed-off-by: Takuya UESHIN --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 173 +++++++++++++----- .../HigherOrderFunctionsSuite.scala | 49 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 46 +++++ 4 files changed, 221 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed2f67da6f2bf..390debd865eed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -442,6 +442,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 5d4665917009b..0a473d2885079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -133,7 +133,29 @@ trait HigherOrderFunction extends Expression { } } -trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +object HigherOrderFunction { + + def arrayArgumentType(dt: DataType): (DataType, Boolean) = { + dt match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + } + + def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { + case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) + case _ => + val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType + (kType, vType, vContainsNull) + } +} + +/** + * Trait for functions having as input one argument and one function. + */ +trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { def input: Expression @@ -145,23 +167,33 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu def expectingFunctionType: AbstractDataType = AnyDataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) - @transient lazy val functionForEval: Expression = functionsForEval.head -} -object ArrayBasedHigherOrderFunction { + /** + * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method + * in order to save null-check code. + */ + protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = + sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") - def elementArgumentType(dt: DataType): (DataType, Boolean) = { - dt match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) + override def eval(inputRow: InternalRow): Any = { + val value = input.eval(inputRow) + if (value == null) { + null + } else { + nullSafeEval(inputRow, value) } } } +trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) +} + +trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) +} + /** * Transform elements in an array using the transform function. This is similar to * a `map` in functional programming. @@ -179,14 +211,14 @@ object ArrayBasedHigherOrderFunction { case class ArrayTransform( input: Expression, function: Expression) - extends ArrayBasedHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => copy(function = f(function, elem :: (IntegerType, false) :: Nil)) @@ -205,29 +237,78 @@ case class ArrayTransform( (elementVar, indexVar) } - override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] - if (arr == null) { - null - } else { - val f = functionForEval - val result = new GenericArrayData(new Array[Any](arr.numElements)) - var i = 0 - while (i < arr.numElements) { - elementVar.value.set(arr.get(i, elementVar.dataType)) - if (indexVar.isDefined) { - indexVar.get.value.set(i) - } - result.update(i, f.eval(input)) - i += 1 + override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { + val arr = inputValue.asInstanceOf[ArrayData] + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) } - result + result.update(i, f.eval(inputRow)) + i += 1 } + result } override def prettyName: String = "transform" } +/** + * Filters entries in a map using the provided function. + */ +@ExpressionDescription( +usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); + [1 -> 0, 3 -> -1] + """, +since = "2.4.0") +case class MapFilter( + input: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val (keyVar, valueVar) = { + val args = function.asInstanceOf[LambdaFunction].arguments + (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) + } + + @transient val (keyType, valueType, valueContainsNull) = + HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + override def nullable: Boolean = input.nullable + + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val m = value.asInstanceOf[MapData] + val f = functionForEval + val retKeys = new mutable.ListBuffer[Any] + val retValues = new mutable.ListBuffer[Any] + m.foreach(keyType, valueType, (k, v) => { + keyVar.value.set(k) + valueVar.value.set(v) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + retKeys += k + retValues += v + } + }) + ArrayBasedMapData(retKeys.toArray, retValues.toArray) + } + + override def dataType: DataType = input.dataType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def prettyName: String = "map_filter" +} + /** * Filters the input array using the given lambda function. */ @@ -242,7 +323,7 @@ case class ArrayTransform( case class ArrayFilter( input: Expression, function: Expression) - extends ArrayBasedHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable @@ -251,29 +332,25 @@ case class ArrayFilter( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) copy(function = f(function, elem :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] - if (arr == null) { - null - } else { - val f = functionForEval - val buffer = new mutable.ArrayBuffer[Any](arr.numElements) - var i = 0 - while (i < arr.numElements) { - elementVar.value.set(arr.get(i, elementVar.dataType)) - if (f.eval(input).asInstanceOf[Boolean]) { - buffer += elementVar.value.get - } - i += 1 + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val arr = value.asInstanceOf[ArrayData] + val f = functionForEval + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + buffer += elementVar.value.get } - new GenericArrayData(buffer) + i += 1 } + new GenericArrayData(buffer) } override def prettyName: String = "filter" @@ -334,7 +411,7 @@ case class ArrayAggregate( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) val acc = zero.dataType -> true val newMerge = f(merge, acc :: elem :: Nil) val newFinish = f(finish, acc :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 40cfc0ccc7c07..f7e84b8757916 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -121,6 +121,55 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq("[1, 3, 5]", null, "[4, 6]")) } + test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val mt = expr.dataType.asInstanceOf[MapType] + MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f)) + } + val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v + + checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1)) + checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) + checkEvaluation(mapFilter(miin, kGreaterThanV), null) + + val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull + + checkEvaluation(mapFilter(mii0, valueIsNull), Map()) + checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null)) + checkEvaluation(mapFilter(miin, valueIsNull), null) + + val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), + MapType(StringType, IntegerType, valueContainsNull = false)) + val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null), + MapType(StringType, IntegerType, valueContainsNull = true)) + val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false)) + + val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v + + checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0)) + checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5)) + checkEvaluation(mapFilter(msin, isLengthOfKey), null) + + val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true)) + val mian = Literal.create( + null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + + val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3 + + checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mian, customFunc), null) + } + test("ArrayFilter") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c04780db4e525..24091f2128049 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1854,6 +1854,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } + test("map_filter") { + val dfInts = Seq( + Map(1 -> 10, 2 -> 20, 3 -> 30), + Map(1 -> -1, 2 -> -2, 3 -> -3), + Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") + + checkAnswer(dfInts.selectExpr( + "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + + val dfComplex = Seq( + Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), + Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") + + checkAnswer(dfComplex.selectExpr( + "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + + // Invalid use cases + val df = Seq( + (Map(1 -> "a"), 1), + (Map.empty[Int, String], 2), + (null, 3) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, x -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_filter(i, (k, v) -> k > v)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + } + test("filter function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), From 8c13cb2ae4f82c0eb04939a667310b9268b1c2a7 Mon Sep 17 00:00:00 2001 From: invkrh Date: Tue, 7 Aug 2018 11:04:37 -0700 Subject: [PATCH 1339/2461] [SPARK-25031][SQL] Fix MapType schema print ## What changes were proposed in this pull request? The PR fix the bug in `buildFormattedString` function in `MapType`, which makes the printed schema misleading. ## How was this patch tested? Added UT Closes #22006 from invkrh/fix-map-schema-print. Authored-by: invkrh Signed-off-by: Xiao Li --- .../org/apache/spark/sql/types/MapType.scala | 2 +- .../spark/sql/types/DataTypeSuite.scala | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 6691b81dcea8d..594e155268bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -42,9 +42,9 @@ case class MapType( private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) builder.append(s"$prefix-- value: ${valueType.typeName} " + s"(valueContainsNull = $valueContainsNull)\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index fccd057e577d4..122a3125ee2c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -452,4 +452,30 @@ class DataTypeSuite extends SparkFunSuite { new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), false) + + test("SPARK-25031: MapType should produce current formatted string for complex types") { + val keyType: DataType = StructType(Seq( + StructField("a", DataTypes.IntegerType), + StructField("b", DataTypes.IntegerType))) + + val valueType: DataType = StructType(Seq( + StructField("c", DataTypes.IntegerType), + StructField("d", DataTypes.IntegerType))) + + val builder = new StringBuilder + + MapType(keyType, valueType).buildFormattedString(prefix = "", builder = builder) + + val result = builder.toString() + val expected = + """-- key: struct + | |-- a: integer (nullable = true) + | |-- b: integer (nullable = true) + |-- value: struct (valueContainsNull = true) + | |-- c: integer (nullable = true) + | |-- d: integer (nullable = true) + |""".stripMargin + + assert(result === expected) + } } From f6356f9bc0ddcb84827c468a3da677995abe033f Mon Sep 17 00:00:00 2001 From: Neal Song Date: Tue, 7 Aug 2018 14:51:41 -0700 Subject: [PATCH 1340/2461] [SPARK-25046][SQL] Fix Alter View can excute sql like "ALTER VIEW ... AS INSERT INTO" ## What changes were proposed in this pull request? Alter View can excute sql like "ALTER VIEW ... AS INSERT INTO" . We should throw ParseException(s"Operation not allowed: $message", ctx) as Create View does. ``` override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { // CREATE VIEW ... AS INSERT INTO is not allowed. ctx.query.queryNoWith match { case s: SingleInsertQueryContext if s.insertInto != null => operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) case _: MultiInsertQueryContext => operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) case _ => // OK } ``` ``` override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { // ALTER VIEW ... AS INSERT INTO is not allowed. ctx.query.queryNoWith match { case s: SingleInsertQueryContext if s.insertInto != null => operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) case _: MultiInsertQueryContext => operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) case _ => // OK } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), query = plan(ctx.query)) } ``` ## How was this patch tested? UT has been added in SparkSqlParserSuite Closes #22028 from sddyljsx/SPARK-25046. Lead-authored-by: Neal Song Co-authored-by: neal Signed-off-by: Xiao Li --- .../apache/spark/sql/execution/SparkSqlParser.scala | 8 ++++++++ .../spark/sql/execution/SparkSqlParserSuite.scala | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 4828fa60a7b58..89cb63784c0f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1458,6 +1458,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + // ALTER VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 107a2f7109793..28a060aff47b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -366,4 +366,15 @@ class SparkSqlParserSuite extends AnalysisTest { "SELECT a || b || c FROM t", Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } + + test("SPARK-25046 Fix Alter View ... As Insert Into Table") { + // Single insert query + intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: ALTER VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") + } } From 66699c5c3061f54463bd1d0f7a8f8e168c2882c9 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 7 Aug 2018 17:30:37 -0500 Subject: [PATCH 1341/2461] [SPARK-25029][TESTS] Scala 2.12 issues: TaskNotSerializable and Janino "Two non-abstract methods ..." errors ## What changes were proposed in this pull request? Fixes for test issues that arose after Scala 2.12 support was added -- ones that only affect the 2.12 build. ## How was this patch tested? Existing tests. Closes #22004 from srowen/SPARK-25029. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 7 ++++--- .../scala/org/apache/spark/util/AccumulatorV2Suite.scala | 3 ++- .../org/apache/spark/graphx/util/BytecodeUtilsSuite.scala | 7 +++++++ .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala | 2 ++ 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index dad339e2cdb91..8b2b6b6bede02 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2386,9 +2386,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Runs a job that encounters a single fetch failure but succeeds on the second attempt def runJobWithTemporaryFetchFailure: Unit = { - object FailThisAttempt { - val _fail = new AtomicBoolean(true) - } val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() val shuffleHandle = rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle @@ -2584,3 +2581,7 @@ object DAGSchedulerSuite { def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) } + +object FailThisAttempt { + val _fail = new AtomicBoolean(true) +} diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index fe0a9a471a651..94c79388e3639 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -165,7 +165,6 @@ class AccumulatorV2Suite extends SparkFunSuite { } test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { - class MyData(val i: Int) extends Serializable val param = new AccumulatorParam[MyData] { override def zero(initialValue: MyData): MyData = new MyData(0) override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) @@ -182,3 +181,5 @@ class AccumulatorV2Suite extends SparkFunSuite { ser.serialize(acc) } } + +class MyData(val i: Int) extends Serializable diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index 61e44dcab578c..5325978a0a1ec 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ClosureCleanerSuite2 // scalastyle:off println @@ -26,6 +27,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass test("closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); } assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo")) assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar")) @@ -43,6 +45,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure inside a closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); } val c2 = {e: TestClass => c1(e); println(e.foo); } assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo")) @@ -51,6 +54,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure inside a closure inside a closure invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c1 = {e: TestClass => println(e.baz); } val c2 = {e: TestClass => c1(e); println(e.foo); } val c3 = {e: TestClass => c2(e) } @@ -60,6 +64,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure calling a function that invokes a method") { + assume(!ClosureCleanerSuite2.supportsLMFs) def zoo(e: TestClass) { println(e.baz) } @@ -70,6 +75,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("closure calling a function that invokes a method which uses another closure") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c2 = {e: TestClass => println(e.baz)} def zoo(e: TestClass) { c2(e) @@ -81,6 +87,7 @@ class BytecodeUtilsSuite extends SparkFunSuite { } test("nested closure") { + assume(!ClosureCleanerSuite2.supportsLMFs) val c2 = {e: TestClass => println(e.baz)} def zoo(e: TestClass, c: TestClass => Unit) { c(e) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index bb3f3a015c715..918560a5988eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -77,7 +77,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} * the heaviest part of the computation. In general, this implementation is bound by either * the cost of statistics computation on workers or by communicating the sufficient statistics. */ -private[spark] object RandomForest extends Logging { +private[spark] object RandomForest extends Logging with Serializable { /** * Train a random forest. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index cdd5cdd841740..4f3df729177fb 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,6 +21,7 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer +import scala.tools.nsc.interpreter.SimpleReader import org.apache.log4j.{Level, LogManager} @@ -84,6 +85,7 @@ class ReplSuite extends SparkFunSuite { settings = new scala.tools.nsc.Settings settings.usejavacp.value = true org.apache.spark.repl.Main.interp = this + in = SimpleReader() } val out = new StringWriter() From d90f1336d87199aac56fe227a0fe14ab0ae3a332 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 7 Aug 2018 17:32:41 -0700 Subject: [PATCH 1342/2461] [SPARK-25045][CORE] Make `RDDBarrier.mapParititions` similar to `RDD.mapPartitions` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Signature of the function passed to `RDDBarrier.mapPartitions()` is different from that of `RDD.mapPartitions`. The later doesn’t take a `TaskContext`. We shall make the function signature the same to avoid confusion and misusage. This PR proposes the following API changes: - In `RDDBarrier`, migrate `mapPartitions` from ``` def mapPartitions[S: ClassTag]( f: (Iterator[T], BarrierTaskContext) => Iterator[S], preservesPartitioning: Boolean = false): RDD[S] } ``` to ``` def mapPartitions[S: ClassTag]( f: Iterator[T] => Iterator[S], preservesPartitioning: Boolean = false): RDD[S] } ``` - Add new static method to get a `BarrierTaskContext`: ``` object BarrierTaskContext { def get(): BarrierTaskContext } ``` ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #22026 from jiangxb1987/mapPartitions. --- .../org/apache/spark/BarrierTaskContext.scala | 14 ++++++++++-- .../org/apache/spark/rdd/RDDBarrier.scala | 7 +++--- .../spark/BarrierStageOnSubmittedSuite.scala | 22 +++++++++---------- .../org/apache/spark/SparkContextSuite.scala | 6 +++-- .../apache/spark/rdd/RDDBarrierSuite.scala | 6 ++--- .../scheduler/BarrierTaskContextSuite.scala | 15 ++++++++----- .../spark/scheduler/DAGSchedulerSuite.scala | 4 ++-- 7 files changed, 45 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 8e2b15599b674..de827987f28f9 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -72,7 +72,8 @@ class BarrierTaskContext( * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it * shall lead to timeout of the function call. * {{{ - * rdd.barrier().mapPartitions { (iter, context) => + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() * if (context.partitionId() == 0) { * // Do nothing. * } else { @@ -85,7 +86,8 @@ class BarrierTaskContext( * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the * second function call. * {{{ - * rdd.barrier().mapPartitions { (iter, context) => + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() * try { * // Do something that might throw an Exception. * doSomething() @@ -152,3 +154,11 @@ class BarrierTaskContext( addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) } } + +object BarrierTaskContext { + /** + * Return the currently active BarrierTaskContext. This can be called inside of user functions to + * access contextual information about running barrier tasks. + */ + def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext] +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 71f38bf6967bc..978e7c004e5e6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -28,7 +28,7 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) { /** * :: Experimental :: - * Maps partitions together with a provided [[org.apache.spark.BarrierTaskContext]]. + * Generate a new barrier RDD by applying a function to each partitions of the prev RDD. * * `preservesPartitioning` indicates whether the input function preserves the partitioner, which * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. @@ -36,13 +36,12 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) { @Experimental @Since("2.4.0") def mapPartitions[S: ClassTag]( - f: (Iterator[T], BarrierTaskContext) => Iterator[S], + f: Iterator[T] => Iterator[S], preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { val cleanedF = rdd.sparkContext.clean(f) new MapPartitionsRDD( rdd, - (context: TaskContext, index: Int, iter: Iterator[T]) => - cleanedF(iter, context.asInstanceOf[BarrierTaskContext]), + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter), preservesPartitioning, isFromBarrier = true ) diff --git a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala index 75e13a9bec105..2f21e61ce9c97 100644 --- a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala +++ b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala @@ -61,7 +61,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) val rdd = prunedRdd .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) testSubmitJob(sc, rdd, message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } @@ -71,7 +71,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1) val rdd = prunedRdd .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) .repartition(2) .map(x => x + 1) testSubmitJob(sc, rdd, @@ -84,7 +84,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext val rdd = prunedRdd .repartition(2) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) // Should be able to submit job and run successfully. val result = rdd.collect().sorted assert(result === Seq(6, 7, 8, 9, 10)) @@ -94,7 +94,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext sc = createSparkContext() val rdd = sc.parallelize(1 to 10, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) testSubmitJob(sc, rdd, Some(Seq(1, 3)), message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } @@ -103,7 +103,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext sc = createSparkContext() val rdd1 = sc.parallelize(1 to 10, 2) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) val rdd2 = sc.parallelize(1 to 20, 2) val rdd3 = rdd1 .union(rdd2) @@ -117,7 +117,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext sc = createSparkContext() val rdd = sc.parallelize(1 to 10, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) .coalesce(1) // Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage // only launches 1 task. @@ -129,10 +129,10 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext sc = createSparkContext() val rdd1 = sc.parallelize(1 to 10, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) val rdd2 = sc.parallelize(11 to 20, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) val rdd3 = rdd1 .zip(rdd2) .map(x => x._1 + x._2) @@ -144,7 +144,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext sc = createSparkContext() val rdd1 = sc.parallelize(1 to 10, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) val rdd2 = sc.parallelize(11 to 20, 4) val rdd3 = rdd1 .zip(rdd2) @@ -164,7 +164,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext val rdd = sc.parallelize(1 to 10, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) testSubmitJob(sc, rdd, message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) } @@ -179,7 +179,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext val rdd = sc.parallelize(1 to 10, 4) .barrier() - .mapPartitions((iter, context) => iter) + .mapPartitions(iter => iter) .repartition(2) .map(x => x + 1) testSubmitJob(sc, rdd, diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index e5f31a04c6e7e..cb44110e30135 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -632,7 +632,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val conf = new SparkConf().setAppName("test").setMaster("local[2]") sc = new SparkContext(conf) val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() // If we don't get the expected taskInfos, the job shall abort due to stage failure. if (context.getTaskInfos().length != 2) { throw new SparkException("Expected taksInfos length is 2, actual length is " + @@ -654,7 +655,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() // If we don't get the expected taskInfos, the job shall abort due to stage failure. if (context.getTaskInfos().length != 2) { throw new SparkException("Expected taksInfos length is 2, actual length is " + diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala index 39d4618e4c6c5..d57ea4d5501e3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -25,19 +25,19 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { val rdd = sc.parallelize(1 to 10, 4) assert(rdd.isBarrier() === false) - val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter) + val rdd2 = rdd.barrier().mapPartitions(iter => iter) assert(rdd2.isBarrier() === true) } test("create an RDDBarrier in the middle of a chain of RDDs") { val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) - val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).map(x => (x, x + 1)) + val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1)) assert(rdd2.isBarrier() === true) } test("RDDBarrier with shuffle") { val rdd = sc.parallelize(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).repartition(2) + val rdd2 = rdd.barrier().mapPartitions(iter => iter).repartition(2) assert(rdd2.isBarrier() === false) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 5f96d6fb0cdb6..36dd620a56853 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -31,7 +31,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() // Sleep for a random time before global sync. Thread.sleep(Random.nextInt(1000)) context.barrier() @@ -49,7 +50,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() // Sleep for a random time before global sync. Thread.sleep(Random.nextInt(1000)) context.barrier() @@ -79,7 +81,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() // Task 3 shall sleep 2000ms to ensure barrier() call timeout if (context.taskAttemptId == 3) { Thread.sleep(2000) @@ -103,7 +106,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() if (context.taskAttemptId != 0) { context.barrier() } @@ -125,7 +129,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { (it, context) => + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() try { if (context.taskAttemptId == 0) { // Due to some non-obvious reason, the code can trigger an Exception and skip the diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b2b6b6bede02..211002b2b5caa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1062,7 +1062,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") { - val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it) + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) @@ -1091,7 +1091,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") { - val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it) + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) From 5fef6e3513d6023a837c427d183006d153c7102b Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 8 Aug 2018 09:55:52 +0800 Subject: [PATCH 1343/2461] [SPARK-24251][SQL] Add AppendData logical plan. ## What changes were proposed in this pull request? This adds a new logical plan, AppendData, that was proposed in SPARK-23521: Standardize SQL logical plans. * DataFrameWriter uses the new AppendData plan for DataSourceV2 appends * AppendData is resolved if its output columns match the incoming data frame * A new analyzer rule, ResolveOutputColumns, validates data before it is appended. This rule will add safe casts, rename columns, and checks nullability ## How was this patch tested? Existing tests for v2 appends. Will add AppendData tests to validate logical plan analysis. Closes #21305 from rdblue/SPARK-24251-add-append-data. Lead-authored-by: Ryan Blue Co-authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 93 ++++ .../sql/catalyst/analysis/NamedRelation.scala | 24 ++ .../plans/logical/basicLogicalOperators.scala | 36 +- .../org/apache/spark/sql/types/DataType.scala | 123 +++++- .../DataTypeWriteCompatibilitySuite.scala | 404 ++++++++++++++++++ .../spark/sql/sources/v2/WriteSupport.java | 9 +- .../apache/spark/sql/DataFrameWriter.scala | 40 +- .../datasources/v2/DataSourceV2Relation.scala | 49 ++- .../datasources/v2/DataSourceV2Strategy.scala | 5 +- ...V2.scala => WriteToDataSourceV2Exec.scala} | 4 +- .../sql/sources/v2/DataSourceV2Suite.scala | 14 +- 11 files changed, 759 insertions(+), 42 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/{WriteToDataSourceV2.scala => WriteToDataSourceV2Exec.scala} (95%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 90c7cf6f082c5..d23d43bef76e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -176,6 +176,7 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputRelation :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -2227,6 +2228,98 @@ class Analyzer( } } + /** + * Resolves columns of an output table from the data in a logical plan. This rule will: + * + * - Reorder columns when the write is by name + * - Insert safe casts when data types do not match + * - Insert aliases when column names do not match + * - Detect plans that are not compatible with the output table and throw AnalysisException + */ + object ResolveOutputRelation extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case append @ AppendData(table, query, isByName) + if table.resolved && query.resolved && !append.resolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + append.copy(query = projection) + } else { + append + } + } + + def resolveOutputColumns( + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan, + byName: Boolean): LogicalPlan = { + + if (expected.size < query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', too many data columns: + |Table columns: ${expected.map(_.name).mkString(", ")} + |Data columns: ${query.output.map(_.name).mkString(", ")}""".stripMargin) + } + + val errors = new mutable.ArrayBuffer[String]() + val resolved: Seq[NamedExpression] = if (byName) { + expected.flatMap { tableAttr => + query.resolveQuoted(tableAttr.name, resolver) match { + case Some(queryExpr) => + checkField(tableAttr, queryExpr, err => errors += err) + case None => + errors += s"Cannot find data for output column '${tableAttr.name}'" + None + } + } + + } else { + if (expected.size > query.output.size) { + throw new AnalysisException( + s"""Cannot write to '$tableName', not enough data columns: + |Table columns: ${expected.map(_.name).mkString(", ")} + |Data columns: ${query.output.map(_.name).mkString(", ")}""".stripMargin) + } + + query.output.zip(expected).flatMap { + case (queryExpr, tableAttr) => + checkField(tableAttr, queryExpr, err => errors += err) + } + } + + if (errors.nonEmpty) { + throw new AnalysisException( + s"Cannot write incompatible data to table '$tableName':\n- ${errors.mkString("\n- ")}") + } + + Project(resolved, query) + } + + private def checkField( + tableAttr: Attribute, + queryExpr: NamedExpression, + addError: String => Unit): Option[NamedExpression] = { + + if (queryExpr.nullable && !tableAttr.nullable) { + addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") + None + + } else if (!DataType.canWrite( + tableAttr.dataType, queryExpr.dataType, resolver, tableAttr.name, addError)) { + None + + } else { + // always add an UpCast. it will be removed in the optimizer if it is unnecessary. + Some(Alias( + UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name + )( + explicitMetadata = Option(tableAttr.metadata) + )) + } + } + } + private def commonNaturalJoinProcessing( left: LogicalPlan, right: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala new file mode 100644 index 0000000000000..ad201f947b671 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +trait NamedRelation extends LogicalPlan { + def name: String +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 19779b73d6dba..0d31c6f6b9c49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.{AliasIdentifier} -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, - RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -353,6 +352,37 @@ case class Join( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Seq.empty + + override lazy val resolved: Boolean = { + query.output.size == table.output.size && query.output.zip(table.output).forall { + case (inAttr, outAttr) => + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) + } + } +} + +object AppendData { + def byName(table: NamedRelation, df: LogicalPlan): AppendData = { + new AppendData(table, df, true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { + new AppendData(table, query, false) + } +} + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index fd40741cfb5f1..50f2a9df522c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -27,7 +27,8 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -336,4 +337,124 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + private val SparkGeneratedName = """col\d+""".r + private def isSparkGeneratedName(name: String): Boolean = name match { + case SparkGeneratedName(_*) => true + case _ => false + } + + /** + * Returns true if the write data type can be read using the read data type. + * + * The write type is compatible with the read type if: + * - Both types are arrays, the array element types are compatible, and element nullability is + * compatible (read allows nulls or write does not contain nulls). + * - Both types are maps and the map key and value types are compatible, and value nullability + * is compatible (read allows nulls or write does not contain nulls). + * - Both types are structs and each field in the read struct is present in the write struct and + * compatible (including nullability), or is nullable if the write struct does not contain the + * field. Write-side structs are not compatible if they contain fields that are not present in + * the read-side struct. + * - Both types are atomic and the write type can be safely cast to the read type. + * + * Extra fields in write-side structs are not allowed to avoid accidentally writing data that + * the read schema will not read, and to ensure map key equality is not changed when data is read. + * + * @param write a write-side data type to validate against the read type + * @param read a read-side data type + * @return true if data written with the write type can be read using the read type + */ + def canWrite( + write: DataType, + read: DataType, + resolver: Resolver, + context: String, + addError: String => Unit = (_: String) => {}): Boolean = { + (write, read) match { + case (wArr: ArrayType, rArr: ArrayType) => + // run compatibility check first to produce all error messages + val typesCompatible = + canWrite(wArr.elementType, rArr.elementType, resolver, context + ".element", addError) + + if (wArr.containsNull && !rArr.containsNull) { + addError(s"Cannot write nullable elements to array of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (wMap: MapType, rMap: MapType) => + // map keys cannot include data fields not in the read schema without changing equality when + // read. map keys can be missing fields as long as they are nullable in the read schema. + + // run compatibility check first to produce all error messages + val keyCompatible = + canWrite(wMap.keyType, rMap.keyType, resolver, context + ".key", addError) + val valueCompatible = + canWrite(wMap.valueType, rMap.valueType, resolver, context + ".value", addError) + val typesCompatible = keyCompatible && valueCompatible + + if (wMap.valueContainsNull && !rMap.valueContainsNull) { + addError(s"Cannot write nullable values to map of non-nulls: '$context'") + false + } else { + typesCompatible + } + + case (StructType(writeFields), StructType(readFields)) => + var fieldCompatible = true + readFields.zip(writeFields).foreach { + case (rField, wField) => + val namesMatch = resolver(wField.name, rField.name) || isSparkGeneratedName(wField.name) + val fieldContext = s"$context.${rField.name}" + val typesCompatible = + canWrite(wField.dataType, rField.dataType, resolver, fieldContext, addError) + + if (!namesMatch) { + addError(s"Struct '$context' field name does not match (may be out of order): " + + s"expected '${rField.name}', found '${wField.name}'") + fieldCompatible = false + } else if (!rField.nullable && wField.nullable) { + addError(s"Cannot write nullable values to non-null field: '$fieldContext'") + fieldCompatible = false + } else if (!typesCompatible) { + // errors are added in the recursive call to canWrite above + fieldCompatible = false + } + } + + if (readFields.size > writeFields.size) { + val missingFieldsStr = readFields.takeRight(readFields.size - writeFields.size) + .map(f => s"'${f.name}'").mkString(", ") + if (missingFieldsStr.nonEmpty) { + addError(s"Struct '$context' missing fields: $missingFieldsStr") + fieldCompatible = false + } + + } else if (writeFields.size > readFields.size) { + val extraFieldsStr = writeFields.takeRight(writeFields.size - readFields.size) + .map(f => s"'${f.name}'").mkString(", ") + addError(s"Cannot write extra fields to struct '$context': $extraFieldsStr") + fieldCompatible = false + } + + fieldCompatible + + case (w: AtomicType, r: AtomicType) => + if (!Cast.canSafeCast(w, r)) { + addError(s"Cannot safely cast '$context': $w to $r") + false + } else { + true + } + + case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => + true + + case (w, r) => + addError(s"Cannot write '$context': $w is incompatible with $r") + false + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala new file mode 100644 index 0000000000000..d92f52f3248aa --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.Cast + +class DataTypeWriteCompatibilitySuite extends SparkFunSuite { + private val atomicTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DateType, TimestampType, StringType, BinaryType) + + private val point2 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + private val widerPoint2 = StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false))) + + private val point3 = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType))) + + private val simpleContainerTypes = Seq( + ArrayType(LongType), ArrayType(LongType, containsNull = false), MapType(StringType, DoubleType), + MapType(StringType, DoubleType, valueContainsNull = false), point2, point3) + + private val nestedContainerTypes = Seq(ArrayType(point2, containsNull = false), + MapType(StringType, point3, valueContainsNull = false)) + + private val allNonNullTypes = Seq( + atomicTypes, simpleContainerTypes, nestedContainerTypes, Seq(CalendarIntervalType)).flatten + + test("Check NullType is incompatible with all other types") { + allNonNullTypes.foreach { t => + assertSingleError(NullType, t, "nulls", s"Should not allow writing None to type $t") { err => + assert(err.contains(s"incompatible with $t")) + } + } + } + + test("Check each type with itself") { + allNonNullTypes.foreach { t => + assertAllowed(t, t, "t", s"Should allow writing type to itself $t") + } + } + + test("Check atomic types: write allowed only when casting is safe") { + atomicTypes.foreach { w => + atomicTypes.foreach { r => + if (Cast.canSafeCast(w, r)) { + assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") + + } else { + assertSingleError(w, r, "t", + s"Should not allow writing $w to $r because cast is not safe") { err => + assert(err.contains("'t'"), "Should include the field name context") + assert(err.contains("Cannot safely cast"), "Should identify unsafe cast") + assert(err.contains(s"$w"), "Should include write type") + assert(err.contains(s"$r"), "Should include read type") + } + } + } + } + } + + test("Check struct types: missing required field") { + val missingRequiredField = StructType(Seq(StructField("x", FloatType, nullable = false))) + assertSingleError(missingRequiredField, point2, "t", + "Should fail because required field 'y' is missing") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'y'"), "Should include the nested field name") + assert(err.contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing starting field, matched by position") { + val missingRequiredField = StructType(Seq(StructField("y", FloatType, nullable = false))) + + // should have 2 errors: names x and y don't match, and field y is missing + assertNumErrors(missingRequiredField, point2, "t", + "Should fail because field 'x' is matched to field 'y' and required field 'y' is missing", 2) + { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'y'"), "Should include the _last_ nested fields of the read schema") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: missing middle field, matched by position") { + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false), + StructField("z", FloatType, nullable = true))) + + // types are compatible: (req int, req int) => (req int, req int, opt int) + // but this should still fail because the names do not match. + + assertNumErrors(missingMiddleField, expectedStruct, "t", + "Should fail because field 'y' is matched to field 'z'", 2) { errs => + assert(errs(0).contains("'t'"), "Should include the struct name for context") + assert(errs(0).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(0).contains("field name does not match"), "Should identify name problem") + + assert(errs(1).contains("'t'"), "Should include the struct name for context") + assert(errs(1).contains("'z'"), "Should include the nested field name") + assert(errs(1).contains("missing field"), "Should call out field missing") + } + } + + test("Check struct types: generic colN names are ignored") { + val missingMiddleField = StructType(Seq( + StructField("col1", FloatType, nullable = false), + StructField("col2", FloatType, nullable = false))) + + val expectedStruct = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))) + + // types are compatible: (req int, req int) => (req int, req int) + // names don't match, but match the naming convention used by Spark to fill in names + + assertAllowed(missingMiddleField, expectedStruct, "t", + "Should succeed because column names are ignored") + } + + test("Check struct types: required field is optional") { + val requiredFieldIsOptional = StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType, nullable = false))) + + assertSingleError(requiredFieldIsOptional, point2, "t", + "Should fail because required field 'x' is optional") { err => + assert(err.contains("'t.x'"), "Should include the nested field name context") + assert(err.contains("Cannot write nullable values to non-null field")) + } + } + + test("Check struct types: data field would be dropped") { + assertSingleError(point3, point2, "t", + "Should fail because field 'z' would be dropped") { err => + assert(err.contains("'t'"), "Should include the struct name for context") + assert(err.contains("'z'"), "Should include the extra field name") + assert(err.contains("Cannot write extra fields")) + } + } + + test("Check struct types: unsafe casts are not allowed") { + assertNumErrors(widerPoint2, point2, "t", + "Should fail because types require unsafe casts", 2) { errs => + + assert(errs(0).contains("'t.x'"), "Should include the nested field name context") + assert(errs(0).contains("Cannot safely cast")) + + assert(errs(1).contains("'t.y'"), "Should include the nested field name context") + assert(errs(1).contains("Cannot safely cast")) + } + } + + test("Check struct types: type promotion is allowed") { + assertAllowed(point2, widerPoint2, "t", + "Should allow widening float fields x and y to double") + } + + ignore("Check struct types: missing optional field is allowed") { + // built-in data sources do not yet support missing fields when optional + assertAllowed(point2, point3, "t", + "Should allow writing point (x,y) to point(x,y,z=null)") + } + + test("Check array types: unsafe casts are not allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + + assertSingleError(arrayOfLong, arrayOfInt, "arr", + "Should not allow array of longs to array of ints") { err => + assert(err.contains("'arr.element'"), + "Should identify problem with named array's element type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check array types: type promotion is allowed") { + val arrayOfLong = ArrayType(LongType) + val arrayOfInt = ArrayType(IntegerType) + assertAllowed(arrayOfInt, arrayOfLong, "arr", + "Should allow array of int written to array of long column") + } + + test("Check array types: cannot write optional to required elements") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertSingleError(arrayOfOptional, arrayOfRequired, "arr", + "Should not allow array of optional elements to array of required elements") { err => + assert(err.contains("'arr'"), "Should include type name context") + assert(err.contains("Cannot write nullable elements to array of non-nulls")) + } + } + + test("Check array types: writing required to optional elements is allowed") { + val arrayOfRequired = ArrayType(LongType, containsNull = false) + val arrayOfOptional = ArrayType(LongType) + + assertAllowed(arrayOfRequired, arrayOfOptional, "arr", + "Should allow array of required elements to array of optional elements") + } + + test("Check map value types: unsafe casts are not allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertSingleError(mapOfLong, mapOfInt, "m", + "Should not allow map of longs to map of ints") { err => + assert(err.contains("'m.value'"), "Should identify problem with named map's value type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map value types: type promotion is allowed") { + val mapOfLong = MapType(StringType, LongType) + val mapOfInt = MapType(StringType, IntegerType) + + assertAllowed(mapOfInt, mapOfLong, "m", "Should allow map of int written to map of long column") + } + + test("Check map value types: cannot write optional to required values") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertSingleError(mapOfOptional, mapOfRequired, "m", + "Should not allow map of optional values to map of required values") { err => + assert(err.contains("'m'"), "Should include type name context") + assert(err.contains("Cannot write nullable values to map of non-nulls")) + } + } + + test("Check map value types: writing required to optional values is allowed") { + val mapOfRequired = MapType(StringType, LongType, valueContainsNull = false) + val mapOfOptional = MapType(StringType, LongType) + + assertAllowed(mapOfRequired, mapOfOptional, "m", + "Should allow map of required elements to map of optional elements") + } + + test("Check map key types: unsafe casts are not allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertSingleError(mapKeyLong, mapKeyInt, "m", + "Should not allow map of long keys to map of int keys") { err => + assert(err.contains("'m.key'"), "Should identify problem with named map's key type") + assert(err.contains("Cannot safely cast")) + } + } + + test("Check map key types: type promotion is allowed") { + val mapKeyLong = MapType(LongType, StringType) + val mapKeyInt = MapType(IntegerType, StringType) + + assertAllowed(mapKeyInt, mapKeyLong, "m", + "Should allow map of int written to map of long column") + } + + test("Check types with multiple errors") { + val readType = StructType(Seq( + StructField("a", ArrayType(DoubleType, containsNull = false)), + StructField("arr_of_structs", ArrayType(point2, containsNull = false)), + StructField("bad_nested_type", ArrayType(StringType)), + StructField("m", MapType(LongType, FloatType, valueContainsNull = false)), + StructField("map_of_structs", MapType(StringType, point3, valueContainsNull = false)), + StructField("x", IntegerType, nullable = false), + StructField("missing1", StringType, nullable = false), + StructField("missing2", StringType) + )) + + val missingMiddleField = StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("z", FloatType, nullable = false))) + + val writeType = StructType(Seq( + StructField("a", ArrayType(StringType)), + StructField("arr_of_structs", ArrayType(point3)), + StructField("bad_nested_type", point3), + StructField("m", MapType(DoubleType, DoubleType)), + StructField("map_of_structs", MapType(StringType, missingMiddleField)), + StructField("y", LongType) + )) + + assertNumErrors(writeType, readType, "top", "Should catch 14 errors", 14) { errs => + assert(errs(0).contains("'top.a.element'"), "Should identify bad type") + assert(errs(0).contains("Cannot safely cast")) + assert(errs(0).contains("StringType to DoubleType")) + + assert(errs(1).contains("'top.a'"), "Should identify bad type") + assert(errs(1).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(2).contains("'top.arr_of_structs.element'"), "Should identify bad type") + assert(errs(2).contains("'z'"), "Should identify bad field") + assert(errs(2).contains("Cannot write extra fields to struct")) + + assert(errs(3).contains("'top.arr_of_structs'"), "Should identify bad type") + assert(errs(3).contains("Cannot write nullable elements to array of non-nulls")) + + assert(errs(4).contains("'top.bad_nested_type'"), "Should identify bad type") + assert(errs(4).contains("is incompatible with")) + + assert(errs(5).contains("'top.m.key'"), "Should identify bad type") + assert(errs(5).contains("Cannot safely cast")) + assert(errs(5).contains("DoubleType to LongType")) + + assert(errs(6).contains("'top.m.value'"), "Should identify bad type") + assert(errs(6).contains("Cannot safely cast")) + assert(errs(6).contains("DoubleType to FloatType")) + + assert(errs(7).contains("'top.m'"), "Should identify bad type") + assert(errs(7).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(8).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(8).contains("expected 'y', found 'z'"), "Should detect name mismatch") + assert(errs(8).contains("field name does not match"), "Should identify name problem") + + assert(errs(9).contains("'top.map_of_structs.value'"), "Should identify bad type") + assert(errs(9).contains("'z'"), "Should identify missing field") + assert(errs(9).contains("missing fields"), "Should detect missing field") + + assert(errs(10).contains("'top.map_of_structs'"), "Should identify bad type") + assert(errs(10).contains("Cannot write nullable values to map of non-nulls")) + + assert(errs(11).contains("'top.x'"), "Should identify bad type") + assert(errs(11).contains("Cannot safely cast")) + assert(errs(11).contains("LongType to IntegerType")) + + assert(errs(12).contains("'top'"), "Should identify bad type") + assert(errs(12).contains("expected 'x', found 'y'"), "Should detect name mismatch") + assert(errs(12).contains("field name does not match"), "Should identify name problem") + + assert(errs(13).contains("'top'"), "Should identify bad type") + assert(errs(13).contains("'missing1'"), "Should identify missing field") + assert(errs(13).contains("missing fields"), "Should detect missing field") + } + } + + // Helper functions + + def assertAllowed(writeType: DataType, readType: DataType, name: String, desc: String): Unit = { + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => fail(s"Should not produce errors but was called with: $errMsg")) === true, desc) + } + + def assertSingleError( + writeType: DataType, + readType: DataType, + name: String, + desc: String) + (errFunc: String => Unit): Unit = { + assertNumErrors(writeType, readType, name, desc, 1) { errs => + errFunc(errs.head) + } + } + + def assertNumErrors( + writeType: DataType, + readType: DataType, + name: String, + desc: String, + numErrs: Int) + (errFunc: Seq[String] => Unit): Unit = { + val errs = new mutable.ArrayBuffer[String]() + assert( + DataType.canWrite(writeType, readType, analysis.caseSensitiveResolution, name, + errMsg => errs += errMsg) === false, desc) + assert(errs.size === numErrs, s"Should produce $numErrs error messages") + errFunc(errs) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 83aeec0c47853..048787a7a0a05 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -38,15 +38,16 @@ public interface WriteSupport extends DataSourceV2 { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * - * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceWriter} can - * use this job id to distinguish itself from other jobs. + * @param writeUUID A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceWriter} can + * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. + * @return a writer to append data to this data source */ Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceOptions options); + String writeUUID, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index cd7dc2a2727e4..db2a1e7426197 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.text.SimpleDateFormat -import java.util.{Date, Locale, Properties, UUID} +import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -26,12 +25,11 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType @@ -240,21 +238,27 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - ds match { + val source = cls.newInstance().asInstanceOf[DataSourceV2] + source match { case ws: WriteSupport => - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = df.sparkSession.sessionState.conf)).asJava) - // Using a timestamp and a random UUID to distinguish different writing jobs. This is good - // enough as there won't be tons of writing jobs created at the same second. - val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) - .format(new Date()) + "-" + UUID.randomUUID() - val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options) - if (writer.isPresent) { + val options = extraOptions ++ + DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) + + val relation = DataSourceV2Relation.create(source, options.toMap) + if (mode == SaveMode.Append) { runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(writer.get(), df.logicalPlan) + AppendData.byName(relation, df.logicalPlan) + } + + } else { + val writer = ws.createWriter( + UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode, + new DataSourceOptions(options.asJava)) + + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get, df.logicalPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 46166928f449d..a4bfc861cc9a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.collection.JavaConverters._ -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter import org.apache.spark.sql.types.StructType /** @@ -40,17 +44,24 @@ case class DataSourceV2Relation( source: DataSourceV2, output: Seq[AttributeReference], options: Map[String, String], - userSpecifiedSchema: Option[StructType]) - extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + tableIdent: Option[TableIdentifier] = None, + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ + override def name: String = { + tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") + } + override def pushedFilters: Seq[Expression] = Seq.empty override def simpleString: String = "RelationV2 " + metadataString def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) + def newWriter(): DataSourceWriter = source.createWriter(options, schema) + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -115,6 +126,15 @@ object DataSourceV2Relation { } } + def asWriteSupport: WriteSupport = { + source match { + case support: WriteSupport => + support + case _ => + throw new AnalysisException(s"Data source is not writable: $name") + } + } + def name: String = { source match { case registered: DataSourceRegister => @@ -135,14 +155,29 @@ object DataSourceV2Relation { asReadSupport.createReader(v2Options) } } + + def createWriter( + options: Map[String, String], + schema: StructType): DataSourceWriter = { + val v2Options = new DataSourceOptions(options.asJava) + asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get + } } def create( source: DataSourceV2, options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + tableIdent: Option[TableIdentifier] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { val reader = source.createReader(options, userSpecifiedSchema) + val ident = tableIdent.orElse(tableFromOptions(options)) DataSourceV2Relation( - source, reader.readSchema().toAttributes, options, userSpecifiedSchema) + source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema) + } + + private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { + options + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9414e68155b98..6daaa4c65c335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -136,6 +136,9 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case AppendData(r: DataSourceV2Relation, query, _) => + WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 0399970495bec..59ebb9bc5431b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -35,8 +35,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** - * The logical plan for writing data into data source v2. + * Deprecated logical plan for writing data into data source v2. This is being replaced by more + * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ +@deprecated("Use specific logical plans like AppendData instead", "2.4.0") case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index fef53e6f7b6fa..aa5f723365d5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -192,33 +192,33 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val path = file.getCanonicalPath assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) // test with different save modes - spark.range(10).select('id, -'id).write.format(cls.getName) + spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("overwrite").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("ignore").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) val e = intercept[Exception] { - spark.range(5).select('id, -'id).write.format(cls.getName) + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } assert(e.getMessage.contains("data already exists")) @@ -235,7 +235,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } // this input data will fail to read middle way. - val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) val e2 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } @@ -253,7 +253,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) val numPartition = 6 - spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), From c7a229d655a0ac0658d9f40d388ad123f4ed8ce2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 8 Aug 2018 11:05:52 +0800 Subject: [PATCH 1344/2461] [SPARK-25010][SQL][FOLLOWUP] Shuffle should also produce different values for each execution in streaming query. ## What changes were proposed in this pull request? This is a follow-up pr of #21980. `Shuffle` can also be `ExpressionWithRandomSeed` to produce different values for each execution in streaming query. ## How was this patch tested? Added a test. Closes #22027 from ueshin/issues/SPARK-25010/random_seed. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 4 +++- .../sql/streaming/StreamingQuerySuite.scala | 21 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fbb182631eefa..ab06a5a544ade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1218,10 +1218,12 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi note = "The function is non-deterministic.", since = "2.4.0") case class Shuffle(child: Expression, randomSeed: Option[Long] = None) - extends UnaryExpression with ExpectsInputTypes with Stateful { + extends UnaryExpression with ExpectsInputTypes with Stateful with ExpressionWithRandomSeed { def this(child: Expression) = this(child, None) + override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 848924dde296e..268ed58315fdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Rand, Randn, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -905,6 +905,25 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(rands.distinct.size == 4) } + test("Shuffle in streaming query should not produce same results in each execution") { + val rands = mutable.ArrayBuffer[Seq[Int]]() + def collectShuffle: Seq[Row] => Unit = { rows: Seq[Row] => + rows.foreach { r => + rands += r.getSeq[Int](0) + } + } + + val stream = MemoryStream[Int] + val df = stream.toDF().select(new Column(new Shuffle(Literal.create[Seq[Int]](0 until 100)))) + testStream(df)( + AddData(stream, 1), + CheckAnswer(collectShuffle), + AddData(stream, 2), + CheckAnswer(collectShuffle) + ) + assert(rands.distinct.size == 2) + } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + "should not fail") { val df = spark.readStream.format("rate").load() From f08f6f4314b16fb09c479f6537f99bda77e4c256 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 8 Aug 2018 14:38:55 +0900 Subject: [PATCH 1345/2461] [SPARK-23935][SQL][FOLLOWUP] mapEntry throws org.codehaus.commons.compiler.CompileException ## What changes were proposed in this pull request? This PR fixes an exception during the compilation of generated code of `mapEntry`. This error occurs since the current code uses `key` type to store a `value` when `key` and `value` types are primitive type. ``` val mid0 = Literal.create(Map(1 -> 1.1, 2 -> 2.2), MapType(IntegerType, DoubleType)) checkEvaluation(MapEntries(mid0), Seq(r(1, 1.1), r(2, 2.2))) ``` ``` [info] Code generation of map_entries(keys: [1,2], values: [1.1,2.2]) failed: [info] java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 80, Column 20: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 80, Column 20: No applicable constructor/method found for actual parameters "int, double"; candidates are: "public void org.apache.spark.sql.catalyst.expressions.UnsafeRow.setInt(int, int)", "public void org.apache.spark.sql.catalyst.InternalRow.setInt(int, int)" [info] java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 80, Column 20: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 80, Column 20: No applicable constructor/method found for actual parameters "int, double"; candidates are: "public void org.apache.spark.sql.catalyst.expressions.UnsafeRow.setInt(int, int)", "public void org.apache.spark.sql.catalyst.InternalRow.setInt(int, int)" [info] at com.google.common.util.concurrent.AbstractFuture$Sync.getValue(AbstractFuture.java:306) [info] at com.google.common.util.concurrent.AbstractFuture$Sync.get(AbstractFuture.java:293) [info] at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:116) [info] at com.google.common.util.concurrent.Uninterruptibles.getUninterruptibly(Uninterruptibles.java:135) [info] at com.google.common.cache.LocalCache$Segment.getAndRecordStats(LocalCache.java:2410) [info] at com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2380) [info] at com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342) [info] at com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2257) [info] at com.google.common.cache.LocalCache.get(LocalCache.java:4000) [info] at com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:4004) [info] at com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874) [info] at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:1290) ... ``` ## How was this patch tested? Added a new test to `CollectionExpressionsSuite` Closes #22033 from kiszk/SPARK-23935-followup. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ab06a5a544ade..b37fdc6d10fd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -426,7 +426,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 val structSizeAsLong = structSize + "L" val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.valueType) val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" val valueAssignmentChecked = if (childDataType.valueContainsNull) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 40487b3fd001c..7b345aabd19c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -90,10 +90,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + val mid0 = Literal.create(Map(1 -> 1.1, 2 -> 2.2), MapType(IntegerType, DoubleType)) checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) checkEvaluation(MapEntries(mi1), Seq.empty) checkEvaluation(MapEntries(mi2), null) + checkEvaluation(MapEntries(mid0), Seq(r(1, 1.1), r(2, 2.2))) // Non-primitive-type keys/values val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) From 960af63913613ed7104cd76e477e325bd3020163 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 8 Aug 2018 14:46:00 +0800 Subject: [PATCH 1346/2461] [SPARK-25036][SQL] avoid match may not be exhaustive in Scala-2.12 ## What changes were proposed in this pull request? The PR remove the following compilation error using scala-2.12 with sbt by adding a default case to `match`. ``` /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala:63: match may not be exhaustive. [error] It would fail on the following inputs: (NumericValueInterval(_, _), _), (_, NumericValueInterval(_, _)), (_, _) [error] [warn] def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match { [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala:79: match may not be exhaustive. [error] It would fail on the following inputs: (NumericValueInterval(_, _), _), (_, NumericValueInterval(_, _)), (_, _) [error] [warn] (r1, r2) match { [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala:67: match may not be exhaustive. [error] It would fail on the following inputs: (ArrayType(_, _), _), (_, ArrayData()), (_, _) [error] [warn] (endpointsExpression.dataType, endpointsExpression.eval()) match { [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala:470: match may not be exhaustive. [error] It would fail on the following inputs: NewFunctionSpec(_, None, Some(_)), NewFunctionSpec(_, Some(_), None) [error] [warn] newFunction match { [error] [warn] [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala:709: match may not be exhaustive. [error] It would fail on the following input: Schema((x: org.apache.spark.sql.types.DataType forSome x not in org.apache.spark.sql.types.StructType), _) [error] [warn] def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { [error] [warn] ``` ## How was this patch tested? Existing UTs with Scala-2.11. Manually build with Scala-2.12 Closes #22014 from kiszk/SPARK-25036b. Authored-by: Kazuaki Ishizaki Signed-off-by: hyukjinkwon --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 2 ++ .../aggregate/ApproxCountDistinctForIntervals.scala | 10 +++++----- .../catalyst/expressions/codegen/CodeGenerator.scala | 2 ++ .../plans/logical/statsEstimation/ValueInterval.scala | 4 ++++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4543bba8f6ed4..191c3de965b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -709,6 +709,8 @@ object ScalaReflection extends ScalaReflection { def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => s.toAttributes + case others => + throw new UnsupportedOperationException(s"Attributes for type $others is not supported") } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index d4421ca20a9bd..f96a087972f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -63,11 +63,11 @@ case class ApproxCountDistinctForIntervals( } // Mark as lazy so that endpointsExpression is not evaluated during tree transformation. - lazy val endpoints: Array[Double] = - (endpointsExpression.dataType, endpointsExpression.eval()) match { - case (ArrayType(elementType, _), arrayData: ArrayData) => - arrayData.toObjectArray(elementType).map(_.toString.toDouble) - } + lazy val endpoints: Array[Double] = { + val endpointsType = endpointsExpression.dataType.asInstanceOf[ArrayType] + val endpoints = endpointsExpression.eval().asInstanceOf[ArrayData] + endpoints.toObjectArray(endpointsType.elementType).map(_.toString.toDouble) + } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 498dd2639f423..4b30de5aeb7ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -471,6 +471,8 @@ class CodegenContext { case NewFunctionSpec(functionName, None, None) => functionName case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => innerClassInstance + "." + functionName + case _ => + throw new IllegalArgumentException(s"$funcName is not matched at addNewFunction") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index f46b4ed764e27..693d2a7210ab8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -69,6 +69,8 @@ object ValueInterval { false case (n1: NumericValueInterval, n2: NumericValueInterval) => n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 + case _ => + throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at isIntersected()") } /** @@ -86,6 +88,8 @@ object ValueInterval { val newMax = if (n1.max <= n2.max) n1.max else n2.max (Some(EstimationUtils.fromDouble(newMin, dt)), Some(EstimationUtils.fromDouble(newMax, dt))) + case _ => + throw new UnsupportedOperationException(s"Not supported pair: $r1, $r2 at intersect()") } } } From 6f6a4200783ec4d9041421c1b5fc59d4c9a58adb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 9 Aug 2018 00:01:03 +0900 Subject: [PATCH 1347/2461] [SPARK-23911][SQL][FOLLOW-UP] Fix examples of aggregate function. ## What changes were proposed in this pull request? This pr is a follow-up pr of #21982 and fixes the examples. ## How was this patch tested? Existing tests. Closes #22035 from ueshin/issues/SPARK-23911/fup1. Authored-by: Takuya UESHIN Signed-off-by: Takuya UESHIN --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0a473d2885079..d20673359129b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -368,9 +368,9 @@ case class ArrayFilter( """, examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x); + > SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x); 6 - > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); + > SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x, acc -> acc * 10); 60 """, since = "2.4.0") From f62fe435de2228ad6c3a857c214e782b8f308516 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 8 Aug 2018 16:47:22 -0500 Subject: [PATCH 1348/2461] [SPARK-25036][SQL][FOLLOW-UP] Avoid match may not be exhaustive in Scala-2.12. ## What changes were proposed in this pull request? This is a follow-up pr of #22014. We still have some more compilation errors in scala-2.12 with sbt: ``` [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala:493: match may not be exhaustive. [error] It would fail on the following input: (_, _) [error] [warn] val typeMatches = (targetType, f.dataType) match { [error] [warn] [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:393: match may not be exhaustive. [error] It would fail on the following input: (_, _) [error] [warn] prevBatchOff.get.toStreamProgress(sources).foreach { [error] [warn] [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala:173: match may not be exhaustive. [error] It would fail on the following input: AggregateExpression(_, _, false, _) [error] [warn] val rewrittenDistinctFunctions = functionsWithDistinct.map { [error] [warn] [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala:271: match may not be exhaustive. [error] It would fail on the following input: (_, _) [error] [warn] keyWithIndexToValueMetrics.customMetrics.map { [error] [warn] [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala:959: match may not be exhaustive. [error] It would fail on the following input: CatalogTableType(_) [error] [warn] val tableTypeString = metadata.tableType match { [error] [warn] [error] [warn] /.../sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala:923: match may not be exhaustive. [error] It would fail on the following input: CatalogTableType(_) [error] [warn] hiveTable.setTableType(table.tableType match { [error] [warn] ``` ## How was this patch tested? Manually build with Scala-2.12. Closes #22039 from ueshin/issues/SPARK-25036/fix_match. Authored-by: Takuya UESHIN Signed-off-by: Sean Owen --- .../scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 2 ++ .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 4 ++++ .../scala/org/apache/spark/sql/execution/command/tables.scala | 3 +++ .../spark/sql/execution/streaming/MicroBatchExecution.scala | 3 +++ .../streaming/state/SymmetricHashJoinStateManager.scala | 3 +++ .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 3 +++ 6 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f3a2b70657c48..5288907b7d7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -494,6 +494,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType case (BooleanType, dt) => dt == BooleanType + case _ => + throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..c8ef2b3f6998d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -177,6 +177,10 @@ object AggUtils { case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] + case agg => + throw new IllegalArgumentException( + "Non-distinct aggregate is found in functionsWithDistinct " + + s"at planAggregateWithOneDistinct: $agg") } val partialDistinctAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 56f48b7dc00ee..f4dede9fcc899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -960,6 +960,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman case EXTERNAL => " EXTERNAL TABLE" case VIEW => " VIEW" case MANAGED => " TABLE" + case t => + throw new IllegalArgumentException( + s"Unknown table type is found at showCreateHiveTable: $t") } builder ++= s"CREATE$tableTypeString ${table.quotedString}" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index c759f5be8ba35..b1c91ac94b268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -394,6 +394,9 @@ class MicroBatchExecution( case (src: Source, off) => src.commit(off) case (reader: MicroBatchReader, off) => reader.commit(reader.deserializeOffset(off.json)) + case (src, _) => + throw new IllegalArgumentException( + s"Unknown source is found at constructNextBatch: $src") } } else { throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 55d783e023246..6e7cd2db213d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -273,6 +273,9 @@ class SymmetricHashJoinStateManager( s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomTimingMetric(_, desc), value) => s.copy(desc = newDesc(desc)) -> value + case (s, _) => + throw new IllegalArgumentException( + s"Unknown state store custom metric is found at metrics: $s") } ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index db8fd5a43d842..02c1ed93eb2f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -927,6 +927,9 @@ private[hive] object HiveClientImpl { case CatalogTableType.MANAGED => HiveTableType.MANAGED_TABLE case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW + case t => + throw new IllegalArgumentException( + s"Unknown table type is found at toHiveTable: $t") }) // Note: In Hive the schema and partition columns must be disjoint sets val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => From a40806d2bd84e9a0308165f0d6c97e9cf00aa4a3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Aug 2018 12:07:57 +0800 Subject: [PATCH 1349/2461] [SPARK-23596][SQL] Test interpreted path on encoders test suites ## What changes were proposed in this pull request? We have completed a significant subset of the object related Expressions to provide an interpreted fallback. This PR is going to modify the tests to also test the interpreted code paths. One concern right now is that by testing the interpreted code paths too, we will double current test time or more. Otherwise, we can only choose to test the interpreted code paths for just few test suites such as encoder related. ## How was this patch tested? Existing tests. Closes #21535 from viirya/SPARK-23596. Authored-by: Liang-Chi Hsieh Signed-off-by: hyukjinkwon --- .../encoders/ExpressionEncoderSuite.scala | 4 ++-- .../catalyst/encoders/RowEncoderSuite.scala | 4 ++-- .../spark/sql/catalyst/plans/PlanTest.scala | 20 +++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e6d09bdae67d7..f0d61de97ffcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -112,7 +112,7 @@ object ReferenceValueClass { case class Container(data: Int) } -class ExpressionEncoderSuite extends PlanTest with AnalysisTest { +class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6ed175f86ca77..8d89f9c6c41d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.encoders import scala.util.Random -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -71,7 +71,7 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] { private[spark] override def asNullable: ExamplePointUDT = this } -class RowEncoderSuite extends SparkFunSuite { +class RowEncoderSuite extends CodegenInterpretedPlanTest { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 139785719fec7..9e95b192968c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.plans +import org.scalactic.source import org.scalatest.Suite +import org.scalatest.Tag import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -33,6 +36,23 @@ import org.apache.spark.sql.internal.SQLConf */ trait PlanTest extends SparkFunSuite with PlanTestBase +trait CodegenInterpretedPlanTest extends PlanTest { + + override protected def test( + testName: String, + testTags: Tag*)(testFun: => Any)(implicit pos: source.Position): Unit = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + val interpretedMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + super.test(testName + " (codegen path)", testTags: _*)(testFun)(pos) + } + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode) { + super.test(testName + " (interpreted path)", testTags: _*)(testFun)(pos) + } + } +} + /** * Provides helper methods for comparing plans, but without the overhead of * mandating a FunSuite. From 519e03d82e52d2d948f3ef25f5b85cc54bb11a75 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 9 Aug 2018 14:06:28 +0900 Subject: [PATCH 1350/2461] [SPARK-25058][SQL] Use Block.isEmpty/nonEmpty to check whether the code is empty or not. ## What changes were proposed in this pull request? We should use `Block.isEmpty/nonEmpty` instead of comparing with empty string to check whether the code is empty or not. ``` [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala:278: org.apache.spark.sql.catalyst.expressions.codegen.Block and String are unrelated: they will most likely always compare unequal [error] [warn] if (ev.code != "" && required.contains(attributes(i))) { [error] [warn] [error] [warn] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala:323: org.apache.spark.sql.catalyst.expressions.codegen.Block and String are unrelated: they will most likely never compare equal [error] [warn] | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} [error] [warn] ``` ## How was this patch tested? Existing tests. Closes #22041 from ueshin/issues/SPARK-25058/fix_comparison. Authored-by: Takuya UESHIN Signed-off-by: Takuya UESHIN --- .../spark/sql/catalyst/expressions/codegen/javaCode.scala | 4 +++- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 2 +- .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 2f8c853e836ba..558cbfa560053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -130,7 +130,9 @@ trait Block extends TreeNode[Block] with JavaCode { def length: Int = toString.length - def nonEmpty: Boolean = toString.nonEmpty + def isEmpty: Boolean = toString.isEmpty + + def nonEmpty: Boolean = !isEmpty // The leading prefix that should be stripped from each line. // By default we strip blanks or control characters followed by '|' from the line. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 372dc3db36ce6..80f886ea1adc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -275,7 +275,7 @@ trait CodegenSupport extends SparkPlan { required: AttributeSet): String = { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => - if (ev.code != "" && required.contains(attributes(i))) { + if (ev.code.nonEmpty && required.contains(attributes(i))) { evaluateVars.append(ev.code.toString + "\n") ev.code = EmptyBlock } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0da0e8610c392..a6f3ea47c8492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -320,7 +320,7 @@ case class BroadcastHashJoinExec( |if (!$conditionPassed) { | $matched = null; | // reset the variables those are already evaluated. - | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} + | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} = true;").mkString("\n")} |} |$numOutput.add(1); |${consume(ctx, resultVars)} From 56e9e97073cf1896e301371b3941c9307e42ff77 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 9 Aug 2018 20:10:17 +0800 Subject: [PATCH 1351/2461] [MINOR][DOC] Fix typo ## What changes were proposed in this pull request? This PR fixes typo regarding `auxiliary verb + verb[s]`. This is a follow-on of #21956. ## How was this patch tested? N/A Closes #22040 from kiszk/spellcheck1. Authored-by: Kazuaki Ishizaki Signed-off-by: hyukjinkwon --- .../java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 2 +- .../util/collection/unsafe/sort/UnsafeSorterSpillMerger.java | 2 +- .../main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala | 2 +- .../java/test/org/apache/spark/JavaSparkContextSuite.java | 2 +- .../org/apache/spark/sql/kafka010/KafkaDataConsumer.scala | 2 +- .../apache/spark/ml/classification/LogisticRegression.scala | 2 +- python/pyspark/sql/types.py | 2 +- .../apache/spark/sql/catalyst/analysis/DecimalPrecision.scala | 2 +- .../expressions/CodeGeneratorWithInterpretedFallback.scala | 2 +- .../spark/sql/catalyst/expressions/ExpectsInputTypes.scala | 2 +- .../sql/catalyst/analysis/UnsupportedOperationsSuite.scala | 4 ++-- .../spark/sql/catalyst/encoders/EncoderResolutionSuite.scala | 2 +- .../org/apache/spark/sql/execution/metric/SQLMetrics.scala | 2 +- .../spark/sql/execution/streaming/FileStreamSource.scala | 2 +- .../spark/sql/execution/streaming/ProgressReporter.scala | 2 +- .../test/scala/org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../sql/hive/execution/CreateHiveTableAsSelectCommand.scala | 2 +- .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 +- 18 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 9a767dd739b91..9b6cbab38cbcc 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -662,7 +662,7 @@ public int getValueLength() { * It is only valid to call this method immediately after calling `lookup()` using the same key. *

      *

      - * The key and value must be word-aligned (that is, their sizes must multiples of 8). + * The key and value must be word-aligned (that is, their sizes must be a multiple of 8). *

      *

      * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index ff0dcc259a4ad..ab800288dcb43 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -51,7 +51,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept if (spillReader.hasNext()) { // We only add the spillReader to the priorityQueue if it is not empty. We do this to // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator - // does not return wrong result because hasNext will returns true + // does not return wrong result because hasNext will return true // at least priorityQueue.size() times. If we allow n spillReaders in the // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator. spillReader.loadNext(); diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 70a8c659bbdd3..4cc0063d010ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -107,7 +107,7 @@ class SparkHadoopUtil extends Logging { } /** - * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * Return an appropriate (subclass) of Configuration. Creating config can initialize some Hadoop * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { diff --git a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java index 7e9cc70d8651f..0f489fb219010 100644 --- a/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaSparkContextSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.*; /** - * Java apps can uses both Java-friendly JavaSparkContext and Scala SparkContext. + * Java apps can use both Java-friendly JavaSparkContext and Scala SparkContext. */ public class JavaSparkContextSuite implements Serializable { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 941f0ab177e48..65046c175a7e5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -155,7 +155,7 @@ private[kafka010] case class InternalKafkaConsumer( var toFetchOffset = offset var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null // We want to break out of the while loop on a successful fetch to avoid using "return" - // which may causes a NonLocalReturnControl exception when this method is used as a function. + // which may cause a NonLocalReturnControl exception when this method is used as a function. var isFetchComplete = false while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index af651b056f2f7..408d92ef180de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1484,7 +1484,7 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * Convenient method for casting to binary logistic regression summary. - * This method will throws an Exception if the summary is not a binary summary. + * This method will throw an Exception if the summary is not a binary summary. */ @Since("2.3.0") def asBinary: BinaryLogisticRegressionSummary = this match { diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3cd7a2ef115af..214d8fe6bbbb6 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -206,7 +206,7 @@ class DecimalType(FractionalType): and scale (the number of digits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99]. - The precision can be up to 38, the scale must less or equal to precision. + The precision can be up to 38, the scale must be less or equal to precision. When create a DecimalType, the default precision and scale is (10, 0). When infer schema from decimal.Decimal objects, it will be DecimalType(38, 18). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 23d146e71ed19..e511f8064e28a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -286,7 +286,7 @@ object DecimalPrecision extends TypeCoercionRule { // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. // If we use the default precision and scale for the integer type, 2 is considered a // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), - // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // which is out of range and therefore it will become DECIMAL(38, 7), leading to // potentially loosing 11 digits of the fractional part. Using only the precision needed // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would // become DECIMAL(38, 16), safely having a much lower precision loss. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala index fb25e781e72e4..0f6d86691b4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -44,7 +44,7 @@ object CodegenObjectFactoryMode extends Enumeration { /** * A codegen object generator which creates objects with codegen path first. Once any compile - * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * error happens, it can fallback to interpreted implementation. In tests, we can use a SQL config * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. */ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 464566b0cb7d9..d8f046c0028a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.AbstractDataType * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define * expected input types without any implicit casting. * - * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. + * Most function expressions (e.g. [[Substring]] should extend [[ImplicitCastInputTypes]]) instead. */ trait ExpectsInputTypes extends Expression { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 197d7c7668ef1..28a164b5d0cad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -766,7 +766,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { * * To test this correctly, the given logical plan is wrapped in a fake operator that makes the * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported - * exception simply for not being a streaming plan, even though that plan could exists as batch + * exception simply for not being a streaming plan, even though that plan could exist as batch * subplan inside some streaming plan. */ def assertSupportedInStreamingPlan( @@ -793,7 +793,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { * * To test this correctly, the given logical plan is wrapped in a fake operator that makes the * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported - * exception simply for not being a streaming plan, even though that plan could exists as batch + * exception simply for not being a streaming plan, even though that plan could exist as batch * subplan inside some streaming plan. */ def assertNotSupportedInStreamingPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 630113ce2d948..dd20e6497fbb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -144,7 +144,7 @@ class EncoderResolutionSuite extends PlanTest { // It should pass analysis val bound = encoder.resolveAndBind(attrs) - // If no null values appear, it should works fine + // If no null values appear, it should work fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) // If there is null value, it should throw runtime exception diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index b4f0ae1eb1a18..98f58a3056906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -110,7 +110,7 @@ object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): SQLMetric = { - // The final result of this metric in physical operator UI may looks like: + // The final result of this metric in physical operator UI may look like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) val acc = new SQLMetric(SIZE_METRIC, -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8c016abc5b643..103fa7ce9066d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -50,7 +50,7 @@ class FileStreamSource( @transient private val fs = new Path(path).getFileSystem(hadoopConf) private val qualifiedBasePath: Path = { - fs.makeQualified(new Path(path)) // can contains glob patterns + fs.makeQualified(new Path(path)) // can contain glob patterns } private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1e158323d2020..ae1bfa2e499bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -312,7 +312,7 @@ trait ProgressReporter extends Logging { // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or - // even multiple times) points and considering it twice will leads to double counting. We + // even multiple times) points and considering it twice will lead to double counting. We // can't dedup them using their hashcode either because two different instances of // DataSourceV2ScanExec can have the same hashcode but account for separate sets of // records read, and deduping them to consider only one of them would be undercounting the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ac70488febc5a..2fb8f70a20791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -76,7 +76,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with /** * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * ConsoleAppender's `follow` should be set to `true` so that it will honor reassignments of * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if * we change System.out and System.err. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 1e801fe1845c4..27d807cc35627 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand /** * Create table and insert the query result into it. * - * @param tableDesc the Table Describe, which may contains serde, storage handler etc. + * @param tableDesc the Table Describe, which may contain serde, storage handler etc. * @param query the query whose result will be insert into the new relation * @param mode SaveMode */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 741b0124c83b9..b9c32e789a410 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -84,7 +84,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } // Testing the Broadcast based join for cartesian join (cross join) - // We assume that the Broadcast Join Threshold will works since the src is a small table + // We assume that the Broadcast Join Threshold will work since the src is a small table private val spark_10484_1 = """ | SELECT a.key, b.key | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 From 386fbd3aff95ce919567b1b94d5b19c5bcef266a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 9 Aug 2018 20:28:14 +0800 Subject: [PATCH 1352/2461] [SPARK-23415][SQL][TEST] Make behavior of BufferHolderSparkSubmitSuite correct and stable ## What changes were proposed in this pull request? This PR addresses two issues in `BufferHolderSparkSubmitSuite`. 1. While `BufferHolderSparkSubmitSuite` tried to allocate a large object several times, it actually allocated an object once and reused the object. 2. `BufferHolderSparkSubmitSuite` may fail due to timeout To assign a small object before allocating a large object each time solved issue 1 by avoiding reuse. To increasing heap size from 4g to 7g solved issue 2. It can also avoid OOM after fixing issue 1. ## How was this patch tested? Updated existing `BufferHolderSparkSubmitSuite` Closes #20636 from kiszk/SPARK-23415. Authored-by: Kazuaki Ishizaki Signed-off-by: Wenchen Fan --- .../expressions/codegen/BufferHolder.java | 13 +++++-- .../BufferHolderSparkSubmitSuite.scala | 36 +++++++++++-------- .../codegen/BufferHolderSuite.scala | 10 +++--- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 537ef244b7e81..6a52a5b0e0664 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -35,6 +35,7 @@ final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + // buffer is guarantee to be word-aligned since UnsafeRow assumes each field is word-aligned. private byte[] buffer; private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; @@ -52,7 +53,8 @@ final class BufferHolder { "too many fields (number of fields: " + row.numFields() + ")"); } this.fixedSize = bitsetWidthInBytes + 8 * row.numFields(); - this.buffer = new byte[fixedSize + initialSize]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(fixedSize + initialSize); + this.buffer = new byte[roundedSize]; this.row = row; this.row.pointTo(buffer, buffer.length); } @@ -61,8 +63,12 @@ final class BufferHolder { * Grows the buffer by at least neededSize and points the row to the buffer. */ void grow(int neededSize) { + if (neededSize < 0) { + throw new IllegalArgumentException( + "Cannot grow BufferHolder by size " + neededSize + " because the size is negative"); + } if (neededSize > ARRAY_MAX - totalSize()) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } @@ -70,7 +76,8 @@ void grow(int neededSize) { if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(newLength); + final byte[] tmp = new byte[roundedSize]; Platform.copyMemory( buffer, Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 85682cf6ea670..d2862c8f41d1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers} import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.SparkSubmitSuite @@ -39,7 +39,7 @@ class BufferHolderSparkSubmitSuite val argsForSparkSubmit = Seq( "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), "--name", "SPARK-22222", - "--master", "local-cluster[2,1,1024]", + "--master", "local-cluster[1,1,4096]", "--driver-memory", "4g", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", @@ -49,28 +49,36 @@ class BufferHolderSparkSubmitSuite } } -object BufferHolderSparkSubmitSuite { +object BufferHolderSparkSubmitSuite extends Assertions { def main(args: Array[String]): Unit = { val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - val holder = new BufferHolder(new UnsafeRow(1000)) + val unsafeRow = new UnsafeRow(1000) + val holder = new BufferHolder(unsafeRow) holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2)) - holder.reset() - holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + assert(intercept[IllegalArgumentException] { + holder.grow(-1) + }.getMessage.contains("because the size is negative")) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + // while to reuse a buffer may happen, this test checks whether the buffer can be grown + holder.grow(ARRAY_MAX / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) - holder.reset() - holder.grow(roundToWord(Integer.MAX_VALUE)) - } + holder.grow(ARRAY_MAX / 2 + 7) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(Integer.MAX_VALUE / 2) + assert(unsafeRow.getSizeInBytes % 8 == 0) + + holder.grow(ARRAY_MAX - holder.totalSize()) + assert(unsafeRow.getSizeInBytes % 8 == 0) - private def roundToWord(len: Int): Int = { - ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + assert(intercept[IllegalArgumentException] { + holder.grow(ARRAY_MAX + 1 - holder.totalSize()) + }.getMessage.contains("because the size after growing")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index c7c386b5b838a..4e0f903a030aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -23,17 +23,15 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow class BufferHolderSuite extends SparkFunSuite { test("SPARK-16071 Check the size limit to avoid integer overflow") { - var e = intercept[UnsupportedOperationException] { + assert(intercept[UnsupportedOperationException] { new BufferHolder(new UnsafeRow(Int.MaxValue / 8)) - } - assert(e.getMessage.contains("too many fields")) + }.getMessage.contains("too many fields")) val holder = new BufferHolder(new UnsafeRow(1000)) holder.reset() holder.grow(1000) - e = intercept[UnsupportedOperationException] { + assert(intercept[IllegalArgumentException] { holder.grow(Integer.MAX_VALUE) - } - assert(e.getMessage.contains("exceeds size limitation")) + }.getMessage.contains("exceeds size limitation")) } } From b2950cef3c898f59a2c92e8800ff134c44263b9a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Aug 2018 20:33:59 +0800 Subject: [PATCH 1353/2461] Revert "[SPARK-24648][SQL] SqlMetrics should be threadsafe" This reverts commit 5264164a67df498b73facae207eda12ee133be7d. --- .../sql/execution/metric/SQLMetrics.scala | 33 +++++++---------- .../execution/metric/SQLMetricsSuite.scala | 36 +------------------ 2 files changed, 14 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 98f58a3056906..cbf707f4a9cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale -import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -33,45 +32,40 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { - // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] val _value = new LongAdder - private val _zeroValue = initValue - _value.add(initValue) + private[this] var _value = initValue + private var _zeroValue = initValue override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, initValue) - newAcc.add(_value.sum()) + val newAcc = new SQLMetric(metricType, _value) + newAcc._zeroValue = initValue newAcc } - override def reset(): Unit = this.set(_zeroValue) + override def reset(): Unit = _value = _zeroValue override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value.add(o.value) + case o: SQLMetric => _value += o.value case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value.sum() == _zeroValue + override def isZero(): Boolean = _value == _zeroValue - override def add(v: Long): Unit = _value.add(v) + override def add(v: Long): Unit = _value += v // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = { - _value.reset() - _value.add(v) - } + def set(v: Long): Unit = _value = v - def +=(v: Long): Unit = _value.add(v) + def +=(v: Long): Unit = _value += v - override def value: Long = _value.sum() + override def value: Long = _value // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -159,7 +153,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -179,8 +173,7 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), - sorted(validValues.length - 1)) + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 8263c9c81c49e..a3a3f3851e21c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -504,38 +504,4 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } - - test("writing metrics from single thread") { - val nAdds = 10 - val acc = new SQLMetric("test", -10) - assert(acc.isZero()) - acc.set(0) - for (i <- 1 to nAdds) acc.add(1) - assert(!acc.isZero()) - assert(nAdds === acc.value) - acc.reset() - assert(acc.isZero()) - } - - test("writing metrics from multiple threads") { - implicit val ec: ExecutionContextExecutor = ExecutionContext.global - val nFutures = 1000 - val nAdds = 100 - val acc = new SQLMetric("test", -10) - assert(acc.isZero() === true) - acc.set(0) - val l = for ( i <- 1 to nFutures ) yield { - Future { - for (j <- 1 to nAdds) acc.add(1) - i - } - } - for (futures <- Future.sequence(l)) { - assert(nFutures === futures.length) - assert(!acc.isZero()) - assert(nFutures * nAdds === acc.value) - acc.reset() - assert(acc.isZero()) - } - } } From 1a7e747ce4f8c5253c5923045d23c62e43a6566b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 9 Aug 2018 08:07:46 -0500 Subject: [PATCH 1354/2461] [SPARK-25047][ML] Can't assign SerializedLambda to scala.Function1 in deserialization of BucketedRandomProjectionLSHModel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Convert two function fields in ML classes to simple functions to avoi…d odd SerializedLambda deserialization problem ## How was this patch tested? Existing tests. Closes #22032 from srowen/SPARK-25047. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../feature/BucketedRandomProjectionLSH.scala | 14 ++++++------- .../org/apache/spark/ml/feature/LSH.scala | 4 ++-- .../apache/spark/ml/feature/MinHashLSH.scala | 20 +++++++++---------- .../GeneralizedLinearRegression.scala | 15 +++++++------- 4 files changed, 24 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index a906e954fecd5..0554455a66d7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -82,14 +82,12 @@ class BucketedRandomProjectionLSHModel private[ml]( override def setOutputCol(value: String): this.type = super.set(outputCol, value) @Since("2.1.0") - override protected[ml] val hashFunction: Vector => Array[Vector] = { - key: Vector => { - val hashValues: Array[Double] = randUnitVectors.map({ - randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength)) - }) - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) - } + override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { + val hashValues = randUnitVectors.map( + randUnitVector => Math.floor(BLAS.dot(elems, randUnitVector) / $(bucketLength)) + ) + // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 + hashValues.map(Vectors.dense(_)) } @Since("2.1.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index a70931f783f45..b20852383a6ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -75,7 +75,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. */ - protected[ml] val hashFunction: Vector => Array[Vector] + protected[ml] def hashFunction(elems: Vector): Array[Vector] /** * Calculate the distance between two different keys using the distance metric corresponding @@ -97,7 +97,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val transformUDF = udf(hashFunction, DataTypes.createArrayType(new VectorUDT)) + val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT)) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index a043033e96724..21cde66d8db6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -60,18 +60,16 @@ class MinHashLSHModel private[ml]( override def setOutputCol(value: String): this.type = super.set(outputCol, value) @Since("2.1.0") - override protected[ml] val hashFunction: Vector => Array[Vector] = { - elems: Vector => { - require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") - val elemsList = elems.toSparse.indices.toList - val hashValues = randCoefficients.map { case (a, b) => - elemsList.map { elem: Int => - ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME - }.min.toDouble - } - // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 - hashValues.map(Vectors.dense(_)) + override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { + require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.") + val elemsList = elems.toSparse.indices.toList + val hashValues = randCoefficients.map { case (a, b) => + elemsList.map { elem: Int => + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME + }.min.toDouble } + // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 + hashValues.map(Vectors.dense(_)) } @Since("2.1.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 20878b6448920..abb60ea205751 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -515,14 +515,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * The reweight function used to update working labels and weights * at each iteration of [[IterativelyReweightedLeastSquares]]. */ - val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = { - (instance: OffsetInstance, model: WeightedLeastSquaresModel) => { - val eta = model.predict(instance.features) + instance.offset - val mu = fitted(eta) - val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) - val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) - (newLabel, newWeight) - } + def reweightFunc( + instance: OffsetInstance, model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + instance.offset + val mu = fitted(eta) + val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu) + val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (newLabel, newWeight) } } From 2949a835fae3f4ac6e3dae6f18cd8b6543b74601 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 9 Aug 2018 08:11:30 -0700 Subject: [PATCH 1355/2461] [SPARK-25063][SQL] Rename class KnowNotNull to KnownNotNull ## What changes were proposed in this pull request? Correct the class name typo checked in through SPARK-24891 ## How was this patch tested? Passed all existing tests. Closes #22049 from maryannxue/known-not-null. Authored-by: maryannxue Signed-off-by: Xiao Li --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../sql/catalyst/expressions/constraintExpressions.scala | 2 +- .../apache/spark/sql/catalyst/analysis/AnalysisSuite.scala | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d23d43bef76e9..a7cd96e46d114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2157,7 +2157,7 @@ class Analyzer( // trust the `nullable` information. // (cls, expr) => cls.isPrimitive && expr.nullable val needsNullCheck = (cls: Class[_], expr: Expression) => - cls.isPrimitive && !expr.isInstanceOf[KnowNotNull] + cls.isPrimitive && !expr.isInstanceOf[KnownNotNull] val inputsNullCheck = parameterTypes.zip(inputs) .filter { case (cls, expr) => needsNullCheck(cls, expr) } .map { case (_, expr) => IsNull(expr) } @@ -2167,7 +2167,7 @@ class Analyzer( // branch of `If` will be called if any of these checked inputs is null. Thus we can // prevent this rule from being applied repeatedly. val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) => - if (needsNullCheck(cls, expr)) KnowNotNull(expr) else expr } + if (needsNullCheck(cls, expr)) KnownNotNull(expr) else expr } inputsNullCheck .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) .getOrElse(udf) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 53936aa914c8f..2917b0b8c9c53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types.DataType -case class KnowNotNull(child: Expression) extends UnaryExpression { +case class KnownNotNull(child: Expression) extends UnaryExpression { override def nullable: Boolean = false override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ba44484b946ff..a1c976dd923f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -319,7 +319,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) val expected2 = - If(IsNull(double), nullResult, udf2.copy(children = string :: KnowNotNull(double) :: Nil)) + If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters @@ -327,7 +327,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3.copy(children = KnowNotNull(short) :: KnowNotNull(double) :: Nil)) + udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -339,7 +339,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected4 = If( IsNull(short), nullResult, - udf4.copy(children = KnowNotNull(short) :: double.withNullability(false) :: Nil)) + udf4.copy(children = KnownNotNull(short) :: double.withNullability(false) :: Nil)) // checkUDF(udf4, expected4) } From d36539741ff6a12a6acde9274e9992a66cdd36e7 Mon Sep 17 00:00:00 2001 From: Achuth17 Date: Thu, 9 Aug 2018 08:29:24 -0700 Subject: [PATCH 1356/2461] [SPARK-24626][SQL] Improve location size calculation in Analyze Table command MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently, Analyze table calculates table size sequentially for each partition. We can parallelize size calculations over partitions. Results : Tested on a table with 100 partitions and data stored in S3. With changes : - 10.429s - 10.557s - 10.439s - 9.893s
 Without changes : - 110.034s - 99.510s - 100.743s - 99.106s ## How was this patch tested? Simple unit test. Closes #21608 from Achuth17/improveAnalyze. Lead-authored-by: Achuth17 Co-authored-by: arajagopal17 Signed-off-by: Xiao Li --- docs/sql-programming-guide.md | 2 ++ .../apache/spark/sql/internal/SQLConf.scala | 12 ++++++++ .../command/AnalyzeColumnCommand.scala | 2 +- .../command/AnalyzeTableCommand.scala | 2 +- .../sql/execution/command/CommandUtils.scala | 30 ++++++++++++++----- .../datasources/DataSourceUtils.scala | 10 +++++++ .../datasources/InMemoryFileIndex.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 23 +++++++++++++- 8 files changed, 72 insertions(+), 11 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a1e019cbec4d2..9adb86aa60bbb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1892,6 +1892,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. + - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.parallelFileListingInStatsComputation.enabled` to `False`. + - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. ## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 67c3abb80c2c5..979a55467ff89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1476,6 +1476,15 @@ object SQLConf { "are performed before any UNION, EXCEPT and MINUS operations.") .booleanConf .createWithDefault(false) + + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = + buildConf("spark.sql.parallelFileListingInStatsComputation.enabled") + .internal() + .doc("When true, SQL commands use parallel file listing, " + + "as opposed to single thread listing." + + "This usually speeds up commands that need to list many directories.") + .booleanConf + .createWithDefault(true) } /** @@ -1873,6 +1882,9 @@ class SQLConf extends Serializable with Logging { def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + def parallelFileListingInStatsComputation: Boolean = + getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 640e01336aa75..3fea6d7c7fbfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -47,7 +47,7 @@ case class AnalyzeColumnCommand( if (tableMeta.tableType == CatalogTableType.VIEW) { throw new AnalysisException("ANALYZE TABLE is not supported on views.") } - val sizeInBytes = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val sizeInBytes = CommandUtils.calculateTotalSize(sparkSession, tableMeta) // Compute stats for each column val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 58b53e8b1c551..3076e919dd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -39,7 +39,7 @@ case class AnalyzeTableCommand( } // Compute stats for the whole table - val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta) + val newTotalSize = CommandUtils.calculateTotalSize(sparkSession, tableMeta) val newRowCount = if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index c27048626c8eb..df71bc9effb3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -21,12 +21,13 @@ import java.net.URI import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} import org.apache.spark.sql.internal.SessionState @@ -38,7 +39,7 @@ object CommandUtils extends Logging { val catalog = sparkSession.sessionState.catalog if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { val newTable = catalog.getTableMetadata(table.identifier) - val newSize = CommandUtils.calculateTotalSize(sparkSession.sessionState, newTable) + val newSize = CommandUtils.calculateTotalSize(sparkSession, newTable) val newStats = CatalogStatistics(sizeInBytes = newSize) catalog.alterTableStats(table.identifier, Some(newStats)) } else { @@ -47,15 +48,29 @@ object CommandUtils extends Logging { } } - def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): BigInt = { + def calculateTotalSize(spark: SparkSession, catalogTable: CatalogTable): BigInt = { + val sessionState = spark.sessionState if (catalogTable.partitionColumnNames.isEmpty) { calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) } else { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) - partitions.map { p => - calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) - }.sum + if (spark.sessionState.conf.parallelFileListingInStatsComputation) { + val paths = partitions.map(x => new Path(x.storage.locationUri.get)) + val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") + val pathFilter = new PathFilter with Serializable { + override def accept(path: Path): Boolean = { + DataSourceUtils.isDataPath(path) && !path.getName.startsWith(stagingDir) + } + } + val fileStatusSeq = InMemoryFileIndex.bulkListLeafFiles( + paths, sessionState.newHadoopConf(), pathFilter, spark) + fileStatusSeq.flatMap(_._2.map(_.getLen)).sum + } else { + partitions.map { p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + }.sum + } } } @@ -78,7 +93,8 @@ object CommandUtils extends Logging { val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => - if (!status.getPath.getName.startsWith(stagingDir)) { + if (!status.getPath.getName.startsWith(stagingDir) && + DataSourceUtils.isDataPath(path)) { getPathSize(fs, status.getPath) } else { 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index cccd6c08ae460..90cec5e72c1a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types._ @@ -49,4 +51,12 @@ object DataSourceUtils { } } } + + // SPARK-24626: Metadata files and temporary files should not be + // counted as data files, so that they shouldn't participate in tasks like + // location size calculation. + private[sql] def isDataPath(path: Path): Boolean = { + val name = path.getName + !(name.startsWith("_") || name.startsWith(".")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 9d9f8bd5bb58e..dc5c2ff927e4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -162,7 +162,7 @@ object InMemoryFileIndex extends Logging { * * @return for each input path, the set of discovered files for the path */ - private def bulkListLeafFiles( + private[sql] def bulkListLeafFiles( paths: Seq[Path], hadoopConf: Configuration, filter: PathFilter, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 61cec82984795..d8ffb29a59317 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,13 +25,14 @@ import scala.util.matching.Regex import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{CommandUtils, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.HiveExternalCatalog._ @@ -148,6 +149,26 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("SPARK-24626 parallel file listing in Stats computation") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "2", + SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION.key -> "True") { + val checkSizeTable = "checkSizeTable" + withTable(checkSizeTable) { + sql(s"CREATE TABLE $checkSizeTable (key STRING, value STRING) PARTITIONED BY (ds STRING)") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-01') SELECT * FROM src") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-02') SELECT * FROM src") + sql(s"INSERT INTO TABLE $checkSizeTable PARTITION (ds='2010-01-03') SELECT * FROM src") + val tableMeta = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(checkSizeTable)) + HiveCatalogMetrics.reset() + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 0) + val size = CommandUtils.calculateTotalSize(spark, tableMeta) + assert(HiveCatalogMetrics.METRIC_PARALLEL_LISTING_JOB_COUNT.getCount() == 1) + assert(size === BigInt(17436)) + } + } + } + test("analyze non hive compatible datasource tables") { val table = "parquet_tab" withTable(table) { From eb9a696dd6f138225708d15bb2383854ed8a6dab Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 9 Aug 2018 13:04:03 -0500 Subject: [PATCH 1357/2461] [MINOR][BUILD] Update Jetty to 9.3.24.v20180605 ## What changes were proposed in this pull request? Update Jetty to 9.3.24.v20180605 to pick up security fix ## How was this patch tested? Existing tests. Closes #22055 from srowen/Jetty9324. Authored-by: Sean Owen Signed-off-by: Sean Owen --- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 90602fce59a7d..fb42adf95db27 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -120,8 +120,8 @@ jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jets3t-0.9.4.jar -jetty-webapp-9.3.20.v20170531.jar -jetty-xml-9.3.20.v20170531.jar +jetty-webapp-9.3.24.v20180605.jar +jetty-xml-9.3.24.v20180605.jar jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar diff --git a/pom.xml b/pom.xml index 8abdb700dc4c0..b89713fab7291 100644 --- a/pom.xml +++ b/pom.xml @@ -134,7 +134,7 @@ 1.5.2 nohive 1.6.0 - 9.3.20.v20170531 + 9.3.24.v20180605 3.1.0 0.8.4 2.4.0 From bd6db1505fb68737fa1782bd457ddc52eae6652d Mon Sep 17 00:00:00 2001 From: liyuanjian Date: Thu, 9 Aug 2018 13:43:07 -0700 Subject: [PATCH 1358/2461] [SPARK-25077][SQL] Delete unused variable in WindowExec ## What changes were proposed in this pull request? Just delete the unused variable `inputFields` in WindowExec, avoid making others confused while reading the code. ## How was this patch tested? Existing UT. Closes #22057 from xuanyuanking/SPARK-25077. Authored-by: liyuanjian Signed-off-by: Xiao Li --- .../org/apache/spark/sql/execution/window/WindowExec.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 626f39d9e95cc..fede0f3e92d67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -323,8 +323,6 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val inputFields = child.output.length - val buffer: ExternalAppendOnlyUnsafeRowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) From fec67ed7e95483c5ea97a7b263ad4bea7d3d42b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Aug 2018 14:38:58 -0700 Subject: [PATCH 1359/2461] [SPARK-25076][SQL] SQLConf should not be retrieved from a stopped SparkSession ## What changes were proposed in this pull request? When a `SparkSession` is stopped, `SQLConf.get` should use the fallback conf to avoid weird issues like ``` sbt.ForkMain$ForkError: java.lang.IllegalStateException: LiveListenerBus is stopped. at org.apache.spark.scheduler.LiveListenerBus.addToQueue(LiveListenerBus.scala:97) at org.apache.spark.scheduler.LiveListenerBus.addToStatusQueue(LiveListenerBus.scala:80) at org.apache.spark.sql.internal.SharedState.(SharedState.scala:93) at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120) at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120) at scala.Option.getOrElse(Option.scala:121) ... ``` ## How was this patch tested? a new test suite Closes #22056 from cloud-fan/session. Authored-by: Wenchen Fan Signed-off-by: Xiao Li --- .../org/apache/spark/sql/SparkSession.scala | 3 +- .../apache/spark/sql/LocalSparkSession.scala | 9 ++--- .../sql/internal/SQLConfGetterSuite.scala | 33 +++++++++++++++++++ 3 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 565042fcf762e..d9278d8cd23d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -92,7 +92,8 @@ class SparkSession private( // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. SQLConf.setSQLConfGetter(() => { - SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + SparkSession.getActiveSession.filterNot(_.sparkContext.isStopped).map(_.sessionState.conf) + .getOrElse(SQLConf.getFallbackConf) }) /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index cbef1c7828319..6b90f20a94fa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -36,19 +36,14 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def afterEach() { try { - resetSparkContext() + LocalSparkSession.stop(spark) SparkSession.clearActiveSession() SparkSession.clearDefaultSession() + spark = null } finally { super.afterEach() } } - - def resetSparkContext(): Unit = { - LocalSparkSession.stop(spark) - spark = null - } - } object LocalSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala new file mode 100644 index 0000000000000..bb79d3a84e5a3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{LocalSparkSession, SparkSession} + +class SQLConfGetterSuite extends SparkFunSuite with LocalSparkSession { + + test("SPARK-25076: SQLConf should not be retrieved from a stopped SparkSession") { + spark = SparkSession.builder().master("local").getOrCreate() + assert(SQLConf.get eq spark.sessionState.conf, + "SQLConf.get should get the conf from the active spark session.") + spark.stop() + assert(SQLConf.get eq SQLConf.getFallbackConf, + "SQLConf.get should not get conf from a stopped spark session.") + } +} From 9b8521e53e56a53b44c02366a99f8a8ee1307bbf Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 9 Aug 2018 14:41:59 -0700 Subject: [PATCH 1360/2461] [SPARK-25068][SQL] Add exists function. ## What changes were proposed in this pull request? This pr adds `exists` function which tests whether a predicate holds for one or more elements in the array. ```sql > SELECT exists(array(1, 2, 3), x -> x % 2 == 0); true ``` ## How was this patch tested? Added tests. Closes #22052 from ueshin/issues/SPARK-25068/exists. Authored-by: Takuya UESHIN Signed-off-by: Xiao Li --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 47 +++++++++ .../HigherOrderFunctionsSuite.scala | 37 +++++++ .../inputs/higher-order-functions.sql | 6 ++ .../results/higher-order-functions.sql.out | 18 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 96 +++++++++++++++++++ 6 files changed, 205 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 390debd865eed..15543c909a271 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -444,6 +444,7 @@ object FunctionRegistry { expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), + expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index d20673359129b..7f8203ab92213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -356,6 +356,53 @@ case class ArrayFilter( override def prettyName: String = "filter" } +/** + * Tests whether a predicate holds for one or more elements in the array. + */ +@ExpressionDescription(usage = + "_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0); + true + """, + since = "2.4.0") +case class ArrayExists( + input: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = BooleanType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + copy(function = f(function, elem :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val arr = value.asInstanceOf[ArrayData] + val f = functionForEval + var exists = false + var i = 0 + while (i < arr.numElements && !exists) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + exists = true + } + i += 1 + } + exists + } + + override def prettyName: String = "exists" +} + /** * Applies a binary operator to a start value and all elements in the array. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index f7e84b8757916..bc7d04c77fa9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -202,6 +202,43 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Seq(1, 3), null, Seq(5))) } + test("ArrayExists") { + def exists(expr: Expression, f: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayExists(expr, createLambda(at.elementType, at.containsNull, f)) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(exists(ai0, isEven), true) + checkEvaluation(exists(ai0, isNullOrOdd), true) + checkEvaluation(exists(ai1, isEven), false) + checkEvaluation(exists(ai1, isNullOrOdd), true) + checkEvaluation(exists(ain, isEven), null) + checkEvaluation(exists(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(exists(as0, startsWithA), true) + checkEvaluation(exists(as1, startsWithA), false) + checkEvaluation(exists(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)), + Seq(true, null, true)) + } + test("ArrayAggregate") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 136396d9553db..ce1d0daa4d397 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -45,3 +45,9 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as -- Aggregate a null array select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; + +-- Check for element existence +select exists(ys, y -> y > 30) as v from nested; + +-- Check for element existence in a null array +select exists(cast(null as array), y -> y > 30) as v; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e6f62f2e1bb67..e18abce3b617b 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -145,3 +145,21 @@ select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) a struct -- !query 14 output NULL + + +-- !query 15 +select exists(ys, y -> y > 30) as v from nested +-- !query 15 schema +struct +-- !query 15 output +false +true +true + + +-- !query 16 +select exists(cast(null as array), y -> y > 30) as v +-- !query 16 schema +struct +-- !query 16 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 24091f2128049..2c4238e69ad7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1996,6 +1996,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) } + test("exists function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 9, 7), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("exists function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, null, 9, 7, null), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("exists function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("exists(s, x -> x is null)"), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("exists function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("exists(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("exists(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("exists(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + } + test("aggregate function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7), From 6c7bb575bf8b0bfc26f23e0ef449aaded77d3789 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 10 Aug 2018 09:12:17 +0800 Subject: [PATCH 1361/2461] [SPARK-24886][INFRA] Fix the testing script to increase timeout for Jenkins build (from 300m to 340m) ## What changes were proposed in this pull request? Currently, looks we hit the time limit time to time. Looks better increasing the time a bit. For instance, please see https://github.com/apache/spark/pull/21822 For clarification, current Jenkins timeout is 400m. This PR just proposes to fix the test script to increase it correspondingly. *This PR does not target to change the build configuration* ## How was this patch tested? Jenkins tests. Closes #21845 from HyukjinKwon/SPARK-24886. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/run-tests-jenkins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 3960a0de62530..16af97c7fbeae 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,8 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 350m) - tests_timeout = "300m" + # must be less than the timeout configured on Jenkins (currently 400m) + tests_timeout = "340m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. From bdd27961c870a3c443686cdbb6dd0eee3ad32012 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 10 Aug 2018 11:10:23 +0800 Subject: [PATCH 1362/2461] [SPARK-24251][SQL] Add analysis tests for AppendData. ## What changes were proposed in this pull request? This is a follow-up to #21305 that adds a test suite for AppendData analysis. This also fixes the following problems uncovered by these tests: * Incorrect order of data types passed to `canWrite` is fixed * The field check calls `canWrite` first to ensure all errors are found * `AppendData#resolved` must check resolution of the query's attributes * Column names are quoted to show empty names ## How was this patch tested? This PR adds a test suite for AppendData analysis. Closes #22043 from rdblue/SPARK-24251-add-append-data-analysis-tests. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 16 +- .../plans/logical/basicLogicalOperators.scala | 15 +- .../analysis/DataSourceV2AnalysisSuite.scala | 379 ++++++++++++++++++ 3 files changed, 397 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a7cd96e46d114..d00b82d35d7d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2258,8 +2258,8 @@ class Analyzer( if (expected.size < query.output.size) { throw new AnalysisException( s"""Cannot write to '$tableName', too many data columns: - |Table columns: ${expected.map(_.name).mkString(", ")} - |Data columns: ${query.output.map(_.name).mkString(", ")}""".stripMargin) + |Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")} + |Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}""".stripMargin) } val errors = new mutable.ArrayBuffer[String]() @@ -2278,8 +2278,9 @@ class Analyzer( if (expected.size > query.output.size) { throw new AnalysisException( s"""Cannot write to '$tableName', not enough data columns: - |Table columns: ${expected.map(_.name).mkString(", ")} - |Data columns: ${query.output.map(_.name).mkString(", ")}""".stripMargin) + |Table columns: ${expected.map(c => s"'${c.name}'").mkString(", ")} + |Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}""" + .stripMargin) } query.output.zip(expected).flatMap { @@ -2301,12 +2302,15 @@ class Analyzer( queryExpr: NamedExpression, addError: String => Unit): Option[NamedExpression] = { + // run the type check first to ensure type errors are present + val canWrite = DataType.canWrite( + queryExpr.dataType, tableAttr.dataType, resolver, tableAttr.name, addError) + if (queryExpr.nullable && !tableAttr.nullable) { addError(s"Cannot write nullable values to non-null column '${tableAttr.name}'") None - } else if (!DataType.canWrite( - tableAttr.dataType, queryExpr.dataType, resolver, tableAttr.name, addError)) { + } else if (!canWrite) { None } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 0d31c6f6b9c49..a6631a8d444e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -363,13 +363,14 @@ case class AppendData( override def output: Seq[Attribute] = Seq.empty override lazy val resolved: Boolean = { - query.output.size == table.output.size && query.output.zip(table.output).forall { - case (inAttr, outAttr) => - // names and types must match, nullability must be compatible - inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && - (outAttr.nullable || !inAttr.nullable) - } + table.resolved && query.resolved && query.output.size == table.output.size && + query.output.zip(table.output).forall { + case (inAttr, outAttr) => + // names and types must match, nullability must be compatible + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala new file mode 100644 index 0000000000000..6c899b610ac5b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} + +case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { + override def name: String = "table-name" +} + +class DataSourceV2AnalysisSuite extends AnalysisTest { + val table = TestRelation(StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType))).toAttributes) + + val requiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", FloatType, nullable = false))).toAttributes) + + val widerTable = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("y", DoubleType))).toAttributes) + + test("Append.byName: basic behavior") { + val query = TestRelation(table.schema.toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + checkAnalysis(parsedPlan, parsedPlan) + assertResolved(parsedPlan) + } + + test("Append.byName: does not match by position") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'", "'y'")) + } + + test("Append.byName: case sensitive column resolution") { + val query = TestRelation(StructType(Seq( + StructField("X", FloatType), // doesn't match case! + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'"), + caseSensitive = true) + } + + test("Append.byName: case insensitive column resolution") { + val query = TestRelation(StructType(Seq( + StructField("X", FloatType), // doesn't match case! + StructField("y", FloatType))).toAttributes) + + val X = query.output.head + val y = query.output.last + + val parsedPlan = AppendData.byName(table, query) + val expectedPlan = AppendData.byName(table, + Project(Seq( + Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan, caseSensitive = false) + assertResolved(expectedPlan) + } + + test("Append.byName: data columns are reordered by name") { + // out of order + val query = TestRelation(StructType(Seq( + StructField("y", FloatType), + StructField("x", FloatType))).toAttributes) + + val y = query.output.head + val x = query.output.last + + val parsedPlan = AppendData.byName(table, query) + val expectedPlan = AppendData.byName(table, + Project(Seq( + Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byName: fail nullable data written to required columns") { + val parsedPlan = AppendData.byName(requiredTable, table) + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", "'y'")) + } + + test("Append.byName: allow required data written to nullable columns") { + val parsedPlan = AppendData.byName(table, requiredTable) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + test("Append.byName: missing required columns cause failure and are identified by name") { + // missing required field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType, nullable = false))).toAttributes) + + val parsedPlan = AppendData.byName(requiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'")) + } + + test("Append.byName: missing optional columns cause failure and are identified by name") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot find data for output column", "'x'")) + } + + test("Append.byName: fail canWrite check") { + val parsedPlan = AppendData.byName(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("Append.byName: insert safe cast") { + val x = table.output.head + val y = table.output.last + + val parsedPlan = AppendData.byName(widerTable, table) + val expectedPlan = AppendData.byName(widerTable, + Project(Seq( + Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + table)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byName: fail extra data fields") { + val query = TestRelation(StructType(Seq( + StructField("x", FloatType), + StructField("y", FloatType), + StructField("z", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "too many data columns", + "Table columns: 'x', 'y'", + "Data columns: 'x', 'y', 'z'")) + } + + test("Append.byName: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot safely cast", "'x'", "DoubleType to FloatType", + "Cannot write nullable values to non-null column", "'x'", + "Cannot find data for output column", "'y'")) + } + + test("Append.byPosition: basic behavior") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType))).toAttributes) + + val a = query.output.head + val b = query.output.last + + val parsedPlan = AppendData.byPosition(table, query) + val expectedPlan = AppendData.byPosition(table, + Project(Seq( + Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan, caseSensitive = false) + assertResolved(expectedPlan) + } + + test("Append.byPosition: data columns are not reordered") { + // out of order + val query = TestRelation(StructType(Seq( + StructField("y", FloatType), + StructField("x", FloatType))).toAttributes) + + val y = query.output.head + val x = query.output.last + + val parsedPlan = AppendData.byPosition(table, query) + val expectedPlan = AppendData.byPosition(table, + Project(Seq( + Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), + query)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byPosition: fail nullable data written to required columns") { + val parsedPlan = AppendData.byPosition(requiredTable, table) + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", "'y'")) + } + + test("Append.byPosition: allow required data written to nullable columns") { + val parsedPlan = AppendData.byPosition(table, requiredTable) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + test("Append.byPosition: missing required columns cause failure") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType, nullable = false))).toAttributes) + + val parsedPlan = AppendData.byPosition(requiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "not enough data columns", + "Table columns: 'x', 'y'", + "Data columns: 'y'")) + } + + test("Append.byPosition: missing optional columns cause failure") { + // missing optional field x + val query = TestRelation(StructType(Seq( + StructField("y", FloatType))).toAttributes) + + val parsedPlan = AppendData.byPosition(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "not enough data columns", + "Table columns: 'x', 'y'", + "Data columns: 'y'")) + } + + test("Append.byPosition: fail canWrite check") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val parsedPlan = AppendData.byPosition(table, widerTable) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", + "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) + } + + test("Append.byPosition: insert safe cast") { + val widerTable = TestRelation(StructType(Seq( + StructField("a", DoubleType), + StructField("b", DoubleType))).toAttributes) + + val x = table.output.head + val y = table.output.last + + val parsedPlan = AppendData.byPosition(widerTable, table) + val expectedPlan = AppendData.byPosition(widerTable, + Project(Seq( + Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), + Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), + table)) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("Append.byPosition: fail extra data fields") { + val query = TestRelation(StructType(Seq( + StructField("a", FloatType), + StructField("b", FloatType), + StructField("c", FloatType))).toAttributes) + + val parsedPlan = AppendData.byName(table, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write", "'table-name'", "too many data columns", + "Table columns: 'x', 'y'", + "Data columns: 'a', 'b', 'c'")) + } + + test("Append.byPosition: multiple field errors are reported") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", FloatType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("x", DoubleType), + StructField("b", FloatType))).toAttributes) + + val parsedPlan = AppendData.byPosition(xRequiredTable, query) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write nullable values to non-null column", "'x'", + "Cannot safely cast", "'x'", "DoubleType to FloatType")) + } + + def assertNotResolved(logicalPlan: LogicalPlan): Unit = { + assert(!logicalPlan.resolved, s"Plan should not be resolved: $logicalPlan") + } + + def assertResolved(logicalPlan: LogicalPlan): Unit = { + assert(logicalPlan.resolved, s"Plan should be resolved: $logicalPlan") + } + + def toLower(attr: AttributeReference): AttributeReference = { + AttributeReference(attr.name.toLowerCase(Locale.ROOT), attr.dataType)(attr.exprId) + } +} From 0cea9e3cd0a92799bdcc0f9bc2cf96259c343a30 Mon Sep 17 00:00:00 2001 From: Brian Lindblom Date: Fri, 10 Aug 2018 03:35:29 +0000 Subject: [PATCH 1363/2461] [SPARK-24855][SQL][EXTERNAL] Built-in AVRO support should support specified schema on write ## What changes were proposed in this pull request? Allows `avroSchema` option to be specified on write, allowing a user to specify a schema in cases where this is required. A trivial use case is reading in an avro dataset, making some small adjustment to a column or columns and writing out using the same schema. Implicit schema creation from SQL Struct results in a schema that while for the most part, is functionally similar, is not necessarily compatible. Allows `fixed` Field type to be utilized for records of specified `avroSchema` ## How was this patch tested? Unit tests in AvroSuite are extended to test this with enum and fixed types. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21847 from lindblombr/specify_schema_on_write. Lead-authored-by: Brian Lindblom Co-authored-by: DB Tsai Signed-off-by: DB Tsai --- .../spark/sql/avro/AvroFileFormat.scala | 6 +- .../spark/sql/avro/AvroSerializer.scala | 40 ++- .../org/apache/spark/sql/avro/AvroSuite.scala | 228 +++++++++++++++++- 3 files changed, 257 insertions(+), 17 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 6ffcf375af678..6df23c93e4c54 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -113,8 +113,10 @@ private[avro] class AvroFileFormat extends FileFormat options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) - val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = false, - parsedOptions.recordName, parsedOptions.recordNamespace, parsedOptions.outputTimestampType) + val outputAvroSchema: Schema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false, + parsedOptions.recordName, parsedOptions.recordNamespace)) AvroJob.setOutputKeySchema(job, outputAvroSchema) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 988582698d826..216c52a5cfd26 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -23,8 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema -import org.apache.avro.Schema.Type.NULL -import org.apache.avro.generic.GenericData.Record +import org.apache.avro.Schema.Type +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} import org.apache.avro.util.Utf8 import org.apache.spark.sql.catalyst.InternalRow @@ -87,10 +87,36 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: (getter, ordinal) => getter.getDouble(ordinal) case d: DecimalType => (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString - case StringType => - (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) - case BinaryType => - (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + case StringType => avroType.getType match { + case Type.ENUM => + import scala.collection.JavaConverters._ + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + "Cannot write \"" + data + "\" since it's not defined in enum \"" + + enumSymbols.mkString("\", \"") + "\"") + } + new EnumSymbol(avroType, data) + case _ => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + } + case BinaryType => avroType.getType match { + case Type.FIXED => + val size = avroType.getFixedSize() + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + throw new IncompatibleSchemaException( + s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + + "binary data into FIXED Type with size of " + + s"$size ${if (size > 1) "bytes" else "byte"}") + } + new Fixed(avroType, data) + case _ => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + } case DateType => (getter, ordinal) => getter.getInt(ordinal) case TimestampType => avroType.getLogicalType match { @@ -182,7 +208,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: // avro uses union to represent nullable type. val fields = avroType.getTypes.asScala assert(fields.length == 2) - val actualType = fields.filter(_.getType != NULL) + val actualType = fields.filter(_.getType != Type.NULL) assert(actualType.length == 1) actualType.head } else { diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 47995bb39a05e..ada9980e65b1c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -32,6 +32,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -100,6 +101,25 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(newEntries, originalEntries) } + def checkAvroSchemaEquals(avroSchema: String, expectedAvroSchema: String): Unit = { + assert(new Schema.Parser().parse(avroSchema) == + new Schema.Parser().parse(expectedAvroSchema)) + } + + def getAvroSchemaStringFromFiles(filePath: String): String = { + new DataFileReader({ + val file = new File(filePath) + if (file.isFile) { + file + } else { + file.listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + .head + } + }, new GenericDatumReader[Any]()).getSchema.toString(false) + } + test("resolve avro data source") { Seq("avro", "com.databricks.spark.avro").foreach { provider => assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === @@ -471,7 +491,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } """ val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro) - checkAnswer(df, expected) } @@ -773,6 +792,205 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result === Row("foo")) } + test("support user provided avro schema for writing nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": [{ "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing data not in the enum will throw an exception + val message = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing non-nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": { "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | } + | }] + |} + """.stripMargin + + val dfWithNull = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val df = spark.createDataFrame(dfWithNull.na.drop().rdd, + StructType(Seq(StructField("Suit", StringType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing nulls without using avro union type will + // throw an exception as avro uses union type to handle null. + val message1 = intercept[SparkException] { + dfWithNull.write.format("avro") + .option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.avro.AvroRuntimeException: Not a union:")) + + // Writing df containing data not in the enum will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": [{ "type": "fixed", + | "size": 2, + | "name": "fixed2" + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(null))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + + test("support user provided avro schema for writing non-nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": { "type": "fixed", + | "size": 2, + | "name": "fixed2" + | } + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(Array(1, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + test("reading from invalid path throws exception") { // Directory given has no avro files @@ -936,13 +1154,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { withTempPath { dir => val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) writeDf.write.format("avro").save(dir.toString) - val file = new File(dir.toString) - .listFiles() - .filter(_.isFile) - .filter(_.getName.endsWith("avro")) - .head - val reader = new DataFileReader(file, new GenericDatumReader[Any]()) - val schema = reader.getSchema.toString() + val schema = getAvroSchemaStringFromFiles(dir.toString) assert(schema.contains("\"namespace\":\"topLevelRecord\"")) assert(schema.contains("\"namespace\":\"topLevelRecord.data\"")) assert(schema.contains("\"namespace\":\"topLevelRecord.data.data\"")) From ab1029fb8aae586e3af1238048e8b3dcfeb096f4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Aug 2018 15:41:59 +0900 Subject: [PATCH 1364/2461] [SPARK-23912][SQL][FOLLOWUP] Refactor ArrayDistinct ## What changes were proposed in this pull request? This PR simplified code generation for `ArrayDistinct`. #21966 enabled code generation only if the type can be specialized by the hash set. This PR follows this strategy. Optimization of null handling will be implemented in #21912. ## How was this patch tested? Existing UTs Closes #22044 from kiszk/SPARK-23912-follow. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../expressions/collectionOperations.scala | 215 +++++------------- 1 file changed, 61 insertions(+), 154 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b37fdc6d10fd1..5e3449d5631b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3410,6 +3410,28 @@ case class ArrayDistinct(child: Expression) case _ => false } + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } + override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementTypeSupportEquals) { @@ -3442,17 +3464,15 @@ case class ArrayDistinct(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (array) => { - val i = ctx.freshName("i") - val j = ctx.freshName("j") - val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") - val getValue1 = CodeGenerator.getValue(array, elementType, i) - val getValue2 = CodeGenerator.getValue(array, elementType, j) - val foundNullElement = ctx.freshName("foundNullElement") - val openHashSet = classOf[OpenHashSet[_]].getName - val hs = ctx.freshName("hs") - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - if (elementTypeSupportEquals) { + if (canUseSpecializedHashSet) { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val getValue = CodeGenerator.getValue(array, elementType, i) s""" |int $sizeOfDistinctArray = 0; |boolean $foundNullElement = false; @@ -3461,53 +3481,26 @@ case class ArrayDistinct(child: Expression) | if ($array.isNullAt($i)) { | $foundNullElement = true; | } else { - | $hs.add($getValue1); + | $hs.add$hsPostFix($hsValueCast$getValue); | } |} |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} """.stripMargin - } else { - s""" - |int $sizeOfDistinctArray = 0; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | if (!($foundNullElement)) { - | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; - | $foundNullElement = true; - | } - | } else { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { - | break; - | } - | } - | if ($i == $j) { - | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; - | } - | } - |} - | - |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} - """.stripMargin - } - }) + }) + } else { + nullSafeCodeGen(ctx, ev, (array) => { + val expr = ctx.addReferenceObj("arrayDistinctExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array);" + }) + } } private def setNull( - isPrimitive: Boolean, foundNullElement: String, distinctArray: String, pos: String): String = { - val setNullValue = - if (!isPrimitive) { - s"$distinctArray[$pos] = null"; - } else { - s"$distinctArray.setNullAt($pos)"; - } - + val setNullValue = s"$distinctArray.setNullAt($pos)" s""" |if (!($foundNullElement)) { | $setNullValue; @@ -3517,57 +3510,16 @@ case class ArrayDistinct(child: Expression) """.stripMargin } - private def setNotNullValue(isPrimitive: Boolean, - distinctArray: String, - pos: String, - getValue1: String, - primitiveValueTypeName: String): String = { - if (!isPrimitive) { - s"$distinctArray[$pos] = $getValue1"; - } else { - s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; - } - } - - private def setValueForFastEval( - isPrimitive: Boolean, + private def setValue( hs: String, distinctArray: String, pos: String, getValue1: String, primitiveValueTypeName: String): String = { - val setValue = setNotNullValue(isPrimitive, - distinctArray, pos, getValue1, primitiveValueTypeName) s""" - |if (!($hs.contains($getValue1))) { - | $hs.add($getValue1); - | $setValue; - | $pos = $pos + 1; - |} - """.stripMargin - } - - private def setValueForBruteForceEval( - isPrimitive: Boolean, - i: String, - j: String, - inputArray: String, - distinctArray: String, - pos: String, - getValue1: String, - isEqual: String, - primitiveValueTypeName: String): String = { - val setValue = setNotNullValue(isPrimitive, - distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |int $j; - |for ($j = 0; $j < $i; $j ++) { - | if (!$inputArray.isNullAt($j) && $isEqual) { - | break; - | } - |} - |if ($i == $j) { - | $setValue; + |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) { + | $hs.add$hsPostFix($hsValueCast$getValue1); + | $distinctArray.set$primitiveValueTypeName($pos, $getValue1); | $pos = $pos + 1; |} """.stripMargin @@ -3580,73 +3532,28 @@ case class ArrayDistinct(child: Expression) size: String): String = { val distinctArray = ctx.freshName("distinctArray") val i = ctx.freshName("i") - val j = ctx.freshName("j") val pos = ctx.freshName("pos") val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) - val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) - val isEqual = ctx.genEqual(elementType, getValue1, getValue2) val foundNullElement = ctx.freshName("foundNullElement") val hs = ctx.freshName("hs") val openHashSet = classOf[OpenHashSet[_]].getName - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - val setNullForNonPrimitive = - setNull(false, foundNullElement, distinctArray, pos) - if (elementTypeSupportEquals) { - val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") - s""" - |int $pos = 0; - |Object[] $distinctArray = new Object[$size]; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForNonPrimitive; - | } else { - | $setValueForFast; - | } - |} - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - } else { - val setValueForBruteForce = setValueForBruteForceEval( - false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") - s""" - |int $pos = 0; - |Object[] $distinctArray = new Object[$size]; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForNonPrimitive; - | } else { - | $setValueForBruteForce; - | } - |} - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - } - } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" - val setValueForFast = - setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForPrimitive; - | } else { - | $setValueForFast; - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin - } + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | ${setNull(foundNullElement, distinctArray, pos)} + | } else { + | ${setValue(hs, distinctArray, pos, getValue1, primitiveValueTypeName)} + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin } override def prettyName: String = "array_distinct" From 9abe09bfc18580233acad676d1241684c7d8768d Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 10 Aug 2018 15:53:31 +0800 Subject: [PATCH 1365/2461] [SPARK-24127][SS] Continuous text socket source ## What changes were proposed in this pull request? Support for text socket stream in spark structured streaming "continuous" mode. This is roughly based on the idea of ContinuousMemoryStream where the executor queries the data from driver over an RPC endpoint. This makes it possible to create Structured streaming continuous pipeline to ingest data via "nc" and run examples. ## How was this patch tested? Unit test and ran spark examples in structured streaming continuous mode. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21199 from arunmahadevan/SPARK-24127. Authored-by: Arun Mahadevan Signed-off-by: hyukjinkwon --- .../streaming/ContinuousRecordEndpoint.scala | 69 +++++ .../ContinuousTextSocketSource.scala | 292 ++++++++++++++++++ .../sources/ContinuousMemoryStream.scala | 32 +- .../execution/streaming/sources/socket.scala | 25 +- .../sources/TextSocketStreamSuite.scala | 98 +++++- 5 files changed, 482 insertions(+), 34 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala new file mode 100644 index 0000000000000..c9c2ebc875f28 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousRecordEndpoint.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset + +case class ContinuousRecordPartitionOffset(partitionId: Int, offset: Int) extends PartitionOffset +case class GetRecord(offset: ContinuousRecordPartitionOffset) + +/** + * A RPC end point for continuous readers to poll for + * records from the driver. + * + * @param buckets the data buckets. Each bucket contains a sequence of items to be + * returned for a partition. The number of buckets should be equal to + * to the number of partitions. + * @param lock a lock object for locking the buckets for read + */ +class ContinuousRecordEndpoint(buckets: Seq[Seq[Any]], lock: Object) + extends ThreadSafeRpcEndpoint { + + private var startOffsets: Seq[Int] = List.fill(buckets.size)(0) + + /** + * Sets the start offset. + * + * @param offsets the base offset per partition to be used + * while retrieving the data in {#receiveAndReply}. + */ + def setStartOffsets(offsets: Seq[Int]): Unit = { + lock.synchronized { + startOffsets = offsets + } + } + + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + /** + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. + */ + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousRecordPartitionOffset(partitionId, offset)) => + lock.synchronized { + val bufOffset = offset - startOffsets(partitionId) + val buf = buckets(partitionId) + val record = if (buf.size <= bufOffset) None else Some(buf(bufOffset)) + + context.reply(record.map(InternalRow(_))) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala new file mode 100644 index 0000000000000..1dbdfd558de48 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import java.sql.Timestamp +import java.util.{Calendar, List => JList} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.{DefaultFormats, NoTypeHints} +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} +import org.apache.spark.sql.execution.streaming.sources.TextSocketReader +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.util.RpcUtils + + +/** + * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This ContinuousReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. + * + * The driver maintains a socket connection to the host-port, keeps the received messages in + * buckets and serves the messages to the executors via a RPC endpoint. + */ +class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + private val numPartitions = spark.sparkContext.defaultParallelism + + @GuardedBy("this") + private var socket: Socket = _ + + @GuardedBy("this") + private var readThread: Thread = _ + + @GuardedBy("this") + private val buckets = Seq.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + + @GuardedBy("this") + private var currentOffset: Int = -1 + + private var startOffset: TextSocketOffset = _ + + private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) + @volatile private var endpointRef: RpcEndpointRef = _ + + initialize() + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + assert(offsets.length == numPartitions) + val offs = offsets + .map(_.asInstanceOf[ContinuousRecordPartitionOffset]) + .sortBy(_.partitionId) + .map(_.offset) + .toList + TextSocketOffset(offs) + } + + override def deserializeOffset(json: String): Offset = { + TextSocketOffset(Serialization.read[List[Int]](json)) + } + + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.startOffset = offset + .orElse(TextSocketOffset(List.fill(numPartitions)(0))) + .asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) + } + + override def getStartOffset: Offset = startOffset + + override def readSchema(): StructType = { + if (includeTimestamp) { + TextSocketReader.SCHEMA_TIMESTAMP + } else { + TextSocketReader.SCHEMA_REGULAR + } + } + + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + + val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" + endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + val offsets = startOffset match { + case off: TextSocketOffset => off.offsets + case off => + throw new IllegalArgumentException( + s"invalid offset type ${off.getClass} for TextSocketContinuousReader") + } + + if (offsets.size != numPartitions) { + throw new IllegalArgumentException( + s"The previous run contained ${offsets.size} partitions, but" + + s" $numPartitions partitions are currently configured. The numPartitions option" + + " cannot be changed.") + } + + startOffset.offsets.zipWithIndex.map { + case (offset, i) => + TextSocketContinuousInputPartition( + endpointName, i, offset, includeTimestamp): InputPartition[InternalRow] + }.asJava + + } + + override def commit(end: Offset): Unit = synchronized { + val endOffset = end match { + case off: TextSocketOffset => off + case _ => throw new IllegalArgumentException(s"TextSocketContinuousReader.commit()" + + s"received an offset ($end) that did not originate with an instance of this class") + } + + endOffset.offsets.zipWithIndex.foreach { + case (offset, partition) => + val max = startOffset.offsets(partition) + buckets(partition).size + if (offset > max) { + throw new IllegalStateException("Invalid offset " + offset + " to commit" + + " for partition " + partition + ". Max valid offset: " + max) + } + val n = offset - startOffset.offsets(partition) + buckets(partition).trimStart(n) + } + startOffset = endOffset + recordEndpoint.setStartOffsets(startOffset.offsets) + } + + /** Stop this source. */ + override def stop(): Unit = synchronized { + if (socket != null) { + try { + // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to + // stop the readThread is to close the socket. + socket.close() + } catch { + case e: IOException => + } + socket = null + } + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + private def initialize(): Unit = synchronized { + socket = new Socket(host, port) + val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) + // Thread continuously reads from a socket and inserts data into buckets + readThread = new Thread(s"TextSocketContinuousReader($host, $port)") { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val line = reader.readLine() + if (line == null) { + // End of file reached + logWarning(s"Stream closed by $host:$port") + return + } + TextSocketContinuousReader.this.synchronized { + currentOffset += 1 + val newData = (line, + Timestamp.valueOf( + TextSocketReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + buckets(currentOffset % numPartitions) += newData + } + } + } catch { + case e: IOException => + } + } + } + + readThread.start() + } + + override def toString: String = s"TextSocketContinuousReader[host: $host, port: $port]" + + private def includeTimestamp: Boolean = options.getBoolean("includeTimestamp", false) + +} + +/** + * Continuous text socket input partition. + */ +case class TextSocketContinuousInputPartition( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) +extends InputPartition[InternalRow] { + + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, + includeTimestamp) +} + +/** + * Continuous text socket input partition reader. + * + * Polls the driver endpoint for new records. + */ +class TextSocketContinuousInputPartitionReader( + driverEndpointName: String, + partitionId: Int, + startOffset: Int, + includeTimestamp: Boolean) + extends ContinuousInputPartitionReader[InternalRow] { + + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[InternalRow] = None + + override def next(): Boolean = { + try { + current = getRecord + while (current.isEmpty) { + Thread.sleep(100) + current = getRecord + } + currentOffset += 1 + } catch { + case _: InterruptedException => + // Someone's trying to end the task; just let them. + return false + } + true + } + + override def get(): InternalRow = { + current.get + } + + override def close(): Unit = {} + + override def getOffset: PartitionOffset = + ContinuousRecordPartitionOffset(partitionId, currentOffset) + + private def getRecord: Option[InternalRow] = + endpoint.askSync[Option[InternalRow]](GetRecord( + ContinuousRecordPartitionOffset(partitionId, currentOffset))).map(rec => + if (includeTimestamp) { + rec + } else { + InternalRow(rec.get(0, TextSocketReader.SCHEMA_TIMESTAMP) + .asInstanceOf[(String, Timestamp)]._1) + } + ) +} + +case class TextSocketOffset(offsets: List[Int]) extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json: String = Serialization.write(offsets) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 711f0941fe731..4a32217f149bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -33,7 +33,6 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} @@ -63,7 +62,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private var startOffset: ContinuousMemoryStreamOffset = _ - private val recordEndpoint = new RecordEndpoint() + private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ def addData(data: TraversableOnce[A]): Offset = synchronized { @@ -94,7 +93,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { ContinuousMemoryStreamOffset( offsets.map { - case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + case ContinuousRecordPartitionOffset(part, num) => (part, num) }.toMap ) } @@ -127,27 +126,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa options: DataSourceOptions): ContinuousReader = { this } - - /** - * Endpoint for executors to poll for records. - */ - private class RecordEndpoint extends ThreadSafeRpcEndpoint { - override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => - ContinuousMemoryStream.this.synchronized { - val buf = records(part) - val record = if (buf.size <= index) None else Some(buf(index)) - - context.reply(record.map(r => encoder.toRow(r).copy())) - } - } - } } object ContinuousMemoryStream { - case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) protected val memoryStreamId = new AtomicInteger(0) def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = @@ -207,12 +188,12 @@ class ContinuousMemoryStreamInputPartitionReader( override def close(): Unit = {} - override def getOffset: ContinuousMemoryStreamPartitionOffset = - ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + override def getOffset: ContinuousRecordPartitionOffset = + ContinuousRecordPartitionOffset(partition, currentOffset) private def getRecord: Option[InternalRow] = endpoint.askSync[Option[InternalRow]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + GetRecord(ContinuousRecordPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) @@ -220,6 +201,3 @@ case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) private implicit val formats = Serialization.formats(NoTypeHints) override def json(): String = Serialization.write(partitionNums) } - -case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) - extends PartitionOffset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 9f53a1849b33d..874c479db95d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -33,14 +33,16 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String -object TextSocketMicroBatchReader { +// Shared object for micro-batch and continuous reader +object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -137,9 +139,9 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR override def readSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { - TextSocketMicroBatchReader.SCHEMA_TIMESTAMP + TextSocketReader.SCHEMA_TIMESTAMP } else { - TextSocketMicroBatchReader.SCHEMA_REGULAR + TextSocketReader.SCHEMA_REGULAR } } @@ -226,7 +228,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupport with DataSourceRegister with Logging { + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -258,6 +260,17 @@ class TextSocketSourceProvider extends DataSourceV2 new TextSocketMicroBatchReader(options) } + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + new TextSocketContinuousReader(options) + } + /** String that represents the format that this data source provider uses. */ override def shortName(): String = "socket" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 52e8386f6b1fa..48e5cf75bf8bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -32,12 +32,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { @@ -300,6 +301,101 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("continuous data") { + serverThread = new ServerThread() + serverThread.start() + + val reader = new TextSocketContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "port" -> serverThread.port.toString).asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() + assert(tasks.size == 2) + + val numRecords = 10 + val data = scala.collection.mutable.ListBuffer[Int]() + val offsets = scala.collection.mutable.ListBuffer[Int]() + import org.scalatest.time.SpanSugar._ + failAfter(5 seconds) { + // inject rows, read and check the data and offsets + for (i <- 0 until numRecords) { + serverThread.enqueue(i.toString) + } + tasks.asScala.foreach { + case t: TextSocketContinuousInputPartition => + val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + for (i <- 0 until numRecords / 2) { + r.next() + offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) + data.append(r.get().get(0, DataTypes.StringType).asInstanceOf[String].toInt) + // commit the offsets in the middle and validate if processing continues + if (i == 2) { + commitOffset(t.partitionId, i + 1) + } + } + assert(offsets.toSeq == Range.inclusive(1, 5)) + assert(data.toSeq == Range(t.partitionId, 10, 2)) + offsets.clear() + data.clear() + case _ => throw new IllegalStateException("Unexpected task type") + } + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3)) + reader.commit(TextSocketOffset(List(5, 5))) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5)) + } + + def commitOffset(partition: Int, offset: Int): Unit = { + val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset] + .offsets.updated(partition, offset) + reader.commit(TextSocketOffset(offsetsToCommit)) + assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit) + } + } + + test("continuous data - invalid commit") { + serverThread = new ServerThread() + serverThread.start() + + val reader = new TextSocketContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "port" -> serverThread.port.toString).asJava)) + reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + // ok to commit same offset + reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + assertThrows[IllegalStateException] { + reader.commit(TextSocketOffset(List(6, 6))) + } + } + + test("continuous data with timestamp") { + serverThread = new ServerThread() + serverThread.start() + + val reader = new TextSocketContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", + "includeTimestamp" -> "true", + "port" -> serverThread.port.toString).asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.planInputPartitions() + assert(tasks.size == 2) + + val numRecords = 4 + // inject rows, read and check the data and offsets + for (i <- 0 until numRecords) { + serverThread.enqueue(i.toString) + } + tasks.asScala.foreach { + case t: TextSocketContinuousInputPartition => + val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + for (i <- 0 until numRecords / 2) { + r.next() + assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) + .isInstanceOf[(String, Timestamp)]) + } + case _ => throw new IllegalStateException("Unexpected task type") + } + } + /** * This class tries to mimic the behavior of netcat, so that we can ensure * TextSocketStream supports netcat, which only accepts the first connection From 4f175850985cfc4c64afb90d784bb292e81dc0b7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 Aug 2018 11:32:15 +0200 Subject: [PATCH 1366/2461] [SPARK-19355][SQL] Use map output statistics to improve global limit's parallelism ## What changes were proposed in this pull request? A logical `Limit` is performed physically by two operations `LocalLimit` and `GlobalLimit`. Most of time, we gather all data into a single partition in order to run `GlobalLimit`. If we use a very big limit number, shuffling data causes performance issue also reduces parallelism. We can avoid shuffling into single partition if we don't care data ordering. This patch implements this idea by doing a map stage during global limit. It collects the info of row numbers at each partition. For each partition, we locally retrieves limited data without any shuffling to finish this global limit. For example, we have three partitions with rows (100, 100, 50) respectively. In global limit of 100 rows, we may take (34, 33, 33) rows for each partition locally. After global limit we still have three partitions. If the data partition has certain ordering, we can't distribute required rows evenly to each partitions because it could change data ordering. But we still can avoid shuffling. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #16677 from viirya/improve-global-limit-parallelism. --- .../sort/BypassMergeSortShuffleWriter.java | 5 +- .../shuffle/sort/UnsafeShuffleWriter.java | 3 +- .../apache/spark/MapOutputStatistics.scala | 6 +- .../org/apache/spark/MapOutputTracker.scala | 10 +- .../apache/spark/scheduler/MapStatus.scala | 43 +++++--- .../shuffle/sort/SortShuffleWriter.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 2 + .../apache/spark/MapOutputTrackerSuite.scala | 28 ++--- .../scala/org/apache/spark/ShuffleSuite.scala | 1 + .../spark/scheduler/DAGSchedulerSuite.scala | 10 +- .../spark/scheduler/MapStatusSuite.scala | 16 +-- .../serializer/KryoSerializerSuite.scala | 3 +- .../plans/physical/partitioning.scala | 14 +++ .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../exchange/ShuffleExchangeExec.scala | 8 ++ .../apache/spark/sql/execution/limit.scala | 101 +++++++++++++++--- .../test/resources/sql-tests/inputs/limit.sql | 2 + .../inputs/subquery/in-subquery/in-limit.sql | 5 +- .../resources/sql-tests/results/limit.sql.out | 92 ++++++++-------- .../subquery/in-subquery/in-limit.sql.out | 56 +++++----- .../spark/sql/DataFrameAggregateSuite.scala | 12 ++- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +- .../execution/ExchangeCoordinatorSuite.scala | 6 +- .../spark/sql/execution/PlannerSuite.scala | 4 +- .../execution/HiveCompatibilitySuite.scala | 4 + .../sql/hive/execution/PruningSuite.scala | 8 ++ 26 files changed, 322 insertions(+), 140 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..e3bd5496cf5ba 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +167,8 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..069e6d5f224d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,7 +248,8 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d0d8cbb..ff85e11409e35 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,5 +23,9 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) + * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val recordsByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4bc6541f..41575ce4e6e3d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -522,16 +522,19 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) + val recordsByMapTask = new Array[Long](statuses.length) + val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - for (s <- statuses) { + statuses.zipWithIndex.foreach { case (s, index) => for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } + recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -548,8 +551,11 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } + statuses.zipWithIndex.foreach { case (s, index) => + recordsByMapTask(index) = s.numberOfOutput + } } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 659694dd189ad..7e1d75fe723d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, + * for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -44,18 +45,23 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + /** + * The number of outputs for the map task. + */ + def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { if (uncompressedSizes.length > Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, numOutput) } } @@ -98,29 +104,34 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) } override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -143,17 +154,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private var hugeBlockSizes: Map[Int, Byte], + private[this] var numOutput: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc + override def numberOfOutput: Long = numOutput + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -168,6 +182,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -179,6 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -194,7 +210,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + numOutput: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -235,6 +254,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizesArray.toMap, numOutput) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f3..91fc26762e533 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, + writeMetrics.recordsWritten) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e9..faa70f23b0ac6 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,6 +233,7 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -252,6 +253,7 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); + assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..e79739692fe13 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 10)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 10)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 1)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 1)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), 1)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), 1)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index ced5a06516f75..d11eaf8c2749c 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -391,6 +391,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 211002b2b5caa..5e095ce1ae4e9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -423,17 +423,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -2576,7 +2576,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 354e6386fa60e..555e48bd28aa0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 1) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) val sizes = Array.fill[Long](500)(150L) // Test default value - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) assert(status.isInstanceOf[CompressedMapStatus]) // Test Non-positive values for (s <- -1 to 0) { assertThrows[IllegalArgumentException] { conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) } } // Test positive values Seq(1, 100, 499, 500, 501).foreach { s => conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, 1) if(sizes.length > s) { assert(status.isInstanceOf[HighlyCompressedMapStatus]) } else { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index fc78655bf52ec..240f8cf800fe8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,7 +345,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize( + HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e835d9cd..cd28c733f3613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -206,6 +208,18 @@ case object SinglePartition extends Partitioning { } } +/** + * Represents a partitioning where rows are only serialized/deserialized locally. The number + * of partitions are not changed and also the distribution of rows. This is mainly used to + * obtain some statistics of map tasks such as number of outputs. + */ +case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { + val numPartitions = childRDD.getNumPartitions + + // We will perform this partitioning no matter what the data distribution is. + override def satisfies0(required: Distribution): Boolean = false +} + /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 979a55467ff89..603c0708fe7ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -214,6 +214,13 @@ object SQLConf { .intConf .createWithDefault(4) + val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") + .internal() + .doc("During global limit, try to evenly distribute limited rows across data " + + "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") + .booleanConf + .createWithDefault(true) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -1682,6 +1689,8 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) + def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b89203719541b..50f10c31427d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,6 +231,11 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case l: LocalPartitioning => + new Partitioner { + override def numPartitions: Int = l.numPartitions + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -247,6 +252,9 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case _: LocalPartitioning => + val partitionId = TaskContext.get().partitionId() + _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda8913738..392ca13724bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,13 +47,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Helper trait which defines methods that are shared by both - * [[LocalLimitExec]] and [[GlobalLimitExec]]. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -trait BaseLimitExec extends UnaryExecNode with CodegenSupport { - val limit: Int +case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Take the `limit` elements of the child output. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning -} -/** - * Take the first `limit` elements of the child's single output partition. - */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { + val childRDD = child.execute() + val partitioner = LocalPartitioning(childRDD) + val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( + childRDD, child.output, partitioner, serializer) + val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) + submittedStageFuture.get().recordsByPartitionId.toSeq + } else { + Nil + } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering + // During global limit, try to evenly distribute limited rows across data + // partitions. If disabled, scanning data partitions sequentially until reaching limit number. + // Besides, if child output has certain ordering, we can't evenly pick up rows from + // each parititon. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && child.outputOrdering == Nil + + val shuffled = new ShuffledRowRDD(shuffleDependency) + + val sumOfOutput = numberOfOutput.sum + if (sumOfOutput <= limit) { + shuffled + } else if (!flatGlobalLimit) { + var numRowTaken = 0 + val takeAmounts = numberOfOutput.map { num => + if (numRowTaken + num < limit) { + numRowTaken += num.toInt + num.toInt + } else { + val toTake = limit - numRowTaken + numRowTaken += toTake + toTake + } + } + val broadMap = sparkContext.broadcast(takeAmounts) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } else { + // We try to evenly require the asked limit number of rows across all child rdd's partitions. + var rowsNeedToTake: Long = limit + val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) + val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) + + while (rowsNeedToTake > 0) { + val nonEmptyParts = remainingRowsByPartition.count(_ > 0) + // If the rows needed to take are less the number of non-empty partitions, take one row from + // each non-empty partitions until we reach `limit` rows. + // Otherwise, evenly divide the needed rows to each non-empty partitions. + val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) + remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => + // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during + // the traversal, so we need to add this check. + if (rowsNeedToTake > 0 && num > 0) { + if (num >= takePerPart) { + rowsNeedToTake -= takePerPart + takeAmountByPartition(index) += takePerPart + remainingRowsByPartition(index) -= takePerPart + } else { + rowsNeedToTake -= num + takeAmountByPartition(index) += num + remainingRowsByPartition(index) -= num + } + } + } + } + val broadMap = sparkContext.broadcast(takeAmountByPartition) + shuffled.mapPartitionsWithIndexInternal { case (index, iter) => + iter.take(broadMap.value(index).toInt) + } + } + } } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index b4c73cf33e53a..e33cd819f281f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,3 +1,5 @@ +-- Disable global limit parallel +set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a40ee082ba3b9..a862e0985b20c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,6 +1,9 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- Disable global limit optimization +set spark.sql.limit.flatGlobalLimit=false; + create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -97,4 +100,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; \ No newline at end of file +LIMIT 1; diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 02fe1de84f753..187f3bd6858fe 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,63 +1,62 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 15 -- !query 0 -SELECT * FROM testdata LIMIT 2 +set spark.sql.limit.flatGlobalLimit=false -- !query 0 schema -struct +struct -- !query 0 output -1 1 -2 2 +spark.sql.limit.flatGlobalLimit false -- !query 1 -SELECT * FROM arraydata LIMIT 2 +SELECT * FROM testdata LIMIT 2 -- !query 1 schema -struct,nestedarraycol:array>> +struct -- !query 1 output -[1,2,3] [[1,2,3]] -[2,3,4] [[2,3,4]] +1 1 +2 2 -- !query 2 -SELECT * FROM mapdata LIMIT 2 +SELECT * FROM arraydata LIMIT 2 -- !query 2 schema -struct> +struct,nestedarraycol:array>> -- !query 2 output -{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} -{1:"a2",2:"b2",3:"c2",4:"d2"} +[1,2,3] [[1,2,3]] +[2,3,4] [[2,3,4]] -- !query 3 -SELECT * FROM testdata LIMIT 2 + 1 +SELECT * FROM mapdata LIMIT 2 -- !query 3 schema -struct +struct> -- !query 3 output -1 1 -2 2 -3 3 +{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} +{1:"a2",2:"b2",3:"c2",4:"d2"} -- !query 4 -SELECT * FROM testdata LIMIT CAST(1 AS int) +SELECT * FROM testdata LIMIT 2 + 1 -- !query 4 schema struct -- !query 4 output 1 1 +2 2 +3 3 -- !query 5 -SELECT * FROM testdata LIMIT -1 +SELECT * FROM testdata LIMIT CAST(1 AS int) -- !query 5 schema -struct<> +struct -- !query 5 output -org.apache.spark.sql.AnalysisException -The limit expression must be equal to or greater than 0, but got -1; +1 1 -- !query 6 -SELECT * FROM testData TABLESAMPLE (-1 ROWS) +SELECT * FROM testdata LIMIT -1 -- !query 6 schema struct<> -- !query 6 output @@ -66,61 +65,70 @@ The limit expression must be equal to or greater than 0, but got -1; -- !query 7 -SELECT * FROM testdata LIMIT CAST(1 AS INT) +SELECT * FROM testData TABLESAMPLE (-1 ROWS) -- !query 7 schema -struct +struct<> -- !query 7 output -1 1 +org.apache.spark.sql.AnalysisException +The limit expression must be equal to or greater than 0, but got -1; -- !query 8 -SELECT * FROM testdata LIMIT CAST(NULL AS INT) +SELECT * FROM testdata LIMIT CAST(1 AS INT) -- !query 8 schema -struct<> +struct -- !query 8 output -org.apache.spark.sql.AnalysisException -The evaluated limit expression must not be null, but got CAST(NULL AS INT); +1 1 -- !query 9 -SELECT * FROM testdata LIMIT key > 3 +SELECT * FROM testdata LIMIT CAST(NULL AS INT) -- !query 9 schema struct<> -- !query 9 output org.apache.spark.sql.AnalysisException -The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); +The evaluated limit expression must not be null, but got CAST(NULL AS INT); -- !query 10 -SELECT * FROM testdata LIMIT true +SELECT * FROM testdata LIMIT key > 3 -- !query 10 schema struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got boolean; +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); -- !query 11 -SELECT * FROM testdata LIMIT 'a' +SELECT * FROM testdata LIMIT true -- !query 11 schema struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -The limit expression must be integer type, but got string; +The limit expression must be integer type, but got boolean; -- !query 12 -SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 +SELECT * FROM testdata LIMIT 'a' -- !query 12 schema -struct +struct<> -- !query 12 output -4 +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; -- !query 13 -SELECT * FROM testdata WHERE key < 3 LIMIT ALL +SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 -- !query 13 schema -struct +struct -- !query 13 output +4 + + +-- !query 14 +SELECT * FROM testdata WHERE key < 3 LIMIT ALL +-- !query 14 schema +struct +-- !query 14 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 71ca1f8649475..9eb5b3383e734 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,8 +1,16 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 9 -- !query 0 +set spark.sql.limit.flatGlobalLimit=false +-- !query 0 schema +struct +-- !query 0 output +spark.sql.limit.flatGlobalLimit false + + +-- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -17,13 +25,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 0 schema +-- !query 1 schema struct<> --- !query 0 output +-- !query 1 output --- !query 1 +-- !query 2 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -39,13 +47,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 1 schema +-- !query 2 schema struct<> --- !query 1 output +-- !query 2 output --- !query 2 +-- !query 3 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -60,27 +68,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 2 schema +-- !query 3 schema struct<> --- !query 2 output +-- !query 3 output --- !query 3 +-- !query 4 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 4 +-- !query 5 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -88,16 +96,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 6 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -108,29 +116,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 1 NULL --- !query 6 +-- !query 7 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 7 +-- !query 8 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -141,7 +149,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output 1 6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..85b3ca11383f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,11 +557,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } test("SPARK-17237 remove backticks in a pivot result schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3a393d766b0bc..c1a5f50fd82c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -524,6 +524,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } + test("limit for skew dataframe") { + // Create a skew dataframe. + val df = testData.repartition(100).union(testData).limit(50) + // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, + // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` + // work on skew partitions. + assert(df.rdd.count() == 50L) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1935,7 +1944,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") + val df = sql("SELECT a, b from testData2 order by a, b limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b736d43bfc6ba..41de731d41f82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) + new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) + new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), + new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bdc106325aa5c..3db89ecfad9fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } { @@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 5) + assert(numExchanges === 3) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cebaad5b4ad9b..b9b2b7dbf38e8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -59,6 +60,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Ensure that limit operation returns rows in the same order as Hive + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -73,6 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index cc592cf6ca629..16541295eb453 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,21 +22,29 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} +import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { + private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit + override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } + override def afterAll() { + TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) + super.afterAll() + } // Column pruning tests From 132bcceebb7723aea9845c9e207e572ecb44a4a2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Aug 2018 07:32:52 -0500 Subject: [PATCH 1367/2461] [SPARK-25036][SQL] Avoid discarding unmoored doc comment in Scala-2.12. ## What changes were proposed in this pull request? This PR avoid the following compilation error using sbt in Scala-2.12. ``` [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala:410: discarding unmoored doc comment [error] [warn] /** [error] [warn] [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala:441: discarding unmoored doc comment [error] [warn] /** [error] [warn] ... [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala:440: discarding unmoored doc comment [error] [warn] /** [error] [warn] ``` ## How was this patch tested? Existing UTs Closes #22059 from kiszk/SPARK-25036d. Authored-by: Kazuaki Ishizaki Signed-off-by: Sean Owen --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 4 ++-- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 918560a5988eb..4cdd17266b771 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -407,7 +407,7 @@ private[spark] object RandomForest extends Logging with Serializable { metadata.isMulticlassWithCategoricalFeatures) logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) - /** + /* * Performs a sequential aggregation over a partition for a particular tree and node. * * For each feature, the aggregate sufficient statistics are updated for the relevant @@ -438,7 +438,7 @@ private[spark] object RandomForest extends Logging with Serializable { } } - /** + /* * Performs a sequential aggregation over a partition. * * Each data point contributes to one node. For each feature, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ed9879c06968d..75614a41e0b62 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -437,7 +437,7 @@ private[spark] class Client( } } - /** + /* * Distribute a file to the cluster. * * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied From 1dd0f1744651efadaa349b96cfd3aaafda1e9f57 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Aug 2018 07:34:09 -0500 Subject: [PATCH 1368/2461] [SPARK-25036][SQL][FOLLOW-UP] Avoid match may not be exhaustive in Scala-2.12. ## What changes were proposed in this pull request? This is a follow-up pr of #22014 and #22039 We still have some more compilation errors in mllib with scala-2.12 with sbt: ``` [error] [warn] /home/ishizaki/Spark/PR/scala212/spark/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala:116: match may not be exhaustive. [error] It would fail on the following inputs: ("silhouette", _), (_, "cosine"), (_, "squaredEuclidean"), (_, String()), (_, _) [error] [warn] ($(metricName), $(distanceMeasure)) match { [error] [warn] ``` ## How was this patch tested? Existing UTs Closes #22058 from kiszk/SPARK-25036c. Authored-by: Kazuaki Ishizaki Signed-off-by: Sean Owen --- .../org/apache/spark/ml/evaluation/ClusteringEvaluator.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index a6d6b4ea8b965..5c1d1aebdc315 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -119,6 +119,8 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str df, $(predictionCol), $(featuresCol)) case ("silhouette", "cosine") => CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) + case (mn, dm) => + throw new IllegalArgumentException(s"No support for metric $mn, distance $dm") } } } From 91cdab51ccb3a4e3b6d76132d00f3da30598735b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 10 Aug 2018 11:15:36 -0500 Subject: [PATCH 1369/2461] [MINOR][BUILD] Add ECCN notice required by http://www.apache.org/dev/crypto.html ## What changes were proposed in this pull request? Add ECCN notice required by http://www.apache.org/dev/crypto.html See https://issues.apache.org/jira/browse/LEGAL-398 This should probably be backported to 2.3, 2.2, as that's when the key dep (commons crypto) turned up. BC is actually unused, but still there. ## How was this patch tested? N/A Closes #22064 from srowen/ECCN. Authored-by: Sean Owen Signed-off-by: Sean Owen --- NOTICE | 24 ++++++++++++++++++++++++ NOTICE-binary | 25 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/NOTICE b/NOTICE index 9246cc54caa3a..23cb53fe3f367 100644 --- a/NOTICE +++ b/NOTICE @@ -4,3 +4,27 @@ Copyright 2014 and onwards The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). + +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. + +This software includes Bouncy Castle (http://bouncycastle.org/) to support the jets3t library. diff --git a/NOTICE-binary b/NOTICE-binary index d56f99bdb55a6..3155c3843ee7c 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -5,6 +5,31 @@ This product includes software developed at The Apache Software Foundation (http://www.apache.org/). +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. + +This software includes Bouncy Castle (http://bouncycastle.org/) to support the jets3t library. + + // ------------------------------------------------------------------ // NOTICE file corresponding to the section 4d of The Apache License, // Version 2.0, in this case for From f5aba657396bd4e2e03dd06491a2d169a99592a7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 10 Aug 2018 10:53:44 -0700 Subject: [PATCH 1370/2461] [SPARK-25081][CORE] Nested spill in ShuffleExternalSorter should not access released memory page ## What changes were proposed in this pull request? This issue is pretty similar to [SPARK-21907](https://issues.apache.org/jira/browse/SPARK-21907). "allocateArray" in [ShuffleInMemorySorter.reset](https://github.com/apache/spark/blob/9b8521e53e56a53b44c02366a99f8a8ee1307bbf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java#L99) may trigger a spill and cause ShuffleInMemorySorter access the released `array`. Another task may get the same memory page from the pool. This will cause two tasks access the same memory page. When a task reads memory written by another task, many types of failures may happen. Here are some examples I have seen: - JVM crash. (This is easy to reproduce in a unit test as we fill newly allocated and deallocated memory with 0xa5 and 0x5a bytes which usually points to an invalid memory address) - java.lang.IllegalArgumentException: Comparison method violates its general contract! - java.lang.NullPointerException at org.apache.spark.memory.TaskMemoryManager.getPage(TaskMemoryManager.java:384) - java.lang.UnsupportedOperationException: Cannot grow BufferHolder by size -536870912 because the size after growing exceeds size limitation 2147483632 This PR resets states in `ShuffleInMemorySorter.reset` before calling `allocateArray` to fix the issue. ## How was this patch tested? The new unit test will make JVM crash without the fix. Closes #22062 from zsxwing/SPARK-25081. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../shuffle/sort/ShuffleInMemorySorter.java | 12 +- .../sort/ShuffleExternalSorterSuite.scala | 111 ++++++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 8f49859746b89..4b48599ad311e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -65,7 +65,7 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int usableCapacity = 0; - private int initialSize; + private final int initialSize; ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { this.consumer = consumer; @@ -94,12 +94,20 @@ public int numRecords() { } public void reset() { + // Reset `pos` here so that `spill` triggered by the below `allocateArray` will be no-op. + pos = 0; if (consumer != null) { consumer.freeArray(array); + // As `array` has been released, we should set it to `null` to avoid accessing it before + // `allocateArray` returns. `usableCapacity` is also set to `0` to avoid any codes writing + // data to `ShuffleInMemorySorter` when `array` is `null` (e.g., in + // ShuffleExternalSorter.growPointerArrayIfNecessary, we may try to access + // `ShuffleInMemorySorter` when `allocateArray` throws SparkOutOfMemoryError). + array = null; + usableCapacity = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } - pos = 0; } public void expandPointerArray(LongArray newArray) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala new file mode 100644 index 0000000000000..b9f0e873375b0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.lang.{Long => JLong} + +import org.mockito.Mockito.when +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory._ +import org.apache.spark.unsafe.Platform + +class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + + test("nested spill should be no-op") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set("spark.testing", "true") + .set("spark.testing.memory", "1600") + .set("spark.memory.fraction", "1") + sc = new SparkContext(conf) + + val memoryManager = UnifiedMemoryManager(conf, 1) + + var shouldAllocate = false + + // Mock `TaskMemoryManager` to allocate free memory when `shouldAllocate` is true. + // This will trigger a nested spill and expose issues if we don't handle this case properly. + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long = { + // ExecutionMemoryPool.acquireMemory will wait until there are 400 bytes for a task to use. + // So we leave 400 bytes for the task. + if (shouldAllocate && + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) { + val acquireExecutionMemoryMethod = + memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head + acquireExecutionMemoryMethod.invoke( + memoryManager, + JLong.valueOf( + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), + JLong.valueOf(1L), // taskAttemptId + MemoryMode.ON_HEAP + ).asInstanceOf[java.lang.Long] + } + super.acquireExecutionMemory(required, consumer) + } + } + val taskContext = mock[TaskContext] + val taskMetrics = new TaskMetrics + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, // initialSize - This will require ShuffleInMemorySorter to acquire at least 800 bytes + 1, // numPartitions + conf, + new ShuffleWriteMetrics) + val inMemSorter = { + val field = sorter.getClass.getDeclaredField("inMemSorter") + field.setAccessible(true) + field.get(sorter).asInstanceOf[ShuffleInMemorySorter] + } + // Allocate memory to make the next "insertRecord" call triggers a spill. + val bytes = new Array[Byte](1) + while (inMemSorter.hasSpaceForAnotherRecord) { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + + // This flag will make the mocked TaskMemoryManager acquire free memory released by spill to + // trigger a nested spill. + shouldAllocate = true + + // Should throw `SparkOutOfMemoryError` as there is no enough memory: `ShuffleInMemorySorter` + // will try to acquire 800 bytes but there are only 400 bytes available. + // + // Before the fix, a nested spill may use a released page and this causes two tasks access the + // same memory page. When a task reads memory written by another task, many types of failures + // may happen. Here are some examples we have seen: + // + // - JVM crash. (This is easy to reproduce in the unit test as we fill newly allocated and + // deallocated memory with 0xa5 and 0x5a bytes which usually points to an invalid memory + // address) + // - java.lang.IllegalArgumentException: Comparison method violates its general contract! + // - java.lang.NullPointerException + // at org.apache.spark.memory.TaskMemoryManager.getPage(TaskMemoryManager.java:384) + // - java.lang.UnsupportedOperationException: Cannot grow BufferHolder by size -536870912 + // because the size after growing exceeds size limitation 2147483632 + intercept[SparkOutOfMemoryError] { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + } +} From 4b11d909fd9e0f55ecb1f51af64cb4ff4dbd615b Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 11 Aug 2018 20:49:52 +0800 Subject: [PATCH 1371/2461] [MINOR][DOC] Add missing compression codec . ## What changes were proposed in this pull request? Parquet file provides six codecs: "snappy", "gzip", "lzo", "lz4", "brotli", "zstd". This pr add missing compression codec :"lz4", "brotli", "zstd" . ## How was this patch tested? N/A Closes #22068 from 10110346/nosupportlz4. Authored-by: liuxian Signed-off-by: hyukjinkwon --- python/pyspark/sql/readwriter.py | 8 ++++---- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index abf878ae709a5..49f4e6b2ede1b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -825,10 +825,10 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): exists. :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the - known case-insensitive shorten names (none, snappy, gzip, and lzo). - This will override ``spark.sql.parquet.compression.codec``. If None - is set, it uses the value specified in - ``spark.sql.parquet.compression.codec``. + known case-insensitive shorten names (none, uncompressed, snappy, gzip, + lzo, brotli, lz4, and zstd). This will override + ``spark.sql.parquet.compression.codec``. If None is set, it uses the + value specified in ``spark.sql.parquet.compression.codec``. >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 603c0708fe7ad..594952e95dd4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -377,7 +377,7 @@ object SQLConf { "`parquet.compression` is specified in the table-specific options/properties, the " + "precedence would be `compression`, `parquet.compression`, " + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + - "snappy, gzip, lzo.") + "snappy, gzip, lzo, brotli, lz4, zstd.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index db2a1e7426197..650c91790a758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -548,8 +548,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *

        *
      • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive - * shorten names(`none`, `snappy`, `gzip`, and `lzo`). This will override - * `spark.sql.parquet.compression.codec`.
      • + * shorten names(`none`, `uncompressed`, `snappy`, `gzip`, `lzo`, `brotli`, `lz4`, and `zstd`). + * This will override `spark.sql.parquet.compression.codec`. *
      * * @since 1.4.0 From b73eb0efe87ba724733b781bf78cb0214c80c9e7 Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 11 Aug 2018 20:53:09 +0800 Subject: [PATCH 1372/2461] [MINOR][DOC] Add missing compression codec . ## What changes were proposed in this pull request? Parquet file provides six codecs: "snappy", "gzip", "lzo", "lz4", "brotli", "zstd". This pr add missing compression codec :"lz4", "brotli", "zstd" . ## How was this patch tested? N/A Closes #22068 from 10110346/nosupportlz4. Authored-by: liuxian Signed-off-by: hyukjinkwon From 41a7de6002d071ba81321bbe02b46db4b3f8cda2 Mon Sep 17 00:00:00 2001 From: yucai Date: Sat, 11 Aug 2018 21:38:31 +0800 Subject: [PATCH 1373/2461] [SPARK-25084][SQL] "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue ## What changes were proposed in this pull request? "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue. Simple way to reproduce: ```scala val df = spark.range(1000) val columns = (0 until 400).map{ i => s"id as id$i" } val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") df.selectExpr(columns : _*).createTempView("test") spark.sql(s"select * from test distribute by ($distributeExprs)").count() ``` ## How was this patch tested? Add UT. Closes #22066 from yucai/SPARK-25084. Authored-by: yucai Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/hash.scala | 23 ++++++++++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index cec00b66f873c..a754e87a17968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) } val hashResultType = CodeGenerator.javaType(dataType) - ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" @@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |$code + """.stripMargin } @tailrec @@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val childResult = ctx.freshName("childResult") val fieldsHash = fields.zipWithIndex.map { case (field, index) => val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) + tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) s""" |$childResult = 0; |$computeFieldHash @@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |${CodeGenerator.JAVA_INT} $childResult = 0; + |$code + """.stripMargin } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c1a5f50fd82c0..84efd2b7a1dc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2840,4 +2840,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") { + withView("spark_25084") { + val count = 1000 + val df = spark.range(count) + val columns = (0 until 400).map{ i => s"id as id$i" } + val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") + df.selectExpr(columns : _*).createTempView("spark_25084") + assert( + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) + } + } } From 4855d5c4b97c21515b629a549d1fc395458f6f8a Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 11 Aug 2018 21:44:45 +0800 Subject: [PATCH 1374/2461] [SPARK-24822][PYSPARK] Python support for barrier execution mode ## What changes were proposed in this pull request? This PR add python support for barrier execution mode, thus enable launch a job containing barrier stage(s) from PySpark. We just forked the existing `RDDBarrier` and `RDD.barrier()` in Python api. ## How was this patch tested? Manually tested: ``` >>> rdd = sc.parallelize([1, 2, 3, 4]) >>> def f(iterator): yield sum(iterator) ... >>> rdd.barrier().mapPartitions(f).isBarrier() == True True ``` Unit tests will be added in a follow-up PR that implements BarrierTaskContext on python side. Closes #22011 from jiangxb1987/python. Authored-by: Xingbo Jiang Signed-off-by: Wenchen Fan --- .../apache/spark/api/python/PythonRDD.scala | 6 ++- .../org/apache/spark/rdd/RDDBarrier.scala | 1 - python/pyspark/rdd.py | 51 ++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8bc0ff7936daf..8c2ce883093c8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -45,7 +45,8 @@ import org.apache.spark.util._ private[spark] class PythonRDD( parent: RDD[_], func: PythonFunction, - preservePartitoning: Boolean) + preservePartitoning: Boolean, + isFromBarrier: Boolean = false) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) @@ -63,6 +64,9 @@ private[spark] class PythonRDD( val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 978e7c004e5e6..b399bf9febae3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -19,7 +19,6 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.BarrierTaskContext import org.apache.spark.TaskContext import org.apache.spark.annotation.{Experimental, Since} diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 951851804b1d8..d17a8eb76ad48 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2406,6 +2406,22 @@ def toLocalIterator(self): sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) return _load_from_socket(sock_info, self._jrdd_deserializer) + def barrier(self): + """ + .. note:: Experimental + + Indicates that Spark must launch the tasks together for the current stage. + + .. versionadded:: 2.4.0 + """ + return RDDBarrier(self) + + def _is_barrier(self): + """ + Whether this RDD is in a barrier stage. + """ + return self._jrdd.rdd().isBarrier() + def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast @@ -2429,6 +2445,33 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class RDDBarrier(object): + + """ + .. note:: Experimental + + An RDDBarrier turns an RDD into a barrier RDD, which forces Spark to launch tasks of the stage + contains this RDD together. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, rdd): + self.rdd = rdd + + def mapPartitions(self, f, preservesPartitioning=False): + """ + .. note:: Experimental + + Return a new RDD by applying a function to each partition of this RDD. + + .. versionadded:: 2.4.0 + """ + def func(s, iterator): + return f(iterator) + return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) + + class PipelinedRDD(RDD): """ @@ -2448,7 +2491,7 @@ class PipelinedRDD(RDD): 20 """ - def __init__(self, prev, func, preservesPartitioning=False): + def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): # This transformation is the first in its stage: self.func = func @@ -2474,6 +2517,7 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None + self.is_barrier = prev._is_barrier() or isFromBarrier def getNumPartitions(self): return self._prev_jrdd.partitions().size() @@ -2493,7 +2537,7 @@ def _jrdd(self): wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer, profiler) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, - self.preservesPartitioning) + self.preservesPartitioning, self.is_barrier) self._jrdd_val = python_rdd.asJavaRDD() if profiler: @@ -2509,6 +2553,9 @@ def id(self): def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) + def _is_barrier(self): + return self.is_barrier + def _test(): import doctest From 8ec25cd67e7ac4a8165917a4211e17aa8f7b394d Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Sat, 11 Aug 2018 21:23:36 -0500 Subject: [PATCH 1375/2461] Fix typos detected by github.com/client9/misspell ## What changes were proposed in this pull request? Fixing typos is sometimes very hard. It's not so easy to visually review them. Recently, I discovered a very useful tool for it, [misspell](https://github.com/client9/misspell). This pull request fixes minor typos detected by [misspell](https://github.com/client9/misspell) except for the false positives. If you would like me to work on other files as well, let me know. ## How was this patch tested? ### before ``` $ misspell . | grep -v '.js' R/pkg/R/SQLContext.R:354:43: "definiton" is a misspelling of "definition" R/pkg/R/SQLContext.R:424:43: "definiton" is a misspelling of "definition" R/pkg/R/SQLContext.R:445:43: "definiton" is a misspelling of "definition" R/pkg/R/SQLContext.R:495:43: "definiton" is a misspelling of "definition" NOTICE-binary:454:16: "containd" is a misspelling of "contained" R/pkg/R/context.R:46:43: "definiton" is a misspelling of "definition" R/pkg/R/context.R:74:43: "definiton" is a misspelling of "definition" R/pkg/R/DataFrame.R:591:48: "persistance" is a misspelling of "persistence" R/pkg/R/streaming.R:166:44: "occured" is a misspelling of "occurred" R/pkg/inst/worker/worker.R:65:22: "ouput" is a misspelling of "output" R/pkg/tests/fulltests/test_utils.R:106:25: "environemnt" is a misspelling of "environment" common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java:38:39: "existant" is a misspelling of "existent" common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java:83:39: "existant" is a misspelling of "existent" common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java:243:46: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:234:19: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:238:63: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:244:46: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:276:39: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java:27:20: "transfered" is a misspelling of "transferred" common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala:195:15: "orgin" is a misspelling of "origin" core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala:621:39: "gauranteed" is a misspelling of "guaranteed" core/src/main/scala/org/apache/spark/status/storeTypes.scala:113:29: "ect" is a misspelling of "etc" core/src/main/scala/org/apache/spark/storage/DiskStore.scala:282:18: "transfered" is a misspelling of "transferred" core/src/main/scala/org/apache/spark/util/ListenerBus.scala:64:17: "overriden" is a misspelling of "overridden" core/src/test/scala/org/apache/spark/ShuffleSuite.scala:211:7: "substracted" is a misspelling of "subtracted" core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:1922:49: "agriculteur" is a misspelling of "agriculture" core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:2468:84: "truely" is a misspelling of "truly" core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala:25:18: "persistance" is a misspelling of "persistence" core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala:26:69: "persistance" is a misspelling of "persistence" data/streaming/AFINN-111.txt:1219:0: "humerous" is a misspelling of "humorous" dev/run-pip-tests:55:28: "enviroments" is a misspelling of "environments" dev/run-pip-tests:91:37: "virutal" is a misspelling of "virtual" dev/merge_spark_pr.py:377:72: "accross" is a misspelling of "across" dev/merge_spark_pr.py:378:66: "accross" is a misspelling of "across" dev/run-pip-tests:126:25: "enviroments" is a misspelling of "environments" docs/configuration.md:1830:82: "overriden" is a misspelling of "overridden" docs/structured-streaming-programming-guide.md:525:45: "processs" is a misspelling of "processes" docs/structured-streaming-programming-guide.md:1165:61: "BETWEN" is a misspelling of "BETWEEN" docs/sql-programming-guide.md:1891:810: "behaivor" is a misspelling of "behavior" examples/src/main/python/sql/arrow.py:98:8: "substract" is a misspelling of "subtract" examples/src/main/python/sql/arrow.py:103:27: "substract" is a misspelling of "subtract" licenses/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt:170:0: "teh" is a misspelling of "the" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt:53:0: "eles" is a misspelling of "eels" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:99:20: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:539:11: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala:77:36: "Teh" is a misspelling of "The" mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala:230:24: "inital" is a misspelling of "initial" mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala:276:9: "Euclidian" is a misspelling of "Euclidean" mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala:237:26: "descripiton" is a misspelling of "descriptions" python/pyspark/find_spark_home.py:30:13: "enviroment" is a misspelling of "environment" python/pyspark/context.py:937:12: "supress" is a misspelling of "suppress" python/pyspark/context.py:938:12: "supress" is a misspelling of "suppress" python/pyspark/context.py:939:12: "supress" is a misspelling of "suppress" python/pyspark/context.py:940:12: "supress" is a misspelling of "suppress" python/pyspark/heapq3.py:6:63: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:7:2: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:263:29: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:263:39: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:270:49: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:270:59: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:275:2: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:275:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/heapq3.py:277:29: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:277:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/heapq3.py:713:8: "probabilty" is a misspelling of "probability" python/pyspark/ml/clustering.py:1038:8: "Currenlty" is a misspelling of "Currently" python/pyspark/ml/stat.py:339:23: "Euclidian" is a misspelling of "Euclidean" python/pyspark/ml/regression.py:1378:20: "paramter" is a misspelling of "parameter" python/pyspark/mllib/stat/_statistics.py:262:8: "probabilty" is a misspelling of "probability" python/pyspark/rdd.py:1363:32: "paramter" is a misspelling of "parameter" python/pyspark/streaming/tests.py:825:42: "retuns" is a misspelling of "returns" python/pyspark/sql/tests.py:768:29: "initalization" is a misspelling of "initialization" python/pyspark/sql/tests.py:3616:31: "initalize" is a misspelling of "initialize" resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala:120:39: "arbitary" is a misspelling of "arbitrary" resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala:26:45: "sucessfully" is a misspelling of "successfully" resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala:358:27: "constaints" is a misspelling of "constraints" resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala:111:24: "senstive" is a misspelling of "sensitive" sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala:1063:5: "overwirte" is a misspelling of "overwrite" sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:1348:17: "compatability" is a misspelling of "compatibility" sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:77:36: "paramter" is a misspelling of "parameter" sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:1374:22: "precendence" is a misspelling of "precedence" sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala:238:27: "unnecassary" is a misspelling of "unnecessary" sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala:212:17: "whn" is a misspelling of "when" sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala:147:60: "timestmap" is a misspelling of "timestamp" sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala:150:45: "precentage" is a misspelling of "percentage" sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala:135:29: "infered" is a misspelling of "inferred" sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922:1:52: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182:1:52: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e:1:63: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478:1:63: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8:9:79: "occurence" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8:13:110: "occurence" is a misspelling of "occurrence" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q:46:105: "distint" is a misspelling of "distinct" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q:29:3: "Currenly" is a misspelling of "Currently" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q:72:15: "existant" is a misspelling of "existent" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q:25:3: "substraction" is a misspelling of "subtraction" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q:16:51: "funtion" is a misspelling of "function" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q:15:30: "issueing" is a misspelling of "issuing" sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala:669:52: "wiht" is a misspelling of "with" sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java:474:9: "Refering" is a misspelling of "Referring" ``` ### after ``` $ misspell . | grep -v '.js' common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java:27:20: "transfered" is a misspelling of "transferred" core/src/main/scala/org/apache/spark/status/storeTypes.scala:113:29: "ect" is a misspelling of "etc" core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:1922:49: "agriculteur" is a misspelling of "agriculture" data/streaming/AFINN-111.txt:1219:0: "humerous" is a misspelling of "humorous" licenses/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt:170:0: "teh" is a misspelling of "the" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt:53:0: "eles" is a misspelling of "eels" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:99:20: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:539:11: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala:77:36: "Teh" is a misspelling of "The" mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala:276:9: "Euclidian" is a misspelling of "Euclidean" python/pyspark/heapq3.py:6:63: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:7:2: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:263:29: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:263:39: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:270:49: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:270:59: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:275:2: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:275:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/heapq3.py:277:29: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:277:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/ml/stat.py:339:23: "Euclidian" is a misspelling of "Euclidean" ``` Closes #22070 from seratch/fix-typo. Authored-by: Kazuhiro Sera Signed-off-by: Sean Owen --- NOTICE-binary | 4 ++-- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/SQLContext.R | 8 ++++---- R/pkg/R/context.R | 4 ++-- R/pkg/R/streaming.R | 2 +- R/pkg/inst/worker/worker.R | 2 +- R/pkg/tests/fulltests/test_utils.R | 2 +- .../org/apache/spark/util/kvstore/InMemoryStoreSuite.java | 2 +- .../java/org/apache/spark/util/kvstore/LevelDBSuite.java | 2 +- .../org/apache/spark/network/crypto/TransportCipher.java | 2 +- .../org/apache/spark/network/sasl/SaslEncryption.java | 8 ++++---- .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 4 ++-- .../scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- .../main/scala/org/apache/spark/storage/DiskStore.scala | 2 +- .../main/scala/org/apache/spark/util/ListenerBus.scala | 2 +- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 2 +- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../org/apache/spark/storage/FlatmapIteratorSuite.scala | 4 ++-- dev/merge_spark_pr.py | 4 ++-- dev/run-pip-tests | 6 +++--- docs/configuration.md | 2 +- docs/sql-programming-guide.md | 2 +- docs/structured-streaming-programming-guide.md | 4 ++-- examples/src/main/python/sql/arrow.py | 4 ++-- .../apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- .../org/apache/spark/ml/clustering/KMeansSuite.scala | 2 +- python/pyspark/context.py | 8 ++++---- python/pyspark/find_spark_home.py | 2 +- python/pyspark/heapq3.py | 2 +- python/pyspark/ml/clustering.py | 2 +- python/pyspark/ml/regression.py | 2 +- python/pyspark/mllib/stat/_statistics.py | 2 +- python/pyspark/rdd.py | 2 +- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/streaming/tests.py | 2 +- .../cluster/mesos/MesosSchedulerBackendUtil.scala | 2 +- .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 2 +- .../mesos/MesosClusterDispatcherArgumentsSuite.scala | 2 +- .../org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 2 +- .../spark/sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../sql/catalyst/expressions/datetimeExpressions.scala | 2 +- .../catalyst/plans/logical/basicLogicalOperators.scala | 2 +- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../catalyst/expressions/ConditionalExpressionSuite.scala | 2 +- .../streaming/StreamingSymmetricHashJoinHelper.scala | 2 +- .../test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala | 2 +- .../execution/datasources/csv/CSVInferSchemaSuite.scala | 2 +- .../apache/hive/service/cli/session/HiveSessionImpl.java | 2 +- .../golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922 | 2 +- .../golden/udf_instr-2-32da357fc754badd6e3898dcc8989182 | 2 +- .../golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e | 2 +- .../golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 | 2 +- .../udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 | 4 ++-- .../src/test/queries/clientpositive/annotate_stats_join.q | 2 +- .../test/queries/clientpositive/auto_sortmerge_join_11.q | 2 +- .../ql/src/test/queries/clientpositive/avro_partitioned.q | 2 +- .../ql/src/test/queries/clientpositive/decimal_udf.q | 2 +- .../queries/clientpositive/groupby2_map_multi_distinct.q | 2 +- .../ql/src/test/queries/clientpositive/groupby_sort_8.q | 2 +- .../apache/spark/sql/sources/HadoopFsRelationTest.scala | 2 +- 61 files changed, 81 insertions(+), 81 deletions(-) diff --git a/NOTICE-binary b/NOTICE-binary index 3155c3843ee7c..ad256aaf9f968 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -476,7 +476,7 @@ which has the following notices: PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ (Apache 2.0 license) - This library containd statically linked libstdc++. This inclusion is allowed by + This library contains statically linked libstdc++. This inclusion is allowed by "GCC RUntime Library Exception" http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html @@ -1192,4 +1192,4 @@ Apache Solr (http://lucene.apache.org/solr/) Copyright 2014 The Apache Software Foundation Apache Mahout (http://mahout.apache.org/) -Copyright 2014 The Apache Software Foundation \ No newline at end of file +Copyright 2014 The Apache Software Foundation diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 70eb7a874b75c..471ada15d655e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -588,7 +588,7 @@ setMethod("cache", #' \url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x the SparkDataFrame to persist. -#' @param newLevel storage level chosen for the persistance. See available options in +#' @param newLevel storage level chosen for the persistence. See available options in #' the description. #' #' @family SparkDataFrame functions diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 429dd5d565492..c819a7d14ae98 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -351,7 +351,7 @@ setMethod("toDF", signature(x = "RDD"), read.json.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -421,7 +421,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { read.orc <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the ORC file path + # Allow the user to have a more flexible definition of the ORC file path path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -442,7 +442,7 @@ read.orc <- function(path, ...) { read.parquet.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the Parquet file path + # Allow the user to have a more flexible definition of the Parquet file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) @@ -492,7 +492,7 @@ parquetFile <- function(x, ...) { read.text.default <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 3e996a5ba26fc..7e77ea4e002d9 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -43,7 +43,7 @@ getMinPartitions <- function(sc, minPartitions) { #' lines <- textFile(sc, "myfile.txt") #'} textFile <- function(sc, path, minPartitions = NULL) { - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") @@ -71,7 +71,7 @@ textFile <- function(sc, path, minPartitions = NULL) { #' rdd <- objectFile(sc, "myfile") #'} objectFile <- function(sc, path, minPartitions = NULL) { - # Allow the user to have a more flexible definiton of the text file path + # Allow the user to have a more flexible definition of the text file path path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index fc83463f72cd4..5eccbdc9d3818 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -163,7 +163,7 @@ setMethod("isActive", #' #' @param x a StreamingQuery. #' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} -#' is called or an error has occured. +#' is called or an error has occurred. #' @return TRUE if query has terminated within the timeout period; nothing if timeout is not #' specified. #' @rdname awaitTermination diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index ba458d2b9ddfb..c2adf613acb02 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -62,7 +62,7 @@ compute <- function(mode, partition, serializer, deserializer, key, # Transform the result data.frame back to a list of rows output <- split(output, seq(nrow(output))) } else { - # Serialize the ouput to a byte array + # Serialize the output to a byte array stopifnot(serializer == "byte") } } else { diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index f0292ab335592..b2b6f34aaa085 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -103,7 +103,7 @@ test_that("cleanClosure on R functions", { expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) expect_equal(get("l", envir = env, inherits = FALSE), l) - # "y" should be in the environemnt of g. + # "y" should be in the environment of g. newG <- get("g", envir = env, inherits = FALSE) env <- environment(newG) expect_equal(length(ls(env)), 1) diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java index 510b3058a4e3c..9abf26f02f7a7 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java @@ -35,7 +35,7 @@ public void testObjectWriteReadDelete() throws Exception { try { store.read(CustomType1.class, t.key); - fail("Expected exception for non-existant object."); + fail("Expected exception for non-existent object."); } catch (NoSuchElementException nsee) { // Expected. } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index b8123ac81d29a..205f7df87c5bc 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -80,7 +80,7 @@ public void testObjectWriteReadDelete() throws Exception { try { db.read(CustomType1.class, t.key); - fail("Expected exception for non-existant object."); + fail("Expected exception for non-existent object."); } catch (NoSuchElementException nsee) { // Expected. } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 452408df19061..b64e4b7a970b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -240,7 +240,7 @@ public boolean release(int decrement) { @Override public long transferTo(WritableByteChannel target, long position) throws IOException { - Preconditions.checkArgument(position == transfered(), "Invalid position."); + Preconditions.checkArgument(position == transferred(), "Invalid position."); do { if (currentEncrypted == null) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 1dcf1324839eb..e1275689ae6a0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -231,17 +231,17 @@ public boolean release(int decrement) { * data into memory at once, and can avoid ballooning memory usage when transferring large * messages such as shuffle blocks. * - * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward + * The {@link #transferred()} counter also behaves a little funny, in that it won't go forward * until a whole chunk has been written. This is done because the code can't use the actual * number of bytes written to the channel as the transferred count (see {@link #count()}). * Instead, once an encrypted chunk is written to the output (including its header), the - * size of the original block will be added to the {@link #transfered()} amount. + * size of the original block will be added to the {@link #transferred()} amount. */ @Override public long transferTo(final WritableByteChannel target, final long position) throws IOException { - Preconditions.checkArgument(position == transfered(), "Invalid position."); + Preconditions.checkArgument(position == transferred(), "Invalid position."); long reportedWritten = 0L; long actuallyWritten = 0L; @@ -273,7 +273,7 @@ public long transferTo(final WritableByteChannel target, final long position) currentChunkSize = 0; currentReportedBytes = 0; } - } while (currentChunk == null && transfered() + reportedWritten < count()); + } while (currentChunk == null && transferred() + reportedWritten < count()); // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, // we return 1 until we can (i.e. until the reported count would actually match the size diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 48004e812a8bf..7d3331f44f015 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -192,8 +192,8 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString)) test("concat") { - def concat(orgin: Seq[String]): String = - if (orgin.contains(null)) null else orgin.mkString + def concat(origin: Seq[String]): String = + if (origin.contains(null)) null else origin.mkString forAll { (inputs: Seq[String]) => assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString)) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8c2ce883093c8..c3db60a23f987 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -622,7 +622,7 @@ private[spark] class PythonAccumulatorV2( override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] // This conditional isn't strictly speaking needed - merging only currently happens on the - // driver program - but that isn't gauranteed so incase this changes. + // driver program - but that isn't guaranteed so incase this changes. if (serverHost == null) { // We are on the worker super.merge(otherPythonAccumulator) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 39249d411b582..ef526fd884058 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -279,7 +279,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def transferred(): Long = _transferred override def transferTo(target: WritableByteChannel, pos: Long): Long = { - assert(pos == transfered(), "Invalid position.") + assert(pos == transferred(), "Invalid position.") var written = 0L var lastWrite = -1L diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index d4474a90b26f1..a8f10684d5a2c 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -61,7 +61,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } /** - * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * This can be overridden by subclasses if there is any extra cleanup to do when removing a * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. */ def removeListenerOnError(listener: L): Unit = { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d11eaf8c2749c..456f97b535ef6 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -208,7 +208,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() results should have length (1) - // substracted rdd return results as Tuple2 + // subtracted rdd return results as Tuple2 results(0) should be ((3, 33)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5e095ce1ae4e9..3fbe636607687 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2465,7 +2465,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(1).tasks(1), Success, makeMapStatus("hostA", 2))) - // Both tasks in rddB should be resubmitted, because none of them has succeeded truely. + // Both tasks in rddB should be resubmitted, because none of them has succeeded truly. // Complete the task(stageId=1, stageAttemptId=1, partitionId=0) successfully. // Task(stageId=1, stageAttemptId=1, partitionId=1) of this new active stage attempt // is still running. diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index b21c91f75d5c7..42828506895a7 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -22,8 +22,8 @@ import org.apache.spark._ class FlatmapIteratorSuite extends SparkFunSuite with LocalSparkContext { /* Tests the ability of Spark to deal with user provided iterators from flatMap * calls, that may generate more data then available memory. In any - * memory based persistance Spark will unroll the iterator into an ArrayBuffer - * for caching, however in the case that the use defines DISK_ONLY persistance, + * memory based persistence Spark will unroll the iterator into an ArrayBuffer + * for caching, however in the case that the use defines DISK_ONLY persistence, * the iterator will be fed directly to the serializer and written to disk. * * This also tests the ObjectOutputStream reset rate. When serializing using the diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 7a6f7d2b891d3..fe05282efdd4d 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -374,8 +374,8 @@ def standardize_jira_ref(text): >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' >>> standardize_jira_ref( - ... "SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") - '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' + ... "SPARK-1094 Support MiMa for reporting binary compatibility across versions.") + '[SPARK-1094] Support MiMa for reporting binary compatibility across versions.' >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") '[SPARK-1146][WIP] Vagrant support for Spark' >>> standardize_jira_ref( diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 7271d1014e4ae..60cf4d8209416 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -52,7 +52,7 @@ if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then PYTHON_EXECS+=('python3') fi elif hash conda 2>/dev/null; then - echo "Using conda virtual enviroments" + echo "Using conda virtual environments" PYTHON_EXECS=('3.5') USE_CONDA=1 else @@ -88,7 +88,7 @@ for python in "${PYTHON_EXECS[@]}"; do virtualenv --python=$python "$VIRTUALENV_PATH" source "$VIRTUALENV_PATH"/bin/activate fi - # Upgrade pip & friends if using virutal env + # Upgrade pip & friends if using virtual env if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi @@ -123,7 +123,7 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" - # conda / virtualenv enviroments need to be deactivated differently + # conda / virtualenv environments need to be deactivated differently if [ -n "$USE_CONDA" ]; then source deactivate else diff --git a/docs/configuration.md b/docs/configuration.md index 4911abb0f5cfc..9c4742a1c0c85 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1827,7 +1827,7 @@ Apart from these, the following properties are also available, and may be useful executors w.r.t. full parallelism. Defaults to 1.0 to give maximum parallelism. 0.5 will divide the target number of executors by 2 - The target number of executors computed by the dynamicAllocation can still be overriden + The target number of executors computed by the dynamicAllocation can still be overridden by the spark.dynamicAllocation.minExecutors and spark.dynamicAllocation.maxExecutors settings diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9adb86aa60bbb..d9ebc3cfe4674 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1888,7 +1888,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behavior to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 0842e8dd88672..b832f7197ace6 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -522,7 +522,7 @@ Here are the details of all the sources in Spark.
      maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max)
      - latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) + latestFirst: whether to process the latest new files first, useful when there is a large backlog of files (default: false)
      fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same:
      @@ -1162,7 +1162,7 @@ In other words, you will have to do the following additional steps in the join. old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for matches with the other input. This constraint can be defined in one of the two ways. - 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEEN rightTime AND rightTime + INTERVAL 1 HOUR`), 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 6c4510d9e3c01..5eb164b20ad04 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -95,12 +95,12 @@ def grouped_map_pandas_udf_example(spark): ("id", "v")) @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) - def substract_mean(pdf): + def subtract_mean(pdf): # pdf is a pandas.DataFrame v = pdf.v return pdf.assign(v=v - v.mean()) - df.groupby("id").apply(substract_mean).show() + df.groupby("id").apply(subtract_mean).show() # +---+----+ # | id| v| # +---+----+ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 7a5e520d5818e..ed8543da4d4ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -227,7 +227,7 @@ class StreamingKMeans @Since("1.2.0") ( require(centers.size == k, s"Number of initial centers must be ${k} but got ${centers.size}") require(weights.forall(_ >= 0), - s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") + s"Weight for each initial center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 9b0b52617755c..ccbceab53bb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -234,7 +234,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes val oldKmeansModel = new MLlibKMeansModel(clusterCenters) val kmeansModel = new KMeansModel("", oldKmeansModel) def checkModel(pmml: PMML): Unit = { - // Check the header descripiton is what we expect + // Check the header description is what we expect assert(pmml.getHeader.getDescription === "k-means clustering") // check that the number of fields match the single vector size assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0ff4f5be0a228..40208ecff75b8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -934,10 +934,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): >>> def stop_job(): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") - >>> supress = lock.acquire() - >>> supress = threading.Thread(target=start_job, args=(10,)).start() - >>> supress = threading.Thread(target=stop_job).start() - >>> supress = lock.acquire() + >>> suppress = lock.acquire() + >>> suppress = threading.Thread(target=start_job, args=(10,)).start() + >>> suppress = threading.Thread(target=stop_job).start() + >>> suppress = lock.acquire() >>> print(result) Cancelled diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 9cf0e8c8d2fe9..9c4ed46598632 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -27,7 +27,7 @@ def _find_spark_home(): """Find the SPARK_HOME.""" - # If the enviroment has SPARK_HOME set trust it. + # If the environment has SPARK_HOME set trust it. if "SPARK_HOME" in os.environ: return os.environ["SPARK_HOME"] diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 6af084adcf373..37a2914ebac05 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -710,7 +710,7 @@ def merge(iterables, key=None, reverse=False): # value seen being in the 100 most extreme values is 100/101. # * If the value is a new extreme value, the cost of inserting it into the # heap is 1 + log(k, 2). -# * The probabilty times the cost gives: +# * The probability times the cost gives: # (k/i) * (1 + log(k, 2)) # * Summing across the remaining n-k elements gives: # sum((k/i) * (1 + log(k, 2)) for i in range(k+1, n+1)) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index ef9822d0ca5a5..ab449bc3f8f51 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -1035,7 +1035,7 @@ def getK(self): def setOptimizer(self, value): """ Sets the value of :py:attr:`optimizer`. - Currenlty only support 'em' and 'online'. + Currently only support 'em' and 'online'. >>> algo = LDA().setOptimizer("em") >>> algo.getOptimizer() diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 564c9f1b8f729..513ca5a9df85e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1375,7 +1375,7 @@ def intercept(self): @since("1.6.0") def scale(self): """ - Model scale paramter. + Model scale parameter. """ return self._call_java("scale") diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 937bb154c2356..6e89bfd691d16 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -259,7 +259,7 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): The KS statistic gives us the maximum distance between the ECDF and the CDF. Intuitively if this statistic is large, the - probabilty that the null hypothesis is true becomes small. + probability that the null hypothesis is true becomes small. For specific details of the implementation, please have a look at the Scala documentation. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d17a8eb76ad48..ba39edbc93d7c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1360,7 +1360,7 @@ def take(self, num): if len(items) == 0: numPartsToTry = partsScanned * 4 else: - # the first paramter of max is >=1 whenever partsScanned >= 2 + # the first parameter of max is >=1 whenever partsScanned >= 2 numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ed97a6394a98c..91ed600afedd7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -765,7 +765,7 @@ def filename(path): row2 = df2.select(sameText(df2['file'])).first() self.assertTrue(row2[0].find("people.json") != -1) - def test_udf_defers_judf_initalization(self): + def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization # when udf is called @@ -3613,7 +3613,7 @@ def tearDown(self): if SparkContext._active_spark_context is not None: SparkContext._active_spark_context.stop() - def test_udf_init_shouldnt_initalize_context(self): + def test_udf_init_shouldnt_initialize_context(self): from pyspark.sql.functions import UserDefinedFunction UserDefinedFunction(lambda x: x, StringType()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 373784f826677..09af47a597bed 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -822,7 +822,7 @@ def setupFunc(): self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) self.assertTrue(self.setupCalled) - # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + # Verify that getActiveOrCreate() returns active context and does not call the setupFunc self.ssc.start() self.setupCalled = False self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index bfb73611f0530..b4364a5e2eb3a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -117,7 +117,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { case Array(key, value) => Some(param.setKey(key).setValue(value)) case spec => - logWarning(s"Unable to parse arbitary parameters: $params. " + logWarning(s"Unable to parse arbitrary parameters: $params. " + "Expected form: \"key=value(, ...)\"") None } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index ecbcc960fc5a0..8ef1e18f83de3 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -355,7 +355,7 @@ trait MesosSchedulerUtils extends Logging { * https://github.com/apache/mesos/blob/master/src/common/values.cpp * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp * - * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * @param constraintsVal constains string consisting of ';' separated key-value pairs (separated * by ':') * @return Map of constraints to match resources offers. */ diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala index 33e7d69d53d38..057c51db455ef 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.deploy.TestPrematureExit class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite with TestPrematureExit { - test("test if spark config args are passed sucessfully") { + test("test if spark config args are passed successfully") { val args = Array[String]("--master", "mesos://localhost:5050", "--conf", "key1=value1", "--conf", "spark.mesos.key2=value2", "--verbose") val conf = new SparkConf() diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 3b78b88de778d..d67f5d2768e49 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -108,7 +108,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", "spark.executor.instances" -> "2", - // Sending some senstive information, which we'll make sure gets redacted + // Sending some sensitive information, which we'll make sure gets redacted "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index cd243b87652f4..ee3932cc56d01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1060,7 +1060,7 @@ class SessionCatalog( } /** - * overwirte a metastore function in the database specified in `funcDefinition`.. + * overwrite a metastore function in the database specified in `funcDefinition`.. * If no database is specified, assume the function is in the current database. */ def alterFunction(funcDefinition: CatalogFunction): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 08838d2b2c612..f95798d64db19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1345,7 +1345,7 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr } def this(left: Expression) = { - // backwards compatability + // backwards compatibility this(left, None, Cast(left, DateType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a6631a8d444e9..7ff83a9be3622 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -74,7 +74,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) * their output. * * @param generator the generator expression - * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer. + * @param unrequiredChildIndex this parameter starts as Nil and gets filled by the Optimizer. * It's used as an optimization for omitting data generation that will * be discarded next by a projection. * A common use case is when we explode(array(..)) and are interested diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 594952e95dd4e..dbb5bb43b4f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1371,7 +1371,7 @@ object SQLConf { "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + "affect Hive serde tables, as they are always overwritten with dynamic mode. This can " + "also be set as an output option for a data source using key partitionOverwriteMode " + - "(which takes precendence over this setting), e.g. " + + "(which takes precedence over this setting), e.g. " + "dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)." ) .stringConf diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a1c976dd923f1..94f37f1aafa78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -235,7 +235,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(plan, expected) } - test("Analysis may leave unnecassary aliases") { + test("Analysis may leave unnecessary aliases") { val att1 = testRelation.output.head var plan = testRelation.select( CreateStruct(Seq(att1, ((att1.as("aa")) + 1).as("a_plus_1"))).as("col"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index e068c32500cfc..f489d330cf453 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -209,7 +209,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), null, row) } - test("case key whn - internal pattern matching expects a List while apply takes a Seq") { + test("case key when - internal pattern matching expects a List while apply takes a Seq") { val indexedSeq = IndexedSeq(Literal(1), Literal(42), Literal(42), Literal(1)) val caseKeyWhaen = CaseKeyWhen(Literal(12), indexedSeq) assert(caseKeyWhaen.branches == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 4aba76cad367e..2d4c3c10e6445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -144,7 +144,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { // Join keys of both sides generate rows of the same fields, that is, same sequence of data - // types. If one side (say left side) has a column (say timestmap) that has a watermark on it, + // types. If one side (say left side) has a column (say timestamp) that has a watermark on it, // then it will never consider joining keys that are < state key watermark (i.e. event time // watermark). On the other side (i.e. right side), even if there is no watermark defined, // there has to be an equivalent column (i.e., timestamp). And any right side data that has the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index bc95b4696190d..817224d1c28ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -147,7 +147,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, - |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |`s_gmt_offset` DECIMAL(5,2), `s_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 842251be92c18..57e36e082653c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -132,7 +132,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { == StringType) } - test("DoubleType should be infered when user defined nan/inf are provided") { + test("DoubleType should be inferred when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index f59cdcd3188e6..745f385e87f78 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -471,7 +471,7 @@ private OperationHandle executeStatementInternal(String statement, Map SELECT instr('Facebook', 'boo') FROM src LIMIT 1; 5 diff --git a/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e b/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e index 84bea329540d1..8e70b0c89b594 100644 --- a/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e +++ b/sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e @@ -1 +1 @@ -locate(substr, str[, pos]) - Returns the position of the first occurance of substr in str after position pos +locate(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos diff --git a/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 b/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 index 092e12586b9e8..e103255a31f03 100644 --- a/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 +++ b/sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478 @@ -1,4 +1,4 @@ -locate(substr, str[, pos]) - Returns the position of the first occurance of substr in str after position pos +locate(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos Example: > SELECT locate('bar', 'foobarbar', 5) FROM src LIMIT 1; 7 diff --git a/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 b/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 index 9ced4ee32cf0b..6caa4b679111d 100644 --- a/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 +++ b/sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8 @@ -6,8 +6,8 @@ translate('abcdef', 'adc', '19') returns '1b9ef' replacing 'a' with '1', 'd' wit translate('a b c d', ' ', '') return 'abcd' removing all spaces from the input string -If the same character is present multiple times in the input string, the first occurence of the character is the one that's considered for matching. However, it is not recommended to have the same character more than once in the from string since it's not required and adds to confusion. +If the same character is present multiple times in the input string, the first occurrence of the character is the one that's considered for matching. However, it is not recommended to have the same character more than once in the from string since it's not required and adds to confusion. For example, -translate('abcdef', 'ada', '192') returns '1bc9ef' replaces 'a' with '1' and 'd' with '9' ignoring the second occurence of 'a' in the from string mapping it to '2' +translate('abcdef', 'ada', '192') returns '1bc9ef' replaces 'a' with '1' and 'd' with '9' ignoring the second occurrence of 'a' in the from string mapping it to '2' diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q index 965b0b7ed0a3e..633150b5cf544 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q @@ -43,7 +43,7 @@ analyze table loc_orc compute statistics for columns state,locid,zip,year; -- dept_orc - 4 -- loc_orc - 8 --- count distincts for relevant columns (since count distinct values are approximate in some cases count distint values will be greater than number of rows) +-- count distincts for relevant columns (since count distinct values are approximate in some cases count distinct values will be greater than number of rows) -- emp_orc.deptid - 3 -- emp_orc.lastname - 7 -- dept_orc.deptid - 6 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q index da2e26fde7069..e8289772e7544 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q @@ -26,7 +26,7 @@ set hive.optimize.bucketmapjoin.sortedmerge=true; -- Since size is being used to find the big table, the order of the tables in the join does not matter -- The tables are only bucketed and not sorted, the join should not be converted --- Currenly, a join is only converted to a sort-merge join without a hint, automatic conversion to +-- Currently, a join is only converted to a sort-merge join without a hint, automatic conversion to -- bucketized mapjoin is not done explain extended select count(*) FROM bucket_small a JOIN bucket_big b ON a.key = b.key; select count(*) FROM bucket_small a JOIN bucket_big b ON a.key = b.key; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q index 6fe5117026ce8..e4ed7195a0575 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q @@ -69,5 +69,5 @@ SELECT * FROM episodes_partitioned WHERE doctor_pt > 6 ORDER BY air_date; SELECT * FROM episodes_partitioned ORDER BY air_date LIMIT 5; -- Fetch w/filter to specific partition SELECT * FROM episodes_partitioned WHERE doctor_pt = 6; --- Fetch w/non-existant partition +-- Fetch w/non-existent partition SELECT * FROM episodes_partitioned WHERE doctor_pt = 7 LIMIT 5; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q index 0c9f1b86a9e97..39d2d248a311f 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q @@ -22,7 +22,7 @@ SELECT key + (value/2) FROM DECIMAL_UDF; EXPLAIN SELECT key + '1.0' FROM DECIMAL_UDF; SELECT key + '1.0' FROM DECIMAL_UDF; --- substraction +-- subtraction EXPLAIN SELECT key - key FROM DECIMAL_UDF; SELECT key - key FROM DECIMAL_UDF; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q index 3aeae0d5c33d6..d677fe65245ed 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q @@ -13,7 +13,7 @@ INSERT OVERWRITE TABLE dest1 SELECT substr(src.key,1,1), count(DISTINCT substr(s SELECT dest1.* FROM dest1 ORDER BY key; --- HIVE-5560 when group by key is used in distinct funtion, invalid result are returned +-- HIVE-5560 when group by key is used in distinct function, invalid result are returned EXPLAIN FROM src diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q index f53295e4b2435..69d671aa47116 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q @@ -12,7 +12,7 @@ LOAD DATA LOCAL INPATH '../../data/files/T1.txt' INTO TABLE T1 PARTITION (ds='1' INSERT OVERWRITE TABLE T1 PARTITION (ds='1') select key, val from T1 where ds = '1'; -- The plan is not converted to a map-side, since although the sorting columns and grouping --- columns match, the user is issueing a distinct. +-- columns match, the user is issuing a distinct. -- However, after HIVE-4310, partial aggregation is performed on the mapper EXPLAIN select count(distinct key) from T1; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 53397991e59dc..b9ec940ac4925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -666,7 +666,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes assert(expectedResult.isRight, s"Was not expecting error with $path: " + e) assert( e.getMessage.contains(expectedResult.right.get), - s"Did not find expected error message wiht $path") + s"Did not find expected error message with $path") } } From c3be2cd347c42972d9c499b6fd9a6f988f80af12 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 11 Aug 2018 22:51:11 -0700 Subject: [PATCH 1376/2461] [SPARK-25092] Add RewriteExceptAll and RewriteIntersectAll in the list of nonExcludableRules ## What changes were proposed in this pull request? Add RewriteExceptAll and RewriteIntersectAll in the list of nonExcludableRules as the rewrites are essential for the functioning of EXCEPT ALL and INTERSECT ALL feature. ## How was this patch tested? Added test in OptimizerRuleExclusionSuite. Closes #22080 from dilipbiswal/exceptall_rewrite_exclusion. Authored-by: Dilip Biswal Signed-off-by: Xiao Li --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 ++++-- .../catalyst/optimizer/OptimizerRuleExclusionSuite.scala | 6 ++++-- .../spark/sql/catalyst/optimizer/SetOperationSuite.scala | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 105623c767d66..2ff67689c3492 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -135,7 +135,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, - RewriteExcepAll, + RewriteExceptAll, RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, @@ -189,6 +189,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceIntersectWithSemiJoin.ruleName :: ReplaceExceptWithFilter.ruleName :: ReplaceExceptWithAntiJoin.ruleName :: + RewriteExceptAll.ruleName :: + RewriteIntersectAll.ruleName :: ReplaceDistinctWithAggregate.ruleName :: PullupCorrelatedPredicates.ruleName :: RewritePredicateSubquery.ruleName :: Nil @@ -1462,7 +1464,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { * }}} */ -object RewriteExcepAll extends Rule[LogicalPlan] { +object RewriteExceptAll extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Except(left, right, true) => assert(left.output.size == right.output.size) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala index 30c80d26b67a1..eee8dc3b76c34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -80,12 +80,14 @@ class OptimizerRuleExclusionSuite extends PlanTest { "DummyRuleName")) } - test("Try to exclude a non-excludable rule") { + test("Try to exclude some non-excludable rules") { verifyExcludedRules( new SimpleTestOptimizer(), Seq( ReplaceIntersectWithSemiJoin.ruleName, - PullupCorrelatedPredicates.ruleName)) + PullupCorrelatedPredicates.ruleName, + RewriteExceptAll.ruleName, + RewriteIntersectAll.ruleName)) } test("Custom optimizer") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index cb744be400603..da3923f8d6477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -148,7 +148,7 @@ class SetOperationSuite extends PlanTest { test("EXCEPT ALL rewrite") { val input = Except(testRelation, testRelation2, isAll = true) - val rewrittenPlan = RewriteExcepAll(input) + val rewrittenPlan = RewriteExceptAll(input) val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f)) From d17723479239ed6ca9f043623d229a972d71f8c9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 12 Aug 2018 19:43:25 +0800 Subject: [PATCH 1377/2461] [SQL][TEST][MINOR] Add missing codes to ParquetCompressionCodecPrecedenceSuite ## What changes were proposed in this pull request? This PR adds codes to ``"Test `spark.sql.parquet.compression.codec` config"` in `ParquetCompressionCodecPrecedenceSuite`. ## How was this patch tested? Existing UTs Closes #22083 from kiszk/ParquetCompressionCodecPrecedenceSuite. Authored-by: Kazuaki Ishizaki Signed-off-by: hyukjinkwon --- .../parquet/ParquetCompressionCodecPrecedenceSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala index ed8fd2b453456..09de715e87a11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { test("Test `spark.sql.parquet.compression.codec` config") { - Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO", "LZ4", "BROTLI", "ZSTD").foreach { c => withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { val expected = if (c == "NONE") "UNCOMPRESSED" else c val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) From 5bc7598b25dbf5ea4b3e0f149aa31fb03a5310f9 Mon Sep 17 00:00:00 2001 From: Tynan CR Date: Sun, 12 Aug 2018 08:13:09 -0500 Subject: [PATCH 1378/2461] Fix typos ## What changes were proposed in this pull request? Small typo fixes in Pyspark. These were the only ones I stumbled across after looking around for a while. ## How was this patch tested? Manually Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22016 from tynan-cr/typo-fix-pyspark. Authored-by: Tynan CR Signed-off-by: Sean Owen --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 40208ecff75b8..b77fa0ee2892b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -126,7 +126,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment = environment or {} # java gateway must have been launched at this point. if conf is not None and conf._jconf is not None: - # conf has been initialized in JVM properly, so use conf directly. This represent the + # conf has been initialized in JVM properly, so use conf directly. This represents the # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is # created and then stopped, and we create a new SparkConf and new SparkContext again) self._conf = conf From a90b1f5d93d2eccca46c9c525c03a13ae55fd967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=8D=E5=86=AC?= Date: Sun, 12 Aug 2018 08:26:21 -0500 Subject: [PATCH 1379/2461] [MINOR][DOC] Fix Java example code in Column's comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fix scaladoc in Column ## How was this patch tested? None Closes #22069 from sadhen/fix_doc_minor. Authored-by: 忍冬 Signed-off-by: Sean Owen --- .../scala/org/apache/spark/sql/Column.scala | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4eee3de5f7d4e..ae27690f2e5ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -345,7 +345,7 @@ class Column(val expr: Expression) extends Logging { * * // Java: * import static org.apache.spark.sql.functions.*; - * people.select( people("age").gt(21) ); + * people.select( people.col("age").gt(21) ); * }}} * * @group expr_ops @@ -361,7 +361,7 @@ class Column(val expr: Expression) extends Logging { * * // Java: * import static org.apache.spark.sql.functions.*; - * people.select( people("age").gt(21) ); + * people.select( people.col("age").gt(21) ); * }}} * * @group java_expr_ops @@ -376,7 +376,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") < 21 ) * * // Java: - * people.select( people("age").lt(21) ); + * people.select( people.col("age").lt(21) ); * }}} * * @group expr_ops @@ -391,7 +391,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") < 21 ) * * // Java: - * people.select( people("age").lt(21) ); + * people.select( people.col("age").lt(21) ); * }}} * * @group java_expr_ops @@ -406,7 +406,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") <= 21 ) * * // Java: - * people.select( people("age").leq(21) ); + * people.select( people.col("age").leq(21) ); * }}} * * @group expr_ops @@ -421,7 +421,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") <= 21 ) * * // Java: - * people.select( people("age").leq(21) ); + * people.select( people.col("age").leq(21) ); * }}} * * @group java_expr_ops @@ -436,7 +436,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") >= 21 ) * * // Java: - * people.select( people("age").geq(21) ) + * people.select( people.col("age").geq(21) ) * }}} * * @group expr_ops @@ -451,7 +451,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("age") >= 21 ) * * // Java: - * people.select( people("age").geq(21) ) + * people.select( people.col("age").geq(21) ) * }}} * * @group java_expr_ops @@ -588,7 +588,7 @@ class Column(val expr: Expression) extends Logging { * people.filter( people("inSchool") || people("isEmployed") ) * * // Java: - * people.filter( people("inSchool").or(people("isEmployed")) ); + * people.filter( people.col("inSchool").or(people.col("isEmployed")) ); * }}} * * @group expr_ops @@ -603,7 +603,7 @@ class Column(val expr: Expression) extends Logging { * people.filter( people("inSchool") || people("isEmployed") ) * * // Java: - * people.filter( people("inSchool").or(people("isEmployed")) ); + * people.filter( people.col("inSchool").or(people.col("isEmployed")) ); * }}} * * @group java_expr_ops @@ -618,7 +618,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("inSchool") && people("isEmployed") ) * * // Java: - * people.select( people("inSchool").and(people("isEmployed")) ); + * people.select( people.col("inSchool").and(people.col("isEmployed")) ); * }}} * * @group expr_ops @@ -633,7 +633,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("inSchool") && people("isEmployed") ) * * // Java: - * people.select( people("inSchool").and(people("isEmployed")) ); + * people.select( people.col("inSchool").and(people.col("isEmployed")) ); * }}} * * @group java_expr_ops @@ -648,7 +648,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") + people("weight") ) * * // Java: - * people.select( people("height").plus(people("weight")) ); + * people.select( people.col("height").plus(people.col("weight")) ); * }}} * * @group expr_ops @@ -663,7 +663,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") + people("weight") ) * * // Java: - * people.select( people("height").plus(people("weight")) ); + * people.select( people.col("height").plus(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -678,7 +678,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") - people("weight") ) * * // Java: - * people.select( people("height").minus(people("weight")) ); + * people.select( people.col("height").minus(people.col("weight")) ); * }}} * * @group expr_ops @@ -693,7 +693,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") - people("weight") ) * * // Java: - * people.select( people("height").minus(people("weight")) ); + * people.select( people.col("height").minus(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -708,7 +708,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") * people("weight") ) * * // Java: - * people.select( people("height").multiply(people("weight")) ); + * people.select( people.col("height").multiply(people.col("weight")) ); * }}} * * @group expr_ops @@ -723,7 +723,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") * people("weight") ) * * // Java: - * people.select( people("height").multiply(people("weight")) ); + * people.select( people.col("height").multiply(people.col("weight")) ); * }}} * * @group java_expr_ops @@ -738,7 +738,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") / people("weight") ) * * // Java: - * people.select( people("height").divide(people("weight")) ); + * people.select( people.col("height").divide(people.col("weight")) ); * }}} * * @group expr_ops @@ -753,7 +753,7 @@ class Column(val expr: Expression) extends Logging { * people.select( people("height") / people("weight") ) * * // Java: - * people.select( people("height").divide(people("weight")) ); + * people.select( people.col("height").divide(people.col("weight")) ); * }}} * * @group java_expr_ops From be2238fb502b0f49a8a1baa6da9bc3e99540b40e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 13 Aug 2018 08:29:07 +0800 Subject: [PATCH 1380/2461] [SPARK-24774][SQL] Avro: Support logical decimal type ## What changes were proposed in this pull request? Support Avro logical date type: https://avro.apache.org/docs/1.8.2/spec.html#Decimal ## How was this patch tested? Unit test Closes #22037 from gengliangwang/avro_decimal. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../spark/sql/avro/AvroDeserializer.scala | 34 +++++- .../spark/sql/avro/AvroSerializer.scala | 12 +- .../spark/sql/avro/SchemaConverters.scala | 54 ++++++--- .../AvroCatalystDataConversionSuite.scala | 2 - .../org/apache/spark/sql/avro/AvroSuite.scala | 103 +++++++++++++++++- .../org/apache/spark/sql/types/Decimal.scala | 19 ++++ .../parquet/ParquetSchemaConverter.scala | 26 +---- .../parquet/ParquetWriteSupport.scala | 6 +- 8 files changed, 207 insertions(+), 49 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 74677a29afcb4..272e7d5b388d9 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.avro +import java.math.{BigDecimal} import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ import org.apache.avro.generic._ @@ -38,6 +40,8 @@ import org.apache.spark.unsafe.types.UTF8String * A deserializer to deserialize data in avro format to data in catalyst format. */ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private lazy val decimalConversions = new DecimalConversion() + private val converter: Any => Any = rootCatalystType match { // A shortcut for empty schema. case st: StructType if st.isEmpty => @@ -138,10 +142,21 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { bytes case b: Array[Byte] => b case other => throw new RuntimeException(s"$other is not a valid avro binary.") - } updater.set(ordinal, bytes) + case (FIXED, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (BYTES, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + case (RECORD, st: StructType) => val writeRecord = getRecordWriter(avroType, st, path) (updater, ordinal, value) => @@ -263,6 +278,17 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { s"Target Catalyst type: $rootCatalystType") } + // TODO: move the following method in Decimal object on creating Decimal from BigDecimal? + private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + Decimal(decimal.unscaledValue().longValue(), precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(decimal, precision, scale) + } + } + private def getRecordWriter( avroType: Schema, sqlType: StructType, @@ -334,6 +360,7 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) } final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { @@ -347,6 +374,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) } final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { @@ -360,5 +389,6 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 216c52a5cfd26..3a9544c3f48cd 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -21,15 +21,17 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema import org.apache.avro.Schema.Type import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} +import org.apache.avro.generic.GenericData.Record import org.apache.avro.util.Utf8 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -67,6 +69,8 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: private type Converter = (SpecializedGetters, Int) => Any + private lazy val decimalConversions = new DecimalConversion() + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { catalystType match { case NullType => @@ -86,7 +90,11 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: case DoubleType => (getter, ordinal) => getter.getDouble(ordinal) case d: DecimalType => - (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + case StringType => avroType.getType match { case Type.ENUM => import scala.collection.JavaConverters._ diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 245e68d242c50..7b33cf6e6e055 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -18,19 +18,28 @@ package org.apache.spark.sql.avro import scala.collection.JavaConverters._ +import scala.util.Random +import com.fasterxml.jackson.annotation.ObjectIdGenerators.UUIDGenerator import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder} -import org.apache.avro.LogicalTypes.{Date, TimestampMicros, TimestampMillis} +import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision} /** * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice * versa. */ object SchemaConverters { + private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong()) + + private lazy val nullSchema = Schema.create(Schema.Type.NULL) + case class SchemaType(dataType: DataType, nullable: Boolean) /** @@ -44,14 +53,20 @@ object SchemaConverters { } case STRING => SchemaType(StringType, nullable = false) case BOOLEAN => SchemaType(BooleanType, nullable = false) - case BYTES => SchemaType(BinaryType, nullable = false) + case BYTES | FIXED => avroSchema.getLogicalType match { + // For FIXED type, if the precision requires more bytes than fixed size, the logical + // type will be null, which is handled by Avro library. + case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false) + case _ => SchemaType(BinaryType, nullable = false) + } + case DOUBLE => SchemaType(DoubleType, nullable = false) case FLOAT => SchemaType(FloatType, nullable = false) case LONG => avroSchema.getLogicalType match { case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false) case _ => SchemaType(LongType, nullable = false) } - case FIXED => SchemaType(BinaryType, nullable = false) + case ENUM => SchemaType(StringType, nullable = false) case RECORD => @@ -114,20 +129,14 @@ object SchemaConverters { prevNameSpace: String = "", outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS) : Schema = { - val builder = if (nullable) { - SchemaBuilder.builder().nullable() - } else { - SchemaBuilder.builder() - } + val builder = SchemaBuilder.builder() - catalystType match { + val schema = catalystType match { case BooleanType => builder.booleanType() case ByteType | ShortType | IntegerType => builder.intType() case LongType => builder.longType() - case DateType => builder - .intBuilder() - .prop(LogicalType.LOGICAL_TYPE_PROP, LogicalTypes.date().getName) - .endInt() + case DateType => + LogicalTypes.date().addToSchema(builder.intType()) case TimestampType => val timestampType = outputTimestampType match { case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis() @@ -135,11 +144,21 @@ object SchemaConverters { case other => throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.") } - builder.longBuilder().prop(LogicalType.LOGICAL_TYPE_PROP, timestampType.getName).endLong() + timestampType.addToSchema(builder.longType()) case FloatType => builder.floatType() case DoubleType => builder.doubleType() - case _: DecimalType | StringType => builder.stringType() + case StringType => builder.stringType() + case d: DecimalType => + val avroType = LogicalTypes.decimal(d.precision, d.scale) + val fixedSize = minBytesForPrecision(d.precision) + // Need to avoid naming conflict for the fixed fields + val name = prevNameSpace match { + case "" => s"$recordName.fixed" + case _ => s"$prevNameSpace.$recordName.fixed" + } + avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize)) + case BinaryType => builder.bytesType() case ArrayType(et, containsNull) => builder.array() @@ -164,6 +183,11 @@ object SchemaConverters { // This should never happen. case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") } + if (nullable) { + Schema.createUnion(schema, nullSchema) + } else { + schema + } } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 06d5477b2ea45..4b3bf0cd52957 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -64,8 +64,6 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH BinaryType) protected def prepareExpectedResult(expected: Any): Any = expected match { - // Spark decimal is converted to avro string= - case d: Decimal => UTF8String.fromString(d.toString) // Spark byte and short both map to avro int case b: Byte => b.toInt case s: Short => s.toInt diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index ada9980e65b1c..3fa43bf929761 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -25,7 +25,8 @@ import java.util.{TimeZone, UUID} import scala.collection.JavaConverters._ -import org.apache.avro.Schema +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.Schema.{Field, Type} import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} @@ -494,6 +495,104 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(df, expected) } + test("Logical type: Decimal") { + val precision = 4 + val scale = 2 + val bytesFieldName = "bytes" + val bytesSchema = s"""{ + "type":"bytes", + "logicalType":"decimal", + "precision":$precision, + "scale":$scale + } + """ + + val fixedFieldName = "fixed" + val fixedSchema = s"""{ + "type":"fixed", + "size":5, + "logicalType":"decimal", + "precision":$precision, + "scale":$scale, + "name":"foo" + } + """ + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "$bytesFieldName", "type": $bytesSchema}, + {"name": "$fixedFieldName", "type": $fixedSchema} + ] + } + """ + val schema = new Schema.Parser().parse(avroSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val decimalConversion = new DecimalConversion + withTempDir { dir => + val avroFile = s"$dir.avro" + dataFileWriter.create(schema, new File(avroFile)) + val logicalType = LogicalTypes.decimal(precision, scale) + val data = Seq("1.23", "4.56", "78.90", "-1", "-2.31") + data.map { x => + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal(x).setScale(scale) + val bytes = + decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) + avroRec.put(bytesFieldName, bytes) + val fixed = + decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) + avroRec.put(fixedFieldName, fixed) + dataFileWriter.append(avroRec) + } + dataFileWriter.flush() + dataFileWriter.close() + + val expected = data.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + checkAnswer(df, expected) + checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: Decimal with too large precision") { + withTempDir { dir => + val schema = new Schema.Parser().parse("""{ + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [{ + "name": "decimal", + "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} + }] + }""") + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") + val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) + avroRec.put("decimal", bytes) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val msg = intercept[SparkException] { + spark.read.format("avro").load(s"$dir.avro").collect() + }.getCause.getMessage + assert(msg.contains("Unscaled value too large for precision")) + } + } + test("Array data types") { withTempPath { dir => val testSchema = StructType(Seq( @@ -689,7 +788,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // DecimalType should be converted to string val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect() - assert(decimals.map(_(0)).contains("3.14")) + assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14"))) // There should be a null entry val length = spark.read.format("avro").load(avroDir).select("Length").collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6da4f28b12962..9eed2eb202045 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -479,6 +479,25 @@ object Decimal { dec } + // Max precision of a decimal value stored in `numBytes` bytes + def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + lazy val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two // parameters inheriting from a common trait since both traits define mkNumericOps. // See scala.math's Numeric.scala for examples for Scala's built-in types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 70f42f2c4ad79..8ce8a86d2f026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -26,7 +26,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -171,7 +170,7 @@ class ParquetToSparkSchemaConverter( case FIXED_LEN_BYTE_ARRAY => originalType match { - case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) case INTERVAL => typeNotImplemented() case _ => illegalType() } @@ -411,7 +410,7 @@ class SparkToParquetSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(ParquetSchemaConverter.minBytesForPrecision(precision)) + .length(Decimal.minBytesForPrecision(precision)) .named(field.name) // ======================== @@ -445,7 +444,7 @@ class SparkToParquetSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(ParquetSchemaConverter.minBytesForPrecision(precision)) + .length(Decimal.minBytesForPrecision(precision)) .named(field.name) // =================================== @@ -584,23 +583,4 @@ private[sql] object ParquetSchemaConverter { throw new AnalysisException(message) } } - - private def computeMinBytesForPrecision(precision : Int) : Int = { - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } - - // Returns the minimum number of bytes needed to store a decimal with a given `precision`. - val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - - // Max precision of a decimal value stored in `numBytes` bytes - def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes - .asInstanceOf[Int] - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index af4e1433c876f..b40b8c2e61f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -33,7 +33,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -73,7 +72,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit private val timestampBuffer = new Array[Byte](12) // Reusable byte array used to write decimal values - private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + private val decimalBuffer = + new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) @@ -212,7 +212,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit precision <= DecimalType.MAX_PRECISION, s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}") - val numBytes = minBytesForPrecision(precision) + val numBytes = Decimal.minBytesForPrecision(precision) val int32Writer = (row: SpecializedGetters, ordinal: Int) => { From 02d0a1ffd9adb4bf898905095318eb099ed1807f Mon Sep 17 00:00:00 2001 From: 10129659 Date: Mon, 13 Aug 2018 09:09:25 +0800 Subject: [PATCH 1381/2461] [SPARK-25069][CORE] Using UnsafeAlignedOffset to make the entire record of 8 byte Items aligned like which is used in UnsafeExternalSorter ## What changes were proposed in this pull request? The class of UnsafeExternalSorter used UnsafeAlignedOffset to make the entire record of 8 byte Items aligned, but ShuffleExternalSorter not. The SPARC platform requires this because using a 4 byte Int for record lengths causes the entire record of 8 byte Items to become misaligned by 4 bytes. Using a 8 byte long for record length keeps things 8 byte aligned. ## How was this patch tested? Existing Test. Closes #22053 from eatoncys/UnsafeAlignedOffset. Authored-by: 10129659 Signed-off-by: Wenchen Fan --- .../spark/shuffle/sort/ShuffleExternalSorter.java | 15 +++++++++------ .../unsafe/sort/UnsafeExternalSorter.java | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c3a07b2abf896..c7d2db4217d96 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -43,6 +43,7 @@ import org.apache.spark.storage.FileSegment; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -184,6 +185,7 @@ private void writeSortedFile(boolean isLastFile) { blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); int currentPartition = -1; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); final int partition = sortedRecords.packedRecordPointer.getPartitionId(); @@ -200,8 +202,8 @@ private void writeSortedFile(boolean isLastFile) { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); - long recordReadPosition = recordOffsetInPage + 4; // skip over record length + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); Platform.copyMemory( @@ -389,15 +391,16 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p } growPointerArrayIfNecessary(); - // Need 4 bytes to store the record length. - final int required = length + 4; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; acquireNewPageIfNecessary(required); assert(currentPage != null); final Object base = currentPage.getBaseObject(); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); - Platform.putInt(base, pageCursor, length); - pageCursor += 4; + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4fc19b1721518..399251b80e649 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -402,7 +402,7 @@ public void insertRecord( growPointerArrayIfNecessary(); int uaoSize = UnsafeAlignedOffset.getUaoSize(); - // Need 4 bytes to store the record length. + // Need 4 or 8 bytes to store the record length. final int required = length + uaoSize; acquireNewPageIfNecessary(required); From 20fa45693238cd39e162b129214f5d6a93e5552e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 13 Aug 2018 09:11:37 +0800 Subject: [PATCH 1382/2461] [SPARK-25090][ML] Enforce implicit type coercion in ParamGridBuilder ## What changes were proposed in this pull request? When the grid of the parameters is created in `ParamGridBuilder`, the implicit type coercion is not enforced. So using an integer in the list of parameters to set for a parameter accepting a double can cause a class cast exception. The PR proposes to enforce the type coercion when building the parameters. ## How was this patch tested? added UT Closes #22076 from mgaido91/SPARK-25090. Authored-by: Marco Gaido Signed-off-by: hyukjinkwon --- python/pyspark/ml/tests.py | 7 +++++++ python/pyspark/ml/tuning.py | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3d8883b486e4c..a770bad32ecd2 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -950,6 +950,13 @@ def test_fit_maximize_metric(self): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0c8029f293cfe..1f4abf5157335 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,7 +115,11 @@ def build(self): """ keys = self._param_grid.keys() grid_values = self._param_grid.values() - return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + + def to_key_value_pairs(keys, values): + return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] + + return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] class ValidatorParams(HasSeed): From 5d6abad36dc8d8a55dafc04c2022d5c10c1b0ba3 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 13 Aug 2018 09:14:17 +0800 Subject: [PATCH 1383/2461] [SPARK-25033] Bump Apache commons.{httpclient, httpcore} ## What changes were proposed in this pull request? Bump the versions of Apache commons.{httpclient, httpcore} to make it congruent with Stocator. Changelog httpclient: https://archive.apache.org/dist/httpcomponents/httpclient/RELEASE_NOTES-4.5.x.txt Changelog httpcore: https://archive.apache.org/dist/httpcomponents/httpcore/RELEASE_NOTES.txt ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22007 from Fokko/SPARK-25033. Authored-by: Fokko Driesprong Signed-off-by: hyukjinkwon --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 54cdcfcaf8aa1..3c0952f36a051 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -86,8 +86,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core-3.0.4.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 113639946f7d6..310f1e4528374 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -86,8 +86,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core-3.1.0-incubating.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index fb42adf95db27..9bff2a1013910 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -85,8 +85,8 @@ hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar hppc-0.7.2.jar htrace-core4-4.1.0-incubating.jar -httpclient-4.5.4.jar -httpcore-4.4.8.jar +httpclient-4.5.6.jar +httpcore-4.4.10.jar ivy-2.4.0.jar jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar diff --git a/pom.xml b/pom.xml index b89713fab7291..45fca285ddadb 100644 --- a/pom.xml +++ b/pom.xml @@ -149,8 +149,8 @@ 0.10.2 - 4.5.4 - 4.4.8 + 4.5.6 + 4.4.10 3.1 3.4.1 From a9928277da7f78aab3de17c35ec5f422ef37b644 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 13 Aug 2018 05:59:08 +0000 Subject: [PATCH 1384/2461] [SPARK-24420][BUILD][FOLLOW-UP] Upgrade ASM6 APIs ## What changes were proposed in this pull request? Use ASM 6 APIs after we upgrading it to ASM6. ## How was this patch tested? N/A Closes #22082 from gatorsmile/asm6. Authored-by: Xiao Li Signed-off-by: DB Tsai --- .../org/apache/spark/util/ClosureCleaner.scala | 14 +++++++------- .../apache/spark/graphx/util/BytecodeUtils.scala | 4 ++-- .../apache/spark/repl/ExecutorClassLoader.scala | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index d8c840c356527..b6c300c4778b1 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -425,7 +425,7 @@ private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") private class ReturnStatementFinder(targetMethodName: Option[String] = None) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { @@ -439,7 +439,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None) val isTargetMethod = targetMethodName.isEmpty || name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { throw new ReturnStatementInClosureException @@ -447,7 +447,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None) } } } else { - new MethodVisitor(ASM5) {} + new MethodVisitor(ASM6) {} } } } @@ -471,7 +471,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { override def visitMethod( access: Int, @@ -486,7 +486,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -526,7 +526,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM6) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -541,7 +541,7 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index a559685b1633c..50b03f71379a1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -109,14 +109,14 @@ private[graphx] object BytecodeUtils { * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM5) { + extends ClassVisitor(ASM6) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM5) { + new MethodVisitor(ASM6) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 42298b06a2c86..88eb0ad1da3d7 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -187,7 +187,7 @@ class ExecutorClassLoader( } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM5, cv) { +extends ClassVisitor(ASM6, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) From b270bccffffb233331b814e77ae55c1b74bc25d7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 13 Aug 2018 19:27:17 +0800 Subject: [PATCH 1385/2461] [SPARK-25096][SQL] Loosen nullability if the cast is force-nullable. ## What changes were proposed in this pull request? In type coercion for complex types, if the found type is force-nullable to cast, we should loosen the nullability to be able to cast. Also for map key type, we can't use the type. ## How was this patch tested? Added some test. Closes #22086 from ueshin/issues/SPARK-25096/fix_type_coercion. Authored-by: Takuya UESHIN Signed-off-by: hyukjinkwon --- .../sql/catalyst/analysis/TypeCoercion.scala | 21 ++++++++++++------- .../catalyst/analysis/TypeCoercionSuite.scala | 16 ++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 27839d72c6306..10d9ee52facac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -153,19 +153,26 @@ object TypeCoercion { t2: DataType, findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + findTypeFunc(et1, et2).map { et => + ArrayType(et, containsNull1 || containsNull2 || + Cast.forceNullable(et1, et) || Cast.forceNullable(et2, et)) + } case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findTypeFunc(kt1, kt2).flatMap { kt => - findTypeFunc(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } + findTypeFunc(kt1, kt2) + .filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) } + .flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2 || + Cast.forceNullable(vt1, vt) || Cast.forceNullable(vt2, vt)) + } } case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => val resolver = SQLConf.get.resolver fields1.zip(fields2).foldLeft(Option(new StructType())) { case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => - findTypeFunc(field1.dataType, field2.dataType).map { - dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + findTypeFunc(field1.dataType, field2.dataType).map { dt => + struct.add(field1.name, dt, field1.nullable || field2.nullable || + Cast.forceNullable(field1.dataType, dt) || Cast.forceNullable(field2.dataType, dt)) } case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d71bbb3227134..2c6cb3ae1274a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -499,6 +499,10 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(new StructType().add("num", ShortType), containsNull = false), ArrayType(new StructType().add("num", LongType), containsNull = false), Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(IntegerType, containsNull = false), + ArrayType(DecimalType.IntDecimal, containsNull = false), + Some(ArrayType(DecimalType.IntDecimal, containsNull = true))) // MapType widenTestWithStringPromotion( @@ -517,6 +521,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType, valueContainsNull = false), + MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false), + Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType, valueContainsNull = false), + MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false), + None) // StructType widenTestWithStringPromotion( @@ -540,6 +552,10 @@ class TypeCoercionSuite extends AnalysisTest { .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), Some(new StructType() .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", IntegerType, nullable = false), + new StructType().add("num", DecimalType.IntDecimal, nullable = false), + Some(new StructType().add("num", DecimalType.IntDecimal, nullable = true))) widenTestWithStringPromotion( new StructType().add("num", IntegerType), From ab06c25350f8a997bef0c3dd8aa82b709e7dfb3f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 13 Aug 2018 20:13:09 +0800 Subject: [PATCH 1386/2461] [SPARK-24391][SQL] Support arrays of any types by from_json ## What changes were proposed in this pull request? The PR removes a restriction for element types of array type which exists in `from_json` for the root type. Currently, the function can handle only arrays of structs. Even array of primitive types is disallowed. The PR allows arrays of any types currently supported by JSON datasource. Here is an example of an array of a primitive type: ``` scala> import org.apache.spark.sql.functions._ scala> val df = Seq("[1, 2, 3]").toDF("a") scala> val schema = new ArrayType(IntegerType, false) scala> val arr = df.select(from_json($"a", schema)) scala> arr.printSchema root |-- jsontostructs(a): array (nullable = true) | |-- element: integer (containsNull = true) ``` and result of converting of the json string to the `ArrayType`: ``` scala> arr.show +----------------+ |jsontostructs(a)| +----------------+ | [1, 2, 3]| +----------------+ ``` ## How was this patch tested? I added a few positive and negative tests: - array of primitive types - array of arrays - array of structs - array of maps Closes #21439 from MaxGekk/from_json-array. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/functions.py | 7 +- .../expressions/jsonExpressions.scala | 19 ++--- .../sql/catalyst/json/JacksonParser.scala | 30 ++++++++ .../org/apache/spark/sql/functions.scala | 10 +-- .../sql-tests/inputs/json-functions.sql | 12 +++ .../sql-tests/results/json-functions.sql.out | 66 +++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 76 +++++++++++++++++-- 7 files changed, 194 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index eaecf284b51f1..f5833734103a4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2241,7 +2241,7 @@ def json_tuple(col, *fields): def from_json(col, schema, options={}): """ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` - as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + as keys type, :class:`StructType` or :class:`ArrayType` with the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format @@ -2269,6 +2269,11 @@ def from_json(col, schema, options={}): >>> schema = schema_of_json(lit('''{"a": 0}''')) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> data = [(1, '''[1, 2, 3]''')] + >>> schema = ArrayType(IntegerType()) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[1, 2, 3])] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index abe88754f3a1e..ca99100b6d64f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -495,7 +495,7 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] or [[ArrayType]] of [[StructType]]s + * Converts an json input string to a [[StructType]], [[ArrayType]] or [[MapType]] * with the specified schema. */ // scalastyle:off line.size.limit @@ -544,17 +544,10 @@ case class JsonToStructs( timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) | _: MapType => + case _: StructType | _: ArrayType | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${nullableSchema.catalogString} must be a struct or an array of structs.") - } - - @transient - lazy val rowSchema = nullableSchema match { - case st: StructType => st - case ArrayType(st: StructType, _) => st - case mt: MapType => mt + s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.") } // This converts parsed rows to the desired output by the given schema. @@ -562,8 +555,8 @@ case class JsonToStructs( lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null - case ArrayType(_: StructType, _) => - (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: ArrayType => + (rows: Seq[InternalRow]) => rows.head.getArray(0) case _: MapType => (rows: Seq[InternalRow]) => rows.head.getMap(0) } @@ -571,7 +564,7 @@ case class JsonToStructs( @transient lazy val parser = new JacksonParser( - rowSchema, + nullableSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) override def dataType: DataType = nullableSchema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 4d409caddd33d..6feea500b2aa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -61,6 +61,7 @@ class JacksonParser( dt match { case st: StructType => makeStructRootConverter(st) case mt: MapType => makeMapRootConverter(mt) + case at: ArrayType => makeArrayRootConverter(at) } } @@ -101,6 +102,35 @@ class JacksonParser( } } + private def makeArrayRootConverter(at: ArrayType): JsonParser => Seq[InternalRow] = { + val elemConverter = makeConverter(at.elementType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, at) { + case START_ARRAY => Seq(InternalRow(convertArray(parser, elemConverter))) + case START_OBJECT if at.elementType.isInstanceOf[StructType] => + // This handles the case when an input JSON object is a structure but + // the specified schema is an array of structures. In that case, the input JSON is + // considered as an array of only one element of struct type. + // This behavior was introduced by changes for SPARK-19595. + // + // For example, if the specified schema is ArrayType(new StructType().add("i", IntegerType)) + // and JSON input as below: + // + // [{"i": 1}, {"i": 2}] + // [{"i": 3}] + // {"i": 4} + // + // The last row is considered as an array with one element, and result of conversion: + // + // Seq(Row(1), Row(2)) + // Seq(Row(3)) + // Seq(Row(4)) + // + val st = at.elementType.asInstanceOf[StructType] + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + Seq(InternalRow(new GenericArrayData(Seq(convertObject(parser, st, fieldConverters))))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 310e428b69819..5a6ed5964a750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3339,7 +3339,7 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3371,7 +3371,7 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3400,7 +3400,7 @@ object functions { /** * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, - * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3414,7 +3414,7 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3431,7 +3431,7 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` - * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * as keys type, `StructType` or `ArrayType` with the specified schema. * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 79fdd5895e691..0cf370c13e8c0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -39,3 +39,15 @@ select from_json('{"a":1, "b":"2"}', 'struct'); -- infer schema of json literal select schema_of_json('{"c1":0, "c2":[1]}'); select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); + +-- from_json - array type +select from_json('[1, 2, 3]', 'array'); +select from_json('[1, "2", 3]', 'array'); +select from_json('[1, 2, null]', 'array'); + +select from_json('[{"a": 1}, {"a":2}]', 'array>'); +select from_json('{"a": 1}', 'array>'); +select from_json('[null, {"a":2}]', 'array>'); + +select from_json('[{"a": 1}, {"b":2}]', 'array>'); +select from_json('[{"a": 1}, 2]', 'array>'); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 827931d74138d..b44883b070663 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 38 -- !query 0 @@ -290,3 +290,67 @@ select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) struct>> -- !query 29 output {"c1":[1,2,3]} + + +-- !query 30 +select from_json('[1, 2, 3]', 'array') +-- !query 30 schema +struct> +-- !query 30 output +[1,2,3] + + +-- !query 31 +select from_json('[1, "2", 3]', 'array') +-- !query 31 schema +struct> +-- !query 31 output +NULL + + +-- !query 32 +select from_json('[1, 2, null]', 'array') +-- !query 32 schema +struct> +-- !query 32 output +[1,2,null] + + +-- !query 33 +select from_json('[{"a": 1}, {"a":2}]', 'array>') +-- !query 33 schema +struct>> +-- !query 33 output +[{"a":1},{"a":2}] + + +-- !query 34 +select from_json('{"a": 1}', 'array>') +-- !query 34 schema +struct>> +-- !query 34 output +[{"a":1}] + + +-- !query 35 +select from_json('[null, {"a":2}]', 'array>') +-- !query 35 schema +struct>> +-- !query 35 output +[null,{"a":2}] + + +-- !query 36 +select from_json('[{"a": 1}, {"b":2}]', 'array>') +-- !query 36 schema +struct>> +-- !query 36 output +[{"a":1},{"b":2}] + + +-- !query 37 +select from_json('[{"a": 1}, 2]', 'array>') +-- !query 37 schema +struct>> +-- !query 37 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index d3b2701f2558e..f321ab86e9b7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -133,15 +133,11 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil) } - test("from_json invalid schema") { + test("from_json - json doesn't conform to the array type") { val df = Seq("""{"a" 1}""").toDS() val schema = ArrayType(StringType) - val message = intercept[AnalysisException] { - df.select(from_json($"value", schema)) - }.getMessage - assert(message.contains( - "Input schema array must be a struct or an array of structs.")) + checkAnswer(df.select(from_json($"value", schema)), Seq(Row(null))) } test("from_json array support") { @@ -405,4 +401,72 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(out.schema == expected) } + + test("from_json - array of primitive types") { + val df = Seq("[1, 2, 3]").toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(from_json($"a", schema)), Seq(Row(Array(1, 2, 3)))) + } + + test("from_json - array of primitive types - malformed row") { + val df = Seq("[1, 2 3]").toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(from_json($"a", schema)), Seq(Row(null))) + } + + test("from_json - array of arrays") { + val jsonDF = Seq("[[1], [2, 3], [4, 5, 6]]").toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("select json[0][0], json[1][1], json[2][2] from jsonTable"), + Seq(Row(1, 3, 6))) + } + + test("from_json - array of arrays - malformed row") { + val jsonDF = Seq("[[1], [2, 3], 4, 5, 6]]").toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("select json[0] from jsonTable"), Seq(Row(null))) + } + + test("from_json - array of structs") { + val jsonDF = Seq("""[{"a":1}, {"a":2}, {"a":3}]""").toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("select json[0], json[1], json[2] from jsonTable"), + Seq(Row(Row(1), Row(2), Row(3)))) + } + + test("from_json - array of structs - malformed row") { + val jsonDF = Seq("""[{"a":1}, {"a:2}, {"a":3}]""").toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("select json[0], json[1]from jsonTable"), Seq(Row(null, null))) + } + + test("from_json - array of maps") { + val jsonDF = Seq("""[{"a":1}, {"b":2}]""").toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer( + sql("""select json[0], json[1] from jsonTable"""), + Seq(Row(Map("a" -> 1), Map("b" -> 2)))) + } + + test("from_json - array of maps - malformed row") { + val jsonDF = Seq("""[{"a":1} "b":2}]""").toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + jsonDF.select(from_json($"a", schema) as "json").createOrReplaceTempView("jsonTable") + + checkAnswer(sql("""select json[0] from jsonTable"""), Seq(Row(null))) + } } From 26775e3c8ed5bf9028253280b57da64678363f8a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 13 Aug 2018 20:50:28 +0800 Subject: [PATCH 1387/2461] [SPARK-25099][SQL][TEST] Generate Avro Binary files in test suite ## What changes were proposed in this pull request? In PR https://github.com/apache/spark/pull/21984 and https://github.com/apache/spark/pull/21935 , the related test cases are using binary files created by Python scripts. Generate the binary files in test suite to make it more transparent. Also we can Also move the related test cases to a new file `AvroLogicalTypeSuite.scala`. ## How was this patch tested? Unit test. Closes #22091 from gengliangwang/logicalType_suite. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- external/avro/src/test/resources/date.avro | Bin 209 -> 0 bytes .../avro/src/test/resources/timestamp.avro | Bin 375 -> 0 bytes .../spark/sql/avro/AvroLogicalTypeSuite.scala | 298 ++++++++++++++++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 242 +------------- 4 files changed, 299 insertions(+), 241 deletions(-) delete mode 100644 external/avro/src/test/resources/date.avro delete mode 100644 external/avro/src/test/resources/timestamp.avro create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala diff --git a/external/avro/src/test/resources/date.avro b/external/avro/src/test/resources/date.avro deleted file mode 100644 index 3a6761704cb169b69e0d8dcf61447e51d90a04c2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 209 zcmeZI%3@>@ODrqO*DFrWNX<>0z*MbNQdy9yWTl`~l$xAhl%k}gpp=)Gn_66um<$%q z$xqKrPRxOcgH)EJ7MFndX_=`xDaAmMXt*iWN>KG7P*YP9OHx5Cx;;@x1E2_6G_LL&BYT}`Ey=nVdJ~SiY^EMS{_DB diff --git a/external/avro/src/test/resources/timestamp.avro b/external/avro/src/test/resources/timestamp.avro deleted file mode 100644 index daef50b78b87494b75c0309bfdbb917f228edc01..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 375 zcmeZI%3@>@ODrqO*DFrWNX<>$!&0qOQdy9yWTl`~l$xAhl%k}gpp=)Gn_66um<$%q z$xqKrPRxOcgH)EJ7MFndX_=`xDaAmMXt*iWN>KG7P*Y1Xfo7E?<`(GYX6EE%7K8M` zY|P2eOINCeS_n26rZ^s|7$`}c(aA;m#2XD(jBGT}(Lk3VIRxUe*jf>ASS9DDq$YFZ ymFDCycpDx~so;CZqcP7bglXq>7Z$Y({0)=7Fn-Wmuq?1){+%7{7v997D*^zXy@3t@ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala new file mode 100644 index 0000000000000..24d8c53764794 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.avro + +import java.io.File +import java.sql.Timestamp + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types.{StructField, StructType, TimestampType} + +class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + + val dateSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "date", "type": {"type": "int", "logicalType": "date"}} + ] + } + """ + + val dateInputData = Seq(7, 365, 0) + + def dateFile(path: String): String = { + val schema = new Schema.Parser().parse(dateSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + dateInputData.foreach { x => + val record = new GenericData.Record(schema) + record.put("date", x) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: date") { + withTempDir { dir => + val expected = dateInputData.map(t => Row(DateTimeUtils.toJavaDate(t))) + val dateAvro = dateFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(dateAvro) + + checkAnswer(df, expected) + + checkAnswer(spark.read.format("avro").option("avroSchema", dateSchema).load(dateAvro), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + val timestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, + {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, + {"name": "long", "type": "long"} + ] + } + """ + + val timestampInputData = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) + + def timestampFile(path: String): String = { + val schema = new Schema.Parser().parse(timestampSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + timestampInputData.foreach { t => + val record = new GenericData.Record(schema) + record.put("timestamp_millis", t._1) + // For microsecond precision, we multiple the value by 1000 to match the expected answer as + // timestamp with millisecond precision. + record.put("timestamp_micros", t._2 * 1000) + record.put("long", t._3) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: timestamp_millis") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._1))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: timestamp_micros") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._2))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: specify different output timestamp types") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = + spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) + + Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType => + withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) { + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + } + } + + test("Read Long type as Timestamp") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val schema = StructType(StructField("long", TimestampType, true) :: Nil) + val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._3))) + + checkAnswer(df, expected) + } + } + + test("Logical type: user specified schema") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val expected = timestampInputData + .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) + + val df = spark.read.format("avro").option("avroSchema", timestampSchema).load(timestampAvro) + checkAnswer(df, expected) + } + } + + val decimalInputData = Seq("1.23", "4.56", "78.90", "-1", "-2.31") + + def decimalSchemaAndFile(path: String): (String, String) = { + val precision = 4 + val scale = 2 + val bytesFieldName = "bytes" + val bytesSchema = s"""{ + "type":"bytes", + "logicalType":"decimal", + "precision":$precision, + "scale":$scale + } + """ + + val fixedFieldName = "fixed" + val fixedSchema = s"""{ + "type":"fixed", + "size":5, + "logicalType":"decimal", + "precision":$precision, + "scale":$scale, + "name":"foo" + } + """ + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "$bytesFieldName", "type": $bytesSchema}, + {"name": "$fixedFieldName", "type": $fixedSchema} + ] + } + """ + val schema = new Schema.Parser().parse(avroSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val decimalConversion = new DecimalConversion + val avroFile = s"$path/test.avro" + dataFileWriter.create(schema, new File(avroFile)) + val logicalType = LogicalTypes.decimal(precision, scale) + + decimalInputData.map { x => + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal(x).setScale(scale) + val bytes = + decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) + avroRec.put(bytesFieldName, bytes) + val fixed = + decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) + avroRec.put(fixedFieldName, fixed) + dataFileWriter.append(avroRec) + } + dataFileWriter.flush() + dataFileWriter.close() + + (avroSchema, avroFile) + } + + test("Logical type: Decimal") { + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + checkAnswer(df, expected) + checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: Decimal with too large precision") { + withTempDir { dir => + val schema = new Schema.Parser().parse("""{ + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [{ + "name": "decimal", + "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} + }] + }""") + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") + val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) + avroRec.put("decimal", bytes) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val msg = intercept[SparkException] { + spark.read.format("avro").load(s"$dir.avro").collect() + }.getCause.getMessage + assert(msg.contains("Unscaled value too large for precision")) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 3fa43bf929761..b07b1464ef805 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -25,8 +25,7 @@ import java.util.{TimeZone, UUID} import scala.collection.JavaConverters._ -import org.apache.avro.{LogicalTypes, Schema} -import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} @@ -35,7 +34,6 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SparkException import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -47,50 +45,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val episodesAvro = testFile("episodes.avro") val testAvro = testFile("test.avro") - // The test file timestamp.avro is generated via following Python code: - // import json - // import avro.schema - // from avro.datafile import DataFileWriter - // from avro.io import DatumWriter - // - // write_schema = avro.schema.parse(json.dumps({ - // "namespace": "logical", - // "type": "record", - // "name": "test", - // "fields": [ - // {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, - // {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, - // {"name": "long", "type": "long"} - // ] - // })) - // - // writer = DataFileWriter(open("timestamp.avro", "wb"), DatumWriter(), write_schema) - // writer.append({"timestamp_millis": 1000, "timestamp_micros": 2000000, "long": 3000}) - // writer.append({"timestamp_millis": 666000, "timestamp_micros": 999000000, "long": 777000}) - // writer.close() - val timestampAvro = testFile("timestamp.avro") - - // The test file date.avro is generated via following Python code: - // import json - // import avro.schema - // from avro.datafile import DataFileWriter - // from avro.io import DatumWriter - // - // write_schema = avro.schema.parse(json.dumps({ - // "namespace": "logical", - // "type": "record", - // "name": "test", - // "fields": [ - // {"name": "date", "type": {"type": "int", "logicalType": "date"}} - // ] - // })) - // - // writer = DataFileWriter(open("date.avro", "wb"), DatumWriter(), write_schema) - // writer.append({"date": 7}) - // writer.append({"date": 365}) - // writer.close() - val dateAvro = testFile("date.avro") - override protected def beforeAll(): Unit = { super.beforeAll() spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) @@ -399,200 +353,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Logical type: date") { - val expected = Seq(7, 365).map(t => Row(DateTimeUtils.toJavaDate(t))) - val df = spark.read.format("avro").load(dateAvro) - - checkAnswer(df, expected) - - val avroSchema = s""" - { - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [ - {"name": "date", "type": {"type": "int", "logicalType": "date"}} - ] - } - """ - - checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(dateAvro), - expected) - - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - - test("Logical type: timestamp_millis") { - val expected = Seq(1000L, 666000L).map(t => Row(new Timestamp(t))) - val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) - - checkAnswer(df, expected) - - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - - test("Logical type: timestamp_micros") { - val expected = Seq(2000L, 999000L).map(t => Row(new Timestamp(t))) - val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) - - checkAnswer(df, expected) - - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - - test("Logical type: specify different output timestamp types") { - val df = - spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) - - val expected = Seq((1000L, 2000L), (666000L, 999000L)) - .map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) - - Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType => - withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) { - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - } - } - - test("Read Long type as Timestamp") { - val schema = StructType(StructField("long", TimestampType, true) :: Nil) - val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) - - val expected = Seq(3000L, 777000L).map(t => Row(new Timestamp(t))) - - checkAnswer(df, expected) - } - - test("Logical type: user specified schema") { - val expected = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) - .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) - - val avroSchema = s""" - { - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [ - {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, - {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, - {"name": "long", "type": "long"} - ] - } - """ - val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro) - checkAnswer(df, expected) - } - - test("Logical type: Decimal") { - val precision = 4 - val scale = 2 - val bytesFieldName = "bytes" - val bytesSchema = s"""{ - "type":"bytes", - "logicalType":"decimal", - "precision":$precision, - "scale":$scale - } - """ - - val fixedFieldName = "fixed" - val fixedSchema = s"""{ - "type":"fixed", - "size":5, - "logicalType":"decimal", - "precision":$precision, - "scale":$scale, - "name":"foo" - } - """ - val avroSchema = s""" - { - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [ - {"name": "$bytesFieldName", "type": $bytesSchema}, - {"name": "$fixedFieldName", "type": $fixedSchema} - ] - } - """ - val schema = new Schema.Parser().parse(avroSchema) - val datumWriter = new GenericDatumWriter[GenericRecord](schema) - val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) - val decimalConversion = new DecimalConversion - withTempDir { dir => - val avroFile = s"$dir.avro" - dataFileWriter.create(schema, new File(avroFile)) - val logicalType = LogicalTypes.decimal(precision, scale) - val data = Seq("1.23", "4.56", "78.90", "-1", "-2.31") - data.map { x => - val avroRec = new GenericData.Record(schema) - val decimal = new java.math.BigDecimal(x).setScale(scale) - val bytes = - decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) - avroRec.put(bytesFieldName, bytes) - val fixed = - decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) - avroRec.put(fixedFieldName, fixed) - dataFileWriter.append(avroRec) - } - dataFileWriter.flush() - dataFileWriter.close() - - val expected = data.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } - val df = spark.read.format("avro").load(avroFile) - checkAnswer(df, expected) - checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), - expected) - - withTempPath { path => - df.write.format("avro").save(path.toString) - checkAnswer(spark.read.format("avro").load(path.toString), expected) - } - } - } - - test("Logical type: Decimal with too large precision") { - withTempDir { dir => - val schema = new Schema.Parser().parse("""{ - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [{ - "name": "decimal", - "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} - }] - }""") - val datumWriter = new GenericDatumWriter[GenericRecord](schema) - val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) - dataFileWriter.create(schema, new File(s"$dir.avro")) - val avroRec = new GenericData.Record(schema) - val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") - val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) - avroRec.put("decimal", bytes) - dataFileWriter.append(avroRec) - dataFileWriter.flush() - dataFileWriter.close() - - val msg = intercept[SparkException] { - spark.read.format("avro").load(s"$dir.avro").collect() - }.getCause.getMessage - assert(msg.contains("Unscaled value too large for precision")) - } - } - test("Array data types") { withTempPath { dir => val testSchema = StructType(Seq( From 2e3abdff23a0725b80992cc30dba2ecf9c2e7fd3 Mon Sep 17 00:00:00 2001 From: Eyal Farago Date: Mon, 13 Aug 2018 20:55:46 +0800 Subject: [PATCH 1388/2461] [SPARK-22713][CORE] ExternalAppendOnlyMap leaks when spilled during iteration ## What changes were proposed in this pull request? This PR solves [SPARK-22713](https://issues.apache.org/jira/browse/SPARK-22713) which describes a memory leak that occurs when and ExternalAppendOnlyMap is spilled during iteration (opposed to insertion). (Please fill in changes proposed in this fix) ExternalAppendOnlyMap's iterator supports spilling but it kept a reference to the internal map (via an internal iterator) after spilling, it seems that the original code was actually supposed to 'get rid' of this reference on the next iteration but according to the elaborate investigation described in the JIRA this didn't happen. the fix was simply replacing the internal iterator immediately after spilling. ## How was this patch tested? I've introduced a new test to test suite ExternalAppendOnlyMapSuite, this test asserts that neither the external map itself nor its iterator hold any reference to the internal map after a spill. These approach required some access relaxation of some members variables and nested classes of ExternalAppendOnlyMap, this members are now package provate and annotated with VisibleForTesting. Closes #21369 from eyalfa/SPARK-22713__ExternalAppendOnlyMap_effective_spill. Authored-by: Eyal Farago Signed-off-by: Wenchen Fan --- .../collection/ExternalAppendOnlyMap.scala | 35 +++--- .../ExternalAppendOnlyMapSuite.scala | 119 +++++++++++++++++- 2 files changed, 138 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index d83da0d126d89..19ff109b673e1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -80,7 +80,10 @@ class ExternalAppendOnlyMap[K, V, C]( this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) } - @volatile private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + /** + * Exposed for testing + */ + @volatile private[collection] var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager @@ -267,7 +270,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { readingIterator = new SpillableIterator(inMemoryIterator) - readingIterator + readingIterator.toCompletionIterator } /** @@ -280,8 +283,7 @@ class ExternalAppendOnlyMap[K, V, C]( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") } if (spilledMaps.isEmpty) { - CompletionIterator[(K, C), Iterator[(K, C)]]( - destructiveIterator(currentMap.iterator), freeCurrentMap()) + destructiveIterator(currentMap.iterator) } else { new ExternalIterator() } @@ -305,8 +307,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator( - currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) + private val sortedMap = destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -568,13 +570,11 @@ class ExternalAppendOnlyMap[K, V, C]( context.addTaskCompletionListener[Unit](context => cleanup()) } - private[this] class SpillableIterator(var upstream: Iterator[(K, C)]) + private class SpillableIterator(var upstream: Iterator[(K, C)]) extends Iterator[(K, C)] { private val SPILL_LOCK = new Object() - private var nextUpstream: Iterator[(K, C)] = null - private var cur: (K, C) = readNext() private var hasSpilled: Boolean = false @@ -585,17 +585,24 @@ class ExternalAppendOnlyMap[K, V, C]( } else { logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") - nextUpstream = spillMemoryIteratorToDisk(upstream) + val nextUpstream = spillMemoryIteratorToDisk(upstream) + assert(!upstream.hasNext) hasSpilled = true + upstream = nextUpstream true } } + private def destroy(): Unit = { + freeCurrentMap() + upstream = Iterator.empty + } + + def toCompletionIterator: CompletionIterator[(K, C), SpillableIterator] = { + CompletionIterator[(K, C), SpillableIterator](this, this.destroy) + } + def readNext(): (K, C) = SPILL_LOCK.synchronized { - if (nextUpstream != null) { - upstream = nextUpstream - nextUpstream = null - } if (upstream.hasNext) { upstream.next() } else { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 35312f2d71131..d542ba0b6640d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.util.collection +import java.util.Objects + import scala.collection.mutable.ArrayBuffer +import scala.ref.WeakReference + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.util.CompletionIterator -class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite + with LocalSparkContext + with Eventually + with Matchers{ import TestUtils.{assertNotSpilled, assertSpilled} private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS @@ -414,7 +424,112 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("external aggregation updates peak execution memory") { + test("SPARK-22713 spill during iteration leaks internal map") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val it = map.iterator + assert(it.isInstanceOf[CompletionIterator[_, _]]) + // org.apache.spark.util.collection.AppendOnlyMap.destructiveSortedIterator returns + // an instance of an annonymous Iterator class. + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val first50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(map.numSpills == 0) + map.spill(Long.MaxValue, null) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + // assert(map.currentMap == null) + eventually { + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + + val next50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(!it.hasNext) + val keys = (first50Keys ++ next50Keys).sorted + assert(keys == (0 until 100)) + } + + test("drop all references to the underlying map once the iterator is exhausted") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val it = map.iterator + assert( it.isInstanceOf[CompletionIterator[_, _]]) + + + val keys = it.map{ + case (k, vs) => + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + .toList + .sorted + + assert(it.isEmpty) + assert(keys == (0 until 100)) + + assert(map.numSpills == 0) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + assert(map.currentMap == null) + + eventually { + Thread.sleep(500) + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + assert(it.toList.isEmpty) + } + + test("SPARK-22713 external aggregation updates peak execution memory") { val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) From b804ca57718ad1568458d8185c8c30118be8275f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 13 Aug 2018 20:58:29 +0800 Subject: [PATCH 1389/2461] [SPARK-23908][SQL][FOLLOW-UP] Rename inputs to arguments, and add argument type check. ## What changes were proposed in this pull request? This is a follow-up pr of #21954 to address comments. - Rename ambiguous name `inputs` to `arguments`. - Add argument type check and remove hacky workaround. - Address other small comments. ## How was this patch tested? Existing tests and some additional tests. Closes #22075 from ueshin/issues/SPARK-23908/fup1. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 ++ .../analysis/higherOrderFunctions.scala | 12 +- .../expressions/ExpectsInputTypes.scala | 16 +- .../expressions/higherOrderFunctions.scala | 181 +++++++++--------- .../spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 25 +++ 6 files changed, 152 insertions(+), 98 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4addc83add3e0..6a91d556b2f3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -90,6 +90,20 @@ trait CheckAnalysis extends PredicateHelper { u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") case operator: LogicalPlan => + // Check argument data types of higher-order functions downwards first. + // If the arguments of the higher-order functions are resolved but the type check fails, + // the argument functions will not get resolved, but we should report the argument type + // check failure instead of claiming the argument functions are unresolved. + operator transformExpressionsDown { + case hof: HigherOrderFunction + if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure => + hof.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + hof.failAnalysis( + s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message") + } + } + operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.qualifiedName).mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 5e2029c251ee4..dd08190e1e8a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -95,15 +95,15 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { */ private def createLambda( e: Expression, - partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match { + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { case f: LambdaFunction if f.bound => f case LambdaFunction(function, names, _) => - if (names.size != partialArguments.size) { + if (names.size != argInfo.size) { e.failAnalysis( s"The number of lambda function arguments '${names.size}' does not " + "match the number of arguments expected by the higher order function " + - s"'${partialArguments.size}'.") + s"'${argInfo.size}'.") } if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { @@ -111,7 +111,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { "Lambda function arguments should not have names that are semantically the same.") } - val arguments = partialArguments.zip(names).map { + val arguments = argInfo.zip(names).map { case ((dataType, nullable), ne) => NamedLambdaVariable(ne.name, dataType, nullable) } @@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { // create a lambda function with default parameters because this is expected by the higher // order function. Note that we hide the lambda variables produced by this function in order // to prevent accidental naming collisions. - val arguments = partialArguments.zipWithIndex.map { + val arguments = argInfo.zipWithIndex.map { case ((dataType, nullable), i) => NamedLambdaVariable(s"col$i", dataType, nullable) } @@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { case _ if e.resolved => e - case h: HigherOrderFunction if h.inputResolved => + case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess => h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) case l: LambdaFunction if !l.bound => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index d8f046c0028a9..981ce0b6a29fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + ExpectsInputTypes.checkInputDataTypes(children, inputTypes) + } +} + +object ExpectsInputTypes { + + def checkInputDataTypes( + inputs: Seq[Expression], + inputTypes: Seq[AbstractDataType]): TypeCheckResult = { + val mismatches = inputs.zip(inputTypes).zipWithIndex.collect { + case ((input, expected), idx) if !expected.acceptsType(input.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.catalogString} type." + s"however, '${input.sql}' is of ${input.dataType.catalogString} type." } if (mismatches.isEmpty) { @@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression { } } - /** * A mixin for the analyzer to perform implicit type casting using * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 7f8203ab92213..5d1b8c4da0bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -35,8 +35,8 @@ case class NamedLambdaVariable( name: String, dataType: DataType, nullable: Boolean, - value: AtomicReference[Any] = new AtomicReference(), - exprId: ExprId = NamedExpression.newExprId) + exprId: ExprId = NamedExpression.newExprId, + value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression with NamedExpression with CodegenFallback { @@ -44,7 +44,7 @@ case class NamedLambdaVariable( override def qualifier: Seq[String] = Seq.empty override def newInstance(): NamedExpression = - copy(value = new AtomicReference(), exprId = NamedExpression.newExprId) + copy(exprId = NamedExpression.newExprId, value = new AtomicReference()) override def toAttribute: Attribute = { AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty) @@ -88,30 +88,45 @@ object LambdaFunction { * A higher order function takes one or more (lambda) functions and applies these to some objects. * The function produces a number of variables which can be consumed by some lambda function. */ -trait HigherOrderFunction extends Expression { +trait HigherOrderFunction extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = inputs ++ functions + override def children: Seq[Expression] = arguments ++ functions /** - * Inputs to the higher ordered function. + * Arguments of the higher ordered function. */ - def inputs: Seq[Expression] + def arguments: Seq[Expression] + + def argumentTypes: Seq[AbstractDataType] /** - * All inputs have been resolved. This means that the types and nullabilty of (most of) the + * All arguments have been resolved. This means that the types and nullabilty of (most of) the * lambda function arguments is known, and that we can start binding the lambda functions. */ - lazy val inputResolved: Boolean = inputs.forall(_.resolved) + lazy val argumentsResolved: Boolean = arguments.forall(_.resolved) + + /** + * Checks the argument data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `argumentsResolved == true`. + */ + def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes) + } /** * Functions applied by the higher order function. */ def functions: Seq[Expression] + def functionTypes: Seq[AbstractDataType] + + override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes + /** * All inputs must be resolved and all functions must be resolved lambda functions. */ - override lazy val resolved: Boolean = inputResolved && functions.forall { + override lazy val resolved: Boolean = argumentsResolved && functions.forall { case l: LambdaFunction => l.resolved case _ => false } @@ -123,6 +138,8 @@ trait HigherOrderFunction extends Expression { */ def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + // Make sure the lambda variables refer the same instances as of arguments for case that the + // variables in instantiated separately during serialization or for some reason. @transient lazy val functionsForEval: Seq[Expression] = functions.map { case LambdaFunction(function, arguments, hidden) => val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap @@ -133,51 +150,38 @@ trait HigherOrderFunction extends Expression { } } -object HigherOrderFunction { - - def arrayArgumentType(dt: DataType): (DataType, Boolean) = { - dt match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } - } - - def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { - case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) - case _ => - val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType - (kType, vType, vContainsNull) - } -} - /** * Trait for functions having as input one argument and one function. */ -trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +trait SimpleHigherOrderFunction extends HigherOrderFunction { + + def argument: Expression - def input: Expression + override def arguments: Seq[Expression] = argument :: Nil - override def inputs: Seq[Expression] = input :: Nil + def argumentType: AbstractDataType + + override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil def function: Expression override def functions: Seq[Expression] = function :: Nil - def expectingFunctionType: AbstractDataType = AnyDataType + def functionType: AbstractDataType = AnyDataType + + override def functionTypes: Seq[AbstractDataType] = functionType :: Nil - @transient lazy val functionForEval: Expression = functionsForEval.head + def functionForEval: Expression = functionsForEval.head /** * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method * in order to save null-check code. */ - protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = + protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") override def eval(inputRow: InternalRow): Any = { - val value = input.eval(inputRow) + val value = argument.eval(inputRow) if (value == null) { null } else { @@ -187,11 +191,11 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) + override def argumentType: AbstractDataType = ArrayType } trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { - override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + override def argumentType: AbstractDataType = MapType } /** @@ -209,21 +213,21 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { """, since = "2.4.0") case class ArrayTransform( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val ArrayType(elementType, containsNull) = argument.dataType function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, elem :: (IntegerType, false) :: Nil)) + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) case _ => - copy(function = f(function, elem :: Nil)) + copy(function = f(function, (elementType, containsNull) :: Nil)) } } @@ -237,8 +241,8 @@ case class ArrayTransform( (elementVar, indexVar) } - override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { - val arr = inputValue.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval val result = new GenericArrayData(new Array[Any](arr.numElements)) var i = 0 @@ -268,7 +272,7 @@ examples = """ """, since = "2.4.0") case class MapFilter( - input: Expression, + argument: Expression, function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { @@ -277,17 +281,16 @@ case class MapFilter( (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) } - @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val m = value.asInstanceOf[MapData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val m = argumentValue.asInstanceOf[MapData] val f = functionForEval val retKeys = new mutable.ListBuffer[Any] val retValues = new mutable.ListBuffer[Any] @@ -302,9 +305,9 @@ case class MapFilter( ArrayBasedMapData(retKeys.toArray, retValues.toArray) } - override def dataType: DataType = input.dataType + override def dataType: DataType = argument.dataType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def prettyName: String = "map_filter" } @@ -321,25 +324,25 @@ case class MapFilter( """, since = "2.4.0") case class ArrayFilter( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable - override def dataType: DataType = input.dataType + override def dataType: DataType = argument.dataType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) - copy(function = f(function, elem :: Nil)) + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val arr = value.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval val buffer = new mutable.ArrayBuffer[Any](arr.numElements) var i = 0 @@ -368,25 +371,25 @@ case class ArrayFilter( """, since = "2.4.0") case class ArrayExists( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: DataType = BooleanType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) - copy(function = f(function, elem :: Nil)) + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val arr = value.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval var exists = false var i = 0 @@ -422,45 +425,49 @@ case class ArrayExists( """, since = "2.4.0") case class ArrayAggregate( - input: Expression, + argument: Expression, zero: Expression, merge: Expression, finish: Expression) extends HigherOrderFunction with CodegenFallback { - def this(input: Expression, zero: Expression, merge: Expression) = { - this(input, zero, merge, LambdaFunction.identity) + def this(argument: Expression, zero: Expression, merge: Expression) = { + this(argument, zero, merge, LambdaFunction.identity) } - override def inputs: Seq[Expression] = input :: zero :: Nil + override def arguments: Seq[Expression] = argument :: zero :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil override def functions: Seq[Expression] = merge :: finish :: Nil - override def nullable: Boolean = input.nullable || finish.nullable + override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil + + override def nullable: Boolean = argument.nullable || finish.nullable override def dataType: DataType = finish.dataType override def checkInputDataTypes(): TypeCheckResult = { - if (!ArrayType.acceptsType(input.dataType)) { - TypeCheckResult.TypeCheckFailure( - s"argument 1 requires ${ArrayType.simpleString} type, " + - s"however, '${input.sql}' is of ${input.dataType.catalogString} type.") - } else if (!DataType.equalsStructurally( - zero.dataType, merge.dataType, ignoreNullability = true)) { - TypeCheckResult.TypeCheckFailure( - s"argument 3 requires ${zero.dataType.simpleString} type, " + - s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") - } else { - TypeCheckResult.TypeCheckSuccess + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure } } override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val ArrayType(elementType, containsNull) = argument.dataType val acc = zero.dataType -> true - val newMerge = f(merge, acc :: elem :: Nil) + val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil) val newFinish = f(finish, acc :: Nil) copy(merge = newMerge, finish = newFinish) } @@ -470,7 +477,7 @@ case class ArrayAggregate( @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] + val arr = argument.eval(input).asInstanceOf[ArrayData] if (arr == null) { null } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 9e95b192968c7..67740c3166471 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -81,7 +81,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) case lv: NamedLambdaVariable => - lv.copy(value = null, exprId = ExprId(0)) + lv.copy(exprId = ExprId(0), value = null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2c4238e69ad7c..6401e3fc99783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1852,6 +1852,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("transform(i, x -> x)") } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("transform(a, x -> x)") + } + assert(ex3.getMessage.contains("cannot resolve '`a`'")) } test("map_filter") { @@ -1898,6 +1903,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("map_filter(i, (k, v) -> k > v)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_filter(a, (k, v) -> k > v)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("filter function - array for primitive type not containing null") { @@ -1994,6 +2004,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("filter(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("exists function - array for primitive type not containing null") { @@ -2090,6 +2105,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("exists(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("exists(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { @@ -2211,6 +2231,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("aggregate(a, 0, (acc, x) -> x)") + } + assert(ex5.getMessage.contains("cannot resolve '`a`'")) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { From c220cc42abebbc98a6110b50f787eb6d338c2d97 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 14 Aug 2018 00:59:18 +0800 Subject: [PATCH 1390/2461] [SPARK-25028][SQL] Avoid NPE when analyzing partition with NULL values ## What changes were proposed in this pull request? `ANALYZE TABLE ... PARTITION(...) COMPUTE STATISTICS` can fail with a NPE if a partition column contains a NULL value. The PR avoids the NPE, replacing the `NULL` values with the default partition placeholder. ## How was this patch tested? added UT Closes #22036 from mgaido91/SPARK-25028. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../command/AnalyzePartitionCommand.scala | 10 ++++++++-- .../spark/sql/StatisticsCollectionSuite.scala | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 5b54b2270b5ec..18fefa0a6f19f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Column, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -140,7 +140,13 @@ case class AnalyzePartitionCommand( val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() df.collect().map { r => - val partitionColumnValues = partitionColumns.indices.map(r.get(_).toString) + val partitionColumnValues = partitionColumns.indices.map { i => + if (r.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + r.get(i).toString + } + } val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap val count = BigInt(r.getLong(partitionColumns.size)) (spec, count) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 60fa951e23178..cb562d65b6147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -204,6 +204,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("SPARK-25028: column stats collection for null partitioning columns") { + val table = "analyze_partition_with_null" + withTempDir { dir => + withTable(table) { + sql(s""" + |CREATE TABLE $table (value string, name string) + |USING PARQUET + |PARTITIONED BY (name) + |LOCATION '${dir.toURI}'""".stripMargin) + val df = Seq(("a", null), ("b", null)).toDF("value", "name") + df.write.mode("overwrite").insertInto(table) + sql(s"ANALYZE TABLE $table PARTITION (name) COMPUTE STATISTICS") + val partitions = spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + assert(partitions.head.stats.get.rowCount.get == 2) + } + } + } + test("number format in statistics") { val numbers = Seq( BigInt(0) -> (("0.0 B", "0")), From ab197308a79c74f0a4205a8f60438811b5e0b991 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 14 Aug 2018 04:43:14 +0000 Subject: [PATCH 1391/2461] [SPARK-25104][SQL] Avro: Validate user specified output schema ## What changes were proposed in this pull request? With code changes in https://github.com/apache/spark/pull/21847 , Spark can write out to Avro file as per user provided output schema. To make it more robust and user friendly, we should validate the Avro schema before tasks launched. Also we should support output logical decimal type as BYTES (By default we output as FIXED) ## How was this patch tested? Unit test Closes #22094 from gengliangwang/AvroSerializerMatch. Authored-by: Gengliang Wang Signed-off-by: DB Tsai --- .../spark/sql/avro/AvroSerializer.scala | 108 ++++++++++-------- .../spark/sql/avro/AvroLogicalTypeSuite.scala | 40 +++++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 57 +++++++++ 3 files changed, 158 insertions(+), 47 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 3a9544c3f48cd..f551c8360729d 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -26,6 +26,7 @@ import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema import org.apache.avro.Schema.Type +import org.apache.avro.Schema.Type._ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} import org.apache.avro.generic.GenericData.Record import org.apache.avro.util.Utf8 @@ -72,62 +73,70 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: private lazy val decimalConversions = new DecimalConversion() private def newConverter(catalystType: DataType, avroType: Schema): Converter = { - catalystType match { - case NullType => + (catalystType, avroType.getType) match { + case (NullType, NULL) => (getter, ordinal) => null - case BooleanType => + case (BooleanType, BOOLEAN) => (getter, ordinal) => getter.getBoolean(ordinal) - case ByteType => + case (ByteType, INT) => (getter, ordinal) => getter.getByte(ordinal).toInt - case ShortType => + case (ShortType, INT) => (getter, ordinal) => getter.getShort(ordinal).toInt - case IntegerType => + case (IntegerType, INT) => (getter, ordinal) => getter.getInt(ordinal) - case LongType => + case (LongType, LONG) => (getter, ordinal) => getter.getLong(ordinal) - case FloatType => + case (FloatType, FLOAT) => (getter, ordinal) => getter.getFloat(ordinal) - case DoubleType => + case (DoubleType, DOUBLE) => (getter, ordinal) => getter.getDouble(ordinal) - case d: DecimalType => + case (d: DecimalType, FIXED) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => (getter, ordinal) => val decimal = getter.getDecimal(ordinal, d.precision, d.scale) decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, LogicalTypes.decimal(d.precision, d.scale)) - case StringType => avroType.getType match { - case Type.ENUM => - import scala.collection.JavaConverters._ - val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet - (getter, ordinal) => - val data = getter.getUTF8String(ordinal).toString - if (!enumSymbols.contains(data)) { - throw new IncompatibleSchemaException( - "Cannot write \"" + data + "\" since it's not defined in enum \"" + - enumSymbols.mkString("\", \"") + "\"") - } - new EnumSymbol(avroType, data) - case _ => - (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) - } - case BinaryType => avroType.getType match { - case Type.FIXED => - val size = avroType.getFixedSize() - (getter, ordinal) => - val data: Array[Byte] = getter.getBinary(ordinal) - if (data.length != size) { - throw new IncompatibleSchemaException( - s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + - "binary data into FIXED Type with size of " + - s"$size ${if (size > 1) "bytes" else "byte"}") - } - new Fixed(avroType, data) - case _ => - (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) - } - case DateType => + case (d: DecimalType, BYTES) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (StringType, ENUM) => + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + "Cannot write \"" + data + "\" since it's not defined in enum \"" + + enumSymbols.mkString("\", \"") + "\"") + } + new EnumSymbol(avroType, data) + + case (StringType, STRING) => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + + case (BinaryType, FIXED) => + val size = avroType.getFixedSize() + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + throw new IncompatibleSchemaException( + s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + + "binary data into FIXED Type with size of " + + s"$size ${if (size > 1) "bytes" else "byte"}") + } + new Fixed(avroType, data) + + case (BinaryType, BYTES) => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + + case (DateType, INT) => (getter, ordinal) => getter.getInt(ordinal) - case TimestampType => avroType.getLogicalType match { + + case (TimestampType, LONG) => avroType.getLogicalType match { case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000 case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) // For backward compatibility, if the Avro type is Long and it is not logical type, @@ -137,7 +146,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}") } - case ArrayType(et, containsNull) => + case (ArrayType(et, containsNull), ARRAY) => val elementConverter = newConverter( et, resolveNullableType(avroType.getElementType, containsNull)) (getter, ordinal) => { @@ -158,12 +167,12 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: java.util.Arrays.asList(result: _*) } - case st: StructType => + case (st: StructType, RECORD) => val structConverter = newStructConverter(st, avroType) val numFields = st.length (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) - case MapType(kt, vt, valueContainsNull) if kt == StringType => + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => val valueConverter = newConverter( vt, resolveNullableType(avroType.getValueType, valueContainsNull)) (getter, ordinal) => @@ -185,12 +194,17 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: result case other => - throw new IncompatibleSchemaException(s"Unexpected type: $other") + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " + + s"Avro type $avroType.") } } private def newStructConverter( catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + if (avroStruct.getType != RECORD) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + + s"Avro type $avroStruct.") + } val avroFields = avroStruct.getFields assert(avroFields.size() == catalystStruct.length) val fieldConverters = catalystStruct.zip(avroFields.asScala).map { @@ -212,7 +226,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: } private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { - if (nullable) { + if (nullable && avroType.getType != NULL) { // avro uses union to represent nullable type. val fields = avroType.getTypes.asScala assert(fields.length == 2) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 24d8c53764794..ca7eef2a81cf8 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -267,6 +267,46 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU } } + test("Logical type: write Decimal with BYTES type") { + val specifiedSchema = """ + { + "type" : "record", + "name" : "topLevelRecord", + "namespace" : "topLevelRecord", + "fields" : [ { + "name" : "bytes", + "type" : [ { + "type" : "bytes", + "namespace" : "topLevelRecord.bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + }, { + "name" : "fixed", + "type" : [ { + "type" : "bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + } ] + } + """ + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + assert(specifiedSchema != avroSchema) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + + withTempPath { path => + df.write.format("avro").option("avroSchema", specifiedSchema).save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + test("Logical type: Decimal with too large precision") { withTempDir { dir => val schema = new Schema.Parser().parse("""{ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index b07b1464ef805..c4f4d8efd6df4 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.Schema.Type._ import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} @@ -850,6 +851,62 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("throw exception if unable to write with user provided Avro schema") { + val input: Seq[(DataType, Schema.Type)] = Seq( + (NullType, NULL), + (BooleanType, BOOLEAN), + (ByteType, INT), + (ShortType, INT), + (IntegerType, INT), + (LongType, LONG), + (FloatType, FLOAT), + (DoubleType, DOUBLE), + (BinaryType, BYTES), + (DateType, INT), + (TimestampType, LONG), + (DecimalType(4, 2), BYTES) + ) + def assertException(f: () => AvroSerializer) { + val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] { + f() + }.getMessage + assert(message.contains("Cannot convert Catalyst type")) + } + + def resolveNullable(schema: Schema, nullable: Boolean): Schema = { + if (nullable && schema.getType != NULL) { + Schema.createUnion(schema, Schema.create(NULL)) + } else { + schema + } + } + for { + i <- input + j <- input + nullable <- Seq(true, false) + } if (i._2 != j._2) { + val avroType = resolveNullable(Schema.create(j._2), nullable) + val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) + val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) + val name = "foo" + val avroField = new Field(name, avroType, "", null) + val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) + val avroRecordType = resolveNullable(recordSchema, nullable) + + val catalystType = i._1 + val catalystArrayType = ArrayType(catalystType, nullable) + val catalystMapType = MapType(StringType, catalystType, nullable) + val catalystStructType = StructType(Seq(StructField(name, catalystType, nullable))) + + for { + avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType) + catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, catalystStructType) + } { + assertException(() => new AvroSerializer(catalyst, avro, nullable)) + } + } + } + test("reading from invalid path throws exception") { // Directory given has no avro files From 3eb52092b3aa9d7d2fc1e50ac237d47bfb3b9e92 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 Aug 2018 05:05:16 +0000 Subject: [PATCH 1392/2461] [SPARK-22974][ML] Attach attributes to output column of CountVectorModel ## What changes were proposed in this pull request? The output column from `CountVectorModel` lacks attribute. So a later transformer like `Interaction` can raise error because no attribute available. ## How was this patch tested? Added test. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #20313 from viirya/SPARK-22974. Authored-by: Liang-Chi Hsieh Signed-off-by: DB Tsai --- .../spark/ml/feature/CountVectorizer.scala | 5 ++++- .../spark/ml/feature/CountVectorizerSuite.scala | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 10c48c3f52085..dc8eb8261dbe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -317,7 +318,9 @@ class CountVectorizerModel( Vectors.sparse(dictBr.value.size, effectiveCounts) } - dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]] + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata) } @Since("1.5.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 61217669d9277..bca580d411373 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -289,4 +289,20 @@ class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { val newInstance = testDefaultReadWrite(instance) assert(newInstance.vocabulary === instance.vocabulary) } + + test("SPARK-22974: CountVectorModel should attach proper attribute to output column") { + val df = spark.createDataFrame(Seq( + (0, 1.0, Array("a", "b", "c")), + (1, 2.0, Array("a", "b", "b", "c", "a", "d")) + )).toDF("id", "features1", "words") + + val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features2") + + val df1 = cvm.transform(df) + val interaction = new Interaction().setInputCols(Array("features1", "features2")) + .setOutputCol("features") + interaction.transform(df1) + } } From e2ab7deae76d3b6f41b9ad4d0ece14ea28db40ce Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 14 Aug 2018 19:59:39 +0800 Subject: [PATCH 1393/2461] [MINOR][SQL][DOC] Fix `to_json` example in function description and doc ## What changes were proposed in this pull request? This PR fixes the an example for `to_json` in doc and function description. - http://spark.apache.org/docs/2.3.0/api/sql/#to_json - `describe function extended` ## How was this patch tested? Pass the Jenkins with the updated test. Closes #22096 from dongjoon-hyun/minor_json. Authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- .../apache/spark/sql/catalyst/expressions/jsonExpressions.scala | 2 +- .../src/test/resources/sql-tests/results/json-functions.sql.out | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index ca99100b6d64f..11cc88735a9a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -625,7 +625,7 @@ case class JsonToStructs( {"a":1,"b":2} > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT _FUNC_(map('a', named_struct('b', 1))); {"a":{"b":1}} diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index b44883b070663..7444cdbef96e4 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -24,7 +24,7 @@ Extended Usage: {"a":1,"b":2} > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + > SELECT to_json(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT to_json(map('a', named_struct('b', 1))); {"a":{"b":1}} From 42263fd0cbdc86c68438515ac439a15033b8bbd2 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 14 Aug 2018 21:14:15 +0900 Subject: [PATCH 1394/2461] [SPARK-23938][SQL] Add map_zip_with function ## What changes were proposed in this pull request? This PR adds a new SQL function called ```map_zip_with```. It merges the two given maps into a single map by applying function to the pair of values with the same key. ## How was this patch tested? Added new tests into: - DataFrameFunctionsSuite.scala - HigherOrderFunctionsSuite.scala Closes #22017 from mn-mikke/SPARK-23938. Authored-by: Marek Novotny Signed-off-by: Takuya UESHIN --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 25 +++ .../expressions/higherOrderFunctions.scala | 197 +++++++++++++++++- .../HigherOrderFunctionsSuite.scala | 129 ++++++++++++ .../inputs/typeCoercion/native/mapZipWith.sql | 66 ++++++ .../typeCoercion/native/mapZipWith.sql.out | 142 +++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 64 ++++++ 7 files changed, 621 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 15543c909a271..cc2b758faa43b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -446,6 +446,7 @@ object FunctionRegistry { expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), + expression[MapZipWith]("map_zip_with"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 10d9ee52facac..288b6358fbff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + MapZipWithCoercion :: EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: @@ -762,6 +763,30 @@ object TypeCoercion { } } + /** + * Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression + * to a common type. + */ + object MapZipWithCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Lambda function isn't resolved when the rule is executed. + case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && + MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && + !Cast.forceNullable(m.rightKeyType, finalKeyType) => + val newLeft = castIfNotSameType( + left, + MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) + val newRight = castIfNotSameType( + right, + MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) + MapZipWith(newLeft, newRight, function) + case _ => m + } + } + } + /** * Coerces the types of [[Elt]] children to expected ones. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 5d1b8c4da0bda..22210f692e755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,11 +22,11 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods /** * A named lambda variable. @@ -496,3 +496,194 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + {1:"ax",2:"by"} + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType + + @transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType + + @transient lazy val keyType = + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def checkArgumentDataTypes(): TypeCheckResult = { + super.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (leftKeyType.sameType(rightKeyType)) { + TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + } else { + TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + + s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + } + case failure => failure + } + } + + override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes() + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(input, value1, value2) + } + } + } + + @transient lazy val LambdaFunction(_, Seq( + keyVar: NamedLambdaVariable, + value1Var: NamedLambdaVariable, + value2Var: NamedLambdaVariable), + _) = function + + private def keyTypeSupportsEquals = keyType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + /** + * The function accepts two key arrays and returns a collection of keys with indexes + * to value arrays. Indexes are represented as an array of two items. This is a small + * optimization leveraging mutability of arrays. + */ + @transient private lazy val getKeysWithValueIndexes: + (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { + if (keyTypeSupportsEquals) { + getKeysWithIndexesFast + } else { + getKeysWithIndexesBruteForce + } + } + + private def assertSizeOfArrayBuffer(size: Int): Unit = { + if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + + s"unique keys due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + } + + private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { + val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + hashMap.get(key) match { + case Some(indexes) => + if (indexes(z).isEmpty) { + indexes(z) = Some(i) + } + case None => + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + hashMap.put(key, indexes) + } + i += 1 + } + } + hashMap + } + + private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, indexes) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(indexes(z).isEmpty) { + indexes(z) = Some(i) + } + } + j += 1 + } + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 + } + } + arrayBuffer + } + + private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { + val mapData1 = value1.asInstanceOf[MapData] + val mapData2 = value2.asInstanceOf[MapData] + val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) + val size = keysWithIndexes.size + val keys = new GenericArrayData(new Array[Any](size)) + val values = new GenericArrayData(new Array[Any](size)) + val valueData1 = mapData1.valueArray() + val valueData2 = mapData2.valueArray() + var i = 0 + for ((key, Array(index1, index2)) <- keysWithIndexes) { + val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) + val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) + keyVar.value.set(key) + value1Var.value.set(v1) + value2Var.value.set(v2) + keys.update(i, key) + values.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(keys, values) + } + + override def prettyName: String = "map_zip_with" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index bc7d04c77fa9e..3137dc9bec49a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -44,6 +44,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper LambdaFunction(function, Seq(lv1, lv2)) } + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + def transform(expr: Expression, f: Expression => Expression): Expression = { val at = expr.dataType.asInstanceOf[ArrayType] ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) @@ -267,4 +282,118 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), 15) } + + test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType] + val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType] + MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f)) + } + + val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii4 = MapFromArrays( + Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), + Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => k * v1 * v2 + } + + checkEvaluation( + map_zip_with(mii0, mii1, multiplyKeyWithValues), + Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) + checkEvaluation( + map_zip_with(mii0, mii2, multiplyKeyWithValues), + Map(1 -> null, 2 -> -80, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii3, multiplyKeyWithValues), + Map(1 -> null, 2 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii4, multiplyKeyWithValues), + Map(1 -> null, 2 -> 800, 3 -> null)) + checkEvaluation( + map_zip_with(mii4, mii0, multiplyKeyWithValues), + Map(2 -> 800, 1 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, miin, multiplyKeyWithValues), + null) + + val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) + val mss4 = MapFromArrays( + Literal.create(Seq("a", "a"), ArrayType(StringType, false)), + Literal.create(Seq("a", "n"), ArrayType(StringType, false))) + val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) + + val concat: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => Concat(Seq(k, v1, v2)) + } + + checkEvaluation( + map_zip_with(mss0, mss1, concat), + Map("a" -> null, "b" -> "byd", "d" -> "dzb")) + checkEvaluation( + map_zip_with(mss1, mss2, concat), + Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null)) + checkEvaluation( + map_zip_with(mss0, mss3, concat), + Map("a" -> null, "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mss4, concat), + Map("a" -> "axa", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss4, mss0, concat), + Map("a" -> "aax", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mssn, concat), + null) + + def b(data: Byte*): Array[Byte] = Array[Byte](data: _*) + + val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), + MapType(BinaryType, BinaryType, valueContainsNull = true)) + val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb4 = MapFromArrays( + Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), + Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) + val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) + + checkEvaluation( + map_zip_with(mbb0, mbb1, concat), + Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null)) + checkEvaluation( + map_zip_with(mbb1, mbb2, concat), + Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb3, concat), + Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb4, concat), + Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb4, mbb0, concat), + Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbbn, concat), + null) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql new file mode 100644 index 0000000000000..119f868cb48e6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -0,0 +1,66 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +); + +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out new file mode 100644 index 0000000000000..7f7e2f07b9e74 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -0,0 +1,142 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 1 schema +struct>> +-- !query 1 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 2 +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 2 schema +struct>> +-- !query 2 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 3 +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 3 schema +struct>> +-- !query 3 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 4 +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 4 schema +struct>> +-- !query 4 output +{2.0:{"k":2.0,"v1":1.0,"v2":1.0}} + + +-- !query 5 +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 + + +-- !query 6 +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 6 schema +struct>> +-- !query 6 output +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} + + +-- !query 7 +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 7 schema +struct>> +-- !query 7 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 8 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 8 schema +struct>> +-- !query 8 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 9 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 9 schema +struct>> +-- !query 9 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 10 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 10 schema +struct,struct,v1:array,v2:array>>> +-- !query 10 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 11 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 11 schema +struct,struct,v1:struct,v2:struct>>> +-- !query 11 output +{{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6401e3fc99783..8d7695b6ebbcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2238,6 +2238,70 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("cannot resolve '`a`'")) } + test("map_zip_with function - map of primitive types") { + val df = Seq( + (Map(8 -> 6L, 3 -> 5L, 6 -> 2L), Map[Integer, Integer]((6, 4), (8, 2), (3, 2))), + (Map(10 -> 6L, 8 -> 3L), Map[Integer, Integer]((8, 4), (4, null))), + (Map.empty[Int, Long], Map[Integer, Integer]((5, 1))), + (Map(5 -> 1L), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) + } + + test("map_zip_with function - map of non-primitive types") { + val df = Seq( + (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), + (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")), + (Map("a" -> "d"), Map.empty[String, String]), + (Map("a" -> "d"), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) + } + + test("map_zip_with function - invalid") { + val df = Seq( + (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1) + ).toDF("mii", "mis", "mss", "mmi", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") + } + assert(ex2.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") + } + assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + } + assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") + } + assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 80784a1de8d02536a94f3fd08ef632777478ab14 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 14 Aug 2018 09:57:01 -0700 Subject: [PATCH 1395/2461] [SPARK-18057][FOLLOW-UP] Use 127.0.0.1 to avoid zookeeper picking up an ipv6 address ## What changes were proposed in this pull request? I'm still seeing the Kafka tests failed randomly due to `kafka.zookeeper.ZooKeeperClientTimeoutException: Timed out waiting for connection while in state: CONNECTING`. I checked the test output and saw zookeeper picked up an ipv6 address. Most details can be found in https://issues.apache.org/jira/browse/KAFKA-7193 This PR just uses `127.0.0.1` rather than `localhost` to make sure zookeeper will never use an ipv6 address. ## How was this patch tested? Jenkins Closes #22097 from zsxwing/fix-zookeeper-connect. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../spark/sql/kafka010/KafkaTestUtils.scala | 80 +++++++++++-------- .../streaming/kafka010/KafkaTestUtils.scala | 79 +++++++++++------- 2 files changed, 96 insertions(+), 63 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index d89cccd3c5215..e58d18361966f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -39,6 +39,7 @@ import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} +import org.apache.kafka.common.utils.Exit import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -56,7 +57,7 @@ import org.apache.spark.util.Utils class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends Logging { // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 private val zkSessionTimeout = 6000 @@ -67,7 +68,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L private var adminClient: AdminClient = null // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = "127.0.0.1" private var brokerPort = 0 private var brokerConf: KafkaConfig = _ @@ -138,40 +139,55 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { - brokerReady = false - zkReady = false - - if (producer != null) { - producer.close() - producer = null + // There is a race condition that may kill JVM when terminating the Kafka cluster. We set + // a custom Procedure here during the termination in order to keep JVM running and not fail the + // tests. + val logExitEvent = new Exit.Procedure { + override def execute(statusCode: Int, message: String): Unit = { + logError(s"Prevent Kafka from killing JVM (statusCode: $statusCode message: $message)") + } } + Exit.setExitProcedure(logExitEvent) + Exit.setHaltProcedure(logExitEvent) + try { + brokerReady = false + zkReady = false - if (server != null) { - server.shutdown() - server.awaitShutdown() - server = null - } + if (producer != null) { + producer.close() + producer = null + } - // On Windows, `logDirs` is left open even after Kafka server above is completely shut down - // in some cases. It leads to test failures on Windows if the directory deletion failure - // throws an exception. - brokerConf.logDirs.foreach { f => - try { - Utils.deleteRecursively(new File(f)) - } catch { - case e: IOException if Utils.isWindows => - logWarning(e.getMessage) + if (server != null) { + server.shutdown() + server.awaitShutdown() + server = null } - } - if (zkUtils != null) { - zkUtils.close() - zkUtils = null - } + // On Windows, `logDirs` is left open even after Kafka server above is completely shut down + // in some cases. It leads to test failures on Windows if the directory deletion failure + // throws an exception. + brokerConf.logDirs.foreach { f => + try { + Utils.deleteRecursively(new File(f)) + } catch { + case e: IOException if Utils.isWindows => + logWarning(e.getMessage) + } + } - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } finally { + Exit.resetExitProcedure() + Exit.resetHaltProcedure() } } @@ -299,8 +315,8 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L protected def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("advertised.host.name", "localhost") + props.put("host.name", "127.0.0.1") + props.put("advertised.host.name", "127.0.0.1") props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) props.put("zookeeper.connect", zkAddress) diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index eef4c55d27d51..bd3cf9abddb5b 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -34,6 +34,7 @@ import kafka.utils.ZkUtils import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.StringSerializer +import org.apache.kafka.common.utils.Exit import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.apache.spark.SparkConf @@ -50,7 +51,7 @@ import org.apache.spark.util.Utils private[kafka010] class KafkaTestUtils extends Logging { // Zookeeper related configurations - private val zkHost = "localhost" + private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 private val zkSessionTimeout = 6000 @@ -60,7 +61,7 @@ private[kafka010] class KafkaTestUtils extends Logging { private var zkUtils: ZkUtils = _ // Kafka broker related configurations - private val brokerHost = "localhost" + private val brokerHost = "127.0.0.1" private var brokerPort = 0 private var brokerConf: KafkaConfig = _ @@ -125,40 +126,55 @@ private[kafka010] class KafkaTestUtils extends Logging { /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { - brokerReady = false - zkReady = false - - if (producer != null) { - producer.close() - producer = null + // There is a race condition that may kill JVM when terminating the Kafka cluster. We set + // a custom Procedure here during the termination in order to keep JVM running and not fail the + // tests. + val logExitEvent = new Exit.Procedure { + override def execute(statusCode: Int, message: String): Unit = { + logError(s"Prevent Kafka from killing JVM (statusCode: $statusCode message: $message)") + } } + Exit.setExitProcedure(logExitEvent) + Exit.setHaltProcedure(logExitEvent) + try { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } - if (server != null) { - server.shutdown() - server.awaitShutdown() - server = null - } + if (server != null) { + server.shutdown() + server.awaitShutdown() + server = null + } - // On Windows, `logDirs` is left open even after Kafka server above is completely shut down - // in some cases. It leads to test failures on Windows if the directory deletion failure - // throws an exception. - brokerConf.logDirs.foreach { f => - try { - Utils.deleteRecursively(new File(f)) - } catch { - case e: IOException if Utils.isWindows => - logWarning(e.getMessage) + // On Windows, `logDirs` is left open even after Kafka server above is completely shut down + // in some cases. It leads to test failures on Windows if the directory deletion failure + // throws an exception. + brokerConf.logDirs.foreach { f => + try { + Utils.deleteRecursively(new File(f)) + } catch { + case e: IOException if Utils.isWindows => + logWarning(e.getMessage) + } } - } - if (zkUtils != null) { - zkUtils.close() - zkUtils = null - } + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } finally { + Exit.resetExitProcedure() + Exit.resetHaltProcedure() } } @@ -217,7 +233,8 @@ private[kafka010] class KafkaTestUtils extends Logging { private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") - props.put("host.name", "localhost") + props.put("host.name", "127.0.0.1") + props.put("advertised.host.name", "127.0.0.1") props.put("port", brokerPort.toString) props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) From 10248758438b9ff57f5669a324a716c8c6c8f17b Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 14 Aug 2018 13:02:33 -0500 Subject: [PATCH 1396/2461] [SPARK-25088][CORE][MESOS][DOCS] Update Rest Server docs & defaults. ## What changes were proposed in this pull request? (a) disabled rest submission server by default in standalone mode (b) fails the standalone master if rest server enabled & authentication secret set (c) fails the mesos cluster dispatcher if authentication secret set (d) doc updates (e) when submitting a standalone app, only try the rest submission first if spark.master.rest.enabled=true otherwise you'd see a 10 second pause like 18/08/09 08:13:22 INFO RestSubmissionClient: Submitting a request to launch an application in spark://... 18/08/09 08:13:33 WARN RestSubmissionClient: Unable to connect to server spark://... I also made sure the mesos cluster dispatcher failed with the secret enabled, though I had to do that on slightly different code as I don't have mesos native libs around. ## How was this patch tested? I ran the tests in the mesos module & in core for org.apache.spark.deploy.* I ran a test on a cluster with standalone master to make sure I could still start with the right configs, and would fail the right way too. Closes #22071 from squito/rest_doc_updates. Authored-by: Imran Rashid Signed-off-by: Sean Owen --- .../org/apache/spark/deploy/SparkSubmitArguments.scala | 4 +++- .../scala/org/apache/spark/deploy/master/Master.scala | 10 +++++++++- .../spark/deploy/rest/RestSubmissionServer.scala | 1 + docs/running-on-mesos.md | 2 ++ docs/security.md | 7 ++++++- .../spark/deploy/mesos/MesosClusterDispatcher.scala | 8 ++++++++ 6 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fb232101114b9..0998757715457 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -82,7 +82,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var driverCores: String = null var submissionToKill: String = null var submissionToRequestStatusFor: String = null - var useRest: Boolean = true // used internally + var useRest: Boolean = false // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { @@ -115,6 +115,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() + useRest = sparkProperties.getOrElse("spark.master.rest.enabled", "false").toBoolean + validateArguments() /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2c78c15773af2..e1184248af460 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -121,10 +121,18 @@ private[deploy] class Master( } // Alternative application submission gateway that is stable across Spark versions - private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", false) private var restServer: Option[StandaloneRestServer] = None private var restServerBoundPort: Option[Int] = None + { + val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF + require(conf.getOption(authKey).isEmpty || !restServerEnabled, + s"The RestSubmissionServer does not support authentication via ${authKey}. Either turn " + + "off the RestSubmissionServer with spark.master.rest.enabled=false, or do not use " + + "authentication.") + } + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 3d99d085408c6..e59bf3f0eaf44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -51,6 +51,7 @@ private[spark] abstract class RestSubmissionServer( val host: String, val requestedPort: Int, val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet protected val statusRequestServlet: StatusRequestServlet diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 66ffb17949845..3e76d47608c74 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -174,6 +174,8 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. +Note that the `MesosClusterDispatcher` does not support authentication. You should ensure that all network access to it is +protected (port 7077 by default). By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. diff --git a/docs/security.md b/docs/security.md index 1de1d6318939a..c8eec730889c1 100644 --- a/docs/security.md +++ b/docs/security.md @@ -22,7 +22,12 @@ secrets to be secure. For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. This secret will be shared by all the daemons and applications, so this deployment configuration is -not as secure as the above, especially when considering multi-tenant clusters. +not as secure as the above, especially when considering multi-tenant clusters. In this +configuration, a user with the secret can effectively impersonate any other user. + +The Rest Submission Server and the MesosClusterDispatcher do not support authentication. You should +ensure that all network access to the REST API & MesosClusterDispatcher (port 6066 and 7077 +respectively by default) are restricted to hosts that are trusted to submit jobs.
      Property NameDefaultMeaning
      spark.history.ui.acls.enablespark.history.ui.acls.enable false Specifies whether ACLs should be checked to authorize users viewing the applications in @@ -292,7 +292,7 @@ To enable authorization in the SHS, a few extra options are used:
      spark.history.ui.admin.aclsspark.history.ui.admin.acls None Comma separated list of users that have view access to all the Spark applications in history @@ -300,7 +300,7 @@ To enable authorization in the SHS, a few extra options are used:
      spark.history.ui.admin.acls.groupsspark.history.ui.admin.acls.groups None Comma separated list of groups that have view access to all the Spark applications in history @@ -501,6 +501,7 @@ can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that c provided by the user on the client side are not used. ### Mesos mode + Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. @@ -562,8 +563,12 @@ Security. # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using tight -firewall settings. Below are the primary ports that Spark uses for its communication and how to +Generally speaking, a Spark cluster and its services are not deployed on the public internet. +They are generally private services, and should only be accessible within the network of the +organization that deploys Spark. Access to the hosts and ports used by Spark services should +be limited to origin hosts that need to access the services. + +Below are the primary ports that Spark uses for its communication and how to configure those ports. ## Standalone mode only @@ -597,6 +602,14 @@ configure those ports. SPARK_MASTER_PORT Set to "0" to choose a port randomly. Standalone mode only.
      External ServiceStandalone Master6066Submit job to cluster via REST APIspark.master.rest.portUse spark.master.rest.enabled to enable/disable this service. Standalone mode only.
      Standalone Master Standalone Worker
      diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index ccf33e8d4283c..64698b55c6bb6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -51,6 +51,14 @@ private[mesos] class MesosClusterDispatcher( conf: SparkConf) extends Logging { + { + // This doesn't support authentication because the RestSubmissionServer doesn't support it. + val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF + require(conf.getOption(authKey).isEmpty, + s"The MesosClusterDispatcher does not support authentication via ${authKey}. It is not " + + s"currently possible to run jobs in cluster mode with authentication on.") + } + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase() logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) From b81e3031fd247dfb4b3e02e0a986fb4b19d00f7c Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 14 Aug 2018 13:15:55 -0500 Subject: [PATCH 1397/2461] [SPARK-25043] print master and appId from spark-sql on startup ## What changes were proposed in this pull request? A small change to print the master and appId from spark-sql as with logging turned down all the way (`log4j.logger.org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver=WARN`), we may not know this information easily. This adds the following string before the `spark-sql>` prompt shows on the screen. `Spark master: yarn, Application Id: application_123456789_12345` ## How was this patch tested? I ran spark-sql locally and saw the appId displayed as expected. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22025 from abellina/SPARK-25043_print_master_and_app_id_from_sparksql. Lead-authored-by: Alessandro Bellina Co-authored-by: Alessandro Bellina Signed-off-by: Thomas Graves --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d9fd3ebd3c65d..bb96cea2b0ae1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -258,6 +258,8 @@ private[hive] object SparkSQLCLIDriver extends Logging { def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + cli.printMasterAndAppId + var currentPrompt = promptWithCurrentDB var line = reader.readLine(currentPrompt + "> ") @@ -323,6 +325,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { hiveVariables.asScala.foreach(kv => SparkSQLEnv.sqlContext.conf.setConfString(kv._1, kv._2)) } + def printMasterAndAppId(): Unit = { + val master = SparkSQLEnv.sparkContext.master + val appId = SparkSQLEnv.sparkContext.applicationId + console.printInfo(s"Spark master: $master, Application Id: $appId") + } + override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) From 3c614d0565a9652a12970dcdf8545432a4ac6f68 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 14 Aug 2018 16:40:00 -0700 Subject: [PATCH 1398/2461] [SPARK-25113][SQL] Add logging to CodeGenerator when any generated method's bytecode size goes above HugeMethodLimit ## What changes were proposed in this pull request? Add logging for all generated methods from the `CodeGenerator` whose bytecode size goes above 8000 bytes. This is to help with gathering stats on how often Spark is generating methods too big to be JIT'd. It covers all codegen scenarios, include whole-stage codegen and also individual expression codegen, e.g. unsafe projection, mutable projection, etc. ## How was this patch tested? Manually tested that logging did happen when generated method was above 8000 bytes. Also added a new unit test case to `CodeGenerationSuite` to verify that the logging did happen. Author: Kris Mok Closes #22103 from rednaxelafx/codegen-8k-logging. --- .../expressions/codegen/CodeGenerator.scala | 11 +++- .../expressions/CodeGenerationSuite.scala | 64 +++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4b30de5aeb7ab..e2d74daca56ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1258,7 +1258,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { - // This is the value of HugeMethodLimit in the OpenJDK JVM settings + // This is the default value of HugeMethodLimit in the OpenJDK HotSpot JVM, + // beyond which methods will be rejected from JIT compilation final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 // The max valid length of method parameters in JVM. @@ -1385,9 +1386,15 @@ object CodeGenerator extends Logging { try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) val stats = cf.methodInfos.asScala.flatMap { method => - method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + method.getAttributes().filter(_.getClass eq codeAttr).map { a => val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + + if (byteCodeSize > DEFAULT_JVM_HUGE_METHOD_LIMIT) { + logInfo("Generated method too long to be JIT compiled: " + + s"${cf.getThisClassName}.${method.getName} is $byteCodeSize bytes") + } + byteCodeSize } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5b71becee2de0..c383eec3d56b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import org.apache.log4j.{Appender, AppenderSkeleton, Logger} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.SparkFunSuite import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ @@ -499,4 +503,64 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil assert(names2.distinct.length == 4) } + + test("SPARK-25113: should log when there exists generated methods above HugeMethodLimit") { + class MockAppender extends AppenderSkeleton { + var seenMessage = false + + override def append(loggingEvent: LoggingEvent): Unit = { + if (loggingEvent.getRenderedMessage().contains("Generated method too long")) { + seenMessage = true + } + } + + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + val appender = new MockAppender() + withLogAppender(appender) { + val x = 42 + val expr = HugeCodeIntExpression(x) + val proj = GenerateUnsafeProjection.generate(Seq(expr)) + val actual = proj(null) + assert(actual.getInt(0) == x) + } + assert(appender.seenMessage) + } + + private def withLogAppender(appender: Appender)(f: => Unit): Unit = { + val logger = + Logger.getLogger(classOf[CodeGenerator[_, _]].getName) + logger.addAppender(appender) + try f finally { + logger.removeAppender(appender) + } + } +} + +case class HugeCodeIntExpression(value: Int) extends Expression { + override def nullable: Boolean = true + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Nil + override def eval(input: InternalRow): Any = value + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Assuming HugeMethodLimit to be 8000 + val HugeMethodLimit = CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT + // A single "int dummyN = 0;" will be at least 2 bytes of bytecode: + // 0: iconst_0 + // 1: istore_1 + // and it'll become bigger as the number of local variables increases. + // So 4000 such dummy local variable definitions are sufficient to bump the bytecode size + // of a generated method to above 8000 bytes. + val hugeCode = (0 until (HugeMethodLimit / 2)).map(i => s"int dummy$i = 0;").mkString("\n") + val code = + code"""{ + | $hugeCode + |} + |boolean ${ev.isNull} = false; + |int ${ev.value} = $value; + """.stripMargin + ev.copy(code = code) + } } From 92fd7f321c4c1c58e07e74ddaaa4932c7c27bcf4 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 15 Aug 2018 00:02:46 +0000 Subject: [PATCH 1399/2461] [SPARK-25115][CORE] Eliminate extra memory copy done when a ByteBuf is used that is backed by > 1 ByteBuffer. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …d by > 1 ByteBuffer. ## What changes were proposed in this pull request? Check how many ByteBuffer are used and depending on it do either call nioBuffer(...) or nioBuffers(...) to eliminate extra memory copies. This is related to netty/netty#8176. ## How was this patch tested? Unit tests added. Closes #22105 from normanmaurer/composite_byte_buf_mem_copy. Authored-by: Norman Maurer Signed-off-by: DB Tsai --- .../network/protocol/MessageWithHeader.java | 20 ++++++++++-- .../protocol/MessageWithHeaderSuite.java | 32 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index e7b66a6f33a82..b81c25afc737f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -140,8 +140,24 @@ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOExcept // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance // for the case that the passed-in buffer has too many components. int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); - ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); - int written = target.write(buffer); + // If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...) + // to eliminate extra memory copies. + int written = 0; + if (buf.nioBufferCount() == 1) { + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + written = target.write(buffer); + } else { + ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length); + for (ByteBuffer buffer: buffers) { + int remaining = buffer.remaining(); + int w = target.write(buffer); + written += w; + if (w < remaining) { + // Could not write all, we need to break now. + break; + } + } + } buf.skipBytes(written); return written; } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index ecb66fcf2ff76..3bff34e210e3c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -22,6 +22,7 @@ import java.nio.channels.WritableByteChannel; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import org.apache.spark.network.util.AbstractFileRegion; import org.junit.Test; @@ -48,7 +49,36 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { + testByteBufBody(Unpooled.copyLong(42)); + } + + @Test + public void testCompositeByteBufBodySingleBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header); + assertEquals(1, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + @Test + public void testCompositeByteBufBodyMultipleBuffers() throws Exception { ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header.retainedSlice(0, 4)); + compositeByteBuf.addComponent(true, header.slice(4, 4)); + assertEquals(2, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + /** + * Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header. + * + * @param header the header to use. + * @throws Exception thrown on error. + */ + private void testByteBufBody(ByteBuf header) throws Exception { + long expectedHeaderValue = header.getLong(header.readerIndex()); ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); assertEquals(1, header.refCnt()); assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); @@ -61,7 +91,7 @@ public void testByteBufBody() throws Exception { MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); - assertEquals(42, result.readLong()); + assertEquals(expectedHeaderValue, result.readLong()); assertEquals(84, result.readLong()); assertTrue(msg.release()); From ed075e1ff60cbb3e7b80b9d2f2ff37054412b934 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 14 Aug 2018 17:13:38 -0700 Subject: [PATCH 1400/2461] [SPARK-23874][SQL][PYTHON] Upgrade Apache Arrow to 0.10.0 ## What changes were proposed in this pull request? Upgrade Apache Arrow to 0.10.0 Version 0.10.0 has a number of bug fixes and improvements with the following pertaining directly to usage in Spark: * Allow for adding BinaryType support ARROW-2141 * Bug fix related to array serialization ARROW-1973 * Python2 str will be made into an Arrow string instead of bytes ARROW-2101 * Python bytearrays are supported in as input to pyarrow ARROW-2141 * Java has common interface for reset to cleanup complex vectors in Spark ArrowWriter ARROW-1962 * Cleanup pyarrow type equality checks ARROW-2423 * ArrowStreamWriter should not hold references to ArrowBlocks ARROW-2632, ARROW-2645 * Improved low level handling of messages for RecordBatch ARROW-2704 ## How was this patch tested? existing tests Author: Bryan Cutler Closes #21939 from BryanCutler/arrow-upgrade-010. --- dev/deps/spark-deps-hadoop-2.6 | 6 +++--- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- dev/deps/spark-deps-hadoop-3.1 | 6 +++--- pom.xml | 2 +- python/pyspark/serializers.py | 2 ++ .../sql/vectorized/ArrowColumnVector.java | 12 +++++------ .../sql/execution/arrow/ArrowWriter.scala | 20 +++---------------- .../vectorized/ArrowColumnVectorSuite.scala | 4 ++-- 8 files changed, 23 insertions(+), 35 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 3c0952f36a051..bdab79c24bbb6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -14,9 +14,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar avro-1.8.2.jar avro-ipc-1.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 310f1e4528374..ddaf9bbfc3cd9 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -14,9 +14,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar avro-1.8.2.jar avro-ipc-1.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 9bff2a1013910..d25d7aa862c56 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -12,9 +12,9 @@ aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar -arrow-format-0.8.0.jar -arrow-memory-0.8.0.jar -arrow-vector-0.8.0.jar +arrow-format-0.10.0.jar +arrow-memory-0.10.0.jar +arrow-vector-0.10.0.jar automaton-1.11-8.jar avro-1.8.2.jar avro-ipc-1.8.2.jar diff --git a/pom.xml b/pom.xml index 45fca285ddadb..979d70919d7a2 100644 --- a/pom.xml +++ b/pom.xml @@ -190,7 +190,7 @@ If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py, ./python/run-tests.py and ./python/setup.py too. --> - 0.8.0 + 0.10.0 ${java.home} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 82abf1947c818..47c4c3e663b97 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -229,12 +229,14 @@ def _create_batch(series, timezone): def create_array(s, t): mask = s.isnull() # Ensure timestamp series are in expected form for Spark internal representation + # TODO: maybe don't need None check anymore as of Arrow 0.9.1 if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) elif t is not None and pa.types.is_string(t) and sys.version < '3': # TODO: need decode before converting to Arrow in Python 2 + # TODO: don't need as of Arrow 0.9.1 return pa.Array.from_pandas(s.apply( lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) elif t is not None and pa.types.is_decimal(t) and \ diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5aed87f88a298..1c9beda404356 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -162,13 +162,13 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - } else if (vector instanceof NullableMapVector) { - NullableMapVector mapVector = (NullableMapVector) vector; - accessor = new StructAccessor(mapVector); + } else if (vector instanceof StructVector) { + StructVector structVector = (StructVector) vector; + accessor = new StructAccessor(structVector); - childColumns = new ArrowColumnVector[mapVector.size()]; + childColumns = new ArrowColumnVector[structVector.size()]; for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); + childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i)); } } else { throw new UnsupportedOperationException(); @@ -472,7 +472,7 @@ final ColumnarArray getArray(int rowId) { */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(NullableMapVector vector) { + StructAccessor(StructVector vector) { super(vector); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 3de6ea8bb2577..8dd484af6e908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -62,7 +61,7 @@ object ArrowWriter { case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) - case (StructType(_), vector: NullableMapVector) => + case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) } @@ -129,20 +128,7 @@ private[arrow] abstract class ArrowFieldWriter { } def reset(): Unit = { - // TODO: reset() should be in a common interface - valueVector match { - case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() - case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() - case listVector: ListVector => - // Manual "reset" the underlying buffer. - // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call - // `listVector.reset()`. - val buffers = listVector.getBuffers(false) - buffers.foreach(buf => buf.setZero(0, buf.capacity())) - listVector.setValueCount(0) - listVector.setLastSet(0) - case _ => - } + valueVector.reset() count = 0 } } @@ -323,7 +309,7 @@ private[arrow] class ArrayWriter( } private[arrow] class StructWriter( - val valueVector: NullableMapVector, + val valueVector: StructVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { override def setNull(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index b55489cb2678a..4592a1663faed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -336,7 +336,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) - .createVector(allocator).asInstanceOf[NullableMapVector] + .createVector(allocator).asInstanceOf[StructVector] vector.allocateNew() val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] @@ -373,7 +373,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) - .createVector(allocator).asInstanceOf[NullableMapVector] + .createVector(allocator).asInstanceOf[StructVector] vector.allocateNew() val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] From 19c45db47725a8087bd50d14d1005c53ac52e87d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Aug 2018 14:32:51 +0800 Subject: [PATCH 1401/2461] [SPARK-24505][SQL] Convert strings in codegen to blocks: Cast and BoundAttribute ## What changes were proposed in this pull request? This is split from #21520. This includes changes of `BoundAttribute` and `Cast`. This patch also adds few convenient APIs: ```scala CodeGenerator.freshVariable(name: String, dt: DataType): VariableValue CodeGenerator.freshVariable(name: String, javaClass: Class[_]): VariableValue JavaCode.javaType(javaClass: Class[_]): Inline JavaCode.javaType(dataType: DataType): Inline JavaCode.boxedType(dataType: DataType): Inline ``` ## How was this patch tested? Existing tests. Closes #21537 from viirya/SPARK-24505-1. Authored-by: Liang-Chi Hsieh Signed-off-by: hyukjinkwon --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 372 +++++++++--------- .../expressions/codegen/CodeGenerator.scala | 12 + .../expressions/codegen/javaCode.scala | 36 +- 4 files changed, 241 insertions(+), 183 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index df3ab05e02c76..77582e10f9ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -53,7 +53,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = CodeGenerator.javaType(dataType) + val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ba4d1314bab2b..100b9cfd70f52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -645,25 +645,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = - code""" - ${eval.code} - // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} - ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} - """) + ev.copy(code = eval.code + + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` // in parameter list, because the returned code will be put in null safe evaluation region. - private[this] type CastFunction = (String, String, String) => String + private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block private[this] def nullSafeCastFunction( from: DataType, to: DataType, ctx: CodegenContext): CastFunction = to match { - case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" - case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) @@ -684,18 +680,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - (c, evPrim, evNull) => s"$evPrim = $c;" + (c, evPrim, evNull) => code"$evPrim = $c;" case _: UserDefinedType[_] => throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, - result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { - s""" + private[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, + result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { + val javaType = JavaCode.javaType(resultType) + code""" boolean $resultIsNull = $inputIsNull; - ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; + $javaType $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -704,22 +701,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeArrayToStringBuilder( et: DataType, - array: String, - buffer: String, - ctx: CodegenContext): String = { + array: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") - val elementToStringFunc = ctx.addNewFunction(funcName, + val element = JavaCode.variable("element", et) + val elementStr = JavaCode.variable("elementStr", StringType) + val elementToStringFunc = inline"${ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { - | UTF8String elementStr = null; - | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) { + | UTF8String $elementStr = null; + | ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)} | return elementStr; |} - """.stripMargin) + """.stripMargin)}" - val loopIndex = ctx.freshName("loopIndex") - s""" + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + code""" |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { @@ -740,31 +739,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeMapToStringBuilder( kt: DataType, vt: DataType, - map: String, - buffer: String, - ctx: CodegenContext): String = { + map: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { def dataToStringFunc(func: String, dataType: DataType) = { val funcName = ctx.freshName(func) val dataToStringCode = castToStringCode(dataType, ctx) - ctx.addNewFunction(funcName, + val data = JavaCode.variable("data", dataType) + val dataStr = JavaCode.variable("dataStr", StringType) + val functionCall = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { - | UTF8String dataStr = null; - | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) { + | UTF8String $dataStr = null; + | ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)} | return dataStr; |} """.stripMargin) + inline"$functionCall" } val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) - val loopIndex = ctx.freshName("loopIndex") - val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") - val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") - val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) - val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) - s""" + val loopIndex = ctx.freshVariable("loopIndex", IntegerType) + val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData]) + val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData]) + val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType)) + val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt, + JavaCode.literal("0", IntegerType)) + val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex) + code""" |$buffer.append("["); |if ($map.numElements() > 0) { | $buffer.append($keyToStringFunc($getMapFirstKey)); @@ -789,20 +794,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeStructToStringBuilder( st: Seq[DataType], - row: String, - buffer: String, - ctx: CodegenContext): String = { + row: ExprValue, + buffer: ExprValue, + ctx: CodegenContext): Block = { val structToStringCode = st.zipWithIndex.map { case (ft, i) => val fieldToStringCode = castToStringCode(ft, ctx) - val field = ctx.freshName("field") - val fieldStr = ctx.freshName("fieldStr") - s""" - |${if (i != 0) s"""$buffer.append(",");""" else ""} + val field = ctx.freshVariable("field", ft) + val fieldStr = ctx.freshVariable("fieldStr", StringType) + val javaType = JavaCode.javaType(ft) + code""" + |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock} |if (!$row.isNullAt($i)) { - | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock} | | // Append $i field into the string buffer - | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; + | $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -811,11 +817,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } val writeStructCode = ctx.splitExpressions( - expressions = structToStringCode, + expressions = structToStringCode.map(_.code), funcName = "fieldToString", - arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + arguments = ("InternalRow", row.code) :: + (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil) - s""" + code""" |$buffer.append("["); |$writeStructCode |$buffer.append("]"); @@ -825,20 +832,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);" case DateType => - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeArrayElemCode; |$evPrim = $buffer.build(); @@ -846,10 +853,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case MapType(kt, vt, _) => (c, evPrim, evNull) => { - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) - s""" + code""" |$bufferClass $buffer = new $bufferClass(); |$writeMapElemCode; |$evPrim = $buffer.build(); @@ -857,11 +864,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case StructType(fields) => (c, evPrim, evNull) => { - val row = ctx.freshName("row") - val buffer = ctx.freshName("buffer") - val bufferClass = classOf[UTF8StringBuilder].getName + val row = ctx.freshVariable("row", classOf[InternalRow]) + val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) + val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) - s""" + code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); |$writeStructCode @@ -870,26 +877,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => - val udtRef = ctx.addReferenceObj("udt", udt) + val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType) (c, evPrim, evNull) => { - s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" } case _ => - (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" } } private[this] def castToBinaryCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + (c, evPrim, evNull) => code"$evPrim = $c.getBytes();" } private[this] def castToDateCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val intOpt = ctx.freshName("intOpt") - (c, evPrim, evNull) => s""" + val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]]) + (c, evPrim, evNull) => code""" scala.Option $intOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); if ($intOpt.isDefined()) { @@ -899,16 +906,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);""" case _ => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" } - private[this] def changePrecision(d: String, decimalType: DecimalType, - evPrim: String, evNull: String): String = - s""" + private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, + evPrim: ExprValue, evNull: ExprValue): Block = + code""" if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { $evPrim = $d; } else { @@ -920,11 +928,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { - val tmp = ctx.freshName("tmpDecimal") + val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision(tmp, target, evPrim, evNull)} @@ -934,37 +942,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case BooleanType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => // Note that we lose precision here. (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = $c.clone(); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: IntegralType => (c, evPrim, evNull) => - s""" + code""" Decimal $tmp = Decimal.apply((long) $c); ${changePrecision(tmp, target, evPrim, evNull)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => - s""" + code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision(tmp, target, evPrim, evNull)} @@ -979,10 +987,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) + val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) (c, evPrim, evNull) => - s""" + code""" scala.Option $longOpt = org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz); if ($longOpt.isDefined()) { @@ -992,18 +1000,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case _: IntegralType => - (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${longToTimeStampCode(c)};" case DateType => - val tz = ctx.addReferenceObj("timeZone", timeZone) + val tz = JavaCode.global(ctx.addReferenceObj("timeZone", timeZone), timeZone.getClass) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" + code"""$evPrim = + org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;""" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${decimalToTimestampCode(c)};" case DoubleType => (c, evPrim, evNull) => - s""" + code""" if (Double.isNaN($c) || Double.isInfinite($c)) { $evNull = true; } else { @@ -1012,7 +1021,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ case FloatType => (c, evPrim, evNull) => - s""" + code""" if (Float.isNaN($c) || Float.isInfinite($c)) { $evNull = true; } else { @@ -1024,7 +1033,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"""$evPrim = CalendarInterval.fromString($c.toString()); + code"""$evPrim = CalendarInterval.fromString($c.toString()); if(${evPrim} == null) { ${evNull} = true; } @@ -1032,18 +1041,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } - private[this] def decimalToTimestampCode(d: String): String = - s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" - private[this] def timestampToIntegerCode(ts: String): String = - s"java.lang.Math.floor((double) $ts / 1000000L)" - private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + private[this] def decimalToTimestampCode(d: ExprValue): Block = { + val block = inline"new java.math.BigDecimal(1000000L)" + code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" + } + private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * 1000000L" + private[this] def timestampToIntegerCode(ts: ExprValue): Block = + code"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: ExprValue): Block = + code"$ts / 1000000.0" private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - s""" + code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { @@ -1053,21 +1065,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case TimestampType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" case DateType => // Hive would return null when cast from date to boolean - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + (c, evPrim, evNull) => code"$evPrim = !$c.isZero();" case n: NumericType => - (c, evPrim, evNull) => s"$evPrim = $c != 0;" + (c, evPrim, evNull) => code"$evPrim = $c != 0;" } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; @@ -1077,24 +1089,24 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (byte) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + (c, evPrim, evNull) => code"$evPrim = $c.toByte();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + (c, evPrim, evNull) => code"$evPrim = (byte) $c;" } private[this] def castToShortCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; @@ -1104,22 +1116,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (short) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + (c, evPrim, evNull) => code"$evPrim = $c.toShort();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (short) $c;" + (c, evPrim, evNull) => code"$evPrim = (short) $c;" } private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("intWrapper") + val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; @@ -1129,23 +1141,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (int) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + (c, evPrim, evNull) => code"$evPrim = $c.toInt();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (int) $c;" + (c, evPrim, evNull) => code"$evPrim = (int) $c;" } private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("longWrapper") + val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper]) (c, evPrim, evNull) => - s""" + code""" UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; @@ -1155,21 +1167,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $wrapper = null; """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToIntegerCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + (c, evPrim, evNull) => code"$evPrim = $c.toLong();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (long) $c;" + (c, evPrim, evNull) => code"$evPrim = (long) $c;" } private[this] def castToFloatCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Float.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1177,21 +1189,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (float) $c;" + (c, evPrim, evNull) => code"$evPrim = (float) $c;" } private[this] def castToDoubleCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s""" + code""" try { $evPrim = Double.valueOf($c.toString()); } catch (java.lang.NumberFormatException e) { @@ -1199,31 +1211,32 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" + (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => s"$evNull = true;" + (c, evPrim, evNull) => code"$evNull = true;" case TimestampType => - (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => - (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" case x: NumericType => - (c, evPrim, evNull) => s"$evPrim = (double) $c;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" } private[this] def castArrayCode( fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) - val arrayClass = classOf[GenericArrayData].getName - val fromElementNull = ctx.freshName("feNull") - val fromElementPrim = ctx.freshName("fePrim") - val toElementNull = ctx.freshName("teNull") - val toElementPrim = ctx.freshName("tePrim") - val size = ctx.freshName("n") - val j = ctx.freshName("j") - val values = ctx.freshName("values") + val arrayClass = JavaCode.javaType(classOf[GenericArrayData]) + val fromElementNull = ctx.freshVariable("feNull", BooleanType) + val fromElementPrim = ctx.freshVariable("fePrim", fromType) + val toElementNull = ctx.freshVariable("teNull", BooleanType) + val toElementPrim = ctx.freshVariable("tePrim", toType) + val size = ctx.freshVariable("n", IntegerType) + val j = ctx.freshVariable("j", IntegerType) + val values = ctx.freshVariable("values", classOf[Array[Object]]) + val javaType = JavaCode.javaType(fromType) (c, evPrim, evNull) => - s""" + code""" final int $size = $c.numElements(); final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { @@ -1231,7 +1244,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${CodeGenerator.javaType(fromType)} $fromElementPrim = + $javaType $fromElementPrim = ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} @@ -1250,23 +1263,23 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) - val mapClass = classOf[ArrayBasedMapData].getName + val mapClass = JavaCode.javaType(classOf[ArrayBasedMapData]) - val keys = ctx.freshName("keys") - val convertedKeys = ctx.freshName("convertedKeys") - val convertedKeysNull = ctx.freshName("convertedKeysNull") + val keys = ctx.freshVariable("keys", ArrayType(from.keyType)) + val convertedKeys = ctx.freshVariable("convertedKeys", ArrayType(to.keyType)) + val convertedKeysNull = ctx.freshVariable("convertedKeysNull", BooleanType) - val values = ctx.freshName("values") - val convertedValues = ctx.freshName("convertedValues") - val convertedValuesNull = ctx.freshName("convertedValuesNull") + val values = ctx.freshVariable("values", ArrayType(from.valueType)) + val convertedValues = ctx.freshVariable("convertedValues", ArrayType(to.valueType)) + val convertedValuesNull = ctx.freshVariable("convertedValuesNull", BooleanType) (c, evPrim, evNull) => - s""" + code""" final ArrayData $keys = $c.keyArray(); final ArrayData $values = $c.valueArray(); - ${castCode(ctx, keys, "false", + ${castCode(ctx, keys, FalseLiteral, convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)} - ${castCode(ctx, values, "false", + ${castCode(ctx, values, FalseLiteral, convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)} $evPrim = new $mapClass($convertedKeys, $convertedValues); @@ -1279,17 +1292,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } - val rowClass = classOf[GenericInternalRow].getName - val tmpResult = ctx.freshName("tmpResult") - val tmpInput = ctx.freshName("tmpInput") + val tmpResult = ctx.freshVariable("tmpResult", classOf[GenericInternalRow]) + val rowClass = JavaCode.javaType(classOf[GenericInternalRow]) + val tmpInput = ctx.freshVariable("tmpInput", classOf[InternalRow]) val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => - val fromFieldPrim = ctx.freshName("ffp") - val fromFieldNull = ctx.freshName("ffn") - val toFieldPrim = ctx.freshName("tfp") - val toFieldNull = ctx.freshName("tfn") - val fromType = CodeGenerator.javaType(from.fields(i).dataType) - s""" + val fromFieldPrim = ctx.freshVariable("ffp", from.fields(i).dataType) + val fromFieldNull = ctx.freshVariable("ffn", BooleanType) + val toFieldPrim = ctx.freshVariable("tfp", to.fields(i).dataType) + val toFieldNull = ctx.freshVariable("tfn", BooleanType) + val fromType = JavaCode.javaType(from.fields(i).dataType) + val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim) + code""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); @@ -1301,18 +1315,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + $setColumn; } } """ } val fieldsEvalCodes = ctx.splitExpressions( - expressions = fieldsEvalCode, + expressions = fieldsEvalCode.map(_.code), funcName = "castStruct", - arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) + arguments = ("InternalRow", tmpInput.code) :: (rowClass.code, tmpResult.code) :: Nil) (input, result, resultIsNull) => - s""" + code""" final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpInput = $input; $fieldsEvalCodes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e2d74daca56ce..2c56456cd4dac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -581,6 +581,18 @@ class CodegenContext { s"${fullName}_$id" } + /** + * Creates an `ExprValue` representing a local java variable of required data type. + */ + def freshVariable(name: String, dt: DataType): VariableValue = + JavaCode.variable(freshName(name), dt) + + /** + * Creates an `ExprValue` representing a local java variable of required Java class. + */ + def freshVariable(name: String, javaClass: Class[_]): VariableValue = + JavaCode.variable(freshName(name), javaClass) + /** * Generates code for equal expression in Java. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 558cbfa560053..17d4a0dc4e884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -114,6 +114,21 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } + + /** + * Create an `Inline` for Java Class name. + */ + def javaType(javaClass: Class[_]): Inline = Inline(javaClass.getName) + + /** + * Create an `Inline` for Java Type name. + */ + def javaType(dataType: DataType): Inline = Inline(CodeGenerator.javaType(dataType)) + + /** + * Create an `Inline` for boxed Java Type name. + */ + def boxedType(dataType: DataType): Inline = Inline(CodeGenerator.boxedType(dataType)) } /** @@ -189,6 +204,16 @@ object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + /** + * A custom string interpolator which inlines a string into code block. + */ + implicit class InlineHelper(val sc: StringContext) extends AnyVal { + def inline(args: Any*): Inline = { + val inlineString = sc.raw(args: _*) + Inline(inlineString) + } + } + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { @@ -198,9 +223,8 @@ object Block { EmptyBlock } else { args.foreach { - case _: ExprValue => + case _: ExprValue | _: Inline | _: Block => case _: Int | _: Long | _: Float | _: Double | _: String => - case _: Block => case other => throw new IllegalArgumentException( s"Can not interpolate ${other.getClass.getName} into code block.") } @@ -270,6 +294,14 @@ case object EmptyBlock extends Block with Serializable { override def children: Seq[Block] = Seq.empty } +/** + * A piece of java code snippet inlines all types of input arguments into a string without + * tracking any reference of `JavaCode` instances. + */ +case class Inline(codeString: String) extends JavaCode { + override val code: String = codeString +} + /** * A typed java fragment that must be a valid java expression. */ From 4d8ae0d1c846560e1cac3480d73f8439968430a6 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 15 Aug 2018 12:06:11 -0500 Subject: [PATCH 1402/2461] [SPARK-25111][BUILD] increment kinesis client/producer & aws-sdk versions This PR has been superceded by #22081 ## What changes were proposed in this pull request? Increment the kinesis client, producer and transient AWS SDK versions to a more recent release. This is to help with the move off bouncy castle of #21146 and #22081; the goal is that moving up to the new SDK will allow a JVM with unlimited JCE but without bouncy castle to work with Kinesis endpoints. Why this specific set of artifacts? it syncs up with the 1.11.271 AWS SDK used by hadoop 3.0.3, hadoop-3.1. and hadoop 3.1.1; that's been stable for the uses there (s3, STS, dynamo). ## How was this patch tested? Running all the external/kinesis-asl tests via maven with java 8.121 & unlimited JCE, without bouncy castle (#21146); default endpoint of us-west.2. Without this SDK update I was getting http cert validation errors, with it they went away. # This PR is not ready without * Jenkins test runs to see what it is happy with * more testing: repeated runs, another endpoint * looking at the new deprecation warnings and selectively addressing them (the AWS SDKs are pretty aggressive about deprecation, but sometimes they increase the complexity of the client code or block some codepaths off completely) Closes #22099 from steveloughran/cloud/SPARK-25111-kinesis. Authored-by: Steve Loughran Signed-off-by: Sean Owen --- pom.xml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 979d70919d7a2..33c15f20ed404 100644 --- a/pom.xml +++ b/pom.xml @@ -143,11 +143,11 @@ 1.8.2 hadoop2 0.9.4 - 1.7.3 + 1.8.10 - 1.11.76 + 1.11.271 - 0.10.2 + 0.12.8 4.5.6 4.4.10 From bfb74394a5513134ea1da9fcf4a1783b77dd64e4 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 15 Aug 2018 13:31:28 -0700 Subject: [PATCH 1403/2461] [SPARK-24819][CORE] Fail fast when no enough slots to launch the barrier stage on job submitted ## What changes were proposed in this pull request? We shall check whether the barrier stage requires more slots (to be able to launch all tasks in the barrier stage together) than the total number of active slots currently, and fail fast if trying to submit a barrier stage that requires more slots than current total number. This PR proposes to add a new method `getNumSlots()` to try to get the total number of currently active slots in `SchedulerBackend`, support of this new method has been added to all the first-class scheduler backends except `MesosFineGrainedSchedulerBackend`. ## How was this patch tested? Added new test cases in `BarrierStageOnSubmittedSuite`. Closes #22001 from jiangxb1987/SPARK-24819. Lead-authored-by: Xingbo Jiang Co-authored-by: Xiangrui Meng Signed-off-by: Xiangrui Meng --- .../scala/org/apache/spark/SparkContext.scala | 9 ++ .../spark/internal/config/package.scala | 27 ++++++ .../BarrierJobAllocationFailed.scala | 62 +++++++++++++ .../apache/spark/scheduler/DAGScheduler.scala | 88 +++++++++++++----- .../spark/scheduler/SchedulerBackend.scala | 9 ++ .../CoarseGrainedSchedulerBackend.scala | 6 ++ .../local/LocalSchedulerBackend.scala | 2 + .../spark/BarrierStageOnSubmittedSuite.scala | 91 +++++++++++++++++-- .../ExecutorAllocationManagerSuite.scala | 2 + .../org/apache/spark/SparkContextSuite.scala | 1 + .../CoarseGrainedSchedulerBackendSuite.scala | 89 +++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../ExternalClusterManagerSuite.scala | 1 + .../scheduler/SchedulerIntegrationSuite.scala | 2 + .../scheduler/TaskSchedulerImplSuite.scala | 1 + .../MesosFineGrainedSchedulerBackend.scala | 4 + 16 files changed, 364 insertions(+), 32 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a7ffb354c09ca..e5b1e0ecd1586 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1602,6 +1602,15 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** + * Get the max number of tasks that can be concurrent launched currently. + * Note that please don't cache the value returned by this method, because the number can change + * due to add/remove executors. + * + * @return The max number of tasks that can be concurrent launched currently. + */ + private[spark] def maxNumConcurrentTasks(): Int = schedulerBackend.maxNumConcurrentTasks() + /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index eb08628ce1112..a8aa6914ffdae 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -577,4 +577,31 @@ package object config { .timeConf(TimeUnit.SECONDS) .checkValue(v => v > 0, "The value should be a positive time value.") .createWithDefaultString("365d") + + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL = + ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.interval") + .doc("Time in seconds to wait between a max concurrent tasks check failure and the next " + + "check. A max concurrent tasks check ensures the cluster can launch more concurrent " + + "tasks than required by a barrier stage on job submitted. The check can fail in case " + + "a cluster has just started and not enough executors have registered, so we wait for a " + + "little while and try to perform the check again. If the check fails more than a " + + "configured max failure times for a job then fail current job submission. Note this " + + "config only applies to jobs that contain one or more barrier stages, we won't perform " + + "the check on non-barrier jobs.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("15s") + + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES = + ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures") + .doc("Number of max concurrent tasks check failures allowed before fail a job submission. " + + "A max concurrent tasks check ensures the cluster can launch more concurrent tasks than " + + "required by a barrier stage on job submitted. The check can fail in case a cluster " + + "has just started and not enough executors have registered, so we wait for a little " + + "while and try to perform the check again. If the check fails more than a configured " + + "max failure times for a job then fail current job submission. Note this config only " + + "applies to jobs that contain one or more barrier stages, we won't perform the check on " + + "non-barrier jobs.") + .intConf + .checkValue(v => v > 0, "The max failures should be a positive value.") + .createWithDefault(40) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala b/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala new file mode 100644 index 0000000000000..803a0a1226d6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/BarrierJobAllocationFailed.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkException + +/** + * Exception thrown when submit a job with barrier stage(s) failing a required check. + */ +private[spark] class BarrierJobAllocationFailed(message: String) extends SparkException(message) + +private[spark] class BarrierJobUnsupportedRDDChainException + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + +private[spark] class BarrierJobRunWithDynamicAllocationException + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + +private[spark] class BarrierJobSlotsNumberCheckFailed + extends BarrierJobAllocationFailed( + BarrierJobAllocationFailed.ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + +private[spark] object BarrierJobAllocationFailed { + + // Error message when running a barrier stage that have unsupported RDD chain pattern. + val ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN = + "[SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of " + + "RDD chain within a barrier stage:\n1. Ancestor RDDs that have different number of " + + "partitions from the resulting RDD (eg. union()/coalesce()/first()/take()/" + + "PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head " + + "(scala) or barrierRdd.collect()[0] (python).\n" + + "2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))." + + // Error message when running a barrier stage with dynamic resource allocation enabled. + val ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION = + "[SPARK-24942]: Barrier execution mode does not support dynamic resource allocation for " + + "now. You can disable dynamic resource allocation by setting Spark conf " + + "\"spark.dynamicAllocation.enabled\" to \"false\"." + + // Error message when running a barrier stage that requires more slots than current total number. + val ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER = + "[SPARK-24819]: Barrier execution mode does not allow run a barrier stage that requires " + + "more slots than the total number of slots in the cluster currently. Please init a new " + + "cluster with more CPU cores or repartition the input RDD(s) to reduce the number of " + + "slots required to run this barrier stage." +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cf1fcbce78b30..2b0ca13485eb5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,8 +19,9 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +import java.util.function.BiFunction import scala.annotation.tailrec import scala.collection.Map @@ -39,7 +40,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{PartitionPruningRDD, RDD, RDDCheckpointData} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -111,8 +112,7 @@ import org.apache.spark.util._ * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ -private[spark] -class DAGScheduler( +private[spark] class DAGScheduler( private[scheduler] val sc: SparkContext, private[scheduler] val taskScheduler: TaskScheduler, listenerBus: LiveListenerBus, @@ -203,6 +203,24 @@ class DAGScheduler( sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + /** + * Number of max concurrent tasks check failures for each barrier job. + */ + private[scheduler] val barrierJobIdToNumTasksCheckFailures = new ConcurrentHashMap[Int, Int] + + /** + * Time in seconds to wait between a max concurrent tasks check failure and the next check. + */ + private val timeIntervalNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) + + /** + * Max number of max concurrent tasks check failures allowed for a job before fail the job + * submission. + */ + private val maxFailureNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") @@ -351,8 +369,7 @@ class DAGScheduler( val predicate: RDD[_] => Boolean = (r => r.getNumPartitions == numTasksInStage && r.dependencies.filter(_.rdd.isBarrier()).size <= 1) if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) { - throw new SparkException( - DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + throw new BarrierJobUnsupportedRDDChainException } } @@ -365,6 +382,7 @@ class DAGScheduler( def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) @@ -398,7 +416,20 @@ class DAGScheduler( */ private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { - throw new SparkException(DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + throw new BarrierJobRunWithDynamicAllocationException + } + } + + /** + * Check whether the barrier stage requires more slots (to be able to launch all tasks in the + * barrier stage together) than the total number of active slots currently. Fail current check + * if trying to submit a barrier stage that requires more slots than current total number. If + * the check fails consecutively beyond a configured number for a job, then fail current job + * submission. + */ + private def checkBarrierStageWithNumSlots(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && rdd.getNumPartitions > sc.maxNumConcurrentTasks) { + throw new BarrierJobSlotsNumberCheckFailed } } @@ -412,6 +443,7 @@ class DAGScheduler( jobId: Int, callSite: CallSite): ResultStage = { checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() @@ -929,11 +961,38 @@ class DAGScheduler( // HadoopRDD whose underlying HDFS files have been deleted. finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) } catch { + case e: BarrierJobSlotsNumberCheckFailed => + logWarning(s"The job $jobId requires to run a barrier stage that requires more slots " + + "than the total number of slots in the cluster currently.") + // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. + val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId, + new BiFunction[Int, Int, Int] { + override def apply(key: Int, value: Int): Int = value + 1 + }) + if (numCheckFailures <= maxFailureNumTasksCheck) { + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func, + partitions, callSite, listener, properties)) + }, + timeIntervalNumTasksCheck, + TimeUnit.SECONDS + ) + return + } else { + // Job failed, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + listener.jobFailed(e) + return + } + case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } + // Job submitted, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) clearCacheLocs() @@ -2011,19 +2070,4 @@ private[spark] object DAGScheduler { // Number of consecutive stage attempts allowed before a stage is aborted val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 - - // Error message when running a barrier stage that have unsupported RDD chain pattern. - val ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN = - "[SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of " + - "RDD chain within a barrier stage:\n1. Ancestor RDDs that have different number of " + - "partitions from the resulting RDD (eg. union()/coalesce()/first()/take()/" + - "PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head " + - "(scala) or barrierRdd.collect()[0] (python).\n" + - "2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))." - - // Error message when running a barrier stage with dynamic resource allocation enabled. - val ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION = - "[SPARK-24942]: Barrier execution mode does not support dynamic resource allocation for " + - "now. You can disable dynamic resource allocation by setting Spark conf " + - "\"spark.dynamicAllocation.enabled\" to \"false\"." } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 22db3350abfa7..c187ee146301b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -69,4 +69,13 @@ private[spark] trait SchedulerBackend { */ def getDriverLogUrls: Option[Map[String, String]] = None + /** + * Get the max number of tasks that can be concurrent launched currently. + * Note that please don't cache the value returned by this method, because the number can change + * due to add/remove executors. + * + * @return The max number of tasks that can be concurrent launched currently. + */ + def maxNumConcurrentTasks(): Int + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 375aeb0c34661..747e8c7dc0fa5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -496,6 +496,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.keySet.toSeq } + override def maxNumConcurrentTasks(): Int = { + executorDataMap.values.map { executor => + executor.totalCores / scheduler.CPUS_PER_TASK + }.sum + } + /** * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index cf8b0ff4f7019..0de57fbd5600c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -156,6 +156,8 @@ private[spark] class LocalSchedulerBackend( override def applicationId(): String = appId + override def maxNumConcurrentTasks(): Int = totalCores / scheduler.CPUS_PER_TASK + private def stop(finalState: SparkAppHandle.State): Unit = { localEndpoint.ask(StopExecutor) try { diff --git a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala index 2f21e61ce9c97..d49ab4aa7df12 100644 --- a/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala +++ b/core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala @@ -21,6 +21,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.rdd.{PartitionPruningRDD, RDD} +import org.apache.spark.scheduler.BarrierJobAllocationFailed._ import org.apache.spark.scheduler.DAGScheduler import org.apache.spark.util.ThreadUtils @@ -63,7 +64,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .barrier() .mapPartitions(iter => iter) testSubmitJob(sc, rdd, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } test("submit a barrier ShuffleMapStage that contains PartitionPruningRDD") { @@ -75,7 +76,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .repartition(2) .map(x => x + 1) testSubmitJob(sc, rdd, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } test("submit a barrier stage that doesn't contain PartitionPruningRDD") { @@ -96,7 +97,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .barrier() .mapPartitions(iter => iter) testSubmitJob(sc, rdd, Some(Seq(1, 3)), - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } test("submit a barrier stage with union()") { @@ -110,7 +111,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .map(x => x * 2) // Fail the job on submit because the barrier RDD (rdd1) may be not assigned Task 0. testSubmitJob(sc, rdd3, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } test("submit a barrier stage with coalesce()") { @@ -122,7 +123,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext // Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage // only launches 1 task. testSubmitJob(sc, rdd, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } test("submit a barrier stage that contains an RDD that depends on multiple barrier RDDs") { @@ -137,7 +138,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .zip(rdd2) .map(x => x._1 + x._2) testSubmitJob(sc, rdd3, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN) } test("submit a barrier stage with zip()") { @@ -166,7 +167,7 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .barrier() .mapPartitions(iter => iter) testSubmitJob(sc, rdd, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) } test("submit a barrier ShuffleMapStage with dynamic resource allocation enabled") { @@ -183,6 +184,80 @@ class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext .repartition(2) .map(x => x + 1) testSubmitJob(sc, rdd, - message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + message = ERROR_MESSAGE_RUN_BARRIER_WITH_DYN_ALLOCATION) + } + + test("submit a barrier ResultStage that requires more slots than current total under local " + + "mode") { + val conf = new SparkConf() + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ShuffleMapStage that requires more slots than current total under " + + "local mode") { + val conf = new SparkConf() + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local[4]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ResultStage that requires more slots than current total under " + + "local-cluster mode") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) + } + + test("submit a barrier ShuffleMapStage that requires more slots than current total under " + + "local-cluster mode") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + // Shorten the time interval between two failed checks to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.interval", "1s") + // Reduce max check failures allowed to make the test fail faster. + .set("spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures", "3") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = createSparkContext(Some(conf)) + val rdd = sc.parallelize(1 to 10, 5) + .barrier() + .mapPartitions(iter => iter) + .repartition(2) + .map(x => x + 1) + testSubmitJob(sc, rdd, + message = ERROR_MESSAGE_BARRIER_REQUIRE_MORE_SLOTS_THAN_CURRENT_TOTAL_NUMBER) } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 3cfb0a9feb32b..659ebb60fef86 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -1376,6 +1376,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def defaultParallelism(): Int = sb.defaultParallelism() + override def maxNumConcurrentTasks(): Int = sb.maxNumConcurrentTasks() + override def killExecutorsOnHost(host: String): Boolean = { false } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index cb44110e30135..e1666a35271d3 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -654,6 +654,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu .setMaster("local-cluster[3, 1, 1024]") .setAppName("test-cluster") sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 04cccc67e328e..80c9c6f0422a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,10 +17,18 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicBoolean + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually + import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD import org.apache.spark.util.{RpcUtils, SerializableBuffer} -class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext + with Eventually { test("serialized task larger than max RPC message size") { val conf = new SparkConf @@ -38,4 +46,83 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(smaller.size === 4) } + test("compute max number of concurrent tasks can be launched") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + assert(sc.maxNumConcurrentTasks() == 12) + } + + test("compute max number of concurrent tasks can be launched when spark.task.cpus > 1") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + // Each executor can only launch one task since `spark.task.cpus` is 2. + assert(sc.maxNumConcurrentTasks() == 4) + } + + test("compute max number of concurrent tasks can be launched when some executors are busy") { + val conf = new SparkConf() + .set("spark.task.cpus", "2") + .setMaster("local-cluster[4, 3, 1024]") + .setAppName("test") + sc = new SparkContext(conf) + val rdd = sc.parallelize(1 to 10, 4).mapPartitions { iter => + Thread.sleep(5000) + iter + } + var taskStarted = new AtomicBoolean(false) + var taskEnded = new AtomicBoolean(false) + val listener = new SparkListener() { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + taskStarted.set(true) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskEnded.set(true) + } + } + + try { + sc.addSparkListener(listener) + eventually(timeout(10.seconds)) { + // Ensure all executors have been launched. + assert(sc.getExecutorIds().length == 4) + } + + // Submit a job to trigger some tasks on active executors. + testSubmitJob(sc, rdd) + + eventually(timeout(10.seconds)) { + // Ensure some tasks have started and no task finished, so some executors must be busy. + assert(taskStarted.get() == true) + assert(taskEnded.get() == false) + // Assert we count in slots on both busy and free executors. + assert(sc.maxNumConcurrentTasks() == 4) + } + } finally { + sc.removeSparkListener(listener) + } + } + + private def testSubmitJob(sc: SparkContext, rdd: RDD[Int]): Unit = { + sc.submitJob( + rdd, + (iter: Iterator[Int]) => iter.toArray, + 0 until rdd.partitions.length, + { case (_, _) => return }: (Int, Array[Int]) => Unit, + { return } + ) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3fbe636607687..6eeddbb763172 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -215,7 +215,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } private def init(testConf: SparkConf): Unit = { - sc = new SparkContext("local", "DAGSchedulerSuite", testConf) + sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 02b19e01ce7a0..b4705914b999b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -69,6 +69,7 @@ private class DummySchedulerBackend extends SchedulerBackend { def stop() {} def reviveOffers() {} def defaultParallelism(): Int = 1 + def maxNumConcurrentTasks(): Int = 0 } private class DummyTaskScheduler extends TaskScheduler { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 75ea409e16b4b..cea7f173c8f2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -385,6 +385,8 @@ private[spark] abstract class MockBackend( }.toIndexedSeq } + override def maxNumConcurrentTasks(): Int = 0 + /** * This is called by the scheduler whenever it has tasks it would like to schedule, when a tasks * completes (which will be in a result-getter thread), and by the reviveOffers thread for delay diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 38e26a82e750f..ca9bf08cee654 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -36,6 +36,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def stop() {} def reviveOffers() {} def defaultParallelism(): Int = 1 + def maxNumConcurrentTasks(): Int = 0 } class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 71a70ff048ccc..0bb6fe0fa4bdf 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -453,4 +453,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( super.applicationId } + override def maxNumConcurrentTasks(): Int = { + // TODO SPARK-25074 support this method for MesosFineGrainedSchedulerBackend + 0 + } } From 717f58e9ced9b43940719dfc8675216009a4c2e9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 15 Aug 2018 14:42:48 -0700 Subject: [PATCH 1404/2461] [SPARK-24685][BUILD] Restore support for building old Hadoop versions of 2.1. Update the release scripts to build binary packages for older versions of Hadoop when building Spark 2.1. Also did some minor refactoring of that part of the script so that changing these later is easier. This was used to build the missing packages from 2.1.3-rc2. Author: Marcelo Vanzin Closes #21661 from vanzin/SPARK-24685. --- dev/create-release/release-build.sh | 50 +++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 24a62a8f4c7d3..73610a3335910 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -178,12 +178,17 @@ if [[ "$1" == "package" ]]; then SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION + ZINC_PORT=3035 + # Updated for each binary build make_binary_release() { NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - BUILD_PACKAGE=$4 + FLAGS="$MVN_EXTRA_OPTS -B $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES $2" + BUILD_PACKAGE=$3 + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + ZINC_PORT=$((ZINC_PORT + 1)) echo "Building binary dist $NAME" cp -r spark spark-$SPARK_VERSION-bin-$NAME @@ -255,20 +260,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then - error "Failed to build hadoop2.6 package. Check logs for details." - fi + # List of binary packages built. Populates two associative arrays, where the key is the "name" of + # the package being built, and the values are respectively the needed maven arguments for building + # the package, and any extra package needed for that particular combination. + # + # In dry run mode, only build the first one. The keys in BINARY_PKGS_ARGS are used as the + # list of packages to be built, so it's ok for things to be missing in BINARY_PKGS_EXTRA. + + declare -A BINARY_PKGS_ARGS + BINARY_PKGS_ARGS["hadoop2.7"]="-Phadoop-2.7 $HIVE_PROFILES" if ! is_dry_run; then - if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then - error "Failed to build hadoop2.7 package. Check logs for details." - fi - if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then - error "Failed to build without-hadoop package. Check logs for details." + BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" + BINARY_PKGS_ARGS["without-hadoop"]="-Pwithout-hadoop" + if [[ $SPARK_VERSION < "2.2." ]]; then + BINARY_PKGS_ARGS["hadoop2.4"]="-Phadoop-2.4 $HIVE_PROFILES" + BINARY_PKGS_ARGS["hadoop2.3"]="-Phadoop-2.3 $HIVE_PROFILES" fi fi + declare -A BINARY_PKGS_EXTRA + BINARY_PKGS_EXTRA["hadoop2.7"]="withpip" + if ! is_dry_run; then + BINARY_PKGS_EXTRA["hadoop2.6"]="withr" + fi + + echo "Packages to build: ${!BINARY_PKGS_ARGS[@]}" + for key in ${!BINARY_PKGS_ARGS[@]}; do + args=${BINARY_PKGS_ARGS[$key]} + extra=${BINARY_PKGS_EXTRA[$key]} + if ! make_binary_release "$key" "$args" "$extra"; then + error "Failed to build $key package. Check logs for details." + fi + done + rm -rf spark-$SPARK_VERSION-bin-*/ if ! is_dry_run; then From a791c29bd824adadfb2d85594bc8dad4424df936 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 15 Aug 2018 17:52:12 -0700 Subject: [PATCH 1405/2461] [SPARK-23984][K8S] Changed Python Version config to be camelCase ## What changes were proposed in this pull request? Small formatting change to have Python Version be camelCase as per request during PR review. ## How was this patch tested? Tested with unit and integration tests Author: Ilan Filonenko Closes #22095 from ifilonenko/spark-py-edits. --- docs/running-on-kubernetes.md | 2 +- .../src/main/scala/org/apache/spark/deploy/k8s/Config.scala | 2 +- .../spark/deploy/k8s/integrationtest/PythonTestsSuite.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 97c650d0f80aa..8f84ca044e163 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -769,7 +769,7 @@ specific to Spark on Kubernetes. - + diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 1c1f40c028a97..e3d67c34d53eb 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -61,7 +61,7 @@ In `cluster` mode, the driver runs on a different machine than the client, so `S # Preparations Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. +Binary distributions can be downloaded from the [downloads page](https://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). To make Spark runtime jars accessible from YARN side, you can specify `spark.yarn.archive` or `spark.yarn.jars`. For details please refer to [Spark Properties](running-on-yarn.html#spark-properties). If neither `spark.yarn.archive` nor `spark.yarn.jars` is specified, Spark will create a zip file with all jars under `$SPARK_HOME/jars` and upload it to the distributed cache. diff --git a/docs/security.md b/docs/security.md index c8eec730889c1..7fb3e17de94c9 100644 --- a/docs/security.md +++ b/docs/security.md @@ -49,7 +49,7 @@ respectively by default) are restricted to hosts that are trusted to submit jobs Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC authentication must also be enabled and properly configured. AES encryption uses the -[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +[Apache Commons Crypto](https://commons.apache.org/proper/commons-crypto/) library, and Spark's configuration system allows access to that library's configuration for advanced users. There is also support for SASL-based encryption, although it should be considered deprecated. It @@ -169,7 +169,7 @@ The following settings cover enabling encryption for data written to disk: ## Authentication and Authorization -Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +Enabling authentication for the Web UIs is done using [javax servlet filters](https://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). You will need a filter that implements the authentication method you want to deploy. Spark does not provide any built-in authentication filters. @@ -492,7 +492,7 @@ distributed with the application using the `--files` command line argument (or t configuration should just reference the file name with no absolute path. Distributing local key stores this way may require the files to be staged in HDFS (or other similar -distributed file system used by the cluster), so it's recommended that the undelying file system be +distributed file system used by the cluster), so it's recommended that the underlying file system be configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode diff --git a/docs/sparkr.md b/docs/sparkr.md index 84e9b4ac6db7f..b4248e8bb21de 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -128,7 +128,7 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. -SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](http://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://spark.apache.org/third-party-projects.html), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
      diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d9ebc3cfe4674..8e308d5aa05e0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1796,7 +1796,7 @@ strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/ on how to label columns when constructing a `pandas.DataFrame`. Note that all data for a group will be loaded into memory before the function is applied. This can -lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user to ensure that the grouped data will fit into the available memory. @@ -1876,7 +1876,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuaration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. + - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. @@ -2162,7 +2162,7 @@ See the API docs for `SQLContext.read` ( Python ) and `DataFrame.write` ( Scala, - Java, + Java, Python ) more information. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 678b0643fd706..6a52e8a7b0ebd 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -196,7 +196,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download a Spark binary from the [download site](http://spark.apache.org/downloads.html). +- Download a Spark binary from the [download site](https://spark.apache.org/downloads.html). - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 118b05355c74d..0ca0f2a8b54d5 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -915,8 +915,7 @@ JavaPairDStream runningCounts = pairs.updateStateByKey(updateFu The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Java code, take a look at the example -[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming -/JavaStatefulNetworkWordCount.java). +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java).
      @@ -2470,7 +2469,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* Third-party DStream data sources can be found in [Third Party Projects](http://spark.apache.org/third-party-projects.html) +* Third-party DStream data sources can be found in [Third Party Projects](https://spark.apache.org/third-party-projects.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index b832f7197ace6..355a6cc26973e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -17,7 +17,7 @@ In this guide, we are going to walk you through the programming model and the AP # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in [Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). -And if you [download Spark](http://spark.apache.org/downloads.html), you can directly [run the example](index.html#running-the-examples-and-shell). In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark. +And if you [download Spark](https://spark.apache.org/downloads.html), you can directly [run the example](index.html#running-the-examples-and-shell). In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
      From 99d2e4e00711cffbfaee8cb3da9b6b3feab8ff18 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 21 Aug 2018 11:26:41 -0700 Subject: [PATCH 1440/2461] [SPARK-24296][CORE] Replicate large blocks as a stream. When replicating large cached RDD blocks, it can be helpful to replicate them as a stream, to avoid using large amounts of memory during the transfer. This also allows blocks larger than 2GB to be replicated. Added unit tests in DistributedSuite. Also ran tests on a cluster for blocks > 2gb. Closes #21451 from squito/clean_replication. Authored-by: Imran Rashid Signed-off-by: Marcelo Vanzin --- .../server/TransportRequestHandler.java | 2 +- .../protocol/BlockTransferMessage.java | 3 +- .../shuffle/protocol/UploadBlockStream.java | 89 +++++++++++++++++++ .../org/apache/spark/executor/Executor.scala | 4 +- .../spark/internal/config/package.scala | 7 ++ .../spark/network/BlockDataManager.scala | 12 +++ .../network/netty/NettyBlockRpcServer.scala | 26 +++++- .../netty/NettyBlockTransferService.scala | 39 ++++---- .../apache/spark/storage/BlockManager.scala | 66 +++++++++++++- .../storage/BlockManagerManagedBuffer.scala | 7 +- .../org/apache/spark/storage/DiskStore.scala | 5 +- .../spark/util/io/ChunkedByteBuffer.scala | 4 + .../org/apache/spark/DistributedSuite.scala | 25 +++++- .../spark/security/EncryptionFunSuite.scala | 12 ++- .../apache/spark/storage/DiskStoreSuite.scala | 3 +- project/MimaExcludes.scala | 3 +- 16 files changed, 270 insertions(+), 37 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index c6fd56b9291e5..9fac96dbe450d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -234,7 +234,7 @@ public void onComplete(String streamId) throws IOException { callback.onSuccess(ByteBuffer.allocate(0)); } catch (Exception ex) { IOException ioExc = new IOException("Failure post-processing complete stream;" + - " failing this rpc and leaving channel active"); + " failing this rpc and leaving channel active", ex); callback.onFailure(ioExc); streamHandler.onFailure(streamId, ioExc); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 9af6759f5d5f3..a68a297519b66 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -42,7 +42,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); private final byte id; @@ -67,6 +67,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); + case 6: return UploadBlockStream.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java new file mode 100644 index 0000000000000..9df30967d5bb2 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlockStream.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A request to Upload a block, which the destination should receive as a stream. + * + * The actual block data is not contained here. It will be passed to the StreamCallbackWithID + * that is returned from RpcHandler.receiveStream() + */ +public class UploadBlockStream extends BlockTransferMessage { + public final String blockId; + public final byte[] metadata; + + public UploadBlockStream(String blockId, byte[] metadata) { + this.blockId = blockId; + this.metadata = metadata; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK_STREAM; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(blockId); + return objectsHashCode * 41 + Arrays.hashCode(metadata); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlockStream) { + UploadBlockStream o = (UploadBlockStream) other; + return Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + } + + public static UploadBlockStream decode(ByteBuf buf) { + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + return new UploadBlockStream(blockId, metadata); + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b1856ff0f3247..86b19578037df 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -363,14 +363,14 @@ private[spark] class Executor( threadMXBean.getCurrentThreadCpuTime } else 0L var threwException = true - val value = try { + val value = Utils.tryWithSafeFinally { val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res - } finally { + } { val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a8aa6914ffdae..daf3f070d72e9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -568,6 +568,13 @@ package object config { .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = + ConfigBuilder("spark.storage.memoryMapLimitForTests") + .internal() + .doc("For testing only, controls the size of chunks when memory mapping a file") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Int.MaxValue) + private[spark] val BARRIER_SYNC_TIMEOUT = ConfigBuilder("spark.barrier.sync.timeout") .doc("The timeout in seconds for each barrier() call from a barrier task. If the " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index b3f8bfe8b1d48..e94a01244474c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.network import scala.reflect.ClassTag import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] @@ -43,6 +44,17 @@ trait BlockDataManager { level: StorageLevel, classTag: ClassTag[_]): Boolean + /** + * Put the given block that will be received as a stream. + * + * When this method is called, the block data itself is not available -- it will be passed to the + * returned StreamCallbackWithID. + */ + def putBlockDataAsStream( + blockId: BlockId, + level: StorageLevel, + classTag: ClassTag[_]): StreamCallbackWithID + /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index eb4cf94164fd4..7076701421e2e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -26,9 +26,9 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.NioManagedBuffer -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} +import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -73,10 +73,32 @@ class NettyBlockRpcServer( } val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) + logDebug(s"Receiving replicated block $blockId with level ${level} " + + s"from ${client.getSocketAddress}") blockManager.putBlockData(blockId, data, level, classTag) responseContext.onSuccess(ByteBuffer.allocate(0)) } } + override def receiveStream( + client: TransportClient, + messageHeader: ByteBuffer, + responseContext: RpcResponseCallback): StreamCallbackWithID = { + val message = + BlockTransferMessage.Decoder.fromByteBuffer(messageHeader).asInstanceOf[UploadBlockStream] + val (level: StorageLevel, classTag: ClassTag[_]) = { + serializer + .newInstance() + .deserialize(ByteBuffer.wrap(message.metadata)) + .asInstanceOf[(StorageLevel, ClassTag[_])] + } + val blockId = BlockId(message.blockId) + logDebug(s"Receiving replicated block $blockId with level ${level} as stream " + + s"from ${client.getSocketAddress}") + // This will return immediately, but will setup a callback on streamData which will still + // do all the processing in the netty thread. + blockManager.putBlockDataAsStream(blockId, level, classTag) + } + override def getStreamManager(): StreamManager = streamManager } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7d8c35032763..1905632a936d3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,13 +27,14 @@ import scala.reflect.ClassTag import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.config import org.apache.spark.network._ -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} -import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -148,20 +149,28 @@ private[spark] class NettyBlockTransferService( // Everything else is encoded using our binary protocol. val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag))) - // Convert or copy nio buffer into array in order to serialize it. - val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + val callback = new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}") + result.success((): Unit) + } - client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, - new RpcResponseCallback { - override def onSuccess(response: ByteBuffer): Unit = { - logTrace(s"Successfully uploaded block $blockId") - result.success((): Unit) - } - override def onFailure(e: Throwable): Unit = { - logError(s"Error while uploading block $blockId", e) - result.failure(e) - } - }) + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading $blockId${if (asStream) " as stream" else ""}", e) + result.failure(e) + } + } + if (asStream) { + val streamHeader = new UploadBlockStream(blockId.name, metadata).toByteBuffer + client.uploadStream(new NioManagedBuffer(streamHeader), blockData, callback) + } else { + // Convert or copy nio buffer into array in order to serialize it. + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) + + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, + callback) + } result.future } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5cd21e31c9554..e7cdfab99b34d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -41,6 +41,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -406,6 +407,63 @@ private[spark] class BlockManager( putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag) } + override def putBlockDataAsStream( + blockId: BlockId, + level: StorageLevel, + classTag: ClassTag[_]): StreamCallbackWithID = { + // TODO if we're going to only put the data in the disk store, we should just write it directly + // to the final location, but that would require a deeper refactor of this code. So instead + // we just write to a temp file, and call putBytes on the data in that file. + val tmpFile = diskBlockManager.createTempLocalBlock()._2 + val channel = new CountingWritableChannel( + Channels.newChannel(serializerManager.wrapForEncryption(new FileOutputStream(tmpFile)))) + logTrace(s"Streaming block $blockId to tmp file $tmpFile") + new StreamCallbackWithID { + + override def getID: String = blockId.name + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.hasRemaining) { + channel.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + logTrace(s"Done receiving block $blockId, now putting into local blockManager") + // Read the contents of the downloaded file as a buffer to put into the blockManager. + // Note this is all happening inside the netty thread as soon as it reads the end of the + // stream. + channel.close() + // TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up + // using a lot of memory here. With encryption, we'll read the whole file into a regular + // byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm + // OOM, but might get killed by the OS / cluster manager. We could at least read the tmp + // file as a stream in both cases. + val buffer = securityManager.getIOEncryptionKey() match { + case Some(key) => + // we need to pass in the size of the unencrypted block + val blockSize = channel.getCount + val allocator = level.memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator) + + case None => + ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) + } + putBytes(blockId, buffer, level)(classTag) + tmpFile.delete() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + // the framework handles the connection itself, we just need to do local cleanup + channel.close() + tmpFile.delete() + } + } + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing. @@ -667,7 +725,7 @@ private[spark] class BlockManager( // TODO if we change this method to return the ManagedBuffer, then getRemoteValues // could just use the inputStream on the temp file, rather than memory-mapping the file. // Until then, replication can cause the process to use too much memory and get killed - // by the OS / cluster manager (not a java OOM, since its a memory-mapped file) even though + // by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though // we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") @@ -1358,12 +1416,16 @@ private[spark] class BlockManager( try { val onePeerStartTime = System.nanoTime logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") + // This thread keeps a lock on the block, so we do not want the netty thread to unlock + // block when it finishes sending the message. + val buffer = new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false, + unlockOnDeallocate = false) blockTransferService.uploadBlockSync( peer.host, peer.port, peer.executorId, blockId, - new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), + buffer, tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index 3d3806126676c..5c12b5cee4d2f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -38,7 +38,8 @@ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, data: BlockData, - dispose: Boolean) extends ManagedBuffer { + dispose: Boolean, + unlockOnDeallocate: Boolean = true) extends ManagedBuffer { private val refCount = new AtomicInteger(1) @@ -58,7 +59,9 @@ private[storage] class BlockManagerManagedBuffer( } override def release(): ManagedBuffer = { - blockInfoManager.unlock(blockId) + if (unlockOnDeallocate) { + blockInfoManager.unlock(blockId) + } if (refCount.decrementAndGet() == 0 && dispose) { data.dispose() } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ef526fd884058..a820bc70b33b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -29,7 +29,7 @@ import com.google.common.io.Closeables import io.netty.channel.DefaultFileRegion import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -44,8 +44,7 @@ private[spark] class DiskStore( securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") - private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", - Int.MaxValue.toString) + private val maxMemoryMapBytes = conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS) private val blockSizes = new ConcurrentHashMap[BlockId, Long]() def getSize(blockId: BlockId): Long = blockSizes.get(blockId) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index efed90cb7678e..39f050f6ca5ad 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -181,6 +181,10 @@ object ChunkedByteBuffer { } } + def map(file: File, maxChunkSize: Int): ChunkedByteBuffer = { + map(file, maxChunkSize, 0, file.length()) + } + def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => var remaining = length diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 28ea0c6f0bdba..629a323042ff2 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.Matchers import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} +import org.apache.spark.internal.config import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -154,6 +155,21 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } + private def testCaching(testName: String, conf: SparkConf, storageLevel: StorageLevel): Unit = { + test(testName) { + testCaching(conf, storageLevel) + } + if (storageLevel.replication > 1) { + // also try with block replication as a stream + val uploadStreamConf = new SparkConf() + uploadStreamConf.setAll(conf.getAll) + uploadStreamConf.set(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1L) + test(s"$testName (with replication as stream)") { + testCaching(uploadStreamConf, storageLevel) + } + } + } + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) TestUtils.waitUntilExecutorsUp(sc, 2, 30000) @@ -169,7 +185,10 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val blockManager = SparkEnv.get.blockManager val blockTransfer = blockManager.blockTransferService val serializerManager = SparkEnv.get.serializerManager - blockManager.master.getLocations(blockId).foreach { cmId => + val locations = blockManager.master.getLocations(blockId) + assert(locations.size === storageLevel.replication, + s"; got ${locations.size} replicas instead of ${storageLevel.replication}") + locations.foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, @@ -189,8 +208,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - encryptionTest(testName) { conf => - testCaching(conf, storageLevel) + encryptionTestHelper(testName) { case (name, conf) => + testCaching(name, conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala index 3f52dc41abf6d..be6b8a6b5b108 100644 --- a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -28,11 +28,15 @@ trait EncryptionFunSuite { * for the test to modify the provided SparkConf. */ final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + encryptionTestHelper(name) { case (name, conf) => + test(name)(fn(conf)) + } + } + + final protected def encryptionTestHelper(name: String)(fn: (String, SparkConf) => Unit): Unit = { Seq(false, true).foreach { encrypt => - test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { - val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) - fn(conf) - } + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(s"$name (encryption = ${ if (encrypt) "on" else "off" })", conf) } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 2f880a3be33d3..eec961a491101 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -24,6 +24,7 @@ import com.google.common.io.{ByteStreams, Files} import io.netty.channel.FileRegion import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -94,7 +95,7 @@ class DiskStoreSuite extends SparkFunSuite { test("blocks larger than 2gb") { val conf = new SparkConf() - .set("spark.storage.memoryMapLimitForTests", "10k" ) + .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 080cdd1c3de80..cdc99a48e5b64 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,11 +36,12 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-24296][CORE] Replicate large blocks as a stream. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), // [SPARK-23528] Add numIter to ClusteringSummary ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), - // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), From 72ecfd095062ad61c073f9b97bf3c47644575d60 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 21 Aug 2018 15:21:55 -0700 Subject: [PATCH 1441/2461] [SPARK-25149][GRAPHX] Update Parallel Personalized Page Rank to test with large vertexIds ## What changes were proposed in this pull request? runParallelPersonalizedPageRank in graphx checks that `sources` are <= Int.MaxValue.toLong, but this is not actually required. This check seems to have been added because we use sparse vectors in the implementation and sparse vectors cannot be indexed by values > MAX_INT. However we do not ever index the sparse vector by the source vertexIds so this isn't an issue. I've added a test with large vertexIds to confirm this works as expected. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22139 from MrBago/remove-veretexId-check-pppr. Authored-by: Bago Amirbekian Signed-off-by: Joseph K. Bradley --- .../apache/spark/graphx/lib/PageRank.scala | 28 ++++++---------- .../spark/graphx/lib/PageRankSuite.scala | 32 +++++++++++++++---- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index ebd65e8320e5c..96b635f9a144e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -184,9 +184,11 @@ object PageRank extends Logging { * indexed by the position of nodes in the sources list) and * edge attributes the normalized edge weight */ - def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], - numIter: Int, resetProb: Double = 0.15, - sources: Array[VertexId]): Graph[Vector, Double] = { + def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], + numIter: Int, + resetProb: Double = 0.15, + sources: Array[VertexId]): Graph[Vector, Double] = { require(numIter > 0, s"Number of iterations must be greater than 0," + s" but got ${numIter}") require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + @@ -194,15 +196,11 @@ object PageRank extends Logging { require(sources.nonEmpty, s"The list of sources must be non-empty," + s" but got ${sources.mkString("[", ",", "]")}") - // TODO if one sources vertex id is outside of the int range - // we won't be able to store its activations in a sparse vector - require(sources.max <= Int.MaxValue.toLong, - s"This implementation currently only works for source vertex ids at most ${Int.MaxValue}") val zero = Vectors.sparse(sources.size, List()).asBreeze - val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => - val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze - (vid, v) - }.toMap + // map of vid -> vector where for each vid, the _position of vid in source_ is set to 1.0 + val sourcesInitMap = sources.zipWithIndex.toMap.mapValues { i => + Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze + } val sc = graph.vertices.sparkContext val sourcesInitMapBC = sc.broadcast(sourcesInitMap) // Initialize the PageRank graph with each edge attribute having @@ -212,13 +210,7 @@ object PageRank extends Logging { .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets(e => 1.0 / e.srcAttr, TripletFields.Src) - .mapVertices { (vid, attr) => - if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) - } else { - zero - } - } + .mapVertices((vid, _) => sourcesInitMapBC.value.getOrElse(vid, zero)) var i = 0 while (i < numIter) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 9779553ce85d1..1e4c6c74bd184 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -203,24 +203,42 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x + 1) ) + // Check that implementation can handle large vertexIds, SPARK-25149 + val vertexIdOffset = Int.MaxValue.toLong + 1 + val sourceOffest = 4 + val source = vertexIdOffset + sourceOffest + val numIter = 10 + val vertices = vertexIdOffset until vertexIdOffset + numIter + val chain1 = vertices.zip(vertices.tail) val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 - val numIter = 10 val errorTol = 1.0e-1 - val staticRanks = chain.staticPersonalizedPageRank(4, numIter, resetProb).vertices - val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices + val a = resetProb / (1 - Math.pow(1 - resetProb, numIter - sourceOffest)) + // We expect the rank to decay as (1 - resetProb) ^ distance + val expectedRanks = sc.parallelize(vertices).map { vid => + val rank = if (vid < source) { + 0.0 + } else { + a * Math.pow(1 - resetProb, vid - source) + } + vid -> rank + } + val expected = VertexRDD(expectedRanks) + + val staticRanks = chain.staticPersonalizedPageRank(source, numIter, resetProb).vertices + assert(compareRanks(staticRanks, expected) < errorTol) - assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + val dynamicRanks = chain.personalizedPageRank(source, tol, resetProb).vertices + assert(compareRanks(dynamicRanks, expected) < errorTol) val parallelStaticRanks = chain - .staticParallelPersonalizedPageRank(Array(4), numIter, resetProb).mapVertices { + .staticParallelPersonalizedPageRank(Array(source), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(0) }.vertices.cache() - assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) + assert(compareRanks(parallelStaticRanks, expected) < errorTol) } } From 6c5cb85856235efd464b109558896f81ae2c4c75 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 21 Aug 2018 15:22:42 -0700 Subject: [PATCH 1442/2461] [SPARK-24763][SS] Remove redundant key data from value in streaming aggregation ## What changes were proposed in this pull request? This patch proposes a new flag option for stateful aggregation: remove redundant key data from value. Enabling new option runs similar with current, and uses less memory for state according to key/value fields of state operator. Please refer below link to see detailed perf. test result: https://issues.apache.org/jira/browse/SPARK-24763?focusedCommentId=16536539&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16536539 Since the state between enabling the option and disabling the option is not compatible, the option is set to 'disable' by default (to ensure backward compatibility), and OffsetSeqMetadata would prevent modifying the option after executing query. ## How was this patch tested? Modify unit tests to cover both disabling option and enabling option. Also did manual tests to see whether propose patch improves state memory usage. Closes #21733 from HeartSaVioR/SPARK-24763. Authored-by: Jungtaek Lim Signed-off-by: Tathagata Das --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../spark/sql/execution/SparkStrategies.scala | 3 + .../sql/execution/aggregate/AggUtils.scala | 5 +- .../streaming/IncrementalExecution.scala | 6 +- .../sql/execution/streaming/OffsetSeq.scala | 8 +- .../StreamingAggregationStateManager.scala | 205 ++++++++++++++++++ .../streaming/statefulOperators.scala | 61 ++++-- .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 3 + .../offsets/1 | 3 + .../state/0/0/1.delta | Bin 0 -> 46 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/1.delta | Bin 0 -> 77 bytes .../state/0/1/2.delta | Bin 0 -> 77 bytes .../state/0/2/1.delta | Bin 0 -> 46 bytes .../state/0/2/2.delta | Bin 0 -> 46 bytes .../state/0/3/1.delta | Bin 0 -> 46 bytes .../state/0/3/2.delta | Bin 0 -> 46 bytes .../state/0/4/1.delta | Bin 0 -> 46 bytes .../state/0/4/2.delta | Bin 0 -> 77 bytes .../streaming/state/MemoryStateStore.scala | 49 +++++ ...treamingAggregationStateManagerSuite.scala | 126 +++++++++++ .../FlatMapGroupsWithStateSuite.scala | 24 +- .../streaming/StreamingAggregationSuite.scala | 150 ++++++++++--- 26 files changed, 573 insertions(+), 85 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bffdddcf3fdb0..b44bfe7193eae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -888,6 +888,16 @@ object SQLConf { .intConf .createWithDefault(2) + val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.aggregation.stateFormatVersion") + .internal() + .doc("State format version used by streaming aggregation operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b4179f4d12d35..4c39990acb627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Streaming aggregation doesn't support group aggregate pandas UDF") } + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, + stateVersion, planLater(child)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index c8ef2b3f6998d..6be88c463dbd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -260,6 +260,7 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -291,7 +292,8 @@ object AggUtils { child = partialAggregate) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -315,6 +317,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, + stateFormatVersion = stateFormatVersion, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 725abb318baa0..fad287e28877d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -102,19 +102,21 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, None, + case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, - StateStoreRestoreExec(_, None, child))) => + StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( keys, Some(aggStateInfo), + stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 9847756f22d4f..73cf355dbe758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** @@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging { private val relevantSQLConfDefaultValues = Map[String, String]( STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> - FlatMapGroupsWithStateExecHelper.legacyVersion.toString + FlatMapGroupsWithStateExecHelper.legacyVersion.toString, + STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> + StreamingAggregationStateManager.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala new file mode 100644 index 0000000000000..9bfb9561b42a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.types.StructType + +/** + * Base trait for state manager purposed to be used from streaming aggregations. + */ +sealed trait StreamingAggregationStateManager extends Serializable { + + /** Extract columns consisting key from input row, and return the new row for key columns. */ + def getKey(row: UnsafeRow): UnsafeRow + + /** Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. */ + def getStateValueSchema: StructType + + /** Get the current value of a non-null key from the target state store. */ + def get(store: StateStore, key: UnsafeRow): UnsafeRow + + /** + * Put a new value for a non-null key to the target state store. Note that key will be + * extracted from the input row, and the key would be same as the result of getKey(inputRow). + */ + def put(store: StateStore, row: UnsafeRow): Unit + + /** + * Commit all the updates that have been made to the target state store, and return the + * new version. + */ + def commit(store: StateStore): Long + + /** Remove a single non-null key from the target state store. */ + def remove(store: StateStore, key: UnsafeRow): Unit + + /** Return an iterator containing all the key-value pairs in target state store. */ + def iterator(store: StateStore): Iterator[UnsafeRowPair] + + /** Return an iterator containing all the keys in target state store. */ + def keys(store: StateStore): Iterator[UnsafeRow] + + /** Return an iterator containing all the values in target state store. */ + def values(store: StateStore): Iterator[UnsafeRow] +} + +object StreamingAggregationStateManager extends Logging { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } +} + +abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) + + override def commit(store: StateStore): Long = store.commit() + + override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) + + override def keys(store: StateStore): Iterator[UnsafeRow] = { + // discard and don't convert values to avoid computation + store.getRange(None, None).map(_.key) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 1. + * In state version 1, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: Same as input row attributes. The schema of value contains key expressions as well. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(getKey(row), row) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator() + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(_.value) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 2. + * In state version 2, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: The diff between input row attributes and key expressions. + * + * The schema of value is changed to optimize the memory/space usage in state, via removing + * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + + // flag to check whether the row needs to be project into input row attributes after join + // e.g. if the fields in the joined row are not in the expected order + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + inputRowAttributes, keyValueJoinedExpressions) + + override def getStateValueSchema: StructType = valueExpressions.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + restoreOriginalRow(key, savedState) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))) + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(rowPair => restoreOriginalRow(rowPair)) + } + + private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = { + restoreOriginalRow(rowPair.key, rowPair.value) + } + + private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow = { + val joinedRow = joiner.join(key, value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6759fb42b4052..34e26d85ae2ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -167,6 +165,18 @@ trait WatermarkSupport extends UnaryExecNode { } } } + + protected def removeKeysOlderThanWatermark( + storeManager: StreamingAggregationStateManager, + store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + storeManager.keys(store).foreach { keyRow => + if (watermarkPredicateForKeys.get.eval(keyRow)) { + storeManager.remove(store, keyRow) + } + } + } + } } object WatermarkSupport { @@ -201,20 +211,23 @@ object WatermarkSupport { case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { // If our `keyExpressions` are empty, we're getting a global aggregation. In that case @@ -224,10 +237,10 @@ case class StateStoreRestoreExec( store.iterator().map(_.value) } else { iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) + val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) + val restoredRow = stateManager.get(store, key) numOutputRows += 1 - Option(savedState).toSeq :+ row + Option(restoredRow).toSeq :+ row } } } @@ -254,9 +267,13 @@ case class StateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -265,11 +282,10 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") @@ -282,19 +298,18 @@ case class StateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { - store.commit() + stateManager.commit(store) } setStoreMetrics(store) - store.iterator().map { rowPair => + stateManager.values(store).map { valueRow => numOutputRows += 1 - rowPair.value + valueRow } // Update and output only rows being evicted from the StateStore @@ -304,14 +319,13 @@ case class StateStoreSaveExec( val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } val removalStartTimeNs = System.nanoTime - val rangeIter = store.getRange(None, None) + val rangeIter = stateManager.iterator(store) new NextIterator[InternalRow] { override protected def getNext(): InternalRow = { @@ -319,7 +333,7 @@ case class StateStoreSaveExec( while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - store.remove(rowPair.key) + stateManager.remove(store, rowPair.key) removedValueRow = rowPair.value } } @@ -333,7 +347,7 @@ case class StateStoreSaveExec( override protected def close(): Unit = { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { store.commit() } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } @@ -352,8 +366,7 @@ case class StateStoreSaveExec( override protected def getNext(): InternalRow = { if (baseIterator.hasNext) { val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numOutputRows += 1 numUpdatedStateRows += 1 row @@ -367,8 +380,10 @@ case class StateStoreSaveExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } + allRemovalsTimeMs += timeTakenMs { + removeKeysOlderThanWatermark(stateManager, store) + } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata new file mode 100644 index 0000000000000..c160d737278e1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"2f32aca2-1b97-458f-a48f-109328724f09"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 new file mode 100644 index 0000000000000..acdc6e69e975a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784347136,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 new file mode 100644 index 0000000000000..27353e8724507 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784349160,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..281b21e96090981faa965b468a31a06f73dc293a GIT binary patch literal 77 zcmeZ?GI7euPtI0VW?*120bDc literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..f4fb2520a4ac43f7ac9d87544480a1e7bb5053b6 GIT binary patch literal 77 zcmeZ?GI7euPtI0VW?*120b-T3tt`PnT7ZF(L70hy!4b%oU}InxVK~4DWP-qdAn<|e J6NLytNB|3?4Hp0a literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala new file mode 100644 index 0000000000000..98586d6492c9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + } + + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key.copy(), newValue.copy()) + + override def remove(key: UnsafeRow): Unit = map.remove(key) + + override def commit(): Long = version + 1 + + override def abort(): Unit = {} + + override def id: StateStoreId = null + + override def version: Long = 0 + + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) + + override def hasCommitted: Boolean = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala new file mode 100644 index 0000000000000..daacdfd58c7b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class StreamingAggregationStateManagerSuite extends StreamTest { + // ============================ fields and method for test data ============================ + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes + val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testKeys.contains(p.name) + } + val testValuesAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testValues.contains(p.name) + } + val expectedTestValuesSchema: StructType = testValuesAttributes.toStructType + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + val expectedTestKeyRow: UnsafeRow = { + val keyProjector = GenerateUnsafeProjection.generate(testKeyAttributes, testOutputAttributes) + keyProjector(testRow) + } + + val expectedTestValueRowForV2: UnsafeRow = { + val valueProjector = GenerateUnsafeProjection.generate(testValuesAttributes, + testOutputAttributes) + valueProjector(testRow) + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + // ============================ StateManagerImplV1 ============================ + + test("StateManager v1 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 1) + + // in V1, input row is stored as value + testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow, + expectedTestKeyRow, expectedStateValue = testRow) + } + + // ============================ StateManagerImplV2 ============================ + test("StateManager v2 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + + // in V2, row for values itself (excluding keys from input row) is stored as value + // so that stored value doesn't have key part, but state manager V2 will provide same output + // as V1 when getting row for key + testGetPutIterOnStateManager(stateManager, expectedTestValuesSchema, testRow, + expectedTestKeyRow, expectedTestValueRowForV2) + } + + private def testGetPutIterOnStateManager( + stateManager: StreamingAggregationStateManager, + expectedValueSchema: StructType, + inputRow: UnsafeRow, + expectedStateKey: UnsafeRow, + expectedStateValue: UnsafeRow): Unit = { + + assert(stateManager.getStateValueSchema === expectedValueSchema) + + val memoryStateStore = new MemoryStateStore() + stateManager.put(memoryStateStore, inputRow) + + assert(memoryStateStore.iterator().size === 1) + assert(stateManager.iterator(memoryStateStore).size === memoryStateStore.iterator().size) + + val keyRow = stateManager.getKey(inputRow) + assert(keyRow === expectedStateKey) + + // iterate state store and verify whether expected format of key and value are stored + val pair = memoryStateStore.iterator().next() + assert(pair.key === keyRow) + assert(pair.value === expectedStateValue) + + // iterate with state manager and see whether original rows are returned as values + val pairFromStateManager = stateManager.iterator(memoryStateStore).next() + assert(pairFromStateManager.key === keyRow) + assert(pairFromStateManager.value === inputRow) + + // following as keys and values + assert(stateManager.keys(memoryStateStore).next() === keyRow) + assert(stateManager.values(memoryStateStore).next() === inputRow) + + // verify the stored value once again via get + assert(memoryStateStore.get(keyRow) === expectedStateValue) + + // state manager should return row which is same as input row regardless of format version + assert(inputRow === stateManager.get(memoryStateStore, keyRow)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 82d7755aef5f0..76511ae2c8362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming import java.io.File import java.sql.Date -import java.util.concurrent.ConcurrentHashMap import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1286,27 +1285,6 @@ object FlatMapGroupsWithStateSuite { var failInTask = true - class MemoryStateStore extends StateStore() { - import scala.collection.JavaConverters._ - private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - - override def iterator(): Iterator[UnsafeRowPair] = { - map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } - } - - override def get(key: UnsafeRow): UnsafeRow = map.get(key) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { - map.put(key.copy(), newValue.copy()) - } - override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def commit(): Long = version + 1 - override def abort(): Unit = { } - override def id: StateStoreId = null - override def version: Long = 0 - override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) - override def hasCommitted: Boolean = true - } - def assertCanGetProcessingTime(predicate: => Boolean): Unit = { if (!predicate) throw new TestFailedException("Could not get processing time", 20) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 382da13430781..1ae6ff3a90989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.{Locale, TimeZone} -import org.scalatest.Assertions -import org.scalatest.BeforeAndAfterAll +import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -31,13 +32,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} +import org.apache.spark.util.Utils object FailureSingleton { var firstTime = true @@ -53,7 +56,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - test("simple count, update mode") { + def executeFuncWithStateVersionSQLConf( + stateVersion: Int, + confPairs: Seq[(String, String)], + func: => Any): Unit = { + withSQLConf(confPairs ++ + Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + func + } + } + + def testWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + test(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + testQuietly(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + testWithAllStateVersions("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -77,7 +108,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("count distinct") { + testWithAllStateVersions("count distinct") { val inputData = MemoryStream[(Int, Seq[Int])] val aggregated = @@ -93,7 +124,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, complete mode") { + testWithAllStateVersions("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -116,7 +147,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, append mode") { + testWithAllStateVersions("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = @@ -133,7 +164,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("sort after aggregate in complete mode") { + testWithAllStateVersions("sort after aggregate in complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -158,7 +189,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("state metrics") { + testWithAllStateVersions("state metrics") { val inputData = MemoryStream[Int] val aggregated = @@ -211,7 +242,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("multiple keys") { + testWithAllStateVersions("multiple keys") { val inputData = MemoryStream[Int] val aggregated = @@ -228,7 +259,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testQuietly("midbatch failure") { + testQuietlyWithAllStateVersions("midbatch failure") { val inputData = MemoryStream[Int] FailureSingleton.firstTime = true val aggregated = @@ -254,7 +285,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("typed aggregators") { + testWithAllStateVersions("typed aggregators") { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) @@ -264,7 +295,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_time, complete mode") { + testWithAllStateVersions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock val inputData = MemoryStream[Long] @@ -316,7 +347,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_date, complete mode") { + testWithAllStateVersions("prune results by current_date, complete mode") { import testImplicits._ val clock = new StreamManualClock val tz = TimeZone.getDefault.getID @@ -365,7 +396,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " + + "to streaming") { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") @@ -429,7 +461,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + + "repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default @@ -467,8 +500,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + - "has non-empty grouping keys") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " + + "repartitioned when it has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => @@ -520,7 +553,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-22230: last should change with new batches") { + testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] val aggregated = input.toDF().agg(last('value)) @@ -536,7 +569,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + + "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") { // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error // by ensuring the following. // - A streaming query with a streaming aggregation. @@ -545,22 +579,72 @@ class StreamingAggregationSuite extends StateStoreMetricsTest // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a // micro-batch with 128 records that shuffle to a single partition. // This test throws the exact error reported in SPARK-23004 without the corresponding fix. - withSQLConf("spark.sql.shuffle.partitions" -> "1") { - val input = MemoryStream[Int] - val df = input.toDF().toDF("value") - .selectExpr("value as group", "value") - .groupBy("group") - .agg(collect_list("value")) - testStream(df, outputMode = OutputMode.Update)( - AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), - AssertOnQuery { q => - q.processAllAvailable() - true + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + + + test("simple count, update mode - recovery from checkpoint uses state format version 1") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(3) + inputData.addData(3, 2) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)) + */ + + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + + Execute { query => + // Verify state format = 1 + val stateVersions = query.lastExecution.executedPlan.collect { + case f: StateStoreSaveExec => f.stateFormatVersion + case f: StateStoreRestoreExec => f.stateFormatVersion } - ) - } + assert(stateVersions.size == 2) + assert(stateVersions.forall(_ == 1)) + }, + + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From ac0174e55af2e935d41545721e9f430c942b3a0c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 21 Aug 2018 15:26:24 -0700 Subject: [PATCH 1443/2461] [SPARK-25129][SQL] Make the mapping of com.databricks.spark.avro to built-in module configurable ## What changes were proposed in this pull request? In https://issues.apache.org/jira/browse/SPARK-24924, the data source provider com.databricks.spark.avro is mapped to the new package org.apache.spark.sql.avro . As per the discussion in the [Jira](https://issues.apache.org/jira/browse/SPARK-24924) and PR #22119, we should make the mapping configurable. This PR also improve the error message when data source of Avro/Kafka is not found. ## How was this patch tested? Unit test Closes #22133 from gengliangwang/configurable_avro_mapping. Authored-by: Gengliang Wang Signed-off-by: Xiao Li --- .../org/apache/spark/sql/avro/AvroSuite.scala | 11 ++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 10 ++++++++++ .../sql/execution/datasources/DataSource.scala | 16 ++++++++++++++-- .../sql/sources/ResolvedDataSourceSuite.scala | 18 ++++++++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c4f4d8efd6df4..72bef9e3aed41 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -77,10 +77,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("resolve avro data source") { - Seq("avro", "com.databricks.spark.avro").foreach { provider => + val databricksAvro = "com.databricks.spark.avro" + // By default the backward compatibility for com.databricks.spark.avro is enabled. + Seq("avro", "org.apache.spark.sql.avro.AvroFileFormat", databricksAvro).foreach { provider => assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === classOf[org.apache.spark.sql.avro.AvroFileFormat]) } + + withSQLConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED.key -> "false") { + val message = intercept[AnalysisException] { + DataSource.lookupDataSource(databricksAvro, spark.sessionState.conf) + }.getMessage + assert(message.contains(s"Failed to find data source: $databricksAvro")) + } } test("reading from multiple paths") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b44bfe7193eae..5913c947f2b61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1469,6 +1469,13 @@ object SQLConf { .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) .createWithDefault(Deflater.DEFAULT_COMPRESSION) + val LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED = + buildConf("spark.sql.legacy.replaceDatabricksSparkAvro.enabled") + .doc("If it is set to true, the data source provider com.databricks.spark.avro is mapped " + + "to the built-in but external Avro data source module for backward compatibility.") + .booleanConf + .createWithDefault(true) + val LEGACY_SETOPS_PRECEDENCE_ENABLED = buildConf("spark.sql.legacy.setopsPrecedence.enabled") .internal() @@ -1881,6 +1888,9 @@ class SQLConf extends Serializable with Logging { def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + def replaceDatabricksSparkAvroEnabled: Boolean = + getConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED) + def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) def parallelFileListingInStatsComputation: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b1a10fdb60207..1dcf9f3185de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -571,7 +571,6 @@ object DataSource extends Logging { val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName val rate = classOf[RateStreamProvider].getCanonicalName - val avro = "org.apache.spark.sql.avro.AvroFileFormat" Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -593,7 +592,6 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "com.databricks.spark.avro" -> avro, "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) @@ -616,6 +614,8 @@ object DataSource extends Logging { case name if name.equalsIgnoreCase("orc") && conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => "org.apache.spark.sql.hive.orc.OrcFileFormat" + case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled => + "org.apache.spark.sql.avro.AvroFileFormat" case name => name } val provider2 = s"$provider1.DefaultSource" @@ -637,6 +637,18 @@ object DataSource extends Logging { "Hive built-in ORC data source must be used with Hive support enabled. " + "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + "'native'") + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || + provider1 == "com.databricks.spark.avro" || + provider1 == "org.apache.spark.sql.avro") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Avro is built-in but external data " + + "source module since Spark 2.4. Please deploy the application as per " + + "the deployment section of \"Apache Avro Data Source Guide\".") + } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Please deploy the application as " + + "per the deployment section of " + + "\"Structured Streaming + Kafka Integration Guide\".") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 95460fa70d8f6..0aa67bf1b0d48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -76,6 +76,24 @@ class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext { classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat]) } + test("avro: show deploy guide for loading the external avro module") { + Seq("avro", "org.apache.spark.sql.avro").foreach { provider => + val message = intercept[AnalysisException] { + getProvidingClass(provider) + }.getMessage + assert(message.contains(s"Failed to find data source: $provider")) + assert(message.contains("Please deploy the application as per the deployment section of")) + } + } + + test("kafka: show deploy guide for loading the external kafka module") { + val message = intercept[AnalysisException] { + getProvidingClass("kafka") + }.getMessage + assert(message.contains("Failed to find data source: kafka")) + assert(message.contains("Please deploy the application as per the deployment section of")) + } + test("error message for unknown data sources") { val error = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") From 42035a4fec6eb216427486b5067a45fceb65cc2d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 21 Aug 2018 15:28:31 -0700 Subject: [PATCH 1444/2461] [SPARK-24441][SS] Expose total estimated size of states in HDFSBackedStateStoreProvider ## What changes were proposed in this pull request? This patch exposes the estimation of size of cache (loadedMaps) in HDFSBackedStateStoreProvider as a custom metric of StateStore. The rationalize of the patch is that state backed by HDFSBackedStateStoreProvider will consume more memory than the number what we can get from query status due to caching multiple versions of states. The memory footprint to be much larger than query status reports in situations where the state store is getting a lot of updates: while shallow-copying map incurs additional small memory usages due to the size of map entities and references, but row objects will still be shared across the versions. If there're lots of updates between batches, less row objects will be shared and more row objects will exist in memory consuming much memory then what we expect. While HDFSBackedStateStore refers loadedMaps in HDFSBackedStateStoreProvider directly, there would be only one `StateStoreWriter` which refers a StateStoreProvider, so the value is not exposed as well as being aggregated multiple times. Current state metrics are safe to aggregate for the same reason. ## How was this patch tested? Tested manually. Below is the snapshot of UI page which is reflected by the patch: screen shot 2018-06-05 at 10 16 16 pm Please refer "estimated size of states cache in provider total" as well as "count of versions in state cache in provider". Closes #21469 from HeartSaVioR/SPARK-24441. Authored-by: Jungtaek Lim Signed-off-by: Tathagata Das --- .../state/HDFSBackedStateStoreProvider.scala | 39 ++++++- .../streaming/state/StateStore.scala | 2 + .../state/SymmetricHashJoinStateManager.scala | 2 + .../streaming/statefulOperators.scala | 12 ++- .../apache/spark/sql/streaming/progress.scala | 15 ++- .../streaming/state/StateStoreSuite.scala | 100 ++++++++++++++++++ .../StreamingQueryListenerSuite.scala | 2 +- ...StreamingQueryStatusAndProgressSuite.scala | 13 ++- 8 files changed, 176 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 523acef34ca61..92a2480e8b017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.util import java.util.Locale +import java.util.concurrent.atomic.LongAdder import scala.collection.JavaConverters._ import scala.collection.mutable @@ -165,7 +166,16 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def metrics: StateStoreMetrics = { - StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty) + // NOTE: we provide estimation of cache size as "memoryUsedBytes", and size of state for + // current version as "stateOnCurrentVersionSizeBytes" + val metricsFromProvider: Map[String, Long] = getMetricsForProvider() + + val customMetrics = metricsFromProvider.flatMap { case (name, value) => + // just allow searching from list cause the list is small enough + supportedCustomMetrics.find(_.name == name).map(_ -> value) + } + (metricStateOnCurrentVersionSizeBytes -> SizeEstimator.estimate(mapToUpdate)) + + StateStoreMetrics(mapToUpdate.size(), metricsFromProvider("memoryUsedBytes"), customMetrics) } /** @@ -180,6 +190,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } } + def getMetricsForProvider(): Map[String, Long] = synchronized { + Map("memoryUsedBytes" -> SizeEstimator.estimate(loadedMaps), + metricLoadedMapCacheHit.name -> loadedMapCacheHitCount.sum(), + metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum()) + } + /** Get the state store for making updates to create a new `version` of the store. */ override def getStore(version: Long): StateStore = synchronized { require(version >= 0, "Version cannot be less than 0") @@ -226,7 +242,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { - Nil + metricStateOnCurrentVersionSizeBytes :: metricLoadedMapCacheHit :: metricLoadedMapCacheMiss :: + Nil } override def toString(): String = { @@ -248,6 +265,21 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + private val loadedMapCacheHitCount: LongAdder = new LongAdder + private val loadedMapCacheMissCount: LongAdder = new LongAdder + + private lazy val metricStateOnCurrentVersionSizeBytes: StateStoreCustomSizeMetric = + StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes", + "estimated size of state only on current version") + + private lazy val metricLoadedMapCacheHit: StateStoreCustomMetric = + StateStoreCustomSumMetric("loadedMapCacheHitCount", + "count of cache hit on states cache in provider") + + private lazy val metricLoadedMapCacheMiss: StateStoreCustomMetric = + StateStoreCustomSumMetric("loadedMapCacheMissCount", + "count of cache miss on states cache in provider") + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { @@ -311,6 +343,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit // Shortcut if the map for this version is already there to avoid a redundant put. val loadedCurrentVersionMap = synchronized { Option(loadedMaps.get(version)) } if (loadedCurrentVersionMap.isDefined) { + loadedMapCacheHitCount.increment() return loadedCurrentVersionMap.get } @@ -318,6 +351,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit "Reading snapshot file and delta files if needed..." + "Note that this is normal for the first batch of starting query.") + loadedMapCacheMissCount.increment() + val (result, elapsedMs) = Utils.timeTakenMs { val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7eb68c21569ba..d3313b8a315c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -138,6 +138,8 @@ trait StateStoreCustomMetric { def name: String def desc: String } + +case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 6e7cd2db213d2..352b3d3616fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -269,6 +269,8 @@ class SymmetricHashJoinStateManager( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, keyWithIndexToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSumMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomSizeMetric(_, desc), value) => s.copy(desc = newDesc(desc)) -> value case (s @ StateStoreCustomTimingMetric(_, desc), value) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 34e26d85ae2ae..7351db8c4fbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -88,10 +88,18 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => * the driver after this SparkPlan has been executed and metrics have been updated. */ def getProgress(): StateOperatorProgress = { + val customMetrics = stateStoreCustomMetrics + .map(entry => entry._1 -> longMetric(entry._1).value) + + val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] = + new java.util.HashMap(customMetrics.mapValues(long2Long).asJava) + new StateOperatorProgress( numRowsTotal = longMetric("numTotalStateRows").value, numRowsUpdated = longMetric("numUpdatedStateRows").value, - memoryUsedBytes = longMetric("stateMemory").value) + memoryUsedBytes = longMetric("stateMemory").value, + javaConvertedCustomMetrics + ) } /** Records the duration of running `body` for the next query progress update. */ @@ -113,6 +121,8 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => private def stateStoreCustomMetrics: Map[String, SQLMetric] = { val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass) provider.supportedCustomMetrics.map { + case StateStoreCustomSumMetric(name, desc) => + name -> SQLMetrics.createMetric(sparkContext, desc) case StateStoreCustomSizeMetric(name, desc) => name -> SQLMetrics.createSizeMetric(sparkContext, desc) case StateStoreCustomTimingMetric(name, desc) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 2fb87960ccb04..cf9375d39b39d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -38,7 +38,8 @@ import org.apache.spark.annotation.InterfaceStability class StateOperatorProgress private[sql]( val numRowsTotal: Long, val numRowsUpdated: Long, - val memoryUsedBytes: Long + val memoryUsedBytes: Long, + val customMetrics: ju.Map[String, JLong] = new ju.HashMap() ) extends Serializable { /** The compact JSON representation of this progress. */ @@ -48,12 +49,20 @@ class StateOperatorProgress private[sql]( def prettyJson: String = pretty(render(jsonValue)) private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress = - new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes) + new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, customMetrics) private[sql] def jsonValue: JValue = { ("numRowsTotal" -> JInt(numRowsTotal)) ~ ("numRowsUpdated" -> JInt(numRowsUpdated)) ~ - ("memoryUsedBytes" -> JInt(memoryUsedBytes)) + ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~ + ("customMetrics" -> { + if (!customMetrics.isEmpty) { + val keys = customMetrics.keySet.asScala.toSeq.sorted + keys.map { k => k -> JInt(customMetrics.get(k).toLong) : JObject }.reduce(_ ~ _) + } else { + JNothing + } + }) } override def toString: String = prettyJson diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index bfeb2b16ff7be..5e973145b0a37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -317,6 +317,22 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(store.metrics.memoryUsedBytes > noDataMemoryUsed) } + test("reports memory usage on current version") { + def getSizeOfStateForCurrentVersion(metrics: StateStoreMetrics): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == "stateOnCurrentVersionSizeBytes") + assert(metricPair.isDefined) + metricPair.get._2 + } + + val provider = newStoreProvider() + val store = provider.getStore(0) + val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics) + + put(store, "a", 1) + store.commit() + assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed) + } + test("StateStore.get") { quietly { val dir = newDir() @@ -631,6 +647,90 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) } + test("expose metrics with custom metrics to StateStoreMetrics") { + def getCustomMetric(metrics: StateStoreMetrics, name: String): Long = { + val metricPair = metrics.customMetrics.find(_._1.name == name) + assert(metricPair.isDefined) + metricPair.get._2 + } + + def getLoadedMapSizeMetric(metrics: StateStoreMetrics): Long = { + metrics.memoryUsedBytes + } + + def assertCacheHitAndMiss( + metrics: StateStoreMetrics, + expectedCacheHitCount: Long, + expectedCacheMissCount: Long): Unit = { + val cacheHitCount = getCustomMetric(metrics, "loadedMapCacheHitCount") + val cacheMissCount = getCustomMetric(metrics, "loadedMapCacheMissCount") + assert(cacheHitCount === expectedCacheHitCount) + assert(cacheMissCount === expectedCacheMissCount) + } + + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + + assert(store.metrics.numKeys === 0) + + val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics) + assert(initialLoadedMapSize >= 0) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + + put(store, "a", 1) + assert(store.metrics.numKeys === 1) + + put(store, "b", 2) + put(store, "aa", 3) + assert(store.metrics.numKeys === 3) + remove(store, _.startsWith("a")) + assert(store.metrics.numKeys === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + + val loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics) + assert(loadedMapSizeForVersion1 > initialLoadedMapSize) + assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0) + + val storeV2 = provider.getStore(1) + assert(!storeV2.hasCommitted) + assert(storeV2.metrics.numKeys === 1) + + put(storeV2, "cc", 4) + assert(storeV2.metrics.numKeys === 2) + assert(storeV2.commit() === 2) + + assert(storeV2.hasCommitted) + + val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics) + assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1) + assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0) + + val reloadedProvider = newStoreProvider(store.id) + // intended to load version 2 instead of 1 + // version 2 will not be loaded to the cache in provider + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.metrics.numKeys === 1) + + assert(getLoadedMapSizeMetric(reloadedStore.metrics) === loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 1) + + // now we are loading version 2 + val reloadedStoreV2 = reloadedProvider.getStore(2) + assert(reloadedStoreV2.metrics.numKeys === 2) + + assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1) + assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0, + expectedCacheMissCount = 2) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b96f2bcbdd644..0f15cd6e5a506 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -231,7 +231,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { test("event ordering") { val listener = new EventCollector withListenerAdded(listener) { - for (i <- 1 to 100) { + for (i <- 1 to 50) { listener.reset() require(listener.startEvent === null) testStream(MemoryStream[Int].toDS)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 79bb827e0de93..7bef687e7e43b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -58,7 +58,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "stateOperators" : [ { | "numRowsTotal" : 0, | "numRowsUpdated" : 1, - | "memoryUsedBytes" : 2 + | "memoryUsedBytes" : 3, + | "customMetrics" : { + | "loadedMapCacheHitCount" : 1, + | "loadedMapCacheMissCount" : 0, + | "stateOnCurrentVersionSizeBytes" : 2 + | } | } ], | "sources" : [ { | "description" : "source", @@ -230,7 +235,11 @@ object StreamingQueryStatusAndProgressSuite { "avg" -> "2016-12-05T20:54:20.827Z", "watermark" -> "2016-12-05T20:54:20.827Z").asJava), stateOperators = Array(new StateOperatorProgress( - numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)), + numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 3, + customMetrics = new java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L, + "loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L) + .mapValues(long2Long).asJava) + )), sources = Array( new SourceProgress( description = "source", From ad45299d047c10472fd3a86103930fe7c54a4cf1 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 21 Aug 2018 15:54:30 -0700 Subject: [PATCH 1445/2461] [SPARK-25095][PYSPARK] Python support for BarrierTaskContext ## What changes were proposed in this pull request? Add method `barrier()` and `getTaskInfos()` in python TaskContext, these two methods are only allowed for barrier tasks. ## How was this patch tested? Add new tests in `tests.py` Closes #22085 from jiangxb1987/python.barrier. Authored-by: Xingbo Jiang Signed-off-by: Xiangrui Meng --- .../spark/api/python/PythonRunner.scala | 106 +++++++++++++ python/pyspark/serializers.py | 7 + python/pyspark/taskcontext.py | 144 ++++++++++++++++++ python/pyspark/tests.py | 36 ++++- python/pyspark/worker.py | 16 +- 5 files changed, 305 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7b31857588252..f8241915e4849 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -20,12 +20,14 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -76,6 +78,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // TODO: support accumulator in multiple UDF protected val accumulator = funcs.head.funcs.head.accumulator + // Expose a ServerSocket to support method calls via socket from Python side. + private[spark] var serverSocket: Option[ServerSocket] = None + + // Authentication helper used when serving method calls via socket from Python side. + private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + def compute( inputIterator: Iterator[IN], partitionIndex: Int, @@ -180,7 +188,73 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. + val isBarrier = context.isInstanceOf[BarrierTaskContext] + if (isBarrier) { + serverSocket = Some(new ServerSocket(/* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) + // A call to accept() for ServerSocket shall block infinitely. + serverSocket.map(_.setSoTimeout(0)) + new Thread("accept-connections") { + setDaemon(true) + + override def run(): Unit = { + while (!serverSocket.get.isClosed()) { + var sock: Socket = null + try { + sock = serverSocket.get.accept() + // Wait for function call from python side. + sock.setSoTimeout(10000) + val input = new DataInputStream(sock.getInputStream()) + input.readInt() match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + + case _ => + val out = new DataOutputStream(new BufferedOutputStream( + sock.getOutputStream)) + writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) + } + } catch { + case e: SocketException if e.getMessage.contains("Socket closed") => + // It is possible that the ServerSocket is not closed, but the native socket + // has already been closed, we shall catch and silently ignore this case. + } finally { + if (sock != null) { + sock.close() + } + } + } + } + }.start() + } + val secret = if (isBarrier) { + authHelper.secret + } else { + "" + } + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener(_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } else if (isBarrier) { + logDebug(s"Started ServerSocket on port $boundPort.") + } // Write out the TaskContextInfo + dataOut.writeBoolean(isBarrier) + dataOut.writeInt(boundPort) + val secretBytes = secret.getBytes(UTF_8) + dataOut.writeInt(secretBytes.length) + dataOut.write(secretBytes, 0, secretBytes.length) dataOut.writeInt(context.stageId()) dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) @@ -243,6 +317,32 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } } + + /** + * Gateway to call BarrierTaskContext.barrier(). + */ + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + + authHelper.authClient(sock) + + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + } catch { + case e: SparkException => + writeUTF(e.getMessage, out) + } finally { + out.close() + } + } + + def writeUTF(str: String, dataOut: DataOutputStream) { + val bytes = str.getBytes(UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } } abstract class ReaderIterator( @@ -465,3 +565,9 @@ private[spark] object SpecialLengths { val NULL = -5 val START_ARROW_STREAM = -6 } + +private[spark] object BarrierTaskContextMessageProtocol { + val BARRIER_FUNCTION = 1 + val BARRIER_RESULT_SUCCESS = "success" + val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." +} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 47c4c3e663b97..10385589c4d3b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -715,6 +715,13 @@ def write_int(value, stream): stream.write(struct.pack("!i", value)) +def read_bool(stream): + length = stream.read(1) + if not length: + raise EOFError + return struct.unpack("!?", length)[0] + + def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 63ae1f30e17ca..c0312e5265c6e 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,6 +16,10 @@ # from __future__ import print_function +import socket + +from pyspark.java_gateway import do_server_auth +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -95,3 +99,143 @@ def getLocalProperty(self, key): Get a local property set upstream in the driver, or None if it is missing. """ return self._localProperties.get(key, None) + + +BARRIER_FUNCTION = 1 + + +def _load_from_socket(port, auth_secret): + """ + Load data from a given socket, this is a blocking method thus only return when the socket + connection has been closed. + + This is copied from context.py, while modified the message protocol. + """ + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + try: + # Do not allow timeout for socket reading operation. + sock.settimeout(None) + sock.connect(sa) + except socket.error: + sock.close() + sock = None + continue + break + if not sock: + raise Exception("could not open socket") + + # We don't really need a socket file here, it's just for convenience that we can reuse the + # do_server_auth() function and data serialization methods. + sockfile = sock.makefile("rwb", 65536) + + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) + sockfile.flush() + + # Do server auth. + do_server_auth(sockfile, auth_secret) + + # Collect result. + res = UTF8Deserializer().loads(sockfile) + + # Release resources. + sockfile.close() + sock.close() + + return res + + +class BarrierTaskContext(TaskContext): + + """ + .. note:: Experimental + + A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext + for a running task, use: + L{BarrierTaskContext.get()}. + + .. versionadded:: 2.4.0 + """ + + _port = None + _secret = None + + def __init__(self): + """Construct a BarrierTaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global BarrierTaskContext.""" + if cls._taskContext is None: + cls._taskContext = BarrierTaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + Return the currently active BarrierTaskContext. This can be called inside of user functions + to access contextual information about running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + @classmethod + def _initialize(cls, port, secret): + """ + Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called + after BarrierTaskContext is initialized. + """ + cls._port = port + cls._secret = secret + + def barrier(self): + """ + .. note:: Experimental + + Sets a global barrier and waits until all tasks in this stage hit this barrier. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + _load_from_socket(self._port, self._secret) + + def getTaskInfos(self): + """ + .. note:: Experimental + + Returns the all task infos in this barrier stage, the task infos are ordered by + partitionId. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call getTaskInfos() before initialize " + + "BarrierTaskContext.") + else: + addresses = self._localProperties.get("addresses", "") + return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")] + + +class BarrierTaskInfo(object): + """ + .. note:: Experimental + + Carries all task infos of a barrier task. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, address): + self.address = address diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a4c5fb1db8b37..8ac1df52fc597 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -70,7 +70,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import BarrierTaskContext, TaskContext _have_scipy = False _have_numpy = False @@ -588,6 +588,40 @@ def test_get_local_property(self): finally: self.sc.setLocalProperty(key, None) + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + class RDDTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eaaae2b14e107..d54a5b8e396ea 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -28,10 +28,10 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.java_gateway import do_server_auth -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType -from pyspark.serializers import write_with_length, write_int, read_long, \ +from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type @@ -259,8 +259,18 @@ def main(infile, outfile): "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) + # read inputs only for a barrier task + isBarrier = read_bool(infile) + boundPort = read_int(infile) + secret = UTF8Deserializer().loads(infile) # initialize global state - taskContext = TaskContext._getOrCreate() + taskContext = None + if isBarrier: + taskContext = BarrierTaskContext._getOrCreate() + BarrierTaskContext._initialize(boundPort, secret) + else: + taskContext = TaskContext._getOrCreate() + # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) From a998e9d829bd499dd7c65f973ea4389e0401b001 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 21 Aug 2018 17:08:15 -0700 Subject: [PATCH 1446/2461] [MINOR] Added import to fix compilation ## What changes were proposed in this pull request? Two back to PRs implicitly conflicted by one PR removing an existing import that the other PR needed. This did not cause explicit conflict as the import already existed, but not used. https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-maven-hadoop-2.7/8226/consoleFull ``` [info] Compiling 342 Scala sources and 97 Java sources to /home/jenkins/workspace/spark-master-compile-maven-hadoop-2.7/sql/core/target/scala-2.11/classes... [warn] /home/jenkins/workspace/spark-master-compile-maven-hadoop-2.7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala:128: value ENABLE_JOB_SUMMARY in object ParquetOutputFormat is deprecated: see corresponding Javadoc for more information. [warn] && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { [warn] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-hadoop-2.7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala:95: value asJava is not a member of scala.collection.immutable.Map[String,Long] [error] new java.util.HashMap(customMetrics.mapValues(long2Long).asJava) [error] ^ [warn] one warning found [error] one error found [error] Compile failed at Aug 21, 2018 4:04:35 PM [12.827s] ``` ## How was this patch tested? It compiles! Closes #22175 from tdas/fix-build. Authored-by: Tathagata Das Signed-off-by: Tathagata Das --- .../spark/sql/execution/streaming/statefulOperators.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 7351db8c4fbae..c11af345b0248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ +import scala.collection.JavaConverters._ + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ From 07737c87d6086c986785ff0edc43ca94effa4fc6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Aug 2018 22:17:44 -0700 Subject: [PATCH 1447/2461] [SPARK-23711][SPARK-25140][SQL] Catch correct exceptions when expr codegen fails ## What changes were proposed in this pull request? This pr is to fix bugs when expr codegen fails; we need to catch `java.util.concurrent.ExecutionException` instead of `InternalCompilerException` and `CompileException` . This handling is the same with the `WholeStageCodegenExec ` one: https://github.com/apache/spark/blob/60af2501e1afc00192c779f2736a4e3de12428fa/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala#L585 ## How was this patch tested? Added tests in `CodeGeneratorWithInterpretedFallbackSuite` Closes #22154 from maropu/SPARK-25140. Authored-by: Takeshi Yamamuro Signed-off-by: Xiao Li --- ...CodeGeneratorWithInterpretedFallback.scala | 23 +++------- .../sql/catalyst/expressions/Projection.scala | 7 ++- ...eneratorWithInterpretedFallbackSuite.scala | 44 +++++++++++++++++-- .../sql/execution/WholeStageCodegenExec.scala | 3 +- 4 files changed, 55 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala index 0f6d86691b4d2..07fa813a98922 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -17,24 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.InternalCompilerException +import scala.util.control.NonFatal -import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils -/** - * Catches compile error during code generation. - */ -object CodegenError { - def unapply(throwable: Throwable): Option[Exception] = throwable match { - case e: InternalCompilerException => Some(e) - case e: CompileException => Some(e) - case _ => None - } -} - /** * Defines values for `SQLConf` config of fallback mode. Use for test only. */ @@ -47,7 +35,7 @@ object CodegenObjectFactoryMode extends Enumeration { * error happens, it can fallback to interpreted implementation. In tests, we can use a SQL config * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. */ -abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging { def createObject(in: IN): OUT = { // We are allowed to choose codegen-only or no-codegen modes if under tests. @@ -63,7 +51,10 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { try { createCodeGeneratedObject(in) } catch { - case CodegenError(_) => createInterpretedObject(in) + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + createInterpretedObject(in) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 6493f09100577..226a4ddcffaa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.types.{DataType, StructType} @@ -180,7 +182,10 @@ object UnsafeProjection try { GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) } catch { - case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + InterpretedUnsafeProjection.createProjection(unsafeExprs) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 531ca9a87370a..28edd85ab6e87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -17,17 +17,33 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.concurrent.ExecutionException + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.IntegerType class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { - test("UnsafeProjection with codegen factory mode") { - val input = Seq(LongType, IntegerType) - .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + object FailedCodegenProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + val invalidCode = new CodeAndComment("invalid code", Map.empty) + // We assume this compilation throws an exception + CodeGenerator.compile(invalidCode) + null + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + } + test("UnsafeProjection with codegen factory mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { val obj = UnsafeProjection.createObject(input) @@ -40,4 +56,24 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT assert(obj.isInstanceOf[InterpretedUnsafeProjection]) } } + + test("fallback to the interpreter mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val fallback = CodegenObjectFactoryMode.FALLBACK.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) { + val obj = FailedCodegenProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } + + test("codegen failures in the CODEGEN_ONLY mode") { + val errMsg = intercept[ExecutionException] { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + FailedCodegenProjection.createObject(input) + } + }.getMessage + assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 80f886ea1adc8..1fc4de9e56015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -21,6 +21,7 @@ import java.util.Locale import java.util.function.Supplier import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.broadcast import org.apache.spark.rdd.RDD @@ -582,7 +583,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { - case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => + case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() From 4a9c9d8f9a8f8f165369e121d3b553a3515333d4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 21 Aug 2018 22:21:08 -0700 Subject: [PATCH 1448/2461] [SPARK-25159][SQL] json schema inference should only trigger one job ## What changes were proposed in this pull request? This fixes a perf regression caused by https://github.com/apache/spark/pull/21376 . We should not use `RDD#toLocalIterator`, which triggers one Spark job per RDD partition. This is very bad for RDDs with a lot of small partitions. To fix it, this PR introduces a way to access SQLConf in the scheduler event loop thread, so that we don't need to use `RDD#toLocalIterator` anymore in `JsonInferSchema`. ## How was this patch tested? a new test Closes #22152 from cloud-fan/conf. Authored-by: Wenchen Fan Signed-off-by: Xiao Li --- .../sql/catalyst/json/JsonInferSchema.scala | 16 ++++++--- .../apache/spark/sql/internal/SQLConf.scala | 33 +++++++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 24 ++++++++++++++ 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 5f70e062d46c8..9999a005106f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -69,10 +70,17 @@ private[sql] object JsonInferSchema { }.reduceOption(typeMerger).toIterator } - // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because - // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have - // active SparkSession and `SQLConf.get` may point to the wrong configs. - val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) + // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running + // the fold functions in the scheduler event loop thread. + val existingConf = SQLConf.get + var rootType: DataType = StructType(Nil) + val foldPartition = (iter: Iterator[DataType]) => iter.fold(StructType(Nil))(typeMerger) + val mergeResult = (index: Int, taskResult: DataType) => { + rootType = SQLConf.withExistingConf(existingConf) { + typeMerger(rootType, taskResult) + } + } + json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5913c947f2b61..df2caff902648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -82,6 +82,19 @@ object SQLConf { /** See [[get]] for more information. */ def getFallbackConf: SQLConf = fallbackConf.get() + private lazy val existingConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = null + } + + def withExistingConf[T](conf: SQLConf)(f: => T): T = { + existingConf.set(conf) + try { + f + } finally { + existingConf.remove() + } + } + /** * Defines a getter that returns the SQLConf within scope. * See [[get]] for more information. @@ -116,16 +129,24 @@ object SQLConf { if (TaskContext.get != null) { new ReadOnlySQLConf(TaskContext.get()) } else { - if (Utils.isTesting && SparkContext.getActive.isDefined) { + val isSchedulerEventLoopThread = SparkContext.getActive + .map(_.dagScheduler.eventProcessLoop.eventThread) + .exists(_.getId == Thread.currentThread().getId) + if (isSchedulerEventLoopThread) { // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` - // will return `fallbackConf` which is unexpected. Here we prevent it from happening. - val schedulerEventLoopThread = - SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread - if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + // will return `fallbackConf` which is unexpected. Here we require the caller to get the + // conf within `withExistingConf`, otherwise fail the query. + val conf = existingConf.get() + if (conf != null) { + conf + } else if (Utils.isTesting) { throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } else { + confGetter.get()() } + } else { + confGetter.get()() } - confGetter.get()() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b0e22a51e7611..7310087cc99f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -27,6 +27,7 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} @@ -2528,4 +2529,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) } } + + test("SPARK-25159: json schema inference should only trigger one job") { + withTempPath { path => + // This test is to prove that the `JsonInferSchema` does not use `RDD#toLocalIterator` which + // triggers one Spark job per RDD partition. + Seq(1 -> "a", 2 -> "b").toDF("i", "p") + // The data set has 2 partitions, so Spark will write at least 2 json files. + // Use a non-splittable compression (gzip), to make sure the json scan RDD has at least 2 + // partitions. + .write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath) + + var numJobs = 0 + sparkContext.addSparkListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + numJobs += 1 + } + }) + + val df = spark.read.json(path.getCanonicalPath) + assert(df.columns === Array("i", "p")) + assert(numJobs == 1) + } + } } From 55f36641ff20114b892795f100da7efb79b0cc32 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 22 Aug 2018 14:31:51 +0800 Subject: [PATCH 1449/2461] [SPARK-25093][SQL] Avoid recompiling regexp for comments multiple times ## What changes were proposed in this pull request? The PR moves the compilation of the regexp for code formatting outside the method which is called for each code block when splitting expressions, in order to avoid recompiling the regexp every time. Credit should be given to Izek Greenfield. ## How was this patch tested? existing UTs Closes #22135 from mgaido91/SPARK-25093. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 4 ++-- core/src/main/scala/org/apache/spark/util/Utils.scala | 11 ++++++----- .../spark/ml/r/AFTSurvivalRegressionWrapper.scala | 6 +++--- .../spark/sql/catalyst/catalog/SessionCatalog.scala | 3 ++- .../catalyst/expressions/codegen/CodeFormatter.scala | 10 +++++----- .../scala/org/apache/spark/sql/types/DataType.scala | 3 ++- .../org/apache/spark/streaming/dstream/DStream.scala | 10 +++++----- 7 files changed, 25 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ee1ca0bba5749..cbd812a05a2c6 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -758,6 +758,7 @@ private[deploy] class Worker( private[deploy] object Worker extends Logging { val SYSTEM_NAME = "sparkWorker" val ENDPOINT_NAME = "Worker" + private val SSL_NODE_LOCAL_CONFIG_PATTERN = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r def main(argStrings: Array[String]) { Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler( @@ -803,9 +804,8 @@ private[deploy] object Worker extends Logging { } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { - val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r val result = cmd.javaOpts.collectFirst { - case pattern(_result) => _result.toBoolean + case SSL_NODE_LOCAL_CONFIG_PATTERN(_result) => _result.toBoolean } result.getOrElse(false) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7ec707d94ed87..e6646bd073c6b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1409,13 +1409,14 @@ private[spark] object Utils extends Logging { } } + // A regular expression to match classes of the internal Spark API's + // that we want to skip when finding the call site of a method. + private val SPARK_CORE_CLASS_REGEX = + """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r + private val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r + /** Default filtering function for finding call sites using `getCallSite`. */ private def sparkInternalExclusionFunction(className: String): Boolean = { - // A regular expression to match classes of the internal Spark API's - // that we want to skip when finding the call site of a method. - val SPARK_CORE_CLASS_REGEX = - """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r - val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r val SCALA_CORE_CLASS_PREFIX = "scala" val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 80d03ab03c87d..48485e02edda8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -59,13 +59,13 @@ private[r] class AFTSurvivalRegressionWrapper private ( private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalRegressionWrapper] { + private val FORMULA_REGEXP = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r + private def formulaRewrite(formula: String): (String, String) = { var rewritedFormula: String = null var censorCol: String = null - - val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r try { - val regex(label, censor, features) = formula + val FORMULA_REGEXP(label, censor, features) = formula // TODO: Support dot operator. if (features.contains(".")) { throw new UnsupportedOperationException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index ee3932cc56d01..afb0f009db05c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -101,6 +101,8 @@ class SessionCatalog( @GuardedBy("this") protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) + private val validNameFormat = "([\\w_]+)".r + /** * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. @@ -109,7 +111,6 @@ class SessionCatalog( * org.apache.hadoop.hive.metastore.MetaStoreUtils.validateName. */ private def validateName(name: String): Unit = { - val validNameFormat = "([\\w_]+)".r if (!validNameFormat.pattern.matcher(name).matches()) { throw new AnalysisException(s"`$name` is not a valid name for tables/databases. " + "Valid names only contain alphabet characters, numbers and _.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 7b398f424cead..ea1bb87d415c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -27,6 +27,10 @@ import java.util.regex.Matcher */ object CodeFormatter { val commentHolder = """\/\*(.+?)\*\/""".r + val commentRegexp = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val extraNewLinesRegexp = """\n\s*\n""".r // strip extra newlines def format(code: CodeAndComment, maxLines: Int = -1): String = { val formatter = new CodeFormatter @@ -91,11 +95,7 @@ object CodeFormatter { } def stripExtraNewLinesAndComments(input: String): String = { - val commentReg = - ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ - """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment - val codeWithoutComment = commentReg.replaceAllIn(input, "") - codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + extraNewLinesRegexp.replaceAllIn(commentRegexp.replaceAllIn(input, ""), "\n") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 50f2a9df522c2..e53628d11ccf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -114,6 +114,8 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r + def fromDDL(ddl: String): DataType = { try { CatalystSqlParser.parseDataType(ddl) @@ -132,7 +134,6 @@ object DataType { /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e23edfa506517..4a4d2c5d9d8c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -940,6 +940,11 @@ abstract class DStream[T: ClassTag] ( object DStream { + private val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r + private val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r + private val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + private val SCALA_CLASS_REGEX = """^scala""".r + // `toPairDStreamFunctions` was in SparkContext before 1.3 and users had to // `import StreamingContext._` to enable it. Now we move it here to make the compiler find // it automatically. However, we still keep the old function in StreamingContext for backward @@ -953,11 +958,6 @@ object DStream { /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ private[streaming] def getCreationSite(): CallSite = { - val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r - val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r - val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r - val SCALA_CLASS_REGEX = """^scala""".r - /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined From e754887182304ad0d622754e33192ebcdd515965 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 22 Aug 2018 00:10:55 -0700 Subject: [PATCH 1450/2461] [SPARK-24882][SQL] improve data source v2 API ## What changes were proposed in this pull request? Improve the data source v2 API according to the [design doc](https://docs.google.com/document/d/1DDXCTCrup4bKWByTalkXWgavcPdvur8a4eEu8x1BzPM/edit?usp=sharing) summary of the changes 1. rename `ReadSupport` -> `DataSourceReader` -> `InputPartition` -> `InputPartitionReader` to `BatchReadSupportProvider` -> `BatchReadSupport` -> `InputPartition`/`PartitionReaderFactory` -> `PartitionReader`. Similar renaming also happens at streaming and write APIs. 2. create `ScanConfig` to store query specific information like operator pushdown result, streaming offsets, etc. This makes batch and streaming `ReadSupport`(previouslly named `DataSourceReader`) immutable. All other methods take `ScanConfig` as input, which implies applying operator pushdown and getting streaming offsets happen before all other things(get input partitions, report statistics, etc.). 3. separate `InputPartition` to `InputPartition` and `PartitionReaderFactory`. This is a natural separation, data splitting and reading are orthogonal and we should not mix them in one interfaces. This also makes the naming consistent between read and write API: `PartitionReaderFactory` vs `DataWriterFactory`. 4. separate the batch and streaming interfaces. Sometimes it's painful to force the streaming interface to extend batch interface, as we may need to override some batch methods to return false, or even leak the streaming concept to batch API(e.g. `DataWriterFactory#createWriter(partitionId, taskId, epochId)`) Some follow-ups we should do after this PR (tracked by https://issues.apache.org/jira/browse/SPARK-25186 ): 1. Revisit the life cycle of `ReadSupport` instances. Currently I keep it same as the previous `DataSourceReader`, i.e. the life cycle is bound to the batch/stream query. This fits streaming very well but may not be perfect for batch source. We can also consider to let `ReadSupport.newScanConfigBuilder` take `DataSourceOptions` as parameter, if we decide to change the life cycle. 2. Add `WriteConfig`. This is similar to `ScanConfig` and makes the write API more flexible. But it's only needed when we add the `replaceWhere` support, and it needs to change the streaming execution engine for this new concept, which I think is better to be done in another PR. 3. Refine the document. This PR adds/changes a lot of document and it's very likely that some people may have better ideas. 4. Figure out the life cycle of `CustomMetrics`. It looks to me that it should be bound to a `ScanConfig`, but we need to change `ProgressReporter` to get the `ScanConfig`. Better to be done in another PR. 5. Better operator pushdown API. This PR keeps the pushdown API as it was, i.e. using the `SupportsPushdownXYZ` traits. We can design a better API using build pattern, but this is a complicated design and deserves an individual JIRA ticket and design doc. 6. Improve the continuous streaming engine to only create a new `ScanConfig` when re-configuring. 7. Remove `SupportsPushdownCatalystFilter`. This is actually not a must-have for file source, we can change the hive partition pruning to use the public `Filter`. ## How was this patch tested? existing tests. Closes #22009 from cloud-fan/redesign. Authored-by: Wenchen Fan Signed-off-by: Xiao Li --- ...scala => KafkaContinuousReadSupport.scala} | 133 +++--- ...scala => KafkaMicroBatchReadSupport.scala} | 109 +++-- .../sql/kafka010/KafkaSourceProvider.scala | 37 +- ...scala => KafkaStreamingWriteSupport.scala} | 14 +- .../kafka010/KafkaContinuousSourceSuite.scala | 8 +- .../sql/kafka010/KafkaContinuousTest.scala | 8 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 33 +- ...ort.java => BatchReadSupportProvider.java} | 42 +- ...rt.java => BatchWriteSupportProvider.java} | 32 +- .../sql/sources/v2/ContinuousReadSupport.java | 46 -- .../v2/ContinuousReadSupportProvider.java | 70 +++ .../spark/sql/sources/v2/DataSourceV2.java | 10 +- .../sql/sources/v2/MicroBatchReadSupport.java | 52 --- .../v2/MicroBatchReadSupportProvider.java | 70 +++ .../sql/sources/v2/SessionConfigSupport.java | 13 +- .../sql/sources/v2/StreamWriteSupport.java | 52 --- .../v2/StreamingWriteSupportProvider.java | 54 +++ .../sources/v2/reader/BatchReadSupport.java | 51 ++ .../sources/v2/reader/DataSourceReader.java | 75 --- .../sql/sources/v2/reader/InputPartition.java | 26 +- ...titionReader.java => PartitionReader.java} | 18 +- .../v2/reader/PartitionReaderFactory.java | 66 +++ .../sql/sources/v2/reader/ReadSupport.java | 50 ++ .../sql/sources/v2/reader/ScanConfig.java | 45 ++ ...tPartition.java => ScanConfigBuilder.java} | 15 +- .../sql/sources/v2/reader/Statistics.java | 2 +- .../SupportsPushDownCatalystFilters.java | 4 +- .../v2/reader/SupportsPushDownFilters.java | 6 +- .../SupportsPushDownRequiredColumns.java | 8 +- .../v2/reader/SupportsReportPartitioning.java | 12 +- .../v2/reader/SupportsReportStatistics.java | 16 +- .../v2/reader/SupportsScanColumnarBatch.java | 53 --- .../partitioning/ClusteredDistribution.java | 4 +- .../v2/reader/partitioning/Distribution.java | 6 +- .../v2/reader/partitioning/Partitioning.java | 5 +- ...er.java => ContinuousPartitionReader.java} | 23 +- .../ContinuousPartitionReaderFactory.java} | 29 +- .../streaming/ContinuousReadSupport.java | 77 ++++ .../v2/reader/streaming/ContinuousReader.java | 79 ---- .../streaming/MicroBatchReadSupport.java | 60 +++ .../v2/reader/streaming/MicroBatchReader.java | 75 --- .../sources/v2/reader/streaming/Offset.java | 4 +- .../streaming/StreamingReadSupport.java | 49 ++ .../SupportsCustomReaderMetrics.java | 10 +- ...urceWriter.java => BatchWriteSupport.java} | 25 +- .../sql/sources/v2/writer/DataWriter.java | 16 +- .../sources/v2/writer/DataWriterFactory.java | 23 +- .../v2/writer/WriterCommitMessage.java | 9 +- .../streaming/StreamingDataWriterFactory.java | 59 +++ ...Writer.java => StreamingWriteSupport.java} | 32 +- .../SupportsCustomWriterMetrics.java | 10 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../datasources/v2/DataSourceRDD.scala | 44 +- .../datasources/v2/DataSourceV2Relation.scala | 72 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 65 ++- .../datasources/v2/DataSourceV2Strategy.scala | 49 +- .../datasources/v2/DataSourceV2Utils.scala | 9 + .../v2/WriteToDataSourceV2Exec.scala | 40 +- .../streaming/MicroBatchExecution.scala | 91 ++-- .../streaming/ProgressReporter.scala | 22 +- .../SimpleStreamingScanConfigBuilder.scala | 40 ++ .../streaming/StreamingRelation.scala | 6 +- .../sql/execution/streaming/console.scala | 14 +- .../continuous/ContinuousDataSourceRDD.scala | 37 +- .../continuous/ContinuousExecution.scala | 51 +- .../ContinuousQueuedDataReader.scala | 29 +- .../ContinuousRateStreamSource.scala | 60 ++- .../ContinuousTextSocketSource.scala | 74 +-- .../continuous/ContinuousWriteRDD.scala | 7 +- .../continuous/EpochCoordinator.scala | 18 +- .../WriteToContinuousDataSource.scala | 4 +- .../WriteToContinuousDataSourceExec.scala | 10 +- .../sql/execution/streaming/memory.scala | 51 +- ...Writer.scala => ConsoleWriteSupport.scala} | 11 +- .../sources/ContinuousMemoryStream.scala | 76 ++- ...cala => ForeachWriteSupportProvider.scala} | 31 +- .../sources/MicroBatchWritSupport.scala | 51 ++ .../sources/PackedRowWriterFactory.scala | 9 +- ...=> RateControlMicroBatchReadSupport.scala} | 22 +- ... => RateStreamMicroBatchReadSupport.scala} | 79 ++-- .../sources/RateStreamProvider.scala | 27 +- .../streaming/sources/memoryV2.scala | 62 +-- .../execution/streaming/sources/socket.scala | 114 ++--- .../sql/streaming/DataStreamReader.scala | 52 ++- .../sql/streaming/DataStreamWriter.scala | 9 +- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 147 +++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 114 ----- .../sources/v2/JavaColumnarDataSourceV2.java | 114 +++++ .../v2/JavaPartitionAwareDataSource.java | 81 ++-- .../v2/JavaSchemaRequiredDataSource.java | 26 +- .../sources/v2/JavaSimpleDataSourceV2.java | 68 +-- .../sql/sources/v2/JavaSimpleReadSupport.java | 99 ++++ ...pache.spark.sql.sources.DataSourceRegister | 4 +- .../streaming/MemorySinkV2Suite.scala | 44 +- ...e.scala => ConsoleWriteSupportSuite.scala} | 4 +- .../sources/RateStreamProviderSuite.scala | 84 ++-- .../sources/TextSocketStreamSuite.scala | 81 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 435 ++++++++++-------- .../sources/v2/SimpleWritableDataSource.scala | 110 ++--- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 4 +- .../sql/streaming/StreamingQuerySuite.scala | 59 +-- .../ContinuousQueuedDataReaderSuite.scala | 45 +- .../continuous/ContinuousSuite.scala | 2 +- .../continuous/EpochCoordinatorSuite.scala | 18 +- .../sources/StreamingDataSourceV2Suite.scala | 95 ++-- 108 files changed, 2589 insertions(+), 2224 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaContinuousReader.scala => KafkaContinuousReadSupport.scala} (74%) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaMicroBatchReader.scala => KafkaMicroBatchReadSupport.scala} (83%) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaStreamWriter.scala => KafkaStreamingWriteSupport.scala} (91%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{ReadSupport.java => BatchReadSupportProvider.java} (59%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{WriteSupport.java => BatchWriteSupportProvider.java} (58%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{InputPartitionReader.java => PartitionReader.java} (67%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ContinuousInputPartition.java => ScanConfigBuilder.java} (61%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/{ContinuousInputPartitionReader.java => ContinuousPartitionReader.java} (60%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{SupportsDeprecatedScanRow.java => streaming/ContinuousPartitionReaderFactory.java} (51%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/{DataSourceWriter.java => BatchWriteSupport.java} (79%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/{StreamWriter.java => StreamingWriteSupport.java} (78%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{ConsoleWriter.scala => ConsoleWriteSupport.scala} (86%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{ForeachWriterProvider.scala => ForeachWriteSupportProvider.scala} (82%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{MicroBatchWriter.scala => RateControlMicroBatchReadSupport.scala} (50%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{RateStreamMicroBatchReader.scala => RateStreamMicroBatchReadSupport.scala} (78%) delete mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/{ConsoleWriterSuite.scala => ConsoleWriteSupportSuite.scala} (98%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala similarity index 74% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index be7ce3b3ed757..4a18839e6a77a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -25,16 +25,15 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType /** - * A [[ContinuousReader]] for data from kafka. + * A [[ContinuousReadSupport]] for data from kafka. * * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. @@ -47,70 +46,49 @@ import org.apache.spark.sql.types.StructType * scenarios, where some offsets after the specified initial ones can't be * properly read. */ -class KafkaContinuousReader( +class KafkaContinuousReadSupport( offsetReader: KafkaOffsetReader, kafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, initialOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends ContinuousReader with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext + extends ContinuousReadSupport with Logging { private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating reader factories. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setStartOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets + override def initialOffset(): Offset = { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) } + logInfo(s"Initial offsets: $offsets") + offsets } - override def getStartOffset(): Offset = offset + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss) + } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets startOffsets.toSeq.map { case (topicPartition, start) => KafkaContinuousInputPartition( - topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss - ): InputPartition[InternalRow] - }.asJava + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + KafkaContinuousReaderFactory } /** Stop this source and free any resources it has allocated. */ @@ -127,8 +105,9 @@ class KafkaContinuousReader( KafkaSourceOffset(mergedMap) } - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + override def needsReconfiguration(config: ScanConfig): Boolean = { + val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions + offsetReader.fetchLatestOffsets().keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" @@ -162,23 +141,51 @@ case class KafkaContinuousInputPartition( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartition[InternalRow] { - - override def createContinuousReader( - offset: PartitionOffset): InputPartitionReader[InternalRow] = { - val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] - require(kafkaOffset.topicPartition == topicPartition, - s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousInputPartitionReader( - topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + failOnDataLoss: Boolean) extends InputPartition + +object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaContinuousInputPartition] + new KafkaContinuousPartitionReader( + p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss) } +} + +class KafkaContinuousScanConfigBuilder( + schema: StructType, + startOffset: Offset, + offsetReader: KafkaOffsetReader, + reportDataLoss: String => Unit) + extends ScanConfigBuilder { + + override def build(): ScanConfig = { + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - override def createPartitionReader(): KafkaContinuousInputPartitionReader = { - new KafkaContinuousInputPartitionReader( - topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + KafkaContinuousScanConfig(schema, startOffsets) } } +case class KafkaContinuousScanConfig( + readSchema: StructType, + startOffsets: Map[TopicPartition, Long]) + extends ScanConfig { + + // Created when building the scan config builder. If this diverges from the partitions at the + // latest offsets, we need to reconfigure the kafka read support. + def knownPartitions: Set[TopicPartition] = startOffsets.keySet +} + /** * A per-task data reader for continuous Kafka processing. * @@ -189,12 +196,12 @@ case class KafkaContinuousInputPartition( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousInputPartitionReader( +class KafkaContinuousPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] { + failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala similarity index 83% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index 900c9f4e7fbf3..c31af60b8a1c2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -21,8 +21,6 @@ import java.{util => ju} import java.io._ import java.nio.charset.StandardCharsets -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.kafka.common.TopicPartition @@ -32,16 +30,17 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset, SupportsCustomReaderMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread /** - * A [[MicroBatchReader]] that reads data from Kafka. + * A [[MicroBatchReadSupport]] that reads data from Kafka. * * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For @@ -56,17 +55,14 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] class KafkaMicroBatchReader( +private[kafka010] class KafkaMicroBatchReadSupport( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with SupportsCustomReaderMetrics with Logging { - - private var startPartitionOffsets: PartitionOffsetMap = _ - private var endPartitionOffsets: PartitionOffsetMap = _ + extends RateControlMicroBatchReadSupport with SupportsCustomReaderMetrics with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -76,34 +72,40 @@ private[kafka010] class KafkaMicroBatchReader( Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) + + private var endPartitionOffsets: KafkaSourceOffset = _ + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ - private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() - - override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { - // Make sure initialPartitionOffsets is initialized - initialPartitionOffsets - - startPartitionOffsets = Option(start.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse(initialPartitionOffsets) - - endPartitionOffsets = Option(end.orElse(null)) - .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) - .getOrElse { - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() - maxOffsetsPerTrigger.map { maxOffsets => - rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) - }.getOrElse { - latestPartitionOffsets - } - } + override def initialOffset(): Offset = { + KafkaSourceOffset(getOrCreateInitialPartitionOffsets()) + } + + override def latestOffset(start: Offset): Offset = { + val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + }) + endPartitionOffsets + } + + override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startPartitionOffsets = sc.start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + val endPartitionOffsets = sc.end.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets + // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -145,30 +147,26 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges offsetRanges.map { range => - new KafkaMicroBatchInputPartition( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer - ): InputPartition[InternalRow] - }.asJava - } - - override def getStartOffset: Offset = { - KafkaSourceOffset(startPartitionOffsets) + KafkaMicroBatchInputPartition( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + }.toArray } - override def getEndOffset: Offset = { - KafkaSourceOffset(endPartitionOffsets) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + KafkaMicroBatchReaderFactory } - override def getCustomMetrics: CustomMetrics = { - KafkaCustomMetrics(kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets) + // TODO: figure out the life cycle of custom metrics, and make this method take `ScanConfig` as + // a parameter. + override def getCustomMetrics(): CustomMetrics = { + KafkaCustomMetrics( + kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets.partitionToOffsets) } override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema - override def commit(end: Offset): Unit = {} override def stop(): Unit = { @@ -311,22 +309,23 @@ private[kafka010] case class KafkaMicroBatchInputPartition( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartition[InternalRow] { + reuseKafkaConsumer: Boolean) extends InputPartition - override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, - failOnDataLoss, reuseKafkaConsumer) +private[kafka010] object KafkaMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[KafkaMicroBatchInputPartition] + KafkaMicroBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs, + p.failOnDataLoss, p.reuseKafkaConsumer) + } } -/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchInputPartitionReader( +/** A [[PartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends InputPartitionReader[InternalRow] with Logging { + reuseKafkaConsumer: Boolean) extends PartitionReader[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d225c1ea6b7f1..28c9853bfea9c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,9 +30,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,9 +45,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamWriteSupport - with ContinuousReadSupport - with MicroBatchReadSupport + with StreamingWriteSupportProvider + with ContinuousReadSupportProvider + with MicroBatchReadSupportProvider with Logging { import KafkaSourceProvider._ @@ -108,13 +107,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches - * of Kafka data in a micro-batch streaming query. + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport]] to read + * batches of Kafka data in a micro-batch streaming query. */ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaMicroBatchReader = { + options: DataSourceOptions): KafkaMicroBatchReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) @@ -140,7 +138,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaMicroBatchReader( + new KafkaMicroBatchReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), options, @@ -150,13 +148,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[ContinuousInputPartitionReader]] to read + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read * Kafka data in a continuous streaming query. */ - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( metadataPath: String, - options: DataSourceOptions): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReadSupport = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -181,7 +178,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters, driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaContinuousReader( + new KafkaContinuousReadSupport( kafkaOffsetReader, kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), parameters, @@ -270,11 +267,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -285,7 +282,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - new KafkaStreamWriter(topic, producerParams, schema) + new KafkaStreamingWriteSupport(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala similarity index 91% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala index 5f0802b466039..dc19312f79a22 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** @@ -33,20 +33,20 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamWriter( +class KafkaStreamingWriteSupport( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter { + extends StreamingWriteSupport { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createWriterFactory(): KafkaStreamWriterFactory = + override def createStreamingWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -63,9 +63,9 @@ class KafkaStreamWriter( */ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { + extends StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index ea2a2a84d22c6..321665042b8eb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -61,10 +61,12 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r - }.exists { r => + case r: StreamingDataSourceV2Relation + if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + r.scanConfigBuilder.build().asInstanceOf[KafkaContinuousScanConfig] + }.exists { config => // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) + config.knownPartitions.exists(_.topic == topic2) }, s"query never reconfigured to new topic $topic2") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa1468a3943c8..fa6bdc20bd4f9 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -46,8 +46,10 @@ trait KafkaContinuousTest extends KafkaSourceTest { testUtils.addPartitions(topic, newCount) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c7b74f305eed2..946b636710f0d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Optional, Properties} +import java.util.{Locale, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -44,11 +44,9 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { @@ -118,14 +116,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = { + val sources: Seq[BaseStreamingSource] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source - case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReadSupport, _) => source } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader + case r: StreamingDataSourceV2Relation + if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + r.readSupport.asInstanceOf[KafkaContinuousReadSupport] } }) }.distinct @@ -650,7 +650,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { makeSureGetOffsetCalled, AssertOnQuery { query => query.logicalPlan.collect { - case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + case StreamingExecutionRelation(_: KafkaMicroBatchReadSupport, _) => true }.nonEmpty } ) @@ -675,17 +675,16 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val reader = provider.createMicroBatchReader( - Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) - reader.setOffsetRange( - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), - Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) - ) - val factories = reader.planInputPartitions().asScala + val readSupport = provider.createMicroBatchReadSupport( + dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + val config = readSupport.newScanConfigBuilder( + KafkaSourceOffset(Map(tp -> 0L)), + KafkaSourceOffset(Map(tp -> 100L))).build() + val inputPartitions = readSupport.planInputPartitions(config) .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) - withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { - assert(factories.size == numPartitionsGenerated) - factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + withClue(s"minPartitions = $minPartitions generated factories $inputPartitions\n\t") { + assert(inputPartitions.size == numPartitionsGenerated) + inputPartitions.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java similarity index 59% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java index 80ac08ee5ff52..f403dc619e86c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -18,48 +18,44 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.DataSourceRegister; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability and scan the data from the data source. + * provide data reading ability for batch processing. + * + * This interface is used to create {@link BatchReadSupport} instances when end users run + * {@code SparkSession.read.format(...).option(...).load()}. */ @InterfaceStability.Evolving -public interface ReadSupport extends DataSourceV2 { +public interface BatchReadSupportProvider extends DataSourceV2 { /** - * Creates a {@link DataSourceReader} to scan the data from this data source. + * Creates a {@link BatchReadSupport} instance to load the data from this data source with a user + * specified schema, which is called by Spark at the beginning of each batch query. + * + * Spark will call this method at the beginning of each batch query to create a + * {@link BatchReadSupport} instance. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. * * @param schema the user specified schema. * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. - * - * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user specified schema. */ - default DataSourceReader createReader(StructType schema, DataSourceOptions options) { - String name; - if (this instanceof DataSourceRegister) { - name = ((DataSourceRegister) this).shortName(); - } else { - name = this.getClass().getName(); - } - throw new UnsupportedOperationException(name + " does not support user specified schema"); + default BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); } /** - * Creates a {@link DataSourceReader} to scan the data from this data source. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * Creates a {@link BatchReadSupport} instance to scan the data from this data source, which is + * called by Spark at the beginning of each batch query. * * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceReader createReader(DataSourceOptions options); + BatchReadSupport createBatchReadSupport(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java similarity index 58% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index 048787a7a0a05..bd10c3353bf12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -21,33 +21,39 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data to the data source. + * provide data writing ability for batch processing. + * + * This interface is used to create {@link BatchWriteSupport} instances when end users run + * {@code Dataset.write.format(...).option(...).save()}. */ @InterfaceStability.Evolving -public interface WriteSupport extends DataSourceV2 { +public interface BatchWriteSupportProvider extends DataSourceV2 { /** - * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done according to the save mode. + * Creates an optional {@link BatchWriteSupport} instance to save the data to this data source, + * which is called by Spark at the beginning of each batch query. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * Data sources can return None if there is no writing needed to be done according to the save + * mode. * - * @param writeUUID A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceWriter} can - * use this job id to distinguish itself from other jobs. + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link BatchWriteSupport} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data * source, please refer to {@link SaveMode} for more details. * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. - * @return a writer to append data to this data source + * @return a write support to write data to this data source. */ - Optional createWriter( - String writeUUID, StructType schema, SaveMode mode, DataSourceOptions options); + Optional createBatchWriteSupport( + String queryId, + StructType schema, + SaveMode mode, + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java deleted file mode 100644 index 7df5a451ae5f3..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data reading ability for continuous stream processing. - */ -@InterfaceStability.Evolving -public interface ContinuousReadSupport extends DataSourceV2 { - /** - * Creates a {@link ContinuousReader} to scan the data from this data source. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - ContinuousReader createContinuousReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java new file mode 100644 index 0000000000000..824c290518acf --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for continuous stream processing. + * + * This interface is used to create {@link ContinuousReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * continuous streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default ContinuousReadSupport createContinuousReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link ContinuousReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each continuous streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + ContinuousReadSupport createContinuousReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6234071320dc9..6e31e84bf6c72 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -22,9 +22,13 @@ /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. * - * Note that this is an empty interface. Data source implementations should mix-in at least one of - * the plug-in interfaces like {@link ReadSupport} and {@link WriteSupport}. Otherwise it's just - * a dummy data source which is un-readable/writable. + * Note that this is an empty interface. Data source implementations must mix in interfaces such as + * {@link BatchReadSupportProvider} or {@link BatchWriteSupportProvider}, which can provide + * batch or streaming read/write support instances. Otherwise it's just a dummy data source which + * is un-readable/writable. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. */ @InterfaceStability.Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java deleted file mode 100644 index 7f4a2c9593c76..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide streaming micro-batch data reading ability. - */ -@InterfaceStability.Evolving -public interface MicroBatchReadSupport extends DataSourceV2 { - /** - * Creates a {@link MicroBatchReader} to read batches of data from this data source in a - * streaming query. - * - * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and - * then call stop() when the execution is complete. Note that a single query may have multiple - * executions due to restart or failure recovery. - * - * @param schema the user provided schema, or empty() if none was provided - * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure - * recovery. Readers for the same logical source in the same query - * will be given the same checkpointLocation. - * @param options the options for the returned data source reader, which is an immutable - * case-insensitive string-to-string map. - */ - MicroBatchReader createMicroBatchReader( - Optional schema, - String checkpointLocation, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java new file mode 100644 index 0000000000000..61c08e7fa89df --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability for micro-batch stream processing. + * + * This interface is used to create {@link MicroBatchReadSupport} instances when end users run + * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupportProvider extends DataSourceV2 { + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source with a user specified schema, which is called by Spark at the beginning of each + * micro-batch streaming query. + * + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user specified schema. + * + * @param schema the user provided schema. + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + default MicroBatchReadSupport createMicroBatchReadSupport( + StructType schema, + String checkpointLocation, + DataSourceOptions options) { + return DataSourceV2Utils.failForUserSpecifiedSchema(this); + } + + /** + * Creates a {@link MicroBatchReadSupport} instance to scan the data from this streaming data + * source, which is called by Spark at the beginning of each micro-batch streaming query. + * + * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure + * recovery. Readers for the same logical source in the same query + * will be given the same checkpointLocation. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. + */ + MicroBatchReadSupport createMicroBatchReadSupport( + String checkpointLocation, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 9d66805d79b9e..bbe430e299261 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -27,10 +27,11 @@ @InterfaceStability.Evolving public interface SessionConfigSupport extends DataSourceV2 { - /** - * Key prefix of the session configs to propagate. Spark will extract all session configs that - * starts with `spark.datasource.$keyPrefix`, turn `spark.datasource.$keyPrefix.xxx -> yyy` - * into `xxx -> yyy`, and propagate them to all data source operations in this session. - */ - String keyPrefix(); + /** + * Key prefix of the session configs to propagate, which is usually the data source name. Spark + * will extract all session configs that starts with `spark.datasource.$keyPrefix`, turn + * `spark.datasource.$keyPrefix.xxx -> yyy` into `xxx -> yyy`, and propagate them to all + * data source operations in this session. + */ + String keyPrefix(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java deleted file mode 100644 index a77b01497269e..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - */ -@InterfaceStability.Evolving -public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { - - /** - * Creates an optional {@link StreamWriter} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link DataSourceWriter} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamWriter createStreamWriter( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java new file mode 100644 index 0000000000000..f9ca85d8089b4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability for structured streaming. + * + * This interface is used to create {@link StreamingWriteSupport} instances when end users run + * {@code Dataset.writeStream.format(...).option(...).start()}. + */ +@InterfaceStability.Evolving +public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { + + /** + * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is + * called by Spark at the beginning of each streaming query. + * + * @param queryId A unique string for the writing query. It's possible that there are many + * writing queries running at the same time, and the returned + * {@link StreamingWriteSupport} can use this id to distinguish itself from others. + * @param schema the schema of the data to be written. + * @param mode the output mode which determines what successive epoch output means to this + * sink, please refer to {@link OutputMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + StreamingWriteSupport createStreamingWriteSupport( + String queryId, + StructType schema, + OutputMode mode, + DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java new file mode 100644 index 0000000000000..452ee86675b42 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface that defines how to load the data from data source for batch processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.BatchReadSupportProvider}) at the start of a batch + * query, then call {@link #newScanConfigBuilder()} and create an instance of {@link ScanConfig}. + * The {@link ScanConfigBuilder} can apply operator pushdown and keep the pushdown result in + * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader + * factory to scan data from the data source with a Spark job. + */ +@InterfaceStability.Evolving +public interface BatchReadSupport extends ReadSupport { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, and keep these + * information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link BatchReadSupport} needs + * to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java deleted file mode 100644 index da98fab1284ef..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.types.StructType; - -/** - * A data source reader that is returned by - * {@link ReadSupport#createReader(DataSourceOptions)} or - * {@link ReadSupport#createReader(StructType, DataSourceOptions)}. - * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s, which are returned by - * {@link #planInputPartitions()}. - * - * There are mainly 3 kinds of query optimizations: - * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column - * pruning), etc. Names of these interfaces start with `SupportsPushDown`. - * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. - * Names of these interfaces start with `SupportsReporting`. - * 3. Columnar scan if implements {@link SupportsScanColumnarBatch}. - * - * If an exception was throw when applying any of these query optimizations, the action will fail - * and no Spark job will be submitted. - * - * Spark first applies all operator push-down optimizations that this data source supports. Then - * Spark collects information this data source reported for further optimizations. Finally Spark - * issues the scan request and does the actual data reading. - */ -@InterfaceStability.Evolving -public interface DataSourceReader { - - /** - * Returns the actual schema of this data source reader, which may be different from the physical - * schema of the underlying storage, as column pruning or other optimizations may happen. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - StructType readSchema(); - - /** - * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for - * creating a data reader to output data of one RDD partition. The number of input partitions - * returned here is the same as the number of RDD partitions this scan outputs. - * - * Note that, this may not be a full scan if the data source reader mixes in other optimization - * interfaces like column pruning, filter push-down, etc. These optimizations are applied before - * Spark issues the scan request. - * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. - */ - List> planInputPartitions(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index f2038d0de3ffe..95c30de907e44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,18 +22,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader of one RDD partition. - * The relationship between {@link InputPartition} and {@link InputPartitionReader} - * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. + * A serializable representation of an input partition returned by + * {@link ReadSupport#planInputPartitions(ScanConfig)}. * - * Note that {@link InputPartition}s will be serialized and sent to executors, then - * {@link InputPartitionReader}s will be created on executors to do the actual reading. So - * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to - * be. + * Note that {@link InputPartition} will be serialized and sent to executors, then + * {@link PartitionReader} will be created by + * {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)} on executors to do + * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} + * doesn't need to be. */ @InterfaceStability.Evolving -public interface InputPartition extends Serializable { +public interface InputPartition extends Serializable { /** * The preferred locations where the input partition reader returned by this partition can run @@ -51,12 +51,4 @@ public interface InputPartition extends Serializable { default String[] preferredLocations() { return new String[0]; } - - /** - * Returns an input partition reader to do the actual reading work. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - */ - InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index f3ff7f5cc0f20..04ff8d0a19fc3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -23,31 +23,27 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is - * responsible for outputting data for a RDD partition. + * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or + * {@link PartitionReaderFactory#createColumnarReader(InputPartition)}. It's responsible for + * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow} - * for normal data source readers, {@link org.apache.spark.sql.vectorized.ColumnarBatch} for data - * source readers that mix in {@link SupportsScanColumnarBatch}. + * for normal data sources, or {@link org.apache.spark.sql.vectorized.ColumnarBatch} for columnar + * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} + * returns true). */ @InterfaceStability.Evolving -public interface InputPartitionReader extends Closeable { +public interface PartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. - * * @throws IOException if failure happens during disk/network IO like reading files. */ boolean next() throws IOException; /** * Return the current record. This method should return same value until `next` is called. - * - * If this method fails (by throwing an exception), the corresponding Spark task would fail and - * get retried until hitting the maximum retry times. */ T get(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java new file mode 100644 index 0000000000000..f35de9310eee3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A factory used to create {@link PartitionReader} instances. + * + * If Spark fails to execute any methods in the implementations of this interface or in the returned + * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and + * get retried until hitting the maximum retry times. + */ +@InterfaceStability.Evolving +public interface PartitionReaderFactory extends Serializable { + + /** + * Returns a row-based partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + PartitionReader createReader(InputPartition partition); + + /** + * Returns a columnar partition reader to read data from the given {@link InputPartition}. + * + * Implementations probably need to cast the input partition to the concrete + * {@link InputPartition} class defined for the data source. + */ + default PartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); + } + + /** + * Returns true if the given {@link InputPartition} should be read by Spark in a columnar way. + * This means, implementations must also implement {@link #createColumnarReader(InputPartition)} + * for the input partitions that this method returns true. + * + * As of Spark 2.4, Spark can only read all input partition in a columnar way, or none of them. + * Data source can't mix columnar and row-based partitions. This may be relaxed in future + * versions. + */ + default boolean supportColumnarReads(InputPartition partition) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java new file mode 100644 index 0000000000000..a58ddb288f1ed --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * The base interface for all the batch and streaming read supports. Data sources should implement + * concrete read support interfaces like {@link BatchReadSupport}. + * + * If Spark fails to execute any methods in the implementations of this interface (by throwing an + * exception), the read action will fail and no Spark job will be submitted. + */ +@InterfaceStability.Evolving +public interface ReadSupport { + + /** + * Returns the full schema of this data source, which is usually the physical schema of the + * underlying storage. This full schema should not be affected by column pruning or other + * optimizations. + */ + StructType fullSchema(); + + /** + * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} + * represents a data split that can be processed by one Spark task. The number of input + * partitions returned here is the same as the number of RDD partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source supports optimization like filter + * push-down. Implementations should check the input {@link ScanConfig} and adjust the resulting + * {@link InputPartition input partitions}. + */ + InputPartition[] planInputPartitions(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java new file mode 100644 index 0000000000000..7462ce2820585 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * An interface that carries query specific information for the data scanning job, like operator + * pushdown information and streaming query offsets. This is defined as an empty interface, and data + * sources should define their own {@link ScanConfig} classes. + * + * For APIs that take a {@link ScanConfig} as input, like + * {@link ReadSupport#planInputPartitions(ScanConfig)}, + * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to + * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. + */ +@InterfaceStability.Evolving +public interface ScanConfig { + + /** + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StructType readSchema(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java similarity index 61% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index dcb87715d0b6f..4c0eedfddfe22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -18,18 +18,13 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link InputPartition}. Continuous input partitions can - * implement this interface to provide creating {@link InputPartitionReader} with particular offset. + * An interface for building the {@link ScanConfig}. Implementations can mixin those + * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in + * the returned {@link ScanConfig}. */ @InterfaceStability.Evolving -public interface ContinuousInputPartition extends InputPartition { - /** - * Create an input partition reader with particular offset as its startOffset. - * - * @param offset offset want to set as the input partition reader's startOffset. - */ - InputPartitionReader createContinuousReader(PartitionOffset offset); +public interface ScanConfigBuilder { + ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index e8cd7adbca071..44799c7d49137 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#getStatistics()}. + * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ @InterfaceStability.Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 4543c143a9aca..9d79a18d14bcf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * A mix-in interface for {@link ScanConfigBuilder}. Data source readers can implement this * interface to push down arbitrary expressions as predicates to the data source. * This is an experimental and unstable interface as {@link Expression} is not public and may get * changed in the future Spark versions. @@ -31,7 +31,7 @@ * process this interface. */ @InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters extends DataSourceReader { +public interface SupportsPushDownCatalystFilters extends ScanConfigBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index b6a90a3d0b681..5d32a8ac60f78 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,15 +21,15 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to push down filters to the data source and reduce the size of the data to be read. + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to + * push down filters to the data source and reduce the size of the data to be read. * * Note that, if data source readers implement both this interface and * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters extends DataSourceReader { +public interface SupportsPushDownFilters extends ScanConfigBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index 427b4d00a1128..edb164937d6ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this + * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns extends DataSourceReader { +public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,8 +35,8 @@ public interface SupportsPushDownRequiredColumns extends DataSourceReader { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceReader#readSchema()} after - * applying column pruning. + * Note that, {@link ScanConfig#readSchema()} implementation should take care of the column + * pruning applied here. */ void pruneColumns(StructType requiredSchema); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 6b60da7c4dc1d..db62cd4515362 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report data partitioning and try to avoid shuffle at Spark side. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid - * adding a shuffle even if the reader does not implement this interface. + * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning extends DataSourceReader { +public interface SupportsReportPartitioning extends ReadSupport { /** * Returns the output data partitioning that this reader guarantees. */ - Partitioning outputPartitioning(); + Partitioning outputPartitioning(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 926396414816c..1831488ba096f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,18 +20,18 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report statistics to Spark. + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report statistics to Spark. * - * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. - * Implementations that return more accurate statistics based on pushed operators will not improve - * query performance until the planner can push operators before getting stats. + * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the + * data source. Implementations that return more accurate statistics based on pushed operators will + * not improve query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics extends DataSourceReader { +public interface SupportsReportStatistics extends ReadSupport { /** - * Returns the basic statistics of this data source. + * Returns the estimated statistics of this data source scan. */ - Statistics getStatistics(); + Statistics estimateStatistics(ScanConfig config); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java deleted file mode 100644 index f4da686740d11..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import java.util.List; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link ColumnarBatch} and make the scan faster. - */ -@InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceReader { - @Override - default List> planInputPartitions() { - throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsScanColumnarBatch."); - } - - /** - * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data - * in batches. - */ - List> planBatchInputPartitions(); - - /** - * Returns true if the concrete data source reader can read data in batch according to the scan - * properties like required columns, pushes filters, etc. It's possible that the implementation - * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #planInputPartitions()} to fallback to normal read path under some conditions. - */ - default boolean enableBatchRead() { - return true; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 38ca5fc6387b2..6764d4b7665c7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link InputPartitionReader}. + * {@link PartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 5e32ba6952e1c..364a3f553923c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,14 +18,14 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * be distributed among the data partitions (one {@link PartitionReader} outputs data for one * partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link InputPartitionReader}). + * partition(the output records of a single {@link PartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index f460f6bfe3bb9..fb0b6f1df43bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -19,12 +19,13 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a - * snapshot. Once created, it should be deterministic and always report the same number of + * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work + * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 7b0ba0bbdda90..9101c8a44d34e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -18,19 +18,20 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** - * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. + * A variation on {@link PartitionReader} for use with continuous streaming processing. */ @InterfaceStability.Evolving -public interface ContinuousInputPartitionReader extends InputPartitionReader { - /** - * Get the offset of the current record, or the start offset if no records have been read. - * - * The execution engine will call this method along with get() to keep track of the current - * offset. When an epoch ends, the offset of the previous record in each partition will be saved - * as a restart checkpoint. - */ - PartitionOffset getOffset(); +public interface ContinuousPartitionReader extends PartitionReader { + + /** + * Get the offset of the current record, or the start offset if no records have been read. + * + * The execution engine will call this method along with get() to keep track of the current + * offset. When an epoch ends, the offset of the previous record in each partition will be saved + * as a restart checkpoint. + */ + PartitionOffset getOffset(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java similarity index 51% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index 595943cf4d8ac..2d9f1ca1686a1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsDeprecatedScanRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -15,25 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; - -import java.util.List; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to output {@link Row} instead of {@link InternalRow}. - * This is an experimental and unstable interface. + * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} + * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for + * continuous streaming processing. */ -@InterfaceStability.Unstable -public interface SupportsDeprecatedScanRow extends DataSourceReader { - default List> planInputPartitions() { - throw new IllegalStateException( - "planInputPartitions not supported by default within SupportsDeprecatedScanRow"); - } +@InterfaceStability.Evolving +public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { + @Override + ContinuousPartitionReader createReader(InputPartition partition); - List> planRowInputPartitions(); + @Override + default ContinuousPartitionReader createColumnarReader(InputPartition partition) { + throw new UnsupportedOperationException("Cannot create columnar reader."); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java new file mode 100644 index 0000000000000..9a3ad2eb8a801 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanConfig; +import org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder; + +/** + * An interface that defines how to load the data from data source for continuous streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.ContinuousReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset)} and create an instance of + * {@link ScanConfig} for the duration of the streaming query or until + * {@link #needsReconfiguration(ScanConfig)} is true. The {@link ScanConfig} will be used to create + * input partitions and reader factory to scan data with a Spark job for its duration. At the end + * {@link #stop()} will be called when the streaming execution is completed. Note that a single + * query may have multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link ContinuousReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start); + + /** + * Returns a factory, which produces one {@link ContinuousPartitionReader} for one + * {@link InputPartition}. + */ + ContinuousPartitionReaderFactory createContinuousReaderFactory(ScanConfig config); + + /** + * Merge partitioned offsets coming from {@link ContinuousPartitionReader} instances + * for each partition to a single global offset. + */ + Offset mergeOffsets(PartitionOffset[] offsets); + + /** + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. + * + * If true, the query will be shut down and restarted with a new {@link ContinuousReadSupport} + * instance. + */ + default boolean needsReconfiguration(ScanConfig config) { + return false; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java deleted file mode 100644 index 6e960bedf8020..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to allow reading in a continuous processing mode stream. - * - * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { - /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances - * for each partition to a single global offset. - */ - Offset mergeOffsets(PartitionOffset[] offsets); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Set the desired start offset for partitions created from this reader. The scan will - * start from the first record after the provided offset, or from an implementation-defined - * inferred starting point if no offset is provided. - */ - void setStartOffset(Optional start); - - /** - * Return the specified or inferred start offset for this reader. - * - * @throws IllegalStateException if setStartOffset has not been called - */ - Offset getStartOffset(); - - /** - * The execution engine will call this method in every epoch to determine if new input - * partitions need to be generated, which may be required if for example the underlying - * source system has had partitions added or removed. - * - * If true, the query will be shut down and restarted with a new reader. - */ - default boolean needsReconfiguration() { - return false; - } - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java new file mode 100644 index 0000000000000..edb0db11bff2c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.BaseStreamingSource; +import org.apache.spark.sql.sources.v2.reader.*; + +/** + * An interface that defines how to scan the data from data source for micro-batch streaming + * processing. + * + * The execution engine will get an instance of this interface from a data source provider + * (e.g. {@link org.apache.spark.sql.sources.v2.MicroBatchReadSupportProvider}) at the start of a + * streaming query, then call {@link #newScanConfigBuilder(Offset, Offset)} and create an instance + * of {@link ScanConfig} for each micro-batch. The {@link ScanConfig} will be used to create input + * partitions and reader factory to scan a micro-batch with a Spark job. At the end {@link #stop()} + * will be called when the streaming execution is completed. Note that a single query may have + * multiple executions due to restart or failure recovery. + */ +@InterfaceStability.Evolving +public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { + + /** + * Returns a builder of {@link ScanConfig}. Spark will call this method and create a + * {@link ScanConfig} for each data scanning job. + * + * The builder can take some query specific information to do operators pushdown, store streaming + * offsets, etc., and keep these information in the created {@link ScanConfig}. + * + * This is the first step of the data scan. All other methods in {@link MicroBatchReadSupport} + * needs to take {@link ScanConfig} as an input. + */ + ScanConfigBuilder newScanConfigBuilder(Offset start, Offset end); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(ScanConfig config); + + /** + * Returns the most recent offset available. + */ + Offset latestOffset(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java deleted file mode 100644 index 0159c731762d9..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; - -import java.util.Optional; - -/** - * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to indicate they allow micro-batch streaming reads. - * - * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with - * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. - */ -@InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { - /** - * Set the desired offset range for input partitions created from this reader. Partition readers - * will generate only data within (`start`, `end`]; that is, from the first record after `start` - * to the record with offset `end`. - * - * @param start The initial offset to scan from. If not specified, scan from an - * implementation-specified start point, such as the earliest available record. - * @param end The last offset to include in the scan. If not specified, scan up to an - * implementation-defined endpoint, such as the last available offset - * or the start offset plus a target batch size. - */ - void setOffsetRange(Optional start, Optional end); - - /** - * Returns the specified (if explicitly set through setOffsetRange) or inferred start offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getStartOffset(); - - /** - * Return the specified (if explicitly set through setOffsetRange) or inferred end offset - * for this reader. - * - * @throws IllegalStateException if setOffsetRange has not been called - */ - Offset getEndOffset(); - - /** - * Deserialize a JSON string into an Offset of the implementation-defined offset type. - * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader - */ - Offset deserializeOffset(String json); - - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index e41c0351edc82..6cf27734867cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -20,8 +20,8 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An abstract representation of progress through a {@link MicroBatchReader} or - * {@link ContinuousReader}. + * An abstract representation of progress through a {@link MicroBatchReadSupport} or + * {@link ContinuousReadSupport}. * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java new file mode 100644 index 0000000000000..84872d1ebc26e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.sql.sources.v2.reader.ReadSupport; + +/** + * A base interface for streaming read support. This is package private and is invisible to data + * sources. Data sources should implement concrete streaming read support interfaces: + * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + */ +interface StreamingReadSupport extends ReadSupport { + + /** + * Returns the initial offset for a streaming query to start reading from. Note that the + * streaming data source should not assume that it will start reading from its initial offset: + * if Spark is restarting an existing query, it will restart from the check-pointed offset rather + * than the initial one. + */ + Offset initialOffset(); + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java index 3b293d925c91d..8693154cb7045 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.CustomMetrics; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** - * A mix in interface for {@link DataSourceReader}. Data source readers can implement this - * interface to report custom metrics that gets reported under the + * A mix in interface for {@link StreamingReadSupport}. Data sources can implement this interface + * to report custom metrics that gets reported under the * {@link org.apache.spark.sql.streaming.SourceProgress} - * */ @InterfaceStability.Evolving -public interface SupportsCustomReaderMetrics extends DataSourceReader { +public interface SupportsCustomReaderMetrics extends StreamingReadSupport { + /** * Returns custom metrics specific to this data source. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 385fc294fea82..0ec9e05d6a02b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -18,28 +18,13 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.StreamWriteSupport; -import org.apache.spark.sql.sources.v2.WriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; /** - * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceOptions)}. - * It can mix in various writing optimization interfaces to speed up the data saving. The actual - * writing logic is delegated to {@link DataWriter}. - * - * If an exception was throw when applying any of these writing optimizations, the action will fail - * and no Spark job will be submitted. + * An interface that defines how to write the data to data source for batch processing. * * The writing procedure is: - * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the - * partitions of the input data(RDD). + * 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all + * the partitions of the input data(RDD). * 2. For each partition, create the data writer, and write the data of the partition with this * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If * exception happens during the writing, call {@link DataWriter#abort()}. @@ -53,7 +38,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceWriter { +public interface BatchWriteSupport { /** * Creates a writer factory which will be serialized and sent to executors. @@ -61,7 +46,7 @@ public interface DataSourceWriter { * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - DataWriterFactory createWriterFactory(); + DataWriterFactory createBatchWriterFactory(); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 27dc5ea224fe2..5fb067966ee67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is + * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -36,11 +36,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a - * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createWriter(int, long)} will receive a + * different `taskId`. Spark will call {@link BatchWriteSupport#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task @@ -71,11 +71,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * {@link BatchWriteSupport#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to + * {@link BatchWriteSupport#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link BatchWriteSupport} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -93,7 +93,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link BatchWriteSupport#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 3d337b6e0bdfd..19a36dd232456 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -19,18 +19,20 @@ import java.io.Serializable; +import org.apache.spark.TaskContext; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; /** - * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link BatchWriteSupport#createBatchWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer - * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataWriterFactory extends Serializable { +public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data @@ -38,19 +40,16 @@ public interface DataWriterFactory extends Serializable { * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a * list. * - * If this method fails (by throwing an exception), the action will fail and no Spark job will be - * submitted. + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param taskId A unique identifier for a task that is performing the write of the partition - * data. Spark may run multiple tasks for the same partition (due to speculation - * or task failures, for example). - * @param epochId A monotonically increasing id for streaming queries that are split in to - * discrete periods of execution. For non-streaming queries, - * this ID will always be 0. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). */ - DataWriter createDataWriter(int partitionId, long taskId, long epochId); + DataWriter createWriter(int partitionId, long taskId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 9e38836c0edf9..123335c414e9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,15 +19,16 @@ import java.io.Serializable; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. + * as the input parameter of {@link BatchWriteSupport#commit(WriterCommitMessage[])} or + * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. * - * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} - * implementations. + * This is an empty interface, data sources should define their own message class and use it when + * generating messages at executor side and handling the messages at driver side. */ @InterfaceStability.Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java new file mode 100644 index 0000000000000..a4da24fc5ae68 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer.streaming; + +import java.io.Serializable; + +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.writer.DataWriter; + +/** + * A factory of {@link DataWriter} returned by + * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So this interface must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface StreamingDataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data + * object instance when sending data to the data writer, for better performance. Data writers + * are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a + * list. + * + * If this method fails (by throwing an exception), the corresponding Spark write task would fail + * and get retried until hitting the maximum retry times. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param taskId The task id returned by {@link TaskContext#taskAttemptId()}. Spark may run + * multiple tasks for the same partition (due to speculation or task failures, + * for example). + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. + */ + DataWriter createWriter(int partitionId, long taskId, long epochId); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index a316b2a4c1d82..3fdfac5e1c84a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -18,27 +18,36 @@ package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. + * An interface that defines how to write the data to data source for streaming processing. * * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceWriter { +public interface StreamingWriteSupport { + + /** + * Creates a writer factory which will be serialized and sent to executors. + * + * If this method fails (by throwing an exception), the action will fail and no Spark job will be + * submitted. + */ + StreamingDataWriterFactory createStreamingWriterFactory(); + /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by * {@link DataWriter#commit()}. * * If this method fails (by throwing an exception), this writing job is considered to have been - * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * failed, and the execution engine will attempt to call + * {@link #abort(long, WriterCommitMessage[])}. * - * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * The execution engine may call `commit` multiple times for the same epoch in some circumstances. * To support exactly-once data semantics, implementations must ensure that multiple commits for * the same epoch are idempotent. */ @@ -46,7 +55,8 @@ public interface StreamWriter extends DataSourceWriter { /** * Aborts this writing job because some data writers are failed and keep failing when retried, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * the Spark job fails with some unknown reasons, or {@link #commit(long, WriterCommitMessage[])} + * fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. @@ -58,14 +68,4 @@ public interface StreamWriter extends DataSourceWriter { * clean up the data left by data writers. */ void abort(long epochId, WriterCommitMessage[] messages); - - default void commit(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Commit without epoch should not be called with StreamWriter"); - } - - default void abort(WriterCommitMessage[] messages) { - throw new UnsupportedOperationException( - "Abort without epoch should not be called with StreamWriter"); - } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java index 0cd36501320fd..2b018c7d123bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java @@ -14,20 +14,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.CustomMetrics; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; /** - * A mix in interface for {@link DataSourceWriter}. Data source writers can implement this - * interface to report custom metrics that gets reported under the + * A mix in interface for {@link StreamingWriteSupport}. Data sources can implement this interface + * to report custom metrics that gets reported under the * {@link org.apache.spark.sql.streaming.SinkProgress} - * */ @InterfaceStability.Evolving -public interface SupportsCustomWriterMetrics extends DataSourceWriter { +public interface SupportsCustomWriterMetrics extends StreamingWriteSupport { + /** * Returns custom metrics specific to this data source. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5b3b5c2451aab..0cfcc45fb3d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[ReadSupport]) { + if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) val pathsOption = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 650c91790a758..eca2d5b971905 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -240,7 +240,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (classOf[DataSourceV2].isAssignableFrom(cls)) { val source = cls.newInstance().asInstanceOf[DataSourceV2] source match { - case ws: WriteSupport => + case provider: BatchWriteSupportProvider => val options = extraOptions ++ DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) @@ -251,8 +251,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } else { - val writer = ws.createWriter( - UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode, + val writer = provider.createBatchWriteSupport( + UUID.randomUUID().toString, + df.logicalPlan.output.toStructType, + mode, new DataSourceOptions(options.asJava)) if (writer.isPresent) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 782829887c446..f62f7349d1da7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.reflect.ClassTag - -import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} -class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) +class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition) extends Partition with Serializable -class DataSourceRDD[T: ClassTag]( +// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for +// columnar scan. +class DataSourceRDD( sc: SparkContext, - @transient private val inputPartitions: Seq[InputPartition[T]]) - extends RDD[T](sc, Nil) { + @transient private val inputPartitions: Seq[InputPartition], + partitionReaderFactory: PartitionReaderFactory, + columnarReads: Boolean) + extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -37,11 +40,21 @@ class DataSourceRDD[T: ClassTag]( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition - .createPartitionReader() + private def castPartition(split: Partition): DataSourceRDDPartition = split match { + case p: DataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a DataSourceRDDPartition: $split") + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val inputPartition = castPartition(split).inputPartition + val reader: PartitionReader[_] = if (columnarReads) { + partitionReaderFactory.createColumnarReader(inputPartition) + } else { + partitionReaderFactory.createReader(inputPartition) + } + context.addTaskCompletionListener[Unit](_ => reader.close()) - val iter = new Iterator[T] { + val iter = new Iterator[Any] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +64,7 @@ class DataSourceRDD[T: ClassTag]( valuePrepared } - override def next(): T = { + override def next(): Any = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -59,10 +72,11 @@ class DataSourceRDD[T: ClassTag]( reader.get() } } - new InterruptibleIterator(context, iter) + // TODO: SPARK-25083 remove the type erasure hack in data source scan + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a4bfc861cc9a4..f7e29593a6353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -27,21 +27,21 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 scan. * * @param source An instance of a [[DataSourceV2]] implementation. - * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. - * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh - * [[DataSourceReader]]. + * @param options The options for this scan. Used to create fresh [[BatchWriteSupport]]. + * @param userSpecifiedSchema The user-specified schema for this scan. */ case class DataSourceV2Relation( source: DataSourceV2, + readSupport: BatchReadSupport, output: Seq[AttributeReference], options: Map[String, String], tableIdent: Option[TableIdentifier] = None, @@ -58,13 +58,12 @@ case class DataSourceV2Relation( override def simpleString: String = "RelationV2 " + metadataString - def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) + def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) - def newWriter(): DataSourceWriter = source.createWriter(options, schema) - - override def computeStats(): Statistics = newReader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -85,7 +84,8 @@ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], source: DataSourceV2, options: Map[String, String], - reader: DataSourceReader) + readSupport: ReadSupport, + scanConfigBuilder: ScanConfigBuilder) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { override def isStreaming: Boolean = true @@ -99,7 +99,8 @@ case class StreamingDataSourceV2Relation( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: StreamingDataSourceV2Relation => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -107,9 +108,10 @@ case class StreamingDataSourceV2Relation( Seq(output, source, options).hashCode() } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = readSupport match { case r: SupportsReportStatistics => - Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + val statistics = r.estimateStatistics(scanConfigBuilder.build()) + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -117,19 +119,19 @@ case class StreamingDataSourceV2Relation( object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupport: ReadSupport = { + def asReadSupportProvider: BatchReadSupportProvider = { source match { - case support: ReadSupport => - support + case provider: BatchReadSupportProvider => + provider case _ => throw new AnalysisException(s"Data source is not readable: $name") } } - def asWriteSupport: WriteSupport = { + def asWriteSupportProvider: BatchWriteSupportProvider = { source match { - case support: WriteSupport => - support + case provider: BatchWriteSupportProvider => + provider case _ => throw new AnalysisException(s"Data source is not writable: $name") } @@ -144,23 +146,26 @@ object DataSourceV2Relation { } } - def createReader( + def createReadSupport( options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceReader = { + userSpecifiedSchema: Option[StructType]): BatchReadSupport = { val v2Options = new DataSourceOptions(options.asJava) userSpecifiedSchema match { case Some(s) => - asReadSupport.createReader(s, v2Options) + asReadSupportProvider.createBatchReadSupport(s, v2Options) case _ => - asReadSupport.createReader(v2Options) + asReadSupportProvider.createBatchReadSupport(v2Options) } } - def createWriter( + def createWriteSupport( options: Map[String, String], - schema: StructType): DataSourceWriter = { - val v2Options = new DataSourceOptions(options.asJava) - asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get + schema: StructType): BatchWriteSupport = { + asWriteSupportProvider.createBatchWriteSupport( + UUID.randomUUID().toString, + schema, + SaveMode.Append, + new DataSourceOptions(options.asJava)).get } } @@ -169,15 +174,16 @@ object DataSourceV2Relation { options: Map[String, String], tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val reader = source.createReader(options, userSpecifiedSchema) + val readSupport = source.createReadSupport(options, userSpecifiedSchema) + val output = readSupport.fullSchema().toAttributes val ident = tableIdent.orElse(tableFromOptions(options)) DataSourceV2Relation( - source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema) + source, readSupport, output, options, ident, userSpecifiedSchema) } private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { options - .get(DataSourceOptions.TABLE_KEY) - .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) + .get(DataSourceOptions.TABLE_KEY) + .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c8494f97f1761..04a97735d024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -28,8 +26,7 @@ import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeSta import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} /** * Physical plan node for scanning data from a data source. @@ -39,7 +36,8 @@ case class DataSourceV2ScanExec( @transient source: DataSourceV2, @transient options: Map[String, String], @transient pushedFilters: Seq[Expression], - @transient reader: DataSourceReader) + @transient readSupport: ReadSupport, + @transient scanConfig: ScanConfig) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { override def simpleString: String = "ScanV2 " + metadataString @@ -47,7 +45,8 @@ case class DataSourceV2ScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: DataSourceV2ScanExec => - output == other.output && reader.getClass == other.reader.getClass && options == other.options + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options case _ => false } @@ -55,36 +54,39 @@ case class DataSourceV2ScanExec( Seq(output, source, options).hashCode() } - override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => - SinglePartition - - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => - SinglePartition - - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => + override def outputPartitioning: physical.Partitioning = readSupport match { + case _ if partitions.length == 1 => SinglePartition case s: SupportsReportPartitioning => new DataSourcePartitioning( - s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition[InternalRow]] = { - reader.planInputPartitions().asScala + private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) + + private lazy val readerFactory = readSupport match { + case r: BatchReadSupport => r.createReaderFactory(scanConfig) + case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) + case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) + case _ => throw new IllegalStateException("unknown read support: " + readSupport) } - private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - assert(!reader.isInstanceOf[ContinuousReader], - "continuous stream reader does not support columnar read yet.") - r.planBatchInputPartitions().asScala + // TODO: clean this up when we have dedicated scan plan for continuous streaming. + override val supportsBatch: Boolean = { + require(partitions.forall(readerFactory.supportColumnarReads) || + !partitions.exists(readerFactory.supportColumnarReads), + "Cannot mix row-based and columnar input partitions.") + + partitions.exists(readerFactory.supportColumnarReads) } - private lazy val inputRDD: RDD[InternalRow] = reader match { - case _: ContinuousReader => + private lazy val inputRDD: RDD[InternalRow] = readSupport match { + case _: ContinuousReadSupport => + assert(!supportsBatch, + "continuous stream reader does not support columnar read yet.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -93,22 +95,17 @@ case class DataSourceV2ScanExec( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - partitions).asInstanceOf[RDD[InternalRow]] - - case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] + partitions, + schema, + readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) case _ => - new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD( + sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - override val supportsBatch: Boolean = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() => true - case _ => false - } - override protected def needsUnsafeRowConversion: Boolean = false override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 6daaa4c65c335..fe713ff6c7850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Rep import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport object DataSourceV2Strategy extends Strategy { @@ -37,9 +37,9 @@ object DataSourceV2Strategy extends Strategy { * @return pushed filter and post-scan filters. */ private def pushFilters( - reader: DataSourceReader, + configBuilder: ScanConfigBuilder, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { + configBuilder match { case r: SupportsPushDownCatalystFilters => val postScanFilters = r.pushCatalystFilters(filters.toArray) val pushedFilters = r.pushedCatalystFilters() @@ -76,41 +76,43 @@ object DataSourceV2Strategy extends Strategy { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return new output attributes after column pruning. + * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), + * and new output attributes after column pruning. */ // TODO: nested column pruning. private def pruneColumns( - reader: DataSourceReader, + configBuilder: ScanConfigBuilder, relation: DataSourceV2Relation, - exprs: Seq[Expression]): Seq[AttributeReference] = { - reader match { + exprs: Seq[Expression]): (ScanConfig, Seq[AttributeReference]) = { + configBuilder match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { r.pruneColumns(neededOutput.toStructType) + val config = r.build() val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - r.readSchema().toAttributes.map { + config -> config.readSchema().toAttributes.map { // We have to keep the attribute id during transformation. a => a.withExprId(nameToAttr(a.name).exprId) } } else { - relation.output + r.build() -> relation.output } - case _ => relation.output + case _ => configBuilder.build() -> relation.output } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val reader = relation.newReader() + val configBuilder = relation.readSupport.newScanConfigBuilder() // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(reader, filters) - val output = pruneColumns(reader, relation, project ++ postScanFilters) + val (pushedFilters, postScanFilters) = pushFilters(configBuilder, filters) + val (config, output) = pruneColumns(configBuilder, relation, project ++ postScanFilters) logInfo( s""" |Pushing operators to ${relation.source.getClass} @@ -120,7 +122,12 @@ object DataSourceV2Strategy extends Strategy { """.stripMargin) val scan = DataSourceV2ScanExec( - output, relation.source, relation.options, pushedFilters, reader) + output, + relation.source, + relation.options, + pushedFilters, + relation.readSupport, + config) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) @@ -129,22 +136,26 @@ object DataSourceV2Strategy extends Strategy { ProjectExec(project, withFilter) :: Nil case r: StreamingDataSourceV2Relation => + // TODO: support operator pushdown for streaming data sources. + val scanConfig = r.scanConfigBuilder.build() // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil + DataSourceV2ScanExec( + r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil + WriteToDataSourceV2Exec(r.newWriteSupport(), planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil case Repartition(1, false, child) => - val isContinuous = child.collectFirst { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + val isContinuous = child.find { + case s: StreamingDataSourceV2Relation => s.readSupport.isInstanceOf[ContinuousReadSupport] + case _ => false }.isDefined if (isContinuous) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 5267f5f1580c3..e9cc3991155c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,6 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { @@ -55,4 +56,12 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { + val name = ds match { + case register: DataSourceRegister => register.shortName() + case _ => ds.getClass.getName + } + throw new UnsupportedOperationException(name + " source does not support user-specified schema") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 59ebb9bc5431b..c3f7b690ef636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -23,15 +23,11 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** @@ -39,7 +35,8 @@ import org.apache.spark.util.Utils * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ @deprecated("Use specific logical plans like AppendData instead", "2.4.0") -case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -47,46 +44,48 @@ case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) ext /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) + extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer.createWriterFactory() - val useCommitCoordinator = writer.useCommitCoordinator + val writerFactory = writeSupport.createBatchWriterFactory() + val useCommitCoordinator = writeSupport.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${messages.length} partitions.") try { sparkContext.runJob( rdd, (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), + DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message - writer.onDataWriterCommit(message) + writeSupport.onDataWriterCommit(message) } ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + logInfo(s"Data source write support $writeSupport is committing.") + writeSupport.commit(messages) + logInfo(s"Data source write support $writeSupport committed.") } catch { case cause: Throwable => - logError(s"Data source writer $writer is aborting.") + logError(s"Data source write support $writeSupport is aborting.") try { - writer.abort(messages) + writeSupport.abort(messages) } catch { case t: Throwable => - logError(s"Data source writer $writer failed to abort.") + logError(s"Data source write support $writeSupport failed to abort.") cause.addSuppressed(t) throw new SparkException("Writing job failed.", cause) } - logError(s"Data source writer $writer aborted.") + logError(s"Data source write support $writeSupport aborted.") cause match { // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) @@ -100,7 +99,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e object DataWritingSparkTask extends Logging { def run( - writeTask: DataWriterFactory[InternalRow], + writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { @@ -109,8 +108,7 @@ object DataWritingSparkTask extends Logging { val partId = context.partitionId() val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() - val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) + val dataWriter = writerFactory.createWriter(partId, taskId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index b1c91ac94b268..cf83ba7436d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -28,9 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -51,8 +49,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = - MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val readSupportToDataSourceMap = + MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -91,20 +89,19 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = dataSourceV2.createMicroBatchReader( - Optional.empty(), // user specified schema + val readSupport = dataSourceV2.createMicroBatchReadSupport( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReader [$reader] from " + + readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(reader, output)(sparkSession) + StreamingExecutionRelation(readSupport, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -340,19 +337,19 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: MicroBatchReader => + case s: RateControlMicroBatchReadSupport => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) + reportTimeTaken("latestOffset") { + val startOffset = availableOffsets + .get(s).map(off => s.deserializeOffset(off.json)) + .getOrElse(s.initialOffset()) + (s, Option(s.latestOffset(startOffset))) + } + case s: MicroBatchReadSupport => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("latestOffset") { + (s, Option(s.latestOffset())) } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -392,8 +389,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (reader: MicroBatchReader, off) => - reader.commit(reader.deserializeOffset(off.json)) + case (readSupport: MicroBatchReadSupport, off) => + readSupport.commit(readSupport.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -437,30 +434,34 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - case (reader: MicroBatchReader, available) - if committedOffsets.get(reader).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) - val availableV2: OffsetV2 = available match { - case v1: SerializedOffset => reader.deserializeOffset(v1.json) + + // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but + // to be compatible with streaming source v1, we return a logical plan as a new batch here. + case (readSupport: MicroBatchReadSupport, available) + if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(readSupport).map { + off => readSupport.deserializeOffset(off.json) + } + val endOffset: OffsetV2 = available match { + case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) - logDebug(s"Retrieving data from $reader: $current -> $availableV2") + val startOffset = current.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) + logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - val (source, options) = reader match { + val (source, options) = readSupport match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readerToDataSourceMap.getOrElse(reader, { + case _ => readSupportToDataSourceMap.getOrElse(readSupport, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(reader -> StreamingDataSourceV2Relation( - reader.readSchema().toAttributes, source, options, reader)) + Some(readSupport -> StreamingDataSourceV2Relation( + readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) case _ => None } } @@ -494,13 +495,13 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamWriteSupport => - val writer = s.createStreamWriter( + case s: StreamingWriteSupportProvider => + val writer = s.createStreamingWriteSupport( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -526,7 +527,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamWriteSupport => + case _: StreamingWriteSupportProvider => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } @@ -551,10 +552,6 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } - - private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { - Optional.ofNullable(scalaOption.orNull) - } } object MicroBatchExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index ae1bfa2e499bb..417b6b39366ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging @@ -33,11 +32,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWritSupport import org.apache.spark.sql.sources.v2.CustomMetrics -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, SupportsCustomReaderMetrics} -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter -import org.apache.spark.sql.sources.v2.writer.streaming.SupportsCustomWriterMetrics +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -201,7 +199,7 @@ trait ProgressReporter extends Logging { ) } - val customWriterMetrics = dataSourceWriter match { + val customWriterMetrics = extractWriteSupport() match { case Some(s: SupportsCustomWriterMetrics) => extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) @@ -238,13 +236,13 @@ trait ProgressReporter extends Logging { } /** Extract writer from the executed query plan. */ - private def dataSourceWriter: Option[DataSourceWriter] = { + private def extractWriteSupport(): Option[StreamingWriteSupport] = { if (lastExecution == null) return None lastExecution.executedPlan.collect { case p if p.isInstanceOf[WriteToDataSourceV2Exec] => - p.asInstanceOf[WriteToDataSourceV2Exec].writer + p.asInstanceOf[WriteToDataSourceV2Exec].writeSupport }.headOption match { - case Some(w: MicroBatchWriter) => Some(w.writer) + case Some(w: MicroBatchWritSupport) => Some(w.writeSupport) case _ => None } } @@ -303,7 +301,7 @@ trait ProgressReporter extends Logging { // Check whether the streaming query's logical plan has only V2 data sources val allStreamingLeaves = logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } } if (onlyDataSourceV2Sources) { @@ -330,7 +328,7 @@ trait ProgressReporter extends Logging { new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() lastExecution.executedPlan.collectLeaves().foreach { - case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => uniqueStreamingExecLeavesMap.put(s, s) case _ => } @@ -338,7 +336,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows }.toSeq logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala new file mode 100644 index 0000000000000..1be071614d92e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.types.StructType + +/** + * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to + * carry schema and offsets for streaming data sources. + */ +class SimpleStreamingScanConfigBuilder( + schema: StructType, + start: Offset, + end: Option[Offset] = None) + extends ScanConfigBuilder { + + override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) +} + +case class SimpleStreamingScanConfig( + readSchema: StructType, + start: Offset, + end: Option[Offset]) + extends ScanConfig diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 24195b5657e8a..4b696dfa57359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -83,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is conntinuous or not, so we need to be able to +// know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -113,7 +113,7 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupport, + source: ContinuousReadSupportProvider, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index cfba1001c6de0..9c5c16f4f5d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamWriteSupport + with StreamingWriteSupportProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new ConsoleWriter(schema, options) + options: DataSourceOptions): StreamingWriteSupport = { + new ConsoleWriteSupport(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 554a0b0573f4d..b68f67e0b22d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,12 +21,13 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory +import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[InternalRow]) + val inputPartition: InputPartition) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -49,15 +50,22 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[InternalRow]]) + private val inputPartitions: Seq[InputPartition], + schema: StructType, + partitionReaderFactory: ContinuousPartitionReaderFactory) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerInputPartitions.zipWithIndex.map { + inputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } + private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { + case p: ContinuousDataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") + } + /** * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. @@ -69,10 +77,12 @@ class ContinuousDataSourceRDD( } val readerForPartition = { - val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + val partition = castPartition(split) if (partition.queueReader == null) { - partition.queueReader = - new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) + val partitionReader = partitionReaderFactory.createReader( + partition.inputPartition) + partition.queueReader = new ContinuousQueuedDataReader( + partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -93,17 +103,6 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getContinuousReader( - reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { - reader match { - case r: ContinuousInputPartitionReader[InternalRow] => r - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } + castPartition(split).inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 140cec64fffb2..4ddebb33b79d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,13 +29,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -43,7 +42,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamWriteSupport, + sink: StreamingWriteSupportProvider, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -53,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -63,7 +62,8 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => + // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -148,8 +148,7 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReader( - java.util.Optional.empty[StructType](), + dataSource.createContinuousReadSupport( metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -160,9 +159,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val reader = continuousSources(insertedSourceId) + val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = reader.readSchema().toAttributes + val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -170,9 +169,10 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - StreamingDataSourceV2Relation(newOutput, source, options, reader) + val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) + val startOffset = realOffset.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) + StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) } // Rewire the plan to use the new attributes that were returned by the source. @@ -185,17 +185,13 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamWriter( + val writer = sink.createStreamingWriteSupport( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) - val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r - }.head - reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, @@ -208,6 +204,11 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } + val (readSupport, scanConfig) = lastExecution.executedPlan.collect { + case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => + scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig + }.head + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across @@ -223,14 +224,16 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && + state.compareAndSet(ACTIVE, RECONFIGURING) + if (shouldReconfigure) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -280,10 +283,12 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, + readSupport: ContinuousReadSupport, + partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index ec1dabd7da3e9..65c5fc63c2f46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -25,8 +25,9 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils /** @@ -37,15 +38,14 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: ContinuousDataSourceRDDPartition, + partitionIndex: Int, + reader: ContinuousPartitionReader[InternalRow], + schema: StructType, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.inputPartition.createPartitionReader() - // Important sequencing - we must get our starting point before the provider threads start running - private var currentOffset: PartitionOffset = - ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentOffset: PartitionOffset = reader.getOffset /** * The record types in the read buffer. @@ -66,7 +66,7 @@ class ContinuousQueuedDataReader( epochMarkerExecutor.scheduleWithFixedDelay( epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - private val dataReaderThread = new DataReaderThread + private val dataReaderThread = new DataReaderThread(schema) dataReaderThread.setDaemon(true) dataReaderThread.start() @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) + partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -128,16 +128,16 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[InputPartitionReader]]. + * a new row arrives to the [[ContinuousPartitionReader]]. */ - class DataReaderThread extends Thread( + class DataReaderThread(schema: StructType) extends Thread( s"continuous-reader--${context.partitionId()}--" + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { @volatile private[continuous] var failureReason: Throwable = _ + private val toUnsafe = UnsafeProjection.create(schema) override def run(): Unit = { TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) try { while (!shouldStop()) { if (!reader.next()) { @@ -149,8 +149,9 @@ class ContinuousQueuedDataReader( return } } - - queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row + // before copy here. + queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) } } catch { case _: InterruptedException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 551e07c3db868..a6cde2b8a710f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,24 +17,22 @@ package org.apache.spark.sql.execution.streaming.continuous -import scala.collection.JavaConverters._ - import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { +class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -56,18 +54,18 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA - - private var offset: Offset = _ + override def fullSchema(): StructType = RateStreamProvider.SCHEMA - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) } - override def getStartOffset(): Offset = offset + override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) - override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { - val partitionStartMap = offset match { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + + val partitionStartMap = startOffset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -90,8 +88,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR i, numPartitions, perPartitionRate) - .asInstanceOf[InputPartition[InternalRow]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + RateStreamContinuousReaderFactory } override def commit(end: Offset): Unit = {} @@ -118,33 +120,23 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartition[InternalRow] { - - override def createContinuousReader( - offset: PartitionOffset): InputPartitionReader[InternalRow] = { - val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] - require(rateStreamOffset.partition == partitionIndex, - s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousInputPartitionReader( - rateStreamOffset.currentValue, - rateStreamOffset.currentTimeMs, - partitionIndex, - increment, - rowsPerSecond) - } + extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new RateStreamContinuousInputPartitionReader( - startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) +object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamContinuousInputPartition] + new RateStreamContinuousPartitionReader( + p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + } } -class RateStreamContinuousInputPartitionReader( +class RateStreamContinuousPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartitionReader[InternalRow] { + extends ContinuousPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 1dbdfd558de48..28ab2448a6633 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp -import java.util.{Calendar, List => JList} +import java.util.Calendar import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.{DefaultFormats, NoTypeHints} @@ -34,24 +33,26 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} +import org.apache.spark.sql.execution.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** - * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This ContinuousReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This ContinuousReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. * * The driver maintains a socket connection to the host-port, keeps the received messages in * buckets and serves the messages to the executors via a RPC endpoint. */ -class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging { +class TextSocketContinuousReadSupport(options: DataSourceOptions) + extends ContinuousReadSupport with Logging { + implicit val defaultFormats: DefaultFormats = DefaultFormats private val host: String = options.get("host").get() @@ -73,7 +74,8 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR @GuardedBy("this") private var currentOffset: Int = -1 - private var startOffset: TextSocketOffset = _ + // Exposed for tests. + private[spark] var startOffset: TextSocketOffset = _ private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -94,16 +96,16 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR TextSocketOffset(Serialization.read[List[Int]](json)) } - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.startOffset = offset - .orElse(TextSocketOffset(List.fill(numPartitions)(0))) - .asInstanceOf[TextSocketOffset] - recordEndpoint.setStartOffsets(startOffset.offsets) + override def initialOffset(): Offset = { + startOffset = TextSocketOffset(List.fill(numPartitions)(0)) + startOffset } - override def getStartOffset: Offset = startOffset + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (includeTimestamp) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -111,8 +113,10 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR } } - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -132,10 +136,13 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR startOffset.offsets.zipWithIndex.map { case (offset, i) => - TextSocketContinuousInputPartition( - endpointName, i, offset, includeTimestamp): InputPartition[InternalRow] - }.asJava + TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) + }.toArray + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + TextSocketReaderFactory } override def commit(end: Offset): Unit = synchronized { @@ -190,7 +197,7 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR logWarning(s"Stream closed by $host:$port") return } - TextSocketContinuousReader.this.synchronized { + TextSocketContinuousReadSupport.this.synchronized { currentOffset += 1 val newData = (line, Timestamp.valueOf( @@ -221,25 +228,30 @@ case class TextSocketContinuousInputPartition( driverEndpointName: String, partitionId: Int, startOffset: Int, - includeTimestamp: Boolean) -extends InputPartition[InternalRow] { + includeTimestamp: Boolean) extends InputPartition + + +object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, - includeTimestamp) + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[TextSocketContinuousInputPartition] + new TextSocketContinuousPartitionReader( + p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) + } } + /** * Continuous text socket input partition reader. * * Polls the driver endpoint for new records. */ -class TextSocketContinuousInputPartitionReader( +class TextSocketContinuousPartitionReader( driverEndpointName: String, partitionId: Int, startOffset: Int, includeTimestamp: Boolean) - extends ContinuousInputPartitionReader[InternalRow] { + extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 967dbe24a3705..a08411d746abe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} +import org.apache.spark.sql.sources.v2.writer.DataWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory import org.apache.spark.util.Utils /** @@ -31,7 +32,7 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) +class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -50,7 +51,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { val dataIterator = prev.compute(split, context) - dataWriter = writeTask.createDataWriter( + dataWriter = writerFactory.createWriter( context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 8877ebeb26735..2238ce26e7b46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writer, reader, query, startEpoch, session, env.rpcEnv) + writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator( s"and is ready to be committed. Committing epoch $epoch.") // Sequencing is important here. We must commit to the writer before recording the commit // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, messages.toArray) + writeSupport.commit(epoch, messages.toArray) query.commit(epoch) } @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 943c731a70529..7ad21cc304e7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource( - writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 927d3a84e296b..c216b61383856 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** - * The physical plan for writing data into a continuous processing [[StreamWriter]]. + * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. */ -case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) +case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) extends SparkPlan with Logging { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer.createWriterFactory() + val writerFactory = writeSupport.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index f81abdcc3711a..adf52aba21a04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -34,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -67,7 +64,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def readSchema(): StructType = encoder.schema + def fullSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -80,7 +77,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -122,24 +119,22 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] - endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] - } - } - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def getStartOffset: OffsetV2 = synchronized { - if (startOffset.offset == -1) null else startOffset + override def initialOffset: OffsetV2 = LongOffset(-1) + + override def latestOffset(): OffsetV2 = { + if (currentOffset.offset == -1) null else currentOffset } - override def getEndOffset: OffsetV2 = synchronized { - if (endOffset.offset == -1) null else endOffset + override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOffset = sc.start.asInstanceOf[LongOffset] + val endOffset = sc.end.get.asInstanceOf[LongOffset] synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,11 +151,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block): InputPartition[InternalRow] - }.asJava + new MemoryStreamInputPartition(block) + }.toArray } } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + MemoryStreamReaderFactory + } + private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -201,10 +200,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[InternalRow] = { - new InputPartitionReader[InternalRow] { +class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition + +object MemoryStreamReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val records = partition.asInstanceOf[MemoryStreamInputPartition].records + new PartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala index fd45ba509091e..833e62f35ede1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceOptions) - extends StreamWriter with Logging { +class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) + extends StreamingWriteSupport with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -39,7 +38,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory + def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 4a32217f149bd..dbcc4483e5770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,26 +17,22 @@ package org.apache.spark.sql.execution.streaming.sources -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.util.RpcUtils /** @@ -48,7 +44,9 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) + with ContinuousReadSupportProvider with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -59,9 +57,6 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) - @GuardedBy("this") - private var startOffset: ContinuousMemoryStreamOffset = _ - private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -75,15 +70,8 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def setStartOffset(start: Optional[Offset]): Unit = synchronized { - // Inferred initial offset is position 0 in each partition. - startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) - }.asInstanceOf[ContinuousMemoryStreamOffset] - } - - override def getStartOffset: Offset = synchronized { - startOffset + override def initialOffset(): Offset = { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -98,34 +86,40 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[ContinuousMemoryStreamOffset] synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => - new ContinuousMemoryStreamInputPartition( - endpointName, part, index): InputPartition[InternalRow] - }.toList.asJava + case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) + }.toArray } } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + ContinuousMemoryStreamReaderFactory + } + override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupport implementation + // ContinuousReadSupportProvider implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - this - } + options: DataSourceOptions): ContinuousReadSupport = this } object ContinuousMemoryStream { @@ -141,12 +135,16 @@ object ContinuousMemoryStream { /** * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamInputPartition( +case class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition[InternalRow] { - override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = - new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition + +object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] + new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) + } } /** @@ -154,10 +152,10 @@ class ContinuousMemoryStreamInputPartition( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamInputPartitionReader( +class ContinuousMemoryStreamPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] { + startOffset: Int) extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala index e8ce21cc12044..4218fd51ad206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,20 +37,21 @@ import org.apache.spark.sql.types.StructType * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T]( +case class ForeachWriteSupportProvider[T]( writer: ForeachWriter[T], - converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + converter: Either[ExpressionEncoder[T], InternalRow => T]) + extends StreamingWriteSupportProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new StreamWriter { + options: DataSourceOptions): StreamingWriteSupport = { + new StreamingWriteSupport { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createWriterFactory(): DataWriterFactory[InternalRow] = { + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( @@ -68,16 +69,16 @@ case class ForeachWriterProvider[T]( } } -object ForeachWriterProvider { +object ForeachWriteSupportProvider { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriterProvider[UnsafeRow]( + new ForeachWriteSupportProvider[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriterProvider[T](writer, Left(encoder)) + new ForeachWriteSupportProvider[T](writer, Left(encoder)) } } } @@ -85,8 +86,8 @@ object ForeachWriterProvider { case class ForeachWriterFactory[T]( writer: ForeachWriter[T], rowConverter: InternalRow => T) - extends DataWriterFactory[InternalRow] { - override def createDataWriter( + extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): ForeachDataWriter[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala new file mode 100644 index 0000000000000..9f88416871f8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} + +/** + * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped + * streaming write support. + */ +class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) + extends BatchWriteSupport { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.commit(eppchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.abort(eppchId, messages) + } + + override def createBatchWriterFactory(): DataWriterFactory = { + new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) + } +} + +class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) + extends DataWriterFactory { + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + streamingWriterFactory.createWriter(partitionId, taskId, epochId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index f26e11d842b29..ac3c71cc222b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,17 +21,18 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[DataSourceWriter]] on the driver. + * to a [[BatchWriteSupport]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { - override def createDataWriter( +case object PackedRowWriterFactory extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala index 2d43a7bb77872..90680ea38fbd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala @@ -17,21 +17,15 @@ package org.apache.spark.sql.execution.streaming.sources -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -/** - * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped - * streaming writer. - */ -class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } +// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. +trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + override def latestOffset(): Offset = { + throw new IllegalAccessException( + "latestOffset should not be called for RateControlMicroBatchReadSupport") + } - override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() + def latestOffset(start: Offset): Offset } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala similarity index 78% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala index 9e0d954932163..f5364047adff1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala @@ -19,27 +19,24 @@ package org.apache.spark.sql.execution.streaming.sources import java.io._ import java.nio.charset.StandardCharsets -import java.util.Optional import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import org.apache.commons.io.IOUtils import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { +class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReadSupport with Logging { import RateStreamProvider._ private[sources] val clock = { @@ -106,38 +103,30 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: @volatile private var lastTimeMs: Long = creationTimeMs - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA + override def initialOffset(): Offset = LongOffset(0L) - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end + override def latestOffset(): Offset = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) } override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + override def fullSchema(): StructType = SCHEMA + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startSeconds = sc.start.asInstanceOf[LongOffset].offset + val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") if (endSeconds > maxSeconds) { throw new ArithmeticException("Integer overflow. Max offset with " + @@ -153,7 +142,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") if (rangeStart == rangeEnd) { - return List.empty.asJava + return Array.empty } val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) @@ -170,8 +159,11 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : InputPartition[InternalRow] - }.toList.asJava + }.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + RateStreamMicroBatchReaderFactory } override def commit(end: Offset): Unit = {} @@ -183,26 +175,29 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchInputPartition( +case class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition[InternalRow] { + relativeMsPerValue: Double) extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, - localStartTimeMs, relativeMsPerValue) +object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] + new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, + p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) + } } -class RateStreamMicroBatchInputPartitionReader( +class RateStreamMicroBatchPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] { + relativeMsPerValue: Double) extends PartitionReader[InternalRow] { private var count: Long = 0 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492f0cb35..6942dfbfe0ecf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional - import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types._ /** @@ -42,13 +39,12 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -74,17 +70,14 @@ class RateStreamProvider extends DataSourceV2 } } - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) + new RateStreamMicroBatchReadSupport(options, checkpointLocation) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + options: DataSourceOptions): ContinuousReadSupport = { + new RateStreamContinuousReadSupport(options) + } override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 2a5d21f330541..2509450f0da9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -45,13 +45,15 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { - override def createStreamWriter( +class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider + with MemorySinkBase with Logging { + + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode, schema) + options: DataSourceOptions): StreamingWriteSupport = { + new MemoryStreamingWriteSupport(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -132,35 +134,15 @@ class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) } -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) - extends DataSourceWriter with SupportsCustomWriterMetrics with Logging { - - private val memoryV2CustomMetrics = new MemoryV2CustomMetrics(sink) - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) - - def commit(messages: Array[WriterCommitMessage]): Unit = { - val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data - } - sink.write(batchId, outputMode, newRows) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = { - // Don't accept any of the new input. - } - - override def getCustomMetrics: CustomMetrics = { - memoryV2CustomMetrics - } -} - -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamWriter with SupportsCustomWriterMetrics { +class MemoryStreamingWriteSupport( + val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamingWriteSupport with SupportsCustomWriterMetrics { private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) + override def createStreamingWriterFactory: MemoryWriterFactory = { + MemoryWriterFactory(outputMode, schema) + } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -173,19 +155,23 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: // Don't accept any of the new input. } - override def getCustomMetrics: CustomMetrics = { - customMemoryV2Metrics - } + override def getCustomMetrics: CustomMetrics = customMemoryV2Metrics } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) - extends DataWriterFactory[InternalRow] { + extends DataWriterFactory with StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( + partitionId: Int, + taskId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) + } + + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) + createWriter(partitionId, taskId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 874c479db95d5..b2a573eae504a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.text.SimpleDateFormat -import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.{Calendar, Locale} import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} @@ -32,16 +31,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.LongOffset -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader +import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String -// Shared object for micro-batch and continuous reader object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -50,14 +48,12 @@ object TextSocketReader { } /** - * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This MicroBatchReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This MicroBatchReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { - - private var startOffset: Offset = _ - private var endOffset: Offset = _ +class TextSocketMicroBatchReadSupport(options: DataSourceOptions) + extends MicroBatchReadSupport with Logging { private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -103,7 +99,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReader.this.synchronized { + TextSocketMicroBatchReadSupport.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -120,24 +116,15 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR readThread.start() } - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { - startOffset = start.orElse(LongOffset(-1L)) - endOffset = end.orElse(currentOffset) - } + override def initialOffset(): Offset = LongOffset(-1L) - override def getStartOffset(): Offset = { - Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) - } - - override def getEndOffset(): Offset = { - Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) - } + override def latestOffset(): Offset = currentOffset override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -145,12 +132,14 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - assert(startOffset != null && endOffset != null, - "start offset and end offset should already be set before create read tasks.") + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } - val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 - val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -172,26 +161,29 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR slices(idx % numPartitions).append(r) } - (0 until numPartitions).map { i => - val slice = slices(i) - new InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new InputPartitionReader[InternalRow] { - private var currentIdx = -1 + slices.map(TextSocketInputPartition) + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val slice = partition.asInstanceOf[TextSocketInputPartition].slice + new PartitionReader[InternalRow] { + private var currentIdx = -1 - override def get(): InternalRow = { - InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) - } + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def close(): Unit = {} + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) } + + override def close(): Unit = {} + } } - }.toList.asJava + } } override def commit(end: Offset): Unit = synchronized { @@ -227,8 +219,11 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR override def toString: String = s"TextSocketV2[host: $host, port: $port]" } +case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition + class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister with Logging { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider + with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -248,27 +243,18 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } - - new TextSocketMicroBatchReader(options) + new TextSocketMicroBatchReadSupport(options) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { + options: DataSourceOptions): ContinuousReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } - new TextSocketContinuousReader(options) + new TextSocketContinuousReadSupport(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ef8dc3a325a33..39e9e1ad426be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.{Locale, Optional} +import java.util.Locale import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,19 +172,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupport => - var tempReader: MicroBatchReader = null + case s: MicroBatchReadSupportProvider => + var tempReadSupport: MicroBatchReadSupport = null val schema = try { - tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) - tempReader.readSchema() + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createMicroBatchReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReader != null) { - tempReader.stop() - tempReader = null + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null } } Dataset.ofRows( @@ -192,16 +194,28 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupport => - val tempReader = s.createContinuousReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + case s: ContinuousReadSupportProvider => + var tempReadSupport: ContinuousReadSupport = null + val schema = try { + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createContinuousReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b9a56ffdde4b..7866e4f70f14b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -270,7 +270,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -299,7 +299,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamingWriteSupportProvider + if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 25bb05212d66f..cd52d991d55c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -256,7 +256,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index e4cead9df429c..5602310219a74 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,29 +24,71 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + public class ReadSupport extends JavaSimpleReadSupport { + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new AdvancedScanConfigBuilder(); + } + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; + return new AdvancedReaderFactory(requiredSchema); + } + } + + public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, + SupportsPushDownFilters, SupportsPushDownRequiredColumns { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override @@ -79,79 +121,54 @@ public Filter[] pushedFilters() { } @Override - public List> planInputPartitions() { - List> res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 4) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 9) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); - } - - return res; + public ScanConfig build() { + return this; } } - static class JavaAdvancedInputPartition implements InputPartition, - InputPartitionReader { - private int start; - private int end; - private StructType requiredSchema; + static class AdvancedReaderFactory implements PartitionReaderFactory { + StructType requiredSchema; - JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { - this.start = start; - this.end = end; + AdvancedReaderFactory(StructType requiredSchema) { this.requiredSchema = requiredSchema; } @Override - public InputPartitionReader createPartitionReader() { - return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } - @Override - public InternalRow get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = start; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -start; + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); } - } - return new GenericInternalRow(values); - } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java deleted file mode 100644 index 97d6176d02559..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanColumnarBatch { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planBatchInputPartitions() { - return java.util.Arrays.asList( - new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); - } - } - - static class JavaBatchInputPartition - implements InputPartition, InputPartitionReader { - private int start; - private int end; - - private static final int BATCH_SIZE = 20; - - private OnHeapColumnVector i; - private OnHeapColumnVector j; - private ColumnarBatch batch; - - JavaBatchInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader createPartitionReader() { - this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - this.batch = new ColumnarBatch(vectors); - return this; - } - - @Override - public boolean next() { - i.reset(); - j.reset(); - int count = 0; - while (start < end && count < BATCH_SIZE) { - i.putInt(count, start); - j.putInt(count, -start); - start += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - } - - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java new file mode 100644 index 0000000000000..28a9330398310 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + + class ReadSupport extends JavaSimpleReadSupport { + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 50); + partitions[1] = new JavaRangeInputPartition(50, 90); + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new ColumnarReaderFactory(); + } + } + + static class ColumnarReaderFactory implements PartitionReaderFactory { + private static final int BATCH_SIZE = 20; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + throw new UnsupportedOperationException(""); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + ColumnarBatch batch = new ColumnarBatch(vectors); + + return new PartitionReader() { + private int current = p.start; + + @Override + public boolean next() throws IOException { + i.reset(); + j.reset(); + int count = 0; + while (current < p.end && count < BATCH_SIZE) { + i.putInt(count, current); + j.putInt(count, -current); + current += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + }; + } + } + + @Override + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 2d21324f5ece3..18a11dde82198 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,38 +19,34 @@ import java.io.IOException; import java.util.Arrays; -import java.util.List; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { +public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader, SupportsReportPartitioning { - private final StructType schema = new StructType().add("a", "int").add("b", "int"); + class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { @Override - public StructType readSchema() { - return schema; + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); + partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); + return partitions; } @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new SpecificReaderFactory(); } @Override - public Partitioning outputPartitioning() { + public Partitioning outputPartitioning(ScanConfig config) { return new MyPartitioning(); } } @@ -66,50 +62,53 @@ public int numPartitions() { public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("a"); + return Arrays.asList(clusteredCols).contains("i"); } return false; } } - static class SpecificInputPartition implements InputPartition, - InputPartitionReader { - - private int[] i; - private int[] j; - private int current = -1; + static class SpecificInputPartition implements InputPartition { + int[] i; + int[] j; SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } + } - @Override - public boolean next() throws IOException { - current += 1; - return current < i.length; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {i[current], j[current]}); - } - - @Override - public void close() throws IOException { - - } + static class SpecificReaderFactory implements PartitionReaderFactory { @Override - public InputPartitionReader createPartitionReader() { - return this; + public PartitionReader createReader(InputPartition partition) { + SpecificInputPartition p = (SpecificInputPartition) partition; + return new PartitionReader() { + private int current = -1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); + } + + @Override + public void close() throws IOException { + + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 6fd6a44d2c4d5..cc9ac04a0dad3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,43 +17,39 @@ package test.org.apache.spark.sql.sources.v2; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupport { +public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader { + class ReadSupport extends JavaSimpleReadSupport { private final StructType schema; - Reader(StructType schema) { + ReadSupport(StructType schema) { this.schema = schema; } @Override - public StructType readSchema() { + public StructType fullSchema() { return schema; } @Override - public List> planInputPartitions() { - return java.util.Collections.emptyList(); + public InputPartition[] planInputPartitions(ScanConfig config) { + return new InputPartition[0]; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { throw new IllegalArgumentException("requires a user-supplied schema"); } @Override - public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - return new Reader(schema); + public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return new ReadSupport(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 274dc3745bcf9..2cdbba84ec4a4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,72 +17,26 @@ package test.org.apache.spark.sql.sources.v2; -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List> planInputPartitions() { - return java.util.Arrays.asList( - new JavaSimpleInputPartition(0, 5), - new JavaSimpleInputPartition(5, 10)); - } - } - - static class JavaSimpleInputPartition implements InputPartition, - InputPartitionReader { +import org.apache.spark.sql.sources.v2.reader.*; - private int start; - private int end; +public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - JavaSimpleInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader createPartitionReader() { - return new JavaSimpleInputPartition(start - 1, end); - } + class ReadSupport extends JavaSimpleReadSupport { @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {start, -start}); - } - - @Override - public void close() throws IOException { - + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java new file mode 100644 index 0000000000000..685f9b9747e85 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +abstract class JavaSimpleReadSupport implements BatchReadSupport { + + @Override + public StructType fullSchema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new JavaNoopScanConfigBuilder(fullSchema()); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new JavaSimpleReaderFactory(); + } +} + +class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { + + private StructType schema; + + JavaNoopScanConfigBuilder(StructType schema) { + this.schema = schema; + } + + @Override + public ScanConfig build() { + return this; + } + + @Override + public StructType readSchema() { + return schema; + } +} + +class JavaSimpleReaderFactory implements PartitionReaderFactory { + + @Override + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {current, -current}); + } + + @Override + public void close() throws IOException { + + } + }; + } +} + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 46b38bed1c0fb..a36b0cfa6ff18 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 1efaead0845db..50f13bee251ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -41,10 +41,11 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(writer.commit().data.isEmpty) } - test("continuous writer") { + test("streaming writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) - writer.commit(0, + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), new StructType().add("i", "int")) + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -52,29 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writer.commit(19, - Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) - } - - test("microbatch writer") { - val sink = new MemorySinkV2 - val schema = new StructType().add("i", "int") - new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( + writeSupport.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -88,22 +67,21 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("writer metrics") { val sink = new MemorySinkV2 val schema = new StructType().add("i", "int") + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), schema) // batch 0 - var writer = new MemoryWriter(sink, 0, OutputMode.Append(), schema) - writer.commit( + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) )) - assert(writer.getCustomMetrics.json() == "{\"numRows\":6}") + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":6}") // batch 1 - writer = new MemoryWriter(sink, 1, OutputMode.Append(), schema - ) - writer.commit( + writeSupport.commit(1, Array( MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) )) - assert(writer.getCustomMetrics.json() == "{\"numRows\":8}") + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":8}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 55acf2ba28d2f..5884380271f0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream -import org.scalatest.time.SpanSugar._ - import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} -class ConsoleWriterSuite extends StreamTest { +class ConsoleWriteSupportSuite extends StreamTest { import testImplicits._ test("microbatch - default") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 7e53da1f312cb..9c1756d68ccc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -42,7 +40,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -55,10 +53,10 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader( - Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case ds: MicroBatchReadSupportProvider => + val readSupport = ds.createMicroBatchReadSupport( + temp.getCanonicalPath, DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) case _ => throw new IllegalStateException("Could not find read support for rate") } @@ -68,7 +66,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -109,30 +107,19 @@ class RateSourceSuite extends StreamTest { ) } - test("microbatch - set offset") { - withTempDir { temp => - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - } - test("microbatch - infer offsets") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions( Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), temp.getCanonicalPath) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { + readSupport.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = readSupport.initialOffset() + startOffset match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - reader.getEndOffset() match { + readSupport.latestOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -141,15 +128,16 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) assert(tasks.size == 1) - val dataReader = tasks.get(0).createPartitionReader() + val dataReader = readerFactory.createReader(tasks(0)) val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -160,24 +148,25 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) assert(tasks.size == 11) - val readData = tasks.asScala - .map(_.createPartitionReader()) + val readData = tasks + .map(readerFactory.createReader) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() while (reader.next()) buf.append(reader.get()) buf } - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) } } @@ -288,41 +277,44 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[AnalysisException] { + val exception = intercept[UnsupportedOperationException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) + "rate source does not support user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) + case ds: ContinuousReadSupportProvider => + val readSupport = ds.createContinuousReadSupport( + "", DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val reader = new RateStreamContinuousReader( + val readSupport = new RateStreamContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createContinuousReaderFactory(config) assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[InternalRow]() - tasks.asScala.foreach { + tasks.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = reader.getStartOffset() + val startTimeMs = readSupport.initialOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] + val r = readerFactory.createReader(t) + .asInstanceOf[RateStreamContinuousPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 48e5cf75bf8bd..409156e5ebc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -21,7 +21,6 @@ import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp -import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -34,8 +33,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -49,14 +48,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } - if (batchReader != null) { - batchReader.stop() - batchReader = null - } } private var serverThread: ServerThread = null - private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -65,7 +59,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source } if (sources.isEmpty) { throw new Exception( @@ -91,7 +85,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -181,16 +175,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -199,7 +193,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReader(Optional.empty(), "", a) + provider.createMicroBatchReadSupport("", a) } } @@ -209,12 +203,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[AnalysisException] { - provider.createMicroBatchReader( - Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + val exception = intercept[UnsupportedOperationException] { + provider.createMicroBatchReadSupport( + userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) + "socket source does not support user-specified schema")) } test("input row metrics") { @@ -305,25 +299,27 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) assert(tasks.size == 2) val numRecords = 10 val data = scala.collection.mutable.ListBuffer[Int]() val offsets = scala.collection.mutable.ListBuffer[Int]() + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) import org.scalatest.time.SpanSugar._ failAfter(5 seconds) { // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.asScala.foreach { + tasks.foreach { case t: TextSocketContinuousInputPartition => - val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { r.next() offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) @@ -339,16 +335,15 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before data.clear() case _ => throw new IllegalStateException("Unexpected task type") } - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3)) - reader.commit(TextSocketOffset(List(5, 5))) - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5)) + assert(readSupport.startOffset.offsets == List(3, 3)) + readSupport.commit(TextSocketOffset(List(5, 5))) + assert(readSupport.startOffset.offsets == List(5, 5)) } def commitOffset(partition: Int, offset: Int): Unit = { - val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset] - .offsets.updated(partition, offset) - reader.commit(TextSocketOffset(offsetsToCommit)) - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit) + val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) + readSupport.commit(TextSocketOffset(offsetsToCommit)) + assert(readSupport.startOffset.offsets == offsetsToCommit) } } @@ -356,14 +351,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) - // ok to commit same offset - reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + + readSupport.startOffset = TextSocketOffset(List(5, 5)) assertThrows[IllegalStateException] { - reader.commit(TextSocketOffset(List(6, 6))) + readSupport.commit(TextSocketOffset(List(6, 6))) } } @@ -371,12 +365,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "includeTimestamp" -> "true", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) assert(tasks.size == 2) val numRecords = 4 @@ -384,9 +378,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.asScala.foreach { + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + tasks.foreach { case t: TextSocketContinuousInputPartition => - val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { r.next() assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index aa5f723365d5f..5edeff553eb16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.sources.v2 -import java.util.{ArrayList, List => JList} - import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -38,6 +36,21 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ + private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + }.head + } + + private def getJavaScanConfig( + query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -50,18 +63,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - - def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] - }.head - } - Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -70,58 +71,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getJavaScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } else { - val reader = getJavaReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getJavaScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } else { - val reader = getJavaReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getJavaScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q4) + val config = getScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q4) + val config = getJavaScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } } } } test("columnar batch scan implementation") { - Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -153,25 +154,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('a).agg(sum('b)) + val groupByColA = df.groupBy('i).agg(sum('j)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + val groupByColAB = df.groupBy('i, 'j).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('b).agg(sum('a)) + val groupByColB = df.groupBy('j).agg(sum('i)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -272,36 +273,30 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val reader1 = getReader(q1) - assert(reader1.requiredSchema.fieldNames === Seq("i")) + val config1 = getScanConfig(q1) + assert(config1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val reader2 = getReader(q2) - assert(reader2.requiredSchema.isEmpty) + val config2 = getScanConfig(q2) + assert(config2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val reader3 = getReader(q3) - assert(reader3.filters.isEmpty) - assert(reader3.requiredSchema.fieldNames === Seq("j")) + val config3 = getScanConfig(q3) + assert(config3.filters.isEmpty) + assert(config3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val reader4 = getReader(q4) - assert(reader4.requiredSchema.fieldNames === Seq("i")) + val config4 = getScanConfig(q4) + assert(config4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -324,240 +319,290 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } -class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") +case class RangeInputPartition(start: Int, end: Int) extends InputPartition - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { + override def build(): ScanConfig = this } -class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { +object SimpleReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def get(): InternalRow = InternalRow(current, -current) - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleInputPartition(start: Int, end: Int) - extends InputPartition[InternalRow] - with InputPartitionReader[InternalRow] { - private var current = start - 1 - - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new SimpleInputPartition(start, end) +abstract class SimpleReadSupport extends BatchReadSupport { + override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def next(): Boolean = { - current += 1 - current < end + override def newScanConfigBuilder(): ScanConfigBuilder = { + NoopScanConfigBuilder(fullSchema()) } - override def get(): InternalRow = InternalRow(current, -current) - - override def close(): Unit = {} + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SimpleReaderFactory + } } +class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { -class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5)) + } + } + + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - class Reader extends DataSourceReader - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] +class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + } - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported - } + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - override def pushedFilters(): Array[Filter] = filters - override def readSchema(): StructType = { - requiredSchema - } +class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + + class ReadSupport extends SimpleReadSupport { + override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v } - val res = new ArrayList[InputPartition[InternalRow]] + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] if (lowerBound.isEmpty) { - res.add(new AdvancedInputPartition(0, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 4) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 9) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 10)) } - res + res.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema + new AdvancedReaderFactory(requiredSchema) } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { +class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - private var current = start - 1 + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] - override def createPartitionReader(): InputPartitionReader[InternalRow] = { - new AdvancedInputPartition(start, end, requiredSchema) + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema } - override def close(): Unit = {} + override def readSchema(): StructType = requiredSchema - override def next(): Boolean = { - current += 1 - current < end + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } - override def get(): InternalRow = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current + override def pushedFilters(): Array[Filter] = filters + + override def build(): ScanConfig = this +} + +class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + InternalRow.fromSeq(values) + } + + override def close(): Unit = {} } - InternalRow.fromSeq(values) } } -class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport { +class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = - java.util.Collections.emptyList() + class ReadSupport(val schema: StructType) extends SimpleReadSupport { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = + Array.empty } - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { - new Reader(schema) + override def createBatchReadSupport( + schema: StructType, options: DataSourceOptions): BatchReadSupport = { + new ReadSupport(schema) } } -class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { +class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - class Reader extends DataSourceReader with SupportsScanColumnarBatch { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) + } - override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { - java.util.Arrays.asList( - new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + ColumnarReaderFactory } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class BatchInputPartitionReader(start: Int, end: Int) - extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { - +object ColumnarReaderFactory extends PartitionReaderFactory { private final val BATCH_SIZE = 20 - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - private var current = start + override def supportColumnarReads(partition: InputPartition): Boolean = true - override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + throw new UnsupportedOperationException + } - override def next(): Boolean = { - i.reset() - j.reset() + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[ColumnarBatch] { + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) + + private var current = start + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def get(): ColumnarBatch = batch - override def get(): ColumnarBatch = { - batch + override def close(): Unit = batch.close() + } } - - override def close(): Unit = batch.close() } -class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { - override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") +class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. - java.util.Arrays.asList( - new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) + Array( + SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), + SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SpecificReaderFactory } - override def outputPartitioning(): Partitioning = new MyPartitioning + override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case c: ClusteredDistribution => c.clusteredColumns.contains("i") case _ => false } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) - extends InputPartition[InternalRow] - with InputPartitionReader[InternalRow] { - assert(i.length == j.length) - - private var current = -1 +case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = this +object SpecificReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[SpecificInputPartition] + new PartitionReader[InternalRow] { + private var current = -1 - override def next(): Boolean = { - current += 1 - current < i.length - } + override def next(): Boolean = { + current += 1 + current < p.i.length + } - override def get(): InternalRow = InternalRow(i(current), j(current)) + override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - override def close(): Unit = {} + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index e1b8e9c44d725..952241b0b6be5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,34 +18,36 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.{Collections, List => JList, Optional} +import java.util.Optional import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/jobId/` to `target`. + * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/queryId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { +class SimpleWritableDataSource extends DataSourceV2 + with BatchReadSupportProvider with BatchWriteSupportProvider { private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { - override def readSchema(): StructType = schema + class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -53,21 +55,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVInputPartitionReader( - f.getPath.toUri.toString, - serializableConf): InputPartition[InternalRow] - }.toList.asJava + CSVInputPartitionReader(f.getPath.toUri.toString) + }.toArray } else { - Collections.emptyList() + Array.empty } } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val serializableConf = new SerializableConfiguration(conf) + new CSVReaderFactory(serializableConf) + } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[InternalRow] = { + class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { + override def createBatchWriterFactory(): DataWriterFactory = { SimpleCounter.resetCounter - new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -76,7 +80,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -91,23 +95,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) + val jobPath = new Path(new Path(path, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new Reader(path.toUri.toString, conf) + new ReadSupport(path.toUri.toString, conf) } - override def createWriter( - jobId: String, + override def createBatchWriteSupport( + queryId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[DataSourceWriter] = { + options: DataSourceOptions): Optional[BatchWriteSupport] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -130,39 +134,42 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } val pathStr = path.toUri.toString - Optional.of(new Writer(jobId, pathStr, conf)) + Optional.of(new WritSupport(queryId, pathStr, conf)) } } -class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { +case class CSVInputPartitionReader(path: String) extends InputPartition - @transient private var lines: Iterator[String] = _ - @transient private var currentLine: String = _ - @transient private var inputStream: FSDataInputStream = _ +class CSVReaderFactory(conf: SerializableConfiguration) + extends PartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[InternalRow] = { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val path = partition.asInstanceOf[CSVInputPartitionReader].path val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) - inputStream = fs.open(filePath) - lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - this - } - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + new PartitionReader[InternalRow] { + private val inputStream = fs.open(filePath) + private val lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala - override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + private var currentLine: String = _ - override def close(): Unit = { - inputStream.close() + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } + } } } @@ -183,12 +190,11 @@ private[v2] object SimpleCounter { } class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[InternalRow] { + extends DataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { + taskId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index df22bc1315b7d..b528006295179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -686,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.reader + case r: StreamingDataSourceV2Relation => r.readSupport } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 0f15cd6e5a506..fe77a1b4469c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getEndOffset: OffsetV2 = { + override def latestOffset(): OffsetV2 = { numTriggers += 1 - super.getEndOffset + super.latestOffset() } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 268ed58315fdd..73592526fb0f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.CountDownLatch import scala.collection.mutable import org.apache.commons.lang3.RandomStringUtils import org.json4s.NoTypeHints -import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter @@ -35,13 +32,12 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -218,25 +214,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // setOffsetRange should take 50 ms the first time it is called after data is added - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.setOffsetRange(start, end) - } - } - - // getEndOffset should take 100 ms the first time it is called after data is added - override def getEndOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1150) - super.getEndOffset() + // latestOffset should take 50 ms the first time it is called after data is added + override def latestOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.latestOffset() } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { synchronized { - clock.waitTillTime(1350) - super.planInputPartitions() + clock.waitTillTime(1150) + super.planInputPartitions(config) } } } @@ -277,34 +265,26 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when setOffsetRange is being called + // Test status and progress when `latestOffset` is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange + AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 - AssertOnQuery(_.status.isDataAvailable === false), - AssertOnQuery(_.status.isTriggerActive === true), - AssertOnQuery(_.status.message.startsWith("Getting offsets from")), - AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - - AdvanceManualClock(100), // time = 1150 to unblock getEndOffset - AssertClockTime(1150), - // will block on planInputPartitions that needs 1350 - AssertStreamExecThreadIsWaitingForTime(1350), + // will block on `planInputPartitions` that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1150), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions - AssertClockTime(1350), + AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` + AssertClockTime(1150), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -312,7 +292,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(150), // time = 1500 to unblock map task + AdvanceManualClock(350), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -332,11 +312,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("setOffsetRange") === 50) - assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 200) + assert(progress.durationMs.get("latestOffset") === 50) + assert(progress.durationMs.get("queryPlanning") === 100) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 150) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 4f198819b58d2..d6819eacd07ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -22,16 +22,15 @@ import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { case class LongPartitionOffset(offset: Long) extends PartitionOffset @@ -44,8 +43,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamWriter], - mock[ContinuousReader], + mock[StreamingWriteSupport], + mock[ContinuousReadSupport], mock[ContinuousExecution], coordinatorId, startEpoch, @@ -73,26 +72,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[InternalRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { - var index = -1 - var curr: UnsafeRow = _ - - override def next() = { - curr = queue.take() - index += 1 - true - } + val partitionReader = new ContinuousPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } - override def get = curr + override def get = curr - override def getOffset = LongPartitionOffset(index) + override def getOffset = LongPartitionOffset(index) - override def close() = {} - } + override def close() = {} } val reader = new ContinuousQueuedDataReader( - new ContinuousDataSourceRDDPartition(0, factory), + 0, + partitionReader, + new StructType().add("i", "int"), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4980b0cd41f81..3d21bc63e0cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 82836dced9df7..3c973d8ebc704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,20 +40,20 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writer: StreamWriter = _ + private var writeSupport: StreamingWriteSupport = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { - val reader = mock[ContinuousReader] - writer = mock[StreamWriter] + val reader = mock[ContinuousReadSupport] + writeSupport = mock[StreamingWriteSupport] query = mock[ContinuousExecution] - orderVerifier = inOrder(writer, query) + orderVerifier = inOrder(writeSupport, query) spark = new TestSparkSession() epochCoordinator - = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get) } test("single epoch") { @@ -209,12 +209,12 @@ class EpochCoordinatorSuite } private def verifyCommit(epoch: Long): Unit = { - orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(writeSupport).commit(eqTo(epoch), any()) orderVerifier.verify(query).commit(epoch) } private def verifyNoCommitFor(epoch: Long): Unit = { - verify(writer, never()).commit(eqTo(epoch), any()) + verify(writeSupport, never()).commit(eqTo(epoch), any()) verify(query, never()).commit(epoch) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 52b833a19c236..aeef4c8fe9332 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,73 +17,74 @@ package org.apache.spark.sql.streaming.sources -import java.util.Optional - import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -case class FakeReader() extends MicroBatchReader with ContinuousReader { - def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} - def getStartOffset: Offset = RateStreamOffset(Map()) - def getEndOffset: Offset = RateStreamOffset(Map()) - def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) - def commit(end: Offset): Unit = {} - def readSchema(): StructType = StructType(Seq()) - def stop(): Unit = {} - def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setStartOffset(start: Optional[Offset]): Unit = {} - - def planInputPartitions(): java.util.ArrayList[InputPartition[InternalRow]] = { +case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSupport { + override def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + override def fullSchema(): StructType = StructType(Seq()) + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = null + override def initialOffset(): Offset = RateStreamOffset(Map()) + override def latestOffset(): Offset = RateStreamOffset(Map()) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = null + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + throw new IllegalStateException("fake source - cannot actually read") + } + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { throw new IllegalStateException("fake source - cannot actually read") } } -trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { - override def createMicroBatchReader( - schema: Optional[StructType], +trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() } -trait FakeContinuousReadSupport extends ContinuousReadSupport { - override def createContinuousReader( - schema: Optional[StructType], +trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() } -trait FakeStreamWriteSupport extends StreamWriteSupport { - override def createStreamWriter( +trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { + options: DataSourceOptions): StreamingWriteSupport = { throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { override def shortName(): String = "fake-read-microbatch-only" } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-continuous-only" } class FakeReadBothModes extends DataSourceRegister - with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + with FakeMicroBatchReadSupportProvider with FakeContinuousReadSupportProvider { override def shortName(): String = "fake-read-microbatch-continuous" } @@ -91,7 +92,7 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { +class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { override def shortName(): String = "fake-write-microbatch-continuous" } @@ -106,8 +107,8 @@ class FakeSink extends Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteV1Fallback extends DataSourceRegister - with FakeStreamWriteSupport with StreamSinkProvider { +class FakeWriteSupportProviderV1Fallback extends DataSourceRegister + with FakeStreamingWriteSupportProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -190,11 +191,11 @@ class StreamingDataSourceV2Suite extends StreamTest { val v2Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteV1Fallback]) + .isInstanceOf[FakeWriteSupportProviderV1Fallback]) // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. - val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { val v1Query = testPositiveCase( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) @@ -218,35 +219,37 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) + case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupportProvider, _: StreamingWriteSupportProvider, + _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all case (r, _, _) - if !r.isInstanceOf[MicroBatchReadSupport] - && !r.isInstanceOf[ContinuousReadSupport] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] + && !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => + case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamWriteSupport, _: ContinuousTrigger) - if !r.isInstanceOf[ContinuousReadSupport] => + case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupportProvider] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") // Invalid - trigger is microbatch but reader is not case (r, _, t) - if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + if !r.isInstanceOf[MicroBatchReadSupportProvider] && + !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, s"Data source $read does not support microbatch processing") } From 71f38ac242157cbede684546159f2a27892ee09f Mon Sep 17 00:00:00 2001 From: cclauss Date: Wed, 22 Aug 2018 10:06:59 -0700 Subject: [PATCH 1451/2461] [SPARK-23698][PYTHON] Resolve undefined names in Python 3 ## What changes were proposed in this pull request? Fix issues arising from the fact that builtins __file__, __long__, __raw_input()__, __unicode__, __xrange()__, etc. were all removed from Python 3. __Undefined names__ have the potential to raise [NameError](https://docs.python.org/3/library/exceptions.html#NameError) at runtime. ## How was this patch tested? * $ __python2 -m flake8 . --count --select=E9,F82 --show-source --statistics__ * $ __python3 -m flake8 . --count --select=E9,F82 --show-source --statistics__ holdenk flake8 testing of https://github.com/apache/spark on Python 3.6.3 $ __python3 -m flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__ ``` ./dev/merge_spark_pr.py:98:14: F821 undefined name 'raw_input' result = raw_input("\n%s (y/n): " % prompt) ^ ./dev/merge_spark_pr.py:136:22: F821 undefined name 'raw_input' primary_author = raw_input( ^ ./dev/merge_spark_pr.py:186:16: F821 undefined name 'raw_input' pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) ^ ./dev/merge_spark_pr.py:233:15: F821 undefined name 'raw_input' jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) ^ ./dev/merge_spark_pr.py:278:20: F821 undefined name 'raw_input' fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) ^ ./dev/merge_spark_pr.py:317:28: F821 undefined name 'raw_input' raw_assignee = raw_input( ^ ./dev/merge_spark_pr.py:430:14: F821 undefined name 'raw_input' pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") ^ ./dev/merge_spark_pr.py:442:18: F821 undefined name 'raw_input' result = raw_input("Would you like to use the modified title? (y/n): ") ^ ./dev/merge_spark_pr.py:493:11: F821 undefined name 'raw_input' while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": ^ ./dev/create-release/releaseutils.py:58:16: F821 undefined name 'raw_input' response = raw_input("%s [y/n]: " % msg) ^ ./dev/create-release/releaseutils.py:152:38: F821 undefined name 'unicode' author = unidecode.unidecode(unicode(author, "UTF-8")).strip() ^ ./python/setup.py:37:11: F821 undefined name '__version__' VERSION = __version__ ^ ./python/pyspark/cloudpickle.py:275:18: F821 undefined name 'buffer' dispatch[buffer] = save_buffer ^ ./python/pyspark/cloudpickle.py:807:18: F821 undefined name 'file' dispatch[file] = save_file ^ ./python/pyspark/sql/conf.py:61:61: F821 undefined name 'unicode' if not isinstance(obj, str) and not isinstance(obj, unicode): ^ ./python/pyspark/sql/streaming.py:25:21: F821 undefined name 'long' intlike = (int, long) ^ ./python/pyspark/streaming/dstream.py:405:35: F821 undefined name 'long' return self._sc._jvm.Time(long(timestamp * 1000)) ^ ./sql/hive/src/test/resources/data/scripts/dumpdata_script.py:21:10: F821 undefined name 'xrange' for i in xrange(50): ^ ./sql/hive/src/test/resources/data/scripts/dumpdata_script.py:22:14: F821 undefined name 'xrange' for j in xrange(5): ^ ./sql/hive/src/test/resources/data/scripts/dumpdata_script.py:23:18: F821 undefined name 'xrange' for k in xrange(20022): ^ 20 F821 undefined name 'raw_input' 20 ``` Closes #20838 from cclauss/fix-undefined-names. Authored-by: cclauss Signed-off-by: Bryan Cutler --- dev/create-release/releaseutils.py | 8 +++-- dev/merge_spark_pr.py | 2 +- python/pyspark/sql/conf.py | 5 ++- python/pyspark/sql/streaming.py | 5 +-- python/pyspark/streaming/dstream.py | 2 ++ python/pyspark/streaming/tests.py | 34 ++++++++++++++++++- .../resources/data/scripts/dumpdata_script.py | 3 ++ 7 files changed, 50 insertions(+), 9 deletions(-) diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index ab812e1bb7c04..8cc990d871842 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -50,7 +50,7 @@ sys.exit(-1) if sys.version < '3': - input = raw_input + input = raw_input # noqa # Contributors list file name contributors_file_name = "contributors.txt" @@ -152,7 +152,11 @@ def get_commits(tag): if not is_valid_author(author): author = github_username # Guard against special characters - author = unidecode.unidecode(unicode(author, "UTF-8")).strip() + try: # Python 2 + author = unicode(author, "UTF-8") + except NameError: # Python 3 + author = str(author) + author = unidecode.unidecode(author).strip() commit = Commit(_hash, author, title, pr_number) commits.append(commit) return commits diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index fe05282efdd4d..28a6714856c10 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -40,7 +40,7 @@ JIRA_IMPORTED = False if sys.version < '3': - input = raw_input + input = raw_input # noqa # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index f80bf598c2211..71ea1631718f1 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -20,6 +20,9 @@ from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix +if sys.version_info[0] >= 3: + basestring = str + class RuntimeConfig(object): """User-facing configuration API, accessible through `SparkSession.conf`. @@ -59,7 +62,7 @@ def unset(self, key): def _checkType(self, obj, identifier): """Assert that an object is of type str.""" - if not isinstance(obj, str) and not isinstance(obj, unicode): + if not isinstance(obj, basestring): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 8c1fd4af674d7..ee13778a7dcd6 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -19,10 +19,7 @@ import json if sys.version >= '3': - intlike = int - basestring = unicode = str -else: - intlike = (int, long) + basestring = str from py4j.java_gateway import java_import diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 59977dcb435a8..ce42a857d0c06 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -23,6 +23,8 @@ if sys.version < "3": from itertools import imap as map, ifilter as filter +else: + long = int from py4j.protocol import Py4JJavaError diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 09af47a597bed..5cef621a28e6e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -179,7 +179,7 @@ def func(dstream): self._test_func(input, func, expected) def test_flatMap(self): - """Basic operation test for DStream.faltMap.""" + """Basic operation test for DStream.flatMap.""" input = [range(1, 5), range(5, 9), range(9, 13)] def func(dstream): @@ -206,6 +206,38 @@ def func(dstream): expected = [[len(x)] for x in input] self._test_func(input, func, expected) + def test_slice(self): + """Basic operation test for DStream.slice.""" + import datetime as dt + self.ssc = StreamingContext(self.sc, 1.0) + self.ssc.remember(4.0) + input = [[1], [2], [3], [4]] + stream = self.ssc.queueStream([self.sc.parallelize(d, 1) for d in input]) + + time_vals = [] + + def get_times(t, rdd): + if rdd and len(time_vals) < len(input): + time_vals.append(t) + + stream.foreachRDD(get_times) + + self.ssc.start() + self.wait_for(time_vals, 4) + begin_time = time_vals[0] + + def get_sliced(begin_delta, end_delta): + begin = begin_time + dt.timedelta(seconds=begin_delta) + end = begin_time + dt.timedelta(seconds=end_delta) + rdds = stream.slice(begin, end) + result_list = [rdd.collect() for rdd in rdds] + return [r for result in result_list for r in result] + + self.assertEqual(set([1]), set(get_sliced(0, 0))) + self.assertEqual(set([2, 3]), set(get_sliced(1, 2))) + self.assertEqual(set([2, 3, 4]), set(get_sliced(1, 4))) + self.assertEqual(set([1, 2, 3, 4]), set(get_sliced(0, 4))) + def test_reduce(self): """Basic operation test for DStream.reduce.""" input = [range(1, 5), range(5, 9), range(9, 13)] diff --git a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py index 341a1b40e07af..5b360208d36f6 100644 --- a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py +++ b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py @@ -18,6 +18,9 @@ # import sys +if sys.version_info[0] >= 3: + xrange = range + for i in xrange(50): for j in xrange(5): for k in xrange(20022): From 2381953ab5d9e86d87a9ef118f28bc3f67d6d805 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 22 Aug 2018 10:16:47 -0700 Subject: [PATCH 1452/2461] [SPARK-25105][PYSPARK][SQL] Include PandasUDFType in the import all of pyspark.sql.functions ## What changes were proposed in this pull request? Include PandasUDFType in the import all of pyspark.sql.functions ## How was this patch tested? Run the test case from the pyspark shell from the jira [spark-25105](https://jira.apache.org/jira/browse/SPARK-25105?jql=project%20%3D%20SPARK%20AND%20component%20in%20(ML%2C%20PySpark%2C%20SQL%2C%20%22Structured%20Streaming%22)) I manually test on pyspark-shell: before: ` >>> from pyspark.sql.functions import * >>> foo = pandas_udf(lambda x: x, 'v int', PandasUDFType.GROUPED_MAP) Traceback (most recent call last): File "", line 1, in NameError: name 'PandasUDFType' is not defined >>> ` after: ` >>> from pyspark.sql.functions import * >>> foo = pandas_udf(lambda x: x, 'v int', PandasUDFType.GROUPED_MAP) >>> ` Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22100 from kevinyu98/spark-25105. Authored-by: Kevin Yu Signed-off-by: Bryan Cutler --- python/pyspark/sql/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5833734103a4..d58d8d10e5cd3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2931,6 +2931,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] +__all__ += ["PandasUDFType"] __all__.sort() From 68ec4d641b87d2ab6a8cafc5d10c08253ae09e3d Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Wed, 22 Aug 2018 10:36:20 -0700 Subject: [PATCH 1453/2461] [SPARK-25181][CORE] Limit Thread Pool size in BlockManager Master and Slave endpoints ## What changes were proposed in this pull request? Limit Thread Pool size in BlockManager Master and Slave endpoints. Currently, BlockManagerMasterEndpoint and BlockManagerSlaveEndpoint both have thread pools with nearly unbounded (Integer.MAX_VALUE) numbers of threads. In certain cases, this can lead to driver OOM errors. This change limits the thread pools to 100 threads; this should not break any existing behavior because any tasks beyond that number will get queued. ## How was this patch tested? Manual testing Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22176 from mukulmurthy/25181-threads. Authored-by: Mukul Murthy Signed-off-by: Shixiong Zhu --- .../org/apache/spark/storage/BlockManagerMasterEndpoint.scala | 3 ++- .../org/apache/spark/storage/BlockManagerSlaveEndpoint.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 8e8f7d197c9ef..f984cf76e3463 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -54,7 +54,8 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private val askThreadPool = + ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100) private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) private val topologyMapper = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 742cf4fe393f9..67544b20408a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -37,7 +37,7 @@ class BlockManagerSlaveEndpoint( extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = - ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100) private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously From 3106324986612800240bc8c945be90c4cb368d79 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 22 Aug 2018 12:22:53 -0700 Subject: [PATCH 1454/2461] [SPARK-25184][SS] Fixed race condition in StreamExecution that caused flaky test in FlatMapGroupsWithState ## What changes were proposed in this pull request? The race condition that caused test failure is between 2 threads. - The MicrobatchExecution thread that processes inputs to produce answers and then generates progress events. - The test thread that generates some input data, checked the answer and then verified the query generated progress event. The synchronization structure between these threads is as follows 1. MicrobatchExecution thread, in every batch, does the following in order. a. Processes batch input to generate answer. b. Signals `awaitProgressLockCondition` to wake up threads waiting for progress using `awaitOffset` c. Generates progress event 2. Test execution thread a. Calls `awaitOffset` to wait for progress, which waits on `awaitProgressLockCondition`. b. As soon as `awaitProgressLockCondition` is signaled, it would move on the in the test to check answer. c. Finally, it would verify the last generated progress event. What can happen is the following sequence of events: 2a -> 1a -> 1b -> 2b -> 2c -> 1c. In other words, the progress event may be generated after the test tries to verify it. The solution has two steps. 1. Signal the waiting thread after the progress event has been generated, that is, after `finishTrigger()`. 2. Increase the timeout of `awaitProgressLockCondition.await(100 ms)` to a large value. This latter is to ensure that test thread for keeps waiting on `awaitProgressLockCondition`until the MicroBatchExecution thread explicitly signals it. With the existing small timeout of 100ms the following sequence can occur. - MicroBatchExecution thread updates committed offsets - Test thread waiting on `awaitProgressLockCondition` accidentally times out after 100 ms, finds that the committed offsets have been updated, therefore returns from `awaitOffset` and moves on to the progress event tests. - MicroBatchExecution thread then generates progress event and signals. But the test thread has already attempted to verify the event and failed. By increasing the timeout to large (e.g., `streamingTimeoutMs = 60 seconds`, similar to `awaitInitialization`), this above type of race condition is also avoided. ## How was this patch tested? Ran locally many times. Closes #22182 from tdas/SPARK-25184. Authored-by: Tathagata Das Signed-off-by: Tathagata Das --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 3 +- .../streaming/MicroBatchExecution.scala | 5 ++- .../execution/streaming/StreamExecution.scala | 4 +- .../sql/streaming/StateStoreMetricsTest.scala | 44 ++++++++++--------- .../spark/sql/streaming/StreamTest.scala | 2 +- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 946b636710f0d..c9c52503dcd1f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -970,7 +970,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { makeSureGetOffsetCalled, Execute { q => // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + q.awaitOffset( + 0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)), streamingTimeout.toMillis) }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index cf83ba7436d17..b1cafd67820c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -200,6 +200,10 @@ class MicroBatchExecution( finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + // Signal waiting threads. Note this must be after finishTrigger() to ensure all + // activities (progress generation, etc.) have completed before signaling. + withProgressLocked { awaitProgressLockCondition.signalAll() } + // If the current batch has been executed, then increment the batch id and reset flag. // Otherwise, there was no data to execute the batch and sleep for some time if (isCurrentBatchConstructed) { @@ -538,7 +542,6 @@ class MicroBatchExecution( watermarkTracker.updateWatermark(lastExecution.executedPlan) commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) committedOffsets ++= availableOffsets - awaitProgressLockCondition.signalAll() } logDebug(s"Completed batch ${currentBatchId}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 290de873c5cfb..a39bb715c9913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -382,7 +382,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets @@ -398,7 +398,7 @@ abstract class StreamExecution( while (notDone) { awaitProgressLock.lock() try { - awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS) + awaitProgressLockCondition.await(timeoutMs, TimeUnit.MILLISECONDS) if (streamDeathCause != null) { throw streamDeathCause } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index e45f9d3e2e97b..fb5d13d09fb0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -31,33 +31,37 @@ trait StateStoreMetricsTest extends StreamTest { def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val recentProgress = q.recentProgress - require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") - require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, - "This test assumes that all progresses are present in q.recentProgress but " + - "some may have been dropped due to retention limits") + // This assumes that the streaming query will not make any progress while the eventually + // is being executed. + eventually(timeout(streamingTimeout)) { + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") - if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 - lastQuery = q + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q - val numStateOperators = recentProgress.last.stateOperators.length - val progressesSinceLastCheck = recentProgress - .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) - .filter(_.stateOperators.length == numStateOperators) + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) - val allNumUpdatedRowsSinceLastCheck = - progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) - lazy val debugString = "recent progresses:\n" + - progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") - val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) - assert(numTotalRows === total, s"incorrect total rows, $debugString") + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") - val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) - assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") - lastCheckedRecentProgressIndex = recentProgress.length - 1 + lastCheckedRecentProgressIndex = recentProgress.length - 1 + } true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index b528006295179..cd9b892eca1f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -467,7 +467,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(sourceIndex, offset, streamingTimeout.toMillis) // Make sure all processing including no-data-batches have been executed if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { currentStream.processAllAvailable() From 49a1993b168accb6f188c682546f12ea568173c4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Aug 2018 14:17:05 -0700 Subject: [PATCH 1455/2461] [SPARK-25163][SQL] Fix flaky test: o.a.s.util.collection.ExternalAppendOnlyMapSuiteCheck ## What changes were proposed in this pull request? `ExternalAppendOnlyMapSuiteCheck` test is flaky. We use a `SparkListener` to collect spill metrics of completed stages. `withListener` runs the code that does spill. Spill status was checked after the code finishes but it was still in `withListener`. At that time it was possibly not all events to the listener bus are processed. We should check spill status after all events are processed. ## How was this patch tested? Locally ran unit tests. Closes #22181 from viirya/SPARK-25163. Authored-by: Liang-Chi Hsieh Signed-off-by: Shixiong Zhu --- core/src/main/scala/org/apache/spark/TestUtils.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 6cc8fe1173d2e..c2ebd388a2365 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -173,10 +173,11 @@ private[spark] object TestUtils { * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { - withListener(sc, new SpillListener) { listener => + val listener = new SpillListener + withListener(sc, listener) { _ => body - assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") } + assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") } /** @@ -184,10 +185,11 @@ private[spark] object TestUtils { * did not spill. */ def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { - withListener(sc, new SpillListener) { listener => + val listener = new SpillListener + withListener(sc, listener) { _ => body - assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } + assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } /** From 2bc7b75537ec81184048738883b282e257cc58de Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 22 Aug 2018 23:14:56 +0000 Subject: [PATCH 1456/2461] [SPARK-24785][SHELL] Making sure REPL prints Spark UI info and then Welcome message ## What changes were proposed in this pull request? After https://github.com/apache/spark/pull/21495 the welcome message is printed first, and then Scala prompt will be shown before the Spark UI info is printed. Although it's a minor issue, but visually, it doesn't look as nice as the existing behavior. This PR intends to fix it by duplicating the Scala `process` code to arrange the printing order. However, one variable is private, so reflection has to be used which is not desirable. We can use this PR to brainstorm how to handle it properly and how Scala can change their APIs to fit our need. ## How was this patch tested? Existing test Closes #21749 from dbtsai/repl-followup. Authored-by: DB Tsai Signed-off-by: DB Tsai --- .../org/apache/spark/repl/SparkILoop.scala | 138 +++++++++++++++++- .../spark/repl/SparkILoopInterpreter.scala | 18 +-- 2 files changed, 138 insertions(+), 18 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index a44051b351e19..94265267b1f97 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -22,8 +22,16 @@ import java.io.BufferedReader // scalastyle:off println import scala.Predef.{println => _, _} // scalastyle:on println +import scala.concurrent.Future +import scala.reflect.classTag +import scala.reflect.internal.util.ScalaClassLoader.savingContextLoader +import scala.reflect.io.File +import scala.tools.nsc.{GenericRunnerSettings, Properties} import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} +import scala.tools.nsc.interpreter.{isReplDebug, isReplPower, replProps} +import scala.tools.nsc.interpreter.{AbstractOrMissingHandler, ILoop, IMain, JPrintWriter} +import scala.tools.nsc.interpreter.{NamedParam, SimpleReader, SplashLoop, SplashReader} +import scala.tools.nsc.interpreter.StdReplTags.tagOfIMain import scala.tools.nsc.util.stringFromStream import scala.util.Properties.{javaVersion, javaVmName, versionString} @@ -36,7 +44,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this() = this(None, new JPrintWriter(Console.out, true)) override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out, initializeSpark) + intp = new SparkILoopInterpreter(settings, out) } val initializationCommands: Seq[String] = Seq( @@ -116,6 +124,132 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) super.replay() } + /** + * The following code is mostly a copy of `process` implementation in `ILoop.scala` in Scala + * + * In newer version of Scala, `printWelcome` is the first thing to be called. As a result, + * SparkUI URL information would be always shown after the welcome message. + * + * However, this is inconsistent compared with the existing version of Spark which will always + * show SparkUI URL first. + * + * The only way we can make it consistent will be duplicating the Scala code. + * + * We should remove this duplication once Scala provides a way to load our custom initialization + * code, and also customize the ordering of printing welcome message. + */ + override def process(settings: Settings): Boolean = savingContextLoader { + + def newReader = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + + /** Reader to use before interpreter is online. */ + def preLoop = { + val sr = SplashReader(newReader) { r => + in = r + in.postInit() + } + in = sr + SplashLoop(sr, prompt) + } + + /* Actions to cram in parallel while collecting first user input at prompt. + * Run with output muted both from ILoop and from the intp reporter. + */ + def loopPostInit(): Unit = mumly { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[IMain]("$intp", intp)(tagOfIMain, classTag[IMain])) + + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // power mode setup + if (isReplPower) enablePowerMode(true) + initializeSpark() + loadInitFiles() + // SI-7418 Now, and only now, can we enable TAB completion. + in.postInit() + } + def loadInitFiles(): Unit = settings match { + case settings: GenericRunnerSettings => + for (f <- settings.loadfiles.value) { + loadCommand(f) + addReplay(s":load $f") + } + for (f <- settings.pastefiles.value) { + pasteCommand(f) + addReplay(s":paste $f") + } + case _ => + } + // wait until after startup to enable noisy settings + def withSuppressedSettings[A](body: => A): A = { + val ss = this.settings + import ss._ + val noisy = List(Xprint, Ytyperdebug) + val noisesome = noisy.exists(!_.isDefault) + val current = (Xprint.value, Ytyperdebug.value) + if (isReplDebug || !noisesome) body + else { + this.settings.Xprint.value = List.empty + this.settings.Ytyperdebug.value = false + try body + finally { + Xprint.value = current._1 + Ytyperdebug.value = current._2 + intp.global.printTypings = current._2 + } + } + } + def startup(): String = withSuppressedSettings { + // let them start typing + val splash = preLoop + + // while we go fire up the REPL + try { + // don't allow ancient sbt to hijack the reader + savingReader { + createInterpreter() + } + intp.initializeSynchronous() + + val field = classOf[ILoop].getDeclaredFields.filter(_.getName.contains("globalFuture")).head + field.setAccessible(true) + field.set(this, Future successful true) + + if (intp.reporter.hasErrors) { + echo("Interpreter encountered errors during initialization!") + null + } else { + loopPostInit() + printWelcome() + splash.start() + + val line = splash.line // what they typed in while they were waiting + if (line == null) { // they ^D + try out print Properties.shellInterruptedString + finally closeInterpreter() + } + line + } + } finally splash.stop() + } + + this.settings = settings + startup() match { + case null => false + case line => + try loop(line) match { + case LineResults.EOF => out print Properties.shellInterruptedString + case _ => + } + catch AbstractOrMissingHandler() + finally closeInterpreter() + true + } + } } object SparkILoop { diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index 4e63816402a10..e736607a9a6b9 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,22 +21,8 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) - extends IMain(settings, out) { self => - - /** - * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized - * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that - * the Spark context is visible in those files. - * - * This is a bit of a hack, but there isn't another hook available to us at this point. - * - * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. - */ - override def initializeSynchronous(): Unit = { - super.initializeSynchronous() - initializeSpark() - } +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { + self => override lazy val memberHandlers = new { val intp: self.type = self From 0295ad40def41b9a8ccefaaa1a7658899fb632a4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Aug 2018 08:10:45 +0800 Subject: [PATCH 1457/2461] [SPARK-25127] DataSourceV2: Remove SupportsPushDownCatalystFilters ## What changes were proposed in this pull request? They depend on internal Expression APIs. Let's see how far we can get without it. ## How was this patch tested? Just some code removal. There's no existing tests as far as I can tell so it's easy to remove. Closes #22185 from rxin/SPARK-25127. Authored-by: Reynold Xin Signed-off-by: Wenchen Fan --- .../SupportsPushDownCatalystFilters.java | 57 ------------------- .../v2/reader/SupportsPushDownFilters.java | 4 -- .../datasources/v2/DataSourceV2Strategy.scala | 5 -- 3 files changed, 66 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java deleted file mode 100644 index 9d79a18d14bcf..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; - -/** - * A mix-in interface for {@link ScanConfigBuilder}. Data source readers can implement this - * interface to push down arbitrary expressions as predicates to the data source. - * This is an experimental and unstable interface as {@link Expression} is not public and may get - * changed in the future Spark versions. - * - * Note that, if data source readers implement both this interface and - * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only - * process this interface. - */ -@InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters extends ScanConfigBuilder { - - /** - * Pushes down filters, and returns filters that need to be evaluated after scanning. - */ - Expression[] pushCatalystFilters(Expression[] filters); - - /** - * Returns the catalyst filters that are pushed to the data source via - * {@link #pushCatalystFilters(Expression[])}. - * - * There are 3 kinds of filters: - * 1. pushable filters which don't need to be evaluated again after scanning. - * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet - * row group filter. - * 3. non-pushable filters. - * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. - * - * It's possible that there is no filters in the query and - * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for - * this case. - */ - Expression[] pushedCatalystFilters(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 5d32a8ac60f78..5e7985f645a06 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -23,10 +23,6 @@ /** * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to * push down filters to the data source and reduce the size of the data to be read. - * - * Note that, if data source readers implement both this interface and - * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process - * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving public interface SupportsPushDownFilters extends ScanConfigBuilder { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fe713ff6c7850..9a3109e7c199e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -40,11 +40,6 @@ object DataSourceV2Strategy extends Strategy { configBuilder: ScanConfigBuilder, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { configBuilder match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (pushedFilters, postScanFilters) - case r: SupportsPushDownFilters => // A map from translated data source filters to original catalyst filter expressions. val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] From 1747469a1ff0b0ab6c5545fe6de63ffe42660580 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 23 Aug 2018 10:56:17 +0800 Subject: [PATCH 1458/2461] [SPARK-25167][SPARKR][TEST][MINOR] Minor fixes for R sql tests ## What changes were proposed in this pull request? A few SQL tests for R were failing in my development environment. In this PR, i am attempting to address some of them. Below are the reasons for the failure. - The catalog api tests assumes catalog artifacts named "foo" to be non existent. I think name such as foo and bar are common and i use it frequently. I have changed it to a string that i hope is less likely to collide. - One test assumes that we only have one database in the system. I had more than one and it caused the test to fail. I have changed that check. - One more test which compares two timestamp values fail - i am debugging this now. I will send it as a followup - may be. ## How was this patch tested? Its a test fix. Closes #22161 from dilipbiswal/r-sql-test-fix1. Authored-by: Dilip Biswal Signed-off-by: hyukjinkwon --- R/pkg/tests/fulltests/test_sparkSQL.R | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bff6e3512ee2f..e1f3cf339e83f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -734,8 +734,8 @@ test_that("test cache, uncache and clearCache", { clearCache() expect_true(dropTempView("table1")) - expect_error(uncacheTable("foo"), - "Error in uncacheTable : analysis error - Table or view not found: foo") + expect_error(uncacheTable("zxwtyswklpf"), + "Error in uncacheTable : analysis error - Table or view not found: zxwtyswklpf") }) test_that("insertInto() on a registered table", { @@ -3632,11 +3632,11 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { expect_equal(currentDatabase(), "default") expect_error(setCurrentDatabase("default"), NA) - expect_error(setCurrentDatabase("foo"), - "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + expect_error(setCurrentDatabase("zxwtyswklpf"), + "Error in setCurrentDatabase : analysis error - Database 'zxwtyswklpf' does not exist") dbs <- collect(listDatabases()) expect_equal(names(dbs), c("name", "description", "locationUri")) - expect_equal(dbs[[1]], "default") + expect_equal(which(dbs[, 1] == "default"), 1) }) test_that("catalog APIs, listTables, listColumns, listFunctions", { @@ -3659,8 +3659,9 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { expect_equal(colnames(c), c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) expect_equal(collect(c)[[1]][[1]], "speed") - expect_error(listColumns("foo", "default"), - "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + expect_error(listColumns("zxwtyswklpf", "default"), + paste("Error in listColumns : analysis error - Table", + "'zxwtyswklpf' does not exist in database 'default'")) f <- listFunctions() expect_true(nrow(f) >= 200) # 250 @@ -3668,8 +3669,9 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { c("name", "database", "description", "className", "isTemporary")) expect_equal(take(orderBy(f, "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") - expect_error(listFunctions("foo_db"), - "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + expect_error(listFunctions("zxwtyswklpf_db"), + paste("Error in listFunctions : analysis error - Database", + "'zxwtyswklpf_db' does not exist")) # recoverPartitions does not work with tempory view expect_error(recoverPartitions("cars"), From 05974f9431e9718a5f331a9892b7d81aca8387a6 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 23 Aug 2018 13:45:49 +0800 Subject: [PATCH 1459/2461] [SPARK-25133][SQL][DOC] Avro data source guide ## What changes were proposed in this pull request? Create documentation for AVRO data source. The new page will be linked in https://spark.apache.org/docs/latest/sql-programming-guide.html For preview please unzip the following file: [AvroDoc.zip](https://github.com/apache/spark/files/2313011/AvroDoc.zip) Closes #22121 from gengliangwang/avroDoc. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- docs/avro-data-source-guide.md | 380 +++++++++++++++++++++++++++++++++ docs/sql-programming-guide.md | 3 + 2 files changed, 383 insertions(+) create mode 100644 docs/avro-data-source-guide.md diff --git a/docs/avro-data-source-guide.md b/docs/avro-data-source-guide.md new file mode 100644 index 0000000000000..d3b81f029d377 --- /dev/null +++ b/docs/avro-data-source-guide.md @@ -0,0 +1,380 @@ +--- +layout: global +title: Apache Avro Data Source Guide +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +Since Spark 2.4 release, [Spark SQL](https://spark.apache.org/docs/latest/sql-programming-guide.html) provides built-in support for reading and writing Apache Avro data. + +## Deploying +The `spark-avro` module is external and not included in `spark-submit` or `spark-shell` by default. + +As with any Spark applications, `spark-submit` is used to launch your application. `spark-avro_{{site.SCALA_BINARY_VERSION}}` +and its dependencies can be directly added to `spark-submit` using `--packages`, such as, + + ./bin/spark-submit --packages org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +For experimenting on `spark-shell`, you can also use `--packages` to add `org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}` and its dependencies directly, + + ./bin/spark-shell --packages org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +See [Application Submission Guide](submitting-applications.html) for more details about submitting applications with external dependencies. + +## Load and Save Functions + +Since `spark-avro` module is external, there is no `.avro` API in +`DataFrameReader` or `DataFrameWriter`. + +To load/save data in Avro format, you need to specify the data source option `format` as `avro`(or `org.apache.spark.sql.avro`). +
      +
      +{% highlight scala %} + +val usersDF = spark.read.format("avro").load("examples/src/main/resources/users.avro") +usersDF.select("name", "favorite_color").write.format("avro").save("namesAndFavColors.avro") + +{% endhighlight %} +
      +
      +{% highlight java %} + +Dataset usersDF = spark.read().format("avro").load("examples/src/main/resources/users.avro"); +usersDF.select("name", "favorite_color").write().format("avro").save("namesAndFavColors.avro"); + +{% endhighlight %} +
      +
      +{% highlight python %} + +df = spark.read.format("avro").load("examples/src/main/resources/users.avro") +df.select("name", "favorite_color").write.format("avro").save("namesAndFavColors.avro") + +{% endhighlight %} +
      +
      +{% highlight r %} + +df <- read.df("examples/src/main/resources/users.avro", "avro") +write.df(select(df, "name", "favorite_color"), "namesAndFavColors.avro", "avro") + +{% endhighlight %} +
      +
      + +## to_avro() and from_avro() +The Avro package provides function `to_avro` to encode a column as binary in Avro +format, and `from_avro()` to decode Avro binary data into a column. Both functions transform one column to +another column, and the input/output SQL data type can be complex type or primitive type. + +Using Avro record as columns are useful when reading from or writing to a streaming source like Kafka. Each +Kafka key-value record will be augmented with some metadata, such as the ingestion timestamp into Kafka, the offset in Kafka, etc. +* If the "value" field that contains your data is in Avro, you could use `from_avro()` to extract your data, enrich it, clean it, and then push it downstream to Kafka again or write it out to a file. +* `to_avro()` can be used to turn structs into Avro records. This method is particularly useful when you would like to re-encode multiple columns into a single one when writing data out to Kafka. + +Both functions are currently only available in Scala and Java. + +
      +
      +{% highlight scala %} +import org.apache.spark.sql.avro._ + +// `from_avro` requires Avro schema in JSON string format. +val jsonFormatSchema = new String(Files.readAllBytes(Paths.get("./examples/src/main/resources/user.avsc"))) + +val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + +// 1. Decode the Avro data into a struct; +// 2. Filter by column `favorite_color`; +// 3. Encode the column `name` in Avro format. +val output = df + .select(from_avro('value, jsonFormatSchema) as 'user) + .where("user.favorite_color == \"red\"") + .select(to_avro($"user.name") as 'value) + +val query = output + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start() + +{% endhighlight %} +
      +
      +{% highlight java %} +import org.apache.spark.sql.avro.*; + +// `from_avro` requires Avro schema in JSON string format. +String jsonFormatSchema = new String(Files.readAllBytes(Paths.get("./examples/src/main/resources/user.avsc"))); + +Dataset df = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load(); + +// 1. Decode the Avro data into a struct; +// 2. Filter by column `favorite_color`; +// 3. Encode the column `name` in Avro format. +Dataset output = df + .select(from_avro(col("value"), jsonFormatSchema).as("user")) + .where("user.favorite_color == \"red\"") + .select(to_avro(col("user.name")).as("value")); + +StreamingQuery query = output + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic2") + .start(); + +{% endhighlight %} +
      +
      + +## Data Source Option + +Data source options of Avro can be set using the `.option` method on `DataFrameReader` or `DataFrameWriter`. +
      Property NameDefaultMeaning
      spark.kubernetes.pyspark.pythonversionspark.kubernetes.pyspark.pythonVersion "2" This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 968679df60367..4442333c573cc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -205,7 +205,7 @@ private[spark] object Config extends Logging { .createWithDefault(0.1) val PYSPARK_MAJOR_PYTHON_VERSION = - ConfigBuilder("spark.kubernetes.pyspark.pythonversion") + ConfigBuilder("spark.kubernetes.pyspark.pythonVersion") .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)") .stringConf .checkValue(pv => List("2", "3").contains(pv), diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala index 0254cc99de268..1ebb30094dcde 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -40,7 +40,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") - .set("spark.kubernetes.pyspark.pythonversion", "2") + .set("spark.kubernetes.pyspark.pythonVersion", "2") runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, mainClass = "", @@ -58,7 +58,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") - .set("spark.kubernetes.pyspark.pythonversion", "3") + .set("spark.kubernetes.pyspark.pythonVersion", "3") runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, mainClass = "", From 7822c3f8d1da9ecf9fe2fd422c4ee3f769976d43 Mon Sep 17 00:00:00 2001 From: Bo Meng Date: Thu, 16 Aug 2018 14:14:42 +0800 Subject: [PATCH 1406/2461] [SPARK-25082][SQL] improve the javadoc for expm1() ## What changes were proposed in this pull request? Correct the javadoc for expm1() function. ## How was this patch tested? None. It is a minor issue. Closes #22115 from bomeng/25082. Authored-by: Bo Meng Signed-off-by: hyukjinkwon --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5a6ed5964a750..c9331883c4799 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1660,7 +1660,7 @@ object functions { def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** - * Computes the exponential of the given column. + * Computes the exponential of the given column minus one. * * @group math_funcs * @since 1.4.0 From 5b4a38d826807ea6733e4382c8f9b82a355a6eb4 Mon Sep 17 00:00:00 2001 From: codeatri Date: Thu, 16 Aug 2018 17:07:33 +0900 Subject: [PATCH 1407/2461] [SPARK-23939][SQL] Add transform_keys function ## What changes were proposed in this pull request? This pr adds transform_keys function which applies the function to each entry of the map and transforms the keys. ```javascript > SELECT transform_keys(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + 1); map(2->1, 3->2, 4->3) > SELECT transform_keys(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v); map(2->1, 4->2, 6->3) ``` ## How was this patch tested? Added tests. Closes #22013 from codeatri/SPARK-23939. Authored-by: codeatri Signed-off-by: Takuya UESHIN --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 53 +++++++++++ .../HigherOrderFunctionsSuite.scala | 75 ++++++++++++++++ .../inputs/higher-order-functions.sql | 14 +++ .../results/higher-order-functions.sql.out | 39 ++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 87 +++++++++++++++++++ 6 files changed, 268 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cc2b758faa43b..b993e1a9bad63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -446,6 +446,7 @@ object FunctionRegistry { expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), + expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 22210f692e755..a305a05add7a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -497,6 +497,59 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } +/** + * Transform Keys for every entry of the map by applying the transform_keys function. + * Returns map with transformed key entries + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + map(array(2, 3, 4), array(1, 2, 3)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(2, 4, 6), array(1, 2, 3)) + """, + since = "2.4.0") +case class TransformKeys( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = argument.nullable + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + val result = functionForEval.eval(inputRow) + if (result == null) { + throw new RuntimeException("Cannot use null as map key!") + } + resultKeys.update(i, result) + i += 1 + } + new ArrayBasedMapData(resultKeys, map.valueArray()) + } + + override def prettyName: String = "transform_keys" +} + /** * Merges two given maps into a single map by applying function to the pair of values with * the same key. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 3137dc9bec49a..12ef01816835a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -74,6 +75,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) } + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val map = expr.dataType.asInstanceOf[MapType] + TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + } + def aggregate( expr: Expression, zero: Expression, @@ -283,6 +289,75 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper 15) } + test("TransformKeys") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 + val plusValue: (Expression, Expression) => Expression = (k, v) => k + v + val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 + + checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation( + transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + checkEvaluation(transformKeys(ai0, modKey), + ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation( + transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation( + transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + checkEvaluation(transformKeys(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(null, + MapType(StringType, StringType, valueContainsNull = false)) + val as3 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val convertKeyToKeyLength: (Expression, Expression) => Expression = + (k, v) => Length(k) + 1 + + checkEvaluation( + transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + checkEvaluation( + transformKeys(transformKeys(as0, concatValue), concatValue), + Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) + checkEvaluation( + transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), + Map.empty[Int, String]) + checkEvaluation(transformKeys(as0, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + checkEvaluation(transformKeys(as1, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> null)) + checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) + checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + } + test("MapZipWith") { def map_zip_with( left: Expression, diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index ce1d0daa4d397..9a8454455ae74 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -51,3 +51,17 @@ select exists(ys, y -> y > 30) as v from nested; -- Check for element existence in a null array select exists(cast(null as array), y -> y > 30) as v; + +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys); + +-- Identity Transform Keys in a map +select transform_keys(ys, (k, v) -> k) as v from nested; + +-- Transform Keys in a map by adding constant +select transform_keys(ys, (k, v) -> k + 1) as v from nested; + +-- Transform Keys in a map using values +select transform_keys(ys, (k, v) -> k + v) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e18abce3b617b..b77bda7bb2675 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 20 -- !query 0 @@ -163,3 +163,40 @@ select exists(cast(null as array), y -> y > 30) as v struct -- !query 16 output NULL + + +-- !query 17 +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys) +-- !query 17 schema +struct<> +-- !query 17 output + + +-- !query 18 +select transform_keys(ys, (k, v) -> k) as v from nested +-- !query 18 schema +struct> +-- !query 18 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 19 +select transform_keys(ys, (k, v) -> k + 1) as v from nested +-- !query 19 schema +struct> +-- !query 19 output +{2:1,3:2,4:3} +{5:4,6:5,7:6} + + +-- !query 20 +select transform_keys(ys, (k, v) -> k + v) as v from nested +-- !query 20 schema +struct> +-- !query 20 output +{10:5,12:6,8:4} +{2:1,4:2,6:3} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8d7695b6ebbcb..22f191209f87b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2302,6 +2302,93 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } + test("transform keys function - primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("j") + + val dfExample3 = Seq( + Map[Int, Boolean](25 -> true, 26 -> false) + ).toDF("x") + + val dfExample4 = Seq( + Map[Array[Int], Boolean](Array(1, 2) -> false) + ).toDF("y") + + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), + Seq(Row(Map(false -> false)))) + } + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform keys function - Invalid lambda functions and exceptions") { + + val dfExample1 = Seq( + Map[String, String]("a" -> null) + ).toDF("i") + + val dfExample2 = Seq( + Seq(1, 2, 3, 4) + ).toDF("j") + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains( + "The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[RuntimeException] { + dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() + } + assert(ex3.getMessage.contains("Cannot use null as map key!")) + + val ex4 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") + } + assert(ex4.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From ea63a7a1681c661844d31da6887e828760427f0c Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Thu, 16 Aug 2018 23:02:45 +0900 Subject: [PATCH 1408/2461] [SPARK-23932][SQL] Higher order function zip_with ## What changes were proposed in this pull request? Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function: ``` SELECT zip_with(ARRAY[1, 3, 5], ARRAY['a', 'b', 'c'], (x, y) -> (y, x)); -- [ROW('a', 1), ROW('b', 3), ROW('c', 5)] SELECT zip_with(ARRAY[1, 2], ARRAY[3, 4], (x, y) -> x + y); -- [4, 6] SELECT zip_with(ARRAY['a', 'b', 'c'], ARRAY['d', 'e', 'f'], (x, y) -> concat(x, y)); -- ['ad', 'be', 'cf'] SELECT zip_with(ARRAY['a'], ARRAY['d', null, 'f'], (x, y) -> coalesce(x, y)); -- ['a', null, 'f'] ``` ## How was this patch tested? Added tests Closes #22031 from techaddict/SPARK-23932. Authored-by: Sandeep Singh Signed-off-by: Takuya UESHIN --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/higherOrderFunctions.scala | 76 +++++++++++++++++++ .../HigherOrderFunctionsSuite.scala | 48 ++++++++++++ .../inputs/higher-order-functions.sql | 11 ++- .../results/higher-order-functions.sql.out | 48 +++++++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 63 +++++++++++++++ 6 files changed, 236 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b993e1a9bad63..061336455189e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -448,6 +448,8 @@ object FunctionRegistry { expression[ArrayAggregate]("aggregate"), expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), + expression[ZipWith]("zip_with"), + CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a305a05add7a8..9d603d79eedcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -740,3 +740,79 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) override def prettyName: String = "map_zip_with" } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); + array(('a', 1), ('b', 3), ('c', 5)) + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); + array(4, 6) + > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); + array('ad', 'be', 'cf') + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil + + override def functions: Seq[Expression] = List(function) + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = { + val ArrayType(leftElementType, _) = left.dataType + val ArrayType(rightElementType, _) = right.dataType + copy(function = f(function, + (leftElementType, true) :: (rightElementType, true) :: Nil)) + } + + @transient lazy val LambdaFunction(_, + Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function + + override def eval(input: InternalRow): Any = { + val leftArr = left.eval(input).asInstanceOf[ArrayData] + if (leftArr == null) { + null + } else { + val rightArr = right.eval(input).asInstanceOf[ArrayData] + if (rightArr == null) { + null + } else { + val resultLength = math.max(leftArr.numElements(), rightArr.numElements()) + val f = functionForEval + val result = new GenericArrayData(new Array[Any](resultLength)) + var i = 0 + while (i < resultLength) { + if (i < leftArr.numElements()) { + leftElemVar.value.set(leftArr.get(i, leftElemVar.dataType)) + } else { + leftElemVar.value.set(null) + } + if (i < rightArr.numElements()) { + rightElemVar.value.set(rightArr.get(i, rightElemVar.dataType)) + } else { + rightElemVar.value.set(null) + } + result.update(i, f.eval(input)) + i += 1 + } + result + } + } + } + + override def prettyName: String = "zip_with" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 12ef01816835a..3a78f14c8b2cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -471,4 +471,52 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper map_zip_with(mbb0, mbbn, concat), null) } + + test("ZipWith") { + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq[Integer](1, null), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val add: (Expression, Expression) => Expression = (x, y) => x + y + val plusOne: Expression => Expression = x => x + 1 + + checkEvaluation(zip_with(ai0, ai1, add), Seq(2, 4, 6, null)) + checkEvaluation(zip_with(ai3, ai2, add), Seq(2, null, null)) + checkEvaluation(zip_with(ai2, ai3, add), Seq(2, null, null)) + checkEvaluation(zip_with(ain, ain, add), null) + checkEvaluation(zip_with(ai1, ain, add), null) + checkEvaluation(zip_with(ain, ai1, add), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val concat: (Expression, Expression) => Expression = (x, y) => Concat(Seq(x, y)) + + checkEvaluation(zip_with(as0, as1, concat), Seq("aa", null, "cc")) + checkEvaluation(zip_with(as0, as2, concat), Seq("aa", null, null)) + + val aai1 = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + val aai2 = Literal.create(Seq(Seq(1, 2, 3)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + zip_with(aai1, aai2, (a1, a2) => + Cast(zip_with(transform(a1, plusOne), transform(a2, plusOne), add), StringType)), + Seq("[4, 6, 8]", null, null)) + checkEvaluation(zip_with(aai1, aai1, (a1, a2) => Cast(transform(a1, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 9a8454455ae74..05ec5effdf146 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -51,7 +51,16 @@ select exists(ys, y -> y > 30) as v from nested; -- Check for element existence in a null array select exists(cast(null as array), y -> y > 30) as v; - + +-- Zip with array +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested; + +-- Zip with array with concat +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v; + +-- Zip with array coalesce +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v; + create or replace temporary view nested as values (1, map(1, 1, 2, 2, 3, 3)), (2, map(4, 4, 5, 5, 6, 6)) diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index b77bda7bb2675..5a39616191e81 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -166,37 +166,63 @@ NULL -- !query 17 +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested +-- !query 17 schema +struct> +-- !query 17 output +[13] +[34,99,null] +[80,-74] + + +-- !query 18 +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v +-- !query 18 schema +struct> +-- !query 18 output +["ad","be","cf"] + + +-- !query 19 +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v +-- !query 19 schema +struct> +-- !query 19 output +["a",null,"f"] + + +-- !query 20 create or replace temporary view nested as values (1, map(1, 1, 2, 2, 3, 3)), (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) --- !query 17 schema +-- !query 20 schema struct<> --- !query 17 output +-- !query 20 output --- !query 18 +-- !query 21 select transform_keys(ys, (k, v) -> k) as v from nested --- !query 18 schema +-- !query 21 schema struct> --- !query 18 output +-- !query 21 output {1:1,2:2,3:3} {4:4,5:5,6:6} --- !query 19 +-- !query 22 select transform_keys(ys, (k, v) -> k + 1) as v from nested --- !query 19 schema +-- !query 22 schema struct> --- !query 19 output +-- !query 22 output {2:1,3:2,4:3} {5:4,6:5,7:6} --- !query 20 +-- !query 23 select transform_keys(ys, (k, v) -> k + v) as v from nested --- !query 20 schema +-- !query 23 schema struct> --- !query 20 output +-- !query 23 output {10:5,12:6,8:4} {2:1,4:2,6:3} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 22f191209f87b..9e2bfd3b7fba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2389,6 +2389,69 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "data type mismatch: argument 1 requires map type")) } + test("arrays zip_with function - for primitive types") { + val df1 = Seq[(Seq[Integer], Seq[Integer])]( + (Seq(9001, 9002, 9003), Seq(4, 5, 6)), + (Seq(1, 2), Seq(3, 4)), + (Seq.empty, Seq.empty), + (null, null) + ).toDF("val1", "val2") + val df2 = Seq[(Seq[Integer], Seq[Long])]( + (Seq(1, null, 3), Seq(1L, 2L)), + (Seq(1, 2, 3), Seq(4L, 11L)) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(9005, 9007, 9009)), + Row(Seq(4, 6)), + Row(Seq.empty), + Row(null)) + checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + val expectedValue2 = Seq( + Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), + Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) + checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + } + + test("arrays zip_with function - for non-primitive types") { + val df = Seq( + (Seq("a"), Seq("x", "y", "z")), + (Seq("a", null), Seq("x", "y")), + (Seq.empty[String], Seq.empty[String]), + (Seq("a", "b", "c"), null) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(Row("x", "a"), Row("y", null), Row("z", null))), + Row(Seq(Row("x", "a"), Row("y", null))), + Row(Seq.empty), + Row(null)) + checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + } + + test("arrays zip_with function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), Seq("x", "y", "z"), 1), + (Seq("b", null, "c", null), Seq("x"), 2), + (Seq.empty, Seq("x", "z"), 3), + (null, Seq("x", "z"), 4) + ).toDF("a1", "a2", "i") + val ex1 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + val ex2 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("Invalid number of arguments for function zip_with")) + val ex3 = intercept[AnalysisException] { + df.selectExpr("zip_with(i, a2, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex4 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From b3e6fe7c46bad991e850d258887400db5f7d7736 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 16 Aug 2018 12:34:23 -0700 Subject: [PATCH 1409/2461] [SPARK-23654][BUILD] remove jets3t as a dependency of spark ## What changes were proposed in this pull request? Remove jets3t dependency, and bouncy castle which it brings in; update licenses and deps Note this just takes over https://github.com/apache/spark/pull/21146 ## How was this patch tested? Existing tests. Closes #22081 from srowen/SPARK-23654. Authored-by: Sean Owen Signed-off-by: Sean Owen --- LICENSE-binary | 2 -- NOTICE | 2 -- NOTICE-binary | 21 -------------- core/pom.xml | 4 +-- dev/deps/spark-deps-hadoop-2.6 | 4 --- dev/deps/spark-deps-hadoop-2.7 | 4 --- dev/deps/spark-deps-hadoop-3.1 | 4 --- external/kafka-0-10-assembly/pom.xml | 5 ---- external/kafka-0-8-assembly/pom.xml | 5 ---- external/kinesis-asl-assembly/pom.xml | 5 ---- .../LICENSE-bouncycastle-bcprov.txt | 7 ----- pom.xml | 28 ++++++++----------- 12 files changed, 13 insertions(+), 78 deletions(-) delete mode 100644 licenses-binary/LICENSE-bouncycastle-bcprov.txt diff --git a/LICENSE-binary b/LICENSE-binary index c033dd8ad2e6a..b94ea90de08be 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -228,7 +228,6 @@ org.apache.xbean:xbean-asm5-shaded com.squareup.okhttp3:logging-interceptor com.squareup.okhttp3:okhttp com.squareup.okio:okio -net.java.dev.jets3t:jets3t org.apache.spark:spark-catalyst_2.11 org.apache.spark:spark-kvstore_2.11 org.apache.spark:spark-launcher_2.11 @@ -447,7 +446,6 @@ org.slf4j:jul-to-slf4j org.slf4j:slf4j-api org.slf4j:slf4j-log4j12 com.github.scopt:scopt_2.11 -org.bouncycastle:bcprov-jdk15on core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js core/src/main/resources/org/apache/spark/ui/static/*dataTables* diff --git a/NOTICE b/NOTICE index 23cb53fe3f367..fefe08b38afc5 100644 --- a/NOTICE +++ b/NOTICE @@ -26,5 +26,3 @@ The following provides more details on the included cryptographic software: This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to support authentication, and encryption and decryption of data sent across the network between services. - -This software includes Bouncy Castle (http://bouncycastle.org/) to support the jets3t library. diff --git a/NOTICE-binary b/NOTICE-binary index ad256aaf9f968..b707c436983f7 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -27,8 +27,6 @@ This software uses Apache Commons Crypto (https://commons.apache.org/proper/comm support authentication, and encryption and decryption of data sent across the network between services. -This software includes Bouncy Castle (http://bouncycastle.org/) to support the jets3t library. - // ------------------------------------------------------------------ // NOTICE file corresponding to the section 4d of The Apache License, @@ -1162,25 +1160,6 @@ NonlinearMinimizer class in package breeze.optimize.proximal is distributed with 2015, Debasish Das (Verizon), all rights reserved. - ========================================================================= - == NOTICE file corresponding to section 4(d) of the Apache License, == - == Version 2.0, in this case for the distribution of jets3t. == - ========================================================================= - - This product includes software developed by: - - The Apache Software Foundation (http://www.apache.org/). - - The ExoLab Project (http://www.exolab.org/) - - Sun Microsystems (http://www.sun.com/) - - Codehaus (http://castor.codehaus.org) - - Tatu Saloranta (http://wiki.fasterxml.com/TatuSaloranta) - - - stream-lib Copyright 2016 AddThis diff --git a/core/pom.xml b/core/pom.xml index d0b869e6ef92c..5fa3a86de6b01 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -88,8 +88,8 @@ ${project.version} - net.java.dev.jets3t - jets3t + javax.activation + activation org.apache.curator diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index bdab79c24bbb6..30727e69203f8 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -21,8 +21,6 @@ automaton-1.11-8.jar avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -101,7 +99,6 @@ jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.8.jar -java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,7 +116,6 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar jline-2.14.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ddaf9bbfc3cd9..e58ae7a0e7668 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -21,8 +21,6 @@ automaton-1.11-8.jar avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -101,7 +99,6 @@ jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar janino-3.0.8.jar -java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,7 +116,6 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-6.1.26.jar jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index d25d7aa862c56..8a65860d3c011 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -19,8 +19,6 @@ automaton-1.11-8.jar avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar -base64-2.3.8.jar -bcprov-jdk15on-1.58.jar bonecp-0.8.0.RELEASE.jar breeze-macros_2.11-0.13.2.jar breeze_2.11-0.13.2.jar @@ -100,7 +98,6 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar janino-3.0.8.jar -java-xmlbuilder-1.1.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar @@ -119,7 +116,6 @@ jersey-container-servlet-core-2.22.2.jar jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar -jets3t-0.9.4.jar jetty-webapp-9.3.24.v20180605.jar jetty-xml-9.3.24.v20180605.jar jline-2.14.3.jar diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index a742b8d6dbddb..f80f8e3a0183d 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -95,11 +95,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.scala-lang scala-library diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 41bc8b3e3ee1f..6be17a81f3fed 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -95,11 +95,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.scala-lang scala-library diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 37c7d1e604ec5..68fded515626b 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -89,11 +89,6 @@ log4j provided - - net.java.dev.jets3t - jets3t - provided - org.apache.hadoop hadoop-client diff --git a/licenses-binary/LICENSE-bouncycastle-bcprov.txt b/licenses-binary/LICENSE-bouncycastle-bcprov.txt deleted file mode 100644 index c445a93a06dd4..0000000000000 --- a/licenses-binary/LICENSE-bouncycastle-bcprov.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright (c) 2000 - 2018 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org) - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/pom.xml b/pom.xml index 33c15f20ed404..ab2a272fb3ff7 100644 --- a/pom.xml +++ b/pom.xml @@ -142,7 +142,6 @@ 3.1.5 1.8.2 hadoop2 - 0.9.4 1.8.10 1.11.271 @@ -911,6 +910,10 @@ com.sun.jersey.contribs * + + net.java.dev.jets3t + jets3t + @@ -984,24 +987,15 @@ - + - net.java.dev.jets3t - jets3t - ${jets3t.version} + javax.activation + activation + 1.1.1 ${hadoop.deps.scope} - - - commons-logging - commons-logging - - - - - org.bouncycastle - bcprov-jdk15on - - 1.58 org.apache.hadoop From e50192494d1ae1bdaf845ddd388189998c1a2403 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 16 Aug 2018 15:23:32 -0700 Subject: [PATCH 1410/2461] [SPARK-24555][ML] logNumExamples in KMeans/BiKM/GMM/AFT/NB ## What changes were proposed in this pull request? logNumExamples in KMeans/BiKM/GMM/AFT/NB ## How was this patch tested? existing tests Closes #21561 from zhengruifeng/alg_logNumExamples. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 14 ++++++----- .../spark/ml/clustering/BisectingKMeans.scala | 3 ++- .../spark/ml/clustering/GaussianMixture.scala | 5 ++++ .../ml/regression/AFTSurvivalRegression.scala | 1 + .../mllib/clustering/BisectingKMeans.scala | 23 +++++++++++++------ .../spark/mllib/clustering/KMeans.scala | 10 ++++++-- 8 files changed, 42 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 20f9366862bb8..1b5c02fc9a576 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -189,7 +189,7 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } - instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNumExamples(summarizer.count) instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 408d92ef180de..6f0804f0c8e4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -519,7 +519,7 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } - instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNumExamples(summarizer.count) instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index f65d3979791a6..51495c1a74e69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -162,19 +162,21 @@ class NaiveBayes @Since("1.5.0") ( // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) - }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( + }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))( seqOp = { - case ((weightSum: Double, featureSum: DenseVector), (weight, features)) => + case ((weightSum, featureSum, count), (weight, features)) => requireValues(features) BLAS.axpy(weight, features, featureSum) - (weightSum + weight, featureSum) + (weightSum + weight, featureSum, count + 1) }, combOp = { - case ((weightSum1, featureSum1), (weightSum2, featureSum2)) => + case ((weightSum1, featureSum1, count1), (weightSum2, featureSum2, count2)) => BLAS.axpy(1.0, featureSum2, featureSum1) - (weightSum1 + weightSum2, featureSum1) + (weightSum1 + weightSum2, featureSum1, count1 + count2) }).collect().sortBy(_._1) + val numSamples = aggregated.map(_._2._3).sum + instr.logNumExamples(numSamples) val numLabels = aggregated.length instr.logNumClasses(numLabels) val numDocuments = aggregated.map(_._2._1).sum @@ -186,7 +188,7 @@ class NaiveBayes @Since("1.5.0") ( val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 - aggregated.foreach { case (label, (n, sumTermFreqs)) => + aggregated.foreach { case (label, (n, sumTermFreqs, _)) => labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 48b8c52dbffd7..8904193cae94c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -273,11 +273,12 @@ class BisectingKMeans @Since("2.0.0") ( .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) .setDistanceMeasure($(distanceMeasure)) - val parentModel = bkm.run(rdd) + val parentModel = bkm.run(rdd, Some(instr)) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) instr.logNamedValue("clusterSizes", summary.clusterSizes) + instr.logNumFeatures(model.clusterCenters.head.size) model.setSummary(Some(summary)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 310b03b15822c..88abc1605d69f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -386,6 +386,11 @@ class GaussianMixture @Since("2.0.0") ( bcWeights.destroy(blocking = false) bcGaussians.destroy(blocking = false) + if (iter == 0) { + val numSamples = sums.count + instr.logNumExamples(numSamples) + } + /* Create new distributions based on the partial assignments (often referred to as the "M" step in literature) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3cd07063ee32d..8d6e36697d2cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -236,6 +236,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S fitIntercept, maxIter, tol, aggregationDepth) instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) instr.logNumFeatures(numFeatures) + instr.logNumExamples(featuresSummarizer.count) if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 98af487306dcc..80ab8eb9bc8b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -151,13 +152,10 @@ class BisectingKMeans private ( this } - /** - * Runs the bisecting k-means algorithm. - * @param input RDD of vectors - * @return model for the bisecting kmeans - */ - @Since("1.6.0") - def run(input: RDD[Vector]): BisectingKMeansModel = { + + private[spark] def run( + input: RDD[Vector], + instr: Option[Instrumentation]): BisectingKMeansModel = { if (input.getStorageLevel == StorageLevel.NONE) { logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + " its parent RDDs are also not cached.") @@ -171,6 +169,7 @@ class BisectingKMeans private ( val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) var activeClusters = summarize(d, assignments, dMeasure) + instr.foreach(_.logNumExamples(activeClusters.values.map(_.size).sum)) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -246,6 +245,16 @@ class BisectingKMeans private ( new BisectingKMeansModel(root, this.distanceMeasure) } + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + run(input, None) + } + /** * Java-friendly version of `run()`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 55df8a34fbfc7..d967c672c581f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -299,7 +299,7 @@ class KMeans private ( val bcCenters = sc.broadcast(centers) // Find the new centers - val newCenters = data.mapPartitions { points => + val collected = data.mapPartitions { points => val thisCenters = bcCenters.value val dims = thisCenters.head.vector.size @@ -317,7 +317,13 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.collectAsMap().mapValues { case (sum, count) => + }.collectAsMap() + + if (iteration == 0) { + instr.foreach(_.logNumExamples(collected.values.map(_._2).sum)) + } + + val newCenters = collected.mapValues { case (sum, count) => distanceMeasureInstance.centroid(sum, count) } From e59dd8fa0c2ae234866bc86b03dfe8bea8c79d2f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 16 Aug 2018 15:55:00 -0700 Subject: [PATCH 1411/2461] [SPARK-25092][SQL][FOLLOWUP] Add RewriteCorrelatedScalarSubquery in list of nonExcludableRules ## What changes were proposed in this pull request? Add RewriteCorrelatedScalarSubquery in the list of nonExcludableRules since its used to transform correlated scalar subqueries to joins. ## How was this patch tested? Added test in OptimizerRuleExclusionSuite Author: Dilip Biswal Closes #22108 from dilipbiswal/scalar_exclusion. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2ff67689c3492..63a62cd0cbfe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -193,6 +193,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewriteIntersectAll.ruleName :: ReplaceDistinctWithAggregate.ruleName :: PullupCorrelatedPredicates.ruleName :: + RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: Nil /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala index eee8dc3b76c34..4fa4a7aadc8f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -86,6 +86,8 @@ class OptimizerRuleExclusionSuite extends PlanTest { Seq( ReplaceIntersectWithSemiJoin.ruleName, PullupCorrelatedPredicates.ruleName, + RewriteCorrelatedScalarSubquery.ruleName, + RewritePredicateSubquery.ruleName, RewriteExceptAll.ruleName, RewriteIntersectAll.ruleName)) } From 709f541dd0c41c2ae8c0871b2593be9100bfc4ee Mon Sep 17 00:00:00 2001 From: Joey Krabacher Date: Thu, 16 Aug 2018 16:47:52 -0700 Subject: [PATCH 1412/2461] [DOCS] Update configuration.md changed $SPARK_HOME/conf/spark-default.conf to $SPARK_HOME/conf/spark-defaults.conf no testing necessary as this was a change to documentation. Closes #22116 from KraFusion/patch-1. Authored-by: Joey Krabacher Signed-off-by: Sean Owen --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 9c4742a1c0c85..0270dc2cfaf45 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2213,7 +2213,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defaults.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. From 30be71e91251971ad45c018538395cbebebc0c83 Mon Sep 17 00:00:00 2001 From: Joey Krabacher Date: Thu, 16 Aug 2018 16:48:51 -0700 Subject: [PATCH 1413/2461] [DOCS] Fix cloud-integration.md Typo Corrected typo; changed spark-default.conf to spark-defaults.conf Closes #22125 from KraFusion/patch-2. Authored-by: Joey Krabacher Signed-off-by: Sean Owen --- docs/cloud-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 18e8fe77bbdbe..36753f6373b55 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -104,7 +104,7 @@ Spark jobs must authenticate with the object stores to access data within them. and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options for the `s3n` and `s3a` connectors to Amazon S3. 1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. -1. Authentication details may be manually added to the Spark configuration in `spark-default.conf` +1. Authentication details may be manually added to the Spark configuration in `spark-defaults.conf` 1. Alternatively, they can be programmatically set in the `SparkConf` instance used to configure the application's `SparkContext`. From 9251c61bd8003b079eb7cac4cc6408cd266413c7 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 17 Aug 2018 10:18:08 +0800 Subject: [PATCH 1414/2461] [SPARK-24665][PYSPARK][FOLLOWUP] Use SQLConf in PySpark to manage all sql configs ## What changes were proposed in this pull request? Follow up for SPARK-24665, find some others hard code during code review. ## How was this patch tested? Existing UT. Closes #22122 from xuanyuanking/SPARK-24665-follow. Authored-by: Yuanjian Li Signed-off-by: hyukjinkwon --- python/pyspark/sql/catalog.py | 3 +-- python/pyspark/sql/session.py | 11 ++++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index b0d8357f4feec..974251f63b37a 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -177,8 +177,7 @@ def createTable(self, tableName, path=None, source=None, schema=None, **options) if path is not None: options["path"] = path if source is None: - source = self._sparkSession.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet") + source = self._sparkSession._wrapped._conf.defaultDataSourceName() if schema is None: df = self._jcatalog.createTable(tableName, source, options) else: diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f1ad6b1212ed9..19eea2fd2c775 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -678,9 +678,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.conf.get("spark.sql.session.timeZone") + if self._wrapped._conf.pandasRespectSessionTimeZone(): + timezone = self._wrapped._conf.sessionLocalTimeZone() else: timezone = None @@ -690,15 +689,13 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr (x.encode('utf-8') if not isinstance(x, str) else x) for x in data.columns] - if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ - and len(data) > 0: + if self._wrapped._conf.arrowEnabled() and len(data) > 0: try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: from pyspark.util import _exception_message - if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self._wrapped._conf.arrowFallbackEnabled(): msg = ( "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " From f16140975db758cf17e9687a465a9864bd1e0b50 Mon Sep 17 00:00:00 2001 From: codeatri Date: Fri, 17 Aug 2018 11:50:06 +0900 Subject: [PATCH 1415/2461] [SPARK-23940][SQL] Add transform_values SQL function ## What changes were proposed in this pull request? This pr adds `transform_values` function which applies the function to each entry of the map and transforms the values. ```javascript > SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> v + 1); map(1->2, 2->3, 3->4) > SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v); map(1->2, 2->4, 3->6) ``` ## How was this patch tested? New Tests added to `DataFrameFunctionsSuite` `HigherOrderFunctionsSuite` `SQLQueryTestSuite` Closes #22045 from codeatri/SPARK-23940. Authored-by: codeatri Signed-off-by: Takuya UESHIN --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 50 ++++- .../HigherOrderFunctionsSuite.scala | 73 ++++++++ .../inputs/higher-order-functions.sql | 9 + .../results/higher-order-functions.sql.out | 29 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 173 +++++++++++++++++- 6 files changed, 332 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 061336455189e..77860e1584f42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -446,6 +446,7 @@ object FunctionRegistry { expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), + expression[TransformValues]("transform_values"), expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), expression[ZipWith]("zip_with"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 9d603d79eedcf..f667a64f7f8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -527,7 +527,7 @@ case class TransformKeys( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { @@ -550,6 +550,54 @@ case class TransformKeys( override def prettyName: String = "transform_keys" } +/** + * Returns a map that applies the function to each value of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + map(array(1, 2, 3), array(2, 3, 4)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(1, 2, 3), array(2, 4, 6)) + """, + since = "2.4.0") +case class TransformValues( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = argument.nullable + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) + : TransformValues = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultValues = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + resultValues.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(map.keyArray(), resultValues) + } + + override def prettyName: String = "transform_values" +} + /** * Merges two given maps into a single map by applying function to the pair of values with * the same key. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 3a78f14c8b2cb..9d992c52e5357 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -101,6 +101,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper aggregate(expr, zero, merge, identity) } + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val map = expr.dataType.asInstanceOf[MapType] + TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -358,6 +363,74 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) } + test("TransformValues") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 + val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k + + checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4)) + checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4)) + checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val valueTypeUpdate: (Expression, Expression) => Expression = + (k, v) => Length(v) + 1 + + checkEvaluation( + transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx")) + checkEvaluation(transformValues(as0, valueTypeUpdate), + Map("a" -> 3, "bb" -> 3, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as0, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as1, concatValue), + Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx")) + checkEvaluation(transformValues(as1, valueTypeUpdate), + Map("a" -> 3, "bb" -> null, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as1, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation( + transformValues(transformValues(as2, concatValue), valueTypeUpdate), + Map.empty[String, Int]) + checkEvaluation(transformValues(as3, concatValue), null) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + } + test("MapZipWith") { def map_zip_with( left: Expression, diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 05ec5effdf146..02ad5e3538689 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -74,3 +74,12 @@ select transform_keys(ys, (k, v) -> k + 1) as v from nested; -- Transform Keys in a map using values select transform_keys(ys, (k, v) -> k + v) as v from nested; + +-- Identity Transform values in a map +select transform_values(ys, (k, v) -> v) as v from nested; + +-- Transform values in a map by adding constant +select transform_values(ys, (k, v) -> v + 1) as v from nested; + +-- Transform values in a map using values +select transform_values(ys, (k, v) -> k + v) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 5a39616191e81..32d20d1b73415 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 27 -- !query 0 @@ -226,3 +226,30 @@ struct> -- !query 23 output {10:5,12:6,8:4} {2:1,4:2,6:3} + + +-- !query 24 +select transform_values(ys, (k, v) -> v) as v from nested +-- !query 24 schema +struct> +-- !query 24 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 25 +select transform_values(ys, (k, v) -> v + 1) as v from nested +-- !query 25 schema +struct> +-- !query 25 output +{1:2,2:3,3:4} +{4:5,5:6,6:7} + + +-- !query 26 +select transform_values(ys, (k, v) -> k + v) as v from nested +-- !query 26 schema +struct> +-- !query 26 output +{1:2,2:4,3:6} +{4:8,5:10,6:12} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9e2bfd3b7fba8..156e54300e38b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2346,6 +2346,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) } + // Test with local relation, the Project will be evaluated without codegen testMapOfPrimitiveTypesCombination() dfExample1.cache() @@ -2357,7 +2358,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("transform keys function - Invalid lambda functions and exceptions") { - val dfExample1 = Seq( Map[String, String]("a" -> null) ).toDF("i") @@ -2389,6 +2389,177 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "data type mismatch: argument 1 requires map type")) } + test("transform values function - test primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Boolean, String](false -> "abc", true -> "def") + ).toDF("x") + + val dfExample3 = Seq( + Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) + ).toDF("y") + + val dfExample4 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("z") + + val dfExample5 = Seq( + Map[Int, Array[Int]](1 -> Array(1, 2)) + ).toDF("c") + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + Seq(Row(Map(1 -> 3)))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + dfExample5.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform values function - test empty") { + val dfExample1 = Seq( + Map.empty[Integer, Integer] + ).toDF("i") + + val dfExample2 = Seq( + Map.empty[BigInt, String] + ).toDF("j") + + def testEmpty(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), + Seq(Row(Map.empty[BigInt, BigInt]))) + } + + testEmpty() + dfExample1.cache() + dfExample2.cache() + testEmpty() + } + + test("transform values function - test null values") { + val dfExample1 = Seq( + Map[Int, Integer](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) + ).toDF("a") + + val dfExample2 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> null) + ).toDF("b") + + def testNullValue(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + } + + testNullValue() + dfExample1.cache() + dfExample2.cache() + testNullValue() + } + + test("transform values function - test invalid functions") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[String, String]("a" -> "b") + ).toDF("j") + + val dfExample3 = Seq( + Seq(1, 2, 3, 4) + ).toDF("x") + + def testInvalidLambdaFunctions(): Unit = { + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_values(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[AnalysisException] { + dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") + } + assert(ex3.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + + testInvalidLambdaFunctions() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + testInvalidLambdaFunctions() + } + test("arrays zip_with function - for primitive types") { val df1 = Seq[(Seq[Integer], Seq[Integer])]( (Seq(9001, 9002, 9003), Seq(4, 5, 6)), From 8af61fba03e1d32ddee4e83717fc8137682ffae6 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 17 Aug 2018 11:52:16 +0800 Subject: [PATCH 1416/2461] [SPARK-25122][SQL] Deduplication of supports equals code ## What changes were proposed in this pull request? The method ```*supportEquals``` determining whether elements of a data type could be used as items in a hash set or as keys in a hash map is duplicated across multiple collection and higher-order functions. This PR suggests to deduplicate the method. ## How was this patch tested? Run tests in: - DataFrameFunctionsSuite - CollectionExpressionsSuite - HigherOrderExpressionsSuite Closes #22110 from mn-mikke/SPARK-25122. Authored-by: Marek Novotny Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 38 ++++++------------- .../expressions/higherOrderFunctions.scala | 8 +--- .../spark/sql/catalyst/util/TypeUtils.scala | 13 ++++++- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5e3449d5631b5..cf9796ef1948f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1505,13 +1505,7 @@ case class ArraysOverlap(left: Expression, right: Expression) @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient private lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - - @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { fastEval _ } else { bruteForceEval _ @@ -1593,7 +1587,7 @@ case class ArraysOverlap(left: Expression, right: Expression) nullSafeCodeGen(ctx, ev, (a1, a2) => { val smaller = ctx.freshName("smallerArray") val bigger = ctx.freshName("biggerArray") - val comparisonCode = if (elementTypeSupportEquals) { + val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) { fastCodegen(ctx, ev, smaller, bigger) } else { bruteForceCodegen(ctx, ev, smaller, bigger) @@ -3404,12 +3398,6 @@ case class ArrayDistinct(child: Expression) } } - @transient private lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - @transient protected lazy val canUseSpecializedHashSet = elementType match { case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false @@ -3434,9 +3422,13 @@ case class ArrayDistinct(child: Expression) override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementTypeSupportEquals) { - new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) - } else { + doEvaluation(data) + } + + @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { + (data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + (data: Array[AnyRef]) => { var foundNullElement = false var pos = 0 for (i <- 0 until data.length) { @@ -3576,12 +3568,6 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { @transient protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) - @transient protected lazy val elementTypeSupportEquals = elementType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - @transient protected lazy val canUseSpecializedHashSet = elementType match { case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false @@ -3679,7 +3665,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike with ComplexTypeMergingExpression { @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] @@ -3896,7 +3882,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => if (array1.numElements() != 0 && array2.numElements() != 0) { val hs = new OpenHashSet[Any] @@ -4136,7 +4122,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { - if (elementTypeSupportEquals) { + if (TypeUtils.typeWithProperEquals(elementType)) { (array1, array2) => val hs = new OpenHashSet[Any] var notFoundNullElement = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index f667a64f7f8d2..3e0621d3e20eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -683,12 +683,6 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) value2Var: NamedLambdaVariable), _) = function - private def keyTypeSupportsEquals = keyType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - /** * The function accepts two key arrays and returns a collection of keys with indexes * to value arrays. Indexes are represented as an array of two items. This is a small @@ -696,7 +690,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) */ @transient private lazy val getKeysWithValueIndexes: (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { - if (keyTypeSupportsEquals) { + if (TypeUtils.typeWithProperEquals(keyType)) { getKeysWithIndexesFast } else { getKeysWithIndexesBruteForce diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 5214cdce861d4..76218b459ef0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ /** - * Helper functions to check for valid data types. + * Functions to help with checking for valid data types and value comparison of various types. */ object TypeUtils { def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = { @@ -73,4 +73,15 @@ object TypeUtils { } x.length - y.length } + + /** + * Returns true if the equals method of the elements of the data type is implemented properly. + * This also means that they can be safely used in collections relying on the equals method, + * as sets or maps. + */ + def typeWithProperEquals(dataType: DataType): Boolean = dataType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } } From c1ffb3c10aa362662e872b8ce11fc8674d31f5f6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 17 Aug 2018 14:13:37 +0900 Subject: [PATCH 1417/2461] [SPARK-23938][SQL][FOLLOW-UP][TEST] Nullabilities of value arguments should be true. ## What changes were proposed in this pull request? This is a follow-up pr of #22017 which added `map_zip_with` function. In the test, when creating a lambda function, we use the `valueContainsNull` values for the nullabilities of the value arguments, but we should've used `true` as the same as `bind` method because the values might be `null` if the keys don't match. ## How was this patch tested? Added small tests and existing tests. Closes #22126 from ueshin/issues/SPARK-23938/fix_tests. Authored-by: Takuya UESHIN Signed-off-by: Takuya UESHIN --- .../expressions/HigherOrderFunctionsSuite.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 9d992c52e5357..ea85c21f727c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -436,9 +436,9 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper left: Expression, right: Expression, f: (Expression, Expression, Expression) => Expression): Expression = { - val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType] - val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType] - MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f)) + val MapType(kt, vt1, _) = left.dataType + val MapType(_, vt2, _) = right.dataType + MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) } val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), @@ -475,6 +475,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mii0, miin, multiplyKeyWithValues), null) + assert(map_zip_with(mii0, mii1, multiplyKeyWithValues).dataType === + MapType(IntegerType, IntegerType, valueContainsNull = true)) val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), MapType(StringType, StringType, valueContainsNull = false)) @@ -510,6 +512,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mss0, mssn, concat), null) + assert(map_zip_with(mss0, mss1, concat).dataType === + MapType(StringType, StringType, valueContainsNull = true)) def b(data: Byte*): Array[Byte] = Array[Byte](data: _*) From 162326c0ee8419083ebd1669796abd234773e9b6 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 17 Aug 2018 00:04:04 -0700 Subject: [PATCH 1418/2461] [SPARK-25117][R] Add EXEPT ALL and INTERSECT ALL support in R ## What changes were proposed in this pull request? [SPARK-21274](https://issues.apache.org/jira/browse/SPARK-21274) added support for EXCEPT ALL and INTERSECT ALL. This PR adds the support in R. ## How was this patch tested? Added test in test_sparkSQL.R Author: Dilip Biswal Closes #22107 from dilipbiswal/SPARK-25117. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 59 ++++++++++++++++++++++++++- R/pkg/R/generics.R | 6 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 19 +++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index adfd3871f3426..0fd08482c4413 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -117,6 +117,7 @@ exportMethods("arrange", "dropna", "dtypes", "except", + "exceptAll", "explain", "fillna", "filter", @@ -131,6 +132,7 @@ exportMethods("arrange", "hint", "insertInto", "intersect", + "intersectAll", "isLocal", "isStreaming", "join", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 471ada15d655e..4f2d4c7c002d4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2848,6 +2848,35 @@ setMethod("intersect", dataFrame(intersected) }) +#' intersectAll +#' +#' Return a new SparkDataFrame containing rows in both this SparkDataFrame +#' and another SparkDataFrame while preserving the duplicates. +#' This is equivalent to \code{INTERSECT ALL} in SQL. Also as standard in +#' SQL, this function resolves columns by position (not by name). +#' +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the intersect all operation. +#' @family SparkDataFrame functions +#' @aliases intersectAll,SparkDataFrame,SparkDataFrame-method +#' @rdname intersectAll +#' @name intersectAll +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' intersectAllDF <- intersectAll(df1, df2) +#' } +#' @note intersectAll since 2.4.0 +setMethod("intersectAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersectAll", y@sdf) + dataFrame(intersected) + }) + #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame @@ -2867,7 +2896,6 @@ setMethod("intersect", #' df2 <- read.json(path2) #' exceptDF <- except(df, df2) #' } -#' @rdname except #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2876,6 +2904,35 @@ setMethod("except", dataFrame(excepted) }) +#' exceptAll +#' +#' Return a new SparkDataFrame containing rows in this SparkDataFrame +#' but not in another SparkDataFrame while preserving the duplicates. +#' This is equivalent to \code{EXCEPT ALL} in SQL. Also as standard in +#' SQL, this function resolves columns by position (not by name). +#' +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the except all operation. +#' @family SparkDataFrame functions +#' @aliases exceptAll,SparkDataFrame,SparkDataFrame-method +#' @rdname exceptAll +#' @name exceptAll +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' exceptAllDF <- exceptAll(df1, df2) +#' } +#' @note exceptAll since 2.4.0 +setMethod("exceptAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + excepted <- callJMethod(x@sdf, "exceptAll", y@sdf) + dataFrame(excepted) + }) + #' Save the contents of SparkDataFrame to a data source. #' #' The data source is specified by the \code{source} and a set of options (...). diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4a7210bf1b902..f6f1849787a23 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -471,6 +471,9 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname exceptAll +setGeneric("exceptAll", function(x, y) { standardGeneric("exceptAll") }) + #' @rdname nafunctions setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) @@ -495,6 +498,9 @@ setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertIn #' @rdname intersect setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) +#' @rdname intersectAll +setGeneric("intersectAll", function(x, y) { standardGeneric("intersectAll") }) + #' @rdname isLocal setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index adcbbff823a2d..bff6e3512ee2f 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2482,6 +2482,25 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF unlink(jsonPath2) }) +test_that("intersectAll() and exceptAll()", { + df1 <- createDataFrame(list(list("a", 1), list("a", 1), list("a", 1), + list("a", 1), list("b", 3), list("c", 4)), + schema = c("a", "b")) + df2 <- createDataFrame(list(list("a", 1), list("a", 1), list("b", 3)), schema = c("a", "b")) + intersectAllExpected <- data.frame("a" = c("a", "a", "b"), "b" = c(1, 1, 3), + stringsAsFactors = FALSE) + exceptAllExpected <- data.frame("a" = c("a", "a", "c"), "b" = c(1, 1, 4), + stringsAsFactors = FALSE) + intersectAllDf <- arrange(intersectAll(df1, df2), df1$a) + expect_is(intersectAllDf, "SparkDataFrame") + exceptAllDf <- arrange(exceptAll(df1, df2), df1$a) + expect_is(exceptAllDf, "SparkDataFrame") + intersectAllActual <- collect(intersectAllDf) + expect_identical(intersectAllActual, intersectAllExpected) + exceptAllActual <- collect(exceptAllDf) + expect_identical(exceptAllActual, exceptAllExpected) +}) + test_that("withColumn() and withColumnRenamed()", { df <- read.json(jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) From 8b0e94d89621befe52d2a53a8cf2f58f98887a61 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Aug 2018 18:40:29 +0000 Subject: [PATCH 1419/2461] [SPARK-23042][ML] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier ## What changes were proposed in this pull request? In MultilayerPerceptronClassifier, we use RDD operation to encode labels for now. I think we should use ML's OneHotEncoderEstimator/Model to do the encoding. ## How was this patch tested? Existing tests. Closes #20232 from viirya/SPARK-23042. Authored-by: Liang-Chi Hsieh Signed-off-by: DB Tsai --- .../fulltests/test_mllib_classification.R | 4 +- R/pkg/vignettes/sparkr-vignettes.Rmd | 2 +- docs/sparkr.md | 4 ++ .../MultilayerPerceptronClassifier.scala | 50 ++++++------------- project/MimaExcludes.scala | 5 +- 5 files changed, 26 insertions(+), 39 deletions(-) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index a46c47dccd02e..023686e75d50a 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -382,10 +382,10 @@ test_that("spark.mlp", { trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) traindf <- as.DataFrame(data[trainidxs, ]) testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) - model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2)) predictions <- predict(model, testdf) expect_error(collect(predictions)) - model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2), handleInvalid = "skip") predictions <- predict(model, testdf) expect_equal(class(collect(predictions)$clicked[1]), "list") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 68a18ab57b28d..090363c5f8a3e 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -654,7 +654,7 @@ We use Titanic data set to show how to use `spark.mlp` in classification. t <- as.data.frame(Titanic) training <- createDataFrame(t) # fit a Multilayer Perceptron Classification Model -model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 2), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 5, 5, 9, 9)) ``` To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. diff --git a/docs/sparkr.md b/docs/sparkr.md index 4faad2c4c1824..84e9b4ac6db7f 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -667,3 +667,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.3.1 and above - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. + +## Upgrading to SparkR 2.4.0 + + - Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 65e3b2d3beb14..4feddce1d9f2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -23,13 +23,13 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.OneHotEncoderModel import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} /** Params for Multilayer Perceptron. */ private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams @@ -103,36 +103,6 @@ private[classification] trait MultilayerPerceptronParams extends ProbabilisticCl solver -> LBFGS, stepSize -> 0.03) } -/** Label to vector converter. */ -private object LabelConverter { - // TODO: Use OneHotEncoder instead - /** - * Encodes a label as a vector. - * Returns a vector of given length with zeroes at all positions - * and value 1.0 at the position that corresponds to the label. - * - * @param labeledPoint labeled point - * @param labelCount total number of labels - * @return pair of features and vector encoding of a label - */ - def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { - val output = Array.fill(labelCount)(0.0) - output(labeledPoint.label.toInt) = 1.0 - (labeledPoint.features, Vectors.dense(output)) - } - - /** - * Converts a vector to a label. - * Returns the position of the maximal element of a vector. - * - * @param output label encoded with a vector - * @return label - */ - def decodeLabel(output: Vector): Double = { - output.argmax.toDouble - } -} - /** * Classifier trainer based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. @@ -243,8 +213,18 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( instr.logNumClasses(labels) instr.logNumFeatures(myLayers.head) - val lpData = extractLabeledPoints(dataset) - val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + // One-hot encoding for labels using OneHotEncoderModel. + // As we already know the length of encoding, we skip fitting and directly create + // the model. + val encodedLabelCol = "_encoded" + $(labelCol) + val encodeModel = new OneHotEncoderModel(uid, Array(labels)) + .setInputCols(Array($(labelCol))) + .setOutputCols(Array(encodedLabelCol)) + .setDropLast(false) + val encodedDataset = encodeModel.transform(dataset) + val data = encodedDataset.select($(featuresCol), encodedLabelCol).rdd.map { + case Row(features: Vector, encodedLabel: Vector) => (features, encodedLabel) + } val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, softmaxOnTop = true) val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) if (isDefined(initialWeights)) { @@ -323,7 +303,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * This internal method is used to implement `transform()` and output [[predictionCol]]. */ override def predict(features: Vector): Double = { - LabelConverter.decodeLabel(mlpModel.predict(features)) + mlpModel.predict(features).argmax.toDouble } @Since("1.5.0") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8f96bb0f33849..080cdd1c3de80 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -97,7 +97,10 @@ object MimaExcludes { ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + + // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter") ) // Exclude rules for 2.3.x From da2dc69291cda8c8e7bb6b4a15001f768a97f65e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 17 Aug 2018 14:21:08 -0700 Subject: [PATCH 1420/2461] [SPARK-25116][TESTS] Fix the Kafka cluster leak and clean up cached producers ## What changes were proposed in this pull request? KafkaContinuousSinkSuite leaks a Kafka cluster because both KafkaSourceTest and KafkaContinuousSinkSuite create a Kafka cluster but `afterAll` only shuts down one cluster. This leaks a Kafka cluster and causes that some Kafka thread crash and kill JVM when SBT is trying to clean up tests. This PR fixes the leak and also adds a shut down hook to detect Kafka cluster leak. In additions, it also fixes `AdminClient` leak and cleans up cached producers (When a record is writtn using a producer, the producer will keep refreshing the topic and I don't find an API to clear it except closing the producer) to eliminate the following annoying logs: ``` 8/13 15:34:42.568 kafka-admin-client-thread | adminclient-4 WARN NetworkClient: [AdminClient clientId=adminclient-4] Connection to node 0 could not be established. Broker may not be available. 18/08/13 15:34:42.570 kafka-admin-client-thread | adminclient-6 WARN NetworkClient: [AdminClient clientId=adminclient-6] Connection to node 0 could not be established. Broker may not be available. 18/08/13 15:34:42.606 kafka-admin-client-thread | adminclient-8 WARN NetworkClient: [AdminClient clientId=adminclient-8] Connection to node -1 could not be established. Broker may not be available. 18/08/13 15:34:42.729 kafka-producer-network-thread | producer-797 WARN NetworkClient: [Producer clientId=producer-797] Connection to node -1 could not be established. Broker may not be available. 18/08/13 15:34:42.906 kafka-producer-network-thread | producer-1598 WARN NetworkClient: [Producer clientId=producer-1598] Connection to node 0 could not be established. Broker may not be available. ``` I also reverted https://github.com/apache/spark/pull/22097/commits/b5eb54244ed573c8046f5abf7bf087f5f08dba58 introduced by #22097 since it doesn't help. ## How was this patch tested? Jenkins Closes #22106 from zsxwing/SPARK-25116. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../sql/kafka010/CachedKafkaProducer.scala | 8 +- .../sql/kafka010/KafkaContinuousReader.scala | 2 +- .../kafka010/CachedKafkaProducerSuite.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 7 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sql/kafka010/KafkaRelationSuite.scala | 3 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../apache/spark/sql/kafka010/KafkaTest.scala | 32 +++++++ .../spark/sql/kafka010/KafkaTestUtils.scala | 91 +++++++++---------- .../streaming/kafka010/KafkaTestUtils.scala | 89 +++++++++--------- 10 files changed, 132 insertions(+), 109 deletions(-) create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala index 571140b0afbc7..cd680adf44365 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala @@ -33,8 +33,12 @@ private[kafka010] object CachedKafkaProducer extends Logging { private type Producer = KafkaProducer[Array[Byte], Array[Byte]] + private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) + private lazy val cacheExpireTimeout: Long = - SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m") + Option(SparkEnv.get).map(_.conf.getTimeAsMs( + "spark.kafka.producer.cache.timeout", + s"${defaultCacheExpireTimeout}ms")).getOrElse(defaultCacheExpireTimeout) private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { override def load(config: Seq[(String, Object)]): Producer = { @@ -102,7 +106,7 @@ private[kafka010] object CachedKafkaProducer extends Logging { } } - private def clear(): Unit = { + private[kafka010] def clear(): Unit = { logInfo("Cleaning up guava cache.") guavaCache.invalidateAll() } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 48b91dfe764e9..be7ce3b3ed757 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -216,7 +216,7 @@ class KafkaContinuousInputPartitionReader( } catch { // We didn't read within the timeout. We're supposed to block indefinitely for new data, so // swallow and ignore this. - case _: TimeoutException => + case _: TimeoutException | _: org.apache.kafka.common.errors.TimeoutException => // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, // or if it's the endpoint of the data range (i.e. the "true" next offset). diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala index 789bffa9da126..0b3355426df10 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala @@ -26,14 +26,13 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.sql.test.SharedSQLContext -class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester { +class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester with KafkaTest { type KP = KafkaProducer[Array[Byte], Array[Byte]] protected override def beforeEach(): Unit = { super.beforeEach() - val clear = PrivateMethod[Unit]('clear) - CachedKafkaProducer.invokePrivate(clear()) + CachedKafkaProducer.clear() } test("Should return the cached instance on calling getOrCreate with same params.") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 0e1492ac27449..3f6fcf6b2e52c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -40,12 +40,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { override val streamingTimeout = 30.seconds - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } + override val brokerProps = Map("auto.create.topics.enable" -> "false") override def afterAll(): Unit = { if (testUtils != null) { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index aa898686c77ca..172c0ef251cac 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.sql.types.StructType -abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { +abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { protected var testUtils: KafkaTestUtils = _ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 91893df4ec32f..688e9c40fed22 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -21,13 +21,12 @@ import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.common.TopicPartition -import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { +class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest { import testImplicits._ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 70ffd7dee89d7..a2213e024bd98 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, DataType} -class KafkaSinkSuite extends StreamTest with SharedSQLContext { +class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { import testImplicits._ protected var testUtils: KafkaTestUtils = _ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala new file mode 100644 index 0000000000000..19acda95c707c --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + +/** A trait to clean cached Kafka producers in `afterAll` */ +trait KafkaTest extends BeforeAndAfterAll { + self: SparkFunSuite => + + override def afterAll(): Unit = { + super.afterAll() + CachedKafkaProducer.clear() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index e58d18361966f..55d61ef20ca8a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -39,14 +39,13 @@ import org.apache.kafka.clients.producer._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} -import org.apache.kafka.common.utils.Exit import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -60,7 +59,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 - private val zkSessionTimeout = 6000 + private val zkSessionTimeout = 10000 private var zookeeper: EmbeddedZookeeper = _ @@ -81,6 +80,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L // Flag to test whether the system is correctly started private var zkReady = false private var brokerReady = false + private var leakDetector: AnyRef = null def zkAddress: String = { assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") @@ -130,6 +130,13 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { + // Set up a KafkaTestUtils leak detector so that we can see where the leak KafkaTestUtils is + // created. + val exception = new SparkException("It was created at: ") + leakDetector = ShutdownHookManager.addShutdownHook { () => + logError("Found a leak KafkaTestUtils.", exception) + } + setupEmbeddedZookeeper() setupEmbeddedKafkaServer() eventually(timeout(60.seconds)) { @@ -139,55 +146,47 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { - // There is a race condition that may kill JVM when terminating the Kafka cluster. We set - // a custom Procedure here during the termination in order to keep JVM running and not fail the - // tests. - val logExitEvent = new Exit.Procedure { - override def execute(statusCode: Int, message: String): Unit = { - logError(s"Prevent Kafka from killing JVM (statusCode: $statusCode message: $message)") - } + if (leakDetector != null) { + ShutdownHookManager.removeShutdownHook(leakDetector) } - Exit.setExitProcedure(logExitEvent) - Exit.setHaltProcedure(logExitEvent) - try { - brokerReady = false - zkReady = false + brokerReady = false + zkReady = false - if (producer != null) { - producer.close() - producer = null - } + if (producer != null) { + producer.close() + producer = null + } - if (server != null) { - server.shutdown() - server.awaitShutdown() - server = null - } + if (adminClient != null) { + adminClient.close() + } - // On Windows, `logDirs` is left open even after Kafka server above is completely shut down - // in some cases. It leads to test failures on Windows if the directory deletion failure - // throws an exception. - brokerConf.logDirs.foreach { f => - try { - Utils.deleteRecursively(new File(f)) - } catch { - case e: IOException if Utils.isWindows => - logWarning(e.getMessage) - } - } + if (server != null) { + server.shutdown() + server.awaitShutdown() + server = null + } - if (zkUtils != null) { - zkUtils.close() - zkUtils = null + // On Windows, `logDirs` is left open even after Kafka server above is completely shut down + // in some cases. It leads to test failures on Windows if the directory deletion failure + // throws an exception. + brokerConf.logDirs.foreach { f => + try { + Utils.deleteRecursively(new File(f)) + } catch { + case e: IOException if Utils.isWindows => + logWarning(e.getMessage) } + } - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } finally { - Exit.resetExitProcedure() - Exit.resetHaltProcedure() + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index bd3cf9abddb5b..efcd5d6a5cdd3 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -34,13 +34,12 @@ import kafka.utils.ZkUtils import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.serialization.StringSerializer -import org.apache.kafka.common.utils.Exit import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.streaming.Time -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -54,7 +53,7 @@ private[kafka010] class KafkaTestUtils extends Logging { private val zkHost = "127.0.0.1" private var zkPort: Int = 0 private val zkConnectionTimeout = 60000 - private val zkSessionTimeout = 6000 + private val zkSessionTimeout = 10000 private var zookeeper: EmbeddedZookeeper = _ @@ -74,6 +73,7 @@ private[kafka010] class KafkaTestUtils extends Logging { // Flag to test whether the system is correctly started private var zkReady = false private var brokerReady = false + private var leakDetector: AnyRef = null def zkAddress: String = { assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") @@ -120,61 +120,56 @@ private[kafka010] class KafkaTestUtils extends Logging { /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ def setup(): Unit = { + // Set up a KafkaTestUtils leak detector so that we can see where the leak KafkaTestUtils is + // created. + val exception = new SparkException("It was created at: ") + leakDetector = ShutdownHookManager.addShutdownHook { () => + logError("Found a leak KafkaTestUtils.", exception) + } + setupEmbeddedZookeeper() setupEmbeddedKafkaServer() } /** Teardown the whole servers, including Kafka broker and Zookeeper */ def teardown(): Unit = { - // There is a race condition that may kill JVM when terminating the Kafka cluster. We set - // a custom Procedure here during the termination in order to keep JVM running and not fail the - // tests. - val logExitEvent = new Exit.Procedure { - override def execute(statusCode: Int, message: String): Unit = { - logError(s"Prevent Kafka from killing JVM (statusCode: $statusCode message: $message)") - } + if (leakDetector != null) { + ShutdownHookManager.removeShutdownHook(leakDetector) } - Exit.setExitProcedure(logExitEvent) - Exit.setHaltProcedure(logExitEvent) - try { - brokerReady = false - zkReady = false - - if (producer != null) { - producer.close() - producer = null - } + brokerReady = false + zkReady = false - if (server != null) { - server.shutdown() - server.awaitShutdown() - server = null - } + if (producer != null) { + producer.close() + producer = null + } - // On Windows, `logDirs` is left open even after Kafka server above is completely shut down - // in some cases. It leads to test failures on Windows if the directory deletion failure - // throws an exception. - brokerConf.logDirs.foreach { f => - try { - Utils.deleteRecursively(new File(f)) - } catch { - case e: IOException if Utils.isWindows => - logWarning(e.getMessage) - } - } + if (server != null) { + server.shutdown() + server.awaitShutdown() + server = null + } - if (zkUtils != null) { - zkUtils.close() - zkUtils = null + // On Windows, `logDirs` is left open even after Kafka server above is completely shut down + // in some cases. It leads to test failures on Windows if the directory deletion failure + // throws an exception. + brokerConf.logDirs.foreach { f => + try { + Utils.deleteRecursively(new File(f)) + } catch { + case e: IOException if Utils.isWindows => + logWarning(e.getMessage) } + } - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } finally { - Exit.resetExitProcedure() - Exit.resetHaltProcedure() + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null } } From ba84bcb2c4f73baf63782ff6fad5a607008c7cd2 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 17 Aug 2018 16:04:02 -0700 Subject: [PATCH 1421/2461] [SPARK-24433][K8S] Initial R Bindings for SparkR on K8s ## What changes were proposed in this pull request? Introducing R Bindings for Spark R on K8s - [x] Running SparkR Job ## How was this patch tested? This patch was tested with - [x] Unit Tests - [x] Integration Tests ## Example: Commands to run example spark job: 1. `dev/make-distribution.sh --pip --r --tgz -Psparkr -Phadoop-2.7 -Pkubernetes` 2. `bin/docker-image-tool.sh -m -t testing build` 3. ``` bin/spark-submit \ --master k8s://https://192.168.64.33:8443 \ --deploy-mode cluster \ --name spark-r \ --conf spark.executor.instances=1 \ --conf spark.kubernetes.container.image=spark-r:testing \ local:///opt/spark/examples/src/main/r/dataframe.R ``` This above spark-submit command works given the distribution. (Will include this integration test in PR once PRB is ready). Author: Ilan Filonenko Closes #21584 from ifilonenko/spark-r. --- bin/docker-image-tool.sh | 23 ++++--- .../org/apache/spark/deploy/SparkSubmit.scala | 8 ++- .../org/apache/spark/deploy/k8s/Config.scala | 13 ++++ .../apache/spark/deploy/k8s/Constants.scala | 2 + .../spark/deploy/k8s/KubernetesConf.scala | 8 ++- .../bindings/RDriverFeatureStep.scala | 59 +++++++++++++++++ .../submit/KubernetesClientApplication.scala | 2 + .../k8s/submit/KubernetesDriverBuilder.scala | 22 ++++--- .../deploy/k8s/submit/MainAppResource.scala | 3 + .../deploy/k8s/KubernetesConfSuite.scala | 22 +++++++ .../bindings/RDriverFeatureStepSuite.scala | 63 +++++++++++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 36 ++++++++++- .../dockerfiles/spark/bindings/R/Dockerfile | 29 +++++++++ .../src/main/dockerfiles/spark/entrypoint.sh | 14 ++++- .../ClientModeTestsSuite.scala | 2 +- .../k8s/integrationtest/KubernetesSuite.scala | 21 ++++++- .../k8s/integrationtest/RTestsSuite.scala | 44 +++++++++++++ 17 files changed, 344 insertions(+), 27 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index cd22e75402f56..d6371051ef7fb 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -71,6 +71,7 @@ function build { ) local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} + local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"} docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ @@ -79,11 +80,16 @@ function build { docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ -f "$PYDOCKERFILE" . + + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" . } function push { docker push "$(image_ref spark)" docker push "$(image_ref spark-py)" + docker push "$(image_ref spark-r)" } function usage { @@ -97,12 +103,13 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: - -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. - -p file Dockerfile with Python baked in. By default builds the Dockerfile shipped with Spark. - -r repo Repository address. - -t tag Tag to apply to the built image, or to identify the image to be pushed. - -m Use minikube's Docker daemon. - -n Build docker image with --no-cache + -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. + -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. + -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + -r repo Repository address. + -t tag Tag to apply to the built image, or to identify the image to be pushed. + -m Use minikube's Docker daemon. + -n Build docker image with --no-cache -b arg Build arg to build or push the image. For multiple build args, this option needs to be used separately for each build arg. @@ -133,14 +140,16 @@ REPO= TAG= BASEDOCKERFILE= PYDOCKERFILE= +RDOCKERFILE= NOCACHEARG= BUILD_PARAMS= -while getopts f:p:mr:t:n:b: option +while getopts f:p:R:mr:t:n:b: option do case "${option}" in f) BASEDOCKERFILE=${OPTARG};; p) PYDOCKERFILE=${OPTARG};; + R) RDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; n) NOCACHEARG="--no-cache";; diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 6e70bcd7fc088..cf902db8709e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -286,8 +286,6 @@ private[spark] class SparkSubmit extends Logging { case (STANDALONE, CLUSTER) if args.isR => error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") - case (KUBERNETES, _) if args.isR => - error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -700,7 +698,11 @@ private[spark] class SparkSubmit extends Logging { if (args.pyFiles != null) { childArgs ++= Array("--other-py-files", args.pyFiles) } - } else { + } else if (args.isR) { + childArgs ++= Array("--primary-r-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.RRunner") + } + else { childArgs ++= Array("--primary-java-resource", args.primaryResource) childArgs ++= Array("--main-class", args.mainClass) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4442333c573cc..1b582fe53624a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -139,6 +139,19 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_R_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.r.mainAppResource") + .doc("The main app resource for SparkR jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_R_APP_ARGS = + ConfigBuilder("spark.kubernetes.r.appArgs") + .doc("The app arguments for SparkR Jobs") + .internal() + .stringConf + .createOptional val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index f82cd7fd02e12..8202d874a4626 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -71,6 +71,8 @@ private[spark] object Constants { val ENV_PYSPARK_FILES = "PYSPARK_FILES" val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" + val ENV_R_PRIMARY = "R_PRIMARY" + val ENV_R_ARGS = "R_APP_ARGS" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 866ba3cbaa9c3..3aa35d419073f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -78,6 +78,9 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( def pySparkPythonVersion(): String = sparkConf .get(PYSPARK_MAJOR_PYTHON_VERSION) + def sparkRMainResource(): Option[String] = sparkConf + .get(KUBERNETES_R_MAIN_APP_RESOURCE) + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) def imagePullSecrets(): Seq[LocalObjectReference] = { @@ -125,7 +128,7 @@ private[spark] object KubernetesConf { sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) } // The function of this outer match is to account for multiple nonJVM - // bindings that will all have increased MEMORY_OVERHEAD_FACTOR to 0.4 + // bindings that will all have increased default MEMORY_OVERHEAD_FACTOR to 0.4 case nonJVM: NonJVMResource => nonJVM match { case PythonMainAppResource(res) => @@ -133,6 +136,9 @@ private[spark] object KubernetesConf { maybePyFiles.foreach{maybePyFiles => additionalFiles.appendAll(maybePyFiles.split(","))} sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) + case RMainAppResource(res) => + additionalFiles += res + sparkConfWithMainAppJar.set(KUBERNETES_R_MAIN_APP_RESOURCE, res) } sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala new file mode 100644 index 0000000000000..b33b86e02ea6f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class RDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "R Main Resource must be defined") + val maybeRArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + rArgs => + new EnvVarBuilder() + .withName(ENV_R_ARGS) + .withValue(rArgs.mkString(",")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_R_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.sparkRMainResource().get)) + .build()) + val rEnvs = envSeq ++ + maybeRArgs.toSeq + + val withRPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(rEnvs.asJava) + .addToArgs("driver-r") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withRPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 9398faee2ea5c..986c950ab365a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -60,6 +60,8 @@ private[spark] object ClientArguments { mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) case Array("--primary-py-file", primaryPythonResource: String) => mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + case Array("--primary-r-file", primaryRFile: String) => + mainAppResource = Some(RMainAppResource(primaryRFile)) case Array("--other-py-files", pyFiles: String) => maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 7208e3d377593..8f3f18ffadc3b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features._ -import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep, MountVolumesFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -40,14 +40,18 @@ private[spark] class KubernetesDriverBuilder( provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountVolumesFeatureStep) = new MountVolumesFeatureStep(_), - provideJavaStep: ( - KubernetesConf[KubernetesDriverSpecificConf] - => JavaDriverFeatureStep) = - new JavaDriverFeatureStep(_), providePythonStep: ( KubernetesConf[KubernetesDriverSpecificConf] => PythonDriverFeatureStep) = - new PythonDriverFeatureStep(_)) { + new PythonDriverFeatureStep(_), + provideRStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => RDriverFeatureStep) = + new RDriverFeatureStep(_), + provideJavaStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => JavaDriverFeatureStep) = + new JavaDriverFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { @@ -71,7 +75,9 @@ private[spark] class KubernetesDriverBuilder( case JavaMainAppResource(_) => provideJavaStep(kubernetesConf) case PythonMainAppResource(_) => - providePythonStep(kubernetesConf)} + providePythonStep(kubernetesConf) + case RMainAppResource(_) => + provideRStep(kubernetesConf)} .getOrElse(provideJavaStep(kubernetesConf)) val allFeatures = (baseFeatures :+ bindingsStep) ++ diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index cbe081ae35683..dd5a4549743df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -24,3 +24,6 @@ private[spark] case class JavaMainAppResource(primaryResource: String) extends M private[spark] case class PythonMainAppResource(primaryResource: String) extends MainAppResource with NonJVMResource + +private[spark] case class RMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index ecdb71359c5bb..e3c19cdb81567 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -122,6 +122,28 @@ class KubernetesConfSuite extends SparkFunSuite { === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) } + test("Creating driver conf with a r primary file") { + val mainResourceFile = "local:///opt/spark/main.R" + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example2.R") + val mainAppResource = Some(RMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + maybePyFiles = None) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example2.R", mainResourceFile)) + } + test("Testing explicit setting of memory overhead on non-JVM tasks") { val sparkConf = new SparkConf(false) .set(MEMORY_OVERHEAD_FACTOR, 0.3) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..8fdf91ef638f2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.RMainAppResource + +class RDriverFeatureStepSuite extends SparkFunSuite { + + test("R Step modifies container correctly") { + val expectedMainResource = "/main.R" + val mainResource = "local:///main.R" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_R_MAIN_APP_RESOURCE, mainResource) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(RMainAppResource(mainResource)), + "test-app", + "r-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Seq.empty, + sparkFiles = Seq.empty[String]) + + val step = new RDriverFeatureStep(kubernetesConf) + val driverContainerwithR = step.configurePod(baseDriverPod).container + assert(driverContainerwithR.getEnv.size === 2) + val envs = driverContainerwithR + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_R_PRIMARY) === expectedMainResource) + assert(envs(ENV_R_ARGS) === "5 7") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 046e578b94629..4117c5487a41e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} -import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -31,6 +31,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SECRETS_STEP_TYPE = "mount-secrets" private val JAVA_STEP_TYPE = "java-bindings" private val PYSPARK_STEP_TYPE = "pyspark-bindings" + private val R_STEP_TYPE = "r-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" @@ -55,6 +56,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) + private val rStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + R_STEP_TYPE, classOf[RDriverFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) @@ -70,8 +74,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => envSecretsStep, _ => localDirsStep, _ => mountVolumesStep, - _ => javaStep, - _ => pythonStep) + _ => pythonStep, + _ => rStep, + _ => javaStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( @@ -211,6 +216,31 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { JAVA_STEP_TYPE) } + test("Apply R step if main resource is R.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(RMainAppResource("example.R")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + R_STEP_TYPE) + } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile new file mode 100644 index 0000000000000..e627883ba782e --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/R +COPY R ${SPARK_HOME}/R + +RUN apk add --no-cache R R-dev + +ENV R_HOME /usr/lib/R + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 8bdb0f7a10795..216e8fe31becb 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -38,7 +38,7 @@ fi SPARK_K8S_CMD="$1" case "$SPARK_K8S_CMD" in - driver | driver-py | executor) + driver | driver-py | driver-r | executor) shift 1 ;; "") @@ -66,6 +66,10 @@ if [ -n "$PYSPARK_APP_ARGS" ]; then PYSPARK_ARGS="$PYSPARK_APP_ARGS" fi +R_ARGS="" +if [ -n "$R_APP_ARGS" ]; then + R_ARGS="$R_APP_ARGS" +fi if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then pyv="$(python -V 2>&1)" @@ -96,6 +100,14 @@ case "$SPARK_K8S_CMD" in "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS ) ;; + driver-r) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $R_PRIMARY $R_ARGS + ) + ;; executor) CMD=( ${JAVA_HOME}/bin/java diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala index 159cfd97ff403..c8bd584516ea5 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.{k8sTestTag, INTERVAL, TIMEOUT} -trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => +private[spark] trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => test("Run in client mode.", k8sTestTag) { val labels = Map("spark-app-selector" -> driverPodName) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 13ce2efecbef7..896a83a5badbb 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -38,10 +38,12 @@ private[spark] class KubernetesSuite extends SparkFunSuite import KubernetesSuite._ - protected var testBackend: IntegrationTestBackend = _ - protected var sparkHomeDir: Path = _ + private var sparkHomeDir: Path = _ + private var pyImage: String = _ + private var rImage: String = _ + protected var image: String = _ - protected var pyImage: String = _ + protected var testBackend: IntegrationTestBackend = _ protected var driverPodName: String = _ protected var kubernetesTestComponents: KubernetesTestComponents = _ protected var sparkAppConf: SparkAppConf = _ @@ -67,6 +69,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite val imageRepo = getTestImageRepo image = s"$imageRepo/spark:$imageTag" pyImage = s"$imageRepo/spark-py:$imageTag" + rImage = s"$imageRepo/spark-r:$imageTag" val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) .toFile @@ -239,6 +242,13 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") } + protected def doBasicDriverRPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === rImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + protected def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") @@ -249,6 +259,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(executorPod.getSpec.getContainers.get(0).getName === "executor") } + protected def doBasicExecutorRPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === rImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + protected def checkCustomSettings(pod: Pod): Unit = { assert(pod.getMetadata.getLabels.get("label1") === "label1-value") assert(pod.getMetadata.getLabels.get("label2") === "label2-value") diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala new file mode 100644 index 0000000000000..885a23cfb4864 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} + +private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite => + + import RTestsSuite._ + import KubernetesSuite.k8sTestTag + + test("Run SparkR on simple dataframe.R example", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-r:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = SPARK_R_DATAFRAME_TEST, + mainClass = "", + expectedLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"), + appArgs = Array.empty[String], + driverPodChecker = doBasicDriverRPodCheck, + executorPodChecker = doBasicExecutorRPodCheck, + appLocator = appLocator, + isJVM = false) + } +} + +private[spark] object RTestsSuite { + val CONTAINER_LOCAL_SPARKR: String = "local:///opt/spark/examples/src/main/r/" + val SPARK_R_DATAFRAME_TEST: String = CONTAINER_LOCAL_SPARKR + "dataframe.R" +} From 10f2b6fa05f3d977f3b6099fcd94c5c0cd97a0cb Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 17 Aug 2018 22:14:42 -0700 Subject: [PATCH 1422/2461] [SPARK-23555][PYTHON] Add BinaryType support for Arrow in Python ## What changes were proposed in this pull request? Adding `BinaryType` support for Arrow in pyspark, conditional on using pyarrow >= 0.10.0. Earlier versions will continue to raise a TypeError. ## How was this patch tested? Additional unit tests in pyspark for code paths that use Arrow for createDataFrame, toPandas, and scalar pandas_udfs. Closes #20725 from BryanCutler/arrow-binary-type-support-SPARK-23555. Authored-by: Bryan Cutler Signed-off-by: Bryan Cutler --- python/pyspark/sql/tests.py | 66 ++++++++++++++++++++++++++++++------- python/pyspark/sql/types.py | 15 +++++++++ 2 files changed, 70 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 91ed600afedd7..00d7e18320a51 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4050,6 +4050,8 @@ class ArrowTests(ReusedSQLTestCase): def setUpClass(cls): from datetime import date, datetime from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -4078,6 +4080,13 @@ def setUpClass(cls): (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): + cls.schema.add(StructField("9_binary_t", BinaryType(), True)) + cls.data[0] = cls.data[0] + (bytearray(b"a"),) + cls.data[1] = cls.data[1] + (bytearray(b"bb"),) + cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + @classmethod def tearDownClass(cls): del os.environ["TZ"] @@ -4115,12 +4124,23 @@ def test_toPandas_fallback_enabled(self): self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): + from distutils.version import LooseVersion + import pyarrow as pa + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + schema = StructType([StructField("binary", BinaryType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): + df.toPandas() + def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -4232,19 +4252,22 @@ def test_createDataFrame_with_schema(self): def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() - wrong_schema = StructType(list(reversed(self.schema))) + fields = list(self.schema) + fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp + wrong_schema = StructType(fields) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() + new_names = list(map(str, range(len(self.schema.fieldNames())))) # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=list(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) - self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) + df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -4331,13 +4354,22 @@ def test_createDataFrame_fallback_enabled(self): self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): + from distutils.version import LooseVersion import pandas as pd + import pyarrow as pa with QuietTest(self.sc): with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): + self.spark.createDataFrame( + pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd @@ -4729,6 +4761,24 @@ def test_vectorized_udf_datatype_string(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_binary(self): + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) + else: + data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] + schema = StructType().add("binary", BinaryType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, BinaryType()) + res = df.select(str_f(col('binary'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_array_type(self): from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), ([3, 4],)] @@ -4835,12 +4885,6 @@ def test_vectorized_udf_unsupported_types(self): 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): - pandas_udf(lambda x: x, BinaryType()) - def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 214d8fe6bbbb6..0b61707c8cc0a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1578,6 +1578,7 @@ def convert(self, obj, gateway_client): def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ + from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() @@ -1597,6 +1598,12 @@ def to_arrow_type(dt): arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == BinaryType: + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + arrow_type = pa.binary() elif type(dt) == DateType: arrow_type = pa.date32() elif type(dt) == TimestampType: @@ -1623,6 +1630,8 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ + from distutils.version import LooseVersion + import pyarrow as pa import pyarrow.types as types if types.is_boolean(at): spark_type = BooleanType() @@ -1642,6 +1651,12 @@ def from_arrow_type(at): spark_type = DecimalType(precision=at.precision, scale=at.scale) elif types.is_string(at): spark_type = StringType() + elif types.is_binary(at): + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at) + + "\nPlease install pyarrow >= 0.10.0 for BinaryType support.") + spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() elif types.is_timestamp(at): From e3cf13d7bdb98e512b7d3b6c37aac4655ab141f3 Mon Sep 17 00:00:00 2001 From: Vinod KC Date: Sat, 18 Aug 2018 17:19:29 +0800 Subject: [PATCH 1423/2461] [SPARK-25137][SPARK SHELL] NumberFormatException` when starting spark-shell from Mac terminal ## What changes were proposed in this pull request? When starting spark-shell from Mac terminal (MacOS High Sirra Version 10.13.6), Getting exception [ERROR] Failed to construct terminal; falling back to unsupported java.lang.NumberFormatException: For input string: "0x100" at java.lang.NumberFormatException.forInputString(NumberFormatException.java:65) at java.lang.Integer.parseInt(Integer.java:580) at java.lang.Integer.valueOf(Integer.java:766) at jline.internal.InfoCmp.parseInfoCmp(InfoCmp.java:59) at jline.UnixTerminal.parseInfoCmp(UnixTerminal.java:242) at jline.UnixTerminal.(UnixTerminal.java:65) at jline.UnixTerminal.(UnixTerminal.java:50) at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62) at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45) at java.lang.reflect.Constructor.newInstance(Constructor.java:423) at java.lang.Class.newInstance(Class.java:442) at jline.TerminalFactory.getFlavor(TerminalFactory.java:211) This issue is due a jline defect : https://github.com/jline/jline2/issues/281, which is fixed in Jline 2.14.4, bumping up JLine version in spark to version >= Jline 2.14.4 will fix the issue ## How was this patch tested? No new UT/automation test added, after upgrade to latest Jline version 2.14.6, manually tested spark shell features Closes #22130 from vinodkc/br_UpgradeJLineVersion. Authored-by: Vinod KC Signed-off-by: hyukjinkwon --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 30727e69203f8..aca5bbbe02eb6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -118,7 +118,7 @@ jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.14.3.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e58ae7a0e7668..0902af1dfdc65 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -119,7 +119,7 @@ jersey-server-2.22.2.jar jetty-6.1.26.jar jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar -jline-2.14.3.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 8a65860d3c011..35cf1cd85fd80 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -118,7 +118,7 @@ jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jetty-webapp-9.3.24.v20180605.jar jetty-xml-9.3.24.v20180605.jar -jline-2.14.3.jar +jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar diff --git a/pom.xml b/pom.xml index ab2a272fb3ff7..a4184c3153336 100644 --- a/pom.xml +++ b/pom.xml @@ -746,7 +746,7 @@ jline jline - 2.14.3 + 2.14.6 org.scalatest diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4b4cce6788a96..1f45a06084c0d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -465,7 +465,7 @@ object DockerIntegrationTests { object DependencyOverrides { lazy val settings = Seq( dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", - dependencyOverrides += "jline" % "jline" % "2.14.3") + dependencyOverrides += "jline" % "jline" % "2.14.6") } /** From f454d5287f3f90696c8068c424e333a71e1e7b1b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 18 Aug 2018 17:20:34 +0800 Subject: [PATCH 1424/2461] [MINOR][DOC][SQL] use one line for annotation arg value ## What changes were proposed in this pull request? Put annotation args in one line, or API doc generation will fail. ~~~ [error] /Users/meng/src/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala:1559: annotation argument needs to be a constant; found: "_FUNC_(expr) - Returns the character length of string data or number of bytes of ".+("binary data. The length of string data includes the trailing spaces. The length of binary ").+("data includes binary zeros.") [error] "binary data. The length of string data includes the trailing spaces. The length of binary " + [error] ^ [info] No documentation generated with unsuccessful compiler run [error] one error found [error] (catalyst/compile:doc) Scaladoc generation failed [error] Total time: 27 s, completed Aug 17, 2018 3:20:08 PM ~~~ ## How was this patch tested? sbt catalyst/compile:doc passed Closes #22137 from mengxr/minor-doc-fix. Authored-by: Xiangrui Meng Signed-off-by: hyukjinkwon --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e1549d3dee539..14faa62bde7d0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1554,10 +1554,9 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + - "binary data. The length of string data includes the trailing spaces. The length of binary " + - "data includes binary zeros.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", examples = """ Examples: > SELECT _FUNC_('Spark SQL '); @@ -1567,6 +1566,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run > SELECT CHARACTER_LENGTH('Spark SQL '); 10 """) +// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) From 4dd87d8ff513d86380cf424e961c3f31ac58eab5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 18 Aug 2018 17:24:06 +0800 Subject: [PATCH 1425/2461] [SPARK-25142][PYSPARK] Add error messages when Python worker could not open socket in `_load_from_socket`. ## What changes were proposed in this pull request? Sometimes Python worker can't open socket in `_load_from_socket` for some reason, but it's difficult to figure out the reason because the exception doesn't even contain the messages from `socket.error`s. We should at least add the error messages when raising the exception. ## How was this patch tested? Manually in my local environment. Closes #22132 from ueshin/issues/SPARK-25142/socket_error. Authored-by: Takuya UESHIN Signed-off-by: hyukjinkwon --- python/pyspark/rdd.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ba39edbc93d7c..b061074a28ab4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,7 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_stopiteration +from pyspark.util import fail_on_stopiteration, _exception_message __all__ = ["RDD"] @@ -143,6 +143,7 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): port, auth_secret = sock_info sock = None + errors = [] # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): @@ -151,13 +152,15 @@ def _load_from_socket(sock_info, serializer): try: sock.settimeout(15) sock.connect(sa) - except socket.error: + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) sock.close() sock = None continue break if not sock: - raise Exception("could not open socket") + raise Exception("could not open socket: %s" % errors) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) From 9047cc0f2c8a101d62d42f57da55227e87ab630d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 18 Aug 2018 17:30:12 +0800 Subject: [PATCH 1426/2461] [SPARK-24886][INFRA] Fix the testing script to increase timeout for Jenkins build (from 340m to 400m) ## What changes were proposed in this pull request? This PR targets to increase the timeout from 340 to 400m. Please also see https://github.com/apache/spark/pull/21845#discussion_r209807634 ## How was this patch tested? N/A Closes #22098 from HyukjinKwon/SPARK-24886-1. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/run-tests-jenkins.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 16af97c7fbeae..e6fe3b82ed202 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,9 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 400m) - tests_timeout = "340m" + # must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher + # then this. Please consult with the build manager or a committer when it should be increased. + tests_timeout = "400m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. From 14d7c1c3e99e7523c757628d411525aa9d8e0709 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Sat, 18 Aug 2018 17:31:52 +0800 Subject: [PATCH 1427/2461] [SPARK-24863][SS] Report Kafka offset lag as a custom metrics ## What changes were proposed in this pull request? This builds on top of SPARK-24748 to report 'offset lag' as a custom metrics for Kafka structured streaming source. This lag is the difference between the latest offsets in Kafka the time the metrics is reported (just after a micro-batch completes) and the latest offset Spark has processed. It can be 0 (or close to 0) if spark keeps up with the rate at which messages are ingested into Kafka topics in steady state. This measures how far behind the spark source has fallen behind (per partition) and can aid in tuning the application. ## How was this patch tested? Existing and new unit tests Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21819 from arunmahadevan/SPARK-24863. Authored-by: Arun Mahadevan Signed-off-by: hyukjinkwon --- .../apache/spark/sql/kafka010/JsonUtils.scala | 33 ++++++++++++---- .../sql/kafka010/KafkaMicroBatchReader.scala | 26 +++++++++++-- .../kafka010/KafkaMicroBatchSourceSuite.scala | 38 ++++++++++++++++++- 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 868edb5dcdc0c..92b13f2b555d1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -29,6 +29,11 @@ import org.json4s.jackson.Serialization */ private object JsonUtils { private implicit val formats = Serialization.formats(NoTypeHints) + implicit val ordering = new Ordering[TopicPartition] { + override def compare(x: TopicPartition, y: TopicPartition): Int = { + Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) + } + } /** * Read TopicPartitions from json string @@ -51,7 +56,7 @@ private object JsonUtils { * Write TopicPartitions as json string */ def partitions(partitions: Iterable[TopicPartition]): String = { - val result = new HashMap[String, List[Int]] + val result = HashMap.empty[String, List[Int]] partitions.foreach { tp => val parts: List[Int] = result.getOrElse(tp.topic, Nil) result += tp.topic -> (tp.partition::parts) @@ -80,19 +85,31 @@ private object JsonUtils { * Write per-TopicPartition offsets as json string */ def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { - val result = new HashMap[String, HashMap[Int, Long]]() - implicit val ordering = new Ordering[TopicPartition] { - override def compare(x: TopicPartition, y: TopicPartition): Int = { - Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) - } - } + val result = HashMap.empty[String, HashMap[Int, Long]] val partitions = partitionOffsets.keySet.toSeq.sorted // sort for more determinism partitions.foreach { tp => val off = partitionOffsets(tp) - val parts = result.getOrElse(tp.topic, new HashMap[Int, Long]) + val parts = result.getOrElse(tp.topic, HashMap.empty[Int, Long]) parts += tp.partition -> off result += tp.topic -> parts } Serialization.write(result) } + + /** + * Write per-topic partition lag as json string + */ + def partitionLags( + latestOffsets: Map[TopicPartition, Long], + processedOffsets: Map[TopicPartition, Long]): String = { + val result = HashMap.empty[String, HashMap[Int, Long]] + val partitions = latestOffsets.keySet.toSeq.sorted + partitions.foreach { tp => + val lag = latestOffsets(tp) - processedOffsets.getOrElse(tp, 0L) + val parts = result.getOrElse(tp.topic, HashMap.empty[Int, Long]) + parts += tp.partition -> lag + result += tp.topic -> parts + } + Serialization.write(Map("lag" -> result)) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 6c95b2b2560c4..900c9f4e7fbf3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import org.apache.commons.io.IOUtils +import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -33,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions} import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset, SupportsCustomReaderMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -62,7 +63,7 @@ private[kafka010] class KafkaMicroBatchReader( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends MicroBatchReader with Logging { + extends MicroBatchReader with SupportsCustomReaderMetrics with Logging { private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -158,6 +159,10 @@ private[kafka010] class KafkaMicroBatchReader( KafkaSourceOffset(endPartitionOffsets) } + override def getCustomMetrics: CustomMetrics = { + KafkaCustomMetrics(kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets) + } + override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } @@ -380,3 +385,18 @@ private[kafka010] case class KafkaMicroBatchInputPartitionReader( } } } + +/** + * Currently reports per topic-partition lag. + * This is the difference between the offset of the latest available data + * in a topic-partition and the latest offset that has been processed. + */ +private[kafka010] case class KafkaCustomMetrics( + latestOffsets: Map[TopicPartition, Long], + processedOffsets: Map[TopicPartition, Long]) extends CustomMetrics { + override def json(): String = { + JsonUtils.partitionLags(latestOffsets, processedOffsets) + } + + override def toString: String = json() +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 172c0ef251cac..c7b74f305eed2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -31,12 +31,13 @@ import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata import org.apache.kafka.common.TopicPartition +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution @@ -701,6 +702,41 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } } + test("custom lag metrics") { + import testImplicits._ + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + testUtils.sendMessages(topic, (1 to 100).map(_.toString).toArray) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .option("maxOffsetsPerTrigger", 10) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + implicit val formats = DefaultFormats + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = OneTimeTrigger), + AssertOnQuery { query => + query.awaitTermination() + val source = query.lastProgress.sources(0) + // masOffsetsPerTrigger is 10, and there are two partitions containing 50 events each + // so 5 events should be processed from each partition and a lag of 45 events + val custom = parse(source.customMetrics) + .extract[Map[String, Map[String, Map[String, Long]]]] + custom("lag")(topic)("0") == 45 && custom("lag")(topic)("1") == 45 + } + ) + } + } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { From a8a1ac01c4732f8a738b973c8486514cd88bf99b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 18 Aug 2018 10:34:49 -0700 Subject: [PATCH 1428/2461] [SPARK-24959][SQL] Speed up count() for JSON and CSV ## What changes were proposed in this pull request? In the PR, I propose to skip invoking of the CSV/JSON parser per each line in the case if the required schema is empty. Added benchmarks for `count()` shows performance improvement up to **3.5 times**. Before: ``` Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) -------------------------------------------------------------------------------------- JSON count() 7676 / 7715 1.3 767.6 CSV count() 3309 / 3363 3.0 330.9 ``` After: ``` Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) -------------------------------------------------------------------------------------- JSON count() 2104 / 2156 4.8 210.4 CSV count() 2332 / 2386 4.3 233.2 ``` ## How was this patch tested? It was tested by `CSVSuite` and `JSONSuite` as well as on added benchmarks. Author: Maxim Gekk Author: Maxim Gekk Closes #21909 from MaxGekk/empty-schema-optimization. --- .../sql/catalyst/json/JacksonParser.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 6 ++- .../datasources/FailureSafeParser.scala | 12 ++++- .../datasources/csv/UnivocityParser.scala | 16 +++---- .../datasources/json/JsonDataSource.scala | 6 ++- .../datasources/csv/CSVBenchmarks.scala | 39 ++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 26 +++++++++++ .../datasources/json/JsonBenchmarks.scala | 45 ++++++++++++++++++- .../datasources/json/JsonSuite.scala | 27 ++++++++++- 9 files changed, 159 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 6feea500b2aa0..984979ac5e9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayOutputStream, CharConversionException} +import java.nio.charset.MalformedInputException import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -402,7 +403,7 @@ class JacksonParser( } } } catch { - case e @ (_: RuntimeException | _: JsonProcessingException) => + case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9bd113419ae4c..1b3a9fc91d198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -450,7 +450,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming) @@ -521,7 +522,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { input => Seq(rawParser.parse(input)), parsedOptions.parseMode, schema, - parsedOptions.columnNameOfCorruptRecord) + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) iter.flatMap(parser.parse) } sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 43591a9ff524a..90e81661bae7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -28,7 +29,8 @@ class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], mode: ParseMode, schema: StructType, - columnNameOfCorruptRecord: String) { + columnNameOfCorruptRecord: String, + isMultiLine: Boolean) { private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) @@ -56,9 +58,15 @@ class FailureSafeParser[IN]( } } + private val skipParsing = !isMultiLine && mode == PermissiveMode && schema.isEmpty + def parse(input: IN): Iterator[InternalRow] = { try { - rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + if (skipParsing) { + Iterator.single(InternalRow.empty) + } else { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } } catch { case e: BadRecordException => mode match { case PermissiveMode => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 79143cce4a380..e15af425b2649 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,19 +203,11 @@ class UnivocityParser( } } - private val doParse = if (requiredSchema.nonEmpty) { - (input: String) => convert(tokenizer.parseLine(input)) - } else { - // If `columnPruning` enabled and partition attributes scanned only, - // `schema` gets empty. - (_: String) => InternalRow.empty - } - /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = doParse(input) + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) private val getToken = if (options.columnPruning) { (tokens: Array[String], index: Int) => tokens(index) @@ -293,7 +285,8 @@ private[csv] object UnivocityParser { input => Seq(parser.convert(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten @@ -341,7 +334,8 @@ private[csv] object UnivocityParser { input => Seq(parser.parse(input)), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index d6c588894d7f8..76f58371ae264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -139,7 +139,8 @@ object TextInputJsonDataSource extends JsonDataSource { input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) linesReader.flatMap(safeParser.parse) } @@ -223,7 +224,8 @@ object MultiLineJsonDataSource extends JsonDataSource { input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, - parser.options.columnNameOfCorruptRecord) + parser.options.columnNameOfCorruptRecord, + parser.options.multiLine) safeParser.parse( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 1a3dacb8398e6..24f5f55d55485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -119,8 +119,47 @@ object CSVBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X + Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X + count() 2332 / 2386 4.3 233.2 5.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) multiColumnsBenchmark(rowsNum = 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 456b4535a0dcc..14840e59a1052 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1641,4 +1641,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("count() for malformed input") { + def countForMalformedCSV(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", IntegerType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).option("header", false).csv(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = "1" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("0xAC", validRec), + Seq(validRec, "0.314"), + Seq("\\\\\\", validRec) + ) + inputs.foreach { input => + countForMalformedCSV(expected, input) + } + } + + checkCount(2) + countForMalformedCSV(0, Seq("")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index 85cf054e51f6b..a2b747eaab411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.json import java.io.File import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.util.{Benchmark, Utils} /** @@ -171,9 +172,49 @@ object JSONBenchmarks { } } + def countBenchmark(rowsNum: Int): Unit = { + val colsNum = 10 + val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write + .json(path.getAbsolutePath) + + val ds = spark.read.schema(schema).json(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + ds.select("*").filter((_: Row) => true).count() + } + benchmark.addCase(s"Select 1 column + count()", 3) { _ => + ds.select($"col1").filter((_: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz + + Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X + Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X + count() 2104 / 2156 4.8 210.4 4.7X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { schemaInferring(100 * 1000 * 1000) perlineParsing(100 * 1000 * 1000) perlineParsingOfWideColumn(10 * 1000 * 1000) + countBenchmark(10 * 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 655f40ad549e6..3e4cc8f166279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2223,7 +2223,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) } - test("SPARK-23723: specified encoding is not matched to actual encoding") { val fileName = "test-data/utf16LE.json" val schema = new StructType().add("firstName", StringType).add("lastName", StringType) @@ -2490,4 +2489,30 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exception.getMessage.contains("encoding must not be included in the blacklist")) } } + + test("count() for malformed input") { + def countForMalformedJSON(expected: Long, input: Seq[String]): Unit = { + val schema = new StructType().add("a", StringType) + val strings = spark.createDataset(input) + val df = spark.read.schema(schema).json(strings) + + assert(df.count() == expected) + } + def checkCount(expected: Long): Unit = { + val validRec = """{"a":"b"}""" + val inputs = Seq( + Seq("{-}", validRec), + Seq(validRec, "?"), + Seq("}", validRec), + Seq(validRec, """{"a": [1, 2, 3]}"""), + Seq("""{"a": {"a": "b"}}""", validRec) + ) + inputs.foreach { input => + countForMalformedJSON(expected, input) + } + } + + checkCount(2) + countForMalformedJSON(0, Seq("")) + } } From 6b8fbbfb110601ffc3343b08113d13267baf27bf Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 19 Aug 2018 09:18:47 +0900 Subject: [PATCH 1429/2461] [SPARK-25141][SQL][TEST] Modify tests for higher-order functions to check bind method. ## What changes were proposed in this pull request? We should also check `HigherOrderFunction.bind` method passes expected parameters. This pr modifies tests for higher-order functions to check `bind` method. ## How was this patch tested? Modified tests. Closes #22131 from ueshin/issues/SPARK-25141/bind_test. Authored-by: Takuya UESHIN Signed-off-by: Takuya UESHIN --- .../HigherOrderFunctionsSuite.scala | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index ea85c21f727c1..e13f4d98295be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -60,24 +60,37 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper LambdaFunction(function, Seq(lv1, lv2, lv3)) } + private def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size === argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType === dataType) + assert(arg.nullable === nullable) + } + f + } + def transform(expr: Expression, f: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) } def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) } def filter(expr: Expression, f: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) } def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val map = expr.dataType.asInstanceOf[MapType] - TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) } def aggregate( @@ -85,13 +98,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper zero: Expression, merge: (Expression, Expression) => Expression, finish: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] + val ArrayType(et, cn) = expr.dataType val zeroType = zero.dataType ArrayAggregate( expr, zero, - createLambda(zeroType, true, at.elementType, at.containsNull, merge), + createLambda(zeroType, true, et, cn, merge), createLambda(zeroType, true, finish)) + .bind(validateBinding) } def aggregate( @@ -102,8 +116,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val map = expr.dataType.asInstanceOf[MapType] - TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) } test("ArrayTransform") { @@ -149,8 +163,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("MapFilter") { def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val mt = expr.dataType.asInstanceOf[MapType] - MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f)) + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) } val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), MapType(IntegerType, IntegerType, valueContainsNull = false)) @@ -230,8 +244,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("ArrayExists") { def exists(expr: Expression, f: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayExists(expr, createLambda(at.elementType, at.containsNull, f)) + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) } val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) @@ -439,6 +453,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val MapType(kt, vt1, _) = left.dataType val MapType(_, vt2, _) = right.dataType MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) } val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), @@ -556,7 +571,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper f: (Expression, Expression) => Expression): Expression = { val ArrayType(leftT, _) = left.dataType val ArrayType(rightT, _) = right.dataType - ZipWith(left, right, createLambda(leftT, true, rightT, true, f)) + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) } val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) From 60af2501e1afc00192c779f2736a4e3de12428fa Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 20 Aug 2018 20:42:27 +0800 Subject: [PATCH 1430/2461] [SPARK-25160][SQL] Avro: remove sql configuration spark.sql.avro.outputTimestampType ## What changes were proposed in this pull request? In the PR for supporting logical timestamp types https://github.com/apache/spark/pull/21935, a SQL configuration spark.sql.avro.outputTimestampType is added, so that user can specify the output timestamp precision they want. With PR https://github.com/apache/spark/pull/21847, the output file can be written with user specified types. So there is no need to have such trivial configuration. Otherwise to make it consistent we need to add configuration for all the Catalyst types that can be converted into different Avro types. This PR also add a test case for user specified output schema with different timestamp types. ## How was this patch tested? Unit test Closes #22151 from gengliangwang/removeOutputTimestampType. Authored-by: Gengliang Wang Signed-off-by: hyukjinkwon --- .../apache/spark/sql/avro/AvroOptions.scala | 11 ------- .../spark/sql/avro/AvroSerializer.scala | 6 ++-- .../spark/sql/avro/SchemaConverters.scala | 22 ++++--------- .../spark/sql/avro/AvroLogicalTypeSuite.scala | 31 +++++++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 18 ----------- 5 files changed, 30 insertions(+), 58 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 8c62d5db7ae24..67f56343b4524 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType /** * Options for Avro Reader and Writer stored in case insensitive manner. @@ -80,14 +79,4 @@ class AvroOptions( val compression: String = { parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) } - - /** - * Avro timestamp type used when Spark writes data to Avro files. - * Currently supported types are `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS`. - * TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of microseconds - * from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with millisecond precision, - * which means Spark has to truncate the microsecond portion of its timestamp value. - * The related configuration is set via SQLConf, and it is not exposed as an option. - */ - val outputTimestampType: AvroOutputTimestampType.Value = SQLConf.get.avroOutputTimestampType } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index f551c8360729d..e902b4c77eaad 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -201,13 +201,11 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: private def newStructConverter( catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { - if (avroStruct.getType != RECORD) { + if (avroStruct.getType != RECORD || avroStruct.getFields.size() != catalystStruct.length) { throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + s"Avro type $avroStruct.") } - val avroFields = avroStruct.getFields - assert(avroFields.size() == catalystStruct.length) - val fieldConverters = catalystStruct.zip(avroFields.asScala).map { + val fieldConverters = catalystStruct.zip(avroStruct.getFields.asScala).map { case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) } val numFields = catalystStruct.length diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 7b33cf6e6e055..3a15e8d087fa4 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.avro import scala.collection.JavaConverters._ import scala.util.Random -import com.fasterxml.jackson.annotation.ObjectIdGenerators.UUIDGenerator -import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} import org.apache.avro.Schema.Type._ -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator -import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType import org.apache.spark.sql.types._ import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision} @@ -126,8 +123,7 @@ object SchemaConverters { catalystType: DataType, nullable: Boolean = false, recordName: String = "topLevelRecord", - prevNameSpace: String = "", - outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS) + prevNameSpace: String = "") : Schema = { val builder = SchemaBuilder.builder() @@ -138,13 +134,7 @@ object SchemaConverters { case DateType => LogicalTypes.date().addToSchema(builder.intType()) case TimestampType => - val timestampType = outputTimestampType match { - case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis() - case AvroOutputTimestampType.TIMESTAMP_MICROS => LogicalTypes.timestampMicros() - case other => - throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.") - } - timestampType.addToSchema(builder.longType()) + LogicalTypes.timestampMicros().addToSchema(builder.longType()) case FloatType => builder.floatType() case DoubleType => builder.doubleType() @@ -162,10 +152,10 @@ object SchemaConverters { case BinaryType => builder.bytesType() case ArrayType(et, containsNull) => builder.array() - .items(toAvroType(et, containsNull, recordName, prevNameSpace, outputTimestampType)) + .items(toAvroType(et, containsNull, recordName, prevNameSpace)) case MapType(StringType, vt, valueContainsNull) => builder.map() - .values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace, outputTimestampType)) + .values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) case st: StructType => val nameSpace = prevNameSpace match { case "" => recordName @@ -175,7 +165,7 @@ object SchemaConverters { val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() st.foreach { f => val fieldAvroType = - toAvroType(f.dataType, f.nullable, f.name, nameSpace, outputTimestampType) + toAvroType(f.dataType, f.nullable, f.name, nameSpace) fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() } fieldsAssembler.endRecord() diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index ca7eef2a81cf8..79ba2871c2264 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -148,7 +148,7 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU } } - test("Logical type: specify different output timestamp types") { + test("Logical type: user specified output schema with different timestamp types") { withTempDir { dir => val timestampAvro = timestampFile(dir.getAbsolutePath) val df = @@ -156,13 +156,26 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) - Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType => - withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) { - withTempPath { path => - df.write.format("avro").save(path.toString) - checkAnswer(spark.read.format("avro").load(path.toString), expected) - } - } + val userSpecifiedTimestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", + "type": [{"type": "long","logicalType": "timestamp-micros"}, "null"]}, + {"name": "timestamp_micros", + "type": [{"type": "long","logicalType": "timestamp-millis"}, "null"]} + ] + } + """ + + withTempPath { path => + df.write + .format("avro") + .option("avroSchema", userSpecifiedTimestampSchema) + .save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) } } } @@ -179,7 +192,7 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU } } - test("Logical type: user specified schema") { + test("Logical type: user specified read schema") { withTempDir { dir => val timestampAvro = timestampFile(dir.getAbsolutePath) val expected = timestampInputData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dbb5bb43b4f1f..bffdddcf3fdb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1444,21 +1444,6 @@ object SQLConf { .intConf .createWithDefault(20) - object AvroOutputTimestampType extends Enumeration { - val TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value - } - - val AVRO_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.avro.outputTimestampType") - .doc("Sets which Avro timestamp type to use when Spark writes data to Avro files. " + - "TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of " + - "microseconds from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with " + - "millisecond precision, which means Spark has to truncate the microsecond portion of its " + - "timestamp value.") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(AvroOutputTimestampType.values.map(_.toString)) - .createWithDefault(AvroOutputTimestampType.TIMESTAMP_MICROS.toString) - val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") @@ -1882,9 +1867,6 @@ class SQLConf extends Serializable with Logging { def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) - def avroOutputTimestampType: AvroOutputTimestampType.Value = - AvroOutputTimestampType.withName(getConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE)) - def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) From 219ed7b487c2dfb5007247f77ebf1b3cc73cecb5 Mon Sep 17 00:00:00 2001 From: Zhang Le Date: Mon, 20 Aug 2018 14:59:03 -0500 Subject: [PATCH 1431/2461] [DOCS] Fixed NDCG formula issues When j is 0, log(j+1) will be 0, and this leads to division by 0 issue. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22090 from yueguoguo/patch-1. Authored-by: Zhang Le Signed-off-by: Sean Owen --- docs/mllib-evaluation-metrics.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index d9dbbab4840a3..c65ecdcb67ee4 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -462,13 +462,13 @@ $$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{ Normalized Discounted Cumulative Gain $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} - \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+2)}} \\ \text{Where} \\ \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ - \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+2)}$ - NDCG at k is a + NDCG at k is a measure of how many of the first k recommended documents are in the set of true relevant documents averaged across all users. In contrast to precision at k, this metric takes into account the order of the recommendations (documents are assumed to be in order of decreasing relevance). From 883f3aff67aac25c9d9a3bdf8d47fadefbf9645b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 21 Aug 2018 09:08:36 +0800 Subject: [PATCH 1432/2461] [SPARK-25144][SQL][TEST] Free aggregate map when task ends ## What changes were proposed in this pull request? [SPARK-25144](https://issues.apache.org/jira/browse/SPARK-25144) reports memory leaks on Apache Spark 2.0.2 ~ 2.3.2-RC5. The bug is already fixed via #21738 as a part of SPARK-21743. This PR only adds a test case to prevent any future regression. ```scala scala> case class Foo(bar: Option[String]) scala> val ds = List(Foo(Some("bar"))).toDS scala> val result = ds.flatMap(_.bar).distinct scala> result.rdd.isEmpty 18/08/19 23:01:54 WARN Executor: Managed memory leak detected; size = 8650752 bytes, TID = 125 res0: Boolean = false ``` ## How was this patch tested? Pass the Jenkins with a new added test case. Closes #22155 from dongjoon-hyun/SPARK-25144-2. Authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 84efd2b7a1dc6..01dc28d70184e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2852,4 +2852,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) } } + + test("SPARK-25144 'distinct' causes memory leak") { + val ds = List(Foo(Some("bar"))).toDS + val result = ds.flatMap(_.bar).distinct + result.rdd.isEmpty + } } + +case class Foo(bar: Option[String]) From b461acb2d90b734393c27fe7b359e2f2d297b8d4 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Tue, 21 Aug 2018 10:23:55 +0800 Subject: [PATCH 1433/2461] [SPARK-25134][SQL] Csv column pruning with checking of headers throws incorrect error ## What changes were proposed in this pull request? When column pruning is turned on the checking of headers in the csv should only be for the fields in the requiredSchema, not the dataSchema, because column pruning means only requiredSchema is read. ## How was this patch tested? Added 2 unit tests where column pruning is turned on/off and csv headers are checked againt schema Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22123 from koertkuipers/feat-csv-column-pruning-and-check-header. Authored-by: Koert Kuipers Signed-off-by: hyukjinkwon --- .../apache/spark/sql/DataFrameReader.scala | 7 ++-- .../datasources/csv/CSVDataSource.scala | 40 ++++++------------- .../datasources/csv/CSVFileFormat.scala | 4 +- .../execution/datasources/csv/CSVSuite.scala | 33 +++++++++++++++ 4 files changed, 53 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1b3a9fc91d198..5b3b5c2451aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -506,10 +506,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => - CSVDataSource.checkHeader( - firstLine, - new CsvParser(parsedOptions.asParserSettings), + val parser = new CsvParser(parsedOptions.asParserSettings) + val columnNames = parser.parseLine(firstLine) + CSVDataSource.checkHeaderColumnNames( actualSchema, + columnNames, csvDataset.getClass.getCanonicalName, parsedOptions.enforceSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b7b46c7c86a29..2b86054c0ffcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,7 +54,8 @@ abstract class CSVDataSource extends Serializable { requiredSchema: StructType, // Actual schema of data in the csv file dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -181,25 +182,6 @@ object CSVDataSource extends Logging { } } } - - /** - * Checks that CSV header contains the same column names as fields names in the given schema - * by taking into account case sensitivity. - */ - def checkHeader( - header: String, - parser: CsvParser, - schema: StructType, - fileName: String, - enforceSchema: Boolean, - caseSensitive: Boolean): Unit = { - checkHeaderColumnNames( - schema, - parser.parseLine(header), - fileName, - enforceSchema, - caseSensitive) - } } object TextInputCSVDataSource extends CSVDataSource { @@ -211,7 +193,8 @@ object TextInputCSVDataSource extends CSVDataSource { parser: UnivocityParser, requiredSchema: StructType, dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] = { + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) @@ -227,10 +210,11 @@ object TextInputCSVDataSource extends CSVDataSource { // Note: if there are only comments in the first block, the header would probably // be not extracted. CSVUtils.extractHeader(lines, parser.options).foreach { header => - CSVDataSource.checkHeader( - header, - parser.tokenizer, - dataSchema, + val schema = if (columnPruning) requiredSchema else dataSchema + val columnNames = parser.tokenizer.parseLine(header) + CSVDataSource.checkHeaderColumnNames( + schema, + columnNames, file.filePath, parser.options.enforceSchema, caseSensitive) @@ -308,10 +292,12 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, requiredSchema: StructType, dataSchema: StructType, - caseSensitive: Boolean): Iterator[InternalRow] = { + caseSensitive: Boolean, + columnPruning: Boolean): Iterator[InternalRow] = { def checkHeader(header: Array[String]): Unit = { + val schema = if (columnPruning) requiredSchema else dataSchema CSVDataSource.checkHeaderColumnNames( - dataSchema, + schema, header, file.filePath, parser.options.enforceSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index d59b9820bdeef..9aad0bd55e736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -131,6 +131,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { ) } val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val columnPruning = sparkSession.sessionState.conf.csvColumnPruning (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -144,7 +145,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { parser, requiredSchema, dataSchema, - caseSensitive) + caseSensitive, + columnPruning) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 14840e59a1052..5a1d6679ebbdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1603,6 +1603,39 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) } + test("SPARK-25134: check header on parsing of dataset with projection and column pruning") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + Seq(false, true).foreach { multiLine => + withTempPath { path => + val dir = path.getAbsolutePath + Seq(("a", "b")).toDF("columnA", "columnB").write + .format("csv") + .option("header", true) + .save(dir) + + // schema with one column + checkAnswer(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .select("columnA"), + Row("a")) + + // empty schema + assert(spark.read + .format("csv") + .option("header", true) + .option("enforceSchema", false) + .option("multiLine", multiLine) + .load(dir) + .count() === 1L) + } + } + } + } + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { withTempPath { path => From f984ec75ed6162ee6f5881716a8311c883aca22a Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 21 Aug 2018 10:34:23 +0800 Subject: [PATCH 1434/2461] [SPARK-25132][SQL] Case-insensitive field resolution when reading from Parquet ## What changes were proposed in this pull request? Spark SQL returns NULL for a column whose Hive metastore schema and Parquet schema are in different letter cases, regardless of spark.sql.caseSensitive set to true or false. This PR aims to add case-insensitive field resolution for ParquetFileFormat. * Do case-insensitive resolution only if Spark is in case-insensitive mode. * Field resolution should fail if there is ambiguity, i.e. more than one field is matched. ## How was this patch tested? Unit tests added. Closes #22148 from seancxmao/SPARK-25132-Parquet. Authored-by: seancxmao Signed-off-by: hyukjinkwon --- .../parquet/ParquetFileFormat.scala | 3 + .../parquet/ParquetReadSupport.scala | 84 +++++++++++++------ .../spark/sql/FileBasedDataSourceSuite.scala | 43 ++++++++++ .../parquet/ParquetSchemaSuite.scala | 61 ++++++++++++-- 4 files changed, 161 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b2409f3470e73..d7eb14356b8b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -310,6 +310,9 @@ class ParquetFileFormat hadoopConf.set( SQLConf.SESSION_LOCAL_TIMEZONE.key, sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 40ce5d5e0564e..3319e73f2b313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.{Map => JMap, TimeZone} +import java.util.{Locale, Map => JMap, TimeZone} import scala.collection.JavaConverters._ @@ -30,6 +30,7 @@ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -71,8 +72,10 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) StructType.fromString(schemaString) } - val parquetRequestedSchema = - ParquetReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key, + SQLConf.CASE_SENSITIVE.defaultValue.get) + val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema( + context.getFileSchema, catalystRequestedSchema, caseSensitive) new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -117,8 +120,12 @@ private[parquet] object ParquetReadSupport { * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist * in `catalystSchema`, and adding those only exist in `catalystSchema`. */ - def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { - val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + def clipParquetSchema( + parquetSchema: MessageType, + catalystSchema: StructType, + caseSensitive: Boolean = true): MessageType = { + val clippedParquetFields = clipParquetGroupFields( + parquetSchema.asGroupType(), catalystSchema, caseSensitive) if (clippedParquetFields.isEmpty) { ParquetSchemaConverter.EMPTY_MESSAGE } else { @@ -129,20 +136,21 @@ private[parquet] object ParquetReadSupport { } } - private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + private def clipParquetType( + parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = { catalystType match { case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => // Only clips array types with nested type as element type. - clipParquetListType(parquetType.asGroupType(), t.elementType) + clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive) case t: MapType if !isPrimitiveCatalystType(t.keyType) || !isPrimitiveCatalystType(t.valueType) => // Only clips map types with nested key type or value type - clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive) case t: StructType => - clipParquetGroup(parquetType.asGroupType(), t) + clipParquetGroup(parquetType.asGroupType(), t, caseSensitive) case _ => // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able @@ -168,14 +176,15 @@ private[parquet] object ParquetReadSupport { * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a * [[StructType]]. */ - private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + private def clipParquetListType( + parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = { // Precondition of this method, should only be called for lists with nested element types. assert(!isPrimitiveCatalystType(elementType)) // Unannotated repeated group should be interpreted as required list of required element, so // list element type is just the group itself. Clip it. if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { - clipParquetType(parquetList, elementType) + clipParquetType(parquetList, elementType, caseSensitive) } else { assert( parquetList.getOriginalType == OriginalType.LIST, @@ -207,7 +216,7 @@ private[parquet] object ParquetReadSupport { Types .buildGroup(parquetList.getRepetition) .as(OriginalType.LIST) - .addField(clipParquetType(repeatedGroup, elementType)) + .addField(clipParquetType(repeatedGroup, elementType, caseSensitive)) .named(parquetList.getName) } else { // Otherwise, the repeated field's type is the element type with the repeated field's @@ -218,7 +227,7 @@ private[parquet] object ParquetReadSupport { .addField( Types .repeatedGroup() - .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive)) .named(repeatedGroup.getName)) .named(parquetList.getName) } @@ -231,7 +240,10 @@ private[parquet] object ParquetReadSupport { * a [[StructType]]. */ private def clipParquetMapType( - parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + parquetMap: GroupType, + keyType: DataType, + valueType: DataType, + caseSensitive: Boolean): GroupType = { // Precondition of this method, only handles maps with nested key types or value types. assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) @@ -243,8 +255,8 @@ private[parquet] object ParquetReadSupport { Types .repeatedGroup() .as(repeatedGroup.getOriginalType) - .addField(clipParquetType(parquetKeyType, keyType)) - .addField(clipParquetType(parquetValueType, valueType)) + .addField(clipParquetType(parquetKeyType, keyType, caseSensitive)) + .addField(clipParquetType(parquetValueType, valueType, caseSensitive)) .named(repeatedGroup.getName) Types @@ -262,8 +274,9 @@ private[parquet] object ParquetReadSupport { * [[MessageType]]. Because it's legal to construct an empty requested schema for column * pruning. */ - private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { - val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + private def clipParquetGroup( + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive) Types .buildGroup(parquetRecord.getRepetition) .as(parquetRecord.getOriginalType) @@ -277,14 +290,35 @@ private[parquet] object ParquetReadSupport { * @return A list of clipped [[GroupType]] fields, which can be empty. */ private def clipParquetGroupFields( - parquetRecord: GroupType, structType: StructType): Seq[Type] = { - val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = { val toParquet = new SparkToParquetSchemaConverter(writeLegacyParquetFormat = false) - structType.map { f => - parquetFieldMap - .get(f.name) - .map(clipParquetType(_, f.dataType)) - .getOrElse(toParquet.convertField(f)) + if (caseSensitive) { + val caseSensitiveParquetFieldMap = + parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + structType.map { f => + caseSensitiveParquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType, caseSensitive)) + .getOrElse(toParquet.convertField(f)) + } + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveParquetFieldMap = + parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT)) + structType.map { f => + caseInsensitiveParquetFieldMap + .get(f.name.toLowerCase(Locale.ROOT)) + .map { parquetTypes => + if (parquetTypes.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched + val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]") + throw new RuntimeException(s"""Found duplicate field(s) "${f.name}": """ + + s"$parquetTypesString in case-insensitive mode") + } else { + clipParquetType(parquetTypes.head, f.dataType, caseSensitive) + } + }.getOrElse(toParquet.convertField(f)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 9f9af89570789..4aa6afd69620b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -430,6 +430,49 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + test(s"SPARK-25132: case-insensitive field resolution when reading from Parquet") { + withTempDir { dir => + val format = "parquet" + val tableDir = dir.getCanonicalPath + s"/$format" + val tableName = s"spark_25132_${format}" + withTable(tableName) { + val end = 5 + val data = spark.range(end).selectExpr("id as A", "id * 2 as b", "id * 3 as B") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + data.write.format(format).mode("overwrite").save(tableDir) + } + sql(s"CREATE TABLE $tableName (a LONG, b LONG) USING $format LOCATION '$tableDir'") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer(sql(s"select a from $tableName"), data.select("A")) + checkAnswer(sql(s"select A from $tableName"), data.select("A")) + + // RuntimeException is triggered at executor side, which is then wrapped as + // SparkException at driver side + val e1 = intercept[SparkException] { + sql(s"select b from $tableName").collect() + } + assert( + e1.getCause.isInstanceOf[RuntimeException] && + e1.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + val e2 = intercept[SparkException] { + sql(s"select B from $tableName").collect() + } + assert( + e2.getCause.isInstanceOf[RuntimeException] && + e2.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select a from $tableName"), (0 until end).map(_ => Row(null))) + checkAnswer(sql(s"select b from $tableName"), data.select("b")) + } + } + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 368e52cfbda9c..7eefedb8ff5bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -1014,19 +1014,21 @@ class ParquetSchemaSuite extends ParquetSchemaTest { testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: String): Unit = { + expectedSchema: String, + caseSensitive: Boolean = true): Unit = { testSchemaClipping(testName, parquetSchema, catalystSchema, - MessageTypeParser.parseMessageType(expectedSchema)) + MessageTypeParser.parseMessageType(expectedSchema), caseSensitive) } private def testSchemaClipping( testName: String, parquetSchema: String, catalystSchema: StructType, - expectedSchema: MessageType): Unit = { + expectedSchema: MessageType, + caseSensitive: Boolean): Unit = { test(s"Clipping - $testName") { val actual = ParquetReadSupport.clipParquetSchema( - MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive) try { expectedSchema.checkContains(actual) @@ -1387,7 +1389,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), - expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE) + expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE, + caseSensitive = true) testSchemaClipping( "disjoint field sets", @@ -1544,4 +1547,52 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin) + + testSchemaClipping( + "case-insensitive resolution: no ambiguity", + parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + }, + expectedSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + |} + """.stripMargin, + caseSensitive = false) + + test("Clipping - case-insensitive resolution: more than one field is matched") { + val parquetSchema = + """message root { + | required group A { + | optional int32 B; + | } + | optional int32 c; + | optional int32 a; + |} + """.stripMargin + val catalystSchema = { + val nestedType = new StructType().add("b", IntegerType, nullable = true) + new StructType() + .add("a", nestedType, nullable = true) + .add("c", IntegerType, nullable = true) + } + assertThrows[RuntimeException] { + ParquetReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema, caseSensitive = false) + } + } } From 4fb96e5105cec4a3eb19a2b7997600b086bac32f Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 20 Aug 2018 23:13:31 -0700 Subject: [PATCH 1435/2461] [SPARK-25114][CORE] Fix RecordBinaryComparator when subtraction between two words is divisible by Integer.MAX_VALUE. ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/22079#discussion_r209705612 It is possible for two objects to be unequal and yet we consider them as equal with this code, if the long values are separated by Int.MaxValue. This PR fixes the issue. ## How was this patch tested? Add new test cases in `RecordBinaryComparatorSuite`. Closes #22101 from jiangxb1987/fix-rbc. Authored-by: Xingbo Jiang Signed-off-by: Xiao Li --- .../sql/execution/RecordBinaryComparator.java | 26 ++++---- .../sort/RecordBinaryComparatorSuite.java | 66 +++++++++++++++++++ 2 files changed, 81 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2a..40c2cc806e87a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,12 +22,10 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +38,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index a19ddbdbadba2..97f3dc588ecc5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -253,4 +253,70 @@ public void testBinaryComparatorForNullColumns() throws Exception { assert(compare(0, 0) == 0); assert(compare(0, 1) > 0); } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } } From b8788b3e79d0d508e3a910fefd7e9cff4c6d6245 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 21 Aug 2018 08:18:21 -0500 Subject: [PATCH 1436/2461] [BUILD] Close stale PRs Closes #16411 Closes #21870 Closes #21794 Closes #21610 Closes #21961 Closes #21940 Closes #21870 Closes #22118 Closes #21624 Closes #19528 Closes #18424 Closes #22159 from srowen/Stale. Authored-by: Sean Owen Signed-off-by: Sean Owen From 5059255d91fc7a9810e013eba39e12d30291dd08 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 21 Aug 2018 08:25:02 -0700 Subject: [PATCH 1437/2461] [SPARK-25161][CORE] Fix several bugs in failure handling of barrier execution mode ## What changes were proposed in this pull request? Fix several bugs in failure handling of barrier execution mode: * Mark TaskSet for a barrier stage as zombie when a task attempt fails; * Multiple barrier task failures from a single barrier stage should not trigger multiple stage retries; * Barrier task failure from a previous failed stage attempt should not trigger stage retry; * Fail the job when a task from a barrier ResultStage failed; * RDD.isBarrier() should not rely on `ShuffleDependency`s. ## How was this patch tested? Added corresponding test cases in `DAGSchedulerSuite` and `TaskSchedulerImplSuite`. Closes #22158 from jiangxb1987/failure. Authored-by: Xingbo Jiang Signed-off-by: Xiangrui Meng --- .../main/scala/org/apache/spark/rdd/RDD.scala | 3 +- .../apache/spark/scheduler/DAGScheduler.scala | 125 ++++++++++-------- .../spark/scheduler/TaskSetManager.scala | 4 + .../spark/scheduler/DAGSchedulerSuite.scala | 106 +++++++++++++++ .../scheduler/TaskSchedulerImplSuite.scala | 18 +++ 5 files changed, 200 insertions(+), 56 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cbc1143126d8e..374b846d2ea57 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1863,7 +1863,8 @@ abstract class RDD[T: ClassTag]( // From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long // RDD chain. - @transient protected lazy val isBarrier_ : Boolean = dependencies.exists(_.rdd.isBarrier()) + @transient protected lazy val isBarrier_ : Boolean = + dependencies.filter(!_.isInstanceOf[ShuffleDependency[_, _, _]]).exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2b0ca13485eb5..6787250ddc3f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1478,9 +1478,11 @@ private[spark] class DAGScheduler( mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) case failedResultStage: ResultStage => - // Mark all the partitions of the result stage to be not finished, to ensure retry - // all the tasks on resubmitted stage attempt. - failedResultStage.activeJob.map(_.resetAllPartitions()) + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $failureMessage" + abortStage(failedResultStage, reason, None) } } @@ -1553,62 +1555,75 @@ private[spark] class DAGScheduler( // Always fail the current stage and retry all the tasks when a barrier task fail. val failedStage = stageIdToStage(task.stageId) - logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + - "failed.") - val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + - failure.toErrorString - try { - // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. - val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) failed." - taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) - } catch { - case e: UnsupportedOperationException => - // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. - // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. - logWarning(s"Could not kill all tasks for stage $stageId", e) - abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + - s"$failedStage (${failedStage.name})", Some(e)) - } - markStageAsFinished(failedStage, Some(message)) + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + + failure.toErrorString + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " + + "failed." + taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) - failedStage.failedAttemptIds.add(task.stageAttemptId) - // TODO Refactor the failure handling logic to combine similar code with that of - // FetchFailed. - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest - if (shouldAbortStage) { - val abortMessage = if (disallowStageRetryForTest) { - "Barrier stage will not retry stage due to testing config. Most recent failure " + - s"reason: $message" + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message + """.stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) } else { - s"""$failedStage (${failedStage.name}) - |has failed the maximum allowable number of - |times: $maxConsecutiveStageAttempts. - |Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ") - } - abortStage(failedStage, abortMessage, None) - } else { - failedStage match { - case failedMapStage: ShuffleMapStage => - // Mark all the map as broken in the map stage, to ensure retry all the tasks on - // resubmitted stage attempt. - mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) - - case failedResultStage: ResultStage => - // Mark all the partitions of the result stage to be not finished, to ensure retry - // all the tasks on resubmitted stage attempt. - failedResultStage.activeJob.map(_.resetAllPartitions()) - } + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) - // update failedStages and make sure a ResubmitFailedStages event is enqueued - failedStages += failedStage - logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + - "failure.") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $message" + abortStage(failedResultStage, reason, None) + } + // In case multiple task failures triggered for a single stage attempt, ensure we only + // resubmit the failed stage once. + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + if (noResubmitEnqueued) { + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + } } case Resubmitted => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8b77641e85b76..d5e85a11cb279 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -893,6 +893,10 @@ private[spark] class TaskSetManager( None } + if (tasks(index).isBarrier) { + isZombie = true + } + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (!isZombie && reason.countTowardsTaskFailures) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6eeddbb763172..56ba23c38af7f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1119,6 +1119,33 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Fail the job if a barrier ResultTask failed") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + .barrier() + .mapPartitions(iter => iter) + submit(reduceRdd, Array(0, 1)) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostA", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first ResultTask fails + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + TaskKilled("test"), + null)) + + // Assert the stage has been cancelled. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(failure.getMessage.startsWith("Job aborted due to stage failure: Could not recover " + + "from a failed barrier ResultStage.")) + } + /** * This tests the case where another FetchFailed comes in while the map stage is getting * re-run. @@ -2521,6 +2548,85 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + test("Barrier task failures from the same stage attempt don't trigger multiple stage retries") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // The first map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + } + + test("Barrier task failures from a previous stage attempt don't trigger stage retry") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // The first map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + + // The second map task failure doesn't trigger stage retry. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index ca9bf08cee654..7a457a0a72d90 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1118,4 +1118,22 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!tsm.isZombie) assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined) } + + test("mark taskset for a barrier stage as zombie in case a task fails") { + val taskScheduler = setupScheduler() + + val attempt = FakeTask.createBarrierTaskSet(3) + taskScheduler.submitTasks(attempt) + + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + val offers = (0 until 3).map{ idx => + WorkerOffer(s"exec-$idx", s"host-$idx", 1, Some(s"192.168.0.101:4962$idx")) + } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 3) + + // Fail a task from the stage attempt. + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test")) + assert(tsm.isZombie) + } } From d80063278debc5529653d184841f50fe98cdad97 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 22 Aug 2018 01:00:06 +0800 Subject: [PATCH 1438/2461] [MINOR] Add .crc files to .gitignore ## What changes were proposed in this pull request? Add .crc files to .gitignore so that we don't add .crc files in state checkpoint to git repo which could be added in test resources. This is based on comments in #21733, https://github.com/apache/spark/pull/21733#issuecomment-414578244. ## How was this patch tested? Add `.1.delta.crc` and `.2.delta.crc` in `/sql/core/src/test/resources`, and confirm git doesn't suggest the files to add to stage. Closes #22170 from HeartSaVioR/add-crc-files-to-gitignore. Authored-by: Jungtaek Lim Signed-off-by: hyukjinkwon --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index e4c44d0590d59..19db7ac277944 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,7 @@ target/ unit-tests.log work/ docs/.jekyll-metadata +*.crc # For Hive TempStatsStore/ From 35f7f5ce83984d8afe0b7955942baa04f2bef74f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 22 Aug 2018 01:02:17 +0800 Subject: [PATCH 1439/2461] [DOCS][MINOR] Fix a few broken links and typos, and, nit, use HTTPS more consistently ## What changes were proposed in this pull request? Fix a few broken links and typos, and, nit, use HTTPS more consistently esp. on scripts and Apache links ## How was this patch tested? Doc build Closes #22172 from srowen/DocTypo. Authored-by: Sean Owen Signed-off-by: hyukjinkwon --- docs/README.md | 4 ++-- docs/_layouts/404.html | 2 +- docs/_layouts/global.html | 6 +++--- docs/building-spark.md | 8 ++++---- docs/contributing-to-spark.md | 2 +- docs/index.md | 16 ++++++++-------- docs/ml-migration-guides.md | 2 +- docs/quick-start.md | 2 +- docs/rdd-programming-guide.md | 4 ++-- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 2 +- docs/security.md | 6 +++--- docs/sparkr.md | 2 +- docs/sql-programming-guide.md | 6 +++--- docs/streaming-kinesis-integration.md | 2 +- docs/streaming-programming-guide.md | 5 ++--- docs/structured-streaming-programming-guide.md | 2 +- 17 files changed, 36 insertions(+), 37 deletions(-) diff --git a/docs/README.md b/docs/README.md index dbea4d64c4298..7da543dd297ad 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,7 +2,7 @@ Welcome to the Spark documentation! This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of -Spark at http://spark.apache.org/documentation.html. +Spark at https://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the documentation yourself. Why build it yourself? So that you have the docs that correspond to @@ -79,7 +79,7 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc and javadoc using [Unidoc](https://github.com/sbt/sbt-unidoc). The jekyll plugin also generates the PySpark docs using [Sphinx](http://sphinx-doc.org/), SparkR docs using [roxygen2](https://cran.r-project.org/web/packages/roxygen2/index.html) and SQL docs -using [MkDocs](http://www.mkdocs.org/). +using [MkDocs](https://www.mkdocs.org/). NOTE: To skip the step of building and copying over the Scala, Java, Python, R and SQL API docs, run `SKIP_API=1 jekyll build`. In addition, `SKIP_SCALADOC=1`, `SKIP_PYTHONDOC=1`, `SKIP_RDOC=1` and `SKIP_SQLDOC=1` can be used diff --git a/docs/_layouts/404.html b/docs/_layouts/404.html index 044654413f9c2..78f98b9ede5a7 100755 --- a/docs/_layouts/404.html +++ b/docs/_layouts/404.html @@ -151,7 +151,7 @@

      Not found :(

      - + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index e5af5ae4561c7..88d549c3f1010 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -50,7 +50,7 @@ @@ -114,8 +114,8 @@
    • Hardware Provisioning
    • Building Spark
    • -
    • Contributing to Spark
    • -
    • Third Party Projects
    • +
    • Contributing to Spark
    • +
    • Third Party Projects
    • diff --git a/docs/building-spark.md b/docs/building-spark.md index affd7df17b001..d3dfd4902a920 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -45,7 +45,7 @@ Other build examples can be found below. ## Building a Runnable Distribution To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +[Spark Downloads](https://spark.apache.org/downloads.html) page, and that is laid out so as to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: @@ -164,7 +164,7 @@ prompt. Developers who compile Spark frequently may want to speed up compilation; e.g., by using Zinc (for developers who build with Maven) or by avoiding re-compilation of the assembly JAR (for developers who build with SBT). For more information about how to do this, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html#reducing-build-times). ## Encrypted Filesystems @@ -182,7 +182,7 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ ## IntelliJ IDEA or Eclipse For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html). # Running Tests @@ -203,7 +203,7 @@ The following is an example of a command to run the tests: ## Running Individual Tests For information about how to run individual tests, refer to the -[Useful Developer Tools page](http://spark.apache.org/developer-tools.html#running-individual-tests). +[Useful Developer Tools page](https://spark.apache.org/developer-tools.html#running-individual-tests). ## PySpark pip installable diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md index 9252545e4a129..ede5584a0cf99 100644 --- a/docs/contributing-to-spark.md +++ b/docs/contributing-to-spark.md @@ -5,4 +5,4 @@ title: Contributing to Spark The Spark team welcomes all forms of contributions, including bug reports, documentation or patches. For the newest information on how to contribute to the project, please read the -[Contributing to Spark guide](http://spark.apache.org/contributing.html). +[Contributing to Spark guide](https://spark.apache.org/contributing.html). diff --git a/docs/index.md b/docs/index.md index 2f009417fafb0..40f628b794c01 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,7 +12,7 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog # Downloading -Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. +Get Spark from the [downloads page](https://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. Users can also download a "Hadoop free" binary and run Spark with any Hadoop version [by augmenting Spark's classpath](hadoop-provided.html). Scala and Java users can include Spark in their projects using its Maven coordinates and in the future Python users can also install Spark from PyPI. @@ -111,7 +111,7 @@ options for deployment: * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using - [Apache Mesos](http://mesos.apache.org) + [Apache Mesos](https://mesos.apache.org) * [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN) * [Kubernetes](running-on-kubernetes.html): deploy Spark on top of Kubernetes @@ -127,20 +127,20 @@ options for deployment: * [Cloud Infrastructures](cloud-integration.html) * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system -* [Contributing to Spark](http://spark.apache.org/contributing.html) -* [Third Party Projects](http://spark.apache.org/third-party-projects.html): related third party Spark projects +* [Contributing to Spark](https://spark.apache.org/contributing.html) +* [Third Party Projects](https://spark.apache.org/third-party-projects.html): related third party Spark projects **External Resources:** -* [Spark Homepage](http://spark.apache.org) -* [Spark Community](http://spark.apache.org/community.html) resources, including local meetups +* [Spark Homepage](https://spark.apache.org) +* [Spark Community](https://spark.apache.org/community.html) resources, including local meetups * [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) -* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here +* [Mailing Lists](https://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/6/), [slides](http://ampcamp.berkeley.edu/6/) and [exercises](http://ampcamp.berkeley.edu/6/exercises/) are available online for free. -* [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), +* [Code Examples](https://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)) diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index e4736411fb5fe..2047065f71eb8 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -289,7 +289,7 @@ In the `spark.mllib` package, there were several breaking changes. The first ch In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. +* The old [SchemaRDD](https://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. * In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. * Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. diff --git a/docs/quick-start.md b/docs/quick-start.md index f1a2096cd4dbd..ef7af6c3f6cec 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -12,7 +12,7 @@ interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. To follow along with this guide, first, download a packaged release of Spark from the -[Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, +[Spark website](https://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index b6424090d2fea..d95b757f36859 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -106,7 +106,7 @@ You can also use `bin/pyspark` to launch an interactive Python shell. If you wish to access HDFS data, you need to use a build of PySpark linking to your version of HDFS. -[Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage +[Prebuilt packages](https://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. Finally, you need to import some Spark classes into your program. Add the following line: @@ -1569,7 +1569,7 @@ as Spark does not support two contexts running concurrently in the same program. # Where to Go from Here -You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website. +You can see some [example Spark programs](https://spark.apache.org/examples.html) on the Spark website. In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3e76d47608c74..b473e654563d6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -672,7 +672,7 @@ See the [configuration page](configuration.html) for information on Spark config
      spark.mesos.dispatcher.historyServer.url (none) - Set the URL of the history + Set the URL of the history server. The dispatcher will then link each driver to its entry in the history server.
      + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaningScope
      avroSchemaNoneOptional Avro schema provided by an user in JSON format. The date type and naming of record fields + should match the input Avro data or Catalyst data, otherwise the read/write action will fail.read and write
      recordNametopLevelRecordTop level record name in write result, which is required in Avro spec.write
      recordNamespace""Record namespace in write result.write
      ignoreExtensiontrueThe option controls ignoring of files without .avro extensions in read.
      If the option is enabled, all files (with and without .avro extension) are loaded.
      read
      compressionsnappyThe compression option allows to specify a compression codec used in write.
      + Currently supported codecs are uncompressed, snappy, deflate, bzip2 and xz.
      If the option is not set, the configuration spark.sql.avro.compression.codec config is taken into account.
      write
      + +## Configuration +Configuration of Avro can be done using the `setConf` method on SparkSession or by running `SET key=value` commands using SQL. + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.sql.legacy.replaceDatabricksSparkAvro.enabledtrueIf it is set to true, the data source provider com.databricks.spark.avro is mapped to the built-in but external Avro data source module for backward compatibility.
      spark.sql.avro.compression.codecsnappyCompression codec used in writing of AVRO files. Supported codecs: uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.
      spark.sql.avro.deflate.level-1Compression level for the deflate codec used in writing of AVRO files. Valid value must be in the range of from 1 to 9 inclusive or -1. The default value is -1 which corresponds to 6 level in the current implementation.
      + +## Compatibility with Databricks spark-avro +This Avro data source module is originally from and compatible with Databricks's open source repository +[spark-avro](https://github.com/databricks/spark-avro). + +By default with the SQL configuration `spark.sql.legacy.replaceDatabricksSparkAvro.enabled` enabled, the data source provider `com.databricks.spark.avro` is +mapped to this built-in Avro module. For the Spark tables created with `Provider` property as `com.databricks.spark.avro` in +catalog meta store, the mapping is essential to load these tables if you are using this built-in Avro module. + +Note in Databricks's [spark-avro](https://github.com/databricks/spark-avro), implicit classes +`AvroDataFrameWriter` and `AvroDataFrameReader` were created for shortcut function `.avro()`. In this +built-in but external module, both implicit classes are removed. Please use `.format("avro")` in +`DataFrameWriter` or `DataFrameReader` instead, which should be clean and good enough. + +If you prefer using your own build of `spark-avro` jar file, you can simply disable the configuration +`spark.sql.legacy.replaceDatabricksSparkAvro.enabled`, and use the option `--jars` on deploying your +applications. Read the [Advanced Dependency Management](https://spark.apache +.org/docs/latest/submitting-applications.html#advanced-dependency-management) section in Application +Submission Guide for more details. + +## Supported types for Avro -> Spark SQL conversion +Currently Spark supports reading all [primitive types](https://avro.apache.org/docs/1.8.2/spec.html#schema_primitive) and [complex types](https://avro.apache.org/docs/1.8.2/spec.html#schema_complex) under records of Avro. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Avro typeSpark SQL type
      booleanBooleanType
      intIntegerType
      longLongType
      floatFloatType
      doubleDoubleType
      stringStringType
      enumStringType
      fixedBinaryType
      bytesBinaryType
      recordStructType
      arrayArrayType
      mapMapType
      unionSee below
      + +In addition to the types listed above, it supports reading `union` types. The following three types are considered basic `union` types: + +1. `union(int, long)` will be mapped to LongType. +2. `union(float, double)` will be mapped to DoubleType. +3. `union(something, null)`, where something is any supported Avro type. This will be mapped to the same Spark SQL type as that of something, with nullable set to true. +All other union types are considered complex. They will be mapped to StructType where field names are member0, member1, etc., in accordance with members of the union. This is consistent with the behavior when converting between Avro and Parquet. + +It also supports reading the following Avro [logical types](https://avro.apache.org/docs/1.8.2/spec.html#Logical+Types): + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Avro logical typeAvro typeSpark SQL type
      dateintDateType
      timestamp-millislongTimestampType
      timestamp-microslongTimestampType
      decimalfixedDecimalType
      decimalbytesDecimalType
      +At the moment, it ignores docs, aliases and other properties present in the Avro file. + +## Supported types for Spark SQL -> Avro conversion +Spark supports writing of all Spark SQL types into Avro. For most types, the mapping from Spark types to Avro types is straightforward (e.g. IntegerType gets converted to int); however, there are a few special cases which are listed below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Spark SQL typeAvro typeAvro logical type
      ByteTypeint
      ShortTypeint
      BinaryTypebytes
      DateTypeintdate
      TimestampTypelongtimestamp-micros
      DecimalTypefixeddecimal
      + +You can also specify the whole output Avro schema with the option `avroSchema`, so that Spark SQL types can be converted into other Avro types. The following conversions are not applied by default and require user specified Avro schema: + + + + + + + + + + + + + + + + + + + + + + + +
      Spark SQL typeAvro typeAvro logical type
      BinaryTypefixed
      StringTypeenum
      TimestampTypelongtimestamp-millis
      DecimalTypebytesdecimal
      diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 8e308d5aa05e0..3749094569271 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1482,6 +1482,9 @@ SELECT * FROM resultTable
      +## Avro Files +See the [Apache Avro Data Source Guide](avro-data-source-guide.html). + ## Troubleshooting * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. From 49720906c9b2f36ead366b06568ddfaddb5cd791 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 23 Aug 2018 14:17:29 +0800 Subject: [PATCH 1460/2461] [SPARK-23932][SQL][FOLLOW-UP] Fix an example of zip_with function. ## What changes were proposed in this pull request? This is a follow-up pr of #22031 which added `zip_with` function to fix an example. ## How was this patch tested? Existing tests. Closes #22194 from ueshin/issues/SPARK-23932/fix_examples. Authored-by: Takuya UESHIN Signed-off-by: hyukjinkwon --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 3e0621d3e20eb..9f2e84a230060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -789,7 +789,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); - array(('a', 1), ('b', 3), ('c', 5)) + array(('a', 1), ('b', 2), ('c', 3)) > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); array(4, 6) > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); From 2a0a8f753bbdc8c251f8e699c0808f35b94cfd20 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Aug 2018 14:26:10 +0800 Subject: [PATCH 1461/2461] [SPARK-23034][SQL] Show RDD/relation names in RDD/Hive table scan nodes ## What changes were proposed in this pull request? This pr proposed to show RDD/relation names in RDD/Hive table scan nodes. This change made these names show up in the webUI and explain results. For example; ``` scala> sql("CREATE TABLE t(c1 int) USING hive") scala> sql("INSERT INTO t VALUES(1)") scala> spark.table("t").explain() == Physical Plan == Scan hive default.t [c1#8], HiveTableRelation `default`.`t`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#8] ^^^^^^^^^^^ ``` spark-pr-hive Closes #20226 ## How was this patch tested? Added tests in `DataFrameSuite`, `DatasetSuite`, and `HiveExplainSuite` Closes #22153 from maropu/pr20226. Lead-authored-by: Takeshi Yamamuro Co-authored-by: Tejas Patil Signed-off-by: Wenchen Fan --- .../apache/spark/sql/kafka010/KafkaRelation.scala | 2 +- .../apache/spark/sql/kafka010/KafkaSource.scala | 4 ++-- .../scala/org/apache/spark/sql/SparkSession.scala | 6 +++--- .../apache/spark/sql/execution/ExistingRDD.scala | 14 +++++++++++--- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ .../streaming/FlatMapGroupsWithStateSuite.scala | 5 ++++- .../sql/hive/execution/HiveTableScanExec.scala | 2 ++ .../sql/hive/execution/HiveExplainSuite.scala | 12 ++++++++++++ 10 files changed, 56 insertions(+), 11 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index c31e6ed3e0903..9d856c9494e10 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -117,7 +117,7 @@ private[kafka010] class KafkaRelation( DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)), cr.timestampType.id) } - sqlContext.internalCreateDataFrame(rdd, schema).rdd + sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd } private def getPartitionOffsets( diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 101e649727fcf..66ec7e0cd084a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -215,7 +215,7 @@ private[kafka010] class KafkaSource( } if (start.isDefined && start.get == end) { return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + sqlContext.sparkContext.emptyRDD[InternalRow].setName("empty"), schema, isStreaming = true) } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => @@ -299,7 +299,7 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema, isStreaming = true) } /** Stop this source and free any resources it has allocated. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d9278d8cd23d6..2b847fb6f9458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -270,7 +270,7 @@ class SparkSession private( */ @transient lazy val emptyDataFrame: DataFrame = { - createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) + createDataFrame(sparkContext.emptyRDD[Row].setName("empty"), StructType(Nil)) } /** @@ -395,7 +395,7 @@ class SparkSession private( // BeanInfo is not serializable so we must rediscover it remotely for each partition. SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq) } - Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd.setName(rdd.name))(self)) } /** @@ -594,7 +594,7 @@ class SparkSession private( } else { rowRDD.map { r: Row => InternalRow.fromSeq(r.toSeq) } } - internalCreateDataFrame(catalystRows, schema) + internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index be50a1571a2ff..2962becb64e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -103,6 +103,10 @@ case class ExternalRDDScanExec[T]( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") + + override val nodeName: String = s"Scan$rddName" + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val outputDataType = outputObjAttr.dataType @@ -116,7 +120,7 @@ case class ExternalRDDScanExec[T]( } override def simpleString: String = { - s"Scan $nodeName${output.mkString("[", ",", "]")}" + s"$nodeName${output.mkString("[", ",", "]")}" } } @@ -169,10 +173,14 @@ case class LogicalRDD( case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], - override val nodeName: String, + name: String, override val outputPartitioning: Partitioning = UnknownPartitioning(0), override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { + private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") + + override val nodeName: String = s"Scan $name$rddName" + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -189,6 +197,6 @@ case class RDDScanExec( } override def simpleString: String = { - s"Scan $nodeName${Utils.truncatedString(output, "[", ",", "]")}" + s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 501520c0e085e..6a5ac2413d73c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -211,6 +211,6 @@ private[sql] object ArrowConverters { ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) } val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - sqlContext.internalCreateDataFrame(rdd, schema) + sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7310087cc99f9..6f5c73074313c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2552,4 +2552,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(numJobs == 1) } } + + test("SPARK-23034 show rdd names in RDD scan nodes") { + val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd") + val df2 = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string")) + val output2 = new java.io.ByteArrayOutputStream() + Console.withOut(output2) { + df2.explain(extended = false) + } + assert(output2.toString.contains("Scan ExistingRDD testRdd")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cf24eba128012..6069f28d185e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1498,6 +1498,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.where($"city".contains(new java.lang.Character('A'))), Seq(Row("Amsterdam"))) } + + test("SPARK-23034 show rdd names in RDD scan nodes") { + val rddWithName = spark.sparkContext.parallelize(SingleData(1) :: Nil).setName("testRdd") + val df = spark.createDataFrame(rddWithName) + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain(extended = false) + } + assert(output.toString.contains("Scan testRdd")) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 76511ae2c8362..e77ba1ec9f1eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning @@ -1229,6 +1230,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { val stateFormatVersion = spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val emptyRdd = spark.sparkContext.emptyRDD[InternalRow] MemoryStream[Int] .toDS .groupByKey(x => x) @@ -1237,7 +1239,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( f, k, v, g, d, o, None, s, stateFormatVersion, m, t, - Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) + Some(currentBatchTimestamp), Some(currentBatchWatermark), + RDDScanExec(g, emptyRdd, "rdd")) }.get } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 6052486c47da2..b3795b4430404 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -62,6 +62,8 @@ case class HiveTableScanExec( override def conf: SQLConf = sparkSession.sessionState.conf + override def nodeName: String = s"Scan hive ${relation.tableMeta.qualifiedName}" + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index a1ce1ea936bbf..c349a327694bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -170,4 +170,16 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } + + test("SPARK-23034 show relation names in Hive table scan nodes") { + val tableName = "tab" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(c1 int) USING hive") + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + spark.table(tableName).explain(extended = false) + } + assert(output.toString.contains(s"Scan hive default.$tableName")) + } + } } From 8cc591c91a0c63effeed73801299985ba8a4a99e Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 23 Aug 2018 14:52:23 +0800 Subject: [PATCH 1462/2461] [SPARK-25164][SQL] Avoid rebuilding column and path list for each column in parquet reader ## What changes were proposed in this pull request? VectorizedParquetRecordReader::initializeInternal rebuilds the column list and path list once for each column. Therefore, it indirectly iterates 2\*colCount\*colCount times for each parquet file. This inefficiency impacts jobs that read parquet-backed tables with many columns and many files. Jobs that read tables with few columns or few files are not impacted. This PR changes initializeInternal so that it builds each list only once. I ran benchmarks on my laptop with 1 worker thread, running this query:
      sql("select * from parquet_backed_table where id1 = 1").collect
      
      There are roughly one matching row for every 425 rows, and the matching rows are sprinkled pretty evenly throughout the table (that is, every page for column id1 has at least one matching row). 6000 columns, 1 million rows, 67 32M files: master | branch | improvement -------|---------|----------- 10.87 min | 6.09 min | 44% 6000 columns, 1 million rows, 23 98m files: master | branch | improvement -------|---------|----------- 7.39 min | 5.80 min | 21% 600 columns 10 million rows, 67 32M files: master | branch | improvement -------|---------|----------- 1.95 min | 1.96 min | -0.5% 60 columns, 100 million rows, 67 32M files: master | branch | improvement -------|---------|----------- 0.55 min | 0.55 min | 0% ## How was this patch tested? - sql unit tests - pyspark-sql tests Closes #22188 from bersprockets/SPARK-25164. Authored-by: Bruce Robbins Signed-off-by: Wenchen Fan --- .../parquet/VectorizedParquetRecordReader.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 5934a23db8af1..f02861355c404 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -270,21 +270,23 @@ public boolean nextBatch() throws IOException { private void initializeInternal() throws IOException, UnsupportedOperationException { // Check that the requested schema is supported. missingColumns = new boolean[requestedSchema.getFieldCount()]; + List columns = requestedSchema.getColumns(); + List paths = requestedSchema.getPaths(); for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { Type t = requestedSchema.getFields().get(i); if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { throw new UnsupportedOperationException("Complex types not supported."); } - String[] colPath = requestedSchema.getPaths().get(i); + String[] colPath = paths.get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); - if (!fd.equals(requestedSchema.getColumns().get(i))) { + if (!fd.equals(columns.get(i))) { throw new UnsupportedOperationException("Schema evolution not supported."); } missingColumns[i] = false; } else { - if (requestedSchema.getColumns().get(i).getMaxDefinitionLevel() == 0) { + if (columns.get(i).getMaxDefinitionLevel() == 0) { // Column is missing in data but the required data is non-nullable. This file is invalid. throw new IOException("Required column is missing in data file. Col: " + Arrays.toString(colPath)); From e3b7bb4132884872a6913b9f452e910a2e4b8e40 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 23 Aug 2018 15:08:46 +0800 Subject: [PATCH 1463/2461] [SPARK-24811][FOLLOWUP][SQL] Revise package of AvroDataToCatalyst and CatalystDataToAvro ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/21838, the class `AvroDataToCatalyst` and `CatalystDataToAvro` were put in package `org.apache.spark.sql`. They should be moved to package `org.apache.spark.sql.avro`. Also optimize imports in Avro module. ## How was this patch tested? Unit test Closes #22196 from gengliangwang/avro_revise_package_name. Authored-by: Gengliang Wang Signed-off-by: hyukjinkwon --- .../scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala | 3 +-- .../scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala | 5 ++--- .../spark/sql/avro/AvroCatalystDataConversionSuite.scala | 5 +---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 6671b3fb8705c..915769fa708b0 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -15,13 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.avro import org.apache.avro.Schema import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} -import org.apache.spark.sql.avro.{AvroDeserializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala index a669388e88258..141ff3782adfb 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.avro import java.io.ByteArrayOutputStream import org.apache.avro.generic.GenericDatumWriter import org.apache.avro.io.{BinaryEncoder, EncoderFactory} -import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters} import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{BinaryType, DataType} case class CatalystDataToAvro(child: Expression) extends UnaryExpression { diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 4b3bf0cd52957..8334cca6cd8f1 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.avro -import org.apache.avro.Schema - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{AvroDataToCatalyst, CatalystDataToAvro, RandomDataGenerator} +import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { From 5d572fc7c35f76e27b2ab400674923eb8ba91745 Mon Sep 17 00:00:00 2001 From: Rao Fu Date: Thu, 23 Aug 2018 22:00:20 +0800 Subject: [PATCH 1464/2461] [SPARK-25126][SQL] Avoid creating Reader for all orc files ## What changes were proposed in this pull request? [SPARK-25126] (https://issues.apache.org/jira/browse/SPARK-25126) reports loading a large number of orc files consumes a lot of memory in both 2.0 and 2.3. The issue is caused by creating a Reader for every orc file in order to infer the schema. In OrFileOperator.ReadSchema, a Reader is created for every file although only the first valid one is used. This uses significant amount of memory when there `paths` have a lot of files. In 2.3 a different code path (OrcUtils.readSchema) is used for inferring schema for orc files. This commit changes both functions to create Reader lazily. ## How was this patch tested? Pass the Jenkins with a newly added test case by dongjoon-hyun Closes #22157 from raofu/SPARK-25126. Lead-authored-by: Rao Fu Co-authored-by: Dongjoon Hyun Co-authored-by: Rao Fu Signed-off-by: hyukjinkwon --- .../execution/datasources/orc/OrcUtils.scala | 7 ++-- .../datasources/orc/OrcQuerySuite.scala | 39 ++++++++++++++++++- .../spark/sql/hive/orc/OrcFileOperator.scala | 11 +++--- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b404cfa61f41e..ac062fdc092ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -79,9 +79,10 @@ object OrcUtils extends Logging { val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => - logDebug(s"Reading schema from file $files, got Hive schema string: $schema") - CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + files.toIterator.map(file => readSchema(file.getPath, conf, ignoreCorruptFiles)).collectFirst { + case Some(schema) => + logDebug(s"Reading schema from file $files, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index f58c331f33ca8..e9dccbf2e261c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -562,20 +562,57 @@ abstract class OrcQueryTest extends OrcTest { } } + def testAllCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.json(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.json(new Path(basePath, "second").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString) + assert(df.count() == 0) + } + } + + def testAllCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.json(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.json(new Path(basePath, "second").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString) + assert(df.count() == 0) + } + } + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() testIgnoreCorruptFilesWithoutSchemaInfer() + val m1 = intercept[AnalysisException] { + testAllCorruptFiles() + }.getMessage + assert(m1.contains("Unable to infer schema for ORC")) + testAllCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { val m1 = intercept[SparkException] { testIgnoreCorruptFiles() }.getMessage - assert(m1.contains("Could not read footer for file")) + assert(m1.contains("Malformed ORC file")) val m2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() }.getMessage assert(m2.contains("Malformed ORC file")) + val m3 = intercept[SparkException] { + testAllCorruptFiles() + }.getMessage + assert(m3.contains("Could not read footer for file")) + val m4 = intercept[SparkException] { + testAllCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m4.contains("Malformed ORC file")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 80e44ca504356..713b70f252b6a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -92,11 +92,12 @@ private[hive] object OrcFileOperator extends Logging { : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => - val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] - val schema = readerInspector.getTypeName - logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") - CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] + paths.toIterator.map(getFileReader(_, conf, ignoreCorruptFiles)).collectFirst { + case Some(reader) => + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") + CatalystSqlParser.parseDataType(schema).asInstanceOf[StructType] } } From a9aacdf1c2a5f202c52e6b539c868dd075eebc25 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 23 Aug 2018 22:48:26 +0800 Subject: [PATCH 1465/2461] [SPARK-25208][SQL] Loosen Cast.forceNullable for DecimalType. ## What changes were proposed in this pull request? Casting to `DecimalType` is not always needed to force nullable. If the decimal type to cast is wider than original type, or only truncating or precision loss, the casted value won't be `null`. ## How was this patch tested? Added and modified tests. Closes #22200 from ueshin/issues/SPARK-25208/cast_nullable_decimal. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/Cast.scala | 11 ++- .../apache/spark/sql/types/DecimalType.scala | 1 + .../catalyst/analysis/TypeCoercionSuite.scala | 22 +++++- .../sql/catalyst/expressions/CastSuite.scala | 45 ++++++++++-- .../inputs/typeCoercion/native/concat.sql | 2 + .../inputs/typeCoercion/native/mapZipWith.sql | 12 +++ .../inputs/typeCoercion/native/mapconcat.sql | 1 + .../typeCoercion/native/concat.sql.out | 6 +- .../typeCoercion/native/mapZipWith.sql.out | 73 ++++++++++++++----- .../typeCoercion/native/mapconcat.sql.out | 5 +- 10 files changed, 145 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 100b9cfd70f52..0053503501047 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -154,6 +154,15 @@ object Cast { fromPrecedence >= 0 && fromPrecedence < toPrecedence } + def canNullSafeCastToDecimal(from: DataType, to: DecimalType): Boolean = from match { + case from: BooleanType if to.isWiderThan(DecimalType.BooleanDecimal) => true + case from: NumericType if to.isWiderThan(from) => true + case from: DecimalType => + // truncating or precision lose + (to.precision - to.scale) > (from.precision - from.scale) + case _ => false // overflow + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -169,7 +178,7 @@ object Cast { case (DateType, _) => true case (_, CalendarIntervalType) => true - case (_, _: DecimalType) => true // overflow + case (_, to: DecimalType) if !canNullSafeCastToDecimal(from, to) => true case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index f780ffd46a876..15004e4b9667d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -121,6 +121,7 @@ object DecimalType extends AbstractDataType { val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types + private[sql] val BooleanDecimal = DecimalType(1, 0) private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) private[sql] val IntDecimal = DecimalType(10, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2c6cb3ae1274a..461eda4334bb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -502,7 +502,11 @@ class TypeCoercionSuite extends AnalysisTest { widenTestWithStringPromotion( ArrayType(IntegerType, containsNull = false), ArrayType(DecimalType.IntDecimal, containsNull = false), - Some(ArrayType(DecimalType.IntDecimal, containsNull = true))) + Some(ArrayType(DecimalType.IntDecimal, containsNull = false))) + widenTestWithStringPromotion( + ArrayType(DecimalType(36, 0), containsNull = false), + ArrayType(DecimalType(36, 35), containsNull = false), + Some(ArrayType(DecimalType(38, 35), containsNull = true))) // MapType widenTestWithStringPromotion( @@ -524,10 +528,18 @@ class TypeCoercionSuite extends AnalysisTest { widenTestWithStringPromotion( MapType(StringType, IntegerType, valueContainsNull = false), MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false), - Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = true))) + Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, DecimalType(36, 0), valueContainsNull = false), + MapType(StringType, DecimalType(36, 35), valueContainsNull = false), + Some(MapType(StringType, DecimalType(38, 35), valueContainsNull = true))) widenTestWithStringPromotion( MapType(IntegerType, StringType, valueContainsNull = false), MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false), + Some(MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(DecimalType(36, 0), StringType, valueContainsNull = false), + MapType(DecimalType(36, 35), StringType, valueContainsNull = false), None) // StructType @@ -555,7 +567,11 @@ class TypeCoercionSuite extends AnalysisTest { widenTestWithStringPromotion( new StructType().add("num", IntegerType, nullable = false), new StructType().add("num", DecimalType.IntDecimal, nullable = false), - Some(new StructType().add("num", DecimalType.IntDecimal, nullable = true))) + Some(new StructType().add("num", DecimalType.IntDecimal, nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", DecimalType(36, 0), nullable = false), + new StructType().add("num", DecimalType(36, 35), nullable = false), + Some(new StructType().add("num", DecimalType(38, 35), nullable = true))) widenTestWithStringPromotion( new StructType().add("num", IntegerType), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5b25bdf907c3a..d9f32c000a885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -399,21 +399,35 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === false) assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) - assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === false) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) assert(cast(10.03, DecimalType(2, 1)).nullable === true) assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + assert(cast(123, DecimalType.IntDecimal).nullable === false) + assert(cast(10.03f, DecimalType.FloatDecimal).nullable === true) + assert(cast(10.03, DecimalType.DoubleDecimal).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 2)).nullable === false) + assert(cast(Decimal(10.03), DecimalType(5, 3)).nullable === false) + + assert(cast(Decimal(10.03), DecimalType(3, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 1)).nullable === false) + assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false) + + assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable === true) + assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false) + assert(cast(Decimal("995"), DecimalType(2, -1)).nullable === true) + assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false) + + assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false) + assert(cast(true, DecimalType(1, 1)).nullable === true) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) @@ -451,6 +465,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null) + + checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995)) + checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) @@ -460,6 +488,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + + checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1)) + checkEvaluation(cast(true, DecimalType(1, 1)), null) } test("cast from date") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index db00a18f2e7e9..99f46dd19d0e2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -148,6 +148,8 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql index 119f868cb48e6..1727ee725db2e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -47,6 +47,18 @@ FROM various_maps; SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps; +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql index fc26397b881b5..69da67fc66fc0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -61,6 +61,7 @@ SELECT map_concat(tinyint_map1, smallint_map2) ts_map, map_concat(smallint_map1, int_map2) si_map, map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, map_concat(decimal_map1, float_map2) df_map, map_concat(string_map1, date_map2) std_map, map_concat(timestamp_map1, string_map2) tst_map, diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index be637b66abc86..6c6d3110d7d0d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -306,12 +306,14 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, (string_array1 || int_array2) sti_array FROM various_arrays -- !query 13 schema -struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +struct,si_array:array,ib_array:array,bd_array:array,dd_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> -- !query 13 output -[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,9223372036854775808,9223372036854775809] [9.223372036854776E18,9.223372036854776E18,3.0,4.0] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 7f7e2f07b9e74..35740094ba53e 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 16 -- !query 0 @@ -89,54 +89,91 @@ cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_ -- !query 6 -SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 6 schema -struct>> +struct>> -- !query 6 output -{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} +{2:{"k":2,"v1":null,"v2":1},922337203685477897945456575809789456:{"k":922337203685477897945456575809789456,"v1":922337203685477897945456575809789456,"v2":null}} -- !query 7 -SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 7 schema -struct>> +struct>> -- !query 7 output -{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854779E35:{"k":9.223372036854779E35,"v1":922337203685477897945456575809789456,"v2":null}} -- !query 8 -SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 8 schema -struct>> +struct<> -- !query 8 output -{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query 9 -SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 9 schema -struct>> +struct>> -- !query 9 output -{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854778:{"k":9.223372036854778,"v1":9.22337203685477897945456575809789456,"v2":null}} -- !query 10 -SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 10 schema -struct,struct,v1:array,v2:array>>> +struct>> -- !query 10 output -{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} -- !query 11 -SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 11 schema -struct,struct,v1:struct,v2:struct>>> +struct>> -- !query 11 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 12 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 12 schema +struct>> +-- !query 12 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 13 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 13 schema +struct>> +-- !query 13 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 14 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 14 schema +struct,struct,v1:array,v2:array>>> +-- !query 14 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 15 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 15 schema +struct,struct,v1:struct,v2:struct>>> +-- !query 15 output {{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index d352b7284ae87..efc88e47209a6 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -76,6 +76,7 @@ SELECT map_concat(tinyint_map1, smallint_map2) ts_map, map_concat(smallint_map1, int_map2) si_map, map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, map_concat(decimal_map1, float_map2) df_map, map_concat(string_map1, date_map2) std_map, map_concat(timestamp_map1, string_map2) tst_map, @@ -83,9 +84,9 @@ SELECT map_concat(int_string_map1, tinyint_map2) istt_map FROM various_maps -- !query 2 schema -struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +struct,si_map:map,ib_map:map,bd_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> -- !query 2 output -{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} +{1:2,3:4} {1:2,7:8} {4:6,8:9} {6:7,9223372036854775808:9223372036854775809} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} -- !query 3 From 8ed0449285507459bbd00752338ed3242427a14f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 23 Aug 2018 12:14:27 -0700 Subject: [PATCH 1466/2461] [SPARK-25204][SS] Fix race in rate source test. ## What changes were proposed in this pull request? Fix a race in the rate source tests. We need a better way of testing restart behavior. ## How was this patch tested? unit test Closes #22191 from jose-torres/racetest. Authored-by: Jose Torres Signed-off-by: Tathagata Das --- .../sources/RateStreamProviderSuite.scala | 40 +++++++++++++++++-- .../spark/sql/streaming/StreamTest.scala | 5 ++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 9c1756d68ccc4..dd74af873c2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ @@ -81,12 +82,43 @@ class RateSourceSuite extends StreamTest { .load() testStream(input)( AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - restart") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .load() + .select('value) + + var streamDuration = 0 + + // Microbatch rate stream offsets contain the number of seconds since the beginning of + // the stream. + def updateStreamDurationFromOffset(s: StreamExecution, expectedMin: Int): Unit = { + streamDuration = s.lastProgress.sources(0).endOffset.toInt + assert(streamDuration >= expectedMin) + } + + // We have to use the lambda version of CheckAnswer because we don't know the right range + // until we see the last offset. + def expectedResultsFromDuration(rows: Seq[Row]): Unit = { + assert(rows.map(_.getLong(0)).sorted == (0 until (streamDuration * 10))) + } + + testStream(input)( + StartStream(), + Execute(_.awaitOffset(0, LongOffset(2), streamingTimeout.toMillis)), StopStream, + Execute(updateStreamDurationFromOffset(_, 2)), + CheckAnswer(expectedResultsFromDuration _), StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + Execute(_.awaitOffset(0, LongOffset(4), streamingTimeout.toMillis)), + StopStream, + Execute(updateStreamDurationFromOffset(_, 4)), + CheckAnswer(expectedResultsFromDuration _) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index cd9b892eca1f6..491dc34afa143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -735,7 +735,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } try { globalCheckFunction(sparkAnswer) } catch { From b5e11880871d6ef31efe3ec42b3caa0fc403e71b Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 23 Aug 2018 16:17:27 -0700 Subject: [PATCH 1467/2461] [SPARK-25124][ML] VectorSizeHint setSize and getSize don't return values ## What changes were proposed in this pull request? In feature.py, VectorSizeHint setSize and getSize don't return value. Add return. ## How was this patch tested? I tested the changes on my local. Closes #22136 from huaxingao/spark-25124. Authored-by: Huaxin Gao Signed-off-by: Joseph K. Bradley --- python/pyspark/ml/feature.py | 4 ++-- python/pyspark/ml/tests.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ddba7389145e3..760aa82168f5a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3843,12 +3843,12 @@ def setParams(self, inputCol=None, size=None, handleInvalid="error"): @since("2.3.0") def getSize(self): """ Gets size param, the size of vectors in `inputCol`.""" - self.getOrDefault(self.size) + return self.getOrDefault(self.size) @since("2.3.0") def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" - self._set(size=value) + return self._set(size=value) if __name__ == "__main__": diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a770bad32ecd2..5c87d1de4139b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -844,6 +844,23 @@ def test_string_indexer_from_labels(self): .select(model_default.getOrDefault(model_default.outputCol)).collect() self.assertEqual(len(transformed_list), 5) + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + class HasInducedError(Params): From 0ce09ec54ec3cb03a44872edd546703d0e0b10f5 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 24 Aug 2018 09:31:06 +0800 Subject: [PATCH 1468/2461] [SPARK-25205][CORE] Fix typo in spark.network.crypto.keyFactoryIterations Closes #22195 from squito/SPARK-25205. Authored-by: Imran Rashid Signed-off-by: hyukjinkwon --- .../main/java/org/apache/spark/network/util/TransportConf.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b9492219..34e4bb5912dcb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -209,7 +209,7 @@ public String keyFactoryAlgorithm() { * (128 bits by default), which is not generally the case with user passwords. */ public int keyFactoryIterations() { - return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + return conf.getInt("spark.network.crypto.keyFactoryIterations", 1024); } /** From b88ddb8a83f88a68ac2ab45fc0b1bd8e8951d700 Mon Sep 17 00:00:00 2001 From: s71955 Date: Fri, 24 Aug 2018 09:54:30 +0800 Subject: [PATCH 1469/2461] [SPARK-23425][SQL] Support wildcard in HDFS path for load table command ## What changes were proposed in this pull request? **Problem statement** load data command with hdfs file paths consists of wild card strings like * are not working eg: "load data inpath 'hdfs://hacluster/user/ext* into table t1" throws Analysis exception while executing this query ![wildcard_issue](https://user-images.githubusercontent.com/12999161/42673744-9f5c0c16-8621-11e8-8d28-cdc41bbe6efe.PNG) **Analysis -** Currently fs.exists() API which is used for path validation in load command API cannot resolve the path with wild card pattern, To mitigate this problem i am using globStatus() API another api which can resolve the paths with hdfs supported wildcards like *,? etc(inline with hive wildcard support). **Improvement identified as part of this issue -** Currently system wont support wildcard character to be used for folder level path in a local file system. This PR has handled this scenario, the same globStatus API will unify the validation logic of local and non local file systems, this will ensure the behavior consistency between the hdfs and local file path in load command. with this improvement user will be able to use a wildcard character in folder level path of a local file system in load command inline with hive behaviour, in older versions user can use wildcards only in file path of the local file system if they use in folder path system use to give an error by mentioning that not supported. eg: load data local inpath '/localfilesystem/folder* into table t1 ## How was this patch tested? a) Manually tested by executing test-cases in HDFS yarn cluster. Reports is been attached in below section. b) Existing test-case can verify the impact and functionality for local file path scenarios c) A test-case is been added for verifying the functionality when wild card is been used in folder level path of a local file system ## Test Results Note: all ip's were updated to localhost for security reasons. HDFS path details ``` vm1:/opt/ficlient # hadoop fs -ls /user/data/sujith1 Found 2 items -rw-r--r-- 3 shahid hadoop 4802 2018-03-26 15:45 /user/data/sujith1/typeddata60.txt -rw-r--r-- 3 shahid hadoop 4883 2018-03-26 15:45 /user/data/sujith1/typeddata61.txt vm1:/opt/ficlient # hadoop fs -ls /user/data/sujith2 Found 2 items -rw-r--r-- 3 shahid hadoop 4802 2018-03-26 15:45 /user/data/sujith2/typeddata60.txt -rw-r--r-- 3 shahid hadoop 4883 2018-03-26 15:45 /user/data/sujith2/typeddata61.txt ``` positive scenario by specifying complete file path to know about record size ``` 0: jdbc:hive2://localhost:22550/default> create table wild_spark (time timestamp, name string, isright boolean, datetoday date, num binary, height double, score float, decimaler decimal(10,0), id tinyint, age int, license bigint, length smallint) row format delimited fields terminated by ','; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (1.217 seconds) 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/sujith1/typeddata60.txt' into table wild_spark; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (4.236 seconds) 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/sujith1/typeddata61.txt' into table wild_spark; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (0.602 seconds) 0: jdbc:hive2://localhost:22550/default> select count(*) from wild_spark; +-----------+--+ | count(1) | +-----------+--+ | 121 | +-----------+--+ 1 row selected (18.529 seconds) 0: jdbc:hive2://localhost:22550/default> ``` With wild card character in file path ``` 0: jdbc:hive2://localhost:22550/default> create table spark_withWildChar (time timestamp, name string, isright boolean, datetoday date, num binary, height double, score float, decimaler decimal(10,0), id tinyint, age int, license bigint, length smallint) row format delimited fields terminated by ','; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (0.409 seconds) 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/sujith1/type*' into table spark_withWildChar; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (1.502 seconds) 0: jdbc:hive2://localhost:22550/default> select count(*) from spark_withWildChar; +-----------+--+ | count(1) | +-----------+--+ | 121 | +-----------+--+ ``` with ? wild card scenario ``` 0: jdbc:hive2://localhost:22550/default> create table spark_withWildChar_DiffChar (time timestamp, name string, isright boolean, datetoday date, num binary, height double, score float, decimaler decimal(10,0), id tinyint, age int, license bigint, length smallint) row format delimited fields terminated by ','; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (0.489 seconds) 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/sujith1/?ypeddata60.txt' into table spark_withWildChar_DiffChar; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (1.152 seconds) 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/sujith1/?ypeddata61.txt' into table spark_withWildChar_DiffChar; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (0.644 seconds) 0: jdbc:hive2://localhost:22550/default> select count(*) from spark_withWildChar_DiffChar; +-----------+--+ | count(1) | +-----------+--+ | 121 | +-----------+--+ 1 row selected (16.078 seconds) ``` with folder level wild card scenario ``` 0: jdbc:hive2://localhost:22550/default> create table spark_withWildChar_folderlevel (time timestamp, name string, isright boolean, datetoday date, num binary, height double, score float, decimaler decimal(10,0), id tinyint, age int, license bigint, length smallint) row format delimited fields terminated by ','; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (0.489 seconds) 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/suji*/*' into table spark_withWildChar_folderlevel; +---------+--+ | Result | +---------+--+ +---------+--+ No rows selected (1.152 seconds) 0: jdbc:hive2://localhost:22550/default> select count(*) from spark_withWildChar_folderlevel; +-----------+--+ | count(1) | +-----------+--+ | 242 | +-----------+--+ 1 row selected (16.078 seconds) ``` Negative scenario invalid path ``` 0: jdbc:hive2://localhost:22550/default> load data inpath '/user/data/sujiinvalid*/*' into table spark_withWildChar_folder; Error: org.apache.spark.sql.AnalysisException: LOAD DATA input path does not exist: /user/data/sujiinvalid*/*; (state=,code=0) 0: jdbc:hive2://localhost:22550/default> ``` Hive Test results- file level ``` 0: jdbc:hive2://localhost:21066/> create table hive_withWildChar_files (time timestamp, name string, isright boolean, datetoday date, num binary, height double, score float, decimaler decimal(10,0), id tinyint, age int, license bigint, length smallint) stored as TEXTFILE; No rows affected (0.723 seconds) 0: jdbc:hive2://localhost:21066/> load data inpath '/user/data/sujith1/type*' into table hive_withWildChar_files; INFO : Loading data to table default.hive_withwildchar_files from hdfs://hacluster/user/sujith1/type* No rows affected (0.682 seconds) 0: jdbc:hive2://localhost:21066/> select count(*) from hive_withWildChar_files; +------+--+ | _c0 | +------+--+ | 121 | +------+--+ 1 row selected (50.832 seconds) ``` Hive Test results- folder level ``` 0: jdbc:hive2://localhost:21066/> create table hive_withWildChar_folder (time timestamp, name string, isright boolean, datetoday date, num binary, height double, score float, decimaler decimal(10,0), id tinyint, age int, license bigint, length smallint) stored as TEXTFILE; No rows affected (0.459 seconds) 0: jdbc:hive2://localhost:21066/> load data inpath '/user/data/suji*/*' into table hive_withWildChar_folder; INFO : Loading data to table default.hive_withwildchar_folder from hdfs://hacluster/user/data/suji*/* No rows affected (0.76 seconds) 0: jdbc:hive2://localhost:21066/> select count(*) from hive_withWildChar_folder; +------+--+ | _c0 | +------+--+ | 242 | +------+--+ 1 row selected (46.483 seconds) ``` Closes #20611 from sujith71955/master_wldcardsupport. Lead-authored-by: s71955 Co-authored-by: sujith71955 Signed-off-by: hyukjinkwon --- .../spark/sql/execution/command/tables.scala | 154 ++++++++---------- .../sql/hive/execution/SQLQuerySuite.scala | 55 ++++++- 2 files changed, 119 insertions(+), 90 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index f4dede9fcc899..2eca1c40a5b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.command import java.io.File -import java.net.URI +import java.net.{URI, URISyntaxException} import java.nio.file.FileSystems import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileContext, FsConstants, Path} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -303,94 +303,44 @@ case class LoadDataCommand( s"partitioned, but a partition spec was provided.") } } - - val loadPath = + val loadPath = { if (isLocal) { - val uri = Utils.resolveURI(path) - val file = new File(uri.getPath) - val exists = if (file.getAbsolutePath.contains("*")) { - val fileSystem = FileSystems.getDefault - val dir = file.getParentFile.getAbsolutePath - if (dir.contains("*")) { - throw new AnalysisException( - s"LOAD DATA input path allows only filename wildcard: $path") - } - - // Note that special characters such as "*" on Windows are not allowed as a path. - // Calling `WindowsFileSystem.getPath` throws an exception if there are in the path. - val dirPath = fileSystem.getPath(dir) - val pathPattern = new File(dirPath.toAbsolutePath.toString, file.getName).toURI.getPath - val safePathPattern = if (Utils.isWindows) { - // On Windows, the pattern should not start with slashes for absolute file paths. - pathPattern.stripPrefix("/") - } else { - pathPattern - } - val files = new File(dir).listFiles() - if (files == null) { - false - } else { - val matcher = fileSystem.getPathMatcher("glob:" + safePathPattern) - files.exists(f => matcher.matches(fileSystem.getPath(f.getAbsolutePath))) - } - } else { - new File(file.getAbsolutePath).exists() - } - if (!exists) { - throw new AnalysisException(s"LOAD DATA input path does not exist: $path") - } - uri + val localFS = FileContext.getLocalFSFileContext() + makeQualified(FsConstants.LOCAL_FS_URI, localFS.getWorkingDirectory(), new Path(path)) } else { - val uri = new URI(path) - val hdfsUri = if (uri.getScheme() != null && uri.getAuthority() != null) { - uri - } else { - // Follow Hive's behavior: - // If no schema or authority is provided with non-local inpath, - // we will use hadoop configuration "fs.defaultFS". - val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") - val defaultFS = if (defaultFSConf == null) { - new URI("") - } else { - new URI(defaultFSConf) - } - - val scheme = if (uri.getScheme() != null) { - uri.getScheme() - } else { - defaultFS.getScheme() - } - val authority = if (uri.getAuthority() != null) { - uri.getAuthority() - } else { - defaultFS.getAuthority() - } - - if (scheme == null) { - throw new AnalysisException( - s"LOAD DATA: URI scheme is required for non-local input paths: '$path'") - } - - // Follow Hive's behavior: - // If LOCAL is not specified, and the path is relative, - // then the path is interpreted relative to "/user/" - val uriPath = uri.getPath() - val absolutePath = if (uriPath != null && uriPath.startsWith("/")) { - uriPath - } else { - s"/user/${System.getProperty("user.name")}/$uriPath" - } - new URI(scheme, authority, absolutePath, uri.getQuery(), uri.getFragment()) - } - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val srcPath = new Path(hdfsUri) - val fs = srcPath.getFileSystem(hadoopConf) - if (!fs.exists(srcPath)) { - throw new AnalysisException(s"LOAD DATA input path does not exist: $path") - } - hdfsUri + val loadPath = new Path(path) + // Follow Hive's behavior: + // If no schema or authority is provided with non-local inpath, + // we will use hadoop configuration "fs.defaultFS". + val defaultFSConf = sparkSession.sessionState.newHadoopConf().get("fs.defaultFS") + val defaultFS = if (defaultFSConf == null) new URI("") else new URI(defaultFSConf) + // Follow Hive's behavior: + // If LOCAL is not specified, and the path is relative, + // then the path is interpreted relative to "/user/" + val uriPath = new Path(s"/user/${System.getProperty("user.name")}/") + // makeQualified() will ignore the query parameter part while creating a path, so the + // entire string will be considered while making a Path instance,this is mainly done + // by considering the wild card scenario in mind.as per old logic query param is + // been considered while creating URI instance and if path contains wild card char '?' + // the remaining charecters after '?' will be removed while forming URI instance + makeQualified(defaultFS, uriPath, loadPath) } - + } + val fs = loadPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + // This handling is because while resolving the invalid URLs starting with file:/// + // system throws IllegalArgumentException from globStatus API,so in order to handle + // such scenarios this code is added in try catch block and after catching the + // runtime exception a generic error will be displayed to the user. + try { + val fileStatus = fs.globStatus(loadPath) + if (fileStatus == null || fileStatus.isEmpty) { + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } + } catch { + case e: IllegalArgumentException => + log.warn(s"Exception while validating the load path $path ", e) + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") + } if (partition.nonEmpty) { catalog.loadPartition( targetTable.identifier, @@ -413,6 +363,36 @@ case class LoadDataCommand( CommandUtils.updateTableStats(sparkSession, targetTable) Seq.empty[Row] } + + /** + * Returns a qualified path object. Method ported from org.apache.hadoop.fs.Path class. + * + * @param defaultUri default uri corresponding to the filesystem provided. + * @param workingDir the working directory for the particular child path wd-relative names. + * @param path Path instance based on the path string specified by the user. + * @return qualified path object + */ + private def makeQualified(defaultUri: URI, workingDir: Path, path: Path): Path = { + val pathUri = if (path.isAbsolute()) path.toUri() else new Path(workingDir, path).toUri() + if (pathUri.getScheme == null || pathUri.getAuthority == null && + defaultUri.getAuthority != null) { + val scheme = if (pathUri.getScheme == null) defaultUri.getScheme else pathUri.getScheme + val authority = if (pathUri.getAuthority == null) { + if (defaultUri.getAuthority == null) "" else defaultUri.getAuthority + } else { + pathUri.getAuthority + } + try { + val newUri = new URI(scheme, authority, pathUri.getPath, pathUri.getFragment) + new Path(newUri) + } catch { + case e: URISyntaxException => + throw new IllegalArgumentException(e) + } + } else { + path + } + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 13aa2b843667c..20c4c36c05091 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1912,11 +1912,60 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("LOAD DATA LOCAL INPATH '/non-exist-folder/*part*' INTO TABLE load_t") }.getMessage assert(m.contains("LOAD DATA input path does not exist")) + } + } + } - val m2 = intercept[AnalysisException] { - sql(s"LOAD DATA LOCAL INPATH '$path*/*part*' INTO TABLE load_t") + test("Support wildcard character in folderlevel for LOAD DATA LOCAL INPATH") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t_folder_wildcard") { + sql("CREATE TABLE load_t (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '${ + path.substring(0, path.length - 1) + .concat("*") + }/' INTO TABLE load_t") + checkAnswer(sql("SELECT * FROM load_t"), Seq(Row("1"), Row("2"), Row("3"))) + val m = intercept[AnalysisException] { + sql(s"LOAD DATA LOCAL INPATH '${ + path.substring(0, path.length - 1).concat("_invalid_dir") concat ("*") + }/' INTO TABLE load_t") }.getMessage - assert(m2.contains("LOAD DATA input path allows only filename wildcard")) + assert(m.contains("LOAD DATA input path does not exist")) + } + } + } + + test("SPARK-17796 Support wildcard '?'char in middle as part of local file path") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t1") { + sql("CREATE TABLE load_t1 (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/part-r-0000?' INTO TABLE load_t1") + checkAnswer(sql("SELECT * FROM load_t1"), Seq(Row("1"), Row("2"), Row("3"))) + } + } + } + + test("SPARK-17796 Support wildcard '?'char in start as part of local file path") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + } + withTable("load_t2") { + sql("CREATE TABLE load_t2 (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/?art-r-00001' INTO TABLE load_t2") + checkAnswer(sql("SELECT * FROM load_t2"), Seq(Row("1"))) } } } From cd6dff78be2739fab60487bc3145118208f46b9e Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Fri, 24 Aug 2018 04:13:07 +0200 Subject: [PATCH 1470/2461] [SPARK-25209][SQL] Avoid deserializer check in Dataset.apply when Dataset is actually DataFrame ## What changes were proposed in this pull request? Dataset.apply calls dataset.deserializer (to provide an early error) which ends up calling the full Analyzer on the deserializer. This can take tens of milliseconds, depending on how big the plan is. Since Dataset.apply is called for many Dataset operations such as Dataset.where it can be a significant overhead for short queries. According to a comment in the PR that introduced this check, we can at least remove this check for DataFrames: https://github.com/apache/spark/pull/20402#discussion_r164338267 ## How was this patch tested? Existing tests + manual benchmark Author: Bogdan Raducanu Closes #22201 from bogdanrdc/deserializer-fix. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f65948d39a1cc..367b98563e0bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -65,7 +65,12 @@ private[sql] object Dataset { val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) // Eagerly bind the encoder so we verify that the encoder matches the underlying // schema. The user will get an error if this is not the case. - dataset.deserializer + // optimization: it is guaranteed that [[InternalRow]] can be converted to [[Row]] so + // do not do this check in that case. this check can be expensive since it requires running + // the whole [[Analyzer]] to resolve the deserializer + if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) { + dataset.deserializer + } dataset } From f2d35427eedeacceb6edb8a51974a7e8bbb94bc2 Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Thu, 23 Aug 2018 21:31:10 -0700 Subject: [PATCH 1471/2461] [SPARK-4502][SQL] Parquet nested column pruning - foundation (Link to Jira: https://issues.apache.org/jira/browse/SPARK-4502) _N.B. This is a restart of PR #16578 which includes a subset of that code. Relevant review comments from that PR should be considered incorporated by reference. Please avoid duplication in review by reviewing that PR first. The summary below is an edited copy of the summary of the previous PR._ ## What changes were proposed in this pull request? One of the hallmarks of a column-oriented data storage format is the ability to read data from a subset of columns, efficiently skipping reads from other columns. Spark has long had support for pruning unneeded top-level schema fields from the scan of a parquet file. For example, consider a table, `contacts`, backed by parquet with the following Spark SQL schema: ``` root |-- name: struct | |-- first: string | |-- last: string |-- address: string ``` Parquet stores this table's data in three physical columns: `name.first`, `name.last` and `address`. To answer the query ```SQL select address from contacts ``` Spark will read only from the `address` column of parquet data. However, to answer the query ```SQL select name.first from contacts ``` Spark will read `name.first` and `name.last` from parquet. This PR modifies Spark SQL to support a finer-grain of schema pruning. With this patch, Spark reads only the `name.first` column to answer the previous query. ### Implementation There are two main components of this patch. First, there is a `ParquetSchemaPruning` optimizer rule for gathering the required schema fields of a `PhysicalOperation` over a parquet file, constructing a new schema based on those required fields and rewriting the plan in terms of that pruned schema. The pruned schema fields are pushed down to the parquet requested read schema. `ParquetSchemaPruning` uses a new `ProjectionOverSchema` extractor for rewriting a catalyst expression in terms of a pruned schema. Second, the `ParquetRowConverter` has been patched to ensure the ordinals of the parquet columns read are correct for the pruned schema. `ParquetReadSupport` has been patched to address a compatibility mismatch between Spark's built in vectorized reader and the parquet-mr library's reader. ### Limitation Among the complex Spark SQL data types, this patch supports parquet column pruning of nested sequences of struct fields only. ## How was this patch tested? Care has been taken to ensure correctness and prevent regressions. A more advanced version of this patch incorporating optimizations for rewriting queries involving aggregations and joins has been running on a production Spark cluster at VideoAmp for several years. In that time, one bug was found and fixed early on, and we added a regression test for that bug. We forward-ported this patch to Spark master in June 2016 and have been running this patch against Spark 2.x branches on ad-hoc clusters since then. Closes #21320 from mallman/spark-4502-parquet_column_pruning-foundation. Lead-authored-by: Michael Allman Co-authored-by: Adam Jacques Co-authored-by: Michael Allman Signed-off-by: Xiao Li --- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../sql/catalyst/SchemaPruningTest.scala | 45 ++ .../sql/execution/GetStructFieldObject.scala | 33 ++ .../sql/execution/ProjectionOverSchema.scala | 59 +++ .../spark/sql/execution/SelectedField.scala | 134 ++++++ .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../parquet/ParquetSchemaPruning.scala | 257 ++++++++++ .../sql/execution/SelectedFieldSuite.scala | 455 ++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 6 +- .../parquet/ParquetSchemaPruningSuite.scala | 311 ++++++++++++ 10 files changed, 1313 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index df2caff902648..ef3ce98fd7add 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1419,8 +1419,18 @@ object SQLConf { "issues. Turn on this config to insert a local sort before actually doing repartition " + "to generate consistent repartition results. The performance of repartition() may go " + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + + val NESTED_SCHEMA_PRUNING_ENABLED = + buildConf("spark.sql.nestedSchemaPruning.enabled") + .internal() + .doc("Prune nested fields from a logical relation's output which are unnecessary in " + + "satisfying a query. This optimization allows columnar file format readers to avoid " + + "reading unnecessary nested column data. Currently Parquet is the only data source that " + + "implements this optimization.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") @@ -1895,6 +1905,8 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala new file mode 100644 index 0000000000000..68e76fc013c18 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf.NESTED_SCHEMA_PRUNING_ENABLED + +/** + * A PlanTest that ensures that all tests in this suite are run with nested schema pruning enabled. + * Remove this trait once the default value of SQLConf.NESTED_SCHEMA_PRUNING_ENABLED is set to true. + */ +private[sql] trait SchemaPruningTest extends PlanTest with BeforeAndAfterAll { + private var originalConfSchemaPruningEnabled = false + + override protected def beforeAll(): Unit = { + originalConfSchemaPruningEnabled = conf.nestedSchemaPruningEnabled + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, true) + super.beforeAll() + } + + override protected def afterAll(): Unit = { + try { + super.afterAll() + } finally { + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, originalConfSchemaPruningEnabled) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala new file mode 100644 index 0000000000000..c88b2f8c034fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField} +import org.apache.spark.sql.types.StructField + +/** + * A Scala extractor that extracts the child expression and struct field from a [[GetStructField]]. + * This is in contrast to the [[GetStructField]] case class extractor which returns the field + * ordinal instead of the field itself. + */ +private[execution] object GetStructFieldObject { + def unapply(getStructField: GetStructField): Option[(Expression, StructField)] = + Some(( + getStructField.child, + getStructField.childSchema(getStructField.ordinal))) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala new file mode 100644 index 0000000000000..2236f18b0da12 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that projects an expression over a given schema. Data types, + * field indexes and field counts of complex type extractors and attributes + * are adjusted to fit the schema. All other expressions are left as-is. This + * class is motivated by columnar nested schema pruning. + */ +private[execution] case class ProjectionOverSchema(schema: StructType) { + private val fieldNames = schema.fieldNames.toSet + + def unapply(expr: Expression): Option[Expression] = getProjection(expr) + + private def getProjection(expr: Expression): Option[Expression] = + expr match { + case a: AttributeReference if fieldNames.contains(a.name) => + Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) + case GetArrayItem(child, arrayItemOrdinal) => + getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } + case a: GetArrayStructFields => + getProjection(a.child).map(p => (p, p.dataType)).map { + case (projection, ArrayType(projSchema @ StructType(_), _)) => + GetArrayStructFields(projection, + projSchema(a.field.name), + projSchema.fieldIndex(a.field.name), + projSchema.size, + a.containsNull) + } + case GetMapValue(child, key) => + getProjection(child).map { projection => GetMapValue(projection, key) } + case GetStructFieldObject(child, field: StructField) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, projSchema: StructType) => + GetStructField(projection, projSchema.fieldIndex(field.name)) + } + case _ => + None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala new file mode 100644 index 0000000000000..0e7c593f9fb67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that builds a [[org.apache.spark.sql.types.StructField]] from a Catalyst + * complex type extractor. For example, consider a relation with the following schema: + * + * {{{ + * root + * |-- name: struct (nullable = true) + * | |-- first: string (nullable = true) + * | |-- last: string (nullable = true) + * }}} + * + * Further, suppose we take the select expression `name.first`. This will parse into an + * `Alias(child, "first")`. Ignoring the alias, `child` matches the following pattern: + * + * {{{ + * GetStructFieldObject( + * AttributeReference("name", StructType(_), _, _), + * StructField("first", StringType, _, _)) + * }}} + * + * [[SelectedField]] converts that expression into + * + * {{{ + * StructField("name", StructType(Array(StructField("first", StringType)))) + * }}} + * + * by mapping each complex type extractor to a [[org.apache.spark.sql.types.StructField]] with the + * same name as its child (or "parent" going right to left in the select expression) and a data + * type appropriate to the complex type extractor. In our example, the name of the child expression + * is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string + * field named "first". + * + * @param expr the top-level complex type extractor + */ +private[execution] object SelectedField { + def unapply(expr: Expression): Option[StructField] = { + // If this expression is an alias, work on its child instead + val unaliased = expr match { + case Alias(child, _) => child + case expr => expr + } + selectField(unaliased, None) + } + + private def selectField(expr: Expression, fieldOpt: Option[StructField]): Option[StructField] = { + expr match { + // No children. Returns a StructField with the attribute name or None if fieldOpt is None. + case AttributeReference(name, dataType, nullable, metadata) => + fieldOpt.map(field => + StructField(name, wrapStructType(dataType, field), nullable, metadata)) + // Handles case "expr0.field[n]", where "expr0" is of struct type and "expr0.field" is of + // array type. + case GetArrayItem(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field[n]", where "expr0.field" is of array type. + case GetArrayItem(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type. + case GetArrayStructFields(child: GetArrayStructFields, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field", where "expr0" is of array type. + case GetArrayStructFields(child, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = + fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of + // map type. + case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, + nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + // Handles case "expr0.field[key]", where "expr0.field" is of map type. + case GetMapValue(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field", where expr0 is of struct type. + case GetStructFieldObject(child, + field @ StructField(name, dataType, nullable, metadata)) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) + case _ => + None + } + } + + // Constructs a composition of complex types with a StructType(Array(field)) at its core. Returns + // a StructType for a StructType, an ArrayType for an ArrayType and a MapType for a MapType. + private def wrapStructType(dataType: DataType, field: StructField): DataType = { + dataType match { + case _: StructType => + StructType(Array(field)) + case ArrayType(elementType, containsNull) => + ArrayType(wrapStructType(elementType, field), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType(keyType, wrapStructType(valueType, field), valueContainsNull) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 64d3f2cdbfa82..969def7624058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -31,7 +32,8 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala new file mode 100644 index 0000000000000..6a46b5f8edc54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectionOverSchema, SelectedField} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} + +/** + * Prunes unnecessary Parquet columns given a [[PhysicalOperation]] over a + * [[ParquetRelation]]. By "Parquet column", we mean a column as defined in the + * Parquet format. In Spark SQL, a root-level Parquet column corresponds to a + * SQL column, and a nested Parquet column corresponds to a [[StructField]]. + */ +private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + if (SQLConf.get.nestedSchemaPruningEnabled) { + apply0(plan) + } else { + plan + } + + private def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op @ PhysicalOperation(projects, filters, + l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) + if canPruneRelation(hadoopFsRelation) => + val (normalizedProjects, normalizedFilters) = + normalizeAttributeRefNames(l, projects, filters) + val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) + + // If requestedRootFields includes a nested field, continue. Otherwise, + // return op + if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + val dataSchema = hadoopFsRelation.dataSchema + val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields) + + // If the data schema is different from the pruned data schema, continue. Otherwise, + // return op. We effect this comparison by counting the number of "leaf" fields in + // each schemata, assuming the fields in prunedDataSchema are a subset of the fields + // in dataSchema. + if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { + val prunedParquetRelation = + hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) + + val prunedRelation = buildPrunedRelation(l, prunedParquetRelation) + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + + buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation, + projectionOverSchema) + } else { + op + } + } else { + op + } + } + + /** + * Checks to see if the given relation is Parquet and can be pruned. + */ + private def canPruneRelation(fsRelation: HadoopFsRelation) = + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + + /** + * Normalizes the names of the attribute references in the given projects and filters to reflect + * the names in the given logical relation. This makes it possible to compare attributes and + * fields by name. Returns a tuple with the normalized projects and filters, respectively. + */ + private def normalizeAttributeRefNames( + logicalRelation: LogicalRelation, + projects: Seq[NamedExpression], + filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = { + val normalizedAttNameMap = logicalRelation.output.map(att => (att.exprId, att.name)).toMap + val normalizedProjects = projects.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }).map { case expr: NamedExpression => expr } + val normalizedFilters = filters.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }) + (normalizedProjects, normalizedFilters) + } + + /** + * Returns the set of fields from the Parquet file that the query plan needs. + */ + private def identifyRootFields(projects: Seq[NamedExpression], filters: Seq[Expression]) = { + val projectionRootFields = projects.flatMap(getRootFields) + val filterRootFields = filters.flatMap(getRootFields) + + (projectionRootFields ++ filterRootFields).distinct + } + + /** + * Builds the new output [[Project]] Spark SQL operator that has the pruned output relation. + */ + private def buildNewProjection( + projects: Seq[NamedExpression], filters: Seq[Expression], prunedRelation: LogicalRelation, + projectionOverSchema: ProjectionOverSchema) = { + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + Filter(newFilterCondition, prunedRelation) + } else { + prunedRelation + } + + // Construct the new projections of our Project by + // rewriting the original projections + val newProjects = projects.map(_.transformDown { + case projectionOverSchema(expr) => expr + }).map { case expr: NamedExpression => expr } + + if (log.isDebugEnabled) { + logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") + } + + Project(newProjects, projectionChild) + } + + /** + * Filters the schema from the given file by the requested fields. + * Schema field ordering from the file is preserved. + */ + private def pruneDataSchema( + fileDataSchema: StructType, + requestedRootFields: Seq[RootField]) = { + // Merge the requested root fields into a single schema. Note the ordering of the fields + // in the resulting schema may differ from their ordering in the logical relation's + // original schema + val mergedSchema = requestedRootFields + .map { case RootField(field, _) => StructType(Array(field)) } + .reduceLeft(_ merge _) + val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet + val mergedDataSchema = + StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + // Sort the fields of mergedDataSchema according to their order in dataSchema, + // recursively. This makes mergedDataSchema a pruned schema of dataSchema + sortLeftFieldsByRight(mergedDataSchema, fileDataSchema).asInstanceOf[StructType] + } + + /** + * Builds a pruned logical relation from the output of the output relation and the schema of the + * pruned base relation. + */ + private def buildPrunedRelation( + outputRelation: LogicalRelation, + prunedBaseRelation: HadoopFsRelation) = { + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = outputRelation.output.map(att => (att.name, att.exprId)).toMap + val prunedRelationOutput = + prunedBaseRelation + .schema + .toAttributes + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + outputRelation.copy(relation = prunedBaseRelation, output = prunedRelationOutput) + } + + /** + * Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]]. + * When expr is an [[Attribute]], construct a field around it and indicate that that + * field was derived from an attribute. + */ + private def getRootFields(expr: Expression): Seq[RootField] = { + expr match { + case att: Attribute => + RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil + case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil + case _ => + expr.children.flatMap(getRootFields) + } + } + + /** + * Counts the "leaf" fields of the given dataType. Informally, this is the + * number of fields of non-complex data type in the tree representation of + * [[DataType]]. + */ + private def countLeaves(dataType: DataType): Int = { + dataType match { + case array: ArrayType => countLeaves(array.elementType) + case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType) + case struct: StructType => + struct.map(field => countLeaves(field.dataType)).sum + case _ => 1 + } + } + + /** + * Sorts the fields and descendant fields of structs in left according to their order in + * right. This function assumes that the fields of left are a subset of the fields of + * right, recursively. That is, left is a "subschema" of right, ignoring order of + * fields. + */ + private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) => + ArrayType( + sortLeftFieldsByRight(leftElementType, rightElementType), + containsNull) + case (MapType(leftKeyType, leftValueType, containsNull), + MapType(rightKeyType, rightValueType, _)) => + MapType( + sortLeftFieldsByRight(leftKeyType, rightKeyType), + sortLeftFieldsByRight(leftValueType, rightValueType), + containsNull) + case (leftStruct: StructType, rightStruct: StructType) => + val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) + val sortedLeftFields = filteredRightFieldNames.map { fieldName => + val leftFieldType = leftStruct(fieldName).dataType + val rightFieldType = rightStruct(fieldName).dataType + val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) + StructField(fieldName, sortedLeftFieldType) + } + StructType(sortedLeftFields) + case _ => left + } + + /** + * A "root" schema field (aka top-level, no-parent) and whether it was derived from + * an attribute or had a proper child. + */ + private case class RootField(field: StructField, derivedFromAtt: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala new file mode 100644 index 0000000000000..05f7e3ce83880 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { + private val ignoredField = StructField("col1", StringType, nullable = false) + + // The test schema as a tree string, i.e. `schema.treeString` + // root + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field1: integer (nullable = true) + // | |-- field6: struct (nullable = true) + // | | |-- subfield1: string (nullable = false) + // | | |-- subfield2: string (nullable = true) + // | |-- field7: struct (nullable = true) + // | | |-- subfield1: struct (nullable = true) + // | | | |-- subsubfield1: integer (nullable = true) + // | | | |-- subsubfield2: integer (nullable = true) + // | |-- field9: map (nullable = true) + // | | |-- key: string + // | | |-- value: integer (valueContainsNull = false) + private val nestedComplex = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field1", IntegerType) :: + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: + StructField("field9", + MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: Nil) + + test("SelectedField should not match an attribute reference") { + val testRelation = LocalRelation(nestedComplex.toAttributes) + assertResult(None)(unapplySelect("col1", testRelation)) + assertResult(None)(unapplySelect("col1 as foo", testRelation)) + assertResult(None)(unapplySelect("col2", testRelation)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field2: array (nullable = true) + // | | |-- element: integer (containsNull = false) + // | |-- field3: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: integer (nullable = true) + // | | | |-- subfield3: array (nullable = true) + // | | | | |-- element: integer (containsNull = true) + private val structOfArray = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) + :: Nil)) + :: Nil) + + testSelect(structOfArray, "col2.field2", "col2.field2[0] as foo") { + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field9", "col2.field9['foo'] as foo") { + StructField("col2", StructType( + StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) + } + + testSelect(structOfArray, "col2.field3.subfield3", "col2.field3[0].subfield3 as foo", + "col2.field3.subfield3[0] as foo", "col2.field3[0].subfield3[0] as foo") { + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structOfArray, "col2.field3.subfield1") { + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), nullable = false) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field4: map (nullable = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: array (nullable = true) + // | | | | |-- element: integer (containsNull = false) + // | |-- field8: map (nullable = true) + // | | |-- key: string + // | | |-- value: array (valueContainsNull = false) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: array (nullable = true) + // | | | | | |-- element: integer (containsNull = false) + private val structWithMap = StructType( + ignoredField :: + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil + ), valueContainsNull = false)) :: + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil) + ), valueContainsNull = false)) :: Nil + )) :: Nil + ) + + testSelect(structWithMap, "col2.field4['foo'].subfield1 as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: Nil), valueContainsNull = false)) :: Nil)) + } + + testSelect(structWithMap, + "col2.field4['foo'].subfield2 as foo", "col2.field4['foo'].subfield2[0] as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) + :: Nil), valueContainsNull = false)) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field5: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: struct (nullable = false) + // | | | | |-- subsubfield1: integer (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + // | | | |-- subfield2: struct (nullable = true) + // | | | | |-- subsubfield1: struct (nullable = true) + // | | | | | |-- subsubsubfield1: string (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + private val structWithArray = StructType( + ignoredField :: + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)), nullable = false) :: Nil) + ) :: Nil + ) + + testSelect(structWithArray, "col2.field5.subfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithArray, "col2.field5.subfield1.subsubfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithArray, "col2.field5.subfield2.subsubfield1.subsubsubfield1") { + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: Nil)) + :: Nil)), nullable = false) :: Nil)) + } + + testSelect(structWithMap, "col2.field8['foo'][0].subfield1 as foo") { + StructField("col2", StructType( + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), valueContainsNull = false)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field1") { + StructField("col2", StructType( + StructField("field1", IntegerType) :: Nil)) + } + + testSelect(nestedComplex, "col2.field6") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field6.subfield1") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field7.subfield1") { + StructField("col2", StructType( + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: Nil)) + } + + // |-- col1: string (nullable = false) + // |-- col3: array (nullable = false) + // | |-- element: struct (containsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val arrayWithStructAndMap = StructType(Array( + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + )) + + testSelect(arrayWithStructAndMap, "col3.field1.subfield1") { + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), containsNull = false), nullable = false) + } + + testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") { + StructField("col3", ArrayType(StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col4: map (nullable = false) + // | |-- key: string + // | |-- value: struct (valueContainsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val col4 = StructType(Array(ignoredField, + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + )) + + testSelect(col4, "col4['foo'].field1.subfield1 as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), valueContainsNull = false), nullable = false) + } + + testSelect(col4, "col4['foo'].field2['bar'] as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col5: array (nullable = true) + // | |-- element: map (containsNull = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val arrayOfStruct = StructType(Array(ignoredField, + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + )) + + testSelect(arrayOfStruct, "col5[0]['foo'].field1.subfield1 as foo") { + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + } + + // |-- col1: string (nullable = false) + // |-- col6: map (nullable = true) + // | |-- key: string + // | |-- value: array (valueContainsNull = true) + // | | |-- element: struct (containsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val mapOfArray = StructType(Array(ignoredField, + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))))) + + testSelect(mapOfArray, "col6['foo'][0].field1.subfield1 as foo") { + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false))) + } + + // An array with a struct with a different fields + // |-- col1: string (nullable = false) + // |-- col7: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: integer (nullable = false) + // | | |-- field2: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | |-- field3: array (nullable = true) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = false) + private val arrayWithMultipleFields = StructType(Array(ignoredField, + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))))) + + testSelect(arrayWithMultipleFields, + "col7.field1", "col7[0].field1 as foo", "col7.field1[0] as foo") { + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: Nil))) + } + + testSelect(arrayWithMultipleFields, "col7.field2.subfield1") { + StructField("col7", ArrayType(StructType( + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) + } + + testSelect(arrayWithMultipleFields, "col7.field3.subfield1") { + StructField("col7", ArrayType(StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) + } + + // Array with a nested int array + // |-- col1: string (nullable = false) + // |-- col8: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: array (nullable = false) + // | | | |-- element: integer (containsNull = false) + private val arrayOfArray = StructType(Array(ignoredField, + StructField("col8", + ArrayType(StructType(Array(StructField("field1", + ArrayType(IntegerType, containsNull = false), nullable = false)))) + ))) + + testSelect(arrayOfArray, "col8.field1", + "col8[0].field1 as foo", + "col8.field1[0] as foo", + "col8[0].field1[0] as foo") { + StructField("col8", ArrayType(StructType( + StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) + :: Nil))) + } + + def assertResult(expected: StructField)(actual: StructField)(selectExpr: String): Unit = { + try { + super.assertResult(expected)(actual) + } catch { + case ex: TestFailedException => + // Print some helpful diagnostics in the case of failure + alert("Expected SELECT \"" + selectExpr + "\" to select the schema\n" + + indent(StructType(expected :: Nil).treeString) + + indent("but it actually selected\n") + + indent(StructType(actual :: Nil).treeString) + + indent("Note that expected.dataType.sameType(actual.dataType) = " + + expected.dataType.sameType(actual.dataType))) + throw ex + } + } + + // Test that the given SELECT expressions prune the test schema to the single-column schema + // defined by the given field + private def testSelect(inputSchema: StructType, selectExprs: String*) + (expected: StructField) { + test(s"SELECT ${selectExprs.map(s => s""""$s"""").mkString(", ")} should select the schema\n" + + indent(StructType(expected :: Nil).treeString)) { + for (selectExpr <- selectExprs) { + assertSelect(selectExpr, expected, inputSchema) + } + } + } + + private def assertSelect(expr: String, expected: StructField, inputSchema: StructType): Unit = { + val relation = LocalRelation(inputSchema.toAttributes) + unapplySelect(expr, relation) match { + case Some(field) => + assertResult(expected)(field)(expr) + case None => + val failureMessage = + "Failed to select a field from " + expr + ". " + + "Expected:\n" + + StructType(expected :: Nil).treeString + fail(failureMessage) + } + } + + private def unapplySelect(expr: String, relation: LocalRelation) = { + val parsedExpr = parseAsCatalystExpression(Seq(expr)).head + val select = relation.select(parsedExpr) + val analyzed = select.analyze + SelectedField.unapply(analyzed.expressions.head) + } + + private def parseAsCatalystExpression(exprs: Seq[String]) = { + exprs.map(CatalystSqlParser.parseExpression(_) match { + case namedExpr: NamedExpression => namedExpr + }) + } + + // Indent every line in `string` by four spaces + private def indent(string: String) = string.replaceAll("(?m)^", " ") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index dbf637783e6d2..54c77dddc3525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -108,7 +108,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -117,7 +117,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -126,7 +126,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala new file mode 100644 index 0000000000000..eb99654fa78f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class ParquetSchemaPruningSuite + extends QueryTest + with ParquetTest + with SchemaPruningTest + with SharedSQLContext { + case class FullName(first: String, middle: String, last: String) + case class Contact( + id: Int, + name: FullName, + address: String, + pets: Int, + friends: Array[FullName] = Array.empty, + relatives: Map[String, FullName] = Map.empty) + + val janeDoe = FullName("Jane", "X.", "Doe") + val johnDoe = FullName("John", "Y.", "Doe") + val susanSmith = FullName("Susan", "Z.", "Smith") + + private val contacts = + Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), + relatives = Map("brother" -> johnDoe)) :: + Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe)) :: Nil + + case class Name(first: String, last: String) + case class BriefContact(id: Int, name: Name, address: String) + + private val briefContacts = + BriefContact(2, Name("Janet", "Jones"), "567 Maple Drive") :: + BriefContact(3, Name("Jim", "Jones"), "6242 Ash Street") :: Nil + + case class ContactWithDataPartitionColumn( + id: Int, + name: FullName, + address: String, + pets: Int, + friends: Array[FullName] = Array(), + relatives: Map[String, FullName] = Map(), + p: Int) + + case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int) + + private val contactsWithDataPartitionColumn = + contacts.map { case Contact(id, name, address, pets, friends, relatives) => + ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, 1) } + private val briefContactsWithDataPartitionColumn = + briefContacts.map { case BriefContact(id, name, address) => + BriefContactWithDataPartitionColumn(id, name, address, 2) } + + testSchemaPruning("select a single complex field") { + val query = sql("select name.middle from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), Row("X.") :: Row("Y.") :: Row(null) :: Row(null) :: Nil) + } + + testSchemaPruning("select a single complex field and its parent struct") { + val query = sql("select name.middle, name from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("X.", Row("Jane", "X.", "Doe")) :: + Row("Y.", Row("John", "Y.", "Doe")) :: + Row(null, Row("Janet", null, "Jones")) :: + Row(null, Row("Jim", null, "Jones")) :: + Nil) + } + + testSchemaPruning("select a single complex field array and its parent struct array") { + val query = sql("select friends.middle, friends from contacts where p=1") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row(Array("Z."), Array(Row("Susan", "Z.", "Smith"))) :: + Row(Array.empty[String], Array.empty[Row]) :: + Nil) + } + + testSchemaPruning("select a single complex field from a map entry and its parent map entry") { + val query = + sql("select relatives[\"brother\"].middle, relatives[\"brother\"] from contacts where p=1") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row("Y.", Row("John", "Y.", "Doe")) :: + Row(null, null) :: + Nil) + } + + testSchemaPruning("select a single complex field and the partition column") { + val query = sql("select name.middle, p from contacts") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) + } + + ignore("partial schema intersection - select missing subfield") { + val query = sql("select name.middle, address from contacts where p=2") + checkScan(query, "struct,address:string>") + checkAnswer(query.orderBy("id"), + Row(null, "567 Maple Drive") :: + Row(null, "6242 Ash Street") :: Nil) + } + + testSchemaPruning("no unnecessary schema pruning") { + val query = + sql("select id, name.last, name.middle, name.first, relatives[''].last, " + + "relatives[''].middle, relatives[''].first, friends[0].last, friends[0].middle, " + + "friends[0].first, pets, address from contacts where p=2") + // We've selected every field in the schema. Therefore, no schema pruning should be performed. + // We check this by asserting that the scanned schema of the query is identical to the schema + // of the contacts relation, even though the fields are selected in different orders. + checkScan(query, + "struct,address:string,pets:int," + + "friends:array>," + + "relatives:map>>") + checkAnswer(query.orderBy("id"), + Row(2, "Jones", null, "Janet", null, null, null, null, null, null, null, "567 Maple Drive") :: + Row(3, "Jones", null, "Jim", null, null, null, null, null, null, null, "6242 Ash Street") :: + Nil) + } + + testSchemaPruning("empty schema intersection") { + val query = sql("select name.middle from contacts where p=2") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row(null) :: Row(null) :: Nil) + } + + private def testSchemaPruning(testName: String)(testThunk: => Unit) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + test(s"Spark vectorized reader - without partition data column - $testName") { + withContacts(testThunk) + } + test(s"Spark vectorized reader - with partition data column - $testName") { + withContactsWithDataPartitionColumn(testThunk) + } + } + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + test(s"Parquet-mr reader - without partition data column - $testName") { + withContacts(testThunk) + } + test(s"Parquet-mr reader - with partition data column - $testName") { + withContactsWithDataPartitionColumn(testThunk) + } + } + } + + private def withContacts(testThunk: => Unit) { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + testThunk + } + } + + private def withContactsWithDataPartitionColumn(testThunk: => Unit) { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1")) + makeParquetFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + testThunk + } + } + + case class MixedCaseColumn(a: String, B: Int) + case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn) + + private val mixedCaseData = + MixedCase(0, "r0c1", MixedCaseColumn("abc", 1)) :: + MixedCase(1, "r1c1", MixedCaseColumn("123", 2)) :: + Nil + + testMixedCasePruning("select with exact column names") { + val query = sql("select CoL1, coL2.B from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCasePruning("select with lowercase column names") { + val query = sql("select col1, col2.b from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCasePruning("select with different-case column names") { + val query = sql("select cOL1, cOl2.b from mixedcase") + checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + + testMixedCasePruning("filter with different-case column names") { + val query = sql("select id from mixedcase where Col2.b = 2") + // Pruning with filters is currently unsupported. As-is, the file reader will read the id column + // and the entire coL2 struct. Once pruning with filters has been implemented we can uncomment + // this line + // checkScan(query, "struct>") + checkAnswer(query.orderBy("id"), Row(1) :: Nil) + } + + private def testMixedCasePruning(testName: String)(testThunk: => Unit) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "true") { + test(s"Spark vectorized reader - case-sensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "false") { + test(s"Parquet-mr reader - case-insensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "false") { + test(s"Spark vectorized reader - case-insensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "true") { + test(s"Parquet-mr reader - case-sensitive parser - mixed-case schema - $testName") { + withMixedCaseData(testThunk) + } + } + } + + private def withMixedCaseData(testThunk: => Unit) { + withParquetTable(mixedCaseData, "mixedcase") { + testThunk + } + } + + private val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => a.sameType(otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + // We check here that we can execute the query without throwing an exception. The results + // themselves are irrelevant, and should be checked elsewhere as needed + df.collect() + } + + private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } +} From 9b6baeb7b9f73e1a38581f481ea7232db712deb8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 23 Aug 2018 21:36:53 -0700 Subject: [PATCH 1472/2461] [SPARK-25029][BUILD][CORE] Janino "Two non-abstract methods ..." errors ## What changes were proposed in this pull request? Update to janino 3.0.9 to address Java 8 + Scala 2.12 incompatibility. The error manifests as test failures like this in `ExpressionEncoderSuite`: ``` - encode/decode for seq of string: List(abc, xyz) *** FAILED *** java.lang.RuntimeException: Error while encoding: org.codehaus.janino.InternalCompilerException: failed to compile: org.codehaus.janino.InternalCompilerException: Compiling "GeneratedClass": Two non-abstract methods "public int scala.collection.TraversableOnce.size()" have the same parameter types, declaring type and return type ``` It comes up pretty immediately in any generated code that references Scala collections, and virtually always concerning the `size()` method. ## How was this patch tested? Existing tests Closes #22203 from srowen/SPARK-25029. Authored-by: Sean Owen Signed-off-by: Xiao Li --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index aca5bbbe02eb6..fc42af905c2fe 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -34,7 +34,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar +commons-compiler-3.0.9.jar commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -98,7 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.8.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 0902af1dfdc65..54e50556b4620 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -34,7 +34,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar +commons-compiler-3.0.9.jar commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -98,7 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.8.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 35cf1cd85fd80..ff5713b5b66b7 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -31,7 +31,7 @@ commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.8.jar +commons-compiler-3.0.9.jar commons-compress-1.8.1.jar commons-configuration2-2.1.1.jar commons-crypto-1.0.0.jar @@ -97,7 +97,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar -janino-3.0.8.jar +janino-3.0.9.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/pom.xml b/pom.xml index a4184c3153336..6988c65348652 100644 --- a/pom.xml +++ b/pom.xml @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.8 + 3.0.9 2.22.2 2.9.3 3.5.2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2c56456cd4dac..b8f09761f61ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1329,7 +1329,7 @@ object CodeGenerator extends Logging { evaluator.setParentClassLoader(parentClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") - evaluator.setDefaultImports(Array( + evaluator.setDefaultImports( classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, @@ -1344,7 +1344,7 @@ object CodeGenerator extends Logging { classOf[TaskContext].getName, classOf[TaskKilledException].getName, classOf[InputMetrics].getName - )) + ) evaluator.setExtendedClass(classOf[GeneratedClass]) logDebug({ From ab33028957443189efc4106afd9d65dddf8f9c98 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Aug 2018 14:58:55 +0900 Subject: [PATCH 1473/2461] [SPARK-25178][SQL] Directly ship the StructType objects of the keySchema / valueSchema for xxxHashMapGenerator ## What changes were proposed in this pull request? This PR generates the code that to refer a `StructType` generated in the scala code instead of generating `StructType` in Java code. The original code has two issues. 1. Avoid to used the field name such as `key.name` 1. Support complicated schema (e.g. nested DataType) At first, [the JIRA entry](https://issues.apache.org/jira/browse/SPARK-25178) proposed to change the generated field name of the keySchema / valueSchema to a dummy name in `RowBasedHashMapGenerator` and `VectorizedHashMapGenerator.scala`. This proposal can addresse issue 1. Ueshin suggested an approach to refer to a `StructType` generated in the scala code using `ctx.addReferenceObj()`. This approach can address issues 1 and 2. Finally, this PR uses this approach. ## How was this patch tested? Existing UTs Closes #22187 from kiszk/SPARK-25178. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../aggregate/RowBasedHashMapGenerator.scala | 33 ++--------------- .../VectorizedHashMapGenerator.scala | 37 +++---------------- 2 files changed, 10 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index d5508275c48c5..ca59bb145f299 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -44,31 +44,8 @@ class RowBasedHashMapGenerator( groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedKeySchema: String = - s"new org.apache.spark.sql.types.StructType()" + - groupingKeySchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedValueSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) + val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; @@ -78,8 +55,6 @@ class RowBasedHashMapGenerator( | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema - | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema | private Object emptyVBase; | private long emptyVOff; | private int emptyVLen; @@ -90,9 +65,9 @@ class RowBasedHashMapGenerator( | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate(keySchema, valueSchema, taskMemoryManager, capacity); + | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); | - | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); + | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); | | emptyVBase = emptyBuffer; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7b3580cecc60d..95ebefed08f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -52,31 +52,9 @@ class VectorizedHashMapGenerator( groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - (groupingKeySchema ++ bufferSchema).map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedAggBufferSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray) + val schema = ctx.addReferenceObj("schemaTerm", schemaStructType) + val aggBufferSchemaFieldsLength = bufferSchema.fields.length s""" | private ${classOf[OnHeapColumnVector].getName}[] vectors; @@ -88,18 +66,15 @@ class VectorizedHashMapGenerator( | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType schema = $generatedSchema - | private org.apache.spark.sql.types.StructType aggregateBufferSchema = - | $generatedAggBufferSchema | | public $generatedClassName() { - | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema); | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = - | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; - | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength]; + | for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) { | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); From c20916a5dc4a7e771463838e797cb944569f6259 Mon Sep 17 00:00:00 2001 From: s71955 Date: Fri, 24 Aug 2018 08:58:19 -0500 Subject: [PATCH 1474/2461] [SPARK-25073][YARN] AM and Executor Memory validation message is not proper while submitting spark yarn application **## What changes were proposed in this pull request?** When the yarn.nodemanager.resource.memory-mb or yarn.scheduler.maximum-allocation-mb memory assignment is insufficient, Spark always reports an error request to adjust yarn.scheduler.maximum-allocation-mb even though in message it shows the memory value of yarn.nodemanager.resource.memory-mb parameter,As the error Message is bit misleading to the user we can modify the same, We can keep the error message same as executor memory validation message. Defintion of **yarn.nodemanager.resource.memory-mb:** Amount of physical memory, in MB, that can be allocated for containers. It means the amount of memory YARN can utilize on this node and therefore this property should be lower then the total memory of that machine. **yarn.scheduler.maximum-allocation-mb:** It defines the maximum memory allocation available for a container in MB it means RM can only allocate memory to containers in increments of "yarn.scheduler.minimum-allocation-mb" and not exceed "yarn.scheduler.maximum-allocation-mb" and It should not be more than total allocated memory of the Node. **## How was this patch tested?** Manually tested in hdfs-Yarn clustaer Closes #22199 from sujith71955/maste_am_log. Authored-by: s71955 Signed-off-by: Sean Owen --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 75614a41e0b62..698fc2ce8bf9d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -344,7 +344,8 @@ private[spark] class Client( if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory ($amMemory" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + + "'yarn.nodemanager.resource.memory-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, From 8bb9414aaff4a147db2d921dccdbd04c8eb4e5db Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 24 Aug 2018 12:00:34 -0700 Subject: [PATCH 1475/2461] [SPARK-25214][SS] Fix the issue that Kafka v2 source may return duplicated records when `failOnDataLoss=false` ## What changes were proposed in this pull request? When there are missing offsets, Kafka v2 source may return duplicated records when `failOnDataLoss=false` because it doesn't skip missing offsets. This PR fixes the issue and also adds regression tests for all Kafka readers. ## How was this patch tested? New tests. Closes #22207 from zsxwing/SPARK-25214. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../kafka010/KafkaMicroBatchReadSupport.scala | 2 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 38 --- .../KafkaDontFailOnDataLossSuite.scala | 272 ++++++++++++++++++ .../kafka010/KafkaMicroBatchSourceSuite.scala | 139 +-------- 4 files changed, 276 insertions(+), 175 deletions(-) create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index c31af60b8a1c2..70f37e32e78db 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -341,6 +341,7 @@ private[kafka010] case class KafkaMicroBatchPartitionReader( val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) if (record != null) { nextRow = converter.toUnsafeRow(record) + nextOffset = record.offset + 1 true } else { false @@ -352,7 +353,6 @@ private[kafka010] case class KafkaMicroBatchPartitionReader( override def get(): UnsafeRow = { assert(nextRow != null) - nextOffset += 1 nextRow } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 8b4494d2e9a25..f8b90056d2931 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -77,44 +77,6 @@ private[kafka010] class KafkaSourceRDD( offsetRanges.zipWithIndex.map { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray } - override def count(): Long = offsetRanges.map(_.size).sum - - override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = { - val nonEmptyPartitions = - this.partitions.map(_.asInstanceOf[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0) - - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0) - } - - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.offsetRange.size) - result + (part.index -> taken.toInt) - } else { - result - } - } - - val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - override def getPreferredLocations(split: Partition): Seq[String] = { val part = split.asInstanceOf[KafkaSourceRDDPartition] part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala new file mode 100644 index 0000000000000..0ff341c1a3db7 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.util.Random + +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter} +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +/** + * This is a basic test trait which will set up a Kafka cluster that keeps only several records in + * a topic and ages out records very quickly. This is a helper trait to test + * "failonDataLoss=false" case with missing offsets. + * + * Note: there is a hard-code 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) to clean up + * records. Hence each class extending this trait needs to wait at least 30 seconds (or even longer + * when running on a slow Jenkins machine) before records start to be removed. To make sure a test + * does see missing offsets, you can check the earliest offset in `eventually` and make sure it's + * not 0 rather than sleeping a hard-code duration. + */ +trait KafkaMissingOffsetsTest extends SharedSQLContext { + + protected var testUtils: KafkaTestUtils = _ + + override def createSparkSession(): TestSparkSession = { + // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic + new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) + } + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils { + override def brokerConfiguration: Properties = { + val props = super.brokerConfiguration + // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code + // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at + // least 30 seconds. + props.put("log.cleaner.backoff.ms", "100") + // The size of RecordBatch V2 increases to support transactional write. + props.put("log.segment.bytes", "70") + props.put("log.retention.bytes", "40") + props.put("log.retention.check.interval.ms", "100") + props.put("delete.retention.ms", "10") + props.put("log.flush.scheduler.interval.ms", "10") + props + } + } + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } +} + +class KafkaDontFailOnDataLossSuite extends KafkaMissingOffsetsTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + /** + * @param testStreamingQuery whether to test a streaming query or a batch query. + * @param writeToTable the function to write the specified [[DataFrame]] to the given table. + */ + private def verifyMissingOffsetsDontCauseDuplicatedRecords( + testStreamingQuery: Boolean)(writeToTable: (DataFrame, String) => Unit): Unit = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (0 until 50).map(_.toString).toArray) + + eventually(timeout(60.seconds)) { + assert( + testUtils.getEarliestOffsets(Set(topic)).head._2 > 0, + "Kafka didn't delete records after 1 minute") + } + + val table = "DontFailOnDataLoss" + withTable(table) { + val kafkaOptions = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "kafka.metadata.max.age.ms" -> "1", + "subscribe" -> topic, + "startingOffsets" -> s"""{"$topic":{"0":0}}""", + "failOnDataLoss" -> "false", + "kafkaConsumer.pollTimeoutMs" -> "1000") + val df = + if (testStreamingQuery) { + val reader = spark.readStream.format("kafka") + kafkaOptions.foreach(kv => reader.option(kv._1, kv._2)) + reader.load() + } else { + val reader = spark.read.format("kafka") + kafkaOptions.foreach(kv => reader.option(kv._1, kv._2)) + reader.load() + } + writeToTable(df.selectExpr("CAST(value AS STRING)"), table) + val result = spark.table(table).as[String].collect().toList + assert(result.distinct.size === result.size, s"$result contains duplicated records") + // Make sure Kafka did remove some records so that this test is valid. + assert(result.size > 0 && result.size < 50) + } + } + + test("failOnDataLoss=false should not return duplicated records: v1") { + withSQLConf( + "spark.sql.streaming.disabledV2MicroBatchReaders" -> + classOf[KafkaSourceProvider].getCanonicalName) { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream.format("memory").queryName(table).start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + } + + test("failOnDataLoss=false should not return duplicated records: v2") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream.format("memory").queryName(table).start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + + test("failOnDataLoss=false should not return duplicated records: continuous processing") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = true) { (df, table) => + val query = df.writeStream + .format("memory") + .queryName(table) + .trigger(Trigger.Continuous(100)) + .start() + try { + query.processAllAvailable() + } finally { + query.stop() + } + } + } + + test("failOnDataLoss=false should not return duplicated records: batch") { + verifyMissingOffsetsDontCauseDuplicatedRecords(testStreamingQuery = false) { (df, table) => + df.write.saveAsTable(table) + } + } +} + +class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = true + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = {} + }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) + + val testTime = 1.minutes + val startTime = System.currentTimeMillis() + // Track the current existing topics + val topics = mutable.ArrayBuffer[String]() + // Track topics that have been deleted + val deletedTopics = mutable.Set[String]() + while (System.currentTimeMillis() - testTime.toMillis < startTime) { + Random.nextInt(10) match { + case 0 => // Create a new topic + val topic = newTopic() + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 1 if topics.nonEmpty => // Delete an existing topic + val topic = topics.remove(Random.nextInt(topics.size)) + testUtils.deleteTopic(topic) + logInfo(s"Delete topic $topic") + deletedTopics += topic + case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. + val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) + deletedTopics -= topic + topics += topic + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, always overwrite to handle this race condition. + testUtils.createTopic(topic, partitions = 1, overwrite = true) + logInfo(s"Create topic $topic") + case 3 => + Thread.sleep(1000) + case _ => // Push random messages + for (topic <- topics) { + val size = Random.nextInt(10) + for (_ <- 0 until size) { + testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) + } + } + } + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } + + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + if (query.exception.nonEmpty) { + throw query.exception.get + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index c9c52503dcd1f..1d1455009251c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Properties} +import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.io.Source import scala.util.Random @@ -36,8 +35,7 @@ import org.json4s.jackson.JsonMethods._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution @@ -46,7 +44,7 @@ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.test.SharedSQLContext abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { @@ -1187,134 +1185,3 @@ class KafkaSourceStressSuite extends KafkaSourceTest { iterations = 50) } } - -class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with SharedSQLContext { - - import testImplicits._ - - private var testUtils: KafkaTestUtils = _ - - private val topicId = new AtomicInteger(0) - - private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" - - override def createSparkSession(): TestSparkSession = { - // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic - new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) - } - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils { - override def brokerConfiguration: Properties = { - val props = super.brokerConfiguration - // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code - // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at - // least 30 seconds. - props.put("log.cleaner.backoff.ms", "100") - // The size of RecordBatch V2 increases to support transactional write. - props.put("log.segment.bytes", "70") - props.put("log.retention.bytes", "40") - props.put("log.retention.check.interval.ms", "100") - props.put("delete.retention.ms", "10") - props.put("log.flush.scheduler.interval.ms", "10") - props - } - } - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - super.afterAll() - } - } - - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("kafka.default.api.timeout.ms", "3000") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) - - val testTime = 1.minutes - val startTime = System.currentTimeMillis() - // Track the current existing topics - val topics = mutable.ArrayBuffer[String]() - // Track topics that have been deleted - val deletedTopics = mutable.Set[String]() - while (System.currentTimeMillis() - testTime.toMillis < startTime) { - Random.nextInt(10) match { - case 0 => // Create a new topic - val topic = newTopic() - topics += topic - // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small - // chance that a topic will be recreated after deletion due to the asynchronous update. - // Hence, always overwrite to handle this race condition. - testUtils.createTopic(topic, partitions = 1, overwrite = true) - logInfo(s"Create topic $topic") - case 1 if topics.nonEmpty => // Delete an existing topic - val topic = topics.remove(Random.nextInt(topics.size)) - testUtils.deleteTopic(topic) - logInfo(s"Delete topic $topic") - deletedTopics += topic - case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted. - val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size)) - deletedTopics -= topic - topics += topic - // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small - // chance that a topic will be recreated after deletion due to the asynchronous update. - // Hence, always overwrite to handle this race condition. - testUtils.createTopic(topic, partitions = 1, overwrite = true) - logInfo(s"Create topic $topic") - case 3 => - Thread.sleep(1000) - case _ => // Push random messages - for (topic <- topics) { - val size = Random.nextInt(10) - for (_ <- 0 until size) { - testUtils.sendMessages(topic, Array(Random.nextInt(10).toString)) - } - } - } - // `failOnDataLoss` is `false`, we should not fail the query - if (query.exception.nonEmpty) { - throw query.exception.get - } - } - - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - if (query.exception.nonEmpty) { - throw query.exception.get - } - } -} From f8346d2fc01f1e881e4e3f9c4499bf5f9e3ceb3f Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 24 Aug 2018 13:44:19 -0700 Subject: [PATCH 1476/2461] [SPARK-25174][YARN] Limit the size of diagnostic message for am to unregister itself from rm ## What changes were proposed in this pull request? When using older versions of spark releases, a use case generated a huge code-gen file which hit the limitation `Constant pool has grown past JVM limit of 0xFFFF`. In this situation, it should fail immediately. But the diagnosis message sent to RM is too large, the ApplicationMaster suspended and RM's ZKStateStore was crashed. For 2.3 or later spark releases the limitation of code-gen has been removed, but maybe there are still some uncaught exceptions that contain oversized error message will cause such a problem. This PR is aim to cut down the diagnosis message size. ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22180 from yaooqinn/SPARK-25174. Authored-by: Kent Yao Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 5 +++-- .../main/scala/org/apache/spark/deploy/yarn/config.scala | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 55ed114f8500f..8f94e3f731007 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier} -import java.net.{Socket, URI, URL} +import java.net.{URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -28,6 +28,7 @@ import scala.concurrent.Promise import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import org.apache.commons.lang3.{StringUtils => ComStrUtils} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ @@ -368,7 +369,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } logInfo(s"Final app status: $finalStatus, exitCode: $exitCode" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - finalMsg = msg + finalMsg = ComStrUtils.abbreviate(msg, sparkConf.get(AM_FINAL_MSG_LIMIT).toInt) finished = true if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { logDebug("shutting down reporter thread") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1013fd2cc4a82..ab8273bd6321d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -192,6 +192,12 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val AM_FINAL_MSG_LIMIT = ConfigBuilder("spark.yarn.am.finalMessageLimit") + .doc("The limit size of final diagnostic message for our ApplicationMaster to unregister from" + + " the ResourceManager.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + /* Client-mode AM configuration. */ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores") From 9714fa547325ed7b6a8066a88957537936b233dd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 24 Aug 2018 15:03:00 -0700 Subject: [PATCH 1477/2461] [SPARK-25234][SPARKR] avoid integer overflow in parallelize ## What changes were proposed in this pull request? `parallelize` uses integer multiplication to determine the split indices. It might cause integer overflow. ## How was this patch tested? unit test Closes #22225 from mengxr/SPARK-25234. Authored-by: Xiangrui Meng Signed-off-by: Xiangrui Meng --- R/pkg/R/context.R | 9 ++++----- R/pkg/tests/fulltests/test_context.R | 7 +++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 7e77ea4e002d9..f168ca76b6007 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) { sizeLimit <- getMaxAllocationLimit(sc) objectSize <- object.size(coll) + len <- length(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSerializedSlices > length(coll)) - numSerializedSlices <- length(coll) + numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit))) # Generate the slice ids to put each row # For instance, for numSerializedSlices of 22, length of 50 @@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) { splits <- if (numSerializedSlices > 0) { unlist(lapply(0: (numSerializedSlices - 1), function(x) { # nolint start - start <- trunc((x * length(coll)) / numSerializedSlices) - end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + start <- trunc((as.numeric(x) * len) / numSerializedSlices) + end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices) # nolint end rep(start, end - start) })) diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f0d0a5114f89f..288a2714a554e 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", { unlink(path, recursive = TRUE) sparkR.session.stop() }) + +test_that("SPARK-25234: parallelize should not have integer overflow", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + # 47000 * 47000 exceeds integer range + parallelize(sc, 1:47000, 47000) + sparkR.session.stop() +}) From 8e6427871a40b82f4a7a28aaa6e197e4e01dc878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?jaroslav=20chl=C3=A1dek?= Date: Sat, 25 Aug 2018 12:49:48 +0800 Subject: [PATCH 1478/2461] Correct missing punctuation in the documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22189 from movrsprbp/patch-1. Authored-by: jaroslav chládek Signed-off-by: hyukjinkwon --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 355a6cc26973e..73de1892977ac 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1005,7 +1005,7 @@ Here is an illustration. As shown in the illustration, the maximum event time tracked by the engine is the *blue dashed line*, and the watermark set as `(max event time - '10 mins')` -at the beginning of every trigger is the red line For example, when the engine observes the data +at the beginning of every trigger is the red line. For example, when the engine observes the data `(12:14, dog)`, it sets the watermark for the next trigger as `12:04`. This watermark lets the engine maintain intermediate state for additional 10 minutes to allow late data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in From 3e4f1666a1253f9d5df05c19b1ce77fe18e9fde3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sat, 25 Aug 2018 13:48:46 +0800 Subject: [PATCH 1479/2461] [MINOR] Fix Scala 2.12 build ## What changes were proposed in this pull request? [SPARK-25095](https://github.com/apache/spark/commit/ad45299d047c10472fd3a86103930fe7c54a4cf1) introduced `ambiguous reference to overloaded definition` ``` [error] /Users/d_tsai/dev/apache-spark/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:242: ambiguous reference to overloaded definition, [error] both method addTaskCompletionListener in class TaskContext of type [U](f: org.apache.spark.TaskContext => U)org.apache.spark.TaskContext [error] and method addTaskCompletionListener in class TaskContext of type (listener: org.apache.spark.util.TaskCompletionListener)org.apache.spark.TaskContext [error] match argument types (org.apache.spark.TaskContext => Unit) [error] context.addTaskCompletionListener(_ => server.close()) [error] ^ [error] one error found [error] Compile failed at Aug 24, 2018 1:56:06 PM [31.582s] ``` which fails the Scala 2.12 branch build. ## How was this patch tested? Existing tests Closes #22229 from dbtsai/fix-2.12-build. Authored-by: DB Tsai Signed-off-by: hyukjinkwon --- .../main/scala/org/apache/spark/api/python/PythonRunner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f8241915e4849..151c910bf1aee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -239,7 +239,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // Close ServerSocket on task completion. serverSocket.foreach { server => - context.addTaskCompletionListener(_ => server.close()) + context.addTaskCompletionListener[Unit](_ => server.close()) } val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) if (boundPort == -1) { From 6c66ab8b334c5358bc77995650f1886e4c43231d Mon Sep 17 00:00:00 2001 From: Huangweizhe Date: Sat, 25 Aug 2018 09:24:20 -0500 Subject: [PATCH 1480/2461] [SPARK-24688][EXAMPLES] Modify the comments about LabeledPoint ## What changes were proposed in this pull request? An RDD is created using LabeledPoint, but the comment is like #LabeledPoint(feature, label). Although in the method ChiSquareTest.test, the second parameter is feature and the third parameter is label, it it better to write label in front of feature here because if an RDD is created using LabeldPoint, what we get are actually (label, feature) pairs. Now it is changed as LabeledPoint(label, feature). The comments in Scala and Java example have the same typos. ## How was this patch tested? tested https://issues.apache.org/jira/browse/SPARK-24688 Author: Weizhe Huang 492816239qq.com Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21665 from uzmijnlm/my_change. Authored-by: Huangweizhe Signed-off-by: Sean Owen --- .../spark/examples/mllib/JavaHypothesisTestingExample.java | 2 +- examples/src/main/python/mllib/hypothesis_testing_example.py | 2 +- .../spark/examples/mllib/HypothesisTestingExample.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java index b48b95ff1d2a3..273273652c955 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java @@ -67,7 +67,7 @@ public static void main(String[] args) { ) ); - // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // The contingency table is constructed from the raw (label, feature) pairs and used to conduct // the independence test. Returns an array containing the ChiSquaredTestResult for every feature // against the label. ChiSqTestResult[] featureTestResults = Statistics.chiSqTest(obs.rdd()); diff --git a/examples/src/main/python/mllib/hypothesis_testing_example.py b/examples/src/main/python/mllib/hypothesis_testing_example.py index e566ead0d318d..21a5584fd6e06 100644 --- a/examples/src/main/python/mllib/hypothesis_testing_example.py +++ b/examples/src/main/python/mllib/hypothesis_testing_example.py @@ -51,7 +51,7 @@ [LabeledPoint(1.0, [1.0, 0.0, 3.0]), LabeledPoint(1.0, [1.0, 2.0, 0.0]), LabeledPoint(1.0, [-1.0, 0.0, -0.5])] - ) # LabeledPoint(feature, label) + ) # LabeledPoint(label, feature) # The contingency table is constructed from an RDD of LabeledPoint and used to conduct # the independence test. Returns an array containing the ChiSquaredTestResult for every feature diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index add1719739539..9b3c3266ee30a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -61,9 +61,9 @@ object HypothesisTestingExample { LabeledPoint(-1.0, Vectors.dense(-1.0, 0.0, -0.5) ) ) - ) // (feature, label) pairs. + ) // (label, feature) pairs. - // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // The contingency table is constructed from the raw (label, feature) pairs and used to conduct // the independence test. Returns an array containing the ChiSquaredTestResult for every feature // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) From c17a8ff52377871ab4ff96b648ebaf4112f0b5be Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 25 Aug 2018 09:17:40 -0700 Subject: [PATCH 1481/2461] [SPARK-25214][SS][FOLLOWUP] Fix the issue that Kafka v2 source may return duplicated records when `failOnDataLoss=false` ## What changes were proposed in this pull request? This is a follow up PR for #22207 to fix a potential flaky test. `processAllAvailable` doesn't work for continuous processing so we should not use it for a continuous query. ## How was this patch tested? Jenkins. Closes #22230 from zsxwing/SPARK-25214-2. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala index 0ff341c1a3db7..39c4e3fda1a4b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala @@ -80,7 +80,7 @@ trait KafkaMissingOffsetsTest extends SharedSQLContext { } } -class KafkaDontFailOnDataLossSuite extends KafkaMissingOffsetsTest { +class KafkaDontFailOnDataLossSuite extends StreamTest with KafkaMissingOffsetsTest { import testImplicits._ @@ -165,7 +165,11 @@ class KafkaDontFailOnDataLossSuite extends KafkaMissingOffsetsTest { .trigger(Trigger.Continuous(100)) .start() try { - query.processAllAvailable() + // `processAllAvailable` doesn't work for continuous processing, so just wait until the last + // record appears in the table. + eventually(timeout(streamingTimeout)) { + assert(spark.table(table).as[String].collect().contains("49")) + } } finally { query.stop() } From ad43e2c1e8c2142a66b135766ff0d7712ce965db Mon Sep 17 00:00:00 2001 From: Adam Bradbury Date: Sun, 26 Aug 2018 08:37:52 -0500 Subject: [PATCH 1482/2461] [SPARK-23792][DOCS] Documentation improvements for datetime functions ## What changes were proposed in this pull request? Improved the documentation for the datetime functions in `org.apache.spark.sql.functions` by adding details about the supported column input types, the column return type, behaviour on invalid input, supporting examples and clarifications. ## How was this patch tested? Manually testing each of the datetime functions with different input to ensure that the corresponding Javadoc/Scaladoc matches the behaviour of the function. Successfully ran the `unidoc` SBT process. Closes #20901 from abradbury/SPARK-23792. Authored-by: Adam Bradbury Signed-off-by: Sean Owen --- .../org/apache/spark/sql/functions.scala | 189 ++++++++++++++---- 1 file changed, 154 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9331883c4799..1d806e056d31c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2626,8 +2626,12 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns the date that is numMonths after startDate. + * Returns the date that is `numMonths` after `startDate`. * + * @param startDate A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param numMonths The number of months to add to `startDate`, can be negative to subtract months + * @return A date, or null if `startDate` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2655,12 +2659,15 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. - * All pattern letters of `java.text.SimpleDateFormat` can be used. + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns * + * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param format A pattern `dd.MM.yyyy` would return a string like `18.03.1993` + * @return A string, or null if `dateExpr` was a string that could not be cast to a timestamp * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. - * + * @throws IllegalArgumentException if the `format` pattern is invalid * @group datetime_funcs * @since 1.5.0 */ @@ -2670,6 +2677,11 @@ object functions { /** * Returns the date that is `days` days after `start` + * + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to add to `start`, can be negative to subtract days + * @return A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2677,6 +2689,11 @@ object functions { /** * Returns the date that is `days` days before `start` + * + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param days The number of days to subtract from `start`, can be negative to add days + * @return A date, or null if `start` was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2684,6 +2701,19 @@ object functions { /** * Returns the number of days from `start` to `end`. + * + * Only considers the date part of the input. For example: + * {{{ + * dateddiff("2018-01-10 00:00:00", "2018-01-09 23:59:59") + * // returns 1 + * }}} + * + * @param end A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return An integer, or null if either `end` or `start` were strings that could not be cast to + * a date. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ @@ -2691,6 +2721,7 @@ object functions { /** * Extracts the year as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2698,6 +2729,7 @@ object functions { /** * Extracts the quarter as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2705,6 +2737,7 @@ object functions { /** * Extracts the month as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2712,6 +2745,8 @@ object functions { /** * Extracts the day of the week as an integer from a given date/timestamp/string. + * Ranges from 1 for a Sunday through to 7 for a Saturday + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 2.3.0 */ @@ -2719,6 +2754,7 @@ object functions { /** * Extracts the day of the month as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2726,6 +2762,7 @@ object functions { /** * Extracts the day of the year as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2733,16 +2770,20 @@ object functions { /** * Extracts the hours as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def hour(e: Column): Column = withExpr { Hour(e.expr) } /** - * Given a date column, returns the last day of the month which the given date belongs to. + * Returns the last day of the month which the given date belongs to. * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the * month in July 2015. * + * @param e A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A date, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2750,46 +2791,60 @@ object functions { /** * Extracts the minutes as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ def minute(e: Column): Column = withExpr { Minute(e.expr) } /** - * Returns number of months between dates `date1` and `date2`. - * If `date1` is later than `date2`, then the result is positive. - * If `date1` and `date2` are on the same day of month, or both are the last day of month, - * time of day will be ignored. + * Returns number of months between dates `start` and `end`. + * + * A whole number is returned if both inputs have the same day of month or both are the last day + * of their respective months. Otherwise, the difference is calculated assuming 31 days per month. * - * Otherwise, the difference is calculated based on 31 days per month, and rounded to - * 8 digits. + * For example: + * {{{ + * months_between("2017-11-14", "2017-07-14") // returns 4.0 + * months_between("2017-01-01", "2017-01-10") // returns 0.29032258 + * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5 + * }}} + * + * @param end A date, timestamp or string. If a string, the data must be in a format that can + * be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that can + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A double, or null if either `end` or `start` were strings that could not be cast to a + * timestamp. Negative if `end` is before `start` * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = withExpr { - new MonthsBetween(date1.expr, date2.expr) + def months_between(end: Column, start: Column): Column = withExpr { + new MonthsBetween(end.expr, start.expr) } /** - * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * Returns number of months between dates `end` and `start`. If `roundOff` is set to true, the * result is rounded off to 8 digits; it is not rounded otherwise. * @group datetime_funcs * @since 2.4.0 */ - def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { - MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) + def months_between(end: Column, start: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(end.expr, start.expr, lit(roundOff).expr) } /** - * Given a date column, returns the first date which is later than the value of the date column - * that is on the specified day of the week. + * Returns the first date which is later than the value of the `date` column that is on the + * specified day of the week. * * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first * Sunday after 2015-07-27. * - * Day of the week parameter is case insensitive, and accepts: - * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". - * + * @param date A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param dayOfWeek Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" + * @return A date, or null if `date` was a string that could not be cast to a date or if + * `dayOfWeek` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2799,6 +2854,7 @@ object functions { /** * Extracts the seconds as an integer from a given date/timestamp/string. + * @return An integer, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 1.5.0 */ @@ -2806,6 +2862,11 @@ object functions { /** * Extracts the week number as an integer from a given date/timestamp/string. + * + * A week is considered to start on a Monday and week 1 is the first week with more than 3 days, + * as defined by ISO 8601 + * + * @return An integer, or null if the input was a string that could not be cast to a date * @group datetime_funcs * @since 1.5.0 */ @@ -2813,8 +2874,12 @@ object functions { /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string - * representing the timestamp of that moment in the current system time zone in the given - * format. + * representing the timestamp of that moment in the current system time zone in the + * yyyy-MM-dd HH:mm:ss format. + * + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @return A string, or null if the input was a string that could not be cast to a long * @group datetime_funcs * @since 1.5.0 */ @@ -2826,6 +2891,14 @@ object functions { * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string * representing the timestamp of that moment in the current system time zone in the given * format. + * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param ut A number of a type that is castable to a long, such as string or integer. Can be + * negative for timestamps before the unix epoch + * @param f A date time pattern that the input will be formatted to + * @return A string, or null if `ut` was a string that could not be cast to a long or `f` was + * an invalid date time pattern * @group datetime_funcs * @since 1.5.0 */ @@ -2834,7 +2907,7 @@ object functions { } /** - * Returns the current Unix timestamp (in seconds). + * Returns the current Unix timestamp (in seconds) as a long. * * @note All calls of `unix_timestamp` within the same query return the same value * (i.e. the current timestamp is calculated at the start of query evaluation). @@ -2849,8 +2922,10 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), * using the default timezone and the default locale. - * Returns `null` if fails. * + * @param s A date, timestamp or string. If a string, the data must be in the + * `yyyy-MM-dd HH:mm:ss` format + * @return A long, or null if the input was a string not of the correct format * @group datetime_funcs * @since 1.5.0 */ @@ -2860,17 +2935,25 @@ object functions { /** * Converts time string with given pattern to Unix timestamp (in seconds). - * Returns `null` if fails. * - * @see - * Customizing Formats + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param p A date time pattern detailing the format of `s` when `s` is a string + * @return A long, or null if `s` was a string that could not be cast to a date or `p` was + * an invalid format * @group datetime_funcs * @since 1.5.0 */ def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** - * Convert time string to a Unix timestamp (in seconds) by casting rules to `TimestampType`. + * Converts to a timestamp by casting rules to `TimestampType`. + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A timestamp, or null if the input was a string that could not be cast to a timestamp * @group datetime_funcs * @since 2.2.0 */ @@ -2879,9 +2962,15 @@ object functions { } /** - * Convert time string to a Unix timestamp (in seconds) with a specified format - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix timestamp (in seconds), return null if fail. + * Converts time string with the given pattern to timestamp. + * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param s A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt A date time pattern detailing the format of `s` when `s` is a string + * @return A timestamp, or null if `s` was a string that could not be cast to a timestamp or + * `fmt` was an invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -2899,9 +2988,14 @@ object functions { /** * Converts the column into a `DateType` with a specified format - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * return null if fail. * + * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * + * @param e A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param fmt A date time pattern detailing the format of `e` when `e`is a string + * @return A date, or null if `e` was a string that could not be cast to a date or `fmt` was an + * invalid format * @group datetime_funcs * @since 2.2.0 */ @@ -2912,9 +3006,15 @@ object functions { /** * Returns date truncated to the unit specified by the format. * + * For example, `trunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 + * + * @param date A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` * @param format: 'year', 'yyyy', 'yy' for truncate by year, * or 'month', 'mon', 'mm' for truncate by month * + * @return A date, or null if `date` was a string that could not be cast to a date or `format` + * was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2925,11 +3025,16 @@ object functions { /** * Returns timestamp truncated to the unit specified by the format. * + * For example, `date_tunc("2018-11-19 12:01:19", "year")` returns 2018-01-01 00:00:00 + * * @param format: 'year', 'yyyy', 'yy' for truncate by year, * 'month', 'mon', 'mm' for truncate by month, * 'day', 'dd' for truncate by day, * Other options are: 'second', 'minute', 'hour', 'week', 'month', 'quarter' - * + * @param timestamp A date, timestamp or string. If a string, the data must be in a format that + * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp + * or `format` was an invalid value * @group datetime_funcs * @since 2.3.0 */ @@ -2941,6 +3046,13 @@ object functions { * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield * '2017-07-14 03:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone that the input should be adjusted to, such as + * `Europe/London`, `PST` or `GMT+5` + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value * @group datetime_funcs * @since 1.5.0 */ @@ -2963,6 +3075,13 @@ object functions { * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield * '2017-07-14 01:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param tz A string detailing the time zone that the input belongs to, such as `Europe/London`, + * `PST` or `GMT+5` + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or + * `tz` was an invalid value * @group datetime_funcs * @since 1.5.0 */ From 5cdb8a23df6f269d6be0bf3536e9af9e29c4a05f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 27 Aug 2018 10:02:31 +0800 Subject: [PATCH 1483/2461] [SPARK-23698][PYTHON][FOLLOWUP] Resolve undefiend names in setup.py ## What changes were proposed in this pull request? `__version__` in `setup.py` is currently being dynamically read by `exec`; so the linter complains. Better just switch it off for this line for now. **Before:** ```bash $ python -m flake8 . --count --select=E9,F82 --show-source --statistics ./setup.py:37:11: F821 undefined name '__version__' VERSION = __version__ ^ 1 F821 undefined name '__version__' 1 ``` **After:** ```bash $ python -m flake8 . --count --select=E9,F82 --show-source --statistics 0 ``` ## How was this patch tested? Manually tested. Closes #22235 from HyukjinKwon/SPARK-23698. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index 45eb74eb87ce7..c447f2d40343d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -34,7 +34,7 @@ print("Failed to load PySpark version file for packaging. You must be in Spark's python dir.", file=sys.stderr) sys.exit(-1) -VERSION = __version__ +VERSION = __version__ # noqa # A temporary path so we can access above the Python project root and fetch scripts and jars we need TEMP_PATH = "deps" SPARK_HOME = os.path.abspath("../") From 5c27b0d4f8d378bd7889d26fb358f478479b9996 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 27 Aug 2018 14:02:50 +0800 Subject: [PATCH 1484/2461] [SPARK-19355][SQL][FOLLOWUP] Remove the child.outputOrdering check in global limit ## What changes were proposed in this pull request? This is based on the discussion https://github.com/apache/spark/pull/16677/files#r212805327. As SQL standard doesn't mandate that a nested order by followed by a limit has to respect that ordering clause, this patch removes the `child.outputOrdering` check. ## How was this patch tested? Unit tests. Closes #22239 from viirya/improve-global-limit-parallelism-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/execution/limit.scala | 10 +++++----- .../sql/execution/TakeOrderedAndProjectSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 392ca13724bc6..fb46970e38f3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -122,11 +122,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { Nil } - // During global limit, try to evenly distribute limited rows across data - // partitions. If disabled, scanning data partitions sequentially until reaching limit number. - // Besides, if child output has certain ordering, we can't evenly pick up rows from - // each parititon. - val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && child.outputOrdering == Nil + // This is an optimization to evenly distribute limited rows across all partitions. + // When enabled, Spark goes to take rows at each partition repeatedly until reaching + // limit number. When disabled, Spark takes all rows at first partition, then rows + // at second partition ..., until reaching limit number. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit val shuffled = new ShuffledRowRDD(shuffleDependency) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 7e317a4d80265..0a1c94cc4ccf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -31,10 +32,19 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 + private val originalLimitFlatGlobalLimit = SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT) + protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) + + // Disable the optimization to make Sort-Limit match `TakeOrderedAndProject` semantics. + SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) + } + + protected override def afterAll() = { + SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) } private def generateRandomInputData(): DataFrame = { From 6193a202aab0271b4532ee4b740318290f2c44a1 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 27 Aug 2018 15:45:48 +0800 Subject: [PATCH 1485/2461] [SPARK-24978][SQL] Add spark.sql.fast.hash.aggregate.row.max.capacity to configure the capacity of fast aggregation. ## What changes were proposed in this pull request? this pr add a configuration parameter to configure the capacity of fast aggregation. Performance comparison: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Windows 7 6.1 Intel64 Family 6 Model 94 Stepping 3, GenuineIntel Aggregate w multiple keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ fasthash = default 5612 / 5882 3.7 267.6 1.0X fasthash = config 3586 / 3595 5.8 171.0 1.6X ``` ## How was this patch tested? the existed test cases. Closes #21931 from heary-cao/FastHashCapacity. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++++++++ .../sql/execution/aggregate/HashAggregateExec.scala | 5 +++-- .../aggregate/RowBasedHashMapGenerator.scala | 5 +++-- .../aggregate/VectorizedHashMapGenerator.scala | 5 +++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ef3ce98fd7add..6336e89671937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1485,6 +1485,17 @@ object SQLConf { .intConf .createWithDefault(20) + val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT = + buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit") + .internal() + .doc("Capacity for the max number of rows to be held in memory " + + "by the fast hash aggregate product operator. The bit is not for actual value, " + + "but the actual numBuckets is determined by loadFactor " + + "(e.g: default bit value 16 , the actual numBuckets is ((1 << 16) / 0.5).") + .intConf + .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") + .createWithDefault(16) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") @@ -1703,6 +1714,8 @@ class SQLConf extends Serializable with Logging { def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2cac0cfce28de..98adba50b2973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -579,6 +579,7 @@ case class HashAggregateExec( case _ => } } + val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit val thisPlan = ctx.addReferenceObj("plan", this) @@ -588,7 +589,7 @@ case class HashAggregateExec( val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() ctx.addInnerClass(generatedMap) // Inline mutable state since not many aggregation operations in a task @@ -598,7 +599,7 @@ case class HashAggregateExec( forceInline = true) } else { val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() ctx.addInnerClass(generatedMap) // Inline mutable state since not many aggregation operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index ca59bb145f299..3d2443ca959a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -39,7 +39,8 @@ class RowBasedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) + bufferSchema: StructType, + bitMaxCapacity: Int) extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { @@ -50,7 +51,7 @@ class RowBasedHashMapGenerator( s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; | private int[] buckets; - | private int capacity = 1 << 16; + | private int capacity = 1 << $bitMaxCapacity; | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 95ebefed08f67..f9c4ecc14e6c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -47,7 +47,8 @@ class VectorizedHashMapGenerator( aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, - bufferSchema: StructType) + bufferSchema: StructType, + bitMaxCapacity: Int) extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName, groupingKeySchema, bufferSchema) { @@ -61,7 +62,7 @@ class VectorizedHashMapGenerator( | private ${classOf[ColumnarBatch].getName} batch; | private ${classOf[MutableColumnarRow].getName} aggBufferRow; | private int[] buckets; - | private int capacity = 1 << 16; + | private int capacity = 1 << $bitMaxCapacity; | private double loadFactor = 0.5; | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; From 381a967a76c9e7ea1e100a922cafedc50042b81e Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 27 Aug 2018 12:05:33 -0500 Subject: [PATCH 1486/2461] [SPARK-25249][CORE][TEST] add a unit test for OpenHashMap ## What changes were proposed in this pull request? This PR adds a unit test for OpenHashMap , this can help developers to distinguish between the 0/0.0/0L and null ## How was this patch tested? Closes #22241 from 10110346/openhashmap. Authored-by: liuxian Signed-off-by: Sean Owen --- .../util/collection/OpenHashMapSuite.scala | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 08a3200288981..151235dd0fb90 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -194,4 +194,50 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { val numInvalidValues = map.iterator.count(_._2 == 0) assertResult(0)(numInvalidValues) } + + test("distinguish between the 0/0.0/0L and null") { + val specializedMap1 = new OpenHashMap[String, Long] + specializedMap1("a") = null.asInstanceOf[Long] + specializedMap1("b") = 0L + assert(specializedMap1.contains("a")) + assert(!specializedMap1.contains("c")) + // null.asInstance[Long] will return 0L + assert(specializedMap1("a") === 0L) + assert(specializedMap1("b") === 0L) + // If the data type is in @specialized annotation, and + // the `key` is not be contained, the `map(key)` will return 0 + assert(specializedMap1("c") === 0L) + + val specializedMap2 = new OpenHashMap[String, Double] + specializedMap2("a") = null.asInstanceOf[Double] + specializedMap2("b") = 0.toDouble + assert(specializedMap2.contains("a")) + assert(!specializedMap2.contains("c")) + // null.asInstance[Double] will return 0.0 + assert(specializedMap2("a") === 0.0) + assert(specializedMap2("b") === 0.0) + assert(specializedMap2("c") === 0.0) + + val map1 = new OpenHashMap[String, Short] + map1("a") = null.asInstanceOf[Short] + map1("b") = 0.toShort + assert(map1.contains("a")) + assert(!map1.contains("c")) + // null.asInstance[Short] will return 0 + assert(map1("a") === 0) + assert(map1("b") === 0) + // If the data type is not in @specialized annotation, and + // the `key` is not be contained, the `map(key)` will return null + assert(map1("c") === null) + + val map2 = new OpenHashMap[String, Float] + map2("a") = null.asInstanceOf[Float] + map2("b") = 0.toFloat + assert(map2.contains("a")) + assert(!map2.contains("c")) + // null.asInstance[Float] will return 0.0 + assert(map2("a") === 0.0) + assert(map2("b") === 0.0) + assert(map2("c") === null) + } } From 810d59ce44e43f725d1b6d822166c2d97ff49929 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 27 Aug 2018 11:04:39 -0700 Subject: [PATCH 1487/2461] [SPARK-24882][FOLLOWUP] Fix flaky synchronization in Kafka tests. ## What changes were proposed in this pull request? Fix flaky synchronization in Kafka tests - we need to use the scan config that was persisted rather than reconstructing it to identify the stream's current configuration. We caught most instances of this in the original PR, but this one slipped through. ## How was this patch tested? n/a Closes #22245 from jose-torres/fixflake. Authored-by: Jose Torres Signed-off-by: Shixiong Zhu --- .../sql/kafka010/KafkaContinuousSourceSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 321665042b8eb..5d68a14326c00 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. @@ -60,10 +60,10 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { testUtils.createTopic(topic2, partitions = 5) eventually(timeout(streamingTimeout)) { assert( - query.lastExecution.logical.collectFirst { - case r: StreamingDataSourceV2Relation - if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] => - r.scanConfigBuilder.build().asInstanceOf[KafkaContinuousScanConfig] + query.lastExecution.executedPlan.collectFirst { + case scan: DataSourceV2ScanExec + if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => + scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists { config => // Ensure the new topic is present and the old topic is gone. config.knownPartitions.exists(_.topic == topic2) From c3f285c939ba046de5171ada9c4bbb1a2589635d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 27 Aug 2018 13:26:55 -0700 Subject: [PATCH 1488/2461] [SPARK-24149][YARN][FOLLOW-UP] Only get the delegation tokens of the filesystem explicitly specified by the user ## What changes were proposed in this pull request? Our HDFS cluster configured 5 nameservices: `nameservices1`, `nameservices2`, `nameservices3`, `nameservices-dev1` and `nameservices4`, but `nameservices-dev1` unstable. So sometimes an error occurred and causing the entire job failed since [SPARK-24149](https://issues.apache.org/jira/browse/SPARK-24149): ![image](https://user-images.githubusercontent.com/5399861/42434779-f10c48fc-8386-11e8-98b0-4d9786014744.png) I think it's best to add a switch here. ## How was this patch tested? manual tests Closes #21734 from wangyum/SPARK-24149. Authored-by: Yuming Wang Signed-off-by: Marcelo Vanzin --- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 7250e58b6c49a..3a3272216294f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -27,11 +27,8 @@ import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} import org.apache.hadoop.yarn.util.ConverterUtils -import org.apache.spark.{SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager -import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.util.Utils @@ -193,8 +190,7 @@ object YarnSparkHadoopUtil { sparkConf: SparkConf, hadoopConf: Configuration): Set[FileSystem] = { val filesystemsToAccess = sparkConf.get(FILESYSTEMS_TO_ACCESS) - .map(new Path(_).getFileSystem(hadoopConf)) - .toSet + val requestAllDelegationTokens = filesystemsToAccess.isEmpty val stagingFS = sparkConf.get(STAGING_DIR) .map(new Path(_).getFileSystem(hadoopConf)) @@ -203,8 +199,8 @@ object YarnSparkHadoopUtil { // Add the list of available namenodes for all namespaces in HDFS federation. // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its // namespaces. - val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { - Set.empty + val hadoopFilesystems = if (!requestAllDelegationTokens || stagingFS.getScheme == "viewfs") { + filesystemsToAccess.map(new Path(_).getFileSystem(hadoopConf)).toSet } else { val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") // Retrieving the filesystem for the nameservices where HA is not enabled @@ -222,7 +218,7 @@ object YarnSparkHadoopUtil { (filesystemsWithoutHA ++ filesystemsWithHA).toSet } - filesystemsToAccess ++ hadoopFilesystems + stagingFS + hadoopFilesystems + stagingFS } } From dac099d08251e73b9a658e506ed6802b294ac051 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Mon, 27 Aug 2018 15:55:34 -0500 Subject: [PATCH 1489/2461] [SPARK-24090][K8S] Update running-on-kubernetes.md ## What changes were proposed in this pull request? Updated documentation for Spark on Kubernetes for the upcoming 2.4.0. Please review http://spark.apache.org/contributing.html before opening a pull request. mccheah erikerlandson Closes #22224 from liyinan926/master. Authored-by: Yinan Li Signed-off-by: Sean Owen --- docs/running-on-kubernetes.md | 40 ++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 8f84ca044e163..c83dad6df1e7b 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -185,6 +185,36 @@ To use a secret through an environment variable use the following options to the --conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key ``` +## Using Kubernetes Volumes + +Starting with Spark 2.4.0, users can mount the following types of Kubernetes [volumes](https://kubernetes.io/docs/concepts/storage/volumes/) into the driver and executor pods: +* [hostPath](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath): mounts a file or directory from the host node’s filesystem into a pod. +* [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node. +* [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod. + +To mount a volume of any of the types above into the driver pod, use the following configuration property: + +``` +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path= +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly= +``` + +Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification. + +Each supported type of volumes may have some specific configuration options, which can be specified using configuration properties of the following form: + +``` +spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName]= +``` + +For example, the claim name of a `persistentVolumeClaim` with volume name `checkpointpvc` can be specified using the following property: + +``` +spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=check-point-pvc-claim +``` + +The configuration properties for mounting volumes into the executor pods use prefix `spark.kubernetes.executor.` instead of `spark.kubernetes.driver.`. For a complete list of available options for each supported type of volumes, please refer to the [Spark Properties](#spark-properties) section below. + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -299,21 +329,15 @@ RBAC authorization and how to configure Kubernetes service accounts for pods, pl ## Future Work -There are several Spark on Kubernetes features that are currently being incubated in a fork - -[apache-spark-on-k8s/spark](https://github.com/apache-spark-on-k8s/spark), which are expected to eventually make it into -future versions of the spark-kubernetes integration. +There are several Spark on Kubernetes features that are currently being worked on or planned to be worked on. Those features are expected to eventually make it into future versions of the spark-kubernetes integration. Some of these include: -* R -* Dynamic Executor Scaling +* Dynamic Resource Allocation and External Shuffle Service * Local File Dependency Management * Spark Application Management * Job Queues and Resource Management -You can refer to the [documentation](https://apache-spark-on-k8s.github.io/userdocs/) if you want to try these features -and provide feedback to the development team. - # Configuration See the [configuration page](configuration.html) for information on Spark configurations. The following configurations are From 8198ea50192cad615071beb5510c73aa9e9178f4 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 28 Aug 2018 10:57:13 +0800 Subject: [PATCH 1490/2461] [SPARK-24721][SQL] Exclude Python UDFs filters in FileSourceStrategy ## What changes were proposed in this pull request? The PR excludes Python UDFs filters in FileSourceStrategy so that they don't ExtractPythonUDF rule to throw exception. It doesn't make sense to pass Python UDF filters in FileSourceStrategy anyway because they cannot be used as push down filters. ## How was this patch tested? Add a new regression test Closes #22104 from icexelloss/SPARK-24721-udf-filter. Authored-by: Li Jin Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests.py | 94 +++++++++++++++++++ python/pyspark/sql/utils.py | 19 ++++ .../spark/sql/execution/QueryExecution.scala | 1 - .../spark/sql/execution/SparkOptimizer.scala | 5 +- .../spark/sql/execution/SparkPlanner.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 15 +++ .../python/ArrowEvalPythonExec.scala | 9 +- .../python/BatchEvalPythonExec.scala | 7 ++ .../execution/python/ExtractPythonUDFs.scala | 27 +++--- .../spark/sql/sources/TableScanSuite.scala | 2 + .../sql/sources/v2/DataSourceV2Suite.scala | 3 +- 11 files changed, 164 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 00d7e18320a51..81c0af0b3d81b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,8 +68,16 @@ # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) +_test_not_compiled_message = None +try: + from pyspark.sql.utils import require_test_compiled + require_test_compiled() +except Exception as e: + _test_not_compiled_message = _exception_message(e) + _have_pandas = _pandas_requirement_message is None _have_pyarrow = _pyarrow_requirement_message is None +_test_compiled = _test_not_compiled_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -3367,6 +3375,47 @@ def test_ignore_column_of_all_nulls(self): finally: shutil.rmtree(path) + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) + def test_datasource_with_udf(self): + from pyspark.sql.functions import udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = udf(lambda x: x + 1, 'int')(lit(1)) + c2 = udf(lambda x: x + 1, 'int')(col('i')) + + f1 = udf(lambda x: False, 'boolean')(lit(1)) + f2 = udf(lambda x: False, 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) @@ -5269,6 +5318,51 @@ def f3(x): self.assertEquals(expected.collect(), df1.collect()) + # SPARK-24721 + @unittest.skipIf(not _test_compiled, _test_not_compiled_message) + def test_datasource_with_udf(self): + # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF + # This needs to a separate test because Arrow dependency is optional + import pandas as pd + import numpy as np + from pyspark.sql.functions import pandas_udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) + c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) + + f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index bb9ce02c4b60f..bdb3a1467f1d8 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -152,6 +152,25 @@ def require_minimum_pyarrow_version(): "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) +def require_test_compiled(): + """ Raise Exception if test classes are not compiled + """ + import os + import glob + try: + spark_home = os.environ['SPARK_HOME'] + except KeyError: + raise RuntimeError('SPARK_HOME is not defined in environment') + + test_class_path = os.path.join( + spark_home, 'sql', 'core', 'target', '*', 'test-classes') + paths = glob.glob(test_class_path) + + if len(paths) == 0: + raise RuntimeError( + "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path) + + class ForeachBatchFunction(object): """ This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3112b306c365e..64f49e2d0d4e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 969def7624058..6c6d344240cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning -import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate +import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( catalog: SessionCatalog, @@ -31,7 +31,8 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ - Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("Extract Python UDFs", Once, + Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++ postHocOptimizationBatches :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 75f5ec0e253df..2a4a1c8ef3438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,6 +36,7 @@ class SparkPlanner( override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + PythonEvals :: DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4c39990acb627..dbc6db62bd820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf @@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert EvalPython logical operator to physical operator. + */ + object PythonEvals extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ArrowEvalPython(udfs, output, child) => + ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil + case BatchEvalPython(udfs, output, child) => + BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil + case _ => + Nil + } + } + object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0bc21c0986e69..6a03f860f8f95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType @@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) } /** - * A physical plan that evaluates a [[PythonUDF]], + * A logical plan that evaluates a [[PythonUDF]]. + */ +case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + +/** + * A physical plan that evaluates a [[PythonUDF]]. */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index f4d83e8dc7c2b..2054c700957e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -25,9 +25,16 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} +/** + * A logical plan that evaluates a [[PythonUDF]] + */ +case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) + extends UnaryNode + /** * A physical plan that evaluates a [[PythonUDF]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index cb75874be32ec..90b5325919e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,9 +24,8 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -93,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { private type EvalType = Int private type EvalTypeChecker = EvalType => Boolean @@ -132,14 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case plan: SparkPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case plan: LogicalPlan => extract(plan) } /** * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - private def extract(plan: SparkPlan): SparkPlan = { + private def extract(plan: LogicalPlan): LogicalPlan = { val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) @@ -151,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val prunedChildren = plan.children.map { child => val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq if (allNeededOutput.length != child.output.length) { - ProjectExec(allNeededOutput, child) + Project(allNeededOutput, child) } else { child } @@ -180,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => - BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) + BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child) case _ => throw new AnalysisException( "Expected either Scalar Pandas UDFs or Batched UDFs but got both") @@ -209,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - ProjectExec(plan.output, newPlan) + Project(plan.output, newPlan) } else { newPlan } @@ -218,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { // Split the original FilterExec to two FilterExecs. Only push down the first few predicates // that are all deterministic. - private def trySplitFilter(plan: SparkPlan): SparkPlan = { + private def trySplitFilter(plan: LogicalPlan): LogicalPlan = { plan match { - case filter: FilterExec => + case filter: Filter => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { - val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) - FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) + val newChild = Filter(pushDown.reduceLeft(And), filter.child) + Filter((rest ++ nonDeterministic).reduceLeft(And), newChild) } else { filter } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 17690e3df9155..13a126ff963d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. class SimpleScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 5edeff553eb16..f6c3e0ce82e3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -370,7 +370,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProv } } - +// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark +// tests still pass. class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { class ReadSupport extends SimpleReadSupport { From 592e3a42c20b72edd6e8b9dd07da367596f43da5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 28 Aug 2018 08:36:06 -0700 Subject: [PATCH 1491/2461] [SPARK-25218][CORE] Fix potential resource leaks in TransportServer and SocketAuthHelper ## What changes were proposed in this pull request? Make sure TransportServer and SocketAuthHelper close the resources for all types of errors. ## How was this patch tested? Jenkins Closes #22210 from zsxwing/SPARK-25218. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../buffer/FileSegmentManagedBuffer.java | 32 ++++++------ .../spark/network/server/TransportServer.java | 9 ++-- .../spark/security/SocketAuthHelper.scala | 50 ++++++++++++------- 3 files changed, 54 insertions(+), 37 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 8b8f9892847c3..45fee541a4f5d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -77,16 +77,16 @@ public ByteBuffer nioByteBuffer() throws IOException { return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); } } catch (IOException e) { + String errorMessage = "Error in reading " + this; try { if (channel != null) { long size = channel.size(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; } } catch (IOException ignored) { // ignore } - throw new IOException("Error in opening " + this, e); + throw new IOException(errorMessage, e); } finally { JavaUtils.closeQuietly(channel); } @@ -95,26 +95,24 @@ public ByteBuffer nioByteBuffer() throws IOException { @Override public InputStream createInputStream() throws IOException { FileInputStream is = null; + boolean shouldClose = true; try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return new LimitedInputStream(is, length); + InputStream r = new LimitedInputStream(is, length); + shouldClose = false; + return r; } catch (IOException e) { - try { - if (is != null) { - long size = file.length(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); - } - } catch (IOException ignored) { - // ignore - } finally { + String errorMessage = "Error in reading " + this; + if (is != null) { + long size = file.length(); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; + } + throw new IOException(errorMessage, e); + } finally { + if (shouldClose) { JavaUtils.closeQuietly(is); } - throw new IOException("Error in opening " + this, e); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(is); - throw e; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index d95ed22912507..9c85ab2f5f06f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -70,11 +70,14 @@ public TransportServer( this.appRpcHandler = appRpcHandler; this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); + boolean shouldClose = true; try { init(hostToBind, portToBind); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(this); - throw e; + shouldClose = false; + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(this); + } } } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index d15e7937b0523..ea38ccb289c30 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -42,43 +42,59 @@ private[spark] class SocketAuthHelper(conf: SparkConf) { * Read the auth secret from the socket and compare to the expected value. Write the reply back * to the socket. * - * If authentication fails, this method will close the socket. + * If authentication fails or error is thrown, this method will close the socket. * * @param s The client socket. * @throws IllegalArgumentException If authentication fails. */ def authClient(s: Socket): Unit = { - // Set the socket timeout while checking the auth secret. Reset it before returning. - val currentTimeout = s.getSoTimeout() + var shouldClose = true try { - s.setSoTimeout(10000) - val clientSecret = readUtf8(s) - if (secret == clientSecret) { - writeUtf8("ok", s) - } else { - writeUtf8("err", s) - JavaUtils.closeQuietly(s) + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + shouldClose = false + } else { + writeUtf8("err", s) + throw new IllegalArgumentException("Authentication failed.") + } + } finally { + s.setSoTimeout(currentTimeout) } } finally { - s.setSoTimeout(currentTimeout) + if (shouldClose) { + JavaUtils.closeQuietly(s) + } } } /** * Authenticate with a server by writing the auth secret and checking the server's reply. * - * If authentication fails, this method will close the socket. + * If authentication fails or error is thrown, this method will close the socket. * * @param s The socket connected to the server. * @throws IllegalArgumentException If authentication fails. */ def authToServer(s: Socket): Unit = { - writeUtf8(secret, s) + var shouldClose = true + try { + writeUtf8(secret, s) - val reply = readUtf8(s) - if (reply != "ok") { - JavaUtils.closeQuietly(s) - throw new IllegalArgumentException("Authentication failed.") + val reply = readUtf8(s) + if (reply != "ok") { + throw new IllegalArgumentException("Authentication failed.") + } else { + shouldClose = false + } + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(s) + } } } From 1149c4efbc5ebe5b412d8f9c61558fef59179a9e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 28 Aug 2018 08:38:07 -0700 Subject: [PATCH 1492/2461] [SPARK-25005][SS] Support non-consecutive offsets for Kafka ## What changes were proposed in this pull request? As the user uses Kafka transactions to write data, the offsets in Kafka will be non-consecutive. It will contains some transaction (commit or abort) markers. In addition, if the consumer's `isolation.level` is `read_committed`, `poll` will not return aborted messages either. Hence, we will see non-consecutive offsets in the date returned by `poll`. However, as `seekToEnd` may move the offset point to these missing offsets, there are 4 possible corner cases we need to support: - The whole batch contains no data messages - The first offset in a batch is not a committed data message - The last offset in a batch is not a committed data message - There is a gap in the middle of a batch They are all covered by the new unit tests. ## How was this patch tested? The new unit tests. Closes #22042 from zsxwing/kafka-transaction-read. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../kafka010/KafkaContinuousReadSupport.scala | 2 +- .../sql/kafka010/KafkaDataConsumer.scala | 273 +++++++++++++----- .../kafka010/KafkaContinuousSourceSuite.scala | 149 +++++++++- .../kafka010/KafkaMicroBatchSourceSuite.scala | 255 +++++++++++++++- .../sql/kafka010/KafkaRelationSuite.scala | 93 ++++++ .../spark/sql/kafka010/KafkaTestUtils.scala | 22 +- 6 files changed, 720 insertions(+), 74 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index 4a18839e6a77a..1753a28fba2fb 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -227,7 +227,7 @@ class KafkaContinuousPartitionReader( // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, // or if it's the endpoint of the data range (i.e. the "true" next offset). - case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => val range = consumer.getAvailableOffsetRange() if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { // retry diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 65046c175a7e5..ceb9e318b283b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -33,9 +33,19 @@ import org.apache.spark.util.UninterruptibleThread private[kafka010] sealed trait KafkaDataConsumer { /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. + * Get the record for the given offset if available. + * + * If the record is invisible (either a + * transaction message, or an aborted message when the consumer's `isolation.level` is + * `read_committed`), it will be skipped and this method will try to fetch next available record + * within [offset, untilOffset). + * + * This method also will try its best to detect data loss. If `failOnDataLoss` is `true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will try to fetch next available record within [offset, untilOffset). + * + * When this method tries to skip offsets due to either invisible messages or data loss and + * reaches `untilOffset`, it will return `null`. * * @param offset the offset to fetch. * @param untilOffset the max offset to fetch. Exclusive. @@ -80,6 +90,83 @@ private[kafka010] case class InternalKafkaConsumer( kafkaParams: ju.Map[String, Object]) extends Logging { import InternalKafkaConsumer._ + /** + * The internal object to store the fetched data from Kafka consumer and the next offset to poll. + * + * @param _records the pre-fetched Kafka records. + * @param _nextOffsetInFetchedData the next offset in `records`. We use this to verify if we + * should check if the pre-fetched data is still valid. + * @param _offsetAfterPoll the Kafka offset after calling `poll`. We will use this offset to + * poll when `records` is drained. + */ + private case class FetchedData( + private var _records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + private var _nextOffsetInFetchedData: Long, + private var _offsetAfterPoll: Long) { + + def withNewPoll( + records: ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + offsetAfterPoll: Long): FetchedData = { + this._records = records + this._nextOffsetInFetchedData = UNKNOWN_OFFSET + this._offsetAfterPoll = offsetAfterPoll + this + } + + /** Whether there are more elements */ + def hasNext: Boolean = _records.hasNext + + /** Move `records` forward and return the next record. */ + def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + val record = _records.next() + _nextOffsetInFetchedData = record.offset + 1 + record + } + + /** Move `records` backward and return the previous record. */ + def previous(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(_records.hasPrevious, "fetchedData cannot move back") + val record = _records.previous() + _nextOffsetInFetchedData = record.offset + record + } + + /** Reset the internal pre-fetched data. */ + def reset(): Unit = { + _records = ju.Collections.emptyListIterator() + } + + /** + * Returns the next offset in `records`. We use this to verify if we should check if the + * pre-fetched data is still valid. + */ + def nextOffsetInFetchedData: Long = _nextOffsetInFetchedData + + /** + * Returns the next offset to poll after draining the pre-fetched records. + */ + def offsetAfterPoll: Long = _offsetAfterPoll + } + + /** + * The internal object returned by the `fetchRecord` method. If `record` is empty, it means it is + * invisible (either a transaction message, or an aborted message when the consumer's + * `isolation.level` is `read_committed`), and the caller should use `nextOffsetToFetch` to fetch + * instead. + */ + private case class FetchedRecord( + var record: ConsumerRecord[Array[Byte], Array[Byte]], + var nextOffsetToFetch: Long) { + + def withRecord( + record: ConsumerRecord[Array[Byte], Array[Byte]], + nextOffsetToFetch: Long): FetchedRecord = { + this.record = record + this.nextOffsetToFetch = nextOffsetToFetch + this + } + } + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] @volatile private var consumer = createConsumer @@ -90,10 +177,21 @@ private[kafka010] case class InternalKafkaConsumer( /** indicate whether this consumer is going to be stopped in the next release */ @volatile var markedForClose = false - /** Iterator to the already fetch data */ - @volatile private var fetchedData = - ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET + /** + * The fetched data returned from Kafka consumer. This is a reusable private object to avoid + * memory allocation. + */ + private val fetchedData = FetchedData( + ju.Collections.emptyListIterator[ConsumerRecord[Array[Byte], Array[Byte]]], + UNKNOWN_OFFSET, + UNKNOWN_OFFSET) + + /** + * The fetched record returned from the `fetchRecord` method. This is a reusable private object to + * avoid memory allocation. + */ + private val fetchedRecord: FetchedRecord = FetchedRecord(null, UNKNOWN_OFFSET) + /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -125,20 +223,7 @@ private[kafka010] case class InternalKafkaConsumer( AvailableOffsetRange(earliestOffset, latestOffset) } - /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. - * - * @param offset the offset to fetch. - * @param untilOffset the max offset to fetch. Exclusive. - * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. - * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at - * offset if available, or throw exception.when `failOnDataLoss` is `false`, - * this method will either return record at offset if available, or return - * the next earliest available record less than untilOffset, or null. It - * will not throw any exception. - */ + /** @see [[KafkaDataConsumer.get]] */ def get( offset: Long, untilOffset: Long, @@ -147,21 +232,32 @@ private[kafka010] case class InternalKafkaConsumer( ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") - logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + logDebug(s"Get $groupId $topicPartition nextOffset ${fetchedData.nextOffsetInFetchedData} " + + s"requested $offset") // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then // we will move to the next available offset within `[offset, untilOffset)` and retry. // If `failOnDataLoss` is `true`, the loop body will be executed only once. var toFetchOffset = offset - var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null + var fetchedRecord: FetchedRecord = null // We want to break out of the while loop on a successful fetch to avoid using "return" // which may cause a NonLocalReturnControl exception when this method is used as a function. var isFetchComplete = false while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) { try { - consumerRecord = fetchData(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) - isFetchComplete = true + fetchedRecord = fetchRecord(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) + if (fetchedRecord.record != null) { + isFetchComplete = true + } else { + toFetchOffset = fetchedRecord.nextOffsetToFetch + if (toFetchOffset >= untilOffset) { + fetchedData.reset() + toFetchOffset = UNKNOWN_OFFSET + } else { + logDebug(s"Skipped offsets [$offset, $toFetchOffset]") + } + } } catch { case e: OffsetOutOfRangeException => // When there is some error thrown, it's better to use a new consumer to drop all cached @@ -174,9 +270,9 @@ private[kafka010] case class InternalKafkaConsumer( } if (isFetchComplete) { - consumerRecord + fetchedRecord.record } else { - resetFetchedData() + fetchedData.reset() null } } @@ -239,57 +335,73 @@ private[kafka010] case class InternalKafkaConsumer( } /** - * Get the record for the given offset if available. Otherwise it will either throw error - * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), - * or null. + * Get the fetched record for the given offset if available. + * + * If the record is invisible (either a transaction message, or an aborted message when the + * consumer's `isolation.level` is `read_committed`), it will return a `FetchedRecord` with the + * next offset to fetch. + * + * This method also will try the best to detect data loss. If `failOnDataLoss` is true`, it will + * throw an exception when we detect an unavailable offset. If `failOnDataLoss` is `false`, this + * method will return `null` if the next available record is within [offset, untilOffset). * * @throws OffsetOutOfRangeException if `offset` is out of range * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. */ - private def fetchData( + private def fetchRecord( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { - if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) { - // This is the first fetch, or the last pre-fetched data has been drained. - // Seek to the offset because we may call seekToBeginning or seekToEnd before this. - seek(offset) - poll(pollTimeoutMs) - } - - if (!fetchedData.hasNext()) { - // We cannot fetch anything after `poll`. Two possible cases: - // - `offset` is out of range so that Kafka returns nothing. Just throw - // `OffsetOutOfRangeException` to let the caller handle it. - // - Cannot fetch any data before timeout. TimeoutException will be thrown. - val range = getAvailableOffsetRange() - if (offset < range.earliest || offset >= range.latest) { - throw new OffsetOutOfRangeException( - Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + failOnDataLoss: Boolean): FetchedRecord = { + if (offset != fetchedData.nextOffsetInFetchedData) { + // This is the first fetch, or the fetched data has been reset. + // Fetch records from Kafka and update `fetchedData`. + fetchData(offset, pollTimeoutMs) + } else if (!fetchedData.hasNext) { // The last pre-fetched data has been drained. + if (offset < fetchedData.offsetAfterPoll) { + // Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask + // the next call to start from `fetchedData.offsetAfterPoll`. + fetchedData.reset() + return fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) } else { - throw new TimeoutException( - s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + // Fetch records from Kafka and update `fetchedData`. + fetchData(offset, pollTimeoutMs) } + } + + if (!fetchedData.hasNext) { + // When we reach here, we have already tried to poll from Kafka. As `fetchedData` is still + // empty, all messages in [offset, fetchedData.offsetAfterPoll) are invisible. Return a + // record to ask the next call to start from `fetchedData.offsetAfterPoll`. + assert(offset <= fetchedData.offsetAfterPoll, + s"seek to $offset and poll but the offset was reset to ${fetchedData.offsetAfterPoll}") + fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) } else { val record = fetchedData.next() - nextOffsetInFetchedData = record.offset + 1 // In general, Kafka uses the specified offset as the start point, and tries to fetch the next // available offset. Hence we need to handle offset mismatch. if (record.offset > offset) { + val range = getAvailableOffsetRange() + if (range.earliest <= offset) { + // `offset` is still valid but the corresponding message is invisible. We should skip it + // and jump to `record.offset`. Here we move `fetchedData` back so that the next call of + // `fetchRecord` can just return `record` directly. + fetchedData.previous() + return fetchedRecord.withRecord(null, record.offset) + } // This may happen when some records aged out but their offsets already got verified if (failOnDataLoss) { reportDataLoss(true, s"Cannot fetch records in [$offset, ${record.offset})") // Never happen as "reportDataLoss" will throw an exception - null + throw new IllegalStateException( + "reportDataLoss didn't throw an exception when 'failOnDataLoss' is true") + } else if (record.offset >= untilOffset) { + reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") + // Set `nextOffsetToFetch` to `untilOffset` to finish the current batch. + fetchedRecord.withRecord(null, untilOffset) } else { - if (record.offset >= untilOffset) { - reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") - null - } else { - reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") - record - } + reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") + fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } else if (record.offset < offset) { // This should not happen. If it does happen, then we probably misunderstand Kafka internal @@ -297,7 +409,7 @@ private[kafka010] case class InternalKafkaConsumer( throw new IllegalStateException( s"Tried to fetch $offset but the returned record offset was ${record.offset}") } else { - record + fetchedRecord.withRecord(record, fetchedData.nextOffsetInFetchedData) } } } @@ -306,13 +418,7 @@ private[kafka010] case class InternalKafkaConsumer( private def resetConsumer(): Unit = { consumer.close() consumer = createConsumer - resetFetchedData() - } - - /** Reset the internal pre-fetched data. */ - private def resetFetchedData(): Unit = { - nextOffsetInFetchedData = UNKNOWN_OFFSET - fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + fetchedData.reset() } /** @@ -346,11 +452,40 @@ private[kafka010] case class InternalKafkaConsumer( consumer.seek(topicPartition, offset) } - private def poll(pollTimeoutMs: Long): Unit = { + /** + * Poll messages from Kafka starting from `offset` and update `fetchedData`. `fetchedData` may be + * empty if the Kafka consumer fetches some messages but all of them are not visible messages + * (either transaction messages, or aborted messages when `isolation.level` is `read_committed`). + * + * @throws OffsetOutOfRangeException if `offset` is out of range. + * @throws TimeoutException if the consumer position is not changed after polling. It means the + * consumer polls nothing before timeout. + */ + private def fetchData(offset: Long, pollTimeoutMs: Long): Unit = { + // Seek to the offset because we may call seekToBeginning or seekToEnd before this. + seek(offset) val p = consumer.poll(pollTimeoutMs) val r = p.records(topicPartition) logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") - fetchedData = r.iterator + val offsetAfterPoll = consumer.position(topicPartition) + logDebug(s"Offset changed from $offset to $offsetAfterPoll after polling") + fetchedData.withNewPoll(r.listIterator, offsetAfterPoll) + if (!fetchedData.hasNext) { + // We cannot fetch anything after `poll`. Two possible cases: + // - `offset` is out of range so that Kafka returns nothing. `OffsetOutOfRangeException` will + // be thrown. + // - Cannot fetch any data before timeout. `TimeoutException` will be thrown. + // - Fetched something but all of them are not invisible. This is a valid case and let the + // caller handles this. + val range = getAvailableOffsetRange() + if (offset < range.earliest || offset >= range.latest) { + throw new OffsetOutOfRangeException( + Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava) + } else if (offset == offsetAfterPoll) { + throw new TimeoutException( + s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds") + } + } } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 5d68a14326c00..af510219a6f6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,12 +17,159 @@ package org.apache.spark.sql.kafka010 +import org.apache.kafka.clients.producer.ProducerRecord + import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest { + import testImplicits._ + + test("read Kafka transactional messages: read_committed") { + val table = "kafka_continuous_source_test" + withTable(table) { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_committed") + .option("startingOffsets", "earliest") + .option("subscribe", topic) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + val q = df + .writeStream + .format("memory") + .queryName(table) + .trigger(ContinuousTrigger(100)) + .start() + try { + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // Should not read any messages before they are committed + assert(spark.table(table).isEmpty) + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all committed messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // Should not read aborted messages + checkAnswer(spark.table(table), (1 to 5).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should skip aborted messages and read new committed ones. + checkAnswer(spark.table(table), ((1 to 5) ++ (11 to 15)).toDF) + } + } finally { + q.stop() + } + } + } + } + + test("read Kafka transactional messages: read_uncommitted") { + val table = "kafka_continuous_source_test" + withTable(table) { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_uncommitted") + .option("startingOffsets", "earliest") + .option("subscribe", topic) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + val q = df + .writeStream + .format("memory") + .queryName(table) + .trigger(ContinuousTrigger(100)) + .start() + try { + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + eventually(timeout(streamingTimeout)) { + // Should read uncommitted messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all committed messages + checkAnswer(spark.table(table), (1 to 5).toDF) + } + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read aborted messages + checkAnswer(spark.table(table), (1 to 10).toDF) + } + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + eventually(timeout(streamingTimeout)) { + // Should read all messages including committed, aborted and uncommitted messages + checkAnswer(spark.table(table), (1 to 15).toDF) + } + + producer.commitTransaction() + + eventually(timeout(streamingTimeout)) { + // Should read all messages including committed and aborted messages + checkAnswer(spark.table(table), (1 to 15).toDF) + } + } finally { + q.stop() + } + } + } + } +} class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { import testImplicits._ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 1d1455009251c..eb66ccac744a3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -28,7 +28,7 @@ import scala.collection.JavaConverters._ import scala.io.Source import scala.util.Random -import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata} import org.apache.kafka.common.TopicPartition import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods._ @@ -159,6 +159,19 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf s"AddKafkaData(topics = $topics, data = $data, message = $message)" } + object WithOffsetSync { + def apply(topic: String)(func: () => Unit): StreamAction = { + Execute("Run Kafka Producer")(_ => { + func() + // This is a hack for the race condition that the committed message may be not visible to + // consumer for a short time. + // Looks like after the following call returns, the consumer can always read the committed + // messages. + testUtils.getLatestOffsets(Set(topic)) + }) + } + } + private val topicId = new AtomicInteger(0) protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } @@ -596,6 +609,246 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } ) } + + test("read Kafka transactional messages: read_committed") { + // This test will cover the following cases: + // 1. the whole batch contains no data messages + // 2. the first offset in a batch is not a committed data message + // 3. the last offset in a batch is not a committed data message + // 4. there is a gap in the middle of a batch + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_committed") + .option("maxOffsetsPerTrigger", 3) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // Set a short timeout to make the test fast. When a batch doesn't contain any visible data + // messages, "poll" will wait until timeout. + .option("kafkaConsumer.pollTimeoutMs", 5000) + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + // Wait until the manual clock is waiting on further instructions to move forward. Then we can + // ensure all batches we are waiting for have been processed. + val waitUntilBatchProcessed = Execute { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + } + + // The message values are the same as their offsets to make the test easy to follow + testUtils.withTranscationalProducer { producer => + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + CheckAnswer(), + WithOffsetSync(topic) { () => + // Send 5 messages. They should be visible only after being committed. + producer.beginTransaction() + (0 to 4).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + // Should not see any uncommitted messages + CheckNewAnswer(), + WithOffsetSync(topic) { () => + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] + WithOffsetSync(topic) { () => + // Send 5 messages and abort the transaction. They should not be read. + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(), // offset: 6*, 7*, 8* + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(), // offset: 9*, 10*, 11* + WithOffsetSync(topic) { () => + // Send 5 messages again. The consumer should skip the above aborted messages and read + // them. + producer.beginTransaction() + (12 to 16).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(12, 13, 14), // offset: 12, 13, 14 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(15, 16), // offset: 15, 16, 17* + WithOffsetSync(topic) { () => + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "18")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "20")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "22")).get() + producer.send(new ProducerRecord[String, String](topic, "23")).get() + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(18, 20), // offset: 18, 19*, 20 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(22, 23), // offset: 21*, 22, 23 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer() // offset: 24* + ) + } + } + + test("read Kafka transactional messages: read_uncommitted") { + // This test will cover the following cases: + // 1. the whole batch contains no data messages + // 2. the first offset in a batch is not a committed data message + // 3. the last offset in a batch is not a committed data message + // 4. there is a gap in the middle of a batch + + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_uncommitted") + .option("maxOffsetsPerTrigger", 3) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // Set a short timeout to make the test fast. When a batch doesn't contain any visible data + // messages, "poll" will wait until timeout. + .option("kafkaConsumer.pollTimeoutMs", 5000) + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + // Wait until the manual clock is waiting on further instructions to move forward. Then we can + // ensure all batches we are waiting for have been processed. + val waitUntilBatchProcessed = Execute { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + } + + // The message values are the same as their offsets to make the test easy to follow + testUtils.withTranscationalProducer { producer => + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + CheckNewAnswer(), + WithOffsetSync(topic) { () => + // Send 5 messages. They should be visible only after being committed. + producer.beginTransaction() + (0 to 4).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 + WithOffsetSync(topic) { () => + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] + WithOffsetSync(topic) { () => + // Send 5 messages and abort the transaction. They should not be read. + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(6, 7, 8), // offset: 6, 7, 8 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(9, 10), // offset: 9, 10, 11* + WithOffsetSync(topic) { () => + // Send 5 messages again. The consumer should skip the above aborted messages and read + // them. + producer.beginTransaction() + (12 to 16).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(12, 13, 14), // offset: 12, 13, 14 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(15, 16), // offset: 15, 16, 17* + WithOffsetSync(topic) { () => + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "18")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "20")).get() + producer.commitTransaction() + producer.beginTransaction() + producer.send(new ProducerRecord[String, String](topic, "22")).get() + producer.send(new ProducerRecord[String, String](topic, "23")).get() + producer.commitTransaction() + }, + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(18, 20), // offset: 18, 19*, 20 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer(22, 23), // offset: 21*, 22, 23 + AdvanceManualClock(100), + waitUntilBatchProcessed, + CheckNewAnswer() // offset: 24* + ) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 688e9c40fed22..93dba18446280 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kafka010 import java.util.Locale import java.util.concurrent.atomic.AtomicInteger +import org.apache.kafka.clients.producer.ProducerRecord import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.QueryTest @@ -234,4 +235,96 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest testBadOptions("subscribe" -> "")("no topics to subscribe") testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") } + + test("read Kafka transactional messages: read_committed") { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_committed") + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // Should not read any messages before they are committed + assert(df.isEmpty) + + producer.commitTransaction() + + // Should read all committed messages + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // Should not read aborted messages + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + // Should skip aborted messages and read new committed ones. + checkAnswer(df, ((1 to 5) ++ (11 to 15)).map(_.toString).toDF) + } + } + + test("read Kafka transactional messages: read_uncommitted") { + val topic = newTopic() + testUtils.createTopic(topic) + testUtils.withTranscationalProducer { producer => + val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.isolation.level", "read_uncommitted") + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + + producer.beginTransaction() + (1 to 5).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + + // "read_uncommitted" should see all messages including uncommitted ones + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.commitTransaction() + + // Should read all committed messages + checkAnswer(df, (1 to 5).map(_.toString).toDF) + + producer.beginTransaction() + (6 to 10).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.abortTransaction() + + // "read_uncommitted" should see all messages including uncommitted or aborted ones + checkAnswer(df, (1 to 10).map(_.toString).toDF) + + producer.beginTransaction() + (11 to 15).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + + // Should read all messages + checkAnswer(df, (1 to 15).map(_.toString).toDF) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 55d61ef20ca8a..7b742a3ea6741 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io.{File, IOException} import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Map => JMap, Properties} +import java.util.{Map => JMap, Properties, UUID} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -323,9 +323,14 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") + props.put("group.initial.rebalance.delay.ms", "10") + + // Change the following settings as we have only 1 broker props.put("offsets.topic.num.partitions", "1") props.put("offsets.topic.replication.factor", "1") - props.put("group.initial.rebalance.delay.ms", "10") + props.put("transaction.state.log.replication.factor", "1") + props.put("transaction.state.log.min.isr", "1") + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 withBrokerProps.foreach { case (k, v) => props.put(k, v) } @@ -342,6 +347,19 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props } + /** Call `f` with a `KafkaProducer` that has initialized transactions. */ + def withTranscationalProducer(f: KafkaProducer[String, String] => Unit): Unit = { + val props = producerConfiguration + props.put("transactional.id", UUID.randomUUID().toString) + val producer = new KafkaProducer[String, String](props) + try { + producer.initTransactions() + f(producer) + } finally { + producer.close() + } + } + private def consumerConfiguration: Properties = { val props = new Properties() props.put("bootstrap.servers", brokerAddress) From de46df549acee7fda56bb0871f444d2f3b49e582 Mon Sep 17 00:00:00 2001 From: Fernando Pereira Date: Tue, 28 Aug 2018 10:31:47 -0700 Subject: [PATCH 1493/2461] [SPARK-23997][SQL] Configurable maximum number of buckets ## What changes were proposed in this pull request? This PR implements the possibility of the user to override the maximum number of buckets when saving to a table. Currently the limit is a hard-coded 100k, which might be insufficient for large workloads. A new configuration entry is proposed: `spark.sql.bucketing.maxBuckets`, which defaults to the previous 100k. ## How was this patch tested? Added unit tests in the following spark.sql test suites: - CreateTableAsSelectSuite - BucketedWriteSuite Author: Fernando Pereira Closes #21087 from ferdonline/enh/configurable_bucket_limit. --- .../sql/catalyst/catalog/interface.scala | 8 +++-- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../sql/sources/BucketedWriteSuite.scala | 33 ++++++++++++++--- .../sources/CreateTableAsSelectSuite.scala | 35 +++++++++++++++++-- 4 files changed, 76 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index a4ead538bb51a..3842d794ba5ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -173,9 +174,12 @@ case class BucketSpec( numBuckets: Int, bucketColumnNames: Seq[String], sortColumnNames: Seq[String]) { - if (numBuckets <= 0 || numBuckets >= 100000) { + def conf: SQLConf = SQLConf.get + + if (numBuckets <= 0 || numBuckets > conf.bucketingMaxBuckets) { throw new AnalysisException( - s"Number of buckets should be greater than 0 but less than 100000. Got `$numBuckets`") + s"Number of buckets should be greater than 0 but less than bucketing.maxBuckets " + + s"(`${conf.bucketingMaxBuckets}`). Got `$numBuckets`") } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6336e89671937..738d8fee891d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -674,6 +674,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") + .doc("The maximum number of buckets allowed. Defaults to 100000") + .intConf + .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be larger than 0") + .createWithDefault(100000) + val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled") .doc("When false, we will throw an error if a query contains a cartesian product without " + "explicit CROSS JOIN syntax.") @@ -1803,6 +1809,8 @@ class SQLConf extends Serializable with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingMaxBuckets: Int = getConf(SQLConf.BUCKETING_MAX_BUCKETS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 5ff1ea84d9a7b..fc61050dc7458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.sources import java.io.File -import java.net.URI import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.datasources.BucketingUtils @@ -48,16 +49,40 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) } - test("numBuckets be greater than 0 but less than 100000") { + test("numBuckets be greater than 0 but less/eq than default bucketing.maxBuckets (100000)") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - Seq(-1, 0, 100000).foreach(numBuckets => { + Seq(-1, 0, 100001).foreach(numBuckets => { val e = intercept[AnalysisException](df.write.bucketBy(numBuckets, "i").saveAsTable("tt")) assert( - e.getMessage.contains("Number of buckets should be greater than 0 but less than 100000")) + e.getMessage.contains("Number of buckets should be greater than 0 but less than")) }) } + test("numBuckets be greater than 0 but less/eq than overridden bucketing.maxBuckets (200000)") { + val maxNrBuckets: Int = 200000 + val catalog = spark.sessionState.catalog + + withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + // within the new limit + Seq(100001, maxNrBuckets).foreach(numBuckets => { + withTable("t") { + df.write.bucketBy(numBuckets, "i").saveAsTable("t") + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(table.bucketSpec == Option(BucketSpec(numBuckets, Seq("i"), Seq()))) + } + }) + + // over the new limit + withTable("t") { + val e = intercept[AnalysisException]( + df.write.bucketBy(maxNrBuckets + 1, "i").saveAsTable("t")) + assert( + e.getMessage.contains("Number of buckets should be greater than 0 but less than")) + } + } + } + test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 916a01ee0ca8e..d46029e84433c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -225,7 +225,7 @@ class CreateTableAsSelectSuite test("create table using as select - with invalid number of buckets") { withTable("t") { - Seq(0, 100000).foreach(numBuckets => { + Seq(0, 100001).foreach(numBuckets => { val e = intercept[AnalysisException] { sql( s""" @@ -236,11 +236,42 @@ class CreateTableAsSelectSuite """.stripMargin ) }.getMessage - assert(e.contains("Number of buckets should be greater than 0 but less than 100000")) + assert(e.contains("Number of buckets should be greater than 0 but less than")) }) } } + test("create table using as select - with overriden max number of buckets") { + def createTableSql(numBuckets: Int): String = + s""" + |CREATE TABLE t USING PARQUET + |OPTIONS (PATH '${path.toURI}') + |CLUSTERED BY (a) SORTED BY (b) INTO $numBuckets BUCKETS + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + + val maxNrBuckets: Int = 200000 + val catalog = spark.sessionState.catalog + withSQLConf("spark.sql.sources.bucketing.maxBuckets" -> maxNrBuckets.toString) { + + // Within the new limit + Seq(100001, maxNrBuckets).foreach(numBuckets => { + withTable("t") { + sql(createTableSql(numBuckets)) + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(table.bucketSpec == Option(BucketSpec(numBuckets, Seq("a"), Seq("b")))) + } + }) + + // Over the new limit + withTable("t") { + val e = intercept[AnalysisException](sql(createTableSql(maxNrBuckets + 1))) + assert( + e.getMessage.contains("Number of buckets should be greater than 0 but less than ")) + } + } + } + test("SPARK-17409: CTAS of decimal calculation") { withTable("tab2") { withTempView("tab1") { From 4e3f3cebe4cc6f47c264821a5ea92c32a4f1daa5 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 28 Aug 2018 10:33:39 -0700 Subject: [PATCH 1494/2461] [SPARK-23679][YARN] Setting RM_HA_URLS for AmIpFilter to avoid redirect failure in YARN mode ## What changes were proposed in this pull request? YARN `AmIpFilter` adds a new parameter "RM_HA_URLS" to support RM HA, but Spark on YARN doesn't provide a such parameter, so it will be failed to redirect when running on RM HA. The detailed exception can be checked from JIRA. So here fixing this issue by adding "RM_HA_URLS" parameter. ## How was this patch tested? Local verification. Closes #22164 from jerryshao/SPARK-23679. Authored-by: jerryshao Signed-off-by: Marcelo Vanzin --- .../spark/deploy/yarn/YarnRMClient.scala | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b59dcf158d87c..05a7b1e1310c4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest @@ -112,7 +113,16 @@ private[spark] class YarnRMClient extends Logging { val proxies = WebAppUtils.getProxyHostsAndPortsForAmFilter(conf) val hosts = proxies.asScala.map(_.split(":").head) val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } - Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + val params = + Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + + // Handles RM HA urls + val rmIds = conf.getStringCollection(YarnConfiguration.RM_HA_IDS).asScala + if (rmIds != null && rmIds.nonEmpty) { + params + ("RM_HA_URLS" -> rmIds.map(getUrlByRmId(conf, _)).mkString(",")) + } else { + params + } } /** Returns the maximum number of attempts to register the AM. */ @@ -126,4 +136,21 @@ private[spark] class YarnRMClient extends Logging { } } + private def getUrlByRmId(conf: Configuration, rmId: String): String = { + val addressPropertyPrefix = if (YarnConfiguration.useHttps(conf)) { + YarnConfiguration.RM_WEBAPP_HTTPS_ADDRESS + } else { + YarnConfiguration.RM_WEBAPP_ADDRESS + } + + val addressWithRmId = if (rmId == null || rmId.isEmpty) { + addressPropertyPrefix + } else if (rmId.startsWith(".")) { + throw new IllegalStateException(s"rmId $rmId should not already have '.' prepended.") + } else { + s"$addressPropertyPrefix.$rmId" + } + + conf.get(addressWithRmId) + } } From aff8f15c153f8031ceaffa237c60e040c6f8115f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 28 Aug 2018 11:29:05 -0700 Subject: [PATCH 1495/2461] [SPARK-25240][SQL] Fix for a deadlock in RECOVER PARTITIONS ## What changes were proposed in this pull request? In the PR, I propose to not perform recursive parallel listening of files in the `scanPartitions` method because it can cause a deadlock. Instead of that I propose to do `scanPartitions` in parallel for top level partitions only. ## How was this patch tested? I extended an existing test to trigger the deadlock. Author: Maxim Gekk Closes #22233 from MaxGekk/fix-recover-partitions. --- .../spark/sql/execution/command/ddl.scala | 34 +++++------ .../sql/execution/command/DDLSuite.scala | 59 +++++++++++-------- .../sql/hive/execution/HiveDDLSuite.scala | 15 ++--- 3 files changed, 61 insertions(+), 47 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 7a6f5741862ca..e1faecedd20ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import java.util.Locale import scala.collection.{GenMap, GenSeq} -import scala.concurrent.ExecutionContext +import scala.collection.parallel.ForkJoinTaskSupport import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -40,7 +40,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types._ import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} -import org.apache.spark.util.ThreadUtils.parmap // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -622,9 +621,8 @@ case class AlterTableRecoverPartitionsCommand( val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = try { - implicit val ec = ExecutionContext.fromExecutor(evalPool) scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, - spark.sessionState.conf.resolver) + spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq } finally { evalPool.shutdown() } @@ -656,13 +654,23 @@ case class AlterTableRecoverPartitionsCommand( spec: TablePartitionSpec, partitionNames: Seq[String], threshold: Int, - resolver: Resolver)(implicit ec: ExecutionContext): Seq[(TablePartitionSpec, Path)] = { + resolver: Resolver, + evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } - val statuses = fs.listStatus(path, filter).toSeq - def handleStatus(st: FileStatus): Seq[(TablePartitionSpec, Path)] = { + val statuses = fs.listStatus(path, filter) + val statusPar: GenSeq[FileStatus] = + if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { + // parallelize the list of partitions here, then we can have better parallelism later. + val parArray = statuses.par + parArray.tasksupport = evalTaskSupport + parArray + } else { + statuses + } + statusPar.flatMap { st => val name = st.getPath.getName if (st.isDirectory && name.contains("=")) { val ps = name.split("=", 2) @@ -671,7 +679,7 @@ case class AlterTableRecoverPartitionsCommand( val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), - partitionNames.drop(1), threshold, resolver) + partitionNames.drop(1), threshold, resolver, evalTaskSupport) } else { logWarning( s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") @@ -682,14 +690,6 @@ case class AlterTableRecoverPartitionsCommand( Seq.empty } } - val result = if (partitionNames.length > 1 && - statuses.length > threshold || partitionNames.length > 2) { - parmap(statuses)(handleStatus _) - } else { - statuses.map(handleStatus) - } - - result.flatten } private def gatherPartitionStats( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 78df1db93692b..f8d98dead2d42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -52,23 +52,24 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable = { + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() .putString("key", "value") .build() + val schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), + schema = schema.copy( + fields = schema.fields ++ partitionCols.map(StructField(_, IntegerType))), provider = Some("parquet"), - partitionColumnNames = Seq("a", "b"), + partitionColumnNames = partitionCols, createTime = 0L, createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) @@ -176,7 +177,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { protected def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): CatalogTable + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -228,8 +230,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private def createTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean = true): Unit = { - catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) + isDataSource: Boolean = true, + partitionCols: Seq[String] = Seq("a", "b")): Unit = { + catalog.createTable( + generateTable(catalog, name, isDataSource, partitionCols), ignoreIfExists = false) } private def createTablePartition( @@ -1131,7 +1135,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("alter table: recover partition (parallel)") { - withSQLConf("spark.rdd.parallelListingThreshold" -> "1") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "0") { testRecoverPartitions() } } @@ -1144,23 +1148,32 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } val tableIdent = TableIdentifier("tab1") - createTable(catalog, tableIdent) - val part1 = Map("a" -> "1", "b" -> "5") + createTable(catalog, tableIdent, partitionCols = Seq("a", "b", "c")) + val part1 = Map("a" -> "1", "b" -> "5", "c" -> "19") createTablePartition(catalog, part1, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) - val part2 = Map("a" -> "2", "b" -> "6") + val part2 = Map("a" -> "2", "b" -> "6", "c" -> "31") val root = new Path(catalog.getTableMetadata(tableIdent).location) val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid - fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) - fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file - fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file - fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file - fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file - fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + fs.mkdirs(new Path(new Path(new Path(root, "a=1"), "b=5"), "c=19")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5/c=19"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5/c=19"), "_SUCCESS")) // file + + fs.mkdirs(new Path(new Path(new Path(root, "A=2"), "B=6"), "C=31")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6/C=31"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6/C=31"), "_temporary")) + + val parts = (10 to 100).map { a => + val part = Map("a" -> a.toString, "b" -> "5", "c" -> "42") + fs.mkdirs(new Path(new Path(new Path(root, s"a=$a"), "b=5"), "c=42")) + fs.createNewFile(new Path(new Path(root, s"a=$a/b=5/c=42"), "a.csv")) // file + createTablePartition(catalog, part, tableIdent) + part + } // invalid fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name @@ -1174,7 +1187,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { try { sql("ALTER TABLE tab1 RECOVER PARTITIONS") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(part1, part2)) + Set(part1, part2) ++ parts) if (!isUsingHiveMetastore) { assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 728817729dcf7..6708a50a961fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -60,7 +60,8 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, name: TableIdentifier, - isDataSource: Boolean): CatalogTable = { + isDataSource: Boolean, + partitionCols: Seq[String] = Seq("a", "b")): CatalogTable = { val storage = if (isDataSource) { val serde = HiveSerDe.sourceToSerDe("parquet") @@ -84,17 +85,17 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA val metadata = new MetadataBuilder() .putString("key", "value") .build() + val schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), + schema = schema.copy( + fields = schema.fields ++ partitionCols.map(StructField(_, IntegerType))), provider = if (isDataSource) Some("parquet") else Some("hive"), - partitionColumnNames = Seq("a", "b"), + partitionColumnNames = partitionCols, createTime = 0L, createVersion = org.apache.spark.SPARK_VERSION, tracksPartitionsInCatalog = true) From 7ad18ee9f26e75dbe038c6034700f9cd4c0e2baa Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 28 Aug 2018 12:31:33 -0700 Subject: [PATCH 1496/2461] [SPARK-25004][CORE] Add spark.executor.pyspark.memory limit. ## What changes were proposed in this pull request? This adds `spark.executor.pyspark.memory` to configure Python's address space limit, [`resource.RLIMIT_AS`](https://docs.python.org/3/library/resource.html#resource.RLIMIT_AS). Limiting Python's address space allows Python to participate in memory management. In practice, we see fewer cases of Python taking too much memory because it doesn't know to run garbage collection. This results in YARN killing fewer containers. This also improves error messages so users know that Python is consuming too much memory: ``` File "build/bdist.linux-x86_64/egg/package/library.py", line 265, in fe_engineer fe_eval_rec.update(f(src_rec_prep, mat_rec_prep)) File "build/bdist.linux-x86_64/egg/package/library.py", line 163, in fe_comp comparisons = EvaluationUtils.leven_list_compare(src_rec_prep.get(item, []), mat_rec_prep.get(item, [])) File "build/bdist.linux-x86_64/egg/package/evaluationutils.py", line 25, in leven_list_compare permutations = sorted(permutations, reverse=True) MemoryError ``` The new pyspark memory setting is used to increase requested YARN container memory, instead of sharing overhead memory between python and off-heap JVM activity. ## How was this patch tested? Tested memory limits in our YARN cluster and verified that MemoryError is thrown. Author: Ryan Blue Closes #21977 from rdblue/SPARK-25004-add-python-memory-limit. --- .../apache/spark/api/python/PythonRDD.scala | 5 +--- .../spark/api/python/PythonRunner.scala | 27 ++++++++++++------- .../spark/internal/config/package.scala | 4 +++ docs/configuration.md | 12 +++++++++ python/pyspark/worker.py | 23 ++++++++++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 17 ++++++++---- .../spark/deploy/yarn/YarnAllocator.scala | 9 ++++++- .../deploy/yarn/BaseYarnClusterSuite.scala | 27 +++++++++++++------ .../spark/deploy/yarn/YarnClusterSuite.scala | 6 +++-- .../python/AggregateInPandasExec.scala | 4 --- .../python/ArrowEvalPythonExec.scala | 4 --- .../execution/python/ArrowPythonRunner.scala | 4 +-- .../python/BatchEvalPythonExec.scala | 5 +--- .../sql/execution/python/EvalPythonExec.scala | 6 +---- .../python/FlatMapGroupsInPandasExec.scala | 4 --- .../python/PythonForeachWriter.scala | 5 +--- .../execution/python/PythonUDFRunner.scala | 4 +-- .../execution/python/WindowInPandasExec.scala | 4 --- 18 files changed, 105 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index c3db60a23f987..197f4643e6134 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -49,9 +49,6 @@ private[spark] class PythonRDD( isFromBarrier: Boolean = false) extends RDD[Array[Byte]](parent) { - val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions: Array[Partition] = firstParent.partitions override val partitioner: Option[Partitioner] = { @@ -61,7 +58,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuseWorker) + val runner = PythonRunner(func) runner.compute(firstParent.iterator(split, context), split.index, context) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 151c910bf1aee..da6475cfa8549 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -62,14 +63,20 @@ private[spark] object PythonEvalType { */ private[spark] abstract class BasePythonRunner[IN, OUT]( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends Logging { require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + private val conf = SparkEnv.get.conf + private val bufferSize = conf.getInt("spark.buffer.size", 65536) + private val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + // each python worker gets an equal part of the allocation. the worker pool will grow to the + // number of concurrent tasks, which is determined by the number of cores in this executor. + private val memoryMb = conf.get(PYSPARK_EXECUTOR_MEMORY) + .map(_ / conf.getInt("spark.executor.cores", 1)) + // All the Python functions should have the same exec, version and envvars. protected val envVars = funcs.head.funcs.head.envVars protected val pythonExec = funcs.head.funcs.head.pythonExec @@ -82,7 +89,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private[spark] var serverSocket: Option[ServerSocket] = None // Authentication helper used when serving method calls via socket from Python side. - private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + private lazy val authHelper = new SocketAuthHelper(conf) def compute( inputIterator: Iterator[IN], @@ -95,6 +102,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (memoryMb.isDefined) { + envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", memoryMb.get.toString) + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool val released = new AtomicBoolean(false) @@ -485,20 +495,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { - new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + def apply(func: PythonFunction): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func)))) } } /** * A helper class to run Python mapPartition in Spark. */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean) +private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + funcs, PythonEvalType.NON_UDF, Array(Array(0))) { protected override def newWriterThread( env: SparkEnv, diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index daf3f070d72e9..7c2f601c9986a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -114,6 +114,10 @@ package object config { .checkValue(_ >= 0, "The off-heap memory size must not be negative") .createWithDefault(0) + private[spark] val PYSPARK_EXECUTOR_MEMORY = ConfigBuilder("spark.executor.pyspark.memory") + .bytesConf(ByteUnit.MiB) + .createOptional + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() .booleanConf.createWithDefault(false) diff --git a/docs/configuration.md b/docs/configuration.md index 0270dc2cfaf45..9714b48d5e69b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -179,6 +179,18 @@ of the most common options to set are: (e.g. 2g, 8g). + + spark.executor.pyspark.memory + Not set + + The amount of memory to be allocated to PySpark in each executor, in MiB + unless otherwise specified. If set, PySpark memory for an executor will be + limited to this amount. If not set, Spark will not limit Python's memory use + and it is up to the application to avoid exceeding the overhead memory space + shared with other non-JVM processes. When PySpark is run in YARN, this memory + is added to executor resource requests. + + spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d54a5b8e396ea..228b3e07c647a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -22,6 +22,7 @@ import os import sys import time +import resource import socket import traceback @@ -263,6 +264,28 @@ def main(infile, outfile): isBarrier = read_bool(infile) boundPort = read_int(infile) secret = UTF8Deserializer().loads(infile) + + # set up memory limits + memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) + total_memory = resource.RLIMIT_AS + try: + if memory_limit_mb > 0: + (soft_limit, hard_limit) = resource.getrlimit(total_memory) + msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) + print(msg, file=sys.stderr) + + # convert to bytes + new_limit = memory_limit_mb * 1024 * 1024 + + if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: + msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) + print(msg, file=sys.stderr) + resource.setrlimit(total_memory, (new_limit, new_limit)) + + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) + # initialize global state taskContext = None if isBarrier: diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 698fc2ce8bf9d..4a85898ef880b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -91,6 +91,13 @@ private[spark] class Client( private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + private val isPython = sparkConf.get(IS_PYTHON_APP) + private val pysparkWorkerMemory: Int = if (isPython) { + sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + 0 + } + private val distCacheMgr = new ClientDistributedCacheManager() private val principal = sparkConf.get(PRINCIPAL).orNull @@ -333,12 +340,12 @@ private[spark] class Client( val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() logInfo("Verifying our application has not requested more than the maximum " + s"memory capability of the cluster ($maxMem MB per container)") - val executorMem = executorMemory + executorMemoryOverhead + val executorMem = executorMemory + executorMemoryOverhead + pysparkWorkerMemory if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory ($executorMemory" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + - "'yarn.nodemanager.resource.memory-mb'.") + throw new IllegalArgumentException(s"Required executor memory ($executorMemory), overhead " + + s"($executorMemoryOverhead MB), and PySpark memory ($pysparkWorkerMemory MB) is above " + + s"the max threshold ($maxMem MB) of this cluster! Please check the values of " + + s"'yarn.scheduler.maximum-allocation-mb' and/or 'yarn.nodemanager.resource.memory-mb'.") } val amMem = amMemory + amMemoryOverhead if (amMem > maxMem) { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 40f1222fcd83f..8a7551de7c088 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -133,10 +133,17 @@ private[yarn] class YarnAllocator( // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt + protected val pysparkWorkerMemory: Int = if (sparkConf.get(IS_PYTHON_APP)) { + sparkConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + 0 + } // Number of cores per executor. protected val executorCores = sparkConf.get(EXECUTOR_CORES) // Resource capability requested for each executors - private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance( + executorMemory + memoryOverhead + pysparkWorkerMemory, + executorCores) private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS)) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index b0abcc9149d08..3a7913122dd83 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -133,7 +133,8 @@ abstract class BaseYarnClusterSuite extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { + extraEnv: Map[String, String] = Map(), + outFile: Option[File] = None): SparkAppHandle.State = { val deployMode = if (clientMode) "client" else "cluster" val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv @@ -161,6 +162,11 @@ abstract class BaseYarnClusterSuite } extraJars.foreach(launcher.addJar) + if (outFile.isDefined) { + launcher.redirectOutput(outFile.get) + launcher.redirectError() + } + val handle = launcher.startApplication() try { eventually(timeout(2 minutes), interval(1 second)) { @@ -179,17 +185,22 @@ abstract class BaseYarnClusterSuite * the tests enforce that something is written to a file after everything is ok to indicate * that the job succeeded. */ - protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { - checkResult(finalState, result, "success") - } - protected def checkResult( finalState: SparkAppHandle.State, result: File, - expected: String): Unit = { - finalState should be (SparkAppHandle.State.FINISHED) + expected: String = "success", + outFile: Option[File] = None): Unit = { + // the context message is passed to assert as Any instead of a function. to lazily load the + // output from the file, this passes an anonymous object that loads it in toString when building + // an error message + val output = new Object() { + override def toString: String = outFile + .map(Files.toString(_, StandardCharsets.UTF_8)) + .getOrElse("(stdout/stderr was not captured)") + } + assert(finalState === SparkAppHandle.State.FINISHED, output) val resultString = Files.toString(result, StandardCharsets.UTF_8) - resultString should be (expected) + assert(resultString === expected, output) } protected def mainClassName(klass: Class[_]): String = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d67f5d2768e49..58d11e96942e1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -282,13 +282,15 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) + val outFile = Some(File.createTempFile("stdout", null, tempDir)) val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), sparkArgs = Seq("--py-files" -> pyFiles), appArgs = Seq(result.getAbsolutePath()), extraEnv = extraEnvVars, - extraConf = extraConf) - checkResult(finalState, result) + extraConf = extraConf, + outFile = outFile) + checkResult(finalState, result, outFile = outFile) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 88c9c026928e8..2ab7240556aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -79,8 +79,6 @@ case class AggregateInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -137,8 +135,6 @@ case class AggregateInPandasExec( val columnarBatchIter = new ArrowPythonRunner( pyFuncs, - bufferSize, - reuseWorker, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 6a03f860f8f95..2b87796dc6833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -75,8 +75,6 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi protected override def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -89,8 +87,6 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, - bufferSize, - reuseWorker, PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 85b187159a3e6..18992d7a9f974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -39,15 +39,13 @@ import org.apache.spark.util.Utils */ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, bufferSize, reuseWorker, evalType, argOffsets) { + funcs, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2054c700957e0..b08b7e60e130b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -43,8 +43,6 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi protected override def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -75,8 +73,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 04c7dfdd4e204..942a6db57416e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -78,8 +78,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil protected def evaluate( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, @@ -87,8 +85,6 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) inputRDD.mapPartitions { iter => val context = TaskContext.get() @@ -129,7 +125,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil } val outputRowIterator = evaluate( - pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context) + pyFuncs, argOffsets, projectedRowIter, schema, context) val joined = new JoinedRow val resultProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index f5a563baf52df..e9cff1a5a2007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -74,8 +74,6 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -141,8 +139,6 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, - bufferSize, - reuseWorker, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index f08f816cbcca9..a4e9b3305052f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -45,10 +45,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) } private lazy val pythonRunner = { - val conf = SparkEnv.get.conf - val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) - PythonRunner(func, bufferSize, reuseWorker) + PythonRunner(func) } private lazy val outputIterator = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index e28def1c4b423..cc61faa7e7051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -29,12 +29,10 @@ import org.apache.spark.api.python._ */ class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, bufferSize, reuseWorker, evalType, argOffsets) { + funcs, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 47bfbde56bb3e..27bed1137e5b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -95,8 +95,6 @@ case class WindowInPandasExec( protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -156,8 +154,6 @@ case class WindowInPandasExec( val windowFunctionResult = new ArrowPythonRunner( pyFuncs, - bufferSize, - reuseWorker, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, argOffsets, windowInputSchema, From 103854028e99846aabeb6f27eb6fd255ecc96381 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Tue, 28 Aug 2018 15:50:25 -0700 Subject: [PATCH 1497/2461] [SPARK-25212][SQL] Support Filter in ConvertToLocalRelation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Support Filter in ConvertToLocalRelation, similar to how Project works. Additionally, in Optimizer, run ConvertToLocalRelation earlier to simplify the plan. This is good for very short queries which often are queries on local relations. ## How was this patch tested? New test. Manual benchmark. Author: Bogdan Raducanu Author: Shixiong Zhu Author: Yinan Li Author: Li Jin Author: s71955 Author: DB Tsai Author: jaroslav chládek Author: Huangweizhe Author: Xiangrui Meng Author: hyukjinkwon Author: Kent Yao Author: caoxuewen Author: liuxian Author: Adam Bradbury Author: Jose Torres Author: Yuming Wang Author: Liang-Chi Hsieh Closes #22205 from bogdanrdc/local-relation-filter. --- .../sql/catalyst/optimizer/Optimizer.scala | 14 ++++++++++++++ .../ConvertToLocalRelationSuite.scala | 18 ++++++++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 8 ++++---- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 63a62cd0cbfe6..e4b4f1ecbe21f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -130,6 +130,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + // run this once earlier. this might simplify the plan and reduce cost of optimizer. + // for example, a query such as Filter(LocalRelation) would go through all the heavy + // optimizer rules that are triggered when there is a filter + // (e.g. InferFiltersFromConstraints). if we run this batch earlier, the query becomes just + // LocalRelation and does not trigger many rules + Batch("LocalRelation early", fixedPoint, + ConvertToLocalRelation, + PropagateEmptyRelation) :: Batch("Pullup Correlated Expressions", Once, PullupCorrelatedPredicates) :: Batch("Subquery", Once, @@ -1349,6 +1357,12 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) => LocalRelation(output, data.take(limit), isStreaming) + + case Filter(condition, LocalRelation(output, data, isStreaming)) + if !hasUnevaluableExpr(condition) => + val predicate = InterpretedPredicate.create(condition, output) + predicate.initialize(0) + LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming) } private def hasUnevaluableExpr(expr: Expression): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 049a19b86f7cd..0c015f88e1e84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -52,4 +53,21 @@ class ConvertToLocalRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Filter on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + InternalRow(1, 3) :: Nil) + + val filterAndProjectOnLocal = testRelation + .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) + .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6))) + + val optimized = Optimize.execute(filterAndProjectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 10d9a11d2ee79..e6b30f9956daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -196,7 +196,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") // outer -> left - val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) + val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" >= 3) assert(outerJoin2Left.queryExecution.optimizedPlan.collect { case j @ Join(_, _, LeftOuter, _) => j }.size === 1) checkAnswer( @@ -204,7 +204,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(3, 4, "3", null, null, null) :: Nil) // outer -> right - val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) + val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" >= 3) assert(outerJoin2Right.queryExecution.optimizedPlan.collect { case j @ Join(_, _, RightOuter, _) => j }.size === 1) checkAnswer( @@ -221,7 +221,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", 1, 3, "1") :: Nil) // right -> inner - val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) + val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" > 0) assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { case j @ Join(_, _, Inner, _) => j }.size === 1) checkAnswer( @@ -229,7 +229,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", 1, 3, "1") :: Nil) // left -> inner - val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) + val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" > 0) assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { case j @ Join(_, _, Inner, _) => j }.size === 1) checkAnswer( From bbbf8146916aa70d9774543643776eed9d9d9373 Mon Sep 17 00:00:00 2001 From: Bo Meng Date: Tue, 28 Aug 2018 19:39:13 -0500 Subject: [PATCH 1498/2461] [SPARK-22357][CORE] SparkContext.binaryFiles ignore minPartitions parameter ## What changes were proposed in this pull request? Fix the issue that minPartitions was not used in the method. This is a simple fix and I am not trying to make it complicated. The purpose is to still allow user to control the defaultParallelism through the value of minPartitions, while also via sc.defaultParallelism parameters. ## How was this patch tested? I have not provided the additional test since the fix is very straightforward. Closes #21638 from bomeng/22357. Lead-authored-by: Bo Meng Co-authored-by: Bo Meng Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/input/PortableDataStream.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 17cdba4f1305b..ab020aaf6fa4f 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -47,7 +47,7 @@ private[spark] abstract class StreamFileInputFormat[T] def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) - val defaultParallelism = sc.defaultParallelism + val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions) val files = listStatus(context).asScala val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum val bytesPerCore = totalBytes / defaultParallelism From 32c8a3d7beac4b47a75f5ec3c69b13ebc57de0c7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 29 Aug 2018 09:20:32 +0800 Subject: [PATCH 1499/2461] [MINOR] Avoid code duplication for nullable in Higher Order function ## What changes were proposed in this pull request? Most of `HigherOrderFunction`s have the same `nullable` definition, ie. they are nullable when one of their arguments is nullable. The PR refactors it in order to avoid code duplication. ## How was this patch tested? NA Closes #22243 from mgaido91/MINOR_nullable_hof. Authored-by: Marco Gaido Signed-off-by: hyukjinkwon --- .../expressions/higherOrderFunctions.scala | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 9f2e84a230060..2bb6b20b944d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -90,6 +90,8 @@ object LambdaFunction { */ trait HigherOrderFunction extends Expression with ExpectsInputTypes { + override def nullable: Boolean = arguments.exists(_.nullable) + override def children: Seq[Expression] = arguments ++ functions /** @@ -217,8 +219,6 @@ case class ArrayTransform( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = argument.nullable - override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { @@ -287,8 +287,6 @@ case class MapFilter( copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - override def nullable: Boolean = argument.nullable - override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val m = argumentValue.asInstanceOf[MapData] val f = functionForEval @@ -328,8 +326,6 @@ case class ArrayFilter( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = argument.nullable - override def dataType: DataType = argument.dataType override def functionType: AbstractDataType = BooleanType @@ -375,8 +371,6 @@ case class ArrayExists( function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = argument.nullable - override def dataType: DataType = BooleanType override def functionType: AbstractDataType = BooleanType @@ -516,8 +510,6 @@ case class TransformKeys( function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = argument.nullable - @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) @@ -568,8 +560,6 @@ case class TransformValues( function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = argument.nullable - @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) @@ -638,8 +628,6 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { @@ -810,8 +798,6 @@ case class ZipWith(left: Expression, right: Expression, function: Expression) override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = { From 68ec207a320bd50ca61e820c9ff559f799c2ab0a Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Wed, 29 Aug 2018 09:25:49 +0800 Subject: [PATCH 1500/2461] [SPARK-25260][SQL] Fix namespace handling in SchemaConverters.toAvroType ## What changes were proposed in this pull request? `toAvroType` converts spark data type to avro schema. It always appends the record name to namespace so its impossible to have an Avro namespace independent of the record name. When invoked with a spark data type like, ```java val sparkSchema = StructType(Seq( StructField("name", StringType, nullable = false), StructField("address", StructType(Seq( StructField("city", StringType, nullable = false), StructField("state", StringType, nullable = false))), nullable = false))) // map it to an avro schema with record name "employee" and top level namespace "foo.bar", val avroSchema = SchemaConverters.toAvroType(sparkSchema, false, "employee", "foo.bar") // result is // avroSchema.getName = employee // avroSchema.getNamespace = foo.bar.employee // avroSchema.getFullname = foo.bar.employee.employee ``` The patch proposes to fix this so that the result is ``` avroSchema.getName = employee avroSchema.getNamespace = foo.bar avroSchema.getFullname = foo.bar.employee ``` ## How was this patch tested? New and existing unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22251 from arunmahadevan/avro-fix. Authored-by: Arun Mahadevan Signed-off-by: hyukjinkwon --- .../spark/sql/avro/SchemaConverters.scala | 18 ++++---- .../org/apache/spark/sql/avro/AvroSuite.scala | 42 ++++++++++++++++++- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 3a15e8d087fa4..bd1576587d7fa 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -123,7 +123,7 @@ object SchemaConverters { catalystType: DataType, nullable: Boolean = false, recordName: String = "topLevelRecord", - prevNameSpace: String = "") + nameSpace: String = "") : Schema = { val builder = SchemaBuilder.builder() @@ -143,29 +143,25 @@ object SchemaConverters { val avroType = LogicalTypes.decimal(d.precision, d.scale) val fixedSize = minBytesForPrecision(d.precision) // Need to avoid naming conflict for the fixed fields - val name = prevNameSpace match { + val name = nameSpace match { case "" => s"$recordName.fixed" - case _ => s"$prevNameSpace.$recordName.fixed" + case _ => s"$nameSpace.$recordName.fixed" } avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize)) case BinaryType => builder.bytesType() case ArrayType(et, containsNull) => builder.array() - .items(toAvroType(et, containsNull, recordName, prevNameSpace)) + .items(toAvroType(et, containsNull, recordName, nameSpace)) case MapType(StringType, vt, valueContainsNull) => builder.map() - .values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) + .values(toAvroType(vt, valueContainsNull, recordName, nameSpace)) case st: StructType => - val nameSpace = prevNameSpace match { - case "" => recordName - case _ => s"$prevNameSpace.$recordName" - } - + val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() st.foreach { f => val fieldAvroType = - toAvroType(f.dataType, f.nullable, f.name, nameSpace) + toAvroType(f.dataType, f.nullable, f.name, childNameSpace) fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() } fieldsAssembler.endRecord() diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 72bef9e3aed41..9ad4388414eaa 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1082,7 +1082,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val schema = getAvroSchemaStringFromFiles(dir.toString) assert(schema.contains("\"namespace\":\"topLevelRecord\"")) assert(schema.contains("\"namespace\":\"topLevelRecord.data\"")) - assert(schema.contains("\"namespace\":\"topLevelRecord.data.data\"")) } } @@ -1099,6 +1098,47 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("check namespace - toAvroType") { + val sparkSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("address", StructType(Seq( + StructField("city", StringType, nullable = false), + StructField("state", StringType, nullable = false))), + nullable = false))) + val employeeType = SchemaConverters.toAvroType(sparkSchema, + recordName = "employee", + nameSpace = "foo.bar") + + assert(employeeType.getFullName == "foo.bar.employee") + assert(employeeType.getName == "employee") + assert(employeeType.getNamespace == "foo.bar") + + val addressType = employeeType.getField("address").schema() + assert(addressType.getFullName == "foo.bar.employee.address") + assert(addressType.getName == "address") + assert(addressType.getNamespace == "foo.bar.employee") + } + + test("check empty namespace - toAvroType") { + val sparkSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("address", StructType(Seq( + StructField("city", StringType, nullable = false), + StructField("state", StringType, nullable = false))), + nullable = false))) + val employeeType = SchemaConverters.toAvroType(sparkSchema, + recordName = "employee") + + assert(employeeType.getFullName == "employee") + assert(employeeType.getName == "employee") + assert(employeeType.getNamespace == null) + + val addressType = employeeType.getField("address").schema() + assert(addressType.getFullName == "employee.address") + assert(addressType.getName == "address") + assert(addressType.getNamespace == "employee") + } + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) case class NestedTopArray(id: Int, data: NestedMiddleArray) From 38391c9aa8a88fcebb337934f30298a32d91596b Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 29 Aug 2018 09:47:38 +0800 Subject: [PATCH 1501/2461] [SPARK-25253][PYSPARK] Refactor local connection & auth code This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm. Also it gives consistent ipv6 and error handling. Two other incidental changes, that shouldn't matter: 1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator) 2) for `rdd._load_from_socket`, the timeout is only increased after authentication. Closes #22247 from squito/py_connection_refactor. Authored-by: Imran Rashid Signed-off-by: hyukjinkwon --- .../spark/api/python/PythonRunner.scala | 3 +- python/pyspark/java_gateway.py | 32 ++++++++++++++++++- python/pyspark/rdd.py | 27 ++-------------- python/pyspark/taskcontext.py | 32 +++---------------- python/pyspark/worker.py | 7 ++-- 5 files changed, 40 insertions(+), 61 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index da6475cfa8549..6c7e8630789bd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -216,6 +216,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock = serverSocket.get.accept() // Wait for function call from python side. sock.setSoTimeout(10000) + authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) input.readInt() match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => @@ -334,8 +335,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( def barrierAndServe(sock: Socket): Unit = { require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { context.asInstanceOf[BarrierTaskContext].barrier() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index fa2d5e8db716a..b06503b53be90 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -134,7 +134,7 @@ def killChild(): return gateway -def do_server_auth(conn, auth_secret): +def _do_server_auth(conn, auth_secret): """ Performs the authentication protocol defined by the SocketAuthHelper class on the given file-like object 'conn'. @@ -147,6 +147,36 @@ def do_server_auth(conn, auth_secret): raise Exception("Unexpected reply from iterator server.") +def local_connect_and_auth(port, auth_secret): + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + :param port + :param auth_secret + :return: a tuple with (sockfile, sock) + """ + sock = None + errors = [] + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(15) + sock.connect(sa) + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) + sock.close() + sock = None + else: + raise Exception("could not open socket: %s" % errors) + + def ensure_callback_server_started(gw): """ Start callback server if not already started. The callback server is needed if the Java diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b061074a28ab4..380475e706fbe 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ else: from itertools import imap as map, ifilter as filter -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ @@ -141,33 +141,10 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): - port, auth_secret = sock_info - sock = None - errors = [] - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - sock.settimeout(15) - sock.connect(sa) - except socket.error as e: - emsg = _exception_message(e) - errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket: %s" % errors) + (sockfile, sock) = local_connect_and_auth(*sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - - sockfile = sock.makefile("rwb", 65536) - do_server_auth(sockfile, auth_secret) - # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index c0312e5265c6e..53fc2b29e066f 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -18,7 +18,7 @@ from __future__ import print_function import socket -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import write_int, UTF8Deserializer @@ -108,38 +108,14 @@ def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. - - This is copied from context.py, while modified the message protocol. """ - sock = None - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - # Do not allow timeout for socket reading operation. - sock.settimeout(None) - sock.connect(sa) - except socket.error: - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket") - - # We don't really need a socket file here, it's just for convenience that we can reuse the - # do_server_auth() function and data serialization methods. - sockfile = sock.makefile("rwb", 65536) - + (sockfile, sock) = local_connect_and_auth(port, auth_secret) + # The barrier() call may block forever, so no timeout + sock.settimeout(None) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() - # Do server auth. - do_server_auth(sockfile, auth_secret) - # Collect result. res = UTF8Deserializer().loads(sockfile) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 228b3e07c647a..e934da4d2eb6e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -28,7 +28,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -387,8 +387,5 @@ def process(): # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("rwb", 65536) - do_server_auth(sock_file, auth_secret) + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file) From ff8dcc1d4c684e1b68e63d61b3f20284b9979cca Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 29 Aug 2018 04:30:31 +0000 Subject: [PATCH 1502/2461] [SPARK-25235][SHELL] Merge the REPL code in Scala 2.11 and 2.12 branches ## What changes were proposed in this pull request? Using some reflection tricks to merge Scala 2.11 and 2.12 codebase. ## How was this patch tested? Existing tests. Closes #22246 from dbtsai/repl. Lead-authored-by: DB Tsai Co-authored-by: Liang-Chi Hsieh Signed-off-by: DB Tsai --- repl/pom.xml | 10 -- .../org/apache/spark/repl/SparkILoop.scala | 143 ------------------ .../org/apache/spark/repl/SparkILoop.scala | 49 +++++- 3 files changed, 45 insertions(+), 157 deletions(-) delete mode 100644 repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala rename repl/{scala-2.11 => }/src/main/scala/org/apache/spark/repl/SparkILoop.scala (82%) diff --git a/repl/pom.xml b/repl/pom.xml index 861bbd7c49654..553d5eb79a256 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -167,14 +167,4 @@ - - - scala-2.12 - - scala-2.12/src/main/scala - scala-2.12/src/test/scala - - - - diff --git a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index ffb2e5f5db7e2..0000000000000 --- a/repl/scala-2.12/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import java.io.BufferedReader - -import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} -import scala.tools.nsc.util.stringFromStream -import scala.util.Properties.{javaVersion, javaVmName, versionString} - -/** - * A Spark-specific interactive shell. - */ -class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) - extends ILoop(in0, out) { - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) - def this() = this(None, new JPrintWriter(Console.out, true)) - - val initializationCommands: Seq[String] = Seq( - """ - @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { - org.apache.spark.repl.Main.sparkSession - } else { - org.apache.spark.repl.Main.createSparkSession() - } - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println( - s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """, - "import org.apache.spark.SparkContext._", - "import spark.implicits._", - "import spark.sql", - "import org.apache.spark.sql.functions._" - ) - - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(command) - } - } - } - - /** Print a welcome message */ - override def printWelcome() { - import org.apache.spark.SPARK_VERSION - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - /** Available commands */ - override def commands: List[LoopCommand] = standardCommands - - /** - * We override `createInterpreter` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def createInterpreter(): Unit = { - super.createInterpreter() - initializeSpark() - } - - override def resetCommand(line: String): Unit = { - super.resetCommand(line) - initializeSpark() - echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") - } - - override def replay(): Unit = { - initializeSpark() - super.replay() - } - -} - -object SparkILoop { - - /** - * Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) - - if (sets.classpath.isDefault) { - sets.classpath.value = sys.props("java.class.path") - } - repl process sets - } - } - } - def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) -} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala similarity index 82% rename from repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala rename to repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 94265267b1f97..aa9aa2793b8b3 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -24,7 +24,6 @@ import scala.Predef.{println => _, _} // scalastyle:on println import scala.concurrent.Future import scala.reflect.classTag -import scala.reflect.internal.util.ScalaClassLoader.savingContextLoader import scala.reflect.io.File import scala.tools.nsc.{GenericRunnerSettings, Properties} import scala.tools.nsc.Settings @@ -33,7 +32,7 @@ import scala.tools.nsc.interpreter.{AbstractOrMissingHandler, ILoop, IMain, JPri import scala.tools.nsc.interpreter.{NamedParam, SimpleReader, SplashLoop, SplashReader} import scala.tools.nsc.interpreter.StdReplTags.tagOfIMain import scala.tools.nsc.util.stringFromStream -import scala.util.Properties.{javaVersion, javaVmName, versionString} +import scala.util.Properties.{javaVersion, javaVmName, versionNumberString, versionString} /** * A Spark-specific interactive shell. @@ -43,10 +42,32 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) def this() = this(None, new JPrintWriter(Console.out, true)) + /** + * TODO: Remove the following `override` when the support of Scala 2.11 is ended + * Scala 2.11 has a bug of finding imported types in class constructors, extends clause + * which is fixed in Scala 2.12 but never be back-ported into Scala 2.11.x. + * As a result, we copied the fixes into `SparkILoopInterpreter`. See SPARK-22393 for detail. + */ override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) + if (isScala2_11) { + if (addedClasspath != "") { + settings.classpath append addedClasspath + } + // scalastyle:off classforname + // Have to use the default classloader to match the one used in + // `classOf[Settings]` and `classOf[JPrintWriter]`. + intp = Class.forName("org.apache.spark.repl.SparkILoopInterpreter") + .getDeclaredConstructor(Seq(classOf[Settings], classOf[JPrintWriter]): _*) + .newInstance(Seq(settings, out): _*) + .asInstanceOf[IMain] + // scalastyle:on classforname + } else { + super.createInterpreter() + } } + private val isScala2_11 = versionNumberString.startsWith("2.11") + val initializationCommands: Seq[String] = Seq( """ @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { @@ -124,6 +145,26 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) super.replay() } + /** + * TODO: Remove `runClosure` when the support of Scala 2.11 is ended + */ + private def runClosure(body: => Boolean): Boolean = { + if (isScala2_11) { + // In Scala 2.11, there is a bug that interpret could set the current thread's + // context classloader, but fails to reset it to its previous state when returning + // from that method. This is fixed in SI-8521 https://github.com/scala/scala/pull/5657 + // which is never back-ported into Scala 2.11.x. The following is a workaround fix. + val original = Thread.currentThread().getContextClassLoader + try { + body + } finally { + Thread.currentThread().setContextClassLoader(original) + } + } else { + body + } + } + /** * The following code is mostly a copy of `process` implementation in `ILoop.scala` in Scala * @@ -138,7 +179,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) * We should remove this duplication once Scala provides a way to load our custom initialization * code, and also customize the ordering of printing welcome message. */ - override def process(settings: Settings): Boolean = savingContextLoader { + override def process(settings: Settings): Boolean = runClosure { def newReader = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) From 82c18c240a6913a917df3b55cc5e22649561c4dd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 29 Aug 2018 15:01:12 +0800 Subject: [PATCH 1503/2461] [SPARK-23030][SQL][PYTHON] Use Arrow stream format for creating from and collecting Pandas DataFrames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This changes the calls of `toPandas()` and `createDataFrame()` to use the Arrow stream format, when Arrow is enabled. Previously, Arrow data was written to byte arrays where each chunk is an output of the Arrow file format. This was mainly due to constraints at the time, and caused some overhead by writing the schema/footer on each chunk of data and then having to read multiple Arrow file inputs and concat them together. Using the Arrow stream format has improved these by increasing performance, lower memory overhead for the average case, and simplified the code. Here are the details of this change: **toPandas()** _Before:_ Spark internal rows are converted to Arrow file format, each group of records is a complete Arrow file which contains the schema and other metadata. Next a collect is done and an Array of Arrow files is the result. After that each Arrow file is sent to Python driver which then loads each file and concats them to a single Arrow DataFrame. _After:_ Spark internal rows are converted to ArrowRecordBatches directly, which is the simplest Arrow component for IPC data transfers. The driver JVM then immediately starts serving data to Python as an Arrow stream, sending the schema first. It then starts a Spark job with a custom handler that sends Arrow RecordBatches to Python. Partitions arriving in order are sent immediately, and out-of-order partitions are buffered until the ones that precede it come in. This improves performance, simplifies memory usage on executors, and improves the average memory usage on the JVM driver. Since the order of partitions must be preserved, the worst case is that the first partition will be the last to arrive all data must be buffered in memory until then. This case is no worse that before when doing a full collect. **createDataFrame()** _Before:_ A Pandas DataFrame is split into parts and each part is made into an Arrow file. Then each file is prefixed by the buffer size and written to a temp file. The temp file is read and each Arrow file is parallelized as a byte array. _After:_ A Pandas DataFrame is split into parts, then an Arrow stream is written to a temp file where each part is an ArrowRecordBatch. The temp file is read as a stream and the Arrow messages are examined. If the message is an ArrowRecordBatch, the data is saved as a byte array. After reading the file, each ArrowRecordBatch is parallelized as a byte array. This has slightly more processing than before because we must look each Arrow message to extract the record batches, but performance ends up a litle better. It is cleaner in the sense that IPC from Python to JVM is done over a single Arrow stream. ## How was this patch tested? Added new unit tests for the additions to ArrowConverters in Scala, existing tests for Python. ## Performance Tests - toPandas Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5 loops each. Test code ```python df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", rand()) for i in range(5): start = time.time() _ = df.toPandas() elapsed = time.time() - start ``` Current Master | This PR ---------------------|------------ 5.803557 | 5.16207 5.409119 | 5.133671 5.493509 | 5.147513 5.433107 | 5.105243 5.488757 | 5.018685 Avg Master | Avg This PR ------------------|-------------- 5.5256098 | 5.1134364 Speedup of **1.08060595** ## Performance Tests - createDataFrame Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `createDataFrame()` and get the first record. Took the average best time of 5 runs/5 loops each. Test code ```python def run(): pdf = pd.DataFrame(np.random.rand(10000000, 10)) spark.createDataFrame(pdf).first() for i in range(6): start = time.time() run() elapsed = time.time() - start gc.collect() print("Run %d: %f" % (i, elapsed)) ``` Current Master | This PR --------------------|---------- 6.234608 | 5.665641 6.32144 | 5.3475 6.527859 | 5.370803 6.95089 | 5.479151 6.235046 | 5.529167 Avg Master | Avg This PR ---------------|---------------- 6.4539686 | 5.4784524 Speedup of **1.178064192** ## Memory Improvements **toPandas()** The most significant improvement is reduction of the upper bound space complexity in the JVM driver. Before, the entire dataset was collected in the JVM first before sending it to Python. With this change, as soon as a partition is collected, the result handler immediately sends it to Python, so the upper bound is the size of the largest partition. Also, using the Arrow stream format is more efficient because the schema is written once per stream, followed by record batches. The schema is now only send from driver JVM to Python. Before, multiple Arrow file formats were used that each contained the schema. This duplicated schema was created in the executors, sent to the driver JVM, and then Python where all but the first one received are discarded. I verified the upper bound limit by running a test that would collect data that would exceed the amount of driver JVM memory available. Using these settings on a standalone cluster: ``` spark.driver.memory 1g spark.executor.memory 5g spark.sql.execution.arrow.enabled true spark.sql.execution.arrow.fallback.enabled false spark.sql.execution.arrow.maxRecordsPerBatch 0 spark.driver.maxResultSize 2g ``` Test code: ```python from pyspark.sql.functions import rand df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()) df.toPandas() ``` This makes total data size of 33554432×8×4 = 1073741824 With the current master, it fails with OOM but passes using this PR. **createDataFrame()** No significant change in memory except that using the stream format instead of separate file formats avoids duplicated the schema, similar to toPandas above. The process of reading the stream and parallelizing the batches does cause the record batch message metadata to be copied, but it's size is insignificant. Closes #21546 from BryanCutler/arrow-toPandas-stream-SPARK-23030. Authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- .../apache/spark/api/python/PythonRDD.scala | 24 +- python/pyspark/context.py | 11 +- python/pyspark/serializers.py | 30 ++- python/pyspark/sql/dataframe.py | 15 +- python/pyspark/sql/session.py | 12 +- .../scala/org/apache/spark/sql/Dataset.scala | 56 ++++- .../spark/sql/api/python/PythonSQLUtils.scala | 21 +- .../sql/execution/arrow/ArrowConverters.scala | 227 ++++++++++++------ .../arrow/ArrowConvertersSuite.scala | 93 ++++--- 9 files changed, 326 insertions(+), 163 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 197f4643e6134..e639a842754bd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -399,6 +399,26 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { + serveToStream(threadName) { out => + writeIteratorToStream(items, new DataOutputStream(out)) + } + } + + /** + * Create a socket server and background thread to execute the writeFunc + * with the given OutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the block of code and pass in + * the socket output stream. + * + * The thread will terminate after the block of code is executed or any + * exceptions happen. + */ + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -410,9 +430,9 @@ private[spark] object PythonRDD extends Logging { val sock = serverSocket.accept() authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + val out = new BufferedOutputStream(sock.getOutputStream) Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) + writeFunc(out) } { out.close() sock.close() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b77fa0ee2892b..4cabae4b2f50b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c) # Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - jrdd = self._serialize_to_jvm(c, numSlices, serializer) + + def reader_func(temp_filename): + return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + + jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) - def _serialize_to_jvm(self, data, parallelism, serializer): + def _serialize_to_jvm(self, data, serializer, reader_func): """ Calling the Java parallelize() method with an ArrayList is too slow, because it sends O(n) Py4J commands. As an alternative, serialized @@ -507,8 +511,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer): try: serializer.dump_stream(data, tempFile) tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - return readRDDFromFile(self._jsc, tempFile.name, parallelism) + return reader_func(tempFile.name) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 10385589c4d3b..48006778e86f2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,27 +185,31 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowStreamSerializer(Serializer): """ - Serializes bytes as Arrow data with the Arrow file format. + Serializes Arrow record batches as a stream. """ - def dumps(self, batch): + def dump_stream(self, iterator, stream): import pyarrow as pa - import io - sink = io.BytesIO() - writer = pa.RecordBatchFileWriter(sink, batch.schema) - writer.write_batch(batch) - writer.close() - return sink.getvalue() + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() + reader = pa.open_stream(stream) + for batch in reader: + yield batch def __repr__(self): - return "ArrowSerializer" + return "ArrowStreamSerializer" def _create_batch(series, timezone): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 07fb260a77ea0..1affc9b4fcf6c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -2118,10 +2118,9 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) + batches = self._collectAsArrow() + if len(batches) > 0: + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2170,14 +2169,14 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowSerializer())) + return list(_load_from_socket(sock_info, ArrowStreamSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 19eea2fd2c775..87d8d6a59a6e9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.serializers import ArrowStreamSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,10 +539,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct - # Create the Spark DataFrame directly from the Arrow data and schema - jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) - jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( - jrdd, schema.json(), self._wrapped._jsqlContext) + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.arrowReadStreamFromFile( + self._wrapped._jsqlContext, temp_filename, schema.json()) + + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func) df = DataFrame(jdf, self._wrapped) df._schema = schema return df diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 367b98563e0bf..db439b1ee76f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -3273,13 +3273,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = - toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => + val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } } } @@ -3386,20 +3422,20 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + /** Convert to an RDD of serialized ArrowRecordBatches. */ + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator( + ArrowConverters.toBatchIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } // This is only used in tests, for now. - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - toArrowPayload(queryExecution.executedPlan) + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { + toArrowBatchRdd(queryExecution.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index b33760b1edbc6..c0830e77b5a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.api.python -import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + def arrowReadStreamFromFile( + sqlContext: SQLContext, + filename: String, + schemaString: String): DataFrame = { + val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) + ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 6a5ac2413d73c..1a48bc8398a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,73 +17,75 @@ package org.apache.spark.sql.execution.arrow -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} +import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ +import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( + schema: StructType, + out: OutputStream, + timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { + arrowBatchIter.foreach(writeChannel.write) + } /** - * Return the schema loaded from the Arrow record batch being iterated over + * End the Arrow stream, does not close output stream. */ - def schema: StructType + def end(): Unit = { + ArrowStreamWriter.writeEndOfStream(writeChannel) + } } private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size + * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toPayloadIterator( + private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext): Iterator[ArrowPayload] = { + context: TaskContext): Iterator[Array[Byte]] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) context.addTaskCompletionListener[Unit] { _ => @@ -91,7 +93,7 @@ private[sql] object ArrowConverters { allocator.close() } - new Iterator[ArrowPayload] { + new Iterator[Array[Byte]] { override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -99,9 +101,9 @@ private[sql] object ArrowConverters { false } - override def next(): ArrowPayload = { + override def next(): Array[Byte] = { val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) Utils.tryWithSafeFinally { var rowCount = 0 @@ -111,45 +113,46 @@ private[sql] object ArrowConverters { rowCount += 1 } arrowWriter.finish() - writer.writeBatch() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() } { arrowWriter.reset() - writer.close() } - new ArrowPayload(out.toByteArray) + out.toByteArray } } } /** - * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator - * and the schema from the first batch of Arrow data read. + * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[sql] def fromPayloadIterator( - payloadIter: Iterator[ArrowPayload], - context: TaskContext): ArrowRowIterator = { + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) + + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) - new ArrowRowIterator { - private var reader: ArrowFileReader = null - private var schemaRead = StructType(Seq.empty) - private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + new Iterator[InternalRow] { + private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty context.addTaskCompletionListener[Unit] { _ => - closeReader() + root.close() allocator.close() } - override def schema: StructType = schemaRead - override def hasNext: Boolean = rowIter.hasNext || { - closeReader() - if (payloadIter.hasNext) { + if (arrowBatchIter.hasNext) { rowIter = nextBatch() true } else { + root.close() allocator.close() false } @@ -157,19 +160,11 @@ private[sql] object ArrowConverters { override def next(): InternalRow = rowIter.next() - private def closeReader(): Unit = { - if (reader != null) { - reader.close() - reader = null - } - } - private def nextBatch(): Iterator[InternalRow] = { - val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) - reader = new ArrowFileReader(in, allocator) - reader.loadNextBatch() // throws IOException - val root = reader.getVectorSchemaRoot // throws IOException - schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { + Utils.tryWithResource(new FileInputStream(filename)) { fileStream => + // Create array to consume iterator so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + + // Iterate over the serialized Arrow RecordBatch messages from a stream + new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { + val prevBatch = batch + batch = readNextBatch() + prevBatch + } + + // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it + // is a RecordBatch message and then returning the complete serialized message which consists + // of a int32 length, serialized message metadata and a serialized RecordBatch message body + def readNextBatch(): Array[Byte] = { + val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) + if (msgMetadata == null) { + return null + } + + // Get the length of the body, which has not been read at this point + val bodyLength = msgMetadata.getMessageBodyLength.toInt + + // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages + if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Buffer backed output large enough to hold the complete serialized message + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) + + // Write message metadata to ByteBuffer output stream + MessageSerializer.writeMessageBuffer( + new WriteChannel(Channels.newChannel(bbout)), + msgMetadata.getMessageLength, + msgMetadata.getMessageBuffer) + + // Get a zero-copy ByteBuffer with already contains message metadata, must close first + bbout.close() + val bb = bbout.toByteBuffer + bb.position(bbout.getCount()) + + // Read message body directly into the ByteBuffer to avoid copy, return backed byte array + bb.limit(bb.capacity()) + JavaUtils.readFully(in, bb) + bb.array() + } else { + if (bodyLength > 0) { + // Skip message body if not a RecordBatch + in.position(in.position() + bodyLength) + } + + // Proceed to next message + readNextBatch() + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 261df06100aef..c36872a6a5289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.arrow -import java.io.File +import java.io.{ByteArrayOutputStream, DataOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader -import org.apache.arrow.vector.util.Validator +import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val arrowBatches = indexData.toArrowBatchRdd.collect() + assert(arrowBatches.nonEmpty) + assert(arrowBatches.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) @@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) + val arrowBatches = testData2.toArrowBatchRdd.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches + assert(arrowBatches.length === 2) val schema = testData2.schema val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") @@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { Files.write(json1, tempFile1, StandardCharsets.UTF_8) Files.write(json2, tempFile2, StandardCharsets.UTF_8) - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) + validateConversion(schema, arrowBatches(0), tempFile1) + validateConversion(schema, arrowBatches(1), tempFile2) } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) + val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect() + assert(arrowBatches.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) + val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect() + assert(filteredArrowBatches.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) + val arrowBatches = emptyPart.toArrowBatchRdd.collect() + assert(arrowBatches.length === 1) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) assert(arrowRecordBatches.head.getLength == 1) arrowRecordBatches.foreach(_.close()) allocator.close() @@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - assert(arrowPayloads.length >= 4) + val arrowBatches = df.toArrowBatchRdd.collect() + assert(arrowBatches.length >= 4) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) var recordCount = 0 arrowRecordBatches.foreach { batch => assert(batch.getLength > 0) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowBatchRdd.collect() } + runUnsupported { complexData.toArrowBatchRdd.collect() } } test("test Arrow Validator") { @@ -1318,7 +1318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) @@ -1326,10 +1326,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) - val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) - assert(schema == outputRowIter.schema) + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { + val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val ctx = TaskContext.empty() + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + + // Write batches to Arrow stream format as a byte array + val out = new ByteArrayOutputStream() + Utils.tryWithResource(new DataOutputStream(out)) { dataOut => + val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + writer.writeBatches(batchIter) + writer.end() + } + + // Read Arrow stream into batches, then convert back to rows + val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) + val readBatches = ArrowConverters.getBatchesFromStream(in) + val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => @@ -1348,15 +1379,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) + validateConversion(df.schema, batchBytes, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, - arrowPayload: ArrowPayload, + batchBytes: Array[Byte], jsonFile: File, timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) @@ -1368,7 +1399,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) + val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator) vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) From 1fd59c129a7aa16f9960b109128b166952992f32 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 29 Aug 2018 15:23:16 +0800 Subject: [PATCH 1504/2461] [WIP][SPARK-25044][SQL] (take 2) Address translation of LMF closure primitive args to Object in Scala 2.12 ## What changes were proposed in this pull request? Alternative take on https://github.com/apache/spark/pull/22063 that does not introduce udfInternal. Resolve issue with inferring func types in 2.12 by instead using info captured when UDF is registered -- capturing which types are nullable (i.e. not primitive) ## How was this patch tested? Existing tests. Closes #22259 from srowen/SPARK-25044.2. Authored-by: Sean Owen Signed-off-by: Wenchen Fan --- project/MimaExcludes.scala | 5 + .../spark/sql/catalyst/ScalaReflection.scala | 9 -- .../sql/catalyst/analysis/Analyzer.scala | 50 +++---- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../sql/catalyst/ScalaReflectionSuite.scala | 17 --- .../sql/catalyst/analysis/AnalysisSuite.scala | 9 +- .../apache/spark/sql/UDFRegistration.scala | 122 +++++++++++------- .../sql/expressions/UserDefinedFunction.scala | 8 +- .../org/apache/spark/sql/functions.scala | 22 ++-- 9 files changed, 133 insertions(+), 116 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cdc99a48e5b64..4f250c9943edb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12 + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + // [SPARK-24296][CORE] Replicate large blocks as a stream. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), // [SPARK-23528] Add numIter to ClusteringSummary diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191c3de965b34..0238d57de2446 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -932,15 +932,6 @@ trait ScalaReflection { tpe.dealias.erasure.typeSymbol.asClass.fullName } - /** - * Returns classes of input parameters of scala function object. - */ - def getParameterTypes(func: AnyRef): Seq[Class[_]] = { - val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) - assert(methods.length == 1) - methods.head.getParameterTypes - } - /** * Returns the parameter names and types for the primary constructor of this type. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d00b82d35d7d4..580133dd971b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2149,28 +2149,34 @@ class Analyzer( case p => p transformExpressionsUp { - case udf @ ScalaUDF(func, _, inputs, _, _, _, _) => - val parameterTypes = ScalaReflection.getParameterTypes(func) - assert(parameterTypes.length == inputs.length) - - // TODO: skip null handling for not-nullable primitive inputs after we can completely - // trust the `nullable` information. - // (cls, expr) => cls.isPrimitive && expr.nullable - val needsNullCheck = (cls: Class[_], expr: Expression) => - cls.isPrimitive && !expr.isInstanceOf[KnownNotNull] - val inputsNullCheck = parameterTypes.zip(inputs) - .filter { case (cls, expr) => needsNullCheck(cls, expr) } - .map { case (_, expr) => IsNull(expr) } - .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - // Once we add an `If` check above the udf, it is safe to mark those checked inputs - // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning - // branch of `If` will be called if any of these checked inputs is null. Thus we can - // prevent this rule from being applied repeatedly. - val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) => - if (needsNullCheck(cls, expr)) KnownNotNull(expr) else expr } - inputsNullCheck - .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) - .getOrElse(udf) + case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) => + if (nullableTypes.isEmpty) { + // If no nullability info is available, do nothing. No fields will be specially + // checked for null in the plan. If nullability info is incorrect, the results + // of the UDF could be wrong. + udf + } else { + // Otherwise, add special handling of null for fields that can't accept null. + // The result of operations like this, when passed null, is generally to return null. + assert(nullableTypes.length == inputs.length) + + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + val inputsNullCheck = nullableTypes.zip(inputs) + .filter { case (nullable, _) => !nullable } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) => + if (nullable) expr else KnownNotNull(expr) + } + inputsNullCheck + .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) + .getOrElse(udf) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 4b09978e75081..8954fe8a58e6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.types.DataType * @param nullable True if the UDF can return null value. * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result * each time it is invoked with a particular input. + * @param nullableTypes which of the inputTypes are nullable (i.e. not primitive) */ case class ScalaUDF( function: AnyRef, @@ -47,7 +48,8 @@ case class ScalaUDF( inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, - udfDeterministic: Boolean = true) + udfDeterministic: Boolean = true, + nullableTypes: Seq[Boolean] = Nil) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { // The constructor for SPARK 2.1 and 2.2 @@ -58,7 +60,8 @@ case class ScalaUDF( inputTypes: Seq[DataType], udfName: Option[String]) = { this( - function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + function, dataType, children, inputTypes, udfName, nullable = true, + udfDeterministic = true, nullableTypes = Nil) } override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 353b8344658f2..f9ee948b97e0a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -261,23 +261,6 @@ class ScalaReflectionSuite extends SparkFunSuite { } } - test("get parameter type from a function object") { - val primitiveFunc = (i: Int, j: Long) => "x" - val primitiveTypes = getParameterTypes(primitiveFunc) - assert(primitiveTypes.forall(_.isPrimitive)) - assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) - - val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" - val boxedTypes = getParameterTypes(boxedFunc) - assert(boxedTypes.forall(!_.isPrimitive)) - assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) - - val anyFunc = (i: Any, j: AnyRef) => "x" - val anyTypes = getParameterTypes(anyFunc) - assert(anyTypes.forall(!_.isPrimitive)) - assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) - } - test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 94f37f1aafa78..3b3edac0a314e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -317,13 +317,15 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkUDF(udf1, expected1) // only primitive parameter needs special null handling - val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, + nullableTypes = true :: false :: Nil) val expected2 = If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters - val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, + nullableTypes = false :: false :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, @@ -335,7 +337,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { val udf4 = ScalaUDF( (s: Short, d: Double) => "x", StringType, - short :: double.withNullability(false) :: Nil) + short :: double.withNullability(false) :: Nil, + nullableTypes = false :: false :: Nil) val expected4 = If( IsNull(short), nullResult, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f94baef39dfad..24ee46d0e8147 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i] :: $s"}) println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -122,9 +122,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputTypes = Try($inputTypes).toOption + | val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try($inputTypes).toOption | def builder(e: Seq[Expression]) = if (e.length == $x) { - | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + | udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $x; Found: " + e.length) @@ -167,9 +168,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -186,9 +188,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -205,9 +208,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -224,9 +228,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -243,9 +248,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -262,9 +268,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -281,9 +288,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -300,9 +308,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -319,9 +328,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -338,9 +348,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -357,9 +368,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -376,9 +388,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -395,9 +408,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -414,9 +428,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -433,9 +448,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -452,9 +468,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -471,9 +488,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -490,9 +508,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -509,9 +528,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -528,9 +548,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -547,9 +568,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -566,9 +588,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -585,9 +608,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption + val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: ScalaReflection.schemaFor[A22] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index bdc4bb4422ae7..7bd20dbe8f6d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.types.DataType @@ -40,7 +41,7 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Option[Seq[DataType]]) { + inputTypes: Option[Seq[ScalaReflection.Schema]]) { private var _nameOption: Option[String] = None private var _nullable: Boolean = true @@ -72,10 +73,11 @@ case class UserDefinedFunction protected[sql] ( f, dataType, exprs.map(_.expr), - inputTypes.getOrElse(Nil), + inputTypes.map(_.map(_.dataType)).getOrElse(Nil), udfName = _nameOption, nullable = _nullable, - udfDeterministic = _deterministic)) + udfDeterministic = _deterministic, + nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil))) } private def copyAll(): UserDefinedFunction = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1d806e056d31c..a261a7c1752d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3819,7 +3819,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -3893,7 +3893,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3909,7 +3909,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3925,7 +3925,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3941,7 +3941,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3957,7 +3957,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3973,7 +3973,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -3989,7 +3989,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -4005,7 +4005,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -4021,7 +4021,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } @@ -4037,7 +4037,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption val udf = UserDefinedFunction(f, dataType, inputTypes) if (nullable) udf else udf.asNonNullable() } From 3864480e14a4961720cc1be43635c7c7dec08c09 Mon Sep 17 00:00:00 2001 From: sarutak Date: Wed, 29 Aug 2018 07:13:13 -0700 Subject: [PATCH 1505/2461] [SPARK-25266][CORE] Fix memory leak in Barrier Execution Mode ## What changes were proposed in this pull request? BarrierCoordinator uses Timer and TimerTask. `TimerTask#cancel()` is invoked in ContextBarrierState#cancelTimerTask but `Timer#purge()` is never invoked. Once a TimerTask is scheduled, the reference to it is not released until `Timer#purge()` is invoked even though `TimerTask#cancel()` is invoked. ## How was this patch tested? I checked the number of instances related to the TimerTask using jmap. Closes #22258 from sarutak/fix-barrierexec-oom. Authored-by: sarutak Signed-off-by: Xiangrui Meng --- core/src/main/scala/org/apache/spark/BarrierCoordinator.scala | 1 + core/src/main/scala/org/apache/spark/BarrierTaskContext.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 5e546c694e8d9..6439ca5db06e9 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -123,6 +123,7 @@ private[spark] class BarrierCoordinator( private def cancelTimerTask(): Unit = { if (timerTask != null) { timerTask.cancel() + timer.purge() timerTask = null } } diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index de827987f28f9..3901f96326f75 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -140,6 +140,7 @@ class BarrierTaskContext( throw e } finally { timerTask.cancel() + timer.purge() } } From 20b7c684cc4a8136b9a9c56390a4948de04e7c34 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 29 Aug 2018 07:22:03 -0700 Subject: [PATCH 1506/2461] [SPARK-25248][.1][PYSPARK] update barrier Python API ## What changes were proposed in this pull request? I made one pass over the Python APIs for barrier mode and updated them to match the Scala doc in #22240 . Major changes: * export the public classes * expand the docs * add doc for BarrierTaskInfo.addresss cc: jiangxb1987 Closes #22261 from mengxr/SPARK-25248.1. Authored-by: Xiangrui Meng Signed-off-by: Xiangrui Meng --- python/pyspark/__init__.py | 12 +++++++++--- python/pyspark/rdd.py | 22 ++++++++++++++++++---- python/pyspark/taskcontext.py | 26 +++++++++++++++++--------- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 58218918693ca..ee153af18c88c 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -36,7 +36,12 @@ Finer-grained cache persistence levels. - :class:`TaskContext`: Information about the current running task, available on the workers and experimental. - + - :class:`RDDBarrier`: + Wraps an RDD under a barrier stage for barrier execution. + - :class:`BarrierTaskContext`: + A :class:`TaskContext` that provides extra info and tooling for barrier execution. + - :class:`BarrierTaskInfo`: + Information about a barrier task. """ from functools import wraps @@ -44,14 +49,14 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.rdd import RDD, RDDBarrier from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ from pyspark._globals import _NoValue @@ -113,4 +118,5 @@ def wrapper(self, *args, **kwargs): "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", + "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", ] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 380475e706fbe..b317156885e51 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2390,7 +2390,18 @@ def barrier(self): """ .. note:: Experimental - Indicates that Spark must launch the tasks together for the current stage. + Marks the current stage as a barrier stage, where Spark must launch all tasks together. + In case of a task failure, instead of only restarting the failed task, Spark will abort the + entire stage and relaunch all tasks for this stage. + The barrier execution mode feature is experimental and it only handles limited scenarios. + Please read the linked SPIP and design docs to understand the limitations and future plans. + + :return: an :class:`RDDBarrier` instance that provides actions within a barrier stage. + + .. seealso:: :class:`BarrierTaskContext` + .. seealso:: `SPIP: Barrier Execution Mode \ + `_ + .. seealso:: `Design Doc `_ .. versionadded:: 2.4.0 """ @@ -2430,8 +2441,8 @@ class RDDBarrier(object): """ .. note:: Experimental - An RDDBarrier turns an RDD into a barrier RDD, which forces Spark to launch tasks of the stage - contains this RDD together. + Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together. + :class:`RDDBarrier` instances are created by :func:`RDD.barrier`. .. versionadded:: 2.4.0 """ @@ -2443,7 +2454,10 @@ def mapPartitions(self, f, preservesPartitioning=False): """ .. note:: Experimental - Return a new RDD by applying a function to each partition of this RDD. + Returns a new RDD by applying a function to each partition of the wrapped RDD, + where tasks are launched together in a barrier stage. + The interface is the same as :func:`RDD.mapPartitions`. + Please see the API doc there. .. versionadded:: 2.4.0 """ diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 53fc2b29e066f..b61643eb0a16e 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -131,9 +131,8 @@ class BarrierTaskContext(TaskContext): """ .. note:: Experimental - A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext - for a running task, use: - L{BarrierTaskContext.get()}. + A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. + Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. .. versionadded:: 2.4.0 """ @@ -155,8 +154,11 @@ def _getOrCreate(cls): @classmethod def get(cls): """ - Return the currently active BarrierTaskContext. This can be called inside of user functions - to access contextual information about running tasks. + .. note:: Experimental + + Return the currently active :class:`BarrierTaskContext`. + This can be called inside of user functions to access contextual information about + running tasks. .. note:: Must be called on the worker, not the driver. Returns None if not initialized. """ @@ -176,7 +178,12 @@ def barrier(self): .. note:: Experimental Sets a global barrier and waits until all tasks in this stage hit this barrier. - Note this method is only allowed for a BarrierTaskContext. + Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks + in the same stage have reached this routine. + + .. warning:: In a barrier stage, each task much have the same number of `barrier()` + calls, in all possible code branches. + Otherwise, you may get the job hanging or a SparkException after timeout. .. versionadded:: 2.4.0 """ @@ -190,9 +197,8 @@ def getTaskInfos(self): """ .. note:: Experimental - Returns the all task infos in this barrier stage, the task infos are ordered by - partitionId. - Note this method is only allowed for a BarrierTaskContext. + Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, + ordered by partition ID. .. versionadded:: 2.4.0 """ @@ -210,6 +216,8 @@ class BarrierTaskInfo(object): Carries all task infos of a barrier task. + :var address: The IPv4 address (host:port) of the executor that the barrier task is running on + .. versionadded:: 2.4.0 """ From 6b1b10ca85f17fc8cd8acc9d6705bd14115ba6b4 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 29 Aug 2018 12:55:44 -0700 Subject: [PATCH 1507/2461] [DOC] Fix comment on SparkPlanGraphEdge ## What changes were proposed in this pull request? `fromId` is the child, and `toId` is the parent, see line 127 in `buildSparkPlanGraphNode` above. The edges in Spark UI also go from child to parent. ## How was this patch tested? Comment change only. Inspected code above. Inspected how the edges in Spark UI look like. Closes #22268 from juliuszsompolski/sparkplangraphedgedoc. Authored-by: Juliusz Sompolski Signed-off-by: Xiao Li --- .../org/apache/spark/sql/execution/ui/SparkPlanGraph.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 884f945815e0f..e57d080dadf78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -202,7 +202,7 @@ private[ui] class SparkPlanGraphCluster( /** - * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child + * Represent an edge in the SparkPlan tree. `fromId` is the child node id, and `toId` is the parent * node id. */ private[ui] case class SparkPlanGraphEdge(fromId: Long, toId: Long) { From ec3e9986385880adce1648eae30007eccff862ba Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 29 Aug 2018 16:32:02 -0700 Subject: [PATCH 1508/2461] [SPARK-24909][CORE] Always unregister pending partition on task completion. Spark scheduler can hang when fetch failures, executor lost, task running on lost executor, and multiple stage attempts. To fix this we change to always unregister the pending partition on task completion. ## What changes were proposed in this pull request? this PR is actually reverting the change in SPARK-19263, so that it always does shuffleStage.pendingPartitions -= task.partitionId. The change in SPARK-23433, should fix the issue originally from SPARK-19263. ## How was this patch tested? Unit tests. The condition happens on a race which I haven't reproduced on a real customer, just see it sometimes on customers jobs in a real cluster. I am also working on adding spark scheduler integration tests. Closes #21976 from tgravescs/SPARK-24909. Authored-by: Thomas Graves Signed-off-by: Marcelo Vanzin --- .../apache/spark/scheduler/DAGScheduler.scala | 17 +------------- .../spark/scheduler/DAGSchedulerSuite.scala | 22 ++++++++++++------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6787250ddc3f4..fec6558f412d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1373,18 +1373,10 @@ private[spark] class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + shuffleStage.pendingPartitions -= task.partitionId val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { - // This task was for the currently running attempt of the stage. Since the task - // completed successfully from the perspective of the TaskSetManager, mark it as - // no longer pending (the TaskSetManager may consider the task complete even - // when the output needs to be ignored because the task's epoch is too small below. - // In this case, when pending partitions is empty, there will still be missing - // output locations, which will cause the DAGScheduler to resubmit the stage below.) - shuffleStage.pendingPartitions -= task.partitionId - } if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { @@ -1393,13 +1385,6 @@ private[spark] class DAGScheduler( // available. mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) - // Remove the task's partition from pending partitions. This may have already been - // done above, but will not have been done yet in cases where the task attempt was - // from an earlier attempt of the stage (i.e., not the attempt that's currently - // running). This allows the DAGScheduler to mark the stage as complete when one - // copy of each task has finished successfully, even if the currently active stage - // still has tasks running. - shuffleStage.pendingPartitions -= task.partitionId } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 56ba23c38af7f..cd00051c56e8d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2492,6 +2492,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(1).tasks(1), Success, makeMapStatus("hostA", 2))) + // task(stageId=1, stageAttemptId=1, partitionId=1) should be marked completed when + // task(stageId=1, stageAttemptId=0, partitionId=1) finished + // ideally we would verify that but no way to get into task scheduler to verify + // Both tasks in rddB should be resubmitted, because none of them has succeeded truly. // Complete the task(stageId=1, stageAttemptId=1, partitionId=0) successfully. // Task(stageId=1, stageAttemptId=1, partitionId=1) of this new active stage attempt @@ -2501,19 +2505,21 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(3).tasks(0), Success, makeMapStatus("hostB", 2))) - // There should be no new attempt of stage submitted, - // because task(stageId=1, stageAttempt=1, partitionId=1) is still running in - // the current attempt (and hasn't completed successfully in any earlier attempts). - assert(taskSets.size === 4) + // At this point there should be no active task set for stageId=1 and we need + // to resubmit because the output from (stageId=1, stageAttemptId=0, partitionId=1) + // was ignored due to executor failure + assert(taskSets.size === 5) + assert(taskSets(4).stageId === 1 && taskSets(4).stageAttemptId === 2 + && taskSets(4).tasks.size === 1) - // Complete task(stageId=1, stageAttempt=1, partitionId=1) successfully. + // Complete task(stageId=1, stageAttempt=2, partitionId=1) successfully. runEvent(makeCompletionEvent( - taskSets(3).tasks(1), Success, makeMapStatus("hostB", 2))) + taskSets(4).tasks(0), Success, makeMapStatus("hostB", 2))) // Now the ResultStage should be submitted, because all of the tasks of rddB have // completed successfully on alive executors. - assert(taskSets.size === 5 && taskSets(4).tasks(0).isInstanceOf[ResultTask[_, _]]) - complete(taskSets(4), Seq( + assert(taskSets.size === 6 && taskSets(5).tasks(0).isInstanceOf[ResultTask[_, _]]) + complete(taskSets(5), Seq( (Success, 1), (Success, 1))) } From 3a66a7fca9f1df3c9175ea9ac04c93d0e86f65c4 Mon Sep 17 00:00:00 2001 From: cclauss Date: Thu, 30 Aug 2018 08:13:11 +0800 Subject: [PATCH 1509/2461] [SPARK-25253][PYSPARK][FOLLOWUP] Undefined name: from pyspark.util import _exception_message HyukjinKwon ## What changes were proposed in this pull request? add __from pyspark.util import \_exception_message__ to python/pyspark/java_gateway.py ## How was this patch tested? [flake8](http://flake8.pycqa.org) testing of https://github.com/apache/spark on Python 3.7.0 $ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__ ``` ./python/pyspark/java_gateway.py:172:20: F821 undefined name '_exception_message' emsg = _exception_message(e) ^ 1 F821 undefined name '_exception_message' 1 ``` Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22265 from cclauss/patch-2. Authored-by: cclauss Signed-off-by: hyukjinkwon --- python/pyspark/java_gateway.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index b06503b53be90..c8c5f801f89bb 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -34,6 +34,7 @@ from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer +from pyspark.util import _exception_message def launch_gateway(conf=None): From 56bc70047edf30485906482015f6378fd5e837f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=8D=E5=86=AC?= Date: Thu, 30 Aug 2018 15:05:36 +0800 Subject: [PATCH 1510/2461] [SQL][MINOR] Fix compiling for scala 2.12 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Introduced by #21320 and #11744 ``` $ sbt > ++2.12.6 > project sql > compile ... [error] [warn] spark/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala:41: match may not be exhaustive. [error] It would fail on the following inputs: (_, ArrayType(_, _)), (_, _) [error] [warn] getProjection(a.child).map(p => (p, p.dataType)).map { [error] [warn] [error] [warn] spark/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala:52: match may not be exhaustive. [error] It would fail on the following input: (_, _) [error] [warn] getProjection(child).map(p => (p, p.dataType)).map { [error] [warn] ... ``` And ``` $ sbt > ++2.12.6 > project hive > testOnly *ParquetMetastoreSuite ... [error] /Users/rendong/wdi/spark/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala:22: object tools is not a member of package scala [error] import scala.tools.nsc.Properties [error] ^ [error] /Users/rendong/wdi/spark/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala:146: not found: value Properties [error] val version = Properties.versionNumberString match { [error] ^ [error] two errors found ... ``` ## How was this patch tested? Existing tests. Closes #22260 from sadhen/fix_exhaustive_match. Authored-by: 忍冬 Signed-off-by: hyukjinkwon --- .../apache/spark/sql/execution/ProjectionOverSchema.scala | 8 ++++++++ .../org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala index 2236f18b0da12..612a7b87b9832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala @@ -45,6 +45,10 @@ private[execution] case class ProjectionOverSchema(schema: StructType) { projSchema.fieldIndex(a.field.name), projSchema.size, a.containsNull) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}" + ) } case GetMapValue(child, key) => getProjection(child).map { projection => GetMapValue(projection, key) } @@ -52,6 +56,10 @@ private[execution] case class ProjectionOverSchema(schema: StructType) { getProjection(child).map(p => (p, p.dataType)).map { case (projection, projSchema: StructType) => GetStructField(projection, projSchema.fieldIndex(field.name)) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for GetStructField: ${projSchema.toString}" + ) } case _ => None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index aa5b531992613..a676cf6ce6925 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedWriter, File, FileWriter} -import scala.tools.nsc.Properties +import scala.util.Properties import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} From e9fce2a4c124cdd709e93065ea691169b5a25f7d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 30 Aug 2018 16:24:47 +0800 Subject: [PATCH 1511/2461] [SPARK-24716][TESTS][FOLLOW-UP] Test Hive metastore schema and parquet schema are in different letter cases ## What changes were proposed in this pull request? Since https://github.com/apache/spark/pull/21696. Spark uses Parquet schema instead of Hive metastore schema to do pushdown. That change can avoid wrong records returned when Hive metastore schema and parquet schema are in different letter cases. This pr add a test case for it. More details: https://issues.apache.org/jira/browse/SPARK-25206 ## How was this patch tested? unit tests Closes #22267 from wangyum/SPARK-24716-TESTS. Authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../apache/spark/sql/hive/HiveParquetSuite.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 09c15473b21c1..e5c9df05d5674 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf case class Cases(lower: String, UPPER: String) @@ -76,4 +77,19 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton } } } + + test("SPARK-25206: wrong records are returned by filter pushdown " + + "when Hive metastore schema and parquet schema are in different letter cases") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString) { + withTempPath { path => + val data = spark.range(1, 10).toDF("id") + data.write.parquet(path.getCanonicalPath) + withTable("SPARK_25206") { + sql("CREATE TABLE SPARK_25206 (ID LONG) USING parquet LOCATION " + + s"'${path.getCanonicalPath}'") + checkAnswer(sql("select id from SPARK_25206 where id > 0"), data) + } + } + } + } } From 3c67cb0b52c14f1cee1a0aaf74d6d71f28cbb5f2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 30 Aug 2018 20:25:26 +0800 Subject: [PATCH 1512/2461] [SPARK-25273][DOC] How to install testthat 1.0.2 ## What changes were proposed in this pull request? R tests require `testthat` v1.0.2. In the PR, I described how to install the version in the section http://spark.apache.org/docs/latest/building-spark.html#running-r-tests. Closes #22272 from MaxGekk/r-testthat-doc. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- docs/README.md | 3 ++- docs/building-spark.md | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 7da543dd297ad..fb67c4b3586d6 100644 --- a/docs/README.md +++ b/docs/README.md @@ -22,8 +22,9 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' $ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("testthat", version = "1.0.2", repos="http://cran.stat.ucla.edu/")' ``` Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. diff --git a/docs/building-spark.md b/docs/building-spark.md index d3dfd4902a920..0086aeaaa4701 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -236,7 +236,8 @@ The run-tests script also can be limited to a specific Python version or a speci To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: - R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + R -e "install.packages(c('knitr', 'rmarkdown', 'devtools', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + R -e "devtools::install_version('testthat', version = '1.0.2', repos='http://cran.us.r-project.org')" You can run just the SparkR tests using the command: From 9e0f9591afccc97cd54a133d8ed10512d14f4f91 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 30 Aug 2018 11:21:40 -0500 Subject: [PATCH 1513/2461] [SPARK-23997][SQL][FOLLOWUP] Update exception message ## What changes were proposed in this pull request? This PR is an follow-up PR of #21087 based on [a discussion thread](https://github.com/apache/spark/pull/21087#discussion_r211080067]. Since #21087 changed a condition of `if` statement, the message in an exception is not consistent of the current behavior. This PR updates the exception message. ## How was this patch tested? Existing UTs Closes #22269 from kiszk/SPARK-23997-followup. Authored-by: Kazuaki Ishizaki Signed-off-by: Sean Owen --- .../org/apache/spark/sql/catalyst/catalog/interface.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 3842d794ba5ff..30ded13410f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -178,8 +178,8 @@ case class BucketSpec( if (numBuckets <= 0 || numBuckets > conf.bucketingMaxBuckets) { throw new AnalysisException( - s"Number of buckets should be greater than 0 but less than bucketing.maxBuckets " + - s"(`${conf.bucketingMaxBuckets}`). Got `$numBuckets`") + s"Number of buckets should be greater than 0 but less than or equal to " + + s"bucketing.maxBuckets (`${conf.bucketingMaxBuckets}`). Got `$numBuckets`") } override def toString: String = { From 135ff16a3510a4dfb3470904004dae9848005019 Mon Sep 17 00:00:00 2001 From: Reza Safi Date: Thu, 30 Aug 2018 13:26:03 -0500 Subject: [PATCH 1514/2461] [SPARK-25233][STREAMING] Give the user the option of specifying a minimum message per partition per batch when using kafka direct API with backpressure After SPARK-18371, it is guaranteed that there would be at least one message per partition per batch using direct kafka API when new messages exist in the topics. This change will give the user the option of setting the minimum instead of just a hard coded 1 limit The related unit test is updated and some internal tests verified that the topic partitions with new messages will be progressed by the specified minimum. Author: Reza Safi Closes #22223 from rezasafi/streaminglag. --- docs/configuration.md | 8 ++++++++ .../streaming/kafka010/DirectKafkaInputDStream.scala | 3 ++- .../spark/streaming/kafka010/PerPartitionConfig.scala | 3 +++ .../streaming/kafka010/DirectKafkaStreamSuite.scala | 10 +++++++--- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 9714b48d5e69b..b5ff426936e59 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1997,6 +1997,14 @@ showDF(properties, numRows = 200, truncate = FALSE) for more details. + + spark.streaming.kafka.minRatePerPartition + 1 + + Minimum rate (number of records per second) at which data will be read from each Kafka + partition when using the new Kafka direct stream API. + + spark.streaming.kafka.maxRetries 1 diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0246006acf0bd..0acc9b8d2a0cf 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -154,7 +154,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, + ppc.minRatePerPartition(tp)) }) } else { None diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala index 4792f2a955110..4017fdbcaf95e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala @@ -34,6 +34,7 @@ abstract class PerPartitionConfig extends Serializable { * from each Kafka partition. */ def maxRatePerPartition(topicPartition: TopicPartition): Long + def minRatePerPartition(topicPartition: TopicPartition): Long = 1 } /** @@ -42,6 +43,8 @@ abstract class PerPartitionConfig extends Serializable { private class DefaultPerPartitionConfig(conf: SparkConf) extends PerPartitionConfig { val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0) + val minRate = conf.getLong("spark.streaming.kafka.minRatePerPartition", 1) def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate + override def minRatePerPartition(topicPartition: TopicPartition): Long = minRate } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 35e4678f2e3c8..661b67a8ab68a 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -664,7 +664,8 @@ class DirectKafkaStreamSuite kafkaStream.stop() } - test("maxMessagesPerPartition with zero offset and rate equal to one") { + test("maxMessagesPerPartition with zero offset and rate equal to the specified" + + " minimum with default 1") { val topic = "backpressure" val kafkaParams = getKafkaParams() val batchIntervalMilliseconds = 60000 @@ -674,6 +675,8 @@ class DirectKafkaStreamSuite .setMaster("local[1]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.kafka.maxRatePerPartition", "100") + .set("spark.streaming.kafka.minRatePerPartition", "5") + // Setup the streaming context ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) @@ -704,12 +707,13 @@ class DirectKafkaStreamSuite ) val result = kafkaStream.maxMessagesPerPartition(offsets) val expected = Map( - new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L, new TopicPartition(topic, 2) -> 20L, new TopicPartition(topic, 3) -> 30L ) - assert(result.contains(expected), s"Number of messages per partition must be at least 1") + assert(result.contains(expected), s"Number of messages per partition must be at least equal" + + s" to the specified minimum") } /** Get the generated offset ranges from the DirectKafkaStream */ From c685b5f56a69abdf77e07e852b9bb2c6f2e715c9 Mon Sep 17 00:00:00 2001 From: aai95 Date: Thu, 30 Aug 2018 20:38:03 +0000 Subject: [PATCH 1515/2461] [SPARK-24411][SQL] Adding native Java tests for 'isInCollection' ## What changes were proposed in this pull request? `JavaColumnExpressionSuite.java` was added and `org.apache.spark.sql.ColumnExpressionSuite#test("isInCollection: Java Collection")` was removed. It provides native Java tests for the method `org.apache.spark.sql.Column#isInCollection`. Closes #22253 from aai95/isInCollectionJavaTest. Authored-by: aai95 Signed-off-by: DB Tsai --- .../spark/sql/JavaColumnExpressionSuite.java | 95 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 21 ---- 2 files changed, 95 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java new file mode 100644 index 0000000000000..38d606c5e108e --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.api.java.function.FilterFunction; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.*; + +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaColumnExpressionSuite { + private transient TestSparkSession spark; + + @Before + public void setUp() { + spark = new TestSparkSession(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void isInCollectionWorksCorrectlyOnJava() { + List rows = Arrays.asList( + RowFactory.create(1, "x"), + RowFactory.create(2, "y"), + RowFactory.create(3, "z")); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", StringType, false))); + Dataset df = spark.createDataFrame(rows, schema); + // Test with different types of collections + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), + (Row[]) df.filter((FilterFunction) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), + (Row[]) df.filter((FilterFunction) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() + )); + Assert.assertTrue(Arrays.equals( + (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), + (Row[]) df.filter((FilterFunction) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() + )); + } + + @Test + public void isInCollectionCheckExceptionMessage() { + List rows = Arrays.asList( + RowFactory.create(1, Arrays.asList(1)), + RowFactory.create(2, Arrays.asList(2)), + RowFactory.create(3, Arrays.asList(3))); + StructType schema = createStructType(Arrays.asList( + createStructField("a", IntegerType, false), + createStructField("b", createArrayType(IntegerType, false), false))); + Dataset df = spark.createDataFrame(rows, schema); + try { + df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); + Assert.fail("Expected org.apache.spark.sql.AnalysisException"); + } catch (Exception e) { + Arrays.asList("cannot resolve", + "due to data type mismatch: Arguments must be same type but were") + .forEach(s -> Assert.assertTrue( + e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2182bd7eadd63..2917c56dbeb56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -436,27 +436,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isInCollection: Java Collection") { - val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b").asJava)) - } - Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") - .foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) - } - } - test("&&") { checkAnswer( booleanData.filter($"a" && true), From a5fb5b62c3595ce2e59e7b8e95ef5bcc825f1577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=8D=E5=86=AC?= Date: Thu, 30 Aug 2018 15:54:07 -0500 Subject: [PATCH 1516/2461] [SPARK-25235][BUILD][SHELL][FOLLOWUP] Fix repl compile for 2.12 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Error messages from https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/183/ ``` [INFO] --- scala-maven-plugin:3.2.2:compile (scala-compile-first) spark-repl_2.12 --- [INFO] Using zinc server for incremental compilation [warn] Pruning sources from previous analysis, due to incompatible CompileSetup. [info] Compiling 6 Scala sources to /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/repl/target/scala-2.12/classes... [error] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala:80: overriding lazy value importableSymbolsWithRenames in class ImportHandler of type List[(this.intp.global.Symbol, this.intp.global.Name)]; [error] lazy value importableSymbolsWithRenames needs `override' modifier [error] lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = { [error] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala:53: variable addedClasspath in class ILoop is deprecated (since 2.11.0): use reset, replay or require to update class path [warn] if (addedClasspath != "") { [warn] ^ [warn] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala:54: variable addedClasspath in class ILoop is deprecated (since 2.11.0): use reset, replay or require to update class path [warn] settings.classpath append addedClasspath [warn] ^ [warn] two warnings found [error] one error found [error] Compile failed at Aug 29, 2018 5:28:22 PM [0.679s] ``` Readd the profile for `scala-2.12`. Using `-Pscala-2.12` will overrides `extra.source.dir` and `extra.testsource.dir` with two non-exist directories. ## How was this patch tested? First, make sure it compiles. ``` dev/change-scala-version.sh 2.12 mvn -Pscala-2.12 -DskipTests compile install ``` Then, make a distribution to try the repl: `./dev/make-distribution.sh --name custom-spark --tgz -Phadoop-2.7 -Phive -Pyarn -Pscala-2.12` ``` 18/08/30 16:04:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Spark context Web UI available at http://172.16.131.140:4040 Spark context available as 'sc' (master = local[*], app id = local-1535616298812). Spark session available as 'spark'. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Scala version 2.12.6 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_112) Type in expressions to have them evaluated. Type :help for more information. scala> spark.sql("select percentile(key, 1) from values (1, 1),(2, 1) T(key, value)").show +-------------------------------------+ |percentile(key, CAST(1 AS DOUBLE), 1)| +-------------------------------------+ | 2.0| +-------------------------------------+ ``` Closes #22280 from sadhen/SPARK_24785_FOLLOWUP. Authored-by: 忍冬 Signed-off-by: Sean Owen --- repl/pom.xml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/repl/pom.xml b/repl/pom.xml index 553d5eb79a256..861bbd7c49654 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -167,4 +167,14 @@ + + + scala-2.12 + + scala-2.12/src/main/scala + scala-2.12/src/test/scala + + + + From d6d1224ffab7ce980b80da459f68502be13d72fc Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Thu, 30 Aug 2018 14:07:04 -0700 Subject: [PATCH 1517/2461] [SPARK-25275][K8S] require memberhip in wheel to run 'su' in dockerfiles ## What changes were proposed in this pull request? Add a PAM configuration in k8s dockerfile to require authentication into wheel to run as `su` ## How was this patch tested? Verify against CI that PAM config succeeds & causes no regressions Closes #22285 from erikerlandson/spark-25275. Authored-by: Erik Erlandson Signed-off-by: Erik Erlandson --- .../kubernetes/docker/src/main/dockerfiles/spark/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 42a670174eae1..071aa2020dd85 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -29,12 +29,13 @@ ARG img_path=kubernetes/dockerfiles RUN set -ex && \ apk upgrade --no-cache && \ - apk add --no-cache bash tini libc6-compat && \ + apk add --no-cache bash tini libc6-compat linux-pam && \ mkdir -p /opt/spark && \ mkdir -p /opt/spark/work-dir && \ touch /opt/spark/RELEASE && \ rm /bin/sh && \ ln -sv /bin/bash /bin/sh && \ + echo "auth required pam_wheel.so use_uid" >> /etc/pam.d/su && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd COPY ${spark_jars} /opt/spark/jars From bb3e6ed9216f98f3a3b96c8c52f20042d65e2181 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Thu, 30 Aug 2018 15:08:12 -0700 Subject: [PATCH 1518/2461] [SPARK-25287][INFRA] Add up-front check for JIRA_USERNAME and JIRA_PASSWORD ## What changes were proposed in this pull request? Add an up-front check that `JIRA_USERNAME` and `JIRA_PASSWORD` have been set. If they haven't, ask user if they want to continue. This prevents the JIRA state update from failing at the very end of the process because user forgot to set these environment variables. ## How was this patch tested? I ran the script with environment vars set, and unset, to verify it works as specified. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22294 from erikerlandson/spark-25287. Authored-by: Erik Erlandson Signed-off-by: Erik Erlandson --- dev/merge_spark_pr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 28a6714856c10..81daa909e019c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -438,6 +438,10 @@ def main(): os.chdir(SPARK_HOME) original_head = get_current_ref() + # Check this up front to avoid failing the JIRA update at the very end + if not JIRA_USERNAME or not JIRA_PASSWORD: + continue_maybe("The env-vars JIRA_USERNAME and/or JIRA_PASSWORD are not set. Continue?") + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically From f29c2b5287563c0d6f55f936bd5a75707d7b2b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=8D=E5=86=AC?= Date: Thu, 30 Aug 2018 22:37:40 -0500 Subject: [PATCH 1519/2461] [SPARK-25256][SQL][TEST] Plan mismatch errors in Hive tests in Scala 2.12 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? ### For `SPARK-5775 read array from partitioned_parquet_with_key_and_complextypes`: scala2.12 ``` scala> (1 to 10).toString res4: String = Range 1 to 10 ``` scala2.11 ``` scala> (1 to 10).toString res2: String = Range(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) ``` And ``` def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } ``` sortBy `_.toString` is not a good idea. ### Other failures are caused by ``` Array(Int.box(1)).toSeq == Array(Double.box(1.0)).toSeq ``` It is false in 2.12.2 + and is true in 2.11.x , 2.12.0, 2.12.1 ## How was this patch tested? This is a patch on a specific unit test. Closes #22264 from sadhen/SPARK25256. Authored-by: 忍冬 Signed-off-by: Sean Owen --- .../test/scala/org/apache/spark/sql/QueryTest.scala | 10 ++++++++++ .../org/apache/spark/sql/hive/parquetSuites.scala | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9fb8be423614b..baca9c1cfb9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -290,6 +290,16 @@ object QueryTest { Row.fromSeq(row.toSeq.map { case null => null case d: java.math.BigDecimal => BigDecimal(d) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 2327d83a1b4f6..e82d457eee394 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -1068,7 +1068,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with test(s"SPARK-5775 read array from $table") { checkAnswer( sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), - (1 to 10).map(i => Row(1 to i, 1))) + (1 to 10).map(i => Row((1 to i).toArray, 1))) } } From aa70a0a1a434e8a4b1d4dde00e20b865bb70b8dd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 30 Aug 2018 23:23:11 -0700 Subject: [PATCH 1520/2461] [SPARK-25288][TESTS] Fix flaky Kafka transaction tests ## What changes were proposed in this pull request? Here are the failures: http://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.sql.kafka010.KafkaRelationSuite&test_name=read+Kafka+transactional+messages%3A+read_committed http://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.sql.kafka010.KafkaMicroBatchV1SourceSuite&test_name=read+Kafka+transactional+messages%3A+read_committed http://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.sql.kafka010.KafkaMicroBatchV2SourceSuite&test_name=read+Kafka+transactional+messages%3A+read_committed I found the Kafka consumer may not see the committed messages for a short time. This PR just adds a new method `waitUntilOffsetAppears` and uses it to make sure the consumer can see a specified offset before checking the result. ## How was this patch tested? Jenkins Closes #22293 from zsxwing/SPARK-25288. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 34 +++++++++++-------- .../sql/kafka010/KafkaRelationSuite.scala | 7 ++++ .../spark/sql/kafka010/KafkaTestUtils.scala | 10 ++++++ 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index eb66ccac744a3..78249f7a80fb5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -160,14 +160,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf } object WithOffsetSync { - def apply(topic: String)(func: () => Unit): StreamAction = { + /** + * Run `func` to write some Kafka messages and wait until the latest offset of the given + * `TopicPartition` is not less than `expectedOffset`. + */ + def apply( + topicPartition: TopicPartition, + expectedOffset: Long)(func: () => Unit): StreamAction = { Execute("Run Kafka Producer")(_ => { func() // This is a hack for the race condition that the committed message may be not visible to // consumer for a short time. - // Looks like after the following call returns, the consumer can always read the committed - // messages. - testUtils.getLatestOffsets(Set(topic)) + testUtils.waitUntilOffsetAppears(topicPartition, expectedOffset) }) } } @@ -652,13 +656,14 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } } + val topicPartition = new TopicPartition(topic, 0) // The message values are the same as their offsets to make the test easy to follow testUtils.withTranscationalProducer { producer => testStream(mapped)( StartStream(ProcessingTime(100), clock), waitUntilBatchProcessed, CheckAnswer(), - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 5) { () => // Send 5 messages. They should be visible only after being committed. producer.beginTransaction() (0 to 4).foreach { i => @@ -669,7 +674,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { waitUntilBatchProcessed, // Should not see any uncommitted messages CheckNewAnswer(), - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 6) { () => producer.commitTransaction() }, AdvanceManualClock(100), @@ -678,7 +683,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 12) { () => // Send 5 messages and abort the transaction. They should not be read. producer.beginTransaction() (6 to 10).foreach { i => @@ -692,7 +697,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(), // offset: 9*, 10*, 11* - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 18) { () => // Send 5 messages again. The consumer should skip the above aborted messages and read // them. producer.beginTransaction() @@ -707,7 +712,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(15, 16), // offset: 15, 16, 17* - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 25) { () => producer.beginTransaction() producer.send(new ProducerRecord[String, String](topic, "18")).get() producer.commitTransaction() @@ -774,13 +779,14 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } } + val topicPartition = new TopicPartition(topic, 0) // The message values are the same as their offsets to make the test easy to follow testUtils.withTranscationalProducer { producer => testStream(mapped)( StartStream(ProcessingTime(100), clock), waitUntilBatchProcessed, CheckNewAnswer(), - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 5) { () => // Send 5 messages. They should be visible only after being committed. producer.beginTransaction() (0 to 4).foreach { i => @@ -790,13 +796,13 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(0, 1, 2), // offset 0, 1, 2 - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 6) { () => producer.commitTransaction() }, AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(3, 4), // offset: 3, 4, 5* [* means it's not a committed data message] - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 12) { () => // Send 5 messages and abort the transaction. They should not be read. producer.beginTransaction() (6 to 10).foreach { i => @@ -810,7 +816,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(9, 10), // offset: 9, 10, 11* - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 18) { () => // Send 5 messages again. The consumer should skip the above aborted messages and read // them. producer.beginTransaction() @@ -825,7 +831,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AdvanceManualClock(100), waitUntilBatchProcessed, CheckNewAnswer(15, 16), // offset: 15, 16, 17* - WithOffsetSync(topic) { () => + WithOffsetSync(topicPartition, expectedOffset = 25) { () => producer.beginTransaction() producer.send(new ProducerRecord[String, String](topic, "18")).get() producer.commitTransaction() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 93dba18446280..eb186970fc25d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -260,6 +260,7 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest producer.commitTransaction() // Should read all committed messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 6) checkAnswer(df, (1 to 5).map(_.toString).toDF) producer.beginTransaction() @@ -269,6 +270,7 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest producer.abortTransaction() // Should not read aborted messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 12) checkAnswer(df, (1 to 5).map(_.toString).toDF) producer.beginTransaction() @@ -278,6 +280,7 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest producer.commitTransaction() // Should skip aborted messages and read new committed ones. + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 18) checkAnswer(df, ((1 to 5) ++ (11 to 15)).map(_.toString).toDF) } } @@ -301,11 +304,13 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest } // "read_uncommitted" should see all messages including uncommitted ones + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 5) checkAnswer(df, (1 to 5).map(_.toString).toDF) producer.commitTransaction() // Should read all committed messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 6) checkAnswer(df, (1 to 5).map(_.toString).toDF) producer.beginTransaction() @@ -315,6 +320,7 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest producer.abortTransaction() // "read_uncommitted" should see all messages including uncommitted or aborted ones + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 12) checkAnswer(df, (1 to 10).map(_.toString).toDF) producer.beginTransaction() @@ -324,6 +330,7 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest producer.commitTransaction() // Should read all messages + testUtils.waitUntilOffsetAppears(new TopicPartition(topic, 0), 18) checkAnswer(df, (1 to 15).map(_.toString).toDF) } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 7b742a3ea6741..bf6934be52705 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -439,6 +439,16 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L } } + /** + * Wait until the latest offset of the given `TopicPartition` is not less than `offset`. + */ + def waitUntilOffsetAppears(topicPartition: TopicPartition, offset: Long): Unit = { + eventually(timeout(60.seconds)) { + val currentOffset = getLatestOffsets(Set(topicPartition.topic)).get(topicPartition) + assert(currentOffset.nonEmpty && currentOffset.get >= offset) + } + } + private class EmbeddedZookeeper(val zkConnect: String) { val snapshotDir = Utils.createTempDir() val logDir = Utils.createTempDir() From 515708d5f33d5acdb4206c626192d1838f8e691f Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 31 Aug 2018 14:45:29 +0800 Subject: [PATCH 1521/2461] [SPARK-25183][SQL] Spark HiveServer2 to use Spark ShutdownHookManager ## What changes were proposed in this pull request? Switch `org.apache.hive.service.server.HiveServer2` to register its shutdown callback with Spark's `ShutdownHookManager`, rather than direct with the Java Runtime callback. This avoids race conditions in shutdown where the filesystem is shutdown before the flush/write/rename of the event log is completed, particularly on object stores where the write and rename can be slow. ## How was this patch tested? There's no explicit unit for test this, which is consistent with every other shutdown hook in the codebase. * There's an implicit test when the scalatest process is halted. * More manual/integration testing is needed. HADOOP-15679 has added the ability to explicitly execute the hadoop shutdown hook sequence which spark uses; that could be stabilized for testing if desired, after which all the spark hooks could be tested. Until then: external system tests only. Author: Steve Loughran Closes #22186 from steveloughran/BUG/SPARK-25183-shutdown. --- .../hive/service/server/HiveServer2.java | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java index 9bf96cff572e8..a30be2bc06b9e 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java @@ -20,6 +20,9 @@ import java.util.Properties; +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + import org.apache.commons.cli.GnuParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Option; @@ -39,6 +42,8 @@ import org.apache.hive.service.cli.thrift.ThriftCLIService; import org.apache.hive.service.cli.thrift.ThriftHttpCLIService; +import org.apache.spark.util.ShutdownHookManager; + /** * HiveServer2. * @@ -67,13 +72,23 @@ public synchronized void init(HiveConf hiveConf) { super.init(hiveConf); // Add a shutdown hook for catching SIGTERM & SIGINT - final HiveServer2 hiveServer2 = this; - Runtime.getRuntime().addShutdownHook(new Thread() { - @Override - public void run() { - hiveServer2.stop(); - } - }); + // this must be higher than the Hadoop Filesystem priority of 10, + // which the default priority is. + // The signature of the callback must match that of a scala () -> Unit + // function + ShutdownHookManager.addShutdownHook( + new AbstractFunction0() { + public BoxedUnit apply() { + try { + LOG.info("Hive Server Shutdown hook invoked"); + stop(); + } catch (Throwable e) { + LOG.warn("Ignoring Exception while stopping Hive Server from shutdown hook", + e); + } + return BoxedUnit.UNIT; + } + }); } public static boolean isHTTPTransportMode(HiveConf hiveConf) { @@ -95,7 +110,6 @@ public synchronized void start() { @Override public synchronized void stop() { LOG.info("Shutting down HiveServer2"); - HiveConf hiveConf = this.getHiveConf(); super.stop(); } From 8d9495a8f1e64dbc42c3741f9bcbd4893ce3f0e9 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 31 Aug 2018 19:24:09 +0800 Subject: [PATCH 1522/2461] [SPARK-25207][SQL] Case-insensitve field resolution for filter pushdown when reading Parquet ## What changes were proposed in this pull request? Currently, filter pushdown will not work if Parquet schema and Hive metastore schema are in different letter cases even spark.sql.caseSensitive is false. Like the below case: ```scala spark.sparkContext.hadoopConfiguration.setInt("parquet.block.size", 8 * 1024 * 1024) spark.range(1, 40 * 1024 * 1024, 1, 1).sortWithinPartitions("id").write.parquet("/tmp/t") sql("CREATE TABLE t (ID LONG) USING parquet LOCATION '/tmp/t'") sql("select * from t where id < 100L").write.csv("/tmp/id") ``` Although filter "ID < 100L" is generated by Spark, it fails to pushdown into parquet actually, Spark still does the full table scan when reading. This PR provides a case-insensitive field resolution to make it work. Before - "ID < 100L" fail to pushedown: screen shot 2018-08-23 at 10 08 26 pm After - "ID < 100L" pushedown sucessfully: screen shot 2018-08-23 at 10 08 40 pm ## How was this patch tested? Added UTs. Closes #22197 from yucai/SPARK-25207. Authored-by: yucai Signed-off-by: Wenchen Fan --- .../parquet/ParquetFileFormat.scala | 3 +- .../datasources/parquet/ParquetFilters.scala | 90 ++++++++++---- .../parquet/ParquetFilterSuite.scala | 115 +++++++++++++++++- 3 files changed, 179 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d7eb14356b8b1..ea4f1592a7c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -347,6 +347,7 @@ class ParquetFileFormat val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -372,7 +373,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, - pushDownStringStartWith, pushDownInFilterThreshold) + pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 58b4a769fcb62..0c286defb9406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} +import java.util.Locale import scala.collection.JavaConverters.asScalaBufferConverter @@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +45,18 @@ private[parquet] class ParquetFilters( pushDownTimestamp: Boolean, pushDownDecimal: Boolean, pushDownStartWith: Boolean, - pushDownInFilterThreshold: Int) { + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + + /** + * Holds a single field information stored in the underlying parquet file. + * + * @param fieldName field name in parquet file + * @param fieldType field type related info in parquet file + */ + private case class ParquetField( + fieldName: String, + fieldType: ParquetSchemaType) private case class ParquetSchemaType( originalType: OriginalType, @@ -350,25 +362,38 @@ private[parquet] class ParquetFilters( } /** - * Returns a map from name of the column to the data type, if predicate push down applies. + * Returns a map, which contains parquet field name and data type, if predicate push down applies. */ - private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { - case m: MessageType => - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetSchemaType( - f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) - }.toMap - case _ => Map.empty[String, ParquetSchemaType] + private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = { + // Here we don't flatten the fields in the nested schema but just look up through + // root fields. Currently, accessing to nested fields does not push down filters + // and it does not support to create filters for them. + val primitiveFields = + dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetField(f.getName, + ParquetSchemaType(f.getOriginalType, + f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields) + } } /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { - val nameToType = getFieldMap(schema) + val nameToParquetField = getFieldMap(schema) // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. @@ -381,7 +406,7 @@ private[parquet] class ParquetFilters( // Parquet's type in the given file should be matched to the value's type // in the pushed filter in order to push down the filter to Parquet. def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToType(name) match { + value == null || (nameToParquetField(name).fieldType match { case ParquetBooleanType => value.isInstanceOf[JBoolean] case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] case ParquetLongType => value.isInstanceOf[JLong] @@ -408,7 +433,7 @@ private[parquet] class ParquetFilters( // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) } // NOTE: @@ -428,29 +453,39 @@ private[parquet] class ParquetFilters( predicate match { case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq.lift(nameToType(name)).map(_(name, null)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq.lift(nameToType(name)).map(_(name, null)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq.lift(nameToType(name)).map(_(name, value)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq.lift(nameToType(name)).map(_(name, value)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt.lift(nameToType(name)).map(_(name, value)) + makeLt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq.lift(nameToType(name)).map(_(name, value)) + makeLtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt.lift(nameToType(name)).map(_(name, value)) + makeGt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq.lift(nameToType(name)).map(_(name, value)) + makeGtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side if we do not understand the @@ -477,7 +512,8 @@ private[parquet] class ParquetFilters( case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => - makeEq.lift(nameToType(name)).map(_(name, v)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, v)) }.reduceLeftOption(FilterApi.or) case sources.StringStartsWith(name, prefix) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index be4f498c921ab..7ebb75009555a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -25,6 +25,7 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operato import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -60,7 +61,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, - conf.parquetFilterPushDownInFilterThreshold) + conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis) override def beforeEach(): Unit = { super.beforeEach() @@ -1021,6 +1022,118 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-25207: Case-insensitive field resolution for pushdown when reading parquet") { + def createParquetFilter(caseSensitive: Boolean): ParquetFilters = { + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold, caseSensitive) + } + val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true) + val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = false) + + def testCaseInsensitiveResolution( + schema: StructType, + expected: FilterPredicate, + filter: sources.Filter): Unit = { + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(expected)) { + caseInsensitiveParquetFilters.createFilter(parquetSchema, filter) + } + assertResult(None) { + caseSensitiveParquetFilters.createFilter(parquetSchema, filter) + } + } + + val schema = StructType(Seq(StructField("cint", IntegerType))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT")) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]), + sources.IsNotNull("CINT")) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualTo("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualNullSafe("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, + FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.ltEq(intColumn("cint"), 1000: Integer), + sources.LessThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.gtEq(intColumn("cint"), 1000: Integer), + sources.GreaterThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.or( + FilterApi.eq(intColumn("cint"), 10: Integer), + FilterApi.eq(intColumn("cint"), 20: Integer)), + sources.In("CINT", Array(10, 20))) + + val dupFieldSchema = StructType( + Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType))) + val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema) + assertResult(None) { + caseInsensitiveParquetFilters.createFilter( + dupParquetSchema, sources.EqualTo("CINT", 1000)) + } + } + + test("SPARK-25207: exception when duplicate fields in case-insensitive mode") { + withTempPath { dir => + val count = 10 + val tableName = "spark_25207" + val tableDir = dir.getAbsoluteFile + "/table" + withTable(tableName) { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark.range(count).selectExpr("id as A", "id as B", "id as b") + .write.mode("overwrite").parquet(tableDir) + } + sql( + s""" + |CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir' + """.stripMargin) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val e = intercept[SparkException] { + sql(s"select a from $tableName where b > 0").collect() + } + assert(e.getCause.isInstanceOf[RuntimeException] && e.getCause.getMessage.contains( + """Found duplicate field(s) "B": [B, b] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until count).map(Row(_))) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 339859c4e4b27726ba8ce9a64451a981ef74de0c Mon Sep 17 00:00:00 2001 From: huangtengfei02 Date: Fri, 31 Aug 2018 09:06:38 -0500 Subject: [PATCH 1523/2461] [SPARK-25261][MINOR][DOC] update the description for spark.executor|driver.memory in configuration.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? As described in [SPARK-25261](https://issues.apache.org/jira/projects/SPARK/issues/SPARK-25261),the unit of spark.executor.memory and spark.driver.memory is parsed as bytes in some cases if no unit specified, while in https://spark.apache.org/docs/latest/configuration.html#application-properties, they are descibed as MiB, which may lead to some misunderstandings. ## How was this patch tested? N/A Closes #22252 from ivoson/branch-correct-configuration. Lead-authored-by: huangtengfei02 Co-authored-by: Huang Tengfei Signed-off-by: Sean Owen --- docs/configuration.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index b5ff426936e59..f344bcd20087d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -152,8 +152,9 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB - unless otherwise specified (e.g. 1g, 2g). + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in the + same format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g).
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -175,8 +176,8 @@ of the most common options to set are: spark.executor.memory 1g - Amount of memory to use per executor process, in MiB unless otherwise specified. - (e.g. 2g, 8g). + Amount of memory to use per executor process, in the same format as JVM memory strings with + a size unit suffix ("k", "m", "g" or "t") (e.g. 512m, 2g). From 7fc8881b0fbc3d85a524e0454fa89925e92c4fa4 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 31 Aug 2018 08:47:20 -0700 Subject: [PATCH 1524/2461] [SPARK-25296][SQL][TEST] Create ExplainSuite ## What changes were proposed in this pull request? Move the output verification of Explain test cases to a new suite ExplainSuite. ## How was this patch tested? N/A Closes #22300 from gatorsmile/test3200. Authored-by: Xiao Li Signed-off-by: Xiao Li --- .../org/apache/spark/sql/DataFrameSuite.scala | 9 --- .../apache/spark/sql/DatasetCacheSuite.scala | 11 ---- .../org/apache/spark/sql/DatasetSuite.scala | 10 ---- .../org/apache/spark/sql/ExplainSuite.scala | 58 +++++++++++++++++++ 4 files changed, 58 insertions(+), 30 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6f5c73074313c..d43fcf3c6f5de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2553,13 +2553,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-23034 show rdd names in RDD scan nodes") { - val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd") - val df2 = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string")) - val output2 = new java.io.ByteArrayOutputStream() - Console.withOut(output2) { - df2.explain(extended = false) - } - assert(output2.toString.contains("Scan ExistingRDD testRdd")) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 44177e36caa01..5c6a021d5b767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -206,15 +206,4 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits // first time use, load cache checkDataset(df5, Row(10)) } - - test("SPARK-24850 InMemoryRelation string representation does not include cached plan") { - val df = Seq(1).toDF("a").cache() - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - df.explain(false) - } - assert(outputStream.toString.replaceAll("#\\d+", "#x").contains( - "InMemoryRelation [a#x], StorageLevel(disk, memory, deserialized, 1 replicas)" - )) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6069f28d185e8..cf24eba128012 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1498,16 +1498,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.where($"city".contains(new java.lang.Character('A'))), Seq(Row("Amsterdam"))) } - - test("SPARK-23034 show rdd names in RDD scan nodes") { - val rddWithName = spark.sparkContext.parallelize(SingleData(1) :: Nil).setName("testRdd") - val df = spark.createDataFrame(rddWithName) - val output = new java.io.ByteArrayOutputStream() - Console.withOut(output) { - df.explain(extended = false) - } - assert(output.toString.contains("Scan testRdd")) - } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala new file mode 100644 index 0000000000000..56d300e30a58e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class ExplainSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + /** + * Runs the plan and makes sure the plans contains all of the keywords. + */ + private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = { + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain(extended = false) + } + for (key <- keywords) { + assert(output.toString.contains(key)) + } + } + + test("SPARK-23034 show rdd names in RDD scan nodes (Dataset)") { + val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd") + val df = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string")) + checkKeywordsExistsInExplain(df, keywords = "Scan ExistingRDD testRdd") + } + + test("SPARK-23034 show rdd names in RDD scan nodes (DataFrame)") { + val rddWithName = spark.sparkContext.parallelize(ExplainSingleData(1) :: Nil).setName("testRdd") + val df = spark.createDataFrame(rddWithName) + checkKeywordsExistsInExplain(df, keywords = "Scan testRdd") + } + + test("SPARK-24850 InMemoryRelation string representation does not include cached plan") { + val df = Seq(1).toDF("a").cache() + checkKeywordsExistsInExplain(df, + keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") + } +} + +case class ExplainSingleData(id: Int) From 32da87dfa451fff677ed9316f740be2abdbff6a4 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 31 Aug 2018 10:43:30 -0700 Subject: [PATCH 1525/2461] [SPARK-25286][CORE] Removing the dangerous parmap ## What changes were proposed in this pull request? I propose to remove one of `parmap` methods which accepts an execution context as a parameter. The method should be removed to eliminate any deadlocks that can occur if `parmap` is called recursively on thread pools restricted by size. Closes #22292 from MaxGekk/remove-overloaded-parmap. Authored-by: Maxim Gekk Signed-off-by: Xiao Li --- .../scala/org/apache/spark/rdd/UnionRDD.scala | 17 +++++----- .../org/apache/spark/util/ThreadUtils.scala | 32 +++---------------- .../util/FileBasedWriteAheadLog.scala | 5 +-- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 4b6f73235a57a..60e383afadf1c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,13 +20,12 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ThreadUtils.parmap import org.apache.spark.util.Utils /** @@ -60,7 +59,8 @@ private[spark] class UnionPartition[T: ClassTag]( } object UnionRDD { - private[spark] lazy val threadPool = new ForkJoinPool(8) + private[spark] lazy val partitionEvalTaskSupport = + new ForkJoinTaskSupport(new ForkJoinPool(8)) } @DeveloperApi @@ -74,13 +74,14 @@ class UnionRDD[T: ClassTag]( rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) override def getPartitions: Array[Partition] = { - val partitionLengths = if (isPartitionListingParallel) { - implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool) - parmap(rdds)(_.partitions.length) + val parRDDs = if (isPartitionListingParallel) { + val parArray = rdds.par + parArray.tasksupport = UnionRDD.partitionEvalTaskSupport + parArray } else { - rdds.map(_.partitions.length) + rdds } - val array = new Array[Partition](partitionLengths.sum) + val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index f0e5addbe5b56..cb0c20541d0d7 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -284,36 +284,12 @@ private[spark] object ThreadUtils { try { implicit val ec = ExecutionContext.fromExecutor(pool) - parmap(in)(f) + val futures = in.map(x => Future(f(x))) + val futureSeq = Future.sequence(futures) + + awaitResult(futureSeq, Duration.Inf) } finally { pool.shutdownNow() } } - - /** - * Transforms input collection by applying the given function to each element in parallel fashion. - * Comparing to the map() method of Scala parallel collections, this method can be interrupted - * at any time. This is useful on canceling of task execution, for example. - * - * @param in - the input collection which should be transformed in parallel. - * @param f - the lambda function will be applied to each element of `in`. - * @param ec - an execution context for parallel applying of the given function `f`. - * @tparam I - the type of elements in the input collection. - * @tparam O - the type of elements in resulted collection. - * @return new collection in which each element was given from the input collection `in` by - * applying the lambda function `f`. - */ - def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]] - (in: Col[I]) - (f: I => O) - (implicit - cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map - cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence - ec: ExecutionContext - ): Col[O] = { - val futures = in.map(x => Future(f(x))) - val futureSeq = Future.sequence(futures) - - awaitResult(futureSeq, Duration.Inf) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index bba071e80c0e4..f0161e1465c29 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -312,10 +312,11 @@ private[streaming] object FileBasedWriteAheadLog { handler: I => Iterator[O]): Iterator[O] = { val taskSupport = new ExecutionContextTaskSupport(executionContext) val groupSize = taskSupport.parallelismLevel.max(8) - implicit val ec = executionContext source.grouped(groupSize).flatMap { group => - ThreadUtils.parmap(group)(handler) + val parallelCollection = group.par + parallelCollection.tasksupport = taskSupport + parallelCollection.map(handler) }.flatten } } From e1d72f2c07ecd6f1880299e9373daa21cb032017 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 31 Aug 2018 15:46:45 -0700 Subject: [PATCH 1526/2461] [SPARK-25264][K8S] Fix comma-delineated arguments passed into PythonRunner and RRunner ## What changes were proposed in this pull request? Fixes the issue brought up in https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/issues/273 where the arguments were being comma-delineated, which was incorrect wrt to the PythonRunner and RRunner. ## How was this patch tested? Modified unit test to test this change. Author: Ilan Filonenko Closes #22257 from ifilonenko/SPARK-25264. --- .../k8s/features/bindings/PythonDriverFeatureStep.scala | 3 ++- .../deploy/k8s/features/bindings/RDriverFeatureStep.scala | 3 ++- .../k8s/features/bindings/PythonDriverFeatureStepSuite.scala | 4 ++-- .../k8s/features/bindings/RDriverFeatureStepSuite.scala | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala index c20bcac1f8987..406944a953382 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -30,11 +30,12 @@ private[spark] class PythonDriverFeatureStep( override def configurePod(pod: SparkPod): SparkPod = { val roleConf = kubernetesConf.roleSpecificConf require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") + // Delineation is done by " " because that is input into PythonRunner val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( pyArgs => new EnvVarBuilder() .withName(ENV_PYSPARK_ARGS) - .withValue(pyArgs.mkString(",")) + .withValue(pyArgs.mkString(" ")) .build()) val maybePythonFiles = kubernetesConf.pyFiles().map( // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala index b33b86e02ea6f..11b09b399618b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala @@ -30,11 +30,12 @@ private[spark] class RDriverFeatureStep( override def configurePod(pod: SparkPod): SparkPod = { val roleConf = kubernetesConf.roleSpecificConf require(roleConf.mainAppResource.isDefined, "R Main Resource must be defined") + // Delineation is done by " " because that is input into RRunner val maybeRArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( rArgs => new EnvVarBuilder() .withName(ENV_R_ARGS) - .withValue(rArgs.mkString(",")) + .withValue(rArgs.mkString(" ")) .build()) val envSeq = Seq(new EnvVarBuilder() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala index a5dac6869327d..c14af1d3b0f01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -44,7 +44,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { Some(PythonMainAppResource("local:///main.py")), "test-app", "python-runner", - Seq("5 7")), + Seq("5", "7", "9")), appResourceNamePrefix = "", appId = "", roleLabels = Map.empty, @@ -66,7 +66,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { .toMap assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) - assert(envs(ENV_PYSPARK_ARGS) === "5 7") + assert(envs(ENV_PYSPARK_ARGS) === "5 7 9") assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") } test("Python Step testing empty pyfiles") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala index 8fdf91ef638f2..ace0faa8629c3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala @@ -38,7 +38,7 @@ class RDriverFeatureStepSuite extends SparkFunSuite { Some(RMainAppResource(mainResource)), "test-app", "r-runner", - Seq("5 7")), + Seq("5", "7", "9")), appResourceNamePrefix = "", appId = "", roleLabels = Map.empty, @@ -58,6 +58,6 @@ class RDriverFeatureStepSuite extends SparkFunSuite { .map(env => (env.getName, env.getValue)) .toMap assert(envs(ENV_R_PRIMARY) === expectedMainResource) - assert(envs(ENV_R_ARGS) === "5 7") + assert(envs(ENV_R_ARGS) === "5 7 9") } } From c5583fdcd2289559ad98371475eb7288ced9b148 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 1 Sep 2018 12:19:19 +0900 Subject: [PATCH 1527/2461] [SPARK-23466][SQL] Remove redundant null checks in generated Java code by GenerateUnsafeProjection ## What changes were proposed in this pull request? This PR works for one of TODOs in `GenerateUnsafeProjection` "if the nullability of field is correct, we can use it to save null check" to simplify generated code. When `nullable=false` in `DataType`, `GenerateUnsafeProjection` removed code for null checks in the generated Java code. ## How was this patch tested? Added new test cases into `GenerateUnsafeProjectionSuite` Closes #20637 from kiszk/SPARK-23466. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../codegen/GenerateUnsafeProjection.scala | 77 +++++++++++-------- .../expressions/JsonExpressionsSuite.scala | 2 +- .../GenerateUnsafeProjectionSuite.scala | 71 ++++++++++++++++- 3 files changed, 117 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 998a675eecc62..0ecd0de8d8203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types._ */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true @@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( ctx: CodegenContext, input: String, index: String, - fieldTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode( - JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), - JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) + val fieldEvals = schemas.zipWithIndex.map { case (Schema(dt, nullable), i) => + val isNull = if (nullable) { + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)") + } else { + FalseLiteral + } + ExprCode(isNull, JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can calculate how many bytes are | // written later. | final int $previousCursor = $rowWriter.cursor(); - | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)} | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, row: String, inputs: Seq[ExprCode], - inputTypes: Seq[DataType], + schemas: Seq[Schema], rowWriter: String, isTopLevel: Boolean = false): String = { val resetWriter = if (isTopLevel) { @@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.resetRowWriter();" } - val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dataType), index) => + val writeFields = inputs.zip(schemas).zipWithIndex.map { + case ((input, Schema(dataType, nullable)), index) => val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { @@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == FalseLiteral) { + if (!nullable) { s""" |${input.code} |${writeField.trim} @@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """.stripMargin } - // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( ctx: CodegenContext, input: String, elementType: DataType, + containsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val element = CodeGenerator.getValue(tmpInput, et, index) + val elementAssignment = if (containsNull) { + s""" + |if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + |} else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + |} + """.stripMargin + } else { + writeElement(ctx, element, index, et, arrayWriter) + } + s""" |final ArrayData $tmpInput = $input; |if ($tmpInput instanceof UnsafeArrayData) { @@ -179,23 +195,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $arrayWriter.initialize($numElements); | | for (int $index = 0; $index < $numElements; $index++) { - | if ($tmpInput.isNullAt($index)) { - | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - | } else { - | ${writeElement(ctx, element, index, et, arrayWriter)} - | } + | $elementAssignment | } |} """.stripMargin } - // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, index: String, keyType: DataType, valueType: DataType, + valueContainsNull: Boolean, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") @@ -203,6 +215,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. + val keyArray = writeArrayToBuffer( + ctx, s"$tmpInput.keyArray()", keyType, false, rowWriter) + val valueArray = writeArrayToBuffer( + ctx, s"$tmpInput.valueArray()", valueType, valueContainsNull, rowWriter) + s""" |final MapData $tmpInput = $input; |if ($tmpInput instanceof UnsafeMapData) { @@ -219,7 +236,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | // Remember the current cursor so that we can write numBytes of key array later. | final int $tmpCursor = $rowWriter.cursor(); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | $keyArray | | // Write the numBytes of key array into the first 8 bytes. | Platform.putLong( @@ -227,7 +244,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $tmpCursor - 8, | $rowWriter.cursor() - $tmpCursor); | - | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $valueArray | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); |} """.stripMargin @@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro dt: DataType, writer: String): String = dt match { case t: StructType => - writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + writeStructToBuffer( + ctx, input, index, t.map(e => Schema(e.dataType, e.nullable)), writer) - case ArrayType(et, _) => + case ArrayType(et, en) => val previousCursor = ctx.freshName("previousCursor") s""" |// Remember the current cursor so that we can calculate how many bytes are |// written later. |final int $previousCursor = $writer.cursor(); - |${writeArrayToBuffer(ctx, input, et, writer)} + |${writeArrayToBuffer(ctx, input, et, en, writer)} |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """.stripMargin - case MapType(kt, vt, _) => - writeMapToBuffer(ctx, input, index, kt, vt, writer) + case MapType(kt, vt, vn) => + writeMapToBuffer(ctx, input, index, kt, vt, vn, writer) case DecimalType.Fixed(precision, scale) => s"$writer.write($index, $input, $precision, $scale);" @@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) - val exprTypes = expressions.map(_.dataType) + val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) - val numVarLenFields = exprTypes.count { - case dt if UnsafeRow.isFixedLength(dt) => false + val numVarLenFields = exprSchemas.count { + case Schema(dt, _) => !UnsafeRow.isFixedLength(dt) // TODO: consider large decimal and interval type - case _ => true } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = writeExpressionsToBuffer( - ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) + ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 04f1c8ce0b83d..0e9c8abec33e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -694,7 +694,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with |""".stripMargin val jsonSchema = new StructType() .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) + .add("b", StringType, nullable = !forceJsonNullableSchema) .add("c", StringType, nullable = false) val output = InternalRow(1L, null, UTF8String.fromString("foo")) val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index e9d21f8a8ebcd..01aa3579aea98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenerateUnsafeProjectionSuite extends SparkFunSuite { @@ -33,6 +33,41 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite { assert(!result.isNullAt(0)) assert(result.getStruct(0, 1).isNullAt(0)) } + + test("Test unsafe projection for array/map/struct") { + val dataType1 = ArrayType(StringType, false) + val exprs1 = BoundReference(0, dataType1, nullable = false) :: Nil + val projection1 = GenerateUnsafeProjection.generate(exprs1) + val result1 = projection1.apply(AlwaysNonNull) + assert(!result1.isNullAt(0)) + assert(!result1.getArray(0).isNullAt(0)) + assert(!result1.getArray(0).isNullAt(1)) + assert(!result1.getArray(0).isNullAt(2)) + + val dataType2 = MapType(StringType, StringType, false) + val exprs2 = BoundReference(0, dataType2, nullable = false) :: Nil + val projection2 = GenerateUnsafeProjection.generate(exprs2) + val result2 = projection2.apply(AlwaysNonNull) + assert(!result2.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(0)) + assert(!result2.getMap(0).keyArray.isNullAt(1)) + assert(!result2.getMap(0).keyArray.isNullAt(2)) + assert(!result2.getMap(0).valueArray.isNullAt(0)) + assert(!result2.getMap(0).valueArray.isNullAt(1)) + assert(!result2.getMap(0).valueArray.isNullAt(2)) + + val dataType3 = (new StructType) + .add("a", StringType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val exprs3 = BoundReference(0, dataType3, nullable = false) :: Nil + val projection3 = GenerateUnsafeProjection.generate(exprs3) + val result3 = projection3.apply(InternalRow(AlwaysNonNull)) + assert(!result3.isNullAt(0)) + assert(!result3.getStruct(0, 1).isNullAt(0)) + assert(!result3.getStruct(0, 2).isNullAt(0)) + assert(!result3.getStruct(0, 3).isNullAt(0)) + } } object AlwaysNull extends InternalRow { @@ -59,3 +94,35 @@ object AlwaysNull extends InternalRow { override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported private def notSupported: Nothing = throw new UnsupportedOperationException } + +object AlwaysNonNull extends InternalRow { + private def stringToUTF8Array(stringArray: Array[String]): ArrayData = { + val utf8Array = stringArray.map(s => UTF8String.fromString(s)).toArray + ArrayData.toArrayData(utf8Array) + } + override def numFields: Int = 1 + override def setNullAt(i: Int): Unit = {} + override def copy(): InternalRow = this + override def anyNull: Boolean = notSupported + override def isNullAt(ordinal: Int): Boolean = notSupported + override def update(i: Int, value: Any): Unit = notSupported + override def getBoolean(ordinal: Int): Boolean = notSupported + override def getByte(ordinal: Int): Byte = notSupported + override def getShort(ordinal: Int): Short = notSupported + override def getInt(ordinal: Int): Int = notSupported + override def getLong(ordinal: Int): Long = notSupported + override def getFloat(ordinal: Int): Float = notSupported + override def getDouble(ordinal: Int): Double = notSupported + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported + override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test") + override def getBinary(ordinal: Int): Array[Byte] = notSupported + override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported + override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) + val keyArray = stringToUTF8Array(Array("1", "2", "3")) + val valueArray = stringToUTF8Array(Array("a", "b", "c")) + override def getMap(ordinal: Int): MapData = new ArrayBasedMapData(keyArray, valueArray) + override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported + private def notSupported: Nothing = throw new UnsupportedOperationException + +} From 7c36ee46d974021474d8098f87f70440a10319ee Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 1 Sep 2018 16:25:29 +0800 Subject: [PATCH 1528/2461] [SPARK-25290][CORE][TEST] Reduce the size of acquired arrays to avoid OOM error ## What changes were proposed in this pull request? `BytesToBytesMapOnHeapSuite`.`randomizedStressTest` caused `OutOfMemoryError` on several test runs. Seems better to reduce memory usage in this test. ## How was this patch tested? Unit tests. Closes #22297 from viirya/SPARK-25290. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 03cec8ed81b72..53a233f698c7a 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -379,7 +379,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { @Test public void randomizedStressTest() { - final int size = 65536; + final int size = 32768; // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap<>(); @@ -388,7 +388,7 @@ public void randomizedStressTest() { // Fill the map to 90% full so that we can trigger probing for (int i = 0; i < size * 0.9; i++) { final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); - final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(256) + 1); if (!expected.containsKey(ByteBuffer.wrap(key))) { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( From 6ad8d4c375772c0c907c25837de762b5b9266a8e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 1 Sep 2018 08:41:07 -0500 Subject: [PATCH 1529/2461] [SPARK-25289][ML] Avoid exception in ChiSqSelector with FDR when no feature is selected ## What changes were proposed in this pull request? Currently, when FDR is used for `ChiSqSelector` and no feature is selected an exception is thrown because the max operation fails. The PR fixes the problem by handling this case and returning an empty array in that case, as sklearn (which was the reference for the initial implementation of FDR) does. ## How was this patch tested? added UT Closes #22303 from mgaido91/SPARK-25289. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../apache/spark/mllib/feature/ChiSqSelector.scala | 12 ++++++++---- .../apache/spark/ml/feature/ChiSqSelectorSuite.scala | 11 +++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f923be871f438..aa78e91b679ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} @@ -272,13 +273,16 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { // https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure val tempRes = chiSqTestResult .sortBy { case (res, _) => res.pValue } - val maxIndex = tempRes + val selected = tempRes .zipWithIndex .filter { case ((res, _), index) => res.pValue <= fdr * (index + 1) / chiSqTestResult.length } - .map { case (_, index) => index } - .max - tempRes.take(maxIndex + 1) + if (selected.isEmpty) { + Array.empty[(ChiSqTestResult, Int)] + } else { + val maxIndex = selected.map(_._2).max + tempRes.take(maxIndex + 1) + } case ChiSqSelector.FWE => chiSqTestResult .filter { case (res, _) => res.pValue < fwe / chiSqTestResult.length } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c843df9f33e3e..80499e79e3bd6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -163,6 +163,17 @@ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { } } + test("SPARK-25289: ChiSqSelector should not fail when selecting no features with FDR") { + val labeledPoints = (0 to 1).map { n => + val v = Vectors.dense((1 to 3).map(_ => n * 1.0).toArray) + (n.toDouble, v) + } + val inputDF = spark.createDataFrame(labeledPoints).toDF("label", "features") + val selector = new ChiSqSelector().setSelectorType("fdr").setFdr(0.05) + val model = selector.fit(inputDF) + assert(model.selectedFeatures.isEmpty) + } + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { val selectorModel = selector.fit(data) testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, From a3dccd24c2e932b90e647e678f351f5b5568305b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 1 Sep 2018 18:07:58 -0500 Subject: [PATCH 1530/2461] [SPARK-10697][ML] Add lift to Association rules ## What changes were proposed in this pull request? The PR adds the lift measure to Association rules. ## How was this patch tested? existing and modified UTs Closes #22236 from mgaido91/SPARK-10697. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- R/pkg/R/mllib_fpm.R | 5 +- R/pkg/tests/fulltests/test_mllib_fpm.R | 3 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 61 ++++++++++++++----- .../spark/mllib/fpm/AssociationRules.scala | 37 +++++++++-- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 25 +++++--- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 6 +- project/MimaExcludes.scala | 4 ++ python/pyspark/ml/fpm.py | 3 +- python/pyspark/ml/tests.py | 4 +- 9 files changed, 108 insertions(+), 40 deletions(-) diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index e2394906d8012..4ad34fe82328f 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -116,10 +116,11 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), # Get association rules. #' @return A \code{SparkDataFrame} with association rules. -#' The \code{SparkDataFrame} contains three columns: +#' The \code{SparkDataFrame} contains four columns: #' \code{antecedent} (an array of the same type as the input column), #' \code{consequent} (an array of the same type as the input column), -#' and \code{condfidence} (confidence). +#' \code{condfidence} (confidence for the rule) +#' and \code{lift} (lift for the rule) #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method #' @note spark.associationRules(FPGrowthModel) since 2.2.0 diff --git a/R/pkg/tests/fulltests/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R index 69dda52f0c279..d80f66a25de1c 100644 --- a/R/pkg/tests/fulltests/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -44,7 +44,8 @@ test_that("spark.fpGrowth", { expected_association_rules <- data.frame( antecedent = I(list(list("2"), list("3"))), consequent = I(list(list("1"), list("1"))), - confidence = c(1, 1) + confidence = c(1, 1), + lift = c(1, 1) ) expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 85c483c387ad8..840a89b76d26b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -20,6 +20,8 @@ package org.apache.spark.ml.fpm import scala.reflect.ClassTag import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} @@ -34,6 +36,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.VersionUtils /** * Common params for FPGrowth and FPGrowthModel @@ -175,7 +178,8 @@ class FPGrowth @Since("2.2.0") ( if (handlePersistence) { items.persist(StorageLevel.MEMORY_AND_DISK) } - + val inputRowCount = items.count() + instr.logNumExamples(inputRowCount) val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( @@ -187,7 +191,8 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + copyValues(new FPGrowthModel(uid, frequentItems, parentModel.itemSupport, inputRowCount)) + .setParent(this) } @Since("2.2.0") @@ -217,7 +222,9 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { @Experimental class FPGrowthModel private[ml] ( @Since("2.2.0") override val uid: String, - @Since("2.2.0") @transient val freqItemsets: DataFrame) + @Since("2.2.0") @transient val freqItemsets: DataFrame, + private val itemSupport: scala.collection.Map[Any, Double], + private val numTrainingRecords: Long) extends Model[FPGrowthModel] with FPGrowthParams with MLWritable { /** @group setParam */ @@ -241,9 +248,9 @@ class FPGrowthModel private[ml] ( @transient private var _cachedRules: DataFrame = _ /** - * Get association rules fitted using the minConfidence. Returns a dataframe - * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and - * "consequent" are Array[T] and "confidence" is Double. + * Get association rules fitted using the minConfidence. Returns a dataframe with four fields, + * "antecedent", "consequent", "confidence" and "lift", where "antecedent" and "consequent" are + * Array[T], whereas "confidence" and "lift" are Double. */ @Since("2.2.0") @transient def associationRules: DataFrame = { @@ -251,7 +258,7 @@ class FPGrowthModel private[ml] ( _cachedRules } else { _cachedRules = AssociationRules - .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence), itemSupport) _cachedMinConf = $(minConfidence) _cachedRules } @@ -301,7 +308,7 @@ class FPGrowthModel private[ml] ( @Since("2.2.0") override def copy(extra: ParamMap): FPGrowthModel = { - val copied = new FPGrowthModel(uid, freqItemsets) + val copied = new FPGrowthModel(uid, freqItemsets, itemSupport, numTrainingRecords) copyValues(copied, extra).setParent(this.parent) } @@ -323,7 +330,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) + val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = Some(extraMetadata)) val dataPath = new Path(path, "data").toString instance.freqItemsets.write.parquet(dataPath) } @@ -335,10 +343,28 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { + implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) + val numTrainingRecords = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt < 4)) { + // 2.3 and before don't store the count + 0L + } else { + // 2.4+ + (metadata.metadata \ "numTrainingRecords").extract[Long] + } val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) - val model = new FPGrowthModel(metadata.uid, frequentItems) + val itemSupport = if (numTrainingRecords == 0L) { + Map.empty[Any, Double] + } else { + frequentItems.rdd.flatMap { + case Row(items: Seq[_], count: Long) if items.length == 1 => + Some(items.head -> count.toDouble / numTrainingRecords) + case _ => None + }.collectAsMap() + } + val model = new FPGrowthModel(metadata.uid, frequentItems, itemSupport, numTrainingRecords) metadata.getAndSetParams(model) model } @@ -354,27 +380,30 @@ private[fpm] object AssociationRules { * @param itemsCol column name for frequent itemsets * @param freqCol column name for appearance count of the frequent itemsets * @param minConfidence minimum confidence for generating the association rules - * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) - * containing the association rules. + * @param itemSupport map containing an item and its support + * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double], + * "lift" [Double]) containing the association rules. */ def getAssociationRulesFromFP[T: ClassTag]( dataset: Dataset[_], itemsCol: String, freqCol: String, - minConfidence: Double): DataFrame = { + minConfidence: Double, + itemSupport: scala.collection.Map[T, Double]): DataFrame = { val freqItemSetRdd = dataset.select(itemsCol, freqCol).rdd .map(row => new FreqItemset(row.getSeq[T](0).toArray, row.getLong(1))) val rows = new MLlibAssociationRules() .setMinConfidence(minConfidence) - .run(freqItemSetRdd) - .map(r => Row(r.antecedent, r.consequent, r.confidence)) + .run(freqItemSetRdd, itemSupport) + .map(r => Row(r.antecedent, r.consequent, r.confidence, r.lift.orNull)) val dt = dataset.schema(itemsCol).dataType val schema = StructType(Seq( StructField("antecedent", dt, nullable = false), StructField("consequent", dt, nullable = false), - StructField("confidence", DoubleType, nullable = false))) + StructField("confidence", DoubleType, nullable = false), + StructField("lift", DoubleType))) val rules = dataset.sparkSession.createDataFrame(rows, schema) rules } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index acb83ac31affd..43d256bbc46c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -56,11 +56,24 @@ class AssociationRules private[fpm] ( /** * Computes the association rules with confidence above `minConfidence`. * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] - * @return a `Set[Rule[Item]]` containing the association rules. + * @return a `RDD[Rule[Item]]` containing the association rules. * */ @Since("1.5.0") def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + run(freqItemsets, Map.empty[Item, Double]) + } + + /** + * Computes the association rules with confidence above `minConfidence`. + * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] + * @param itemSupport map containing an item and its support + * @return a `RDD[Rule[Item]]` containing the association rules. The rules will be able to + * compute also the lift metric. + */ + @Since("2.4.0") + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]], + itemSupport: scala.collection.Map[Item, Double]): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => val items = itemset.items @@ -76,8 +89,13 @@ class AssociationRules private[fpm] ( // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => - new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) - }.filter(_.confidence >= minConfidence) + new Rule(antecendent.toArray, + consequent.toArray, + freqUnion, + freqAntecedent, + // the consequent contains always only one element + itemSupport.get(consequent.head)) + }.filter(_.confidence >= minConfidence) } /** @@ -107,14 +125,21 @@ object AssociationRules { @Since("1.5.0") val antecedent: Array[Item], @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, - freqAntecedent: Double) extends Serializable { + freqAntecedent: Double, + freqConsequent: Option[Double]) extends Serializable { /** * Returns the confidence of the rule. * */ @Since("1.5.0") - def confidence: Double = freqUnion.toDouble / freqAntecedent + def confidence: Double = freqUnion / freqAntecedent + + /** + * Returns the lift of the rule. + */ + @Since("2.4.0") + def lift: Option[Double] = freqConsequent.map(fCons => confidence / fCons) require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { val sharedItems = antecedent.toSet.intersect(consequent.toSet) @@ -142,7 +167,7 @@ object AssociationRules { override def toString: String = { s"${antecedent.mkString("{", ",", "}")} => " + - s"${consequent.mkString("{", ",", "}")}: ${confidence}" + s"${consequent.mkString("{", ",", "}")}: (confidence: $confidence; lift: $lift)" } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 4f2b7e6f0764e..3a1bc35186dc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -48,9 +48,14 @@ import org.apache.spark.storage.StorageLevel * @tparam Item item type */ @Since("1.3.0") -class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) +class FPGrowthModel[Item: ClassTag] @Since("2.4.0") ( + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]], + @Since("2.4.0") val itemSupport: Map[Item, Double]) extends Saveable with Serializable { + + @Since("1.3.0") + def this(freqItemsets: RDD[FreqItemset[Item]]) = this(freqItemsets, Map.empty) + /** * Generates association rules for the `Item`s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -58,7 +63,7 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) - associationRules.run(freqItemsets) + associationRules.run(freqItemsets, itemSupport) } /** @@ -213,9 +218,12 @@ class FPGrowth private[spark] ( val minCount = math.ceil(minSupport * count).toLong val numParts = if (numPartitions > 0) numPartitions else data.partitions.length val partitioner = new HashPartitioner(numParts) - val freqItems = genFreqItems(data, minCount, partitioner) - val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) - new FPGrowthModel(freqItemsets) + val freqItemsCount = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItemsCount.map(_._1), partitioner) + val itemSupport = freqItemsCount.map { + case (item, cnt) => item -> cnt.toDouble / count + }.toMap + new FPGrowthModel(freqItemsets, itemSupport) } /** @@ -231,12 +239,12 @@ class FPGrowth private[spark] ( * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items - * @return array of frequent pattern ordered by their frequencies + * @return array of frequent patterns and their frequencies ordered by their frequencies */ private def genFreqItems[Item: ClassTag]( data: RDD[Array[Item]], minCount: Long, - partitioner: Partitioner): Array[Item] = { + partitioner: Partitioner): Array[(Item, Long)] = { data.flatMap { t => val uniq = t.toSet if (t.length != uniq.size) { @@ -248,7 +256,6 @@ class FPGrowth private[spark] ( .filter(_._2 >= minCount) .collect() .sortBy(-_._2) - .map(_._1) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 87f8b9034dde8..b75526a48371a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -39,9 +39,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules val expectedRules = spark.createDataFrame(Seq( - (Array("2"), Array("1"), 1.0), - (Array("1"), Array("2"), 0.75) - )).toDF("antecedent", "consequent", "confidence") + (Array("2"), Array("1"), 1.0, 1.0), + (Array("1"), Array("2"), 0.75, 1.0) + )).toDF("antecedent", "consequent", "confidence", "lift") .withColumn("antecedent", col("antecedent").cast(ArrayType(dt))) .withColumn("consequent", col("consequent").cast(ArrayType(dt))) assert(expectedRules.sort("antecedent").rdd.collect().sameElements( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f250c9943edb..62f8b1af50a6c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-10697][ML] Add lift to Association rules + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), + // [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12 ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index f9394421e0cc4..c2b29b73460ff 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -145,10 +145,11 @@ def freqItemsets(self): @since("2.2.0") def associationRules(self): """ - DataFrame with three columns: + DataFrame with four columns: * `antecedent` - Array of the same type as the input column. * `consequent` - Array of the same type as the input column. * `confidence` - Confidence for the rule (`DoubleType`). + * `lift` - Lift for the rule (`DoubleType`). """ return self._call_java("associationRules") diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 5c87d1de4139b..625d9927f7063 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2158,8 +2158,8 @@ def test_association_rules(self): fpm = fp.fit(self.data) expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0), ([2], [1], 1.0)], - ["antecedent", "consequent", "confidence"] + [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], + ["antecedent", "consequent", "confidence", "lift"] ) actual_association_rules = fpm.associationRules From a481794ca9a5edb87982679cd0e95146f668fe78 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 2 Sep 2018 00:06:19 -0700 Subject: [PATCH 1531/2461] [SPARK-25007][R] Add array_intersect/array_except/array_union/shuffle to SparkR ## What changes were proposed in this pull request? Add the R version of array_intersect/array_except/array_union/shuffle ## How was this patch tested? Add test in test_sparkSQL.R Author: Huaxin Gao Closes #22291 from huaxingao/spark-25007. --- R/pkg/NAMESPACE | 4 ++ R/pkg/R/functions.R | 59 ++++++++++++++++++++++++++- R/pkg/R/generics.R | 16 ++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 19 +++++++++ 4 files changed, 97 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 0fd08482c4413..96ff389faf4a0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,6 +204,8 @@ exportMethods("%<=>%", "approxQuantile", "array_contains", "array_distinct", + "array_except", + "array_intersect", "array_join", "array_max", "array_min", @@ -212,6 +214,7 @@ exportMethods("%<=>%", "array_repeat", "array_sort", "arrays_overlap", + "array_union", "arrays_zip", "asc", "ascii", @@ -355,6 +358,7 @@ exportMethods("%<=>%", "shiftLeft", "shiftRight", "shiftRightUnsigned", + "shuffle", "sd", "sign", "signum", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 2929a00330c62..d157acc3ca47b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,7 +208,7 @@ NULL #' # Dataframe used throughout this doc #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) -#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) @@ -223,6 +223,8 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, array_except(tmp4$v4, tmp4$v5), array_intersect(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, array_union(tmp4$v4, tmp4$v5))) #' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) @@ -3024,6 +3026,34 @@ setMethod("array_distinct", column(jc) }) +#' @details +#' \code{array_except}: Returns an array of the elements in the first array but not in the second +#' array, without duplicates. The order of elements in the result is not determined. +#' +#' @rdname column_collection_functions +#' @aliases array_except array_except,Column-method +#' @note array_except since 2.4.0 +setMethod("array_except", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_except", x@jc, y@jc) + column(jc) + }) + +#' @details +#' \code{array_intersect}: Returns an array of the elements in the intersection of the given two +#' arrays, without duplicates. +#' +#' @rdname column_collection_functions +#' @aliases array_intersect array_intersect,Column-method +#' @note array_intersect since 2.4.0 +setMethod("array_intersect", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_intersect", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{array_join}: Concatenates the elements of column using the delimiter. #' Null values are replaced with nullReplacement if set, otherwise they are ignored. @@ -3149,6 +3179,20 @@ setMethod("arrays_overlap", column(jc) }) +#' @details +#' \code{array_union}: Returns an array of the elements in the union of the given two arrays, +#' without duplicates. +#' +#' @rdname column_collection_functions +#' @aliases array_union array_union,Column-method +#' @note array_union since 2.4.0 +setMethod("array_union", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_union", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th #' values of input arrays. @@ -3167,6 +3211,19 @@ setMethod("arrays_zip", column(jc) }) +#' @details +#' \code{shuffle}: Returns a random permutation of the given array. +#' +#' @rdname column_collection_functions +#' @aliases shuffle shuffle,Column-method +#' @note shuffle since 2.4.0 +setMethod("shuffle", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "shuffle", x@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f6f1849787a23..27c1b312d645c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -767,6 +767,14 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain #' @name NULL setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_except", function(x, y) { standardGeneric("array_except") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_intersect", function(x, y) { standardGeneric("array_intersect") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) @@ -799,6 +807,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) #' @name NULL setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_union", function(x, y) { standardGeneric("array_union") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) @@ -1220,6 +1232,10 @@ setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("shuffle", function(x) { standardGeneric("shuffle") }) + #' @rdname column_math_functions #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index e1f3cf339e83f..17e4a970425af 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1598,6 +1598,25 @@ test_that("column functions", { result <- collect(select(df, element_at(df$map, "y")))[[1]] expect_equal(result, 2) + # Test array_except(), array_intersect() and array_union() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, 2L, 3L), list(3L, 4L)))) + result1 <- collect(select(df, array_except(df[[1]], df[[2]])))[[1]] + expect_equal(result1, list(list(2L), list(1L, 2L), list(1L, 2L))) + + result2 <- collect(select(df, array_intersect(df[[1]], df[[2]])))[[1]] + expect_equal(result2, list(list(1L, 3L), list(), list(3L))) + + result3 <- collect(select(df, array_union(df[[1]], df[[2]])))[[1]] + expect_equal(result3, list(list(1L, 2L, 3L), list(1L, 2L, 3L, 4L), list(1L, 2L, 3L, 4L))) + + # Test shuffle() + df <- createDataFrame(list(list(list(1L, 20L, 3L, 5L)), list(list(4L, 5L, 6L, 7L)))) + result <- collect(select(df, shuffle(df[[1]])))[[1]] + expect_true(setequal(result[[1]], c(1L, 20L, 3L, 5L))) + expect_true(setequal(result[[2]], c(4L, 5L, 6L, 7L))) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) From 64bbd134ea1db5ead72461a65e130ef486c24dbf Mon Sep 17 00:00:00 2001 From: Darcy Shen Date: Sun, 2 Sep 2018 21:57:06 -0500 Subject: [PATCH 1532/2461] [SPARK-25304][SPARK-8489][SQL][TEST] Fix HiveSparkSubmitSuite test for Scala 2.12 ## What changes were proposed in this pull request? remove test-2.10.jar and add test-2.12.jar. ## How was this patch tested? ``` $ sbt -Dscala-2.12 > ++ 2.12.6 > project hive > testOnly *HiveSparkSubmitSuite -- -z "8489" ``` Closes #22308 from sadhen/SPARK-8489-FOLLOWUP. Authored-by: Darcy Shen Signed-off-by: Sean Owen --- .../regression-test-SPARK-8489/test-2.10.jar | Bin 6865 -> 0 bytes .../regression-test-SPARK-8489/test-2.12.jar | Bin 0 -> 7179 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar create mode 100644 sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.12.jar diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar deleted file mode 100644 index 3f28d37b93150ebdeec4c6d803351f8c9e1f6cf2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6865 zcmaJ`1yCK^vOPcuaNq=YOK^f)g1fuB22XHzJ@~&<3UIAo8lm+M{WJMVs2LJ#VfV>nO z{L2jJe@&DBTQJ7+jQ@aT1Y{*dMU<53WkfGzh6bgj=;+5mQgrV}hDNIu8K+oQx9sR8 zWoRX2#vO}ZA*k=-cBkM{ncZTOQ|x?u&%T3=xrt344~YA6Zg*w}^$#aLwF>}+o*yCn zeEz)?z!xVB%&q7REe!1Kbrf{0(G*b9ej>*sB03--#20T}*qkiB>sTj0jIS>=x`PRBIewT#8XJ-nr~8$(n>p`X+}H$rJ*GezQFb^XEwiEgck|c3 z4(Cve1eUCr2dxZsr;103)gtpm>Z&Ya5Bn|67gxS2xi7q~<}4^7>WO<8DWl(?0gRGs z;3#~*nL5$ja0jB(#FY7ll2j=Ob8<`{8b;`4F1B&=ymxqY#!%X|lcw!8NiGi!-M z;w8n@Q6lp9Y7C;SvfqF;Ia| z(Nk@N9yVqIUHZ5qsrIuu!bkUQ=yg0qBK>$pvr>FcC#jYAjY^}H9(=*Bhp&03C%hBj zKq&lvn{UGfyHMh(X;348U7v$0I+%H?%y*8@n4J~`zK}MRPWT#)i5537RZG5(QMWDA$aoQ3JXP< zsvA}SDf!~MS|OM2aU+T$Y5950=5nW%8_%c0%T$|^e4d9zzM}k<6mVDcD`}3P7T1%d zt>%^1leU}D@7|9feJ(Hbf)f^Lnrs1h+C~cqekH^p*AS*F!FXU&o+gvVm;R%%5fW6X z_3GGO_Zs|nw6;%?{Fxu)nCK=KDCO}j`AEmc4&;Q!P^T39MSYuS17KB!8OTi1$as2d zCU_*bh*rD~V@bs}sm5us)KeP_ODVyxX+fV5I-#{*U&MDk?V`NA4&^4JJFL-zC+>?$#`!3|?zV!i|pLH?3&}ZuyJY&YKAo%i`lF7C~f0 zq35W^bZcQjDT&b~Yi3=Un_}64FKA zlDc996BjakiL~qjGP#XSHlm@;o0WN0PJ#=nPgQ)NJrh_uP$a&v*+vT6xh)K=p&jEbB{kXR;0 z?Cd~=5A}+Do_QN{H3X3(hAwSBVWCz=*C{yk9TviP39vJuJb8X z$C8n^=|4b(vuBX@rT{WF?p{bdAnJ*$sAh^;&;^QN--U>wu-LAHrfPaPUvY66rJbfF zPiJ%_;550ZzI?-x-X%dgy@0bzqS^jly>PPf=Pgds!F#+0ei<^oW%L2ry03#k2W9WK z1x%gQGcAn6MKG*IfjO{ujCWv5!#vtd$+@) z%rRFITewSPSrH5ZPFy_*-#?=Y`#eGO+0;RD?kBBsOn_M{zY*~9urOh}?N}92T};si zI|mWfc|x|K=R2dJIXf4C!I%}TKEHm|$z%tSIC>kLG|n$PAe#!hIsF6wy|Lh&t<#^8=T%dF;@KzJ#)(}GRR1l#_e*18RxSkI>lLm;v9 zgRh)vRzN5|x!Hh%|Mv;OdcI$ZSGNI%aJ^p+e`MI+AV#SAXI>oKer<9+_U<{u_Xs*^ zd@K$<`$u#}9UvJrh6MoDo&)gTM`yf$M`tnzR~utPJ7WU}Yda<~rhf+OaAhf_NqLM% zS#tzBfmfB>;B1s!9~1$Rce%lYX|{PiBEI3p3n*`;lO4|W^PVI>vNzdi|H>g4mEl4% z3pKJO-tpMsx@kE-xVa$)%qYBrOoNX;YtC6|T*_!)?LRL*FtpqO8uNqZqu3(nIXCxH@^p53;PADSxbIg+dwOGHBA} z%goWMK3yqSd#+rB_q>Yp5YFs(tT)C7hv?1#1L0w!zhr26>7xRB$eR4imgSis%^Zsz z@CFq#WiIJTw!$b#t3-%?^58mLMEH~*a06Py!I4RqJ#=aH;7(CPWQzl{top9y+ni=r zMMjBTY+sBjl}{KF=|uW&C`yU%AYff>&QXvcA9JGz_)a+epkIY4Ea>Y7>MSdFEytw? zTcJS%5=UvFNErq(YMxs{Y}5ov=cP<;R3b&Mc9jknE&6?E)tC5+R3SV$gz6{P20xWf z7vv~E(^t7+o7{^eNx*b;3U_GC7BVA-aq!yr2_c*&RfQ3g9 z*mY&|WwqFB`0*=oMhDj+G@X4r_R5@dP1Md>YMK{@8;~o}#v5?lH8SM9^aril4arK) z^1>PVN@Lyz?BWd@4fj>U`aaxj?Sl(1tyPS)4%#OWby?zr;_5rUfnQiF5zXhF=dhw* z2Hk0L46ze$Lkg{_P}V&oPC4pWTy*v`nULvBXjYlh84X{ftMf-I0MFHpG>9cjPE6Bk z<|A+iJT|u1RPJY76Xu5|)Ji-W9kQIjI@oFw6LTR7xGifu!)11JTayGCj&E#vWL$hf zE$e)Gp@c0rwBo!d&93ygmtha`3bOcGX6Mn{s2%a3PGp+}Q zzY=kupa@4D&ZA85C?|(5H*(YVyr+_C_8Mt${pFdZ^YENbmM$De@6zl?TOSUdvk*jp zbmz`1+P!MwMr~<~T1^>%;OY;SqW~=W&$&6L^BpWaYS=#H`lVYn*;rocbUXO5-KaLK z5-D!AnV5(Swf@=kKK%S``nH`>?-(2AG0OIXL|Ac`zv-m}+A7=ga@mDJm!J_g2IGg~sCXa>AFe4@ zp=AtHQsgc-41}eMO;HB=1?@FxD`2rZ@ki5qGrcCO-H@5*_)!AsV;*Fl(JHgL0h``y z_@BtiUfQnHk-*)@zCy48BAA-szoTy_3!eJ;R+OfTb|bnGUZY=jkH)h&G`Xg7=JM36 z!F$_0vVa+B5QJu{aHyb;y&$yM3u7NzlFC$MC*Gvqlf^VXKj2?Uy6~2gBAAvF`aSl# za#egl9N7a@$)V^_Q2`ncGZEj z7RSNt12+D|b!w=N5q^1ke{!1XC>|9rdyJm5Yd+V0=UZHc}j%O}gx0>gg?J|!nd`MT)&ObjpToVl7 z1{#FJONtwhD2APd?ti%aVz}SGi62R ze8j<)XQ^KN)3?ehTJphVjd5Om=JTeJ;%M8^Aw?@O$$6J;jn@ZyoU0Z|7(sV zq1#O&ka#pUGFIH!mP$aqd0jq`-TpWU>}o%?iCUR?tn@@{8ahB^cII&sPOcP1h{eMj zfL&_NBt$Z&!)+P?FLc0zhC8kv_fyq4gXAH!rwP#v+Hj@pMiVZ(w-zL#t0(Usn_i%; zAx|hN6A&MyXvOQw&v0FD|ck>5f+^Zhat*O3;;Jj~}K;kYB z!wm^E31^A^%WGv%@AGR?no-7jl}G|*3mxc$D?0SauAbGo8pkMro^>k56*tt5S?07z z0$#3mt316(g{>rZQ>h++3+i2KT}5ltE~UO!YhP*%|LY+5kL}dHm=-$h2Lfl>Ihu64 zt_W@^tRJB!LsW|mgY-PTu`Fha113=EsZy1%l63_?c+X`nrSN}ra}rVX<9+^H`o61C z^VTd_duKOJjB)w_Hjo_0Y(!L^RgIT}w1VAjdfk!AS4t`~@&=eI$RqmpUJ7(9j^d+_ z?S?HVTgWyr!7i?WepDiXm(~$lQ=Mnl@cQ6`pxp$w0`YpkIy>7WjQZ#3 zsn1yv;uSf*7|CX+*kSjTP+p^E`y{)=Ie0KcUUE`#=l6S>Ic-ij3?C)$=GV<-`LY~I z+j*cpOVJw@im)TDY+5{uWv`9;2! zZP#Y_C=jFX``XM;WPeA_Aov3nIy16vV9rIqcU5p*$QbxN<9H< zH;SL2<0#z{;xp}jCrGyaX%-NuwfPQD{WmUxVTMI;#hVoM`KxMlo~Mc>|H|=cDeUaV z40~}u%eg&4_vT;)SsFT3Jk3f_lIkP0dxGUq=TDAn;k9bWWtV;#b8&ufE)?rVSsLY{ zbDW-f65BjqXX=9L{fvw0{9MRxu(^e`Rh+6JlK&RixaBXeqna6Rig!<+k=0y@&RBrI zDVssnDf&{qSg?JW3(dCr)9lzMO|aAlYg(~Td$s`caAw;%!9kvJkHW)+5eD(i{M^|J z9Nf{$5B4!U5x0^|ZP52rexFBhdDbtwhBjv2{aeN{{*L7gluK<44_{h~irYLt1A;;- z><4s@)3|r1xNW%2NHmt+T-hVuLsBMuLVL zE1h%h=p7Wr#(pDIb=`>y4!DM*XT}RyI47%1TJU4yadn)QGI9 zt0VTDRO)i}k%YC14Czz?pQ>bH_6j%c>5xo$LQ~Gxt+Ad(3qwQ}N4=vwcP%Yg*E)bw5 z`dk2CO7Aneqk+la1L;^F3kZ?ygQAibHpsmUq^tr&JF-S>vcAJy5GCXVEj9)bxZO4k z3qfb38ZK;);DtO9fklO z-vyKr^!R$HBxJk8uyd9+;f_+gndq2U-yJ5Bdr^zs4TNmoMCEN4Z+-6S&$qD#dJJMM zj`cQGH9Lp(Ke>i~{S+8fChr3B+XE%eId3FbCOxea%rbDUO0qm@RXhS@p``g%myX3S zc3|~S(Vv2ux<0ojoRZcFXQerYd^v9C!M^<{PMqT!$XtNY?(LJC&8PUF#VS}2(lya4 zyq_{V037S@;@`eyzXsj+Qruhz>WdBIKd)zwnpxKp8zqP?%buX}@0DYQ1S_IQ+0+N&4_k4+=e5{F4XEnHcWQp;bkV zFoLDYXex)j$@VW3C-2{n*M#ld%S?SqU9kpc)H2iwlw^#7IS!2lW|E5~yYX)A8#Dn` z#_vVGv}R+|O^;!J-Sf;X>ShyJ3EQDnn!Z7st(r2Hq@`S><*+IW%_E|Q4X4EU@f4%w z7nzlRR|HlK)Yl2bT?i5wFjC|DHcNrSd4QOhscqSxowr|9JG2T4sK?cexmFd0+XQg$ zk`WKu44CdB2BZiu)y#@m3L*QC-WSBR`WVK_Pt# z-lw594ij$yLX??ZlMWmfo3tw5Q&1|2I7dSQ>Eke*xDx3?M&Kz@f9GY{Cx2^3BHZJY z5)351`P~4wFVcRLY0IDH&abq4SpS|pTq#nBIS>`&O4JC#IbQ=|5=$}HC()YQ5uvdq z4D0H)02mB0UFxCsoT1ChApv`o&0Lvm%Tm>$ulKU>PB~O*`-NLQb_YO1%4@6a1TM271S_Yf7EJfn*&pv)^^Xqn!OgK&Bel-?;+?;;RpkZoeJL0h+oXjbDAgYoduuNrkHfljT0BHsch23%i`@{ z>W%Z4hs}h^xLXRTZa%@hd}C5GxM^fC001T2%QprEjRE*?9pv94$BPc~-}WD2<3$zu zEA-FO#ES~@hw(fQ{{#J#Ci2%6{@kj5kud(S&gbDDg2rEQe{Q0@G_wD&%jdnk;QkdM z{tEnasru4m`NO_G%RT=9{+}I}zpnG=sP$6V{$a+?{`{{p_pk6jzmS*E_J?I-{ulf| V+kk!jG6(MYsC{njzs7yJ`XAm)KnMT; diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.12.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.12.jar new file mode 100644 index 0000000000000000000000000000000000000000..b0d3fd17a41cb698b9e1b2544a58c39b929cdb8e GIT binary patch literal 7179 zcmaKR1yCH@wl*@jyF-A%b+F*>1b27W!65_>5S-u+K>|Z?w*iJENRWZxPH>k536_`K z^S?Uh-M{Yrx~g|qb+27}cYSNEZ|$X}h72G^Ku1SMcvC~Kj_?l*>sQEN11P7FCnS#Si=}4cc1eP6c zG6TVWt>GnikH$2b$d2rArzDzUSBpt)Om);V?y%%vou<-GSCG!TMa_ex19-5o@C%iA z9iPvO$XaS>=ubcQ$KU?mffOPFg4TcB7s2l@)Pv+aJ>BHotbKgAZU5YB$qqEB44-b+ z%xLLy8at_ZN63kAk(24!7N}1a@P{yB@ROr~K$IO8KjQ_YGZxCBxMb`!s~E ztH#~Ly+Q(dta?vZ+p>CmmPA7sM^y?bxp?`#`|M_S{nz~$OT-O^hnR52S2?Vim~SF2 zYE-1rTcl>BrUy}aJIE}?#RdG;F@%xomg+3vgYz{NZiNv5q#{(O)vDbq*{3%}Y~+N> zfQ;ur4I=@-T{hK6eMUp+v`xS3wxV^8Lx*0MF+zO4*h+0iS!G~kgW9)fMsrnQ1E?=B zj9!duJt^m9nYaAx?N8UU3qzRW;sVT%f zk-R$%u6W0dSzxKpUe0fz;eN%miqu*i6%PZ?2+qt-@o(mb8j*>;C2|A=#j!Fv&HyL~ z^a$6IXXRk%>iCX*mdP5INLOsxcJ^syy-oO2b;n~F-!*@C~hRxDY3r z_UnL3)gVe13bDL?xBPh`KtCE>&3AVr4-kVNuEf#3( zGqz5A84_7wdQ5^F1%CIrb07^(VZFl4+lDiD*GZOOF2k4Zvo`iM@x8-w0a|*_DLH zaxUEqKl5YSv*Nn<(--l(Oc}witF_f3-!8Y?#~7dvYX?qTx!o@kL9^zK&WCI{l|yG%gu*qa`2 znMNMAvh#uZ;-#Zc47s?`A{ef>(!P6$(+-B>=u0ueHRiwCEw5I)Jw1Hw#YmQXu`Kv0 z)Dx1`ap5~$xPg__qoMn0XC=Rb;E@$XfnD`VTjAWWLDRu*pnfLd)z?mWw?^5JGwo58 zJpO@sPN-fYJH%({6^oXJABu!Yw$)2{dt(gN6cK$1R{x2)Ia{9*L_T1*pM~~Jvy}t^?g>eyql|xgCoRem^Ag=`oh#*q*H2mr&<}Z&z&uX&uerS>k|iQisVk9@4{F=Z zB>H_NZ~gT%UqMFwUJHFc0;_L{ntNvevD-AVjsp~gNTI@42Qn>Q-22wL|s~~4Apui9d!(D5oNT-Ie_x=0g#M|!{_S&vn z4ua;E)5m}#QQ8|g?i?Al)EEvyHan(2x3-KFA!T+3hs&=QiYboW%L>UwWLf*t1@8x8 z9V}(eh4IwY3y;lsXG4MIQ(bcTjM!^YUR*PDt9@;RhDcorMMOrHe!41n-a4`+{IME| z97?vZHcBWB2(|kTcj<1|o3W{9oM(#UbW2ZucG(bRebed=cQ>DK<4f@PN0pw3V2Aqr z*$^-u3q-GuQm`z}31sM~ODc2q&Vri^7vnZIdp&2mnQ{E0$3*CR1p}WC+5i@yz!MH& z+Ly*SnEf=qQ}HaK?Kuj-u@H{3%I7PaLPUJN01iNXAFR7NeZEXugv#eVrpodoqFCir zsdy~7)5;DbFxHYHtkYt}epe~=oVqHJH=Oiru4c9;%}=%gkqQWp5cfRrj>KBBE8$H9 z3({@72JRPXoq6CVij?*?=1Whn>QG1!?_GZC>6CtUXKTLdEel{iVfDi-`%;eDtf)~> zHf9yQDD+Y5tWaK^uZC0a2hb%x*ST~Tp6R`{4NYe8dg}j-8|+9_uK%4mxF08_TuYv) zLyz~m_S*rqMlZ#O&zU9K^aM!A^`L5-IboyP{F=_dYKyVCEL?w_ngXB^gI8tH(|m$M z=sNn48k*vXVsA$T4g==#y*Hx)A4a8D6s6iDV4gBRAlzD4_Yb3%gi|p7@*Rd(@1)Uqo*`-aY$hG6}f#5D0eTu>4XcUu(pfE1Dj85n%c%o1GLMuLm-O@r7L>BC(TFrlgp3E&F@`=r*euD>>NqeA zVMQcXW6gSXEK%q>9itSuH-Nm!VuP2p9lNi1n!neK?6wLZDi2>Hp|0s&VGyE*g`Yp< zM;9|-;^9BoM_ZdwRP)|BlK&iKxWx`D#Cet|i#}hx$!LPlJiq!j=)rgFge!OOy)D#> zxX-VA1MC-q<|`&3^n_y;!e!GE;Sz5pmZfBKV#HkCYTLP;%;rQ8f|PAS{(w}e>3IG^ zZFk`y$I+B;vN$ zV;psKHHg=9%w_lKq%D=Itv`6#eF#TAF5C7u3SIhIMBYA(VGwaPmp-{jCeMKD{TDiG0b; zV)!++$Ioaotzgu7V{ZycQX;CrDcIz#JXTnQec|SwtqwAe@hxq!-RAPE7jiok$`3vp zorXu*ERb&rWzu0&4De#p^R_cE2co9sx4;dik`5L^DiR*BTYNp@oClFtH(PJCN21|`5R#67faraWaTJ%n}SA!W;^aslhWdr1RK0*=?5nRF&Y1k#gE`>x~^H20LB zm4^Ca&i9{NC5K|6g+*^M*6*=mODO#woVKqDJhm^+vD7|b_W^`hE}T9MX%AlhBtbg& zNR%H_tik&5(`j79)xwx@iJjSX+SJ~?P-w2&I*HU()u_-3hponYbDpylss9(^A06x~ zh-)!|iGYwq{NHpi;s4daz`ulWp}xB&kqmxl8<&Jh-4X+|A2bx1_f0l- z@XyB&?)lBko&s+v$p>R~%08F^%H#O~y zKO&zqcmd4QMn@!va?duU&4 zx|&G7E!9&}aqf%y;TO{7=8@1prGv7vo9hg^>osms-XD>JPIw23%O>1u> zj0VbOwh12|V&|^!<@(-rMp#X5N}4S<`s_(a{I$-@y_BE#I^{g$dY?Q*1wo!w5DvcM zNQHSbPNRIGRaN@|Vek|(rN`SiH|$4hmzEnDmM{FGh40KmTV4T7gga|+F+Qq@$7(NC zs5ShWVOc^lD$&1rL9RtWrUDyfDAUFoR)d<9uFx;tUxT4}G!*XE`MsDNAxX&WN7jn71eMz}O^!NFb>hQsQ+E5MM;+0A1-`D?E&#+d6g$SI@l zc<2E72>D1YDYHq(y-GytY{7+vXq5xGAvgu8-I1H^5@`w&gT~77Obt>%tGEUhm6Ts| zbRTo%YKc`^ZF!CRm9&_CDPGcD|1xz2$hUZwLqku&^fR*4O$Cy^{#O zf*{F@%fW4@>KIT@YQ-E@40V;R( zKTx9Zo)E$^z!PmCy|$09gnj&X^F4RR#zgu;f8Anii1=fq>sv2%owk}P%2C36wqu~B z5$HZ1az+e1Nkm_q0)wXbzN00w2WU~EqGRmz&~BhNhO5XG;JaUD@n$}niI9;ocw*cO zSbi^eE+#cd)*=6hZ~ddVB>H$fuGPQ!)Tqv)E&9km)Ze5shIrG%vL%K9rd*F@-6_>A zIxlz$hIysLH{e?&CcAD3sN|d?PxC6kyOg;zoLWB!S#y|n9%u z@A9l+pe5patP&U_M18f~!&~l(zuRoR5^MgF{&f&1kT>=|+OE~ypPGYp`x&Y6W-yGO zDk4&AG{*ONv{+9pwaSnj$pHa9x0xTeEVn>6ToNQh%Ata7o5_1BGTwE`(^91&1Ge3)e&8EJ z4Ec*ZnlBR|!)NB;oCjrP-5IVeX|51M_Dc8T`|&RN?!3~xAQ!C3?4Fy1*!=d2-k@4d zx}!Zeo{YOmi;Im0t-E?e>Ke+mT%QLOowHi1+nHVf7iuVVmf7anOh1X(H*ZscX;TaNtW z-jO>rFBplKG8QFaJYaby{EUcm3#KQuKlINk1%$UFAs`H({5Sgl&S(Fk|6g=h(4O)j z)W$DdT>au5p@UCqmOfXPVWgaqO6;1DbF5QkgkFKJQYLP;AE{>b+?bSz@R9Y^`2dgvTx5rBOdJ?-S- zrwaoq(q=<1hbmOU9KUpC{T-CwN1@^RgdK^>6Y1sabFt{32}Kcv*#aJvXQSV>)3R}N ztIhkg)-q`6*qoSOH|V`NKxt`+VDtD9A3$=(CiNW1{0WzRrR9)3Tkvs*|Iw}B8`W3Y zPn7Ck2PYUwM`_4c+)BDe&=zku)0kZy@bWR!QOj>4et4uay!RFnbD^k$@X>vU8j$&9mgu?ZHz&VBuaVcmjDUCk9oDt-NrEM*@7lPvY9p~INikez+eD+);+Xlt4(}=w z{2RN?kD+HTr80|%lCN4~Cm*#XUY!CnRADxf+a707@w11=NxCq^g$PR&$jhV*in(Wi zLmu+%B&Q|b^r@#!EWS;bDq=W^Ony4=bWFc@is*5&qdhDu#&z*z(;r6-aOhx;DPZc) z=VF)iF(*;fco0hN$6ZkdZcR`dQZaV8mFCW*O>0B#>%T)85u8DHQlK5Zaf@H_7Em`G z$&070KxP@6v3C~Z`+(caw%o?rXJEDZuuXY*aD<)GAEOgP-z97ZEc61GmC9I+ulQ0L zz`uP{vD$(W=DwLlf*ufczP#_&{D=!NAHxboD;6U<$jJ4+!PHaC)5PRF_>qh#U|pDu z84ZB&0ZEN0M5yvrTvu9k9h&4C^(g$O35-m5o%^#Gy)#CigVGztPDPPA(DENU@CYKt zqeQ-z$E9+m`_psK8;W{3%{0kfL_9_fnFO3Hw04ijl2R5_2t8NlFp#%tc1=(OG?3Ja zqO^k+u*ec(?onWkuQ4PN+J89hokw8E{xEsm?p;IbyFN9 zZg!|HL6vhm)-oS(eGlOm`0jf24UG~l9$mrGSa)%SQ+l+5b|$Lt7vn}r+gPCV@d4k%;N91gf*}P3b?p{ zIF2YthHi{vtV%b_Z6Izq1)JQ0+2IT3JxITzW||m>lDLQH4LPl$Ur=)O-mBmh{ysC{ z=!iJyIQ_C#$vb)7ya?;UQdNA(FC*}0YVeVXTn~1}vawACMz&mRm=ddyR9bds2Ayq( zbmA~cL{cw_bMiP|==KZ7vh?)vLtBxc5xb5N5hBI2x8=rWI?o1~tMad^21u$3cTK={x9hFC#F%4ZC z24 zJ{=msxif}T$Ix<6C98aOY$COM;5zMZ%sA%ETlslhWl)0+fkG$Dh5h}D#q2NGYTbdl z>l~{fSG#_#%>G6cLvx^RpPny2uy= zq)>aTfXtQE1izl2S?^Kn4&SdTege+@+NFtOUr<&gM3-`Dhz@wp0;gVitK^F$cDX>Z zl0LMyF&N~#17=Mr>osM8Yrcg@fggN#+@Dak28af9(}jtOVP!!DleTu4Wkel0v90eR z4Q<<-+n`XaNj)n~KigSigEd$)AoGw-JdA4TK&2;jNG+&mZa!~l`F#e?89b&osHgr# zZGXVSx*D==+&AP9Wy{I0lPTvQ$U#|><~dt+9YfsYxyY;cjJL3asHN8zOofj{_c-Sy zzn;I#y1`laI5HFIx)6EvkkJ_-^dKGY5?WDG#E}1lwW(3Z9`}@`5;A)!2F-2c)Cy#E z(F-PSi!{E_mL>`P!CG@cG*4PbX-J2t?l-7&ISZ!SzPa&6nq#j^GMm@hoh+}}QR;P! z$>xf%%uxG<^e1j}+t5se{m$`^zeBT@8X^)g!rzUjf3|M^G@kx1{!<(3-)H$};rpk} z^EVcLKmLDBpMN+1rw;#99siAu-%I|7`Twhx|K0wdLgtTJ{~M`z|4qC9`#k?laDQaa y-$?nrtUootUpejH?f&_L|D?aa(f9i#;9s@F{~Ds8{&|V|`wIP?Ndx}1+W!GqLgq&R literal 0 HcmV?d00001 From 39d3d6cc965bd09b1719d245e672b013b8cee6f7 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 3 Sep 2018 00:38:08 -0700 Subject: [PATCH 1533/2461] [SPARK-25167][SPARKR][TEST][MINOR] Minor fixes for R sql tests (timestamp comparison) ## What changes were proposed in this pull request? The "date function on DataFrame" test fails consistently on my laptop. In this PR i am fixing it by changing the way we compare the two timestamp values. With this change i am able to run the tests clean. ## How was this patch tested? Fixed the failing test. Author: Dilip Biswal Closes #22274 from dilipbiswal/r-sql-test-fix2. --- R/pkg/tests/fulltests/test_sparkSQL.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 17e4a970425af..5c07a028f8b0e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1870,9 +1870,9 @@ test_that("date functions on a DataFrame", { expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], - c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + c(as.POSIXct("2012-12-13 21:34:00 UTC"), as.POSIXct("2014-12-15 10:24:34 UTC"))) expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], - c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + c(as.POSIXct("2012-12-13 03:34:00 UTC"), as.POSIXct("2014-12-14 16:24:34 UTC"))) expect_gt(collect(select(df2, unix_timestamp()))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) @@ -3652,7 +3652,8 @@ test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { expect_equal(currentDatabase(), "default") expect_error(setCurrentDatabase("default"), NA) expect_error(setCurrentDatabase("zxwtyswklpf"), - "Error in setCurrentDatabase : analysis error - Database 'zxwtyswklpf' does not exist") + paste0("Error in setCurrentDatabase : analysis error - Database ", + "'zxwtyswklpf' does not exist")) dbs <- collect(listDatabases()) expect_equal(names(dbs), c("name", "description", "locationUri")) expect_equal(which(dbs[, 1] == "default"), 1) From 546683c21a23cd5e3827e69609ca91cf92bd9e02 Mon Sep 17 00:00:00 2001 From: Darcy Shen Date: Mon, 3 Sep 2018 07:36:04 -0500 Subject: [PATCH 1534/2461] [SPARK-25298][BUILD] Improve build definition for Scala 2.12 ## What changes were proposed in this pull request? Improve build for Scala 2.12. Current build for sbt fails on the subproject `repl`: ``` [info] Compiling 6 Scala sources to /Users/rendong/wdi/spark/repl/target/scala-2.12/classes... [error] /Users/rendong/wdi/spark/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala:80: overriding lazy value importableSymbolsWithRenames in class ImportHandler of type List[(this.intp.global.Symbol, this.intp.global.Name)]; [error] lazy value importableSymbolsWithRenames needs `override' modifier [error] lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = { [error] ^ [warn] /Users/rendong/wdi/spark/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala:53: variable addedClasspath in class ILoop is deprecated (since 2.11.0): use reset, replay or require to update class path [warn] if (addedClasspath != "") { [warn] ^ [warn] /Users/rendong/wdi/spark/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala:54: variable addedClasspath in class ILoop is deprecated (since 2.11.0): use reset, replay or require to update class path [warn] settings.classpath append addedClasspath [warn] ^ [warn] two warnings found [error] one error found [error] (repl/compile:compileIncremental) Compilation failed [error] Total time: 93 s, completed 2018-9-3 10:07:26 ``` ## How was this patch tested? ``` ./dev/change-scala-version.sh 2.12 ## For Maven ./build/mvn -Pscala-2.12 [mvn commands] ## For SBT sbt -Dscala.version=2.12.6 ``` Closes #22310 from sadhen/SPARK-25298. Authored-by: Darcy Shen Signed-off-by: Sean Owen --- docs/building-spark.md | 16 ++++++++++++++++ project/SparkBuild.scala | 6 ++++++ repl/pom.xml | 14 ++------------ .../org/apache/spark/repl/SparkExprTyper.scala | 0 .../spark/repl/SparkILoopInterpreter.scala | 0 5 files changed, 24 insertions(+), 12 deletions(-) rename repl/{scala-2.11/src/main/scala => src/main/scala-2.11}/org/apache/spark/repl/SparkExprTyper.scala (100%) rename repl/{scala-2.11/src/main/scala => src/main/scala-2.11}/org/apache/spark/repl/SparkILoopInterpreter.scala (100%) diff --git a/docs/building-spark.md b/docs/building-spark.md index 0086aeaaa4701..1d3e0b1b7d396 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -256,3 +256,19 @@ On Linux, this can be done by `sudo service docker start`. or ./build/sbt docker-integration-tests/test + +## Change Scala Version + +To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.12): + + ./dev/change-scala-version.sh 2.12 + +For Maven, please enable the profile (e.g. 2.12): + + ./build/mvn -Pscala-2.12 compile + +For SBT, specify a complete scala version using (e.g. 2.12.6): + + ./build/sbt -Dscala.version=2.12.6 + +Otherwise, the sbt-pom-reader plugin will use the `scala.version` specified in the spark-parent pom. diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1f45a06084c0d..a5ed9088eaa4d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -94,6 +94,12 @@ object SparkBuild extends PomBuild { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } + + Option(System.getProperty("scala.version")) + .filter(_.startsWith("2.12")) + .foreach { versionString => + System.setProperty("scala-2.12", "true") + } if (System.getProperty("scala-2.12") == "") { // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. diff --git a/repl/pom.xml b/repl/pom.xml index 861bbd7c49654..e8464a688336b 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -32,8 +32,8 @@ repl - scala-2.11/src/main/scala - scala-2.11/src/test/scala + src/main/scala-${scala.binary.version} + src/test/scala-${scala.binary.version} @@ -167,14 +167,4 @@ - - - scala-2.12 - - scala-2.12/src/main/scala - scala-2.12/src/test/scala - - - - diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/src/main/scala-2.11/org/apache/spark/repl/SparkExprTyper.scala similarity index 100% rename from repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala rename to repl/src/main/scala-2.11/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/src/main/scala-2.11/org/apache/spark/repl/SparkILoopInterpreter.scala similarity index 100% rename from repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala rename to repl/src/main/scala-2.11/org/apache/spark/repl/SparkILoopInterpreter.scala From 8e2169696f4cd52e5e3f51626a512d25215cffa4 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 4 Sep 2018 13:28:36 +0900 Subject: [PATCH 1535/2461] [SPARK-25308][SQL] ArrayContains function may return a error in the code generation phase. ## What changes were proposed in this pull request? Invoking ArrayContains function with non nullable array type throws the following error in the code generation phase. Below is the error snippet. ```SQL Code generation of array_contains([1,2,3], 1) failed: java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 40, Column 11: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 40, Column 11: Expression "isNull_0" is not an rvalue java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 40, Column 11: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 40, Column 11: Expression "isNull_0" is not an rvalue at com.google.common.util.concurrent.AbstractFuture$Sync.getValue(AbstractFuture.java:306) at com.google.common.util.concurrent.AbstractFuture$Sync.get(AbstractFuture.java:293) at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:116) at com.google.common.util.concurrent.Uninterruptibles.getUninterruptibly(Uninterruptibles.java:135) at com.google.common.cache.LocalCache$Segment.getAndRecordStats(LocalCache.java:2410) at com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2380) at com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342) at com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2257) at com.google.common.cache.LocalCache.get(LocalCache.java:4000) at com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:4004) at com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874) at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:1305) ``` ## How was this patch tested? Added test in CollectionExpressionSuite. Closes #22315 from dilipbiswal/SPARK-25308. Authored-by: Dilip Biswal Signed-off-by: Takuya UESHIN --- .../expressions/collectionOperations.scala | 32 +++++++++++++------ .../CollectionExpressionsSuite.scala | 3 ++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cf9796ef1948f..17c683cc8ff57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1464,17 +1464,29 @@ case class ArrayContains(left: Expression, right: Expression) nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(right.dataType, value, getValue)}) { - ${ev.isNull} = false; - ${ev.value} = true; - break; - } + val loopBodyCode = if (nullable) { + s""" + |if ($arr.isNullAt($i)) { + | ${ev.isNull} = true; + |} else if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin + } else { + s""" + |if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ${ev.value} = true; + | break; + |} + """.stripMargin } - """ + s""" + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | $loopBodyCode + |} + """.stripMargin }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7b345aabd19c8..a9fc3e9c7b378 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -383,10 +383,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a3 = Literal.create(null, ArrayType(StringType)) val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( StructField("a", IntegerType, true))))) + // Explicitly mark the array type not nullable (spark-25308) + val a5 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) + checkEvaluation(ArrayContains(a5, Literal(1)), true) checkEvaluation(ArrayContains(a1, Literal("")), true) checkEvaluation(ArrayContains(a1, Literal("a")), null) From b60ee3a337e0484e1d3e978197685d7d5ab858e7 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 4 Sep 2018 13:39:29 +0900 Subject: [PATCH 1536/2461] [SPARK-25307][SQL] ArraySort function may return an error in the code generation phase ## What changes were proposed in this pull request? Sorting array of booleans (not nullable) returns a compilation error in the code generation phase. Below is the compilation error : ```SQL java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 51, Column 23: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 51, Column 23: No applicable constructor/method found for actual parameters "boolean[]"; candidates are: "public static void java.util.Arrays.sort(long[])", "public static void java.util.Arrays.sort(long[], int, int)", "public static void java.util.Arrays.sort(byte[], int, int)", "public static void java.util.Arrays.sort(float[])", "public static void java.util.Arrays.sort(float[], int, int)", "public static void java.util.Arrays.sort(char[])", "public static void java.util.Arrays.sort(char[], int, int)", "public static void java.util.Arrays.sort(short[], int, int)", "public static void java.util.Arrays.sort(short[])", "public static void java.util.Arrays.sort(byte[])", "public static void java.util.Arrays.sort(java.lang.Object[], int, int, java.util.Comparator)", "public static void java.util.Arrays.sort(java.lang.Object[], java.util.Comparator)", "public static void java.util.Arrays.sort(int[])", "public static void java.util.Arrays.sort(java.lang.Object[], int, int)", "public static void java.util.Arrays.sort(java.lang.Object[])", "public static void java.util.Arrays.sort(double[])", "public static void java.util.Arrays.sort(double[], int, int)", "public static void java.util.Arrays.sort(int[], int, int)" at com.google.common.util.concurrent.AbstractFuture$Sync.getValue(AbstractFuture.java:306) at com.google.common.util.concurrent.AbstractFuture$Sync.get(AbstractFuture.java:293) at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:116) at com.google.common.util.concurrent.Uninterruptibles.getUninterruptibly(Uninterruptibles.java:135) at com.google.common.cache.LocalCache$Segment.getAndRecordStats(LocalCache.java:2410) at com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2380) at com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342) at com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2257) at com.google.common.cache.LocalCache.get(LocalCache.java:4000) at com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:4004) at com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874) at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:1305) ``` ## How was this patch tested? Added test in collectionExpressionSuite Closes #22314 from dilipbiswal/SPARK-25307. Authored-by: Dilip Biswal Signed-off-by: Takuya UESHIN --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- .../catalyst/expressions/CollectionExpressionsSuite.scala | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 17c683cc8ff57..a29828e7f0e65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1046,8 +1046,9 @@ trait ArraySortLike extends ExpectsInputTypes { } else { s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" } - val nonNullPrimitiveAscendingSort = - if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val canPerformFastSort = + CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull + val nonNullPrimitiveAscendingSort = if (canPerformFastSort) { val javaType = CodeGenerator.javaType(elementType) val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a9fc3e9c7b378..96ae7d1eba9fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -326,12 +326,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val d2 = new Decimal().set(100) val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a6 = Literal.create(Seq(true, false, true, false), + ArrayType(BooleanType, containsNull = false)) + val a7 = Literal.create(Seq(true, false, true, false), ArrayType(BooleanType)) + val a8 = Literal.create(Seq(true, false, true, null, false), ArrayType(BooleanType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) checkEvaluation(new SortArray(a4), Seq(d1, d2)) + checkEvaluation(new SortArray(a6), Seq(false, false, true, true)) + checkEvaluation(new SortArray(a7), Seq(false, false, true, true)) + checkEvaluation(new SortArray(a8), Seq(null, false, false, true, true)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) From 4cb2ff9d8a58da5170744a634f672ba07b0a6a24 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 4 Sep 2018 14:00:00 +0900 Subject: [PATCH 1537/2461] [SPARK-25310][SQL] ArraysOverlap may throw a CompilationException ## What changes were proposed in this pull request? This PR fixes a problem that `ArraysOverlap` function throws a `CompilationException` with non-nullable array type. The following is the stack trace of the original problem: ``` Code generation of arrays_overlap([1,2,3], [4,5,3]) failed: java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 56, Column 11: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 56, Column 11: Expression "isNull_0" is not an rvalue java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 56, Column 11: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 56, Column 11: Expression "isNull_0" is not an rvalue at com.google.common.util.concurrent.AbstractFuture$Sync.getValue(AbstractFuture.java:306) at com.google.common.util.concurrent.AbstractFuture$Sync.get(AbstractFuture.java:293) at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:116) at com.google.common.util.concurrent.Uninterruptibles.getUninterruptibly(Uninterruptibles.java:135) at com.google.common.cache.LocalCache$Segment.getAndRecordStats(LocalCache.java:2410) at com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2380) at com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2342) at com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2257) at com.google.common.cache.LocalCache.get(LocalCache.java:4000) at com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:4004) at com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4874) at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:1305) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:143) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:48) at org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection$.create(GenerateMutableProjection.scala:32) at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator.generate(CodeGenerator.scala:1260) ``` ## How was this patch tested? Added test in `CollectionExpressionSuite`. Closes #22317 from kiszk/SPARK-25310. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../sql/catalyst/expressions/collectionOperations.scala | 6 ++++-- .../catalyst/expressions/CollectionExpressionsSuite.scala | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a29828e7f0e65..5e4f48ecfc47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1636,12 +1636,13 @@ case class ArraysOverlap(left: Expression, right: Expression) val set = ctx.freshName("set") val addToSetFromSmallerCode = nullSafeElementCodegen( smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val setIsNullCode = if (nullable) s"${ev.isNull} = false;" else "" val elementIsInSetCode = nullSafeElementCodegen( bigger, i, s""" |if ($set.contains($getFromBigger)) { - | ${ev.isNull} = false; + | $setIsNullCode | ${ev.value} = true; | break; |} @@ -1666,12 +1667,13 @@ case class ArraysOverlap(left: Expression, right: Expression) val j = ctx.freshName("j") val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val setIsNullCode = if (nullable) s"${ev.isNull} = false;" else "" val compareValues = nullSafeElementCodegen( smaller, j, s""" |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { - | ${ev.isNull} = false; + | $setIsNullCode | ${ev.value} = true; |} """.stripMargin, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 96ae7d1eba9fc..c7db4ec9e16b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -449,6 +449,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + val a7 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) @@ -463,6 +464,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArraysOverlap(a4, a5), true) checkEvaluation(ArraysOverlap(a4, a6), null) checkEvaluation(ArraysOverlap(a5, a6), false) + checkEvaluation(ArraysOverlap(a7, a7), true) // null handling checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) @@ -481,9 +483,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(BinaryType)) val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType, containsNull = false)) checkEvaluation(ArraysOverlap(b0, b1), true) checkEvaluation(ArraysOverlap(b0, b2), false) + checkEvaluation(ArraysOverlap(b3, b3), true) // arrays of complex data types val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), From e319ac92e597d95eba5b787bb7a5d5499bb3f87c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 4 Sep 2018 15:26:34 +0800 Subject: [PATCH 1538/2461] [SPARK-24962][SQL] Refactor CodeGenerator.createUnsafeArray, ArraySetLike, and ArrayDistinct ## What changes were proposed in this pull request? This PR integrates handling of `UnsafeArrayData` and `GenericArrayData` into one. The current `CodeGenerator.createUnsafeArray` handles only allocation of `UnsafeArrayData`. This PR introduces a new method `createArrayData` that returns a code to allocate `UnsafeArrayData` or `GenericArrayData` and to assign a value into the allocated array. This PR also reduce the size of generated code by calling a runtime helper. This PR replaced `createArrayData` with `createUnsafeArray`. This PR also refactor `ArraySetLike` that can be used for `ArrayDistinct`, too. This PR also refactors`ArrayDistinct` to use `ArraryBuilder`. ## How was this patch tested? Existing tests Closes #21912 from kiszk/SPARK-24962. Lead-authored-by: Kazuaki Ishizaki Co-authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../catalyst/expressions/UnsafeArrayData.java | 22 +- .../expressions/codegen/CodeGenerator.scala | 150 +-- .../expressions/collectionOperations.scala | 926 +++++++----------- .../spark/sql/catalyst/util/ArrayData.scala | 27 + 4 files changed, 464 insertions(+), 661 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index cf2a5ed2e27f9..9e7b15d339eeb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -473,13 +473,27 @@ public static UnsafeArrayData fromPrimitiveArray( return result; } - public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { - return fromPrimitiveArray(null, offset, length, elementSize); + public static UnsafeArrayData createFreshArray(int length, int elementSize) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + if (totalSizeInLongs > Integer.MAX_VALUE / 8) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final long[] data = new long[(int)totalSizeInLongs]; + + Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); + return result; } - public static boolean shouldUseGenericArrayData(int elementSize, int length) { + public static boolean shouldUseGenericArrayData(int elementSize, long length) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = (long)elementSize * length; + final long valueRegionInBytes = elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; return totalSizeInLongs > Integer.MAX_VALUE / 8; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b8f09761f61ad..d5857e060a2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -39,7 +39,7 @@ import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -746,73 +746,6 @@ class CodegenContext { """.stripMargin } - /** - * Generates code creating a [[UnsafeArrayData]]. - * - * @param arrayName name of the array to create - * @param numElements code representing the number of elements the array should contain - * @param elementType data type of the elements in the array - * @param additionalErrorMessage string to include in the error message - */ - def createUnsafeArray( - arrayName: String, - numElements: String, - elementType: DataType, - additionalErrorMessage: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - - s""" - |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | ${elementType.defaultSize}); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + - | "$additionalErrorMessage"); - |} - |byte[] $arrayBytes = new byte[(int)$arraySize]; - |UnsafeArrayData $arrayName = new UnsafeArrayData(); - |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - """.stripMargin - } - - /** - * Generates code creating a [[UnsafeArrayData]]. The generated code executes - * a provided fallback when the size of backing array would exceed the array size limit. - * @param arrayName a name of the array to create - * @param numElements a piece of code representing the number of elements the array should contain - * @param elementSize a size of an element in bytes - * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] - * and getting the backing array as a parameter - * @param fallbackCode a piece of code executed when the array size limit is exceeded - */ - def createUnsafeArrayWithFallback( - arrayName: String, - numElements: String, - elementSize: Int, - bodyCode: String => String, - fallbackCode: String): String = { - val arraySize = freshName("size") - val arrayBytes = freshName("arrayBytes") - s""" - |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElements, - | $elementSize); - |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | $fallbackCode - |} else { - | final byte[] $arrayBytes = new byte[(int)$arraySize]; - | UnsafeArrayData $arrayName = new UnsafeArrayData(); - | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); - | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); - | ${bodyCode(arrayBytes)} - |} - """.stripMargin - } - /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. @@ -1490,6 +1423,59 @@ object CodeGenerator extends Logging { } } + /** + * Generates code creating a [[UnsafeArrayData]] or [[GenericArrayData]] based on + * given parameters. + * + * @param arrayName name of the array to create + * @param elementType data type of the elements in source array + * @param numElements code representing the number of elements the array should contain + * @param additionalErrorMessage string to include in the error message + * + * @return code representing the allocation of [[ArrayData]] + */ + def createArrayData( + arrayName: String, + elementType: DataType, + numElements: String, + additionalErrorMessage: String): String = { + val elementSize = if (CodeGenerator.isPrimitiveType(elementType)) { + elementType.defaultSize + } else { + -1 + } + s""" + |ArrayData $arrayName = ArrayData.allocateArrayData( + | $elementSize, $numElements, "$additionalErrorMessage"); + """.stripMargin + } + + /** + * Generates assignment code for an [[ArrayData]] + * + * @param dstArray name of the array to be assigned + * @param elementType data type of the elements in destination and source arrays + * @param srcArray name of the array to be read + * @param needNullCheck value which shows whether a nullcheck is required for the returning + * assignment + * @param dstArrayIndex an index variable to access each element of destination array + * @param srcArrayIndex an index variable to access each element of source array + * + * @return code representing an assignment to each element of the [[ArrayData]], which requires + * a pair of destination and source loop index variables + */ + def createArrayAssignment( + dstArray: String, + elementType: DataType, + srcArray: String, + dstArrayIndex: String, + srcArrayIndex: String, + needNullCheck: Boolean): String = { + CodeGenerator.setArrayElement(dstArray, elementType, dstArrayIndex, + CodeGenerator.getValue(srcArray, elementType, srcArrayIndex), + if (needNullCheck) Some(s"$srcArray.isNullAt($srcArrayIndex)") else None) + } + /** * Returns the code to update a column in Row for a given DataType. */ @@ -1558,6 +1544,34 @@ object CodeGenerator extends Logging { } } + /** + * Generates code of setter for an [[ArrayData]]. + */ + def setArrayElement( + array: String, + elementType: DataType, + i: String, + value: String, + isNull: Option[String] = None): String = { + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val setFunc = if (isPrimitiveType) { + s"set${CodeGenerator.primitiveTypeName(elementType)}" + } else { + "update" + } + if (isNull.isDefined && isPrimitiveType) { + s""" + |if (${isNull.get}) { + | $array.setNullAt($i); + |} else { + | $array.$setFunc($i, $value); + |} + """.stripMargin + } else { + s"$array.$setFunc($i, $value);" + } + } + /** * Returns the specialized code to set a given value in a column vector for a given `DataType` * that could potentially be nullable. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5e4f48ecfc47a..ea6fcccddfd49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -372,7 +372,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val values = childMap.valueArray() val length = childMap.numElements() val resultData = new Array[AnyRef](length) - var i = 0; + var i = 0 while (i < length) { val key = keys.get(i, childDataType.keyType) val value = values.get(i, childDataType.valueType) @@ -385,107 +385,123 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { + val arrayData = ctx.freshName("arrayData") val numElements = ctx.freshName("numElements") val keys = ctx.freshName("keys") val values = ctx.freshName("values") val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) - val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) { + (true, structSize + wordSize) + } else { + (false, -1) + } + + val allocation = + s""" + |ArrayData $arrayData = ArrayData.allocateArrayData( + | $elementSize, $numElements, " $prettyName failed."); + """.stripMargin + + val code = if (isPrimitive) { + val genCodeForPrimitive = genCodeForPrimitiveElements( + ctx, arrayData, keys, values, ev.value, numElements, structSize) + s""" + |if ($arrayData instanceof UnsafeArrayData) { + | $genCodeForPrimitive + |} else { + | ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)} + |} + """.stripMargin } else { - genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}" } + s""" |final int $numElements = $c.numElements(); |final ArrayData $keys = $c.keyArray(); |final ArrayData $values = $c.valueArray(); + |$allocation |$code """.stripMargin }) } - private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + private def getKey(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.keyType, index) - private def getValue(varName: String) = { - CodeGenerator.getValue(varName, childDataType.valueType, "z") - } + private def getValue(varName: String, index: String) = + CodeGenerator.getValue(varName, childDataType.valueType, index) private def genCodeForPrimitiveElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, - numElements: String): String = { - val unsafeRow = ctx.freshName("unsafeRow") + resultArrayData: String, + numElements: String, + structSize: Int): String = { val unsafeArrayData = ctx.freshName("unsafeArrayData") + val baseObject = ctx.freshName("baseObject") + val unsafeRow = ctx.freshName("unsafeRow") val structsOffset = ctx.freshName("structsOffset") + val offset = ctx.freshName("offset") + val z = ctx.freshName("z") val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET val wordSize = UnsafeRow.WORD_SIZE - val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 - val structSizeAsLong = structSize + "L" - val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.valueType) + val structSizeAsLong = s"${structSize}L" - val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" - val valueAssignmentChecked = if (childDataType.valueContainsNull) { - s""" - |if ($values.isNullAt(z)) { - | $unsafeRow.setNullAt(1); - |} else { - | $valueAssignment - |} - """.stripMargin - } else { - valueAssignment - } + val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, getKey(keys, z)) - val assignmentLoop = (byteArray: String) => - s""" - |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; - |UnsafeRow $unsafeRow = new UnsafeRow(2); - |for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); - | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); - | $valueAssignmentChecked - |} - |$arrayData = $unsafeArrayData; - """.stripMargin + val valueAssignmentChecked = CodeGenerator.createArrayAssignment( + unsafeRow, childDataType.valueType, values, "1", z, childDataType.valueContainsNull) - ctx.createUnsafeArrayWithFallback( - unsafeArrayData, - numElements, - structSize + wordSize, - assignmentLoop, - genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + s""" + |UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData; + |Object $baseObject = $unsafeArrayData.getBaseObject(); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int $z = 0; $z < $numElements; $z++) { + | long $offset = $structsOffset + $z * $structSizeAsLong; + | $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize); + | $setKey; + | $valueAssignmentChecked + |} + |$resultArrayData = $arrayData; + """.stripMargin } private def genCodeForAnyElements( ctx: CodegenContext, + arrayData: String, keys: String, values: String, - arrayData: String, + resultArrayData: String, numElements: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val rowClass = classOf[GenericInternalRow].getName - val data = ctx.freshName("internalRowArray") - + val z = ctx.freshName("z") val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { - s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}" } else { - getValue(values) + getValue(values, z) } + val rowClass = classOf[GenericInternalRow].getName + val genericArrayDataClass = classOf[GenericArrayData].getName + val genericArrayData = ctx.freshName("genericArrayData") + val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck})" s""" - |final Object[] $data = new Object[$numElements]; - |for (int z = 0; z < $numElements; z++) { - | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |$genericArrayDataClass $genericArrayData = ($genericArrayDataClass)$arrayData; + |for (int $z = 0; $z < $numElements; $z++) { + | $genericArrayData.update($z, $rowObject); |} - |$arrayData = new $genericArrayClass($data); + |$resultArrayData = $arrayData; """.stripMargin } @@ -610,20 +626,14 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val finKeysName = ctx.freshName("finalKeys") val finValsName = ctx.freshName("finalValues") - val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) { - genCodeForPrimitiveArrays(ctx, keyType, false) - } else { - genCodeForNonPrimitiveArrays(ctx, keyType) - } + val keyConcat = genCodeForArrays(ctx, keyType, false) val valueConcat = if (valueType.sameType(keyType) && !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { keyConcat - } else if (CodeGenerator.isPrimitiveType(valueType)) { - genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) } else { - genCodeForNonPrimitiveArrays(ctx, valueType) + genCodeForArrays(ctx, valueType, dataType.valueContainsNull) } val keyArgsName = ctx.freshName("keyArgs") @@ -662,7 +672,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """.stripMargin) } - private def genCodeForPrimitiveArrays( + private def genCodeForArrays( ctx: CodegenContext, elementType: DataType, checkForNull: Boolean): String = { @@ -670,35 +680,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val arrayData = ctx.freshName("arrayData") val argsName = ctx.freshName("args") val numElemName = ctx.freshName("numElements") - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - - val setterCode1 = - s""" - |$arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} - |);""".stripMargin + val y = ctx.freshName("y") + val z = ctx.freshName("z") - val setterCode = if (checkForNull) { - s""" - |if ($argsName[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - |} else { - | $setterCode1 - |}""".stripMargin - } else { - setterCode1 - } + val allocation = CodeGenerator.createArrayData( + arrayData, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull) val concat = ctx.freshName("concat") val concatDef = s""" |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | $allocation | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $setterCode + | for (int $y = 0; $y < ${children.length}; $y++) { + | for (int $z = 0; $z < $argsName[$y].numElements(); $z++) { + | $assignment | $counter++; | } | } @@ -709,32 +707,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres ctx.addNewFunction(concat, concatDef) } - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") - val argsName = ctx.freshName("args") - val numElemName = ctx.freshName("numElements") - - val concat = ctx.freshName("concat") - val concatDef = - s""" - |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { - | Object[] $arrayData = new Object[$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - |} - """.stripMargin - - ctx.addNewFunction(concat, concatDef) - } - override def prettyName: String = "map_concat" } @@ -867,25 +839,12 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { val valueSize = dataType.valueType.defaultSize val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" - val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) - val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) - val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" - val valueAssignment = (entry: String, idx: String) => { - val value = CodeGenerator.getValue(entry, dataType.valueType, "1") - val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" - if (dataType.valueContainsNull) { - s""" - |if ($entry.isNullAt(1)) { - | $valueArrayData.setNullAt($idx); - |} else { - | $valueNullUnsafeAssignment - |} - """.stripMargin - } else { - valueNullUnsafeAssignment - } - } + val keyAssignment = (key: String, idx: String) => + CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key) + val valueAssignment = (entry: String, idx: String) => + CodeGenerator.createArrayAssignment( + valueArrayData, dataType.valueType, entry, idx, "1", dataType.valueContainsNull) val assignmentLoop = genCodeForAssignmentLoop( ctx, childVariable, @@ -1263,40 +1222,15 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) ctx.addPartitionInitializationStatement( s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);") - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) - val numElements = ctx.freshName("numElements") val arrayData = ctx.freshName("arrayData") - - val initialization = if (isPrimitiveType) { - ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") - } else { - val arrayDataClass = classOf[GenericArrayData].getName() - s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" - } - val indices = ctx.freshName("indices") val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(childName, elementType, s"$indices[$i]") - - val setFunc = if (isPrimitiveType) { - s"set${CodeGenerator.primitiveTypeName(elementType)}" - } else { - "update" - } - - val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($childName.isNullAt($indices[$i])) { - | $arrayData.setNullAt($i); - |} else { - | $arrayData.$setFunc($i, $getValue); - |} - """.stripMargin - } else { - s"$arrayData.$setFunc($i, $getValue);" - } + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElements, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName, + i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull) s""" |int $numElements = $childName.numElements(); @@ -1354,40 +1288,16 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) - val numElements = ctx.freshName("numElements") val arrayData = ctx.freshName("arrayData") - val initialization = if (isPrimitiveType) { - ctx.createUnsafeArray(arrayData, numElements, elementType, s" $prettyName failed.") - } else { - val arrayDataClass = classOf[GenericArrayData].getName - s"$arrayDataClass $arrayData = new $arrayDataClass(new Object[$numElements]);" - } - val i = ctx.freshName("i") val j = ctx.freshName("j") - val getValue = CodeGenerator.getValue(childName, elementType, i) - - val setFunc = if (isPrimitiveType) { - s"set${CodeGenerator.primitiveTypeName(elementType)}" - } else { - "update" - } - - val assignment = if (isPrimitiveType && dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($childName.isNullAt($i)) { - | $arrayData.setNullAt($j); - |} else { - | $arrayData.$setFunc($j, $getValue); - |} - """.stripMargin - } else { - s"$arrayData.$setFunc($j, $getValue);" - } + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElements, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull) s""" |final int $numElements = $childName.numElements(); @@ -1803,38 +1713,24 @@ case class Slice(x: Expression, start: Expression, length: Expression) resLength: String): String = { val values = ctx.freshName("values") val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - s""" - |Object[] $values; - |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { - | $values = new Object[0]; - |} else { - | $values = new Object[$resLength]; - | for (int $i = 0; $i < $resLength; $i ++) { - | $values[$i] = $getValue; - | } - |} - |${ev.value} = new $arrayClass($values); - """.stripMargin - } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { - | $resLength = 0; - |} - |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} - |for (int $i = 0; $i < $resLength; $i ++) { - | if ($inputArray.isNullAt($i + $startIdx)) { - | $values.setNullAt($i); - | } else { - | $values.set$primitiveValueTypeName($i, $getValue); - | } - |} - |${ev.value} = $values; - """.stripMargin - } + val genericArrayData = classOf[GenericArrayData].getName + + val allocation = CodeGenerator.createArrayData( + values, elementType, resLength, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray, + i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull) + + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | ${ev.value} = new $genericArrayData(new Object[0]); + |} else { + | $allocation + | for (int $i = 0; $i < $resLength; $i ++) { + | $assignment + | } + | ${ev.value} = $values; + |} + """.stripMargin } } @@ -2452,11 +2348,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio case StringType => ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, containsNull) => - val concat = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType, containsNull) - } else { - genCodeForNonPrimitiveArrays(ctx, elementType) - } + val concat = genCodeForArrays(ctx, elementType, containsNull) (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];") } @@ -2475,62 +2367,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { val numElements = ctx.freshName("numElements") + val z = ctx.freshName("z") val code = s""" |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit" + - | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |for (int $z = 0; $z < ${children.length}; $z++) { + | $numElements += args[$z].numElements(); |} """.stripMargin (code, numElements) } - private def genCodeForPrimitiveArrays( + private def genCodeForArrays( ctx: CodegenContext, elementType: DataType, checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") + val y = ctx.freshName("y") + val z = ctx.freshName("z") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - - val setterCode = - s""" - |$arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - |); - """.stripMargin - - val nullSafeSetterCode = if (checkForNull) { - s""" - |if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - |} else { - | $setterCode - |} - """.stripMargin - } else { - setterCode - } + val initialization = CodeGenerator.createArrayData( + arrayData, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + arrayData, elementType, s"args[$y]", counter, z, + dataType.asInstanceOf[ArrayType].containsNull) val concat = ctx.freshName("concat") val concatDef = s""" |private ArrayData $concat(ArrayData[] args) { | $numElemCode - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | $initialization | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $nullSafeSetterCode + | for (int $y = 0; $y < ${children.length}; $y++) { + | for (int $z = 0; $z < args[$y].numElements(); $z++) { + | $assignment | $counter++; | } | } @@ -2541,33 +2415,6 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio ctx.addNewFunction(concat, concatDef) } - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") - - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - - val concat = ctx.freshName("concat") - val concatDef = - s""" - |private ArrayData $concat(ArrayData[] args) { - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - |} - """.stripMargin - - ctx.addNewFunction(concat, concatDef) - } - override def toString: String = s"concat(${children.mkString(", ")})" override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" @@ -2630,11 +2477,7 @@ case class Flatten(child: Expression) extends UnaryExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { - val code = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) - } else { - genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) - } + val code = genCodeForFlatten(ctx, c, ev.value) ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } @@ -2648,41 +2491,36 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit" + - | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - |} """.stripMargin (code, variableName) } - private def genCodeForFlattenOfPrimitiveElements( + private def genCodeForFlatten( ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") + val k = ctx.freshName("k") + val l = ctx.freshName("l") + val arr = ctx.freshName("arr") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val allocation = CodeGenerator.createArrayData( + tempArrayDataName, elementType, numElemName, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + tempArrayDataName, elementType, arr, counter, l, + dataType.asInstanceOf[ArrayType].containsNull) s""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |$allocation |int $counter = 0; - |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | if (arr.isNullAt(l)) { - | $tempArrayDataName.setNullAt($counter); - | } else { - | $tempArrayDataName.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue("arr", elementType, "l")} - | ); - | } + |for (int $k = 0; $k < $childVariableName.numElements(); $k++) { + | ArrayData $arr = $childVariableName.getArray($k); + | for (int $l = 0; $l < $arr.numElements(); $l++) { + | $assignment | $counter++; | } |} @@ -2690,30 +2528,6 @@ case class Flatten(child: Expression) extends UnaryExpression { """.stripMargin } - private def genCodeForFlattenOfNonPrimitiveElements( - ctx: CodegenContext, - childVariableName: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val counter = ctx.freshName("counter") - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - - s""" - |$numElemCode - |Object[] $arrayName = new Object[(int)$numElemName]; - |int $counter = 0; - |for (int k = 0; k < $childVariableName.numElements(); k++) { - | ArrayData arr = $childVariableName.getArray(k); - | for (int l = 0; l < arr.numElements(); l++) { - | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; - | $counter++; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin - } - override def prettyName: String = "flatten" } @@ -3155,11 +2969,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val count = rightGen.value val et = dataType.elementType - val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { - genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) - } else { - genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) - } + val coreLogic = genCodeForElement(ctx, et, element, count, leftGen.isNull, ev.value) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = @@ -3198,17 +3008,12 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit" + - | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - |} """.stripMargin (numElements, numElementsCode) } - private def genCodeForPrimitiveElement( + private def genCodeForElement( ctx: CodegenContext, elementType: DataType, element: String, @@ -3216,48 +3021,30 @@ case class ArrayRepeat(left: Expression, right: Expression) leftIsNull: String, arrayDataName: String): String = { val tempArrayDataName = ctx.freshName("tempArrayData") - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val errorMessage = s" $prettyName failed." + val k = ctx.freshName("k") val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + val allocation = CodeGenerator.createArrayData( + tempArrayDataName, elementType, numElemName, s" $prettyName failed.") + val assignment = + CodeGenerator.setArrayElement(tempArrayDataName, elementType, k, element) + s""" |$numElemCode - |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |$allocation |if (!$leftIsNull) { - | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { - | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) { + | $assignment | } |} else { - | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { - | $tempArrayDataName.setNullAt(k); + | for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) { + | $tempArrayDataName.setNullAt($k); | } |} |$arrayDataName = $tempArrayDataName; """.stripMargin } - private def genCodeForNonPrimitiveElement( - ctx: CodegenContext, - element: String, - count: String, - leftIsNull: String, - arrayDataName: String): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayName = ctx.freshName("arrayObject") - val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) - - s""" - |$numElemCode - |Object[] $arrayName = new Object[(int)$numElemName]; - |if (!$leftIsNull) { - | for (int k = 0; k < $numElemName; k++) { - | $arrayName[k] = $element; - | } - |} - |$arrayDataName = new $genericArrayClass($arrayName); - """.stripMargin - } - } /** @@ -3339,50 +3126,117 @@ case class ArrayRemove(left: Expression, right: Expression) val pos = ctx.freshName("pos") val getValue = CodeGenerator.getValue(inputArray, elementType, i) val isEqual = ctx.genEqual(elementType, value, getValue) - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName + + val allocation = CodeGenerator.createArrayData( + values, elementType, newArraySize, s" $prettyName failed.") + val assignment = CodeGenerator.createArrayAssignment( + values, elementType, inputArray, pos, i, false) + + s""" + |$allocation + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $assignment + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } + + override def prettyName: String = "array_remove" +} + +/** + * Will become common base class for [[ArrayDistinct]], [[ArrayUnion]], [[ArrayIntersect]], + * and [[ArrayExcept]]. + */ +trait ArraySetLike { + protected def dt: DataType + protected def et: DataType + + @transient protected lazy val canUseSpecializedHashSet = et match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(et) + + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, et, i) + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName(et) + et match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = et match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = et match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { + if (dt.asInstanceOf[ArrayType].containsNull) { s""" - |int $pos = 0; - |Object[] $values = new Object[$newArraySize]; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $values[$pos] = null; - | $pos = $pos + 1; - | } - | else { - | if (!($isEqual)) { - | $values[$pos] = $getValue; - | $pos = $pos + 1; - | } - | } + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | $value.setNullAt($nullElementIndex); |} - |${ev.value} = new $arrayClass($values); """.stripMargin } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} - |int $pos = 0; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $values.setNullAt($pos); - | $pos = $pos + 1; - | } - | else { - | if (!($isEqual)) { - | $values.set$primitiveValueTypeName($pos, $getValue); - | $pos = $pos + 1; - | } - | } - |} - |${ev.value} = $values; - """.stripMargin + body } } - override def prettyName: String = "array_remove" + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " elements of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${et.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) + } + /** * Removes duplicate values from the array. */ @@ -3394,7 +3248,7 @@ case class ArrayRemove(left: Expression, right: Expression) [1,2,3,null] """, since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -3402,8 +3256,8 @@ case class ArrayDistinct(child: Expression) @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(elementType) + override protected def dt: DataType = dataType + override protected def et: DataType = elementType override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { @@ -3413,28 +3267,6 @@ case class ArrayDistinct(child: Expression) } } - @transient protected lazy val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } - - @transient protected lazy val (hsPostFix, hsTypeName) = { - val ptName = CodeGenerator.primitiveTypeName(elementType) - elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - } - - // we cast byte/short to int when writing to the hash set. - @transient protected lazy val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) doEvaluation(data) @@ -3471,28 +3303,73 @@ case class ArrayDistinct(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + nullSafeCodeGen(ctx, ev, (array) => { - val i = ctx.freshName("i") - val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") val openHashSet = classOf[OpenHashSet[_]].getName - val hs = ctx.freshName("hs") val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" - val getValue = CodeGenerator.getValue(array, elementType, i) + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + // Only need to track null element index when array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray = withArrayNullAssignment( + s""" + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + s""" - |int $sizeOfDistinctArray = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | $foundNullElement = true; - | } else { - | $hs.add$hsPostFix($hsValueCast$getValue); - | } + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array.numElements(); $i++) { + | $processArray |} - |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); - |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { @@ -3503,73 +3380,16 @@ case class ArrayDistinct(child: Expression) } } - private def setNull( - foundNullElement: String, - distinctArray: String, - pos: String): String = { - val setNullValue = s"$distinctArray.setNullAt($pos)" - s""" - |if (!($foundNullElement)) { - | $setNullValue; - | $pos = $pos + 1; - | $foundNullElement = true; - |} - """.stripMargin - } - - private def setValue( - hs: String, - distinctArray: String, - pos: String, - getValue1: String, - primitiveValueTypeName: String): String = { - s""" - |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) { - | $hs.add$hsPostFix($hsValueCast$getValue1); - | $distinctArray.set$primitiveValueTypeName($pos, $getValue1); - | $pos = $pos + 1; - |} - """.stripMargin - } - - def genCodeForResult( - ctx: CodegenContext, - ev: ExprCode, - inputArray: String, - size: String): String = { - val distinctArray = ctx.freshName("distinctArray") - val i = ctx.freshName("i") - val pos = ctx.freshName("pos") - val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) - val foundNullElement = ctx.freshName("foundNullElement") - val hs = ctx.freshName("hs") - val openHashSet = classOf[OpenHashSet[_]].getName - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" - - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | ${setNull(foundNullElement, distinctArray, pos)} - | } else { - | ${setValue(hs, distinctArray, pos, getValue1, primitiveValueTypeName)} - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin - } - override def prettyName: String = "array_distinct" } /** * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { +trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { + override protected def dt: DataType = dataType + override protected def et: DataType = elementType + override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -3579,81 +3399,9 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { typeCheckResult } } - - @transient protected lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(elementType) - - @transient protected lazy val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } - - protected def genGetValue(array: String, i: String): String = - CodeGenerator.getValue(array, elementType, i) - - @transient protected lazy val (hsPostFix, hsTypeName) = { - val ptName = CodeGenerator.primitiveTypeName (elementType) - elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - } - - // we cast byte/short to int when writing to the hash set. - @transient protected lazy val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - - // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will - // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. - @transient protected lazy val nullValueHolder = elementType match { - case ByteType => "(byte) 0" - case ShortType => "(short) 0" - case _ => "0" - } - - protected def withResultArrayNullCheck( - body: String, - value: String, - nullElementIndex: String): String = { - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |$body - |if ($nullElementIndex >= 0) { - | // result has null element - | $value.setNullAt($nullElementIndex); - |} - """.stripMargin - } else { - body - } - } - - def buildResultArray( - builder: String, - value : String, - size : String, - nullElementIndex : String): String = withResultArrayNullCheck( - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Cannot create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); - |} - | - |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { - | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | $value = new ${classOf[GenericArrayData].getName}($builder.result()); - |} - """.stripMargin, value, nullElementIndex) } -object ArraySetLike { +object ArrayBinaryLike { def throwUnionLengthOverflowException(length: Int): Unit = { throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + s"elements due to exceeding the array size limit " + @@ -3676,7 +3424,7 @@ object ArraySetLike { array(1, 2, 3, 5) """, since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike +case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { @@ -3697,7 +3445,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val elem = array.get(i, elementType) if (!hs.contains(elem)) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) } arrayBuffer += elem hs.add(elem) @@ -3732,7 +3480,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } if (!found) { if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) } arrayBuffer += elem } @@ -3864,7 +3612,7 @@ object ArrayUnion { } if (!found) { if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) } arrayBuffer += elem } @@ -3887,7 +3635,7 @@ object ArrayUnion { array(1, 3) """, since = "2.4.0") -case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike +case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { override def dataType: DataType = { dataTypeCheck @@ -4128,7 +3876,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL array(2) """, since = "2.4.0") -case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike +case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { override def dataType: DataType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 104b428614849..4da8ce05fe8a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -22,6 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -34,6 +36,31 @@ object ArrayData { case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a) case other => new GenericArrayData(other) } + + + /** + * Allocate [[UnsafeArrayData]] or [[GenericArrayData]] based on given parameters. + * + * @param elementSize a size of an element in bytes. If less than zero, the type of an element is + * non-primitive type + * @param numElements the number of elements the array should contain + * @param additionalErrorMessage string to include in the error message + */ + def allocateArrayData( + elementSize: Int, + numElements: Long, + additionalErrorMessage: String): ArrayData = { + if (elementSize >= 0 && !UnsafeArrayData.shouldUseGenericArrayData(elementSize, numElements)) { + UnsafeArrayData.createFreshArray(numElements.toInt, elementSize) + } else if (numElements <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong) { + new GenericArrayData(new Array[Any](numElements.toInt)) + } else { + throw new RuntimeException(s"Cannot create array with $numElements " + + "elements of data due to exceeding the limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData. " + + additionalErrorMessage) + } + } } abstract class ArrayData extends SpecializedGetters with Serializable { From 0b9b6b7d10c8aa3d5bfb013d824255f67791c97d Mon Sep 17 00:00:00 2001 From: blueszheng Date: Tue, 4 Sep 2018 04:39:55 -0700 Subject: [PATCH 1539/2461] [DOC] Update some outdated links ## What changes were proposed in this pull request? These links are outdated: - http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version - http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn Fix files which use these links. Closes #22321 from kisimple/docfix. Authored-by: blueszheng Signed-off-by: Dongjoon Hyun --- R/README.md | 2 +- R/WINDOWS.md | 2 +- README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/README.md b/R/README.md index 1152b1e8e5f9f..d77a1ecffc99c 100644 --- a/R/README.md +++ b/R/README.md @@ -17,7 +17,7 @@ export R_HOME=/home/username/R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ```bash build/mvn -DskipTests -Psparkr package diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 124bc631be9cd..da668a69b8679 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -14,7 +14,7 @@ directory in Maven in `PATH`. 4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). -5. Open a command shell (`cmd`) in the Spark directory and build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run +5. Open a command shell (`cmd`) in the Spark directory and build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ```bash mvn.cmd -DskipTests -Psparkr package diff --git a/README.md b/README.md index 531d330234062..fd8c7f656968e 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) +["Specifying the Hadoop Version and Enabling YARN"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version-and-enabling-yarn) for detailed guidance on building for a particular distribution of Hadoop, including building for particular Hive and Hive Thriftserver distributions. From 3aa60282cc84d471ea32ef240ec84e5b6e3e231b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 4 Sep 2018 09:44:42 -0700 Subject: [PATCH 1540/2461] [SPARK-19355][SQL][FOLLOWUP][TEST] Properly recycle SparkSession on TakeOrderedAndProjectSuite finishes ## What changes were proposed in this pull request? Previously in `TakeOrderedAndProjectSuite` the SparkSession will not get recycled when the test suite finishes. ## How was this patch tested? N/A Closes #22330 from jiangxb1987/SPARK-19355. Authored-by: Xingbo Jiang Signed-off-by: Xiao Li --- .../apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 0a1c94cc4ccf4..f076959dfdf7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -45,6 +45,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { protected override def afterAll() = { SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) + super.afterAll() } private def generateRandomInputData(): DataFrame = { From 061bb01d9b99911353e66a90abc3164c467fcae1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 4 Sep 2018 09:55:53 -0700 Subject: [PATCH 1541/2461] [SPARK-25248][CORE] Audit barrier Scala APIs for 2.4 ## What changes were proposed in this pull request? I made one pass over barrier APIs added to Spark 2.4 and updates some scopes and docs. I will update Python docs once Scala doc was reviewed. One major issue is that `BarrierTaskContext` implements `TaskContextImpl` that exposes some public methods. And internally there were several direct references to `TaskContextImpl` methods instead of `TaskContext`. This PR moved some methods from `TaskContextImpl` to `TaskContext`, remaining package private, and used delegate methods to avoid inheriting `TaskContextImp` and exposing unnecessary APIs. TODOs: - [x] scala doc - [x] python doc (#22261 ). Closes #22240 from mengxr/SPARK-25248. Authored-by: Xiangrui Meng Signed-off-by: Xiangrui Meng --- .../org/apache/spark/BarrierTaskContext.scala | 114 ++++++++++++++---- .../org/apache/spark/BarrierTaskInfo.scala | 2 +- .../scala/org/apache/spark/TaskContext.scala | 14 +++ .../org/apache/spark/TaskContextImpl.scala | 15 +-- .../spark/api/python/PythonRunner.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 10 +- .../org/apache/spark/rdd/RDDBarrier.scala | 22 ++-- .../org/apache/spark/scheduler/Task.scala | 35 +++--- .../scala/org/apache/spark/util/Utils.scala | 2 +- project/MimaExcludes.scala | 7 ++ .../spark/sql/internal/ReadOnlySQLConf.scala | 4 +- 11 files changed, 163 insertions(+), 64 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 3901f96326f75..90a5c4130f799 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -24,25 +24,22 @@ import scala.language.postfixOps import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} -import org.apache.spark.util.{RpcUtils, Utils} - -/** A [[TaskContext]] with extra info and tooling for a barrier stage. */ -class BarrierTaskContext( - override val stageId: Int, - override val stageAttemptNumber: Int, - override val partitionId: Int, - override val taskAttemptId: Long, - override val attemptNumber: Int, - override val taskMemoryManager: TaskMemoryManager, - localProperties: Properties, - @transient private val metricsSystem: MetricsSystem, - // The default value is only used in tests. - override val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, - taskMemoryManager, localProperties, metricsSystem, taskMetrics) { +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util._ + +/** + * :: Experimental :: + * A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage. + * Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task. + */ +@Experimental +@Since("2.4.0") +class BarrierTaskContext private[spark] ( + taskContext: TaskContext) extends TaskContext with Logging { // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. private val barrierCoordinator: RpcEndpointRef = { @@ -68,7 +65,7 @@ class BarrierTaskContext( * * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of misuses listed below: + * timeout. Some examples of '''misuses''' are listed below: * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it * shall lead to timeout of the function call. * {{{ @@ -146,20 +143,95 @@ class BarrierTaskContext( /** * :: Experimental :: - * Returns the all task infos in this barrier stage, the task infos are ordered by partitionId. + * Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID. */ @Experimental @Since("2.4.0") def getTaskInfos(): Array[BarrierTaskInfo] = { - val addressesStr = localProperties.getProperty("addresses", "") + val addressesStr = Option(taskContext.getLocalProperty("addresses")).getOrElse("") addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) } + + // delegate methods + + override def isCompleted(): Boolean = taskContext.isCompleted() + + override def isInterrupted(): Boolean = taskContext.isInterrupted() + + override def isRunningLocally(): Boolean = taskContext.isRunningLocally() + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + taskContext.addTaskCompletionListener(listener) + this + } + + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + taskContext.addTaskFailureListener(listener) + this + } + + override def stageId(): Int = taskContext.stageId() + + override def stageAttemptNumber(): Int = taskContext.stageAttemptNumber() + + override def partitionId(): Int = taskContext.partitionId() + + override def attemptNumber(): Int = taskContext.attemptNumber() + + override def taskAttemptId(): Long = taskContext.taskAttemptId() + + override def getLocalProperty(key: String): String = taskContext.getLocalProperty(key) + + override def taskMetrics(): TaskMetrics = taskContext.taskMetrics() + + override def getMetricsSources(sourceName: String): Seq[Source] = { + taskContext.getMetricsSources(sourceName) + } + + override private[spark] def killTaskIfInterrupted(): Unit = taskContext.killTaskIfInterrupted() + + override private[spark] def getKillReason(): Option[String] = taskContext.getKillReason() + + override private[spark] def taskMemoryManager(): TaskMemoryManager = { + taskContext.taskMemoryManager() + } + + override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = { + taskContext.registerAccumulator(a) + } + + override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + taskContext.setFetchFailed(fetchFailed) + } + + override private[spark] def markInterrupted(reason: String): Unit = { + taskContext.markInterrupted(reason) + } + + override private[spark] def markTaskFailed(error: Throwable): Unit = { + taskContext.markTaskFailed(error) + } + + override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = { + taskContext.markTaskCompleted(error) + } + + override private[spark] def fetchFailed: Option[FetchFailedException] = { + taskContext.fetchFailed + } + + override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties } +@Experimental +@Since("2.4.0") object BarrierTaskContext { /** - * Return the currently active BarrierTaskContext. This can be called inside of user functions to + * :: Experimental :: + * Returns the currently active BarrierTaskContext. This can be called inside of user functions to * access contextual information about running barrier tasks. */ + @Experimental + @Since("2.4.0") def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext] } diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala index ce2653df2e845..347239b1d7db4 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala @@ -28,4 +28,4 @@ import org.apache.spark.annotation.{Experimental, Since} */ @Experimental @Since("2.4.0") -class BarrierTaskInfo(val address: String) +class BarrierTaskInfo private[spark] (val address: String) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index ceadf108c86cd..2b939dabb1105 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -221,4 +221,18 @@ abstract class TaskContext extends Serializable { */ private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(reason: String): Unit + + /** Marks the task as failed and triggers the failure listeners. */ + private[spark] def markTaskFailed(error: Throwable): Unit + + /** Marks the task as completed and triggers the completion listeners. */ + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit + + /** Optionally returns the stored fetch failure in the task. */ + private[spark] def fetchFailed: Option[FetchFailedException] + + /** Gets local properties set upstream in the driver. */ + private[spark] def getLocalProperties: Properties } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 0791fe856ef15..89730424e5acf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ + /** * A [[TaskContext]] implementation. * @@ -98,9 +99,8 @@ private[spark] class TaskContextImpl( this } - /** Marks the task as failed and triggers the failure listeners. */ @GuardedBy("this") - private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { + private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true failure = error @@ -109,9 +109,8 @@ private[spark] class TaskContextImpl( } } - /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { + private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { @@ -140,8 +139,7 @@ private[spark] class TaskContextImpl( } } - /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(reason: String): Unit = { + private[spark] override def markInterrupted(reason: String): Unit = { reasonIfKilled = Some(reason) } @@ -176,8 +174,7 @@ private[spark] class TaskContextImpl( this._fetchFailedException = Option(fetchFailed) } - private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException - // TODO: shall we publish it and define it in `TaskContext`? - private[spark] def getLocalProperties(): Properties = localProperties + private[spark] override def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 6c7e8630789bd..4c53bc269a104 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -270,7 +270,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) - val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + val localProps = context.getLocalProperties.asScala dataOut.writeInt(localProps.size) localProps.foreach { case (k, v) => PythonRDD.writeUTF(k, dataOut) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 374b846d2ea57..ea895bb3412e1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1649,7 +1649,15 @@ abstract class RDD[T: ClassTag]( /** * :: Experimental :: - * Indicates that Spark must launch the tasks together for the current stage. + * Marks the current stage as a barrier stage, where Spark must launch all tasks together. + * In case of a task failure, instead of only restarting the failed task, Spark will abort the + * entire stage and re-launch all tasks for this stage. + * The barrier execution mode feature is experimental and it only handles limited scenarios. + * Please read the linked SPIP and design docs to understand the limitations and future plans. + * @return an [[RDDBarrier]] instance that provides actions within a barrier stage + * @see [[org.apache.spark.BarrierTaskContext]] + * @see SPIP: Barrier Execution Mode + * @see Design Doc */ @Experimental @Since("2.4.0") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index b399bf9febae3..42802f7113a19 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -22,15 +22,23 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.annotation.{Experimental, Since} -/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */ -class RDDBarrier[T: ClassTag](rdd: RDD[T]) { +/** + * :: Experimental :: + * Wraps an RDD in a barrier stage, which forces Spark to launch tasks of this stage together. + * [[org.apache.spark.rdd.RDDBarrier]] instances are created by + * [[org.apache.spark.rdd.RDD#barrier]]. + */ +@Experimental +@Since("2.4.0") +class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) { /** * :: Experimental :: - * Generate a new barrier RDD by applying a function to each partitions of the prev RDD. - * - * `preservesPartitioning` indicates whether the input function preserves the partitioner, which - * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. + * Returns a new RDD by applying a function to each partition of the wrapped RDD, + * where tasks are launched together in a barrier stage. + * The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitions]]. + * Please see the API doc there. + * @see [[org.apache.spark.BarrierTaskContext]] */ @Experimental @Since("2.4.0") @@ -46,5 +54,5 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) { ) } - /** TODO extra conf(e.g. timeout) */ + // TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 11f85fd91ba08..eb059f12be6d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -82,28 +82,21 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether // the stage is barrier. + val taskContext = new TaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + context = if (isBarrier) { - new BarrierTaskContext( - stageId, - stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal - partitionId, - taskAttemptId, - attemptNumber, - taskMemoryManager, - localProperties, - metricsSystem, - metrics) + new BarrierTaskContext(taskContext) } else { - new TaskContextImpl( - stageId, - stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal - partitionId, - taskAttemptId, - attemptNumber, - taskMemoryManager, - localProperties, - metricsSystem, - metrics) + taskContext } TaskContext.setTaskContext(context) @@ -180,7 +173,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient var context: TaskContextImpl = _ + @transient var context: TaskContext = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e6646bd073c6b..935bff92c466f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1387,7 +1387,7 @@ private[spark] object Utils extends Logging { originalThrowable = cause try { logError("Aborting task", originalThrowable) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable) + TaskContext.get().markTaskFailed(originalThrowable) catchBlock } catch { case t: Throwable => diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 62f8b1af50a6c..45cc5ccf2ea92 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,13 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-25248] add package private methods to TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskFailed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markInterrupted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.fetchFailed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskCompleted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperties"), + // [SPARK-10697][ML] Add lift to Association rules ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala index 19f67236c8979..ef4b339730807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import java.util.{Map => JMap} -import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.TaskContext import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} /** @@ -29,7 +29,7 @@ import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigRead class ReadOnlySQLConf(context: TaskContext) extends SQLConf { @transient override val settings: JMap[String, String] = { - context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + context.getLocalProperties.asInstanceOf[JMap[String, String]] } @transient override protected val reader: ConfigReader = { From 103f513231d31cfa475aa34ce4defc63acc3cab6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Sep 2018 10:24:13 +0800 Subject: [PATCH 1542/2461] [SPARK-25306][SQL] Avoid skewed filter trees to speed up `createFilter` in ORC ## What changes were proposed in this pull request? In both ORC data sources, `createFilter` function has exponential time complexity due to its skewed filter tree generation. This PR aims to improve it by using new `buildTree` function. **REPRODUCE** ```scala // Create and read 1 row table with 1000 columns sql("set spark.sql.orc.filterPushdown=true") val selectExpr = (1 to 1000).map(i => s"id c$i") spark.range(1).selectExpr(selectExpr: _*).write.mode("overwrite").orc("/tmp/orc") print(s"With 0 filters, ") spark.time(spark.read.orc("/tmp/orc").count) // Increase the number of filters (20 to 30).foreach { width => val whereExpr = (1 to width).map(i => s"c$i is not null").mkString(" and ") print(s"With $width filters, ") spark.time(spark.read.orc("/tmp/orc").where(whereExpr).count) } ``` **RESULT** ```scala With 0 filters, Time taken: 653 ms With 20 filters, Time taken: 962 ms With 21 filters, Time taken: 1282 ms With 22 filters, Time taken: 1982 ms With 23 filters, Time taken: 3855 ms With 24 filters, Time taken: 6719 ms With 25 filters, Time taken: 12669 ms With 26 filters, Time taken: 25032 ms With 27 filters, Time taken: 49585 ms With 28 filters, Time taken: 98980 ms // over 1 min 38 seconds With 29 filters, Time taken: 198368 ms // over 3 mins With 30 filters, Time taken: 393744 ms // over 6 mins ``` **AFTER THIS PR** ```scala With 0 filters, Time taken: 774 ms With 20 filters, Time taken: 601 ms With 21 filters, Time taken: 399 ms With 22 filters, Time taken: 679 ms With 23 filters, Time taken: 363 ms With 24 filters, Time taken: 342 ms With 25 filters, Time taken: 336 ms With 26 filters, Time taken: 352 ms With 27 filters, Time taken: 322 ms With 28 filters, Time taken: 302 ms With 29 filters, Time taken: 307 ms With 30 filters, Time taken: 301 ms ``` ## How was this patch tested? Pass the Jenkins with newly added test cases. Closes #22313 from dongjoon-hyun/SPARK-25306. Authored-by: Dongjoon Hyun Signed-off-by: Wenchen Fan --- .../FilterPushdownBenchmark-results.txt | 34 +++++++++++++++++++ .../datasources/orc/OrcFilters.scala | 25 +++++++++----- .../benchmark/FilterPushdownBenchmark.scala | 32 +++++++++++++---- .../spark/sql/hive/orc/OrcFilters.scala | 12 +++---- 4 files changed, 82 insertions(+), 21 deletions(-) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 2215ed91e2018..a75a15c99328a 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -702,3 +702,37 @@ Parquet Vectorized (Pushdown) 11766 / 11927 1.3 7 Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + +================================================================================================ +Pushdown benchmark with many filters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + +Select 1 row with 1 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 158 / 182 0.0 158442969.0 1.0X +Parquet Vectorized (Pushdown) 150 / 158 0.0 149718289.0 1.1X +Native ORC Vectorized 141 / 148 0.0 141259852.0 1.1X +Native ORC Vectorized (Pushdown) 142 / 147 0.0 142016472.0 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + +Select 1 row with 250 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 1013 / 1026 0.0 1013194322.0 1.0X +Parquet Vectorized (Pushdown) 1326 / 1332 0.0 1326301956.0 0.8X +Native ORC Vectorized 1005 / 1010 0.0 1005266379.0 1.0X +Native ORC Vectorized (Pushdown) 1068 / 1071 0.0 1067964993.0 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + +Select 1 row with 500 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3598 / 3614 0.0 3598001202.0 1.0X +Parquet Vectorized (Pushdown) 4282 / 4333 0.0 4281849770.0 0.8X +Native ORC Vectorized 3594 / 3619 0.0 3593551548.0 1.0X +Native ORC Vectorized (Pushdown) 3834 / 3840 0.0 3834240570.0 0.9X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index c4c3b3053a3b1..dbafc468c6c40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.datasources.orc -import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory} +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types._ /** @@ -54,7 +55,17 @@ import org.apache.spark.sql.types._ * builder methods mentioned above can only be found in test code, where all tested filters are * known to be convertible. */ -private[orc] object OrcFilters { +private[sql] object OrcFilters { + private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { + filters match { + case Seq() => None + case Seq(filter) => Some(filter) + case Seq(filter1, filter2) => Some(And(filter1, filter2)) + case _ => // length > 2 + val (left, right) = filters.splitAt(filters.length / 2) + Some(And(buildTree(left).get, buildTree(right).get)) + } + } /** * Create ORC filter as a SearchArgument instance. @@ -66,14 +77,14 @@ private[orc] object OrcFilters { // collect all convertible ones to build the final `SearchArgument`. val convertibleFilters = for { filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) } yield filter for { // Combines all convertible filters using `And` to produce a single conjunction - conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And) + conjunction <- buildTree(convertibleFilters) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate - builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) } yield builder.build() } @@ -127,8 +138,6 @@ private[orc] object OrcFilters { dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgumentFactory.newBuilder() - def getType(attribute: String): PredicateLeaf.Type = getPredicateLeafType(dataTypeMap(attribute)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index bdb60b44750c7..41087f1a97174 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -242,7 +242,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter ignore("Pushdown for many distinct value case") { withTempPath { dir => - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { Seq(true, false).foreach { useStringForValue => prepareTable(dir, numRows, width, useStringForValue) if (useStringForValue) { @@ -259,7 +259,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter withTempPath { dir => val numDistinctValues = 200 - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { prepareStringDictTable(dir, numRows, numDistinctValues, width) runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") } @@ -268,7 +268,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter ignore("Pushdown benchmark for StringStartsWith") { withTempPath { dir => - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { prepareTable(dir, numRows, width, true) Seq( "value like '10%'", @@ -296,7 +296,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter monotonically_increasing_id() } val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) Seq(s"value = $mid").foreach { whereExpr => @@ -320,7 +320,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter ignore("Pushdown benchmark for InSet -> InFilters") { withTempPath { dir => - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { prepareTable(dir, numRows, width, false) Seq(5, 10, 50, 100).foreach { count => Seq(10, 50, 90).foreach { distribution => @@ -341,7 +341,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter val df = spark.range(numRows).selectExpr(columns: _*) .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) .orderBy("value") - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") @@ -373,7 +373,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter val columns = (1 to width).map(i => s"CAST(id AS string) c$i") val df = spark.range(numRows).selectExpr(columns: _*) .withColumn("value", monotonically_increasing_id().cast(TimestampType)) - withTempTable("orcTable", "patquetTable") { + withTempTable("orcTable", "parquetTable") { saveAsTable(df, dir) Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => @@ -398,6 +398,24 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter } } } + + test(s"Pushdown benchmark with many filters") { + val numRows = 1 + val width = 500 + + withTempPath { dir => + val columns = (1 to width).map(i => s"id c$i") + val df = spark.range(1).selectExpr(columns: _*) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + Seq(1, 250, 500).foreach { numFilter => + val whereExpr = (1 to numFilter).map(i => s"c$i = 0").mkString(" and ") + // Note: InferFiltersFromConstraints will add more filters to this given filters + filterPushDownBenchmark(numRows, s"Select 1 row with $numFilter filters", whereExpr) + } + } + } + } } trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index d9efd0cb457cd..aee9cb58a031e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive.orc -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.orc.OrcFilters.buildTree import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -62,14 +64,14 @@ private[orc] object OrcFilters extends Logging { // collect all convertible ones to build the final `SearchArgument`. val convertibleFilters = for { filter <- filters - _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder()) + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder) } yield filter for { // Combines all convertible filters using `And` to produce a single conjunction - conjunction <- convertibleFilters.reduceOption(And) + conjunction <- buildTree(convertibleFilters) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate - builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder()) + builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) } yield builder.build() } @@ -77,8 +79,6 @@ private[orc] object OrcFilters extends Logging { dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgumentFactory.newBuilder() - def isSearchableType(dataType: DataType): Boolean = dataType match { // Only the values in the Spark types below can be recognized by // the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. From ca861fea21adc4e6ec95eced7076cb27fc86ea18 Mon Sep 17 00:00:00 2001 From: liuxian Date: Wed, 5 Sep 2018 10:43:46 +0800 Subject: [PATCH 1543/2461] [SPARK-25300][CORE] Unified the configuration parameter `spark.shuffle.service.enabled` ## What changes were proposed in this pull request? The configuration parameter "spark.shuffle.service.enabled" has defined in `package.scala`, and it is also used in many place, so we can replace it with `SHUFFLE_SERVICE_ENABLED`. and unified this configuration parameter "spark.shuffle.service.port" together. ## How was this patch tested? N/A Closes #22306 from 10110346/unifiedserviceenable. Authored-by: liuxian Signed-off-by: Wenchen Fan --- .../org/apache/spark/ExecutorAllocationManager.scala | 4 ++-- .../org/apache/spark/deploy/ExternalShuffleService.scala | 8 ++++---- .../scala/org/apache/spark/deploy/LocalSparkCluster.scala | 4 ++-- .../scala/org/apache/spark/deploy/worker/Worker.scala | 4 ++-- .../scala/org/apache/spark/internal/config/package.scala | 3 +++ .../scala/org/apache/spark/storage/BlockManager.scala | 7 ++++--- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 ++-- .../org/apache/spark/ExecutorAllocationManagerSuite.scala | 3 ++- .../org/apache/spark/ExternalShuffleServiceSuite.scala | 5 +++-- .../spark/deploy/StandaloneDynamicAllocationSuite.scala | 2 +- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 5 +++-- .../org/apache/spark/storage/BlockManagerSuite.scala | 4 ++-- .../mesos/MesosCoarseGrainedSchedulerBackend.scala | 4 ++-- .../mesos/MesosCoarseGrainedSchedulerBackendSuite.scala | 2 +- .../spark/deploy/yarn/YarnShuffleIntegrationSuite.scala | 6 +++--- 15 files changed, 36 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 17b88631bcb4c..c3e5b96a55884 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -25,7 +25,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ @@ -212,7 +212,7 @@ private[spark] class ExecutorAllocationManager( } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors - if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + if (!conf.get(config.SHUFFLE_SERVICE_ENABLED) && !testing) { throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b59a4fe66587c..f6b3c37f0fe72 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.TransportContext import org.apache.spark.network.crypto.AuthServerBootstrap @@ -45,8 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana protected val masterMetricsSystem = MetricsSystem.createMetricsSystem("shuffleService", sparkConf, securityManager) - private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) - private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val enabled = sparkConf.get(config.SHUFFLE_SERVICE_ENABLED) + private val port = sparkConf.get(config.SHUFFLE_SERVICE_PORT) private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) @@ -131,7 +131,7 @@ object ExternalShuffleService extends Logging { // we override this value since this service is started from the command line // and we assume the user really wants it to be running - sparkConf.set("spark.shuffle.service.enabled", "true") + sparkConf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") server = newShuffleService(sparkConf, securityManager) server.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 84aa8944fc1c7..be293f88a9d4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils @@ -52,7 +52,7 @@ class LocalSparkCluster( // Disable REST server on Master in this mode unless otherwise specified val _conf = conf.clone() .setIfMissing("spark.master.rest.enabled", "false") - .set("spark.shuffle.service.enabled", "false") + .set(config.SHUFFLE_SERVICE_ENABLED.key, "false") /* Start the Master */ val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index cbd812a05a2c6..d5ea2523c628b 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -36,7 +36,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} @@ -773,7 +773,7 @@ private[deploy] object Worker extends Logging { // bound, we may launch no more than one external shuffle service on each host. // When this happens, we should give explicit reason of failure instead of fail silently. For // more detail see SPARK-20989. - val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val externalShuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) val sparkWorkerInstances = scala.sys.env.getOrElse("SPARK_WORKER_INSTANCES", "1").toInt require(externalShuffleServiceEnabled == false || sparkWorkerInstances <= 1, "Starting multiple workers on one host is failed because we may launch no more than one " + diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7c2f601c9986a..319e664a19677 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -144,6 +144,9 @@ package object config { private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_PORT = + ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337) + private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") .doc("Location of user's keytab.") .stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e7cdfab99b34d..f5c69ad241e3a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -130,7 +130,7 @@ private[spark] class BlockManager( extends BlockDataManager with BlockEvictionHandler with Logging { private[spark] val externalShuffleServiceEnabled = - conf.getBoolean("spark.shuffle.service.enabled", false) + conf.get(config.SHUFFLE_SERVICE_ENABLED) private val chunkSize = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString).toInt private val remoteReadNioBufferConversion = @@ -165,12 +165,13 @@ private[spark] class BlockManager( // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. private val externalShuffleServicePort = { - val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + val tmpPort = Utils.getSparkOrYarnConfig(conf, config.SHUFFLE_SERVICE_PORT.key, + config.SHUFFLE_SERVICE_PORT.defaultValueString).toInt if (tmpPort == 0) { // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds // an open port. But we still need to tell our spark apps the right port to use. So // only if the yarn config has the port set to 0, we prefer the value in the spark config - conf.get("spark.shuffle.service.port").toInt + conf.get(config.SHUFFLE_SERVICE_PORT.key).toInt } else { tmpPort } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 935bff92c466f..15c958d3f511e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -60,7 +60,7 @@ import org.slf4j.Logger import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils @@ -822,7 +822,7 @@ private[spark] object Utils extends Logging { * logic of locating the local directories according to deployment mode. */ def getConfiguredLocalDirs(conf: SparkConf): Array[String] = { - val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + val shuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. Note this assumes that Yarn has diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 659ebb60fef86..5c718cb654ce8 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -24,6 +24,7 @@ import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -1092,7 +1093,7 @@ class ExecutorAllocationManagerSuite val maxExecutors = 2 val conf = new SparkConf() .set("spark.dynamicAllocation.enabled", "true") - .set("spark.shuffle.service.enabled", "true") + .set(config.SHUFFLE_SERVICE_ENABLED.key, "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 472952addf353..462d5f5604ae3 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.scalatest.BeforeAndAfterAll +import org.apache.spark.internal.config import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.TransportServer @@ -42,8 +43,8 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { server = transportContext.createServer() conf.set("spark.shuffle.manager", "sort") - conf.set("spark.shuffle.service.enabled", "true") - conf.set("spark.shuffle.service.port", server.getPort.toString) + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") + conf.set(config.SHUFFLE_SERVICE_PORT.key, server.getPort.toString) } override def afterAll() { diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 27cc47496c805..a1d2a1283db14 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -458,7 +458,7 @@ class StandaloneDynamicAllocationSuite val initialExecutorLimit = 1 val myConf = appConf .set("spark.dynamicAllocation.enabled", "true") - .set("spark.shuffle.service.enabled", "true") + .set(config.SHUFFLE_SERVICE_ENABLED.key, "true") .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString) sc = new SparkContext(myConf) val appId = sc.applicationId diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index cd00051c56e8d..e0202fe703f82 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.internal.config import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} @@ -406,7 +407,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // reset the test context with the right shuffle service config afterEach() val conf = new SparkConf() - conf.set("spark.shuffle.service.enabled", "true") + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") init(conf) runEvent(ExecutorAdded("exec-hostA1", "hostA")) @@ -728,7 +729,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // reset the test context with the right shuffle service config afterEach() val conf = new SparkConf() - conf.set("spark.shuffle.service.enabled", shuffleServiceOn.toString) + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, shuffleServiceOn.toString) init(conf) assert(sc.env.blockManager.externalShuffleServiceEnabled == shuffleServiceOn) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 08172f0b07b75..dbee1f60d7af0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1377,8 +1377,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") - conf.set("spark.shuffle.service.enabled", "true") - conf.set("spark.shuffle.service.port", shufflePort.toString) + conf.set(SHUFFLE_SERVICE_ENABLED.key, "true") + conf.set(SHUFFLE_SERVICE_PORT.key, shufflePort.toString) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") var e = intercept[SparkException] { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 1ce2f816dffb2..178de30f0f381 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -102,7 +102,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. - private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val shuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) // Cores we have acquired with each Mesos task ID private val coresByTaskId = new mutable.HashMap[String, Int] @@ -624,7 +624,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. - val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) + val externalShufflePort = conf.get(config.SHUFFLE_SERVICE_PORT) logDebug(s"Connecting to shuffle service on slave $slaveId, " + s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index b790c7cd27794..da33d85d8fb2e 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -262,7 +262,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite } test("mesos doesn't register twice with the same shuffle service") { - setBackend(Map("spark.shuffle.service.enabled" -> "true")) + setBackend(Map(SHUFFLE_SERVICE_ENABLED.key -> "true")) val (mem, cpu) = (backend.executorMemory(sc), 4) val offer1 = createOffer("o1", "s1", mem, cpu) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 01db796096f26..37bccaf0439b4 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -44,7 +44,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) - yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig.set(SHUFFLE_SERVICE_PORT.key, "0") yarnConfig } @@ -54,8 +54,8 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { logInfo("Shuffle service port = " + shuffleServicePort) Map( - "spark.shuffle.service.enabled" -> "true", - "spark.shuffle.service.port" -> shuffleServicePort.toString, + SHUFFLE_SERVICE_ENABLED.key -> "true", + SHUFFLE_SERVICE_PORT.key -> shuffleServicePort.toString, MAX_EXECUTOR_FAILURES.key -> "1" ) } From 2119e518d31331e65415e0f817a6f28ff18d2b42 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 5 Sep 2018 13:39:34 +0800 Subject: [PATCH 1544/2461] [SPARK-25336][SS]Revert SPARK-24863 and SPARK-24748 ## What changes were proposed in this pull request? Revert SPARK-24863 (#21819) and SPARK-24748 (#21721) as per discussion in #21721. We will revisit them when the data source v2 APIs are out. ## How was this patch tested? Jenkins Closes #22334 from zsxwing/revert-SPARK-24863-SPARK-24748. Authored-by: Shixiong Zhu Signed-off-by: Wenchen Fan --- .../apache/spark/sql/kafka010/JsonUtils.scala | 33 +++-------- .../kafka010/KafkaMicroBatchReadSupport.scala | 30 +--------- .../kafka010/KafkaMicroBatchSourceSuite.scala | 37 ------------ .../spark/sql/sources/v2/CustomMetrics.java | 33 ----------- .../SupportsCustomReaderMetrics.java | 47 --------------- .../SupportsCustomWriterMetrics.java | 47 --------------- .../streaming/ProgressReporter.scala | 58 ++----------------- .../streaming/sources/memoryV2.scala | 22 +------ .../apache/spark/sql/streaming/progress.scala | 46 ++------------- .../streaming/MemorySinkV2Suite.scala | 21 ------- .../sql/streaming/StreamingQuerySuite.scala | 27 --------- 11 files changed, 22 insertions(+), 379 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 92b13f2b555d1..868edb5dcdc0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -29,11 +29,6 @@ import org.json4s.jackson.Serialization */ private object JsonUtils { private implicit val formats = Serialization.formats(NoTypeHints) - implicit val ordering = new Ordering[TopicPartition] { - override def compare(x: TopicPartition, y: TopicPartition): Int = { - Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) - } - } /** * Read TopicPartitions from json string @@ -56,7 +51,7 @@ private object JsonUtils { * Write TopicPartitions as json string */ def partitions(partitions: Iterable[TopicPartition]): String = { - val result = HashMap.empty[String, List[Int]] + val result = new HashMap[String, List[Int]] partitions.foreach { tp => val parts: List[Int] = result.getOrElse(tp.topic, Nil) result += tp.topic -> (tp.partition::parts) @@ -85,31 +80,19 @@ private object JsonUtils { * Write per-TopicPartition offsets as json string */ def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { - val result = HashMap.empty[String, HashMap[Int, Long]] + val result = new HashMap[String, HashMap[Int, Long]]() + implicit val ordering = new Ordering[TopicPartition] { + override def compare(x: TopicPartition, y: TopicPartition): Int = { + Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) + } + } val partitions = partitionOffsets.keySet.toSeq.sorted // sort for more determinism partitions.foreach { tp => val off = partitionOffsets(tp) - val parts = result.getOrElse(tp.topic, HashMap.empty[Int, Long]) + val parts = result.getOrElse(tp.topic, new HashMap[Int, Long]) parts += tp.partition -> off result += tp.topic -> parts } Serialization.write(result) } - - /** - * Write per-topic partition lag as json string - */ - def partitionLags( - latestOffsets: Map[TopicPartition, Long], - processedOffsets: Map[TopicPartition, Long]): String = { - val result = HashMap.empty[String, HashMap[Int, Long]] - val partitions = latestOffsets.keySet.toSeq.sorted - partitions.foreach { tp => - val lag = latestOffsets(tp) - processedOffsets.getOrElse(tp, 0L) - val parts = result.getOrElse(tp.topic, HashMap.empty[Int, Long]) - parts += tp.partition -> lag - result += tp.topic -> parts - } - Serialization.write(Map("lag" -> result)) - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index 70f37e32e78db..bb4de674c3c72 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -22,7 +22,6 @@ import java.io._ import java.nio.charset.StandardCharsets import org.apache.commons.io.IOUtils -import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -33,9 +32,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchReadSupport import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions} +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -61,8 +60,7 @@ private[kafka010] class KafkaMicroBatchReadSupport( options: DataSourceOptions, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) - extends RateControlMicroBatchReadSupport with SupportsCustomReaderMetrics with Logging { + failOnDataLoss: Boolean) extends RateControlMicroBatchReadSupport with Logging { private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", @@ -156,13 +154,6 @@ private[kafka010] class KafkaMicroBatchReadSupport( KafkaMicroBatchReaderFactory } - // TODO: figure out the life cycle of custom metrics, and make this method take `ScanConfig` as - // a parameter. - override def getCustomMetrics(): CustomMetrics = { - KafkaCustomMetrics( - kafkaOffsetReader.fetchLatestOffsets(), endPartitionOffsets.partitionToOffsets) - } - override def deserializeOffset(json: String): Offset = { KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } @@ -384,18 +375,3 @@ private[kafka010] case class KafkaMicroBatchPartitionReader( } } } - -/** - * Currently reports per topic-partition lag. - * This is the difference between the offset of the latest available data - * in a topic-partition and the latest offset that has been processed. - */ -private[kafka010] case class KafkaCustomMetrics( - latestOffsets: Map[TopicPartition, Long], - processedOffsets: Map[TopicPartition, Long]) extends CustomMetrics { - override def json(): String = { - JsonUtils.partitionLags(latestOffsets, processedOffsets) - } - - override def toString: String = json() -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 78249f7a80fb5..8e246dbbf5d70 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -30,8 +30,6 @@ import scala.util.Random import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata} import org.apache.kafka.common.TopicPartition -import org.json4s.DefaultFormats -import org.json4s.jackson.JsonMethods._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ @@ -958,41 +956,6 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } } - test("custom lag metrics") { - import testImplicits._ - val topic = newTopic() - testUtils.createTopic(topic, partitions = 2) - testUtils.sendMessages(topic, (1 to 100).map(_.toString).toArray) - require(testUtils.getLatestOffsets(Set(topic)).size === 2) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("startingOffsets", s"earliest") - .option("maxOffsetsPerTrigger", 10) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - implicit val formats = DefaultFormats - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = OneTimeTrigger), - AssertOnQuery { query => - query.awaitTermination() - val source = query.lastProgress.sources(0) - // masOffsetsPerTrigger is 10, and there are two partitions containing 50 events each - // so 5 events should be processed from each partition and a lag of 45 events - val custom = parse(source.customMetrics) - .extract[Map[String, Map[String, Map[String, Long]]]] - custom("lag")(topic)("0") == 45 && custom("lag")(topic)("1") == 45 - } - ) - } - } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java deleted file mode 100644 index 7011a70e515e2..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.InterfaceStability; - -/** - * An interface for reporting custom metrics from streaming sources and sinks - */ -@InterfaceStability.Evolving -public interface CustomMetrics { - /** - * Returns a JSON serialized representation of custom metrics - * - * @return JSON serialized representation of custom metrics - */ - String json(); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java deleted file mode 100644 index 8693154cb7045..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.reader.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.CustomMetrics; - -/** - * A mix in interface for {@link StreamingReadSupport}. Data sources can implement this interface - * to report custom metrics that gets reported under the - * {@link org.apache.spark.sql.streaming.SourceProgress} - */ -@InterfaceStability.Evolving -public interface SupportsCustomReaderMetrics extends StreamingReadSupport { - - /** - * Returns custom metrics specific to this data source. - */ - CustomMetrics getCustomMetrics(); - - /** - * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid - * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that - * your custom metrics work right and correct values are reported always. The default action - * on invalid metrics is to ignore it. - * - * @param ex the exception - */ - default void onInvalidMetrics(Exception ex) { - // default is to ignore invalid custom metrics - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java deleted file mode 100644 index 2b018c7d123bb..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer.streaming; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.CustomMetrics; - -/** - * A mix in interface for {@link StreamingWriteSupport}. Data sources can implement this interface - * to report custom metrics that gets reported under the - * {@link org.apache.spark.sql.streaming.SinkProgress} - */ -@InterfaceStability.Evolving -public interface SupportsCustomWriterMetrics extends StreamingWriteSupport { - - /** - * Returns custom metrics specific to this data source. - */ - CustomMetrics getCustomMetrics(); - - /** - * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid - * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that - * your custom metrics work right and correct values are reported always. The default action - * on invalid metrics is to ignore it. - * - * @param ex the exception - */ - default void onInvalidMetrics(Exception ex) { - // default is to ignore invalid custom metrics - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 417b6b39366ae..d4b50655c7215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -22,20 +22,14 @@ import java.util.{Date, UUID} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.control.NonFatal - -import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWritSupport -import org.apache.spark.sql.sources.v2.CustomMetrics -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, SupportsCustomReaderMetrics} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWriteSupport, SupportsCustomWriterMetrics} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -162,31 +156,7 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - // extracts and validates custom metrics from readers and writers - def extractMetrics( - getMetrics: () => Option[CustomMetrics], - onInvalidMetrics: (Exception) => Unit): Option[String] = { - try { - getMetrics().map(m => { - val json = m.json() - parse(json) - json - }) - } catch { - case ex: Exception if NonFatal(ex) => - onInvalidMetrics(ex) - None - } - } - val sourceProgress = sources.distinct.map { source => - val customReaderMetrics = source match { - case s: SupportsCustomReaderMetrics => - extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) - - case _ => None - } - val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -194,19 +164,11 @@ trait ProgressReporter extends Logging { endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, - processedRowsPerSecond = numRecords / processingTimeSec, - customReaderMetrics.orNull + processedRowsPerSecond = numRecords / processingTimeSec ) } - val customWriterMetrics = extractWriteSupport() match { - case Some(s: SupportsCustomWriterMetrics) => - extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) - - case _ => None - } - - val sinkProgress = new SinkProgress(sink.toString, customWriterMetrics.orNull) + val sinkProgress = new SinkProgress(sink.toString) val newProgress = new StreamingQueryProgress( id = id, @@ -235,18 +197,6 @@ trait ProgressReporter extends Logging { currentStatus = currentStatus.copy(isTriggerActive = false) } - /** Extract writer from the executed query plan. */ - private def extractWriteSupport(): Option[StreamingWriteSupport] = { - if (lastExecution == null) return None - lastExecution.executedPlan.collect { - case p if p.isInstanceOf[WriteToDataSourceV2Exec] => - p.asInstanceOf[WriteToDataSourceV2Exec].writeSupport - }.headOption match { - case Some(w: MicroBatchWritSupport) => Some(w.writeSupport) - case _ => None - } - } - /** Extract statistics about stateful operators from the executed query plan. */ private def extractStateOperatorMetrics(hasNewData: Boolean): Seq[StateOperatorProgress] = { if (lastExecution == null) return Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 2509450f0da9d..c50dc7bcb8da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -23,9 +23,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.json4s.NoTypeHints -import org.json4s.jackson.Serialization - import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -35,9 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport, SupportsCustomWriterMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -119,26 +116,15 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider batches.clear() } - def numRows: Int = synchronized { - batches.foldLeft(0)(_ + _.data.length) - } - override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { - private implicit val formats = Serialization.formats(NoTypeHints) - override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) -} - class MemoryStreamingWriteSupport( val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport with SupportsCustomWriterMetrics { - - private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) + extends StreamingWriteSupport { override def createStreamingWriterFactory: MemoryWriterFactory = { MemoryWriterFactory(outputMode, schema) @@ -154,8 +140,6 @@ class MemoryStreamingWriteSupport( override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } - - override def getCustomMetrics: CustomMetrics = customMemoryV2Metrics } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index cf9375d39b39d..f2173aa1e59c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -172,27 +172,7 @@ class SourceProgress protected[sql]( val endOffset: String, val numInputRows: Long, val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double, - val customMetrics: String) extends Serializable { - - /** SourceProgress without custom metrics. */ - protected[sql] def this( - description: String, - startOffset: String, - endOffset: String, - numInputRows: Long, - inputRowsPerSecond: Double, - processedRowsPerSecond: Double) { - - this( - description, - startOffset, - endOffset, - numInputRows, - inputRowsPerSecond, - processedRowsPerSecond, - null) - } + val processedRowsPerSecond: Double) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -207,18 +187,12 @@ class SourceProgress protected[sql]( if (value.isNaN || value.isInfinity) JNothing else JDouble(value) } - val jsonVal = ("description" -> JString(description)) ~ + ("description" -> JString(description)) ~ ("startOffset" -> tryParse(startOffset)) ~ ("endOffset" -> tryParse(endOffset)) ~ ("numInputRows" -> JInt(numInputRows)) ~ ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) - - if (customMetrics != null) { - jsonVal ~ ("customMetrics" -> parse(customMetrics)) - } else { - jsonVal - } } private def tryParse(json: String) = try { @@ -237,13 +211,7 @@ class SourceProgress protected[sql]( */ @InterfaceStability.Evolving class SinkProgress protected[sql]( - val description: String, - val customMetrics: String) extends Serializable { - - /** SinkProgress without custom metrics. */ - protected[sql] def this(description: String) { - this(description, null) - } + val description: String) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -254,12 +222,6 @@ class SinkProgress protected[sql]( override def toString: String = prettyJson private[sql] def jsonValue: JValue = { - val jsonVal = ("description" -> JString(description)) - - if (customMetrics != null) { - jsonVal ~ ("customMetrics" -> parse(customMetrics)) - } else { - jsonVal - } + ("description" -> JString(description)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 50f13bee251ea..61857365ac989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -63,25 +63,4 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } - - test("writer metrics") { - val sink = new MemorySinkV2 - val schema = new StructType().add("i", "int") - val writeSupport = new MemoryStreamingWriteSupport( - sink, OutputMode.Append(), schema) - // batch 0 - writeSupport.commit(0, - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) - )) - assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":6}") - // batch 1 - writeSupport.commit(1, - Array( - MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) - )) - assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":8}") - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 73592526fb0f7..1dd817545a969 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -22,8 +22,6 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable import org.apache.commons.lang3.RandomStringUtils -import org.json4s.NoTypeHints -import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout @@ -457,31 +455,6 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("Check if custom metrics are reported") { - val streamInput = MemoryStream[Int] - implicit val formats = Serialization.formats(NoTypeHints) - testStream(streamInput.toDF(), useV2Sink = true)( - AddData(streamInput, 1, 2, 3), - CheckAnswer(1, 2, 3), - AssertOnQuery { q => - val lastProgress = getLastProgressWithData(q) - assert(lastProgress.nonEmpty) - assert(lastProgress.get.numInputRows == 3) - assert(lastProgress.get.sink.customMetrics == "{\"numRows\":3}") - true - }, - AddData(streamInput, 4, 5, 6, 7), - CheckAnswer(1, 2, 3, 4, 5, 6, 7), - AssertOnQuery { q => - val lastProgress = getLastProgressWithData(q) - assert(lastProgress.nonEmpty) - assert(lastProgress.get.numInputRows == 4) - assert(lastProgress.get.sink.customMetrics == "{\"numRows\":7}") - true - } - ) - } - test("input row calculation with same V1 source used twice in self-join") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") From 341b55a58964b1966a1919ac0774c8be5d5e7251 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Sep 2018 21:13:16 +0800 Subject: [PATCH 1545/2461] [SPARK-25044][SQL][FOLLOWUP] add back UserDefinedFunction.inputTypes ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/22259 . Scala case class has a wide surface: apply, unapply, accessors, copy, etc. In https://github.com/apache/spark/pull/22259 , we change the type of `UserDefinedFunction.inputTypes` from `Option[Seq[DataType]]` to `Option[Seq[Schema]]`. This breaks backward compatibility. This PR changes the type back, and use a `var` to keep the new nullable info. ## How was this patch tested? N/A Closes #22319 from cloud-fan/revert. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- project/MimaExcludes.scala | 5 - .../apache/spark/sql/UDFRegistration.scala | 196 +++++++++--------- .../sql/expressions/UserDefinedFunction.scala | 29 ++- .../org/apache/spark/sql/functions.scala | 78 +++---- 4 files changed, 163 insertions(+), 145 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 45cc5ccf2ea92..7ff783da130af 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -47,11 +47,6 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), - // [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12 - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), - // [SPARK-24296][CORE] Replicate large blocks as a stream. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), // [SPARK-23528] Add numIter to ClusteringSummary diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 24ee46d0e8147..c37ba0c60c3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction -import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i] :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i] :: $s"}) println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -122,16 +122,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try($inputTypes).toOption + | val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try($inputSchemas).toOption | def builder(e: Seq[Expression]) = if (e.length == $x) { - | ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - | udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + | ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + | udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $x; Found: " + e.length) | } | functionRegistry.createOrReplaceTempFunction(name, builder) - | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) } @@ -168,16 +168,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -188,16 +188,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -208,16 +208,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -228,16 +228,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -248,16 +248,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -268,16 +268,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -288,16 +288,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -308,16 +308,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -328,16 +328,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -348,16 +348,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -368,16 +368,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -388,16 +388,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -408,16 +408,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -428,16 +428,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -448,16 +448,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -468,16 +468,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -488,16 +488,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -508,16 +508,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -528,16 +528,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -548,16 +548,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -568,16 +568,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -588,16 +588,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } @@ -608,16 +608,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: ScalaReflection.schemaFor[A22] :: Nil).toOption + val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: ScalaReflection.schemaFor[A22] :: Nil).toOption def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputTypes.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, + udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) if (nullable) udf else udf.asNonNullable() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 7bd20dbe8f6d0..697757f8a73ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -41,12 +41,16 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Option[Seq[ScalaReflection.Schema]]) { + inputTypes: Option[Seq[DataType]]) { private var _nameOption: Option[String] = None private var _nullable: Boolean = true private var _deterministic: Boolean = true + // This is a `var` instead of in the constructor for backward compatibility of this case class. + // TODO: revisit this case class in Spark 3.0, and narrow down the public surface. + private[sql] var nullableTypes: Option[Seq[Boolean]] = None + /** * Returns true when the UDF can return a nullable value. * @@ -69,15 +73,19 @@ case class UserDefinedFunction protected[sql] ( */ @scala.annotation.varargs def apply(exprs: Column*): Column = { + if (inputTypes.isDefined && nullableTypes.isDefined) { + require(inputTypes.get.length == nullableTypes.get.length) + } + Column(ScalaUDF( f, dataType, exprs.map(_.expr), - inputTypes.map(_.map(_.dataType)).getOrElse(Nil), + inputTypes.getOrElse(Nil), udfName = _nameOption, nullable = _nullable, udfDeterministic = _deterministic, - nullableTypes = inputTypes.map(_.map(_.nullable)).getOrElse(Nil))) + nullableTypes = nullableTypes.getOrElse(Nil))) } private def copyAll(): UserDefinedFunction = { @@ -85,6 +93,7 @@ case class UserDefinedFunction protected[sql] ( udf._nameOption = _nameOption udf._nullable = _nullable udf._deterministic = _deterministic + udf.nullableTypes = nullableTypes udf } @@ -129,3 +138,17 @@ case class UserDefinedFunction protected[sql] ( } } } + +// We have to use a name different than `UserDefinedFunction` here, to avoid breaking the binary +// compatibility of the auto-generate UserDefinedFunction object. +private[sql] object SparkUserDefinedFunction { + + def create( + f: AnyRef, + dataType: DataType, + inputSchemas: Option[Seq[ScalaReflection.Schema]]): UserDefinedFunction = { + val udf = new UserDefinedFunction(f, dataType, inputSchemas.map(_.map(_.dataType))) + udf.nullableTypes = inputSchemas.map(_.map(_.nullable)) + udf + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a261a7c1752d0..c120be469a268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3819,7 +3819,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -3832,8 +3832,8 @@ object functions { | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputTypes = Try($inputTypes).toOption - | val udf = UserDefinedFunction(f, dataType, inputTypes) + | val inputSchemas = Try($inputTypes).toOption + | val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) } @@ -3856,7 +3856,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = f$anyCast.call($anyParams) - | UserDefinedFunction($funcCall, returnType, inputTypes = None) + | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = None) |}""".stripMargin) } @@ -3877,8 +3877,8 @@ object functions { */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3893,8 +3893,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3909,8 +3909,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3925,8 +3925,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3941,8 +3941,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3957,8 +3957,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3973,8 +3973,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3989,8 +3989,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4005,8 +4005,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4021,8 +4021,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4037,8 +4037,8 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption + val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4057,7 +4057,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF0[Any]].call() - UserDefinedFunction(() => func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = None) } /** @@ -4071,7 +4071,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4085,7 +4085,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4099,7 +4099,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4113,7 +4113,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4127,7 +4127,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4141,7 +4141,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4155,7 +4155,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4169,7 +4169,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4183,7 +4183,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } /** @@ -4197,7 +4197,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - UserDefinedFunction(func, returnType, inputTypes = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) } // scalastyle:on parameter.number @@ -4216,7 +4216,7 @@ object functions { * @since 2.0.0 */ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { - UserDefinedFunction(f, dataType, None) + SparkUserDefinedFunction.create(f, dataType, inputSchemas = None) } /** From 8440e3072898e545f0aadbbf123df3450f222a67 Mon Sep 17 00:00:00 2001 From: LucaCanali Date: Wed, 5 Sep 2018 06:58:15 -0700 Subject: [PATCH 1546/2461] [SPARK-25228][CORE] Add executor CPU time metric. ## What changes were proposed in this pull request? Add a new metric to measure the executor's process (JVM) CPU time. ## How was this patch tested? Manually tested on a Spark cluster (see SPARK-25228 for an example screenshot). Closes #22218 from LucaCanali/AddExecutrCPUTimeMetric. Authored-by: LucaCanali Signed-off-by: Sean Owen --- .../spark/executor/ExecutorSource.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 669ce63325d0e..a8264022a0aff 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -17,9 +17,12 @@ package org.apache.spark.executor +import java.lang.management.ManagementFactory import java.util.concurrent.ThreadPoolExecutor +import javax.management.{MBeanServer, ObjectName} import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem @@ -73,6 +76,24 @@ class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends registerFileSystemStat(scheme, "write_ops", _.getWriteOps(), 0) } + // Dropwizard metrics gauge measuring the executor's process CPU time. + // This Gauge will try to get and return the JVM Process CPU time or return -1 otherwise. + // The CPU time value is returned in nanoseconds. + // It will use proprietary extensions such as com.sun.management.OperatingSystemMXBean or + // com.ibm.lang.management.OperatingSystemMXBean, if available. + metricRegistry.register(MetricRegistry.name("jvmCpuTime"), new Gauge[Long] { + val mBean: MBeanServer = ManagementFactory.getPlatformMBeanServer + val name = new ObjectName("java.lang", "type", "OperatingSystem") + override def getValue: Long = { + try { + // return JVM process CPU time if the ProcessCpuTime method is available + mBean.getAttribute(name, "ProcessCpuTime").asInstanceOf[Long] + } catch { + case NonFatal(_) => -1L + } + } + }) + // Expose executor task metrics using the Dropwizard metrics system. // The list is taken from TaskMetrics.scala val METRIC_CPU_TIME = metricRegistry.counter(MetricRegistry.name("cpuTime")) From 39a02d8f75def7191c66d388729ba1721c92188d Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Wed, 5 Sep 2018 09:41:05 -0700 Subject: [PATCH 1547/2461] [SPARK-24415][CORE] Fixed the aggregated stage metrics by retaining stage objects in liveStages until all tasks are complete The problem occurs because stage object is removed from liveStages in AppStatusListener onStageCompletion. Because of this any onTaskEnd event received after onStageCompletion event do not update stage metrics. The fix is to retain stage objects in liveStages until all tasks are complete. 1. Fixed the reproducible example posted in the JIRA 2. Added unit test Closes #22209 from ankuriitg/ankurgupta/SPARK-24415. Authored-by: ankurgupta Signed-off-by: Marcelo Vanzin --- .../spark/status/AppStatusListener.scala | 61 ++++++++++++++----- .../spark/status/AppStatusListenerSuite.scala | 55 +++++++++++++++++ .../spark/streaming/UISeleniumSuite.scala | 9 ++- 3 files changed, 108 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 5ea161cd0d151..91b75e4852999 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -350,11 +350,20 @@ private[spark] class AppStatusListener( val e = it.next() if (job.stageIds.contains(e.getKey()._1)) { val stage = e.getValue() - stage.status = v1.StageStatus.SKIPPED - job.skippedStages += stage.info.stageId - job.skippedTasks += stage.info.numTasks - it.remove() - update(stage, now) + if (v1.StageStatus.PENDING.equals(stage.status)) { + stage.status = v1.StageStatus.SKIPPED + job.skippedStages += stage.info.stageId + job.skippedTasks += stage.info.numTasks + job.activeStages -= 1 + + pools.get(stage.schedulingPool).foreach { pool => + pool.stageIds = pool.stageIds - stage.info.stageId + update(pool, now) + } + + it.remove() + update(stage, now, last = true) + } } } @@ -506,7 +515,16 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) } - maybeUpdate(stage, now) + // [SPARK-24415] Wait for all tasks to finish before removing stage from live list + val removeStage = + stage.activeTasks == 0 && + (v1.StageStatus.COMPLETE.equals(stage.status) || + v1.StageStatus.FAILED.equals(stage.status)) + if (removeStage) { + update(stage, now, last = true) + } else { + maybeUpdate(stage, now) + } // Store both stage ID and task index in a single long variable for tracking at job level. val taskIndex = (event.stageId.toLong << Integer.SIZE) | event.taskInfo.index @@ -521,7 +539,7 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) } - maybeUpdate(job, now) + conditionalLiveUpdate(job, now, removeStage) } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -532,7 +550,7 @@ private[spark] class AppStatusListener( if (metricsDelta != null) { esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } - maybeUpdate(esummary, now) + conditionalLiveUpdate(esummary, now, removeStage) if (!stage.cleaning && stage.savedTasks.get() > maxTasksPerStage) { stage.cleaning = true @@ -540,6 +558,9 @@ private[spark] class AppStatusListener( cleanupTasks(stage) } } + if (removeStage) { + liveStages.remove((event.stageId, event.stageAttemptId)) + } } liveExecutors.get(event.taskInfo.executorId).foreach { exec => @@ -564,17 +585,13 @@ private[spark] class AppStatusListener( // Force an update on live applications when the number of active tasks reaches 0. This is // checked in some tests (e.g. SQLTestUtilsBase) so it needs to be reliably up to date. - if (exec.activeTasks == 0) { - liveUpdate(exec, now) - } else { - maybeUpdate(exec, now) - } + conditionalLiveUpdate(exec, now, exec.activeTasks == 0) } } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { val maybeStage = - Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) + Option(liveStages.get((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -608,7 +625,6 @@ private[spark] class AppStatusListener( } stage.executorSummaries.values.foreach(update(_, now)) - update(stage, now, last = true) val executorIdsForStage = stage.blackListedExecutors executorIdsForStage.foreach { executorId => @@ -616,6 +632,13 @@ private[spark] class AppStatusListener( removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) } } + + // Remove stage only if there are no active tasks remaining + val removeStage = stage.activeTasks == 0 + update(stage, now, last = removeStage) + if (removeStage) { + liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber)) + } } appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) @@ -882,6 +905,14 @@ private[spark] class AppStatusListener( } } + private def conditionalLiveUpdate(entity: LiveEntity, now: Long, condition: Boolean): Unit = { + if (condition) { + liveUpdate(entity, now) + } else { + maybeUpdate(entity, now) + } + } + private def cleanupExecutors(count: Long): Unit = { // Because the limit is on the number of *dead* executors, we need to calculate whether // there are actually enough dead executors to be deleted. diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 1b3639ad64a73..ea80fea905340 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1190,6 +1190,61 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) } + test("SPARK-24415: update metrics for tasks that finish late") { + val listener = new AppStatusListener(store, conf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Start job + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start 2 stages + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + + // Start 2 Tasks + val tasks = createTasks(2, Array("1")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task)) + } + + // Task 1 Finished + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + + // Stage 1 Completed + stage1.failureReason = Some("Failed") + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Task 2 Killed + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", + TaskKilled(reason = "Killed"), tasks(1), null)) + + // Ensure killed task metrics are updated + val allStages = store.view(classOf[StageDataWrapper]).reverse().asScala.map(_.info) + val failedStages = allStages.filter(_.status == v1.StageStatus.FAILED) + assert(failedStages.size == 1) + assert(failedStages.head.numKilledTasks == 1) + assert(failedStages.head.numCompleteTasks == 1) + + val allJobs = store.view(classOf[JobDataWrapper]).reverse().asScala.map(_.info) + assert(allJobs.size == 1) + assert(allJobs.head.numKilledTasks == 1) + assert(allJobs.head.numCompletedTasks == 1) + assert(allJobs.head.numActiveStages == 1) + assert(allJobs.head.numFailedStages == 1) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index f2204a1870933..957feca2e552d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -77,7 +77,12 @@ class UISeleniumSuite inputStream.foreachRDD { rdd => rdd.foreach(_ => {}) try { - rdd.foreach(_ => throw new RuntimeException("Oops")) + rdd.foreach { _ => + // Failing the task with id 15 to ensure only one task fails + if (TaskContext.get.taskAttemptId() % 15 == 0) { + throw new RuntimeException("Oops") + } + } } catch { case e: SparkException if e.getMessage.contains("Oops") => } @@ -166,7 +171,7 @@ class UISeleniumSuite // Check job progress findAll(cssSelector(""".progress-cell""")).map(_.text).toList should be ( - List("4/4", "4/4", "4/4", "0/4 (1 failed)")) + List("4/4", "4/4", "4/4", "3/4 (1 failed)")) // Check stacktrace val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.underlying).toSeq From c66eef84409b103accfcaa5073d50426e70b7870 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Sep 2018 11:29:15 -0700 Subject: [PATCH 1548/2461] [SPARK-25306][SQL][FOLLOWUP] Change `test` to `ignore` in FilterPushdownBenchmark ## What changes were proposed in this pull request? This is a follow-up of #22313 and aim to ignore the micro benchmark test which takes over 2 minutes in Jenkins. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4939/consoleFull ## How was this patch tested? The test case should be ignored in Jenkins. ``` [info] FilterPushdownBenchmark: ... [info] - Pushdown benchmark with many filters !!! IGNORED !!! ``` Closes #22336 from dongjoon-hyun/SPARK-25306-2. Authored-by: Dongjoon Hyun Signed-off-by: Xiao Li --- .../spark/sql/execution/benchmark/FilterPushdownBenchmark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 41087f1a97174..8596abd1b4ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -399,7 +399,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter } } - test(s"Pushdown benchmark with many filters") { + ignore(s"Pushdown benchmark with many filters") { val numRows = 1 val width = 500 From 925449283dcaef80e0f77e60aea6ef988bd697b4 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 5 Sep 2018 11:59:00 -0700 Subject: [PATCH 1549/2461] [SPARK-22666][ML][SQL] Spark datasource for image format ## What changes were proposed in this pull request? Implement an image schema datasource. This image datasource support: - partition discovery (loading partitioned images) - dropImageFailures (the same behavior with `ImageSchema.readImage`) - path wildcard matching (the same behavior with `ImageSchema.readImage`) - loading recursively from directory (different from `ImageSchema.readImage`, but use such path: `/path/to/dir/**`) This datasource **NOT** support: - specify `numPartitions` (it will be determined by datasource automatically) - sampling (you can use `df.sample` later but the sampling operator won't be pushdown to datasource) ## How was this patch tested? Unit tests. ## Benchmark I benchmark and compare the cost time between old `ImageSchema.read` API and my image datasource. **cluster**: 4 nodes, each with 64GB memory, 8 cores CPU **test dataset**: Flickr8k_Dataset (about 8091 images) **time cost**: - My image datasource time (automatically generate 258 partitions): 38.04s - `ImageSchema.read` time (set 16 partitions): 68.4s - `ImageSchema.read` time (set 258 partitions): 90.6s **time cost when increase image number by double (clone Flickr8k_Dataset and loads double number images)**: - My image datasource time (automatically generate 515 partitions): 95.4s - `ImageSchema.read` (set 32 partitions): 109s - `ImageSchema.read` (set 515 partitions): 105s So we can see that my image datasource implementation (this PR) bring some performance improvement compared against old`ImageSchema.read` API. Closes #22328 from WeichenXu123/image_datasource. Authored-by: WeichenXu Signed-off-by: Xiangrui Meng --- .../kittens/29.5.a_b_EGDP022204.jpg | Bin .../images/{ => origin}/kittens/54893.jpg | Bin .../images/{ => origin}/kittens/DP153539.jpg | Bin .../images/{ => origin}/kittens/DP802813.jpg | Bin .../images/{ => origin}/kittens/not-image.txt | 0 data/mllib/images/origin/license.txt | 13 ++ .../{ => origin}/multi-channel/BGRA.png | Bin .../multi-channel/BGRA_alpha_60.png | Bin .../multi-channel/chr30.4.184.jpg | Bin .../{ => origin}/multi-channel/grayscale.jpg | Bin .../date=2018-01/29.5.a_b_EGDP022204.jpg | Bin 0 -> 27295 bytes .../cls=kittens/date=2018-01/not-image.txt | 1 + .../cls=kittens/date=2018-02/54893.jpg | Bin 0 -> 35914 bytes .../cls=kittens/date=2018-02/DP153539.jpg | Bin 0 -> 26354 bytes .../cls=kittens/date=2018-02/DP802813.jpg | Bin 0 -> 30432 bytes .../cls=multichannel/date=2018-01/BGRA.png | Bin 0 -> 683 bytes .../date=2018-01/BGRA_alpha_60.png | Bin 0 -> 747 bytes .../date=2018-02/chr30.4.184.jpg | Bin 0 -> 59472 bytes .../date=2018-02/grayscale.jpg | Bin 0 -> 36728 bytes ...pache.spark.sql.sources.DataSourceRegister | 1 + .../ml/source/image/ImageDataSource.scala | 53 ++++++++ .../ml/source/image/ImageFileFormat.scala | 100 +++++++++++++++ .../spark/ml/source/image/ImageOptions.scala | 32 +++++ .../spark/ml/image/ImageSchemaSuite.scala | 2 +- .../source/image/ImageFileFormatSuite.scala | 119 ++++++++++++++++++ python/pyspark/ml/image.py | 2 +- python/pyspark/ml/tests.py | 4 +- 27 files changed, 323 insertions(+), 4 deletions(-) rename data/mllib/images/{ => origin}/kittens/29.5.a_b_EGDP022204.jpg (100%) rename data/mllib/images/{ => origin}/kittens/54893.jpg (100%) rename data/mllib/images/{ => origin}/kittens/DP153539.jpg (100%) rename data/mllib/images/{ => origin}/kittens/DP802813.jpg (100%) rename data/mllib/images/{ => origin}/kittens/not-image.txt (100%) create mode 100644 data/mllib/images/origin/license.txt rename data/mllib/images/{ => origin}/multi-channel/BGRA.png (100%) rename data/mllib/images/{ => origin}/multi-channel/BGRA_alpha_60.png (100%) rename data/mllib/images/{ => origin}/multi-channel/chr30.4.184.jpg (100%) rename data/mllib/images/{ => origin}/multi-channel/grayscale.jpg (100%) create mode 100644 data/mllib/images/partitioned/cls=kittens/date=2018-01/29.5.a_b_EGDP022204.jpg create mode 100644 data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt create mode 100644 data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg create mode 100644 data/mllib/images/partitioned/cls=kittens/date=2018-02/DP153539.jpg create mode 100644 data/mllib/images/partitioned/cls=kittens/date=2018-02/DP802813.jpg create mode 100644 data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA.png create mode 100644 data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA_alpha_60.png create mode 100644 data/mllib/images/partitioned/cls=multichannel/date=2018-02/chr30.4.184.jpg create mode 100644 data/mllib/images/partitioned/cls=multichannel/date=2018-02/grayscale.jpg create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala diff --git a/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg b/data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg similarity index 100% rename from data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg rename to data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg diff --git a/data/mllib/images/kittens/54893.jpg b/data/mllib/images/origin/kittens/54893.jpg similarity index 100% rename from data/mllib/images/kittens/54893.jpg rename to data/mllib/images/origin/kittens/54893.jpg diff --git a/data/mllib/images/kittens/DP153539.jpg b/data/mllib/images/origin/kittens/DP153539.jpg similarity index 100% rename from data/mllib/images/kittens/DP153539.jpg rename to data/mllib/images/origin/kittens/DP153539.jpg diff --git a/data/mllib/images/kittens/DP802813.jpg b/data/mllib/images/origin/kittens/DP802813.jpg similarity index 100% rename from data/mllib/images/kittens/DP802813.jpg rename to data/mllib/images/origin/kittens/DP802813.jpg diff --git a/data/mllib/images/kittens/not-image.txt b/data/mllib/images/origin/kittens/not-image.txt similarity index 100% rename from data/mllib/images/kittens/not-image.txt rename to data/mllib/images/origin/kittens/not-image.txt diff --git a/data/mllib/images/origin/license.txt b/data/mllib/images/origin/license.txt new file mode 100644 index 0000000000000..052f302c4670a --- /dev/null +++ b/data/mllib/images/origin/license.txt @@ -0,0 +1,13 @@ +The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved: +https://creativecommons.org/share-your-work/public-domain/cc0/ +The images are taken from: +https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q== +https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA== +https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ== +https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw== + +The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from: +https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw== + +The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from: +https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png diff --git a/data/mllib/images/multi-channel/BGRA.png b/data/mllib/images/origin/multi-channel/BGRA.png similarity index 100% rename from data/mllib/images/multi-channel/BGRA.png rename to data/mllib/images/origin/multi-channel/BGRA.png diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/origin/multi-channel/BGRA_alpha_60.png similarity index 100% rename from data/mllib/images/multi-channel/BGRA_alpha_60.png rename to data/mllib/images/origin/multi-channel/BGRA_alpha_60.png diff --git a/data/mllib/images/multi-channel/chr30.4.184.jpg b/data/mllib/images/origin/multi-channel/chr30.4.184.jpg similarity index 100% rename from data/mllib/images/multi-channel/chr30.4.184.jpg rename to data/mllib/images/origin/multi-channel/chr30.4.184.jpg diff --git a/data/mllib/images/multi-channel/grayscale.jpg b/data/mllib/images/origin/multi-channel/grayscale.jpg similarity index 100% rename from data/mllib/images/multi-channel/grayscale.jpg rename to data/mllib/images/origin/multi-channel/grayscale.jpg diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-01/29.5.a_b_EGDP022204.jpg b/data/mllib/images/partitioned/cls=kittens/date=2018-01/29.5.a_b_EGDP022204.jpg new file mode 100644 index 0000000000000000000000000000000000000000..435e7dfd6a459822d7f8feaeec2b1f3c5b2a3f76 GIT binary patch literal 27295 zcmeGCcUV-*vM>&>VaQRUAZY{yCFh)xBoZV?5rzSVFat9~&Zrp3ARwrqf*_JqG9p28 z5{Z(t1Q8?%f}s4?pxeFAzUSQcp8G!U_k4exhFPn+s=B(my1F{8qmiQ-;KT(jT`d3t zfdHxCA8<4aaB2D@odG~!AK(W7fE1+B0Z zpO6qwBv4`_0G|K>8-Wo(rR{^oIrySs8V;UbST+JPfJN8S!3BwOft^QSkuE5hi5CKn zbnrl8acl%w0COJcg2Xv^!0=oMd6=my0;Z3^VbES^4-k)n>HA<2KAteNGfWMG6NDKc z{9smSjJpEN6^HYZ7ZLIE^Aq+&;5?7ngwYrm5hThPjq!BAA<-xi90KEs6>>m13Bl1Q zCnO#RDn$g4ipvR0IS9fW*$7Yo$A4cb8-YK>TEhVc^F!iXVeUvA7KTQ_oE-29JHZ^$ z{xGmUm;(yw2{yyi$I~-F5GE}xITxNRAqrC$))1!BbP4dnAstYHu!{~j1k4D5L86^N zx)BcTf$(%d!hfNtG|*mN2vC-X2Rc9jrXvg!6NAA$VP_>_QX*o~FeFO&oC3^37$zqy zd5o6;8U7?(1oN%0yrc96S-GnwF*@iVe^Ni~tkB5HJ8>AnXNT0W^RDVG{rc z1OOg@4~T(*JQxrU10(=YXFwD$?N>EM7&O}11cCGMGIVrF9&>z#&n`B0&<_F}H@kmEKP~Dd71Dguut2cB1=>oeV)o z1AQE%pTu)&c;QfgG#xaJBj$I;*afTqJEMU?Y5bW%;r`5Abo9XfrWm;3e1B&&Jw49< z&VbhcQ?xqV-Q|}dj++Mn#@gx{pf>oB0vJ2NU`}WsM^*F;K6{+BG2npu8&2Khzw)VL zoJ`G7I4y2t54=bGPU;>`uz!X(#(Ln6@kah0s+fj9_{n`i&qQO+JK!AfEjU)v$i)cz zhcDq_JTLghs|{R4IR6y$w`xtnSoqJXHGzBl#2aDYs#bUj$LRzd3~Lsy2(WWOzXn6% zALgUR6O2Kf)X_K`+S3D#a`{z+@(&7Lg}>mb{=gd}U0nZ4p!kCT>h_x-AFH8jKXt%I z{NuR(>wO%@d5>w-zbXY^D&fGfzd^T-apd4Eah!f)38{{;_~_4ZT%P40^U&?%JlXLR zd_az6@!>rd^#5O@8$=)LwR4Qs0YsdgL5c_XBtFJKeqkP^XZ*&9gA@p04L<+ikvhiU z$HOnovCKU%58%h!Pp$v{gf8RBf5-mJ|KHG#Z9$5c{m1AePz6UT`1@ry4m<=L&zi@) zr|?-Qz70RM$7ccRO8~$abPVEK`(M*vaCkmE|39X`F#IO`87=^nj|V;$f|E8t?1=V3 zIbr`cTH!b`kanyl-o=0W5s3Vo9|G#X{q?6B#=qV1xCz8q4_xxY#B5}c>eZ2kwMU2LP%Z1;uQqu%4-S8L(%Aj)a!J)MgCe}{#{@zClNpb}Q*=FDVBqWjb^D<{7>{4m%&$(t8~-0M`dF7grlOxb#10-fQwNuS z;G%|u6R;`5AE%4eG1b5LN7y3$1N{%oWUgpTpqdA`+WHmIXnrx7aA+@Z62~I`8mGuT zz<~B&u#}F+0r0=zDZnmv{X6^d2s(?e2%K+0cUIXrhJVhJ4976ug^!~Q;pyYN89t9U zG7X>xIDxadBY*(#GrKDY&G7y<(aK>vT>fB_od`al{l5ICBnG}6`3us1R`)Y84E zd0Yjh>O~|9O#lu%PjGGoCqS5`l{JjyEjaARz|js4Fb6o+OW#BbzXSnvH8o&ZkO>d} znRbVNj{gB69U~a*PyGJ~qJ?{5z)2Ef1aeFG;k@whH4tWSbU%jg;$aM^6o@ax!!AGJ zhGU+e@E{&`@r75Wa`Y*4!V1$O)k?1aF=!3Em{$PdSXo)5kO zaDg0f>Ub=R3Ezi*Sq>u+ zaAACdf2aQ=!*9-i4g94(VSIhkVsg^bLgK<=;O^>pIK%Ls1}?JEJ{TC*%K?sn3H{x% z{?`TnqSjyZ5Htoi2N(njtjZkp0{l)-$IpLw5X6I70&uqq z!=p8fO<`~!j4vMkH7gQ-0J4A_ARj0OUILXsEzk(G0$sp6UAgwg{(lfAm0d}1e62}1ndO71R?~o1Zo6&1f~Rb1TF;L1c3zC2;vCt5{1Mu$gdxaF%d`@COkU5gU;Jkvx$u zktLA}ksnbMQ7Ta`Q8`fy(GbxGqFrKQVkTlfVi{r`Vk=@d;vnJ};s?Yf#0|uK#Ph^E zBqSs(BmyK#B!(moBsh{Nl5~Baa}zOI}RgLOwyhO+i7y zMIlFVi2_LxPH~swIYm3gEX6)09VI`d2Bj^f4`m!>9%TdNDCIU46%{X)DwQ=Ajw+5S zpQ?#!ifWIVo?4h%m)eOsjQSpR1$965It>}k85%X3%QS&BsWhcDy)>Urke%Q?p>e|D zMCgfoC#p}pKe0ngPb)@iNb5-(L;H-jgLavYgpQX^lg^3m8eKMB6Wu&Ll%AVjgC0(Q zl|F~Qh5iErF#{iiE`u9GG(!=?TZRorIz~xGb4Gv0dyKCcXPBT&yiB@G9!&8}FPVm! zzMW)0sdf@^^47`Xll>>ZGP5$PG9#F8GnX(AF@Ix$v1qcmvm~-qvP_;LJjH+N(kZ`F znWtJ#ePU%`RbYj)-e!HtI>tu8#?NNV7Q~j#*3GuV&cUwD?!}(Y-o*ZigNZ|x!;Rw( zM;*rsCq1V!CzA6HXFcaf7$Zy#<^j74Yldx{W;?BO+UNA6({E27aPe`OaYb^Ka7}WP zbIWo&bKl{9!@a@7$)nE`%u~QKc82VX>>1ZHX=mEbeC6fiwcx$MTgki3$IPe87syw@ zH*uEgtjbyMvyaaXo+Cadd(QpbgLCir3HYV>UHR|v_X-dQNDCkZ9tiXa5(&x)q68lc zz89htQWNqODiE3#W)i+A94TBWye`5cVl9#=(kAjlR7%uCG)Ht?j85!=Sfp5u*p@iI zI9&Xm_@D%pgqB2@M3uy*q<|zsGE;Iyie5@z>W0)CsqfM<(irJt=@l6snJY5)W!}p& z$QsGU$hON7$*IeQ$<@m3%S+4SU>*}KF81-`XuNv|iAsP+mq35;F-#-6V zlTOo2GedJqi%ZKzt3Yc*TS_}fyFrIYM^7hF=e;hwE?hTXcjJQ0g^&x)dX#!5dKr3i z7x^z@F4pK1=wHxJ)}JsqV}LTKGz1KF4U-Kgjrfebjb2?MzGQUi-laujG2>w4HWPXi zdy{7-yQb=<@uuTuXU%-f-k8&v+nVQ_f3-Mokz_GrDPkF7*=2Rg3TahkO=4|sool^g zqiK_3^TAfiHp+I`j?d2DuH*8l%O00s+tb)P*q2@*x?*wV*_A^FLx;x>JC3@J_Z>Ij z8t^pus*|eI9j6t95+Vt)N;;Q1B>iP+(fy_W|y6L!OxqWpvbkB1? z^04qI@g()U;#rNNL%E}x(Hv-hbf1@iSCrR`x4d_X_Xb7}lZ%C5?XXoi25_;{?Q_=W zy3eeyvhO|LJwJ25Qh!>1lz(>sf55GPr9iE~CqaZk@Sr!rr-LJcXG7FN9)&_e9Yf!Q zafe+C`w*@j{xpI-!abrVQZzCt^2=3=t2NiyuZ3Tmzpj0~Ac{H)6ZQUv!i}t(gg22l z-`6NnQ~3B!phiBFSglLC|G zll7A;?!fNczVju;KBXg7GBqoWJk2L<=I+J2Rp~tG3F(I!t{Fr3)bEwtXS;vv{?`Xi z5Bf7zGmEp>v!b)UJw!end8GZQ;xX^z)NJBx-|WR4^PILPvQP4JPvu7Ee$PYY%{(=J z+MF+wpZ|>QS$qLOflt9op>5&2qVq*n#X`l8N|;Kbp99Z*o_~CC<;Czzy_auFWlM|8 z&XnCRr!T)<0jUV6*r;@^oT;*^>Z`s`{ia5tru>!gtGwFNwf9~#zK*XWue(-vR3BKs z)8N(csnMlz{*C>c@g~cr!RAZNZ(Hf4ptYdhpSsyk&nE4!q-D!Qe* z%X_4H%HK-At>~5Mt$HW-uBK11udZLEzj5IFK;6 zqnl&CV+Z5m6GRiYCr?bKPO(iro<28SG9x|ndiMNm_ngVx_&j`mbpg9@_~F_j)ne)r z=TiQ%_;T%v*2=q&HXlE%daWLOidv&xyT5*Rz3j8<=k5*5jSric&7-ZDZI=^A#?RxGWe7(KLviEdfcE9bL#kZvczeAG4yWh`!ulb?(W8%p3=;$wNA2R%!T^I=m z*QvkLKg0bd{!f9U*%P+E){-c!u!9rY5g`mmdy4owc!`J!i;4h>D*oW=&K-e+IU>Lv zu@c|sSM_`_q>~b#g_ORizLy5V6{!=5L6`&@n8E|y;c`xVD#}#kivIHco?hU2xdY7K z(*uQ-_gCUNmM#y%c(e!??3e}TuEh28NM0U)H(&}LlYyts60%&N5HdxKle4_>d97bE zz?>4-Usk?;636H8;ObmXL{wZvTwDm`5W)taa1Q=LC@lA}jGsczgKK;Y^7zaWh8GH+ z$NAutxIiV2wbS*&!6iY3HP9Hu?_031 zll(7B8-sp>M*X)<`NeW9;_o&57hB>TgG^Bb-;00g)t}hEA^%xdJHdZ zMUlT(qX;7Kga4Pq{!t8Hx0(m;Z*}86Sl+_{<)XypFZAoIg-iMW>Kcjse|(Qb{{QA8 ziTuB}=l|VJ68T>u$Uk1_frm=}v#X?k+M&O@=D0)uOaI|J6^F!mApW0s=l?*vKNb1k z4p$lDw=D2)G5^QA{$s9x z%L4xv^Z(y@*PkZ^2o!ic;0K-q9F2pQmjn<9{{L5k5<&5an2-=kL_!Q6xPw1ZG72(M zQgTuf5^^eXatcZ?k&sc-P*GCjVZ0E0aeOK8pOTb>6wmSRrK1LbmK11*2tpxT00Auo zN((t^1~|c686xlk1WbP%m_P`igha%kO5_wEL-`3%eJGRwRF(+;5h5f6qyvPsM0BUc z)QIUXIgoI9Gl)kdJ|g8-f6>Tj+_%9a;fRSOBWF6v%yQ}sFW=d7{E|}AGO}{=8s{~& zv~_eZn3$TGTYy(ZaB#2Y;)-;`;(UDl`~w27Ub`N3Ml(z5c3%Bt#`H%-kgt!?cco&5uYL&NV!M#pC7<`+IJE-kNo+}zs!va|bjZ~q&< zT@V2Jvsu4e_P2J?g6$#z`+$%H-!2G&ANYgP5)z#jBc@ZkMB?C0&m|r~%AlV3=tUzL zw}kNqqa&t|oQX$r_RJ=}X~!-5pEWG=zV`9d$YjXlJOHi;`vh62i4op`N}UQ908@~sT1C}GtT;MUS0dF zcIFbb$~Dt`CX)Im);6LvlWvJZrK*hCmAzB>1fHn-W>441*k$vxU+xEbA3h~`rWL^I zpU(unUCCZd(^)&XN}?F&N^rAMIVyJypMxm3~H zvrw^za$^C9=i-S;3^$8ZUPvqPOTUC|sN#4v6kjYObSykm20-)vYX*nX^A!Zbe8|(f29! z{h~vaA{nbU{Gsy9_ zxsFP4o0^QKNh^|Ab;L*4S$OpK zgNfr3B8F-G;Q~({UvE9PO8$kyJgNDmX@MiKRAP-BA$OdR;7r@eH=d?2uWQS|j3osc z6O@dGrsqY66pYN3P?)Ru?L7I>+qUuU(tSr`fwaBs`Gc?rgu`SR~l;AZJ!>zj`Ujz)-F@ypct{i$~K!?n)EIPJZ^K3uG@pYPOn72z_Fl zyxKuX{-Q%dS29wz53c2V=3ZQMxO}PjQ=)X!HcF&7C04r|a~7feFy{75oisA9Ww$qP zY9;j~({^y5Q0IDv_`O})6a|{Yvi(q0cqDwjFPQsbMe?kj@URBAE*e2;%)*M5xCW%k zvhZBZBzcE>cHO`w_st#8x&;weo}3@O^H1E*eQUKnnBH+=1G3G(W|&eMgmwvril zyCTcV3vg>u<)+sFTq@OY@*_Ws$4JTKs!cpH`Bi0+`Y_SwlnLZQgsiY>4`i7h6}`OHqE&!$CNfrnow4u zV8AdZID5sEc_H39MM!3FM8;*^d`M0(s=lEjpDODwA)-nCp+vAN3zt)AmB&ToWto$ylK%|DG~D-G^$<6E^o0vgaakTda78G&W| z8BzJC?5D*z6?9<}Cz9WC8=s(V7Mpod-23UgZ}B&;ORQB?X%AKVRLNe=SzF+dFx9o#_7lbI;f3lB}iI-NQb4 zETuecNjs2)ye*H0@70svwxJGj&voa#w?0UR93R;Ty#smsI*#N*%*y$!S9Yw(YWr8; zOoj4hzEAh(s`{6CZEnT5d?}#SjPq9i5NLkQmtTJ9W}&D=+*znHmDE)y(5LgT*0d{! zB?ck<3)bG~Rg-RFr@&B~yjcr}fvD*k^OErq2HT-)r6}1IV`5w&CSuR@-N@OMZ@1s+ ziLnWBd&J2OzIt_nk->*26X3|Zy|AMaH^n&6`%P(RxQI_xP^kO%C<}Aqu2$2eu0@#8%gW`7(gMlfsD~C$s+7fU_f|B$q4$`l!3IAXmu*`g5{qHGuVY9YeFVgu z`rfZKjuISp?Ce4236L^V#_VUhT-Kgqm&ab>@6$?r!W;O~njmyJFC=FD?F6>JHcEuq zzqI+JxOSfZ1IN^pV>&OjxnY2oSlmwXJeno=LSI>Y+34z);3LIsCagQJz^3ua*c!&@Ci5pt37 z^Pw$Sm(Ij}J_5|-JTQ|2i+WjO3UiFbnWb%Wp2B;jDHI&Zt_7T(a*plH34Uylh=z&4 zyXz`$BDXYi4>*dlFOycvbUYcq*i3olx&9N34BN~SN`7E>lLpTHQO85xwBw4-vI_Zr zSCf6&*SKQI*4nC?jqeNI33(j@F#^&Z z=}tAl#x|W#oRR}Fv{k#I1lE$2kXbTL0jC3~zn}f)Zo#gdMzi((-R*N$LQaEcyt)Gm zH13P1lR>0UykhZXuTmn8J&<&j$>bbzc%3vE+RIRhsGAlP>pe}*!{^p@GQE7Bk5KyL z`j%L)pnB@fF@^y4(Y>KfJ-zpmB8=Zt=X8a+UwR12&~Bev3WqE5V(+`;f5K*k9RZR4 zhq6_Bm77^jffwH=>p$rgLl)7-Clb20x<@i>u+PQBY#DZv&5)1nt!(fL(a-rbR+h&W zjd^|4?Cc)kmNvV~qotPE0ro4K3s+sn_c+M{PKrasMzl7MS6Os`<7uSdh4hMK0#&$i zn=gaz*7}6l*1Y_){G3}U(k1ZK)&a`m?gqq5e+J>Cs-i?h8^dg#)7QBgY7-X8^2-Ii z=s>E-au4MPaf|Vvb|fLUafk6mTHpLNY>MEC{rA2<$Q)#dQTgD`V9g(pb`%kso5&Iyz{0MjxE*Ft~AMVMw>pVSc-&w=^9QeAb zFKs3tvHiL@X*__Pf8mt51-O8aT}E8S873aJ3BW$aq`Vnhvb68OchyM4Bj#L6jHuIdouyX4g`fvI=n;&#n;rg!cYg$J&aerXfchkt7 z{jr9vQ`*bU4OXsW64&+)dZr6h>A9*h8(%11TuY(`h@MNtTuBkow7X`K{E}spzGb(b zqLwkF#NPVjNptJKf{Up}`!BD%P`qN{Jlo#!Nv=e`?+8G?iZu?po6)OT>TmXhSl1$XYVfX9m^Ex3)})DOx&Im7hgP z$lbe9Ym#Rj2Wxeo>bx5n7c6aKE2yEQx$7WNdEbeSjD-12uB5xcl35IYJz}$e(>4Zy0U+W2Sc@-ZD?7~?Qr(l01pA-fn^K9aI-Ap7^C z^=oS?%PL=#m#eX*(b=Ed=H`iQqe!4`Off$#*3qil+9zrGr6huv*(oXL77YExGRJ>D z$z`(Ak6PK>gSgQlvcbtyMP@YE_TDVIagDoqz|F*FIdCDN0GS5staz*OxFpja*T}3H zQ!?c||FrPIhm92G;=Jb{GcvO{N+_`*`P?x9x&AYrFPk%LOs#iaN%C1eZ^E6&0iNm}Ubt3U5P}QS~IZ1N0Qm!~@dZsu_ z@69l>GP~%Px8&d7GOD6kEy;~Q><*%DtE^=6;9+;^SRTdBRczp=bhvfPnv?xndZD>V%&wf{E zoRS@K8C03&t#W0Wdf%JqkhI(mq-Tnfk-t2f(8#=CF(E#-r`?z}l3HUQQ+5QLuzlTq z1#Xri{HQFR7N#lg=}6wjojTjsBwz3w4Y-c*_Y1czkobkrs}fg zY7dW^cx-pi!u>&ZrfSK#0UtESo}>QPzU4sbTY;;dSgMxgfMU{p#pT&+h#bl{XM`4& zM`f6x`Iqd7lzN{kORR8kf-|N!)%I{-Y$*>B=~%vGyHIi2C>z!4t}5v z|6UucfwHtcl)Cw>=AnQ2?${`hB;g(tH}9u54YNR3=1t7)N+{nv;FE1wxp;8mO9^%V z5r8J%et%%t;Fl^^;OQG}rJEi$BScQq@cd99sObo(L`V97U$i6U*i~E|t=1l;ip7mu z`O!~bmt&WRBnr~;6BupZQFE#KK1P%gKU=^-v&oURNunBW+J}BTDCp#v>EYIC7KJG+ zgr1wvh9tJnT(#EE+y3e{oa?M$6;@hvqueOa+qz9(NOeG5DgFsmY-;-4K%Lsto|jbz z{^!X_uNA> z`0PP1<#J%!$iT&+;=|R78uPO~ti%=3+y`3)w2{=y$%Is+`z54}oM zUKPhK&)Rj1Ejm5Vvd@gP-Dmreic9C?QZXJ2YJsV=bCXfGFD*LBqg(wCLpCGzW%TQl zD58S|?WVO+yivnU&bJ;5TRugG4w_vcK0Azg#Hq6Wjb==v>8tgP2AfzBwU<@SRb^W< z!f-6xGYrGCibW5nwJ`%oJL?^pyY6GoSqChKtl&zZ>`~PlGdTzNy~bj;8hM97kHAKhCfKnK(kB(Ee$3OufTvE0ZB= zWcE%HWC9M%(77|WZk0+Vm!aZa_mZF3xz&@e^KA4?-3wv3`yf@D{~Zjc+O2+PgvF}P z_At6qj)#*AYzq_14+6_Qm6k5vY}Bxd^B{K;rler!<8d$JiC*1a~P z!D7k%TE1S~wRJu+?8`VA?i18YXaCWN3xkG2aekQT*vqh3(jomG?v_c1K7+)QpdH-dx=63svjA*B@ zuvJ^4gk+HI4BJ(<*v6#el)A<;a7uq8KMo0-%XAw>@hyfF zx;e`-e)Jt7G%1n^FZ|L+)hkrrRItaXjVFZNc)UN|M01@(Wbb+|4W<99H@hn2hbmId z9dmNEDzNI;!`wdSczoOUPv};OE64@sH_SXo%WX~{(o4)C;tHhlt{&P-T2{E2{J~*+F83UpcPOUZpKxMbKM@#KrJXYG?4c=pP-%w}X5yEzD62JClNNfpA zti0oCyAJUBZrGd&^&n;zN^;9udoj0ME3*Mt8Z3UB&T_(?4{ zBBJKPb>Fekw}rNImfAO?BRP07KDTo4PG--7J8MH^*lzn_Yiy!ULPTse^W~F(pOImS z|I%#QHDyu%z$Lv+#-iO458*ROD&jNaJ@oS79}Kw(VY1LllMaE@fko>Z;^pz=8-gmd zJQXiC`J(2NJ{!sy*$0TzNL+kfX|7Glx6SaN&eH0H=atTe8S#{FXZH^D_ljm%3d07z z*o4uJ?1RUTokGQ2y7XE6vnRPf@I|&vvM@+u8=>dlJ#AZZL6^HCtw%RC@(8LF)}uOb8YRhifha!&bEij*$ zc8<4&ab$lnjA$xYVAJ41&c|fY0aP*PClhV)!Zf#WhV1qacY&$P6p`M-Gd4UHm2O#N z<@O}V-Ha!((Y@}ECj!z4hSh4ss^OT|r_?F)Q&e*v90nH(aLf6k#Tr0$_+L^P9 z+wt1mD$JO}o_Ud&-Rea9;YG^D>kFTZ#3J2WaP?6toQmA@YAUxBEjD(?FPKLL-HQ(D zSgZ;UxR!HmEpC17ofb83;=AG#iPz2gX={C)^g?J%HeKRSY34QC<<;ES4`uFDO`40R zNxVX-A;x!=Xu2h0;3n#Ql3)4-OER|))X_rvwbw$|Os~W~=5^rpmu)rZi^(+1@*i)j z(T@Lerq2d_>l%qq@xwY>`V+cMj8_^1-d^fWZ8$NN_k1HY`kB>=vcVJakGg5CEjNY2 zJ2*1ii(#)#fNoLejR2u zFA^wCvte0hQFOu6=8K%6P-%dBkpwmak6Fli$dOfiNV9QB1l7!&lBJ2RP}0_8K!$F_9UjHI%T7pNpBb47}?dNEIoerGr!4y&0Y<;Uds z!s1B3mULK_h1jbO*mj`$b*ugKsMiX!nsr7b-VQX{0;9zBI`uPIp_0jS=uagNb-l%E z>Nmc;=isaf##oO=#NgOrNsxg(eVoaF#;LCS<)w?glWe%0A6-1)DrU)!b|C6=4U$tI zm??5|#MY0hEJl#Qlj)0wfR=s_H6IHg2{tX&-B_ws$)pr)^HQvy?atcS& zkfFn6Xp3jCB5l~vmx6>9@*AuI+3~vbiPM#CDY*LeALU5P?xtwHrS$`|o4D6)>!82_ zl#mROme|P9wE-Kyyq#QPw!lj(pZ5n>pB%nc`L}JC?OD_26QR>ee)6$9DURPXc8q!pe_=*1mlYs=qvW@59PT?s_Wf zMUmb2cG*wn??OY*oG9ZSO z)HZ62Le?Z+km)k{Vb3OwS$f&N&@Ay3&R!H&GbVjhTMHk+-Yc2@-`c`Z_i)f`xJnN_fV-ucfw4`g2GB6x#n={C! z9v(*f+(dD|nUg(iX(G5#tp$r>6C>PgW6Mejn{{$$dLQu^aRXpfBVQBr?qCg3$&*|O z6?P{}ZOt4d=HsGcX(Ztz{L04W-Fv~_!tVBN+u)e}TInXIW{qx9u$siatK;ZXEp`MS zziGBG%@og#4iXh*uk-I!Dh-xg=e%j^8ag zAg0UlAa1Rd+9H>90hiZ1Uu4kC z9hbGWcr3!9a(Y|GQ%G7%OFcGMkTTw`x}lXAn@zLUws}s_nsZO55dHxgRaWlXZufAB zWj!zM!Irl@U(@`-yllwu8t>;C2Z$F5TYm5Rjc=wFq4Gm0CqWb=u{@7!)smuj`>xy% z(-1qxHJV(Xu<^9H_%w>fxh@<3(J^V)FRQxDhtDX&p7A8Ga)iebYZ?uLTfCCn7UnB# zq33F2T6j0!wDyr*u08@RG8;TsV!jJB2-b%A$7a18>3O@z%PZKVL_G}6NPR!Yz7XA$ zlCA}d{rKI>K%xrwWp`bUDWAus7rPwG{ev93Nv}6p^qIq5q4jzM5>YXng2`x8H=IV6WtiZi~NRt^gL-f%_e%-w4E5cz4paV4r1DoW_Wy`=kt&t23} zDsHgytLBTz@xl3g&LxH(!`8sZvi_gBN0`H3j=m`xsM^1IsQV)#h`H`Mf2y^_lxLVH z&CRjY2i{InU8~Yn<}_i93(>gO5f;u35Px!i;6CG9&GaJx#d+(}BwyMlXNv{B;Iq$X zsy$^sy(Onse9Rths^rU-I?pEa7IUaQn6Cc4)O~3^+JBy9=^bvR#bwGr+9-9k``~=E zx9IqgSrBAAzG=Ex-(qNh#-rDJI1E|OY1`uM%^9s|TIna|5tP1@{@Gcaj%7oyc*ye^ zDft4fgKfx8X18W92Uk)m5?$WvL){BYJhi+;q?$OW|^5sF{3=-d<*Q%AN2Bc@hiEP8E-y7?A5f5`^gT3#;d8)vo~C%0~0D^ZGkF4wCRu-ZyxGKf=ZS}=2R zh-px9)baQc4fQX0NDKK@wZZjC=NSHyL#-96jSxWe-M` z4&q2TZxn@Ag*8@w-R#+zo~z4zZDBMZtV&0h6F?txUGeHG-vGJ!h^rYNbhG7@X!?jt z3BOObn_O;}aB(flt$=5xf#>V`(yBDCdWaQU1+3uD4hHle&ueUxx23|)g( z$L|{1toCQS8`Ic^GL{LMDPWE6nXx}*S=?Niy|?$moZgU{n5pba8Q~YDXA{1mch;Q| z#p&BwDSnXyy<&;w# zS-EuP%SFYDB%Q01WandZ65Gv;kbLm+v))kq6U19;>{8o3wMN z-K3bjceU_)>J`=7un4-zTva~_xfC;h)oz+8l?|i#==zN(^@z0qPhmP0TK~rrE2l1U zXmo$25KL_G{a65g*|X-9zI{K6BBp}jhw)zT4Bf}eGBvL@)`og-J*rB4>=%DQT0|t} zbfuL!ZDNhn&E9go=f>-1w8>#gw&&r(=I9H_NsHPKrrf@c)o4rS4CqQmm`lN{Pf#+N zT^4;92kfiP1JnUDCy^3v-aQ{~@g!+PrV}Q|Er8!nIP{_(WZlL^ukkMG+4xVn&O^+d z>0pF2YrN;xt&~^Z02f}cFiUlOwD#P&lhob~j}wE5hxy;<(^`FBuS5-5xs)>WPTy)Y z(`GJ&4r%vvz1xyw{(dDk;(n+c_dB1Jj4!e7>Xqe* z^#g^DuJ4{BdCjJD#(NzxYM%^rJ6aqkR(96}M8*oO91_#tbSIs2B)t@;Cd=bQGY8wW z#-8x59~mF?gnAK1e)O&AOA?po_^~vdQ8-NLO!}@uI>{26ue;i zFs`+#HIn3csr8uNGe`#}t|}OW+HxEo0kNokjG=^l`>WDHXaY4hTtyJK;Im>??~uyE zAR!sk_AYLzUY1>~ZR_!j_*rHL`1i4JVncYVHphNG%V2XuVNk2>LR^sd=l9bYGIh3R|dBZy^N&KE!w(Z&MXvhhV%UiH10XRCY7|w+t(HAu~)}8d(^)+-lJr> z19|^6S2wq}6jQ*T={uis<5QBB^1?k+KXxVl;_%nIeN;6szw9!)snC5YwRE01v5@dg zJVZ%}d?; zcJ8h~$7=>6~D4owtNLBs?3XXqa)z zlv{VZUYyOF7{WM6!Kz2`*{BJlQcC?XkapRy$YUDWuPWU^^O*%;VS}{o9Rd7HCGR^^ z=26u@IBry02Ngipo`3v?dhI~;;aTT!&`nCn_f`gkeP1{?Yy(5e!^k^m?QUl!S&%xW z-8qVC*qM*pSx|=VT~&!FXQ?QVdNKy_in!^-Et@FEDZfSR4@+zu>-sdN^t}IxW$nm9 z^TmU-@Am4_1=OsceOO53SlHf6ejmoEcemZGb!&3V(@3aVdQYMhvbrPjgX*bPM(%pqOY53=fzBY<#o&u&h^E~wes{+8mUd4U3(g6&+J zT3`9V@iwlU6Sa2(>X11?+&oui0kz1t-5-?%L%2LfBlZm$rk~Gnh)rjV1bG&G^)tCV zs&-Juw#YiVt?y1CkK?M)e@LPbCLg;dr(zAXxGQDnT`p^K#7UH8xguC`rwmrhB_&{k&*+Q)Uu zUNJzevashwq}CaoWO3OuU&D)7r&<$}M-C@8X_{HqJ=5shS$R^n$%Msv_gvV1uz8Oa zeMNG)W_uTUbmYLRCsB-ZB1k)tFtX^Fv*5+(Vz@ znymx1X;qtls%VBu?)KDtc3PvxCQ@bEdGgzCfppEFjZIh5zL|lg zU?F6IHZyZF=)D}{{N=-!{QL{S4{5{Yl&cTJE>yz>tg;=;# zSgeUP-A-#KS;ks;iZJc?w(+dh^Q#Krzm>4V7;=MOOch*;>Z}t>3iJV^hwUd9a*W@Vl-sTdfOtbIWukF-sewXp}+(uuY<@uiX^%tz`?> zW@VH7=lJ`cHt?Osc)!mirjhM&YnxquSi7}ghGbHGC0)ACdb4B9viK9_7a-8dbRW_= zS(g`vT$s3Qt4nc0o#CRdVxUir|v0T5}p4Dki__`UA_DwU}M|p-UXX` z6s|s0{Y9^OI1gW5Ir+iim9J=5)}8rkW(VDe{?dylU8<^IER&=zLL}pVL%;ZlM@qs zk~6M_#9WCJm~`3*mNBP0)B4DT{HZhRI`jnbyFj*xc0%fFkkRV_)A7pWA8(k%c`7qI zXIv&IOxHc>Z8sAq9R=~1BZ-u-=+2z=4b#lpI3ItDgMJU(VG%BDgua-yBRRh@_`_qD z<8;M<;9U8ELAJ*S#r?x_dwJ)GhrHTdbUp2C7Q@7Y(qo4&j)1H7g@1HMkl&Qm2FB^GMo|oEgdU^O)ps>FnQh5A~MabW&%A_ZC8Q+N?cN z_uH1f6xa;i;7pP8Y?2OH=IM1e#*sMYQ+D&%ukjP4>?Q4iDp_1iOjt4Y~ zgPqNn*eow+4mqCHQV5lV*iq1l8Y%{w-5kES^LYgf%BM<-blG4csC>}V2cW2@qn#Jw9!%8JqOm*$h@{=L?FX+J1Z z3di>1ciE&14ps}on2>(bA9L0$XeUn#zGe@uOmL{`AbDz}{y4tvwR#TBEwX~a9?amZV3sGzS%5sj(~>2$s}Wgo=?Xf9c?>PxYCy5MUzWI z587F_ERt*oA~g&~NLcdOY$@DF-5L8&ddSkpkLrFHhsZFZx7$3x2hMQHcPQF0JM|>w zoN?B<`|T!cJA$)AGRU~XpOk=oc4!-fqc040h9%qu%vUemM74YlgX{GeNj4fIQXr6>6SK(;n-&w z7(Y+1PTd7`9x9gBahn#&~+f-bsy)Bw#Qg)8UEF7ziLbv-5pPs>OllCQV*_2L(gswE!$10Ci5I< zxd5`M2mtGok)FVwymEP>ErEqKzmlJK zX*BPZN6x_VFnV^bt}hDQT{H@eyK%H`Vfk`+`hWH6)t`dqoVL*+2eN0M^X*FDau!-c zdBY=a>|vRJ1CiAFbM&ng@Z!%dBa^H1h3ng=wRJFjG`RECqgMHHhvnoBLFwzrttdPP zcGl%~vveR8+l=QQ{-fNE=9e%Wm7a?&w5uAVirc!TFsFs*raSlTis?Kjs9fk%t1R~x zpSuwmBY!A3BxOlcf--aSV+B_jV4!rmXTv6G7k2h+w(O?urBIW=9YUP-KcTM2!9NFM z)h{8M>NyfU<`SYccSk4XDgoR^OB1zv207zpOKQM#S~jDp8^7&Ux`tk!bkmi%Q~U~F ze(#v6PzxO7D8_5M`)$^xrCZwEHlCMD<~exTe6=8CgMh%Qk1>cw0hF^3_IewShLc+9 zP~WY_v24+*3n*6At|y;q&+d{kz_E6Sw;5Ilr2N1iU;Uo{0A%}%xUVk!O&+14+4-i| zPqK~P3#lPSarTK!at23y{Gd6HX=CK_uUO7e-o-}k*&ZpQ_>%tsTGtxPSy{(-BrbtmtY$nneNL{b6*y9bwi zpDcyQjpB#QY-A5DR1x(L!(Z8Bz&;%CsM2NA?sbS{4dz^GHz)0KI}9|wd?GnSmsnd^z|ZuG;FJ5%52T z^gUM218W8RcIIo?A}1`XaVdFMc3W``mfYDVb7r(9ZPmJuT=SZq8GD0`r>GU8xh?cP z*GTxgcdk5db#rlW$^xrPA#$V8WaJ9$v`-d!!FM+pub%unsJv4z%v)*B4SH9D^_C?U zE7zetD?(3W()CX___+F3)wR!B>h#CmwEjGr(GtS;t+>p&he6n|a5^GFf0 zQPhk?4XH8}NWb@&5qAlKeH- zuPs{M;@K=!78Rai;#3Yu%V(7;N6xs<&CsqZ&!yFFmtB)nxQ$ls;Q7)P2)j28+lFz$ zUO~?z+P>@f#o&7nh?-Q^vTR$+HWJ$0dKGQnTQCIqOhyhlU%WCwR{0;sJ_b4tuXSf} zV;7f}jN}1^8R{3DVYdUGcp$b-XR|AF4^U|QLmcUM8raDUd0>pZvm!E?2nTWKagn!n z2;I9JCb{s@Zef@oAj+?hfp(HP!6P_Kj&a+jPgCDoXicX#lJReafnbVM;FE*wJx+PY zZi2dvGvSrfS?_CgUoaLMu*^6({DwFlxgG0B?Vy?Stey*)FYckYnQ}M$zdB%JXl^mj zJY>~Np9M{IaU;BI=1NPk2nIdz(C5%_Yw0~h;Z>|wlgu|Xi~xoIu};htGjAkaEw5VmmKaz1X7h{7iKjU^*&c$hKplo z3mwvVs7Q}!!kpj}gZT4->}m^tg*U`pd2u$@11O*$mm~w8!=TPFlA|DzUr5}pl?IKz z);|o!>_pOC>$gP>q5-@}kgn#Fb~b^x2Xhgft+x+F{i1vyqv)#)+FW;6(Z~Y9Zxewb z0EdyCw#d94sGt=A*kw*lRBARpTOMoSZ-Tm=#e{aS>8*8rBxTU!cgP3c``dPm=W}D8 zdJa0z4t~n|mc8N&T?XDw3LRSF5MJK%T(!26QGmg~l0_gM2O#7P%eZnqQ^Fs$cY)2c zk;h}JXgY<**^pkq?QF2J5t)^v++$_O?!u|q-!Nt1C!%;?_L%r>ZL&>vZxY<;b4%yk zwVtDF*AhD`Jdfl@!zq(vGs)$aW+W+P5hi}pU528MIlun^f^7J%!^5_kwxOabX;Y%h z9n#9OiDNkcE*S3vZUaW8Nh6RC6HWUL-*|q`85++vvo=xG23+z&lDU3BY<#1qI6POi z>YwmXKZpJvw6T*%v(@h+N4SZj{qjo!s8otl7;RwK%BX$HtA=>-&J-W;QJp0$ByBUo zH(I2wB}px%y$GurVgfS=r3~jPesF+*_yc}Y9Ohps(_(O{{RqL3GfDGy$+WCVn(@eg-+f;`MMS0=cjywo-5uoZ`I-;sGFx|P?f)j=X zLCo_`+K-hPegx4QMZLJZM2Ov9uxVvfP*nM7geeLeY1j@4BRM!3IrJ?V?j>b!-ZnTS zde=$ePlB=}gUo{rG4h7pkH-T(nZU0>@c#h8w_C9>K+zoj_Z9A}{V){@(o z7$3@{OFhFF$I^f>Zglefx{Aisbar&@nm{HLE@)YtI(zwppIo;%_n49P08 zjoZ6@YtH;#@c#1U8~0?demd9G6Zko0`^lersXxQ0rs|9Mik?yP9|U;A_I0?PX(q9W zp^=VEY^P!W01Em40QQ0Xoplc zZ9m?Hes5K(8o$6zCtI?T-tNZU=H6Z&DPj*B{Z#Q&D0d$ST7p zs8zuG%t#3Czp1|){{Y~Q{{Rm@EyZ(lrT+kFcrRR@Hi&8(b0yL^UaQIhGmsFZLm2d8~$D8eL$W$qZWJ-Ys(x8{z;D#~+6i~VS zr**G~c2|)P6KOhg>b9>G(8XySvpP2AD#16F+z#MWAYH(#Dafz1d^7(52FzXPvF($@ zUK^25U8^$g$Eh)&*1bc)AMkFUhApla`)gab@I zUaP82;*DZksUf($Vi_Uxi6SsT0kQqz)N_D9KS3-EMl=J-lYYGudoH4DjzZ}*G-Zz2G<{jZfN$?As;M_ z+t$BtrvCtfDSR-K{ZyK)ZP1tUoS*!3R;Gpj00e0GTc^J3i1jP1cFIOSMWRt*?GMJE z2}aMi%uYJ=JlC>(HPbAekj}1i-vr{nSG2GAB}c+NH}2@SE=P8`KjB?=gZl$~Fwi#n zeiycVz2%YE{{T9lX2H})=O=?aH)Avl<-@yfI;JbpJQLxYR@Ea}H0ztB&H_Om;y(;m z=r)u4Ie0fiMO$AC=+W)}0HGe(PL5fSLi9rqoDpA{uKS17dt#MFM#l% zU)|#a(x+d6kgxo*{W%rwqfCq+yGiy0Bks}2h0f0!UxCo8f#ok9IOe90;GMWn^(*cO z>0Z9p2tR~#%_rFumi8TK@*(rH$y4CPvbQmQm2TJJ4ZH{aT6OfVPHRQZJ9|@7)k6K= z^m3tdvxl|#eGR|ekAYg2-WW(g!tKX;=?%L2ezeknj`;1~j!@<@qS971AIhamMNlg1 zEO0l6 zw>am}XBqxXTr>;)*B%M0h-d`BAlfHqk`@JZ%HNG%Bci zD4+#6RsJunDPVf`In5MM0s+M_kNZCPqKW`R*S#Tc{Q@ HV1NJFdVH0S literal 0 HcmV?d00001 diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt b/data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt new file mode 100644 index 0000000000000..283e5e936f231 --- /dev/null +++ b/data/mllib/images/partitioned/cls=kittens/date=2018-01/not-image.txt @@ -0,0 +1 @@ +not an image diff --git a/data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg b/data/mllib/images/partitioned/cls=kittens/date=2018-02/54893.jpg new file mode 100644 index 0000000000000000000000000000000000000000..825630cc40288dd7a64b1dd76681be305c959c5a GIT binary patch literal 35914 zcmeFY2|SeF_c;EH!Pp{25g`=HHe;KyjV1dUDoc?VjBSRoXHAGKS=yvfmQ+ag2u0RJ zh_dgJ>`Qk4XQuVNewOd|_j`T+-{0pkbKQIHx#ymHpL5SW_ntF5?{@})-I}TzssIE6 z0j`68z|IiBt>T5X0RSBx;3xn9RAAj601DB>{r(Bo)Ph(pNMq_pU9%UoE*74Fs=kF z&RLEdE-cE;g>kmRS!11T<+x88sR|*v(egC1hmPwh8=2^^BcZEs5REgO`e8Y)&b)!50dys4DX6{ zb`#dOw8g06T%9bztd2U)73+g@cC&Oezz}3bi1=W$Z_WR_y&Klb0pliY<%+R%!?}KM zLoG{i1IgR|zB>@)Wa$ic31WW3prQU-!dN$qlcKAurT0(Lmc$~+7})STYkxuG9}kX5 z1Z3S&UJ4;ALdp>7{w+sVj#!Md+Yfd?@8@2b7xs=$s$=<=+#TQ{X;`1;yrn;g0nuL;E2*|Jkqp8}_U36a2gN`M+h)`itTH zkfHymAL3MeH*fxF^+Rw!Z}ipb#an__xCPeP250frlKJO0uHTrYiq`h-1UHN|5omz1 z`e|7Hkix$=CVppl|BKe>-&tjUG`*>1zZ-?biTiFU67&D?Xz`cBB)0yADp(?IiK4=w zY5H?>{0CquG4VgcmXMPC4fv-`0`~gnOnIM#Tb}>S(En+n{?5|>MectQ{9mB|$}0s|!T%GY`}=YJh3pE!k-rhUZzBB{U@56T zmtE9v!T*t6(t-8sarcML{+DDIe0>_At_CkU^ofdF*V5_dN&BZZ`%}UHt``1DZ7Ym- z64u8N+zhQ;v3O9W|Gew?+ZFs@*SLQ-M545QW5n4~$PwcPUOfGVwZB(df4KMfLGAZ< zN`F20zwhDi)S{S>sFaW>+z2iuEsl~Fll<>ei+?z{|5rupq_gG!ogMvufQ;J<9`1gbT3km*+hPndEsqr1UmM<>2xE~UBxeK{kZY?yFUIn z%gyayFry==nr@#7ji5_E; zU-;X9k{|d8sg9nMk`i>_H*fGRQjzE<<{)9R{{X)vKgZ}J4eV`w|aYRMJBlZCLw#gK}e!tamfKUJ+d9_L0 z9v~K>XfXNO6R`;F#{0!!c88Cbc&!4ipKBhWGLh;z0D%cMseW8?N~S;-CO z^tJqz4Fg-dZ}p_XP!b%mRv5x5M{NSw9P*ofC`ff6cn|o6kqR2ef1wStb;Y^ke*>Zf zec`rPXN-8Epl+n2 zeIAE%Bod_h4*m@?wH?mYN6`^$`+cSuz78G?CHM*kTf(d{HkR&=ZeW($19aE?2|Dqr zngsqWI_-H|Wt=0<^_R(I{7PR*?MDUJ1klAf6X%-R4TslrcOzha6&xBzP>}urr8|#v z0|n|2a9Z#c+x;G$w3hgZ9RYOU7lNK6<-f*@Fp=^^5hkr;@a zg#Uc$eWyfP;;Mw-tTB)u*KbSu8x?}8>vw*L%kz7W+RDzq@J*Zp0~|@QeAR_uW752& zlBlpd*|=iC98u}~G7n!@#m>+!>u+&2&IDmgYux#-VKrWscoAac4FHDr;4O(UEnMd@ z;98gCn=Py4 zD@|INhTtq>vN3HVnWaY14$DN>3Q6cZB{10i4vLd200qNEf-N+n3CB-mC=3akZF5FboY z#1si)f+?{tF(il&ro>*vkRUZM1>1rtsVz!UP#i7}))G^YTQL0!ks{@h#1sYc38o|* zlmwWULJ;9n2oy0D138DIK$^sq*cOE(wnZU{I4C3$2ZbWxpx~rDks69b4TU21h$7KM z5jjDK!X*Vo1wjl!2|-CgDM6$lN)VhHI9w2p5QK{f!o>yQ5`rKSTuKm*6hwfy2tkCH zAVORaAt8v66hue~g3O7D%BUg~;o>SPqAC)S$5jD1@@2tjHhV z7~w&a5;dh;U`=}Ss4uEv%xxo6>4Bi8<(1;lanQ=0Fb|8se>_oI9G3w zcaeX3Ufi%Yf3a4HHQyK{W~Fs7)>up8_A1A%i?JoGBW^BI&G!ZOwbIDd(wTs_bOqzP zzSV?rL(oAtM#`!WiUR`5bJ|cm*bWa6PA*c6caMeqQi~VH55u<8p z1U@#g#gtdLQ)FUE#OR($b1cMl;g%afeuy?BItVsIPXLrLFy7n;&MQU{$LIa z%XGyMz6E!F!&h{+bi~s!h)E=UelHV|oWAJWvAldI_4Pw`)vDQT3 zSFK0lj0D!kINK88Uo{>HP6|^6w+vUTl?E85=L)j&HKvXvJ|q-FOD8ZMl%S7sC8^b) z6~@2L8j&XWh`@>IH{MMC5yH>qhMGYuECm-N_K+@-^d#m4_2u zt>9V%vtKi$O45W&|E?*C+;4O?LIPAF5}*!|0F{db3e+f|dXWN^3#elepnjA9wVi|* z9L$4oP-`MUttBQ3>RC`3iNR5V5};0$5Cb)?C{Y!G+DQt;l_HiwoeOF%qAo|sAS6{# zQewwNMODSblqDrpQBnxy<6_E4goL7!BwP}V0R1Cp|Gv5aEs8%^6yHq^aAUQ^ODlhM zDH(z0fE+hyR-MOikpzGgr>tx!#-(Fv1x6Wyam&BpV~D$=9QV(SkVFbpF<2Y#uT{TB zjEYMmB*Bjm=ualf0kKydP^3VQv*mYWI9x~+B_t|g1ecITfFCi_|B9vrd(hGO`@uUo z5e*_Dp0tKDc*X5(h0!=J$L;QpwU$0Ep)4kOTtZP&QB)ZYKYknqR~18viK?Pd#}Ud1 zRn)(r2%XM@+T*9aM-10=^~8B#oJq0NPDJH*9_##7AkWN{6K{h^it% zjU^=oUhS%g3WVO2wm;_k!8r?HsSzdDn~7H2e0m zupT&kgpXg~sJMir6cQz^d|X9UOE-Pcd?_R} z>~eTSTztZ{>o;!RO1yVJEj=UiLDs{k`2~eV#U-WBURKw9}#i~t(gqHt-u=c@Yd%PaNMF|XTsQTlSvTFKSOHw?p|lgf+jIWEg7ZZmd3 z4Oc7XMc+a9fC47WkIj@ZCean-{rc%E5vjO!?L4n>K!0<#?&kg{|q#T(e3*?OC z2w>DZ)#MYCF3ynK)s`F?U<&xmiN)7X8VI{|Y`e3_4yVP~nOCL*mZA-H{&%oZiLtc^ znMAAN{*{&y1;)Lr<5$nIPj60+o?FoE#s$0#+@lMd)NX+-WHMK+Gc;ah=QJD;2Owp= z*8XnuEq$Al>7vrvsM9==8}kY@Y|9f&bA*^TiD8LeNX~XSI{-#@H}k<%^fk`89l*lW zV|pN$2^e#{o8&&8l$#uICgPma#9X>z@LDT$A}YNoPQshp#vrwwExy-!xB3@ypNq)M zye3GZ!w0s|rj$!IRjm)Qpv5d3@GQSzv0nWlh`zxz;MOv29*nRXN(dZU&+sJLBTO;cR zc`bTFDdfUwWRhdGc20~1^UQQ5bBxE+1`0m z-q1i`B$}W=#u1>Kuuvj9pKO2J;>+o_r1Gr`t>jm}RGPYO4WuLF)hA{mPWo>7-?1ap zcU7yve!N|`1DM`7r(P$_8l>G;+X2i2;Z*gXz0)j1b^w}y-umTP9ocL^)oP?t&pe@P zsG?9*dl`N8!KlIhWMu`$;N~N(0b6c!2u9UXbQ;#)CCNS0RcLwmPps4;^m_SnN)kfT$I6jXjr>%ld@o@Jb;Q zdt!9?_3125Wu;tA{Q7Vp!fXe4r|qw1E^QQDIyI$HJhmyLPcGENxFwLXXu!;}GO6<- zh4NZ&!BG2MG!3lR`EG#~9+WC1?!+g7Mwd0)Y_ZDRDTAPR1qLXee6PVn=gqC7QqQy; zut#FcONI?*YPMP_b-XNU(A~9DFAwyZ>l6fF;;*Se1ycf5{al)X3;AkMm6~?sj8n#F zx^4bR3+7Gv;%>*PYcHD_H|9A5aLQ+?&JJ7yxSVc9Y7R8Gngaw+>&cL7iefUjthWiY@(Q!FTv*Qt&PXvOJ}vqfhur$k~v_@=%Hk8TOfPo8QVhi<`Pp&S&R}YiytZ!?Q_7&uch-^ zeoS?t=zJF_qN}ZBrPV&$OfLtdlmckXx04$aUj=R10cd#3P_Qaovzlkkz^R4P%O;SF z&|<#Kdj|4Kr6b_^wBd1T^r$|JVdJ=0lp8~6zFJA<%B2R!8oFySHwa5F2a{KxPYtz% zb27W_0QsCheN%4az}$Vndi&>Mo$kkU?J z5EBVZqstn}pn|3NH~SQRsb-B^?_4)2x=7PI&nAL0fQ-7L_)x%n+w&&vt@J{&9bj^w zkHdT!b*b7!$qv9{E;2lDg#+5|>s(2{0~FjZoKNeNkYmuc5b%>r-ENhZTRhsYCU->&Qa5XbHjuYmjmR1v z=)R!X`gnA^WsiL^4gI?lW2$xSgn$s=RPP}jGgHNtX=P&_(+h>6d>e!wQ&XigdGUz!qGAV|Ih>)A;Q0;h|HpA!BK6rCifUyTBLgMQECJ6AS#NbmCh43J{XR*O76p zTf@n1wMwv>kk*#hTb+ltWsjC#cUgKPInth$0`m@*U`1ux9^~&U?hmBDQ)!BI%8Um^ zGaVMwb~}<@rj;zHcp<&cUWF}eN>>-Tod68p_TIw2oy&Z$A5}M{$oZuU+>!v@fR~XI zNYE%!?MAg^%sP$-a&`nPKd4pT_D@I|-K@Pvi1sf4Lec#q4`50D0gV^aFHiSm%kFAk zO6M6%p}bU=JE3bdv7G5R5!-zy2L23e3s_Vi^f9_)FF}EBF~vKhYxK#lEVPI56son( z>l*CunNOkXLM=@!Nzg_;+PctDGa;W)dnOUM9N1i3p~L1qxz=&JIN(Biy0|c!gC97m zxxqm}AxA-OZS7w`FZU6Wy|4pF&j<7NW3ru%_4k8h4dr zP@R%~*98+6*4!3!4`*H5ltlq7eCEz6>-HH=o#=z0$@7@=nIQ6$CWo~De2lprv_-hw zLEL9)&%oLR-p0qyw+clPgZl`YA+Zy|l5^l4|euOD3-mGz%IIB6(fee9GCiCeraORr%@+fKvV{)@(=XSe8DvGt zr$^`HAYU>SN09EjPJQa^$MW9NILwt`{Qyd5h2czI?TXwqnVZRTtDKAVm0wuCBXBL_ zovg?q%E`-bte@CawGe1I)Yn|T=pCG!U5R?TEMzdH*C9cB6DPh_(d^9_`N};_8!tuG z3p=wL3Q2IZwmx^2;mIcF@Pk}8ufdq63KLH2D|`*SYQ?*fbypwu3%2o?$hWgQ6spd> z9rBb??n+wRoJt1mt*HB5Zj(ZoTjE>x9b`8cA3i9J3Ct-S3ga0=v9G+bzf)T%l>Yeo z7ySY$O}hN6nGL0^RhVAge(9jq3Dy*<&2uCAc|E+xc-h`rA6DoIr7+8&XIt|?<`~S@ z?9cL+d)(y0b%R@g)_G%M?3~94l*@g3fc?Wf^U|}-W;)0Dp1R6_*N?+kAvd^OmKx*` z4o55^!-oz^59Lg6wg@c*wboX;-PJo03hS_@0oJXysq2Pbq}un(Uh31!(y^316=&Yd zI9^|2Jo!wM%^8s2CK`lq!ts!dK>NbCSFM3yc*Y6(irypsE_;@yax$&8ij zZQmv@{o{9#TpAOBUWXOxSzh@GN`odaP-nQ-&O+D$I-510%ZwBW#F9*~4CgsIYo}=Ys8oDtO%*QyOEhKJ;CdLc^CgK8DQC{Cwg;v(6@vgPyR~ zM$w~Q9`i*9DXcf5KoiWjoOzT?#$Q~ZD)Zsz%?H4TTJ-Z~el~vg;oVoS(_!L31Is<%2W{*@Ayr7eTwZ$V9U#0aD{K=s&(v#K zdBDRCDx{g)kJ=I_7?-&n)6yB3yuKDh*69^9b*6kI}U94Gm;8b8SUtLR%(v z3}Y4(0Z$o+MZT3HkGZgE_{eSQG`oJM^JXE+$^B}0l^Uz~=sgOAWYaC=>H1B$1r<1* z7XB=H_q=XY%S-2}EV!$isZ|i#P{Ry@@AfunA2c0^?$)BvIV1!7I}c0c)p zk(nJJ|Bdd2t_-h7b*3UU8~fgg6o}CG+7BaBuAC|jGF{uxdE_i!y(fesysExVUtV_y zxHPVn5WuUH?ZB~iYy2c4XVe^mvWk;G*-)!1nih5E*zokL=l5%^Uuso5y)0_;W~aZu zAP!k}PM#;|z2zz>El-7DqH4lDpY$tMpS7O!rgS#Ljs9H5=s#k#jt4+DNBgu35 zRnO3&1()EHiMx_K_I+TnHPV+VegtoV&%D<2`Cty7QB3oFS%;_4mpS&dyEmh^WtXp5 zmxU1GSzO}6w$gdCmX10Yh16*R+EKxL(E$gjU$DESOw-;?&K5WmyVMMbwkv_hspYJJ zwmU_EEPI_~oJU$TPn{WoMkfeV`^!}%0C&w9`nD71!ivS;u?4(LGwKWe62Lj7KyP=8 zy00gAS}EhR1WeFVN^apotC~14Sjo@#T$V0cPT=O+=}l}ka)Q%B6FA8_t0KVW=br#A zV9?V9m6tH z){m^d6TLch_6=SPG?wM~p>xpMkK$+jH*`Ao4Zq8__R0vKAMYvQl}3(+n(U6MNLfDs zZnZ2^d+pcnyz0%+;ETPLs32Kzscavz7KW1bnXf%Io4YC`@_DeIe07>L(yd3XJ0SYd zeouSqy6&N)@q=#Up`iu=;hFwkxj~z%5nrOlYTq^>!sMzBt3fBL*2e=;)T1wFdV7rL z{8Y}i=U$4e#=Xgu^O~Gonmyh*wU%#f0YW6Qe;{wMBH|I$;@crR6H{K5m^l?2sQv|81Loq>$=QgU3W!b~a7Z{QO*{D+0 zlPi=j>c^1vAD@@Bw3{=u-`09s7}52nl=hfl!yaabLZLz{1AH|rYb{=$7l3ItI zKCNRiLcq351JudqiOm=6934f+nWOtPrwLNGDok=^!d+1%U>VSLD;I=qp#Af%~ zN{3u#U&$_aey7iAVW0D-lT>)GHdFBeW)$TwF2WkqhFqoM^ZWvazHBn(>bJXFv^&;ldbXb@a4^VzZUYPBhSv>mRnFrp)7vQe=K<}3(1{$de^FLIv_6Gdcmi?D|T`= zLZ_ne(3HY1C~gOMcz6G38VuI;QQ9yp8<4V<^)DFHxt9>>Y0-Fj+v*iSe(u7TbiU-w zNu7i6t%1%POhC$_R&W7*z%@}0Kkwmd-Uj@Tv0L?1jmzr+$*t`{wXV`0OBtcPVw#j{ z_Q2WRn8vhGBd@+~pY|73Qwppf8aAcdEygwjOkEtoV_@K5l0K9bAF%$8m30|5W$*-= z*`5v$Mdv;TFLY?sDDrwbtg`vpbfCdgldk~4mKD@WhS&0DC!5&_Z+zx8VZ_@QdCMK( z;xk{-fXkeD%l+v%>crWXiIE2ad1S7$QWwPogs8TfiL3B^dVRTdLZ0fo&*m&M){W<@ zu~Awripr?5ta;Lp$gJvu z7nhO;ymttSTRW)^6QpW7H^Swb5~y_D@N_{o{fKTXFBvPNv0v|#8+2s}A-!iPk7m-A z2~+Xz*3se9-@jtXm899mc%Yi+V5ZP^E);q|wJ z?OVU<))?wf#m?ma~LYmf=N;xO@b`k zl8;zb%Xt5>sI^bEwf?D?Sv7xnwV)b8^lZv*MLMm0G)i$|uPj#?hTNm$b9$!Zrx~j* zYIzCe-jqFjbx&m(P5Q~AIOdB(ifbW0!m;U&okHCwH`J1bPS^x0<$T`F92*cdXKa-t zZ{YO`@S(qEFY@_X(VD}a!92^-AR9V-+Gb5`t!{e~J%#z)N_jtL1uc7<8sHh#5tk;f z!54oK!WaEH;6UoV0?|(nbo`xpAsBq4hrpZG3wK@dYK{wV9h=*yTs*7`qecVZhy6+% zBXe|n2QQhcaarM6js(I%m)@eiWFqdeqiO%EQSLW8Knnw}y}6XYIM+EUp^Xed1pemW zOin;`;-woasuua?N>O|28yslSAC8PBP^kK!4dc3NB*6aYz$bYj`!0@yb63TTd$onr zPh%K(U36$_-;BK%wWgxRCCpNUvm6ca%HD5=IhyOVm=z&Y@-nmpiCnyb7)#R@XlS6T zPYpk(0CO0`MI5NRW<=ekzZ8fFvGp*MX@-tJHZ4GUWR>a>0=)b z?n-iB%4JwRwxMBC5ro@HcUzcrM0@(Ej@ZHPJhh6uA9R%-+0sbO(MQ)npd+YpYp_S~ zs;Tzoq*?D)Bi~JkeFdP9xY{(@*ln+jhd$sb*7eTIKynrfQ`e-W&3qkBkA%(sR4{Ku0zyy^<+v{(mdDH+S37kw%bqAkHyqQVDbC|LEybY3d}Zi2S8>D z8*j;7Xs71BwqG0=lPR6HxVi((okEAi`wDGk7fv;K_0!@Qndg~r4?e7Kf3!Bbn~+>k z$Nz$*wqJBVXRao|IeNM_V9r!8(9mmkF!aVqwtQb&`J8q$eM)_)izNzBzYxpRgYMf~ zmUJ{Pg`r_B*93FXik|05YtJF5 zEi_jpQtkEY4)!)VHwW;W*)EFl##9nm(q8C^5%!&3>;E95XgA6(R~!DoT<}8Co2CQx zsb3DphU_`Eo3QZkp5Is%1-(9n%ajG}472!UwGY!gaZ`TxHH9N@oVI^X0FST}@0kKA z$IsBE@UCRN^PLh0i`iq!dSeSUVSEepZL@FmjvivGwN;zh0?SLx#Rg zf6Mtuct(-J?U(9psxH%EAv=O(Q36gKj`p`&bSYGd`#6!Q8 zlU3#E4a@SU;ZnzPXnBqn?U!-oAt$2*B;}8InJ?D*)0CGNMry5~?v+Kq8j2lHr0Uj# zvzk4Eduu*uymtJk=F`~4vpYZyTeFN|%=A^ZV1}8Bw_@T~@Oj791ZgRSLfE2>DmwTe z3H#QQU#=H4$T)^1N)GHkepP<&o7(2rH20l&E}ES)fgNVO#DQ968>$Rni8B{BxUx~D zEyR>c(`M*i*#-~8J_sIe$8i;hcD0-!ca~s{E)!%N-qrj#rf^mE(t6lH<~fHgX1L3!9LfUfuD z+YeLF=})H^IdAYrw5fgqMzvVAg1gh z98|~4DGWK;7_BSDV-YGP(I@HS1SqTx5L7&oRxEY)@#Io{xp({_z9%h0a9dX%XmGFk zt9lkWiunbF_XRcJMA%B8(f6Ta=;O+WUS;;!`ll+~vucK=CVE`Ky^Nqgz|6rA)KWg+ z&F5+}3Ota1;)i@VSMYIrOhS4qK+`pMp&8&B4A|dlkeiO=%#$OI@Ao^ltN}DTq;2=M+D`+~>V3~C#O@${n zb4a%qjcyFwk8EkX`zWSGyb5$$N&7@qEP<-j&)=$!nIkpRY|0e#?%H(W^W~@|gxQ%M z-P^;Z;T`vDKY1N(e4XY=9^u9A)0WEv-tPpWnF0n=(NCG{K3|!?=fu}io?PwI__MLi-`Wpbj9E%^8-f2_hmFDmp>A(isdVpOK0#S7in0&-_GmzDi+%aqMlR5rH5_iP4ozf9=Wu;lP40Mg;Y0VWma~uE zYT1rgh8=hu_lZm(z#}_^g~xeDRDGilcA`DHgn8Mv`VrT@w&3y_`q_z^ zbH1aR%Fu<)(!=UUiiMoJ-)49A1sA;v)lwF8$=8T&rx za7w%+WPgu(?&PKLnB_PWO~`o;bpk6dx3-yw6*NBJa@PqJ9>IN2&xp*96-KC7&?Icg3-@+t%UADRu5> zp#~Z*v05~e>?zga3-OlzZAtvv+uBnWeC|m)&>-K zJ3!vNLOp$B`f}I4CpRZ-OZMuVS3P{*l_EK$9j9#Q;5=<6cZCz1^xPL_oVH&bQ#dZn6c+3G zilu5TkKI1ZvhSH)pqvWeB2T_*IaC?;?Ar6@(8T;t*Imoh+~cZL(;nsNd_rtgJM&zS zksb_WeZIH)z~?Uu4K55fnBG1xc=$&4b#Nsl(2#BK!V^B31iY0*%j!jlTvdS;Rb_Fc zEQ@L!Pslaq*l_(a4=pOZ^6uF~yskyEQi#``_FX5m<7#QR@Q)))&yzs{m5xfPkr#vk zmrQ}Fj82=-!QP|2>N$Z#f%(Ajp_A$g4fHh_`0OjzL)Ng>0&p3)wR8zvMYpzQ`4z=>42M( zR(NW$s(ZMpDh=!VAtT!?zChTKlmp8#D-yM7U7ZdY-8GB7=G4lHmC{4T@#XNp`sw7j2Jx;NU?Sc+jx_WkW6@y>XkmVysi(; zD=zjqaqS>%Lczk3(_VVTqBc;>ntf@iuJgW2fzM)nv-$}4&8FC?_%(fUOj&@b4Qjih zkcSTR{0H3y-Q|F@sC>e3l0Qd+bKS%%)|ZjPg~qiJUtyYw=!-Ux?vus zT&h8ik>KxU66}Ku)A?>Zk*`j|Tg>U76?qAIg)Kf~hF6#F;HsbXjR%7_HPjSA`5n_U zk{XVzX*>>hbxQ9jET04KIX5|?RkFXg;J+!mXmfloT-~t8zZ~9=3Bqth*<5 z!D%Q&e#%K+x1-uwU;jwUZ35LZZ(}a@+2}&*J#C$RHnjLIhS99ZoaOgVDriE_AjGvB z)22PTwYWviX%*(VKXJ&I@sw8T}mS1E`-#JXK$E0Y4uQJGv6E z$Ep*Tx}lm(m)I*sM!qnuwePxsnN^zRd!3%fWC2CGdY32D7H8&`-xWtMt#$V5Sc!LE zO7soCF!so#AWrk~!xypY25(c{0j`12mPQ4s$Fe5kK{nCD3sTX;SnEC<*JVb0P=hbi zF!@E%)NpFJtTSrXZ7REV@iW@eb|Iwi+3_BuE2u%Qr!|(o4OTr@)nB?f>v(4*JFN8i z2q_uo2<=0evlS!6g?8IUU=Hi`b=T_;vaeb10ea2~haQRU){`^wZ`t#Rr|gktP}ekj z!h~l$z6O2Dxl$5<=?yc`nl-1W;H+$=8?0C)T}MTq#Wb9-f0)U7;()^AOMr3l+1I9= zFZRcU`WXr&mv*-S@1( zGp>21{>_w^D}D3WyZpHDFV~*M0izVaUC3>*-XYw(@|YJnv}n z>y$&W8tj1Yt2F?|;o$xP;3%ETH^S`(<9zet)=@3p^3|f(k!Cc*dz2+vl_RrItmJ|4 zJI`@-2TndOzWXdYUe(JsVodg){L0|Q6Mua(*4nov){M!o8JIVZ&kuZ_%&s0f)h?1Y zXA>p<$ib5FI&${hqUp)#)2DfPJA?$gCM z_qI^v!?s9R1?O7{tu0|WbhYG-JtkCxab0F{TG{P+r)nOFHOXAH!W?C6ngvt`%uH7= z>iEb=Mt?$n2yhn_WD(`Hg>Jn6s0XkQl>sgR44YK_bb0N9)0PnfdTmbk^TKaHayt19 z4)M67T+>U}N+)1JOA{|~G5e)Gk{zrbzUr{<^brgyBjtv^GK3{d_}UJ2XXFt z_rlXh35^Ulj`rJlVl=NZv^5_)S=i-hFy3Ze610xs*m|)fUz>@k-}0KaxQotf*<^O8 zLLVp(*t&wtvF~T@aE5HKn?B}dnLXq8K1kK;R{cj#_9SQ-`Bw2<^2n~_nkj55hmlXn z6~?2g0iK8nT0y^+Cr9z#tY93(;k5wsr7_3EDBU4l74Yz-Po9_LV=xzxfxYY!lR2iR zTxtgu&p|Cquq|z8B=ZnjhIhg9z~HB;?uXC z{@^StufdCRFxQ?Vy6>w#!jo&_L6`C2`Ke(uhXKxV^e}5-$!E8scc#2guXRhnn{S2& zTuM%gOd3pDKmBkM*vfL(3-agNFZQVaO|A_!V}`7?e!z6y zSRg|;n#@=6?Ka=)sV>vjS<8y{nH#Rs97bPG9XuH{7+lYtFOOU`pQufS`D$bs-`DbU zIX~;ff)vwX+*8L9(CzguU{JM^8*MAnU0G{ETR15q{`|a?PzGLdkLj_CGE68~G-~tM z=avxk+=&>@)sHXp?EKl{kjCC!%}hZs*}!{xHzVEH=?F8B37@V|w-(M-!>vO*fbXMH zc9cQbbv9!T6zjpt1B(`no0KO*RJnHJf`C3D#|h-PkaeBQ3H{L0>jg9%H$&mrv5N&2wYwksOt&pR%~ zq+`b!OVXrbN@)il;uPqI;^T`3KS{TmgB_NI#2r@mirH=HdbJ@&lvwJ)F_}(YfHEY<- zb9`M-7toB`U@$`6CD{`f?uS2KChy>pH#l>&|7o>E!6=zkf_lZdtImro%=9n!)Zkw} zr`$MO{LYOp|Jh!XdY&l%;rcE{IGNCf$ytxG%Qtfn57`7*ZE-1u;Nj`Rr&yu%<;WW@ zk1@(jR$~IYs?f=RRea*&bHf#<=npA3l#{1Bhs;?W56oTOEJET0lm#uD$R3~BU0}}m*7?`UY#(;rkn4+_$A0EFKXB6{Y zaE_wvMbQ>=EhwHpEokzT3gWN$b~)?ugr2ZYX6$E07WB>fO+SZDd8XA`nUd==S&|K8 z$@!F5b8466G(H1Jc0#K@tSf#~`)!Un>!He1`AIS_TgeQ#Qb!gVv>D3tWtQuLrgE z1T<-v0y(NY4p>(fr)^QS|Z)Qhv>Q%IF zzm-h6veILsbmqf@MINstCa)7E6s}gM4n|B=s91#c!=62ff?l~iQmhhn?H+vk(9;n8 zbwahgyyXi)(=xN{V!AdJUMF|LMxpu_C-2*=2Hq$^MOlmnANk&^baG-Z+wCQvtUYyF z5Nbo5gv8iir_tJDmL*F^*+9;jAau~II`LZ4d~HOOH@o&+<9=HF2b2b+ig)JJ%V|Db z|7e$I`lqVCd^Quw#}N^cRTt^iH#?j;0)wu)XgxMRQ_${ri>t)z{J@IGA;m+{ZD!;n zp`QsC{XRlp%PJ2(ka#q3#Nkv&-0|1`2CAZZeQ!oSJ*wv#ak-`wrCd-%dNh;CI-VnQ} zac}@ zK77Ica&;qq^y{4tG7oSCQ`et9Z=cP^6g_LpVK($NsOiymW)7z8RTgITq}w}ES#{#>i%YxGi`$3G%C(LmZQQWV(i+DEy=5-uz<4;}OSv5)A8@st za}mDA?ZeSFC%irKK3fJh9AdmG>O(Kez;2B{HX1;-s;;Rl z+vc5DV;F#yO;&(>Bu8D)|K?U=0e1h$>Cox%2vuzE{X-g@hYZK}*5R2q7WTIV%CBok zCf6o0tjT>s_VBEC1Qq)mLAZ0%2*WuSWj6K5rj7NB&NgENPve*qpv&ToIN}{p1 z(SnXj!CR1}T*}oO^VcV}yyZ1jr^`dOd3xm5FA%R6)|Cf~tFt-l*6%wc7B~=WZRLCH zm>AjH&&Aw^38q|wsL>IU9vmG|wM!h`H|fN#pZ)N$W)n@$I|Y}8K|HeVAo}R)>|k`+UHIZGcqK?iRmCd%s#ZFYV0q5e&8P;?0^U<-QiHk}IIm1}mcd~Q3MFZ^+?4_MQc zpH+-=0Ye=j)6c@9(w^=BDxpZ<5|0Y&HvsQ$S8tZ7DKmej)ud-3iS@;&x+1FaiAM(5 zQzADrY(I}+CEw7ML_69Ja_iB*)0{{Rbg+(7{xs0Y;6mp1WSn}wB1$jlgm8Qctl03Aqd z0yxHdQ*L#2gyE%F<80xEM{Jyv--Fh?bHQI1?e88tdxu%BqarsEOS&}$dmaGEBdl3G~7gm z1X&$M)g&B!eSJBg4qL~557Yd8r!}qRw&>*DBCuelJg(42dSn(HWDtGp=1FfE|w|?B;u`2uny_3YA6^40y7c8PF;Asrcp~!duvH6X=G1{!zWNa%E zN|WFo9k>zeb}Z2O`)*?RXmWEHR-XLe`d>uP&_&lP$*B?aKK@3YH_L z(s6^h5H{c*bJwOni?!K&I+s&j3tuKkeB3&nvB?+?y@@04=NV#l;|CRm;Z1!t7@=)J zAp1;a8Rkr|X!k3(n9auJE3|?@1Lx#!004F3OIf_TdsVZzX;*F+i0%|d9N_Xcuwq6@ z<8kMV^7_w*qSkcd3|6pQZ*ErMK}nFGyTRa--;s>+d+0u)8cBC0y3HKOk+~j047|Gr zEO_LVBdJo@?UPk+e2qb`BUv8KIkM~{C7c`*;YK%g&Ozs%PZR<1kBa{Q;GJG1)xIf7 zCxvg~(DVz(Yk@WVau2hN=gUx_tc|$f45;anxdOJE{t1iYtycAR{6f?24XOrTJ4cyY z8+jlJSP{kv1pE3|QQ^;zUu5{5sCahc#NIgZH;B^ecx~-}vSN}r!k!Aj1YUH3fCqm3 z5KaID1MHf2#a%>civ{kRCY2j5{{SNJ+q}&A%AmL2CInz~jkh*1#AHwh$l90eZ}CgR z*No7=hjkL8XlO2Gj!6{rkgpCf2t0=DdFxs_N5DTGT%>QPUqhzvZdTMZuGtuDOX-Z{ zrZa*&^YtFD@!!LK9k*F-?xK=lV|f-gifo)JJGR66umwrS%%ZU$_JZ(j{p9m5t!Xry z+!80$vj#eHY{ z9(*;MRrsxU;z@*UYoXe;v~Q2Ub1$c3wjy$Rk1=cK{{RVVvEMC}I&5jGYg(aLwCyz{ z4J4B}$X41&P(O&0cm$K)zWDel@YsAN_-CbPZE!B`?c?)rZa~1gx^uo+2m2+-3OguK z+JHK+Hk`L#y`& z%-uGuE!k~SIHS0Nb#vGs3qC7Yk~?4 z=jTVYX4WQKT?%{nySaUmrJIJu9B!bM? zOrA~y5_8mWYtyCido*#!Bug#yAh5a-{^eJU?N%FeoM&q}KZzB{{A1Ff_<4UF{{Z&4 zk*Bn&1+kuIl1oMjxnf`#=R267+AyOT;1fpE!~*OzcN0RzS(YY_`M2BfN@1EIy!8yB z+#qGhIiL?$(R@>Je*(s$;bcNfN9O+Re(+XNxd|8;EUHF*F@xB`(rUV8ypJSz3krGX z$xY%H83e4IN#JvVoDgwdLEy&!0E8}o^sn?8u8O-&a!D6`H|&%EB?BC9VB2wzO7tx! z#~N>lw&}N1T3XEU22(ZA+jO|VDw~+P`ygy}_aGfzrm2LO%-(gz{{SgrAsfq#j;hMc z#O~mLy@@4w4Ey_^?b~&aY?kt2RE8m=Z2jTDLO|!1KD>3TYqYh~;__Q^Llh%)O>$Be z000O=oPWK5a92DKDAK$|7oNUIX;phIyF(i{sAd6Ja7F<^p*4;uJIV2K)Ev&{|k zEf|F+Nr_H{T{3aU8%{Da!5FWcPy7>m;pUO?a?ivG9<6I{8>P%zR5GM`Xt)HBO0qMm zXKQ(C;Q;_`3zV;?E%mtW?$Y8lMLt<|opX}S^75O*ag5=($4>P7pA5s|%`!!nIT%g4 zA-XNYp5z{*Ipd${Kp5W+{68L(B26FeKEbAxM3_;}tM(LU_f@B@ZY+xSv&No@4o&6D5o73co|6ZHwL zd?j$I(R`(htv&-trx?!OGmJ6$Ks$(0@{vFpe;O@U#d=hbrli)gLlUSyU_!b1#_n)) z#sC;2qh}(#x5FO}E&L^H+LLNc4yhEPTMOH!j%-K(7-)ln#~B1=K?JUA9zWVd%Xrp5 zW|5&sXuQqO84NMNRXnx<8^Lz#bsFw%0H1Y-YQ>Nr^)` zA}{ucF5t`+zU#b!BqZ%p_k*EapPICfiFcn9z9HM(Gqa93=W?*amI21 z81sz-;zXhan(iBixWh{u#6fGxRz>6G8&u^&m1YEdqY9XF-}o=$TsP49a>zGGqFKJo zRijKgtB`TZru77nIKu)8kN`YyI&kU1%)-XWge-I4B6Mfr(en;e`DgFh(a2iJ^&Kp!6bKmP!NcbB?r zje7cP#J4zxIh$-)$zV@VK~aoj8Ej`Y>XQEe!9V;1aLe{xK20A`+FssUc^Ia5{GrC} zthg)aPdMveNl4Z@#hfi0F58y#H$3ER1a1TtB!je)G2MG$Voh&EST5Q)e|c{X(6N^J zRP99=9Wl2Dlhk127yur#{{RH;{gFH;@K3-NULduZC>tVrK7 za6X(6aLdT|=zgy7 zkBhB*WuqnKt2Wcbk+H+X-`!#TdJ#Yy%H28V-kyi${{TEyXwFZ6Q<|Ah^`Hrq?LBc) ztL^^)^;8Ji=O0Ru4p4ObO#o%e*az1&U9x%U#aNk7ztmNEKJ5TkcPf6h8#d8XOP{Z` zK_?UeT5w43ROLg@U(T<);QpUVf>1x&{HOyi;c|QS$u&-Bom=j%+zi&F%Yj_)#BUX+ zg1j?#s0CT3jdt6{#{yYMJ6G2n5^>lX0P-K(16r1U3onOlR3{?} z?yZtY;w*9r2jpKHd?2^@mRC+TDuX3B)dr@KXTgYWsN~3QPD>~(0P08p0hY-pCqJl1z>Rwp(n8)`s#pLE z1#`|vs66$-;MNC>z6a^PJkun!@dk&a_=43VG3~yEXK3?<4vcc@cH;zcFfqsmfIeXO zZ}GWqboQFwPqH@9C7Rwr5PWBDTp(~mW9Cv66MzUOJc7! z0X@$JqI_?YRyOHzD;Z(Q%dyK}Bw#LUr zUX0nv$W`cgVB1Ej_%q{2#EB!iu<)*(q{g7-Zyi?QmmkI?Dp=zf%7KA_z}sGf`#ksq zL-2os?Nd+t6^^ZUYJ&P}csAUl$P}2PVg9XoIor+$Pze?9eh<61iqUQsJ+1Q}oz7%@ zFeh;tIOmRo^&lP@@NeM<#BU3(SWi320)G zWwn`{u6*Tw`$!#B0uLh@CxTB(-fxKdZk7y_NoyL0#HFNc=YTLa<%#qJ55|*J@pO8U zZCMsI8EgaPZt6i854m|hi^hFu0}sSLBE8V`(w7X<1q|he6z2mRa(K_ac+Gs*`$71I zX|Gx)+jAbFBFE(UrFbE*Mghi2L6O*;b`|cPFY&#ty_L+W(KAL4#tM!j=YszL!baig z&PQ%f%l`lz{xIqP01SLLCb@RWcdcDYVz9R2G`XDr010g1B8{x=kX2bh{t!Ss!{S$r zG;f7AvKgcjU2SOMjUx-Z<+4>t1h_8*7tlTAjq#SJx~Il0WFGIt3j790B;&mWSb4u4eMvY-;Kc z&G%c3iU9BY9VF6NDTXlncD{b@F`dP+w<;D%a7%Jnf=J4*wCLXwE)1|tk}aO6BNT5g zP(0n`TMjnoX~MDRBX;o{GZvG^d=H~)S_4K|rDX>&Ixlhas)oSE;=F^@9D7%xcyn8i zTRu};+S|@uRaXTF##C(J?Z}rIE5L3~0)Rb&&&76rBP}xBN9H7P2gb}=$GN<=RmaL1 z7`E`cUpP5tU^-6&>sH#1v2Pp8aU_iG4a6wRF@i@9?UqxS12zF8a8FF54bcIfS%|I3M;z`ojd9w)HU>PPZ31?A| z0c?{PJCu{3e$w>_^xZn)Z!GQPjs{d^n&~{oFOq+D!5MviV5A;#fw+Kqu94zPdreLC z3(qZV;fs8WDK3fvsb@I(1_m*=s0@1a=)5W7Pxx6LBz3pk*06luY^YQ}IW5yWk4`WS z4muj?Cip|(-G1&Shf1}M2F%CExhZbs{H2&FF%|%}2^n1RU5AJK7o_OcUuUtkYgp~~ zNb4Rpv1TN&Ip^jqNj&F0IiLzQ7LDacGOJCA4)$U51IJvCZoRQxx5BR&w}bpc3#^~G zlbh2s!>NR&IkNlIXTgApPUp##{s6Xw7Ly&JWA=pb1oQ z_2!f&N8bF+2)G0MY5Pd(^Z}M*@9R*Q@_!C7R^2mDu)@#S zlc)o9GX*Epp68(>rF&e_6@lV@+TkAi7dX3G3s~a1LE^d$IQ%|@RWAW}R>#5@GYdO*^Dd6&oFU54SuApZa_P6w@QXj-ASSsrFn<<3gB;#H934nDk}e$)Zf zT;DV?x0WMkn+S*-k@Iu|JoG&EsjjWO`&)$l-MXTAY-Er3Z^+3V_~7s|PZg7->mGDU ze6&}IRUiGVH#`&2jQ&;FYe^of2J2?uhZ~~CP`(KnJ^FX!r?miZ_Ti&~HS=L5)tQt6 zp!yI_oS(DRSOwRzu<0$NOsBh=K^Pt=2`K^ZYgCaV<;08N!&~+Vv0Ll2} z;J7?n@Xt`UxP@d{m4xx|W6s=lY-0c~Uw(R5uo;#1L zyaDl^ZBteQ?TFFt^5kYuF>RX}3P)972pu}?BPXZjpM`uyb>SUiJKJ3EQHETQ9A`M= z>5zL6E9tL-UmWJveQ#?d z!3xW8*hU}%0cFQ>IOjgxR}0{;+Rsi+Gs=?s*T_}<&r-8so@94XtkE_${pBMW`LY1cF~^{a6>Rt;XzEuzt$lun)&gd&D+Yy0)C&4bW|*5=(h` zacw+`yMun_DCA}@*9Cx3PEPD?pbwrrJFIB-R-Rk#n^d0A%9;lh|bAJmb0P zZvHOl?Hp-sVQXh3WUl!nGDdU!MMnhZBcK3hjzC}2&j+p|9U-W;5eg#s%4b9=}MsLldCV)R2XZYo!*hA#`Rw5V_EIXw| zbl~l72+0JT_QxI0V7C3Bbej-~Z+UMTFYgE!D$MPWa$Bx=9l8P6x9*pRzhxf*=#hbG z;NJ||?d=AYZx??>L0vAJ@H@knC48R`=`oc#R!dmagX%b-55e6_;)jSnBj`V7(R??m z>z5K0B3msoNg;_q+PEsX_3hYWy>Yi63%(}VZr|ZA6h|DLRdm@%KAi`!ujsiv1*gUT z04^|aJE*Gno)EQ-2#-&;ZKJvh0Q{N#hG4#j;aIcPEcHze{uj88=Fa-ec{7%6(NE>c z3%17d0yLj8hmY>BF27UgI^$f&V-z}!H(1*hs@v)^O=}W1La1M}7G1%P(4kNq<7gti z2ll)0&Z)2XA?@vaGSW|C-gTwTqnK_j-yGnDA(6P<8v?8VI1H*+;yp|D#D9h}{{UwA z`W<6NlkQ$#-s$sPeX+qOEg=phB<;X(6CrY>jL-+z5%`u%T`$Rw-erar$|P3bXO03} zBYK&VHN!C48$1!mP&aKghSXw2`#VE$$+}47S8O&iyMh4i*P%p^;jl2JtK_XW_U!m~ z3$4Y!jC?Z*wL4OG)NP|-xQ5^>1W~xIKp5SU*&{X6cz^cg@LsGVve5iN4~Qgo{{W8- z6LzO@F8)YWw^wtIPnm~q!hk-Mxwu*5UHn_45|QK>-M&%RC-Vij2H=gu5%@I8;@=uu zXli_uc_W-{%SRHh^B>T0(0lFr`Ok!OkBk2R8(c@>?}xg^kEYyUe{6WT9X2S3aF)k) z(~+F*XyV*+^N#-jQ~u8%vse5oI~KUqwGZu`a^XVHb|6?+9d<^eB~W{k#{(pq0QEZn zvD>Pug;fAAJOC=>Uo&z(ywf9uD5+I*%>X?(V1EjgcLDrgg+@Mi`qP#T!>IhI0-G)e z-TZ|@>w(wt#aLy?sAON8y#QB`4aYTJRN47{6=m0@MA3hrS^%bF=b-1(n$hFw=~jyU z?tb?jDmfFcufKW#r*RJ9T#v+G46lfE3yWD4Ngc#&j=h*3wc8cQs`nOaDs%j(1M+kB zvhb$4`z&~pOXQ8$L6TBd>6O|meDqm1gRywV>>jz!csyp2`*7%Xx@;%yk;@@<9!eEF zlgUyu?l4H}+lu{T_=)g;OY!EHHMRAuQ^PWqWOdxE4_w#6pR~995mUkbEtcZ{0K?Yp zpj^nsv1!C4ob^w-KL98K=Yf7IS?ao0mvFZc+(UnE3~Z@~#^KIzc?1wh#z%g&>Ru1l z-saljt>$};#IogvULlT0C%*?Ij!EN+_=imWojxUg$uPc?BG0XOax0HI3FBPNWs*0p z(~x6+qzn+pBimy_{j*!)^}@$^jVmiRnH#s4Bd!l0Fa=IK`udsx`pe=!?EB!~jWeyr zt!pCRhz+mWt>rMt5a%F{RQ3KPU@~!DU*n&EdS}D9w%F^M_0q#8_OzNv>?-fNm zi5s#>$<9H|dFjcN4%JCguI5 zeKF4m4}c|;7zB*%3IOVBdrAKQf;ISsG?@O?aQ7GNxt9@qr1jwaee;U_a_~>WZ41Df zBpQB~r`p+B+NeZ|7}XU(_5-i2be8%RpdaejGy(XF;(z!g7me(4_K>WKr#o0S8y&d8 z=i0c5{{Y~fpA$R_erC1O@7hSm-?t5%ejdK{`{r~DQa{y4kN0|-2Zro{{{W9@Pq|S* zAAJun1vPY&%He&!0k`(UAJqKm?71Mk&KOHqs5Wya$@VDY6 z#f9YXiD$R)-TcA;2-?IXOiHQj6#oEI`;iulZx7sp6+X#|u(s;B`A1%WPzTC?v7hYI z;!Q6`)U}ToS=?#YdZ~sxeLhuGXkv&z?_B};Ws#+9spKL?4jUnUhSUBYTH7j(szCmA z)d(@)sp(D;=cwXEnKB!yc)rE4sF`ufv?k5B7B9G$npF9+M)Eb;h9!tu_$iAIxdAM3a# ze;n6IV`7?huaj>YTf#a=8iY~!0gAJYqtczT6abOOvw(Oaq>eqLf1J`PkbT8Hct6g7 z6&C~Ao(322>+4EJ54|WK=h}cApEn#*&d1@qwcS6 z)iieM2MRuv0Y=wOXrmjqayy!^uwZmNRiK$4j%liZp5ICUhs(j~x}6$ip4Bej%cm5G zm>$1{00DNNS_E1CT+@*-Pu|a=siKX3Y~gt0j=t0Z7nh21r+(dO*yrx?oO4P>2T#g? z8RhLi;ZIoq0G|}#nf2>V3A9iIeqIJK_*1rXO3~-?r3=sb;(!rUZ9e^}N{Gp)WBc59 zqhl{a^q>S#2_E$IO!m(JPzTrRP9NvmfB-r4^rr!|Vuj>)qz4^Gxd8U=!nkb+I#|!lQsWbj@MHB!L_~Y=Q2Q*PY z4prQKl&#a|{x0-UKo3>VG^2*c`4mw=13usL)|F01-lshK(M13Se)fCPqG##rMHB$x z9uLx=w~l)redwZq0zPlAwE;(C*QFFt1DBDXY}1o~Pqh?K08UhYoEl#&W54G`6ad_X zIM3%wC01+$-yJBTfE*b~?fFr?N$W)v0U(nev!q``vZFwZ329y=P|DnziQ4th49L+56ege&%HSWEODYv5vkD zfQ5wxkjC5qCv$*nkAmHt0RRI7fD`}#IK#Zg3t(lQu`sv)*-j<_4*;iFSpM7pS7SZR z`d?#bV`DuHWCsHOJIo|Egp^Zy?5-|he1#oRg1 z0MGo_7Vu(DhOoZ?|&pI667IxVpLf`3D3B z1&4&j{`2---1`sl$tj;w)6z3OfB82z?`uA`ps=X2s=B7OuD+r1duP`V9KO4!cW8Lz z*XY>z?}@oTe^AAG7|qW&a<$c$jvbV$J~@@ISj) zP6aU!Rvxy~;&SY~+RuUZ(2LjPqt9HrpOjP4aTcWTVuSCs?@x})V8yxXoBx^iKP~(J zXIRYtm1Y0iu>aSt2>>T63v=>Vc>tP#;}mD|L2Cn5ze-{geOyA_#(3hd`-}LgEe`NbcBIFEc5&FA7 z0IUnz2b>pQyw_I`+>Cmpbd9H;Sif^gzSG$D6>&0#eWs;5h?Mw4Dt$M(~j zYEgbMluK#9MsEXuH` zazU^g7>J|>;BI*dlJ5J_3$FjY|3LnDfl2^A_LE|Wab+b~UcJ7k@y3cUv(=3m<`ep2 zGe$ooAssGqn#$dXUo=Wzz#p6dFx%5#94GO{)GuGG1d@XEz2y;6l@~PCOkNWnd4$OO z&+Ho+SteJ5US~(l;MG-z^zUc}d1);)tS`|{0O*_PSEK=g6OcOp+a19oyTUc|_mULT zwspmf>wF=0Tr6TqGi;Hs3(UZN4Ux;}(8Cdr-T@Uk`4;v9{L^cA2j$P|7 zR&E5--glEr!|jiPU>*0HJ&zQ_Lo>g(idJu|>`CWa%z)Da5nG_gVA5Hx+#)y>-x87C z7dswQ*4K6*{k0@rvo_56kvTrjf02`2uOq{6O^@E1TBGXcZUdPbABtmetZ%BU$#=I3 zBOV#IaG_+OXp=Q^_Usy7ZHHwL)?w2s`7~DcZ&Q%Nsg8R}fwKuV#un0mL+gzjzbSc@Uob_?^uy7Al@ih`dA!_1jz+>A#(m z(!bSn#=_hs=v7N5L>P!2{TKPtZD#R$?Pi(;0(LlMwD5do_)fCK_!Zj?Gw+Xb?ftiR zY^M?kRIFt;uShT1`m!8QJn?^z;*`-Y*YE_M7X^zK#*<-^k_ zoIl$Kgc5q^|KSW}C9$qzHuS$JZ>$O>o=X#WpHh?f%?DT~u)m9QD{>zzZ+Jab5}|Mc z;PyTEb$5uSG8*68nK;3hucJcdiSRM1z`JL_IxHD?AGkYu+CDJDajjnMo;?A~N?S9{r#a%cd}JK@xSq8TXDhWLRMFbErhz#D z#5IJw$MtXDExojAy*)?Oj!ilNJUsyr+^8pj)1C(-y(fT8jh{ijxhkhIeT*^Edh8hd zBF;y*e_u|Qul2GL>4FRET1wK`CjJCq%V$V_?xGsAV7t&E@r_#(DRjgUUR>US3cFK( z_}KeEnurJJ*}xurKLK3W*iG8qg>McXZW`JiNW<{@?{+sYe;A1Z#(B_04KStQc;DCN z^O{Yn84lpfued$t1CBLnxN=nQ^V6Go*(0&68r8no>|>N!tu1u-4qe@wjM{|$Nw{la zG=SuCJ5~+&*mmGgFVh!!q4B}AzGt3}R-y>de`-=X%w9O%gjfbWxV>)J;< zxt(L;P8qtZ4PL9$JyvH zWm0fn65FIHK3|Tm7I*?s|2tf$_VHi!cd!FVBrYZj_PuNK)Y`mf06#xLB`jz_mbsSh zw<@ID%5 zRYTztE%JIWUXAkv;M%RaJzYF-6cFy%kk$6*Z}7UeXJdb_7Uif!lj+*F0@4BG=tLES zYt$O8BDO9^7)EfjjvNS7*FOPx+QKjDG_~Fgzdr#~9=kzn44_rvBSY4SisLX3%tr%W z%z>rt1%aqpr3wWHjocv3ve0ei8*Jo*7z&b6Bdatwe2yUbGbY|jBqiwR72TxXLw^&5 zD^VD9_X5^Em>hx0uwW!s6*f7HQ^**;@ZW?0{Rr6fw;L9O6TqBZJDg`7HR4OK_QWx; z5~Zy3y7th`b=>~5?&RRUABIhVi2-iyAmrX*SeU8c%!gOk+CG2r410;`Ulw8L5MOm2 z@@ihCT+4iOJcXA4M6V)#f%I|{{px_i>f98@rAw&^)tg>9^7M&!IX5=nvCk1;o9lk2 z-gk?8KoMgvlO=mZmpGa08?sA3C^^XAvheBJC6A^P9U5}`&a}@g0}Zo`o)1L@-L|mp z5gvN%oT*xFHQ;1_0yqT(Z5NGQsS2@DqkfS6-Kp2*wckUB9h>++nykMtbqHq}UNTU& z(-oCGf8cilSkC`^Oa9ALuc_wH?Hmrdt=_&tI|u&x$kfA-69AiU)0LS@bWY2(#g(ox z?fF0w{#ay~RXtcJQYiH00{yS`DZbPeqe8wzPiALV@}BufD)X5j;f~2+@3rh_nwviv zA^(mdZ<6uHUqZt za_mTS)s1VN<~i4QSy_mCniP%f-ld>h;+E#T(--Lx(%T@$Ir(>YCgPr?xJgUA^YrZ! zovUx;Kh}Xw^D)B?-xI%uzV|vgoF5zdbY6OyCv!hVD72pdcFn-Gz9gyd)6ih+RnIKr zgC>NbkuCK{tbg4VgP7jX{2^KWZt8FjUAvgqK<2`ow8@vfBz^uBBLQ`8sM!iS&a}EA z;ofI&scx_tM?Dh#nVd8}%XO^0_wSwCHy_jM0Dh=Dk? zSkeFNu%W^n)Nk@$z97@vH*%ZR5iV+!vp4YOQfa}PLA3C*uJ#a^G*2gI42-Dw%2sbn zroW`pDvK0Id>R%&r;$s%60R2&hwjlHA8s#Ad~K?TmmN>_Jpm|b6AMoORKHP#iBNwe z=FQ&Q>3;CmufR3HJiUI8u6h)dOc-;NSIm4~$0=Q_A|e|RC$zn4;mO0XrMpnk&K2Yv zEL=e(FIkFuq8IteZ){4w31md45R8D?Vr2Dbs_`vcS~~CA!-Ox7yfFf2_=Q-g(MdZJ z`G#_%Uw3-bw!9oQe7ZSgfpZ6u*p2aD{(t;+K7aZ?==(yhyi4c7O<^B-{cgy5_kjJ3 zRpd%V9LuxIU{Pzwn8;XoWcZ+D@8900Cae1^{ekm#j^SO$n)wDdjp`Ip-@p01jQM`e z7l?%rjD%dUkG;o7DbFUte{3b%r@mZIR!RBW)EF8StmdjHVy#{K&J^F%cbOHiyE_|- z8r{^FMnU4sGLtS@c1@IaI`~7y#i~x@*bVjDDT@->9fzX}sn}Ez2(9~irIG!3N;D4E z&1afz<|JQ|{IEl3RyU8xy` z@Dinkra=dOansse8JMMC{-_5#FcvX&0wAjS?dHf^3Z1zqE%ypQoe@?q*m~&J7qqOr zlhY~sl>E&vWwPd60<|vCmjwvf$yU~BlERgy-9iAZcLr1xNqoNeF^xRf?&gnkBA4jn zOHo$KAz0md#n^=8Ieu^kUUkM?n_ztb(g|RF*YCf8#x_@k&n|JO}89pLb?kq zr!WvsUv70-$#Xo_D_w!!+6>rLd$XHSlbOQ7ccmYtzR*42zN1XMG|WElY?}^(sJUz& zh93TeZY79Gf<>lLlqCi%OLg(i76PlT`Z_~rPk7{;Tw{Jsr5o71n`V_v;yTUJEAFh8 z+zC>Q+pIv3f)(1vQNMZvWgh;@vVKkGlz>d;-uIN$M)$GMo;5QQ)K&uyUH^-u64n@ZF?d zaT;H%?RGq!>0@H11_E4;9tV4LiX}L1G%BnaBW`1D+9~|Pj0~c(D0Xc!Zb~%WlHUJc zSMyTJD&K5YuNr=%apkpc_vzJfuute3b)1!I`Jztg>#!jOSIZEuM-%r4@%J(^V{f|l zkQeZr%EAO?F+1mP_~IXhg*4}yEqsQhec~kqJG-Hdt_RB)pP@AWcy-`vo0lqQA9DR7 zaX1_ARTFaSvN=4``lC3TXojrAh}?Q~52a$YqqGHI(L|9X>DBSl5*1`5Po>}LRMIEI zVTUctA)us(nQ%e#g-tJD>#N+DReHu1g| znA^h-?oxmq7nt!rRI8|jm|1W5avO5{Kp-o@-NpDT&uJb4sfG){BBl#xG^2hF+^|V{ zq90@Rq|p9;ukvYgmMzSERWb&D>0l+d2BA0o(8&YNntW-VhY=?F_41)dB1!K@61^W# zFBYY~iL_7>MV6@&IFlv%Ri>Z!AKSiu0(k3eCJ5IvUpwl4c^#HcCbsxl27FtW z%Y=PSe0%1!i3=;GFp^U>Pn3dL>)BG3$vJ;jeg>Qa7cS&ES&W?TO6)8=+*zF& zs!bW3#jcOk%Y3}P+gS4Et&Z>gtenamqx=w$NR3^rpeHROVhF2t0(e~JU{{{&iA>X! zV}O!K49%}9mrZC0U9sb}ET->yypV>iFjOo0k(&B-3ZV>eq%czfQ+57aS;hsR%o*Us zzjYtt@12Gs;;{S_IwL^-otMhinY|y{IiT+8&$A+aPZLH8ooz}x*f3=+*V4?};tIhc zQ<}s}fAxd6Yq~O+cm+d=)#5lC@crAU1WWDi`isk3+o^%%cjxfN48=Z!(!R^b;{+%? zfJ%#ODeA)Zb6S|THMK7tyNQAHdz0vD`$Cb5;lBk0pFm?&4q--wJn~ePmzf5C>ND2;ijdnOt8m0; z0_@I^ca`I1wyJwtVS!~t&4XL#KN*MH9Kgh-Th*^F%7@mTCukpvX80Ws#}A3}nCxp@ zuSl>@R(iVT1Xw_X`5KM!>FiOjK+lC_jD}nxCT{5U7L#v{`^{`D4dvFC_b18sh=IL01X5qBN7NIIsbekXRn5uM zZX9oeeG;w~r>Z%kP!4GTDC*DK+Z`{yyfJ=oByn^@nX^LR^4*!=RZ%KV??q53mqo^& zZrZZhR%9PXr$}JoRG@fc!~xqVs?JgbRz2^XQL?R9SgS`*YM68-4CxB1w>?ezp$!eF3K|d% z6qVEX2d44sMpa?w-vy)&bJkw0en;ZjqZ_x~N*=2Udrx>=fX&s$V97!~QV=8RaKL{x}AlO|p zb{1@%udIfkr;?z{Tz(@XwVOB@zBD75ZhSUl8+9>uodnyhec}*PTemCI0g|)%dAvZ$ zem>6A(++}PB*A0honY9+6%7*usnq}%kF4UD(%3a0c5aBXy5xBBxE$g&ILYS>WU_wN zj&nM1sl~_6h0plgtaWl)<7SyI+;v&8~2j_8WK$o*wTfeG5@uLRzY&2nNl1 zscX~S&BuK6e-Zl-02+c!oY%!Z)^|EyTzN;Xr_?xxf(Ny?U_pS$461=`vf+{3xWf`d zGj8tbbUFPurwX8N{b_w0{Bq~%XXh^H&H9nJ@fyF-b$9On>Da1_PH0E8z16xQ;{5ZS zXN$@Og&s}H_L4>V^lcmJd9>rC*vr@0g&vBqA0k98!maUK&zno~1~SQt@z1kxL*{ip z5W=M(CJDBkn>XUsuU>84@Vp_P_R*uDU&Q)?O)#*HwOpWm8491DPo6VV^8SQA(=AGV z)H2;gMgKklC>L`+pVxfqfGN{xq27FqTjCp35`fMu>@%0OrfsQ`{LfXvK$ljH7NZtR zj_l(h{vq(=vr`T4d{sHgOU9Ae56jne1e*|gRH~s=cw19Kv*UK`1orkN(dH{DG0fag zGeKTJ(0MexP&M<`7D9EX-W)L(h_=et6@A|GC(D~v&iQ+Cu2@J!$owwx1n>=;4Dj6V z_HCFD+M$OOd~nUDAr{)f?+CzMf@ItMcQjQcPbd2k$5=#YsOdk_3uS4)SMEn|1>)z% zrlY+D-%X(b;2nt^;IOEzCH#Dn=pWK=>N;YijwRWPbgHa3k4sAQI{N;=u&czmRWlQu z^&G-S2Wb4nohszf<)aHmWNBKyw0js2G9{QSq?SWKaEF<)r3V{6?X@H_gkO}=yo;rE zU|gTtXRsiR+;;J*osM_CT}BJ+sU;N1Cim1414>|Qz|j+Dlh*^(0L$~b{H?K(u%EH_$KQ8 zx&oSk8E$c3$RXNrAv6X}#a(2y6bXP2fj<+Ql!e8z9zPc&an4Sf%c# zpiv@ngAuB~!ZHiY46PQeS0)JzX7S-UlQTI7Tht_Z$VVfx6i^P>Ijib`WN?+*R%P#y z9%3tU_be{*Rc_gMQ^J)HLcuP@WBDd^5^!vvyNQjAh~WI5I}@*D zgLL}uZzxM(DD8Bq07>M_5eL`bWSf%A$DXr2aED}dE+(-#R&*zFB1SG=`SK&R3+<;R zFos+an`nHg<(C^g8dH^W{3LdHD6!08Ii%#M>7uvS7$*qfI>th8UeJYT#S6i8Wc6iWtWB4r@%9ivO zq9)Spq0gE4v@3Eom(4*iLY_-g1TtG!P(sTK3tJS*SVMVr;L|z`6?R49;8%zJ(mths zf9_2D{Pv5!>3&&X7FM&1L{=7yQk_R_IU|-_5+2vC%L-jgEG3x6kfTu|&_RDT;{a(t zXj8P(^v9o>QE;IwW934}-iP>vh27du=pON-k$?D7KPZp+;PO6_XWYvVKM5(Jog;Tc8 zM1|rf?>jHuy!q^(i*Q6jGWRRx4t_Q?>_gVR$?P|Oj^(7{G@2fc`&G?(+}AgYmAJkp zI~vP*oI3Dsbm2jfw$O|`UMMUYoHdJS^a-Xce&1E@kz#lMGD)&k&>4_^6G z@NTXF$^r!ikX7y`qXHyKpLlxx)L1R98<1zg$xp?#^#W%jmKT}^8cZD~1oUsu+6C-W ziNn;vek}hIj}D(yxWka3m2%$18Mn&LGrsraQ-10RBh-P}kQVGCZ=Wr`55Ii1bUfHQ z*pln<`DZB$g-H7gDE^r!>N+_inqKPnIOn%v{KIAO^L5@c zvp#2x*TRB!3g@d37n`6jre0oeL}4g91#@+zUY$9Dp(c@P#`CH11W<)&#%!rmrN6}J z)YCrR@ev3nFyVA$_2{!fFIro6iO$KJ_}>P{gqtL6(qXOD`1N&boo0m7)V#&=h!t5#_nIfA0L`Q7!m9 zUg($`2yhWzPW`|tP^l?NYrrY|hlQ{{tPBVo#y>Ixgw}L3XC)L^#^X7EN#J>e_R|o)L7>S+F4Hk zWs^gns?Fp}H#E~1k=60VE(;>fXg8a^y1vI)JBHMoHHig*uZ7>a6J`wMFjQN#=7#lY zSGv{4@h{uO;WZOkVXFb*upI*MeuRbx<6XsRSj!9_jTvB1w??&AZLG%Q`sSx{&eP2u zH50q3B&2cH&B-k5MpX>@=&KulL~g$fJI5+sOCJb4>fF>Amp8ig9nc!MnqvavAwZ*m z-b5ip@KjEL_R+k|K&`B{*xY-uAiAdZfy(?VP+2mwggJOs$2y6w8(Y64w7rf1y1RYp zRru=wbJuH1{i2Nva-sRSAi&O zKd=1T9&7)8)rJ?EV8UPPQM>&syWIq*2G(hj_+e!=K z`kNC7CKoiK4q-{LFRQJP)!Ik6Nd-v9j#-wpOG~UmCwX zvRR>idl(g^bxPp-baP^Ntc+5<8A*8uB77*@F&kzUAtP1s+#NdZmHo4*8wc5olWKkA z`Ao$DS*DGFnr{)kDIwDqh`aXvP|rz(Z2fBYKETh+F>^rom=x(uq9Ku}&?w*%JS^0u z@6N;>kn{Ez{fuj>^%MN`mNSGIJRGgUF@*e_;5dpaEuGad|KkOn)*5FOiERx&yz05) zQx=yWBGBsbF5K&E`otOQC?k)+UA|X18usVjJG-9}8Kx-qnKAwjDJY6dj~YMNR1-7y z{jw|W?mMj_@@Z3(Ep|DS50HPJEACV>Qd%m&WcFXDA8*d%(|vkzl4KKeoYHr8fha1m zH#bl|0?=M+^~R@r{hVGAq4U%=u$el;9?pr#@_C~&b~P@ow96xUFLf%V9*%-0>UuY7 zfeJuR9*>Nwu2)rszmx5@U5`y8k5}7^Qy0>>CI=wqp@Q5+G>ngNUaVca1(Of*B5YQr{1|nxJVo2Qfsr3gNMXR&*=fKQC?R-#FjgClZ1@)_#-8c&rfG5C}k>vgWx}jTkz8>aP73st^o^ zF3q%R;_`gH{m7(+*Ar~9Ah+Wg`&&>A-atF{SFG|c) z5u-Lg0%K&2SP||J&$|9rXp)Jn^t;JqpOAS@X4&ZdTI#=V)7Q!1q{Lv+i~$v83n)Kf zs(oH52)0Y38m`J#f|@?$rtG^9RPWdY1rb6Kl&@wNx{2w+Cx9m87;@~;EMes1MHqJ< zic+`02xCIJmSu}{AQ+(lxBJ=uy}OXt^L1DFS7^^-zT5q4a^iugx8U%Zp&WhLQbrF! zT^aw_qElg&qwVHVPK2_=A%nHl^zdS@vh((-Rb(I2r3WuDOAGlu;?AlLwwZ~+0s{N_ zxRp$!wIkH(f9i##b;AbdN&2mN@}idNAybf$s`uyvAKEbDEKbfA2J7IXxT)RSym7t+ z^rS^z`|sH&<))>1g*%82T#l)X@)LPS0#STx#y){-EFFv#q$e-f@jOHp#l7 zF=FI!b%vGEhzXRW>8b6|Z#GS6C}Fp-H<`gqUGs2c^eLvXl~(ezBU^smb{qb(y=}wl zOlx1z#v-|IYHKD}^qE=;Jm8Q#+DZ>D&(hd?6wc&APIspFR2{rp(=fxBJRl*{1~&lA z**8l(_-PWP!FD}$s*WedCv+QWOZ<}3-&eIJy^fC4XgC}((ztcclT>0P0r!Dj3!$z* ztUP|lHzg>~_uAa{J?rUgN&t$b`BsSdT z!!AlCE*n7=9(nctv@VwGRk=ySYsZ}ZjtRHp^p&#+y8MK%B2!d++B{-L68T^Zij#{` zzsJeZ;8IxHzcTH2`=C$c3H$ZUSUk{MnJq$A7Fc)ma55L;D|dc-3l$q1i}GLHrIfxA zaTQ7h8HTCrdF)n4JACv=s9&Fuy@Z%+oZaGLAq3LL?cO|loM^-J*|`&Z8NP+}u>`H7 zJov8Y&1Gqg;fKbd1B-GFH*EcQM2rGKU{tNVp|YA?Wwp`Scl#?VlgLjwvC2-LW__`w zKy{<=O~b&@oT2ol9`z3kp6Q-7rb_w^`eF+{dH|8OFp+}+r{$kGao?XY0}T(~_9<&l zlcQtD;X9NTcNtHg2REJSx6Ds_ymNboY3AZ4T$0=Ks=0J|#bduoHcr6lpSK=i=oUO! zmWueK{q>rZ4L(9=T=$xLN=_@JW&7pW_YEb-fY5#$~m1aJB47d-uya1GgGbX^#y zuq`gEK#bJ!8^ear|NZBQw;fwxQKah`e3qnm><*Cs+%*d3`EdTPjT69!`;2ver!G5r z4^=ZQE@0t=Iy5?V__V=wKlGQyq2s5S+~VGsg_|7S3TR!glbvUY3M5}|5Y|&;iE;b| zLRc+%sG5sz{ zU@+i!kGP(gALb8WlMpr#p?Z7Arn=DTZ65p29D-%X%Z$&9^K`gknZwNd{F(u|OcoqDmwO6tBE8>G+1Xo^5wcQw1mDoC7_h_lxVv zw>7Z;UyfSyP#(O)etgo)TRZ;ZU(bvDJStB{bJS^RK_I*8JwkAn$MBb3{a@p~_5SU9 zvt`YfprAA;yVl7C>_Z_xF0^19{h&N~E*O`ekf`)?;%sdJi znMxr1>)!~w^_tXwWzdCxiQaPpxK^T+Kg#^bO8SJ}Dbtz|5ur#gqrt_lLN6ANxy&p$ zjF~dLXvR#O4)U5P9o+C^s5j(l3E%ylBoDq`j`!j=V$_Mw55wPndD`=Ko=%D#n^bp4 z4IXFJDOs3%;?`d791iEDI@S3$2QIDZ-rD2xn%^LR14JBQH|HZAYOawKQhQhJH0-Nm zAY?JLD(vlZgiW$MVoFvIuX!B$`6t%1f8~q-aFS};`@sO=LRzZ%R4#NB#?-Q%0IV@? zlDS-hK)^+!rH6yRZvVvOdo@I!2xXIDdahF-sg-(~EFh>8YHI#yIQW5l4(T^T#BZ#Q zx^ByS8y!4032j~9axALn66ythbZ8Pcx_9|DIQ2hQs!+t6cFJ1YW=kvz{`Q4V&(PPzW^{GGgDASL) z@KW>x7qTld;%;TbLj0C*@VXAx)$7);q)d;7SAZ2D2p)R4#ZoBNbme)~qnb||!_oWt z`@4rfg%a*o7RqGb)}|}iTye|#qUhL)Dg@buP;cv#7W|UC$O~c=2p-@dymfQRZFJk9 zCe3skNf3y17(R#2iYv1xK4^j7_kDv94Y=%8{_nPnh zWU1V+<0Q{_uEMnAy2A;;q&Z?x`*fRzZp6Uzi1<0fKiOR?dI`}^@1@XwD0CDc&xOGb zXj9yOJ`u4u^~++!OJLq7bTfKbqV#%xzC)#RtHmdZkGEvP5~5zPo&fD8jEZp`^L#1^ z*@AnS)r>frAatg@a_CE@nZcmKB?76Czhh~Pe;^~kZ1@)W4ONHhr(T?lY$d>O=Dttf zF_+&&QRcyIkFTCH(x(=LBrRmy0$}1}!05l{ANe~=i~gQ;mHVTmwe=klSh>Z|rS*c_ zi<{2?o2$+L_-5<7ttXWSw%vw)Oxu{l9397HZ8E)$EBz(=DGW@eP!omqX0aOml;~OIGrx5d5n8eDdV1&Jp zuZ+~B`Pr!V4c0rLq_^tI9c&0?0rVCDGKIW}Tir@|hgYvO-&q!HHL~n~bGNcFjuBG( zH|w(!4&>Cy|4YAzb#ay@l>i|LL$uf8)1-=iXZ92lN+F?Gnmv z7vHB^LEZR3$z<;g{6X!iS`i$jmCLd ziyS|$EQkk_+agtdg_!0dWtvi>mxEXJEKlLf*G$YP>MD^C><8J7F)#Me04AJ%9Dm)> z?JUhS&H+syPUgKQYn=uZ-WM zZe*Ao7!4{)ctvIzux-F{O831Gt}tiol3 zI+x-x(`4DI)$Ri9{>WM|OP`6Nb5G$HY%{%ue5B`$4sO@fXU-XQDt%uFm@NMma>nwE zmc%3qV{Nv>xa$jw#EiyKbW&6w?sH7L9L}<9gEu~6@WMMkj&KiB5vBZm$V1`wFv?r!nC{UF7N4qkqU7+@^|7xmQwGrM%@vgL{yqH2)Pd1HW^MnGVjZ{X`~|h=vEW3kY2RT91rHhCQ7e z=55+e!l?V9E+#b9gYC|bGd51IzPf+@yIx16E)7pz*hTi7A$9z1$G;XHp1SK+6jftV@$v7k%M}z0CR@Q?OqZ5B5xkNs@k zw?l-r=IuBx&$#|0<54M3I^S_0b6c@}vLOoJ??=owCKy1b_dnFhC|^vsr9Q7J5>|0t zbe&K5pu-l8d0O5xOTJ!(@I#p%-+Pkyy;83ocX83_bw&AXR{Np#B+f2cHa^Z8on^ea4Ip>R}#i?SJV~BA5pMQ-*b3-BiD}$3y z%-Xg;PXw6R@tEnrqZ&F-$?B}Ngz{K^F>O{_<5l8GzyqMsV_5Rxub>Yp5@- zpj+taw&95!g+<>`M@)LqbL8VoVTgr)Tdg}RAkQ@ap{PqKf8yp)EqPfcfcIb`1N1r1^4f*xI2Ep7O$vRJCJD2j{Y6HL13ocpRu^Rb5nm2VcJp9|Ky9MNWl7P{EMaO_l za&ZLZnW5flHP7gK(o^UIFQX%C2ceWNSg?M9&JAwK891LG(u0${i65J|W&GSVeP7@; zln(SRpyc~8{Ni(5iWD#Nes602G55eljqJ51tE@ryv%+Q8{QUIwXL6VGyScf#WD?E6 zFC7>nkrWMgD1#2QZZlExbsL3n?ahHEr4=_GJwrL-#A_kWL(JgIjtl&r*FNi}FWw}~ zr8%qW_8a}tef4Yq4*kNny(h%?XtnK6YmBahv6xG-0iqJMK3DQ}jd5Cc%`~XayrJgD z;zwj`Ipv#Mvgoa%`z5nU?K{FtAZz5DSLm1d3PNn4asCQC;|^sY_l98m&8%3hjTkAQv|6}@^% zlend+P@_$;))%t>`esO;_@qkdA0`J8h?_WW8|((ku%?RmgF- z^z4l-&c6EN_M?g@U6AO(`W9W`_BWmd67^aUdTCpvJ!yoceiuKz;`3$@UiLM90S@Xn zI2T0)N%2oHoQZ0lkR1xrc(UxS*YsZ^MOs@%Qw}7T%rBw}19k7Lo8G$Ze7?pJCJWFV zc!w9DK&S*J0zwMZ)O=EEJhDteTK44zX*@;MJ{3?~}V3ZL z@Bt=s;amz{jf=#j@3N@qyi&p3N;!jM+XG*|=uwhP5NCLM;al-~y442_57-%Q*~Ui-Y!XLt1rWTt9KN8`^Wo&QuC&zgLedj=FYx{XMYJRCaL_T80(~$yh>sO)wE(D9=m?? zcE0t`M{1K#6f5HFw(8ZQ@)8xsef_BDFaepl_secFKTNS&tBNieIHgr5$}p=4y6c6t z)g$4LQiadtc;7+*7YQ2iXXv_LxxIRP3x%L-cE4}003ZD27H$Cw-`rkuw5@o#U;D6D z<&8DikkY4oXP#ZqUIGF@Fw3R7&s0YLdC+GFn5sWzvfoj1PfcDTr4jID58Wws>dx?FkEf& z)u+FMX@)NDtY2VkrUkqS2?3;$N)$$fLPfU++$Y?h_+863bKN5cO|!-l%UIDq>f08* zDTe(+e7V?{8f3jS3g+;`#6)YYt5ttdUNU;UGqKrCk5EM&L`B< ziNY$}W7k|h^@2SoAIJ7^X#Iu4Lq~moj{kIj@1renJfls7IT!@%r?}!4$q=DE@(xq9 zT$q<|KfK=A7!K!jkr<}`i59Yt+=T60@h{B`{d29#)Amp2O|~Q|y1Lcn_pd8saF^t> z1JYBB{RR8uwT2GjFnuw5i$=>74Cdci)$+Thw@KN3_!_u@P>II3)mbJRe2ZeNehSct z$~)@UM(V^pms-tEjZ~FO4=y~1T)Rt!uK=C{`x_?FR~GS8Og5c~a@9;c z&!(B0NfmW|i}V;t+~^zwfgl7;?gi7o*OZMiIpHOf?U3bjyhgCj$x?@p;MSBr@wctN z%wEenPBKD|p#+ZQeu1AMbD=`JxKPt`A4hAQKHAeFwU{J1!?eWeugmp_&KJXi{_pra zx5Z{s_D$@d_{h$rB}C2M82!Q@(h^1>JWOQKW?EF2L!W$nn`Q7=i^=AZ!6}<+`yRJ7 z1E*Mk{3I?Tza@;}^zI&v&U0hMCUUU#wuU2|I2Y85V)?OvRb$zqyBFQ-ak%+NG#|Rv z_R6S6O1BzzIqc7uuAPp*3O~f*DP!JVZ{UlRuhUENqMLp zb};oIO7rTG+XsJIWX+~1{R)>lh9z87qW^2SRwpZVfR#%}nemyJbG4Iv{{tTjlkj@k=X|FWMwRbpuBNnb^R^aLD*P!|b?AwNQ1(aXpjDUyu57cDSX z+ArBq(-xO+)#r-&d|As%DAdma?cr-xglvg!()7oau|cB=A!{loD{5cQExh$uhz7&` zS%v*ZFx#n;nK*1x;JL5fSptz6YD{M9tw*p~xb@Sc$mcN}-@;&r#n;AZ9@*02?dP){ zDLDX_HZpOH?{}~^d1zI`v(^7&Q#2`QeV1OGR;6#ggbhp9b@B`o>=AR#_c9bJz+3MY zEshFGr$Vlqsm*TiFx6$VPz-VS#&*ZmgzOTI_BBfM;qD=9w^+l5S1?OQAMEzc;dzaN z6rT>euQpJEDgU$n==ZONGc4h5u@ditFwTtR8*;rrk>GovG(s>L&q=@z<81OJB>` z15+i`BQ(D^O43z?>c5%%VJ@Lg!Nc^Z+;?x$@Qx?7LCBY^ltF3(Kt0G8P^_H5gcDuG2Tca!Fk5>~7XE~nLb$m{F zEq(iy>Bt+ayTjd<`bw4aF+q{6>Z&d%)-^poLTZq338LTeiO(8#UG6!2-jkIHQCcC? z4^AUawe_xKWN5OFMK?tZ5sP9S)95_)uX*^MbDBECWm;TUBdd{h(63#XmecJoyL+{1 zocQH8*;K*XS^V=?k`l^aV3kgrO+M@9ne)ai<&ACwGR>mTUJD3LWqZr2{sCVm=6!$p zU}+F~fPj@zOjDjF1)Sd6__^MF`U9v{S%^NqCd(HAkS+3e@h`BMIzLo%*1k1G#7@Vy z1-xhwmsE%1&QO9d{w|>d$(VhKb8xhldhn9P(9**1(coXXQyHD#d=%DHdGNKWG42zD zX{(#IO{o19>9f?rW@AgPq(IWQOZ564QV5y8{Wx1s=@VsbD@1LOuX*)1%M(YQ4jj;W z7cWtV^%xS^I`GT|JD|g!^ThqOu%f{CAw-Gj>16#&G5>0GqDCqV%dQr>unJ6QuK{WM z`3joTs)rkTY5~mJM)kF^sHfvu(I&L z$xt8rCfZe0d~AW%hz^GE$pHIrrP};n{vQ=wbzD!=@sP?rtPTPr8Q0fC)(FDBt`3yU*S3?z!hY@r?(36Bq6qmma+m z3Lx%{wC^w9kGHwdXtG7eF1l|5-9%yP)Bo1^6a0E*pm>^Cz`kQ?_!a)hpE+(QK`Hsmw@xVxtA$Q#t2wTC?&m5d?Y4NE)SfczsfFHl<&mVDFa#&fIA zUvHY594mW>bql^7oCLrin&I*uvl2Zl33ZxFuu#KQOKS_i6I`3J&HW|^hDh6LV&uxv z4vzHfnEB$axUI*z)@dOv{QZ(2{{piwc5GqXr>J0;rGFBvOR)+&f#^x7NI_(6a%q%c z)1rj^yuzdQoh{2KzR#j_GB_QRZU7V4uR4CayMp$?EuZ+6~VvTx%^lYY?^Bs(q>)LBJB7nZiEDfy{Vd>>l3^>auhtTUHMeIL?u+B5i7qSaeK0=pJ*m(6p;Z^MxlOc#uu!Ti z?}JhCSnF)`$gaxTr50DLxPnEXnu@iV~^zo)56gqLR$S)^}dEotKj*a z*ohbAi8jm7X(GUJA=H=tyo5S?VyP}1u3ni?hR-O37)hkQ^3WA_f9q^q>1;T}W0w9% zEG1?@zR?MQ{p8PrUF~#lGu$F-^3U?JLKn;bk%hqxWH$AGr&BE}wsh(LY;cVGf|lTp zYp{ktAVn8=*h7o$oH-rc+?CwYy{G$FasSOwphA6yg>cEgi#Io%=YK0@+j;RkAFaM_ zapOP&+SksB8Ep`_@Oa!9XUq~*>mD||;8w05OB^E^U(6(M~8<8~nWO$vFTK2|bJvAG}kh2nnYyZeTx~eW0N#X~`IQpu;en{cay>=-| z-bZ=J&us!Ov+kqK>iKOF_m3|+%CcMnkk)>B3-e!pi6{iM2 zbvq=)xbh!oS#b2M3-g(ha9KNeuO8-jTFrrXH%*Uv!lN6f7f}z!U7q>`hO5xrWAiR% zG__A^lHr+KST~!^dSqob=RP{}kifaK4!AkYHVTSd|B*QwV+Yz`oIk?$5sGw{#}_X0 zG2V@Tb7!;z?$ic*+i;v6Y+ec_W|^m8bj3ACE4RIllltiEvKVa-q#1#kEL%Jq8uk>w z5@Lyr+^A(kf%I_JRN9KV+le<=H)A8)IiwJ$-CMWRp(?NTTcKkNufN*8z5eMdGC!f| zj$YtRzfXsBe_ew|@D4_yp8YJbg!&T#P;ps36U-GcMBj)qKvp{hE6uQmIL`AMrVNN} z4hNnQ!NVoOjlHqV{rHkB0*LaOT-eUp#4X39KC|oYnVR~;fvB8kytJ#WRvn=|nisAT zP0-sc4^7j8`ls2Y#)iE%0j+Q~;*sA1yNvBegBE2l5t*y6is!+%`FzB&d&*y-<6b=F zi_z1s-@@C?RAp_=1Mv$omFJZ5w?LCO4!m{Uo#O|~d8QW{9*9_)bUI)}ek00XlCvH& z8a&M3ihXx-$Tdb9wXsompw(tX#)~1)Q?-v*TRYynThR+?*1Won$+yKVT?L*SZXr{o zdYjC2z3QbPV935}_K3*rnwcGaFL$PJqE@eZi2A&&n^!nUIjOPkJ0Fp7n(CcM$I^H8 zu~2{2*S<$f>rdKGbNtg?5zp%^Hlv@p0?u28AzYAwjIanvGl!e=`(j8y2oKFzEp6?!V^sS6*8a{0-mXTa(55uW(a|CJ>orgLL~)lmWc1%7d9U(;riaAspF~Yn@xYXWBKxihFpc-FT1+~sX3Cizy`!|lPA~2E=sUE{KR=hj`@~$Zn0~EOTk6=OnaK|$r2MM8{5q^V% z%B(;^1TG2wzh(N32mf)-)R&7^r;kpvWD4*f=IANb3!xQ@4+%Q(<|B4^ds|N{U=0bE z(iyp$OdVLQdNOZ}!d?%kRp_r?X42XD z6xL?fVuvJ=DN*kb^wI%+zHAE|p%Q_MxB6g#ZI8sYrF3hv}<|2gB_93#q}) z+(2^ZW!Zp@O}6wjxfh?q1HVnX8{f|26cr`oAb~%x&@{Mz>FKK$M>WWTE`nXMN{;n% zob|)Ch;0MN5cJCEv5B7p%)<5=MEp;6qFPzQXR%PJqdtWQT-(^D&IokgoaM?$EuLMJ zD9aOiW2XUPjoH+E^z{xSG;LTc0a25agNKennqLG@N5hZJ=bcehNtPuC!T=mmf{V|yP zc7QBVVbZNi@-aV^oncpWsd#h5_|YQ8FM35*+TWUMlt#S7mKn5!goU1=wT8w@lS^};BTMm@$_j~!jUjLy zOYtY3meS&G_3G0Wt37w_1S{<+Htqs;sahz91X|LLZ=TVn%TV#BR+WGmZ2-k<%hy>1 z6g3P*C|_%`hPzk+hM#Wn=LOF_8?OFPuI!DWa%x@JG99zGk_|68m6M|S6o|aUuQL1# z6zn%g%j>4KwO6Zl86xMZ?P=a82ArIc30_WPUk+Z9a8@xHe3dhiiK@3BQK;8H-98IL z$jF<9s~1qj-kc4$8qB%Ox^^84F~_N09&gRqqym7Rp!2(BOr*l-{Pa2nBTf|^-HNvXEeNL~0D{{5@{ zqR44W92?{9^Ip8|6weyF=Z%YhWVsFRDFAC1dx zV=EM9H=qh$^MhlxlrbIcVCj%H%PY86PT>ZZbFxvE#8X!6e4+PgCBt~>6GClTp)|eV z)wW7RyW%Cdr*M978w-OM+h%c|9$&Sh|wxq>=#DjuySd@Wl4Mg1vk>IM5ni$5v4s-&-nvFiZgNtwfADLVJ9gQx{A2#Ic z39R6;qY}(+T1)y_F|B%teBdvRsqJh10XzCzU1sxVzwY_%37lZf*JMEqH)X&Oa}wz4 zmb1s&RSy$}OLR{OjZ!DW0wHsfYh`Lm_8ztIT{=)xRJ*IMPcoT54@ zWVkj;b!E=%7`;OF>%7MB7ik~+@*3RLt`S+!LTvO6@a9YD=BlLf>;XDmx#Jy^Wn=l^ z8y&JsRmQ^2)p3hW0qLyVMem53jDr_>kjqdWkE?(7YA4Vr4c{tZauku*No!%;GelLW4(Mx zJy_Z+VtV(#wNWN^>03=}4>Yj=36xQTu^J+C`l$K9`37iJA+dgQY_1H0q1F3^%gj_! zHig@%OQ$v@CsUi$8{#0_g<)8hk}|T{;w}HVI)>}@SUhFJ;-Zrb>(MNo<&-RM99L9x z{fgtvyHu-^{Z%f`>0W@TG9TGZ!8}U9*g#gR6G35jnZpOQ5iOn)^}C+psxL~%9_ft@ z`W9QySv9M^8EBZ!J_QFvVU!W`z!80$vTnKcx(~s|U$!1SFwR$xG;GXRYHrB0d(`eH zxJ{}J4_r(H(}YSyd(5WM1U30CLalQY%fWP;X8jzxh0lZ;(knQM`N*&C$k$4Y2g5Jl z(M$$a$4gDEYE>~A_Q=t|wy!JD=@t@8}z_qxn^4k{-vxN*E=4=N#^HW7reM1Uu$cw~=cE|ij8O?$wQPTWkoyQ$d>v^4> z?jV5#C=KQeT&t=49?x}fI|1^8`%z4FoUz#qs?vK*!?fW0eyNSVPll4Q)|X;><{C}q zgEI0iuHojykK3}Gz_c#Xg`jEAV4nZUJ4>e$@%O5<-N_a*9CBU|T5Bb=7!IW1rS0jp z7Ss{)rQJr)|$h%3wj4LxZ6Bi zV{K;sz4b6BDS3d7oKmdPH)>UXU7LX}x{G(GV0GKmrE2o<04Rb&+uebY0Afb@quHXV zJpXGA0t1<_4E@^9mOg`{&>p%`9ZWV#DCR!V~8F z%06JeAhVz19wHF_w7=q=@T-SUytb&kz1k9WAskJJhV_h`+xQ_N$xIXPpa%FD0v#7< zjcu@=onuqCm`yk&Mdd076h4u1h>;icYM}OGC$R3Il4JQ3v|yiRn%FDaSv!;_Y<0ZV zN_br(<@ZkN{5a~5mW!r~@W^7Rmya-_zQ%Fipt&9MC;|uQxj5{vs!@hReJy?IRW zsluH=JC2HBknLQFS$)F|1;^(L;$*&w#a(yQGzb*vFS~yBYdH;{UuO=Vzd)|pN1pe$ zfH3EZxIK1atTRjL2hF+$#_sA}D|H$K$}-L14v|Y><ocS0z{ z9sl0(<0EaHh{OD~JT~raw=CecaIRc+iR$v8erYEMj%rABSUH?){MlcO6k zKvA)^1&q@caO=C-U)cYZ;9X{*uA7D<49gpbY>lv2tDsDPS^?nE6Z@H*7hDC^Fs@LQ zMr&!n;1f@+Nj%K$|J~W(j1gBo5MiduM1MIW99|Tu{3LgOn!#Xv(@TKtwV3KVuBe?V z+}A=S8`Oh$nL#>cwFUh$>!{F=3WyJ6*LwnrfZR1wrm(s{HyE488Xmh1DwXY_V~kBn zk^Exu0wt2d@=Vh{Nk5U zf{f1%52V#9`?!<5Ydxo@B{rlMU;BbHX>L|***x{9s^g)0+(q>yoz>uZ8AYBr;Gw5C zu5TTRYBpgy3F>)(bhQ_vy^amWBXUp|=lb*{tjF)0x0n>RaZ`I+zq1u9yLpmrS>6nK zMFKOF7Ri6n&R6NrF@C0-_Yl1XVmVwVjnr?~axkFmH+$ddDJh3HP_zaIJS%miP(?Yw zSPqOpU;+8$Pn&Zb0e@V%<0?{SGj(Oi11`Wez>qF$XF}8q1XF@>Kpw9eubu`=AJzag zFc%x0F7ngE!%#`PHdfSlem`~_OSO9BrZGPw+rj!tWh+bAxD&>9vVbHQS~@@IuX6_u?n$n0Q{p*k_1Iku2II^v^P8${ z&dufP2~NBX?5lOJ$-{OKg4*x@{75kJt+MmlZ2!MKMy`xAK8ALF;F5q9!h*&Qu4=Gi5);^%WutD@=Idf`M{R;S6W(*J5?aFu${@hmf%P&@0 z`-H6y8(szxYJcotQkI<Ru>K(T_ z$m5J1s|cpUjzyE+SoyhY3vTJhZhJh3Hz81$-NEueU0MLtr6apLdnwi9$`5pzFKNAS zT)3$}&aWnEI`2jnsh<{3HJmlj%xrZmt{A3Q&cpONA1(!#2Xgh~(qm5GDwcO;OvB0< zZZAl@e(!|}245N-&|NS+91;QB#_2!9@f@pp4nYEKiO(;#cwt_Z-BD{5#`AoIs$I?wf@F3`!nqEGV5o4#E02XSH3Pc z9dXHa^2nW71cZ2060qc4CoVDyYbPLw(;1sv9r09n_LbOkKvlD;IsQu#k*7J5mpqfY z+Cxr>$g;QNDZnkQhrcX}V?McYpAO>y|1S5}sHENa>S|Tz^|{pq?_0F<7G+CYMoc5&QofPv;n>v` zE^E}nW#AvDkqom1_rLL;z6ku`8SU@-Lhl*ogiQV&UsfYx*)m`X%g)i4U^Vdjl~-$F z**bkGD$nkO0Cyg-Mo*5IUkSv>tKY2r{^k5b!Z2(dt9$ACvU-iTib}R`1t`A(VN;1# zBMaQ(j- z)U)YMtD~WsTL*MP=4hAxV<8whw^cCCu>*Ce0JZ=Ck}^6Fo>Ob8lv`>a8P4ZSWH_5cfy*9NmsHP)5NzT@u@T6SHsV-~b=}A-?1{-nCra$% z?3G;Pxk#QvUPhgcvCM+^O1zphtnbPOlV9Z2NvVaF9zcAxQ&c|GDUq5Ft3RP!N!T>)hkPr3Q|lxsITwZY9V(pe534>RBmAPnMvqsIPT0mu`rF0E z2YSM>K<8otE$J>S8yx&W`-8LX((UN4fIXX8xmUGMFQJaaWp0qvpszBfWz=po|M8Ih zwP$c)08}p+CuWK;i!z*wj&m1f*#@MzILcR*^wW4Oe}iKErrujIv+(WNAt2_h8)~vu zH;`)2xc2hx;&OwVLf*8t$e&$051ki8Bd{R;5;926gjr*2+`$P&vNX?6M=DVLg`0y10nZEhnval))r4w64!erH=##!R^j19+>9M z3bu5UCuA&?hg_nvgavlLT(%rz1P&L zLNFD#8-flJ7%kt@?|b?80Jm==q|oqUmzTk> z_h!v3nh$!JR>@~$`g*+m39MUIVy6io=z8bhI2`z?OS~AGawe2nz8{zJJ1-UOR~tD9 zted(xDV_9~!7e{Yl(Hc^ZSqwH;D{=Rb`I}XNBm{(U(Vq(cFR7M66(9eZo}Tfg#+4m zZh`b-f8bm$hrx3$*uk-Bf>Z0kn9EWKw3!o*W4zw-mKW)=E&qv>iGbJ`A_or2lw*#i zP8)cIS1+FJVJsu>TgIA(P+;Kuf6%xDrY|ssRK{AC>n)=Uzf#YraJv>yT4^kCTAB; zH#EhSY{9h)LB|;Y@O+F3f$5E`(-?=>@O#8Und18UTmgxJEnbVybwZ+ZeF|O! zfAUq3uO_ooq)#ruQ+7#F$Q)e~k2XnTYgXNZ$gD9RgTm1MISOyFj^oYV4#tOrtf_42 zpVk*HGS2wGi(k+jC`e|foM2otvCEj>Z}q|PrHeNFl$VI*g$^yv-*csn>XPF3C{<6t z^8c|)-;ioZeEIxLCc^y{`##9{&kdYc2WR@Q+=MJ+E_CH#PbcU}NJLrLL0JLw)eXMA zg81!GNj^7k>K1$0Fx^G<2528+4)5T(k{FN>TEF5#zNHGA{5!|0=pTPRO8f*9Ar8b* z`iY4QbyCKz?8*}Ny7}_ZHK9wN6*YiUM;vV}kGi@J%(V9U6|&PQ3LSQ@d482GC3awKEz zHnNDko=7nT^uCW;^J;*#lnXL^EWL1QeWwLHc8eZP)m}mcd@32F70>aTch@c0|3HMU zoHzhpF<3)-!FX0usbon3?VP`VW(qhOv0q@EaV+5e*0%ru)RQ1J2e?FcE2W+q`J=zE z>BSfR03)?@+_tL<{zT7o5ahKeohHqsl-i%ysfAWwc@T-XTUUsY;h{WI zgyT35E7jMMh9_*}Au_}8pRMCIrw?u7%-9C^PQr~ehF8K5>0t{eA!uTYDKHi@H9AUo zkK(7D^ZED5tuN@Z6gt@SA!jqKKCMrXnq4BRRPxUVgv1p*S4O)T2#EG%>XuXZ7+}@( zCNm~ppzl`0CDCaj=~CUCSYL>Hp;9$aUW7k#p}6$ua%Jr8xu@lOQN?{_LohqnNFGK; z*#Xt*g5LtD4FIe@^Eqr*JFS7{fPsT2K;jI4T%7PFX;~pwV*ktNwIrOUy<*kxY$xDe zx|L63ke}~|vcK-kRnwIjhVfhOE%0#2GN}Dv(RzC-=1$X*=}jyE9bQ#Sl8~z!gcW)+ z8UCz9MaWTE8_?m0^SxDid!WR7ii8O)zTIU6aW$ecNJw250Cc4^+jSEcP*T`aq4lG{ zSR5qH5f|SbT|PekwH?TCAbmZbw=o2?);^Nf(mHXAfo}ZBVSZtiWqpj7ZnE_w@CZ&? zzd~j?Z@WFLOZNYegm{@%)PB?fMiXrGE#V#aFYY7HargZ{rzS+$UdPH_k4k6eo9k-aNeLqNv{M`IT9RDCLH3S z`*jo*55GilVe77_qD$P8ELTm)O^WUs-(^82*Ru!Zlz5L$ST3i0NG+*7vq>-KeXS<5 z0Z*-w_N6nxaAvDA7?0wA%HO+{vc)4kp7A1eQ2#)zhZfe&peJsc>3s*x>GBpjP{tmy zrOK%IU&7&9j~T>9%PqwX99h4EfT=QaKfmWp;W57z?VPlYLVxs9Wz+ys!(G+mcgVBtU3YNL8(P#m(#W*Cy{04R)4u7(v z%3w(WqZq4`9CxH>1FnXArKL;w^^Yu0esDU6r}RIT@6nbWqsu#h7YovWDU{%948DFg zqAE#~V$pl|kRLf)i_aV(Pcv{X*9U)Fz9ijux@gC5gZ&gldSS5H7j|2oNAxaEAqgI|L7q;Oml2eoeprD`tAkPi(v;koF=;vSy z0H~@0*Z=?k=JOmO0QGr<^4$KXJS_oa0WVNc{(JtHprWJxCoj>^P|-17VqpB&F|n|* zFfm_YVqm<&ef0_(=UFhYUgP27y#84UXY%w5K!AxN0vJL? zVF0`!KtUxydFltyKI@A9jOag5|4UF_prWC_d{*fd_VWTL{ULEv2 z4?rV8C#2_y>*(s~8yFf{S=-p!**iFTefIY8^#l2bM?^+N$Hc~g)4rxdGBUHW z3yX?NO3QwfS2Q#>L7Q7z+uD2k`UeJwhDS!Hre|j7<`)*>8=G6(JG*=P2Z!eum;bJ= zZxFY4|KUObp#DFw{x`D!4=#ddTrZx_01e|mTqrMmpF1i68ah4iOG0T)3=0n;2EI_t zH!`W;>wB;m`L)i7Ej=e+kpKlYn9l!$_CJyR{{t5Ge}(LS1N(n*!2vj^D9?w7N&t`o zJV>517uA>};ExJ|)kdp>tEYbtY%C{hgzB`mTmC>$F)$KZFl6PXz0oL3&sk!C*1i^x z;0i2P-h|HpcL{ZoeVhC_;d9CF_ahXtAQ8GJz)Ne_XlhZbZXnS-NS$03h@ z(gU>G*@_WeKYs{nWNfg8@V64Vn;U!Go;I=2m#j9+#T~d&?a1oXR12Z1B>9bk2JTu| zJk6aQuBt6;B?@tL><$n!()e|?g=RGkbuwNWtgG1>f~6u4FZ!!0Z+z}?^y2Ye7vIZy z>1Df?P=DoZhE}RJsQ%s6s9FSW(S6=S6!dB`KUH z@#fNmyAWZBAMd|AQ1u(#t0Y>UzUE=c9VIwCnQY2#I@A_}u>0wnmzk7{e@ZHfx&lm9 zI1@gcR&4PKNDEI;9L0 zttgTD)K|HfeN(D14)3{@K9okX_4_k496$beFk4aBL9^82px_s9cb( z7ynz_PW0it!O>_?2W8@qdkBX1-Cd9)**~XYH}TWsYR8Xy%M3nWeo%TOU&%6xC+rn` zVwV$X>EKdaTBCZEM!t7GLDP%(t!**7)GtI2YU_YhtTYQgt6L#%VB0H1*5icc^C_-d z=X-Hvn|&u#r1GCDE|kMDgfP3KXuSXzyrowVr|ndAYCb*nh1Fh_Zh3yXEKn+_>|r;k zwO?vWyT<-YEIL=h{1S!Lwe?nMj~#yx(N#3fkoUU4*ENPxX9l=LICKx==8&o53Rds* z#5zu=dM(|##^|5)?prG@mryhRTy`1F4=yV~=r-D|Kjd4tXdC+MWmvFzmO6$++dFfm zFtLCibP7)ZsGy>ck~JT>eqgkZd$5TMwa}NBV!jlArK>UMD4WS#06 zs>4Hw+>WwE9Y*oEw61oG-bs#uCV%CSRIP&rE!FTNKL0y9{b)yQt&(4c2LvK=G(P#% zx20J7TW^(xX#cCOR5<^GGc0{%cuy`a|Kp`t+zRiz<6IsX5d1tlcC$F`SAwJD67RPuKm&I%uR&GQ zX3DgJd8l5T!OnZ##>*hKL;0PuCxBT~!h4q&+M z0-j{*3Opew-IVY@Khok0>}cCUUaEcUSvtL)ZYZZS?eoEQ0R~OBu+S`0-6tD~d>q&3 z`AaPvED%M-<#GE2D7eht6W2U%?4!Y&?dmIcBH{VnfLSWm5^uFLzbFIEt^`rd=S>ftbEM0miFZ|%cxr6cbIFUldS(=~(FK?9D z*yEe4MiF5Gluc3S>BCm*!x|V9iSG?wefHq~NGb(gMfSqFYHlnDkY#W`65Gjt~U!QiWb{NY;dcRiD4d0GCW#mEA+$*Yo@`RH|1M=M9k4k>Bj^A`&8epO1$8*~~FTxnl&abzLTV?0aIxU(BPD9Uxg) zy2AP)Bykx8iu;i@$>*F&vOiW~+P~I`_F&7gYfx30=Z-n8Uhl zk-YOjqowfaW#7G@pGn*+A{Tf>_XH>9Ev;3Z@Irp3O%wAQ#7a~8Q+-iaFS|WfO^$jz z$M}A#7ZytkbY&lwd;)CGt3@P6OXPkBYETr4Rcb86#vJn9G$=-MQKU9lFSyndp^5^y^U3K; zaUd?QCuRhslHp%T6Zwr7k#zG{rIFvnRHoDhENo18Jn8C8i&C(b@kc^?kfw+Sp`Sr# zCvz~ZBnj>TIJDDOqdh8{-2oEtHX-l}Aj|c-ksdy`{|#ME?{;ah0Dl)c>((jl>;1+z zJ0r5ww1S9r=JO_dFXb}?Ac zCR5waWZ<&%kvDGP-HJLe=;27a&GkIv8;)2)KOAG*No3i53||+4BS-to>ZtS>Mr_ zrVE}u-bniQ0dhk)u9Ia&uQv4}b0OsoDmypX@HutB=8V5g0s#HSdjiobs2gURj^e_ zE-qCHyROaAcw1KwgJot*gJrhZqSIQyD41T27i+)FiI|PCu`CixoT_CtU;hg03GN;C zM0wbZT582R*~4(OJE}!RG;eaFmoCk=-gCzWRz+T-$=@>&rU(UJN@a>Lz6)9)*Z5pw zIIYK9{!y_&!y))>(+oM8I%x(Ey`nffTO$ul#YJoJHxD0QO~m0lf?TWl@Z%xGeB{2Q z=w%ZkWL}=iRi6O=LLO-4nVO``DuIbZD(*~qxKcK7U;EmT!015e%*5fw$YtUl*vNa2 zQ!bH%u>fvhmdXupx8lurZijOwTPVIgsRw*GM9qpgrAEO#5W3 zCl8Z(4-=t&DMtz5sfn4@;!PJ})QysP0(_?0to)sDaM8?v{J{46EIblklw1at+kOjM zIhu!>G}F84U@8lF6=-H#r<{3Rk(Hw`AK9C4U_7i2=AS=`sh1h;1yGb~pG}7hC|7Wc zbxk;hNyWTuJNP#OoAIC2^7BDACv==*4?eH1M$stRF}*PuqQ2eB&+{W=Pf{B%N^s&v zR9@}!9}-t zTU#)oYmENdX#wb*|M7)MJ@~Ie`iG%Z&Lu(mbJcP_cduJ3?ml+~A5RYFD4+?wRBySc zG$zCQn&hvqqq5yA-Av+&fJs#TeiZN<`K-Xjc|zaT79|+AiDDb29811qPeC?`;*&nx zI~X^zn?)mWkXs-#TSEdbaOK>- z=a*Q%5d7I0WkSorTqg0=J~bEmb0$NC6_p~pu{lyZ@k*TZ`uw4SV<^e_-y&C4peoah zrZVuFtJ;x;XcTRl0nj3&tlNc!<4j^a@DsCR7e5&r;dXx9l)UJ3_|BVCy0m%9E$^fz ztJGM~u%>VHvJnU-PMyVhG)Wt6Sv9c%1ede>m3^hHCc0yS^3h)frvVz7QAgM(dO+S~LB4EmCGZV=K?@;Y8J7j*d#s z3}S@*t+~-6*%WkcfmDwk3=}o+ITZW<&e$^CIcJIDDr{|SFX)9!&y-D#T%J^)aAZ4d zG$%$=iGn1JaO8JM;n>5_Uye@zB5p28+Wk7kvK1LfHy)I!9nODU>7ePnYCVEpq)@$` z;p)fq`tOP5it=Ex!&ye(-EI-E*+0ED>J+>WM@5s_)FK_-As4Eo?orq<`1=-s^~-Ry zi}I```g3dcGV$4NdFh06jWw7(Dp9!%(n-osqb`Cqgo!Ka=QpO2=Sz*JJRQtAREg+9 zs5-Y092ygj1jE~$M*i)4k?e|SQ;js8SSl-pJEr(L83($C?XC_g3xePN^Giws=6lDQ zVZ>@i35Q;A2%U!AYAN(tiy@-2DxgsKLxls0y{kB4R|Rx^;kpXD1PI%_w|tLY%9&8l z!bYuj5hQt`{%>&lgxhRC)P1r_PEO_^sRH-->&Q`z?3W~Nz-UzBY)a0tv-GkQTtiQP_gk=6>? zBk0h#+9cbUxe`lMo2V&*%(`e663`74$C6NyY_ST-mJdzvJshuf3~DGYObOEb{a5fB zsRV7ECUMmj7sLRJPNBjgK4LmY}?xtt6j4n`&EgnXiPkf0F8=iGtuaB z6nT57do6_3FaC!#lXK?e!p`uAj{^^h-~HxoiP5HH8m%)@1Dkl* z#qIV$)(|yVPj^#LZR{~L&q_Y)q{^G~m4u0!OG1Hlvv9U?@fV@q!|Mixd(+NepD+i4 zKhhG~uTep{@emN@AW5hA%C*<~b8~l4IHvPvhukA9cK{6>OU4>y;1Y3>ICdSg>Y&`j z5rlp1fcRPN=Vc0BKJb^!DS3#Z76h-_fSSf=n-*evn*s>M>{m891{H~2z^a4H)=PxR z?5(V@vQHj07jv@fKT0I2)qH3>n@6Ovv87X*Nllu zKeIP=hTreCWJwx-X8!1TfH#e+4Y1WGu$@So4HP2^os8A%mYSXbt`{m6&6^H{Jo1tj zxNW3_$o{Zce%@aE^W)NHJUbFt4U)b(#Nj=!2h|w6nB2bH{GjkLAU{E!e9N7>P7tUu zzQ%SDkTkpb1i%#6^l~Y3aO*zCu>S*uj8MZ1s7Y!L}c5?iZp6j`q>*e|6La$psfS7uEA|z=!I69*J#UZrtHeRDs!8ZU^nd&ht$Zk@B<; z{s(-nyXpg~juNOnr1M^!urxFmfr(e-6z_`kz*cph+cIEcO|Z#ZT3#dM(nGx@?;#-~ z`w1Z3cn>5^GdEdf>QDvJyOxA~8IwdGkU~&-_nv(kdqn{a@IqX}(GxNwZ-gI18?2hTw4U)_uo;UnkWnsPMH0_G} z=WB+{s(liJ0=%47SJc7v!=b4?Xs=HlA$+6XdKQjHjQm)qd|Q27wb>B=(P>`*;j$f* z**TaezLRcg!`&dHfsB%RdnD?R6ZB1X*V#1fj{cV318;gD9NE=5h zy~6@t(kIB#j=`tsdP)4oTtZgx$o-NZ2{Nl$Qk?7gM9tJ^jJ`$5>q};(5=e`ac$;?~ zX(Gn@?n2`Kbp(-WG}O&ixiKeERywt6Z3=3TL*(Spgn+z1fL45Xj;?xt#(c|>>j$Lfpk|*1K^Y$ITcD`d9JJ z%%eW%94j9G<#PSol_E~=zt?lPCKsLz&_a0$CV}7}ee=4k9fMXTu`Juyx?@Qb*i_`NG>Dz7gjaBe-wX2Bwx*wDtZ90Ds^< z&Vg70G&;@{0c?Xq^idMZ{X_Ee;^x%XSvFkBR)G=H4xATjBV?Agkt&5`o<+NKndx!u zJDGJX`Lngo|GF_&l)arB411eXOypeEM=;Yr#MGB-$O7w!cEW_&3BcdxscgK;GCpV{ zOa+8QZNs`uyyR9HL$lI$3>X!48Ug*te9I9Qu_WqD7cPb2GwvAR*<*>XP%U$Y8%sjr zw76163m;Gj0<-hrD6p}hM=z#KC3urgph%go#8=j{5ICmFV$ouMEU5j3neO*@XHu~S zRN$%?PK^|Zfr?NlClT`;BZl&0OKT_oVa`U_OcXQL_bj`F*RhGy0&!s~;rwo&UaoV( zZImt5u@M*_fpPoyg+WAh~=Atv&xOM+?UqPf z`Yh>NEOUXxY_IW>?6^%fsIRvZH+11=2mT_6MNzqVP|vjBrHKI|O(w!6WeE}M%5{Z# zf@~EpaqzLQG3Veo5a8hpy`L0p;_G{Xo5(RzqdDXS&Dm;+3lvqH-U)b3(&DTpQUGNu_oLp$l-M%SG6U0p44~V7taR7h$m*L z&Q2ydB}@(=xXnSfyo#rj$WIJ-_DfYIodN$@lEn?m(RBRHh3qQ??kn`2AR!jnQ*a!<5ev0s$*J6(~ajJHcv zDl!ajGU<8HLdd?nl(`H&7};9NMmz!5IU0?fR+Yt_Gv;qdh3YGt*8G=mov}UdW0rn) zS^f!+93j4%k2taJJ{$Pj2nu$tf3bwcrnsFBL~Xi|W_&jYn($ty+20%BD!FwJ^7!{7 zruV}5E*r(hvucgoDm!r#P3Z3xOWJh7fvHxInwHdSL25;*@voJM>L<4k?8k)s?5ku3dh6Q`y6&1I2X#dGHcyGWz1;ZhABLna{ zifo!CbOq1kF_VHfq{6g}>UEHlu+~>R(!*EACslq!WstJk8ujhqbRR5V+UsbE;Icct zuQtnBBz>n6FilN(uimAv^h`CvGl`0zOm}Z zE(-*l6E2NW3C9<&kR@bci%&`mPW7d;ZM;}Yqti1To}V4l_8!ks`dW^#^Jje{RA zHp%pJV7VUBGUGGUwue2Q#C%Dj5=y98Ve-|SQ>S~i`^+F~b?MsV+|x*&D&wv>kvX2? zbfklDI0F9lkhlnR)1I={)uP@8%tY=ybl5;U&m!kaUVr!S@F-1g91$4DWP{txsNs?E zjSzQ&yQWj-2ED?FsFvn#^XeXPVOVF+{%fblRzMJHp9YK&QNEidZ>l-&Ioz1#T)p9* zV6|N7vFl3bEXPq5x0G{rv{`#>y5WXkF&|uu4#9FxYU(05Ta6KD-TL5AD=y7>HP8+n zZC(L8YKn*0zr7E<>S{N*Q9GAq*;em$VH_asCi*-f)eBcW`M&9bTg`dQxltF=WZe0Z z8)Dj)B*nJ-cD9(wWP}cPmxQ-qDjC6w_r>{>{+8N~X6Yp3!fxOe#*NZmMR)GG{i|D>?1nYc^;7{R!3`qX-)vb*8vd0E@ zB|>zlz#HU1D8L7jAslmO2<>c(=fGxMk`_r9j@f8CZ`r)*D}ZBf(DuQw1LdZXrt2U{ z>H}N&kn-fe{@JD=y*E&0FFtfliIRSIxY=N>#|wFipvG82kzNG!riZq1)!{CEATlxW zIc6>^M{JuZ&@Y{8>>LhRTB(fIn(TtWh#-X0!YrITHtT;d=hT`SS}JQ@7?r1(nPEG{Q-;#%HZKEpf?cw8Sy?!-Bb=)| zj7VdLJuTW5H|A6isgyf)0^#mOyi!1A=Q^wY<40a>2|BBNxk?^(dy zjmvnsgY|>^A1Z*D-#%0`6pdTfe)ohx?C{P)^cxigvPdY$VvD*z$F07NIF*9T)iv-e z6XO?)-SE(H*s=9oXIhbdD-XRzzR?tN6 zAfxKp`!TwJ1DVPOQCCE$pT8RV=oB!hEcU}NPmnqC1J7x*!Fc0G0uV5cmuMuvgpRxU z?GvDpGrjGQot@Es^+uvnZRVMXTV2deo>b?HbisJC@yD8kUj8ZMtoI5k`D0W5?fR(wTTyV`wiRiUN%cCsp^bH3l!{to5Kn1 zkEeyesHWxbcVxemFcAXd7T;E0D8*U%9gn&RE;7z?kza$nmU8@_0KTpSeT+r)u<@l_ zNZ^Gsn$1Vuy&VVA4#~WS>}tEVxz7b`4kYANi`zJV$9Vmwq@%6!wQRLhIzm&{-4upa zXvo>FCmJ6Nd)`y}(r)`c0gm+dT6*1{6Nz?*nzBCivN)Ejn|cmQY+7Wf&H7$g^^7L; z^oxLbzF(bH)}^uQ=adz9R$vq{+3_JElK^oA`M;WK5|f*CYv=1b zLY5C{roqh%>+Vt$LJF*mc$!1-rlFeJVQtv$)xRXw`{s$K@%YJUX`cW+7UCig-RtTv zmzr8l-TEnlQ{7zan+mUYmGgd3>r_(LjrshjxMkSgh@xOJ40+Bjt$#B^Z=L_=c6=1D zZ4X_M8&kvY?B?vRWoT5UQ~e>*ldpjzK9-e9V})5D|E!(3?oHGVia`PyiT<@we-PCXC# zJ^k-<=j^{)JV&2vA3X;SL75d{-Zm!SXtr7Hk2^qqCrPMkx5eqEnIdj2W$hYUriRb!CXQPMh=Fz(} z4DGG$9nhuEHVt;o9fp&mCzmwr?3tO=9g~x^D7SmTWJvzWU=}3ZAMc~YI~5OvZC^jd zKi*rg;4-S3(D|rlG8<%Bs2BSSH7u)Faw38Yq&BSOzDEYd!>7gWe zk-EuW6y|fpHKPgCZ*C;(9B(jP8WA7VYfb4E`=(m5Y;5M3wz%MMM_t`NPsCAgJb~LZ z+AA2tG+o6@%SDn}y`T^_#T;i7;M1aPt}y$zK;xV{`d?0^li<97v8vGiv0t#^o3Zjk zq4V`hH;2+NgD9nm>oI+oWOlh+spvrNVnm2EcuT21sqgR5=%nlC4tBfNmV7ZMLn~<9 z0b_eAjk>690X8wI9-dZla0u!ll%tdfF->!D5gu+IUMVw~FM_3b8b z0T)+0MS*@xt59o=3rC?APdt`!(C1TO^cqWhR-{E+dk7&2jyus$VJ(`t?JtT(h?PB1 zbzJ=+4wV~nwy>0WQa zR!n4tBW_NprS2~YB`So=1vQQOwVkwkL+s6?uOGz_4pu@LFALnN{S;cR%}Rn?txRPG zm%KnO;+~)ot1KXAl*~XV%P2mF+of>$sj0X974}#x^US~I)`rK4HgjfU^3LXn0U!%5 zM@oKC<69i1l<1=%Y?XomGqKrh$}vvSHMN;5R?eo(?KJg7i+kN|j|dP>y*SwLAzO8V z6l{?aMP^C+%lV%VqWru=J&Ib0)7aoVvMVR#bB>U@Hwa|q#yR$8P-<&uJM#qiMdJk~ znvTuMg3nEoo$BOqmG%v!v;e0AW zO1}}>3{@=9nby;^^Q3_NCd+!+56>x~uUldO%jBD>`aO1+o0=X~JMA7q7(}+^!b|xfc(9uNRGyw37Oe zXGQ$!u`!0*qVVA2*;Jnw=J0n)94f3k_Y;S0`7SYgmU39Se1Dn&)eg)5AnzP&g{J(G z{qFn+gCQAvT0HdfYRD<#CxB%7zQW~b7uju~ba`NBQ2xS&AB*>|gIi6tKLpzScV%}I z6ApH{&F^V!R;6gV%YG8~-cXfTh0kad#2SSMj^^udsQJM%zVISwe?z}xcZgs9mSAT$ zg}m!xGJX5|fWu`m>79r}^o1F;+*QH#x#&ji>%oamGfN++pjW?a<1r-NG$JskJG$lH zj~8y_z4x<`F5xt}^D4f4sh6Zk;zY*yYlF@iaxR=~>Xj0ts$1i44({saAxa*48y~Fg z^MO%gC55C_D|3FoSGz|a5m!Y{KDD@2K0@>6i>G#2_rR76(WErXFC=5Zhhhi2Nqj!d zGoC&$Htv>?9!KawQ^>;oKuVR?t=EAH7+C`GquE6N%Kk#Aq2@y_4}p+rk;=(gNNfw- zEuxDb0h!mCqko*5t~u^hwkTDG{G8baPa7w(JH2(chA1a01?Z*i;AQI3lX-{n!#EvR zp3OMWd-{>MBlWR2nB@r&5r2IFIoTX^0{w#l5L&wqzQKQi;qNqOijOOlph%I?4whc^jKEE8VCk1Q(( ztHkwZ|KOb-hXe>&HU4*Del5^Lhj!B%lRDG7EEP;+MCahT74hbsPp&w|Sb-ZjcoA&cvY5gStVTB_?xhc&1=gxEkj?nNOb)#>J^W3`L9i}2aH^<{^fi>XKwQ^ zu#VD;3sMt;XwcJH6d4%7$KBjth#6Tq=%bF&H>UN`u$;MU!%M34%?Z7+ZR&qU>DY#$ z8*t|P=+PrvH4OW8-hS3e#FV9&BmB7}7*$Rqc3$w_pi_4NhE8>e+K_KyxNu8e$9V;g6f1h##rVa9x?lR@WXH^^4BsIex6p>!W3cseClGb#TG) zlKpZ4U&-;?d0SuEzfYa5BTo`(w*Nr<0Fgv{U(jG1TejjjPsknj*|5gQMf7lUg_C~P z8vLjezK5oWy{eJ|i~OvcjH2*qjSG?7OGA#}zgP2H`gNThZU9#-bAXp-l{uV!=(fu` z<*Ih8{rXqx00*)){s~}j6UtL85>ni42}`E1i;=RGdU>c%dacoCWk9*$5w+IFnp#uF1N}PBXg#zw*+a@6 zyu$KUnowNtKEaaoGrXq$&nN_ZlU(rl3Q1+2?7x=8SYY0Ewi>=}D#^ZLtT}zmb^01I z_s0m^;El{=d~Azsh-Zz;pilzuN4{^gM!|g{XyUZ%1!1D|@l8w0MVs7YCM#n>RVMI4 zbJHH|cTnE`U-6#<6h38H4xX;EdXA&K30eP2aTy5Y$hT*34sRpznFHB-V5wUCB@RM+9dx)mQ%}K!L04r85@Bp`%lV)~} ze1;P36r~|o$8@qpXs1XL6(O}br(lG|;hGORA)1KjV@yybyzSq>I|0jU4s$lW?HSR# zc>@DuJ()708vs6z(?uyUH;uuoS$oB<{`X; z^NF1I-}5vEg`<7+v0)H`^cLK`PoCdnsP1<2HiAfS)#Wb#4O=(Qx_bZE^=UGsu-^(Z zXJryouzbUX%?=R%L;7xfuF;=T`zKU_zpp(c`rkd7o4Ca0fMKTuZiFKB*~PIX!D15+ zFW+&0zUSzhjN2V<@=u1v=CcJ07uOQ9>!Zsu)l_Z)3}o5*>Brj!oAR-FUpO?sCa%2Z zKQ)gw(rbH#q(7f4_qIYG*q)m^$+(_mQ=)H;R)y%UN(Q2XbX>I5UbZHLgGstP4!jZCrH?`}qcalYRTym43+cem*=k9f;LJTXpX3za5fXC>r!aEZx5C!3$8OZPdFkZM$jwF( zy=fO4UGc5Y8SDtVF(l}^I`LgtNygQ)$)C;HoGRMn?Zdrr5bAR{jXmW}S{(Df*-OyN ztHET<7rlO07?VGc%s=WJaP_h*5dmNFxAsM{6?cWZiF-YUz=+@NfL4C<$L$-Zx!@&- ztk*y?GP+QvUDiYXNysT!w6A$zn=>K*Ku)SWZM8W{?DXjv_pBb)v-Iv9Z ztM)2B6Ryl|OWlWY$d-1*sV(`#RO2mZH9Bxu8I%O`xK2jCi9D`iMyPX+^NTu}2~Hs; ze-ji(F}HR#r(##+Em)w{^OcRC!91Fx6&8b^sd-ravSLx7;B-kCy~lN#(1n7PoV7fd zqyGZsqLKr&+}(}_3e2mP;_8~^mYSq)f?u06+#rdPf9b7ZJYxAxn?+H^j{p&-oA108 z1dgRXwHg}=emk*AXuc&cY|-J1SrAhIuuX>F37(7fH#j`9&KpT+k7QFGm9&ICV)MIF z16}G}36;g%hA4HFz-_v(e2^XYj7SQnx@lR?0z`Mi9%duOhG6-!w z3fFaDi)dKst94%(Gp&HRG1BHwE!2sdR6*e0@CsLE>WVWP#IXrVPG8~rl1&Vb}J69VGQd7+7Kn`6XZEg9=;gt1g#C-N13Tvz!{nlY>)(>sm zZiR@r%940$rPJOg0Kt_xaXfn@ABHmZ%hj{>w2SSP{==;uMa$%Hup?eL(W)Qi;DtRY z5w`hX33-JUzJ*&*Q%A>)71-d9i2ZMxyXQbD<|Qh1ZR8&0IK+G|>U@ed-{YKLZMsIB z*X@U&tm5=ld~sdU?$jzw<8cgP&LHwBlIA0pPXp(-^z~d7Gtw_~`h;)DrSQ7wj*R5P z9w!Bt{qsvo2W$)f43s$X?VktL>e+%y2>afsoPwN_7?LnFt`DuFUVn^SSTD%~%6o1V5w(7U0ILG+M3jD!p zN3!(s*_+4qeWTCILG$7Db&6c-@(N?7nK17dTRM6~xD-QT$l}^oX9&&}H;H{0E94!L zee)=@Hch{NU~0b|ni;ksxhUXeLL9plB?dBccrieoSE?Lq=NH_1w$?9D@V@mK_b}vf zaCJ!;9FX#AMfuGiRE7qecMu7^WqOs3c)U)`T39k6qDNm7Ja5v{G%JYrS_M8)1^>BI zvXQMYC8vAuG~~)OgUSH}+3dmVyTAV4uB{A)nPVeW*SQngX~TDHJ6fi=zMX$;(gyM~ zbu9gyt!sE){1uv^F(Y6c@j;{hA?Z-=Ee)}Y6^l0^3$d%!LZ|W>fz?B6Q;=0JYtY7R zNt7ExiRrUk^Gwz;ml>}+X#uZITl1&m5nfRe_^9=1m$;vh6pr8YulDqvyxo%0g(J!< zda7K^;+D2B(u?sdH?B=5-!*9S59d$r5*UNQ1g&$`TTw&(q!6KQ>o^o2Qlbxubb9H` z3B}#^Sq5w=_p%IiGDEpXv=P~~@klWcdheQKG8MTj zm>wEI(F8@*svIEm@kza?>S7vQ4MTL`w%1O?iOTw(tT|OGNM3PSZ_H zauXxy31IAChSPr5Js3i0ub&Z*oBq2VcktUK#ywYci6c>t@ejs_EiCfR_%Z^%lbEW+ zzVE4z7Oe=o>Sr?%8ufd~`!rSnMmXSe<8xN22r4&TT$H1M3yd{XWe3*`vMB0E<}u=K zFebHaFRk6xPd?X0(Z-&K^r&hpQUp$~nJo2G+_MPlB}49RLAy*VA0G>gdhFJn=8tVq#Oct5VCO^dhy_hl#ZG^S+<*MgbrJo$NH*NHX8~Y zslr<8Lsm`J15lE}6TtppwFFwp>C__ivH71+Vu zOskc9^`%`i`*$M70V$_rA;Z&PntCYsaHP}t)6EAF0;7_~oO*#F^r8p>x~9o=zeT!~ z41;UhP^tCmM{0e9u|>I}o4@TOl+tOEWF~BHoXBWN)z#8c!-TL`PVNQ68^Uq=Q2G7t zV`d6C9=0h+TLB*2f)+EvR?){13nb{MOsS3f`X%XuP2d(iX2<*`>=B#T&1>@~*Ib|S z+I?jryGqjPyQ1t(5!c#u`k1C1Q!{R+w9lcYbN1qc`XGD@x7C5$yfhO(@Twu)x4t|kDCH1xddUGEKY#y_lpWVdOgVtxmHJbCb+M7D->$4XW zUqo8pt8RP9l`&cuDr6bF6d97xrEHECck#-8C>7T)7aqVPcgl|7IS___&6wvgGCtJg z1{?R@eGumOy+9~`zeV=70>fa|hjpV35Bo@TFIFIgLYZkNH-pFzrV{u2Fd#jQhy#CZ zMS;n3NegGlbs%@WLV;)D#dGPDeGs}g_baFD#M!BdYGZ`l;nA@hn`usMl65;aMIW`GFkh)rQeMpkQ#x75tFaTC9p!$cdv3V@6)|l8$@ByS-r=1_}XQ> z8XK>X!y=qC7N2)>=veh!$^gN)N;<8t_6R%w`8yrtI^hLusdtoBZ~L>IzyMXs>&&gIexM?2RIfg!?&z-9&%i_tUVp>NGc}A|c5tzKwQ?*00_>HG0!vE^EyK zdf;faG372dG!y6Yn}M`=#{yCfOYlRAlI^;hoU-4i@ma2>F2`1*VY=~3LX9|>pE?77 zM8+IipK})Us-rth9+6gi;qaE!ckvFmQ8m)Uk+P*!1}>p2QLwpibV!SIk)8$RL7+jK z_wj}l7nIy}3-q3wbG-|&?aZDHGn>?HAV6lczMysByxYN?c0U+^00jxhbp5#tP-krT zkmcY~m`HfYvqEeplttdkPSw!LUe|ub?QOC*r!KqGy=1j@m3I?U;~n%;;h(l6ydYZq zJ$R|kMUA-`%4;hV`rUrLyCqwgts3oA)D1AUhq(DwkZ49Y!-z=T4ARz91lA`>xZ4vk z!q?cP%N;W=1CvSajpGe%W=e3WHVw9td_4bT>pJ6Tk0p^HT|iZQ@C#4V}pT&VVp5I#T9MKEnK+w}7dd{x>mgsRe` zRO}AWag6dDmSh#sHa-E) zRFdmW9{W}?y6Ri*$>tc^W|+UEn30?9z8_s(;)|gFj^atuuqa)O`zx*X)SF zA#n`_ytCwc*Sr2I0pysfrC<6}-Oa%`!0YB`Iaq}%go`zq#?tL1EqCiukhqVJ#Ofya zC$5OnAgVZdAj|nlo7;|n<=o)F?2;=J9@83Km>Nn#a^`DUZ>mYPGo3+o5+_EdV`DKM zvS%7YXquMf5!&ic6T}<)ben55ik4^vrx8QgFYpONuF&x!#*+bHK~okmRMqvZ6nH) z%_NsqgvtXc+ajkt_JZ+*x^IO%8{wO)ZBoX0()-2<3thz=@wLsh>smoFjH&Y3H#sM03<%uHYvtQ7 zjQWRyyi=-LTUds;x(JY@=r>;9c{*el!AX;QcmzHv0PBEA;Ul^l?bGwg%??=ai!%8*U8Ecw>yR z$QhPQDkEd${{VxUwu|C_+C#whHyRzrm*Gzo_=4Fry*EaJICRVCWrpD`Ls_JS& zqOys%5T-e2Tq<86Kd9PPmErG)I&3y}kZAfHt<=CpZyLvJmojZs!l1_XDu|!}P7y~y z03!}(%)+F(S0XnixQ*>u_xW)WrHWy8jE)CxdRM`J@KOH&jsF1fk@#E0nhm|we{R+^ z+rJOnMkX%}rr^TrHEo)yyXygQ^_Y5G0ot+{Px zU0U3~_A@*wJaD{7NO>9=71;q|7HCidG|?mcW&1LCdi(a1{h+)}u3cYWYI15;O=F{_ z=^Ol+Z}&|z3@l(xaRSKVH8IB=5yXr{4TV}-`UM#2bLmYl;a!)CygeU-?CtH>PSNzc zIbO;nAHJGP*rB!OuglaguAC6xcJ_@hFYuI&*X19aGy4u{&r-@T^_EXxx#^N%? z!m$0?NE;iZX7(P zD^l>k#D5QHv%{&|>RN4tYa0N`G?95xBr`x&nmIQtPO&nbt+h{QCS$6*Hs(ujBMad@ z&)7AS8r&;F#U?Jw6(kY7tGI?a&LkLP$`Ei#I6uCh@J@MSxc!PgAIh=$es-K>lPR-% zG_H4n{ni*GIUz_HIIrjb0Knh5)NP|pvNFOgztNlbW>}fO)$}_E$Xs9?X9JQypwHM2 z!}}-pCiq<*{{T2m;k!G$G$0v@p$XCqs6He3FiN?hSTi_HP zyS`FLJ(+!RMIH-z6UQ3j6%88c+Ixvha>UJk3gaDusn31>+?9+{wxQJH{AT$0oi@_J zzp|{*@%`EF4;+8CKPr_k%1AaZOq>JBA5LpK#mvv~dgp4YmDBB_&Usix_9&ZuQg!*07bb(W=w$KNa%KxgU8f< zc7MS{w6mu8^Wn{`Hy2(V)h5=-vxd=D?ipiyJJtUHT70&e(%v~*Reazr-ODnp##E2= zds>2fs}_#XqTDddCBbO?#rc#tP)^1pY1#{6jyb^}oL}%$e~9*P`$qgC`&Fhikag`| z@uE{5cZ9Gw^X%_bP zGovP+bi!tnDV9 zrdZ9r?VXjnEQLZyiw)z)lYz1`o&yny@jLd5n)~A1(fD@9`(oC{OK%$6M+9PD?SZRV zzM*Y+L`cD1Z)Hhjj%}WENeMfN1Z^no{eN2^UlTna_M!NvrTC9f(Pp@s;?GZlJxf$u zhKHrr_{{Ut$fST9DpN}@$HHNF zG{)36ji-H=2(m1)$t~8%(%d@lHuDjHvl_kQZvps&;^%=aw0k{U!oCr=w|!5@)9x-L zH#&X1x-Hg_ikABs)uD#sDS{W=ix4iaARv3k?0*KY;7=G`=&P++_*+oau5WeUAKu$T zx#7JwPQJ!G8*r8+Gu}2b(!;0oS4dS1IJ_GS2-YD^Oeh;?1 zT}|beR@S^ZBQEgKvjw<#q*!iKEr@)?m^^t@5w?NNDzGT+}vDF@*K)ML&aV^_=|3IeLq~fdAu*;ts2kD*Dr5o zd%F+xTj~72HI`94Z#BCz#c+~EBs0Uce7OvfMwhrKyIA^L;;-#f`!wp2YF7UM4g6E# zo6i|)c0$TKEelba!a3oF<|#^vVF&DY5j1{$S1N#`z>fRkU8l}`ZTnk%aqz$F%VXlL zYgq8@``_@FH&FeO^%eBni}ol^gHWZDEcpk#UzRG|O{qG~zdZ-brt^StN~&@Aqbd;N$V*#J0Llzo=en z)@^I3Jd+}(bqRd9Y+eyA*wMU`e|Qn&iXh*;w@)SHD%Dm^J1(J4nzh;I{uKCeuXyXk z`h|~(*ym_VEW+N*>K9gw?H}&c2NG>mj^TiLUS|m!c{y*tKW9&bns0#``sr31 zNLTH#=r=|Y945>yw3sC45^qbjkbnT%?a{YM=AR0`X6S^zCjQL4w_Rgg(iSQALbg)d zT!^EF(8h}2=*6Q9p*67G%R18m)jdgINf#@Vq zanB>T@{?_wN;y3g{`AmqgfKtXoak?KGJZ)IbvyW{1%9Vg<@N!@MnuO zt#89}T}LJ$wU2eoGGV5%To)0sT<%5rwmRh)E)=#hX?WYlx?hE_^c^|ly1bLb8n5kp@0;)5W{yQbbY0+kRpX=6A*)0L|h{0pQEaIM!V~R@1`zh2-PT`%TISqcSl& zRx60?meNtaLlm>yJdHHYg}+PuSK&QV;)lik2Tr>d5=(0x?XIORS-!<~whSPgD9Ws; z(_Jp$eq^^$Fvv*w@BRuCrbjQ2ykU7~1XXe7|5_O5ShS=i%zwN3zb}NrpRg&OE5cMfNH|Twr5vFa>@bABa~zGWeCOYm<5Q zw=+C5S$^+-*!M^C6x=Y3gJT3Fg=T(saNoDT@J_8EY(Hdg2gogBwOOp%MYoCfMQ&Po zq%ok{r8XBSwUjcluqD3xD$;JM^eIIuN3=@q6ekK>CuzqWNC%In@}?88s)h%GN%!OY zdsKw9Tff6O{{T3u-e2#=0K)X^_?N8qhZW%!4!S>c0Q>2?=6=WZpQ^*u41z<30nM*@nJz~uZ-1ZU!{!+Bsc zX%={JIXRy2b?L_6^O~!0%W@<}W%XKJ^|!`e*!}G!3%M%2@8%(Bm2Z06f-FMJZ@x=vNY1-HSJ7k#?CwNO=PwZT?4( zUe(t|b{;wboCDUfY_#a2nVczEAFkIt{YUxl_-I+jwlxQF?azPn(z z&eqaJ_7~H1`(Kfm+FGpeBuQ@$_7SW`<&|UtMTS+J%kvNHr=r-{e%jh(nxw-`@Xf+% zy8fK8k#};|cM;2FvUS3^iEM859mUGa99Hj=fns<1%R%th{3hDc=(?ttadm4Yw3hm% zl=muhzqQ=kM|bC3U1m0d%-CARGE7nM@^JtCYXvtJ>Eu&yrO1t12g*JaJv2SS}?kBHKkAtRr_QOTx0l zy6V^BHKVl9CDKUOQEK)#a$eej4iT0JA-9QURw@x@P*eo3b@+SxPF`q#w-1T_CTd)W~Ih)Fe-%ucHf+~nOU6PAKkyp<%D4>2P;II}p* zx%s>C2wP3@3gXwn?F^bc8nynRXF3*+;&r&axQ5y@a*8E8tVO8{yKX0ylz#EUdxygB z+JnXVwy|fS+()D8b7~g$*EzUW65kY^q-o)b(agmnH#3A=R$b8ChCWPgOp|yY_8GbO zTRPrrnr!#pBh$=lCXH%kj`AD%W`^DyRymGg`{zIrAag92UPvlZ37hmY;qSxm1b@Lf z{{Up&b5i(;Wu)7^_zuErt5v>R*;%dh+X!x@AG>RNK&mb`M5^&$U9yoQ$TV_{mn!Hh z+f$$TANy!{mru|d$KsM+9TL?ewXpGgsS-zNb}g;d8t(cWA%c5}WGd6!Lh*@_1I;fZ z+&&(D+Mf%4EqLcz(0niBF9*%2LuDQLSqn)G#k4HWWQy%)3nIn?l1i@K>$@dnQ}Vxq z{yna#;av*XTJWEZJV_>skyvS;5WFd-U0h9Y01LU}Xguc#!ISNOxR~#AA{P(l_SU}; zJO}ZY;d$1)L2-5B4RiZ4?)&}`?Ki^e2&M9dr|Q>?6;@*rJMUt zP*k|_gm)KEo!X_%^!F1^t_l5LXKZAX%ihuuo3|(|NG!wT8;%dk>`8YfzohTk ziW`Ofp8O8iO}gSu0?DMvzy+aJJAc{vzFva_U=D)nMX9e(@)LK4hh1{JLar6~9<^3M zgS3FW=O(nS*K}*od}lwHtodTZ>__mZDsniiB&;Q{?+Du?c=q~lEquYJ#UGi@=3h2a z0X;AXBzGH|rg125iRlnivPnq*62 z)ImsH&9*gcA7P%Nz8C$id>f?rgX4ais_Xg+>ROfeh4m|0ts=TH$nuN31bFR+@Cn~@ zJTo`SVo2UcD)(Sl+=5N9$T{ToJuBtU+KwF#%j2bnz5f6S1)YzHJOiWL-~FEI(&3`8 zx704<`%V0QTyAiheWs!$bNjIKmLxMOrj?^~6lG?9b9_42E`AOCVfcTjYcgBwQfeL} zn)6pxfAo9(KUIfPS1K58SlZlAD6}9tBz)nW0sHSy_-kt}mEsL9);n9z_(!$EIHvyqt-sA0RNlw3-`{CJ3Vs53x8fWg zAMoC(s%awev@3rUUC(x!c!D%}<)6-HmU&hft~Dl+N7}BVG7ZK_3s=tn0JA^FKNoo7 zPls^po-nu8w0{)cSn5!VZ92x<(*AbwU51Nxmk~*B>5*%D9I553Z{{>>zC|FlFWbB0 zqkK92hxH#0*ruUxr|aMFl4^Q;OgBiDx=)8~F72Xs>-){oMowFXWh9=SCf+Ll0ES=j ze~0`#B!^0~_*>$b^vy%<>>f*$k6-ZKp#;Pz*fR!>_c;tz#nJ^_0g~@RV;v8yz7BrT zx<8Gy3v2V?e+B7&EAc0X^myzweJaW*n@hdaEc_npaT+mw@wMNJnSjVpmSR7f&XBa; z9QeQB{{W9)7-iI-!JZet(L7NekoRe;*uiyhQ&QKjboK?4Y>a-v9kfArY&dp4Vab1p zehz$ce-G-b95>H>rfE73nX}7ub&xQ4fyi+H`R{qeF#1_6M@D7vWfpx0+ zmgZeX(&1y&URdIjXOiM9nA$c2?Kc21Q5r?b+-ja4lj7CpgC?P^Sol@7Ej{9UdF1;f z7FP1K|seokHBC&mnnaR7lg$KWSKXZ;W3Gb1nr#cZJuN@ zjj#K}2X-(@klf^4!4CrXTJgIawn4}li|%i#z=?*$!M)(MobO>R#LorSLYAx zf#Y`Z4d$ERyQI~B;UduD(zN|l+*$0omRYyQ48SU+lPvc3?=c==Xw`zbKm(7|4+vSc zhlMn|c@|}o3piqNG^R-d2*4rS{pQb8+~n4D>eQUywCFm~gkZS~Q;fSGyN<@H+()#G ze;01Q=dD=1kLAeYcXL%{xL8PG{(4nP*VM*42ZPyH#@e6H8Dz1y1D~0Hc|LMF5i-p4k5YKD8f;ShvL4w=)GAHJ0D*Zr(t&1eBzi%y-$51%y`0o8HKHY&Q+p*)`uO!Qj{m0Yan(3ypBO((Ug-9E^g;7qaY4-BU(Cd5kpq?#z& zAdMVsyOe-NMpTAVgT_GsgV>Dqu6XTo{EmTcepPl~zwHLxoAI1Lw0Q4a##^bj!UPS_#?&8qX|08h+4PFN!rU6>8twnqIkm4}_qPUR^fuNhI(= zq(ce4hDf7R8?i~Cw0DtHbc)gJ%q)ZVr&aKEp0nZDw4?T|KTKHG0S23ME}p9rs8Ja5 z83>Fo$`B9$IL0f0@&5qqEAaQ?mCTcCi{lMK;(KXqgN>#z$S^jSTp9Pl+01&8u16*x%dudhSyT%=S>*c)Q{Pgp80p*$}~E=u@iv zyyb;{q1}JMD}D=KL#W4fqF;|2l`xdtZOnz1?Ud4EK#cm4Y(qMzyM3B zf59xi3u{*Q*7q7$jU>0!t_|(vGJHbu42g9)p5oxj`>&2)&bqpncU|`qwsP!FoX)|` z8~*@b=6+@PAK}&i0K&hH9vjncgGsJV4Oitgg#-rM3OoYzxBb1T8|AH+eVmKXCPkhj}OERuZcN%>d;05xAn{{VtW z{1nmF-aB1C$55r5CfW_2Ki4Dj;x`L(<&Bp>?GvKN0pK!%Dn0Fn>O}mx*8UjWcq8Ga zrKZ~7UrT#u;Y~u`*5(UX);qh^@aKt5v|d(Oh&O%@vy~JA#zxzK$_d=~qwsd)!`~bJ zB_@xkU4LflJ~Y)d8(Vqpnj@sFUO4c?K^a$Jc-6c?6E56**6a@nn*Bl3zu=eu00?Ya zHd{X)D?uYnjy@uIg$JR!2c7 zUCrn3AS}|Z3lWDSk;N`qw{>FGNLtr_hxwn2f3$Cg^=pq5-a@*3_IfRxmvY%jb86)! zj`v*B^+@f&AQg@)PZz{Qgp$jVgwH;;;4g-LF#Vc572tht*TVXQarlEo(OXiwx|UrQ zJG%0}cX!%HiM5N!c7>l%@lK&M z%eENp#w7DY?hwNadx}j9wSLkgoM$tt8Cu z*eICX#{m_7=l~U5-@1|%CkX4U{zvC${1ZpuCyFnBY3ttxZua{7DQM{16@ z+pXdfE@O#V07Dy}m4cyl3oU+}vokpB^W~0DJX97M9hQk`w$?UwmKOPDSZ$$5V^(5W zh*barS&2MyNj>Ud%^3dj=L7!$uTdDas})L|+YFsnRBhb{p{%Q+zEjRf zTm0g)S|S`W?jJ!@F^=$blSANz8gZSHi`XrFKIH$^Gl4{Tz1i1*z^9vyWb5p8~Q zth!y;Sp9_7ptnzOImb?iouzWhR(1oB@ne?wwQ!+6RnV-m{CBw202%h{_|=$G&27p) z?_k|deLbq5j8a4Kftz>k{{XP;gA9y3zdU5{%*Zk^e`6|j6a`gkQXa}Kg z2Fy$H_QB$%j$w_xzdUs{TsnpdI}YEiIUSi9cm=;&tJqd2k;eVsr}_HSo7zAd_(1%H zS5U)@7CGnGasD*#n6LvZ06prItS(YnSm)*Yy-hkPHw9mqbmF6DXOpK+I`B<87}?tx zz|Vdt6J>^DB4-J2hk3h+LJ{{UK))-sEg5j8aySXO=n$J^CR7E-5m{sq~H)^f{c|VE1Zh0IW z5AmrlCHY6omSK_!>>(w{A0&TlwEKUFe?vGR2vy&*3!r_O+OH$0OjVkbc;u*hx zXOp)W>^TRF_s<10!%Oy$h>SQJ*3Hxb~vC?c7$m%J{Mcm*WhL zy>glbkY#1ysk+oL{J+j~T2LrR*#;F>V`Qoey!UD%WgV+4==}tESXyiqde-nCre_pwv}#sk7||~)uk*?I0mUfm3_sR`@Z0M zRj6Jx9rN4iO-rEcOtN9e&A0G0=;r+rjQjIc1MC}A<(vF!cqH48?zeC;j(Dq#gO{;8LId65!!`La`oC;a+;pM_;2gz!VizY!#ps@&r0pmQEk92 zh0h|L0(pCT=RE-DHCbdgL@J=b#|OPxf%{>;o}AT2Rsr}e-;Td~txF7|B?r*;^yyKT z+SvJV+b6gA{&d+7nI`S6o`Cw*a{+Z`P^XQ)l=Xf-U%lHuN~3QPL|%C9&NI@b5-d;O z+nUizk97$lKSgBb~#J{qfi9QQZ8=<+cc8oF0a;l(j7&FqXD&m-mf=z#qHD zKlsX=vO5w{gm8)03`&=$Zf0Af{-38lP05+skz^5(bGBzTdKeDTXP zJNBG^0fI2Yx0sxB!BY)^?s*5sFNiSsZ}y67X{0#+0EAk`eVI7I-D-^4IKtsmKRXPA zg$L$fw%$Fr)F#^w*x8sO+G?HMr@|12QhTM1@lgRFWQ_p_I#ySwo<83D1 zO|oC<`G~LXjl!&o#GHJ&?s|SbYi=(d=+}VxI;Dd(ylNc%p-4s}cG7df0OuZrR$6#B zPnLbD9FT2^BUge=vBsgez|Y~+H~?{)t3HXT_;56uY}Pu2(&KraF$>De!sUQ${JTKT za#Z#pEps)vTIR6$i=d|4hSs#NBW_kfa+_UvISrnD@Nx9*avvRNvZDV0r^3poKO!;D zIN;+Qk55BcY2(JQx!tLy&4VMTTZy9{Tek!dK_x~=KQ=btl5hzdR$9a|+s7!lwVFGC z959h~xFBtAI)lI-xgDsCh09`V4;^UoTmJ7+hG5@t42I7<73w+@$g284Jo;A~aua@v}{9qlUoM4`Vk~5y1_U5l! z>Y%G*KQfmptU2k^*SAibb5hT!2OC+L268|go}E2t&@5126j|QcsgB|~BVnERZKNK4 zQU}ePbQv`C@kOJ_wPq~kg03QV3{OsW_sBluK9sAfNfqMT{g&oPe4{YB1}`NJV2aQqs&pB22W3K^XXGU+aHJI@<=@@*1BAi zCH!q-5nUHiWpGpxV?SQJotPNG>UbicZ-`b;DM1#jSyv<`op5&YK<8#TBOZkG&l%e{ zQX+iYTbz)&<0sRvKHSu+xNScyXTDD&w2q=OWASaceV=F4T1YazySkGc*zdvz4{ z@r*GIqTf<13CI?5^TvAMXF12NcWh}ZEoUl*`ekr2!+G>H4C&jI#Lvu{a);YK3~L4 zZNqE}M3t11Syr$eoA# zEV3>}d6;L6W0$eXjdbGA`8X|Ckj`A z*N{i8C){e8I(~oFtLp2xl~;4kv>%0<^|iW3cW>d}F+m=6z3l6|k&KTgImSkEaB_Lh z1+nn=#F5;2k*<=;;|#BVAzm0@f2!z4@1e)u=N^F{@YI99`ShdMrJ@N{cQ|c3!u}>l zXzq0V9zYOgXz$xHI0N`hTyvk|?gcjA!rmdDOxre@XqO5@3}5b6An}o~d;b9Suq(e_ zzks1X)SsvIs&u@;RA-+J55!oJf3fYt@-VlV-PrTAdJdWXS)L6$#8KJhbSOm4)lQ+n z`GCgnD3Su__YY%G@0SKZgB!sGDv_chep{{Y9jf4V=NIe*rl z_eb)fMcZ)ZIk@~UstF4<-Ir0db~Y4xe?QKZJTa=fFp}#}^72aUg|~D#&rJ39>0J`_ z{{TMJ#Qra>S1*{kj#A^nn&r%_My;pDxCIyNuHcY!&RlSD_*9GF?OQtzWO|y=+@HdF0xtr}8rjDu?_b(N!jM&_A8B34h z!Rl~$#~Cs~;9Xq?*>$ToUoQu3fkrLQ&7Uy(j>o6Hb?Nv203Y$FNBYL2m!NYta?$uN zSXonC1he7zDG~Wx`ycAN+sx@TCZlATK!_Y6oepVoPzXB*LU70&2W4CbLZltCyTnKd z19Bil(jncWpyUn>Bj-Q}qpVdn#}$U7H3%W8>ye~H4us$kp}f(RKVDoHLU4#s-ss97 zFRlwAI7E;PA#xyOltZ$sn~a7W1R)jr#NKls7jS2a)^sV$Qv>wK?q4* zkPM;Z4k4qG`e|eoI0!;;$UrjKAA&=Ae{*-00XYzIghMWNY^V}hND73Ea>%BR3RXxy zzXu2z;SlmhnmZxH+98mUP7;JzIs`IOj)VIl?hb*BGBR+8l|x*}sIG$V145)7qC!Sd z$5s&;%=WrFq>Bud&KDvJLM)Sd78wWU4MtfblDg{u-VR=i3?UO75;!7lW&%Q-9ReA7 zWZ)2Ihm0XZt#F9BLu| zM=0)ngb*CkK!(w-Ukby&R2xQ<89J1u^qlDqwVhFKv2xM&Qd`4NO^$8NZ#m^%E RqoDu*002ovPDHLkV1nTs5oQ1Y literal 0 HcmV?d00001 diff --git a/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA_alpha_60.png b/data/mllib/images/partitioned/cls=multichannel/date=2018-01/BGRA_alpha_60.png new file mode 100644 index 0000000000000000000000000000000000000000..913637cd2828ab4e2ff4b2bbd92c4cf362f871c4 GIT binary patch literal 747 zcmV zL3V>M390}vf5ExS_g|iB()io}T zdzc-a9(QkC83JQkyn5Lt1Z5CrRi#xHpD{9oI-@&0jtqe@HUD3>kZs0McGO5TM~1+d z7FE_NXiaZ3z?maMU@%u%R=mvsm?J}AJhD?K_8)UCLtrp7B&y$7129;Iz!*D2yjN8< z0y9X4z_>+*xc3{0=Ex8j*D|Cx*=8j4K{5o!c7|A?%)le4CMiSsY+o_Vo>AV}VFh50 z41v-2iea`H1DYg5U<}czr}rCyox2Qyfqu6aXGTB<$q*RG3>nT0#|)AoFmf{_nrt%+ z=R=0T=#wE~ZFN$PgH=pIpR#r$}~v0vQ6s<%(gm8QC*8vEQiGG6cq@D~4&`dke3xtTJT?jHV1R zn*p1-WHaVkhQK(LAu?mT_I%GyhQKgo$ZgD^p$y@(n<2xRQG=Ep8^{nCn;A09ds8(A zT2-xU83JRGAy_kN4BT(jY8e8f?Uz2S`+1phqfY#&mLV|C{nDp(Kbg^7%Mci;cTmZU z&sv7SNV$V*STh2UAuvMkprV>#Myssnk&_`_<2W4$=?*U$0wXp)$=dj2RgM;~p7O*^V>AfDC~V@@+tF<5118 dqE*&-`~khx#S9AHa*F@}002ovPDHLkV1lO_Oc($F literal 0 HcmV?d00001 diff --git a/data/mllib/images/partitioned/cls=multichannel/date=2018-02/chr30.4.184.jpg b/data/mllib/images/partitioned/cls=multichannel/date=2018-02/chr30.4.184.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7068b97deb344269811fd6000dc254d5005c1444 GIT binary patch literal 59472 zcmeFZ2UrwK(=a;AlB0kinU$#IoP$Ufi4r6WOIpH`vw)(4B0;i}a}JU-7?2VL(Sv65i+aSt!UItzJQFhKBfPjBLuNU>0F3RvP&KD=i2f&Xu@h-CGEz?|3lK(Xi z-M+|UU7W!ak8_La^uJG z(~2}8762~#?}C=%JPTC>Pe1fI&jJv2@E8ICY`_h026zE3fCJzI4{ksRyk77 z@W8mpD9ZeaH-@E?y{p4tQ0Vqf2ulRg9O-HY*1?s%bIx7$0te>tOk7>;<;{`iP9`qq zW*`F4^mH))NsRr2_#AUyqG)Fc126w?>koBxvelHms|hys)41ox|0_n>+0y(k7<3a` z7fln(zu@rU=3qT@4;Mve1x@AKruO!>=L|7_q5ndRWo7T=C1s1S{ADSEAIdBILW4Ci z%*-uJTy0%I5X;Tn$>raWYyU$3l^n;^Qrh0u-sw-qCH%qvru^>$unM4Jk36?EmW#cE zs;i5$`JWvJ+ZJ@B|ANIewRZs>>c8M|z%I7>SMrNdb@{xZ3s;3)WSK%c9f(n76prW8*pktyzae2W)LMZB0B5pJ(brWJb$A>)Vm%~#_ zZqhT5XuLJ$<#Y1yk;Y(*xU)I61$X8rO)GuGBonD=Ci^bkMc_)W*0lMzvd8cHTxYiB zn4>bDlt(xB&+cSad>EM9)wXa8h{>vK8Jyn}l-IFz4~)&OY8_hGCj?MXKxt7gWWqp4 zKV+2s} zUAu}!J`V;`Q!R=#3y5dFcdwaH)O2$Zswgf##VZ{0Ol+LlaVt3kl$TT6Y!5?EMqhlm zrz!WQmS>TJy$AyX?)U8)`|H?C$|~yugP;AD_$}slAB>Dk`zVoTXV8$ZGrlc2RkHcA zVrX*)On>%*_-b>m8LU3neG>XkDY+u;-GuADZH^yL>uEq(_X!1xSrR4lH^mR?(=fGgeRS4{To4t+?y8byRkW#>??09-M2eq({xLP}*5Hvu&l!MqC^|mPg~_KT4KfS>r)6L@Llxx6CLUH>L?whl6p!XBBV`5xi*f!$nbqRu|>xF z`}+IX$kqs*Z%YQ;fWXASVdbN>0|Fc0=OLkq_eWAhsxzo0zFKGJ8!=X)> z>SSi+@fGE9O_SNUcU95KE<1B&)Stt&)nx5Kw_?|0I#n!^Yp3?=i< zfV0mlULDCtT9XJsh{$b`qId7<$U;Tm#@9p1RK2tgX}5|9y|>k68a};-oUf~PLe7B4 zhrT&F1`)9G%ADh+y)%GsZ#i#P4@$+hHxp9i{gB=5tFe9l)h_#ru+ygR{E74{ z$+s8>T4@K~w_GLhfagzWQ9iMEmJNMP?tpp*Jh?Y|sC4~gBc<{%GrDGAXY3y3^Ji84 zHhk!B1n;O!jHNbZ`5g;p6O51O#=-HQPYXQY)JF~+9j z^<+p7!uTm<&tR>~)AE>p;!yRwN$~R6?38j9@YVB=euxsE9^~IoJ_=qz zmNtE7R1?W^DF|}nPPWhrwQ*dxN&1u`(4;@oqN{y!nzcW1yfDvZyJBmzZo%bxa}zHf z_kjyaR`t<|c)YFlmPxS5TjHiAkA1{!J3R-Tz;W{3tdGy-vnZ=d>}jmcyYcWa6!uzY zpDZM6`j#ZvNVG;RwoVm2wGoZ2RW0*7sX7CY9EbTMdq@N+e z8lKxPt61=!R@~DZl~hn}N_RFILMU&ik6e+hqPJ|kKTRm294lEkS@QVswtz^l_i5;$ z_Qdka@^vm}DqWb6O+Wwi=&AMCLH3#%dPs{Q{~3@0j+36H#uAN*QO=qIPUpL=gj`rO zgJ;0JevE8^d9bt5K9}y=Ya&5ARFs1=KwiJ$4A>!E7{q?IT*mf*yd5gd;7(Bm8l&;q-(MF88A8M8Ahu6z?DaIXL;%7eW6tPLK$%m zyNbRiy_;i7KJGrLqjK?WJ0*g4r*>f_-CvGtqjh>a^WvRx+|}VDpFVLQ=Y<9jH!?^M zqmjiE#wRE6Ni&PxypF@@B2RV;f#nXpGvHQP)6$V&;>4lx;bTg#A@z65ZeGz3AJz{U zF)yG46N;;4V+*>kdEIj@mtM4o&Z`KYmeQu4+)d-qacjNUUqZ1Z{x3G&S!aClk zlX}RZh%v1n-;cBC5ai#jrP`;)KpYSBoM;{Pb1~DhZz+JUKB!p7pC&#YY9YaURPpJR zxXMYn;a4aO?Suyt2mY}tEfy*_%pX7O-0MsHIDcEcJ-7v~@}wB+BQZL&POs#GFYU+^ ze(j?U^p%9RfT~lxjEcdUx+cNXLIcY|->CT*ugtQR`S9X*f_4U4?hk)`IZ)@JW_hm2nz0B!|Xt9D`Ewrz*a^}6_vj4K? zHCW^2SG)&UqzWZfk+w zDsoC+m+gy@8)IjHzfssjkB&EX!xbUTWH~P()9&08qiW7nHSf?4tIIKvv?o}R`1UTz zJ%u*=FHi2Zm4+t6m)trdE6RneWFLQfdoZe7vu+)uZNs|6-kvpByBf4W>2av~aOc%Q za?`q21Mk3VS~(?IVO;N5^(Z*DD3ms!-#CkhX)Ik92@uJ@OKCC9MhCr3U=x>`^VFtb zMs(^l_Y9~p7c6-Qqv1LmF`RBV0+EHKXbzt<4P9avj^Or1YcuMuAKF*NK|O6 z7$R->Ng^BNVTt)qI}b~IE%o?K^|h4So%L0UUVIAqy3C*2Y2Q$#yAvyY%5ylcqtxD< zO_S9y^t!I!k)g__@(d8%k(AsX@C~-JS#4YisqA~Y(yD|EH=zPOI?RBl5LI+A9N zzmjw}ez-B)-r7p!5oZ6P+yr}yzL#pq<lcu?ozqXO?I<<)T$ zP^(@6uiL|4i;p+D?{p`$5}7LUx%SCj`_@PQ6q3`^f|9LcV~KEpBDm2eS5nq)pU<7fEdA)@%C{*6_;4d&-74?8^=YUL4={avU3Rc+SHX zw>oAp|A?`;B3NR-q7HPNKK%);l7voPM@qEE=+Btt6gG~tuX&c@3)|H6!~%5iH`c3T zpA1G0oIOv^fN_lv85mdA@WxC&QQ%$vDv`36$V7lR84y3xyCItXntIS7gqqLp%NHkL z{44V_Nhe7qWLM_Xz5-F{Et>(J6)b1D$K=Ujt~S}??{lXyFeNg`PV=K)2CiH^`F;lU z!rQ;89+e&^pY`>h0TezQtLcGJZ6sP+pBl#A`^g>P(hl69`}o?c;6`ZyRNvq|ib8I- z{Yj19a$^Ng_P~vr)Ky1ss$M4P_;vfeSl;-cW8WtOMhnPMB_$(1v_c=d(NjDAiaTwJ zufn7(*YMVjsv#TdXFyV0BdZI3Pw|$nrpst(^-6@KHZkvBW%6@uIR5xF@jpXI(`-{UQ)-EkuM{cCcXCQRl1PC4@#AZvF^ zTzRSZ+q9kdI?%8HC+juZP66DVgJw!?x|R*~Mr{kn%M|iyYD8Hrf<;cKqYoz#KHfSv z$Da(!7SJSaIB^NSXcKd~vdS-A|BE=ilBDvQS zR;N`x+WMlHueecFGW9G6x?X>i-e@+Vdbm6y_04LBq|!&b9aZ?`hW`|_xAh+F&;$r1(bE z=J&TY*_Vy3D|dW+6+b-%wR(_nzxf_Ny3B4>r2Cy_JO-mye1S(I2Ww;LM{-5oK3(W^ zS>m(F=mFdvRpz$lCed76c;YF1exEUAHo~-}6$T#mN+nNVEN1z%;mzutthxjqZ>)OW z+1p32O<-$%z);<*f09^f>%^(rbUo1qB{*2As*MuqJjOSVlWW}^wT67?az*9uK?`r_tg`Pt8R*)@59@PuSUvLIE^c{18{DezoJ`xGB4i0Z z=cOAG7#`tHpA`FUs2z%m8)sG-f=>y@pPKgRcUV~^_bhY6R(#6L&VUg=E{}XzcfhACUd{{4Ke<9(Ca>JeIf7?j*`)rVXKnmQ54=|!TB{iw znw9}osy#=PCvl@Y+PtS7_&a_`)dl&Hrp3b`_{~1aG~Ooxjx%e*apEtTVwGFvSs;hrGCVo$?Qy^@Bm(b4c;!Tov2^J zA*ok^>A;eY4o<<0ZglXdxcGWuW6H<^e~NE(;h5p#mt}8BAIH@*ApQ&}Z&CtZBdZQW zcJ_P%EpsK~_x$?(X2exLA0~Hkls4VzE0RRmJ}k2z-fNv-+wqB)oIZI5d-%0sNwcRS z#EPmFTPh`1IP=pORKWiRVbQ$r${7%oQ|Ffat?0|3K=T>E+P8nLZdE@l|40MgG(smS zB2Ms0dHeWIw#Ed%SkQ_5m4Lc&*!yyJniTN)s>?lhO`d@H!F^y}o43i#PV@{wI2KJ* zc1vzb%u!03NqyZrscRT*9BoD%H#~d)3{}47h&rogJKeyS_)-Fn;g^w>oX;G*yT+qh z-D|X8yj4eh2llGJV)PsloaBwxEp6!e?a@k{0os1CdlN^8iaYy~mZy(*nBr+}T>XR^ zB9|QP^cAbiPh)pXx}fb-W$W9|C=1-l^QGnbP4gVB_B%~gD;{USJ6|V5-cx&3#hoho z9ElH-=7)Y6yG`-?$X&nWFNb!gvBxaWVN`0<{5eI9Sq3ij5e0{?pXJg|dC$IF%zIwF znDz5KpRGKf0VD&Y05iZIFa^v37@!7Nfv^jBIA1^xAWjyr1wE*_B(+S)rv~0DuA@A>TNZ}yD&9AkOue}EBkr9m<@*g5ucN_HMcWI zx;Vplz(5_q!^I^i!p$YZ%@5<|7U2~V;R53)SN-FDNsIE>Kk+vl{FLseRG^MenSa7y zxZ(vkDG#=b%EiOO^#b}!3zVoAsK4QyaL5lD2nrMkNBPaOJ^~etAf0yufYKa+_8Z?B zf$@_+Xu}B1pD;)mf%Ovx!$J|*zv=rU0xtMJ0Qroi{E~;8iz@~4lOAj#1?86|Awwz9 z->`c64|>RX*QNiE6Z8Q9$oN(F97g#=h79N*JTp*#$pAITK>LYT0_oBJME^?{kZH!B zbjkRo3+M?M*ngtO{R0M@K2PU1IzMF2uNS?>#V^7w4C4|8`%Z+1|3^>$3l5}31&?!8 z&Q&|Ff9`4L*TkP!@B#q89qRuqaM3n+0zC*EcfV{3j>JDk4z&D_kpr#&^T+|?A3t@2 zHlzIRS4xJTOn=0Se~tk#ZZZl;1DNRO;1oS31_lNe7A7|CWjtIQ99(iDVuH)m6tpzd z6jW3&1`cKz9UDCr)iuHEY+T%Ye0;Rb!s0?aVjR4DJm*XxSYQ|h2bT;FkBsLk)m5Iq zoqohZJb)A^yD}73A*el)%0)}@mD^0Ef#Y-GHI7` z%*1bIaFo6()&QAN$7*-=s$pII5=B4lZMZo1 zi4D1!>-6=H55+5|9lwppKV0f}&nB-d`XGYM(HUn zY(vXI@ce|ReYHnNMQW;0qbc4V>?x{fg6g_h8RQ;&9Xr@K<=a?mP)k0*GTdERsM`MIoHa6#G%>K`$8z){+hD*!b6LlOo~pCiPh{ov z$|tc#_7|VnA}o4uBTfCDi4690t`W9|rmF}c_-d(7Gd}VQD{;+Gd1)ayra}T1jJ!At ztxuUeUZ_SBC{=eU;?+;CO65bVKgKwR>(iF2dtVST*7V#2H(yAltGlF?6Kf52-X}jiC_s*{)RO0q3 z7pugIbzYbD@xMiw@bWz1$$ai<={QMygY_nkF%JW_g2yPn!ZP^qv@#f#$&4mglW)UwOA@3O%tAq=Pd0_rzra^_2icZS7E zR$HgaDe|eFGIjCXc+-|m+E>e@rgeqw`AsG$l;mjch)zk7XVwSGON<3|)H~CXxE{Xj5}%rS=g{ck z?X=_0J8SE*&PTl;KYdYMPYJH0S%5Nct>{kACPayneQsjQI!y8!jE#FNs3Vn_Q8r?( zy*5wE_R;Y(P1A<<_%jp!?Sjzw6zzlP_N!q?(IG>(+_%H?>7E~3VlpQUY1JpLRvw8J zshCxXQYIZccpeT(I&O|Hk21C1WpjLI&|2*`A)VY?eKjPCmTh2Y+KXESf79(a$Wo!P zOwqJfCc=z?pe#c3Malyn-Y!|QPabv(M|O!{Q^kA_nbf}%vF6J0em8Dg%nqrHl##UU z%l7FRmyi4sDw(KE8(Gcuirnvu$zU?u^w*xj{Jyw-GXtW82MRCrZ9M^|m7A%Su4t8y9zeR)oFuCGE2fkvJ*Fu#=Vi9C$cHDdwRrgu7~g^wvxb(< zWv=hSs%u4}N5@~dRgYd0Pp6iYXbN!{*$>j{KI4}RU+nLd1Y20Z0|a;;Se+8p&t3*ux$CJsj_b)~2o9@8#pGH+@>5_!hZn?kmD=`yWxR8TL; ziuBZr*|hPT?$|dYVlThZtbg7kRPj1I{!*|q5xAZ6DW=Myu`jLBk1%L4J1qIQKFD+~ z;=5T77cW^qXqe{9!_4;b_Z=_cCap1%pJ<*ike@^&@W_LlvZa&%4+-W4FD)D7&jPl#3C|6Pp?j@|xmWXje zsp=b-d3?HVTh{jqd*{MKM}(W5uu`b5MEb5)-C5^89K5xxBM`ghR7Uh=c_Cp2H5eYZ zI~}R4Foo=`6_0=COCP3LqRGQNLO<1Tat6>4%!@h?J_{I^ujc>uJ?It*BdNEhc{gf1 z)g?ho@lvmTwEkQTba`ZSgr2o^IaL}@O5#Uvvpyrc^bLRS$!!_^hJfQ{g`OcQ=<%{3 zx}fO8g}w$2ZM{f$CwiWe1AsMnrTr#OO%dAp}ok)_EpQ1V)1E_uP_`y;tkfgaCdXJnYEDtWcx5o`1PJX`8k{v)rew+d4-sw^^t1&wao zc{}u$<(LY+5n6j_c;yNFTg3nsUQfvzl1vQR0$EvIgRVU5;NoGw=R>E=pGj|OzU_eS zJE%(L6sb2kxXaSUwtRMc(;~l1@9P$!e9>G82)X2sXp$T3WrYuyw=bD_e= z9(klyI(J?_K`lCbOas65fJdg4F#|6tp_sN4PmFk_r*I(v?^< zF$(?mn7&KmP04fG7NG<&Cb*_*%}8uPss|#cDD}?-Rd}-YEp6>Y6=}4lcU=3|%|jX$ zIdpZogH7mBPn=hl49uHtt7uGY@2AX=&%*h4$Sl3oH{fjrh0M$9FHsq50~QAa5_85! zIx0t`w`pq?p*yp_B*SYg_Y2+{@qAvksIHG+J$$mw+FE}#jQvSBYH+fV>R@GReS?Ct zozs;3wE}5QYGM9`flMsyEh=)}I@AVi_ZrfNN64P3jxDw!NWQ zQNxkhi^NMomp6RG;KGb|NU@#Jzsz)$*zWYM?d@h|CLdS3-6b7GES#EB(^hH>qo6jyhi`t2?hL!UkIdZ%`adCp`9R4GQKt zSrxxke0-f?E0RH{r>%Fj^Hz4kJyPM*oApN@;u`Ap*uUo1p}143!C#OChk*}+8+P?F zX`x=oVHeMV0^I#vGQSxzO<1rup-~vFjbpY!1`%KxlHMNdA}+;%6c>%FH}vdtv_WR{ zWMq@RdS9SJ!M^G=D$4dInzNtEO(OkR)-}%f8GYTyaTb~HAEOuptM_s~sB6KysaLhI zgW~7BX%N%G5dy^B48! z!q8-sC4D1WHD(i*Tm#}z9&gjCTytW~`|{eF-B7vIo2hJ?JIL!U!&(Ke;GKR#ERSbX zX6PjihA7=`FI8p>7WmYf#S&dB3)x&DpK+wVOR)JtIZfIT=Ff%L z2%N`mttpreOC$0;2R~?;wVACv8Vc*JilSxLNqkS$8|l1X>u*p?jE+M#?iM@?AFQaC znEW1rYj!_a_>f}o$w|5R?ff$U-{?s(B`dGqT9~+${gRRKv?0FXO-jpQw$uT$r%~fEq1Br=^ z6}676?%pYI97vcmFKMJnrmN$&*!&b+8yws7Ok=X--e|j;D#jc$*-UTTq0H9t=5u$l zt1nO+$#i2Ty%sxa-fBnB7v_?%RJ$L+b?4jU;Z`gED|@s3NQE z@O4x@r_k1=a$NUC0^aaZ1vb4N{cFGQhbAu^GLl{$VlPC0{5+rTzi#KIXjeS-FtIYE zImhucis*brghFT`Zlv6sJhXQX=}_`J6Up&ZuKDrWrF4S&i4zZz^rsyW{cW^n+4iM= zSjUdZEw!)S;FEZVilg0Mr;}`m^yxN@+$S`;W;o0i_~ak}4YArRHtF_323g6Wsvi=5 zfAZ0T6n|FL_w^qBfHyjD2E0+W4?Jb#9D7%_Ik(6=jjThYj_UKbzu?O|y(`n@7M`3v zPO>maQ9sg&aH!n+V85`Y&7tT7h;z3Z4CU%Yk7aZjuq)GFb{Ts8vWRcg9pF>a66cHq-dE zQoW4QC~RJ)%lM*pv!~)uQ zubCSu1=~qZ>0KGzDC zo!mE9+h1&~Ui*ZIkCHvjtwYW`a#4Z(0DIzE(6S;tE^bM7L>%S0o#C!<7{&6X(Io=Y zr+VHWuG0_zxMs<|9`0^celVe(p0E* ze)6U8MJzg7tX6uI0yHV48oqoJQ+=+Ryr;(E4Lb&1$Q&6}_?uS#X6r^mV#L-mX<;lk zs}i|b2{c0dB)kg|kjZ9w(pT)RDf>j+%Xc4#9lZ2l^Z`vX$yW1bm~H?M!8d5)G->Cb2Q_69UlygplO!Oh1@=V_*os&Z(?g?@X+f#!5Owl?>^P)hYZ2o#?v$nD@yO@{HB;oN8Ftb=ucV-Uw7}9?D7;J3wvET zWmw!?9#O7-Nj%pPeK>gr95zvGC(@B9*=Q0z&CtypzdRF-#OOnPnB3LC@@c!ei9Q}Z zmd!Y~_&!O!etxZhIxeff4?XR)AB$N}!CQ^nB*J9$JH#eQ~u!0T!@_N`9|S!G$uxhv=%nw~xKu}|Je z7z23=LSw#2-WH5X_wvH)E=QJzZ^}#gf}6cIG3G*9vx)V%GI9z#t@@C@196ynl+5Id zTp|Nebh6h{9t9Jx8Q^Ha3)kRLTLSs)c?O8B8-?_FPM^Y}t1PMDDZp%4BP-9IvUZ1* z2hBH5&W@3rZJBM8Xf3?SMxOHx1rjvagi#fBHS+_x8Y;KMaqY`|0tSh`8NGUGuVrSK zUV@8*FM%9?6ts5C?R_Nytz3}U;FiB4H|YZV1mH%cHJ!>5Uat5)6YHqe;8}iG;oRbW zr&rMqChxATW;h~lXL(|a^;;c!j0-vK4aFj|$?X2MXfO3;(Icqt8yc;=t#)4GnXF6=$ z8*QFEQ?!q?XT%~W7Z0(=<-7io-JLsze&pt@ulTJWhEG4@ssN4#V%jYGn_tnzUNd4% z5;^vLJI-6lSbjC7ueom4xtzp4IAcIuOCP**GTi-*FUunxwVOo_!`(qL(X*mxS&_cS z|9;R(c&rC(fZ++O{@KE_QWB??Zdxu>$B7~ljix6RuntQ7@%DLccdnZ#ObSz~$iNzj z=z&13XbCOinDVRLp0`?_G(B-$9L}K7RBt9b?){qEegC*dSYj7k&V=ck=hV9_@VbS7 zwvJ6+)j`}k#%Odto^{1#Cc~Be?WDY2){Jhg{oxdaoMfc-V_%G3xLqk_&8xsJvl`#T z#wC(ITt=Z4I2Js!=dMCMZE9B9bM!$5Hb_>N0_KPklv+0^{q_(82ND z8792fIoR)`nS&i}6g`(?l@8Ui6zFokySbCZoz&%0f2qSX%-$j5&6G@nrImPCk*buA z3oRg>w(MY(mB|nxrwA(&M9F{`t2lqxa;>f@FBt|HV3yWGgs*+HpU zUJ@!=qwLYi7W_gT_G~h4;g(fOlq@~hdmdz~MD+XCsX}!N3z69ZB_tY*kmjpHOL-y( z;_K`8gCRn?`QaZX?O2o!7V2x`_9oZ1A9X2qOVy#$yjsyG^zc(C`8LumTY(_2v5<nG;Nu!s;yMJ~lBC?b#BGuPnbN>66Kp=XbKVac)PG$PCD zXzx|h?(L|a_GNGu#^9ZR(^qqi!v_)*%QU#Ly1=L|7PuD>3>BNeogI|#$Z4KO-4tb|Va_1Y&j{Sl z%iiEc{1ND5S5Sk&{#y1w<_O^qPA*`uOAVytH8XdHgYaVzwsm)LILB{-@MTk*a~S13 z8hF_W6cB{T&SA?RFxv&q5BSD8Y-Wcv18L4>b}+LuJBOP=_^F%g`PMt=Z4eG{LzugR za6bq$*t*&wK=^#O=4Cr`6L7%+>iJ$R1{ZU<)p^hxgmIlTHKajU3|uIHYxyVG^iQx0 zxCaiT1*GjAJTG>I!LGrt!?=ZogR;HE%VCz!K?3EUjU{;y8_A6NX7Sbvg(O~c&6 z+{qjXHl+=E83Jhub~n-t;exP7!Vt*+u804_YJZa99Q@(eAV8dP3S2hk04{Z32B2%5 z04gyK05wPgOCZ1d?KYM+xBv$LbZ8fT_&o@NPV0ptQjKq*iK)B|sU7N7(83=9CHz$7pStN@$99&iN4moXrC5Ml^9 zgciaGVTJHOgdh?SS;%dO21FNf4`K;HLR=yDA&(%>ArX)`NGc>3QUa-gyoIzwdLUmQ zQ;;RdH^?Ch6a^RMG72>c6ACAa5Q-Fv5{f2@A&Mo6BZ?QwW0X*omndl{1t^s$Z&5l> z22dtZR#0}qr3*MvQs`AEJ5&fN4ONBeK`o$8P+w>eGzOXmErQlUTcLf>N$4u{02LjT z7!`)fjw*_(fU1QGM|DK?MGZ!cL(M_0L~TaxL!Cn1L_I~rL!&}tMH4|&M7xV-iRO+L zfEI(6g;s&~0c`+n4s8z|9i0@N8C?)v0bLi}8r>T`7(Efa82t@;5Bd!HF1W&h9D@}@ z977eu1j7X*0OJ)#0Y(!>FUB0kcT5~iT1*~HIZQoFB&I)R3}zl?BW5q=0_G9cB`iiP z5iC_KGb~T6Fsux$TC8rYIjkdWLTqMiacoU&YwU;EFR_cTKVXkxZ{y(L(BTNxvtSn~nPh_Y3Ye9v&VO-VHoGJQuuByj;BZc;k5A z@rm&{@D=gR@cr=P@hkCr@z)8k2p9>Z2#g542x17z2)YSYFJWC`x+HzcL97ZlnZbKeUUPbNHUWe{a4

      oD6$_OHl{Zy7RVURhH4U{qwFC7l>i5*EG^8{)Xsl_XXc}pjXo+bh zXc4qgv`w_jFcO#)%og?%)(qPOmy;-5b-kK;wd?8;9W$L4oj+YM-B)^SdLeoX`Y8H$ z^qUN{49X1m8S)uM7_k_I7_As%8Cx0mnV6Y$m;#xqm=>AIm=&45nDd#(uHjviyykc< z?OOkJ^y@;`ZLTL?|IC8I!p~yG63^1jioz6J|rQrLhgM&IKoyT!-KXU> zS<25fvI()4YX+fKJ@l~I&$ zDMu=gsW7S_RjO10RYlcE)vs#IYL04k>geif>hbFH8r&N9HCpcw-!ZJt zsMa+tSFLy21lk7Lx!U_WGCC1DlXp4q`rPf%rO-v_R_kHtY3pU_?dr?wN9)fS@EZge z3>q>UdKk7FQ5e}8H5d~Zn;4hfL%XMQFYn&5iKMydq{Il_M*oU{OI) zJJH6`?J;~Y$+6h6uCbFZZ@;X5#qjFItK&GUxPkba@x=*L3C|Mt6U`EPlWrsxB~vFq zPd-dRqztFZr&gvhrNyM9r@N)kXJ}`%W(sBIW>I84&pOFQW>4m5)>yeNg())*{#PzE!HVsZG4CzFnlf=A+QZsty5g?=ydA#V7tx6dmED5f z)t`kw*Y=3@H1Un~zzbUC3SJS*%}@Tl&1L zzdX0%uyV8-utvC+uzqd5eB;JO=jPqbxhsZ zX3UZjc-Uf+7rfX(KI|3TZNe8BX4H>zgEWw;O-Jc#}-FZ3eoh&)Qhk9X7E*?%E9(Is|-PsfAV&cJ$bfyPeybuL( z_jk3w*aUI+!Xy_>{b)-DRLKe9aPf3#0)F$vMV#&z6Nf)F{ueG_0mz1z^9LI&v8(Lt zFc2pq?PPA^V($b)IKxbA5hl*&W-xa4bFHxcpt<8}YHbd8`8PCToaeg#NuGCvHvrXK(F}s!u;~cUv>lYzig%n`i8`R zyO!DlG^!2)X=d;4%mxF$N#cC|EDtWQ7CU$Hzp(hj*ykmGw%h-M+9@mlul2LD`|TQ_ z6w1mXGWKxS^Ht%BG7>*Vh?zYcVJ7msb3}yrxOq)@O!(LZO?kQ4xw*~Y?80XJX6&XG zyyoU`Zb4yQK8p)w|H|xK+yAVO%)oK{b6h!qU!u7fRc3Hba9_^F>UGdtxPR+l1{X2= zW0d^x=O2Sd%GTv)_D9|paX$V;Y)z1s;&dMDKR+b%v%7zZYK!>A;=uDDi3FYIXWi6shIcniSR3+EQ^0b%~)0MU5lBRE6PgZpLopw0>Rl0@AK!IEu=XllJ~LdlD|{F{Mj zbbL&brLN_4aiHMiG!Koh4BOZwB2J-l-ZNA)#a;~kr_`;k6v>z)8&Q*lrh&<=70#6s zlmQ>-<|39d!UM1n=zU{uIpv9%Cz)n8xfHi+u$y>H8hoQ$b$E(h-A~*FPT#+4&`cih z%AFzh_V*K*rtVDA-e-2ooqfHJ(V*#!s}nm$)V^swHZ;_^GV!4i zFRsV8byA#2ZDIB+(ftY4*tn$b4@T6w8mwWcRt}>xr@Ey{J}8s;^EKk=<0CD2%e`Mu z62fWjiV8NVuqJM8Jm=)o&yh{6iC1$w1JrtFaE50eEu*vEdiil@i2q8yzUM@h)8>H$ ze#RJbXWKGYd72jW8GY!?)M-*?6^SSN)Jh9?cmhQQBLBYR3s%2(S27DUP;^5c8*>UR zgtB(;7E^NsP6_WRwu;qS={jG`e7$RQ$$QNei zXRVtgdSpoVsjw@$`~_i#lVLl(uo0Jtb1TX1A=<(X*WQU)_sMF#P@c4G#gYAi_QE=j zrtr07YQM>HKOA4IWAiPHN2`ql@AJm2oKf{Ff=K*VAu(??O{*uJL>}EGo!B%k4)`|O zN_ihALbQifhK`9*Xmmf|DlB)mt+?YmzoGsIBaZsPjDgj4&e_BWlCQ~Pk=p_aXr@x&ykd?V5$Bm@+W^UjW2FJHJUXy;u{BbH+U@(}pT`Q*-IIxAgx2hCG_N zEeO@ZDZVZ)Na5;bp9nM)OQA@gpEp z(Nsn=!5rjuuax{pYh$U&b!#P~Tf=>9vxv(4s#ZTSU;g<3WBJ#J+}qtYy$!&TKFbVE z8McnyLW7_6g|nRSG0r;I%wlsKbvmh4ZAZST?X}VRrvz|bQIEpU>lC@FMXhe^)xB4# z`Vlw$7Moqa-S!`ae+TSX1a&Q>`H{%D%YP6%v7QD*EgHGKT4tD-a!Zlc82xJK<`vvj zl%H)|H;)2jaf|-|Wv74eqv^dz{tBP*cFa7m>fRZ=KQKj`!uoJXn3KUCUufjDzIF)9akM|{X};T1xQ$7Y6!Rk` zfY09GX9Lr%Y+ZP3PnE6z0JXKYSo5_YfbU>A>&I`?6<=JUoTWOCPM3W@Q}bGUGl{O_ z3RE7tUful9b@0FZ71!gwk!xpvsV{@=^!s%aNd*2KwMb=*j4CiiR@_g_GEOl`;{O2n zDUZdE5$b5~BN7CdH_e40g8u-Ar+n9mm}Avx{{XXJPd)uVQ}%8O z@cm$~h|H+MNu_I~?Y;L;$oGv)_Wk&^ukDiO$J)VrbZ;%85JIXCJeW{9JP(+Yj1NkW zNBzD2E$L8sQT$5r7nd6eho2;O4hi(!fMeL_^sfQDnnj4RaRhM!!P`4??d?l%s7Tmm zk~40qHx(H?`ubOoPZd?hQmE|it!?ev=jamQtffUKhllMm2Z7Tx=zrj%KN{WkuaC9m4uRQLI6h<79Ii%vE5k+GNiG$a$=X0E z*bX}PsWkb*q>#?OWINTNcWur{&&$R*?^?Gh+D>mxZL0qOf%lUx8_KP3tHS>P;ZO8E z_8<5u55_??ox^zhR@lIAg7iY`CQfh;&CACa&mF6t@o(+N@l(dPlYOVe+KVY&%?V(* zT%465M^3}*?T;aQL`POnEo6jb621Qbh;{UwnNH~?0?h1Qz)0*LQ z0p>I^MQsE(4fn#y)O9@m72JN-2JY5fKI$3m z(bhq-W>JEH&j9g^_O5@z)9YR`^5oR6AfHcZ_6Z}C&GJD0B6?u*0qwT4ol11$%LjMp zdeFdBl2=_dJ(ot)%SRn{ug*k4NZJCeSNB5&nSZ0 z)=?>y7=V(NJ+O0uk_B@@ym>jNxgXiJ9Mw2~UZ-E<--{66+(o)MBZfnXrbt=*ul9-K z0CyzUn`qy(M~1v*tLi#shOwt#Txsr2%t)bH2-hU%*y9{}*9-exe%N=P0QAYcF)fV# zCe{2+r<;wpGAv_F-lzm>fOg%J+~c-u=0AtO6#Nh1pNDa2_Ih@--aPSRi)~{52?0>d z`0~+5&9?x1jGE@DoooB3*{zp%{ztck;qPfp&b|3w&$936eTDGH_Qmkui}g0K@akP^ zH`ez}1d%{mKPDGa7I0gth8;Tl8v1L%zXx>PPFC?YuM*vPapfVolWZgoa!00roq6Bv zTl*O4z60>?m34om+JA2gF$}j6Mk8sW3|cjCKZs|O>MNW70D_AC&{~$2eWiRO@IA!V zy56i3O?u@RV@lIVA%;dsk;%&A1OcAHy-sSFTpY1jhpy>%eUF`~hmJQMuow+H%gJv2 z4kz}6{h((3q5dZ?hWrzBx^IL$FLWc+uWi_ae{~zK=+}@6j1ivv^{A2Aj5&wxB7rhGK;rIwXzAC(@S zsUMK%cI+q~VLjNMa(T(GB>lF&HR?;O_Mglr+cIAvkG`T9tJdHA9f1K^=R%j`3}6WO(E(tW|KM_;Jv& z9Gd-Jlod#OM)lhMXTZwySC++y`~j&)uG`z(642dV`AsXwl}PeQ8By4Q(zEdI z$N@rw&wBAM+n>aqH277m-Zq!;x5dWhJAo#tsE3N;)uSt&$gmb{0Nm~igN`ejo?$<2 z3R9BneNt(3dcNmfI&Jd3ZMD+B&5Psz01v~bK{dt3sMhnFR{J!wOXZ37s*S3Dyu*M8 zbBgk7;*D*xN#+jFuq9gr9P~e0^KE4$k&N2x0 zuA4#lkD_>{HE6COzJ--s1ds-Ay^aa>2Dow9JXL6^PBM+PcYF2Ir~CuTLO6!gNsGTG`#T#EQ_zB)It#3_5^0>CP%UhD(K; z%#vnBEV3z#7Va<)(z%&#S#B=3Ni5!G@ZpZ_`%lZyQPQ#B;$_X$D;leLK!YMLBsbLe z{S4Y&Yy z3&9+B>sPeM65qq}NpSv5toYcZZeR}pe_GSgwFo?qI!(7wupokPBVpW%;l)*&ai;fn zZLh-j2+@r=@>AQ$%xftvEthP|iH_xdpEs*|bp0v1ov+#S_ua7;R@z)SW(|Rn{_hQn zn&MTG-ZoWbvtxoAu0PN8rs^6bmgyn4k~xHjks#a@U>uKrMRDQla@10a(%S2#-~2J% z7|K-nB4e9>wEppBMo3p=m?6vOpcv`t)YYpEC9bbh7>At91_&4fudf-Z5opsvxt3wM zj_2$dKFJEtFt)X_)gav)`T-jfsK>b4DD`x|*1~lOJgE=Z__2Wcgp8yib4rwQW*`dDBvMZFEg&j4<^*=NYc1uZT3b?ci6A zF|?9h20G)eNc8MJl-nD*e7IWTRY*_B7Cb2`dJmWJ{HvGoFN$s=pX{xCvSSi1PS6fH z_0M|R(JZZJf3!;{nH&Ra@W#MbIRYrln|8c3J^673z7spWyZi-YZ7 z(JCUc^Ixd&l{E|`lIrgLPvw6cEc03tU&;Rfw==^30Bp-ow&)gd$7}Yf+11e>{ZkZe z*lPK=z`qe^@rH@wn>}vVPt)}WpX~6ekRzRQg#h$yj@hrEzu=$}Ua!O37FNEWYMM>il_QbNR?yjPVQkXx$55Cd!~7yK}P6J2>8 z9+aQAo9buA8MH81Sm7|%_m#a7&-{A$*QEZ+-vl*36KWB?-lV_i+HJC*-d$XK1xNuO5S@YEZ}{mqYR)G07C&_PxxRSIHPbz|S9l`0wIIi|{kU zw;I=qE|%Xw(&G}abN9{(+73GO!0%s|`oF{t4#(mA-|(D8bsJqH-($)|nWwsu=EQ&z zxJJ*nle-ID`~em1<&1FAdy(&U)$Q^;jxVY1s&JQ^@E*@yRtAmcg5 zwR`^n>}&Buz&{^>F-f+?8c}fu`2vw133VKRbBxzF`!ReX zUk&_cn@jOzntamUX^(%S-r7p*46<81LhcesA>;`>Po6LVufEPB^>r$fT0>iRK0_;; zs|eG+t=sWE5B5~}J*;?2{%eUWEVR?8%4T5)%-?rlVnzTgxFio@UqP>fV$`*Xqib}! z0SL($^RxUQcjKIURqYJe%WoS-vB~EkSll*tfyl}4!4HYwH*vdkM8 z*}gIk(}A9A=f8=6@NbvHy;fNtMDZ7hJRfqZS~%dB?Qp5V+6ul73Gcs;di@5xxtb|f zN#RMPjB-HagTTkX{{UKbzT9%t~@UhW_WowF_48SB-QZ5OHI3^`H14--C6Ha$h6H{yMa>h!uuCw(>UONAVB>qWc0) z2TJBOfA}HKiLD9$0EBzvMxk#sE+dcoM#+{kdl1o*Lg%;{$4cI&ZNvD58Z@zPHt!eu zXnhS1EzDxMr-tTRT}ggS^O4~Q?k}*j{h<{|fznkQvB~4wouYVN2=xT}WOo_mRFV!r zZ2p-1D}M9-30?a}XsXMnc;~}pe`-w|Rgo-w zF{WzJr~xLlzk*$xzV_U@5((oN&3aU8VQMej;^6JxO?P zFvo8*G-qO$eaxerl1CqprD$pTn%WkQMg7vlhIa%ko`7TU8LvFlJTvj<;ka0J&x5`w zV+?XUuMVFy{V{?JVSrrcBtdy9a*JE(D8JO+hZ$w@c9fi7W|{5t_>S7kM@4A{ghg-z zW78dZ?s=?#vT>>ve{$^a({-WhVfdpfl+>Mg-`3xm*m#e`jcKDvJTgNkost*LV zbsL#&;F8pNeF3{!74d?9h~uqu)XSYYMqbIQbhnZ8m`rsUMSD7tY4&UU&eOwEHof8a zR}yZE?;L_bCjgxBnw!Ko-)5Q%SmlvbpUQB+3@{n(&|;F}&eQu2cto>9kc*Z64yT45 z{p&M8zte4PCDZRL$&zH|=5{Q~6P?_0N$JxS9?>b;^xx2zMYzgM+}iM+<*X1G+Q#DP zTHyJT&4IOX&OJqR+GJK*GPF>_(Zt7l$UbF2bB-&Kwb!8*wyhjL@$x=YjFJ^QZYQVs zR`r(KMDrtsZlmb?nC212Hj#|}Ij(xqrBhhDw|DX~<#AVgf06Bfv5nMPf9*@~r6!St zlK9s~K+urG{y$>lC!Bs2{SvBg=Z>UwujdE$GL|L%slFEOkzfl>6qi0cWs;l@{_rbAos7TKj5lYGDZ76Myq@F zA2#|SxXK|C<1OWu?5%XcHe!A;5ss^m883~x6(XZ-XHM|?y(F~UT8PAvPPHc=+AR7;{|h)yNMijA4|MIwJDtEkVhjJ^fk$a zr^)WOw@*ZUU1@Uruvc#1p}_vpdKCIri>Je-&v&U`THm#@>0Gw`&4hPsbF|=N(APKn zLS1N@=8*@2^yn@l@fM${*`?mgTtZ<1e8YfF=3-kvO5W4_P2(Q}>)Q8=u49tiXm_b* zVV)9h?u_b;SYVJB44`q?)gOj3>Ru=Kf#a_pc`0Xoq3U++ZrAl1yNcCOMw|`{qD%Kqa@$}hk7fJhD$7mslAxP7M8=M2zo@%%4 z@B2+j;lBe|_#eXeQ2zkJS>ncPaka44b|V6L$9{(ZW0DU}E3Ehv;N|#3@MrCwCb(GT zjv2LWLRG>kc2gq`-*hS9@@wNyh8{8h0EPMSJ6ia$sarLr$B2;HXx7rkH(A*(?Ysg& zEt1E%+y`M^&VQLx%CLUPd#+k}-2CS)!NW1lDq`wuPMc53zKHDfj}F=RXG*ivZFe%m zle=I-j51bNfzwak2fL{310?N>q~O@ftldPZJ1ayp(*zZ#f9KQ<)p(Ui{!6 zUYZdlvZ?*#sR1Tr9mi5fOm+HK=1=YS;;A(++SA7n_=-sFZTv5z>2TRZ@{*roxQ0n* znNiLdnN`72*J#gd{0|Fc)bmWw*i=_`a(DIpPkV@S97Yy|t5ekB{w;pdUNHEJ<1Ie( z#fyD;;K)tH_7QpQZElj`WfC}Wo_UE`QCXCaL)4Ew_^;xv3&%RUzN6v~3+fs)&Aoi9 z^4ygtc?LeYAShPnIKZ!|zh^J_C%2BjXb**&XT%1#@ZZC|CscK~(^Fi~r`0v4(d7lL zZdK%vDOpAUO{6Yc@cFaH9|Zh;;Sbtp#GVuJ1($`iT_?k;@j>_TYTj&8`8ILhn9<@mVv5eu zz!jSd!RNJnZFS-OQ^EfL88xfTa>8q0?Axf|X(NqeiO@%Y%kJcE!RIINtp5PoFZP<$ zwMgv$0JH3GZ1lZ0@$KP~WC}GX5C+>S=MpNG2WbZ%FSU9YtiD*)?4@TJrkmgT)Zn9v zmKL`>m7HDI$o<0jVe$7(@n!YSlQqVgH7gLSd||ptqqlh(%utCJhIRm+K3-2X=^AbK z+9Ottm&#W$tNH=R$K#s(D*cTA0BcBo6wi08*y`4LRn*eQ4xyyYxuCl9CPOv+0ALAX zkL6f+Rt1<~^Ixq$vhTJ2^I{<5#zF@3l!7=B{EtfdM4^dyw z+*!dD@>=eEs#KoMFYCGc9!JD<@s!e>z3#`+8fDaC>30?p;IGO*h?T(~l}C7Ex4dbi zM)~=dAKvvflcQ^qTV8#VStJZ_xhH{;GwoZJK2^r}yx7`Ka$i4y;=U4{YKz{bA9Y{v zK91`A$fNc@v?raPHU-RNr%~FVpTW?}rs^71-IO}Tjr2>p3BX{-JAotgtrfX0 zvnI`iVOjl0(wlo1+qELaq=rafQ1~Qg9Zf&3N~iB-qp>b2rx(=8(;~FF(bnG8W3@|g z$>qZq#fWs-f!KEUtlO(~(`|%As}1VJM;?Z#~=P7%EC*1pfdo zJ?hP$hCjBX-xPDHjf!$X$2eYd>5A3wpj`}si5*r0^z$d@uM6h(H9%pGQ z*+0zZi>V1G1-`FS&(5R$lO&5PMf;f0ESbY^4Z!~Zcbw*=)-+8*Pm#jcmATg(eQ zcW@-ENm0NIq!k$%`qs3%Tr=ud4I4BO%t0y!Cpj%6Q5V*RtFqU7C#!t*L4q1Ei&Qw>@;W3%68A;H@REpPWH?_Om<(zQU z-Ks@DRFBM$AAi9N{3kTy`y<7^BD0P*{{UAsQB3iI*uh|{y9d4nX4?M%!31|)z?Od$ zH7FyvKQNa>4Im!569JE|EA<}ZM$_$Pf@vlS*HZaEy$UxH0So9kJv!o{O)4>QAU7UN zG4A96fsWblUY#C1;uz7l)-NaBC;A+fGJLa{a#FMMZGNZbeyjfg1Qqz1;p=(sd_VD0 z8_6L}z)b|U%gG=Ntd0l1RFAD>!T$gRxBaKQCv6(oYyJ_|A&o&lX}=Kt-Y~vod0;(w z`d8@>hdf(8lYMxzPHwF3BvS)EGad;ARDF5&&2~_D!aJthAd6~^z!8;0)3t5?01+cM zhr3cSZ@()24vaoq8WQHEDMk57pOT&~{hWVlZwB5+4BiE^)VK@gSkG?pGWN;Jo=-K0 zdwucG!FMp*>t6|cLK0)NJ*Jy9yN`S{YQX*mzh19rU4GtHjtJa_!CAR74E^l(=eMP0 z_-{?cvA9^a+(s}E6Y`Aq^y^yF@fxQs>){~!H>y1;Vliq~jv{|OkIsMCo8!NUyhr<7 zd?dBM_<5_^y~mGqC6T_%Dy^b6p%5>UfNfq0B=lf;uj&Yv2aI(!`3I}~Rgc5}03QAn z=x|$!ueGlh>BizXoz%k@+Tb}P9PI=rUOx)`elhZYsIS_%8#0zN1uCwRq_3jA?z^9d zc$bc^UmHu?Po2dz_@68P0Kr;59@}{P_Ia~ew->V2Ba$GlS1x0986U(xVs`si=XJ~$ z)}9v9;k#R?WRh~opca-5ndGi|5uDfP-~1I$+H>#%&f?vjWVO{JU=^?qw z9!qwS5UhDPz?6)Z$5K9(*E8DuIMuY>Ev#%-lHh^((urQ%AnIxxctKtX??Ta5kvy6mI|yyN{r*ctuW%=6b8^-ECvpt%z$8DaCuP z%HOHQ{8RBhxo6wyzz#$J&ott;}0bAMid5-Ist*S=eQj^SG9h?{{RZLZ70Te z-VD~QXSd&VW3SxY-MNvnfH+(peqzhVVk^Ra8Ee;H6~AiD3t7@`8q#YLVO741$&D?N zCX8iR^!ZhQJlEb|uw9z#T33lR$*&{TFST7u?5lY3?3kGU0CE#2w&dUrFh|zAGF%*I zS`}(1?&|$I92`ZDcw7YOMJq{O{cczP0D_!+TYVGZC6|CS`JHe4eKCbv2gH%t#>%ot z4^g?e{7rbqk@i0c=rGu;4<gJT)-g26^7n$v z=}=!3z3Fdzy_57I@yEtXPluiZn^x36b8BZ4lLUl$*9_SU9Al$1WE>BA{L=W>;=OC) zAI6^)_@l*^z8`CU7+c#Vvxs7rYkLg&am0>B)pTsQJRFnLx7B|hHCeymr*uyb>$bX* zrmJ&lZQ(oEt*5xsH0z0xqgk&WGyAqTK3a{#7;juxTk(7LEV#GuZ}yjn^!*v`H5lT) zyO&LD+uO?ue*KYO?n9I~PNzI|&3#`D=`+5~!hI*q2Sd#qWHvqsxdZP}VQ(byphxK#w; za%<-8H&WF6X{t@AUwDG@7x8t2#dRjCtefs7+C0M-_wu<1_;+z%W?guFJb$A@r|JF) z)ZfD@>!@4n4|@PuZA?z!{hZtg#D5Kq#l`i$qpj$I z2(^}38aWvw#8K~4%QJ9tItub%6u^ReJ5{%Z#jGM-k-IwO(mlf@^v^$sdio<=_-}RM z&k^5vfLO;gEo{PgWkPosk83+G_mx-w03FSFXNiL^+vxW8vD#fp{gk&&7Ug18 zDPVGVUIu#{SI}T_k&M2so~a&HS~7pVo~!XZr%=>k(e*txNLJ(RT6jYfFU+#AeBk5` zql){7_9XqV;qZ5jWzsJ#f8t-_q=xs(k9>@}_9KWy(y#heh=e`;AgW<`ccU(1R>qzRL0 z1!e)52I@DTTGKVk;hA9F8mn-@vH)JFPs)B7)5nh0wa- z7gBw8mgA|(9YL=~wz;~JC7BQI=8&{xA?Oz*pVGe${{UuR+ULYtPmHZ}T?baxd|RR2 zNG>GO?xaZWZ6J;qqT3Ue%%x&iSmVg&2dS^ruh`4>qOkb0rpB%fw|5oxcp zYl|Gy#$j*-auwhMP-6O*qO=?AdzXx88j$uH=$dcl-?q)g`u)H;3hmDza|_1n>#t z-!-46>;6rgmbV+@nF&;KcOf9-Irqg$9m?x=(perxe*V$aI?)nmw$MgS;um8{hmZoV)QHF;BUfWGp}f%loAC=f7I>t3gzD>QrK) zw%5BaU21kLtkW2-7BvO-ow?3<{3{bt)O77qC63Kywq|uMN_@O`-Pe;=^+;{?SR_?T zjj}iH$xWniGmmPMNAQey_UN%Nh2uV8FT*W&(ZIMp%MP4Zt}g)Iv0bvrLQ zIhTAuGk^{ooO@SOV|#0H6k=&NNW2W352Z81z7|%59g6#MrB!$U9`%0X!tt!CvMNTo z$aEhtJ$U!3bYtxOv=_bGV+cvx-`v6RCx>L7A#racM{AIFt}r^0`1H@Eb1_}d3=bu| zjU~;A%kMZv?Tmgk+iH4)3rU%QL;MQdhQ}W;A5Yf2x8gRhsc3#5)wL^4GHc@uW^c1V zcF5sKEC;E^c&=$wttqNejlEgDIHmi^Y+StYt;dQq`K_AXIoVlV%C0tGgP*DC#YcB) z*BVGMX|e_w1jful{YQGm_&4$I!haET=`K7!GhJWhOYkrhxNZxqK!z&DMM6pu;>>Wz=qZvS#9H40UMMI z40Fam&MT+Wd@FUQ*tBbEt0Sj7&g@{-Zy5Yi9s<+kzSQKkvD0l;ou{4FG87OA&V9$d zeC6>Q{t9hz1ghQ_wUb)=LL*0UBw`q31dzx9QKY1IZa9;|1AAREA+QUHcHT9%0Tz|qzWYI+gzGQIQMLd31 z1_shpWc^2_e$X)9e)cQ!x8Y5TcxU35hhU1|mDKGKLt$OF%@_pwfye3TU#c57j=A?e z>-R4Uan&(6YVfGO86=mLkI4K-!@`axmnT%)`_C2qy)7(X;1%@O=?szET)`wlLGp*) z!5-qjIdr?#@ZG=sBn4kp)BHY?-Cr1^4J620{aE!Qr>;$Yi2ne=MdgOy_FvcJxLwi0 zV3nAS+ZAR)#17|ZJuBglhnjAW@gKtyC9S=_oo2Tt+6a}tOMSN@vi?!dEU-$>n_W@GnD3=-Xm`q!T-kfVol zr+5DV57W8zF^h7&td@@VwuIAoX4~S7-|&({rL~Rn_=R+d~g|yGLL-M!GE(S#63Ua&&6vW4CwdP zHkSG>uRf^gleCR8HM+!r7RK+Go=>l8_uu#|hs3Y=R9~|K@sAQ_&};x*iU7Ibw4P7)H~A)!SV!uHIKXbM}7G zAy0+6SZkQBFSQ9C;`FKW43jy`h)?&F@m+_9=Xdc$mUc2U(A$#io3}DAp&ds(E9d)9 zh?)n$4~||g@go_^zBmDe2Wdm-LlGnO%edFx2pifLFT)Eg!fvl zx9v6JPZU^QS?XR2@x_Ws;oGQR_kue&5x(J{zQLMKnA&Ue{Og8s^e9xu)L!FVwe)AL z9xTG*@ot=SmHYJQciuSoQ)lA62K!CcQd`Sg=)&Ab5ZfUt+m+Ri-C)=t=Yjzo*NuEm z@Xy69AK`wXrTBxynzxO-A95}H%e_9&%ha@(#nOlSNf`S`WaIZ?P6!8~HS~tFuG{ME z_Kz}6+X6bF;1%H2Zx8CW`c%ZB*y`DDtVYi2HI}0jj zo)?bf*M|Ia{gyTF*+0iR9ruZKD~}cUcU;s>o{gyeo9z0OI-H4!VhXtrziW?_b|i33 zexi8eYhCGjtbb=!dr0l)ML#g>l1JClzG43Wf{tn)Z^0i0-$`?03uwO)v@J;|y9H5R z8RP|tUzjSVnVcNsIOe@9KNnY*62k_P_A_aDy)}LBe?EtgiQ)QG@>xk=+DBis9Cz>H@Dt$6Q`JaI4lA@`bn(OVz1NYGk{ zhY=Fs@sGNCA8}hB71mFNe-*XQ56N$Ne-)kG#ktkszGjbbnWA=G!#-Z_12^91n)6?X zR!^yTF77*NZlb=??%F*-u4LYa0yho7EJ?xJlZySz3xSHK6rIzv=IiKwb5dT;4r@ni ze^)tqE@koU&6$!kRkF8xK!{K83~cSc86}ywa(}|P&llMZJ6C43TbOkIAXb@hS=B(3 zr)|WM>MNkpbj?#uw~qZ}fn*c;K}RwiZ3HnKjO1~ia&uW;D$thF@`Um($`&MZFIrw*`PonA?d){6ju-#g0F-Zi87&9mgcPgTs?re9f zUl4z79|(LS@PvBCyvYr!TL!n09ZJj8DvtYz8Oa#rSK@DiziYn|YCb!-)O8IPPw^I~ zc^OeWe+tiSr2hbEG?3dy ziUA-w849X6>)RFhO>YV4LbH`xzT$iA{Flq2`d%N7D$0v=VIzNiFTR*bLXv(d|VwqC=Qy5@Zz$Lh9=@v}hHbnp2K1_I&i zq$~dbET>S#m9juMQIm|0YMr-^;TBf&M>c`f`)&q0&Kc?QuW9sPdGtR@ zZ1onK?5i62!f+ZXQyW6}&%gLrNgC;}#AA|W^DW#y!Jm}lJ!|1Ff?v0Ph4qPSbWJ-} zwY=1`y=`rlNlVR*&kg|E9eaghKMLu-9e=?=d_D22#2zTp*H!-jgnv!&;=rkWE3eyi zN1Qnk!vgMh3K5GPykmo0H28NPTbWdYi;nH+t@vEhrx0N%{_#mky^?9a!(-|<_@{2Z zG#VaArFnUF zAM06#96Pn;E#GCS>Shv+95ki!)6C?*;Fdlgw3k-#7QKC_TUq}AYupQ4D<*%QSIazw zRAVEC9_Nbt(KIm(Vi_7T#y>YPrrZJ0AI$z$`ET%_#dd!VzAS0F-Hqq=ot}aqxce%z zTYx}v{{T1#Zccf}8R$r_x%@@@P-gpvD0La&(h$t zj(317yh>G2mtoT+u1^@|v_EH$g3a+$U5~_?g^V__z0TydLs@N@=G`qsw*~ydFaaC6#6A4ri%<3XPIkl5Be(Zc);WjVN+Qu z-&)1B<8LJSh=6>+oOKmqTT`fLJCZ^g;Wsmd#_{QzucYa5rl8YB8xJ9OyJv=60seDe z9V#%DC(onqkD!e8_H7gB@<(m)Q$?EIG(`J#up7qaTuQhg*Xg`t1CL(y`I~QGx1KEU z^wFavtEj^MS^La9#jn#OIbZRw+B_lmaJ843{Ey0fOxCpzNBN`1zi&&6dwn5J9QW3AjP`+>U!i_)Vt8;4NvS)GXTaJ7VcLjE96i3 zDM!TnKZ0KjX7Ej{cN1&gHbEuqYbTi{y^==7os&3i&H&>W3twGIx_^fJLT@iM3s|GI zwJ|ECz}>N>GP^O(NC)y5^smmZ_$YP#zr`Qg_Ty6U{-H03uNzH*%4>B}#s-gg5urSQ z7aZfC!mlQcD%I&mO8d2Kv`=@j>u1@N>EY)}i};$qbgX!9!M}-KAowrh3w?Ia#DCcL zmV{drf%7-WazQ6K&PU~6QfMEx?~XreFN^;G5ByoH__S$9;ey+(VdA!Qdr~r%ky=)=O=4)9yZJBErg8xW;!$ z-Cp{w52CKWVoh`Qiuk4AJzK<X#N9zo(MqB^#peSDU3Ea!6DZ*P|w zcwq&)B>E3Nfiz_>O%8;wAKx zLwR_%*1E*ZrZ@|3I~iA(1gHRjG4Ee=!Qz-TO*+={*{x-@xQS&)ft8iifN{|BM;Onw zeo%Z<)1&adu?4-Pc5_FiphZZOJpN$rdvG#(*QIBzk9yH1i#%r*Cy1l(#oKGrm3Dh{qZ2UoU^bW;AIv z8*c$!NZ(Bf4g%g*q`QZ*ioEx#e8l-mApQGMvftBpPsrqH0`nOX4uLvE1ao2 z^LE_%3-+A-qI@snuZViKyt7^E3#{q|qg$(Ghi$sax-zJpxC885`@p}H%%~PxJ z+vatbc4JTQ)RSCXUR&JWMd!gBv-yhu060>4JRE`Z4!Ep4e-hZ<_;Tdm$uzBS_>hRg z$c?n)ACF$Om*v~(kWVGxKWZ_o&=CIsb+USn;Na)dvF@)V)@5lNHr^ev(Mjrm6U}?n zFtMpVS#{gx=5XRMGK9U}q&K>Q-D51nZSC7~ zBgW?r4%PYs%kG;HP>r-ey7#Oicz$CSspNhU-&tw>o=BmHoB*Wb&Iup zxGJ`QA!cO*l14c7uXy->;(b5G{uch;)9qp~WhOFiJBqpvcH;x_?_7VxOZ$(38oq~h zapAxCNwl>{*&lx4Yz|gdbQlfSBl=f)`#yLjbRUK758Ii=)Q;vb0wGs&$Z?#4szArJ zXNi3DJ*r;suB^HdsX}t4P29!!lW%e2EfKW~9UeQ_Y~zw^kqkS$wBvRTexIFky03_S z5Bw?chLPeA8$seNV^~X~Ewn4RHrl#ISQED+%>eED*-&%ABDQ`lcph&IXc5h(T}NXY zOD<*3(NTlW{m^wi`CH0DkHNyMe(t zCb}|pPP|%+w|C{(;LK%ND}Cp4@f-dM4dRV^T(>?QSo|ZW*~{kJ3pkU`jO{7Ie+Xsa zfFIJnvi|^rWPaG+4`RBLLAvnAhkP@vNwnOg7Ra$&Hs%6Y8)<)%GyBgmiO)_e;13S` z64g9)c9UuGo9kFelJaX@y4fc%_1h~h#v>RF-xa44WVjq=23Sg(UuPZsu9oZ6tdA9m!eXOd3h8ft$L{Ba zB(>BnZOmpCid+cUc~?{4ckuu`LBqau{3^f1uN>*04D=g~b5@4o?3Pg&Hc@$v8#3j0 z0)A1)ulQGjf5AQeHva&^-1uh-&8JVJ&vMhxErXMr78eeW9$4pxR^a2ceKn~3HqyLv zq0bC*ERA-dU9dMJC?N63`d9Lwe7e?}Q(= zZ-@MQ>Cj%P3)$Hj{>>p&WpG#T0#%jr2>FG5Icaj% z%e%`}`XUuJv>@akwD6M{o$V`Da<3f;>T zQ+%FoILPb|C)1kP`1SCENcgkim$$sPwSw)HT`!bvf0K}*y#o%O-t~#_Ye2p5^_BEj zQ%JLmzb_IvZ#j01k&ZF=3iR`=|(dVY?R z=#Y@hR}8~7^8UgwQKuz!rS0aAdR64-&l983=V|;eCfD-T#3%vB%AlTdDgGK{ zxwyY!HeqP|zc2oJgN(5D6<1E4>w3o0tZ>{Zl0}oTJY@S-%?9qu!&KbTEON)pi;b)P z?mKkP<6WN8DvYHKY`XS22-`j;Txd#GOC-ertPKMv>(!VY* zwRP0|cc$F!eb(1j{!=6E&SgXWKRW$KMZ%A})7HOY@Q%Gn(8NKl6qn?FSK?G{T7~;R z%^W}N6XE3XN5Xw-&BNrvB#bMSJJEwLBis%;9+mlLjrDJ*FeHh87 zTEi56T*rkwnU3xW&IUo}(!V-x?e!}w^zkS5#nW7B3e$Ot<|(&@7&%eXoO*FzuiiKK zZvOxi{{UuBA9xn(Z%s5hvplaAjmOG( zbnLeNy$_MJX!Q>ic(E;&n)2c{kL?dN&zQkUIV6+M%uZXsHT7rg(W1fe1H)b<@XRt> z%8_h%t{V!mOown(Bxju9{{Z!?-Y$ROoxTVCpMPWD5qwD0F11~9O<8TA(e(SY4A8u! zKILFXIAh07E2i*I!A}h6dJl$V({xKc9`er4C7s|VGji;|F>&389Gc6GxS9~9^=dKq zx_);IgCwaby7GSNZEL>Dqyx(=8*^HHctyV++aUtn-1nIsP1a`q!drI(#}k#A_d&G>UMNFB?Z+*QFZu z_u2G&e>i!KaVd`lcs!HW@vp~fW;K>3IFyoUr)0GGTVw93;c3omdfXW$?clgV5I-;&0iT zz`iH(JeroH;j1|{i`GGNa2Uqd3K3KmE03E1bKA9kiT?l*I*k7SbBLj3-Y)n1&yB6Y zHCD80b6>uVe9`mIgT4#tTAsNA>J6xA_V$-{&`LFnMVix7Wj{BNgc2nYEA3;?TvwES z+FuBCmH0hx;sYddX?KywsKsd&%+ahdOC+fh2qbLe$AWg_fHFH**n0l}?2BtHx7xfr z;?Ed(cH$?*Q^BcyrUg4#Di|gXHv6d!j+o}WSN4|h4~VaPAE;@66ZAImH-#h8&73;! zhYn}7x>*E}+}&HBnJ@!K`>c#P&r157yNL4qbyz4=<(k_1Ka>9ef_OQmL65^yZld>m zuKx6Xee0ePkH-H12(4^1`)kQAH1=eT+!a{ElEHf8j(<9)bbSX*@O8b!w)d>ip)sQ( zt+mafY2fll`g|*TQ}zxBDHnoce9a1=2XnO?2B20kM(e+Dsmf+k- zRJv{-d2?P(@Q2~aNxqMD)wFw~mPX3TRg`3~#tum7=xg1yU0%+A5nM%fu*EcR zlCYHA5O(LUK9%Cury6tThh5F-#?iEySK{A;uCKftnxt{I^gm2%SJAADs|!NmP$Wvuw%eEGKU0pKtC>C$(!2}e zdmD`+#^OyoMYR#mAA1SfLAjjr2JOW37#w@okIzxklw`f#_Bv`bq^fOx#UGD0pYV|Q zS(f5>Z6b?LxVX4kCRI{QE>Ac=G7>Np;PQPdCcM6#5&)8xYem}5$f`RXq-WFt z-n>8IhwS6v-G5YTjcZoaY_!O+EUOf;$0TtnJO2RI4p=D+7aZ&&y^O9{dU$uv-JZH9 zb@(1`JKM&(l|8)w0Lh<3e#*WB&^#HU>s})8zlt@6cdt6f=j+w#>#b7~l=c4;cjFx&Hu#KN7Wl8^`un8WLIfZ^V{z ztm`UjF3Bp1EBg>;+`SKa>UcGZ;2%!5>)) zvAxpR$FS2h-Fkb8^$W|W(*$`a;E3OG7(vtPp5)bUhCUfOZoeh5nro=$n5<1RZgg&{ zD#`Fm>37$YLMAPdDap#@_5)XM23L(5=as7G%&?H6F_t~Jugx!w9y5!tCi{7iW7fhZ^LlzGFdSEUKV=;;366rGGE~0JKN^6<6YajRu<(&7x@R&c;Wu`R4M( zukthck$_G(z^|as^17JJ!t~>NB)`h%lft>BJSHBZslMmvr26HprNzuMYtY4O6h{78 z3k76FW1Kcda0vN(mBo2BnpYhEDLBGt78((GcL0J%GEWb@b)AY=1mCxAH^9`*Jw?DybZH^y4jdRCun91~pX zYilG?zSZ83iC5f!wey*_NeUIIrz_7*Ed3ABvRuk_D9)nUtvmk!Rz27D`grZXW?zNg z502d}rd(RAw#8y5y1UYvTB-v?}H>}Orcl8V^CVbFa| zPbG`p3$|Y_IMjtjE%Po&9X}e*doL?bktU8g`7DYTEDD@g$W@}K^Ej;$(HSP_;c83X zlp2fc-1VtegkLQ=Cn8Nl#J7t-*=(%*!;A?GJ5(^}+3U`GRXaT@=U0hg)>aI%7^jXI zOGhkcB!l0R&IzrUd`$xB%C9lphhpv-!SBaz-<437@n^nm%(6!Ab}mY`PZ{acnk7z} z#KtZSdEamNAv%tx@21Cs{8s&fz7}hK2-aW2x(%m>d_f9JeQ$VfUO3)J_b?7O9vhX$ z4`E-L5o*@@w}|6rfnX~0c@xWRwoxLs(yDR7o}#}_e{KH&inrede`j5LT+{90xw=UO zf*G6$;PZnalknr8Tvz3fgYC?)SVgO8^RmYrmiF<)%&OA!A(XaA!S(1X`VYhWO-xQ6 zrCRkIl4(Z!>AL&2KON&LbSdIv8gaGOy_Nc&o#H(o#CqnXqQtTJvO}~LZdHCcQK>faK)4Qn5UbiI36@r|oo9X1ODdDaIkf>x1$A!7-@4A>dRJ*!*xd>7s& z)@*z;Yjq5I?zm-&z{ePPCXNnR&cyEdPC9UXYwT}_pR=}`bEMndOxlIMp?hSoO?zRe z-dHoMWbN9H2UCoCA8Pb-o+`rSn3&bCtF_Wwzdu8dE68eQG#|3tzxC*O&%w|5CnP@( zyft8R3@gj$`QiL76?f80beJ}c@UB$*# z60%xZ{%7Yu!V9fotmcMoo>D}N(kz9QTL5xDOncW|@eARewW)Y=+f>rEXyd+yhK6ole1HNn7ek%UK*Pao)hV#JpK5R3%lIF^N!)pZtYPNbD6Zq8~L*QSFd@JC+ z_mIr8d2H?@Y~*7+W0G@T=3m8F3{@FUjYQIGdv)Dw{{RelRkA!q3&x~d^gea?OQ?9C z$3Gn|+gjA&iq}rMXss-xoT^Ir3U?stLN*81zLwSPe#xXtwtr*PwB0>-PlxheG(AhR z9;>wD>0ch}`X`9K5PV#>@fM+@-|2RDu?d7x+PsZ*A!H9cz12df%He-Ltb2cAuRMLaxlsuP7zbwys<9t}#>XeSv*UYG25 z9y-)C9~~Jkq_w)!V7gnIYYXUZqm#{b$VUpN?u?$ijl#V5;-BpE;zY8wit60nY7tAe zOPL)ZhstzF;v8omEM$Y%kV(ns%r4y9>#X)b*63-Mu25}#~X=OU7)UefJJ)}!^;m9Do;hfnmB3N zwCz>uQTSz|c*_3(#M+E^kjrg*sMy;joO4@wim_Nh<*H#u(0sWG1-T=U(A1wEJ|5oq zPUpiHO>mbQPNi=Z+%ww7(X?^r3>ddOH$m-RS9qi2&X1s(JVB$`!J|zbm{+(7wmX#A zRPO7^2N|!J{73OWUihD`={_;FHgkAp&h-VYn3zts@mk2!-6Vr<23kZ=NWncbUq2jG zDzbGYrN6JgvFq2R8c9=|z5es1_(S1~FCFW;6{|-jL{{RrucZW-RI8qCNBv7OY zybN5vG4nAv0E~1$O8p1XG)XM)BV~zWRZY)A6huIL)2!m$!GAazP(9eJh}8<~D((wVTYz3hqYG za5+Ezs>#u> zmKOQQ+E*AHepJ4%9(gAhrokC<{AQ$gYh!;TibF>fLMGgKEz}Y0E1z4a<EwBwwEO;N8D#TeN{gb>uMb?T^nD z`jr~uBpZp^Hk0d)EAorPTIbvEVe(}1VGv5O!c*klK`co6fJf55R^pMGC2l@o4Su=d zOko;0hP1qvhvq&aPnx9!yG!#&KkQ7qz59b>^ zWpid7zqA^EW7wvg*Sf zLF->N{{X=%67#?w3AoX9Z8qZ0JC=DZUM1ZF@Li0L{5zC01?!b3ka1sJe#F{u{3gE; zEOiFA5yu70w$QhkpDi~B&OVGeuaV7W<(fuvpH-(#HT@6Mvafe3s`^GhrJ2|M+S-WL z{{U`F`+J!e$+gm~3En%;GD{AZ;j38QIAs0OgaJcirzQH1qZQm}+N=^n(mL)|0F_>c9dqx<>0gp@ z1Zk|!Rg>#?+g5#M1;!b53U8zr`I74p{{Ur09L&+PTo4L@{;M6&}aBkcDMEn4Am6rf-+E1%Hfzdrsmd>p;; zw~yw!_=l|BYLn@x*7mx5ku))CO2n6RAOW*r0|B#;Xv*^Xc&tYzD#|Hzz22vpm}4tp z<9czseO{J2@7cq|8c*%j@xNBp{8^yh%VVX*CcU8Cncr;i2lA2`A|sHe_kd$@$;rU4 zt$aUyFT@)!?MUvLNLgAysDP~4BnHnXp1AfI!LKOz3FE&DjU_Z+4OrW0I>w(Ib}eNb zvU$;4va!gKj!K*{$83E~ddGz>ZM0idUEgK-7X=&T-~c~NRdEmRW}(T+2<*K-%-1E4 zEK|QUp6~Se31Q*M46c&0N4Mt%xFr2Ln!l>)7TSE0J?zY^6e9ti3F(j1rBv|^?3bFn zQZa8QAgsL<83RM@)CkdQa^O@h)$TKeV>1d2uEFt);c}>!#dH zj$?v1z&m4Zgn&8pt$jDbHh&4cGkvG&dc@7*tzz!l4LeU*KzZf5bwzN*W93rJNX`P+ z^j`>Nk1fjMO02(mE6Q40*IRwtpNw%9FY9>EXti1-zGuJx0Kq-J8?=5R*L+=X8(v!4 z%AOz6CWz!lk?%Kdr1A4P+Di4?-oC`}%XMfvNj<|KG30+%XJ*OAO6()mjj~JUBO}~rByp41r%L=&h_eXe zvnsWw^mkDI01CCAk@~(zgLpg~DD~=gekaq$qPD(L!wiic94QQ?h~(qBtzQFpV%En? zTSmQkogDy}4Po?Fn7isC#cs7Wo%cN0Y-#;LR@<+<;VY}Z|Fbq4=`{OwGkvsql6fpz zRE!*|f!MJGeg>-n7}8(7dM9H!$-8o^o*(;Ud_OwJ#D9l(ntawV*k5WAXt$QKLS9RY zX`^F~GNkovwgzx|;=Vimss8|B9ZTSEiLWoTcyG0T3EV)lT+a68t}TMU1(kVhb|aqE z^{4Fx@iRuX_}}o~!&jP?pQP$iJeq{?T}pR_3=ts0;}SPG=c1l7Uq%f^-&fP&hwOXq zf)J~nqA*Fscx;o&<0SBF?XVnCjme>ha*g98^mq4{Gq=xk=W$ss4>9MAw3Jl0F~l?Gvkj`Ca!Q_f{}pzr4q<*{5~{{WF+hhlNnF*PMpqSpI) zey8c!dNHF%*{zbXX75TDcZ{avF@P1eHqb~NeZLyXz3@e>dW=_4-3ZK0wFU{``r@P1 zFK#rb?j(h6=7h^1AU5U$InO@-0HsH3uiV?g=S<#oIQzSP@Wp9Ts&P~Df1xUy=xK$G zY_C4gBr&46h4L`_`UU{=g19;SDY5ETT6740+>=If8*hG`ex|tnN5r>Jt0KbQWHT-# zbxiF8v8nAfg|xNZ9Klt|42qli_0MDNS^M#)yIX75;l za|S-6pkJj{*1SH=s)-@;x5ERVT^{P`rS5!>oGttE<#6_ZJ|_x}LF1sNygPidyZbdj>MtkJPm`M3uias6uS zkw%e-W5Y_P=L`Hz-2OGf>Dsl;wyd%*${G^BZ}3Lo_x8sXxZX56bW3qHyGJtspt}ST zA7hX6nwk=dpEUJu;ZiO>>`|ZNZ7NHBC(8k?rivZq*x`;lXZqLcZMGKPKlesB_OHk< z7i#0gk`Zx@Iok=DA8=(S1mn}1{dm(R^K{#G+&*ZO=huN>vv^rjQm0-mbz7@T!J< zzS&fu?LX5?oV*Pxe%5sAY}ev^P5T=7rrkedPl;Nu+2fN|U0eG??kPsp7%qznDEp|p zus_+rucSZVpWhF(iw}+dAKeZ5L*jb|ltyLS9m^*F0GU6#oREEfUrO;SKZVvmvmLjO zv>2`~uk3tO1=KLio^;PMzU_dF02d(V*b4d!_9wBrx4p2u7g5~D;gT%uospwqk-fJ1 z08TOWuMZ1~ql(O_;vlaV6#X>if0_CfI&q_u;vp^e`8~fA#C|sEaq9m7@KfK47aDwX zeV;+Ii$`~f_QcR4!G&f1@R-O+_a8ysM|&i8O9{fuF9@!FUEKk%EYW;Hb$9;&1r_+& zB<#xRtY1ScE(!Aj&mJ9DXbOXX0U+Rxwd%T=`!(vdzzZ8LFMuT`tQQ+KgM<|4y6Ql!Z1F} z%&f9R{kFRD2d!Sz^#z&*2*Nvf&)zGJN$=jQSnCUHB3p>$k*&xj)M6*>G2gv--&(1M zyyqo!+uvJsDPh##ElpE7+aHAbPk_GE);dgjWS7kx&fCiqCej-$$EHv9tGY;<8#a#M z8GOcM4CE@Daly~w#d04W{{U+3Kj40=rAMl1I@Y&iu1vO;Nu`j+N!hlfXCRUDU}pz8 z^)=+b8a_Js)590Ke~B+7)_fPG!z__UXKeN^g;?7^9Ps9+0$O~(#O#kdQ{ibM{f+7v{GbH#{*zHivA&neKyyZ{?wi= z(fmgkO7dBtjy_2PbFdu%%|~hRR@2A73?T4N!W$c34e2-X82nk{ z`$r*f#0ljhMI(*#PRPR&@r4_K1e*1-z7v)Lr`GccD<^cew^grx%d0psd__EDeTEJW z&v$+NkC?8$1$dLdo+I(~wug1P4+H8l#dE1k51DKcKX}SAdV~zE&Q$P4eH-vA_Gh{A zH^#3F`1)z=buA{tL5}-NHtmlp;%%UB3^EGD7;(TPXSI6w!cPWiKM%YEZ=~qfxwVgc zGlpQ&-EUBhEXDJ_N5)Ax&jj$ZJ}FvDz&CrFYL?!A@Cb&BJoC7>W@#7qpdL5i05W@0J+!f1$8S3AdxqVN z{nP4wF;o8l!bK$VL-umBsw-{-_(wdSU&g9k>Xz}^wAa`6YRuVTJB46=qd2e5>CvxN zYH?cro=4cmH@i2RPX5JJhHHrCSqB+r$UTo7XPnm7pRCy1mXCLtCkg-@UoV-%it}TkFeA zymS4cFseju4(x;X0OK|HTra||wjol?VAVFXYpdA#ULvoH!{J=;no9cmA2RsQ<8H0{ zQhaIgPNm`f9_rp}>+LsI(FOdG%_f{Ji7SQxjf}!f;NXHYUbsFfd|dFCfMfHt--uIq zk5RY+`h7FR7O8D^9Bm#*NcYAgW^4h=af9japR&jN5_{t{hl=j?Z3n_W1N$Atv=;Ec zi06B0pfR%_mK-rCHxbvbTKZvs;F6ykyj^!=Hl?h58_;g7(P56}-$|Bs^22T<@$(Q# z;~eI{R;`1>)0H|HdQHZfTdlfj`k$Xrrz+GZJVqf$ncsUoyCd-?`{A9h#&3ylrKgJF zv9`WZa@RWCZ)+L4kr2(iGnuX!6$`okY@7kNr>kl|vv-30H=)kD$HfgF!g`;Z9;c>U z-GbJ$PXY5HZHhr)N}%b$_Z9jd`!W9j!MM#Vw_YIeXT49Hq{?DHjH28I$e${U@_;igQReK9V zYcP^2z=>p820NVfuY{NOezov_impGlE;PICN;t%EMwT&3@cbb;fTI{rIwz}Jza!|~Pvd@v09Cx!8Mh~!7*Ih&oM#xSn!oKU;XMb# z@=0~8$EjUH8_R+-iIImL``6f7H~bQp;3Qu={6B+KX&+>g>Q|R{4exrB$&wu@U%dDdc za(C}#q510{?RDYH8@pdImP2xh`-PG}F3pK$Ekqr`Ug`0LhQlkUpInPXw zQD3Q+nl;3*w6WaCBpC^kL1I0-^H}lTWb4r;B_!Ap03fN!F)Emg78?I~!onJu)~Q zvst%4@JfG+SCFLBc++@=`$=>66xoIRuRQVEt?N%SI^r zJ)bZVg`6)be;nGL?4)7Vf$pZ>{`2(|muT+3SsG zWd-}(8REN|Ex>)!7hyO*G3Pyb^%dWK&7KyM!@9SLEF6hi-u@|LWrxYPmwb6`owB)6 z^*mS3?fY%`A5PS-uB|nzi>){1mdRqXXI1&YXq-v6W=00%(>eC8qr<=OR1XPw8%)%- zn=c0I`mB=qV%pjpq>V^c^BZp9!;IkPp4|5NJg?b8q-RkjFZc)Q8IDtqtAzGcVJl0c z)6c2lzAL$gKlm!=#+6wil08c4?x9vC3$XHIeB-*eT>96hc!K`s{xthUJ5fL@w}w&Z z2wlyO{DlA!m+eTv2O2{;z+sYaELxh{A zg{2%#RT)MrM%B|^mpd&l%o4BKm&~(|0zHRujGylw{{W3#@s5dbvBz|=w03YDNMwDO z^yBiY9vc4uf>nHO@icC}Ecn@>=yvx~Vnooic9v*{GEQYG3FtTj8O?e(f`8zR-VL^t z-FU0w2gFYnUCe`YbLwAaww1aJM;1!-=uhQY;k-V_)%T~%ua{2mL(HwjnOzI}N~!76 z@AN)(_`l$-2gY6zhx=Y@J4Lg&7goB0OcH1&S4T;s$rxqWeMdby5yn}5!XNG`eY}kH^0r_;cZXw~oKEZ2U8> zNHjQB8?7GAZX~g{FgV_mTQA>bZ{T3Vzk2mg7X?oR?-Pv5& zX~N0`iMUeGxKAsN&bb{3&3?9N_Lov_jiH7!aWRuVQVAq;l1mZEJohw7phGtz72}UKd}&V){_BTOe|h;wrT)o(wWhg0nGeFP z6T-I2-e`SeU$li>9taB?Gj1J5K(AJx{t0jK3&umk{{U{+V$g30S+#;LS5aHBjjFM1 zL@_8Qc3GI3{TI`;=ydVBJ6}DxV_2bRZnYbT-g!q61M;XK^PhZm zuToqWl}1r?=_cFTdt14jxbqz=IP%5%bUp^rf8dHX_K^LTTKK8r&1qXQGTB>4Z{;1w z_hFfsu==ilr&Hkn0Qe#o!QCfNisMrM0EFknw~G@7hW`LuXb_GE-Ha}O6WsA%VqbVN z^IpEVYss!H^p&_DWfpf&42R|9w%#+l2M4cuo5Ow^pHG=aq_+2(To(CBrk58M?xc^L z^yGg&mDL=h2Sz^aIIog8@f=y0)r*X4JHFZ;3-G)C361bC;rhC1e+#@fG^|6rUc9%2 zu<69Ez#clTIp)1{M))(}&xZO<!V;1RQ856kSRB@i9b+1CdTU&OOZS8i+ zk|T)}e)H3xUY^|6E~BXGo+I&f(T>U};nCB|wsyE5=`-@I#0~-{uRN2UmEHANMv{wG z{{Vt{weyT@qaJ#V>8iUc_+P`VV_>Z(lMTj9ud#;Qka7ZycEHEAdB5#dp?Iss{ua^v zHQ{|?{{X|9w~DSB7S$}8684ZPZk9(9hyER^r9e4h>6+=hWvXl1rLWkmG}){-LK;WU5mC;TTgZGlM z+UYaRt2)(VerA#AnoozJ(JkY(^5lQBN3b~nNgy5E_ayRazO&-(JHuB7Jx=1v&6(CH zR$H`JWsVy-}V|@dv)^yu8xyQsi7u36Os8X)50+jaXz5L0%1gG2!R%$L&k|I^IIIQ)=ED@h+V3EH^E& zr32+tj-c*d4`My4=fI60!#@jr38#2-!+I3jHLi~shFfcEVH9gG@w1RH&c~hK#7=wj z7+gm`EGN@ogngU$Pn&)p{{X=YJO!d8(|D&>)I3jp%0Viqw%ZvUaJ^1*$PORljjSIv#g#&O3=NZO0>&<%yi2Q%2ctcK` zQn<6cv$fOVXVvEN!rrQJj%i^>OO;K0?0W{x#4pq32p#$+e@B zMF+|wC3xf!P};=|w(`2j%_PJp?>h0%;Z9hhF)z%%b1?`*mLOxf&mWCvDl>Cd-*#g! zk3rU+AYuE(RJT_`!!cy>)c!nEuWu!}Qe@i{j)x%u$UQmq>-DNDdbaZIWsy9fR52fX zeNWc3Bk>))s+Ge_SuMNv(s4K&#+)Ubxl7fZd zXv&~m3~)L#iCiyghWl1;LxBLpbt(?5?D;y>_FPmebL0JHDxXXB3%TltXc7g~F1Hbw=^ zkX&5mC16KQp|@w+y;o52CY#}ZEhd>ETgevQ2;_VjTO;ILo?F;+UlM=8Z@gJHtD*cc z@Ei`&wzI6I)z_5YD@MEp_a87ng?SmyA{1!4>*D+mV})=p6PVMXKZ5%D9%uVAcy`~y z{{RCt*p@#sX0|g+8yv|aoE+oafr|Z#@V@5o-kx)e+eQBXpbGr7@Ew(|ou$O~yIFMf zREhQv-m-C%>N=mOuhgFlNhJOl(`0P!dsvyV>QED2H-|sGo+hl7oSJ8&;ymN+DX8tK za{eIEup>Nvb@AWrC;K+|N8;DU_187u4(XmEy0Nv6)e_{85lL(TkL6!tYAfbLCeze# zeJkbPAACjdj)mj9oj+3Xe};5@Y8PBXd8ow=)c*k8D;Z$K5^=yiYwU9jJu0}HSK8fA zkjd*+r#sQVH*~+xh0w{uBH>(eJL{R+8Z~Xl9V^B)9J*R$zGtJ#s#n z>AXMtG3ZwMznKP$q3P`$F~t?Vt&D1j#gFec0fUjo20HrzUW=x9)503evqxv*4Ifav zh!>S$xK{Fr87qb7J*$uSouf_Ti!`x|#rN7xqN!`iG>M>&J zx=D7rTIhP;T*lRIl%(Ztb#COEPltRJXRZx5O_N2`FXMOfEuvW3Si&5g$_C{FfPKy@ zaV~$eXI8bhw}D}1!l2=I3=lnduaEx#V$axX!oLrHYS}zV;eUz#Hu1Ihhiv7wzin4+ zQHHp0nlCgG(y$6LpzIjN2GNhWkAd3aWRXSJojGEOmEBhn74JK4*8Mtsj%cQ(d24?sr|4X;*CW&%<}(~@e(_9Tq=3ISKD9s1sU#}7--l^S8Qxf&3)%C9Dua1=#&ZZ4ve=27MC$KmbkPfa>IYoiyFad#Bal8?CUEz_rO>s?f9 zKaN)W_Ar!TDRay4K3M&;{{Uw{iQlxIp=BR~{wGJ^9TvVwx_?L04e$F4WKDFW#t6g}89|-s=NLNml<;Bjec^N`(1pK*>Mx+39l5@^0+J}Vg z;)8k#A(lkQ#mCBb9)uJ4{{Wm;XNw|#2KYl%y0?xitE+uB>T7#hk!BY%h?^0WEAs|8 z9r2pa9Vt_%%(t`HsM+15eqR2~U+`ERJI5au{4L?#R!dzegU1;B`&ZyU#eaq#De!NM zwQK9I4Qd`5)Gy*X&AMOM$*0JWtfXQWWd~_0#{;jw74@g=TmJwBMe%pT$-FP&-ySmS zz7*5tNoQ+|FFyF&9EhNe$q1yCQVw#z?Bv!eN>ur6bgk9z`JLXz$xg@W2gDzaUK{*t zX0p9fH(a1379b3%EF0xKiu{V!zha+=pSFMPM+x%-a*N2_?7`(n>i)HFR%$K^>hM0qkMN##ydo^hNC z+lNt-@#M9J*6g%zhVvI_BiQsOIO3Rf2omS_mhs0oB&5udl?#A-{XaUG`!9*exY2U= zQYidI@GAJT;r6ki>iXo%ruc#>Y%ku{F$AAxb!g;J>M{$4``F^NybWy+g*-QBrrX-L z_I{CV6}7aP&QTN~`j7yk!^Ji?&~1}ZwDQ@uvMlkj!(?EL=aQ#^lf_hn?EW3yWvAb1 z_p7+b(93u?!nytx&wiPrg?YI(&Az^db#N`MFMmUQ{7D_n(X}wD%m^yvWmI>o7gq~o z3#4qnFee|c70Zti#bXrC_nL3nbB1=FEXGL5T<56IPPEtYj+1G9Zx!aXrTKGn6D{;A z?me=@rFv14wfU9P@2g&iEgU^XJ4z4kW|hx~8qVt1&bFE1yZLu7A{=)=pS4=K))v~& z%NS_blNtN0*(35i8s;y2cca0%hPj}Qemd}UaP1@an%%Oh=RF4e z_2=qpxc#!e5LsWix%j8x3ssG_W?gCl=Ht+A?eFRI{{UwFIU^ApN!pk6zM9T&k!5QygoPW1`dWVjG;G_CBlIc9!e}wF<<+7P# zI09HhG<)rRsR45g#DZ}Iip*Di8^kpJd%h+dUdVE#HdOC09Zgzx<8j7 zbgzuAziE9R_O1T_f`jW{@P)3a;%N2DBld8XhTVsdBTq1_m?(fY+&h-bqJQtSZaMfyJXHBT6?^DLlledAMKj=vc{X`c`L0-EZ>$C2M?x`mP%ZfvDbK17JC?R-R7 zHdP7Q2|S#EU$8$8uVcITf1yl_Jh`;%c@Um53`$2@_INK3?43(Rr0u^`=Q#OZ6kFe4 z(A&9P*7)_!e5dBaB5~Uw^rk})&z6)%@kL;m_bVr zuhAbJm*Of?ZPRu%BKU8j*jd}d4UAI0QxrSf)TY$jPFL&NOM%_?JZf`yQPP%M!0OEkjKI0D@Kg zCDpto{f{#EwtX@=3Pl9E!?8Hv5=iW$0Cvx)6^=jPjQ;=!tmpexo#&5iZta7J)iu4* zx{-h!6gz-k6iLvweq@JJ3 zzsSMX{{Y~cU$h2^cjjq7wEqCaTR5bN)P18?mggh^&jc|0?a6Jx?#2c?44+zqN&SI; zY7I8ZIi&rmJWR;NgqHAIW+@boFj>PU+>8O9fE@Eh6iWpOD>nY~O8iZVNjC~_POy*p z7jJ)MZ`!U+Rm6X^HQcs#+{GK&M$@hb8FmD99Y6hbQrGre{j)AS!E5_OhkJQ^(z>mL z5=Y3!#Sz9%IpZg;IULbNcVaNGr%Lf}qTI?{MP6`=bm{z{jtAn0{1Wr_sQsdJ`&}2t z{{S6)DdHy0x3Ut%FN&eKt zbMoC=(SMPZVgCRG?X~k(TmJxyKOA)>cil2w__j$2e!)l|E=D=Q$G1w=)PLZb{t?#< zV^Hyj#!0-txflE=`iyE6bXM}*1|53w&uS>Eqec>!yoD^&6M0LREx+KE9}37iyQP*S-g~$O|5O~jg^{m)` z;E`Vfbp2vWsQwvgcM>}1f7|X+iR;=?$G1)>qKcGZOkuMPG zmj)d+;D8lllHGDP@!J4mzEb_0z8>FrBlf)i0EH;u41O2*>rn8`@7Xn!*7RFxWVb>V zMk4Kw`3M_!<2<%W;}lU&loOvL*xsfhtv3Cfn!4!O`fa8E0KqrD3)sZ3C&90VcJoOJ z`H>A485S1%-zg&*2iqK0&Y$}Rd?kj{?CJ2Q;oQi|h=1apvQ{di9G@?o<1|rE4+U45 zOQD>z3`>VIoUf-><>qf#{>8r!VQ9^-!=C}$=aNQ;XfQF1V`dMfK0nz*z|peB4}d%# zsH7ED((Tbx9)VO*MRY1Ka&fknG5*xWb4%ZUd;b73zP0-}cp5qFZDsI}hVG=1$yl3A zh;TAaTwvgPSB-woKd|S-?}nPDsdE>J{2!*rC6(p1j5c~Vn6|S`b0o3M;g}f;#$rrv z#v|`V6^%SCKW(D1xAw*^oFx@D)7>`h_ZPfT`wDzo)qWiKQ%kY&w}YYA^w^+_Th(qn zITXiRDJnb2>miVUR61=|0}_*rfNSkf59yOw=oU8f7~X4pc;a?N3arYY5(i<&1GN-a Wt$B#_`XxVyW%y9FEEU4s)KxH|-QcMa|k+)3~NK?4K`@+R3y_CEKV zci+4JytiO}Q>*){udA!8dU`?4)6&x?0Gf=1v;+VG0sQ5OKEN4TY{b?_NzyLsj+aSO{00>+F%ug8rsDZ%yBey_M z{lb9RAeevZB7<@7Ay|IOm|%Gv^7-GGHB;O2jM-Tnx+{02U@@P9A0^9%c?=W;PxsZXR|n0DvSg{kLt9 zPJvl}WUnyNUp9be6h`?=HU$7;0RT8KJ2Vpu3sWrAZ<`?lW1;`Z$#D>WVIUx(0CA9i zILqRp0f>KD2U!;n^QXNn9_|-E*kkeVzhp2l9^sda29^>3;CtgkfAaeS0Fa}(zwKdW zV#>t2BP4?ObUGI;18*AE~4c!~XZXJKOE|s0O!x&Spt>&ca}hv{4+}+A%4#css{4Us^X&jV*HnQfBI)O{}=K8kT46wK36t^BXz`N35a_K4-~2G!BUw%BV#po3PeDB z@e&&cmz;ugMj@=@k?l z68a`AJR&|JF)2AEH7z|azo4+FxTLhKuD+qMskx=Kt+%g#U~p)7WOQbBZhm2LX?bOH zYkOyRZ~ybb;rYea%d6{eH@A0BD**T(^AQAGdoSlkXuhiJALl0b-^+a9mIp3*v&4u`X$HzIna4DK_uXSj0VD%;!OR**HX$*kpd;Fpr|Am=$XEm81 z(zw*>HR+sCxRQdmFYOfRO;$zA=*mHX6UVJ!rCc{!V{zN-c`!#mZBJ@an9XGOBUS}Sn_M=g848{?nAt#)8X020?0Io)JXIJYcfV#Cv8|D;L~=bI`??6G_6z+8MCo< zR#M$=I1VK{9Vl9=r&NWx9S)01?bD7w&otvY8IyrW*)+4Re}AkH@JK^!29}=CVJO;5 zNk;<5ZT+ew)9?yE0v&-1F2`=3E%_89u;1{4c>-94#vu)xT?I}ILBgcV!7z6V9?z?$ zm?wZR)|)MAT)gpuh{#bU{gw&P7`M1ZK`hXBp5)f-y6ml62C~d<6~curATn|6Fw~{l2-<8$0198l3 zd_Nj#p*T~#^Q9CvOmYg7LaXy5QX$8{=oa#ng0a%Anu&a5clGmJ;AQGmHQ}rsRHGg>8*FGEB6o) zXZ3D&*IYR>D=6^`H{!iXkk;a?Fc0tW*dDDd1Dv9q*TS11KCa_Z_W2$1{5KD>Oq(&z zx`gDNoXtVhD*$2$&brC>Do@ zfQ^PnxTdUw&FHJeY1p>!8IVx*qFXB5&djpPpj)#}Oq`hFC2nrL9kmyUP#H}U9-iwN zzIpT#jTZ!}lbzXz89xCuJwaRqshAYyvYV&Rg@hi*L(Q(F{kNNSDz3dcO9U;Z?H2Jd z{J9=z3eqzjb}1vaI8F#jnbmVN>uJBYI;^+NIqHd$KuyYJD~{9L#yJGMax#g{Pmkv> zllZXiL8jh7N$do1Xnl}_ihpA-CwO`E@-dJ$^*q1Kim8faicn(n!k=f6$If6&e>#x) z0{BeSafBVS605S+(%}{Wkz)`9wNl?(;_WS#WE$eg7OA!?Wd}Fg=&pLmVzEIOttMng zrW6MxlqEVl(sC3t4u-Woj-T8w=X5r3t+uMSe#s&XRi_(j8l{;)g0l}r78G z5^cE&4o)FRBj#Bv;M0y-?a4iJp}`4#u`V-zzv1h=`ZxYQf^mD6!`}B5eKGH61vd>m z>>LQ;_k)(nxrwr^I;Z;D#d@0O8S{~rTI>(hp8%V%9U`aBAbOzkLWl2miJ$r`oQ1hF~EN1vYdvs**U5lqiKNQH)oaIrFDulx!Ynd0J`_j_zWExmq3RmaIa_ zA}bRlV%ztI(Vi8@qORkx8_fPg(3iZ{3of9GOIgf=CTOBqtf8C7`Rv05b>nUl1A)eI z5xy1^sU5sl#wnB4BemiPOL~|x%t#JY8)iNjqRF-xl(vHv@xgr3kZ3NzYUaZS?+0s) zIjI`Y1muoL)7U3K!&-WEi4O1DfXllLu0gUe2YtU~{use%T3XHQi5*L$ua%Y4PCHfj z93Fb;IR@(kPF}$(jpx$zdjUd>*s+G0?gTM2wT#A1rdsOjNXR+WjfM)VMX#tF;MPsP zDK6e$74?I06VF#^^bU4E)bPGBfKzB3 zUY7)N(`+UU%ATb7?g>>D1KdHr@G^6495%47G-Pih=U;%xd}Hu z%GxvGEI^M$-xy!IN}0I|ua|A?_#p17RwUgFt5OW(9Q*ps!6SQe*eEt@H8(QvW!WT$ z9rCm7Nh?`?iLH^P-Zid1>K!7P!^$gpUG@4G-_w~#j6q~I{4qFC?Dn!P%&Ymmrh!|6 zaid@3Y1IzzlDcO1PgMO)mssz+Ct(a~&(-pT+eE2FNAt4YyIl`$>o#@Kz0{Pp#$<^# zVbE9Q^@S^X-IY{r-o?J2hRM8ks7u~Ca;gbxW1(E<1`_aSB5kFe=xXR6o)5qiiUwrX z``ty^FIg>e;u^0E5G6Tdgs_Z!g*sJUK7hsBc4U%tBYz7KnXu7DHTd0cCefdmHwM~u zO6Jm*_zBRTJ==Yfa;wR|O>WOKDF<0n#|h;71aW_qpd^GiEg@Qjnin((}MHlLfD9Mq> zcVZ*NEkT^5KXomYSIaNT?anXkm>mZL>nmx}5Xm5{x4-j*k;mQj-Y6~5&J$L2Hd=__ z?uva+CpbmPAT?WQZF}cnrfiJa#ZX8&TY4u!9h6i+oG=@p$(9WtNgHpuYWCl=xIK&N^1b4;RH3{G#haFxsxKyQxn?Ux$Q*wvCj zNdKuRCm8un?*~PNPD%gfe7ASZwB0@1YWg zw+nd`lV@=Nmsp>@xF@5`{0`!u05pB)b*MFxBu*`4_=LHOY1XvFGj@gAMUX2mVzF#Y zFLAD`*(iL+^E47BZo`Ima-IPCE^QQ9EN>#DvJkMpBE~{RJD|8}oC_`|@=~Q>+ja^f zZdJ3%>&Sj>+gLu17Y&RS3lSvbD*TSDL@>A|6dd#wig4SRAA^p!CW(O?^G3+}#h=n6+*$;3jamiB!zm8Z_4x&qv2OABM)S(tSAWzpgplcIwG)eW~j7*4cj(OOz zC#G$O31Q_J3I{|Q8mAexX{~5-v&dY$HmgihdVg}1(_0>_q;R4rH;WU;+Hw9si>ES8 zJV%@f?v$(V58n$WPnVIK``S=h$)2c~fwJxlaJZ z=&?Tc^w4}qOsQ3@$GeNwOBDBYg1u2!LF zwFDQf7D_O`KnfLP;E6c6XGnQ&WzUpqVx|`tFZ>=#aKF%t4K&V0=5k+#8OCC3CsodI zW-Gd6_sMJS0u4wwbXoFsHR!~9ngd=hI_a|xg+jPFIBPD&>=r~0#iWuqp#M{A-%n{MPtWqoNS-Uxj zXVod0_qmX$atkUqUQ#~2DZD8+!uX4{B|3+=a8j0XRW91 zWL{`C@9MvfUF3)zWX#j<4^JJV;v!PMQFWz0v7K}K6&i3$h_EHVS za|H4a2&|8UK;aunh?`qC*RS;I%&=Rc;+XiqsTG)yde;|T#Zc5?X>X4i?{#Wk<2X1YpQS?eYAoo3S7nqu`~WvL#UMYlugt-9JXz|_8eOz5$AS30>V z$sX@x5E0dVca}Zb{qcM5qXC0x`7%`?v;&~*(J(Th=bm?gx8tSG-4Mf|eLdQCO#V^T*bG6?9j+nTm$SU0R5$&qZpk?JyUalA5S7* zG;^bZxCG@gwR|_#u`Jy|O?36CLCofM4(hVsBFRjt$K)j5=`%e^9(&=?0gtQ3*4gSS zT2}>_+zV+v2h>Ar;~h~{|3)&zmbiRGCq-llze+kQ6>|OM;UlT(@=;Oz9T)yX5+h`7 zropluKlptt?M6Iy}t*aluF$Dge{c5s$3AFO$G+IaNPr$Ty?Oh#~!62=Uj z99Z91aZKA^*(z+^7Dp_1Ocopu6Z>>DIleztYYZ#uoO>6zrO`q^a-ml%HnJV>IJ~bc&Qv2|1M})*IMCa&^+ia}4@BiNwzwP*qnt zMf>1^XOr2fgQQSc6kY8N-T|w&o>>FAonf|4g%5D6mhtGp-J*fRC~Siiq)cugR9p*J~5MDuKQoB8jywYG}3^%aISbg(J8i z55nOW_=+Y-(~26v{sP2_yQz1DrW=scpWcxT(GdA2MX5Vk^!<2%pAfYjAKbf@RlX;H zN&KN9Q=sk}ZnA9xkGqAgoK0Vf!FB0bTb~d4o>X|5UfGUv-LADQ4V`tm;`Nf$Uf+cf z2wCueMW|n^6t&OPAG^0?J2gi4R%CVkv-LvyAUOgCS!c2QBySuYi>x@lnJ7zm6le;z z^99VgR^2LqEY{(nux<9g7Dkaqg{$*jI5~^89+u}+YTmAn#C^qs_)4|&{IiuI&_2Tg z+HvNcw-nwU4Hv(p*Orem9jizBl%j%ogPHr3{KWeQ7%$W}Hy?3sC z+Fsn9&vex0pCh<39iWQ^c+fi(yarZTrY9$D={w8#vUiOjKXcj@rX=1EQqlInILq^G z*KIWr_(GzFIt1IG+S7*8F@iPMtjWSOtbKyL$FR;XvKahS>9i{M2pK4mX{N6rtY-ovdJp4=~tCbQe`+hcW&W2&1Sz<<-Ces$_ zo8ubpbg24NB6=OHI}`q+T8k%wBW12~b4^*dF(b_g-MEy^Rqy_RDhm!v6ot~2>V&Yp zS&cM*+iGueIUKEXcMGapB4SRx5A2r+y; zGPJ2#2g`L|8gjN~T=3N$nV^#fKh(I039$40!-scFQUKhjx9{|81F`RyzVKU5%)Dy8 z$x!dwOdKaweH@LRKlDPN&(q0#0-#>8J?N;tQB0P`Co=@4MFX&=gk6oC0=C)$g(TDl zm*LJUY~AjTMwl&E8k6!vhiE%vY5?bxyc!M9-+vIRGgxJNCuAeV&9&dDrO{~IKLy2& zO|EB-#U~dctf>rs=*Hq$i-wH5%I%j)!nq4jhDVGSK9~zIp!8|(dU5Wvc%59|N0hmt zG8T)2B%aLX708UH7Y1FTS52=gy z7KGFvBvH$kV_t0pDcHa8aFpcfY{mYJAULeSrw>+8uwmm4Sej}Y8YvM}JqtJNx`ei!zJtv3`)Wm|p35lw@w2*XIzA{_^f&JMUlFL}ES1f{cS`D$7 z;!TmF_jRd_p{3%eR_^^{*X91%M%YgBOeO20fYyY0&!#-=L3n~}6VXW(&4)9Ez;Vf? zVZyhx;FF~38ASk{4eo-_%Yg6<;Q53QA5xweXWn|~)cu3>BRlq)S*rfa$vId>Ww+Z$ zm*UOqDB9#?eG!o~qpss6^UqfE-GOXWMJ5CtW6?~_EK~yfC?k7~FMC%+sVEQ>ZScaG zp8)Svh>$Jbqh7*K5%o(ynq|BUZBYB(X4}p0lvcWZb`qK6F57@}8RFMlj4P0xwrYuZ zy$L8LAX4Bm0oh5N7^^&yf{t2gM*}1 zC2Hsw3@oKZJ|Kh7_2(hVZUb6p#fIp?UObDpsoVNx`fu^Q6BfFkmY+5ePAyzq9C#QR z?VK5mOzn+93?}xrj2=b~jLZy7i~v4C4+kR?Ymf`EG05D~j-T?hy_=HQ(v+VP{BXo1 z?;r}Yu$1z00;zf_sF`?Kn{b;_3JM_ddGL7HI@p3-jEFsKZS0(RJoqVplJkJ&AI*%E zV33ok8IOvX#4iZ=j-T?EE$;5_4DPH9_D<%E%-r1Ej7%(yEG+b31iiDTor{qNy`3`< z%z#MzlS2&TY~p0;;9_ZS_lwKnA6)-3FK`EVUaUWcC+1`PVZc9>|2skpdl!3W3wwt@ zCjKk!@9v*Y`=647UG<-7f3Lz1LSjBJj7QW7WaMJ+q-Jk#BOv@^wTP8|4^1p?1F{9# zxi}M>**g(yTH2Y~yF1enyMdgXe>jGPfr*dtzf%0qoB-CAm;diAwYB}}+MmYC%kzlY zo4Ee4Nm`8fhp(ML&i1ZOCZONJU}|q-Y0Bg9dj>pQY|N}iEJkefoW`t7^vukrCiL8< z?56a_W~?BP2{R`*E1MZ5G2{Pm?MzMn*zDlyWb-qsOidU;zuX00GUi`Ko0{;L{vAs{ zoc>pU3EQ~*y!q?S#_~s$@z@yIne$V6(1ZWeXyj_+LMiY^{420*EdL_-8D+m1_!$3f zi2q3xg1`kAf9CQp0r;cmPwWx%{kY|kwly*b{Uf}8I__7Wypc6X;Ad?FALCE(&xRjS z_Sc=gDac0P@2ZL4pr6fuP>N2L;I%dSk#_=W_73zmATyU=$bWYI!E332*WbqVC!Uk# z7xJHte~{Y$3Hi^)Kghp<>UNec0@8LS7Qcx9OV1w-ReLiRcOxgz|2}_yrpix0{K5Xw z7Bg}Isae{B1b!A6{3sP+CuU~mVddgsV*WMoAIKjiZBb)UCnbAl%O6v>v@;hFwWkN) zfJ}dB{9FGYO;LMW2Pcs8j|>)&Q4*IV25-*48-6MO6Y@vzpI-g{=W1~e2YY8%r@tJ= zOzaFYVfbbAzx4dET0+6q))?d@V8+P7!1W9HyX_C?x2lqpy}6T-t$?rt$jJpf2(g2c zrM;7-i|2nC;y3OOfwZckssjjQssfH&J9Fp1(_9tg_%CTL0Irw$54rpc_G_5`!1h-T ziGy=cju0N|Bn5CLjIWdJAA+?1pdG^viU30{)#0VW1kCSq}o{}Ld7 z_y19qH}d!$R)1--FmV6SB&ITBWMSZ>{>AWbu)hgFE+8jxbjgEvOaUb$7x10kZ>4{B z{ZWv!G`Db3{rTxApl)YvXYc+S|IeQP{CWMCtHCEJ@R^G7&&dmX^XKdavI8I3oPM0l zepdQGLO}fZ{dGaXK>fI2p}}=HaIoO=AMgtv0SN&f9uZs&gouKOh=dHj;1Hgnq98x} zAwTv0sPtKcp!-$n1FoX_tJ3FpoeuyS66!~t4+H?=cfkw*;`l~#>FYgQf z;S#1XP?~8EA6odi{gC`)&>Za`{Gr}kH>O9wUQmHz)a0mV$_9Cg6g4x7&qbHp=^eYY zzw7IE-Aa_=T~7mgKdrM%zXBrqx&C`&La8Zf=>3J$85#i(F|Rn)u{#Ntdi8yjXRqTD z@JMKqIIxekkJ&5i$Hl4^(dkBcl>j z<5N&)Jy}WuuE9TW06)maww?qd*y=tlFCix@m2v94n=4>m* zNA1@Rd%lcW=CO@3mr19kjIXI#m^Y{vm&>Az%(?326XqMDZES0sZ8&VpeC?vNi`4R$ zmc2m3`pED0Oic>&X>QHFhIBX5Pv+(s)_h{l1+ltwXr?(`*-q@-LB~5o~-1?g6qq#dT!>OzD-}MwZ*EE-uJ%k-?o;P~Q ze@0?d1n)B$8APX|hb5CHfyD=L!4a&+%~E@%eV6GCYz-VKiHl!oGY-m`V5K?Q>x8 zdX9U$btwmWW^K>C4!?#n=*p(u;FzrzyIS&#zjLd{;*0j1rhi>KvUgGsJogY)b_S{?8B&~@uo*d;-sihDM?n};z&bBW1vL#1e9{r%uzy?8E z4xMR-xT=+mG1Ae-SEkBeSo?V*C@7Z|or{cU!#-9tX|K9~|L^C~R|hHfmRUd6O}e#_ zNtKS}(~_>f$|L)$1g>j!{{0-Y_*6-G3JGB#SIf6>Y9pSeg`(Cx_tIBm2BBlkl>EL^ z{AhktSr!WEP1V!-GHF+I5u=!x29zXA`#YyI0?+V!VWVZk{J#q4=)lnS5)g>YUSy}U zJpq(ufVSZBK6*3&1T+LRBm~UQ-2uEeKtVu)%l`my=ors2v0jjn0-0IhUsA9Nle7KW zB_LoR({S+!t1Bc)tTQ9LI$89sFN>-qNj_OLHRijwBj_wUwrVrE^l4UAa~4%4?{Ypi zmz@?!lDMa1e`x<`fG+5VexiTIdntdET%rFZr;jsvsQdev(Y$-$od4@DdaS}5x&iPR zt)}M+>#~j!9H$%M_02LH!qP1%a39>CeUf+ES)(c1HDE|^u4+xmUYU3e;m%-Z7hVxp zw%AZ2jZa$s1hCY`!QOO_~^mb8pC7~ib=T(TvlCUr66mc1DpjJ+tCxX`G-?Y!2(epIzqvKU3{ z^5WJtfc4RQvTja&&U9bgt5ris{DVZWXCViTVv`6LH<@K<`FOq@9imZ+Hlau>L6u^o zdOZS47svF_ryN4DIQDRo1w&z`oUN%5Y<7)&KXxaIxCxW2z6P^=ee~fOuIPlK+R2$( zof+10uH$zx^~-8iG4mB<0ap_I9z6(y7CGbX^)9QyboHa2Ay|89g^gE5Qf(kTxo#FW z0h;exOh+}DR1Uxs2{B=eHsW&9y+zZDK}kfBl5u<2>V#~mG!1QTUP=qr0ZQq<2BM{f zvSyjf*O{kdJ8L^6>MM8+WQP#woq8Sbpeq+cRE1;DyV&D4dB;{&K-`$fR!1=Q_XC1y z^}K@tru9%1mO$(-kN#ko7+m>|mW#@hRROnIBf_=v?>)weuFCL>Nb&8!#4-HP$&f2{ zhi-*hoggt%QIX_RyCW8Q?I%ES+g3$ZW@i<-zFbM$u&uC=^11}a8&Ynz&o=5;=u^;U zx0n-ZMm(0LY(_4FWh3#GNT!U@oTo7)4;sLgC=4d*;8wUK%C8Pd2fB2D%ug<2F$rzMD_8`}b;dm`+Tt197XC7BqLBxM6r z@1#SFkBfKE0pc&@bma@*H5oNw;Km@y4v8b}N zfuY~hh*OU`lk`|w$l|F#U%)86()G&-wO4W;8Uy?%v}35<L5#gKZKta8>79GViePNW2MvTr=;UJYCemZX2;!J zUx(XznXtK2hSAqDd|~`j%T6ar_U%W+1tz&DXD+wygW1_uqG^{;dCiNI3GXjwWjekS zD2qjV=U>itjUQIe`g2r=qnnH?^HTsDcBN6G8T*q<3BT*M8!X>!$5p&l5gx5LyJKuj zbQ4)UjE}JUtmi=kF@x{fj8YxoPB(jEar!!=%t)GTnIOPaQ&w?gNBAYeI0pX8n7?)a z);)*hsRjxAwMN)I>&+9uRphhKw zJT9kNzfVWKZq59iOQBEmYxaUlu>)kT+*IQ1 zJC^Iy$PMp)oiGt6h3PJ`v1W{=PBl{)GrGmS{CZ_?&#J`{W5D*DfCOD8iUyXH8{%c6 zR!M})BL@Ys$?V7eTI{LaxpqY{H@x({3OrhfJKnb@{$l0P9itg)%R+TA7BsEQMc|EpYp8#3S%zI_GQAvqQu`!X5Jz|TqyJ7J zf|b)dg{qqI&AI;kwVQeh@=shSAK=+Cv#=Fh@1|c&^ABhym3<_vDvU?5TaYw~Iy$fH zVV$KMDt)cAi^wK=zk(&PO8M1s9xq~%jYC;2D3SYxH9q!05Pjvh?5Qp0xFyNl_EME* z3hotsNm{Xy?%kIDC%~4jv$idNww58fM10@$K2WNHHS|W?B~T%AM021xgK(xyux7CC z&Lj$oG|^ba+%c+b{e6rxDYWAYI&DidCkoPfrxV#{SaL<0x>BPEzJNrPVXF^hQV%}< zXbHm$?x{KBw>oRa_(nn&R+8u6V@yY7+N@b;rJgx*Qn%<4-x3sU?nLL-^?t*Sj9r#e z9!=glayA1Pk6d8aebKdxM;|{QvjRd6h5{=7c^+T*s*#g9(x#aiShHqa8YHy$|iFwcV85mACxn z%AFMkhAT(v5|VK(_O{u2!xKhl9MM4QO>rEjaeWove126`Qw@j4NK^TSsE*u|kyw93 zsy>aQU|jxD^?7G_;35`K1n(hq34M-y=-mx1eQt%JLPhdbl#KRdZkY+C?bYI#MtqBE zkfSC=@5Ho;}^=qp9*kdiaO-kuj2JaqkrgG^$&b z+T6lm8y|8XzC83O1!X9fdo_$J~NrGo>^G(6pBHJvT;thl!TG=1{clgRFSLY zO9{h>=vvtO@L<&UcukqnQy9eMZ^a6u>7&NU4A>;6P?=Np4O%Y=&X7c!>$G$MbG=a{ zN00?_1Iw4kPfQ>6^A~F3vZWJf9~E)Z{2`y~H^(}D;Xt7?gxD5pIX!#%}dgg7LX@m>575i>KOcI1_Iv}mlx z2h^7`8Kj|(MSzs~7g6cGv|g;=iu1QTAhTPfay6ARy0-#8ShONAu}Yv8_5JoquhT z)H!^ES8z9r!7sZr*-7ga;oQTk(Q+9KGNcfY{{q+VDrct1?jjr{FSZ#zePT{U(%?A# zK(x*m(WZEpbZ`0m*rwCkiGHp=y@j0C35V+veN147R@iAM4{m5xW)e=ontB_$GV{fV zLbM;_L?w}IbiY0Q(Dz~X*L<7^OP6Nt+smXVRj@bFzW6OQ(KuKyzPVqR#f|lx%W{+2 zWNbpI(8ZYfV47zb0?J=4dxeFmkb3IQmPZ>%u|Cc;cLnb^n&J&V0q&dCr8werr18`i zI$Jz4k2=<}ZH`AbLhQroPDIs^orLN$YWQuixyDAsrjMKpqT6AMBdg`RtG2|~5qYlHndS*-^6u zLS0^AnzPaK&hMreuk4v}$hs#y1JW2zYu?kFCsgI3d$WIjDZL$^5@7nm$ac|%>O$u; zlhc{0a8($Bitp&>qLJe*Nei?SiNqL(hUI5g7^)qWoEo*AD<8^rSXv9UyV#QLL1G<| z$0y_6DEPGZJi4mLtHY24$4ehp^I{&HBsx*GxFbPIxu2yiVNJjCf zuRy4)m&KglxUkf=(Q-RhJ@U-hPCI2JHZ+AB$SkZacjLXm*76VV<_P>?q4El%jX@cM z&t=@Cc6XSs5+zOyeY}PtNIPMPz@Z|0oN?vtyV_MnyUPqwQmnJro0K=Z4%5q!>a5sy z$eh3)lF2ov*9;`w)<}cZb+t1u_AYER#}l=cF_!C(Xh;ydknM8q>&lDMO6EG>;8cy3 z_p0s@wLFoe=7_#9Ij_F_ROPn2nLAkbv3T{F#;USd!$?p=BEmZ(%=wHA3x1(R2cwZt)Z55sWUTF&d0bwbk^3crox}#j6t3OapnivfPtzm* zc0NtbQe`WG1)?aKYs>^9HQPJsxWl)1xcr(1;Hufk?pjYq9YWbnXBx#ndW>)W72D!# z*hi4Frfi>X>&xSznsFT-i9B!y@6kN+z7c@&J*rOy(%H_215_xMm&_a-JL-B&Y{kE1 z2`_utYGv@oNcy@_zfxhY9V?6(t+AM(aaip2{D{jEB|VR|Lzm_(=fpt=#o=6lQx1w6 zOssvdDJIB6wrXLagF|z8cFSW-TD%(lpiV~mxznW*C+m|*!!L+9-DDQ zS*>1b&!4^B(iKnmEF{uLiyXX2^FZ6sJwAnBKZ0OdG{jtC z?;**v4h@Cqx?^JnL+?*tPSCfe(5nf&ERbyUm)dfHrptgASoeV~&>ADhnsS#Yz2hD{ z9Fo};>5nOLlgIMrduJLl6`RmZ6=N{~v;I!-WEGFfFq^t&yvbO*viB^k@%B~mvM!9O zxK&X)Wf|3?0jKp~*r|HSvh+cJQ$-S@9afvl(cF@Lw0YtdyNctX!1yLFuM}sq0lC*Z zUhCnoJsBKytlEb7xG3#BlaGA#i<3QwyyVsLzSasuRh!!rmYk}%W`*_?958MQ_{|O{ zP9B#?I+woQ-Dxk#^^NS30`OQmn}QOkuJdxI1(&VdT2{sNRN&0KXgTuyXq(&uJgAtdJO}y0b)=k$@)3N_#l;T<8=THo43>Yb6Zjz-G(cf z@Ww{Dnd3rjgA8F2mbbPvLVnkKxkKH{W$iA5NBT&NFIH4);r!=-aKV}s3F!Xm9IQ~8 z&d~McbQOq8^V9wLG2o|Daq!dWukW5gA;N;ccLo4|`3xNZ4gT)gb5dbt7#1T;5yyac zXh4;?T2|v8SS&I&QB|ku&9fKmCeDF%d3--ZdzYZ(=b%22 z$A`~TgqjXb^X3TvE3yYvyeP=Iapn7X^<1zJ9bA;q1gt2=r4fK1|H2UgK?t{LNhUq^ zYEv&8YZLzHJ2h;9&W()ZLxiCFW{*JfR=2sC--GD~n3!izfZO1O@H&y=^(J{KuVp_x zv&%b;rY`I~Nk8)-Ecz7>FZEMDNiii?dJki^Z#J>g1(j1KvX*l}cB!Ow*Zxy%Tca3& zQq3^_u?C?B8o|sHV(SW;?rN^;dvU*UqWchVfiV~FnVz*IZLt8>CXSwf@d5E@bhokk zOhIDWfhbgk8o82Up|n~~l!AnX!m4sSuco?5BeE)kTl%o?(IyUTLMP_(O%RS8Jw1@{ z@3FLVMb2iXtA2qa|= zNE=~0@Db$no9`aNj_*E@5Yuk!y`!HVKiVkf&MHBfKM)mxSR0(%mhx;Y?PmHyi{z83 zqYjNi>|PE28#EM%fOfe;y|2DejjQ4NApz}5jm)UNN7KN3S6||y1vdW-J^=Fx7 zC>c;QS;B*!=~K)2qfw$p(S~fgd&0<8SriP7CJcCE@K9-Nnry$ShwS9llsUB2Nnj`7 zS%_zuNvOoAM3LSjDn#v(#6bR5`T1x}%&L~hOXJLYfqP14I1RFc*71Qp=V53mOhNKM zQ_7dVX_G5@Fdfvul-@pA0ysYzy zN{7S&MsWGbj-B^x3IlTHMN@W}ri|C#rn9G;lLNRkmmhp^=qsfaWeCM{z@`%lM17Q~ zIK@C2Xbxvj7yf2?EMKD}f$Okm91(%`5p6P2g4++bcWXYcvnNtDhgx93 zYhV~1L3`+ddgo(9jD_%kv_fwxVosNY?y+hr`bOOZQ5KeUb)|iZt#u}iKLU>$LH^gu z5~nOPU#d7*Un^o%9lQf!;_UeymDWazTzG<`-hxAW?bUmeFG99H?3IQv5_7{k5Tafs zyKQaaaEX1d7dR2rN{MI2Y%~1(@i`Lo3}uZ)rr0^^Bw?AAvMr0Wo27NF_8oPL(nVAQ zMO+9YL6{1!X@a1c2YIE z>5|*-jn24`GsoT>QrVbWeK9KtC(RwnMWNnFOeS)9iP(gQDcdwIFt%YpDzeE=I`QH} zGAki?`t9mYt_`=mpX7N+K@=Q1t}pG$scrU<)|}!@^oz zhDBWY2PCXw)HI?C$dm@8qi4l|1cF``;iPY^C5kuLQVH$mn?x_Gkro7t>Yo5D-*Jb! z9!lrP03VoImk2ML{b`LXfr9?h8f3s?}vF5l-UJE^)`oUg^nw~3RTEse|@5{X4rHWvPtPR&R>y8uj;<~;b znHw?oyO{8n*s3p2fJ!Xgn>-Ck2i1EE!-xFb*W^QoR8Yn#{7R~r7~g`|tZ#HFFRS5Y z?z!i7U-w^Rzv>_|&q2^SdVU}+1Q+^lQk?Z|#EVa69@(b2h}6;|AO+#0nVfbdpnnH( zG)Vr{t3f2A{_Wy)Yu+SbK-=cs>;~PIBPV`EK@2HvE0kB)OK5VAiUu^wH6^l~$&7Pk z+YbI>&y4LFeeXb@vbn4B4B23vy7{hvFD3#j?hDzvpQ$%X5FlaaqH$k@W*89X9K%(6 zsSdlMP{ra~d@1Fr@Cuko*V3tNQILN;yKz~mC7Y8<2YMaCW8o1gG;pLdsrZFRdM&UW z>A#>F^k3{?Kkh_PkUH3?Qvq&DlifyIOKL%;P?Av{9Ow*SL+vlrL8WSjFg>UBI_~iQ z%tB;Lx#og?t17GwsW`6Qr?3N68ErOE%_5f19mDH(hohLcp zu&<0pLky}-Hu-H7Cp=y{{6P(yzCpOq(n+}3-k<4PHHDW}tjFs?8aF#V$oJsCKCy)P3B8LcKL_$GCJ>(gTTcqBz_u6n6eDU^DIbnh}DG@{z(UB2R z^UxM|*d`>>A$vF^HOzuIJuj&1t6HzgkF%7(j0NDj&l-u#w|RJCi$;7Oyb1$;`7l&S zv{5ZuD0QI!sOGbBtYeA$kp4bjp(S+mlLF9dbS-2!K?0g0Vx3yl^3bLFb15UJF3jIj zd}a_4k{J`T(l1KaKj--QT}&;=K|~Ec{W@l)*`7uZMX`{?N+!U;397=aG85&TT+*`I z66I63#U|OaX-!VP(ian0K6k1(Okb^X>8s!YrgS{nXsZ7xKVpZk$~X$sziPlQ50_CnCq=S)vUuO5&b z`4kCIR2$1B{h_)u;F2vKQ3KUV&YBl@?lAomQ~Ysns2=kNcYYYmkfz4RlR2{-59AhV z8r%Urfo6&*Iir;`an@^yEVKoAh!Y4vE8qs_wc%--* z+ZaqOy%loM@mj4yCU9AedvR0PiNa)|IqrPv12JamLY;IVwF1{D1HM5)w74w+&DOR* zi&kP|e{7sIB6bcixbd7%G?6Ia}H2gg-c@PoeDrYsFq zfp@WtC`wyW>y0Xd= z$HmKjQM3x52Jv`NPD;HP7Rkbw_&NqE36bSq2r^7VG9M>YVCU^5?uCn|zO;?mSk1nK z)K`MB!mVwOA&j>h><{ACVyqI%xS+$=7CYK;(Ay5$U#u>M*mnOWp9emm#y! zCNZR~fku@u$VBeOKD2J)EeAsAC|0sIydVly?G`N=y+`48O+h7BVjPc>WLsRPhYX<< z6{olYK`%g_=e~53HQh9IH;bBD!2wlWNFAe)o<=JcAzc$KA^Fbw+BiFX#?-zQ%4%Rz zk8Sv}I60Vn)jjD6u+VZ`An{dCjO7Iyf*ML``P*rSuggv~@5YLcq^~_1Y$;Uz=QJTS z7Otm=*eJ!lVv@f~hXplKe-o)m-Fc4LwD{IARftkuLoh^}QaYktBtmzVeFq|r$6n}D z@Jm=RsD(GSpC1!lBdNU(&cGaSoHn&d`u`BqnFxyGm!!ZLxdV z2Ty%WKy@y%s*{A@XXPs1PnHFt*+ENINFwuv?VW=pR%11kz%g_IacH|eTe+LQXsL@L zkK2s~YV7=~Lkh+_9NL)*Oiv8fi|<1+O-fSRRq82QPXI@;=LD+fd769SSc1lCX%IS! zULfMf%TrlvwHO&AnWpRH25@C^r^IWMM#z;w=p1HVC9hBA7;;p%R=PL^R&_bf+Zq$4 zR}_1@-1S6n^Ip4<*Nv#f$Mj4zq9O0f@%qz;+)iz>H=;EZ(4NoqhG*s*G(cNX%Ph^w z!C09!2XYq8->9YDKMutmkS11Sj^Krtk2Q8}y|fu`iiO6#&St4IsIRT?IYdt^v)Wt2 zOkLKGg6;}jeCDuVs@l`=h7c-NPPz!+yC>VS0TMw3B)Jf~Tle9rO<_cnMSwKaM`9+4 zjF1Jy;KOg6;7Ht5gO7kn3PHgpkeJAkl*HLwNR|lxn;Bw~G=hYc)H3c|IO4a26@MA^ z__i4(wStopN@sEI*a8$v9jP&(o|52s!PG))m=sE>Ou~V|(CewjrB;m0d?Mu|MLk`8 zJ{w$Vx6Ke2#rCTRqmTV)Kb2ooA*T}kgropLcm?kKqBvJ1=eD{QMP zd%MeZ4Ty6g`5JM*6KC3OLs$pc8l_VqN>9^DC}S-SxSi;6)RZ9!3P_#BWP^edok8nK z!6+@Ou+>yDZA(hP1t7#8^~GBbVz#k4nD`oLW0@mDdeVa65H_StN2j|;Q%F&5l)?3y zbS#p!oi*~-Pdefo8kisuK6H35xpKwAM-jCu0qUKp1%Qidlo_AJOx(6Z1(wTi&bX@} zX(~w=Cz3bnCZD0G@f6-joaA zBn-)`rIJa={Hwz3!f_6`G591>d#02Izhn(TthCcg#R#8RsCqjP4un(Q&cbc31g$H| z1d19)@Je->WJHlk7R^5O;&S1J^^WuDNnPnmRvYIkR0A8*eVOhiE(EUul0c61LD&uE zlIu=@u&sG^rE#4BEvgcLccX?5Hs1Xrii%X&R!Ks_ftpBCb8sD0bP=%qDZrpow-B5G zOz43hp{r$OlTIWk6o`RLCJN2Muu+=Dq~hCV2g17wj`Vs1K@&vu0x44vXykOE8l>0m z&)E6i5KT4@1)lhC6Hh+iblK zGC5FkH`WM>-n)u@_RLO%_7+<0r%u0280gul^!BoS|F_sEfhZO zDPTw=3XmiGR5Xs!TBOZuq*U$& zNd$@my&Vp;BcimdPnKvwfd&WIc=)KPy@cEltzi_dJ2_5CFum~)|QnW64E2K z)PLBZbmsa}opt0OCtq0WPQ7-*${Gds6+4ZfkF9Y+q;K)9bfcN8)VQ=2T*FPLsWJ@w zwXBff+Y{VV_lX&$C)wH~qs}nT$^sVQWoaW-B#K(u z3wF^$RB`4_E0HTyWP+6ehpk(sTr|3n3M25+S!r@f0F@^!MLMs$&GUj<2%nfA#+y;* zIC7ExRkT-|N(u^yrQ$&8UU?K1k@-}JKN=myK;EOuK%q+mc&oO+Cl^=K)Y6%CwIC~0 z#@{+Z@qiRkHQKMai49vMKsr>`Dmc!@io8MPI$PySWI%GbHpoj)HUcmN$WO+<4Fl7? zDb<`%-m^}umAckvhNG|_3R`WweW~RnHr#GiOxw3OvNi5iywRQ2K|~m&>AC#tnh)@z z+I7!U8^w5SvfOQ_47fs%)8jy6&H8#!Y%DWsb+~Q@Id&bW@LfPqKxrCzJ5j#6xDGOo zRjx4t4?)(EezAgf0F6^nRTLA%-)d19ks_q(C`O>~S{jg|CbF8JP5V|wFn?Ot?*25N zg+)Jb!V;x2qbN|2>E66a+(CDX#85w_Tj;pCM8xQ-@kKTcEFd$c(Wmj6j;4|cpM^t! zFYxyitQrlK!THo3W`aTG-k!8*rPmRU7a3O#C;0kR`dC($P7tREl=rJ^@=)VQU@feA zDE|PZ9CzM;KJ`IRGf}ysqs#!LF~K)iOKDH7v?wPA)7DL0IKxT>Yi-tZAeQm}0A_@@ z;s)WxbC&KCioP0Y6~gwQgs+A^MkJ{0p{MytUC>S;wr(%nk;RQF$%a^2gWO@8?5s*`(bJx3LN>2{Ydad<`)TatsO z)SrbKy3$^Q9Zj{1t`+PfrB$PF~?zc9Pc#32n2Q)3L~Q zrMquwVS5FT$6)|!OP}&JhC~IHx@Z|1XT9rRqJRb+P>pyjo`qpe% z+cBQMvH*(H9K8ijR2fss4wTAu^9onqst(%Holg&l@2-T_t)_aEQgY%gCsH@7`=tZH zRl&qfBm5)mT8P+E9>D>5NKqAXjp1O8Aty@ltw9beA@mi@E1Ogey(lYW4d#a#NKOjL zE>58L75CRSH`ab^T>=&Jis2&F^S}_MQISGpx@nW)-U#*EPa`BL4d5f;(xiA(HmTq5 zpGf^G-#5bx-R^ECUk=ra-Ya(ZUPOe)r3E?d*DH;d{C<@UcwM>=FB(frm3`Ws?tWC?iQC>= zD%Hk<7Pgm`L{8sICh6BMm(clit9yo%ah0;OPEPfRrpPHNOyrtL*q)U>UzKa2t73;V zZAXYtXr!rW$fwG#wFV&8um?($tzO(FLyB*Nh}r2K_ud_5N$YM0ihKUxU3v0;=Tbbf{B` z5_(dzDoa^%;HWkCn603l$>vQQQ}7+>T!032PdAv>2yCNGS4HPtD^@tq;T02*K9%iH zv=p}ajCqI}(#8ebYnI{_)|B zJ$u%k?CAyIO7h)5;ipR z4?cKz0HJt@IP7UHg#Ft|2}!JDTGEp#Ft4RYPL+@6-i`?aEyYNIKyMYON`Tf4GSQOV zq7qPJxZ0Pr$4Y4TK>a?&Hs$5BFAj@v1s|0Vi-ufS-W9b33$?VLe6&DU){VuliQe9> z-32S+=sCJiTApb>m92RRGgGFO0PG~xFC_!%Q3_Zhx`Katol-zjlG=(=b;!ol>eX!J zvx^*L=7QOY=~mYF{{Vk#*^`nbc=Q5_UTMZybg5z2X94R!I<>QSG81aw3nUW@gS}`K z@tROcqZ*SynVKmHR-NlqfPHkWTugjLVA|VXSS#9(J9zv?l#sWYgSR-Z+p&kMg0;M~ z*GS_gq+eNq>hf!w#cUHD-Wt}vG2!E{?L`6n^JD)2+PQp5w@EP#rAj{bdLpXE)@szP)%r}BwK17U!gIR5}ByrzEj0FR)gvwP#Vmk1x;Gp%Irh%nW{ zIU5Z$PdBWZ33K^bd|O6=E!#E(cGjoj@}F@UQvj@`EVWE`Xn< zO8i^L9PgdkG&q-@GGnK=xIaoxNxM?}s_n&3){ZuV9n~u*TIaN6#=C7rL&F zIxH>5BOlI{#PI89TTi&O0x3(fXZgLs(X~Tr^eM0AU*^vXzH-@kJ;r4ur>~o>YvA4( zl18I~3|2qQUL1OVS4a6*FNau<{{RYV2Z!0m{j}0jZ7q~fP8g?DE^ONvEhq?PaV129 z>!f~FYT7Ss!|i|`bynceE?%8QF>L@K9NGx;fev_FUkAi80uYS`+P^oU2lLs8(|G&tXk8NWSQEMxX`^=H)k z)@xRdBh*Itu86LY&Z|f!?ZUtM!8V329y5+&G2y5J_unmcyTltlRru6+GjA7e# zd<0^Fve3)<2$~&}KNd^$6>aMtF$61c>SRRgGX|zb(+^V?CcGlchC^6|$hBLfQk3rr zjGHVOGc*xvT};!{Ct=*MkCe5U-ktbBq_h&(w80jX66_o+FT{1QhW>AvMS(v9nMX=( zw*LU-%j$8Q)4%ZsTzpt6hw4&7^#y#ULC9eXaXVuuvWgY%V}@(iYN{HG9RmLV@*r9l zRWQGLMV{l({$~xErBukgwCZ5x95*bD3vaY+Nr0rH zgKKka#(Z3DCZ$O%3zXh0c#F;bbauhp*O^$3iuUe&rV;TkP2MaIgk?A_`xUtNuJA}; z^2#-~y))K{4ol9P!~&Fj^$Y=5RhBgeh%+`MST>hdhi0HVd3E1ej8anY5@ZgOkKz!= znECTAbPBHRzw>e85Z5b}*{ms4Daad{dih{ro(CjCfvsVWi26l`cZ#6Ha~;NoX69eH zf^6ODraxFr8O6j?YW;`)>Jk*!Hn{C@Po~s5w8JJq$`~IpkXuHz(A`dMnhUHxWf7fE zZ_9~804>eq^#h6Ni$IO!3~J*+oD1DMxN;p}(T=^qV})$o--S+Fa3-4zyz^0Oo`?rh zzMBUbJBKl1obTola#Oay$SpJnRJ9bV4yvova@>{PZnY~w^c}|9aElmHt4KL>F9ptx zvspLl3#k`y&#;vXuI^tXsy^+?DE?Y6xL;VhMjtZgJRC*11@{9I78T4jN(H(g?M>!a zL@B3_aZR>_uQm~B*Nzc4a^Sy(e$KxPcV4*>g@OyFB-3c-eJs3JHy z%W!}f$-+2&U-2;v>%yuxlG;53xEWO9af_2{FMO>;b-%`Ata|5#dLBw&yYBpiSw%vV%vahx`02Ml2bo;M&2Q!!@P{+G`yeUG$DuR%29|)}-F4d- zXrV>)^>NH1BFSOwKbYx$HoI`Ih(=1UU*x%o1-}hLl(1}Ya~{;wl~u|>1KP&;F}4VO zW?C3ob%_B;tx>Zt8~&8eQoty6B)OwH5O$$jR}&L%S7LCszfy@}jTu9C!yV#L3D4>k z@pwSxYtMYa!Zt9E2pY~u!8ez2#wE)QtNe(-D_Cmaa)Djc!aNG)cl=DKMH<(WQHJW8 zcDu?DzZ!2CyQuth?KyWv${ye?Wb}>EpNsf_#@`FAlZH2BkH{*^8@~ArepnSja%)HT z1tFm2Ue;YPoi+8>tDFLaw=|~_g`l?E3YG3oAC3CSF*R2T(D!T_a5;j~-&HTT+kjnt z@8&g61Bh7(g%cg%r|L;tS#{vlj*o%iu?% zvAsle@peiY;aO?_01??JtprsvraOB-$w60&&xmp^htWN}H6`7GBw+ZnH$BMq|eARi+r zcN2w(v)*@3yB0eT3yDLa9K+AcEEEYrYX%77hY>E1lOGo|HgHI69IO(`VI?wJ7P!_j zE0`mhgl-oAXoA{|wia@ITyW+Xc#T5$gO~-m4kl5K1*Ulh2V@1XR!8_ojodX_hh0IM z!~hm0;~#SiZzs$gxJ_gp2tJLQ(Gf`zXw+ZX97-!3h_nD4=5ly8R4cSyGgAEc`XWqD zC9rKU_9SENg6Oi_7!F{WZ@q2?_awA9J!kGUEO}HGkD-A{CaKl?CY2=O%HPW!PPz?O z`zdAw9HF*g&vNOAZU-=Rj-fn?BAd^aa1z4`JXOkhc!h@5ZQT_+1M(7;W3b|{;vc_R znRK?>^9==*GXq-IZhXqWfK+I&K#+OHB2#{W@zg0z-IT$#y{)oP>1H_UFpBmf{YkZV zFLqn$1%iu(L1Moq$X3wSC?E<4X5tRxq37J{+tpU$8X#k7X1*e|SQ~~xs=0(Wg%6O# z#he%4JBC9&hbX#m%fj4)9l}3|Mcf<|8DQL6zD4xDrt9g>3vc64g3Kd9z4GCPkVk_Q zuV&I@F(}#7*%=!#F3^Q#8v_Q!cy4hv!Z!VfgaY=P6^pvC_?k|T9jM*YsbGa;agv@l z6td^M{FVTwso1yWZW32couU1qvEr#oY_J(dO2F8oc8@8e(P7R&Ib*UNfh(czV)tl- zbX>-?+FEtjbC+_XhU|VPe3lM*C(N+bK&Uodbpv4ghUB+KIh&w9ub9@fQ=TdSOCv+W z_cJaic{ zrqxJzn}3lZ@5gB>2(NeH0 zmX;r0x^WVx0koJeifra#!tTU8^bn*NQCee~RJV9n+#KJEx5Oq+Y=rQX;uXVywV{6$AG6(G8HA_eijI@Y8ot6DVs+IxAs0bc~0a~2f z{7xNY_{^Rkg*R3?b#pi@ulx+) z&HyGKHVmn;Vz@?&y|DO}5{^H6#}d?MgD$FynH&5i)S-2TWt>48N&#}ilL)ssoihVU zH)zeX=L!~IQprW0Tet_u@|M8E!2t6ipG;J*x5S}kQ4+Wei5E_93Wlw7tnoLouvYxk zprj$Y=#!W)b92C@v*?9gWWrgjZ!!8ZmBN?7xF@kp9#Yo{=mWy|MLy<#Q@rdsc^EGd zrQ}InQY{w%jyI%z$I6E}F5`sYhYmQua1FW3tw0+V%Qq_rb_}wLfl(Pq&}j7viMv6r z5MiiKUs0@$kdoD7)BVo_Z5ZWq4p|nbEqH~(qRW;40KCD-RZIh#BWDn>R^w57j`}yA zqU|fyIUM3wiW{dLq2C|m(lMpx>)>gaFhL>$@5F0>iNS2tK!o6c0=Ll$Roh+8TP7Gnfo1YRXt2k|U(YGmse}zSQXlKj}Xw) zBBv$TOX3)pL&_24G290TwKZ5rhcE(}GwvNI%=ypkd;FPiYJR}@gbjP<%Kd_@MB4JP z`G8~xxS78&AX3oca|?kmiEb&^nCSq?VL_VaUJ_~CG-t=G#?P!z^!2yk6rx40?@o(+#ygaW4n<>rfaxiH7vG~^;Ohi}EQD9*aHGc=hPioRd zte%;#8uPgFy;QDWv%#j0KiP)RAiN)h2Qa9GF^qd7j4blxyYmp>n&ch6V4NTw#(N=} znpgIMHv?V7#zu~y37ZSzqhtd*g7T#@&sQvPNyG*og67m{>cZ zcv|0FCAu8hf ziY8+u5{hzpxWol^0bG5@D9ha(TR2L@Yzsz|6vmO=;C7ir23}!ZMg5f)FOIpHWiqI~ zJmv%(w$wa^;n-yeD3_N0GYutW^78kX1+;9b%AV#VumlRx>C5RC5}{*Ta?>bjS)%q! zqrlTp`8hy~XxTX#PM#)85)^a8ms zMq;S9%Wf>Tyh}dBcrO0{xq%{BO2z)ih_T>;S@Gg zUtSXcB;I~yP_|arN;;`iHh|xP--tjM%WY)5s?2nx6bK5gVrWeRxUDyw(QH)?<70vI z8;fM>x-sSbP4})1eUS9U49K*2XXZHYFuGqB>zF&aVZh&EzW3gs0{*3xFVI5rb^id8;uI>8F;3***Ffs!SThd61lM*%PYEsVs@>_Fu^{A2 za7XZ5FRj<_feYe$H-uqk1K5d2_1@JIj1a+-Dv_ z3NCEH#JC#pthPl4G#5E9F|wmMe*vjPmP$~+9ZF|rFzvo8#9wv)08B#JbS(K}UlXiT z4e(*8Qu1BJJT=S}70|j4Sd2+jn)-(Os&fILG}E8iisSw5y7-$I5f$4}HAGv}Fjffb za7CQT;CmG?$Z|G-$EB66A85b;22;X<_y_YFKXn1iLG~=a+>aH|Ot}0rFazD~+EZ@S z7*N$!BvWufzYz)6=Jdvh{%w8)EaVsNjIF}N^_a4_PxyJh(Yl-HEdsg`YW1(YlV85CRyv>OMN{w7CX;RU?rB|@={ zxUx4_%H3kU#aC<@%&XQQ5TW(y98vm!>q!`V zBEn^jS>;>91A4wz^-;tSd?pG2OM$dQh|&u<7U);Z5SFrItxO^vZQR6=aN-OAdY3$^ zydn6zg{wB3Az%72DA_LNT)TAC!!?2-oUAL{!cGP-+OFlqZuPKml%>^TTN)G}v6KZc z2%m{4Ec0xB68;8(r=d|ayF^~`daa&iqU$A|Vzmwil{|Z+v=cP=gmx3P)y;W8S``%&ka0Cp1Jro&Oh)`ojE6QT@4Jsn51*zbElLtyb%Ht6#jMvd$v zkJ!6Z0IOm;^0V~`P@w*t48xuDGTXqi;K2+H@Rwbtf}^u!rKxTXE&{4b>E2hO-!qL< z%RV-L*@)_oH4!TN7`VKwX=((Z5sKzAI-!xrh>@sui+s(= z-fs1mAJm~@rpJCWc!!O+dx0_4~Gc)^+$?uENWfG(T1Yd#Bcjb#vE~T*yv{Hs{07eJNeXc6qmrX9GnvdcO zDk3XAN2^@UAOSv@g%_3_`GxA|OBQ(+HLs|#8+2BOh`D^<)Ncau79Wc97h9kFH47!5>r^_oH zu*e}6TsK|+0M`gAo)CV)`ZZtvv{p9gT*1>dBbvhI8c6zTJb}==RvHK{o z>+HH|;z*I_l!_{_{{a8Q03{Fs0RRF50s;a90RaI3000015da}EK~Z6GfsvuH!SK=H z@&DQY2mt{A0Y4BLIZmG6j5*k?4_CwPWh+%L2UKpacu*i3Ce}sQU1JN8qS>r35?LEzz>bc+F=3nO_knrv9*DMw9WSV03YE0wcKG&Cz^hn`3;9gVfZz z>o1g`BATFa(evs8t!XTf?QY)&3U1#xNFW}RERXY_QX-kp9~#a8>`B_U&a&@|*L)s( z!rBo+yE?^eLpUEVjBr?B;7)yfXLzI9HjiW8ZTlq;bB$%%=&b0}e1ylWagAHf4R3ZB zMvaKF46dcRLyhZm@uJs}29G>BYpjIRkXbjW;|WeR6dhS^nsFdhA~4W#HFJ_x2O<9eTmexKl_)j;0Ba;n&8Qw|4Ry6YQZEIf0AD+=;A&bs|fTi}YE8{+H7d9JW7 z1S9qF-NQX9i*4)HKtL2LzdkX8v;)*N&v-(4qbYh#cZnnlFK(Rp!nFuG9j{^TVYPFH zv9M{T*!ajp<0KQXA&iCA_}+A{9z&OZSOG&1b(}t20@VN@9ydFV01=YtJ>NdPV{q>| z_xSc@qzUp&>IZ7$I;wRd-%OhGj1`Na0`dcneK@hoS_5@FU*`sjk`FQfh?l+dk7~UU zM51g&*v+g11=5Z;9z10Ql?kA4YI5a*+S;W`Izx;YVs1e;cqcxwBP|+{PX>VPnN-3S z2;Lu@9l<#WJXejGq=KCXk>SW{fcPIdsX8b_{$!yCjB*C}+`s_=0QMmKVn7OTX;+W- zvc^J2lcU}aOc+EazwS>a&Kd#HqHDfzP=e_Eyg2=4`dteMDfmWkn$kyn1mnjpMX(VF z9QoDG%z*%29wpC#kVRdm-VG{L=QEm~IzBO|tzebG`+hKzL}ZN{J=cxrHdMsSgK!B= zn>xZt0I?DhKjpwAh)0Y&x{bnKIjz9!sMu67p>h!v z3D(bS>fvt*WJ4D~dn=GRfk2>6mv55;#S%skQlhpZ1?bJ|lK25nY29~RMAe{3PX;J? zGz**31T6%ho3PwF#K;j49vmp!kgX(#TL1yyc&6b6!m7YFh-R?im`-U;ogxjCPgrs| zAsr1bk-}9-DdbSvUyKy=8sVsWc=^c@8vSE?_lEC%4}+V4nk24QZcZ^s1s$aMJ^uiC zq!n5_5p~}3&;wL9h5j&fir9eYKfHJVkrn~kVFd(EQ~mv7fD2%O2A=j|U@74M-{YJ@ zszgx|S+5rc0IVb}*NAUeUF&;Sli&Vkp$SF?Z_V#DK6bEKr}u{0C|eGf)2XbSiXoGzWsXx|~zQ&_~9AZ#!fyeHivtFqfwA<=bR2;k}7 z4$V=#WGKLz9zrtN>l0+*>M0;FQ1oG1hE$t$JJ#o{8PXp}o5N7M!H-o0j4@*O$KEEC z4lG1c!OTcD`=xazfsdcOh|_|i)OIu1)w~2-1^8)zZu_)gm#XEJLz*$ z=5lOCfDMM(jv++A*}M?(g$A7x!JvnHS%RvXBX?2TNv?ISa;YG9N`(*&rT+l9+N?#% zA!3C(+I(RGFoYlRAGa<5pAHf{v4bcW zNU@O$TL#O3D7(wDh*$Li$}$8s2FmxO%21FWaQ<-nX3*L?C%^o}FxDLCG03nYqek?7 zf3=;{dm)C{FmS-TPI3ffx60^y?;NwT;BUif@x;yrU9~!C?^Cx)Vp|6aKS)%n-gU=LCQjJ}IydI{3n)!@&}@ZQdC>Ma#GY^yDvI z6D_3RRd^2ZqpfD8D@0AFNF(axv>@oVkEOm$yB%c;fCYs$W6In_ETIo_?bw?n@Zlm> zQq2gV0Bffi0TAxtQgC!{6vRO@1Fj2B+^erW;DJGHb5A_1IfYm;0Q+nqrh?5_3tm3Ya0L@P%!s}#UELxHe%_k z$3>*0O+*G{0b#c_VQx)VzK7t*bn}3%Z$YO92m-gNWnM1+yUC!BP_8sNHHiDB&=Q0M zc<>NhiegfE3qEU3TAYqD6-Yh`SM`rt5KRcE{Q1Rvi+Dfuz~rL@i=lPy=G@x0sLOrK zopuzr#9lEcsYE>t=HB*ZB6R$w%k1ff(+Vv&LRM;J}}ZX9T4R{_`Wf46};+` z?-7z5HAeirYY(JIDi4v~1SS#$CiOqlH&qSk91QFMjVAJ<9DE3|{AXAcymW;kzbA|x zD)Ewn0D#6{F-PWM*FlY;x4w=rTuK)5-YSolHxy=&A;<8ba{ljt{qxTO^4*rf=(3do5=lzXy zaVzc8n_A&$3AVq{0<(T{L?H&c(7iJ^l?qUask%fanU(;MG=iZ34n_IEA?k~pr&Jq{ zE(mj{g^g^8HGp_-1=H*!tBTr$>DqowFAUXeozM4Jfr%WDJp0~ChGj|Z{{S-eWCO9| zee;~8K-ke|tP^U7Tmy2Ypab3mRM7N!*SrE3?h&mIkKQ#AD}@L=vg;-=GoIVVk-Rqs zpb4h;#x=q&0R^|&?-2wT0H`Ci3A0%}Vxkaof&=-uq_={!HvNBjvGC}M3**;$QPBV{NzLKfak ziX-G=NDvF6-Wqaj%8(Tm(cTFuqjG`W1*tY*Ru{jLCiGlR#tAQUQ?&<|BM>We%7hpz zDZJ|lHxIGFU&b*42_8FNfAbi~GrLR928;nkjdao~rPi{MK})Y?@*CNKAcS~-^uJjk z;QO5ikVd<>U?SzcK`Ei_)Q=?@YEZ!9hNrD)?2a&u>$rW4&K}#K(fRapYNJMP3)!lA6pl(pA(si%- zlh6aOwMp}L;}xS_XH=($e&~d7WOizo0!OgV*{{UQ}7A_1Qn~ZY02s3(~@DxCeB|6frJtO)t=s_FK^F#%wHeI&eF9#SM zAYGbv>7l1HG)NE>5`FW|M_2$00s+Ot5ZaBSnSm(*DGi0Xrv|y!E=rM%EiB~m&IJ?{ zp%@0#n)8-pZAu>Whu?Ta5PA-c`26Bo0er!BBz%FoQvg^#vbE^Y1NzJTH2MY2?n8Li zYJts#cl3Q>@@RK~1=C)C=3X#4^q>yR9;PVBsw^FykbPlsk8YWbX z8_4AYe<@*xUCPGy*_sdl4u#z?m0Jf^QRu`)1e9RU=y8X9myEe6XApw;vpZyB(MzN$ zYvUI%gto3)gZL+mc$~pW71$BEzpNTUCWZ|*;Dr^Qao^z0S~@E2-&s&4r23<8CNtC{ z;LEUh0?aibEjMpkE6xg)BXy_KHB!}C(TyBjv42q}uUQIZ{{W(77I89C69qjf0+GN7 z({HNZC+{AGgzh|@jBkt*V^ZNfzi#tq^{ipuVY(ypy*{&CX%c9CFs(owb)C6x!tAEE zR?OnSCl-r`Dt}nSDqtOJEffjsDA4tqC}<4)8q$FF!G z4x!Lgnx94u@^AyYjy{mE3x-1nK)SE2wAfJm4!FO0=PK5({eE+z;CF@%pcPH9_8Q^i z8&r6pJ6-Z|Gd`qgQNe5|@5XPCL6k<0H#}=JUJ*pKRo>0*ywnM}mK84SfAb3y8OjuB zUuBrQ2ZU!Yho3mkAjEBJYP4=7=4CJu@ZV)&bkI|* zBp_M_j#FFw;>3!nh#ExQj?C4CqMpl6mtxtPYAmqRw2c$t!^jKG_!ZZ2*BZ$Y01&3w zfP!CHEJ4(qJ@WI45ft5-M-C&%Cs&&Hm&tqTg72$!^k6Eu0IKdgQVKhKWLN+#ATr?H zgTg#GRK+9#6&wS7b((e^?FA$9L~%Ec+?0kAaEBz0@rcrdu+W#l zmB*40@mh-1^>9H0R3bxxC7%7|s8Mo`&4Cd? zQ6p224k1J-Bi8xfj87pRa*vjA)ejWl3@It*HeU^qT9Z$T!?)O}&J1S=0|=bu?H zDjjV-2e$MC29%N5q|Qq zD=lHEee;}}mPif{mnYXCG~WlTZ89E7!+&!G(2!6LE1ll5g{tK$qu$-c$2l080M@pM z^k)kfdqF0bCTqnG&O+v^YlyIgb|eTI;90(~){{msooW&P0GP-Ea}do$7uz?nus{u; zsqYW?UNG#Man*)zZP`Fn?vL5q=xJ5s?40xFu5k;|suNX_R6)EY>*a2LGIQz~? z6jT*4utgE54V5YY7kSHU9S2G!>r~#dm{(HAy>-{)D8ctMb$H`=5H#XMzfEt<$JtT{ zDbh#J1}Z>!Zz-Nh#-BN0PhCewa>?QKmCONCXu@(qd&ok0ApjZ>kSV@$EUj?>Y;eW< z#Z!Z&CDhv7Jd+H7_rNxAxDt?abRMI0JKLHFuLB)78b`!@=OzWg}K*l zqb@sOUDECrTFnl6Rbb9CO7<@(J;rO`*o1GL(dC zctCQg*vw%~o;0p}4d+yrL_r*F=5kCfKo^0W-+MU7ktoO} zzf8x1P)JvRYY0^w6?6q5M5>B)K3wt0o}dTnVvVqPRO|iYA`&IA4aYn2?-boo;*UW< z-;5Pu;5w+A5iBsI0@^9L-R~&#U8aHq{{_!m$$M7AJn?5m+ zBd{5ubfG?+VSzS4o+in@7OqC7Wne2Jmu{-*#<@qt$8lCEHnbUC+-a);Su}%Vd%`?I zP?k1G4XWZdksKNb!vUr`HF+dH$0d3#E)^JjM37U`vSUPZ~5}dEi z>Gg(EQd6}Z@$g_tr#h&0@@q7}PYrwZHHuIGNUu!)09=h^*Ag*jjpm3cTh5GOJeS5c z8p>uWY*oBn=7R+BiI;>v1|<&F`;XK6*MV0ulI=xMMRp@ezk~UWE3T@urG`>bwLewae%O@@E0^h{+(jHhhnWJG*Jj@ z2^2(YLYJ%x(h=6673(--BJ%S1m;eCPp!DlkcRI;H8>eIH?7n^x0exe=dHz>`Nj^N<8FGK;N+bvJR7w&#^F0Gbbj3pqht zn1KzNYoW$)0IvX0(IfHT0=kH2O_`^R7GhNyr7J_lJavpwS_$VJBsHPYF$4!4En4CB z!QiG231ChG3N+H_qt|=KmBC;%U0ZeR#sf>SMywW)y zAm9VZ=cfn4$^|%{D8yE4SJx*ytajb!eoTDF&w;!)hs&&S zP=d>5cgcL=+>r1GfjRS+8WCDgmw$LVD;A%PelYkT6{m?Q@^6k1B}%y)AZXB8Fa!V( z>roE-!oZjH=_fH{+k2638mp7$%hWF=&lcDPPEbv1oxt2-LqAOT$m##4||{ zZ^ir8Sx{1PJ>e`6Mfg zePnX!$Z!DFjo(}Ajed|gTnZu(d%>v4FtCm5WWkESBDGu%km~O zL8cGE-Qx)P6 zruf5Au!=7=_|A4B2*XvV{{ZGsi1Mb%$nZTUtP`jo+|Y{Ok-hE+2;M~z#0A_Y2uQ~B z4(o3D#ux{iTl&C_fMY>BZoczQ4G81%{NPc}ciCi#2*XI|U4udcg_{Z_E#V&w$H=2Av z>k|fCBx$HleFhqqw6KEbr_-5x;oVS^^E8yD9YW*r3q1Mi^U@rA5|LR~`e{9t~h zBTm)k{CALmD*7Nt_SkX ztvvXo3c{;^)Igz`ArokkOxzVb%x1p@CX8p^5F9KjC!ZM#8-Vk5-zM^@V09Yb$Na#9 zXbumrUUA)s-CQ{9`@%;`O_jng-}+z>fFa2(3LVRfHSL6uMwEv>^_DnZXGc+$2H%m# zU=PedX>e&2L6%5}D^Q4t+Vi6;D zZCMsj6x`Yy#XiBw7lVFx^58D$ORs~&-;)Le!C)6#-itC1y@q(;-_^ii@S|Qq#r$kI z$AnmLnxL7Y%+Uc+pl|O=#VZQ!p_?fa9i}r?=HWn7&^m%|=H*ejR2o)QtIi`(TG4>b zSTdE!y(;O3l}?NfuaL)zF<3|iaBzQ|rGm?hNZOt|GAsI~5aVDv=KbWAM1X}-#`>Hd z?$8vyolK%gkgpikFecrz2S`qe`OaXp7Euw2AYK8x$Cz0129%SNkN)J%3fqK|GCHp_ zX7QB-sRc235Fr{cM7*fb98G{5z%Y{lQh)+TT3+ThNm3Ll)CYQo7pq{YqL!*UO%E&< zni?WUgzK&TbAB&`>LBDz#(LuzTDW*|Bpmat^_)PBu$))(@s~?CPBp5wl(B`lSS6?( z00`>$dBGw5yO&2^xYiM#SlHR_o41#(W}^^#8V-I6^>F@lRUl9f9YC4C;=dvWuy9?{ z>SUsb`2a~r)Yd36t}vT df = spark.read().format("image") + * .option("dropInvalid", true) + * .load("data/mllib/images/partitioned"); + * }}} + * + * Image data source supports the following options: + * - "dropInvalid": Whether to drop the files that are not valid images from the result. + * + * @note This IMAGE data source does not support saving images to files. + * + * @note This class is public for documentation purpose. Please don't use this class directly. + * Rather, use the data source API as illustrated above. + */ +class ImageDataSource private() {} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala new file mode 100644 index 0000000000000..c3321447e3c96 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +import com.google.common.io.{ByteStreams, Closeables} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.ml.image.ImageSchema +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeRow} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +private[image] class ImageFileFormat extends FileFormat with DataSourceRegister { + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = Some(ImageSchema.imageSchema) + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new UnsupportedOperationException("Write is not supported for image data source") + } + + override def shortName(): String = "image" + + override protected def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + assert( + requiredSchema.length <= 1, + "Image data source only produces a single data column named \"image\".") + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val imageSourceOptions = new ImageOptions(options) + + (file: PartitionedFile) => { + val emptyUnsafeRow = new UnsafeRow(0) + if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) { + Iterator(emptyUnsafeRow) + } else { + val origin = file.filePath + val path = new Path(origin) + val fs = path.getFileSystem(broadcastedHadoopConf.value.value) + val stream = fs.open(path) + val bytes = try { + ByteStreams.toByteArray(stream) + } finally { + Closeables.close(stream, true) + } + val resultOpt = ImageSchema.decode(origin, bytes) + val filteredResult = if (imageSourceOptions.dropInvalid) { + resultOpt.toIterator + } else { + Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin))) + } + + if (requiredSchema.isEmpty) { + filteredResult.map(_ => emptyUnsafeRow) + } else { + val converter = RowEncoder(requiredSchema) + filteredResult.map(row => converter.toRow(row)) + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala new file mode 100644 index 0000000000000..7ff196907717e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +private[image] class ImageOptions( + @transient private val parameters: CaseInsensitiveMap[String]) extends Serializable { + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + /** + * Whether to drop invalid images. If true, invalid images will be removed, otherwise + * invalid images will be returned with empty data and all other field filled with `-1`. + */ + val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index 527b3f8955968..e16ec906c90b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // Single column of images named "image" - private lazy val imagePath = "../data/mllib/images" + private lazy val imagePath = "../data/mllib/images/origin" test("Smoke test: create basic ImageSchema dataframe") { val origin = "path" diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala new file mode 100644 index 0000000000000..1a6a8d67d8d66 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.image + +import java.nio.file.Paths + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.image.ImageSchema._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.{col, substring_index} + +class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext { + + // Single column of images named "image" + private lazy val imagePath = "../data/mllib/images/partitioned" + + test("image datasource count test") { + val df1 = spark.read.format("image").load(imagePath) + assert(df1.count === 9) + + val df2 = spark.read.format("image").option("dropInvalid", true).load(imagePath) + assert(df2.count === 8) + } + + test("image datasource test: read jpg image") { + val df = spark.read.format("image").load(imagePath + "/cls=kittens/date=2018-02/DP153539.jpg") + assert(df.count() === 1) + } + + test("image datasource test: read png image") { + val df = spark.read.format("image").load(imagePath + "/cls=multichannel/date=2018-01/BGRA.png") + assert(df.count() === 1) + } + + test("image datasource test: read non image") { + val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt" + val df = spark.read.format("image").option("dropInvalid", true) + .load(filePath) + assert(df.count() === 0) + + val df2 = spark.read.format("image").option("dropInvalid", false) + .load(filePath) + assert(df2.count() === 1) + val result = df2.head() + assert(result === invalidImageRow( + Paths.get(filePath).toAbsolutePath().normalize().toUri().toString)) + } + + test("image datasource partition test") { + val result = spark.read.format("image") + .option("dropInvalid", true).load(imagePath) + .select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date")) + .collect() + + assert(Set(result: _*) === Set( + Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"), + Row("54893.jpg", "kittens", "2018-02"), + Row("DP153539.jpg", "kittens", "2018-02"), + Row("DP802813.jpg", "kittens", "2018-02"), + Row("BGRA.png", "multichannel", "2018-01"), + Row("BGRA_alpha_60.png", "multichannel", "2018-01"), + Row("chr30.4.184.jpg", "multichannel", "2018-02"), + Row("grayscale.jpg", "multichannel", "2018-02") + )) + } + + // Images with the different number of channels + test("readImages pixel values test") { + val images = spark.read.format("image").option("dropInvalid", true) + .load(imagePath + "/cls=multichannel/").collect() + + val firstBytes20Set = images.map { rrow => + val row = rrow.getAs[Row]("image") + val filename = Paths.get(getOrigin(row)).getFileName().toString() + val mode = getMode(row) + val bytes20 = getData(row).slice(0, 20).toList + filename -> Tuple2(mode, bytes20) // Cannot remove `Tuple2`, otherwise `->` operator + // will match 2 arguments + }.toSet + + assert(firstBytes20Set === expectedFirstBytes20Set) + } + + // number of channels and first 20 bytes of OpenCV representation + // - default representation for 3-channel RGB images is BGR row-wise: + // (B00, G00, R00, B10, G10, R10, ...) + // - default representation for 4-channel RGB images is BGRA row-wise: + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) + private val expectedFirstBytes20Set = Set( + "grayscale.jpg" -> + ((0, List[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, + -57, -60, -63, -53, -49, -55, -69))), + "chr30.4.184.jpg" -> ((16, + List[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, -74, -59, -57, + -71, -58, -56, -73, -64))), + "BGRA.png" -> ((24, + List[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))), + "BGRA_alpha_60.png" -> ((24, + List[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128, + -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60))) + ) +} diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 5f0c57ee3cc67..ef6785b4a8ed4 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -216,7 +216,7 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/origin/kittens', recursive=True) >>> df.count() 5 diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 625d9927f7063..821e037af0271 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2186,7 +2186,7 @@ def tearDown(self): class ImageReaderTest(SparkSessionTestCase): def test_read_images(self): - data_path = 'data/mllib/images/kittens' + data_path = 'data/mllib/images/origin/kittens' df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) self.assertEqual(df.count(), 4) first_row = df.take(1)[0][0] @@ -2253,7 +2253,7 @@ def tearDownClass(cls): def test_read_images_multiple_times(self): # This test case is to check if `ImageSchema.readImages` tries to # initiate Hive client multiple times. See SPARK-22651. - data_path = 'data/mllib/images/kittens' + data_path = 'data/mllib/images/origin/kittens' ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) From 559b899aceb160fcec3a57109c0b60a0ae40daeb Mon Sep 17 00:00:00 2001 From: pgandhi Date: Wed, 5 Sep 2018 16:10:49 -0500 Subject: [PATCH 1550/2461] [SPARK-25231] Fix synchronization of executor heartbeat receiver in TaskSchedulerImpl Running a large Spark job with speculation turned on was causing executor heartbeats to time out on the driver end after sometime and eventually, after hitting the max number of executor failures, the job would fail. ## What changes were proposed in this pull request? The main reason for the heartbeat timeouts was that the heartbeat-receiver-event-loop-thread was blocked waiting on the TaskSchedulerImpl object which was being held by one of the dispatcher-event-loop threads executing the method dequeueSpeculativeTasks() in TaskSetManager.scala. On further analysis of the heartbeat receiver method executorHeartbeatReceived() in TaskSchedulerImpl class, we found out that instead of waiting to acquire the lock on the TaskSchedulerImpl object, we can remove that lock and make the operations to the global variables inside the code block to be atomic. The block of code in that method only uses one global HashMap taskIdToTaskSetManager. Making that map a ConcurrentHashMap, we are ensuring atomicity of operations and speeding up the heartbeat receiver thread operation. ## How was this patch tested? Screenshots of the thread dump have been attached below: **heartbeat-receiver-event-loop-thread:** screen shot 2018-08-24 at 9 19 57 am **dispatcher-event-loop-thread:** screen shot 2018-08-24 at 9 21 56 am Closes #22221 from pgandhi999/SPARK-25231. Authored-by: pgandhi Signed-off-by: Thomas Graves --- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 12 ++++++------ .../cluster/CoarseGrainedSchedulerBackend.scala | 2 +- .../spark/scheduler/SchedulerIntegrationSuite.scala | 3 ++- .../spark/scheduler/TaskSchedulerImplSuite.scala | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8992d7e2284a4..8b71170668639 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{Locale, Timer, TimerTask} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicLong import scala.collection.Set @@ -91,7 +91,7 @@ private[spark] class TaskSchedulerImpl( private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] // Protected by `this` - private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] + private[scheduler] val taskIdToTaskSetManager = new ConcurrentHashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -315,7 +315,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetManager(tid) = taskSet + taskIdToTaskSetManager.put(tid, taskSet) taskIdToExecutorId(tid) = execId executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK @@ -465,7 +465,7 @@ private[spark] class TaskSchedulerImpl( var reason: Option[ExecutorLossReason] = None synchronized { try { - taskIdToTaskSetManager.get(tid) match { + Option(taskIdToTaskSetManager.get(tid)) match { case Some(taskSet) => if (state == TaskState.LOST) { // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode, @@ -517,10 +517,10 @@ private[spark] class TaskSchedulerImpl( accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = { // (taskId, stageId, stageAttemptId, accumUpdates) - val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized { + val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = { accumUpdates.flatMap { case (id, updates) => val accInfos = updates.map(acc => acc.toInfo(Some(acc.value), None)) - taskIdToTaskSetManager.get(id).map { taskSetMgr => + Option(taskIdToTaskSetManager.get(id)).map { taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, accInfos) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 747e8c7dc0fa5..de7c0d813ae65 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -290,7 +290,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = TaskDescription.encode(task) if (serializedTask.limit() >= maxRpcMessageSize) { - scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => + Option(scheduler.taskIdToTaskSetManager.get(task.taskId)).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index cea7f173c8f2f..2d409d94ca1b3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -400,7 +400,8 @@ private[spark] abstract class MockBackend( // get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual // tests from introducing a race if they need it. val newTasks = newTaskDescriptions.map { taskDescription => - val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet + val taskSet = + Option(taskScheduler.taskIdToTaskSetManager.get(taskDescription.taskId).taskSet).get val task = taskSet.tasks(taskDescription.index) (taskDescription, task) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 7a457a0a72d90..9e1d13e369ad9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -248,7 +248,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.submitTasks(attempt2) val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten assert(1 === taskDescriptions3.length) - val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get + val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId)).get assert(mgr.taskSet.stageAttemptId === 1) assert(!failedTaskSet) } @@ -286,7 +286,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(10 === taskDescriptions3.length) taskDescriptions3.foreach { task => - val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get + val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(task.taskId)).get assert(mgr.taskSet.stageAttemptId === 1) } assert(!failedTaskSet) @@ -724,7 +724,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // only schedule one task because of locality assert(taskDescs.size === 1) - val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId).get + val mgr = Option(taskScheduler.taskIdToTaskSetManager.get(taskDescs(0).taskId)).get assert(mgr.myLocalityLevels.toSet === Set(TaskLocality.NODE_LOCAL, TaskLocality.ANY)) // we should know about both executors, even though we only scheduled tasks on one of them assert(taskScheduler.getExecutorsAliveOnHost("host0") === Some(Set("executor0"))) From 71bd7965177cc4d3f1f65fa28fdc8cdd797ad738 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Sep 2018 15:36:34 -0700 Subject: [PATCH 1551/2461] [SPARK-23243][CORE] Fix RDD.repartition() data correctness issue ## What changes were proposed in this pull request? An alternative fix for https://github.com/apache/spark/pull/21698 When Spark rerun tasks for an RDD, there are 3 different behaviors: 1. determinate. Always return the same result with same order when rerun. 2. unordered. Returns same data set in random order when rerun. 3. indeterminate. Returns different result when rerun. Normally Spark doesn't need to care about it. Spark runs stages one by one, when a task is failed, just rerun it. Although the rerun task may return a different result, users will not be surprised. However, Spark may rerun a finished stage when seeing fetch failures. When this happens, Spark needs to rerun all the tasks of all the succeeding stages if the RDD output is indeterminate, because the input of the succeeding stages has been changed. If the RDD output is determinate, we only need to rerun the failed tasks of the succeeding stages, because the input doesn't change. If the RDD output is unordered, it's same as determinate, because shuffle partitioner is always deterministic(round-robin partitioner is not a shuffle partitioner that extends `org.apache.spark.Partitioner`), so the reducers will still get the same input data set. This PR fixed the failure handling for `repartition`, to avoid correctness issues. For `repartition`, it applies a stateful map function to generate a round-robin id, which is order sensitive and makes the RDD's output indeterminate. When the stage contains `repartition` reruns, we must also rerun all the tasks of all the succeeding stages. **future improvement:** 1. Currently we can't rollback and rerun a shuffle map stage, and just fail. We should fix it later. https://issues.apache.org/jira/browse/SPARK-25341 2. Currently we can't rollback and rerun a result stage, and just fail. We should fix it later. https://issues.apache.org/jira/browse/SPARK-25342 3. We should provide public API to allow users to tag the random level of the RDD's computing function. ## How is this pull request tested? a new test case Closes #22112 from cloud-fan/repartition. Lead-authored-by: Wenchen Fan Co-authored-by: Xingbo Jiang Signed-off-by: Xiao Li --- .../scala/org/apache/spark/Partitioner.scala | 3 + .../apache/spark/rdd/MapPartitionsRDD.scala | 14 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 100 ++++++++++- .../apache/spark/scheduler/DAGScheduler.scala | 59 +++++- .../spark/scheduler/DAGSchedulerSuite.scala | 169 +++++++++++++++++- .../exchange/ShuffleExchangeExec.scala | 17 +- 6 files changed, 345 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index c940cb25d478b..515237558fd87 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -33,6 +33,9 @@ import org.apache.spark.util.random.SamplingUtils /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. + * + * Note that, partitioner must be deterministic, i.e. it must return the same partition id given + * the same partition key. */ abstract class Partitioner extends Serializable { def numPartitions: Int diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 904d9c025629f..aa61997122cf4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -32,12 +32,16 @@ import org.apache.spark.{Partition, TaskContext} * doesn't modify the keys. * @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage * containing at least one RDDBarrier shall be turned into a barrier stage. + * @param isOrderSensitive whether or not the function is order-sensitive. If it's order + * sensitive, it may return totally different result when the input order + * is changed. Mostly stateful functions are order-sensitive. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false, - isFromBarrier: Boolean = false) + isFromBarrier: Boolean = false, + isOrderSensitive: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -54,4 +58,12 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( @transient protected lazy override val isBarrier_ : Boolean = isFromBarrier || dependencies.exists(_.rdd.isBarrier()) + + override protected def getOutputDeterministicLevel = { + if (isOrderSensitive && prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) { + DeterministicLevel.INDETERMINATE + } else { + super.getOutputDeterministicLevel + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ea895bb3412e1..61ad6dfdb2215 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -462,8 +462,9 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), - new HashPartitioner(numPartitions)), + new ShuffledRDD[Int, T, T]( + mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive = true), + new HashPartitioner(numPartitions)), numPartitions, partitionCoalescer).values } else { @@ -807,16 +808,21 @@ abstract class RDD[T: ClassTag]( * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, - * which should be `false` unless this is a pair RDD and the input function doesn't modify - * the keys. + * which should be `false` unless this is a pair RDD and the input + * function doesn't modify the keys. + * @param isOrderSensitive whether or not the function is order-sensitive. If it's order + * sensitive, it may return totally different result when the input order + * is changed. Mostly stateful functions are order-sensitive. */ private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { + preservesPartitioning: Boolean = false, + isOrderSensitive: Boolean = false): RDD[U] = withScope { new MapPartitionsRDD( this, (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), - preservesPartitioning) + preservesPartitioning = preservesPartitioning, + isOrderSensitive = isOrderSensitive) } /** @@ -1636,6 +1642,16 @@ abstract class RDD[T: ClassTag]( } } + /** + * Return whether this RDD is reliably checkpointed and materialized. + */ + private[rdd] def isReliablyCheckpointed: Boolean = { + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true + case _ => false + } + } + /** * Gets the name of the directory to which this RDD was checkpointed. * This is not defined if the RDD is checkpointed locally. @@ -1873,6 +1889,63 @@ abstract class RDD[T: ClassTag]( // RDD chain. @transient protected lazy val isBarrier_ : Boolean = dependencies.filter(!_.isInstanceOf[ShuffleDependency[_, _, _]]).exists(_.rdd.isBarrier()) + + /** + * Returns the deterministic level of this RDD's output. Please refer to [[DeterministicLevel]] + * for the definition. + * + * By default, an reliably checkpointed RDD, or RDD without parents(root RDD) is DETERMINATE. For + * RDDs with parents, we will generate a deterministic level candidate per parent according to + * the dependency. The deterministic level of the current RDD is the deterministic level + * candidate that is deterministic least. Please override [[getOutputDeterministicLevel]] to + * provide custom logic of calculating output deterministic level. + */ + // TODO: make it public so users can set deterministic level to their custom RDDs. + // TODO: this can be per-partition. e.g. UnionRDD can have different deterministic level for + // different partitions. + private[spark] final lazy val outputDeterministicLevel: DeterministicLevel.Value = { + if (isReliablyCheckpointed) { + DeterministicLevel.DETERMINATE + } else { + getOutputDeterministicLevel + } + } + + @DeveloperApi + protected def getOutputDeterministicLevel: DeterministicLevel.Value = { + val deterministicLevelCandidates = dependencies.map { + // The shuffle is not really happening, treat it like narrow dependency and assume the output + // deterministic level of current RDD is same as parent. + case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == dep.partitioner) => + dep.rdd.outputDeterministicLevel + + case dep: ShuffleDependency[_, _, _] => + if (dep.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + // If map output was indeterminate, shuffle output will be indeterminate as well + DeterministicLevel.INDETERMINATE + } else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) { + // if aggregator specified (and so unique keys) and key ordering specified - then + // consistent ordering. + DeterministicLevel.DETERMINATE + } else { + // In Spark, the reducer fetches multiple remote shuffle blocks at the same time, and + // the arrival order of these shuffle blocks are totally random. Even if the parent map + // RDD is DETERMINATE, the reduce RDD is always UNORDERED. + DeterministicLevel.UNORDERED + } + + // For narrow dependency, assume the output deterministic level of current RDD is same as + // parent. + case dep => dep.rdd.outputDeterministicLevel + } + + if (deterministicLevelCandidates.isEmpty) { + // By default we assume the root RDD is determinate. + DeterministicLevel.DETERMINATE + } else { + deterministicLevelCandidates.maxBy(_.id) + } + } } @@ -1926,3 +1999,18 @@ object RDD { new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) } } + +/** + * The deterministic level of RDD's output (i.e. what `RDD#compute` returns). This explains how + * the output will diff when Spark reruns the tasks for the RDD. There are 3 deterministic levels: + * 1. DETERMINATE: The RDD output is always the same data set in the same order after a rerun. + * 2. UNORDERED: The RDD output is always the same data set but the order can be different + * after a rerun. + * 3. INDETERMINATE. The RDD output can be different after a rerun. + * + * Note that, the output of an RDD usually relies on the parent RDDs. When the parent RDD's output + * is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE. + */ +private[spark] object DeterministicLevel extends Enumeration { + val DETERMINATE, UNORDERED, INDETERMINATE = Value +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index fec6558f412d0..50c91da8b13d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -40,7 +40,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1487,6 +1487,63 @@ private[spark] class DAGScheduler( failedStages += failedStage failedStages += mapStage if (noResubmitEnqueued) { + // If the map stage is INDETERMINATE, which means the map tasks may return + // different result when re-try, we need to re-try all the tasks of the failed + // stage and its succeeding stages, because the input data will be changed after the + // map tasks are re-tried. + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. + if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + // It's a little tricky to find all the succeeding stages of `failedStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `failedStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val stagesToRollback = scala.collection.mutable.HashSet(failedStage) + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => stagesToRollback += s) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) + } + } + } + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + val numMissingPartitions = mapStage.findMissingPartitions().length + if (numMissingPartitions < mapStage.numTasks) { + // TODO: support to rollback shuffle files. + // Currently the shuffle writing is "first write wins", so we can't re-run a + // shuffle map stage and overwrite existing shuffle files. We have to finish + // SPARK-8029 first. + abortStage(mapStage, generateErrorMessage(mapStage), None) + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + } + // We expect one executor failure to trigger many FetchFailures in rapid succession, // but all of those task failures can typically be handled by a single resubmission of // the failed stage. We avoid flooding the scheduler's event queue with resubmit diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index e0202fe703f82..4e87deb136df6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} @@ -57,6 +57,20 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) } +class MyCheckpointRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) { + + // Allow doCheckpoint() on this RDD. + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Iterator.empty +} + /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable @@ -71,7 +85,8 @@ class MyRDD( numPartitions: Int, dependencies: List[Dependency[_]], locations: Seq[Seq[String]] = Nil, - @(transient @param) tracker: MapOutputTrackerMaster = null) + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = @@ -81,6 +96,10 @@ class MyRDD( override def index: Int = i }).toArray + override protected def getOutputDeterministicLevel = { + if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel + } + override def getPreferredLocations(partition: Partition): Seq[String] = { if (locations.isDefinedAt(partition.index)) { locations(partition.index) @@ -2634,6 +2653,152 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(countSubmittedMapStageAttempts() === 2) } + test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // Finish the second shuffle map stage. + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // The first task of the final stage failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), + null)) + + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.length == 2) + // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage + }.head.findMissingPartitions() == Seq(0)) + // The result stage is still waiting for its 2 tasks to complete + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // The first task of the `shuffleMapRdd2` failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the shuffle map stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the result stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + test("SPARK-23207: cannot rollback a result stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) + assertResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-23207: local checkpoint fail to rollback (checkpointed before)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.localCheckpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-23207: local checkpoint fail to rollback (checkpointing now)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.localCheckpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + + private def assertResultStageNotRollbacked(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + assert(failure == null, "job should not fail") + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.length == 2) + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId => stage + }.head.findMissingPartitions() == Seq(0)) + // The first task of result stage remains completed. + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(1)) + } + + test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed before)") { + sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageNotRollbacked(shuffleMapRdd) + } + + test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing now)") { + sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 50f10c31427d0..9576605b1a214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -258,6 +258,9 @@ object ShuffleExchangeExec { case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && + newPartitioning.numPartitions > 1 + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, // otherwise a retry task may output different rows and thus lead to data loss. @@ -267,9 +270,7 @@ object ShuffleExchangeExec { // // Note that we don't perform local sort if the new partitioning has only 1 partition, under // that case all output rows go to the same partition. - val newRdd = if (SQLConf.get.sortBeforeRepartition && - newPartitioning.numPartitions > 1 && - newPartitioning.isInstanceOf[RoundRobinPartitioning]) { + val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { override def get: RecordComparator = new RecordBinaryComparator() @@ -305,17 +306,19 @@ object ShuffleExchangeExec { rdd } + // round-robin function is order sensitive if we don't sort the input. + val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition if (needToCopyObjectsBeforeShuffle(part)) { - newRdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } - } + }, isOrderSensitive = isOrderSensitive) } else { - newRdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } - } + }, isOrderSensitive = isOrderSensitive) } } From 458468ad5163211b9275faa49cd817a974a6dc21 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Sep 2018 15:41:45 -0700 Subject: [PATCH 1552/2461] [SPARK-25335][BUILD] Skip Zinc downloading if it's installed in the system ## What changes were proposed in this pull request? Zinc is 23.5MB (tgz). ``` $ curl -LO https://downloads.lightbend.com/zinc/0.3.15/zinc-0.3.15.tgz % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 23.5M 100 23.5M 0 0 35.4M 0 --:--:-- --:--:-- --:--:-- 35.3M ``` Currently, Spark downloads Zinc once. However, it occurs too many times in build systems. This PR aims to skip Zinc downloading when the system already has it. ``` $ build/mvn clean exec: curl --progress-bar -L https://downloads.lightbend.com/zinc/0.3.15/zinc-0.3.15.tgz ######################################################################## 100.0% ``` This will reduce many resources(CPU/Networks/DISK) at least in Mac and Docker-based build system. ## How was this patch tested? Pass the Jenkins. Closes #22333 from dongjoon-hyun/SPARK-25335. Authored-by: Dongjoon Hyun Signed-off-by: Sean Owen --- build/mvn | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/build/mvn b/build/mvn index ae4276dbc7e32..2487b81abb4ea 100755 --- a/build/mvn +++ b/build/mvn @@ -67,6 +67,9 @@ install_app() { fi } +# See simple version normalization: http://stackoverflow.com/questions/16989598/bash-comparing-version-numbers +function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; } + # Determine the Maven version from the root pom.xml file and # install maven under the build/ folder if needed. install_mvn() { @@ -75,8 +78,6 @@ install_mvn() { if [ "$MVN_BIN" ]; then local MVN_DETECTED_VERSION="$(mvn --version | head -n1 | awk '{print $3}')" fi - # See simple version normalization: http://stackoverflow.com/questions/16989598/bash-comparing-version-numbers - function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; } if [ $(version $MVN_DETECTED_VERSION) -lt $(version $MVN_VERSION) ]; then local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='} @@ -91,15 +92,23 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.15/bin/zinc" - [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} + local ZINC_VERSION=0.3.15 + ZINC_BIN="$(command -v zinc)" + if [ "$ZINC_BIN" ]; then + local ZINC_DETECTED_VERSION="$(zinc -version | head -n1 | awk '{print $5}')" + fi - install_app \ - "${TYPESAFE_MIRROR}/zinc/0.3.15" \ - "zinc-0.3.15.tgz" \ - "${zinc_path}" - ZINC_BIN="${_DIR}/${zinc_path}" + if [ $(version $ZINC_DETECTED_VERSION) -lt $(version $ZINC_VERSION) ]; then + local zinc_path="zinc-${ZINC_VERSION}/bin/zinc" + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} + + install_app \ + "${TYPESAFE_MIRROR}/zinc/${ZINC_VERSION}" \ + "zinc-${ZINC_VERSION}.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" + fi } # Determine the Scala version from the root pom.xml file, set the Scala URL, From 3e033035a3c0b7d46c2ae18d0d322d4af3808711 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 5 Sep 2018 15:48:41 -0700 Subject: [PATCH 1553/2461] [SPARK-25258][SPARK-23131][SPARK-25176][BUILD] Upgrade Kryo to 4.0.2 ## What changes were proposed in this pull request? Upgrade chill to 0.9.3, Kryo to 4.0.2, to get bug fixes and improvements. The resolved tickets includes: - SPARK-25258 Upgrade kryo package to version 4.0.2 - SPARK-23131 Kryo raises StackOverflow during serializing GLR model - SPARK-25176 Kryo fails to serialize a parametrised type hierarchy More details: https://github.com/twitter/chill/releases/tag/v0.9.3 https://github.com/twitter/chill/commit/cc3910d501a844f3c882249fef8fc2560b95b6dd ## How was this patch tested? Existing tests. Closes #22179 from wangyum/SPARK-23131. Lead-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Sean Owen --- .../serializer/KryoSerializerSuite.scala | 20 +++++++++++++++++++ dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- dev/deps/spark-deps-hadoop-3.1 | 8 ++++---- docs/tuning.md | 2 +- .../GeneralizedLinearRegressionSuite.scala | 11 +++++++++- pom.xml | 6 +++++- 7 files changed, 48 insertions(+), 15 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 240f8cf800fe8..36912441c03bd 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -412,6 +412,26 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(!ser2.getAutoReset) } + test("SPARK-25176 ClassCastException when writing a Map after previously " + + "reading a Map with different generic type") { + // This test uses the example in https://github.com/EsotericSoftware/kryo/issues/384 + import java.util._ + val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] + + class MapHolder { + private val mapOne = new HashMap[Int, String] + private val mapTwo = this.mapOne + } + + val serializedMapHolder = ser.serialize(new MapHolder) + ser.deserialize[MapHolder](serializedMapHolder) + + val stringMap = new HashMap[Int, List[String]] + stringMap.put(1, new ArrayList[String]) + val serializedMap = ser.serialize[Map[Int, List[String]]](stringMap) + ser.deserialize[HashMap[Int, List[String]]](serializedMap) + } + private def testSerializerInstanceReuse(autoReset: Boolean, referenceTracking: Boolean): Unit = { val conf = new SparkConf(loadDefaults = false) .set("spark.kryo.referenceTracking", referenceTracking.toString) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index fc42af905c2fe..62ae04dbc255f 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -27,8 +27,8 @@ breeze_2.11-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.8.4.jar -chill_2.11-0.8.4.jar +chill-java-0.9.3.jar +chill_2.11-0.9.3.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -130,7 +130,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-shaded-3.0.3.jar +kryo-shaded-4.0.2.jar kubernetes-client-3.0.0.jar kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar @@ -149,7 +149,7 @@ metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar -objenesis-2.1.jar +objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 54e50556b4620..5e12ca053af51 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -27,8 +27,8 @@ breeze_2.11-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.8.4.jar -chill_2.11-0.8.4.jar +chill-java-0.9.3.jar +chill_2.11-0.9.3.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -132,7 +132,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-shaded-3.0.3.jar +kryo-shaded-4.0.2.jar kubernetes-client-3.0.0.jar kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar @@ -151,7 +151,7 @@ metrics-jvm-3.1.5.jar minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar -objenesis-2.1.jar +objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index ff5713b5b66b7..641b4a15ad7cd 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -25,8 +25,8 @@ breeze_2.11-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.8.4.jar -chill_2.11-0.8.4.jar +chill-java-0.9.3.jar +chill_2.11-0.9.3.jar commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar @@ -146,7 +146,7 @@ kerby-config-1.0.1.jar kerby-pkix-1.0.1.jar kerby-util-1.0.1.jar kerby-xdr-1.0.1.jar -kryo-shaded-3.0.3.jar +kryo-shaded-4.0.2.jar kubernetes-client-3.0.0.jar kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar @@ -167,7 +167,7 @@ mssql-jdbc-6.2.1.jre7.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar nimbus-jose-jwt-4.41.1.jar -objenesis-2.1.jar +objenesis-2.5.1.jar okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar diff --git a/docs/tuning.md b/docs/tuning.md index 1c3bd0e8758ff..f60971aa2e0af 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -35,7 +35,7 @@ in your operations) and performance. It provides two serialization libraries: Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. * [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use - the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly + the Kryo library (version 4) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance for best performance. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 997c50157dcda..600a43242751f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.feature.{LabeledPoint, RFormula} @@ -29,6 +29,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.FloatType @@ -1687,6 +1688,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(evalSummary.deviance === summary.deviance) assert(evalSummary.aic === summary.aic) } + + test("SPARK-23131 Kryo raises StackOverflow during serializing GLR model") { + val conf = new SparkConf(false) + val ser = new KryoSerializer(conf).newInstance() + val trainer = new GeneralizedLinearRegression() + val model = trainer.fit(Seq(Instance(1.0, 1.0, Vectors.dense(1.0, 7.0))).toDF) + ser.serialize[GeneralizedLinearRegressionModel](model) + } } object GeneralizedLinearRegressionSuite { diff --git a/pom.xml b/pom.xml index 6988c65348652..da526a1709e65 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 1.6.0 9.3.24.v20180605 3.1.0 - 0.8.4 + 0.9.3 2.4.0 2.0.8 3.1.5 @@ -1770,6 +1770,10 @@ org.apache.hive hive-storage-api + + com.esotericsoftware + kryo-shaded + From 3d6b68b030ee85a0f639dd8e9b68aedf5f27b46f Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 6 Sep 2018 10:37:52 +0800 Subject: [PATCH 1554/2461] [SPARK-25313][SQL] Fix regression in FileFormatWriter output names ## What changes were proposed in this pull request? Let's see the follow example: ``` val location = "/tmp/t" val df = spark.range(10).toDF("id") df.write.format("parquet").saveAsTable("tbl") spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") spark.sql(s"CREATE TABLE tbl2(ID long) USING parquet location $location") spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") println(spark.read.parquet(location).schema) spark.table("tbl2").show() ``` The output column name in schema will be `id` instead of `ID`, thus the last query shows nothing from `tbl2`. By enabling the debug message we can see that the output naming is changed from `ID` to `id`, and then the `outputColumns` in `InsertIntoHadoopFsRelationCommand` is changed in `RemoveRedundantAliases`. ![wechatimg5](https://user-images.githubusercontent.com/1097932/44947871-6299f200-ae46-11e8-9c96-d45fe368206c.jpeg) ![wechatimg4](https://user-images.githubusercontent.com/1097932/44947866-56ae3000-ae46-11e8-8923-8b3bbe060075.jpeg) **To guarantee correctness**, we should change the output columns from `Seq[Attribute]` to `Seq[String]` to avoid its names being replaced by optimizer. I will fix project elimination related rules in https://github.com/apache/spark/pull/22311 after this one. ## How was this patch tested? Unit test. Closes #22320 from gengliangwang/fixOutputSchema. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../command/DataWritingCommand.scala | 43 ++++++++++- .../command/createDataSourceTables.scala | 4 +- .../execution/datasources/DataSource.scala | 16 ++-- .../datasources/DataSourceStrategy.scala | 4 +- .../InsertIntoHadoopFsRelationCommand.scala | 6 +- .../sql/test/DataFrameReaderWriterSuite.scala | 74 +++++++++++++++++++ .../spark/sql/hive/HiveStrategies.scala | 6 +- .../CreateHiveTableAsSelectCommand.scala | 9 ++- .../execution/InsertIntoHiveDirCommand.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 48 ++++++++++++ 11 files changed, 189 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e11dbd201004d..0a185b8472060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -41,8 +42,12 @@ trait DataWritingCommand extends Command { override final def children: Seq[LogicalPlan] = query :: Nil - // Output columns of the analyzed input query plan - def outputColumns: Seq[Attribute] + // Output column names of the analyzed input query plan. + def outputColumnNames: Seq[String] + + // Output columns of the analyzed input query plan. + def outputColumns: Seq[Attribute] = + DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames) lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics @@ -53,3 +58,35 @@ trait DataWritingCommand extends Command { def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] } + +object DataWritingCommand { + /** + * Returns output attributes with provided names. + * The length of provided names should be the same of the length of [[LogicalPlan.output]]. + */ + def logicalPlanOutputWithNames( + query: LogicalPlan, + names: Seq[String]): Seq[Attribute] = { + // Save the output attributes to a variable to avoid duplicated function calls. + val outputAttributes = query.output + assert(outputAttributes.length == names.length, + "The length of provided names doesn't match the length of output attributes.") + outputAttributes.zip(names).map { case (attr, outputName) => + attr.withName(outputName) + } + } + + /** + * Returns schema of logical plan with provided names. + * The length of provided names should be the same of the length of [[LogicalPlan.schema]]. + */ + def logicalPlanSchemaWithNames( + query: LogicalPlan, + names: Seq[String]): StructType = { + assert(query.schema.length == names.length, + "The length of provided names doesn't match the length of query schema.") + StructType(query.schema.zip(names).map { case (structField, outputName) => + structField.copy(name = outputName) + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f6ef433f2ce15..b2e1f530b5328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -139,7 +139,7 @@ case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, query: LogicalPlan, - outputColumns: Seq[Attribute]) + outputColumnNames: Seq[String]) extends DataWritingCommand { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { @@ -214,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) + dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 1dcf9f3185de9..ce3bc3dd48327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -450,7 +451,7 @@ case class DataSource( mode = mode, catalogTable = catalogTable, fileIndex = fileIndex, - outputColumns = data.output) + outputColumnNames = data.output.map(_.name)) } /** @@ -460,9 +461,9 @@ case class DataSource( * @param mode The save mode for this writing. * @param data The input query plan that produces the data to be written. Note that this plan * is analyzed and optimized. - * @param outputColumns The original output columns of the input query plan. The optimizer may not - * preserve the output column's names' case, so we need this parameter - * instead of `data.output`. + * @param outputColumnNames The original output column names of the input query plan. The + * optimizer may not preserve the output column's names' case, so we need + * this parameter instead of `data.output`. * @param physicalPlan The physical plan of the input query plan. We should run the writing * command with this physical plan instead of creating a new physical plan, * so that the metrics can be correctly linked to the given physical plan and @@ -471,8 +472,9 @@ case class DataSource( def writeAndRead( mode: SaveMode, data: LogicalPlan, - outputColumns: Seq[Attribute], + outputColumnNames: Seq[String], physicalPlan: SparkPlan): BaseRelation = { + val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -495,7 +497,9 @@ case class DataSource( s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") } } - val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + val resolved = cmd.copy( + partitionColumns = resolvedPartCols, + outputColumnNames = outputColumnNames) resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6b61e749e3063..c6000442fae76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name)) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => @@ -209,7 +209,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast mode, table, Some(t.location), - actualQuery.output) + actualQuery.output.map(_.name)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 2ae21b7df9823..484942d35c857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -56,14 +56,14 @@ case class InsertIntoHadoopFsRelationCommand( mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], - outputColumns: Seq[Attribute]) + outputColumnNames: Seq[String]) extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that - SchemaUtils.checkSchemaColumnNameDuplication( - query.schema, + SchemaUtils.checkColumnNameDuplication( + outputColumnNames, s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index b65058fffd339..237872585e11d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -805,6 +805,80 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } + test("Insert overwrite table command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long) USING parquet") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Insert overwrite table command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2") + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + test("use Spark jobs to list files") { withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { withTempDir { dir => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9fe83bb332a9a..07ee105404311 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -149,7 +149,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, - ifPartitionNotExists, query.output) + ifPartitionNotExists, query.output.map(_.name)) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) @@ -157,14 +157,14 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output.map(_.name), mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => val outputPath = new Path(storage.locationUri.get) if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath) - InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output) + InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output.map(_.name)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 27d807cc35627..0eb2f0de0acd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, - outputColumns: Seq[Attribute], + outputColumnNames: Seq[String], mode: SaveMode) extends DataWritingCommand { @@ -63,13 +63,14 @@ case class CreateHiveTableAsSelectCommand( query, overwrite = false, ifPartitionNotExists = false, - outputColumns = outputColumns).run(sparkSession, child) + outputColumnNames = outputColumnNames).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) + val schema = DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames) + catalog.createTable(tableDesc.copy(schema = schema), ignoreIfExists = false) try { // Read back the metadata of the table which was created just now. @@ -82,7 +83,7 @@ case class CreateHiveTableAsSelectCommand( query, overwrite = true, ifPartitionNotExists = false, - outputColumns = outputColumns).run(sparkSession, child) + outputColumnNames = outputColumnNames).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index cebeca0ce9444..0a73aaa94bc75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -57,7 +57,7 @@ case class InsertIntoHiveDirCommand( storage: CatalogStorageFormat, query: LogicalPlan, overwrite: Boolean, - outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) extends SaveAsHiveFile { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(storage.locationUri.nonEmpty) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 02a60f16b3b3a..75a0563e72c91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,7 +69,7 @@ case class InsertIntoHiveTable( query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean, - outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) extends SaveAsHiveFile { /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 6708a50a961fd..9acd5e1c248eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -755,6 +755,54 @@ class HiveDDLSuite } } + test("Insert overwrite Hive table should output correct schema") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + sql( + s""" + |CREATE TABLE tbl2(ID long) USING hive + |OPTIONS(fileFormat 'parquet') + |LOCATION '${path.toURI}' + """.stripMargin) + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + } + } + + test("Create Hive table as select should output correct schema") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + sql( + s""" + |CREATE TABLE tbl2 USING hive + |OPTIONS(fileFormat 'parquet') + |LOCATION '${path.toURI}' + |AS SELECT ID FROM view1 + """.stripMargin) + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + } + } + test("alter table partition - storage information") { sql("CREATE TABLE boxes (height INT, length INT) PARTITIONED BY (width INT)") sql("INSERT OVERWRITE TABLE boxes PARTITION (width=4) SELECT 4, 4") From 0a5a49a51c85d2c81c38104d3fcc8e0fa330ccc5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 5 Sep 2018 21:10:51 -0700 Subject: [PATCH 1555/2461] [SPARK-25337][SQL][TEST] runSparkSubmit` should provide non-testing mode ## What changes were proposed in this pull request? `HiveExternalCatalogVersionsSuite` Scala-2.12 test has been failing due to class path issue. It is marked as `ABORTED` because it fails at `beforeAll` during data population stage. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/ ``` org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite *** ABORTED *** Exception encountered when invoking run on a nested suite - spark-submit returned with exit code 1. ``` The root cause of the failure is that `runSparkSubmit` mixes 2.4.0-SNAPSHOT classes and old Spark (2.1.3/2.2.2/2.3.1) together during `spark-submit`. This PR aims to provide `non-test` mode execution mode to `runSparkSubmit` by removing the followings. - SPARK_TESTING - SPARK_SQL_TESTING - SPARK_PREPEND_CLASSES - SPARK_DIST_CLASSPATH Previously, in the class path, new Spark classes are behind the old Spark classes. So, new ones are unseen. However, Spark 2.4.0 reveals this bug due to the recent data source class changes. ## How was this patch tested? Manual test. After merging, it will be tested via Jenkins. ```scala $ dev/change-scala-version.sh 2.12 $ build/mvn -DskipTests -Phive -Pscala-2.12 clean package $ build/mvn -Phive -Pscala-2.12 -Dtest=none -DwildcardSuites=org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite test ... HiveExternalCatalogVersionsSuite: - backward compatibility ... Tests: succeeded 1, failed 0, canceled 0, ignored 0, pending 0 All tests passed. ``` Closes #22340 from dongjoon-hyun/SPARK-25337. Authored-by: Dongjoon Hyun Signed-off-by: Sean Owen --- .../hive/HiveExternalCatalogVersionsSuite.scala | 2 +- .../spark/sql/hive/SparkSubmitTestUtils.scala | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 5103aa8a207db..25df3339e62f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -181,7 +181,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { "--conf", s"spark.sql.test.version.index=$index", "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", tempPyFile.getCanonicalPath) - runSparkSubmit(args, Some(sparkHome.getCanonicalPath)) + runSparkSubmit(args, Some(sparkHome.getCanonicalPath), false) } tempPyFile.delete() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala index 68ed97d6d1f5a..889f81b056397 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala @@ -38,7 +38,10 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite - protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = { + protected def runSparkSubmit( + args: Seq[String], + sparkHomeOpt: Option[String] = None, + isSparkTesting: Boolean = true): Unit = { val sparkHome = sparkHomeOpt.getOrElse( sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))) val history = ArrayBuffer.empty[String] @@ -53,7 +56,14 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits { val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) val env = builder.environment() - env.put("SPARK_TESTING", "1") + if (isSparkTesting) { + env.put("SPARK_TESTING", "1") + } else { + env.remove("SPARK_TESTING") + env.remove("SPARK_SQL_TESTING") + env.remove("SPARK_PREPEND_CLASSES") + env.remove("SPARK_DIST_CLASSPATH") + } env.put("SPARK_HOME", sparkHome) def captureOutput(source: String)(line: String): Unit = { From d749d034a80f528932f613ac97f13cfb99acd207 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 6 Sep 2018 12:35:59 +0800 Subject: [PATCH 1556/2461] [SPARK-25252][SQL] Support arrays of any types by to_json ## What changes were proposed in this pull request? In the PR, I propose to extended `to_json` and support any types as element types of input arrays. It should allow converting arrays of primitive types and arrays of arrays. For example: ``` select to_json(array('1','2','3')) > ["1","2","3"] select to_json(array(array(1,2,3),array(4))) > [[1,2,3],[4]] ``` ## How was this patch tested? Added a couple sql tests for arrays of primitive type and of arrays. Also I added round trip test `from_json` -> `to_json`. Closes #22226 from MaxGekk/to_json-array. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- R/pkg/R/functions.R | 4 +- R/pkg/tests/fulltests/test_sparkSQL.R | 9 ++++ python/pyspark/sql/functions.py | 12 +++-- .../expressions/jsonExpressions.scala | 40 +++++++-------- .../sql/catalyst/json/JacksonGenerator.scala | 29 +++++------ .../sql/catalyst/json/JacksonUtils.scala | 12 +++-- .../org/apache/spark/sql/functions.scala | 18 +++---- .../sql-tests/inputs/json-functions.sql | 5 ++ .../sql-tests/results/json-functions.sql.out | 22 +++++++-- .../apache/spark/sql/JsonFunctionsSuite.scala | 49 +++++++++++++++++++ 10 files changed, 139 insertions(+), 61 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d157acc3ca47b..572dee50127b8 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1699,8 +1699,8 @@ setMethod("to_date", }) #' @details -#' \code{to_json}: Converts a column containing a \code{structType}, array of \code{structType}, -#' a \code{mapType} or array of \code{mapType} into a Column of JSON string. +#' \code{to_json}: Converts a column containing a \code{structType}, a \code{mapType} +#' or an \code{arrayType} into a Column of JSON string. #' Resolving the Column can fail if an unsupported type is encountered. #' #' @rdname column_collection_functions diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5c07a028f8b0e..0c4bdb31b027b 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1686,6 +1686,15 @@ test_that("column functions", { expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } + # Test to_json() supports arrays of primitive types and arrays + df <- sql("SELECT array(19, 42, 70) as age") + j <- collect(select(df, alias(to_json(df$age), "json"))) + expect_equal(j[order(j$json), ][1], "[19,42,70]") + + df <- sql("SELECT array(array(1, 2), array(3, 4)) as matrix") + j <- collect(select(df, alias(to_json(df$matrix), "json"))) + expect_equal(j[order(j$json), ][1], "[[1,2],[3,4]]") + # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d58d8d10e5cd3..864780e0be9bd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2289,13 +2289,11 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a :class:`StructType`, :class:`ArrayType` of - :class:`StructType`\\s, a :class:`MapType` or :class:`ArrayType` of :class:`MapType`\\s + Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType` into a JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct, array of the structs, the map or - array of the maps. - :param options: options to control converting. accepts the same options as the json datasource + :param col: name of column containing a struct, an array or a map. + :param options: options to control converting. accepts the same options as the JSON datasource >>> from pyspark.sql import Row >>> from pyspark.sql.types import * @@ -2315,6 +2313,10 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')] + >>> data = [(1, ["Alice", "Bob"])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'["Alice","Bob"]')] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 11cc88735a9a3..bd9090a07471b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -613,12 +613,11 @@ case class JsonToStructs( } /** - * Converts a [[StructType]], [[ArrayType]] of [[StructType]]s, [[MapType]] - * or [[ArrayType]] of [[MapType]]s to a json output string. + * Converts a [[StructType]], [[ArrayType]] or [[MapType]] to a JSON output string. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value", + usage = "_FUNC_(expr[, options]) - Returns a JSON string with a given struct value", examples = """ Examples: > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); @@ -660,15 +659,10 @@ case class StructsToJson( @transient lazy val gen = new JacksonGenerator( - rowSchema, writer, new JSONOptions(options, timeZoneId.get)) + inputSchema, writer, new JSONOptions(options, timeZoneId.get)) @transient - lazy val rowSchema = child.dataType match { - case st: StructType => st - case ArrayType(st: StructType, _) => st - case mt: MapType => mt - case ArrayType(mt: MapType, _) => mt - } + lazy val inputSchema = child.dataType // This converts rows to the JSON output according to the given schema. @transient @@ -680,12 +674,12 @@ case class StructsToJson( UTF8String.fromString(json) } - child.dataType match { + inputSchema match { case _: StructType => (row: Any) => gen.write(row.asInstanceOf[InternalRow]) getAndReset() - case ArrayType(_: StructType, _) => + case _: ArrayType => (arr: Any) => gen.write(arr.asInstanceOf[ArrayData]) getAndReset() @@ -693,34 +687,38 @@ case class StructsToJson( (map: Any) => gen.write(map.asInstanceOf[MapData]) getAndReset() - case ArrayType(_: MapType, _) => - (arr: Any) => - gen.write(arr.asInstanceOf[ArrayData]) - getAndReset() } } override def dataType: DataType = StringType - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case _: StructType | ArrayType(_: StructType, _) => + override def checkInputDataTypes(): TypeCheckResult = inputSchema match { + case struct: StructType => try { - JacksonUtils.verifySchema(rowSchema.asInstanceOf[StructType]) + JacksonUtils.verifySchema(struct) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } - case _: MapType | ArrayType(_: MapType, _) => + case map: MapType => // TODO: let `JacksonUtils.verifySchema` verify a `MapType` try { - val st = StructType(StructField("a", rowSchema.asInstanceOf[MapType]) :: Nil) + val st = StructType(StructField("a", map) :: Nil) JacksonUtils.verifySchema(st) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } + case array: ArrayType => + try { + JacksonUtils.verifyType(prettyName, array) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } case _ => TypeCheckResult.TypeCheckFailure( s"Input type ${child.dataType.catalogString} must be a struct, array of structs or " + "a map or array of map.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 738947766adda..9b86d865622dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer -import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ /** - * `JackGenerator` can only be initialized with a `StructType` or a `MapType`. + * `JackGenerator` can only be initialized with a `StructType`, a `MapType` or an `ArrayType`. * Once it is initialized with `StructType`, it can be used to write out a struct or an array of * struct. Once it is initialized with `MapType`, it can be used to write out a map or an array * of map. An exception will be thrown if trying to write out a struct if it is initialized with @@ -43,34 +42,32 @@ private[sql] class JacksonGenerator( // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`. private type ValueWriter = (SpecializedGetters, Int) => Unit - // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. - require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + - s"or ${MapType.simpleString} but got ${dataType.catalogString}") + // `JackGenerator` can only be initialized with a `StructType`, a `MapType` or a `ArrayType`. + require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType] + || dataType.isInstanceOf[ArrayType], + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString}, " + + s"${MapType.simpleString} or ${ArrayType.simpleString} but got ${dataType.catalogString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { case st: StructType => st.map(_.dataType).map(makeWriter).toArray case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.catalogString} must be a struct") + s"Initial type ${dataType.catalogString} must be a ${StructType.simpleString}") } // `ValueWriter` for array data storing rows of the schema. private lazy val arrElementWriter: ValueWriter = dataType match { - case st: StructType => - (arr: SpecializedGetters, i: Int) => { - writeObject(writeFields(arr.getStruct(i, st.length), st, rootFieldWriters)) - } - case mt: MapType => - (arr: SpecializedGetters, i: Int) => { - writeObject(writeMapData(arr.getMap(i), mt, mapElementWriter)) - } + case at: ArrayType => makeWriter(at.elementType) + case _: StructType | _: MapType => makeWriter(dataType) + case _ => throw new UnsupportedOperationException( + s"Initial type ${dataType.catalogString} must be " + + s"an ${ArrayType.simpleString}, a ${StructType.simpleString} or a ${MapType.simpleString}") } private lazy val mapElementWriter: ValueWriter = dataType match { case mt: MapType => makeWriter(mt.valueType) case _ => throw new UnsupportedOperationException( - s"Initial type ${dataType.catalogString} must be a map") + s"Initial type ${dataType.catalogString} must be a ${MapType.simpleString}") } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index f26b194e7a7ce..2d89c7066d080 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -32,11 +32,8 @@ object JacksonUtils { } } - /** - * Verify if the schema is supported in JSON parsing. - */ - def verifySchema(schema: StructType): Unit = { - def verifyType(name: String, dataType: DataType): Unit = dataType match { + def verifyType(name: String, dataType: DataType): Unit = { + dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType | TimestampType | DateType | BinaryType | _: DecimalType => @@ -54,7 +51,12 @@ object JacksonUtils { throw new UnsupportedOperationException( s"Unable to convert column $name of type ${dataType.catalogString} to JSON.") } + } + /** + * Verify if the schema is supported in JSON parsing. + */ + def verifySchema(schema: StructType): Unit = { schema.foreach(field => verifyType(field.name, field.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c120be469a268..10b67d7a1ca54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3612,11 +3612,11 @@ object functions { def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) /** - * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, - * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or + * a `MapType` into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct or array of the structs. + * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3628,11 +3628,11 @@ object functions { } /** - * (Java-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, - * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * (Java-specific) Converts a column containing a `StructType`, `ArrayType` or + * a `MapType` into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct or array of the structs. + * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3643,11 +3643,11 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType`, `ArrayType` of `StructType`s, - * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. + * Converts a column containing a `StructType`, `ArrayType` or + * a `MapType` into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. * - * @param e a column containing a struct or array of the structs. + * @param e a column containing a struct, an array or a map. * * @group collection_funcs * @since 2.1.0 diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 0cf370c13e8c0..0f22c0eeed581 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -51,3 +51,8 @@ select from_json('[null, {"a":2}]', 'array>'); select from_json('[{"a": 1}, {"b":2}]', 'array>'); select from_json('[{"a": 1}, 2]', 'array>'); + +-- to_json - array type +select to_json(array('1', '2', '3')); +select to_json(array(array(1, 2, 3), array(4))); + diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 7444cdbef96e4..e550b43e08c28 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 40 -- !query 0 @@ -9,7 +9,7 @@ struct -- !query 0 output Class: org.apache.spark.sql.catalyst.expressions.StructsToJson Function: to_json -Usage: to_json(expr[, options]) - Returns a json string with a given struct value +Usage: to_json(expr[, options]) - Returns a JSON string with a given struct value -- !query 1 @@ -38,7 +38,7 @@ Extended Usage: Since: 2.2.0 Function: to_json -Usage: to_json(expr[, options]) - Returns a json string with a given struct value +Usage: to_json(expr[, options]) - Returns a JSON string with a given struct value -- !query 2 @@ -354,3 +354,19 @@ select from_json('[{"a": 1}, 2]', 'array>') struct>> -- !query 37 output NULL + + +-- !query 38 +select to_json(array('1', '2', '3')) +-- !query 38 schema +struct +-- !query 38 output +["1","2","3"] + + +-- !query 39 +select to_json(array(array(1, 2, 3), array(4))) +-- !query 39 schema +struct +-- !query 39 output +[[1,2,3],[4]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index f321ab86e9b7f..fe4bf15fa3921 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -469,4 +469,53 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sql("""select json[0] from jsonTable"""), Seq(Row(null))) } + + test("to_json - array of primitive types") { + val df = Seq(Array(1, 2, 3)).toDF("a") + checkAnswer(df.select(to_json($"a")), Seq(Row("[1,2,3]"))) + } + + test("roundtrip to_json -> from_json - array of primitive types") { + val arr = Array(1, 2, 3) + val df = Seq(arr).toDF("a") + checkAnswer(df.select(from_json(to_json($"a"), ArrayType(IntegerType))), Row(arr)) + } + + test("roundtrip from_json -> to_json - array of primitive types") { + val json = "[1,2,3]" + val df = Seq(json).toDF("a") + val schema = new ArrayType(IntegerType, false) + + checkAnswer(df.select(to_json(from_json($"a", schema))), Seq(Row(json))) + } + + test("roundtrip from_json -> to_json - array of arrays") { + val json = "[[1],[2,3],[4,5,6]]" + val jsonDF = Seq(json).toDF("a") + val schema = new ArrayType(ArrayType(IntegerType, false), false) + + checkAnswer( + jsonDF.select(to_json(from_json($"a", schema))), + Seq(Row(json))) + } + + test("roundtrip from_json -> to_json - array of maps") { + val json = """[{"a":1},{"b":2}]""" + val jsonDF = Seq(json).toDF("a") + val schema = new ArrayType(MapType(StringType, IntegerType, false), false) + + checkAnswer( + jsonDF.select(to_json(from_json($"a", schema))), + Seq(Row(json))) + } + + test("roundtrip from_json -> to_json - array of structs") { + val json = """[{"a":1},{"a":2},{"a":3}]""" + val jsonDF = Seq(json).toDF("a") + val schema = new ArrayType(new StructType().add("a", IntegerType), false) + + checkAnswer( + jsonDF.select(to_json(from_json($"a", schema))), + Seq(Row(json))) + } } From 64c314e22fecca1ca3fe32378fc9374d8485deec Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 6 Sep 2018 15:27:59 +0800 Subject: [PATCH 1557/2461] [SPARK-25317][CORE] Avoid perf regression in Murmur3 Hash on UTF8String ## What changes were proposed in this pull request? SPARK-10399 introduced a performance regression on the hash computation for UTF8String. The regression can be evaluated with the code attached in the JIRA. That code runs in about 120 us per method on my laptop (MacBook Pro 2.5 GHz Intel Core i7, RAM 16 GB 1600 MHz DDR3) while the code from branch 2.3 takes on the same machine about 45 us for me. After the PR, the code takes about 45 us on the master branch too. ## How was this patch tested? running the perf test from the JIRA Closes #22338 from mgaido91/SPARK-25317. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../spark/unsafe/hash/Murmur3_x86_32.java | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index aff6e93d647fe..566f116154302 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -19,6 +19,7 @@ import com.google.common.primitives.Ints; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; @@ -59,7 +60,7 @@ public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByIntBlock(base, seed); + int h1 = hashBytesByIntBlock(base, lengthInBytes, seed); return fmix(h1, lengthInBytes); } @@ -69,14 +70,19 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { + return hashUnsafeBytesBlock(base, Ints.checkedCast(base.size()), seed); + } + + private static int hashUnsafeBytesBlock(MemoryBlock base, int lengthInBytes, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. - int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); + int h1 = hashBytesByIntBlock(base, lengthAligned, seed); + long offset = base.getBaseOffset(); + Object o = base.getBaseObject(); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = base.getByte(i); + int halfWord = Platform.getByte(o, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } @@ -84,7 +90,7 @@ public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { } public static int hashUTF8String(UTF8String str, int seed) { - return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + return hashUnsafeBytesBlock(str.getMemoryBlock(), str.numBytes(), seed); } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { @@ -101,7 +107,7 @@ public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); + int h1 = hashBytesByIntBlock(base, lengthAligned, seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { k1 ^= (base.getByte(i) & 0xFF) << shift; @@ -110,11 +116,10 @@ public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { return fmix(h1, lengthInBytes); } - private static int hashBytesByIntBlock(MemoryBlock base, int seed) { - long lengthInBytes = base.size(); + private static int hashBytesByIntBlock(MemoryBlock base, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; - for (long i = 0; i < lengthInBytes; i += 4) { + for (int i = 0; i < lengthInBytes; i += 4) { int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); From f5817d8bb33b733eeca0154d1ed207c8d1e8513f Mon Sep 17 00:00:00 2001 From: xuejianbest <384329882@qq.com> Date: Thu, 6 Sep 2018 07:17:37 -0700 Subject: [PATCH 1558/2461] [SPARK-25108][SQL] Fix the show method to display the wide character alignment problem MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is not a perfect solution. It is designed to minimize complexity on the basis of solving problems. It is effective for English, Chinese characters, Japanese, Korean and so on. ```scala before: +---+---------------------------+-------------+ |id |中国 |s2 | +---+---------------------------+-------------+ |1 |ab |[a] | |2 |null |[中国, abc] | |3 |ab1 |[hello world]| |4 |か行 きゃ(kya) きゅ(kyu) きょ(kyo) |[“中国] | |5 |中国(你好)a |[“中(国), 312] | |6 |中国山(东)服务区 |[“中(国)] | |7 |中国山东服务区 |[中(国)] | |8 | |[中国] | +---+---------------------------+-------------+ after: +---+-----------------------------------+----------------+ |id |中国 |s2 | +---+-----------------------------------+----------------+ |1 |ab |[a] | |2 |null |[中国, abc] | |3 |ab1 |[hello world] | |4 |か行 きゃ(kya) きゅ(kyu) きょ(kyo) |[“中国] | |5 |中国(你好)a |[“中(国), 312]| |6 |中国山(东)服务区 |[“中(国)] | |7 |中国山东服务区 |[中(国)] | |8 | |[中国] | +---+-----------------------------------+----------------+ ``` ## What changes were proposed in this pull request? When there are wide characters such as Chinese characters or Japanese characters in the data, the show method has a alignment problem. Try to fix this problem. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) ![image](https://user-images.githubusercontent.com/13044869/44250564-69f6b400-a227-11e8-88b2-6cf6960377ff.png) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22048 from xuejianbest/master. Authored-by: xuejianbest <384329882@qq.com> Signed-off-by: Sean Owen --- .../scala/org/apache/spark/util/Utils.scala | 30 ++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 21 ++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 18 +++---- .../org/apache/spark/sql/DatasetSuite.scala | 49 +++++++++++++++++++ 4 files changed, 109 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 15c958d3f511e..4593b057fc634 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2795,6 +2795,36 @@ private[spark] object Utils extends Logging { } } } + + /** + * Regular expression matching full width characters. + * + * Looked at all the 0x0000-0xFFFF characters (unicode) and showed them under Xshell. + * Found all the full width characters, then get the regular expression. + */ + private val fullWidthRegex = ("""[""" + + // scalastyle:off nonascii + """\u1100-\u115F""" + + """\u2E80-\uA4CF""" + + """\uAC00-\uD7A3""" + + """\uF900-\uFAFF""" + + """\uFE10-\uFE19""" + + """\uFE30-\uFE6F""" + + """\uFF00-\uFF60""" + + """\uFFE0-\uFFE6""" + + // scalastyle:on nonascii + """]""").r + + /** + * Return the number of half widths in a given string. Note that a full width character + * occupies two half widths. + * + * For a string consisting of 1 million characters, the execution of this method requires + * about 50ms. + */ + def stringHalfWidth(str: String): Int = { + if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 418d2f9b88500..943b53522d64e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1184,6 +1184,27 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === "UtilsSuite$MalformedClassObject$MalformedClass") } + + test("stringHalfWidth") { + // scalastyle:off nonascii + assert(Utils.stringHalfWidth(null) == 0) + assert(Utils.stringHalfWidth("") == 0) + assert(Utils.stringHalfWidth("ab c") == 4) + assert(Utils.stringHalfWidth("1098") == 4) + assert(Utils.stringHalfWidth("mø") == 2) + assert(Utils.stringHalfWidth("γύρ") == 3) + assert(Utils.stringHalfWidth("pê") == 2) + assert(Utils.stringHalfWidth("ー") == 2) + assert(Utils.stringHalfWidth("测") == 2) + assert(Utils.stringHalfWidth("か") == 2) + assert(Utils.stringHalfWidth("걸") == 2) + assert(Utils.stringHalfWidth("à") == 1) + assert(Utils.stringHalfWidth("焼") == 2) + assert(Utils.stringHalfWidth("羍む") == 4) + assert(Utils.stringHalfWidth("뺭ᾘ") == 3) + assert(Utils.stringHalfWidth("\u0967\u0968\u0969") == 3) + // scalastyle:on nonascii + } } private class SimpleExtension diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index db439b1ee76f1..fa14aa14ee968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -306,16 +306,16 @@ class Dataset[T] private[sql]( // Compute the width of each column for (row <- rows) { for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) + colWidths(i) = math.max(colWidths(i), Utils.stringHalfWidth(cell)) } } val paddedRows = rows.map { row => row.zipWithIndex.map { case (cell, i) => if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) + StringUtils.leftPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length) } else { - StringUtils.rightPad(cell, colWidths(i)) + StringUtils.rightPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length) } } } @@ -337,12 +337,10 @@ class Dataset[T] private[sql]( // Compute the width of field name and data columns val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => - math.max(curMax, fieldName.length) + math.max(curMax, Utils.stringHalfWidth(fieldName)) } val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => - math.max(curMax, row.map(_.length).reduceLeftOption[Int] { case (cellMax, cell) => - math.max(cellMax, cell) - }.getOrElse(0)) + math.max(curMax, row.map(cell => Utils.stringHalfWidth(cell)).max) } dataRows.zipWithIndex.foreach { case (row, i) => @@ -351,8 +349,10 @@ class Dataset[T] private[sql]( s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") sb.append(rowHeader).append("\n") row.zipWithIndex.map { case (cell, j) => - val fieldName = StringUtils.rightPad(fieldNames(j), fieldNameColWidth) - val data = StringUtils.rightPad(cell, dataColWidth) + val fieldName = StringUtils.rightPad(fieldNames(j), + fieldNameColWidth - Utils.stringHalfWidth(fieldNames(j)) + fieldNames(j).length) + val data = StringUtils.rightPad(cell, + dataColWidth - Utils.stringHalfWidth(cell) + cell.length) s" $fieldName | $data " }.addString(sb, "", "\n", "\n") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cf24eba128012..ca8fbc991a3a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -969,6 +969,55 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkShowString(ds, expected) } + test("SPARK-25108 Fix the show method to display the full width character alignment problem") { + // scalastyle:off nonascii + val df = Seq( + (0, null, 1), + (0, "", 1), + (0, "ab c", 1), + (0, "1098", 1), + (0, "mø", 1), + (0, "γύρ", 1), + (0, "pê", 1), + (0, "ー", 1), + (0, "测", 1), + (0, "か", 1), + (0, "걸", 1), + (0, "à", 1), + (0, "焼", 1), + (0, "羍む", 1), + (0, "뺭ᾘ", 1), + (0, "\u0967\u0968\u0969", 1) + ).toDF("b", "a", "c") + // scalastyle:on nonascii + val ds = df.as[ClassData] + val expected = + // scalastyle:off nonascii + """+---+----+---+ + || b| a| c| + |+---+----+---+ + || 0|null| 1| + || 0| | 1| + || 0|ab c| 1| + || 0|1098| 1| + || 0| mø| 1| + || 0| γύρ| 1| + || 0| pê| 1| + || 0| ー| 1| + || 0| 测| 1| + || 0| か| 1| + || 0| 걸| 1| + || 0| à| 1| + || 0| 焼| 1| + || 0|羍む| 1| + || 0| 뺭ᾘ| 1| + || 0| १२३| 1| + |+---+----+---+ + |""".stripMargin + // scalastyle:on nonascii + checkShowString(ds, expected) + } + test( "SPARK-15112: EmbedDeserializerInFilter should not optimize plan fragment that changes schema" ) { From 7ef6d1daf858cc9a2c390074f92aaf56c219518a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 6 Sep 2018 08:18:49 -0700 Subject: [PATCH 1559/2461] [SPARK-25328][PYTHON] Add an example for having two columns as the grouping key in group aggregate pandas UDF ## What changes were proposed in this pull request? This PR proposes to add another example for multiple grouping key in group aggregate pandas UDF since this feature could make users still confused. ## How was this patch tested? Manually tested and documentation built. Closes #22329 from HyukjinKwon/SPARK-25328. Authored-by: hyukjinkwon Signed-off-by: Bryan Cutler --- python/pyspark/sql/functions.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 864780e0be9bd..9396b16b7ada8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2783,14 +2783,14 @@ def pandas_udf(f=None, returnType=None, functionType=None): +---+-------------------+ Alternatively, the user can define a function that takes two arguments. - In this case, the grouping key will be passed as the first argument and the data will - be passed as the second argument. The grouping key will be passed as a tuple of numpy + In this case, the grouping key(s) will be passed as the first argument and the data will + be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. - This is useful when the user does not want to hardcode grouping key in the function. + This is useful when the user does not want to hardcode grouping key(s) in the function. - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP @@ -2806,6 +2806,22 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 1|1.5| | 2|6.0| +---+---+ + >>> @pandas_udf( + ... "id long, `ceil(v / 2)` long, v double", + ... PandasUDFType.GROUPED_MAP) # doctest: +SKIP + >>> def sum_udf(key, pdf): + ... # key is a tuple of two numpy.int64s, which is the values + ... # of 'id' and 'ceil(df.v / 2)' for the current group + ... return pd.DataFrame([key + (pdf.v.sum(),)]) + >>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP + +---+-----------+----+ + | id|ceil(v / 2)| v| + +---+-----------+----+ + | 2| 5|10.0| + | 1| 1| 3.0| + | 2| 3| 5.0| + | 2| 2| 3.0| + +---+-----------+----+ .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is recommended to explicitly index the columns by name to ensure the positions are correct, From 3b6591b0b064b13a411e5b8f8ee4883a69c39e2d Mon Sep 17 00:00:00 2001 From: Shahid Date: Thu, 6 Sep 2018 09:52:58 -0700 Subject: [PATCH 1560/2461] [SPARK-25268][GRAPHX] run Parallel Personalized PageRank throws serialization Exception ## What changes were proposed in this pull request? mapValues in scala is currently not serializable. To avoid the serialization issue while running pageRank, we need to use map instead of mapValues. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22271 from shahidki31/master_latest. Authored-by: Shahid Signed-off-by: Joseph K. Bradley --- .../main/scala/org/apache/spark/graphx/lib/PageRank.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 96b635f9a144e..1305c059b89ce 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -198,9 +198,11 @@ object PageRank extends Logging { val zero = Vectors.sparse(sources.size, List()).asBreeze // map of vid -> vector where for each vid, the _position of vid in source_ is set to 1.0 - val sourcesInitMap = sources.zipWithIndex.toMap.mapValues { i => - Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze - } + val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => + val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze + (vid, v) + }.toMap + val sc = graph.vertices.sparkContext val sourcesInitMapBC = sc.broadcast(sourcesInitMap) // Initialize the PageRank graph with each edge attribute having From c84bc40d7f33c71eca1c08f122cd60517f34c1f8 Mon Sep 17 00:00:00 2001 From: liyuanjian Date: Thu, 6 Sep 2018 10:17:29 -0700 Subject: [PATCH 1561/2461] [SPARK-25072][PYSPARK] Forbid extra value for custom Row ## What changes were proposed in this pull request? Add value length check in `_create_row`, forbid extra value for custom Row in PySpark. ## How was this patch tested? New UT in pyspark-sql Closes #22140 from xuanyuanking/SPARK-25072. Lead-authored-by: liyuanjian Co-authored-by: Yuanjian Li Signed-off-by: Bryan Cutler --- python/pyspark/sql/tests.py | 4 ++++ python/pyspark/sql/types.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 81c0af0b3d81b..6d9d636b23a3a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -277,6 +277,10 @@ def test_struct_field_type_name(self): struct_field = StructField("a", IntegerType()) self.assertRaises(TypeError, struct_field.typeName) + def test_invalid_create_row(self): + row_class = Row("c1", "c2") + self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) + class SQLTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0b61707c8cc0a..ce1d004c6c8ff 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1500,6 +1500,9 @@ def __contains__(self, item): # let object acts like class def __call__(self, *args): """create new Row object""" + if len(args) > len(self): + raise ValueError("Can not create Row with fields %s, expected %d values " + "but got %s" % (self, len(self), args)) return _create_row(self, args) def __getitem__(self, item): From 27d3b0a51cfd1caf05c242b45db9a78ef5868685 Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Thu, 6 Sep 2018 16:15:11 -0700 Subject: [PATCH 1562/2461] [SPARK-25222][K8S] Improve container status logging ## What changes were proposed in this pull request? Currently when running Spark on Kubernetes a logger is run by the client that watches the K8S API for events related to the Driver pod and logs them. However for the container status aspect of the logging this simply dumps the raw object which is not human readable e.g. ![screen shot 2018-08-24 at 10 37 46](https://user-images.githubusercontent.com/2104864/44577799-e0486880-a789-11e8-9ae9-fdeddacbbea8.png) ![screen shot 2018-08-24 at 10 38 14](https://user-images.githubusercontent.com/2104864/44577800-e0e0ff00-a789-11e8-81f5-3bb315dbbdb1.png) This is despite the fact that the logging class in question actually has methods to pretty print this information but only invokes these at the end of a job. This PR improves the logging to always use the pretty printing methods, additionally modifying them to include further useful information provided by the K8S API. A similar issue also exists when tasks are lost that will be addressed by further commits to this PR - [x] Improved `LoggingPodStatusWatcher` - [x] Improved container status on task failure ## How was this patch tested? Built and launched jobs with the updated Spark client and observed the new human readable output: ![screen shot 2018-08-24 at 11 09 32](https://user-images.githubusercontent.com/2104864/44579429-5353de00-a78e-11e8-9228-c750af8e6311.png) ![screen shot 2018-08-24 at 11 09 42](https://user-images.githubusercontent.com/2104864/44579430-5353de00-a78e-11e8-8fce-d5bb2a3ae65f.png) ![screen shot 2018-08-24 at 11 10 13](https://user-images.githubusercontent.com/2104864/44579431-53ec7480-a78e-11e8-9fa2-aeabc5b28ec4.png) ![screen shot 2018-08-24 at 17 47 44](https://user-images.githubusercontent.com/2104864/44596922-db090f00-a7c5-11e8-910c-bc2339f5a196.png) Suggested reviewers: liyinan926 mccheah Author: Rob Vesse Closes #22215 from rvesse/SPARK-25222. --- .../spark/deploy/k8s/KubernetesUtils.scala | 83 ++++++++++++++++++- .../k8s/submit/LoggingPodStatusWatcher.scala | 73 +--------------- .../k8s/ExecutorPodsLifecycleManager.scala | 9 +- .../ExecutorPodsLifecycleManagerSuite.scala | 9 +- 4 files changed, 95 insertions(+), 79 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 588cd9d40f9a0..f5fae7cc8c470 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,7 +16,11 @@ */ package org.apache.spark.deploy.k8s -import org.apache.spark.SparkConf +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod, Time} + +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.util.Utils private[spark] object KubernetesUtils { @@ -60,4 +64,81 @@ private[spark] object KubernetesUtils { } def parseMasterUrl(url: String): String = url.substring("k8s://".length) + + def formatPairsBundle(pairs: Seq[(String, String)], indent: Int = 1) : String = { + // Use more loggable format if value is null or empty + val indentStr = "\t" * indent + pairs.map { + case (k, v) => s"\n$indentStr $k: ${Option(v).filter(_.nonEmpty).getOrElse("N/A")}" + }.mkString("") + } + + /** + * Given a pod, output a human readable representation of its state + * + * @param pod Pod + * @return Human readable pod state + */ + def formatPodState(pod: Pod): String = { + val details = Seq[(String, String)]( + // pod metadata + ("pod name", pod.getMetadata.getName), + ("namespace", pod.getMetadata.getNamespace), + ("labels", pod.getMetadata.getLabels.asScala.mkString(", ")), + ("pod uid", pod.getMetadata.getUid), + ("creation time", formatTime(pod.getMetadata.getCreationTimestamp)), + + // spec details + ("service account name", pod.getSpec.getServiceAccountName), + ("volumes", pod.getSpec.getVolumes.asScala.map(_.getName).mkString(", ")), + ("node name", pod.getSpec.getNodeName), + + // status + ("start time", formatTime(pod.getStatus.getStartTime)), + ("phase", pod.getStatus.getPhase), + ("container status", containersDescription(pod, 2)) + ) + + formatPairsBundle(details) + } + + def containersDescription(p: Pod, indent: Int = 1): String = { + p.getStatus.getContainerStatuses.asScala.map { status => + Seq( + ("container name", status.getName), + ("container image", status.getImage)) ++ + containerStatusDescription(status) + }.map(p => formatPairsBundle(p, indent)).mkString("\n\n") + } + + def containerStatusDescription(containerStatus: ContainerStatus) + : Seq[(String, String)] = { + val state = containerStatus.getState + Option(state.getRunning) + .orElse(Option(state.getTerminated)) + .orElse(Option(state.getWaiting)) + .map { + case running: ContainerStateRunning => + Seq( + ("container state", "running"), + ("container started at", formatTime(running.getStartedAt))) + case waiting: ContainerStateWaiting => + Seq( + ("container state", "waiting"), + ("pending reason", waiting.getReason)) + case terminated: ContainerStateTerminated => + Seq( + ("container state", "terminated"), + ("container started at", formatTime(terminated.getStartedAt)), + ("container finished at", formatTime(terminated.getFinishedAt)), + ("exit code", terminated.getExitCode.toString), + ("termination reason", terminated.getReason)) + case unknown => + throw new SparkException(s"Unexpected container status type ${unknown.getClass}.") + }.getOrElse(Seq(("container state", "N/A"))) + } + + def formatTime(time: Time): String = { + if (time != null) time.getTime else "N/A" + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala index 173ac541626a7..1889fe5eb3e9b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala @@ -25,6 +25,7 @@ import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.internal.Logging import org.apache.spark.util.ThreadUtils @@ -99,82 +100,10 @@ private[k8s] class LoggingPodStatusWatcherImpl( scheduler.shutdown() } - private def formatPodState(pod: Pod): String = { - val details = Seq[(String, String)]( - // pod metadata - ("pod name", pod.getMetadata.getName), - ("namespace", pod.getMetadata.getNamespace), - ("labels", pod.getMetadata.getLabels.asScala.mkString(", ")), - ("pod uid", pod.getMetadata.getUid), - ("creation time", formatTime(pod.getMetadata.getCreationTimestamp)), - - // spec details - ("service account name", pod.getSpec.getServiceAccountName), - ("volumes", pod.getSpec.getVolumes.asScala.map(_.getName).mkString(", ")), - ("node name", pod.getSpec.getNodeName), - - // status - ("start time", formatTime(pod.getStatus.getStartTime)), - ("container images", - pod.getStatus.getContainerStatuses - .asScala - .map(_.getImage) - .mkString(", ")), - ("phase", pod.getStatus.getPhase), - ("status", pod.getStatus.getContainerStatuses.toString) - ) - - formatPairsBundle(details) - } - - private def formatPairsBundle(pairs: Seq[(String, String)]) = { - // Use more loggable format if value is null or empty - pairs.map { - case (k, v) => s"\n\t $k: ${Option(v).filter(_.nonEmpty).getOrElse("N/A")}" - }.mkString("") - } - override def awaitCompletion(): Unit = { podCompletedFuture.await() logInfo(pod.map { p => s"Container final statuses:\n\n${containersDescription(p)}" }.getOrElse("No containers were found in the driver pod.")) } - - private def containersDescription(p: Pod): String = { - p.getStatus.getContainerStatuses.asScala.map { status => - Seq( - ("Container name", status.getName), - ("Container image", status.getImage)) ++ - containerStatusDescription(status) - }.map(formatPairsBundle).mkString("\n\n") - } - - private def containerStatusDescription( - containerStatus: ContainerStatus): Seq[(String, String)] = { - val state = containerStatus.getState - Option(state.getRunning) - .orElse(Option(state.getTerminated)) - .orElse(Option(state.getWaiting)) - .map { - case running: ContainerStateRunning => - Seq( - ("Container state", "Running"), - ("Container started at", formatTime(running.getStartedAt))) - case waiting: ContainerStateWaiting => - Seq( - ("Container state", "Waiting"), - ("Pending reason", waiting.getReason)) - case terminated: ContainerStateTerminated => - Seq( - ("Container state", "Terminated"), - ("Exit code", terminated.getExitCode.toString)) - case unknown => - throw new SparkException(s"Unexpected container status type ${unknown.getClass}.") - }.getOrElse(Seq(("Container state", "N/A"))) - } - - private def formatTime(time: Time): String = { - if (time != null) time.getTime else "N/A" - } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index b28d93990313e..e2800cff7b720 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorExited import org.apache.spark.util.Utils @@ -151,13 +152,15 @@ private[spark] class ExecutorPodsLifecycleManager( private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { val pod = podState.pod + val reason = Option(pod.getStatus.getReason) + val message = Option(pod.getStatus.getMessage) s""" |The executor with id $execId exited with exit code $exitCode. - |The API gave the following brief reason: ${pod.getStatus.getReason} - |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following brief reason: ${reason.getOrElse("N/A")} + |The API gave the following message: ${message.getOrElse("N/A")} |The API gave the following container statuses: | - |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + |${containersDescription(pod)} """.stripMargin } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index 562ace9f49d4d..d8409383b4a1c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -31,6 +31,7 @@ import scala.collection.mutable import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.scheduler.ExecutorExited import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ @@ -104,13 +105,15 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte } private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + val reason = Option(failedPod.getStatus.getReason) + val message = Option(failedPod.getStatus.getMessage) s""" |The executor with id $failedExecutorId exited with exit code 1. - |The API gave the following brief reason: ${failedPod.getStatus.getReason} - |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following brief reason: ${reason.getOrElse("N/A")} + |The API gave the following message: ${message.getOrElse("N/A")} |The API gave the following container statuses: | - |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + |${containersDescription(failedPod)} """.stripMargin } From da6fa3828bb824b65f50122a8a0a0d4741551257 Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Thu, 6 Sep 2018 16:18:59 -0700 Subject: [PATCH 1563/2461] [SPARK-25262][K8S] Allow SPARK_LOCAL_DIRS to be tmpfs backed on K8S ## What changes were proposed in this pull request? The default behaviour of Spark on K8S currently is to create `emptyDir` volumes to back `SPARK_LOCAL_DIRS`. In some environments e.g. diskless compute nodes this may actually hurt performance because these are backed by the Kubelet's node storage which on a diskless node will typically be some remote network storage. Even if this is enterprise grade storage connected via a high speed interconnect the way Spark uses these directories as scratch space (lots of relatively small short lived files) has been observed to cause serious performance degradation. Therefore we would like to provide the option to use K8S's ability to instead back these `emptyDir` volumes with `tmpfs`. Therefore this PR adds a configuration option that enables `SPARK_LOCAL_DIRS` to be backed by Memory backed `emptyDir` volumes rather than the default. Documentation is added to describe both the default behaviour plus this new option and its implications. One of which is that scratch space then counts towards your pods memory limits and therefore users will need to adjust their memory requests accordingly. *NB* - This is an alternative version of PR #22256 reduced to just the `tmpfs` piece ## How was this patch tested? Ran with this option in our diskless compute environments to verify functionality Author: Rob Vesse Closes #22323 from rvesse/SPARK-25262-tmpfs. --- docs/running-on-kubernetes.md | 21 +++++++++++++ .../org/apache/spark/deploy/k8s/Config.scala | 9 ++++++ .../k8s/features/LocalDirsFeatureStep.scala | 3 ++ .../features/LocalDirsFeatureStepSuite.scala | 30 +++++++++++++++++++ 4 files changed, 63 insertions(+) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index c83dad6df1e7b..4ae7acaae2314 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -215,6 +215,19 @@ spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.clai The configuration properties for mounting volumes into the executor pods use prefix `spark.kubernetes.executor.` instead of `spark.kubernetes.driver.`. For a complete list of available options for each supported type of volumes, please refer to the [Spark Properties](#spark-properties) section below. +## Local Storage + +Spark uses temporary scratch space to spill data to disk during shuffles and other operations. When using Kubernetes as the resource manager the pods will be created with an [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir) volume mounted for each directory listed in `SPARK_LOCAL_DIRS`. If no directories are explicitly specified then a default directory is created and configured appropriately. + +`emptyDir` volumes use the ephemeral storage feature of Kubernetes and do not persist beyond the life of the pod. + +### Using RAM for local storage + +`emptyDir` volumes use the nodes backing storage for ephemeral storage by default, this behaviour may not be appropriate for some compute environments. For example if you have diskless nodes with remote storage mounted over a network, having lots of executors doing IO to this remote storage may actually degrade performance. + +In this case it may be desirable to set `spark.kubernetes.local.dirs.tmpfs=true` in your configuration which will cause the `emptyDir` volumes to be configured as `tmpfs` i.e. RAM backed volumes. When configured like this Sparks local storage usage will count towards your pods memory usage therefore you may wish to increase your memory requests by increasing the value of `spark.kubernetes.memoryOverheadFactor` as appropriate. + + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -784,6 +797,14 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + spark.kubernetes.local.dirs.tmpfs + false + + Configure the emptyDir volumes used to back SPARK_LOCAL_DIRS within the Spark driver and executor pods to use tmpfs backing i.e. RAM. See Local Storage earlier on this page + for more discussion of this. + + spark.kubernetes.memoryOverheadFactor 0.1 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 1b582fe53624a..c5f4d6c53b7f9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -225,6 +225,15 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") + val KUBERNETES_LOCAL_DIRS_TMPFS = + ConfigBuilder("spark.kubernetes.local.dirs.tmpfs") + .doc("If set to true then emptyDir volumes created to back SPARK_LOCAL_DIRS will have " + + "their medium set to Memory so that they will be created as tmpfs (i.e. RAM) backed " + + "volumes. This may improve performance but scratch space usage will count towards " + + "your pods memory limit so you may wish to request more memory.") + .booleanConf + .createWithDefault(false) + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala index 70b307303d149..be386e119d465 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -22,6 +22,7 @@ import java.util.UUID import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ private[spark] class LocalDirsFeatureStep( conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], @@ -37,6 +38,7 @@ private[spark] class LocalDirsFeatureStep( .orElse(conf.getOption("spark.local.dir")) .getOrElse(defaultLocalDir) .split(",") + private val useLocalDirTmpFs = conf.get(KUBERNETES_LOCAL_DIRS_TMPFS) override def configurePod(pod: SparkPod): SparkPod = { val localDirVolumes = resolvedLocalDirs @@ -45,6 +47,7 @@ private[spark] class LocalDirsFeatureStep( new VolumeBuilder() .withName(s"spark-local-dir-${index + 1}") .withNewEmptyDir() + .withMedium(if (useLocalDirTmpFs) "Memory" else null) .endEmptyDir() .build() } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index a339827b819a9..acdd07bc594b2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.deploy.k8s.features import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} import org.mockito.Mockito +import org.scalatest._ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val defaultLocalDir = "/var/data/default-local-dir" @@ -111,4 +113,32 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") .build()) } + + test("Use tmpfs to back default local dir") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + Mockito.doReturn(true).when(sparkConf).get(KUBERNETES_LOCAL_DIRS_TMPFS) + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .withMedium("Memory") + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } } From 1b1711e0532b1a1521054ef3b5980cdb3d70cdeb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 7 Sep 2018 10:12:20 +0800 Subject: [PATCH 1564/2461] [SPARK-25208][SQL][FOLLOW-UP] Reduce code size. ## What changes were proposed in this pull request? This is a follow-up pr of #22200. When casting to decimal type, if `Cast.canNullSafeCastToDecimal()`, overflow won't happen, so we don't need to check the result of `Decimal.changePrecision()`. ## How was this patch tested? Existing tests. Closes #22352 from ueshin/issues/SPARK-25208/reduce_code_size. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/Cast.scala | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0053503501047..8f777997bf615 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -924,27 +924,36 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, - evPrim: ExprValue, evNull: ExprValue): Block = - code""" - if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { - $evPrim = $d; - } else { - $evNull = true; - } - """ + evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean): Block = { + if (canNullSafeCast) { + code""" + |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); + |$evPrim = $d; + """.stripMargin + } else { + code""" + |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { + | $evPrim = $d; + |} else { + | $evNull = true; + |} + """.stripMargin + } + } private[this] def castToDecimalCode( from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) + val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) from match { case StringType => (c, evPrim, evNull) => code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -953,7 +962,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case DateType => // date can't cast to decimal in Hive @@ -964,19 +973,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case DecimalType() => (c, evPrim, evNull) => code""" Decimal $tmp = $c.clone(); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case x: IntegralType => (c, evPrim, evNull) => code""" Decimal $tmp = Decimal.apply((long) $c); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles @@ -984,7 +993,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; } From b0ada7dce02d101b6a04323d8185394e997caca4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 6 Sep 2018 21:41:13 -0700 Subject: [PATCH 1565/2461] [SPARK-25330][BUILD][BRANCH-2.3] Revert Hadoop 2.7 to 2.7.3 ## What changes were proposed in this pull request? How to reproduce permission issue: ```sh # build spark ./dev/make-distribution.sh --name SPARK-25330 --tgz -Phadoop-2.7 -Phive -Phive-thriftserver -Pyarn tar -zxf spark-2.4.0-SNAPSHOT-bin-SPARK-25330.tar && cd spark-2.4.0-SNAPSHOT-bin-SPARK-25330 export HADOOP_PROXY_USER=user_a bin/spark-sql export HADOOP_PROXY_USER=user_b bin/spark-sql ``` ```java Exception in thread "main" java.lang.RuntimeException: org.apache.hadoop.security.AccessControlException: Permission denied: user=user_b, access=EXECUTE, inode="/tmp/hive-$%7Buser.name%7D/user_b/668748f2-f6c5-4325-a797-fd0a7ee7f4d4":user_b:hadoop:drwx------ at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.check(FSPermissionChecker.java:319) at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkTraverse(FSPermissionChecker.java:259) at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkPermission(FSPermissionChecker.java:205) at org.apache.hadoop.hdfs.server.namenode.FSPermissionChecker.checkPermission(FSPermissionChecker.java:190) ``` The issue occurred in this commit: https://github.com/apache/hadoop/commit/feb886f2093ea5da0cd09c69bd1360a335335c86. This pr revert Hadoop 2.7 to 2.7.3 to avoid this issue. ## How was this patch tested? unit tests and manual tests. Closes #22327 from wangyum/SPARK-25330. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- assembly/README | 2 +- dev/deps/spark-deps-hadoop-2.7 | 31 +++++++++++++++---------------- docs/building-spark.md | 2 +- pom.xml | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/assembly/README b/assembly/README index affd281a1385c..d5dafab477410 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.7 + -Dhadoop.version=2.7.3 diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 5e12ca053af51..dcb5d63aeff4d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -64,21 +64,21 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.7.jar -hadoop-auth-2.7.7.jar -hadoop-client-2.7.7.jar -hadoop-common-2.7.7.jar -hadoop-hdfs-2.7.7.jar -hadoop-mapreduce-client-app-2.7.7.jar -hadoop-mapreduce-client-common-2.7.7.jar -hadoop-mapreduce-client-core-2.7.7.jar -hadoop-mapreduce-client-jobclient-2.7.7.jar -hadoop-mapreduce-client-shuffle-2.7.7.jar -hadoop-yarn-api-2.7.7.jar -hadoop-yarn-client-2.7.7.jar -hadoop-yarn-common-2.7.7.jar -hadoop-yarn-server-common-2.7.7.jar -hadoop-yarn-server-web-proxy-2.7.7.jar +hadoop-annotations-2.7.3.jar +hadoop-auth-2.7.3.jar +hadoop-client-2.7.3.jar +hadoop-common-2.7.3.jar +hadoop-hdfs-2.7.3.jar +hadoop-mapreduce-client-app-2.7.3.jar +hadoop-mapreduce-client-common-2.7.3.jar +hadoop-mapreduce-client-core-2.7.3.jar +hadoop-mapreduce-client-jobclient-2.7.3.jar +hadoop-mapreduce-client-shuffle-2.7.3.jar +hadoop-yarn-api-2.7.3.jar +hadoop-yarn-client-2.7.3.jar +hadoop-yarn-common-2.7.3.jar +hadoop-yarn-server-common-2.7.3.jar +hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar @@ -117,7 +117,6 @@ jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jetty-6.1.26.jar -jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar jline-2.14.6.jar joda-time-2.9.3.jar diff --git a/docs/building-spark.md b/docs/building-spark.md index 1d3e0b1b7d396..1501f0bb84544 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -67,7 +67,7 @@ Examples: ./build/mvn -Pyarn -DskipTests clean package # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.7 -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.3 -DskipTests clean package ## Building With Hive and JDBC Support diff --git a/pom.xml b/pom.xml index da526a1709e65..05e3b05613efd 100644 --- a/pom.xml +++ b/pom.xml @@ -2683,7 +2683,7 @@ hadoop-2.7 - 2.7.7 + 2.7.3 2.7.1 From 4e3365b577fbc9021fa237ea4e8792f5aea5d80c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 6 Sep 2018 21:43:14 -0700 Subject: [PATCH 1566/2461] [SPARK-22357][CORE][FOLLOWUP] SparkContext.binaryFiles ignore minPartitions parameter ## What changes were proposed in this pull request? This adds a test following https://github.com/apache/spark/pull/21638 ## How was this patch tested? Existing tests and new test. Closes #22356 from srowen/SPARK-22357.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../scala/org/apache/spark/FileSuite.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index a441b9c8ab97a..81b18c71f30ee 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark import java.io._ import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.zip.GZIPOutputStream import scala.io.Source +import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ @@ -299,6 +301,25 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-22357 test binaryFiles minPartitions") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.files.openCostInBytes", "0") + .set("spark.default.parallelism", "1")) + + val tempDir = Utils.createTempDir() + val tempDirPath = tempDir.getAbsolutePath + + for (i <- 0 until 8) { + val tempFile = new File(tempDir, s"part-0000$i") + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, + StandardCharsets.UTF_8) + } + + for (p <- Seq(1, 2, 8)) { + assert(sc.binaryFiles(tempDirPath, minPartitions = p).getNumPartitions === p) + } + } + test("fixed record length binary file as byte array") { sc = new SparkContext("local", "test") val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) From ed249db9c464062fbab7c6f68ad24caaa95cec82 Mon Sep 17 00:00:00 2001 From: dujunling Date: Thu, 6 Sep 2018 21:44:46 -0700 Subject: [PATCH 1567/2461] [SPARK-25237][SQL] Remove updateBytesReadWithFileSize in FileScanRDD ## What changes were proposed in this pull request? This pr removed the method `updateBytesReadWithFileSize` in `FileScanRDD` because it computes input metrics by file size supported in Hadoop 2.5 and earlier. The current Spark does not support the versions, so it causes wrong input metric numbers. This is rework from #22232. Closes #22232 ## How was this patch tested? Added tests in `FileBasedDataSourceSuite`. Closes #22324 from maropu/pr22232-2. Lead-authored-by: dujunling Co-authored-by: Takeshi Yamamuro Signed-off-by: Sean Owen --- .../execution/datasources/FileScanRDD.scala | 10 -------- .../spark/sql/FileBasedDataSourceSuite.scala | 24 +++++++++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 99fc78ff3e49b..345c9d82ca0e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -89,14 +89,6 @@ class FileScanRDD( inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) } - // If we can't get the bytes read from the FS stats, fall back to the file size, - // which may be inaccurate. - private def updateBytesReadWithFileSize(): Unit = { - if (currentFile != null) { - inputMetrics.incBytesRead(currentFile.length) - } - } - private[this] val files = split.asInstanceOf[FilePartition].files.toIterator private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null @@ -139,7 +131,6 @@ class FileScanRDD( /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { - updateBytesReadWithFileSize() if (files.hasNext) { currentFile = files.next() logInfo(s"Reading File $currentFile") @@ -208,7 +199,6 @@ class FileScanRDD( override def close(): Unit = { updateBytesRead() - updateBytesReadWithFileSize() InputFileBlockHolder.unset() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 4aa6afd69620b..304ede9c5a612 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql import java.io.{File, FileNotFoundException} import java.util.Locale +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -473,6 +476,27 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + test("SPARK-25237 compute correct input metrics in FileScanRDD") { + withTempPath { p => + val path = p.getAbsolutePath + spark.range(1000).repartition(1).write.csv(path) + val bytesReads = new mutable.ArrayBuffer[Long]() + val bytesReadListener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + bytesReads += taskEnd.taskMetrics.inputMetrics.bytesRead + } + } + sparkContext.addSparkListener(bytesReadListener) + try { + spark.read.csv(path).limit(1).collect() + sparkContext.listenerBus.waitUntilEmpty(1000L) + assert(bytesReads.sum === 7860) + } finally { + sparkContext.removeSparkListener(bytesReadListener) + } + } + } } object TestingUDT { From 6d7bc5af454341f6d9bfc1e903148ad7ba8de6f9 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 6 Sep 2018 23:35:02 -0700 Subject: [PATCH 1568/2461] [SPARK-25267][SQL][TEST] Disable ConvertToLocalRelation in the test cases of sql/core and sql/hive ## What changes were proposed in this pull request? In SharedSparkSession and TestHive, we need to disable the rule ConvertToLocalRelation for better test case coverage. ## How was this patch tested? Identify the failures after excluding "ConvertToLocalRelation" rule. Closes #22270 from dilipbiswal/SPARK-25267-final. Authored-by: Dilip Biswal Signed-off-by: gatorsmile --- .../scala/org/apache/spark/ml/util/MLTest.scala | 10 +++++++++- .../sql-tests/inputs/group-by-ordinal.sql | 4 +++- .../sql-tests/results/group-by-ordinal.sql.out | 4 +++- .../apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 14 ++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++--- .../apache/spark/sql/test/SharedSparkSession.scala | 6 ++++++ .../org/apache/spark/sql/hive/test/TestHive.scala | 8 +++++++- 8 files changed, 39 insertions(+), 14 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 76d41f9b23715..acac171346a85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -21,12 +21,13 @@ import java.io.File import org.scalatest.Suite -import org.apache.spark.SparkContext +import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext} import org.apache.spark.ml.{PredictionModel, Transformer} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -36,6 +37,13 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var checkpointDir: String = _ + protected override def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + protected override def createSparkSession: TestSparkSession = { new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", sparkConf)) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 928f766b4add2..3144833b608be 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -38,7 +38,9 @@ select a, b, sum(b) from data group by 3; select a, b, sum(b) + 2 from data group by 3; -- negative case: nondeterministic expression -select a, rand(0), sum(b) from data group by a, 2; +select a, rand(0), sum(b) +from +(select /*+ REPARTITION(1) */ a, b from data) group by a, 2; -- negative case: star select * from data group by a, b, 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 9ecbe19078dd6..cf5add6a71af2 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -135,7 +135,9 @@ aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS -- !query 13 -select a, rand(0), sum(b) from data group by a, 2 +select a, rand(0), sum(b) +from +(select /*+ REPARTITION(1) */ a, b from data) group by a, 2 -- !query 13 schema struct -- !query 13 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 85b3ca11383f7..ed110f751645d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -558,7 +558,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-18004 limit + aggregates") { withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value").repartition(1) val limit2Df = df.limit(2) checkAnswer( limit2Df.groupBy("id").count().select($"id"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 156e54300e38b..4b83e51fa8992 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -85,14 +85,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") - intercept[RuntimeException] { + val msg1 = intercept[Exception] { df5.select(map_from_arrays($"k", $"v")).collect - } + }.getMessage + assert(msg1.contains("Cannot use null as map key!")) val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") - intercept[RuntimeException] { + val msg2 = intercept[Exception] { df6.select(map_from_arrays($"k", $"v")).collect - } + }.getMessage + assert(msg2.contains("The given two arrays should have the same length")) } test("struct with column name") { @@ -2377,7 +2379,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains( "The number of lambda function arguments '3' does not match")) - val ex3 = intercept[RuntimeException] { + val ex3 = intercept[Exception] { dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() } assert(ex3.getMessage.contains("Cannot use null as map key!")) @@ -2697,7 +2699,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24734: Fix containsNull of Concat for array type") { val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v") - val ex = intercept[RuntimeException] { + val ex = intercept[Exception] { df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() } assert(ex.getMessage.contains("Cannot use null as map key")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d43fcf3c6f5de..45b17b3d4958f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContex import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -1729,10 +1730,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9083: sort with non-deterministic expressions") { - import org.apache.spark.util.random.XORShiftRandom - val seed = 33 - val df = (1 to 100).map(Tuple1.apply).toDF("i") + val df = (1 to 100).map(Tuple1.apply).toDF("i").repartition(1) val random = new XORShiftRandom(seed) val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 8968dbf36d507..e7e0ce64963a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -24,6 +24,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.internal.SQLConf /** @@ -39,6 +40,11 @@ trait SharedSparkSession .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) .set("spark.unsafe.exceptionOnMemoryLeak", "true") .set(SQLConf.CODEGEN_FALLBACK.key, "false") + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ee3f99ab7e9bb..71f15a45d162a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -36,6 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -59,7 +60,12 @@ object TestHive .set("spark.sql.warehouse.dir", TestHiveContext.makeWarehouseDir().toURI.getPath) // SPARK-8910 .set("spark.ui.enabled", "false") - .set("spark.unsafe.exceptionOnMemoryLeak", "true"))) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName))) case class TestHiveVersion(hiveClient: HiveClient) From f96a8bf8ffe9472a839ca482f64c7cdf7540c243 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 6 Sep 2018 23:36:30 -0700 Subject: [PATCH 1569/2461] [SPARK-12321][SQL][FOLLOW-UP] Add tests for fromString ## What changes were proposed in this pull request? Add test cases for fromString ## How was this patch tested? N/A Closes #22345 from gatorsmile/addTest. Authored-by: Xiao Li Signed-off-by: gatorsmile --- .../sql/catalyst/expressions/literals.scala | 46 +++++++++++-------- .../expressions/LiteralExpressionSuite.scala | 21 +++++++++ 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0efd1224f1bca..2bcbb92f1a469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -128,30 +128,36 @@ object Literal { val dataType = DataType.parseDataType(json \ "dataType") json \ "value" match { case JNull => Literal.create(null, dataType) - case JString(str) => - val value = dataType match { - case BooleanType => str.toBoolean - case ByteType => str.toByte - case ShortType => str.toShort - case IntegerType => str.toInt - case LongType => str.toLong - case FloatType => str.toFloat - case DoubleType => str.toDouble - case StringType => UTF8String.fromString(str) - case DateType => java.sql.Date.valueOf(str) - case TimestampType => java.sql.Timestamp.valueOf(str) - case CalendarIntervalType => CalendarInterval.fromString(str) - case t: DecimalType => - val d = Decimal(str) - assert(d.changePrecision(t.precision, t.scale)) - d - case _ => null - } - Literal.create(value, dataType) + case JString(str) => fromString(str, dataType) case other => sys.error(s"$other is not a valid Literal json value") } } + /** + * Constructs a Literal from a String + */ + def fromString(str: String, dataType: DataType): Literal = { + val value = dataType match { + case BooleanType => str.toBoolean + case ByteType => str.toByte + case ShortType => str.toShort + case IntegerType => str.toInt + case LongType => str.toLong + case FloatType => str.toFloat + case DoubleType => str.toDouble + case StringType => UTF8String.fromString(str) + case DateType => java.sql.Date.valueOf(str) + case TimestampType => java.sql.Timestamp.valueOf(str) + case CalendarIntervalType => CalendarInterval.fromString(str) + case t: DecimalType => + val d = Decimal(str) + assert(d.changePrecision(t.precision, t.scale)) + d + case _ => null + } + Literal.create(value, dataType) + } + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 86f80fe66d28b..3ea6bfac9ddca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -226,4 +226,25 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal('\u0000'), "\u0000") checkEvaluation(Literal.create('\n'), "\n") } + + test("fromString converts String/DataType input correctly") { + checkEvaluation(Literal.fromString(false.toString, BooleanType), false) + checkEvaluation(Literal.fromString(null, NullType), null) + checkEvaluation(Literal.fromString(Int.MaxValue.toByte.toString, ByteType), Int.MaxValue.toByte) + checkEvaluation(Literal.fromString(Short.MaxValue.toShort.toString, ShortType), Short.MaxValue + .toShort) + checkEvaluation(Literal.fromString(Int.MaxValue.toString, IntegerType), Int.MaxValue) + checkEvaluation(Literal.fromString(Long.MaxValue.toString, LongType), Long.MaxValue) + checkEvaluation(Literal.fromString(Float.MaxValue.toString, FloatType), Float.MaxValue) + checkEvaluation(Literal.fromString(Double.MaxValue.toString, DoubleType), Double.MaxValue) + checkEvaluation(Literal.fromString("1.23456", DecimalType(10, 5)), Decimal(1.23456)) + checkEvaluation(Literal.fromString("Databricks", StringType), "Databricks") + val dateString = "1970-01-01" + checkEvaluation(Literal.fromString(dateString, DateType), java.sql.Date.valueOf(dateString)) + val timestampString = "0000-01-01 00:00:00" + checkEvaluation(Literal.fromString(timestampString, TimestampType), + java.sql.Timestamp.valueOf(timestampString)) + val calInterval = new CalendarInterval(1, 1) + checkEvaluation(Literal.fromString(calInterval.toString, CalendarIntervalType), calInterval) + } } From 473f2fb3bfd0e51c40a87e475392f2e2c8f912dd Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Fri, 7 Sep 2018 09:28:33 -0700 Subject: [PATCH 1570/2461] [SPARK-21786][SQL][FOLLOWUP] Add compressionCodec test for CTAS ## What changes were proposed in this pull request? Before Apache Spark 2.3, table properties were ignored when writing data to a hive table(created with STORED AS PARQUET/ORC syntax), because the compression configurations were not passed to the FileFormatWriter in hadoopConf. Then it was fixed in #20087. But actually for CTAS with USING PARQUET/ORC syntax, table properties were ignored too when convertMastore, so the test case for CTAS not supported. Now it has been fixed in #20522 , the test case should be enabled too. ## How was this patch tested? This only re-enables the test cases of previous PR. Closes #22302 from fjh100456/compressionCodec. Authored-by: fjh100456 Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/hive/CompressionCodecSuite.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 4550d350f6db2..30204d1223846 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -122,7 +122,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo """.stripMargin) } - private def writeDateToTableUsingCTAS( + private def writeDataToTableUsingCTAS( rootDir: File, tableName: String, partitionValue: Option[String], @@ -152,7 +152,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo usingCTAS: Boolean): String = { val partitionValue = if (isPartitioned) Some("test") else None if (usingCTAS) { - writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + writeDataToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) } else { createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) writeDataToTable(tableName, partitionValue) @@ -258,8 +258,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. - Seq(false).foreach { usingCTAS => + Seq(true, false).foreach { usingCTAS => checkTableCompressionCodecForCodecs( format, isPartitioned, @@ -281,8 +280,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. - Seq(false).foreach { usingCTAS => + Seq(true, false).foreach { usingCTAS => checkTableCompressionCodecForCodecs( format, isPartitioned, From 22a46ca195eaa5ab8bcb5436f948b87d25c7cc29 Mon Sep 17 00:00:00 2001 From: cclauss Date: Fri, 7 Sep 2018 09:35:25 -0700 Subject: [PATCH 1571/2461] [SPARK-25270] lint-python: Add flake8 to find syntax errors and undefined names ## What changes were proposed in this pull request? Add [flake8](http://flake8.pycqa.org) tests to find Python syntax errors and undefined names. __E901,E999,F821,F822,F823__ are the "_showstopper_" flake8 issues that can halt the runtime with a SyntaxError, NameError, etc. Most other flake8 issues are merely "style violations" -- useful for readability but they do not effect runtime safety. * F821: undefined name `name` * F822: undefined name `name` in `__all__` * F823: local variable name referenced before assignment * E901: SyntaxError or IndentationError * E999: SyntaxError -- failed to compile a file into an Abstract Syntax Tree ## How was this patch tested? $ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__ $ __flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics__ Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22266 from cclauss/patch-3. Authored-by: cclauss Signed-off-by: Holden Karau --- dev/lint-python | 17 +++++++++++++++++ dev/requirements.txt | 1 + 2 files changed, 18 insertions(+) diff --git a/dev/lint-python b/dev/lint-python index f738af9c49763..a98a251af9e6c 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -82,6 +82,23 @@ else rm "$PYCODESTYLE_REPORT_PATH" fi +# stop the build if there are Python syntax errors or undefined names +flake8 . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics +flake8_status="${PIPESTATUS[0]}" + +if [ "$flake8_status" -eq 0 ]; then + lint_status=0 +else + lint_status=1 +fi + +if [ "$lint_status" -ne 0 ]; then + echo "flake8 checks failed." + exit "$lint_status" +else + echo "flake8 checks passed." +fi + # Check that the documentation builds acceptably, skip check if sphinx is not installed. if hash "$SPHINXBUILD" 2> /dev/null; then cd python/docs diff --git a/dev/requirements.txt b/dev/requirements.txt index fa833ab96b8e7..3fdd3425ffcc2 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -1,3 +1,4 @@ +flake8==3.5.0 jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 From 458f5011bd52851632c3592ac35f1573bc904d50 Mon Sep 17 00:00:00 2001 From: Lee Dongjin Date: Fri, 7 Sep 2018 10:36:15 -0700 Subject: [PATCH 1572/2461] [MINOR][SS] Fix kafka-0-10-sql trivials ## What changes were proposed in this pull request? Fix unused imports & outdated comments on `kafka-0-10-sql` module. (Found while I was working on [SPARK-23539](https://github.com/apache/spark/pull/22282)) ## How was this patch tested? Existing unit tests. Closes #22342 from dongjinleekr/feature/fix-kafka-sql-trivials. Authored-by: Lee Dongjin Signed-off-by: Sean Owen --- .../spark/sql/kafka010/KafkaOffsetRangeCalculator.scala | 1 - .../scala/org/apache/spark/sql/kafka010/KafkaRelation.scala | 1 - .../spark/sql/kafka010/KafkaStreamingWriteSupport.scala | 4 ++-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index 6631ae84167c8..fb209c724afba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.sources.v2.DataSourceOptions private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { require(minPartitions.isEmpty || minPartitions.get > 0) - import KafkaOffsetRangeCalculator._ /** * Calculate the offset ranges that we are going to process this batch. If `minPartitions` * is not set or is set less than or equal the number of `topicPartitions` that we're going to diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 9d856c9494e10..e6f9d1259e43e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.kafka010 -import java.{util => ju} import java.util.UUID import org.apache.kafka.common.TopicPartition diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala index dc19312f79a22..927c56d9ce829 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala @@ -54,8 +54,8 @@ class KafkaStreamingWriteSupport( } /** - * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate - * the per-task data writers. + * A [[StreamingDataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to + * generate the per-task data writers. * @param topic The topic that should be written to. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. From 9241e1e7e66574cfafa68791771959dfc39c9684 Mon Sep 17 00:00:00 2001 From: Edwina Lu Date: Fri, 7 Sep 2018 10:42:46 -0700 Subject: [PATCH 1573/2461] [SPARK-23429][CORE] Add executor memory metrics to heartbeat and expose in executors REST API Add new executor level memory metrics (JVM used memory, on/off heap execution memory, on/off heap storage memory, on/off heap unified memory, direct memory, and mapped memory), and expose via the executors REST API. This information will help provide insight into how executor and driver JVM memory is used, and for the different memory regions. It can be used to help determine good values for spark.executor.memory, spark.driver.memory, spark.memory.fraction, and spark.memory.storageFraction. ## What changes were proposed in this pull request? An ExecutorMetrics class is added, with jvmUsedHeapMemory, jvmUsedNonHeapMemory, onHeapExecutionMemory, offHeapExecutionMemory, onHeapStorageMemory, and offHeapStorageMemory, onHeapUnifiedMemory, offHeapUnifiedMemory, directMemory and mappedMemory. The new ExecutorMetrics is sent by executors to the driver as part of the Heartbeat. A heartbeat is added for the driver as well, to collect these metrics for the driver. The EventLoggingListener store information about the peak values for each metric, per active stage and executor. When a StageCompleted event is seen, a StageExecutorsMetrics event will be logged for each executor, with peak values for the stage. The AppStatusListener records the peak values for each memory metric. The new memory metrics are added to the executors REST API. ## How was this patch tested? New unit tests have been added. This was also tested on our cluster. Author: Edwina Lu Author: Imran Rashid Author: edwinalu Closes #21221 from edwinalu/SPARK-23429.2. --- .../apache/spark/SparkFirehoseListener.java | 6 + .../org/apache/spark/HeartbeatReceiver.scala | 8 +- .../scala/org/apache/spark/Heartbeater.scala | 71 ++++ .../scala/org/apache/spark/SparkContext.scala | 20 ++ .../org/apache/spark/executor/Executor.scala | 36 +- .../spark/executor/ExecutorMetrics.scala | 81 +++++ .../spark/internal/config/package.scala | 5 + .../apache/spark/memory/MemoryManager.scala | 28 ++ .../spark/metrics/ExecutorMetricType.scala | 104 ++++++ .../apache/spark/scheduler/DAGScheduler.scala | 9 +- .../scheduler/EventLoggingListener.scala | 54 ++- .../spark/scheduler/SparkListener.scala | 32 +- .../spark/scheduler/SparkListenerBus.scala | 2 + .../spark/scheduler/TaskScheduler.scala | 10 +- .../spark/scheduler/TaskSchedulerImpl.scala | 13 +- .../spark/status/AppStatusListener.scala | 44 ++- .../org/apache/spark/status/LiveEntity.scala | 9 +- .../org/apache/spark/status/api/v1/api.scala | 39 ++- .../org/apache/spark/util/JsonProtocol.scala | 47 ++- .../application_list_json_expectation.json | 15 + .../completed_app_list_json_expectation.json | 15 + ...ith_executor_metrics_json_expectation.json | 314 ++++++++++++++++++ .../limit_app_list_json_expectation.json | 30 +- .../minDate_app_list_json_expectation.json | 15 + .../minEndDate_app_list_json_expectation.json | 17 +- .../application_1506645932520_24630151 | 63 ++++ .../apache/spark/HeartbeatReceiverSuite.scala | 11 +- .../deploy/history/HistoryServerSuite.scala | 3 + .../spark/scheduler/DAGSchedulerSuite.scala | 7 +- .../scheduler/EventLoggingListenerSuite.scala | 221 +++++++++++- .../ExternalClusterManagerSuite.scala | 4 +- .../spark/scheduler/ReplayListenerSuite.scala | 5 +- .../spark/status/AppStatusListenerSuite.scala | 162 ++++++++- .../apache/spark/util/JsonProtocolSuite.scala | 107 +++++- dev/.rat-excludes | 1 + project/MimaExcludes.scala | 6 + 36 files changed, 1531 insertions(+), 83 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/Heartbeater.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala create mode 100644 core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json create mode 100644 core/src/test/resources/spark-events/application_1506645932520_24630151 diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 94c5c11b61a50..731f6fc767dfd 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -103,6 +103,12 @@ public final void onExecutorMetricsUpdate( onEvent(executorMetricsUpdate); } + @Override + public final void onStageExecutorMetrics( + SparkListenerStageExecutorMetrics executorMetrics) { + onEvent(executorMetrics); + } + @Override public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { onEvent(executorAdded); diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index bcbc8df0d5865..ab0ae55ed357d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable import scala.concurrent.Future +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ @@ -37,7 +38,8 @@ import org.apache.spark.util._ private[spark] case class Heartbeat( executorId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], // taskId -> accumulator updates - blockManagerId: BlockManagerId) + blockManagerId: BlockManagerId, + executorUpdates: ExecutorMetrics) // executor level updates /** * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is @@ -119,14 +121,14 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) context.reply(true) // Messages received from executors - case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId) => + case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId, executorMetrics) => if (scheduler != null) { if (executorLastSeen.contains(executorId)) { executorLastSeen(executorId) = clock.getTimeMillis() eventLoopThread.submit(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, accumUpdates, blockManagerId) + executorId, accumUpdates, blockManagerId, executorMetrics) val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) context.reply(response) } diff --git a/core/src/main/scala/org/apache/spark/Heartbeater.scala b/core/src/main/scala/org/apache/spark/Heartbeater.scala new file mode 100644 index 0000000000000..5ba1b9b2d828e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/Heartbeater.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.TimeUnit + +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.memory.MemoryManager +import org.apache.spark.metrics.ExecutorMetricType +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Creates a heartbeat thread which will call the specified reportHeartbeat function at + * intervals of intervalMs. + * + * @param memoryManager the memory manager for execution and storage memory. + * @param reportHeartbeat the heartbeat reporting function to call. + * @param name the thread name for the heartbeater. + * @param intervalMs the interval between heartbeats. + */ +private[spark] class Heartbeater( + memoryManager: MemoryManager, + reportHeartbeat: () => Unit, + name: String, + intervalMs: Long) extends Logging { + // Executor for the heartbeat task + private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor(name) + + /** Schedules a task to report a heartbeat. */ + def start(): Unit = { + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartbeat()) + } + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) + } + + /** Stops the heartbeat thread. */ + def stop(): Unit = { + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) + } + + /** + * Get the current executor level metrics. These are returned as an array, with the index + * determined by MetricGetter.values + */ + def getCurrentMetrics(): ExecutorMetrics = { + val metrics = ExecutorMetricType.values.map(_.getMetricValue(memoryManager)).toArray + new ExecutorMetrics(metrics) + } +} + diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e5b1e0ecd1586..d943087ab6b80 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -213,6 +213,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _files: Seq[String] = _ private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ + private var _heartbeater: Heartbeater = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -496,6 +497,11 @@ class SparkContext(config: SparkConf) extends Logging { _dagScheduler = new DAGScheduler(this) _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) + // create and start the heartbeater for collecting memory metrics + _heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, "driver-heartbeater", + conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + _heartbeater.start() + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor _taskScheduler.start() @@ -1959,6 +1965,12 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } + if (_heartbeater != null) { + Utils.tryLogNonFatalError { + _heartbeater.stop() + } + _heartbeater = null + } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) @@ -2429,6 +2441,14 @@ class SparkContext(config: SparkConf) extends Logging { } } + /** Reports heartbeat metrics for the driver. */ + private def reportHeartBeat(): Unit = { + val driverUpdates = _heartbeater.getCurrentMetrics() + val accumUpdates = new Array[(Long, Int, Int, Seq[AccumulableInfo])](0) + listenerBus.post(SparkListenerExecutorMetricsUpdate("driver", accumUpdates, + Some(driverUpdates))) + } + // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having finished construction. // NOTE: this must be placed at the end of the SparkContext constructor. diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 86b19578037df..072277cb78dc1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} +import org.apache.spark.scheduler._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -148,7 +148,8 @@ private[spark] class Executor( private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] // Executor for the heartbeat task. - private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + private val heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, + "executor-heartbeater", conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) // must be initialized before running startDriverHeartbeat() private val heartbeatReceiverRef = @@ -167,7 +168,7 @@ private[spark] class Executor( */ private var heartbeatFailures = 0 - startDriverHeartbeater() + heartbeater.start() private[executor] def numRunningTasks: Int = runningTasks.size() @@ -216,8 +217,12 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - heartbeater.shutdown() - heartbeater.awaitTermination(10, TimeUnit.SECONDS) + try { + heartbeater.stop() + } catch { + case NonFatal(e) => + logWarning("Unable to stop heartbeater", e) + } threadPool.shutdown() if (!isLocal) { env.stop() @@ -787,6 +792,9 @@ private[spark] class Executor( val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]() val curGCTime = computeTotalGcTime() + // get executor level memory metrics + val executorUpdates = heartbeater.getCurrentMetrics() + for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.mergeShuffleReadMetrics() @@ -795,7 +803,8 @@ private[spark] class Executor( } } - val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) + val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId, + executorUpdates) try { val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) @@ -815,21 +824,6 @@ private[spark] class Executor( } } } - - /** - * Schedules a task to report heartbeat and partial metrics for active tasks to driver. - */ - private def startDriverHeartbeater(): Unit = { - val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") - - // Wait a random interval so the heartbeats don't end up in sync - val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] - - val heartbeatTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) - } - heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) - } } private[spark] object Executor { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala new file mode 100644 index 0000000000000..2933f3ba6d3b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.executor + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.metrics.ExecutorMetricType + +/** + * :: DeveloperApi :: + * Metrics tracked for executors and the driver. + * + * Executor-level metrics are sent from each executor to the driver as part of the Heartbeat. + */ +@DeveloperApi +class ExecutorMetrics private[spark] extends Serializable { + + // Metrics are indexed by MetricGetter.values + private val metrics = new Array[Long](ExecutorMetricType.values.length) + + // the first element is initialized to -1, indicating that the values for the array + // haven't been set yet. + metrics(0) = -1 + + /** Returns the value for the specified metricType. */ + def getMetricValue(metricType: ExecutorMetricType): Long = { + metrics(ExecutorMetricType.metricIdxMap(metricType)) + } + + /** Returns true if the values for the metrics have been set, false otherwise. */ + def isSet(): Boolean = metrics(0) > -1 + + private[spark] def this(metrics: Array[Long]) { + this() + Array.copy(metrics, 0, this.metrics, 0, Math.min(metrics.size, this.metrics.size)) + } + + /** + * Constructor: create the ExecutorMetrics with the values specified. + * + * @param executorMetrics map of executor metric name to value + */ + private[spark] def this(executorMetrics: Map[String, Long]) { + this() + (0 until ExecutorMetricType.values.length).foreach { idx => + metrics(idx) = executorMetrics.getOrElse(ExecutorMetricType.values(idx).name, 0L) + } + } + + /** + * Compare the specified executor metrics values with the current executor metric values, + * and update the value for any metrics where the new value for the metric is larger. + * + * @param executorMetrics the executor metrics to compare + * @return if there is a new peak value for any metric + */ + private[spark] def compareAndUpdatePeakValues(executorMetrics: ExecutorMetrics): Boolean = { + var updated = false + + (0 until ExecutorMetricType.values.length).foreach { idx => + if (executorMetrics.metrics(idx) > metrics(idx)) { + updated = true + metrics(idx) = executorMetrics.metrics(idx) + } + } + updated + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 319e664a19677..ee41bd1a79ae3 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -69,6 +69,11 @@ package object config { .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") + private[spark] val EVENT_LOG_STAGE_EXECUTOR_METRICS = + ConfigBuilder("spark.eventLog.logStageExecutorMetrics.enabled") + .booleanConf + .createWithDefault(false) + private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 0641adc2ab699..4fde2d0beaa71 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -180,6 +180,34 @@ private[spark] abstract class MemoryManager( onHeapStorageMemoryPool.memoryUsed + offHeapStorageMemoryPool.memoryUsed } + /** + * On heap execution memory currently in use, in bytes. + */ + final def onHeapExecutionMemoryUsed: Long = synchronized { + onHeapExecutionMemoryPool.memoryUsed + } + + /** + * Off heap execution memory currently in use, in bytes. + */ + final def offHeapExecutionMemoryUsed: Long = synchronized { + offHeapExecutionMemoryPool.memoryUsed + } + + /** + * On heap storage memory currently in use, in bytes. + */ + final def onHeapStorageMemoryUsed: Long = synchronized { + onHeapStorageMemoryPool.memoryUsed + } + + /** + * Off heap storage memory currently in use, in bytes. + */ + final def offHeapStorageMemoryUsed: Long = synchronized { + offHeapStorageMemoryPool.memoryUsed + } + /** * Returns the execution memory consumption, in bytes, for the given task. */ diff --git a/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala new file mode 100644 index 0000000000000..cd10dad25e87b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.metrics + +import java.lang.management.{BufferPoolMXBean, ManagementFactory} +import javax.management.ObjectName + +import org.apache.spark.memory.MemoryManager + +/** + * Executor metric types for executor-level metrics stored in ExecutorMetrics. + */ +sealed trait ExecutorMetricType { + private[spark] def getMetricValue(memoryManager: MemoryManager): Long + private[spark] val name = getClass().getName().stripSuffix("$").split("""\.""").last +} + +private[spark] abstract class MemoryManagerExecutorMetricType( + f: MemoryManager => Long) extends ExecutorMetricType { + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + f(memoryManager) + } +} + +private[spark] abstract class MBeanExecutorMetricType(mBeanName: String) + extends ExecutorMetricType { + private val bean = ManagementFactory.newPlatformMXBeanProxy( + ManagementFactory.getPlatformMBeanServer, + new ObjectName(mBeanName).toString, classOf[BufferPoolMXBean]) + + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + bean.getMemoryUsed + } +} + +case object JVMHeapMemory extends ExecutorMetricType { + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + ManagementFactory.getMemoryMXBean.getHeapMemoryUsage().getUsed() + } +} + +case object JVMOffHeapMemory extends ExecutorMetricType { + override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { + ManagementFactory.getMemoryMXBean.getNonHeapMemoryUsage().getUsed() + } +} + +case object OnHeapExecutionMemory extends MemoryManagerExecutorMetricType( + _.onHeapExecutionMemoryUsed) + +case object OffHeapExecutionMemory extends MemoryManagerExecutorMetricType( + _.offHeapExecutionMemoryUsed) + +case object OnHeapStorageMemory extends MemoryManagerExecutorMetricType( + _.onHeapStorageMemoryUsed) + +case object OffHeapStorageMemory extends MemoryManagerExecutorMetricType( + _.offHeapStorageMemoryUsed) + +case object OnHeapUnifiedMemory extends MemoryManagerExecutorMetricType( + (m => m.onHeapExecutionMemoryUsed + m.onHeapStorageMemoryUsed)) + +case object OffHeapUnifiedMemory extends MemoryManagerExecutorMetricType( + (m => m.offHeapExecutionMemoryUsed + m.offHeapStorageMemoryUsed)) + +case object DirectPoolMemory extends MBeanExecutorMetricType( + "java.nio:type=BufferPool,name=direct") + +case object MappedPoolMemory extends MBeanExecutorMetricType( + "java.nio:type=BufferPool,name=mapped") + +private[spark] object ExecutorMetricType { + // List of all executor metric types + val values = IndexedSeq( + JVMHeapMemory, + JVMOffHeapMemory, + OnHeapExecutionMemory, + OffHeapExecutionMemory, + OnHeapStorageMemory, + OffHeapStorageMemory, + OnHeapUnifiedMemory, + OffHeapUnifiedMemory, + DirectPoolMemory, + MappedPoolMemory + ) + + // Map of executor metric type to its index in values. + val metricIdxMap = + Map[ExecutorMetricType, Int](ExecutorMetricType.values.zipWithIndex: _*) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 50c91da8b13d1..47108353583a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,7 +35,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils @@ -264,8 +264,11 @@ private[spark] class DAGScheduler( execId: String, // (taskId, stageId, stageAttemptId, accumUpdates) accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], - blockManagerId: BlockManagerId): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) + blockManagerId: BlockManagerId, + // executor metrics indexed by MetricGetter.values + executorUpdates: ExecutorMetrics): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, + Some(executorUpdates))) blockManagerMaster.driverEndpoint.askSync[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 69bc51c1ecf90..1629e1797977f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -23,8 +23,7 @@ import java.nio.charset.StandardCharsets import java.util.EnumSet import java.util.Locale -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, Map} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} @@ -36,6 +35,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec @@ -51,6 +51,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * spark.eventLog.overwrite - Whether to overwrite any existing files. * spark.eventLog.dir - Path to the directory in which events are logged. * spark.eventLog.buffer.kb - Buffer size to use when writing to output streams + * spark.eventLog.logStageExecutorMetrics.enabled - Whether to log stage executor metrics */ private[spark] class EventLoggingListener( appId: String, @@ -69,6 +70,7 @@ private[spark] class EventLoggingListener( private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val shouldLogStageExecutorMetrics = sparkConf.get(EVENT_LOG_STAGE_EXECUTOR_METRICS) private val testing = sparkConf.get(EVENT_LOG_TESTING) private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) @@ -93,6 +95,9 @@ private[spark] class EventLoggingListener( // Visible for tests only. private[scheduler] val logPath = getLogPath(logBaseDir, appId, appAttemptId, compressionCodecName) + // map of (stageId, stageAttempt), to peak executor metrics for the stage + private val liveStageExecutorMetrics = Map.empty[(Int, Int), Map[String, ExecutorMetrics]] + /** * Creates the log file in the configured log directory. */ @@ -155,7 +160,14 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + logEvent(event) + if (shouldLogStageExecutorMetrics) { + // record the peak metrics for the new stage + liveStageExecutorMetrics.put((event.stageInfo.stageId, event.stageInfo.attemptNumber()), + Map.empty[String, ExecutorMetrics]) + } + } override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) @@ -169,6 +181,26 @@ private[spark] class EventLoggingListener( // Events that trigger a flush override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + if (shouldLogStageExecutorMetrics) { + // clear out any previous attempts, that did not have a stage completed event + val prevAttemptId = event.stageInfo.attemptNumber() - 1 + for (attemptId <- 0 to prevAttemptId) { + liveStageExecutorMetrics.remove((event.stageInfo.stageId, attemptId)) + } + + // log the peak executor metrics for the stage, for each live executor, + // whether or not the executor is running tasks for the stage + val executorOpt = liveStageExecutorMetrics.remove( + (event.stageInfo.stageId, event.stageInfo.attemptNumber())) + executorOpt.foreach { execMap => + execMap.foreach { case (executorId, peakExecutorMetrics) => + logEvent(new SparkListenerStageExecutorMetrics(executorId, event.stageInfo.stageId, + event.stageInfo.attemptNumber(), peakExecutorMetrics)) + } + } + } + + // log stage completed event logEvent(event, flushLogger = true) } @@ -234,8 +266,18 @@ private[spark] class EventLoggingListener( } } - // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + if (shouldLogStageExecutorMetrics) { + // For the active stages, record any new peak values for the memory metrics for the executor + event.executorUpdates.foreach { executorUpdates => + liveStageExecutorMetrics.values.foreach { peakExecutorMetrics => + val peakMetrics = peakExecutorMetrics.getOrElseUpdate( + event.execId, new ExecutorMetrics()) + peakMetrics.compareAndUpdatePeakValues(executorUpdates) + } + } + } + } override def onOtherEvent(event: SparkListenerEvent): Unit = { if (event.logEvent) { @@ -296,7 +338,7 @@ private[spark] object EventLoggingListener extends Logging { private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) // A cache for compression codecs to avoid creating the same codec many times - private val codecMap = new mutable.HashMap[String, CompressionCodec] + private val codecMap = Map.empty[String, CompressionCodec] /** * Write metadata about an event log to the given stream. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 8a112f6a37b96..293e8369677f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo import org.apache.spark.{SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.ui.SparkUI @@ -160,11 +160,29 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends * Periodic updates from executors. * @param execId executor id * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates) + * @param executorUpdates executor level metrics updates */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])]) + accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])], + executorUpdates: Option[ExecutorMetrics] = None) + extends SparkListenerEvent + +/** + * Peak metric values for the executor for the stage, written to the history log at stage + * completion. + * @param execId executor id + * @param stageId stage id + * @param stageAttemptId stage attempt + * @param executorMetrics executor level metrics, indexed by MetricGetter.values + */ +@DeveloperApi +case class SparkListenerStageExecutorMetrics( + execId: String, + stageId: Int, + stageAttemptId: Int, + executorMetrics: ExecutorMetrics) extends SparkListenerEvent @DeveloperApi @@ -264,6 +282,13 @@ private[spark] trait SparkListenerInterface { */ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit + /** + * Called with the peak memory metrics for a given (executor, stage) combination. Note that this + * is only present when reading from the event log (as in the history server), and is never + * called in a live application. + */ + def onStageExecutorMetrics(executorMetrics: SparkListenerStageExecutorMetrics): Unit + /** * Called when the driver registers a new executor. */ @@ -361,6 +386,9 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onStageExecutorMetrics( + executorMetrics: SparkListenerStageExecutorMetrics): Unit = { } + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { } override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index ff19cc65552e0..8f6b7ad309602 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -57,6 +57,8 @@ private[spark] trait SparkListenerBus listener.onApplicationEnd(applicationEnd) case metricsUpdate: SparkListenerExecutorMetricsUpdate => listener.onExecutorMetricsUpdate(metricsUpdate) + case stageExecutorMetrics: SparkListenerStageExecutorMetrics => + listener.onStageExecutorMetrics(stageExecutorMetrics) case executorAdded: SparkListenerExecutorAdded => listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 95f7ae4fd39a2..94221eb0d5515 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AccumulatorV2 @@ -74,14 +75,15 @@ private[spark] trait TaskScheduler { def defaultParallelism(): Int /** - * Update metrics for in-progress tasks and let the master know that the BlockManager is still - * alive. Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. + * Update metrics for in-progress tasks and executor metrics, and let the master know that the + * BlockManager is still alive. Return true if the driver knows about the given block manager. + * Otherwise, return false, indicating that the block manager should re-register. */ def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean + blockManagerId: BlockManagerId, + executorUpdates: ExecutorMetrics): Boolean /** * Get an application ID associated with the job. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8b71170668639..4f870e85ad38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -28,6 +28,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.rpc.RpcEndpoint @@ -508,14 +509,15 @@ private[spark] class TaskSchedulerImpl( } /** - * Update metrics for in-progress tasks and let the master know that the BlockManager is still - * alive. Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. + * Update metrics for in-progress tasks and executor metrics, and let the master know that the + * BlockManager is still alive. Return true if the driver knows about the given block manager. + * Otherwise, return false, indicating that the block manager should re-register. */ override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = { + blockManagerId: BlockManagerId, + executorMetrics: ExecutorMetrics): Boolean = { // (taskId, stageId, stageAttemptId, accumUpdates) val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = { accumUpdates.flatMap { case (id, updates) => @@ -525,7 +527,8 @@ private[spark] class TaskSchedulerImpl( } } } - dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId) + dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId, + executorMetrics) } def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 91b75e4852999..304d0922a37d2 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1 @@ -66,6 +66,7 @@ private[spark] class AppStatusListener( private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]() private val liveJobs = new HashMap[Int, LiveJob]() private val liveExecutors = new HashMap[String, LiveExecutor]() + private val deadExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() @@ -204,6 +205,19 @@ private[spark] class AppStatusListener( update(rdd, now) } } + if (isExecutorActiveForLiveStages(exec)) { + // the executor was running for a currently active stage, so save it for now in + // deadExecutors, and remove when there are no active stages overlapping with the + // executor. + deadExecutors.put(event.executorId, exec) + } + } + } + + /** Was the specified executor active for any currently live stages? */ + private def isExecutorActiveForLiveStages(exec: LiveExecutor): Boolean = { + liveStages.values.asScala.exists { stage => + stage.info.submissionTime.getOrElse(0L) < exec.removeTime.getTime } } @@ -641,6 +655,9 @@ private[spark] class AppStatusListener( } } + // remove any dead executors that were not running for any currently active stages + deadExecutors.retain((execId, exec) => isExecutorActiveForLiveStages(exec)) + appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } @@ -692,6 +709,31 @@ private[spark] class AppStatusListener( } } } + + // check if there is a new peak value for any of the executor level memory metrics + // for the live UI. SparkListenerExecutorMetricsUpdate events are only processed + // for the live UI. + event.executorUpdates.foreach { updates => + liveExecutors.get(event.execId).foreach { exec => + if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(updates)) { + maybeUpdate(exec, now) + } + } + } + } + + override def onStageExecutorMetrics(executorMetrics: SparkListenerStageExecutorMetrics): Unit = { + val now = System.nanoTime() + + // check if there is a new peak value for any of the executor level memory metrics, + // while reading from the log. SparkListenerStageExecutorMetrics are only processed + // when reading logs. + liveExecutors.get(executorMetrics.execId) + .orElse(deadExecutors.get(executorMetrics.execId)).map { exec => + if (exec.peakExecutorMetrics.compareAndUpdatePeakValues(executorMetrics.executorMetrics)) { + update(exec, now) + } + } } override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 79e3f13b826ce..a0b2458549fbb 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.HashMap import com.google.common.collect.Interners import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} import org.apache.spark.status.api.v1 import org.apache.spark.storage.RDDInfo @@ -268,6 +268,9 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE def hasMemoryInfo: Boolean = totalOnHeap >= 0L + // peak values for executor level metrics + val peakExecutorMetrics = new ExecutorMetrics() + def hostname: String = if (host != null) host else hostPort.split(":")(0) override protected def doUpdate(): Any = { @@ -302,10 +305,10 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeReason), executorLogs, memoryMetrics, - blacklistedInStages) + blacklistedInStages, + Some(peakExecutorMetrics).filter(_.isSet)) new ExecutorSummaryWrapper(info) } - } private class LiveExecutorStageSummary( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 971d7e90fa7b8..77466b62ff6ed 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -22,9 +22,14 @@ import java.util.Date import scala.xml.{NodeSeq, Text} import com.fasterxml.jackson.annotation.JsonIgnoreProperties -import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.core.{JsonGenerator, JsonParser} +import com.fasterxml.jackson.core.`type`.TypeReference +import com.fasterxml.jackson.databind.{DeserializationContext, JsonDeserializer, JsonSerializer, SerializerProvider} +import com.fasterxml.jackson.databind.annotation.{JsonDeserialize, JsonSerialize} import org.apache.spark.JobExecutionStatus +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.metrics.ExecutorMetricType case class ApplicationInfo private[spark]( id: String, @@ -98,7 +103,10 @@ class ExecutorSummary private[spark]( val removeReason: Option[String], val executorLogs: Map[String, String], val memoryMetrics: Option[MemoryMetrics], - val blacklistedInStages: Set[Int]) + val blacklistedInStages: Set[Int], + @JsonSerialize(using = classOf[ExecutorMetricsJsonSerializer]) + @JsonDeserialize(using = classOf[ExecutorMetricsJsonDeserializer]) + val peakMemoryMetrics: Option[ExecutorMetrics]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, @@ -106,6 +114,33 @@ class MemoryMetrics private[spark]( val totalOnHeapStorageMemory: Long, val totalOffHeapStorageMemory: Long) +/** deserializer for peakMemoryMetrics: convert map to ExecutorMetrics */ +private[spark] class ExecutorMetricsJsonDeserializer + extends JsonDeserializer[Option[ExecutorMetrics]] { + override def deserialize( + jsonParser: JsonParser, + deserializationContext: DeserializationContext): Option[ExecutorMetrics] = { + val metricsMap = jsonParser.readValueAs[Option[Map[String, Long]]]( + new TypeReference[Option[Map[String, java.lang.Long]]] {}) + metricsMap.map(metrics => new ExecutorMetrics(metrics)) + } +} +/** serializer for peakMemoryMetrics: convert ExecutorMetrics to map with metric name as key */ +private[spark] class ExecutorMetricsJsonSerializer + extends JsonSerializer[Option[ExecutorMetrics]] { + override def serialize( + metrics: Option[ExecutorMetrics], + jsonGenerator: JsonGenerator, + serializerProvider: SerializerProvider): Unit = { + metrics.foreach { m: ExecutorMetrics => + val metricsMap = ExecutorMetricType.values.map { metricType => + metricType.name -> m.getMetricValue(metricType) + }.toMap + jsonGenerator.writeObject(metricsMap) + } + } +} + class JobData private[spark]( val jobId: Int, val name: String, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 50c6461373dee..0cd8612b8fd1c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -98,6 +99,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) + case stageExecutorMetrics: SparkListenerStageExecutorMetrics => + stageExecutorMetricsToJson(stageExecutorMetrics) case blockUpdate: SparkListenerBlockUpdated => blockUpdateToJson(blockUpdate) case _ => parse(mapper.writeValueAsString(event)) @@ -236,6 +239,7 @@ private[spark] object JsonProtocol { def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId val accumUpdates = metricsUpdate.accumUpdates + val executorMetrics = metricsUpdate.executorUpdates.map(executorMetricsToJson(_)) ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~ ("Executor ID" -> execId) ~ ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => @@ -243,7 +247,16 @@ private[spark] object JsonProtocol { ("Stage ID" -> stageId) ~ ("Stage Attempt ID" -> stageAttemptId) ~ ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList)) - }) + }) ~ + ("Executor Metrics Updated" -> executorMetrics) + } + + def stageExecutorMetricsToJson(metrics: SparkListenerStageExecutorMetrics): JValue = { + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageExecutorMetrics) ~ + ("Executor ID" -> metrics.execId) ~ + ("Stage ID" -> metrics.stageId) ~ + ("Stage Attempt ID" -> metrics.stageAttemptId) ~ + ("Executor Metrics" -> executorMetricsToJson(metrics.executorMetrics)) } def blockUpdateToJson(blockUpdate: SparkListenerBlockUpdated): JValue = { @@ -379,6 +392,14 @@ private[spark] object JsonProtocol { ("Updated Blocks" -> updatedBlocks) } + /** Convert executor metrics to JSON. */ + def executorMetricsToJson(executorMetrics: ExecutorMetrics): JValue = { + val metrics = ExecutorMetricType.values.map{ metricType => + JField(metricType.name, executorMetrics.getMetricValue(metricType)) + } + JObject(metrics: _*) + } + def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { val reason = Utils.getFormattedClassName(taskEndReason) val json: JObject = taskEndReason match { @@ -531,6 +552,7 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val stageExecutorMetrics = Utils.getFormattedClassName(SparkListenerStageExecutorMetrics) val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) } @@ -555,6 +577,7 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `stageExecutorMetrics` => stageExecutorMetricsFromJson(json) case `blockUpdate` => blockUpdateFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] @@ -585,6 +608,15 @@ private[spark] object JsonProtocol { SparkListenerTaskGettingResult(taskInfo) } + /** Extract the executor metrics from JSON. */ + def executorMetricsFromJson(json: JValue): ExecutorMetrics = { + val metrics = + ExecutorMetricType.values.map { metric => + metric.name -> jsonOption(json \ metric.name).map(_.extract[Long]).getOrElse(0L) + }.toMap + new ExecutorMetrics(metrics) + } + def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = @@ -691,7 +723,18 @@ private[spark] object JsonProtocol { (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson) (taskId, stageId, stageAttemptId, updates) } - SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) + val executorUpdates = jsonOption(json \ "Executor Metrics Updated").map { + executorUpdate => executorMetricsFromJson(executorUpdate) + } + SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates, executorUpdates) + } + + def stageExecutorMetricsFromJson(json: JValue): SparkListenerStageExecutorMetrics = { + val execId = (json \ "Executor ID").extract[String] + val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] + val executorMetrics = executorMetricsFromJson(json \ "Executor Metrics") + SparkListenerStageExecutorMetrics(execId, stageId, stageAttemptId, executorMetrics) } def blockUpdateFromJson(json: JValue): SparkListenerBlockUpdated = { diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index 4fecf84db65a2..eea6f595efd2a 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 4fecf84db65a2..7bc7f31be097b 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json new file mode 100644 index 0000000000000..9bf2086cc8e72 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json @@ -0,0 +1,314 @@ +[ { + "id" : "driver", + "hostPort" : "node0033.grid.company.com:60749", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 1043437977, + "addTime" : "2018-04-19T23:55:05.107GMT", + "executorLogs" : { }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 1043437977, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 905801, + "JVMOffHeapMemory" : 205304696, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 905801, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 397602, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 629553808, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "7", + "hostPort" : "node6340.grid.company.com:5933", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:49.826GMT", + "executorLogs" : { + "stdout" : "http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stdout?start=-4096", + "stderr" : "http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ] +}, { + "id" : "6", + "hostPort" : "node6644.grid.company.com:8445", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:47.549GMT", + "executorLogs" : { + "stdout" : "http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stdout?start=-4096", + "stderr" : "http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ] +}, { + "id" : "5", + "hostPort" : "node2477.grid.company.com:20123", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 1, + "totalTasks" : 1, + "totalDuration" : 9252, + "totalGCTime" : 920, + "totalInputBytes" : 36838295, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 355051, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:43.160GMT", + "executorLogs" : { + "stdout" : "http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stdout?start=-4096", + "stderr" : "http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ] +}, { + "id" : "4", + "hostPort" : "node4243.grid.company.com:16084", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 3, + "totalTasks" : 3, + "totalDuration" : 15645, + "totalGCTime" : 405, + "totalInputBytes" : 87272855, + "totalShuffleRead" : 438675, + "totalShuffleWrite" : 26773039, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:12.278GMT", + "executorLogs" : { + "stdout" : "http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stdout?start=-4096", + "stderr" : "http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 63104457, + "JVMOffHeapMemory" : 95657456, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 100853193, + "OnHeapExecutionMemory" : 37748736, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 126261, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 518613056, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "3", + "hostPort" : "node0998.grid.company.com:45265", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 1, + "totalTasks" : 1, + "totalDuration" : 14491, + "totalGCTime" : 342, + "totalInputBytes" : 50409514, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 31362123, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:12.088GMT", + "executorLogs" : { + "stdout" : "http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stdout?start=-4096", + "stderr" : "http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 69535048, + "JVMOffHeapMemory" : 90709624, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 69535048, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 87796, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 726805712, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "2", + "hostPort" : "node4045.grid.company.com:29262", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 1, + "totalTasks" : 1, + "totalDuration" : 14113, + "totalGCTime" : 326, + "totalInputBytes" : 50423423, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 22950296, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:12.471GMT", + "executorLogs" : { + "stdout" : "http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stdout?start=-4096", + "stderr" : "http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 58468944, + "JVMOffHeapMemory" : 91208368, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 58468944, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 87796, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 595946552, + "OffHeapStorageMemory" : 0 + } +}, { + "id" : "1", + "hostPort" : "node1404.grid.company.com:34043", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 3, + "totalTasks" : 3, + "totalDuration" : 15665, + "totalGCTime" : 471, + "totalInputBytes" : 98905018, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 20594744, + "isBlacklisted" : false, + "maxMemory" : 956615884, + "addTime" : "2018-04-19T23:55:11.695GMT", + "executorLogs" : { + "stdout" : "http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stdout?start=-4096", + "stderr" : "http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 956615884, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 47962185, + "JVMOffHeapMemory" : 100519936, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 47962185, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 98230, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 755008624, + "OffHeapStorageMemory" : 0 + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index 79950b0dc6486..9e1e65a358815 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { @@ -28,19 +43,4 @@ "startTimeEpoch" : 1515492942372, "endTimeEpoch" : 1515493477606 } ] -}, { - "id" : "app-20161116163331-0000", - "name" : "Spark shell", - "attempts" : [ { - "startTime" : "2016-11-16T22:33:29.916GMT", - "endTime" : "2016-11-16T22:33:40.587GMT", - "lastUpdated" : "", - "duration" : 10671, - "sparkUser" : "jose", - "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "lastUpdatedEpoch" : 0, - "startTimeEpoch" : 1479335609916, - "endTimeEpoch" : 1479335620587 - } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 7d60977dcd4fe..28c6bf1b3e01e 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index dfbfd8aedcc23..f547b79f47e1a 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1506645932520_24630151", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-04-19T23:54:42.734GMT", + "endTime" : "2018-04-19T23:56:29.134GMT", + "lastUpdated" : "", + "duration" : 106400, + "sparkUser" : "edlu", + "completed" : true, + "appSparkVersion" : "2.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1524182082734, + "endTimeEpoch" : 1524182189134 + } ] +}, { "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { @@ -101,4 +116,4 @@ "startTimeEpoch" : 1430917380880, "endTimeEpoch" : 1430917380890 } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/spark-events/application_1506645932520_24630151 b/core/src/test/resources/spark-events/application_1506645932520_24630151 new file mode 100644 index 0000000000000..c48ed741c56e0 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1506645932520_24630151 @@ -0,0 +1,63 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.4.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"node0033.grid.company.com","Port":60749},"Maximum Memory":1043437977,"Timestamp":1524182105107,"Maximum Onheap Memory":1043437977,"Maximum Offheap Memory":0} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/java/jdk1.8.0_31/jre","Java Version":"1.8.0_31 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.jars.ivySettings":"/export/apps/spark/commonconf/ivysettings.xml","spark.serializer":"org.apache.spark.serializer.KryoSerializer","spark.driver.host":"node0033.grid.company.com","spark.dynamicAllocation.sustainedSchedulerBacklogTimeout":"5","spark.eventLog.enabled":"true","spark.ui.port":"0","spark.driver.port":"57705","spark.shuffle.service.enabled":"true","spark.ui.acls.enable":"true","spark.reducer.maxSizeInFlight":"48m","spark.yarn.queue":"spark_default","spark.repl.class.uri":"spark://node0033.grid.company.com:57705/classes","spark.jars":"","spark.yarn.historyServer.address":"clustersh01.grid.company.com:18080","spark.memoryOverhead.multiplier.percent":"10","spark.repl.class.outputDir":"/grid/a/mapred/tmp/spark-21b68b4b-c1db-460e-a228-b87545d870f1/repl-58778a76-04c1-434d-bfb7-9a9b83afe718","spark.dynamicAllocation.cachedExecutorIdleTimeout":"1200","spark.yarn.access.namenodes":"hdfs://clusternn02.grid.company.com:9000","spark.app.name":"Spark shell","spark.dynamicAllocation.schedulerBacklogTimeout":"5","spark.yarn.security.credentials.hive.enabled":"false","spark.yarn.am.cores":"1","spark.memoryOverhead.min":"384","spark.scheduler.mode":"FIFO","spark.driver.memory":"2G","spark.executor.instances":"4","spark.isolated.classloader.additional.classes.prefix":"com_company_","spark.logConf":"true","spark.ui.showConsoleProgress":"true","spark.user.priority.jars":"*********(redacted)","spark.isolated.classloader":"true","spark.sql.sources.schemaStringLengthThreshold":"40000","spark.yarn.secondary.jars":"spark-avro_2.11-3.2.0.21.jar,grid-topology-1.0.jar","spark.reducer.maxBlocksInFlightPerAddress":"100","spark.dynamicAllocation.maxExecutors":"900","spark.yarn.appMasterEnv.LD_LIBRARY_PATH":"/export/apps/hadoop/latest/lib/native","spark.executor.id":"driver","spark.yarn.am.memory":"2G","spark.driver.cores":"1","spark.search.packages":"com.company.dali:dali-data-spark,com.company.spark-common:spark-common","spark.min.mem.vore.ratio":"5","spark.sql.sources.partitionOverwriteMode":"DYNAMIC","spark.submit.deployMode":"client","spark.yarn.maxAppAttempts":"1","spark.master":"yarn","spark.default.packages":"com.company.dali:dali-data-spark:8.+?classifier=all,com.company.spark-common:spark-common_2.10:0.+?","spark.isolated.classloader.default.jar":"*dali-data-spark*","spark.authenticate":"true","spark.eventLog.usexattr":"true","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51","spark.reducer.maxReqsInFlight":"10","spark.eventLog.dir":"hdfs://clusternn02.grid.company.com:9000/system/spark-history","spark.dynamicAllocation.enabled":"true","spark.sql.catalogImplementation":"hive","spark.isolated.classes":"org.apache.hadoop.hive.ql.io.CombineHiveInputFormat$CombineHiveInputSplit","spark.eventLog.compress":"true","spark.executor.cores":"1","spark.version":"2.1.0","spark.driver.appUIAddress":"http://node0033.grid.company.com:8364","spark.repl.local.jars":"file:///export/home/edlu/spark-avro_2.11-3.2.0.21.jar,file:///export/apps/hadoop/site/lib/grid-topology-1.0.jar","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"clusterwp01.grid.company.com","spark.min.memory-gb.size":"10","spark.dynamicAllocation.minExecutors":"1","spark.dynamicAllocation.initialExecutors":"3","spark.expressionencoder.org.apache.avro.specific.SpecificRecord":"com.databricks.spark.avro.AvroEncoder$","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://clusterwp01.grid.company.com:8080/proxy/application_1506645932520_24630151","spark.executorEnv.LD_LIBRARY_PATH":"/export/apps/hadoop/latest/lib/native","spark.dynamicAllocation.executorIdleTimeout":"150","spark.shell.auto.node.labeling":"true","spark.yarn.dist.jars":"file:///export/home/edlu/spark-avro_2.11-3.2.0.21.jar,file:///export/apps/hadoop/site/lib/grid-topology-1.0.jar","spark.app.id":"application_1506645932520_24630151","spark.ui.view.acls":"*"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/java/jdk1.8.0_31/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.31-b07","java.endorsed.dirs":"/usr/java/jdk1.8.0_31/jre/lib/endorsed","java.runtime.version":"1.8.0_31-b13","java.vm.info":"mixed mode","java.ext.dirs":"/usr/java/jdk1.8.0_31/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/java/jdk1.8.0_31/jre/lib/resources.jar:/usr/java/jdk1.8.0_31/jre/lib/rt.jar:/usr/java/jdk1.8.0_31/jre/lib/sunrsasign.jar:/usr/java/jdk1.8.0_31/jre/lib/jsse.jar:/usr/java/jdk1.8.0_31/jre/lib/jce.jar:/usr/java/jdk1.8.0_31/jre/lib/charsets.jar:/usr/java/jdk1.8.0_31/jre/lib/jfr.jar:/usr/java/jdk1.8.0_31/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"2.6.32-504.16.2.el6.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --class org.apache.spark.repl.Main --name Spark shell --jars /export/home/edlu/spark-avro_2.11-3.2.0.21.jar,/export/apps/hadoop/site/lib/grid-topology-1.0.jar --num-executors 4 spark-shell","java.home":"/usr/java/jdk1.8.0_31/jre","java.version":"1.8.0_31","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/guice-servlet-3.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/derby-10.12.1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/htrace-core-3.0.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-reflect-2.11.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-graphx_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/api-util-1.0.0-M20.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-client-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/base64-2.3.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-auth-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/validation-api-1.1.0.Final.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-api-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/objenesis-2.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/conf/":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/httpclient-4.5.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/kryo-shaded-3.0.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-library-2.11.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-net-3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xz-1.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-jackson_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-server-1.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-annotations-2.6.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/activation-1.1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spire_2.11-0.13.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arpack_combined_all-0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/libthrift-0.9.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/aircompressor-0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-jackson-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/asm-3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-hive_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/ivy-2.4.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/export/apps/hadoop/site/etc/hadoop/":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/snappy-java-1.1.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arrow-format-0.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/netty-all-4.1.17.Final.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/avro-ipc-1.7.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xmlenc-0.52.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jdo-api-3.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/curator-client-2.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/antlr-runtime-3.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/pyrolite-4.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-catalyst_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-collections-3.2.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/slf4j-api-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stream-2.7.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-format-2.3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arrow-vector-0.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-server-web-proxy-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/htrace-core-3.1.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-sketch_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-common-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hppc-0.7.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-sql_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/univocity-parsers-2.5.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-math3-3.4.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-compiler-3.0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-beanutils-1.7.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/java-xmlbuilder-1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.inject-1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-annotations-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/netty-3.9.9.Final.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/zookeeper-3.4.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/guice-3.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-compiler-2.11.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/aopalliance-1.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-yarn_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/JavaEWAH-0.3.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jsr305-1.3.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/libfb303-0.9.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.annotation-api-1.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-server-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-digester-1.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-jvm-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/curator-framework-2.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/paranamer-2.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/janino-3.0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-core-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-server-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/orc-core-1.4.3-nohive.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jsch-0.1.42.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-unsafe_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-codec-1.10.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jtransforms-2.4.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/lz4-java-1.4.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/datanucleus-core-3.2.10.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stax-api-1.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/core-1.1.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/leveldbjni-all-1.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-dbcp-1.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-lang3-3.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/chill-java-0.8.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jodd-core-3.5.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-pool-1.5.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/minlog-1.3.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/gson-2.2.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/py4j-0.10.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-streaming_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-core-2.6.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/machinist_2.11-0.6.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/avro-1.7.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/snappy-0.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-app-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-graphite-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-core-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-mllib-local_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/arrow-memory-0.8.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/breeze_2.11-0.13.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-guava-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-client-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xercesImpl-2.9.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-tags_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javolution-5.5.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jetty-6.1.26.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/joda-time-2.9.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/antlr-2.7.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-jobclient-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-lang-2.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/compress-lzf-1.0.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-crypto-1.0.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-core-1.9.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/curator-recipes-2.7.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/guava-14.0.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-core_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jetty-sslengine-6.1.26.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-network-common_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-launcher_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-ast_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/antlr4-runtime-4.7.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jetty-util-6.1.26.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jaxb-api-2.2.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-io-2.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-encoding-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/httpcore-4.4.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jackson-xc-1.9.13.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/protobuf-java-2.5.0.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-scalap_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-mllib_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-configuration-1.6.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-compress-1.4.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/json4s-core_2.11-3.5.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/orc-mapreduce-1.4.3-nohive.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/ST4-4.0.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-mapreduce-client-shuffle-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-repl_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/opencsv-2.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-logging-1.1.3.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-cli-1.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-client-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-yarn-common-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hadoop-hdfs-2.7.4.51.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/log4j-1.2.17.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-column-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/chill_2.11-0.8.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stringtemplate-3.2.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/parquet-common-1.8.2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-network-shuffle_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/spark-kvstore_2.11-2.4.0-SNAPSHOT.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/stax-api-1.0-2.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jta-1.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/javassist-3.18.1-GA.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/commons-httpclient-3.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jets3t-0.9.4.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/metrics-json-3.1.5.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/oro-2.0.8.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/jsp-api-2.1.jar":"System Classpath","/export/home/edlu/spark-2.4.0-SNAPSHOT-bin-2.7.4.51/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1506645932520_24630151","Timestamp":1524182082734,"User":"edlu"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182111695,"Executor ID":"1","Executor Info":{"Host":"node1404.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stdout?start=-4096","stderr":"http://node1404.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000002/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"node1404.grid.company.com","Port":34043},"Maximum Memory":956615884,"Timestamp":1524182111795,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182112088,"Executor ID":"3","Executor Info":{"Host":"node0998.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stdout?start=-4096","stderr":"http://node0998.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000005/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"node0998.grid.company.com","Port":45265},"Maximum Memory":956615884,"Timestamp":1524182112208,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182112278,"Executor ID":"4","Executor Info":{"Host":"node4243.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stdout?start=-4096","stderr":"http://node4243.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000006/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"node4243.grid.company.com","Port":16084},"Maximum Memory":956615884,"Timestamp":1524182112408,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182112471,"Executor ID":"2","Executor Info":{"Host":"node4045.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stdout?start=-4096","stderr":"http://node4045.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000004/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"node4045.grid.company.com","Port":29262},"Maximum Memory":956615884,"Timestamp":1524182112578,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart","executionId":0,"description":"createOrReplaceTempView at :40","details":"org.apache.spark.sql.Dataset.createOrReplaceTempView(Dataset.scala:3033)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line44.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line44.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line44.$read$$iw$$iw$$iw$$iw.(:59)\n$line44.$read$$iw$$iw$$iw.(:61)\n$line44.$read$$iw$$iw.(:63)\n$line44.$read$$iw.(:65)\n$line44.$read.(:67)\n$line44.$read$.(:71)\n$line44.$read$.()\n$line44.$eval$.$print$lzycompute(:7)\n$line44.$eval$.$print(:6)\n$line44.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","physicalPlanDescription":"== Parsed Logical Plan ==\nCreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n\n== Analyzed Logical Plan ==\nCreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n\n== Optimized Logical Plan ==\nCreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n\n== Physical Plan ==\nExecute CreateViewCommand\n +- CreateViewCommand `apps`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro","sparkPlanInfo":{"nodeName":"Execute CreateViewCommand","simpleString":"Execute CreateViewCommand","children":[],"metrics":[]},"time":1524182125829} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd","executionId":0,"time":1524182125832} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart","executionId":1,"description":"createOrReplaceTempView at :40","details":"org.apache.spark.sql.Dataset.createOrReplaceTempView(Dataset.scala:3033)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line48.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line48.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line48.$read$$iw$$iw$$iw$$iw.(:59)\n$line48.$read$$iw$$iw$$iw.(:61)\n$line48.$read$$iw$$iw.(:63)\n$line48.$read$$iw.(:65)\n$line48.$read.(:67)\n$line48.$read$.(:71)\n$line48.$read$.()\n$line48.$eval$.$print$lzycompute(:7)\n$line48.$eval$.$print(:6)\n$line48.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","physicalPlanDescription":"== Parsed Logical Plan ==\nCreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Analyzed Logical Plan ==\nCreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Optimized Logical Plan ==\nCreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Physical Plan ==\nExecute CreateViewCommand\n +- CreateViewCommand `sys_props`, false, true, LocalTempView\n +- AnalysisBarrier\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro","sparkPlanInfo":{"nodeName":"Execute CreateViewCommand","simpleString":"Execute CreateViewCommand","children":[],"metrics":[]},"time":1524182128463} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd","executionId":1,"time":1524182128463} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart","executionId":2,"description":"show at :40","details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","physicalPlanDescription":"== Parsed Logical Plan ==\nGlobalLimit 21\n+- LocalLimit 21\n +- AnalysisBarrier\n +- Project [cast(appId#0 as string) AS appId#397, cast(attemptId#1 as string) AS attemptId#398, cast(name#2 as string) AS name#399, cast(mode#3 as string) AS mode#400, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, cast(endTime#6 as string) AS endTime#403, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, cast(lastUpdated#8 as string) AS lastUpdated#405, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, cast(sparkUser#10 as string) AS sparkUser#407, cast(startTime#11 as string) AS startTime#408, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, cast(appSparkVersion#13 as string) AS appSparkVersion#410, cast(endDate#28 as string) AS endDate#411, cast(azkaban.link.workflow.url#159 as string) AS azkaban.link.workflow.url#412, cast(azkaban.link.execution.url#161 as string) AS azkaban.link.execution.url#413, cast(azkaban.link.job.url#163 as string) AS azkaban.link.job.url#414, cast(user.name#165 as string) AS user.name#415]\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- Join LeftOuter, (appId#0 = appId#137)\n :- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Analyzed Logical Plan ==\nappId: string, attemptId: string, name: string, mode: string, completed: string, duration: string, endTime: string, endTimeEpoch: string, lastUpdated: string, lastUpdatedEpoch: string, sparkUser: string, startTime: string, startTimeEpoch: string, appSparkVersion: string, endDate: string, azkaban.link.workflow.url: string, azkaban.link.execution.url: string, azkaban.link.job.url: string, user.name: string\nGlobalLimit 21\n+- LocalLimit 21\n +- Project [cast(appId#0 as string) AS appId#397, cast(attemptId#1 as string) AS attemptId#398, cast(name#2 as string) AS name#399, cast(mode#3 as string) AS mode#400, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, cast(endTime#6 as string) AS endTime#403, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, cast(lastUpdated#8 as string) AS lastUpdated#405, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, cast(sparkUser#10 as string) AS sparkUser#407, cast(startTime#11 as string) AS startTime#408, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, cast(appSparkVersion#13 as string) AS appSparkVersion#410, cast(endDate#28 as string) AS endDate#411, cast(azkaban.link.workflow.url#159 as string) AS azkaban.link.workflow.url#412, cast(azkaban.link.execution.url#161 as string) AS azkaban.link.execution.url#413, cast(azkaban.link.job.url#163 as string) AS azkaban.link.job.url#414, cast(user.name#165 as string) AS user.name#415]\n +- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- Join LeftOuter, (appId#0 = appId#137)\n :- Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- Relation[appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] avro\n +- Aggregate [appId#137], [appId#137, first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else cast(null as string), true) AS azkaban.link.workflow.url#159, first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else cast(null as string), true) AS azkaban.link.execution.url#161, first(if ((key#148 <=> azkaban.link.job.url)) value#149 else cast(null as string), true) AS azkaban.link.job.url#163, first(if ((key#148 <=> user.name)) value#149 else cast(null as string), true) AS user.name#165]\n +- Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Project [appId#137, col#145]\n +- Generate explode(systemProperties#135), false, [col#145]\n +- Relation[runtime#133,sparkProperties#134,systemProperties#135,classpathEntries#136,appId#137,attemptId#138] avro\n\n== Optimized Logical Plan ==\nGlobalLimit 21\n+- LocalLimit 21\n +- Project [appId#0, attemptId#1, name#2, mode#3, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, endTime#6, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, lastUpdated#8, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, sparkUser#10, startTime#11, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, appSparkVersion#13, cast(endDate#28 as string) AS endDate#411, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n +- *(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n : +- *(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct azkaban.link.workflow.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165])\n +- *(4) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- Exchange hashpartitioning(appId#137, 200)\n +- SortAggregate(key=[appId#137], functions=[partial_first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), partial_first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, first#273, valueSet#274, first#275, valueSet#276, first#277, valueSet#278, first#279, valueSet#280])\n +- *(3) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- *(3) Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Generate explode(systemProperties#135), [appId#137], false, [col#145]\n +- *(2) FileScan avro [systemProperties#135,appId#137] Batched: false, Format: com.databricks.spark.avro.DefaultSource@485d3d1, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct>,appId:string>\n\n== Physical Plan ==\nCollectLimit 21\n+- *(1) LocalLimit 21\n +- *(1) Project [appId#0, attemptId#1, name#2, mode#3, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, endTime#6, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, lastUpdated#8, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, sparkUser#10, startTime#11, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, appSparkVersion#13, cast(endDate#28 as string) AS endDate#411, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- InMemoryTableScan [appId#0, appSparkVersion#13, attemptId#1, azkaban.link.execution.url#161, azkaban.link.job.url#163, azkaban.link.workflow.url#159, completed#4, duration#5L, endDate#28, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, mode#3, name#2, sparkUser#10, startTime#11, startTimeEpoch#12L, user.name#165]\n +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n +- *(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n +- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)\n : +- *(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n : +- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct azkaban.link.workflow.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165])\n +- *(4) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- Exchange hashpartitioning(appId#137, 200)\n +- SortAggregate(key=[appId#137], functions=[partial_first(if ((key#148 <=> azkaban.link.workflow.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.execution.url)) value#149 else null, true), partial_first(if ((key#148 <=> azkaban.link.job.url)) value#149 else null, true), partial_first(if ((key#148 <=> user.name)) value#149 else null, true)], output=[appId#137, first#273, valueSet#274, first#275, valueSet#276, first#277, valueSet#278, first#279, valueSet#280])\n +- *(3) Sort [appId#137 ASC NULLS FIRST], false, 0\n +- *(3) Project [appId#137, col#145.key AS key#148, col#145.value AS value#149]\n +- Generate explode(systemProperties#135), [appId#137], false, [col#145]\n +- *(2) FileScan avro [systemProperties#135,appId#137] Batched: false, Format: com.databricks.spark.avro.DefaultSource@485d3d1, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct>,appId:string>","sparkPlanInfo":{"nodeName":"CollectLimit","simpleString":"CollectLimit 21","children":[{"nodeName":"WholeStageCodegen","simpleString":"WholeStageCodegen","children":[{"nodeName":"LocalLimit","simpleString":"LocalLimit 21","children":[{"nodeName":"Project","simpleString":"Project [appId#0, attemptId#1, name#2, mode#3, cast(completed#4 as string) AS completed#401, cast(duration#5L as string) AS duration#402, endTime#6, cast(endTimeEpoch#7L as string) AS endTimeEpoch#404, lastUpdated#8, cast(lastUpdatedEpoch#9L as string) AS lastUpdatedEpoch#406, sparkUser#10, startTime#11, cast(startTimeEpoch#12L as string) AS startTimeEpoch#409, appSparkVersion#13, cast(endDate#28 as string) AS endDate#411, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]","children":[{"nodeName":"InputAdapter","simpleString":"InputAdapter","children":[{"nodeName":"InMemoryTableScan","simpleString":"InMemoryTableScan [appId#0, appSparkVersion#13, attemptId#1, azkaban.link.execution.url#161, azkaban.link.job.url#163, azkaban.link.workflow.url#159, completed#4, duration#5L, endDate#28, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, mode#3, name#2, sparkUser#10, startTime#11, startTimeEpoch#12L, user.name#165]","children":[],"metrics":[{"name":"number of output rows","accumulatorId":35,"metricType":"sum"},{"name":"scan time total (min, med, max)","accumulatorId":36,"metricType":"timing"}]}],"metrics":[]}],"metrics":[]}],"metrics":[]}],"metrics":[{"name":"duration total (min, med, max)","accumulatorId":34,"metricType":"timing"}]}],"metrics":[]},"time":1524182129952} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1524182130194,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":6,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"FileScanRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"*(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n+- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct:39","Parent IDs":[1],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[4],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"FileScanRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"19\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"23\",\"name\":\"Generate\"}","Callsite":"cache at :41","Parent IDs":[10],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"18\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Accumulables":[]},{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"show at :40","Number of Tasks":1,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"33\",\"name\":\"map\"}","Callsite":"show at :40","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"32\",\"name\":\"mapPartitionsInternal\"}","Callsite":"show at :40","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"8\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"WholeStageCodegen\"}","Callsite":"show at :40","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":20,"Name":"*(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n+- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 rep...","Scope":"{\"id\":\"26\",\"name\":\"mapPartitionsInternal\"}","Callsite":"cache at :41","Parent IDs":[19],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[22],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":18,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"7\",\"name\":\"SortMergeJoin\"}","Callsite":"cache at :41","Parent IDs":[8,17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"13\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[14],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"4\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[18],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0,1],"Details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Accumulables":[]}],"Stage IDs":[0,1,2],"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":6,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"FileScanRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"*(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n+- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct:39","Parent IDs":[1],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[4],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130229,"Accumulables":[]},"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"FileScanRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"19\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"23\",\"name\":\"Generate\"}","Callsite":"cache at :41","Parent IDs":[10],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"18\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130328,"Accumulables":[]},"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1524182130331,"Executor ID":"2","Host":"node4045.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1524182130349,"Executor ID":"3","Host":"node0998.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":0,"Attempt":0,"Launch Time":1524182142251,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182142286,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"154334487","Value":"154334486","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"466636","Value":"466636","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"466636","Value":"466636","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"19666","Value":"19665","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":466636,"Value":466636,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":37809697,"Value":37809697,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":91545212,"Value":91545212,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":466636,"Value":466636,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":20002743,"Value":20002743,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":407,"Value":407,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":1856,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":9020410971,"Value":9020410971,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":11146,"Value":11146,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":574344183,"Value":574344183,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":714,"Value":714,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":714,"Executor Deserialize CPU Time":574344183,"Executor Run Time":11146,"Executor CPU Time":9020410971,"Result Size":1856,"JVM GC Time":407,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":20002743,"Shuffle Write Time":91545212,"Shuffle Records Written":466636},"Input Metrics":{"Bytes Read":37809697,"Records Read":466636},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":1,"Attempt":0,"Launch Time":1524182142997,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1524182130350,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182143009,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"206421303","Value":"360755789","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"624246","Value":"1090882","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"624246","Value":"1090882","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"20604","Value":"40269","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":624246,"Value":1090882,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":50423609,"Value":88233306,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":104125550,"Value":195670762,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":624246,"Value":1090882,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":26424033,"Value":46426776,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":374,"Value":781,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":3712,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":11039226628,"Value":20059637599,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":11978,"Value":23124,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":526915936,"Value":1101260119,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":622,"Value":1336,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":622,"Executor Deserialize CPU Time":526915936,"Executor Run Time":11978,"Executor CPU Time":11039226628,"Result Size":1856,"JVM GC Time":374,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":26424033,"Shuffle Write Time":104125550,"Shuffle Records Written":624246},"Input Metrics":{"Bytes Read":50423609,"Records Read":624246},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182143160,"Executor ID":"5","Executor Info":{"Host":"node2477.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stdout?start=-4096","stderr":"http://node2477.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000007/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":2,"Attempt":0,"Launch Time":1524182143166,"Executor ID":"5","Host":"node2477.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"node2477.grid.company.com","Port":20123},"Maximum Memory":956615884,"Timestamp":1524182143406,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":3,"Attempt":0,"Launch Time":1524182144237,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":0,"Attempt":0,"Launch Time":1524182142251,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182144246,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1920975","Value":"1920974","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"3562","Value":"3562","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"41943039","Value":"41943038","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"38","Value":"37","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"1813","Value":"1812","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"195602","Value":"195602","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"3563","Value":"3563","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"1558","Value":"1557","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":3563,"Value":3563,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":36845111,"Value":36845111,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":27318908,"Value":27318908,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3562,"Value":3562,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":349287,"Value":349287,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":41943040,"Value":41943040,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2394,"Value":2394,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":1498974375,"Value":1498974375,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":1922,"Value":1922,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":49547405,"Value":49547405,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":56,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":49547405,"Executor Run Time":1922,"Executor CPU Time":1498974375,"Result Size":2394,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":349287,"Shuffle Write Time":27318908,"Shuffle Records Written":3562},"Input Metrics":{"Bytes Read":36845111,"Records Read":3563},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1524182130331,"Executor ID":"2","Host":"node4045.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182144444,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"204058975","Value":"564814764","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"616897","Value":"1707779","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"616897","Value":"1707779","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"23365","Value":"63634","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":616897,"Value":1707779,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":50423423,"Value":138656729,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":105575962,"Value":301246724,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":616897,"Value":1707779,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":22950296,"Value":69377072,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":326,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":5568,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":11931694025,"Value":31991331624,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":13454,"Value":36578,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":531799977,"Value":1633060096,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":594,"Value":1930,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":594,"Executor Deserialize CPU Time":531799977,"Executor Run Time":13454,"Executor CPU Time":11931694025,"Result Size":1856,"JVM GC Time":326,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":22950296,"Shuffle Write Time":105575962,"Shuffle Records Written":616897},"Input Metrics":{"Bytes Read":50423423,"Records Read":616897},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1524182130349,"Executor ID":"3","Host":"node0998.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182144840,"Failed":false,"Killed":false,"Accumulables":[{"ID":7,"Name":"data size total (min, med, max)","Update":"207338935","Value":"772153699","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Update":"626277","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":1,"Name":"number of output rows","Update":"626277","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":5,"Name":"duration total (min, med, max)","Update":"24254","Value":"87888","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":59,"Name":"internal.metrics.input.recordsRead","Update":626277,"Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Update":50409514,"Value":189066243,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Update":106963069,"Value":408209793,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":626277,"Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":31362123,"Value":100739195,"Internal":true,"Count Failed Values":true},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Update":342,"Value":1449,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.resultSize","Update":1856,"Value":7424,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Update":12267596062,"Value":44258927686,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Update":13858,"Value":50436,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Update":519573839,"Value":2152633935,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Update":573,"Value":2503,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":573,"Executor Deserialize CPU Time":519573839,"Executor Run Time":13858,"Executor CPU Time":12267596062,"Result Size":1856,"JVM GC Time":342,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":31362123,"Shuffle Write Time":106963069,"Shuffle Records Written":626277},"Input Metrics":{"Bytes Read":50409514,"Records Read":626277},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":592412824,"JVMOffHeapMemory":202907152,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":905801,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":905801,"OffHeapUnifiedMemory":0,"DirectPoolMemory":355389,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"2","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":523121272,"JVMOffHeapMemory":88280720,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":52050147,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":52050147,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":214174608,"JVMOffHeapMemory":91548704,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":47399168,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":47399168,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"4","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":518613056,"JVMOffHeapMemory":95657456,"OnHeapExecutionMemory":37748736,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":63104457,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":100853193,"OffHeapUnifiedMemory":0,"DirectPoolMemory":126261,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"3","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":726805712,"JVMOffHeapMemory":90709624,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":69535048,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":69535048,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":6,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"FileScanRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"*(1) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, cast(endTime#6 as date) AS endDate#28]\n+- *(1) FileScan avro [appId#0,attemptId#1,name#2,mode#3,completed#4,duration#5L,endTime#6,endTimeEpoch#7L,lastUpdated#8,lastUpdatedEpoch#9L,sparkUser#10,startTime#11,startTimeEpoch#12L,appSparkVersion#13] Batched: false, Format: com.databricks.spark.avro.DefaultSource@7006b304, Location: InMemoryFileIndex[hdfs://clusternn01.grid.company.com:9000/data/hadoopdev/sparkmetrics/ltx1-..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct:39","Parent IDs":[1],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"0\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :39","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[4],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"InMemoryTableScan\"}","Callsite":"cache at :41","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130229,"Completion Time":1524182144852,"Accumulables":[{"ID":41,"Name":"internal.metrics.resultSize","Value":7424,"Internal":true,"Count Failed Values":true},{"ID":59,"Name":"internal.metrics.input.recordsRead","Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.executorDeserializeCpuTime","Value":2152633935,"Internal":true,"Count Failed Values":true},{"ID":56,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":2334056,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"duration total (min, med, max)","Value":"87888","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":55,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":100739195,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.executorCpuTime","Value":44258927686,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.input.bytesRead","Value":189066243,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"data size total (min, med, max)","Value":"772153699","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":16,"Name":"number of output rows","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":43,"Name":"internal.metrics.resultSerializationTime","Value":7,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"number of output rows","Value":"2334056","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":37,"Name":"internal.metrics.executorDeserializeTime","Value":2503,"Internal":true,"Count Failed Values":true},{"ID":57,"Name":"internal.metrics.shuffle.write.writeTime","Value":408209793,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.executorRunTime","Value":50436,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.jvmGCTime","Value":1449,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":1,"Attempt":0,"Launch Time":1524182142997,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182145327,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1953295","Value":"3874269","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"3575","Value":"7137","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"41943039","Value":"83886077","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"49","Value":"86","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"2002","Value":"3814","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"196587","Value":"392189","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"3575","Value":"7138","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"1755","Value":"3312","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":3575,"Value":7138,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":36849246,"Value":73694357,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":32035583,"Value":59354491,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3575,"Value":7137,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":349006,"Value":698293,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":41943040,"Value":83886080,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":31,"Value":64,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2394,"Value":4788,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":1785119941,"Value":3284094316,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":2182,"Value":4104,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":71500541,"Value":121047946,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":136,"Value":192,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":136,"Executor Deserialize CPU Time":71500541,"Executor Run Time":2182,"Executor CPU Time":1785119941,"Result Size":2394,"JVM GC Time":31,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":349006,"Shuffle Write Time":32035583,"Shuffle Records Written":3575},"Input Metrics":{"Bytes Read":36849246,"Records Read":3575},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":3,"Attempt":0,"Launch Time":1524182144237,"Executor ID":"1","Host":"node1404.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182145971,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1337999","Value":"5212268","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"2435","Value":"9572","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"37748735","Value":"121634812","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"9","Value":"95","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"1703","Value":"5517","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"133759","Value":"525948","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"2435","Value":"9573","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"1609","Value":"4921","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":2435,"Value":9573,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":24250210,"Value":97944567,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":20055909,"Value":79410400,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2435,"Value":9572,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":242714,"Value":941007,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":37748736,"Value":121634816,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":31,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2394,"Value":7182,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":896878991,"Value":4180973307,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":1722,"Value":5826,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2787355,"Value":123835301,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":195,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2787355,"Executor Run Time":1722,"Executor CPU Time":896878991,"Result Size":2394,"JVM GC Time":31,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":242714,"Shuffle Write Time":20055909,"Shuffle Records Written":2435},"Input Metrics":{"Bytes Read":24250210,"Records Read":2435},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182147549,"Executor ID":"6","Executor Info":{"Host":"node6644.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stdout?start=-4096","stderr":"http://node6644.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000008/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"6","Host":"node6644.grid.company.com","Port":8445},"Maximum Memory":956615884,"Timestamp":1524182147706,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1524182149826,"Executor ID":"7","Executor Info":{"Host":"node6340.grid.company.com","Total Cores":1,"Log Urls":{"stdout":"http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stdout?start=-4096","stderr":"http://node6340.grid.company.com:8042/node/containerlogs/container_e05_1523494505172_1552404_01_000009/edlu/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"7","Host":"node6340.grid.company.com","Port":5933},"Maximum Memory":956615884,"Timestamp":1524182149983,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":2,"Attempt":0,"Launch Time":1524182143166,"Executor ID":"5","Host":"node2477.grid.company.com","Locality":"ANY","Speculative":false,"Getting Result Time":0,"Finish Time":1524182152418,"Failed":false,"Killed":false,"Accumulables":[{"ID":8,"Name":"data size total (min, med, max)","Update":"1910103","Value":"7122371","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":23,"Name":"number of output rows","Update":"3541","Value":"13113","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":25,"Name":"peak memory total (min, med, max)","Update":"41943039","Value":"163577851","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":24,"Name":"sort time total (min, med, max)","Update":"48","Value":"143","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":27,"Name":"duration total (min, med, max)","Update":"6093","Value":"11610","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Update":"194553","Value":"720501","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":29,"Name":"number of output rows","Update":"3541","Value":"13114","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Update":"5951","Value":"10872","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":84,"Name":"internal.metrics.input.recordsRead","Update":3541,"Value":13114,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.input.bytesRead","Update":36838295,"Value":134782862,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Update":49790497,"Value":129200897,"Internal":true,"Count Failed Values":true},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3541,"Value":13113,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355051,"Value":1296058,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Update":41943040,"Value":163577856,"Internal":true,"Count Failed Values":true},{"ID":68,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Update":920,"Value":1015,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Update":2437,"Value":9619,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.executorCpuTime","Update":5299274511,"Value":9480247818,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Update":7847,"Value":13673,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Update":687811857,"Value":811647158,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Update":1037,"Value":1232,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1037,"Executor Deserialize CPU Time":687811857,"Executor Run Time":7847,"Executor CPU Time":5299274511,"Result Size":2437,"JVM GC Time":920,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355051,"Shuffle Write Time":49790497,"Shuffle Records Written":3541},"Input Metrics":{"Bytes Read":36838295,"Records Read":3541},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":629553808,"JVMOffHeapMemory":205304696,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":905801,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":905801,"OffHeapUnifiedMemory":0,"DirectPoolMemory":397602,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"2","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":595946552,"JVMOffHeapMemory":91208368,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":58468944,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":58468944,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":755008624,"JVMOffHeapMemory":100519936,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":47962185,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":47962185,"OffHeapUnifiedMemory":0,"DirectPoolMemory":98230,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"4","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":518613056,"JVMOffHeapMemory":95657456,"OnHeapExecutionMemory":37748736,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":63104457,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":100853193,"OffHeapUnifiedMemory":0,"DirectPoolMemory":126261,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"3","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":726805712,"JVMOffHeapMemory":90709624,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":69535048,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":69535048,"OffHeapUnifiedMemory":0,"DirectPoolMemory":87796,"MappedPoolMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"cache at :41","Number of Tasks":4,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"FileScanRDD","Scope":"{\"id\":\"24\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"19\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"23\",\"name\":\"Generate\"}","Callsite":"cache at :41","Parent IDs":[10],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"18\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":4,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:41)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:46)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:48)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:50)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:52)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:54)\n$line49.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:56)\n$line49.$read$$iw$$iw$$iw$$iw$$iw.(:58)\n$line49.$read$$iw$$iw$$iw$$iw.(:60)\n$line49.$read$$iw$$iw$$iw.(:62)\n$line49.$read$$iw$$iw.(:64)\n$line49.$read$$iw.(:66)\n$line49.$read.(:68)\n$line49.$read$.(:72)\n$line49.$read$.()\n$line49.$eval$.$print$lzycompute(:7)\n$line49.$eval$.$print(:6)\n$line49.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182130328,"Completion Time":1524182152419,"Accumulables":[{"ID":83,"Name":"internal.metrics.input.bytesRead","Value":134782862,"Internal":true,"Count Failed Values":true},{"ID":23,"Name":"number of output rows","Value":"13113","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":68,"Name":"internal.metrics.resultSerializationTime","Value":2,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"data size total (min, med, max)","Value":"7122371","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":62,"Name":"internal.metrics.executorDeserializeTime","Value":1232,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1296058,"Internal":true,"Count Failed Values":true},{"ID":71,"Name":"internal.metrics.peakExecutionMemory","Value":163577856,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"number of output rows","Value":"13114","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":65,"Name":"internal.metrics.executorCpuTime","Value":9480247818,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.executorRunTime","Value":13673,"Internal":true,"Count Failed Values":true},{"ID":82,"Name":"internal.metrics.shuffle.write.writeTime","Value":129200897,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.jvmGCTime","Value":1015,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"peak memory total (min, med, max)","Value":"163577851","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":28,"Name":"number of output rows","Value":"720501","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":63,"Name":"internal.metrics.executorDeserializeCpuTime","Value":811647158,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"duration total (min, med, max)","Value":"11610","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":81,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":13113,"Internal":true,"Count Failed Values":true},{"ID":84,"Name":"internal.metrics.input.recordsRead","Value":13114,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.resultSize","Value":9619,"Internal":true,"Count Failed Values":true},{"ID":24,"Name":"sort time total (min, med, max)","Value":"143","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":33,"Name":"duration total (min, med, max)","Value":"10872","Internal":true,"Count Failed Values":true,"Metadata":"sql"}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"show at :40","Number of Tasks":1,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"33\",\"name\":\"map\"}","Callsite":"show at :40","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"32\",\"name\":\"mapPartitionsInternal\"}","Callsite":"show at :40","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"8\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"WholeStageCodegen\"}","Callsite":"show at :40","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":20,"Name":"*(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n+- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 rep...","Scope":"{\"id\":\"26\",\"name\":\"mapPartitionsInternal\"}","Callsite":"cache at :41","Parent IDs":[19],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[22],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":18,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"7\",\"name\":\"SortMergeJoin\"}","Callsite":"cache at :41","Parent IDs":[8,17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"13\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[14],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"4\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[18],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0,1],"Details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182152430,"Accumulables":[]},"Properties":{"spark.sql.execution.id":"2"}} +{"Event":"SparkListenerTaskStart","Stage ID":2,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":0,"Attempt":0,"Launch Time":1524182152447,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":2,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":0,"Attempt":0,"Launch Time":1524182152447,"Executor ID":"4","Host":"node4243.grid.company.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1524182153103,"Failed":false,"Killed":false,"Accumulables":[{"ID":34,"Name":"duration total (min, med, max)","Update":"1","Value":"0","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":35,"Name":"number of output rows","Update":"6928","Value":"6928","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":10,"Name":"duration total (min, med, max)","Update":"452","Value":"451","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":11,"Name":"number of output rows","Update":"10945","Value":"10945","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":18,"Name":"number of output rows","Update":"62","Value":"62","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":20,"Name":"peak memory total (min, med, max)","Update":"33619967","Value":"33619966","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":22,"Name":"duration total (min, med, max)","Update":"323","Value":"322","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":13,"Name":"peak memory total (min, med, max)","Update":"34078719","Value":"34078718","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":12,"Name":"sort time total (min, med, max)","Update":"10","Value":"9","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":15,"Name":"duration total (min, med, max)","Update":"367","Value":"366","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":104,"Name":"internal.metrics.shuffle.read.recordsRead","Update":11007,"Value":11007,"Internal":true,"Count Failed Values":true},{"ID":103,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":102,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":124513,"Value":124513,"Internal":true,"Count Failed Values":true},{"ID":101,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":100,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":314162,"Value":314162,"Internal":true,"Count Failed Values":true},{"ID":99,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":98,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.peakExecutionMemory","Update":67698688,"Value":67698688,"Internal":true,"Count Failed Values":true},{"ID":91,"Name":"internal.metrics.resultSize","Update":4642,"Value":4642,"Internal":true,"Count Failed Values":true},{"ID":90,"Name":"internal.metrics.executorCpuTime","Update":517655714,"Value":517655714,"Internal":true,"Count Failed Values":true},{"ID":89,"Name":"internal.metrics.executorRunTime","Update":589,"Value":589,"Internal":true,"Count Failed Values":true},{"ID":88,"Name":"internal.metrics.executorDeserializeCpuTime","Update":45797784,"Value":45797784,"Internal":true,"Count Failed Values":true},{"ID":87,"Name":"internal.metrics.executorDeserializeTime","Update":50,"Value":50,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":50,"Executor Deserialize CPU Time":45797784,"Executor Run Time":589,"Executor CPU Time":517655714,"Result Size":4642,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":2,"Fetch Wait Time":0,"Remote Bytes Read":314162,"Remote Bytes Read To Disk":0,"Local Bytes Read":124513,"Total Records Read":11007},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"show at :40","Number of Tasks":1,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"33\",\"name\":\"map\"}","Callsite":"show at :40","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"32\",\"name\":\"mapPartitionsInternal\"}","Callsite":"show at :40","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"8\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"WholeStageCodegen\"}","Callsite":"show at :40","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":20,"Name":"*(5) Project [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28, azkaban.link.workflow.url#159, azkaban.link.execution.url#161, azkaban.link.job.url#163, user.name#165]\n+- SortMergeJoin [appId#0], [appId#137], LeftOuter\n :- *(1) Sort [appId#0 ASC NULLS FIRST], false, 0\n : +- Exchange hashpartitioning(appId#0, 200)\n : +- InMemoryTableScan [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28]\n : +- InMemoryRelation [appId#0, attemptId#1, name#2, mode#3, completed#4, duration#5L, endTime#6, endTimeEpoch#7L, lastUpdated#8, lastUpdatedEpoch#9L, sparkUser#10, startTime#11, startTimeEpoch#12L, appSparkVersion#13, endDate#28], true, 10000, StorageLevel(disk, memory, deserialized, 1 rep...","Scope":"{\"id\":\"26\",\"name\":\"mapPartitionsInternal\"}","Callsite":"cache at :41","Parent IDs":[19],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"31\",\"name\":\"InMemoryTableScan\"}","Callsite":"show at :40","Parent IDs":[22],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":18,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"7\",\"name\":\"SortMergeJoin\"}","Callsite":"cache at :41","Parent IDs":[8,17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"13\",\"name\":\"SortAggregate\"}","Callsite":"cache at :41","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"11\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"ShuffledRowRDD","Scope":"{\"id\":\"17\",\"name\":\"Exchange\"}","Callsite":"cache at :41","Parent IDs":[14],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"4\",\"name\":\"WholeStageCodegen\"}","Callsite":"cache at :41","Parent IDs":[18],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":200,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0,1],"Details":"org.apache.spark.sql.Dataset.show(Dataset.scala:691)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:40)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:45)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:47)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:49)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:51)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:53)\n$line50.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:55)\n$line50.$read$$iw$$iw$$iw$$iw$$iw.(:57)\n$line50.$read$$iw$$iw$$iw$$iw.(:59)\n$line50.$read$$iw$$iw$$iw.(:61)\n$line50.$read$$iw$$iw.(:63)\n$line50.$read$$iw.(:65)\n$line50.$read.(:67)\n$line50.$read$.(:71)\n$line50.$read$.()\n$line50.$eval$.$print$lzycompute(:7)\n$line50.$eval$.$print(:6)\n$line50.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)","Submission Time":1524182152430,"Completion Time":1524182153104,"Accumulables":[{"ID":101,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":104,"Name":"internal.metrics.shuffle.read.recordsRead","Value":11007,"Internal":true,"Count Failed Values":true},{"ID":35,"Name":"number of output rows","Value":"6928","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":89,"Name":"internal.metrics.executorRunTime","Value":589,"Internal":true,"Count Failed Values":true},{"ID":98,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":6,"Internal":true,"Count Failed Values":true},{"ID":11,"Name":"number of output rows","Value":"10945","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":20,"Name":"peak memory total (min, med, max)","Value":"33619966","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":91,"Name":"internal.metrics.resultSize","Value":4642,"Internal":true,"Count Failed Values":true},{"ID":100,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":314162,"Internal":true,"Count Failed Values":true},{"ID":13,"Name":"peak memory total (min, med, max)","Value":"34078718","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":22,"Name":"duration total (min, med, max)","Value":"322","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":103,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":88,"Name":"internal.metrics.executorDeserializeCpuTime","Value":45797784,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"duration total (min, med, max)","Value":"0","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":10,"Name":"duration total (min, med, max)","Value":"451","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":87,"Name":"internal.metrics.executorDeserializeTime","Value":50,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.peakExecutionMemory","Value":67698688,"Internal":true,"Count Failed Values":true},{"ID":90,"Name":"internal.metrics.executorCpuTime","Value":517655714,"Internal":true,"Count Failed Values":true},{"ID":99,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"number of output rows","Value":"62","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":12,"Name":"sort time total (min, med, max)","Value":"9","Internal":true,"Count Failed Values":true,"Metadata":"sql"},{"ID":102,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":124513,"Internal":true,"Count Failed Values":true},{"ID":15,"Name":"duration total (min, med, max)","Value":"366","Internal":true,"Count Failed Values":true,"Metadata":"sql"}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1524182153112,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd","executionId":2,"time":1524182153139} +{"Event":"SparkListenerUnpersistRDD","RDD ID":2} +{"Event":"SparkListenerUnpersistRDD","RDD ID":20} +{"Event":"SparkListenerApplicationEnd","Timestamp":1524182189134} diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b705556e54b14..de479db5fbc0f 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -28,7 +28,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito.{mock, spy, verify, when} import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -77,7 +77,7 @@ class HeartbeatReceiverSuite heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(scheduler.executorHeartbeatReceived(any(), any(), any(), any())).thenReturn(true) } /** @@ -213,8 +213,10 @@ class HeartbeatReceiverSuite executorShouldReregister: Boolean): Unit = { val metrics = TaskMetrics.empty val blockManagerId = BlockManagerId(executorId, "localhost", 12345) + val executorUpdates = new ExecutorMetrics(Array(123456L, 543L, 12345L, 1234L, 123L, + 12L, 432L, 321L, 654L, 765L)) val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( - Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId)) + Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId, executorUpdates)) if (executorShouldReregister) { assert(response.reregisterBlockManager) } else { @@ -223,7 +225,8 @@ class HeartbeatReceiverSuite verify(scheduler).executorHeartbeatReceived( Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics.accumulators())), - Matchers.eq(blockManagerId)) + Matchers.eq(blockManagerId), + Matchers.eq(executorUpdates)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 11b29121739a4..11a2db81f7c6d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -82,6 +82,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + .set("spark.eventLog.logStageExecutorMetrics.enabled", "true") conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -128,6 +129,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "succeeded&failed job list json" -> "applications/local-1422981780767/jobs?status=succeeded&status=failed", "executor list json" -> "applications/local-1422981780767/executors", + "executor list with executor metrics json" -> + "applications/application_1506645932520_24630151/executors", "stage list json" -> "applications/local-1422981780767/stages", "complete stage list json" -> "applications/local-1422981780767/stages?status=complete", "failed stage list json" -> "applications/local-1422981780767/stages?status=failed", diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4e87deb136df6..365eab0668ab2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -140,7 +141,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = true + blockManagerId: BlockManagerId, + executorUpdates: ExecutorMetrics): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -660,7 +662,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = true + blockManagerId: BlockManagerId, + executorMetrics: ExecutorMetrics): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index a9e92fa07b9dd..cecd6996df7bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} +import scala.collection.immutable.Map import scala.collection.mutable +import scala.collection.mutable.Set import scala.io.Source import org.apache.hadoop.fs.Path @@ -29,11 +31,14 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.io._ -import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.{JsonProtocol, Utils} + /** * Test whether EventLoggingListener logs events properly. * @@ -43,6 +48,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with Logging { + import EventLoggingListenerSuite._ private val fileSystem = Utils.getHadoopFileSystem("/", @@ -137,6 +143,10 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit "a fine:mind$dollar{bills}.1", None, Some("lz4"))) } + test("Executor metrics update") { + testStageExecutorMetricsEventLogging() + } + /* ----------------- * * Actual test logic * * ----------------- */ @@ -251,6 +261,214 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } } + /** + * Test stage executor metrics logging functionality. This checks that peak + * values from SparkListenerExecutorMetricsUpdate events during a stage are + * logged in a StageExecutorMetrics event for each executor at stage completion. + */ + private def testStageExecutorMetricsEventLogging() { + val conf = getLoggingConf(testDirPath, None) + val logName = "stageExecutorMetrics-test" + val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) + val listenerBus = new LiveListenerBus(conf) + + // Events to post. + val events = Array( + SparkListenerApplicationStart("executionMetrics", None, + 1L, "update", None), + createExecutorAddedEvent(1), + createExecutorAddedEvent(2), + createStageSubmittedEvent(0), + // receive 3 metric updates from each executor with just stage 0 running, + // with different peak updates for each executor + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L))), + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L))), + // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 4, 6 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L))), + // exec 1: new stage 0 peaks for metrics at indexes: 5, 7 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 5, 6, 7, 8 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L))), + // now start stage 1, one more metric update for each executor, and new + // peaks for some stage 1 metrics (as listed), initialize stage 1 peaks + createStageSubmittedEvent(1), + // exec 1: new stage 0 peaks for metrics at indexes: 0, 3, 7; initialize stage 1 peaks + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 2, 3, 6, 7, 9; + // initialize stage 1 peaks + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L))), + // complete stage 0, and 3 more updates for each executor with just + // stage 1 running + createStageCompletedEvent(0), + // exec 1: new stage 1 peaks for metrics at indexes: 0, 1, 3 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L))), + // enew ExecutorMetrics(xec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L))), + // exec 1: new stage 1 peaks for metrics at indexes: 0, 4, 5, 7 + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L))), + // exec 2: new stage 1 peak for metrics at index: 7 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L))), + // exec 1: no new stage 1 peaks + createExecutorMetricsUpdateEvent(1, + new ExecutorMetrics(Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L))), + createExecutorRemovedEvent(1), + // exec 2: new stage 1 peak for metrics at index: 6 + createExecutorMetricsUpdateEvent(2, + new ExecutorMetrics(Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L))), + createStageCompletedEvent(1), + SparkListenerApplicationEnd(1000L)) + + // play the events for the event logger + eventLogger.start() + listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) + listenerBus.addToEventLogQueue(eventLogger) + events.foreach(event => listenerBus.post(event)) + listenerBus.stop() + eventLogger.stop() + + // expected StageExecutorMetrics, for the given stage id and executor id + val expectedMetricsEvents: Map[(Int, String), SparkListenerStageExecutorMetrics] = + Map( + ((0, "1"), + new SparkListenerStageExecutorMetrics("1", 0, 0, + new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, 70L, 20L)))), + ((0, "2"), + new SparkListenerStageExecutorMetrics("2", 0, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L)))), + ((1, "1"), + new SparkListenerStageExecutorMetrics("1", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L)))), + ((1, "2"), + new SparkListenerStageExecutorMetrics("2", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L))))) + + // Verify the log file contains the expected events. + // Posted events should be logged, except for ExecutorMetricsUpdate events -- these + // are consolidated, and the peak values for each stage are logged at stage end. + val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) + try { + val lines = readLines(logData) + val logStart = SparkListenerLogStart(SPARK_VERSION) + assert(lines.size === 14) + assert(lines(0).contains("SparkListenerLogStart")) + assert(lines(1).contains("SparkListenerApplicationStart")) + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + var logIdx = 1 + events.foreach {event => + event match { + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + case stageCompleted: SparkListenerStageCompleted => + val execIds = Set[String]() + (1 to 2).foreach { _ => + val execId = checkStageExecutorMetrics(lines(logIdx), + stageCompleted.stageInfo.stageId, expectedMetricsEvents) + execIds += execId + logIdx += 1 + } + assert(execIds.size == 2) // check that each executor was logged + checkEvent(lines(logIdx), event) + logIdx += 1 + case _ => + checkEvent(lines(logIdx), event) + logIdx += 1 + } + } + } finally { + logData.close() + } + } + + private def createStageSubmittedEvent(stageId: Int) = { + SparkListenerStageSubmitted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + private def createStageCompletedEvent(stageId: Int) = { + SparkListenerStageCompleted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + private def createExecutorAddedEvent(executorId: Int) = { + SparkListenerExecutorAdded(0L, executorId.toString, new ExecutorInfo("host1", 1, Map.empty)) + } + + private def createExecutorRemovedEvent(executorId: Int) = { + SparkListenerExecutorRemoved(0L, executorId.toString, "test") + } + + private def createExecutorMetricsUpdateEvent( + executorId: Int, + executorMetrics: ExecutorMetrics): SparkListenerExecutorMetricsUpdate = { + val taskMetrics = TaskMetrics.empty + taskMetrics.incDiskBytesSpilled(111) + taskMetrics.incMemoryBytesSpilled(222) + val accum = Array((333L, 1, 1, taskMetrics.accumulators().map(AccumulatorSuite.makeInfo))) + SparkListenerExecutorMetricsUpdate(executorId.toString, accum, Some(executorMetrics)) + } + + /** Check that the Spark history log line matches the expected event. */ + private def checkEvent(line: String, event: SparkListenerEvent): Unit = { + assert(line.contains(event.getClass.toString.split("\\.").last)) + val parsed = JsonProtocol.sparkEventFromJson(parse(line)) + assert(parsed.getClass === event.getClass) + (event, parsed) match { + case (expected: SparkListenerStageSubmitted, actual: SparkListenerStageSubmitted) => + // accumulables can be different, so only check the stage Id + assert(expected.stageInfo.stageId == actual.stageInfo.stageId) + case (expected: SparkListenerStageCompleted, actual: SparkListenerStageCompleted) => + // accumulables can be different, so only check the stage Id + assert(expected.stageInfo.stageId == actual.stageInfo.stageId) + case (expected: SparkListenerEvent, actual: SparkListenerEvent) => + assert(expected === actual) + } + } + + /** + * Check that the Spark history log line is an StageExecutorMetrics event, and matches the + * expected value for the stage and executor. + * + * @param line the Spark history log line + * @param stageId the stage ID the ExecutorMetricsUpdate is associated with + * @param expectedEvents map of expected ExecutorMetricsUpdate events, for (stageId, executorId) + */ + private def checkStageExecutorMetrics( + line: String, + stageId: Int, + expectedEvents: Map[(Int, String), SparkListenerStageExecutorMetrics]): String = { + JsonProtocol.sparkEventFromJson(parse(line)) match { + case executorMetrics: SparkListenerStageExecutorMetrics => + expectedEvents.get((stageId, executorMetrics.execId)) match { + case Some(expectedMetrics) => + assert(executorMetrics.execId === expectedMetrics.execId) + assert(executorMetrics.stageId === expectedMetrics.stageId) + assert(executorMetrics.stageAttemptId === expectedMetrics.stageAttemptId) + ExecutorMetricType.values.foreach { metricType => + assert(executorMetrics.executorMetrics.getMetricValue(metricType) === + expectedMetrics.executorMetrics.getMetricValue(metricType)) + } + case None => + assert(false) + } + executorMetrics.execId + case _ => + fail("expecting SparkListenerStageExecutorMetrics") + } + } + private def readLines(in: InputStream): Seq[String] = { Source.fromInputStream(in).getLines().toSeq } @@ -299,6 +517,7 @@ object EventLoggingListenerSuite { conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) } + conf.set("spark.eventLog.logStageExecutorMetrics.enabled", "true") conf } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index b4705914b999b..0621c98d41184 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AccumulatorV2 @@ -92,5 +93,6 @@ private class DummyTaskScheduler extends TaskScheduler { def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], - blockManagerId: BlockManagerId): Boolean = true + blockManagerId: BlockManagerId, + executorMetrics: ExecutorMetrics): Boolean = true } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index e24d550a62665..d1113c7e0b103 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -22,6 +22,7 @@ import java.net.URI import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter @@ -217,7 +218,9 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // Verify the same events are replayed in the same order assert(sc.eventLogger.isDefined) - val originalEvents = sc.eventLogger.get.loggedEvents + val originalEvents = sc.eventLogger.get.loggedEvents.filter { e => + !JsonProtocol.sparkEventFromJson(e).isInstanceOf[SparkListenerStageExecutorMetrics] + } val replayedEvents = eventMonster.loggedEvents originalEvents.zip(replayedEvents).foreach { case (e1, e2) => // Don't compare the JSON here because accumulators in StageInfo may be out of order diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ea80fea905340..d0c2dc4ad1337 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -22,18 +22,19 @@ import java.lang.{Integer => JInteger, Long => JLong} import java.util.{Arrays, Date, Properties} import scala.collection.JavaConverters._ +import scala.collection.immutable.Map import scala.reflect.{classTag, ClassTag} import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ import org.apache.spark.util.Utils -import org.apache.spark.util.kvstore._ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { @@ -1263,6 +1264,130 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("executor metrics updates") { + val listener = new AppStatusListener(store, conf, true) + + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + + listener.onExecutorAdded(createExecutorAddedEvent(1)) + listener.onExecutorAdded(createExecutorAddedEvent(2)) + listener.onStageSubmitted(createStageSubmittedEvent(0)) + // receive 3 metric updates from each executor with just stage 0 running, + // with different peak updates for each executor + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L))) + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L))) + // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L))) + // exec 2: new stage 0 peaks for metrics at indexes: 0, 4, 6 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L))) + // exec 1: new stage 0 peaks for metrics at indexes: 5, 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L))) + // exec 2: new stage 0 peaks for metrics at indexes: 0, 5, 6, 7, 8 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L))) + // now start stage 1, one more metric update for each executor, and new + // peaks for some stage 1 metrics (as listed), initialize stage 1 peaks + listener.onStageSubmitted(createStageSubmittedEvent(1)) + // exec 1: new stage 0 peaks for metrics at indexes: 0, 3, 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L))) + // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 2, 3, 6, 7, 9 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(7000L, 80L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L))) + // complete stage 0, and 3 more updates for each executor with just + // stage 1 running + listener.onStageCompleted(createStageCompletedEvent(0)) + // exec 1: new stage 1 peaks for metrics at indexes: 0, 1, 3 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L))) + // exec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L))) + // exec 1: new stage 1 peaks for metrics at indexes: 0, 4, 5, 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L))) + // exec 2: new stage 1 peak for metrics at index: 7 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L))) + // exec 1: no new stage 1 peaks + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, + Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L))) + listener.onExecutorRemoved(createExecutorRemovedEvent(1)) + // exec 2: new stage 1 peak for metrics at index: 6 + listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, + Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L))) + listener.onStageCompleted(createStageCompletedEvent(1)) + + // expected peak values for each executor + val expectedValues = Map( + "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, 70L, 20L)), + "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 80L, 40L))) + + // check that the stored peak values match the expected values + expectedValues.foreach { case (id, metrics) => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + exec.info.peakMemoryMetrics match { + case Some(actual) => + ExecutorMetricType.values.foreach { metricType => + assert(actual.getMetricValue(metricType) === metrics.getMetricValue(metricType)) + } + case _ => + assert(false) + } + } + } + } + + test("stage executor metrics") { + // simulate reading in StageExecutorMetrics events from the history log + val listener = new AppStatusListener(store, conf, false) + val driver = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "localhost", 42) + + listener.onExecutorAdded(createExecutorAddedEvent(1)) + listener.onExecutorAdded(createExecutorAddedEvent(2)) + listener.onStageSubmitted(createStageSubmittedEvent(0)) + listener.onStageSubmitted(createStageSubmittedEvent(1)) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("1", 0, 0, + new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, 70L, 20L)))) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("2", 0, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L)))) + listener.onStageCompleted(createStageCompletedEvent(0)) + // executor 1 is removed before stage 1 has finished, the stage executor metrics + // are logged afterwards and should still be used to update the executor metrics. + listener.onExecutorRemoved(createExecutorRemovedEvent(1)) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("1", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L)))) + listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("2", 1, 0, + new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L)))) + listener.onStageCompleted(createStageCompletedEvent(1)) + + // expected peak values for each executor + val expectedValues = Map( + "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, 70L, 20L)), + "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 80L, 40L))) + + // check that the stored peak values match the expected values + for ((id, metrics) <- expectedValues) { + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + exec.info.peakMemoryMetrics match { + case Some(actual) => + ExecutorMetricType.values.foreach { metricType => + assert(actual.getMetricValue(metricType) === metrics.getMetricValue(metricType)) + } + case _ => + assert(false) + } + } + } + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { @@ -1300,4 +1425,37 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } + /** Create a stage submitted event for the specified stage Id. */ + private def createStageSubmittedEvent(stageId: Int) = { + SparkListenerStageSubmitted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + /** Create a stage completed event for the specified stage Id. */ + private def createStageCompletedEvent(stageId: Int) = { + SparkListenerStageCompleted(new StageInfo(stageId, 0, stageId.toString, 0, + Seq.empty, Seq.empty, "details")) + } + + /** Create an executor added event for the specified executor Id. */ + private def createExecutorAddedEvent(executorId: Int) = { + SparkListenerExecutorAdded(0L, executorId.toString, new ExecutorInfo("host1", 1, Map.empty)) + } + + /** Create an executor added event for the specified executor Id. */ + private def createExecutorRemovedEvent(executorId: Int) = { + SparkListenerExecutorRemoved(10L, executorId.toString, "test") + } + + /** Create an executor metrics update event, with the specified executor metrics values. */ + private def createExecutorMetricsUpdateEvent( + executorId: Int, + executorMetrics: Array[Long]): SparkListenerExecutorMetricsUpdate = { + val taskMetrics = TaskMetrics.empty + taskMetrics.incDiskBytesSpilled(111) + taskMetrics.incMemoryBytesSpilled(222) + val accum = Array((333L, 1, 1, taskMetrics.accumulators().map(AccumulatorSuite.makeInfo))) + SparkListenerExecutorMetricsUpdate(executorId.toString, accum, + Some(new ExecutorMetrics(executorMetrics))) + } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 74b72d940eeef..1e0d2af9a4711 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -94,11 +95,17 @@ class JsonProtocolSuite extends SparkFunSuite { makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true) .accumulators().map(AccumulatorSuite.makeInfo) .zipWithIndex.map { case (a, i) => a.copy(id = i) } - SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) + val executorUpdates = new ExecutorMetrics( + Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L)) + SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)), + Some(executorUpdates)) } val blockUpdated = SparkListenerBlockUpdated(BlockUpdatedInfo(BlockManagerId("Stars", "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) + val stageExecutorMetrics = + SparkListenerStageExecutorMetrics("1", 2, 3, + new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -124,6 +131,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(nodeUnblacklisted, nodeUnblacklistedJsonString) testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) testEvent(blockUpdated, blockUpdatedJsonString) + testEvent(stageExecutorMetrics, stageExecutorMetricsJsonString) } test("Dependent Classes") { @@ -419,6 +427,30 @@ class JsonProtocolSuite extends SparkFunSuite { exceptionFailure.accumUpdates, oldExceptionFailure.accumUpdates, (x, y) => x == y) } + test("ExecutorMetricsUpdate backward compatibility: executor metrics update") { + // executorMetricsUpdate was added in 2.4.0. + val executorMetricsUpdate = makeExecutorMetricsUpdate("1", true, true) + val oldExecutorMetricsUpdateJson = + JsonProtocol.executorMetricsUpdateToJson(executorMetricsUpdate) + .removeField( _._1 == "Executor Metrics Updated") + val exepectedExecutorMetricsUpdate = makeExecutorMetricsUpdate("1", true, false) + assertEquals(exepectedExecutorMetricsUpdate, + JsonProtocol.executorMetricsUpdateFromJson(oldExecutorMetricsUpdateJson)) + } + + test("executorMetricsFromJson backward compatibility: handle missing metrics") { + // any missing metrics should be set to 0 + val executorMetrics = new ExecutorMetrics( + Array(12L, 23L, 45L, 67L, 78L, 89L, 90L, 123L, 456L, 789L)) + val oldExecutorMetricsJson = + JsonProtocol.executorMetricsToJson(executorMetrics) + .removeField( _._1 == "MappedPoolMemory") + val expectedExecutorMetrics = new ExecutorMetrics( + Array(12L, 23L, 45L, 67L, 78L, 89L, 90L, 123L, 456L, 0L)) + assertEquals(expectedExecutorMetrics, + JsonProtocol.executorMetricsFromJson(oldExecutorMetricsJson)) + } + test("AccumulableInfo value de/serialization") { import InternalAccumulator._ val blocks = Seq[(BlockId, BlockStatus)]( @@ -435,7 +467,6 @@ class JsonProtocolSuite extends SparkFunSuite { testAccumValue(Some("anything"), blocks, JString(blocks.toString)) testAccumValue(Some("anything"), 123, JString("123")) } - } @@ -565,6 +596,13 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(stageAttemptId1 === stageAttemptId2) assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b)) }) + assertOptionEquals(e1.executorUpdates, e2.executorUpdates, + (e1: ExecutorMetrics, e2: ExecutorMetrics) => assertEquals(e1, e2)) + case (e1: SparkListenerStageExecutorMetrics, e2: SparkListenerStageExecutorMetrics) => + assert(e1.execId === e2.execId) + assert(e1.stageId === e2.stageId) + assert(e1.stageAttemptId === e2.stageAttemptId) + assertEquals(e1.executorMetrics, e2.executorMetrics) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -715,6 +753,12 @@ private[spark] object JsonProtocolSuite extends Assertions { assertStackTraceElementEquals) } + private def assertEquals(metrics1: ExecutorMetrics, metrics2: ExecutorMetrics) { + ExecutorMetricType.values.foreach { metricType => + assert(metrics1.getMetricValue(metricType) === metrics2.getMetricValue(metricType)) + } + } + private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { val expectedJson = pretty(parse(expected)) val actualJson = pretty(parse(actual)) @@ -765,7 +809,6 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(ste1 === ste2) } - /** ----------------------------------- * | Util methods for constructing events | * ------------------------------------ */ @@ -820,6 +863,27 @@ private[spark] object JsonProtocolSuite extends Assertions { new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), internal, countFailedValues, metadata) + /** Creates an SparkListenerExecutorMetricsUpdate event */ + private def makeExecutorMetricsUpdate( + execId: String, + includeTaskMetrics: Boolean, + includeExecutorMetrics: Boolean): SparkListenerExecutorMetricsUpdate = { + val taskMetrics = + if (includeTaskMetrics) { + Seq((1L, 1, 1, Seq(makeAccumulableInfo(1, false, false, None), + makeAccumulableInfo(2, false, false, None)))) + } else { + Seq() + } + val executorMetricsUpdate = + if (includeExecutorMetrics) { + Some(new ExecutorMetrics(Array(123456L, 543L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L))) + } else { + None + } + SparkListenerExecutorMetricsUpdate(execId, taskMetrics, executorMetricsUpdate) + } + /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is * set to true) or read data from a shuffle otherwise. @@ -2007,7 +2071,42 @@ private[spark] object JsonProtocolSuite extends Assertions { | } | ] | } - | ] + | ], + | "Executor Metrics Updated" : { + | "JVMHeapMemory" : 543, + | "JVMOffHeapMemory" : 123456, + | "OnHeapExecutionMemory" : 12345, + | "OffHeapExecutionMemory" : 1234, + | "OnHeapStorageMemory" : 123, + | "OffHeapStorageMemory" : 12, + | "OnHeapUnifiedMemory" : 432, + | "OffHeapUnifiedMemory" : 321, + | "DirectPoolMemory" : 654, + | "MappedPoolMemory" : 765 + | } + | + |} + """.stripMargin + + private val stageExecutorMetricsJsonString = + """ + |{ + | "Event": "SparkListenerStageExecutorMetrics", + | "Executor ID": "1", + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Executor Metrics" : { + | "JVMHeapMemory" : 543, + | "JVMOffHeapMemory" : 123456, + | "OnHeapExecutionMemory" : 12345, + | "OffHeapExecutionMemory" : 1234, + | "OnHeapStorageMemory" : 123, + | "OffHeapStorageMemory" : 12, + | "OnHeapUnifiedMemory" : 432, + | "OffHeapUnifiedMemory" : 321, + | "DirectPoolMemory" : 654, + | "MappedPoolMemory" : 765 + | } |} """.stripMargin diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 466135e72233a..777950016801d 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -81,6 +81,7 @@ app-20180109111548-0000 app-20161115172038-0000 app-20161116163331-0000 application_1516285256255_0012 +application_1506645932520_24630151 local-1422981759269 local-1422981780767 local-1425081759269 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7ff783da130af..55dc2b81cfe2f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,12 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23429][CORE] Add executor memory metrics to heartbeat and expose in executors REST API + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate$"), + // [SPARK-25248] add package private methods to TaskContext ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskFailed"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markInterrupted"), From 01c3dfab158d40653f8ce5d96f57220297545d5b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 8 Sep 2018 12:55:44 +0800 Subject: [PATCH 1574/2461] [MINOR][SQL] Add a debug log when a SQL text is used for a view ## What changes were proposed in this pull request? This took me a while to debug and find out. Looks we better at least leave a debug log that SQL text for a view will be used. Here's how I got there: **Hive:** ``` CREATE TABLE emp AS SELECT 'user' AS name, 'address' as address; CREATE DATABASE d100; CREATE FUNCTION d100.udf100 AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper'; CREATE VIEW testview AS SELECT d100.udf100(name) FROM default.emp; ``` **Spark:** ``` sql("SELECT * FROM testview").show() ``` ``` scala> sql("SELECT * FROM testview").show() org.apache.spark.sql.AnalysisException: Undefined function: 'd100.udf100'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 ``` Under the hood, it actually makes sense since the view is defined as `SELECT d100.udf100(name) FROM default.emp;` and Hive API: ``` org.apache.hadoop.hive.ql.metadata.Table.getViewExpandedText() ``` This returns a wrongly qualified SQL string for the view as below: ``` SELECT `d100.udf100`(`emp`.`name`) FROM `default`.`emp` ``` which works fine in Hive but not in Spark. ## How was this patch tested? Manually: ``` 18/09/06 19:32:48 DEBUG HiveSessionCatalog: 'SELECT `d100.udf100`(`emp`.`name`) FROM `default`.`emp`' will be used for the view(testview). ``` Closes #22351 from HyukjinKwon/minor-debug. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- .../org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index afb0f009db05c..c11b444212946 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -701,6 +701,7 @@ class SessionCatalog( val metadata = externalCatalog.getTable(db, table) if (metadata.tableType == CatalogTableType.VIEW) { val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) + logDebug(s"'$viewText' will be used for the view($table).") // The relation is a view, so we wrap the relation by: // 1. Add a [[View]] operator over the relation to keep track of the view desc; // 2. Wrap the logical plan in a [[SubqueryAlias]] which tracks the name of the view. From 08c02e637ac601df2fe890b8b5a7a049bdb4541b Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sat, 8 Sep 2018 09:09:14 -0700 Subject: [PATCH 1575/2461] [SPARK-25345][ML] Deprecate public APIs from ImageSchema ## What changes were proposed in this pull request? Deprecate public APIs from ImageSchema. ## How was this patch tested? N/A Closes #22349 from WeichenXu123/image_api_deprecate. Authored-by: WeichenXu Signed-off-by: Xiangrui Meng --- .../scala/org/apache/spark/ml/image/ImageSchema.scala | 4 ++++ python/pyspark/ml/image.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index dcc40b6668c7a..0b13eefdf3f5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -198,6 +198,8 @@ object ImageSchema { * @return DataFrame with a single column "image" of images; * see ImageSchema for the details */ + @deprecated("use `spark.read.format(\"image\").load(path)` and this `readImages` will be " + + "removed in 3.0.0.", "2.4.0") def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) /** @@ -218,6 +220,8 @@ object ImageSchema { * @return DataFrame with a single column "image" of images; * see ImageSchema for the details */ + @deprecated("use `spark.read.format(\"image\").load(path)` and this `readImages` will be " + + "removed in 3.0.0.", "2.4.0") def readImages( path: String, sparkSession: SparkSession, diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index ef6785b4a8ed4..edb90a3578546 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -25,8 +25,10 @@ """ import sys +import warnings import numpy as np + from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession @@ -207,6 +209,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. note:: If sample ratio is less than 1, sampling uses a PathFilter that is efficient but potentially non-deterministic. + .. note:: Deprecated in 2.4.0. Use `spark.read.format("image").load(path)` instead and + this `readImages` will be removed in 3.0.0. + :param str path: Path to the image directory. :param bool recursive: Recursive search flag. :param int numPartitions: Number of DataFrame partitions. @@ -222,7 +227,8 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - + warnings.warn("`ImageSchema.readImage` is deprecated. " + + "Use `spark.read.format(\"image\").load(path)` instead.", DeprecationWarning) spark = SparkSession.builder.getOrCreate() image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession From 26f74b7cb16869079aa7b60577ac05707101ee68 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 8 Sep 2018 10:21:55 -0700 Subject: [PATCH 1576/2461] [SPARK-25375][SQL][TEST] Reenable qualified perm. function checks in UDFSuite ## What changes were proposed in this pull request? At Spark 2.0.0, SPARK-14335 adds some [commented-out test coverages](https://github.com/apache/spark/pull/12117/files#diff-dd4b39a56fac28b1ced6184453a47358R177 ). This PR enables them because it's supported since 2.0.0. ## How was this patch tested? Pass the Jenkins with re-enabled test coverage. Closes #22363 from dongjoon-hyun/SPARK-25375. Authored-by: Dongjoon Hyun Signed-off-by: gatorsmile --- .../org/apache/spark/sql/hive/UDFSuite.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 88cc42efd0fe3..a56c6f73989a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -141,11 +141,10 @@ class UDFSuite withTempDatabase { dbName => withUserDefinedFunction(functionName -> false) { sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") - // TODO: Re-enable it after can distinguish qualified and unqualified function name - // checkAnswer( - // sql(s"SELECT $dbName.myuPPer(value) from $testTableName"), - // expectedDF - // ) + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) checkAnswer( sql(s"SHOW FUNCTIONS like $dbName.$functionNameUpper"), @@ -174,11 +173,10 @@ class UDFSuite // For this block, drop function command uses default.functionName as the function name. withUserDefinedFunction(s"$dbName.$functionNameUpper" -> false) { sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") - // TODO: Re-enable it after can distinguish qualified and unqualified function name - // checkAnswer( - // sql(s"SELECT $dbName.myupper(value) from $testTableName"), - // expectedDF - // ) + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) sql(s"USE $dbName") From 78981efc2cf321ce93e176d30d49bb1a8bd59eb2 Mon Sep 17 00:00:00 2001 From: ptkool Date: Sat, 8 Sep 2018 11:36:55 -0700 Subject: [PATCH 1577/2461] [SPARK-20636] Add new optimization rule to transpose adjacent Window expressions. ## What changes were proposed in this pull request? Add new optimization rule to eliminate unnecessary shuffling by flipping adjacent Window expressions. ## How was this patch tested? Tested with unit tests, integration tests, and manual tests. Closes #17899 from ptkool/adjacent_window_optimization. Authored-by: ptkool Signed-off-by: gatorsmile --- .../sql/catalyst/optimizer/Optimizer.scala | 22 ++++ .../optimizer/TransposeWindowSuite.scala | 114 ++++++++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 45 +++++-- 3 files changed, 170 insertions(+), 11 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4b4f1ecbe21f..b432ce24e1ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -734,6 +734,28 @@ object CollapseWindow extends Rule[LogicalPlan] { } } +/** + * Transpose Adjacent Window Expressions. + * - If the partition spec of the parent Window expression is compatible with the partition spec + * of the child window expression, transpose them. + */ +object TransposeWindow extends Rule[LogicalPlan] { + private def compatibleParititions(ps1 : Seq[Expression], ps2: Seq[Expression]): Boolean = { + ps1.length < ps2.length && ps2.take(ps1.length).permutations.exists(ps1.zip(_).forall { + case (l, r) => l.semanticEquals(r) + }) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) + if w1.references.intersect(w2.windowOutputSet).isEmpty && + w1.expressions.forall(_.deterministic) && + w2.expressions.forall(_.deterministic) && + compatibleParititions(ps1, ps2) => + Project(w1.output, Window(we2, ps2, os2, Window(we1, ps1, os1, grandChild))) + } +} + /** * Generate a list of additional filters from an operator's existing constraint but remove those * that are either already part of the operator's condition or are part of the operator's child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala new file mode 100644 index 0000000000000..58b3d1c98f3cd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class TransposeWindowSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) :: + Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil + } + + val testRelation = LocalRelation('a.string, 'b.string, 'c.int, 'd.string) + + val a = testRelation.output(0) + val b = testRelation.output(1) + val c = testRelation.output(2) + val d = testRelation.output(3) + + val partitionSpec1 = Seq(a) + val partitionSpec2 = Seq(a, b) + val partitionSpec3 = Seq(d) + val partitionSpec4 = Seq(b, a, d) + + val orderSpec1 = Seq(d.asc) + val orderSpec2 = Seq(d.desc) + + test("transpose two adjacent windows with compatible partitions") { + val query = testRelation + .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + val correctAnswer = testRelation + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1) + .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2) + .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1) + + comparePlans(optimized, correctAnswer.analyze) + } + + test("transpose two adjacent windows with differently ordered compatible partitions") { + val query = testRelation + .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + val correctAnswer = testRelation + .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty) + .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty) + .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1) + + comparePlans(optimized, correctAnswer.analyze) + } + + test("don't transpose two adjacent windows with incompatible partitions") { + val query = testRelation + .window(Seq(sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + comparePlans(optimized, analyzed) + } + + test("don't transpose two adjacent windows with intersection of partition and output set") { + val query = testRelation + .window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + comparePlans(optimized, analyzed) + } + + test("don't transpose two adjacent windows with non-deterministic expressions") { + val query = testRelation + .window(Seq(Rand(0).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty) + + val analyzed = query.analyze + val optimized = Optimize.execute(analyzed) + + comparePlans(optimized, analyzed) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 97a843978f0bd..78277d7dcf757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ * Window function testing for DataFrame API. */ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reuse window partitionBy") { @@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { cume_dist().over(Window.partitionBy("value").orderBy("key")), percent_rank().over(Window.partitionBy("value").orderBy("key"))), Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: - Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("window function should fail if order by clause is not specified") { @@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Seq( Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), - Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), - Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), - Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), - Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), - Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), - Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) } @@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { var_samp($"value").over(window), approx_count_distinct($"value").over(window)), Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) - ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) } test("window function with aggregates") { @@ -624,7 +625,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { def checkAnalysisError(df: => DataFrame): Unit = { - val thrownException = the [AnalysisException] thrownBy { + val thrownException = the[AnalysisException] thrownBy { df.queryExecution.analyzed } assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) @@ -658,4 +659,26 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { |GROUP BY a |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) } + + test("window functions in multiple selects") { + val df = Seq( + ("S1", "P1", 100), + ("S1", "P1", 700), + ("S2", "P1", 200), + ("S2", "P2", 300) + ).toDF("sno", "pno", "qty") + + val w1 = Window.partitionBy("sno") + val w2 = Window.partitionBy("sno", "pno") + + checkAnswer( + df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) + .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")), + Seq( + Row("S1", "P1", 100, 800, 800), + Row("S1", "P1", 700, 800, 800), + Row("S2", "P1", 200, 200, 500), + Row("S2", "P2", 300, 300, 500))) + + } } From 1cfda448255d5b4a0df88148e0f6acd88aa6e318 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Sat, 8 Sep 2018 22:18:06 -0700 Subject: [PATCH 1578/2461] [SPARK-25021][K8S] Add spark.executor.pyspark.memory limit for K8S ## What changes were proposed in this pull request? Add spark.executor.pyspark.memory limit for K8S ## How was this patch tested? Unit and Integration tests Closes #22298 from ifilonenko/SPARK-25021. Authored-by: Ilan Filonenko Signed-off-by: Holden Karau --- dev/make-distribution.sh | 1 + docs/configuration.md | 2 +- .../org/apache/spark/deploy/k8s/Config.scala | 7 +++ .../features/BasicExecutorFeatureStep.scala | 14 +++++- .../bindings/JavaDriverFeatureStep.scala | 4 +- .../bindings/PythonDriverFeatureStep.scala | 4 +- .../bindings/RDriverFeatureStep.scala | 4 +- .../BasicDriverFeatureStepSuite.scala | 1 - .../BasicExecutorFeatureStepSuite.scala | 24 ++++++++++ .../bindings/JavaDriverFeatureStepSuite.scala | 1 - .../src/main/dockerfiles/spark/Dockerfile | 1 + .../dockerfiles/spark/bindings/R/Dockerfile | 2 +- .../spark/bindings/python/Dockerfile | 2 +- .../k8s/integrationtest/KubernetesSuite.scala | 33 +++++++++++++ .../integrationtest/PythonTestsSuite.scala | 34 +++++++++++--- .../integrationtest/SecretsTestsSuite.scala | 1 + .../tests}/py_container_checks.py | 0 .../integration-tests/tests}/pyfiles.py | 0 .../tests/worker_memory_check.py | 47 +++++++++++++++++++ 19 files changed, 166 insertions(+), 16 deletions(-) rename {examples/src/main/python => resource-managers/kubernetes/integration-tests/tests}/py_container_checks.py (100%) rename {examples/src/main/python => resource-managers/kubernetes/integration-tests/tests}/pyfiles.py (100%) create mode 100644 resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index ad99ce55806af..778d376c12b56 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -192,6 +192,7 @@ fi if [ -d "$SPARK_HOME"/resource-managers/kubernetes/core/target/ ]; then mkdir -p "$DISTDIR/kubernetes/" cp -a "$SPARK_HOME"/resource-managers/kubernetes/docker/src/main/dockerfiles "$DISTDIR/kubernetes/" + cp -a "$SPARK_HOME"/resource-managers/kubernetes/integration-tests/tests "$DISTDIR/kubernetes/" fi # Copy examples and dependencies diff --git a/docs/configuration.md b/docs/configuration.md index f344bcd20087d..3a8d56776e9e8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -188,7 +188,7 @@ of the most common options to set are: unless otherwise specified. If set, PySpark memory for an executor will be limited to this amount. If not set, Spark will not limit Python's memory use and it is up to the application to avoid exceeding the overhead memory space - shared with other non-JVM processes. When PySpark is run in YARN, this memory + shared with other non-JVM processes. When PySpark is run in YARN or Kubernetes, this memory is added to executor resource requests. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index c5f4d6c53b7f9..71e4d321a0e3a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -225,6 +225,13 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") + val APP_RESOURCE_TYPE = + ConfigBuilder("spark.kubernetes.resource.type") + .doc("This sets the resource type internally") + .internal() + .stringConf + .createOptional + val KUBERNETES_LOCAL_DIRS_TMPFS = ConfigBuilder("spark.kubernetes.local.dirs.tmpfs") .doc("If set to true then emptyDir volumes created to back SPARK_LOCAL_DIRS will have " + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index c37f713c56de1..d89995ba5e4f4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, PYSPARK_EXECUTOR_MEMORY} import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -58,6 +58,16 @@ private[spark] class BasicExecutorFeatureStep( (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + private val executorMemoryTotal = kubernetesConf.sparkConf + .getOption(APP_RESOURCE_TYPE.key).map{ res => + val additionalPySparkMemory = res match { + case "python" => + kubernetesConf.sparkConf + .get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + case _ => 0 + } + executorMemoryWithOverhead + additionalPySparkMemory + }.getOrElse(executorMemoryWithOverhead) private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) private val executorCoresRequest = @@ -76,7 +86,7 @@ private[spark] class BasicExecutorFeatureStep( // executorId val hostname = name.substring(Math.max(0, name.length - 63)) val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") + .withAmount(s"${executorMemoryTotal}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) .withAmount(executorCoresRequest) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala index f52ec9fdc677e..6f063b253cd73 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.k8s.features.bindings import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep import org.apache.spark.launcher.SparkLauncher @@ -38,7 +39,8 @@ private[spark] class JavaDriverFeatureStep( .build() SparkPod(pod.pod, withDriverArgs) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + override def getAdditionalPodSystemProperties(): Map[String, String] = + Map(APP_RESOURCE_TYPE.key -> "java") override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala index 406944a953382..cf0c03b22bd7e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep @@ -68,7 +69,8 @@ private[spark] class PythonDriverFeatureStep( SparkPod(pod.pod, withPythonPrimaryContainer) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + override def getAdditionalPodSystemProperties(): Map[String, String] = + Map(APP_RESOURCE_TYPE.key -> "python") override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala index 11b09b399618b..1a7ef52fefe70 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep @@ -54,7 +55,8 @@ private[spark] class RDriverFeatureStep( SparkPod(pod.pod, withRPrimaryContainer) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + override def getAdditionalPodSystemProperties(): Map[String, String] = + Map(APP_RESOURCE_TYPE.key -> "r") override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index d98e113554648..0968cce971c31 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -57,7 +57,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { MAIN_CLASS, APP_ARGS) - test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 95d373f791649..63b237b9dfe46 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -75,6 +75,7 @@ class BasicExecutorFeatureStepSuite .set("spark.driver.host", DRIVER_HOSTNAME) .set("spark.driver.port", DRIVER_PORT.toString) .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + .set("spark.kubernetes.resource.type", "java") } test("basic executor pod has reasonable defaults") { @@ -161,6 +162,29 @@ class BasicExecutorFeatureStepSuite checkOwnerReferences(executor.pod, DRIVER_POD_UID) } + test("test executor pyspark memory") { + val conf = baseConf.clone() + conf.set("spark.kubernetes.resource.type", "python") + conf.set(org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY, 42L) + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) + val executor = step.configurePod(SparkPod.initialPod()) + // This is checking that basic executor + executorMemory = 1408 + 42 = 1450 + assert(executor.container.getResources.getRequests.get("memory").getAmount === "1450Mi") + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala index 18874afe6e53a..bf552aeb8b901 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -56,6 +56,5 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite { "--properties-file", SPARK_CONF_PATH, "--class", "test-class", "spark-internal", "5 7")) - } } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 071aa2020dd85..7ae57bf6e42d0 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -43,6 +43,7 @@ COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile index e627883ba782e..9f67422efeb3c 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile @@ -19,10 +19,10 @@ ARG base_img FROM $base_img WORKDIR / RUN mkdir ${SPARK_HOME}/R -COPY R ${SPARK_HOME}/R RUN apk add --no-cache R R-dev +COPY R ${SPARK_HOME}/R ENV R_HOME /usr/lib/R WORKDIR /opt/spark/work-dir diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile index 72bb9620b45de..69b6efa6149a0 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -19,7 +19,6 @@ ARG base_img FROM $base_img WORKDIR / RUN mkdir ${SPARK_HOME}/python -COPY python/lib ${SPARK_HOME}/python/lib # TODO: Investigate running both pip and pip3 via virtualenvs RUN apk add --no-cache python && \ apk add --no-cache python3 && \ @@ -33,6 +32,7 @@ RUN apk add --no-cache python && \ # Removed the .cache to save space rm -r /root/.cache +COPY python/lib ${SPARK_HOME}/python/lib ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip WORKDIR /opt/spark/work-dir diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 896a83a5badbb..82e6efa2707d9 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -50,6 +50,17 @@ private[spark] class KubernetesSuite extends SparkFunSuite protected var containerLocalSparkDistroExamplesJar: String = _ protected var appLocator: String = _ + // Default memory limit is 1024M + 384M (minimum overhead constant) + private val baseMemory = s"${1024 + 384}Mi" + protected val memOverheadConstant = 0.8 + private val standardNonJVMMemory = s"${(1024 + 0.4*1024).toInt}Mi" + protected val additionalMemory = 200 + // 209715200 is 200Mi + protected val additionalMemoryInBytes = 209715200 + private val extraDriverTotalMemory = s"${(1024 + memOverheadConstant*1024).toInt}Mi" + private val extraExecTotalMemory = + s"${(1024 + memOverheadConstant*1024 + additionalMemory).toInt}Mi" + override def beforeAll(): Unit = { // The scalatest-maven-plugin gives system properties that are referenced but not set null // values. We need to remove the null-value properties before initializing the test backend. @@ -233,6 +244,8 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === image) assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === baseMemory) } @@ -240,28 +253,48 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) } protected def doBasicDriverRPodCheck(driverPod: Pod): Unit = { assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === rImage) assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) } protected def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === baseMemory) } protected def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) } protected def doBasicExecutorRPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === rImage) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === standardNonJVMMemory) + } + + protected def doDriverMemoryCheck(driverPod: Pod): Unit = { + assert(driverPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === extraDriverTotalMemory) + } + + protected def doExecutorMemoryCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount + === extraExecTotalMemory) } protected def checkCustomSettings(pod: Pod): Unit = { diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala index 1ebb30094dcde..06b73107ec236 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -23,9 +23,11 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => import PythonTestsSuite._ import KubernetesSuite.k8sTestTag + private val pySparkDockerImage = + s"${getTestImageRepo}/spark-py:${getTestImageTag}" test("Run PySpark on simple pi.py example", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.container.image", pySparkDockerImage) runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_PI, mainClass = "", @@ -39,7 +41,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.container.image", pySparkDockerImage) .set("spark.kubernetes.pyspark.pythonVersion", "2") runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, @@ -57,7 +59,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.container.image", pySparkDockerImage) .set("spark.kubernetes.pyspark.pythonVersion", "3") runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, @@ -72,12 +74,32 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => isJVM = false, pyFiles = Some(PYSPARK_CONTAINER_TESTS)) } + + test("Run PySpark with memory customization", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.pyspark.pythonVersion", "3") + .set("spark.kubernetes.memoryOverheadFactor", s"$memOverheadConstant") + .set("spark.executor.pyspark.memory", s"${additionalMemory}m") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_MEMORY_CHECK, + mainClass = "", + expectedLogOnCompletion = Seq( + "PySpark Worker Memory Check is: True"), + appArgs = Array(s"$additionalMemoryInBytes"), + driverPodChecker = doDriverMemoryCheck, + executorPodChecker = doExecutorMemoryCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } } private[spark] object PythonTestsSuite { val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" - val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" - val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" + val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" + val PYSPARK_FILES: String = TEST_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = TEST_LOCAL_PYSPARK + "py_container_checks.py" + val PYSPARK_MEMORY_CHECK: String = TEST_LOCAL_PYSPARK + "worker_memory_check.py" } - diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala index 7b05c1355ca24..9b039bb98dd9a 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala @@ -53,6 +53,7 @@ private[spark] trait SecretsTestsSuite { k8sSuite: KubernetesSuite => .delete() } + // TODO: [SPARK-25291] This test is flaky with regards to memory of executors test("Run SparkPi with env and mount secrets.", k8sTestTag) { createTestSecret() sparkAppConf diff --git a/examples/src/main/python/py_container_checks.py b/resource-managers/kubernetes/integration-tests/tests/py_container_checks.py similarity index 100% rename from examples/src/main/python/py_container_checks.py rename to resource-managers/kubernetes/integration-tests/tests/py_container_checks.py diff --git a/examples/src/main/python/pyfiles.py b/resource-managers/kubernetes/integration-tests/tests/pyfiles.py similarity index 100% rename from examples/src/main/python/pyfiles.py rename to resource-managers/kubernetes/integration-tests/tests/pyfiles.py diff --git a/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py b/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py new file mode 100644 index 0000000000000..d312a29f388e4 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/tests/worker_memory_check.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import resource +import sys + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: worker_memory_check [Memory_in_Mi] + """ + spark = SparkSession \ + .builder \ + .appName("PyMemoryTest") \ + .getOrCreate() + sc = spark.sparkContext + if len(sys.argv) < 2: + print("Usage: worker_memory_check [Memory_in_Mi]", file=sys.stderr) + sys.exit(-1) + + def f(x): + rLimit = resource.getrlimit(resource.RLIMIT_AS) + print("RLimit is " + str(rLimit)) + return rLimit + resourceValue = sc.parallelize([1]).map(f).collect()[0][0] + print("Resource Value is " + str(resourceValue)) + truthCheck = (resourceValue == int(sys.argv[1])) + print("PySpark Worker Memory Check is: " + str(truthCheck)) + spark.stop() From 0b9ccd55c2986957863dcad3b44ce80403eecfa1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 9 Sep 2018 21:25:19 +0800 Subject: [PATCH 1579/2461] Revert [SPARK-10399] [SPARK-23879] [SPARK-23762] [SPARK-25317] ## What changes were proposed in this pull request? When running TPC-DS benchmarks on 2.4 release, npoggi and winglungngai saw more than 10% performance regression on the following queries: q67, q24a and q24b. After we applying the PR https://github.com/apache/spark/pull/22338, the performance regression still exists. If we revert the changes in https://github.com/apache/spark/pull/19222, npoggi and winglungngai found the performance regression was resolved. Thus, this PR is to revert the related changes for unblocking the 2.4 release. In the future release, we still can continue the investigation and find out the root cause of the regression. ## How was this patch tested? The existing test cases Closes #22361 from gatorsmile/revertMemoryBlock. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/HiveHasher.java | 18 +- .../org/apache/spark/unsafe/Platform.java | 2 +- .../spark/unsafe/array/ByteArrayMethods.java | 15 +- .../apache/spark/unsafe/array/LongArray.java | 17 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 53 ++---- .../unsafe/memory/ByteArrayMemoryBlock.java | 128 ------------- .../unsafe/memory/HeapMemoryAllocator.java | 21 +- .../spark/unsafe/memory/MemoryAllocator.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 157 ++------------- .../spark/unsafe/memory/MemoryLocation.java | 96 ++++++---- .../unsafe/memory/OffHeapMemoryBlock.java | 105 ---------- .../unsafe/memory/OnHeapMemoryBlock.java | 132 ------------- .../unsafe/memory/UnsafeMemoryAllocator.java | 21 +- .../apache/spark/unsafe/types/UTF8String.java | 147 +++++++------- .../spark/unsafe/PlatformUtilSuite.java | 4 +- .../spark/unsafe/array/LongArraySuite.java | 5 +- .../unsafe/hash/Murmur3_x86_32Suite.java | 18 -- .../spark/unsafe/memory/MemoryBlockSuite.java | 179 ------------------ .../spark/unsafe/types/UTF8StringSuite.java | 41 ++-- .../spark/memory/TaskMemoryManager.java | 22 +-- .../shuffle/sort/ShuffleInMemorySorter.java | 14 +- .../shuffle/sort/ShuffleSortDataFormat.java | 11 +- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../unsafe/sort/UnsafeInMemorySorter.java | 13 +- .../spark/memory/TaskMemoryManagerSuite.java | 2 +- .../util/collection/ExternalSorterSuite.scala | 7 +- .../unsafe/sort/RadixSortSuite.scala | 10 +- .../spark/ml/feature/FeatureHasher.scala | 5 +- .../spark/mllib/feature/HashingTF.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../spark/sql/catalyst/expressions/XXH64.java | 47 ++--- .../codegen/UTF8StringBuilder.java | 35 ++-- .../spark/sql/catalyst/expressions/hash.scala | 37 ++-- .../catalyst/expressions/HiveHasherSuite.java | 21 +- .../sql/catalyst/expressions/XXH64Suite.java | 18 +- .../vectorized/OffHeapColumnVector.java | 3 +- .../sql/vectorized/ArrowColumnVector.java | 6 +- .../execution/benchmark/SortBenchmark.scala | 16 +- .../sql/execution/python/RowQueueSuite.scala | 4 +- 40 files changed, 376 insertions(+), 1070 deletions(-) delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala => common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java (51%) delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java delete mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 62b75ae8aa01d..73577437ac506 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.Platform; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -39,21 +38,12 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytesBlock(MemoryBlock mb) { - long lengthInBytes = mb.size(); + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (long i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) mb.getByte(i); + for (int i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) Platform.getByte(base, offset + i); } return result; } - - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { - return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); - } - - public static int hashUTF8String(UTF8String str) { - return hashUnsafeBytesBlock(str.getMemoryBlock()); - } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 54dcadf3a7754..aca6fca00c48b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index ef0f78d95d1ee..cec8c30887e2f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -53,25 +52,15 @@ public static long roundNumberOfBytesToNearestWord(long numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); - /** - * MemoryBlock equality check for MemoryBlocks. - * @return true if the arrays are equal, false otherwise - */ - public static boolean arrayEqualsBlock( - MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { - return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, - rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); - } - /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if starts align and we can get both offsets to be aligned + // check if stars align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index b74d2de0691d5..2cd39bd60c2ac 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe.array; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -32,12 +33,16 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -46,11 +51,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return memory.getBaseObject(); + return baseObj; } public long getBaseOffset() { - return memory.getBaseOffset(); + return baseOffset; } /** @@ -64,8 +69,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = 0; off < length * WIDTH; off += WIDTH) { - memory.putLong(off, 0); + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); } } @@ -75,7 +80,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - memory.putLong(index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -84,6 +89,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return memory.getLong(index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 566f116154302..d239de6083ad0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,11 +17,7 @@ package org.apache.spark.unsafe.hash; -import com.google.common.primitives.Ints; - import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.types.UTF8String; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -53,74 +49,49 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + return hashUnsafeWords(base, offset, lengthInBytes, seed); } - public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. - int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByIntBlock(base, lengthInBytes, seed); + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { - // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. - return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); - } - - public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { - return hashUnsafeBytesBlock(base, Ints.checkedCast(base.size()), seed); - } - - private static int hashUnsafeBytesBlock(MemoryBlock base, int lengthInBytes, int seed) { + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByIntBlock(base, lengthAligned, seed); - long offset = base.getBaseOffset(); - Object o = base.getBaseObject(); + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(o, offset + i); + int halfWord = Platform.getByte(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } - public static int hashUTF8String(UTF8String str, int seed) { - return hashUnsafeBytesBlock(str.getMemoryBlock(), str.numBytes(), seed); - } - - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { - return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); - } - public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { - return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); - } - - public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { - // This is compatible with original and other implementations. + // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - int lengthInBytes = Ints.checkedCast(base.size()); - assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByIntBlock(base, lengthAligned, seed); + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (base.getByte(i) & 0xFF) << shift; + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByIntBlock(MemoryBlock base, int lengthInBytes, int seed) { + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = base.getInt(i); + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java deleted file mode 100644 index 9f238632bc87a..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; - -/** - * A consecutive block of memory with a byte array on Java heap. - */ -public final class ByteArrayMemoryBlock extends MemoryBlock { - - private final byte[] array; - - public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { - super(obj, offset, size); - this.array = obj; - assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : - "The sum of size " + size + " and offset " + offset + " should not be larger than " + - "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); - } - - public ByteArrayMemoryBlock(long length) { - this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); - } - - @Override - public MemoryBlock subBlock(long offset, long size) { - checkSubBlockRange(offset, size); - if (offset == 0 && size == this.size()) return this; - return new ByteArrayMemoryBlock(array, this.offset + offset, size); - } - - public byte[] getByteArray() { return array; } - - /** - * Creates a memory block pointing to the memory used by the byte array. - */ - public static ByteArrayMemoryBlock fromArray(final byte[] array) { - return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); - } - - @Override - public int getInt(long offset) { - return Platform.getInt(array, this.offset + offset); - } - - @Override - public void putInt(long offset, int value) { - Platform.putInt(array, this.offset + offset, value); - } - - @Override - public boolean getBoolean(long offset) { - return Platform.getBoolean(array, this.offset + offset); - } - - @Override - public void putBoolean(long offset, boolean value) { - Platform.putBoolean(array, this.offset + offset, value); - } - - @Override - public byte getByte(long offset) { - return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; - } - - @Override - public void putByte(long offset, byte value) { - array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; - } - - @Override - public short getShort(long offset) { - return Platform.getShort(array, this.offset + offset); - } - - @Override - public void putShort(long offset, short value) { - Platform.putShort(array, this.offset + offset, value); - } - - @Override - public long getLong(long offset) { - return Platform.getLong(array, this.offset + offset); - } - - @Override - public void putLong(long offset, long value) { - Platform.putLong(array, this.offset + offset, value); - } - - @Override - public float getFloat(long offset) { - return Platform.getFloat(array, this.offset + offset); - } - - @Override - public void putFloat(long offset, float value) { - Platform.putFloat(array, this.offset + offset, value); - } - - @Override - public double getDouble(long offset) { - return Platform.getDouble(array, this.offset + offset); - } - - @Override - public void putDouble(long offset, double value) { - Platform.putDouble(array, this.offset + offset, value); - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 36caf80888cda..2733760dd19ef 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,6 +23,8 @@ import java.util.LinkedList; import java.util.Map; +import org.apache.spark.unsafe.Platform; + /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ @@ -56,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -68,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -77,13 +79,12 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert(memory instanceof OnHeapMemoryBlock); - assert (memory.getBaseObject() != null) : + assert (memory.obj != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) - || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -93,12 +94,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); - memory.resetObjAndOffset(); + long[] array = (long[]) memory.obj; + memory.setObjAndOffset(null, 0); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 38315fb97b46a..7b588681d9790 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); + MemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index ca7213bbf92da..c333857358d30 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A representation of a consecutive memory block in Spark. It defines the common interfaces - * for memory accessing and mutating. + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. */ -public abstract class MemoryBlock { +public class MemoryBlock extends MemoryLocation { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,163 +45,38 @@ public abstract class MemoryBlock { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - @Nullable - protected Object obj; - - protected long offset; - - protected long length; + private final long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field can be updated using setPageNumber method so that - * this can be modified by the TaskMemoryManager, which lives in a different package. + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. */ - private int pageNumber = NO_PAGE_NUMBER; + public int pageNumber = NO_PAGE_NUMBER; - protected MemoryBlock(@Nullable Object obj, long offset, long length) { - if (offset < 0 || length < 0) { - throw new IllegalArgumentException( - "Length " + length + " and offset " + offset + "must be non-negative"); - } - this.obj = obj; - this.offset = offset; + public MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); this.length = length; } - protected MemoryBlock() { - this(null, 0, 0); - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } - - public void resetObjAndOffset() { - this.obj = null; - this.offset = 0; - } - /** * Returns the size of the memory block. */ - public final long size() { + public long size() { return length; } - public final void setPageNumber(int pageNum) { - pageNumber = pageNum; - } - - public final int getPageNumber() { - return pageNumber; - } - - /** - * Fills the memory block with the specified byte value. - */ - public final void fill(byte value) { - Platform.setMemory(obj, offset, length, value); - } - - /** - * Instantiate MemoryBlock for given object type with new offset - */ - public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { - MemoryBlock mb = null; - if (obj instanceof byte[]) { - byte[] array = (byte[])obj; - mb = new ByteArrayMemoryBlock(array, offset, length); - } else if (obj instanceof long[]) { - long[] array = (long[])obj; - mb = new OnHeapMemoryBlock(array, offset, length); - } else if (obj == null) { - // we assume that to pass null pointer means off-heap - mb = new OffHeapMemoryBlock(offset, length); - } else { - throw new UnsupportedOperationException( - "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); - } - return mb; - } - /** - * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative - * offset from the original offset. The data is not copied. - * If parameters are invalid, an exception is thrown. + * Creates a memory block pointing to the memory used by the long array. */ - public abstract MemoryBlock subBlock(long offset, long size); - - protected void checkSubBlockRange(long offset, long size) { - if (offset < 0 || size < 0) { - throw new ArrayIndexOutOfBoundsException( - "Size " + size + " and offset " + offset + " must be non-negative"); - } - if (offset + size > length) { - throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + - offset + " should not be larger than the length " + length + " in the MemoryBlock"); - } + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); } /** - * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal - * memory access, throw an exception, or etc. - * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as - * JVM object header. The offset is 0-based and is expected as an logical offset in the memory - * block. + * Fills the memory block with the specified byte value. */ - public abstract int getInt(long offset); - - public abstract void putInt(long offset, int value); - - public abstract boolean getBoolean(long offset); - - public abstract void putBoolean(long offset, boolean value); - - public abstract byte getByte(long offset); - - public abstract void putByte(long offset, byte value); - - public abstract short getShort(long offset); - - public abstract void putShort(long offset, short value); - - public abstract long getLong(long offset); - - public abstract void putLong(long offset, long value); - - public abstract float getFloat(long offset); - - public abstract void putFloat(long offset, float value); - - public abstract double getDouble(long offset); - - public abstract void putDouble(long offset, double value); - - public static final void copyMemory( - MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { - assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); - Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, - dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); - } - - public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { - assert(length <= src.length && length <= dst.length); - Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), - dst.getBaseObject(), dst.getBaseOffset(), length); - } - - public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { - assert(length <= this.length - srcOffset); - Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); - } - - public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { - assert(length <= this.length - srcOffset); - Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + public void fill(byte value) { + Platform.setMemory(obj, offset, length, value); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java similarity index 51% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java index 1b25a4b191f86..74ebc87dc978c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -1,42 +1,54 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.SparkFunSuite -import org.apache.spark.unsafe.types.UTF8String - -class UTF8StringBuilderSuite extends SparkFunSuite { - - test("basic test") { - val sb = new UTF8StringBuilder() - assert(sb.build() === UTF8String.EMPTY_UTF8) - - sb.append("") - assert(sb.build() === UTF8String.EMPTY_UTF8) - - sb.append("abcd") - assert(sb.build() === UTF8String.fromString("abcd")) - - sb.append(UTF8String.fromString("1234")) - assert(sb.build() === UTF8String.fromString("abcd1234")) - - // expect to grow an internal buffer - sb.append(UTF8String.fromString("efgijk567890")) - assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java deleted file mode 100644 index 3431b08980eb8..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import org.apache.spark.unsafe.Platform; - -public class OffHeapMemoryBlock extends MemoryBlock { - public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); - - public OffHeapMemoryBlock(long address, long size) { - super(null, address, size); - } - - @Override - public MemoryBlock subBlock(long offset, long size) { - checkSubBlockRange(offset, size); - if (offset == 0 && size == this.size()) return this; - return new OffHeapMemoryBlock(this.offset + offset, size); - } - - @Override - public final int getInt(long offset) { - return Platform.getInt(null, this.offset + offset); - } - - @Override - public final void putInt(long offset, int value) { - Platform.putInt(null, this.offset + offset, value); - } - - @Override - public final boolean getBoolean(long offset) { - return Platform.getBoolean(null, this.offset + offset); - } - - @Override - public final void putBoolean(long offset, boolean value) { - Platform.putBoolean(null, this.offset + offset, value); - } - - @Override - public final byte getByte(long offset) { - return Platform.getByte(null, this.offset + offset); - } - - @Override - public final void putByte(long offset, byte value) { - Platform.putByte(null, this.offset + offset, value); - } - - @Override - public final short getShort(long offset) { - return Platform.getShort(null, this.offset + offset); - } - - @Override - public final void putShort(long offset, short value) { - Platform.putShort(null, this.offset + offset, value); - } - - @Override - public final long getLong(long offset) { - return Platform.getLong(null, this.offset + offset); - } - - @Override - public final void putLong(long offset, long value) { - Platform.putLong(null, this.offset + offset, value); - } - - @Override - public final float getFloat(long offset) { - return Platform.getFloat(null, this.offset + offset); - } - - @Override - public final void putFloat(long offset, float value) { - Platform.putFloat(null, this.offset + offset, value); - } - - @Override - public final double getDouble(long offset) { - return Platform.getDouble(null, this.offset + offset); - } - - @Override - public final void putDouble(long offset, double value) { - Platform.putDouble(null, this.offset + offset, value); - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java deleted file mode 100644 index ee42bc27c9c5f..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; - -/** - * A consecutive block of memory with a long array on Java heap. - */ -public final class OnHeapMemoryBlock extends MemoryBlock { - - private final long[] array; - - public OnHeapMemoryBlock(long[] obj, long offset, long size) { - super(obj, offset, size); - this.array = obj; - assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : - "The sum of size " + size + " and offset " + offset + " should not be larger than " + - "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); - } - - public OnHeapMemoryBlock(long size) { - this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); - } - - @Override - public MemoryBlock subBlock(long offset, long size) { - checkSubBlockRange(offset, size); - if (offset == 0 && size == this.size()) return this; - return new OnHeapMemoryBlock(array, this.offset + offset, size); - } - - public long[] getLongArray() { return array; } - - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static OnHeapMemoryBlock fromArray(final long[] array) { - return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); - } - - public static OnHeapMemoryBlock fromArray(final long[] array, long size) { - return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); - } - - @Override - public int getInt(long offset) { - return Platform.getInt(array, this.offset + offset); - } - - @Override - public void putInt(long offset, int value) { - Platform.putInt(array, this.offset + offset, value); - } - - @Override - public boolean getBoolean(long offset) { - return Platform.getBoolean(array, this.offset + offset); - } - - @Override - public void putBoolean(long offset, boolean value) { - Platform.putBoolean(array, this.offset + offset, value); - } - - @Override - public byte getByte(long offset) { - return Platform.getByte(array, this.offset + offset); - } - - @Override - public void putByte(long offset, byte value) { - Platform.putByte(array, this.offset + offset, value); - } - - @Override - public short getShort(long offset) { - return Platform.getShort(array, this.offset + offset); - } - - @Override - public void putShort(long offset, short value) { - Platform.putShort(array, this.offset + offset, value); - } - - @Override - public long getLong(long offset) { - return Platform.getLong(array, this.offset + offset); - } - - @Override - public void putLong(long offset, long value) { - Platform.putLong(array, this.offset + offset, value); - } - - @Override - public float getFloat(long offset) { - return Platform.getFloat(array, this.offset + offset); - } - - @Override - public void putFloat(long offset, float value) { - Platform.putFloat(array, this.offset + offset, value); - } - - @Override - public double getDouble(long offset) { - return Platform.getDouble(array, this.offset + offset); - } - - @Override - public void putDouble(long offset, double value) { - Platform.putDouble(array, this.offset + offset, value); - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 5310bdf2779a9..4368fb615ba1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { + public MemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); + MemoryBlock memory = new MemoryBlock(null, address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,25 +36,22 @@ public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert(memory instanceof OffHeapMemoryBlock) : - "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; - if (memory == OffHeapMemoryBlock.NULL) return; - assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) - || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } - Platform.freeMemory(memory.offset); - // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.resetObjAndOffset(); + memory.offset = 0; // Mark the page as freed (so we can detect double-frees). - memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e91fc4391425c..dff4a73f3e9da 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -34,8 +34,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -53,13 +51,12 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private MemoryBlock base; - // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int + private Object base; + private long offset; private int numBytes; - public MemoryBlock getMemoryBlock() { return base; } - public Object getBaseObject() { return base.getBaseObject(); } - public long getBaseOffset() { return base.getBaseOffset(); } + public Object getBaseObject() { return base; } + public long getBaseOffset() { return offset; } /** * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which @@ -112,8 +109,7 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String( - new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); + return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); } else { return null; } @@ -126,13 +122,19 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String( - new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); } else { return null; } } + /** + * Creates an UTF8String from given address (base and offset) and length. + */ + public static UTF8String fromAddress(Object base, long offset, int numBytes) { + return new UTF8String(base, offset, numBytes); + } + /** * Creates an UTF8String from String. */ @@ -149,13 +151,16 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - public UTF8String(MemoryBlock base) { + protected UTF8String(Object base, long offset, int numBytes) { this.base = base; - this.numBytes = Ints.checkedCast(base.size()); + this.offset = offset; + this.numBytes = numBytes; } // for serialization - public UTF8String() {} + public UTF8String() { + this(null, 0, 0); + } /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -163,7 +168,7 @@ public UTF8String() {} * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - base.writeTo(0, target, targetOffset, numBytes); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -183,9 +188,8 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - long offset = base.getBaseOffset(); - if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); + if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = (byte[]) base; // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -252,12 +256,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = base.getLong(0); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = base.getLong(0); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) base.getInt(0); + p = (long) Platform.getInt(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -266,12 +270,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = base.getLong(0); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = base.getLong(0); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) base.getInt(0)) << 32; + p = ((long) Platform.getInt(base, offset)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -286,13 +290,12 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - long offset = base.getBaseOffset(); - if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock - && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { - return ((ByteArrayMemoryBlock) base).getByteArray(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] + && ((byte[]) base).length == numBytes) { + return (byte[]) base; } else { byte[] bytes = new byte[numBytes]; - base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -322,7 +325,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -363,14 +366,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return base.getByte(i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); + return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -497,7 +500,8 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { return n; } lastComma = i; @@ -505,7 +509,8 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { + ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, + match.numBytes)) { return n; } return 0; @@ -520,7 +525,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); + copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -667,7 +672,8 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -681,7 +687,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -718,7 +724,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { + if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -734,7 +740,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { return start; } start += 1; @@ -748,7 +754,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { + if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { return start; } start -= 1; @@ -781,7 +787,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -801,7 +807,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -824,15 +830,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -860,13 +866,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -891,8 +897,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - inputs[i].base.writeTo( - 0, + copyMemory( + inputs[i].base, inputs[i].offset, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -931,8 +937,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - inputs[i].base.writeTo( - 0, + copyMemory( + inputs[i].base, inputs[i].offset, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -940,8 +946,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - separator.base.writeTo( - 0, + copyMemory( + separator.base, separator.offset, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1215,7 +1221,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1223,10 +1229,11 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - MemoryBlock rbase = other.base; + long roffset = other.offset; + Object rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = base.getLong(i); - long right = rbase.getLong(i); + long left = getLong(base, offset + i); + long right = getLong(rbase, roffset + i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1237,7 +1244,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); + int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); if (res != 0) { return res; } @@ -1256,7 +1263,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); + return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); } else { return false; } @@ -1312,8 +1319,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, - i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, + s.offset + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1328,7 +1335,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); + return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); } /** @@ -1391,10 +1398,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - byte[] bytes = new byte[numBytes]; - in.readFully(bytes); - base = ByteArrayMemoryBlock.fromArray(bytes); + base = new byte[numBytes]; + in.readFully((byte[]) base); } @Override @@ -1406,10 +1413,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - numBytes = in.readInt(); - byte[] bytes = new byte[numBytes]; - in.read(bytes); - base = ByteArrayMemoryBlock.fromArray(bytes); + this.offset = BYTE_ARRAY_OFFSET; + this.numBytes = in.readInt(); + this.base = new byte[numBytes]; + in.read((byte[]) base); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 583a148b3845d..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 8c2e98c2bfc54..fb8e53b3348f3 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,13 +20,14 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; public class LongArraySuite { @Test public void basicTest() { - LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index d9898771720ae..6348a73bf3895 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,24 +70,6 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } - @Test - public void testKnownWordsInputs() { - byte[] bytes = new byte[16]; - long offset = Platform.BYTE_ARRAY_OFFSET; - for (int i = 0; i < 16; i++) { - bytes[i] = 0; - } - Assert.assertEquals(-300363099, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); - for (int i = 0; i < 16; i++) { - bytes[i] = -1; - } - Assert.assertEquals(-1210324667, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); - for (int i = 0; i < 16; i++) { - bytes[i] = (byte)i; - } - Assert.assertEquals(-634919701, Murmur3_x86_32.hashUnsafeWords(bytes, offset, 16, 42)); - } - @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java deleted file mode 100644 index ef5ff8ee70ec0..0000000000000 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import org.apache.spark.unsafe.Platform; -import org.junit.Assert; -import org.junit.Test; - -import java.nio.ByteOrder; - -import static org.hamcrest.core.StringContains.containsString; - -public class MemoryBlockSuite { - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); - - private void check(MemoryBlock memory, Object obj, long offset, int length) { - memory.setPageNumber(1); - memory.fill((byte)-1); - memory.putBoolean(0, true); - memory.putByte(1, (byte)127); - memory.putShort(2, (short)257); - memory.putInt(4, 0x20000002); - memory.putLong(8, 0x1234567089ABCDEFL); - memory.putFloat(16, 1.0F); - memory.putLong(20, 0x1234567089ABCDEFL); - memory.putDouble(28, 2.0); - MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); - int[] a = new int[2]; - a[0] = 0x12345678; - a[1] = 0x13579BDF; - memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); - byte[] b = new byte[8]; - memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); - - Assert.assertEquals(obj, memory.getBaseObject()); - Assert.assertEquals(offset, memory.getBaseOffset()); - Assert.assertEquals(length, memory.size()); - Assert.assertEquals(1, memory.getPageNumber()); - Assert.assertEquals(true, memory.getBoolean(0)); - Assert.assertEquals((byte)127, memory.getByte(1 )); - Assert.assertEquals((short)257, memory.getShort(2)); - Assert.assertEquals(0x20000002, memory.getInt(4)); - Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); - Assert.assertEquals(1.0F, memory.getFloat(16), 0); - Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); - Assert.assertEquals(2.0, memory.getDouble(28), 0); - Assert.assertEquals(true, memory.getBoolean(36)); - Assert.assertEquals((byte)127, memory.getByte(37 )); - Assert.assertEquals((short)257, memory.getShort(38)); - Assert.assertEquals(a[0], memory.getInt(40)); - Assert.assertEquals(a[1], memory.getInt(44)); - if (bigEndianPlatform) { - Assert.assertEquals(a[0], - ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | - ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); - Assert.assertEquals(a[1], - ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | - ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); - } else { - Assert.assertEquals(a[0], - ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | - ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); - Assert.assertEquals(a[1], - ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | - ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); - } - for (int i = 48; i < memory.size(); i++) { - Assert.assertEquals((byte) -1, memory.getByte(i)); - } - - assert(memory.subBlock(0, memory.size()) == memory); - - try { - memory.subBlock(-8, 8); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("non-negative")); - } - - try { - memory.subBlock(0, -8); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("non-negative")); - } - - try { - memory.subBlock(0, length + 8); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); - } - - try { - memory.subBlock(8, length - 4); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); - } - - try { - memory.subBlock(length + 8, 4); - Assert.fail(); - } catch (Exception expected) { - Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); - } - - memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); - } - - @Test - public void testByteArrayMemoryBlock() { - byte[] obj = new byte[56]; - long offset = Platform.BYTE_ARRAY_OFFSET; - int length = obj.length; - - MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - - memory = ByteArrayMemoryBlock.fromArray(obj); - check(memory, obj, offset, length); - - obj = new byte[112]; - memory = new ByteArrayMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - } - - @Test - public void testOnHeapMemoryBlock() { - long[] obj = new long[7]; - long offset = Platform.LONG_ARRAY_OFFSET; - int length = obj.length * 8; - - MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - - memory = OnHeapMemoryBlock.fromArray(obj); - check(memory, obj, offset, length); - - obj = new long[14]; - memory = new OnHeapMemoryBlock(obj, offset, length); - check(memory, obj, offset, length); - } - - @Test - public void testOffHeapArrayMemoryBlock() { - MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); - MemoryBlock memory = memoryAllocator.allocate(56); - Object obj = memory.getBaseObject(); - long offset = memory.getBaseOffset(); - int length = 56; - - check(memory, obj, offset, length); - memoryAllocator.free(memory); - - long address = Platform.allocateMemory(112); - memory = new OffHeapMemoryBlock(address, length); - obj = memory.getBaseObject(); - offset = memory.getBaseOffset(); - check(memory, obj, offset, length); - Platform.freeMemory(address); - } -} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 42dda30480702..dae13f03b02ff 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,8 +25,7 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; +import org.apache.spark.unsafe.Platform; import org.junit.Test; import static org.junit.Assert.*; @@ -513,6 +512,21 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } + @Test + public void writeToOutputStreamUnderflow() throws IOException { + // offset underflow is apparently supported? + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + .writeTo(outputStream); + final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); + assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); + outputStream.reset(); + } + } + @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); @@ -520,7 +534,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -551,7 +565,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) + fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -578,25 +592,26 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamLongArray() throws IOException { + public void writeToOutputStreamIntArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(16, length); + assertEquals(12, length); - final int longs = length / 8; - final long[] array = new long[longs]; + final int ints = length / 4; + final int[] array = new int[ints]; - for (int i = 0; i < longs; ++i) { - array[i] = buffer.getLong(); + for (int i = 0; i < ints; ++i) { + array[i] = buffer.getInt(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); - assertEquals("3千大千世界", outputStream.toString("UTF-8")); + fromAddress(array, Platform.INT_ARRAY_OFFSET, length) + .writeTo(outputStream); + assertEquals("大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 8651a639c07f7..d07faf1da1248 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.setPageNumber(pageNumber); + page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.getPageNumber())); - pageTable[page.getPageNumber()] = null; + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; synchronized (this) { - allocatedPages.clear(page.getPageNumber()); + allocatedPages.clear(page.pageNumber); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 4b48599ad311e..0d069125dc60e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,6 +20,7 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -112,7 +113,13 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L + ); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -181,7 +188,10 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 254449e95443e..717bdd79d47ef 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,8 +60,13 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, - dst.memoryBlock(),dstPos * 8L,length * 8L); + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8L, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8L, + length * 8L + ); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 399251b80e649..5056652a2420b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getPageNumber() != + if (!loaded || page.pageNumber != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 717823ebbd320..75690ae264838 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,6 +26,7 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -215,7 +216,12 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -342,7 +348,10 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index d7d2d0b012bd3..a0664b30d6cc2 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); } @Test(expected = AssertionError.class) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 3e56db5ea116a..47173b89e91e2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,8 +105,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) - val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) + val buf = new LongArray(MemoryBlock.fromLongArray(ref)) + val tmp = new Array[Long](size/2) + val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index ddf3740e76a7a..d5956ea32096a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(OnHeapMemoryBlock.fromArray(ref)), - new LongArray(OnHeapMemoryBlock.fromArray(extended))) + (new LongArray(MemoryBlock.fromLongArray(ref)), + new LongArray(MemoryBlock.fromLongArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) + val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index dc38ee326e5e9..dc18e1d34880a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -244,7 +244,8 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 7b73b286fb91c..8935c8496cdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) + hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 9e7b15d339eeb..9002abdcfd474 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,7 +27,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -241,8 +240,7 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); - return new UTF8String(mb); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 469b0e60cc9a2..a76e6ef8c91c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,7 +37,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -417,8 +416,7 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); - return new UTF8String(mb); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index 8e9c0a2e9dc81..eb5051b284073 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off @@ -72,13 +72,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWordsBlock(MemoryBlock mb) { - return hashUnsafeWordsBlock(mb, seed); + public long hashUnsafeWords(Object base, long offset, int length) { + return hashUnsafeWords(base, offset, length, seed); } - public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { - assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWordsBlock(mb, seed); + public static long hashUnsafeWords(Object base, long offset, int length, long seed) { + assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWords(base, offset, length, seed); return fmix(hash); } @@ -86,22 +86,20 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { - long offset = 0; - long length = mb.size(); + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWordsBlock(mb, seed); + long hash = hashBytesByWords(base, offset, length, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; + hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } @@ -109,11 +107,7 @@ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { } public static long hashUTF8String(UTF8String str, long seed) { - return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); - } - - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { - return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + return hashUnsafeBytes(str.getBaseObject(), str.getBaseOffset(), str.numBytes(), seed); } private static long fmix(long hash) { @@ -125,31 +119,30 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { - long offset = 0; - long length = mb.size(); + private static long hashBytesByWords(Object base, long offset, int length, long seed) { + long end = offset + length; long hash; if (length >= 32) { - long limit = length - 32; + long limit = end - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += mb.getLong(offset) * PRIME64_2; + v1 += Platform.getLong(base, offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += mb.getLong(offset + 8) * PRIME64_2; + v2 += Platform.getLong(base, offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += mb.getLong(offset + 16) * PRIME64_2; + v3 += Platform.getLong(base, offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += mb.getLong(offset + 24) * PRIME64_2; + v4 += Platform.getLong(base, offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -190,9 +183,9 @@ private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { hash += length; - long limit = length - 8; + long limit = end - 8; while (offset <= limit) { - long k1 = mb.getLong(offset); + long k1 = Platform.getLong(base, offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index f8000d78cd1b6..f0f66bae245fd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,8 +19,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -31,34 +29,43 @@ public class UTF8StringBuilder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - private ByteArrayMemoryBlock buffer; - private int length = 0; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new ByteArrayMemoryBlock(16); + this.buffer = new byte[16]; } // Grows the buffer by at least `neededSize` private void grow(int neededSize) { - if (neededSize > ARRAY_MAX - length) { + if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } - final int requestedSize = length + neededSize; - if (buffer.size() < requestedSize) { - int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; - final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); - MemoryBlock.copyMemory(buffer, tmp, length); + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); buffer = tmp; } } + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + public void append(UTF8String value) { grow(value.numBytes()); - value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); - length += value.numBytes(); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); } public void append(String value) { @@ -66,6 +73,6 @@ public void append(String value) { } public UTF8String build() { - return UTF8String.fromBytes(buffer.getByteArray(), 0, length); + return UTF8String.fromBytes(buffer, 0, totalSize()); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index a754e87a17968..742a4f87a9c04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 -import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -362,7 +361,10 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - s"$result = $hasherClassName.hashUTF8String($input, $result);" + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" } protected def genHashForMap( @@ -469,8 +471,6 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long - protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long - /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -496,7 +496,8 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) + case s: UTF8String => + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) case array: ArrayData => val elementType = dataType match { @@ -583,15 +584,9 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes( - base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } - - override protected def hashUnsafeBytesBlock( - base: MemoryBlock, seed: Long): Long = { - Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) - } } /** @@ -616,14 +611,9 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes( - base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } - - override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { - XXH64.hashUnsafeBytesBlock(base, seed) - } } /** @@ -730,7 +720,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - s"$result = $hasherClassName.hashUTF8String($input);" + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" } override protected def genHashForArray( @@ -824,14 +817,10 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes( - base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } - override protected def hashUnsafeBytesBlock( - base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) - private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 76930f9368514..b67c6f3e6e85e 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -54,7 +53,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); + int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -90,13 +89,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytesBlock(mb), - HiveHasher.hashUnsafeBytesBlock(mb)); + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); + hashcodes.add(HiveHasher.hashUnsafeBytes( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -113,13 +112,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytesBlock(mb), - HiveHasher.hashUnsafeBytesBlock(mb)); + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); + hashcodes.add(HiveHasher.hashUnsafeBytes( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index cd8bce623c5df..1baee91b3439c 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,8 +24,6 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -144,13 +142,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWordsBlock(mb), - hasher.hashUnsafeWordsBlock(mb)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); + hashcodes.add(hasher.hashUnsafeWords( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -167,13 +165,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); - MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWordsBlock(mb), - hasher.hashUnsafeWordsBlock(mb)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); - hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 6fdadde628551..5e0cf7d370dd1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,7 +23,6 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -207,7 +206,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); + return UTF8String.fromAddress(null, data + rowId, count); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 1c9beda404356..5f58b031f6aef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,7 +25,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -378,10 +377,9 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return new UTF8String(new OffHeapMemoryBlock( + return UTF8String.fromAddress(null, stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start - )); + stringResult.end - stringResult.start); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 470b93efd1974..50ae26a3ff9d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) + val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(OnHeapMemoryBlock.fromArray(ref)), - new LongArray(OnHeapMemoryBlock.fromArray(extended))) + (new LongArray(MemoryBlock.fromLongArray(ref)), + new LongArray(MemoryBlock.fromLongArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) + val buf = new LongArray(MemoryBlock.fromLongArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index 25ee95daa034c..ffda33cf906c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.OnHeapMemoryBlock +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = new OnHeapMemoryBlock((1<<10) * 8L) + val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } From 88a930dfab56c15df02c7bb944444745c2921fa5 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sun, 9 Sep 2018 09:49:13 -0500 Subject: [PATCH 1580/2461] [MINOR][ML] Remove `BisectingKMeansModel.setDistanceMeasure` method ## What changes were proposed in this pull request? Remove `BisectingKMeansModel.setDistanceMeasure` method. In `BisectingKMeansModel` set this param is meaningless. ## How was this patch tested? N/A Closes #22360 from WeichenXu123/bkmeans_update. Authored-by: WeichenXu Signed-off-by: Sean Owen --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 8904193cae94c..5cb16cc765887 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -104,10 +104,6 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group expertSetParam */ - @Since("2.4.0") - def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) From 77c996403d5c761f0dfea64c5b1cb7480ba1d3ac Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 9 Sep 2018 09:07:31 -0700 Subject: [PATCH 1581/2461] [SPARK-25368][SQL] Incorrect predicate pushdown returns wrong result ## What changes were proposed in this pull request? How to reproduce: ```scala val df1 = spark.createDataFrame(Seq( (1, 1) )).toDF("a", "b").withColumn("c", lit(null).cast("int")) val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter($"c".isNotNull) df2.show +---+---+----+---+ | a| b| c| d| +---+---+----+---+ | 1| 1|null| 0| | 1| 1|null| 1| +---+---+----+---+ ``` `filter($"c".isNotNull)` was transformed to `(null <=> c#10)` before https://github.com/apache/spark/pull/19201, but it is transformed to `(c#10 = null)` since https://github.com/apache/spark/pull/20155. This pr revert it to `(null <=> c#10)` to fix this issue. ## How was this patch tested? unit tests Closes #22368 from wangyum/SPARK-25368. Authored-by: Yuming Wang Signed-off-by: gatorsmile --- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0e4456ac0e6a9..5f136629eb15b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -159,7 +159,7 @@ abstract class UnaryNode extends LogicalPlan { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { case a @ Alias(l: Literal, _) => - allConstraints += EqualTo(a.toAttribute, l) + allConstraints += EqualNullSafe(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e4671f0d1cce6..a40ba2dc38b70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -196,7 +196,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("constraints should be inferred from aliased literals") { val originalLeft = testRelation.subquery('left).as("left") - val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left") val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") val condition = Some("left.a".attr === "right.two".attr) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 45b17b3d4958f..435b887cb3c78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2552,4 +2552,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { + def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + val df1 = spark.createDataFrame(Seq( + (1, 1) + )).toDF("a", "b").withColumn("c", newCol) + + val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) + checkAnswer(df2, result) + } + + check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) + check(lit(null).cast("int"), $"c".isNotNull, Seq()) + check(lit(2).cast("int"), $"c".isNull, Seq()) + check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" =!= 2, Seq()) + } } From a0aed475c54079665a8e5c5cd53a2e990a4f47b4 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sun, 9 Sep 2018 19:22:47 -0700 Subject: [PATCH 1582/2461] [SPARK-25175][SQL] Field resolution should fail if there is ambiguity for ORC native data source table persisted in metastore ## What changes were proposed in this pull request? Apache Spark doesn't create Hive table with duplicated fields in both case-sensitive and case-insensitive mode. However, if Spark creates ORC files in case-sensitive mode first and create Hive table on that location, where it's created. In this situation, field resolution should fail in case-insensitive mode. Otherwise, we don't know which columns will be returned or filtered. Previously, SPARK-25132 fixed the same issue in Parquet. Here is a simple example: ``` val data = spark.range(5).selectExpr("id as a", "id * 2 as A") spark.conf.set("spark.sql.caseSensitive", true) data.write.format("orc").mode("overwrite").save("/user/hive/warehouse/orc_data") sql("CREATE TABLE orc_data_source (A LONG) USING orc LOCATION '/user/hive/warehouse/orc_data'") spark.conf.set("spark.sql.caseSensitive", false) sql("select A from orc_data_source").show +---+ | A| +---+ | 3| | 2| | 4| | 1| | 0| +---+ ``` See #22148 for more details about parquet data source reader. ## How was this patch tested? Unit tests added. Closes #22262 from seancxmao/SPARK-25175. Authored-by: seancxmao Signed-off-by: Dongjoon Hyun --- .../execution/datasources/orc/OrcUtils.scala | 29 +++++++- .../spark/sql/FileBasedDataSourceSuite.scala | 71 ++++++++++--------- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index ac062fdc092ee..95fb25bf5addb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration @@ -27,7 +29,7 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} +import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ @@ -116,8 +118,29 @@ object OrcUtils extends Logging { } }) } else { - val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution - Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) }) + if (isCaseSensitive) { + Some(requiredSchema.fieldNames.map { name => + orcFieldNames.indexWhere(caseSensitiveResolution(_, name)) + }) + } else { + // Do case-insensitive resolution only if in case-insensitive mode + val caseInsensitiveOrcFieldMap = + orcFieldNames.zipWithIndex.groupBy(_._1.toLowerCase(Locale.ROOT)) + Some(requiredSchema.fieldNames.map { requiredFieldName => + caseInsensitiveOrcFieldMap + .get(requiredFieldName.toLowerCase(Locale.ROOT)) + .map { matchedOrcFields => + if (matchedOrcFields.size > 1) { + // Need to fail if there is ambiguity, i.e. more than one field is matched. + val matchedOrcFieldsString = matchedOrcFields.map(_._1).mkString("[", ", ", "]") + throw new RuntimeException(s"""Found duplicate field(s) "$requiredFieldName": """ + + s"$matchedOrcFieldsString in case-insensitive mode") + } else { + matchedOrcFields.head._2 + } + }.getOrElse(-1) + }) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 304ede9c5a612..94f163708832c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -434,44 +434,45 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } - test(s"SPARK-25132: case-insensitive field resolution when reading from Parquet") { - withTempDir { dir => - val format = "parquet" - val tableDir = dir.getCanonicalPath + s"/$format" - val tableName = s"spark_25132_${format}" - withTable(tableName) { - val end = 5 - val data = spark.range(end).selectExpr("id as A", "id * 2 as b", "id * 3 as B") - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - data.write.format(format).mode("overwrite").save(tableDir) - } - sql(s"CREATE TABLE $tableName (a LONG, b LONG) USING $format LOCATION '$tableDir'") - - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswer(sql(s"select a from $tableName"), data.select("A")) - checkAnswer(sql(s"select A from $tableName"), data.select("A")) - - // RuntimeException is triggered at executor side, which is then wrapped as - // SparkException at driver side - val e1 = intercept[SparkException] { - sql(s"select b from $tableName").collect() + Seq("parquet", "orc").foreach { format => + test(s"Spark native readers should respect spark.sql.caseSensitive - ${format}") { + withTempDir { dir => + val tableName = s"spark_25132_${format}_native" + val tableDir = dir.getCanonicalPath + s"/$tableName" + withTable(tableName) { + val end = 5 + val data = spark.range(end).selectExpr("id as A", "id * 2 as b", "id * 3 as B") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + data.write.format(format).mode("overwrite").save(tableDir) } - assert( - e1.getCause.isInstanceOf[RuntimeException] && - e1.getCause.getMessage.contains( - """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) - val e2 = intercept[SparkException] { - sql(s"select B from $tableName").collect() + sql(s"CREATE TABLE $tableName (a LONG, b LONG) USING $format LOCATION '$tableDir'") + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer(sql(s"select a from $tableName"), data.select("A")) + checkAnswer(sql(s"select A from $tableName"), data.select("A")) + + // RuntimeException is triggered at executor side, which is then wrapped as + // SparkException at driver side + val e1 = intercept[SparkException] { + sql(s"select b from $tableName").collect() + } + assert( + e1.getCause.isInstanceOf[RuntimeException] && + e1.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) + val e2 = intercept[SparkException] { + sql(s"select B from $tableName").collect() + } + assert( + e2.getCause.isInstanceOf[RuntimeException] && + e2.getCause.getMessage.contains( + """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) } - assert( - e2.getCause.isInstanceOf[RuntimeException] && - e2.getCause.getMessage.contains( - """Found duplicate field(s) "b": [b, B] in case-insensitive mode""")) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - checkAnswer(sql(s"select a from $tableName"), (0 until end).map(_ => Row(null))) - checkAnswer(sql(s"select b from $tableName"), data.select("b")) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select a from $tableName"), (0 until end).map(_ => Row(null))) + checkAnswer(sql(s"select b from $tableName"), data.select("b")) + } } } } From f8b4d5aafd1923d9524415601469f8749b3d0811 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 10 Sep 2018 13:47:19 +0800 Subject: [PATCH 1583/2461] [SPARK-25313][SQL][FOLLOW-UP] Fix InsertIntoHiveDirCommand output schema in Parquet issue ## What changes were proposed in this pull request? How to reproduce: ```scala spark.sql("CREATE TABLE tbl(id long)") spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") spark.sql(s"INSERT OVERWRITE LOCAL DIRECTORY '/tmp/spark/parquet' " + "STORED AS PARQUET SELECT ID FROM view1") spark.read.parquet("/tmp/spark/parquet").schema scala> spark.read.parquet("/tmp/spark/parquet").schema res10: org.apache.spark.sql.types.StructType = StructType(StructField(id,LongType,true)) ``` The schema should be `StructType(StructField(ID,LongType,true))` as we `SELECT ID FROM view1`. This pr fix this issue. ## How was this patch tested? unit tests Closes #22359 from wangyum/SPARK-25313-FOLLOW-UP. Authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../command/DataWritingCommand.scala | 15 --------------- .../CreateHiveTableAsSelectCommand.scala | 4 ++-- .../execution/InsertIntoHiveDirCommand.scala | 5 ++--- .../hive/execution/InsertIntoHiveTable.scala | 1 - .../sql/hive/execution/SaveAsHiveFile.scala | 3 +-- .../sql/hive/execution/HiveDDLSuite.scala | 19 +++++++++++++++++++ 6 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 0a185b8472060..a1bb5af1ab723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -75,18 +74,4 @@ object DataWritingCommand { attr.withName(outputName) } } - - /** - * Returns schema of logical plan with provided names. - * The length of provided names should be the same of the length of [[LogicalPlan.schema]]. - */ - def logicalPlanSchemaWithNames( - query: LogicalPlan, - names: Seq[String]): StructType = { - assert(query.schema.length == names.length, - "The length of provided names doesn't match the length of query schema.") - StructType(query.schema.zip(names).map { case (structField, outputName) => - structField.copy(name = outputName) - }) - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 0eb2f0de0acd9..aa573b54a2b62 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -69,8 +69,8 @@ case class CreateHiveTableAsSelectCommand( // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - val schema = DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames) - catalog.createTable(tableDesc.copy(schema = schema), ignoreIfExists = false) + catalog.createTable( + tableDesc.copy(schema = outputColumns.toStructType), ignoreIfExists = false) try { // Read back the metadata of the table which was created just now. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 0a73aaa94bc75..a24e902074c2d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -66,7 +66,7 @@ case class InsertIntoHiveDirCommand( identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")), tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW, storage = storage, - schema = query.schema + schema = outputColumns.toStructType )) hiveTable.getMetadata.put(serdeConstants.SERIALIZATION_LIB, storage.serde.getOrElse(classOf[LazySimpleSerDe].getName)) @@ -104,8 +104,7 @@ case class InsertIntoHiveDirCommand( plan = child, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, - outputLocation = tmpPath.toString, - allColumns = outputColumns) + outputLocation = tmpPath.toString) val fs = writeToPath.getFileSystem(hadoopConf) if (overwrite && fs.exists(writeToPath)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 75a0563e72c91..0ed464dad91b1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -198,7 +198,6 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, - allColumns = outputColumns, partitionAttributes = partitionAttributes) if (partition.nonEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index e0f7375387d24..078968ed0145f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -51,7 +51,6 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, - allColumns: Seq[Attribute], customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { @@ -90,7 +89,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = - FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns), + FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, bucketSpec = None, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 9acd5e1c248eb..69ee2bbf06651 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -803,6 +803,25 @@ class HiveDDLSuite } } + test("SPARK-25313 Insert overwrite directory should output correct schema") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + spark.sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path.getCanonicalPath}' " + + "STORED AS PARQUET SELECT ID FROM view1") + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.read.parquet(path.toString), Seq(Row(4))) + } + } + } + } + } + test("alter table partition - storage information") { sql("CREATE TABLE boxes (height INT, length INT) PARTITIONED BY (width INT)") sql("INSERT OVERWRITE TABLE boxes PARTITION (width=4) SELECT 4, 4") From e7853dc103bf3fd541aa2b498f5f3a223067f812 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 10 Sep 2018 15:11:14 +0800 Subject: [PATCH 1584/2461] [SPARK-24999][SQL] Reduce unnecessary 'new' memory operations ## What changes were proposed in this pull request? This PR is to solve the CodeGen code generated by fast hash, and there is no need to apply for a block of memory for every new entry, because unsafeRow's memory can be reused. ## How was this patch tested? the existed test cases. Closes #21968 from heary-cao/updateNewMemory. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../aggregate/RowBasedHashMapGenerator.scala | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 3d2443ca959a4..56cf78d8b7fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -48,6 +48,12 @@ class RowBasedHashMapGenerator( val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) + val numVarLenFields = groupingKeys.map(_.dataType).count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; | private int[] buckets; @@ -60,6 +66,7 @@ class RowBasedHashMapGenerator( | private long emptyVOff; | private int emptyVLen; | private boolean isBatchFull = false; + | private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; | | | public $generatedClassName( @@ -75,6 +82,9 @@ class RowBasedHashMapGenerator( | emptyVOff = Platform.BYTE_ARRAY_OFFSET; | emptyVLen = emptyBuffer.length; | + | agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); | } @@ -112,12 +122,6 @@ class RowBasedHashMapGenerator( * */ protected def generateFindOrInsert(): String = { - val numVarLenFields = groupingKeys.map(_.dataType).count { - case dt if UnsafeRow.isFixedLength(dt) => false - // TODO: consider large decimal and interval type - case _ => true - } - val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => key.dataType match { case t: DecimalType => @@ -130,6 +134,12 @@ class RowBasedHashMapGenerator( } }.mkString(";\n") + val resetNullBits = if (groupingKeySchema.map(_.nullable).forall(_ == false)) { + "" + } else { + "agg_rowWriter.zeroOutNullBytes();" + } + s""" |public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${ groupingKeySignature}) { @@ -140,12 +150,8 @@ class RowBasedHashMapGenerator( | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { - | // creating the unsafe for new entry - | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter - | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | ${groupingKeySchema.length}, ${numVarLenFields * 32}); - | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed - | agg_rowWriter.zeroOutNullBytes(); + | agg_rowWriter.reset(); + | $resetNullBits | ${createUnsafeRowForKey}; | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result | = agg_rowWriter.getRow(); From 6f6517837ba9934a280b11aba9d9be58bc131f25 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 10 Sep 2018 19:18:00 +0800 Subject: [PATCH 1585/2461] [SPARK-24849][SPARK-24911][SQL][FOLLOW-UP] Converting a value of StructType to a DDL string ## What changes were proposed in this pull request? Add the version number for the new APIs. ## How was this patch tested? N/A Closes #22377 from gatorsmile/followup24849. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../main/scala/org/apache/spark/sql/types/StructField.scala | 2 ++ .../main/scala/org/apache/spark/sql/types/StructType.scala | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 902cae9150ede..35f9970a0aaec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -79,6 +79,8 @@ case class StructField( /** * Returns a string containing a schema in DDL format. For example, the following value: * `StructField("eventId", IntegerType)` will be converted to `eventId` INT. + * + * @since 2.4.0 */ def toDDL: String = { val comment = getComment() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c5ca169c955dc..06289b1483203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -365,6 +365,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` * will be converted to `eventId` INT, `s` STRING. * The returned DDL schema can be used in a table creation. + * + * @since 2.4.0 */ def toDDL: String = fields.map(_.toDDL).mkString(",") @@ -441,6 +443,8 @@ object StructType extends AbstractDataType { /** * Creates StructType for a given DDL-formatted string, which is a comma separated list of field * definitions, e.g., a INT, b STRING. + * + * @since 2.2.0 */ def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) From 12e3e9f17dca11a2cddf0fb99d72b4b97517fb56 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 10 Sep 2018 19:41:51 +0800 Subject: [PATCH 1586/2461] [SPARK-25278][SQL] Avoid duplicated Exec nodes when the same logical plan appears in the query ## What changes were proposed in this pull request? In the Planner, we collect the placeholder which need to be substituted in the query execution plan and once we plan them, we substitute the placeholder with the effective plan. In this second phase, we rely on the `==` comparison, ie. the `equals` method. This means that if two placeholder plans - which are different instances - have the same attributes (so that they are equal, according to the equal method) they are both substituted with their corresponding new physical plans. So, in such a situation, the first time we substitute both them with the first of the 2 new generated plan and the second time we substitute nothing. This is usually of no harm for the execution of the query itself, as the 2 plans are identical. But since they are the same instance, now, the local variables are shared (which is unexpected). This causes issues for the metrics collected, as the same node is executed 2 times, so the metrics are accumulated 2 times, wrongly. The PR proposes to use the `eq` method in checking which placeholder needs to be substituted,; thus in the previous situation, actually both the two different physical nodes which are created (one for each time the logical plan appears in the query plan) are used and the metrics are collected properly for each of them. ## How was this patch tested? added UT Closes #22284 from mgaido91/SPARK-25278. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 17 +++++++++++++++++ .../sql/execution/metric/SQLMetricsSuite.scala | 13 +++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index bc41dd0465e34..6fa5203a06f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -81,7 +81,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { childPlans.map { childPlan => // Replace the placeholder by the child plan candidateWithPlaceholders.transformUp { - case p if p == placeholder => childPlan + case p if p.eq(placeholder) => childPlan } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3db89ecfad9fc..b10da6c70be16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -704,6 +704,23 @@ class PlannerSuite extends SharedSQLContext { df.queryExecution.executedPlan.execute() } + test("SPARK-25278: physical nodes should be different instances for same logical nodes") { + val range = Range(1, 1, 1, 1) + val df = Union(range, range) + val ranges = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(ranges.length == 2) + // Ensure the two Range instances are equal according to their equal method + assert(ranges.head == ranges.last) + val execRanges = df.queryExecution.sparkPlan.collect { + case r: RangeExec => r + } + assert(execRanges.length == 2) + // Ensure the two RangeExec instances are different instances + assert(!execRanges.head.eq(execRanges.last)) + } + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + "and InMemoryTableScanExec") { def checkOutputPartitioningRewrite( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a3a3f3851e21c..d45eb0c27a6b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -497,6 +497,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } } + test("SPARK-25278: output metrics are wrong for plans repeated in the query") { + val name = "demo_view" + withView(name) { + sql(s"CREATE OR REPLACE VIEW $name AS VALUES 1,2") + val view = spark.table(name) + val union = view.union(view) + testSparkPlanMetrics(union, 1, Map( + 0L -> ("Union" -> Map()), + 1L -> ("LocalTableScan" -> Map("number of output rows" -> 2L)), + 2L -> ("LocalTableScan" -> Map("number of output rows" -> 2L)))) + } + } + test("writing data out metrics: parquet") { testMetricsNonDynamicPartition("parquet", "t1") } From da5685b5bb9ee7daaeb4e8f99c488ebd50c7aac3 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 10 Sep 2018 11:01:51 -0700 Subject: [PATCH 1587/2461] [SPARK-23672][PYTHON] Document support for nested return types in scalar with arrow udfs ## What changes were proposed in this pull request? Clarify docstring for Scalar functions ## How was this patch tested? Adds a unit test showing use similar to wordcount, there's existing unit test for array of floats as well. Closes #20908 from holdenk/SPARK-23672-document-support-for-nested-return-types-in-scalar-with-arrow-udfs. Authored-by: Holden Karau Signed-off-by: Bryan Cutler --- python/pyspark/sql/functions.py | 3 ++- python/pyspark/sql/tests.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9396b16b7ada8..81f35f54aa54d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2720,9 +2720,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + :class:`MapType`, :class:`StructType` are currently not supported as output types. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6d9d636b23a3a..8e5bc6729dfa4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4443,6 +4443,7 @@ def test_timestamp_dst(self): not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4658,6 +4659,24 @@ def random_udf(v): random_udf = random_udf.asNondeterministic() return random_udf + def test_pandas_udf_tokenize(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), + ArrayType(StringType())) + self.assertEqual(tokenize.returnType, ArrayType(StringType())) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) + + def test_pandas_udf_nested_arrays(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), + ArrayType(ArrayType(StringType()))) + self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( From 0736e72a66735664b191fc363f54e3c522697dba Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 11 Sep 2018 14:16:56 +0800 Subject: [PATCH 1588/2461] [SPARK-25371][SQL] struct() should allow being called with 0 args ## What changes were proposed in this pull request? SPARK-21281 introduced a check for the inputs of `CreateStructLike` to be non-empty. This means that `struct()`, which was previously considered valid, now throws an Exception. This behavior change was introduced in 2.3.0. The change may break users' application on upgrade and it causes `VectorAssembler` to fail when an empty `inputCols` is defined. The PR removes the added check making `struct()` valid again. ## How was this patch tested? added UT Closes #22373 from mgaido91/SPARK-25371. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../org/apache/spark/ml/feature/VectorAssemblerSuite.scala | 5 +++++ .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 5 +---- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 -- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index ed15a1d88a269..a4d388fd321db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -256,4 +256,9 @@ class VectorAssemblerSuite assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) } + test("SPARK-25371: VectorAssembler with empty inputCols") { + val vectorAssembler = new VectorAssembler().setInputCols(Array()).setOutputCol("a") + val output = vectorAssembler.transform(dfWithNullsAndNaNs) + assert(output.select("a").limit(1).collect().head == Row(Vectors.sparse(0, Seq.empty))) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 077a6dc93bd17..aba9c6c8ad6fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -379,10 +379,7 @@ trait CreateNamedStructLike extends Expression { } override def checkInputDataTypes(): TypeCheckResult = { - if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") - } else if (children.size % 2 != 0) { + if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4b83e51fa8992..121db442c77f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2677,8 +2677,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val funcsMustHaveAtLeastOneArg = ("coalesce", (df: DataFrame) => df.select(coalesce())) :: ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("named_struct", (df: DataFrame) => df.select(struct())) :: - ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) :: ("hash", (df: DataFrame) => df.select(hash())) :: ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => From 0e680dcf1e20c5632b9451adce4079bf57107dbc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Sep 2018 19:38:45 +0800 Subject: [PATCH 1589/2461] [SPARK-25278][SQL][FOLLOWUP] remove the hack in ProgressReporter ## What changes were proposed in this pull request? It turns out it's a bug that a `DataSourceV2ScanExec` instance may be referred to in the execution plan multiple times. This bug is fixed by https://github.com/apache/spark/pull/22284 and now we have corrected SQL metrics for batch queries. Thus we don't need the hack in `ProgressReporter` anymore, which fixes the same metrics problem for streaming queries. ## How was this patch tested? existing tests Closes #22380 from cloud-fan/followup. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../streaming/ProgressReporter.scala | 36 +++---------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d4b50655c7215..73b180468d367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -255,40 +255,12 @@ trait ProgressReporter extends Logging { } if (onlyDataSourceV2Sources) { - // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data - // from a V2 source and has a direct reference to the V2 source that generated it. Each - // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, - // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as - // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or - // even multiple times) points and considering it twice will lead to double counting. We - // can't dedup them using their hashcode either because two different instances of - // DataSourceV2ScanExec can have the same hashcode but account for separate sets of - // records read, and deduping them to consider only one of them would be undercounting the - // records read. Therefore the right way to do this is to consider the unique instances of - // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. - // Hence we calculate in the following way. - // - // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. - // - // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. - // - // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with - // self-unions or self-joins). Add up the number of rows for each unique source. - val uniqueStreamingExecLeavesMap = - new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() - - lastExecution.executedPlan.collectLeaves().foreach { + val sourceToInputRowsTuples = lastExecution.executedPlan.collect { case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => - uniqueStreamingExecLeavesMap.put(s, s) - case _ => - } - - val sourceToInputRowsTuples = - uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => - val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] + val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = s.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows - }.toSeq + } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) sumRows(sourceToInputRowsTuples) } else { From c9cb393dc414ae98093c1541d09fa3c8663ce276 Mon Sep 17 00:00:00 2001 From: Mario Molina Date: Tue, 11 Sep 2018 20:47:14 +0800 Subject: [PATCH 1590/2461] [SPARK-17916][SPARK-25241][SQL][FOLLOW-UP] Fix empty string being parsed as null when nullValue is set. ## What changes were proposed in this pull request? In the PR, I propose new CSV option `emptyValue` and an update in the SQL Migration Guide which describes how to revert previous behavior when empty strings were not written at all. Since Spark 2.4, empty strings are saved as `""` to distinguish them from saved `null`s. Closes #22234 Closes #22367 ## How was this patch tested? It was tested by `CSVSuite` and new tests added in the PR #22234 Closes #22389 from MaxGekk/csv-empty-value-master. Lead-authored-by: Mario Molina Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- docs/sql-programming-guide.md | 1 + python/pyspark/sql/readwriter.py | 12 +- python/pyspark/sql/streaming.py | 7 +- .../apache/spark/sql/DataFrameReader.scala | 1 + .../apache/spark/sql/DataFrameWriter.scala | 1 + .../datasources/csv/CSVOptions.scala | 19 ++- .../sql/streaming/DataStreamReader.scala | 1 + .../resources/test-data/cars-empty-value.csv | 4 + .../execution/datasources/csv/CSVSuite.scala | 111 ++++++++++++++++++ 9 files changed, 149 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/cars-empty-value.csv diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3749094569271..9da7d64322eb6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1897,6 +1897,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.parallelFileListingInStatsComputation.enabled` to `False`. - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. + - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was writted as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. ## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 49f4e6b2ede1b..3ca5d548ae7d6 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -349,7 +349,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None): + samplingRatio=None, enforceSchema=None, emptyValue=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -444,6 +444,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non different, ``\0`` otherwise. :param samplingRatio: defines fraction of rows used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param emptyValue: sets the string representation of an empty value. If None is set, it uses + the default value, empty string. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -463,7 +465,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema) + enforceSchema=enforceSchema, emptyValue=emptyValue) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -859,7 +861,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None, encoding=None): + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -911,6 +913,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No different, ``\0`` otherwise.. :param encoding: sets the encoding (charset) of saved csv files. If None is set, the default UTF-8 charset will be used. + :param emptyValue: sets the string representation of an empty value. If None is set, it uses + the default value, ``""``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -921,7 +925,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, - encoding=encoding) + encoding=encoding, emptyValue=emptyValue) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index ee13778a7dcd6..522900bf6684c 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -564,7 +564,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None): + enforceSchema=None, emptyValue=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -658,6 +658,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise.. + :param emptyValue: sets the string representation of an empty value. If None is set, it uses + the default value, empty string. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -674,7 +676,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, + emptyValue=emptyValue) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0cfcc45fb3d31..e6c2cba79841a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -571,6 +571,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * whitespaces from values being read should be skipped. *

    • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
    • + *
    • `emptyValue` (default empty string): sets the string representation of an empty value.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index eca2d5b971905..dfb8c4718550f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -635,6 +635,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * enclosed in quotes. Default is to only escape values containing a quote character. *
    • `header` (default `false`): writes the names of columns as the first line.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `emptyValue` (default `""`): sets the string representation of an empty value.
    • *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved csv * files. If it is not set, the UTF-8 charset will be used.
    • *
    • `compression` (default `null`): compression codec to use when saving to file. This can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index fab8d62da0c1d..492a21be6df3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -162,6 +162,21 @@ class CSVOptions( */ val enforceSchema = getBool("enforceSchema", default = true) + + /** + * String representation of an empty value in read and in write. + */ + val emptyValue = parameters.get("emptyValue") + /** + * The string is returned when CSV reader doesn't have any characters for input value, + * or an empty quoted string `""`. Default value is empty string. + */ + val emptyValueInRead = emptyValue.getOrElse("") + /** + * The value is used instead of an empty string in write. Default value is `""` + */ + val emptyValueInWrite = emptyValue.getOrElse("\"\"") + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -173,7 +188,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue("\"\"") + writerSettings.setEmptyValue(emptyValueInWrite) writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -194,7 +209,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) - settings.setEmptyValue("") + settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 39e9e1ad426be..2a4db4afbe005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -327,6 +327,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * whitespaces from values being read should be skipped.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
    • + *
    • `emptyValue` (default empty string): sets the string representation of an empty value.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • diff --git a/sql/core/src/test/resources/test-data/cars-empty-value.csv b/sql/core/src/test/resources/test-data/cars-empty-value.csv new file mode 100644 index 0000000000000..0f20a2f23ac06 --- /dev/null +++ b/sql/core/src/test/resources/test-data/cars-empty-value.csv @@ -0,0 +1,4 @@ +year,make,model,comment,blank +"2012","Tesla","S","","" +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt,,"" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5a1d6679ebbdb..2b39a0b1f52ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -50,6 +50,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val carsAltFile = "test-data/cars-alternative.csv" private val carsUnbalancedQuotesFile = "test-data/cars-unbalanced-quotes.csv" private val carsNullFile = "test-data/cars-null.csv" + private val carsEmptyValueFile = "test-data/cars-empty-value.csv" private val carsBlankColName = "test-data/cars-blank-column-name.csv" private val emptyFile = "test-data/empty.csv" private val commentsFile = "test-data/comments.csv" @@ -668,6 +669,70 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } + test("empty fields with user defined empty values") { + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = spark.read + .format("csv") + .schema(dataSchema) + .option("header", "true") + .option("emptyValue", "empty") + .load(testFile(carsEmptyValueFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "empty", "empty")) + assert(results(1).toSeq === + Array(1997, "Ford", "E350", "Go get one now they are going fast", null)) + assert(results(2).toSeq === Array(2015, "Chevy", "Volt", null, "empty")) + } + + test("save csv with empty fields with user defined empty values") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = spark.read + .format("csv") + .schema(dataSchema) + .option("header", "true") + .option("nullValue", "NULL") + .load(testFile(carsEmptyValueFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("emptyValue", "empty") + .option("nullValue", null) + .save(csvDir) + + val carsCopy = spark.read + .format("csv") + .schema(dataSchema) + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true, checkValues = false) + val results = carsCopy.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "empty", "empty")) + assert(results(1).toSeq === + Array(1997, "Ford", "E350", "Go get one now they are going fast", null)) + assert(results(2).toSeq === Array(2015, "Chevy", "Volt", null, "empty")) + } + } + test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath @@ -1375,6 +1440,52 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } + test("SPARK-25241: An empty string should not be coerced to null when emptyValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where a null is not coerced to an empty string when `emptyValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("emptyValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("emptyValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, "-"), + (3, "-"), + (4, "-") + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to emptyValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { val schema = new StructType().add("colA", StringType) val ds = spark From 77579aa8c35b0d98bbeac3c828bf68a1d190d13e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 11 Sep 2018 08:57:42 -0700 Subject: [PATCH 1591/2461] [SPARK-25389][SQL] INSERT OVERWRITE DIRECTORY STORED AS should prevent duplicate fields ## What changes were proposed in this pull request? Like `INSERT OVERWRITE DIRECTORY USING` syntax, `INSERT OVERWRITE DIRECTORY STORED AS` should not generate files with duplicate fields because Spark cannot read those files back. **INSERT OVERWRITE DIRECTORY USING** ```scala scala> sql("INSERT OVERWRITE DIRECTORY 'file:///tmp/parquet' USING parquet SELECT 'id', 'id2' id") ... ERROR InsertIntoDataSourceDirCommand: Failed to write to directory ... org.apache.spark.sql.AnalysisException: Found duplicate column(s) when inserting into file:/tmp/parquet: `id`; ``` **INSERT OVERWRITE DIRECTORY STORED AS** ```scala scala> sql("INSERT OVERWRITE DIRECTORY 'file:///tmp/parquet' STORED AS parquet SELECT 'id', 'id2' id") // It generates corrupted files scala> spark.read.parquet("/tmp/parquet").show 18/09/09 22:09:57 WARN DataSource: Found duplicate column(s) in the data schema and the partition schema: `id`; ``` ## How was this patch tested? Pass the Jenkins with newly added test cases. Closes #22378 from dongjoon-hyun/SPARK-25389. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../execution/InsertIntoHiveDirCommand.scala | 5 ++++ .../apache/spark/sql/hive/InsertSuite.scala | 24 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index a24e902074c2d..0c694910b06d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.util.SchemaUtils /** * Command for writing the results of `query` to file system. @@ -61,6 +62,10 @@ case class InsertIntoHiveDirCommand( override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(storage.locationUri.nonEmpty) + SchemaUtils.checkColumnNameDuplication( + outputColumnNames, + s"when inserting into ${storage.locationUri.get}", + sparkSession.sessionState.conf.caseSensitiveAnalysis) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index ab91727049ff5..5879748d05b2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -750,4 +751,27 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } } + + Seq("LOCAL", "").foreach { local => + Seq(true, false).foreach { caseSensitivity => + Seq("orc", "parquet").foreach { format => + test(s"SPARK-25389 INSERT OVERWRITE $local DIRECTORY ... STORED AS with duplicated names" + + s"(caseSensitivity=$caseSensitivity, format=$format)") { + withTempDir { dir => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitivity") { + val m = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE $local DIRECTORY '${dir.toURI}' + |STORED AS $format + |SELECT 'id', 'id2' ${if (caseSensitivity) "id" else "ID"} + """.stripMargin) + }.getMessage + assert(m.contains("Found duplicate column(s) when inserting into")) + } + } + } + } + } + } } From bcb9a8c83f4e6835af5dc51f1be7f964b8fa49a3 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Tue, 11 Sep 2018 09:28:32 -0700 Subject: [PATCH 1592/2461] [SPARK-25221][DEPLOY] Consistent trailing whitespace treatment of conf values ## What changes were proposed in this pull request? Stop trimming values of properties loaded from a file ## How was this patch tested? Added unit test demonstrating the issue hit in production. Closes #22213 from gerashegalov/gera/SPARK-25221. Authored-by: Gera Shegalov Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/util/Utils.scala | 31 ++++++++++-- .../spark/deploy/SparkSubmitSuite.scala | 47 +++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 28 +++++++++++ 3 files changed, 103 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4593b057fc634..14f68cd6f3509 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,7 +19,6 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} -import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -2052,6 +2051,30 @@ private[spark] object Utils extends Logging { } } + /** + * Implements the same logic as JDK `java.lang.String#trim` by removing leading and trailing + * non-printable characters less or equal to '\u0020' (SPACE) but preserves natural line + * delimiters according to [[java.util.Properties]] load method. The natural line delimiters are + * removed by JDK during load. Therefore any remaining ones have been specifically provided and + * escaped by the user, and must not be ignored + * + * @param str + * @return the trimmed value of str + */ + private[util] def trimExceptCRLF(str: String): String = { + val nonSpaceOrNaturalLineDelimiter: Char => Boolean = { ch => + ch > ' ' || ch == '\r' || ch == '\n' + } + + val firstPos = str.indexWhere(nonSpaceOrNaturalLineDelimiter) + val lastPos = str.lastIndexWhere(nonSpaceOrNaturalLineDelimiter) + if (firstPos >= 0 && lastPos >= 0) { + str.substring(firstPos, lastPos + 1) + } else { + "" + } + } + /** Load properties present in the given file. */ def getPropertiesFromFile(filename: String): Map[String, String] = { val file = new File(filename) @@ -2062,8 +2085,10 @@ private[spark] object Utils extends Logging { try { val properties = new Properties() properties.load(inReader) - properties.stringPropertyNames().asScala.map( - k => (k, properties.getProperty(k).trim)).toMap + properties.stringPropertyNames().asScala + .map { k => (k, trimExceptCRLF(properties.getProperty(k))) } + .toMap + } catch { case e: IOException => throw new SparkException(s"Failed when loading Spark properties from $filename", e) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index f829fecc30840..9eae3605d0738 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1144,6 +1144,53 @@ class SparkSubmitSuite conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") conf1.get("spark.submit.pyFiles") should (startWith("/")) } + + test("handles natural line delimiters in --properties-file and --conf uniformly") { + val delimKey = "spark.my.delimiter." + val LF = "\n" + val CR = "\r" + + val lineFeedFromCommandLine = s"${delimKey}lineFeedFromCommandLine" -> LF + val leadingDelimKeyFromFile = s"${delimKey}leadingDelimKeyFromFile" -> s"${LF}blah" + val trailingDelimKeyFromFile = s"${delimKey}trailingDelimKeyFromFile" -> s"blah${CR}" + val infixDelimFromFile = s"${delimKey}infixDelimFromFile" -> s"${CR}blah${LF}" + val nonDelimSpaceFromFile = s"${delimKey}nonDelimSpaceFromFile" -> " blah\f" + + val testProps = Seq(leadingDelimKeyFromFile, trailingDelimKeyFromFile, infixDelimFromFile, + nonDelimSpaceFromFile) + + val props = new java.util.Properties() + val propsFile = File.createTempFile("test-spark-conf", ".properties", + Utils.createTempDir()) + val propsOutputStream = new FileOutputStream(propsFile) + try { + testProps.foreach { case (k, v) => props.put(k, v) } + props.store(propsOutputStream, "test whitespace") + } finally { + propsOutputStream.close() + } + + val clArgs = Seq( + "--class", "org.SomeClass", + "--conf", s"${lineFeedFromCommandLine._1}=${lineFeedFromCommandLine._2}", + "--conf", "spark.master=yarn", + "--properties-file", propsFile.getPath, + "thejar.jar") + + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) + + Seq( + lineFeedFromCommandLine, + leadingDelimKeyFromFile, + trailingDelimKeyFromFile, + infixDelimFromFile + ).foreach { case (k, v) => + conf.get(k) should be (v) + } + + conf.get(nonDelimSpaceFromFile._1) should be ("blah") + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 943b53522d64e..39f4fba78583f 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1205,6 +1205,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.stringHalfWidth("\u0967\u0968\u0969") == 3) // scalastyle:on nonascii } + + test("trimExceptCRLF standalone") { + val crlfSet = Set("\r", "\n") + val nonPrintableButCRLF = (0 to 32).map(_.toChar.toString).toSet -- crlfSet + + // identity for CRLF + crlfSet.foreach { s => Utils.trimExceptCRLF(s) === s } + + // empty for other non-printables + nonPrintableButCRLF.foreach { s => assert(Utils.trimExceptCRLF(s) === "") } + + // identity for a printable string + assert(Utils.trimExceptCRLF("a") === "a") + + // identity for strings with CRLF + crlfSet.foreach { s => + assert(Utils.trimExceptCRLF(s"${s}a") === s"${s}a") + assert(Utils.trimExceptCRLF(s"a${s}") === s"a${s}") + assert(Utils.trimExceptCRLF(s"b${s}b") === s"b${s}b") + } + + // trim nonPrintableButCRLF except when inside a string + nonPrintableButCRLF.foreach { s => + assert(Utils.trimExceptCRLF(s"${s}a") === "a") + assert(Utils.trimExceptCRLF(s"a${s}") === "a") + assert(Utils.trimExceptCRLF(s"b${s}b") === s"b${s}b") + } + } } private class SimpleExtension From 14f3ad20932535fe952428bf255e7eddd8fa1b58 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Sep 2018 10:31:06 -0700 Subject: [PATCH 1593/2461] [SPARK-24889][CORE] Update block info when unpersist rdds ## What changes were proposed in this pull request? We will update block info coming from executors, at the timing like caching a RDD. However, when removing RDDs with unpersisting, we don't ask to update block info. So the block info is not updated. We can fix this with few options: 1. Ask to update block info when unpersisting This is simplest but changes driver-executor communication a bit. 2. Update block info when processing the event of unpersisting RDD We send a `SparkListenerUnpersistRDD` event when unpersisting RDD. When processing this event, we can update block info of the RDD. This only changes event processing code so the risk seems to be lower. Currently this patch takes option 2 for lower risk. If we agree first option has no risk, we can change to it. ## How was this patch tested? Unit tests. Closes #22341 from viirya/SPARK-24889. Authored-by: Liang-Chi Hsieh Signed-off-by: Marcelo Vanzin --- .../spark/status/AppStatusListener.scala | 64 ++++++++++++++----- .../org/apache/spark/status/LiveEntity.scala | 4 ++ .../org/apache/spark/storage/RDDInfo.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 29 +++++++++ 4 files changed, 82 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 304d0922a37d2..f21eee1965761 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -686,7 +686,37 @@ private[spark] class AppStatusListener( } override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { - liveRDDs.remove(event.rddId) + liveRDDs.remove(event.rddId).foreach { liveRDD => + val storageLevel = liveRDD.info.storageLevel + + // Use RDD partition info to update executor block info. + liveRDD.getPartitions().foreach { case (_, part) => + part.executors.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + exec.rddBlocks = exec.rddBlocks - 1 + } + } + } + + val now = System.nanoTime() + + // Use RDD distribution to update executor memory and disk usage info. + liveRDD.getDistributions().foreach { case (executorId, rddDist) => + liveExecutors.get(executorId).foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = addDeltaToValue(exec.usedOffHeap, -rddDist.offHeapUsed) + } else { + exec.usedOnHeap = addDeltaToValue(exec.usedOnHeap, -rddDist.onHeapUsed) + } + } + exec.memoryUsed = addDeltaToValue(exec.memoryUsed, -rddDist.memoryUsed) + exec.diskUsed = addDeltaToValue(exec.diskUsed, -rddDist.diskUsed) + maybeUpdate(exec, now) + } + } + } + kvstore.delete(classOf[RDDStorageInfoWrapper], event.rddId) } @@ -770,6 +800,11 @@ private[spark] class AppStatusListener( .sortBy(_.stageId) } + /** + * Apply a delta to a value, but ensure that it doesn't go negative. + */ + private def addDeltaToValue(old: Long, delta: Long): Long = math.max(0, old + delta) + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { val now = System.nanoTime() val executorId = event.blockUpdatedInfo.blockManagerId.executorId @@ -779,9 +814,6 @@ private[spark] class AppStatusListener( val diskDelta = event.blockUpdatedInfo.diskSize * (if (storageLevel.useDisk) 1 else -1) val memoryDelta = event.blockUpdatedInfo.memSize * (if (storageLevel.useMemory) 1 else -1) - // Function to apply a delta to a value, but ensure that it doesn't go negative. - def newValue(old: Long, delta: Long): Long = math.max(0, old + delta) - val updatedStorageLevel = if (storageLevel.isValid) { Some(storageLevel.description) } else { @@ -798,13 +830,13 @@ private[spark] class AppStatusListener( maybeExec.foreach { exec => if (exec.hasMemoryInfo) { if (storageLevel.useOffHeap) { - exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + exec.usedOffHeap = addDeltaToValue(exec.usedOffHeap, memoryDelta) } else { - exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + exec.usedOnHeap = addDeltaToValue(exec.usedOnHeap, memoryDelta) } } - exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) - exec.diskUsed = newValue(exec.diskUsed, diskDelta) + exec.memoryUsed = addDeltaToValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = addDeltaToValue(exec.diskUsed, diskDelta) } // Update the block entry in the RDD info, keeping track of the deltas above so that we @@ -832,8 +864,8 @@ private[spark] class AppStatusListener( // Only update the partition if it's still stored in some executor, otherwise get rid of it. if (executors.nonEmpty) { partition.update(executors, rdd.storageLevel, - newValue(partition.memoryUsed, memoryDelta), - newValue(partition.diskUsed, diskDelta)) + addDeltaToValue(partition.memoryUsed, memoryDelta), + addDeltaToValue(partition.diskUsed, diskDelta)) } else { rdd.removePartition(block.name) } @@ -841,14 +873,14 @@ private[spark] class AppStatusListener( maybeExec.foreach { exec => if (exec.rddBlocks + rddBlocksDelta > 0) { val dist = rdd.distribution(exec) - dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) - dist.diskUsed = newValue(dist.diskUsed, diskDelta) + dist.memoryUsed = addDeltaToValue(dist.memoryUsed, memoryDelta) + dist.diskUsed = addDeltaToValue(dist.diskUsed, diskDelta) if (exec.hasMemoryInfo) { if (storageLevel.useOffHeap) { - dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) + dist.offHeapUsed = addDeltaToValue(dist.offHeapUsed, memoryDelta) } else { - dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) + dist.onHeapUsed = addDeltaToValue(dist.onHeapUsed, memoryDelta) } } dist.lastUpdate = null @@ -867,8 +899,8 @@ private[spark] class AppStatusListener( } } - rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) - rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) + rdd.memoryUsed = addDeltaToValue(rdd.memoryUsed, memoryDelta) + rdd.diskUsed = addDeltaToValue(rdd.diskUsed, diskDelta) update(rdd, now) } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index a0b2458549fbb..762aed4133517 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -541,6 +541,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { distributions.get(exec.executorId) } + def getPartitions(): scala.collection.Map[String, LiveRDDPartition] = partitions + + def getDistributions(): scala.collection.Map[String, LiveRDDDistribution] = distributions + override protected def doUpdate(): Any = { val dists = if (distributions.nonEmpty) { Some(distributions.values.map(_.toApi()).toSeq) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 9ccc8f9cc585b..64e5c8b1c4bbf 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -55,7 +55,7 @@ class RDDInfo( } private[spark] object RDDInfo { - private val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) + private lazy val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index d0c2dc4ad1337..0b2bbd2fa8a78 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -882,12 +882,41 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(dist.memoryRemaining === maxMemory - rdd2b1.memSize - rdd1b2.memSize ) } + // Add block1 of rdd1 back to bm 1. + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo(bm1, rdd1b1.blockId, level, rdd1b1.memSize, rdd1b1.diskSize))) + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 3L) + assert(exec.info.memoryUsed === rdd1b1.memSize + rdd1b2.memSize + rdd2b1.memSize) + assert(exec.info.diskUsed === rdd1b1.diskSize + rdd1b2.diskSize + rdd2b1.diskSize) + } + // Unpersist RDD1. listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) intercept[NoSuchElementException] { check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } } + // executor1 now only contains block1 from rdd2. + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === rdd2b1.memSize) + assert(exec.info.diskUsed === rdd2b1.diskSize) + } + + // Unpersist RDD2. + listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd2b1.rddId)) + intercept[NoSuchElementException] { + check[RDDStorageInfoWrapper](rdd2b1.rddId) { _ => () } + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 0L) + assert(exec.info.memoryUsed === 0) + assert(exec.info.diskUsed === 0) + } + // Update a StreamBlock. val stream1 = StreamBlockId(1, 1L) listener.onBlockUpdated(SparkListenerBlockUpdated( From 9d9601ac8ad2da96343a0181897fdb415dd1b575 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 11 Sep 2018 10:53:28 -0700 Subject: [PATCH 1594/2461] [INFRA] Close stale PRs. Closes #22242 From cfbdd6a1f5906b848c520d3365cc4034992215d9 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 11 Sep 2018 14:46:03 -0500 Subject: [PATCH 1595/2461] [SPARK-25398] Minor bugs from comparing unrelated types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Correct some comparisons between unrelated types to what they seem to… have been trying to do ## How was this patch tested? Existing tests. Closes #22384 from srowen/SPARK-25398. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../org/apache/spark/status/LiveEntity.scala | 4 +--- .../apache/spark/util/ClosureCleaner.scala | 2 +- .../ExternalAppendOnlyMapSuite.scala | 4 +--- .../cluster/mesos/MesosClusterScheduler.scala | 20 +++++++++---------- .../mesos/MesosClusterSchedulerSuite.scala | 14 ++++++------- ...esosFineGrainedSchedulerBackendSuite.scala | 2 +- .../spark/deploy/yarn/ClientSuite.scala | 2 +- .../PropagateEmptyRelationSuite.scala | 2 +- .../sql/catalyst/util/UnsafeArraySuite.scala | 4 ++-- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../parquet/ParquetSchemaSuite.scala | 2 +- 11 files changed, 27 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 762aed4133517..8708e64db3c17 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -33,7 +33,6 @@ import org.apache.spark.storage.RDDInfo import org.apache.spark.ui.SparkUI import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.util.kvstore.KVStore /** * A mutable representation of a live entity in Spark (jobs, stages, tasks, et al). Every live @@ -588,8 +587,7 @@ private object LiveEntityHelpers { .filter { acc => // We don't need to store internal or SQL accumulables as their values will be shown in // other places, so drop them to reduce the memory usage. - !acc.internal && (!acc.metadata.isDefined || - acc.metadata.get != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + !acc.internal && acc.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) } .map { acc => new v1.AccumulableInfo( diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index b6c300c4778b1..43d62561e8eba 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -175,7 +175,7 @@ private[spark] object ClosureCleaner extends Logging { closure.getClass.isSynthetic && closure .getClass - .getInterfaces.exists(_.getName.equals("scala.Serializable")) + .getInterfaces.exists(_.getName == "scala.Serializable") if (isClosureCandidate) { try { diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index d542ba0b6640d..8a2f2ffe0acf1 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.util.collection -import java.util.Objects - import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference @@ -509,7 +507,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite .sorted assert(it.isEmpty) - assert(keys == (0 until 100)) + assert(keys == (0 until 100).toList) assert(map.numSpills == 0) // these asserts try to show that we're no longer holding references to the underlying map. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 7d80eedcc43ce..cb1bcba651be6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -202,7 +202,7 @@ private[spark] class MesosClusterScheduler( } else if (removeFromPendingRetryDrivers(submissionId)) { k.success = true k.message = "Removed driver while it's being retried" - } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + } else if (finishedDrivers.exists(_.driverDescription.submissionId == submissionId)) { k.success = false k.message = "Driver already terminated" } else { @@ -222,21 +222,21 @@ private[spark] class MesosClusterScheduler( } s.submissionId = submissionId stateLock.synchronized { - if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + if (queuedDrivers.exists(_.submissionId == submissionId)) { s.success = true s.driverState = "QUEUED" } else if (launchedDrivers.contains(submissionId)) { s.success = true s.driverState = "RUNNING" launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) - } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + } else if (finishedDrivers.exists(_.driverDescription.submissionId == submissionId)) { s.success = true s.driverState = "FINISHED" finishedDrivers .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus .foreach(state => s.message = state.toString) - } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { - val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + } else if (pendingRetryDrivers.exists(_.submissionId == submissionId)) { + val status = pendingRetryDrivers.find(_.submissionId == submissionId) .get.retryState.get.lastFailureStatus s.success = true s.driverState = "RETRYING" @@ -254,13 +254,13 @@ private[spark] class MesosClusterScheduler( */ def getDriverState(submissionId: String): Option[MesosDriverState] = { stateLock.synchronized { - queuedDrivers.find(_.submissionId.equals(submissionId)) + queuedDrivers.find(_.submissionId == submissionId) .map(d => new MesosDriverState("QUEUED", d)) .orElse(launchedDrivers.get(submissionId) .map(d => new MesosDriverState("RUNNING", d.driverDescription, Some(d)))) - .orElse(finishedDrivers.find(_.driverDescription.submissionId.equals(submissionId)) + .orElse(finishedDrivers.find(_.driverDescription.submissionId == submissionId) .map(d => new MesosDriverState("FINISHED", d.driverDescription, Some(d)))) - .orElse(pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .orElse(pendingRetryDrivers.find(_.submissionId == submissionId) .map(d => new MesosDriverState("RETRYING", d))) } } @@ -814,7 +814,7 @@ private[spark] class MesosClusterScheduler( status: Int): Unit = {} private def removeFromQueuedDrivers(subId: String): Boolean = { - val index = queuedDrivers.indexWhere(_.submissionId.equals(subId)) + val index = queuedDrivers.indexWhere(_.submissionId == subId) if (index != -1) { queuedDrivers.remove(index) queuedDriversState.expunge(subId) @@ -834,7 +834,7 @@ private[spark] class MesosClusterScheduler( } private def removeFromPendingRetryDrivers(subId: String): Boolean = { - val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(subId)) + val index = pendingRetryDrivers.indexWhere(_.submissionId == subId) if (index != -1) { pendingRetryDrivers.remove(index) pendingRetryDriversState.expunge(subId) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index e534b9d7e3ed9..082d4bcfdf83a 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -21,7 +21,7 @@ import java.util.{Collection, Collections, Date} import scala.collection.JavaConverters._ -import org.apache.mesos.Protos.{Environment, Secret, TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Value.{Scalar, Type} import org.apache.mesos.SchedulerDriver import org.mockito.{ArgumentCaptor, Matchers} @@ -146,14 +146,14 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(scheduler.getResource(resources, "cpus") == 1.5) assert(scheduler.getResource(resources, "mem") == 1200) val resourcesSeq: Seq[Resource] = resources.asScala - val cpus = resourcesSeq.filter(_.getName.equals("cpus")).toList + val cpus = resourcesSeq.filter(_.getName == "cpus").toList assert(cpus.size == 2) - assert(cpus.exists(_.getRole().equals("role2"))) - assert(cpus.exists(_.getRole().equals("*"))) - val mem = resourcesSeq.filter(_.getName.equals("mem")).toList + assert(cpus.exists(_.getRole() == "role2")) + assert(cpus.exists(_.getRole() == "*")) + val mem = resourcesSeq.filter(_.getName == "mem").toList assert(mem.size == 2) - assert(mem.exists(_.getRole().equals("role2"))) - assert(mem.exists(_.getRole().equals("*"))) + assert(mem.exists(_.getRole() == "role2")) + assert(mem.exists(_.getRole() == "*")) verify(driver, times(1)).launchTasks( Matchers.eq(Collections.singleton(offer.getId)), diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 31f84310485a0..1ead4b1ed7c7e 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -106,7 +106,7 @@ class MesosFineGrainedSchedulerBackendSuite // uri is null. val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") val executorResources = executorInfo.getResourcesList - val cpus = executorResources.asScala.find(_.getName.equals("cpus")).get.getScalar.getValue + val cpus = executorResources.asScala.find(_.getName == "cpus").get.getScalar.getValue assert(cpus === mesosExecutorCores) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 7fa597167f3f0..26013a109c42b 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -191,7 +191,7 @@ class ClientSuite extends SparkFunSuite with Matchers { appContext.getQueue should be ("staging-queue") appContext.getAMContainerSpec should be (containerLaunchContext) appContext.getApplicationType should be ("SPARK") - appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => + appContext.getClass.getMethods.filter(_.getName == "getApplicationTags").foreach { method => val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] tags should contain allOf ("tag1", "dup", "tag2", "multi word") tags.asScala.count(_.nonEmpty) should be (4) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index f1ce7543ffdc1..d395bba105a7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -147,7 +147,7 @@ class PropagateEmptyRelationSuite extends PlanTest { .where(false) .select('a) .where('a > 1) - .where('a != 200) + .where('a =!= 200) .orderBy('a.asc) val optimized = Optimize.execute(query.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala index 8f75c14192c9b..755c8897cada2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -114,7 +114,7 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafeDate.isInstanceOf[UnsafeArrayData]) assert(unsafeDate.numElements == dateArray.length) dateArray.zipWithIndex.map { case (e, i) => - assert(unsafeDate.get(i, DateType) == e) + assert(unsafeDate.get(i, DateType).asInstanceOf[Int] == e) } val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind(). @@ -122,7 +122,7 @@ class UnsafeArraySuite extends SparkFunSuite { assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData]) assert(unsafeTimestamp.numElements == timestampArray.length) timestampArray.zipWithIndex.map { case (e, i) => - assert(unsafeTimestamp.get(i, TimestampType) == e) + assert(unsafeTimestamp.get(i, TimestampType).asInstanceOf[Long] == e) } Seq(decimalArray4_1, decimalArray20_20).map { decimalArray => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ca8fbc991a3a5..4e593ff046a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -611,7 +611,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDF("id", "stringData") val sampleDF = df.sample(false, 0.7, 50) // After sampling, sampleDF doesn't contain id=1. - assert(!sampleDF.select("id").collect.contains(1)) + assert(!sampleDF.select("id").as[Int].collect.contains(1)) // simpleUdf should not encounter id=1. checkAnswer(sampleDF.select(simpleUdf($"id")), List.fill(sampleDF.count.toInt)(Row(1))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 7eefedb8ff5bb..528a4d0ca8004 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -427,7 +427,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { assert(errMsg.startsWith("Parquet column cannot be converted in file")) val file = errMsg.substring("Parquet column cannot be converted in file ".length, errMsg.indexOf(". ")) - val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + val col = spark.read.parquet(file).schema.fields.filter(_.name == "a") assert(col.length == 1) if (col(0).dataType == StringType) { assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) From 97d4afaa13aaa771220adb3625f1783ce1b3a8df Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 11 Sep 2018 14:52:58 -0500 Subject: [PATCH 1596/2461] Revert "[SPARK-23820][CORE] Enable use of long form of callsite in logs" This reverts commit e58dadb77ed6cac3e1b2a037a6449e5a6e7f2cec. --- .../org/apache/spark/internal/config/package.scala | 3 --- .../main/scala/org/apache/spark/storage/RDDInfo.scala | 10 +--------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ee41bd1a79ae3..bf0391cc9185b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -77,9 +77,6 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) - private[spark] val EVENT_LOG_CALLSITE_FORM = - ConfigBuilder("spark.eventLog.callsite").stringConf.createWithDefault("short") - private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 64e5c8b1c4bbf..e5abbf745cc41 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,9 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -55,16 +53,10 @@ class RDDInfo( } private[spark] object RDDInfo { - private lazy val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) - def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) - val callSite = callsiteForm match { - case "short" => rdd.creationSite.shortForm - case "long" => rdd.creationSite.longForm - } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, callSite, rdd.scope) + rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) } } From 9f5c5b4cca7d4eaa30a3f8adb4cb1eebe3f77c7a Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 11 Sep 2018 15:53:15 -0700 Subject: [PATCH 1597/2461] [SPARK-25399][SS] Continuous processing state should not affect microbatch execution jobs ## What changes were proposed in this pull request? The leftover state from running a continuous processing streaming job should not affect later microbatch execution jobs. If a continuous processing job runs and the same thread gets reused for a microbatch execution job in the same environment, the microbatch job could get wrong answers because it can attempt to load the wrong version of the state. ## How was this patch tested? New and existing unit tests Closes #22386 from mukulmurthy/25399-streamthread. Authored-by: Mukul Murthy Signed-off-by: Tathagata Das --- .../streaming/MicroBatchExecution.scala | 2 ++ .../execution/streaming/StreamExecution.scala | 1 + .../continuous/ContinuousExecution.scala | 2 ++ .../streaming/state/StateStoreRDD.scala | 12 +++++-- .../spark/sql/streaming/StreamSuite.scala | 33 +++++++++++++++++-- 5 files changed, 45 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index b1cafd67820c2..2cac86599ef19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -511,6 +511,8 @@ class MicroBatchExecution( sparkSessionToRunBatch.sparkContext.setLocalProperty( MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + sparkSessionToRunBatch.sparkContext.setLocalProperty( + StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index a39bb715c9913..f6c60c1c92124 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -529,6 +529,7 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" def isInterruptionException(e: Throwable): Boolean = e match { // InterruptedIOException - thrown when an I/O operation is interrupted diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4ddebb33b79d1..ccca72667a217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -209,6 +209,8 @@ class ContinuousExecution( scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig }.head + sparkSessionForQuery.sparkContext.setLocalProperty( + StreamExecution.IS_CONTINUOUS_PROCESSING, true.toString) sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3f11b8f79943c..4a69a48fed75f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -74,9 +75,14 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( // If we're in continuous processing mode, we should get the store version for the current // epoch rather than the one at planning time. - val currentVersion = EpochTracker.getCurrentEpoch match { - case None => storeVersion - case Some(value) => value + val isContinuous = Option(ctxt.getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + .map(_.toBoolean).getOrElse(false) + val currentVersion = if (isContinuous) { + val epoch = EpochTracker.getCurrentEpoch + assert(epoch.isDefined, "Current epoch must be defined for continuous processing streams.") + epoch.get + } else { + storeVersion } store = StateStore.get( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index bf509b1976ed8..f55ddb5419d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -29,13 +29,14 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, TaskContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ @@ -788,7 +789,7 @@ class StreamSuite extends StreamTest { val query = input .toDS() .map { i => - while (!org.apache.spark.TaskContext.get().isInterrupted()) { + while (!TaskContext.get().isInterrupted()) { // keep looping till interrupted by query.stop() Thread.sleep(100) } @@ -1029,6 +1030,34 @@ class StreamSuite extends StreamTest { false)) } + test("is_continuous_processing property should be false for microbatch processing") { + val input = MemoryStream[Int] + val df = input.toDS() + .map(i => TaskContext.get().getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + testStream(df) ( + AddData(input, 1), + CheckAnswer("false") + ) + } + + test("is_continuous_processing property should be true for continuous processing") { + val input = ContinuousMemoryStream[Int] + val stream = input.toDS() + .map(i => TaskContext.get().getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + .writeStream.format("memory") + .queryName("output") + .trigger(Trigger.Continuous("1 seconds")) + .start() + try { + input.addData(1) + stream.processAllAvailable() + } finally { + stream.stop() + } + + checkAnswer(spark.sql("select * from output"), Row("true")) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, From 79cc59718fdf7785bdc37a26bb8df4c6151114a6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 12 Sep 2018 21:11:22 +0800 Subject: [PATCH 1598/2461] [SPARK-25402][SQL] Null handling in BooleanSimplification ## What changes were proposed in this pull request? This PR is to fix the null handling in BooleanSimplification. In the rule BooleanSimplification, there are two cases that do not properly handle null values. The optimization is not right if either side is null. This PR is to fix them. ## How was this patch tested? Added test cases Closes #22390 from gatorsmile/fixBooleanSimplification. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/expressions.scala | 13 ++++-- .../BooleanSimplificationSuite.scala | 45 +++++++++++++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 10 +++++ 3 files changed, 60 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 5629b72894225..f8037588fa71e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -263,10 +263,15 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral - case a And b if Not(a).semanticEquals(b) => FalseLiteral - case a Or b if Not(a).semanticEquals(b) => TrueLiteral - case a And b if a.semanticEquals(Not(b)) => FalseLiteral - case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral) + case a And b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral) + + case a Or b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral) + case a Or b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral) case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 653c07f1835ca..6cd1108eef333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -37,6 +38,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, + SimplifyConditionals, BooleanSimplification, PruneFilters) :: Nil } @@ -48,6 +50,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.output, Seq(Row(1, 2, 3, "abc")) ) + val testNotNullableRelation = LocalRelation('a.int.notNull, 'b.int.notNull, 'c.int.notNull, + 'd.string.notNull, 'e.boolean.notNull, 'f.boolean.notNull, 'g.boolean.notNull, + 'h.boolean.notNull) + + val testNotNullableRelationWithData = LocalRelation.fromExternalRows( + testNotNullableRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { val plan = testRelationWithData.where(input).analyze val actual = Optimize.execute(plan) @@ -61,6 +71,13 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { comparePlans(actual, correctAnswer) } + private def checkConditionInNotNullableRelation( + input: Expression, expected: LogicalPlan): Unit = { + val plan = testNotNullableRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + test("a && a => a") { checkCondition(Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) checkCondition(Literal(1) < 'a && Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) @@ -174,10 +191,30 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { } test("Complementation Laws") { - checkCondition('a && !'a, testRelation) - checkCondition(!'a && 'a, testRelation) + checkConditionInNotNullableRelation('e && !'e, testNotNullableRelation) + checkConditionInNotNullableRelation(!'e && 'e, testNotNullableRelation) + + checkConditionInNotNullableRelation('e || !'e, testNotNullableRelationWithData) + checkConditionInNotNullableRelation(!'e || 'e, testNotNullableRelationWithData) + } + + test("Complementation Laws - null handling") { + checkCondition('e && !'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + checkCondition(!'e && 'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), false)).analyze) + + checkCondition('e || !'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + checkCondition(!'e || 'e, + testRelationWithData.where(If('e.isNull, Literal.create(null, BooleanType), true)).analyze) + } + + test("Complementation Laws - negative case") { + checkCondition('e && !'f, testRelationWithData.where('e && !'f).analyze) + checkCondition(!'f && 'e, testRelationWithData.where(!'f && 'e).analyze) - checkCondition('a || !'a, testRelationWithData) - checkCondition(!'a || 'a, testRelationWithData) + checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze) + checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 435b887cb3c78..279b7b8d49f52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2569,4 +2569,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) check(lit(2).cast("int"), $"c" =!= 2, Seq()) } + + test("SPARK-25402 Null handling in BooleanSimplification") { + val schema = StructType.fromDDL("a boolean, b int") + val rows = Seq(Row(null, 1)) + + val rdd = sparkContext.parallelize(rows) + val df = spark.createDataFrame(rdd, schema) + + checkAnswer(df.where("(NOT a) OR a"), Seq.empty) + } } From 2f422398b524eacc89ab58e423bb134ae3ca3941 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Sep 2018 22:54:05 +0800 Subject: [PATCH 1599/2461] [SPARK-25352][SQL] Perform ordered global limit when limit number is bigger than topKSortFallbackThreshold ## What changes were proposed in this pull request? We have optimization on global limit to evenly distribute limit rows across all partitions. This optimization doesn't work for ordered results. For a query ending with sort + limit, in most cases it is performed by `TakeOrderedAndProjectExec`. But if limit number is bigger than `SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD`, global limit will be used. At this moment, we need to do ordered global limit. ## How was this patch tested? Unit tests. Closes #22344 from viirya/SPARK-25352. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/execution/SparkStrategies.scala | 44 ++++++--- .../apache/spark/sql/execution/limit.scala | 7 +- .../org/apache/spark/sql/DataFrameSuite.scala | 22 ++++- .../spark/sql/execution/LimitSuite.scala | 81 ++++++++++++++++ .../TakeOrderedAndProjectSuite.scala | 94 +++++++++++-------- 5 files changed, 192 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..7c8ce316f9647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -68,22 +68,42 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(s)), + orderedLimit = true) :: Nil + } + case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(p)), + orderedLimit = true) :: Nil + } case Limit(IntegerLiteral(limit), child) => CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(s)), + orderedLimit = true) :: Nil + } + case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) => + if (limit < conf.topKSortFallbackThreshold) { + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(p)), + orderedLimit = true) :: Nil + } case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index fb46970e38f3c..1a09632f93ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -98,7 +98,8 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode wi /** * Take the `limit` elements of the child output. */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { +case class GlobalLimitExec(limit: Int, child: SparkPlan, + orderedLimit: Boolean = false) extends UnaryExecNode { override def output: Seq[Attribute] = child.output @@ -126,7 +127,9 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode { // When enabled, Spark goes to take rows at each partition repeatedly until reaching // limit number. When disabled, Spark takes all rows at first partition, then rows // at second partition ..., until reaching limit number. - val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit + // The optimization is disabled when it is needed to keep the original order of rows + // before global sort, e.g., select * from table order by col limit 10. + val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit val shuffled = new ShuffledRowRDD(shuffleDependency) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 279b7b8d49f52..f001b138f4b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ @@ -2552,6 +2552,26 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val baseDf = spark.range(1000).toDF.repartition(3).sort("id") + + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { + val expected = baseDf.limit(99) + val takeOrderedNode1 = expected.queryExecution.executedPlan + .find(_.isInstanceOf[TakeOrderedAndProjectExec]) + assert(takeOrderedNode1.isDefined) + + val result = baseDf.limit(100) + val takeOrderedNode2 = result.queryExecution.executedPlan + .find(_.isInstanceOf[TakeOrderedAndProjectExec]) + assert(takeOrderedNode2.isEmpty) + + checkAnswer(expected, result.collect().take(99)) + } + } + } + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { val df1 = spark.createDataFrame(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala new file mode 100644 index 0000000000000..a7840a5fcfae0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class LimitSuite extends SparkPlanTest with SharedSQLContext { + + private var rand: Random = _ + private var seed: Long = 0 + + protected override def beforeAll(): Unit = { + super.beforeAll() + seed = System.currentTimeMillis() + rand = new Random(seed) + } + + test("Produce ordered global limit if more than topKSortFallbackThreshold") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { + val df = LimitTest.generateRandomInputData(spark, rand).sort("a") + + val globalLimit = df.limit(99).queryExecution.executedPlan.collect { + case g: GlobalLimitExec => g + } + assert(globalLimit.size == 0) + + val topKSort = df.limit(99).queryExecution.executedPlan.collect { + case t: TakeOrderedAndProjectExec => t + } + assert(topKSort.size == 1) + + val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect { + case g: GlobalLimitExec => g + } + assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true) + } + } + + test("Ordered global limit") { + val baseDf = LimitTest.generateRandomInputData(spark, rand) + .select("a").repartition(3).sort("a") + + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, + orderedLimit = true) + val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext) + .map(_.getInt(0)) + + val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false) + val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext) + .map(_.getInt(0)) + + // Global limit without order takes values at each partition sequentially. + // After global sort, the values in second partition must be larger than the values + // in first partition. + assert(orderedGlobalLimitResult(0) == globalLimitResult(0)) + assert(orderedGlobalLimitResult(1) < globalLimitResult(1)) + assert(orderedGlobalLimitResult(2) < globalLimitResult(2)) + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index f076959dfdf7b..9322204063af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.internal.SQLConf @@ -32,28 +32,10 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { private var rand: Random = _ private var seed: Long = 0 - private val originalLimitFlatGlobalLimit = SQLConf.get.getConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT) - protected override def beforeAll(): Unit = { super.beforeAll() seed = System.currentTimeMillis() rand = new Random(seed) - - // Disable the optimization to make Sort-Limit match `TakeOrderedAndProject` semantics. - SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) - } - - protected override def afterAll() = { - SQLConf.get.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) - super.afterAll() - } - - private def generateRandomInputData(): DataFrame = { - val schema = new StructType() - .add("a", IntegerType, nullable = false) - .add("b", IntegerType, nullable = false) - val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } /** @@ -66,32 +48,62 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) + } } } test("TakeOrderedAndProject.doExecute with project") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - generateRandomInputData(), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) + } } } + + test("TakeOrderedAndProject.doExecute equals to ordered global limit") { + withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + LimitTest.generateRandomInputData(spark, rand), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input)), orderedLimit = true), + sortAnswers = false) + } + } + } +} + +object LimitTest { + def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } } From 3030b82c89d3e45a2e361c469fbc667a1e43b854 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Sep 2018 17:43:40 +0000 Subject: [PATCH 1600/2461] [SPARK-25363][SQL] Fix schema pruning in where clause by ignoring unnecessary root fields ## What changes were proposed in this pull request? Schema pruning doesn't work if nested column is used in where clause. For example, ``` sql("select name.first from contacts where name.first = 'David'") == Physical Plan == *(1) Project [name#19.first AS first#40] +- *(1) Filter (isnotnull(name#19) && (name#19.first = David)) +- *(1) FileScan parquet [name#19] Batched: false, Format: Parquet, PartitionFilters: [], PushedFilters: [IsNotNull(name)], ReadSchema: struct> ``` In above query plan, the scan node reads the entire schema of `name` column. This issue is reported by: https://github.com/apache/spark/pull/21320#issuecomment-419290197 The cause is that we infer a root field from expression `IsNotNull(name)`. However, for such expression, we don't really use the nested fields of this root field, so we can ignore the unnecessary nested fields. ## How was this patch tested? Unit tests. Closes #22357 from viirya/SPARK-25363. Authored-by: Liang-Chi Hsieh Signed-off-by: DB Tsai --- .../parquet/ParquetSchemaPruning.scala | 34 ++++++-- .../parquet/ParquetSchemaPruningSuite.scala | 77 ++++++++++++++++--- 2 files changed, 96 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala index 6a46b5f8edc54..91080b15727d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -110,7 +110,17 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { val projectionRootFields = projects.flatMap(getRootFields) val filterRootFields = filters.flatMap(getRootFields) - (projectionRootFields ++ filterRootFields).distinct + // Kind of expressions don't need to access any fields of a root fields, e.g., `IsNotNull`. + // For them, if there are any nested fields accessed in the query, we don't need to add root + // field access of above expressions. + // For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`, + // we don't need to read nested fields of `name` struct other than `first` field. + val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields) + .distinct.partition(_.contentAccessed) + + optRootFields.filter { opt => + !rootFields.exists(_.field.name == opt.field.name) + } ++ rootFields } /** @@ -156,7 +166,7 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { // in the resulting schema may differ from their ordering in the logical relation's // original schema val mergedSchema = requestedRootFields - .map { case RootField(field, _) => StructType(Array(field)) } + .map { case root: RootField => StructType(Array(root.field)) } .reduceLeft(_ merge _) val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet val mergedDataSchema = @@ -199,6 +209,15 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { case att: Attribute => RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil + // Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions + // don't actually use any nested fields. These root field accesses might be excluded later + // if there are any nested fields accesses in the query plan. + case IsNotNull(SelectedField(field)) => + RootField(field, derivedFromAtt = false, contentAccessed = false) :: Nil + case IsNull(SelectedField(field)) => + RootField(field, derivedFromAtt = false, contentAccessed = false) :: Nil + case IsNotNull(_: Attribute) | IsNull(_: Attribute) => + expr.children.flatMap(getRootFields).map(_.copy(contentAccessed = false)) case _ => expr.children.flatMap(getRootFields) } @@ -250,8 +269,11 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { } /** - * A "root" schema field (aka top-level, no-parent) and whether it was derived from - * an attribute or had a proper child. + * This represents a "root" schema field (aka top-level, no-parent). `field` is the + * `StructField` for field name and datatype. `derivedFromAtt` indicates whether it + * was derived from an attribute or had a proper child. `contentAccessed` means whether + * it was accessed with its content by the expressions refer it. */ - private case class RootField(field: StructField, derivedFromAtt: Boolean) + private case class RootField(field: StructField, derivedFromAtt: Boolean, + contentAccessed: Boolean = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index eb99654fa78f5..7b132af4f6911 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -35,22 +35,29 @@ class ParquetSchemaPruningSuite with SchemaPruningTest with SharedSQLContext { case class FullName(first: String, middle: String, last: String) + case class Company(name: String, address: String) + case class Employer(id: Int, company: Company) case class Contact( id: Int, name: FullName, address: String, pets: Int, friends: Array[FullName] = Array.empty, - relatives: Map[String, FullName] = Map.empty) + relatives: Map[String, FullName] = Map.empty, + employer: Employer = null) val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") + val employer = Employer(0, Company("abc", "123 Business Street")) + val employerWithNullCompany = Employer(1, null) + private val contacts = Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), - relatives = Map("brother" -> johnDoe)) :: - Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe)) :: Nil + relatives = Map("brother" -> johnDoe), employer = employer) :: + Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe), + employer = employerWithNullCompany) :: Nil case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -66,13 +73,14 @@ class ParquetSchemaPruningSuite pets: Int, friends: Array[FullName] = Array(), relatives: Map[String, FullName] = Map(), + employer: Employer = null, p: Int) case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int) private val contactsWithDataPartitionColumn = - contacts.map { case Contact(id, name, address, pets, friends, relatives) => - ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, 1) } + contacts.map { case Contact(id, name, address, pets, friends, relatives, employer) => + ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, employer, 1) } private val briefContactsWithDataPartitionColumn = briefContacts.map { case BriefContact(id, name, address) => BriefContactWithDataPartitionColumn(id, name, address, 2) } @@ -155,6 +163,60 @@ class ParquetSchemaPruningSuite Row(null) :: Row(null) :: Nil) } + testSchemaPruning("select a single complex field and in where clause") { + val query1 = sql("select name.first from contacts where name.first = 'Jane'") + checkScan(query1, "struct>") + checkAnswer(query1, Row("Jane") :: Nil) + + val query2 = sql("select name.first, name.last from contacts where name.first = 'Jane'") + checkScan(query2, "struct>") + checkAnswer(query2, Row("Jane", "Doe") :: Nil) + + val query3 = sql("select name.first from contacts " + + "where employer.company.name = 'abc' and p = 1") + checkScan(query3, "struct," + + "employer:struct>>") + checkAnswer(query3, Row("Jane") :: Nil) + + val query4 = sql("select name.first, employer.company.name from contacts " + + "where employer.company is not null and p = 1") + checkScan(query4, "struct," + + "employer:struct>>") + checkAnswer(query4, Row("Jane", "abc") :: Nil) + } + + testSchemaPruning("select nullable complex field and having is not null predicate") { + val query = sql("select employer.company from contacts " + + "where employer is not null and p = 1") + checkScan(query, "struct>>") + checkAnswer(query, Row(Row("abc", "123 Business Street")) :: Row(null) :: Nil) + } + + testSchemaPruning("select a single complex field and is null expression in project") { + val query = sql("select name.first, address is not null from contacts") + checkScan(query, "struct,address:string>") + checkAnswer(query.orderBy("id"), + Row("Jane", true) :: Row("John", true) :: Row("Janet", true) :: Row("Jim", true) :: Nil) + } + + testSchemaPruning("select a single complex field array and in clause") { + val query = sql("select friends.middle from contacts where friends.first[0] = 'Susan'") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row(Array("Z.")) :: Nil) + } + + testSchemaPruning("select a single complex field from a map entry and in clause") { + val query = + sql("select relatives[\"brother\"].middle from contacts " + + "where relatives[\"brother\"].first = 'John'") + checkScan(query, + "struct>>") + checkAnswer(query.orderBy("id"), + Row("Y.") :: Nil) + } + private def testSchemaPruning(testName: String)(testThunk: => Unit) { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { test(s"Spark vectorized reader - without partition data column - $testName") { @@ -238,10 +300,7 @@ class ParquetSchemaPruningSuite testMixedCasePruning("filter with different-case column names") { val query = sql("select id from mixedcase where Col2.b = 2") - // Pruning with filters is currently unsupported. As-is, the file reader will read the id column - // and the entire coL2 struct. Once pruning with filters has been implemented we can uncomment - // this line - // checkScan(query, "struct>") + checkScan(query, "struct>") checkAnswer(query.orderBy("id"), Row(1) :: Nil) } From ab25c967905ca0973fc2f30b8523246bb9244206 Mon Sep 17 00:00:00 2001 From: Michael Mior Date: Thu, 13 Sep 2018 09:45:25 +0800 Subject: [PATCH 1601/2461] [SPARK-23820][CORE] Enable use of long form of callsite in logs This is a rework of #21433 to address some concerns there. Closes #22398 from michaelmior/long-callsite2. Authored-by: Michael Mior Signed-off-by: Wenchen Fan --- .../org/apache/spark/internal/config/package.scala | 3 +++ .../main/scala/org/apache/spark/storage/RDDInfo.scala | 11 ++++++++++- docs/configuration.md | 7 +++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bf0391cc9185b..8d827189ebb57 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -77,6 +77,9 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_LONG_FORM = + ConfigBuilder("spark.eventLog.longForm.enabled").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..19f86569c1e3c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,17 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteLongForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_LONG_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = if (callsiteLongForm) { + rdd.creationSite.longForm + } else { + rdd.creationSite.shortForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } diff --git a/docs/configuration.md b/docs/configuration.md index 3a8d56776e9e8..782ccff667076 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -746,6 +746,13 @@ Apart from these, the following properties are also available, and may be useful *Warning*: This will increase the size of the event log considerably. + + spark.eventLog.longForm.enabled + false + + If true, use the long form of call sites in the event log. Otherwise use the short form. + + spark.eventLog.compress false From 083c9447671719e0bd67312e3d572f6160c06a4a Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 13 Sep 2018 09:51:49 +0800 Subject: [PATCH 1602/2461] [SPARK-25387][SQL] Fix for NPE caused by bad CSV input ## What changes were proposed in this pull request? The PR fixes NPE in `UnivocityParser` caused by malformed CSV input. In some cases, `uniVocity` parser can return `null` for bad input. In the PR, I propose to check result of parsing and not propagate NPE to upper layers. ## How was this patch tested? I added a test which reproduce the issue and tested by `CSVSuite`. Closes #22374 from MaxGekk/npe-on-bad-csv. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../datasources/csv/CSVDataSource.scala | 36 ++++++++++--------- .../datasources/csv/UnivocityParser.scala | 7 +++- .../execution/datasources/csv/CSVSuite.scala | 11 +++++- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2b86054c0ffcb..e840ff1682502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -240,23 +240,25 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, csv: Dataset[String], maybeFirstLine: Option[String], - parsedOptions: CSVOptions): StructType = maybeFirstLine match { - case Some(firstLine) => - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) - val tokenRDD = sampled.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) - case None => - // If the first line could not be read, just return the empty schema. - StructType(Nil) + parsedOptions: CSVOptions): StructType = { + val csvParser = new CsvParser(parsedOptions.asParserSettings) + maybeFirstLine.map(csvParser.parseLine(_)) match { + case Some(firstRow) if firstRow != null => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case _ => + // If the first line could not be read, just return the empty schema. + StructType(Nil) + } } private def createBaseDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e15af425b2649..9088d43905e28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -216,7 +216,12 @@ class UnivocityParser( } private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != parsedSchema.length) { + if (tokens == null) { + throw BadRecordException( + () => getCurrentInput, + () => None, + new RuntimeException("Malformed CSV record")) + } else if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2b39a0b1f52ee..f70df0bcecde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -34,7 +34,7 @@ import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -1811,4 +1811,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkCount(2) countForMalformedCSV(0, Seq("")) } + + test("SPARK-25387: bad input should not cause NPE") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val input = spark.createDataset(Seq("\u0000\u0000\u0001234")) + + checkAnswer(spark.read.schema(schema).csv(input), Row(null)) + checkAnswer(spark.read.option("multiLine", true).schema(schema).csv(input), Row(null)) + assert(spark.read.csv(input).collect().toSet == Set(Row())) + } } From 6dc5921e66d56885b95c07e56e687f9f6c1eaca7 Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Thu, 13 Sep 2018 09:57:34 +0800 Subject: [PATCH 1603/2461] [SPARK-25357][SQL] Add metadata to SparkPlanInfo to dump more information like file path to event log ## What changes were proposed in this pull request? Field metadata removed from SparkPlanInfo in #18600 . Corresponding, many meta data was also removed from event SparkListenerSQLExecutionStart in Spark event log. If we want to analyze event log to get all input paths, we couldn't get them. Instead, simpleString of SparkPlanInfo JSON only display 100 characters, it won't help. Before 2.3, the fragment of SparkListenerSQLExecutionStart in event log looks like below (It contains the metadata field which has the intact information): >{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", Location: InMemoryFileIndex[hdfs://cluster1/sys/edw/test1/test2/test3/test4..., "metadata": {"Location": "InMemoryFileIndex[hdfs://cluster1/sys/edw/test1/test2/test3/test4/test5/snapshot/dt=20180904]","ReadSchema":"struct"} After #18600, metadata field was removed. >{"Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", Location: InMemoryFileIndex[hdfs://cluster1/sys/edw/test1/test2/test3/test4..., So I add this field back to SparkPlanInfo class. Then it will log out the meta data to event log. Intact information in event log is very useful for offline job analysis. ## How was this patch tested? Unit test Closes #22353 from LantaoJin/SPARK-25357. Authored-by: LantaoJin Signed-off-by: Wenchen Fan --- .../apache/spark/sql/execution/SparkPlanInfo.scala | 12 ++++++++---- .../spark/sql/execution/SQLJsonProtocolSuite.scala | 2 +- .../apache/spark/sql/execution/SparkPlanSuite.scala | 8 ++++++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 2a2315896831c..59ffd16381116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import com.fasterxml.jackson.annotation.JsonIgnoreProperties - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo @@ -28,11 +26,11 @@ import org.apache.spark.sql.execution.metric.SQLMetricInfo * Stores information about a SQL SparkPlan. */ @DeveloperApi -@JsonIgnoreProperties(Array("metadata")) // The metadata field was removed in Spark 2.3. class SparkPlanInfo( val nodeName: String, val simpleString: String, val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], val metrics: Seq[SQLMetricInfo]) { override def hashCode(): Int = { @@ -59,6 +57,12 @@ private[execution] object SparkPlanInfo { new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) } - new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), metrics) + // dump the file scan metadata (e.g file path) to event log + val metadata = plan match { + case fileScan: FileSourceScanExec => fileScan.metadata + case _ => Map[String, String]() + } + new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + metadata, metrics) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala index c2e62b987e0cc..08e40e28d3d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -46,7 +46,7 @@ class SQLJsonProtocolSuite extends SparkFunSuite { """.stripMargin val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", - new SparkPlanInfo("TestNode", "test string", Nil, Nil), 0) + new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) assert(reconstructedEvent == expectedEvent) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index 34dc6f37c0e4d..47ff372992b91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -50,4 +50,12 @@ class SparkPlanSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-25357 SparkPlanInfo of FileScan contains nonEmpty metadata") { + withTempPath { path => + spark.range(5).write.parquet(path.getAbsolutePath) + val f = spark.read.parquet(path.getAbsolutePath) + assert(SparkPlanInfo.fromSparkPlan(f.queryExecution.sparkPlan).metadata.nonEmpty) + } + } } From 08c76b5d39127ae207d9d1fff99c2551e6ce2581 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 13 Sep 2018 11:19:43 +0800 Subject: [PATCH 1604/2461] [SPARK-25238][PYTHON] lint-python: Fix W605 warnings for pycodestyle 2.4 (This change is a subset of the changes needed for the JIRA; see https://github.com/apache/spark/pull/22231) ## What changes were proposed in this pull request? Use raw strings and simpler regex syntax consistently in Python, which also avoids warnings from pycodestyle about accidentally relying Python's non-escaping of non-reserved chars in normal strings. Also, fix a few long lines. ## How was this patch tested? Existing tests, and some manual double-checking of the behavior of regexes in Python 2/3 to be sure. Closes #22400 from srowen/SPARK-25238.2. Authored-by: Sean Owen Signed-off-by: hyukjinkwon --- dev/create-release/generate-contributors.py | 10 +++++----- dev/create-release/releaseutils.py | 2 +- dev/merge_spark_pr.py | 4 ++-- dev/run-tests-jenkins.py | 3 ++- dev/run-tests.py | 2 +- python/pyspark/ml/classification.py | 4 ++-- python/pyspark/ml/clustering.py | 16 ++++++++-------- python/pyspark/ml/feature.py | 16 ++++++++-------- python/pyspark/ml/fpm.py | 2 +- python/pyspark/ml/regression.py | 20 ++++++++++---------- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/evaluation.py | 4 ++-- python/pyspark/mllib/feature.py | 2 +- python/pyspark/rdd.py | 2 +- python/pyspark/shell.py | 2 +- python/pyspark/sql/functions.py | 14 ++++++++------ python/pyspark/sql/readwriter.py | 12 ++++++------ python/pyspark/sql/streaming.py | 2 +- python/pyspark/sql/types.py | 2 +- python/pyspark/storagelevel.py | 4 ++-- python/pyspark/util.py | 2 +- python/run-tests.py | 2 +- 22 files changed, 66 insertions(+), 63 deletions(-) diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 131d81c8a75cf..d9135173419ae 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -67,7 +67,7 @@ print("Release tag: %s" % RELEASE_TAG) print("Previous release tag: %s" % PREVIOUS_RELEASE_TAG) print("Number of commits in this range: %s" % len(new_commits)) -print +print("") def print_indented(_list): @@ -88,10 +88,10 @@ def print_indented(_list): def is_release(commit_title): - return re.findall("\[release\]", commit_title.lower()) or \ - "preparing spark release" in commit_title.lower() or \ - "preparing development version" in commit_title.lower() or \ - "CHANGES.txt" in commit_title + return ("[release]" in commit_title.lower() or + "preparing spark release" in commit_title.lower() or + "preparing development version" in commit_title.lower() or + "CHANGES.txt" in commit_title) def is_maintenance(commit_title): diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 8cc990d871842..f273b337fdb4e 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -235,7 +235,7 @@ def translate_component(component, commit_hash, warnings): # Parse components in the commit message # The returned components are already filtered and translated def find_components(commit, commit_hash): - components = re.findall("\[\w*\]", commit.lower()) + components = re.findall(r"\[\w*\]", commit.lower()) components = [translate_component(c, commit_hash) for c in components if c in known_components] return components diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 81daa909e019c..cca6f405e89ac 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -274,7 +274,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): versions = sorted(versions, key=lambda x: x.name, reverse=True) versions = filter(lambda x: x.raw['released'] is False, versions) # Consider only x.y.z versions - versions = filter(lambda x: re.match('\d+\.\d+\.\d+', x.name), versions) + versions = filter(lambda x: re.match(r'\d+\.\d+\.\d+', x.name), versions) default_fix_versions = map(lambda x: fix_version_from_branch(x, versions).name, merge_branches) for v in default_fix_versions: @@ -403,7 +403,7 @@ def standardize_jira_ref(text): # Extract spark component(s): # Look for alphanumeric chars, spaces, dashes, periods, and/or commas - pattern = re.compile(r'(\[[\w\s,-\.]+\])', re.IGNORECASE) + pattern = re.compile(r'(\[[\w\s,.-]+\])', re.IGNORECASE) for component in pattern.findall(text): components.append(component.upper()) text = text.replace(component, '') diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index e6fe3b82ed202..6e943898ffed9 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -115,7 +115,8 @@ def run_tests(tests_timeout): os.path.join(SPARK_HOME, 'dev', 'run-tests')]).wait() failure_note_by_errcode = { - 1: 'executing the `dev/run-tests` script', # error to denote run-tests script failures + # error to denote run-tests script failures: + 1: 'executing the `dev/run-tests` script', # noqa: W605 ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', diff --git a/dev/run-tests.py b/dev/run-tests.py index d9d3789ac1255..f534637b80d6b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -169,7 +169,7 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - match = re.search('(\d+)\.(\d+)\.(\d+)', raw_version_str) + match = re.search(r'(\d+)\.(\d+)\.(\d+)', raw_version_str) major = int(match.group(1)) minor = int(match.group(2)) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d5963f4f7042c..ce028512357f2 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -773,8 +773,8 @@ def roc(self): which is a Dataframe having two fields (FPR, TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. - .. seealso:: `Wikipedia reference \ - `_ + .. seealso:: `Wikipedia reference + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. This will change in later Spark diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index ab449bc3f8f51..5ef4e765ea4e1 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -1202,21 +1202,21 @@ class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReada .. note:: Experimental Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by - Lin and Cohen. From the abstract: + `Lin and Cohen `_. From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method to run the PowerIterationClustering algorithm. - .. seealso:: `Wikipedia on Spectral clustering \ - `_ + .. seealso:: `Wikipedia on Spectral clustering + `_ - >>> data = [(1, 0, 0.5), \ - (2, 0, 0.5), (2, 1, 0.7), \ - (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \ - (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \ - (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] + >>> data = [(1, 0, 0.5), + ... (2, 0, 0.5), (2, 1, 0.7), + ... (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), + ... (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), + ... (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight") >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight") >>> assignments = pic.assignClusters(df) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 760aa82168f5a..eccb7acae5b98 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -207,8 +207,8 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp distance space. The output will be vectors of configurable dimension. Hash values in the same dimension are calculated by the same hash function. - .. seealso:: `Stable Distributions \ - `_ + .. seealso:: `Stable Distributions + `_ .. seealso:: `Hashing for Similarity Search: A Survey `_ >>> from pyspark.ml.linalg import Vectors @@ -303,7 +303,7 @@ def _create_model(self, java_model): class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable): - """ + r""" .. note:: Experimental Model fitted by :py:class:`BucketedRandomProjectionLSH`, where multiple random vectors are @@ -653,8 +653,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit The return vector is scaled such that the transform matrix is unitary (aka scaled DCT-II). - .. seealso:: `More information on Wikipedia \ - `_. + .. seealso:: `More information on Wikipedia + `_. >>> from pyspark.ml.linalg import Vectors >>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) @@ -1353,7 +1353,7 @@ def _create_model(self, java_model): class MinHashLSHModel(LSHModel, JavaMLReadable, JavaMLWritable): - """ + r""" .. note:: Experimental Model produced by :py:class:`MinHashLSH`, where where multiple hash functions are stored. Each @@ -1362,8 +1362,8 @@ class MinHashLSHModel(LSHModel, JavaMLReadable, JavaMLWritable): :math:`h_i(x) = ((x \cdot a_i + b_i) \mod prime)` This hash family is approximately min-wise independent according to the reference. - .. seealso:: Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear \ - permutations." Electronic Journal of Combinatorics 7 (2000): R26. + .. seealso:: Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear + permutations." Electronic Journal of Combinatorics 7 (2000): R26. .. versionadded:: 2.2.0 """ diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index c2b29b73460ff..886ad8409ca66 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -158,7 +158,7 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, HasMinSupport, HasNumPartitions, HasMinConfidence, JavaMLWritable, JavaMLReadable): - """ + r""" .. note:: Experimental A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 513ca5a9df85e..98f4361351847 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -188,8 +188,8 @@ def intercept(self): @property @since("2.3.0") def scale(self): - """ - The value by which \|y - X'w\| is scaled down when loss is "huber", otherwise 1.0. + r""" + The value by which :math:`\|y - X'w\|` is scaled down when loss is "huber", otherwise 1.0. """ return self._call_java("scale") @@ -279,12 +279,12 @@ def featuresCol(self): @property @since("2.0.0") def explainedVariance(self): - """ + r""" Returns the explained variance regression score. - explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}` - .. seealso:: `Wikipedia explain variation \ - `_ + .. seealso:: `Wikipedia explain variation + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -339,8 +339,8 @@ def r2(self): """ Returns R^2, the coefficient of determination. - .. seealso:: `Wikipedia coefficient of determination \ - `_ + .. seealso:: `Wikipedia coefficient of determination + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -354,8 +354,8 @@ def r2adj(self): """ Returns Adjusted R^2, the adjusted coefficient of determination. - .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ - `_ + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark versions. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b09469b9f5c2d..b1a8af6bcc094 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -647,7 +647,7 @@ class PowerIterationClustering(object): @classmethod @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): - """ + r""" :param rdd: An RDD of (i, j, s\ :sub:`ij`\) tuples representing the affinity matrix, which is the matrix A in the PIC paper. The diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 6c65da58e4e2b..0bb0ca37c1ab6 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -117,9 +117,9 @@ def __init__(self, predictionAndObservations): @property @since('1.4.0') def explainedVariance(self): - """ + r""" Returns the explained variance regression score. - explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + explainedVariance = :math:`1 - \frac{variance(y - \hat{y})}{variance(y)}` """ return self.call("explainedVariance") diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 40ecd2e0ff4be..6d7d4d61db043 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -59,7 +59,7 @@ def transform(self, vector): class Normalizer(VectorTransformer): - """ + r""" Normalizes samples individually to unit L\ :sup:`p`\ norm For any 1 <= `p` < float('inf'), normalizes samples using diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b317156885e51..ccf39e1ffbe96 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2399,7 +2399,7 @@ def barrier(self): :return: an :class:`RDDBarrier` instance that provides actions within a barrier stage. .. seealso:: :class:`BarrierTaskContext` - .. seealso:: `SPIP: Barrier Execution Mode \ + .. seealso:: `SPIP: Barrier Execution Mode `_ .. seealso:: `Design Doc `_ diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 472c3cd4452f0..65e3bdbc05ce8 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -54,7 +54,7 @@ sqlContext = spark._wrapped sqlCtx = sqlContext -print("""Welcome to +print(r"""Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 81f35f54aa54d..e288ec818b404 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -283,7 +283,8 @@ def approxCountDistinct(col, rsd=None): @since(2.1) def approx_count_distinct(col, rsd=None): - """Aggregate function: returns a new :class:`Column` for approximate distinct count of column `col`. + """Aggregate function: returns a new :class:`Column` for approximate distinct count of + column `col`. :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more efficient to use :func:`countDistinct` @@ -346,7 +347,8 @@ def coalesce(*cols): @since(1.6) def corr(col1, col2): - """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` + and ``col2``. >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -1688,14 +1690,14 @@ def split(str, pattern): @ignore_unicode_prefix @since(1.5) def regexp_extract(str, pattern, idx): - """Extract a specific group matched by a Java regex, from the specified string column. + r"""Extract a specific group matched by a Java regex, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] >>> df = spark.createDataFrame([('foo',)], ['str']) - >>> df.select(regexp_extract('str', '(\d+)', 1).alias('d')).collect() + >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect() [Row(d=u'')] >>> df = spark.createDataFrame([('aaaac',)], ['str']) >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() @@ -1712,7 +1714,7 @@ def regexp_replace(str, pattern, replacement): """Replace all substrings of the specified string value that match regexp with rep. >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() + >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() [Row(d=u'-----')] """ sc = SparkContext._active_spark_context diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3ca5d548ae7d6..690b13072244b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -350,7 +350,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, samplingRatio=None, enforceSchema=None, emptyValue=None): - """Loads a CSV file and returns the result as a :class:`DataFrame`. + r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if ``inferSchema`` is enabled. To avoid going through the entire data once, disable @@ -519,8 +519,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar If both ``column`` and ``predicates`` are specified, ``column`` will be used. - .. note:: Don't create too many partitions in parallel on a large cluster; \ - otherwise Spark might crash your external database systems. + .. note:: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: the name of the table @@ -862,7 +862,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): - """Saves the content of the :class:`DataFrame` in CSV format at the specified path. + r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. @@ -962,8 +962,8 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): def jdbc(self, url, table, mode=None, properties=None): """Saves the content of the :class:`DataFrame` to an external database table via JDBC. - .. note:: Don't create too many partitions in parallel on a large cluster; \ - otherwise Spark might crash your external database systems. + .. note:: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` :param table: Name of the table in the external database. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 522900bf6684c..b18453b2a4f96 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -565,7 +565,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, enforceSchema=None, emptyValue=None): - """Loads a CSV file stream and returns the result as a :class:`DataFrame`. + r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if ``inferSchema`` is enabled. To avoid going through the entire data once, disable diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ce1d004c6c8ff..1d24c40e5858e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -752,7 +752,7 @@ def __eq__(self, other): for v in [ArrayType, MapType, StructType]) -_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") +_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)") def _parse_datatype_string(s): diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index ef012d27cb22f..7f29646c07432 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -58,8 +58,8 @@ def __str__(self): StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1) """ -.. note:: The following four storage level constants are deprecated in 2.0, since the records \ -will always be serialized in Python. +.. note:: The following four storage level constants are deprecated in 2.0, since the records + will always be serialized in Python. """ StorageLevel.MEMORY_ONLY_SER = StorageLevel.MEMORY_ONLY """.. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY`` instead.""" diff --git a/python/pyspark/util.py b/python/pyspark/util.py index f015542c8799d..f906f49595438 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -80,7 +80,7 @@ def majorMinorVersion(sparkVersion): (2, 3) """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + m = re.search(r'^(\d+)\.(\d+)(\..*)?$', sparkVersion) if m is not None: return (int(m.group(1)), int(m.group(2))) else: diff --git a/python/run-tests.py b/python/run-tests.py index 4c90926cfa350..ccbdfac3f3850 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -138,7 +138,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python): # 2 (or --verbose option is enabled). decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) skipped_tests = list(filter( - lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + lambda line: re.search(r'test_.* \(pyspark\..*\) ... skipped ', line), decoded_lines)) skipped_counts = len(skipped_tests) if skipped_counts > 0: From 8b702e1e0aba1d3e4b0aa582f20cf99f80a44a09 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 12 Sep 2018 21:56:09 -0700 Subject: [PATCH 1605/2461] [SPARK-25415][SQL] Make plan change log in RuleExecutor configurable by SQLConf ## What changes were proposed in this pull request? In RuleExecutor, after applying a rule, if the plan has changed, the before and after plan will be logged using level "trace". At times, however, such information can be very helpful for debugging. Hence, making the log level configurable in SQLConf would allow users to turn on the plan change log independently and save the trouble of tweaking log4j settings. Meanwhile, filtering plan change log for specific rules can also be very useful. So this PR adds two SQL configurations: 1. spark.sql.optimizer.planChangeLog.level - set a specific log level for logging plan changes after a rule is applied. 2. spark.sql.optimizer.planChangeLog.rules - enable plan change logging only for a set of specified rules, separated by commas. ## How was this patch tested? Added UT. Closes #22406 from maryannxue/spark-25415. Authored-by: maryannxue Signed-off-by: gatorsmile --- .../sql/catalyst/rules/RuleExecutor.scala | 33 +++- .../apache/spark/sql/internal/SQLConf.scala | 24 +++ .../optimizer/OptimizerLoggingSuite.scala | 148 ++++++++++++++++++ 3 files changed, 200 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index dccb44ddebfa4..183be5a027ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils object RuleExecutor { @@ -72,6 +73,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { def execute(plan: TreeType): TreeType = { var curPlan = plan val queryExecutionMetrics = RuleExecutor.queryExecutionMeter + val planChangeLogger = new PlanChangeLogger() batches.foreach { batch => val batchStartPlan = curPlan @@ -90,11 +92,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (!result.fastEquals(plan)) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) - logTrace( - s""" - |=== Applying Rule ${rule.ruleName} === - |${sideBySide(plan.treeString, result.treeString).mkString("\n")} - """.stripMargin) + planChangeLogger.log(rule.ruleName, plan, result) } queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) queryExecutionMetrics.incNumExecution(rule.ruleName) @@ -143,4 +141,29 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } + + private class PlanChangeLogger { + + private val logLevel = SQLConf.get.optimizerPlanChangeLogLevel.toUpperCase + + private val logRules = SQLConf.get.optimizerPlanChangeRules.map(Utils.stringToSeq) + + def log(ruleName: String, oldPlan: TreeType, newPlan: TreeType): Unit = { + if (logRules.isEmpty || logRules.get.contains(ruleName)) { + lazy val message = + s""" + |=== Applying Rule ${ruleName} === + |${sideBySide(oldPlan.treeString, newPlan.treeString).mkString("\n")} + """.stripMargin + logLevel match { + case "TRACE" => logTrace(message) + case "DEBUG" => logDebug(message) + case "INFO" => logInfo(message) + case "WARN" => logWarning(message) + case "ERROR" => logError(message) + case _ => logTrace(message) + } + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 738d8fee891d1..4928560eacb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -171,6 +171,26 @@ object SQLConf { .intConf .createWithDefault(10) + val OPTIMIZER_PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.optimizer.planChangeLog.level") + .internal() + .doc("Configures the log level for logging the change from the original plan to the new " + + "plan after a rule is applied. The value can be 'trace', 'debug', 'info', 'warn', or " + + "'error'. The default log level is 'trace'.") + .stringConf + .checkValue( + str => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(str.toUpperCase), + "Invalid value for 'spark.sql.optimizer.planChangeLog.level'. Valid values are " + + "'trace', 'debug', 'info', 'warn' and 'error'.") + .createWithDefault("trace") + + val OPTIMIZER_PLAN_CHANGE_LOG_RULES = buildConf("spark.sql.optimizer.planChangeLog.rules") + .internal() + .doc("If this configuration is set, the optimizer will only log plan changes caused by " + + "applying the rules specified in this configuration. The value can be a list of rule " + + "names separated by comma.") + .stringConf + .createOptional + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") @@ -1570,6 +1590,10 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + def optimizerPlanChangeLogLevel: String = getConf(OPTIMIZER_PLAN_CHANGE_LOG_LEVEL) + + def optimizerPlanChangeRules: Option[String] = getConf(OPTIMIZER_PLAN_CHANGE_LOG_RULES) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala new file mode 100644 index 0000000000000..915f408089fe9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{Appender, AppenderSkeleton, Level, Logger} +import org.apache.log4j.spi.LoggingEvent + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class OptimizerLoggingSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Optimizer Batch", FixedPoint(100), + PushDownPredicate, + ColumnPruning, + CollapseProject) :: Nil + } + + class MockAppender extends AppenderSkeleton { + val loggingEvents = new ArrayBuffer[LoggingEvent]() + + override def append(loggingEvent: LoggingEvent): Unit = { + if (loggingEvent.getRenderedMessage().contains("Applying Rule")) { + loggingEvents.append(loggingEvent) + } + } + + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + private def withLogLevelAndAppender(level: Level, appender: Appender)(f: => Unit): Unit = { + val logger = Logger.getLogger(Optimize.getClass.getName.dropRight(1)) + val restoreLevel = logger.getLevel + logger.setLevel(level) + logger.addAppender(appender) + try f finally { + logger.setLevel(restoreLevel) + logger.removeAppender(appender) + } + } + + private def verifyLog(expectedLevel: Level, expectedRules: Seq[String]): Unit = { + val logAppender = new MockAppender() + withLogLevelAndAppender(Level.TRACE, logAppender) { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = input.select('a, 'b).select('a).where('a > 1).analyze + val expected = input.where('a > 1).select('a).analyze + comparePlans(Optimize.execute(query), expected) + } + val logMessages = logAppender.loggingEvents.map(_.getRenderedMessage) + assert(expectedRules.forall(rule => logMessages.exists(_.contains(rule)))) + assert(logAppender.loggingEvents.forall(_.getLevel == expectedLevel)) + } + + test("test log level") { + val levels = Seq( + "TRACE" -> Level.TRACE, + "trace" -> Level.TRACE, + "DEBUG" -> Level.DEBUG, + "debug" -> Level.DEBUG, + "INFO" -> Level.INFO, + "info" -> Level.INFO, + "WARN" -> Level.WARN, + "warn" -> Level.WARN, + "ERROR" -> Level.ERROR, + "error" -> Level.ERROR, + "deBUG" -> Level.DEBUG) + + levels.foreach { level => + withSQLConf(SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_LEVEL.key -> level._1) { + verifyLog( + level._2, + Seq( + PushDownPredicate.ruleName, + ColumnPruning.ruleName, + CollapseProject.ruleName)) + } + } + } + + test("test invalid log level conf") { + val levels = Seq( + "", + "*d_", + "infoo") + + levels.foreach { level => + val error = intercept[IllegalArgumentException] { + withSQLConf(SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_LEVEL.key -> level) {} + } + assert(error.getMessage.contains( + "Invalid value for 'spark.sql.optimizer.planChangeLog.level'.")) + } + } + + test("test log rules") { + val rulesSeq = Seq( + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName, + CollapseProject.ruleName).reduce(_ + "," + _) -> + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName, + CollapseProject.ruleName), + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName).reduce(_ + "," + _) -> + Seq(PushDownPredicate.ruleName, + ColumnPruning.ruleName), + CollapseProject.ruleName -> + Seq(CollapseProject.ruleName), + Seq(ColumnPruning.ruleName, + "DummyRule").reduce(_ + "," + _) -> + Seq(ColumnPruning.ruleName), + "DummyRule" -> Seq(), + "" -> Seq() + ) + + rulesSeq.foreach { case (rulesConf, expectedRules) => + withSQLConf( + SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_RULES.key -> rulesConf, + SQLConf.OPTIMIZER_PLAN_CHANGE_LOG_LEVEL.key -> "INFO") { + verifyLog(Level.INFO, expectedRules) + } + } + } +} From 3e75a9fa24f8629d068b5fbbc7356ce2603fa58d Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Wed, 12 Sep 2018 22:02:59 -0700 Subject: [PATCH 1606/2461] [SPARK-25295][K8S] Fix executor names collision ## What changes were proposed in this pull request? Fixes the collision issue with spark executor names in client mode, see SPARK-25295 for the details. It follows the cluster name convention as app-name will be used as the prefix and if that is not defined we use "spark" as the default prefix. Eg. `spark-pi-1536781360723-exec-1` where spark-pi is the name of the app passed at the config side or transformed if it contains illegal characters. Also fixes the issue with spark app name having spaces in cluster mode. If you run the Spark Pi test in client mode it passes. The tricky part is the user may set the app name: https://github.com/apache/spark/blob/3030b82c89d3e45a2e361c469fbc667a1e43b854/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala#L30 If i do: ``` ./bin/spark-submit ... --deploy-mode cluster --name "spark pi" ... ``` it will fail as the app name is used for the prefix of driver's pod name and it cannot have spaces (according to k8s conventions). ## How was this patch tested? Manually by running spark job in client mode. To reproduce do: ``` kubectl create -f service.yaml kubectl create -f pod.yaml ``` service.yaml : ``` kind: Service apiVersion: v1 metadata: name: spark-test-app-1-svc spec: clusterIP: None selector: spark-app-selector: spark-test-app-1 ports: - protocol: TCP name: driver-port port: 7077 targetPort: 7077 - protocol: TCP name: block-manager port: 10000 targetPort: 10000 ``` pod.yaml: ``` apiVersion: v1 kind: Pod metadata: name: spark-test-app-1 labels: spark-app-selector: spark-test-app-1 spec: containers: - name: spark-test image: skonto/spark:k8s-client-fix imagePullPolicy: Always command: - 'sh' - '-c' - "/opt/spark/bin/spark-submit --verbose --master k8s://https://kubernetes.default.svc --deploy-mode client --class org.apache.spark.examples.SparkPi --conf spark.app.name=spark --conf spark.executor.instances=1 --conf spark.kubernetes.container.image=skonto/spark:k8s-client-fix --conf spark.kubernetes.container.image.pullPolicy=Always --conf spark.kubernetes.authenticate.oauthTokenFile=/var/run/secrets/kubernetes.io/serviceaccount/token --conf spark.kubernetes.authenticate.caCertFile=/var/run/secrets/kubernetes.io/serviceaccount/ca.crt --conf spark.executor.memory=500m --conf spark.executor.cores=1 --conf spark.executor.instances=1 --conf spark.driver.host=spark-test-app-1-svc.default.svc --conf spark.driver.port=7077 --conf spark.driver.blockManager.port=10000 local:///opt/spark/examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar 1000000" ``` Closes #22405 from skonto/fix-k8s-client-mode-executor-names. Authored-by: Stavros Kontopoulos Signed-off-by: Yinan Li --- .../spark/deploy/k8s/KubernetesConf.scala | 13 +++++++++++- .../submit/KubernetesClientApplication.scala | 21 +++++++++++++++---- .../k8s/ExecutorPodsAllocatorSuite.scala | 16 +++++++++++--- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 3aa35d419073f..cae6e7d5ad518 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ +import org.apache.spark.deploy.k8s.submit.KubernetesClientApplication._ import org.apache.spark.internal.config.ConfigEntry @@ -220,10 +221,20 @@ private[spark] object KubernetesConf { val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) + // If no prefix is defined then we are in pure client mode + // (not the one used by cluster mode inside the container) + val appResourceNamePrefix = { + if (sparkConf.getOption(KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key).isEmpty) { + getResourceNamePrefix(getAppName(sparkConf)) + } else { + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + } + } + KubernetesConf( sparkConf.clone(), KubernetesExecutorSpecificConf(executorId, driverPod), - sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appResourceNamePrefix, appId, executorLabels, executorAnnotations, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 986c950ab365a..edeaa380194ac 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -211,11 +211,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate // a unique app ID (captured by spark.app.id) in the format below. val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" - val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val kubernetesResourceNamePrefix = { - s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") - } + val kubernetesResourceNamePrefix = KubernetesClientApplication.getResourceNamePrefix(appName) sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, @@ -254,3 +251,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } } } + +private[spark] object KubernetesClientApplication { + + def getAppName(conf: SparkConf): String = conf.getOption("spark.app.name").getOrElse("spark") + + def getResourceNamePrefix(appName: String): String = { + val launchTime = System.currentTimeMillis() + s"$appName-$launchTime" + .trim + .toLowerCase + .replaceAll("\\s+", "-") + .replaceAll("\\.", "-") + .replaceAll("[^a-z0-9\\-]", "") + .replaceAll("-+", "-") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index e847f8590d353..0e617b0021019 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -167,13 +167,23 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { executorSpecificConf.executorId, TEST_SPARK_APP_ID, Some(driverPod)) - k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + + // Set prefixes to a common string since KUBERNETES_EXECUTOR_POD_NAME_PREFIX + // has not be set for the tests and thus KubernetesConf will use a random + // string for the prefix, based on the app name, and this comparison here will fail. + val k8sConfCopy = k8sConf + .copy(appResourceNamePrefix = "") + .copy(sparkConf = conf) + val expectedK8sConfCopy = expectedK8sConf + .copy(appResourceNamePrefix = "") + .copy(sparkConf = conf) + + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && // Since KubernetesConf.createExecutorConf clones the SparkConf object, force // deep equality comparison for the SparkConf object and use object equality // comparison on all other fields. - k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + k8sConfCopy == expectedK8sConfCopy } } }) - } From 5b761c537a600115450b53817bee0679d5c2bb97 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 13 Sep 2018 14:21:00 +0200 Subject: [PATCH 1607/2461] [SPARK-25352][SQL][FOLLOWUP] Add helper method and address style issue ## What changes were proposed in this pull request? This follow-up patch addresses [the review comment](https://github.com/apache/spark/pull/22344/files#r217070658) by adding a helper method to simplify code and fixing style issue. ## How was this patch tested? Existing unit tests. Author: Liang-Chi Hsieh Closes #22409 from viirya/SPARK-25352-followup. --- .../spark/sql/execution/SparkStrategies.scala | 55 ++++++++----------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7c8ce316f9647..89442a70283f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,44 +66,35 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) => - if (limit < conf.topKSortFallbackThreshold) { + private def decideTopRankNode(limit: Int, child: LogicalPlan): Seq[SparkPlan] = { + if (limit < conf.topKSortFallbackThreshold) { + child match { + case Sort(order, true, child) => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - } else { - GlobalLimitExec(limit, - LocalLimitExec(limit, planLater(s)), - orderedLimit = true) :: Nil - } - case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) => - if (limit < conf.topKSortFallbackThreshold) { + case Project(projectList, Sort(order, true, child)) => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil - } else { - GlobalLimitExec(limit, - LocalLimitExec(limit, planLater(p)), - orderedLimit = true) :: Nil - } + } + } else { + GlobalLimitExec(limit, + LocalLimitExec(limit, planLater(child)), + orderedLimit = true) :: Nil + } + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => + decideTopRankNode(limit, s) + case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => + decideTopRankNode(limit, p) case Limit(IntegerLiteral(limit), child) => CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), s@Sort(order, true, child)) => - if (limit < conf.topKSortFallbackThreshold) { - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - } else { - GlobalLimitExec(limit, - LocalLimitExec(limit, planLater(s)), - orderedLimit = true) :: Nil - } - case Limit(IntegerLiteral(limit), p@Project(projectList, Sort(order, true, child))) => - if (limit < conf.topKSortFallbackThreshold) { - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil - } else { - GlobalLimitExec(limit, - LocalLimitExec(limit, planLater(p)), - orderedLimit = true) :: Nil - } + case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => + decideTopRankNode(limit, s) + case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => + decideTopRankNode(limit, p) case _ => Nil } } From 45c4ebc8171d75fc0d169bb8071a4c43263d283e Mon Sep 17 00:00:00 2001 From: LucaCanali Date: Thu, 13 Sep 2018 10:19:21 -0500 Subject: [PATCH 1608/2461] [SPARK-25170][DOC] Add list and short description of Spark Executor Task Metrics to the documentation. ## What changes were proposed in this pull request? Add description of Executor Task Metrics to the documentation. Closes #22397 from LucaCanali/docMonitoringTaskMetrics. Authored-by: LucaCanali Signed-off-by: Sean Owen --- docs/monitoring.md | 152 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/docs/monitoring.md b/docs/monitoring.md index 2717dd091c751..f6d52ef4597e9 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -388,6 +388,158 @@ value triggering garbage collection on jobs, and `spark.ui.retainedStages` that Note that the garbage collection takes place on playback: it is possible to retrieve more entries by increasing these values and restarting the history server. +### Executor Task Metrics + +The REST API exposes the values of the Task Metrics collected by Spark executors with the granularity +of task execution. The metrics can be used for performance troubleshooting and workload characterization. +A list of the available metrics, with a short description: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Spark Executor Task Metric nameShort description
      executorRunTimeElapsed time the executor spent running this task. This includes time fetching shuffle data. + The value is expressed in milliseconds.
      executorCpuTimeCPU time the executor spent running this task. This includes time fetching shuffle data. + The value is expressed in nanoseconds.
      executorDeserializeTimeElapsed time spent to deserialize this task. The value is expressed in milliseconds.
      executorDeserializeCpuTimeCPU time taken on the executor to deserialize this task. The value is expressed + in nanoseconds.
      resultSizeThe number of bytes this task transmitted back to the driver as the TaskResult.
      jvmGCTimeElapsed time the JVM spent in garbage collection while executing this task. + The value is expressed in milliseconds.
      resultSerializationTimeElapsed time spent serializing the task result. The value is expressed in milliseconds.
      memoryBytesSpilledThe number of in-memory bytes spilled by this task.
      diskBytesSpilledThe number of on-disk bytes spilled by this task.
      peakExecutionMemoryPeak memory used by internal data structures created during shuffles, aggregations and + joins. The value of this accumulator should be approximately the sum of the peak sizes + across all such data structures created in this task. For SQL jobs, this only tracks all + unsafe operators and ExternalSort.
      inputMetrics.*Metrics related to reading data from [[org.apache.spark.rdd.HadoopRDD]] + or from persisted data.
          .bytesReadTotal number of bytes read.
          .recordsReadTotal number of records read.
      outputMetrics.*Metrics related to writing data externally (e.g. to a distributed filesystem), + defined only in tasks with output.
          .bytesWrittenTotal number of bytes written
          .recordsWrittenTotal number of records written
      shuffleReadMetrics.*Metrics related to shuffle read operations.
          .recordsReadNumber of records read in shuffle operations
          .remoteBlocksFetchedNumber of remote blocks fetched in shuffle operations
          .localBlocksFetchedNumber of local (as opposed to read from a remote executor) blocks fetched + in shuffle operations
          .totalBlocksFetchedNumber of blocks fetched in shuffle operations (both local and remote)
          .remoteBytesReadNumber of remote bytes read in shuffle operations
          .localBytesReadNumber of bytes read in shuffle operations from local disk (as opposed to + read from a remote executor)
          .totalBytesReadNumber of bytes read in shuffle operations (both local and remote)
          .remoteBytesReadToDiskNumber of remote bytes read to disk in shuffle operations. + Large blocks are fetched to disk in shuffle read operations, as opposed to + being read into memory, which is the default behavior.
          .fetchWaitTimeTime the task spent waiting for remote shuffle blocks. + This only includes the time blocking on shuffle input data. + For instance if block B is being fetched while the task is still not finished + processing block A, it is not considered to be blocking on block B. + The value is expressed in milliseconds.
      shuffleWriteMetrics.*Metrics related to operations writing shuffle data.
          .bytesWrittenNumber of bytes written in shuffle operations
          .recordsWrittenNumber of records written in shuffle operations
          .writeTimeTime spent blocking on writes to disk or buffer cache. The value is expressed + in nanoseconds.
      + + + ### API Versioning Policy These endpoints have been strongly versioned to make it easier to develop applications on top. From a7e5aa6cd430d0a49bb6dac92c007fab189db3a3 Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Thu, 13 Sep 2018 17:08:45 +0000 Subject: [PATCH 1609/2461] [SPARK-25406][SQL] For ParquetSchemaPruningSuite.scala, move calls to `withSQLConf` inside calls to `test` (Link to Jira: https://issues.apache.org/jira/browse/SPARK-25406) ## What changes were proposed in this pull request? The current use of `withSQLConf` in `ParquetSchemaPruningSuite.scala` is incorrect. The desired configuration settings are not being set when running the test cases. This PR fixes that defective usage and addresses the test failures that were previously masked by that defect. ## How was this patch tested? I added code to relevant test cases to print the expected SQL configuration settings and found that the settings were not being set as expected. When I changed the order of calls to `test` and `withSQLConf` I found that the configuration settings were being set as expected. Closes #22394 from mallman/spark-25406-fix_broken_schema_pruning_tests. Authored-by: Michael Allman Signed-off-by: DB Tsai --- .../parquet/ParquetSchemaPruningSuite.scala | 63 +++++++++++-------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index 7b132af4f6911..434c4414edeba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -218,20 +218,24 @@ class ParquetSchemaPruningSuite } private def testSchemaPruning(testName: String)(testThunk: => Unit) { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { - test(s"Spark vectorized reader - without partition data column - $testName") { + test(s"Spark vectorized reader - without partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { withContacts(testThunk) } - test(s"Spark vectorized reader - with partition data column - $testName") { + } + test(s"Spark vectorized reader - with partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { withContactsWithDataPartitionColumn(testThunk) } } - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - test(s"Parquet-mr reader - without partition data column - $testName") { + test(s"Parquet-mr reader - without partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { withContacts(testThunk) } - test(s"Parquet-mr reader - with partition data column - $testName") { + } + test(s"Parquet-mr reader - with partition data column - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { withContactsWithDataPartitionColumn(testThunk) } } @@ -271,7 +275,7 @@ class ParquetSchemaPruningSuite MixedCase(1, "r1c1", MixedCaseColumn("123", 2)) :: Nil - testMixedCasePruning("select with exact column names") { + testExactCaseQueryPruning("select with exact column names") { val query = sql("select CoL1, coL2.B from mixedcase") checkScan(query, "struct>") checkAnswer(query.orderBy("id"), @@ -280,7 +284,7 @@ class ParquetSchemaPruningSuite Nil) } - testMixedCasePruning("select with lowercase column names") { + testMixedCaseQueryPruning("select with lowercase column names") { val query = sql("select col1, col2.b from mixedcase") checkScan(query, "struct>") checkAnswer(query.orderBy("id"), @@ -289,7 +293,7 @@ class ParquetSchemaPruningSuite Nil) } - testMixedCasePruning("select with different-case column names") { + testMixedCaseQueryPruning("select with different-case column names") { val query = sql("select cOL1, cOl2.b from mixedcase") checkScan(query, "struct>") checkAnswer(query.orderBy("id"), @@ -298,34 +302,43 @@ class ParquetSchemaPruningSuite Nil) } - testMixedCasePruning("filter with different-case column names") { + testMixedCaseQueryPruning("filter with different-case column names") { val query = sql("select id from mixedcase where Col2.b = 2") checkScan(query, "struct>") checkAnswer(query.orderBy("id"), Row(1) :: Nil) } - private def testMixedCasePruning(testName: String)(testThunk: => Unit) { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", - SQLConf.CASE_SENSITIVE.key -> "true") { - test(s"Spark vectorized reader - case-sensitive parser - mixed-case schema - $testName") { - withMixedCaseData(testThunk) + // Tests schema pruning for a query whose column and field names are exactly the same as the table + // schema's column and field names. N.B. this implies that `testThunk` should pass using either a + // case-sensitive or case-insensitive query parser + private def testExactCaseQueryPruning(testName: String)(testThunk: => Unit) { + test(s"Spark vectorized reader - case-sensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "true") { + withMixedCaseData(testThunk) } } - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", - SQLConf.CASE_SENSITIVE.key -> "false") { - test(s"Parquet-mr reader - case-insensitive parser - mixed-case schema - $testName") { + test(s"Parquet-mr reader - case-sensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "true") { withMixedCaseData(testThunk) } } - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", - SQLConf.CASE_SENSITIVE.key -> "false") { - test(s"Spark vectorized reader - case-insensitive parser - mixed-case schema - $testName") { - withMixedCaseData(testThunk) + testMixedCaseQueryPruning(testName)(testThunk) + } + + // Tests schema pruning for a query whose column and field names may differ in case from the table + // schema's column and field names + private def testMixedCaseQueryPruning(testName: String)(testThunk: => Unit) { + test(s"Spark vectorized reader - case-insensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.CASE_SENSITIVE.key -> "false") { + withMixedCaseData(testThunk) } } - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", - SQLConf.CASE_SENSITIVE.key -> "true") { - test(s"Parquet-mr reader - case-sensitive parser - mixed-case schema - $testName") { + test(s"Parquet-mr reader - case-insensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", + SQLConf.CASE_SENSITIVE.key -> "false") { withMixedCaseData(testThunk) } } From f60cd7cc3ce663bb1517e059f5fd79c0098ebbcd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 13 Sep 2018 11:34:22 -0700 Subject: [PATCH 1610/2461] [SPARK-25338][TEST] Ensure to call super.beforeAll() and super.afterAll() in test cases ## What changes were proposed in this pull request? This PR ensures to call `super.afterAll()` in `override afterAll()` method for test suites. * Some suites did not call `super.afterAll()` * Some suites may call `super.afterAll()` only under certain condition * Others never call `super.afterAll()`. This PR also ensures to call `super.beforeAll()` in `override beforeAll()` for test suites. ## How was this patch tested? Existing UTs Closes #22337 from kiszk/SPARK-25338. Authored-by: Kazuaki Ishizaki Signed-off-by: Dongjoon Hyun --- .../deploy/master/ui/MasterWebUISuite.scala | 7 +++-- .../flume/FlumePollingStreamSuite.scala | 11 +++++--- .../sql/kafka010/KafkaRelationSuite.scala | 9 ++++--- .../spark/sql/kafka010/KafkaSinkSuite.scala | 9 ++++--- .../kafka010/DirectKafkaStreamSuite.scala | 11 +++++--- .../streaming/kafka010/KafkaRDDSuite.scala | 23 ++++++++++------ .../kafka/DirectKafkaStreamSuite.scala | 11 +++++--- .../streaming/kafka/KafkaClusterSuite.scala | 11 +++++--- .../spark/streaming/kafka/KafkaRDDSuite.scala | 23 ++++++++++------ .../streaming/kafka/KafkaStreamSuite.scala | 23 ++++++++++------ .../kafka/ReliableKafkaStreamSuite.scala | 13 +++++++--- .../KinesisInputDStreamBuilderSuite.scala | 6 ++++- .../kinesis/KinesisStreamSuite.scala | 26 +++++++++++-------- .../k8s/integrationtest/KubernetesSuite.scala | 7 ++++- .../apache/spark/sql/SessionStateSuite.scala | 15 ++++++----- .../execution/ExchangeCoordinatorSuite.scala | 11 +++++--- ...xternalAppendOnlyUnsafeRowArraySuite.scala | 6 ++++- .../SortBasedAggregationStoreSuite.scala | 6 ++++- .../benchmark/WideSchemaBenchmark.scala | 7 +++-- .../BasicWriteTaskStatsTrackerSuite.scala | 6 ++++- .../execution/joins/BroadcastJoinSuite.scala | 8 ++++-- .../python/BatchEvalPythonExecSuite.scala | 7 +++-- .../streaming/state/StateStoreRDDSuite.scala | 7 +++-- .../internal/ExecutorSideSQLConfSuite.scala | 8 ++++-- .../FlatMapGroupsWithStateSuite.scala | 8 +----- .../spark/sql/streaming/StreamTest.scala | 7 +++-- .../streaming/StreamingAggregationSuite.scala | 8 +----- .../StreamingDeduplicationSuite.scala | 7 +---- .../spark/sql/test/SharedSQLContext.scala | 7 +++-- .../hive/thriftserver/UISeleniumSuite.scala | 9 ++++--- .../HiveExternalCatalogVersionsSuite.scala | 11 +++++--- .../sql/hive/execution/HiveUDAFSuite.scala | 1 + .../execution/ObjectHashAggregateSuite.scala | 1 + 33 files changed, 216 insertions(+), 114 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index 69a460fbc7dba..f4558aa3eb893 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -53,8 +53,11 @@ class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { } override def afterAll() { - masterWebUI.stop() - super.afterAll() + try { + masterWebUI.stop() + } finally { + super.afterAll() + } } test("kill application") { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 4324cc6d0f804..9241b13c100f1 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -50,13 +50,18 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfterAll with val utils = new PollingFlumeTestUtils override def beforeAll(): Unit = { + super.beforeAll() _sc = new SparkContext(conf) } override def afterAll(): Unit = { - if (_sc != null) { - _sc.stop() - _sc = null + try { + if (_sc != null) { + _sc.stop() + _sc = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index eb186970fc25d..8cfca56433f5d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -48,9 +48,12 @@ class KafkaRelationSuite extends QueryTest with SharedSQLContext with KafkaTest } override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null + try { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + } finally { super.afterAll() } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index a2213e024bd98..81832fbdcd7ec 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -48,9 +48,12 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { } override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null + try { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + } finally { super.afterAll() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 661b67a8ab68a..1974bb1e12e15 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -59,14 +59,19 @@ class DirectKafkaStreamSuite private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index 3ac6509b04707..561bca5f55370 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -44,20 +44,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ override def beforeAll { + super.beforeAll() sc = new SparkContext(sparkConf) kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (sc != null) { - sc.stop - sc = null - } - - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + try { + if (sc != null) { + sc.stop + sc = null + } + } finally { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index ecca38784e777..3fd37f4c8ac90 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -57,14 +57,19 @@ class DirectKafkaStreamSuite private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index d66830cbacdee..73d528518d486 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -32,6 +32,7 @@ class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() @@ -41,9 +42,13 @@ class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { } override def afterAll() { - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 809699a739962..72f954149fefe 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -35,20 +35,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ override def beforeAll { + super.beforeAll() sc = new SparkContext(sparkConf) kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll { - if (sc != null) { - sc.stop - sc = null - } - - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + try { + if (sc != null) { + sc.stop + sc = null + } + } finally { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 426cd83b4ddf8..ed130f5990955 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -35,19 +35,26 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll(): Unit = { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() } override def afterAll(): Unit = { - if (ssc != null) { - ssc.stop() - ssc = null - } - - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + try { + try { + if (ssc != null) { + ssc.stop() + ssc = null + } + } finally { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + } finally { + super.afterAll() } } diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 57f89cc7dbc65..5da5ea49d77ed 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -51,6 +51,7 @@ class ReliableKafkaStreamSuite extends SparkFunSuite private var tempDirectory: File = null override def beforeAll(): Unit = { + super.beforeAll() kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() @@ -65,11 +66,15 @@ class ReliableKafkaStreamSuite extends SparkFunSuite } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDirectory) + try { + Utils.deleteRecursively(tempDirectory) - if (kafkaTestUtils != null) { - kafkaTestUtils.teardown() - kafkaTestUtils = null + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } finally { + super.afterAll() } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala index e0e26847aa0ec..361520e292266 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -40,7 +40,11 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE .checkpointAppName(checkpointAppName) override def afterAll(): Unit = { - ssc.stop() + try { + ssc.stop() + } finally { + super.afterAll() + } } test("should raise an exception if the StreamingContext is missing") { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index a7a68eba910bf..6d27445c5b606 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -71,17 +71,21 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } override def afterAll(): Unit = { - if (ssc != null) { - ssc.stop() - } - if (sc != null) { - sc.stop() - } - if (testUtils != null) { - // Delete the Kinesis stream as well as the DynamoDB table generated by - // Kinesis Client Library when consuming the stream - testUtils.deleteStream() - testUtils.deleteDynamoDBTable(appName) + try { + if (ssc != null) { + ssc.stop() + } + if (sc != null) { + sc.stop() + } + if (testUtils != null) { + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + testUtils.deleteStream() + testUtils.deleteDynamoDBTable(appName) + } + } finally { + super.afterAll() } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 82e6efa2707d9..18541baf05813 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -62,6 +62,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite s"${(1024 + memOverheadConstant*1024 + additionalMemory).toInt}Mi" override def beforeAll(): Unit = { + super.beforeAll() // The scalatest-maven-plugin gives system properties that are referenced but not set null // values. We need to remove the null-value properties before initializing the test backend. val nullValueProperties = System.getProperties.asScala @@ -93,7 +94,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite } override def afterAll(): Unit = { - testBackend.cleanUp() + try { + testBackend.cleanUp() + } finally { + super.afterAll() + } } before { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 7d1366092d1e6..e1b5eba53f06a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -41,13 +41,16 @@ class SessionStateSuite extends SparkFunSuite { } override def afterAll(): Unit = { - if (activeSession != null) { - activeSession.stop() - activeSession = null - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() + try { + if (activeSession != null) { + activeSession.stop() + activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } finally { + super.afterAll() } - super.afterAll() } test("fork new session and inherit RuntimeConfig options") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 41de731d41f82..c627c51655c8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -31,6 +31,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalInstantiatedSparkSession: Option[SparkSession] = _ override protected def beforeAll(): Unit = { + super.beforeAll() originalActiveSparkSession = SparkSession.getActiveSession originalInstantiatedSparkSession = SparkSession.getDefaultSession @@ -39,9 +40,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } override protected def afterAll(): Unit = { - // Set these states back. - originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) - originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) + try { + // Set these states back. + originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) + originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) + } finally { + super.afterAll() + } } private def checkEstimation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index ecc7264d79442..b29de9c4adbaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -29,7 +29,11 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar private val random = new java.util.Random() private var taskContext: TaskContext = _ - override def afterAll(): Unit = TaskContext.unset() + override def afterAll(): Unit = try { + TaskContext.unset() + } finally { + super.afterAll() + } private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 3fad7dfddadcc..dc67446460877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -39,7 +39,11 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } - override def afterAll(): Unit = TaskContext.unset() + override def afterAll(): Unit = try { + TaskContext.unset() + } finally { + super.afterAll() + } private val rand = new java.util.Random() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala index a42891e55a18a..c368f17a84364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -54,8 +54,11 @@ class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { } override def afterAll() { - super.afterAll() - out.close() + try { + out.close() + } finally { + super.afterAll() + } } override def afterEach() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala index bf3c8ede9a980..32941d8d2cd11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -49,7 +49,11 @@ class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { * In teardown delete the temp dir. */ protected override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterAll() + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index bcdee792f4c70..b4ad1db20a9ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -54,8 +54,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } override def afterAll(): Unit = { - spark.stop() - spark = null + try { + spark.stop() + spark = null + } finally { + super.afterAll() + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 2cc55ff88b983..289cc667a1c66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -37,8 +37,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { } override def afterAll(): Unit = { - spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) - super.afterAll() + try { + spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) + } finally { + super.afterAll() + } } test("Python UDF: push down deterministic FilterExec predicates") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 579a364ebc3e5..015415a534ff5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -49,8 +49,11 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } override def afterAll(): Unit = { - super.afterAll() - Utils.deleteRecursively(new File(tempDir)) + try { + super.afterAll() + } finally { + Utils.deleteRecursively(new File(tempDir)) + } } test("versioning and immutability") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 5b4736ef4f7f3..d885348f3774a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -38,8 +38,12 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } override def afterAll(): Unit = { - spark.stop() - spark = null + try { + spark.stop() + spark = null + } finally { + super.afterAll() + } } override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index e77ba1ec9f1eb..43463a84093ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -45,19 +45,13 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { import testImplicits._ import GroupStateImpl._ import GroupStateTimeout._ import FlatMapGroupsWithStateSuite._ - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - test("GroupState - get, exists, update, remove") { var state: GroupStateImpl[String] = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 491dc34afa143..d878c345c2988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -79,8 +79,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be implicit val defaultSignaler: Signaler = ThreadSignaler override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() // stop the state store maintenance thread and unload store providers + try { + super.afterAll() + } finally { + StateStore.stop() // stop the state store maintenance thread and unload store providers + } } protected val defaultTrigger = Trigger.ProcessingTime(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1ae6ff3a90989..97dbb9b0360ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -46,13 +46,7 @@ object FailureSingleton { var firstTime = true } -class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { - - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } +class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 42ffd472eb843..cfd7204ea2931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -26,15 +26,10 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest { import testImplicits._ - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - test("deduplicate with all columns") { val inputData = MemoryStream[String] val result = inputData.toDS().dropDuplicates() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e6c7648c986ae..0dd24d2d56b82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -35,7 +35,10 @@ trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { } protected override def afterAll(): Unit = { - super.afterAll() - doThreadPostAudit() + try { + super.afterAll() + } finally { + doThreadPostAudit() + } } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 4c53dd8f4616c..fef18f147b057 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -46,10 +46,13 @@ class UISeleniumSuite } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } - super.afterAll() } override protected def serverStartCommand(port: Int) = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 25df3339e62f3..a7d6972fa71f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -49,10 +49,13 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { - Utils.deleteRecursively(wareHousePath) - Utils.deleteRecursively(tmpDataDir) - Utils.deleteRecursively(sparkTestingDir) - super.afterAll() + try { + Utils.deleteRecursively(wareHousePath) + Utils.deleteRecursively(tmpDataDir) + Utils.deleteRecursively(sparkTestingDir) + } finally { + super.afterAll() + } } private def tryDownloadSpark(version: String, path: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 7402c9626873c..fe3deceb08067 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -37,6 +37,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { import testImplicits._ protected override def beforeAll(): Unit = { + super.beforeAll() sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'") sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 8dbcd24cd78de..0ef630bbd3670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -43,6 +43,7 @@ class ObjectHashAggregateSuite import testImplicits._ protected override def beforeAll(): Unit = { + super.beforeAll() sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") } From 9deddbb13edebfefb3fd03f063679ed12e73c575 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 13 Sep 2018 14:11:55 -0500 Subject: [PATCH 1611/2461] [SPARK-25400][CORE][TEST] Increase test timeouts We've seen some flakiness in jenkins in SchedulerIntegrationSuite which looks like it just needs a longer timeout. Closes #22385 from squito/SPARK-25400. Authored-by: Imran Rashid Signed-off-by: Sean Owen --- .../apache/spark/scheduler/BlacklistIntegrationSuite.scala | 1 - .../apache/spark/scheduler/SchedulerIntegrationSuite.scala | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index d3bbfd11d406d..fe22d70850c7d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.internal.config class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorMockBackend]{ val badHost = "host-0" - val duration = Duration(10, SECONDS) /** * This backend just always fails if the task is executed on a bad host, but otherwise succeeds diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 2d409d94ca1b3..ff0f99b5c94d0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -51,6 +51,9 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa var taskScheduler: TestTaskScheduler = null var scheduler: DAGScheduler = null var backend: T = _ + // Even though the tests aren't doing much, occassionally we see flakiness from pauses over + // a second (probably from GC?) so we leave a long timeout in here + val duration = Duration(10, SECONDS) override def beforeEach(): Unit = { if (taskScheduler != null) { @@ -539,7 +542,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } assert(results === (0 until 10).map { _ -> 42 }.toMap) @@ -592,7 +594,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(d, (0 until 30).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } assert(results === (0 until 30).map { idx => idx -> (4321 + idx) }.toMap) @@ -634,7 +635,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(shuffledRdd, (0 until 10).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) } assertDataStructuresEmpty() @@ -649,7 +649,6 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor } withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) - val duration = Duration(1, SECONDS) awaitJobTermination(jobFuture, duration) assert(failure.getMessage.contains("test task failure")) } From a81ef9e1f9bea79aab4a72a5efff69193ee386de Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 13 Sep 2018 22:22:00 -0700 Subject: [PATCH 1612/2461] [SPARK-25418][SQL] The metadata of DataSource table should not include Hive-generated storage properties. ## What changes were proposed in this pull request? When Hive support enabled, Hive catalog puts extra storage properties into table metadata even for DataSource tables, but we should not have them. ## How was this patch tested? Modified a test. Closes #22410 from ueshin/issues/SPARK-25418/hive_metadata. Authored-by: Takuya UESHIN Signed-off-by: gatorsmile --- .../org/apache/spark/sql/hive/HiveExternalCatalog.scala | 7 ++++++- .../org/apache/spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 5cc1047fc067b..505124ae9e7c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_FORMAT import org.apache.thrift.TException import org.apache.spark.{SparkConf, SparkException} @@ -806,6 +807,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat updateLocationInStorageProps(table, newPath = None).copy( locationUri = tableLocation.map(CatalogUtils.stringToURI(_))) } + val storageWithoutHiveGeneratedProperties = storageWithLocation.copy( + properties = storageWithLocation.properties.filterKeys(!HIVE_GENERATED_STORAGE_PROPERTIES(_))) val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) val schemaFromTableProps = getSchemaFromTableProperties(table) @@ -814,7 +817,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table.copy( provider = Some(provider), - storage = storageWithLocation, + storage = storageWithoutHiveGeneratedProperties, schema = reorderedSchema, partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), @@ -1309,6 +1312,8 @@ object HiveExternalCatalog { val CREATED_SPARK_VERSION = SPARK_SQL_PREFIX + "create.version" + val HIVE_GENERATED_STORAGE_PROPERTIES = Set(SERIALIZATION_FORMAT) + // When storing data source tables in hive metastore, we need to set data schema to empty if the // schema is hive-incompatible. However we need a hack to preserve existing behavior. Before // Spark 2.0, we do not set a default serde here (this was done in Hive), and so if the user diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 69ee2bbf06651..be1aa83d682b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -72,7 +72,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA outputFormat = serde.get.outputFormat, serde = serde.get.serde, compressed = false, - properties = Map("serialization.format" -> "1")) + properties = Map.empty) } else { CatalogStorageFormat( locationUri = Some(catalog.defaultTablePath(name)), From 9c25d7f735ed8c49c795babea3fda3cab226e7cb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 14 Sep 2018 09:25:27 -0700 Subject: [PATCH 1613/2461] [SPARK-25431][SQL][EXAMPLES] Fix function examples and unify the format of the example results. ## What changes were proposed in this pull request? There are some mistakes in examples of newly added functions. Also the format of the example results are not unified. We should fix and unify them. ## How was this patch tested? Manually executed the examples. Closes #22421 from ueshin/issues/SPARK-25431/fix_examples. Authored-by: Takuya UESHIN Signed-off-by: gatorsmile --- .../expressions/collectionOperations.scala | 49 ++++++++++--------- .../expressions/complexTypeCreator.scala | 4 +- .../expressions/higherOrderFunctions.scala | 32 ++++++------ 3 files changed, 43 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ea6fcccddfd49..3ad21ec5e51f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -131,7 +131,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [1,2] + [1, 2] """) case class MapKeys(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -320,7 +320,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - ["a","b"] + ["a", "b"] """) case class MapValues(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -348,7 +348,7 @@ case class MapValues(child: Expression) examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [(1,"a"),(2,"b")] + [[1, "a"], [2, "b"]] """, since = "2.4.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -516,7 +516,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); - [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + [1 -> "a", 2 -> "b", 2 -> "c", 3 -> "d"] """, since = "2.4.0") case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { @@ -718,7 +718,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres examples = """ Examples: > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); - {1:"a",2:"b"} + [1 -> "a", 2 -> "b"] """, since = "2.4.0") case class MapFromEntries(child: Expression) extends UnaryExpression { @@ -1071,7 +1071,7 @@ object ArraySortLike { examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); - [null,"a","b","c","d"] + [null, "a", "b", "c", "d"] """) // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) @@ -1129,7 +1129,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["a","b","c","d",null] + ["a", "b", "c", "d", null] """, since = "2.4.0") // scalastyle:on line.size.limit @@ -1254,7 +1254,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) examples = """ Examples: > SELECT _FUNC_('Spark SQL'); - LQS krapS + "LQS krapS" > SELECT _FUNC_(array(2, 1, 4, 3)); [3, 4, 1, 2] """, @@ -1634,9 +1634,9 @@ case class ArraysOverlap(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); - [2,3] + [2, 3] > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); - [3,4] + [3, 4] """, since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) @@ -1745,11 +1745,11 @@ case class Slice(x: Expression, start: Expression, length: Expression) examples = """ Examples: > SELECT _FUNC_(array('hello', 'world'), ' '); - hello world + "hello world" > SELECT _FUNC_(array('hello', null ,'world'), ' '); - hello world + "hello world" > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); - hello , world + "hello , world" """, since = "2.4.0") case class ArrayJoin( array: Expression, @@ -2236,10 +2236,11 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti examples = """ Examples: > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL + "SparkSQL" > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - | [1,2,3,4,5,6] - """) + [1, 2, 3, 4, 5, 6] + """, + note = "Concat logic for arrays is available since 2.4.0.") case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) @@ -2427,8 +2428,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", examples = """ Examples: - > SELECT _FUNC_(array(array(1, 2), array(3, 4)); - [1,2,3,4] + > SELECT _FUNC_(array(array(1, 2), array(3, 4))); + [1, 2, 3, 4] """, since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { @@ -2934,7 +2935,7 @@ object Sequence { examples = """ Examples: > SELECT _FUNC_('123', 2); - ['123', '123'] + ["123", "123"] """, since = "2.4.0") case class ArrayRepeat(left: Expression, right: Expression) @@ -3055,7 +3056,7 @@ case class ArrayRepeat(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); - [1,2,null] + [1, 2, null] """, since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -3245,7 +3246,7 @@ trait ArraySetLike { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3)); - [1,2,3,null] + [1, 2, 3, null] """, since = "2.4.0") case class ArrayDistinct(child: Expression) extends UnaryExpression with ArraySetLike with ExpectsInputTypes { @@ -3421,7 +3422,7 @@ object ArrayBinaryLike { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) + [1, 2, 3, 5] """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike @@ -3632,7 +3633,7 @@ object ArrayUnion { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 3) + [1, 3] """, since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBinaryLike @@ -3873,7 +3874,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(2) + [2] """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index aba9c6c8ad6fd..117fa3e9aa519 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -248,8 +248,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { in keys should not be null""", examples = """ Examples: - > SELECT _FUNC_([1.0, 3.0], ['2', '4']); - {1.0:"2",3.0:"4"} + > SELECT _FUNC_(array(1.0, 3.0), array('2', '4')); + [1.0 -> "2", 3.0 -> "4"] """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2bb6b20b944d4..3ef2ec03099e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -209,9 +209,9 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); - array(2, 3, 4) + [2, 3, 4] > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); - array(1, 3, 5) + [1, 3, 5] """, since = "2.4.0") case class ArrayTransform( @@ -318,7 +318,7 @@ case class MapFilter( examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); - array(1, 3) + [1, 3] """, since = "2.4.0") case class ArrayFilter( @@ -499,10 +499,10 @@ case class ArrayAggregate( usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); - map(array(2, 3, 4), array(1, 2, 3)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); - map(array(2, 4, 6), array(1, 2, 3)) + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + [2 -> 1, 3 -> 2, 4 -> 3] + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + [2 -> 1, 4 -> 2, 6 -> 3] """, since = "2.4.0") case class TransformKeys( @@ -549,10 +549,10 @@ case class TransformKeys( usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); - map(array(1, 2, 3), array(2, 3, 4)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); - map(array(1, 2, 3), array(2, 4, 6)) + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + [1 -> 2, 2 -> 3, 3 -> 4] + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + [1 -> 2, 2 -> 4, 3 -> 6] """, since = "2.4.0") case class TransformValues( @@ -603,7 +603,7 @@ case class TransformValues( examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); - {1:"ax",2:"by"} + [1 -> "ax", 2 -> "by"] """, since = "2.4.0") case class MapZipWith(left: Expression, right: Expression, function: Expression) @@ -777,11 +777,11 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); - array(('a', 1), ('b', 2), ('c', 3)) - > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); - array(4, 6) + [["a", 1], ["b", 2], ["c", 3]] + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y); + [4, 6] > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); - array('ad', 'be', 'cf') + ["ad", "be", "cf"] """, since = "2.4.0") // scalastyle:on line.size.limit From 9bb798f2e6eefd9edb7b6d9980a894557c107bd3 Mon Sep 17 00:00:00 2001 From: cclauss Date: Fri, 14 Sep 2018 20:13:07 -0500 Subject: [PATCH 1614/2461] [SPARK-25238][PYTHON] lint-python: Upgrade pycodestyle to v2.4.0 See https://pycodestyle.readthedocs.io/en/latest/developer.html#changes for changes made in this release. ## What changes were proposed in this pull request? Upgrade pycodestyle to v2.4.0 ## How was this patch tested? __pycodestyle__ Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22231 from cclauss/patch-1. Authored-by: cclauss Signed-off-by: Sean Owen --- dev/lint-python | 2 +- dev/run-tests-jenkins.py | 4 ++-- dev/tox.ini | 2 +- python/pyspark/sql/functions.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index a98a251af9e6c..e26bd4bd4517c 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -36,7 +36,7 @@ compile_status="${PIPESTATUS[0]}" # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. # See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 # Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. -PYCODESTYLE_VERSION="2.3.1" +PYCODESTYLE_VERSION="2.4.0" PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 6e943898ffed9..eca88f2391bf8 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -116,7 +116,7 @@ def run_tests(tests_timeout): failure_note_by_errcode = { # error to denote run-tests script failures: - 1: 'executing the `dev/run-tests` script', # noqa: W605 + 1: 'executing the `dev/run-tests` script', ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', @@ -131,7 +131,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests', ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', - ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( + ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of `%s`' % ( tests_timeout) } diff --git a/dev/tox.ini b/dev/tox.ini index 28dad8f3b5c7c..6ec223b743b4e 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -14,6 +14,6 @@ # limitations under the License. [pycodestyle] -ignore=E402,E731,E241,W503,E226,E722,E741,E305 +ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e288ec818b404..6da5237d18de4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1711,7 +1711,7 @@ def regexp_extract(str, pattern, idx): @ignore_unicode_prefix @since(1.5) def regexp_replace(str, pattern, replacement): - """Replace all substrings of the specified string value that match regexp with rep. + r"""Replace all substrings of the specified string value that match regexp with rep. >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() From be454a7cef1cb5c76fb22589fc3a55c1bf519cf4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 15 Sep 2018 12:50:46 +0900 Subject: [PATCH 1615/2461] Revert "[SPARK-25431][SQL][EXAMPLES] Fix function examples and unify the format of the example results." This reverts commit 9c25d7f735ed8c49c795babea3fda3cab226e7cb. --- .../expressions/collectionOperations.scala | 49 +++++++++---------- .../expressions/complexTypeCreator.scala | 4 +- .../expressions/higherOrderFunctions.scala | 32 ++++++------ 3 files changed, 42 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3ad21ec5e51f4..ea6fcccddfd49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -131,7 +131,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [1, 2] + [1,2] """) case class MapKeys(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -320,7 +320,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - ["a", "b"] + ["a","b"] """) case class MapValues(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -348,7 +348,7 @@ case class MapValues(child: Expression) examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [[1, "a"], [2, "b"]] + [(1,"a"),(2,"b")] """, since = "2.4.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -516,7 +516,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); - [1 -> "a", 2 -> "b", 2 -> "c", 3 -> "d"] + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] """, since = "2.4.0") case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { @@ -718,7 +718,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres examples = """ Examples: > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); - [1 -> "a", 2 -> "b"] + {1:"a",2:"b"} """, since = "2.4.0") case class MapFromEntries(child: Expression) extends UnaryExpression { @@ -1071,7 +1071,7 @@ object ArraySortLike { examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); - [null, "a", "b", "c", "d"] + [null,"a","b","c","d"] """) // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) @@ -1129,7 +1129,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) examples = """ Examples: > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); - ["a", "b", "c", "d", null] + ["a","b","c","d",null] """, since = "2.4.0") // scalastyle:on line.size.limit @@ -1254,7 +1254,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) examples = """ Examples: > SELECT _FUNC_('Spark SQL'); - "LQS krapS" + LQS krapS > SELECT _FUNC_(array(2, 1, 4, 3)); [3, 4, 1, 2] """, @@ -1634,9 +1634,9 @@ case class ArraysOverlap(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); - [2, 3] + [2,3] > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); - [3, 4] + [3,4] """, since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) @@ -1745,11 +1745,11 @@ case class Slice(x: Expression, start: Expression, length: Expression) examples = """ Examples: > SELECT _FUNC_(array('hello', 'world'), ' '); - "hello world" + hello world > SELECT _FUNC_(array('hello', null ,'world'), ' '); - "hello world" + hello world > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); - "hello , world" + hello , world """, since = "2.4.0") case class ArrayJoin( array: Expression, @@ -2236,11 +2236,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti examples = """ Examples: > SELECT _FUNC_('Spark', 'SQL'); - "SparkSQL" + SparkSQL > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - [1, 2, 3, 4, 5, 6] - """, - note = "Concat logic for arrays is available since 2.4.0.") + | [1,2,3,4,5,6] + """) case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) @@ -2428,8 +2427,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", examples = """ Examples: - > SELECT _FUNC_(array(array(1, 2), array(3, 4))); - [1, 2, 3, 4] + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] """, since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { @@ -2935,7 +2934,7 @@ object Sequence { examples = """ Examples: > SELECT _FUNC_('123', 2); - ["123", "123"] + ['123', '123'] """, since = "2.4.0") case class ArrayRepeat(left: Expression, right: Expression) @@ -3056,7 +3055,7 @@ case class ArrayRepeat(left: Expression, right: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); - [1, 2, null] + [1,2,null] """, since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -3246,7 +3245,7 @@ trait ArraySetLike { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3, null, 3)); - [1, 2, 3, null] + [1,2,3,null] """, since = "2.4.0") case class ArrayDistinct(child: Expression) extends UnaryExpression with ArraySetLike with ExpectsInputTypes { @@ -3422,7 +3421,7 @@ object ArrayBinaryLike { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - [1, 2, 3, 5] + array(1, 2, 3, 5) """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike @@ -3633,7 +3632,7 @@ object ArrayUnion { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - [1, 3] + array(1, 3) """, since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBinaryLike @@ -3874,7 +3873,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - [2] + array(2) """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 117fa3e9aa519..aba9c6c8ad6fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -248,8 +248,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { in keys should not be null""", examples = """ Examples: - > SELECT _FUNC_(array(1.0, 3.0), array('2', '4')); - [1.0 -> "2", 3.0 -> "4"] + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 3ef2ec03099e4..2bb6b20b944d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -209,9 +209,9 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); - [2, 3, 4] + array(2, 3, 4) > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); - [1, 3, 5] + array(1, 3, 5) """, since = "2.4.0") case class ArrayTransform( @@ -318,7 +318,7 @@ case class MapFilter( examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); - [1, 3] + array(1, 3) """, since = "2.4.0") case class ArrayFilter( @@ -499,10 +499,10 @@ case class ArrayAggregate( usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", examples = """ Examples: - > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); - [2 -> 1, 3 -> 2, 4 -> 3] - > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); - [2 -> 1, 4 -> 2, 6 -> 3] + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + map(array(2, 3, 4), array(1, 2, 3)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(2, 4, 6), array(1, 2, 3)) """, since = "2.4.0") case class TransformKeys( @@ -549,10 +549,10 @@ case class TransformKeys( usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", examples = """ Examples: - > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); - [1 -> 2, 2 -> 3, 3 -> 4] - > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); - [1 -> 2, 2 -> 4, 3 -> 6] + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + map(array(1, 2, 3), array(2, 3, 4)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(1, 2, 3), array(2, 4, 6)) """, since = "2.4.0") case class TransformValues( @@ -603,7 +603,7 @@ case class TransformValues( examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); - [1 -> "ax", 2 -> "by"] + {1:"ax",2:"by"} """, since = "2.4.0") case class MapZipWith(left: Expression, right: Expression, function: Expression) @@ -777,11 +777,11 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); - [["a", 1], ["b", 2], ["c", 3]] - > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y); - [4, 6] + array(('a', 1), ('b', 2), ('c', 3)) + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); + array(4, 6) > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); - ["ad", "be", "cf"] + array('ad', 'be', 'cf') """, since = "2.4.0") // scalastyle:on line.size.limit From 5ebef33c85a66cdc29db2eff2343600602bbe94e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 15 Sep 2018 16:20:45 -0700 Subject: [PATCH 1616/2461] [SPARK-25426][SQL] Remove the duplicate fallback logic in UnsafeProjection ## What changes were proposed in this pull request? This pr removed the duplicate fallback logic in `UnsafeProjection`. This pr comes from #22355. ## How was this patch tested? Added tests in `CodeGeneratorWithInterpretedFallbackSuite`. Closes #22417 from maropu/SPARK-25426. Authored-by: Takeshi Yamamuro Signed-off-by: gatorsmile --- .../sql/catalyst/expressions/Projection.scala | 25 ++----------------- .../execution/basicPhysicalOperators.scala | 3 +-- 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 226a4ddcffaa8..5f24170398715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import scala.util.control.NonFatal - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -117,7 +116,7 @@ object UnsafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(in) + GenerateUnsafeProjection.generate(in, SQLConf.get.subexpressionEliminationEnabled) } override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { @@ -168,26 +167,6 @@ object UnsafeProjection def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(toBoundExprs(exprs, inputSchema)) } - - /** - * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, - * when fallbacking to interpreted execution, it is not supported. - */ - def create( - exprs: Seq[Expression], - inputSchema: Seq[Attribute], - subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) - try { - GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) - } catch { - case NonFatal(_) => - // We should have already seen the error message in `CodeGenerator` - logWarning("Expr codegen error and falling back to interpreter mode") - InterpretedUnsafeProjection.createProjection(unsafeExprs) - } - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 9434ceb7cd16c..222a1b8bc7301 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -68,8 +68,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsWithIndexInternal { (index, iter) => - val project = UnsafeProjection.create(projectList, child.output, - subexpressionEliminationEnabled) + val project = UnsafeProjection.create(projectList, child.output) project.initialize(index) iter.map(project) } From bb2f069cf2c1e5b05362c7bbe8e0994a3e36a626 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 15 Sep 2018 16:24:02 -0700 Subject: [PATCH 1617/2461] [SPARK-25436] Bump master branch version to 2.5.0-SNAPSHOT ## What changes were proposed in this pull request? In the dev list, we can still discuss whether the next version is 2.5.0 or 3.0.0. Let us first bump the master branch version to `2.5.0-SNAPSHOT`. ## How was this patch tested? N/A Closes #22426 from gatorsmile/bumpVersionMaster. Authored-by: gatorsmile Signed-off-by: gatorsmile --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/avro/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ python/pyspark/version.py | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/kubernetes/integration-tests/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 44 files changed, 49 insertions(+), 44 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f52d785e05cdd..96090bed6899b 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.4.0 +Version: 2.5.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/assembly/pom.xml b/assembly/pom.xml index 9608c96fd5369..d431d3f8caf28 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 8c148359c3029..cdb4359d17a87 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8ca7733507f1b..5a8e0eb46cf91 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 05335df61a664..24c8675cbd7d0 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 564e6583c909e..40b9267a335e3 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2f04abe8c7e88..c2adbf04563b5 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index ba127408e1c59..d0a3c7a61d1cf 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 1527854730394..b0c52d9468226 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 5fa3a86de6b01..d881dc1f140ec 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/docs/_config.yml b/docs/_config.yml index 095fadb93fe5d..75e54c6ecffcf 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.4.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.4.0 +SPARK_VERSION: 2.5.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.5.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.8" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 868110b8e35ef..6a736d599fabe 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 8f118ba48201b..ee31741524cf8 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 431339d412194..bd3e4adf99186 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 7cd1ec4c9c09a..36bc02a742cd2 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index f810aa80e8780..68059f0e121cc 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 498e88f665eb5..4335732034737 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index f80f8e3a0183d..668fbcd1103cf 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 8588e8be052eb..5b8baedb71d1c 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index a97fd35bfbb73..5aabef63a891c 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 6be17a81f3fed..4a202861bf380 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 6d1c4789f382d..0d5e609ded3e8 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 68fded515626b..81eea530fa2bc 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 4915893965595..735e78192f5aa 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 027157e53d511..07bc7fa722fb0 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 0f5dc548600b2..c0f373d71053d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 2c39a7df0146e..d20bdd0d68e93 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 912eb6b6d2a08..7c32aade17c71 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 53286fe93478d..e19c09f287656 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index f07d7f24fd312..e4ac94aba462d 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 05e3b05613efd..71d5f944ee60e 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 55dc2b81cfe2f..f4c34a140e9ca 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.5.x + lazy val v25excludes = v24excludes ++ Seq( + ) + // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( // [SPARK-23429][CORE] Add executor memory metrics to heartbeat and expose in executors REST API @@ -1202,6 +1206,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.5") => v25excludes case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes diff --git a/python/pyspark/version.py b/python/pyspark/version.py index b9c2c4ced71d5..6bd00c59d506a 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.4.0.dev0" +__version__ = "2.5.0.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index e8464a688336b..17121216a021d 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 920f0f6ebf2c8..06e522fa93f04 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 614705c1ed668..43b7857a79400 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 3995d0afeb5f4..5674036ee7935 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 37e25ceecb883..8d5d6ab5a3f5a 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 7d23637e28342..224c70ce24d62 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ba17f5f33f2b6..f78126ef53077 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9f247f9224c75..6c8e52d0afa6f 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index c55ba32fa458c..0bb6026910fbd 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 4497e53b65984..90c8b974c376d 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 242219e29f50f..14ccae6f8f187 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.0-SNAPSHOT + 2.5.0-SNAPSHOT ../pom.xml From e06da95cd9423f55cdb154a2778b0bddf7be984c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 15 Sep 2018 17:24:11 -0700 Subject: [PATCH 1618/2461] [SPARK-25425][SQL] Extra options should override session options in DataSource V2 ## What changes were proposed in this pull request? In the PR, I propose overriding session options by extra options in DataSource V2. Extra options are more specific and set via `.option()`, and should overwrite more generic session options. Entries from seconds map overwrites entries with the same key from the first map, for example: ```Scala scala> Map("option" -> false) ++ Map("option" -> true) res0: scala.collection.immutable.Map[String,Boolean] = Map(option -> true) ``` ## How was this patch tested? Added a test for checking which option is propagated to a data source in `load()`. Closes #22413 from MaxGekk/session-options. Lead-authored-by: Maxim Gekk Co-authored-by: Dongjoon Hyun Co-authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 8 +++-- .../sql/sources/v2/DataSourceV2Suite.scala | 35 ++++++++++++++++++- .../sources/v2/SimpleWritableDataSource.scala | 6 +++- 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e6c2cba79841a..fe69f252d43e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -202,7 +202,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) } Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, extraOptions.toMap ++ sessionOptions + pathsOption, + ds, sessionOptions ++ extraOptions.toMap + pathsOption, userSpecifiedSchema = userSpecifiedSchema)) } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index dfb8c4718550f..188fce72efac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -241,10 +241,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val source = cls.newInstance().asInstanceOf[DataSourceV2] source match { case provider: BatchWriteSupportProvider => - val options = extraOptions ++ - DataSourceV2Utils.extractSessionConfigs(source, df.sparkSession.sessionState.conf) + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + source, + df.sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions - val relation = DataSourceV2Relation.create(source, options.toMap) + val relation = DataSourceV2Relation.create(source, options) if (mode == SaveMode.Append) { runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f6c3e0ce82e3f..7cc8abc9f0428 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources.v2 +import java.io.File + import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -317,6 +319,38 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { checkCanonicalizedOutput(df, 2, 2) checkCanonicalizedOutput(df.select('i), 2, 1) } + + test("SPARK-25425: extra options should override sessions options during reading") { + val prefix = "spark.datasource.userDefinedDataSource." + val optionName = "optionA" + withSQLConf(prefix + optionName -> "true") { + val df = spark + .read + .option(optionName, false) + .format(classOf[DataSourceV2WithSessionConfig].getName).load() + val options = df.queryExecution.optimizedPlan.collectFirst { + case d: DataSourceV2Relation => d.options + } + assert(options.get.get(optionName) == Some("false")) + } + } + + test("SPARK-25425: extra options should override sessions options during writing") { + withTempPath { path => + val sessionPath = path.getCanonicalPath + withSQLConf("spark.datasource.simpleWritableDataSource.path" -> sessionPath) { + withTempPath { file => + val optionPath = file.getCanonicalPath + val format = classOf[SimpleWritableDataSource].getName + + val df = Seq((1L, 2L)).toDF("i", "j") + df.write.format(format).option("path", optionPath).save() + assert(!new File(sessionPath).exists) + checkAnswer(spark.read.format(format).option("path", optionPath).load(), df) + } + } + } + } } @@ -385,7 +419,6 @@ class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { } } - class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { class ReadSupport extends SimpleReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 952241b0b6be5..a0f4404f46140 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -39,10 +39,14 @@ import org.apache.spark.util.SerializableConfiguration * Each job moves files from `target/_temporary/queryId/` to `target`. */ class SimpleWritableDataSource extends DataSourceV2 - with BatchReadSupportProvider with BatchWriteSupportProvider { + with BatchReadSupportProvider + with BatchWriteSupportProvider + with SessionConfigSupport { private val schema = new StructType().add("i", "long").add("j", "long") + override def keyPrefix: String = "simpleWritableDataSource" + class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { override def fullSchema(): StructType = schema From fefaa3c30df2c56046370081cb51bfe68d26976b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 15 Sep 2018 17:48:39 -0700 Subject: [PATCH 1619/2461] [SPARK-25438][SQL][TEST] Fix FilterPushdownBenchmark to use the same memory assumption ## What changes were proposed in this pull request? This PR aims to fix three things in `FilterPushdownBenchmark`. **1. Use the same memory assumption.** The following configurations are used in ORC and Parquet. - Memory buffer for writing - parquet.block.size (default: 128MB) - orc.stripe.size (default: 64MB) - Compression chunk size - parquet.page.size (default: 1MB) - orc.compress.size (default: 256KB) SPARK-24692 used 1MB, the default value of `parquet.page.size`, for `parquet.block.size` and `orc.stripe.size`. But, it missed to match `orc.compress.size`. So, the current benchmark shows the result from ORC with 256KB memory for compression and Parquet with 1MB. To compare correctly, we need to be consistent. **2. Dictionary encoding should not be enforced for all cases.** SPARK-24206 enforced dictionary encoding for all test cases. This PR recovers the default behavior in general and enforces dictionary encoding only in case of `prepareStringDictTable`. **3. Generate test result on AWS r3.xlarge** SPARK-24206 generated the result on AWS in order to reproduce and compare easily. This PR also aims to update the result on the same machine again in the same reason. Specifically, AWS r3.xlarge with Instance Store is used. ## How was this patch tested? Manual. Enable the test cases and run `FilterPushdownBenchmark` on `AWS r3.xlarge`. It takes about 4 hours 15 minutes. Closes #22427 from dongjoon-hyun/SPARK-25438. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../FilterPushdownBenchmark-results.txt | 912 ++++++++---------- .../benchmark/FilterPushdownBenchmark.scala | 11 +- 2 files changed, 428 insertions(+), 495 deletions(-) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index a75a15c99328a..e680ddff53dd1 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -2,737 +2,669 @@ Pushdown for many distinct value case ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X -Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X -Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X -Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11405 / 11485 1.4 725.1 1.0X +Parquet Vectorized (Pushdown) 675 / 690 23.3 42.9 16.9X +Native ORC Vectorized 7127 / 7170 2.2 453.1 1.6X +Native ORC Vectorized (Pushdown) 519 / 541 30.3 33.0 22.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X -Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X -Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X -Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11457 / 11473 1.4 728.4 1.0X +Parquet Vectorized (Pushdown) 656 / 686 24.0 41.7 17.5X +Native ORC Vectorized 7328 / 7342 2.1 465.9 1.6X +Native ORC Vectorized (Pushdown) 539 / 565 29.2 34.2 21.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X -Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X -Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X -Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11878 / 11888 1.3 755.2 1.0X +Parquet Vectorized (Pushdown) 630 / 654 25.0 40.1 18.9X +Native ORC Vectorized 7342 / 7362 2.1 466.8 1.6X +Native ORC Vectorized (Pushdown) 519 / 537 30.3 33.0 22.9X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X -Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X -Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X -Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11423 / 11440 1.4 726.2 1.0X +Parquet Vectorized (Pushdown) 625 / 643 25.2 39.7 18.3X +Native ORC Vectorized 7315 / 7335 2.2 465.1 1.6X +Native ORC Vectorized (Pushdown) 507 / 520 31.0 32.2 22.5X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X -Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X -Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X -Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11440 / 11478 1.4 727.3 1.0X +Parquet Vectorized (Pushdown) 634 / 652 24.8 40.3 18.0X +Native ORC Vectorized 7311 / 7324 2.2 464.8 1.6X +Native ORC Vectorized (Pushdown) 517 / 548 30.4 32.8 22.1X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X -Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X -Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X -Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 20750 / 20872 0.8 1319.3 1.0X +Parquet Vectorized (Pushdown) 21002 / 21032 0.7 1335.3 1.0X +Native ORC Vectorized 16714 / 16742 0.9 1062.6 1.2X +Native ORC Vectorized (Pushdown) 16926 / 16965 0.9 1076.1 1.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X -Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X -Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X -Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10510 / 10532 1.5 668.2 1.0X +Parquet Vectorized (Pushdown) 642 / 665 24.5 40.8 16.4X +Native ORC Vectorized 6609 / 6618 2.4 420.2 1.6X +Native ORC Vectorized (Pushdown) 502 / 512 31.4 31.9 21.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X -Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X -Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X -Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10505 / 10514 1.5 667.9 1.0X +Parquet Vectorized (Pushdown) 659 / 673 23.9 41.9 15.9X +Native ORC Vectorized 6634 / 6641 2.4 421.8 1.6X +Native ORC Vectorized (Pushdown) 513 / 526 30.7 32.6 20.5X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X -Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X -Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X -Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10555 / 10570 1.5 671.1 1.0X +Parquet Vectorized (Pushdown) 651 / 668 24.2 41.4 16.2X +Native ORC Vectorized 6721 / 6728 2.3 427.3 1.6X +Native ORC Vectorized (Pushdown) 508 / 519 31.0 32.3 20.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X -Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X -Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X -Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10556 / 10566 1.5 671.1 1.0X +Parquet Vectorized (Pushdown) 647 / 654 24.3 41.1 16.3X +Native ORC Vectorized 6716 / 6728 2.3 427.0 1.6X +Native ORC Vectorized (Pushdown) 510 / 521 30.9 32.4 20.7X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X -Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X -Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X -Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10556 / 10565 1.5 671.1 1.0X +Parquet Vectorized (Pushdown) 649 / 654 24.2 41.3 16.3X +Native ORC Vectorized 6700 / 6712 2.3 426.0 1.6X +Native ORC Vectorized (Pushdown) 509 / 520 30.9 32.3 20.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X -Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X -Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X -Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10547 / 10566 1.5 670.5 1.0X +Parquet Vectorized (Pushdown) 649 / 653 24.2 41.3 16.3X +Native ORC Vectorized 6703 / 6713 2.3 426.2 1.6X +Native ORC Vectorized (Pushdown) 510 / 520 30.8 32.5 20.7X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X -Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X -Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X -Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11478 / 11525 1.4 729.7 1.0X +Parquet Vectorized (Pushdown) 2576 / 2587 6.1 163.8 4.5X +Native ORC Vectorized 7633 / 7657 2.1 485.3 1.5X +Native ORC Vectorized (Pushdown) 2076 / 2096 7.6 132.0 5.5X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X -Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X -Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X -Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 14785 / 14802 1.1 940.0 1.0X +Parquet Vectorized (Pushdown) 9971 / 9977 1.6 633.9 1.5X +Native ORC Vectorized 11082 / 11107 1.4 704.6 1.3X +Native ORC Vectorized (Pushdown) 8061 / 8073 2.0 512.5 1.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X -Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X -Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X -Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 18174 / 18214 0.9 1155.5 1.0X +Parquet Vectorized (Pushdown) 17387 / 17403 0.9 1105.5 1.0X +Native ORC Vectorized 14465 / 14492 1.1 919.7 1.3X +Native ORC Vectorized (Pushdown) 14024 / 14041 1.1 891.6 1.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X -Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X -Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X -Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 19004 / 19014 0.8 1208.2 1.0X +Parquet Vectorized (Pushdown) 19219 / 19232 0.8 1221.9 1.0X +Native ORC Vectorized 15266 / 15290 1.0 970.6 1.2X +Native ORC Vectorized (Pushdown) 15469 / 15482 1.0 983.5 1.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X -Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X -Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X -Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 19036 / 19052 0.8 1210.3 1.0X +Parquet Vectorized (Pushdown) 19287 / 19306 0.8 1226.2 1.0X +Native ORC Vectorized 15311 / 15371 1.0 973.5 1.2X +Native ORC Vectorized (Pushdown) 15517 / 15590 1.0 986.5 1.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X -Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X -Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X -Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X +Parquet Vectorized 19072 / 19102 0.8 1212.6 1.0X +Parquet Vectorized (Pushdown) 19288 / 19318 0.8 1226.3 1.0X +Native ORC Vectorized 15277 / 15293 1.0 971.3 1.2X +Native ORC Vectorized (Pushdown) 15479 / 15499 1.0 984.1 1.2X ================================================================================================ Pushdown for few distinct value case (use dictionary encoding) ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X -Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X -Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X -Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10250 / 10274 1.5 651.7 1.0X +Parquet Vectorized (Pushdown) 571 / 576 27.5 36.3 17.9X +Native ORC Vectorized 8651 / 8660 1.8 550.0 1.2X +Native ORC Vectorized (Pushdown) 909 / 933 17.3 57.8 11.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X -Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X -Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X -Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10420 / 10426 1.5 662.5 1.0X +Parquet Vectorized (Pushdown) 574 / 579 27.4 36.5 18.2X +Native ORC Vectorized 8973 / 8982 1.8 570.5 1.2X +Native ORC Vectorized (Pushdown) 916 / 955 17.2 58.2 11.4X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X -Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X -Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X -Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10428 / 10441 1.5 663.0 1.0X +Parquet Vectorized (Pushdown) 789 / 809 19.9 50.2 13.2X +Native ORC Vectorized 9042 / 9055 1.7 574.9 1.2X +Native ORC Vectorized (Pushdown) 1130 / 1145 13.9 71.8 9.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X -Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X -Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X -Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10402 / 10416 1.5 661.3 1.0X +Parquet Vectorized (Pushdown) 791 / 806 19.9 50.3 13.2X +Native ORC Vectorized 9042 / 9055 1.7 574.9 1.2X +Native ORC Vectorized (Pushdown) 1112 / 1145 14.1 70.7 9.4X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X -Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X -Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X -Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10548 / 10563 1.5 670.6 1.0X +Parquet Vectorized (Pushdown) 790 / 796 19.9 50.2 13.4X +Native ORC Vectorized 9144 / 9153 1.7 581.3 1.2X +Native ORC Vectorized (Pushdown) 1117 / 1148 14.1 71.0 9.4X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X -Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X -Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X -Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X +Parquet Vectorized 20445 / 20469 0.8 1299.8 1.0X +Parquet Vectorized (Pushdown) 20686 / 20699 0.8 1315.2 1.0X +Native ORC Vectorized 18851 / 18953 0.8 1198.5 1.1X +Native ORC Vectorized (Pushdown) 19255 / 19268 0.8 1224.2 1.1X ================================================================================================ Pushdown benchmark for StringStartsWith ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X -Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X -Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X -Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 14265 / 15213 1.1 907.0 1.0X +Parquet Vectorized (Pushdown) 4228 / 4870 3.7 268.8 3.4X +Native ORC Vectorized 10116 / 10977 1.6 643.2 1.4X +Native ORC Vectorized (Pushdown) 10653 / 11376 1.5 677.3 1.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X -Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X -Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X -Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11499 / 11539 1.4 731.1 1.0X +Parquet Vectorized (Pushdown) 669 / 672 23.5 42.5 17.2X +Native ORC Vectorized 7343 / 7363 2.1 466.8 1.6X +Native ORC Vectorized (Pushdown) 7559 / 7568 2.1 480.6 1.5X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X -Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X -Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X -Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X +Parquet Vectorized 11463 / 11468 1.4 728.8 1.0X +Parquet Vectorized (Pushdown) 647 / 651 24.3 41.1 17.7X +Native ORC Vectorized 7322 / 7338 2.1 465.5 1.6X +Native ORC Vectorized (Pushdown) 7533 / 7544 2.1 478.9 1.5X ================================================================================================ Pushdown benchmark for decimal ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X -Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X -Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X -Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 5543 / 5564 2.8 352.4 1.0X +Parquet Vectorized (Pushdown) 168 / 174 93.7 10.7 33.0X +Native ORC Vectorized 4992 / 5052 3.2 317.4 1.1X +Native ORC Vectorized (Pushdown) 840 / 850 18.7 53.4 6.6X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X -Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X -Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X -Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 7312 / 7358 2.2 464.9 1.0X +Parquet Vectorized (Pushdown) 3008 / 3078 5.2 191.2 2.4X +Native ORC Vectorized 6775 / 6798 2.3 430.7 1.1X +Native ORC Vectorized (Pushdown) 6819 / 6832 2.3 433.5 1.1X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X -Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X -Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X -Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 13232 / 13241 1.2 841.3 1.0X +Parquet Vectorized (Pushdown) 12555 / 12569 1.3 798.2 1.1X +Native ORC Vectorized 12597 / 12627 1.2 800.9 1.1X +Native ORC Vectorized (Pushdown) 12677 / 12711 1.2 806.0 1.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X -Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X -Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X -Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 14725 / 14729 1.1 936.2 1.0X +Parquet Vectorized (Pushdown) 14781 / 14800 1.1 939.7 1.0X +Native ORC Vectorized 15360 / 15453 1.0 976.5 1.0X +Native ORC Vectorized (Pushdown) 15444 / 15466 1.0 981.9 1.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X -Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X -Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X -Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 5746 / 5763 2.7 365.3 1.0X +Parquet Vectorized (Pushdown) 166 / 169 94.8 10.6 34.6X +Native ORC Vectorized 5007 / 5023 3.1 318.3 1.1X +Native ORC Vectorized (Pushdown) 2629 / 2640 6.0 167.1 2.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X -Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X -Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X -Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 6827 / 6864 2.3 434.0 1.0X +Parquet Vectorized (Pushdown) 1809 / 1827 8.7 115.0 3.8X +Native ORC Vectorized 6287 / 6296 2.5 399.7 1.1X +Native ORC Vectorized (Pushdown) 6364 / 6377 2.5 404.6 1.1X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X -Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X -Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X -Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 11315 / 11342 1.4 719.4 1.0X +Parquet Vectorized (Pushdown) 8431 / 8450 1.9 536.0 1.3X +Native ORC Vectorized 11591 / 11611 1.4 736.9 1.0X +Native ORC Vectorized (Pushdown) 11424 / 11475 1.4 726.3 1.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X -Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X -Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X -Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 15703 / 15712 1.0 998.4 1.0X +Parquet Vectorized (Pushdown) 14982 / 15009 1.0 952.5 1.0X +Native ORC Vectorized 16887 / 16955 0.9 1073.7 0.9X +Native ORC Vectorized (Pushdown) 16518 / 16530 1.0 1050.2 1.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X -Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X -Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X -Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 8101 / 8130 1.9 515.1 1.0X +Parquet Vectorized (Pushdown) 184 / 187 85.6 11.7 44.1X +Native ORC Vectorized 4998 / 5027 3.1 317.8 1.6X +Native ORC Vectorized (Pushdown) 165 / 168 95.6 10.5 49.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X -Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X -Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X -Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 9405 / 9447 1.7 597.9 1.0X +Parquet Vectorized (Pushdown) 2269 / 2275 6.9 144.2 4.1X +Native ORC Vectorized 6167 / 6203 2.6 392.1 1.5X +Native ORC Vectorized (Pushdown) 1783 / 1787 8.8 113.3 5.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X -Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X -Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X -Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 14700 / 14707 1.1 934.6 1.0X +Parquet Vectorized (Pushdown) 10699 / 10712 1.5 680.2 1.4X +Native ORC Vectorized 10687 / 10703 1.5 679.5 1.4X +Native ORC Vectorized (Pushdown) 8364 / 8415 1.9 531.8 1.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X -Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X -Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X -Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X +Parquet Vectorized 19780 / 19894 0.8 1257.6 1.0X +Parquet Vectorized (Pushdown) 19003 / 19025 0.8 1208.1 1.0X +Native ORC Vectorized 15385 / 15404 1.0 978.2 1.3X +Native ORC Vectorized (Pushdown) 15032 / 15060 1.0 955.7 1.3X ================================================================================================ Pushdown benchmark for InSet -> InFilters ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X -Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X -Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X -Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10521 / 10534 1.5 668.9 1.0X +Parquet Vectorized (Pushdown) 677 / 691 23.2 43.1 15.5X +Native ORC Vectorized 6768 / 6776 2.3 430.3 1.6X +Native ORC Vectorized (Pushdown) 501 / 512 31.4 31.8 21.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X -Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X -Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X -Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10531 / 10538 1.5 669.5 1.0X +Parquet Vectorized (Pushdown) 677 / 718 23.2 43.0 15.6X +Native ORC Vectorized 6765 / 6773 2.3 430.1 1.6X +Native ORC Vectorized (Pushdown) 499 / 507 31.5 31.7 21.1X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X -Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X -Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X -Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10540 / 10553 1.5 670.1 1.0X +Parquet Vectorized (Pushdown) 678 / 710 23.2 43.1 15.5X +Native ORC Vectorized 6787 / 6794 2.3 431.5 1.6X +Native ORC Vectorized (Pushdown) 501 / 509 31.4 31.9 21.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X -Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X -Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X -Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10551 / 10559 1.5 670.8 1.0X +Parquet Vectorized (Pushdown) 703 / 708 22.4 44.7 15.0X +Native ORC Vectorized 6791 / 6802 2.3 431.7 1.6X +Native ORC Vectorized (Pushdown) 519 / 526 30.3 33.0 20.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X -Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X -Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X -Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10561 / 10565 1.5 671.4 1.0X +Parquet Vectorized (Pushdown) 711 / 716 22.1 45.2 14.9X +Native ORC Vectorized 6791 / 6806 2.3 431.8 1.6X +Native ORC Vectorized (Pushdown) 529 / 537 29.8 33.6 20.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X -Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X -Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X -Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10572 / 10590 1.5 672.1 1.0X +Parquet Vectorized (Pushdown) 713 / 716 22.1 45.3 14.8X +Native ORC Vectorized 6808 / 6815 2.3 432.9 1.6X +Native ORC Vectorized (Pushdown) 530 / 541 29.7 33.7 19.9X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X -Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X -Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X -Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10871 / 10882 1.4 691.2 1.0X +Parquet Vectorized (Pushdown) 11104 / 11110 1.4 706.0 1.0X +Native ORC Vectorized 7088 / 7104 2.2 450.7 1.5X +Native ORC Vectorized (Pushdown) 665 / 677 23.6 42.3 16.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X -Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X -Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X -Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10861 / 10867 1.4 690.5 1.0X +Parquet Vectorized (Pushdown) 11094 / 11099 1.4 705.3 1.0X +Native ORC Vectorized 7075 / 7092 2.2 449.8 1.5X +Native ORC Vectorized (Pushdown) 718 / 733 21.9 45.6 15.1X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X -Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X -Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X -Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10868 / 10887 1.4 691.0 1.0X +Parquet Vectorized (Pushdown) 11100 / 11106 1.4 705.7 1.0X +Native ORC Vectorized 7087 / 7093 2.2 450.6 1.5X +Native ORC Vectorized (Pushdown) 712 / 731 22.1 45.3 15.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X -Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X -Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X -Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10850 / 10888 1.4 689.8 1.0X +Parquet Vectorized (Pushdown) 11086 / 11105 1.4 704.9 1.0X +Native ORC Vectorized 7090 / 7101 2.2 450.8 1.5X +Native ORC Vectorized (Pushdown) 867 / 882 18.1 55.1 12.5X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X -Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X -Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X -Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10816 / 10819 1.5 687.7 1.0X +Parquet Vectorized (Pushdown) 11052 / 11059 1.4 702.7 1.0X +Native ORC Vectorized 7037 / 7044 2.2 447.4 1.5X +Native ORC Vectorized (Pushdown) 919 / 931 17.1 58.4 11.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X -Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X -Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X -Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X +Parquet Vectorized 10807 / 10815 1.5 687.1 1.0X +Parquet Vectorized (Pushdown) 11047 / 11054 1.4 702.4 1.0X +Native ORC Vectorized 7042 / 7047 2.2 447.7 1.5X +Native ORC Vectorized (Pushdown) 950 / 961 16.6 60.4 11.4X ================================================================================================ Pushdown benchmark for tinyint ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X -Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X -Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X -Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 6034 / 6048 2.6 383.6 1.0X +Parquet Vectorized (Pushdown) 333 / 344 47.2 21.2 18.1X +Native ORC Vectorized 3240 / 3307 4.9 206.0 1.9X +Native ORC Vectorized (Pushdown) 330 / 341 47.6 21.0 18.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X -Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X -Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X -Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 6759 / 6800 2.3 429.7 1.0X +Parquet Vectorized (Pushdown) 1533 / 1537 10.3 97.5 4.4X +Native ORC Vectorized 3863 / 3874 4.1 245.6 1.7X +Native ORC Vectorized (Pushdown) 1235 / 1248 12.7 78.5 5.5X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X -Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X -Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X -Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10247 / 10289 1.5 651.5 1.0X +Parquet Vectorized (Pushdown) 7430 / 7453 2.1 472.4 1.4X +Native ORC Vectorized 6995 / 7009 2.2 444.7 1.5X +Native ORC Vectorized (Pushdown) 5561 / 5571 2.8 353.6 1.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X -Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X -Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X -Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X +Parquet Vectorized 13949 / 13991 1.1 886.9 1.0X +Parquet Vectorized (Pushdown) 13486 / 13511 1.2 857.4 1.0X +Native ORC Vectorized 10149 / 10186 1.5 645.3 1.4X +Native ORC Vectorized (Pushdown) 9889 / 9905 1.6 628.7 1.4X ================================================================================================ Pushdown benchmark for Timestamp ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X -Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X -Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X -Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 6307 / 6310 2.5 401.0 1.0X +Parquet Vectorized (Pushdown) 6360 / 6397 2.5 404.3 1.0X +Native ORC Vectorized 2912 / 2917 5.4 185.1 2.2X +Native ORC Vectorized (Pushdown) 138 / 141 114.4 8.7 45.9X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X -Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X -Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X -Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 7225 / 7233 2.2 459.4 1.0X +Parquet Vectorized (Pushdown) 7250 / 7255 2.2 461.0 1.0X +Native ORC Vectorized 3772 / 3783 4.2 239.8 1.9X +Native ORC Vectorized (Pushdown) 1277 / 1282 12.3 81.2 5.7X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X -Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X -Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X -Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10952 / 10965 1.4 696.3 1.0X +Parquet Vectorized (Pushdown) 10985 / 10998 1.4 698.4 1.0X +Native ORC Vectorized 7178 / 7227 2.2 456.3 1.5X +Native ORC Vectorized (Pushdown) 5825 / 5830 2.7 370.3 1.9X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X -Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X -Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X -Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 14560 / 14583 1.1 925.7 1.0X +Parquet Vectorized (Pushdown) 14608 / 14620 1.1 928.7 1.0X +Native ORC Vectorized 10601 / 10640 1.5 674.0 1.4X +Native ORC Vectorized (Pushdown) 10392 / 10406 1.5 660.7 1.4X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X -Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X -Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X -Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 5653 / 5658 2.8 359.4 1.0X +Parquet Vectorized (Pushdown) 165 / 169 95.1 10.5 34.2X +Native ORC Vectorized 2918 / 2921 5.4 185.5 1.9X +Native ORC Vectorized (Pushdown) 137 / 145 114.9 8.7 41.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X -Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X -Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X -Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 6540 / 6552 2.4 415.8 1.0X +Parquet Vectorized (Pushdown) 1610 / 1614 9.8 102.3 4.1X +Native ORC Vectorized 3775 / 3788 4.2 240.0 1.7X +Native ORC Vectorized (Pushdown) 1274 / 1277 12.3 81.0 5.1X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X -Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X -Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X -Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10259 / 10278 1.5 652.3 1.0X +Parquet Vectorized (Pushdown) 7591 / 7601 2.1 482.6 1.4X +Native ORC Vectorized 7185 / 7194 2.2 456.8 1.4X +Native ORC Vectorized (Pushdown) 5828 / 5843 2.7 370.6 1.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X -Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X -Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X -Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 13850 / 13868 1.1 880.5 1.0X +Parquet Vectorized (Pushdown) 13433 / 13450 1.2 854.0 1.0X +Native ORC Vectorized 10635 / 10669 1.5 676.1 1.3X +Native ORC Vectorized (Pushdown) 10437 / 10448 1.5 663.6 1.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X -Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X -Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X -Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 5884 / 5888 2.7 374.1 1.0X +Parquet Vectorized (Pushdown) 166 / 170 94.7 10.6 35.4X +Native ORC Vectorized 2913 / 2916 5.4 185.2 2.0X +Native ORC Vectorized (Pushdown) 136 / 144 115.4 8.7 43.2X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X -Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X -Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X -Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 6763 / 6776 2.3 430.0 1.0X +Parquet Vectorized (Pushdown) 1634 / 1638 9.6 103.9 4.1X +Native ORC Vectorized 3777 / 3785 4.2 240.1 1.8X +Native ORC Vectorized (Pushdown) 1276 / 1279 12.3 81.2 5.3X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X -Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X -Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X -Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Parquet Vectorized 10460 / 10469 1.5 665.0 1.0X +Parquet Vectorized (Pushdown) 7689 / 7698 2.0 488.9 1.4X +Native ORC Vectorized 7190 / 7197 2.2 457.1 1.5X +Native ORC Vectorized (Pushdown) 5820 / 5834 2.7 370.0 1.8X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X -Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X -Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X -Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X +Parquet Vectorized 14033 / 14039 1.1 892.2 1.0X +Parquet Vectorized (Pushdown) 13608 / 13636 1.2 865.2 1.0X +Native ORC Vectorized 10635 / 10686 1.5 676.2 1.3X +Native ORC Vectorized (Pushdown) 10420 / 10442 1.5 662.5 1.3X ================================================================================================ Pushdown benchmark with many filters ================================================================================================ -Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 -Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 row with 1 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 158 / 182 0.0 158442969.0 1.0X -Parquet Vectorized (Pushdown) 150 / 158 0.0 149718289.0 1.1X -Native ORC Vectorized 141 / 148 0.0 141259852.0 1.1X -Native ORC Vectorized (Pushdown) 142 / 147 0.0 142016472.0 1.1X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 -Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz +Parquet Vectorized 319 / 323 0.0 318789986.0 1.0X +Parquet Vectorized (Pushdown) 323 / 347 0.0 322755287.0 1.0X +Native ORC Vectorized 316 / 336 0.0 315670745.0 1.0X +Native ORC Vectorized (Pushdown) 317 / 320 0.0 317392594.0 1.0X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 row with 250 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 1013 / 1026 0.0 1013194322.0 1.0X -Parquet Vectorized (Pushdown) 1326 / 1332 0.0 1326301956.0 0.8X -Native ORC Vectorized 1005 / 1010 0.0 1005266379.0 1.0X -Native ORC Vectorized (Pushdown) 1068 / 1071 0.0 1067964993.0 0.9X - -Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 -Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz +Parquet Vectorized 2192 / 2218 0.0 2191883823.0 1.0X +Parquet Vectorized (Pushdown) 2675 / 2687 0.0 2675439029.0 0.8X +Native ORC Vectorized 2158 / 2162 0.0 2157646071.0 1.0X +Native ORC Vectorized (Pushdown) 2309 / 2326 0.0 2309096612.0 0.9X +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Select 1 row with 500 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3598 / 3614 0.0 3598001202.0 1.0X -Parquet Vectorized (Pushdown) 4282 / 4333 0.0 4281849770.0 0.8X -Native ORC Vectorized 3594 / 3619 0.0 3593551548.0 1.0X -Native ORC Vectorized (Pushdown) 3834 / 3840 0.0 3834240570.0 0.9X +Parquet Vectorized 6219 / 6248 0.0 6218727737.0 1.0X +Parquet Vectorized (Pushdown) 7376 / 7436 0.0 7375977710.0 0.8X +Native ORC Vectorized 6252 / 6279 0.0 6252473320.0 1.0X +Native ORC Vectorized (Pushdown) 6858 / 6876 0.0 6857854486.0 0.9X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 8596abd1b4ff2..d6dfdec45a0e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -53,7 +53,8 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter private val numRows = 1024 * 1024 * 15 private val width = 5 private val mid = numRows / 2 - private val blockSize = 1048576 + // For Parquet/ORC, we will use the same value for block size and compression size + private val blockSize = org.apache.parquet.hadoop.ParquetWriter.DEFAULT_PAGE_SIZE private val spark = SparkSession.builder().config(conf).getOrCreate() @@ -130,16 +131,16 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter } val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") - saveAsTable(df, dir) + saveAsTable(df, dir, true) } - private def saveAsTable(df: DataFrame, dir: File): Unit = { + private def saveAsTable(df: DataFrame, dir: File, useDictionary: Boolean = false): Unit = { val orcPath = dir.getCanonicalPath + "/orc" val parquetPath = dir.getCanonicalPath + "/parquet" - // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) df.write.mode("overwrite") - .option("orc.dictionary.key.threshold", 1.0) + .option("orc.dictionary.key.threshold", if (useDictionary) 1.0 else 0.8) + .option("orc.compress.size", blockSize) .option("orc.stripe.size", blockSize).orc(orcPath) spark.read.orc(orcPath).createOrReplaceTempView("orcTable") From 02c2963f895b9d78d7f6d9972cacec4ef55fa278 Mon Sep 17 00:00:00 2001 From: npoggi Date: Sat, 15 Sep 2018 20:06:08 -0700 Subject: [PATCH 1620/2461] [SPARK-25439][TESTS][SQL] Fixes TPCHQuerySuite datatype of customer.c_nationkey to BIGINT according to spec ## What changes were proposed in this pull request? Fixes TPCH DDL datatype of `customer.c_nationkey` from `STRING` to `BIGINT` according to spec and `nation.nationkey` in `TPCHQuerySuite.scala`. The rest of the keys are OK. Note, this will lead to **non-comparable previous results** to new runs involving the customer table. ## How was this patch tested? Manual tests Author: npoggi Closes #22430 from npoggi/SPARK-25439_Fix-TPCH-customer-c_nationkey. --- .../src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala index e3e700529bba7..b32d95d0b286c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCHQuerySuite.scala @@ -69,7 +69,7 @@ class TPCHQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, - |`c_nationkey` STRING, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), |`c_mktsegment` STRING, `c_comment` STRING) |USING parquet """.stripMargin) From bfcf7426057a964b3cee90089aab6c003addc4fb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 16 Sep 2018 04:14:19 +0000 Subject: [PATCH 1621/2461] [SPARK-24418][FOLLOWUP][DOC] Update docs to show Scala 2.11.12 ## What changes were proposed in this pull request? SPARK-24418 upgrades Scala to 2.11.12. This PR updates Scala version in docs. - https://spark.apache.org/docs/latest/quick-start.html#self-contained-applications (screenshot) ![screen1](https://user-images.githubusercontent.com/9700541/45590509-9c5f0400-b8ee-11e8-9293-e48d297db894.png) - https://spark.apache.org/docs/latest/rdd-programming-guide.html#working-with-key-value-pairs (Scala, Java) (These are hyperlink updates) - https://spark.apache.org/docs/latest/streaming-flume-integration.html#configuring-flume-1 (screenshot) ![screen2](https://user-images.githubusercontent.com/9700541/45590511-a123b800-b8ee-11e8-97a5-b7f2288229c2.png) ## How was this patch tested? Manual. ```bash $ cd docs $ SKIP_API=1 jekyll build ``` Closes #22431 from dongjoon-hyun/SPARK-24418. Authored-by: Dongjoon Hyun Signed-off-by: DB Tsai --- docs/_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_config.yml b/docs/_config.yml index 75e54c6ecffcf..dfc1a73f4ac1c 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -17,7 +17,7 @@ include: SPARK_VERSION: 2.5.0-SNAPSHOT SPARK_VERSION_SHORT: 2.5.0 SCALA_BINARY_VERSION: "2.11" -SCALA_VERSION: "2.11.8" +SCALA_VERSION: "2.11.12" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark From a1dd78255a3ae023820b2f245cd39f0c57a32fb1 Mon Sep 17 00:00:00 2001 From: Michael Chirico Date: Sun, 16 Sep 2018 12:57:44 -0700 Subject: [PATCH 1622/2461] [MINOR][DOCS] Axe deprecated doc refs Continuation of #22370. Summary of discussion there: There is some inconsistency in the R manual w.r.t. supercedent functions linking back to deprecated functions. - `createOrReplaceTempView` and `createTable` both link back to functions which are deprecated (`registerTempTable` and `createExternalTable`, respectively) - `sparkR.session` and `dropTempView` do _not_ link back to deprecated functions This PR takes the view that it is preferable _not_ to link back to deprecated functions, and removes these references from `?createOrReplaceTempView` and `?createTable`. As `registerTempTable` was included in the `SparkDataFrame functions` `family` of functions, other documentation pages which included a link to `?registerTempTable` will similarly be altered. Author: Michael Chirico Author: Michael Chirico Closes #22393 from MichaelChirico/axe_deprecated_doc_refs. --- R/pkg/R/DataFrame.R | 1 - R/pkg/R/catalog.R | 1 - 2 files changed, 2 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 4f2d4c7c002d4..458decaf4766f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -503,7 +503,6 @@ setMethod("createOrReplaceTempView", #' @param x A SparkDataFrame #' @param tableName A character vector containing the name of the table #' -#' @family SparkDataFrame functions #' @seealso \link{createOrReplaceTempView} #' @rdname registerTempTable-deprecated #' @name registerTempTable diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index baf4d861fcf86..c2d0fc38786be 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -69,7 +69,6 @@ createExternalTable <- function(x, ...) { #' @param ... additional named parameters as options for the data source. #' @return A SparkDataFrame. #' @rdname createTable -#' @seealso \link{createExternalTable} #' @examples #'\dontrun{ #' sparkR.session() From 538e0478783160d8fab2dc76fd8fc7b469cb4e19 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 17 Sep 2018 11:07:51 +0800 Subject: [PATCH 1623/2461] [SPARK-22713][CORE][TEST][FOLLOWUP] Fix flaky ExternalAppendOnlyMapSuite due to timeout ## What changes were proposed in this pull request? SPARK-22713 uses [`eventually` with the default timeout `150ms`](https://github.com/apache/spark/pull/21369/files#diff-5bbb6a931b7e4d6a31e4938f51935682R462). It causes flakiness because it's executed once when GC is slow. ```scala eventually { System.gc() ... } ``` **Failures** ```scala org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 1 times over 501.22261 milliseconds. Last failure message: tmpIsNull was false. ``` - master-test-sbt-hadoop-2.7 [4916](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4916) [4907](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4907) [4906](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4906) - spark-master-test-sbt-hadoop-2.6 [4979](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4979) [4974](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4974) [4967](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4967) [4966](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4966) ## How was this patch tested? Pass the Jenkins. Closes #22432 from dongjoon-hyun/SPARK-22713. Authored-by: Dongjoon Hyun Signed-off-by: Wenchen Fan --- .../spark/util/collection/ExternalAppendOnlyMapSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 8a2f2ffe0acf1..cd25265784136 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ import scala.ref.WeakReference import org.scalatest.Matchers @@ -457,7 +458,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala // (lines 69-89) // assert(map.currentMap == null) - eventually { + eventually(timeout(5 seconds), interval(200 milliseconds)) { System.gc() // direct asserts introduced some macro generated code that held a reference to the map val tmpIsNull = null == underlyingMapRef.get.orNull From b66e14dc96011a83f5ea0df8708ecb02a154ed1d Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 17 Sep 2018 15:21:18 +0800 Subject: [PATCH 1624/2461] [SPARK-24685][BUILD][FOLLOWUP] Fix the nonexist profile name in release script ## What changes were proposed in this pull request? `without-hadoop` profile doesn't exist in Maven, instead the name should be `hadoop-provided`, this is a regression introduced by SPARK-24685. So here fix it. ## How was this patch tested? Local test. Closes #22434 from jerryshao/SPARK-24685-followup. Authored-by: jerryshao Signed-off-by: Wenchen Fan --- dev/create-release/release-build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 73610a3335910..ca066bed133d2 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -271,7 +271,7 @@ if [[ "$1" == "package" ]]; then BINARY_PKGS_ARGS["hadoop2.7"]="-Phadoop-2.7 $HIVE_PROFILES" if ! is_dry_run; then BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" - BINARY_PKGS_ARGS["without-hadoop"]="-Pwithout-hadoop" + BINARY_PKGS_ARGS["without-hadoop"]="-Phadoop-provided" if [[ $SPARK_VERSION < "2.2." ]]; then BINARY_PKGS_ARGS["hadoop2.4"]="-Phadoop-2.4 $HIVE_PROFILES" BINARY_PKGS_ARGS["hadoop2.3"]="-Phadoop-2.3 $HIVE_PROFILES" From 619c949019feccd3fc2c9e58a841c655d05216f3 Mon Sep 17 00:00:00 2001 From: s71955 Date: Mon, 17 Sep 2018 19:22:27 +0800 Subject: [PATCH 1625/2461] [SPARK-23425][SQL][FOLLOWUP] Support wildcards in HDFS path for loadtable command. What changes were proposed in this pull request Updated the Migration guide for the behavior changes done in the JIRA issue SPARK-23425. How was this patch tested? Manually verified. Closes #22396 from sujith71955/master_newtest. Authored-by: s71955 Signed-off-by: Wenchen Fan --- docs/sql-programming-guide.md | 1 + .../spark/sql/hive/execution/SQLQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9da7d64322eb6..e262987ab23de 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1898,6 +1898,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.parallelFileListingInStatsComputation.enabled` to `False`. - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was writted as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. + - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. ## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 20c4c36c05091..e49aea267026e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1916,6 +1916,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-23425 Test LOAD DATA LOCAL INPATH with space in file name") { + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + val dirPath = dir.getAbsoluteFile + for (i <- 1 to 3) { + Files.write(s"$i", new File(dirPath, s"part-r-0000 $i"), StandardCharsets.UTF_8) + } + withTable("load_t") { + sql("CREATE TABLE load_t (a STRING)") + sql(s"LOAD DATA LOCAL INPATH '$path/part-r-0000 1' INTO TABLE load_t") + checkAnswer(sql("SELECT * FROM load_t"), Seq(Row("1"))) + } + } + } + test("Support wildcard character in folderlevel for LOAD DATA LOCAL INPATH") { withTempDir { dir => val path = dir.toURI.toString.stripSuffix("/") From 0dd61ec47df7078fd4f77d8c58ecf26c630c700e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 17 Sep 2018 19:33:51 +0800 Subject: [PATCH 1626/2461] [SPARK-25427][SQL][TEST] Add BloomFilter creation test cases ## What changes were proposed in this pull request? Spark supports BloomFilter creation for ORC files. This PR aims to add test coverages to prevent accidental regressions like [SPARK-12417](https://issues.apache.org/jira/browse/SPARK-12417). ## How was this patch tested? Pass the Jenkins with newly added test cases. Closes #22418 from dongjoon-hyun/SPARK-25427. Authored-by: Dongjoon Hyun Signed-off-by: Wenchen Fan --- .../datasources/orc/OrcSourceSuite.scala | 69 +++++++++++++++++++ .../sql/hive/orc/HiveOrcSourceSuite.scala | 9 +++ 2 files changed, 78 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 02bfb7197ffc0..b6bb1d7ba4ce3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -21,7 +21,12 @@ import java.io.File import java.sql.Timestamp import java.util.Locale +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.orc.OrcConf.COMPRESS +import org.apache.orc.OrcFile +import org.apache.orc.OrcProto.Stream.Kind +import org.apache.orc.impl.RecordReaderImpl import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.Row @@ -50,6 +55,66 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { .createOrReplaceTempView("orc_temp_table") } + protected def testBloomFilterCreation(bloomFilterKind: Kind) { + val tableName = "bloomFilter" + + withTempDir { dir => + withTable(tableName) { + val sqlStatement = orcImp match { + case "native" => + s""" + |CREATE TABLE $tableName (a INT, b STRING) + |USING ORC + |OPTIONS ( + | path '${dir.toURI}', + | orc.bloom.filter.columns '*', + | orc.bloom.filter.fpp 0.1 + |) + """.stripMargin + case "hive" => + s""" + |CREATE TABLE $tableName (a INT, b STRING) + |STORED AS ORC + |LOCATION '${dir.toURI}' + |TBLPROPERTIES ( + | orc.bloom.filter.columns='*', + | orc.bloom.filter.fpp=0.1 + |) + """.stripMargin + case impl => + throw new UnsupportedOperationException(s"Unknown ORC implementation: $impl") + } + + sql(sqlStatement) + sql(s"INSERT INTO $tableName VALUES (1, 'str')") + + val partFiles = dir.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + val orcFilePath = new Path(partFiles.head.getAbsolutePath) + val readerOptions = OrcFile.readerOptions(new Configuration()) + val reader = OrcFile.createReader(orcFilePath, readerOptions) + var recordReader: RecordReaderImpl = null + try { + recordReader = reader.rows.asInstanceOf[RecordReaderImpl] + + // BloomFilter array is created for all types; `struct`, int (`a`), string (`b`) + val sargColumns = Array(true, true, true) + val orcIndex = recordReader.readRowIndex(0, null, sargColumns) + + // Check the types and counts of bloom filters + assert(orcIndex.getBloomFilterKinds.forall(_ === bloomFilterKind)) + assert(orcIndex.getBloomFilterIndex.forall(_.getBloomFilterCount > 0)) + } finally { + if (recordReader != null) { + recordReader.close() + } + } + } + } + } + test("create temporary orc table") { checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) @@ -215,4 +280,8 @@ class OrcSourceSuite extends OrcSuite with SharedSQLContext { |) """.stripMargin) } + + test("Check BloomFilter creation") { + testBloomFilterCreation(Kind.BLOOM_FILTER_UTF8) // After ORC-101 + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index d84f9a3828207..c1ae2f6861cb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types._ @@ -173,4 +174,12 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { assert(msg.contains("ORC data source does not support calendarinterval data type.")) } } + + test("Check BloomFilter creation") { + Seq(true, false).foreach { convertMetastore => + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> s"$convertMetastore") { + testBloomFilterCreation(org.apache.orc.OrcProto.Stream.Kind.BLOOM_FILTER) // Before ORC-101 + } + } + } } From 8cf6fd1c2342949916fedb5a7f712177b22585fa Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 17 Sep 2018 20:40:42 +0800 Subject: [PATCH 1627/2461] [SPARK-25431][SQL][EXAMPLES] Fix function examples and the example results. ## What changes were proposed in this pull request? There are some mistakes in examples of newly added functions. Also the format of the example results are not unified. We should fix them. ## How was this patch tested? Manually executed the examples. Closes #22437 from ueshin/issues/SPARK-25431/fix_examples_2. Authored-by: Takuya UESHIN Signed-off-by: hyukjinkwon --- .../expressions/collectionOperations.scala | 37 ++++++++++--------- .../expressions/complexTypeCreator.scala | 2 +- .../expressions/higherOrderFunctions.scala | 32 ++++++++-------- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ea6fcccddfd49..cc9edcfd41d02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -159,9 +159,9 @@ case class MapKeys(child: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); - [[1, 2], [2, 3], [3, 4]] + [{"0":1,"1":2},{"0":2,"1":3},{"0":3,"1":4}] > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); - [[1, 2, 3], [2, 3, 4]] + [{"0":1,"1":2,"2":3},{"0":2,"1":3,"2":4}] """, since = "2.4.0") case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { @@ -348,7 +348,7 @@ case class MapValues(child: Expression) examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b')); - [(1,"a"),(2,"b")] + [{"key":1,"value":"a"},{"key":2,"value":"b"}] """, since = "2.4.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -516,7 +516,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); - [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + {1:"a",2:"c",3:"d"} """, since = "2.4.0") case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { @@ -1171,9 +1171,9 @@ case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLi examples = """ Examples: > SELECT _FUNC_(array(1, 20, 3, 5)); - [3, 1, 5, 20] + [3,1,5,20] > SELECT _FUNC_(array(1, 20, null, 3)); - [20, null, 3, 1] + [20,null,3,1] """, note = "The function is non-deterministic.", since = "2.4.0") @@ -1256,7 +1256,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) > SELECT _FUNC_('Spark SQL'); LQS krapS > SELECT _FUNC_(array(2, 1, 4, 3)); - [3, 4, 1, 2] + [3,4,1,2] """, since = "1.5.0", note = "Reverse logic for arrays is available since 2.4.0." @@ -2123,7 +2123,7 @@ case class ArrayPosition(left: Expression, right: Expression) > SELECT _FUNC_(array(1, 2, 3), 2); 2 > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); - "b" + b """, since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { @@ -2238,8 +2238,9 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti > SELECT _FUNC_('Spark', 'SQL'); SparkSQL > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - | [1,2,3,4,5,6] - """) + [1,2,3,4,5,6] + """, + note = "Concat logic for arrays is available since 2.4.0.") case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) @@ -2427,7 +2428,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", examples = """ Examples: - > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + > SELECT _FUNC_(array(array(1, 2), array(3, 4))); [1,2,3,4] """, since = "2.4.0") @@ -2556,11 +2557,11 @@ case class Flatten(child: Expression) extends UnaryExpression { examples = """ Examples: > SELECT _FUNC_(1, 5); - [1, 2, 3, 4, 5] + [1,2,3,4,5] > SELECT _FUNC_(5, 1); - [5, 4, 3, 2, 1] + [5,4,3,2,1] > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); - [2018-01-01, 2018-02-01, 2018-03-01] + [2018-01-01,2018-02-01,2018-03-01] """, since = "2.4.0" ) @@ -2934,7 +2935,7 @@ object Sequence { examples = """ Examples: > SELECT _FUNC_('123', 2); - ['123', '123'] + ["123","123"] """, since = "2.4.0") case class ArrayRepeat(left: Expression, right: Expression) @@ -3421,7 +3422,7 @@ object ArrayBinaryLike { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 2, 3, 5) + [1,2,3,5] """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike @@ -3632,7 +3633,7 @@ object ArrayUnion { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(1, 3) + [1,3] """, since = "2.4.0") case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBinaryLike @@ -3873,7 +3874,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); - array(2) + [2] """, since = "2.4.0") case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index aba9c6c8ad6fd..fd8b5e94fe48d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -248,7 +248,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { in keys should not be null""", examples = """ Examples: - > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + > SELECT _FUNC_(array(1.0, 3.0), array('2', '4')); {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 2bb6b20b944d4..b07d9466ba0d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -209,9 +209,9 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); - array(2, 3, 4) + [2,3,4] > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); - array(1, 3, 5) + [1,3,5] """, since = "2.4.0") case class ArrayTransform( @@ -268,7 +268,7 @@ usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", examples = """ Examples: > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); - [1 -> 0, 3 -> -1] + {1:0,3:-1} """, since = "2.4.0") case class MapFilter( @@ -318,7 +318,7 @@ case class MapFilter( examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); - array(1, 3) + [1,3] """, since = "2.4.0") case class ArrayFilter( @@ -499,10 +499,10 @@ case class ArrayAggregate( usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); - map(array(2, 3, 4), array(1, 2, 3)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); - map(array(2, 4, 6), array(1, 2, 3)) + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + {2:1,3:2,4:3} + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + {2:1,4:2,6:3} """, since = "2.4.0") case class TransformKeys( @@ -549,10 +549,10 @@ case class TransformKeys( usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); - map(array(1, 2, 3), array(2, 3, 4)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); - map(array(1, 2, 3), array(2, 4, 6)) + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + {1:2,2:3,3:4} + > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + {1:2,2:4,3:6} """, since = "2.4.0") case class TransformValues( @@ -777,11 +777,11 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); - array(('a', 1), ('b', 2), ('c', 3)) - > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); - array(4, 6) + [{"y":"a","x":1},{"y":"b","x":2},{"y":"c","x":3}] + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y); + [4,6] > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); - array('ad', 'be', 'cf') + ["ad","be","cf"] """, since = "2.4.0") // scalastyle:on line.size.limit From 30aa37fca45ec0ad4f30076bc855d1a201cfc097 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 17 Sep 2018 08:54:44 -0500 Subject: [PATCH 1628/2461] [SPARK-24654][BUILD][FOLLOWUP] Update, fix LICENSE and NOTICE, and specialize for source vs binary ## What changes were proposed in this pull request? Fix location of licenses-binary in binary release, and remove binary items from source release ## How was this patch tested? N/A Closes #22436 from srowen/SPARK-24654.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- dev/create-release/release-build.sh | 4 ++++ dev/make-distribution.sh | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index ca066bed133d2..098aa5745e34d 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -171,6 +171,10 @@ if [[ "$1" == "package" ]]; then # Source and binary tarballs echo "Packaging release source tarballs" cp -r spark spark-$SPARK_VERSION + # For source release, exclude copy of binary license/notice + rm spark-$SPARK_VERSION/LICENSE-binary + rm spark-$SPARK_VERSION/NOTICE-binary + rm -r spark-$SPARK_VERSION/licenses-binary tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ --detach-sig spark-$SPARK_VERSION.tgz diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 778d376c12b56..668682fbb913d 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -213,7 +213,6 @@ cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" -mkdir -p "$DISTDIR/licenses" cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" From 4b9542e3a3d0c493a05061be5a9f8d278c0ac980 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 17 Sep 2018 11:26:08 -0700 Subject: [PATCH 1629/2461] [SPARK-25423][SQL] Output "dataFilters" in DataSourceScanExec.metadata ## What changes were proposed in this pull request? Output `dataFilters` in `DataSourceScanExec.metadata`. ## How was this patch tested? unit tests Closes #22435 from wangyum/SPARK-25423. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/DataSourceScanExec.scala | 1 + .../DataSourceScanExecRedactionSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 36ed016773b67..738c0666bc3fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -284,6 +284,7 @@ case class FileSourceScanExec( "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), "PushedFilters" -> seqToString(pushedDownFilters), + "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) val withOptPartitionCount = relation.partitionSchemaOption.map { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index c8d045a32d73c..11a1c9a1f9b9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -83,4 +83,20 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { } } + test("FileSourceScanExec metadata") { + withTempPath { path => + val dir = path.getCanonicalPath + spark.range(0, 10).write.parquet(dir) + val df = spark.read.parquet(dir) + + assert(isIncluded(df.queryExecution, "Format")) + assert(isIncluded(df.queryExecution, "ReadSchema")) + assert(isIncluded(df.queryExecution, "Batched")) + assert(isIncluded(df.queryExecution, "PartitionFilters")) + assert(isIncluded(df.queryExecution, "PushedFilters")) + assert(isIncluded(df.queryExecution, "DataFilters")) + assert(isIncluded(df.queryExecution, "Location")) + } + } + } From 553af22f2c8ecdc039c8d06431564b1432e60d2d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 17 Sep 2018 11:33:50 -0700 Subject: [PATCH 1630/2461] [SPARK-16323][SQL] Add IntegralDivide expression ## What changes were proposed in this pull request? The PR takes over #14036 and it introduces a new expression `IntegralDivide` in order to avoid the several unneded cast added previously. In order to prove the performance gain, the following benchmark has been run: ``` test("Benchmark IntegralDivide") { val r = new scala.util.Random(91) val nData = 1000000 val testDataInt = (1 to nData).map(_ => (r.nextInt(), r.nextInt())) val testDataLong = (1 to nData).map(_ => (r.nextLong(), r.nextLong())) val testDataShort = (1 to nData).map(_ => (r.nextInt().toShort, r.nextInt().toShort)) // old code val oldExprsInt = testDataInt.map(x => Cast(Divide(Cast(Literal(x._1), DoubleType), Cast(Literal(x._2), DoubleType)), LongType)) val oldExprsLong = testDataLong.map(x => Cast(Divide(Cast(Literal(x._1), DoubleType), Cast(Literal(x._2), DoubleType)), LongType)) val oldExprsShort = testDataShort.map(x => Cast(Divide(Cast(Literal(x._1), DoubleType), Cast(Literal(x._2), DoubleType)), LongType)) // new code val newExprsInt = testDataInt.map(x => IntegralDivide(x._1, x._2)) val newExprsLong = testDataLong.map(x => IntegralDivide(x._1, x._2)) val newExprsShort = testDataShort.map(x => IntegralDivide(x._1, x._2)) Seq(("Long", "old", oldExprsLong), ("Long", "new", newExprsLong), ("Int", "old", oldExprsInt), ("Int", "new", newExprsShort), ("Short", "old", oldExprsShort), ("Short", "new", oldExprsShort)).foreach { case (dt, t, ds) => val start = System.nanoTime() ds.foreach(e => e.eval(EmptyRow)) val endNoCodegen = System.nanoTime() println(s"Running $nData op with $t code on $dt (no-codegen): ${(endNoCodegen - start) / 1000000} ms") } } ``` The results on my laptop are: ``` Running 1000000 op with old code on Long (no-codegen): 600 ms Running 1000000 op with new code on Long (no-codegen): 112 ms Running 1000000 op with old code on Int (no-codegen): 560 ms Running 1000000 op with new code on Int (no-codegen): 135 ms Running 1000000 op with old code on Short (no-codegen): 317 ms Running 1000000 op with new code on Short (no-codegen): 153 ms ``` Showing a 2-5X improvement. The benchmark doesn't include code generation as it is pretty hard to test the performance there as for such simple operations the most of the time is spent in the code generation/compilation process. ## How was this patch tested? added UTs Closes #22395 from mgaido91/SPARK-16323. Authored-by: Marco Gaido Signed-off-by: Dongjoon Hyun --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/dsl/package.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 28 +++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../ArithmeticExpressionSuite.scala | 18 ++++++------ .../parser/ExpressionParserSuite.scala | 4 +-- .../sql-tests/results/operators.sql.out | 8 +++--- 7 files changed, 45 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 77860e1584f42..8b69a47036962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -267,6 +267,7 @@ object FunctionRegistry { expression[Subtract]("-"), expression[Multiply]("*"), expression[Divide]("/"), + expression[IntegralDivide]("div"), expression[Remainder]("%"), // aggregate functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d3ccd18d0245e..176ea823b1fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -72,6 +72,7 @@ package object dsl { def - (other: Expression): Expression = Subtract(expr, other) def * (other: Expression): Expression = Multiply(expr, other) def / (other: Expression): Expression = Divide(expr, other) + def div (other: Expression): Expression = IntegralDivide(expr, other) def % (other: Expression): Expression = Remainder(expr, other) def & (other: Expression): Expression = BitwiseAnd(expr, other) def | (other: Expression): Expression = BitwiseOr(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c827226d58420..1b1808f8366d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -314,6 +314,34 @@ case class Divide(left: Expression, right: Expression) extends DivModLike { override def evalOperation(left: Any, right: Any): Any = div(left, right) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Divide `expr1` by `expr2` rounded to the long integer. It returns NULL if an operand is NULL or `expr2` is 0.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1 + """, + since = "2.5.0") +// scalastyle:on line.size.limit +case class IntegralDivide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = IntegralType + override def dataType: DataType = LongType + + override def symbol: String = "/" + override def sqlOperator: String = "div" + + private lazy val div: (Any, Any) => Long = left.dataType match { + case i: IntegralType => + val divide = i.integral.asInstanceOf[Integral[Any]].quot _ + val toLong = i.integral.asInstanceOf[Integral[Any]].toLong _ + (x, y) => toLong(divide(x, y)) + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7bc1f63e30540..5cfb5dc871041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1157,7 +1157,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.PERCENT => Remainder(left, right) case SqlBaseParser.DIV => - Cast(Divide(left, right), LongType) + IntegralDivide(left, right) case SqlBaseParser.PLUS => Add(left, right) case SqlBaseParser.MINUS => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 9a752af523ffc..c3c4d9ee6b702 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -143,16 +143,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. - // TODO: in future release, we should add a IntegerDivide to support integral types. - ignore("/ (Divide) for integral type") { - checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) - checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) - checkEvaluation(Divide(positiveShortLit, negativeShortLit), 0.toShort) - checkEvaluation(Divide(positiveIntLit, negativeIntLit), 0) - checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) + test("/ (Divide) for integral type") { + checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) + checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) + checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) + checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) + checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) + checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) + checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) } test("% (Remainder)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 781fc1e957ae0..b4df22c5b29fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -203,7 +203,7 @@ class ExpressionParserSuite extends PlanTest { // Simple operations assertEqual("a * b", 'a * 'b) assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a DIV b", 'a div 'b) assertEqual("a % b", 'a % 'b) assertEqual("a + b", 'a + 'b) assertEqual("a - b", 'a - 'b) @@ -214,7 +214,7 @@ class ExpressionParserSuite extends PlanTest { // Check precedences assertEqual( "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g div 'h) / 'i * 'k))))) } test("unary arithmetic expressions") { diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 840655b7a6447..2555734756fc4 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -157,7 +157,7 @@ NULL -- !query 19 select 5 div 2 -- !query 19 schema -struct +struct<(5 div 2):bigint> -- !query 19 output 2 @@ -165,7 +165,7 @@ struct -- !query 20 select 5 div 0 -- !query 20 schema -struct +struct<(5 div 0):bigint> -- !query 20 output NULL @@ -173,7 +173,7 @@ NULL -- !query 21 select 5 div null -- !query 21 schema -struct +struct<(5 div CAST(NULL AS INT)):bigint> -- !query 21 output NULL @@ -181,7 +181,7 @@ NULL -- !query 22 select null div 5 -- !query 22 schema -struct +struct<(CAST(NULL AS INT) div 5):bigint> -- !query 22 output NULL From 58419b92673c46911c25bc6c6b13397f880c6424 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 13 Aug 2018 21:35:34 -0500 Subject: [PATCH 1631/2461] [PYSPARK] Updates to pyspark broadcast --- .../apache/spark/api/python/PythonRDD.scala | 299 +++++++++++++++--- .../spark/api/python/PythonRunner.scala | 52 ++- .../spark/api/python/PythonRDDSuite.scala | 23 +- dev/sparktestsupport/modules.py | 2 + python/pyspark/broadcast.py | 58 +++- python/pyspark/context.py | 64 +++- python/pyspark/serializers.py | 51 +++ python/pyspark/sql/session.py | 12 +- python/pyspark/sql/tests.py | 45 ++- python/pyspark/test_broadcast.py | 126 ++++++++ python/pyspark/test_serializers.py | 90 ++++++ python/pyspark/tests.py | 9 +- python/pyspark/worker.py | 22 +- .../spark/sql/api/python/PythonSQLUtils.scala | 47 ++- .../sql/execution/arrow/ArrowConverters.scala | 9 +- 15 files changed, 789 insertions(+), 120 deletions(-) create mode 100644 python/pyspark/test_broadcast.py create mode 100644 python/pyspark/test_serializers.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index e639a842754bd..8b5a7a9aefea5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.Promise +import scala.concurrent.duration.Duration import scala.language.existentials -import scala.util.control.NonFatal +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -169,27 +172,34 @@ private[spark] object PythonRDD extends Logging { def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) + readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism) + } + + def readRDDFromInputStream( + sc: SparkContext, + in: InputStream, + parallelism: Int): JavaRDD[Array[Byte]] = { + val din = new DataInputStream(in) try { val objs = new mutable.ArrayBuffer[Array[Byte]] try { while (true) { - val length = file.readInt() + val length = din.readInt() val obj = new Array[Byte](length) - file.readFully(obj) + din.readFully(obj) objs += obj } } catch { case eof: EOFException => // No-op } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + JavaRDD.fromRDD(sc.parallelize(objs, parallelism)) } finally { - file.close() + din.close() } } - def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = { - sc.broadcast(new PythonBroadcast(path)) + def setupBroadcast(path: String): PythonBroadcast = { + new PythonBroadcast(path) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { @@ -419,34 +429,15 @@ private[spark] object PythonRDD extends Logging { */ private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) - - new Thread(threadName) { - setDaemon(true) - override def run() { - try { - val sock = serverSocket.accept() - authHelper.authClient(sock) - - val out = new BufferedOutputStream(sock.getOutputStream) - Utils.tryWithSafeFinally { - writeFunc(out) - } { - out.close() - sock.close() - } - } catch { - case NonFatal(e) => - logError(s"Error while sending iterator", e) - } finally { - serverSocket.close() - } + val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s => + val out = new BufferedOutputStream(s.getOutputStream()) + Utils.tryWithSafeFinally { + writeFunc(out) + } { + out.close() } - }.start() - - Array(serverSocket.getLocalPort, authHelper.secret) + } + Array(port, secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], @@ -664,13 +655,11 @@ private[spark] class PythonAccumulatorV2( } } -/** - * A Wrapper for Python Broadcast, which is written into disk by Python. It also will - * write the data into disk after deserialization, then Python can read it from disks. - */ // scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable - with Logging { + with Logging { + + private var encryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -713,5 +702,235 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } super.finalize() } + + def setupEncryptionServer(): Array[Any] = { + encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") { + override def handleConnection(sock: Socket): Unit = { + val env = SparkEnv.get + val in = sock.getInputStream() + val dir = new File(Utils.getLocalDir(env.conf)) + val file = File.createTempFile("broadcast", "", dir) + path = file.getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + DechunkedInputStream.dechunkAndCopyToOutput(in, out) + } + } + Array(encryptionServer.port, encryptionServer.secret) + } + + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize + +/** + * The inverse of pyspark's ChunkedStream for sending data of unknown size. + * + * We might be serializing a really large object from python -- we don't want + * python to buffer the whole thing in memory, nor can it write to a file, + * so we don't know the length in advance. So python writes it in chunks, each chunk + * preceeded by a length, till we get a "length" of -1 which serves as EOF. + * + * Tested from python tests. + */ +private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging { + private val din = new DataInputStream(wrapped) + private var remainingInChunk = din.readInt() + + override def read(): Int = { + val into = new Array[Byte](1) + val n = read(into, 0, 1) + if (n == -1) { + -1 + } else { + // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted + // as an EOF + val b = into(0) + if (b < 0) { + 256 + b + } else { + b + } + } + } + + override def read(dest: Array[Byte], off: Int, len: Int): Int = { + if (remainingInChunk == -1) { + return -1 + } + var destSpace = len + var destPos = off + while (destSpace > 0 && remainingInChunk != -1) { + val toCopy = math.min(remainingInChunk, destSpace) + val read = din.read(dest, destPos, toCopy) + destPos += read + destSpace -= read + remainingInChunk -= read + if (remainingInChunk == 0) { + remainingInChunk = din.readInt() + } + } + assert(destSpace == 0 || remainingInChunk == -1) + return destPos - off + } + + override def close(): Unit = wrapped.close() +} + +private[spark] object DechunkedInputStream { + + /** + * Dechunks the input, copies to output, and closes both input and the output safely. + */ + def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = { + val dechunked = new DechunkedInputStream(chunked) + Utils.tryWithSafeFinally { + Utils.copyStream(dechunked, out) + } { + JavaUtils.closeQuietly(out) + JavaUtils.closeQuietly(dechunked) + } + } +} + +/** + * Creates a server in the jvm to communicate with python for handling one batch of data, with + * authentication and error handling. + */ +private[spark] abstract class PythonServer[T]( + authHelper: SocketAuthHelper, + threadName: String) { + + def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) + def this(threadName: String) = this(SparkEnv.get, threadName) + + val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock => + promise.complete(Try(handleConnection(sock))) + } + + /** + * Handle a connection which has already been authenticated. Any error from this function + * will clean up this connection and the entire server, and get propogated to [[getResult]]. + */ + def handleConnection(sock: Socket): T + + val promise = Promise[T]() + + /** + * Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If + * handleConnection throws an exception, this will throw an exception which includes the original + * exception as a cause. + */ + def getResult(): T = { + getResult(Duration.Inf) + } + + def getResult(wait: Duration): T = { + ThreadUtils.awaitResult(promise.future, wait) + } + +} + +private[spark] object PythonServer { + + /** + * Create a socket server and run user function on the socket in a background thread. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * The thread will terminate after the supplied user function, or if there are any exceptions. + * + * If you need to get a result of the supplied function, create a subclass of [[PythonServer]] + * + * @return The port number of a local socket and the secret for authentication. + */ + def setupOneConnectionServer( + authHelper: SocketAuthHelper, + threadName: String) + (func: Socket => Unit): (Int, String) = { + val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + // Close the socket if no connection in 15 seconds + serverSocket.setSoTimeout(15000) + + new Thread(threadName) { + setDaemon(true) + override def run(): Unit = { + var sock: Socket = null + try { + sock = serverSocket.accept() + authHelper.authClient(sock) + func(sock) + } finally { + JavaUtils.closeQuietly(serverSocket) + JavaUtils.closeQuietly(sock) + } + } + }.start() + (serverSocket.getLocalPort, authHelper.secret) + } +} + +/** + * Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol. + */ +private[spark] class EncryptedPythonBroadcastServer( + val env: SparkEnv, + val idsAndFiles: Seq[(Long, String)]) + extends PythonServer[Unit]("broadcast-decrypt-server") with Logging { + + override def handleConnection(socket: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream())) + var socketIn: InputStream = null + // send the broadcast id, then the decrypted data. We don't need to send the length, the + // the python pickle module just needs a stream. + Utils.tryWithSafeFinally { + (idsAndFiles).foreach { case (id, path) => + out.writeLong(id) + val in = env.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + } + logTrace("waiting for python to accept broadcast data over socket") + out.flush() + socketIn = socket.getInputStream() + socketIn.read() + logTrace("done serving broadcast data") + } { + JavaUtils.closeQuietly(socketIn) + JavaUtils.closeQuietly(out) + } + } + + def waitTillBroadcastDataSent(): Unit = { + getResult() + } +} + +/** + * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python + * over a socket. This is used in preference to writing data to a file when encryption is enabled. + */ +private[spark] abstract class PythonRDDServer + extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") { + + def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = { + val in = sock.getInputStream() + val dechunkedInput: InputStream = new DechunkedInputStream(in) + streamToRDD(dechunkedInput) + } + + protected def streamToRDD(input: InputStream): RDD[Array[Byte]] + +} + +private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int) + extends PythonRDDServer { + + override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = { + PythonRDD.readRDDFromInputStream(sc, input, parallelism) + } +} + diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4c53bc269a104..6e53a044e9a8c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -289,19 +289,51 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size + val addedBids = newBids.diff(oldBids) + val cnt = toRemove.size + addedBids.size + val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty + dataOut.writeBoolean(needsDecryptionServer) dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) + def sendBidsToRemove(): Unit = { + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(-bid - 1) // bid >= 0 + oldBids.remove(bid) + } } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { + if (needsDecryptionServer) { + // if there is encryption, we setup a server which reads the encrypted files, and sends + // the decrypted data to python + val idsAndFiles = broadcastVars.flatMap { broadcast => + if (!oldBids.contains(broadcast.id)) { + Some((broadcast.id, broadcast.value.path)) + } else { + None + } + } + val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) + dataOut.writeInt(server.port) + logTrace(s"broadcast decryption server setup on ${server.port}") + PythonRDD.writeUTF(server.secret, dataOut) + sendBidsToRemove() + idsAndFiles.foreach { case (id, _) => // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) + dataOut.writeLong(id) + oldBids.add(id) + } + dataOut.flush() + logTrace("waiting for python to read decrypted broadcast data from server") + server.waitTillBroadcastDataSent() + logTrace("done sending decrypted data to python") + } else { + sendBidsToRemove() + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } } } dataOut.flush() diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 05b4e67412f2e..6f9b583898c38 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.net.{InetAddress, Socket} import java.nio.charset.StandardCharsets -import org.apache.spark.SparkFunSuite +import scala.concurrent.duration.Duration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.security.SocketAuthHelper class PythonRDDSuite extends SparkFunSuite { @@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite { ("a".getBytes(StandardCharsets.UTF_8), null), (null, "b".getBytes(StandardCharsets.UTF_8))), buffer) } + + test("python server error handling") { + val authHelper = new SocketAuthHelper(new SparkConf()) + val errorServer = new ExceptionPythonServer(authHelper) + val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port) + authHelper.authToServer(client) + val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) } + assert(ex.getCause().getMessage().contains("exception within handleConnection")) + } + + class ExceptionPythonServer(authHelper: SocketAuthHelper) + extends PythonServer[Unit](authHelper, "error-server") { + + override def handleConnection(sock: Socket): Unit = { + throw new Exception("exception within handleConnection") + } + } } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2aa355504bf29..e267fbfa623b5 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -387,6 +387,8 @@ def __hash__(self): "pyspark.profiler", "pyspark.shuffle", "pyspark.tests", + "pyspark.test_broadcast", + "pyspark.test_serializers", "pyspark.util", ] ) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index b3dfc99962a35..1c7f2a7418df0 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -15,13 +15,16 @@ # limitations under the License. # +import gc import os +import socket import sys -import gc from tempfile import NamedTemporaryFile import threading from pyspark.cloudpickle import print_exec +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ChunkedStream from pyspark.util import _exception_message if sys.version < '3': @@ -64,19 +67,43 @@ class Broadcast(object): >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, sc=None, value=None, pickle_registry=None, path=None): + def __init__(self, sc=None, value=None, pickle_registry=None, path=None, + sock_file=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ if sc is not None: + # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) - self._path = self.dump(value, f) - self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) + self._path = f.name + python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + if sc._encryption_enabled: + # with encryption, we ask the jvm to do the encryption for us, we send it data + # over a socket + port, auth_secret = python_broadcast.setupEncryptionServer() + (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) + broadcast_out = ChunkedStream(encryption_sock_file, 8192) + else: + # no encryption, we can just write pickled data directly to the file from python + broadcast_out = f + self.dump(value, broadcast_out) + if sc._encryption_enabled: + python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(python_broadcast) self._pickle_registry = pickle_registry else: + # we're on an executor self._jbroadcast = None - self._path = path + if sock_file is not None: + # the jvm is doing decryption for us. Read the value + # immediately from the sock_file + self._value = self.load(sock_file) + else: + # the jvm just dumps the pickled data in path -- we'll unpickle lazily when + # the value is requested + assert(path is not None) + self._path = path def dump(self, value, f): try: @@ -89,24 +116,25 @@ def dump(self, value, f): print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close() - return f.name - def load(self, path): + def load_from_path(self, path): with open(path, 'rb', 1 << 20) as f: - # pickle.load() may create lots of objects, disable GC - # temporary for better performance - gc.disable() - try: - return pickle.load(f) - finally: - gc.enable() + return self.load(f) + + def load(self, file): + # "file" could also be a socket + gc.disable() + try: + return pickle.load(file) + finally: + gc.enable() @property def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load(self._path) + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4cabae4b2f50b..2c92c29a1cc1b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -33,9 +33,9 @@ from pyspark.broadcast import Broadcast, BroadcastPickleRegistry from pyspark.conf import SparkConf from pyspark.files import SparkFiles -from pyspark.java_gateway import launch_gateway +from pyspark.java_gateway import launch_gateway, local_connect_and_auth from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, AutoBatchedSerializer, NoOpSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call @@ -189,6 +189,13 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) + # If encryption is enabled, we need to setup a server in the jvm to read broadcast + # data via a socket. + # scala's mangled names w/ $ in them require special treatment. + encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\ + .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED() + self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] @@ -498,23 +505,46 @@ def f(split, iterator): def reader_func(temp_filename): return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) - jrdd = self._serialize_to_jvm(c, serializer, reader_func) + def createRDDServer(): + return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices) + + jrdd = self._serialize_to_jvm(c, serializer, reader_func, createRDDServer) return RDD(jrdd, self, serializer) - def _serialize_to_jvm(self, data, serializer, reader_func): - """ - Calling the Java parallelize() method with an ArrayList is too slow, - because it sends O(n) Py4J commands. As an alternative, serialized - objects are written to a file and loaded through textFile(). - """ - tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - try: - serializer.dump_stream(data, tempFile) - tempFile.close() - return reader_func(tempFile.name) - finally: - # readRDDFromFile eagerily reads the file so we can delete right after. - os.unlink(tempFile.name) + def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer): + """ + Using py4j to send a large dataset to the jvm is really slow, so we use either a file + or a socket if we have encryption enabled. + :param data: + :param serializer: + :param reader_func: A function which takes a filename and reads in the data in the jvm and + returns a JavaRDD. Only used when encryption is disabled. + :param createRDDServer: A function which creates a PythonRDDServer in the jvm to + accept the serialized data, for use when encryption is enabled. + :return: + """ + if self._encryption_enabled: + # with encryption, we open a server in java and send the data directly + server = createRDDServer() + (sock_file, _) = local_connect_and_auth(server.port(), server.secret()) + chunked_out = ChunkedStream(sock_file, 8192) + serializer.dump_stream(data, chunked_out) + chunked_out.close() + # this call will block until the server has read all the data and processed it (or + # throws an exception) + r = server.getResult() + return r + else: + # without encryption, we serialize to a file, and we read the file in java and + # parallelize from there. + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + try: + serializer.dump_stream(data, tempFile) + tempFile.close() + return reader_func(tempFile.name) + finally: + # we eagerily reads the file so we can delete right after. + os.unlink(tempFile.name) def pickleFile(self, name, minPartitions=None): """ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 48006778e86f2..ff9a612b77f61 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -731,6 +731,57 @@ def write_with_length(obj, stream): stream.write(obj) +class ChunkedStream(object): + + """ + This is a file-like object takes a stream of data, of unknown length, and breaks it into fixed + length frames. The intended use case is serializing large data and sending it immediately over + a socket -- we do not want to buffer the entire data before sending it, but the receiving end + needs to know whether or not there is more data coming. + + It works by buffering the incoming data in some fixed-size chunks. If the buffer is full, it + first sends the buffer size, then the data. This repeats as long as there is more data to send. + When this is closed, it sends the length of whatever data is in the buffer, then that data, and + finally a "length" of -1 to indicate the stream has completed. + """ + + def __init__(self, wrapped, buffer_size): + self.buffer_size = buffer_size + self.buffer = bytearray(buffer_size) + self.current_pos = 0 + self.wrapped = wrapped + + def write(self, bytes): + byte_pos = 0 + byte_remaining = len(bytes) + while byte_remaining > 0: + new_pos = byte_remaining + self.current_pos + if new_pos < self.buffer_size: + # just put it in our buffer + self.buffer[self.current_pos:new_pos] = bytes[byte_pos:] + self.current_pos = new_pos + byte_remaining = 0 + else: + # fill the buffer, send the length then the contents, and start filling again + space_left = self.buffer_size - self.current_pos + new_byte_pos = byte_pos + space_left + self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos] + write_int(self.buffer_size, self.wrapped) + self.wrapped.write(self.buffer) + byte_remaining -= space_left + byte_pos = new_byte_pos + self.current_pos = 0 + + def close(self): + # if there is anything left in the buffer, write it out first + if self.current_pos > 0: + write_int(self.current_pos, self.wrapped) + self.wrapped.write(self.buffer[:self.current_pos]) + # -1 length indicates to the receiving end that we're done. + write_int(-1, self.wrapped) + self.wrapped.close() + + if __name__ == '__main__': import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 87d8d6a59a6e9..51a38ebfd19ff 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -539,12 +539,18 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct + jsqlContext = self._wrapped._jsqlContext + def reader_func(temp_filename): - return self._jvm.PythonSQLUtils.arrowReadStreamFromFile( - self._wrapped._jsqlContext, temp_filename, schema.json()) + return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) + + def create_RDD_server(): + return self._jvm.ArrowRDDServer(jsqlContext) # Create Spark DataFrame from Arrow stream file, using one batch per partition - jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func) + jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func, + create_RDD_server) + jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) df = DataFrame(jdf, self._wrapped) df._schema = schema return df diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8e5bc6729dfa4..08d7cfadc084c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -26,6 +26,7 @@ import pydoc import shutil import tempfile +import threading import pickle import functools import time @@ -228,12 +229,12 @@ def sql_conf(self, pairs): class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): @classmethod def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() + super(ReusedSQLTestCase, cls).setUpClass() cls.spark = SparkSession(cls.sc) @classmethod def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() + super(ReusedSQLTestCase, cls).tearDownClass() cls.spark.stop() def assertPandasEqual(self, expected, result): @@ -4105,7 +4106,8 @@ def setUpClass(cls): from decimal import Decimal from distutils.version import LooseVersion import pyarrow as pa - ReusedSQLTestCase.setUpClass() + super(ArrowTests, cls).setUpClass() + cls.warnings_lock = threading.Lock() # Synchronize default timezone between Python and Java cls.tz_prev = os.environ.get("TZ", None) # save current tz if set @@ -4146,7 +4148,7 @@ def tearDownClass(cls): if cls.tz_prev is not None: os.environ["TZ"] = cls.tz_prev time.tzset() - ReusedSQLTestCase.tearDownClass() + super(ArrowTests, cls).tearDownClass() def create_pandas_data_frame(self): import pandas as pd @@ -4166,15 +4168,18 @@ def test_toPandas_fallback_enabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) with QuietTest(self.sc): - with warnings.catch_warnings(record=True) as warns: - pdf = df.toPandas() - # Catch and check the last UserWarning. - user_warns = [ - warn.message for warn in warns if isinstance(warn.message, UserWarning)] - self.assertTrue(len(user_warns) > 0) - self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) - self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + with self.warnings_lock: + with warnings.catch_warnings(record=True) as warns: + # we want the warnings to appear even if this test is run from a subclass + warnings.simplefilter("always") + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): from distutils.version import LooseVersion @@ -4183,8 +4188,9 @@ def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type'): - df.toPandas() + with self.warnings_lock: + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): @@ -4396,6 +4402,8 @@ def test_createDataFrame_fallback_enabled(self): with QuietTest(self.sc): with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): with warnings.catch_warnings(record=True) as warns: + # we want the warnings to appear even if this test is run from a subclass + warnings.simplefilter("always") df = self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") # Catch and check the last UserWarning. @@ -4439,6 +4447,13 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_pandas.toPandas()) +class EncryptionArrowTests(ArrowTests): + + @classmethod + def conf(cls): + return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") + + @unittest.skipIf( not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py new file mode 100644 index 0000000000000..a00329c18ad8f --- /dev/null +++ b/python/pyspark/test_broadcast.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import random +import tempfile +import unittest + +try: + import xmlrunner +except ImportError: + xmlrunner = None + +from pyspark.broadcast import Broadcast +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import ChunkedStream + + +class BroadcastTest(unittest.TestCase): + + def tearDown(self): + if getattr(self, "sc", None) is not None: + self.sc.stop() + self.sc = None + + def _test_encryption_helper(self, vs): + """ + Creates a broadcast variables for each value in vs, and runs a simple job to make sure the + value is the same when it's read in the executors. Also makes sure there are no task + failures. + """ + bs = [self.sc.broadcast(value=v) for v in vs] + exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect() + for ev in exec_values: + self.assertEqual(ev, vs) + # make sure there are no task failures + status = self.sc.statusTracker() + for jid in status.getJobIdsForGroup(): + for sid in status.getJobInfo(jid).stageIds: + stage_info = status.getStageInfo(sid) + self.assertEqual(0, stage_info.numFailedTasks) + + def _test_multiple_broadcasts(self, *extra_confs): + """ + Test broadcast variables make it OK to the executors. Tests multiple broadcast variables, + and also multiple jobs. + """ + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + self._test_encryption_helper([5]) + self._test_encryption_helper([5, 10, 20]) + + def test_broadcast_with_encryption(self): + self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true")) + + def test_broadcast_no_encryption(self): + self._test_multiple_broadcasts() + + +class BroadcastFrameProtocolTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + gateway = launch_gateway(SparkConf()) + cls._jvm = gateway.jvm + cls.longMessage = True + random.seed(42) + + def _test_chunked_stream(self, data, py_buf_size): + # write data using the chunked protocol from python. + chunked_file = tempfile.NamedTemporaryFile(delete=False) + dechunked_file = tempfile.NamedTemporaryFile(delete=False) + dechunked_file.close() + try: + out = ChunkedStream(chunked_file, py_buf_size) + out.write(data) + out.close() + # now try to read it in java + jin = self._jvm.java.io.FileInputStream(chunked_file.name) + jout = self._jvm.java.io.FileOutputStream(dechunked_file.name) + self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout) + # java should have decoded it back to the original data + self.assertEqual(len(data), os.stat(dechunked_file.name).st_size) + with open(dechunked_file.name, "rb") as f: + byte = f.read(1) + idx = 0 + while byte: + self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx)) + byte = f.read(1) + idx += 1 + finally: + os.unlink(chunked_file.name) + os.unlink(dechunked_file.name) + + def test_chunked_stream(self): + def random_bytes(n): + return bytearray(random.getrandbits(8) for _ in range(n)) + for data_length in [1, 10, 100, 10000]: + for buffer_length in [1, 2, 5, 8192]: + self._test_chunked_stream(random_bytes(data_length), buffer_length) + +if __name__ == '__main__': + from pyspark.test_broadcast import * + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + else: + unittest.main(verbosity=2) diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py new file mode 100644 index 0000000000000..5b43729f9ebb1 --- /dev/null +++ b/python/pyspark/test_serializers.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import io +import math +import struct +import sys +import unittest + +try: + import xmlrunner +except ImportError: + xmlrunner = None + +from pyspark import serializers + + +def read_int(b): + return struct.unpack("!i", b)[0] + + +def write_int(i): + return struct.pack("!i", i) + + +class SerializersTest(unittest.TestCase): + + def test_chunked_stream(self): + original_bytes = bytearray(range(100)) + for data_length in [1, 10, 100]: + for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: + dest = ByteArrayOutput() + stream_out = serializers.ChunkedStream(dest, buffer_length) + stream_out.write(original_bytes[:data_length]) + stream_out.close() + num_chunks = int(math.ceil(float(data_length) / buffer_length)) + # length for each chunk, and a final -1 at the very end + exp_size = (num_chunks + 1) * 4 + data_length + self.assertEqual(len(dest.buffer), exp_size) + dest_pos = 0 + data_pos = 0 + for chunk_idx in range(num_chunks): + chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) + if chunk_idx == num_chunks - 1: + exp_length = data_length % buffer_length + if exp_length == 0: + exp_length = buffer_length + else: + exp_length = buffer_length + self.assertEqual(chunk_length, exp_length) + dest_pos += 4 + dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] + orig_chunk = original_bytes[data_pos:data_pos + chunk_length] + self.assertEqual(dest_chunk, orig_chunk) + dest_pos += chunk_length + data_pos += chunk_length + # ends with a -1 + self.assertEqual(dest.buffer[-4:], write_int(-1)) + + +class ByteArrayOutput(object): + def __init__(self): + self.buffer = bytearray() + + def write(self, b): + self.buffer += b + + def close(self): + pass + +if __name__ == '__main__': + from pyspark.test_serializers import * + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + else: + unittest.main(verbosity=2) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8ac1df52fc597..050c2dd018360 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -372,9 +372,16 @@ def tearDown(self): class ReusedPySparkTestCase(unittest.TestCase): + @classmethod + def conf(cls): + """ + Override this in subclasses to supply a more specific conf + """ + return SparkConf() + @classmethod def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__) + cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) @classmethod def tearDownClass(cls): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e934da4d2eb6e..974344f01d923 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -324,16 +324,34 @@ def main(infile, outfile): importlib.invalidate_caches() # fetch names and values of broadcast variables + needs_broadcast_decryption_server = read_bool(infile) num_broadcast_variables = read_int(infile) + if needs_broadcast_decryption_server: + # read the decrypted data from a server in the jvm + port = read_int(infile) + auth_secret = utf8_deserializer.loads(infile) + (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) + for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - path = utf8_deserializer.loads(infile) - _broadcastRegistry[bid] = Broadcast(path=path) + if needs_broadcast_decryption_server: + read_bid = read_long(broadcast_sock_file) + assert(read_bid == bid) + _broadcastRegistry[bid] = \ + Broadcast(sock_file=broadcast_sock_file) + else: + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) + else: bid = - bid - 1 _broadcastRegistry.pop(bid) + if needs_broadcast_decryption_server: + broadcast_sock_file.write(b'1') + broadcast_sock_file.close() + _accumulatorRegistry.clear() eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index c0830e77b5a87..482e2bfeb7098 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,6 +17,12 @@ package org.apache.spark.sql.api.python +import java.io.InputStream +import java.nio.channels.Channels + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.PythonRDDServer +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -33,19 +39,36 @@ private[sql] object PythonSQLUtils { } /** - * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * Python callable function to read a file in Arrow stream format and create a [[RDD]] * using each serialized ArrowRecordBatch as a partition. - * - * @param sqlContext The active [[SQLContext]]. - * @param filename File to read the Arrow stream from. - * @param schemaString JSON Formatted Spark schema for Arrow batches. - * @return A new [[DataFrame]]. */ - def arrowReadStreamFromFile( - sqlContext: SQLContext, - filename: String, - schemaString: String): DataFrame = { - val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) - ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) + def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = { + ArrowConverters.readArrowStreamFromFile(sqlContext, filename) + } + + /** + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * from an RDD. + */ + def toDataFrame( + arrowBatchRDD: JavaRDD[Array[Byte]], + schemaString: String, + sqlContext: SQLContext): DataFrame = { + ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext) } } + +/** + * Helper for making a dataframe from arrow data from data sent from python over a socket. This is + * used when encryption is enabled, and we don't want to write data to a file. + */ +private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends PythonRDDServer { + + override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = { + // Create array to consume iterator so that we can safely close the inputStream + val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 1a48bc8398a63..2bf6a58b55658 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} -import java.nio.channels.{Channels, SeekableByteChannel} +import java.nio.channels.{Channels, ReadableByteChannel} import scala.collection.JavaConverters._ @@ -31,6 +31,7 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -189,7 +190,7 @@ private[sql] object ArrowConverters { } /** - * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + * Create a DataFrame from an RDD of serialized ArrowRecordBatches. */ private[sql] def toDataFrame( arrowBatchRDD: JavaRDD[Array[Byte]], @@ -221,7 +222,7 @@ private[sql] object ArrowConverters { /** * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. */ - private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + private[sql] def getBatchesFromStream(in: ReadableByteChannel): Iterator[Array[Byte]] = { // Iterate over the serialized Arrow RecordBatch messages from a stream new Iterator[Array[Byte]] { @@ -271,7 +272,7 @@ private[sql] object ArrowConverters { } else { if (bodyLength > 0) { // Skip message body if not a RecordBatch - in.position(in.position() + bodyLength) + Channels.newInputStream(in).skip(bodyLength) } // Proceed to next message From 8f5a5a9e5b9f273443b2721f80c99dc7397ef4c0 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 6 Sep 2018 12:11:47 -0500 Subject: [PATCH 1632/2461] [PYSPARK][SQL] Updates to RowQueue Tested with updates to RowQueueSuite --- .../spark/sql/execution/python/RowQueue.scala | 27 +++++++++++++----- .../sql/execution/python/RowQueueSuite.scala | 28 ++++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index e2fa6e7f504ba..d2820ff335ecf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -21,9 +21,10 @@ import java.io._ import com.google.common.io.Closeables -import org.apache.spark.SparkException +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryBlock @@ -108,9 +109,13 @@ private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any * reader has begun reading from the queue. */ -private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue { - private var out = new DataOutputStream( - new BufferedOutputStream(new FileOutputStream(file.toString))) +private[python] case class DiskRowQueue( + file: File, + fields: Int, + serMgr: SerializerManager) extends RowQueue { + + private var out = new DataOutputStream(serMgr.wrapForEncryption( + new BufferedOutputStream(new FileOutputStream(file.toString)))) private var unreadBytes = 0L private var in: DataInputStream = _ @@ -131,7 +136,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu if (out != null) { out.close() out = null - in = new DataInputStream(new NioBufferedFileInputStream(file)) + in = new DataInputStream(serMgr.wrapForEncryption( + new NioBufferedFileInputStream(file))) } if (unreadBytes > 0) { @@ -166,7 +172,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu private[python] case class HybridRowQueue( memManager: TaskMemoryManager, tempDir: File, - numFields: Int) + numFields: Int, + serMgr: SerializerManager) extends MemoryConsumer(memManager) with RowQueue { // Each buffer should have at least one row @@ -212,7 +219,7 @@ private[python] case class HybridRowQueue( } private def createDiskQueue(): RowQueue = { - DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields) + DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr) } private def createNewQueue(required: Long): RowQueue = { @@ -279,3 +286,9 @@ private[python] case class HybridRowQueue( } } } + +private[python] object HybridRowQueue { + def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = { + HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c5..1ec9986328429 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.python import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.internal.config._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Utils -class RowQueueSuite extends SparkFunSuite { +class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite { test("in-memory queue") { val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) @@ -53,10 +56,20 @@ class RowQueueSuite extends SparkFunSuite { queue.close() } - test("disk queue") { + private def createSerializerManager(conf: SparkConf): SerializerManager = { + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } + new SerializerManager(new JavaSerializer(conf), conf, ioEncryptionKey) + } + + encryptionTest("disk queue") { conf => + val serManager = createSerializerManager(conf) val dir = Utils.createTempDir().getCanonicalFile dir.mkdirs() - val queue = DiskRowQueue(new File(dir, "buffer"), 1) + val queue = DiskRowQueue(new File(dir, "buffer"), 1, serManager) val row = new UnsafeRow(1) row.pointTo(new Array[Byte](16), 16) val n = 1000 @@ -81,11 +94,12 @@ class RowQueueSuite extends SparkFunSuite { queue.close() } - test("hybrid queue") { - val mem = new TestMemoryManager(new SparkConf()) + encryptionTest("hybrid queue") { conf => + val serManager = createSerializerManager(conf) + val mem = new TestMemoryManager(conf) mem.limit(4<<10) val taskM = new TaskMemoryManager(mem, 0) - val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1) + val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager) val row = new UnsafeRow(1) row.pointTo(new Array[Byte](16), 16) val n = (4<<10) / 16 * 3 From a97001d21757ae214c86371141bd78a376200f66 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 22 Aug 2018 16:38:28 -0500 Subject: [PATCH 1633/2461] [CORE] Updates to remote cache reads Covered by tests in DistributedSuite --- .../spark/network/buffer/ManagedBuffer.java | 5 +- .../spark/network/shuffle/DownloadFile.java | 47 ++++++++++ ...eManager.java => DownloadFileManager.java} | 12 +-- .../shuffle/DownloadFileWritableChannel.java | 30 ++++++ .../shuffle/ExternalShuffleClient.java | 4 +- .../shuffle/OneForOneBlockFetcher.java | 28 ++---- .../spark/network/shuffle/ShuffleClient.java | 4 +- .../network/shuffle/SimpleDownloadFile.java | 91 +++++++++++++++++++ .../spark/network/BlockTransferService.scala | 6 +- .../netty/NettyBlockTransferService.scala | 4 +- .../apache/spark/storage/BlockManager.scala | 78 +++++++++++++--- .../org/apache/spark/storage/DiskStore.scala | 16 ++++ .../storage/ShuffleBlockFetcherIterator.scala | 21 +++-- .../spark/storage/BlockManagerSuite.scala | 8 +- .../ShuffleBlockFetcherIteratorSuite.scala | 6 +- 15 files changed, 298 insertions(+), 62 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/{TempFileManager.java => DownloadFileManager.java} (75%) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index 1861f8d7fd8f3..2d573f512437e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -36,7 +36,10 @@ */ public abstract class ManagedBuffer { - /** Number of bytes of the data. */ + /** + * Number of bytes of the data. If this buffer will decrypt for all of the views into the data, + * this is the size of the decrypted data. + */ public abstract long size(); /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java new file mode 100644 index 0000000000000..633622b35175b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; + +/** + * A handle on the file used when fetching remote data to disk. Used to ensure the lifecycle of + * writing the data, reading it back, and then cleaning it up is followed. Specific implementations + * may also handle encryption. The data can be read only via DownloadFileWritableChannel, + * which ensures data is not read until after the writer is closed. + */ +public interface DownloadFile { + /** + * Delete the file. + * + * @return true if and only if the file or directory is + * successfully deleted; false otherwise + */ + boolean delete(); + + /** + * A channel for writing data to the file. This special channel allows access to the data for + * reading, after the channel is closed, via {@link DownloadFileWritableChannel#closeAndRead()}. + */ + DownloadFileWritableChannel openForWriting() throws IOException; + + /** + * The path of the file, intended only for debug purposes. + */ + String path(); +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java similarity index 75% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java index 552364d274f19..c335a17ae1fe0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java @@ -17,20 +17,20 @@ package org.apache.spark.network.shuffle; -import java.io.File; +import org.apache.spark.network.util.TransportConf; /** - * A manager to create temp block files to reduce the memory usage and also clean temp - * files when they won't be used any more. + * A manager to create temp block files used when fetching remote data to reduce the memory usage. + * It will clean files when they won't be used any more. */ -public interface TempFileManager { +public interface DownloadFileManager { /** Create a temp block file. */ - File createTempFile(); + DownloadFile createTempFile(TransportConf transportConf); /** * Register a temp file to clean up when it won't be used any more. Return whether the * file is registered successfully. If `false`, the caller should clean up the file by itself. */ - boolean registerTempFileToClean(File file); + boolean registerTempFileToClean(DownloadFile file); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java new file mode 100644 index 0000000000000..dbbbac43eb741 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.apache.spark.network.buffer.ManagedBuffer; + +import java.nio.channels.WritableByteChannel; + +/** + * A channel for writing data which is fetched to disk, which allows access to the written data only + * after the writer has been closed. Used with DownloadFile and DownloadFileManager. + */ +public interface DownloadFileWritableChannel extends WritableByteChannel { + ManagedBuffer closeAndRead(); +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 7ed0b6e93a7a8..9a2cf0f953481 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -91,7 +91,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempFileManager tempFileManager) { + DownloadFileManager downloadFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +99,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempFileManager).start(); + blockIds1, listener1, conf, downloadFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 0bc571874f07c..30587023877c1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,18 +17,13 @@ package org.apache.spark.network.shuffle; -import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; import java.util.Arrays; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; @@ -58,7 +53,7 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private final TransportConf transportConf; - private final TempFileManager tempFileManager; + private final DownloadFileManager downloadFileManager; private StreamHandle streamHandle = null; @@ -79,14 +74,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempFileManager tempFileManager) { + DownloadFileManager downloadFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.tempFileManager = tempFileManager; + this.downloadFileManager = downloadFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -125,7 +120,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempFileManager != null) { + if (downloadFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -159,13 +154,13 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { private class DownloadCallback implements StreamCallback { - private WritableByteChannel channel = null; - private File targetFile = null; + private DownloadFileWritableChannel channel = null; + private DownloadFile targetFile = null; private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tempFileManager.createTempFile(); - this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + this.targetFile = downloadFileManager.createTempFile(transportConf); + this.channel = targetFile.openForWriting(); this.chunkIndex = chunkIndex; } @@ -178,11 +173,8 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { - channel.close(); - ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, - targetFile.length()); - listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (!tempFileManager.registerTempFileToClean(targetFile)) { + listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead()); + if (!downloadFileManager.registerTempFileToClean(targetFile)) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 18b04fedcac5b..62b99c40f61f9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -43,7 +43,7 @@ public void init(String appId) { } * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. - * @param tempFileManager TempFileManager to create and clean temp files. + * @param downloadFileManager DownloadFileManager to create and clean temp files. * If it's not null, the remote blocks will be streamed * into temp shuffle files to reduce the memory usage, otherwise, * they will be kept in memory. @@ -54,7 +54,7 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempFileManager tempFileManager); + DownloadFileManager downloadFileManager); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java new file mode 100644 index 0000000000000..670612fd6f66a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.TransportConf; + +/** + * A DownloadFile that does not take any encryption settings into account for reading and + * writing data. + * + * This does *not* mean the data in the file is un-encrypted -- it could be that the data is + * already encrypted when its written, and subsequent layer is responsible for decrypting. + */ +public class SimpleDownloadFile implements DownloadFile { + + private final File file; + private final TransportConf transportConf; + + public SimpleDownloadFile(File file, TransportConf transportConf) { + this.file = file; + this.transportConf = transportConf; + } + + @Override + public boolean delete() { + return file.delete(); + } + + @Override + public DownloadFileWritableChannel openForWriting() throws IOException { + return new SimpleDownloadWritableChannel(); + } + + @Override + public String path() { + return file.getAbsolutePath(); + } + + private class SimpleDownloadWritableChannel implements DownloadFileWritableChannel { + + private final WritableByteChannel channel; + + SimpleDownloadWritableChannel() throws FileNotFoundException { + channel = Channels.newChannel(new FileOutputStream(file)); + } + + @Override + public ManagedBuffer closeAndRead() { + return new FileSegmentManagedBuffer(transportConf, file, 0, file.length()); + } + + @Override + public int write(ByteBuffer src) throws IOException { + return channel.write(src); + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } + + @Override + public void close() throws IOException { + channel.close(); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 1d8a266d0079c..eef8c31e05ab1 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit + tempFileManager: DownloadFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -92,7 +92,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: DownloadFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 1905632a936d3..dc55685b1e7bd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -33,7 +33,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -106,7 +106,7 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: DownloadFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f5c69ad241e3a..22341467add5c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,6 +33,7 @@ import scala.util.Random import scala.util.control.NonFatal import com.codahale.metrics.{MetricRegistry, MetricSet} +import com.google.common.io.CountingOutputStream import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} @@ -43,8 +44,9 @@ import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} @@ -213,11 +215,11 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ - // A TempFileManager used to track all the files of remote blocks which above the + // A DownloadFileManager used to track all the files of remote blocks which are above the // specified memory threshold. Files will be deleted automatically based on weak reference. // Exposed for test private[storage] val remoteBlockTempFileManager = - new BlockManager.RemoteBlockTempFileManager(this) + new BlockManager.RemoteBlockDownloadFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) /** @@ -1664,23 +1666,28 @@ private[spark] object BlockManager { metricRegistry.registerAll(metricSet) } - class RemoteBlockTempFileManager(blockManager: BlockManager) - extends TempFileManager with Logging { + class RemoteBlockDownloadFileManager(blockManager: BlockManager) + extends DownloadFileManager with Logging { + // lazy because SparkEnv is set after this + lazy val encryptionKey = SparkEnv.get.securityManager.getIOEncryptionKey() - private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File]) - extends WeakReference[File](file, referenceQueue) { - private val filePath = file.getAbsolutePath + private class ReferenceWithCleanup( + file: DownloadFile, + referenceQueue: JReferenceQueue[DownloadFile] + ) extends WeakReference[DownloadFile](file, referenceQueue) { + + val filePath = file.path() def cleanUp(): Unit = { logDebug(s"Clean up file $filePath") - if (!new File(filePath).delete()) { + if (!file.delete()) { logDebug(s"Fail to delete file $filePath") } } } - private val referenceQueue = new JReferenceQueue[File] + private val referenceQueue = new JReferenceQueue[DownloadFile] private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup]( new ConcurrentHashMap) @@ -1692,11 +1699,21 @@ private[spark] object BlockManager { cleaningThread.setName("RemoteBlock-temp-file-clean-thread") cleaningThread.start() - override def createTempFile(): File = { - blockManager.diskBlockManager.createTempLocalBlock()._2 + override def createTempFile(transportConf: TransportConf): DownloadFile = { + val file = blockManager.diskBlockManager.createTempLocalBlock()._2 + encryptionKey match { + case Some(key) => + // encryption is enabled, so when we read the decrypted data off the network, we need to + // encrypt it when writing to disk. Note that the data may have been encrypted when it + // was cached on disk on the remote side, but it was already decrypted by now (see + // EncryptedBlockData). + new EncryptedDownloadFile(file, key) + case None => + new SimpleDownloadFile(file, transportConf) + } } - override def registerTempFileToClean(file: File): Boolean = { + override def registerTempFileToClean(file: DownloadFile): Boolean = { referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue)) } @@ -1724,4 +1741,39 @@ private[spark] object BlockManager { } } } + + /** + * A DownloadFile that encrypts data when it is written, and decrypts when it's read. + */ + private class EncryptedDownloadFile( + file: File, + key: Array[Byte]) extends DownloadFile { + + private val env = SparkEnv.get + + override def delete(): Boolean = file.delete() + + override def openForWriting(): DownloadFileWritableChannel = { + new EncryptedDownloadWritableChannel() + } + + override def path(): String = file.getAbsolutePath + + private class EncryptedDownloadWritableChannel extends DownloadFileWritableChannel { + private val countingOutput: CountingWritableChannel = new CountingWritableChannel( + Channels.newChannel(env.serializerManager.wrapForEncryption(new FileOutputStream(file)))) + + override def closeAndRead(): ManagedBuffer = { + countingOutput.close() + val size = countingOutput.getCount + new EncryptedManagedBuffer(new EncryptedBlockData(file, size, env.conf, key)) + } + + override def write(src: ByteBuffer): Int = countingOutput.write(src) + + override def isOpen: Boolean = countingOutput.isOpen() + + override def close(): Unit = countingOutput.close() + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index a820bc70b33b2..d88bd710d1ead 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.Utils @@ -260,7 +261,22 @@ private class EncryptedBlockData( throw e } } +} + +private class EncryptedManagedBuffer(val blockData: EncryptedBlockData) extends ManagedBuffer { + + // This is the size of the decrypted data + override def size(): Long = blockData.size + + override def nioByteBuffer(): ByteBuffer = blockData.toByteBuffer() + + override def convertToNetty(): AnyRef = blockData.toNetty() + + override def createInputStream(): InputStream = blockData.toInputStream() + + override def retain(): ManagedBuffer = this + override def release(): ManagedBuffer = this } private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 00d01dd28afb5..e534c746433f2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{File, InputStream, IOException} +import java.io.{InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -28,7 +28,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -71,7 +72,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -150,7 +151,7 @@ final class ShuffleBlockFetcherIterator( * deleted when cleanup. This is a layer of defensiveness against disk file leaks. */ @GuardedBy("this") - private[this] val shuffleFilesSet = mutable.HashSet[File]() + private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]() initialize() @@ -164,11 +165,15 @@ final class ShuffleBlockFetcherIterator( currentResult = null } - override def createTempFile(): File = { - blockManager.diskBlockManager.createTempLocalBlock()._2 + override def createTempFile(transportConf: TransportConf): DownloadFile = { + // we never need to do any encryption or decryption here, regardless of configs, because that + // is handled at another layer in the code. When encryption is enabled, shuffle data is written + // to disk encrypted in the first place, and sent over the network still encrypted. + new SimpleDownloadFile( + blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf) } - override def registerTempFileToClean(file: File): Boolean = synchronized { + override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized { if (isZombie) { false } else { @@ -204,7 +209,7 @@ final class ShuffleBlockFetcherIterator( } shuffleFilesSet.foreach { file => if (!file.delete()) { - logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()) + logWarning("Failed to cleanup shuffle fetch temp file " + file.path()) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index dbee1f60d7af0..32d6e8b94e1a2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv @@ -1437,7 +1437,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 - var tempFileManager: TempFileManager = null + var tempFileManager: DownloadFileManager = null override def init(blockDataManager: BlockDataManager): Unit = {} @@ -1447,7 +1447,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: DownloadFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1474,7 +1474,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: DownloadFileManager): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a2997dbd1b1ac..b268195e09a5b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -478,12 +478,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempFileManager: TempFileManager = null + var tempFileManager: DownloadFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] + tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) From 0f1413e320bbf9804dac1b00d56f30bc20dc36a6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 18 Sep 2018 10:10:20 +0800 Subject: [PATCH 1634/2461] [SPARK-25443][BUILD] fix issues when building docs with release scripts in docker ## What changes were proposed in this pull request? These 2 changes are required to build the docs for Spark 2.4.0 RC1: 1. install `mkdocs` in the docker image 2. set locale to C.UTF-8. Otherwise jekyll fails to build the doc. ## How was this patch tested? tested manually when doing the 2.4.0 RC1 Closes #22438 from cloud-fan/infra. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- dev/create-release/release-build.sh | 5 ++--- dev/create-release/spark-rm/Dockerfile | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 098aa5745e34d..4753c29b03874 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -76,9 +76,8 @@ for env in ASF_USERNAME GPG_PASSPHRASE GPG_KEY; do fi done -# Explicitly set locale in order to make `sort` output consistent across machines. -# See https://stackoverflow.com/questions/28881 for more details. -export LC_ALL=C +export LC_ALL=C.UTF-8 +export LANG=C.UTF-8 # Commit ref to checkout when building GIT_REF=${GIT_REF:-master} diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 07ce320177f5a..15f831cf06a66 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -73,7 +73,7 @@ RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ Rscript -e "devtools::install_github('jimhester/lintr')" && \ # Install tools needed to build the documentation. - $APT_INSTALL ruby2.3 ruby2.3-dev && \ + $APT_INSTALL ruby2.3 ruby2.3-dev mkdocs && \ gem install jekyll --no-rdoc --no-ri && \ gem install jekyll-redirect-from && \ gem install pygments.rb From acc6452579aca99aae9f9787ddbe5c4aeb170e58 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 18 Sep 2018 12:44:54 +0900 Subject: [PATCH 1635/2461] [SPARK-25444][SQL] Refactor GenArrayData.genCodeToCreateArrayData method ## What changes were proposed in this pull request? This PR makes `GenArrayData.genCodeToCreateArrayData` method simple by using `ArrayData.createArrayData` method. Before this PR, `genCodeToCreateArrayData` method was complicated * Generated a temporary Java array to create `ArrayData` * Had separate code generation path to assign values for `GenericArrayData` and `UnsafeArrayData` After this PR, the method * Directly generates `GenericArrayData` or `UnsafeArrayData` without a temporary array * Has only code generation path to assign values ## How was this patch tested? Existing UTs Closes #22439 from kiszk/SPARK-25444. Authored-by: Kazuaki Ishizaki Signed-off-by: Takuya UESHIN --- .../expressions/complexTypeCreator.scala | 122 +++++++----------- 1 file changed, 45 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index fd8b5e94fe48d..0361372b6b732 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -61,11 +61,10 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val et = dataType.elementType - val evals = children.map(e => e.genCode(ctx)) - val (preprocess, assigns, postprocess, arrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) + val (allocation, assigns, arrayData) = + GenArrayData.genCodeToCreateArrayData(ctx, et, children, false, "createArray") ev.copy( - code = code"${preprocess}${assigns}${postprocess}", + code = code"${allocation}${assigns}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -75,87 +74,60 @@ case class CreateArray(children: Seq[Expression]) extends Expression { private [sql] object GenArrayData { /** - * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class + * Return Java code pieces based on DataType and array size to allocate ArrayData class * * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements - * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array + * @param elementsExpr concatenated set of [[Expression]] for each element of an underlying array * @param isMapKey if true, throw an exception when the element is null - * @return (code pre-assignments, concatenated assignments to each array elements, - * code post-assignments, arrayData name) + * @param functionName string to include in the error message + * @return (array allocation, concatenated assignments to each array elements, arrayData name) */ def genCodeToCreateArrayData( ctx: CodegenContext, elementType: DataType, - elementsCode: Seq[ExprCode], - isMapKey: Boolean): (String, String, String, String) = { + elementsExpr: Seq[Expression], + isMapKey: Boolean, + functionName: String): (String, String, String) = { val arrayDataName = ctx.freshName("arrayData") - val numElements = elementsCode.length + val numElements = s"${elementsExpr.length}L" - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayName = ctx.freshName("arrayObject") - val genericArrayClass = classOf[GenericArrayData].getName + val initialization = CodeGenerator.createArrayData( + arrayDataName, elementType, numElements, s" $functionName failed.") - val assignments = elementsCode.zipWithIndex.map { case (eval, i) => - val isNullAssignment = if (!isMapKey) { - s"$arrayName[$i] = null;" - } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" - } - eval.code + s""" - if (${eval.isNull}) { - $isNullAssignment - } else { - $arrayName[$i] = ${eval.value}; - } - """ - } - val assignmentString = ctx.splitExpressionsWithCurrentInputs( - expressions = assignments, - funcName = "apply", - extraArguments = ("Object[]", arrayName) :: Nil) - - (s"Object[] $arrayName = new Object[$numElements];", - assignmentString, - s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", - arrayDataName) - } else { - val arrayName = ctx.freshName("array") - val unsafeArraySizeInBytes = - UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + - ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) - val baseOffset = Platform.BYTE_ARRAY_OFFSET - - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val assignments = elementsCode.zipWithIndex.map { case (eval, i) => + val assignments = elementsExpr.zipWithIndex.map { case (expr, i) => + val eval = expr.genCode(ctx) + val setArrayElement = CodeGenerator.setArrayElement( + arrayDataName, elementType, i.toString, eval.value) + + val assignment = if (!expr.nullable) { + setArrayElement + } else { val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" } else { "throw new RuntimeException(\"Cannot use null as map key!\");" } - eval.code + s""" - if (${eval.isNull}) { - $isNullAssignment - } else { - $arrayDataName.set$primitiveValueTypeName($i, ${eval.value}); - } - """ + + s""" + |if (${eval.isNull}) { + | $isNullAssignment + |} else { + | $setArrayElement + |} + """.stripMargin } - val assignmentString = ctx.splitExpressionsWithCurrentInputs( - expressions = assignments, - funcName = "apply", - extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) - - (s""" - byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - UnsafeArrayData $arrayDataName = new UnsafeArrayData(); - Platform.putLong($arrayName, $baseOffset, $numElements); - $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); - """, - assignmentString, - "", - arrayDataName) + s""" + |${eval.code} + |$assignment + """.stripMargin } + val assignmentString = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "apply", + extraArguments = ("ArrayData", arrayDataName) :: Nil) + + (initialization, assignmentString, arrayDataName) } } @@ -216,21 +188,17 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapClass = classOf[ArrayBasedMapData].getName val MapType(keyDt, valueDt, _) = dataType - val evalKeys = keys.map(e => e.genCode(ctx)) - val evalValues = values.map(e => e.genCode(ctx)) - val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true) - val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) + val (allocationKeyData, assignKeys, keyArrayData) = + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys, true, "createMap") + val (allocationValueData, assignValues, valueArrayData) = + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values, false, "createMap") val code = code""" final boolean ${ev.isNull} = false; - $preprocessKeyData + $allocationKeyData $assignKeys - $postprocessKeyData - $preprocessValueData + $allocationValueData $assignValues - $postprocessValueData final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); """ ev.copy(code = code) From ba838fee001d553e30d2205337c1fa5ccbd57caf Mon Sep 17 00:00:00 2001 From: James Thompson Date: Mon, 17 Sep 2018 23:19:04 -0700 Subject: [PATCH 1636/2461] [SPARK-24151][SQL] Case insensitive resolution of CURRENT_DATE and CURRENT_TIMESTAMP ## What changes were proposed in this pull request? SPARK-22333 introduced a regression in the resolution of `CURRENT_DATE` and `CURRENT_TIMESTAMP`. Before that ticket, these 2 functions were resolved in a case insensitive way. After, this depends on the value of `spark.sql.caseSensitive`. The PR restores the previous behavior and makes their resolution case insensitive anyhow. The PR takes over #21217, therefore it closes #21217 and credit for this patch should be given to jamesthomp. ## How was this patch tested? added UT Closes #22440 from mgaido91/SPARK-24151. Lead-authored-by: James Thompson Co-authored-by: Marco Gaido Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 1 + .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e262987ab23de..f25415c0bc748 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1879,6 +1879,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - In versions 2.2.1+ and 2.3, if `spark.sql.caseSensitive` is set to true, then the `CURRENT_DATE` and `CURRENT_TIMESTAMP` functions incorrectly became case-sensitive and would resolve to columns (unless typed in lower case). In Spark 2.4 this has been fixed and the functions are no longer case-sensitive. - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 580133dd971b1..e3b17121bf350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1045,7 +1045,7 @@ class Analyzer( // support CURRENT_DATE and CURRENT_TIMESTAMP val literalFunctions = Seq(CurrentDate(), CurrentTimestamp()) val name = nameParts.head - val func = literalFunctions.find(e => resolver(e.prettyName, name)) + val func = literalFunctions.find(e => caseInsensitiveResolution(e.prettyName, name)) func.map(wrapper) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3b3edac0a314e..f9facbb71a4e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -586,4 +588,20 @@ class AnalysisSuite extends AnalysisTest with Matchers { listRelation.select(MultiAlias(MultiAlias( PosExplode('list), Seq("first_pos", "first_val")), Seq("second_pos", "second_val")))) } + + test("SPARK-24151: CURRENT_DATE, CURRENT_TIMESTAMP should be case insensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val input = Project(Seq( + UnresolvedAttribute("current_date"), + UnresolvedAttribute("CURRENT_DATE"), + UnresolvedAttribute("CURRENT_TIMESTAMP"), + UnresolvedAttribute("current_timestamp")), testRelation) + val expected = Project(Seq( + Alias(CurrentDate(), toPrettySQL(CurrentDate()))(), + Alias(CurrentDate(), toPrettySQL(CurrentDate()))(), + Alias(CurrentTimestamp(), toPrettySQL(CurrentTimestamp()))(), + Alias(CurrentTimestamp(), toPrettySQL(CurrentTimestamp()))()), testRelation).analyze + checkAnalysis(input, expected) + } + } } From 1c0423b28705eb96237c0cb4e90f49305c64a997 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 18 Sep 2018 22:29:00 +0800 Subject: [PATCH 1637/2461] [SPARK-25445][BUILD] the release script should be able to publish a scala-2.12 build ## What changes were proposed in this pull request? update the package and publish steps, to support scala 2.12 ## How was this patch tested? manual test Closes #22441 from cloud-fan/scala. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- dev/create-release/release-build.sh | 49 +++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 4753c29b03874..4c90a772104fc 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -111,13 +111,17 @@ fi # different versions of Scala are supported. BASE_PROFILES="-Pmesos -Pyarn" PUBLISH_SCALA_2_10=0 +PUBLISH_SCALA_2_12=0 SCALA_2_10_PROFILES="-Pscala-2.10" SCALA_2_11_PROFILES= -SCALA_2_12_PROFILES="-Pscala-2.12" +SCALA_2_12_PROFILES="-Pscala-2.12 -Pkafka-0-8" if [[ $SPARK_VERSION > "2.3" ]]; then BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" SCALA_2_11_PROFILES="-Pkafka-0-8" + if [[ $SPARK_VERSION > "2.4" ]]; then + PUBLISH_SCALA_2_12=1 + fi else PUBLISH_SCALA_2_10=1 fi @@ -186,8 +190,17 @@ if [[ "$1" == "package" ]]; then # Updated for each binary build make_binary_release() { NAME=$1 - FLAGS="$MVN_EXTRA_OPTS -B $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES $2" - BUILD_PACKAGE=$3 + SCALA_VERSION=$2 + SCALA_PROFILES= + if [[ SCALA_VERSION == "2.10" ]]; then + SCALA_PROFILES="$SCALA_2_10_PROFILES" + elif [[ SCALA_VERSION == "2.12" ]]; then + SCALA_PROFILES="$SCALA_2_12_PROFILES" + else + SCALA_PROFILES="$SCALA_2_11_PROFILES" + fi + FLAGS="$MVN_EXTRA_OPTS -B $SCALA_PROFILES $BASE_RELEASE_PROFILES $3" + BUILD_PACKAGE=$4 # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. @@ -197,10 +210,11 @@ if [[ "$1" == "package" ]]; then cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME - # TODO There should probably be a flag to make-distribution to allow 2.12 support - #if [[ $FLAGS == *scala-2.12* ]]; then - # ./dev/change-scala-version.sh 2.12 - #fi + if [[ SCALA_VERSION == "2.10" ]]; then + ./dev/change-scala-version.sh 2.10 + elif [[ SCALA_VERSION == "2.12" ]]; then + ./dev/change-scala-version.sh 2.12 + fi export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" @@ -291,11 +305,20 @@ if [[ "$1" == "package" ]]; then for key in ${!BINARY_PKGS_ARGS[@]}; do args=${BINARY_PKGS_ARGS[$key]} extra=${BINARY_PKGS_EXTRA[$key]} - if ! make_binary_release "$key" "$args" "$extra"; then + if ! make_binary_release "$key" "2.11" "$args" "$extra"; then error "Failed to build $key package. Check logs for details." fi done + if [[ $PUBLISH_SCALA_2_12 = 1 ]]; then + key="without-hadoop-scala-2.12" + args="-Phadoop-provided" + extra="" + if ! make_binary_release "$key" "2.12" "$args" "$extra"; then + error "Failed to build $key package. Check logs for details." + fi + fi + rm -rf spark-$SPARK_VERSION-bin-*/ if ! is_dry_run; then @@ -414,15 +437,15 @@ if [[ "$1" == "publish-release" ]]; then -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install fi - #./dev/change-scala-version.sh 2.12 - #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ - # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_12 = 1 ]]; then + ./dev/change-scala-version.sh 2.12 + $MVN -DzincPort=$((ZINC_PORT + 2)) -Dmaven.repo.local=$tmp_repo -Dscala-2.12 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_12_PROFILES clean install + fi # Clean-up Zinc nailgun process $LSOF -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill - #./dev/change-scala-version.sh 2.11 - pushd $tmp_repo/org/apache/spark # Remove any extra files generated during install From 182da81e9e75ac1658a39014beb90e60495bf544 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 18 Sep 2018 10:38:55 -0500 Subject: [PATCH 1638/2461] [SPARK-19550][DOC][FOLLOW-UP] Update tuning.md to use JDK8 ## What changes were proposed in this pull request? Update `tuning.md` and `rdd-programming-guide.md` to use JDK8. ## How was this patch tested? manual tests Closes #22446 from wangyum/java8. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- docs/rdd-programming-guide.md | 4 ++-- docs/tuning.md | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index d95b757f36859..005425754c646 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -859,7 +859,7 @@ We could also use `counts.sortByKey()`, for example, to sort the pairs alphabeti **Note:** when using custom objects as the key in key-value pair operations, you must be sure that a custom `equals()` method is accompanied with a matching `hashCode()` method. For full details, see the contract outlined in the [Object.hashCode() -documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#hashCode()). +documentation](https://docs.oracle.com/javase/8/docs/api/java/lang/Object.html#hashCode--). @@ -896,7 +896,7 @@ We could also use `counts.sortByKey()`, for example, to sort the pairs alphabeti **Note:** when using custom objects as the key in key-value pair operations, you must be sure that a custom `equals()` method is accompanied with a matching `hashCode()` method. For full details, see the contract outlined in the [Object.hashCode() -documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#hashCode()). +documentation](https://docs.oracle.com/javase/8/docs/api/java/lang/Object.html#hashCode--). diff --git a/docs/tuning.md b/docs/tuning.md index f60971aa2e0af..cd0f9cd081369 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -26,12 +26,12 @@ Often, this will be the first thing you should tune to optimize a Spark applicat Spark aims to strike a balance between convenience (allowing you to work with any Java type in your operations) and performance. It provides two serialization libraries: -* [Java serialization](http://docs.oracle.com/javase/6/docs/api/java/io/Serializable.html): +* [Java serialization](https://docs.oracle.com/javase/8/docs/api/java/io/Serializable.html): By default, Spark serializes objects using Java's `ObjectOutputStream` framework, and can work with any class you create that implements - [`java.io.Serializable`](http://docs.oracle.com/javase/6/docs/api/java/io/Serializable.html). + [`java.io.Serializable`](https://docs.oracle.com/javase/8/docs/api/java/io/Serializable.html). You can also control the performance of your serialization more closely by extending - [`java.io.Externalizable`](http://docs.oracle.com/javase/6/docs/api/java/io/Externalizable.html). + [`java.io.Externalizable`](https://docs.oracle.com/javase/8/docs/api/java/io/Externalizable.html). Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. * [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use @@ -230,7 +230,7 @@ temporary objects created during task execution. Some steps which may be useful * Monitor how the frequency and time taken by garbage collection changes with the new settings. Our experience suggests that the effect of GC tuning depends on your application and the amount of memory available. -There are [many more tuning options](http://www.oracle.com/technetwork/java/javase/gc-tuning-6-140523.html) described online, +There are [many more tuning options](https://docs.oracle.com/javase/8/docs/technotes/guides/vm/gctuning/index.html) described online, but at a high level, managing how frequently full GC takes place can help in reducing the overhead. GC tuning flags for executors can be specified by setting `spark.executor.extraJavaOptions` in From 123f0041d534f28e14343aafb4e5cec19dde14ad Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Tue, 18 Sep 2018 11:43:35 -0700 Subject: [PATCH 1639/2461] [SPARK-25291][K8S] Fixing Flakiness of Executor Pod tests ## What changes were proposed in this pull request? Added fix to flakiness that was present in PySpark tests w.r.t Executors not being tested. Important fix to executorConf which was failing tests when executors *were* tested ## How was this patch tested? Unit and Integration tests Closes #22415 from ifilonenko/SPARK-25291. Authored-by: Ilan Filonenko Signed-off-by: Yinan Li --- .../k8s/integrationtest/KubernetesSuite.scala | 35 +++++++++++++------ .../KubernetesTestComponents.scala | 1 - .../integrationtest/SecretsTestsSuite.scala | 3 +- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 18541baf05813..c99a907f98d0a 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -23,7 +23,10 @@ import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} +import org.scalatest.Matchers import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} import scala.collection.JavaConverters._ @@ -31,10 +34,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.k8s.integrationtest.TestConfig._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} +import org.apache.spark.internal.Logging private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite - with PythonTestsSuite with ClientModeTestsSuite { + with PythonTestsSuite with ClientModeTestsSuite + with Logging with Eventually with Matchers { import KubernetesSuite._ @@ -223,17 +228,28 @@ private[spark] class KubernetesSuite extends SparkFunSuite .getItems .get(0) driverPodChecker(driverPod) - - val executorPods = kubernetesTestComponents.kubernetesClient + val execPods = scala.collection.mutable.Map[String, Pod]() + val execWatcher = kubernetesTestComponents.kubernetesClient .pods() .withLabel("spark-app-locator", appLocator) .withLabel("spark-role", "executor") - .list() - .getItems - executorPods.asScala.foreach { pod => - executorPodChecker(pod) - } - + .watch(new Watcher[Pod] { + logInfo("Beginning watch of executors") + override def onClose(cause: KubernetesClientException): Unit = + logInfo("Ending watch of executors") + override def eventReceived(action: Watcher.Action, resource: Pod): Unit = { + val name = resource.getMetadata.getName + action match { + case Action.ADDED | Action.MODIFIED => + execPods(name) = resource + case Action.DELETED | Action.ERROR => + execPods.remove(name) + } + } + }) + Eventually.eventually(TIMEOUT, INTERVAL) { execPods.values.nonEmpty should be (true) } + execWatcher.close() + execPods.values.foreach(executorPodChecker(_)) Eventually.eventually(TIMEOUT, INTERVAL) { expectedLogOnCompletion.foreach { e => assert(kubernetesTestComponents.kubernetesClient @@ -244,7 +260,6 @@ private[spark] class KubernetesSuite extends SparkFunSuite } } } - protected def doBasicDriverPodCheck(driverPod: Pod): Unit = { assert(driverPod.getMetadata.getName === driverPodName) assert(driverPod.getSpec.getContainers.get(0).getImage === image) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index b602fdf39731f..5615d6173eebd 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -62,7 +62,6 @@ private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesCl new SparkAppConf() .set("spark.master", s"k8s://${kubernetesClient.getMasterUrl}") .set("spark.kubernetes.namespace", namespace) - .set("spark.executor.memory", "500m") .set("spark.executor.cores", "1") .set("spark.executors.instances", "1") .set("spark.app.name", "spark-test-app") diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala index 9b039bb98dd9a..b18a6aebda497 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SecretsTestsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.integrationtest import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{Pod, Secret, SecretBuilder} +import io.fabric8.kubernetes.api.model.{Pod, SecretBuilder} import org.apache.commons.codec.binary.Base64 import org.apache.commons.io.output.ByteArrayOutputStream import org.scalatest.concurrent.Eventually @@ -53,7 +53,6 @@ private[spark] trait SecretsTestsSuite { k8sSuite: KubernetesSuite => .delete() } - // TODO: [SPARK-25291] This test is flaky with regards to memory of executors test("Run SparkPi with env and mount secrets.", k8sTestTag) { createTestSecret() sparkAppConf From a6f37b0742d87d5c8ee3e134999d665e5719e822 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 18 Sep 2018 16:33:37 -0500 Subject: [PATCH 1640/2461] [SPARK-25456][SQL][TEST] Fix PythonForeachWriterSuite PythonForeachWriterSuite was failing because RowQueue now needs to have a handle on a SparkEnv with a SerializerManager, so added a mock env with a serializer manager. Also fixed a typo in the `finally` that was hiding the real exception. Tested PythonForeachWriterSuite locally, full tests via jenkins. Closes #22452 from squito/SPARK-25456. Authored-by: Imran Rashid Signed-off-by: Imran Rashid --- .../python/PythonForeachWriterSuite.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala index 07e6034770127..d02014c0dee54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -19,17 +19,20 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable.ArrayBuffer +import org.mockito.Mockito.when import org.scalatest.concurrent.Eventually +import org.scalatest.mockito.MockitoSugar import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer import org.apache.spark.sql.types.{DataType, IntegerType} import org.apache.spark.util.Utils -class PythonForeachWriterSuite extends SparkFunSuite with Eventually { +class PythonForeachWriterSuite extends SparkFunSuite with Eventually with MockitoSugar { testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => b.assertIteratorBlocked() @@ -75,7 +78,7 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually { tester = new BufferTester(memBytes, sleepPerRowReadMs) f(tester) } finally { - if (tester == null) tester.close() + if (tester != null) tester.close() } } } @@ -83,7 +86,12 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually { class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { private val buffer = { - val mem = new TestMemoryManager(new SparkConf()) + val mockEnv = mock[SparkEnv] + val conf = new SparkConf() + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf, None) + when(mockEnv.serializerManager).thenReturn(serializerManager) + SparkEnv.set(mockEnv) + val mem = new TestMemoryManager(conf) mem.limit(memBytes) val taskM = new TaskMemoryManager(mem, 0) new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) From 497f00f62b3ddd1f40507fdfe10f30cd9effb6cf Mon Sep 17 00:00:00 2001 From: Santiago Saavedra Date: Tue, 18 Sep 2018 22:08:50 -0700 Subject: [PATCH 1641/2461] [SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore Several configuration parameters related to Kubernetes need to be reset, as they are changed with each invokation of spark-submit and thus prevents recovery of Spark Streaming tasks. ## What changes were proposed in this pull request? When using the Kubernetes cluster-manager and spawning a Streaming workload, it is important to reset many spark.kubernetes.* properties that are generated by spark-submit but which would get rewritten when restoring a Checkpoint. This is so, because the spark-submit codepath creates Kubernetes resources, such as a ConfigMap, a Secret and other variables, which have an autogenerated name and the previous one will not resolve anymore. In short, this change enables checkpoint restoration for streaming workloads, and thus enables Spark Streaming workloads in Kubernetes, which were not possible to restore from a checkpoint before if the workload went down. ## How was this patch tested? This patch needs would benefit from testing in different k8s clusters. This is similar to the YARN related code for resetting a Spark Streaming workload, but for the Kubernetes scheduler. This PR removes the initcontainers properties that existed before because they are now removed in master. For a previous discussion, see the non-rebased work at: apache-spark-on-k8s#516 Closes #22392 from ssaavedra/fix-checkpointing-master. Authored-by: Santiago Saavedra Signed-off-by: Yinan Li --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 3703a87cdb9ab..a882558551e37 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -54,6 +54,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.bindAddress", "spark.driver.port", "spark.master", + "spark.kubernetes.driver.pod.name", + "spark.kubernetes.executor.podNamePrefix", "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", @@ -64,6 +66,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) .remove("spark.driver.host") .remove("spark.driver.bindAddress") .remove("spark.driver.port") + .remove("spark.kubernetes.driver.pod.name") + .remove("spark.kubernetes.executor.podNamePrefix") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => newReloadConf.getOption(prop).foreach { value => From 6c7db7fd1ced1d143b1389d09990a620fc16be46 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 18 Sep 2018 22:39:29 -0700 Subject: [PATCH 1642/2461] [SPARK-23173][SQL] rename spark.sql.fromJsonForceNullableSchema ## What changes were proposed in this pull request? `spark.sql.fromJsonForceNullableSchema` -> `spark.sql.function.fromJson.forceNullable` ## How was this patch tested? Made sure there are no more references to `spark.sql.fromJsonForceNullableSchema`. Closes #22459 from rxin/SPARK-23173. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../catalyst/expressions/jsonExpressions.scala | 4 ++-- .../org/apache/spark/sql/internal/SQLConf.scala | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bd9090a07471b..ade10ab044ae2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -517,12 +517,12 @@ case class JsonToStructs( timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + val forceNullableSchema: Boolean = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. - val nullableSchema = if (forceNullableSchema) schema.asNullable else schema + val nullableSchema: DataType = if (forceNullableSchema) schema.asNullable else schema override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4928560eacb1c..bdc4007ba2866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -624,14 +624,6 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") - val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") - .internal() - .doc("When true, force the output schema of the from_json() function to be nullable " + - "(including all the fields). Otherwise, the schema might not be compatible with" + - "actual data, which leads to curruptions.") - .booleanConf - .createWithDefault(true) - val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) @@ -1354,6 +1346,14 @@ object SQLConf { "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.function.fromJson.forceNullable") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to corruptions.") + .booleanConf + .createWithDefault(true) + val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString") .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " + "an output as binary. Otherwise, it returns as a string. ") From 4193c7623b92765adaee539e723328ddc9048c09 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 18 Sep 2018 22:41:27 -0700 Subject: [PATCH 1643/2461] [SPARK-24626] Add statistics prefix to parallelFileListingInStatsComputation ## What changes were proposed in this pull request? To be more consistent with other statistics based configs. ## How was this patch tested? N/A - straightforward rename of config option. Used `git grep` to make sure there are no mention of it. Closes #22457 from rxin/SPARK-24626. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 63 +++++++++---------- 2 files changed, 32 insertions(+), 33 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f25415c0bc748..2fa29a00e74c7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1896,7 +1896,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. - - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.parallelFileListingInStatsComputation.enabled` to `False`. + - Since Spark 2.4, File listing for compute statistics is done in parallel by default. This can be disabled by setting `spark.sql.statistics.parallelFileListingInStatsComputation.enabled` to `False`. - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was writted as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bdc4007ba2866..4499a35310a6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -269,22 +269,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = - buildConf("spark.sql.statistics.fallBackToHdfs") - .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + - " This is useful in determining if a table is small enough to use auto broadcast joins.") - .booleanConf - .createWithDefault(false) - - val DEFAULT_SIZE_IN_BYTES = buildConf("spark.sql.defaultSizeInBytes") - .internal() - .doc("The default table size used in query planning. By default, it is set to Long.MaxValue " + - "which is larger than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. " + - "That is to say by default the optimizer will not choose to broadcast a table unless it " + - "knows for sure its size is small enough.") - .longConf - .createWithDefault(Long.MaxValue) - val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions") .doc("The default number of partitions to use when shuffling data for joins or aggregations.") .intConf @@ -1110,6 +1094,30 @@ object SQLConf { .internal() .stringConf + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = + buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") + .internal() + .doc("When true, SQL commands use parallel file listing, " + + "as opposed to single thread listing." + + "This usually speeds up commands that need to list many directories.") + .booleanConf + .createWithDefault(true) + + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") + .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + + " This is useful in determining if a table is small enough to use auto broadcast joins.") + .booleanConf + .createWithDefault(false) + + val DEFAULT_SIZE_IN_BYTES = buildConf("spark.sql.defaultSizeInBytes") + .internal() + .doc("The default table size used in query planning. By default, it is set to Long.MaxValue " + + "which is larger than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. " + + "That is to say by default the optimizer will not choose to broadcast a table unless it " + + "knows for sure its size is small enough.") + .longConf + .createWithDefault(Long.MaxValue) + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() @@ -1553,15 +1561,6 @@ object SQLConf { "are performed before any UNION, EXCEPT and MINUS operations.") .booleanConf .createWithDefault(false) - - val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = - buildConf("spark.sql.parallelFileListingInStatsComputation.enabled") - .internal() - .doc("When true, SQL commands use parallel file listing, " + - "as opposed to single thread listing." + - "This usually speeds up commands that need to list many directories.") - .booleanConf - .createWithDefault(true) } /** @@ -1770,14 +1769,10 @@ class SQLConf extends Serializable with Logging { def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) - def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) - def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) - def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES) - def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED) def isParquetSchemaRespectSummaries: Boolean = getConf(PARQUET_SCHEMA_RESPECT_SUMMARIES) @@ -1869,6 +1864,13 @@ class SQLConf extends Serializable with Logging { def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + def parallelFileListingInStatsComputation: Boolean = + getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) + + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) + + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES) + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) def histogramEnabled: Boolean = getConf(HISTOGRAM_ENABLED) @@ -1971,9 +1973,6 @@ class SQLConf extends Serializable with Logging { def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) - def parallelFileListingInStatsComputation: Boolean = - getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From 5534a3a58e4025624fbad527dd129acb8025f25a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 19 Sep 2018 18:30:46 +0800 Subject: [PATCH 1644/2461] [SPARK-25445][BUILD][FOLLOWUP] Resolve issues in release-build.sh for publishing scala-2.12 build ## What changes were proposed in this pull request? This is a follow up for #22441. 1. Remove flag "-Pkafka-0-8" for Scala 2.12 build. 2. Clean up the script, simpler logic. 3. Switch to Scala version to 2.11 before script exit. ## How was this patch tested? Manual test. Closes #22454 from gengliangwang/revise_release_build. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- dev/create-release/release-build.sh | 38 ++++++++++++----------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 4c90a772104fc..cce5f8b6975ca 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -111,21 +111,21 @@ fi # different versions of Scala are supported. BASE_PROFILES="-Pmesos -Pyarn" PUBLISH_SCALA_2_10=0 -PUBLISH_SCALA_2_12=0 SCALA_2_10_PROFILES="-Pscala-2.10" SCALA_2_11_PROFILES= -SCALA_2_12_PROFILES="-Pscala-2.12 -Pkafka-0-8" - if [[ $SPARK_VERSION > "2.3" ]]; then BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" SCALA_2_11_PROFILES="-Pkafka-0-8" - if [[ $SPARK_VERSION > "2.4" ]]; then - PUBLISH_SCALA_2_12=1 - fi else PUBLISH_SCALA_2_10=1 fi +PUBLISH_SCALA_2_12=0 +SCALA_2_12_PROFILES="-Pscala-2.12" +if [[ $SPARK_VERSION > "2.4" ]]; then + PUBLISH_SCALA_2_12=1 +fi + # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central @@ -190,17 +190,9 @@ if [[ "$1" == "package" ]]; then # Updated for each binary build make_binary_release() { NAME=$1 - SCALA_VERSION=$2 - SCALA_PROFILES= - if [[ SCALA_VERSION == "2.10" ]]; then - SCALA_PROFILES="$SCALA_2_10_PROFILES" - elif [[ SCALA_VERSION == "2.12" ]]; then - SCALA_PROFILES="$SCALA_2_12_PROFILES" - else - SCALA_PROFILES="$SCALA_2_11_PROFILES" - fi - FLAGS="$MVN_EXTRA_OPTS -B $SCALA_PROFILES $BASE_RELEASE_PROFILES $3" - BUILD_PACKAGE=$4 + FLAGS="$MVN_EXTRA_OPTS -B $BASE_RELEASE_PROFILES $2" + BUILD_PACKAGE=$3 + SCALA_VERSION=$4 # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. @@ -210,10 +202,8 @@ if [[ "$1" == "package" ]]; then cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME - if [[ SCALA_VERSION == "2.10" ]]; then - ./dev/change-scala-version.sh 2.10 - elif [[ SCALA_VERSION == "2.12" ]]; then - ./dev/change-scala-version.sh 2.12 + if [[ "$SCALA_VERSION" != "2.11" ]]; then + ./dev/change-scala-version.sh $SCALA_VERSION fi export ZINC_PORT=$ZINC_PORT @@ -305,7 +295,7 @@ if [[ "$1" == "package" ]]; then for key in ${!BINARY_PKGS_ARGS[@]}; do args=${BINARY_PKGS_ARGS[$key]} extra=${BINARY_PKGS_EXTRA[$key]} - if ! make_binary_release "$key" "2.11" "$args" "$extra"; then + if ! make_binary_release "$key" "$SCALA_2_11_PROFILES $args" "$extra" "2.11"; then error "Failed to build $key package. Check logs for details." fi done @@ -314,7 +304,7 @@ if [[ "$1" == "package" ]]; then key="without-hadoop-scala-2.12" args="-Phadoop-provided" extra="" - if ! make_binary_release "$key" "2.12" "$args" "$extra"; then + if ! make_binary_release "$key" "$SCALA_2_12_PROFILES $args" "$extra" "2.12"; then error "Failed to build $key package. Check logs for details." fi fi @@ -446,6 +436,8 @@ if [[ "$1" == "publish-release" ]]; then # Clean-up Zinc nailgun process $LSOF -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + ./dev/change-scala-version.sh 2.11 + pushd $tmp_repo/org/apache/spark # Remove any extra files generated during install From 12b1e91e6b5135f6ed3e59a49abfc2e5a855263a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 19 Sep 2018 19:54:49 +0800 Subject: [PATCH 1645/2461] [SPARK-25358][SQL] MutableProjection supports fallback to an interpreted mode ## What changes were proposed in this pull request? In SPARK-23711, `UnsafeProjection` supports fallback to an interpreted mode. Therefore, this pr fixed code to support the same fallback mode in `MutableProjection` based on `CodeGeneratorWithInterpretedFallback`. ## How was this patch tested? Added tests in `CodeGeneratorWithInterpretedFallbackSuite`. Closes #22355 from maropu/SPARK-25358. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../InterpretedMutableProjection.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/Projection.scala | 75 ++++++++-------- .../codegen/GenerateMutableProjection.scala | 4 + .../sql/catalyst/expressions/package.scala | 18 +--- ...eneratorWithInterpretedFallbackSuite.scala | 38 +++++++- .../CollectionExpressionsSuite.scala | 8 +- .../expressions/ExpressionEvalHelper.scala | 34 ++++--- .../expressions/MiscExpressionsSuite.scala | 10 +-- .../expressions/ObjectExpressionsSuite.scala | 8 +- .../spark/sql/execution/SparkPlan.scala | 2 +- .../spark/sql/execution/aggregate/udaf.scala | 2 +- 11 files changed, 201 insertions(+), 87 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala new file mode 100644 index 0000000000000..0654108cea281 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp + + +/** + * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified + * expressions. + * + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. + */ +class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(toBoundExprs(expressions, inputSchema)) + + private[this] val buffer = new Array[Any](expressions.size) + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + private[this] val validExprs = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + } + private[this] var mutableRow: InternalRow = new GenericInternalRow(expressions.size) + def currentValue: InternalRow = mutableRow + + override def target(row: InternalRow): MutableProjection = { + mutableRow = row + this + } + + override def apply(input: InternalRow): InternalRow = { + var i = 0 + while (i < validExprs.length) { + val (expr, ordinal) = validExprs(i) + // Store the result into buffer first, to make the projection atomic (needed by aggregation) + buffer(ordinal) = expr.eval(input) + i += 1 + } + i = 0 + while (i < validExprs.length) { + val (_, ordinal) = validExprs(i) + mutableRow(ordinal) = buffer(ordinal) + i += 1 + } + mutableRow + } +} + +/** + * Helper functions for creating an [[InterpretedMutableProjection]]. + */ +object InterpretedMutableProjection { + + /** + * Returns a [[MutableProjection]] for given sequence of bound Expressions. + */ + def createProjection(exprs: Seq[Expression]): MutableProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedMutableProjection(cleanedExpressions) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 5f24170398715..792646cf9f10c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -56,47 +56,50 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { } /** - * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified - * expressions. + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. * - * @param expressions a sequence of expressions that determine the value of each column of the - * output row. + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after + * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call + * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. */ -case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { - def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(expressions.map(BindReferences.bindReference(_, inputSchema))) +abstract class MutableProjection extends Projection { + def currentValue: InternalRow - private[this] val buffer = new Array[Any](expressions.size) + /** Uses the given row to store the output of the projection. */ + def target(row: InternalRow): MutableProjection +} - override def initialize(partitionIndex: Int): Unit = { - expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize(partitionIndex) - case _ => - }) +/** + * The factory object for `MutableProjection`. + */ +object MutableProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], MutableProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): MutableProjection = { + GenerateMutableProjection.generate(in, SQLConf.get.subexpressionEliminationEnabled) } - private[this] val exprArray = expressions.toArray - private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) - def currentValue: InternalRow = mutableRow + override protected def createInterpretedObject(in: Seq[Expression]): MutableProjection = { + InterpretedMutableProjection.createProjection(in) + } - override def target(row: InternalRow): MutableProjection = { - mutableRow = row - this + /** + * Returns an MutableProjection for given sequence of bound Expressions. + */ + def create(exprs: Seq[Expression]): MutableProjection = { + createObject(exprs) } - override def apply(input: InternalRow): InternalRow = { - var i = 0 - while (i < exprArray.length) { - // Store the result into buffer first, to make the projection atomic (needed by aggregation) - buffer(i) = exprArray(i).eval(input) - i += 1 - } - i = 0 - while (i < exprArray.length) { - mutableRow(i) = buffer(i) - i += 1 - } - mutableRow + /** + * Returns an MutableProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. + */ + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = { + create(toBoundExprs(exprs, inputSchema)) } } @@ -123,12 +126,6 @@ object UnsafeProjection InterpretedUnsafeProjection.createProjection(in) } - protected def toBoundExprs( - exprs: Seq[Expression], - inputSchema: Seq[Attribute]): Seq[Expression] = { - exprs.map(BindReferences.bindReference(_, inputSchema)) - } - protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { exprs.map(_ transform { case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 33d14329ec95c..d588e7f081303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,6 +44,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) } + def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { + create(canonicalize(expressions), useSubexprElimination) + } + protected def create(expressions: Seq[Expression]): MutableProjection = { create(expressions, false) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 11dcc3ebf798c..0083ee64653e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -86,24 +86,12 @@ package object expressions { } /** - * Converts a [[InternalRow]] to another Row given a sequence of expression that define each - * column of the new row. If the schema of the input row is specified, then the given expression - * will be bound to that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after - * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call - * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. + * A helper function to bind given expressions to an input schema. */ - abstract class MutableProjection extends Projection { - def currentValue: InternalRow - - /** Uses the given row to store the output of the projection. */ - def target(row: InternalRow): MutableProjection + def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) } - /** * Helper functions for working with `Seq[Attribute]`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 28edd85ab6e87..6ea3b05ff9c1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -20,13 +20,18 @@ package org.apache.spark.sql.catalyst.expressions import java.util.concurrent.ExecutionException import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, StructType} class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + object FailedCodegenProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { @@ -44,19 +49,30 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT test("UnsafeProjection with codegen factory mode") { val input = Seq(BoundReference(0, IntegerType, nullable = true)) - val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { val obj = UnsafeProjection.createObject(input) assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) } - val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { val obj = UnsafeProjection.createObject(input) assert(obj.isInstanceOf[InterpretedUnsafeProjection]) } } + test("MutableProjection with codegen factory mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = MutableProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificMutableProjection")) + } + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = MutableProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedMutableProjection]) + } + } + test("fallback to the interpreter mode") { val input = Seq(BoundReference(0, IntegerType, nullable = true)) val fallback = CodegenObjectFactoryMode.FALLBACK.toString @@ -69,11 +85,25 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT test("codegen failures in the CODEGEN_ONLY mode") { val errMsg = intercept[ExecutionException] { val input = Seq(BoundReference(0, IntegerType, nullable = true)) - val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { FailedCodegenProjection.createObject(input) } }.getMessage assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:")) } + + test("SPARK-25358 Correctly handles NoOp in MutableProjection") { + val exprs = Seq(Add(BoundReference(0, IntegerType, nullable = true), Literal.create(1)), NoOp) + val input = InternalRow.fromSeq(1 :: 1 :: Nil) + val expected = 2 :: null :: Nil + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val proj = MutableProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val proj = MutableProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c7db4ec9e16b1..2e0adbb465008 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1510,16 +1510,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val seed1 = Some(r.nextLong()) assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) === evaluateWithoutCodegen(Shuffle(ai0, seed1))) - assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) === - evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1))) + assert(evaluateWithMutableProjection(Shuffle(ai0, seed1)) === + evaluateWithMutableProjection(Shuffle(ai0, seed1))) assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) === evaluateWithUnsafeProjection(Shuffle(ai0, seed1))) val seed2 = Some(r.nextLong()) assert(evaluateWithoutCodegen(Shuffle(ai0, seed1)) !== evaluateWithoutCodegen(Shuffle(ai0, seed2))) - assert(evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed1)) !== - evaluateWithGeneratedMutableProjection(Shuffle(ai0, seed2))) + assert(evaluateWithMutableProjection(Shuffle(ai0, seed1)) !== + evaluateWithMutableProjection(Shuffle(ai0, seed2))) assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6684e5ce18d4c..b5986aac65552 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + checkEvaluationWithMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } @@ -136,7 +136,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa // Make it as method to obtain fresh expression everytime. def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") - checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") + checkException(evaluateWithMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode") } @@ -183,22 +183,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } - protected def checkEvaluationWithGeneratedMutableProjection( - expression: Expression, + protected def checkEvaluationWithMutableProjection( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val actual = evaluateWithGeneratedMutableProjection(expression, inputRow) - if (!checkResult(actual, expected, expression.dataType)) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val actual = evaluateWithMutableProjection(expression, inputRow) + if (!checkResult(actual, expected, expression.dataType)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (fallback mode = $fallbackMode): $expression, " + + s"actual: $actual, expected: $expected$input") + } + } } } - protected def evaluateWithGeneratedMutableProjection( - expression: Expression, + protected def evaluateWithMutableProjection( + expression: => Expression, inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + MutableProjection.create(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) @@ -218,7 +224,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa if (expected == null) { if (!unsafeRow.isNullAt(0)) { val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + + fail(s"Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") } } else { @@ -226,7 +232,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val expectedRow = UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + + fail(s"Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") } } @@ -266,7 +272,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa expected: Spread[Double], inputRow: InternalRow = EmptyRow): Unit = { checkEvaluationWithoutCodegen(expression, expected) - checkEvaluationWithGeneratedMutableProjection(expression, expected) + checkEvaluationWithMutableProjection(expression, expected) checkEvaluationWithOptimization(expression, expected) var plan = generateProject( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index b6c269348b002..4b2d153a28cc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -48,15 +48,15 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val r = new Random() val seed1 = Some(r.nextLong()) assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1))) - assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) === - evaluateWithGeneratedMutableProjection(Uuid(seed1))) + assert(evaluateWithMutableProjection(Uuid(seed1)) === + evaluateWithMutableProjection(Uuid(seed1))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) === evaluateWithUnsafeProjection(Uuid(seed1))) val seed2 = Some(r.nextLong()) assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2))) - assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !== - evaluateWithGeneratedMutableProjection(Uuid(seed2))) + assert(evaluateWithMutableProjection(Uuid(seed1)) !== + evaluateWithMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) @@ -79,7 +79,7 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val outputEval = errorStream.toString errorStream.reset() // check with codegen - checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1) + checkEvaluationWithMutableProjection(PrintToStderr(inputExpr), 1) val outputCodegen = errorStream.toString (outputEval, outputCodegen) } finally { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0af9e07d1d1d..d145fd0aaba47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -72,7 +72,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val cls = classOf[Tuple2[Boolean, java.lang.Integer]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val invoke = Invoke(inputObject, "_2", IntegerType) - checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow) + checkEvaluationWithMutableProjection(invoke, null, inputRow) } test("MapObjects should make copies of unsafe-backed data") { @@ -233,13 +233,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new TestBean), Map("setNonPrimitive" -> Literal(null))) evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) - evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + evaluateWithMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) val initializeBean2 = InitializeJavaBean( Literal.fromObject(new TestBean), Map("setNonPrimitive" -> Literal("string"))) evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) - evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) } test("SPARK-23585: UnwrapOption should support interpreted execution") { @@ -273,7 +273,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val resolver = ResolveTimeZone(new SQLConf) val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) checkEvaluationWithoutCodegen(expr, expected, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow) + checkEvaluationWithMutableProjection(expr, expected, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { checkEvaluationWithUnsafeProjection( expr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1f97993e20458..ab6031c436e9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -380,7 +380,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute], useSubexprElimination: Boolean = false): MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") - GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) + MutableProjection.create(expressions, inputSchema) } private def genInterpretedPredicate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 72aa4adff4e64..100486fa9850f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -365,7 +365,7 @@ case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - GenerateMutableProjection.generate(children, inputAttributes) + MutableProjection.create(children, inputAttributes) } private[this] lazy val inputToScalaConverters: Any => Any = From a71f6a1750fd0a29ecae6b98673ee15840da1c62 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 20 Sep 2018 00:29:48 +0800 Subject: [PATCH 1646/2461] [SPARK-25414][SS][TEST] make it clear that the numRows metrics should be counted for each scan of the source ## What changes were proposed in this pull request? For self-join/self-union, Spark will produce a physical plan which has multiple `DataSourceV2ScanExec` instances referring to the same `ReadSupport` instance. In this case, the streaming source is indeed scanned multiple times, and the `numInputRows` metrics should be counted for each scan. Actually we already have 2 test cases to verify the behavior: 1. `StreamingQuerySuite.input row calculation with same V2 source used twice in self-join` 2. `KafkaMicroBatchSourceSuiteBase.ensure stream-stream self-join generates only one offset in log and correct metrics`. However, in these 2 tests, the expected result is different, which is super confusing. It turns out that, the first test doesn't trigger exchange reuse, so the source is scanned twice. The second test triggers exchange reuse, and the source is scanned only once. This PR proposes to improve these 2 tests, to test with/without exchange reuse. ## How was this patch tested? test only change Closes #22402 from cloud-fan/bug. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 39 ++++++++++++---- .../streaming/ProgressReporter.scala | 6 +-- .../sql/streaming/StreamingQuerySuite.scala | 46 ++++++++++++++----- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 8e246dbbf5d70..e5f008804ee5b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,9 +35,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} @@ -598,18 +600,37 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { val join = values.join(values, "key") - testStream(join)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2), - CheckAnswer((1, 1, 1), (2, 2, 2)), - AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), - AssertOnQuery { q => + def checkQuery(check: AssertOnQuery): Unit = { + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + check + ) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { + checkQuery(AssertOnQuery { q => assert(q.availableOffsets.iterator.size == 1) + // The kafka source is scanned twice because of self-join + assert(q.recentProgress.map(_.numInputRows).sum == 8) + true + }) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") { + checkQuery(AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.lastExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + }.length == 1) + // The kafka source is scanned only once because of exchange reuse. assert(q.recentProgress.map(_.numInputRows).sum == 4) true - } - ) + }) + } } test("read Kafka transactional messages: read_committed") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 73b180468d367..392229bcb5f55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -240,9 +240,6 @@ trait ProgressReporter extends Logging { /** Extract number of input sources for each streaming source in plan */ private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { - import java.util.IdentityHashMap - import scala.collection.JavaConverters._ - def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } @@ -255,6 +252,9 @@ trait ProgressReporter extends Logging { } if (onlyDataSourceV2Sources) { + // It's possible that multiple DataSourceV2ScanExec instances may refer to the same source + // (can happen with self-unions or self-joins). This means the source is scanned multiple + // times in the query, we should count the numRows for each scan. val sourceToInputRowsTuples = lastExecution.executedPlan.collect { case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1dd817545a969..c170641372d61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ @@ -500,29 +501,52 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery { q => val lastProgress = getLastProgressWithData(q) assert(lastProgress.nonEmpty) - assert(lastProgress.get.numInputRows == 6) assert(lastProgress.get.sources.length == 1) - assert(lastProgress.get.sources(0).numInputRows == 6) + // The source is scanned twice because of self-union + assert(lastProgress.get.numInputRows == 6) true } ) } test("input row calculation with same V2 source used twice in self-join") { - val streamInput = MemoryStream[Int] - val df = streamInput.toDF() - testStream(df.join(df, "value"), useV2Sink = true)( - AddData(streamInput, 1, 2, 3), - CheckAnswer(1, 2, 3), - AssertOnQuery { q => + def checkQuery(check: AssertOnQuery): Unit = { + val memoryStream = MemoryStream[Int] + // TODO: currently the streaming framework always add a dummy Project above streaming source + // relation, which breaks exchange reuse, as the optimizer will remove Project from one side. + // Here we manually add a useful Project, to trigger exchange reuse. + val streamDF = memoryStream.toDF().select('value + 0 as "v") + testStream(streamDF.join(streamDF, "v"), useV2Sink = true)( + AddData(memoryStream, 1, 2, 3), + CheckAnswer(1, 2, 3), + check + ) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { + checkQuery(AssertOnQuery { q => val lastProgress = getLastProgressWithData(q) assert(lastProgress.nonEmpty) + assert(lastProgress.get.sources.length == 1) + // The source is scanned twice because of self-join assert(lastProgress.get.numInputRows == 6) + true + }) + } + + withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") { + checkQuery(AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) assert(lastProgress.get.sources.length == 1) - assert(lastProgress.get.sources(0).numInputRows == 6) + assert(q.lastExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + }.length == 1) + // The source is scanned only once because of exchange reuse + assert(lastProgress.get.numInputRows == 3) true - } - ) + }) + } } test("input row calculation with trigger having data for only one of two V2 sources") { From cb1b55cf771018f1560f6b173cdd7c6ca8061bc7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 19 Sep 2018 14:33:40 -0700 Subject: [PATCH 1647/2461] Revert "[SPARK-23173][SQL] rename spark.sql.fromJsonForceNullableSchema" This reverts commit 6c7db7fd1ced1d143b1389d09990a620fc16be46. --- .../catalyst/expressions/jsonExpressions.scala | 4 ++-- .../org/apache/spark/sql/internal/SQLConf.scala | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index ade10ab044ae2..bd9090a07471b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -517,12 +517,12 @@ case class JsonToStructs( timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - val forceNullableSchema: Boolean = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. - val nullableSchema: DataType = if (forceNullableSchema) schema.asNullable else schema + val nullableSchema = if (forceNullableSchema) schema.asNullable else schema override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4499a35310a6e..b1e9b17e049e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -608,6 +608,14 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to curruptions.") + .booleanConf + .createWithDefault(true) + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) @@ -1354,14 +1362,6 @@ object SQLConf { "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) - val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.function.fromJson.forceNullable") - .internal() - .doc("When true, force the output schema of the from_json() function to be nullable " + - "(including all the fields). Otherwise, the schema might not be compatible with" + - "actual data, which leads to corruptions.") - .booleanConf - .createWithDefault(true) - val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString") .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " + "an output as binary. Otherwise, it returns as a string. ") From 6f681d42964884d19bf22deb614550d712223117 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 19 Sep 2018 15:16:20 -0700 Subject: [PATCH 1648/2461] [SPARK-22666][ML][FOLLOW-UP] Improve testcase to tolerate different schema representation ## What changes were proposed in this pull request? Improve testcase "image datasource test: read non image" to tolerate different schema representation. Because file:/path and file:///path are both valid URI-ifications so in some environment the testcase will fail. ## How was this patch tested? Manual. Closes #22449 from WeichenXu123/image_url. Authored-by: WeichenXu Signed-off-by: Xiangrui Meng --- .../spark/ml/source/image/ImageFileFormatSuite.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala index 1a6a8d67d8d66..38e25131df867 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.source.image +import java.net.URI import java.nio.file.Paths import org.apache.spark.SparkFunSuite @@ -58,8 +59,14 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext { .load(filePath) assert(df2.count() === 1) val result = df2.head() - assert(result === invalidImageRow( - Paths.get(filePath).toAbsolutePath().normalize().toUri().toString)) + + val resultOrigin = result.getStruct(0).getString(0) + // covert `origin` to `java.net.URI` object and then compare. + // because `file:/path` and `file:///path` are both valid URI-ifications + assert(new URI(resultOrigin) === Paths.get(filePath).toAbsolutePath().normalize().toUri()) + + // Compare other columns in the row to be the same with the `invalidImageRow` + assert(result === invalidImageRow(resultOrigin)) } test("image datasource partition test") { From 90e3955f384ca07bdf24faa6cdb60ded944cf0d8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 20 Sep 2018 09:29:29 +0800 Subject: [PATCH 1649/2461] [SPARK-25471][PYTHON][TEST] Fix pyspark-sql test error when using Python 3.6 and Pandas 0.23 ## What changes were proposed in this pull request? Fix test that constructs a Pandas DataFrame by specifying the column order. Previously this test assumed the columns would be sorted alphabetically, however when using Python 3.6 with Pandas 0.23 or higher, the original column order is maintained. This causes the columns to get mixed up and the test errors. Manually tested with `python/run-tests` using Python 3.6.6 and Pandas 0.23.4 Closes #22477 from BryanCutler/pyspark-tests-py36-pd23-SPARK-25471. Authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 08d7cfadc084c..603f994dc9597 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3266,7 +3266,7 @@ def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) + "d": [pd.Timestamp.now().date()]}, columns=["d", "ts"]) # test types are inferred correctly without specifying schema df = self.spark.createDataFrame(pdf) self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) From 936c920347e196381b48bc3656ca81a06f2ff46d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 19 Sep 2018 18:51:20 -0700 Subject: [PATCH 1650/2461] [SPARK-24157][SS][FOLLOWUP] Rename to spark.sql.streaming.noDataMicroBatches.enabled ## What changes were proposed in this pull request? This patch changes the config option `spark.sql.streaming.noDataMicroBatchesEnabled` to `spark.sql.streaming.noDataMicroBatches.enabled` to be more consistent with rest of the configs. Unfortunately there is one streaming config called `spark.sql.streaming.metricsEnabled`. For that one we should just use a fallback config and change it in a separate patch. ## How was this patch tested? Made sure no other references to this config are in the code base: ``` > git grep "noDataMicro" sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala: buildConf("spark.sql.streaming.noDataMicroBatches.enabled") ``` Closes #22476 from rxin/SPARK-24157. Authored-by: Reynold Xin Signed-off-by: Reynold Xin --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b1e9b17e049e7..c3328a6936ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1076,7 +1076,7 @@ object SQLConf { .createWithDefault(10000L) val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = - buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + buildConf("spark.sql.streaming.noDataMicroBatches.enabled") .doc( "Whether streaming micro-batch engine will execute batches without data " + "for eager state management for stateful streaming queries.") From 8aae49afc7997aa1da61029409ef6d8ce0ab256a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 20 Sep 2018 10:10:20 +0800 Subject: [PATCH 1651/2461] [SPARK-24341][FOLLOWUP][DOCS] Add migration note for IN subqueries behavior ## What changes were proposed in this pull request? The PR updates the migration guide in order to explain the changes introduced in the behavior of the IN operator with subqueries, in particular, the improved handling of struct attributes in these situations. ## How was this patch tested? NA Closes #22469 from mgaido91/SPARK-24341_followup. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- docs/sql-programming-guide.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2fa29a00e74c7..c76f2e30e6771 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1879,6 +1879,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - Since Spark 2.4, when there is a struct field in front of the IN operator before a subquery, the inner query must contain a struct field as well. In previous versions, instead, the fields of the struct were compared to the output of the inner query. Eg. if `a` is a `struct(a string, b int)`, in Spark 2.4 `a in (select (1 as a, 'a' as b) from range(1))` is a valid query, while `a in (select 1, 'a' from range(1))` is not. In previous version it was the opposite. - In versions 2.2.1+ and 2.3, if `spark.sql.caseSensitive` is set to true, then the `CURRENT_DATE` and `CURRENT_TIMESTAMP` functions incorrectly became case-sensitive and would resolve to columns (unless typed in lower case). In Spark 2.4 this has been fixed and the functions are no longer case-sensitive. - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. - Since Spark 2.4, Spark will display table description column Last Access value as UNKNOWN when the value was Jan 01 1970. From 47d6e80a2e64823fabb596503fb6a6cc6f51f713 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 20 Sep 2018 10:23:37 +0800 Subject: [PATCH 1652/2461] [SPARK-25457][SQL] IntegralDivide returns data type of the operands ## What changes were proposed in this pull request? The PR proposes to return the data type of the operands as a result for the `div` operator. Before the PR, `bigint` is always returned. It introduces also a `spark.sql.legacy.integralDivide.returnBigint` config in order to let the users restore the legacy behavior. ## How was this patch tested? added UTs Closes #22465 from mgaido91/SPARK-25457. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/arithmetic.scala | 17 +- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../ArithmeticExpressionSuite.scala | 26 +- .../sql-tests/inputs/operator-div.sql | 14 + .../resources/sql-tests/inputs/operators.sql | 6 +- .../sql-tests/results/operator-div.sql.out | 82 ++++++ .../sql-tests/results/operators.sql.out | 248 ++++++++---------- 7 files changed, 246 insertions(+), 156 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/operator-div.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/operator-div.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1b1808f8366d2..f59b2a2ec510f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -327,16 +328,24 @@ case class Divide(left: Expression, right: Expression) extends DivModLike { case class IntegralDivide(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = IntegralType - override def dataType: DataType = LongType + override def dataType: DataType = if (SQLConf.get.integralDivideReturnLong) { + LongType + } else { + left.dataType + } override def symbol: String = "/" override def sqlOperator: String = "div" - private lazy val div: (Any, Any) => Long = left.dataType match { + private lazy val div: (Any, Any) => Any = left.dataType match { case i: IntegralType => val divide = i.integral.asInstanceOf[Integral[Any]].quot _ - val toLong = i.integral.asInstanceOf[Integral[Any]].toLong _ - (x, y) => toLong(divide(x, y)) + if (SQLConf.get.integralDivideReturnLong) { + val toLong = i.integral.asInstanceOf[Integral[Any]].toLong _ + (x, y) => toLong(divide(x, y)) + } else { + divide + } } override def evalOperation(left: Any, right: Any): Any = div(left, right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c3328a6936ae7..907221c073471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1561,6 +1561,13 @@ object SQLConf { "are performed before any UNION, EXCEPT and MINUS operations.") .booleanConf .createWithDefault(false) + + val LEGACY_INTEGRALDIVIDE_RETURN_LONG = buildConf("spark.sql.legacy.integralDivide.returnBigint") + .doc("If it is set to true, the div operator returns always a bigint. This behavior was " + + "inherited from Hive. Otherwise, the return type is the data type of the operands.") + .internal() + .booleanConf + .createWithDefault(false) } /** @@ -1973,6 +1980,8 @@ class SQLConf extends Serializable with Logging { def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index c3c4d9ee6b702..1318ab1859839 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -144,13 +145,24 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("/ (Divide) for integral type") { - checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) - checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) - checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) - checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) - checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) - checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) - checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) + withSQLConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG.key -> "false") { + checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) + checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) + checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0) + checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) + checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0.toShort) + checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0) + checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) + } + withSQLConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG.key -> "true") { + checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) + checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) + checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) + checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) + checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) + checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) + checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) + } } test("% (Remainder)") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operator-div.sql b/sql/core/src/test/resources/sql-tests/inputs/operator-div.sql new file mode 100644 index 0000000000000..6e1c1bded9043 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/operator-div.sql @@ -0,0 +1,14 @@ +set spark.sql.legacy.integralDivide.returnBigint=true; + +select 5 div 2; +select 5 div 0; +select 5 div null; +select null div 5; + +set spark.sql.legacy.integralDivide.returnBigint=false; + +select 5 div 2; +select 5 div 0; +select 5 div null; +select null div 5; + diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 15d981985c55b..37f9cd44da7f2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -16,15 +16,11 @@ select + + 100; select - - max(key) from testdata; select + - key from testdata where key = 33; --- div +-- division select 5 / 2; select 5 / 0; select 5 / null; select null / 5; -select 5 div 2; -select 5 div 0; -select 5 div null; -select null div 5; -- other arithmetics select 1 + 2; diff --git a/sql/core/src/test/resources/sql-tests/results/operator-div.sql.out b/sql/core/src/test/resources/sql-tests/results/operator-div.sql.out new file mode 100644 index 0000000000000..088b4d1c231fa --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/operator-div.sql.out @@ -0,0 +1,82 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +set spark.sql.legacy.integralDivide.returnBigint=true +-- !query 0 schema +struct +-- !query 0 output +spark.sql.legacy.integralDivide.returnBigint true + + +-- !query 1 +select 5 div 2 +-- !query 1 schema +struct<(5 div 2):bigint> +-- !query 1 output +2 + + +-- !query 2 +select 5 div 0 +-- !query 2 schema +struct<(5 div 0):bigint> +-- !query 2 output +NULL + + +-- !query 3 +select 5 div null +-- !query 3 schema +struct<(5 div CAST(NULL AS INT)):bigint> +-- !query 3 output +NULL + + +-- !query 4 +select null div 5 +-- !query 4 schema +struct<(CAST(NULL AS INT) div 5):bigint> +-- !query 4 output +NULL + + +-- !query 5 +set spark.sql.legacy.integralDivide.returnBigint=false +-- !query 5 schema +struct +-- !query 5 output +spark.sql.legacy.integralDivide.returnBigint false + + +-- !query 6 +select 5 div 2 +-- !query 6 schema +struct<(5 div 2):int> +-- !query 6 output +2 + + +-- !query 7 +select 5 div 0 +-- !query 7 schema +struct<(5 div 0):int> +-- !query 7 output +NULL + + +-- !query 8 +select 5 div null +-- !query 8 schema +struct<(5 div CAST(NULL AS INT)):int> +-- !query 8 output +NULL + + +-- !query 9 +select null div 5 +-- !query 9 schema +struct<(CAST(NULL AS INT) div 5):int> +-- !query 9 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 2555734756fc4..fd1d0db9e3f78 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 59 +-- Number of queries: 55 -- !query 0 @@ -155,332 +155,300 @@ NULL -- !query 19 -select 5 div 2 --- !query 19 schema -struct<(5 div 2):bigint> --- !query 19 output -2 - - --- !query 20 -select 5 div 0 --- !query 20 schema -struct<(5 div 0):bigint> --- !query 20 output -NULL - - --- !query 21 -select 5 div null --- !query 21 schema -struct<(5 div CAST(NULL AS INT)):bigint> --- !query 21 output -NULL - - --- !query 22 -select null div 5 --- !query 22 schema -struct<(CAST(NULL AS INT) div 5):bigint> --- !query 22 output -NULL - - --- !query 23 select 1 + 2 --- !query 23 schema +-- !query 19 schema struct<(1 + 2):int> --- !query 23 output +-- !query 19 output 3 --- !query 24 +-- !query 20 select 1 - 2 --- !query 24 schema +-- !query 20 schema struct<(1 - 2):int> --- !query 24 output +-- !query 20 output -1 --- !query 25 +-- !query 21 select 2 * 5 --- !query 25 schema +-- !query 21 schema struct<(2 * 5):int> --- !query 25 output +-- !query 21 output 10 --- !query 26 +-- !query 22 select 5 % 3 --- !query 26 schema +-- !query 22 schema struct<(5 % 3):int> --- !query 26 output +-- !query 22 output 2 --- !query 27 +-- !query 23 select pmod(-7, 3) --- !query 27 schema +-- !query 23 schema struct --- !query 27 output +-- !query 23 output 2 --- !query 28 +-- !query 24 explain select 'a' || 1 + 2 --- !query 28 schema +-- !query 24 schema struct --- !query 28 output +-- !query 24 output == Physical Plan == *Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] +- Scan OneRowRelation[] --- !query 29 +-- !query 25 explain select 1 - 2 || 'b' --- !query 29 schema +-- !query 25 schema struct --- !query 29 output +-- !query 25 output == Physical Plan == *Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] +- Scan OneRowRelation[] --- !query 30 +-- !query 26 explain select 2 * 4 + 3 || 'b' --- !query 30 schema +-- !query 26 schema struct --- !query 30 output +-- !query 26 output == Physical Plan == *Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] +- Scan OneRowRelation[] --- !query 31 +-- !query 27 explain select 3 + 1 || 'a' || 4 / 2 --- !query 31 schema +-- !query 27 schema struct --- !query 31 output +-- !query 27 output == Physical Plan == *Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] +- Scan OneRowRelation[] --- !query 32 +-- !query 28 explain select 1 == 1 OR 'a' || 'b' == 'ab' --- !query 32 schema +-- !query 28 schema struct --- !query 32 output +-- !query 28 output == Physical Plan == *Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] +- Scan OneRowRelation[] --- !query 33 +-- !query 29 explain select 'a' || 'c' == 'ac' AND 2 == 3 --- !query 33 schema +-- !query 29 schema struct --- !query 33 output +-- !query 29 output == Physical Plan == *Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] +- Scan OneRowRelation[] --- !query 34 +-- !query 30 select cot(1) --- !query 34 schema +-- !query 30 schema struct --- !query 34 output +-- !query 30 output 0.6420926159343306 --- !query 35 +-- !query 31 select cot(null) --- !query 35 schema +-- !query 31 schema struct --- !query 35 output +-- !query 31 output NULL --- !query 36 +-- !query 32 select cot(0) --- !query 36 schema +-- !query 32 schema struct --- !query 36 output +-- !query 32 output Infinity --- !query 37 +-- !query 33 select cot(-1) --- !query 37 schema +-- !query 33 schema struct --- !query 37 output +-- !query 33 output -0.6420926159343306 --- !query 38 +-- !query 34 select ceiling(0) --- !query 38 schema +-- !query 34 schema struct --- !query 38 output +-- !query 34 output 0 --- !query 39 +-- !query 35 select ceiling(1) --- !query 39 schema +-- !query 35 schema struct --- !query 39 output +-- !query 35 output 1 --- !query 40 +-- !query 36 select ceil(1234567890123456) --- !query 40 schema +-- !query 36 schema struct --- !query 40 output +-- !query 36 output 1234567890123456 --- !query 41 +-- !query 37 select ceiling(1234567890123456) --- !query 41 schema +-- !query 37 schema struct --- !query 41 output +-- !query 37 output 1234567890123456 --- !query 42 +-- !query 38 select ceil(0.01) --- !query 42 schema +-- !query 38 schema struct --- !query 42 output +-- !query 38 output 1 --- !query 43 +-- !query 39 select ceiling(-0.10) --- !query 43 schema +-- !query 39 schema struct --- !query 43 output +-- !query 39 output 0 --- !query 44 +-- !query 40 select floor(0) --- !query 44 schema +-- !query 40 schema struct --- !query 44 output +-- !query 40 output 0 --- !query 45 +-- !query 41 select floor(1) --- !query 45 schema +-- !query 41 schema struct --- !query 45 output +-- !query 41 output 1 --- !query 46 +-- !query 42 select floor(1234567890123456) --- !query 46 schema +-- !query 42 schema struct --- !query 46 output +-- !query 42 output 1234567890123456 --- !query 47 +-- !query 43 select floor(0.01) --- !query 47 schema +-- !query 43 schema struct --- !query 47 output +-- !query 43 output 0 --- !query 48 +-- !query 44 select floor(-0.10) --- !query 48 schema +-- !query 44 schema struct --- !query 48 output +-- !query 44 output -1 --- !query 49 +-- !query 45 select 1 > 0.00001 --- !query 49 schema +-- !query 45 schema struct<(CAST(1 AS BIGINT) > 0):boolean> --- !query 49 output +-- !query 45 output true --- !query 50 +-- !query 46 select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null) --- !query 50 schema +-- !query 46 schema struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> --- !query 50 output +-- !query 46 output 1 NULL 0 NULL NULL NULL --- !query 51 +-- !query 47 select BIT_LENGTH('abc') --- !query 51 schema +-- !query 47 schema struct --- !query 51 output +-- !query 47 output 24 --- !query 52 +-- !query 48 select CHAR_LENGTH('abc') --- !query 52 schema +-- !query 48 schema struct --- !query 52 output +-- !query 48 output 3 --- !query 53 +-- !query 49 select CHARACTER_LENGTH('abc') --- !query 53 schema +-- !query 49 schema struct --- !query 53 output +-- !query 49 output 3 --- !query 54 +-- !query 50 select OCTET_LENGTH('abc') --- !query 54 schema +-- !query 50 schema struct --- !query 54 output +-- !query 50 output 3 --- !query 55 +-- !query 51 select abs(-3.13), abs('-2.19') --- !query 55 schema +-- !query 51 schema struct --- !query 55 output +-- !query 51 output 3.13 2.19 --- !query 56 +-- !query 52 select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) --- !query 56 schema +-- !query 52 schema struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> --- !query 56 output +-- !query 52 output -1.11 -1.11 1.11 1.11 --- !query 57 +-- !query 53 select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null) --- !query 57 schema +-- !query 53 schema struct --- !query 57 output +-- !query 53 output 1 0 NULL NULL NULL NULL --- !query 58 +-- !query 54 select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)) --- !query 58 schema +-- !query 54 schema struct --- !query 58 output +-- !query 54 output NULL NULL From 76399d75e23f2c7d6c2a1fb77a4387c5e15c809b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 19 Sep 2018 21:23:35 -0700 Subject: [PATCH 1653/2461] [SPARK-4502][SQL] Rename to spark.sql.optimizer.nestedSchemaPruning.enabled ## What changes were proposed in this pull request? This patch adds an "optimizer" prefix to nested schema pruning. ## How was this patch tested? Should be covered by existing tests. Closes #22475 from rxin/SPARK-4502. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 907221c073471..a01e87c8d1dd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1457,7 +1457,7 @@ object SQLConf { .createWithDefault(true) val NESTED_SCHEMA_PRUNING_ENABLED = - buildConf("spark.sql.nestedSchemaPruning.enabled") + buildConf("spark.sql.optimizer.nestedSchemaPruning.enabled") .internal() .doc("Prune nested fields from a logical relation's output which are unnecessary in " + "satisfying a query. This optimization allows columnar file format readers to avoid " + From 95b177c8f0862c6965a7c3cd76b3935c975adee9 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 19 Sep 2018 21:27:30 -0700 Subject: [PATCH 1654/2461] [SPARK-23648][R][SQL] Adds more types for hint in SparkR ## What changes were proposed in this pull request? Addition of numeric and list hints for SparkR. ## How was this patch tested? Add test in test_sparkSQL.R Author: Huaxin Gao Closes #21649 from huaxingao/spark-23648. --- R/pkg/R/DataFrame.R | 12 +++++++++++- R/pkg/tests/fulltests/test_sparkSQL.R | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 458decaf4766f..a1cb4781f4d0a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3985,7 +3985,17 @@ setMethod("hint", signature(x = "SparkDataFrame", name = "character"), function(x, name, ...) { parameters <- list(...) - stopifnot(all(sapply(parameters, is.character))) + if (!all(sapply(parameters, function(y) { + if (is.character(y) || is.numeric(y)) { + TRUE + } else if (is.list(y)) { + all(sapply(y, function(z) { is.character(z) || is.numeric(z) })) + } else { + FALSE + } + }))) { + stop("sql hint should be character, numeric, or list with character or numeric.") + } jdf <- callJMethod(x@sdf, "hint", name, parameters) dataFrame(jdf) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0c4bdb31b027b..40d8f8084f2f4 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2419,6 +2419,15 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast))) }) +test_that("test hint", { + df <- sql("SELECT * FROM range(10e10)") + hintList <- list("hint2", "hint3", "hint4") + execution_plan_hint <- capture.output( + explain(hint(df, "hint1", 1.23456, "aaaaaaaaaa", hintList), TRUE) + ) + expect_true(any(grepl("1.23456, aaaaaaaaaa", execution_plan_hint))) +}) + test_that("toJSON() on DataFrame", { df <- as.DataFrame(cars) df_json <- toJSON(df) From 0e31a6f25e0263b144255a6630e1d381fe2d27a7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 20 Sep 2018 12:34:39 +0800 Subject: [PATCH 1655/2461] [SPARK-25339][TEST] Refactor FilterPushdownBenchmark ## What changes were proposed in this pull request? Refactor `FilterPushdownBenchmark` use `main` method. we can use 3 ways to run this test now: 1. bin/spark-submit --class org.apache.spark.sql.execution.benchmark.FilterPushdownBenchmark spark-sql_2.11-2.5.0-SNAPSHOT-tests.jar 2. build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.FilterPushdownBenchmark" 3. SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.FilterPushdownBenchmark" The method 2 and the method 3 do not need to compile the `spark-sql_*-tests.jar` package. So these two methods are mainly for developers to quickly do benchmark. ## How was this patch tested? manual tests Closes #22443 from wangyum/SPARK-25339. Authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../org/apache/spark/util/BenchmarkBase.scala | 57 +++ .../benchmark/FilterPushdownBenchmark.scala | 333 ++++++++---------- 2 files changed, 206 insertions(+), 184 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala diff --git a/core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala b/core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala new file mode 100644 index 0000000000000..c84032b8726db --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.{File, FileOutputStream, OutputStream} + +/** + * A base class for generate benchmark results to a file. + */ +abstract class BenchmarkBase { + var output: Option[OutputStream] = None + + def benchmark(): Unit + + final def runBenchmark(benchmarkName: String)(func: => Any): Unit = { + val separator = "=" * 96 + val testHeader = (separator + '\n' + benchmarkName + '\n' + separator + '\n' + '\n').getBytes + output.foreach(_.write(testHeader)) + func + output.foreach(_.write('\n')) + } + + def main(args: Array[String]): Unit = { + val regenerateBenchmarkFiles: Boolean = System.getenv("SPARK_GENERATE_BENCHMARK_FILES") == "1" + if (regenerateBenchmarkFiles) { + val resultFileName = s"${this.getClass.getSimpleName.replace("$", "")}-results.txt" + val file = new File(s"benchmarks/$resultFileName") + if (!file.exists()) { + file.createNewFile() + } + output = Some(new FileOutputStream(file)) + } + + benchmark() + + output.foreach { o => + if (o != null) { + o.close() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index d6dfdec45a0e8..9ecea99f12895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -17,29 +17,28 @@ package org.apache.spark.sql.execution.benchmark -import java.io.{File, FileOutputStream, OutputStream} +import java.io.File import scala.util.{Random, Try} -import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} - import org.apache.spark.SparkConf -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} -import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.util.{Benchmark, BenchmarkBase => FileBenchmarkBase, Utils} /** * Benchmark to measure read performance with Filter pushdown. - * To run this: - * build/sbt "sql/test-only *FilterPushdownBenchmark" - * - * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". */ -class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { +object FilterPushdownBenchmark extends FileBenchmarkBase { + private val conf = new SparkConf() .setAppName(this.getClass.getSimpleName) // Since `spark.master` always exists, overrides this value @@ -58,33 +57,6 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter private val spark = SparkSession.builder().config(conf).getOrCreate() - private var out: OutputStream = _ - - override def beforeAll() { - super.beforeAll() - out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) - } - - override def beforeEach(td: TestData) { - super.beforeEach(td) - val separator = "=" * 96 - val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes - out.write(testHeader) - } - - override def afterEach(td: TestData) { - out.write('\n') - super.afterEach(td) - } - - override def afterAll() { - try { - out.close() - } finally { - super.afterAll() - } - } - def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() path.delete() @@ -154,7 +126,7 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter title: String, whereExpr: String, selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) + val benchmark = new Benchmark(title, values, minNumIters = 5, output = output) Seq(false, true).foreach { pushDownEnabled => val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" @@ -241,191 +213,184 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter } } - ignore("Pushdown for many distinct value case") { - withTempPath { dir => - withTempTable("orcTable", "parquetTable") { - Seq(true, false).foreach { useStringForValue => - prepareTable(dir, numRows, width, useStringForValue) - if (useStringForValue) { - runStringBenchmark(numRows, width, mid, "string") - } else { - runIntBenchmark(numRows, width, mid) + override def benchmark(): Unit = { + runBenchmark("Pushdown for many distinct value case") { + withTempPath { dir => + withTempTable("orcTable", "parquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } } } } } - } - ignore("Pushdown for few distinct value case (use dictionary encoding)") { - withTempPath { dir => - val numDistinctValues = 200 + runBenchmark("Pushdown for few distinct value case (use dictionary encoding)") { + withTempPath { dir => + val numDistinctValues = 200 - withTempTable("orcTable", "parquetTable") { - prepareStringDictTable(dir, numRows, numDistinctValues, width) - runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") + withTempTable("orcTable", "parquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") + } } } - } - ignore("Pushdown benchmark for StringStartsWith") { - withTempPath { dir => - withTempTable("orcTable", "parquetTable") { - prepareTable(dir, numRows, width, true) - Seq( - "value like '10%'", - "value like '1000%'", - s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" - ).foreach { whereExpr => - val title = s"StringStartsWith filter: ($whereExpr)" - filterPushDownBenchmark(numRows, title, whereExpr) + runBenchmark("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "parquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } } } } - } - - ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { - withTempPath { dir => - Seq( - s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", - s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", - s"decimal(${DecimalType.MAX_PRECISION}, 2)" - ).foreach { dt => - val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { - monotonically_increasing_id() % 9999999 - } else { - monotonically_increasing_id() - } - val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) - withTempTable("orcTable", "parquetTable") { - saveAsTable(df, dir) - Seq(s"value = $mid").foreach { whereExpr => - val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) + runBenchmark(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() } + val df = spark.range(numRows) + .selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% $dt rows (value < ${numRows * percent / 100})", - s"value < ${numRows * percent / 100}", - selectExpr - ) + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } } } } } - } - ignore("Pushdown benchmark for InSet -> InFilters") { - withTempPath { dir => - withTempTable("orcTable", "parquetTable") { - prepareTable(dir, numRows, width, false) - Seq(5, 10, 50, 100).foreach { count => - Seq(10, 50, 90).foreach { distribution => - val filter = - Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) - val whereExpr = s"value in(${filter.mkString(",")})" - val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" - filterPushDownBenchmark(numRows, title, whereExpr) + runBenchmark("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "parquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } } } } } - } - ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { - withTempPath { dir => - val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) - .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) - .orderBy("value") - withTempTable("orcTable", "parquetTable") { - saveAsTable(df, dir) - - Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") - .foreach { whereExpr => - val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" - .replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) - } + runBenchmark(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% ${ByteType.simpleString} rows " + - s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", - s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", - selectExpr - ) + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } } } } - } - ignore(s"Pushdown benchmark for Timestamp") { - withTempPath { dir => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { - ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { - val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) - .withColumn("value", monotonically_increasing_id().cast(TimestampType)) - withTempTable("orcTable", "parquetTable") { - saveAsTable(df, dir) - - Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => - val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" - .replace("value AND value", "value") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% timestamp stored as $fileType rows " + - s"(value < CAST(${numRows * percent / 100} AS timestamp))", - s"value < CAST(${numRows * percent / 100} as timestamp)", - selectExpr - ) + runBenchmark(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width) + .map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } } } } } } } - } - ignore(s"Pushdown benchmark with many filters") { - val numRows = 1 - val width = 500 - - withTempPath { dir => - val columns = (1 to width).map(i => s"id c$i") - val df = spark.range(1).selectExpr(columns: _*) - withTempTable("orcTable", "parquetTable") { - saveAsTable(df, dir) - Seq(1, 250, 500).foreach { numFilter => - val whereExpr = (1 to numFilter).map(i => s"c$i = 0").mkString(" and ") - // Note: InferFiltersFromConstraints will add more filters to this given filters - filterPushDownBenchmark(numRows, s"Select 1 row with $numFilter filters", whereExpr) + runBenchmark(s"Pushdown benchmark with many filters") { + val numRows = 1 + val width = 500 + + withTempPath { dir => + val columns = (1 to width).map(i => s"id c$i") + val df = spark.range(1).selectExpr(columns: _*) + withTempTable("orcTable", "parquetTable") { + saveAsTable(df, dir) + Seq(1, 250, 500).foreach { numFilter => + val whereExpr = (1 to numFilter).map(i => s"c$i = 0").mkString(" and ") + // Note: InferFiltersFromConstraints will add more filters to this given filters + filterPushDownBenchmark(numRows, s"Select 1 row with $numFilter filters", whereExpr) + } } } } } } - -trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => - - override def beforeEach(td: TestData) { - super.beforeEach(td) - } - - override def afterEach(td: TestData) { - super.afterEach(td) - } -} From 7ff5386ed934190344b2cda1069bde4bc68a3e63 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 20 Sep 2018 15:03:16 +0800 Subject: [PATCH 1656/2461] [MINOR][PYTHON][TEST] Use collect() instead of show() to make the output silent ## What changes were proposed in this pull request? This PR replace an effective `show()` to `collect()` to make the output silent. **Before:** ``` test_simple_udt_in_df (pyspark.sql.tests.SQLTests) ... +---+----------+ |key| val| +---+----------+ | 0|[0.0, 0.0]| | 1|[1.0, 1.0]| | 2|[2.0, 2.0]| | 0|[3.0, 3.0]| | 1|[4.0, 4.0]| | 2|[5.0, 5.0]| | 0|[6.0, 6.0]| | 1|[7.0, 7.0]| | 2|[8.0, 8.0]| | 0|[9.0, 9.0]| +---+----------+ ``` **After:** ``` test_simple_udt_in_df (pyspark.sql.tests.SQLTests) ... ok ``` ## How was this patch tested? Manually tested. Closes #22479 from HyukjinKwon/minor-udf-test. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 603f994dc9597..8724bbc6ca7c5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1168,7 +1168,7 @@ def test_simple_udt_in_df(self): df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) - df.show() + df.collect() def test_nested_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) From 89671a27e783d77d4bfaec3d422cc8dd468ef04c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 20 Sep 2018 20:18:31 +0800 Subject: [PATCH 1657/2461] Revert [SPARK-19355][SPARK-25352] ## What changes were proposed in this pull request? This goes to revert sequential PRs based on some discussion and comments at https://github.com/apache/spark/pull/16677#issuecomment-422650759. #22344 #22330 #22239 #16677 ## How was this patch tested? Existing tests. Closes #22481 from viirya/revert-SPARK-19355-1. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../sort/BypassMergeSortShuffleWriter.java | 5 +- .../shuffle/sort/UnsafeShuffleWriter.java | 3 +- .../apache/spark/MapOutputStatistics.scala | 6 +- .../org/apache/spark/MapOutputTracker.scala | 10 +- .../apache/spark/scheduler/MapStatus.scala | 43 ++------ .../shuffle/sort/SortShuffleWriter.scala | 3 +- .../sort/UnsafeShuffleWriterSuite.java | 2 - .../apache/spark/MapOutputTrackerSuite.scala | 28 ++--- .../scala/org/apache/spark/ShuffleSuite.scala | 1 - .../spark/scheduler/DAGSchedulerSuite.scala | 10 +- .../spark/scheduler/MapStatusSuite.scala | 16 +-- .../serializer/KryoSerializerSuite.scala | 3 +- .../plans/physical/partitioning.scala | 14 --- .../apache/spark/sql/internal/SQLConf.scala | 9 -- .../spark/sql/execution/SparkStrategies.scala | 35 ++---- .../exchange/ShuffleExchangeExec.scala | 8 -- .../apache/spark/sql/execution/limit.scala | 104 +++--------------- .../test/resources/sql-tests/inputs/limit.sql | 2 - .../inputs/subquery/in-subquery/in-limit.sql | 5 +- .../resources/sql-tests/results/limit.sql.out | 92 +++++++--------- .../subquery/in-subquery/in-limit.sql.out | 56 ++++------ .../spark/sql/DataFrameAggregateSuite.scala | 12 +- .../org/apache/spark/sql/DataFrameSuite.scala | 22 +--- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +- .../execution/ExchangeCoordinatorSuite.scala | 6 +- .../spark/sql/execution/LimitSuite.scala | 81 -------------- .../spark/sql/execution/PlannerSuite.scala | 4 +- .../TakeOrderedAndProjectSuite.scala | 85 ++++++-------- .../execution/HiveCompatibilitySuite.scala | 4 - .../sql/hive/execution/PruningSuite.scala | 8 -- 30 files changed, 184 insertions(+), 504 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index e3bd5496cf5ba..323a5d3c52831 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,8 +167,7 @@ public void write(Iterator> records) throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 069e6d5f224d7..4839d04522f10 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -248,8 +248,7 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten()); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index ff85e11409e35..f8a6f1d0d8cbb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -23,9 +23,5 @@ package org.apache.spark * @param shuffleId ID of the shuffle * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) - * @param recordsByPartitionId number of output records for each map output partition */ -private[spark] class MapOutputStatistics( - val shuffleId: Int, - val bytesByPartitionId: Array[Long], - val recordsByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 41575ce4e6e3d..1c4fa4bc6541f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -522,19 +522,16 @@ private[spark] class MapOutputTrackerMaster( def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - val recordsByMapTask = new Array[Long](statuses.length) - val parallelAggThreshold = conf.get( SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt if (parallelism <= 1) { - statuses.zipWithIndex.foreach { case (s, index) => + for (s <- statuses) { for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } - recordsByMapTask(index) = s.numberOfOutput } } else { val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") @@ -551,11 +548,8 @@ private[spark] class MapOutputTrackerMaster( } finally { threadPool.shutdown() } - statuses.zipWithIndex.foreach { case (s, index) => - recordsByMapTask(index) = s.numberOfOutput - } } - new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask) + new MapOutputStatistics(dep.shuffleId, totalSizes) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 7e1d75fe723d6..659694dd189ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task, - * for passing on to the reduce tasks. + * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { /** Location where this task was run. */ @@ -45,23 +44,18 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long - - /** - * The number of outputs for the map task. - */ - def numberOfOutput: Long } private[spark] object MapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { - HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput) + HighlyCompressedMapStatus(loc, uncompressedSizes) } else { - new CompressedMapStatus(loc, uncompressedSizes, numOutput) + new CompressedMapStatus(loc, uncompressedSizes) } } @@ -104,34 +98,29 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte], - private[this] var numOutput: Long) + private[this] var compressedSizes: Array[Byte]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) { - this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc - override def numberOfOutput: Long = numOutput - override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) - out.writeLong(numOutput) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - numOutput = in.readLong() val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -154,20 +143,17 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte], - private[this] var numOutput: Long) + private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc - override def numberOfOutput: Long = numOutput - override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -182,7 +168,6 @@ private[spark] class HighlyCompressedMapStatus private ( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) - out.writeLong(numOutput) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -194,7 +179,6 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - numOutput = in.readLong() emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -210,10 +194,7 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply( - loc: BlockManagerId, - uncompressedSizes: Array[Long], - numOutput: Long): HighlyCompressedMapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -254,6 +235,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap, numOutput) + hugeBlockSizesArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 91fc26762e533..274399b9cc1f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,8 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, - writeMetrics.recordsWritten) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index faa70f23b0ac6..0d5c5ea7903e9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -233,7 +233,6 @@ public void writeEmptyIterator() throws Exception { writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); - assertEquals(0, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); @@ -253,7 +252,6 @@ public void writeWithoutSpilling() throws Exception { writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); - assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput()); assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index e79739692fe13..21f481d477242 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), 10)) + Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L), 10)) + Array(10000L, 1000L))) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000), 10)) + Array(compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000), 10)) + Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000), 10)) + Array(compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000), 10)) + Array(compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), 10)) + BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0)) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), 1)) + Array(2L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), 1)) + Array(2L))) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L), 1)) + Array(3L))) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0)) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000), 1)) + Array(size0, size1000, size0, size10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0), 1)) + Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 456f97b535ef6..b917469e48747 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -391,7 +391,6 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(mapOutput2.isDefined) assert(mapOutput1.get.location === mapOutput2.get.location) assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) - assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput) // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 365eab0668ab2..d6c9ae6ab5191 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -445,17 +445,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), (Success, makeMapStatus("hostB", 1)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)), + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), (Success, makeMapStatus("hostB", 1)) )) // make sure our test setup is correct @@ -2857,7 +2857,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi object DAGSchedulerSuite { def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 555e48bd28aa0..354e6386fa60e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes, 1) + val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes, 1) + val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) @@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) val sizes = Array.fill[Long](500)(150L) // Test default value - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) assert(status.isInstanceOf[CompressedMapStatus]) // Test Non-positive values for (s <- -1 to 0) { assertThrows[IllegalArgumentException] { conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) } } // Test positive values Seq(1, 100, 499, 500, 501).foreach { s => conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes, 1) + val status = MapStatus(null, sizes) if(sizes.length > s) { assert(status.isInstanceOf[HighlyCompressedMapStatus]) } else { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 36912441c03bd..ac25bcef54349 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -345,8 +345,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize( - HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1)) + ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cd28c733f3613..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -208,18 +206,6 @@ case object SinglePartition extends Partitioning { } } -/** - * Represents a partitioning where rows are only serialized/deserialized locally. The number - * of partitions are not changed and also the distribution of rows. This is mainly used to - * obtain some statistics of map tasks such as number of outputs. - */ -case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning { - val numPartitions = childRDD.getNumPartitions - - // We will perform this partitioning no matter what the data distribution is. - override def satisfies0(required: Distribution): Boolean = false -} - /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a01e87c8d1dd3..da492198af5ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -255,13 +255,6 @@ object SQLConf { .intConf .createWithDefault(4) - val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit") - .internal() - .doc("During global limit, try to evenly distribute limited rows across data " + - "partitions. If disabled, scanning data partitions sequentially until reaching limit number.") - .booleanConf - .createWithDefault(true) - val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") .internal() @@ -1771,8 +1764,6 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) - def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT) - def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 89442a70283f5..dbc6db62bd820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,35 +66,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { - private def decideTopRankNode(limit: Int, child: LogicalPlan): Seq[SparkPlan] = { - if (limit < conf.topKSortFallbackThreshold) { - child match { - case Sort(order, true, child) => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Project(projectList, Sort(order, true, child)) => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil - } - } else { - GlobalLimitExec(limit, - LocalLimitExec(limit, planLater(child)), - orderedLimit = true) :: Nil - } - } - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => - decideTopRankNode(limit, s) - case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => - decideTopRankNode(limit, p) + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), s @ Sort(order, true, child)) => - decideTopRankNode(limit, s) - case Limit(IntegerLiteral(limit), p @ Project(projectList, Sort(order, true, child))) => - decideTopRankNode(limit, p) + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 9576605b1a214..aba94885f941c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -231,11 +231,6 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } - case l: LocalPartitioning => - new Partitioner { - override def numPartitions: Int = l.numPartitions - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -252,9 +247,6 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity - case _: LocalPartitioning => - val partitionId = TaskContext.get().partitionId() - _ => partitionId case _ => sys.error(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1a09632f93ca1..66bcda8913738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -47,16 +47,13 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } /** - * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + * Helper trait which defines methods that are shared by both + * [[LocalLimitExec]] and [[GlobalLimitExec]]. */ -case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport { - +trait BaseLimitExec extends UnaryExecNode with CodegenSupport { + val limit: Int override def output: Seq[Attribute] = child.output - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } @@ -96,96 +93,25 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode wi } /** - * Take the `limit` elements of the child output. + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. */ -case class GlobalLimitExec(limit: Int, child: SparkPlan, - orderedLimit: Boolean = false) extends UnaryExecNode { +case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning +} - override def outputOrdering: Seq[SortOrder] = child.outputOrdering +/** + * Take the first `limit` elements of the child's single output partition. + */ +case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - protected override def doExecute(): RDD[InternalRow] = { - val childRDD = child.execute() - val partitioner = LocalPartitioning(childRDD) - val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency( - childRDD, child.output, partitioner, serializer) - val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) { - // submitMapStage does not accept RDD with 0 partition. - // So, we will not submit this dependency. - val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency) - submittedStageFuture.get().recordsByPartitionId.toSeq - } else { - Nil - } + override def outputPartitioning: Partitioning = child.outputPartitioning - // This is an optimization to evenly distribute limited rows across all partitions. - // When enabled, Spark goes to take rows at each partition repeatedly until reaching - // limit number. When disabled, Spark takes all rows at first partition, then rows - // at second partition ..., until reaching limit number. - // The optimization is disabled when it is needed to keep the original order of rows - // before global sort, e.g., select * from table order by col limit 10. - val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && !orderedLimit - - val shuffled = new ShuffledRowRDD(shuffleDependency) - - val sumOfOutput = numberOfOutput.sum - if (sumOfOutput <= limit) { - shuffled - } else if (!flatGlobalLimit) { - var numRowTaken = 0 - val takeAmounts = numberOfOutput.map { num => - if (numRowTaken + num < limit) { - numRowTaken += num.toInt - num.toInt - } else { - val toTake = limit - numRowTaken - numRowTaken += toTake - toTake - } - } - val broadMap = sparkContext.broadcast(takeAmounts) - shuffled.mapPartitionsWithIndexInternal { case (index, iter) => - iter.take(broadMap.value(index).toInt) - } - } else { - // We try to evenly require the asked limit number of rows across all child rdd's partitions. - var rowsNeedToTake: Long = limit - val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L) - val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*) - - while (rowsNeedToTake > 0) { - val nonEmptyParts = remainingRowsByPartition.count(_ > 0) - // If the rows needed to take are less the number of non-empty partitions, take one row from - // each non-empty partitions until we reach `limit` rows. - // Otherwise, evenly divide the needed rows to each non-empty partitions. - val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts) - remainingRowsByPartition.zipWithIndex.foreach { case (num, index) => - // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during - // the traversal, so we need to add this check. - if (rowsNeedToTake > 0 && num > 0) { - if (num >= takePerPart) { - rowsNeedToTake -= takePerPart - takeAmountByPartition(index) += takePerPart - remainingRowsByPartition(index) -= takePerPart - } else { - rowsNeedToTake -= num - takeAmountByPartition(index) += num - remainingRowsByPartition(index) -= num - } - } - } - } - val broadMap = sparkContext.broadcast(takeAmountByPartition) - shuffled.mapPartitionsWithIndexInternal { case (index, iter) => - iter.take(broadMap.value(index).toInt) - } - } - } + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql index e33cd819f281f..b4c73cf33e53a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -1,5 +1,3 @@ --- Disable global limit parallel -set spark.sql.limit.flatGlobalLimit=false; -- limit on various data types SELECT * FROM testdata LIMIT 2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql index a862e0985b20c..a40ee082ba3b9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql @@ -1,9 +1,6 @@ -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery -- It includes correlated cases. --- Disable global limit optimization -set spark.sql.limit.flatGlobalLimit=false; - create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -100,4 +97,4 @@ WHERE t1d NOT IN (SELECT t2d LIMIT 1) GROUP BY t1b ORDER BY t1b NULLS last -LIMIT 1; +LIMIT 1; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out index 187f3bd6858fe..02fe1de84f753 100644 --- a/sql/core/src/test/resources/sql-tests/results/limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -1,134 +1,126 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 14 -- !query 0 -set spark.sql.limit.flatGlobalLimit=false --- !query 0 schema -struct --- !query 0 output -spark.sql.limit.flatGlobalLimit false - - --- !query 1 SELECT * FROM testdata LIMIT 2 --- !query 1 schema +-- !query 0 schema struct --- !query 1 output +-- !query 0 output 1 1 2 2 --- !query 2 +-- !query 1 SELECT * FROM arraydata LIMIT 2 --- !query 2 schema +-- !query 1 schema struct,nestedarraycol:array>> --- !query 2 output +-- !query 1 output [1,2,3] [[1,2,3]] [2,3,4] [[2,3,4]] --- !query 3 +-- !query 2 SELECT * FROM mapdata LIMIT 2 --- !query 3 schema +-- !query 2 schema struct> --- !query 3 output +-- !query 2 output {1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} {1:"a2",2:"b2",3:"c2",4:"d2"} --- !query 4 +-- !query 3 SELECT * FROM testdata LIMIT 2 + 1 --- !query 4 schema +-- !query 3 schema struct --- !query 4 output +-- !query 3 output 1 1 2 2 3 3 --- !query 5 +-- !query 4 SELECT * FROM testdata LIMIT CAST(1 AS int) --- !query 5 schema +-- !query 4 schema struct --- !query 5 output +-- !query 4 output 1 1 --- !query 6 +-- !query 5 SELECT * FROM testdata LIMIT -1 --- !query 6 schema +-- !query 5 schema struct<> --- !query 6 output +-- !query 5 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 7 +-- !query 6 SELECT * FROM testData TABLESAMPLE (-1 ROWS) --- !query 7 schema +-- !query 6 schema struct<> --- !query 7 output +-- !query 6 output org.apache.spark.sql.AnalysisException The limit expression must be equal to or greater than 0, but got -1; --- !query 8 +-- !query 7 SELECT * FROM testdata LIMIT CAST(1 AS INT) --- !query 8 schema +-- !query 7 schema struct --- !query 8 output +-- !query 7 output 1 1 --- !query 9 +-- !query 8 SELECT * FROM testdata LIMIT CAST(NULL AS INT) --- !query 9 schema +-- !query 8 schema struct<> --- !query 9 output +-- !query 8 output org.apache.spark.sql.AnalysisException The evaluated limit expression must not be null, but got CAST(NULL AS INT); --- !query 10 +-- !query 9 SELECT * FROM testdata LIMIT key > 3 --- !query 10 schema +-- !query 9 schema struct<> --- !query 10 output +-- !query 9 output org.apache.spark.sql.AnalysisException The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); --- !query 11 +-- !query 10 SELECT * FROM testdata LIMIT true --- !query 11 schema +-- !query 10 schema struct<> --- !query 11 output +-- !query 10 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got boolean; --- !query 12 +-- !query 11 SELECT * FROM testdata LIMIT 'a' --- !query 12 schema +-- !query 11 schema struct<> --- !query 12 output +-- !query 11 output org.apache.spark.sql.AnalysisException The limit expression must be integer type, but got string; --- !query 13 +-- !query 12 SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3 --- !query 13 schema +-- !query 12 schema struct --- !query 13 output +-- !query 12 output 4 --- !query 14 +-- !query 13 SELECT * FROM testdata WHERE key < 3 LIMIT ALL --- !query 14 schema +-- !query 13 schema struct --- !query 14 output +-- !query 13 output 1 1 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out index 9eb5b3383e734..71ca1f8649475 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out @@ -1,16 +1,8 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 8 -- !query 0 -set spark.sql.limit.flatGlobalLimit=false --- !query 0 schema -struct --- !query 0 output -spark.sql.limit.flatGlobalLimit false - - --- !query 1 create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -25,13 +17,13 @@ create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) --- !query 1 schema +-- !query 0 schema struct<> --- !query 1 output +-- !query 0 output --- !query 2 +-- !query 1 create temporary view t2 as select * from values ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), @@ -47,13 +39,13 @@ create temporary view t2 as select * from values ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) --- !query 2 schema +-- !query 1 schema struct<> --- !query 2 output +-- !query 1 output --- !query 3 +-- !query 2 create temporary view t3 as select * from values ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), @@ -68,27 +60,27 @@ create temporary view t3 as select * from values ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) --- !query 3 schema +-- !query 2 schema struct<> --- !query 3 output +-- !query 2 output --- !query 4 +-- !query 3 SELECT * FROM t1 WHERE t1a IN (SELECT t2a FROM t2 WHERE t1d = t2d) LIMIT 2 --- !query 4 schema +-- !query 3 schema struct --- !query 4 output +-- !query 3 output val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 5 +-- !query 4 SELECT * FROM t1 WHERE t1c IN (SELECT t2c @@ -96,16 +88,16 @@ WHERE t1c IN (SELECT t2c WHERE t2b >= 8 LIMIT 2) LIMIT 4 --- !query 5 schema +-- !query 4 schema struct --- !query 5 output +-- !query 4 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 val1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 --- !query 6 +-- !query 5 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -116,29 +108,29 @@ WHERE t1d IN (SELECT t2d GROUP BY t1b ORDER BY t1b DESC NULLS FIRST LIMIT 1 --- !query 6 schema +-- !query 5 schema struct --- !query 6 output +-- !query 5 output 1 NULL --- !query 7 +-- !query 6 SELECT * FROM t1 WHERE t1b NOT IN (SELECT t2b FROM t2 WHERE t2b > 6 LIMIT 2) --- !query 7 schema +-- !query 6 schema struct --- !query 7 output +-- !query 6 output val1a 16 12 10 15.0 20.0 2000 2014-07-04 01:01:00 2014-07-04 val1a 16 12 21 15.0 20.0 2000 2014-06-04 01:02:00.001 2014-06-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 val1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 --- !query 8 +-- !query 7 SELECT Count(DISTINCT( t1a )), t1b FROM t1 @@ -149,7 +141,7 @@ WHERE t1d NOT IN (SELECT t2d GROUP BY t1b ORDER BY t1b NULLS last LIMIT 1 --- !query 8 schema +-- !query 7 schema struct --- !query 8 output +-- !query 7 output 1 6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ed110f751645d..d0106c44b7db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,13 +557,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-18004 limit + aggregates") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value").repartition(1) - val limit2Df = df.limit(2) - checkAnswer( - limit2Df.groupBy("id").count().select($"id"), - limit2Df.select($"id")) - } + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) } test("SPARK-17237 remove backticks in a pivot result schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f001b138f4b8e..279b7b8d49f52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, TakeOrderedAndProjectExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ @@ -2552,26 +2552,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-25352: Ordered global limit when more than topKSortFallbackThreshold ") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val baseDf = spark.range(1000).toDF.repartition(3).sort("id") - - withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { - val expected = baseDf.limit(99) - val takeOrderedNode1 = expected.queryExecution.executedPlan - .find(_.isInstanceOf[TakeOrderedAndProjectExec]) - assert(takeOrderedNode1.isDefined) - - val result = baseDf.limit(100) - val takeOrderedNode2 = result.queryExecution.executedPlan - .find(_.isInstanceOf[TakeOrderedAndProjectExec]) - assert(takeOrderedNode2.isEmpty) - - checkAnswer(expected, result.collect().take(99)) - } - } - } - test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { val df1 = spark.createDataFrame(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 01dc28d70184e..8fcebb35a0543 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -524,15 +524,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("limit for skew dataframe") { - // Create a skew dataframe. - val df = testData.repartition(100).union(testData).limit(50) - // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`, - // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit` - // work on skew partitions. - assert(df.rdd.count() == 50L) - } - test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), @@ -1944,7 +1935,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 order by a, b limit 1") + val df = sql("SELECT a, b from testData2 limit 1") checkAnswer(df, Row(1, 1)) checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index c627c51655c8d..6ad025f37e440 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -55,7 +55,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1)) + new MapOutputStatistics(index, bytesByPartitionId) } val estimatedPartitionStartIndices = coordinator.estimatePartitionStartIndices(mapOutputStatistics) @@ -119,8 +119,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) val mapOutputStatistics = Array( - new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)), - new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0))) + new MapOutputStatistics(0, bytesByPartitionId1), + new MapOutputStatistics(1, bytesByPartitionId2)) intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala deleted file mode 100644 index a7840a5fcfae0..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LimitSuite.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext - - -class LimitSuite extends SparkPlanTest with SharedSQLContext { - - private var rand: Random = _ - private var seed: Long = 0 - - protected override def beforeAll(): Unit = { - super.beforeAll() - seed = System.currentTimeMillis() - rand = new Random(seed) - } - - test("Produce ordered global limit if more than topKSortFallbackThreshold") { - withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "100") { - val df = LimitTest.generateRandomInputData(spark, rand).sort("a") - - val globalLimit = df.limit(99).queryExecution.executedPlan.collect { - case g: GlobalLimitExec => g - } - assert(globalLimit.size == 0) - - val topKSort = df.limit(99).queryExecution.executedPlan.collect { - case t: TakeOrderedAndProjectExec => t - } - assert(topKSort.size == 1) - - val orderedGlobalLimit = df.limit(100).queryExecution.executedPlan.collect { - case g: GlobalLimitExec => g - } - assert(orderedGlobalLimit.size == 1 && orderedGlobalLimit(0).orderedLimit == true) - } - } - - test("Ordered global limit") { - val baseDf = LimitTest.generateRandomInputData(spark, rand) - .select("a").repartition(3).sort("a") - - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - val orderedGlobalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, - orderedLimit = true) - val orderedGlobalLimitResult = SparkPlanTest.executePlan(orderedGlobalLimit, spark.sqlContext) - .map(_.getInt(0)) - - val globalLimit = GlobalLimitExec(3, baseDf.queryExecution.sparkPlan, orderedLimit = false) - val globalLimitResult = SparkPlanTest.executePlan(globalLimit, spark.sqlContext) - .map(_.getInt(0)) - - // Global limit without order takes values at each partition sequentially. - // After global sort, the values in second partition must be larger than the values - // in first partition. - assert(orderedGlobalLimitResult(0) == globalLimitResult(0)) - assert(orderedGlobalLimitResult(1) < globalLimitResult(1)) - assert(orderedGlobalLimitResult(2) < globalLimitResult(2)) - } - } -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b10da6c70be16..e4e224df7607f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } { @@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext { ).queryExecution.executedPlan.collect { case exchange: ShuffleExchangeExec => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 9322204063af3..7e317a4d80265 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -38,6 +37,14 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { rand = new Random(seed) } + private def generateRandomInputData(): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } + /** * Adds a no-op filter to the child plan in order to prevent executeCollect() from being * called directly on the child plan. @@ -48,62 +55,32 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { val sortOrder = 'a.desc :: 'b.desc :: Nil test("TakeOrderedAndProject.doExecute without project") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - LimitTest.generateRandomInputData(spark, rand), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input))), - sortAnswers = false) - } + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + SortExec(sortOrder, true, input))), + sortAnswers = false) } } test("TakeOrderedAndProject.doExecute with project") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "false") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - LimitTest.generateRandomInputData(spark, rand), - input => - noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - ProjectExec(Seq(input.output.last), - SortExec(sortOrder, true, input)))), - sortAnswers = false) - } - } - } - - test("TakeOrderedAndProject.doExecute equals to ordered global limit") { - withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") { - withClue(s"seed = $seed") { - checkThatPlansAgree( - LimitTest.generateRandomInputData(spark, rand), - input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), - input => - GlobalLimitExec(limit, - LocalLimitExec(limit, - SortExec(sortOrder, true, input)), orderedLimit = true), - sortAnswers = false) - } + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter( + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), + input => + GlobalLimitExec(limit, + LocalLimitExec(limit, + ProjectExec(Seq(input.output.last), + SortExec(sortOrder, true, input)))), + sortAnswers = false) } } } - -object LimitTest { - def generateRandomInputData(spark: SparkSession, rand: Random): DataFrame = { - val schema = new StructType() - .add("a", IntegerType, nullable = false) - .add("b", IntegerType, nullable = false) - val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - spark.createDataFrame(spark.sparkContext.parallelize(Random.shuffle(inputData), 10), schema) - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b9b2b7dbf38e8..cebaad5b4ad9b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,7 +40,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled - private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { @@ -60,8 +59,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) - // Ensure that limit operation returns rows in the same order as Hive - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") @@ -76,7 +73,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 16541295eb453..cc592cf6ca629 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -22,29 +22,21 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} -import org.apache.spark.sql.internal.SQLConf /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit - override def beforeAll(): Unit = { super.beforeAll() TestHive.setCacheTables(false) - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false) // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, // need to reset the environment to ensure all referenced tables in this suites are // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 // for details. TestHive.reset() } - override def afterAll() { - TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit) - super.afterAll() - } // Column pruning tests From edf5cc64e4bfc34643952f3a9582beca20c4bddc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 20 Sep 2018 20:22:55 +0800 Subject: [PATCH 1658/2461] [SPARK-25460][SS] DataSourceV2: SS sources do not respect SessionConfigSupport ## What changes were proposed in this pull request? This PR proposes to respect `SessionConfigSupport` in SS datasources as well. Currently these are only respected in batch sources: https://github.com/apache/spark/blob/e06da95cd9423f55cdb154a2778b0bddf7be984c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala#L198-L203 https://github.com/apache/spark/blob/e06da95cd9423f55cdb154a2778b0bddf7be984c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala#L244-L249 If a developer makes a datasource V2 that supports both structured streaming and batch jobs, batch jobs respect a specific configuration, let's say, URL to connect and fetch data (which end users might not be aware of); however, structured streaming ends up with not supporting this (and should explicitly be set into options). ## How was this patch tested? Unit tests were added. Closes #22462 from HyukjinKwon/SPARK-25460. Authored-by: hyukjinkwon Signed-off-by: Wenchen Fan --- .../sql/streaming/DataStreamReader.scala | 24 ++-- .../sql/streaming/DataStreamWriter.scala | 16 ++- .../sources/StreamingDataSourceV2Suite.scala | 116 +++++++++++++++--- 3 files changed, 128 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2a4db4afbe005..4c7dcedafeeae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} @@ -158,7 +159,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() - val options = new DataSourceOptions(extraOptions.asJava) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. @@ -173,13 +173,18 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } ds match { case s: MicroBatchReadSupportProvider => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = s, conf = sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dataSourceOptions = new DataSourceOptions(options.asJava) var tempReadSupport: MicroBatchReadSupport = null val schema = try { val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + s.createMicroBatchReadSupport( + userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) } else { - s.createMicroBatchReadSupport(tmpCheckpointPath, options) + s.createMicroBatchReadSupport(tmpCheckpointPath, dataSourceOptions) } tempReadSupport.fullSchema() } finally { @@ -192,16 +197,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo Dataset.ofRows( sparkSession, StreamingRelationV2( - s, source, extraOptions.toMap, + s, source, options, schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupportProvider => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = s, conf = sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dataSourceOptions = new DataSourceOptions(options.asJava) var tempReadSupport: ContinuousReadSupport = null val schema = try { val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + s.createContinuousReadSupport( + userSpecifiedSchema.get, tmpCheckpointPath, dataSourceOptions) } else { - s.createContinuousReadSupport(tmpCheckpointPath, options) + s.createContinuousReadSupport(tmpCheckpointPath, dataSourceOptions) } tempReadSupport.fullSchema() } finally { @@ -214,7 +224,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo Dataset.ofRows( sparkSession, StreamingRelationV2( - s, source, extraOptions.toMap, + s, source, options, schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 7866e4f70f14b..e9a15214d952f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ @@ -298,23 +299,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") + var options = extraOptions.toMap val sink = ds.newInstance() match { case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => w + if !disabledSources.contains(w.getClass.getCanonicalName) => + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + w, df.sparkSession.sessionState.conf) + options = sessionOptions ++ extraOptions + w case _ => val ds = DataSource( df.sparkSession, className = source, - options = extraOptions.toMap, + options = options, partitionColumns = normalizedParCols.getOrElse(Nil)) ds.createSink(outputMode) } df.sparkSession.sessionState.streamingQueryManager.startQuery( - extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), + options.get("queryName"), + options.get("checkpointLocation"), df, - extraOptions.toMap, + options, sink, outputMode, useTempCheckpointLocation = source == "console", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index aeef4c8fe9332..3a0e780a73915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReaderFactory, ScanConfig, ScanConfigBuilder} import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport -import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -56,13 +56,19 @@ case class FakeReadSupport() extends MicroBatchReadSupport with ContinuousReadSu trait FakeMicroBatchReadSupportProvider extends MicroBatchReadSupportProvider { override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = FakeReadSupport() + options: DataSourceOptions): MicroBatchReadSupport = { + LastReadOptions.options = options + FakeReadSupport() + } } trait FakeContinuousReadSupportProvider extends ContinuousReadSupportProvider { override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = FakeReadSupport() + options: DataSourceOptions): ContinuousReadSupport = { + LastReadOptions.options = options + FakeReadSupport() + } } trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { @@ -71,16 +77,27 @@ trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamingWriteSupport = { + LastWriteOptions.options = options throw new IllegalStateException("fake sink - cannot actually write") } } -class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupportProvider { +class FakeReadMicroBatchOnly + extends DataSourceRegister + with FakeMicroBatchReadSupportProvider + with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" + + override def keyPrefix: String = shortName() } -class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupportProvider { +class FakeReadContinuousOnly + extends DataSourceRegister + with FakeContinuousReadSupportProvider + with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" + + override def keyPrefix: String = shortName() } class FakeReadBothModes extends DataSourceRegister @@ -92,8 +109,13 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWriteSupportProvider extends DataSourceRegister with FakeStreamingWriteSupportProvider { +class FakeWriteSupportProvider + extends DataSourceRegister + with FakeStreamingWriteSupportProvider + with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" + + override def keyPrefix: String = shortName() } class FakeNoWrite extends DataSourceRegister { @@ -121,6 +143,21 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" } +object LastReadOptions { + var options: DataSourceOptions = _ + + def clear(): Unit = { + options = null + } +} + +object LastWriteOptions { + var options: DataSourceOptions = _ + + def clear(): Unit = { + options = null + } +} class StreamingDataSourceV2Suite extends StreamTest { @@ -130,6 +167,11 @@ class StreamingDataSourceV2Suite extends StreamTest { spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) } + override def afterEach(): Unit = { + LastReadOptions.clear() + LastWriteOptions.clear() + } + val readFormats = Seq( "fake-read-microbatch-only", "fake-read-continuous-only", @@ -143,7 +185,14 @@ class StreamingDataSourceV2Suite extends StreamTest { Trigger.ProcessingTime(1000), Trigger.Continuous(1000)) - private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger): Unit = { + testPositiveCaseWithQuery(readFormat, writeFormat, trigger)(() => _) + } + + private def testPositiveCaseWithQuery( + readFormat: String, + writeFormat: String, + trigger: Trigger)(check: StreamingQuery => Unit): Unit = { val query = spark.readStream .format(readFormat) .load() @@ -151,8 +200,8 @@ class StreamingDataSourceV2Suite extends StreamTest { .format(writeFormat) .trigger(trigger) .start() + check(query) query.stop() - query } private def testNegativeCase( @@ -188,19 +237,54 @@ class StreamingDataSourceV2Suite extends StreamTest { test("disabled v2 write") { // Ensure the V2 path works normally and generates a V2 sink.. - val v2Query = testPositiveCase( - "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) - assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + testPositiveCaseWithQuery( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query => + assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + } // Ensure we create a V1 sink with the config. Note the config is a comma separated // list, including other fake entries. val fullSinkName = classOf[FakeWriteSupportProviderV1Fallback].getName withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { - val v1Query = testPositiveCase( - "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) - assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeSink]) + testPositiveCaseWithQuery( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v1Query => + assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeSink]) + } + } + } + + Seq( + Tuple2(classOf[FakeReadMicroBatchOnly], Trigger.Once()), + Tuple2(classOf[FakeReadContinuousOnly], Trigger.Continuous(1000)) + ).foreach { case (source, trigger) => + test(s"SPARK-25460: session options are respected in structured streaming sources - $source") { + // `keyPrefix` and `shortName` are the same in this test case + val readSource = source.newInstance().shortName() + val writeSource = "fake-write-microbatch-continuous" + + val readOptionName = "optionA" + withSQLConf(s"spark.datasource.$readSource.$readOptionName" -> "true") { + testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => + eventually(timeout(streamingTimeout)) { + // Write options should not be set. + assert(LastWriteOptions.options.getBoolean(readOptionName, false) == false) + assert(LastReadOptions.options.getBoolean(readOptionName, false) == true) + } + } + } + + val writeOptionName = "optionB" + withSQLConf(s"spark.datasource.$writeSource.$writeOptionName" -> "true") { + testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => + eventually(timeout(streamingTimeout)) { + // Read options should not be set. + assert(LastReadOptions.options.getBoolean(writeOptionName, false) == false) + assert(LastWriteOptions.options.getBoolean(writeOptionName, false) == true) + } + } + } } } From 67f2c6a55425d0f38e26caaf7e0b665d978d0a68 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 20 Sep 2018 20:33:44 +0800 Subject: [PATCH 1659/2461] [SPARK-25417][SQL] ArrayContains function may return incorrect result when right expression is implicitly down casted ## What changes were proposed in this pull request? In ArrayContains, we currently cast the right hand side expression to match the element type of the left hand side Array. This may result in down casting and may return wrong result or questionable result. Example : ```SQL spark-sql> select array_contains(array(1), 1.34); true ``` ```SQL spark-sql> select array_contains(array(1), 'foo'); null ``` We should safely coerce both left and right hand side expressions. ## How was this patch tested? Added tests in DataFrameFunctionsSuite Closes #22408 from dilipbiswal/SPARK-25417. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- docs/sql-programming-guide.md | 61 ++++++++++++++++++- python/pyspark/sql/tests.py | 3 +- .../expressions/collectionOperations.scala | 28 +++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 51 ++++++++++++++++ 4 files changed, 128 insertions(+), 15 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c76f2e30e6771..d2e3ee3e77818 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1879,6 +1879,66 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 + - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. + + + + + + + + + + + + + + + + + + + + + + + + + +
      + Query + + Result Spark 2.3 or Prior + + Result Spark 2.4 + + Remarks +
      + SELECT
      array_contains(array(1), 1.34D);
      +
      + true + + false + + In Spark 2.4, left and right parameters are promoted to array(double) and double type respectively. +
      + SELECT
      array_contains(array(1), '1');
      +
      + true + + AnalysisException is thrown since integer type can not be promoted to string type in a loss-less manner. + + Users can use explict cast +
      + SELECT
      array_contains(array(1), 'anystring');
      +
      + null + + AnalysisException is thrown since integer type can not be promoted to string type in a loss-less manner. + + Users can use explict cast +
      + - Since Spark 2.4, when there is a struct field in front of the IN operator before a subquery, the inner query must contain a struct field as well. In previous versions, instead, the fields of the struct were compared to the output of the inner query. Eg. if `a` is a `struct(a string, b int)`, in Spark 2.4 `a in (select (1 as a, 'a' as b) from range(1))` is a valid query, while `a in (select 1, 'a' from range(1))` is not. In previous version it was the opposite. - In versions 2.2.1+ and 2.3, if `spark.sql.caseSensitive` is set to true, then the `CURRENT_DATE` and `CURRENT_TIMESTAMP` functions incorrectly became case-sensitive and would resolve to columns (unless typed in lower case). In Spark 2.4 this has been fixed and the functions are no longer case-sensitive. - Since Spark 2.4, Spark will evaluate the set operations referenced in a query by following a precedence rule as per the SQL standard. If the order is not specified by parentheses, set operations are performed from left to right with the exception that all INTERSECT operations are performed before any UNION, EXCEPT or MINUS operations. The old behaviour of giving equal precedence to all the set operations are preserved under a newly added configuration `spark.sql.legacy.setopsPrecedence.enabled` with a default value of `false`. When this property is set to `true`, spark will evaluate the set operators from left to right as they appear in the query given no explicit ordering is enforced by usage of parenthesis. @@ -1912,7 +1972,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. - Partition column inference previously found incorrect common type for different inferred types, for example, previously it ended up with double type as the common type for double type and date type. Now it finds the correct common type for such conflicts. The conflict resolution follows the table below: - From 77e52448e7f94aadfa852cc67084415de6ecfa7c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 20 Sep 2018 15:46:33 -0700 Subject: [PATCH 1665/2461] [SPARK-25472][SS] Don't have legitimate stops of streams cause stream exceptions ## What changes were proposed in this pull request? Legitimate stops of streams may actually cause an exception to be captured by stream execution, because the job throws a SparkException regarding job cancellation during a stop. This PR makes the stop more graceful by swallowing this cancellation error. ## How was this patch tested? This is pretty hard to test. The existing tests should make sure that we're not swallowing other specific SparkExceptions. I've also run the `KafkaSourceStressForDontFailOnDataLossSuite`100 times, and it didn't fail, whereas it used to be flaky. Closes #22478 from brkyvz/SPARK-25472. Authored-by: Burak Yavuz Signed-off-by: Burak Yavuz --- .../execution/streaming/StreamExecution.scala | 22 ++++++++++++++----- .../continuous/ContinuousExecution.scala | 4 ++-- .../WriteToContinuousDataSourceExec.scala | 2 +- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index f6c60c1c92124..631a6eb649ffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -282,7 +283,7 @@ abstract class StreamExecution( // `stop()` is already called. Let `finally` finish the cleanup. } } catch { - case e if isInterruptedByStop(e) => + case e if isInterruptedByStop(e, sparkSession.sparkContext) => // interrupted by stop() updateStatusMessage("Stopped") case e: IOException if e.getMessage != null @@ -354,9 +355,9 @@ abstract class StreamExecution( } } - private def isInterruptedByStop(e: Throwable): Boolean = { + private def isInterruptedByStop(e: Throwable, sc: SparkContext): Boolean = { if (state.get == TERMINATED) { - StreamExecution.isInterruptionException(e) + StreamExecution.isInterruptionException(e, sc) } else { false } @@ -531,7 +532,7 @@ object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" - def isInterruptionException(e: Throwable): Boolean = e match { + def isInterruptionException(e: Throwable, sc: SparkContext): Boolean = e match { // InterruptedIOException - thrown when an I/O operation is interrupted // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => @@ -546,7 +547,18 @@ object StreamExecution { // ExecutionException, such as BiFunction.apply case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) if e2.getCause != null => - isInterruptionException(e2.getCause) + isInterruptionException(e2.getCause, sc) + case se: SparkException => + val jobGroup = sc.getLocalProperty("spark.jobGroup.id") + if (jobGroup == null) return false + val errorMsg = se.getMessage + if (errorMsg.contains("cancelled") && errorMsg.contains(jobGroup) && se.getCause == null) { + true + } else if (se.getCause != null) { + isInterruptionException(se.getCause, sc) + } else { + false + } case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index ccca72667a217..f009c52449adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -265,8 +265,8 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } catch { - case t: Throwable - if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING => + case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) && + state.get() == RECONFIGURING => logInfo(s"Query $id ignoring exception from reconfiguring: $t") // interrupted by reconfiguration - swallow exception so we can restart the query } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index c216b61383856..a797ac1879f41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -57,7 +57,7 @@ case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, case cause: Throwable => cause match { // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause + case _ if StreamExecution.isInterruptionException(cause, sparkContext) => throw cause // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) case _ => throw cause From 950ab79957fc0cdc2dafac94765787e87ece9e74 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Sep 2018 17:41:24 -0700 Subject: [PATCH 1666/2461] [SPARK-24777][SQL] Add write benchmark for AVRO ## What changes were proposed in this pull request? Refactor `DataSourceWriteBenchmark` and add write benchmark for AVRO. ## How was this patch tested? Build and run the benchmark. Closes #22451 from gengliangwang/avroWriteBenchmark. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../benchmark/AvroWriteBenchmark.scala | 40 ++++++++++ .../BuiltInDataSourceWriteBenchmark.scala | 79 +++++++++++++++++++ .../benchmark/DataSourceWriteBenchmark.scala | 75 +++--------------- 3 files changed, 131 insertions(+), 63 deletions(-) create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala new file mode 100644 index 0000000000000..df13b4a1c2d3a --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +/** + * Benchmark to measure Avro data sources write performance. + * Usage: + * 1. with spark-submit: bin/spark-submit --class + * 2. with sbt: build/sbt "avro/test:runMain " + */ +object AvroWriteBenchmark extends DataSourceWriteBenchmark { + def main(args: Array[String]): Unit = { + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Avro writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 2481 / 2499 6.3 157.8 1.0X + Output Single Double Column 2705 / 2710 5.8 172.0 0.9X + Output Int and String Column 5539 / 5639 2.8 352.2 0.4X + Output Partitions 4613 / 5004 3.4 293.3 0.5X + Output Buckets 5554 / 5561 2.8 353.1 0.4X + */ + runBenchmark("Avro") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2de516c19da9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +/** + * Benchmark to measure built-in data sources write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV. Run it with spark-submit: + * spark-submit --class + * Or with sbt: + * build/sbt "sql/test:runMain " + * + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + * Or with sbt: + * build/sbt "sql/test:runMain format1 [format2] [...]" + */ +object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark { + def main(args: Array[String]): Unit = { + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + + spark.conf.set("spark.sql.parquet.compression.codec", "snappy") + spark.conf.set("spark.sql.orc.compression.codec", "snappy") + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + formats.foreach { format => + runBenchmark(format) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala index 2d2cdebd067c1..e3463d9e28acc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -21,25 +21,14 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Benchmark -/** - * Benchmark to measure data source write performance. - * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: - * spark-submit --class - * To measure specified formats, run it with arguments: - * spark-submit --class format1 [format2] [...] - */ -object DataSourceWriteBenchmark { +trait DataSourceWriteBenchmark { val conf = new SparkConf() .setAppName("DataSourceWriteBenchmark") .setIfMissing("spark.master", "local[1]") - .set("spark.sql.parquet.compression.codec", "snappy") - .set("spark.sql.orc.compression.codec", "snappy") + .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") val spark = SparkSession.builder.config(conf).getOrCreate() - // Set default configs. Individual cases will change them if necessary. - spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - val tempTable = "temp" val numRows = 1024 * 1024 * 15 @@ -86,64 +75,24 @@ object DataSourceWriteBenchmark { } } - def main(args: Array[String]): Unit = { + def runBenchmark(format: String): Unit = { val tableInt = "tableInt" val tableDouble = "tableDouble" val tableIntString = "tableIntString" val tablePartition = "tablePartition" val tableBucket = "tableBucket" - val formats: Seq[String] = if (args.isEmpty) { - Seq("Parquet", "ORC", "JSON", "CSV") - } else { - args - } - /* - Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz - Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 1815 / 1932 8.7 115.4 1.0X - Output Single Double Column 1877 / 1878 8.4 119.3 1.0X - Output Int and String Column 6265 / 6543 2.5 398.3 0.3X - Output Partitions 4067 / 4457 3.9 258.6 0.4X - Output Buckets 5608 / 5820 2.8 356.6 0.3X - - ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 1201 / 1239 13.1 76.3 1.0X - Output Single Double Column 1542 / 1600 10.2 98.0 0.8X - Output Int and String Column 6495 / 6580 2.4 412.9 0.2X - Output Partitions 3648 / 3842 4.3 231.9 0.3X - Output Buckets 5022 / 5145 3.1 319.3 0.2X - - JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 1988 / 2093 7.9 126.4 1.0X - Output Single Double Column 2854 / 2911 5.5 181.4 0.7X - Output Int and String Column 6467 / 6653 2.4 411.1 0.3X - Output Partitions 4548 / 5055 3.5 289.1 0.4X - Output Buckets 5664 / 5765 2.8 360.1 0.4X - - CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 3025 / 3190 5.2 192.3 1.0X - Output Single Double Column 3575 / 3634 4.4 227.3 0.8X - Output Int and String Column 7313 / 7399 2.2 464.9 0.4X - Output Partitions 5105 / 5190 3.1 324.6 0.6X - Output Buckets 6986 / 6992 2.3 444.1 0.4X - */ withTempTable(tempTable) { spark.range(numRows).createOrReplaceTempView(tempTable) - formats.foreach { format => - withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { - val benchmark = new Benchmark(s"$format writer benchmark", numRows) - writeNumeric(tableInt, format, benchmark, "Int") - writeNumeric(tableDouble, format, benchmark, "Double") - writeIntString(tableIntString, format, benchmark) - writePartition(tablePartition, format, benchmark) - writeBucket(tableBucket, format, benchmark) - benchmark.run() - } + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() } } } } + From 5d25e154408f71d24c4829165a16014fdacdd209 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 21 Sep 2018 10:39:45 +0800 Subject: [PATCH 1667/2461] Revert "[SPARK-23715][SQL] the input of to/from_utc_timestamp can not have timezone ## What changes were proposed in this pull request? This reverts commit 417ad92502e714da71552f64d0e1257d2fd5d3d0. We decided to keep the current behaviors unchanged and will consider whether we will deprecate the these functions in 3.0. For more details, see the discussion in https://issues.apache.org/jira/browse/SPARK-23715 ## How was this patch tested? The existing tests. Closes #22505 from gatorsmile/revertSpark-23715. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- docs/sql-programming-guide.md | 1 - .../sql/catalyst/analysis/TypeCoercion.scala | 30 +--- .../expressions/datetimeExpressions.scala | 42 ------ .../sql/catalyst/util/DateTimeUtils.scala | 22 +-- .../apache/spark/sql/internal/SQLConf.scala | 7 - .../catalyst/analysis/TypeCoercionSuite.scala | 12 +- .../resources/sql-tests/inputs/datetime.sql | 33 ----- .../sql-tests/results/datetime.sql.out | 135 +----------------- .../apache/spark/sql/DateFunctionsSuite.scala | 9 -- 9 files changed, 13 insertions(+), 278 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 8ec4865d58162..ca9d153155441 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1955,7 +1955,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behavior to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 288b6358fbff1..49d286f6cf125 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -60,7 +60,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - new ImplicitTypeCasts(conf) :: + ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -841,33 +841,12 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { - - private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - + object ImplicitTypeCasts extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp - // string is in a specific timezone, so the string itself should not contain timezone. - // TODO: We should move the type coercion logic to expressions instead of a central - // place to put all the rules. - case e: FromUTCTimestamp if e.left.dataType == StringType => - if (rejectTzInString) { - e.copy(left = StringToTimestampWithoutTimezone(e.left)) - } else { - e.copy(left = Cast(e.left, TimestampType)) - } - - case e: ToUTCTimestamp if e.left.dataType == StringType => - if (rejectTzInString) { - e.copy(left = StringToTimestampWithoutTimezone(e.left)) - } else { - e.copy(left = Cast(e.left, TimestampType)) - } - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -884,7 +863,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) + implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -900,9 +879,6 @@ object TypeCoercion { } e.withNewChildren(children) } - } - - object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f95798d64db19..eb78e394f9850 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1017,48 +1017,6 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } -/** - * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, - * which requires the timestamp string to not have timezone information, otherwise null is returned. - */ -case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { - - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Option(timeZoneId)) - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - override def dataType: DataType = TimestampType - override def nullable: Boolean = true - override def toString: String = child.toString - override def sql: String = child.sql - - override def nullSafeEval(input: Any): Any = { - DateTimeUtils.stringToTimestamp( - input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val tz = ctx.addReferenceObj("timeZone", timeZone) - val longOpt = ctx.freshName("longOpt") - val eval = child.genCode(ctx) - val code = code""" - |${eval.code} - |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; - |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; - |if (!${eval.isNull}) { - | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); - | if ($longOpt.isDefined()) { - | ${ev.value} = ((Long) $longOpt.get()).longValue(); - | ${ev.isNull} = false; - | } - |} - """.stripMargin - ev.copy(code = code) - } -} - /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 02813d3939796..81d7274607ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -300,28 +300,10 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) + stringToTimestamp(s, defaultTimeZone()) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { - stringToTimestamp(s, timeZone, rejectTzInString = false) - } - - /** - * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. - * Returns None if the input string is not a valid timestamp format. - * - * @param s the input timestamp string. - * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string - * already contains timezone information and `forceTimezone` is false. - * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the - * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, - * return None. - */ - def stringToTimestamp( - s: UTF8String, - timeZone: TimeZone, - rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -439,8 +421,6 @@ object DateTimeUtils { return None } - if (tz.isDefined && rejectTzInString) return None - val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index da492198af5ed..083f493fc06a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1408,13 +1408,6 @@ object SQLConf { .stringConf .createWithDefault("") - val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") - .internal() - .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + - "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") - .booleanConf - .createWithDefault(true) - object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 461eda4334bb9..1602f4d046118 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -680,11 +680,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts, NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -692,11 +692,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts, NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -976,7 +976,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1436,7 +1436,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 4950a4b7a4e5a..547c2bef02b24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,36 +27,3 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); - -select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); - -select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); - -select from_utc_timestamp(null, 'PST'); - -select from_utc_timestamp('2015-07-24 00:00:00', null); - -select from_utc_timestamp(null, null); - -select from_utc_timestamp(cast(0 as timestamp), 'PST'); - -select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); - -select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); - -select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); - -select to_utc_timestamp(null, 'PST'); - -select to_utc_timestamp('2015-07-24 00:00:00', null); - -select to_utc_timestamp(null, null); - -select to_utc_timestamp(cast(0 as timestamp), 'PST'); - -select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); - --- SPARK-23715: the input of to/from_utc_timestamp can not have timezone -select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); - -select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 9eede305dbdcc..4e1cfa6e48c1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 10 -- !query 0 @@ -82,138 +82,9 @@ struct 1 2 2 3 - -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 9 schema +-- !query 3 schema struct --- !query 9 output +-- !query 3 output 5 3 5 NULL 4 - - --- !query 10 -select from_utc_timestamp('2015-07-24 00:00:00', 'PST') --- !query 10 schema -struct --- !query 10 output -2015-07-23 17:00:00 - - --- !query 11 -select from_utc_timestamp('2015-01-24 00:00:00', 'PST') --- !query 11 schema -struct --- !query 11 output -2015-01-23 16:00:00 - - --- !query 12 -select from_utc_timestamp(null, 'PST') --- !query 12 schema -struct --- !query 12 output -NULL - - --- !query 13 -select from_utc_timestamp('2015-07-24 00:00:00', null) --- !query 13 schema -struct --- !query 13 output -NULL - - --- !query 14 -select from_utc_timestamp(null, null) --- !query 14 schema -struct --- !query 14 output -NULL - - --- !query 15 -select from_utc_timestamp(cast(0 as timestamp), 'PST') --- !query 15 schema -struct --- !query 15 output -1969-12-31 08:00:00 - - --- !query 16 -select from_utc_timestamp(cast('2015-01-24' as date), 'PST') --- !query 16 schema -struct --- !query 16 output -2015-01-23 16:00:00 - - --- !query 17 -select to_utc_timestamp('2015-07-24 00:00:00', 'PST') --- !query 17 schema -struct --- !query 17 output -2015-07-24 07:00:00 - - --- !query 18 -select to_utc_timestamp('2015-01-24 00:00:00', 'PST') --- !query 18 schema -struct --- !query 18 output -2015-01-24 08:00:00 - - --- !query 19 -select to_utc_timestamp(null, 'PST') --- !query 19 schema -struct --- !query 19 output -NULL - - --- !query 20 -select to_utc_timestamp('2015-07-24 00:00:00', null) --- !query 20 schema -struct --- !query 20 output -NULL - - --- !query 21 -select to_utc_timestamp(null, null) --- !query 21 schema -struct --- !query 21 output -NULL - - --- !query 22 -select to_utc_timestamp(cast(0 as timestamp), 'PST') --- !query 22 schema -struct --- !query 22 output -1970-01-01 00:00:00 - - --- !query 23 -select to_utc_timestamp(cast('2015-01-24' as date), 'PST') --- !query 23 schema -struct --- !query 23 output -2015-01-24 08:00:00 - - --- !query 24 -select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') --- !query 24 schema -struct --- !query 24 output -NULL - - --- !query 25 -select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') --- !query 25 schema -struct --- !query 25 output -NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 3af80b36ec42c..c4ec7150c4075 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,7 +23,6 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -730,12 +729,4 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 07:00:00")), Row(Timestamp.valueOf("2015-07-24 22:00:00")))) } - - test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { - withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { - checkAnswer( - sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), - Row(Timestamp.valueOf("2000-10-09 18:00:00"))) - } - } } From 596af211a5b5a7468a0e9b840561c9ae2353a29c Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Thu, 20 Sep 2018 22:15:52 -0700 Subject: [PATCH 1668/2461] [SPARK-25494][SQL] Upgrade Spark's use of Janino to 3.0.10 ## What changes were proposed in this pull request? This PR upgrades Spark's use of Janino from 3.0.9 to 3.0.10. Note that 3.0.10 is a out-of-band release specifically for fixing an integer overflow issue in Janino's `ClassFile` reader. It is otherwise exactly the same as 3.0.9, so it's a low risk and compatible upgrade. The integer overflow issue affects Spark SQL's codegen stats collection: when a generated Class file is huge, especially when the constant pool size is above `Short.MAX_VALUE`, Janino's `ClassFile reader` will throw an exception when Spark wants to parse the generated Class file to collect stats. So we'll miss the stats of some huge Class files. The related Janino issue is: https://github.com/janino-compiler/janino/issues/58 ## How was this patch tested? Existing codegen tests. Closes #22506 from rednaxelafx/upgrade-janino. Authored-by: Kris Mok Signed-off-by: gatorsmile --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 62ae04dbc255f..969df4d92946b 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -34,7 +34,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.9.jar +commons-compiler-3.0.10.jar commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -98,7 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.9.jar +janino-3.0.10.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index dcb5d63aeff4d..e827dc6036f85 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -34,7 +34,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.9.jar +commons-compiler-3.0.10.jar commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -98,7 +98,7 @@ jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar -janino-3.0.9.jar +janino-3.0.10.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 641b4a15ad7cd..2b12c35d18e27 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -31,7 +31,7 @@ commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.9.jar +commons-compiler-3.0.10.jar commons-compress-1.8.1.jar commons-configuration2-2.1.1.jar commons-crypto-1.0.0.jar @@ -97,7 +97,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar -janino-3.0.9.jar +janino-3.0.10.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/pom.xml b/pom.xml index 71d5f944ee60e..099a08185d2a5 100644 --- a/pom.xml +++ b/pom.xml @@ -170,7 +170,7 @@ 3.5 3.2.10 - 3.0.9 + 3.0.10 2.22.2 2.9.3 3.5.2 From 1f4ca6f5c52560585ea977bddc69243a29bf67f2 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Fri, 21 Sep 2018 15:04:47 +0900 Subject: [PATCH 1669/2461] [SPARK-25487][SQL][TEST] Refactor PrimitiveArrayBenchmark ## What changes were proposed in this pull request? Refactor PrimitiveArrayBenchmark to use main method and print the output as a separate file. Run blow command to generate benchmark results: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.PrimitiveArrayBenchmark" ``` ## How was this patch tested? Manual tests. Closes #22497 from seancxmao/SPARK-25487. Authored-by: seancxmao Signed-off-by: Kazuaki Ishizaki --- .../PrimitiveArrayBenchmark-results.txt | 13 +++++ .../benchmark/PrimitiveArrayBenchmark.scala | 47 +++++++++---------- 2 files changed, 35 insertions(+), 25 deletions(-) create mode 100644 sql/core/benchmarks/PrimitiveArrayBenchmark-results.txt diff --git a/sql/core/benchmarks/PrimitiveArrayBenchmark-results.txt b/sql/core/benchmarks/PrimitiveArrayBenchmark-results.txt new file mode 100644 index 0000000000000..b06b5c092b61a --- /dev/null +++ b/sql/core/benchmarks/PrimitiveArrayBenchmark-results.txt @@ -0,0 +1,13 @@ +================================================================================================ +Write primitive arrays in dataset +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz + +Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Int 437 / 529 19.2 52.1 1.0X +Double 638 / 670 13.1 76.1 0.7X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index e7c8f2717fd74..7f467d161081a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -17,21 +17,30 @@ package org.apache.spark.sql.execution.benchmark -import scala.concurrent.duration._ - -import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.util.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.{Benchmark, BenchmarkBase => FileBenchmarkBase} /** - * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array - * To run this: - * 1. replace ignore(...) with test(...) - * 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark" - * - * Benchmarks in this file are skipped in normal builds. + * Benchmark primitive arrays via DataFrame and Dataset program using primitive arrays + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/PrimitiveArrayBenchmark-results.txt". */ -class PrimitiveArrayBenchmark extends BenchmarkBase { +object PrimitiveArrayBenchmark extends FileBenchmarkBase { + lazy val sparkSession = SparkSession.builder + .master("local[1]") + .appName("microbenchmark") + .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.autoBroadcastJoinThreshold", 1) + .getOrCreate() + + override def benchmark(): Unit = { + runBenchmark("Write primitive arrays in dataset") { + writeDatasetArray(4) + } + } def writeDatasetArray(iters: Int): Unit = { import sparkSession.implicits._ @@ -62,21 +71,9 @@ class PrimitiveArrayBenchmark extends BenchmarkBase { } } - val benchmark = new Benchmark("Write an array in Dataset", count * iters) + val benchmark = new Benchmark("Write an array in Dataset", count * iters, output = output) benchmark.addCase("Int ")(intArray) benchmark.addCase("Double")(doubleArray) benchmark.run - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Int 352 / 401 23.8 42.0 1.0X - Double 821 / 885 10.2 97.9 0.4X - */ - } - - ignore("Write an array in Dataset") { - writeDatasetArray(4) } } From fb3276a54a2b7339e5e0fb62fb01cbefcc330c8b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 21 Sep 2018 14:17:34 +0800 Subject: [PATCH 1670/2461] [SPARK-25384][SQL] Clarify fromJsonForceNullableSchema will be removed in Spark 3.0 See above. This should go into the 2.4 release. Closes #22509 from rxin/SPARK-25384. Authored-by: Reynold Xin Signed-off-by: Wenchen Fan --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 083f493fc06a2..d973ba029af99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -605,7 +605,7 @@ object SQLConf { .internal() .doc("When true, force the output schema of the from_json() function to be nullable " + "(including all the fields). Otherwise, the schema might not be compatible with" + - "actual data, which leads to curruptions.") + "actual data, which leads to corruptions. This config will be removed in Spark 3.0.") .booleanConf .createWithDefault(true) From 411ecc365ea62aef7a29d8764e783e6a58dbb1d5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 21 Sep 2018 14:27:14 +0800 Subject: [PATCH 1671/2461] [SPARK-23549][SQL] Rename config spark.sql.legacy.compareDateTimestampInTimestamp ## What changes were proposed in this pull request? See title. Makes our legacy backward compatibility configs more consistent. ## How was this patch tested? Make sure all references have been updated: ``` > git grep compareDateTimestampInTimestamp docs/sql-programming-guide.md: - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.legacy.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala: // if conf.compareDateTimestampInTimestamp is true sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala: => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala: => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala: buildConf("spark.sql.legacy.compareDateTimestampInTimestamp") sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala: def compareDateTimestampInTimestamp : Boolean = getConf(COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala: "spark.sql.legacy.compareDateTimestampInTimestamp" -> convertToTS.toString) { ``` Closes #22508 from rxin/SPARK-23549. Authored-by: Reynold Xin Signed-off-by: Wenchen Fan --- docs/sql-programming-guide.md | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 35 +++++++++---------- .../catalyst/analysis/TypeCoercionSuite.scala | 2 +- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ca9d153155441..0cc6a67c66029 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1950,7 +1950,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.legacy.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d973ba029af99..e31c536a81d2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -577,16 +577,6 @@ object SQLConf { .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) - val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = - buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp") - .internal() - .doc("When true (default), compare Date with Timestamp after converting both sides to " + - "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + - "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + - "converting both sides to string. This config will be removed in spark 3.0") - .booleanConf - .createWithDefault(true) - val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -1476,12 +1466,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") - .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + - "The size function returns null for null input if the flag is disabled.") - .booleanConf - .createWithDefault(true) - val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + @@ -1531,6 +1515,22 @@ object SQLConf { .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) .createWithDefault(Deflater.DEFAULT_COMPRESSION) + val COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = + buildConf("spark.sql.legacy.compareDateTimestampInTimestamp") + .internal() + .doc("When true (default), compare Date with Timestamp after converting both sides to " + + "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + + "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + + "converting both sides to string. This config will be removed in Spark 3.0.") + .booleanConf + .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) + val LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED = buildConf("spark.sql.legacy.replaceDatabricksSparkAvro.enabled") .doc("If it is set to true, the data source provider com.databricks.spark.avro is mapped " + @@ -1691,8 +1691,7 @@ class SQLConf extends Serializable with Logging { def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) - def compareDateTimestampInTimestamp : Boolean = - getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) + def compareDateTimestampInTimestamp : Boolean = getConf(COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1602f4d046118..0594673ecc926 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1459,7 +1459,7 @@ class TypeCoercionSuite extends AnalysisTest { DoubleType))) Seq(true, false).foreach { convertToTS => withSQLConf( - "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) { + "spark.sql.legacy.compareDateTimestampInTimestamp" -> convertToTS.toString) { val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) From 2c9d8f56c71093faf152ca7136c5fcc4a7b2a95f Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 21 Sep 2018 18:16:54 +0900 Subject: [PATCH 1672/2461] [SPARK-25469][SQL] Eval methods of Concat, Reverse and ElementAt should use pattern matching only once ## What changes were proposed in this pull request? The PR proposes to avoid usage of pattern matching for each call of ```eval``` method within: - ```Concat``` - ```Reverse``` - ```ElementAt``` ## How was this patch tested? Run the existing tests for ```Concat```, ```Reverse``` and ```ElementAt``` expression classes. Closes #22471 from mn-mikke/SPARK-25470. Authored-by: Marek Novotny Signed-off-by: Takeshi Yamamuro --- .../expressions/collectionOperations.scala | 81 +++++++++++-------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e23ebef9643ff..161adc9cc5bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1268,11 +1268,15 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + override def nullSafeEval(input: Any): Any = doReverse(input) - override def nullSafeEval(input: Any): Any = input match { - case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) - case s: UTF8String => s.reverse() + @transient private lazy val doReverse: Any => Any = dataType match { + case ArrayType(elementType, _) => + input => { + val arrayData = input.asInstanceOf[ArrayData] + new GenericArrayData(arrayData.toObjectArray(elementType).reverse) + } + case StringType => _.asInstanceOf[UTF8String].reverse() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1294,6 +1298,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI val i = ctx.freshName("i") val j = ctx.freshName("j") + val elementType = dataType.asInstanceOf[ArrayType].elementType val initialization = CodeGenerator.createArrayData( arrayData, elementType, numElements, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment( @@ -2164,9 +2169,11 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def nullable: Boolean = true - override def nullSafeEval(value: Any, ordinal: Any): Any = { - left.dataType match { - case _: ArrayType => + override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) + + @transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match { + case _: ArrayType => + (value, ordinal) => { val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { @@ -2185,9 +2192,9 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti array.get(idx, dataType) } } - case _: MapType => - getValueEval(value, ordinal, mapKeyType, ordering) - } + } + case _: MapType => + (value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -2278,33 +2285,41 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override def foldable: Boolean = children.forall(_.foldable) - override def eval(input: InternalRow): Any = dataType match { + override def eval(input: InternalRow): Any = doConcat(input) + + @transient private lazy val doConcat: InternalRow => Any = dataType match { case BinaryType => - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) + input => { + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + } case StringType => - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + input => { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs: _*) + } case ArrayType(elementType, _) => - val inputs = children.toStream.map(_.eval(input)) - if (inputs.contains(null)) { - null - } else { - val arrayData = inputs.map(_.asInstanceOf[ArrayData]) - val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - " elements due to exceeding the array size limit " + - ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") - } - val finalData = new Array[AnyRef](numberOfElements.toInt) - var position = 0 - for(ad <- arrayData) { - val arr = ad.toObjectArray(elementType) - Array.copy(arr, 0, finalData, position, arr.length) - position += arr.length + input => { + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) } - new GenericArrayData(finalData) } } From ff601cf71d226082e156c4ff9a8f5593aa7a2085 Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Fri, 21 Sep 2018 09:05:56 -0500 Subject: [PATCH 1673/2461] [SPARK-24355] Spark external shuffle server improvement to better handle block fetch requests. ## What changes were proposed in this pull request? Description: Right now, the default server side netty handler threads is 2 * # cores, and can be further configured with parameter spark.shuffle.io.serverThreads. In order to process a client request, it would require one available server netty handler thread. However, when the server netty handler threads start to process ChunkFetchRequests, they will be blocked on disk I/O, mostly due to disk contentions from the random read operations initiated by all the ChunkFetchRequests received from clients. As a result, when the shuffle server is serving many concurrent ChunkFetchRequests, the server side netty handler threads could all be blocked on reading shuffle files, thus leaving no handler thread available to process other types of requests which should all be very quick to process. This issue could potentially be fixed by limiting the number of netty handler threads that could get blocked when processing ChunkFetchRequest. We have a patch to do this by using a separate EventLoopGroup with a dedicated ChannelHandler to process ChunkFetchRequest. This enables shuffle server to reserve netty handler threads for non-ChunkFetchRequest, thus enabling consistent processing time for these requests which are fast to process. After deploying the patch in our infrastructure, we no longer see timeout issues with either executor registration with local shuffle server or shuffle client establishing connection with remote shuffle server. (Please fill in changes proposed in this fix) For Original PR please refer here https://github.com/apache/spark/pull/21402 ## How was this patch tested? Unit tests and stress testing. (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22173 from redsanket/SPARK-24335. Authored-by: Sanket Chintapalli Signed-off-by: Thomas Graves --- .../spark/network/TransportContext.java | 66 ++++++++- .../server/ChunkFetchRequestHandler.java | 135 ++++++++++++++++++ .../server/TransportChannelHandler.java | 21 ++- .../server/TransportRequestHandler.java | 35 +---- .../spark/network/util/TransportConf.java | 28 ++++ .../ChunkFetchRequestHandlerSuite.java | 102 +++++++++++++ .../spark/network/ExtendedChannelPromise.java | 69 +++++++++ .../network/TransportRequestHandlerSuite.java | 55 +------ .../shuffle/ExternalShuffleClient.java | 2 +- 9 files changed, 425 insertions(+), 88 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index ae91bc9cfdd08..480b52652de53 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -21,6 +21,8 @@ import java.util.List; import io.netty.channel.Channel; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; @@ -32,11 +34,13 @@ import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.server.ChunkFetchRequestHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.util.TransportFrameDecoder; @@ -61,6 +65,7 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; private final boolean closeIdleConnections; + private final boolean isClientOnly; /** * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created @@ -77,17 +82,54 @@ public class TransportContext { private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; + // Separate thread pool for handling ChunkFetchRequest. This helps to enable throttling + // max number of TransportServer worker threads that are blocked on writing response + // of ChunkFetchRequest message back to the client via the underlying channel. + private static EventLoopGroup chunkFetchWorkers; + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { - this(conf, rpcHandler, false); + this(conf, rpcHandler, false, false); } public TransportContext( TransportConf conf, RpcHandler rpcHandler, boolean closeIdleConnections) { + this(conf, rpcHandler, closeIdleConnections, false); + } + + /** + * Enables TransportContext initialization for underlying client and server. + * + * @param conf TransportConf + * @param rpcHandler RpcHandler responsible for handling requests and responses. + * @param closeIdleConnections Close idle connections if it is set to true. + * @param isClientOnly This config indicates the TransportContext is only used by a client. + * This config is more important when external shuffle is enabled. + * It stops creating extra event loop and subsequent thread pool + * for shuffle clients to handle chunked fetch requests. + */ + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections, + boolean isClientOnly) { this.conf = conf; this.rpcHandler = rpcHandler; this.closeIdleConnections = closeIdleConnections; + this.isClientOnly = isClientOnly; + + synchronized(TransportContext.class) { + if (chunkFetchWorkers == null && + conf.getModuleName() != null && + conf.getModuleName().equalsIgnoreCase("shuffle") && + !isClientOnly) { + chunkFetchWorkers = NettyUtils.createEventLoop( + IOMode.valueOf(conf.ioMode()), + conf.chunkFetchHandlerThreads(), + "shuffle-chunk-fetch-handler"); + } + } } /** @@ -144,14 +186,23 @@ public TransportChannelHandler initializePipeline( RpcHandler channelRpcHandler) { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); - channel.pipeline() + ChunkFetchRequestHandler chunkFetchHandler = + createChunkFetchHandler(channelHandler, channelRpcHandler); + ChannelPipeline pipeline = channel.pipeline() .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) .addLast("decoder", DECODER) - .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) + .addLast("idleStateHandler", + new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); + // Use a separate EventLoopGroup to handle ChunkFetchRequest messages for shuffle rpcs. + if (conf.getModuleName() != null && + conf.getModuleName().equalsIgnoreCase("shuffle") + && !isClientOnly) { + pipeline.addLast(chunkFetchWorkers, "chunkFetchHandler", chunkFetchHandler); + } return channelHandler; } catch (RuntimeException e) { logger.error("Error while initializing Netty pipeline", e); @@ -173,5 +224,14 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler conf.connectionTimeoutMs(), closeIdleConnections); } + /** + * Creates the dedicated ChannelHandler for ChunkFetchRequest messages. + */ + private ChunkFetchRequestHandler createChunkFetchHandler(TransportChannelHandler channelHandler, + RpcHandler rpcHandler) { + return new ChunkFetchRequestHandler(channelHandler.getClient(), + rpcHandler.getStreamManager(), conf.maxChunksBeingTransferred()); + } + public TransportConf getConf() { return conf; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java new file mode 100644 index 0000000000000..f08d8b0f984cf --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.net.SocketAddress; + +import com.google.common.base.Throwables; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; + +import static org.apache.spark.network.util.NettyUtils.*; + +/** + * A dedicated ChannelHandler for processing ChunkFetchRequest messages. When sending response + * of ChunkFetchRequest messages to the clients, the thread performing the I/O on the underlying + * channel could potentially be blocked due to disk contentions. If several hundreds of clients + * send ChunkFetchRequest to the server at the same time, it could potentially occupying all + * threads from TransportServer's default EventLoopGroup for waiting for disk reads before it + * can send the block data back to the client as part of the ChunkFetchSuccess messages. As a + * result, it would leave no threads left to process other RPC messages, which takes much less + * time to process, and could lead to client timing out on either performing SASL authentication, + * registering executors, or waiting for response for an OpenBlocks messages. + */ +public class ChunkFetchRequestHandler extends SimpleChannelInboundHandler { + private static final Logger logger = LoggerFactory.getLogger(ChunkFetchRequestHandler.class); + + private final TransportClient client; + private final StreamManager streamManager; + /** The max number of chunks being transferred and not finished yet. */ + private final long maxChunksBeingTransferred; + + public ChunkFetchRequestHandler( + TransportClient client, + StreamManager streamManager, + Long maxChunksBeingTransferred) { + this.client = client; + this.streamManager = streamManager; + this.maxChunksBeingTransferred = maxChunksBeingTransferred; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + getRemoteAddress(ctx.channel()), cause); + ctx.close(); + } + + @Override + protected void channelRead0( + ChannelHandlerContext ctx, + final ChunkFetchRequest msg) throws Exception { + Channel channel = ctx.channel(); + if (logger.isTraceEnabled()) { + logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), + msg.streamChunkId); + } + long chunksBeingTransferred = streamManager.chunksBeingTransferred(); + if (chunksBeingTransferred >= maxChunksBeingTransferred) { + logger.warn("The number of chunks being transferred {} is above {}, close the connection.", + chunksBeingTransferred, maxChunksBeingTransferred); + channel.close(); + return; + } + ManagedBuffer buf; + try { + streamManager.checkAuthorization(client, msg.streamChunkId.streamId); + streamManager.registerChannel(channel, msg.streamChunkId.streamId); + buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format("Error opening block %s for request from %s", + msg.streamChunkId, getRemoteAddress(channel)), e); + respond(channel, new ChunkFetchFailure(msg.streamChunkId, + Throwables.getStackTraceAsString(e))); + return; + } + + streamManager.chunkBeingSent(msg.streamChunkId.streamId); + respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } + + /** + * The invocation to channel.writeAndFlush is async, and the actual I/O on the + * channel will be handled by the EventLoop the channel is registered to. So even + * though we are processing the ChunkFetchRequest in a separate thread pool, the actual I/O, + * which is the potentially blocking call that could deplete server handler threads, is still + * being processed by TransportServer's default EventLoopGroup. In order to throttle the max + * number of threads that channel I/O for sending response to ChunkFetchRequest, the thread + * calling channel.writeAndFlush will wait for the completion of sending response back to + * client by invoking await(). This will throttle the rate at which threads from + * ChunkFetchRequest dedicated EventLoopGroup submit channel I/O requests to TransportServer's + * default EventLoopGroup, thus making sure that we can reserve some threads in + * TransportServer's default EventLoopGroup for handling other RPC messages. + */ + private ChannelFuture respond( + final Channel channel, + final Encodable result) throws InterruptedException { + final SocketAddress remoteAddress = channel.remoteAddress(); + return channel.writeAndFlush(result).await().addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + logger.trace("Sent result {} to client {}", result, remoteAddress); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + channel.close(); + } + }); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 56782a8327876..c824a7b0d4740 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -18,7 +18,7 @@ package org.apache.spark.network.server; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; @@ -26,6 +26,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -47,7 +49,7 @@ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * timeout if the client is continuously sending but getting no responses, for simplicity. */ -public class TransportChannelHandler extends ChannelInboundHandlerAdapter { +public class TransportChannelHandler extends SimpleChannelInboundHandler { private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; @@ -112,8 +114,21 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { super.channelInactive(ctx); } + /** + * Overwrite acceptInboundMessage to properly delegate ChunkFetchRequest messages + * to ChunkFetchRequestHandler. + */ @Override - public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception { + public boolean acceptInboundMessage(Object msg) throws Exception { + if (msg instanceof ChunkFetchRequest) { + return false; + } else { + return super.acceptInboundMessage(msg); + } + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); } else if (request instanceof ResponseMessage) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 9fac96dbe450d..3e089b4cae273 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -24,6 +24,7 @@ import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -97,9 +98,7 @@ public void channelInactive() { @Override public void handle(RequestMessage request) { - if (request instanceof ChunkFetchRequest) { - processFetchRequest((ChunkFetchRequest) request); - } else if (request instanceof RpcRequest) { + if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); } else if (request instanceof OneWayMessage) { processOneWayMessage((OneWayMessage) request); @@ -112,36 +111,6 @@ public void handle(RequestMessage request) { } } - private void processFetchRequest(final ChunkFetchRequest req) { - if (logger.isTraceEnabled()) { - logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), - req.streamChunkId); - } - long chunksBeingTransferred = streamManager.chunksBeingTransferred(); - if (chunksBeingTransferred >= maxChunksBeingTransferred) { - logger.warn("The number of chunks being transferred {} is above {}, close the connection.", - chunksBeingTransferred, maxChunksBeingTransferred); - channel.close(); - return; - } - ManagedBuffer buf; - try { - streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); - streamManager.registerChannel(channel, req.streamChunkId.streamId); - buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); - } catch (Exception e) { - logger.error(String.format("Error opening block %s for request from %s", - req.streamChunkId, getRemoteAddress(channel)), e); - respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); - return; - } - - streamManager.chunkBeingSent(req.streamChunkId.streamId); - respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> { - streamManager.chunkSent(req.streamChunkId.streamId); - }); - } - private void processStreamRequest(final StreamRequest req) { if (logger.isTraceEnabled()) { logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel), diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 34e4bb5912dcb..6d5cccd20b333 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -21,6 +21,7 @@ import java.util.Properties; import com.google.common.primitives.Ints; +import io.netty.util.NettyRuntime; /** * A central location that tracks all the settings we expose to users. @@ -281,4 +282,31 @@ public Properties cryptoConf() { public long maxChunksBeingTransferred() { return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); } + + /** + * Percentage of io.serverThreads used by netty to process ChunkFetchRequest. + * Shuffle server will use a separate EventLoopGroup to process ChunkFetchRequest messages. + * Although when calling the async writeAndFlush on the underlying channel to send + * response back to client, the I/O on the channel is still being handled by + * {@link org.apache.spark.network.server.TransportServer}'s default EventLoopGroup + * that's registered with the Channel, by waiting inside the ChunkFetchRequest handler + * threads for the completion of sending back responses, we are able to put a limit on + * the max number of threads from TransportServer's default EventLoopGroup that are + * going to be consumed by writing response to ChunkFetchRequest, which are I/O intensive + * and could take long time to process due to disk contentions. By configuring a slightly + * higher number of shuffler server threads, we are able to reserve some threads for + * handling other RPC messages, thus making the Client less likely to experience timeout + * when sending RPC messages to the shuffle server. Default to 0, which is 2*#cores + * or io.serverThreads. 90 would mean 90% of 2*#cores or 90% of io.serverThreads + * which equals 0.9 * 2*#cores or 0.9 * io.serverThreads. + */ + public int chunkFetchHandlerThreads() { + if (!this.getModuleName().equalsIgnoreCase("shuffle")) { + return 0; + } + int chunkFetchHandlerThreadsPercent = + conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 0); + return this.serverThreads() > 0 ? (this.serverThreads() * chunkFetchHandlerThreadsPercent)/100: + (2 * NettyRuntime.availableProcessors() * chunkFetchHandlerThreadsPercent)/100; + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java new file mode 100644 index 0000000000000..2c72c53a33ae8 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.ChannelHandlerContext; +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import org.apache.spark.network.server.ChunkFetchRequestHandler; +import org.junit.Test; + +import static org.mockito.Mockito.*; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; + +public class ChunkFetchRequestHandlerSuite { + + @Test + public void handleChunkFetchRequest() throws Exception { + RpcHandler rpcHandler = new NoOpRpcHandler(); + OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); + Channel channel = mock(Channel.class); + ChannelHandlerContext context = mock(ChannelHandlerContext.class); + when(context.channel()) + .thenAnswer(invocationOnMock0 -> { + return channel; + }); + List> responseAndPromisePairs = + new ArrayList<>(); + when(channel.writeAndFlush(any())) + .thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + // Prepare the stream. + List managedBuffers = new ArrayList<>(); + managedBuffers.add(new TestManagedBuffer(10)); + managedBuffers.add(new TestManagedBuffer(20)); + managedBuffers.add(new TestManagedBuffer(30)); + managedBuffers.add(new TestManagedBuffer(40)); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); + streamManager.registerChannel(channel, streamId); + TransportClient reverseClient = mock(TransportClient.class); + ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient, + rpcHandler.getStreamManager(), 2L); + + RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); + requestHandler.channelRead(context, request0); + assert responseAndPromisePairs.size() == 1; + assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == + managedBuffers.get(0); + + RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); + requestHandler.channelRead(context, request1); + assert responseAndPromisePairs.size() == 2; + assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == + managedBuffers.get(1); + + // Finish flushing the response for request0. + responseAndPromisePairs.get(0).getRight().finish(true); + + RequestMessage request2 = new ChunkFetchRequest(new StreamChunkId(streamId, 2)); + requestHandler.channelRead(context, request2); + assert responseAndPromisePairs.size() == 3; + assert responseAndPromisePairs.get(2).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(2).getLeft())).body() == + managedBuffers.get(2); + + RequestMessage request3 = new ChunkFetchRequest(new StreamChunkId(streamId, 3)); + requestHandler.channelRead(context, request3); + verify(channel, times(1)).close(); + assert responseAndPromisePairs.size() == 3; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java b/common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java new file mode 100644 index 0000000000000..573ffd627a2e7 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/ExtendedChannelPromise.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +class ExtendedChannelPromise extends DefaultChannelPromise { + + private List>> listeners = new ArrayList<>(); + private boolean success; + + ExtendedChannelPromise(Channel channel) { + super(channel); + success = false; + } + + @Override + public ChannelPromise addListener( + GenericFutureListener> listener) { + @SuppressWarnings("unchecked") + GenericFutureListener> gfListener = + (GenericFutureListener>) listener; + listeners.add(gfListener); + return super.addListener(listener); + } + + @Override + public boolean isSuccess() { + return success; + } + + @Override + public ChannelPromise await() throws InterruptedException { + return this; + } + + public void finish(boolean success) { + this.success = success; + listeners.forEach(listener -> { + try { + listener.operationComplete(this); + } catch (Exception e) { + // do nothing + } + }); + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 2656cbee95a20..ad640415a8e6d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -21,10 +21,6 @@ import java.util.List; import io.netty.channel.Channel; -import io.netty.channel.ChannelPromise; -import io.netty.channel.DefaultChannelPromise; -import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.GenericFutureListener; import org.junit.Test; import static org.mockito.Mockito.*; @@ -42,7 +38,7 @@ public class TransportRequestHandlerSuite { @Test - public void handleFetchRequestAndStreamRequest() throws Exception { + public void handleStreamRequest() throws Exception { RpcHandler rpcHandler = new NoOpRpcHandler(); OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); Channel channel = mock(Channel.class); @@ -68,18 +64,18 @@ public void handleFetchRequestAndStreamRequest() throws Exception { TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L); - RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); + RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0)); requestHandler.handle(request0); assert responseAndPromisePairs.size() == 1; - assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == + assert responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse; + assert ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body() == managedBuffers.get(0); - RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); + RequestMessage request1 = new StreamRequest(String.format("%d_%d", streamId, 1)); requestHandler.handle(request1); assert responseAndPromisePairs.size() == 2; - assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == + assert responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse; + assert ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body() == managedBuffers.get(1); // Finish flushing the response for request0. @@ -99,41 +95,4 @@ public void handleFetchRequestAndStreamRequest() throws Exception { verify(channel, times(1)).close(); assert responseAndPromisePairs.size() == 3; } - - private class ExtendedChannelPromise extends DefaultChannelPromise { - - private List>> listeners = new ArrayList<>(); - private boolean success; - - ExtendedChannelPromise(Channel channel) { - super(channel); - success = false; - } - - @Override - public ChannelPromise addListener( - GenericFutureListener> listener) { - @SuppressWarnings("unchecked") - GenericFutureListener> gfListener = - (GenericFutureListener>) listener; - listeners.add(gfListener); - return super.addListener(listener); - } - - @Override - public boolean isSuccess() { - return success; - } - - public void finish(boolean success) { - this.success = success; - listeners.forEach(listener -> { - try { - listener.operationComplete(this); - } catch (Exception e) { - // do nothing - } - }); - } - } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 9a2cf0f953481..e49e27ab5aa79 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -76,7 +76,7 @@ protected void checkInit() { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); List bootstraps = Lists.newArrayList(); if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); From d25f425c9652a3611dd5fea8a37df4abb13e126e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 21 Sep 2018 22:20:55 +0800 Subject: [PATCH 1674/2461] [SPARK-25499][TEST] Refactor BenchmarkBase and Benchmark ## What changes were proposed in this pull request? Currently there are two classes with the same naming BenchmarkBase: 1. `org.apache.spark.util.BenchmarkBase` 2. `org.apache.spark.sql.execution.benchmark.BenchmarkBase` This is very confusing. And the benchmark object `org.apache.spark.sql.execution.benchmark.FilterPushdownBenchmark` is using the one in `org.apache.spark.util.BenchmarkBase`, while there is another class `BenchmarkBase` in the same package of it... Here I propose: 1. the package `org.apache.spark.util.BenchmarkBase` should be in test package of core module. Move it to package `org.apache.spark.benchmark` . 2. Move `org.apache.spark.util.Benchmark` to test package of core module. Move it to package `org.apache.spark.benchmark` . 3. Rename the class `org.apache.spark.sql.execution.benchmark.BenchmarkBase` as `BenchmarkWithCodegen` ## How was this patch tested? Unit test Closes #22513 from gengliangwang/refactorBenchmarkBase. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../org/apache/spark/benchmark}/Benchmark.scala | 4 +++- .../apache/spark/benchmark}/BenchmarkBase.scala | 2 +- .../apache/spark/serializer/KryoBenchmark.scala | 2 +- .../mllib/linalg/UDTSerializationBenchmark.scala | 2 +- .../org/apache/spark/sql/HashBenchmark.scala | 2 +- .../apache/spark/sql/HashByteArrayBenchmark.scala | 2 +- .../spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../org/apache/spark/sql/DatasetBenchmark.scala | 2 +- ...xternalAppendOnlyUnsafeRowArrayBenchmark.scala | 2 +- .../execution/benchmark/AggregateBenchmark.scala | 4 ++-- .../execution/benchmark/BenchmarkWideTable.scala | 5 ++--- ...hmarkBase.scala => BenchmarkWithCodegen.scala} | 4 ++-- .../benchmark/DataSourceReadBenchmark.scala | 3 ++- .../benchmark/DataSourceWriteBenchmark.scala | 2 +- .../benchmark/FilterPushdownBenchmark.scala | 15 +++++++++------ .../sql/execution/benchmark/JoinBenchmark.scala | 2 +- .../sql/execution/benchmark/MiscBenchmark.scala | 4 ++-- .../benchmark/PrimitiveArrayBenchmark.scala | 4 ++-- .../sql/execution/benchmark/SortBenchmark.scala | 4 ++-- .../execution/benchmark/TPCDSQueryBenchmark.scala | 2 +- .../benchmark/UnsafeArrayDataBenchmark.scala | 4 ++-- .../execution/benchmark/WideSchemaBenchmark.scala | 3 ++- .../compression/CompressionSchemeBenchmark.scala | 2 +- .../execution/datasources/csv/CSVBenchmarks.scala | 3 ++- .../datasources/json/JsonBenchmarks.scala | 3 ++- .../vectorized/ColumnarBatchBenchmark.scala | 2 +- .../ObjectHashAggregateExecBenchmark.scala | 4 ++-- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 4 ++-- 28 files changed, 51 insertions(+), 43 deletions(-) rename core/src/{main/scala/org/apache/spark/util => test/scala/org/apache/spark/benchmark}/Benchmark.scala (99%) rename core/src/{main/scala/org/apache/spark/util => test/scala/org/apache/spark/benchmark}/BenchmarkBase.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/{BenchmarkBase.scala => BenchmarkWithCodegen.scala} (94%) diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/util/Benchmark.scala rename to core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala index 7def44bd2a2b1..7a36b5f02dc4c 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.benchmark import java.io.{OutputStream, PrintStream} @@ -27,6 +27,8 @@ import scala.util.Try import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.SystemUtils +import org.apache.spark.util.Utils + /** * Utility class to benchmark components. An example of how to use this is: * val benchmark = new Benchmark("My Benchmark", valuesPerIteration) diff --git a/core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala rename to core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index c84032b8726db..9a37e0221b27b 100644 --- a/core/src/main/scala/org/apache/spark/util/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.benchmark import java.io.{File, FileOutputStream, OutputStream} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala index a1cf3570a7a6d..f4fc0080f3108 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala @@ -21,8 +21,8 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.benchmark.Benchmark import org.apache.spark.serializer.KryoTest._ -import org.apache.spark.util.Benchmark class KryoBenchmark extends SparkFunSuite { val benchmark = new Benchmark("Benchmark Kryo Unsafe vs safe Serialization", 1024 * 1024 * 15, 10) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 5973479dfb5ed..e2976e1ab022b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.linalg +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.util.Benchmark /** * Serialization benchmark for VectorUDT. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 9a89e6290e695..7a2a66c9b1d33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.types._ -import org.apache.spark.util.Benchmark /** * Benchmark for the previous interpreted hash function(InternalRow.hashCode) vs codegened diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index f6c8111f5bc57..a60eb20d9edef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql import java.util.Random +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.catalyst.expressions.{HiveHasher, XXH64} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 -import org.apache.spark.util.Benchmark /** * Synthetic benchmark for MurMurHash 3 and xxHash64. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 6c63769945312..faff681e13955 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.types._ -import org.apache.spark.util.Benchmark /** * Benchmark `UnsafeProjection` for fixed-length/primitive-type fields. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 1a0672b8876da..fa2f0b6ba61d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StringType -import org.apache.spark.util.Benchmark /** * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 59397dbcb1cab..611b2fc037f3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.benchmark.Benchmark import org.apache.spark.internal.config import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter object ExternalAppendOnlyUnsafeRowArrayBenchmark { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8f4ee8533e599..57a6fdb800ea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.HashMap import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.internal.config._ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -30,7 +31,6 @@ import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.util.Benchmark /** * Benchmark to measure performance for aggregate primitives. @@ -39,7 +39,7 @@ import org.apache.spark.util.Benchmark * * Benchmarks in this file are skipped in normal builds. */ -class AggregateBenchmark extends BenchmarkBase { +class AggregateBenchmark extends BenchmarkWithCodegen { ignore("aggregate without grouping") { val N = 500L << 22 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala index 9dcaca0ca93ee..76367cbbe5342 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.benchmark -import org.apache.spark.util.Benchmark - +import org.apache.spark.benchmark.Benchmark /** * Benchmark to measure performance for wide table. @@ -27,7 +26,7 @@ import org.apache.spark.util.Benchmark * * Benchmarks in this file are skipped in normal builds. */ -class BenchmarkWideTable extends BenchmarkBase { +class BenchmarkWideTable extends BenchmarkWithCodegen { ignore("project on wide table") { val N = 1 << 20 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala index c99a5aec1cd6e..51331500479a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkFunSuite +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.SparkSession -import org.apache.spark.util.Benchmark /** * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together * with other test suites). */ -private[benchmark] trait BenchmarkBase extends SparkFunSuite { +private[benchmark] trait BenchmarkWithCodegen extends SparkFunSuite { lazy val sparkSession = SparkSession.builder .master("local[1]") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 8711f5a8fa1ce..cf9bda2fb1ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -22,13 +22,14 @@ import scala.collection.JavaConverters._ import scala.util.{Random, Try} import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector -import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala index e3463d9e28acc..994d6b5b7d334 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Benchmark trait DataSourceWriteBenchmark { val conf = new SparkConf() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 9ecea99f12895..3b7f10783b64c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -22,22 +22,25 @@ import java.io.File import scala.util.{Random, Try} import org.apache.spark.SparkConf +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} -import org.apache.spark.util.{Benchmark, BenchmarkBase => FileBenchmarkBase, Utils} +import org.apache.spark.util.Utils /** * Benchmark to measure read performance with Filter pushdown. * To run this benchmark: - * 1. without sbt: bin/spark-submit --class - * 2. build/sbt "sql/test:runMain " - * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " - * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". + * }}} */ -object FilterPushdownBenchmark extends FileBenchmarkBase { +object FilterPushdownBenchmark extends BenchmarkBase { private val conf = new SparkConf() .setAppName(this.getClass.getSimpleName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 5a25d72308370..37744dccc06f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.IntegerType * * Benchmarks in this file are skipped in normal builds. */ -class JoinBenchmark extends BenchmarkBase { +class JoinBenchmark extends BenchmarkWithCodegen { ignore("broadcast hash join, long key") { val N = 20 << 20 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index f039aeaad442c..f44da242e62b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.benchmark -import org.apache.spark.util.Benchmark +import org.apache.spark.benchmark.Benchmark /** * Benchmark to measure whole stage codegen performance. @@ -26,7 +26,7 @@ import org.apache.spark.util.Benchmark * * Benchmarks in this file are skipped in normal builds. */ -class MiscBenchmark extends BenchmarkBase { +class MiscBenchmark extends BenchmarkWithCodegen { ignore("filter & aggregate without group") { val N = 500L << 22 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index 7f467d161081a..8b275188f06d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.SparkSession -import org.apache.spark.util.{Benchmark, BenchmarkBase => FileBenchmarkBase} /** * Benchmark primitive arrays via DataFrame and Dataset program using primitive arrays @@ -28,7 +28,7 @@ import org.apache.spark.util.{Benchmark, BenchmarkBase => FileBenchmarkBase} * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " * Results will be written to "benchmarks/PrimitiveArrayBenchmark-results.txt". */ -object PrimitiveArrayBenchmark extends FileBenchmarkBase { +object PrimitiveArrayBenchmark extends BenchmarkBase { lazy val sparkSession = SparkSession.builder .master("local[1]") .appName("microbenchmark") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d9..17619ec5fadc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} +import org.apache.spark.benchmark.Benchmark import org.apache.spark.unsafe.array.LongArray import org.apache.spark.unsafe.memory.MemoryBlock -import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ import org.apache.spark.util.random.XORShiftRandom @@ -33,7 +33,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Benchmarks in this file are skipped in normal builds. */ -class SortBenchmark extends BenchmarkBase { +class SortBenchmark extends BenchmarkWithCodegen { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index fccee97820e75..2d72b1c14af7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.util.Benchmark /** * Benchmark to measure TPCDS query performance. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala index 6c7779b5790d0..51ab0e13a98a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.benchmark import scala.util.Random +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter} -import org.apache.spark.util.Benchmark /** * Benchmark [[UnsafeArrayDataBenchmark]] for UnsafeArrayData @@ -32,7 +32,7 @@ import org.apache.spark.util.Benchmark * * Benchmarks in this file are skipped in normal builds. */ -class UnsafeArrayDataBenchmark extends BenchmarkBase { +class UnsafeArrayDataBenchmark extends BenchmarkWithCodegen { def calculateHeaderPortionInBytes(count: Int) : Int = { /* 4 + 4 * count // Use this expression for SPARK-15962 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala index c368f17a84364..81017a6d244f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -22,8 +22,9 @@ import java.io.{File, FileOutputStream, OutputStream} import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.functions._ -import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.util.Utils /** * Benchmark for performance with very wide and nested DataFrames. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 619b76fabdd5e..9c26d67b62ccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -23,10 +23,10 @@ import java.nio.charset.StandardCharsets import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} import org.apache.spark.sql.types.AtomicType -import org.apache.spark.util.Benchmark import org.apache.spark.util.Utils._ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 24f5f55d55485..6d319eb723d93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, Row, SparkSession} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.util.Utils /** * Benchmark to measure CSV read/write performance. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index a2b747eaab411..e40cb9b50148b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.datasources.json import java.io.File import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.util.Utils /** * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 8aeb06d428951..d69cf1126868e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -21,11 +21,11 @@ import java.nio.charset.StandardCharsets import scala.util.Random +import org.apache.spark.benchmark.Benchmark import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform -import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index e599d1ab1d486..3b33785cdfbb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -21,6 +21,7 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogFunction @@ -31,9 +32,8 @@ import org.apache.spark.sql.hive.execution.TestingTypedCount import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.LongType -import org.apache.spark.util.Benchmark -class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingleton { +class ObjectHashAggregateExecBenchmark extends BenchmarkWithCodegen with TestHiveSingleton { ignore("Hive UDAF vs Spark AF") { val N = 2 << 15 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index bf6efa7c4c08c..0eab7d1ea8e80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -22,11 +22,11 @@ import java.io.File import scala.util.{Random, Try} import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.{Benchmark, Utils} - +import org.apache.spark.util.Utils /** * Benchmark to measure ORC read performance. From 4a11209539130c6a075119bf87c5ad854d42978e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 21 Sep 2018 09:45:41 -0700 Subject: [PATCH 1675/2461] [SPARK-19724][SQL] allowCreatingManagedTableUsingNonemptyLocation should have legacy prefix One more legacy config to go ... Closes #22515 from rxin/allowCreatingManagedTableUsingNonemptyLocation. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 2 +- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0cc6a67c66029..c72fa3d75d67f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1951,7 +1951,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.legacy.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e31c536a81d2d..ddf17fa88c76b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1358,7 +1358,7 @@ object SQLConf { .createWithDefault(false) val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = - buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation") + buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation") .internal() .doc("When this option is set to true, creating managed tables with nonempty location " + "is allowed. Otherwise, an analysis exception is thrown. ") From 40edab209bdefe793b59b650099cea026c244484 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 21 Sep 2018 13:08:01 -0700 Subject: [PATCH 1676/2461] [SPARK-25321][ML] Fix local LDA model constructor ## What changes were proposed in this pull request? change back the constructor to: ``` class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, private[clustering] val oldLocalModel : OldLocalLDAModel, sparkSession: SparkSession) ``` Although it is marked `private[ml]`, it is used in `mleap` and the master change breaks `mleap` building. See mleap code [here](https://github.com/combust/mleap/blob/c7860af328d519cf56441b4a7cd8e6ec9d9fee59/mleap-spark/src/main/scala/org/apache/spark/ml/bundle/ops/clustering/LDAModelOp.scala#L57) ## How was this patch tested? Manual. Closes #22510 from WeichenXu123/LDA_fix. Authored-by: WeichenXu Signed-off-by: Xiangrui Meng --- .../src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 50867f776c522..84e73dc19a392 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -570,13 +570,11 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - private[clustering] val oldLocalModel_ : OldLocalLDAModel, + private[clustering] val oldLocalModel : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { - override private[clustering] def oldLocalModel: OldLocalLDAModel = { - oldLocalModel_.setSeed(getSeed) - } + oldLocalModel.setSeed(getSeed) @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { From 6ca87eb2e0c60baa5faec91a12240ac50a248e72 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 22 Sep 2018 09:44:46 -0700 Subject: [PATCH 1677/2461] [SPARK-25465][TEST] Refactor Parquet test suites in project Hive ## What changes were proposed in this pull request? Current the file [parquetSuites.scala](https://github.com/apache/spark/blob/f29c2b5287563c0d6f55f936bd5a75707d7b2b1f/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala) is not recognizable. When I tried to find test suites for built-in Parquet conversions for Hive serde, I can only find [HiveParquetSuite](https://github.com/apache/spark/blob/f29c2b5287563c0d6f55f936bd5a75707d7b2b1f/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala) in the first few minutes. This PR is to: 1. Rename `ParquetMetastoreSuite` to `HiveParquetMetastoreSuite`, and create a single file for it. 2. Rename `ParquetSourceSuite` to `HiveParquetSourceSuite`, and create a single file for it. 3. Create a single file for `ParquetPartitioningTest`. 4. Delete `parquetSuites.scala` . ## How was this patch tested? Unit test Closes #22467 from gengliangwang/refactor_parquet_suites. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- ....scala => HiveParquetMetastoreSuite.scala} | 586 +++--------------- .../sql/hive/HiveParquetSourceSuite.scala | 225 +++++++ .../sql/hive/ParquetPartitioningTest.scala | 252 ++++++++ 3 files changed, 559 insertions(+), 504 deletions(-) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/{parquetSuites.scala => HiveParquetMetastoreSuite.scala} (55%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetPartitioningTest.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala similarity index 55% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala index e82d457eee394..0d4f040156084 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala @@ -19,44 +19,18 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.DataSourceScanExec -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation} import org.apache.spark.sql.hive.execution.HiveTableScanExec -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -// The data where the partitioning key exists only in the directory structure. -case class ParquetData(intField: Int, stringField: String) -// The data that also includes the partitioning key -case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) - -case class StructContainer(intStructField: Int, stringStructField: String) - -case class ParquetDataWithComplexTypes( - intField: Int, - stringField: String, - structField: StructContainer, - arrayField: Seq[Int]) - -case class ParquetDataWithKeyAndComplexTypes( - p: Int, - intField: Int, - stringField: String, - structField: StructContainer, - arrayField: Seq[Int]) /** * A suite to test the automatic conversion of metastore tables with parquet data to use the * built in parquet support. */ -class ParquetMetastoreSuite extends ParquetPartitioningTest { +class HiveParquetMetastoreSuite extends ParquetPartitioningTest { import hiveContext._ import spark.implicits._ @@ -70,78 +44,83 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { "jt", "jt_array", "test_parquet") - sql(s""" - create external table partitioned_parquet - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDir.toURI}' - """) - - sql(s""" - create external table partitioned_parquet_with_key - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDirWithKey.toURI}' - """) - - sql(s""" - create external table normal_parquet - ( - intField INT, - stringField STRING - ) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(normalTableDir, "normal").toURI}' - """) - - sql(s""" - CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes - ( - intField INT, - stringField STRING, - structField STRUCT, - arrayField ARRAY - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - LOCATION '${partitionedTableDirWithComplexTypes.toURI}' - """) - - sql(s""" - CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes - ( - intField INT, - stringField STRING, - structField STRUCT, - arrayField ARRAY - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - LOCATION '${partitionedTableDirWithKeyAndComplexTypes.toURI}' - """) + sql( + s""" + |create external table partitioned_parquet + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (p int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |location '${partitionedTableDir.toURI}' + """.stripMargin) + + sql( + s""" + |create external table partitioned_parquet_with_key + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (p int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |location '${partitionedTableDirWithKey.toURI}' + """.stripMargin) + + sql( + s""" + |create external table normal_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |location '${new File(normalTableDir, "normal").toURI}' + """.stripMargin) + + sql( + s""" + |CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes + |( + | intField INT, + | stringField STRING, + | structField STRUCT, + | arrayField ARRAY + |) + |PARTITIONED BY (p int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |LOCATION '${partitionedTableDirWithComplexTypes.toURI}' + """.stripMargin) + + sql( + s""" + |CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes + |( + | intField INT, + | stringField STRING, + | structField STRUCT, + | arrayField ARRAY + |) + |PARTITIONED BY (p int) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |LOCATION '${partitionedTableDirWithKeyAndComplexTypes.toURI}' + """.stripMargin) sql( """ @@ -291,7 +270,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[HadoopFsRelation ].getCanonicalName }") + s"${classOf[HadoopFsRelation ].getCanonicalName }") } } } @@ -430,7 +409,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("SPARK-15968: nonempty partitioned metastore Parquet table lookup should use cached " + - "relation") { + "relation") { withTable("partitioned") { sql( """ @@ -678,404 +657,3 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("SELECT * FROM normal_parquet x CROSS JOIN normal_parquet y")) } } - -/** - * A suite of tests for the Parquet support through the data sources API. - */ -class ParquetSourceSuite extends ParquetPartitioningTest { - import testImplicits._ - import spark._ - - override def beforeAll(): Unit = { - super.beforeAll() - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet") - - sql( s""" - CREATE TEMPORARY VIEW partitioned_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDir.toURI}' - ) - """) - - sql( s""" - CREATE TEMPORARY VIEW partitioned_parquet_with_key - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDirWithKey.toURI}' - ) - """) - - sql( s""" - CREATE TEMPORARY VIEW normal_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${new File(partitionedTableDir, "p=1").toURI}' - ) - """) - - sql( s""" - CREATE TEMPORARY VIEW partitioned_parquet_with_key_and_complextypes - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDirWithKeyAndComplexTypes.toURI}' - ) - """) - - sql( s""" - CREATE TEMPORARY VIEW partitioned_parquet_with_complextypes - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDirWithComplexTypes.toURI}' - ) - """) - } - - test("SPARK-6016 make sure to use the latest footers") { - sql("drop table if exists spark_6016_fix") - - // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = (1 to 10).map(Tuple1(_)).toDF("a").coalesce(2) - df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") - checkAnswer( - sql("select * from spark_6016_fix"), - (1 to 10).map(i => Row(i)) - ) - - // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = (1 to 10).map(Tuple1(_)).toDF("b").coalesce(4) - df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") - // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, - // since the new table has four parquet files, we are trying to read new footers from two files - // and then merge metadata in footers of these four (two outdated ones and two latest one), - // which will cause an error. - checkAnswer( - sql("select * from spark_6016_fix"), - (1 to 10).map(i => Row(i)) - ) - - sql("drop table spark_6016_fix") - } - - test("SPARK-8811: compatibility with array of struct in Hive") { - withTempPath { dir => - withTable("array_of_struct") { - val conf = Seq( - HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false", - SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") - - withSQLConf(conf: _*) { - sql( - s"""CREATE TABLE array_of_struct - |STORED AS PARQUET LOCATION '${dir.toURI}' - |AS SELECT - | '1st' AS a, - | '2nd' AS b, - | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c - """.stripMargin) - - checkAnswer( - spark.read.parquet(dir.getCanonicalPath), - Row("1st", "2nd", Seq(Row("val_a", "val_b")))) - } - } - } - } - - test("Verify the PARQUET conversion parameter: CONVERT_METASTORE_PARQUET") { - withTempView("single") { - val singleRowDF = Seq((0, "foo")).toDF("key", "value") - singleRowDF.createOrReplaceTempView("single") - - Seq("true", "false").foreach { parquetConversion => - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> parquetConversion) { - val tableName = "test_parquet_ctas" - withTable(tableName) { - sql( - s""" - |CREATE TABLE $tableName STORED AS PARQUET - |AS SELECT tmp.key, tmp.value FROM single tmp - """.stripMargin) - - val df = spark.sql(s"SELECT * FROM $tableName WHERE key=0") - checkAnswer(df, singleRowDF) - - val queryExecution = df.queryExecution - if (parquetConversion == "true") { - queryExecution.analyzed.collectFirst { - case _: LogicalRelation => - }.getOrElse { - fail(s"Expecting the query plan to convert parquet to data sources, " + - s"but got:\n$queryExecution") - } - } else { - queryExecution.analyzed.collectFirst { - case _: HiveTableRelation => - }.getOrElse { - fail(s"Expecting no conversion from parquet to data sources, " + - s"but got:\n$queryExecution") - } - } - } - } - } - } - } - - test("values in arrays and maps stored in parquet are always nullable") { - val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") - val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) - val arrayType1 = ArrayType(IntegerType, containsNull = false) - val expectedSchema1 = - StructType( - StructField("m", mapType1, nullable = true) :: - StructField("a", arrayType1, nullable = true) :: Nil) - assert(df.schema === expectedSchema1) - - withTable("alwaysNullable") { - df.write.format("parquet").saveAsTable("alwaysNullable") - - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val arrayType2 = ArrayType(IntegerType, containsNull = true) - val expectedSchema2 = - StructType( - StructField("m", mapType2, nullable = true) :: - StructField("a", arrayType2, nullable = true) :: Nil) - - assert(table("alwaysNullable").schema === expectedSchema2) - - checkAnswer( - sql("SELECT m, a FROM alwaysNullable"), - Row(Map(2 -> 3), Seq(4, 5, 6))) - } - } - - test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { - val tempDir = Utils.createTempDir() - val filePath = new File(tempDir, "testParquet").getCanonicalPath - val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath - - val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") - val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") - intercept[Throwable](df2.write.parquet(filePath)) - - val df3 = df2.toDF("str", "max_int") - df3.write.parquet(filePath2) - val df4 = read.parquet(filePath2) - checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) - assert(df4.columns === Array("str", "max_int")) - } -} - -/** - * A collection of tests for parquet data with various forms of partitioning. - */ -abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton { - import testImplicits._ - - var partitionedTableDir: File = null - var normalTableDir: File = null - var partitionedTableDirWithKey: File = null - var partitionedTableDirWithComplexTypes: File = null - var partitionedTableDirWithKeyAndComplexTypes: File = null - - override def beforeAll(): Unit = { - super.beforeAll() - partitionedTableDir = Utils.createTempDir() - normalTableDir = Utils.createTempDir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .toDF() - .write.parquet(partDir.getCanonicalPath) - } - - sparkContext - .makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-1")) - .toDF() - .write.parquet(new File(normalTableDir, "normal").getCanonicalPath) - - partitionedTableDirWithKey = Utils.createTempDir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithKey, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetDataWithKey(p, i, s"part-$p")) - .toDF() - .write.parquet(partDir.getCanonicalPath) - } - - partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") - sparkContext.makeRDD(1 to 10).map { i => - ParquetDataWithKeyAndComplexTypes( - p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().write.parquet(partDir.getCanonicalPath) - } - - partitionedTableDirWithComplexTypes = Utils.createTempDir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") - sparkContext.makeRDD(1 to 10).map { i => - ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().write.parquet(partDir.getCanonicalPath) - } - } - - override protected def afterAll(): Unit = { - try { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() - } finally { - super.afterAll() - } - } - - /** - * Drop named tables if they exist - * - * @param tableNames tables to drop - */ - def dropTables(tableNames: String*): Unit = { - tableNames.foreach { name => - sql(s"DROP TABLE IF EXISTS $name") - } - } - - Seq( - "partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes").foreach { table => - - test(s"ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) - } - - test(s"pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - Row(10)) - } - - test(s"non-existent partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").rdd.map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } - } - - Seq( - "partitioned_parquet_with_key_and_complextypes", - "partitioned_parquet_with_complextypes").foreach { table => - - test(s"SPARK-5775 read struct from $table") { - checkAnswer( - sql( - s""" - |SELECT p, structField.intStructField, structField.stringStructField - |FROM $table WHERE p = 1 - """.stripMargin), - (1 to 10).map(i => Row(1, i, f"${i}_string"))) - } - - test(s"SPARK-5775 read array from $table") { - checkAnswer( - sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), - (1 to 10).map(i => Row((1 to i).toArray, 1))) - } - } - - - test("non-part select(*)") { - checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), - Row(10)) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala new file mode 100644 index 0000000000000..de588768cfdee --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * A suite of tests for the Parquet support through the data sources API. + */ +class HiveParquetSourceSuite extends ParquetPartitioningTest { + import testImplicits._ + import spark._ + + override def beforeAll(): Unit = { + super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet") + + sql( + s""" + |CREATE TEMPORARY VIEW partitioned_parquet + |USING org.apache.spark.sql.parquet + |OPTIONS ( + | path '${partitionedTableDir.toURI}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TEMPORARY VIEW partitioned_parquet_with_key + |USING org.apache.spark.sql.parquet + |OPTIONS ( + | path '${partitionedTableDirWithKey.toURI}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TEMPORARY VIEW normal_parquet + |USING org.apache.spark.sql.parquet + |OPTIONS ( + | path '${new File(partitionedTableDir, "p=1").toURI}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TEMPORARY VIEW partitioned_parquet_with_key_and_complextypes + |USING org.apache.spark.sql.parquet + |OPTIONS ( + | path '${partitionedTableDirWithKeyAndComplexTypes.toURI}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TEMPORARY VIEW partitioned_parquet_with_complextypes + |USING org.apache.spark.sql.parquet + |OPTIONS ( + | path '${partitionedTableDirWithComplexTypes.toURI}' + |) + """.stripMargin) + } + + test("SPARK-6016 make sure to use the latest footers") { + val tableName = "spark_6016_fix" + withTable(tableName) { + // Create a DataFrame with two partitions. So, the created table will have two parquet files. + val df1 = (1 to 10).map(Tuple1(_)).toDF("a").coalesce(2) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable(tableName) + checkAnswer( + sql(s"select * from $tableName"), + (1 to 10).map(i => Row(i)) + ) + + // Create a DataFrame with four partitions. So the created table will have four parquet files. + val df2 = (1 to 10).map(Tuple1(_)).toDF("b").coalesce(4) + df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable(tableName) + // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, + // since the new table has four parquet files, we are trying to read new footers from two + // files and then merge metadata in footers of these four + // (two outdated ones and two latest one), which will cause an error. + checkAnswer( + sql(s"select * from $tableName"), + (1 to 10).map(i => Row(i)) + ) + } + } + + test("SPARK-8811: compatibility with array of struct in Hive") { + withTempPath { dir => + withTable("array_of_struct") { + val conf = Seq( + HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false", + SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") + + withSQLConf(conf: _*) { + sql( + s"""CREATE TABLE array_of_struct + |STORED AS PARQUET LOCATION '${dir.toURI}' + |AS SELECT + | '1st' AS a, + | '2nd' AS b, + | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c + """.stripMargin) + + checkAnswer( + spark.read.parquet(dir.getCanonicalPath), + Row("1st", "2nd", Seq(Row("val_a", "val_b")))) + } + } + } + } + + test("Verify the PARQUET conversion parameter: CONVERT_METASTORE_PARQUET") { + withTempView("single") { + val singleRowDF = Seq((0, "foo")).toDF("key", "value") + singleRowDF.createOrReplaceTempView("single") + + Seq("true", "false").foreach { parquetConversion => + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> parquetConversion) { + val tableName = "test_parquet_ctas" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName STORED AS PARQUET + |AS SELECT tmp.key, tmp.value FROM single tmp + """.stripMargin) + + val df = spark.sql(s"SELECT * FROM $tableName WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + if (parquetConversion == "true") { + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => + }.getOrElse { + fail(s"Expecting the query plan to convert parquet to data sources, " + + s"but got:\n$queryExecution") + } + } else { + queryExecution.analyzed.collectFirst { + case _: HiveTableRelation => + }.getOrElse { + fail(s"Expecting no conversion from parquet to data sources, " + + s"but got:\n$queryExecution") + } + } + } + } + } + } + } + + test("values in arrays and maps stored in parquet are always nullable") { + val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") + val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) + val arrayType1 = ArrayType(IntegerType, containsNull = false) + val expectedSchema1 = + StructType( + StructField("m", mapType1, nullable = true) :: + StructField("a", arrayType1, nullable = true) :: Nil) + assert(df.schema === expectedSchema1) + + withTable("alwaysNullable") { + df.write.format("parquet").saveAsTable("alwaysNullable") + + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) + + assert(table("alwaysNullable").schema === expectedSchema2) + + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + } + } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + withTempDir { tempDir => + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[Throwable](df2.write.parquet(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.write.parquet(filePath2) + val df4 = read.parquet(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetPartitioningTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetPartitioningTest.scala new file mode 100644 index 0000000000000..2ae3cf4b38f04 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetPartitioningTest.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) +// The data that also includes the partitioning key +case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) + +case class StructContainer(intStructField: Int, stringStructField: String) + +case class ParquetDataWithComplexTypes( + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +case class ParquetDataWithKeyAndComplexTypes( + p: Int, + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +/** + * A collection of tests for parquet data with various forms of partitioning. + */ +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + var partitionedTableDir: File = null + var normalTableDir: File = null + var partitionedTableDirWithKey: File = null + var partitionedTableDirWithComplexTypes: File = null + var partitionedTableDirWithKeyAndComplexTypes: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + partitionedTableDir = Utils.createTempDir() + normalTableDir = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .toDF() + .write.parquet(partDir.getCanonicalPath) + } + + sparkContext + .makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-1")) + .toDF() + .write.parquet(new File(normalTableDir, "normal").getCanonicalPath) + + partitionedTableDirWithKey = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKey, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetDataWithKey(p, i, s"part-$p")) + .toDF() + .write.parquet(partDir.getCanonicalPath) + } + + partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithKeyAndComplexTypes( + p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().write.parquet(partDir.getCanonicalPath) + } + + partitionedTableDirWithComplexTypes = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().write.parquet(partDir.getCanonicalPath) + } + } + + override protected def afterAll(): Unit = { + try { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } finally { + super.afterAll() + } + } + + /** + * Drop named tables if they exist + * + * @param tableNames tables to drop + */ + def dropTables(tableNames: String*): Unit = { + tableNames.foreach { name => + sql(s"DROP TABLE IF EXISTS $name") + } + } + + Seq( + "partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes").foreach { table => + + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) + } + + test(s"pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + Row(10)) + } + + test(s"non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } + + test(s"sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").rdd.map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } + } + + Seq( + "partitioned_parquet_with_key_and_complextypes", + "partitioned_parquet_with_complextypes").foreach { table => + + test(s"SPARK-5775 read struct from $table") { + checkAnswer( + sql( + s""" + |SELECT p, structField.intStructField, structField.stringStructField + |FROM $table WHERE p = 1 + """.stripMargin), + (1 to 10).map(i => Row(1, i, f"${i}_string"))) + } + + test(s"SPARK-5775 read array from $table") { + checkAnswer( + sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), + (1 to 10).map(i => Row((1 to i).toArray, 1))) + } + } + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + Row(10)) + } +} From 0fbba76faa00a18eef5d8c2ef2e673744d0d490b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 23 Sep 2018 10:16:33 +0800 Subject: [PATCH 1678/2461] [MINOR][PYSPARK] Always Close the tempFile in _serialize_to_jvm ## What changes were proposed in this pull request? Always close the tempFile after `serializer.dump_stream(data, tempFile)` in _serialize_to_jvm ## How was this patch tested? N/A Closes #22523 from gatorsmile/fixMinor. Authored-by: gatorsmile Signed-off-by: hyukjinkwon --- python/pyspark/context.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 87255c40e330e..0924d3d95f044 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -537,8 +537,10 @@ def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer): # parallelize from there. tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) try: - serializer.dump_stream(data, tempFile) - tempFile.close() + try: + serializer.dump_stream(data, tempFile) + finally: + tempFile.close() return reader_func(tempFile.name) finally: # we eagerily reads the file so we can delete right after. From a72d118cd96cd44d37cb8f8b6c444953a99aab3f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 23 Sep 2018 11:14:27 +0800 Subject: [PATCH 1679/2461] [SPARK-25473][PYTHON][SS][TEST] ForeachWriter tests failed on Python 3.6 and macOS High Sierra ## What changes were proposed in this pull request? This PR does not fix the problem itself but just target to add few comments to run PySpark tests on Python 3.6 and macOS High Serria since it actually blocks to run tests on this enviornment. it does not target to fix the problem yet. The problem here looks because we fork python workers and the forked workers somehow call Objective-C libraries in some codes at CPython's implementation. After debugging a while, I suspect `pickle` in Python 3.6 has some changes: https://github.com/apache/spark/blob/58419b92673c46911c25bc6c6b13397f880c6424/python/pyspark/serializers.py#L577 in particular, it looks also related to which objects are serialized or not as well. This link (http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html) and this link (https://blog.phusion.nl/2017/10/13/why-ruby-app-servers-break-on-macos-high-sierra-and-what-can-be-done-about-it/) were helpful for me to understand this. I am still debugging this but my guts say it's difficult to fix or workaround within Spark side. ## How was this patch tested? Manually tested: Before `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`: ``` /usr/local/Cellar/python/3.6.5/Frameworks/Python.framework/Versions/3.6/lib/python3.6/subprocess.py:766: ResourceWarning: subprocess 27563 is still running ResourceWarning, source=self) [Stage 0:> (0 + 1) / 1]objc[27586]: +[__NSPlaceholderDictionary initialize] may have been in progress in another thread when fork() was called. objc[27586]: +[__NSPlaceholderDictionary initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug. ERROR ====================================================================== ERROR: test_streaming_foreach_with_simple_function (pyspark.sql.tests.SQLTests) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value format(target_id, ".", name), value) py4j.protocol.Py4JJavaError: An error occurred while calling o54.processAllAvailable. : org.apache.spark.sql.streaming.StreamingQueryException: Writing job aborted. === Streaming Query === Identifier: [id = f508d634-407c-4232-806b-70e54b055c42, runId = 08d1435b-5358-4fb6-b167-811584a3163e] Current Committed Offsets: {} Current Available Offsets: {FileStreamSource[file:/var/folders/71/484zt4z10ks1vydt03bhp6hr0000gp/T/tmpolebys1s]: {"logOffset":0}} Current State: ACTIVE Thread State: RUNNABLE Logical Plan: FileStreamSource[file:/var/folders/71/484zt4z10ks1vydt03bhp6hr0000gp/T/tmpolebys1s] at org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runStream(StreamExecution.scala:295) at org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:189) Caused by: org.apache.spark.SparkException: Writing job aborted. at org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2Exec.doExecute(WriteToDataSourceV2Exec.scala:91) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) ``` After `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`: ``` test_streaming_foreach_with_simple_function (pyspark.sql.tests.SQLTests) ... ok ``` Closes #22480 from HyukjinKwon/SPARK-25473. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9fa1577681f03..b829baeca4775 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1961,6 +1961,9 @@ def __getstate__(self): def __setstate__(self, state): self.open_events_dir, self.process_events_dir, self.close_events_dir = state + # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules + # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html + # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. def test_streaming_foreach_with_simple_function(self): tester = self.ForeachWriterTester(self.spark) From 9bf04d8543d70ba8e55c970f2a8e2df872cf74f6 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sun, 23 Sep 2018 13:34:06 -0700 Subject: [PATCH 1680/2461] [SPARK-25489][ML][TEST] Refactor UDTSerializationBenchmark ## What changes were proposed in this pull request? Refactor `UDTSerializationBenchmark` to use main method and print the output as a separate file. Run blow command to generate benchmark results: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "mllib/test:runMain org.apache.spark.mllib.linalg.UDTSerializationBenchmark" ``` ## How was this patch tested? Manual tests. Closes #22499 from seancxmao/SPARK-25489. Authored-by: seancxmao Signed-off-by: Dongjoon Hyun --- .../UDTSerializationBenchmark-results.txt | 13 ++++ .../linalg/UDTSerializationBenchmark.scala | 70 ++++++++++--------- 2 files changed, 49 insertions(+), 34 deletions(-) create mode 100644 mllib/benchmarks/UDTSerializationBenchmark-results.txt diff --git a/mllib/benchmarks/UDTSerializationBenchmark-results.txt b/mllib/benchmarks/UDTSerializationBenchmark-results.txt new file mode 100644 index 0000000000000..169f4c60c748e --- /dev/null +++ b/mllib/benchmarks/UDTSerializationBenchmark-results.txt @@ -0,0 +1,13 @@ +================================================================================================ +VectorUDT de/serialization +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz + +VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +serialize 144 / 206 0.0 143979.7 1.0X +deserialize 114 / 135 0.0 113802.6 1.3X + + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index e2976e1ab022b..1a2216ea070c4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -17,53 +17,55 @@ package org.apache.spark.mllib.linalg -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder /** * Serialization benchmark for VectorUDT. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "mllib/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "mllib/test:runMain " + * Results will be written to "benchmarks/UDTSerializationBenchmark-results.txt". + * }}} */ -object UDTSerializationBenchmark { +object UDTSerializationBenchmark extends BenchmarkBase { - def main(args: Array[String]): Unit = { - val iters = 1e2.toInt - val numRows = 1e3.toInt + override def benchmark(): Unit = { - val encoder = ExpressionEncoder[Vector].resolveAndBind() + runBenchmark("VectorUDT de/serialization") { + val iters = 1e2.toInt + val numRows = 1e3.toInt - val vectors = (1 to numRows).map { i => - Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) - }.toArray - val rows = vectors.map(encoder.toRow) + val encoder = ExpressionEncoder[Vector].resolveAndBind() - val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters) + val vectors = (1 to numRows).map { i => + Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) + }.toArray + val rows = vectors.map(encoder.toRow) - benchmark.addCase("serialize") { _ => - var sum = 0 - var i = 0 - while (i < numRows) { - sum += encoder.toRow(vectors(i)).numFields - i += 1 + val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters, output = output) + + benchmark.addCase("serialize") { _ => + var sum = 0 + var i = 0 + while (i < numRows) { + sum += encoder.toRow(vectors(i)).numFields + i += 1 + } } - } - benchmark.addCase("deserialize") { _ => - var sum = 0 - var i = 0 - while (i < numRows) { - sum += encoder.fromRow(rows(i)).numActives - i += 1 + benchmark.addCase("deserialize") { _ => + var sum = 0 + var i = 0 + while (i < numRows) { + sum += encoder.fromRow(rows(i)).numActives + i += 1 + } } - } - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - serialize 265 / 318 0.0 265138.5 1.0X - deserialize 155 / 197 0.0 154611.4 1.7X - */ - benchmark.run() + benchmark.run() + } } } From d522a563ad5ab157993a19f406a3cc6f443ccb9e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 24 Sep 2018 09:30:07 +0800 Subject: [PATCH 1681/2461] [SPARK-25415][SQL][FOLLOW-UP] Add Locale.ROOT when toUpperCase ## What changes were proposed in this pull request? Add `Locale.ROOT` when `toUpperCase`. ## How was this patch tested? manual tests Closes #22531 from wangyum/SPARK-25415. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon --- .../org/apache/spark/sql/catalyst/rules/RuleExecutor.scala | 2 +- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 183be5a027ec5..e991a2dc7462f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -144,7 +144,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { private class PlanChangeLogger { - private val logLevel = SQLConf.get.optimizerPlanChangeLogLevel.toUpperCase + private val logLevel = SQLConf.get.optimizerPlanChangeLogLevel private val logRules = SQLConf.get.optimizerPlanChangeRules.map(Utils.stringToSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ddf17fa88c76b..0e0a01def357e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -177,8 +177,8 @@ object SQLConf { "plan after a rule is applied. The value can be 'trace', 'debug', 'info', 'warn', or " + "'error'. The default log level is 'trace'.") .stringConf - .checkValue( - str => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(str.toUpperCase), + .transform(_.toUpperCase(Locale.ROOT)) + .checkValue(logLevel => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(logLevel), "Invalid value for 'spark.sql.optimizer.planChangeLog.level'. Valid values are " + "'trace', 'debug', 'info', 'warn' and 'error'.") .createWithDefault("trace") From c79072aafa2f406c342e393e0c61bb5cb3e89a7f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 23 Sep 2018 20:46:40 -0700 Subject: [PATCH 1682/2461] [SPARK-25478][SQL][TEST] Refactor CompressionSchemeBenchmark to use main method ## What changes were proposed in this pull request? Refactor `CompressionSchemeBenchmark` to use main method. Generate benchmark result: ```sh SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.columnar.compression.CompressionSchemeBenchmark" ``` ## How was this patch tested? manual tests Closes #22486 from wangyum/SPARK-25478. Lead-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../CompressionSchemeBenchmark-results.txt | 137 +++++++++++++++++ .../CompressionSchemeBenchmark.scala | 138 +++--------------- 2 files changed, 156 insertions(+), 119 deletions(-) create mode 100644 sql/core/benchmarks/CompressionSchemeBenchmark-results.txt diff --git a/sql/core/benchmarks/CompressionSchemeBenchmark-results.txt b/sql/core/benchmarks/CompressionSchemeBenchmark-results.txt new file mode 100644 index 0000000000000..caa9378301f5d --- /dev/null +++ b/sql/core/benchmarks/CompressionSchemeBenchmark-results.txt @@ -0,0 +1,137 @@ +================================================================================================ +Compression Scheme Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +BOOLEAN Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 4 / 4 17998.9 0.1 1.0X +RunLengthEncoding(2.501) 680 / 680 98.7 10.1 0.0X +BooleanBitSet(0.125) 365 / 365 183.9 5.4 0.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +BOOLEAN Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 144 / 144 466.5 2.1 1.0X +RunLengthEncoding 679 / 679 98.9 10.1 0.2X +BooleanBitSet 1425 / 1431 47.1 21.2 0.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SHORT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 7 / 7 10115.0 0.1 1.0X +RunLengthEncoding(1.494) 1671 / 1672 40.2 24.9 0.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SHORT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 1128 / 1128 59.5 16.8 1.0X +RunLengthEncoding 1630 / 1633 41.2 24.3 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SHORT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 7 / 7 10164.2 0.1 1.0X +RunLengthEncoding(1.989) 1562 / 1563 43.0 23.3 0.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SHORT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 1127 / 1127 59.6 16.8 1.0X +RunLengthEncoding 1629 / 1631 41.2 24.3 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +INT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 22 / 23 2983.2 0.3 1.0X +RunLengthEncoding(1.003) 2426 / 2427 27.7 36.1 0.0X +DictionaryEncoding(0.500) 958 / 958 70.1 14.3 0.0X +IntDelta(0.250) 286 / 286 235.0 4.3 0.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +INT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 1268 / 1269 52.9 18.9 1.0X +RunLengthEncoding 1906 / 1911 35.2 28.4 0.7X +DictionaryEncoding 981 / 982 68.4 14.6 1.3X +IntDelta 812 / 817 82.6 12.1 1.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +INT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 23 / 23 2926.9 0.3 1.0X +RunLengthEncoding(1.326) 2614 / 2614 25.7 38.9 0.0X +DictionaryEncoding(0.501) 1024 / 1024 65.5 15.3 0.0X +IntDelta(0.250) 286 / 286 234.7 4.3 0.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +INT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 1433 / 1433 46.8 21.4 1.0X +RunLengthEncoding 1923 / 1926 34.9 28.6 0.7X +DictionaryEncoding 1285 / 1285 52.2 19.2 1.1X +IntDelta 1129 / 1137 59.4 16.8 1.3X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +LONG Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 45 / 45 1495.6 0.7 1.0X +RunLengthEncoding(0.738) 2662 / 2663 25.2 39.7 0.0X +DictionaryEncoding(0.250) 1269 / 1269 52.9 18.9 0.0X +LongDelta(0.125) 450 / 450 149.1 6.7 0.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +LONG Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 1483 / 1483 45.3 22.1 1.0X +RunLengthEncoding 1875 / 1875 35.8 27.9 0.8X +DictionaryEncoding 1213 / 1214 55.3 18.1 1.2X +LongDelta 816 / 817 82.2 12.2 1.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +LONG Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 45 / 45 1489.3 0.7 1.0X +RunLengthEncoding(1.003) 2906 / 2906 23.1 43.3 0.0X +DictionaryEncoding(0.251) 1610 / 1610 41.7 24.0 0.0X +LongDelta(0.125) 451 / 451 148.7 6.7 0.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +LONG Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 1485 / 1485 45.2 22.1 1.0X +RunLengthEncoding 1889 / 1890 35.5 28.2 0.8X +DictionaryEncoding 1215 / 1216 55.2 18.1 1.2X +LongDelta 1107 / 1110 60.6 16.5 1.3X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +STRING Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough(1.000) 67 / 68 994.5 1.0 1.0X +RunLengthEncoding(0.894) 5877 / 5882 11.4 87.6 0.0X +DictionaryEncoding(0.167) 3597 / 3602 18.7 53.6 0.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +STRING Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +PassThrough 3243 / 3244 20.7 48.3 1.0X +RunLengthEncoding 3598 / 3601 18.7 53.6 0.9X +DictionaryEncoding 3182 / 3182 21.1 47.4 1.0X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 9c26d67b62ccc..ff0e4acd31279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.distribution.LogNormalDistribution -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} import org.apache.spark.sql.types.AtomicType @@ -31,8 +31,15 @@ import org.apache.spark.util.Utils._ /** * Benchmark to decoders using various compression schemes. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/CompressionSchemeBenchmark-results.txt". + * }}} */ -object CompressionSchemeBenchmark extends AllCompressionSchemes { +object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchemes { private[this] def allocateLocal(size: Int): ByteBuffer = { ByteBuffer.allocate(size).order(ByteOrder.nativeOrder) @@ -77,7 +84,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count.toLong) + val benchmark = new Benchmark(name, iters * count.toLong, output = output) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -101,7 +108,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count.toLong) + val benchmark = new Benchmark(name, iters * count.toLong, output = output) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -138,21 +145,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.put(i * BOOLEAN.defaultSize, g()) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // BOOLEAN Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 3 / 4 19300.2 0.1 1.0X - // RunLengthEncoding(2.491) 923 / 939 72.7 13.8 0.0X - // BooleanBitSet(0.125) 359 / 363 187.1 5.3 0.0X runEncodeBenchmark("BOOLEAN Encode", iters, count, BOOLEAN, testData) - - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // BOOLEAN Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 129 / 136 519.8 1.9 1.0X - // RunLengthEncoding 613 / 623 109.4 9.1 0.2X - // BooleanBitSet 1196 / 1222 56.1 17.8 0.1X runDecodeBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) } @@ -165,18 +158,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.putShort(i * SHORT.defaultSize, g1().toShort) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 6 / 7 10971.4 0.1 1.0X - // RunLengthEncoding(1.510) 1526 / 1542 44.0 22.7 0.0X runEncodeBenchmark("SHORT Encode (Lower Skew)", iters, count, SHORT, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 811 / 837 82.8 12.1 1.0X - // RunLengthEncoding 1219 / 1266 55.1 18.2 0.7X runDecodeBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) val g2 = genHigherSkewData() @@ -184,18 +166,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.putShort(i * SHORT.defaultSize, g2().toShort) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 7 / 7 10112.4 0.1 1.0X - // RunLengthEncoding(2.009) 1623 / 1661 41.4 24.2 0.0X runEncodeBenchmark("SHORT Encode (Higher Skew)", iters, count, SHORT, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // SHORT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 818 / 827 82.0 12.2 1.0X - // RunLengthEncoding 1202 / 1237 55.8 17.9 0.7X runDecodeBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) } @@ -208,22 +179,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.putInt(i * INT.defaultSize, g1().toInt) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 18 / 19 3716.4 0.3 1.0X - // RunLengthEncoding(1.001) 1992 / 2056 33.7 29.7 0.0X - // DictionaryEncoding(0.500) 723 / 739 92.8 10.8 0.0X - // IntDelta(0.250) 368 / 377 182.2 5.5 0.0X runEncodeBenchmark("INT Encode (Lower Skew)", iters, count, INT, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 821 / 845 81.8 12.2 1.0X - // RunLengthEncoding 1246 / 1256 53.9 18.6 0.7X - // DictionaryEncoding 757 / 766 88.6 11.3 1.1X - // IntDelta 680 / 689 98.7 10.1 1.2X runDecodeBenchmark("INT Decode (Lower Skew)", iters, count, INT, testData) val g2 = genHigherSkewData() @@ -231,22 +187,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.putInt(i * INT.defaultSize, g2().toInt) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 17 / 19 3888.4 0.3 1.0X - // RunLengthEncoding(1.339) 2127 / 2148 31.5 31.7 0.0X - // DictionaryEncoding(0.501) 960 / 972 69.9 14.3 0.0X - // IntDelta(0.250) 362 / 366 185.5 5.4 0.0X runEncodeBenchmark("INT Encode (Higher Skew)", iters, count, INT, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // INT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 838 / 884 80.1 12.5 1.0X - // RunLengthEncoding 1287 / 1311 52.1 19.2 0.7X - // DictionaryEncoding 844 / 859 79.5 12.6 1.0X - // IntDelta 764 / 784 87.8 11.4 1.1X runDecodeBenchmark("INT Decode (Higher Skew)", iters, count, INT, testData) } @@ -259,22 +200,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.putLong(i * LONG.defaultSize, g1().toLong) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 37 / 38 1804.8 0.6 1.0X - // RunLengthEncoding(0.748) 2065 / 2094 32.5 30.8 0.0X - // DictionaryEncoding(0.250) 950 / 962 70.6 14.2 0.0X - // LongDelta(0.125) 475 / 482 141.2 7.1 0.1X runEncodeBenchmark("LONG Encode (Lower Skew)", iters, count, LONG, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 888 / 894 75.5 13.2 1.0X - // RunLengthEncoding 1301 / 1311 51.6 19.4 0.7X - // DictionaryEncoding 887 / 904 75.7 13.2 1.0X - // LongDelta 693 / 735 96.8 10.3 1.3X runDecodeBenchmark("LONG Decode (Lower Skew)", iters, count, LONG, testData) val g2 = genHigherSkewData() @@ -282,22 +208,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { testData.putLong(i * LONG.defaultSize, g2().toLong) } - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 34 / 35 1963.9 0.5 1.0X - // RunLengthEncoding(0.999) 2260 / 3021 29.7 33.7 0.0X - // DictionaryEncoding(0.251) 1270 / 1438 52.8 18.9 0.0X - // LongDelta(0.125) 496 / 509 135.3 7.4 0.1X runEncodeBenchmark("LONG Encode (Higher Skew)", iters, count, LONG, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // LONG Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 965 / 1494 69.5 14.4 1.0X - // RunLengthEncoding 1350 / 1378 49.7 20.1 0.7X - // DictionaryEncoding 892 / 924 75.2 13.3 1.1X - // LongDelta 817 / 847 82.2 12.2 1.2X runDecodeBenchmark("LONG Decode (Higher Skew)", iters, count, LONG, testData) } @@ -318,28 +229,17 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { } testData.rewind() - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // STRING Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough(1.000) 56 / 57 1197.9 0.8 1.0X - // RunLengthEncoding(0.893) 4892 / 4937 13.7 72.9 0.0X - // DictionaryEncoding(0.167) 2968 / 2992 22.6 44.2 0.0X runEncodeBenchmark("STRING Encode", iters, count, STRING, testData) - - // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz - // STRING Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - // ------------------------------------------------------------------------------------------- - // PassThrough 2422 / 2449 27.7 36.1 1.0X - // RunLengthEncoding 2885 / 3018 23.3 43.0 0.8X - // DictionaryEncoding 2716 / 2752 24.7 40.5 0.9X runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) } - def main(args: Array[String]): Unit = { - bitEncodingBenchmark(1024) - shortEncodingBenchmark(1024) - intEncodingBenchmark(1024) - longEncodingBenchmark(1024) - stringEncodingBenchmark(1024) + override def benchmark(): Unit = { + runBenchmark("Compression Scheme Benchmark") { + bitEncodingBenchmark(1024) + shortEncodingBenchmark(1024) + intEncodingBenchmark(1024) + longEncodingBenchmark(1024) + stringEncodingBenchmark(1024) + } } } From c3b4a94a91d66c172cf332321d3a78dba29ef8f0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 24 Sep 2018 19:25:02 +0800 Subject: [PATCH 1683/2461] [SPARKR] Match pyspark features in SparkR communication protocol --- R/pkg/R/context.R | 43 +++++++++++++------ R/pkg/tests/fulltests/test_Serde.R | 32 ++++++++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 12 ------ .../scala/org/apache/spark/api/r/RRDD.scala | 33 +++++++++++++- .../scala/org/apache/spark/api/r/RUtils.scala | 4 ++ 5 files changed, 98 insertions(+), 26 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index f168ca76b6007..e99136723f65b 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -167,18 +167,30 @@ parallelize <- function(sc, coll, numSlices = 1) { # 2-tuples of raws serializedSlices <- lapply(slices, serialize, connection = NULL) - # The PRC backend cannot handle arguments larger than 2GB (INT_MAX) + # The RPC backend cannot handle arguments larger than 2GB (INT_MAX) # If serialized data is safely less than that threshold we send it over the PRC channel. # Otherwise, we write it to a file and send the file name if (objectSize < sizeLimit) { jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices) } else { - fileName <- writeToTempFile(serializedSlices) - jrdd <- tryCatch(callJStatic( - "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), - finally = { - file.remove(fileName) - }) + if (callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)) { + # the length of slices here is the parallelism to use in the jvm's sc.parallelize() + parallelism <- as.integer(numSlices) + jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, parallelism) + authSecret <- callJMethod(jserver, "secret") + port <- callJMethod(jserver, "port") + conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500) + doServerAuth(conn, authSecret) + writeToConnection(serializedSlices, conn) + jrdd <- callJMethod(jserver, "getResult") + } else { + fileName <- writeToTempFile(serializedSlices) + jrdd <- tryCatch(callJStatic( + "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), + finally = { + file.remove(fileName) + }) + } } RDD(jrdd, "byte") @@ -194,14 +206,21 @@ getMaxAllocationLimit <- function(sc) { )) } +writeToConnection <- function(serializedSlices, conn) { + tryCatch({ + for (slice in serializedSlices) { + writeBin(as.integer(length(slice)), conn, endian = "big") + writeBin(slice, conn, endian = "big") + } + }, finally = { + close(conn) + }) +} + writeToTempFile <- function(serializedSlices) { fileName <- tempfile() conn <- file(fileName, "wb") - for (slice in serializedSlices) { - writeBin(as.integer(length(slice)), conn, endian = "big") - writeBin(slice, conn, endian = "big") - } - close(conn) + writeToConnection(serializedSlices, conn) fileName } diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 3577929323b8b..1525bdb2f5c8b 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -124,3 +124,35 @@ test_that("SerDe of list of lists", { }) sparkR.session.stop() + +# Note that this test should be at the end of tests since the configruations used here are not +# specific to sessions, and the Spark context is restarted. +test_that("createDataFrame large objects", { + for (encryptionEnabled in list("true", "false")) { + # To simulate a large object scenario, we set spark.r.maxAllocationLimit to a smaller value + conf <- list(spark.r.maxAllocationLimit = "100", + spark.io.encryption.enabled = encryptionEnabled) + + suppressWarnings(sparkR.session(master = sparkRTestMaster, + sparkConfig = conf, + enableHiveSupport = FALSE)) + + sc <- getSparkContext() + actual <- callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc) + expected <- as.logical(encryptionEnabled) + expect_equal(actual, expected) + + tryCatch({ + # suppress warnings from dot in the field names. See also SPARK-21536. + df <- suppressWarnings(createDataFrame(iris, numPartitions = 3)) + expect_equal(getNumPartitions(df), 3) + expect_equal(dim(df), dim(iris)) + + df <- createDataFrame(cars, numPartitions = 3) + expect_equal(collect(df), cars) + }, + finally = { + sparkR.stop() + }) + } +}) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 40d8f8084f2f4..a874bfbb58dc7 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -316,18 +316,6 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) -test_that("createDataFrame uses files for large objects", { - # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value - conf <- callJMethod(sparkSession, "conf") - callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") - df <- suppressWarnings(createDataFrame(iris, numPartitions = 3)) - expect_equal(getNumPartitions(df), 3) - - # Resetting the conf back to default value - callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) - expect_equal(dim(df), dim(iris)) -}) - test_that("read/write csv as DataFrame", { if (windows_with_hadoop()) { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 295355c7bf018..1dc61c7eef33c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,7 +17,9 @@ package org.apache.spark.api.r -import java.io.File +import java.io.{DataInputStream, File} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -25,10 +27,11 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonRDD, PythonServer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -163,3 +166,29 @@ private[r] object RRDD { PythonRDD.readRDDFromFile(jsc, fileName, parallelism) } } + +/** + * Helper for making RDD[Array[Byte]] from some R data, by reading the data from R + * over a socket. This is used in preference to writing data to a file when encryption is enabled. + */ +private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int) + extends PythonServer[JavaRDD[Array[Byte]]]( + new RSocketAuthHelper(), "sparkr-parallelize-server") { + + override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = { + val in = sock.getInputStream() + PythonRDD.readRDDFromInputStream(sc.sc, in, parallelism) + } +} + +private[spark] class RSocketAuthHelper extends SocketAuthHelper(SparkEnv.get.conf) { + override protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + // The R code adds a null terminator to serialized strings, so ignore it here. + assert(bytes(bytes.length - 1) == 0) // sanity check. + new String(bytes, 0, bytes.length - 1, UTF_8) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index fdd8cf62f0e5f..9bf35af1da925 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -21,6 +21,8 @@ import java.io.File import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.api.python.PythonUtils private[spark] object RUtils { // Local path where R binary packages built from R source code contained in the spark @@ -104,4 +106,6 @@ private[spark] object RUtils { case e: Exception => false } } + + def getEncryptionEnabled(sc: JavaSparkContext): Boolean = PythonUtils.getEncryptionEnabled(sc) } From 804515f821086ea685815d3c8eff42d76b7d9e4e Mon Sep 17 00:00:00 2001 From: Stan Zhai Date: Mon, 24 Sep 2018 21:33:12 +0800 Subject: [PATCH 1684/2461] [SPARK-21318][SQL] Improve exception message thrown by `lookupFunction` ## What changes were proposed in this pull request? The function actually exists in current selected database, and it's failed to init during `lookupFunciton`, but the exception message is: ``` This function is neither a registered temporary function nor a permanent function registered in the database 'default'. ``` This is not conducive to positioning problems. This PR fix the problem. ## How was this patch tested? new test case + manual tests Closes #18544 from stanzhai/fix-udf-error-message. Authored-by: Stan Zhai Signed-off-by: Wenchen Fan --- .../catalog/SessionCatalogSuite.scala | 3 ++ .../spark/sql/hive/HiveSessionCatalog.scala | 10 +++--- .../spark/sql/hive/execution/UDAFEmpty.java | 32 +++++++++++++++++++ .../org/apache/spark/sql/hive/UDFSuite.scala | 16 ++++++++++ 4 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDAFEmpty.java diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 89fabd4774065..19e8c0334689c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1427,6 +1427,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { Seq(true, false) foreach { caseSensitive => val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) + catalog.setCurrentDatabase("db1") try { val analyzer = new Analyzer(catalog, conf) @@ -1440,6 +1441,8 @@ abstract class SessionCatalogSuite extends AnalysisTest { } assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + // SPARK-21318: the error message should contains the current database name + assert(cause.getMessage.contains("db1")) } finally { catalog.reset() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index de41bb418181d..405c0c8bfe660 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -131,14 +131,14 @@ private[sql] class HiveSessionCatalog( Try(super.lookupFunction(funcName, children)) match { case Success(expr) => expr case Failure(error) => - if (functionRegistry.functionExists(funcName)) { - // If the function actually exists in functionRegistry, it means that there is an - // error when we create the Expression using the given children. + if (super.functionExists(name)) { + // If the function exists (either in functionRegistry or externalCatalog), + // it means that there is an error when we create the Expression using the given children. // We need to throw the original exception. throw error } else { - // This function is not in functionRegistry, let's try to load it as a Hive's - // built-in function. + // This function does not exist (neither in functionRegistry or externalCatalog), + // let's try to load it as a Hive's built-in function. // Hive is case insensitive. val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDAFEmpty.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDAFEmpty.java new file mode 100644 index 0000000000000..badc396688f5f --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDAFEmpty.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +/** + * An empty UDAF that throws a semantic exception + */ +public class UDAFEmpty extends AbstractGenericUDAFResolver { + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException { + throw new SemanticException("Can not get an evaluator of the empty UDAF"); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index a56c6f73989a7..d567128e1a322 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -193,4 +193,20 @@ class UDFSuite } } } + + test("SPARK-21318: The correct exception message should be thrown " + + "if a UDF/UDAF has already been registered") { + val functionName = "empty" + val functionClass = classOf[org.apache.spark.sql.hive.execution.UDAFEmpty].getCanonicalName + + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + + val e = intercept[AnalysisException] { + sql(s"SELECT $functionName(value) from $testTableName") + } + + assert(e.getMessage.contains("Can not get an evaluator of the empty UDAF")) + } + } } From bb49661e192eed78a8a306deffd83c73bd4a9eff Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 24 Sep 2018 21:37:51 +0800 Subject: [PATCH 1685/2461] [SPARK-25416][SQL] ArrayPosition function may return incorrect result when right expression is implicitly down casted ## What changes were proposed in this pull request? In ArrayPosition, we currently cast the right hand side expression to match the element type of the left hand side Array. This may result in down casting and may return wrong result or questionable result. Example : ```SQL spark-sql> select array_position(array(1), 1.34); 1 ``` ```SQL spark-sql> select array_position(array(1), 'foo'); null ``` We should safely coerce both left and right hand side expressions. ## How was this patch tested? Added tests in DataFrameFunctionsSuite Closes #22407 from dilipbiswal/SPARK-25416. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 21 ++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 57 +++++++++++++++++-- 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 161adc9cc5bac..85bc1cdb43051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2071,18 +2071,23 @@ case class ArrayPosition(left: Expression, right: Expression) override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = { - val elementType = left.dataType match { - case t: ArrayType => t.elementType - case _ => AnyDataType + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty } - Seq(ArrayType, elementType) } override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ad52fd01248e3..fd71f24935611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1097,18 +1097,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - df.selectExpr("array_position(array(array(1), null)[0], 1)"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.23D)"), + Seq(Row(0L)) ) + checkAnswer( - df.selectExpr("array_position(array(1, null), array(1, null)[0])"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.0D)"), + Seq(Row(1L)) ) - val e = intercept[AnalysisException] { + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1.D), 1)"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1.23D), 1)"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.0D))"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.23D))"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L)) + ) + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L)) + ) + + val e1 = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") } - assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [string, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_position(array(1), '1')") + } + val errorMsg2 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [array, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("element_at function") { From 3ce2e008ec1bf70adc5a4b356e09a469e94af803 Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 24 Sep 2018 14:17:42 -0700 Subject: [PATCH 1686/2461] [SPARK-25502][CORE][WEBUI] Empty Page when page number exceeds the reatinedTask size. ## What changes were proposed in this pull request? Test steps : 1) bin/spark-shell --conf spark.ui.retainedTasks=200 ``` val rdd = sc.parallelize(1 to 1000, 1000) rdd.count ``` Stage tab in the UI will display 10 pages with 100 tasks per page. But number of retained tasks is only 200. So, from the 3rd page onwards will display nothing. We have to calculate total pages based on the number of tasks need display in the UI. **Before fix:** ![empty_4](https://user-images.githubusercontent.com/23054875/45918251-b1650580-bea1-11e8-90d3-7e0d491981a2.jpg) **After fix:** ![empty_3](https://user-images.githubusercontent.com/23054875/45918257-c2ae1200-bea1-11e8-960f-dfbdb4a90ae7.jpg) ## How was this patch tested? Manually tested Closes #22526 from shahidki31/SPARK-25502. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 55eb989962668..fd6a298e577d6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -117,7 +117,8 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = taskCount(stageData) + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks if (totalTasks == 0) { val content =
      @@ -685,7 +686,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = taskCount(stage) + override def dataSize: Int = store.taskCount(stage.stageId, stage.attemptId).toInt override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1051,9 +1052,4 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } - def taskCount(stageData: StageData): Int = { - stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + - stageData.numKilledTasks - } - } From 2c9ffda1b5484fe01cc13a38c9fb52a861c2371b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 25 Sep 2018 07:38:40 +0800 Subject: [PATCH 1687/2461] [BUILD] Closes stale PR Closes #22517 From 615792da42b3ee3c5f623c869fada17a3aa92884 Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 24 Sep 2018 20:03:52 -0700 Subject: [PATCH 1688/2461] [SPARK-25503][CORE][WEBUI] Total task message in stage page is ambiguous ## What changes were proposed in this pull request? Test steps : 1) bin/spark-shell --conf spark.ui.retainedTasks=10 2) val rdd = sc.parallelize(1 to 1000, 1000) 3) rdd.count Stage page tab in the UI will display 10 tasks, but display message is wrong. It should reverse. **Before fix :** ![webui_1](https://user-images.githubusercontent.com/23054875/45917921-8926d800-be9c-11e8-8da5-3998d07e3ccc.jpg) **After fix** ![spark_web_ui2](https://user-images.githubusercontent.com/23054875/45917935-b4112c00-be9c-11e8-9d10-4fcc8e88568f.jpg) ## How was this patch tested? Manually tested Closes #22525 from shahidki31/SparkUI. Authored-by: Shahid Signed-off-by: Dongjoon Hyun --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index fd6a298e577d6..7428bbe6c5592 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -133,7 +133,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$storedTasks, showing ${totalTasks}" + s"$totalTasks, showing $storedTasks" } val summary = From 7d8f5b62c57c9e2903edd305e8b9c5400652fdb0 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 25 Sep 2018 12:05:04 +0800 Subject: [PATCH 1689/2461] [SPARK-25519][SQL] ArrayRemove function may return incorrect result when right expression is implicitly downcasted. ## What changes were proposed in this pull request? In ArrayRemove, we currently cast the right hand side expression to match the element type of the left hand side Array. This may result in down casting and may return wrong result or questionable result. Example : ```SQL spark-sql> select array_remove(array(1,2,3), 1.23D); [2,3] ``` ```SQL spark-sql> select array_remove(array(1,2,3), 'foo'); NULL ``` We should safely coerce both left and right hand side expressions. ## How was this patch tested? Added tests in DataFrameFunctionsSuite Closes #22542 from dilipbiswal/SPARK-25519. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 29 ++++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 48 ++++++++++++++++++- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 85bc1cdb43051..9cc7dbadd923a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3088,11 +3088,24 @@ case class ArrayRemove(left: Expression, right: Expression) override def dataType: DataType = left.dataType override def inputTypes: Seq[AbstractDataType] = { - val elementType = left.dataType match { - case t: ArrayType => t.elementType - case _ => AnyDataType + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } - Seq(ArrayType, elementType) } private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @@ -3100,14 +3113,6 @@ case class ArrayRemove(left: Expression, right: Expression) @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") - } - } - override def nullSafeEval(arr: Any, value: Any): Any = { val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) var pos = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fd71f24935611..88dbae8c21350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1574,6 +1574,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null)) ) + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1, 2), 1.23D)"), + Seq( + Row(Seq(1.0, 2.0)) + ) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1, 2), 1.0D)"), + Seq( + Row(Seq(2.0)) + ) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1.0D, 2.0D), 2)"), + Seq( + Row(Seq(1.0)) + ) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_remove(array(1.1D, 1.2D), 1)"), + Seq( + Row(Seq(1.1, 1.2)) + ) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), @@ -1583,10 +1611,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) - val e = intercept[AnalysisException] { + val e1 = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") } - assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |Input to function array_remove should have been array followed by a + |value with same element type, but it's [string, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_remove(array(1, 2), '1')") + } + + val errorMsg2 = + s""" + |Input to function array_remove should have been array followed by a + |value with same element type, but it's [array, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("array_distinct functions") { From 9cbd001e2476cd06aa0bcfcc77a21a9077d5797a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 25 Sep 2018 20:13:07 +0800 Subject: [PATCH 1690/2461] [SPARK-23907][SQL] Revert regr_* functions entirely ## What changes were proposed in this pull request? This patch reverts entirely all the regr_* functions added in SPARK-23907. These were added by mgaido91 (and proposed by gatorsmile) to improve compatibility with other database systems, without any actual use cases. However, they are very rarely used, and in Spark there are much better ways to compute these functions, due to Spark's flexibility in exposing real programming APIs. I'm going through all the APIs added in Spark 2.4 and I think we should revert these. If there are strong enough demands and more use cases, we can add them back in the future pretty easily. ## How was this patch tested? Reverted test cases also. Closes #22541 from rxin/SPARK-23907. Authored-by: Reynold Xin Signed-off-by: hyukjinkwon --- .../catalyst/analysis/FunctionRegistry.scala | 9 - .../expressions/aggregate/regression.scala | 190 ------------------ .../sql-tests/inputs/udaf-regrfunctions.sql | 56 ------ .../results/udaf-regrfunctions.sql.out | 93 --------- 4 files changed, 348 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala delete mode 100644 sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql delete mode 100644 sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8b69a47036962..7dafebff79874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,15 +300,6 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), - expression[RegrCount]("regr_count"), - expression[RegrSXX]("regr_sxx"), - expression[RegrSYY]("regr_syy"), - expression[RegrAvgX]("regr_avgx"), - expression[RegrAvgY]("regr_avgy"), - expression[RegrSXY]("regr_sxy"), - expression[RegrSlope]("regr_slope"), - expression[RegrR2]("regr_r2"), - expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala deleted file mode 100644 index d8f4505588ff2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{AbstractDataType, DoubleType} - -/** - * Base trait for all regression functions. - */ -trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { - def y: Expression - def x: Expression - - override def children: Seq[Expression] = Seq(y, x) - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - - protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { - assert(aggBufferAttributes.length == exprs.length) - val nullableChildren = children.filter(_.nullable) - if (nullableChildren.isEmpty) { - exprs - } else { - exprs.zip(aggBufferAttributes).map { case (e, a) => - If(nullableChildren.map(IsNull).reduce(Or), a, e) - } - } - } -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", - since = "2.4.0") -case class RegrCount(y: Expression, x: Expression) - extends CountLike with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) - - override def prettyName: String = "regr_count" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrSXX(y: Expression, x: Expression) - extends CentralMomentAgg(x) with RegrLike { - - override protected def momentOrder = 2 - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), m2) - } - - override def prettyName: String = "regr_sxx" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrSYY(y: Expression, x: Expression) - extends CentralMomentAgg(y) with RegrLike { - - override protected def momentOrder = 2 - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), m2) - } - - override def prettyName: String = "regr_syy" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrAvgX(y: Expression, x: Expression) - extends AverageLike(x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override def prettyName: String = "regr_avgx" -} - - -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", - since = "2.4.0") -case class RegrAvgY(y: Expression, x: Expression) - extends AverageLike(y) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override def prettyName: String = "regr_avgy" -} - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrSXY(y: Expression, x: Expression) - extends Covariance(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0), Literal.create(null, DoubleType), ck) - } - - override def prettyName: String = "regr_sxy" -} - - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrSlope(y: Expression, x: Expression) - extends PearsonCorrelation(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) - } - - override def prettyName: String = "regr_slope" -} - - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrR2(y: Expression, x: Expression) - extends PearsonCorrelation(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), - If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) - } - - override def prettyName: String = "regr_r2" -} - - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", - since = "2.4.0") -// scalastyle:on line.size.limit -case class RegrIntercept(y: Expression, x: Expression) - extends PearsonCorrelation(y, x) with RegrLike { - - override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) - - override val evaluateExpression: Expression = { - If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), - xAvg - (ck / yMk) * yAvg) - } - - override def prettyName: String = "regr_intercept" -} diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql deleted file mode 100644 index 92c7e26e3add2..0000000000000 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql +++ /dev/null @@ -1,56 +0,0 @@ --- --- Licensed to the Apache Software Foundation (ASF) under one or more --- contributor license agreements. See the NOTICE file distributed with --- this work for additional information regarding copyright ownership. --- The ASF licenses this file to You under the Apache License, Version 2.0 --- (the "License"); you may not use this file except in compliance with --- the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. --- - -CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES - (101, 1, 1, 1), - (201, 2, 1, 1), - (301, 3, 1, 1), - (401, 4, 1, 11), - (501, 5, 1, null), - (601, 6, null, 1), - (701, 6, null, null), - (102, 1, 2, 2), - (202, 2, 1, 2), - (302, 3, 2, 1), - (402, 4, 2, 12), - (502, 5, 2, null), - (602, 6, null, 2), - (702, 6, null, null), - (103, 1, 3, 3), - (203, 2, 1, 3), - (303, 3, 3, 1), - (403, 4, 3, 13), - (503, 5, 3, null), - (603, 6, null, 3), - (703, 6, null, null), - (104, 1, 4, 4), - (204, 2, 1, 4), - (304, 3, 4, 1), - (404, 4, 4, 14), - (504, 5, 4, null), - (604, 6, null, 4), - (704, 6, null, null), - (800, 7, 1, 1) -as t1(id, px, y, x); - -select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), - regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), - regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) -from t1 group by px order by px; - - -select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out deleted file mode 100644 index d7d009a64bf84..0000000000000 --- a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out +++ /dev/null @@ -1,93 +0,0 @@ --- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 - - --- !query 0 -CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES - (101, 1, 1, 1), - (201, 2, 1, 1), - (301, 3, 1, 1), - (401, 4, 1, 11), - (501, 5, 1, null), - (601, 6, null, 1), - (701, 6, null, null), - (102, 1, 2, 2), - (202, 2, 1, 2), - (302, 3, 2, 1), - (402, 4, 2, 12), - (502, 5, 2, null), - (602, 6, null, 2), - (702, 6, null, null), - (103, 1, 3, 3), - (203, 2, 1, 3), - (303, 3, 3, 1), - (403, 4, 3, 13), - (503, 5, 3, null), - (603, 6, null, 3), - (703, 6, null, null), - (104, 1, 4, 4), - (204, 2, 1, 4), - (304, 3, 4, 1), - (404, 4, 4, 14), - (504, 5, 4, null), - (604, 6, null, 4), - (704, 6, null, null), - (800, 7, 1, 1) -as t1(id, px, y, x) --- !query 0 schema -struct<> --- !query 0 output - - - --- !query 1 -select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), - regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), - regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) -from t1 group by px order by px --- !query 1 schema -struct --- !query 1 output -1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 -2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 -3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 -4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 -5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 -6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 -7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 - - --- !query 2 -select id, regr_count(y,x) over (partition by px) from t1 order by id --- !query 2 schema -struct --- !query 2 output -101 4 -102 4 -103 4 -104 4 -201 4 -202 4 -203 4 -204 4 -301 4 -302 4 -303 4 -304 4 -401 4 -402 4 -403 4 -404 4 -501 0 -502 0 -503 0 -504 0 -601 0 -602 0 -603 0 -604 0 -701 0 -702 0 -703 0 -704 0 -800 1 From 04db035378012907c93f6e5b4faa6ec11f1fc67b Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 25 Sep 2018 11:13:05 -0700 Subject: [PATCH 1691/2461] [SPARK-25486][TEST] Refactor SortBenchmark to use main method ## What changes were proposed in this pull request? Refactor SortBenchmark to use main method. Generate benchmark result: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.SortBenchmark" ``` ## How was this patch tested? manual tests Closes #22495 from yucai/SPARK-25486. Authored-by: yucai Signed-off-by: Dongjoon Hyun --- sql/core/benchmarks/SortBenchmark-results.txt | 17 +++++++++ .../execution/benchmark/SortBenchmark.scala | 38 ++++++++----------- 2 files changed, 33 insertions(+), 22 deletions(-) create mode 100644 sql/core/benchmarks/SortBenchmark-results.txt diff --git a/sql/core/benchmarks/SortBenchmark-results.txt b/sql/core/benchmarks/SortBenchmark-results.txt new file mode 100644 index 0000000000000..0d00a0c89d02d --- /dev/null +++ b/sql/core/benchmarks/SortBenchmark-results.txt @@ -0,0 +1,17 @@ +================================================================================================ +radix sort +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_162-b12 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +reference TimSort key prefix array 11770 / 11960 2.1 470.8 1.0X +reference Arrays.sort 2106 / 2128 11.9 84.3 5.6X +radix sort one byte 93 / 100 269.7 3.7 126.9X +radix sort two bytes 171 / 179 146.0 6.9 68.7X +radix sort eight bytes 659 / 664 37.9 26.4 17.9X +radix sort key prefix array 1024 / 1053 24.4 41.0 11.5X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 17619ec5fadc1..958a064402149 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.unsafe.array.LongArray import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.collection.Sorter @@ -28,12 +28,15 @@ import org.apache.spark.util.random.XORShiftRandom /** * Benchmark to measure performance for aggregate primitives. - * To run this: - * build/sbt "sql/test-only *benchmark.SortBenchmark" - * - * Benchmarks in this file are skipped in normal builds. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} */ -class SortBenchmark extends BenchmarkWithCodegen { +object SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) @@ -54,10 +57,10 @@ class SortBenchmark extends BenchmarkWithCodegen { new LongArray(MemoryBlock.fromLongArray(extended))) } - ignore("sort") { + def sortBenchmark(): Unit = { val size = 25000000 val rand = new XORShiftRandom(123) - val benchmark = new Benchmark("radix sort " + size, size) + val benchmark = new Benchmark("radix sort " + size, size, output = output) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } val buf = new LongArray(MemoryBlock.fromLongArray(array)) @@ -114,20 +117,11 @@ class SortBenchmark extends BenchmarkWithCodegen { timer.stopTiming() } benchmark.run() + } - /* - Running benchmark: radix sort 25000000 - Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic - Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz - - radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X - reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X - radix sort one byte 133 / 137 188.4 5.3 117.2X - radix sort two bytes 255 / 258 98.2 10.2 61.1X - radix sort eight bytes 991 / 997 25.2 39.6 15.7X - radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X - */ + override def benchmark(): Unit = { + runBenchmark("radix sort") { + sortBenchmark() + } } } From 66d29870c09e6050dd846336e596faaa8b0d14ad Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 25 Sep 2018 11:42:27 -0700 Subject: [PATCH 1692/2461] [SPARK-25495][SS] FetchedData.reset should reset all fields ## What changes were proposed in this pull request? `FetchedData.reset` should reset `_nextOffsetInFetchedData` and `_offsetAfterPoll`. Otherwise it will cause inconsistent cached data and may make Kafka connector return wrong results. ## How was this patch tested? The new unit test. Closes #22507 from zsxwing/fix-kafka-reset. Lead-authored-by: Shixiong Zhu Co-authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../sql/kafka010/KafkaDataConsumer.scala | 5 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 52 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index ceb9e318b283b..7b1314bc8c3c0 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -134,6 +134,8 @@ private[kafka010] case class InternalKafkaConsumer( /** Reset the internal pre-fetched data. */ def reset(): Unit = { _records = ju.Collections.emptyListIterator() + _nextOffsetInFetchedData = UNKNOWN_OFFSET + _offsetAfterPoll = UNKNOWN_OFFSET } /** @@ -361,8 +363,9 @@ private[kafka010] case class InternalKafkaConsumer( if (offset < fetchedData.offsetAfterPoll) { // Offsets in [offset, fetchedData.offsetAfterPoll) are invisible. Return a record to ask // the next call to start from `fetchedData.offsetAfterPoll`. + val nextOffsetToFetch = fetchedData.offsetAfterPoll fetchedData.reset() - return fetchedRecord.withRecord(null, fetchedData.offsetAfterPoll) + return fetchedRecord.withRecord(null, nextOffsetToFetch) } else { // Fetch records from Kafka and update `fetchedData`. fetchData(offset, pollTimeoutMs) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e5f008804ee5b..39c2cde7de40d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -874,6 +874,58 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } } + + test("SPARK-25495: FetchedData.reset should reset all fields") { + val topic = newTopic() + val topicPartition = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + val ds = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.isolation.level", "read_committed") + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .select($"value".as[String]) + + testUtils.withTranscationalProducer { producer => + producer.beginTransaction() + (0 to 3).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + } + testUtils.waitUntilOffsetAppears(topicPartition, 5) + + val q = ds.writeStream.foreachBatch { (ds, epochId) => + if (epochId == 0) { + // Send more message before the tasks of the current batch start reading the current batch + // data, so that the executors will prefetch messages in the next batch and drop them. In + // this case, if we forget to reset `FetchedData._nextOffsetInFetchedData` or + // `FetchedData._offsetAfterPoll` (See SPARK-25495), the next batch will see incorrect + // values and return wrong results hence fail the test. + testUtils.withTranscationalProducer { producer => + producer.beginTransaction() + (4 to 7).foreach { i => + producer.send(new ProducerRecord[String, String](topic, i.toString)).get() + } + producer.commitTransaction() + } + testUtils.waitUntilOffsetAppears(topicPartition, 10) + checkDatasetUnorderly(ds, (0 to 3).map(_.toString): _*) + } else { + checkDatasetUnorderly(ds, (4 to 7).map(_.toString): _*) + } + }.start() + try { + q.processAllAvailable() + } finally { + q.stop() + } + } } From 9bb3a0c67bd851b09ff4701ef1d280e2a77d791b Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 26 Sep 2018 08:45:27 +0800 Subject: [PATCH 1693/2461] [SPARK-25422][CORE] Don't memory map blocks streamed to disk. After data has been streamed to disk, the buffers are inserted into the memory store in some cases (eg., with broadcast blocks). But broadcast code also disposes of those buffers when the data has been read, to ensure that we don't leave mapped buffers using up memory, which then leads to garbage data in the memory store. ## How was this patch tested? Ran the old failing test in a loop. Full tests on jenkins Closes #22546 from squito/SPARK-25422-master. Authored-by: Imran Rashid Signed-off-by: Wenchen Fan --- .../apache/spark/storage/BlockManager.scala | 13 ++--- .../spark/util/io/ChunkedByteBuffer.scala | 47 ++++++++++--------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 22341467add5c..0fe82ac0cedc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -438,10 +438,8 @@ private[spark] class BlockManager( // stream. channel.close() // TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up - // using a lot of memory here. With encryption, we'll read the whole file into a regular - // byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm - // OOM, but might get killed by the OS / cluster manager. We could at least read the tmp - // file as a stream in both cases. + // using a lot of memory here. We'll read the whole file into a regular + // byte buffer and OOM. We could at least read the tmp file as a stream. val buffer = securityManager.getIOEncryptionKey() match { case Some(key) => // we need to pass in the size of the unencrypted block @@ -453,7 +451,7 @@ private[spark] class BlockManager( new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator) case None => - ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) + ChunkedByteBuffer.fromFile(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt) } putBytes(blockId, buffer, level)(classTag) tmpFile.delete() @@ -726,10 +724,9 @@ private[spark] class BlockManager( */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { // TODO if we change this method to return the ManagedBuffer, then getRemoteValues - // could just use the inputStream on the temp file, rather than memory-mapping the file. + // could just use the inputStream on the temp file, rather than reading the file into memory. // Until then, replication can cause the process to use too much memory and get killed - // by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though - // we've read the data to disk. + // even though we've read the data to disk. logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 39f050f6ca5ad..4aa8d45ec7404 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -19,17 +19,16 @@ package org.apache.spark.util.io import java.io.{File, FileInputStream, InputStream} import java.nio.ByteBuffer -import java.nio.channels.{FileChannel, WritableByteChannel} -import java.nio.file.StandardOpenOption - -import scala.collection.mutable.ListBuffer +import java.nio.channels.WritableByteChannel +import com.google.common.io.ByteStreams import com.google.common.primitives.UnsignedBytes +import org.apache.commons.io.IOUtils import org.apache.spark.SparkEnv import org.apache.spark.internal.config import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.util.ByteArrayWritableChannel +import org.apache.spark.network.util.{ByteArrayWritableChannel, LimitedInputStream} import org.apache.spark.storage.StorageUtils import org.apache.spark.util.Utils @@ -175,30 +174,36 @@ object ChunkedByteBuffer { def fromManagedBuffer(data: ManagedBuffer, maxChunkSize: Int): ChunkedByteBuffer = { data match { case f: FileSegmentManagedBuffer => - map(f.getFile, maxChunkSize, f.getOffset, f.getLength) + fromFile(f.getFile, maxChunkSize, f.getOffset, f.getLength) case other => new ChunkedByteBuffer(other.nioByteBuffer()) } } - def map(file: File, maxChunkSize: Int): ChunkedByteBuffer = { - map(file, maxChunkSize, 0, file.length()) + def fromFile(file: File, maxChunkSize: Int): ChunkedByteBuffer = { + fromFile(file, maxChunkSize, 0, file.length()) } - def map(file: File, maxChunkSize: Int, offset: Long, length: Long): ChunkedByteBuffer = { - Utils.tryWithResource(FileChannel.open(file.toPath, StandardOpenOption.READ)) { channel => - var remaining = length - var pos = offset - val chunks = new ListBuffer[ByteBuffer]() - while (remaining > 0) { - val chunkSize = math.min(remaining, maxChunkSize) - val chunk = channel.map(FileChannel.MapMode.READ_ONLY, pos, chunkSize) - pos += chunkSize - remaining -= chunkSize - chunks += chunk - } - new ChunkedByteBuffer(chunks.toArray) + private def fromFile( + file: File, + maxChunkSize: Int, + offset: Long, + length: Long): ChunkedByteBuffer = { + // We do *not* memory map the file, because we may end up putting this into the memory store, + // and spark currently is not expecting memory-mapped buffers in the memory store, it conflicts + // with other parts that manage the lifecyle of buffers and dispose them. See SPARK-25422. + val is = new FileInputStream(file) + ByteStreams.skipFully(is, offset) + val in = new LimitedInputStream(is, length) + val chunkSize = math.min(maxChunkSize, length).toInt + val out = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate _) + Utils.tryWithSafeFinally { + IOUtils.copy(in, out) + } { + in.close() + out.close() } + out.toChunkedByteBuffer } } From 8c2edf46d0f89e5ec54968218d89f30a3f8190bc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 26 Sep 2018 09:32:51 +0800 Subject: [PATCH 1694/2461] [SPARK-24324][PYTHON][FOLLOW-UP] Rename the Conf to spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName ## What changes were proposed in this pull request? Add the legacy prefix for spark.sql.execution.pandas.groupedMap.assignColumnsByPosition and rename it to spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName ## How was this patch tested? The existing tests. Closes #22540 from gatorsmile/renameAssignColumnsByPosition. Lead-authored-by: gatorsmile Co-authored-by: Hyukjin Kwon Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 3 ++- python/pyspark/worker.py | 7 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 18 +++++++++--------- .../spark/sql/execution/arrow/ArrowUtils.scala | 9 +++------ 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b829baeca4775..74642d46d1cd1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5802,7 +5802,8 @@ def test_positional_assignment_conf(self): import pandas as pd from pyspark.sql.functions import pandas_udf, PandasUDFType - with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + with self.sql_conf({ + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) def foo(_): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 974344f01d923..8c59f1f999f18 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -97,8 +97,9 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): - assign_cols_by_pos = runner_conf.get( - "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + assign_cols_by_name = runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true") + assign_cols_by_name = assign_cols_by_name.lower() == "true" def wrapped(key_series, value_series): import pandas as pd @@ -119,7 +120,7 @@ def wrapped(key_series, value_series): "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) # Assign result columns by schema name if user labeled with strings, else use position - if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns): return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] else: return [(result[result.columns[i]], to_arrow_type(field.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0e0a01def357e..e7c9a83798907 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1295,15 +1295,15 @@ object SQLConf { .booleanConf .createWithDefault(true) - val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = - buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = + buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") .internal() - .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + - "Pandas DataFrame based on position, regardless of column label type. When false, " + - "columns will be looked up by name if labeled with a string and fallback to use " + - "position if not. This configuration will be deprecated in future releases.") + .doc("When true, columns will be looked up by name if labeled with a string and fallback " + + "to use position if not. When false, a grouped map Pandas UDF will assign columns from " + + "the returned Pandas DataFrame based on position, regardless of column label type. " + + "This configuration will be deprecated in future releases.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() @@ -1915,8 +1915,8 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) - def pandasGroupedMapAssignColumnssByPosition: Boolean = - getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def pandasGroupedMapAssignColumnsByName: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME) def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 533097ac399e9..b1e8fb39ac9de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -131,11 +131,8 @@ object ArrowUtils { } else { Nil } - val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { - Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") - } else { - Nil - } - Map(timeZoneConf ++ pandasColsByPosition: _*) + val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + Map(timeZoneConf ++ pandasColsByName: _*) } } From cb77a6689137916e64bc5692b0c942e86ca1a0ea Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 26 Sep 2018 09:37:44 +0800 Subject: [PATCH 1695/2461] [SPARK-21291][R] add R partitionBy API in DataFrame ## What changes were proposed in this pull request? add R partitionBy API in write.df I didn't add bucketBy in write.df. The last line of write.df is ``` write <- handledCallJMethod(write, "save") ``` save doesn't support bucketBy right now. ``` assertNotBucketed("save") ``` ## How was this patch tested? Add unit test in test_sparkSQL.R Closes #22537 from huaxingao/spark-21291. Authored-by: Huaxin Gao Signed-off-by: hyukjinkwon --- R/pkg/R/DataFrame.R | 17 +++++++++++++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1cb4781f4d0a..34691883bc5a9 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2954,6 +2954,9 @@ setMethod("exceptAll", #' @param source a name for external data source. #' @param mode one of 'append', 'overwrite', 'error', 'errorifexists', 'ignore' #' save mode (it is 'error' by default) +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar +#' to Hive's partitioning scheme. #' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions @@ -2965,13 +2968,13 @@ setMethod("exceptAll", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' write.df(df, "myfile", "parquet", "overwrite") +#' write.df(df, "myfile", "parquet", "overwrite", partitionBy = c("col1", "col2")) #' saveDF(df, parquetPath2, "parquet", mode = "append", mergeSchema = TRUE) #' } #' @note write.df since 1.4.0 setMethod("write.df", signature(df = "SparkDataFrame"), - function(df, path = NULL, source = NULL, mode = "error", ...) { + function(df, path = NULL, source = NULL, mode = "error", partitionBy = NULL, ...) { if (!is.null(path) && !is.character(path)) { stop("path should be character, NULL or omitted.") } @@ -2985,8 +2988,18 @@ setMethod("write.df", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) is.character(c)))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } write <- setWriteOptions(write, path = path, mode = mode, ...) write <- handledCallJMethod(write, "save") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a874bfbb58dc7..50eff3755edf8 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2701,8 +2701,16 @@ test_that("read/write text files", { expect_equal(colnames(df2), c("value")) expect_equal(count(df2), count(df) * 2) + df3 <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), + schema = c("key", "value")) + textPath3 <- tempfile(pattern = "textPath3", fileext = ".txt") + write.df(df3, textPath3, "text", mode = "overwrite", partitionBy = "key") + df4 <- read.df(textPath3, "text") + expect_equal(count(df3), count(df4)) + unlink(textPath) unlink(textPath2) + unlink(textPath3) }) test_that("read/write text files - compression option", { From 473d0d862de54ec1c7a8f0354fa5e06f3d66e455 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 26 Sep 2018 09:52:15 +0800 Subject: [PATCH 1696/2461] [SPARK-25514][SQL] Generating pretty JSON by to_json ## What changes were proposed in this pull request? The PR introduces new JSON option `pretty` which allows to turn on `DefaultPrettyPrinter` of `Jackson`'s Json generator. New option is useful in exploring of deep nested columns and in converting of JSON columns in more readable representation (look at the added test). ## How was this patch tested? Added rount trip test which convert an JSON string to pretty representation via `from_json()` and `to_json()`. Closes #22534 from MaxGekk/pretty-json. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- R/pkg/R/functions.R | 5 +++-- python/pyspark/sql/functions.py | 4 +++- .../spark/sql/catalyst/json/JSONOptions.scala | 5 +++++ .../sql/catalyst/json/JacksonGenerator.scala | 5 ++++- .../org/apache/spark/sql/functions.scala | 4 ++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 21 +++++++++++++++++++ 6 files changed, 40 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 572dee50127b8..6425c9d26bef3 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -198,8 +198,9 @@ NULL #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. In \code{arrays_zip}, this contains additional -#' Columns of arrays to be merged. +#' options as the JSON data source. Additionally \code{to_json} supports the "pretty" +#' option which enables pretty JSON generation. In \code{arrays_zip}, this contains +#' additional Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6da5237d18de4..1c3d9725b285b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2295,7 +2295,9 @@ def to_json(col, options={}): into a JSON string. Throws an exception, in the case of an unsupported type. :param col: name of column containing a struct, an array or a map. - :param options: options to control converting. accepts the same options as the JSON datasource + :param options: options to control converting. accepts the same options as the JSON datasource. + Additionally the function supports the `pretty` option which enables + pretty JSON generation. >>> from pyspark.sql import Row >>> from pyspark.sql.types import * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 47eeb70e00427..64152e04928d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -113,6 +113,11 @@ private[sql] class JSONOptions( } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** + * Generating JSON strings in pretty representation if the parameter is enabled. + */ + val pretty: Boolean = parameters.get("pretty").map(_.toBoolean).getOrElse(false) + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9b86d865622dc..d02a2be8ddad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -70,7 +70,10 @@ private[sql] class JacksonGenerator( s"Initial type ${dataType.catalogString} must be a ${MapType.simpleString}") } - private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val gen = { + val generator = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + if (options.pretty) generator.useDefaultPrettyPrinter() else generator + } private val lineSeparator: String = options.lineSeparatorInWrite diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7a1ca54..4c58e77df485e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3619,6 +3619,8 @@ object functions { * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. + * Additionally the function supports the `pretty` option which enables + * pretty JSON generation. * * @group collection_funcs * @since 2.1.0 @@ -3635,6 +3637,8 @@ object functions { * @param e a column containing a struct, an array or a map. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. + * Additionally the function supports the `pretty` option which enables + * pretty JSON generation. * * @group collection_funcs * @since 2.1.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index fe4bf15fa3921..853bc182f2f4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -518,4 +518,25 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { jsonDF.select(to_json(from_json($"a", schema))), Seq(Row(json))) } + + test("pretty print - roundtrip from_json -> to_json") { + val json = """[{"book":{"publisher":[{"country":"NL","year":[1981,1986,1999]}]}}]""" + val jsonDF = Seq(json).toDF("root") + val expected = + """[ { + | "book" : { + | "publisher" : [ { + | "country" : "NL", + | "year" : [ 1981, 1986, 1999 ] + | } ] + | } + |} ]""".stripMargin + + checkAnswer( + jsonDF.select( + to_json( + from_json($"root", schema_of_json(lit(json))), + Map("pretty" -> "true"))), + Seq(Row(expected))) + } } From 81cbcca60099fd267492769b465d01e90d7deeac Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 25 Sep 2018 23:03:54 -0700 Subject: [PATCH 1697/2461] [SPARK-25534][SQL] Make `SQLHelper` trait ## What changes were proposed in this pull request? Currently, Spark has 7 `withTempPath` and 6 `withSQLConf` functions. This PR aims to remove duplicated and inconsistent code and reduce them to the following meaningful implementations. **withTempPath** - `SQLHelper.withTempPath`: The one which was used in `SQLTestUtils`. **withSQLConf** - `SQLHelper.withSQLConf`: The one which was used in `PlanTest`. - `ExecutorSideSQLConfSuite.withSQLConf`: The one which doesn't throw `AnalysisException` on StaticConf changes. - `SQLTestUtils.withSQLConf`: The one which overrides intentionally to change the active session. ```scala protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) } ``` ## How was this patch tested? Pass the Jenkins with the existing tests. Closes #22548 from dongjoon-hyun/SPARK-25534. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/plans/PlanTest.scala | 31 +-------- .../spark/sql/catalyst/plans/SQLHelper.scala | 64 +++++++++++++++++++ .../benchmark/DataSourceReadBenchmark.scala | 23 +------ .../benchmark/FilterPushdownBenchmark.scala | 24 +------ .../datasources/csv/CSVBenchmarks.scala | 12 +--- .../datasources/json/JsonBenchmarks.scala | 11 +--- .../CheckpointFileManagerSuite.scala | 10 +-- .../apache/spark/sql/test/SQLTestUtils.scala | 13 ---- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 25 ++------ 9 files changed, 81 insertions(+), 132 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 67740c3166471..3081ff935f043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -22,7 +22,6 @@ import org.scalatest.Suite import org.scalatest.Tag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode @@ -57,7 +56,7 @@ trait CodegenInterpretedPlanTest extends PlanTest { * Provides helper methods for comparing plans, but without the overhead of * mandating a FunSuite. */ -trait PlanTestBase extends PredicateHelper { self: Suite => +trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get @@ -174,32 +173,4 @@ trait PlanTestBase extends PredicateHelper { self: Suite => plan1 == plan2 } } - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL - * configurations. - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val conf = SQLConf.get - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (conf.contains(key)) { - Some(conf.getConfString(key)) - } else { - None - } - } - (keys, values).zipped.foreach { (k, v) => - if (SQLConf.staticConfKeys.contains(k)) { - throw new AnalysisException(s"Cannot modify the value of a static config: $k") - } - conf.setConfString(k, v) - } - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConfString(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala new file mode 100644 index 0000000000000..4d869d79ad594 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SQLHelper.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans + +import java.io.File + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +trait SQLHelper { + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index cf9bda2fb1ff1..51a7f9f1ef096 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution.benchmark import java.io.File import scala.collection.JavaConverters._ -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnVector -import org.apache.spark.util.Utils /** @@ -37,7 +37,7 @@ import org.apache.spark.util.Utils * To run this: * spark-submit --class */ -object DataSourceReadBenchmark { +object DataSourceReadBenchmark extends SQLHelper { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") // Since `spark.master` always exists, overrides this value @@ -54,27 +54,10 @@ object DataSourceReadBenchmark { spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { val testDf = if (partition.isDefined) { df.write.partitionBy(partition.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 3b7f10783b64c..7cdf653e38697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.benchmark import java.io.File -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} -import org.apache.spark.util.Utils /** * Benchmark to measure read performance with Filter pushdown. @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". * }}} */ -object FilterPushdownBenchmark extends BenchmarkBase { +object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper { private val conf = new SparkConf() .setAppName(this.getClass.getSimpleName) @@ -60,28 +60,10 @@ object FilterPushdownBenchmark extends BenchmarkBase { private val spark = SparkSession.builder().config(conf).getOrCreate() - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - private def prepareTable( dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { import spark.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index 6d319eb723d93..5d1a874999c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -16,21 +16,19 @@ */ package org.apache.spark.sql.execution.datasources.csv -import java.io.File - import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * Benchmark to measure CSV read/write performance. * To run this: * spark-submit --class --jars */ -object CSVBenchmarks { +object CSVBenchmarks extends SQLHelper { val conf = new SparkConf() val spark = SparkSession.builder @@ -40,12 +38,6 @@ object CSVBenchmarks { .getOrCreate() import spark.implicits._ - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index e40cb9b50148b..368318ab38cb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -21,16 +21,16 @@ import java.io.File import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. * To run this: * spark-submit --class --jars */ -object JSONBenchmarks { +object JSONBenchmarks extends SQLHelper { val conf = new SparkConf() val spark = SparkSession.builder @@ -40,13 +40,6 @@ object JSONBenchmarks { .getOrCreate() import spark.implicits._ - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def schemaInferring(rowsNum: Int): Unit = { val benchmark = new Benchmark("JSON schema inferring", rowsNum) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala index fe59cb25d5005..cbac1c13cdd33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -25,12 +25,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.Utils -abstract class CheckpointFileManagerTests extends SparkFunSuite { +abstract class CheckpointFileManagerTests extends SparkFunSuite with SQLHelper { def createManager(path: Path): CheckpointFileManager @@ -88,12 +88,6 @@ abstract class CheckpointFileManagerTests extends SparkFunSuite { fm.delete(path) // should not throw exception } } - - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } } class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 2fb8f70a20791..6b03d1e5b7662 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.Utils @@ -167,18 +166,6 @@ private[sql] trait SQLTestUtilsBase super.withSQLConf(pairs: _*)(f) } - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - /** * Copy file in jar's resource to a temp file, then pass it to `f`. * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 0eab7d1ea8e80..49de007df3828 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive.orc import java.io.File -import scala.util.{Random, Try} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils + /** * Benchmark to measure ORC read performance. @@ -34,7 +35,7 @@ import org.apache.spark.util.Utils * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. */ // scalastyle:off line.size.limit -object OrcReadBenchmark { +object OrcReadBenchmark extends SQLHelper { val conf = new SparkConf() conf.set("orc.compression", "snappy") @@ -47,28 +48,10 @@ object OrcReadBenchmark { // Set default configs. Individual cases will change them if necessary. spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - def withTempTable(tableNames: String*)(f: => Unit): Unit = { try f finally tableNames.foreach(spark.catalog.dropTempView) } - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName From b39e228ce8f771c6e6198b9ccd8665a68a25b857 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 26 Sep 2018 19:41:45 +0800 Subject: [PATCH 1698/2461] [SPARK-25541][SQL] CaseInsensitiveMap should be serializable after '-' or 'filterKeys' ## What changes were proposed in this pull request? `CaseInsensitiveMap` is declared as Serializable. However, it is no serializable after `-` operator or `filterKeys` method. This PR fix the issue by overriding the operator `-` and method `filterKeys`. So the we can avoid potential `NotSerializableException` on using `CaseInsensitiveMap`. ## How was this patch tested? New test suite. Closes #22553 from gengliangwang/fixCaseInsensitiveMap. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../catalyst/util/CaseInsensitiveMap.scala | 6 ++- .../util/CaseInsensitiveMapSuite.scala | 53 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index bb2c5926ae9bb..288a4f34a447e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -42,7 +42,11 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma override def iterator: Iterator[(String, T)] = keyLowerCasedMap.iterator override def -(key: String): Map[String, T] = { - new CaseInsensitiveMap(originalMap.filterKeys(!_.equalsIgnoreCase(key))) + new CaseInsensitiveMap(originalMap.filter(!_._1.equalsIgnoreCase(key))) + } + + override def filterKeys(p: (String) => Boolean): Map[String, T] = { + new CaseInsensitiveMap(originalMap.filter(kv => p(kv._1))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala new file mode 100644 index 0000000000000..03eed4aaa750b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer + +class CaseInsensitiveMapSuite extends SparkFunSuite { + private def shouldBeSerializable(m: Map[String, String]): Unit = { + new JavaSerializer(new SparkConf()).newInstance().serialize(m) + } + + test("Keys are case insensitive") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foO" -> "bar")) + assert(m("FOO") == "bar") + assert(m("fOo") == "bar") + assert(m("A") == "b") + shouldBeSerializable(m) + } + + test("CaseInsensitiveMap should be serializable after '-' operator") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")) - "a" + assert(m == Map("foo" -> "bar")) + shouldBeSerializable(m) + } + + test("CaseInsensitiveMap should be serializable after '+' operator") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")) + ("x" -> "y") + assert(m == Map("a" -> "b", "foo" -> "bar", "x" -> "y")) + shouldBeSerializable(m) + } + + test("CaseInsensitiveMap should be serializable after 'filterKeys' method") { + val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")).filterKeys(_ == "foo") + assert(m == Map("foo" -> "bar")) + shouldBeSerializable(m) + } +} From 44a71741d510484b787855986cec970ac0cb5da8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 26 Sep 2018 21:34:18 +0800 Subject: [PATCH 1699/2461] [SPARK-25379][SQL] Improve AttributeSet and ColumnPruning performance ## What changes were proposed in this pull request? This PR contains 3 optimizations: 1) it improves significantly the operation `--` on `AttributeSet`. As a benchmark for the `--` operation, the following code has been run ``` test("AttributeSet -- benchmark") { val attrSetA = AttributeSet((1 to 100).map { i => AttributeReference(s"c$i", IntegerType)() }) val attrSetB = AttributeSet(attrSetA.take(80).toSeq) val attrSetC = AttributeSet((1 to 100).map { i => AttributeReference(s"c2_$i", IntegerType)() }) val attrSetD = AttributeSet((attrSetA.take(50) ++ attrSetC.take(50)).toSeq) val attrSetE = AttributeSet((attrSetC.take(50) ++ attrSetA.take(50)).toSeq) val n_iter = 1000000 val t0 = System.nanoTime() (1 to n_iter) foreach { _ => val r1 = attrSetA -- attrSetB val r2 = attrSetA -- attrSetC val r3 = attrSetA -- attrSetD val r4 = attrSetA -- attrSetE } val t1 = System.nanoTime() val totalTime = t1 - t0 println(s"Average time: ${totalTime / n_iter} us") } ``` The results are: ``` Before PR - Average time: 67674 us (100 %) After PR - Average time: 28827 us (42.6 %) ``` 2) In `ColumnPruning`, it replaces the occurrences of `(attributeSet1 -- attributeSet2).nonEmpty` with `attributeSet1.subsetOf(attributeSet2)` which is order of magnitudes more efficient (especially where there are many attributes). Running the previous benchmark replacing `--` with `subsetOf` returns: ``` Average time: 67 us (0.1 %) ``` 3) Provides a more efficient way of building `AttributeSet`s, which can greatly improve the performance of the methods `references` and `outputSet` of `Expression` and `QueryPlan`. This basically avoids unneeded operations (eg. creating many `AttributeEqual` wrapper classes which could be avoided) The overall effect of those optimizations has been tested on `ColumnPruning` with the following benchmark: ``` test("ColumnPruning benchmark") { val attrSetA = (1 to 100).map { i => AttributeReference(s"c$i", IntegerType)() } val attrSetB = attrSetA.take(80) val attrSetC = attrSetA.take(20).map(a => Alias(Add(a, Literal(1)), s"${a.name}_1")()) val input = LocalRelation(attrSetA) val query1 = Project(attrSetB, Project(attrSetA, input)).analyze val query2 = Project(attrSetC, Project(attrSetA, input)).analyze val query3 = Project(attrSetA, Project(attrSetA, input)).analyze val nIter = 100000 val t0 = System.nanoTime() (1 to nIter).foreach { _ => ColumnPruning(query1) ColumnPruning(query2) ColumnPruning(query3) } val t1 = System.nanoTime() val totalTime = t1 - t0 println(s"Average time: ${totalTime / nIter} us") } ``` The output of the test is: ``` Before PR - Average time: 733471 us (100 %) After PR - Average time: 362455 us (49.4 %) ``` The performance improvement has been evaluated also on the `SQLQueryTestSuite`'s queries: ``` (before) org.apache.spark.sql.catalyst.optimizer.ColumnPruning 518413198 / 1377707172 2756 / 15717 (after) org.apache.spark.sql.catalyst.optimizer.ColumnPruning 415432579 / 1121147950 2756 / 15717 % Running time 80.1% / 81.3% ``` Also other rules benefit especially from (3), despite the impact is lower, eg: ``` (before) org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 307341442 / 623436806 2154 / 16480 (after) org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 290511312 / 560962495 2154 / 16480 % Running time 94.5% / 90.0% ``` The reason why the impact on the `SQLQueryTestSuite`'s queries is lower compared to the other benchmark is that the optimizations are more significant when the number of attributes involved is higher. Since in the tests we often have very few attributes, the effect there is lower. ## How was this patch tested? run benchmarks + existing UTs Closes #22364 from mgaido91/SPARK-25379. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../catalyst/expressions/AttributeSet.scala | 23 +++++++++++++----- .../sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 24 +++++++++---------- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 7420b6b57d8e1..a7e09eee617e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + protected class AttributeEquals(val a: Attribute) { override def hashCode(): Int = a match { @@ -39,10 +41,13 @@ object AttributeSet { /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ def apply(baseSet: Iterable[Expression]): AttributeSet = { - new AttributeSet( - baseSet - .flatMap(_.references) - .map(new AttributeEquals(_)).toSet) + fromAttributeSets(baseSet.map(_.references)) + } + + /** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */ + def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = { + val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet) + new AttributeSet(baseSet.toSet) } } @@ -94,8 +99,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]): AttributeSet = - new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + def --(other: Traversable[NamedExpression]): AttributeSet = { + other match { + case otherSet: AttributeSet => + new AttributeSet(baseSet -- otherSet.baseSet) + case _ => + new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + } + } /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 773aefc0ac1f9..c215735ab1c98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -85,7 +85,7 @@ abstract class Expression extends TreeNode[Expression] { def nullable: Boolean - def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) + def references: AttributeSet = AttributeSet.fromAttributeSets(children.map(_.references)) /** Returns the result of evaluating this expression on a given input Row */ def eval(input: InternalRow = null): Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7c461895c5e52..07a653f3b5d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -532,12 +532,12 @@ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand - case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) - case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => proj.zip(e.output).filter { case (_, a) => @@ -547,18 +547,18 @@ object ColumnPruning extends Rule[LogicalPlan] { a.copy(child = Expand(newProjects, newOutput, grandChild)) // Prunes the unused columns from child of `DeserializeToObject` - case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => + case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) => d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation - case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) => a.copy(child = prunedChild(child, a.references)) - case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + case f @ FlatMapGroupsInPandas(_, _, _, child) if !child.outputSet.subsetOf(f.references) => f.copy(child = prunedChild(child, f.references)) - case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) => e.copy(child = prunedChild(child, e.references)) case s @ ScriptTransformation(_, _, _, child, _) - if (child.outputSet -- s.references).nonEmpty => + if !child.outputSet.subsetOf(s.references) => s.copy(child = prunedChild(child, s.references)) // prune unrequired references @@ -579,7 +579,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, _: Distinct) => p // Eliminate unneeded attributes from children of Union. case p @ Project(_, u: Union) => - if ((u.outputSet -- p.references).nonEmpty) { + if (!u.outputSet.subsetOf(p.references)) { val firstChild = u.children.head val newOutput = prunedChild(firstChild, p.references).output // pruning the columns of all children based on the pruned first child. @@ -595,7 +595,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } // Prune unnecessary window expressions - case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + case p @ Project(_, w: Window) if !w.windowOutputSet.subsetOf(p.references) => p.copy(child = w.copy( windowExpressions = w.windowExpressions.filter(p.references.contains))) @@ -611,7 +611,7 @@ object ColumnPruning extends Rule[LogicalPlan] { // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => val required = child.references ++ p.references - if ((child.inputSet -- required).nonEmpty) { + if (!child.inputSet.subsetOf(required)) { val newChildren = child.children.map(c => prunedChild(c, required)) p.copy(child = child.withNewChildren(newChildren)) } else { @@ -621,7 +621,7 @@ object ColumnPruning extends Rule[LogicalPlan] { /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = - if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { + if (!c.outputSet.subsetOf(allReferences)) { Project(c.output.filter(allReferences.contains), c) } else { c diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b1ffdca091461..ca0cea6ba7de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -42,7 +42,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. */ - def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + def references: AttributeSet = AttributeSet.fromAttributeSets(expressions.map(_.references)) /** * The set of all attributes that are input to this operator by its children. From cf5c9c4b550c3a8ed59d7ef9404f2689ea763fa9 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Wed, 26 Sep 2018 22:14:14 +0800 Subject: [PATCH 1700/2461] [SPARK-20937][DOCS] Describe spark.sql.parquet.writeLegacyFormat property in Spark SQL, DataFrames and Datasets Guide ## What changes were proposed in this pull request? Describe spark.sql.parquet.writeLegacyFormat property in Spark SQL, DataFrames and Datasets Guide. ## How was this patch tested? N/A Closes #22453 from seancxmao/SPARK-20937. Authored-by: seancxmao Signed-off-by: hyukjinkwon --- docs/sql-programming-guide.md | 11 +++++++++++ .../scala/org/apache/spark/sql/internal/SQLConf.scala | 7 +++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c72fa3d75d67f..6de9de90c62c3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1004,6 +1004,17 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession

      + + + + +
      diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8724bbc6ca7c5..9fa1577681f03 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1498,8 +1498,7 @@ def test_array_contains_function(self): from pyspark.sql.functions import array_contains df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) - actual = df.select(array_contains(df.data, 1).alias('b')).collect() - # The value argument can be implicitly castable to the element's type of the array. + actual = df.select(array_contains(df.data, "1").alias('b')).collect() self.assertEqual([Row(b=True), Row(b=False)], actual) def test_between_function(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cc9edcfd41d02..e23ebef9643ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1331,23 +1331,27 @@ case class ArrayContains(left: Expression, right: Expression) @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) - override def inputTypes: Seq[AbstractDataType] = right.dataType match { - case NullType => Seq.empty - case _ => left.dataType match { - case n @ ArrayType(element, _) => Seq(n, element) + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (_, NullType) => Seq.empty + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } case _ => Seq.empty } } override def checkInputDataTypes(): TypeCheckResult = { - if (right.dataType == NullType) { - TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") - } else if (!left.dataType.isInstanceOf[ArrayType] - || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be an array followed by a value of same type as the array members") - } else { - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + (left.dataType, right.dataType) match { + case (_, NullType) => + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 121db442c77f7..ad52fd01248e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -735,6 +736,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), Seq(Row(true), Row(true)) ) + + checkAnswer( + OneRowRelation().selectExpr("array_contains(array(1), 1.23D)"), + Seq(Row(false)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_contains(array(1), 1.0D)"), + Seq(Row(true)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_contains(array(1.0D), 1)"), + Seq(Row(true)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_contains(array(1.23D), 1)"), + Seq(Row(false)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_contains(array(array(1)), array(1.0D))"), + Seq(Row(true)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_contains(array(array(1)), array(1.23D))"), + Seq(Row(false)) + ) + + val e1 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_contains(array(1), .01234567890123456790123456780)") + } + val errorMsg1 = + s""" + |Input to function array_contains should have been array followed by a + |value with same element type, but it's [array, decimal(29,29)]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_contains(array(1), 'foo')") + } + val errorMsg2 = + s""" + |Input to function array_contains should have been array followed by a + |value with same element type, but it's [array, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("arrays_overlap function") { From 88e7e87bd5c052e10f52d4bb97a9d78f5b524128 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 21 Sep 2018 00:41:42 +0800 Subject: [PATCH 1660/2461] [MINOR][PYTHON] Use a helper in `PythonUtils` instead of direct accessing Scala package ## What changes were proposed in this pull request? This PR proposes to use add a helper in `PythonUtils` instead of direct accessing Scala package. ## How was this patch tested? Jenkins tests. Closes #22483 from HyukjinKwon/minor-refactoring. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- .../main/scala/org/apache/spark/api/python/PythonUtils.scala | 4 ++++ python/pyspark/context.py | 4 +--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 27a5e19f96a14..cdce371dfcbfa 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -74,4 +74,8 @@ private[spark] object PythonUtils { def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { jm.asScala.toMap } + + def getEncryptionEnabled(sc: JavaSparkContext): Boolean = { + sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED) + } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2c92c29a1cc1b..87255c40e330e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -192,9 +192,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # If encryption is enabled, we need to setup a server in the jvm to read broadcast # data via a socket. # scala's mangled names w/ $ in them require special treatment. - encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\ - .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED() - self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf) + self._encryption_enabled = self._jvm.PythonUtils.getEncryptionEnabled(self._jsc) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2] From 88446b6ad19371f15d06ef67052f6c1a8072c04a Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 20 Sep 2018 10:00:28 -0700 Subject: [PATCH 1661/2461] [SPARK-25450][SQL] PushProjectThroughUnion rule uses the same exprId for project expressions in each Union child, causing mistakes in constant propagation ## What changes were proposed in this pull request? The problem was cause by the PushProjectThroughUnion rule, which, when creating new Project for each child of Union, uses the same exprId for expressions of the same position. This is wrong because, for each child of Union, the expressions are all independent, and it can lead to a wrong result if other rules like FoldablePropagation kicks in, taking two different expressions as the same. This fix is to create new expressions in the new Project for each child of Union. ## How was this patch tested? Added UT. Closes #22447 from maryannxue/push-project-thru-union-bug. Authored-by: maryannxue Signed-off-by: gatorsmile --- .../sql/catalyst/optimizer/Optimizer.scala | 4 ++ .../PushProjectThroughUnionSuite.scala | 54 +++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b432ce24e1ef7..7c461895c5e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -486,6 +486,10 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { case a: Attribute => rewrites(a) + } match { + // Make sure exprId is unique in each child of Union. + case Alias(child, alias) => Alias(child, alias)() + case other => other } // We must promise the compiler that we did not discard the names in the case of project diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala new file mode 100644 index 0000000000000..294d29842b045 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class PushProjectThroughUnionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Optimizer Batch", FixedPoint(100), + PushProjectionThroughUnion, + FoldablePropagation) :: Nil + } + + test("SPARK-25450 PushProjectThroughUnion rule uses the same exprId for project expressions " + + "in each Union child, causing mistakes in constant propagation") { + val testRelation1 = LocalRelation('a.string, 'b.int, 'c.string) + val testRelation2 = LocalRelation('d.string, 'e.int, 'f.string) + val query = testRelation1 + .union(testRelation2.select("bar".as("d"), 'e, 'f)) + .select('a.as("n")) + .select('n, "dummy").analyze + val optimized = Optimize.execute(query) + + val expected = testRelation1 + .select('a.as("n")) + .select('n, "dummy") + .union(testRelation2 + .select("bar".as("d"), 'e, 'f) + .select("bar".as("n")) + .select("bar".as("n"), "dummy")).analyze + + comparePlans(optimized, expected) + } +} From a86f84102e10a6ca6325c604bc76d81b0f53eba3 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 21 Sep 2018 01:11:40 +0800 Subject: [PATCH 1662/2461] [SPARK-25381][SQL] Stratified sampling by Column argument ## What changes were proposed in this pull request? In the PR, I propose to add an overloaded method for `sampleBy` which accepts the first argument of the `Column` type. This will allow to sample by any complex columns as well as sampling by multiple columns. For example: ```Scala spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), ("Alice", 10))).toDF("name", "age") .stat .sampleBy(struct($"name", $"age"), Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0), 36L) .show() +-----+---+ | name|age| +-----+---+ | Nico| 8| |Alice| 10| +-----+---+ ``` ## How was this patch tested? Added new test for sampling by multiple columns for Scala and test for Java, Python to check that `sampleBy` is able to sample by `Column` type argument. Closes #22365 from MaxGekk/sample-by-column. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/dataframe.py | 11 +++- .../spark/sql/DataFrameStatFunctions.scala | 57 +++++++++++++++++-- .../apache/spark/sql/JavaDataFrameSuite.java | 11 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 20 ++++++- 4 files changed, 91 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1affc9b4fcf6c..21bc69b8236fd 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -880,16 +880,23 @@ def sampleBy(self, col, fractions, seed=None): | 0| 5| | 1| 9| +---+-----+ + >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count() + 33 + .. versionchanged:: 2.5 + Added sampling by a column of :class:`Column` """ - if not isinstance(col, basestring): - raise ValueError("col must be a string, but got %r" % type(col)) + if isinstance(col, basestring): + col = Column(col) + elif not isinstance(col, Column): + raise ValueError("col must be a string or a column, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) for k, v in fractions.items(): if not isinstance(k, (float, int, long, basestring)): raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) fractions[k] = float(v) + col = col._jc seed = seed if seed is not None else random.randint(0, sys.maxsize) return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a41753098966e..75b84773bd0b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -370,19 +370,66 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 1.5.0 */ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + sampleBy(Column(col), fractions, seed) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * The stratified sample can be performed over multiple columns: + * {{{ + * import org.apache.spark.sql.Row + * import org.apache.spark.sql.functions.struct + * + * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), + * ("Alice", 10))).toDF("name", "age") + * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) + * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() + * +-----+---+ + * | name|age| + * +-----+---+ + * | Nico| 8| + * |Alice| 10| + * +-----+---+ + * }}} + * + * @since 2.5.0 + */ + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.{rand, udf} - val c = Column(col) val r = rand(seed) val f = udf { (stratum: Any, x: Double) => x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) } - df.filter(f(c, r)) + df.filter(f(col, r)) } /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. + * (Java-specific) Returns a stratified sample without replacement based on the fraction given + * on each stratum. * @param col column that defines strata * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat * its fraction as zero. @@ -390,9 +437,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @tparam T stratum type * @return a new `DataFrame` that represents the stratified sample * - * @since 1.5.0 + * @since 2.5.0 */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 69a2904f5f3fe..3f37e5814ccaa 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -290,6 +290,17 @@ public void testSampleBy() { Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); } + @Test + public void testSampleByColumn() { + Dataset df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset sampled = df.stat().sampleBy(col("key"), ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); + Assert.assertEquals(0, actual.get(0).getLong(0)); + Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); + Assert.assertEquals(1, actual.get(1).getLong(0)); + Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); + } + @Test public void pivot() { Dataset df = spark.table("courseSales"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8eae35325faea..589873b9c3ea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers._ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -374,6 +374,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { Seq(Row(0, 6), Row(1, 11))) } + test("sampleBy one column") { + val df = spark.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy($"key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 6), Row(1, 11))) + } + + test("sampleBy multiple columns") { + val df = spark.range(0, 100) + .select(lit("Foo").as("name"), (col("id") % 3).as("key")) + val sampled = df.stat.sampleBy( + struct($"name", $"key"), Map(Row("Foo", 0) -> 0.1, Row("Foo", 1) -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 6), Row(1, 11))) + } + // This test case only verifies that `DataFrame.countMinSketch()` methods do return // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in // `CountMinSketchSuite` in project spark-sketch. From 2f51e72356babac703cc20a531b4dcc7712f34af Mon Sep 17 00:00:00 2001 From: Nihar Sheth Date: Thu, 20 Sep 2018 11:52:20 -0700 Subject: [PATCH 1663/2461] [SPARK-24918][CORE] Executor Plugin API ## What changes were proposed in this pull request? A continuation of squito's executor plugin task. By his request I took his code and added testing and moved the plugin initialization to a separate thread. Executor plugins now run on one separate thread, so the executor does not wait on them. Added testing. ## How was this patch tested? Added test cases that test using a sample plugin. Closes #22192 from NiharS/executorPlugin. Lead-authored-by: Nihar Sheth Co-authored-by: NiharS Signed-off-by: Marcelo Vanzin --- .../java/org/apache/spark/ExecutorPlugin.java | 57 +++++++ .../org/apache/spark/executor/Executor.scala | 35 +++++ .../spark/internal/config/package.scala | 10 ++ .../scala/org/apache/spark/util/Utils.scala | 13 ++ .../org/apache/spark/ExecutorPluginSuite.java | 139 ++++++++++++++++++ 5 files changed, 254 insertions(+) create mode 100644 core/src/main/java/org/apache/spark/ExecutorPlugin.java create mode 100644 core/src/test/java/org/apache/spark/ExecutorPluginSuite.java diff --git a/core/src/main/java/org/apache/spark/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/ExecutorPlugin.java new file mode 100644 index 0000000000000..ec0b57f1a2819 --- /dev/null +++ b/core/src/main/java/org/apache/spark/ExecutorPlugin.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * A plugin which can be automaticaly instantiated within each Spark executor. Users can specify + * plugins which should be created with the "spark.executor.plugins" configuration. An instance + * of each plugin will be created for every executor, including those created by dynamic allocation, + * before the executor starts running any tasks. + * + * The specific api exposed to the end users still considered to be very unstable. We will + * hopefully be able to keep compatability by providing default implementations for any methods + * added, but make no guarantees this will always be possible across all Spark releases. + * + * Spark does nothing to verify the plugin is doing legitimate things, or to manage the resources + * it uses. A plugin acquires the same privileges as the user running the task. A bad plugin + * could also intefere with task execution and make the executor fail in unexpected ways. + */ +@DeveloperApi +public interface ExecutorPlugin { + + /** + * Initialize the executor plugin. + * + *

      Each executor will, during its initialization, invoke this method on each + * plugin provided in the spark.executor.plugins configuration.

      + * + *

      Plugins should create threads in their implementation of this method for + * any polling, blocking, or intensive computation.

      + */ + default void init() {} + + /** + * Clean up and terminate this plugin. + * + *

      This function is called during the executor shutdown phase. The executor + * will wait for the plugin to terminate before continuing its own shutdown.

      + */ + default void shutdown() {} +} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 072277cb78dc1..6d7d65626ea12 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -136,6 +136,29 @@ private[spark] class Executor( // for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too. env.serializerManager.setDefaultClassLoader(replClassLoader) + private val executorPlugins: Seq[ExecutorPlugin] = { + val pluginNames = conf.get(EXECUTOR_PLUGINS) + if (pluginNames.nonEmpty) { + logDebug(s"Initializing the following plugins: ${pluginNames.mkString(", ")}") + + // Plugins need to load using a class loader that includes the executor's user classpath + val pluginList: Seq[ExecutorPlugin] = + Utils.withContextClassLoader(replClassLoader) { + val plugins = Utils.loadExtensions(classOf[ExecutorPlugin], pluginNames, conf) + plugins.foreach { plugin => + plugin.init() + logDebug(s"Successfully loaded plugin " + plugin.getClass().getCanonicalName()) + } + plugins + } + + logDebug("Finished initializing plugins") + pluginList + } else { + Nil + } + } + // Max size of direct result. If task result is bigger than this, we use the block manager // to send the result back. private val maxDirectResultSize = Math.min( @@ -224,6 +247,18 @@ private[spark] class Executor( logWarning("Unable to stop heartbeater", e) } threadPool.shutdown() + + // Notify plugins that executor is shutting down so they can terminate cleanly + Utils.withContextClassLoader(replClassLoader) { + executorPlugins.foreach { plugin => + try { + plugin.shutdown() + } catch { + case e: Exception => + logWarning("Plugin " + plugin.getClass().getCanonicalName() + " shutdown failed", e) + } + } + } if (!isLocal) { env.stop() } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8d827189ebb57..9891b6a2196de 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -623,4 +623,14 @@ package object config { .intConf .checkValue(v => v > 0, "The max failures should be a positive value.") .createWithDefault(40) + + private[spark] val EXECUTOR_PLUGINS = + ConfigBuilder("spark.executor.plugins") + .doc("Comma-separated list of class names for \"plugins\" implementing " + + "org.apache.spark.ExecutorPlugin. Plugins have the same privileges as any task " + + "in a Spark executor. They can also interfere with task execution and fail in " + + "unexpected ways. So be sure to only use this for trusted plugins.") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 14f68cd6f3509..c8b148be84536 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -239,6 +239,19 @@ private[spark] object Utils extends Logging { // scalastyle:on classforname } + /** + * Run a segment of code using a different context class loader in the current thread + */ + def withContextClassLoader[T](ctxClassLoader: ClassLoader)(fn: => T): T = { + val oldClassLoader = Thread.currentThread().getContextClassLoader() + try { + Thread.currentThread().setContextClassLoader(ctxClassLoader) + fn + } finally { + Thread.currentThread().setContextClassLoader(oldClassLoader) + } + } + /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ diff --git a/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java b/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java new file mode 100644 index 0000000000000..686eb28010c6a --- /dev/null +++ b/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.api.java.JavaSparkContext; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExecutorPluginSuite { + private static final String EXECUTOR_PLUGIN_CONF_NAME = "spark.executor.plugins"; + private static final String testBadPluginName = TestBadShutdownPlugin.class.getName(); + private static final String testPluginName = TestExecutorPlugin.class.getName(); + private static final String testSecondPluginName = TestSecondPlugin.class.getName(); + + // Static value modified by testing plugins to ensure plugins loaded correctly. + public static int numSuccessfulPlugins = 0; + + // Static value modified by testing plugins to verify plugins shut down properly. + public static int numSuccessfulTerminations = 0; + + private JavaSparkContext sc; + + @Before + public void setUp() { + sc = null; + numSuccessfulPlugins = 0; + numSuccessfulTerminations = 0; + } + + @After + public void tearDown() { + if (sc != null) { + sc.stop(); + sc = null; + } + } + + private SparkConf initializeSparkConf(String pluginNames) { + return new SparkConf() + .setMaster("local") + .setAppName("test") + .set(EXECUTOR_PLUGIN_CONF_NAME, pluginNames); + } + + @Test + public void testPluginClassDoesNotExist() { + SparkConf conf = initializeSparkConf("nonexistant.plugin"); + try { + sc = new JavaSparkContext(conf); + fail("No exception thrown for nonexistant plugin"); + } catch (Exception e) { + // We cannot catch ClassNotFoundException directly because Java doesn't think it'll be thrown + assertTrue(e.toString().startsWith("java.lang.ClassNotFoundException")); + } + } + + @Test + public void testAddPlugin() throws InterruptedException { + // Load the sample TestExecutorPlugin, which will change the value of numSuccessfulPlugins + SparkConf conf = initializeSparkConf(testPluginName); + sc = new JavaSparkContext(conf); + assertEquals(1, numSuccessfulPlugins); + sc.stop(); + sc = null; + assertEquals(1, numSuccessfulTerminations); + } + + @Test + public void testAddMultiplePlugins() throws InterruptedException { + // Load two plugins and verify they both execute. + SparkConf conf = initializeSparkConf(testPluginName + "," + testSecondPluginName); + sc = new JavaSparkContext(conf); + assertEquals(2, numSuccessfulPlugins); + sc.stop(); + sc = null; + assertEquals(2, numSuccessfulTerminations); + } + + @Test + public void testPluginShutdownWithException() { + // Verify an exception in one plugin shutdown does not affect the others + String pluginNames = testPluginName + "," + testBadPluginName + "," + testPluginName; + SparkConf conf = initializeSparkConf(pluginNames); + sc = new JavaSparkContext(conf); + assertEquals(3, numSuccessfulPlugins); + sc.stop(); + sc = null; + assertEquals(2, numSuccessfulTerminations); + } + + public static class TestExecutorPlugin implements ExecutorPlugin { + public void init() { + ExecutorPluginSuite.numSuccessfulPlugins++; + } + + public void shutdown() { + ExecutorPluginSuite.numSuccessfulTerminations++; + } + } + + public static class TestSecondPlugin implements ExecutorPlugin { + public void init() { + ExecutorPluginSuite.numSuccessfulPlugins++; + } + + public void shutdown() { + ExecutorPluginSuite.numSuccessfulTerminations++; + } + } + + public static class TestBadShutdownPlugin implements ExecutorPlugin { + public void init() { + ExecutorPluginSuite.numSuccessfulPlugins++; + } + + public void shutdown() { + throw new RuntimeException("This plugin will fail to cleanly shut down"); + } + } +} From 4d114fc9a2cb0be7256560bc8b2e4ce72adb7a7f Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 20 Sep 2018 16:53:48 -0500 Subject: [PATCH 1664/2461] [SPARK-25366][SQL] Zstd and brotli CompressionCodec are not supported for parquet files ## What changes were proposed in this pull request? Hadoop2.6 and hadoop2.7 do not contain zstd and brotli compressioncodec ,hadoop 3.1 also contains only zstd compressioncodec . So I think we should remove zstd and brotil for the time being. **set `spark.sql.parquet.compression.codec=brotli`:** Caused by: org.apache.parquet.hadoop.BadConfigurationException: Class org.apache.hadoop.io.compress.BrotliCodec was not found at org.apache.parquet.hadoop.CodecFactory.getCodec(CodecFactory.java:235) at org.apache.parquet.hadoop.CodecFactory$HeapBytesCompressor.(CodecFactory.java:142) at org.apache.parquet.hadoop.CodecFactory.createCompressor(CodecFactory.java:206) at org.apache.parquet.hadoop.CodecFactory.getCompressor(CodecFactory.java:189) at org.apache.parquet.hadoop.ParquetRecordWriter.(ParquetRecordWriter.java:153) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:411) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:349) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:161) **set `spark.sql.parquet.compression.codec=zstd`:** Caused by: org.apache.parquet.hadoop.BadConfigurationException: Class org.apache.hadoop.io.compress.ZStandardCodec was not found at org.apache.parquet.hadoop.CodecFactory.getCodec(CodecFactory.java:235) at org.apache.parquet.hadoop.CodecFactory$HeapBytesCompressor.(CodecFactory.java:142) at org.apache.parquet.hadoop.CodecFactory.createCompressor(CodecFactory.java:206) at org.apache.parquet.hadoop.CodecFactory.getCompressor(CodecFactory.java:189) at org.apache.parquet.hadoop.ParquetRecordWriter.(ParquetRecordWriter.java:153) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:411) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:349) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:161) ## How was this patch tested? Exist unit test Closes #22358 from 10110346/notsupportzstdandbrotil. Authored-by: liuxian Signed-off-by: Sean Owen --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2e3ee3e77818..8ec4865d58162 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -965,6 +965,8 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession `parquet.compression` is specified in the table-specific options/properties, the precedence would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. + Note that `zstd` requires `ZStandardCodec` to be installed before Hadoop 2.9.0, `brotli` requires + `BrotliCodec` to be installed.
      spark.sql.parquet.writeLegacyFormatfalse + If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal values + will be written in Apache Parquet's fixed-length byte array format, which other systems such as + Apache Hive and Apache Impala use. If false, the newer format in Parquet will be used. For + example, decimals will be written in int-based format. If Parquet output is intended for use + with systems that do not support this newer format, set to true. +
      ## ORC Files diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e7c9a83798907..2f4d660437ada 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -451,8 +451,11 @@ object SQLConf { .createWithDefault(10) val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") - .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + - "versions, when converting Parquet schema to Spark SQL schema and vice versa.") + .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + + "values will be written in Apache Parquet's fixed-length byte array format, which other " + + "systems such as Apache Hive and Apache Impala use. If false, the newer format in Parquet " + + "will be used. For example, decimals will be written in int-based format. If Parquet " + + "output is intended for use with systems that do not support this newer format, set to true.") .booleanConf .createWithDefault(false) From a2ac5a72ccd2b14c8492d4a6da9e8b30f0f3c9b4 Mon Sep 17 00:00:00 2001 From: Rong Tang Date: Wed, 26 Sep 2018 10:37:17 -0500 Subject: [PATCH 1701/2461] [SPARK-25509][CORE] Windows doesn't support POSIX permissions ## What changes were proposed in this pull request? SHS V2 cannot enabled in Windows, because windows doesn't support POSIX permission. ## How was this patch tested? test case fails in windows without this fix. org.apache.spark.deploy.history.HistoryServerDiskManagerSuite test("leasing space") SHS V2 cannot run successfully in Windows without this fix. java.lang.UnsupportedOperationException: 'posix:permissions' not supported as initial attribute at sun.nio.fs.WindowsSecurityDescriptor.fromAttribute(WindowsSecurityDescriptor.java:358) Closes #22520 from jianjianjiao/FixWindowsPermssionsIssue. Authored-by: Rong Tang Signed-off-by: Sean Owen --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 6 ++---- .../spark/deploy/history/HistoryServerDiskManager.scala | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 44d23908146c7..c23a659e76df1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.nio.file.Files -import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -133,9 +132,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - val perms = PosixFilePermissions.fromString("rwx------") - val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), - PosixFilePermissions.asFileAttribute(perms)).toFile() + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath()).toFile() + Utils.chmod700(dbPath) val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index c03a360b91ef8..ad0dd23cb59c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -18,8 +18,6 @@ package org.apache.spark.deploy.history import java.io.File -import java.nio.file.Files -import java.nio.file.attribute.PosixFilePermissions import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ @@ -107,9 +105,8 @@ private class HistoryServerDiskManager( val needed = approximateSize(eventLogSize, isCompressed) makeRoom(needed) - val perms = PosixFilePermissions.fromString("rwx------") - val tmp = Files.createTempDirectory(tmpStoreDir.toPath(), "appstore", - PosixFilePermissions.asFileAttribute(perms)).toFile() + val tmp = Utils.createTempDir(tmpStoreDir.getPath(), "appstore") + Utils.chmod700(tmp) updateUsage(needed) val current = currentUsage.get() From bd2ae857d1c5f251056de38a7a40540986756b94 Mon Sep 17 00:00:00 2001 From: Reza Safi Date: Wed, 26 Sep 2018 09:29:58 -0700 Subject: [PATCH 1702/2461] [SPARK-25318] Add exception handling when wrapping the input stream during the the fetch or stage retry in response to a corrupted block SPARK-4105 provided a solution to block corruption issue by retrying the fetch or the stage. In that solution there is a step that wraps the input stream with compression and/or encryption. This step is prone to exceptions, but in the current code there is no exception handling for this step and this has caused confusion for the user. The confusion was that after SPARK-4105 the user expects to see either a fetchFailed exception or a warning about a corrupted block. However an exception during wrapping can fail the job without any of those. This change adds exception handling for the wrapping step and also adds a fetch retry if we experience a corruption during the wrapping step. The reason for adding the retry is that usually user won't experience the same failure after rerunning the job and so it seems reasonable try to fetch and wrap one more time instead of failing. Closes #22325 from rezasafi/localcorruption. Authored-by: Reza Safi Signed-off-by: Marcelo Vanzin --- .../storage/ShuffleBlockFetcherIterator.scala | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e534c746433f2..aecc2284a9588 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -448,35 +448,35 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - - input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption, also the size of - // block is small (the decompressed block is smaller than maxBytesInFlight) - if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { - val originalInput = input - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - try { + var isStreamCopied: Boolean = false + try { + input = streamWrapper(blockId, in) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) + if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + isStreamCopied = true + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) // Decompress the whole block at once to detect any corruption, which could increase // the memory usage tne potential increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(input, out) - out.close() + Utils.copyStream(input, out, closeStreams = true) input = out.toChunkedByteBuffer.toInputStream(dispose = true) - } catch { - case e: IOException => - buf.release() - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, address, e) - } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest(address, Array((blockId, size))) - result = null - } - } finally { - // TODO: release the buf here to free memory earlier - originalInput.close() + } + } catch { + case e: IOException => + buf.release() + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest(address, Array((blockId, size))) + result = null + } + } finally { + // TODO: release the buf here to free memory earlier + if (isStreamCopied) { in.close() } } From e702fb1d5218d062fcb8e618b92dad7958eb4062 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 26 Sep 2018 10:15:16 -0700 Subject: [PATCH 1703/2461] [SPARK-24519][CORE] Compute SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS only once ## What changes were proposed in this pull request? Previously SPARK-24519 created a modifiable config SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS. However, the config is being parsed for every creation of MapStatus, which could be very expensive. Another problem with the previous approach is that it created the illusion that this can be changed dynamically at runtime, which was not true. This PR changes it so the config is computed only once. ## How was this patch tested? Removed a test case that's no longer valid. Closes #22521 from rxin/SPARK-24519. Authored-by: Reynold Xin Signed-off-by: Dongjoon Hyun --- .../apache/spark/scheduler/MapStatus.scala | 12 ++++++-- .../spark/scheduler/MapStatusSuite.scala | 28 ------------------- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 659694dd189ad..0e221edf3965a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -49,10 +49,16 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { + /** + * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test + * code we can't assume SparkEnv.get exists. + */ + private lazy val minPartitionsToUseHighlyCompressMapStatus = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > Option(SparkEnv.get) - .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) - .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { + if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 354e6386fa60e..2155a0f2b6c21 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,32 +188,4 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } - - test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { - val conf = new SparkConf() - val env = mock(classOf[SparkEnv]) - doReturn(conf).when(env).conf - SparkEnv.set(env) - val sizes = Array.fill[Long](500)(150L) - // Test default value - val status = MapStatus(null, sizes) - assert(status.isInstanceOf[CompressedMapStatus]) - // Test Non-positive values - for (s <- -1 to 0) { - assertThrows[IllegalArgumentException] { - conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) - } - } - // Test positive values - Seq(1, 100, 499, 500, 501).foreach { s => - conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) - val status = MapStatus(null, sizes) - if(sizes.length > s) { - assert(status.isInstanceOf[HighlyCompressedMapStatus]) - } else { - assert(status.isInstanceOf[CompressedMapStatus]) - } - } - } } From 5ee21661834e837d414bc20591982a092c0aece3 Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 26 Sep 2018 10:47:49 -0700 Subject: [PATCH 1704/2461] [SPARK-25533][CORE][WEBUI] AppSummary should hold the information about succeeded Jobs and completed stages only ## What changes were proposed in this pull request? Currently, In the spark UI, when there are failed jobs or failed stages, display message for the completed jobs and completed stages are not consistent with the previous versions of spark. Reason is because, AppSummary holds the information about all the jobs and stages. But, In the below code, it checks against the completedJobs and completedStages. So, AppSummary should hold only successful jobs and stages. https://github.com/apache/spark/blob/66d29870c09e6050dd846336e596faaa8b0d14ad/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala#L306 https://github.com/apache/spark/blob/66d29870c09e6050dd846336e596faaa8b0d14ad/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala#L119 So, we should keep only completed jobs and stage information in the AppSummary, to make it consistent with Spark2.2 ## How was this patch tested? Test steps: bin/spark-shell ``` sc.parallelize(1 to 5, 5).collect() sc.parallelize(1 to 5, 2).map{ x => throw new RuntimeException("Fail")}.collect() ``` **Before fix:** ![screenshot from 2018-09-26 03-24-53](https://user-images.githubusercontent.com/23054875/46045669-f60bcd80-c13b-11e8-9aa6-a2e5a2038dba.png) ![screenshot from 2018-09-26 03-25-08](https://user-images.githubusercontent.com/23054875/46045699-0ae86100-c13c-11e8-94e5-ad35944c7615.png) **After fix:** ![screenshot from 2018-09-26 03-16-14](https://user-images.githubusercontent.com/23054875/46045636-d83e6880-c13b-11e8-98df-f49d15c18958.png) ![screenshot from 2018-09-26 03-16-28](https://user-images.githubusercontent.com/23054875/46045645-e1c7d080-c13b-11e8-8c9c-d32e1f663356.png) Closes #22549 from shahidki31/SPARK-25533. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../apache/spark/status/AppStatusListener.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index f21eee1965761..36aaf67b57298 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -388,10 +388,11 @@ private[spark] class AppStatusListener( job.completionTime = if (event.time > 0) Some(new Date(event.time)) else None update(job, now, last = true) + if (job.status == JobExecutionStatus.SUCCEEDED) { + appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages) + kvstore.write(appSummary) + } } - - appSummary = new AppSummary(appSummary.numCompletedJobs + 1, appSummary.numCompletedStages) - kvstore.write(appSummary) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -653,13 +654,14 @@ private[spark] class AppStatusListener( if (removeStage) { liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber)) } + if (stage.status == v1.StageStatus.COMPLETE) { + appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) + kvstore.write(appSummary) + } } // remove any dead executors that were not running for any currently active stages deadExecutors.retain((execId, exec) => isExecutorActiveForLiveStages(exec)) - - appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) - kvstore.write(appSummary) } private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { From 51540c2fa677658be954c820bc18ba748e4c8583 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 26 Sep 2018 17:24:52 -0700 Subject: [PATCH 1705/2461] [SPARK-25372][YARN][K8S] Deprecate and generalize keytab / principal config ## What changes were proposed in this pull request? SparkSubmit already logs in the user if a keytab is provided, the only issue is that it uses the existing configs which have "yarn" in their name. As such, the configs were changed to: `spark.kerberos.keytab` and `spark.kerberos.principal`. ## How was this patch tested? Will be tested with K8S tests, but needs to be tested with Yarn - [x] K8S Secure HDFS tests - [x] Yarn Secure HDFS tests vanzin Closes #22362 from ifilonenko/SPARK-25372. Authored-by: Ilan Filonenko Signed-off-by: Marcelo Vanzin --- R/pkg/R/sparkR.R | 2 ++ R/pkg/vignettes/sparkr-vignettes.Rmd | 4 ++-- core/src/main/scala/org/apache/spark/SparkConf.scala | 6 +++++- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 6 ++++-- .../org/apache/spark/deploy/SparkSubmitArguments.scala | 10 ++++++++-- .../org/apache/spark/internal/config/package.scala | 4 ++-- docs/running-on-yarn.md | 4 ++-- docs/sparkr.md | 4 ++-- .../scala/org/apache/spark/streaming/Checkpoint.scala | 2 ++ 9 files changed, 29 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d3a9cbae7d808..038fefadaaeff 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -626,6 +626,8 @@ sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-pat sparkConfToSubmitOps[["spark.master"]] <- "--master" sparkConfToSubmitOps[["spark.yarn.keytab"]] <- "--keytab" sparkConfToSubmitOps[["spark.yarn.principal"]] <- "--principal" +sparkConfToSubmitOps[["spark.kerberos.keytab"]] <- "--keytab" +sparkConfToSubmitOps[["spark.kerberos.principal"]] <- "--principal" # Utility function that returns Spark Submit arguments as a string diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 090363c5f8a3e..ad934947437bc 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -157,8 +157,8 @@ Property Name | Property group | spark-submit equivalent `spark.driver.extraClassPath` | Runtime Environment | `--driver-class-path` `spark.driver.extraJavaOptions` | Runtime Environment | `--driver-java-options` `spark.driver.extraLibraryPath` | Runtime Environment | `--driver-library-path` -`spark.yarn.keytab` | Application Properties | `--keytab` -`spark.yarn.principal` | Application Properties | `--principal` +`spark.kerberos.keytab` | Application Properties | `--keytab` +`spark.kerberos.principal` | Application Properties | `--principal` **For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6c4c5c94cfa28..e0f98f1aca071 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -726,7 +726,11 @@ private[spark] object SparkConf extends Logging { DRIVER_MEMORY_OVERHEAD.key -> Seq( AlternateConfig("spark.yarn.driver.memoryOverhead", "2.3")), EXECUTOR_MEMORY_OVERHEAD.key -> Seq( - AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3")) + AlternateConfig("spark.yarn.executor.memoryOverhead", "2.3")), + KEYTAB.key -> Seq( + AlternateConfig("spark.yarn.keytab", "2.5")), + PRINCIPAL.key -> Seq( + AlternateConfig("spark.yarn.principal", "2.5")) ) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cf902db8709e7..d5f2865f87281 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -520,6 +520,10 @@ private[spark] class SparkSubmit extends Logging { confKey = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.driver.extraLibraryPath"), + OptionAssigner(args.principal, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = PRINCIPAL.key), + OptionAssigner(args.keytab, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + confKey = KEYTAB.key), // Propagate attributes for dependency resolution at the driver side OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"), @@ -537,8 +541,6 @@ private[spark] class SparkSubmit extends Logging { OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN | KUBERNETES, ALL_DEPLOY_MODES, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0998757715457..4cf08a7980f55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -199,8 +199,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull - keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull - principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + keytab = Option(keytab) + .orElse(sparkProperties.get("spark.kerberos.keytab")) + .orElse(sparkProperties.get("spark.yarn.keytab")) + .orNull + principal = Option(principal) + .orElse(sparkProperties.get("spark.kerberos.principal")) + .orElse(sparkProperties.get("spark.yarn.principal")) + .orNull dynamicAllocationEnabled = sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9891b6a2196de..7f6342208350a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -152,11 +152,11 @@ package object config { private[spark] val SHUFFLE_SERVICE_PORT = ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337) - private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") + private[spark] val KEYTAB = ConfigBuilder("spark.kerberos.keytab") .doc("Location of user's keytab.") .stringConf.createOptional - private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal") + private[spark] val PRINCIPAL = ConfigBuilder("spark.kerberos.principal") .doc("Name of the Kerberos principal.") .stringConf.createOptional diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e3d67c34d53eb..687f9e46c3285 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -465,7 +465,7 @@ providers can be disabled individually by setting `spark.security.credentials.{s - + - + - + - + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index a882558551e37..135430f1ef621 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -59,6 +59,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", + "spark.kerberos.keytab", + "spark.kerberos.principal", "spark.ui.filters", "spark.mesos.driver.frameworkId") From d0990e3dfee752a6460a6360e1a773138364d774 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 26 Sep 2018 17:47:05 -0700 Subject: [PATCH 1706/2461] [SPARK-25454][SQL] add a new config for picking minimum precision for integral literals ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/20023 proposed to allow precision lose during decimal operations, to reduce the possibilities of overflow. This is a behavior change and is protected by the DECIMAL_OPERATIONS_ALLOW_PREC_LOSS config. However, that PR introduced another behavior change: pick a minimum precision for integral literals, which is not protected by a config. This PR add a new config for it: `spark.sql.literal.pickMinimumPrecision`. This can allow users to work around issue in SPARK-25454, which is caused by a long-standing bug of negative scale. ## How was this patch tested? a new test Closes #22494 from cloud-fan/decimal. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../sql/catalyst/analysis/DecimalPrecision.scala | 10 ++++++---- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index e511f8064e28a..82692334544e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -290,11 +290,13 @@ object DecimalPrecision extends TypeCoercionRule { // potentially loosing 11 digits of the fractional part. Using only the precision needed // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would // become DECIMAL(38, 16), safely having a much lower precision loss. - case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] - && l.dataType.isInstanceOf[IntegralType] => + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && + l.dataType.isInstanceOf[IntegralType] && + SQLConf.get.literalPickMinimumPrecision => b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) - case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] - && r.dataType.isInstanceOf[IntegralType] => + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && + r.dataType.isInstanceOf[IntegralType] && + SQLConf.get.literalPickMinimumPrecision => b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2f4d660437ada..f6c98805bfb15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1331,6 +1331,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LITERAL_PICK_MINIMUM_PRECISION = + buildConf("spark.sql.legacy.literal.pickMinimumPrecision") + .internal() + .doc("When integral literal is used in decimal operations, pick a minimum precision " + + "required by the literal if this config is true, to make the resulting precision and/or " + + "scale smaller. This can reduce the possibility of precision lose and/or overflow.") + .booleanConf + .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = buildConf("spark.sql.redaction.options.regex") .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + @@ -1925,6 +1934,8 @@ class SQLConf extends Serializable with Logging { def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8fcebb35a0543..631ab1b7ece7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2849,6 +2849,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val result = ds.flatMap(_.bar).distinct result.rdd.isEmpty } + + test("SPARK-25454: decimal division with negative scale") { + // TODO: completely fix this issue even when LITERAL_PRECISE_PRECISION is true. + withSQLConf(SQLConf.LITERAL_PICK_MINIMUM_PRECISION.key -> "false") { + checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000"))) + } + } } case class Foo(bar: Option[String]) From c3c45cbd76d91d591d98cf8411fcfd30079f5969 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 27 Sep 2018 09:51:20 +0800 Subject: [PATCH 1707/2461] [SPARK-25540][SQL][PYSPARK] Make HiveContext in PySpark behave as the same as Scala. ## What changes were proposed in this pull request? In Scala, `HiveContext` sets a config `spark.sql.catalogImplementation` of the given `SparkContext` and then passes to `SparkSession.builder`. The `HiveContext` in PySpark should behave as the same as Scala. ## How was this patch tested? Existing tests. Closes #22552 from ueshin/issues/SPARK-25540/hive_context. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- python/pyspark/sql/context.py | 3 ++- python/pyspark/sql/session.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9c094dd9a9033..1938965a7e175 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -485,7 +485,8 @@ def __init__(self, sparkContext, jhiveContext=None): "SparkSession.builder.enableHiveSupport().getOrCreate() instead.", DeprecationWarning) if jhiveContext is None: - sparkSession = SparkSession.builder.enableHiveSupport().getOrCreate() + sparkContext._conf.set("spark.sql.catalogImplementation", "hive") + sparkSession = SparkSession.builder._sparkContext(sparkContext).getOrCreate() else: sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession()) SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 51a38ebfd19ff..a5e2872577312 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -83,6 +83,7 @@ class Builder(object): _lock = RLock() _options = {} + _sc = None @since(2.0) def config(self, key=None, value=None, conf=None): @@ -139,6 +140,11 @@ def enableHiveSupport(self): """ return self.config("spark.sql.catalogImplementation", "hive") + def _sparkContext(self, sc): + with self._lock: + self._sc = sc + return self + @since(2.0) def getOrCreate(self): """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a @@ -167,11 +173,14 @@ def getOrCreate(self): from pyspark.conf import SparkConf session = SparkSession._instantiatedSession if session is None or session._sc._jsc is None: - sparkConf = SparkConf() - for key, value in self._options.items(): - sparkConf.set(key, value) - sc = SparkContext.getOrCreate(sparkConf) - # This SparkContext may be an existing one. + if self._sc is not None: + sc = self._sc + else: + sparkConf = SparkConf() + for key, value in self._options.items(): + sparkConf.set(key, value) + sc = SparkContext.getOrCreate(sparkConf) + # This SparkContext may be an existing one. for key, value in self._options.items(): # we need to propagate the confs # before we create the SparkSession. Otherwise, confs like From 9063b17f3d0f22b8e4142200259190a20f832a29 Mon Sep 17 00:00:00 2001 From: yucai Date: Wed, 26 Sep 2018 20:40:10 -0700 Subject: [PATCH 1708/2461] [SPARK-25481][SQL][TEST] Refactor ColumnarBatchBenchmark to use main method ## What changes were proposed in this pull request? Refactor `ColumnarBatchBenchmark` to use main method. Generate benchmark result: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.vectorized.ColumnarBatchBenchmark" ``` ## How was this patch tested? manual tests Closes #22490 from yucai/SPARK-25481. Lead-authored-by: yucai Co-authored-by: Yucai Yu Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../ColumnarBatchBenchmark-results.txt | 59 +++++++++++++ .../vectorized/ColumnarBatchBenchmark.scala | 84 ++++++------------- 2 files changed, 85 insertions(+), 58 deletions(-) create mode 100644 sql/core/benchmarks/ColumnarBatchBenchmark-results.txt diff --git a/sql/core/benchmarks/ColumnarBatchBenchmark-results.txt b/sql/core/benchmarks/ColumnarBatchBenchmark-results.txt new file mode 100644 index 0000000000000..59637162f0a1d --- /dev/null +++ b/sql/core/benchmarks/ColumnarBatchBenchmark-results.txt @@ -0,0 +1,59 @@ +================================================================================================ +Int Read/Write +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Java Array 244 / 244 1342.3 0.7 1.0X +ByteBuffer Unsafe 445 / 445 736.5 1.4 0.5X +ByteBuffer API 2124 / 2125 154.3 6.5 0.1X +DirectByteBuffer 750 / 750 437.2 2.3 0.3X +Unsafe Buffer 234 / 236 1401.3 0.7 1.0X +Column(on heap) 245 / 245 1335.6 0.7 1.0X +Column(off heap) 489 / 489 670.3 1.5 0.5X +Column(off heap direct) 236 / 236 1388.1 0.7 1.0X +UnsafeRow (on heap) 532 / 534 616.0 1.6 0.5X +UnsafeRow (off heap) 564 / 565 580.7 1.7 0.4X +Column On Heap Append 489 / 489 670.6 1.5 0.5X + + +================================================================================================ +Boolean Read/Write +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Bitset 879 / 879 381.9 2.6 1.0X +Byte Array 794 / 794 422.6 2.4 1.1X + + +================================================================================================ +String Read/Write +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +On Heap 449 / 449 36.5 27.4 1.0X +Off Heap 679 / 679 24.1 41.4 0.7X + + +================================================================================================ +Array Vector Read +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +On Heap Read Size Only 713 / 713 229.8 4.4 1.0X +Off Heap Read Size Only 757 / 757 216.5 4.6 0.9X +On Heap Read Elements 3648 / 3650 44.9 22.3 0.2X +Off Heap Read Elements 5263 / 5265 31.1 32.1 0.1X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index d69cf1126868e..df6ab14e661c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import scala.util.Random -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} @@ -30,8 +30,15 @@ import org.apache.spark.util.collection.BitSet /** * Benchmark to low level memory access using different ways to manage buffers. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/ColumnarBatchBenchmark-results.txt". + * }}} */ -object ColumnarBatchBenchmark { +object ColumnarBatchBenchmark extends BenchmarkBase { // This benchmark reads and writes an array of ints. // TODO: there is a big (2x) penalty for a random access API for off heap. // Note: carefully if modifying this code. It's hard to reason about the JIT. @@ -260,25 +267,7 @@ object ColumnarBatchBenchmark { col.close } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Java Array 177 / 183 1851.1 0.5 1.0X - ByteBuffer Unsafe 314 / 330 1043.7 1.0 0.6X - ByteBuffer API 1298 / 1307 252.4 4.0 0.1X - DirectByteBuffer 465 / 483 704.2 1.4 0.4X - Unsafe Buffer 179 / 183 1835.5 0.5 1.0X - Column(on heap) 181 / 186 1815.2 0.6 1.0X - Column(off heap) 344 / 349 951.7 1.1 0.5X - Column(off heap direct) 178 / 186 1838.6 0.5 1.0X - UnsafeRow (on heap) 388 / 394 844.8 1.2 0.5X - UnsafeRow (off heap) 400 / 403 819.4 1.2 0.4X - Column On Heap Append 315 / 325 1041.8 1.0 0.6X - */ - val benchmark = new Benchmark("Int Read/Write", count * iters) + val benchmark = new Benchmark("Int Read/Write", count * iters, output = output) benchmark.addCase("Java Array")(javaArray) benchmark.addCase("ByteBuffer Unsafe")(byteBufferUnsafe) benchmark.addCase("ByteBuffer API")(byteBufferApi) @@ -295,7 +284,7 @@ object ColumnarBatchBenchmark { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 - val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong) + val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong, output = output) benchmark.addCase("Bitset") { i: Int => { val b = new BitSet(count) var sum = 0L @@ -329,15 +318,6 @@ object ColumnarBatchBenchmark { } } }} - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Bitset 741 / 747 452.6 2.2 1.0X - Byte Array 531 / 542 631.6 1.6 1.4X - */ benchmark.run() } @@ -386,16 +366,7 @@ object ColumnarBatchBenchmark { } } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - On Heap 351 / 362 46.6 21.4 1.0X - Off Heap 456 / 466 35.9 27.8 0.8X - */ - val benchmark = new Benchmark("String Read/Write", count * iters) + val benchmark = new Benchmark("String Read/Write", count * iters, output = output) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) benchmark.addCase("Off Heap")(column(MemoryMode.OFF_HEAP)) benchmark.run @@ -463,30 +434,27 @@ object ColumnarBatchBenchmark { } } - val benchmark = new Benchmark("Array Vector Read", count * iters) + val benchmark = new Benchmark("Array Vector Read", count * iters, output = output) benchmark.addCase("On Heap Read Size Only") { _ => readArrays(true) } benchmark.addCase("Off Heap Read Size Only") { _ => readArrays(false) } benchmark.addCase("On Heap Read Elements") { _ => readArrayElements(true) } benchmark.addCase("Off Heap Read Elements") { _ => readArrayElements(false) } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 426 / 437 384.9 2.6 1.0X - Off Heap Read Size Only 406 / 421 404.0 2.5 1.0X - On Heap Read Elements 2636 / 2642 62.2 16.1 0.2X - Off Heap Read Elements 3770 / 3774 43.5 23.0 0.1X - */ benchmark.run } - def main(args: Array[String]): Unit = { - intAccess(1024 * 40) - booleanAccess(1024 * 40) - stringAccess(1024 * 4) - arrayAccess(1024 * 40) + override def benchmark(): Unit = { + runBenchmark("Int Read/Write") { + intAccess(1024 * 40) + } + runBenchmark("Boolean Read/Write") { + booleanAccess(1024 * 40) + } + runBenchmark("String Read/Write") { + stringAccess(1024 * 4) + } + runBenchmark("Array Vector Read") { + arrayAccess(1024 * 40) + } } } From 5def10e61e49dba85f4d8b39c92bda15137990a2 Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 26 Sep 2018 21:10:39 -0700 Subject: [PATCH 1709/2461] [SPARK-25536][CORE] metric value for METRIC_OUTPUT_RECORDS_WRITTEN is incorrect ## What changes were proposed in this pull request? changed metric value of METRIC_OUTPUT_RECORDS_WRITTEN from 'task.metrics.inputMetrics.recordsRead' to 'task.metrics.outputMetrics.recordsWritten'. This bug was introduced in SPARK-22190. https://github.com/apache/spark/pull/19426 ## How was this patch tested? Existing tests Closes #22555 from shahidki31/SPARK-25536. Authored-by: Shahid Signed-off-by: Dongjoon Hyun --- core/src/main/scala/org/apache/spark/executor/Executor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6d7d65626ea12..eba708da7798e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -504,7 +504,7 @@ private[spark] class Executor( executorSource.METRIC_OUTPUT_BYTES_WRITTEN .inc(task.metrics.outputMetrics.bytesWritten) executorSource.METRIC_OUTPUT_RECORDS_WRITTEN - .inc(task.metrics.inputMetrics.recordsRead) + .inc(task.metrics.outputMetrics.recordsWritten) executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize) executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled) executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled) From ee214ef3a0ec36c4aae5040778d41c376df3da19 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 27 Sep 2018 12:37:03 +0800 Subject: [PATCH 1710/2461] [SPARK-25525][SQL][PYSPARK] Do not update conf for existing SparkContext in SparkSession.getOrCreate. ## What changes were proposed in this pull request? In [SPARK-20946](https://issues.apache.org/jira/browse/SPARK-20946), we modified `SparkSession.getOrCreate` to not update conf for existing `SparkContext` because `SparkContext` is shared by all sessions. We should not update it in PySpark side as well. ## How was this patch tested? Added tests. Closes #22545 from ueshin/issues/SPARK-25525/not_update_existing_conf. Authored-by: Takuya UESHIN Signed-off-by: hyukjinkwon --- python/pyspark/sql/session.py | 14 +++-------- python/pyspark/sql/tests.py | 46 ++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index a5e2872577312..079af8c05705d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -156,7 +156,7 @@ def getOrCreate(self): default. >>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate() - >>> s1.conf.get("k1") == s1.sparkContext.getConf().get("k1") == "v1" + >>> s1.conf.get("k1") == "v1" True In case an existing SparkSession is returned, the config options specified @@ -179,19 +179,13 @@ def getOrCreate(self): sparkConf = SparkConf() for key, value in self._options.items(): sparkConf.set(key, value) - sc = SparkContext.getOrCreate(sparkConf) # This SparkContext may be an existing one. - for key, value in self._options.items(): - # we need to propagate the confs - # before we create the SparkSession. Otherwise, confs like - # warehouse path and metastore url will not be set correctly ( - # these confs cannot be changed once the SparkSession is created). - sc._conf.set(key, value) + sc = SparkContext.getOrCreate(sparkConf) + # Do not update `SparkConf` for existing `SparkContext`, as it's shared + # by all sessions. session = SparkSession(sc) for key, value in self._options.items(): session._jsparkSession.sessionState().conf().setConfString(key, value) - for key, value in self._options.items(): - session.sparkContext._conf.set(key, value) return session builder = Builder() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 74642d46d1cd1..64a7ceb3fea96 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -80,7 +80,7 @@ _have_pyarrow = _pyarrow_requirement_message is None _test_compiled = _test_not_compiled_message is None -from pyspark import SparkContext +from pyspark import SparkConf, SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier @@ -283,6 +283,50 @@ def test_invalid_create_row(self): self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) +class SparkSessionBuilderTests(unittest.TestCase): + + def test_create_spark_context_first_then_spark_session(self): + sc = None + session = None + try: + conf = SparkConf().set("key1", "value1") + sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) + session = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session.conf.get("key1"), "value1") + self.assertEqual(session.conf.get("key2"), "value2") + self.assertEqual(session.sparkContext, sc) + + self.assertFalse(sc.getConf().contains("key2")) + self.assertEqual(sc.getConf().get("key1"), "value1") + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + def test_another_spark_session(self): + session1 = None + session2 = None + try: + session1 = SparkSession.builder.config("key1", "value1").getOrCreate() + session2 = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session1.conf.get("key1"), "value1") + self.assertEqual(session2.conf.get("key1"), "value1") + self.assertEqual(session1.conf.get("key2"), "value2") + self.assertEqual(session2.conf.get("key2"), "value2") + self.assertEqual(session1.sparkContext, session2.sparkContext) + + self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") + self.assertFalse(session1.sparkContext.getConf().contains("key2")) + finally: + if session1 is not None: + session1.stop() + if session2 is not None: + session2.stop() + + class SQLTests(ReusedSQLTestCase): @classmethod From 8b727994edd27104d49c6d690f93c6858fb9e1fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=B0=8F=E5=88=9A?= Date: Thu, 27 Sep 2018 00:02:05 -0500 Subject: [PATCH 1711/2461] [SPARK-25468][WEBUI] Highlight current page index in the spark UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR is highlight current page index in the spark UI and history server UI, https://issues.apache.org/jira/browse/SPARK-25468 I have add the following code in webui.css ``` .paginate_button.active>a { color: #999999; text-decoration: underline; } ``` ## How was this patch tested? Manual tests for Chrome, Firefox and Safari Before modifying: ![image](https://user-images.githubusercontent.com/10048468/45914897-01ca6c00-be7e-11e8-8e31-47d45db0c3bf.png) After modifying: ![image](https://user-images.githubusercontent.com/10048468/45913987-7e564e00-be70-11e8-9c16-de17e2c63308.png) Closes #22516 from Adamyuanyuan/spark-adam-25468. Lead-authored-by: 王小刚 Co-authored-by: Adam Wang Signed-off-by: Sean Owen --- core/src/main/resources/org/apache/spark/ui/static/webui.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 935d9b1aec615..4b060b0f4e53e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -251,4 +251,9 @@ a.expandbutton { .table-cell-width-limited td { max-width: 600px; +} + +.paginate_button.active > a { + color: #999999; + text-decoration: underline; } \ No newline at end of file From f309b28bd9271719ca36fcf334f016ed6165a79b Mon Sep 17 00:00:00 2001 From: yucai Date: Wed, 26 Sep 2018 23:27:45 -0700 Subject: [PATCH 1712/2461] [SPARK-25485][SQL][TEST] Refactor UnsafeProjectionBenchmark to use main method ## What changes were proposed in this pull request? Refactor `UnsafeProjectionBenchmark` to use main method. Generate benchmark result: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/test:runMain org.apache.spark.sql.UnsafeProjectionBenchmark" ``` ## How was this patch tested? manual test Closes #22493 from yucai/SPARK-25485. Lead-authored-by: yucai Co-authored-by: Yucai Yu Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../UnsafeProjectionBenchmark-results.txt | 14 ++ .../spark/sql/UnsafeProjectionBenchmark.scala | 172 +++++++++--------- 2 files changed, 98 insertions(+), 88 deletions(-) create mode 100644 sql/catalyst/benchmarks/UnsafeProjectionBenchmark-results.txt diff --git a/sql/catalyst/benchmarks/UnsafeProjectionBenchmark-results.txt b/sql/catalyst/benchmarks/UnsafeProjectionBenchmark-results.txt new file mode 100644 index 0000000000000..43156dc6fc67f --- /dev/null +++ b/sql/catalyst/benchmarks/UnsafeProjectionBenchmark-results.txt @@ -0,0 +1,14 @@ +================================================================================================ +unsafe projection +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +unsafe projection: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +single long 2867 / 2868 93.6 10.7 1.0X +single nullable long 3915 / 3949 68.6 14.6 0.7X +7 primitive types 8166 / 8167 32.9 30.4 0.4X +7 nullable primitive types 12767 / 12767 21.0 47.6 0.2X + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index faff681e13955..cbe723fd11c6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeProjection @@ -25,8 +25,15 @@ import org.apache.spark.sql.types._ /** * Benchmark `UnsafeProjection` for fixed-length/primitive-type fields. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/UnsafeProjectionBenchmark-results.txt". + * }}} */ -object UnsafeProjectionBenchmark { +object UnsafeProjectionBenchmark extends BenchmarkBase { def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { val generator = RandomDataGenerator.forType(schema, nullable = false).get @@ -34,103 +41,92 @@ object UnsafeProjectionBenchmark { (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray } - def main(args: Array[String]) { - val iters = 1024 * 16 - val numRows = 1024 * 16 - - val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong) - - - val schema1 = new StructType().add("l", LongType, false) - val attrs1 = schema1.toAttributes - val rows1 = generateRows(schema1, numRows) - val projection1 = UnsafeProjection.create(attrs1, attrs1) - - benchmark.addCase("single long") { _ => - for (_ <- 1 to iters) { - var sum = 0L - var i = 0 - while (i < numRows) { - sum += projection1(rows1(i)).getLong(0) - i += 1 + override def benchmark(): Unit = { + runBenchmark("unsafe projection") { + val iters = 1024 * 16 + val numRows = 1024 * 16 + + val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong, output = output) + + val schema1 = new StructType().add("l", LongType, false) + val attrs1 = schema1.toAttributes + val rows1 = generateRows(schema1, numRows) + val projection1 = UnsafeProjection.create(attrs1, attrs1) + + benchmark.addCase("single long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection1(rows1(i)).getLong(0) + i += 1 + } } } - } - - val schema2 = new StructType().add("l", LongType, true) - val attrs2 = schema2.toAttributes - val rows2 = generateRows(schema2, numRows) - val projection2 = UnsafeProjection.create(attrs2, attrs2) - benchmark.addCase("single nullable long") { _ => - for (_ <- 1 to iters) { - var sum = 0L - var i = 0 - while (i < numRows) { - sum += projection2(rows2(i)).getLong(0) - i += 1 + val schema2 = new StructType().add("l", LongType, true) + val attrs2 = schema2.toAttributes + val rows2 = generateRows(schema2, numRows) + val projection2 = UnsafeProjection.create(attrs2, attrs2) + + benchmark.addCase("single nullable long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection2(rows2(i)).getLong(0) + i += 1 + } } } - } - - val schema3 = new StructType() - .add("boolean", BooleanType, false) - .add("byte", ByteType, false) - .add("short", ShortType, false) - .add("int", IntegerType, false) - .add("long", LongType, false) - .add("float", FloatType, false) - .add("double", DoubleType, false) - val attrs3 = schema3.toAttributes - val rows3 = generateRows(schema3, numRows) - val projection3 = UnsafeProjection.create(attrs3, attrs3) - - benchmark.addCase("7 primitive types") { _ => - for (_ <- 1 to iters) { - var sum = 0L - var i = 0 - while (i < numRows) { - sum += projection3(rows3(i)).getLong(0) - i += 1 + val schema3 = new StructType() + .add("boolean", BooleanType, false) + .add("byte", ByteType, false) + .add("short", ShortType, false) + .add("int", IntegerType, false) + .add("long", LongType, false) + .add("float", FloatType, false) + .add("double", DoubleType, false) + val attrs3 = schema3.toAttributes + val rows3 = generateRows(schema3, numRows) + val projection3 = UnsafeProjection.create(attrs3, attrs3) + + benchmark.addCase("7 primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection3(rows3(i)).getLong(0) + i += 1 + } } } - } - - - val schema4 = new StructType() - .add("boolean", BooleanType, true) - .add("byte", ByteType, true) - .add("short", ShortType, true) - .add("int", IntegerType, true) - .add("long", LongType, true) - .add("float", FloatType, true) - .add("double", DoubleType, true) - val attrs4 = schema4.toAttributes - val rows4 = generateRows(schema4, numRows) - val projection4 = UnsafeProjection.create(attrs4, attrs4) - benchmark.addCase("7 nullable primitive types") { _ => - for (_ <- 1 to iters) { - var sum = 0L - var i = 0 - while (i < numRows) { - sum += projection4(rows4(i)).getLong(0) - i += 1 + val schema4 = new StructType() + .add("boolean", BooleanType, true) + .add("byte", ByteType, true) + .add("short", ShortType, true) + .add("int", IntegerType, true) + .add("long", LongType, true) + .add("float", FloatType, true) + .add("double", DoubleType, true) + val attrs4 = schema4.toAttributes + val rows4 = generateRows(schema4, numRows) + val projection4 = UnsafeProjection.create(attrs4, attrs4) + + benchmark.addCase("7 nullable primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection4(rows4(i)).getLong(0) + i += 1 + } } } - } - - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - single long 1533.34 175.07 1.00 X - single nullable long 2306.73 116.37 0.66 X - primitive types 8403.93 31.94 0.18 X - nullable primitive types 12448.39 21.56 0.12 X - */ - benchmark.run() + benchmark.run() + } } } From ff876137faba1802b66ecd483ba15f6ccd83ffc5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Sep 2018 15:02:20 +0800 Subject: [PATCH 1713/2461] [SPARK-23715][SQL][DOC] improve document for from/to_utc_timestamp ## What changes were proposed in this pull request? We have an agreement that the behavior of `from/to_utc_timestamp` is corrected, although the function itself doesn't make much sense in Spark: https://issues.apache.org/jira/browse/SPARK-23715 This PR improves the document. ## How was this patch tested? N/A Closes #22543 from cloud-fan/doc. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- R/pkg/R/functions.R | 26 ++++++++++++---- python/pyspark/sql/functions.py | 30 +++++++++++++++---- .../expressions/datetimeExpressions.scala | 30 +++++++++++++++---- 3 files changed, 68 insertions(+), 18 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 6425c9d26bef3..2cb4cb8d531e1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2204,9 +2204,16 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") }) #' @details -#' \code{from_utc_timestamp}: Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a -#' time in UTC, and renders that time as a timestamp in the given time zone. For example, 'GMT+1' -#' would yield '2017-07-14 03:40:00.0'. +#' \code{from_utc_timestamp}: This is a common function for databases supporting TIMESTAMP WITHOUT +#' TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a +#' timestamp in UTC, and renders that timestamp as a timestamp in the given time zone. +#' However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not +#' timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to +#' the given timezone. +#' This function may return confusing result if the input is a string with timezone, e.g. +#' (\code{2018-03-13T06:18:23+00:00}). The reason is that, Spark firstly cast the string to +#' timestamp according to the timezone in the string, and finally display the result by converting +#' the timestamp to string according to the session local timezone. #' #' @rdname column_datetime_diff_functions #' @@ -2262,9 +2269,16 @@ setMethod("next_day", signature(y = "Column", x = "character"), }) #' @details -#' \code{to_utc_timestamp}: Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a -#' time in the given time zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' -#' would yield '2017-07-14 01:40:00.0'. +#' \code{to_utc_timestamp}: This is a common function for databases supporting TIMESTAMP WITHOUT +#' TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a +#' timestamp in the given timezone, and renders that timestamp as a timestamp in UTC. +#' However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not +#' timezone-agnostic. So in Spark this function just shift the timestamp value from the given +#' timezone to UTC timezone. +#' This function may return confusing result if the input is a string with timezone, e.g. +#' (\code{2018-03-13T06:18:23+00:00}). The reason is that, Spark firstly cast the string to +#' timestamp according to the timezone in the string, and finally display the result by converting +#' the timestamp to string according to the session local timezone. #' #' @rdname column_datetime_diff_functions #' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1c3d9725b285b..e5bc1eaaad21a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1283,9 +1283,18 @@ def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): @since(1.5) def from_utc_timestamp(timestamp, tz): """ - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders - that time as a timestamp in the given time zone. For example, 'GMT+1' would yield - '2017-07-14 03:40:00.0'. + This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and + renders that timestamp as a timestamp in the given time zone. + + However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to + the given timezone. + + This function may return confusing result if the input is a string with timezone, e.g. + '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + according to the timezone in the string, and finally display the result by converting the + timestamp to string according to the session local timezone. :param timestamp: the column that contains timestamps :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc @@ -1308,9 +1317,18 @@ def from_utc_timestamp(timestamp, tz): @since(1.5) def to_utc_timestamp(timestamp, tz): """ - Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time - zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield - '2017-07-14 01:40:00.0'. + This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given + timezone, and renders that timestamp as a timestamp in UTC. + + However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + timezone-agnostic. So in Spark this function just shift the timestamp value from the given + timezone to UTC timezone. + + This function may return confusing result if the input is a string with timezone, e.g. + '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + according to the timezone in the string, and finally display the result by converting the + timestamp to string according to the session local timezone. :param timestamp: the column that contains timestamps :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index eb78e394f9850..45e17ae235a94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1018,9 +1018,18 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } /** - * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders - * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield - * '2017-07-14 03:40:00.0'. + * This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + * takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and + * renders that timestamp as a timestamp in the given time zone. + * + * However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + * timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to + * the given timezone. + * + * This function may return confusing result if the input is a string with timezone, e.g. + * '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + * according to the timezone in the string, and finally display the result by converting the + * timestamp to string according to the session local timezone. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -1215,9 +1224,18 @@ case class MonthsBetween( } /** - * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time zone, - * and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield - * '2017-07-14 01:40:00.0'. + * This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + * takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given + * timezone, and renders that timestamp as a timestamp in UTC. + * + * However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + * timezone-agnostic. So in Spark this function just shift the timestamp value from the given + * timezone to UTC timezone. + * + * This function may return confusing result if the input is a string with timezone, e.g. + * '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + * according to the timezone in the string, and finally display the result by converting the + * timestamp to string according to the session local timezone. */ // scalastyle:off line.size.limit @ExpressionDescription( From d03e0af80d7659f12821cc2442efaeaee94d3985 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 27 Sep 2018 15:04:59 +0800 Subject: [PATCH 1714/2461] [SPARK-25522][SQL] Improve type promotion for input arguments of elementAt function ## What changes were proposed in this pull request? In ElementAt, when first argument is MapType, we should coerce the key type and the second argument based on findTightestCommonType. This is not happening currently. We may produce wrong output as we will incorrectly downcast the right hand side double expression to int. ```SQL spark-sql> select element_at(map(1,"one", 2, "two"), 2.2); two ``` Also, when the first argument is ArrayType, the second argument should be an integer type or a smaller integral type that can be safely casted to an integer type. Currently we may do an unsafe cast. In the following case, we should fail with an error as 2.2 is not a integer index. But instead we down cast it to int currently and return a result instead. ```SQL spark-sql> select element_at(array(1,2), 1.24D); 1 ``` This PR also supports implicit cast between two MapTypes. I have followed similar logic that exists today to do implicit casts between two array types. ## How was this patch tested? Added new tests in DataFrameFunctionSuite, TypeCoercionSuite. Closes #22544 from dilipbiswal/SPARK-25522. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/TypeCoercion.scala | 19 +++++ .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/collectionOperations.scala | 37 ++++++--- .../catalyst/analysis/TypeCoercionSuite.scala | 43 +++++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 75 ++++++++++++++++++- 5 files changed, 154 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 49d286f6cf125..72ac80e0a0a18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -950,6 +950,25 @@ object TypeCoercion { if !Cast.forceNullable(fromType, toType) => implicitCast(fromType, toType).map(ArrayType(_, false)).orNull + // Implicit cast between Map types. + // Follows the same semantics of implicit casting between two array types. + // Refer to documentation above. Make sure that both key and values + // can not be null after the implicit cast operation by calling forceNullable + // method. + case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) + if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) => + if (Cast.forceNullable(fromValueType, toValueType) && !tn) { + null + } else { + val newKeyType = implicitCast(fromKeyType, toKeyType).orNull + val newValueType = implicitCast(fromValueType, toValueType).orNull + if (newKeyType != null && newValueType != null) { + MapType(newKeyType, newValueType, tn) + } else { + null + } + } + case _ => null } Option(ret) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8f777997bf615..ee463bf5eb6ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -183,7 +183,7 @@ object Cast { case _ => false } - private def resolvableNullability(from: Boolean, to: Boolean) = !from || to + def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9cc7dbadd923a..b24d7486f3454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2154,21 +2154,34 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(ArrayType, MapType), - left.dataType match { - case _: ArrayType => IntegerType - case _: MapType => mapKeyType - case _ => AnyDataType // no match for a wrong 'left' expression type - } - ) + (left.dataType, right.dataType) match { + case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) => + Seq(arr, IntegerType) + case (MapType(keyType, valueType, hasNull), e2) => + TypeCoercion.findTightestCommonType(keyType, e2) match { + case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt) + case _ => Seq.empty + } + case (l, r) => Seq.empty + + } } override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") - case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + (left.dataType, right.dataType) match { + case (_: ArrayType, e2) if e2 != IntegerType => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${MapType.simpleString} followed by a value of same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) => + TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " + + s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " + + s"${left.dataType.catalogString} type.") + case _ => TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0594673ecc926..0eba1c537d67d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -257,12 +257,43 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(checkedType, IntegralType) } - test("implicit type cast - MapType(StringType, StringType)") { - val checkedType = MapType(StringType, StringType) - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) + test("implicit type cast between two Map types") { + val sourceType = MapType(IntegerType, IntegerType, true) + val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _)) + val targetTypes = numericTypes.filter(!Cast.forceNullable(IntegerType, _)).map { t => + MapType(t, sourceType.valueType, valueContainsNull = true) + } + val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t => + MapType(t, sourceType.valueType, valueContainsNull = true) + } + + // Tests that its possible to setup implicit casts between two map types when + // source map's key type is integer and the target map's key type are either Byte, Short, + // Long, Double, Float, Decimal(38, 18) or String. + targetTypes.foreach { targetType => + shouldCast(sourceType, targetType, targetType) + } + + // Tests that its not possible to setup implicit casts between two map types when + // source map's key type is integer and the target map's key type are either Binary, + // Boolean, Date, Timestamp, Array, Struct, CaleandarIntervalType or NullType + nonCastableTargetTypes.foreach { targetType => + shouldNotCast(sourceType, targetType) + } + + // Tests that its not possible to cast from nullable map type to not nullable map type. + val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t => + MapType(t, sourceType.valueType, valueContainsNull = false) + } + val sourceMapExprWithValueNull = + CreateMap(Seq(Literal.default(sourceType.keyType), + Literal.create(null, sourceType.valueType))) + targetNotNullableTypes.foreach { targetType => + val castDefault = + TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, targetType) + assert(castDefault.isEmpty, + s"Should not be able to cast $sourceType to $targetType, but got $castDefault") + } } test("implicit type cast - StructType().add(\"a1\", StringType)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 88dbae8c21350..60ebc5e6cc09b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1211,11 +1211,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row("3"), Row(""), Row(null)) ) - val e = intercept[AnalysisException] { + val e1 = intercept[AnalysisException] { Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") } - assert(e.message.contains( - "argument 1 requires (array or map) type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |The first argument to function element_at should have been array or map type, but + |its string type. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + checkAnswer( + OneRowRelation().selectExpr("element_at(array(2, 1), 2S)"), + Seq(Row(1)) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(array('a', 'b'), 1Y)"), + Seq(Row("a")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)"), + Seq(Row(3)) + ) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)") + } + val errorMsg2 = + s""" + |Input to function element_at should have been array followed by a int, but it's + |[array, bigint]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2Y)"), + Seq(Row("b")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1S)"), + Seq(Row("a")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2)"), + Seq(Row("b")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2L)"), + Seq(Row("b")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.0D)"), + Seq(Row("a")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"), + Seq(Row(null)) + ) + + val e3 = intercept[AnalysisException] { + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')") + } + val errorMsg3 = + s""" + |Input to function element_at should have been map followed by a value of same + |key type, but it's [map, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e3.message.contains(errorMsg3)) } test("array_union functions") { From 2a8cbfddba2a59d144b32910c68c22d0199093fe Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 27 Sep 2018 15:13:18 +0800 Subject: [PATCH 1715/2461] [SPARK-25314][SQL] Fix Python UDF accessing attributes from both side of join in join conditions ## What changes were proposed in this pull request? Thanks for bahchis reporting this. It is more like a follow up work for #16581, this PR fix the scenario of Python UDF accessing attributes from both side of join in join condition. ## How was this patch tested? Add regression tests in PySpark and `BatchEvalPythonExecSuite`. Closes #22326 from xuanyuanking/SPARK-25314. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests.py | 64 +++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 8 ++- .../spark/sql/catalyst/optimizer/joins.scala | 49 ++++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 64a7ceb3fea96..b88a6551f8ae5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -596,6 +596,70 @@ def test_udf_in_filter_on_top_of_join(self): df = left.crossJoin(right).filter(f("a", "b")) self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b")) + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b"), "leftsemi") + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_and_common_filter_in_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1]) + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + + def test_udf_and_common_filter_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_not_supported_in_join_condition(self): + # regression test for SPARK-25314 + # test python udf is not supported in join type besides left_semi and inner join. + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + + def runWithJoinType(join_type, type_string): + with self.assertRaisesRegexp( + AnalysisException, + 'Using PythonUDF.*%s is not supported.' % type_string): + left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() + runWithJoinType("full", "FullOuter") + runWithJoinType("left", "LeftOuter") + runWithJoinType("right", "RightOuter") + runWithJoinType("leftanti", "LeftAnti") + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 07a653f3b5d48..da8009d50b5ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -165,7 +165,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :+ - // The following batch should be executed after batch "Join Reorder" and "LocalRelation". + Batch("Extract PythonUDF From JoinCondition", Once, + PullOutPythonUDFInJoinCondition) :+ + // The following batch should be executed after batch "Join Reorder" "LocalRelation" and + // "Extract PythonUDF From JoinCondition". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ Batch("RewriteSubquery", Once, @@ -202,7 +205,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceDistinctWithAggregate.ruleName :: PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: - RewritePredicateSubquery.ruleName :: Nil + RewritePredicateSubquery.ruleName :: + PullOutPythonUDFInJoinCondition.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index edbeaf273fd6f..7149edee0173e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ @@ -152,3 +153,51 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } + +/** + * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF + * and pull them out from join condition. For python udf accessing attributes from only one side, + * they are pushed down by operation push down rules. If not (e.g. user disables filter push + * down rules), we need to pull them out in this rule too. + */ +object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { + def hasPythonUDF(expression: Expression): Boolean = { + expression.collectFirst { case udf: PythonUDF => udf }.isDefined + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case j @ Join(_, _, joinType, condition) + if condition.isDefined && hasPythonUDF(condition.get) => + if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { + // The current strategy only support InnerLike and LeftSemi join because for other type, + // it breaks SQL semantic if we run the join condition as a filter after join. If we pass + // the plan here, it'll still get a an invalid PythonUDF RuntimeException with message + // `requires attributes from more than one child`, we throw firstly here for better + // readable information. + throw new AnalysisException("Using PythonUDF in join condition of join type" + + s" $joinType is not supported.") + } + // If condition expression contains python udf, it will be moved out from + // the new join conditions. + val (udf, rest) = + splitConjunctivePredicates(condition.get).partition(hasPythonUDF) + val newCondition = if (rest.isEmpty) { + logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," + + s" it will be moved out and the join plan will be turned to cross join.") + None + } else { + Some(rest.reduceLeft(And)) + } + val newJoin = j.copy(condition = newCondition) + joinType match { + case _: InnerLike => Filter(udf.reduceLeft(And), newJoin) + case LeftSemi => + Project( + j.left.output.map(_.toAttribute), + Filter(udf.reduceLeft(And), newJoin.copy(joinType = Inner))) + case _ => + throw new AnalysisException("Using PythonUDF in join condition of join type" + + s" $joinType is not supported.") + } + } +} From 86a2450e09cbd3affbd66139ce4ed2b807e7b3b3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 27 Sep 2018 19:34:05 +0800 Subject: [PATCH 1716/2461] [SPARK-25551][SQL] Remove unused InSubquery expression ## What changes were proposed in this pull request? The PR removes the `InSubquery` expression which was introduced a long time ago and its only usage was removed in https://github.com/apache/spark/commit/4ce970d71488c7de6025ef925f75b8b92a5a6a79. Hence it is not used anymore. ## How was this patch tested? existing UTs Closes #22556 from mgaido91/minor_insubq. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../apache/spark/sql/execution/subquery.scala | 43 ------------------- 1 file changed, 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index d11045fb6ac8c..310ebcdf67686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -90,49 +90,6 @@ case class ScalarSubquery( } } -/** - * A subquery that will check the value of `child` whether is in the result of a query or not. - */ -case class InSubquery( - child: Expression, - plan: SubqueryExec, - exprId: ExprId, - private var result: Array[Any] = null, - private var updated: Boolean = false) extends ExecSubqueryExpression { - - override def dataType: DataType = BooleanType - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = child.nullable - override def toString: String = s"$child IN ${plan.name}" - override def withNewPlan(plan: SubqueryExec): InSubquery = copy(plan = plan) - - override def semanticEquals(other: Expression): Boolean = other match { - case in: InSubquery => child.semanticEquals(in.child) && plan.sameResult(in.plan) - case _ => false - } - - def updateResult(): Unit = { - val rows = plan.executeCollect() - result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]] - updated = true - } - - override def eval(input: InternalRow): Any = { - require(updated, s"$this has not finished") - val v = child.eval(input) - if (v == null) { - null - } else { - result.contains(v) - } - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - require(updated, s"$this has not finished") - InSet(child, result.toSet).doGenCode(ctx, ev) - } -} - /** * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ From dd8f6b1ce8ae7b2b75efda863fea40b29d52f657 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 27 Sep 2018 19:53:13 +0800 Subject: [PATCH 1717/2461] [SPARK-25541][SQL][FOLLOWUP] Remove overriding filterKeys in CaseInsensitiveMap ## What changes were proposed in this pull request? As per the discussion in https://github.com/apache/spark/pull/22553#pullrequestreview-159192221, override `filterKeys` violates the documented semantics. This PR is to remove it and add documentation. Also fix one potential non-serializable map in `FileStreamOptions`. The only one call of `CaseInsensitiveMap`'s `filterKeys` left is https://github.com/apache/spark/blob/c3c45cbd76d91d591d98cf8411fcfd30079f5969/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala#L88-L90 But this one is OK. ## How was this patch tested? Existing unit tests. Closes #22562 from gengliangwang/SPARK-25541-FOLLOWUP. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala | 6 ++---- .../spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala | 6 ------ .../spark/sql/execution/streaming/FileStreamOptions.scala | 3 +-- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index 288a4f34a447e..06f95989f2e3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -24,6 +24,8 @@ import java.util.Locale * case-sensitive information is required. The primary constructor is marked private to avoid * nested case-insensitive map creation, otherwise the keys in the original map will become * case-insensitive in this scenario. + * Note: CaseInsensitiveMap is serializable. However, after transformation, e.g. `filterKeys()`, + * it may become not serializable. */ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { @@ -44,10 +46,6 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma override def -(key: String): Map[String, T] = { new CaseInsensitiveMap(originalMap.filter(!_._1.equalsIgnoreCase(key))) } - - override def filterKeys(p: (String) => Boolean): Map[String, T] = { - new CaseInsensitiveMap(originalMap.filter(kv => p(kv._1))) - } } object CaseInsensitiveMap { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala index 03eed4aaa750b..a8bb1d0afdb87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMapSuite.scala @@ -44,10 +44,4 @@ class CaseInsensitiveMapSuite extends SparkFunSuite { assert(m == Map("a" -> "b", "foo" -> "bar", "x" -> "y")) shouldBeSerializable(m) } - - test("CaseInsensitiveMap should be serializable after 'filterKeys' method") { - val m = CaseInsensitiveMap(Map("a" -> "b", "foo" -> "bar")).filterKeys(_ == "foo") - assert(m == Map("foo" -> "bar")) - shouldBeSerializable(m) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index d54ed44b43bf1..1d57cb084df9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -54,8 +54,7 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging Utils.timeStringAsMs(parameters.getOrElse("maxFileAge", "7d")) /** Options as specified by the user, in a case-insensitive map, without "path" set. */ - val optionMapWithoutPath: Map[String, String] = - parameters.filterKeys(_ != "path") + val optionMapWithoutPath: Map[String, String] = parameters - "path" /** * Whether to scan latest files first. If it's true, when the source finds unprocessed files in a From f856fe4839757e3a1036df3fc3dec459fa439aef Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 27 Sep 2018 20:57:56 +0800 Subject: [PATCH 1718/2461] [SPARK-21436][CORE] Take advantage of known partitioner for distinct on RDDs to avoid a shuffle ## What changes were proposed in this pull request? Special case the situation where we know the partioner and the number of requested partions output is the same as the current partioner to avoid a shuffle and instead compute distinct inside of each partion. ## How was this patch tested? New unit test that verifies partitioner does not change if the partitioner is known and distinct is called with the same target # of partition. Closes #22010 from holdenk/SPARK-21436-take-advantage-of-known-partioner-for-distinct-on-rdds. Authored-by: Holden Karau Signed-off-by: Wenchen Fan --- .../main/scala/org/apache/spark/rdd/RDD.scala | 18 ++++++++++++++++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 12 ++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 61ad6dfdb2215..743e3441eea55 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -42,7 +42,8 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, OpenHashMap, + Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -396,7 +397,20 @@ abstract class RDD[T: ClassTag]( * Return a new RDD containing the distinct elements in this RDD. */ def distinct(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { - map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + def removeDuplicatesInPartition(partition: Iterator[T]): Iterator[T] = { + // Create an instance of external append only map which ignores values. + val map = new ExternalAppendOnlyMap[T, Null, Null]( + createCombiner = value => null, + mergeValue = (a, b) => a, + mergeCombiners = (a, b) => a) + map.insertAll(partition.map(_ -> null)) + map.iterator.map(_._1) + } + partitioner match { + case Some(p) if numPartitions == partitions.length => + mapPartitions(removeDuplicatesInPartition, preservesPartitioning = true) + case _ => map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + } } /** diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index b143a468a1baf..2227698cf1ad2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -95,6 +95,18 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(!deserial.toString().isEmpty()) } + test("distinct with known partitioner preserves partitioning") { + val rdd = sc.parallelize(1.to(100), 10).map(x => (x % 10, x % 10)).sortByKey() + val initialPartitioner = rdd.partitioner + val distinctRdd = rdd.distinct() + val resultingPartitioner = distinctRdd.partitioner + assert(initialPartitioner === resultingPartitioner) + val distinctRddDifferent = rdd.distinct(5) + val distinctRddDifferentPartitioner = distinctRddDifferent.partitioner + assert(initialPartitioner != distinctRddDifferentPartitioner) + assert(distinctRdd.collect().sorted === distinctRddDifferent.collect().sorted) + } + test("countApproxDistinct") { def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble From a1adde54086469b45950946d9143d17daab01f18 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Sep 2018 21:19:25 +0800 Subject: [PATCH 1719/2461] [SPARK-24341][SQL][FOLLOWUP] remove duplicated error checking ## What changes were proposed in this pull request? There are 2 places we check for problematic `InSubquery`: the rule `ResolveSubquery` and `InSubquery.checkInputDataTypes`. We should unify them. ## How was this patch tested? existing tests Closes #22563 from cloud-fan/followup. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 16 +---- .../sql/catalyst/expressions/predicates.scala | 60 +++++++++---------- .../sql-tests/results/datetime.sql.out | 5 +- .../results/higher-order-functions.sql.out | 1 + .../subquery/in-subquery/in-basic.sql.out | 10 ++-- .../subq-input-typecheck.sql.out | 20 +++---- 6 files changed, 49 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e3b17121bf350..7034dfdafad33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1436,21 +1436,7 @@ class Analyzer( val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - val subqueryOutput = expr.plan.output - val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) - if (values.length != subqueryOutput.length) { - throw new AnalysisException( - s"""Cannot analyze ${resolvedIn.sql}. - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${values.length} - |#columns in right hand side: ${subqueryOutput.length} - |Left side columns: - |[${values.map(_.sql).mkString(", ")}] - |Right side columns: - |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) - } - resolvedIn + InSubquery(values, expr.asInstanceOf[ListQuery]) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 149bd79278a54..2125340f38ee8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -144,7 +144,7 @@ case class Not(child: Expression) case class InSubquery(values: Seq[Expression], query: ListQuery) extends Predicate with Unevaluable { - @transient lazy val value: Expression = if (values.length > 1) { + @transient private lazy val value: Expression = if (values.length > 1) { CreateNamedStruct(values.zipWithIndex.flatMap { case (v: NamedExpression, _) => Seq(Literal(v.name), v) case (v, idx) => Seq(Literal(s"_$idx"), v) @@ -155,37 +155,35 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, - ignoreNullability = true) - if (mismatchOpt) { - if (values.length != query.childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${values.length}. - |#columns in right hand side: ${query.childOutputs.length}. - |Left side columns: - |[${values.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = values.zip(query.childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${values.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else if (!DataType.equalsStructurally( + query.dataType, value.dataType, ignoreNullability = true)) { + + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..63aa00426ea32 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -82,9 +82,10 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 32d20d1b73415..1b7c6f4f76250 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -201,6 +201,7 @@ struct<> -- !query 20 output + -- !query 21 select transform_keys(ys, (k, v) -> k) as v from nested -- !query 21 schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out index 088db55d66406..686fe4975379b 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -41,15 +41,15 @@ select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). +cannot resolve '(named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery()))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2 -#columns in right hand side: 1 +#columns in left hand side: 2. +#columns in right hand side: 1. Left side columns: -[tab_a.`a1`, tab_a.`b1`] +[tab_a.`a1`, tab_a.`b1`]. Right side columns: -[`named_struct(a2, a2, b2, b2)`]; +[`named_struct(a2, a2, b2, b2)`].; -- !query 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index c52e5706deeee..dcd30055bca19 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -92,15 +92,15 @@ t1a IN (SELECT t2a, t2b struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). +cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 1 -#columns in right hand side: 2 +#columns in left hand side: 1. +#columns in right hand side: 2. Left side columns: -[t1.`t1a`] +[t1.`t1a`]. Right side columns: -[t2.`t2a`, t2.`t2b`]; +[t2.`t2a`, t2.`t2b`].; -- !query 8 @@ -113,15 +113,15 @@ WHERE struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). +cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2 -#columns in right hand side: 1 +#columns in left hand side: 2. +#columns in right hand side: 1. Left side columns: -[t1.`t1a`, t1.`t1b`] +[t1.`t1a`, t1.`t1b`]. Right side columns: -[t2.`t2a`]; +[t2.`t2a`].; -- !query 9 From 5fd22d05363dd8c0e1b10f3822ccb71eb42f6db9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 27 Sep 2018 09:26:50 -0700 Subject: [PATCH 1720/2461] [SPARK-25546][CORE] Don't cache value of EVENT_LOG_CALLSITE_LONG_FORM. Caching the value of that config means different instances of SparkEnv will always use whatever was the first value to be read. It also breaks tests that use RDDInfo outside of the scope of a SparkContext. Since this is not a performance sensitive area, there's no advantage in caching the config value. Closes #22558 from vanzin/SPARK-25546. Authored-by: Marcelo Vanzin Signed-off-by: Dongjoon Hyun --- core/src/main/scala/org/apache/spark/storage/RDDInfo.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 19f86569c1e3c..917cfab1c699a 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -55,11 +55,13 @@ class RDDInfo( } private[spark] object RDDInfo { - private val callsiteLongForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_LONG_FORM) - def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callsiteLongForm = Option(SparkEnv.get) + .map(_.conf.get(EVENT_LOG_CALLSITE_LONG_FORM)) + .getOrElse(false) + val callSite = if (callsiteLongForm) { rdd.creationSite.longForm } else { From 3b7395fe025a4c9a591835e53ac6ca05be6868f1 Mon Sep 17 00:00:00 2001 From: Chris Zhao Date: Thu, 27 Sep 2018 17:55:08 -0700 Subject: [PATCH 1721/2461] [SPARK-25459][SQL] Add viewOriginalText back to CatalogTable ## What changes were proposed in this pull request? The `show create table` will show a lot of generated attributes for views that created by older Spark version. This PR will basically revert https://issues.apache.org/jira/browse/SPARK-19272 back, so when you `DESC [FORMATTED|EXTENDED] view` will show the original view DDL text. ## How was this patch tested? Unit test. Closes #22458 from zheyuan28/testbranch. Lead-authored-by: Chris Zhao Co-authored-by: Christopher Zhao Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/catalog/interface.scala | 4 +++- .../spark/sql/execution/command/views.scala | 2 ++ .../sql-tests/results/describe.sql.out | 2 ++ .../sql/hive/client/HiveClientImpl.scala | 9 +++++--- .../sql/hive/execution/HiveDDLSuite.scala | 22 +++++++++++++++++++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 30ded13410f7c..817abebd72ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -244,7 +244,8 @@ case class CatalogTable( unsupportedFeatures: Seq[String] = Seq.empty, tracksPartitionsInCatalog: Boolean = false, schemaPreservesCase: Boolean = true, - ignoredProperties: Map[String, String] = Map.empty) { + ignoredProperties: Map[String, String] = Map.empty, + viewOriginalText: Option[String] = None) { import CatalogTable._ @@ -331,6 +332,7 @@ case class CatalogTable( comment.foreach(map.put("Comment", _)) if (tableType == CatalogTableType.VIEW) { viewText.foreach(map.put("View Text", _)) + viewOriginalText.foreach(map.put("View Original Text", _)) viewDefaultDatabase.foreach(map.put("View Default Database", _)) if (viewQueryColumnNames.nonEmpty) { map.put("View Query Output Columns", viewQueryColumnNames.mkString("[", ", ", "]")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 5172f32ec7b9c..cd34dfafd1320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -242,6 +242,7 @@ case class CreateViewCommand( storage = CatalogStorageFormat.empty, schema = aliasPlan(session, analyzedPlan).schema, properties = newProperties, + viewOriginalText = originalText, viewText = originalText, comment = comment ) @@ -299,6 +300,7 @@ case class AlterViewAsCommand( val updatedViewMeta = viewMeta.copy( schema = analyzedPlan.schema, properties = newProperties, + viewOriginalText = Some(originalText), viewText = Some(originalText)) session.sessionState.catalog.alterTable(updatedViewMeta) diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 79390cb424444..9c4b70d1b1ab7 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -474,6 +474,7 @@ Last Access [not included in comparison] Created By [not included in comparison] Type VIEW View Text SELECT * FROM t +View Original Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] @@ -497,6 +498,7 @@ Last Access [not included in comparison] Created By [not included in comparison] Type VIEW View Text SELECT * FROM t +View Original Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 02c1ed93eb2f8..5e9b324a168e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -467,9 +467,12 @@ private[hive] class HiveClientImpl( properties = filteredProperties, stats = readHiveStats(properties), comment = comment, - // In older versions of Spark(before 2.2.0), we expand the view original text and store - // that into `viewExpandedText`, and that should be used in view resolution. So we get - // `viewExpandedText` instead of `viewOriginalText` for viewText here. + // In older versions of Spark(before 2.2.0), we expand the view original text and + // store that into `viewExpandedText`, that should be used in view resolution. + // We get `viewExpandedText` as viewText, and also get `viewOriginalText` in order to + // display the original view text in `DESC [EXTENDED|FORMATTED] table` command for views + // that created by older versions of Spark. + viewOriginalText = Option(h.getViewOriginalText), viewText = Option(h.getViewExpandedText), unsupportedFeatures = unsupportedFeatures, ignoredProperties = ignoredProperties.toMap) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index be1aa83d682b2..fd38944a5dd2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2348,4 +2348,26 @@ class HiveDDLSuite } } } + + test("desc formatted table should also show viewOriginalText for views") { + withView("v1", "v2") { + sql("CREATE VIEW v1 AS SELECT 1 AS value") + assert(sql("DESC FORMATTED v1").collect().containsSlice( + Seq( + Row("Type", "VIEW", ""), + Row("View Text", "SELECT 1 AS value", ""), + Row("View Original Text", "SELECT 1 AS value", "") + ) + )) + + hiveClient.runSqlHive("CREATE VIEW v2 AS SELECT * FROM (SELECT 1) T") + assert(sql("DESC FORMATTED v2").collect().containsSlice( + Seq( + Row("Type", "VIEW", ""), + Row("View Text", "SELECT `t`.`_c0` FROM (SELECT 1) `T`", ""), + Row("View Original Text", "SELECT * FROM (SELECT 1) T", "") + ) + )) + } + } } From e120a38c0cdfb569c9151bef4d53e98175da2b25 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Fri, 28 Sep 2018 00:09:06 -0700 Subject: [PATCH 1722/2461] [SPARK-25505][SQL] The output order of grouping columns in Pivot is different from the input order ## What changes were proposed in this pull request? The grouping columns from a Pivot query are inferred as "input columns - pivot columns - pivot aggregate columns", where input columns are the output of the child relation of Pivot. The grouping columns will be the leading columns in the pivot output and they should preserve the same order as specified by the input. For example, ``` SELECT * FROM ( SELECT course, earnings, "a" as a, "z" as z, "b" as b, "y" as y, "c" as c, "x" as x, "d" as d, "w" as w FROM courseSales ) PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) ``` The output columns should be "a, z, b, y, c, x, d, w, ..." but now it is "a, b, c, d, w, x, y, z, ..." The fix is to use the child plan's `output` instead of `outputSet` so that the order can be preserved. ## How was this patch tested? Added UT. Closes #22519 from maryannxue/spark-25505. Authored-by: maryannxue Signed-off-by: gatorsmile --- .../spark/sql/catalyst/analysis/Analyzer.scala | 7 +++++-- .../test/resources/sql-tests/inputs/pivot.sql | 10 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 17 ++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7034dfdafad33..c0a73083c2d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -554,8 +554,11 @@ class Analyzer( Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) } // Group-by expressions coming from SQL are implicit and need to be deduced. - val groupByExprs = groupByExprsOpt.getOrElse( - (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) + val groupByExprs = groupByExprsOpt.getOrElse { + val pivotColAndAggRefs = + (pivotColumn.references ++ aggregates.flatMap(_.references)).toSet + child.output.filterNot(pivotColAndAggRefs.contains) + } val singleAgg = aggregates.size == 1 def outputName(value: Expression, aggregate: Expression): String = { val stringValue = value match { diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 1f607b334dc18..81547ab46ce09 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -287,3 +287,13 @@ PIVOT ( sum(earnings) FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) ); + +-- grouping columns output in the same order as input +SELECT * FROM ( + SELECT course, earnings, "a" as a, "z" as z, "b" as b, "y" as y, "c" as c, "x" as x, "d" as d, "w" as w + FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 2dd92930f92aa..487883a7f3847 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 31 +-- Number of queries: 32 -- !query 0 @@ -476,3 +476,18 @@ struct<> -- !query 30 output org.apache.spark.sql.AnalysisException Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.; + + +-- !query 31 +SELECT * FROM ( + SELECT course, earnings, "a" as a, "z" as z, "b" as b, "y" as y, "c" as c, "x" as x, "d" as d, "w" as w + FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 31 schema +struct +-- !query 31 output +a z b y c x d w 63000 50000 From 0b33f08683a41f6f3a6ec02c327010c0722cc1d1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 28 Sep 2018 14:10:24 -0700 Subject: [PATCH 1723/2461] [SPARK-23285][DOC][FOLLOWUP] Fix missing markup tag ## What changes were proposed in this pull request? This adds a missing markup tag. This should go to `master/branch-2.4`. ## How was this patch tested? Manual via `SKIP_API=1 jekyll build`. Closes #22585 from dongjoon-hyun/SPARK-23285. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/running-on-kubernetes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 4ae7acaae2314..840e306fc1040 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -691,6 +691,7 @@ specific to Spark on Kubernetes. Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. + From b7d80349b0e367d78cab238e62c2ec353f0f12b3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 28 Sep 2018 14:29:56 -0700 Subject: [PATCH 1724/2461] [SPARK-25542][CORE][TEST] Move flaky test in OpenHashMapSuite to OpenHashSetSuite and make it against OpenHashSet ## What changes were proposed in this pull request? The specified test in OpenHashMapSuite to test large items is somehow flaky to throw OOM. By considering the original work #6763 that added this test, the test can be against OpenHashSetSuite. And by doing this should be to save memory because OpenHashMap allocates two more arrays when growing the map/set. ## How was this patch tested? Existing tests. Closes #22569 from viirya/SPARK-25542. Authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun --- .../spark/util/collection/OpenHashMapSuite.scala | 10 ---------- .../spark/util/collection/OpenHashSetSuite.scala | 13 +++++++++++++ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 151235dd0fb90..68bcc5e5a5092 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -185,16 +185,6 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { assert(map.contains(null)) } - test("support for more than 12M items") { - val cnt = 12000000 // 12M - val map = new OpenHashMap[Int, Int](cnt) - for (i <- 0 until cnt) { - map(i) = 1 - } - val numInvalidValues = map.iterator.count(_._2 == 0) - assertResult(0)(numInvalidValues) - } - test("distinguish between the 0/0.0/0L and null") { val specializedMap1 = new OpenHashMap[String, Long] specializedMap1("a") = null.asInstanceOf[Long] diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index b887f937a9da9..44d2118d77945 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -255,4 +255,17 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { val set = new OpenHashSet[Long](0) assert(set.size === 0) } + + test("support for more than 12M items") { + val cnt = 12000000 // 12M + val set = new OpenHashSet[Int](cnt) + for (i <- 0 until cnt) { + set.add(i) + assert(set.contains(i)) + + val pos1 = set.getPos(i) + val pos2 = set.addWithoutResize(i) & OpenHashSet.POSITION_MASK + assert(pos1 == pos2) + } + } } From 7deef7a49b95c5de5af10419ece8c6a36d96ac61 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 28 Sep 2018 15:03:06 -0700 Subject: [PATCH 1725/2461] [SPARK-25458][SQL] Support FOR ALL COLUMNS in ANALYZE TABLE ## What changes were proposed in this pull request? **Description from the JIRA :** Currently, to collect the statistics of all the columns, users need to specify the names of all the columns when calling the command "ANALYZE TABLE ... FOR COLUMNS...". This is not user friendly. Instead, we can introduce the following SQL command to achieve it without specifying the column names. ``` ANALYZE TABLE [db_name.]tablename COMPUTE STATISTICS FOR ALL COLUMNS; ``` ## How was this patch tested? Added new tests in SparkSqlParserSuite and StatisticsSuite Closes #22566 from dilipbiswal/SPARK-25458. Authored-by: Dilip Biswal Signed-off-by: gatorsmile --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../spark/sql/execution/SparkSqlParser.scala | 26 +++++--- .../command/AnalyzeColumnCommand.scala | 61 ++++++++++++------- .../sql/execution/SparkSqlParserSuite.scala | 14 ++++- .../spark/sql/hive/StatisticsSuite.scala | 47 +++++++++++++- 5 files changed, 115 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 94283f59011a8..16665eb0d7374 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -99,7 +99,7 @@ statement | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS - (identifier | FOR COLUMNS identifierSeq)? #analyze + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze | ALTER TABLE tableIdentifier ADD COLUMNS '(' columns=colTypeList ')' #addTableColumns | ALTER (TABLE | VIEW) from=tableIdentifier diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 89cb63784c0f4..4ed14d3e077f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -102,15 +102,29 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * ANALYZE TABLE [db_name.]tablename COMPUTE STATISTICS FOR COLUMNS column1, column2; * }}} + * + * Example SQL for analyzing all columns of a table: + * {{{ + * ANALYZE TABLE [db_name.]tablename COMPUTE STATISTICS FOR ALL COLUMNS; + * }}} */ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { + def checkPartitionSpec(): Unit = { + if (ctx.partitionSpec != null) { + logWarning("Partition specification is ignored when collecting column statistics: " + + ctx.partitionSpec.getText) + } + } if (ctx.identifier != null && ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } val table = visitTableIdentifier(ctx.tableIdentifier) - if (ctx.identifierSeq() == null) { + if (ctx.ALL() != null) { + checkPartitionSpec() + AnalyzeColumnCommand(table, None, allColumns = true) + } else if (ctx.identifierSeq() == null) { if (ctx.partitionSpec != null) { AnalyzePartitionCommand(table, visitPartitionSpec(ctx.partitionSpec), noscan = ctx.identifier != null) @@ -118,13 +132,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { AnalyzeTableCommand(table, noscan = ctx.identifier != null) } } else { - if (ctx.partitionSpec != null) { - logWarning("Partition specification is ignored when collecting column statistics: " + - ctx.partitionSpec.getText) - } - AnalyzeColumnCommand( - table, - visitIdentifierSeq(ctx.identifierSeq())) + checkPartitionSpec() + AnalyzeColumnCommand(table, + Option(visitIdentifierSeq(ctx.identifierSeq())), allColumns = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 3fea6d7c7fbfe..93447a52097ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -33,13 +33,17 @@ import org.apache.spark.sql.types._ /** * Analyzes the given columns of the given table to generate statistics, which will be used in - * query optimizations. + * query optimizations. Parameter `allColumns` may be specified to generate statistics of all the + * columns of a given table. */ case class AnalyzeColumnCommand( tableIdent: TableIdentifier, - columnNames: Seq[String]) extends RunnableCommand { + columnNames: Option[Seq[String]], + allColumns: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + require((columnNames.isDefined ^ allColumns), "Parameter `columnNames` or `allColumns` are " + + "mutually exclusive. Only one of them should be specified.") val sessionState = sparkSession.sessionState val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) @@ -48,9 +52,12 @@ case class AnalyzeColumnCommand( throw new AnalysisException("ANALYZE TABLE is not supported on views.") } val sizeInBytes = CommandUtils.calculateTotalSize(sparkSession, tableMeta) + val relation = sparkSession.table(tableIdent).logicalPlan + val columnsToAnalyze = getColumnsToAnalyze(tableIdent, relation, columnNames, allColumns) - // Compute stats for each column - val (rowCount, newColStats) = computeColumnStats(sparkSession, tableIdentWithDB, columnNames) + // Compute stats for the computed list of columns. + val (rowCount, newColStats) = + computeColumnStats(sparkSession, relation, columnsToAnalyze) // We also update table-level stats in order to keep them consistent with column-level stats. val statistics = CatalogStatistics( @@ -64,31 +71,39 @@ case class AnalyzeColumnCommand( Seq.empty[Row] } - /** - * Compute stats for the given columns. - * @return (row count, map from column name to CatalogColumnStats) - */ - private def computeColumnStats( - sparkSession: SparkSession, + private def getColumnsToAnalyze( tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { - - val conf = sparkSession.sessionState.conf - val relation = sparkSession.table(tableIdent).logicalPlan - // Resolve the column names and dedup using AttributeSet - val attributesToAnalyze = columnNames.map { col => - val exprOption = relation.output.find(attr => conf.resolver(attr.name, col)) - exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) + relation: LogicalPlan, + columnNames: Option[Seq[String]], + allColumns: Boolean = false): Seq[Attribute] = { + val columnsToAnalyze = if (allColumns) { + relation.output + } else { + columnNames.get.map { col => + val exprOption = relation.output.find(attr => conf.resolver(attr.name, col)) + exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) + } } - // Make sure the column types are supported for stats gathering. - attributesToAnalyze.foreach { attr => + columnsToAnalyze.foreach { attr => if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") } } + columnsToAnalyze + } + + /** + * Compute stats for the given columns. + * @return (row count, map from column name to CatalogColumnStats) + */ + private def computeColumnStats( + sparkSession: SparkSession, + relation: LogicalPlan, + columns: Seq[Attribute]): (Long, Map[String, CatalogColumnStat]) = { + val conf = sparkSession.sessionState.conf // Collect statistics per column. // If no histogram is required, we run a job to compute basic column stats such as @@ -99,20 +114,20 @@ case class AnalyzeColumnCommand( // 2. use the percentiles as value intervals of bins, e.g. [p(0), p(1/n)], // [p(1/n), p(2/n)], ..., [p((n-1)/n), p(1)], and then count ndv in each bin. // Basic column stats will be computed together in the second job. - val attributePercentiles = computePercentiles(attributesToAnalyze, sparkSession, relation) + val attributePercentiles = computePercentiles(columns, sparkSession, relation) // The first element in the result will be the overall row count, the following elements // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) + columns.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) .executedPlan.executeTake(1).head val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => + val columnStats = columns.zipWithIndex.map { case (attr, i) => // according to `statExprs`, the stats struct always have 7 fields. (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 28a060aff47b5..31b9bcdafbab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -323,12 +323,22 @@ class SparkSqlParserSuite extends AnalysisTest { intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS", "") assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value", - AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + AnalyzeColumnCommand(TableIdentifier("t"), Option(Seq("key", "value")), allColumns = false)) // Partition specified - should be ignored assertEqual("ANALYZE TABLE t PARTITION(ds='2017-06-10') " + "COMPUTE STATISTICS FOR COLUMNS key, value", - AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value"))) + AnalyzeColumnCommand(TableIdentifier("t"), Option(Seq("key", "value")), allColumns = false)) + + // Partition specified should be ignored in case of COMPUTE STATISTICS FOR ALL COLUMNS + assertEqual("ANALYZE TABLE t PARTITION(ds='2017-06-10') " + + "COMPUTE STATISTICS FOR ALL COLUMNS", + AnalyzeColumnCommand(TableIdentifier("t"), None, allColumns = true)) + + intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR ALL COLUMNS key, value", + "mismatched input 'key' expecting ") + intercept("ANALYZE TABLE t COMPUTE STATISTICS FOR ALL", + "missing 'COLUMNS' at ''") } test("query organization") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index d8ffb29a59317..57f1c243a70de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} -import org.apache.spark.sql.execution.command.{CommandUtils, DDLUtils} +import org.apache.spark.sql.execution.command.{AnalyzeColumnCommand, CommandUtils, DDLUtils} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.HiveExternalCatalog._ @@ -653,6 +653,51 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("collecting statistics for all columns") { + val table = "update_col_stats_table" + withTable(table) { + sql(s"CREATE TABLE $table (c1 INT, c2 STRING, c3 DOUBLE)") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS") + val fetchedStats0 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats0.get.colStats == Map( + "c1" -> CatalogColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + "c2" -> CatalogColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(20), maxLen = Some(20)))) + + // Insert new data and analyze: have the latest column stats. + sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") + sql(s"INSERT INTO TABLE $table SELECT 1, 'b', null") + + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR ALL COLUMNS") + val fetchedStats1 = + checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) + assert(fetchedStats1.get.colStats == Map( + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(1), min = Some("10.0"), max = Some("10.0"), + nullCount = Some(1), avgLen = Some(8), maxLen = Some(8)), + "c2" -> CatalogColumnStat(distinctCount = Some(2), min = None, max = None, + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)))) + } + } + + test("analyze column command paramaters validation") { + val e1 = intercept[IllegalArgumentException] { + AnalyzeColumnCommand(TableIdentifier("test"), Option(Seq("c1")), true).run(spark) + } + assert(e1.getMessage.contains("Parameter `columnNames` or `allColumns` are" + + " mutually exclusive")) + val e2 = intercept[IllegalArgumentException] { + AnalyzeColumnCommand(TableIdentifier("test"), None, false).run(spark) + } + assert(e1.getMessage.contains("Parameter `columnNames` or `allColumns` are" + + " mutually exclusive")) + } + private def createNonPartitionedTable( tabName: String, analyzedBySpark: Boolean = true, From a281465686e8099bb2c0fa4f2ef4822b6e634269 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 28 Sep 2018 15:08:15 -0700 Subject: [PATCH 1726/2461] [SPARK-25429][SQL] Use Set instead of Array to improve lookup performance ## What changes were proposed in this pull request? Use `Set` instead of `Array` to improve `accumulatorIds.contains(acc.id)` performance. This PR close https://github.com/apache/spark/pull/22420 ## How was this patch tested? manual tests. Benchmark code: ```scala def benchmark(func: () => Unit): Long = { val start = System.currentTimeMillis() func() val end = System.currentTimeMillis() end - start } val range = Range(1, 1000000) val set = range.toSet val array = range.toArray for (i <- 0 until 5) { val setExecutionTime = benchmark(() => for (i <- 0 until 500) { set.contains(scala.util.Random.nextInt()) }) val arrayExecutionTime = benchmark(() => for (i <- 0 until 500) { array.contains(scala.util.Random.nextInt()) }) println(s"set execution time: $setExecutionTime, array execution time: $arrayExecutionTime") } ``` Benchmark result: ``` set execution time: 4, array execution time: 2760 set execution time: 1, array execution time: 1911 set execution time: 3, array execution time: 2043 set execution time: 12, array execution time: 2214 set execution time: 6, array execution time: 1770 ``` Closes #22579 from wangyum/SPARK-25429. Authored-by: Yuming Wang Signed-off-by: gatorsmile --- .../spark/sql/execution/ui/SQLAppStatusListener.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d254af400a7cf..1199eeca959d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -81,9 +81,9 @@ class SQLAppStatusListener( // Record the accumulator IDs for the stages of this job, so that the code that keeps // track of the metrics knows which accumulators to look at. - val accumIds = exec.metrics.map(_.accumulatorId).sorted.toList + val accumIds = exec.metrics.map(_.accumulatorId).toSet event.stageIds.foreach { id => - stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds.toArray, new ConcurrentHashMap())) + stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap())) } exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) @@ -382,7 +382,7 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { private class LiveStageMetrics( val stageId: Int, var attemptId: Int, - val accumulatorIds: Array[Long], + val accumulatorIds: Set[Long], val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics]) private class LiveTaskMetrics( From 9362c5cc273fdd09f9b3b512e2f6b64bcefc25ab Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Fri, 28 Sep 2018 16:34:17 -0700 Subject: [PATCH 1727/2461] [SPARK-25449][CORE] Heartbeat shouldn't include accumulators for zero metrics ## What changes were proposed in this pull request? Heartbeat shouldn't include accumulators for zero metrics. Heartbeats sent from executors to the driver every 10 seconds contain metrics and are generally on the order of a few KBs. However, for large jobs with lots of tasks, heartbeats can be on the order of tens of MBs, causing tasks to die with heartbeat failures. We can mitigate this by not sending zero metrics to the driver. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22473 from mukulmurthy/25449-heartbeat. Authored-by: Mukul Murthy Signed-off-by: Shixiong Zhu --- .../scala/org/apache/spark/SparkConf.scala | 11 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 40 +++++-- .../spark/internal/config/package.scala | 14 +++ .../apache/spark/executor/ExecutorSuite.scala | 111 ++++++++++++++++-- .../MesosCoarseGrainedSchedulerBackend.scala | 3 +- 6 files changed, 154 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e0f98f1aca071..81aa31d79ba82 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -609,13 +609,14 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") - val executorTimeoutThreshold = getTimeAsSeconds("spark.network.timeout", "120s") - val executorHeartbeatInterval = getTimeAsSeconds("spark.executor.heartbeatInterval", "10s") + val executorTimeoutThresholdMs = + getTimeAsSeconds("spark.network.timeout", "120s") * 1000 + val executorHeartbeatIntervalMs = get(EXECUTOR_HEARTBEAT_INTERVAL) // If spark.executor.heartbeatInterval bigger than spark.network.timeout, // it will almost always cause ExecutorLostFailure. See SPARK-22754. - require(executorTimeoutThreshold > executorHeartbeatInterval, "The value of " + - s"spark.network.timeout=${executorTimeoutThreshold}s must be no less than the value of " + - s"spark.executor.heartbeatInterval=${executorHeartbeatInterval}s.") + require(executorTimeoutThresholdMs > executorHeartbeatIntervalMs, "The value of " + + s"spark.network.timeout=${executorTimeoutThresholdMs}ms must be no less than the value of " + + s"spark.executor.heartbeatInterval=${executorHeartbeatIntervalMs}ms.") } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d943087ab6b80..0a66dae94dbd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -499,7 +499,7 @@ class SparkContext(config: SparkConf) extends Logging { // create and start the heartbeater for collecting memory metrics _heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, "driver-heartbeater", - conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + conf.get(EXECUTOR_HEARTBEAT_INTERVAL)) _heartbeater.start() // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index eba708da7798e..61deb543d8747 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -28,6 +28,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import scala.concurrent.duration._ import scala.util.control.NonFatal import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -120,7 +121,7 @@ private[spark] class Executor( } // Whether to load classes in user jars before those in Spark jars - private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) + private val userClassPathFirst = conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) // Whether to monitor killed / interrupted tasks private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false) @@ -170,21 +171,32 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.get(EXECUTOR_HEARTBEAT_MAX_FAILURES) + + /** + * Whether to drop empty accumulators from heartbeats sent to the driver. Including the empty + * accumulators (that satisfy isZero) can make the size of the heartbeat message very large. + */ + private val HEARTBEAT_DROP_ZEROES = conf.get(EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES) + + /** + * Interval to send heartbeats, in milliseconds + */ + private val HEARTBEAT_INTERVAL_MS = conf.get(EXECUTOR_HEARTBEAT_INTERVAL) + // Executor for the heartbeat task. private val heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, - "executor-heartbeater", conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + "executor-heartbeater", HEARTBEAT_INTERVAL_MS) // must be initialized before running startDriverHeartbeat() private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) - /** - * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` - * times, it should kill itself. The default value is 60. It means we will retry to send - * heartbeats about 10 minutes because the heartbeat interval is 10s. - */ - private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) - /** * Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each * successful heartbeat will reset it to 0. @@ -834,7 +846,13 @@ private[spark] class Executor( if (taskRunner.task != null) { taskRunner.task.metrics.mergeShuffleReadMetrics() taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulators())) + val accumulatorsToReport = + if (HEARTBEAT_DROP_ZEROES) { + taskRunner.task.metrics.accumulators().filterNot(_.isZero) + } else { + taskRunner.task.metrics.accumulators() + } + accumUpdates += ((taskRunner.taskId, accumulatorsToReport)) } } @@ -842,7 +860,7 @@ private[spark] class Executor( executorUpdates) try { val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( - message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) + message, new RpcTimeout(HEARTBEAT_INTERVAL_MS.millis, EXECUTOR_HEARTBEAT_INTERVAL.key)) if (response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7f6342208350a..e8b1d8859cc44 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -83,6 +83,20 @@ package object config { private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional + private[spark] val EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES = + ConfigBuilder("spark.executor.heartbeat.dropZeroAccumulatorUpdates") + .internal() + .booleanConf + .createWithDefault(true) + + private[spark] val EXECUTOR_HEARTBEAT_INTERVAL = + ConfigBuilder("spark.executor.heartbeatInterval") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("10s") + + private[spark] val EXECUTOR_HEARTBEAT_MAX_FAILURES = + ConfigBuilder("spark.executor.heartbeat.maxFailures").internal().intConf.createWithDefault(60) + private[spark] val EXECUTOR_JAVA_OPTIONS = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 77a7668d3a1d1..1f8a65707b2f7 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -21,9 +21,10 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Map import scala.concurrent.duration._ import scala.language.postfixOps @@ -33,22 +34,25 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester import org.scalatest.concurrent.Eventually import org.scalatest.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.memory.MemoryManager +import org.apache.spark.internal.config._ +import org.apache.spark.memory.TestMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcTimeout} +import org.apache.spark.scheduler.{FakeTask, ResultTask, Task, TaskDescription} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.storage.{BlockManager, BlockManagerId} +import org.apache.spark.util.{LongAccumulator, UninterruptibleThread} -class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { +class ExecutorSuite extends SparkFunSuite + with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy @@ -252,18 +256,107 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } } + test("Heartbeat should drop zero accumulator updates") { + heartbeatZeroAccumulatorUpdateTest(true) + } + + test("Heartbeat should not drop zero accumulator updates when the conf is disabled") { + heartbeatZeroAccumulatorUpdateTest(false) + } + + private def withHeartbeatExecutor(confs: (String, String)*) + (f: (Executor, ArrayBuffer[Heartbeat]) => Unit): Unit = { + val conf = new SparkConf + confs.foreach { case (k, v) => conf.set(k, v) } + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val executor = + new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + val executorClass = classOf[Executor] + + // Save all heartbeats sent into an ArrayBuffer for verification + val heartbeats = ArrayBuffer[Heartbeat]() + val mockReceiver = mock[RpcEndpointRef] + when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any)) + .thenAnswer(new Answer[HeartbeatResponse] { + override def answer(invocation: InvocationOnMock): HeartbeatResponse = { + val args = invocation.getArguments() + val mock = invocation.getMock + heartbeats += args(0).asInstanceOf[Heartbeat] + HeartbeatResponse(false) + } + }) + val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef") + receiverRef.setAccessible(true) + receiverRef.set(executor, mockReceiver) + + f(executor, heartbeats) + } + + private def heartbeatZeroAccumulatorUpdateTest(dropZeroMetrics: Boolean): Unit = { + val c = EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES.key -> dropZeroMetrics.toString + withHeartbeatExecutor(c) { (executor, heartbeats) => + val reportHeartbeat = PrivateMethod[Unit]('reportHeartBeat) + + // When no tasks are running, there should be no accumulators sent in heartbeat + executor.invokePrivate(reportHeartbeat()) + // invokeReportHeartbeat(executor) + assert(heartbeats.length == 1) + assert(heartbeats(0).accumUpdates.length == 0, + "No updates should be sent when no tasks are running") + + // When we start a task with a nonzero accumulator, that should end up in the heartbeat + val metrics = new TaskMetrics() + val nonZeroAccumulator = new LongAccumulator() + nonZeroAccumulator.add(1) + metrics.registerAccumulator(nonZeroAccumulator) + + val executorClass = classOf[Executor] + val tasksMap = { + val field = + executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks") + field.setAccessible(true) + field.get(executor).asInstanceOf[ConcurrentHashMap[Long, executor.TaskRunner]] + } + val mockTaskRunner = mock[executor.TaskRunner] + val mockTask = mock[Task[Any]] + when(mockTask.metrics).thenReturn(metrics) + when(mockTaskRunner.taskId).thenReturn(6) + when(mockTaskRunner.task).thenReturn(mockTask) + when(mockTaskRunner.startGCTime).thenReturn(1) + tasksMap.put(6, mockTaskRunner) + + executor.invokePrivate(reportHeartbeat()) + assert(heartbeats.length == 2) + val updates = heartbeats(1).accumUpdates + assert(updates.length == 1 && updates(0)._1 == 6, + "Heartbeat should only send update for the one task running") + val accumsSent = updates(0)._2.length + assert(accumsSent > 0, "The nonzero accumulator we added should be sent") + if (dropZeroMetrics) { + assert(accumsSent == metrics.accumulators().count(!_.isZero), + "The number of accumulators sent should match the number of nonzero accumulators") + } else { + assert(accumsSent == metrics.accumulators().length, + "The number of accumulators sent should match the number of total accumulators") + } + } + } + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv] val mockMetricsSystem = mock[MetricsSystem] - val mockMemoryManager = mock[MemoryManager] + val mockBlockManager = mock[BlockManager] when(mockEnv.conf).thenReturn(conf) when(mockEnv.serializer).thenReturn(serializer) when(mockEnv.serializerManager).thenReturn(mock[SerializerManager]) when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) - when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.memoryManager).thenReturn(new TestMemoryManager(conf)) when(mockEnv.closureSerializer).thenReturn(serializer) + when(mockBlockManager.blockManagerId).thenReturn(BlockManagerId("1", "hostA", 1234)) + when(mockEnv.blockManager).thenReturn(mockBlockManager) SparkEnv.set(mockEnv) mockEnv } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 178de30f0f381..bac0246b7ddc5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -33,6 +33,7 @@ import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.config +import org.apache.spark.internal.config.EXECUTOR_HEARTBEAT_INTERVAL import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient @@ -635,7 +636,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), - sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + sc.conf.get(EXECUTOR_HEARTBEAT_INTERVAL)) slave.shuffleRegistered = true } From 5d726b865948f993911fd5b9730b25cfa94e16c7 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 28 Sep 2018 17:46:11 -0700 Subject: [PATCH 1728/2461] [SPARK-25559][SQL] Remove the unsupported predicates in Parquet when possible ## What changes were proposed in this pull request? Currently, in `ParquetFilters`, if one of the children predicates is not supported by Parquet, the entire predicates will be thrown away. In fact, if the unsupported predicate is in the top level `And` condition or in the child before hitting `Not` or `Or` condition, it can be safely removed. ## How was this patch tested? Tests are added. Closes #22574 from dbtsai/removeUnsupportedPredicatesInParquet. Lead-authored-by: DB Tsai Co-authored-by: Dongjoon Hyun Co-authored-by: DB Tsai Signed-off-by: Dongjoon Hyun --- .../datasources/parquet/ParquetFilters.scala | 38 +++-- .../parquet/ParquetFilterSuite.scala | 147 +++++++++++++++++- 2 files changed, 172 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 0c286defb9406..44a0d209e6e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -394,7 +394,13 @@ private[parquet] class ParquetFilters( */ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToParquetField = getFieldMap(schema) + createFilterHelper(nameToParquetField, predicate, canRemoveOneSideInAnd = true) + } + private def createFilterHelper( + nameToParquetField: Map[String, ParquetField], + predicate: sources.Filter, + canRemoveOneSideInAnd: Boolean): Option[FilterPredicate] = { // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { @@ -488,26 +494,36 @@ private[parquet] class ParquetFilters( .map(_(nameToParquetField(name).fieldName, value)) case sources.And(lhs, rhs) => - // At here, it is not safe to just convert one side if we do not understand the - // other side. Here is an example used to explain the reason. + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to // convert b in ('1'). If we only convert a = 2, we will end up with a filter // NOT(a = 2), which will generate wrong results. - // Pushing one side of AND down is only safe to do at the top level. - // You can see ParquetRelation's initializeLocalJobFunc method as an example. - for { - lhsFilter <- createFilter(schema, lhs) - rhsFilter <- createFilter(schema, rhs) - } yield FilterApi.and(lhsFilter, rhsFilter) + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val lhsFilterOption = createFilterHelper(nameToParquetField, lhs, canRemoveOneSideInAnd) + val rhsFilterOption = createFilterHelper(nameToParquetField, rhs, canRemoveOneSideInAnd) + + (lhsFilterOption, rhsFilterOption) match { + case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) + case (Some(lhsFilter), None) if canRemoveOneSideInAnd => Some(lhsFilter) + case (None, Some(rhsFilter)) if canRemoveOneSideInAnd => Some(rhsFilter) + case _ => None + } case sources.Or(lhs, rhs) => for { - lhsFilter <- createFilter(schema, lhs) - rhsFilter <- createFilter(schema, rhs) + lhsFilter <- createFilterHelper(nameToParquetField, lhs, canRemoveOneSideInAnd = false) + rhsFilter <- createFilterHelper(nameToParquetField, rhs, canRemoveOneSideInAnd = false) } yield FilterApi.or(lhsFilter, rhsFilter) case sources.Not(pred) => - createFilter(schema, pred).map(FilterApi.not) + createFilterHelper(nameToParquetField, pred, canRemoveOneSideInAnd = false) + .map(FilterApi.not) case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 7ebb75009555a..01e41b3c5df36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -750,7 +750,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-12218 Converting conjunctions into Parquet filter predicates") { + test("SPARK-12218 and SPARK-25559 Converting conjunctions into Parquet filter predicates") { val schema = StructType(Seq( StructField("a", IntegerType, nullable = false), StructField("b", StringType, nullable = true), @@ -770,7 +770,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex sources.GreaterThan("c", 1.5D))) } - assertResult(None) { + // Testing when `canRemoveOneSideInAnd == true` + // case sources.And(lhs, rhs) => + // ... + // case (Some(lhsFilter), None) if canRemoveOneSideInAnd => Some(lhsFilter) + assertResult(Some(lt(intColumn("a"), 10: Integer))) { parquetFilters.createFilter( parquetSchema, sources.And( @@ -778,6 +782,83 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex sources.StringContains("b", "prefix"))) } + // Testing when `canRemoveOneSideInAnd == true` + // case sources.And(lhs, rhs) => + // ... + // case (None, Some(rhsFilter)) if canRemoveOneSideInAnd => Some(rhsFilter) + assertResult(Some(lt(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter( + parquetSchema, + sources.And( + sources.StringContains("b", "prefix"), + sources.LessThan("a", 10))) + } + + // Testing complex And conditions + assertResult(Some( + FilterApi.and(lt(intColumn("a"), 10: Integer), gt(intColumn("a"), 5: Integer)))) { + parquetFilters.createFilter( + parquetSchema, + sources.And( + sources.And( + sources.LessThan("a", 10), + sources.StringContains("b", "prefix") + ), + sources.GreaterThan("a", 5))) + } + + // Testing complex And conditions + assertResult(Some( + FilterApi.and(gt(intColumn("a"), 5: Integer), lt(intColumn("a"), 10: Integer)))) { + parquetFilters.createFilter( + parquetSchema, + sources.And( + sources.GreaterThan("a", 5), + sources.And( + sources.StringContains("b", "prefix"), + sources.LessThan("a", 10) + ))) + } + + // Testing + // case sources.Or(lhs, rhs) => + // ... + // lhsFilter <- createFilterHelper(nameToParquetField, lhs, canRemoveOneSideInAnd = false) + assertResult(None) { + parquetFilters.createFilter( + parquetSchema, + sources.Or( + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix")), + sources.GreaterThan("a", 2))) + } + + // Testing + // case sources.Or(lhs, rhs) => + // ... + // rhsFilter <- createFilterHelper(nameToParquetField, rhs, canRemoveOneSideInAnd = false) + assertResult(None) { + parquetFilters.createFilter( + parquetSchema, + sources.Or( + sources.GreaterThan("a", 2), + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix")))) + } + + // Testing + // case sources.Not(pred) => + // createFilterHelper(nameToParquetField, pred, canRemoveOneSideInAnd = false) + // .map(FilterApi.not) + // + // and + // + // Testing when `canRemoveOneSideInAnd == false` + // case sources.And(lhs, rhs) => + // ... + // case (Some(lhsFilter), None) if canRemoveOneSideInAnd => Some(lhsFilter) assertResult(None) { parquetFilters.createFilter( parquetSchema, @@ -786,6 +867,68 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex sources.GreaterThan("a", 1), sources.StringContains("b", "prefix")))) } + + // Testing + // case sources.Not(pred) => + // createFilterHelper(nameToParquetField, pred, canRemoveOneSideInAnd = false) + // .map(FilterApi.not) + // + // and + // + // Testing when `canRemoveOneSideInAnd == false` + // case sources.And(lhs, rhs) => + // ... + // case (None, Some(rhsFilter)) if canRemoveOneSideInAnd => Some(rhsFilter) + assertResult(None) { + parquetFilters.createFilter( + parquetSchema, + sources.Not( + sources.And( + sources.StringContains("b", "prefix"), + sources.GreaterThan("a", 1)))) + } + + // Testing + // case sources.Not(pred) => + // createFilterHelper(nameToParquetField, pred, canRemoveOneSideInAnd = false) + // .map(FilterApi.not) + // + // and + // + // Testing passing `canRemoveOneSideInAnd = false` into + // case sources.And(lhs, rhs) => + // val lhsFilterOption = createFilterHelper(nameToParquetField, lhs, canRemoveOneSideInAnd) + assertResult(None) { + parquetFilters.createFilter( + parquetSchema, + sources.Not( + sources.And( + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix")), + sources.GreaterThan("a", 2)))) + } + + // Testing + // case sources.Not(pred) => + // createFilterHelper(nameToParquetField, pred, canRemoveOneSideInAnd = false) + // .map(FilterApi.not) + // + // and + // + // Testing passing `canRemoveOneSideInAnd = false` into + // case sources.And(lhs, rhs) => + // val rhsFilterOption = createFilterHelper(nameToParquetField, rhs, canRemoveOneSideInAnd) + assertResult(None) { + parquetFilters.createFilter( + parquetSchema, + sources.Not( + sources.And( + sources.GreaterThan("a", 2), + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix"))))) + } } test("SPARK-16371 Do not push down filters when inner name and outer name are the same") { From e99ba8d7c8ec4b4cdd63fd1621f54be993bb0404 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 29 Sep 2018 11:23:37 +0800 Subject: [PATCH 1729/2461] [SPARK-25262][DOC][FOLLOWUP] Fix missing markup tag ## What changes were proposed in this pull request? This adds a missing end markup tag. This should go `master` branch only. ## How was this patch tested? This is a doc-only change. Manual via `SKIP_API=1 jekyll build`. Closes #22584 from dongjoon-hyun/SPARK-25262. Authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- docs/running-on-kubernetes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 840e306fc1040..c7aea2709605d 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -800,7 +800,7 @@ specific to Spark on Kubernetes. - @@ -697,7 +697,7 @@ specific to Spark on Kubernetes. From 623c2ec4ef3776bc5e2cac2c66300ddc6264db54 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 29 Sep 2018 21:50:35 +0800 Subject: [PATCH 1733/2461] [SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java ## What changes were proposed in this pull request? In the PR, I propose to extend implementation of existing method: ``` def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset ``` to support values of the struct type. This allows pivoting by multiple columns combined by `struct`: ``` trainingSales .groupBy($"sales.year") .pivot( pivotColumn = struct(lower($"sales.course"), $"training"), values = Seq( struct(lit("dotnet"), lit("Experts")), struct(lit("java"), lit("Dummies"))) ).agg(sum($"sales.earnings")) ``` ## How was this patch tested? Added a test for values specified via `struct` in Java and Scala. Closes #22316 from MaxGekk/pivoting-by-multiple-columns2. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../spark/sql/RelationalGroupedDataset.scala | 17 ++++++++++++-- .../apache/spark/sql/JavaDataFrameSuite.java | 16 +++++++++++++ .../spark/sql/DataFramePivotSuite.scala | 23 +++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d700fb83b9b70..dbacdbff7383a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -330,6 +330,15 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy("year").pivot("course").sum("earnings") * }}} * + * From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by + * multiple columns, use the `struct` function to combine the columns and values: + * + * {{{ + * df.groupBy("year") + * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) + * .agg(sum($"earnings")) + * }}} + * * @param pivotColumn Name of the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 @@ -413,10 +422,14 @@ class RelationalGroupedDataset protected[sql]( def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => + val valueExprs = values.map(_ match { + case c: Column => c.expr + case v => Literal.apply(v) + }) new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs)) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -561,5 +574,5 @@ private[sql] object RelationalGroupedDataset { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 3f37e5814ccaa..00f41d6484afb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -317,6 +317,22 @@ public void pivot() { Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); } + @Test + public void pivotColumnValues() { + Dataset df = spark.table("courseSales"); + List actual = df.groupBy("year") + .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java"))) + .agg(sum("earnings")).orderBy("year").collectAsList(); + + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); + + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); + } + private String getResource(String resource) { try { // The following "getResource" has different behaviors in SBT and Maven. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b972b9ef93e5e..02ab19754b0c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { assert(exception.getMessage.contains("aggregate functions are not allowed")) } + + test("pivoting column list with values") { + val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training"), Seq( + struct(lit("dotnet"), lit("Experts")), + struct(lit("java"), lit("Dummies"))) + ).agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("pivoting column list") { + val exception = intercept[RuntimeException] { + trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training")) + .agg(sum($"sales.earnings")) + .collect() + } + assert(exception.getMessage.contains("Unsupported literal type")) + } } From f246813afba16fee4d703f09e6302011b11806f3 Mon Sep 17 00:00:00 2001 From: yucai Date: Sat, 29 Sep 2018 09:48:03 -0700 Subject: [PATCH 1734/2461] [SPARK-25508][SQL][TEST] Refactor OrcReadBenchmark to use main method ## What changes were proposed in this pull request? Refactor OrcReadBenchmark to use main method. Generate benchmark result: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain org.apache.spark.sql.hive.orc.OrcReadBenchmark" ``` ## How was this patch tested? manual tests Closes #22580 from yucai/SPARK-25508. Lead-authored-by: yucai Co-authored-by: Yucai Yu Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../benchmarks/OrcReadBenchmark-results.txt | 173 ++++++++++++++++ .../spark/sql/hive/orc/OrcReadBenchmark.scala | 196 ++++-------------- 2 files changed, 212 insertions(+), 157 deletions(-) create mode 100644 sql/hive/benchmarks/OrcReadBenchmark-results.txt diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt new file mode 100644 index 0000000000000..c77f966723d71 --- /dev/null +++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt @@ -0,0 +1,173 @@ +================================================================================================ +SQL Single Numeric Column Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1630 / 1639 9.7 103.6 1.0X +Native ORC Vectorized 253 / 288 62.2 16.1 6.4X +Native ORC Vectorized with copy 227 / 244 69.2 14.5 7.2X +Hive built-in ORC 1980 / 1991 7.9 125.9 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1587 / 1589 9.9 100.9 1.0X +Native ORC Vectorized 227 / 242 69.2 14.5 7.0X +Native ORC Vectorized with copy 228 / 238 69.0 14.5 7.0X +Hive built-in ORC 2323 / 2332 6.8 147.7 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1726 / 1771 9.1 109.7 1.0X +Native ORC Vectorized 309 / 333 50.9 19.7 5.6X +Native ORC Vectorized with copy 313 / 321 50.2 19.9 5.5X +Hive built-in ORC 2668 / 2672 5.9 169.6 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1722 / 1747 9.1 109.5 1.0X +Native ORC Vectorized 395 / 403 39.8 25.1 4.4X +Native ORC Vectorized with copy 399 / 405 39.4 25.4 4.3X +Hive built-in ORC 2767 / 2777 5.7 175.9 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1797 / 1824 8.8 114.2 1.0X +Native ORC Vectorized 434 / 441 36.2 27.6 4.1X +Native ORC Vectorized with copy 437 / 447 36.0 27.8 4.1X +Hive built-in ORC 2701 / 2710 5.8 171.7 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1931 / 2028 8.1 122.8 1.0X +Native ORC Vectorized 542 / 557 29.0 34.5 3.6X +Native ORC Vectorized with copy 550 / 564 28.6 35.0 3.5X +Hive built-in ORC 2816 / 3206 5.6 179.1 0.7X + + +================================================================================================ +Int and String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 4012 / 4068 2.6 382.6 1.0X +Native ORC Vectorized 2337 / 2339 4.5 222.9 1.7X +Native ORC Vectorized with copy 2520 / 2540 4.2 240.3 1.6X +Hive built-in ORC 5503 / 5575 1.9 524.8 0.7X + + +================================================================================================ +Partitioned Table Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Data column - Native ORC MR 2020 / 2025 7.8 128.4 1.0X +Data column - Native ORC Vectorized 398 / 409 39.5 25.3 5.1X +Data column - Native ORC Vectorized with copy 406 / 411 38.8 25.8 5.0X +Data column - Hive built-in ORC 2967 / 2969 5.3 188.6 0.7X +Partition column - Native ORC MR 1494 / 1505 10.5 95.0 1.4X +Partition column - Native ORC Vectorized 73 / 82 216.3 4.6 27.8X +Partition column - Native ORC Vectorized with copy 71 / 80 221.4 4.5 28.4X +Partition column - Hive built-in ORC 1932 / 1937 8.1 122.8 1.0X +Both columns - Native ORC MR 2057 / 2071 7.6 130.8 1.0X +Both columns - Native ORC Vectorized 445 / 448 35.4 28.3 4.5X +Both column - Native ORC Vectorized with copy 534 / 539 29.4 34.0 3.8X +Both columns - Hive built-in ORC 2994 / 2994 5.3 190.3 0.7X + + +================================================================================================ +Repeated String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1771 / 1785 5.9 168.9 1.0X +Native ORC Vectorized 372 / 375 28.2 35.5 4.8X +Native ORC Vectorized with copy 543 / 576 19.3 51.8 3.3X +Hive built-in ORC 2671 / 2671 3.9 254.7 0.7X + + +================================================================================================ +String with Nulls Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 3276 / 3302 3.2 312.5 1.0X +Native ORC Vectorized 1057 / 1080 9.9 100.8 3.1X +Native ORC Vectorized with copy 1420 / 1431 7.4 135.4 2.3X +Hive built-in ORC 5377 / 5407 2.0 512.8 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 3147 / 3147 3.3 300.1 1.0X +Native ORC Vectorized 1305 / 1319 8.0 124.4 2.4X +Native ORC Vectorized with copy 1685 / 1686 6.2 160.7 1.9X +Hive built-in ORC 4077 / 4085 2.6 388.8 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1739 / 1744 6.0 165.8 1.0X +Native ORC Vectorized 500 / 501 21.0 47.7 3.5X +Native ORC Vectorized with copy 618 / 631 17.0 58.9 2.8X +Hive built-in ORC 2411 / 2427 4.3 229.9 0.7X + + +================================================================================================ +Single Column Scan From Wide Columns +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 1348 / 1366 0.8 1285.3 1.0X +Native ORC Vectorized 119 / 134 8.8 113.5 11.3X +Native ORC Vectorized with copy 119 / 148 8.8 113.9 11.3X +Hive built-in ORC 487 / 507 2.2 464.8 2.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 2667 / 2837 0.4 2543.6 1.0X +Native ORC Vectorized 203 / 222 5.2 193.4 13.2X +Native ORC Vectorized with copy 217 / 255 4.8 207.0 12.3X +Hive built-in ORC 737 / 741 1.4 702.4 3.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Native ORC MR 3954 / 3956 0.3 3770.4 1.0X +Native ORC Vectorized 348 / 360 3.0 331.7 11.4X +Native ORC Vectorized with copy 349 / 359 3.0 333.2 11.3X +Hive built-in ORC 1057 / 1067 1.0 1008.0 3.7X + + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 49de007df3828..0bb5e8c141595 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -22,20 +22,26 @@ import java.io.File import scala.util.Random import org.apache.spark.SparkConf -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - /** * Benchmark to measure ORC read performance. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/OrcReadBenchmark-results.txt". + * }}} * * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. */ // scalastyle:off line.size.limit -object OrcReadBenchmark extends SQLHelper { +object OrcReadBenchmark extends BenchmarkBase with SQLHelper { val conf = new SparkConf() conf.set("orc.compression", "snappy") @@ -69,7 +75,7 @@ object OrcReadBenchmark extends SQLHelper { } def numericScanBenchmark(values: Int, dataType: DataType): Unit = { - val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values, output = output) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -98,59 +104,13 @@ object OrcReadBenchmark extends SQLHelper { spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1135 / 1171 13.9 72.2 1.0X - Native ORC Vectorized 152 / 163 103.4 9.7 7.5X - Native ORC Vectorized with copy 149 / 162 105.4 9.5 7.6X - Hive built-in ORC 1380 / 1384 11.4 87.7 0.8X - - SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1182 / 1244 13.3 75.2 1.0X - Native ORC Vectorized 145 / 156 108.7 9.2 8.2X - Native ORC Vectorized with copy 148 / 158 106.4 9.4 8.0X - Hive built-in ORC 1591 / 1636 9.9 101.2 0.7X - - SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1271 / 1271 12.4 80.8 1.0X - Native ORC Vectorized 206 / 212 76.3 13.1 6.2X - Native ORC Vectorized with copy 200 / 213 78.8 12.7 6.4X - Hive built-in ORC 1776 / 1787 8.9 112.9 0.7X - - SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1344 / 1355 11.7 85.4 1.0X - Native ORC Vectorized 258 / 268 61.0 16.4 5.2X - Native ORC Vectorized with copy 252 / 257 62.4 16.0 5.3X - Hive built-in ORC 1818 / 1823 8.7 115.6 0.7X - - SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1333 / 1352 11.8 84.8 1.0X - Native ORC Vectorized 310 / 324 50.7 19.7 4.3X - Native ORC Vectorized with copy 312 / 320 50.4 19.9 4.3X - Hive built-in ORC 1904 / 1918 8.3 121.0 0.7X - - SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1408 / 1585 11.2 89.5 1.0X - Native ORC Vectorized 359 / 368 43.8 22.8 3.9X - Native ORC Vectorized with copy 364 / 371 43.2 23.2 3.9X - Hive built-in ORC 1881 / 1954 8.4 119.6 0.7X - */ benchmark.run() } } } def intStringScanBenchmark(values: Int): Unit = { - val benchmark = new Benchmark("Int and String Scan", values) + val benchmark = new Benchmark("Int and String Scan", values, output = output) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -181,24 +141,13 @@ object OrcReadBenchmark extends SQLHelper { spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 2566 / 2592 4.1 244.7 1.0X - Native ORC Vectorized 1098 / 1113 9.6 104.7 2.3X - Native ORC Vectorized with copy 1527 / 1593 6.9 145.6 1.7X - Hive built-in ORC 3561 / 3705 2.9 339.6 0.7X - */ benchmark.run() } } } def partitionTableScanBenchmark(values: Int): Unit = { - val benchmark = new Benchmark("Partitioned Table", values) + val benchmark = new Benchmark("Partitioned Table", values, output = output) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -267,32 +216,13 @@ object OrcReadBenchmark extends SQLHelper { spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Data only - Native ORC MR 1447 / 1457 10.9 92.0 1.0X - Data only - Native ORC Vectorized 256 / 266 61.4 16.3 5.6X - Data only - Native ORC Vectorized with copy 263 / 273 59.8 16.7 5.5X - Data only - Hive built-in ORC 1960 / 1988 8.0 124.6 0.7X - Partition only - Native ORC MR 1039 / 1043 15.1 66.0 1.4X - Partition only - Native ORC Vectorized 48 / 53 326.6 3.1 30.1X - Partition only - Native ORC Vectorized with copy 48 / 53 328.4 3.0 30.2X - Partition only - Hive built-in ORC 1234 / 1242 12.7 78.4 1.2X - Both columns - Native ORC MR 1465 / 1475 10.7 93.1 1.0X - Both columns - Native ORC Vectorized 292 / 301 53.9 18.6 5.0X - Both column - Native ORC Vectorized with copy 348 / 354 45.1 22.2 4.2X - Both columns - Hive built-in ORC 2051 / 2060 7.7 130.4 0.7X - */ benchmark.run() } } } def repeatedStringScanBenchmark(values: Int): Unit = { - val benchmark = new Benchmark("Repeated String", values) + val benchmark = new Benchmark("Repeated String", values, output = output) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -320,17 +250,6 @@ object OrcReadBenchmark extends SQLHelper { spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1271 / 1278 8.3 121.2 1.0X - Native ORC Vectorized 200 / 212 52.4 19.1 6.4X - Native ORC Vectorized with copy 342 / 347 30.7 32.6 3.7X - Hive built-in ORC 1874 / 2105 5.6 178.7 0.7X - */ benchmark.run() } } @@ -347,7 +266,8 @@ object OrcReadBenchmark extends SQLHelper { s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) - val benchmark = new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values) + val benchmark = + new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values, output = output) benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { @@ -373,38 +293,13 @@ object OrcReadBenchmark extends SQLHelper { "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 2394 / 2886 4.4 228.3 1.0X - Native ORC Vectorized 699 / 729 15.0 66.7 3.4X - Native ORC Vectorized with copy 959 / 1025 10.9 91.5 2.5X - Hive built-in ORC 3899 / 3901 2.7 371.9 0.6X - - String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 2234 / 2255 4.7 213.1 1.0X - Native ORC Vectorized 854 / 869 12.3 81.4 2.6X - Native ORC Vectorized with copy 1099 / 1128 9.5 104.8 2.0X - Hive built-in ORC 2767 / 2793 3.8 263.9 0.8X - - String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1166 / 1202 9.0 111.2 1.0X - Native ORC Vectorized 338 / 345 31.1 32.2 3.5X - Native ORC Vectorized with copy 418 / 428 25.1 39.9 2.8X - Hive built-in ORC 1730 / 1761 6.1 164.9 0.7X - */ benchmark.run() } } } def columnsBenchmark(values: Int, width: Int): Unit = { - val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values, output = output) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -436,49 +331,36 @@ object OrcReadBenchmark extends SQLHelper { spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 1050 / 1053 1.0 1001.1 1.0X - Native ORC Vectorized 95 / 101 11.0 90.9 11.0X - Native ORC Vectorized with copy 95 / 102 11.0 90.9 11.0X - Hive built-in ORC 348 / 358 3.0 331.8 3.0X - - Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 2099 / 2108 0.5 2002.1 1.0X - Native ORC Vectorized 179 / 187 5.8 171.1 11.7X - Native ORC Vectorized with copy 176 / 188 6.0 167.6 11.9X - Hive built-in ORC 562 / 581 1.9 535.9 3.7X - - Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Native ORC MR 3221 / 3246 0.3 3071.4 1.0X - Native ORC Vectorized 312 / 322 3.4 298.0 10.3X - Native ORC Vectorized with copy 306 / 320 3.4 291.6 10.5X - Hive built-in ORC 815 / 824 1.3 777.3 4.0X - */ benchmark.run() } } } - def main(args: Array[String]): Unit = { - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => - numericScanBenchmark(1024 * 1024 * 15, dataType) + override def benchmark(): Unit = { + runBenchmark("SQL Single Numeric Column Scan") { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + } + runBenchmark("Int and String Scan") { + intStringScanBenchmark(1024 * 1024 * 10) + } + runBenchmark("Partitioned Table Scan") { + partitionTableScanBenchmark(1024 * 1024 * 15) + } + runBenchmark("Repeated String Scan") { + repeatedStringScanBenchmark(1024 * 1024 * 10) + } + runBenchmark("String with Nulls Scan") { + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } } - intStringScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - repeatedStringScanBenchmark(1024 * 1024 * 10) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + runBenchmark("Single Column Scan From Wide Columns") { + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) } - columnsBenchmark(1024 * 1024 * 1, 100) - columnsBenchmark(1024 * 1024 * 1, 200) - columnsBenchmark(1024 * 1024 * 1, 300) } } // scalastyle:on line.size.limit From f4b138082ff91be74b0f5bbe19cdb90dd9e5f131 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sat, 29 Sep 2018 14:48:32 -0700 Subject: [PATCH 1735/2461] [SPARK-25572][SPARKR] test only if not cran ## What changes were proposed in this pull request? CRAN doesn't seem to respect the system requirements as running tests - we have seen cases where SparkR is run on Java 10, which unfortunately Spark does not start on. For 2.4, lets attempt skipping all tests ## How was this patch tested? manual, jenkins, appveyor Author: Felix Cheung Closes #22589 from felixcheung/ralltests. --- R/pkg/tests/run-all.R | 83 +++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 94d75188fb948..1e96418558883 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -18,50 +18,55 @@ library(testthat) library(SparkR) -# Turn all warnings into errors -options("warn" = 2) +# SPARK-25572 +if (identical(Sys.getenv("NOT_CRAN"), "true")) { -if (.Platform$OS.type == "windows") { - Sys.setenv(TZ = "GMT") -} + # Turn all warnings into errors + options("warn" = 2) -# Setup global test environment -# Install Spark first to set SPARK_HOME + if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") + } -# NOTE(shivaram): We set overwrite to handle any old tar.gz files or directories left behind on -# CRAN machines. For Jenkins we should already have SPARK_HOME set. -install.spark(overwrite = TRUE) + # Setup global test environment + # Install Spark first to set SPARK_HOME -sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") -sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") -invisible(lapply(sparkRWhitelistSQLDirs, - function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) -sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) + # NOTE(shivaram): We set overwrite to handle any old tar.gz files or directories left behind on + # CRAN machines. For Jenkins we should already have SPARK_HOME set. + install.spark(overwrite = TRUE) -sparkRTestMaster <- "local[1]" -sparkRTestConfig <- list() -if (identical(Sys.getenv("NOT_CRAN"), "true")) { - sparkRTestMaster <- "" -} else { - # Disable hsperfdata on CRAN - old_java_opt <- Sys.getenv("_JAVA_OPTIONS") - Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt)) - tmpDir <- tempdir() - tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) - sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, - spark.executor.extraJavaOptions = tmpArg) -} + sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") + sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") + invisible(lapply(sparkRWhitelistSQLDirs, + function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) + sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) -test_package("SparkR") + sparkRTestMaster <- "local[1]" + sparkRTestConfig <- list() + if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" + } else { + # Disable hsperfdata on CRAN + old_java_opt <- Sys.getenv("_JAVA_OPTIONS") + Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt)) + tmpDir <- tempdir() + tmpArg <- paste0("-Djava.io.tmpdir=", tmpDir) + sparkRTestConfig <- list(spark.driver.extraJavaOptions = tmpArg, + spark.executor.extraJavaOptions = tmpArg) + } -if (identical(Sys.getenv("NOT_CRAN"), "true")) { - # set random seed for predictable results. mostly for base's sample() in tree and classification - set.seed(42) - # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() - testthat:::run_tests("SparkR", - file.path(sparkRDir, "pkg", "tests", "fulltests"), - NULL, - "summary") -} + test_package("SparkR") + + if (identical(Sys.getenv("NOT_CRAN"), "true")) { + # set random seed for predictable results. mostly for base's sample() in tree and classification + set.seed(42) + # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() + testthat:::run_tests("SparkR", + file.path(sparkRDir, "pkg", "tests", "fulltests"), + NULL, + "summary") + } -SparkR:::uninstallDownloadedSpark() + SparkR:::uninstallDownloadedSpark() + +} From b6b8a6632e2b6e5482aaf4bfa093700752a9df80 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 29 Sep 2018 18:10:04 -0700 Subject: [PATCH 1736/2461] [SPARK-25568][CORE] Continue to update the remaining accumulators when failing to update one accumulator ## What changes were proposed in this pull request? Since we don't fail a job when `AccumulatorV2.merge` fails, we should try to update the remaining accumulators so that they can still report correct values. ## How was this patch tested? The new unit test. Closes #22586 from zsxwing/SPARK-25568. Authored-by: Shixiong Zhu Signed-off-by: gatorsmile --- .../apache/spark/scheduler/DAGScheduler.scala | 20 +++++++++++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 20 +++++++++++++++++++ docs/rdd-programming-guide.md | 4 ++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 47108353583a8..f93d8a8d5de55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1245,9 +1245,10 @@ private[spark] class DAGScheduler( private def updateAccumulators(event: CompletionEvent): Unit = { val task = event.task val stage = stageIdToStage(task.stageId) - try { - event.accumUpdates.foreach { updates => - val id = updates.id + + event.accumUpdates.foreach { updates => + val id = updates.id + try { // Find the corresponding accumulator on the driver and update it val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match { case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]] @@ -1261,10 +1262,17 @@ private[spark] class DAGScheduler( event.taskInfo.setAccumulables( acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables) } + } catch { + case NonFatal(e) => + // Log the class name to make it easy to find the bad implementation + val accumClassName = AccumulatorContext.get(id) match { + case Some(accum) => accum.getClass.getName + case None => "Unknown class" + } + logError( + s"Failed to update accumulator $id ($accumClassName) for task ${task.partitionId}", + e) } - } catch { - case NonFatal(e) => - logError(s"Failed to update accumulators for task ${task.partitionId}", e) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d6c9ae6ab5191..b41d2acab7152 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1880,6 +1880,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(sc.parallelize(1 to 10, 2).count() === 10) } + test("misbehaved accumulator should not impact other accumulators") { + val bad = new LongAccumulator { + override def merge(other: AccumulatorV2[java.lang.Long, java.lang.Long]): Unit = { + throw new DAGSchedulerSuiteDummyException + } + } + sc.register(bad, "bad") + val good = sc.longAccumulator("good") + + sc.parallelize(1 to 10, 2).foreach { item => + bad.add(1) + good.add(1) + } + + // This is to ensure the `bad` accumulator did fail to update its value + assert(bad.value == 0L) + // Should be able to update the "good" accumulator + assert(good.value == 10L) + } + /** * The job will be failed on first task throwing a DAGSchedulerSuiteDummyException. * Any subsequent task WILL throw a legitimate java.lang.UnsupportedOperationException. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 005425754c646..9a07d6ca24b65 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -1465,6 +1465,10 @@ jsc.sc().register(myVectorAcc, "MyVectorAcc1"); Note that, when programmers define their own type of AccumulatorV2, the resulting type can be different than that of the elements added. +*Warning*: When a Spark task finishes, Spark will try to merge the accumulated updates in this task to an accumulator. +If it fails, Spark will ignore the failure and still mark the task successful and continue to run other tasks. Hence, +a buggy accumulator will not impact a Spark job, but it may not get updated correctly although a Spark job is successful. +
      From a2f502cf53b6b00af7cb80b6f38e64cf46367595 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 30 Sep 2018 14:31:04 +0800 Subject: [PATCH 1737/2461] [SPARK-25565][BUILD] Add scalastyle rule to check add Locale.ROOT to .toLowerCase and .toUpperCase for internal calls ## What changes were proposed in this pull request? This PR adds a rule to force `.toLowerCase(Locale.ROOT)` or `toUpperCase(Locale.ROOT)`. It produces an error as below: ``` [error] Are you sure that you want to use toUpperCase or toLowerCase without the root locale? In most cases, you [error] should use toUpperCase(Locale.ROOT) or toLowerCase(Locale.ROOT) instead. [error] If you must use toUpperCase or toLowerCase without the root locale, wrap the code block with [error] // scalastyle:off caselocale [error] .toUpperCase [error] .toLowerCase [error] // scalastyle:on caselocale ``` This PR excludes the cases above for SQL code path for external calls like table name, column name and etc. For test suites, or when it's clear there's no locale problem like Turkish locale problem, it uses `Locale.ROOT`. One minor problem is, `UTF8String` has both methods, `toLowerCase` and `toUpperCase`, and the new rule detects them as well. They are ignored. ## How was this patch tested? Manually tested, and Jenkins tests. Closes #22581 from HyukjinKwon/SPARK-25565. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- .../types/UTF8StringPropertyCheckSuite.scala | 2 ++ .../spark/metrics/sink/StatsdSink.scala | 5 ++-- .../spark/rdd/OrderedRDDFunctions.scala | 3 +- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../history/FsHistoryProviderSuite.scala | 4 +-- .../spark/ml/feature/StopWordsRemover.scala | 2 ++ .../apache/spark/ml/feature/Tokenizer.scala | 4 +++ .../submit/KubernetesClientApplication.scala | 4 +-- .../cluster/k8s/ExecutorPodsSnapshot.scala | 4 ++- .../deploy/mesos/MesosClusterDispatcher.scala | 3 +- scalastyle-config.xml | 13 +++++++++ .../analysis/higherOrderFunctions.scala | 2 ++ .../expressions/stringExpressions.scala | 6 ++++ .../sql/catalyst/parser/AstBuilder.scala | 2 ++ .../spark/sql/catalyst/util/StringUtils.scala | 2 ++ .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../apache/spark/sql/util/SchemaUtils.scala | 2 ++ .../spark/sql/util/SchemaUtilsSuite.scala | 4 ++- .../InsertIntoHadoopFsRelationCommand.scala | 2 ++ .../datasources/csv/CSVDataSource.scala | 6 ++++ .../streaming/WatermarkTracker.scala | 4 ++- .../state/SymmetricHashJoinStateManager.scala | 4 ++- .../spark/sql/ColumnExpressionSuite.scala | 4 +-- .../spark/sql/DataFramePivotSuite.scala | 4 ++- .../org/apache/spark/sql/JoinSuite.scala | 4 ++- .../streaming/EventTimeWatermarkSuite.scala | 4 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 16 ++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +++ .../sql/hive/CompressionCodecSuite.scala | 29 ++++++++++++------- .../sql/hive/HiveSchemaInferenceSuite.scala | 9 +++--- .../spark/sql/hive/StatisticsSuite.scala | 15 ++++++---- 31 files changed, 132 insertions(+), 40 deletions(-) diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 7d3331f44f015..9656951810daf 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -63,6 +63,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty } } + // scalastyle:off caselocale test("toUpperCase") { forAll { (s: String) => assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase)) @@ -74,6 +75,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase)) } } + // scalastyle:on caselocale test("compare") { forAll { (s1: String, s2: String) => diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala index 859a2f6bcd456..61e74e05169cc 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -52,7 +52,8 @@ private[spark] class StatsdSink( val pollPeriod = property.getProperty(STATSD_KEY_PERIOD, STATSD_DEFAULT_PERIOD).toInt val pollUnit = - TimeUnit.valueOf(property.getProperty(STATSD_KEY_UNIT, STATSD_DEFAULT_UNIT).toUpperCase) + TimeUnit.valueOf( + property.getProperty(STATSD_KEY_UNIT, STATSD_DEFAULT_UNIT).toUpperCase(Locale.ROOT)) val prefix = property.getProperty(STATSD_KEY_PREFIX, STATSD_DEFAULT_PREFIX) diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index a5992022d0832..5b1c024257529 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -35,7 +35,8 @@ import org.apache.spark.internal.Logging * * val rdd: RDD[(String, Int)] = ... * implicit val caseInsensitiveOrdering = new Ordering[String] { - * override def compare(a: String, b: String) = a.toLowerCase.compare(b.toLowerCase) + * override def compare(a: String, b: String) = + * a.toLowerCase(Locale.ROOT).compare(b.toLowerCase(Locale.ROOT)) * } * * // Sort by key, using the above case insensitive ordering. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c8b148be84536..93b5826f8a74b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2736,7 +2736,7 @@ private[spark] object Utils extends Logging { } val masterScheme = new URI(masterWithoutK8sPrefix).getScheme - val resolvedURL = masterScheme.toLowerCase match { + val resolvedURL = masterScheme.toLowerCase(Locale.ROOT) match { case "https" => masterWithoutK8sPrefix case "http" => diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index b4eba755eccbf..444e8d6e11f88 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io._ import java.nio.charset.StandardCharsets -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} @@ -834,7 +834,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open( argThat(new ArgumentMatcher[Path]() { override def matches(path: Any): Boolean = { - path.asInstanceOf[Path].getName.toLowerCase == "accessdenied" + path.asInstanceOf[Path].getName.toLowerCase(Locale.ROOT) == "accessdenied" } })) val mockedProvider = spy(provider) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 94640a5cbe310..6669d402cd996 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -118,7 +118,9 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String } } else { val lc = new Locale($(locale)) + // scalastyle:off caselocale val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s + // scalastyle:on caselocale val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => terms.filter(s => !lowerStopWords.contains(toLower(s))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index aede1f812a552..748c869af4117 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -36,7 +36,9 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) def this() = this(Identifiable.randomUID("tok")) override protected def createTransformFunc: String => Seq[String] = { + // scalastyle:off caselocale _.toLowerCase.split("\\s") + // scalastyle:on caselocale } override protected def validateInputType(inputType: DataType): Unit = { @@ -140,7 +142,9 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) override protected def createTransformFunc: String => Seq[String] = { originStr => val re = $(pattern).r + // scalastyle:off caselocale val str = if ($(toLowercase)) originStr.toLowerCase() else originStr + // scalastyle:on caselocale val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index edeaa380194ac..af3903ac5da56 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import java.io.StringWriter -import java.util.{Collections, UUID} +import java.util.{Collections, Locale, UUID} import java.util.Properties import io.fabric8.kubernetes.api.model._ @@ -260,7 +260,7 @@ private[spark] object KubernetesClientApplication { val launchTime = System.currentTimeMillis() s"$appName-$launchTime" .trim - .toLowerCase + .toLowerCase(Locale.ROOT) .replaceAll("\\s+", "-") .replaceAll("\\.", "-") .replaceAll("[^a-z0-9\\-]", "") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala index 26be918043412..435a5f1461c92 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster.k8s +import java.util.Locale + import io.fabric8.kubernetes.api.model.Pod import org.apache.spark.deploy.k8s.Constants._ @@ -52,7 +54,7 @@ object ExecutorPodsSnapshot extends Logging { if (isDeleted(pod)) { PodDeleted(pod) } else { - val phase = pod.getStatus.getPhase.toLowerCase + val phase = pod.getStatus.getPhase.toLowerCase(Locale.ROOT) phase match { case "pending" => PodPending(pod) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 64698b55c6bb6..32ac4f37c5f99 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.mesos +import java.util.Locale import java.util.concurrent.CountDownLatch import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -60,7 +61,7 @@ private[mesos] class MesosClusterDispatcher( } private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) - private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase() + private val recoveryMode = conf.get(RECOVERY_MODE).toUpperCase(Locale.ROOT) logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) private val engineFactory = recoveryMode match { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index da5c3f29c32dc..36a73e3362218 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -227,6 +227,19 @@ This file is divided into 3 sections: ]]> + + (\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\))) + + + JavaConversions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index dd08190e1e8a3..a8a7bbd9f9cd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -73,7 +73,9 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { private val canonicalizer = { if (!conf.caseSensitiveAnalysis) { + // scalastyle:off caselocale s: String => s.toLowerCase + // scalastyle:on caselocale } else { s: String => s } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 14faa62bde7d0..cd824ee87ca53 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -330,7 +330,9 @@ trait String2StringExpression extends ImplicitCastInputTypes { case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { + // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toUpperCase + // scalastyle:on caselocale override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -349,7 +351,9 @@ case class Upper(child: Expression) """) case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { + // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toLowerCase + // scalastyle:on caselocale override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -1389,7 +1393,9 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = StringType override def nullSafeEval(string: Any): Any = { + // scalastyle:off caselocale string.asInstanceOf[UTF8String].toLowerCase.toTitleCase + // scalastyle:on caselocale } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5cfb5dc871041..da12a6519bd28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -663,7 +663,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), unrequiredChildIndex = Nil, outer = ctx.OUTER != null, + // scalastyle:off caselocale Some(ctx.tblName.getText.toLowerCase), + // scalastyle:on caselocale ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), query) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index ca22ea24207e1..bc861a805ce61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -62,8 +62,10 @@ object StringUtils { private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + // scalastyle:off caselocale def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + // scalastyle:on caselocale /** * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f6c98805bfb15..b699707d85235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -974,8 +974,9 @@ object SQLConf { "Note: This configuration cannot be changed between query restarts from the same " + "checkpoint location.") .stringConf + .transform(_.toLowerCase(Locale.ROOT)) .checkValue( - str => Set("min", "max").contains(str.toLowerCase), + str => Set("min", "max").contains(str), "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + "Valid values are 'min' and 'max'") .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 41ca270095ffb..052014ab86744 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -77,7 +77,9 @@ private[spark] object SchemaUtils { */ def checkColumnNameDuplication( columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { + // scalastyle:off caselocale val names = if (caseSensitiveAnalysis) columnNames else columnNames.map(_.toLowerCase) + // scalastyle:on caselocale if (names.distinct.length != names.length) { val duplicateColumns = names.groupBy(identity).collect { case (x, ys) if ys.length > 1 => s"`$x`" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala index a25be2fe61dbd..2f576a4031e92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.util +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ @@ -39,7 +41,7 @@ class SchemaUtilsSuite extends SparkFunSuite { test(s"Check column name duplication in $testType cases") { def checkExceptionCases(schemaStr: String, duplicatedColumns: Seq[String]): Unit = { val expectedErrorMsg = "Found duplicate column(s) in SchemaUtilsSuite: " + - duplicatedColumns.map(c => s"`${c.toLowerCase}`").mkString(", ") + duplicatedColumns.map(c => s"`${c.toLowerCase(Locale.ROOT)}`").mkString(", ") val schema = StructType.fromDDL(schemaStr) var msg = intercept[AnalysisException] { SchemaUtils.checkSchemaColumnNameDuplication( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 484942d35c857..d43fa3893df1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -95,7 +95,9 @@ case class InsertIntoHadoopFsRelationCommand( val parameters = CaseInsensitiveMap(options) val partitionOverwriteMode = parameters.get("partitionOverwriteMode") + // scalastyle:off caselocale .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) + // scalastyle:on caselocale .getOrElse(sparkSession.sessionState.conf.partitionOverwriteMode) val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC // This config only makes sense when we are overwriting a partitioned dataset with dynamic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index e840ff1682502..b93f418bcb5be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -86,7 +86,9 @@ abstract class CSVDataSource extends Serializable { if (options.headerFlag) { val duplicates = { val headerNames = row.filter(_ != null) + // scalastyle:off caselocale .map(name => if (caseSensitive) name else name.toLowerCase) + // scalastyle:on caselocale headerNames.diff(headerNames.distinct).distinct } @@ -95,7 +97,9 @@ abstract class CSVDataSource extends Serializable { // When there are empty strings or the values set in `nullValue`, put the // index as the suffix. s"_c$index" + // scalastyle:off caselocale } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // scalastyle:on caselocale // When there are case-insensitive duplicates, put the index as the suffix. s"$value$index" } else if (duplicates.contains(value)) { @@ -153,8 +157,10 @@ object CSVDataSource extends Logging { while (errorMessage.isEmpty && i < headerLen) { var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) if (!caseSensitive) { + // scalastyle:off caselocale nameInSchema = nameInSchema.toLowerCase nameInHeader = nameInHeader.toLowerCase + // scalastyle:on caselocale } if (nameInHeader != nameInSchema) { errorMessage = Some( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 7b30db44a2090..76ab1284633b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.Locale + import scala.collection.mutable import org.apache.spark.internal.Logging @@ -36,7 +38,7 @@ object MultipleWatermarkPolicy { val DEFAULT_POLICY_NAME = "min" def apply(policyName: String): MultipleWatermarkPolicy = { - policyName.toLowerCase match { + policyName.toLowerCase(Locale.ROOT) match { case DEFAULT_POLICY_NAME => MinWatermark case "max" => MaxWatermark case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 352b3d3616fba..43f22803e7685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.state +import java.util.Locale + import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext @@ -263,7 +265,7 @@ class SymmetricHashJoinStateManager( def metrics: StateStoreMetrics = { val keyToNumValuesMetrics = keyToNumValues.metrics val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics - def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase}: $desc" + def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc" StateStoreMetrics( keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2917c56dbeb56..f984a1b722e36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -505,7 +505,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("upper") { checkAnswer( lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase)) + ('a' to 'd').map(c => Row(c.toString.toUpperCase(Locale.ROOT))) ) checkAnswer( @@ -526,7 +526,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("lower") { checkAnswer( upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase)) + ('A' to 'F').map(c => Row(c.toString.toLowerCase(Locale.ROOT))) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 02ab19754b0c7..b52ca58c07d27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -272,7 +274,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil val df = trainingSales .groupBy($"sales.year") - .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase)) + .pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase(Locale.ROOT))) .agg(sum($"sales.earnings")) checkAnswer(df, expected) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 44767dfc92497..aa2162c9d2cda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Locale + import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.language.existentials @@ -831,7 +833,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { case _ => } val joinPairs = physicalJoins.zip(executedJoins) - val numOfJoins = sqlString.split(" ").count(_.toUpperCase == "JOIN") + val numOfJoins = sqlString.split(" ").count(_.toUpperCase(Locale.ROOT) == "JOIN") assert(joinPairs.size == numOfJoins) joinPairs.foreach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 026af17c7b23f..c696204cecc2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming import java.{util => ju} import java.io.File import java.text.SimpleDateFormat -import java.util.{Calendar, Date} +import java.util.{Calendar, Date, Locale} import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} @@ -698,7 +698,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val e = intercept[IllegalArgumentException] { spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) } - assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("valid values are 'min' and 'max'")) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 505124ae9e7c8..445161d5de1c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -868,7 +868,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // and Hive will validate the column names in partition spec to make sure they are partition // columns. Here we Lowercase the column names before passing the partition spec to Hive // client, to satisfy Hive. + // scalastyle:off caselocale orderedPartitionSpec.put(colName.toLowerCase, partition(colName)) + // scalastyle:on caselocale } client.loadPartition( @@ -896,7 +898,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // and Hive will validate the column names in partition spec to make sure they are partition // columns. Here we Lowercase the column names before passing the partition spec to Hive // client, to satisfy Hive. + // scalastyle:off caselocale orderedPartitionSpec.put(colName.toLowerCase, partition(colName)) + // scalastyle:on caselocale } client.loadDynamicPartitions( @@ -916,13 +920,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // to lower case the column names in partition specification before calling partition related Hive // APIs, to match this behaviour. private def lowerCasePartitionSpec(spec: TablePartitionSpec): TablePartitionSpec = { + // scalastyle:off caselocale spec.map { case (k, v) => k.toLowerCase -> v } + // scalastyle:on caselocale } // Build a map from lower-cased partition column names to exact column names for a given table private def buildLowerCasePartColNameMap(table: CatalogTable): Map[String, String] = { val actualPartColNames = table.partitionColumnNames + // scalastyle:off caselocale actualPartColNames.map(colName => (colName.toLowerCase, colName)).toMap + // scalastyle:on caselocale } // Hive metastore is not case preserving and the column names of the partition specification we @@ -931,7 +939,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat private def restorePartitionSpec( spec: TablePartitionSpec, partColMap: Map[String, String]): TablePartitionSpec = { + // scalastyle:off caselocale spec.map { case (k, v) => partColMap(k.toLowerCase) -> v } + // scalastyle:on caselocale } private def restorePartitionSpec( @@ -990,7 +1000,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // When Hive rename partition for managed tables, it will create the partition location with // a default path generate by the new spec with lower cased partition column names. This is // unexpected and we need to rename them manually and alter the partition location. + // scalastyle:off caselocale val hasUpperCasePartitionColumn = partitionColumnNames.exists(col => col.toLowerCase != col) + // scalastyle:on caselocale if (tableMeta.tableType == MANAGED && hasUpperCasePartitionColumn) { val tablePath = new Path(tableMeta.location) val fs = tablePath.getFileSystem(hadoopConf) @@ -1031,7 +1043,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // another partition to `A=1/B=3`, then we will have `A=1/B=2` and `a=1/b=3`, and we should // just move `a=1/b=3` into `A=1` with new name `B=3`. } else { + // scalastyle:off caselocale val actualPartitionString = getPartitionPathString(col.toLowerCase, partValue) + // scalastyle:on caselocale val actualPartitionPath = new Path(currentFullPath, actualPartitionString) try { fs.rename(actualPartitionPath, expectedPartitionPath) @@ -1182,7 +1196,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat clientPartitionNames.map { partitionPath => val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partitionPath) partSpec.map { case (partName, partValue) => + // scalastyle:off caselocale partColNameMap(partName.toLowerCase) + "=" + escapePathName(partValue) + // scalastyle:on caselocale }.mkString("/") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8adfda07d29d5..d047953327958 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -59,8 +59,10 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // For testing only private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { val key = QualifiedTableName( + // scalastyle:off caselocale table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) + // scalastyle:on caselocale catalogProxy.getCachedTable(key) } @@ -273,6 +275,7 @@ private[hive] object HiveMetastoreCatalog { def mergeWithMetastoreSchema( metastoreSchema: StructType, inferredSchema: StructType): StructType = try { + // scalastyle:off caselocale // Find any nullable fields in mestastore schema that are missing from the inferred schema. val metastoreFields = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap val missingNullables = metastoreFields @@ -282,6 +285,7 @@ private[hive] object HiveMetastoreCatalog { // Merge missing nullable fields to inferred schema and build a case-insensitive field map. val inferredFields = StructType(inferredSchema ++ missingNullables) .map(f => f.name.toLowerCase -> f).toMap + // scalastyle:on caselocale StructType(metastoreSchema.map(f => f.copy(name = inferredFields(f.name).name))) } catch { case NonFatal(_) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 30204d1223846..1bd7e52c88ecf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.File +import java.util.Locale import scala.collection.JavaConverters._ @@ -50,23 +51,29 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo private val maxRecordNum = 50 - private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { - case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key - case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + private def getConvertMetastoreConfName(format: String): String = { + format.toLowerCase(Locale.ROOT) match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } } - private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { - case "parquet" => SQLConf.PARQUET_COMPRESSION.key - case "orc" => SQLConf.ORC_COMPRESSION.key + private def getSparkCompressionConfName(format: String): String = { + format.toLowerCase(Locale.ROOT) match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } } - private def getHiveCompressPropName(format: String): String = format.toLowerCase match { - case "parquet" => ParquetOutputFormat.COMPRESSION - case "orc" => COMPRESS.getAttribute + private def getHiveCompressPropName(format: String): String = { + format.toLowerCase(Locale.ROOT) match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } } private def normalizeCodecName(format: String, name: String): String = { - format.toLowerCase match { + format.toLowerCase(Locale.ROOT) match { case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) case "orc" => OrcOptions.getORCCompressionCodecName(name) } @@ -74,7 +81,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo private def getTableCompressionCodec(path: String, format: String): Seq[String] = { val hadoopConf = spark.sessionState.newHadoopConf() - val codecs = format.toLowerCase match { + val codecs = format.toLowerCase(Locale.ROOT) match { case "parquet" => for { footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) block <- footer.getParquetMetadata.getBlocks.asScala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index 51a48a20daaa2..aa4fc13333c48 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.File +import java.util.Locale import scala.util.Random @@ -56,7 +57,7 @@ class HiveSchemaInferenceSuite // Return a copy of the given schema with all field names converted to lower case. private def lowerCaseSchema(schema: StructType): StructType = { - StructType(schema.map(f => f.copy(name = f.name.toLowerCase))) + StructType(schema.map(f => f.copy(name = f.name.toLowerCase(Locale.ROOT)))) } // Create a Hive external test table containing the given field and partition column names. @@ -78,7 +79,7 @@ class HiveSchemaInferenceSuite val partitionStructFields = partitionCols.map { field => StructField( // Partition column case isn't preserved - name = field.toLowerCase, + name = field.toLowerCase(Locale.ROOT), dataType = IntegerType, nullable = true, metadata = Metadata.empty) @@ -113,7 +114,7 @@ class HiveSchemaInferenceSuite properties = Map("serialization.format" -> "1")), schema = schema, provider = Option("hive"), - partitionColumnNames = partitionCols.map(_.toLowerCase), + partitionColumnNames = partitionCols.map(_.toLowerCase(Locale.ROOT)), properties = Map.empty), true) @@ -180,7 +181,7 @@ class HiveSchemaInferenceSuite val catalogTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) assert(catalogTable.schemaPreservesCase) assert(catalogTable.schema == schema) - assert(catalogTable.partitionColumnNames == partCols.map(_.toLowerCase)) + assert(catalogTable.partitionColumnNames == partCols.map(_.toLowerCase(Locale.ROOT))) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 57f1c243a70de..db2024e8b5d16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} import java.sql.Timestamp +import java.util.Locale import scala.reflect.ClassTag import scala.util.matching.Regex @@ -489,7 +490,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS") }.getMessage assert(message.contains( - s"DS is not a valid partition column in table `default`.`${tableName.toLowerCase}`")) + "DS is not a valid partition column in table " + + s"`default`.`${tableName.toLowerCase(Locale.ROOT)}`")) } } } @@ -503,8 +505,9 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName $partitionSpec COMPUTE STATISTICS") }.getMessage assert(message.contains("The list of partition columns with values " + - s"in partition specification for table '${tableName.toLowerCase}' in database 'default' " + - "is not a prefix of the list of partition columns defined in the table schema")) + s"in partition specification for table '${tableName.toLowerCase(Locale.ROOT)}' in " + + "database 'default' is not a prefix of the list of partition columns defined in " + + "the table schema")) } withTable(tableName) { @@ -550,12 +553,14 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assertAnalysisException( s"ANALYZE TABLE $tableName PARTITION (hour=20) COMPUTE STATISTICS", - s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`" + "hour is not a valid partition column in table " + + s"`default`.`${tableName.toLowerCase(Locale.ROOT)}`" ) assertAnalysisException( s"ANALYZE TABLE $tableName PARTITION (hour) COMPUTE STATISTICS", - s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`" + "hour is not a valid partition column in table " + + s"`default`.`${tableName.toLowerCase(Locale.ROOT)}`" ) intercept[NoSuchPartitionException] { From 40e6ed89405828ff312eca0abd43cfba4b9185b2 Mon Sep 17 00:00:00 2001 From: Darcy Shen Date: Sun, 30 Sep 2018 09:00:23 -0500 Subject: [PATCH 1738/2461] [CORE][MINOR] Fix obvious error and compiling for Scala 2.12.7 ## What changes were proposed in this pull request? Fix an obvious error. ## How was this patch tested? Existing tests. Closes #22577 from sadhen/minor_fix. Authored-by: Darcy Shen Signed-off-by: Sean Owen --- .../org/apache/spark/status/api/v1/OneApplicationResource.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 32100c5704538..1f4082cac8f75 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -175,7 +175,7 @@ private[v1] class OneApplicationAttemptResource extends AbstractApplicationResou def getAttempt(): ApplicationAttemptInfo = { uiRoot.getApplicationInfo(appId) .flatMap { app => - app.attempts.filter(_.attemptId == attemptId).headOption + app.attempts.find(_.attemptId.contains(attemptId)) } .getOrElse { throw new NotFoundException(s"unknown app $appId, attempt $attemptId") From 4da541a5d23b039eb549dd849cf121bdc8676e59 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Sun, 30 Sep 2018 14:28:20 -0700 Subject: [PATCH 1739/2461] [SPARK-25543][K8S] Print debug message iff execIdsRemovedInThisRound is not empty. ## What changes were proposed in this pull request? Spurious logs like /sec. 2018-09-26 09:33:57 DEBUG ExecutorPodsLifecycleManager:58 - Removed executors with ids from Spark that were either found to be deleted or non-existent in the cluster. 2018-09-26 09:33:58 DEBUG ExecutorPodsLifecycleManager:58 - Removed executors with ids from Spark that were either found to be deleted or non-existent in the cluster. 2018-09-26 09:33:59 DEBUG ExecutorPodsLifecycleManager:58 - Removed executors with ids from Spark that were either found to be deleted or non-existent in the cluster. 2018-09-26 09:34:00 DEBUG ExecutorPodsLifecycleManager:58 - Removed executors with ids from Spark that were either found to be deleted or non-existent in the cluster. The fix is easy, first check if there are any removed executors, before producing the log message. ## How was this patch tested? Tested by manually deploying to a minikube cluster. Closes #22565 from ScrapCodes/spark-25543/k8s/debug-log-spurious-warning. Authored-by: Prashant Sharma Signed-off-by: Dongjoon Hyun --- .../cluster/k8s/ExecutorPodsLifecycleManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index e2800cff7b720..cc254b896249a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -100,8 +100,11 @@ private[spark] class ExecutorPodsLifecycleManager( } } } - logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + - s" from Spark that were either found to be deleted or non-existent in the cluster.") + + if (execIdsRemovedInThisRound.nonEmpty) { + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } } private def onFinalNonDeletedState( From fb8f4c05657595e089b6812d97dbfee246fce06f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 30 Sep 2018 22:08:04 -0700 Subject: [PATCH 1740/2461] [SPARK-25505][SQL][FOLLOWUP] Fix for attributes cosmetically different in Pivot clause ## What changes were proposed in this pull request? #22519 introduced a bug when the attributes in the pivot clause are cosmetically different from the output ones (eg. different case). In particular, the problem is that the PR used a `Set[Attribute]` instead of an `AttributeSet`. ## How was this patch tested? added UT Closes #22582 from mgaido91/SPARK-25505_followup. Authored-by: Marco Gaido Signed-off-by: gatorsmile --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- sql/core/src/test/resources/sql-tests/inputs/pivot.sql | 5 +++-- sql/core/src/test/resources/sql-tests/results/pivot.sql.out | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c0a73083c2d1f..d72e512e0df56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -555,8 +555,7 @@ class Analyzer( } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse { - val pivotColAndAggRefs = - (pivotColumn.references ++ aggregates.flatMap(_.references)).toSet + val pivotColAndAggRefs = pivotColumn.references ++ AttributeSet(aggregates) child.output.filterNot(pivotColAndAggRefs.contains) } val singleAgg = aggregates.size == 1 diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 81547ab46ce09..c2ecd97e2b02f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -289,11 +289,12 @@ PIVOT ( ); -- grouping columns output in the same order as input +-- correctly handle pivot columns with different cases SELECT * FROM ( SELECT course, earnings, "a" as a, "z" as z, "b" as b, "y" as y, "c" as c, "x" as x, "d" as d, "w" as w FROM courseSales ) PIVOT ( - sum(earnings) - FOR course IN ('dotNET', 'Java') + sum(Earnings) + FOR Course IN ('dotNET', 'Java') ); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 487883a7f3847..595ce1f8efcd2 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -484,8 +484,8 @@ SELECT * FROM ( FROM courseSales ) PIVOT ( - sum(earnings) - FOR course IN ('dotNET', 'Java') + sum(Earnings) + FOR Course IN ('dotNET', 'Java') ) -- !query 31 schema struct From 21f0b73dbcd94f9eea8cbc06a024b0e899edaf4c Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sun, 30 Sep 2018 22:49:14 -0700 Subject: [PATCH 1741/2461] [SPARK-25453][SQL][TEST][.FFFFFFFFF] OracleIntegrationSuite IllegalArgumentException: Timestamp format must be yyyy-mm-dd hh:mm:ss ## What changes were proposed in this pull request? This PR aims to fix the failed test of `OracleIntegrationSuite`. ## How was this patch tested? Existing integration tests. Closes #22461 from seancxmao/SPARK-25453. Authored-by: seancxmao Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 2 +- .../apache/spark/sql/jdbc/OracleIntegrationSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6de9de90c62c3..a1d7b1108bf73 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1502,7 +1502,7 @@ See the [Apache Avro Data Source Guide](avro-data-source-guide.html). * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. * Some databases, such as H2, convert all names to upper case. You'll need to use upper case to refer to those names in Spark SQL. - + * Users can specify vendor-specific JDBC connection properties in the data source options to do special treatment. For example, `spark.read.format("jdbc").option("url", oracleJdbcUrl).option("oracle.jdbc.mapDateToTimestamp", "false")`. `oracle.jdbc.mapDateToTimestamp` defaults to true, users often need to disable this flag to avoid Oracle date being resolved as timestamp. # Performance Tuning diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 09a2cd83aed6b..70d294d0ca650 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -442,6 +442,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo .option("lowerBound", "2018-07-06") .option("upperBound", "2018-07-20") .option("numPartitions", 3) + // oracle.jdbc.mapDateToTimestamp defaults to true. If this flag is not disabled, column d + // (Oracle DATE) will be resolved as Catalyst Timestamp, which will fail bound evaluation of + // the partition column. E.g. 2018-07-06 cannot be evaluated as Timestamp, and the error + // message says: Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]. + .option("oracle.jdbc.mapDateToTimestamp", "false") + .option("sessionInitStatement", "ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'") .load() df1.logicalPlan match { @@ -462,6 +468,9 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo .option("lowerBound", "2018-07-04 03:30:00.0") .option("upperBound", "2018-07-27 14:11:05.0") .option("numPartitions", 2) + .option("oracle.jdbc.mapDateToTimestamp", "false") + .option("sessionInitStatement", + "ALTER SESSION SET NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'") .load() df2.logicalPlan match { From 30f5d0f2ddfe56266ea81e4255f9b4f373dab237 Mon Sep 17 00:00:00 2001 From: Aleksandr Koriagin Date: Mon, 1 Oct 2018 17:18:45 +0800 Subject: [PATCH 1742/2461] [SPARK-23401][PYTHON][TESTS] Add more data types for PandasUDFTests ## What changes were proposed in this pull request? Add more data types for Pandas UDF Tests for PySpark SQL ## How was this patch tested? manual tests Closes #22568 from AlexanderKoryagin/new_types_for_pandas_udf_tests. Lead-authored-by: Aleksandr Koriagin Co-authored-by: hyukjinkwon Co-authored-by: Alexander Koryagin Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 107 ++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b88a6551f8ae5..815772d23ceea 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5525,32 +5525,81 @@ def data(self): .withColumn("v", explode(col('vs'))).drop('vs') def test_supported_types(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - df = self.data.withColumn("arr", array(col("id"))) + from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, PandasUDFType - # Different forms of group map pandas UDF, results of these are the same + values = [ + 1, 2, 3, + 4, 5, 1.1, + 2.2, Decimal(1.123), + [1, 2, 2], True, 'hello' + ] + output_fields = [ + ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()), + ('int', IntegerType()), ('long', LongType()), ('float', FloatType()), + ('double', DoubleType()), ('decim', DecimalType(10, 3)), + ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()) + ] - output_schema = StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]) + # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"): + values.append(bytearray([0x01, 0x02])) + output_fields.append(('bin', BinaryType())) + output_schema = StructType([StructField(*x) for x in output_fields]) + df = self.spark.createDataFrame([values], schema=output_schema) + + # Different forms of group map pandas UDF, results of these are the same udf1 = pandas_udf( - lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda pdf: pdf.assign( + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), output_schema, PandasUDFType.GROUPED_MAP ) udf2 = pandas_udf( - lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda _, pdf: pdf.assign( + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), output_schema, PandasUDFType.GROUPED_MAP ) udf3 = pandas_udf( - lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda key, pdf: pdf.assign( + id=key[0], + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), output_schema, PandasUDFType.GROUPED_MAP ) @@ -5714,24 +5763,26 @@ def test_wrong_args(self): pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): + from distutils.version import LooseVersion + import pyarrow as pa from pyspark.sql.functions import pandas_udf, PandasUDFType - schema = StructType( - [StructField("id", LongType(), True), - StructField("map", MapType(StringType(), IntegerType()), True)]) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*grouped map Pandas UDF.*MapType'): - pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) - schema = StructType( - [StructField("id", LongType(), True), - StructField("arr_ts", ArrayType(TimestampType()), True)]) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): - pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' + unsupported_types = [ + StructField('map', MapType(StringType(), IntegerType())), + StructField('arr_ts', ArrayType(TimestampType())), + StructField('null', NullType()), + ] + + # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + unsupported_types.append(StructField('bin', BinaryType())) + + for unsupported_type in unsupported_types: + schema = StructType([StructField('id', LongType(), True), unsupported_type]) + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, common_err_msg): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) # Regression test for SPARK-23314 def test_timestamp_dst(self): From b96fd44f0e91751c1ce3a617cb083bdf880701a1 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 1 Oct 2018 07:32:40 -0700 Subject: [PATCH 1743/2461] [SPARK-25476][SPARK-25510][TEST] Refactor AggregateBenchmark and add a new trait to better support Dataset and DataFrame API ## What changes were proposed in this pull request? This PR does 2 things: 1. Add a new trait(`SqlBasedBenchmark`) to better support Dataset and DataFrame API. 2. Refactor `AggregateBenchmark` to use main method. Generate benchmark result: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.AggregateBenchmark" ``` ## How was this patch tested? manual tests Closes #22484 from wangyum/SPARK-25476. Lead-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../benchmarks/AggregateBenchmark-results.txt | 143 +++ .../benchmark/AggregateBenchmark.scala | 943 ++++++++---------- .../benchmark/SqlBasedBenchmark.scala | 60 ++ 3 files changed, 633 insertions(+), 513 deletions(-) create mode 100644 sql/core/benchmarks/AggregateBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala diff --git a/sql/core/benchmarks/AggregateBenchmark-results.txt b/sql/core/benchmarks/AggregateBenchmark-results.txt new file mode 100644 index 0000000000000..19e524777692e --- /dev/null +++ b/sql/core/benchmarks/AggregateBenchmark-results.txt @@ -0,0 +1,143 @@ +================================================================================================ +aggregate without grouping +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +agg w/o group: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +agg w/o group wholestage off 65374 / 70665 32.1 31.2 1.0X +agg w/o group wholestage on 1178 / 1209 1779.8 0.6 55.5X + + +================================================================================================ +stat functions +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +stddev wholestage off 8667 / 8851 12.1 82.7 1.0X +stddev wholestage on 1266 / 1273 82.8 12.1 6.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +kurtosis wholestage off 41218 / 41231 2.5 393.1 1.0X +kurtosis wholestage on 1347 / 1357 77.8 12.8 30.6X + + +================================================================================================ +aggregate with linear keys +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +codegen = F 9309 / 9389 9.0 111.0 1.0X +codegen = T hashmap = F 4417 / 4435 19.0 52.7 2.1X +codegen = T hashmap = T 1289 / 1298 65.1 15.4 7.2X + + +================================================================================================ +aggregate with randomized keys +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +codegen = F 11424 / 11426 7.3 136.2 1.0X +codegen = T hashmap = F 6441 / 6496 13.0 76.8 1.8X +codegen = T hashmap = T 2333 / 2344 36.0 27.8 4.9X + + +================================================================================================ +aggregate with string key +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Aggregate w string key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +codegen = F 4751 / 4890 4.4 226.5 1.0X +codegen = T hashmap = F 3146 / 3182 6.7 150.0 1.5X +codegen = T hashmap = T 2211 / 2261 9.5 105.4 2.1X + + +================================================================================================ +aggregate with decimal key +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +codegen = F 3029 / 3062 6.9 144.4 1.0X +codegen = T hashmap = F 1534 / 1569 13.7 73.2 2.0X +codegen = T hashmap = T 575 / 578 36.5 27.4 5.3X + + +================================================================================================ +aggregate with multiple key types +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Aggregate w multiple keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +codegen = F 7506 / 7521 2.8 357.9 1.0X +codegen = T hashmap = F 4791 / 4808 4.4 228.5 1.6X +codegen = T hashmap = T 3553 / 3585 5.9 169.4 2.1X + + +================================================================================================ +max function bytecode size of wholestagecodegen +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +max function bytecode size: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +codegen = F 608 / 656 1.1 927.1 1.0X +codegen = T hugeMethodLimit = 10000 402 / 419 1.6 613.5 1.5X +codegen = T hugeMethodLimit = 1500 616 / 619 1.1 939.9 1.0X + + +================================================================================================ +cube +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +cube wholestage off 3229 / 3237 1.6 615.9 1.0X +cube wholestage on 1285 / 1306 4.1 245.2 2.5X + + +================================================================================================ +hash and BytesToBytesMap +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +UnsafeRowhash 328 / 330 64.0 15.6 1.0X +murmur3 hash 167 / 167 125.4 8.0 2.0X +fast hash 84 / 85 249.0 4.0 3.9X +arrayEqual 192 / 192 109.3 9.1 1.7X +Java HashMap (Long) 144 / 147 145.9 6.9 2.3X +Java HashMap (two ints) 147 / 153 142.3 7.0 2.2X +Java HashMap (UnsafeRow) 785 / 788 26.7 37.4 0.4X +LongToUnsafeRowMap (opt=false) 456 / 457 46.0 21.8 0.7X +LongToUnsafeRowMap (opt=true) 125 / 125 168.3 5.9 2.6X +BytesToBytesMap (off Heap) 885 / 885 23.7 42.2 0.4X +BytesToBytesMap (on Heap) 860 / 864 24.4 41.0 0.4X +Aggregate HashMap 56 / 56 373.9 2.7 5.8X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 57a6fdb800ea4..296ae104a94a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -34,621 +34,538 @@ import org.apache.spark.unsafe.map.BytesToBytesMap /** * Benchmark to measure performance for aggregate primitives. - * To run this: - * build/sbt "sql/test-only *benchmark.AggregateBenchmark" - * - * Benchmarks in this file are skipped in normal builds. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/AggregateBenchmark-results.txt". + * }}} */ -class AggregateBenchmark extends BenchmarkWithCodegen { +object AggregateBenchmark extends SqlBasedBenchmark { - ignore("aggregate without grouping") { - val N = 500L << 22 - val benchmark = new Benchmark("agg without grouping", N) - runBenchmark("agg w/o group", N) { - sparkSession.range(N).selectExpr("sum(id)").collect() + override def benchmark(): Unit = { + runBenchmark("aggregate without grouping") { + val N = 500L << 22 + codegenBenchmark("agg w/o group", N) { + spark.range(N).selectExpr("sum(id)").collect() + } } - /* - agg w/o group: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - agg w/o group wholestage off 30136 / 31885 69.6 14.4 1.0X - agg w/o group wholestage on 1851 / 1860 1132.9 0.9 16.3X - */ - } - ignore("stat functions") { - val N = 100L << 20 + runBenchmark("stat functions") { + val N = 100L << 20 - runBenchmark("stddev", N) { - sparkSession.range(N).groupBy().agg("id" -> "stddev").collect() - } + codegenBenchmark("stddev", N) { + spark.range(N).groupBy().agg("id" -> "stddev").collect() + } - runBenchmark("kurtosis", N) { - sparkSession.range(N).groupBy().agg("id" -> "kurtosis").collect() + codegenBenchmark("kurtosis", N) { + spark.range(N).groupBy().agg("id" -> "kurtosis").collect() + } } - /* - Using ImperativeAggregate (as implemented in Spark 1.6): - - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - stddev w/o codegen 2019.04 10.39 1.00 X - stddev w codegen 2097.29 10.00 0.96 X - kurtosis w/o codegen 2108.99 9.94 0.96 X - kurtosis w codegen 2090.69 10.03 0.97 X - - Using DeclarativeAggregate: - - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - stddev codegen=false 5630 / 5776 18.0 55.6 1.0X - stddev codegen=true 1259 / 1314 83.0 12.0 4.5X - - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X - kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X - */ - } - - ignore("aggregate with linear keys") { - val N = 20 << 22 + runBenchmark("aggregate with linear keys") { + val N = 20 << 22 - val benchmark = new Benchmark("Aggregate w keys", N) - def f(): Unit = { - sparkSession.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() - } + val benchmark = new Benchmark("Aggregate w keys", N, output = output) - benchmark.addCase(s"codegen = F", numIters = 2) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") - f() - } + def f(): Unit = { + spark.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } - benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") - f() - } + benchmark.addCase("codegen = F", numIters = 2) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - f() - } + benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + f() + } + } - benchmark.run() + benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + f() + } + } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + benchmark.run() + } - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - codegen = F 6619 / 6780 12.7 78.9 1.0X - codegen = T hashmap = F 3935 / 4059 21.3 46.9 1.7X - codegen = T hashmap = T 897 / 971 93.5 10.7 7.4X - */ - } + runBenchmark("aggregate with randomized keys") { + val N = 20 << 22 - ignore("aggregate with randomized keys") { - val N = 20 << 22 + val benchmark = new Benchmark("Aggregate w keys", N, output = output) + spark.range(N).selectExpr("id", "floor(rand() * 10000) as k") + .createOrReplaceTempView("test") - val benchmark = new Benchmark("Aggregate w keys", N) - sparkSession.range(N).selectExpr("id", "floor(rand() * 10000) as k") - .createOrReplaceTempView("test") + def f(): Unit = spark.sql("select k, k, sum(id) from test group by k, k").collect() - def f(): Unit = sparkSession.sql("select k, k, sum(id) from test group by k, k").collect() + benchmark.addCase("codegen = F", numIters = 2) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f() + } + } - benchmark.addCase(s"codegen = F", numIters = 2) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) - f() - } + benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") - f() - } + benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - f() + benchmark.run() } - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + runBenchmark("aggregate with string key") { + val N = 20 << 20 - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - codegen = F 7445 / 7517 11.3 88.7 1.0X - codegen = T hashmap = F 4672 / 4703 18.0 55.7 1.6X - codegen = T hashmap = T 1764 / 1958 47.6 21.0 4.2X - */ - } + val benchmark = new Benchmark("Aggregate w string key", N, output = output) - ignore("aggregate with string key") { - val N = 20 << 20 + def f(): Unit = spark.range(N).selectExpr("id", "cast(id & 1023 as string) as k") + .groupBy("k").count().collect() - val benchmark = new Benchmark("Aggregate w string key", N) - def f(): Unit = sparkSession.range(N).selectExpr("id", "cast(id & 1023 as string) as k") - .groupBy("k").count().collect() + benchmark.addCase("codegen = F", numIters = 2) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f() + } + } - benchmark.addCase(s"codegen = F", numIters = 2) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") - f() - } + benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") - f() - } + benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - f() + benchmark.run() } - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w string key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 3307 / 3376 6.3 157.7 1.0X - codegen = T hashmap = F 2364 / 2471 8.9 112.7 1.4X - codegen = T hashmap = T 1740 / 1841 12.0 83.0 1.9X - */ - } - - ignore("aggregate with decimal key") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w decimal key", N) - def f(): Unit = sparkSession.range(N).selectExpr("id", "cast(id & 65535 as decimal) as k") - .groupBy("k").count().collect() - - benchmark.addCase(s"codegen = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") - f() - } + runBenchmark("aggregate with decimal key") { + val N = 20 << 20 - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") - f() - } + val benchmark = new Benchmark("Aggregate w decimal key", N, output = output) - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - f() - } + def f(): Unit = spark.range(N).selectExpr("id", "cast(id & 65535 as decimal) as k") + .groupBy("k").count().collect() - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 2756 / 2817 7.6 131.4 1.0X - codegen = T hashmap = F 1580 / 1647 13.3 75.4 1.7X - codegen = T hashmap = T 641 / 662 32.7 30.6 4.3X - */ - } + benchmark.addCase("codegen = F") { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f() + } + } - ignore("aggregate with multiple key types") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w multiple keys", N) - def f(): Unit = sparkSession.range(N) - .selectExpr( - "id", - "(id & 1023) as k1", - "cast(id & 1023 as string) as k2", - "cast(id & 1023 as int) as k3", - "cast(id & 1023 as double) as k4", - "cast(id & 1023 as float) as k5", - "id > 1023 as k6") - .groupBy("k1", "k2", "k3", "k4", "k5", "k6") - .sum() - .collect() - - benchmark.addCase(s"codegen = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") - f() - } + benchmark.addCase("codegen = T hashmap = F") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") - f() - } + benchmark.addCase("codegen = T hashmap = T") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + f() + } + } - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - f() + benchmark.run() } - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 5885 / 6091 3.6 280.6 1.0X - codegen = T hashmap = F 3625 / 4009 5.8 172.8 1.6X - codegen = T hashmap = T 3204 / 3271 6.5 152.8 1.8X - */ - } + runBenchmark("aggregate with multiple key types") { + val N = 20 << 20 - ignore("max function bytecode size of wholestagecodegen") { - val N = 20 << 15 - - val benchmark = new Benchmark("max function bytecode size", N) - def f(): Unit = sparkSession.range(N) - .selectExpr( - "id", - "(id & 1023) as k1", - "cast(id & 1023 as double) as k2", - "cast(id & 1023 as int) as k3", - "case when id > 100 and id <= 200 then 1 else 0 end as v1", - "case when id > 200 and id <= 300 then 1 else 0 end as v2", - "case when id > 300 and id <= 400 then 1 else 0 end as v3", - "case when id > 400 and id <= 500 then 1 else 0 end as v4", - "case when id > 500 and id <= 600 then 1 else 0 end as v5", - "case when id > 600 and id <= 700 then 1 else 0 end as v6", - "case when id > 700 and id <= 800 then 1 else 0 end as v7", - "case when id > 800 and id <= 900 then 1 else 0 end as v8", - "case when id > 900 and id <= 1000 then 1 else 0 end as v9", - "case when id > 1000 and id <= 1100 then 1 else 0 end as v10", - "case when id > 1100 and id <= 1200 then 1 else 0 end as v11", - "case when id > 1200 and id <= 1300 then 1 else 0 end as v12", - "case when id > 1300 and id <= 1400 then 1 else 0 end as v13", - "case when id > 1400 and id <= 1500 then 1 else 0 end as v14", - "case when id > 1500 and id <= 1600 then 1 else 0 end as v15", - "case when id > 1600 and id <= 1700 then 1 else 0 end as v16", - "case when id > 1700 and id <= 1800 then 1 else 0 end as v17", - "case when id > 1800 and id <= 1900 then 1 else 0 end as v18") - .groupBy("k1", "k2", "k3") - .sum() - .collect() - - benchmark.addCase("codegen = F") { iter => - sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") - f() - } + val benchmark = new Benchmark("Aggregate w multiple keys", N, output = output) - benchmark.addCase("codegen = T hugeMethodLimit = 10000") { iter => - sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "10000") - f() - } + def f(): Unit = spark.range(N) + .selectExpr( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as string) as k2", + "cast(id & 1023 as int) as k3", + "cast(id & 1023 as double) as k4", + "cast(id & 1023 as float) as k5", + "id > 1023 as k6") + .groupBy("k1", "k2", "k3", "k4", "k5", "k6") + .sum() + .collect() - benchmark.addCase("codegen = T hugeMethodLimit = 1500") { iter => - sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "1500") - f() - } + benchmark.addCase("codegen = F") { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f() + } + } - benchmark.run() + benchmark.addCase("codegen = T hashmap = F") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "false") { + f() + } + } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 - Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + benchmark.addCase("codegen = T hashmap = T") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + "spark.sql.codegen.aggregate.map.vectorized.enable" -> "true") { + f() + } + } - max function bytecode size: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - codegen = F 709 / 803 0.9 1082.1 1.0X - codegen = T hugeMethodLimit = 10000 3485 / 3548 0.2 5317.7 0.2X - codegen = T hugeMethodLimit = 1500 636 / 701 1.0 969.9 1.1X - */ - } + benchmark.run() + } + + runBenchmark("max function bytecode size of wholestagecodegen") { + val N = 20 << 15 + + val benchmark = new Benchmark("max function bytecode size", N, output = output) + + def f(): Unit = spark.range(N) + .selectExpr( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as double) as k2", + "cast(id & 1023 as int) as k3", + "case when id > 100 and id <= 200 then 1 else 0 end as v1", + "case when id > 200 and id <= 300 then 1 else 0 end as v2", + "case when id > 300 and id <= 400 then 1 else 0 end as v3", + "case when id > 400 and id <= 500 then 1 else 0 end as v4", + "case when id > 500 and id <= 600 then 1 else 0 end as v5", + "case when id > 600 and id <= 700 then 1 else 0 end as v6", + "case when id > 700 and id <= 800 then 1 else 0 end as v7", + "case when id > 800 and id <= 900 then 1 else 0 end as v8", + "case when id > 900 and id <= 1000 then 1 else 0 end as v9", + "case when id > 1000 and id <= 1100 then 1 else 0 end as v10", + "case when id > 1100 and id <= 1200 then 1 else 0 end as v11", + "case when id > 1200 and id <= 1300 then 1 else 0 end as v12", + "case when id > 1300 and id <= 1400 then 1 else 0 end as v13", + "case when id > 1400 and id <= 1500 then 1 else 0 end as v14", + "case when id > 1500 and id <= 1600 then 1 else 0 end as v15", + "case when id > 1600 and id <= 1700 then 1 else 0 end as v16", + "case when id > 1700 and id <= 1800 then 1 else 0 end as v17", + "case when id > 1800 and id <= 1900 then 1 else 0 end as v18") + .groupBy("k1", "k2", "k3") + .sum() + .collect() + + benchmark.addCase("codegen = F") { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f() + } + } + benchmark.addCase("codegen = T hugeMethodLimit = 10000") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "10000") { + f() + } + } - ignore("cube") { - val N = 5 << 20 + benchmark.addCase("codegen = T hugeMethodLimit = 1500") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "1500") { + f() + } + } - runBenchmark("cube", N) { - sparkSession.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") - .cube("k1", "k2").sum("id").collect() + benchmark.run() } - /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - cube codegen=false 3188 / 3392 1.6 608.2 1.0X - cube codegen=true 1239 / 1394 4.2 236.3 2.6X - */ - } - - ignore("hash and BytesToBytesMap") { - val N = 20 << 20 - val benchmark = new Benchmark("BytesToBytesMap", N) + runBenchmark("cube") { + val N = 5 << 20 - benchmark.addCase("UnsafeRowhash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var s = 0 - while (i < N) { - key.setInt(0, i % 1000) - val h = Murmur3_x86_32.hashUnsafeWords( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) - s += h - i += 1 + codegenBenchmark("cube", N) { + spark.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") + .cube("k1", "k2").sum("id").collect() } } - benchmark.addCase("murmur3 hash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 - var s = 0 - while (i < N) { - var h = Murmur3_x86_32.hashLong(i, 42) - key.setInt(0, h) - s += h - i += 1 - } - } + runBenchmark("hash and BytesToBytesMap") { + val N = 20 << 20 - benchmark.addCase("fast hash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 - var s = 0 - while (i < N) { - var h = i % p - if (h < 0) { - h += p - } - key.setInt(0, h) - s += h - i += 1 - } - } + val benchmark = new Benchmark("BytesToBytesMap", N, output = output) - benchmark.addCase("arrayEqual") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - var s = 0 - while (i < N) { - key.setInt(0, i % 1000) - if (key.equals(value)) { - s += 1 - } - i += 1 + benchmark.addCase("UnsafeRowhash") { _ => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) + s += h + i += 1 + } } - } - benchmark.addCase("Java HashMap (Long)") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val map = new HashMap[Long, UnsafeRow]() - while (i < 65536) { - value.setInt(0, i) - map.put(i.toLong, value) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - if (map.get(i % 100000) != null) { - s += 1 - } - i += 1 + benchmark.addCase("murmur3 hash") { _ => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } } - } - benchmark.addCase("Java HashMap (two ints) ") { iter => - var i = 0 - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val map = new HashMap[Long, UnsafeRow]() - while (i < 65536) { - value.setInt(0, i) - val key = (i.toLong << 32) + Integer.rotateRight(i, 15) - map.put(key, value) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) - if (map.get(key) != null) { - s += 1 - } - i += 1 + benchmark.addCase("fast hash") { _ => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) + s += h + i += 1 + } } - } - benchmark.addCase("Java HashMap (UnsafeRow)") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val map = new HashMap[UnsafeRow, UnsafeRow]() - while (i < 65536) { - key.setInt(0, i) - value.setInt(0, i) - map.put(key, value.copy()) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - key.setInt(0, i % 100000) - if (map.get(key) != null) { - s += 1 - } - i += 1 + benchmark.addCase("arrayEqual") { _ => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + if (key.equals(value)) { + s += 1 + } + i += 1 + } } - } - Seq(false, true).foreach { optimized => - benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + benchmark.addCase("Java HashMap (Long)") { _ => var i = 0 + val keyBytes = new Array[Byte](16) val valueBytes = new Array[Byte](16) val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) value.setInt(0, 555) - val taskMemoryManager = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + val map = new HashMap[Long, UnsafeRow]() while (i < 65536) { value.setInt(0, i) - val key = i % 100000 - map.append(key, value) + map.put(i.toLong, value) i += 1 } - if (optimized) { - map.optimize() + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { _ => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 } var s = 0 i = 0 while (i < N) { - val key = i % 100000 - if (map.getValue(key, value) != null) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { s += 1 } i += 1 } } - } - Seq("off", "on").foreach { heap => - benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => - val taskMemoryManager = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, s"${heap == "off"}") - .set(MEMORY_OFFHEAP_SIZE.key, "102400000"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + benchmark.addCase("Java HashMap (UnsafeRow)") { _ => + var i = 0 val keyBytes = new Array[Byte](16) val valueBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var i = 0 - val numKeys = 65536 - while (i < numKeys) { - key.setInt(0, i % 65536) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - Murmur3_x86_32.hashLong(i % 65536, 42)) - if (!loc.isDefined) { - loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) - } + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) i += 1 } - i = 0 var s = 0 + i = 0 while (i < N) { key.setInt(0, i % 100000) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - Murmur3_x86_32.hashLong(i % 100000, 42)) - if (loc.isDefined) { + if (map.get(key) != null) { s += 1 } i += 1 } } - } - benchmark.addCase("Aggregate HashMap") { iter => - var i = 0 - val numKeys = 65536 - val schema = new StructType() - .add("key", LongType) - .add("value", LongType) - val map = new AggregateHashMap(schema) - while (i < numKeys) { - val row = map.findOrInsert(i.toLong) - row.setLong(1, row.getLong(1) + 1) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - if (map.find(i % 100000) != -1) { - s += 1 - } - i += 1 + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { _ => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { _ => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, s"${heap == "off"}") + .set(MEMORY_OFFHEAP_SIZE.key, "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L << 20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + val numKeys = 65536 + while (i < numKeys) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 65536, 42)) + if (!loc.isDefined) { + loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 + } + } } - } - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - UnsafeRow hash 267 / 284 78.4 12.8 1.0X - murmur3 hash 102 / 129 205.5 4.9 2.6X - fast hash 79 / 96 263.8 3.8 3.4X - arrayEqual 164 / 172 128.2 7.8 1.6X - Java HashMap (Long) 321 / 399 65.4 15.3 0.8X - Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X - Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X - LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X - LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X - BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X - BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X - Aggregate HashMap 121 / 131 173.3 5.8 2.2X - */ - benchmark.run() + benchmark.addCase("Aggregate HashMap") { _ => + var i = 0 + val numKeys = 65536 + val schema = new StructType() + .add("key", LongType) + .add("value", LongType) + val map = new AggregateHashMap(schema) + while (i < numKeys) { + val row = map.findOrInsert(i.toLong) + row.setLong(1, row.getLong(1) + 1) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.find(i % 100000) != -1) { + s += 1 + } + i += 1 + } + } + benchmark.run() + } } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala new file mode 100644 index 0000000000000..e95e5a960246b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf + +/** + * Common base trait to run benchmark with the Dataset and DataFrame API. + */ +trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper { + + protected val spark: SparkSession = getSparkSession + + /** Subclass can override this function to build their own SparkSession */ + def getSparkSession: SparkSession = { + SparkSession.builder() + .master("local[1]") + .appName(this.getClass.getCanonicalName) + .config(SQLConf.SHUFFLE_PARTITIONS.key, 1) + .config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, 1) + .getOrCreate() + } + + /** Runs function `f` with whole stage codegen on and off. */ + final def codegenBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase(s"$name wholestage off", numIters = 2) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f + } + } + + benchmark.addCase(s"$name wholestage on", numIters = 5) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + f + } + } + + benchmark.run() + } +} From a802c69b130b69a35b372ffe1b01289577f6fafb Mon Sep 17 00:00:00 2001 From: "marek.simunek" Date: Mon, 1 Oct 2018 11:04:37 -0500 Subject: [PATCH 1744/2461] [SPARK-18364][YARN] Expose metrics for YarnShuffleService ## What changes were proposed in this pull request? This PR is follow-up of closed https://github.com/apache/spark/pull/17401 which only ended due to of inactivity, but its still nice feature to have. Given review by jerryshao taken in consideration and edited: - VisibleForTesting deleted because of dependency conflicts - removed unnecessary reflection for `MetricsSystemImpl` - added more available types for gauge ## How was this patch tested? Manual deploy of new yarn-shuffle jar into a Node Manager and verifying that the metrics appear in the Node Manager-standard location. This is JMX with an query endpoint running on `hostname:port` Resulting metrics look like this: ``` curl -sk -XGET hostname:port | grep -v '#' | grep 'shuffleService' hadoop_nodemanager_openblockrequestlatencymillis_rate15{name="shuffleService",} 0.31428910657834713 hadoop_nodemanager_blocktransferratebytes_rate15{name="shuffleService",} 566144.9983653595 hadoop_nodemanager_blocktransferratebytes_ratemean{name="shuffleService",} 2464409.9678099006 hadoop_nodemanager_openblockrequestlatencymillis_rate1{name="shuffleService",} 1.2893844732240272 hadoop_nodemanager_registeredexecutorssize{name="shuffleService",} 2.0 hadoop_nodemanager_openblockrequestlatencymillis_ratemean{name="shuffleService",} 1.255574678369966 hadoop_nodemanager_openblockrequestlatencymillis_count{name="shuffleService",} 315.0 hadoop_nodemanager_openblockrequestlatencymillis_rate5{name="shuffleService",} 0.7661929192569739 hadoop_nodemanager_registerexecutorrequestlatencymillis_ratemean{name="shuffleService",} 0.0 hadoop_nodemanager_registerexecutorrequestlatencymillis_count{name="shuffleService",} 0.0 hadoop_nodemanager_registerexecutorrequestlatencymillis_rate1{name="shuffleService",} 0.0 hadoop_nodemanager_registerexecutorrequestlatencymillis_rate5{name="shuffleService",} 0.0 hadoop_nodemanager_blocktransferratebytes_count{name="shuffleService",} 6.18271213E8 hadoop_nodemanager_registerexecutorrequestlatencymillis_rate15{name="shuffleService",} 0.0 hadoop_nodemanager_blocktransferratebytes_rate5{name="shuffleService",} 1154114.4881816586 hadoop_nodemanager_blocktransferratebytes_rate1{name="shuffleService",} 574745.0749848988 ``` Closes #22485 from mareksimunek/SPARK-18364. Lead-authored-by: marek.simunek Co-authored-by: Andrew Ash Signed-off-by: Thomas Graves --- .../network/yarn/YarnShuffleService.java | 11 ++ .../yarn/YarnShuffleServiceMetrics.java | 137 ++++++++++++++++++ .../yarn/YarnShuffleServiceMetricsSuite.scala | 73 ++++++++++ 3 files changed, 221 insertions(+) create mode 100644 common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index d8b2ed6b5dc7b..72ae1a1295236 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -35,6 +35,8 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.metrics2.impl.MetricsSystemImpl; +import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.server.api.*; import org.apache.spark.network.util.LevelDBProvider; @@ -168,6 +170,15 @@ protected void serviceInit(Configuration conf) throws Exception { TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + // register metrics on the block handler into the Node Manager's metrics system. + YarnShuffleServiceMetrics serviceMetrics = + new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); + + MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance(); + metricsSystem.register( + "sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics); + logger.info("Registered metrics with Hadoop's DefaultMetricsSystem"); + // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests List bootstraps = Lists.newArrayList(); diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java new file mode 100644 index 0000000000000..3e4d479b862b3 --- /dev/null +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.util.Map; + +import com.codahale.metrics.*; +import org.apache.hadoop.metrics2.MetricsCollector; +import org.apache.hadoop.metrics2.MetricsInfo; +import org.apache.hadoop.metrics2.MetricsRecordBuilder; +import org.apache.hadoop.metrics2.MetricsSource; + +/** + * Forward {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler.ShuffleMetrics} + * to hadoop metrics system. + * NodeManager by default exposes JMX endpoint where can be collected. + */ +class YarnShuffleServiceMetrics implements MetricsSource { + + private final MetricSet metricSet; + + YarnShuffleServiceMetrics(MetricSet metricSet) { + this.metricSet = metricSet; + } + + /** + * Get metrics from the source + * + * @param collector to contain the resulting metrics snapshot + * @param all if true, return all metrics even if unchanged. + */ + @Override + public void getMetrics(MetricsCollector collector, boolean all) { + MetricsRecordBuilder metricsRecordBuilder = collector.addRecord("sparkShuffleService"); + + for (Map.Entry entry : metricSet.getMetrics().entrySet()) { + collectMetric(metricsRecordBuilder, entry.getKey(), entry.getValue()); + } + } + + /** + * The metric types used in + * {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler.ShuffleMetrics}. + * Visible for testing. + */ + public static void collectMetric( + MetricsRecordBuilder metricsRecordBuilder, String name, Metric metric) { + + if (metric instanceof Timer) { + Timer t = (Timer) metric; + metricsRecordBuilder + .addCounter(new ShuffleServiceMetricsInfo(name + "_count", "Count of timer " + name), + t.getCount()) + .addGauge( + new ShuffleServiceMetricsInfo(name + "_rate15", "15 minute rate of timer " + name), + t.getFifteenMinuteRate()) + .addGauge( + new ShuffleServiceMetricsInfo(name + "_rate5", "5 minute rate of timer " + name), + t.getFiveMinuteRate()) + .addGauge( + new ShuffleServiceMetricsInfo(name + "_rate1", "1 minute rate of timer " + name), + t.getOneMinuteRate()) + .addGauge(new ShuffleServiceMetricsInfo(name + "_rateMean", "Mean rate of timer " + name), + t.getMeanRate()); + } else if (metric instanceof Meter) { + Meter m = (Meter) metric; + metricsRecordBuilder + .addCounter(new ShuffleServiceMetricsInfo(name + "_count", "Count of meter " + name), + m.getCount()) + .addGauge( + new ShuffleServiceMetricsInfo(name + "_rate15", "15 minute rate of meter " + name), + m.getFifteenMinuteRate()) + .addGauge( + new ShuffleServiceMetricsInfo(name + "_rate5", "5 minute rate of meter " + name), + m.getFiveMinuteRate()) + .addGauge( + new ShuffleServiceMetricsInfo(name + "_rate1", "1 minute rate of meter " + name), + m.getOneMinuteRate()) + .addGauge(new ShuffleServiceMetricsInfo(name + "_rateMean", "Mean rate of meter " + name), + m.getMeanRate()); + } else if (metric instanceof Gauge) { + final Object gaugeValue = ((Gauge) metric).getValue(); + if (gaugeValue instanceof Integer) { + metricsRecordBuilder.addGauge(getShuffleServiceMetricsInfo(name), (Integer) gaugeValue); + } else if (gaugeValue instanceof Long) { + metricsRecordBuilder.addGauge(getShuffleServiceMetricsInfo(name), (Long) gaugeValue); + } else if (gaugeValue instanceof Float) { + metricsRecordBuilder.addGauge(getShuffleServiceMetricsInfo(name), (Float) gaugeValue); + } else if (gaugeValue instanceof Double) { + metricsRecordBuilder.addGauge(getShuffleServiceMetricsInfo(name), (Double) gaugeValue); + } else { + throw new IllegalStateException( + "Not supported class type of metric[" + name + "] for value " + gaugeValue); + } + } + } + + private static MetricsInfo getShuffleServiceMetricsInfo(String name) { + return new ShuffleServiceMetricsInfo(name, "Value of gauge " + name); + } + + private static class ShuffleServiceMetricsInfo implements MetricsInfo { + + private final String name; + private final String description; + + ShuffleServiceMetricsInfo(String name, String description) { + this.name = name; + this.description = description; + } + + @Override + public String name() { + return name; + } + + @Override + public String description() { + return description; + } + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala new file mode 100644 index 0000000000000..40b92282a3b8f --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.metrics2.MetricsRecordBuilder +import org.mockito.Matchers._ +import org.mockito.Mockito.{mock, times, verify, when} +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.server.OneForOneStreamManager +import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleBlockResolver} + +class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { + + val streamManager = mock(classOf[OneForOneStreamManager]) + val blockResolver = mock(classOf[ExternalShuffleBlockResolver]) + when(blockResolver.getRegisteredExecutorsSize).thenReturn(42) + + val metrics = new ExternalShuffleBlockHandler(streamManager, blockResolver).getAllMetrics + + test("metrics named as expected") { + val allMetrics = Set( + "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", + "blockTransferRateBytes", "registeredExecutorsSize") + + metrics.getMetrics.keySet().asScala should be (allMetrics) + } + + // these three metrics have the same effect on the collector + for (testname <- Seq("openBlockRequestLatencyMillis", + "registerExecutorRequestLatencyMillis", + "blockTransferRateBytes")) { + test(s"$testname - collector receives correct types") { + val builder = mock(classOf[MetricsRecordBuilder]) + when(builder.addCounter(any(), anyLong())).thenReturn(builder) + when(builder.addGauge(any(), anyDouble())).thenReturn(builder) + + YarnShuffleServiceMetrics.collectMetric(builder, testname, + metrics.getMetrics.get(testname)) + + verify(builder).addCounter(anyObject(), anyLong()) + verify(builder, times(4)).addGauge(anyObject(), anyDouble()) + } + } + + // this metric writes only one gauge to the collector + test("registeredExecutorsSize - collector receives correct types") { + val builder = mock(classOf[MetricsRecordBuilder]) + + YarnShuffleServiceMetrics.collectMetric(builder, "registeredExecutorsSize", + metrics.getMetrics.get("registeredExecutorsSize")) + + // only one + verify(builder).addGauge(anyObject(), anyInt()) + } +} From 3422fc0b6cfffb5834ce94024167458a67f0a01f Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 1 Oct 2018 17:45:12 -0500 Subject: [PATCH 1745/2461] [SPARK-25575][WEBUI][SQL] SQL tab in the spark UI support hide tables, to make it consistent with other tabs. ## What changes were proposed in this pull request? Currently, SQL tab in the WEBUI doesn't support hiding table. Other tabs in the web ui like, Jobs, stages etc supports hiding table (refer SPARK-23024 https://github.com/apache/spark/pull/20216). In this PR, added the support for hide table in the sql tab also. ## How was this patch tested? bin/spark-shell ``` sql("create table a (id int)") for(i <- 1 to 100) sql(s"insert into a values ($i)") ``` Open SQL tab in the web UI **Before fix:** ![image](https://user-images.githubusercontent.com/23054875/46249137-f5c44880-c441-11e8-953a-a811e33ac24d.png) **After fix:** Consistent with the other tabs. ![screenshot from 2018-09-30 00-11-28](https://user-images.githubusercontent.com/23054875/46249354-75074b80-c445-11e8-9417-28751fd8628a.png) (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22592 from shahidki31/SPARK-25575. Authored-by: Shahid Signed-off-by: Sean Owen --- .../org/apache/spark/ui/static/webui.js | 3 + .../sql/execution/ui/AllExecutionsPage.scala | 65 +++++++++++++------ 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index f01c567ba58ad..12c056af9a51a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -83,4 +83,7 @@ $(function() { collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); + collapseTablePageLoad('collapse-aggregated-runningExecutions','runningExecutions'); + collapseTablePageLoad('collapse-aggregated-completedExecutions','completedExecutions'); + collapseTablePageLoad('collapse-aggregated-failedExecutions','failedExecutions'); }); \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index a7a24ac3641b5..1b2d8a821b364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -55,24 +55,57 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val _content = mutable.ListBuffer[Node]() if (running.nonEmpty) { + val runningPageTable = new RunningExecutionTable( + parent, currentTime, running.sortBy(_.submissionTime).reverse).toNodeSeq(request) + _content ++= - new RunningExecutionTable( - parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq(request) + +

      + + Running Queries ({running.size}) +

      +
      ++ +
      + {runningPageTable} +
      } if (completed.nonEmpty) { + val completedPageTable = new CompletedExecutionTable( + parent, currentTime, completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) + _content ++= - new CompletedExecutionTable( - parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) + +

      + + Completed Queries ({completed.size}) +

      +
      ++ +
      + {completedPageTable} +
      } if (failed.nonEmpty) { + val failedPageTable = new FailedExecutionTable( + parent, currentTime, failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) + _content ++= - new FailedExecutionTable( - parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) + +

      + + Failed Queries ({failed.size}) +

      +
      ++ +
      + {failedPageTable} +
      } _content } @@ -118,7 +151,6 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L private[ui] abstract class ExecutionTable( parent: SQLTab, tableId: String, - tableName: String, currentTime: Long, executionUIDatas: Seq[SQLExecutionUIData], showRunningJobs: Boolean, @@ -206,11 +238,8 @@ private[ui] abstract class ExecutionTable( } def toNodeSeq(request: HttpServletRequest): Seq[Node] = { -
      -

      {tableName}

      - {UIUtils.listingTable[SQLExecutionUIData]( - header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))} -
      + UIUtils.listingTable[SQLExecutionUIData]( + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId)) } private def jobURL(request: HttpServletRequest, jobId: Long): String = @@ -223,13 +252,11 @@ private[ui] abstract class ExecutionTable( private[ui] class RunningExecutionTable( parent: SQLTab, - tableName: String, currentTime: Long, executionUIDatas: Seq[SQLExecutionUIData]) extends ExecutionTable( parent, "running-execution-table", - tableName, currentTime, executionUIDatas, showRunningJobs = true, @@ -242,13 +269,11 @@ private[ui] class RunningExecutionTable( private[ui] class CompletedExecutionTable( parent: SQLTab, - tableName: String, currentTime: Long, executionUIDatas: Seq[SQLExecutionUIData]) extends ExecutionTable( parent, "completed-execution-table", - tableName, currentTime, executionUIDatas, showRunningJobs = false, @@ -260,13 +285,11 @@ private[ui] class CompletedExecutionTable( private[ui] class FailedExecutionTable( parent: SQLTab, - tableName: String, currentTime: Long, executionUIDatas: Seq[SQLExecutionUIData]) extends ExecutionTable( parent, "failed-execution-table", - tableName, currentTime, executionUIDatas, showRunningJobs = false, From 5114db5781967c1e8046296905d97560187479fb Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 Oct 2018 21:35:12 -0500 Subject: [PATCH 1746/2461] [SPARK-25578][BUILD] Update to Scala 2.12.7 ## What changes were proposed in this pull request? Update to Scala 2.12.7. See https://issues.apache.org/jira/browse/SPARK-25578 for why. ## How was this patch tested? Existing tests. Closes #22600 from srowen/SPARK-25578. Authored-by: Sean Owen Signed-off-by: Sean Owen --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 099a08185d2a5..172a6ce87eac6 100644 --- a/pom.xml +++ b/pom.xml @@ -2760,7 +2760,7 @@ scala-2.12 - 2.12.6 + 2.12.7 2.12 From 71876633f3af706408355b5fb561b58dbc593360 Mon Sep 17 00:00:00 2001 From: Shahid Date: Tue, 2 Oct 2018 08:05:09 -0700 Subject: [PATCH 1747/2461] [SPARK-25583][DOC] Add history-server related configuration in the documentation. ## What changes were proposed in this pull request? Add history-server related configuration in the documentation. Some of the history server related configurations were missing in the documentation.Like, 'spark.history.store.maxDiskUsage', 'spark.ui.liveUpdate.period' etc. ## How was this patch tested? ![screenshot from 2018-10-01 20-58-26](https://user-images.githubusercontent.com/23054875/46298568-04833a80-c5bd-11e8-95b8-54c9d6582fd2.png) ![screenshot from 2018-10-01 20-59-31](https://user-images.githubusercontent.com/23054875/46298591-11a02980-c5bd-11e8-93d0-892afdfd4f9a.png) ![screenshot from 2018-10-01 20-59-45](https://user-images.githubusercontent.com/23054875/46298601-1533b080-c5bd-11e8-9689-e9b39882a7b5.png) Closes #22601 from shahidki31/historyConf. Authored-by: Shahid Signed-off-by: Dongjoon Hyun --- docs/configuration.md | 16 ++++++++++++++++ docs/monitoring.md | 25 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 782ccff667076..55773937d4d71 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -793,6 +793,13 @@ Apart from these, the following properties are also available, and may be useful Buffer size to use when writing to output streams, in KiB unless otherwise specified.
      + + + + + @@ -807,6 +814,15 @@ Apart from these, the following properties are also available, and may be useful Allows jobs and stages to be killed from the web UI. + + + + + diff --git a/docs/monitoring.md b/docs/monitoring.md index f6d52ef4597e9..69bf3082f0f27 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -185,6 +185,23 @@ Security options for the Spark History Server are covered more detail in the Job history files older than this will be deleted when the filesystem history cleaner runs. + + + + + + + + + + @@ -192,6 +209,14 @@ Security options for the Spark History Server are covered more detail in the Number of threads that will be used by history server to process event logs. + + + + + From 9bf397c0e45cb161f3f12f09bd2bf14ff96dc823 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Oct 2018 08:48:24 -0700 Subject: [PATCH 1748/2461] [SPARK-25592] Setting version to 3.0.0-SNAPSHOT ## What changes were proposed in this pull request? This patch is to bump the master branch version to 3.0.0-SNAPSHOT. ## How was this patch tested? N/A Closes #22606 from gatorsmile/bump3.0. Authored-by: gatorsmile Signed-off-by: gatorsmile --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/avro/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 6 +++--- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/version.py | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/kubernetes/integration-tests/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- sql/core/pom.xml | 2 +- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 4 ++-- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 2 +- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 50 files changed, 54 insertions(+), 54 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 96090bed6899b..cdaaa6104e6a9 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.5.0 +Version: 3.0.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/assembly/pom.xml b/assembly/pom.xml index d431d3f8caf28..b0337e58cca71 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index cdb4359d17a87..23a0f49206909 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 5a8e0eb46cf91..41fcbf0589499 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c8675cbd7d0..ff717057bb25d 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 40b9267a335e3..a1cf761d12d8b 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index c2adbf04563b5..adbbcb1cb3040 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index d0a3c7a61d1cf..f6627beabe84b 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index b0c52d9468226..62c493a5e1ed8 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index d881dc1f140ec..eff3aa1d19423 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/docs/_config.yml b/docs/_config.yml index dfc1a73f4ac1c..c3ef98575fa62 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.5.0 +SPARK_VERSION: 3.0.0-SNAPSHOT +SPARK_VERSION_SHORT: 3.0.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.12" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 6a736d599fabe..756c475b4748d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index ee31741524cf8..9d8f319cc9396 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index bd3e4adf99186..f24254b698080 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 36bc02a742cd2..002bd6fb7f294 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 68059f0e121cc..168d9d3b2ae0a 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 4335732034737..1410ef7f4702d 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 668fbcd1103cf..4f9c3163b2408 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 5b8baedb71d1c..efd0862fb58ee 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 5aabef63a891c..f59f07265a0f4 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 4a202861bf380..83edb11f296ab 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 0d5e609ded3e8..4545877a9d83f 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 81eea530fa2bc..0bf4c265939e7 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 735e78192f5aa..032aca9077e20 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 07bc7fa722fb0..35a55b70baf33 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index c0f373d71053d..d65a8ceb62b9b 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index d20bdd0d68e93..d48162007e675 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 7c32aade17c71..b1b6126ea5934 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index e19c09f287656..ec5f9b0e92c8f 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index e4ac94aba462d..17ddb87c4d86a 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 172a6ce87eac6..cc20c5cbf8887 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f4c34a140e9ca..a931738032467 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,8 +34,8 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { - // Exclude rules for 2.5.x - lazy val v25excludes = v24excludes ++ Seq( + // Exclude rules for 3.0.x + lazy val v30excludes = v24excludes ++ Seq( ) // Exclude rules for 2.4.x @@ -1206,7 +1206,7 @@ object MimaExcludes { } def excludes(version: String) = version match { - case v if v.startsWith("2.5") => v25excludes + case v if v.startsWith("3.0") => v30excludes case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 21bc69b8236fd..bf6b990487617 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -883,7 +883,7 @@ def sampleBy(self, col, fractions, seed=None): >>> dataset.sampleBy(col("key"), fractions={2: 1.0}, seed=0).count() 33 - .. versionchanged:: 2.5 + .. versionchanged:: 3.0 Added sampling by a column of :class:`Column` """ if isinstance(col, basestring): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 74f0685b3cca6..3128d5792eead 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2355,7 +2355,7 @@ def schema_of_json(col, options={}): :param col: string column in json format :param options: options to control parsing. accepts the same options as the JSON datasource - .. versionchanged:: 2.5 + .. versionchanged:: 3.0 It accepts `options` parameter to control schema inferring. >>> from pyspark.sql.types import * diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 6bd00c59d506a..ba2a40cec01e6 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.5.0.dev0" +__version__ = "3.0.0.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index 17121216a021d..d2a89b2744018 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 06e522fa93f04..90bac19cba019 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 43b7857a79400..23453c8957b28 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 5674036ee7935..9585bdfafdcf4 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 8d5d6ab5a3f5a..e55b814be8465 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 224c70ce24d62..2e7df4fd14042 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f59b2a2ec510f..22b29c3000c16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -323,7 +323,7 @@ case class Divide(left: Expression, right: Expression) extends DivModLike { > SELECT 3 _FUNC_ 2; 1 """, - since = "2.5.0") + since = "3.0.0") // scalastyle:on line.size.limit case class IntegralDivide(left: Expression, right: Expression) extends DivModLike { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index f78126ef53077..2f72ff6cfdbfb 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 75b84773bd0b7..7c12432d33c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -414,7 +414,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * +-----+---+ * }}} * - * @since 2.5.0 + * @since 3.0.0 */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), @@ -437,7 +437,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @tparam T stratum type * @return a new `DataFrame` that represents the stratified sample * - * @since 2.5.0 + * @since 3.0.0 */ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index dbacdbff7383a..d4e75b5ebd405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -330,7 +330,7 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy("year").pivot("course").sum("earnings") * }}} * - * From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by + * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by * multiple columns, use the `struct` function to combine the columns and values: * * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 59a1fcb5ba367..367ac66dd77f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3620,7 +3620,7 @@ object functions { * @return a column with string literal containing schema in DDL format. * * @group collection_funcs - * @since 2.5.0 + * @since 3.0.0 */ def schema_of_json(e: Column, options: java.util.Map[String, String]): Column = { withExpr(SchemaOfJson(e.expr, options.asScala.toMap)) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 6c8e52d0afa6f..55e051c3ed1be 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0bb6026910fbd..ef22e2abfb53e 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 90c8b974c376d..f9a5029a8e818 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 14ccae6f8f187..247f5a6df4b08 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.5.0-SNAPSHOT + 3.0.0-SNAPSHOT ../pom.xml From 7b4e94f16096cd35835450d63620583496e4f978 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 2 Oct 2018 10:04:47 -0700 Subject: [PATCH 1749/2461] [SPARK-25581][SQL] Rename method `benchmark` as `runBenchmarkSuite` in `BenchmarkBase` ## What changes were proposed in this pull request? Rename method `benchmark` in `BenchmarkBase` as `runBenchmarkSuite `. Also add comments. Currently the method name `benchmark` is a bit confusing. Also the name is the same as instances of `Benchmark`: https://github.com/apache/spark/blob/f246813afba16fee4d703f09e6302011b11806f3/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala#L330-L339 ## How was this patch tested? Unit test. Closes #22599 from gengliangwang/renameBenchmarkSuite. Authored-by: Gengliang Wang Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/benchmark/BenchmarkBase.scala | 9 +++++++-- .../spark/mllib/linalg/UDTSerializationBenchmark.scala | 2 +- .../org/apache/spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../sql/execution/benchmark/AggregateBenchmark.scala | 2 +- .../execution/benchmark/FilterPushdownBenchmark.scala | 2 +- .../execution/benchmark/PrimitiveArrayBenchmark.scala | 2 +- .../spark/sql/execution/benchmark/SortBenchmark.scala | 2 +- .../compression/CompressionSchemeBenchmark.scala | 2 +- .../execution/vectorized/ColumnarBatchBenchmark.scala | 2 +- .../org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala | 2 +- 10 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 9a37e0221b27b..89e927e5784d2 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -25,7 +25,12 @@ import java.io.{File, FileOutputStream, OutputStream} abstract class BenchmarkBase { var output: Option[OutputStream] = None - def benchmark(): Unit + /** + * Main process of the whole benchmark. + * Implementations of this method are supposed to use the wrapper method `runBenchmark` + * for each benchmark scenario. + */ + def runBenchmarkSuite(): Unit final def runBenchmark(benchmarkName: String)(func: => Any): Unit = { val separator = "=" * 96 @@ -46,7 +51,7 @@ abstract class BenchmarkBase { output = Some(new FileOutputStream(file)) } - benchmark() + runBenchmarkSuite() output.foreach { o => if (o != null) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 1a2216ea070c4..6c1d58089867a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder */ object UDTSerializationBenchmark extends BenchmarkBase { - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("VectorUDT de/serialization") { val iters = 1e2.toInt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index cbe723fd11c6a..e7a99485cdf04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -41,7 +41,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("unsafe projection") { val iters = 1024 * 16 val numRows = 1024 * 16 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 296ae104a94a3..86e0df2fea350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -44,7 +44,7 @@ import org.apache.spark.unsafe.map.BytesToBytesMap */ object AggregateBenchmark extends SqlBasedBenchmark { - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("aggregate without grouping") { val N = 500L << 22 codegenBenchmark("agg w/o group", N) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 7cdf653e38697..cf05ca3361711 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -198,7 +198,7 @@ object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper { } } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Pushdown for many distinct value case") { withTempPath { dir => withTempTable("orcTable", "parquetTable") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index 8b275188f06d6..83edf73abfae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -36,7 +36,7 @@ object PrimitiveArrayBenchmark extends BenchmarkBase { .config("spark.sql.autoBroadcastJoinThreshold", 1) .getOrCreate() - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Write primitive arrays in dataset") { writeDatasetArray(4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 958a064402149..9a54e2320b80f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -119,7 +119,7 @@ object SortBenchmark extends BenchmarkBase { benchmark.run() } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("radix sort") { sortBenchmark() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index ff0e4acd31279..0f9079744a220 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -233,7 +233,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Compression Scheme Benchmark") { bitEncodingBenchmark(1024) shortEncodingBenchmark(1024) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index df6ab14e661c4..f311465e582ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -443,7 +443,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { benchmark.run } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("Int Read/Write") { intAccess(1024 * 40) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 0bb5e8c141595..870ad4818eb28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -336,7 +336,7 @@ object OrcReadBenchmark extends BenchmarkBase with SQLHelper { } } - override def benchmark(): Unit = { + override def runBenchmarkSuite(): Unit = { runBenchmark("SQL Single Numeric Column Scan") { Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) From d6be46eb9c68e48f6c0ed1e461649ee9575b2426 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 2 Oct 2018 10:10:22 -0700 Subject: [PATCH 1750/2461] [SPARK-24530][FOLLOWUP] run Sphinx with python 3 in docker ## What changes were proposed in this pull request? SPARK-24530 discovered a problem of generation python doc, and provided a fix: setting SPHINXPYTHON to python 3. This PR makes this fix automatic in the release script using docker. ## How was this patch tested? verified by the 2.4.0 rc2 Closes #22607 from cloud-fan/python. Authored-by: Wenchen Fan Signed-off-by: Marcelo Vanzin --- dev/create-release/do-release-docker.sh | 3 +++ dev/create-release/spark-rm/Dockerfile | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh index fa7b73cdb40ec..c1a122ebfb12e 100755 --- a/dev/create-release/do-release-docker.sh +++ b/dev/create-release/do-release-docker.sh @@ -135,6 +135,9 @@ if [ -n "$JAVA" ]; then JAVA_VOL="--volume $JAVA:/opt/spark-java" fi +# SPARK-24530: Sphinx must work with python 3 to generate doc correctly. +echo "SPHINXPYTHON=/opt/p35/bin/python" >> $ENVFILE + echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" docker run -ti \ --env-file "$ENVFILE" \ diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 15f831cf06a66..42315446016cf 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -62,8 +62,8 @@ RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt pip install $BASE_PIP_PKGS && \ pip install $PIP_PKGS && \ cd && \ - virtualenv -p python3 p35 && \ - . p35/bin/activate && \ + virtualenv -p python3 /opt/p35 && \ + . /opt/p35/bin/activate && \ pip install $BASE_PIP_PKGS && \ pip install $PIP_PKGS && \ # Install R packages and dependencies used when building. From 928d0739c45d0fbb1d3bfc09c0ed7a213f09f3e5 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 3 Oct 2018 17:08:55 +0800 Subject: [PATCH 1751/2461] [SPARK-25595] Ignore corrupt Avro files if flag IGNORE_CORRUPT_FILES enabled ## What changes were proposed in this pull request? With flag `IGNORE_CORRUPT_FILES` enabled, schema inference should ignore corrupt Avro files, which is consistent with Parquet and Orc data source. ## How was this patch tested? Unit test Closes #22611 from gengliangwang/ignoreCorruptAvro. Authored-by: Gengliang Wang Signed-off-by: hyukjinkwon --- .../spark/sql/avro/AvroFileFormat.scala | 78 ++++++++++++------- .../org/apache/spark/sql/avro/AvroSuite.scala | 43 ++++++++++ 2 files changed, 93 insertions(+), 28 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 6df23c93e4c54..e60fa88cbeba9 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -32,14 +32,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister with Logging with Serializable { @@ -59,36 +59,13 @@ private[avro] class AvroFileFormat extends FileFormat val conf = spark.sessionState.newHadoopConf() val parsedOptions = new AvroOptions(options, conf) - // Schema evolution is not supported yet. Here we only pick a single random sample file to - // figure out the schema of the whole dataset. - val sampleFile = - if (parsedOptions.ignoreExtension) { - files.headOption.getOrElse { - throw new FileNotFoundException("Files for schema inferring have been not found.") - } - } else { - files.find(_.getPath.getName.endsWith(".avro")).getOrElse { - throw new FileNotFoundException( - "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") - } - } - // User can specify an optional avro json schema. val avroSchema = parsedOptions.schema .map(new Schema.Parser().parse) .getOrElse { - val in = new FsInput(sampleFile.getPath, conf) - try { - val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) - try { - reader.getSchema - } finally { - reader.close() - } - } finally { - in.close() - } - } + inferAvroSchemaFromFiles(files, conf, parsedOptions.ignoreExtension, + spark.sessionState.conf.ignoreCorruptFiles) + } SchemaConverters.toSqlType(avroSchema).dataType match { case t: StructType => Some(t) @@ -100,6 +77,51 @@ private[avro] class AvroFileFormat extends FileFormat } } + private def inferAvroSchemaFromFiles( + files: Seq[FileStatus], + conf: Configuration, + ignoreExtension: Boolean, + ignoreCorruptFiles: Boolean): Schema = { + // Schema evolution is not supported yet. Here we only pick first random readable sample file to + // figure out the schema of the whole dataset. + val avroReader = files.iterator.map { f => + val path = f.getPath + if (!ignoreExtension && !path.getName.endsWith(".avro")) { + None + } else { + Utils.tryWithResource { + new FsInput(path, conf) + } { in => + try { + Some(DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]())) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read file: $path", e) + } + } + } + } + }.collectFirst { + case Some(reader) => reader + } + + avroReader match { + case Some(reader) => + try { + reader.getSchema + } finally { + reader.close() + } + case None => + throw new FileNotFoundException( + "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") + } + } + override def shortName(): String = "avro" override def isSplitable( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 9ad4388414eaa..1e08f7b50b115 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ @@ -342,6 +343,48 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + private def createDummyCorruptFile(dir: File): Unit = { + Utils.tryWithResource { + FileUtils.forceMkdir(dir) + val corruptFile = new File(dir, "corrupt.avro") + new BufferedWriter(new FileWriter(corruptFile)) + } { writer => + writer.write("corrupt") + } + } + + test("Ignore corrupt Avro file if flag IGNORE_CORRUPT_FILES enabled") { + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + withTempPath { dir => + createDummyCorruptFile(dir) + val message = intercept[FileNotFoundException] { + spark.read.format("avro").load(dir.getAbsolutePath).schema + }.getMessage + assert(message.contains("No Avro files found.")) + + val srcFile = new File("src/test/resources/episodes.avro") + val destFile = new File(dir, "episodes.avro") + FileUtils.copyFile(srcFile, destFile) + + val result = spark.read.format("avro").load(srcFile.getAbsolutePath).collect() + checkAnswer(spark.read.format("avro").load(dir.getAbsolutePath), result) + } + } + } + + test("Throws IOException on reading corrupt Avro file if flag IGNORE_CORRUPT_FILES disabled") { + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + withTempPath { dir => + createDummyCorruptFile(dir) + val message = intercept[org.apache.spark.SparkException] { + spark.read.format("avro").load(dir.getAbsolutePath) + }.getMessage + + assert(message.contains("Could not read file")) + } + } + } + test("Date field type") { withTempPath { dir => val schema = StructType(Seq( From 1a5d83bed8a6df62ef643b08453c7dd8feebf93a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 3 Oct 2018 04:14:07 -0700 Subject: [PATCH 1752/2461] [SPARK-25589][SQL][TEST] Add BloomFilterBenchmark ## What changes were proposed in this pull request? This PR aims to add `BloomFilterBenchmark`. For ORC data source, Apache Spark has been supporting for a long time. For Parquet data source, it's expected to be added with next Parquet release update. ## How was this patch tested? Manual. ```scala SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.BloomFilterBenchmark" ``` Closes #22605 from dongjoon-hyun/SPARK-25589. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../BloomFilterBenchmark-results.txt | 24 +++++ .../benchmark/BloomFilterBenchmark.scala | 87 +++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 sql/core/benchmarks/BloomFilterBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala diff --git a/sql/core/benchmarks/BloomFilterBenchmark-results.txt b/sql/core/benchmarks/BloomFilterBenchmark-results.txt new file mode 100644 index 0000000000000..2eeb26c899b42 --- /dev/null +++ b/sql/core/benchmarks/BloomFilterBenchmark-results.txt @@ -0,0 +1,24 @@ +================================================================================================ +ORC Write +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Write 100M rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Without bloom filter 16765 / 17587 6.0 167.7 1.0X +With bloom filter 20060 / 20626 5.0 200.6 0.8X + + +================================================================================================ +ORC Read +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Read a row from 100M rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Without bloom filter 1857 / 1904 53.9 18.6 1.0X +With bloom filter 1399 / 1437 71.5 14.0 1.3X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala new file mode 100644 index 0000000000000..2f3caca849cdf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.util.Random + +import org.apache.spark.benchmark.Benchmark + +/** + * Benchmark to measure read performance with Bloom filters. + * + * Currently, only ORC supports bloom filters, we will add Parquet BM as soon as it becomes + * available. + * + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/BloomFilterBenchmark-results.txt". + * }}} + */ +object BloomFilterBenchmark extends SqlBasedBenchmark { + import spark.implicits._ + + private val scaleFactor = 100 + private val N = scaleFactor * 1000 * 1000 + private val df = spark.range(N).map(_ => Random.nextInt) + + private def writeBenchmark(): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + runBenchmark(s"ORC Write") { + val benchmark = new Benchmark(s"Write ${scaleFactor}M rows", N, output = output) + benchmark.addCase("Without bloom filter") { _ => + df.write.mode("overwrite").orc(path + "/withoutBF") + } + benchmark.addCase("With bloom filter") { _ => + df.write.mode("overwrite") + .option("orc.bloom.filter.columns", "value").orc(path + "/withBF") + } + benchmark.run() + } + } + } + + private def readBenchmark(): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + df.write.orc(path + "/withoutBF") + df.write.option("orc.bloom.filter.columns", "value").orc(path + "/withBF") + + runBenchmark(s"ORC Read") { + val benchmark = new Benchmark(s"Read a row from ${scaleFactor}M rows", N, output = output) + benchmark.addCase("Without bloom filter") { _ => + spark.read.orc(path + "/withoutBF").where("value = 0").count + } + benchmark.addCase("With bloom filter") { _ => + spark.read.orc(path + "/withBF").where("value = 0").count + } + benchmark.run() + } + } + } + + override def runBenchmarkSuite(): Unit = { + writeBenchmark() + readBenchmark() + } +} From 56741c342dce87a75b39e52db6de92d7d7bef371 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 3 Oct 2018 04:20:02 -0700 Subject: [PATCH 1753/2461] [SPARK-25483][TEST] Refactor UnsafeArrayDataBenchmark to use main method ## What changes were proposed in this pull request? Refactor `UnsafeArrayDataBenchmark` to use main method. Generate benchmark result: ```sh SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.UnsafeArrayDataBenchmark" ``` ## How was this patch tested? manual tests Closes #22491 from wangyum/SPARK-25483. Lead-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../UnsafeArrayDataBenchmark-results.txt | 33 +++++++++ .../benchmark/UnsafeArrayDataBenchmark.scala | 73 ++++++------------- 2 files changed, 56 insertions(+), 50 deletions(-) create mode 100644 sql/core/benchmarks/UnsafeArrayDataBenchmark-results.txt diff --git a/sql/core/benchmarks/UnsafeArrayDataBenchmark-results.txt b/sql/core/benchmarks/UnsafeArrayDataBenchmark-results.txt new file mode 100644 index 0000000000000..4ecc1f1fad4b9 --- /dev/null +++ b/sql/core/benchmarks/UnsafeArrayDataBenchmark-results.txt @@ -0,0 +1,33 @@ +================================================================================================ +Benchmark UnsafeArrayData +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Int 233 / 234 718.6 1.4 1.0X +Double 244 / 244 687.0 1.5 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Int 32 / 33 658.6 1.5 1.0X +Double 73 / 75 287.0 3.5 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Int 70 / 72 895.0 1.1 1.0X +Double 141 / 143 446.9 2.2 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Int 72 / 73 874.7 1.1 1.0X +Double 145 / 146 433.7 2.3 0.5X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala index 51ab0e13a98a8..79eaeab9c399f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -19,20 +19,21 @@ package org.apache.spark.sql.execution.benchmark import scala.util.Random -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter} +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData /** * Benchmark [[UnsafeArrayDataBenchmark]] for UnsafeArrayData - * To run this: - * 1. replace ignore(...) with test(...) - * 2. build/sbt "sql/test-only *benchmark.UnsafeArrayDataBenchmark" - * - * Benchmarks in this file are skipped in normal builds. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/UnsafeArrayDataBenchmark-results.txt". + * }}} */ -class UnsafeArrayDataBenchmark extends BenchmarkWithCodegen { +object UnsafeArrayDataBenchmark extends BenchmarkBase { def calculateHeaderPortionInBytes(count: Int) : Int = { /* 4 + 4 * count // Use this expression for SPARK-15962 */ @@ -77,18 +78,10 @@ class UnsafeArrayDataBenchmark extends BenchmarkWithCodegen { } } - val benchmark = new Benchmark("Read UnsafeArrayData", count * iters) + val benchmark = new Benchmark("Read UnsafeArrayData", count * iters, output = output) benchmark.addCase("Int")(readIntArray) benchmark.addCase("Double")(readDoubleArray) benchmark.run - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Int 252 / 260 666.1 1.5 1.0X - Double 281 / 292 597.7 1.7 0.9X - */ } def writeUnsafeArray(iters: Int): Unit = { @@ -121,18 +114,10 @@ class UnsafeArrayDataBenchmark extends BenchmarkWithCodegen { doubleTotalLength = len } - val benchmark = new Benchmark("Write UnsafeArrayData", count * iters) + val benchmark = new Benchmark("Write UnsafeArrayData", count * iters, output = output) benchmark.addCase("Int")(writeIntArray) benchmark.addCase("Double")(writeDoubleArray) benchmark.run - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Int 196 / 249 107.0 9.3 1.0X - Double 227 / 367 92.3 10.8 0.9X - */ } def getPrimitiveArray(iters: Int): Unit = { @@ -167,18 +152,11 @@ class UnsafeArrayDataBenchmark extends BenchmarkWithCodegen { doubleTotalLength = len } - val benchmark = new Benchmark("Get primitive array from UnsafeArrayData", count * iters) + val benchmark = + new Benchmark("Get primitive array from UnsafeArrayData", count * iters, output = output) benchmark.addCase("Int")(readIntArray) benchmark.addCase("Double")(readDoubleArray) benchmark.run - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Int 151 / 198 415.8 2.4 1.0X - Double 214 / 394 293.6 3.4 0.7X - */ } def putPrimitiveArray(iters: Int): Unit = { @@ -209,24 +187,19 @@ class UnsafeArrayDataBenchmark extends BenchmarkWithCodegen { doubleTotalLen = len } - val benchmark = new Benchmark("Create UnsafeArrayData from primitive array", count * iters) + val benchmark = + new Benchmark("Create UnsafeArrayData from primitive array", count * iters, output = output) benchmark.addCase("Int")(createIntArray) benchmark.addCase("Double")(createDoubleArray) benchmark.run - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Int 206 / 211 306.0 3.3 1.0X - Double 232 / 406 271.6 3.7 0.9X - */ } - ignore("Benchmark UnsafeArrayData") { - readUnsafeArray(10) - writeUnsafeArray(10) - getPrimitiveArray(5) - putPrimitiveArray(5) + override def runBenchmarkSuite(): Unit = { + runBenchmark("Benchmark UnsafeArrayData") { + readUnsafeArray(10) + writeUnsafeArray(10) + getPrimitiveArray(5) + putPrimitiveArray(5) + } } } From d7ae36a810bfcbedfe7360eb2cdbbc3ca970e4d0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 3 Oct 2018 07:28:34 -0700 Subject: [PATCH 1754/2461] [SPARK-25538][SQL] Zero-out all bytes when writing decimal ## What changes were proposed in this pull request? In #20850 when writing non-null decimals, instead of zero-ing all the 16 allocated bytes, we zero-out only the padding bytes. Since we always allocate 16 bytes, if the number of bytes needed for a decimal is lower than 9, then this means that the bytes between 8 and 16 are not zero-ed. I see 2 solutions here: - we can zero-out all the bytes in advance as it was done before #20850 (safer solution IMHO); - we can allocate only the needed bytes (may be a bit more efficient in terms of memory used, but I have not investigated the feasibility of this option). Hence I propose here the first solution in order to fix the correctness issue. We can eventually switch to the second if we think is more efficient later. ## How was this patch tested? Running the test attached in the JIRA + added UT Closes #22602 from mgaido91/SPARK-25582. Authored-by: Marco Gaido Signed-off-by: Dongjoon Hyun --- .../expressions/codegen/UnsafeRowWriter.java | 10 ++-- .../codegen/UnsafeRowWriterSuite.scala | 53 +++++++++++++++++++ 2 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 71c49d8ed0177..3960d6d520476 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -185,13 +185,13 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); + // always zero-out the 16-byte buffer + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - // zero-out the bytes - Platform.putLong(getBuffer(), cursor(), 0L); - Platform.putLong(getBuffer(), cursor() + 8, 0L); - BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); @@ -200,8 +200,6 @@ public void write(int ordinal, Decimal input, int precision, int scale) { final int numBytes = bytes.length; assert numBytes <= 16; - zeroOutPaddingBytes(numBytes); - // Write the bytes to the variable length portion. Platform.copyMemory( bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala new file mode 100644 index 0000000000000..fb651b76fc16d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.Decimal + +class UnsafeRowWriterSuite extends SparkFunSuite { + + def checkDecimalSizeInBytes(decimal: Decimal, numBytes: Int): Unit = { + assert(decimal.toJavaBigDecimal.unscaledValue().toByteArray.length == numBytes) + } + + test("SPARK-25538: zero-out all bits for decimals") { + val decimal1 = Decimal(0.431) + decimal1.changePrecision(38, 18) + checkDecimalSizeInBytes(decimal1, 8) + + val decimal2 = Decimal(123456789.1232456789) + decimal2.changePrecision(38, 18) + checkDecimalSizeInBytes(decimal2, 11) + // On an UnsafeRowWriter we write decimal2 first and then decimal1 + val unsafeRowWriter1 = new UnsafeRowWriter(1) + unsafeRowWriter1.resetRowWriter() + unsafeRowWriter1.write(0, decimal2, decimal2.precision, decimal2.scale) + unsafeRowWriter1.reset() + unsafeRowWriter1.write(0, decimal1, decimal1.precision, decimal1.scale) + val res1 = unsafeRowWriter1.getRow + // On a second UnsafeRowWriter we write directly decimal1 + val unsafeRowWriter2 = new UnsafeRowWriter(1) + unsafeRowWriter2.resetRowWriter() + unsafeRowWriter2.write(0, decimal1, decimal1.precision, decimal1.scale) + val res2 = unsafeRowWriter2.getRow + // The two rows should be the equal + assert(res1 == res2) + } + +} From 075dd620e32872b5d90a2fa7d09b43b15502182b Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Wed, 3 Oct 2018 16:18:36 -0700 Subject: [PATCH 1755/2461] [SPARK-25586][CORE] Remove outer objects from logdebug statements in ClosureCleaner ## What changes were proposed in this pull request? Cause: Recently test_glr_summary failed for PR of SPARK-25118, which enables spark-shell to run with default log level. It failed because this logdebug was called for GeneralizedLinearRegressionTrainingSummary which invoked its toString method, which started a Spark Job and ended up running into an infinite loop. Fix: Remove logDebug statement for outer objects as closures aren't implemented with outerclasses in Scala 2.12 and this debug statement looses its purpose ## How was this patch tested? Ran python pyspark-ml tests on top of PR for SPARK-25118 and ClosureCleaner unit tests Closes #22616 from ankuriitg/ankur/SPARK-25586. Authored-by: ankurgupta Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/util/ClosureCleaner.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 43d62561e8eba..6c4740c002103 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -285,8 +285,6 @@ private[spark] object ClosureCleaner extends Logging { innerClasses.foreach { c => logDebug(s" ${c.getName}") } logDebug(s" + outer classes: ${outerClasses.size}" ) outerClasses.foreach { c => logDebug(s" ${c.getName}") } - logDebug(s" + outer objects: ${outerObjects.size}") - outerObjects.foreach { o => logDebug(s" $o") } } // Fail fast if we detect return statements in closures @@ -318,19 +316,20 @@ private[spark] object ClosureCleaner extends Logging { if (outerPairs.nonEmpty) { val (outermostClass, outermostObject) = outerPairs.head if (isClosure(outermostClass)) { - logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + logDebug(s" + outermost object is a closure, so we clone it: ${outermostClass}") } else if (outermostClass.getName.startsWith("$line")) { // SPARK-14558: if the outermost object is a REPL line object, we should clone // and clean it as it may carray a lot of unnecessary information, // e.g. hadoop conf, spark conf, etc. - logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + logDebug(s" + outermost object is a REPL line object, so we clone it:" + + s" ${outermostClass}") } else { // The closure is ultimately nested inside a class; keep the object of that // class without cloning it since we don't want to clone the user's objects. // Note that we still need to keep around the outermost object itself because // we need it to clone its child closure later (see below). - logDebug(" + outermost object is not a closure or REPL line object," + - "so do not clone it: " + outerPairs.head) + logDebug(s" + outermost object is not a closure or REPL line object," + + s" so do not clone it: ${outermostClass}") parent = outermostObject // e.g. SparkContext outerPairs = outerPairs.tail } @@ -341,7 +340,7 @@ private[spark] object ClosureCleaner extends Logging { // Clone the closure objects themselves, nulling out any fields that are not // used in the closure we're working on or any of its inner closures. for ((cls, obj) <- outerPairs) { - logDebug(s" + cloning the object $obj of class ${cls.getName}") + logDebug(s" + cloning instance of class ${cls.getName}") // We null out these unused references by cloning each object and then filling in all // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be @@ -351,7 +350,7 @@ private[spark] object ClosureCleaner extends Logging { // If transitive cleaning is enabled, we recursively clean any enclosing closure using // the already populated accessed fields map of the starting closure if (cleanTransitively && isClosure(clone.getClass)) { - logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + logDebug(s" + cleaning cloned closure recursively (${cls.getName})") // No need to check serializable here for the outer closures because we're // only interested in the serializability of the starting closure clean(clone, checkSerializable = false, cleanTransitively, accessedFields) From 79dd4c96484c9be7ad9250b64f3fd8e088707641 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 4 Oct 2018 09:36:23 +0800 Subject: [PATCH 1756/2461] [SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance: ```python from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf("integer", PandasUDFType.GROUPED_AGG) def sum_udf(v): return v.sum() spark.udf.register("sum_udf", sum_udf) q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" spark.sql(q).show() ``` ``` +---+-----------+ | v2|sum_udf(v1)| +---+-----------+ | 1| 1| | 0| 5| +---+-----------+ ``` ## How was this patch tested? Manual test and unit test. Closes #22620 from HyukjinKwon/SPARK-25601. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 20 ++++++++++++++++++-- python/pyspark/sql/udf.py | 15 +++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 815772d23ceea..d3c29d061fc32 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5642,8 +5642,9 @@ def test_register_grouped_map_udf(self): foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' - 'SQL_SCALAR_PANDAS_UDF'): + with self.assertRaisesRegexp( + ValueError, + 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): @@ -6459,6 +6460,21 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + + sum_pandas_udf = pandas_udf( + lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + + self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf) + self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect())) + expected = [1, 5] + self.assertEqual(actual, expected) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..58f4e0dff5ee5 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -298,6 +298,15 @@ def register(self, name, f, returnType=None): >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + >>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def sum_udf(v): + ... return v.sum() + ... + >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP + >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + >>> spark.sql(q).collect() # doctest: +SKIP + [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)] + .. note:: Registration for a user-defined function (case 2.) was added from Spark 2.3.0. """ @@ -310,9 +319,11 @@ def register(self, name, f, returnType=None): "Invalid returnType: data type can not be specified when f is" "a user-defined function, but got %s." % returnType) if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_UDF]: + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") + "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF") register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic) From 927e527934a882fab89ca661c4eb31f84c45d830 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 4 Oct 2018 09:38:06 +0800 Subject: [PATCH 1757/2461] [SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance: ```python from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf("integer", PandasUDFType.GROUPED_AGG) def sum_udf(v): return v.sum() spark.udf.register("sum_udf", sum_udf) q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" spark.sql(q).show() ``` ``` +---+-----------+ | v2|sum_udf(v1)| +---+-----------+ | 1| 1| | 0| 5| +---+-----------+ ``` ## How was this patch tested? Manual test and unit test. Closes #22620 from HyukjinKwon/SPARK-25601. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon From 71c24aad36ae6b3f50447a019bf893490dcf1cf4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Oct 2018 20:15:21 +0800 Subject: [PATCH 1758/2461] [SPARK-25602][SQL] SparkPlan.getByteArrayRdd should not consume the input when not necessary ## What changes were proposed in this pull request? In `SparkPlan.getByteArrayRdd`, we should only call `it.hasNext` when the limit is not hit, as `iter.hasNext` may produce one row and buffer it, and cause wrong metrics. ## How was this patch tested? new tests Closes #22621 from cloud-fan/range. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/execution/SparkPlan.scala | 4 +- .../execution/metric/SQLMetricsSuite.scala | 55 ++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ab6031c436e9d..9d9b020309d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -250,7 +250,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext && (n < 0 || count < n)) { + // `iter.hasNext` may produce one row and buffer it, we should only call it when the limit is + // not hit. + while ((n < 0 || count < n) && iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d45eb0c27a6b1..085a445488480 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore +import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -517,4 +517,57 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { + def checkFilterAndRangeMetrics( + df: DataFrame, + filterNumOutputs: Int, + rangeNumOutputs: Int): Unit = { + var filter: FilterExec = null + var range: RangeExec = null + val collectFilterAndRange: SparkPlan => Unit = { + case f: FilterExec => + assert(filter == null, "the query should only have one Filter") + filter = f + case r: RangeExec => + assert(range == null, "the query should only have one Range") + range = r + case _ => + } + if (SQLConf.get.wholeStageEnabled) { + df.queryExecution.executedPlan.foreach { + case w: WholeStageCodegenExec => + w.child.foreach(collectFilterAndRange) + case _ => + } + } else { + df.queryExecution.executedPlan.foreach(collectFilterAndRange) + } + + assert(filter != null && range != null, "the query doesn't have Filter and Range") + assert(filter.metrics("numOutputRows").value == filterNumOutputs) + assert(range.metrics("numOutputRows").value == rangeNumOutputs) + } + + val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) + val df2 = df.limit(2) + Seq(true, false).foreach { wholeStageEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) { + df.collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) + + df.queryExecution.executedPlan.foreach(_.resetMetrics()) + // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, + // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces + // 4 rows, and Range produces 2000 rows. + df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) + + // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first + // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). + df2.collect() + checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) + } + } + } } From 95ae2094618fbbe07008c190105053dc2b85da1a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 4 Oct 2018 11:58:16 -0700 Subject: [PATCH 1759/2461] [SPARK-25479][TEST] Refactor DatasetBenchmark to use main method ## What changes were proposed in this pull request? Refactor `DatasetBenchmark` to use main method. Generate benchmark result: ```sh SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.DatasetBenchmark" ``` ## How was this patch tested? manual tests Closes #22488 from wangyum/SPARK-25479. Lead-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../benchmarks/DatasetBenchmark-results.txt | 46 +++++++++ .../apache/spark/sql/DatasetBenchmark.scala | 96 +++++-------------- 2 files changed, 71 insertions(+), 71 deletions(-) create mode 100644 sql/core/benchmarks/DatasetBenchmark-results.txt diff --git a/sql/core/benchmarks/DatasetBenchmark-results.txt b/sql/core/benchmarks/DatasetBenchmark-results.txt new file mode 100644 index 0000000000000..dcc190eb45c03 --- /dev/null +++ b/sql/core/benchmarks/DatasetBenchmark-results.txt @@ -0,0 +1,46 @@ +================================================================================================ +Dataset Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +RDD 11800 / 12042 8.5 118.0 1.0X +DataFrame 1927 / 2189 51.9 19.3 6.1X +Dataset 2483 / 2605 40.3 24.8 4.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +RDD 16286 / 16301 6.1 162.9 1.0X +DataFrame 8101 / 8104 12.3 81.0 2.0X +Dataset 17445 / 17811 5.7 174.4 0.9X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +RDD 2971 / 3184 33.7 29.7 1.0X +DataFrame 1243 / 1296 80.5 12.4 2.4X +Dataset 3062 / 3091 32.7 30.6 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +RDD 5253 / 5269 19.0 52.5 1.0X +DataFrame 211 / 234 473.4 2.1 24.9X +Dataset 9550 / 9552 10.5 95.5 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +RDD sum 5086 / 5108 19.7 50.9 1.0X +DataFrame sum 65 / 73 1548.9 0.6 78.8X +Dataset sum using Aggregator 9024 / 9320 11.1 90.2 0.6X +Dataset complex Aggregator 15079 / 15171 6.6 150.8 0.3X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index fa2f0b6ba61d4..e3df449b41f0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ @@ -26,8 +26,15 @@ import org.apache.spark.sql.types.StringType /** * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/DatasetBenchmark-results.txt". + * }}} */ -object DatasetBenchmark { +object DatasetBenchmark extends SqlBasedBenchmark { case class Data(l: Long, s: String) @@ -39,7 +46,7 @@ object DatasetBenchmark { val df = ds.toDF("l") val func = (l: Long) => l + 1 - val benchmark = new Benchmark("back-to-back map long", numRows) + val benchmark = new Benchmark("back-to-back map long", numRows, output = output) benchmark.addCase("RDD") { iter => var res = rdd @@ -78,7 +85,7 @@ object DatasetBenchmark { import spark.implicits._ val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val benchmark = new Benchmark("back-to-back map", numRows) + val benchmark = new Benchmark("back-to-back map", numRows, output = output) val func = (d: Data) => Data(d.l + 1, d.s) val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) @@ -123,7 +130,7 @@ object DatasetBenchmark { val df = ds.toDF("l") val func = (l: Long) => l % 2L == 0L - val benchmark = new Benchmark("back-to-back filter Long", numRows) + val benchmark = new Benchmark("back-to-back filter Long", numRows, output = output) benchmark.addCase("RDD") { iter => var res = rdd @@ -162,7 +169,7 @@ object DatasetBenchmark { import spark.implicits._ val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val benchmark = new Benchmark("back-to-back filter", numRows) + val benchmark = new Benchmark("back-to-back filter", numRows, output = output) val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) @@ -220,7 +227,7 @@ object DatasetBenchmark { import spark.implicits._ val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val benchmark = new Benchmark("aggregate", numRows) + val benchmark = new Benchmark("aggregate", numRows, output = output) val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD sum") { iter => @@ -242,75 +249,22 @@ object DatasetBenchmark { benchmark } - def main(args: Array[String]): Unit = { - val spark = SparkSession.builder + override def getSparkSession: SparkSession = { + SparkSession.builder .master("local[*]") .appName("Dataset benchmark") .getOrCreate() + } + override def runBenchmarkSuite(): Unit = { val numRows = 100000000 val numChains = 10 - - val benchmark0 = backToBackMapLong(spark, numRows, numChains) - val benchmark1 = backToBackMap(spark, numRows, numChains) - val benchmark2 = backToBackFilterLong(spark, numRows, numChains) - val benchmark3 = backToBackFilter(spark, numRows, numChains) - val benchmark4 = aggregate(spark, numRows) - - /* - OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic - Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz - back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - RDD 1883 / 1892 53.1 18.8 1.0X - DataFrame 502 / 642 199.1 5.0 3.7X - Dataset 657 / 784 152.2 6.6 2.9X - */ - benchmark0.run() - - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - RDD 3448 / 3646 29.0 34.5 1.0X - DataFrame 2647 / 3116 37.8 26.5 1.3X - Dataset 4781 / 5155 20.9 47.8 0.7X - */ - benchmark1.run() - - /* - OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic - Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz - back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - RDD 846 / 1120 118.1 8.5 1.0X - DataFrame 270 / 329 370.9 2.7 3.1X - Dataset 545 / 789 183.5 5.4 1.6X - */ - benchmark2.run() - - /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) - back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - RDD 1346 / 1618 74.3 13.5 1.0X - DataFrame 59 / 72 1695.4 0.6 22.8X - Dataset 2777 / 2805 36.0 27.8 0.5X - */ - benchmark3.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - RDD sum 1913 / 1942 52.3 19.1 1.0X - DataFrame sum 46 / 61 2157.7 0.5 41.3X - Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X - Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X - */ - benchmark4.run() + runBenchmark("Dataset Benchmark") { + backToBackMapLong(spark, numRows, numChains).run() + backToBackMap(spark, numRows, numChains).run() + backToBackFilterLong(spark, numRows, numChains).run() + backToBackFilter(spark, numRows, numChains).run() + aggregate(spark, numRows).run() + } } } From 3ae4f07de06e267f0363a53264876ea99dd731df Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 5 Oct 2018 02:22:06 +0100 Subject: [PATCH 1760/2461] [SPARK-17159][STREAM] Significant speed up for running spark streaming against Object store. ## What changes were proposed in this pull request? Original work by Steve Loughran. Based on #17745. This is a minimal patch of changes to FileInputDStream to reduce File status requests when querying files. Each call to file status is 3+ http calls to object store. This patch eliminates the need for it, by using FileStatus objects. This is a minor optimisation when working with filesystems, but significant when working with object stores. ## How was this patch tested? Tests included. Existing tests pass. Closes #22339 from ScrapCodes/PR_17745. Lead-authored-by: Prashant Sharma Co-authored-by: Steve Loughran Signed-off-by: Sean Owen --- .../streaming/dstream/FileInputDStream.scala | 57 +++++------ .../spark/streaming/InputStreamsSuite.scala | 98 +++++++++++++++---- .../spark/streaming/TestSuiteBase.scala | 14 ++- 3 files changed, 118 insertions(+), 51 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index b8a5a96faf15c..438847caf0c3a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -17,19 +17,19 @@ package org.apache.spark.streaming.dstream -import java.io.{IOException, ObjectInputStream} +import java.io.{FileNotFoundException, IOException, ObjectInputStream} import scala.collection.mutable import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamInputInfo -import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -122,9 +122,6 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Set of files that were selected in the remembered batches @transient private var recentlySelectedFiles = new mutable.HashSet[String]() - // Read-through cache of file mod times, used to speed up mod time lookups - @transient private var fileToModTime = new TimeStampedHashMap[String, Long](true) - // Timestamp of the last round of finding files @transient private var lastNewFileFindingTime = 0L @@ -140,7 +137,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * a union RDD out of them. Note that this maintains the list of files that were processed * in the latest modification time in the previous call to this method. This is because the * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. And new files may have the same modification time as the + * granularity of seconds in HDFS. And new files may have the same modification time as the * latest modification time in the previous call to this method yet was not reported in * the previous call. */ @@ -174,8 +171,6 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logDebug("Cleared files are:\n" + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) } - // Delete file mod times that weren't accessed in the last round of getting new files - fileToModTime.clearOldValues(lastNewFileFindingTime - 1) } /** @@ -197,29 +192,29 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logDebug(s"Getting new files for time $currentTime, " + s"ignoring files older than $modTimeIgnoreThreshold") - val newFileFilter = new PathFilter { - def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) - } - val directoryFilter = new PathFilter { - override def accept(path: Path): Boolean = fs.getFileStatus(path).isDirectory - } - val directories = fs.globStatus(directoryPath, directoryFilter).map(_.getPath) + val directories = Option(fs.globStatus(directoryPath)).getOrElse(Array.empty[FileStatus]) + .filter(_.isDirectory) + .map(_.getPath) val newFiles = directories.flatMap(dir => - fs.listStatus(dir, newFileFilter).map(_.getPath.toString)) + fs.listStatus(dir) + .filter(isNewFile(_, currentTime, modTimeIgnoreThreshold)) + .map(_.getPath.toString)) val timeTaken = clock.getTimeMillis() - lastNewFileFindingTime - logInfo("Finding new files took " + timeTaken + " ms") - logDebug("# cached file times = " + fileToModTime.size) + logDebug(s"Finding new files took $timeTaken ms") if (timeTaken > slideDuration.milliseconds) { logWarning( - "Time taken to find new files exceeds the batch size. " + + s"Time taken to find new files $timeTaken exceeds the batch size. " + "Consider increasing the batch size or reducing the number of " + - "files in the monitored directory." + "files in the monitored directories." ) } newFiles } catch { + case e: FileNotFoundException => + logWarning(s"No directory to scan: $directoryPath: $e") + Array.empty case e: Exception => - logWarning("Error finding new files", e) + logWarning(s"Error finding new files under $directoryPath", e) reset() Array.empty } @@ -242,8 +237,16 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * The files with mod time T+5 are not remembered and cannot be ignored (since, t+5 > t+1). * Hence they can get selected as new files again. To prevent this, files whose mod time is more * than current batch time are not considered. + * @param fileStatus file status + * @param currentTime time of the batch + * @param modTimeIgnoreThreshold the ignore threshold + * @return true if the file has been modified within the batch window */ - private def isNewFile(path: Path, currentTime: Long, modTimeIgnoreThreshold: Long): Boolean = { + private def isNewFile( + fileStatus: FileStatus, + currentTime: Long, + modTimeIgnoreThreshold: Long): Boolean = { + val path = fileStatus.getPath val pathStr = path.toString // Reject file if it does not satisfy filter if (!filter(path)) { @@ -251,7 +254,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( return false } // Reject file if it was created before the ignore time - val modTime = getFileModTime(path) + val modTime = fileStatus.getModificationTime() if (modTime <= modTimeIgnoreThreshold) { // Use <= instead of < to avoid SPARK-4518 logDebug(s"$pathStr ignored as mod time $modTime <= ignore time $modTimeIgnoreThreshold") @@ -293,11 +296,6 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( new UnionRDD(context.sparkContext, fileRDDs) } - /** Get file mod time from cache or fetch it from the file system */ - private def getFileModTime(path: Path) = { - fileToModTime.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) - } - private def directoryPath: Path = { if (_path == null) _path = new Path(directory) _path @@ -319,7 +317,6 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() - fileToModTime = new TimeStampedHashMap[String, Long](true) } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b5d36a36513ab..1cf21e8a28033 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -27,7 +27,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import com.google.common.io.Files -import org.apache.hadoop.fs.Path +import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.scalatest.BeforeAndAfter @@ -130,10 +131,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("binary records stream") { - var testDir: File = null - try { + withTempDir { testDir => val batchDuration = Seconds(2) - testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") Files.write("0\n", existingFile, StandardCharsets.UTF_8) @@ -176,8 +175,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(obtainedOutput(i) === input.map(b => (b + i).toByte)) } } - } finally { - if (testDir != null) Utils.deleteRecursively(testDir) } } @@ -190,10 +187,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("file input stream - wildcard") { - var testDir: File = null - try { + withTempDir { testDir => val batchDuration = Seconds(2) - testDir = Utils.createTempDir() val testSubDir1 = Utils.createDirectory(testDir.toString, "tmp1") val testSubDir2 = Utils.createDirectory(testDir.toString, "tmp2") @@ -221,12 +216,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // not enough to trigger a batch clock.advance(batchDuration.milliseconds / 2) - def createFileAndAdvenceTime(data: Int, dir: File): Unit = { + def createFileAndAdvanceTime(data: Int, dir: File): Unit = { val file = new File(testSubDir1, data.toString) Files.write(data + "\n", file, StandardCharsets.UTF_8) assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) - logInfo("Created file " + file) + logInfo(s"Created file $file") // Advance the clock after creating the file to avoid a race when // setting its modification time clock.advance(batchDuration.milliseconds) @@ -236,18 +231,85 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // Over time, create files in the temp directory 1 val input1 = Seq(1, 2, 3, 4, 5) - input1.foreach(i => createFileAndAdvenceTime(i, testSubDir1)) + input1.foreach(i => createFileAndAdvanceTime(i, testSubDir1)) // Over time, create files in the temp directory 1 val input2 = Seq(6, 7, 8, 9, 10) - input2.foreach(i => createFileAndAdvenceTime(i, testSubDir2)) + input2.foreach(i => createFileAndAdvanceTime(i, testSubDir2)) // Verify that all the files have been read val expectedOutput = (input1 ++ input2).map(_.toString).toSet assert(outputQueue.asScala.flatten.toSet === expectedOutput) } - } finally { - if (testDir != null) Utils.deleteRecursively(testDir) + } + } + + test("Modified files are correctly detected.") { + withTempDir { testDir => + val batchDuration = Seconds(2) + val durationMs = batchDuration.milliseconds + val testPath = new Path(testDir.toURI) + val streamDir = new Path(testPath, "streaming") + val streamGlobPath = new Path(streamDir, "sub*") + val generatedDir = new Path(testPath, "generated") + val generatedSubDir = new Path(generatedDir, "subdir") + val renamedSubDir = new Path(streamDir, "subdir") + + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val sparkContext = ssc.sparkContext + val hc = sparkContext.hadoopConfiguration + val fs = FileSystem.get(testPath.toUri, hc) + + fs.delete(testPath, true) + fs.mkdirs(testPath) + fs.mkdirs(streamDir) + fs.mkdirs(generatedSubDir) + + def write(path: Path, text: String): Unit = { + val out = fs.create(path, true) + IOUtils.write(text, out) + out.close() + } + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val existingFile = new Path(generatedSubDir, "existing") + write(existingFile, "existing\n") + val status = fs.getFileStatus(existingFile) + clock.setTime(status.getModificationTime + durationMs) + val batchCounter = new BatchCounter(ssc) + val fileStream = ssc.textFileStream(streamGlobPath.toUri.toString) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputQueue) + outputStream.register() + + ssc.start() + clock.advance(durationMs) + eventually(eventuallyTimeout) { + assert(1 === batchCounter.getNumCompletedBatches) + } + // create and rename the file + // put a file into the generated directory + val textPath = new Path(generatedSubDir, "renamed.txt") + write(textPath, "renamed\n") + val now = clock.getTimeMillis() + val modTime = now + durationMs / 2 + fs.setTimes(textPath, modTime, modTime) + val textFilestatus = fs.getFileStatus(existingFile) + assert(textFilestatus.getModificationTime < now + durationMs) + + // rename the directory under the path being scanned + fs.rename(generatedSubDir, renamedSubDir) + + // move forward one window + clock.advance(durationMs) + // await the next scan completing + eventually(eventuallyTimeout) { + assert(2 === batchCounter.getNumCompletedBatches) + } + // verify that the "renamed" file is found, but not the "existing" one which is out of + // the window + assert(Set("renamed") === outputQueue.asScala.flatten.toSet) + } } } @@ -416,10 +478,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } def testFileStream(newFilesOnly: Boolean) { - var testDir: File = null - try { + withTempDir { testDir => val batchDuration = Seconds(2) - testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") Files.write("0\n", existingFile, StandardCharsets.UTF_8) @@ -466,8 +526,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } assert(outputQueue.asScala.flatten.toSet === expectedOutput) } - } finally { - if (testDir != null) Utils.deleteRecursively(testDir) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index dbab70886102d..ada494eb897f3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import java.io.{IOException, ObjectInputStream} +import java.io.{File, IOException, ObjectInputStream} import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ @@ -557,4 +557,16 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { verifyOutput[W](output.toSeq, expectedOutput, useSet) } } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * (originally from `SqlTestUtils`.) + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + } From 85a93595d505ff40971f3c797b43e3de6e5a7760 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 4 Oct 2018 18:46:16 -0700 Subject: [PATCH 1761/2461] [SPARK-25609][TESTS] Reduce time of test for SPARK-22226 ## What changes were proposed in this pull request? The PR changes the test introduced for SPARK-22226, so that we don't run analysis and optimization on the plan. The scope of the test is code generation and running the above mentioned operation is expensive and useless for the test. The UT was also moved to the `CodeGenerationSuite` which is a better place given the scope of the test. ## How was this patch tested? running the UT before SPARK-22226 fails, after it passes. The execution time is about 50% the original one. On my laptop this means that the test now runs in about 23 seconds (instead of 50 seconds). Closes #22629 from mgaido91/SPARK-25609. Authored-by: Marco Gaido Signed-off-by: gatorsmile --- .../catalyst/expressions/CodeGenerationSuite.scala | 10 ++++++++++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index c383eec3d56b4..5e8113ac8658e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -346,6 +346,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { projection(row) } + test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { + val colNumber = 10000 + val attrs = (1 to colNumber).map(colIndex => AttributeReference(s"_$colIndex", IntegerType)()) + val lit = Literal(1000) + val exprs = attrs.flatMap { a => + Seq(If(lit < a, lit, a), sqrt(a)) + } + UnsafeProjection.create(exprs, attrs) + } + test("SPARK-22543: split large predicates into blocks due to JVM code size limit") { val length = 600 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 279b7b8d49f52..c0b277f76ae68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2408,18 +2408,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } - test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { - val colNumber = 10000 - val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) - val df = sqlContext.createDataFrame(input, StructType( - (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) - val newCols = (1 to colNumber).flatMap { colIndex => - Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), - expr(s"sqrt(_$colIndex)")) - } - df.select(newCols: _*).collect() - } - test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") From f27d96b9f35799bf7ecc850effbfdb0bf7b237ab Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 4 Oct 2018 18:52:28 -0700 Subject: [PATCH 1762/2461] [SPARK-25606][TEST] Reduce DateExpressionsSuite test time costs in Jenkins ## What changes were proposed in this pull request? Reduce `DateExpressionsSuite.Hour` test time costs in Jenkins by reduce iteration times. ## How was this patch tested? Manual tests on my local machine. before: ``` - Hour (34 seconds, 54 milliseconds) ``` after: ``` - Hour (2 seconds, 697 milliseconds) ``` Closes #22632 from wangyum/SPARK-25606. Authored-by: Yuming Wang Signed-off-by: gatorsmile --- .../sql/catalyst/expressions/DateExpressionsSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 63b24fb9eb13a..c9d733726ff2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -273,9 +273,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { val timeZoneId = Option(tz.getID) c.setTimeZone(tz) - (0 to 24).foreach { h => - (0 to 60 by 15).foreach { m => - (0 to 60 by 15).foreach { s => + (0 to 24 by 6).foreach { h => + (0 to 60 by 30).foreach { m => + (0 to 60 by 30).foreach { s => c.set(2015, 18, 3, h, m, s) checkEvaluation( Hour(Literal(new Timestamp(c.getTimeInMillis)), timeZoneId), From 8113b9c96601d8af5b1cbc453630c648a5d45550 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 4 Oct 2018 18:54:46 -0700 Subject: [PATCH 1763/2461] [SPARK-25605][TESTS] Run cast string to timestamp tests for a subset of timezones ## What changes were proposed in this pull request? The test `cast string to timestamp` used to run for all time zones. So it run for more than 600 times. Running the tests for a significant subset of time zones is probably good enough and doing this in a randomized manner enforces anyway that we are going to test all time zones in different runs. ## How was this patch tested? the test time reduces to 11 seconds from more than 2 minutes Closes #22631 from mgaido91/SPARK-25605. Authored-by: Marco Gaido Signed-off-by: gatorsmile --- .../org/apache/spark/sql/catalyst/expressions/CastSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index d9f32c000a885..90c0bf7d8b3d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.{Calendar, Locale, TimeZone} +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -110,7 +112,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - for (tz <- ALL_TIMEZONES) { + for (tz <- Random.shuffle(ALL_TIMEZONES).take(50)) { def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = { checkEvaluation(cast(Literal(str), TimestampType, Option(tz.getID)), expected) } From 44c1e1ab1c26560371831b1593f96f30344c4363 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Fri, 5 Oct 2018 02:58:25 +0100 Subject: [PATCH 1764/2461] [SPARK-25408] Move to mode ideomatic Java8 While working on another PR, I noticed that there is quite some legacy Java in there that can be beautified. For example the use og features from Java8, such as: - Collection libraries - Try-with-resource blocks No code has been changed What are your thoughts on this? This makes code easier to read, and using try-with-resource makes is less likely to forget to close something. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22399 from Fokko/SPARK-25408. Authored-by: Fokko Driesprong Signed-off-by: Sean Owen --- .../spark/util/kvstore/KVStoreSerializer.java | 10 +- .../spark/util/kvstore/LevelDBSuite.java | 2 +- .../network/ChunkFetchIntegrationSuite.java | 53 ++++----- .../shuffle/ShuffleIndexInformation.java | 8 +- .../ExternalShuffleBlockResolverSuite.java | 22 ++-- .../ExternalShuffleIntegrationSuite.java | 51 ++++---- .../shuffle/ExternalShuffleSecuritySuite.java | 15 ++- .../spark/util/sketch/CountMinSketch.java | 7 +- .../spark/util/sketch/CountMinSketchImpl.java | 8 +- .../apache/spark/io/ReadAheadInputStream.java | 102 ++++++++-------- .../sort/BypassMergeSortShuffleWriter.java | 6 +- .../shuffle/sort/ShuffleExternalSorter.java | 63 +++++----- .../org/apache/spark/JavaJdbcRDDSuite.java | 28 ++--- .../sort/UnsafeShuffleWriterSuite.java | 14 +-- .../test/org/apache/spark/JavaAPISuite.java | 27 +++-- .../expressions/RowBasedKeyValueBatch.java | 3 +- .../RowBasedKeyValueBatchSuite.java | 110 +++++++----------- .../apache/hive/service/cli/CLIService.java | 9 +- .../cli/operation/OperationManager.java | 8 +- 19 files changed, 243 insertions(+), 303 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java index bd8d9486acde5..771a9541bb349 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java @@ -54,11 +54,8 @@ public final byte[] serialize(Object o) throws Exception { return ((String) o).getBytes(UTF_8); } else { ByteArrayOutputStream bytes = new ByteArrayOutputStream(); - GZIPOutputStream out = new GZIPOutputStream(bytes); - try { + try (GZIPOutputStream out = new GZIPOutputStream(bytes)) { mapper.writeValue(out, o); - } finally { - out.close(); } return bytes.toByteArray(); } @@ -69,11 +66,8 @@ public final T deserialize(byte[] data, Class klass) throws Exception { if (klass.equals(String.class)) { return (T) new String(data, UTF_8); } else { - GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); - try { + try (GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data))) { return mapper.readValue(in, klass); - } finally { - in.close(); } } } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 205f7df87c5bc..39a952f2b0df9 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -217,7 +217,7 @@ public void testSkip() throws Exception { public void testNegativeIndexValues() throws Exception { List expected = Arrays.asList(-100, -50, 0, 50, 100); - expected.stream().forEach(i -> { + expected.forEach(i -> { try { db.write(createCustomType1(i)); } catch (Exception e) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 824482af08dd4..9656a9aba6291 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -143,37 +143,38 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - final Semaphore sem = new Semaphore(0); - final FetchResult res = new FetchResult(); - res.successChunks = Collections.synchronizedSet(new HashSet()); - res.failedChunks = Collections.synchronizedSet(new HashSet()); - res.buffers = Collections.synchronizedList(new LinkedList()); - ChunkReceivedCallback callback = new ChunkReceivedCallback() { - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - buffer.retain(); - res.successChunks.add(chunkIndex); - res.buffers.add(buffer); - sem.release(); - } + try (TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())) { + final Semaphore sem = new Semaphore(0); - @Override - public void onFailure(int chunkIndex, Throwable e) { - res.failedChunks.add(chunkIndex); - sem.release(); - } - }; + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); - for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); - } - if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } } - client.close(); return res; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 386738ece51a6..371149bef3974 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -37,14 +37,8 @@ public ShuffleIndexInformation(File indexFile) throws IOException { size = (int)indexFile.length(); ByteBuffer buffer = ByteBuffer.allocate(size); offsets = buffer.asLongBuffer(); - DataInputStream dis = null; - try { - dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); + try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { dis.readFully(buffer.array()); - } finally { - if (dis != null) { - dis.close(); - } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d2072a54fa415..44bc25a86b363 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -98,19 +98,15 @@ public void testSortShuffleBlocks() throws IOException { resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); - InputStream block0Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); - String block0 = CharStreams.toString( - new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); - block0Stream.close(); - assertEquals(sortBlock0, block0); - - InputStream block1Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); - String block1 = CharStreams.toString( - new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); - block1Stream.close(); - assertEquals(sortBlock1, block1); + try (InputStream block0Stream = resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream()) { + String block0 = CharStreams.toString(new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock0, block0); + } + + try (InputStream block1Stream = resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream()) { + String block1 = CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock1, block1); + } } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a6a1b8d0ac3f1..41bee401e2919 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -133,37 +133,38 @@ private FetchResult fetchBlocks( final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000); - client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, - new BlockFetchingListener() { - @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - data.retain(); - res.successBlocks.add(blockId); - res.buffers.add(data); - requestsRemaining.release(); + try (ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000)) { + client.init(APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } } } - } - - @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - res.failedBlocks.add(blockId); - requestsRemaining.release(); + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } } } - } - }, null); + }, null + ); - if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } } - client.close(); return res; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 16bad9f1b319d..dafefaaa7d38f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -96,14 +96,13 @@ private void validate(String appId, String secretKey, boolean encrypt) ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); } - ExternalShuffleClient client = - new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000); - client.init(appId); - // Registration either succeeds or throws an exception. - client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", - new ExecutorShuffleInfo(new String[0], 0, - "org.apache.spark.shuffle.sort.SortShuffleManager")); - client.close(); + try (ExternalShuffleClient client = + new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000)) { + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "org.apache.spark.shuffle.sort.SortShuffleManager")); + } } /** Provides a secret key holder which always returns the given secret key, for a single appId. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index f7c22dddb8cc0..06a248c9a27c2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -191,10 +191,9 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { * Reads in a {@link CountMinSketch} from a byte array. */ public static CountMinSketch readFrom(byte[] bytes) throws IOException { - InputStream in = new ByteArrayInputStream(bytes); - CountMinSketch cms = readFrom(in); - in.close(); - return cms; + try (InputStream in = new ByteArrayInputStream(bytes)) { + return readFrom(in); + } } /** diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index fd1906d2e5ae9..b78c1677a1213 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -322,10 +322,10 @@ public void writeTo(OutputStream out) throws IOException { @Override public byte[] toByteArray() throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - writeTo(out); - out.close(); - return out.toByteArray(); + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + writeTo(out); + return out.toByteArray(); + } } public static CountMinSketchImpl readFrom(InputStream in) throws IOException { diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 0cced9e222952..2e18715b600e0 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -135,62 +135,58 @@ private void readAsync() throws IOException { } finally { stateChangeLock.unlock(); } - executorService.execute(new Runnable() { - - @Override - public void run() { - stateChangeLock.lock(); - try { - if (isClosed) { - readInProgress = false; - return; - } - // Flip this so that the close method will not close the underlying input stream when we - // are reading. - isReading = true; - } finally { - stateChangeLock.unlock(); + executorService.execute(() -> { + stateChangeLock.lock(); + try { + if (isClosed) { + readInProgress = false; + return; } + // Flip this so that the close method will not close the underlying input stream when we + // are reading. + isReading = true; + } finally { + stateChangeLock.unlock(); + } - // Please note that it is safe to release the lock and read into the read ahead buffer - // because either of following two conditions will hold - 1. The active buffer has - // data available to read so the reader will not read from the read ahead buffer. - // 2. This is the first time read is called or the active buffer is exhausted, - // in that case the reader waits for this async read to complete. - // So there is no race condition in both the situations. - int read = 0; - int off = 0, len = arr.length; - Throwable exception = null; - try { - // try to fill the read ahead buffer. - // if a reader is waiting, possibly return early. - do { - read = underlyingInputStream.read(arr, off, len); - if (read <= 0) break; - off += read; - len -= read; - } while (len > 0 && !isWaiting.get()); - } catch (Throwable ex) { - exception = ex; - if (ex instanceof Error) { - // `readException` may not be reported to the user. Rethrow Error to make sure at least - // The user can see Error in UncaughtExceptionHandler. - throw (Error) ex; - } - } finally { - stateChangeLock.lock(); - readAheadBuffer.limit(off); - if (read < 0 || (exception instanceof EOFException)) { - endOfStream = true; - } else if (exception != null) { - readAborted = true; - readException = exception; - } - readInProgress = false; - signalAsyncReadComplete(); - stateChangeLock.unlock(); - closeUnderlyingInputStreamIfNecessary(); + // Please note that it is safe to release the lock and read into the read ahead buffer + // because either of following two conditions will hold - 1. The active buffer has + // data available to read so the reader will not read from the read ahead buffer. + // 2. This is the first time read is called or the active buffer is exhausted, + // in that case the reader waits for this async read to complete. + // So there is no race condition in both the situations. + int read = 0; + int off = 0, len = arr.length; + Throwable exception = null; + try { + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); + } catch (Throwable ex) { + exception = ex; + if (ex instanceof Error) { + // `readException` may not be reported to the user. Rethrow Error to make sure at least + // The user can see Error in UncaughtExceptionHandler. + throw (Error) ex; } + } finally { + stateChangeLock.lock(); + readAheadBuffer.limit(off); + if (read < 0 || (exception instanceof EOFException)) { + endOfStream = true; + } else if (exception != null) { + readAborted = true; + readException = exception; + } + readInProgress = false; + signalAsyncReadComplete(); + stateChangeLock.unlock(); + closeUnderlyingInputStreamIfNecessary(); } }); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..abe027f79d7e6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -152,9 +152,9 @@ public void write(Iterator> records) throws IOException { } for (int i = 0; i < numPartitions; i++) { - final DiskBlockObjectWriter writer = partitionWriters[i]; - partitionWriterSegments[i] = writer.commitAndGet(); - writer.close(); + try (final DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } } File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c7d2db4217d96..ad660741dcbac 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -181,42 +181,43 @@ private void writeSortedFile(boolean isLastFile) { // around this, we pass a dummy no-op serializer. final SerializerInstance ser = DummySerializerInstance.INSTANCE; - final DiskBlockObjectWriter writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); - int currentPartition = -1; - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final int partition = sortedRecords.packedRecordPointer.getPartitionId(); - assert (partition >= currentPartition); - if (partition != currentPartition) { - // Switch to the new partition - if (currentPartition != -1) { - final FileSegment fileSegment = writer.commitAndGet(); - spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + final FileSegment committedSegment; + try (final DiskBlockObjectWriter writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse)) { + + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + final FileSegment fileSegment = writer.commitAndGet(); + spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + } + currentPartition = partition; } - currentPartition = partition; - } - final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = taskMemoryManager.getPage(recordPointer); - final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); - long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length - while (dataRemaining > 0) { - final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); - Platform.copyMemory( - recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); - writer.write(writeBuffer, 0, toTransfer); - recordReadPosition += toTransfer; - dataRemaining -= toTransfer; + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); } - writer.recordWritten(); - } - final FileSegment committedSegment = writer.commitAndGet(); - writer.close(); + committedSegment = writer.commitAndGet(); + } // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, // then the file might be empty. Note that it might be better to avoid calling // writeSortedFile() in that case. diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java index a6589d2898144..c35661eed9751 100644 --- a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java +++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java @@ -39,30 +39,26 @@ public void setUp() throws ClassNotFoundException, SQLException { sc = new JavaSparkContext("local", "JavaAPISuite"); Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); - Connection connection = - DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true"); - try { - Statement create = connection.createStatement(); - create.execute( - "CREATE TABLE FOO(" + - "ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + - "DATA INTEGER)"); - create.close(); + try (Connection connection = DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true")) { + + try (Statement create = connection.createStatement()) { + create.execute( + "CREATE TABLE FOO(ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + + "DATA INTEGER)"); + } - PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)"); - for (int i = 1; i <= 100; i++) { - insert.setInt(1, i * 2); - insert.executeUpdate(); + try (PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")) { + for (int i = 1; i <= 100; i++) { + insert.setInt(1, i * 2); + insert.executeUpdate(); + } } - insert.close(); } catch (SQLException e) { // If table doesn't exist... if (e.getSQLState().compareTo("X0Y32") != 0) { throw e; } - } finally { - connection.close(); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e9..a07d0e84ea854 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -186,14 +186,14 @@ private List> readRecordsFromFile() throws IOException { if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } - DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); - Iterator> records = recordsStream.asKeyValueIterator(); - while (records.hasNext()) { - Tuple2 record = records.next(); - assertEquals(i, hashPartitioner.getPartition(record._1())); - recordsList.add(record); + try (DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in)) { + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2 record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); + } } - recordsStream.close(); startOffset += partitionSize; } } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 01b5fb7b46684..3992ab7049bdd 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -997,10 +997,10 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - channel1.close(); + try (FileChannel channel1 = fos1.getChannel()) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + } JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { @@ -1018,10 +1018,10 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - channel1.close(); + try (FileChannel channel1 = fos1.getChannel()) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + } JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(pair -> pair._2().toArray()); // force the file to read @@ -1042,13 +1042,12 @@ public void binaryRecords() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - FileChannel channel1 = fos1.getChannel(); - - for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); + try (FileChannel channel1 = fos1.getChannel()) { + for (int i = 0; i < numOfCopies; i++) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + } } - channel1.close(); JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length); assertEquals(numOfCopies,readRDD.count()); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 551443a11298b..460513816dfd9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions; +import java.io.Closeable; import java.io.IOException; import org.apache.spark.memory.MemoryConsumer; @@ -45,7 +46,7 @@ * page requires an average size for key value pairs to be larger than 1024 bytes. * */ -public abstract class RowBasedKeyValueBatch extends MemoryConsumer { +public abstract class RowBasedKeyValueBatch extends MemoryConsumer implements Closeable { protected final Logger logger = LoggerFactory.getLogger(RowBasedKeyValueBatch.class); private static final int DEFAULT_CAPACITY = 1 << 16; diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index 2da87113c6229..ef02f0ae72686 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -123,9 +123,8 @@ public void tearDown() { @Test public void emptyBatch() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { Assert.assertEquals(0, batch.numRows()); try { batch.getKeyRow(-1); @@ -152,31 +151,24 @@ public void emptyBatch() throws Exception { // Expected exception; do nothing. } Assert.assertFalse(batch.rowIterator().next()); - } finally { - batch.close(); } } @Test - public void batchType() throws Exception { - RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, + public void batchType() { + try (RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { Assert.assertEquals(batch1.getClass(), VariableLengthRowBasedKeyValueBatch.class); Assert.assertEquals(batch2.getClass(), FixedLengthRowBasedKeyValueBatch.class); - } finally { - batch1.close(); - batch2.close(); } } @Test public void setAndRetrieve() { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); Assert.assertTrue(checkValue(ret1, 1, 1)); UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); @@ -204,33 +196,27 @@ public void setAndRetrieve() { } catch (AssertionError e) { // Expected exception; do nothing. } - } finally { - batch.close(); } } @Test public void setUpdateAndRetrieve() { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); Assert.assertEquals(1, batch.numRows()); UnsafeRow retrievedValue = batch.getValueRow(0); updateValueRow(retrievedValue, 2, 2); UnsafeRow retrievedValue2 = batch.getValueRow(0); Assert.assertTrue(checkValue(retrievedValue2, 2, 2)); - } finally { - batch.close(); } } @Test public void iteratorTest() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3)); @@ -253,16 +239,13 @@ public void iteratorTest() throws Exception { Assert.assertTrue(checkKey(key3, 3, "C")); Assert.assertTrue(checkValue(value3, 3, 3)); Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @Test public void fixedLengthTest() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1)); appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2)); appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3)); @@ -293,16 +276,13 @@ public void fixedLengthTest() throws Exception { Assert.assertTrue(checkKey(key3, 33, 33)); Assert.assertTrue(checkValue(value3, 3, 3)); Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @Test public void appendRowUntilExceedingCapacity() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, 10); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, 10)) { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); for (int i = 0; i < 10; i++) { @@ -321,8 +301,6 @@ public void appendRowUntilExceedingCapacity() throws Exception { Assert.assertTrue(checkValue(value1, 1, 1)); } Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @@ -330,9 +308,8 @@ public void appendRowUntilExceedingCapacity() throws Exception { public void appendRowUntilExceedingPageSize() throws Exception { // Use default size or spark.buffer.pageSize if specified int pageSizeToUse = (int) memoryManager.pageSizeBytes(); - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, pageSizeToUse); //enough capacity - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, pageSizeToUse)) { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8; @@ -356,49 +333,44 @@ public void appendRowUntilExceedingPageSize() throws Exception { Assert.assertTrue(checkValue(value1, 1, 1)); } Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @Test public void failureToAllocateFirstPage() throws Exception { memoryManager.limit(1024); - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(11, 11); UnsafeRow ret = appendRow(batch, key, value); Assert.assertNull(ret); Assert.assertFalse(batch.rowIterator().next()); - } finally { - batch.close(); } } @Test public void randomizedTest() { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - int numEntry = 100; - long[] expectedK1 = new long[numEntry]; - String[] expectedK2 = new String[numEntry]; - long[] expectedV1 = new long[numEntry]; - long[] expectedV2 = new long[numEntry]; - - for (int i = 0; i < numEntry; i++) { - long k1 = rand.nextLong(); - String k2 = getRandomString(rand.nextInt(256)); - long v1 = rand.nextLong(); - long v2 = rand.nextLong(); - appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); - expectedK1[i] = k1; - expectedK2[i] = k2; - expectedV1[i] = v1; - expectedV2[i] = v2; - } - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + int numEntry = 100; + long[] expectedK1 = new long[numEntry]; + String[] expectedK2 = new String[numEntry]; + long[] expectedV1 = new long[numEntry]; + long[] expectedV2 = new long[numEntry]; + + for (int i = 0; i < numEntry; i++) { + long k1 = rand.nextLong(); + String k2 = getRandomString(rand.nextInt(256)); + long v1 = rand.nextLong(); + long v2 = rand.nextLong(); + appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); + expectedK1[i] = k1; + expectedK2[i] = k2; + expectedV1[i] = v1; + expectedV2[i] = v2; + } + for (int j = 0; j < 10000; j++) { int rowId = rand.nextInt(numEntry); if (rand.nextBoolean()) { @@ -410,8 +382,6 @@ public void randomizedTest() { Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId])); } } - } finally { - batch.close(); } } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java index 791ddcbd2c5b6..3cbc2c42761e5 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java @@ -146,18 +146,11 @@ public UserGroupInformation getHttpUGI() { public synchronized void start() { super.start(); // Initialize and test a connection to the metastore - IMetaStoreClient metastoreClient = null; - try { - metastoreClient = new HiveMetaStoreClient(hiveConf); + try (IMetaStoreClient metastoreClient = new HiveMetaStoreClient(hiveConf)) { metastoreClient.getDatabases("default"); } catch (Exception e) { throw new ServiceException("Unable to connect to MetaStore!", e); } - finally { - if (metastoreClient != null) { - metastoreClient.close(); - } - } } @Override diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java index 92c340a29c107..4a8779e07834d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java @@ -197,11 +197,11 @@ public void cancelOperation(OperationHandle opHandle) throws HiveSQLException { } public void closeOperation(OperationHandle opHandle) throws HiveSQLException { - Operation operation = removeOperation(opHandle); - if (operation == null) { - throw new HiveSQLException("Operation does not exist!"); + try (Operation operation = removeOperation(opHandle)) { + if (operation == null) { + throw new HiveSQLException("Operation does not exist!"); + } } - operation.close(); } public TableSchema getOperationResultSetSchema(OperationHandle opHandle) From 5ae20cf1a96a33f5de4435fcfb55914d64466525 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Oct 2018 11:03:41 +0800 Subject: [PATCH 1765/2461] Revert "[SPARK-25408] Move to mode ideomatic Java8" This reverts commit 44c1e1ab1c26560371831b1593f96f30344c4363. --- .../spark/util/kvstore/KVStoreSerializer.java | 10 +- .../spark/util/kvstore/LevelDBSuite.java | 2 +- .../network/ChunkFetchIntegrationSuite.java | 53 +++++---- .../shuffle/ShuffleIndexInformation.java | 8 +- .../ExternalShuffleBlockResolverSuite.java | 22 ++-- .../ExternalShuffleIntegrationSuite.java | 51 ++++---- .../shuffle/ExternalShuffleSecuritySuite.java | 15 +-- .../spark/util/sketch/CountMinSketch.java | 7 +- .../spark/util/sketch/CountMinSketchImpl.java | 8 +- .../apache/spark/io/ReadAheadInputStream.java | 102 ++++++++-------- .../sort/BypassMergeSortShuffleWriter.java | 6 +- .../shuffle/sort/ShuffleExternalSorter.java | 63 +++++----- .../org/apache/spark/JavaJdbcRDDSuite.java | 28 +++-- .../sort/UnsafeShuffleWriterSuite.java | 14 +-- .../test/org/apache/spark/JavaAPISuite.java | 27 ++--- .../expressions/RowBasedKeyValueBatch.java | 3 +- .../RowBasedKeyValueBatchSuite.java | 110 +++++++++++------- .../apache/hive/service/cli/CLIService.java | 9 +- .../cli/operation/OperationManager.java | 8 +- 19 files changed, 303 insertions(+), 243 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java index 771a9541bb349..bd8d9486acde5 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java @@ -54,8 +54,11 @@ public final byte[] serialize(Object o) throws Exception { return ((String) o).getBytes(UTF_8); } else { ByteArrayOutputStream bytes = new ByteArrayOutputStream(); - try (GZIPOutputStream out = new GZIPOutputStream(bytes)) { + GZIPOutputStream out = new GZIPOutputStream(bytes); + try { mapper.writeValue(out, o); + } finally { + out.close(); } return bytes.toByteArray(); } @@ -66,8 +69,11 @@ public final T deserialize(byte[] data, Class klass) throws Exception { if (klass.equals(String.class)) { return (T) new String(data, UTF_8); } else { - try (GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data))) { + GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); + try { return mapper.readValue(in, klass); + } finally { + in.close(); } } } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 39a952f2b0df9..205f7df87c5bc 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -217,7 +217,7 @@ public void testSkip() throws Exception { public void testNegativeIndexValues() throws Exception { List expected = Arrays.asList(-100, -50, 0, 50, 100); - expected.forEach(i -> { + expected.stream().forEach(i -> { try { db.write(createCustomType1(i)); } catch (Exception e) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 9656a9aba6291..824482af08dd4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -143,38 +143,37 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { - final FetchResult res = new FetchResult(); - - try (TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())) { - final Semaphore sem = new Semaphore(0); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); - res.successChunks = Collections.synchronizedSet(new HashSet()); - res.failedChunks = Collections.synchronizedSet(new HashSet()); - res.buffers = Collections.synchronizedList(new LinkedList()); - - ChunkReceivedCallback callback = new ChunkReceivedCallback() { - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - buffer.retain(); - res.successChunks.add(chunkIndex); - res.buffers.add(buffer); - sem.release(); - } - - @Override - public void onFailure(int chunkIndex, Throwable e) { - res.failedChunks.add(chunkIndex); - sem.release(); - } - }; + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); - for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); } - if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); } + client.close(); return res; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 371149bef3974..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -37,8 +37,14 @@ public ShuffleIndexInformation(File indexFile) throws IOException { size = (int)indexFile.length(); ByteBuffer buffer = ByteBuffer.allocate(size); offsets = buffer.asLongBuffer(); - try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { + DataInputStream dis = null; + try { + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); + } finally { + if (dis != null) { + dis.close(); + } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 44bc25a86b363..d2072a54fa415 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -98,15 +98,19 @@ public void testSortShuffleBlocks() throws IOException { resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); - try (InputStream block0Stream = resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream()) { - String block0 = CharStreams.toString(new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); - assertEquals(sortBlock0, block0); - } - - try (InputStream block1Stream = resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream()) { - String block1 = CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); - assertEquals(sortBlock1, block1); - } + InputStream block0Stream = + resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); + String block0 = CharStreams.toString( + new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); + String block1 = CharStreams.toString( + new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); + block1Stream.close(); + assertEquals(sortBlock1, block1); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 41bee401e2919..a6a1b8d0ac3f1 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -133,38 +133,37 @@ private FetchResult fetchBlocks( final Semaphore requestsRemaining = new Semaphore(0); - try (ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000)) { - client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, - new BlockFetchingListener() { - @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - data.retain(); - res.successBlocks.add(blockId); - res.buffers.add(data); - requestsRemaining.release(); - } + ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000); + client.init(APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); } } - - @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - res.failedBlocks.add(blockId); - requestsRemaining.release(); - } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); } } - }, null - ); + } + }, null); - if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); - } + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); } + client.close(); return res; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index dafefaaa7d38f..16bad9f1b319d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -96,13 +96,14 @@ private void validate(String appId, String secretKey, boolean encrypt) ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); } - try (ExternalShuffleClient client = - new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000)) { - client.init(appId); - // Registration either succeeds or throws an exception. - client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", - new ExecutorShuffleInfo(new String[0], 0, "org.apache.spark.shuffle.sort.SortShuffleManager")); - } + ExternalShuffleClient client = + new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, + "org.apache.spark.shuffle.sort.SortShuffleManager")); + client.close(); } /** Provides a secret key holder which always returns the given secret key, for a single appId. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 06a248c9a27c2..f7c22dddb8cc0 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -191,9 +191,10 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { * Reads in a {@link CountMinSketch} from a byte array. */ public static CountMinSketch readFrom(byte[] bytes) throws IOException { - try (InputStream in = new ByteArrayInputStream(bytes)) { - return readFrom(in); - } + InputStream in = new ByteArrayInputStream(bytes); + CountMinSketch cms = readFrom(in); + in.close(); + return cms; } /** diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index b78c1677a1213..fd1906d2e5ae9 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -322,10 +322,10 @@ public void writeTo(OutputStream out) throws IOException { @Override public byte[] toByteArray() throws IOException { - try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { - writeTo(out); - return out.toByteArray(); - } + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeTo(out); + out.close(); + return out.toByteArray(); } public static CountMinSketchImpl readFrom(InputStream in) throws IOException { diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 2e18715b600e0..0cced9e222952 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -135,58 +135,62 @@ private void readAsync() throws IOException { } finally { stateChangeLock.unlock(); } - executorService.execute(() -> { - stateChangeLock.lock(); - try { - if (isClosed) { - readInProgress = false; - return; - } - // Flip this so that the close method will not close the underlying input stream when we - // are reading. - isReading = true; - } finally { - stateChangeLock.unlock(); - } + executorService.execute(new Runnable() { - // Please note that it is safe to release the lock and read into the read ahead buffer - // because either of following two conditions will hold - 1. The active buffer has - // data available to read so the reader will not read from the read ahead buffer. - // 2. This is the first time read is called or the active buffer is exhausted, - // in that case the reader waits for this async read to complete. - // So there is no race condition in both the situations. - int read = 0; - int off = 0, len = arr.length; - Throwable exception = null; - try { - // try to fill the read ahead buffer. - // if a reader is waiting, possibly return early. - do { - read = underlyingInputStream.read(arr, off, len); - if (read <= 0) break; - off += read; - len -= read; - } while (len > 0 && !isWaiting.get()); - } catch (Throwable ex) { - exception = ex; - if (ex instanceof Error) { - // `readException` may not be reported to the user. Rethrow Error to make sure at least - // The user can see Error in UncaughtExceptionHandler. - throw (Error) ex; - } - } finally { + @Override + public void run() { stateChangeLock.lock(); - readAheadBuffer.limit(off); - if (read < 0 || (exception instanceof EOFException)) { - endOfStream = true; - } else if (exception != null) { - readAborted = true; - readException = exception; + try { + if (isClosed) { + readInProgress = false; + return; + } + // Flip this so that the close method will not close the underlying input stream when we + // are reading. + isReading = true; + } finally { + stateChangeLock.unlock(); + } + + // Please note that it is safe to release the lock and read into the read ahead buffer + // because either of following two conditions will hold - 1. The active buffer has + // data available to read so the reader will not read from the read ahead buffer. + // 2. This is the first time read is called or the active buffer is exhausted, + // in that case the reader waits for this async read to complete. + // So there is no race condition in both the situations. + int read = 0; + int off = 0, len = arr.length; + Throwable exception = null; + try { + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); + } catch (Throwable ex) { + exception = ex; + if (ex instanceof Error) { + // `readException` may not be reported to the user. Rethrow Error to make sure at least + // The user can see Error in UncaughtExceptionHandler. + throw (Error) ex; + } + } finally { + stateChangeLock.lock(); + readAheadBuffer.limit(off); + if (read < 0 || (exception instanceof EOFException)) { + endOfStream = true; + } else if (exception != null) { + readAborted = true; + readException = exception; + } + readInProgress = false; + signalAsyncReadComplete(); + stateChangeLock.unlock(); + closeUnderlyingInputStreamIfNecessary(); } - readInProgress = false; - signalAsyncReadComplete(); - stateChangeLock.unlock(); - closeUnderlyingInputStreamIfNecessary(); } }); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index abe027f79d7e6..323a5d3c52831 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -152,9 +152,9 @@ public void write(Iterator> records) throws IOException { } for (int i = 0; i < numPartitions; i++) { - try (final DiskBlockObjectWriter writer = partitionWriters[i]) { - partitionWriterSegments[i] = writer.commitAndGet(); - } + final DiskBlockObjectWriter writer = partitionWriters[i]; + partitionWriterSegments[i] = writer.commitAndGet(); + writer.close(); } File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index ad660741dcbac..c7d2db4217d96 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -181,43 +181,42 @@ private void writeSortedFile(boolean isLastFile) { // around this, we pass a dummy no-op serializer. final SerializerInstance ser = DummySerializerInstance.INSTANCE; - int currentPartition = -1; - final FileSegment committedSegment; - try (final DiskBlockObjectWriter writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse)) { - - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final int partition = sortedRecords.packedRecordPointer.getPartitionId(); - assert (partition >= currentPartition); - if (partition != currentPartition) { - // Switch to the new partition - if (currentPartition != -1) { - final FileSegment fileSegment = writer.commitAndGet(); - spillInfo.partitionLengths[currentPartition] = fileSegment.length(); - } - currentPartition = partition; - } + final DiskBlockObjectWriter writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); - final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = taskMemoryManager.getPage(recordPointer); - final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); - long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length - while (dataRemaining > 0) { - final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); - Platform.copyMemory( - recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); - writer.write(writeBuffer, 0, toTransfer); - recordReadPosition += toTransfer; - dataRemaining -= toTransfer; + int currentPartition = -1; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + final FileSegment fileSegment = writer.commitAndGet(); + spillInfo.partitionLengths[currentPartition] = fileSegment.length(); } - writer.recordWritten(); + currentPartition = partition; } - committedSegment = writer.commitAndGet(); + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); } + + final FileSegment committedSegment = writer.commitAndGet(); + writer.close(); // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, // then the file might be empty. Note that it might be better to avoid calling // writeSortedFile() in that case. diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java index c35661eed9751..a6589d2898144 100644 --- a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java +++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java @@ -39,26 +39,30 @@ public void setUp() throws ClassNotFoundException, SQLException { sc = new JavaSparkContext("local", "JavaAPISuite"); Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); + Connection connection = + DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true"); - try (Connection connection = DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true")) { - - try (Statement create = connection.createStatement()) { - create.execute( - "CREATE TABLE FOO(ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + - "DATA INTEGER)"); - } + try { + Statement create = connection.createStatement(); + create.execute( + "CREATE TABLE FOO(" + + "ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + + "DATA INTEGER)"); + create.close(); - try (PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")) { - for (int i = 1; i <= 100; i++) { - insert.setInt(1, i * 2); - insert.executeUpdate(); - } + PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)"); + for (int i = 1; i <= 100; i++) { + insert.setInt(1, i * 2); + insert.executeUpdate(); } + insert.close(); } catch (SQLException e) { // If table doesn't exist... if (e.getSQLState().compareTo("X0Y32") != 0) { throw e; } + } finally { + connection.close(); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a07d0e84ea854..0d5c5ea7903e9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -186,14 +186,14 @@ private List> readRecordsFromFile() throws IOException { if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } - try (DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in)) { - Iterator> records = recordsStream.asKeyValueIterator(); - while (records.hasNext()) { - Tuple2 record = records.next(); - assertEquals(i, hashPartitioner.getPartition(record._1())); - recordsList.add(record); - } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2 record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); } + recordsStream.close(); startOffset += partitionSize; } } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 3992ab7049bdd..01b5fb7b46684 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -997,10 +997,10 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - try (FileChannel channel1 = fos1.getChannel()) { - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - } + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { @@ -1018,10 +1018,10 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - try (FileChannel channel1 = fos1.getChannel()) { - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - } + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(pair -> pair._2().toArray()); // force the file to read @@ -1042,12 +1042,13 @@ public void binaryRecords() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - try (FileChannel channel1 = fos1.getChannel()) { - for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - } + FileChannel channel1 = fos1.getChannel(); + + for (int i = 0; i < numOfCopies; i++) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); } + channel1.close(); JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length); assertEquals(numOfCopies,readRDD.count()); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 460513816dfd9..551443a11298b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.catalyst.expressions; -import java.io.Closeable; import java.io.IOException; import org.apache.spark.memory.MemoryConsumer; @@ -46,7 +45,7 @@ * page requires an average size for key value pairs to be larger than 1024 bytes. * */ -public abstract class RowBasedKeyValueBatch extends MemoryConsumer implements Closeable { +public abstract class RowBasedKeyValueBatch extends MemoryConsumer { protected final Logger logger = LoggerFactory.getLogger(RowBasedKeyValueBatch.class); private static final int DEFAULT_CAPACITY = 1 << 16; diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index ef02f0ae72686..2da87113c6229 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -123,8 +123,9 @@ public void tearDown() { @Test public void emptyBatch() throws Exception { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { Assert.assertEquals(0, batch.numRows()); try { batch.getKeyRow(-1); @@ -151,24 +152,31 @@ public void emptyBatch() throws Exception { // Expected exception; do nothing. } Assert.assertFalse(batch.rowIterator().next()); + } finally { + batch.close(); } } @Test - public void batchType() { - try (RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, + public void batchType() throws Exception { + RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + try { Assert.assertEquals(batch1.getClass(), VariableLengthRowBasedKeyValueBatch.class); Assert.assertEquals(batch2.getClass(), FixedLengthRowBasedKeyValueBatch.class); + } finally { + batch1.close(); + batch2.close(); } } @Test public void setAndRetrieve() { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); Assert.assertTrue(checkValue(ret1, 1, 1)); UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); @@ -196,27 +204,33 @@ public void setAndRetrieve() { } catch (AssertionError e) { // Expected exception; do nothing. } + } finally { + batch.close(); } } @Test public void setUpdateAndRetrieve() { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); Assert.assertEquals(1, batch.numRows()); UnsafeRow retrievedValue = batch.getValueRow(0); updateValueRow(retrievedValue, 2, 2); UnsafeRow retrievedValue2 = batch.getValueRow(0); Assert.assertTrue(checkValue(retrievedValue2, 2, 2)); + } finally { + batch.close(); } } @Test public void iteratorTest() throws Exception { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3)); @@ -239,13 +253,16 @@ public void iteratorTest() throws Exception { Assert.assertTrue(checkKey(key3, 3, "C")); Assert.assertTrue(checkValue(value3, 3, 3)); Assert.assertFalse(iterator.next()); + } finally { + batch.close(); } } @Test public void fixedLengthTest() throws Exception { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1)); appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2)); appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3)); @@ -276,13 +293,16 @@ public void fixedLengthTest() throws Exception { Assert.assertTrue(checkKey(key3, 33, 33)); Assert.assertTrue(checkValue(value3, 3, 3)); Assert.assertFalse(iterator.next()); + } finally { + batch.close(); } } @Test public void appendRowUntilExceedingCapacity() throws Exception { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, 10)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, 10); + try { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); for (int i = 0; i < 10; i++) { @@ -301,6 +321,8 @@ public void appendRowUntilExceedingCapacity() throws Exception { Assert.assertTrue(checkValue(value1, 1, 1)); } Assert.assertFalse(iterator.next()); + } finally { + batch.close(); } } @@ -308,8 +330,9 @@ public void appendRowUntilExceedingCapacity() throws Exception { public void appendRowUntilExceedingPageSize() throws Exception { // Use default size or spark.buffer.pageSize if specified int pageSizeToUse = (int) memoryManager.pageSizeBytes(); - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, pageSizeToUse)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, pageSizeToUse); //enough capacity + try { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8; @@ -333,44 +356,49 @@ public void appendRowUntilExceedingPageSize() throws Exception { Assert.assertTrue(checkValue(value1, 1, 1)); } Assert.assertFalse(iterator.next()); + } finally { + batch.close(); } } @Test public void failureToAllocateFirstPage() throws Exception { memoryManager.limit(1024); - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + try { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(11, 11); UnsafeRow ret = appendRow(batch, key, value); Assert.assertNull(ret); Assert.assertFalse(batch.rowIterator().next()); + } finally { + batch.close(); } } @Test public void randomizedTest() { - try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { - int numEntry = 100; - long[] expectedK1 = new long[numEntry]; - String[] expectedK2 = new String[numEntry]; - long[] expectedV1 = new long[numEntry]; - long[] expectedV2 = new long[numEntry]; - - for (int i = 0; i < numEntry; i++) { - long k1 = rand.nextLong(); - String k2 = getRandomString(rand.nextInt(256)); - long v1 = rand.nextLong(); - long v2 = rand.nextLong(); - appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); - expectedK1[i] = k1; - expectedK2[i] = k2; - expectedV1[i] = v1; - expectedV2[i] = v2; - } - + RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + int numEntry = 100; + long[] expectedK1 = new long[numEntry]; + String[] expectedK2 = new String[numEntry]; + long[] expectedV1 = new long[numEntry]; + long[] expectedV2 = new long[numEntry]; + + for (int i = 0; i < numEntry; i++) { + long k1 = rand.nextLong(); + String k2 = getRandomString(rand.nextInt(256)); + long v1 = rand.nextLong(); + long v2 = rand.nextLong(); + appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); + expectedK1[i] = k1; + expectedK2[i] = k2; + expectedV1[i] = v1; + expectedV2[i] = v2; + } + try { for (int j = 0; j < 10000; j++) { int rowId = rand.nextInt(numEntry); if (rand.nextBoolean()) { @@ -382,6 +410,8 @@ public void randomizedTest() { Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId])); } } + } finally { + batch.close(); } } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java index 3cbc2c42761e5..791ddcbd2c5b6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java @@ -146,11 +146,18 @@ public UserGroupInformation getHttpUGI() { public synchronized void start() { super.start(); // Initialize and test a connection to the metastore - try (IMetaStoreClient metastoreClient = new HiveMetaStoreClient(hiveConf)) { + IMetaStoreClient metastoreClient = null; + try { + metastoreClient = new HiveMetaStoreClient(hiveConf); metastoreClient.getDatabases("default"); } catch (Exception e) { throw new ServiceException("Unable to connect to MetaStore!", e); } + finally { + if (metastoreClient != null) { + metastoreClient.close(); + } + } } @Override diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java index 4a8779e07834d..92c340a29c107 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java @@ -197,11 +197,11 @@ public void cancelOperation(OperationHandle opHandle) throws HiveSQLException { } public void closeOperation(OperationHandle opHandle) throws HiveSQLException { - try (Operation operation = removeOperation(opHandle)) { - if (operation == null) { - throw new HiveSQLException("Operation does not exist!"); - } + Operation operation = removeOperation(opHandle); + if (operation == null) { + throw new HiveSQLException("Operation does not exist!"); } + operation.close(); } public TableSchema getOperationResultSetSchema(OperationHandle opHandle) From 459700727fadf3f35a211eab2ffc8d68a4a1c39a Mon Sep 17 00:00:00 2001 From: s71955 Date: Fri, 5 Oct 2018 13:09:16 +0800 Subject: [PATCH 1766/2461] [SPARK-25521][SQL] Job id showing null in the logs when insert into command Job is finished. ## What changes were proposed in this pull request? ``As part of insert command in FileFormatWriter, a job context is created for handling the write operation , While initializing the job context using setupJob() API in HadoopMapReduceCommitProtocol , we set the jobid in the Jobcontext configuration.In FileFormatWriter since we are directly getting the jobId from the map reduce JobContext the job id will come as null while adding the log. As a solution we shall get the jobID from the configuration of the map reduce Jobcontext.`` ## How was this patch tested? Manually, verified the logs after the changes. ![spark-25521 1](https://user-images.githubusercontent.com/12999161/46164933-e95ab700-c2ac-11e8-88e9-49fa5100b872.PNG) Closes #22572 from sujith71955/master_log_issue. Authored-by: s71955 Signed-off-by: Wenchen Fan --- .../spark/sql/execution/datasources/FileFormatWriter.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 7c6ab4bc922fe..774fe38f5c2e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -183,15 +183,15 @@ object FileFormatWriter extends Logging { val commitMsgs = ret.map(_.commitMsg) committer.commitJob(job, commitMsgs) - logInfo(s"Job ${job.getJobID} committed.") + logInfo(s"Write Job ${description.uuid} committed.") processStats(description.statsTrackers, ret.map(_.summary.stats)) - logInfo(s"Finished processing stats for job ${job.getJobID}.") + logInfo(s"Finished processing stats for write job ${description.uuid}.") // return a set of all the partition paths that were updated during this job ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty) } catch { case cause: Throwable => - logError(s"Aborting job ${job.getJobID}.", cause) + logError(s"Aborting job ${description.uuid}.", cause) committer.abortJob(job) throw new SparkException("Job aborted.", cause) } From ab1650d2938db4901b8c28df945d6a0691a19d31 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Fri, 5 Oct 2018 16:40:08 +0800 Subject: [PATCH 1767/2461] [SPARK-24601] Update Jackson to 2.9.6 Hi all, Jackson is incompatible with upstream versions, therefore bump the Jackson version to a more recent one. I bumped into some issues with Azure CosmosDB that is using a more recent version of Jackson. This can be fixed by adding exclusions and then it works without any issues. So no breaking changes in the API's. I would also consider bumping the version of Jackson in Spark. I would suggest to keep up to date with the dependencies, since in the future this issue will pop up more frequently. ## What changes were proposed in this pull request? Bump Jackson to 2.9.6 ## How was this patch tested? Compiled and tested it locally to see if anything broke. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21596 from Fokko/fd-bump-jackson. Authored-by: Fokko Driesprong Signed-off-by: hyukjinkwon --- .../rest/SubmitRestProtocolMessage.scala | 2 +- .../apache/spark/rdd/RDDOperationScope.scala | 2 +- .../org/apache/spark/status/KVUtils.scala | 2 +- .../status/api/v1/JacksonMessageWriter.scala | 2 +- .../org/apache/spark/status/api/v1/api.scala | 3 ++ dev/deps/spark-deps-hadoop-2.6 | 16 ++++----- dev/deps/spark-deps-hadoop-2.7 | 16 ++++----- dev/deps/spark-deps-hadoop-3.1 | 16 ++++----- pom.xml | 7 ++-- .../expressions/JsonExpressionsSuite.scala | 7 ++++ .../datasources/json/JsonBenchmarks.scala | 33 ++++++++++--------- 11 files changed, 59 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index ef5a7e35ad562..97b689cdadd5f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * (2) the Spark version of the client / server * (3) an optional message */ -@JsonInclude(Include.NON_NULL) +@JsonInclude(Include.NON_ABSENT) @JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) @JsonPropertyOrder(alphabetic = true) private[rest] abstract class SubmitRestProtocolMessage { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 53d69ba26811f..3abb2d8a11f35 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -41,7 +41,7 @@ import org.apache.spark.internal.Logging * There is no particular relationship between an operation scope and a stage or a job. * A scope may live inside one stage (e.g. map) or span across multiple jobs (e.g. take). */ -@JsonInclude(Include.NON_NULL) +@JsonInclude(Include.NON_ABSENT) @JsonPropertyOrder(Array("id", "name", "parent")) private[spark] class RDDOperationScope( val name: String, diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala index 99b1843d8e1c0..45348be5c98b9 100644 --- a/core/src/main/scala/org/apache/spark/status/KVUtils.scala +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -42,7 +42,7 @@ private[spark] object KVUtils extends Logging { private[spark] class KVStoreScalaSerializer extends KVStoreSerializer { mapper.registerModule(DefaultScalaModule) - mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + mapper.setSerializationInclusion(JsonInclude.Include.NON_ABSENT) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 4560d300cb0c8..50a286d0d3b0f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -49,7 +49,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ } mapper.registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule) mapper.enable(SerializationFeature.INDENT_OUTPUT) - mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + mapper.setSerializationInclusion(JsonInclude.Include.NON_ABSENT) mapper.setDateFormat(JacksonMessageWriter.makeISODateFormat) override def isWriteable( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 77466b62ff6ed..30afd8b769720 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -139,6 +139,9 @@ private[spark] class ExecutorMetricsJsonSerializer jsonGenerator.writeObject(metricsMap) } } + + override def isEmpty(provider: SerializerProvider, value: Option[ExecutorMetrics]): Boolean = + value.isEmpty } class JobData private[spark]( diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 969df4d92946b..2dcab8533b018 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -87,16 +87,16 @@ htrace-core-3.0.4.jar httpclient-4.5.6.jar httpcore-4.4.10.jar ivy-2.4.0.jar -jackson-annotations-2.6.7.jar -jackson-core-2.6.7.jar +jackson-annotations-2.9.6.jar +jackson-core-2.9.6.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.6.7.1.jar -jackson-dataformat-yaml-2.6.7.jar +jackson-databind-2.9.6.jar +jackson-dataformat-yaml-2.9.6.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-jaxb-annotations-2.6.7.jar -jackson-module-paranamer-2.7.9.jar -jackson-module-scala_2.11-2.6.7.1.jar +jackson-module-jaxb-annotations-2.9.6.jar +jackson-module-paranamer-2.9.6.jar +jackson-module-scala_2.11-2.9.6.jar jackson-xc-1.9.13.jar janino-3.0.10.jar javassist-3.18.1-GA.jar @@ -177,7 +177,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar -snakeyaml-1.15.jar +snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e827dc6036f85..d1d695c47c0bf 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -87,16 +87,16 @@ htrace-core-3.1.0-incubating.jar httpclient-4.5.6.jar httpcore-4.4.10.jar ivy-2.4.0.jar -jackson-annotations-2.6.7.jar -jackson-core-2.6.7.jar +jackson-annotations-2.9.6.jar +jackson-core-2.9.6.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.6.7.1.jar -jackson-dataformat-yaml-2.6.7.jar +jackson-databind-2.9.6.jar +jackson-dataformat-yaml-2.9.6.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-jaxb-annotations-2.6.7.jar -jackson-module-paranamer-2.7.9.jar -jackson-module-scala_2.11-2.6.7.1.jar +jackson-module-jaxb-annotations-2.9.6.jar +jackson-module-paranamer-2.9.6.jar +jackson-module-scala_2.11-2.9.6.jar jackson-xc-1.9.13.jar janino-3.0.10.jar javassist-3.18.1-GA.jar @@ -178,7 +178,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar -snakeyaml-1.15.jar +snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 2b12c35d18e27..e9691eb02aba2 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -86,17 +86,17 @@ htrace-core4-4.1.0-incubating.jar httpclient-4.5.6.jar httpcore-4.4.10.jar ivy-2.4.0.jar -jackson-annotations-2.6.7.jar -jackson-core-2.6.7.jar +jackson-annotations-2.9.6.jar +jackson-core-2.9.6.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.6.7.1.jar -jackson-dataformat-yaml-2.6.7.jar +jackson-databind-2.9.6.jar +jackson-dataformat-yaml-2.9.6.jar jackson-jaxrs-base-2.7.8.jar jackson-jaxrs-json-provider-2.7.8.jar jackson-mapper-asl-1.9.13.jar -jackson-module-jaxb-annotations-2.6.7.jar -jackson-module-paranamer-2.7.9.jar -jackson-module-scala_2.11-2.6.7.1.jar +jackson-module-jaxb-annotations-2.9.6.jar +jackson-module-paranamer-2.9.6.jar +jackson-module-scala_2.11-2.9.6.jar janino-3.0.10.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar @@ -197,7 +197,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar -snakeyaml-1.15.jar +snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar diff --git a/pom.xml b/pom.xml index cc20c5cbf8887..c5eea6a8ec7d5 100644 --- a/pom.xml +++ b/pom.xml @@ -158,8 +158,7 @@ 2.11.12 2.11 1.9.13 - 2.6.7 - 2.6.7.1 + 2.9.6 1.1.7.1 1.1.2 1.2.0-incubating @@ -629,7 +628,7 @@ com.fasterxml.jackson.core jackson-databind - ${fasterxml.jackson.databind.version} + ${fasterxml.jackson.version} com.fasterxml.jackson.core @@ -641,7 +640,7 @@ com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} - ${fasterxml.jackson.databind.version} + ${fasterxml.jackson.version} com.google.guava diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 34fdd0cc834f0..81ab7d690396a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -244,6 +244,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with "1234") } + test("some big value") { + val value = "x" * 3000 + checkEvaluation( + GetJsonObject(NonFoldableLiteral((s"""{"big": "$value"}""")), + NonFoldableLiteral("$.big")), value) + } + val jsonTupleQuery = Literal("f1") :: Literal("f2") :: Literal("f3") :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala index 368318ab38cb9..3c4a5ab32724b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -66,12 +66,13 @@ object JSONBenchmarks extends SQLHelper { } /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_172-b11 on Mac OS X 10.13.5 + Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - No encoding 38902 / 39282 2.6 389.0 1.0X - UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + No encoding 45908 / 46480 2.2 459.1 1.0X + UTF-8 is set 68469 / 69762 1.5 684.7 0.7X */ benchmark.run() } @@ -107,12 +108,13 @@ object JSONBenchmarks extends SQLHelper { } /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_172-b11 on Mac OS X 10.13.5 + Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - No encoding 25947 / 26188 3.9 259.5 1.0X - UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + No encoding 9982 / 10237 10.0 99.8 1.0X + UTF-8 is set 16373 / 16806 6.1 163.7 0.6X */ benchmark.run() } @@ -155,12 +157,13 @@ object JSONBenchmarks extends SQLHelper { } /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_172-b11 on Mac OS X 10.13.5 + Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - No encoding 45543 / 45660 0.2 4554.3 1.0X - UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + --------------------------------------------------------------------------------------------- + No encoding 26038 / 26386 0.4 2603.8 1.0X + UTF-8 is set 28343 / 28557 0.4 2834.3 0.9X */ benchmark.run() } From 434ada12a06d1d2d3cb19c4eac5a52f330bb236c Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Fri, 5 Oct 2018 17:48:52 +0900 Subject: [PATCH 1768/2461] [SPARK-17952][SQL] Nested Java beans support in createDataFrame ## What changes were proposed in this pull request? When constructing a DataFrame from a Java bean, using nested beans throws an error despite [documentation](http://spark.apache.org/docs/latest/sql-programming-guide.html#inferring-the-schema-using-reflection) stating otherwise. This PR aims to add that support. This PR does not yet add nested beans support in array or List fields. This can be added later or in another PR. ## How was this patch tested? Nested bean was added to the appropriate unit test. Also manually tested in Spark shell on code emulating the referenced JIRA: ``` scala> import scala.beans.BeanProperty import scala.beans.BeanProperty scala> class SubCategory(BeanProperty var id: String, BeanProperty var name: String) extends Serializable defined class SubCategory scala> class Category(BeanProperty var id: String, BeanProperty var subCategory: SubCategory) extends Serializable defined class Category scala> import scala.collection.JavaConverters._ import scala.collection.JavaConverters._ scala> spark.createDataFrame(Seq(new Category("s-111", new SubCategory("sc-111", "Sub-1"))).asJava, classOf[Category]) java.lang.IllegalArgumentException: The value (SubCategory65130cf2) of the type (SubCategory) cannot be converted to struct at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:262) at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:238) at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103) at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:396) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:1108) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:1108) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:1108) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:1106) at scala.collection.Iterator$$anon$11.next(Iterator.scala:410) at scala.collection.Iterator$class.toStream(Iterator.scala:1320) at scala.collection.AbstractIterator.toStream(Iterator.scala:1334) at scala.collection.TraversableOnce$class.toSeq(TraversableOnce.scala:298) at scala.collection.AbstractIterator.toSeq(Iterator.scala:1334) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:423) ... 51 elided ``` New behavior: ``` scala> spark.createDataFrame(Seq(new Category("s-111", new SubCategory("sc-111", "Sub-1"))).asJava, classOf[Category]) res0: org.apache.spark.sql.DataFrame = [id: string, subCategory: struct] scala> res0.show() +-----+---------------+ | id| subCategory| +-----+---------------+ |s-111|[sc-111, Sub-1]| +-----+---------------+ ``` Closes #22527 from michalsenkyr/SPARK-17952. Authored-by: Michal Senkyr Signed-off-by: Takuya UESHIN --- .../org/apache/spark/sql/SQLContext.scala | 29 +++++++++++++----- .../apache/spark/sql/JavaDataFrameSuite.java | 30 ++++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index af6018472cb03..dfb12f272eb2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1098,16 +1098,29 @@ object SQLContext { data: Iterator[_], beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { - val extractors = - JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) - val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => - (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + def createStructConverter(cls: Class[_], fieldTypes: Seq[DataType]): Any => InternalRow = { + val methodConverters = + JavaTypeInference.getJavaBeanReadableProperties(cls).zip(fieldTypes) + .map { case (property, fieldType) => + val method = property.getReadMethod + method -> createConverter(method.getReturnType, fieldType) + } + value => + if (value == null) { + null + } else { + new GenericInternalRow( + methodConverters.map { case (method, converter) => + converter(method.invoke(value)) + }) + } } - data.map { element => - new GenericInternalRow( - methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) } - ): InternalRow + def createConverter(cls: Class[_], dataType: DataType): Any => Any = dataType match { + case struct: StructType => createStructConverter(cls, struct.map(_.dataType)) + case _ => CatalystTypeConverters.createToCatalystConverter(dataType) } + val dataConverter = createStructConverter(beanClass, attrs.map(_.dataType)) + data.map(dataConverter) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 00f41d6484afb..a05afa4f6ba30 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -134,6 +134,8 @@ public static class Bean implements Serializable { private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); private BigInteger e = new BigInteger("1234567"); + private NestedBean f = new NestedBean(); + private NestedBean g = null; public double getA() { return a; @@ -152,6 +154,22 @@ public List getD() { } public BigInteger getE() { return e; } + + public NestedBean getF() { + return f; + } + + public NestedBean getG() { + return g; + } + + public static class NestedBean implements Serializable { + private int a = 1; + + public int getA() { + return a; + } + } } void validateDataFrameWithBeans(Bean bean, Dataset df) { @@ -171,7 +189,14 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { schema.apply("d")); Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), schema.apply("e")); - Row first = df.select("a", "b", "c", "d", "e").first(); + StructType nestedBeanType = + DataTypes.createStructType(Collections.singletonList(new StructField( + "a", IntegerType$.MODULE$, false, Metadata.empty()))); + Assert.assertEquals(new StructField("f", nestedBeanType, true, Metadata.empty()), + schema.apply("f")); + Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()), + schema.apply("g")); + Row first = df.select("a", "b", "c", "d", "e", "f", "g").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -192,6 +217,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { } // Java.math.BigInteger is equivalent to Spark Decimal(38,0) Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); + Row nested = first.getStruct(5); + Assert.assertEquals(bean.getF().getA(), nested.getInt(0)); + Assert.assertTrue(first.isNullAt(6)); } @Test From 7dcc90fbb8dc75077819a5d8c42652f0c84424b5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 5 Oct 2018 10:45:15 -0700 Subject: [PATCH 1769/2461] [SPARK-25644][SS] Fix java foreachBatch in DataStreamWriter ## What changes were proposed in this pull request? The java `foreachBatch` API in `DataStreamWriter` should accept `java.lang.Long` rather `scala.Long`. ## How was this patch tested? New java test. Closes #22633 from zsxwing/fix-java-foreachbatch. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../sql/streaming/DataStreamWriter.scala | 2 +- .../JavaDataStreamReaderWriterSuite.java | 89 +++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/streaming/JavaDataStreamReaderWriterSuite.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index e9a15214d952f..b23e86a786459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -380,7 +380,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.4.0 */ @InterfaceStability.Evolving - def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/streaming/JavaDataStreamReaderWriterSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/streaming/JavaDataStreamReaderWriterSuite.java new file mode 100644 index 0000000000000..48cdb2642d830 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/streaming/JavaDataStreamReaderWriterSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.streaming; + +import java.io.File; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.function.VoidFunction2; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.ForeachWriter; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.streaming.StreamingQuery; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.util.Utils; + +public class JavaDataStreamReaderWriterSuite { + private SparkSession spark; + private String input; + + @Before + public void setUp() { + spark = new TestSparkSession(); + input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "input").toString(); + } + + @After + public void tearDown() { + try { + Utils.deleteRecursively(new File(input)); + } finally { + spark.stop(); + spark = null; + } + } + + @Test + public void testForeachBatchAPI() { + StreamingQuery query = spark + .readStream() + .textFile(input) + .writeStream() + .foreachBatch(new VoidFunction2, Long>() { + @Override + public void call(Dataset v1, Long v2) throws Exception {} + }) + .start(); + query.stop(); + } + + @Test + public void testForeachAPI() { + StreamingQuery query = spark + .readStream() + .textFile(input) + .writeStream() + .foreach(new ForeachWriter() { + @Override + public boolean open(long partitionId, long epochId) { + return true; + } + + @Override + public void process(String value) {} + + @Override + public void close(Throwable errorOrNull) {} + }) + .start(); + query.stop(); + } +} From a433fbcee66904d1b7fa98ab053e2bdf81e5e4f2 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 5 Oct 2018 14:39:30 -0700 Subject: [PATCH 1770/2461] [SPARK-25626][SQL][TEST] Improve the test execution time of HiveClientSuites ## What changes were proposed in this pull request? Improve the runtime by reducing the number of partitions created in the test. The number of partitions are reduced from 280 to 60. Here are the test times for the `getPartitionsByFilter returns all partitions` test on my laptop. ``` [info] - 0.13: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (4 seconds, 230 milliseconds) [info] - 0.14: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (3 seconds, 576 milliseconds) [info] - 1.0: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (3 seconds, 495 milliseconds) [info] - 1.1: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (6 seconds, 728 milliseconds) [info] - 1.2: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (7 seconds, 260 milliseconds) [info] - 2.0: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (8 seconds, 270 milliseconds) [info] - 2.1: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (6 seconds, 856 milliseconds) [info] - 2.2: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (7 seconds, 587 milliseconds) [info] - 2.3: getPartitionsByFilter returns all partitions when hive.metastore.try.direct.sql=false (7 seconds, 230 milliseconds) ## How was this patch tested? Test only. Closes #22644 from dilipbiswal/SPARK-25626. Authored-by: Dilip Biswal Signed-off-by: gatorsmile --- .../sql/hive/client/HiveClientSuite.scala | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index fa9f753795f65..7a325bf26b4cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -32,7 +32,7 @@ class HiveClientSuite(version: String) private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname - private val testPartitionCount = 3 * 24 * 4 + private val testPartitionCount = 3 * 5 * 4 private def init(tryDirectSql: Boolean): HiveClient = { val storageFormat = CatalogStorageFormat( @@ -51,7 +51,7 @@ class HiveClientSuite(version: String) val partitions = for { ds <- 20170101 to 20170103 - h <- 0 to 23 + h <- 0 to 4 chunk <- Seq("aa", "ab", "ba", "bb") } yield CatalogTablePartition(Map( "ds" -> ds.toString, @@ -92,7 +92,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds") <=> 20170101, 20170101 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -100,7 +100,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds") === 20170101, 20170101 to 20170101, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -118,7 +118,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("chunk") === "aa", 20170101 to 20170103, - 0 to 23, + 0 to 4, "aa" :: Nil) } @@ -126,7 +126,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("chunk").cast(IntegerType) === 1, 20170101 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -134,7 +134,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("chunk").cast(BooleanType) === true, 20170101 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -142,23 +142,23 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( Literal(20170101) === attr("ds"), 20170101 to 20170101, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } - test("getPartitionsByFilter: ds=20170101 and h=10") { + test("getPartitionsByFilter: ds=20170101 and h=2") { testMetastorePartitionFiltering( - attr("ds") === 20170101 && attr("h") === 10, + attr("ds") === 20170101 && attr("h") === 2, 20170101 to 20170101, - 10 to 10, + 2 to 2, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } - test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=2") { testMetastorePartitionFiltering( - attr("ds").cast(LongType) === 20170101L && attr("h") === 10, + attr("ds").cast(LongType) === 20170101L && attr("h") === 2, 20170101 to 20170101, - 10 to 10, + 2 to 2, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -166,7 +166,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -174,7 +174,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds").in(20170102, 20170103), 20170102 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -182,7 +182,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil) } @@ -190,7 +190,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds").in(20170102, 20170103), 20170102 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) @@ -202,7 +202,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, - 0 to 23, + 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) @@ -213,7 +213,7 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("chunk").in("ab", "ba"), 20170101 to 20170103, - 0 to 23, + 0 to 4, "ab" :: "ba" :: Nil) } @@ -221,34 +221,34 @@ class HiveClientSuite(version: String) testMetastorePartitionFiltering( attr("chunk").in("ab", "ba"), 20170101 to 20170103, - 0 to 23, + 0 to 4, "ab" :: "ba" :: Nil, { case expr @ In(v, list) if expr.inSetConvertible => InSet(v, list.map(_.eval(EmptyRow)).toSet) }) } - test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) - val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || - (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) + test("getPartitionsByFilter: (ds=20170101 and h>=2) or (ds=20170102 and h<2)") { + val day1 = (20170101 to 20170101, 2 to 4, Seq("aa", "ab", "ba", "bb")) + val day2 = (20170102 to 20170102, 0 to 1, Seq("aa", "ab", "ba", "bb")) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 2) || + (attr("ds") === 20170102 && attr("h") < 2), day1 :: day2 :: Nil) } - test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) + test("getPartitionsByFilter: (ds=20170101 and h>=2) or (ds=20170102 and h<(1+1))") { + val day1 = (20170101 to 20170101, 2 to 4, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) - val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || - (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) + val day2 = (20170102 to 20170102, 0 to 4, Seq("aa", "ab", "ba", "bb")) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 2) || + (attr("ds") === 20170102 && attr("h") < (Literal(1) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { - val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) - val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=2) or (ds=20170102 and h<2))") { + val day1 = (20170101 to 20170101, 2 to 4, Seq("ab", "ba")) + val day2 = (20170102 to 20170102, 0 to 1, Seq("ab", "ba")) testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && - ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), + ((attr("ds") === 20170101 && attr("h") >= 2) || (attr("ds") === 20170102 && attr("h") < 2)), day1 :: day2 :: Nil) } From 1c9486c1acceb73e5cc6f1fa684b6d992e187a9a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 5 Oct 2018 16:42:06 -0700 Subject: [PATCH 1771/2461] [SPARK-25635][SQL][BUILD] Support selective direct encoding in native ORC write ## What changes were proposed in this pull request? Before ORC 1.5.3, `orc.dictionary.key.threshold` and `hive.exec.orc.dictionary.key.size.threshold` are applied for all columns. This has been a big huddle to enable dictionary encoding. From ORC 1.5.3, `orc.column.encoding.direct` is added to enforce direct encoding selectively in a column-wise manner. This PR aims to add that feature by upgrading ORC from 1.5.2 to 1.5.3. The followings are the patches in ORC 1.5.3 and this feature is the only one related to Spark directly. ``` ORC-406: ORC: Char(n) and Varchar(n) writers truncate to n bytes & corrupts multi-byte data (gopalv) ORC-403: [C++] Add checks to avoid invalid offsets in InputStream ORC-405: Remove calcite as a dependency from the benchmarks. ORC-375: Fix libhdfs on gcc7 by adding #include two places. ORC-383: Parallel builds fails with ConcurrentModificationException ORC-382: Apache rat exclusions + add rat check to travis ORC-401: Fix incorrect quoting in specification. ORC-385: Change RecordReader to extend Closeable. ORC-384: [C++] fix memory leak when loading non-ORC files ORC-391: [c++] parseType does not accept underscore in the field name ORC-397: Allow selective disabling of dictionary encoding. Original patch was by Mithun Radhakrishnan. ORC-389: Add ability to not decode Acid metadata columns ``` ## How was this patch tested? Pass the Jenkins with newly added test cases. Closes #22622 from dongjoon-hyun/SPARK-25635. Authored-by: Dongjoon Hyun Signed-off-by: gatorsmile --- dev/deps/spark-deps-hadoop-2.6 | 6 +- dev/deps/spark-deps-hadoop-2.7 | 6 +- dev/deps/spark-deps-hadoop-3.1 | 6 +- pom.xml | 2 +- .../datasources/orc/OrcSourceSuite.scala | 75 +++++++++++++++++++ .../sql/hive/orc/HiveOrcSourceSuite.scala | 8 ++ 6 files changed, 93 insertions(+), 10 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2dcab8533b018..22e86ef6c43b3 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -153,9 +153,9 @@ objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.2-nohive.jar -orc-mapreduce-1.5.2-nohive.jar -orc-shims-1.5.2.jar +orc-core-1.5.3-nohive.jar +orc-mapreduce-1.5.3-nohive.jar +orc-shims-1.5.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index d1d695c47c0bf..19dd786c63e48 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -154,9 +154,9 @@ objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.2-nohive.jar -orc-mapreduce-1.5.2-nohive.jar -orc-shims-1.5.2.jar +orc-core-1.5.3-nohive.jar +orc-mapreduce-1.5.3-nohive.jar +orc-shims-1.5.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e9691eb02aba2..ea0f487a193eb 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -172,9 +172,9 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.2-nohive.jar -orc-mapreduce-1.5.2-nohive.jar -orc-shims-1.5.2.jar +orc-core-1.5.3-nohive.jar +orc-mapreduce-1.5.3-nohive.jar +orc-shims-1.5.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index c5eea6a8ec7d5..79af5d69c66c5 100644 --- a/pom.xml +++ b/pom.xml @@ -131,7 +131,7 @@ 1.2.1 10.12.1.1 1.10.0 - 1.5.2 + 1.5.3 nohive 1.6.0 9.3.24.v20180605 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index b6bb1d7ba4ce3..dc81c0585bf18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.OrcFile +import org.apache.orc.OrcProto.ColumnEncoding.Kind.{DICTIONARY_V2, DIRECT, DIRECT_V2} import org.apache.orc.OrcProto.Stream.Kind import org.apache.orc.impl.RecordReaderImpl import org.scalatest.BeforeAndAfterAll @@ -115,6 +116,76 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } + protected def testSelectiveDictionaryEncoding(isSelective: Boolean) { + val tableName = "orcTable" + + withTempDir { dir => + withTable(tableName) { + val sqlStatement = orcImp match { + case "native" => + s""" + |CREATE TABLE $tableName (zipcode STRING, uniqColumn STRING, value DOUBLE) + |USING ORC + |OPTIONS ( + | path '${dir.toURI}', + | orc.dictionary.key.threshold '1.0', + | orc.column.encoding.direct 'uniqColumn' + |) + """.stripMargin + case "hive" => + s""" + |CREATE TABLE $tableName (zipcode STRING, uniqColumn STRING, value DOUBLE) + |STORED AS ORC + |LOCATION '${dir.toURI}' + |TBLPROPERTIES ( + | orc.dictionary.key.threshold '1.0', + | hive.exec.orc.dictionary.key.size.threshold '1.0', + | orc.column.encoding.direct 'uniqColumn' + |) + """.stripMargin + case impl => + throw new UnsupportedOperationException(s"Unknown ORC implementation: $impl") + } + + sql(sqlStatement) + sql(s"INSERT INTO $tableName VALUES ('94086', 'random-uuid-string', 0.0)") + + val partFiles = dir.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + val orcFilePath = new Path(partFiles.head.getAbsolutePath) + val readerOptions = OrcFile.readerOptions(new Configuration()) + val reader = OrcFile.createReader(orcFilePath, readerOptions) + var recordReader: RecordReaderImpl = null + try { + recordReader = reader.rows.asInstanceOf[RecordReaderImpl] + + // Check the kind + val stripe = recordReader.readStripeFooter(reader.getStripes.get(0)) + + // The encodings are divided into direct or dictionary-based categories and + // further refined as to whether they use RLE v1 or v2. RLE v1 is used by + // Hive 0.11 and RLE v2 is introduced in Hive 0.12 ORC with more improvements. + // For more details, see https://orc.apache.org/specification/ + assert(stripe.getColumns(1).getKind === DICTIONARY_V2) + if (isSelective) { + assert(stripe.getColumns(2).getKind === DIRECT_V2) + } else { + assert(stripe.getColumns(2).getKind === DICTIONARY_V2) + } + // Floating point types are stored with DIRECT encoding in IEEE 754 floating + // point bit layout. + assert(stripe.getColumns(3).getKind === DIRECT) + } finally { + if (recordReader != null) { + recordReader.close() + } + } + } + } + } + test("create temporary orc table") { checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) @@ -284,4 +355,8 @@ class OrcSourceSuite extends OrcSuite with SharedSQLContext { test("Check BloomFilter creation") { testBloomFilterCreation(Kind.BLOOM_FILTER_UTF8) // After ORC-101 } + + test("Enforce direct encoding column-wise selectively") { + testSelectiveDictionaryEncoding(isSelective = true) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index c1ae2f6861cb8..7fefaf53939bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -182,4 +182,12 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { } } } + + test("Enforce direct encoding column-wise selectively") { + Seq(true, false).foreach { convertMetastore => + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> s"$convertMetastore") { + testSelectiveDictionaryEncoding(isSelective = false) + } + } + } } From bbd038d2436c17ff519c08630a016f3ec796a282 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 5 Oct 2018 17:03:24 -0700 Subject: [PATCH 1772/2461] [SPARK-25653][TEST] Add tag ExtendedHiveTest for HiveSparkSubmitSuite ## What changes were proposed in this pull request? The total run time of `HiveSparkSubmitSuite` is about 10 minutes. While the related code is stable, add tag `ExtendedHiveTest` for it. ## How was this patch tested? Unit test. Closes #22642 from gengliangwang/addTagForHiveSparkSubmitSuite. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a676cf6ce6925..f839e8979d355 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -33,11 +33,13 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.{ResetSystemProperties, Utils} /** * This suite tests spark-submit with applications using HiveContext. */ +@ExtendedHiveTest class HiveSparkSubmitSuite extends SparkSubmitTestUtils with Matchers @@ -46,8 +48,6 @@ class HiveSparkSubmitSuite override protected val enableAutoThreadAudit = false - // TODO: rewrite these or mark them as slow tests to be run sparingly - override def beforeEach() { super.beforeEach() } From 2c6f4d61bbf7f0267a7309b4a236047f830bd6ee Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 5 Oct 2018 17:25:28 -0700 Subject: [PATCH 1773/2461] [SPARK-25610][SQL][TEST] Improve execution time of DatasetCacheSuite: cache UDF result correctly ## What changes were proposed in this pull request? In this test case, we are verifying that the result of an UDF is cached when the underlying data frame is cached and that the udf is not evaluated again when the cached data frame is used. To reduce the runtime we do : 1) Use a single partition dataframe, so the total execution time of UDF is more deterministic. 2) Cut down the size of the dataframe from 10 to 2. 3) Reduce the sleep time in the UDF from 5secs to 2secs. 4) Reduce the failafter condition from 3 to 2. With the above change, it takes about 4 secs to cache the first dataframe. And subsequent check takes a few hundred milliseconds. The new runtime for 5 consecutive runs of this test is as follows : ``` [info] - cache UDF result correctly (4 seconds, 906 milliseconds) [info] - cache UDF result correctly (4 seconds, 281 milliseconds) [info] - cache UDF result correctly (4 seconds, 288 milliseconds) [info] - cache UDF result correctly (4 seconds, 355 milliseconds) [info] - cache UDF result correctly (4 seconds, 280 milliseconds) ``` ## How was this patch tested? This is s test fix. Closes #22638 from dilipbiswal/SPARK-25610. Authored-by: Dilip Biswal Signed-off-by: gatorsmile --- .../test/scala/org/apache/spark/sql/DatasetCacheSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 5c6a021d5b767..fef6ddd0b93c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -127,8 +127,8 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits } test("cache UDF result correctly") { - val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) - val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val expensiveUDF = udf({x: Int => Thread.sleep(2000); x}) + val df = spark.range(0, 2).toDF("a").repartition(1).withColumn("b", expensiveUDF($"a")) val df2 = df.agg(sum(df("b"))) df.cache() @@ -136,7 +136,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits assertCached(df2) // udf has been evaluated during caching, and thus should not be re-evaluated here - failAfter(3 seconds) { + failAfter(2 seconds) { df2.collect() } From 58287a39864db463eeef17d1152d664be021d9ef Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 5 Oct 2018 21:15:16 -0700 Subject: [PATCH 1774/2461] [SPARK-25646][K8S] Fix docker-image-tool.sh on dev build. The docker file was referencing a path that only existed in the distribution tarball; it needs to be parameterized so that the right path can be used in a dev build. Tested on local dev build. Closes #22634 from vanzin/SPARK-25646. Authored-by: Marcelo Vanzin Signed-off-by: Dongjoon Hyun --- bin/docker-image-tool.sh | 2 ++ .../kubernetes/docker/src/main/dockerfiles/spark/Dockerfile | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index d6371051ef7fb..228494de6d5a1 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -54,6 +54,8 @@ function build { img_path=$IMG_PATH --build-arg spark_jars=assembly/target/scala-$SPARK_SCALA_VERSION/jars + --build-arg + k8s_tests=resource-managers/kubernetes/integration-tests/tests ) else # Not passed as an argument to docker, but used to validate the Spark directory. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 7ae57bf6e42d0..1c4dcd5476872 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -19,6 +19,7 @@ FROM openjdk:8-alpine ARG spark_jars=jars ARG img_path=kubernetes/dockerfiles +ARG k8s_tests=kubernetes/tests # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. @@ -43,7 +44,7 @@ COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples -COPY kubernetes/tests /opt/spark/tests +COPY ${k8s_tests} /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark From 44cf800c831588b1f7940dd8eef7ecb6cde28f23 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 6 Oct 2018 14:25:48 +0800 Subject: [PATCH 1775/2461] [SPARK-25655][BUILD] Add -Pspark-ganglia-lgpl to the scala style check. ## What changes were proposed in this pull request? Our lint failed due to the following errors: ``` [INFO] --- scalastyle-maven-plugin:1.0.0:check (default) spark-ganglia-lgpl_2.11 --- error file=/home/jenkins/workspace/spark-master-maven-snapshots/spark/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala message= Are you sure that you want to use toUpperCase or toLowerCase without the root locale? In most cases, you should use toUpperCase(Locale.ROOT) or toLowerCase(Locale.ROOT) instead. If you must use toUpperCase or toLowerCase without the root locale, wrap the code block with // scalastyle:off caselocale .toUpperCase .toLowerCase // scalastyle:on caselocale line=67 column=49 error file=/home/jenkins/workspace/spark-master-maven-snapshots/spark/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala message= Are you sure that you want to use toUpperCase or toLowerCase without the root locale? In most cases, you should use toUpperCase(Locale.ROOT) or toLowerCase(Locale.ROOT) instead. If you must use toUpperCase or toLowerCase without the root locale, wrap the code block with // scalastyle:off caselocale .toUpperCase .toLowerCase // scalastyle:on caselocale line=71 column=32 Saving to outputFile=/home/jenkins/workspace/spark-master-maven-snapshots/spark/external/spark-ganglia-lgpl/target/scalastyle-output.xml ``` See https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-lint/8890/ ## How was this patch tested? N/A Closes #22647 from gatorsmile/fixLint. Authored-by: gatorsmile Signed-off-by: hyukjinkwon --- dev/scalastyle | 1 + .../scala/org/apache/spark/metrics/sink/GangliaSink.scala | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dev/scalastyle b/dev/scalastyle index b8053df05fa2b..b0ad02523826c 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -29,6 +29,7 @@ ERRORS=$(echo -e "q\n" \ -Pflume \ -Phive \ -Phive-thriftserver \ + -Pspark-ganglia-lgpl \ scalastyle test:scalastyle \ | awk '{if($1~/error/)print}' \ ) diff --git a/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 0cd795f638870..93db4773372cd 100644 --- a/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -64,11 +64,12 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, val ttl = propertyToOption(GANGLIA_KEY_TTL).map(_.toInt).getOrElse(GANGLIA_DEFAULT_TTL) val dmax = propertyToOption(GANGLIA_KEY_DMAX).map(_.toInt).getOrElse(GANGLIA_DEFAULT_DMAX) val mode: UDPAddressingMode = propertyToOption(GANGLIA_KEY_MODE) - .map(u => GMetric.UDPAddressingMode.valueOf(u.toUpperCase)).getOrElse(GANGLIA_DEFAULT_MODE) + .map(u => GMetric.UDPAddressingMode.valueOf(u.toUpperCase(Locale.Root))) + .getOrElse(GANGLIA_DEFAULT_MODE) val pollPeriod = propertyToOption(GANGLIA_KEY_PERIOD).map(_.toInt) .getOrElse(GANGLIA_DEFAULT_PERIOD) val pollUnit: TimeUnit = propertyToOption(GANGLIA_KEY_UNIT) - .map(u => TimeUnit.valueOf(u.toUpperCase)) + .map(u => TimeUnit.valueOf(u.toUpperCase(Locale.Root))) .getOrElse(GANGLIA_DEFAULT_UNIT) MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) From 17781d75308c328b11cab3658ca4f358539414f2 Mon Sep 17 00:00:00 2001 From: Parker Hegstrom Date: Sat, 6 Oct 2018 14:30:43 +0800 Subject: [PATCH 1776/2461] [SPARK-25202][SQL] Implements split with limit sql function ## What changes were proposed in this pull request? Adds support for the setting limit in the sql split function ## How was this patch tested? 1. Updated unit tests 2. Tested using Scala spark shell Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22227 from phegstrom/master. Authored-by: Parker Hegstrom Signed-off-by: hyukjinkwon --- R/pkg/R/functions.R | 15 +++++-- R/pkg/R/generics.R | 2 +- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++++ .../apache/spark/unsafe/types/UTF8String.java | 6 +++ .../spark/unsafe/types/UTF8StringSuite.java | 14 +++--- python/pyspark/sql/functions.py | 28 +++++++++--- .../expressions/regexpExpressions.scala | 44 ++++++++++++++----- .../expressions/RegexpExpressionsSuite.scala | 15 +++++-- .../org/apache/spark/sql/functions.scala | 32 ++++++++++++-- .../sql-tests/inputs/string-functions.sql | 6 ++- .../results/string-functions.sql.out | 18 +++++++- .../spark/sql/StringFunctionsSuite.scala | 44 +++++++++++++++++-- 12 files changed, 189 insertions(+), 43 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 2cb4cb8d531e1..6a8fef5aa7b22 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3473,13 +3473,21 @@ setMethod("collect_set", #' @details #' \code{split_string}: Splits string on regular expression. -#' Equivalent to \code{split} SQL function. +#' Equivalent to \code{split} SQL function. Optionally a +#' \code{limit} can be specified #' #' @rdname column_string_functions +#' @param limit determines the length of the returned array. +#' \itemize{ +#' \item \code{limit > 0}: length of the array will be at most \code{limit} +#' \item \code{limit <= 0}: the returned array can have any length +#' } +#' #' @aliases split_string split_string,Column-method #' @examples #' #' \dontrun{ +#' head(select(df, split_string(df$Class, "\\d", 2))) #' head(select(df, split_string(df$Sex, "a"))) #' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression @@ -3487,8 +3495,9 @@ setMethod("collect_set", #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), - function(x, pattern) { - jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + function(x, pattern, limit = -1) { + jc <- callJStatic("org.apache.spark.sql.functions", + "split", x@jc, pattern, as.integer(limit)) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 27c1b312d645c..697d124095a75 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1258,7 +1258,7 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") #' @rdname column_string_functions #' @name NULL -setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) +setGeneric("split_string", function(x, pattern, ...) { standardGeneric("split_string") }) #' @rdname column_string_functions #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 50eff3755edf8..5cc75aa3f3673 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1819,6 +1819,14 @@ test_that("string operators", { collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], list(list("a.b@c.d 1", "b")) ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.", 2)))[1, 1], + list(list("a", "b@c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "b", 0)))[1, 1], + list(list("a.", "@c.d 1\\", "")) + ) l5 <- list(list(a = "abc")) df5 <- createDataFrame(l5) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index dff4a73f3e9da..3a3bfc4a94bb3 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -958,6 +958,12 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { } public UTF8String[] split(UTF8String pattern, int limit) { + // Java String's split method supports "ignore empty string" behavior when the limit is 0 + // whereas other languages do not. To avoid this java specific behavior, we fall back to + // -1 when the limit is 0. + if (limit == 0) { + limit = -1; + } String[] splits = toString().split(pattern.toString(), limit); UTF8String[] res = new UTF8String[splits.length]; for (int i = 0; i < res.length; i++) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index dae13f03b02ff..cf9cc6b1800a9 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -393,12 +393,14 @@ public void substringSQL() { @Test public void split() { - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), - new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), - new UTF8String[]{fromString("ab"), fromString("def,ghi")})); - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), - new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + UTF8String[] negativeAndZeroLimitCase = + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi"), fromString("")}; + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 0), + negativeAndZeroLimitCase)); + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), -1), + negativeAndZeroLimitCase)); + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi,")})); } @Test diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3128d5792eead..7685264b2d4d1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1691,18 +1691,32 @@ def repeat(col, n): @since(1.5) @ignore_unicode_prefix -def split(str, pattern): +def split(str, pattern, limit=-1): """ - Splits str around pattern (pattern is a regular expression). + Splits str around matches of the given pattern. - .. note:: pattern is a string represent the regular expression. + :param str: a string expression to split + :param pattern: a string representing a regular expression. The regex string should be + a Java regular expression. + :param limit: an integer which controls the number of times `pattern` is applied. - >>> df = spark.createDataFrame([('ab12cd',)], ['s',]) - >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() - [Row(s=[u'ab', u'cd'])] + * ``limit > 0``: The resulting array's length will not be more than `limit`, and the + resulting array's last entry will contain all input beyond the last + matched pattern. + * ``limit <= 0``: `pattern` will be applied as many times as possible, and the resulting + array can be of any size. + + .. versionchanged:: 3.0 + `split` now takes an optional `limit` field. If not provided, default limit value is -1. + + >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() + [Row(s=[u'one', u'twoBthreeC'])] + >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() + [Row(s=[u'one', u'two', u'three', u''])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.split(_to_java_column(str), pattern)) + return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index bf0c35fe61018..4f5ea1e95f833 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -157,7 +157,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi arguments = """ Arguments: * str - a string expression - * regexp - a string expression. The pattern string should be a Java regular expression. + * regexp - a string expression. The regex string should be a Java regular expression. Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser. For example, to match "\abc", a regular expression for `regexp` can be @@ -229,33 +229,53 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress /** - * Splits str around pat (pattern is a regular expression). + * Splits str around matches of the given regex. */ @ExpressionDescription( - usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.", + usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" + + " and returns an array with a length of at most `limit`", + arguments = """ + Arguments: + * str - a string expression to split. + * regex - a string representing a regular expression. The regex string should be a + Java regular expression. + * limit - an integer expression which controls the number of times the regex is applied. + * limit > 0: The resulting array's length will not be more than `limit`, + and the resulting array's last entry will contain all input + beyond the last matched regex. + * limit <= 0: `regex` will be applied as many times as possible, and + the resulting array can be of any size. + """, examples = """ Examples: > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]'); ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', -1); + ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2); + ["one","twoBthreeC"] """) -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class StringSplit(str: Expression, regex: Expression, limit: Expression) + extends TernaryExpression with ImplicitCastInputTypes { - override def left: Expression = str - override def right: Expression = pattern override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = str :: regex :: limit :: Nil - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)); + + override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split( + regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => + nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""") + s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + }) } override def prettyName: String = "split" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index d532dc4f77198..06fb73ad83923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -225,11 +225,18 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row3 = create_row("aa2bb3cc", null) checkEvaluation( - StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), -1), Seq("aa", "bb", "cc"), row1) checkEvaluation( - StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) - checkEvaluation(StringSplit(s1, s2), null, row2) - checkEvaluation(StringSplit(s1, s2), null, row3) + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), 2), Seq("aa", "bb3cc"), row1) + // limit = 0 should behave just like limit = -1 + checkEvaluation( + StringSplit(Literal("aacbbcddc"), Literal("c"), 0), Seq("aa", "bb", "dd", ""), row1) + checkEvaluation( + StringSplit(Literal("aacbbcddc"), Literal("c"), -1), Seq("aa", "bb", "dd", ""), row1) + checkEvaluation( + StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2, -1), null, row2) + checkEvaluation(StringSplit(s1, s2, -1), null, row3) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 367ac66dd77f5..4247d3110f1e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2546,15 +2546,39 @@ object functions { def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** - * Splits str around pattern (pattern is a regular expression). + * Splits str around matches of the given regex. * - * @note Pattern is a string representation of the regular expression. + * @param str a string expression to split + * @param regex a string representing a regular expression. The regex string should be + * a Java regular expression. * * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = withExpr { - StringSplit(str.expr, lit(pattern).expr) + def split(str: Column, regex: String): Column = withExpr { + StringSplit(str.expr, Literal(regex), Literal(-1)) + } + + /** + * Splits str around matches of the given regex. + * + * @param str a string expression to split + * @param regex a string representing a regular expression. The regex string should be + * a Java regular expression. + * @param limit an integer expression which controls the number of times the regex is applied. + *
        + *
      • limit greater than 0: The resulting array's length will not be more than limit, + * and the resulting array's last entry will contain all input beyond the last + * matched regex.
      • + *
      • limit less than or equal to 0: `regex` will be applied as many times as + * possible, and the resulting array can be of any size.
      • + *
      + * + * @group string_funcs + * @since 3.0.0 + */ + def split(str: Column, regex: String, limit: Int): Column = withExpr { + StringSplit(str.expr, Literal(regex), Literal(limit)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 4113734e1707e..2effb43183d75 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -46,4 +46,8 @@ FROM ( encode(string(id + 2), 'utf-8') col3, encode(string(id + 3), 'utf-8') col4 FROM range(10) -) +); + +-- split function +SELECT split('aa1cc2ee3', '[1-9]+'); +SELECT split('aa1cc2ee3', '[1-9]+', 2); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 7b3dc84388889..e8f2e0a81455a 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 17 -- !query 0 @@ -161,3 +161,19 @@ struct == Physical Plan == *Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] +- *Range (0, 10, step=1, splits=2) + + +-- !query 15 +SELECT split('aa1cc2ee3', '[1-9]+') +-- !query 15 schema +struct> +-- !query 15 output +["aa","cc","ee",""] + + +-- !query 16 +SELECT split('aa1cc2ee3', '[1-9]+', 2) +-- !query 16 schema +struct> +-- !query 16 output +["aa","cc2ee3"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3d76b9ac33e57..bb19fde2b2b5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -329,16 +329,52 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row(" ")) } - test("string split function") { - val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + test("string split function with no limit") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") checkAnswer( df.select(split($"a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"))) + Row(Seq("aa", "bb", "cc", ""))) checkAnswer( df.selectExpr("split(a, '[1-9]+')"), - Row(Seq("aa", "bb", "cc"))) + Row(Seq("aa", "bb", "cc", ""))) + } + + test("string split function with limit explicitly set to 0") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+", 0)), + Row(Seq("aa", "bb", "cc", ""))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+', 0)"), + Row(Seq("aa", "bb", "cc", ""))) + } + + test("string split function with positive limit") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+", 2)), + Row(Seq("aa", "bb3cc4"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+', 2)"), + Row(Seq("aa", "bb3cc4"))) + } + + test("string split function with negative limit") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+", -2)), + Row(Seq("aa", "bb", "cc", ""))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+', -2)"), + Row(Seq("aa", "bb", "cc", ""))) } test("string / binary length function") { From f2f4e7afe730badaf443f459b27fe40879947d51 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 6 Oct 2018 14:49:51 +0800 Subject: [PATCH 1777/2461] [SPARK-25600][SQL][MINOR] Make use of TypeCoercion.findTightestCommonType while inferring CSV schema. ## What changes were proposed in this pull request? Current the CSV's infer schema code inlines `TypeCoercion.findTightestCommonType`. This is a minor refactor to make use of the common type coercion code when applicable. This way we can take advantage of any improvement to the base method. Thanks to MaxGekk for finding this while reviewing another PR. ## How was this patch tested? This is a minor refactor. Existing tests are used to verify the change. Closes #22619 from dilipbiswal/csv_minor. Authored-by: Dilip Biswal Signed-off-by: hyukjinkwon --- .../datasources/csv/CSVInferSchema.scala | 37 +++++++------------ 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index a585cbed2551b..3596ff105fd7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -70,7 +70,7 @@ private[csv] object CSVInferSchema { def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case (a, b) => - findTightestCommonType(a, b).getOrElse(NullType) + compatibleType(a, b).getOrElse(NullType) } } @@ -88,7 +88,7 @@ private[csv] object CSVInferSchema { case LongType => tryParseLong(field, options) case _: DecimalType => // DecimalTypes have different precisions and scales, so we try to find the common type. - findTightestCommonType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) + compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) case DoubleType => tryParseDouble(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) @@ -172,35 +172,27 @@ private[csv] object CSVInferSchema { StringType } - private val numericPrecedence: IndexedSeq[DataType] = TypeCoercion.numericPrecedence + /** + * Returns the common data type given two input data types so that the return type + * is compatible with both input data types. + */ + private def compatibleType(t1: DataType, t2: DataType): Option[DataType] = { + TypeCoercion.findTightestCommonType(t1, t2).orElse(findCompatibleTypeForCSV(t1, t2)) + } /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]] + * The following pattern matching represents additional type promotion rules that + * are CSV specific. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { - case (t1, t2) if t1 == t2 => Some(t1) - case (NullType, t1) => Some(t1) - case (t1, NullType) => Some(t1) + private val findCompatibleTypeForCSV: (DataType, DataType) => Option[DataType] = { case (StringType, t2) => Some(StringType) case (t1, StringType) => Some(StringType) - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal - case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => - val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) - Some(numericPrecedence(index)) - - // These two cases below deal with when `DecimalType` is larger than `IntegralType`. - case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => - Some(t2) - case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => - Some(t1) - // These two cases below deal with when `IntegralType` is larger than `DecimalType`. case (t1: IntegralType, t2: DecimalType) => - findTightestCommonType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - findTightestCommonType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2)) // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. @@ -216,7 +208,6 @@ private[csv] object CSVInferSchema { } else { Some(DecimalType(range + scale, scale)) } - case _ => None } } From 1ee472eec15e104c4cd087179a9491dc542e15d7 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 6 Oct 2018 14:54:04 +0800 Subject: [PATCH 1778/2461] [SPARK-25621][SPARK-25622][TEST] Reduce test time of BucketedReadWithHiveSupportSuite ## What changes were proposed in this pull request? By replacing loops with random possible value. - `read partitioning bucketed tables with bucket pruning filters` reduce from 55s to 7s - `read partitioning bucketed tables having composite filters` reduce from 54s to 8s - total time: reduce from 288s to 192s ## How was this patch tested? Unit test Closes #22640 from gengliangwang/fastenBucketedReadSuite. Authored-by: Gengliang Wang Signed-off-by: hyukjinkwon --- .../spark/sql/sources/BucketedReadSuite.scala | 181 +++++++++--------- 1 file changed, 91 insertions(+), 90 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a9414200e70f8..a2bc651bb2bd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.sources import java.io.File import java.net.URI +import scala.util.Random + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions @@ -47,11 +49,13 @@ class BucketedReadWithoutHiveSupportSuite extends BucketedReadSuite with SharedS abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { import testImplicits._ - private lazy val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private val maxI = 5 + private val maxJ = 13 + private lazy val df = (0 until 50).map(i => (i % maxI, i % maxJ, i.toString)).toDF("i", "j", "k") private lazy val nullDF = (for { i <- 0 to 50 s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") - } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + } yield (i % maxI, s, i % maxJ)).toDF("i", "j", "k") // number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF // empty buckets before filtering might hide bugs in pruning logic @@ -66,23 +70,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { .bucketBy(8, "j", "k") .saveAsTable("bucketed_table") - for (i <- 0 until 5) { - val table = spark.table("bucketed_table").filter($"i" === i) - val query = table.queryExecution - val output = query.analyzed.output - val rdd = query.toRdd - - assert(rdd.partitions.length == 8) - - val attrs = table.select("j", "k").queryExecution.analyzed.output - val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { - val getBucketId = UnsafeProjection.create( - HashPartitioning(attrs, 8).partitionIdExpression :: Nil, - output) - rows.map(row => getBucketId(row).getInt(0) -> index) - }) - checkBucketId.collect().foreach(r => assert(r._1 == r._2)) - } + val bucketValue = Random.nextInt(maxI) + val table = spark.table("bucketed_table").filter($"i" === bucketValue) + val query = table.queryExecution + val output = query.analyzed.output + val rdd = query.toRdd + + assert(rdd.partitions.length == 8) + + val attrs = table.select("j", "k").queryExecution.analyzed.output + val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { + val getBucketId = UnsafeProjection.create( + HashPartitioning(attrs, 8).partitionIdExpression :: Nil, + output) + rows.map(row => getBucketId(row).getInt(0) -> index) + }) + checkBucketId.collect().foreach(r => assert(r._1 == r._2)) } } @@ -145,36 +148,36 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { .bucketBy(numBuckets, "j") .saveAsTable("bucketed_table") - for (j <- 0 until 13) { - // Case 1: EqualTo - checkPrunedAnswers( - bucketSpec, - bucketValues = j :: Nil, - filterCondition = $"j" === j, - df) - - // Case 2: EqualNullSafe - checkPrunedAnswers( - bucketSpec, - bucketValues = j :: Nil, - filterCondition = $"j" <=> j, - df) - - // Case 3: In - checkPrunedAnswers( - bucketSpec, - bucketValues = Seq(j, j + 1, j + 2, j + 3), - filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), - df) - - // Case 4: InSet - val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr)) - checkPrunedAnswers( - bucketSpec, - bucketValues = Seq(j, j + 1, j + 2, j + 3), - filterCondition = Column(inSetExpr), - df) - } + val bucketValue = Random.nextInt(maxJ) + // Case 1: EqualTo + checkPrunedAnswers( + bucketSpec, + bucketValues = bucketValue :: Nil, + filterCondition = $"j" === bucketValue, + df) + + // Case 2: EqualNullSafe + checkPrunedAnswers( + bucketSpec, + bucketValues = bucketValue :: Nil, + filterCondition = $"j" <=> bucketValue, + df) + + // Case 3: In + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3), + filterCondition = $"j".isin(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3), + df) + + // Case 4: InSet + val inSetExpr = expressions.InSet($"j".expr, + Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr)) + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3), + filterCondition = Column(inSetExpr), + df) } } @@ -188,13 +191,12 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { .bucketBy(numBuckets, "j") .saveAsTable("bucketed_table") - for (j <- 0 until 13) { - checkPrunedAnswers( - bucketSpec, - bucketValues = j :: Nil, - filterCondition = $"j" === j, - df) - } + val bucketValue = Random.nextInt(maxJ) + checkPrunedAnswers( + bucketSpec, + bucketValues = bucketValue :: Nil, + filterCondition = $"j" === bucketValue, + df) } } @@ -236,40 +238,39 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { .bucketBy(numBuckets, "j") .saveAsTable("bucketed_table") - for (j <- 0 until 13) { - checkPrunedAnswers( - bucketSpec, - bucketValues = j :: Nil, - filterCondition = $"j" === j && $"k" > $"j", - df) - - checkPrunedAnswers( - bucketSpec, - bucketValues = j :: Nil, - filterCondition = $"j" === j && $"i" > j % 5, - df) - - // check multiple bucket values OR condition - checkPrunedAnswers( - bucketSpec, - bucketValues = Seq(j, j + 1), - filterCondition = $"j" === j || $"j" === (j + 1), - df) - - // check bucket value and none bucket value OR condition - checkPrunedAnswers( - bucketSpec, - bucketValues = Nil, - filterCondition = $"j" === j || $"i" === 0, - df) - - // check AND condition in complex expression - checkPrunedAnswers( - bucketSpec, - bucketValues = Seq(j), - filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, - df) - } + val bucketValue = Random.nextInt(maxJ) + checkPrunedAnswers( + bucketSpec, + bucketValues = bucketValue :: Nil, + filterCondition = $"j" === bucketValue && $"k" > $"j", + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = bucketValue :: Nil, + filterCondition = $"j" === bucketValue && $"i" > bucketValue % 5, + df) + + // check multiple bucket values OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(bucketValue, bucketValue + 1), + filterCondition = $"j" === bucketValue || $"j" === (bucketValue + 1), + df) + + // check bucket value and none bucket value OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Nil, + filterCondition = $"j" === bucketValue || $"i" === 0, + df) + + // check AND condition in complex expression + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(bucketValue), + filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === bucketValue, + df) } } From edf42866118c8522dedea3fab848b04a7c50e44c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 6 Oct 2018 08:47:43 -0700 Subject: [PATCH 1779/2461] [SPARK-25488][SQL][TEST] Refactor MiscBenchmark to use main method ## What changes were proposed in this pull request? Refactor `MiscBenchmark ` to use main method. Generate benchmark result: ```sh SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.MiscBenchmark" ``` ## How was this patch tested? manual tests Closes #22500 from wangyum/SPARK-25488. Lead-authored-by: Yuming Wang Co-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- sql/core/benchmarks/MiscBenchmark-results.txt | 120 +++++++ .../execution/benchmark/MiscBenchmark.scala | 331 ++++++------------ 2 files changed, 232 insertions(+), 219 deletions(-) create mode 100644 sql/core/benchmarks/MiscBenchmark-results.txt diff --git a/sql/core/benchmarks/MiscBenchmark-results.txt b/sql/core/benchmarks/MiscBenchmark-results.txt new file mode 100644 index 0000000000000..85acd57893655 --- /dev/null +++ b/sql/core/benchmarks/MiscBenchmark-results.txt @@ -0,0 +1,120 @@ +================================================================================================ +filter & aggregate without group +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +range/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +range/filter/sum wholestage off 47752 / 48952 43.9 22.8 1.0X +range/filter/sum wholestage on 3123 / 3558 671.5 1.5 15.3X + + +================================================================================================ +range/limit/sum +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +range/limit/sum wholestage off 229 / 236 2288.9 0.4 1.0X +range/limit/sum wholestage on 257 / 267 2041.0 0.5 0.9X + + +================================================================================================ +sample +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +sample with replacement: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +sample with replacement wholestage off 12908 / 13076 10.2 98.5 1.0X +sample with replacement wholestage on 7334 / 7346 17.9 56.0 1.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +sample without replacement: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +sample without replacement wholestage off 3082 / 3095 42.5 23.5 1.0X +sample without replacement wholestage on 1125 / 1211 116.5 8.6 2.7X + + +================================================================================================ +collect +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +collect 1 million 291 / 311 3.6 277.3 1.0X +collect 2 millions 552 / 564 1.9 526.6 0.5X +collect 4 millions 1104 / 1108 0.9 1053.0 0.3X + + +================================================================================================ +collect limit +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +collect limit 1 million 311 / 340 3.4 296.2 1.0X +collect limit 2 millions 581 / 614 1.8 554.4 0.5X + + +================================================================================================ +generate explode +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +generate explode array wholestage off 15211 / 15368 1.1 906.6 1.0X +generate explode array wholestage on 10761 / 10776 1.6 641.4 1.4X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +generate explode map wholestage off 22128 / 22578 0.8 1318.9 1.0X +generate explode map wholestage on 16421 / 16520 1.0 978.8 1.3X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +generate posexplode array wholestage off 17108 / 18019 1.0 1019.7 1.0X +generate posexplode array wholestage on 11715 / 11804 1.4 698.3 1.5X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +generate inline array wholestage off 16358 / 16418 1.0 975.0 1.0X +generate inline array wholestage on 11152 / 11472 1.5 664.7 1.5X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +generate big struct array wholestage off 708 / 776 0.1 11803.5 1.0X +generate big struct array wholestage on 535 / 589 0.1 8913.9 1.3X + + +================================================================================================ +generate regular generator +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +generate stack wholestage off 29082 / 29393 0.6 1733.4 1.0X +generate stack wholestage on 21066 / 21128 0.8 1255.6 1.4X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index f44da242e62b9..43380869fefe4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -21,247 +21,140 @@ import org.apache.spark.benchmark.Benchmark /** * Benchmark to measure whole stage codegen performance. - * To run this: - * build/sbt "sql/test-only *benchmark.MiscBenchmark" - * - * Benchmarks in this file are skipped in normal builds. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/MiscBenchmark-results.txt". + * }}} */ -class MiscBenchmark extends BenchmarkWithCodegen { +object MiscBenchmark extends SqlBasedBenchmark { - ignore("filter & aggregate without group") { - val N = 500L << 22 - runBenchmark("range/filter/sum", N) { - sparkSession.range(N).filter("(id & 1) = 1").groupBy().sum().collect() + def filterAndAggregateWithoutGroup(numRows: Long): Unit = { + runBenchmark("filter & aggregate without group") { + codegenBenchmark("range/filter/sum", numRows) { + spark.range(numRows).filter("(id & 1) = 1").groupBy().sum().collect() + } } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - range/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - range/filter/sum codegen=false 30663 / 31216 68.4 14.6 1.0X - range/filter/sum codegen=true 2399 / 2409 874.1 1.1 12.8X - */ } - ignore("range/limit/sum") { - val N = 500L << 20 - runBenchmark("range/limit/sum", N) { - sparkSession.range(N).limit(1000000).groupBy().sum().collect() + def limitAndAggregateWithoutGroup(numRows: Long): Unit = { + runBenchmark("range/limit/sum") { + codegenBenchmark("range/limit/sum", numRows) { + spark.range(numRows).limit(1000000).groupBy().sum().collect() + } } - /* - Westmere E56xx/L56xx/X56xx (Nehalem-C) - range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X - range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X - */ } - ignore("sample") { - val N = 500 << 18 - runBenchmark("sample with replacement", N) { - sparkSession.range(N).sample(withReplacement = true, 0.01).groupBy().sum().collect() - } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + def sample(numRows: Int): Unit = { + runBenchmark("sample") { + codegenBenchmark("sample with replacement", numRows) { + spark.range(numRows).sample(withReplacement = true, 0.01).groupBy().sum().collect() + } - sample with replacement: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - sample with replacement codegen=false 7073 / 7227 18.5 54.0 1.0X - sample with replacement codegen=true 5199 / 5203 25.2 39.7 1.4X - */ - - runBenchmark("sample without replacement", N) { - sparkSession.range(N).sample(withReplacement = false, 0.01).groupBy().sum().collect() + codegenBenchmark("sample without replacement", numRows) { + spark.range(numRows).sample(withReplacement = false, 0.01).groupBy().sum().collect() + } } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - sample without replacement: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - sample without replacement codegen=false 1508 / 1529 86.9 11.5 1.0X - sample without replacement codegen=true 644 / 662 203.5 4.9 2.3X - */ } - ignore("collect") { - val N = 1 << 20 - - val benchmark = new Benchmark("collect", N) - benchmark.addCase("collect 1 million") { iter => - sparkSession.range(N).collect() - } - benchmark.addCase("collect 2 millions") { iter => - sparkSession.range(N * 2).collect() - } - benchmark.addCase("collect 4 millions") { iter => - sparkSession.range(N * 4).collect() + def collect(numRows: Int): Unit = { + runBenchmark("collect") { + val benchmark = new Benchmark("collect", numRows, output = output) + benchmark.addCase("collect 1 million") { iter => + spark.range(numRows).collect() + } + benchmark.addCase("collect 2 millions") { iter => + spark.range(numRows * 2).collect() + } + benchmark.addCase("collect 4 millions") { iter => + spark.range(numRows * 4).collect() + } + benchmark.run() } - benchmark.run() - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - collect 1 million 439 / 654 2.4 418.7 1.0X - collect 2 millions 961 / 1907 1.1 916.4 0.5X - collect 4 millions 3193 / 3895 0.3 3044.7 0.1X - */ } - ignore("collect limit") { - val N = 1 << 20 - - val benchmark = new Benchmark("collect limit", N) - benchmark.addCase("collect limit 1 million") { iter => - sparkSession.range(N * 4).limit(N).collect() + def collectLimit(numRows: Int): Unit = { + runBenchmark("collect limit") { + val benchmark = new Benchmark("collect limit", numRows, output = output) + benchmark.addCase("collect limit 1 million") { iter => + spark.range(numRows * 4).limit(numRows).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + spark.range(numRows * 4).limit(numRows * 2).collect() + } + benchmark.run() } - benchmark.addCase("collect limit 2 millions") { iter => - sparkSession.range(N * 4).limit(N * 2).collect() - } - benchmark.run() - - /* - model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) - collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - collect limit 1 million 833 / 1284 1.3 794.4 1.0X - collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X - */ } - ignore("generate explode") { - val N = 1 << 24 - runBenchmark("generate explode array", N) { - val df = sparkSession.range(N).selectExpr( - "id as key", - "array(rand(), rand(), rand(), rand(), rand()) as values") - df.selectExpr("key", "explode(values) value").count() - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 - Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz - - generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate explode array wholestage off 6920 / 7129 2.4 412.5 1.0X - generate explode array wholestage on 623 / 646 26.9 37.1 11.1X - */ - - runBenchmark("generate explode map", N) { - val df = sparkSession.range(N).selectExpr( - "id as key", - "map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs") - df.selectExpr("key", "explode(pairs) as (k, v)").count() - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 - Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz - - generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate explode map wholestage off 11978 / 11993 1.4 714.0 1.0X - generate explode map wholestage on 866 / 919 19.4 51.6 13.8X - */ - - runBenchmark("generate posexplode array", N) { - val df = sparkSession.range(N).selectExpr( - "id as key", - "array(rand(), rand(), rand(), rand(), rand()) as values") - df.selectExpr("key", "posexplode(values) as (idx, value)").count() + def explode(numRows: Int): Unit = { + runBenchmark("generate explode") { + codegenBenchmark("generate explode array", numRows) { + val df = spark.range(numRows).selectExpr( + "id as key", + "array(rand(), rand(), rand(), rand(), rand()) as values") + df.selectExpr("key", "explode(values) value").count() + } + + codegenBenchmark("generate explode map", numRows) { + val df = spark.range(numRows).selectExpr( + "id as key", + "map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs") + df.selectExpr("key", "explode(pairs) as (k, v)").count() + } + + codegenBenchmark("generate posexplode array", numRows) { + val df = spark.range(numRows).selectExpr( + "id as key", + "array(rand(), rand(), rand(), rand(), rand()) as values") + df.selectExpr("key", "posexplode(values) as (idx, value)").count() + } + + codegenBenchmark("generate inline array", numRows) { + val df = spark.range(numRows).selectExpr( + "id as key", + "array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values") + df.selectExpr("key", "inline(values) as (r1, r2)").count() + } + + val M = 60000 + codegenBenchmark("generate big struct array", M) { + import spark.implicits._ + val df = spark.sparkContext.parallelize(Seq(("1", + Array.fill(M)({ + val i = math.random + (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString) + })))).toDF("col", "arr") + + df.selectExpr("*", "explode(arr) as arr_col") + .select("col", "arr_col.*").count + } } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 - Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz - - generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate posexplode array wholestage off 7502 / 7513 2.2 447.1 1.0X - generate posexplode array wholestage on 617 / 623 27.2 36.8 12.2X - */ - - runBenchmark("generate inline array", N) { - val df = sparkSession.range(N).selectExpr( - "id as key", - "array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values") - df.selectExpr("key", "inline(values) as (r1, r2)").count() - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 - Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz - - generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X - generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X - */ - - val M = 60000 - runBenchmark("generate big struct array", M) { - import sparkSession.implicits._ - val df = sparkSession.sparkContext.parallelize(Seq(("1", - Array.fill(M)({ - val i = math.random - (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString) - })))).toDF("col", "arr") - - df.selectExpr("*", "expode(arr) as arr_col") - .select("col", "arr_col.*").count - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 - Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz - - test the impact of adding the optimization of Generate.unrequiredChildIndex, - we can see enormous improvement of x250 in this case! and it grows O(n^2). - - with Optimization ON: - - generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate big struct array wholestage off 331 / 378 0.2 5524.9 1.0X - generate big struct array wholestage on 205 / 232 0.3 3413.1 1.6X - - with Optimization OFF: - - generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate big struct array wholestage off 49697 / 51496 0.0 828277.7 1.0X - generate big struct array wholestage on 50558 / 51434 0.0 842641.6 1.0X - */ - } - ignore("generate regular generator") { - val N = 1 << 24 - runBenchmark("generate stack", N) { - val df = sparkSession.range(N).selectExpr( - "id as key", - "id % 2 as t1", - "id % 3 as t2", - "id % 5 as t3", - "id % 7 as t4", - "id % 13 as t5") - df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count() + def stack(numRows: Int): Unit = { + runBenchmark("generate regular generator") { + codegenBenchmark("generate stack", numRows) { + val df = spark.range(numRows).selectExpr( + "id as key", + "id % 2 as t1", + "id % 3 as t2", + "id % 5 as t3", + "id % 7 as t4", + "id % 13 as t5") + df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count() + } } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 - Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz - - generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - generate stack wholestage off 12953 / 13070 1.3 772.1 1.0X - generate stack wholestage on 836 / 847 20.1 49.8 15.5X - */ } + override def runBenchmarkSuite(): Unit = { + filterAndAggregateWithoutGroup(500L << 22) + limitAndAggregateWithoutGroup(500L << 20) + sample(500 << 18) + collect(1 << 20) + collectLimit(1 << 20) + explode(1 << 24) + stack(1 << 24) + } } From 7ef65c0537ae3ac0961617f427584cc2e3d2a057 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 6 Oct 2018 08:50:50 -0700 Subject: [PATCH 1780/2461] [HOT-FIX] Fix compilation errors. --- .../scala/org/apache/spark/metrics/sink/GangliaSink.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 93db4773372cd..4fb9f2f849085 100644 --- a/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -64,12 +64,12 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, val ttl = propertyToOption(GANGLIA_KEY_TTL).map(_.toInt).getOrElse(GANGLIA_DEFAULT_TTL) val dmax = propertyToOption(GANGLIA_KEY_DMAX).map(_.toInt).getOrElse(GANGLIA_DEFAULT_DMAX) val mode: UDPAddressingMode = propertyToOption(GANGLIA_KEY_MODE) - .map(u => GMetric.UDPAddressingMode.valueOf(u.toUpperCase(Locale.Root))) + .map(u => GMetric.UDPAddressingMode.valueOf(u.toUpperCase(Locale.ROOT))) .getOrElse(GANGLIA_DEFAULT_MODE) val pollPeriod = propertyToOption(GANGLIA_KEY_PERIOD).map(_.toInt) .getOrElse(GANGLIA_DEFAULT_PERIOD) val pollUnit: TimeUnit = propertyToOption(GANGLIA_KEY_UNIT) - .map(u => TimeUnit.valueOf(u.toUpperCase(Locale.Root))) + .map(u => TimeUnit.valueOf(u.toUpperCase(Locale.ROOT))) .getOrElse(GANGLIA_DEFAULT_UNIT) MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) From 5a617ec4eac26c60facbace15f6f4222b86de6d4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 6 Oct 2018 09:15:44 -0700 Subject: [PATCH 1781/2461] [MINOR] Clean up the joinCriteria in SQL parser ## What changes were proposed in this pull request? Clean up the joinCriteria parsing in the parser by directly using identifierList ## How was this patch tested? N/A Closes #22648 from gatorsmile/cleanupJoinCriteria. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 16665eb0d7374..056998630b09f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -468,7 +468,7 @@ joinType joinCriteria : ON booleanExpression - | USING '(' identifier (',' identifier)* ')' + | USING identifierList ; sample diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index da12a6519bd28..ba0b72e747fc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -701,7 +701,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging // Resolve the join type and join condition val (joinType, condition) = Option(join.joinCriteria) match { case Some(c) if c.USING != null => - (UsingJoin(baseJoinType, c.identifier.asScala.map(_.getText)), None) + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) case Some(c) if c.booleanExpression != null => (baseJoinType, Option(expression(c.booleanExpression))) case None if join.NATURAL != null => From 9cbf105ab1256d65f027115ba5505842ce8fffe3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 6 Oct 2018 09:40:42 -0700 Subject: [PATCH 1782/2461] [SPARK-25644][SS][FOLLOWUP][BUILD] Fix Scala 2.12 build error due to foreachBatch ## What changes were proposed in this pull request? This PR fixes the Scala-2.12 build error due to ambiguity in `foreachBatch` test cases. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/428/console ```scala [error] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala:102: ambiguous reference to overloaded definition, [error] both method foreachBatch in class DataStreamWriter of type (function: org.apache.spark.api.java.function.VoidFunction2[org.apache.spark.sql.Dataset[Int],Long])org.apache.spark.sql.streaming.DataStreamWriter[Int] [error] and method foreachBatch in class DataStreamWriter of type (function: (org.apache.spark.sql.Dataset[Int], Long) => Unit)org.apache.spark.sql.streaming.DataStreamWriter[Int] [error] match argument types ((org.apache.spark.sql.Dataset[Int], Any) => Unit) [error] ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() [error] ^ [error] /home/jenkins/workspace/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.12/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala:106: ambiguous reference to overloaded definition, [error] both method foreachBatch in class DataStreamWriter of type (function: org.apache.spark.api.java.function.VoidFunction2[org.apache.spark.sql.Dataset[Int],Long])org.apache.spark.sql.streaming.DataStreamWriter[Int] [error] and method foreachBatch in class DataStreamWriter of type (function: (org.apache.spark.sql.Dataset[Int], Long) => Unit)org.apache.spark.sql.streaming.DataStreamWriter[Int] [error] match argument types ((org.apache.spark.sql.Dataset[Int], Any) => Unit) [error] ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() [error] ^ ``` ## How was this patch tested? Manual. Since this failure occurs in Scala-2.12 profile and test cases, Jenkins will not test this. We need to build with Scala-2.12 and run the tests. Closes #22649 from dongjoon-hyun/SPARK-SCALA212. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala | 4 ++-- .../execution/streaming/sources/ForeachBatchSinkSuite.scala | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 39c2cde7de40d..5ee76990b54f4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -33,7 +33,7 @@ import org.apache.kafka.common.TopicPartition import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.{ForeachWriter, SparkSession} +import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ @@ -900,7 +900,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } testUtils.waitUntilOffsetAppears(topicPartition, 5) - val q = ds.writeStream.foreachBatch { (ds, epochId) => + val q = ds.writeStream.foreachBatch { (ds: Dataset[String], epochId: Long) => if (epochId == 0) { // Send more message before the tasks of the current batch start reading the current batch // data, so that the executors will prefetch messages in the next batch and drop them. In diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index 71dff443e8836..3e9ccb0f705df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -99,11 +99,12 @@ class ForeachBatchSinkSuite extends StreamTest { } assert(ex1.getMessage.contains("foreachBatch function cannot be null")) val ex2 = intercept[AnalysisException] { - ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + ds.writeStream.foreachBatch((_: Dataset[Int], _: Long) => {}) + .trigger(Trigger.Continuous("1 second")).start() } assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) val ex3 = intercept[AnalysisException] { - ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + ds.writeStream.foreachBatch((_: Dataset[Int], _: Long) => {}).partitionBy("value").start() } assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) } From b0cee9605e7c71cfd020aa917319478f9ac61bdb Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 6 Oct 2018 14:50:03 -0700 Subject: [PATCH 1783/2461] [SPARK-25062][SQL] Clean up BlockLocations in InMemoryFileIndex ## What changes were proposed in this pull request? `InMemoryFileIndex` contains a cache of `LocatedFileStatus` objects. Each `LocatedFileStatus` object can contain several `BlockLocation`s or some subclass of it. Filling up this cache by listing files happens recursively either on the driver or on the executors, depending on the parallel discovery threshold (`spark.sql.sources.parallelPartitionDiscovery.threshold`). If the listing happens on the executors block location objects are converted to simple `BlockLocation` objects to ensure serialization requirements. If it happens on the driver then there is no conversion and depending on the file system a `BlockLocation` object can be a subclass like `HdfsBlockLocation` and consume more memory. This PR adds the conversion to the latter case and decreases memory consumption. ## How was this patch tested? Added unit test. Closes #22603 from peter-toth/SPARK-25062. Authored-by: Peter Toth Signed-off-by: Dongjoon Hyun --- .../datasources/InMemoryFileIndex.scala | 9 ++++- .../datasources/FileIndexSuite.scala | 39 ++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index dc5c2ff927e4a..fe418e610da8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -315,7 +315,14 @@ object InMemoryFileIndex extends Logging { // which is very slow on some file system (RawLocalFileSystem, which is launch a // subprocess and parse the stdout). try { - val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val locations = fs.getFileBlockLocations(f, 0, f.getLen).map { loc => + // Store BlockLocation objects to consume less memory + if (loc.getClass == classOf[BlockLocation]) { + loc + } else { + new BlockLocation(loc.getNames, loc.getHosts, loc.getOffset, loc.getLength) + } + } val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime, 0, null, null, null, null, f.getPath, locations) if (f.isSymlink) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 18bb4bfe661ce..49e7af4a9896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import scala.collection.mutable import scala.language.reflectiveCalls -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ @@ -248,6 +248,26 @@ class FileIndexSuite extends SharedSQLContext { assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape)) } } + + test("SPARK-25062 - InMemoryFileIndex stores BlockLocation objects no matter what subclass " + + "the FS returns") { + withSQLConf("fs.file.impl" -> classOf[SpecialBlockLocationFileSystem].getName) { + withTempDir { dir => + val file = new File(dir, "text.txt") + stringToFile(file, "text") + + val inMemoryFileIndex = new InMemoryFileIndex( + spark, Seq(new Path(file.getCanonicalPath)), Map.empty, None) { + def leafFileStatuses = leafFiles.values + } + val blockLocations = inMemoryFileIndex.leafFileStatuses.flatMap( + _.asInstanceOf[LocatedFileStatus].getBlockLocations) + + assert(blockLocations.forall(_.getClass == classOf[BlockLocation])) + } + } + } + } class FakeParentPathFileSystem extends RawLocalFileSystem { @@ -257,3 +277,20 @@ class FakeParentPathFileSystem extends RawLocalFileSystem { URI.create("mockFs://some-bucket") } } + +class SpecialBlockLocationFileSystem extends RawLocalFileSystem { + + class SpecialBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + extends BlockLocation(names, hosts, offset, length) + + override def getFileBlockLocations( + file: FileStatus, + start: Long, + len: Long): Array[BlockLocation] = { + Array(new SpecialBlockLocation(Array("dummy"), Array("dummy"), 0L, file.getLen)) + } +} From 756a3ab18c2032f3e84e5591725ec713e3e22726 Mon Sep 17 00:00:00 2001 From: Shahid Date: Sat, 6 Oct 2018 17:12:41 -0500 Subject: [PATCH 1784/2461] [SPARK-25575][WEBUI][FOLLOWUP] SQL tab in the spark UI support hide tables ## What changes were proposed in this pull request? After the PR, https://github.com/apache/spark/pull/22592, SQL tab supports collapsing table. However, after refreshing the page, it doesn't store it previous state. This was due to a typo in the argument list in the collapseTablePageLoadCommand(). ## How was this patch tested? bin/spark-shell ``` sql("create table a (id int)") for(i <- 1 to 100) sql(s"insert into a values ($i)") ``` ![screenshot from 2018-10-06 10-19-30](https://user-images.githubusercontent.com/23054875/46567490-59bea380-c951-11e8-9484-9aa2ee84b816.png) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22650 from shahidki31/SPARK-25575-followUp. Authored-by: Shahid Signed-off-by: Sean Owen --- core/src/main/resources/org/apache/spark/ui/static/webui.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 12c056af9a51a..b1254e08fa504 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -83,7 +83,7 @@ $(function() { collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); - collapseTablePageLoad('collapse-aggregated-runningExecutions','runningExecutions'); - collapseTablePageLoad('collapse-aggregated-completedExecutions','completedExecutions'); - collapseTablePageLoad('collapse-aggregated-failedExecutions','failedExecutions'); + collapseTablePageLoad('collapse-aggregated-runningExecutions','aggregated-runningExecutions'); + collapseTablePageLoad('collapse-aggregated-completedExecutions','aggregated-completedExecutions'); + collapseTablePageLoad('collapse-aggregated-failedExecutions','aggregated-failedExecutions'); }); \ No newline at end of file From 8bb242902760535d12c6c40c5d8481a98fdc11e0 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 6 Oct 2018 15:49:41 -0700 Subject: [PATCH 1785/2461] [SPARK-25671] Build external/spark-ganglia-lgpl in Jenkins Test ## What changes were proposed in this pull request? Currently, we do not build external/spark-ganglia-lgpl in Jenkins tests when the code is changed. ## How was this patch tested? N/A Closes #22658 from gatorsmile/buildGanglia. Authored-by: gatorsmile Signed-off-by: gatorsmile --- dev/sparktestsupport/modules.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index e267fbfa623b5..e7ac063e234e3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -557,6 +557,16 @@ def __hash__(self): sbt_test_goals=["kubernetes/test"] ) + +spark_ganglia_lgpl = Module( + name="spark-ganglia-lgpl", + dependencies=[], + build_profile_flags=["-Pspark-ganglia-lgpl"], + source_file_regexes=[ + "external/spark-ganglia-lgpl", + ] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( From fba722e319e356113a69c54f59e23150017634ae Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 7 Oct 2018 09:51:33 -0500 Subject: [PATCH 1786/2461] [SPARK-25539][BUILD] Upgrade lz4-java to 1.5.0 get speed improvement ## What changes were proposed in this pull request? This PR upgrade `lz4-java` to 1.5.0 get speed improvement. **General speed improvements** LZ4 decompression speed has always been a strong point. In v1.8.2, this gets even better, as it improves decompression speed by about 10%, thanks in a large part to suggestion from svpv . For example, on a Mac OS-X laptop with an Intel Core i7-5557U CPU 3.10GHz, running lz4 -bsilesia.tar compiled with default compiler llvm v9.1.0: Version | v1.8.1 | v1.8.2 | Improvement -- | -- | -- | -- Decompression speed | 2490 MB/s | 2770 MB/s | +11% Compression speeds also receive a welcomed boost, though improvement is not evenly distributed, with higher levels benefiting quite a lot more. Version | v1.8.1 | v1.8.2 | Improvement -- | -- | -- | -- lz4 -1 | 504 MB/s | 516 MB/s | +2% lz4 -9 | 23.2 MB/s | 25.6 MB/s | +10% lz4 -12 | 3.5 Mb/s | 9.5 MB/s | +170% More details: https://github.com/lz4/lz4/releases/tag/v1.8.3 **Below is my benchmark result** set `spark.sql.parquet.compression.codec` to `lz4` and disable orc benchmark, then run `FilterPushdownBenchmark`. lz4-java 1.5.0: ``` [success] Total time: 5585 s, completed Sep 26, 2018 5:22:16 PM ``` lz4-java 1.4.0: ``` [success] Total time: 5591 s, completed Sep 26, 2018 5:22:24 PM ``` Some benchmark result: ``` lz4-java 1.5.0 Select 1 row with 500 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Parquet Vectorized 1953 / 1980 0.0 1952502908.0 1.0X Parquet Vectorized (Pushdown) 2541 / 2585 0.0 2541019869.0 0.8X lz4-java 1.4.0 Select 1 row with 500 filters: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Parquet Vectorized 1979 / 2103 0.0 1979328144.0 1.0X Parquet Vectorized (Pushdown) 2596 / 2909 0.0 2596222118.0 0.8X ``` Complete benchmark result: https://issues.apache.org/jira/secure/attachment/12941360/FilterPushdownBenchmark-lz4-java-140-results.txt https://issues.apache.org/jira/secure/attachment/12941361/FilterPushdownBenchmark-lz4-java-150-results.txt ## How was this patch tested? manual tests Closes #22551 from wangyum/SPARK-25539. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 22e86ef6c43b3..e0e3e0a82e730 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -138,7 +138,7 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.8.1.jar -lz4-java-1.4.0.jar +lz4-java-1.5.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 19dd786c63e48..3b17f88a82c14 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -139,7 +139,7 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.8.1.jar -lz4-java-1.4.0.jar +lz4-java-1.5.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index ea0f487a193eb..c818b2c39f748 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -154,7 +154,7 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.8.1.jar -lz4-java-1.4.0.jar +lz4-java-1.5.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar diff --git a/pom.xml b/pom.xml index 79af5d69c66c5..98da38f045536 100644 --- a/pom.xml +++ b/pom.xml @@ -540,7 +540,7 @@ org.lz4 lz4-java - 1.4.0 + 1.5.0 com.github.luben From 3eb842969906d6e81a137af6dc4339881df0a315 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 7 Oct 2018 23:18:46 +0800 Subject: [PATCH 1787/2461] [SPARK-25461][PYSPARK][SQL] Add document for mismatch between return type of Pandas.Series and return type of pandas udf ## What changes were proposed in this pull request? For Pandas UDFs, we get arrow type from defined Catalyst return data type of UDFs. We use this arrow type to do serialization of data. If the defined return data type doesn't match with actual return type of Pandas.Series returned by Pandas UDFs, it has a risk to return incorrect data from Python side. Currently we don't have reliable approach to check if the data conversion is safe or not. We leave some document to notify this to users for now. When there is next upgrade of PyArrow available we can use to check it, we should add the option to check it. ## How was this patch tested? Only document change. Closes #22610 from viirya/SPARK-25461. Authored-by: Liang-Chi Hsieh Signed-off-by: hyukjinkwon --- python/pyspark/sql/functions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7685264b2d4d1..be089eea0b280 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2948,6 +2948,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): can fail on special rows, the workaround is to incorporate the condition into the functions. .. note:: The user-defined functions do not take keyword arguments on the calling side. + + .. note:: The data type of returned `pandas.Series` from the user-defined functions should be + matched with defined returnType (see :meth:`types.to_arrow_type` and + :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do + conversion on returned data. The conversion is not guaranteed to be correct and results + should be checked for accuracy by users. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From b1328cc58ebb73bc191de5546735cffe0c68255e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 7 Oct 2018 09:44:01 -0700 Subject: [PATCH 1788/2461] [SPARK-25658][SQL][TEST] Refactor HashByteArrayBenchmark to use main method ## What changes were proposed in this pull request? Refactor `HashByteArrayBenchmark` to use main method. 1. use `spark-submit`: ```console bin/spark-submit --class org.apache.spark.sql.HashByteArrayBenchmark --jars ./core/target/spark-core_2.11-3.0.0-SNAPSHOT-tests.jar ./sql/catalyst/target/spark-catalyst_2.11-3.0.0-SNAPSHOT-tests.jar ``` 2. Generate benchmark result: ```console SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/test:runMain org.apache.spark.sql.HashByteArrayBenchmark" ``` ## How was this patch tested? manual tests Closes #22652 from wangyum/SPARK-25658. Lead-authored-by: Yuming Wang Co-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../HashByteArrayBenchmark-results.txt | 77 +++++++++++ .../spark/sql/HashByteArrayBenchmark.scala | 120 ++++-------------- 2 files changed, 102 insertions(+), 95 deletions(-) create mode 100644 sql/catalyst/benchmarks/HashByteArrayBenchmark-results.txt diff --git a/sql/catalyst/benchmarks/HashByteArrayBenchmark-results.txt b/sql/catalyst/benchmarks/HashByteArrayBenchmark-results.txt new file mode 100644 index 0000000000000..a4304ee3b5f60 --- /dev/null +++ b/sql/catalyst/benchmarks/HashByteArrayBenchmark-results.txt @@ -0,0 +1,77 @@ +================================================================================================ +Benchmark for MurMurHash 3 and xxHash64 +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 16 / 16 127.7 7.8 1.0X +xxHash 64-bit 23 / 23 90.7 11.0 0.7X +HiveHasher 16 / 16 134.8 7.4 1.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 26 / 26 79.5 12.6 1.0X +xxHash 64-bit 26 / 27 79.3 12.6 1.0X +HiveHasher 30 / 30 70.1 14.3 0.9X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 36 / 36 58.1 17.2 1.0X +xxHash 64-bit 30 / 30 70.2 14.2 1.2X +HiveHasher 45 / 45 46.4 21.5 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 50 / 50 41.8 23.9 1.0X +xxHash 64-bit 43 / 43 49.3 20.3 1.2X +HiveHasher 58 / 58 35.9 27.8 0.9X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 132 / 132 15.9 62.7 1.0X +xxHash 64-bit 79 / 79 26.7 37.5 1.7X +HiveHasher 198 / 199 10.6 94.6 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 334 / 334 6.3 159.3 1.0X +xxHash 64-bit 126 / 126 16.7 59.9 2.7X +HiveHasher 633 / 634 3.3 302.0 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 1149 / 1149 1.8 547.9 1.0X +xxHash 64-bit 327 / 327 6.4 155.9 3.5X +HiveHasher 2338 / 2346 0.9 1114.6 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 2215 / 2216 0.9 1056.1 1.0X +xxHash 64-bit 554 / 554 3.8 264.0 4.0X +HiveHasher 4609 / 4609 0.5 2197.5 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Murmur3_x86_32 8633 / 8643 0.2 4116.3 1.0X +xxHash 64-bit 1891 / 1892 1.1 901.6 4.6X +HiveHasher 18206 / 18206 0.1 8681.3 0.5X + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index a60eb20d9edef..7dc865d85af04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -19,15 +19,24 @@ package org.apache.spark.sql import java.util.Random -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.expressions.{HiveHasher, XXH64} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 /** * Synthetic benchmark for MurMurHash 3 and xxHash64. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "catalyst/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/test:runMain " + * Results will be written to "benchmarks/HashByteArrayBenchmark-results.txt". + * }}} */ -object HashByteArrayBenchmark { +object HashByteArrayBenchmark extends BenchmarkBase { def test(length: Int, seed: Long, numArrays: Int, iters: Int): Unit = { val random = new Random(seed) val arrays = Array.fill[Array[Byte]](numArrays) { @@ -36,8 +45,8 @@ object HashByteArrayBenchmark { bytes } - val benchmark = - new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong) + val benchmark = new Benchmark( + "Hash byte arrays with length " + length, iters * numArrays.toLong, output = output) benchmark.addCase("Murmur3_x86_32") { _: Int => var sum = 0L for (_ <- 0L until iters) { @@ -74,96 +83,17 @@ object HashByteArrayBenchmark { benchmark.run() } - def main(args: Array[String]): Unit = { - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 12 / 16 174.3 5.7 1.0X - xxHash 64-bit 17 / 22 120.0 8.3 0.7X - HiveHasher 13 / 15 162.1 6.2 0.9X - */ - test(8, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 19 / 22 107.6 9.3 1.0X - xxHash 64-bit 20 / 24 104.6 9.6 1.0X - HiveHasher 24 / 28 87.0 11.5 0.8X - */ - test(16, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 28 / 32 74.8 13.4 1.0X - xxHash 64-bit 24 / 29 87.3 11.5 1.2X - HiveHasher 36 / 41 57.7 17.3 0.8X - */ - test(24, 42L, 1 << 10, 1 << 11) - - // Add 31 to all arrays to create worse case alignment for xxHash. - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 41 / 45 51.1 19.6 1.0X - xxHash 64-bit 36 / 44 58.8 17.0 1.2X - HiveHasher 49 / 54 42.6 23.5 0.8X - */ - test(31, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 100 / 110 21.0 47.7 1.0X - xxHash 64-bit 74 / 78 28.2 35.5 1.3X - HiveHasher 189 / 196 11.1 90.3 0.5X - */ - test(64 + 31, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 299 / 311 7.0 142.4 1.0X - xxHash 64-bit 113 / 122 18.5 54.1 2.6X - HiveHasher 620 / 624 3.4 295.5 0.5X - */ - test(256 + 31, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 1068 / 1070 2.0 509.1 1.0X - xxHash 64-bit 306 / 315 6.9 145.9 3.5X - HiveHasher 2316 / 2369 0.9 1104.3 0.5X - */ - test(1024 + 31, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 2252 / 2274 0.9 1074.1 1.0X - xxHash 64-bit 534 / 580 3.9 254.6 4.2X - HiveHasher 4739 / 4786 0.4 2259.8 0.5X - */ - test(2048 + 31, 42L, 1 << 10, 1 << 11) - - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Murmur3_x86_32 9249 / 9586 0.2 4410.5 1.0X - xxHash 64-bit 2897 / 3241 0.7 1381.6 3.2X - HiveHasher 19392 / 20211 0.1 9246.6 0.5X - */ - test(8192 + 31, 42L, 1 << 10, 1 << 11) + override def runBenchmarkSuite(): Unit = { + runBenchmark("Benchmark for MurMurHash 3 and xxHash64") { + test(8, 42L, 1 << 10, 1 << 11) + test(16, 42L, 1 << 10, 1 << 11) + test(24, 42L, 1 << 10, 1 << 11) + test(31, 42L, 1 << 10, 1 << 11) + test(64 + 31, 42L, 1 << 10, 1 << 11) + test(256 + 31, 42L, 1 << 10, 1 << 11) + test(1024 + 31, 42L, 1 << 10, 1 << 11) + test(2048 + 31, 42L, 1 << 10, 1 << 11) + test(8192 + 31, 42L, 1 << 10, 1 << 11) + } } } From 669ade3a8eed0016b5ece57d776cea0616417088 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 7 Oct 2018 09:49:37 -0700 Subject: [PATCH 1789/2461] [SPARK-25657][SQL][TEST] Refactor HashBenchmark to use main method ## What changes were proposed in this pull request? Refactor `HashBenchmark` to use main method. 1. use `spark-submit`: ```console bin/spark-submit --class org.apache.spark.sql.HashBenchmark --jars ./core/target/spark-core_2.11-3.0.0-SNAPSHOT-tests.jar ./sql/catalyst/target/spark-catalyst_2.11-3.0.0-SNAPSHOT-tests.jar ``` 2. Generate benchmark result: ```console SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/test:runMain org.apache.spark.sql.HashBenchmark" ``` ## How was this patch tested? manual tests Closes #22651 from wangyum/SPARK-25657. Lead-authored-by: Yuming Wang Co-authored-by: Yuming Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../benchmarks/HashBenchmark-results.txt | 70 ++++++++ .../org/apache/spark/sql/HashBenchmark.scala | 152 +++++++----------- 2 files changed, 129 insertions(+), 93 deletions(-) create mode 100644 sql/catalyst/benchmarks/HashBenchmark-results.txt diff --git a/sql/catalyst/benchmarks/HashBenchmark-results.txt b/sql/catalyst/benchmarks/HashBenchmark-results.txt new file mode 100644 index 0000000000000..2459b35c75bb5 --- /dev/null +++ b/sql/catalyst/benchmarks/HashBenchmark-results.txt @@ -0,0 +1,70 @@ +================================================================================================ +single ints +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +interpreted version 5615 / 5616 95.6 10.5 1.0X +codegen version 8400 / 8407 63.9 15.6 0.7X +codegen version 64-bit 8139 / 8145 66.0 15.2 0.7X +codegen HiveHash version 7213 / 7348 74.4 13.4 0.8X + + +================================================================================================ +single longs +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +interpreted version 6053 / 6054 88.7 11.3 1.0X +codegen version 9367 / 9369 57.3 17.4 0.6X +codegen version 64-bit 8041 / 8051 66.8 15.0 0.8X +codegen HiveHash version 7546 / 7575 71.1 14.1 0.8X + + +================================================================================================ +normal +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +interpreted version 3181 / 3182 0.7 1517.0 1.0X +codegen version 2403 / 2403 0.9 1145.7 1.3X +codegen version 64-bit 915 / 916 2.3 436.2 3.5X +codegen HiveHash version 4505 / 4527 0.5 2148.3 0.7X + + +================================================================================================ +array +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +interpreted version 1828 / 1844 0.1 13946.1 1.0X +codegen version 3678 / 3804 0.0 28058.2 0.5X +codegen version 64-bit 2925 / 2931 0.0 22317.8 0.6X +codegen HiveHash version 1216 / 1217 0.1 9280.0 1.5X + + +================================================================================================ +map +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +interpreted version 0 / 0 44.3 22.6 1.0X +codegen version 176 / 176 0.0 42978.8 0.0X +codegen version 64-bit 173 / 175 0.0 42214.3 0.0X +codegen HiveHash version 44 / 44 0.1 10659.9 0.0X + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 7a2a66c9b1d33..4226ab3773fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection @@ -26,94 +26,87 @@ import org.apache.spark.sql.types._ /** * Benchmark for the previous interpreted hash function(InternalRow.hashCode) vs codegened * hash expressions (Murmur3Hash/xxHash64). + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "catalyst/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/test:runMain " + * Results will be written to "benchmarks/HashBenchmark-results.txt". + * }}} */ -object HashBenchmark { +object HashBenchmark extends BenchmarkBase { def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = { - val generator = RandomDataGenerator.forType(schema, nullable = false).get - val encoder = RowEncoder(schema) - val attrs = schema.toAttributes - val safeProjection = GenerateSafeProjection.generate(attrs, attrs) + runBenchmark(name) { + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + val attrs = schema.toAttributes + val safeProjection = GenerateSafeProjection.generate(attrs, attrs) - val rows = (1 to numRows).map(_ => - // The output of encoder is UnsafeRow, use safeProjection to turn in into safe format. - safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() - ).toArray + val rows = (1 to numRows).map(_ => + // The output of encoder is UnsafeRow, use safeProjection to turn in into safe format. + safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() + ).toArray - val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong) - benchmark.addCase("interpreted version") { _: Int => - var sum = 0 - for (_ <- 0L until iters) { - var i = 0 - while (i < numRows) { - sum += rows(i).hashCode() - i += 1 + val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output) + benchmark.addCase("interpreted version") { _: Int => + var sum = 0 + for (_ <- 0L until iters) { + var i = 0 + while (i < numRows) { + sum += rows(i).hashCode() + i += 1 + } } } - } - val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) - benchmark.addCase("codegen version") { _: Int => - var sum = 0 - for (_ <- 0L until iters) { - var i = 0 - while (i < numRows) { - sum += getHashCode(rows(i)).getInt(0) - i += 1 + val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) + benchmark.addCase("codegen version") { _: Int => + var sum = 0 + for (_ <- 0L until iters) { + var i = 0 + while (i < numRows) { + sum += getHashCode(rows(i)).getInt(0) + i += 1 + } } } - } - val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) - benchmark.addCase("codegen version 64-bit") { _: Int => - var sum = 0 - for (_ <- 0L until iters) { - var i = 0 - while (i < numRows) { - sum += getHashCode64b(rows(i)).getInt(0) - i += 1 + val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) + benchmark.addCase("codegen version 64-bit") { _: Int => + var sum = 0 + for (_ <- 0L until iters) { + var i = 0 + while (i < numRows) { + sum += getHashCode64b(rows(i)).getInt(0) + i += 1 + } } } - } - val getHiveHashCode = UnsafeProjection.create(new HiveHash(attrs) :: Nil, attrs) - benchmark.addCase("codegen HiveHash version") { _: Int => - var sum = 0 - for (_ <- 0L until iters) { - var i = 0 - while (i < numRows) { - sum += getHiveHashCode(rows(i)).getInt(0) - i += 1 + val getHiveHashCode = UnsafeProjection.create(new HiveHash(attrs) :: Nil, attrs) + benchmark.addCase("codegen HiveHash version") { _: Int => + var sum = 0 + for (_ <- 0L until iters) { + var i = 0 + while (i < numRows) { + sum += getHiveHashCode(rows(i)).getInt(0) + i += 1 + } } } - } - benchmark.run() + benchmark.run() + } } - def main(args: Array[String]): Unit = { + override def runBenchmarkSuite(): Unit = { val singleInt = new StructType().add("i", IntegerType) - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - interpreted version 3262 / 3267 164.6 6.1 1.0X - codegen version 6448 / 6718 83.3 12.0 0.5X - codegen version 64-bit 6088 / 6154 88.2 11.3 0.5X - codegen HiveHash version 4732 / 4745 113.5 8.8 0.7X - */ test("single ints", singleInt, 1 << 15, 1 << 14) val singleLong = new StructType().add("i", LongType) - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - interpreted version 3716 / 3726 144.5 6.9 1.0X - codegen version 7706 / 7732 69.7 14.4 0.5X - codegen version 64-bit 6370 / 6399 84.3 11.9 0.6X - codegen HiveHash version 4924 / 5026 109.0 9.2 0.8X - */ test("single longs", singleLong, 1 << 15, 1 << 14) val normal = new StructType() @@ -131,45 +124,18 @@ object HashBenchmark { .add("binary", BinaryType) .add("date", DateType) .add("timestamp", TimestampType) - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - interpreted version 2985 / 3013 0.7 1423.4 1.0X - codegen version 2422 / 2434 0.9 1155.1 1.2X - codegen version 64-bit 856 / 920 2.5 408.0 3.5X - codegen HiveHash version 4501 / 4979 0.5 2146.4 0.7X - */ test("normal", normal, 1 << 10, 1 << 11) val arrayOfInt = ArrayType(IntegerType) val array = new StructType() .add("array", arrayOfInt) .add("arrayOfArray", ArrayType(arrayOfInt)) - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - interpreted version 3100 / 3555 0.0 23651.8 1.0X - codegen version 5779 / 5865 0.0 44088.4 0.5X - codegen version 64-bit 4738 / 4821 0.0 36151.7 0.7X - codegen HiveHash version 2200 / 2246 0.1 16785.9 1.4X - */ test("array", array, 1 << 8, 1 << 9) val mapOfInt = MapType(IntegerType, IntegerType) val map = new StructType() .add("map", mapOfInt) .add("mapOfMap", MapType(IntegerType, mapOfInt)) - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - interpreted version 0 / 0 48.1 20.8 1.0X - codegen version 257 / 275 0.0 62768.7 0.0X - codegen version 64-bit 226 / 240 0.0 55224.5 0.0X - codegen HiveHash version 89 / 96 0.0 21708.8 0.0X - */ test("map", map, 1 << 6, 1 << 6) } } From ebd899b8a865395e6f1137163cb508086696879b Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sun, 7 Oct 2018 10:06:44 -0700 Subject: [PATCH 1790/2461] [SPARK-25321][ML] Revert SPARK-14681 to avoid API breaking change ## What changes were proposed in this pull request? This is the same as #22492 but for master branch. Revert SPARK-14681 to avoid API breaking changes. cc: WeichenXu123 ## How was this patch tested? Existing unit tests. Closes #22618 from mengxr/SPARK-25321.master. Authored-by: WeichenXu Signed-off-by: Dongjoon Hyun --- .../DecisionTreeClassifier.scala | 14 +- .../ml/classification/GBTClassifier.scala | 6 +- .../RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 247 ++++-------------- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 +-- .../DecisionTreeClassifierSuite.scala | 31 +-- .../classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 14 - .../ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 7 - 16 files changed, 107 insertions(+), 332 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8a57bfc029d14..6648e78d8eafa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -168,7 +168,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: ClassificationNode, + @Since("1.4.0")override val rootNode: Node, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -279,9 +279,8 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) - val model = new DecisionTreeClassificationModel(metadata.uid, - root.asInstanceOf[ClassificationNode], numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession) + val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) metadata.getAndSetParams(model) model } @@ -296,10 +295,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, - rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) + new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 33acd9914073f..62cfa39746ff0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -412,14 +412,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeRegressionModel(treeMetadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 94887ac346fec..57132381b6474 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -313,15 +313,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeClassificationModel(treeMetadata.uid, - root.asInstanceOf[ClassificationNode], numFeatures, numClasses) + val tree = + new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 018290f81842f..6fa656275c1fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: RegressionNode, + override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = + private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,9 +279,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) - val model = new DecisionTreeRegressionModel(metadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession) + val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) metadata.getAndSetParams(model) model } @@ -296,8 +295,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) + new DecisionTreeRegressionModel(uid, rootNode, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 3305881b0ccc6..07f88d8d5f84d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -338,15 +338,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeRegressionModel(treeMetadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 35875724b3cfa..82bf66ff66d8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -271,13 +271,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = new DecisionTreeRegressionModel(treeMetadata.uid, - root.asInstanceOf[RegressionNode], numFeatures) + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) treeMetadata.getAndSetParams(tree) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 0242bc76698d0..d30be452a436e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,16 +17,14 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed trait Node extends Serializable { +sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -86,86 +84,35 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld( - oldNode: OldNode, - categoricalFeatures: Map[Int, Int], - isClassification: Boolean): Node = { + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - if (isClassification) { - new ClassificationLeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) - } else { - new RegressionLeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) - } + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - if (isClassification) { - new ClassificationInternalNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, gain = gain, - leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) - .asInstanceOf[ClassificationNode], - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) - .asInstanceOf[ClassificationNode], - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) - } else { - new RegressionInternalNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, gain = gain, - leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) - .asInstanceOf[RegressionNode], - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) - .asInstanceOf[RegressionNode], - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) - } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } -@Since("2.4.0") -sealed trait ClassificationNode extends Node { - - /** - * Get count of training examples for specified label in this node - * @param label label number in the range [0, numClasses) - */ - @Since("2.4.0") - def getLabelCount(label: Int): Double = { - require(label >= 0 && label < impurityStats.stats.length, - "label should be in the range between 0 (inclusive) " + - s"and ${impurityStats.stats.length} (exclusive).") - impurityStats.stats(label) - } -} - -@Since("2.4.0") -sealed trait RegressionNode extends Node { - - /** Number of training data points in this node */ - @Since("2.4.0") - def getCount: Double = impurityStats.stats(0) - - /** Sum over training data points of the labels in this node */ - @Since("2.4.0") - def getSum: Double = impurityStats.stats(1) - - /** Sum over training data points of the square of the labels in this node */ - @Since("2.4.0") - def getSumOfSquares: Double = impurityStats.stats(2) -} - -@Since("2.4.0") -sealed trait LeafNode extends Node { - - /** Prediction this node makes. */ - def prediction: Double - - def impurity: Double +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -188,58 +135,32 @@ sealed trait LeafNode extends Node { override private[ml] def maxSplitFeatureIndex(): Int = -1 -} - -/** - * Decision tree leaf node for classification. - */ -@Since("2.4.0") -class ClassificationLeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) - extends ClassificationNode with LeafNode { - override private[tree] def deepCopy(): Node = { - new ClassificationLeafNode(prediction, impurity, impurityStats) + new LeafNode(prediction, impurity, impurityStats) } } /** - * Decision tree leaf node for regression. + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. */ -@Since("2.4.0") -class RegressionLeafNode private[ml] ( +class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) - extends RegressionNode with LeafNode { - - override private[tree] def deepCopy(): Node = { - new RegressionLeafNode(prediction, impurity, impurityStats) - } -} - -/** - * Internal Decision Tree node. - */ -@Since("2.4.0") -sealed trait InternalNode extends Node { - - /** - * Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - */ - def gain: Double - - /** Left-hand child node */ - def leftChild: Node - - /** Right-hand child node */ - def rightChild: Node + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) extends Node { - /** Information about the test used to split to the left or right child. */ - def split: Split + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -284,6 +205,11 @@ sealed trait InternalNode extends Node { math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } + + override private[tree] def deepCopy(): Node = { + new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), + split, impurityStats) + } } private object InternalNode { @@ -314,57 +240,6 @@ private object InternalNode { } } -/** - * Internal Decision Tree node for regression. - */ -@Since("2.4.0") -class ClassificationInternalNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override val gain: Double, - override val leftChild: ClassificationNode, - override val rightChild: ClassificationNode, - override val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) - extends ClassificationNode with InternalNode { - - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. - - override private[tree] def deepCopy(): Node = { - new ClassificationInternalNode(prediction, impurity, gain, - leftChild.deepCopy().asInstanceOf[ClassificationNode], - rightChild.deepCopy().asInstanceOf[ClassificationNode], - split, impurityStats) - } -} - -/** - * Internal Decision Tree node for regression. - */ -@Since("2.4.0") -class RegressionInternalNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override val gain: Double, - override val leftChild: RegressionNode, - override val rightChild: RegressionNode, - override val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) - extends RegressionNode with InternalNode { - - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. - - override private[tree] def deepCopy(): Node = { - new RegressionInternalNode(prediction, impurity, gain, - leftChild.deepCopy().asInstanceOf[RegressionNode], - rightChild.deepCopy().asInstanceOf[RegressionNode], - split, impurityStats) - } -} - - /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -390,52 +265,30 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) - - def toClassificationNode(prune: Boolean = true): ClassificationNode = { - toNode(true, prune).asInstanceOf[ClassificationNode] - } - - def toRegressionNode(prune: Boolean = true): RegressionNode = { - toNode(false, prune).asInstanceOf[RegressionNode] - } + def toNode: Node = toNode(prune = true) /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(isClassification: Boolean, prune: Boolean): Node = { + def toNode(prune: Boolean = true): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(isClassification, prune), - rightChild.get.toNode(isClassification, prune)) match { + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - if (isClassification) { - new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) - } else { - new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) - } + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) case (l, r) => - if (isClassification) { - new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, - stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], - split.get, stats.impurityCalculator) - } else { - new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], - split.get, stats.impurityCalculator) - } + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) } } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - val impurity = if (stats.valid) stats.impurity else -1.0 - if (isClassification) { - new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, stats.impurityCalculator) } else { - new RegressionLeafNode(stats.impurityCalculator.predict, impurity, - stats.impurityCalculator) + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4cdd17266b771..822abd2d3522d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -226,23 +226,23 @@ private[spark] object RandomForest extends Logging with Serializable { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), - numFeatures, strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, + strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index f027b14f1d476..4aa4c3617e7fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,10 +219,8 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case _: LeafNode => + case n: LeafNode => // do nothing - case _ => - throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -319,8 +317,6 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) - case _ => - throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -331,7 +327,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession, isClassification: Boolean): Node = { + sparkSession: SparkSession): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -343,7 +339,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType, isClassification) + buildTreeFromNodes(data.collect(), impurityType) } /** @@ -352,8 +348,7 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String, - isClassification: Boolean): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -369,21 +364,10 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - if (isClassification) { - new ClassificationInternalNode(n.prediction, n.impurity, n.gain, - leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], - n.split.getSplit, impurityStats) - } else { - new RegressionInternalNode(n.prediction, n.impurity, n.gain, - leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], - n.split.getSplit, impurityStats) - } + new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, + n.split.getSplit, impurityStats) } else { - if (isClassification) { - new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) - } else { - new RegressionLeafNode(n.prediction, n.impurity, impurityStats) - } + new LeafNode(n.prediction, n.impurity, impurityStats) } finalNodes(n.id) = node } @@ -437,8 +421,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String, - isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -466,8 +449,7 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( - nodeData.toArray, impurityType, isClassification) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index d3dbb4e754d3d..2930f4900d50e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.ClassificationLeafNode +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,8 +61,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", - new ClassificationLeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -376,32 +375,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } - - test("label/impurity stats") { - val arr = Array( - LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), - LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), - LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) - val rdd = sc.parallelize(arr) - val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) - val dt1 = new DecisionTreeClassifier() - .setImpurity("entropy") - .setMaxDepth(2) - .setMinInstancesPerNode(2) - val model1 = dt1.fit(df) - - val rootNode1 = model1.rootNode - assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) - - val dt2 = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(2) - .setMinInstancesPerNode(2) - val model2 = dt2.fit(df) - - val rootNode2 = model2.rootNode - assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) - } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e6d2a8e2b900e..304977634189c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.RegressionLeafNode +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -70,7 +70,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 3062aa9f3d274..ba4a9cf082785 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.ClassificationLeafNode +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,8 +71,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", - new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 9ae27339b11d5..29a438396516b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,20 +191,6 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } - - test("label/impurity stats") { - val categoricalFeatures = Map(0 -> 2, 1 -> 2) - val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) - val dtr = new DecisionTreeRegressor() - .setImpurity("variance") - .setMaxDepth(2) - .setMaxBins(8) - val model = dtr.fit(df) - val statInfo = model.rootNode - - assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 - && statInfo.getSumOfSquares == 600.0) - } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 4dbbd75d2466d..743dacf146fe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) - assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) - assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], - numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) + .asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 3f03d909d4a4c..b6894b30b0c2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { + def buildParentNode(left: Node, right: Node, split: Split): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,15 +168,7 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - if (isClassification) { - new ClassificationInternalNode(pred, parentImp.calculate(), gain, - left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], - split, parentImp) - } else { - new RegressionInternalNode(pred, parentImp.calculate(), gain, - left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], - split, parentImp) - } + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a931738032467..0b074fbf64eda 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -103,13 +103,6 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), - // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), - // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), From 219922422003e59cc8b3bece60778536759fa669 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 8 Oct 2018 15:07:06 +0800 Subject: [PATCH 1791/2461] [SPARK-25673][BUILD] Remove Travis CI which enables Java lint check ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/12980 added Travis CI file mainly for linter because we disabled Java lint check in Jenkins. It's enabled as of https://github.com/apache/spark/pull/21399 and now SBT runs it. Looks we can now remove the file added before. ## How was this patch tested? N/A Closes #22665 Closes #22667 from HyukjinKwon/SPARK-25673. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- .travis.yml | 50 -------------------------------------------------- 1 file changed, 50 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 05b94adeeb93b..0000000000000 --- a/.travis.yml +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Spark provides this Travis CI configuration file to help contributors -# check Scala/Java style conformance and JDK7/8 compilation easily -# during their preparing pull requests. -# - Scalastyle is executed during `maven install` implicitly. -# - Java Checkstyle is executed by `lint-java`. -# See the related discussion here. -# https://github.com/apache/spark/pull/12980 - -# 1. Choose OS (Ubuntu 14.04.3 LTS Server Edition 64bit, ~2 CORE, 7.5GB RAM) -sudo: required -dist: trusty - -# 2. Choose language and target JDKs for parallel builds. -language: java -jdk: - - oraclejdk8 - -# 3. Setup cache directory for SBT and Maven. -cache: - directories: - - $HOME/.sbt - - $HOME/.m2 - -# 4. Turn off notifications. -notifications: - email: false - -# 5. Run maven install before running lint-java. -install: - - export MAVEN_SKIP_RC=1 - - build/mvn -T 4 -q -DskipTests -Pkubernetes -Pmesos -Pyarn -Pkinesis-asl -Phive -Phive-thriftserver install - -# 6. Run lint-java. -script: - - dev/lint-java From cb90617f894fd51a092710271823ec7d1cd3a668 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 8 Oct 2018 15:18:08 +0800 Subject: [PATCH 1792/2461] [SPARK-25591][PYSPARK][SQL] Avoid overwriting deserialized accumulator ## What changes were proposed in this pull request? If we use accumulators in more than one UDFs, it is possible to overwrite deserialized accumulators and its values. We should check if an accumulator was deserialized before overwriting it in accumulator registry. ## How was this patch tested? Added test. Closes #22635 from viirya/SPARK-25591. Authored-by: Liang-Chi Hsieh Signed-off-by: hyukjinkwon --- python/pyspark/accumulators.py | 12 ++++++++---- python/pyspark/sql/tests.py | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 30ad04297c682..00ec094e7e3b4 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -109,10 +109,14 @@ def _deserialize_accumulator(aid, zero_value, accum_param): from pyspark.accumulators import _accumulatorRegistry - accum = Accumulator(aid, zero_value, accum_param) - accum._deserialized = True - _accumulatorRegistry[aid] = accum - return accum + # If this certain accumulator was deserialized, don't overwrite it. + if aid in _accumulatorRegistry: + return _accumulatorRegistry[aid] + else: + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum class Accumulator(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d3c29d061fc32..ac87ccddd689f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3603,6 +3603,31 @@ def test_repr_behaviors(self): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) + # SPARK-25591 + def test_same_accumulator_in_udfs(self): + from pyspark.sql.functions import udf + + data_schema = StructType([StructField("a", IntegerType(), True), + StructField("b", IntegerType(), True)]) + data = self.spark.createDataFrame([[1, 2]], schema=data_schema) + + test_accum = self.sc.accumulator(0) + + def first_udf(x): + test_accum.add(1) + return x + + def second_udf(x): + test_accum.add(100) + return x + + func_udf = udf(first_udf, IntegerType()) + func_udf2 = udf(second_udf, IntegerType()) + data = data.withColumn("out1", func_udf(data["a"])) + data = data.withColumn("out2", func_udf2(data["b"])) + data.collect() + self.assertEqual(test_accum.value, 101) + class HiveSparkSubmitTests(SparkSubmitTests): From 1a6815cd9f421a106f8d96a36a53042a00f02386 Mon Sep 17 00:00:00 2001 From: shivusondur Date: Mon, 8 Oct 2018 15:43:08 +0800 Subject: [PATCH 1793/2461] [SPARK-25677][DOC] spark.io.compression.codec = org.apache.spark.io.ZstdCompressionCodec throwing IllegalArgumentException Exception ## What changes were proposed in this pull request? Documentation is updated with proper classname org.apache.spark.io.ZStdCompressionCodec ## How was this patch tested? we used the spark.io.compression.codec = org.apache.spark.io.ZStdCompressionCodec and verified the logs. Closes #22669 from shivusondur/CompressionIssue. Authored-by: shivusondur Signed-off-by: hyukjinkwon --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 55773937d4d71..613e214783d59 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -954,7 +954,7 @@ Apart from these, the following properties are also available, and may be useful org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, org.apache.spark.io.SnappyCompressionCodec, - and org.apache.spark.io.ZstdCompressionCodec. + and org.apache.spark.io.ZStdCompressionCodec.
      From a853a80202032083ad411eec5ec97b304f732a61 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 8 Oct 2018 15:47:15 +0800 Subject: [PATCH 1794/2461] [SPARK-25666][PYTHON] Internally document type conversion between Python data and SQL types in normal UDFs ### What changes were proposed in this pull request? We are facing some problems about type conversions between Python data and SQL types in UDFs (Pandas UDFs as well). It's even difficult to identify the problems (see https://github.com/apache/spark/pull/20163 and https://github.com/apache/spark/pull/22610). This PR targets to internally document the type conversion table. Some of them looks buggy and we should fix them. ```python import sys import array import datetime from decimal import Decimal from pyspark.sql import Row from pyspark.sql.types import * from pyspark.sql.functions import udf if sys.version >= '3': long = int data = [ None, True, 1, long(1), "a", u"a", datetime.date(1970, 1, 1), datetime.datetime(1970, 1, 1, 0, 0), 1.0, array.array("i", [1]), [1], (1,), bytearray([65, 66, 67]), Decimal(1), {"a": 1}, Row(kwargs=1), Row("namedtuple")(1), ] types = [ BooleanType(), ByteType(), ShortType(), IntegerType(), LongType(), StringType(), DateType(), TimestampType(), FloatType(), DoubleType(), ArrayType(IntegerType()), BinaryType(), DecimalType(10, 0), MapType(StringType(), IntegerType()), StructType([StructField("_1", IntegerType())]), ] df = spark.range(1) results = [] count = 0 total = len(types) * len(data) spark.sparkContext.setLogLevel("FATAL") for t in types: result = [] for v in data: try: row = df.select(udf(lambda: v, t)()).first() ret_str = repr(row[0]) except Exception: ret_str = "X" result.append(ret_str) progress = "SQL Type: [%s]\n Python Value: [%s(%s)]\n Result Python Value: [%s]" % ( t.simpleString(), str(v), type(v).__name__, ret_str) count += 1 print("%s/%s:\n %s" % (count, total, progress)) results.append([t.simpleString()] + list(map(str, result))) schema = ["SQL Type \\ Python Value(Type)"] + list(map(lambda v: "%s(%s)" % (str(v), type(v).__name__), data)) strings = spark.createDataFrame(results, schema=schema)._jdf.showString(20, 20, False) print("\n".join(map(lambda line: " # %s # noqa" % line, strings.strip().split("\n")))) ``` This table was generated under Python 2 but the code above is Python 3 compatible as well. ## How was this patch tested? Manually tested and lint check. Closes #22655 from HyukjinKwon/SPARK-25666. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/pyspark/sql/functions.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index be089eea0b280..5425d311f8c7f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2733,6 +2733,39 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ + + # The following table shows most of Python data and SQL type conversions in normal UDFs that + # are not yet visible to the user. Some of behaviors are buggy and might be changed in the near + # future. The table might have to be eventually documented externally. + # Please see SPARK-25666's PR to see the codes in order to generate the table below. + # + # +-----------------------------+--------------+----------+------+-------+---------------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+-----------------+------------+--------------+------------------+----------------------+ # noqa + # |SQL Type \ Python Value(Type)|None(NoneType)|True(bool)|1(int)|1(long)| a(str)| a(unicode)| 1970-01-01(date)|1970-01-01 00:00:00(datetime)|1.0(float)|array('i', [1])(array)|[1](list)| (1,)(tuple)| ABC(bytearray)| 1(Decimal)|{'a': 1}(dict)|Row(kwargs=1)(Row)|Row(namedtuple=1)(Row)| # noqa + # +-----------------------------+--------------+----------+------+-------+---------------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+-----------------+------------+--------------+------------------+----------------------+ # noqa + # | boolean| None| True| None| None| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa + # | tinyint| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa + # | smallint| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa + # | int| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa + # | bigint| None| None| 1| 1| None| None| None| None| None| None| None| None| None| None| None| X| X| # noqa + # | string| None| u'true'| u'1'| u'1'| u'a'| u'a'|u'java.util.Grego...| u'java.util.Grego...| u'1.0'| u'[I@24a83055'| u'[1]'|u'[Ljava.lang.Obj...| u'[B@49093632'| u'1'| u'{a=1}'| X| X| # noqa + # | date| None| X| X| X| X| X|datetime.date(197...| datetime.date(197...| X| X| X| X| X| X| X| X| X| # noqa + # | timestamp| None| X| X| X| X| X| X| datetime.datetime...| X| X| X| X| X| X| X| X| X| # noqa + # | float| None| None| None| None| None| None| None| None| 1.0| None| None| None| None| None| None| X| X| # noqa + # | double| None| None| None| None| None| None| None| None| 1.0| None| None| None| None| None| None| X| X| # noqa + # | array| None| None| None| None| None| None| None| None| None| [1]| [1]| [1]| [65, 66, 67]| None| None| X| X| # noqa + # | binary| None| None| None| None|bytearray(b'a')|bytearray(b'a')| None| None| None| None| None| None|bytearray(b'ABC')| None| None| X| X| # noqa + # | decimal(10,0)| None| None| None| None| None| None| None| None| None| None| None| None| None|Decimal('1')| None| X| X| # noqa + # | map| None| None| None| None| None| None| None| None| None| None| None| None| None| None| {u'a': 1}| X| X| # noqa + # | struct<_1:int>| None| X| X| X| X| X| X| X| X| X|Row(_1=1)| Row(_1=1)| X| X| Row(_1=None)| Row(_1=1)| Row(_1=1)| # noqa + # +-----------------------------+--------------+----------+------+-------+---------------+---------------+--------------------+-----------------------------+----------+----------------------+---------+--------------------+-----------------+------------+--------------+------------------+----------------------+ # noqa + # + # Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be + # used in `returnType`. + # Note: The values inside of the table are generated by `repr`. + # Note: Python 2 is used to generate this table since it is used to check the backward + # compatibility often in practice. + # Note: 'X' means it throws an exception during the conversion. + # decorator @udf, @udf(), @udf(dataType()) if f is None or isinstance(f, (str, DataType)): # If DataType has been passed as a positional argument From 1a28625355d75076bde4bcc95a72e9b187cda606 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 8 Oct 2018 09:58:52 -0500 Subject: [PATCH 1795/2461] [SPARK-25408] Move to more ideomatic Java8 While working on another PR, I noticed that there is quite some legacy Java in there that can be beautified. For example the use of features from Java8, such as: - Collection libraries - Try-with-resource blocks No logic has been changed. I think it is important to have a solid codebase with examples that will inspire next PR's to follow up on the best practices. What are your thoughts on this? This makes code easier to read, and using try-with-resource makes is less likely to forget to close something. ## What changes were proposed in this pull request? No changes in the logic of Spark, but more in the aesthetics of the code. ## How was this patch tested? Using the existing unit tests. Since no logic is changed, the existing unit tests should pass. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22637 from Fokko/SPARK-25408. Authored-by: Fokko Driesprong Signed-off-by: Sean Owen --- .../spark/util/kvstore/KVStoreSerializer.java | 10 +- .../spark/util/kvstore/LevelDBSuite.java | 2 +- .../network/ChunkFetchIntegrationSuite.java | 54 +++++---- .../shuffle/ShuffleIndexInformation.java | 8 +- .../ExternalShuffleBlockResolverSuite.java | 28 ++--- .../ExternalShuffleIntegrationSuite.java | 50 ++++---- .../shuffle/ExternalShuffleSecuritySuite.java | 18 +-- .../spark/util/sketch/CountMinSketch.java | 7 +- .../spark/util/sketch/CountMinSketchImpl.java | 8 +- .../apache/spark/io/ReadAheadInputStream.java | 102 ++++++++-------- .../sort/BypassMergeSortShuffleWriter.java | 6 +- .../shuffle/sort/ShuffleExternalSorter.java | 63 +++++----- .../org/apache/spark/JavaJdbcRDDSuite.java | 30 +++-- .../sort/UnsafeShuffleWriterSuite.java | 14 +-- .../test/org/apache/spark/JavaAPISuite.java | 27 ++--- .../expressions/RowBasedKeyValueBatch.java | 3 +- .../RowBasedKeyValueBatchSuite.java | 112 +++++++----------- 17 files changed, 249 insertions(+), 293 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java index bd8d9486acde5..771a9541bb349 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java @@ -54,11 +54,8 @@ public final byte[] serialize(Object o) throws Exception { return ((String) o).getBytes(UTF_8); } else { ByteArrayOutputStream bytes = new ByteArrayOutputStream(); - GZIPOutputStream out = new GZIPOutputStream(bytes); - try { + try (GZIPOutputStream out = new GZIPOutputStream(bytes)) { mapper.writeValue(out, o); - } finally { - out.close(); } return bytes.toByteArray(); } @@ -69,11 +66,8 @@ public final T deserialize(byte[] data, Class klass) throws Exception { if (klass.equals(String.class)) { return (T) new String(data, UTF_8); } else { - GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); - try { + try (GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data))) { return mapper.readValue(in, klass); - } finally { - in.close(); } } } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 205f7df87c5bc..39a952f2b0df9 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -217,7 +217,7 @@ public void testSkip() throws Exception { public void testNegativeIndexValues() throws Exception { List expected = Arrays.asList(-100, -50, 0, 50, 100); - expected.stream().forEach(i -> { + expected.forEach(i -> { try { db.write(createCustomType1(i)); } catch (Exception e) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 824482af08dd4..37a8664a52661 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -143,37 +143,39 @@ public void releaseBuffers() { } private FetchResult fetchChunks(List chunkIndices) throws Exception { - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - final Semaphore sem = new Semaphore(0); - final FetchResult res = new FetchResult(); - res.successChunks = Collections.synchronizedSet(new HashSet()); - res.failedChunks = Collections.synchronizedSet(new HashSet()); - res.buffers = Collections.synchronizedList(new LinkedList()); - ChunkReceivedCallback callback = new ChunkReceivedCallback() { - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - buffer.retain(); - res.successChunks.add(chunkIndex); - res.buffers.add(buffer); - sem.release(); - } + try (TransportClient client = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort())) { + final Semaphore sem = new Semaphore(0); + + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } - @Override - public void onFailure(int chunkIndex, Throwable e) { - res.failedChunks.add(chunkIndex); - sem.release(); - } - }; + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; - for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); - } - if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } } - client.close(); return res; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 386738ece51a6..371149bef3974 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -37,14 +37,8 @@ public ShuffleIndexInformation(File indexFile) throws IOException { size = (int)indexFile.length(); ByteBuffer buffer = ByteBuffer.allocate(size); offsets = buffer.asLongBuffer(); - DataInputStream dis = null; - try { - dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); + try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { dis.readFully(buffer.array()); - } finally { - if (dis != null) { - dis.close(); - } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d2072a54fa415..459629c5f05fe 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -98,19 +98,19 @@ public void testSortShuffleBlocks() throws IOException { resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); - InputStream block0Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); - String block0 = CharStreams.toString( - new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); - block0Stream.close(); - assertEquals(sortBlock0, block0); - - InputStream block1Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); - String block1 = CharStreams.toString( - new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); - block1Stream.close(); - assertEquals(sortBlock1, block1); + try (InputStream block0Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 0).createInputStream()) { + String block0 = + CharStreams.toString(new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock0, block0); + } + + try (InputStream block1Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 1).createInputStream()) { + String block1 = + CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock1, block1); + } } @Test @@ -149,7 +149,7 @@ public void testNormalizeAndInternPathname() { private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { String normPathname = - ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); + ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); assertEquals(expectedPathname, normPathname); File file = new File(normPathname); String returnedPath = file.getPath(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a6a1b8d0ac3f1..526b96b364473 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -133,37 +133,37 @@ private FetchResult fetchBlocks( final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000); - client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, - new BlockFetchingListener() { - @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - data.retain(); - res.successBlocks.add(blockId); - res.buffers.add(data); - requestsRemaining.release(); + try (ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000)) { + client.init(APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } } } - } - - @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { - synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - res.failedBlocks.add(blockId); - requestsRemaining.release(); + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } } } - } - }, null); + }, null); - if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server"); + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } } - client.close(); return res; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 16bad9f1b319d..82caf392b821b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -96,14 +96,16 @@ private void validate(String appId, String secretKey, boolean encrypt) ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); } - ExternalShuffleClient client = - new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000); - client.init(appId); - // Registration either succeeds or throws an exception. - client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", - new ExecutorShuffleInfo(new String[0], 0, - "org.apache.spark.shuffle.sort.SortShuffleManager")); - client.close(); + try (ExternalShuffleClient client = + new ExternalShuffleClient( + testConf, new TestSecretKeyHolder(appId, secretKey), true, 5000)) { + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo( + new String[0], 0, "org.apache.spark.shuffle.sort.SortShuffleManager") + ); + } } /** Provides a secret key holder which always returns the given secret key, for a single appId. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index f7c22dddb8cc0..06a248c9a27c2 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -191,10 +191,9 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { * Reads in a {@link CountMinSketch} from a byte array. */ public static CountMinSketch readFrom(byte[] bytes) throws IOException { - InputStream in = new ByteArrayInputStream(bytes); - CountMinSketch cms = readFrom(in); - in.close(); - return cms; + try (InputStream in = new ByteArrayInputStream(bytes)) { + return readFrom(in); + } } /** diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index fd1906d2e5ae9..b78c1677a1213 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -322,10 +322,10 @@ public void writeTo(OutputStream out) throws IOException { @Override public byte[] toByteArray() throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - writeTo(out); - out.close(); - return out.toByteArray(); + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + writeTo(out); + return out.toByteArray(); + } } public static CountMinSketchImpl readFrom(InputStream in) throws IOException { diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 0cced9e222952..2e18715b600e0 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -135,62 +135,58 @@ private void readAsync() throws IOException { } finally { stateChangeLock.unlock(); } - executorService.execute(new Runnable() { - - @Override - public void run() { - stateChangeLock.lock(); - try { - if (isClosed) { - readInProgress = false; - return; - } - // Flip this so that the close method will not close the underlying input stream when we - // are reading. - isReading = true; - } finally { - stateChangeLock.unlock(); + executorService.execute(() -> { + stateChangeLock.lock(); + try { + if (isClosed) { + readInProgress = false; + return; } + // Flip this so that the close method will not close the underlying input stream when we + // are reading. + isReading = true; + } finally { + stateChangeLock.unlock(); + } - // Please note that it is safe to release the lock and read into the read ahead buffer - // because either of following two conditions will hold - 1. The active buffer has - // data available to read so the reader will not read from the read ahead buffer. - // 2. This is the first time read is called or the active buffer is exhausted, - // in that case the reader waits for this async read to complete. - // So there is no race condition in both the situations. - int read = 0; - int off = 0, len = arr.length; - Throwable exception = null; - try { - // try to fill the read ahead buffer. - // if a reader is waiting, possibly return early. - do { - read = underlyingInputStream.read(arr, off, len); - if (read <= 0) break; - off += read; - len -= read; - } while (len > 0 && !isWaiting.get()); - } catch (Throwable ex) { - exception = ex; - if (ex instanceof Error) { - // `readException` may not be reported to the user. Rethrow Error to make sure at least - // The user can see Error in UncaughtExceptionHandler. - throw (Error) ex; - } - } finally { - stateChangeLock.lock(); - readAheadBuffer.limit(off); - if (read < 0 || (exception instanceof EOFException)) { - endOfStream = true; - } else if (exception != null) { - readAborted = true; - readException = exception; - } - readInProgress = false; - signalAsyncReadComplete(); - stateChangeLock.unlock(); - closeUnderlyingInputStreamIfNecessary(); + // Please note that it is safe to release the lock and read into the read ahead buffer + // because either of following two conditions will hold - 1. The active buffer has + // data available to read so the reader will not read from the read ahead buffer. + // 2. This is the first time read is called or the active buffer is exhausted, + // in that case the reader waits for this async read to complete. + // So there is no race condition in both the situations. + int read = 0; + int off = 0, len = arr.length; + Throwable exception = null; + try { + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); + } catch (Throwable ex) { + exception = ex; + if (ex instanceof Error) { + // `readException` may not be reported to the user. Rethrow Error to make sure at least + // The user can see Error in UncaughtExceptionHandler. + throw (Error) ex; } + } finally { + stateChangeLock.lock(); + readAheadBuffer.limit(off); + if (read < 0 || (exception instanceof EOFException)) { + endOfStream = true; + } else if (exception != null) { + readAborted = true; + readException = exception; + } + readInProgress = false; + signalAsyncReadComplete(); + stateChangeLock.unlock(); + closeUnderlyingInputStreamIfNecessary(); } }); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..b020a6d99247b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -152,9 +152,9 @@ public void write(Iterator> records) throws IOException { } for (int i = 0; i < numPartitions; i++) { - final DiskBlockObjectWriter writer = partitionWriters[i]; - partitionWriterSegments[i] = writer.commitAndGet(); - writer.close(); + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } } File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c7d2db4217d96..1c0d664afb138 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -181,42 +181,43 @@ private void writeSortedFile(boolean isLastFile) { // around this, we pass a dummy no-op serializer. final SerializerInstance ser = DummySerializerInstance.INSTANCE; - final DiskBlockObjectWriter writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); - int currentPartition = -1; - final int uaoSize = UnsafeAlignedOffset.getUaoSize(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final int partition = sortedRecords.packedRecordPointer.getPartitionId(); - assert (partition >= currentPartition); - if (partition != currentPartition) { - // Switch to the new partition - if (currentPartition != -1) { - final FileSegment fileSegment = writer.commitAndGet(); - spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + final FileSegment committedSegment; + try (DiskBlockObjectWriter writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse)) { + + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + final FileSegment fileSegment = writer.commitAndGet(); + spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + } + currentPartition = partition; } - currentPartition = partition; - } - final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = taskMemoryManager.getPage(recordPointer); - final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); - long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length - while (dataRemaining > 0) { - final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); - Platform.copyMemory( - recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); - writer.write(writeBuffer, 0, toTransfer); - recordReadPosition += toTransfer; - dataRemaining -= toTransfer; + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); } - writer.recordWritten(); - } - final FileSegment committedSegment = writer.commitAndGet(); - writer.close(); + committedSegment = writer.commitAndGet(); + } // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, // then the file might be empty. Note that it might be better to avoid calling // writeSortedFile() in that case. diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java index a6589d2898144..40a7c9486ae55 100644 --- a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java +++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java @@ -39,30 +39,28 @@ public void setUp() throws ClassNotFoundException, SQLException { sc = new JavaSparkContext("local", "JavaAPISuite"); Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); - Connection connection = - DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true"); - try { - Statement create = connection.createStatement(); - create.execute( - "CREATE TABLE FOO(" + - "ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," + - "DATA INTEGER)"); - create.close(); + try (Connection connection = DriverManager.getConnection( + "jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true")) { + + try (Statement create = connection.createStatement()) { + create.execute( + "CREATE TABLE FOO(ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY" + + " (START WITH 1, INCREMENT BY 1), DATA INTEGER)"); + } - PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)"); - for (int i = 1; i <= 100; i++) { - insert.setInt(1, i * 2); - insert.executeUpdate(); + try (PreparedStatement insert = connection.prepareStatement( + "INSERT INTO FOO(DATA) VALUES(?)")) { + for (int i = 1; i <= 100; i++) { + insert.setInt(1, i * 2); + insert.executeUpdate(); + } } - insert.close(); } catch (SQLException e) { // If table doesn't exist... if (e.getSQLState().compareTo("X0Y32") != 0) { throw e; } - } finally { - connection.close(); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e9..a07d0e84ea854 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -186,14 +186,14 @@ private List> readRecordsFromFile() throws IOException { if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } - DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); - Iterator> records = recordsStream.asKeyValueIterator(); - while (records.hasNext()) { - Tuple2 record = records.next(); - assertEquals(i, hashPartitioner.getPartition(record._1())); - recordsList.add(record); + try (DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in)) { + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2 record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); + } } - recordsStream.close(); startOffset += partitionSize; } } diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 01b5fb7b46684..3992ab7049bdd 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -997,10 +997,10 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - channel1.close(); + try (FileChannel channel1 = fos1.getChannel()) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + } JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { @@ -1018,10 +1018,10 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); - channel1.close(); + try (FileChannel channel1 = fos1.getChannel()) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + } JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(pair -> pair._2().toArray()); // force the file to read @@ -1042,13 +1042,12 @@ public void binaryRecords() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); - FileChannel channel1 = fos1.getChannel(); - - for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = ByteBuffer.wrap(content1); - channel1.write(bbuf); + try (FileChannel channel1 = fos1.getChannel()) { + for (int i = 0; i < numOfCopies; i++) { + ByteBuffer bbuf = ByteBuffer.wrap(content1); + channel1.write(bbuf); + } } - channel1.close(); JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length); assertEquals(numOfCopies,readRDD.count()); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 551443a11298b..460513816dfd9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions; +import java.io.Closeable; import java.io.IOException; import org.apache.spark.memory.MemoryConsumer; @@ -45,7 +46,7 @@ * page requires an average size for key value pairs to be larger than 1024 bytes. * */ -public abstract class RowBasedKeyValueBatch extends MemoryConsumer { +public abstract class RowBasedKeyValueBatch extends MemoryConsumer implements Closeable { protected final Logger logger = LoggerFactory.getLogger(RowBasedKeyValueBatch.class); private static final int DEFAULT_CAPACITY = 1 << 16; diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index 2da87113c6229..8da778800bb9f 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -123,9 +123,8 @@ public void tearDown() { @Test public void emptyBatch() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { Assert.assertEquals(0, batch.numRows()); try { batch.getKeyRow(-1); @@ -152,31 +151,24 @@ public void emptyBatch() throws Exception { // Expected exception; do nothing. } Assert.assertFalse(batch.rowIterator().next()); - } finally { - batch.close(); } } @Test - public void batchType() throws Exception { - RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + public void batchType() { + try (RowBasedKeyValueBatch batch1 = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY); + RowBasedKeyValueBatch batch2 = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { Assert.assertEquals(batch1.getClass(), VariableLengthRowBasedKeyValueBatch.class); Assert.assertEquals(batch2.getClass(), FixedLengthRowBasedKeyValueBatch.class); - } finally { - batch1.close(); - batch2.close(); } } @Test public void setAndRetrieve() { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { UnsafeRow ret1 = appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); Assert.assertTrue(checkValue(ret1, 1, 1)); UnsafeRow ret2 = appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); @@ -204,33 +196,27 @@ public void setAndRetrieve() { } catch (AssertionError e) { // Expected exception; do nothing. } - } finally { - batch.close(); } } @Test public void setUpdateAndRetrieve() { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); Assert.assertEquals(1, batch.numRows()); UnsafeRow retrievedValue = batch.getValueRow(0); updateValueRow(retrievedValue, 2, 2); UnsafeRow retrievedValue2 = batch.getValueRow(0); Assert.assertTrue(checkValue(retrievedValue2, 2, 2)); - } finally { - batch.close(); } } @Test public void iteratorTest() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { appendRow(batch, makeKeyRow(1, "A"), makeValueRow(1, 1)); appendRow(batch, makeKeyRow(2, "B"), makeValueRow(2, 2)); appendRow(batch, makeKeyRow(3, "C"), makeValueRow(3, 3)); @@ -253,16 +239,13 @@ public void iteratorTest() throws Exception { Assert.assertTrue(checkKey(key3, 3, "C")); Assert.assertTrue(checkValue(value3, 3, 3)); Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @Test public void fixedLengthTest() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(fixedKeySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { appendRow(batch, makeKeyRow(11, 11), makeValueRow(1, 1)); appendRow(batch, makeKeyRow(22, 22), makeValueRow(2, 2)); appendRow(batch, makeKeyRow(33, 33), makeValueRow(3, 3)); @@ -293,16 +276,13 @@ public void fixedLengthTest() throws Exception { Assert.assertTrue(checkKey(key3, 33, 33)); Assert.assertTrue(checkValue(value3, 3, 3)); Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @Test public void appendRowUntilExceedingCapacity() throws Exception { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, 10); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, 10)) { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); for (int i = 0; i < 10; i++) { @@ -321,8 +301,6 @@ public void appendRowUntilExceedingCapacity() throws Exception { Assert.assertTrue(checkValue(value1, 1, 1)); } Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @@ -330,9 +308,8 @@ public void appendRowUntilExceedingCapacity() throws Exception { public void appendRowUntilExceedingPageSize() throws Exception { // Use default size or spark.buffer.pageSize if specified int pageSizeToUse = (int) memoryManager.pageSizeBytes(); - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, pageSizeToUse); //enough capacity - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, pageSizeToUse)) { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(1, 1); int recordLength = 8 + key.getSizeInBytes() + value.getSizeInBytes() + 8; @@ -356,49 +333,44 @@ public void appendRowUntilExceedingPageSize() throws Exception { Assert.assertTrue(checkValue(value1, 1, 1)); } Assert.assertFalse(iterator.next()); - } finally { - batch.close(); } } @Test public void failureToAllocateFirstPage() throws Exception { memoryManager.limit(1024); - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { UnsafeRow key = makeKeyRow(1, "A"); UnsafeRow value = makeValueRow(11, 11); UnsafeRow ret = appendRow(batch, key, value); Assert.assertNull(ret); Assert.assertFalse(batch.rowIterator().next()); - } finally { - batch.close(); } } @Test public void randomizedTest() { - RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, - valueSchema, taskMemoryManager, DEFAULT_CAPACITY); - int numEntry = 100; - long[] expectedK1 = new long[numEntry]; - String[] expectedK2 = new String[numEntry]; - long[] expectedV1 = new long[numEntry]; - long[] expectedV2 = new long[numEntry]; - - for (int i = 0; i < numEntry; i++) { - long k1 = rand.nextLong(); - String k2 = getRandomString(rand.nextInt(256)); - long v1 = rand.nextLong(); - long v2 = rand.nextLong(); - appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); - expectedK1[i] = k1; - expectedK2[i] = k2; - expectedV1[i] = v1; - expectedV2[i] = v2; - } - try { + try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema, + valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) { + int numEntry = 100; + long[] expectedK1 = new long[numEntry]; + String[] expectedK2 = new String[numEntry]; + long[] expectedV1 = new long[numEntry]; + long[] expectedV2 = new long[numEntry]; + + for (int i = 0; i < numEntry; i++) { + long k1 = rand.nextLong(); + String k2 = getRandomString(rand.nextInt(256)); + long v1 = rand.nextLong(); + long v2 = rand.nextLong(); + appendRow(batch, makeKeyRow(k1, k2), makeValueRow(v1, v2)); + expectedK1[i] = k1; + expectedK2[i] = k2; + expectedV1[i] = v1; + expectedV2[i] = v2; + } + for (int j = 0; j < 10000; j++) { int rowId = rand.nextInt(numEntry); if (rand.nextBoolean()) { @@ -410,8 +382,6 @@ public void randomizedTest() { Assert.assertTrue(checkValue(value, expectedV1[rowId], expectedV2[rowId])); } } - } finally { - batch.close(); } } } From 6353425af76f9cc9de7ee4094f41df7a7390d898 Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Mon, 8 Oct 2018 13:19:34 -0500 Subject: [PATCH 1796/2461] [SPARK-25641] Change the spark.shuffle.server.chunkFetchHandlerThreadsPercent default to 100 ## What changes were proposed in this pull request? We want to change the default percentage to 100 for spark.shuffle.server.chunkFetchHandlerThreadsPercent. The reason being currently this is set to 0. Which means currently if server.ioThreads > 0, the default number of threads would be 2 * #cores instead of server.io.Threads. We want the default to server.io.Threads in case this is not set at all. Also here a default of 0 would also mean 2 * #cores ## How was this patch tested? Manual (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22628 from redsanket/SPARK-25641. Lead-authored-by: Sanket Chintapalli Co-authored-by: Sanket Chintapalli Signed-off-by: Thomas Graves --- .../apache/spark/network/util/TransportConf.java | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6d5cccd20b333..43a6bc7dc3d06 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -296,17 +296,21 @@ public long maxChunksBeingTransferred() { * and could take long time to process due to disk contentions. By configuring a slightly * higher number of shuffler server threads, we are able to reserve some threads for * handling other RPC messages, thus making the Client less likely to experience timeout - * when sending RPC messages to the shuffle server. Default to 0, which is 2*#cores - * or io.serverThreads. 90 would mean 90% of 2*#cores or 90% of io.serverThreads - * which equals 0.9 * 2*#cores or 0.9 * io.serverThreads. + * when sending RPC messages to the shuffle server. The number of threads used for handling + * chunked fetch requests are percentage of io.serverThreads (if defined) else it is a percentage + * of 2 * #cores. However, a percentage of 0 means netty default number of threads which + * is 2 * #cores ignoring io.serverThreads. The percentage here is configured via + * spark.shuffle.server.chunkFetchHandlerThreadsPercent. The returned value is rounded off to + * ceiling of the nearest integer. */ public int chunkFetchHandlerThreads() { if (!this.getModuleName().equalsIgnoreCase("shuffle")) { return 0; } int chunkFetchHandlerThreadsPercent = - conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 0); - return this.serverThreads() > 0 ? (this.serverThreads() * chunkFetchHandlerThreadsPercent)/100: - (2 * NettyRuntime.availableProcessors() * chunkFetchHandlerThreadsPercent)/100; + conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 100); + return (int)Math.ceil( + (this.serverThreads() > 0 ? this.serverThreads() : 2 * NettyRuntime.availableProcessors()) * + chunkFetchHandlerThreadsPercent/(double)100); } } From 6a60fb0aad62e98c8a0e1c365819f31b1fc0132e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 8 Oct 2018 13:05:53 -0700 Subject: [PATCH 1797/2461] [SPARK-25630][TEST] Reduce test time of HadoopFsRelationTest ## What changes were proposed in this pull request? There was 5 suites extends `HadoopFsRelationTest`, for testing "orc"/"parquet"/"text"/"json" data sources. This PR refactor the base trait `HadoopFsRelationTest`: 1. Rename unnecessary loop for setting parquet conf 2. The test case `SPARK-8406: Avoids name collision while writing files` takes about 14 to 20 seconds. As now all the file format data source are using common code, for creating result files, we can test one data source(Parquet) only to reduce test time. To run related 5 suites: ``` ./build/sbt "hive/testOnly *HadoopFsRelationSuite" ``` The total test run time is reduced from 5 minutes 40 seconds to 3 minutes 50 seconds. ## How was this patch tested? Unit test Closes #22643 from gengliangwang/refactorHadoopFsRelationTest. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../sql/sources/HadoopFsRelationTest.scala | 50 +++++++------------ .../ParquetHadoopFsRelationSuite.scala | 31 +++++++++++- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index b9ec940ac4925..6bd59fde550de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -38,6 +38,10 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val dataSourceName: String + protected val parquetDataSourceName: String = "parquet" + + private def isParquetDataSource: Boolean = dataSourceName == parquetDataSourceName + protected def supportsDataType(dataType: DataType): Boolean = true val dataSchema = @@ -114,10 +118,21 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new UDT.MyDenseVectorUDT() ).filter(supportsDataType) - for (dataType <- supportedDataTypes) { - for (parquetDictionaryEncodingEnabled <- Seq(true, false)) { - test(s"test all data types - $dataType with parquet.enable.dictionary = " + - s"$parquetDictionaryEncodingEnabled") { + test(s"test all data types") { + val parquetDictionaryEncodingEnabledConfs = if (isParquetDataSource) { + // Run with/without Parquet dictionary encoding enabled for Parquet data source. + Seq(true, false) + } else { + Seq(false) + } + for (dataType <- supportedDataTypes) { + for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) { + val extraMessage = if (isParquetDataSource) { + s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled" + } else { + "" + } + logInfo(s"Testing $dataType data type$extraMessage") val extraOptions = Map[String, String]( "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString @@ -754,33 +769,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores - // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or - // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this - // requirement. We probably want to move this test case to spark-integration-tests or spark-perf - // later. - test("SPARK-8406: Avoids name collision while writing files") { - withTempPath { dir => - val path = dir.getCanonicalPath - spark - .range(10000) - .repartition(250) - .write - .mode(SaveMode.Overwrite) - .format(dataSourceName) - .save(path) - - assertResult(10000) { - spark - .read - .format(dataSourceName) - .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) - .load(path) - .count() - } - } - } - test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { val df = Seq( (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 6858bbc441721..6ebc1d145848c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ - override val dataSourceName: String = "parquet" + override val dataSourceName: String = parquetDataSourceName // Parquet does not play well with NullType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { @@ -232,4 +232,33 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + + // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores + // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or + // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this + // requirement. We probably want to move this test case to spark-integration-tests or spark-perf + // later. + // Also, this test is slow. As now all the file format data source are using common code + // for creating result files, we can test Parquet only to reduce test time. + test("SPARK-8406: Avoids name collision while writing files") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark + .range(10000) + .repartition(250) + .write + .mode(SaveMode.Overwrite) + .format(dataSourceName) + .save(path) + + assertResult(10000) { + spark + .read + .format(dataSourceName) + .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) + .load(path) + .count() + } + } + } } From f9935a3f85f46deef2cb7b213c1c02c8ff627a8c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 8 Oct 2018 14:32:04 -0700 Subject: [PATCH 1798/2461] [SPARK-25639][DOCS] Added docs for foreachBatch, python foreach and multiple watermarks ## What changes were proposed in this pull request? Added - Python foreach - Scala, Java and Python foreachBatch - Multiple watermark policy - The semantics of what changes are allowed to the streaming between restarts. ## How was this patch tested? No tests Closes #22627 from tdas/SPARK-25639. Authored-by: Tathagata Das Signed-off-by: Tathagata Das --- .../structured-streaming-programming-guide.md | 323 +++++++++++++++++- 1 file changed, 312 insertions(+), 11 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 73de1892977ac..b6e427735e74b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1560,6 +1560,35 @@ streamingDf <- dropDuplicates(streamingDf, "guid", "eventTime") +### Policy for handling multiple watermarks +A streaming query can have multiple input streams that are unioned or joined together. +Each of the input streams can have a different threshold of late data that needs to +be tolerated for stateful operations. You specify these thresholds using +``withWatermarks("eventTime", delay)`` on each of the input streams. For example, consider +a query with stream-stream joins between `inputStream1` and `inputStream2`. + + inputStream1.withWatermark("eventTime1", "1 hour") + .join( + inputStream2.withWatermark("eventTime2", "2 hours"), + joinCondition) + +While executing the query, Structured Streaming individually tracks the maximum +event time seen in each input stream, calculates watermarks based on the corresponding delay, +and chooses a single global watermark with them to be used for stateful operations. By default, +the minimum is chosen as the global watermark because it ensures that no data is +accidentally dropped as too late if one of the streams falls behind the others +(for example, one of the streams stop receiving data due to upstream failures). In other words, +the global watermark will safely move at the pace of the slowest stream and the query output will +be delayed accordingly. + +However, in some cases, you may want to get faster results even if it means dropping data from the +slowest stream. Since Spark 2.4, you can set the multiple watermark policy to choose +the maximum value as the global watermark by setting the SQL configuration +``spark.sql.streaming.multipleWatermarkPolicy`` to ``max`` (default is ``min``). +This lets the global watermark move at the pace of the fastest stream. +However, as a side effect, data from the slower streams will be aggressively dropped. Hence, use +this configuration judiciously. + ### Arbitrary Stateful Operations Many usecases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)). @@ -1799,8 +1828,16 @@ Here are the details of all the sinks in Spark. - + + + + + + + + + @@ -1989,22 +2026,214 @@ head(sql("select * from aggregates")) -##### Using Foreach -The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` -([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), -which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. +##### Using Foreach and ForeachBatch +The `foreach` and `foreachBatch` operations allow you to apply arbitrary operations and writing +logic on the output of a streaming query. They have slightly different use cases - while `foreach` +allows custom write logic on every row, `foreachBatch` allows arbitrary operations +and custom logic on the output of each micro-batch. Let's understand their usages in more detail. + +###### ForeachBatch +`foreachBatch(...)` allows you to specify a function that is executed on +the output data of every micro-batch of a streaming query. Since Spark 2.4, this is supported in Scala, Java and Python. +It takes two parameters: a DataFrame or Dataset that has the output data of a micro-batch and the unique ID of the micro-batch. + +
      +
      + +{% highlight scala %} +streamingDF.writeStream.foreachBatch { (batchDF: DataFrame, batchId: Long) => + // Transform and write batchDF +}.start() +{% endhighlight %} + +
      +
      + +{% highlight java %} +streamingDatasetOfString.writeStream().foreachBatch( + new VoidFunction2, Long> { + public void call(Dataset dataset, Long batchId) { + // Transform and write batchDF + } + } +).start(); +{% endhighlight %} + +
      +
      + +{% highlight python %} +def foreach_batch_function(df, epoch_id): + # Transform and write batchDF + pass + +streamingDF.writeStream.foreachBatch(foreach_batch_function).start() +{% endhighlight %} + +
      +
      +R is not yet supported. +
      +
      + +With `foreachBatch`, you can do the following. + +- **Reuse existing batch data sources** - For many storage systems, there may not be a streaming sink available yet, + but there may already exist a data writer for batch queries. Using `foreachBatch`, you can use the batch + data writers on the output of each micro-batch. +- **Write to multiple locations** - If you want to write the output of a streaming query to multiple locations, + then you can simply write the output DataFrame/Dataset multiple times. However, each attempt to write can + cause the output data to be recomputed (including possible re-reading of the input data). To avoid recomputations, + you should cache the output DataFrame/Dataset, write it to multiple locations, and then uncache it. Here is an outline. + + streamingDF.writeStream.foreachBatch { (batchDF: DataFrame, batchId: Long) => + batchDF.persist() + batchDF.write.format(...).save(...) // location 1 + batchDF.write.format(...).save(...) // location 2 + batchDF.unpersist() + } + +- **Apply additional DataFrame operations** - Many DataFrame and Dataset operations are not supported + in streaming DataFrames because Spark does not support generating incremental plans in those cases. + Using `foreachBatch`, you can apply some of these operations on each micro-batch output. However, you will have to reason about the end-to-end semantics of doing that operation yourself. + +**Note:** +- By default, `foreachBatch` provides only at-least-once write guarantees. However, you can use the + batchId provided to the function as way to deduplicate the output and get an exactly-once guarantee. +- `foreachBatch` does not work with the continuous processing mode as it fundamentally relies on the + micro-batch execution of a streaming query. If you write data in the continuous mode, use `foreach` instead. + + +###### Foreach +If `foreachBatch` is not an option (for example, corresponding batch data writer does not exist, or +continuous processing mode), then you can express you custom writer logic using `foreach`. +Specifically, you can express the data writing logic by dividing it into three methods: `open`, `process`, and `close`. +Since Spark 2.4, `foreach` is available in Scala, Java and Python. + +
      +
      + +In Scala, you have to extend the class `ForeachWriter` ([docs](api/scala/index.html#org.apache.spark.sql.ForeachWriter)). + +{% highlight scala %} +streamingDatasetOfString.writeStream.foreach( + new ForeachWriter[String] { + + def open(partitionId: Long, version: Long): Boolean = { + // Open connection + } + + def process(record: String): Unit = { + // Write string to connection + } + + def close(errorOrNull: Throwable): Unit = { + // Close the connection + } + } +).start() +{% endhighlight %} + +
      +
      + +In Java, you have to extend the class `ForeachWriter` ([docs](api/java/org/apache/spark/sql/ForeachWriter.html)). +{% highlight java %} +streamingDatasetOfString.writeStream().foreach( + new ForeachWriter[String] { + + @Override public boolean open(long partitionId, long version) { + // Open connection + } + + @Override public void process(String record) { + // Write string to connection + } + + @Override public void close(Throwable errorOrNull) { + // Close the connection + } + } +).start(); + +{% endhighlight %} + +
      +
      + +In Python, you can invoke foreach in two ways: in a function or in an object. +The function offers a simple way to express your processing logic but does not allow you to +deduplicate generated data when failures cause reprocessing of some input data. +For that situation you must specify the processing logic in an object. + +1. The function takes a row as input. + + {% highlight python %} + def process_row(row): + # Write row to storage + pass + + query = streamingDF.writeStream.foreach(process_row).start() + {% endhighlight %} + +2. The object has a process method and optional open and close methods: + + {% highlight python %} + class ForeachWriter: + def open(self, partition_id, epoch_id): + # Open connection. This method is optional in Python. + pass + + def process(self, row): + # Write row to connection. This method is NOT optional in Python. + pass + + def close(self, error): + # Close the connection. This method in optional in Python. + pass + + query = streamingDF.writeStream.foreach(ForeachWriter()).start() + {% endhighlight %} + +
      +
      +R is not yet supported. +
      +
      + + +**Execution semantics** +When the streaming query is started, Spark calls the function or the object’s methods in the following way: + +- A single copy of this object is responsible for all the data generated by a single task in a query. + In other words, one instance is responsible for processing one partition of the data generated in a distributed manner. + +- This object must be serializable, because each task will get a fresh serialized-deserialized copy + of the provided object. Hence, it is strongly recommended that any initialization for writing data + (for example. opening a connection or starting a transaction) is done after the open() method has + been called, which signifies that the task is ready to generate data. + +- The lifecycle of the methods are as follows: + + - For each partition with partition_id: -- The writer must be serializable, as it will be serialized and sent to the executors for execution. + - For each batch/epoch of streaming data with epoch_id: -- All the three methods, `open`, `process` and `close` will be called on the executors. + - Method open(partitionId, epochId) is called. -- The writer must do all the initialization (e.g. opening connections, starting a transaction, etc.) only when the `open` method is called. Be aware that, if there is any initialization in the class as soon as the object is created, then that initialization will happen in the driver (because that is where the instance is being created), which may not be what you intend. + - If open(...) returns true, for each row in the partition and batch/epoch, method process(row) is called. -- `version` and `partition` are two parameters in `open` that uniquely represent a set of rows that needs to be pushed out. `version` is a monotonically increasing id that increases with every trigger. `partition` is an id that represents a partition of the output, since the output is distributed and will be processed on multiple executors. + - Method close(error) is called with error (if any) seen while processing rows. -- `open` can use the `version` and `partition` to choose whether it needs to write the sequence of rows. Accordingly, it can return `true` (proceed with writing), or `false` (no need to write). If `false` is returned, then `process` will not be called on any row. For example, after a partial failure, some of the output partitions of the failed trigger may have already been committed to a database. Based on metadata stored in the database, the writer can identify partitions that have already been committed and accordingly return false to skip committing them again. +- The close() method (if it exists) is called if an open() method exists and returns successfully (irrespective of the return value), except if the JVM or Python process crashes in the middle. -- Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. +- **Note:** The partitionId and epochId in the open() method can be used to deduplicate generated data + when failures cause reprocessing of some input data. This depends on the execution mode of the query. + If the streaming query is being executed in the micro-batch mode, then every partition represented + by a unique tuple (partition_id, epoch_id) is guaranteed to have the same data. + Hence, (partition_id, epoch_id) can be used to deduplicate and/or transactionally commit + data and achieve exactly-once guarantees. However, if the streaming query is being executed + in the continuous mode, then this guarantee does not hold and therefore should not be used for deduplication. #### Triggers The trigger settings of a streaming query defines the timing of streaming data processing, whether @@ -2709,6 +2938,78 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat + +## Recovery Semantics after Changes in a Streaming Query +There are limitations on what changes in a streaming query are allowed between restarts from the +same checkpoint location. Here are a few kinds of changes that are either not allowed, or +the effect of the change is not well-defined. For all of them: + +- The term *allowed* means you can do the specified change but whether the semantics of its effect + is well-defined depends on the query and the change. + +- The term *not allowed* means you should not do the specified change as the restarted query is likely + to fail with unpredictable errors. `sdf` represents a streaming DataFrame/Dataset + generated with sparkSession.readStream. + +**Types of changes** + +- *Changes in the number or type (i.e. different source) of input sources*: This is not allowed. + +- *Changes in the parameters of input sources*: Whether this is allowed and whether the semantics + of the change are well-defined depends on the source and the query. Here are a few examples. + + - Addition/deletion/modification of rate limits is allowed: `spark.readStream.format("kafka").option("subscribe", "topic")` to `spark.readStream.format("kafka").option("subscribe", "topic").option("maxOffsetsPerTrigger", ...)` + + - Changes to subscribed topics/files is generally not allowed as the results are unpredictable: `spark.readStream.format("kafka").option("subscribe", "topic")` to `spark.readStream.format("kafka").option("subscribe", "newTopic")` + +- *Changes in the type of output sink*: Changes between a few specific combinations of sinks + are allowed. This needs to be verified on a case-by-case basis. Here are a few examples. + + - File sink to Kafka sink is allowed. Kafka will see only the new data. + + - Kafka sink to file sink is not allowed. + + - Kafka sink changed to foreach, or vice versa is allowed. + +- *Changes in the parameters of output sink*: Whether this is allowed and whether the semantics of + the change are well-defined depends on the sink and the query. Here are a few examples. + + - Changes to output directory of a file sink is not allowed: `sdf.writeStream.format("parquet").option("path", "/somePath")` to `sdf.writeStream.format("parquet").option("path", "/anotherPath")` + + - Changes to output topic is allowed: `sdf.writeStream.format("kafka").option("topic", "someTopic")` to `sdf.writeStream.format("kafka").option("topic", "anotherTopic")` + + - Changes to the user-defined foreach sink (that is, the `ForeachWriter` code) is allowed, but the semantics of the change depends on the code. + +- *Changes in projection / filter / map-like operations**: Some cases are allowed. For example: + + - Addition / deletion of filters is allowed: `sdf.selectExpr("a")` to `sdf.where(...).selectExpr("a").filter(...)`. + + - Changes in projections with same output schema is allowed: `sdf.selectExpr("stringColumn AS json").writeStream` to `sdf.selectExpr("anotherStringColumn AS json").writeStream` + + - Changes in projections with different output schema are conditionally allowed: `sdf.selectExpr("a").writeStream` to `sdf.selectExpr("b").writeStream` is allowed only if the output sink allows the schema change from `"a"` to `"b"`. + +- *Changes in stateful operations*: Some operations in streaming queries need to maintain + state data in order to continuously update the result. Structured Streaming automatically checkpoints + the state data to fault-tolerant storage (for example, HDFS, AWS S3, Azure Blob storage) and restores it after restart. + However, this assumes that the schema of the state data remains same across restarts. This means that + *any changes (that is, additions, deletions, or schema modifications) to the stateful operations of a streaming query are not allowed between restarts*. + Here is the list of stateful operations whose schema should not be changed between restarts in order to ensure state recovery: + + - *Streaming aggregation*: For example, `sdf.groupBy("a").agg(...)`. Any change in number or type of grouping keys or aggregates is not allowed. + + - *Streaming deduplication*: For example, `sdf.dropDuplicates("a")`. Any change in number or type of grouping keys or aggregates is not allowed. + + - *Stream-stream join*: For example, `sdf1.join(sdf2, ...)` (i.e. both inputs are generated with `sparkSession.readStream`). Changes + in the schema or equi-joining columns are not allowed. Changes in join type (outer or inner) not allowed. Other changes in the join condition are ill-defined. + + - *Arbitrary stateful operation*: For example, `sdf.groupByKey(...).mapGroupsWithState(...)` or `sdf.groupByKey(...).flatMapGroupsWithState(...)`. + Any change to the schema of the user-defined state and the type of timeout is not allowed. + Any change within the user-defined state-mapping function are allowed, but the semantic effect of the change depends on the user-defined logic. + If you really want to support state schema changes, then you can explicitly encode/decode your complex state data + structures into bytes using an encoding/decoding scheme that supports schema migration. For example, + if you save your state as Avro-encoded bytes, then you are free to change the Avro-state-schema between query + restarts as the binary state will always be restored successfully. + # Continuous Processing ## [Experimental] {:.no_toc} From f3fed28230e4e5e08d182715e8cf901daf8f3b73 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 9 Oct 2018 07:45:02 +0800 Subject: [PATCH 1799/2461] [SPARK-25659][PYTHON][TEST] Test type inference specification for createDataFrame in PySpark ## What changes were proposed in this pull request? This PR proposes to specify type inference and simple e2e tests. Looks we are not cleanly testing those logics. For instance, see https://github.com/apache/spark/blob/08c76b5d39127ae207d9d1fff99c2551e6ce2581/python/pyspark/sql/types.py#L894-L905 Looks we intended to support datetime.time and None for type inference too but it does not work: ``` >>> spark.createDataFrame([[datetime.time()]]) Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/session.py", line 751, in createDataFrame rdd, schema = self._createFromLocal(map(prepare, data), schema) File "/.../spark/python/pyspark/sql/session.py", line 432, in _createFromLocal data = [schema.toInternal(row) for row in data] File "/.../spark/python/pyspark/sql/types.py", line 604, in toInternal for f, v, c in zip(self.fields, obj, self._needConversion)) File "/.../spark/python/pyspark/sql/types.py", line 604, in for f, v, c in zip(self.fields, obj, self._needConversion)) File "/.../spark/python/pyspark/sql/types.py", line 442, in toInternal return self.dataType.toInternal(obj) File "/.../spark/python/pyspark/sql/types.py", line 193, in toInternal else time.mktime(dt.timetuple())) AttributeError: 'datetime.time' object has no attribute 'timetuple' >>> spark.createDataFrame([[None]]) Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/session.py", line 751, in createDataFrame rdd, schema = self._createFromLocal(map(prepare, data), schema) File "/.../spark/python/pyspark/sql/session.py", line 419, in _createFromLocal struct = self._inferSchemaFromList(data, names=schema) File "/.../python/pyspark/sql/session.py", line 353, in _inferSchemaFromList raise ValueError("Some of types cannot be determined after inferring") ValueError: Some of types cannot be determined after inferring ``` ## How was this patch tested? Manual tests and unit tests were added. Closes #22653 from HyukjinKwon/SPARK-25659. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/pyspark/sql/tests.py | 69 +++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ac87ccddd689f..85712df5f2ad1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1149,6 +1149,75 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_specification(self): + from decimal import Decimal + + class A(object): + def __init__(self): + self.a = 1 + + data = [ + True, + 1, + "a", + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + array.array("d", [1]), + [1], + (1, ), + {"a": 1}, + bytearray(1), + Decimal(1), + Row(a=1), + Row("a")(1), + A(), + ] + + df = self.spark.createDataFrame([data]) + actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) + expected = [ + 'boolean', + 'bigint', + 'string', + 'string', + 'date', + 'timestamp', + 'double', + 'array', + 'array', + 'struct<_1:bigint>', + 'map', + 'binary', + 'decimal(38,18)', + 'struct', + 'struct', + 'struct', + ] + self.assertEqual(actual, expected) + + actual = list(df.first()) + expected = [ + True, + 1, + 'a', + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + [1.0], + [1], + Row(_1=1), + {"a": 1}, + bytearray(b'\x00'), + Decimal('1.000000000000000000'), + Row(a=1), + Row(a=1), + Row(a=1), + ] + self.assertEqual(actual, expected) + def test_infer_schema_not_enough_names(self): df = self.spark.createDataFrame([["a", "b"]], ["col1"]) self.assertEqual(df.columns, ['col1', '_2']) From a4b14a9cf828572829ad74743e68a06eb376ba28 Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 8 Oct 2018 19:07:05 -0500 Subject: [PATCH 1800/2461] [SPARK-25623][SPARK-25624][SPARK-25625][TEST] Reduce test time of LogisticRegressionSuite ...with intercept with L1 regularization ## What changes were proposed in this pull request? In the test, "multinomial logistic regression with intercept with L1 regularization" in the "LogisticRegressionSuite", taking more than a minute due to training of 2 logistic regression model. However after analysing the training cost over iteration, we can reduce the computation time by 50%. Training cost vs iteration for model 1 ![image](https://user-images.githubusercontent.com/23054875/46573805-ddab7680-c9b7-11e8-9ee9-63a99d498475.png) So, model1 is converging after iteration 150. Training cost vs iteration for model 2 ![image](https://user-images.githubusercontent.com/23054875/46573790-b3f24f80-c9b7-11e8-89c0-81045ad647cb.png) After around 100 iteration, model2 is converging. So, if we give maximum iteration for model1 and model2 as 175 and 125 respectively, we can reduce the computation time by half. ## How was this patch tested? Computation time in local setup : Before change: ~53 sec After change: ~26 sec Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22659 from shahidki31/SPARK-25623. Authored-by: Shahid Signed-off-by: Sean Owen --- .../LogisticRegressionSuite.scala | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 75c2aeb146786..84c10e2f85c81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -79,7 +79,9 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed)) + val df = sc.parallelize(testData, 4).toDF().withColumn("weight", rand(seed)) + df.cache() + df } multinomialDataset = { @@ -1130,9 +1132,9 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { } test("binary logistic regression with intercept with ElasticNet regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) + val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(120) .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") - val trainer2 = (new LogisticRegression).setFitIntercept(true) + val trainer2 = (new LogisticRegression).setFitIntercept(true).setMaxIter(30) .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false).setWeightCol("weight") val model1 = trainer1.fit(binaryDataset) @@ -1174,7 +1176,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { val coefficientsR = Vectors.dense(0.0, 0.0, -0.1846038, -0.0559614) val interceptR = 0.5024256 - assert(model1.intercept ~== interceptRStd relTol 6E-3) + assert(model1.intercept ~== interceptRStd relTol 6E-2) assert(model1.coefficients ~== coefficientsRStd absTol 5E-3) assert(model2.intercept ~== interceptR relTol 6E-3) assert(model2.coefficients ~= coefficientsR absTol 1E-3) @@ -1677,10 +1679,10 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { // use tighter constraints because OWL-QN solver takes longer to converge val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true) - .setMaxIter(300).setTol(1e-10).setWeightCol("weight") + .setMaxIter(160).setTol(1e-10).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false) - .setMaxIter(300).setTol(1e-10).setWeightCol("weight") + .setMaxIter(110).setTol(1e-10).setWeightCol("weight") val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) @@ -1767,7 +1769,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { 0.0, 0.0, 0.0, 0.0), isTransposed = true) val interceptsR = Vectors.dense(-0.44215290, 0.76308326, -0.3209304) - assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.02) + assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.05) assert(model1.interceptVector ~== interceptsRStd relTol 0.1) assert(model1.interceptVector.toArray.sum ~== 0.0 absTol eps) assert(model2.coefficientMatrix ~== coefficientsR absTol 0.02) @@ -2145,10 +2147,10 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { test("multinomial logistic regression with intercept with elasticnet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(220).setTol(1e-10) val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(90).setTol(1e-10) val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) @@ -2234,8 +2236,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { 0.0, 0.0, 0.0, 0.0), isTransposed = true) val interceptsR = Vectors.dense(-0.38857157, 0.62492165, -0.2363501) - assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.01) - assert(model1.interceptVector ~== interceptsRStd absTol 0.01) + assert(model1.coefficientMatrix ~== coefficientsRStd absTol 0.05) + assert(model1.interceptVector ~== interceptsRStd absTol 0.1) assert(model1.interceptVector.toArray.sum ~== 0.0 absTol eps) assert(model2.coefficientMatrix ~== coefficientsR absTol 0.01) assert(model2.interceptVector ~== interceptsR absTol 0.01) @@ -2245,10 +2247,10 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { test("multinomial logistic regression without intercept with elasticnet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(75).setTol(1e-10) val trainer2 = (new LogisticRegression).setFitIntercept(false).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false) - .setMaxIter(300).setTol(1e-10) + .setMaxIter(50).setTol(1e-10) val model1 = trainer1.fit(multinomialDataset) val model2 = trainer2.fit(multinomialDataset) From 46fe40838aa682a7073dd6f1373518b0c8498a94 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 9 Oct 2018 14:35:00 +0800 Subject: [PATCH 1801/2461] [SPARK-25669][SQL] Check CSV header only when it exists ## What changes were proposed in this pull request? Currently the first row of dataset of CSV strings is compared to field names of user specified or inferred schema independently of presence of CSV header. It causes false-positive error messages. For example, parsing `"1,2"` outputs the error: ```java java.lang.IllegalArgumentException: CSV header does not conform to the schema. Header: 1, 2 Schema: _c0, _c1 Expected: _c0 but found: 1 ``` In the PR, I propose: - Checking CSV header only when it exists - Filter header from the input dataset only if it exists ## How was this patch tested? Added a test to `CSVSuite` which reproduces the issue. Closes #22656 from MaxGekk/inferred-header-check. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../main/scala/org/apache/spark/sql/DataFrameReader.scala | 7 +++++-- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fe69f252d43e0..72694463cedb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -505,7 +505,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + val linesWithoutHeader = if (parsedOptions.headerFlag && maybeFirstLine.isDefined) { + val firstLine = maybeFirstLine.get val parser = new CsvParser(parsedOptions.asParserSettings) val columnNames = parser.parseLine(firstLine) CSVDataSource.checkHeaderColumnNames( @@ -515,7 +516,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions.enforceSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) - }.getOrElse(filteredLines.rdd) + } else { + filteredLines.rdd + } val parsed = linesWithoutHeader.mapPartitions { iter => val rawParser = new UnivocityParser(actualSchema, parsedOptions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index f70df0bcecde7..5d4746cf90b3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1820,4 +1820,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(spark.read.option("multiLine", true).schema(schema).csv(input), Row(null)) assert(spark.read.csv(input).collect().toSet == Set(Row())) } + + test("field names of inferred schema shouldn't compare to the first row") { + val input = Seq("1,2").toDS() + val df = spark.read.option("enforceSchema", false).csv(input) + checkAnswer(df, Row("1", "2")) + } } From e3133f4abf1cd5667abe5f0d05fa0af0df3033ae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Oct 2018 16:46:23 +0900 Subject: [PATCH 1802/2461] [SPARK-25497][SQL] Limit operation within whole stage codegen should not consume all the inputs ## What changes were proposed in this pull request? This PR is inspired by https://github.com/apache/spark/pull/22524, but proposes a safer fix. The current limit whole stage codegen has 2 problems: 1. It's only applied to `InputAdapter`, many leaf nodes can't stop earlier w.r.t. limit. 2. It needs to override a method, which will break if we have more than one limit in the whole-stage. The first problem is easy to fix, just figure out which nodes can stop earlier w.r.t. limit, and update them. This PR updates `RangeExec`, `ColumnarBatchScan`, `SortExec`, `HashAggregateExec`. The second problem is hard to fix. This PR proposes to propagate the limit counter variable name upstream, so that the upstream leaf/blocking nodes can check the limit counter and quit the loop earlier. For better performance, the implementation here follows `CodegenSupport.needStopCheck`, so that we only codegen the check only if there is limit in the query. For columnar node like range, we check the limit counter per-batch instead of per-row, to make the inner loop tight and fast. Why this is safer? 1. the leaf/blocking nodes don't have to check the limit counter and stop earlier. It's only for performance. (this is same as before) 2. The blocking operators can stop propagating the limit counter name, because the counter of limit after blocking operators will never increase, before blocking operators consume all the data from upstream operators. So the upstream operators don't care about limit after blocking operators. This is also for performance only, it's OK if we forget to do it for some new blocking operators. ## How was this patch tested? a new test Closes #22630 from cloud-fan/limit. Authored-by: Wenchen Fan Signed-off-by: Kazuaki Ishizaki --- .../sql/execution/BufferedRowIterator.java | 10 -- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../apache/spark/sql/execution/SortExec.scala | 12 +- .../sql/execution/WholeStageCodegenExec.scala | 59 +++++++++- .../aggregate/HashAggregateExec.scala | 22 +--- .../execution/basicPhysicalOperators.scala | 91 +++++++++----- .../apache/spark/sql/execution/limit.scala | 31 +++-- .../execution/metric/SQLMetricsSuite.scala | 111 +++++++++++------- 8 files changed, 215 insertions(+), 125 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 74c9c05992719..3d0511b7ba838 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -73,16 +73,6 @@ public void append(InternalRow row) { currentRows.add(row); } - /** - * Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]]. - * - * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. - * This interface is mainly used to limit the number of input rows. - */ - public boolean stopEarly() { - return false; - } - /** * Returns whether `processNext()` should stop processing next row from `input` or not. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 48abad9078650..9f6b593360802 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { |if ($batch == null) { | $nextBatchFuncName(); |} - |while ($batch != null) { + |while ($limitNotReachedCond $batch != null) { | int $numRows = $batch.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { @@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val inputRow = if (needsUnsafeRowConversion) null else row s""" - |while ($input.hasNext()) { + |while ($limitNotReachedCond $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); | ${consume(ctx, outputVars, inputRow).trim} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0dc16ba5ce281..f1470e45f1292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -39,7 +39,7 @@ case class SortExec( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with BlockingOperatorWithCodegen { override def output: Seq[Attribute] = child.output @@ -124,14 +124,6 @@ case class SortExec( // Name of sorter variable used in codegen. private var sorterVariable: String = _ - // The result rows come from the sort buffer, so this operator doesn't need to copy its result - // even if its child does. - override def needCopyResult: Boolean = false - - // Sort operator always consumes all the input rows before outputting any result, so we don't need - // a stop check before sorting. - override def needStopCheck: Boolean = false - override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") @@ -172,7 +164,7 @@ case class SortExec( | $needToSort = false; | } | - | while ($sortedIterator.hasNext()) { + | while ($limitNotReachedCond $sortedIterator.hasNext()) { | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); | ${consume(ctx, null, outputRow)} | if (shouldStop()) return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1fc4de9e56015..f5aee627fe901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -345,6 +345,61 @@ trait CodegenSupport extends SparkPlan { * don't require shouldStop() in the loop of producing rows. */ def needStopCheck: Boolean = parent.needStopCheck + + /** + * A sequence of checks which evaluate to true if the downstream Limit operators have not received + * enough records and reached the limit. If current node is a data producing node, it can leverage + * this information to stop producing data and complete the data flow earlier. Common data + * producing nodes are leaf nodes like Range and Scan, and blocking nodes like Sort and Aggregate. + * These checks should be put into the loop condition of the data producing loop. + */ + def limitNotReachedChecks: Seq[String] = parent.limitNotReachedChecks + + /** + * A helper method to generate the data producing loop condition according to the + * limit-not-reached checks. + */ + final def limitNotReachedCond: String = { + // InputAdapter is also a leaf node. + val isLeafNode = children.isEmpty || this.isInstanceOf[InputAdapter] + if (!isLeafNode && !this.isInstanceOf[BlockingOperatorWithCodegen]) { + val errMsg = "Only leaf nodes and blocking nodes need to call 'limitNotReachedCond' " + + "in its data producing loop." + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logWarning(s"[BUG] $errMsg Please open a JIRA ticket to report it.") + } + } + if (parent.limitNotReachedChecks.isEmpty) { + "" + } else { + parent.limitNotReachedChecks.mkString("", " && ", " &&") + } + } +} + +/** + * A special kind of operators which support whole stage codegen. Blocking means these operators + * will consume all the inputs first, before producing output. Typical blocking operators are + * sort and aggregate. + */ +trait BlockingOperatorWithCodegen extends CodegenSupport { + + // Blocking operators usually have some kind of buffer to keep the data before producing them, so + // then don't to copy its result even if its child does. + override def needCopyResult: Boolean = false + + // Blocking operators always consume all the input first, so its upstream operators don't need a + // stop check. + override def needStopCheck: Boolean = false + + // Blocking operators need to consume all the inputs before producing any output. This means, + // Limit operator after this blocking operator will never reach its limit during the execution of + // this blocking operator's upstream operators. Here we override this method to return Nil, so + // that upstream operators will not generate useless conditions (which are always evaluated to + // false) for the Limit operators after this blocking operator. + override def limitNotReachedChecks: Seq[String] = Nil } @@ -381,7 +436,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp forceInline = true) val row = ctx.freshName("row") s""" - | while ($input.hasNext() && !stopEarly()) { + | while ($limitNotReachedCond $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | ${consume(ctx, null, row).trim} | if (shouldStop()) return; @@ -677,6 +732,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) override def needStopCheck: Boolean = true + override def limitNotReachedChecks: Seq[String] = Nil + override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 98adba50b2973..6155ec9d30db4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -45,7 +45,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with BlockingOperatorWithCodegen { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -151,14 +151,6 @@ case class HashAggregateExec( child.asInstanceOf[CodegenSupport].inputRDDs() } - // The result rows come from the aggregate buffer, or a single row(no grouping keys), so this - // operator doesn't need to copy its result even if its child does. - override def needCopyResult: Boolean = false - - // Aggregate operator always consumes all the input rows before outputting any result, so we - // don't need a stop check before aggregating. - override def needStopCheck: Boolean = false - protected override def doProduce(ctx: CodegenContext): String = { if (groupingExpressions.isEmpty) { doProduceWithoutKeys(ctx) @@ -705,13 +697,16 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" - |while ($iterTerm.next()) { + |while ($limitNotReachedCond $iterTerm.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); - | | if (shouldStop()) return; |} + |$iterTerm.close(); + |if ($sorterTerm == null) { + | $hashMapTerm.free(); + |} """.stripMargin } @@ -728,11 +723,6 @@ case class HashAggregateExec( // output the result $outputFromFastHashMap $outputFromRegularHashMap - - $iterTerm.close(); - if ($sorterTerm == null) { - $hashMapTerm.free(); - } """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 222a1b8bc7301..4cd2e788ade07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -378,7 +378,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") + val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex") val value = ctx.freshName("value") val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) @@ -397,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // within a batch, while the code in the outer loop is setting batch parameters and updating // the metrics. - // Once number == batchEnd, it's time to progress to the next batch. + // Once nextIndex == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. @@ -421,13 +421,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; + | $nextIndex = Long.MAX_VALUE; | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; + | $nextIndex = Long.MIN_VALUE; | } else { - | $number = st.longValue(); + | $nextIndex = st.longValue(); | } - | $batchEnd = $number; + | $batchEnd = $nextIndex; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); @@ -440,7 +440,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( - | $BigInt.valueOf($number)); + | $BigInt.valueOf($nextIndex)); | $numElementsTodo = startToEnd.divide(step).longValue(); | if ($numElementsTodo < 0) { | $numElementsTodo = 0; @@ -452,12 +452,42 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") - val range = ctx.freshName("range") val shouldStop = if (parent.needStopCheck) { - s"if (shouldStop()) { $number = $value + ${step}L; return; }" + s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" } + val loopCondition = if (limitNotReachedChecks.isEmpty) { + "true" + } else { + limitNotReachedChecks.mkString(" && ") + } + + // An overview of the Range processing. + // + // For each partition, the Range task needs to produce records from partition start(inclusive) + // to end(exclusive). For better performance, we separate the partition range into batches, and + // use 2 loops to produce data. The outer while loop is used to iterate batches, and the inner + // for loop is used to iterate records inside a batch. + // + // `nextIndex` tracks the index of the next record that is going to be consumed, initialized + // with partition start. `batchEnd` tracks the end index of the current batch, initialized + // with `nextIndex`. In the outer loop, we first check if `nextIndex == batchEnd`. If it's true, + // it means the current batch is fully consumed, and we will update `batchEnd` to process the + // next batch. If `batchEnd` reaches partition end, exit the outer loop. Finally we enter the + // inner loop. Note that, when we enter inner loop, `nextIndex` must be different from + // `batchEnd`, otherwise we already exit the outer loop. + // + // The inner loop iterates from 0 to `localEnd`, which is calculated by + // `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by `nextBatchTodo * step` in + // the outer loop, and initialized with `nextIndex`, so `batchEnd - nextIndex` is always + // divisible by `step`. The `nextIndex` is increased by `step` during each iteration, and ends + // up being equal to `batchEnd` when the inner loop finishes. + // + // The inner loop can be interrupted, if the query has produced at least one result row, so that + // we don't buffer too many result rows and waste memory. It's ok to interrupt the inner loop, + // because `nextIndex` will be updated before interrupting. + s""" | // initialize Range | if (!$initTerm) { @@ -465,33 +495,30 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $initRangeFuncName(partitionIndex); | } | - | while (true) { - | long $range = $batchEnd - $number; - | if ($range != 0L) { - | int $localEnd = (int)($range / ${step}L); - | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | long $value = ((long)$localIdx * ${step}L) + $number; - | ${consume(ctx, Seq(ev))} - | $shouldStop + | while ($loopCondition) { + | if ($nextIndex == $batchEnd) { + | long $nextBatchTodo; + | if ($numElementsTodo > ${batchSize}L) { + | $nextBatchTodo = ${batchSize}L; + | $numElementsTodo -= ${batchSize}L; + | } else { + | $nextBatchTodo = $numElementsTodo; + | $numElementsTodo = 0; + | if ($nextBatchTodo == 0) break; | } - | $number = $batchEnd; + | $numOutput.add($nextBatchTodo); + | $inputMetrics.incRecordsRead($nextBatchTodo); + | $batchEnd += $nextBatchTodo * ${step}L; | } | - | $taskContext.killTaskIfInterrupted(); - | - | long $nextBatchTodo; - | if ($numElementsTodo > ${batchSize}L) { - | $nextBatchTodo = ${batchSize}L; - | $numElementsTodo -= ${batchSize}L; - | } else { - | $nextBatchTodo = $numElementsTodo; - | $numElementsTodo = 0; - | if ($nextBatchTodo == 0) break; + | int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $nextIndex; + | ${consume(ctx, Seq(ev))} + | $shouldStop | } - | $numOutput.add($nextBatchTodo); - | $inputMetrics.incRecordsRead($nextBatchTodo); - | - | $batchEnd += $nextBatchTodo * ${step}L; + | $nextIndex = $batchEnd; + | $taskContext.killTaskIfInterrupted(); | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda8913738..9bfe1a79fc1e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -46,6 +46,15 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode } } +object BaseLimitExec { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def newLimitCountTerm(): String = { + val id = curId.getAndIncrement() + s"_limit_counter_$id" + } +} + /** * Helper trait which defines methods that are shared by both * [[LocalLimitExec]] and [[GlobalLimitExec]]. @@ -66,27 +75,25 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { // to the parent operator. override def usedInputs: AttributeSet = AttributeSet.empty + private lazy val countTerm = BaseLimitExec.newLimitCountTerm() + + override lazy val limitNotReachedChecks: Seq[String] = { + s"$countTerm < $limit" +: super.limitNotReachedChecks + } + protected override def doProduce(ctx: CodegenContext): String = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = - ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false - - ctx.addNewFunction("stopEarly", s""" - @Override - protected boolean stopEarly() { - return $stopEarly; - } - """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 + // The counter name is already obtained by the upstream operators via `limitNotReachedChecks`. + // Here we have to inline it to not change its name. This is fine as we won't have many limit + // operators in one query. + ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false) s""" | if ($countTerm < $limit) { | $countTerm += 1; | ${consume(ctx, input)} - | } else { - | $stopEarly = true; | } """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 085a445488480..81db3e137964d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.reflect.{classTag, ClassTag} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -518,56 +521,80 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared testMetricsDynamicPartition("parquet", "parquet", "t1") } + private def collectNodeWithinWholeStage[T <: SparkPlan : ClassTag](plan: SparkPlan): Seq[T] = { + val stages = plan.collect { + case w: WholeStageCodegenExec => w + } + assert(stages.length == 1, "The query plan should have one and only one whole-stage.") + + val cls = classTag[T].runtimeClass + stages.head.collect { + case n if n.getClass == cls => n.asInstanceOf[T] + } + } + test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { def checkFilterAndRangeMetrics( df: DataFrame, filterNumOutputs: Int, rangeNumOutputs: Int): Unit = { - var filter: FilterExec = null - var range: RangeExec = null - val collectFilterAndRange: SparkPlan => Unit = { - case f: FilterExec => - assert(filter == null, "the query should only have one Filter") - filter = f - case r: RangeExec => - assert(range == null, "the query should only have one Range") - range = r - case _ => - } - if (SQLConf.get.wholeStageEnabled) { - df.queryExecution.executedPlan.foreach { - case w: WholeStageCodegenExec => - w.child.foreach(collectFilterAndRange) - case _ => - } - } else { - df.queryExecution.executedPlan.foreach(collectFilterAndRange) - } + val plan = df.queryExecution.executedPlan - assert(filter != null && range != null, "the query doesn't have Filter and Range") - assert(filter.metrics("numOutputRows").value == filterNumOutputs) - assert(range.metrics("numOutputRows").value == rangeNumOutputs) + val filters = collectNodeWithinWholeStage[FilterExec](plan) + assert(filters.length == 1, "The query plan should have one and only one Filter") + assert(filters.head.metrics("numOutputRows").value == filterNumOutputs) + + val ranges = collectNodeWithinWholeStage[RangeExec](plan) + assert(ranges.length == 1, "The query plan should have one and only one Range") + assert(ranges.head.metrics("numOutputRows").value == rangeNumOutputs) } - val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) - val df2 = df.limit(2) - Seq(true, false).foreach { wholeStageEnabled => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageEnabled.toString) { - df.collect() - checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) - - df.queryExecution.executedPlan.foreach(_.resetMetrics()) - // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, - // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces - // 4 rows, and Range produces 2000 rows. - df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() - checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) - - // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first - // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). - df2.collect() - checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) - } + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) + df.collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) + + df.queryExecution.executedPlan.foreach(_.resetMetrics()) + // For each partition, we get 2 rows. Then the Filter should produce 2 rows per-partition, + // and Range should produce 1000 rows (one batch) per-partition. Totally Filter produces + // 4 rows, and Range produces 2000 rows. + df.queryExecution.toRdd.mapPartitions(_.take(2)).collect() + checkFilterAndRangeMetrics(df, filterNumOutputs = 4, rangeNumOutputs = 2000) + + // Top-most limit will call `CollectLimitExec.executeCollect`, which will only run the first + // task, so totally the Filter produces 2 rows, and Range produces 1000 rows (one batch). + val df2 = df.limit(2) + df2.collect() + checkFilterAndRangeMetrics(df2, filterNumOutputs = 2, rangeNumOutputs = 1000) + } + } + + test("SPARK-25497: LIMIT within whole stage codegen should not consume all the inputs") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + // A special query that only has one partition, so there is no shuffle and the entire query + // can be whole-stage-codegened. + val df = spark.range(0, 1500, 1, 1).limit(10).groupBy('id).count().limit(1).filter('id >= 0) + df.collect() + val plan = df.queryExecution.executedPlan + + val ranges = collectNodeWithinWholeStage[RangeExec](plan) + assert(ranges.length == 1, "The query plan should have one and only one Range") + // The Range should only produce the first batch, i.e. 1000 rows. + assert(ranges.head.metrics("numOutputRows").value == 1000) + + val aggs = collectNodeWithinWholeStage[HashAggregateExec](plan) + assert(aggs.length == 2, "The query plan should have two and only two Aggregate") + val partialAgg = aggs.filter(_.aggregateExpressions.head.mode == Partial).head + // The partial aggregate should output 10 rows, because its input is 10 rows. + assert(partialAgg.metrics("numOutputRows").value == 10) + val finalAgg = aggs.filter(_.aggregateExpressions.head.mode == Final).head + // The final aggregate should only produce 1 row, because the upstream limit only needs 1 row. + assert(finalAgg.metrics("numOutputRows").value == 1) + + val filters = collectNodeWithinWholeStage[FilterExec](plan) + assert(filters.length == 1, "The query plan should have one and only one Filter") + // The final Filter should produce 1 rows, because the input is just one row. + assert(filters.head.metrics("numOutputRows").value == 1) } } } From deb9588b2ab6596b30ab17f56c59951cabf57162 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Tue, 9 Oct 2018 08:59:21 -0500 Subject: [PATCH 1803/2461] [SPARK-24851][UI] Map a Stage ID to it's Associated Job ID It would be nice to have a field in Stage Page UI which would show mapping of the current stage id to the job id's to which that stage belongs to. ## What changes were proposed in this pull request? Added a field in Stage UI to display the corresponding job id for that particular stage. ## How was this patch tested? screen shot 2018-07-25 at 1 33 07 pm Closes #21809 from pgandhi999/SPARK-24851. Authored-by: pgandhi Signed-off-by: Thomas Graves --- .../org/apache/spark/status/AppStatusStore.scala | 8 +++++--- .../apache/spark/status/api/v1/StagesResource.scala | 2 +- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 13 +++++++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index e237281c552b1..9839cbb99f862 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -112,10 +112,12 @@ private[spark] class AppStatusStore( } } - def stageAttempt(stageId: Int, stageAttemptId: Int, details: Boolean = false): v1.StageData = { + def stageAttempt(stageId: Int, stageAttemptId: Int, + details: Boolean = false): (v1.StageData, Seq[Int]) = { val stageKey = Array(stageId, stageAttemptId) - val stage = store.read(classOf[StageDataWrapper], stageKey).info - if (details) stageWithDetails(stage) else stage + val stageDataWrapper = store.read(classOf[StageDataWrapper], stageKey) + val stage = if (details) stageWithDetails(stageDataWrapper.info) else stageDataWrapper.info + (stage, stageDataWrapper.jobIds.toSeq) } def taskCount(stageId: Int, stageAttemptId: Int): Long = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 96249e4bfd5fa..30d52b97833e6 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -56,7 +56,7 @@ private[v1] class StagesResource extends BaseAppResource { @PathParam("stageAttemptId") stageAttemptId: Int, @QueryParam("details") @DefaultValue("true") details: Boolean): StageData = withUI { ui => try { - ui.store.stageAttempt(stageId, stageAttemptId, details = details) + ui.store.stageAttempt(stageId, stageAttemptId, details = details)._1 } catch { case _: NoSuchElementException => // Change the message depending on whether there are any attempts for the requested stage. diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7428bbe6c5592..0f74b07a6265c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -105,7 +105,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val stageAttemptId = parameterAttempt.toInt val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" - val stageData = parent.store + val (stageData, stageJobIds) = parent.store .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false)) .getOrElse { val content = @@ -183,6 +183,15 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We {Utils.bytesToString(stageData.diskBytesSpilled)} }} + {if (!stageJobIds.isEmpty) { +
    • + Associated Job Ids: + {stageJobIds.map(jobId => {val detailUrl = "%s/jobs/job/?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), jobId) + {s"${jobId}"}    + })} +
    • + }} @@ -1048,7 +1057,7 @@ private[ui] object ApiHelper { } def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { - val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)) + val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)._1) (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } From 3eee9e02463e10570a29fad00823c953debd945e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 9 Oct 2018 09:27:08 -0500 Subject: [PATCH 1804/2461] [SPARK-25535][CORE] Work around bad error handling in commons-crypto. The commons-crypto library does some questionable error handling internally, which can lead to JVM crashes if some call into native code fails and cleans up state it should not. While the library is not fixed, this change adds some workarounds in Spark code so that when an error is detected in the commons-crypto side, Spark avoids calling into the library further. Tested with existing and added unit tests. Closes #22557 from vanzin/SPARK-25535. Authored-by: Marcelo Vanzin Signed-off-by: Imran Rashid --- .../spark/network/crypto/AuthEngine.java | 95 ++++++++---- .../spark/network/crypto/TransportCipher.java | 60 +++++++- .../spark/network/crypto/AuthEngineSuite.java | 17 +++ .../spark/security/CryptoStreamUtils.scala | 137 ++++++++++++++++-- .../security/CryptoStreamUtilsSuite.scala | 37 ++++- 5 files changed, 295 insertions(+), 51 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index 056505ef53356..64fdb32a67ada 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -159,15 +159,21 @@ public void close() throws IOException { // accurately report the errors when they happen. RuntimeException error = null; byte[] dummy = new byte[8]; - try { - doCipherOp(encryptor, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); + if (encryptor != null) { + try { + doCipherOp(Cipher.ENCRYPT_MODE, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + encryptor = null; } - try { - doCipherOp(decryptor, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); + if (decryptor != null) { + try { + doCipherOp(Cipher.DECRYPT_MODE, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + decryptor = null; } random.close(); @@ -189,11 +195,11 @@ byte[] rawResponse(byte[] challenge) { } private byte[] decrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(decryptor, in, false); + return doCipherOp(Cipher.DECRYPT_MODE, in, false); } private byte[] encrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(encryptor, in, false); + return doCipherOp(Cipher.ENCRYPT_MODE, in, false); } private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) @@ -205,11 +211,13 @@ private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) byte[] iv = new byte[conf.ivLength()]; System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); - encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + this.encryptor = _encryptor; - decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + this.decryptor = _decryptor; } /** @@ -241,29 +249,52 @@ private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int k return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); } - private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal) + private byte[] doCipherOp(int mode, byte[] in, boolean isFinal) throws GeneralSecurityException { - Preconditions.checkState(cipher != null); + CryptoCipher cipher; + switch (mode) { + case Cipher.ENCRYPT_MODE: + cipher = encryptor; + break; + case Cipher.DECRYPT_MODE: + cipher = decryptor; + break; + default: + throw new IllegalArgumentException(String.valueOf(mode)); + } - int scale = 1; - while (true) { - int size = in.length * scale; - byte[] buffer = new byte[size]; - try { - int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) - : cipher.update(in, 0, in.length, buffer, 0); - if (outSize != buffer.length) { - byte[] output = new byte[outSize]; - System.arraycopy(buffer, 0, output, 0, output.length); - return output; - } else { - return buffer; + Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error."); + + try { + int scale = 1; + while (true) { + int size = in.length * scale; + byte[] buffer = new byte[size]; + try { + int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) + : cipher.update(in, 0, in.length, buffer, 0); + if (outSize != buffer.length) { + byte[] output = new byte[outSize]; + System.arraycopy(buffer, 0, output, 0, output.length); + return output; + } else { + return buffer; + } + } catch (ShortBufferException e) { + // Try again with a bigger buffer. + scale *= 2; } - } catch (ShortBufferException e) { - // Try again with a bigger buffer. - scale *= 2; } + } catch (InternalError ie) { + // SPARK-25535. The commons-cryto library will throw InternalError if something goes wrong, + // and leave bad state behind in the Java wrappers, so it's not safe to use them afterwards. + if (mode == Cipher.ENCRYPT_MODE) { + this.encryptor = null; + } else { + this.decryptor = null; + } + throw ie; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index b64e4b7a970b5..2745052265f7f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -107,45 +107,72 @@ public void addToChannel(Channel ch) throws IOException { private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; + private boolean isCipherValid; EncryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); cos = cipher.createOutputStream(byteChannel); + isCipherValid = true; } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + ctx.write(new EncryptedMessage(this, cos, msg, byteChannel), promise); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { try { - cos.close(); + if (isCipherValid) { + cos.close(); + } } finally { super.close(ctx, promise); } } + + /** + * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher + * after an error occurs. + */ + void reportError() { + this.isCipherValid = false; + } + + boolean isCipherValid() { + return isCipherValid; + } } private static class DecryptionHandler extends ChannelInboundHandlerAdapter { private final CryptoInputStream cis; private final ByteArrayReadableChannel byteChannel; + private boolean isCipherValid; DecryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayReadableChannel(); cis = cipher.createInputStream(byteChannel); + isCipherValid = true; } @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + if (!isCipherValid) { + throw new IOException("Cipher is in invalid state."); + } byteChannel.feedData((ByteBuf) data); byte[] decryptedData = new byte[byteChannel.readableBytes()]; int offset = 0; while (offset < decryptedData.length) { - offset += cis.read(decryptedData, offset, decryptedData.length - offset); + // SPARK-25535: workaround for CRYPTO-141. + try { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } catch (InternalError ie) { + isCipherValid = false; + throw ie; + } } ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); @@ -154,7 +181,9 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { - cis.close(); + if (isCipherValid) { + cis.close(); + } } finally { super.channelInactive(ctx); } @@ -165,8 +194,9 @@ private static class EncryptedMessage extends AbstractFileRegion { private final boolean isByteBuf; private final ByteBuf buf; private final FileRegion region; + private final CryptoOutputStream cos; + private final EncryptionHandler handler; private long transferred; - private CryptoOutputStream cos; // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data @@ -176,9 +206,14 @@ private static class EncryptedMessage extends AbstractFileRegion { private ByteBuffer currentEncrypted; - EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + EncryptedMessage( + EncryptionHandler handler, + CryptoOutputStream cos, + Object msg, + ByteArrayWritableChannel ch) { Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, "Unrecognized message type: %s", msg.getClass().getName()); + this.handler = handler; this.isByteBuf = msg instanceof ByteBuf; this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; @@ -261,6 +296,9 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep } private void encryptMore() throws IOException { + if (!handler.isCipherValid()) { + throw new IOException("Cipher is in invalid state."); + } byteRawChannel.reset(); if (isByteBuf) { @@ -269,8 +307,14 @@ private void encryptMore() throws IOException { } else { region.transferTo(byteRawChannel, region.transferred()); } - cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); - cos.flush(); + + try { + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + } catch (InternalError ie) { + handler.reportError(); + throw ie; + } currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), 0, byteEncChannel.length()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index a3519fe4a423e..c0aa298a4017c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -18,8 +18,11 @@ package org.apache.spark.network.crypto; import java.util.Arrays; +import java.util.Map; +import java.security.InvalidKeyException; import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.collect.ImmutableMap; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; @@ -104,4 +107,18 @@ public void testBadChallenge() throws Exception { challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); } + @Test(expected = InvalidKeyException.class) + public void testBadKeySize() throws Exception { + Map mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42"); + TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf)); + + try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) { + engine.challenge(); + fail("Should have failed to create challenge message."); + + // Call close explicitly to make sure it's idempotent. + engine.close(); + } + } + } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 00621976b77f4..18b735b8035ab 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{Closeable, InputStream, IOException, OutputStream} import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties @@ -54,8 +54,10 @@ private[spark] object CryptoStreamUtils extends Logging { val params = new CryptoParams(key, sparkConf) val iv = createInitializationVector(params.conf) os.write(iv) - new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingOutputStream( + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)), + os) } /** @@ -70,8 +72,10 @@ private[spark] object CryptoStreamUtils extends Logging { val helper = new CryptoHelperChannel(channel) helper.write(ByteBuffer.wrap(iv)) - new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingWritableChannel( + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)), + helper) } /** @@ -84,8 +88,10 @@ private[spark] object CryptoStreamUtils extends Logging { val iv = new Array[Byte](IV_LENGTH_IN_BYTES) ByteStreams.readFully(is, iv) val params = new CryptoParams(key, sparkConf) - new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingInputStream( + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)), + is) } /** @@ -100,8 +106,10 @@ private[spark] object CryptoStreamUtils extends Logging { JavaUtils.readFully(channel, buf) val params = new CryptoParams(key, sparkConf) - new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, - new IvParameterSpec(iv)) + new ErrorHandlingReadableChannel( + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)), + channel) } def toCryptoConf(conf: SparkConf): Properties = { @@ -157,6 +165,117 @@ private[spark] object CryptoStreamUtils extends Logging { } + /** + * SPARK-25535. The commons-cryto library will throw InternalError if something goes + * wrong, and leave bad state behind in the Java wrappers, so it's not safe to use them + * afterwards. This wrapper detects that situation and avoids further calls into the + * commons-crypto code, while still allowing the underlying streams to be closed. + * + * This should be removed once CRYPTO-141 is fixed (and Spark upgrades its commons-crypto + * dependency). + */ + trait BaseErrorHandler extends Closeable { + + private var closed = false + + /** The encrypted stream that may get into an unhealthy state. */ + protected def cipherStream: Closeable + + /** + * The underlying stream that is being wrapped by the encrypted stream, so that it can be + * closed even if there's an error in the crypto layer. + */ + protected def original: Closeable + + protected def safeCall[T](fn: => T): T = { + if (closed) { + throw new IOException("Cipher stream is closed.") + } + try { + fn + } catch { + case ie: InternalError => + closed = true + original.close() + throw ie + } + } + + override def close(): Unit = { + if (!closed) { + cipherStream.close() + } + } + + } + + // Visible for testing. + class ErrorHandlingReadableChannel( + protected val cipherStream: ReadableByteChannel, + protected val original: ReadableByteChannel) + extends ReadableByteChannel with BaseErrorHandler { + + override def read(src: ByteBuffer): Int = safeCall { + cipherStream.read(src) + } + + override def isOpen(): Boolean = cipherStream.isOpen() + + } + + private class ErrorHandlingInputStream( + protected val cipherStream: InputStream, + protected val original: InputStream) + extends InputStream with BaseErrorHandler { + + override def read(b: Array[Byte]): Int = safeCall { + cipherStream.read(b) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = safeCall { + cipherStream.read(b, off, len) + } + + override def read(): Int = safeCall { + cipherStream.read() + } + } + + private class ErrorHandlingWritableChannel( + protected val cipherStream: WritableByteChannel, + protected val original: WritableByteChannel) + extends WritableByteChannel with BaseErrorHandler { + + override def write(src: ByteBuffer): Int = safeCall { + cipherStream.write(src) + } + + override def isOpen(): Boolean = cipherStream.isOpen() + + } + + private class ErrorHandlingOutputStream( + protected val cipherStream: OutputStream, + protected val original: OutputStream) + extends OutputStream with BaseErrorHandler { + + override def flush(): Unit = safeCall { + cipherStream.flush() + } + + override def write(b: Array[Byte]): Unit = safeCall { + cipherStream.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = safeCall { + cipherStream.write(b, off, len) + } + + override def write(b: Int): Unit = safeCall { + cipherStream.write(b) + } + } + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { val keySpec = new SecretKeySpec(key, "AES") diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 78f618f8a2163..0d3611c80b8d0 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,13 +16,16 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} -import java.nio.channels.Channels +import java.io._ +import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel} import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Files import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.apache.spark._ import org.apache.spark.internal.config._ @@ -164,6 +167,36 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("error handling wrapper") { + val wrapped = mock(classOf[ReadableByteChannel]) + val decrypted = mock(classOf[ReadableByteChannel]) + val errorHandler = new CryptoStreamUtils.ErrorHandlingReadableChannel(decrypted, wrapped) + + when(decrypted.read(any(classOf[ByteBuffer]))) + .thenThrow(new IOException()) + .thenThrow(new InternalError()) + .thenReturn(1) + + val out = ByteBuffer.allocate(1) + intercept[IOException] { + errorHandler.read(out) + } + intercept[InternalError] { + errorHandler.read(out) + } + + val e = intercept[IOException] { + errorHandler.read(out) + } + assert(e.getMessage().contains("is closed")) + errorHandler.close() + + verify(decrypted, times(2)).read(any(classOf[ByteBuffer])) + verify(wrapped, never()).read(any(classOf[ByteBuffer])) + verify(decrypted, never()).close() + verify(wrapped, times(1)).close() + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } From faf73dcd33d04365c28c2846d3a1f845785f69df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 9 Oct 2018 21:10:33 +0000 Subject: [PATCH 1805/2461] [SPARK-25559][FOLLOW-UP] Add comments for partial pushdown of conjuncts in Parquet ## What changes were proposed in this pull request? This is a follow up of https://github.com/apache/spark/pull/22574. Renamed the parameter and added comments. ## How was this patch tested? N/A Closes #22679 from gatorsmile/followupSPARK-25559. Authored-by: gatorsmile Signed-off-by: DB Tsai --- .../datasources/parquet/ParquetFilters.scala | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 44a0d209e6e69..21ab9c78e53d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -394,13 +394,22 @@ private[parquet] class ParquetFilters( */ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToParquetField = getFieldMap(schema) - createFilterHelper(nameToParquetField, predicate, canRemoveOneSideInAnd = true) + createFilterHelper(nameToParquetField, predicate, canPartialPushDownConjuncts = true) } + /** + * @param nameToParquetField a map from the field name to its field name and data type. + * This only includes the root fields whose types are primitive types. + * @param predicate the input filter predicates. Not all the predicates can be pushed down. + * @param canPartialPushDownConjuncts whether a subset of conjuncts of predicates can be pushed + * down safely. Pushing ONLY one side of AND down is safe to + * do at the top level or none of its ancestors is NOT and OR. + * @return the Parquet-native filter predicates that are eligible for pushdown. + */ private def createFilterHelper( nameToParquetField: Map[String, ParquetField], predicate: sources.Filter, - canRemoveOneSideInAnd: Boolean): Option[FilterPredicate] = { + canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { @@ -505,24 +514,28 @@ private[parquet] class ParquetFilters( // Pushing one side of AND down is only safe to do at the top level or in the child // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate // can be safely removed. - val lhsFilterOption = createFilterHelper(nameToParquetField, lhs, canRemoveOneSideInAnd) - val rhsFilterOption = createFilterHelper(nameToParquetField, rhs, canRemoveOneSideInAnd) + val lhsFilterOption = + createFilterHelper(nameToParquetField, lhs, canPartialPushDownConjuncts) + val rhsFilterOption = + createFilterHelper(nameToParquetField, rhs, canPartialPushDownConjuncts) (lhsFilterOption, rhsFilterOption) match { case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) - case (Some(lhsFilter), None) if canRemoveOneSideInAnd => Some(lhsFilter) - case (None, Some(rhsFilter)) if canRemoveOneSideInAnd => Some(rhsFilter) + case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter) + case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter) case _ => None } case sources.Or(lhs, rhs) => for { - lhsFilter <- createFilterHelper(nameToParquetField, lhs, canRemoveOneSideInAnd = false) - rhsFilter <- createFilterHelper(nameToParquetField, rhs, canRemoveOneSideInAnd = false) + lhsFilter <- + createFilterHelper(nameToParquetField, lhs, canPartialPushDownConjuncts = false) + rhsFilter <- + createFilterHelper(nameToParquetField, rhs, canPartialPushDownConjuncts = false) } yield FilterApi.or(lhsFilter, rhsFilter) case sources.Not(pred) => - createFilterHelper(nameToParquetField, pred, canRemoveOneSideInAnd = false) + createFilterHelper(nameToParquetField, pred, canPartialPushDownConjuncts = false) .map(FilterApi.not) case sources.In(name, values) if canMakeFilterOn(name, values.head) From 3caab872db22246c9ab5f3395498f05cb097c142 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 10 Oct 2018 21:07:59 +0800 Subject: [PATCH 1806/2461] [SPARK-20946][SPARK-25525][SQL][FOLLOW-UP] Update the migration guide. ## What changes were proposed in this pull request? This is a follow-up pr of #18536 and #22545 to update the migration guide. ## How was this patch tested? Build and check the doc locally. Closes #22682 from ueshin/issues/SPARK-20946_25525/migration_guide. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- docs/sql-programming-guide.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a1d7b1108bf73..0d2935769ae51 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1890,6 +1890,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see # Migration Guide +## Upgrading From Spark SQL 2.4 to 3.0 + + - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder come to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. @@ -2135,6 +2139,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details. + - When creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 2.3, the builder come to not update the configurations. If you want to update them, you need to update them prior to creating a `SparkSession`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. From eaafcd8a22db187e87f09966826dcf677c4c38ea Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 10 Oct 2018 08:25:12 -0700 Subject: [PATCH 1807/2461] [SPARK-25605][TESTS] Alternate take. Run cast string to timestamp tests for a subset of timezones ## What changes were proposed in this pull request? Try testing timezones in parallel instead in CastSuite, instead of random sampling. See also #22631 ## How was this patch tested? Existing test. Closes #22672 from srowen/SPARK-25605.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../org/apache/spark/sql/catalyst/expressions/CastSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 90c0bf7d8b3d7..94dee7ea048c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -112,7 +112,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - for (tz <- Random.shuffle(ALL_TIMEZONES).take(50)) { + ALL_TIMEZONES.par.foreach { tz => def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = { checkEvaluation(cast(Literal(str), TimestampType, Option(tz.getID)), expected) } From 3528c08bebbcad3dee7557945ddcd31c99deb50e Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 10 Oct 2018 08:51:16 -0700 Subject: [PATCH 1808/2461] [SPARK-25611][SPARK-25612][SQL][TESTS] Improve test run time of CompressionCodecSuite ## What changes were proposed in this pull request? Reduced the combination of codecs from 9 to 3 to improve the test runtime. ## How was this patch tested? This is a test fix. Closes #22641 from dilipbiswal/SPARK-25611. Authored-by: Dilip Biswal Signed-off-by: Sean Owen --- .../sql/hive/CompressionCodecSuite.scala | 54 ++++++++----------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 1bd7e52c88ecf..398f4d2efbbf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -229,8 +229,8 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo tableCompressionCodecs: List[String]) (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { - tableCompressionCodecs.foreach { tableCompression => - compressionCodecs.foreach { sessionCompressionCodec => + tableCompressionCodecs.zipAll(compressionCodecs, null, "SNAPPY").foreach { + case (tableCompression, sessionCompressionCodec) => withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { // 'tableCompression = null' means no table-level compression val compression = Option(tableCompression) @@ -240,7 +240,6 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compression, sessionCompressionCodec, realCompressionCodec, tableSize) } } - } } } } @@ -262,7 +261,10 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo } } - def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + def checkForTableWithCompressProp( + format: String, + tableCompressCodecs: List[String], + sessionCompressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => Seq(true, false).foreach { usingCTAS => @@ -271,10 +273,10 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo isPartitioned, convertMetastore, usingCTAS, - compressionCodecs = compressCodecs, - tableCompressionCodecs = compressCodecs) { + compressionCodecs = sessionCompressCodecs, + tableCompressionCodecs = tableCompressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - val expectCodec = tableCodec.get + val expectCodec = tableCodec.getOrElse(sessionCodec) assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) @@ -284,36 +286,22 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo } } - def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { - Seq(true, false).foreach { isPartitioned => - Seq(true, false).foreach { convertMetastore => - Seq(true, false).foreach { usingCTAS => - checkTableCompressionCodecForCodecs( - format, - isPartitioned, - convertMetastore, - usingCTAS, - compressionCodecs = compressCodecs, - tableCompressionCodecs = List(null)) { - case (tableCodec, sessionCodec, realCodec, tableSize) => - // Always expect session-level take effect - assert(sessionCodec == realCodec) - assert(checkTableSize( - format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) - } - } - } - } - } - test("both table-level and session-level compression are set") { - checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) - checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + checkForTableWithCompressProp("parquet", + tableCompressCodecs = List("UNCOMPRESSED", "SNAPPY", "GZIP"), + sessionCompressCodecs = List("SNAPPY", "GZIP", "SNAPPY")) + checkForTableWithCompressProp("orc", + tableCompressCodecs = List("NONE", "SNAPPY", "ZLIB"), + sessionCompressCodecs = List("SNAPPY", "ZLIB", "SNAPPY")) } test("table-level compression is not set but session-level compressions is set ") { - checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) - checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + checkForTableWithCompressProp("parquet", + tableCompressCodecs = List.empty, + sessionCompressCodecs = List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", + tableCompressCodecs = List.empty, + sessionCompressCodecs = List("NONE", "SNAPPY", "ZLIB")) } def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { From 8a7872dc254710f9b29fdfdb2915a949ef606871 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 10 Oct 2018 09:24:36 -0700 Subject: [PATCH 1809/2461] [SPARK-25636][CORE] spark-submit cuts off the failure reason when there is an error connecting to master ## What changes were proposed in this pull request? Cause of the error is wrapped with SparkException, now finding the cause from the wrapped exception and throwing the cause instead of the wrapped exception. ## How was this patch tested? Verified it manually by checking the cause of the error, it gives the error as shown below. ### Without the PR change ``` [apache-spark]$ ./bin/spark-submit --verbose --master spark://****** .... Error: Exception thrown in awaitResult: Run with --help for usage help or --verbose for debug output ``` ### With the PR change ``` [apache-spark]$ ./bin/spark-submit --verbose --master spark://****** .... Exception in thread "main" org.apache.spark.SparkException: Exception thrown in awaitResult: at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:226) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) .... at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: java.io.IOException: Failed to connect to devaraj-pc1/10.3.66.65:7077 at org.apache.spark.network.client.TransportClientFactory.createClient(TransportClientFactory.java:245) .... at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) Caused by: io.netty.channel.AbstractChannel$AnnotatedConnectException: Connection refused: devaraj-pc1/10.3.66.65:7077 at sun.nio.ch.SocketChannelImpl.checkConnect(Native Method) .... at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138) ... 1 more Caused by: java.net.ConnectException: Connection refused ... 11 more ``` Closes #22623 from devaraj-kavali/SPARK-25636. Authored-by: Devaraj K Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/deploy/SparkSubmit.scala | 2 -- .../apache/spark/deploy/SparkSubmitSuite.scala | 17 +++++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d5f2865f87281..61b379f286802 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -927,8 +927,6 @@ object SparkSubmit extends CommandLineUtils with Logging { } catch { case e: SparkUserAppException => exitFn(e.exitCode) - case e: SparkException => - printErrorAndExit(e.getMessage()) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9eae3605d0738..652c36ffa6e71 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -74,20 +74,25 @@ trait TestPrematureExit { @volatile var exitedCleanly = false mainObject.exitFn = (_) => exitedCleanly = true + @volatile var exception: Exception = null val thread = new Thread { override def run() = try { mainObject.main(input) } catch { - // If exceptions occur after the "exit" has happened, fine to ignore them. - // These represent code paths not reachable during normal execution. - case e: Exception => if (!exitedCleanly) throw e + // Capture the exception to check whether the exception contains searchString or not + case e: Exception => exception = e } } thread.start() thread.join() - val joined = printStream.lineBuffer.mkString("\n") - if (!joined.contains(searchString)) { - fail(s"Search string '$searchString' not found in $joined") + if (exitedCleanly) { + val joined = printStream.lineBuffer.mkString("\n") + assert(joined.contains(searchString)) + } else { + assert(exception != null) + if (!exception.getMessage.contains(searchString)) { + throw exception + } } } } From 6df2345794614c33c95fa453cabac755cf94d131 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 10 Oct 2018 18:18:56 +0000 Subject: [PATCH 1810/2461] [SPARK-25699][SQL] Partially push down conjunctive predicated in ORC ## What changes were proposed in this pull request? Inspired by https://github.com/apache/spark/pull/22574 . We can partially push down top level conjunctive predicates to Orc. This PR improves Orc predicate push down in both SQL and Hive module. ## How was this patch tested? New unit test. Closes #22684 from gengliangwang/pushOrcFilters. Authored-by: Gengliang Wang Signed-off-by: DB Tsai --- .../datasources/orc/OrcFilters.scala | 69 ++++++++++++++----- .../datasources/orc/OrcFilterSuite.scala | 37 +++++++++- .../spark/sql/hive/orc/OrcFilters.scala | 69 ++++++++++++++----- .../sql/hive/orc/HiveOrcFilterSuite.scala | 45 +++++++++++- 4 files changed, 186 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index dbafc468c6c40..2b17b479432fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -138,6 +138,23 @@ private[sql] object OrcFilters { dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { + createBuilder(dataTypeMap, expression, builder, canPartialPushDownConjuncts = true) + } + + /** + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input filter predicates. + * @param builder the input SearchArgument.Builder. + * @param canPartialPushDownConjuncts whether a subset of conjuncts of predicates can be pushed + * down safely. Pushing ONLY one side of AND down is safe to + * do at the top level or none of its ancestors is NOT and OR. + * @return the builder so far. + */ + private def createBuilder( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder, + canPartialPushDownConjuncts: Boolean): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = getPredicateLeafType(dataTypeMap(attribute)) @@ -145,32 +162,52 @@ private[sql] object OrcFilters { expression match { case And(left, right) => - // At here, it is not safe to just convert one side if we do not understand the - // other side. Here is an example used to explain the reason. + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to // convert b in ('1'). If we only convert a = 2, we will end up with a filter // NOT(a = 2), which will generate wrong results. - // Pushing one side of AND down is only safe to do at the top level. - // You can see ParquetRelation's initializeLocalJobFunc method as an example. - for { - _ <- buildSearchArgument(dataTypeMap, left, newBuilder) - _ <- buildSearchArgument(dataTypeMap, right, newBuilder) - lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd()) - rhs <- buildSearchArgument(dataTypeMap, right, lhs) - } yield rhs.end() + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val leftBuilderOption = + createBuilder(dataTypeMap, left, newBuilder, canPartialPushDownConjuncts) + val rightBuilderOption = + createBuilder(dataTypeMap, right, newBuilder, canPartialPushDownConjuncts) + (leftBuilderOption, rightBuilderOption) match { + case (Some(_), Some(_)) => + for { + lhs <- createBuilder(dataTypeMap, left, + builder.startAnd(), canPartialPushDownConjuncts) + rhs <- createBuilder(dataTypeMap, right, lhs, canPartialPushDownConjuncts) + } yield rhs.end() + + case (Some(_), None) if canPartialPushDownConjuncts => + createBuilder(dataTypeMap, left, builder, canPartialPushDownConjuncts) + + case (None, Some(_)) if canPartialPushDownConjuncts => + createBuilder(dataTypeMap, right, builder, canPartialPushDownConjuncts) + + case _ => None + } case Or(left, right) => for { - _ <- buildSearchArgument(dataTypeMap, left, newBuilder) - _ <- buildSearchArgument(dataTypeMap, right, newBuilder) - lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr()) - rhs <- buildSearchArgument(dataTypeMap, right, lhs) + _ <- createBuilder(dataTypeMap, left, newBuilder, canPartialPushDownConjuncts = false) + _ <- createBuilder(dataTypeMap, right, newBuilder, canPartialPushDownConjuncts = false) + lhs <- createBuilder(dataTypeMap, left, + builder.startOr(), canPartialPushDownConjuncts = false) + rhs <- createBuilder(dataTypeMap, right, lhs, canPartialPushDownConjuncts = false) } yield rhs.end() case Not(child) => for { - _ <- buildSearchArgument(dataTypeMap, child, newBuilder) - negate <- buildSearchArgument(dataTypeMap, child, builder.startNot()) + _ <- createBuilder(dataTypeMap, child, newBuilder, canPartialPushDownConjuncts = false) + negate <- createBuilder(dataTypeMap, + child, builder.startNot(), canPartialPushDownConjuncts = false) } yield negate.end() // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 8680b86517b19..ee12f30892436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -358,7 +358,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { } } - test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + test("SPARK-12218 and SPARK-25699 Converting conjunctions into ORC SearchArguments") { import org.apache.spark.sql.sources._ // The `LessThan` should be converted while the `StringContains` shouldn't val schema = new StructType( @@ -382,5 +382,40 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { )) )).get.toString } + + // Can not remove unsupported `StringContains` predicate since it is under `Or` operator. + assert(OrcFilters.createFilter(schema, Array( + Or( + LessThan("a", 10), + And( + StringContains("b", "prefix"), + GreaterThan("a", 1) + ) + ) + )).isEmpty) + + // Safely remove unsupported `StringContains` predicate and push down `LessThan` + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + And( + LessThan("a", 10), + StringContains("b", "prefix") + ) + )).get.toString + } + + // Safely remove unsupported `StringContains` predicate, push down `LessThan` and `GreaterThan`. + assertResult("leaf-0 = (LESS_THAN a 10), leaf-1 = (LESS_THAN_EQUALS a 1)," + + " expr = (and leaf-0 (not leaf-1))") { + OrcFilters.createFilter(schema, Array( + And( + And( + LessThan("a", 10), + StringContains("b", "prefix") + ), + GreaterThan("a", 1) + ) + )).get.toString + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index aee9cb58a031e..a82576a233acd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -79,6 +79,23 @@ private[orc] object OrcFilters extends Logging { dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { + createBuilder(dataTypeMap, expression, builder, canPartialPushDownConjuncts = true) + } + + /** + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input filter predicates. + * @param builder the input SearchArgument.Builder. + * @param canPartialPushDownConjuncts whether a subset of conjuncts of predicates can be pushed + * down safely. Pushing ONLY one side of AND down is safe to + * do at the top level or none of its ancestors is NOT and OR. + * @return the builder so far. + */ + private def createBuilder( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder, + canPartialPushDownConjuncts: Boolean): Option[Builder] = { def isSearchableType(dataType: DataType): Boolean = dataType match { // Only the values in the Spark types below can be recognized by // the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. @@ -90,32 +107,52 @@ private[orc] object OrcFilters extends Logging { expression match { case And(left, right) => - // At here, it is not safe to just convert one side if we do not understand the - // other side. Here is an example used to explain the reason. + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to // convert b in ('1'). If we only convert a = 2, we will end up with a filter // NOT(a = 2), which will generate wrong results. - // Pushing one side of AND down is only safe to do at the top level. - // You can see ParquetRelation's initializeLocalJobFunc method as an example. - for { - _ <- buildSearchArgument(dataTypeMap, left, newBuilder) - _ <- buildSearchArgument(dataTypeMap, right, newBuilder) - lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd()) - rhs <- buildSearchArgument(dataTypeMap, right, lhs) - } yield rhs.end() + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val leftBuilderOption = + createBuilder(dataTypeMap, left, newBuilder, canPartialPushDownConjuncts) + val rightBuilderOption = + createBuilder(dataTypeMap, right, newBuilder, canPartialPushDownConjuncts) + (leftBuilderOption, rightBuilderOption) match { + case (Some(_), Some(_)) => + for { + lhs <- createBuilder(dataTypeMap, left, + builder.startAnd(), canPartialPushDownConjuncts) + rhs <- createBuilder(dataTypeMap, right, lhs, canPartialPushDownConjuncts) + } yield rhs.end() + + case (Some(_), None) if canPartialPushDownConjuncts => + createBuilder(dataTypeMap, left, builder, canPartialPushDownConjuncts) + + case (None, Some(_)) if canPartialPushDownConjuncts => + createBuilder(dataTypeMap, right, builder, canPartialPushDownConjuncts) + + case _ => None + } case Or(left, right) => for { - _ <- buildSearchArgument(dataTypeMap, left, newBuilder) - _ <- buildSearchArgument(dataTypeMap, right, newBuilder) - lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr()) - rhs <- buildSearchArgument(dataTypeMap, right, lhs) + _ <- createBuilder(dataTypeMap, left, newBuilder, canPartialPushDownConjuncts = false) + _ <- createBuilder(dataTypeMap, right, newBuilder, canPartialPushDownConjuncts = false) + lhs <- createBuilder(dataTypeMap, left, + builder.startOr(), canPartialPushDownConjuncts = false) + rhs <- createBuilder(dataTypeMap, right, lhs, canPartialPushDownConjuncts = false) } yield rhs.end() case Not(child) => for { - _ <- buildSearchArgument(dataTypeMap, child, newBuilder) - negate <- buildSearchArgument(dataTypeMap, child, builder.startNot()) + _ <- createBuilder(dataTypeMap, child, newBuilder, canPartialPushDownConjuncts = false) + negate <- createBuilder(dataTypeMap, + child, builder.startNot(), canPartialPushDownConjuncts = false) } yield negate.end() // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala index 283037caf4a9b..5094763b0cd2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala @@ -351,7 +351,7 @@ class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton { } } - test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + test("SPARK-12218 and SPARK-25699 Converting conjunctions into ORC SearchArguments") { import org.apache.spark.sql.sources._ // The `LessThan` should be converted while the `StringContains` shouldn't val schema = new StructType( @@ -383,5 +383,48 @@ class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton { )) )).get.toString } + + // Can not remove unsupported `StringContains` predicate since it is under `Or` operator. + assert(OrcFilters.createFilter(schema, Array( + Or( + LessThan("a", 10), + And( + StringContains("b", "prefix"), + GreaterThan("a", 1) + ) + ) + )).isEmpty) + + // Safely remove unsupported `StringContains` predicate and push down `LessThan` + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(schema, Array( + And( + LessThan("a", 10), + StringContains("b", "prefix") + ) + )).get.toString + } + + // Safely remove unsupported `StringContains` predicate, push down `LessThan` and `GreaterThan`. + assertResult( + """leaf-0 = (LESS_THAN a 10) + |leaf-1 = (LESS_THAN_EQUALS a 1) + |expr = (and leaf-0 (not leaf-1)) + """.stripMargin.trim + ) { + OrcFilters.createFilter(schema, Array( + And( + And( + LessThan("a", 10), + StringContains("b", "prefix") + ), + GreaterThan("a", 1) + ) + )).get.toString + } } } From 80813e198033cd63cc6100ee6ffe7d1eb1dff27b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 10 Oct 2018 12:07:53 -0700 Subject: [PATCH 1811/2461] [SPARK-25016][BUILD][CORE] Remove support for Hadoop 2.6 ## What changes were proposed in this pull request? Remove Hadoop 2.6 references and make 2.7 the default. Obviously, this is for master/3.0.0 only. After this we can also get rid of the separate test jobs for Hadoop 2.6. ## How was this patch tested? Existing tests Closes #22615 from srowen/SPARK-25016. Authored-by: Sean Owen Signed-off-by: Sean Owen --- dev/appveyor-install-dependencies.ps1 | 3 +- dev/create-release/release-build.sh | 43 ++-- dev/deps/spark-deps-hadoop-2.6 | 198 ------------------ dev/run-tests.py | 15 +- dev/test-dependencies.sh | 1 - docs/building-spark.md | 11 +- docs/index.md | 3 - docs/running-on-yarn.md | 3 +- hadoop-cloud/pom.xml | 59 +++--- pom.xml | 14 +- .../dev/dev-run-integration-tests.sh | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 13 +- .../apache/spark/sql/hive/TableReader.scala | 2 +- .../hive/client/IsolatedClientLoader.scala | 11 +- 14 files changed, 68 insertions(+), 310 deletions(-) delete mode 100644 dev/deps/spark-deps-hadoop-2.6 diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index 8a04b621f8ce4..c91882851847b 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -95,7 +95,8 @@ $env:MAVEN_OPTS = "-Xmx2g -XX:ReservedCodeCacheSize=512m" Pop-Location # ========================== Hadoop bin package -$hadoopVer = "2.6.4" +# This must match the version at https://github.com/steveloughran/winutils/tree/master/hadoop-2.7.1 +$hadoopVer = "2.7.1" $hadoopPath = "$tools\hadoop" if (!(Test-Path $hadoopPath)) { New-Item -ItemType Directory -Force -Path $hadoopPath | Out-Null diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index cce5f8b6975ca..89593cfa0107a 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -191,9 +191,19 @@ if [[ "$1" == "package" ]]; then make_binary_release() { NAME=$1 FLAGS="$MVN_EXTRA_OPTS -B $BASE_RELEASE_PROFILES $2" + # BUILD_PACKAGE can be "withpip", "withr", or both as "withpip,withr" BUILD_PACKAGE=$3 SCALA_VERSION=$4 + PIP_FLAG="" + if [[ $BUILD_PACKAGE == *"withpip"* ]]; then + PIP_FLAG="--pip" + fi + R_FLAG="" + if [[ $BUILD_PACKAGE == *"withr"* ]]; then + R_FLAG="--r" + fi + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. ZINC_PORT=$((ZINC_PORT + 1)) @@ -217,18 +227,13 @@ if [[ "$1" == "package" ]]; then # Get maven home set by MVN MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` + echo "Creating distribution" + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz \ + $PIP_FLAG $R_FLAG $FLAGS \ + -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log + cd .. - if [ -z "$BUILD_PACKAGE" ]; then - echo "Creating distribution without PIP/R package" - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ - -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log - cd .. - elif [[ "$BUILD_PACKAGE" == "withr" ]]; then - echo "Creating distribution with R package" - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --r $FLAGS \ - -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log - cd .. - + if [[ -n $R_FLAG ]]; then echo "Copying and signing R source package" R_DIST_NAME=SparkR_$SPARK_VERSION.tar.gz cp spark-$SPARK_VERSION-bin-$NAME/R/$R_DIST_NAME . @@ -239,12 +244,9 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ $R_DIST_NAME.sha512 - else - echo "Creating distribution with PIP package" - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ - -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log - cd .. + fi + if [[ -n $PIP_FLAG ]]; then echo "Copying and signing python distribution" PYTHON_DIST_NAME=pyspark-$PYSPARK_VERSION.tar.gz cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_DIST_NAME . @@ -277,8 +279,10 @@ if [[ "$1" == "package" ]]; then declare -A BINARY_PKGS_ARGS BINARY_PKGS_ARGS["hadoop2.7"]="-Phadoop-2.7 $HIVE_PROFILES" if ! is_dry_run; then - BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" BINARY_PKGS_ARGS["without-hadoop"]="-Phadoop-provided" + if [[ $SPARK_VERSION < "3.0." ]]; then + BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES" + fi if [[ $SPARK_VERSION < "2.2." ]]; then BINARY_PKGS_ARGS["hadoop2.4"]="-Phadoop-2.4 $HIVE_PROFILES" BINARY_PKGS_ARGS["hadoop2.3"]="-Phadoop-2.3 $HIVE_PROFILES" @@ -286,10 +290,7 @@ if [[ "$1" == "package" ]]; then fi declare -A BINARY_PKGS_EXTRA - BINARY_PKGS_EXTRA["hadoop2.7"]="withpip" - if ! is_dry_run; then - BINARY_PKGS_EXTRA["hadoop2.6"]="withr" - fi + BINARY_PKGS_EXTRA["hadoop2.7"]="withpip,withr" echo "Packages to build: ${!BINARY_PKGS_ARGS[@]}" for key in ${!BINARY_PKGS_ARGS[@]}; do diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 deleted file mode 100644 index e0e3e0a82e730..0000000000000 --- a/dev/deps/spark-deps-hadoop-2.6 +++ /dev/null @@ -1,198 +0,0 @@ -JavaEWAH-0.3.2.jar -RoaringBitmap-0.5.11.jar -ST4-4.0.4.jar -activation-1.1.1.jar -aircompressor-0.10.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar -antlr4-runtime-4.7.jar -aopalliance-1.0.jar -aopalliance-repackaged-2.4.0-b34.jar -apache-log4j-extras-1.2.17.jar -apacheds-i18n-2.0.0-M15.jar -apacheds-kerberos-codec-2.0.0-M15.jar -api-asn1-api-1.0.0-M20.jar -api-util-1.0.0-M20.jar -arpack_combined_all-0.1.jar -arrow-format-0.10.0.jar -arrow-memory-0.10.0.jar -arrow-vector-0.10.0.jar -automaton-1.11-8.jar -avro-1.8.2.jar -avro-ipc-1.8.2.jar -avro-mapred-1.8.2-hadoop2.jar -bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar -calcite-avatica-1.2.0-incubating.jar -calcite-core-1.2.0-incubating.jar -calcite-linq4j-1.2.0-incubating.jar -chill-java-0.9.3.jar -chill_2.11-0.9.3.jar -commons-beanutils-1.7.0.jar -commons-beanutils-core-1.8.0.jar -commons-cli-1.2.jar -commons-codec-1.10.jar -commons-collections-3.2.2.jar -commons-compiler-3.0.10.jar -commons-compress-1.8.1.jar -commons-configuration-1.6.jar -commons-crypto-1.0.0.jar -commons-dbcp-1.4.jar -commons-digester-1.8.jar -commons-httpclient-3.1.jar -commons-io-2.4.jar -commons-lang-2.6.jar -commons-lang3-3.5.jar -commons-logging-1.1.3.jar -commons-math3-3.4.1.jar -commons-net-3.1.jar -commons-pool-1.5.4.jar -compress-lzf-1.0.3.jar -core-1.1.2.jar -curator-client-2.6.0.jar -curator-framework-2.6.0.jar -curator-recipes-2.6.0.jar -datanucleus-api-jdo-3.2.6.jar -datanucleus-core-3.2.10.jar -datanucleus-rdbms-3.2.9.jar -derby-10.12.1.1.jar -eigenbase-properties-1.1.5.jar -flatbuffers-1.2.0-3f79e055.jar -generex-1.0.1.jar -gson-2.2.4.jar -guava-14.0.1.jar -guice-3.0.jar -guice-servlet-3.0.jar -hadoop-annotations-2.6.5.jar -hadoop-auth-2.6.5.jar -hadoop-client-2.6.5.jar -hadoop-common-2.6.5.jar -hadoop-hdfs-2.6.5.jar -hadoop-mapreduce-client-app-2.6.5.jar -hadoop-mapreduce-client-common-2.6.5.jar -hadoop-mapreduce-client-core-2.6.5.jar -hadoop-mapreduce-client-jobclient-2.6.5.jar -hadoop-mapreduce-client-shuffle-2.6.5.jar -hadoop-yarn-api-2.6.5.jar -hadoop-yarn-client-2.6.5.jar -hadoop-yarn-common-2.6.5.jar -hadoop-yarn-server-common-2.6.5.jar -hadoop-yarn-server-web-proxy-2.6.5.jar -hk2-api-2.4.0-b34.jar -hk2-locator-2.4.0-b34.jar -hk2-utils-2.4.0-b34.jar -hppc-0.7.2.jar -htrace-core-3.0.4.jar -httpclient-4.5.6.jar -httpcore-4.4.10.jar -ivy-2.4.0.jar -jackson-annotations-2.9.6.jar -jackson-core-2.9.6.jar -jackson-core-asl-1.9.13.jar -jackson-databind-2.9.6.jar -jackson-dataformat-yaml-2.9.6.jar -jackson-jaxrs-1.9.13.jar -jackson-mapper-asl-1.9.13.jar -jackson-module-jaxb-annotations-2.9.6.jar -jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.11-2.9.6.jar -jackson-xc-1.9.13.jar -janino-3.0.10.jar -javassist-3.18.1-GA.jar -javax.annotation-api-1.2.jar -javax.inject-1.jar -javax.inject-2.4.0-b34.jar -javax.servlet-api-3.1.0.jar -javax.ws.rs-api-2.0.1.jar -javolution-5.5.1.jar -jaxb-api-2.2.2.jar -jcl-over-slf4j-1.7.16.jar -jdo-api-3.0.1.jar -jersey-client-2.22.2.jar -jersey-common-2.22.2.jar -jersey-container-servlet-2.22.2.jar -jersey-container-servlet-core-2.22.2.jar -jersey-guava-2.22.2.jar -jersey-media-jaxb-2.22.2.jar -jersey-server-2.22.2.jar -jetty-6.1.26.jar -jetty-util-6.1.26.jar -jline-2.14.6.jar -joda-time-2.9.3.jar -jodd-core-3.5.2.jar -jpam-1.1.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar -jsr305-1.3.9.jar -jta-1.1.jar -jtransforms-2.4.0.jar -jul-to-slf4j-1.7.16.jar -kryo-shaded-4.0.2.jar -kubernetes-client-3.0.0.jar -kubernetes-model-2.0.0.jar -leveldbjni-all-1.8.jar -libfb303-0.9.3.jar -libthrift-0.9.3.jar -log4j-1.2.17.jar -logging-interceptor-3.8.1.jar -lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar -mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.1.5.jar -metrics-graphite-3.1.5.jar -metrics-json-3.1.5.jar -metrics-jvm-3.1.5.jar -minlog-1.3.0.jar -netty-3.9.9.Final.jar -netty-all-4.1.17.Final.jar -objenesis-2.5.1.jar -okhttp-3.8.1.jar -okio-1.13.0.jar -opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar -oro-2.0.8.jar -osgi-resource-locator-1.0.1.jar -paranamer-2.8.jar -parquet-column-1.10.0.jar -parquet-common-1.10.0.jar -parquet-encoding-1.10.0.jar -parquet-format-2.4.0.jar -parquet-hadoop-1.10.0.jar -parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.10.0.jar -protobuf-java-2.5.0.jar -py4j-0.10.7.jar -pyrolite-4.13.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar -slf4j-api-1.7.16.jar -slf4j-log4j12-1.7.16.jar -snakeyaml-1.18.jar -snappy-0.2.jar -snappy-java-1.1.7.1.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar -stax-api-1.0-2.jar -stax-api-1.0.1.jar -stream-2.7.0.jar -stringtemplate-3.2.1.jar -super-csv-2.2.0.jar -univocity-parsers-2.7.3.jar -validation-api-1.1.0.Final.jar -xbean-asm6-shaded-4.8.jar -xercesImpl-2.9.1.jar -xmlenc-0.52.jar -xz-1.5.jar -zjsonpatch-0.3.0.jar -zookeeper-3.4.6.jar -zstd-jni-1.3.2-2.jar diff --git a/dev/run-tests.py b/dev/run-tests.py index f534637b80d6b..271360b6048a3 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -305,7 +305,6 @@ def get_hadoop_profiles(hadoop_version): """ sbt_maven_hadoop_profiles = { - "hadoop2.6": ["-Phadoop-2.6"], "hadoop2.7": ["-Phadoop-2.7"], } @@ -369,15 +368,7 @@ def build_spark_assembly_sbt(hadoop_version, checkstyle=False): if checkstyle: run_java_style_checks() - # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. - # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the - # documentation build fails on a specific machine & environment in Jenkins but it was unable - # to reproduce. Please see SPARK-20343. This is a band-aid fix that should be removed in - # the future. - is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6" - if not is_hadoop_version_2_6: - # Make sure that Java and Scala API documentation can be generated - build_spark_unidoc_sbt(hadoop_version) + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): @@ -528,14 +519,14 @@ def main(): # if we're on the Amplab Jenkins build servers setup variables # to reflect the environment settings build_tool = os.environ.get("AMPLAB_JENKINS_BUILD_TOOL", "sbt") - hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop2.6") + hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop2.7") test_env = "amplab_jenkins" # add path for Python3 in Jenkins if we're calling from a Jenkins machine os.environ["PATH"] = "/home/anaconda/envs/py3k/bin:" + os.environ.get("PATH") else: # else we're running locally and can use local settings build_tool = "sbt" - hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop2.6") + hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop2.7") test_env = "local" print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 2fbd6b5e98f7f..a3627c9b9b0a7 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -32,7 +32,6 @@ export LC_ALL=C HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( - hadoop-2.6 hadoop-2.7 hadoop-3.1 ) diff --git a/docs/building-spark.md b/docs/building-spark.md index 1501f0bb84544..b9e171547c3c0 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -49,25 +49,20 @@ To create a Spark distribution like those distributed by the to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: - ./dev/make-distribution.sh --name custom-spark --pip --r --tgz -Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pmesos -Pyarn -Pkubernetes + ./dev/make-distribution.sh --name custom-spark --pip --r --tgz -Psparkr -Phive -Phive-thriftserver -Pmesos -Pyarn -Pkubernetes This will build Spark distribution along with Python pip and R packages. For more information on usage, run `./dev/make-distribution.sh --help` ## Specifying the Hadoop Version and Enabling YARN You can specify the exact version of Hadoop to compile against through the `hadoop.version` property. -If unset, Spark will build against Hadoop 2.6.X by default. You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. -Examples: +Example: - # Apache Hadoop 2.6.X - ./build/mvn -Pyarn -DskipTests clean package - - # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.3 -DskipTests clean package + ./build/mvn -Pyarn -Dhadoop.version=2.8.5 -DskipTests clean package ## Building With Hive and JDBC Support diff --git a/docs/index.md b/docs/index.md index 40f628b794c01..d269f54c73439 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,9 +30,6 @@ Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{s uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). -Note that support for Java 7, Python 2.6 and old Hadoop versions before 2.6.5 were removed as of Spark 2.2.0. -Support for Scala 2.10 was removed as of 2.3.0. - # Running the Examples and Shell Spark comes with several sample programs. Scala, Java, Python and R examples are in the diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 687f9e46c3285..bdf7b99966e4f 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -396,8 +396,7 @@ To use a custom metrics.properties for the application master and executors, upd and those log files will be aggregated in a rolling fashion. This will be used with YARN's rolling log aggregation, to enable this feature in YARN side yarn.nodemanager.log-aggregation.roll-monitoring-interval-seconds should be - configured in yarn-site.xml. - This feature can only be used with Hadoop 2.6.4+. The Spark log4j appender needs be changed to use + configured in yarn-site.xml. The Spark log4j appender needs be changed to use FileAppender or another appender that can handle the files being removed while it is running. Based on the file name configured in the log4j configuration (like spark.log), the user should set the regex (spark*) to include all the log files that need to be aggregated. diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index d48162007e675..3182ab15db5f5 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -166,45 +166,34 @@ httpcore ${hadoop.deps.scope} + + org.apache.hadoop + hadoop-azure + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + - - hadoop-2.7 - - - - - - org.apache.hadoop - hadoop-azure - ${hadoop.version} - ${hadoop.deps.scope} - - - org.apache.hadoop - hadoop-common - - - org.codehaus.jackson - jackson-mapper-asl - - - com.fasterxml.jackson.core - jackson-core - - - com.google.guava - guava - - - - - - 1.2.1.spark2 @@ -2674,17 +2674,9 @@ http://hadoop.apache.org/docs/ra.b.c/hadoop-project-dist/hadoop-common/dependency-analysis.html --> - - hadoop-2.6 - - - hadoop-2.7 - - 2.7.3 - 2.7.1 - + diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index b28b8b82ca016..e26c0b3a39c90 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -103,4 +103,4 @@ then properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) fi -$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes -Phadoop-2.7 ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes ${properties[@]} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4a85898ef880b..01bdebc000b9f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -273,19 +273,10 @@ private[spark] class Client( sparkConf.get(ROLLED_LOG_INCLUDE_PATTERN).foreach { includePattern => try { val logAggregationContext = Records.newRecord(classOf[LogAggregationContext]) - - // These two methods were added in Hadoop 2.6.4, so we still need to use reflection to - // avoid compile error when building against Hadoop 2.6.0 ~ 2.6.3. - val setRolledLogsIncludePatternMethod = - logAggregationContext.getClass.getMethod("setRolledLogsIncludePattern", classOf[String]) - setRolledLogsIncludePatternMethod.invoke(logAggregationContext, includePattern) - + logAggregationContext.setRolledLogsIncludePattern(includePattern) sparkConf.get(ROLLED_LOG_EXCLUDE_PATTERN).foreach { excludePattern => - val setRolledLogsExcludePatternMethod = - logAggregationContext.getClass.getMethod("setRolledLogsExcludePattern", classOf[String]) - setRolledLogsExcludePatternMethod.invoke(logAggregationContext, excludePattern) + logAggregationContext.setRolledLogsExcludePattern(excludePattern) } - appContext.setLogAggregationContext(logAggregationContext) } catch { case NonFatal(e) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 7d57389947576..9443fbb4330a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -71,7 +71,7 @@ class HadoopTableReader( // Hadoop honors "mapreduce.job.maps" as hint, // but will ignore when mapreduce.jobtracker.address is "local". - // https://hadoop.apache.org/docs/r2.6.5/hadoop-mapreduce-client/hadoop-mapreduce-client-core/ + // https://hadoop.apache.org/docs/r2.7.6/hadoop-mapreduce-client/hadoop-mapreduce-client-core/ // mapred-default.xml // // In order keep consistency with Hive, we will let it be 0 in local mode also. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 6a90c44a2633d..31899370454ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -53,7 +53,7 @@ private[hive] object IsolatedClientLoader extends Logging { sharesHadoopClasses: Boolean = true): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(hiveMetastoreVersion) // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact - // with the given version, we will use Hadoop 2.6 and then will not share Hadoop classes. + // with the given version, we will use Hadoop 2.7 and then will not share Hadoop classes. var _sharesHadoopClasses = sharesHadoopClasses val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { resolvedVersions((resolvedVersion, hadoopVersion)) @@ -65,13 +65,14 @@ private[hive] object IsolatedClientLoader extends Logging { case e: RuntimeException if e.getMessage.contains("hadoop") => // If the error message contains hadoop, it is probably because the hadoop // version cannot be resolved. - logWarning(s"Failed to resolve Hadoop artifacts for the version $hadoopVersion. " + - s"We will change the hadoop version from $hadoopVersion to 2.6.0 and try again. " + - "Hadoop classes will not be shared between Spark and Hive metastore client. " + + val fallbackVersion = "2.7.3" + logWarning(s"Failed to resolve Hadoop artifacts for the version $hadoopVersion. We " + + s"will change the hadoop version from $hadoopVersion to $fallbackVersion and try " + + "again. Hadoop classes will not be shared between Spark and Hive metastore client. " + "It is recommended to set jars used by Hive metastore client through " + "spark.sql.hive.metastore.jars in the production environment.") _sharesHadoopClasses = false - (downloadVersion(resolvedVersion, "2.6.5", ivyPath), "2.6.5") + (downloadVersion(resolvedVersion, fallbackVersion, ivyPath), fallbackVersion) } resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) resolvedVersions((resolvedVersion, actualHadoopVersion)) From 83e19d5b80fac6ea4b29d8eb561a5ad06835171b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 11 Oct 2018 09:35:49 -0700 Subject: [PATCH 1812/2461] [SPARK-25700][SQL] Creates ReadSupport in only Append Mode in Data Source V2 write path ## What changes were proposed in this pull request? This PR proposes to avoid to make a readsupport and read schema when it writes in other save modes. https://github.com/apache/spark/commit/5fef6e3513d6023a837c427d183006d153c7102b happened to create a readsupport in write path, which ended up with reading schema from readsupport at write path. This breaks `spark.range(1).format("source").write.save("non-existent-path")` case since there's no way to read the schema from "non-existent-path". See also https://github.com/apache/spark/pull/22009#discussion_r223982672 See also https://github.com/apache/spark/pull/22697 See also http://apache-spark-developers-list.1001551.n3.nabble.com/Possible-bug-in-DatasourceV2-td25343.html ## How was this patch tested? Unit test and manual tests. Closes #22688 from HyukjinKwon/append-revert-2. Authored-by: hyukjinkwon Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 29 +++++++++++++++++++ .../sources/v2/SimpleWritableDataSource.scala | 5 ++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 188fce72efac5..55e538f49feda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -246,8 +246,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { df.sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions - val relation = DataSourceV2Relation.create(source, options) if (mode == SaveMode.Append) { + val relation = DataSourceV2Relation.create(source, options) runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 7cc8abc9f0428..e8f291af13baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -351,6 +351,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-25700: do not read schema when writing in other modes except append mode") { + withTempPath { file => + val cls = classOf[SimpleWriteOnlyDataSource] + val path = file.getCanonicalPath + val df = spark.range(5).select('id as 'i, -'id as 'j) + try { + df.write.format(cls.getName).option("path", path).mode("error").save() + df.write.format(cls.getName).option("path", path).mode("overwrite").save() + df.write.format(cls.getName).option("path", path).mode("ignore").save() + } catch { + case e: SchemaReadAttemptException => fail("Schema read was attempted.", e) + } + intercept[SchemaReadAttemptException] { + df.write.format(cls.getName).option("path", path).mode("append").save() + } + } + } } @@ -640,3 +658,14 @@ object SpecificReaderFactory extends PartitionReaderFactory { } } } + +class SchemaReadAttemptException(m: String) extends RuntimeException(m) + +class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { + override def fullSchema(): StructType = { + // This is a bit hacky since this source implements read support but throws + // during schema retrieval. Might have to rewrite but it's done + // such so for minimised changes. + throw new SchemaReadAttemptException("read is not supported") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a0f4404f46140..a7dfc2d1deacc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -43,13 +43,13 @@ class SimpleWritableDataSource extends DataSourceV2 with BatchWriteSupportProvider with SessionConfigSupport { - private val schema = new StructType().add("i", "long").add("j", "long") + protected def fullSchema(): StructType = new StructType().add("i", "long").add("j", "long") override def keyPrefix: String = "simpleWritableDataSource" class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - override def fullSchema(): StructType = schema + override def fullSchema(): StructType = SimpleWritableDataSource.this.fullSchema() override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { val dataPath = new Path(path) @@ -116,7 +116,6 @@ class SimpleWritableDataSource extends DataSourceV2 schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[BatchWriteSupport] = { - assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) val path = new Path(options.get("path").get()) From 8115e6b26916c42491d712c06c73c045f4ee17e1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 11 Oct 2018 20:27:07 +0000 Subject: [PATCH 1813/2461] [SPARK-25662][SQL][TEST] Refactor DataSourceReadBenchmark to use main method ## What changes were proposed in this pull request? 1. Refactor DataSourceReadBenchmark ## How was this patch tested? Manually tested and regenerated results. ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.DataSourceReadBenchmark" ``` Closes #22664 from peter-toth/SPARK-25662. Lead-authored-by: Peter Toth Co-authored-by: Dongjoon Hyun Signed-off-by: DB Tsai --- .../DataSourceReadBenchmark-results.txt | 269 ++++++++++++++++ .../benchmark/DataSourceReadBenchmark.scala | 300 +++--------------- 2 files changed, 316 insertions(+), 253 deletions(-) create mode 100644 sql/core/benchmarks/DataSourceReadBenchmark-results.txt diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt new file mode 100644 index 0000000000000..2d3bae442cc50 --- /dev/null +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -0,0 +1,269 @@ +================================================================================================ +SQL Single Numeric Column Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 21508 / 22112 0.7 1367.5 1.0X +SQL Json 8705 / 8825 1.8 553.4 2.5X +SQL Parquet Vectorized 157 / 186 100.0 10.0 136.7X +SQL Parquet MR 1789 / 1794 8.8 113.8 12.0X +SQL ORC Vectorized 156 / 166 100.9 9.9 138.0X +SQL ORC Vectorized with copy 218 / 225 72.1 13.9 98.6X +SQL ORC MR 1448 / 1492 10.9 92.0 14.9X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet Reader Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +ParquetReader Vectorized 202 / 211 77.7 12.9 1.0X +ParquetReader Vectorized -> Row 118 / 120 133.5 7.5 1.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 23282 / 23312 0.7 1480.2 1.0X +SQL Json 9187 / 9189 1.7 584.1 2.5X +SQL Parquet Vectorized 204 / 218 77.0 13.0 114.0X +SQL Parquet MR 1941 / 1953 8.1 123.4 12.0X +SQL ORC Vectorized 217 / 225 72.6 13.8 107.5X +SQL ORC Vectorized with copy 279 / 289 56.3 17.8 83.4X +SQL ORC MR 1541 / 1549 10.2 98.0 15.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet Reader Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +ParquetReader Vectorized 288 / 297 54.6 18.3 1.0X +ParquetReader Vectorized -> Row 255 / 257 61.7 16.2 1.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 24990 / 25012 0.6 1588.8 1.0X +SQL Json 9837 / 9865 1.6 625.4 2.5X +SQL Parquet Vectorized 170 / 180 92.3 10.8 146.6X +SQL Parquet MR 2319 / 2328 6.8 147.4 10.8X +SQL ORC Vectorized 293 / 301 53.7 18.6 85.3X +SQL ORC Vectorized with copy 297 / 309 52.9 18.9 84.0X +SQL ORC MR 1667 / 1674 9.4 106.0 15.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet Reader Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +ParquetReader Vectorized 257 / 274 61.3 16.3 1.0X +ParquetReader Vectorized -> Row 259 / 264 60.8 16.4 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 32537 / 32554 0.5 2068.7 1.0X +SQL Json 12610 / 12668 1.2 801.7 2.6X +SQL Parquet Vectorized 258 / 276 61.0 16.4 126.2X +SQL Parquet MR 2422 / 2435 6.5 154.0 13.4X +SQL ORC Vectorized 378 / 385 41.6 24.0 86.2X +SQL ORC Vectorized with copy 381 / 389 41.3 24.2 85.4X +SQL ORC MR 1797 / 1819 8.8 114.3 18.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet Reader Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +ParquetReader Vectorized 352 / 368 44.7 22.4 1.0X +ParquetReader Vectorized -> Row 351 / 359 44.8 22.3 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 27179 / 27184 0.6 1728.0 1.0X +SQL Json 12578 / 12585 1.3 799.7 2.2X +SQL Parquet Vectorized 161 / 171 97.5 10.3 168.5X +SQL Parquet MR 2361 / 2395 6.7 150.1 11.5X +SQL ORC Vectorized 473 / 480 33.3 30.0 57.5X +SQL ORC Vectorized with copy 478 / 483 32.9 30.4 56.8X +SQL ORC MR 1858 / 1859 8.5 118.2 14.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet Reader Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +ParquetReader Vectorized 251 / 255 62.7 15.9 1.0X +ParquetReader Vectorized -> Row 255 / 259 61.8 16.2 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 34797 / 34830 0.5 2212.3 1.0X +SQL Json 17806 / 17828 0.9 1132.1 2.0X +SQL Parquet Vectorized 260 / 269 60.6 16.5 134.0X +SQL Parquet MR 2512 / 2534 6.3 159.7 13.9X +SQL ORC Vectorized 582 / 593 27.0 37.0 59.8X +SQL ORC Vectorized with copy 576 / 584 27.3 36.6 60.4X +SQL ORC MR 2309 / 2313 6.8 146.8 15.1X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet Reader Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +ParquetReader Vectorized 350 / 363 44.9 22.3 1.0X +ParquetReader Vectorized -> Row 350 / 366 44.9 22.3 1.0X + + +================================================================================================ +Int and String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 22486 / 22590 0.5 2144.5 1.0X +SQL Json 14124 / 14195 0.7 1347.0 1.6X +SQL Parquet Vectorized 2342 / 2347 4.5 223.4 9.6X +SQL Parquet MR 4660 / 4664 2.2 444.4 4.8X +SQL ORC Vectorized 2378 / 2379 4.4 226.8 9.5X +SQL ORC Vectorized with copy 2548 / 2571 4.1 243.0 8.8X +SQL ORC MR 4206 / 4211 2.5 401.1 5.3X + + +================================================================================================ +Repeated String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 12150 / 12178 0.9 1158.7 1.0X +SQL Json 7012 / 7014 1.5 668.7 1.7X +SQL Parquet Vectorized 792 / 796 13.2 75.5 15.3X +SQL Parquet MR 1961 / 1975 5.3 187.0 6.2X +SQL ORC Vectorized 482 / 485 21.8 46.0 25.2X +SQL ORC Vectorized with copy 710 / 715 14.8 67.7 17.1X +SQL ORC MR 2081 / 2083 5.0 198.5 5.8X + + +================================================================================================ +Partitioned Table Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Data column - CSV 31789 / 31791 0.5 2021.1 1.0X +Data column - Json 12873 / 12918 1.2 818.4 2.5X +Data column - Parquet Vectorized 267 / 280 58.9 17.0 119.1X +Data column - Parquet MR 3387 / 3402 4.6 215.3 9.4X +Data column - ORC Vectorized 391 / 453 40.2 24.9 81.2X +Data column - ORC Vectorized with copy 392 / 398 40.2 24.9 81.2X +Data column - ORC MR 2508 / 2512 6.3 159.4 12.7X +Partition column - CSV 6965 / 6977 2.3 442.8 4.6X +Partition column - Json 5563 / 5576 2.8 353.7 5.7X +Partition column - Parquet Vectorized 65 / 78 241.1 4.1 487.2X +Partition column - Parquet MR 1811 / 1811 8.7 115.1 17.6X +Partition column - ORC Vectorized 66 / 73 239.0 4.2 483.0X +Partition column - ORC Vectorized with copy 65 / 70 241.1 4.1 487.3X +Partition column - ORC MR 1775 / 1778 8.9 112.8 17.9X +Both columns - CSV 30032 / 30113 0.5 1909.4 1.1X +Both columns - Json 13941 / 13959 1.1 886.3 2.3X +Both columns - Parquet Vectorized 312 / 330 50.3 19.9 101.7X +Both columns - Parquet MR 3858 / 3862 4.1 245.3 8.2X +Both columns - ORC Vectorized 431 / 437 36.5 27.4 73.8X +Both column - ORC Vectorized with copy 523 / 529 30.1 33.3 60.7X +Both columns - ORC MR 2712 / 2805 5.8 172.4 11.7X + + +================================================================================================ +String with Nulls Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 13525 / 13823 0.8 1289.9 1.0X +SQL Json 9913 / 9921 1.1 945.3 1.4X +SQL Parquet Vectorized 1517 / 1517 6.9 144.7 8.9X +SQL Parquet MR 3996 / 4008 2.6 381.1 3.4X +ParquetReader Vectorized 1120 / 1128 9.4 106.8 12.1X +SQL ORC Vectorized 1203 / 1224 8.7 114.7 11.2X +SQL ORC Vectorized with copy 1639 / 1646 6.4 156.3 8.3X +SQL ORC MR 3720 / 3780 2.8 354.7 3.6X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 15860 / 15877 0.7 1512.5 1.0X +SQL Json 7676 / 7688 1.4 732.0 2.1X +SQL Parquet Vectorized 1072 / 1084 9.8 102.2 14.8X +SQL Parquet MR 2890 / 2897 3.6 275.6 5.5X +ParquetReader Vectorized 1052 / 1053 10.0 100.4 15.1X +SQL ORC Vectorized 1248 / 1248 8.4 119.0 12.7X +SQL ORC Vectorized with copy 1627 / 1637 6.4 155.2 9.7X +SQL ORC MR 3365 / 3369 3.1 320.9 4.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 13401 / 13561 0.8 1278.1 1.0X +SQL Json 5253 / 5303 2.0 500.9 2.6X +SQL Parquet Vectorized 233 / 242 45.0 22.2 57.6X +SQL Parquet MR 1791 / 1796 5.9 170.8 7.5X +ParquetReader Vectorized 236 / 238 44.4 22.5 56.7X +SQL ORC Vectorized 453 / 473 23.2 43.2 29.6X +SQL ORC Vectorized with copy 573 / 577 18.3 54.7 23.4X +SQL ORC MR 1846 / 1850 5.7 176.0 7.3X + + +================================================================================================ +Single Column Scan From Wide Columns +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 3147 / 3148 0.3 3001.1 1.0X +SQL Json 2666 / 2693 0.4 2542.9 1.2X +SQL Parquet Vectorized 54 / 58 19.5 51.3 58.5X +SQL Parquet MR 220 / 353 4.8 209.9 14.3X +SQL ORC Vectorized 63 / 77 16.8 59.7 50.3X +SQL ORC Vectorized with copy 63 / 66 16.7 59.8 50.2X +SQL ORC MR 317 / 321 3.3 302.2 9.9X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 7902 / 7921 0.1 7536.2 1.0X +SQL Json 9467 / 9491 0.1 9028.6 0.8X +SQL Parquet Vectorized 73 / 79 14.3 69.8 108.0X +SQL Parquet MR 239 / 247 4.4 228.0 33.1X +SQL ORC Vectorized 78 / 84 13.4 74.6 101.0X +SQL ORC Vectorized with copy 78 / 88 13.4 74.4 101.3X +SQL ORC MR 910 / 918 1.2 867.6 8.7X + +OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +SQL CSV 13539 / 13543 0.1 12912.0 1.0X +SQL Json 17420 / 17446 0.1 16613.1 0.8X +SQL Parquet Vectorized 103 / 120 10.2 98.1 131.6X +SQL Parquet MR 250 / 258 4.2 238.9 54.1X +SQL ORC Vectorized 99 / 104 10.6 94.6 136.5X +SQL ORC Vectorized with copy 100 / 106 10.5 95.6 135.1X +SQL ORC MR 1653 / 1659 0.6 1576.3 8.2X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 51a7f9f1ef096..a1e7f9e36f4b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import scala.util.Random import org.apache.spark.SparkConf -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.SQLHelper @@ -34,10 +34,16 @@ import org.apache.spark.sql.vectorized.ColumnVector /** * Benchmark to measure data source read performance. - * To run this: - * spark-submit --class + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/DataSourceReadBenchmark-results.txt". + * }}} */ -object DataSourceReadBenchmark extends SQLHelper { +object DataSourceReadBenchmark extends BenchmarkBase with SQLHelper { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") // Since `spark.master` always exists, overrides this value @@ -93,11 +99,16 @@ object DataSourceReadBenchmark extends SQLHelper { def numericScanBenchmark(values: Int, dataType: DataType): Unit = { // Benchmarks running through spark sql. - val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + val sqlBenchmark = new Benchmark( + s"SQL Single ${dataType.sql} Column Scan", + values, + output = output) // Benchmarks driving reader component directly. val parquetReaderBenchmark = new Benchmark( - s"Parquet Reader Single ${dataType.sql} Column Scan", values) + s"Parquet Reader Single ${dataType.sql} Column Scan", + values, + output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { @@ -140,74 +151,6 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 22964 / 23096 0.7 1460.0 1.0X - SQL Json 8469 / 8593 1.9 538.4 2.7X - SQL Parquet Vectorized 164 / 177 95.8 10.4 139.9X - SQL Parquet MR 1687 / 1706 9.3 107.2 13.6X - SQL ORC Vectorized 191 / 197 82.3 12.2 120.2X - SQL ORC Vectorized with copy 215 / 219 73.2 13.7 106.9X - SQL ORC MR 1392 / 1412 11.3 88.5 16.5X - - - SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 24090 / 24097 0.7 1531.6 1.0X - SQL Json 8791 / 8813 1.8 558.9 2.7X - SQL Parquet Vectorized 204 / 212 77.0 13.0 117.9X - SQL Parquet MR 1813 / 1850 8.7 115.3 13.3X - SQL ORC Vectorized 226 / 230 69.7 14.4 106.7X - SQL ORC Vectorized with copy 295 / 298 53.3 18.8 81.6X - SQL ORC MR 1526 / 1549 10.3 97.1 15.8X - - - SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 25637 / 25791 0.6 1629.9 1.0X - SQL Json 9532 / 9570 1.7 606.0 2.7X - SQL Parquet Vectorized 181 / 191 86.8 11.5 141.5X - SQL Parquet MR 2210 / 2227 7.1 140.5 11.6X - SQL ORC Vectorized 309 / 317 50.9 19.6 83.0X - SQL ORC Vectorized with copy 316 / 322 49.8 20.1 81.2X - SQL ORC MR 1650 / 1680 9.5 104.9 15.5X - - - SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 31617 / 31764 0.5 2010.1 1.0X - SQL Json 12440 / 12451 1.3 790.9 2.5X - SQL Parquet Vectorized 284 / 315 55.4 18.0 111.4X - SQL Parquet MR 2382 / 2390 6.6 151.5 13.3X - SQL ORC Vectorized 398 / 403 39.5 25.3 79.5X - SQL ORC Vectorized with copy 410 / 413 38.3 26.1 77.1X - SQL ORC MR 1783 / 1813 8.8 113.4 17.7X - - - SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 26679 / 26742 0.6 1696.2 1.0X - SQL Json 12490 / 12541 1.3 794.1 2.1X - SQL Parquet Vectorized 174 / 183 90.4 11.1 153.3X - SQL Parquet MR 2201 / 2223 7.1 140.0 12.1X - SQL ORC Vectorized 415 / 429 37.9 26.4 64.3X - SQL ORC Vectorized with copy 422 / 428 37.2 26.9 63.2X - SQL ORC MR 1767 / 1773 8.9 112.3 15.1X - - - SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 34223 / 34324 0.5 2175.8 1.0X - SQL Json 17784 / 17785 0.9 1130.7 1.9X - SQL Parquet Vectorized 277 / 283 56.7 17.6 123.4X - SQL Parquet MR 2356 / 2386 6.7 149.8 14.5X - SQL ORC Vectorized 533 / 536 29.5 33.9 64.2X - SQL ORC Vectorized with copy 541 / 546 29.1 34.4 63.3X - SQL ORC MR 2166 / 2177 7.3 137.7 15.8X - */ sqlBenchmark.run() // Driving the parquet reader in batch mode directly. @@ -279,51 +222,13 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 198 / 202 79.4 12.6 1.0X - ParquetReader Vectorized -> Row 119 / 121 132.3 7.6 1.7X - - - Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 282 / 287 55.8 17.9 1.0X - ParquetReader Vectorized -> Row 246 / 247 64.0 15.6 1.1X - - - Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 258 / 262 60.9 16.4 1.0X - ParquetReader Vectorized -> Row 259 / 260 60.8 16.5 1.0X - - - Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 361 / 369 43.6 23.0 1.0X - ParquetReader Vectorized -> Row 361 / 371 43.6 22.9 1.0X - - - Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 253 / 261 62.2 16.1 1.0X - ParquetReader Vectorized -> Row 254 / 256 61.9 16.2 1.0X - - - Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 357 / 364 44.0 22.7 1.0X - ParquetReader Vectorized -> Row 358 / 366 44.0 22.7 1.0X - */ parquetReaderBenchmark.run() } } } def intStringScanBenchmark(values: Int): Unit = { - val benchmark = new Benchmark("Int and String Scan", values) + val benchmark = new Benchmark("Int and String Scan", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { @@ -368,26 +273,13 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 27145 / 27158 0.4 2588.7 1.0X - SQL Json 12969 / 13337 0.8 1236.8 2.1X - SQL Parquet Vectorized 2419 / 2448 4.3 230.7 11.2X - SQL Parquet MR 4631 / 4633 2.3 441.7 5.9X - SQL ORC Vectorized 2412 / 2465 4.3 230.0 11.3X - SQL ORC Vectorized with copy 2633 / 2675 4.0 251.1 10.3X - SQL ORC MR 4280 / 4350 2.4 408.2 6.3X - */ benchmark.run() } } } def repeatedStringScanBenchmark(values: Int): Unit = { - val benchmark = new Benchmark("Repeated String", values) + val benchmark = new Benchmark("Repeated String", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { @@ -432,26 +324,13 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 17345 / 17424 0.6 1654.1 1.0X - SQL Json 8639 / 8664 1.2 823.9 2.0X - SQL Parquet Vectorized 839 / 854 12.5 80.0 20.7X - SQL Parquet MR 1771 / 1775 5.9 168.9 9.8X - SQL ORC Vectorized 550 / 569 19.1 52.4 31.6X - SQL ORC Vectorized with copy 785 / 849 13.4 74.9 22.1X - SQL ORC MR 2168 / 2202 4.8 206.7 8.0X - */ benchmark.run() } } } def partitionTableScanBenchmark(values: Int): Unit = { - val benchmark = new Benchmark("Partitioned Table", values) + val benchmark = new Benchmark("Partitioned Table", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { @@ -562,40 +441,13 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - Data column - CSV 32613 / 32841 0.5 2073.4 1.0X - Data column - Json 13343 / 13469 1.2 848.3 2.4X - Data column - Parquet Vectorized 302 / 318 52.1 19.2 108.0X - Data column - Parquet MR 2908 / 2924 5.4 184.9 11.2X - Data column - ORC Vectorized 412 / 425 38.1 26.2 79.1X - Data column - ORC Vectorized with copy 442 / 446 35.6 28.1 73.8X - Data column - ORC MR 2390 / 2396 6.6 152.0 13.6X - Partition column - CSV 9626 / 9683 1.6 612.0 3.4X - Partition column - Json 10909 / 10923 1.4 693.6 3.0X - Partition column - Parquet Vectorized 69 / 76 228.4 4.4 473.6X - Partition column - Parquet MR 1898 / 1933 8.3 120.7 17.2X - Partition column - ORC Vectorized 67 / 74 236.0 4.2 489.4X - Partition column - ORC Vectorized with copy 65 / 72 241.9 4.1 501.6X - Partition column - ORC MR 1743 / 1749 9.0 110.8 18.7X - Both columns - CSV 35523 / 35552 0.4 2258.5 0.9X - Both columns - Json 13676 / 13681 1.2 869.5 2.4X - Both columns - Parquet Vectorized 317 / 326 49.5 20.2 102.7X - Both columns - Parquet MR 3333 / 3336 4.7 211.9 9.8X - Both columns - ORC Vectorized 441 / 446 35.6 28.1 73.9X - Both column - ORC Vectorized with copy 517 / 524 30.4 32.9 63.1X - Both columns - ORC MR 2574 / 2577 6.1 163.6 12.7X - */ benchmark.run() } } } def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - val benchmark = new Benchmark("String with Nulls Scan", values) + val benchmark = new Benchmark("String with Nulls Scan", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { @@ -673,51 +525,16 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 14875 / 14920 0.7 1418.6 1.0X - SQL Json 10974 / 10992 1.0 1046.5 1.4X - SQL Parquet Vectorized 1711 / 1750 6.1 163.2 8.7X - SQL Parquet MR 3838 / 3884 2.7 366.0 3.9X - ParquetReader Vectorized 1155 / 1168 9.1 110.2 12.9X - SQL ORC Vectorized 1341 / 1380 7.8 127.9 11.1X - SQL ORC Vectorized with copy 1659 / 1716 6.3 158.2 9.0X - SQL ORC MR 3594 / 3634 2.9 342.7 4.1X - - - String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 17219 / 17264 0.6 1642.1 1.0X - SQL Json 8843 / 8864 1.2 843.3 1.9X - SQL Parquet Vectorized 1169 / 1178 9.0 111.4 14.7X - SQL Parquet MR 2676 / 2697 3.9 255.2 6.4X - ParquetReader Vectorized 1068 / 1071 9.8 101.8 16.1X - SQL ORC Vectorized 1319 / 1319 7.9 125.8 13.1X - SQL ORC Vectorized with copy 1638 / 1639 6.4 156.2 10.5X - SQL ORC MR 3230 / 3257 3.2 308.1 5.3X - - - String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 13976 / 14053 0.8 1332.8 1.0X - SQL Json 5166 / 5176 2.0 492.6 2.7X - SQL Parquet Vectorized 274 / 282 38.2 26.2 50.9X - SQL Parquet MR 1553 / 1555 6.8 148.1 9.0X - ParquetReader Vectorized 241 / 246 43.5 23.0 57.9X - SQL ORC Vectorized 476 / 479 22.0 45.4 29.3X - SQL ORC Vectorized with copy 584 / 588 17.9 55.7 23.9X - SQL ORC MR 1720 / 1734 6.1 164.1 8.1X - */ benchmark.run() } } } def columnsBenchmark(values: Int, width: Int): Unit = { - val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + val benchmark = new Benchmark( + s"Single Column Scan from $width columns", + values, + output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { @@ -763,58 +580,35 @@ object DataSourceReadBenchmark extends SQLHelper { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 3478 / 3481 0.3 3316.4 1.0X - SQL Json 2646 / 2654 0.4 2523.6 1.3X - SQL Parquet Vectorized 67 / 72 15.8 63.5 52.2X - SQL Parquet MR 207 / 214 5.1 197.6 16.8X - SQL ORC Vectorized 69 / 76 15.2 66.0 50.3X - SQL ORC Vectorized with copy 70 / 76 15.0 66.5 49.9X - SQL ORC MR 299 / 303 3.5 285.1 11.6X - - - Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 9214 / 9236 0.1 8786.7 1.0X - SQL Json 9943 / 9978 0.1 9482.7 0.9X - SQL Parquet Vectorized 77 / 86 13.6 73.3 119.8X - SQL Parquet MR 229 / 235 4.6 218.6 40.2X - SQL ORC Vectorized 84 / 96 12.5 80.0 109.9X - SQL ORC Vectorized with copy 83 / 91 12.6 79.4 110.7X - SQL ORC MR 843 / 854 1.2 804.0 10.9X - - - Single Column Scan from 100 columns Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - SQL CSV 16503 / 16622 0.1 15738.9 1.0X - SQL Json 19109 / 19184 0.1 18224.2 0.9X - SQL Parquet Vectorized 99 / 108 10.6 94.3 166.8X - SQL Parquet MR 253 / 264 4.1 241.6 65.1X - SQL ORC Vectorized 107 / 114 9.8 101.6 154.8X - SQL ORC Vectorized with copy 107 / 118 9.8 102.1 154.1X - SQL ORC MR 1526 / 1529 0.7 1455.3 10.8X - */ benchmark.run() } } } - def main(args: Array[String]): Unit = { - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => - numericScanBenchmark(1024 * 1024 * 15, dataType) + override def runBenchmarkSuite(): Unit = { + runBenchmark("SQL Single Numeric Column Scan") { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { + dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) + } + } + runBenchmark("Int and String Scan") { + intStringScanBenchmark(1024 * 1024 * 10) } - intStringScanBenchmark(1024 * 1024 * 10) - repeatedStringScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + runBenchmark("Repeated String Scan") { + repeatedStringScanBenchmark(1024 * 1024 * 10) } - for (columnWidth <- List(10, 50, 100)) { - columnsBenchmark(1024 * 1024 * 1, columnWidth) + runBenchmark("Partitioned Table Scan") { + partitionTableScanBenchmark(1024 * 1024 * 15) + } + runBenchmark("String with Nulls Scan") { + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + } + runBenchmark("Single Column Scan From Wide Columns") { + for (columnWidth <- List(10, 50, 100)) { + columnsBenchmark(1024 * 1024 * 1, columnWidth) + } } } } From 65f75db61176d711a120f9c50c617844811274dc Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 11 Oct 2018 14:03:41 -0700 Subject: [PATCH 1814/2461] [MINOR][SQL] remove Redundant semicolons MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? remove Redundant semicolons in SortMergeJoinExec, thanks. ## How was this patch tested? N/A Closes #22695 from heary-cao/RedundantSemicolons. Authored-by: caoxuewen Signed-off-by: Sean Owen --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f4b9d132122e4..d7d3f6d6078b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -494,7 +494,7 @@ case class SortMergeJoinExec( | $leftRow = null; | } else { | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null;; + | $rightRow = null; | } | } while ($leftRow != null); | } From 1bb63ae5127609bd71748450c7c99287f98c72c8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 11 Oct 2018 14:04:44 -0700 Subject: [PATCH 1815/2461] [SPARK-24109][CORE] Remove class SnappyOutputStreamWrapper ## What changes were proposed in this pull request? Remove SnappyOutputStreamWrapper and other workaround now that new Snappy fixes these. See also https://github.com/apache/spark/pull/21176 and comments it links to. ## How was this patch tested? Existing tests Closes #22691 from srowen/SPARK-24109. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../apache/spark/io/CompressionCodec.scala | 63 ++----------------- project/MimaExcludes.scala | 1 + 2 files changed, 6 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 7722db56ee297..0664c5ac752c1 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -154,72 +154,19 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { - val version = SnappyCompressionCodec.version - override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt - new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize)) - } - - override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) -} - -/** - * Object guards against memory leak bug in snappy-java library: - * (https://github.com/xerial/snappy-java/issues/131). - * Before a new version of the library, we only call the method once and cache the result. - */ -private final object SnappyCompressionCodec { - private lazy val version: String = try { + try { Snappy.getNativeLibraryVersion } catch { case e: Error => throw new IllegalArgumentException(e) } -} -/** - * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close - * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version - * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. - */ -private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends OutputStream { - - private[this] var closed: Boolean = false - - override def write(b: Int): Unit = { - if (closed) { - throw new IOException("Stream is closed") - } - os.write(b) - } - - override def write(b: Array[Byte]): Unit = { - if (closed) { - throw new IOException("Stream is closed") - } - os.write(b) - } - - override def write(b: Array[Byte], off: Int, len: Int): Unit = { - if (closed) { - throw new IOException("Stream is closed") - } - os.write(b, off, len) - } - - override def flush(): Unit = { - if (closed) { - throw new IOException("Stream is closed") - } - os.flush() + override def compressedOutputStream(s: OutputStream): OutputStream = { + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt + new SnappyOutputStream(s, blockSize) } - override def close(): Unit = { - if (!closed) { - closed = true - os.close() - } - } + override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0b074fbf64eda..bf85fe0b4512c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,7 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version") ) // Exclude rules for 2.4.x From adf648b5be0e2479074e8d822e3563dc18f13586 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 11 Oct 2018 14:10:07 -0700 Subject: [PATCH 1816/2461] [SPARK-25615][SQL][TEST] Improve the test runtime of KafkaSinkSuite: streaming write to non-existing topic ## What changes were proposed in this pull request? Specify `kafka.max.block.ms` to 10 seconds while creating the kafka writer. In the absence of this overridden config, by default it uses a default time out of 60 seconds. With this change the test completes in close to 10 seconds as opposed to 1 minute. ## How was this patch tested? This is a test fix. Closes #22671 from dilipbiswal/SPARK-25615. Authored-by: Dilip Biswal Signed-off-by: Sean Owen --- .../scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 81832fbdcd7ec..d46c4139011da 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -427,6 +427,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { .format("kafka") .option("checkpointLocation", checkpointDir.getCanonicalPath) .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.max.block.ms", "5000") .queryName("kafkaStream") withTopic.foreach(stream.option("topic", _)) withOutputMode.foreach(stream.outputMode(_)) From 69f5e9cce14632a1f912c3632243a4e20b275365 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 11 Oct 2018 14:24:15 -0700 Subject: [PATCH 1817/2461] [SPARK-25674][SQL] If the records are incremented by more than 1 at a time,the number of bytes might rarely ever get updated MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? If the records are incremented by more than 1 at a time,the number of bytes might rarely ever get updated,because it might skip over the count that is an exact multiple of UPDATE_INPUT_METRICS_INTERVAL_RECORDS. This PR just checks whether the increment causes the value to exceed a higher multiple of UPDATE_INPUT_METRICS_INTERVAL_RECORDS. ## How was this patch tested? existed unit tests Closes #22594 from 10110346/inputMetrics. Authored-by: liuxian Signed-off-by: Sean Owen --- .../apache/spark/sql/execution/datasources/FileScanRDD.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 345c9d82ca0e7..dd3c154259c73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -104,12 +104,15 @@ class FileScanRDD( val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we // don't need to run this `if` for every record. + val preNumRecordsRead = inputMetrics.recordsRead if (nextElement.isInstanceOf[ColumnarBatch]) { inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) } else { inputMetrics.incRecordsRead(1) } - if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + // The records may be incremented by more than 1 at a time. + if (preNumRecordsRead / SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS != + inputMetrics.recordsRead / SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS) { updateBytesRead() } nextElement From a00181418911307725524641254439712e95445b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 11 Oct 2018 14:28:06 -0700 Subject: [PATCH 1818/2461] [SPARK-25598][STREAMING][BUILD][TEST-MAVEN] Remove flume connector in Spark 3 ## What changes were proposed in this pull request? Removes all vestiges of Flume in the build, for Spark 3. I don't think this needs Jenkins config changes. ## How was this patch tested? Existing tests. Closes #22692 from srowen/SPARK-25598. Authored-by: Sean Owen Signed-off-by: Sean Owen --- dev/create-release/release-build.sh | 2 +- dev/mima | 2 +- dev/run-tests.py | 1 - dev/sbt-checkstyle | 1 - dev/scalastyle | 1 - dev/sparktestsupport/modules.py | 52 --- dev/test-dependencies.sh | 2 +- docs/building-spark.md | 7 - docs/streaming-custom-receivers.md | 2 +- docs/streaming-flume-integration.md | 169 ---------- docs/streaming-programming-guide.md | 29 +- .../main/python/streaming/flume_wordcount.py | 56 ---- external/flume-assembly/pom.xml | 167 ---------- external/flume-sink/pom.xml | 140 -------- .../flume-sink/src/main/avro/sparkflume.avdl | 40 --- .../spark/streaming/flume/sink/Logging.scala | 127 ------- .../flume/sink/SparkAvroCallbackHandler.scala | 166 ---------- .../streaming/flume/sink/SparkSink.scala | 171 ---------- .../flume/sink/SparkSinkThreadFactory.scala | 35 -- .../streaming/flume/sink/SparkSinkUtils.scala | 28 -- .../flume/sink/TransactionProcessor.scala | 252 -------------- .../src/test/resources/log4j.properties | 28 -- .../streaming/flume/sink/SparkSinkSuite.scala | 218 ------------ external/flume/pom.xml | 89 ----- .../spark/examples/JavaFlumeEventCount.java | 67 ---- .../spark/examples/FlumeEventCount.scala | 68 ---- .../examples/FlumePollingEventCount.scala | 65 ---- .../streaming/flume/EventTransformer.scala | 72 ---- .../streaming/flume/FlumeBatchFetcher.scala | 166 ---------- .../streaming/flume/FlumeInputDStream.scala | 208 ------------ .../flume/FlumePollingInputDStream.scala | 123 ------- .../streaming/flume/FlumeTestUtils.scala | 117 ------- .../spark/streaming/flume/FlumeUtils.scala | 312 ------------------ .../flume/PollingFlumeTestUtils.scala | 209 ------------ .../spark/streaming/flume/package-info.java | 21 -- .../spark/streaming/flume/package.scala | 23 -- .../streaming/LocalJavaStreamingContext.java | 44 --- .../flume/JavaFlumePollingStreamSuite.java | 44 --- .../streaming/flume/JavaFlumeStreamSuite.java | 37 --- .../flume/src/test/resources/log4j.properties | 28 -- .../spark/streaming/TestOutputStream.scala | 48 --- .../flume/FlumePollingStreamSuite.scala | 149 --------- .../streaming/flume/FlumeStreamSuite.scala | 103 ------ .../kafka010/DirectKafkaInputDStream.scala | 1 - .../kafka/DirectKafkaInputDStream.scala | 1 - pom.xml | 54 --- project/SparkBuild.scala | 18 +- python/docs/pyspark.streaming.rst | 7 - python/pyspark/streaming/dstream.py | 2 +- python/pyspark/streaming/flume.py | 156 --------- python/pyspark/streaming/tests.py | 176 +--------- .../spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/java/JavaDStream.scala | 2 +- .../spark/streaming/dstream/DStream.scala | 2 +- .../streaming/dstream/InputDStream.scala | 1 - 55 files changed, 27 insertions(+), 4084 deletions(-) delete mode 100644 docs/streaming-flume-integration.md delete mode 100644 examples/src/main/python/streaming/flume_wordcount.py delete mode 100644 external/flume-assembly/pom.xml delete mode 100644 external/flume-sink/pom.xml delete mode 100644 external/flume-sink/src/main/avro/sparkflume.avdl delete mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala delete mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala delete mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala delete mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala delete mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala delete mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala delete mode 100644 external/flume-sink/src/test/resources/log4j.properties delete mode 100644 external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala delete mode 100644 external/flume/pom.xml delete mode 100644 external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java delete mode 100644 external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java delete mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala delete mode 100644 external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java delete mode 100644 external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java delete mode 100644 external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java delete mode 100644 external/flume/src/test/resources/log4j.properties delete mode 100644 external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala delete mode 100644 external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala delete mode 100644 external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala delete mode 100644 python/pyspark/streaming/flume.py diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 89593cfa0107a..b80f55de98e16 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -114,7 +114,7 @@ PUBLISH_SCALA_2_10=0 SCALA_2_10_PROFILES="-Pscala-2.10" SCALA_2_11_PROFILES= if [[ $SPARK_VERSION > "2.3" ]]; then - BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + BASE_PROFILES="$BASE_PROFILES -Pkubernetes" SCALA_2_11_PROFILES="-Pkafka-0-8" else PUBLISH_SCALA_2_10=1 diff --git a/dev/mima b/dev/mima index cd2694ff4d3de..a9ac8aff11eb6 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/run-tests.py b/dev/run-tests.py index 271360b6048a3..a125f5b07b9c7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -333,7 +333,6 @@ def build_spark_sbt(hadoop_version): build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["test:package", # Build test jars as some tests depend on them "streaming-kafka-0-8-assembly/assembly", - "streaming-flume-assembly/assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle index 8821a7c0e4ccf..1e825dbf78a12 100755 --- a/dev/sbt-checkstyle +++ b/dev/sbt-checkstyle @@ -26,7 +26,6 @@ ERRORS=$(echo -e "q\n" \ -Pkafka-0-8 \ -Pkubernetes \ -Pyarn \ - -Pflume \ -Phive \ -Phive-thriftserver \ checkstyle test:checkstyle \ diff --git a/dev/scalastyle b/dev/scalastyle index b0ad02523826c..0448e1dd74d1d 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -26,7 +26,6 @@ ERRORS=$(echo -e "q\n" \ -Pkafka-0-8 \ -Pkubernetes \ -Pyarn \ - -Pflume \ -Phive \ -Phive-thriftserver \ -Pspark-ganglia-lgpl \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index e7ac063e234e3..bd5f00916668f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -283,56 +283,6 @@ def __hash__(self): ] ) -streaming_flume_sink = Module( - name="streaming-flume-sink", - dependencies=[streaming], - source_file_regexes=[ - "external/flume-sink", - ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - }, - sbt_test_goals=[ - "streaming-flume-sink/test", - ] -) - - -streaming_flume = Module( - name="streaming-flume", - dependencies=[streaming], - source_file_regexes=[ - "external/flume", - ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - }, - sbt_test_goals=[ - "streaming-flume/test", - ] -) - - -streaming_flume_assembly = Module( - name="streaming-flume-assembly", - dependencies=[streaming_flume, streaming_flume_sink], - source_file_regexes=[ - "external/flume-assembly", - ], - build_profile_flags=[ - "-Pflume", - ], - environ={ - "ENABLE_FLUME_TESTS": "1" - } -) - mllib_local = Module( name="mllib-local", @@ -425,14 +375,12 @@ def __hash__(self): pyspark_core, streaming, streaming_kafka, - streaming_flume_assembly, streaming_kinesis_asl ], source_file_regexes=[ "python/pyspark/streaming" ], environ={ - "ENABLE_FLUME_TESTS": "1", "ENABLE_KAFKA_0_8_TESTS": "1" }, python_test_goals=[ diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index a3627c9b9b0a7..cc8f5d3a8e3a7 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.7 diff --git a/docs/building-spark.md b/docs/building-spark.md index b9e171547c3c0..55830d38a9e24 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -99,13 +99,6 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. -## Building with Flume support - -Apache Flume support must be explicitly enabled with the `flume` profile. -Note: Flume support is deprecated as of Spark 2.3.0. - - ./build/mvn -Pflume -DskipTests clean package - ## Building submodules individually It's possible to build Spark submodules using the `mvn -pl` option. diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 44ae52e81cd64..a83ebd9449fa4 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the ones for which it has built-in support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). +the ones for which it has built-in support (that is, beyond Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. Note that custom receivers can be implemented diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md deleted file mode 100644 index a1b6942ffe0a4..0000000000000 --- a/docs/streaming-flume-integration.md +++ /dev/null @@ -1,169 +0,0 @@ ---- -layout: global -title: Spark Streaming + Flume Integration Guide ---- - -[Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. - -**Note: Flume support is deprecated as of Spark 2.3.0.** - -## Approach 1: Flume-style Push-based Approach -Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. - -#### General Requirements -Choose a machine in your cluster such that - -- When your Flume + Spark Streaming application is launched, one of the Spark workers must run on that machine. - -- Flume can be configured to push data to a port on that machine. - -Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data. - -#### Configuring Flume -Configure Flume agent to send data to an Avro sink by having the following in the configuration file. - - agent.sinks = avroSink - agent.sinks.avroSink.type = avro - agent.sinks.avroSink.channel = memoryChannel - agent.sinks.avroSink.hostname = - agent.sinks.avroSink.port = - -See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about -configuring Flume agents. - -#### Configuring Spark Streaming Application -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - - groupId = org.apache.spark - artifactId = spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - -2. **Programming:** In the streaming application code, import `FlumeUtils` and create input DStream as follows. - -
      -
      - import org.apache.spark.streaming.flume._ - - val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) - - See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$). -
      -
      - import org.apache.spark.streaming.flume.*; - - JavaReceiverInputDStream flumeStream = - FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); - - See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html). -
      -
      - from pyspark.streaming.flume import FlumeUtils - - flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) - - By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. - See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). -
      -
      - - Note that the hostname should be the same as the one used by the resource manager in the - cluster (Mesos, YARN or Spark Standalone), so that resource allocation can match the names and launch - the receiver in the right machine. - -3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. - - For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - - For Python applications which lack SBT/Maven project management, `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, - - ./bin/spark-submit --packages org.apache.spark:spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... - - Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-flume-assembly` from the - [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-flume-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. - -## Approach 2: Pull-based Approach using a Custom Sink -Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. - -- Flume pushes data into the sink, and the data stays buffered. -- Spark Streaming uses a [reliable Flume receiver](streaming-programming-guide.html#receiver-reliability) - and transactions to pull data from the sink. Transactions succeed only after data is received and - replicated by Spark Streaming. - -This ensures stronger reliability and -[fault-tolerance guarantees](streaming-programming-guide.html#fault-tolerance-semantics) -than the previous approach. However, this requires configuring Flume to run a custom sink. -Here are the configuration steps. - -#### General Requirements -Choose a machine that will run the custom sink in a Flume agent. The rest of the Flume pipeline is configured to send data to that agent. Machines in the Spark cluster should have access to the chosen machine running the custom sink. - -#### Configuring Flume -Configuring Flume on the chosen machine requires the following two steps. - -1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink. - - (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). - - groupId = org.apache.spark - artifactId = spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - - (ii) *Scala library JAR*: Download the Scala library JAR for Scala {{site.SCALA_VERSION}}. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/scala-lang/scala-library/{{site.SCALA_VERSION}}/scala-library-{{site.SCALA_VERSION}}.jar)). - - groupId = org.scala-lang - artifactId = scala-library - version = {{site.SCALA_VERSION}} - - (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.5/commons-lang3-3.5.jar)). - - groupId = org.apache.commons - artifactId = commons-lang3 - version = 3.5 - -2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. - - agent.sinks = spark - agent.sinks.spark.type = org.apache.spark.streaming.flume.sink.SparkSink - agent.sinks.spark.hostname = - agent.sinks.spark.port = - agent.sinks.spark.channel = memoryChannel - - Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. - -See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about -configuring Flume agents. - -#### Configuring Spark Streaming Application -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide). - -2. **Programming:** In the streaming application code, import `FlumeUtils` and create input DStream as follows. - -
      -
      - import org.apache.spark.streaming.flume._ - - val flumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]) -
      -
      - import org.apache.spark.streaming.flume.*; - - JavaReceiverInputDStreamflumeStream = - FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); -
      -
      - from pyspark.streaming.flume import FlumeUtils - - addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] - flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) - - By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. - See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). -
      -
      - - Note that each input DStream can be configured to receive data from multiple sinks. - -3. **Deploying:** This is same as the first approach. - - - diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 0ca0f2a8b54d5..1103d5c73ff1f 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -11,7 +11,7 @@ description: Spark Streaming programming guide and tutorial for Spark SPARK_VERS # Overview Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Kinesis, or TCP sockets, and can be processed using complex +like Kafka, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's @@ -40,7 +40,7 @@ stream of results in batches. Spark Streaming provides a high-level abstraction called *discretized stream* or *DStream*, which represents a continuous stream of data. DStreams can be created either from input data -streams from sources such as Kafka, Flume, and Kinesis, or by applying high-level +streams from sources such as Kafka, and Kinesis, or by applying high-level operations on other DStreams. Internally, a DStream is represented as a sequence of [RDDs](api/scala/index.html#org.apache.spark.rdd.RDD). @@ -393,7 +393,7 @@ Similar to Spark, Spark Streaming is available through Maven Central. To write y -For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark +For ingesting data from sources like Kafka and Kinesis that are not present in the Spark Streaming core API, you will have to add the corresponding artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, @@ -402,7 +402,6 @@ some of the common ones are as follows.
      Property NameDefaultMeaning
      spark.yarn.keytabspark.kerberos.keytab (none) The full path to the file that contains the keytab for the principal specified above. This keytab @@ -477,7 +477,7 @@ providers can be disabled individually by setting `spark.security.credentials.{s
      spark.yarn.principalspark.kerberos.principal (none) Principal to be used to login to KDC, while running on secure clusters. Equivalent to the diff --git a/docs/sparkr.md b/docs/sparkr.md index b4248e8bb21de..55e8f15da17ca 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -70,12 +70,12 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s --master
      spark.yarn.keytabspark.kerberos.keytab Application Properties --keytab
      spark.yarn.principalspark.kerberos.principal Application Properties --principal
      spark.kubernetes.executor.limit.cores
      spark.kubernetes.local.dirs.tmpfsfalse + false Configure the emptyDir volumes used to back SPARK_LOCAL_DIRS within the Spark driver and executor pods to use tmpfs backing i.e. RAM. See Local Storage earlier on this page for more discussion of this. From 1e437835e96c4417117f44c29eba5ebc0112926f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 29 Sep 2018 11:43:58 +0800 Subject: [PATCH 1730/2461] [SPARK-25570][SQL][TEST] Replace 2.3.1 with 2.3.2 in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? This PR aims to prevent test slowdowns at `HiveExternalCatalogVersionsSuite` by using the latest Apache Spark 2.3.2 link because the Apache mirrors will remove the old Spark 2.3.1 binaries eventually. `HiveExternalCatalogVersionsSuite` will not fail because [SPARK-24813](https://issues.apache.org/jira/browse/SPARK-24813) implements a fallback logic. However, it will cause many trials and fallbacks in all builds over `branch-2.3/branch-2.4/master`. We had better fix this issue. ## How was this patch tested? Pass the Jenkins with the updated version. Closes #22587 from dongjoon-hyun/SPARK-25570. Authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a7d6972fa71f7..fd4985d131885 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -206,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.1.3", "2.2.2", "2.3.1") + val testingVersions = Seq("2.1.3", "2.2.2", "2.3.2") protected var spark: SparkSession = _ From 1007cae20e8f566e7d7c25f0f81c9b84f352b6d5 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 29 Sep 2018 17:53:30 +0800 Subject: [PATCH 1731/2461] [SPARK-25447][SQL] Support JSON options by schema_of_json() ## What changes were proposed in this pull request? In the PR, I propose to extended the `schema_of_json()` function, and accept JSON options since they can impact on schema inferring. Purpose is to support the same options that `from_json` can use during schema inferring. ## How was this patch tested? Added SQL, Python and Scala tests (`JsonExpressionsSuite` and `JsonFunctionsSuite`) that checks JSON options are used. Closes #22442 from MaxGekk/schema_of_json-options. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/functions.py | 11 ++++++-- .../expressions/jsonExpressions.scala | 28 +++++++++++++++---- .../expressions/JsonExpressionsSuite.scala | 12 ++++++-- .../org/apache/spark/sql/functions.scala | 15 ++++++++++ .../sql-tests/inputs/json-functions.sql | 4 +++ .../sql-tests/results/json-functions.sql.out | 18 +++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 8 ++++++ 7 files changed, 85 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e5bc1eaaad21a..74f0685b3cca6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2348,11 +2348,15 @@ def to_json(col, options={}): @ignore_unicode_prefix @since(2.4) -def schema_of_json(col): +def schema_of_json(col, options={}): """ Parses a column containing a JSON string and infers its schema in DDL format. :param col: string column in json format + :param options: options to control parsing. accepts the same options as the JSON datasource + + .. versionchanged:: 2.5 + It accepts `options` parameter to control schema inferring. >>> from pyspark.sql.types import * >>> data = [(1, '{"a": 1}')] @@ -2361,10 +2365,13 @@ def schema_of_json(col): [Row(json=u'struct')] >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() [Row(json=u'struct')] + >>> schema = schema_of_json(lit('{a: 1}'), {'allowUnquotedFieldNames':'true'}) + >>> df.select(schema.alias("json")).collect() + [Row(json=u'struct')] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + jc = sc._jvm.functions.schema_of_json(_to_java_column(col), options) return Column(jc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bd9090a07471b..f5297dde10ed6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -740,15 +740,31 @@ case class StructsToJson( examples = """ Examples: > SELECT _FUNC_('[{"col":0}]'); - array> + array> + > SELECT _FUNC_('[{"col":01}]', map('allowNumericLeadingZeros', 'true')); + array> """, since = "2.4.0") -case class SchemaOfJson(child: Expression) +case class SchemaOfJson( + child: Expression, + options: Map[String, String]) extends UnaryExpression with String2StringExpression with CodegenFallback { - private val jsonOptions = new JSONOptions(Map.empty, "UTC") - private val jsonFactory = new JsonFactory() - jsonOptions.setJacksonOptions(jsonFactory) + def this(child: Expression) = this(child, Map.empty[String, String]) + + def this(child: Expression, options: Expression) = this( + child = child, + options = JsonExprUtils.convertToMapData(options)) + + @transient + private lazy val jsonOptions = new JSONOptions(options, "UTC") + + @transient + private lazy val jsonFactory = { + val factory = new JsonFactory() + jsonOptions.setJacksonOptions(factory) + factory + } override def convert(v: UTF8String): UTF8String = { val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => @@ -764,7 +780,7 @@ object JsonExprUtils { def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e @ SchemaOfJson(_: Literal) => + case e @ SchemaOfJson(_: Literal, _) => val ddlSchema = e.eval().asInstanceOf[UTF8String] DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 0e9c8abec33e4..34fdd0cc834f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -707,9 +707,17 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with } test("SPARK-24709: infer schema of json strings") { - checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation(new SchemaOfJson(Literal.create("""{"col":0}""")), + "struct") checkEvaluation( - SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + new SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), "struct,col1:struct>") } + + test("infer schema of JSON strings by using options") { + checkEvaluation( + new SchemaOfJson(Literal.create("""{"col":01}"""), + CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))), + "struct") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4c58e77df485e..59a1fcb5ba367 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3611,6 +3611,21 @@ object functions { */ def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + /** + * Parses a column containing a JSON string and infers its schema using options. + * + * @param e a string column containing JSON data. + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. See [[DataFrameReader#json]]. + * @return a column with string literal containing schema in DDL format. + * + * @group collection_funcs + * @since 2.5.0 + */ + def schema_of_json(e: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfJson(e.expr, options.asScala.toMap)) + } + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or * a `MapType` into a JSON string with the specified schema. diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 0f22c0eeed581..bdd1fe4074f3c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -56,3 +56,7 @@ select from_json('[{"a": 1}, 2]', 'array>'); select to_json(array('1', '2', '3')); select to_json(array(array(1, 2, 3), array(4))); +-- infer schema of json literal using options +select schema_of_json('{"c1":1}', map('primitivesAsString', 'true')); +select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'true', 'prefersDecimal', 'true')); + diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index e550b43e08c28..77e9000401141 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 40 +-- Number of queries: 42 -- !query 0 @@ -370,3 +370,19 @@ select to_json(array(array(1, 2, 3), array(4))) struct -- !query 39 output [[1,2,3],[4]] + + +-- !query 40 +select schema_of_json('{"c1":1}', map('primitivesAsString', 'true')) +-- !query 40 schema +struct +-- !query 40 output +struct + + +-- !query 41 +select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'true', 'prefersDecimal', 'true')) +-- !query 41 schema +struct +-- !query 41 output +struct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 853bc182f2f4a..5cbf10129a4da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import collection.JavaConverters._ + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -402,6 +404,12 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(out.schema == expected) } + test("infers schemas using options") { + val df = spark.range(1) + .select(schema_of_json(lit("{a:1}"), Map("allowUnquotedFieldNames" -> "true").asJava)) + checkAnswer(df, Seq(Row("struct"))) + } + test("from_json - array of primitive types") { val df = Seq("[1, 2, 3]").toDF("a") val schema = new ArrayType(IntegerType, false) From dcb9a97f3e16d4645529ac619c3197fcba1c9806 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 29 Sep 2018 18:18:37 +0800 Subject: [PATCH 1732/2461] [SPARK-25262][DOC][FOLLOWUP] Fix link tags in html table ## What changes were proposed in this pull request? Markdown links are not working inside html table. We should use html link tag. ## How was this patch tested? Verified in IntelliJ IDEA's markdown editor and online markdown editor. Closes #22588 from viirya/SPARK-25262-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: hyukjinkwon --- docs/running-on-kubernetes.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index c7aea2709605d..b4088d79addff 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -680,15 +680,15 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.limit.cores (none) - Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu limit for the driver pod.
      spark.kubernetes.executor.request.cores (none) - Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). - Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + Specify the cpu request for each executor pod. Values conform to the Kubernetes convention. + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in CPU units. This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. spark.kubernetes.executor.limit.cores (none) - Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu limit for each executor pod launched for the Spark Application.
      spark.ui.dagGraph.retainedRootRDDsInt.MaxValue + How many DAG graph nodes the Spark UI and status APIs remember before garbage collecting. +
      spark.ui.enabled true
      spark.ui.liveUpdate.period100ms + How often to update live entities. -1 means "never update" when replaying applications, + meaning only the last write will happen. For live applications, this avoids a few + operations that we can live without when rapidly processing incoming task events. +
      spark.ui.port 4040
      spark.history.fs.endEventReparseChunkSize1m + How many bytes to parse at the end of log files looking for the end event. + This is used to speed up generation of application listings by skipping unnecessary + parts of event log files. It can be disabled by setting this config to 0. +
      spark.history.fs.inProgressOptimization.enabledtrue + Enable optimized handling of in-progress logs. This option may leave finished + applications that fail to rename their event logs listed as in-progress. +
      spark.history.fs.numReplayThreads 25% of available cores
      spark.history.store.maxDiskUsage10g + Maximum disk usage for the local directory where the cache application history information + are stored. +
      spark.history.store.path (none)
      Append, Update, Complete None Depends on ForeachWriter implementationMore details in the next sectionMore details in the next section
      ForeachBatch SinkAppend, Update, CompleteNoneDepends on the implementationMore details in the next section
      Console Sink Append, Update, Complete
      -
      SourceArtifact
      Kafka spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
      Flume spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}
      Kinesis
      spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Amazon Software License]
      @@ -577,7 +576,7 @@ Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. Examples: file systems, and socket connections. -- *Advanced sources*: Sources like Kafka, Flume, Kinesis, etc. are available through +- *Advanced sources*: Sources like Kafka, Kinesis, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -597,7 +596,7 @@ as well as to run the receiver(s). - When running a Spark Streaming program locally, do not use "local" or "local[1]" as the master URL. Either of these means that only one thread will be used for running tasks locally. If you are using - an input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will + an input DStream based on a receiver (e.g. sockets, Kafka, etc.), then the single thread will be used to run the receiver, leaving no thread for processing the received data. Hence, when running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run (see [Spark Properties](configuration.html#spark-properties) for information on how to set @@ -732,10 +731,10 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, Kafka, Kinesis and Flume are available in the Python API. +out of these sources, Kafka and Kinesis are available in the Python API. This category of sources require interfacing with external non-Spark libraries, some of them with -complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts +complex dependencies (e.g., Kafka). Hence, to minimize issues related to version conflicts of dependencies, the functionality to create DStreams from these sources has been moved to separate libraries that can be [linked](#linking) to explicitly when necessary. @@ -748,8 +747,6 @@ Some of these advanced sources are as follows. - **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka broker versions 0.8.2.1 or higher. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.6.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. - - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. ### Custom Sources @@ -766,7 +763,7 @@ Guide](streaming-custom-receivers.html) for details. {:.no_toc} There can be two kinds of data sources based on their *reliability*. Sources -(like Kafka and Flume) allow the transferred data to be acknowledged. If the system receiving +(like Kafka) allow the transferred data to be acknowledged. If the system receiving data from these *reliable* sources acknowledges the received data correctly, it can be ensured that no data will be lost due to any kind of failure. This leads to two kinds of receivers: @@ -1603,7 +1600,7 @@ operations on the same data). For window-based operations like `reduceByWindow` Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. -For input streams that receive data over the network (such as, Kafka, Flume, sockets, etc.), the +For input streams that receive data over the network (such as, Kafka, sockets, etc.), the default persistence level is set to replicate the data to two nodes for fault-tolerance. Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in @@ -1973,7 +1970,7 @@ To run a Spark Streaming applications, you need to have the following. - *Package the application JAR* - You have to compile your streaming application into a JAR. If you are using [`spark-submit`](submitting-applications.html) to start the application, then you will not need to provide Spark and Spark Streaming in the JAR. However, - if your application uses [advanced sources](#advanced-sources) (e.g. Kafka, Flume), + if your application uses [advanced sources](#advanced-sources) (e.g. Kafka), then you will have to package the extra artifact they link to, along with their dependencies, in the JAR that is used to deploy the application. For example, an application using `KafkaUtils` will have to include `spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and all its @@ -2060,7 +2057,7 @@ for graceful shutdown options) which ensure data that has been received is compl processed before shutdown. Then the upgraded application can be started, which will start processing from the same point where the earlier application left off. Note that this can be done only with input sources that support source-side buffering -(like Kafka, and Flume) as data needs to be buffered while the previous application was down and +(like Kafka) as data needs to be buffered while the previous application was down and the upgraded application is not yet up. And restarting from earlier checkpoint information of pre-upgrade code cannot be done. The checkpoint information essentially contains serialized Scala/Java/Python objects and trying to deserialize objects with new, @@ -2115,7 +2112,7 @@ highlights some of the most important ones. ### Level of Parallelism in Data Receiving {:.no_toc} -Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to be deserialized +Receiving data over the network (like Kafka, socket, etc.) requires the data to be deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. @@ -2475,14 +2472,12 @@ additional effort may be necessary to achieve exactly-once semantics. There are * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) * [KafkaUtils](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$), - [FlumeUtils](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$), [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$), - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), - [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) - Python docs * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py deleted file mode 100644 index c8ea92b61ca6e..0000000000000 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ /dev/null @@ -1,56 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -r""" - Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - Usage: flume_wordcount.py - - To run this on your local machine, you need to setup Flume first, see - https://flume.apache.org/documentation.html - - and then run the example - `$ bin/spark-submit --jars \ - external/flume-assembly/target/scala-*/spark-streaming-flume-assembly-*.jar \ - examples/src/main/python/streaming/flume_wordcount.py \ - localhost 12345 -""" -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.streaming import StreamingContext -from pyspark.streaming.flume import FlumeUtils - -if __name__ == "__main__": - if len(sys.argv) != 3: - print("Usage: flume_wordcount.py ", file=sys.stderr) - sys.exit(-1) - - sc = SparkContext(appName="PythonStreamingFlumeWordCount") - ssc = StreamingContext(sc, 1) - - hostname, port = sys.argv[1:] - kvs = FlumeUtils.createStream(ssc, hostname, int(port)) - lines = kvs.map(lambda x: x[1]) - counts = lines.flatMap(lambda line: line.split(" ")) \ - .map(lambda word: (word, 1)) \ - .reduceByKey(lambda a, b: a+b) - counts.pprint() - - ssc.start() - ssc.awaitTermination() diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml deleted file mode 100644 index 002bd6fb7f294..0000000000000 --- a/external/flume-assembly/pom.xml +++ /dev/null @@ -1,167 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 3.0.0-SNAPSHOT - ../../pom.xml - - - spark-streaming-flume-assembly_2.11 - jar - Spark Project External Flume Assembly - http://spark.apache.org/ - - - provided - streaming-flume-assembly - - - - - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - - - org.mortbay.jetty - jetty - - - org.mortbay.jetty - jetty-util - - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - - commons-codec - commons-codec - provided - - - commons-lang - commons-lang - provided - - - commons-net - commons-net - provided - - - com.google.protobuf - protobuf-java - provided - - - org.apache.avro - avro - provided - - - org.apache.avro - avro-ipc - provided - - - org.apache.avro - avro-mapred - ${avro.mapred.classifier} - provided - - - org.scala-lang - scala-library - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - - - - - flume-provided - - provided - - - - - diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml deleted file mode 100644 index 168d9d3b2ae0a..0000000000000 --- a/external/flume-sink/pom.xml +++ /dev/null @@ -1,140 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 3.0.0-SNAPSHOT - ../../pom.xml - - - spark-streaming-flume-sink_2.11 - - streaming-flume-sink - - jar - Spark Project External Flume Sink - http://spark.apache.org/ - - - - org.apache.flume - flume-ng-sdk - - - - com.google.guava - guava - - - - org.apache.thrift - libthrift - - - - - org.apache.flume - flume-ng-core - - - com.google.guava - guava - - - org.apache.thrift - libthrift - - - - - org.scala-lang - scala-library - - - - com.google.guava - guava - test - - - - io.netty - netty - 3.4.0.Final - test - - - org.apache.spark - spark-tags_${scala.binary.version} - - - - - org.apache.spark - spark-tags_${scala.binary.version} - test-jar - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.avro - avro-maven-plugin - ${avro.version} - - - ${project.basedir}/target/scala-${scala.binary.version}/src_managed/main/compiled_avro - - - - generate-sources - - idl-protocol - - - - - - org.apache.maven.plugins - maven-shade-plugin - - - - - - - - diff --git a/external/flume-sink/src/main/avro/sparkflume.avdl b/external/flume-sink/src/main/avro/sparkflume.avdl deleted file mode 100644 index 8806e863ac7c6..0000000000000 --- a/external/flume-sink/src/main/avro/sparkflume.avdl +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -@namespace("org.apache.spark.streaming.flume.sink") - -protocol SparkFlumeProtocol { - - record SparkSinkEvent { - map headers; - bytes body; - } - - record EventBatch { - string errorMsg = ""; // If this is empty it is a valid message, else it represents an error - string sequenceNumber; - array events; - } - - EventBatch getEventBatch (int n); - - void ack (string sequenceNumber); - - void nack (string sequenceNumber); -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala deleted file mode 100644 index 09d3fe91e42c8..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import org.slf4j.{Logger, LoggerFactory} - -/** - * Copy of the org.apache.spark.Logging for being used in the Spark Sink. - * The org.apache.spark.Logging is not used so that all of Spark is not brought - * in as a dependency. - */ -private[sink] trait Logging { - // Make the log field transient so that objects with Logging can - // be serialized and used on another machine - @transient private var _log: Logger = null - - // Method to get or create the logger for this object - protected def log: Logger = { - if (_log == null) { - initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - _log = LoggerFactory.getLogger(className) - } - _log - } - - // Log methods that take only a String - protected def logInfo(msg: => String) { - if (log.isInfoEnabled) log.info(msg) - } - - protected def logDebug(msg: => String) { - if (log.isDebugEnabled) log.debug(msg) - } - - protected def logTrace(msg: => String) { - if (log.isTraceEnabled) log.trace(msg) - } - - protected def logWarning(msg: => String) { - if (log.isWarnEnabled) log.warn(msg) - } - - protected def logError(msg: => String) { - if (log.isErrorEnabled) log.error(msg) - } - - // Log methods that take Throwables (Exceptions/Errors) too - protected def logInfo(msg: => String, throwable: Throwable) { - if (log.isInfoEnabled) log.info(msg, throwable) - } - - protected def logDebug(msg: => String, throwable: Throwable) { - if (log.isDebugEnabled) log.debug(msg, throwable) - } - - protected def logTrace(msg: => String, throwable: Throwable) { - if (log.isTraceEnabled) log.trace(msg, throwable) - } - - protected def logWarning(msg: => String, throwable: Throwable) { - if (log.isWarnEnabled) log.warn(msg, throwable) - } - - protected def logError(msg: => String, throwable: Throwable) { - if (log.isErrorEnabled) log.error(msg, throwable) - } - - protected def isTraceEnabled(): Boolean = { - log.isTraceEnabled - } - - private def initializeIfNecessary() { - if (!Logging.initialized) { - Logging.initLock.synchronized { - if (!Logging.initialized) { - initializeLogging() - } - } - } - } - - private def initializeLogging() { - Logging.initialized = true - - // Force a call into slf4j to initialize it. Avoids this happening from multiple threads - // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html - log - } -} - -private[sink] object Logging { - @volatile private var initialized = false - val initLock = new Object() - try { - // We use reflection here to handle the case where users remove the - // slf4j-to-jul bridge order to route their logs to JUL. - // scalastyle:off classforname - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") - // scalastyle:on classforname - bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) - val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] - if (!installed) { - bridgeClass.getMethod("install").invoke(null) - } - } catch { - case e: ClassNotFoundException => // can't log anything yet so just fail silently - } -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala deleted file mode 100644 index 8050ec357e261..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.util.UUID -import java.util.concurrent.{CountDownLatch, Executors} -import java.util.concurrent.atomic.AtomicLong - -import scala.collection.mutable - -import org.apache.flume.Channel - -/** - * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process - * requests. Each getEvents, ack and nack call is forwarded to an instance of this class. - * @param threads Number of threads to use to process requests. - * @param channel The channel that the sink pulls events from - * @param transactionTimeout Timeout in millis after which the transaction if not acked by Spark - * is rolled back. - */ -// Flume forces transactions to be thread-local. So each transaction *must* be committed, or -// rolled back from the thread it was originally created in. So each getEvents call from Spark -// creates a TransactionProcessor which runs in a new thread, in which the transaction is created -// and events are pulled off the channel. Once the events are sent to spark, -// that thread is blocked and the TransactionProcessor is saved in a map, -// until an ACK or NACK comes back or the transaction times out (after the specified timeout). -// When the response comes or a timeout is hit, the TransactionProcessor is retrieved and then -// unblocked, at which point the transaction is committed or rolled back. - -private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, - val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { - val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) - // Protected by `sequenceNumberToProcessor` - private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() - // This sink will not persist sequence numbers and reuses them if it gets restarted. - // So it is possible to commit a transaction which may have been meant for the sink before the - // restart. - // Since the new txn may not have the same sequence number we must guard against accidentally - // committing a new transaction. To reduce the probability of that happening a random string is - // prepended to the sequence number. Does not change for life of sink - private val seqBase = UUID.randomUUID().toString.substring(0, 8) - private val seqCounter = new AtomicLong(0) - - // Protected by `sequenceNumberToProcessor` - private var stopped = false - - @volatile private var isTest = false - private var testLatch: CountDownLatch = null - - /** - * Returns a bunch of events to Spark over Avro RPC. - * @param n Maximum number of events to return in a batch - * @return [[EventBatch]] instance that has a sequence number and an array of at most n events - */ - override def getEventBatch(n: Int): EventBatch = { - logDebug("Got getEventBatch call from Spark.") - val sequenceNumber = seqBase + seqCounter.incrementAndGet() - createProcessor(sequenceNumber, n) match { - case Some(processor) => - transactionExecutorOpt.foreach(_.submit(processor)) - // Wait until a batch is available - will be an error if error message is non-empty - val batch = processor.getEventBatch - if (SparkSinkUtils.isErrorBatch(batch)) { - // Remove the processor if it is an error batch since no ACK is sent. - removeAndGetProcessor(sequenceNumber) - logWarning("Received an error batch - no events were received from channel! ") - } - batch - case None => - new EventBatch("Spark sink has been stopped!", "", java.util.Collections.emptyList()) - } - } - - private def createProcessor(seq: String, n: Int): Option[TransactionProcessor] = { - sequenceNumberToProcessor.synchronized { - if (!stopped) { - val processor = new TransactionProcessor( - channel, seq, n, transactionTimeout, backOffInterval, this) - sequenceNumberToProcessor.put(seq, processor) - if (isTest) { - processor.countDownWhenBatchAcked(testLatch) - } - Some(processor) - } else { - None - } - } - } - - /** - * Called by Spark to indicate successful commit of a batch - * @param sequenceNumber The sequence number of the event batch that was successful - */ - override def ack(sequenceNumber: CharSequence): Void = { - logDebug("Received Ack for batch with sequence number: " + sequenceNumber) - completeTransaction(sequenceNumber, success = true) - null - } - - /** - * Called by Spark to indicate failed commit of a batch - * @param sequenceNumber The sequence number of the event batch that failed - * @return - */ - override def nack(sequenceNumber: CharSequence): Void = { - completeTransaction(sequenceNumber, success = false) - logInfo("Spark failed to commit transaction. Will reattempt events.") - null - } - - /** - * Helper method to commit or rollback a transaction. - * @param sequenceNumber The sequence number of the batch that was completed - * @param success Whether the batch was successful or not. - */ - private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { - removeAndGetProcessor(sequenceNumber).foreach { processor => - processor.batchProcessed(success) - } - } - - /** - * Helper method to remove the TxnProcessor for a Sequence Number. Can be used to avoid a leak. - * @param sequenceNumber - * @return An `Option` of the transaction processor for the corresponding batch. Note that this - * instance is no longer tracked and the caller is responsible for that txn processor. - */ - private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): - Option[TransactionProcessor] = { - sequenceNumberToProcessor.synchronized { - sequenceNumberToProcessor.remove(sequenceNumber.toString) - } - } - - private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { - testLatch = latch - isTest = true - } - - /** - * Shuts down the executor used to process transactions. - */ - def shutdown() { - logInfo("Shutting down Spark Avro Callback Handler") - sequenceNumberToProcessor.synchronized { - stopped = true - sequenceNumberToProcessor.values.foreach(_.shutdown()) - } - transactionExecutorOpt.foreach(_.shutdownNow()) - } -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala deleted file mode 100644 index e5b63aa1a77ef..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.net.InetSocketAddress -import java.util.concurrent._ - -import org.apache.avro.ipc.NettyServer -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.flume.Context -import org.apache.flume.Sink.Status -import org.apache.flume.conf.{Configurable, ConfigurationException} -import org.apache.flume.sink.AbstractSink - -/** - * A sink that uses Avro RPC to run a server that can be polled by Spark's - * FlumePollingInputDStream. This sink has the following configuration parameters: - * - * hostname - The hostname to bind to. Default: 0.0.0.0 - * port - The port to bind to. (No default - mandatory) - * timeout - Time in seconds after which a transaction is rolled back, - * if an ACK is not received from Spark within that time - * threads - Number of threads to use to receive requests from Spark (Default: 10) - * - * This sink is unlike other Flume sinks in the sense that it does not push data, - * instead the process method in this sink simply blocks the SinkRunner the first time it is - * called. This sink starts up an Avro IPC server that uses the SparkFlumeProtocol. - * - * Each time a getEventBatch call comes, creates a transaction and reads events - * from the channel. When enough events are read, the events are sent to the Spark receiver and - * the thread itself is blocked and a reference to it saved off. - * - * When the ack for that batch is received, - * the thread which created the transaction is retrieved and it commits the transaction with the - * channel from the same thread it was originally created in (since Flume transactions are - * thread local). If a nack is received instead, the sink rolls back the transaction. If no ack - * is received within the specified timeout, the transaction is rolled back too. If an ack comes - * after that, it is simply ignored and the events get re-sent. - * - */ - -class SparkSink extends AbstractSink with Logging with Configurable { - - // Size of the pool to use for holding transaction processors. - private var poolSize: Integer = SparkSinkConfig.DEFAULT_THREADS - - // Timeout for each transaction. If spark does not respond in this much time, - // rollback the transaction - private var transactionTimeout = SparkSinkConfig.DEFAULT_TRANSACTION_TIMEOUT - - // Address info to bind on - private var hostname: String = SparkSinkConfig.DEFAULT_HOSTNAME - private var port: Int = 0 - - private var backOffInterval: Int = 200 - - // Handle to the server - private var serverOpt: Option[NettyServer] = None - - // The handler that handles the callback from Avro - private var handler: Option[SparkAvroCallbackHandler] = None - - // Latch that blocks off the Flume framework from wasting 1 thread. - private val blockingLatch = new CountDownLatch(1) - - override def start() { - logInfo("Starting Spark Sink: " + getName + " on port: " + port + " and interface: " + - hostname + " with " + "pool size: " + poolSize + " and transaction timeout: " + - transactionTimeout + ".") - handler = Option(new SparkAvroCallbackHandler(poolSize, getChannel, transactionTimeout, - backOffInterval)) - val responder = new SpecificResponder(classOf[SparkFlumeProtocol], handler.get) - // Using the constructor that takes specific thread-pools requires bringing in netty - // dependencies which are being excluded in the build. In practice, - // Netty dependencies are already available on the JVM as Flume would have pulled them in. - serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) - serverOpt.foreach { server => - logInfo("Starting Avro server for sink: " + getName) - server.start() - } - super.start() - } - - override def stop() { - logInfo("Stopping Spark Sink: " + getName) - handler.foreach { callbackHandler => - callbackHandler.shutdown() - } - serverOpt.foreach { server => - logInfo("Stopping Avro Server for sink: " + getName) - server.close() - server.join() - } - blockingLatch.countDown() - super.stop() - } - - override def configure(ctx: Context) { - import SparkSinkConfig._ - hostname = ctx.getString(CONF_HOSTNAME, DEFAULT_HOSTNAME) - port = Option(ctx.getInteger(CONF_PORT)). - getOrElse(throw new ConfigurationException("The port to bind to must be specified")) - poolSize = ctx.getInteger(THREADS, DEFAULT_THREADS) - transactionTimeout = ctx.getInteger(CONF_TRANSACTION_TIMEOUT, DEFAULT_TRANSACTION_TIMEOUT) - backOffInterval = ctx.getInteger(CONF_BACKOFF_INTERVAL, DEFAULT_BACKOFF_INTERVAL) - logInfo("Configured Spark Sink with hostname: " + hostname + ", port: " + port + ", " + - "poolSize: " + poolSize + ", transactionTimeout: " + transactionTimeout + ", " + - "backoffInterval: " + backOffInterval) - } - - override def process(): Status = { - // This method is called in a loop by the Flume framework - block it until the sink is - // stopped to save CPU resources. The sink runner will interrupt this thread when the sink is - // being shut down. - logInfo("Blocking Sink Runner, sink will continue to run..") - blockingLatch.await() - Status.BACKOFF - } - - private[flume] def getPort(): Int = { - serverOpt - .map(_.getPort) - .getOrElse( - throw new RuntimeException("Server was not started!") - ) - } - - /** - * Pass in a [[CountDownLatch]] for testing purposes. This batch is counted down when each - * batch is received. The test can simply call await on this latch till the expected number of - * batches are received. - * @param latch - */ - private[flume] def countdownWhenBatchReceived(latch: CountDownLatch) { - handler.foreach(_.countDownWhenBatchAcked(latch)) - } -} - -/** - * Configuration parameters and their defaults. - */ -private[flume] -object SparkSinkConfig { - val THREADS = "threads" - val DEFAULT_THREADS = 10 - - val CONF_TRANSACTION_TIMEOUT = "timeout" - val DEFAULT_TRANSACTION_TIMEOUT = 60 - - val CONF_HOSTNAME = "hostname" - val DEFAULT_HOSTNAME = "0.0.0.0" - - val CONF_PORT = "port" - - val CONF_BACKOFF_INTERVAL = "backoffInterval" - val DEFAULT_BACKOFF_INTERVAL = 200 -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala deleted file mode 100644 index 845fc8debda75..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.util.concurrent.ThreadFactory -import java.util.concurrent.atomic.AtomicLong - -/** - * Thread factory that generates daemon threads with a specified name format. - */ -private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { - - private val threadId = new AtomicLong() - - override def newThread(r: Runnable): Thread = { - val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) - t.setDaemon(true) - t - } - -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala deleted file mode 100644 index 47c0e294d6b52..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -private[flume] object SparkSinkUtils { - /** - * This method determines if this batch represents an error or not. - * @param batch - The batch to check - * @return - true if the batch represents an error - */ - def isErrorBatch(batch: EventBatch): Boolean = { - !batch.getErrorMsg.toString.equals("") // If there is an error message, it is an error batch. - } -} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala deleted file mode 100644 index 19e736f016977..0000000000000 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.nio.ByteBuffer -import java.util -import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} - -import scala.util.control.Breaks - -import org.apache.flume.{Channel, Transaction} - -// Flume forces transactions to be thread-local (horrible, I know!) -// So the sink basically spawns a new thread to pull the events out within a transaction. -// The thread fills in the event batch object that is set before the thread is scheduled. -// After filling it in, the thread waits on a condition - which is released only -// when the success message comes back for the specific sequence number for that event batch. -/** - * This class represents a transaction on the Flume channel. This class runs a separate thread - * which owns the transaction. The thread is blocked until the success call for that transaction - * comes back with an ACK or NACK. - * @param channel The channel from which to pull events - * @param seqNum The sequence number to use for the transaction. Must be unique - * @param maxBatchSize The maximum number of events to process per batch - * @param transactionTimeout Time in seconds after which a transaction must be rolled back - * without waiting for an ACK from Spark - * @param parent The parent [[SparkAvroCallbackHandler]] instance, for reporting timeouts - */ -private class TransactionProcessor(val channel: Channel, val seqNum: String, - var maxBatchSize: Int, val transactionTimeout: Int, val backOffInterval: Int, - val parent: SparkAvroCallbackHandler) extends Callable[Void] with Logging { - - // If a real batch is not returned, we always have to return an error batch. - @volatile private var eventBatch: EventBatch = new EventBatch("Unknown Error", "", - util.Collections.emptyList()) - - // Synchronization primitives - val batchGeneratedLatch = new CountDownLatch(1) - val batchAckLatch = new CountDownLatch(1) - - // Sanity check to ensure we don't loop like crazy - val totalAttemptsToRemoveFromChannel = Int.MaxValue / 2 - - // OK to use volatile, since the change would only make this true (otherwise it will be - // changed to false - we never apply a negation operation to this) - which means the transaction - // succeeded. - @volatile private var batchSuccess = false - - @volatile private var stopped = false - - @volatile private var isTest = false - - private var testLatch: CountDownLatch = null - - // The transaction that this processor would handle - var txOpt: Option[Transaction] = None - - /** - * Get an event batch from the channel. This method will block until a batch of events is - * available from the channel. If no events are available after a large number of attempts of - * polling the channel, this method will return an [[EventBatch]] with a non-empty error message - * - * @return An [[EventBatch]] instance with sequence number set to seqNum, filled with a - * maximum of maxBatchSize events - */ - def getEventBatch: EventBatch = { - batchGeneratedLatch.await() - eventBatch - } - - /** - * This method is to be called by the sink when it receives an ACK or NACK from Spark. This - * method is a no-op if it is called after transactionTimeout has expired since - * getEventBatch returned a batch of events. - * @param success True if an ACK was received and the transaction should be committed, else false. - */ - def batchProcessed(success: Boolean) { - logDebug("Batch processed for sequence number: " + seqNum) - batchSuccess = success - batchAckLatch.countDown() - } - - private[flume] def shutdown(): Unit = { - logDebug("Shutting down transaction processor") - stopped = true - } - - /** - * Populates events into the event batch. If the batch cannot be populated, - * this method will not set the events into the event batch, but it sets an error message. - */ - private def populateEvents() { - try { - txOpt = Option(channel.getTransaction) - if(txOpt.isEmpty) { - eventBatch.setErrorMsg("Something went wrong. Channel was " + - "unable to create a transaction!") - } - txOpt.foreach { tx => - tx.begin() - val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) - val loop = new Breaks - var gotEventsInThisTxn = false - var loopCounter: Int = 0 - loop.breakable { - while (!stopped && events.size() < maxBatchSize - && loopCounter < totalAttemptsToRemoveFromChannel) { - loopCounter += 1 - Option(channel.take()) match { - case Some(event) => - events.add(new SparkSinkEvent(toCharSequenceMap(event.getHeaders), - ByteBuffer.wrap(event.getBody))) - gotEventsInThisTxn = true - case None => - if (!gotEventsInThisTxn && !stopped) { - logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" + - " the current transaction") - TimeUnit.MILLISECONDS.sleep(backOffInterval) - } else { - loop.break() - } - } - } - } - if (!gotEventsInThisTxn && !stopped) { - val msg = "Tried several times, " + - "but did not get any events from the channel!" - logWarning(msg) - eventBatch.setErrorMsg(msg) - } else { - // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("", seqNum, events) - } - } - } catch { - case interrupted: InterruptedException => - // Don't pollute logs if the InterruptedException came from this being stopped - if (!stopped) { - logWarning("Error while processing transaction.", interrupted) - } - case e: Exception => - logWarning("Error while processing transaction.", e) - eventBatch.setErrorMsg(e.getMessage) - try { - txOpt.foreach { tx => - rollbackAndClose(tx, close = true) - } - } finally { - txOpt = None - } - } finally { - batchGeneratedLatch.countDown() - } - } - - /** - * Waits for upto transactionTimeout seconds for an ACK. If an ACK comes in - * this method commits the transaction with the channel. If the ACK does not come in within - * that time or a NACK comes in, this method rolls back the transaction. - */ - private def processAckOrNack() { - batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) - txOpt.foreach { tx => - if (batchSuccess) { - try { - logDebug("Committing transaction") - tx.commit() - } catch { - case e: Exception => - logWarning("Error while attempting to commit transaction. Transaction will be rolled " + - "back", e) - rollbackAndClose(tx, close = false) // tx will be closed later anyway - } finally { - tx.close() - if (isTest) { - testLatch.countDown() - } - } - } else { - logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") - rollbackAndClose(tx, close = true) - // This might have been due to timeout or a NACK. Either way the following call does not - // cause issues. This is required to ensure the TransactionProcessor instance is not leaked - parent.removeAndGetProcessor(seqNum) - } - } - } - - /** - * Helper method to rollback and optionally close a transaction - * @param tx The transaction to rollback - * @param close Whether the transaction should be closed or not after rolling back - */ - private def rollbackAndClose(tx: Transaction, close: Boolean) { - try { - logWarning("Spark was unable to successfully process the events. Transaction is being " + - "rolled back.") - tx.rollback() - } catch { - case e: Exception => - logError("Error rolling back transaction. Rollback may have failed!", e) - } finally { - if (close) { - tx.close() - } - } - } - - /** - * Helper method to convert a Map[String, String] to Map[CharSequence, CharSequence] - * @param inMap The map to be converted - * @return The converted map - */ - private def toCharSequenceMap(inMap: java.util.Map[String, String]): java.util.Map[CharSequence, - CharSequence] = { - val charSeqMap = new util.HashMap[CharSequence, CharSequence](inMap.size()) - charSeqMap.putAll(inMap) - charSeqMap - } - - /** - * When the thread is started it sets as many events as the batch size or less (if enough - * events aren't available) into the eventBatch and object and lets any threads waiting on the - * [[getEventBatch]] method to proceed. Then this thread waits for acks or nacks to come in, - * or for a specified timeout and commits or rolls back the transaction. - * @return - */ - override def call(): Void = { - populateEvents() - processAckOrNack() - null - } - - private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { - testLatch = latch - isTest = true - } -} diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties deleted file mode 100644 index 1e3f163f95c09..0000000000000 --- a/external/flume-sink/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file streaming/target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark_project.jetty=WARN - diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala deleted file mode 100644 index e8ca1e716394d..0000000000000 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume.sink - -import java.net.InetSocketAddress -import java.nio.charset.StandardCharsets -import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.JavaConverters._ -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.event.EventBuilder -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory - -// Due to MNG-1378, there is not a way to include test dependencies transitively. -// We cannot include Spark core tests as a dependency here because it depends on -// Spark core main, which has too many dependencies to require here manually. -// For this reason, we continue to use FunSuite and ignore the scalastyle checks -// that fail if this is detected. -// scalastyle:off -import org.scalatest.FunSuite - -class SparkSinkSuite extends FunSuite { -// scalastyle:on - - val eventsPerBatch = 1000 - val channelCapacity = 5000 - - test("Success with ack") { - val (channel, sink, latch) = initializeChannelAndSink() - channel.start() - sink.start() - - putEvents(channel, eventsPerBatch) - - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - client.ack(events.getSequenceNumber) - assert(events.getEvents.size() === 1000) - latch.await(1, TimeUnit.SECONDS) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Failure with nack") { - val (channel, sink, latch) = initializeChannelAndSink() - channel.start() - sink.start() - putEvents(channel, eventsPerBatch) - - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - assert(events.getEvents.size() === 1000) - client.nack(events.getSequenceNumber) - latch.await(1, TimeUnit.SECONDS) - assert(availableChannelSlots(channel) === 4000) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Failure with timeout") { - val (channel, sink, latch) = initializeChannelAndSink(Map(SparkSinkConfig - .CONF_TRANSACTION_TIMEOUT -> 1.toString)) - channel.start() - sink.start() - putEvents(channel, eventsPerBatch) - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - - val (transceiver, client) = getTransceiverAndClient(address, 1)(0) - val events = client.getEventBatch(1000) - assert(events.getEvents.size() === 1000) - latch.await(1, TimeUnit.SECONDS) - assert(availableChannelSlots(channel) === 4000) - sink.stop() - channel.stop() - transceiver.close() - } - - test("Multiple consumers") { - testMultipleConsumers(failSome = false) - } - - test("Multiple consumers with some failures") { - testMultipleConsumers(failSome = true) - } - - def testMultipleConsumers(failSome: Boolean): Unit = { - implicit val executorContext = ExecutionContext - .fromExecutorService(Executors.newFixedThreadPool(5)) - val (channel, sink, latch) = initializeChannelAndSink(Map.empty, 5) - channel.start() - sink.start() - (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) - val port = sink.getPort - val address = new InetSocketAddress("0.0.0.0", port) - val transceiversAndClients = getTransceiverAndClient(address, 5) - val batchCounter = new CountDownLatch(5) - val counter = new AtomicInteger(0) - transceiversAndClients.foreach(x => { - Future { - val client = x._2 - val events = client.getEventBatch(1000) - if (!failSome || counter.getAndIncrement() % 2 == 0) { - client.ack(events.getSequenceNumber) - } else { - client.nack(events.getSequenceNumber) - throw new RuntimeException("Sending NACK for failure!") - } - events - }.onComplete { - case Success(events) => - assert(events.getEvents.size() === 1000) - batchCounter.countDown() - case Failure(t) => - // Don't re-throw the exception, causes a nasty unnecessary stack trace on stdout - batchCounter.countDown() - } - }) - batchCounter.await() - latch.await(1, TimeUnit.SECONDS) - executorContext.shutdown() - if(failSome) { - assert(availableChannelSlots(channel) === 3000) - } else { - assertChannelIsEmpty(channel) - } - sink.stop() - channel.stop() - transceiversAndClients.foreach(x => x._1.close()) - } - - private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty, - batchCounter: Int = 1): (MemoryChannel, SparkSink, CountDownLatch) = { - val channel = new MemoryChannel() - val channelContext = new Context() - - channelContext.put("capacity", channelCapacity.toString) - channelContext.put("transactionCapacity", 1000.toString) - channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides.asJava) - channel.setName(scala.util.Random.nextString(10)) - channel.configure(channelContext) - - val sink = new SparkSink() - val sinkContext = new Context() - sinkContext.put(SparkSinkConfig.CONF_HOSTNAME, "0.0.0.0") - sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) - sink.configure(sinkContext) - sink.setChannel(channel) - val latch = new CountDownLatch(batchCounter) - sink.countdownWhenBatchReceived(latch) - (channel, sink, latch) - } - - private def putEvents(ch: MemoryChannel, count: Int): Unit = { - val tx = ch.getTransaction - tx.begin() - (1 to count).foreach(x => - ch.put(EventBuilder.withBody(x.toString.getBytes(StandardCharsets.UTF_8)))) - tx.commit() - tx.close() - } - - private def getTransceiverAndClient(address: InetSocketAddress, - count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { - - (1 to count).map(_ => { - lazy val channelFactoryExecutor = Executors.newCachedThreadPool( - new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) - lazy val channelFactory = - new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) - val transceiver = new NettyTransceiver(address, channelFactory) - val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) - (transceiver, client) - }) - } - - private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - assert(availableChannelSlots(channel) === channelCapacity) - } - - private def availableChannelSlots(channel: MemoryChannel): Int = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] - } -} diff --git a/external/flume/pom.xml b/external/flume/pom.xml deleted file mode 100644 index 1410ef7f4702d..0000000000000 --- a/external/flume/pom.xml +++ /dev/null @@ -1,89 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 3.0.0-SNAPSHOT - ../../pom.xml - - - spark-streaming-flume_2.11 - - streaming-flume - - jar - Spark Project External Flume - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-streaming-flume-sink_${scala.binary.version} - ${project.version} - - - org.apache.flume - flume-ng-core - - - org.apache.flume - flume-ng-sdk - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-tags_${scala.binary.version} - - - - - org.apache.spark - spark-tags_${scala.binary.version} - test-jar - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java deleted file mode 100644 index 4e3420d9c3b06..0000000000000 --- a/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.*; -import org.apache.spark.streaming.api.java.*; -import org.apache.spark.streaming.flume.FlumeUtils; -import org.apache.spark.streaming.flume.SparkFlumeEvent; - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: JavaFlumeEventCount - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - * - * To run this example: - * `$ bin/run-example org.apache.spark.examples.streaming.JavaFlumeEventCount ` - */ -public final class JavaFlumeEventCount { - private JavaFlumeEventCount() { - } - - public static void main(String[] args) throws Exception { - if (args.length != 2) { - System.err.println("Usage: JavaFlumeEventCount "); - System.exit(1); - } - - String host = args[0]; - int port = Integer.parseInt(args[1]); - - Duration batchInterval = new Duration(2000); - SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); - JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); - JavaReceiverInputDStream flumeStream = - FlumeUtils.createStream(ssc, host, port); - - flumeStream.count(); - - flumeStream.count().map(in -> "Received " + in + " flume events.").print(); - - ssc.start(); - ssc.awaitTermination(); - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala deleted file mode 100644 index f877f79391b37..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming._ -import org.apache.spark.streaming.flume._ -import org.apache.spark.util.IntParam - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: FlumeEventCount - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - * - * To run this example: - * `$ bin/run-example org.apache.spark.examples.streaming.FlumeEventCount ` - */ -object FlumeEventCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println( - "Usage: FlumeEventCount ") - System.exit(1) - } - - val Array(host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - - // Create the context and set the batch size - val sparkConf = new SparkConf().setAppName("FlumeEventCount") - val ssc = new StreamingContext(sparkConf, batchInterval) - - // Create a flume stream - val stream = FlumeUtils.createStream(ssc, host, port, StorageLevel.MEMORY_ONLY_SER_2) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala deleted file mode 100644 index 79a4027ca5bde..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.SparkConf -import org.apache.spark.streaming._ -import org.apache.spark.streaming.flume._ -import org.apache.spark.util.IntParam - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with the Spark Sink running in a Flume agent. See - * the Spark Streaming programming guide for more details. - * - * Usage: FlumePollingEventCount - * `host` is the host on which the Spark Sink is running. - * `port` is the port at which the Spark Sink is listening. - * - * To run this example: - * `$ bin/run-example org.apache.spark.examples.streaming.FlumePollingEventCount [host] [port] ` - */ -object FlumePollingEventCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println( - "Usage: FlumePollingEventCount ") - System.exit(1) - } - - val Array(host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - - // Create the context and set the batch size - val sparkConf = new SparkConf().setAppName("FlumePollingEventCount") - val ssc = new StreamingContext(sparkConf, batchInterval) - - // Create a flume stream that polls the Spark Sink running in a Flume agent - val stream = FlumeUtils.createPollingStream(ssc, host, port) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala deleted file mode 100644 index 07c5286477737..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.io.{ObjectInput, ObjectOutput} - -import scala.collection.JavaConverters._ - -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - -/** - * A simple object that provides the implementation of readExternal and writeExternal for both - * the wrapper classes for Flume-style Events. - */ -private[streaming] object EventTransformer extends Logging { - def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence], - Array[Byte]) = { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.readFully(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.readFully(keyBuff) - val key: String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.readFully(valBuff) - val value: String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - (headers, bodyBuff) - } - - def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence], - body: Array[Byte]) { - out.writeInt(body.length) - out.write(body) - val numHeaders = headers.size() - out.writeInt(numHeaders) - for ((k, v) <- headers.asScala) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala deleted file mode 100644 index 8af7c23431063..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Throwables - -import org.apache.spark.internal.Logging -import org.apache.spark.streaming.flume.sink._ - -/** - * This class implements the core functionality of [[FlumePollingReceiver]]. When started it - * pulls data from Flume, stores it to Spark and then sends an Ack or Nack. This class should be - * run via a [[java.util.concurrent.Executor]] as this implements [[Runnable]] - * - * @param receiver The receiver that owns this instance. - */ - -private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends Runnable with - Logging { - - def run(): Unit = { - while (!receiver.isStopped()) { - val connection = receiver.getConnections.poll() - val client = connection.client - var batchReceived = false - var seq: CharSequence = null - try { - getBatch(client) match { - case Some(eventBatch) => - batchReceived = true - seq = eventBatch.getSequenceNumber - val events = toSparkFlumeEvents(eventBatch.getEvents) - if (store(events)) { - sendAck(client, seq) - } else { - sendNack(batchReceived, client, seq) - } - case None => - } - } catch { - case e: Exception => - Throwables.getRootCause(e) match { - // If the cause was an InterruptedException, then check if the receiver is stopped - - // if yes, just break out of the loop. Else send a Nack and log a warning. - // In the unlikely case, the cause was not an Exception, - // then just throw it out and exit. - case interrupted: InterruptedException => - if (!receiver.isStopped()) { - logWarning("Interrupted while receiving data from Flume", interrupted) - sendNack(batchReceived, client, seq) - } - case exception: Exception => - logWarning("Error while receiving data from Flume", exception) - sendNack(batchReceived, client, seq) - } - } finally { - receiver.getConnections.add(connection) - } - } - } - - /** - * Gets a batch of events from the specified client. This method does not handle any exceptions - * which will be propagated to the caller. - * @param client Client to get events from - * @return [[Some]] which contains the event batch if Flume sent any events back, else [[None]] - */ - private def getBatch(client: SparkFlumeProtocol.Callback): Option[EventBatch] = { - val eventBatch = client.getEventBatch(receiver.getMaxBatchSize) - if (!SparkSinkUtils.isErrorBatch(eventBatch)) { - // No error, proceed with processing data - logDebug(s"Received batch of ${eventBatch.getEvents.size} events with sequence " + - s"number: ${eventBatch.getSequenceNumber}") - Some(eventBatch) - } else { - logWarning("Did not receive events from Flume agent due to error on the Flume agent: " + - eventBatch.getErrorMsg) - None - } - } - - /** - * Store the events in the buffer to Spark. This method will not propagate any exceptions, - * but will propagate any other errors. - * @param buffer The buffer to store - * @return true if the data was stored without any exception being thrown, else false - */ - private def store(buffer: ArrayBuffer[SparkFlumeEvent]): Boolean = { - try { - receiver.store(buffer) - true - } catch { - case e: Exception => - logWarning("Error while attempting to store data received from Flume", e) - false - } - } - - /** - * Send an ack to the client for the sequence number. This method does not handle any exceptions - * which will be propagated to the caller. - * @param client client to send the ack to - * @param seq sequence number of the batch to be ack-ed. - * @return - */ - private def sendAck(client: SparkFlumeProtocol.Callback, seq: CharSequence): Unit = { - logDebug("Sending ack for sequence number: " + seq) - client.ack(seq) - logDebug("Ack sent for sequence number: " + seq) - } - - /** - * This method sends a Nack if a batch was received to the client with the given sequence - * number. Any exceptions thrown by the RPC call is simply thrown out as is - no effort is made - * to handle it. - * @param batchReceived true if a batch was received. If this is false, no nack is sent - * @param client The client to which the nack should be sent - * @param seq The sequence number of the batch that is being nack-ed. - */ - private def sendNack(batchReceived: Boolean, client: SparkFlumeProtocol.Callback, - seq: CharSequence): Unit = { - if (batchReceived) { - // Let Flume know that the events need to be pushed back into the channel. - logDebug("Sending nack for sequence number: " + seq) - client.nack(seq) // If the agent is down, even this could fail and throw - logDebug("Nack sent for sequence number: " + seq) - } - } - - /** - * Utility method to convert [[SparkSinkEvent]]s to [[SparkFlumeEvent]]s - * @param events - Events to convert to SparkFlumeEvents - * @return - The SparkFlumeEvent generated from SparkSinkEvent - */ - private def toSparkFlumeEvents(events: java.util.List[SparkSinkEvent]): - ArrayBuffer[SparkFlumeEvent] = { - // Convert each Flume event to a serializable SparkFlumeEvent - val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) - var j = 0 - while (j < events.size()) { - val event = events.get(j) - val sparkFlumeEvent = new SparkFlumeEvent() - sparkFlumeEvent.event.setBody(event.getBody) - sparkFlumeEvent.event.setHeaders(event.getHeaders) - buffer += sparkFlumeEvent - j += 1 - } - buffer - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala deleted file mode 100644 index 13aa817492f7b..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.io.{Externalizable, ObjectInput, ObjectOutput} -import java.net.InetSocketAddress -import java.nio.ByteBuffer -import java.util.concurrent.Executors - -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import org.apache.avro.ipc.NettyServer -import org.apache.avro.ipc.specific.SpecificResponder -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol, Status} -import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} -import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory -import org.jboss.netty.handler.codec.compression._ - -import org.apache.spark.internal.Logging -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils - -private[streaming] -class FlumeInputDStream[T: ClassTag]( - _ssc: StreamingContext, - host: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean -) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { - - override def getReceiver(): Receiver[SparkFlumeEvent] = { - new FlumeReceiver(host, port, storageLevel, enableDecompression) - } -} - -/** - * A wrapper class for AvroFlumeEvent's with a custom serialization format. - * - * This is necessary because AvroFlumeEvent uses inner data structures - * which are not serializable. - */ -class SparkFlumeEvent() extends Externalizable { - var event: AvroFlumeEvent = new AvroFlumeEvent() - - /* De-serialize from bytes. */ - def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val bodyLength = in.readInt() - val bodyBuff = new Array[Byte](bodyLength) - in.readFully(bodyBuff) - - val numHeaders = in.readInt() - val headers = new java.util.HashMap[CharSequence, CharSequence] - - for (i <- 0 until numHeaders) { - val keyLength = in.readInt() - val keyBuff = new Array[Byte](keyLength) - in.readFully(keyBuff) - val key: String = Utils.deserialize(keyBuff) - - val valLength = in.readInt() - val valBuff = new Array[Byte](valLength) - in.readFully(valBuff) - val value: String = Utils.deserialize(valBuff) - - headers.put(key, value) - } - - event.setBody(ByteBuffer.wrap(bodyBuff)) - event.setHeaders(headers) - } - - /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody - out.writeInt(body.remaining()) - Utils.writeByteBuffer(body, out) - - val numHeaders = event.getHeaders.size() - out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders.asScala) { - val keyBuff = Utils.serialize(k.toString) - out.writeInt(keyBuff.length) - out.write(keyBuff) - val valBuff = Utils.serialize(v.toString) - out.writeInt(valBuff.length) - out.write(valBuff) - } - } -} - -private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in: AvroFlumeEvent): SparkFlumeEvent = { - val event = new SparkFlumeEvent - event.event = in - event - } -} - -/** A simple server that implements Flume's Avro protocol. */ -private[streaming] -class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol { - override def append(event: AvroFlumeEvent): Status = { - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event)) - Status.OK - } - - override def appendBatch(events: java.util.List[AvroFlumeEvent]): Status = { - events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) - Status.OK - } -} - -/** - * A NetworkReceiver which listens for events using the - * Flume Avro interface. - */ -private[streaming] -class FlumeReceiver( - host: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { - - lazy val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)) - var server: NettyServer = null - - private def initServer() = { - if (enableDecompression) { - val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), - Executors.newCachedThreadPool()) - val channelPipelineFactory = new CompressionChannelPipelineFactory() - - new NettyServer( - responder, - new InetSocketAddress(host, port), - channelFactory, - channelPipelineFactory, - null) - } else { - new NettyServer(responder, new InetSocketAddress(host, port)) - } - } - - def onStart() { - synchronized { - if (server == null) { - server = initServer() - server.start() - } else { - logWarning("Flume receiver being asked to start more then once with out close") - } - } - logInfo("Flume receiver started") - } - - def onStop() { - synchronized { - if (server != null) { - server.close() - server = null - } - } - logInfo("Flume receiver stopped") - } - - override def preferredLocation: Option[String] = Option(host) - - /** - * A Netty Pipeline factory that will decompress incoming data from - * and the Netty client and compress data going back to the client. - * - * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel - */ - private[streaming] - class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - def getPipeline(): ChannelPipeline = { - val pipeline = Channels.pipeline() - val encoder = new ZlibEncoder(6) - pipeline.addFirst("deflater", encoder) - pipeline.addFirst("inflater", new ZlibDecoder()) - pipeline - } - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala deleted file mode 100644 index d84e289272c62..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.flume - - -import java.net.InetSocketAddress -import java.util.concurrent.{Executors, LinkedBlockingQueue, TimeUnit} - -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory - -import org.apache.spark.internal.Logging -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.flume.sink._ -import org.apache.spark.streaming.receiver.Receiver - -/** - * A `ReceiverInputDStream` that can be used to read data from several Flume agents running - * [[org.apache.spark.streaming.flume.sink.SparkSink]]s. - * @param _ssc Streaming context that will execute this input stream - * @param addresses List of addresses at which SparkSinks are listening - * @param maxBatchSize Maximum size of a batch - * @param parallelism Number of parallel connections to open - * @param storageLevel The storage level to use. - * @tparam T Class type of the object of this stream - */ -private[streaming] class FlumePollingInputDStream[T: ClassTag]( - _ssc: StreamingContext, - val addresses: Seq[InetSocketAddress], - val maxBatchSize: Int, - val parallelism: Int, - storageLevel: StorageLevel - ) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { - - override def getReceiver(): Receiver[SparkFlumeEvent] = { - new FlumePollingReceiver(addresses, maxBatchSize, parallelism, storageLevel) - } -} - -private[streaming] class FlumePollingReceiver( - addresses: Seq[InetSocketAddress], - maxBatchSize: Int, - parallelism: Int, - storageLevel: StorageLevel - ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { - - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) - - lazy val channelFactory = - new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) - - lazy val receiverExecutor = Executors.newFixedThreadPool(parallelism, - new ThreadFactoryBuilder().setDaemon(true).setNameFormat("Flume Receiver Thread - %d").build()) - - private lazy val connections = new LinkedBlockingQueue[FlumeConnection]() - - override def onStart(): Unit = { - // Create the connections to each Flume agent. - addresses.foreach { host => - val transceiver = new NettyTransceiver(host, channelFactory) - val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) - connections.add(new FlumeConnection(transceiver, client)) - } - for (i <- 0 until parallelism) { - logInfo("Starting Flume Polling Receiver worker threads..") - // Threads that pull data from Flume. - receiverExecutor.submit(new FlumeBatchFetcher(this)) - } - } - - override def onStop(): Unit = { - logInfo("Shutting down Flume Polling Receiver") - receiverExecutor.shutdown() - // Wait upto a minute for the threads to die - if (!receiverExecutor.awaitTermination(60, TimeUnit.SECONDS)) { - receiverExecutor.shutdownNow() - } - connections.asScala.foreach(_.transceiver.close()) - channelFactory.releaseExternalResources() - } - - private[flume] def getConnections: LinkedBlockingQueue[FlumeConnection] = { - this.connections - } - - private[flume] def getMaxBatchSize: Int = { - this.maxBatchSize - } -} - -/** - * A wrapper around the transceiver and the Avro IPC API. - * @param transceiver The transceiver to use for communication with Flume - * @param client The client that the callbacks are received on. - */ -private[flume] class FlumeConnection(val transceiver: NettyTransceiver, - val client: SparkFlumeProtocol.Callback) - - - diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala deleted file mode 100644 index e8623b4766aea..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer -import java.nio.charset.StandardCharsets -import java.util.{List => JList} -import java.util.Collections - -import scala.collection.JavaConverters._ - -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} - -import org.apache.spark.SparkConf -import org.apache.spark.util.Utils - -/** - * Share codes for Scala and Python unit tests - */ -private[flume] class FlumeTestUtils { - - private var transceiver: NettyTransceiver = null - - private val testPort: Int = findFreePort() - - def getTestPort(): Int = testPort - - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - /** Send data to the flume receiver */ - def writeInput(input: JList[String], enableCompression: Boolean): Unit = { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.asScala.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(StandardCharsets.UTF_8))) - event.setHeaders(Collections.singletonMap("test", "header")) - event - } - - // if last attempted transceiver had succeeded, close it - close() - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - if (client == null) { - throw new AssertionError("Cannot create client") - } - - // Send data - val status = client.appendBatch(inputEvents.asJava) - if (status != avro.Status.OK) { - throw new AssertionError("Sent events unsuccessfully") - } - } - - def close(): Unit = { - if (transceiver != null) { - transceiver.close() - transceiver = null - } - } - - /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) - extends NioClientSocketChannelFactory { - - override def newChannel(pipeline: ChannelPipeline): SocketChannel = { - val encoder = new ZlibEncoder(compressionLevel) - pipeline.addFirst("deflater", encoder) - pipeline.addFirst("inflater", new ZlibDecoder()) - super.newChannel(pipeline) - } - } - -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala deleted file mode 100644 index 707193a957700..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.io.{ByteArrayOutputStream, DataOutputStream} -import java.net.InetSocketAddress -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConverters._ - -import org.apache.spark.api.java.function.PairFunction -import org.apache.spark.api.python.PythonRDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -@deprecated("Deprecated without replacement", "2.3.0") -object FlumeUtils { - private val DEFAULT_POLLING_PARALLELISM = 5 - private val DEFAULT_POLLING_BATCH_SIZE = 1000 - - /** - * Create a input stream from a Flume source. - * @param ssc StreamingContext object - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream ( - ssc: StreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[SparkFlumeEvent] = { - createStream(ssc, hostname, port, storageLevel, false) - } - - /** - * Create a input stream from a Flume source. - * @param ssc StreamingContext object - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - * @param enableDecompression should netty server decompress input stream - */ - def createStream ( - ssc: StreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ): ReceiverInputDStream[SparkFlumeEvent] = { - val inputStream = new FlumeInputDStream[SparkFlumeEvent]( - ssc, hostname, port, storageLevel, enableDecompression) - - inputStream - } - - /** - * Creates a input stream from a Flume source. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - */ - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port) - } - - /** - * Creates a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port, storageLevel, false) - } - - /** - * Creates a input stream from a Flume source. - * @param hostname Hostname of the slave machine to which the flume data will be sent - * @param port Port of the slave machine to which the flume data will be sent - * @param storageLevel Storage level to use for storing the received objects - * @param enableDecompression should netty server decompress input stream - */ - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param hostname Address of the host on which the Spark Sink is running - * @param port Port of the host at which the Spark Sink is listening - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - ssc: StreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(ssc, Seq(new InetSocketAddress(hostname, port)), storageLevel) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param addresses List of InetSocketAddresses representing the hosts to connect to. - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - ssc: StreamingContext, - addresses: Seq[InetSocketAddress], - storageLevel: StorageLevel - ): ReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(ssc, addresses, storageLevel, - DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * @param addresses List of InetSocketAddresses representing the hosts to connect to. - * @param maxBatchSize Maximum number of events to be pulled from the Spark sink in a - * single RPC call - * @param parallelism Number of concurrent requests this stream should send to the sink. Note - * that having a higher number of requests concurrently being pulled will - * result in this stream using more threads - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - ssc: StreamingContext, - addresses: Seq[InetSocketAddress], - storageLevel: StorageLevel, - maxBatchSize: Int, - parallelism: Int - ): ReceiverInputDStream[SparkFlumeEvent] = { - new FlumePollingInputDStream[SparkFlumeEvent](ssc, addresses, maxBatchSize, - parallelism, storageLevel) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param hostname Hostname of the host on which the Spark Sink is running - * @param port Port of the host at which the Spark Sink is listening - */ - def createPollingStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc, hostname, port, StorageLevel.MEMORY_AND_DISK_SER_2) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param hostname Hostname of the host on which the Spark Sink is running - * @param port Port of the host at which the Spark Sink is listening - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc, Array(new InetSocketAddress(hostname, port)), storageLevel) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * This stream will use a batch size of 1000 events and run 5 threads to pull data. - * @param addresses List of InetSocketAddresses on which the Spark Sink is running. - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - jssc: JavaStreamingContext, - addresses: Array[InetSocketAddress], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc, addresses, storageLevel, - DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) - } - - /** - * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - * This stream will poll the sink for data and will pull events as they are available. - * @param addresses List of InetSocketAddresses on which the Spark Sink is running - * @param maxBatchSize The maximum number of events to be pulled from the Spark sink in a - * single RPC call - * @param parallelism Number of concurrent requests this stream should send to the sink. Note - * that having a higher number of requests concurrently being pulled will - * result in this stream using more threads - * @param storageLevel Storage level to use for storing the received objects - */ - def createPollingStream( - jssc: JavaStreamingContext, - addresses: Array[InetSocketAddress], - storageLevel: StorageLevel, - maxBatchSize: Int, - parallelism: Int - ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) - } -} - -/** - * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and - * function so that it can be easily instantiated and called from Python's FlumeUtils. - */ -private[flume] class FlumeUtilsPythonHelper { - - def createStream( - jssc: JavaStreamingContext, - hostname: String, - port: Int, - storageLevel: StorageLevel, - enableDecompression: Boolean - ): JavaPairDStream[Array[Byte], Array[Byte]] = { - val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) - FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) - } - - def createPollingStream( - jssc: JavaStreamingContext, - hosts: JList[String], - ports: JList[Int], - storageLevel: StorageLevel, - maxBatchSize: Int, - parallelism: Int - ): JavaPairDStream[Array[Byte], Array[Byte]] = { - assert(hosts.size() == ports.size()) - val addresses = hosts.asScala.zip(ports.asScala).map { - case (host, port) => new InetSocketAddress(host, port) - } - val dstream = FlumeUtils.createPollingStream( - jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) - FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) - } - -} - -private object FlumeUtilsPythonHelper { - - private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { - val byteStream = new ByteArrayOutputStream() - val output = new DataOutputStream(byteStream) - try { - output.writeInt(map.size) - map.asScala.foreach { kv => - PythonRDD.writeUTF(kv._1.toString, output) - PythonRDD.writeUTF(kv._2.toString, output) - } - byteStream.toByteArray - } - finally { - output.close() - } - } - - private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): - JavaPairDStream[Array[Byte], Array[Byte]] = { - dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { - override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { - val event = sparkEvent.event - val byteBuffer = event.getBody - val body = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(body) - (stringMapToByteArray(event.getHeaders), body) - } - }) - } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala deleted file mode 100644 index a3e784a4f32ee..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.nio.charset.StandardCharsets -import java.util.{Collections, List => JList, Map => JMap} -import java.util.concurrent._ - -import scala.collection.mutable.ArrayBuffer - -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder - -import org.apache.spark.streaming.flume.sink.{SparkSink, SparkSinkConfig} - -/** - * Share codes for Scala and Python unit tests - */ -private[flume] class PollingFlumeTestUtils { - - private val batchCount = 5 - val eventsPerBatch = 100 - private val totalEventsPerChannel = batchCount * eventsPerBatch - private val channelCapacity = 5000 - - def getTotalEvents: Int = totalEventsPerChannel * channels.size - - private val channels = new ArrayBuffer[MemoryChannel] - private val sinks = new ArrayBuffer[SparkSink] - - /** - * Start a sink and return the port of this sink - */ - def startSingleSink(): Int = { - channels.clear() - sinks.clear() - - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - channels += (channel) - sinks += sink - - sink.getPort() - } - - /** - * Start 2 sinks and return the ports - */ - def startMultipleSinks(): Seq[Int] = { - channels.clear() - sinks.clear() - - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() - - sinks += sink - sinks += sink2 - channels += channel - channels += channel2 - - sinks.map(_.getPort()) - } - - /** - * Send data and wait until all data has been received - */ - def sendDataAndEnsureAllDataHasBeenReceived(): Unit = { - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach { channel => - executorCompletion.submit(new TxnSubmitter(channel)) - } - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - } - - /** - * A Python-friendly method to assert the output - */ - def assertOutput( - outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { - require(outputHeaders.size == outputBodies.size) - val eventSize = outputHeaders.size - if (eventSize != totalEventsPerChannel * channels.size) { - throw new AssertionError( - s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") - } - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventBodyToVerify = s"${channels(k).getName}-$i" - val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") - var found = false - var j = 0 - while (j < eventSize && !found) { - if (eventBodyToVerify == outputBodies.get(j) && - eventHeaderToVerify == outputHeaders.get(j)) { - found = true - counter += 1 - } - j += 1 - } - } - if (counter != totalEventsPerChannel * channels.size) { - throw new AssertionError( - s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") - } - } - - def assertChannelsAreEmpty(): Unit = { - channels.foreach(assertChannelIsEmpty) - } - - private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != channelCapacity) { - throw new AssertionError(s"Channel ${channel.getName} is not empty") - } - } - - def close(): Unit = { - sinks.foreach(_.stop()) - sinks.clear() - channels.foreach(_.stop()) - channels.clear() - } - - private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody( - s"${channel.getName}-$t".getBytes(StandardCharsets.UTF_8), - Collections.singletonMap(s"test-$t", "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach - } - null - } - } - -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java b/external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java deleted file mode 100644 index 4a5da226aded3..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Spark streaming receiver for Flume. - */ -package org.apache.spark.streaming.flume; diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala deleted file mode 100644 index 9bfab68c4b8b7..0000000000000 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * Spark streaming receiver for Flume. - */ -package object flume diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java deleted file mode 100644 index 79c5b91654b42..0000000000000 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume; - -import java.net.InetSocketAddress; - -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; - -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; - -public class JavaFlumePollingStreamSuite extends LocalJavaStreamingContext { - @Test - public void testFlumeStream() { - // tests the API, does not actually test data receiving - InetSocketAddress[] addresses = new InetSocketAddress[] { - new InetSocketAddress("localhost", 12345) - }; - JavaReceiverInputDStream test1 = - FlumeUtils.createPollingStream(ssc, "localhost", 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createPollingStream( - ssc, "localhost", 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createPollingStream( - ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test4 = FlumeUtils.createPollingStream( - ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2(), 100, 5); - } -} diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java deleted file mode 100644 index ada05f203b6a8..0000000000000 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume; - -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; - -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; - -public class JavaFlumeStreamSuite extends LocalJavaStreamingContext { - @Test - public void testFlumeStream() { - // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", - 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", - 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", - 12345, StorageLevel.MEMORY_AND_DISK_SER_2(), false); - } -} diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties deleted file mode 100644 index fd51f8faf56b9..0000000000000 --- a/external/flume/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark_project.jetty=WARN - diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala deleted file mode 100644 index c97a27ca7c7aa..0000000000000 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.reflect.ClassTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} -import org.apache.spark.util.Utils - -/** - * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. - * - * The buffer contains a sequence of RDD's, each containing a sequence of items - */ -class TestOutputStream[T: ClassTag](parent: DStream[T], - val output: ConcurrentLinkedQueue[Seq[T]] = new ConcurrentLinkedQueue[Seq[T]]()) - extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output.add(collected) - }, false) { - - // This is to clear the output buffer every it is read from a checkpoint - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { - ois.defaultReadObject() - output.clear() - } -} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala deleted file mode 100644 index 9241b13c100f1..0000000000000 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.net.InetSocketAddress -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.collection.JavaConverters._ -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually._ - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.util.{ManualClock, Utils} - -class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { - - val maxAttempts = 5 - val batchDuration = Seconds(1) - - @transient private var _sc: SparkContext = _ - - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName(this.getClass.getSimpleName) - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - - val utils = new PollingFlumeTestUtils - - override def beforeAll(): Unit = { - super.beforeAll() - _sc = new SparkContext(conf) - } - - override def afterAll(): Unit = { - try { - if (_sc != null) { - _sc.stop() - _sc = null - } - } finally { - super.afterAll() - } - } - - test("flume polling test") { - testMultipleTimes(() => testFlumePolling()) - } - - test("flume polling test multiple hosts") { - testMultipleTimes(() => testFlumePollingMultipleHost()) - } - - /** - * Run the given test until no more java.net.BindException's are thrown. - * Do this only up to a certain attempt limit. - */ - private def testMultipleTimes(test: () => Unit): Unit = { - var testPassed = false - var attempt = 0 - while (!testPassed && attempt < maxAttempts) { - try { - test() - testPassed = true - } catch { - case e: Exception if Utils.isBindCollision(e) => - logWarning("Exception when running flume polling test: " + e) - attempt += 1 - } - } - assert(testPassed, s"Test failed after $attempt attempts!") - } - - private def testFlumePolling(): Unit = { - try { - val port = utils.startSingleSink() - - writeAndVerify(Seq(port)) - utils.assertChannelsAreEmpty() - } finally { - utils.close() - } - } - - private def testFlumePollingMultipleHost(): Unit = { - try { - val ports = utils.startMultipleSinks() - writeAndVerify(ports) - utils.assertChannelsAreEmpty() - } finally { - utils.close() - } - } - - def writeAndVerify(sinkPorts: Seq[Int]): Unit = { - // Set up the streaming context and input streams - val ssc = new StreamingContext(_sc, batchDuration) - val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - utils.eventsPerBatch, 5) - val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputQueue) - outputStream.register() - - ssc.start() - try { - utils.sendDataAndEnsureAllDataHasBeenReceived() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenOutput = outputQueue.asScala.toSeq.flatten - val headers = flattenOutput.map(_.event.getHeaders.asScala.map { - case (key, value) => (key.toString, value.toString) - }).map(_.asJava) - val bodies = flattenOutput.map(e => JavaUtils.bytesToString(e.event.getBody)) - utils.assertOutput(headers.asJava, bodies.asJava) - } - } finally { - // here stop ssc only, but not underlying sparkcontext - ssc.stop(false) - } - } - -} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala deleted file mode 100644 index 7bac1cc4b0ae7..0000000000000 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.flume - -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.collection.JavaConverters._ -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, Matchers} -import org.scalatest.concurrent.Eventually._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} - -class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { - val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - - test("flume input stream") { - testFlumeStream(testCompression = false) - } - - test("flume input compressed stream") { - testFlumeStream(testCompression = true) - } - - /** Run test on flume stream */ - private def testFlumeStream(testCompression: Boolean): Unit = { - val input = (1 to 100).map { _.toString } - val utils = new FlumeTestUtils - try { - val outputQueue = startContext(utils.getTestPort(), testCompression) - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - utils.writeInput(input.asJava, testCompression) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputQueue.asScala.toSeq.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) - output should be (input) - } - } finally { - if (ssc != null) { - ssc.stop() - } - utils.close() - } - } - - /** Setup and start the streaming context */ - private def startContext( - testPort: Int, testCompression: Boolean): (ConcurrentLinkedQueue[Seq[SparkFlumeEvent]]) = { - ssc = new StreamingContext(conf, Milliseconds(200)) - val flumeStream = FlumeUtils.createStream( - ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) - val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputQueue) - outputStream.register() - ssc.start() - outputQueue - } - - /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) - extends NioClientSocketChannelFactory { - - override def newChannel(pipeline: ChannelPipeline): SocketChannel = { - val encoder = new ZlibEncoder(compressionLevel) - pipeline.addFirst("deflater", encoder) - pipeline.addFirst("inflater", new ZlibDecoder()) - super.newChannel(pipeline) - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0acc9b8d2a0cf..ba4009ef08856 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -108,7 +108,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } - // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") private[streaming] override def name: String = s"Kafka 0.10 direct stream [$id]" protected[streaming] override val checkpointData = diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 9297c39d170c4..2ec771e977147 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -67,7 +67,6 @@ class DirectKafkaInputDStream[ val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) - // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") private[streaming] override def name: String = s"Kafka direct stream [$id]" protected[streaming] override val checkpointData = diff --git a/pom.xml b/pom.xml index 59ba317aa8c96..7ce7c9f0280e5 100644 --- a/pom.xml +++ b/pom.xml @@ -121,7 +121,6 @@ 2.7.3 2.5.0 ${hadoop.version} - 1.6.0 3.4.6 2.7.1 org.spark-project.hive @@ -212,7 +211,6 @@ during compilation if the dependency is transivite (e.g. "graphx/" depending on "core/" and needing Hadoop classes in the classpath to compile). --> - compile compile compile compile @@ -1805,46 +1803,6 @@ ${hive.parquet.version} compile - - org.apache.flume - flume-ng-core - ${flume.version} - ${flume.deps.scope} - - - io.netty - netty - - - org.apache.flume - flume-ng-auth - - - org.apache.thrift - libthrift - - - org.mortbay.jetty - servlet-api - - - - - org.apache.flume - flume-ng-sdk - ${flume.version} - ${flume.deps.scope} - - - io.netty - netty - - - org.apache.thrift - libthrift - - - org.apache.calcite calcite-core @@ -2635,15 +2593,6 @@ - - flume - - external/flume - external/flume-sink - external/flume-assembly - - - spark-ganglia-lgpl @@ -2835,9 +2784,6 @@ maven does not complain when they're provided on the command line for a sub-module that does not have them. --> - - flume-provided - hadoop-provided diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a5ed9088eaa4d..8b01b9079e6d7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -55,16 +55,14 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, - streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) = Seq("kubernetes", "mesos", "yarn", - "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = - Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(networkYarn, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = + Seq("network-yarn", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples") @@ -373,8 +371,6 @@ object SparkBuild extends PomBuild { /* Hive console settings */ enable(Hive.settings)(hive) - enable(Flume.settings)(streamingFlumeSink) - // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) @@ -452,9 +448,6 @@ object Unsafe { ) } -object Flume { - lazy val settings = sbtavro.SbtAvro.avroSettings -} object DockerIntegrationTests { // This serves to override the override specified in DependencyOverrides: @@ -587,8 +580,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly := { - if (moduleName.value.contains("streaming-flume-assembly") - || moduleName.value.contains("streaming-kafka-0-8-assembly") + if (moduleName.value.contains("streaming-kafka-0-8-assembly") || moduleName.value.contains("streaming-kafka-0-10-assembly") || moduleName.value.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-0-8-assembly/pom.xml) @@ -694,10 +686,10 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 25ceabac0a541..9c256284ad698 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -22,10 +22,3 @@ pyspark.streaming.kinesis module :members: :undoc-members: :show-inheritance: - -pyspark.streaming.flume.module ------------------------------- -.. automodule:: pyspark.streaming.flume - :members: - :undoc-members: - :show-inheritance: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ce42a857d0c06..946601e779d2f 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -45,7 +45,7 @@ class DStream(object): for more details on RDDs). DStreams can either be created from live data (such as, data from TCP - sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + sockets, Kafka, etc.) using a L{StreamingContext} or it can be generated by transforming existing DStreams using operations such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each DStream periodically generates a RDD, either diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py deleted file mode 100644 index 5de448114ece8..0000000000000 --- a/python/pyspark/streaming/flume.py +++ /dev/null @@ -1,156 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sys -if sys.version >= "3": - from io import BytesIO -else: - from StringIO import StringIO -import warnings - -from py4j.protocol import Py4JJavaError - -from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int -from pyspark.streaming import DStream - -__all__ = ['FlumeUtils', 'utf8_decoder'] - - -def utf8_decoder(s): - """ Decode the unicode as UTF-8 """ - if s is None: - return None - return s.decode('utf-8') - - -class FlumeUtils(object): - - @staticmethod - def createStream(ssc, hostname, port, - storageLevel=StorageLevel.MEMORY_AND_DISK_2, - enableDecompression=False, - bodyDecoder=utf8_decoder): - """ - Create an input stream that pulls events from Flume. - - :param ssc: StreamingContext object - :param hostname: Hostname of the slave machine to which the flume data will be sent - :param port: Port of the slave machine to which the flume data will be sent - :param storageLevel: Storage level to use for storing the received objects - :param enableDecompression: Should netty server decompress input stream - :param bodyDecoder: A function used to decode body (default is utf8_decoder) - :return: A DStream object - - .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. - See SPARK-22142. - """ - warnings.warn( - "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " - "See SPARK-22142.", - DeprecationWarning) - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - helper = FlumeUtils._get_helper(ssc._sc) - jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) - return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) - - @staticmethod - def createPollingStream(ssc, addresses, - storageLevel=StorageLevel.MEMORY_AND_DISK_2, - maxBatchSize=1000, - parallelism=5, - bodyDecoder=utf8_decoder): - """ - Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. - This stream will poll the sink for data and will pull events as they are available. - - :param ssc: StreamingContext object - :param addresses: List of (host, port)s on which the Spark Sink is running. - :param storageLevel: Storage level to use for storing the received objects - :param maxBatchSize: The maximum number of events to be pulled from the Spark sink - in a single RPC call - :param parallelism: Number of concurrent requests this stream should send to the sink. - Note that having a higher number of requests concurrently being pulled - will result in this stream using more threads - :param bodyDecoder: A function used to decode body (default is utf8_decoder) - :return: A DStream object - - .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. - See SPARK-22142. - """ - warnings.warn( - "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " - "See SPARK-22142.", - DeprecationWarning) - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - hosts = [] - ports = [] - for (host, port) in addresses: - hosts.append(host) - ports.append(port) - helper = FlumeUtils._get_helper(ssc._sc) - jstream = helper.createPollingStream( - ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) - return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) - - @staticmethod - def _toPythonDStream(ssc, jstream, bodyDecoder): - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - - def func(event): - headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) - headers = {} - strSer = UTF8Deserializer() - for i in range(0, read_int(headersBytes)): - key = strSer.loads(headersBytes) - value = strSer.loads(headersBytes) - headers[key] = value - body = bodyDecoder(event[1]) - return (headers, body) - return stream.map(func) - - @staticmethod - def _get_helper(sc): - try: - return sc._jvm.org.apache.spark.streaming.flume.FlumeUtilsPythonHelper() - except TypeError as e: - if str(e) == "'JavaPackage' object is not callable": - FlumeUtils._printErrorMsg(sc) - raise - - @staticmethod - def _printErrorMsg(sc): - print(""" -________________________________________________________________________________________________ - - Spark Streaming's Flume libraries not found in class path. Try one of the following. - - 1. Include the Flume library and its dependencies with in the - spark-submit command as - - $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... - - 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, - Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. - Then, include the jar in the spark-submit command as - - $ bin/spark-submit --jars ... - -________________________________________________________________________________________________ - -""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5cef621a28e6e..4b995c04c07d0 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -48,7 +48,6 @@ from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition -from pyspark.streaming.flume import FlumeUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream from pyspark.streaming.listener import StreamingListener @@ -1301,148 +1300,6 @@ def getKeyAndDoubleMessage(m): self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream) -class FlumeStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - - def setUp(self): - super(FlumeStreamTests, self).setUp() - self._utils = self.ssc._jvm.org.apache.spark.streaming.flume.FlumeTestUtils() - - def tearDown(self): - if self._utils is not None: - self._utils.close() - self._utils = None - - super(FlumeStreamTests, self).tearDown() - - def _startContext(self, n, compressed): - # Start the StreamingContext and also collect the result - dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), - enableDecompression=compressed) - result = [] - - def get_output(_, rdd): - for event in rdd.collect(): - if len(result) < n: - result.append(event) - dstream.foreachRDD(get_output) - self.ssc.start() - return result - - def _validateResult(self, input, result): - # Validate both the header and the body - header = {"test": "header"} - self.assertEqual(len(input), len(result)) - for i in range(0, len(input)): - self.assertEqual(header, result[i][0]) - self.assertEqual(input[i], result[i][1]) - - def _writeInput(self, input, compressed): - # Try to write input to the receiver until success or timeout - start_time = time.time() - while True: - try: - self._utils.writeInput(input, compressed) - break - except: - if time.time() - start_time < self.timeout: - time.sleep(0.01) - else: - raise - - def test_flume_stream(self): - input = [str(i) for i in range(1, 101)] - result = self._startContext(len(input), False) - self._writeInput(input, False) - self.wait_for(result, len(input)) - self._validateResult(input, result) - - def test_compressed_flume_stream(self): - input = [str(i) for i in range(1, 101)] - result = self._startContext(len(input), True) - self._writeInput(input, True) - self.wait_for(result, len(input)) - self._validateResult(input, result) - - -class FlumePollingStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - maxAttempts = 5 - - def setUp(self): - self._utils = self.sc._jvm.org.apache.spark.streaming.flume.PollingFlumeTestUtils() - - def tearDown(self): - if self._utils is not None: - self._utils.close() - self._utils = None - - def _writeAndVerify(self, ports): - # Set up the streaming context and input streams - ssc = StreamingContext(self.sc, self.duration) - try: - addresses = [("localhost", port) for port in ports] - dstream = FlumeUtils.createPollingStream( - ssc, - addresses, - maxBatchSize=self._utils.eventsPerBatch(), - parallelism=5) - outputBuffer = [] - - def get_output(_, rdd): - for e in rdd.collect(): - outputBuffer.append(e) - - dstream.foreachRDD(get_output) - ssc.start() - self._utils.sendDataAndEnsureAllDataHasBeenReceived() - - self.wait_for(outputBuffer, self._utils.getTotalEvents()) - outputHeaders = [event[0] for event in outputBuffer] - outputBodies = [event[1] for event in outputBuffer] - self._utils.assertOutput(outputHeaders, outputBodies) - finally: - ssc.stop(False) - - def _testMultipleTimes(self, f): - attempt = 0 - while True: - try: - f() - break - except: - attempt += 1 - if attempt >= self.maxAttempts: - raise - else: - import traceback - traceback.print_exc() - - def _testFlumePolling(self): - try: - port = self._utils.startSingleSink() - self._writeAndVerify([port]) - self._utils.assertChannelsAreEmpty() - finally: - self._utils.close() - - def _testFlumePollingMultipleHosts(self): - try: - port = self._utils.startSingleSink() - self._writeAndVerify([port]) - self._utils.assertChannelsAreEmpty() - finally: - self._utils.close() - - def test_flume_polling(self): - self._testMultipleTimes(self._testFlumePolling) - - def test_flume_polling_multiple_hosts(self): - self._testMultipleTimes(self._testFlumePollingMultipleHosts) - - class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -1531,23 +1388,6 @@ def search_kafka_assembly_jar(): return jars[0] -def search_flume_assembly_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") - jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly") - if not jars: - raise Exception( - ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + - "You need to build Spark with " - "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or " - "'build/mvn -DskipTests -Pflume package' before running this test.") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - def _kinesis_asl_assembly_dir(): SPARK_HOME = os.environ["SPARK_HOME"] return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") @@ -1564,9 +1404,6 @@ def search_kinesis_asl_assembly_jar(): return jars[0] -# Must be same as the variable and condition defined in modules.py -flume_test_environ_var = "ENABLE_FLUME_TESTS" -are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1577,15 +1414,14 @@ def search_kinesis_asl_assembly_jar(): if __name__ == "__main__": from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() - flume_assembly_jar = search_flume_assembly_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() if kinesis_asl_assembly_jar is None: kinesis_jar_present = False - jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + jars = kafka_assembly_jar else: kinesis_jar_present = True - jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) + jars = "%s,%s" % (kafka_assembly_jar, kinesis_asl_assembly_jar) existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") jars_args = "--jars %s" % jars @@ -1593,14 +1429,6 @@ def search_kinesis_asl_assembly_jar(): testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] - if are_flume_tests_enabled: - testcases.append(FlumeStreamTests) - testcases.append(FlumePollingStreamTests) - else: - sys.stderr.write( - "Skipped test_flume_stream (enable by setting environment variable %s=1" - % flume_test_environ_var) - if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 027403816f538..122f25b21a0d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -537,7 +537,7 @@ class StreamingContext private[streaming] ( ExecutorAllocationManager.isDynamicAllocationEnabled(conf)) { logWarning("Dynamic Allocation is enabled for this application. " + "Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " + - "Write Ahead Log is not enabled for non-replayable sources like Flume. " + + "Write Ahead Log is not enabled for non-replayable sources. " + "See the programming guide for details on how to enable the Write Ahead Log.") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index a59f4efccb575..99396865f7d28 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -30,7 +30,7 @@ import org.apache.spark.streaming.dstream.DStream /** * A Java-friendly interface to [[org.apache.spark.streaming.dstream.DStream]], the basic * abstraction in Spark Streaming that represents a continuous stream of data. - * DStreams can either be created from live data (such as, data from TCP sockets, Kafka, Flume, + * DStreams can either be created from live data (such as, data from TCP sockets, Kafka, * etc.) or it can be generated by transforming existing DStreams using operations such as `map`, * `window`. For operations applicable to key-value pair DStreams, see * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 4a4d2c5d9d8c8..35243373daf9d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.{CallSite, Utils} * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous * sequence of RDDs (of the same type) representing a continuous stream of data (see * org.apache.spark.rdd.RDD in the Spark core documentation for more details on RDDs). - * DStreams can either be created from live data (such as, data from TCP sockets, Kafka, Flume, + * DStreams can either be created from live data (such as, data from TCP sockets, Kafka, * etc.) using a [[org.apache.spark.streaming.StreamingContext]] or it can be generated by * transforming existing DStreams using operations such as `map`, * `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each DStream diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 931f015f03b6f..6495c91247047 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -56,7 +56,6 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) /** A human-readable name of this InputDStream */ private[streaming] def name: String = { - // e.g. FlumePollingDStream -> "Flume polling stream" val newName = Utils.getFormattedClassName(this) .replaceAll("InputDStream", "Stream") .split("(?=[A-Z])") From 39872af882e3d73667acfab93c9de962c9c8939d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 12 Oct 2018 09:16:41 +0800 Subject: [PATCH 1819/2461] [SPARK-25684][SQL] Organize header related codes in CSV datasource ## What changes were proposed in this pull request? 1. Move `CSVDataSource.makeSafeHeader` to `CSVUtils.makeSafeHeader` (as is). - Historically and at the first place of refactoring (which I did), I intended to put all CSV specific handling (like options), filtering, extracting header, etc. - See `JsonDataSource`. Now `CSVDataSource` is quite consistent with `JsonDataSource`. Since CSV's code path is quite complicated, we might better match them as possible as we can. 2. Create `CSVHeaderChecker` and put `enforceSchema` logics into that. - The checking header and column pruning stuff were added (per https://github.com/apache/spark/pull/20894 and https://github.com/apache/spark/pull/21296) but some of codes such as https://github.com/apache/spark/pull/22123 are duplicated - Also, checking header code is basically here and there. We better put them in a single place, which was quite error-prone. See (https://github.com/apache/spark/pull/22656). 3. Move `CSVDataSource.checkHeaderColumnNames` to `CSVHeaderChecker.checkHeaderColumnNames` (as is). - Similar reasons above with 1. ## How was this patch tested? Existing tests should cover this. Closes #22676 from HyukjinKwon/refactoring-csv. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- .../apache/spark/sql/DataFrameReader.scala | 18 +- .../datasources/csv/CSVDataSource.scala | 161 ++---------------- .../datasources/csv/CSVFileFormat.scala | 11 +- .../datasources/csv/CSVHeaderChecker.scala | 131 ++++++++++++++ .../execution/datasources/csv/CSVUtils.scala | 44 ++++- .../datasources/csv/UnivocityParser.scala | 34 ++-- 6 files changed, 217 insertions(+), 182 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 72694463cedb5..3af70b5153c83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -505,20 +505,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val linesWithoutHeader = if (parsedOptions.headerFlag && maybeFirstLine.isDefined) { - val firstLine = maybeFirstLine.get - val parser = new CsvParser(parsedOptions.asParserSettings) - val columnNames = parser.parseLine(firstLine) - CSVDataSource.checkHeaderColumnNames( + val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + val headerChecker = new CSVHeaderChecker( actualSchema, - columnNames, - csvDataset.getClass.getCanonicalName, - parsedOptions.enforceSchema, - sparkSession.sessionState.conf.caseSensitiveAnalysis) + parsedOptions, + source = s"CSV source: $csvDataset") + headerChecker.checkHeaderColumnNames(firstLine) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) - } else { - filteredLines.rdd - } + }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => val rawParser = new UnivocityParser(actualSchema, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b93f418bcb5be..0b5a719d427c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -51,11 +51,8 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - requiredSchema: StructType, - // Actual schema of data in the csv file - dataSchema: StructType, - caseSensitive: Boolean, - columnPruning: Boolean): Iterator[InternalRow] + headerChecker: CSVHeaderChecker, + requiredSchema: StructType): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -75,48 +72,6 @@ abstract class CSVDataSource extends Serializable { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType - - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - protected def makeSafeHeader( - row: Array[String], - caseSensitive: Boolean, - options: CSVOptions): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - // scalastyle:off caselocale - .map(name => if (caseSensitive) name else name.toLowerCase) - // scalastyle:on caselocale - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - // scalastyle:off caselocale - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // scalastyle:on caselocale - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } - } } object CSVDataSource extends Logging { @@ -127,67 +82,6 @@ object CSVDataSource extends Logging { TextInputCSVDataSource } } - - /** - * Checks that column names in a CSV header and field names in the schema are the same - * by taking into account case sensitivity. - * - * @param schema - provided (or inferred) schema to which CSV must conform. - * @param columnNames - names of CSV columns that must be checked against to the schema. - * @param fileName - name of CSV file that are currently checked. It is used in error messages. - * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column - * names are checked for conformance to the schema. In the case if - * the column name don't conform to the schema, an exception is thrown. - * @param caseSensitive - if it is set to `false`, comparison of column names and schema field - * names is not case sensitive. - */ - def checkHeaderColumnNames( - schema: StructType, - columnNames: Array[String], - fileName: String, - enforceSchema: Boolean, - caseSensitive: Boolean): Unit = { - if (columnNames != null) { - val fieldNames = schema.map(_.name).toIndexedSeq - val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) - var errorMessage: Option[String] = None - - if (headerLen == schemaSize) { - var i = 0 - while (errorMessage.isEmpty && i < headerLen) { - var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) - if (!caseSensitive) { - // scalastyle:off caselocale - nameInSchema = nameInSchema.toLowerCase - nameInHeader = nameInHeader.toLowerCase - // scalastyle:on caselocale - } - if (nameInHeader != nameInSchema) { - errorMessage = Some( - s"""|CSV header does not conform to the schema. - | Header: ${columnNames.mkString(", ")} - | Schema: ${fieldNames.mkString(", ")} - |Expected: ${fieldNames(i)} but found: ${columnNames(i)} - |CSV file: $fileName""".stripMargin) - } - i += 1 - } - } else { - errorMessage = Some( - s"""|Number of column in CSV header is not equal to number of fields in the schema: - | Header length: $headerLen, schema size: $schemaSize - |CSV file: $fileName""".stripMargin) - } - - errorMessage.foreach { msg => - if (enforceSchema) { - logWarning(msg) - } else { - throw new IllegalArgumentException(msg) - } - } - } - } } object TextInputCSVDataSource extends CSVDataSource { @@ -197,10 +91,8 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - requiredSchema: StructType, - dataSchema: StructType, - caseSensitive: Boolean, - columnPruning: Boolean): Iterator[InternalRow] = { + headerChecker: CSVHeaderChecker, + requiredSchema: StructType): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) @@ -209,25 +101,7 @@ object TextInputCSVDataSource extends CSVDataSource { } } - val hasHeader = parser.options.headerFlag && file.start == 0 - if (hasHeader) { - // Checking that column names in the header are matched to field names of the schema. - // The header will be removed from lines. - // Note: if there are only comments in the first block, the header would probably - // be not extracted. - CSVUtils.extractHeader(lines, parser.options).foreach { header => - val schema = if (columnPruning) requiredSchema else dataSchema - val columnNames = parser.tokenizer.parseLine(header) - CSVDataSource.checkHeaderColumnNames( - schema, - columnNames, - file.filePath, - parser.options.enforceSchema, - caseSensitive) - } - } - - UnivocityParser.parseIterator(lines, parser, requiredSchema) + UnivocityParser.parseIterator(lines, parser, headerChecker, requiredSchema) } override def infer( @@ -251,7 +125,7 @@ object TextInputCSVDataSource extends CSVDataSource { maybeFirstLine.map(csvParser.parseLine(_)) match { case Some(firstRow) if firstRow != null => val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions) val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) @@ -298,26 +172,13 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - requiredSchema: StructType, - dataSchema: StructType, - caseSensitive: Boolean, - columnPruning: Boolean): Iterator[InternalRow] = { - def checkHeader(header: Array[String]): Unit = { - val schema = if (columnPruning) requiredSchema else dataSchema - CSVDataSource.checkHeaderColumnNames( - schema, - header, - file.filePath, - parser.options.enforceSchema, - caseSensitive) - } - + headerChecker: CSVHeaderChecker, + requiredSchema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), - parser.options.headerFlag, parser, - requiredSchema, - checkHeader) + headerChecker, + requiredSchema) } override def infer( @@ -334,7 +195,7 @@ object MultiLineCSVDataSource extends CSVDataSource { }.take(1).headOption match { case Some(firstRow) => val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions) val tokenRDD = csv.flatMap { lines => UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9aad0bd55e736..3de1c2d955d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -130,7 +130,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val columnPruning = sparkSession.sessionState.conf.csvColumnPruning (file: PartitionedFile) => { @@ -139,14 +138,16 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) + val schema = if (columnPruning) requiredSchema else dataSchema + val isStartOfFile = file.start == 0 + val headerChecker = new CSVHeaderChecker( + schema, parsedOptions, source = s"CSV file: ${file.filePath}", isStartOfFile) CSVDataSource(parsedOptions).readFile( conf, file, parser, - requiredSchema, - dataSchema, - caseSensitive, - columnPruning) + headerChecker, + requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala new file mode 100644 index 0000000000000..558ee91c419b9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema provided (or inferred) schema to which CSV must conform. + * @param options parsed CSV options. + * @param source name of CSV source that are currently checked. It is used in error messages. + * @param isStartOfFile indicates if the currently processing partition is the start of the file. + * if unknown or not applicable (for instance when the input is a dataset), + * can be omitted. + */ +class CSVHeaderChecker( + schema: StructType, + options: CSVOptions, + source: String, + isStartOfFile: Boolean = false) extends Logging { + + // Indicates if it is set to `false`, comparison of column names and schema field + // names is not case sensitive. + private val caseSensitive = SQLConf.get.caseSensitiveAnalysis + + // Indicates if it is `true`, column names are ignored otherwise the CSV column + // names are checked for conformance to the schema. In the case if + // the column name don't conform to the schema, an exception is thrown. + private val enforceSchema = options.enforceSchema + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param columnNames names of CSV columns that must be checked against to the schema. + */ + private def checkHeaderColumnNames(columnNames: Array[String]): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + // scalastyle:off caselocale + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + // scalastyle:on caselocale + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |$source""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |$source""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + // This is currently only used to parse CSV from Dataset[String]. + def checkHeaderColumnNames(line: String): Unit = { + if (options.headerFlag) { + val parser = new CsvParser(options.asParserSettings) + checkHeaderColumnNames(parser.parseLine(line)) + } + } + + // This is currently only used to parse CSV with multiLine mode. + private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = { + assert(options.multiLine, "This method should be executed with multiLine.") + if (options.headerFlag) { + val firstRecord = tokenizer.parseNext() + checkHeaderColumnNames(firstRecord) + } + } + + // This is currently only used to parse CSV with non-multiLine mode. + private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer: CsvParser): Unit = { + assert(!options.multiLine, "This method should not be executed with multiline.") + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + if (options.headerFlag && isStartOfFile) { + CSVUtils.extractHeader(lines, options).foreach { header => + checkHeaderColumnNames(tokenizer.parseLine(header)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 7ce65fa89b02d..b912f8add3afd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ object CSVUtils { /** @@ -90,6 +89,49 @@ object CSVUtils { None } } + + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + // scalastyle:off caselocale + .map(name => if (caseSensitive) name else name.toLowerCase) + // scalastyle:on caselocale + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + // scalastyle:off caselocale + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // scalastyle:on caselocale + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } + } + /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 9088d43905e28..fbd19c6e677e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -273,7 +273,10 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, tokenizer: CsvParser): Iterator[Array[String]] = { - convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens) + val handleHeader: () => Unit = + () => if (shouldDropHeader) tokenizer.parseNext + + convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens) } /** @@ -281,10 +284,9 @@ private[csv] object UnivocityParser { */ def parseStream( inputStream: InputStream, - shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType, - checkHeader: Array[String] => Unit): Iterator[InternalRow] = { + headerChecker: CSVHeaderChecker, + schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), @@ -292,25 +294,26 @@ private[csv] object UnivocityParser { schema, parser.options.columnNameOfCorruptRecord, parser.options.multiLine) - convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => + + val handleHeader: () => Unit = + () => headerChecker.checkHeaderColumnNames(tokenizer) + + convertStream(inputStream, tokenizer, handleHeader) { tokens => safeParser.parse(tokens) }.flatten } private def convertStream[T]( inputStream: InputStream, - shouldDropHeader: Boolean, tokenizer: CsvParser, - checkHeader: Array[String] => Unit = _ => ())( + handleHeader: () => Unit)( convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) - private var nextRecord = { - if (shouldDropHeader) { - val firstRecord = tokenizer.parseNext() - checkHeader(firstRecord) - } - tokenizer.parseNext() - } + + // We can handle header here since here the stream is open. + handleHeader() + + private var nextRecord = tokenizer.parseNext() override def hasNext: Boolean = nextRecord != null @@ -330,7 +333,10 @@ private[csv] object UnivocityParser { def parseIterator( lines: Iterator[String], parser: UnivocityParser, + headerChecker: CSVHeaderChecker, schema: StructType): Iterator[InternalRow] = { + headerChecker.checkHeaderColumnNames(lines, parser.tokenizer) + val options = parser.options val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) From c9d7d83ed5790aa272e969af36fd0cb90231111f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 12 Oct 2018 11:14:35 +0800 Subject: [PATCH 1820/2461] [SPARK-25388][TEST][SQL] Detect incorrect nullable of DataType in the result ## What changes were proposed in this pull request? This PR can correctly cause assertion failure when incorrect nullable of DataType in the result is generated by a target function to be tested. Let us think the following example. In the future, a developer would write incorrect code that returns unexpected result. We have to correctly cause fail in this test since `valueContainsNull=false` while `expr` includes `null`. However, without this PR, this test passes. This PR can correctly cause fail. ``` test("test TARGETFUNCTON") { val expr = TARGETMAPFUNCTON() // expr = UnsafeMap(3 -> 6, 7 -> null) // expr.dataType = (IntegerType, IntegerType, false) expected = Map(3 -> 6, 7 -> null) checkEvaluation(expr, expected) ``` In [`checkEvaluationWithUnsafeProjection`](https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala#L208-L235), the results are compared using `UnsafeRow`. When the given `expected` is [converted](https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala#L226-L227)) to `UnsafeRow` using the `DataType` of `expr`. ``` val expectedRow = UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) ``` In summary, `expr` is `[0,1800000038,5000000038,18,2,0,700000003,2,0,6,18,2,0,700000003,2,0,6]` with and w/o this PR. `expected` is converted to * w/o this PR, `[0,1800000038,5000000038,18,2,0,700000003,2,0,6,18,2,0,700000003,2,0,6]` * with this PR, `[0,1800000038,5000000038,18,2,0,700000003,2,2,6,18,2,0,700000003,2,2,6]` As a result, w/o this PR, the test unexpectedly passes. This is because, w/o this PR, based on given `dataType`, generated code of projection for `expected` avoids to set nullbit. ``` // tmpInput_2 is expected /* 155 */ for (int index_1 = 0; index_1 < numElements_1; index_1++) { /* 156 */ mutableStateArray_1[1].write(index_1, tmpInput_2.getInt(index_1)); /* 157 */ } ``` With this PR, generated code of projection for `expected` always checks whether nullbit should be set by `isNullAt` ``` // tmpInput_2 is expected /* 161 */ for (int index_1 = 0; index_1 < numElements_1; index_1++) { /* 162 */ /* 163 */ if (tmpInput_2.isNullAt(index_1)) { /* 164 */ mutableStateArray_1[1].setNull4Bytes(index_1); /* 165 */ } else { /* 166 */ mutableStateArray_1[1].write(index_1, tmpInput_2.getInt(index_1)); /* 167 */ } /* 168 */ /* 169 */ } ``` ## How was this patch tested? Existing UTs Closes #22375 from kiszk/SPARK-25388. Authored-by: Kazuaki Ishizaki Signed-off-by: Wenchen Fan --- .../expressions/CodeGenerationSuite.scala | 14 +++--- .../expressions/ExpressionEvalHelper.scala | 46 +++++++++++++------ .../ExpressionEvalHelperSuite.scala | 27 ++++++++++- .../execution/ObjectHashAggregateSuite.scala | 2 +- 4 files changed, 64 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 5e8113ac8658e..7843003a4aac3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -113,7 +113,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = UTF8String.fromString("abc") - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -126,7 +126,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true)) - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -142,7 +142,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true)) - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -154,7 +154,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) - if (!checkResult(actual, expected, expressions.head.dataType)) { + if (!checkResult(actual, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -170,7 +170,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual.length == 1) val expected = InternalRow(Seq.fill(length)(true): _*) - if (!checkResult(actual.head, expected, expressions.head.dataType)) { + if (!checkResult(actual.head, expected, expressions.head)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -375,7 +375,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actualOr.length == 1) val expectedOr = false - if (!checkResult(actualOr.head, expectedOr, exprOr.dataType)) { + if (!checkResult(actualOr.head, expectedOr, exprOr)) { fail(s"Incorrect Evaluation: expressions: $exprOr, actual: $actualOr, expected: $expectedOr") } @@ -389,7 +389,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actualAnd.length == 1) val expectedAnd = false - if (!checkResult(actualAnd.head, expectedAnd, exprAnd.dataType)) { + if (!checkResult(actualAnd.head, expectedAnd, exprAnd)) { fail( s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b5986aac65552..da18475276a13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -69,11 +69,22 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], MapData and Row. + * Array[Byte], Spread[Double], MapData and Row. Also check whether nullable in expression is + * true if result is null */ - protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, expression: Expression): Boolean = { + checkResult(result, expected, expression.dataType, expression.nullable) + } + + protected def checkResult( + result: Any, + expected: Any, + exprDataType: DataType, + exprNullable: Boolean): Boolean = { val dataType = UserDefinedType.sqlType(exprDataType) + // The result is null for a non-nullable expression + assert(result != null || exprNullable, "exprNullable should be true if result is null") (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) @@ -83,24 +94,24 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val st = dataType.asInstanceOf[StructType] assert(result.numFields == st.length && expected.numFields == st.length) st.zipWithIndex.forall { case (f, i) => - checkResult(result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType) + checkResult( + result.get(i, f.dataType), expected.get(i, f.dataType), f.dataType, f.nullable) } case (result: ArrayData, expected: ArrayData) => result.numElements == expected.numElements && { - val et = dataType.asInstanceOf[ArrayType].elementType + val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType] var isSame = true var i = 0 while (isSame && i < result.numElements) { - isSame = checkResult(result.get(i, et), expected.get(i, et), et) + isSame = checkResult(result.get(i, et), expected.get(i, et), et, cn) i += 1 } isSame } case (result: MapData, expected: MapData) => - val kt = dataType.asInstanceOf[MapType].keyType - val vt = dataType.asInstanceOf[MapType].valueType - checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) && - checkResult(result.valueArray, expected.valueArray, ArrayType(vt)) + val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType] + checkResult(result.keyArray, expected.keyArray, ArrayType(kt, false), false) && + checkResult(result.valueArray, expected.valueArray, ArrayType(vt, vcn), false) case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => @@ -175,7 +186,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (!checkResult(actual, expected, expression.dataType)) { + if (!checkResult(actual, expected, expression)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -191,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa for (fallbackMode <- modes) { withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { val actual = evaluateWithMutableProjection(expression, inputRow) - if (!checkResult(actual, expected, expression.dataType)) { + if (!checkResult(actual, expected, expression)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (fallback mode = $fallbackMode): $expression, " + s"actual: $actual, expected: $expected$input") @@ -221,6 +232,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + val dataType = expression.dataType + if (!checkResult(unsafeRow.get(0, dataType), expected, dataType, expression.nullable)) { + fail("Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " + + s"$expression, actual: $unsafeRow, expected: $expected, " + + s"dataType: $dataType, nullable: ${expression.nullable}") + } if (expected == null) { if (!unsafeRow.isNullAt(0)) { val expectedRow = InternalRow(expected, expected) @@ -229,8 +246,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } else { val lit = InternalRow(expected, expected) - val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + val expectedRow = UnsafeProjection.create(Array(dataType, dataType)).apply(lit) if (unsafeRow != expectedRow) { fail(s"Incorrect evaluation in unsafe mode (fallback mode = $fallbackMode): " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -280,7 +296,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa expression) plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) - assert(checkResult(actual, expected, expression.dataType)) + assert(checkResult(actual, expected, expression)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), @@ -288,7 +304,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) - assert(checkResult(actual, expected, expression.dataType)) + assert(checkResult(actual, expected, expression)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 7c7c4cccee253..54ef9641bee0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types.{DataType, IntegerType, MapType} /** * A test suite for testing [[ExpressionEvalHelper]]. @@ -35,6 +36,13 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) } assert(e.getMessage.contains("some_variable")) } + + test("SPARK-25388: checkEvaluation should fail if nullable in DataType is incorrect") { + val e = intercept[RuntimeException] { + checkEvaluation(MapIncorrectDataTypeExpression(), Map(3 -> 7, 6 -> null)) + } + assert(e.getMessage.contains("and exprNullable was")) + } } /** @@ -53,3 +61,18 @@ case class BadCodegenExpression() extends LeafExpression { } override def dataType: DataType = IntegerType } + +/** + * An expression that returns a MapData with incorrect DataType whose valueContainsNull is false + * while its value includes null + */ +case class MapIncorrectDataTypeExpression() extends LeafExpression with CodegenFallback { + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = { + val keys = new GenericArrayData(Array(3, 6)) + val values = new GenericArrayData(Array(7, null)) + new ArrayBasedMapData(keys, values) + } + // since values includes null, valueContainsNull must be true + override def dataType: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 0ef630bbd3670..c9309197791bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -416,7 +416,7 @@ class ObjectHashAggregateSuite actual.zip(expected).foreach { case (lhs: Row, rhs: Row) => assert(lhs.length == rhs.length) lhs.toSeq.zip(rhs.toSeq).foreach { - case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType) + case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType, false) case (a, b) => a == b } } From 368513048198efcee8c9a35678b608be0cb9ad48 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 11 Oct 2018 20:45:08 -0700 Subject: [PATCH 1821/2461] [SPARK-25690][SQL] Analyzer rule HandleNullInputsForUDF does not stabilize and can be applied infinitely ## What changes were proposed in this pull request? The HandleNullInputsForUDF rule can generate new If node infinitely, thus causing problems like match of SQL cache missed. This was fixed in SPARK-24891 and was then broken by SPARK-25044. The unit test in `AnalysisSuite` added in SPARK-24891 should have failed but didn't because it wasn't properly updated after the `ScalaUDF` constructor signature change. So this PR also updates the test accordingly based on the new `ScalaUDF` constructor. ## How was this patch tested? Updated the original UT. This should be justified as the original UT became invalid after SPARK-25044. Closes #22701 from maryannxue/spark-25690. Authored-by: maryannxue Signed-off-by: gatorsmile --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 +++- .../apache/spark/sql/catalyst/analysis/AnalysisSuite.scala | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d72e512e0df56..7f641ace46298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2150,8 +2150,10 @@ class Analyzer( // TODO: skip null handling for not-nullable primitive inputs after we can completely // trust the `nullable` information. + val needsNullCheck = (nullable: Boolean, expr: Expression) => + nullable && !expr.isInstanceOf[KnownNotNull] val inputsNullCheck = nullableTypes.zip(inputs) - .filter { case (nullable, _) => !nullable } + .filter { case (nullableType, expr) => needsNullCheck(!nullableType, expr) } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) // Once we add an `If` check above the udf, it is safe to mark those checked inputs diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f9facbb71a4e4..cf76c92b093b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -351,8 +351,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("SPARK-24891 Fix HandleNullInputsForUDF rule") { val a = testRelation.output(0) val func = (x: Int, y: Int) => x + y - val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil) - val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil) + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, nullableTypes = false :: false :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, nullableTypes = false :: false :: Nil) val plan = Project(Alias(udf2, "")() :: Nil, testRelation) comparePlans(plan.analyze, plan.analyze.analyze) } From 78e133141ce8131c60181f947346802864b0951a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 12 Oct 2018 00:24:06 -0700 Subject: [PATCH 1822/2461] [SPARK-25708][SQL] HAVING without GROUP BY means global aggregate ## What changes were proposed in this pull request? According to the SQL standard, when a query contains `HAVING`, it indicates an aggregate operator. For more details please refer to https://blog.jooq.org/2014/12/04/do-you-really-understand-sqls-group-by-and-having-clauses/ However, in Spark SQL parser, we treat HAVING as a normal filter when there is no GROUP BY, which breaks SQL semantic and lead to wrong result. This PR fixes the parser. ## How was this patch tested? new test Closes #22696 from cloud-fan/having. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/parser/AstBuilder.scala | 41 +++++++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 8 ++++ .../sql/catalyst/parser/PlanParserSuite.scala | 2 +- .../resources/sql-tests/inputs/group-by.sql | 7 ++++ .../sql-tests/results/group-by.sql.out | 27 +++++++++++- .../sql/hive/execution/HiveQuerySuite.scala | 4 -- 7 files changed, 71 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0d2935769ae51..fb03ed2e292b3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1977,6 +1977,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was writted as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. + - In Spark version 2.3 and earlier, HAVING without GROUP BY is treated as WHERE. This means, `SELECT 1 FROM range(10) HAVING true` is executed as `SELECT 1 FROM range(10) WHERE true` and returns 10 rows. This violates SQL standard, and has been fixed in Spark 2.4. Since Spark 2.4, HAVING without GROUP BY is treated as a global aggregate, which means `SELECT 1 FROM range(10) HAVING true` will return only one row. To restore the previous behavior, set `spark.sql.legacy.parser.havingWithoutGroupByAsWhere` to `true`. ## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ba0b72e747fc9..672bffcfc0cad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -394,6 +394,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Filter(expression(ctx), plan) } + def withHaving(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(ctx) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, plan) + } + + // Expressions. val expressions = Option(namedExpressionSeq).toSeq .flatMap(_.namedExpression.asScala) @@ -446,30 +457,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case e: NamedExpression => e case e: Expression => UnresolvedAlias(e) } - val withProject = if (aggregation != null) { - withAggregation(aggregation, namedExpressions, withFilter) - } else if (namedExpressions.nonEmpty) { + + def createProject() = if (namedExpressions.nonEmpty) { Project(namedExpressions, withFilter) } else { withFilter } - // Having - val withHaving = withProject.optional(having) { - // Note that we add a cast to non-predicate expressions. If the expression itself is - // already boolean, the optimizer will get rid of the unnecessary cast. - val predicate = expression(having) match { - case p: Predicate => p - case e => Cast(e, BooleanType) + val withProject = if (aggregation == null && having != null) { + if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) { + // If the legacy conf is set, treat HAVING without GROUP BY as WHERE. + withHaving(having, createProject()) + } else { + // According to SQL standard, HAVING without GROUP BY means global aggregate. + withHaving(having, Aggregate(Nil, namedExpressions, withFilter)) } - Filter(predicate, withProject) + } else if (aggregation != null) { + val aggregate = withAggregation(aggregation, namedExpressions, withFilter) + aggregate.optionalMap(having)(withHaving) + } else { + // When hitting this branch, `having` must be null. + createProject() } // Distinct val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { - Distinct(withHaving) + Distinct(withProject) } else { - withHaving + withProject } // Window diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b699707d85235..da70d7da7351b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1567,6 +1567,14 @@ object SQLConf { .internal() .booleanConf .createWithDefault(false) + + val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = + buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") + .internal() + .doc("If it is set to true, the parser will treat HAVING without GROUP BY as a normal " + + "WHERE, which does not follow SQL standard.") + .booleanConf + .createWithDefault(false) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 422bf97e30e7e..f5da90f7cf0c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -108,7 +108,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) assertEqual( "select a, b from db.c having x < 1", - table("db", "c").select('a, 'b).where('x < 1)) + table("db", "c").groupBy()('a, 'b).where('x < 1)) assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 2c18d6aaabdba..433db71527437 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -73,3 +73,10 @@ where b.z != b.z; -- SPARK-24369 multiple distinct aggregations having the same argument set SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); + +-- SPARK-25708 HAVING without GROUP BY means global aggregate +SELECT 1 FROM range(10) HAVING true; + +SELECT 1 FROM range(10) HAVING MAX(id) > 0; + +SELECT id FROM range(10) HAVING id > 0; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 581aa1754ce14..f9d1ee8a6bcdb 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 30 -- !query 0 @@ -250,3 +250,28 @@ SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) struct -- !query 26 output 1.0 1.0 3 + + +-- !query 27 +SELECT 1 FROM range(10) HAVING true +-- !query 27 schema +struct<1:int> +-- !query 27 output +1 + + +-- !query 28 +SELECT 1 FROM range(10) HAVING MAX(id) > 0 +-- !query 28 schema +struct<1:int> +-- !query 28 output +1 + + +-- !query 29 +SELECT id FROM range(10) HAVING id > 0 +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and '`id`' is not an aggregate function. Wrap '()' in windowing function(s) or wrap '`id`' in first() (or first_value) if you don't care which value you get.; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b9c32e789a410..a5cff35abf37e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -740,10 +740,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd sql("select key, count(*) c from src group by key having c").collect() } - test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { - assert(sql("select key from src having key > 490").collect().size < 100) - } - test("union/except/intersect") { assertResult(Array(Row(1), Row(1))) { sql("select 1 as a union all select 1 as a").collect() From 3494b122814ed991b40dc4b80703c0ef55493d36 Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 12 Oct 2018 12:36:35 -0500 Subject: [PATCH 1823/2461] [SPARK-25566][SPARK-25567][WEBUI][SQL] Support pagination for SQL tab to avoid OOM ## What changes were proposed in this pull request? Currently SQL tab in the WEBUI doesn't support pagination. Because of that following issues are happening. 1) For large number of executions, SQL page is throwing OOM exception (around 40,000) 2) For large number of executions, loading SQL page is taking time. 3) Difficult to analyse the execution table for large number of execution. [Note: spark.sql.ui.retainedExecutions = 50000] All the tabs, Jobs, Stages etc. supports pagination. So, to make it consistent with other tabs SQL tab also should support pagination. I have followed the similar flow of the pagination code in the Jobs and Stages page for SQL page. Also, this patch doesn't make any behavior change for the SQL tab except the pagination support. ## How was this patch tested? bin/spark-shell --conf spark.sql.ui.retainedExecutions=50000 Run 50,000 sql queries. **Before this PR** ![screenshot from 2018-10-05 23-48-27](https://user-images.githubusercontent.com/23054875/46552750-4ed82480-c8f9-11e8-8b05-d60bedddd1b8.png) ![screenshot from 2018-10-05 22-58-11](https://user-images.githubusercontent.com/23054875/46550276-33b5e680-c8f2-11e8-9e32-9ae9c5b181e0.png) **After this PR** Loading of the page is faster, and OOM issue doesn't happen. ![screenshot from 2018-10-05 23-50-32](https://user-images.githubusercontent.com/23054875/46552814-8050f000-c8f9-11e8-96e9-42502d2cfaea.png) Closes #22645 from shahidki31/SPARK-25566. Authored-by: Shahid Signed-off-by: Sean Owen --- .../org/apache/spark/ui/PagedTable.scala | 4 +- .../sql/execution/ui/AllExecutionsPage.scala | 404 ++++++++++++++---- 2 files changed, 312 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 65fa38387b9ee..2fc0259c39d02 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils * * @param pageSize the number of rows in a page */ -private[ui] abstract class PagedDataSource[T](val pageSize: Int) { +private[spark] abstract class PagedDataSource[T](val pageSize: Int) { if (pageSize <= 0) { throw new IllegalArgumentException("Page size must be positive") @@ -72,7 +72,7 @@ private[ui] case class PageData[T](totalPage: Int, data: Seq[T]) /** * A paged table that will generate a HTML table for a specified page and also the page navigation. */ -private[ui] trait PagedTable[T] { +private[spark] trait PagedTable[T] { def tableId: String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 1b2d8a821b364..1a25cd2a49e36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.execution.ui +import java.net.URLEncoder import javax.servlet.http.HttpServletRequest +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.xml.{Node, NodeSeq} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.{Node, NodeSeq, Unparsed} import org.apache.spark.JobExecutionStatus import org.apache.spark.internal.Logging -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.ui.{PagedDataSource, PagedTable, UIUtils, WebUIPage} +import org.apache.spark.util.Utils private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging { @@ -55,8 +56,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val _content = mutable.ListBuffer[Node]() if (running.nonEmpty) { - val runningPageTable = new RunningExecutionTable( - parent, currentTime, running.sortBy(_.submissionTime).reverse).toNodeSeq(request) + val runningPageTable = + executionsTable(request, "running", running, currentTime, true, true, true) _content ++= - Running Queries: + Running Queries: {running.size} } @@ -129,7 +130,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L { if (completed.nonEmpty) {
    • - Completed Queries: + Completed Queries: {completed.size}
    • } @@ -137,50 +138,232 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L { if (failed.nonEmpty) {
    • - Failed Queries: + Failed Queries: {failed.size}
    • } } + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } + + private def executionsTable( + request: HttpServletRequest, + executionTag: String, + executionData: Seq[SQLExecutionUIData], + currentTime: Long, + showRunningJobs: Boolean, + showSucceededJobs: Boolean, + showFailedJobs: Boolean): Seq[Node] = { + + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } + val parameterOtherTable = allParameters.filterNot(_._1.startsWith(executionTag)) + .map(para => para._1 + "=" + para._2(0)) + + val parameterExecutionPage = UIUtils.stripXSS(request.getParameter(s"$executionTag.page")) + val parameterExecutionSortColumn = UIUtils.stripXSS(request + .getParameter(s"$executionTag.sort")) + val parameterExecutionSortDesc = UIUtils.stripXSS(request.getParameter(s"$executionTag.desc")) + val parameterExecutionPageSize = UIUtils.stripXSS(request + .getParameter(s"$executionTag.pageSize")) + val parameterExecutionPrevPageSize = UIUtils.stripXSS(request + .getParameter(s"$executionTag.prevPageSize")) + + val executionPage = Option(parameterExecutionPage).map(_.toInt).getOrElse(1) + val executionSortColumn = Option(parameterExecutionSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("ID") + val executionSortDesc = Option(parameterExecutionSortDesc).map(_.toBoolean).getOrElse( + // New executions should be shown above old executions by default. + executionSortColumn == "ID" + ) + val executionPageSize = Option(parameterExecutionPageSize).map(_.toInt).getOrElse(100) + val executionPrevPageSize = Option(parameterExecutionPrevPageSize).map(_.toInt) + .getOrElse(executionPageSize) + + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + val page: Int = if (executionPageSize <= executionPrevPageSize) { + executionPage + } else { + 1 + } + val tableHeaderId = executionTag // "running", "completed" or "failed" + + try { + new ExecutionPagedTable( + request, + parent, + executionData, + tableHeaderId, + executionTag, + UIUtils.prependBaseUri(request, parent.basePath), + "SQL", // subPath + parameterOtherTable, + currentTime, + pageSize = executionPageSize, + sortColumn = executionSortColumn, + desc = executionSortDesc, + showRunningJobs, + showSucceededJobs, + showFailedJobs).table(page) + } catch { + case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) => +
      +

      Error while rendering execution table:

      +
      +            {Utils.exceptionString(e)}
      +          
      +
      + } + } } -private[ui] abstract class ExecutionTable( +private[ui] class ExecutionPagedTable( + request: HttpServletRequest, parent: SQLTab, - tableId: String, + data: Seq[SQLExecutionUIData], + tableHeaderId: String, + executionTag: String, + basePath: String, + subPath: String, + parameterOtherTable: Iterable[String], currentTime: Long, - executionUIDatas: Seq[SQLExecutionUIData], + pageSize: Int, + sortColumn: String, + desc: Boolean, showRunningJobs: Boolean, showSucceededJobs: Boolean, - showFailedJobs: Boolean) { + showFailedJobs: Boolean) extends PagedTable[ExecutionTableRowData] { - protected def baseHeader: Seq[String] = Seq( - "ID", - "Description", - "Submitted", - "Duration") + override val dataSource = new ExecutionDataSource( + request, + parent, + data, + basePath, + currentTime, + pageSize, + sortColumn, + desc, + showRunningJobs, + showSucceededJobs, + showFailedJobs) + + private val parameterPath = s"$basePath/$subPath/?${parameterOtherTable.mkString("&")}" + + override def tableId: String = s"$executionTag-table" + + override def tableCssClass: String = + "table table-bordered table-condensed table-striped " + + "table-head-clickable table-cell-width-limited" + + override def prevPageSizeFormField: String = s"$executionTag.prevPageSize" + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + parameterPath + + s"&$pageNumberFormField=$page" + + s"&$executionTag.sort=$encodedSortColumn" + + s"&$executionTag.desc=$desc" + + s"&$pageSizeFormField=$pageSize" + + s"#$tableHeaderId" + } - protected def header: Seq[String] + override def pageSizeFormField: String = s"$executionTag.pageSize" - protected def row( - request: HttpServletRequest, - currentTime: Long, - executionUIData: SQLExecutionUIData): Seq[Node] = { - val submissionTime = executionUIData.submissionTime - val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - - submissionTime + override def pageNumberFormField: String = s"$executionTag.page" + + override def goButtonFormPath: String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"$parameterPath&$executionTag.sort=$encodedSortColumn&$executionTag.desc=$desc#$tableHeaderId" + } + + override def headers: Seq[Node] = { + // Information for each header: title, sortable + val executionHeadersAndCssClasses: Seq[(String, Boolean)] = + Seq( + ("ID", true), + ("Description", true), + ("Submitted", true), + ("Duration", true)) ++ { + if (showRunningJobs && showSucceededJobs && showFailedJobs) { + Seq( + ("Running Job IDs", true), + ("Succeeded Job IDs", true), + ("Failed Job IDs", true)) + } else if (showSucceededJobs && showFailedJobs) { + Seq( + ("Succeeded Job IDs", true), + ("Failed Job IDs", true)) + } else { + Seq(("Job IDs", true)) + } + } - def jobLinks(status: JobExecutionStatus): Seq[Node] = { - executionUIData.jobs.flatMap { case (jobId, jobStatus) => - if (jobStatus == status) { - [{jobId.toString}] + val sortableColumnHeaders = executionHeadersAndCssClasses.filter { + case (_, sortable) => sortable + }.map { case (title, _) => title } + + require(sortableColumnHeaders.contains(sortColumn), s"Unknown column: $sortColumn") + + val headerRow: Seq[Node] = { + executionHeadersAndCssClasses.map { case (header, sortable) => + if (header == sortColumn) { + val headerLink = Unparsed( + parameterPath + + s"&$executionTag.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&$executionTag.desc=${!desc}" + + s"&$executionTag.pageSize=$pageSize" + + s"#$tableHeaderId") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + + + {header} +  {Unparsed(arrow)} + + + } else { - None + if (sortable) { + val headerLink = Unparsed( + parameterPath + + s"&$executionTag.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&$executionTag.pageSize=$pageSize" + + s"#$tableHeaderId") + + + + {header} + + + } else { + + {header} + + } } - }.toSeq + } + } + + {headerRow} + + } + + override def row(executionTableRow: ExecutionTableRowData): Seq[Node] = { + val executionUIData = executionTableRow.executionUIData + val submissionTime = executionUIData.submissionTime + val duration = executionTableRow.duration + + def jobLinks(jobData: Seq[Int]): Seq[Node] = { + jobData.map { jobId => + [{jobId.toString}] + } } @@ -188,7 +371,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(request, executionUIData)} + {descriptionCell(executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -198,27 +381,26 @@ private[ui] abstract class ExecutionTable( {if (showRunningJobs) { - {jobLinks(JobExecutionStatus.RUNNING)} + {jobLinks(executionTableRow.runningJobData)} }} {if (showSucceededJobs) { - {jobLinks(JobExecutionStatus.SUCCEEDED)} + {jobLinks(executionTableRow.completedJobData)} }} {if (showFailedJobs) { - {jobLinks(JobExecutionStatus.FAILED)} + {jobLinks(executionTableRow.failedJobData)} }} } - private def descriptionCell( - request: HttpServletRequest, - execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { - + +details ++ +The extra options are also used during write operation. +For example, you can control bloom filters and dictionary encodings for ORC data sources. +The following ORC example will create bloom filter and use dictionary encoding only for `favorite_color`. +For Parquet, there exists `parquet.enable.dictionary`, too. +To find more detailed information about the extra ORC/Parquet options, +visit the official Apache ORC/Parquet websites. + +
      + +
      +{% include_example manual_save_options_orc scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
      + +
      +{% include_example manual_save_options_orc java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
      + +
      +{% include_example manual_save_options_orc python/sql/datasource.py %} +
      + +
      +{% include_example manual_save_options_orc r/RSparkSQLExample.R %} +
      + +
      + +{% highlight sql %} +CREATE TABLE users_with_options ( + name STRING, + favorite_color STRING, + favorite_numbers array +) USING ORC +OPTIONS ( + orc.bloom.filter.columns 'favorite_color', + orc.dictionary.key.threshold '1.0', + orc.column.encoding.direct 'name' +) +{% endhighlight %} + +
      + +
      + ### Run SQL on files directly Instead of using read API to load a file into DataFrame and query it, you can also query that diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index ef3c904775697..cbe9dfdaa907b 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -123,6 +123,13 @@ private static void runBasicDataSourceExample(SparkSession spark) { .option("header", "true") .load("examples/src/main/resources/people.csv"); // $example off:manual_load_options_csv$ + // $example on:manual_save_options_orc$ + usersDF.write().format("orc") + .option("orc.bloom.filter.columns", "favorite_color") + .option("orc.dictionary.key.threshold", "1.0") + .option("orc.column.encoding.direct", "name") + .save("users_with_options.orc"); + // $example off:manual_save_options_orc$ // $example on:direct_sql$ Dataset sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index d8c879dfe02ed..04660724b308d 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -57,6 +57,15 @@ def basic_datasource_example(spark): format="csv", sep=":", inferSchema="true", header="true") # $example off:manual_load_options_csv$ + # $example on:manual_save_options_orc$ + df = spark.read.orc("examples/src/main/resources/users.orc") + (df.write.format("orc") + .option("orc.bloom.filter.columns", "favorite_color") + .option("orc.dictionary.key.threshold", "1.0") + .option("orc.column.encoding.direct", "name") + .save("users_with_options.orc")) + # $example off:manual_save_options_orc$ + # $example on:write_sorting_and_bucketing$ df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") # $example off:write_sorting_and_bucketing$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index effba948e5317..196a110f351ce 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -114,10 +114,14 @@ write.df(namesAndAges, "namesAndAges.parquet", "parquet") # $example on:manual_load_options_csv$ -df <- read.df("examples/src/main/resources/people.csv", "csv", sep=";", inferSchema=T, header=T) +df <- read.df("examples/src/main/resources/people.csv", "csv", sep = ";", inferSchema = TRUE, header = TRUE) namesAndAges <- select(df, "name", "age") # $example off:manual_load_options_csv$ +# $example on:manual_save_options_orc$ +df <- read.df("examples/src/main/resources/users.orc", "orc") +write.orc(df, "users_with_options.orc", orc.bloom.filter.columns = "favorite_color", orc.dictionary.key.threshold = 1.0, orc.column.encoding.direct = "name") +# $example off:manual_save_options_orc$ # $example on:direct_sql$ df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") diff --git a/examples/src/main/resources/users.orc b/examples/src/main/resources/users.orc new file mode 100644 index 0000000000000000000000000000000000000000..12478a5d03c26cb30b35af232a5764e076eaab1f GIT binary patch literal 547 zcmZ`#Jxc>Y5S`t<-He+fZZ-y&D1Mwxz(Or-4vQp$h^RSIU8P1nQP2b~QLz($LH>cQ z{tF8ce~yK{&c!BZT$uM}cG)*?rrFvo0%&DDwa5Zri!?d488{WO8PdpWktqx{m#HrO)INGvp)yr>5J3sxN1B9l z09)*Y`hL|YB_YBFybP~vd4M;e{Nxp&KdZC<;PMiYxg|pG772wj63;#7=$#qnd}2(&;MDi2oBw}Np|@jC6Rq*6F*-*nT9esXxyz3iqHb5=Cf&h^!ClJ)|Q zIq7{gBx=h%s>CV}haSWKJceUEhKjAnu?l}#tPS?J0iPHx>i*sY9Qt^a{vGU literal 0 HcmV?d00001 diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 7d83aacb11548..18615d9b9b908 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -56,6 +56,13 @@ object SQLDataSourceExample { .option("header", "true") .load("examples/src/main/resources/people.csv") // $example off:manual_load_options_csv$ + // $example on:manual_save_options_orc$ + usersDF.write.format("orc") + .option("orc.bloom.filter.columns", "favorite_color") + .option("orc.dictionary.key.threshold", "1.0") + .option("orc.column.encoding.direct", "name") + .save("users_with_options.orc") + // $example off:manual_save_options_orc$ // $example on:direct_sql$ val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") From 584e767d372d41071c3436f9ad4bf77a820f12b4 Mon Sep 17 00:00:00 2001 From: Vladimir Kuriatkov Date: Wed, 24 Oct 2018 09:29:40 +0800 Subject: [PATCH 1906/2461] [SPARK-25772][SQL] Fix java map of structs deserialization This is a follow-up PR for #22708. It considers another case of java beans deserialization: java maps with struct keys/values. When deserializing values of MapType with struct keys/values in java beans, fields of structs get mixed up. I suggest using struct data types retrieved from resolved input data instead of inferring them from java beans. ## What changes were proposed in this pull request? Invocations of "keyArray" and "valueArray" functions are used to extract arrays of keys and values. Struct type of keys or values is also inferred from java bean structure and ends up with mixed up field order. I created a new UnresolvedInvoke expression as a temporary substitution of Invoke expression while no actual data is available. It allows to provide the resulting data type during analysis based on the resolved input data, not on the java bean (similar to UnresolvedMapObjects). Key and value arrays are then fed to MapObjects expression which I replaced with UnresolvedMapObjects, just like in case of ArrayType. Finally I added resolution of UnresolvedInvoke expressions in Analyzer.resolveExpression method as an additional pattern matching case. ## How was this patch tested? Added a test case. Built complete project on travis. viirya kiszk cloud-fan michalsenkyr marmbrus liancheng Closes #22745 from vofque/SPARK-21402-FOLLOWUP. Lead-authored-by: Vladimir Kuriatkov Co-authored-by: Vladimir Kuriatkov Signed-off-by: Wenchen Fan --- .../sql/catalyst/JavaTypeInference.scala | 12 +- .../expressions/objects/objects.scala | 76 ++++++ .../sql/JavaBeanDeserializationSuite.java | 240 ++++++++++++++++++ .../spark/sql/JavaBeanWithArraySuite.java | 154 ----------- .../resources/test-data/with-map-fields.json | 5 + 5 files changed, 325 insertions(+), 162 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java delete mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanWithArraySuite.java create mode 100644 sql/core/src/test/resources/test-data/with-map-fields.json diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7a226d72f5977..60dd4a57139e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -278,24 +278,20 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val keyDataType = inferDataType(keyType)._1 - val valueDataType = inferDataType(valueType)._1 val keyData = Invoke( - MapObjects( + UnresolvedMapObjects( p => deserializerFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + GetKeyArrayFromMap(getPath)), "array", ObjectType(classOf[Array[Any]])) val valueData = Invoke( - MapObjects( + UnresolvedMapObjects( p => deserializerFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + GetValueArrayFromMap(getPath)), "array", ObjectType(classOf[Array[Any]])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 3189e6841a525..5bfa485f1569a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,6 +30,7 @@ import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -1787,3 +1788,78 @@ case class ValidateExternalType(child: Expression, expected: DataType) ev.copy(code = code, isNull = input.isNull) } } + +object GetKeyArrayFromMap { + + /** + * Construct an instance of GetArrayFromMap case class + * extracting a key array from a Map expression. + * + * @param child a Map expression to extract a key array from + */ + def apply(child: Expression): Expression = { + GetArrayFromMap( + child, + "keyArray", + _.keyArray(), + { case MapType(kt, _, _) => kt }) + } +} + +object GetValueArrayFromMap { + + /** + * Construct an instance of GetArrayFromMap case class + * extracting a value array from a Map expression. + * + * @param child a Map expression to extract a value array from + */ + def apply(child: Expression): Expression = { + GetArrayFromMap( + child, + "valueArray", + _.valueArray(), + { case MapType(_, vt, _) => vt }) + } +} + +/** + * Extracts a key/value array from a Map expression. + * + * @param child a Map expression to extract an array from + * @param functionName name of the function that is invoked to extract an array + * @param arrayGetter function extracting `ArrayData` from `MapData` + * @param elementTypeGetter function extracting array element `DataType` from `MapType` + */ +case class GetArrayFromMap private( + child: Expression, + functionName: String, + arrayGetter: MapData => ArrayData, + elementTypeGetter: MapType => DataType) extends UnaryExpression with NonSQLExpression { + + private lazy val encodedFunctionName: String = TermName(functionName).encodedName.toString + + lazy val dataType: DataType = { + val mt: MapType = child.dataType.asInstanceOf[MapType] + ArrayType(elementTypeGetter(mt)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[MapType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"Can't extract array from $child: need map type but got ${child.dataType.catalogString}") + } + } + + override def nullSafeEval(input: Any): Any = { + arrayGetter(input.asInstanceOf[MapData]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, childValue => s"$childValue.$encodedFunctionName()") + } + + override def toString: String = s"$child.$functionName" +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java new file mode 100644 index 0000000000000..7f975a647c241 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.*; + +import org.junit.*; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.test.TestSparkSession; + +public class JavaBeanDeserializationSuite implements Serializable { + + private TestSparkSession spark; + + @Before + public void setUp() { + spark = new TestSparkSession(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + private static final List ARRAY_RECORDS = new ArrayList<>(); + + static { + ARRAY_RECORDS.add( + new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221))) + ); + ARRAY_RECORDS.add( + new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222))) + ); + ARRAY_RECORDS.add( + new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223))) + ); + } + + @Test + public void testBeanWithArrayFieldDeserialization() { + + Encoder encoder = Encoders.bean(ArrayRecord.class); + + Dataset dataset = spark + .read() + .format("json") + .schema("id int, intervals array>") + .load("src/test/resources/test-data/with-array-fields.json") + .as(encoder); + + List records = dataset.collectAsList(); + Assert.assertEquals(records, ARRAY_RECORDS); + } + + private static final List MAP_RECORDS = new ArrayList<>(); + + static { + MAP_RECORDS.add(new MapRecord(1, + toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(111, 211), new Interval(121, 221))) + )); + MAP_RECORDS.add(new MapRecord(2, + toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(112, 212), new Interval(122, 222))) + )); + MAP_RECORDS.add(new MapRecord(3, + toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(113, 213), new Interval(123, 223))) + )); + MAP_RECORDS.add(new MapRecord(4, new HashMap<>())); + MAP_RECORDS.add(new MapRecord(5, null)); + } + + private static Map toMap(Collection keys, Collection values) { + Map map = new HashMap<>(); + Iterator keyI = keys.iterator(); + Iterator valueI = values.iterator(); + while (keyI.hasNext() && valueI.hasNext()) { + map.put(keyI.next(), valueI.next()); + } + return map; + } + + @Test + public void testBeanWithMapFieldsDeserialization() { + + Encoder encoder = Encoders.bean(MapRecord.class); + + Dataset dataset = spark + .read() + .format("json") + .schema("id int, intervals map>") + .load("src/test/resources/test-data/with-map-fields.json") + .as(encoder); + + List records = dataset.collectAsList(); + + Assert.assertEquals(records, MAP_RECORDS); + } + + public static class ArrayRecord { + + private int id; + private List intervals; + + public ArrayRecord() { } + + ArrayRecord(int id, List intervals) { + this.id = id; + this.intervals = intervals; + } + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public List getIntervals() { + return intervals; + } + + public void setIntervals(List intervals) { + this.intervals = intervals; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof ArrayRecord)) return false; + ArrayRecord other = (ArrayRecord) obj; + return (other.id == this.id) && other.intervals.equals(this.intervals); + } + + @Override + public String toString() { + return String.format("{ id: %d, intervals: %s }", id, intervals); + } + } + + public static class MapRecord { + + private int id; + private Map intervals; + + public MapRecord() { } + + MapRecord(int id, Map intervals) { + this.id = id; + this.intervals = intervals; + } + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public Map getIntervals() { + return intervals; + } + + public void setIntervals(Map intervals) { + this.intervals = intervals; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof MapRecord)) return false; + MapRecord other = (MapRecord) obj; + return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); + } + + @Override + public String toString() { + return String.format("{ id: %d, intervals: %s }", id, intervals); + } + } + + public static class Interval { + + private long startTime; + private long endTime; + + public Interval() { } + + Interval(long startTime, long endTime) { + this.startTime = startTime; + this.endTime = endTime; + } + + public long getStartTime() { + return startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + public long getEndTime() { + return endTime; + } + + public void setEndTime(long endTime) { + this.endTime = endTime; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Interval)) return false; + Interval other = (Interval) obj; + return (other.startTime == this.startTime) && (other.endTime == this.endTime); + } + + @Override + public String toString() { + return String.format("[%d,%d]", startTime, endTime); + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanWithArraySuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanWithArraySuite.java deleted file mode 100644 index 70dd11067253e..0000000000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanWithArraySuite.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.test.TestSparkSession; - -public class JavaBeanWithArraySuite { - - private static final List RECORDS = new ArrayList<>(); - - static { - RECORDS.add(new Record(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)))); - RECORDS.add(new Record(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)))); - RECORDS.add(new Record(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)))); - } - - private TestSparkSession spark; - - @Before - public void setUp() { - spark = new TestSparkSession(); - } - - @After - public void tearDown() { - spark.stop(); - spark = null; - } - - @Test - public void testBeanWithArrayFieldDeserialization() { - - Encoder encoder = Encoders.bean(Record.class); - - Dataset dataset = spark - .read() - .format("json") - .schema("id int, intervals array>") - .load("src/test/resources/test-data/with-array-fields.json") - .as(encoder); - - List records = dataset.collectAsList(); - Assert.assertEquals(records, RECORDS); - } - - public static class Record { - - private int id; - private List intervals; - - public Record() { } - - Record(int id, List intervals) { - this.id = id; - this.intervals = intervals; - } - - public int getId() { - return id; - } - - public void setId(int id) { - this.id = id; - } - - public List getIntervals() { - return intervals; - } - - public void setIntervals(List intervals) { - this.intervals = intervals; - } - - @Override - public boolean equals(Object obj) { - if (!(obj instanceof Record)) return false; - Record other = (Record) obj; - return (other.id == this.id) && other.intervals.equals(this.intervals); - } - - @Override - public String toString() { - return String.format("{ id: %d, intervals: %s }", id, intervals); - } - } - - public static class Interval { - - private long startTime; - private long endTime; - - public Interval() { } - - Interval(long startTime, long endTime) { - this.startTime = startTime; - this.endTime = endTime; - } - - public long getStartTime() { - return startTime; - } - - public void setStartTime(long startTime) { - this.startTime = startTime; - } - - public long getEndTime() { - return endTime; - } - - public void setEndTime(long endTime) { - this.endTime = endTime; - } - - @Override - public boolean equals(Object obj) { - if (!(obj instanceof Interval)) return false; - Interval other = (Interval) obj; - return (other.startTime == this.startTime) && (other.endTime == this.endTime); - } - - @Override - public String toString() { - return String.format("[%d,%d]", startTime, endTime); - } - } -} diff --git a/sql/core/src/test/resources/test-data/with-map-fields.json b/sql/core/src/test/resources/test-data/with-map-fields.json new file mode 100644 index 0000000000000..576fbb9b8758b --- /dev/null +++ b/sql/core/src/test/resources/test-data/with-map-fields.json @@ -0,0 +1,5 @@ +{ "id": 1, "intervals": { "a": { "startTime": 111, "endTime": 211 }, "b": { "startTime": 121, "endTime": 221 }}} +{ "id": 2, "intervals": { "a": { "startTime": 112, "endTime": 212 }, "b": { "startTime": 122, "endTime": 222 }}} +{ "id": 3, "intervals": { "a": { "startTime": 113, "endTime": 213 }, "b": { "startTime": 123, "endTime": 223 }}} +{ "id": 4, "intervals": { }} +{ "id": 5 } \ No newline at end of file From 4d6704db4d490bd1830ed3c757525f41058523e0 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 24 Oct 2018 19:09:15 +0800 Subject: [PATCH 1907/2461] [SPARK-25243][SQL] Use FailureSafeParser in from_json ## What changes were proposed in this pull request? In the PR, I propose to switch `from_json` on `FailureSafeParser`, and to make the function compatible to `PERMISSIVE` mode by default, and to support the `FAILFAST` mode as well. The `DROPMALFORMED` mode is not supported by `from_json`. ## How was this patch tested? It was tested by existing `JsonSuite`/`CSVSuite`, `JsonFunctionsSuite` and `JsonExpressionsSuite` as well as new tests for `from_json` which checks different modes. Closes #22237 from MaxGekk/from_json-failuresafe. Lead-authored-by: Maxim Gekk Co-authored-by: hyukjinkwon Signed-off-by: Wenchen Fan --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- docs/sql-migration-guide-upgrade.md | 2 + python/pyspark/sql/functions.py | 2 +- .../expressions/jsonExpressions.scala | 64 ++++++++----------- .../sql/catalyst/json/JacksonParser.scala | 7 +- .../expressions/JsonExpressionsSuite.scala | 38 ++++++----- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 2 +- .../native/stringCastAndExpressions.sql.out | 2 +- .../apache/spark/sql/JsonFunctionsSuite.scala | 33 +++++++++- .../datasources/json/JsonSuite.scala | 2 +- 11 files changed, 95 insertions(+), 61 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5ad5d78d3ed17..509f689ac521e 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1694,7 +1694,7 @@ test_that("column functions", { df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) s <- collect(select(df, from_json(df$col, schema2))) - expect_equal(s[[1]][[1]], NA) + expect_equal(s[[1]][[1]]$date, NA) s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) expect_is(s[[1]][[1]]$date, "Date") expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index b8b9ad8438554..dfa35b88369cb 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -13,6 +13,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. + - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 32d7f02f61883..2694e777d8266 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2305,7 +2305,7 @@ def from_json(col, schema, options={}): [Row(json=[Row(a=1)])] >>> schema = schema_of_json(lit('''{"a": 0}''')) >>> df.select(from_json(df.value, schema).alias("json")).collect() - [Row(json=Row(a=1))] + [Row(json=Row(a=None))] >>> data = [(1, '''[1, 2, 3]''')] >>> schema = ArrayType(IntegerType()) >>> df = spark.createDataFrame(data, ("key", "value")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index b4815b47d1797..e966924293cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -554,18 +554,36 @@ case class JsonToStructs( @transient lazy val converter = nullableSchema match { case _: StructType => - (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null case _: ArrayType => - (rows: Seq[InternalRow]) => rows.head.getArray(0) + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null case _: MapType => - (rows: Seq[InternalRow]) => rows.head.getMap(0) + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null } - @transient - lazy val parser = - new JacksonParser( - nullableSchema, - new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) + val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + @transient lazy val parser = { + val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord) + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new IllegalArgumentException(s"from_json() doesn't support the ${mode.name} mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") + } + val rawParser = new JacksonParser(nullableSchema, parsedOptions, allowArrayAsStructs = false) + val createParser = CreateJacksonParser.utf8String _ + + val parserSchema = nullableSchema match { + case s: StructType => s + case other => StructType(StructField("value", other) :: Nil) + } + + new FailureSafeParser[UTF8String]( + input => rawParser.parse(input, createParser, identity[UTF8String]), + mode, + parserSchema, + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) + } override def dataType: DataType = nullableSchema @@ -573,35 +591,7 @@ case class JsonToStructs( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { - // When input is, - // - `null`: `null`. - // - invalid json: `null`. - // - empty string: `null`. - // - // When the schema is array, - // - json array: `Array(Row(...), ...)` - // - json object: `Array(Row(...))` - // - empty json array: `Array()`. - // - empty json object: `Array(Row(null))`. - // - // When the schema is a struct, - // - json object/array with single element: `Row(...)` - // - json array with multiple elements: `null` - // - empty json array: `null`. - // - empty json object: `Row(null)`. - - // We need `null` if the input string is an empty string. `JacksonParser` can - // deal with this but produces `Nil`. - if (json.toString.trim.isEmpty) return null - - try { - converter(parser.parse( - json.asInstanceOf[UTF8String], - CreateJacksonParser.utf8String, - identity[UTF8String])) - } catch { - case _: BadRecordException => null - } + converter(parser.parse(json.asInstanceOf[UTF8String])) } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 918c9e71ad37a..57c7f2faf3107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -38,7 +38,8 @@ import org.apache.spark.util.Utils */ class JacksonParser( schema: DataType, - val options: JSONOptions) extends Logging { + val options: JSONOptions, + allowArrayAsStructs: Boolean) extends Logging { import JacksonUtils._ import com.fasterxml.jackson.core.JsonToken._ @@ -84,7 +85,7 @@ class JacksonParser( // List([str_a_1,null]) // List([str_a_2,null], [null,str_b_3]) // - case START_ARRAY => + case START_ARRAY if allowArrayAsStructs => val array = convertArray(parser, elementConverter) // Here, as we support reading top level JSON arrays and take every element // in such an array as a row, this case is possible. @@ -93,6 +94,8 @@ class JacksonParser( } else { array.toArray[InternalRow](schema).toSeq } + case START_ARRAY => + throw new RuntimeException("Parsing JSON arrays as structs is forbidden.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index f31b294fe25d4..304642161146b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Calendar -import org.apache.spark.SparkFunSuite +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -409,14 +411,18 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), - null + InternalRow(null) ) - // Other modes should still return `null`. - checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), - null - ) + val exception = intercept[TestFailedException] { + checkEvaluation( + JsonToStructs(schema, Map("mode" -> FailFastMode.name), Literal(jsonData), gmtId), + InternalRow(null) + ) + }.getCause + assert(exception.isInstanceOf[SparkException]) + assert(exception.getMessage.contains( + "Malformed records are detected in record parsing. Parse Mode: FAILFAST")) } test("from_json - input=array, schema=array, output=array") { @@ -450,21 +456,23 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) - val output = InternalRow(1) + val output = InternalRow(null) checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } - test("from_json - input=array, schema=struct, output=null") { + test("from_json - input=array, schema=struct, output=single row") { val input = """[{"a": 1}, {"a": 2}]""" - val schema = StructType(StructField("a", IntegerType) :: Nil) - val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + val corrupted = "corrupted" + val schema = new StructType().add("a", IntegerType).add(corrupted, StringType) + val output = InternalRow(null, UTF8String.fromString(input)) + val options = Map("columnNameOfCorruptRecord" -> corrupted) + checkEvaluation(JsonToStructs(schema, options, Literal(input), gmtId), output) } - test("from_json - input=empty array, schema=struct, output=null") { + test("from_json - input=empty array, schema=struct, output=single row with null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) - val output = null + val output = InternalRow(null) checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } @@ -487,7 +495,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), - null) + InternalRow(null)) } test("from_json with timestamp") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4f6d8b8a0c34a..95c97e5c9433c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -446,7 +446,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val rawParser = new JacksonParser(actualSchema, parsedOptions) + val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = true) val parser = new FailureSafeParser[String]( input => rawParser.parse(input, createParser, UTF8String.fromString), parsedOptions.parseMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9241afba537b..1f7c9d73f19fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -130,7 +130,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val parser = new JacksonParser(actualSchema, parsedOptions) + val parser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = true) JsonDataSource(parsedOptions).readFile( broadcastedHadoopConf.value.value, file, diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out index ba9bf76513f97..31ee700a8db95 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -258,4 +258,4 @@ select from_json(a, 'a INT') from t -- !query 31 schema struct> -- !query 31 output -NULL +{"a":null} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 5cbf10129a4da..797b274f42cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -132,7 +134,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(from_json($"value", schema)), - Row(null) :: Nil) + Row(Row(null)) :: Nil) } test("from_json - json doesn't conform to the array type") { @@ -547,4 +549,33 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Map("pretty" -> "true"))), Seq(Row(expected))) } + + test("from_json invalid json - check modes") { + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + val schema = new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("_unparsed", StringType) + val badRec = """{"a" 1, "b": 11}""" + val df = Seq(badRec, """{"a": 2, "b": 12}""").toDS() + + checkAnswer( + df.select(from_json($"value", schema, Map("mode" -> "PERMISSIVE"))), + Row(Row(null, null, badRec)) :: Row(Row(2, 12, null)) :: Nil) + + val exception1 = intercept[SparkException] { + df.select(from_json($"value", schema, Map("mode" -> "FAILFAST"))).collect() + }.getMessage + assert(exception1.contains( + "Malformed records are detected in record parsing. Parse Mode: FAILFAST.")) + + val exception2 = intercept[SparkException] { + df.select(from_json($"value", schema, Map("mode" -> "DROPMALFORMED"))) + .collect() + }.getMessage + assert(exception2.contains( + "from_json() doesn't support the DROPMALFORMED mode. " + + "Acceptable modes are PERMISSIVE and FAILFAST.")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 43e1a616e363c..06032ded42a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -67,7 +67,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dummyOption = new JSONOptions(Map.empty[String, String], "GMT") val dummySchema = StructType(Seq.empty) - val parser = new JacksonParser(dummySchema, dummyOption) + val parser = new JacksonParser(dummySchema, dummyOption, allowArrayAsStructs = true) Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser => jsonParser.nextToken() From b19a28dea098c7d6188f8540429c50f42952d678 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 24 Oct 2018 09:08:26 -0500 Subject: [PATCH 1908/2461] [SPARK-16775][CORE] Remove deprecated accumulator v1 APIs ## What changes were proposed in this pull request? Remove deprecated accumulator v1 ## How was this patch tested? Existing tests. Closes #22730 from srowen/SPARK-16775. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../java/org/apache/spark/package-info.java | 4 +- .../scala/org/apache/spark/Accumulable.scala | 226 ------------------ .../scala/org/apache/spark/Accumulator.scala | 117 --------- .../scala/org/apache/spark/SparkContext.scala | 73 +----- .../spark/api/java/JavaSparkContext.scala | 113 --------- .../spark/scheduler/AccumulableInfo.scala | 2 +- .../org/apache/spark/util/AccumulatorV2.scala | 31 --- .../test/org/apache/spark/JavaAPISuite.java | 54 +---- .../org/apache/spark/AccumulatorSuite.scala | 148 +----------- .../spark/util/AccumulatorV2Suite.scala | 53 ---- .../org/apache/sparktest/ImplicitSuite.scala | 20 -- project/MimaExcludes.scala | 19 ++ 12 files changed, 30 insertions(+), 830 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/Accumulable.scala delete mode 100644 core/src/main/scala/org/apache/spark/Accumulator.scala diff --git a/core/src/main/java/org/apache/spark/package-info.java b/core/src/main/java/org/apache/spark/package-info.java index 4426c7afcebdd..a029931f9e4c0 100644 --- a/core/src/main/java/org/apache/spark/package-info.java +++ b/core/src/main/java/org/apache/spark/package-info.java @@ -16,8 +16,8 @@ */ /** - * Core Spark classes in Scala. A few classes here, such as {@link org.apache.spark.Accumulator} - * and {@link org.apache.spark.storage.StorageLevel}, are also used in Java, but the + * Core Spark classes in Scala. A few classes here, such as + * {@link org.apache.spark.storage.StorageLevel}, are also used in Java, but the * {@link org.apache.spark.api.java} package contains the main Java API. */ package org.apache.spark; diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala deleted file mode 100644 index 3092074232d18..0000000000000 --- a/core/src/main/scala/org/apache/spark/Accumulable.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.Serializable - -import scala.collection.generic.Growable -import scala.reflect.ClassTag - -import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, LegacyAccumulatorWrapper} - - -/** - * A data type that can be accumulated, i.e. has a commutative and associative "add" operation, - * but where the result type, `R`, may be different from the element type being added, `T`. - * - * You must define how to add data, and how to merge two of these together. For some data types, - * such as a counter, these might be the same operation. In that case, you can use the simpler - * [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are - * accumulating a set. You will add items to the set, and you will union two sets together. - * - * Operations are not thread-safe. - * - * @param id ID of this accumulator; for internal use only. - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `R` and `T` - * @param name human-readable name for use in Spark's web UI - * @param countFailedValues whether to accumulate values from failed tasks. This is set to true - * for system and time metrics like serialization time or bytes spilled, - * and false for things with absolute values like number of input rows. - * This should be used for internal metrics only. - * @tparam R the full accumulated data (result type) - * @tparam T partial data that can be added in - */ -@deprecated("use AccumulatorV2", "2.0.0") -class Accumulable[R, T] private ( - val id: Long, - // SI-8813: This must explicitly be a private val, or else scala 2.11 doesn't compile - @transient private val initialValue: R, - param: AccumulableParam[R, T], - val name: Option[String], - private[spark] val countFailedValues: Boolean) - extends Serializable { - - private[spark] def this( - initialValue: R, - param: AccumulableParam[R, T], - name: Option[String], - countFailedValues: Boolean) = { - this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues) - } - - private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = { - this(initialValue, param, name, false /* countFailedValues */) - } - - def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) - - val zero = param.zero(initialValue) - private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param) - newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues) - // Register the new accumulator in ctor, to follow the previous behaviour. - AccumulatorContext.register(newAcc) - - /** - * Add more data to this accumulator / accumulable - * @param term the data to add - */ - def += (term: T) { newAcc.add(term) } - - /** - * Add more data to this accumulator / accumulable - * @param term the data to add - */ - def add(term: T) { newAcc.add(term) } - - /** - * Merge two accumulable objects together - * - * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other `R` that will get merged with this - */ - def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) } - - /** - * Merge two accumulable objects together - * - * Normally, a user will not want to use this version, but will instead call `add`. - * @param term the other `R` that will get merged with this - */ - def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) } - - /** - * Access the accumulator's current value; only allowed on driver. - */ - def value: R = { - if (newAcc.isAtDriverSide) { - newAcc.value - } else { - throw new UnsupportedOperationException("Can't read accumulator value in task") - } - } - - /** - * Get the current value of this accumulator from within a task. - * - * This is NOT the global value of the accumulator. To get the global value after a - * completed operation on the dataset, call `value`. - * - * The typical use of this method is to directly mutate the local value, eg., to add - * an element to a Set. - */ - def localValue: R = newAcc.value - - /** - * Set the accumulator's value; only allowed on driver. - */ - def value_= (newValue: R) { - if (newAcc.isAtDriverSide) { - newAcc._value = newValue - } else { - throw new UnsupportedOperationException("Can't assign accumulator value in task") - } - } - - /** - * Set the accumulator's value. For internal use only. - */ - def setValue(newValue: R): Unit = { newAcc._value = newValue } - - /** - * Set the accumulator's value. For internal use only. - */ - private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) } - - /** - * Create an [[AccumulableInfo]] representation of this [[Accumulable]] with the provided values. - */ - private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { - val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) - new AccumulableInfo(id, name, update, value, isInternal, countFailedValues) - } - - override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString -} - - -/** - * Helper object defining how to accumulate values of a particular type. An implicit - * AccumulableParam needs to be available when you create [[Accumulable]]s of a specific type. - * - * @tparam R the full accumulated data (result type) - * @tparam T partial data that can be added in - */ -@deprecated("use AccumulatorV2", "2.0.0") -trait AccumulableParam[R, T] extends Serializable { - /** - * Add additional data to the accumulator value. Is allowed to modify and return `r` - * for efficiency (to avoid allocating objects). - * - * @param r the current value of the accumulator - * @param t the data to be added to the accumulator - * @return the new value of the accumulator - */ - def addAccumulator(r: R, t: T): R - - /** - * Merge two accumulated values together. Is allowed to modify and return the first value - * for efficiency (to avoid allocating objects). - * - * @param r1 one set of accumulated data - * @param r2 another set of accumulated data - * @return both data sets merged together - */ - def addInPlace(r1: R, r2: R): R - - /** - * Return the "zero" (identity) value for an accumulator type, given its initial value. For - * example, if R was a vector of N dimensions, this would return a vector of N zeroes. - */ - def zero(initialValue: R): R -} - - -@deprecated("use AccumulatorV2", "2.0.0") -private[spark] class -GrowableAccumulableParam[R : ClassTag, T] - (implicit rg: R => Growable[T] with TraversableOnce[T] with Serializable) - extends AccumulableParam[R, T] { - - def addAccumulator(growable: R, elem: T): R = { - growable += elem - growable - } - - def addInPlace(t1: R, t2: R): R = { - t1 ++= t2 - t1 - } - - def zero(initialValue: R): R = { - // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. - // Instead we'll serialize it to a buffer and load it back. - val ser = new JavaSerializer(new SparkConf(false)).newInstance() - val copy = ser.deserialize[R](ser.serialize(initialValue)) - copy.clear() // In case it contained stuff - copy - } -} diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala deleted file mode 100644 index 9d5fbefc824ad..0000000000000 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -/** - * A simpler value of [[Accumulable]] where the result type being accumulated is the same - * as the types of elements being merged, i.e. variables that are only "added" to through an - * associative and commutative operation and can therefore be efficiently supported in parallel. - * They can be used to implement counters (as in MapReduce) or sums. Spark natively supports - * accumulators of numeric value types, and programmers can add support for new types. - * - * An accumulator is created from an initial value `v` by calling `SparkContext.accumulator`. - * Tasks running on the cluster can then add to it using the `+=` operator. - * However, they cannot read its value. Only the driver program can read the accumulator's value, - * using its [[#value]] method. - * - * The interpreter session below shows an accumulator being used to add up the elements of an array: - * - * {{{ - * scala> val accum = sc.accumulator(0) - * accum: org.apache.spark.Accumulator[Int] = 0 - * - * scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x) - * ... - * 10/09/29 18:41:08 INFO SparkContext: Tasks finished in 0.317106 s - * - * scala> accum.value - * res2: Int = 10 - * }}} - * - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `T` - * @param name human-readable name associated with this accumulator - * @param countFailedValues whether to accumulate values from failed tasks - * @tparam T result type -*/ -@deprecated("use AccumulatorV2", "2.0.0") -class Accumulator[T] private[spark] ( - // SI-8813: This must explicitly be a private val, or else scala 2.11 doesn't compile - @transient private val initialValue: T, - param: AccumulatorParam[T], - name: Option[String] = None, - countFailedValues: Boolean = false) - extends Accumulable[T, T](initialValue, param, name, countFailedValues) - - -/** - * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add - * in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be - * available when you create Accumulators of a specific type. - * - * @tparam T type of value to accumulate - */ -@deprecated("use AccumulatorV2", "2.0.0") -trait AccumulatorParam[T] extends AccumulableParam[T, T] { - def addAccumulator(t1: T, t2: T): T = { - addInPlace(t1, t2) - } -} - - -@deprecated("use AccumulatorV2", "2.0.0") -object AccumulatorParam { - - // The following implicit objects were in SparkContext before 1.2 and users had to - // `import SparkContext._` to enable them. Now we move them here to make the compiler find - // them automatically. However, as there are duplicate codes in SparkContext for backward - // compatibility, please update them accordingly if you modify the following implicit objects. - - @deprecated("use AccumulatorV2", "2.0.0") - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double): Double = 0.0 - } - - @deprecated("use AccumulatorV2", "2.0.0") - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { - def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int): Int = 0 - } - - @deprecated("use AccumulatorV2", "2.0.0") - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long): Long = t1 + t2 - def zero(initialValue: Long): Long = 0L - } - - @deprecated("use AccumulatorV2", "2.0.0") - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float): Float = t1 + t2 - def zero(initialValue: Float): Float = 0f - } - - // Note: when merging values, this param just adopts the newer value. This is used only - // internally for things that shouldn't really be accumulated across tasks, like input - // read method, which should be the same across all tasks in the same stage. - @deprecated("use AccumulatorV2", "2.0.0") - private[spark] object StringAccumulatorParam extends AccumulatorParam[String] { - def addInPlace(t1: String, t2: String): String = t2 - def zero(initialValue: String): String = "" - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 10f3168a4c2db..b3c9c030487cd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -25,7 +25,6 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReferenc import scala.collection.JavaConverters._ import scala.collection.Map -import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} @@ -51,7 +50,7 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} +import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.status.{AppStatusSource, AppStatusStore} import org.apache.spark.status.api.v1.ThreadStackTrace @@ -1337,76 +1336,6 @@ class SparkContext(config: SparkConf) extends Logging { // Methods for creating shared variables - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" - * values to using the `+=` method. Only the driver can access the accumulator's `value`. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = { - val acc = new Accumulator(initialValue, param) - cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) - acc - } - - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display - * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the - * driver can access the accumulator's `value`. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) - : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Option(name)) - cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) - acc - } - - /** - * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values - * with `+=`. Only the driver can access the accumulable's `value`. - * @tparam R accumulator result type - * @tparam T type that can be added to the accumulator - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) - : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param) - cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) - acc - } - - /** - * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the - * Spark UI. Tasks can add values to the accumulable using the `+=` operator. Only the driver can - * access the accumulable's `value`. - * @tparam R accumulator result type - * @tparam T type that can be added to the accumulator - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) - : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Option(name)) - cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) - acc - } - - /** - * Create an accumulator from a "mutable collection" type. - * - * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by - * standard mutable collections. So you can use this with mutable Map, Set, etc. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] - (initialValue: R): Accumulable[R, T] = { - // TODO the context bound (<%) above should be replaced with simple type bound and implicit - // conversion but is a breaking change. This should be fixed in Spark 3.x. - val param = new GrowableAccumulableParam[R, T] - val acc = new Accumulable(initialValue, param) - cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) - acc - } - /** * Register the given accumulator. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 09c83849e26b2..ef15f95b3fe5b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.AccumulatorParam._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream @@ -530,118 +529,6 @@ class JavaSparkContext(val sc: SparkContext) new JavaDoubleRDD(sc.union(rdds)) } - /** - * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - @deprecated("use sc().longAccumulator()", "2.0.0") - def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = - sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] - - /** - * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - * - * This version supports naming the accumulator for display in Spark's web UI. - */ - @deprecated("use sc().longAccumulator(String)", "2.0.0") - def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = - sc.accumulator(initialValue, name)(IntAccumulatorParam) - .asInstanceOf[Accumulator[java.lang.Integer]] - - /** - * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - @deprecated("use sc().doubleAccumulator()", "2.0.0") - def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] - - /** - * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - * - * This version supports naming the accumulator for display in Spark's web UI. - */ - @deprecated("use sc().doubleAccumulator(String)", "2.0.0") - def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = - sc.accumulator(initialValue, name)(DoubleAccumulatorParam) - .asInstanceOf[Accumulator[java.lang.Double]] - - /** - * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - @deprecated("use sc().longAccumulator()", "2.0.0") - def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) - - /** - * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - * - * This version supports naming the accumulator for display in Spark's web UI. - */ - @deprecated("use sc().longAccumulator(String)", "2.0.0") - def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = - intAccumulator(initialValue, name) - - /** - * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - @deprecated("use sc().doubleAccumulator()", "2.0.0") - def accumulator(initialValue: Double): Accumulator[java.lang.Double] = - doubleAccumulator(initialValue) - - - /** - * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - * - * This version supports naming the accumulator for display in Spark's web UI. - */ - @deprecated("use sc().doubleAccumulator(String)", "2.0.0") - def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = - doubleAccumulator(initialValue, name) - - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" - * values to using the `add` method. Only the master can access the accumulator's `value`. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = - sc.accumulator(initialValue)(accumulatorParam) - - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" - * values to using the `add` method. Only the master can access the accumulator's `value`. - * - * This version supports naming the accumulator for display in Spark's web UI. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) - : Accumulator[T] = - sc.accumulator(initialValue, name)(accumulatorParam) - - /** - * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks - * can "add" values with `add`. Only the master can access the accumulable's `value`. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = - sc.accumulable(initialValue)(param) - - /** - * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks - * can "add" values with `add`. Only the master can access the accumulable's `value`. - * - * This version supports naming the accumulator for display in Spark's web UI. - */ - @deprecated("use AccumulatorV2", "2.0.0") - def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) - : Accumulable[T, R] = - sc.accumulable(initialValue, name)(param) - /** * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 0a5fe5a1d3ee1..d745345f4e0d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + * Information about an [[org.apache.spark.util.AccumulatorV2]] modified during a task or stage. * * @param id accumulator ID * @param name accumulator name diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index bf618b4afbce0..d5b3ce36e742a 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -485,34 +485,3 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { _list.addAll(newValue) } } - - -class LegacyAccumulatorWrapper[R, T]( - initialValue: R, - param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { - private[spark] var _value = initialValue // Current value on driver - - @transient private lazy val _zero = param.zero(initialValue) - - override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) - - override def copy(): LegacyAccumulatorWrapper[R, T] = { - val acc = new LegacyAccumulatorWrapper(initialValue, param) - acc._value = _value - acc - } - - override def reset(): Unit = { - _value = _zero - } - - override def add(v: T): Unit = _value = param.addAccumulator(_value, v) - - override def merge(other: AccumulatorV2[T, R]): Unit = other match { - case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.value) - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - override def value: R = _value -} diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 3992ab7049bdd..365a93d2601e7 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -33,8 +33,6 @@ import java.util.Map; import java.util.concurrent.*; -import org.apache.spark.Accumulator; -import org.apache.spark.AccumulatorParam; import org.apache.spark.Partitioner; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; @@ -186,7 +184,7 @@ public void randomSplit() { long s1 = splits[1].count(); long s2 = splits[2].count(); assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250); - assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350); + assertTrue(s1 + " not within expected range", s1 > 250 && s1 < 350); assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570); } @@ -956,7 +954,7 @@ public void wholeTextFiles() throws Exception { } @Test - public void textFilesCompressed() throws IOException { + public void textFilesCompressed() { String outputDir = new File(tempDir, "output").getAbsolutePath(); JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); rdd.saveAsTextFile(outputDir, DefaultCodec.class); @@ -1183,46 +1181,6 @@ public void zipPartitions() { assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); } - @SuppressWarnings("deprecation") - @Test - public void accumulators() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - - Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(intAccum::add); - assertEquals((Integer) 25, intAccum.value()); - - Accumulator doubleAccum = sc.doubleAccumulator(10.0); - rdd.foreach(x -> doubleAccum.add((double) x)); - assertEquals((Double) 25.0, doubleAccum.value()); - - // Try a custom accumulator type - AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { - @Override - public Float addInPlace(Float r, Float t) { - return r + t; - } - - @Override - public Float addAccumulator(Float r, Float t) { - return r + t; - } - - @Override - public Float zero(Float initialValue) { - return 0.0f; - } - }; - - Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); - rdd.foreach(x -> floatAccum.add((float) x)); - assertEquals((Float) 25.0f, floatAccum.value()); - - // Test the setValue method - floatAccum.setValue(5.0f); - assertEquals((Float) 5.0f, floatAccum.value()); - } - @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); @@ -1410,13 +1368,13 @@ public void sampleByKeyExact() { JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = wrExact.countByKey(); assertEquals(2, wrExactCounts.size()); - assertTrue(wrExactCounts.get(0) == 2); - assertTrue(wrExactCounts.get(1) == 4); + assertEquals(2, (long) wrExactCounts.get(0)); + assertEquals(4, (long) wrExactCounts.get(1)); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = worExact.countByKey(); assertEquals(2, worExactCounts.size()); - assertTrue(worExactCounts.get(0) == 2); - assertTrue(worExactCounts.get(1) == 4); + assertEquals(2, (long) worExactCounts.get(0)); + assertEquals(4, (long) worExactCounts.get(1)); } private static class SomeCustomClass implements Serializable { diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 5d0ffd92647bc..435665d8a1ce2 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -28,7 +28,6 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException -import org.apache.spark.AccumulatorParam.StringAccumulatorParam import org.apache.spark.scheduler._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, AccumulatorV2, LongAccumulator} @@ -45,21 +44,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = - new AccumulableParam[mutable.Set[A], A] { - def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[A]) : mutable.Set[A] = { - new mutable.HashSet[A]() - } - } - test("accumulator serialization") { val ser = new JavaSerializer(new SparkConf).newInstance() val acc = createLongAccum("x") @@ -81,122 +65,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(acc3.isAtDriverSide) } - test ("basic accumulation") { - sc = new SparkContext("local", "test") - val acc: Accumulator[Int] = sc.accumulator(0) - - val d = sc.parallelize(1 to 20) - d.foreach{x => acc += x} - acc.value should be (210) - - val longAcc = sc.accumulator(0L) - val maxInt = Integer.MAX_VALUE.toLong - d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210L + maxInt * 20) - } - - test("value not assignable from tasks") { - sc = new SparkContext("local", "test") - val acc: Accumulator[Int] = sc.accumulator(0) - - val d = sc.parallelize(1 to 20) - intercept[SparkException] { - d.foreach(x => acc.value = x) - } - } - - test ("add value to collection accumulators") { - val maxI = 1000 - for (nThreads <- List(1, 10)) { // test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val d = sc.parallelize(1 to maxI) - d.foreach { - x => acc += x - } - val v = acc.value.asInstanceOf[mutable.Set[Int]] - for (i <- 1 to maxI) { - v should contain(i) - } - resetSparkContext() - } - } - - test("value not readable in tasks") { - val maxI = 1000 - for (nThreads <- List(1, 10)) { // test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val d = sc.parallelize(1 to maxI) - an [SparkException] should be thrownBy { - d.foreach { - x => acc.value += x - } - } - resetSparkContext() - } - } - - test ("collection accumulators") { - val maxI = 1000 - for (nThreads <- List(1, 10)) { - // test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) - val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) - val mapAcc = sc.accumulableCollection(mutable.HashMap[Int, String]()) - val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) - d.foreach { - x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} - } - - // Note that this is typed correctly -- no casts necessary - setAcc.value.size should be (maxI) - bufferAcc.value.size should be (2 * maxI) - mapAcc.value.size should be (maxI) - for (i <- 1 to maxI) { - setAcc.value should contain(i) - bufferAcc.value should contain(i) - mapAcc.value should contain (i -> i.toString) - } - resetSparkContext() - } - } - - test ("localValue readable in tasks") { - val maxI = 1000 - for (nThreads <- List(1, 10)) { // test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} - val d = sc.parallelize(groupedInts) - d.foreach { - x => acc.localValue ++= x - } - acc.value should be ((0 to maxI).toSet) - resetSparkContext() - } - } - - test ("garbage collection") { - // Create an accumulator and let it go out of scope to test that it's properly garbage collected - sc = new SparkContext("local", "test") - var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val accId = acc.id - val ref = WeakReference(acc) - - // Ensure the accumulator is present - assert(ref.get.isDefined) - - // Remove the explicit reference to it and allow weak reference to get garbage collected - acc = null - System.gc() - assert(ref.get.isEmpty) - - AccumulatorContext.remove(accId) - assert(!AccumulatorContext.get(accId).isDefined) - } - test("get accum") { // Don't register with SparkContext for cleanup var acc = createLongAccum("a") @@ -221,20 +89,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(AccumulatorContext.get(100000).isEmpty) } - test("string accumulator param") { - val acc = new Accumulator("", StringAccumulatorParam, Some("darkness")) - assert(acc.value === "") - acc.setValue("feeds") - assert(acc.value === "feeds") - acc.add("your") - assert(acc.value === "your") // value is overwritten, not concatenated - acc += "soul" - assert(acc.value === "soul") - acc ++= "with" - assert(acc.value === "with") - acc.merge("kindness") - assert(acc.value === "kindness") - } } private[spark] object AccumulatorSuite { @@ -256,7 +110,7 @@ private[spark] object AccumulatorSuite { } /** - * Make an `AccumulableInfo` out of an [[Accumulable]] with the intent to use the + * Make an `AccumulableInfo` out of an `AccumulatorV2` with the intent to use the * info as an accumulator update. */ def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index 94c79388e3639..621399af731f7 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,7 +18,6 @@ package org.apache.spark.util import org.apache.spark._ -import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -128,58 +127,6 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.value.isEmpty) } - test("LegacyAccumulatorWrapper") { - val acc = new LegacyAccumulatorWrapper("default", AccumulatorParam.StringAccumulatorParam) - assert(acc.value === "default") - assert(!acc.isZero) - - acc.add("foo") - assert(acc.value === "foo") - assert(!acc.isZero) - - acc.add(new java.lang.String("bar")) - - val acc2 = acc.copyAndReset() - assert(acc2.value === "") - assert(acc2.isZero) - - assert(acc.value === "bar") - assert(!acc.isZero) - - acc2.add("baz") - assert(acc2.value === "baz") - assert(!acc2.isZero) - - // Test merging - acc.merge(acc2) - assert(acc.value === "baz") - assert(!acc.isZero) - - val acc3 = acc.copy() - assert(acc3.value === "baz") - assert(!acc3.isZero) - - acc3.reset() - assert(acc3.isZero) - assert(acc3.value === "") - } - - test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { - val param = new AccumulatorParam[MyData] { - override def zero(initialValue: MyData): MyData = new MyData(0) - override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) - } - - val acc = new LegacyAccumulatorWrapper(new MyData(0), param) - acc.metadata = AccumulatorMetadata( - AccumulatorContext.newId(), - Some("test"), - countFailedValues = false) - AccumulatorContext.register(acc) - - val ser = new JavaSerializer(new SparkConf).newInstance() - ser.serialize(acc) - } } class MyData(val i: Int) extends Serializable diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala index 2fb09ead4b2d8..24762ea2f4e6b 100644 --- a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -74,26 +74,6 @@ class ImplicitSuite { rdd.stats() } - def testDoubleAccumulatorParam(): Unit = { - val sc = mockSparkContext - sc.accumulator(123.4) - } - - def testIntAccumulatorParam(): Unit = { - val sc = mockSparkContext - sc.accumulator(123) - } - - def testLongAccumulatorParam(): Unit = { - val sc = mockSparkContext - sc.accumulator(123L) - } - - def testFloatAccumulatorParam(): Unit = { - val sc = mockSparkContext - sc.accumulator(123F) - } - def testIntWritableConverter(): Unit = { val sc = mockSparkContext sc.sequenceFile[Int, Int]("/a/test/path") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a5d6d6366ede9..d6beac14bed66 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,25 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-16775] Remove deprecated accumulator v1 APIs + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulable"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulator"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulator$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulableParam"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$FloatAccumulatorParam$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$DoubleAccumulatorParam$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$LongAccumulatorParam$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$IntAccumulatorParam$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulable"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulableCollection"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulator"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.LegacyAccumulatorWrapper"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.intAccumulator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulable"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.doubleAccumulator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulator"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"), From 7251be0c04f0380208e0197e559158a9e1400868 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 24 Oct 2018 10:04:17 -0700 Subject: [PATCH 1909/2461] [SPARK-25798][PYTHON] Internally document type conversion between Pandas data and SQL types in Pandas UDFs ## What changes were proposed in this pull request? We are facing some problems about type conversions between Pandas data and SQL types in Pandas UDFs. It's even difficult to identify the problems (see #20163 and #22610). This PR targets to internally document the type conversion table. Some of them looks buggy and we should fix them. Table can be generated via the codes below: ```python from pyspark.sql.types import * from pyspark.sql.functions import pandas_udf columns = [ ('none', 'object(NoneType)'), ('bool', 'bool'), ('int8', 'int8'), ('int16', 'int16'), ('int32', 'int32'), ('int64', 'int64'), ('uint8', 'uint8'), ('uint16', 'uint16'), ('uint32', 'uint32'), ('uint64', 'uint64'), ('float64', 'float16'), ('float64', 'float32'), ('float64', 'float64'), ('date', 'datetime64[ns]'), ('tz_aware_dates', 'datetime64[ns, US/Eastern]'), ('string', 'object(string)'), ('decimal', 'object(Decimal)'), ('array', 'object(array[int32])'), ('float128', 'float128'), ('complex64', 'complex64'), ('complex128', 'complex128'), ('category', 'category'), ('tdeltas', 'timedelta64[ns]'), ] def create_dataframe(): import pandas as pd import numpy as np import decimal pdf = pd.DataFrame({ 'none': [None, None], 'bool': [True, False], 'int8': np.arange(1, 3).astype('int8'), 'int16': np.arange(1, 3).astype('int16'), 'int32': np.arange(1, 3).astype('int32'), 'int64': np.arange(1, 3).astype('int64'), 'uint8': np.arange(1, 3).astype('uint8'), 'uint16': np.arange(1, 3).astype('uint16'), 'uint32': np.arange(1, 3).astype('uint32'), 'uint64': np.arange(1, 3).astype('uint64'), 'float16': np.arange(1, 3).astype('float16'), 'float32': np.arange(1, 3).astype('float32'), 'float64': np.arange(1, 3).astype('float64'), 'float128': np.arange(1, 3).astype('float128'), 'complex64': np.arange(1, 3).astype('complex64'), 'complex128': np.arange(1, 3).astype('complex128'), 'string': list('ab'), 'array': pd.Series([np.array([1, 2, 3], dtype=np.int32), np.array([1, 2, 3], dtype=np.int32)]), 'decimal': pd.Series([decimal.Decimal('1'), decimal.Decimal('2')]), 'date': pd.date_range('19700101', periods=2).values, 'category': pd.Series(list("AB")).astype('category')}) pdf['tdeltas'] = [pdf.date.diff()[1], pdf.date.diff()[0]] pdf['tz_aware_dates'] = pd.date_range('19700101', periods=2, tz='US/Eastern') return pdf types = [ BooleanType(), ByteType(), ShortType(), IntegerType(), LongType(), FloatType(), DoubleType(), DateType(), TimestampType(), StringType(), DecimalType(10, 0), ArrayType(IntegerType()), MapType(StringType(), IntegerType()), StructType([StructField("_1", IntegerType())]), BinaryType(), ] df = spark.range(2).repartition(1) results = [] count = 0 total = len(types) * len(columns) values = [] spark.sparkContext.setLogLevel("FATAL") for t in types: result = [] for column, pandas_t in columns: v = create_dataframe()[column][0] values.append(v) try: row = df.select(pandas_udf(lambda _: create_dataframe()[column], t)(df.id)).first() ret_str = repr(row[0]) except Exception: ret_str = "X" result.append(ret_str) progress = "SQL Type: [%s]\n Pandas Value(Type): %s(%s)]\n Result Python Value: [%s]" % ( t.simpleString(), v, pandas_t, ret_str) count += 1 print("%s/%s:\n %s" % (count, total, progress)) results.append([t.simpleString()] + list(map(str, result))) schema = ["SQL Type \\ Pandas Value(Type)"] + list(map(lambda values_column: "%s(%s)" % (values_column[0], values_column[1][1]), zip(values, columns))) strings = spark.createDataFrame(results, schema=schema)._jdf.showString(20, 20, False) print("\n".join(map(lambda line: " # %s # noqa" % line, strings.strip().split("\n")))) ``` This code is compatible with both Python 2 and 3 but the table was generated under Python 2. ## How was this patch tested? Manually tested and lint check. Closes #22795 from HyukjinKwon/SPARK-25798. Authored-by: hyukjinkwon Signed-off-by: Bryan Cutler --- python/pyspark/sql/functions.py | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2694e777d8266..8b2e423d250cd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -3023,6 +3023,42 @@ def pandas_udf(f=None, returnType=None, functionType=None): conversion on returned data. The conversion is not guaranteed to be correct and results should be checked for accuracy by users. """ + + # The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that + # are not yet visible to the user. Some of behaviors are buggy and might be changed in the near + # future. The table might have to be eventually documented externally. + # Please see SPARK-25798's PR to see the codes in order to generate the table below. + # + # +-----------------------------+----------------------+----------+-------+--------+--------------------+--------------------+--------+---------+---------+---------+------------+------------+------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+-------------+-----------------+------------------+-----------+--------------------------------+ # noqa + # |SQL Type \ Pandas Value(Type)|None(object(NoneType))|True(bool)|1(int8)|1(int16)| 1(int32)| 1(int64)|1(uint8)|1(uint16)|1(uint32)|1(uint64)|1.0(float16)|1.0(float32)|1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))|1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa + # +-----------------------------+----------------------+----------+-------+--------+--------------------+--------------------+--------+---------+---------+---------+------------+------------+------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+-------------+-----------------+------------------+-----------+--------------------------------+ # noqa + # | boolean| None| True| True| True| True| True| True| True| True| True| False| False| False| False| False| X| X| X| False| False| False| X| False| # noqa + # | tinyint| None| 1| 1| 1| 1| 1| X| X| X| X| 1| 1| 1| X| X| X| X| X| X| X| X| 0| X| # noqa + # | smallint| None| 1| 1| 1| 1| 1| 1| X| X| X| 1| 1| 1| X| X| X| X| X| X| X| X| X| X| # noqa + # | int| None| 1| 1| 1| 1| 1| 1| 1| X| X| 1| 1| 1| X| X| X| X| X| X| X| X| X| X| # noqa + # | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| X| 1| 1| 1| 0| 18000000000000| X| X| X| X| X| X| X| X| # noqa + # | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X|1.401298464324817...| X| X| X| X| X| X| # noqa + # | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa + # | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| X| X| X| X| X| X| X| X| X| # noqa + # | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X| X| X| X| X| X| X| X| # noqa + # | string| None| u''|u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u'\x01'| u''| u''| u''| X| X| u'a'| X| X| u''| u''| u''| X| X| # noqa + # | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa + # | array| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa + # | map| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa + # | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa + # | binary| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa + # +-----------------------------+----------------------+----------+-------+--------+--------------------+--------------------+--------+---------+---------+---------+------------+------------+------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+-------------+-----------------+------------------+-----------+--------------------------------+ # noqa + # + # Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be + # used in `returnType`. + # Note: The values inside of the table are generated by `repr`. + # Note: Python 2 is used to generate this table since it is used to check the backward + # compatibility often in practice. + # Note: Pandas 0.19.2 and PyArrow 0.9.0 are used. + # Note: Timezone is Singapore timezone. + # Note: 'X' means it throws an exception during the conversion. + # Note: 'binary' type is only supported with PyArrow 0.10.0+ (SPARK-23555). + # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From f83fedc9f20869ab4c62bb07bac50113d921207f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 24 Oct 2018 14:43:51 -0500 Subject: [PATCH 1910/2461] [SPARK-25737][CORE] Remove JavaSparkContextVarargsWorkaround ## What changes were proposed in this pull request? Remove JavaSparkContextVarargsWorkaround ## How was this patch tested? Existing tests. Closes #22729 from srowen/SPARK-25737. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../JavaSparkContextVarargsWorkaround.java | 67 ------------------- .../spark/api/java/JavaSparkContext.scala | 42 ++++++------ .../test/org/apache/spark/JavaAPISuite.java | 5 -- .../streaming/JavaKinesisWordCountASL.java | 2 +- project/MimaExcludes.scala | 7 ++ python/pyspark/context.py | 8 ++- python/pyspark/streaming/context.py | 8 ++- .../api/java/JavaStreamingContext.scala | 27 ++++---- 8 files changed, 53 insertions(+), 113 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java deleted file mode 100644 index 0dd8fafbf2c82..0000000000000 --- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java; - -import java.util.ArrayList; -import java.util.List; - -// See -// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html -abstract class JavaSparkContextVarargsWorkaround { - - @SafeVarargs - public final JavaRDD union(JavaRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - List> rest = new ArrayList<>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - public JavaDoubleRDD union(JavaDoubleRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - List rest = new ArrayList<>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - @SafeVarargs - public final JavaPairRDD union(JavaPairRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - List> rest = new ArrayList<>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - // These methods take separate "first" and "rest" elements to avoid having the same type erasure - public abstract JavaRDD union(JavaRDD first, List> rest); - public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest); - public abstract JavaPairRDD union(JavaPairRDD first, List> - rest); -} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index ef15f95b3fe5b..03f259d73e975 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,6 +21,7 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} +import scala.annotation.varargs import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -33,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream -import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns @@ -42,8 +43,7 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. */ -class JavaSparkContext(val sc: SparkContext) - extends JavaSparkContextVarargsWorkaround with Closeable { +class JavaSparkContext(val sc: SparkContext) extends Closeable { /** * Create a JavaSparkContext that loads settings from system properties (for instance, when @@ -506,27 +506,29 @@ class JavaSparkContext(val sc: SparkContext) new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } - /** Build the union of two or more RDDs. */ - override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ rest.asScala).map(_.rdd) - implicit val ctag: ClassTag[T] = first.classTag - sc.union(rdds) + /** Build the union of JavaRDDs. */ + @varargs + def union[T](rdds: JavaRDD[T]*): JavaRDD[T] = { + require(rdds.nonEmpty, "Union called on no RDDs") + implicit val ctag: ClassTag[T] = rdds.head.classTag + sc.union(rdds.map(_.rdd)) } - /** Build the union of two or more RDDs. */ - override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) - : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.rdd) - implicit val ctag: ClassTag[(K, V)] = first.classTag - implicit val ctagK: ClassTag[K] = first.kClassTag - implicit val ctagV: ClassTag[V] = first.vClassTag - new JavaPairRDD(sc.union(rdds)) + /** Build the union of JavaPairRDDs. */ + @varargs + def union[K, V](rdds: JavaPairRDD[K, V]*): JavaPairRDD[K, V] = { + require(rdds.nonEmpty, "Union called on no RDDs") + implicit val ctag: ClassTag[(K, V)] = rdds.head.classTag + implicit val ctagK: ClassTag[K] = rdds.head.kClassTag + implicit val ctagV: ClassTag[V] = rdds.head.vClassTag + new JavaPairRDD(sc.union(rdds.map(_.rdd))) } - /** Build the union of two or more RDDs. */ - override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ rest.asScala).map(_.srdd) - new JavaDoubleRDD(sc.union(rdds)) + /** Build the union of JavaDoubleRDDs. */ + @varargs + def union(rdds: JavaDoubleRDD*): JavaDoubleRDD = { + require(rdds.nonEmpty, "Union called on no RDDs") + new JavaDoubleRDD(sc.union(rdds.map(_.srdd))) } /** diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 365a93d2601e7..f979f9e8bb956 100644 --- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -106,11 +106,6 @@ public void sparkContextUnion() { // Varargs JavaRDD sUnion = sc.union(s1, s2); assertEquals(4, sUnion.count()); - // List - List> list = new ArrayList<>(); - list.add(s2); - sUnion = sc.union(s1, list); - assertEquals(4, sUnion.count()); // Union of JavaDoubleRDDs List doubles = Arrays.asList(1.0, 2.0); diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 626bde48e1a86..86c42df9e8435 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -145,7 +145,7 @@ public static void main(String[] args) throws Exception { // Union all the streams if there is more than 1 stream JavaDStream unionStreams; if (streamsList.size() > 1) { - unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + unionStreams = jssc.union(streamsList.toArray(new JavaDStream[0])); } else { // Otherwise, just use the 1 stream unionStreams = streamsList.get(0); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d6beac14bed66..350d8ad6942ff 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25737] Remove JavaSparkContextVarargsWorkaround + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.union"), // [SPARK-16775] Remove deprecated accumulator v1 APIs ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulable"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam"), @@ -55,9 +59,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulable"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.doubleAccumulator"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulator"), + // [SPARK-24109] Remove class SnappyOutputStreamWrapper ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), + // [SPARK-19287] JavaPairRDD flatMapValues requires function returning Iterable, not Iterator ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"), + // [SPARK-25680] SQL execution listener shouldn't happen on execution thread ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this") ) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0924d3d95f044..1180bf91baa5a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -834,9 +834,11 @@ def union(self, rdds): first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] - first = rdds[0]._jrdd - rest = [x._jrdd for x in rdds[1:]] - return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) + cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD + jrdds = SparkContext._gateway.new_array(cls, len(rdds)) + for i in range(0, len(rdds)): + jrdds[i] = rdds[i]._jrdd + return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): """ diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 3fa57ca85b37b..e1c194b446504 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -343,9 +343,11 @@ def union(self, *dstreams): raise ValueError("All DStreams should have same serializer") if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") - first = dstreams[0] - jrest = [d._jdstream for d in dstreams[1:]] - return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) + cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream + jdstreams = SparkContext._gateway.new_array(cls, len(dstreams)) + for i in range(0, len(dstreams)): + jdstreams[i] = dstreams[i]._jdstream + return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer) def addStreamingListener(self, streamingListener): """ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 982e72cffbf3f..e61c0d4ea5afa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,6 +21,7 @@ import java.io.{Closeable, InputStream} import java.lang.{Boolean => JBoolean} import java.util.{List => JList, Map => JMap} +import scala.annotation.varargs import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -36,7 +37,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.StreamingListener @@ -431,24 +431,23 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ - def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = { - val dstreams: Seq[DStream[T]] = (Seq(first) ++ rest.asScala).map(_.dstream) - implicit val cm: ClassTag[T] = first.classTag - ssc.union(dstreams)(cm) + @varargs + def union[T](jdstreams: JavaDStream[T]*): JavaDStream[T] = { + require(jdstreams.nonEmpty, "Union called on no streams") + implicit val cm: ClassTag[T] = jdstreams.head.classTag + ssc.union(jdstreams.map(_.dstream))(cm) } /** * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ - def union[K, V]( - first: JavaPairDStream[K, V], - rest: JList[JavaPairDStream[K, V]] - ): JavaPairDStream[K, V] = { - val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.dstream) - implicit val cm: ClassTag[(K, V)] = first.classTag - implicit val kcm: ClassTag[K] = first.kManifest - implicit val vcm: ClassTag[V] = first.vManifest - new JavaPairDStream[K, V](ssc.union(dstreams)(cm))(kcm, vcm) + @varargs + def union[K, V](jdstreams: JavaPairDStream[K, V]*): JavaPairDStream[K, V] = { + require(jdstreams.nonEmpty, "Union called on no streams") + implicit val cm: ClassTag[(K, V)] = jdstreams.head.classTag + implicit val kcm: ClassTag[K] = jdstreams.head.kManifest + implicit val vcm: ClassTag[V] = jdstreams.head.vManifest + new JavaPairDStream[K, V](ssc.union(jdstreams.map(_.dstream))(cm))(kcm, vcm) } /** From b2e3256256e409d6f7b6e68e6ee26d532d778268 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 24 Oct 2018 16:56:17 -0500 Subject: [PATCH 1911/2461] [SPARK-25490][SQL][TEST] Fix OOM of KryoBenchmark due to large 2D array and refactor it to use main method ## What changes were proposed in this pull request? Before the code changes, I tried to run it with 8G memory: ``` build/sbt -mem 8000 "core/testOnly org.apache.spark.serializer.KryoBenchmark" ``` Still I got got OOM. This is because the lengths of the arrays are random https://github.com/apache/spark/blob/669ade3a8eed0016b5ece57d776cea0616417088/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala#L90-L91 And the 2D array is usually large: `10000 * Random.nextInt(0, 10000)` This PR is to fix it and refactor it to use main method. The benchmark result is also reason compared to the original one. ## How was this patch tested? Run with ``` bin/spark-submit --class org.apache.spark.serializer.KryoBenchmark core/target/scala-2.11/spark-core_2.11-3.0.0-SNAPSHOT-tests.jar ``` and ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain org.apache.spark.serializer.KryoBenchmark" Closes #22663 from gengliangwang/kyroBenchmark. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- core/benchmarks/KryoBenchmark-results.txt | 29 +++++++ .../spark/serializer/KryoBenchmark.scala | 79 +++++++++---------- 2 files changed, 65 insertions(+), 43 deletions(-) create mode 100644 core/benchmarks/KryoBenchmark-results.txt diff --git a/core/benchmarks/KryoBenchmark-results.txt b/core/benchmarks/KryoBenchmark-results.txt new file mode 100644 index 0000000000000..91e22f3afc14f --- /dev/null +++ b/core/benchmarks/KryoBenchmark-results.txt @@ -0,0 +1,29 @@ +================================================================================================ +Benchmark Kryo Unsafe vs safe Serialization +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + +Benchmark Kryo Unsafe vs safe Serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +basicTypes: Int with unsafe:true 138 / 149 7.2 138.0 1.0X +basicTypes: Long with unsafe:true 168 / 173 6.0 167.7 0.8X +basicTypes: Float with unsafe:true 153 / 174 6.5 153.1 0.9X +basicTypes: Double with unsafe:true 161 / 185 6.2 161.1 0.9X +Array: Int with unsafe:true 2 / 3 409.7 2.4 56.5X +Array: Long with unsafe:true 4 / 5 232.5 4.3 32.1X +Array: Float with unsafe:true 3 / 4 367.3 2.7 50.7X +Array: Double with unsafe:true 4 / 5 228.5 4.4 31.5X +Map of string->Double with unsafe:true 38 / 45 26.5 37.8 3.7X +basicTypes: Int with unsafe:false 176 / 187 5.7 175.9 0.8X +basicTypes: Long with unsafe:false 191 / 203 5.2 191.2 0.7X +basicTypes: Float with unsafe:false 166 / 176 6.0 166.2 0.8X +basicTypes: Double with unsafe:false 174 / 190 5.7 174.3 0.8X +Array: Int with unsafe:false 19 / 26 52.9 18.9 7.3X +Array: Long with unsafe:false 27 / 31 37.7 26.5 5.2X +Array: Float with unsafe:false 8 / 10 124.3 8.0 17.2X +Array: Double with unsafe:false 12 / 13 83.6 12.0 11.5X +Map of string->Double with unsafe:false 38 / 42 26.1 38.3 3.6X + + diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala index f4fc0080f3108..8a52c131af847 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala @@ -20,58 +20,48 @@ package org.apache.spark.serializer import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.benchmark.Benchmark +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.serializer.KryoTest._ -class KryoBenchmark extends SparkFunSuite { - val benchmark = new Benchmark("Benchmark Kryo Unsafe vs safe Serialization", 1024 * 1024 * 15, 10) - - ignore(s"Benchmark Kryo Unsafe vs safe Serialization") { - Seq (true, false).foreach (runBenchmark) - benchmark.run() - - // scalastyle:off - /* - Benchmark Kryo Unsafe vs safe Serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - basicTypes: Int with unsafe:true 151 / 170 104.2 9.6 1.0X - basicTypes: Long with unsafe:true 175 / 191 89.8 11.1 0.9X - basicTypes: Float with unsafe:true 177 / 184 88.8 11.3 0.9X - basicTypes: Double with unsafe:true 193 / 216 81.4 12.3 0.8X - Array: Int with unsafe:true 513 / 587 30.7 32.6 0.3X - Array: Long with unsafe:true 1211 / 1358 13.0 77.0 0.1X - Array: Float with unsafe:true 890 / 964 17.7 56.6 0.2X - Array: Double with unsafe:true 1335 / 1428 11.8 84.9 0.1X - Map of string->Double with unsafe:true 931 / 988 16.9 59.2 0.2X - basicTypes: Int with unsafe:false 197 / 217 79.9 12.5 0.8X - basicTypes: Long with unsafe:false 219 / 240 71.8 13.9 0.7X - basicTypes: Float with unsafe:false 208 / 217 75.7 13.2 0.7X - basicTypes: Double with unsafe:false 208 / 225 75.6 13.2 0.7X - Array: Int with unsafe:false 2559 / 2681 6.1 162.7 0.1X - Array: Long with unsafe:false 3425 / 3516 4.6 217.8 0.0X - Array: Float with unsafe:false 2025 / 2134 7.8 128.7 0.1X - Array: Double with unsafe:false 2241 / 2358 7.0 142.5 0.1X - Map of string->Double with unsafe:false 1044 / 1085 15.1 66.4 0.1X - */ - // scalastyle:on +/** + * Benchmark for Kryo Unsafe vs safe Serialization. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "core/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain " + * Results will be written to "benchmarks/KryoBenchmark-results.txt". + * }}} + */ +object KryoBenchmark extends BenchmarkBase { + + val N = 1000000 + override def runBenchmarkSuite(): Unit = { + val name = "Benchmark Kryo Unsafe vs safe Serialization" + runBenchmark(name) { + val benchmark = new Benchmark(name, N, 10, output = output) + Seq(true, false).foreach(useUnsafe => run(useUnsafe, benchmark)) + benchmark.run() + } } - private def runBenchmark(useUnsafe: Boolean): Unit = { + private def run(useUnsafe: Boolean, benchmark: Benchmark): Unit = { def check[T: ClassTag](t: T, ser: SerializerInstance): Int = { - if (ser.deserialize[T](ser.serialize(t)) === t) 1 else 0 + if (ser.deserialize[T](ser.serialize(t)) == t) 1 else 0 } // Benchmark Primitives - val basicTypeCount = 1000000 def basicTypes[T: ClassTag](name: String, gen: () => T): Unit = { lazy val ser = createSerializer(useUnsafe) - val arrayOfBasicType: Array[T] = Array.fill(basicTypeCount)(gen()) + val arrayOfBasicType: Array[T] = Array.fill(N)(gen()) benchmark.addCase(s"basicTypes: $name with unsafe:$useUnsafe") { _ => var sum = 0L var i = 0 - while (i < basicTypeCount) { + while (i < N) { sum += check(arrayOfBasicType(i), ser) i += 1 } @@ -84,11 +74,12 @@ class KryoBenchmark extends SparkFunSuite { basicTypes("Double", () => Random.nextDouble()) // Benchmark Array of Primitives - val arrayCount = 10000 + val arrayCount = 4000 + val arrayLength = N / arrayCount def basicTypeArray[T: ClassTag](name: String, gen: () => T): Unit = { lazy val ser = createSerializer(useUnsafe) val arrayOfArrays: Array[Array[T]] = - Array.fill(arrayCount)(Array.fill[T](Random.nextInt(arrayCount))(gen())) + Array.fill(arrayCount)(Array.fill[T](arrayLength + Random.nextInt(arrayLength / 4))(gen())) benchmark.addCase(s"Array: $name with unsafe:$useUnsafe") { _ => var sum = 0L @@ -107,11 +98,13 @@ class KryoBenchmark extends SparkFunSuite { basicTypeArray("Double", () => Random.nextDouble()) // Benchmark Maps - val mapsCount = 1000 + val mapsCount = 200 + val mapKeyLength = 20 + val mapLength = N / mapsCount / mapKeyLength lazy val ser = createSerializer(useUnsafe) val arrayOfMaps: Array[Map[String, Double]] = Array.fill(mapsCount) { - Array.fill(Random.nextInt(mapsCount)) { - (Random.nextString(mapsCount / 10), Random.nextDouble()) + Array.fill(mapLength + Random.nextInt(mapLength / 4)) { + (Random.nextString(mapKeyLength), Random.nextDouble()) }.toMap } From 19ada15d1b15256de4e3bf2f4b17d87ea0d65cc3 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 24 Oct 2018 23:29:47 -0700 Subject: [PATCH 1912/2461] [SPARK-24516][K8S] Change Python default to Python3 ## What changes were proposed in this pull request? As this is targeted for 3.0.0 and Python2 will be deprecated by Jan 1st, 2020, I feel it is appropriate to change the default to Python3. Especially as these projects [found here](https://python3statement.org/) are deprecating their support. ## How was this patch tested? Unit and Integration tests Author: Ilan Filonenko Closes #22810 from ifilonenko/SPARK-24516. --- docs/running-on-kubernetes.md | 2 +- .../src/main/scala/org/apache/spark/deploy/k8s/Config.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d629ed3b503ab..60c9279f2bce2 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -816,7 +816,7 @@ specific to Spark on Kubernetes. spark.kubernetes.pyspark.pythonVersion - "2" + "3" This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index c2ad80c4755a6..fff8fa4340c35 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -223,7 +223,7 @@ private[spark] object Config extends Logging { .stringConf .checkValue(pv => List("2", "3").contains(pv), "Ensure that major Python version is either Python2 or Python3") - .createWithDefault("2") + .createWithDefault("3") val KUBERNETES_KERBEROS_KRB5_FILE = ConfigBuilder("spark.kubernetes.kerberos.krb5.path") From ddd1b1e8aec023e61b186c494ccbc182db2eb3ca Mon Sep 17 00:00:00 2001 From: adrian555 Date: Wed, 24 Oct 2018 23:42:06 -0700 Subject: [PATCH 1913/2461] [SPARK-24572][SPARKR] "eager execution" for R shell, IDE ## What changes were proposed in this pull request? Check the `spark.sql.repl.eagerEval.enabled` configuration property in SparkDataFrame `show()` method. If the `SparkSession` has eager execution enabled, the data will be returned to the R client when the data frame is created. So instead of seeing this ``` > df <- createDataFrame(faithful) > df SparkDataFrame[eruptions:double, waiting:double] ``` you will see ``` > df <- createDataFrame(faithful) > df +---------+-------+ |eruptions|waiting| +---------+-------+ | 3.6| 79.0| | 1.8| 54.0| | 3.333| 74.0| | 2.283| 62.0| | 4.533| 85.0| | 2.883| 55.0| | 4.7| 88.0| | 3.6| 85.0| | 1.95| 51.0| | 4.35| 85.0| | 1.833| 54.0| | 3.917| 84.0| | 4.2| 78.0| | 1.75| 47.0| | 4.7| 83.0| | 2.167| 52.0| | 1.75| 62.0| | 4.8| 84.0| | 1.6| 52.0| | 4.25| 79.0| +---------+-------+ only showing top 20 rows ``` ## How was this patch tested? Manual tests as well as unit tests (one new test case is added). Author: adrian555 Closes #22455 from adrian555/eager_execution. --- R/pkg/R/DataFrame.R | 36 ++++++++-- R/pkg/tests/fulltests/test_sparkSQL_eager.R | 72 +++++++++++++++++++ docs/sparkr.md | 42 +++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 7 +- 4 files changed, 148 insertions(+), 9 deletions(-) create mode 100644 R/pkg/tests/fulltests/test_sparkSQL_eager.R diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 34691883bc5a9..bf82d0c7882d7 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -226,7 +226,9 @@ setMethod("showDF", #' show #' -#' Print class and type information of a Spark object. +#' If eager evaluation is enabled and the Spark object is a SparkDataFrame, evaluate the +#' SparkDataFrame and print top rows of the SparkDataFrame, otherwise, print the class +#' and type information of the Spark object. #' #' @param object a Spark object. Can be a SparkDataFrame, Column, GroupedData, WindowSpec. #' @@ -244,11 +246,33 @@ setMethod("showDF", #' @note show(SparkDataFrame) since 1.4.0 setMethod("show", "SparkDataFrame", function(object) { - cols <- lapply(dtypes(object), function(l) { - paste(l, collapse = ":") - }) - s <- paste(cols, collapse = ", ") - cat(paste(class(object), "[", s, "]\n", sep = "")) + allConf <- sparkR.conf() + prop <- allConf[["spark.sql.repl.eagerEval.enabled"]] + if (!is.null(prop) && identical(prop, "true")) { + argsList <- list() + argsList$x <- object + prop <- allConf[["spark.sql.repl.eagerEval.maxNumRows"]] + if (!is.null(prop)) { + numRows <- as.integer(prop) + if (numRows > 0) { + argsList$numRows <- numRows + } + } + prop <- allConf[["spark.sql.repl.eagerEval.truncate"]] + if (!is.null(prop)) { + truncate <- as.integer(prop) + if (truncate > 0) { + argsList$truncate <- truncate + } + } + do.call(showDF, argsList) + } else { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste(class(object), "[", s, "]\n", sep = "")) + } }) #' DataTypes diff --git a/R/pkg/tests/fulltests/test_sparkSQL_eager.R b/R/pkg/tests/fulltests/test_sparkSQL_eager.R new file mode 100644 index 0000000000000..df7354fa063e9 --- /dev/null +++ b/R/pkg/tests/fulltests/test_sparkSQL_eager.R @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("test show SparkDataFrame when eager execution is enabled.") + +test_that("eager execution is not enabled", { + # Start Spark session without eager execution enabled + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + df <- createDataFrame(faithful) + expect_is(df, "SparkDataFrame") + expected <- "eruptions:double, waiting:double" + expect_output(show(df), expected) + + # Stop Spark session + sparkR.session.stop() +}) + +test_that("eager execution is enabled", { + # Start Spark session with eager execution enabled + sparkConfig <- list(spark.sql.repl.eagerEval.enabled = "true") + + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkConfig) + + df <- createDataFrame(faithful) + expect_is(df, "SparkDataFrame") + expected <- paste0("(+---------+-------+\n", + "|eruptions|waiting|\n", + "+---------+-------+\n)*", + "(only showing top 20 rows)") + expect_output(show(df), expected) + + # Stop Spark session + sparkR.session.stop() +}) + +test_that("eager execution is enabled with maxNumRows and truncate set", { + # Start Spark session with eager execution enabled + sparkConfig <- list(spark.sql.repl.eagerEval.enabled = "true", + spark.sql.repl.eagerEval.maxNumRows = as.integer(5), + spark.sql.repl.eagerEval.truncate = as.integer(2)) + + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkConfig) + + df <- arrange(createDataFrame(faithful), "waiting") + expect_is(df, "SparkDataFrame") + expected <- paste0("(+---------+-------+\n", + "|eruptions|waiting|\n", + "+---------+-------+\n", + "| 1.| 43|\n)*", + "(only showing top 5 rows)") + expect_output(show(df), expected) + + # Stop Spark session + sparkR.session.stop() +}) diff --git a/docs/sparkr.md b/docs/sparkr.md index ba4cca811b185..79f8ab81342be 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -450,6 +450,48 @@ print(model.summaries) {% endhighlight %} +### Eager execution + +If eager execution is enabled, the data will be returned to R client immediately when the `SparkDataFrame` is created. By default, eager execution is not enabled and can be enabled by setting the configuration property `spark.sql.repl.eagerEval.enabled` to `true` when the `SparkSession` is started up. + +Maximum number of rows and maximum number of characters per column of data to display can be controlled by `spark.sql.repl.eagerEval.maxNumRows` and `spark.sql.repl.eagerEval.truncate` configuration properties, respectively. These properties are only effective when eager execution is enabled. If these properties are not set explicitly, by default, data up to 20 rows and up to 20 characters per column will be showed. + +
      +{% highlight r %} + +# Start up spark session with eager execution enabled +sparkR.session(master = "local[*]", + sparkConfig = list(spark.sql.repl.eagerEval.enabled = "true", + spark.sql.repl.eagerEval.maxNumRows = as.integer(10))) + +# Create a grouped and sorted SparkDataFrame +df <- createDataFrame(faithful) +df2 <- arrange(summarize(groupBy(df, df$waiting), count = n(df$waiting)), "waiting") + +# Similar to R data.frame, displays the data returned, instead of SparkDataFrame class string +df2 + +##+-------+-----+ +##|waiting|count| +##+-------+-----+ +##| 43.0| 1| +##| 45.0| 3| +##| 46.0| 5| +##| 47.0| 4| +##| 48.0| 3| +##| 49.0| 5| +##| 50.0| 5| +##| 51.0| 6| +##| 52.0| 5| +##| 53.0| 7| +##+-------+-----+ +##only showing top 10 rows + +{% endhighlight %} +
      + +Note that to enable eager execution in `sparkR` shell, add `spark.sql.repl.eagerEval.enabled=true` configuration property to the `--conf` option. + ## Running SQL Queries from SparkR A SparkDataFrame can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The `sql` function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index da70d7da7351b..e8529550b8fca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1482,9 +1482,10 @@ object SQLConf { val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + - "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + - "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + - "the returned outputs are formatted like dataframe.show().") + "eager evaluation is supported in PySpark and SparkR. In PySpark, for the notebooks like " + + "Jupyter, the HTML table (generated by _repr_html_) will be returned. For plain Python " + + "REPL, the returned outputs are formatted like dataframe.show(). In SparkR, the returned " + + "outputs are showed similar to R data.frame would.") .booleanConf .createWithDefault(false) From cb5ea201df5fae8aacb653ffb4147b9288bca1e9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Oct 2018 19:27:45 +0800 Subject: [PATCH 1914/2461] [SPARK-25746][SQL] Refactoring ExpressionEncoder to get rid of flat flag ## What changes were proposed in this pull request? This is inspired during implementing #21732. For now `ScalaReflection` needs to consider how `ExpressionEncoder` uses generated serializers and deserializers. And `ExpressionEncoder` has a weird `flat` flag. After discussion with cloud-fan, it seems to be better to refactor `ExpressionEncoder`. It should make SPARK-24762 easier to do. To summarize the proposed changes: 1. `serializerFor` and `deserializerFor` return expressions for serializing/deserializing an input expression for a given type. They are private and should not be called directly. 2. `serializerForType` and `deserializerForType` returns an expression for serializing/deserializing for an object of type T to/from Spark SQL representation. It assumes the input object/Spark SQL representation is located at ordinal 0 of a row. So in other words, `serializerForType` and `deserializerForType` return expressions for atomically serializing/deserializing JVM object to/from Spark SQL value. A serializer returned by `serializerForType` will serialize an object at `row(0)` to a corresponding Spark SQL representation, e.g. primitive type, array, map, struct. A deserializer returned by `deserializerForType` will deserialize an input field at `row(0)` to an object with given type. 3. The construction of `ExpressionEncoder` takes a pair of serializer and deserializer for type `T`. It uses them to create serializer and deserializer for T <-> row serialization. Now `ExpressionEncoder` dones't need to remember if serializer is flat or not. When we need to construct new `ExpressionEncoder` based on existing ones, we only need to change input location in the atomic serializer and deserializer. ## How was this patch tested? Existing tests. Closes #22749 from viirya/SPARK-24762-refactor. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/Encoders.scala | 8 +- .../sql/catalyst/JavaTypeInference.scala | 78 +++---- .../spark/sql/catalyst/ScalaReflection.scala | 182 ++++++++-------- .../catalyst/encoders/ExpressionEncoder.scala | 201 +++++++++++------- .../sql/catalyst/encoders/RowEncoder.scala | 16 +- .../sql/catalyst/ScalaReflectionSuite.scala | 70 +++--- .../encoders/ExpressionEncoderSuite.scala | 6 +- .../catalyst/encoders/RowEncoderSuite.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 10 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 12 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- 12 files changed, 304 insertions(+), 285 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index b47ec0b72c638..8a30c81912fe9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -203,12 +203,10 @@ object Encoders { validatePublicClass[T]() ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - serializer = Seq( + objSerializer = EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - deserializer = + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo), + objDeserializer = DecodeUsingSerializer[T]( Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), classTag[T], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 60dd4a57139e3..f32e080447317 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -187,26 +187,23 @@ object JavaTypeInference { } /** - * Returns an expression that can be used to deserialize an internal row to an object of java bean - * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal + * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed + * using `UnresolvedExtractValue`. */ def deserializerFor(beanClass: Class[_]): Expression = { - deserializerFor(TypeToken.of(beanClass), None) + val typeToken = TypeToken.of(beanClass) + deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1)) } - private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path or `GetColumnByOrdinal`. */ - def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) + def addToPath(part: String): Expression = UnresolvedExtractValue(path, + expressions.Literal(part)) typeToken.getRawType match { - case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + case c if !inferExternalType(c).isInstanceOf[ObjectType] => path case c if c == classOf[java.lang.Short] || c == classOf[java.lang.Integer] || @@ -219,7 +216,7 @@ object JavaTypeInference { c, ObjectType(c), "valueOf", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.sql.Date] => @@ -227,7 +224,7 @@ object JavaTypeInference { DateTimeUtils.getClass, ObjectType(c), "toJavaDate", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.sql.Timestamp] => @@ -235,14 +232,14 @@ object JavaTypeInference { DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(path, "toString", ObjectType(classOf[String])) case c if c == classOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) case c if c.isArray => val elementType = c.getComponentType @@ -258,12 +255,12 @@ object JavaTypeInference { } primitiveMethod.map { method => - Invoke(getPath, method, ObjectType(c)) + Invoke(path, method, ObjectType(c)) }.getOrElse { Invoke( MapObjects( - p => deserializerFor(typeToken.getComponentType, Some(p)), - getPath, + p => deserializerFor(typeToken.getComponentType, p), + path, inferDataType(elementType)._1), "array", ObjectType(c)) @@ -272,8 +269,8 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) UnresolvedMapObjects( - p => deserializerFor(et, Some(p)), - getPath, + p => deserializerFor(et, p), + path, customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => @@ -282,16 +279,16 @@ object JavaTypeInference { val keyData = Invoke( UnresolvedMapObjects( - p => deserializerFor(keyType, Some(p)), - GetKeyArrayFromMap(getPath)), + p => deserializerFor(keyType, p), + GetKeyArrayFromMap(path)), "array", ObjectType(classOf[Array[Any]])) val valueData = Invoke( UnresolvedMapObjects( - p => deserializerFor(valueType, Some(p)), - GetValueArrayFromMap(getPath)), + p => deserializerFor(valueType, p), + GetValueArrayFromMap(path)), "array", ObjectType(classOf[Array[Any]])) @@ -307,7 +304,7 @@ object JavaTypeInference { other, ObjectType(other), "valueOf", - Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, returnNullable = false) case other => @@ -316,7 +313,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (_, nullable) = inferDataType(fieldType) - val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor(fieldType, addToPath(fieldName)) val setter = if (nullable) { constructor } else { @@ -328,28 +325,23 @@ object JavaTypeInference { val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(other)), - result - ) - } else { + expressions.If( + IsNull(path), + expressions.Literal.create(null, ObjectType(other)), result - } + ) } } /** - * Returns an expression for serializing an object of the given type to an internal row. + * Returns an expression for serializing an object of the given type to a Spark SQL + * representation. The input object is located at ordinal 0 of a row, i.e., + * `BoundReference(0, _)`. */ - def serializerFor(beanClass: Class[_]): CreateNamedStruct = { + def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) - serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { - case expressions.If(_, _, s: CreateNamedStruct) => s - case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) - } + serializerFor(nullSafeInput, TypeToken.of(beanClass)) } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c27180e2a6b9b..40074b36f6a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -24,7 +24,7 @@ import scala.util.Properties import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} @@ -129,21 +129,44 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns an expression that can be used to deserialize an input row to an object of type `T` - * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. * - * When used on a primitive type, the constructor will instead default to extracting the value - * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling resolve/bind with a new schema. + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. */ - def deserializerFor[T : TypeTag]: Expression = { - val tpe = localTypeOf[T] + private def upCastToExpectedType(expr: Expression, expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. + case _ => UpCast(expr, expected, walkedTypePath) + } + + /** + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of + * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using + * `UnresolvedExtractValue`. + * + * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this + * deserializer expression when using it. + */ + def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - val expr = deserializerFor(tpe, None, walkedTypePath) - val Schema(_, nullable) = schemaFor(tpe) + val Schema(dataType, nullable) = schemaFor(tpe) + + // Assumes we are deserializing the first column of a row. + val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, + walkedTypePath) + + val expr = deserializerFor(tpe, input, walkedTypePath) if (nullable) { expr } else { @@ -151,16 +174,22 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Returns an expression that can be used to deserialize an input expression to an object of type + * `T` with a compatible schema. + * + * @param tpe The `Type` of deserialized object. + * @param path The expression which can be used to extract serialized value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + */ private def deserializerFor( tpe: `Type`, - path: Option[Expression], + path: Expression, walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute.quoted(part)) + val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -169,46 +198,12 @@ object ScalaReflection extends ScalaReflection { ordinal: Int, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(GetColumnByOrdinal(ordinal, dataType)) + val newPath = GetStructField(path, ordinal) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `GetColumnByOrdinal`. */ - def getPath: Expression = { - val dataType = schemaFor(tpe).dataType - if (path.isDefined) { - path.get - } else { - upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - } - } - - /** - * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff - * and lost the required data type, which may lead to runtime error if the real type doesn't - * match the encoder's schema. - * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type - * is [a: int, b: long], then we will hit runtime error and say that we can't construct class - * `Data` with int and long, because we lost the information that `b` should be a string. - * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * only need to do this for leaf nodes. - */ - def upCastToExpectedType( - expr: Expression, - expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { - case _: StructType => expr - case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. - case _ => UpCast(expr, expected, walkedTypePath) - } - tpe.dealias match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -219,44 +214,44 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil, + path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => @@ -264,25 +259,25 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil, + path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) + Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), returnNullable = false) case t if t <:< localTypeOf[Array[_]] => @@ -294,7 +289,7 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, Some(casted), newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath) if (elementNullable) { converter } else { @@ -302,7 +297,7 @@ object ScalaReflection extends ScalaReflection { } } - val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) if (elementNullable) { @@ -334,7 +329,7 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, Some(casted), newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath) if (elementNullable) { converter } else { @@ -349,16 +344,16 @@ object ScalaReflection extends ScalaReflection { classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - UnresolvedMapObjects(mapFunction, getPath, Some(cls)) + UnresolvedMapObjects(mapFunction, path, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t CatalystToExternalMap( - p => deserializerFor(keyType, Some(p), walkedTypePath), - p => deserializerFor(valueType, Some(p), walkedTypePath), - getPath, + p => deserializerFor(keyType, p, walkedTypePath), + p => deserializerFor(valueType, p, walkedTypePath), + path, mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -368,7 +363,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -377,7 +372,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -392,12 +387,12 @@ object ScalaReflection extends ScalaReflection { val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - Some(addToPathOrdinal(i, dataType, newTypePath)), + addToPathOrdinal(i, dataType, newTypePath), newTypePath) } else { deserializerFor( fieldType, - Some(addToPath(fieldName, dataType, newTypePath)), + addToPath(fieldName, dataType, newTypePath), newTypePath) } @@ -410,20 +405,17 @@ object ScalaReflection extends ScalaReflection { val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { + expressions.If( + IsNull(path), + expressions.Literal.create(null, ObjectType(cls)), newInstance - } + ) } } /** - * Returns an expression for serializing an object of type T to an internal row. + * Returns an expression for serializing an object of type T to Spark SQL representation. The + * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain @@ -434,17 +426,21 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - val tpe = localTypeOf[T] + def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - serializerFor(inputObject, tpe, walkedTypePath) match { - case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s - case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) - } + + // The input object to `ExpressionEncoder` is located at first column of an row. + val inputObject = BoundReference(0, dataTypeFor(tpe), + nullable = !tpe.typeSymbol.asClass.isPrimitive) + + serializerFor(inputObject, tpe, walkedTypePath) } - /** Helper for extracting internal fields from a case class. */ + /** + * Returns an expression for serializing the value of an input expression into Spark SQL + * internal representation. + */ private def serializerFor( inputObject: Expression, tpe: `Type`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cbea3c017a265..29f6136a75ee8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -43,8 +44,8 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { + def apply[T : TypeTag](): ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe @@ -58,25 +59,11 @@ object ExpressionEncoder { } val cls = mirror.runtimeClass(tpe) - val flat = !ScalaReflection.definedByConstructorParams(tpe) - - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive) - val nullSafeInput = if (flat) { - inputObject - } else { - // For input object of Product type, we can't encode it to row if it's null, as Spark SQL - // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(inputObject, Seq("top level Product input object")) - } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T] - - val schema = serializer.dataType + val serializer = ScalaReflection.serializerForType(tpe) + val deserializer = ScalaReflection.deserializerForType(tpe) new ExpressionEncoder[T]( - schema, - flat, - serializer.flatten, + serializer, deserializer, ClassTag[T](cls)) } @@ -86,14 +73,12 @@ object ExpressionEncoder { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val serializer = JavaTypeInference.serializerFor(beanClass) - val deserializer = JavaTypeInference.deserializerFor(beanClass) + val objSerializer = JavaTypeInference.serializerFor(beanClass) + val objDeserializer = JavaTypeInference.deserializerFor(beanClass) new ExpressionEncoder[T]( - schema.asInstanceOf[StructType], - flat = false, - serializer.flatten, - deserializer, + objSerializer, + objDeserializer, ClassTag[T](beanClass)) } @@ -103,75 +88,59 @@ object ExpressionEncoder { * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + // TODO: check if encoders length is more than 22 and throw exception for it. + encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { case (e, i) => - val (dataType, nullable) = if (e.flat) { - e.schema.head.dataType -> e.schema.head.nullable - } else { - e.schema -> true - } - StructField(s"_${i + 1}", dataType, nullable) + StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val serializer = encoders.zipWithIndex.map { case (enc, index) => - val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head + val serializers = encoders.zipWithIndex.map { case (enc, index) => + val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct + assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + + s"there are ${boundRefs.size}") + + val originalInputObject = boundRefs.head val newInputObject = Invoke( BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", - originalInputObject.dataType) - - val newSerializer = enc.serializer.map(_.transformUp { - case b: BoundReference if b == originalInputObject => newInputObject - }) + originalInputObject.dataType, + returnNullable = originalInputObject.nullable) - val serializerExpr = if (enc.flat) { - newSerializer.head - } else { - // For non-flat encoder, the input object is not top level anymore after being combined to - // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and - // null check to handle null case correctly. - // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is - // not able to handle the case when the input tuple is null. This is not a problem as there - // is a check to make sure the input object won't be null. However, if this encoder is used - // to create a bigger tuple encoder, the original input object becomes a filed of the new - // input tuple and can be null. So instead of creating a struct directly here, we should add - // a null/None check and return a null struct if the null/None check fails. - val struct = CreateStruct(newSerializer) - val nullCheck = Or( - IsNull(newInputObject), - Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) - If(nullCheck, Literal.create(null, struct.dataType), struct) + val newSerializer = enc.objSerializer.transformUp { + case b: BoundReference => newInputObject } - Alias(serializerExpr, s"_${index + 1}")() + + Alias(newSerializer, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.deserializer.transform { - case g: GetColumnByOrdinal => g.copy(ordinal = index) - } + val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c } + .distinct + assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " + + s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}") + + val input = GetStructField(GetColumnByOrdinal(0, schema), index) + val newDeserializer = enc.objDeserializer.transformUp { + case GetColumnByOrdinal(0, _) => input + } + if (schema(index).nullable) { + If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer) } else { - val input = GetColumnByOrdinal(index, enc.schema) - val deserialized = enc.deserializer.transformUp { - case UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - UnresolvedExtractValue(input, Literal(nameParts.head)) - case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) - } - If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) + newDeserializer } } + val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)), + Literal.create(null, schema), CreateStruct(serializers)) val deserializer = NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( - schema, - flat = false, serializer, deserializer, ClassTag(cls)) @@ -212,21 +181,91 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param serializer A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param deserializer An expression that will construct an object given an [[InternalRow]]. + * @param objSerializer An expression that can be used to encode a raw object to corresponding + * Spark SQL representation that can be a primitive column, array, map or a + * struct. This represents how Spark SQL generally serializes an object of + * type `T`. + * @param objDeserializer An expression that will construct an object given a Spark SQL + * representation. This represents how Spark SQL generally deserializes + * a serialized value in Spark SQL representation back to an object of + * type `T`. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( - schema: StructType, - flat: Boolean, - serializer: Seq[Expression], - deserializer: Expression, + objSerializer: Expression, + objDeserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(serializer.size == 1) + /** + * A sequence of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]: + * 1. If `serializer` encodes a raw object to a struct, strip the outer If-IsNull and get + * the `CreateNamedStruct`. + * 2. For other cases, wrap the single serializer with `CreateNamedStruct`. + */ + val serializer: Seq[NamedExpression] = { + val clsName = Utils.getSimpleName(clsTag.runtimeClass) + + if (isSerializedAsStruct) { + val nullSafeSerializer = objSerializer.transformUp { + case r: BoundReference => + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(r, Seq("top level Product or row object")) + } + nullSafeSerializer match { + case If(_: IsNull, _, s: CreateNamedStruct) => s + case s: CreateNamedStruct => s + case _ => + throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") + } + } else { + // For other input objects like primitive, array, map, etc., we construct a struct to wrap + // the serializer which is a column of an row. + CreateNamedStruct(Literal("value") :: objSerializer :: Nil) + } + }.flatten + + /** + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using `UnresolvedAttribute`. + * of the same name as the constructor arguments. + * + * For complex objects that are encoded to structs, Fields of the struct will be extracted using + * `GetColumnByOrdinal` with corresponding ordinal. + */ + val deserializer: Expression = { + if (isSerializedAsStruct) { + // We serialized this kind of objects to root-level row. The input of general deserializer + // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to + // transform attributes accessors. + objDeserializer.transform { + case UnresolvedExtractValue(GetColumnByOrdinal(0, _), + Literal(part: UTF8String, StringType)) => + UnresolvedAttribute.quoted(part.toString) + case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) => + GetColumnByOrdinal(ordinal, dt) + case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n + case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i + } + } else { + // For other input objects like primitive, array, map, etc., we deserialize the first column + // of a row to the object. + objDeserializer + } + } + + // The schema after converting `T` to a Spark SQL row. This schema is dependent on the given + // serialier. + val schema: StructType = StructType(serializer.map { s => + StructField(s.name, s.dataType, s.nullable) + }) + + /** + * Returns true if the type `T` is serialized as a struct. + */ + def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This @@ -258,7 +297,7 @@ case class ExpressionEncoder[T]( analyzer.checkAnalysis(analyzedPlan) val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer val bound = BindReferences.bindReference(resolved, attrs) - copy(deserializer = bound) + copy(objDeserializer = bound) } @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ae89f98b19025..d905f8f9858e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -58,12 +58,10 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) - val deserializer = deserializerFor(schema) + val serializer = serializerFor(inputObject, schema) + val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema) new ExpressionEncoder[Row]( - schema, - flat = false, - serializer.asInstanceOf[CreateNamedStruct].flatten, + serializer, deserializer, ClassTag(cls)) } @@ -237,13 +235,9 @@ object RowEncoder { case udt: UserDefinedType[_] => ObjectType(udt.userClass) } - private def deserializerFor(schema: StructType): Expression = { + private def deserializerFor(input: Expression, schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val dt = f.dataType match { - case p: PythonUserDefinedType => p.sqlType - case other => other - } - deserializerFor(GetColumnByOrdinal(i, dt)) + deserializerFor(GetStructField(input, i)) } CreateExternalRow(fields, schema) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f9ee948b97e0a..d98589db323cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String case class PrimitiveData( intField: Int, @@ -112,6 +113,14 @@ object TestingUDT { class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + // A helper method used to test `ScalaReflection.serializerForType`. + private def serializerFor[T: TypeTag]: Expression = + serializerForType(ScalaReflection.localTypeOf[T]) + + // A helper method used to test `ScalaReflection.deserializerForType`. + private def deserializerFor[T: TypeTag]: Expression = + deserializerForType(ScalaReflection.localTypeOf[T]) + test("SQLUserDefinedType annotation on Scala structure") { val schema = schemaFor[TestingUDT.NestedStruct] assert(schema === Schema( @@ -263,13 +272,9 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) - val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false)) - assert(serializer.children.size == 2) - assert(serializer.children.head.isInstanceOf[Literal]) - assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) - assert(serializer.children.last.isInstanceOf[NewInstance]) - assert(serializer.children.last.asInstanceOf[NewInstance] + val serializer = serializerFor[List[Int]] + assert(serializer.isInstanceOf[NewInstance]) + assert(serializer.asInstanceOf[NewInstance] .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) } @@ -280,59 +285,58 @@ class ScalaReflectionSuite extends SparkFunSuite { test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue - val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false)) - assert(queueSerializer.dataType.head.dataType == + val queueSerializer = serializerFor[Queue[Int]] + assert(queueSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val queueDeserializer = deserializerFor[Queue[Int]] assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer - val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) - assert(arrayBufferSerializer.dataType.head.dataType == + val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]] + assert(arrayBufferSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { - val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) - assert(mapSerializer.dataType.head.dataType == + val mapSerializer = serializerFor[Map[Int, Int]] + assert(mapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val mapDeserializer = deserializerFor[Map[Int, Int]] assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap - val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) - assert(hashMapSerializer.dataType.head.dataType == + val hashMapSerializer = serializerFor[HashMap[Int, Int]] + assert(hashMapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} - val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) - assert(linkedHashMapSerializer.dataType.head.dataType == + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]] + assert(linkedHashMapSerializer.dataType == MapType(LongType, StringType, valueContainsNull = true)) val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { - val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val serializer = serializerFor[SpecialCharAsFieldData] + .collect { + case If(_, _, s: CreateNamedStruct) => s + }.head val deserializer = deserializerFor[SpecialCharAsFieldData] assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") - val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { - case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts + val newInstance = deserializer.collect { case n: NewInstance => n }.head + + val argumentsFields = newInstance.arguments.flatMap { _.collect { + case UpCast(u: UnresolvedExtractValue, _, _) => u.extraction.toString }} - assert(argumentsFields(0) == Seq("field.1")) - assert(argumentsFields(1) == Seq("field 2")) + assert(argumentsFields(0) == "field.1") + assert(argumentsFields(1) == "field 2") } test("SPARK-22472: add null check for top-level primitive values") { @@ -351,8 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-23835: add null check to non-nullable types in Tuples") { def numberOfCheckedArguments(deserializer: Expression): Int = { - assert(deserializer.isInstanceOf[NewInstance]) - deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + val newInstance = deserializer.collect { case n: NewInstance => n}.head + newInstance.arguments.count(_.isInstanceOf[AssertNotNull]) } assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f0d61de97ffcd..e9b100b3b30db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -348,7 +348,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes test("nullable of encoder serializer") { def checkNullable[T: Encoder](nullable: Boolean): Unit = { - assert(encoderFor[T].serializer.forall(_.nullable === nullable)) + assert(encoderFor[T].objSerializer.nullable === nullable) } // test for flat encoders diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 235732134d4b8..ab819bec72e85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -239,7 +239,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val encoder = RowEncoder(schema) val e = intercept[RuntimeException](encoder.toRow(null)) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level row object")) + assert(e.getMessage.contains("top level Product or row object")) } test("RowEncoder should validate external type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0fb3301b36162..c91b0d778fab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1087,7 +1087,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (this.exprEnc.flat) { + val combined = if (!this.exprEnc.isSerializedAsStruct) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1097,7 +1097,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (other.exprEnc.flat) { + val combined = if (!other.exprEnc.isSerializedAsStruct) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1110,14 +1110,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (this.exprEnc.flat) { + if (!this.exprEnc.isSerializedAsStruct) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (other.exprEnc.flat) { + if (!other.exprEnc.isSerializedAsStruct) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1390,7 +1390,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (encoder.flat) { + if (!encoder.isSerializedAsStruct) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cbd..555bcdffb6ee4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (kExprEnc.flat) { + val keyColumn = if (!kExprEnc.isSerializedAsStruct) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6d44890704f49..39200ec00e152 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -38,18 +38,14 @@ object TypedAggregateExpression { val bufferSerializer = bufferEncoder.namedExpressions val outputEncoder = encoderFor[OUT] - val outputType = if (outputEncoder.flat) { - outputEncoder.schema.head.dataType - } else { - outputEncoder.schema - } + val outputType = outputEncoder.objSerializer.dataType // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer // expression is an alias of `BoundReference`, which means the buffer object doesn't need // serialization. val isSimpleBuffer = { bufferSerializer.head match { - case Alias(_: BoundReference, _) if bufferEncoder.flat => true + case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true case _ => false } } @@ -71,7 +67,7 @@ object TypedAggregateExpression { outputEncoder.serializer, outputEncoder.deserializer.dataType, outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + outputEncoder.objSerializer.nullable) } else { ComplexTypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], @@ -82,7 +78,7 @@ object TypedAggregateExpression { bufferEncoder.resolveAndBind().deserializer, outputEncoder.serializer, outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + outputEncoder.objSerializer.nullable) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4e593ff046a53..27b3b3d78d2bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1065,7 +1065,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level Product input object")) + assert(e.getMessage.contains("top level Product or row object")) } test("dropDuplicates") { From 3123c7f4881225e912842e9897b0e0e6bd0f4b20 Mon Sep 17 00:00:00 2001 From: xiaoding Date: Thu, 25 Oct 2018 07:06:17 -0500 Subject: [PATCH 1915/2461] [SPARK-25808][BUILD] Upgrade jsr305 version from 1.3.9 to 3.0.0 ## What changes were proposed in this pull request? We find below warnings when build spark project: ``` [warn] * com.google.code.findbugs:jsr305:3.0.0 is selected over 1.3.9 [warn] +- org.apache.hadoop:hadoop-common:2.7.3 (depends on 3.0.0) [warn] +- org.apache.spark:spark-core_2.11:3.0.0-SNAPSHOT (depends on 1.3.9) [warn] +- org.apache.spark:spark-network-common_2.11:3.0.0-SNAPSHOT (depends on 1.3.9) [warn] +- org.apache.spark:spark-unsafe_2.11:3.0.0-SNAPSHOT (depends on 1.3.9) ``` So ideally we need to upgrade jsr305 from 1.3.9 to 3.0.0 to fix this warning Upgrade one of the dependencies jsr305 version from 1.3.9 to 3.0.0 ## How was this patch tested? sbt "core/testOnly" sbt "sql/testOnly" Closes #22803 from daviddingly/master. Authored-by: xiaoding Signed-off-by: Sean Owen --- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 06173f73cc32a..537831ecac45b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -127,7 +127,7 @@ json4s-core_2.11-3.5.3.jar json4s-jackson_2.11-3.5.3.jar json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar -jsr305-1.3.9.jar +jsr305-3.0.0.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 62fddf084c8bc..bc4ef31e3bac4 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -128,7 +128,7 @@ json4s-core_2.11-3.5.3.jar json4s-jackson_2.11-3.5.3.jar json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar -jsr305-1.3.9.jar +jsr305-3.0.0.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar diff --git a/pom.xml b/pom.xml index b1f0a53c8895f..92934c125f783 100644 --- a/pom.xml +++ b/pom.xml @@ -172,7 +172,7 @@ 2.22.2 2.9.3 3.5.2 - 1.3.9 + 3.0.0 0.9.3 4.7 1.1 From 65c653fb455336948c5af2e0f381d1d8f5640874 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 25 Oct 2018 08:35:27 -0500 Subject: [PATCH 1916/2461] [BUILD] Close stale PRs Closes #22567 Closes #18457 Closes #21517 Closes #21858 Closes #22383 Closes #19219 Closes #22401 Closes #22811 Closes #20405 Closes #21933 Closes #22819 from srowen/ClosePRs. Authored-by: Sean Owen Signed-off-by: Sean Owen From 002f9c169eb20b0d71b6d0296595f343c7f5bab2 Mon Sep 17 00:00:00 2001 From: Behroz Sikander Date: Thu, 25 Oct 2018 08:36:44 -0500 Subject: [PATCH 1917/2461] [SPARK-24794][CORE] Driver launched through rest should use all masters ## What changes were proposed in this pull request? In standalone cluster mode, one could launch driver with supervise mode enabled. StandaloneRestServer class uses the host and port of current master as the spark.master property while launching the driver (even if you are running in HA mode). This class also ignores the spark.master property passed as part of the request. Due to the above problem, if the Spark masters switch due to some reason and your driver is killed unexpectedly and relaunched, it will try to connect to the master which is in the driver command specified as -Dspark.master. But this master will be in STANDBY mode and after trying multiple times, the SparkContext will kill itself (even though secondary master was alive and healthy). This change picks the spark.master property from request and uses it to launch the driver process. Due to this, the driver process has both masters in -Dspark.master property. Even if the masters switch, SparkContext can still connect to the ALIVE master and work correctly. ## How was this patch tested? This patch was manually tested on a standalone cluster running 2.2.1. It was rebased on current master and all tests were executed. I have added a unit test for this change (but since I am new I hope I have covered all). Closes #21816 from bsikander/rest_driver_fix. Authored-by: Behroz Sikander Signed-off-by: Sean Owen --- .../deploy/rest/StandaloneRestServer.scala | 12 ++++++++++- .../rest/StandaloneRestSubmitSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 22b65abce611a..afa1a5fbba792 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -138,6 +138,16 @@ private[rest] class StandaloneSubmitRequestServlet( val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") val superviseDriver = sparkProperties.get("spark.driver.supervise") + // The semantics of "spark.master" and the masterUrl are different. While the + // property "spark.master" could contain all registered masters, masterUrl + // contains only the active master. To make sure a Spark driver can recover + // in a multi-master setup, we use the "spark.master" property while submitting + // the driver. + val masters = sparkProperties.get("spark.master") + val (_, masterPort) = Utils.extractHostPortFromSparkUrl(masterUrl) + val masterRestPort = this.conf.getInt("spark.master.rest.port", 6066) + val updatedMasters = masters.map( + _.replace(s":$masterRestPort", s":$masterPort")).getOrElse(masterUrl) val appArgs = request.appArgs // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. val environmentVariables = @@ -146,7 +156,7 @@ private[rest] class StandaloneSubmitRequestServlet( // Construct driver description val conf = new SparkConf(false) .setAll(sparkProperties) - .set("spark.master", masterUrl) + .set("spark.master", updatedMasters) val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 54c168a8218f3..4839c842cc785 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -83,6 +83,26 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(submitResponse.success) } + test("create submission with multiple masters") { + val submittedDriverId = "your-driver-id" + val submitMessage = "my driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val conf = new SparkConf(loadDefaults = false) + val RANDOM_PORT = 9000 + val allMasters = s"$masterUrl,${Utils.localHostName()}:$RANDOM_PORT" + conf.set("spark.master", allMasters) + conf.set("spark.app.name", "dreamer") + val appArgs = Array("one", "two", "six") + // main method calls this + val response = new RestSubmissionClientApp().run("app-resource", "main-class", appArgs, conf) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + test("create submission from main method") { val submittedDriverId = "your-driver-id" val submitMessage = "my driver is submitted" From 6540c2f8f31bbde4df57e48698f46bb1815740ff Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 25 Oct 2018 23:03:16 +0800 Subject: [PATCH 1918/2461] [SPARK-25347][ML][DOC] Spark datasource for image/libsvm user guide ## What changes were proposed in this pull request? Spark datasource for image/libsvm user guide ## How was this patch tested? Scala: 1 Java: 2 Python: 3 R: 4 Closes #22675 from WeichenXu123/add_image_source_doc. Authored-by: WeichenXu Signed-off-by: Wenchen Fan --- docs/_data/menu-ml.yaml | 2 + docs/ml-datasource.md | 108 ++++++++++++++++++ .../ml/source/image/ImageDataSource.scala | 17 +-- 3 files changed, 120 insertions(+), 7 deletions(-) create mode 100644 docs/ml-datasource.md diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index b5a6641e2e7e2..8e366f7f029aa 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -1,5 +1,7 @@ - text: Basic statistics url: ml-statistics.html +- text: Data sources + url: ml-datasource - text: Pipelines url: ml-pipeline.html - text: Extracting, transforming and selecting features diff --git a/docs/ml-datasource.md b/docs/ml-datasource.md new file mode 100644 index 0000000000000..15083326240ac --- /dev/null +++ b/docs/ml-datasource.md @@ -0,0 +1,108 @@ +--- +layout: global +title: Data sources +displayTitle: Data sources +--- + +In this section, we introduce how to use data source in ML to load data. +Beside some general data sources such as Parquet, CSV, JSON and JDBC, we also provide some specific data sources for ML. + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## Image data source + +This image data source is used to load image files from a directory, it can load compressed image (jpeg, png, etc.) into raw image representation via `ImageIO` in Java library. +The loaded DataFrame has one `StructType` column: "image", containing image data stored as image schema. +The schema of the `image` column is: + - origin: `StringType` (represents the file path of the image) + - height: `IntegerType` (height of the image) + - width: `IntegerType` (width of the image) + - nChannels: `IntegerType` (number of image channels) + - mode: `IntegerType` (OpenCV-compatible type) + - data: `BinaryType` (Image bytes in OpenCV-compatible order: row-wise BGR in most cases) + + +
      +
      +[`ImageDataSource`](api/scala/index.html#org.apache.spark.ml.source.image.ImageDataSource) +implements a Spark SQL data source API for loading image data as a DataFrame. + +{% highlight scala %} +scala> val df = spark.read.format("image").option("dropInvalid", true).load("data/mllib/images/origin/kittens") +df: org.apache.spark.sql.DataFrame = [image: struct] + +scala> df.select("image.origin", "image.width", "image.height").show(truncate=false) ++-----------------------------------------------------------------------+-----+------+ +|origin |width|height| ++-----------------------------------------------------------------------+-----+------+ +|file:///spark/data/mllib/images/origin/kittens/54893.jpg |300 |311 | +|file:///spark/data/mllib/images/origin/kittens/DP802813.jpg |199 |313 | +|file:///spark/data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg |300 |200 | +|file:///spark/data/mllib/images/origin/kittens/DP153539.jpg |300 |296 | ++-----------------------------------------------------------------------+-----+------+ +{% endhighlight %} +
      + +
      +[`ImageDataSource`](api/java/org/apache/spark/ml/source/image/ImageDataSource.html) +implements Spark SQL data source API for loading image data as DataFrame. + +{% highlight java %} +Dataset imagesDF = spark.read().format("image").option("dropInvalid", true).load("data/mllib/images/origin/kittens"); +imageDF.select("image.origin", "image.width", "image.height").show(false); +/* +Will output: ++-----------------------------------------------------------------------+-----+------+ +|origin |width|height| ++-----------------------------------------------------------------------+-----+------+ +|file:///spark/data/mllib/images/origin/kittens/54893.jpg |300 |311 | +|file:///spark/data/mllib/images/origin/kittens/DP802813.jpg |199 |313 | +|file:///spark/data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg |300 |200 | +|file:///spark/data/mllib/images/origin/kittens/DP153539.jpg |300 |296 | ++-----------------------------------------------------------------------+-----+------+ +*/ +{% endhighlight %} +
      + +
      +In PySpark we provide Spark SQL data source API for loading image data as DataFrame. + +{% highlight python %} +>>> df = spark.read.format("image").option("dropInvalid", true).load("data/mllib/images/origin/kittens") +>>> df.select("image.origin", "image.width", "image.height").show(truncate=False) ++-----------------------------------------------------------------------+-----+------+ +|origin |width|height| ++-----------------------------------------------------------------------+-----+------+ +|file:///spark/data/mllib/images/origin/kittens/54893.jpg |300 |311 | +|file:///spark/data/mllib/images/origin/kittens/DP802813.jpg |199 |313 | +|file:///spark/data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg |300 |200 | +|file:///spark/data/mllib/images/origin/kittens/DP153539.jpg |300 |296 | ++-----------------------------------------------------------------------+-----+------+ +{% endhighlight %} +
      + +
      +In SparkR we provide Spark SQL data source API for loading image data as DataFrame. + +{% highlight r %} +> df = read.df("data/mllib/images/origin/kittens", "image") +> head(select(df, df$image.origin, df$image.width, df$image.height)) + +1 file:///spark/data/mllib/images/origin/kittens/54893.jpg +2 file:///spark/data/mllib/images/origin/kittens/DP802813.jpg +3 file:///spark/data/mllib/images/origin/kittens/29.5.a_b_EGDP022204.jpg +4 file:///spark/data/mllib/images/origin/kittens/DP153539.jpg + width height +1 300 311 +2 199 313 +3 300 200 +4 300 296 + +{% endhighlight %} +
      + + +
      diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala index a111c95248cf5..d4d74082dc8c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala @@ -19,14 +19,17 @@ package org.apache.spark.ml.source.image /** * `image` package implements Spark SQL data source API for loading image data as `DataFrame`. - * The loaded `DataFrame` has one `StructType` column: `image`. + * It can load compressed image (jpeg, png, etc.) into raw image representation via `ImageIO` + * in Java library. + * The loaded `DataFrame` has one `StructType` column: `image`, containing image data stored + * as image schema. * The schema of the `image` column is: - * - origin: String (represents the file path of the image) - * - height: Int (height of the image) - * - width: Int (width of the image) - * - nChannels: Int (number of the image channels) - * - mode: Int (OpenCV-compatible type) - * - data: BinaryType (Image bytes in OpenCV-compatible order: row-wise BGR in most cases) + * - origin: `StringType` (represents the file path of the image) + * - height: `IntegerType` (height of the image) + * - width: `IntegerType` (width of the image) + * - nChannels: `IntegerType` (number of image channels) + * - mode: `IntegerType` (OpenCV-compatible type) + * - data: `BinaryType` (Image bytes in OpenCV-compatible order: row-wise BGR in most cases) * * To use image data source, you need to set "image" as the format in `DataFrameReader` and * optionally specify the data source options, for example: From ccd07b736640c87ac6980a1c7c2d706ef3bab1bf Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 25 Oct 2018 12:42:31 -0700 Subject: [PATCH 1919/2461] =?UTF-8?q?[SPARK-25665][SQL][TEST]=20Refactor?= =?UTF-8?q?=20ObjectHashAggregateExecBenchmark=20to=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Refactor ObjectHashAggregateExecBenchmark to use main method ## How was this patch tested? Manually tested: ``` bin/spark-submit --class org.apache.spark.sql.execution.benchmark.ObjectHashAggregateExecBenchmark --jars sql/catalyst/target/spark-catalyst_2.11-3.0.0-SNAPSHOT-tests.jar,core/target/spark-core_2.11-3.0.0-SNAPSHOT-tests.jar,sql/hive/target/spark-hive_2.11-3.0.0-SNAPSHOT.jar --packages org.spark-project.hive:hive-exec:1.2.1.spark2 sql/hive/target/spark-hive_2.11-3.0.0-SNAPSHOT-tests.jar ``` Generated results with: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain org.apache.spark.sql.execution.benchmark.ObjectHashAggregateExecBenchmark" ``` Closes #22804 from peter-toth/SPARK-25665. Lead-authored-by: Peter Toth Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- ...jectHashAggregateExecBenchmark-results.txt | 45 ++++ .../ObjectHashAggregateExecBenchmark.scala | 218 +++++++++--------- 2 files changed, 152 insertions(+), 111 deletions(-) create mode 100644 sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt diff --git a/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt b/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt new file mode 100644 index 0000000000000..f3044da972497 --- /dev/null +++ b/sql/hive/benchmarks/ObjectHashAggregateExecBenchmark-results.txt @@ -0,0 +1,45 @@ +================================================================================================ +Hive UDAF vs Spark AF +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +hive udaf w/o group by 6370 / 6400 0.0 97193.6 1.0X +spark af w/o group by 54 / 63 1.2 820.8 118.4X +hive udaf w/ group by 4492 / 4507 0.0 68539.5 1.4X +spark af w/ group by w/o fallback 58 / 64 1.1 881.7 110.2X +spark af w/ group by w/ fallback 136 / 142 0.5 2075.0 46.8X + + +================================================================================================ +ObjectHashAggregateExec vs SortAggregateExec - typed_count +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +sort agg w/ group by 41500 / 41630 2.5 395.8 1.0X +object agg w/ group by w/o fallback 10075 / 10122 10.4 96.1 4.1X +object agg w/ group by w/ fallback 28131 / 28205 3.7 268.3 1.5X +sort agg w/o group by 6182 / 6221 17.0 59.0 6.7X +object agg w/o group by w/o fallback 5435 / 5468 19.3 51.8 7.6X + + +================================================================================================ +ObjectHashAggregateExec vs SortAggregateExec - percentile_approx +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +sort agg w/ group by 970 / 1025 2.2 462.5 1.0X +object agg w/ group by w/o fallback 772 / 798 2.7 368.1 1.3X +object agg w/ group by w/ fallback 1013 / 1044 2.1 483.1 1.0X +sort agg w/o group by 751 / 781 2.8 358.0 1.3X +object agg w/o group by w/o fallback 772 / 814 2.7 368.0 1.3X + + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 3b33785cdfbb2..50ee09678e2cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -21,207 +21,189 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox -import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile -import org.apache.spark.sql.hive.HiveSessionCatalog +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.hive.execution.TestingTypedCount -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.LongType -class ObjectHashAggregateExecBenchmark extends BenchmarkWithCodegen with TestHiveSingleton { - ignore("Hive UDAF vs Spark AF") { - val N = 2 << 15 +/** + * Benchmark to measure hash based aggregation. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * --jars ,, + * --packages org.spark-project.hive:hive-exec:1.2.1.spark2 + * + * 2. build/sbt "hive/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain " + * Results will be written to "benchmarks/ObjectHashAggregateExecBenchmark-results.txt". + * }}} + */ +object ObjectHashAggregateExecBenchmark extends BenchmarkBase with SQLHelper { + + private val spark: SparkSession = TestHive.sparkSession + private val sql = spark.sql _ + import spark.implicits._ + private def hiveUDAFvsSparkAF(N: Int): Unit = { val benchmark = new Benchmark( name = "hive udaf vs spark af", valuesPerIteration = N, minNumIters = 5, warmupTime = 5.seconds, minTime = 10.seconds, - outputPerIteration = true + outputPerIteration = true, + output = output ) - registerHiveFunction("hive_percentile_approx", classOf[GenericUDAFPercentileApprox]) + sql( + s"CREATE TEMPORARY FUNCTION hive_percentile_approx AS '" + + s"${classOf[GenericUDAFPercentileApprox].getName}'" + ) - sparkSession.range(N).createOrReplaceTempView("t") + spark.range(N).createOrReplaceTempView("t") benchmark.addCase("hive udaf w/o group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") - sparkSession.sql("SELECT hive_percentile_approx(id, 0.5) FROM t").collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + sql("SELECT hive_percentile_approx(id, 0.5) FROM t").collect() + } } benchmark.addCase("spark af w/o group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - sparkSession.sql("SELECT percentile_approx(id, 0.5) FROM t").collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + sql("SELECT percentile_approx(id, 0.5) FROM t").collect() + } } benchmark.addCase("hive udaf w/ group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") - sparkSession.sql( - s"SELECT hive_percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" - ).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + sql( + s"SELECT hive_percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" + ).collect() + } } benchmark.addCase("spark af w/ group by w/o fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - sparkSession.sql( - s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" - ).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + sql(s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)") + .collect() + } } benchmark.addCase("spark af w/ group by w/ fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2") - sparkSession.sql( - s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" - ).collect() + withSQLConf( + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "2") { + sql(s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)") + .collect() + } } benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - hive udaf w/o group by 5326 / 5408 0.0 81264.2 1.0X - spark af w/o group by 93 / 111 0.7 1415.6 57.4X - hive udaf w/ group by 3804 / 3946 0.0 58050.1 1.4X - spark af w/ group by w/o fallback 71 / 90 0.9 1085.7 74.8X - spark af w/ group by w/ fallback 98 / 111 0.7 1501.6 54.1X - */ } - ignore("ObjectHashAggregateExec vs SortAggregateExec - typed_count") { - val N: Long = 1024 * 1024 * 100 - + private def objectHashAggregateExecVsSortAggregateExecUsingTypedCount(N: Int): Unit = { val benchmark = new Benchmark( name = "object agg v.s. sort agg", valuesPerIteration = N, minNumIters = 1, warmupTime = 10.seconds, minTime = 45.seconds, - outputPerIteration = true + outputPerIteration = true, + output = output ) - import sparkSession.implicits._ - def typed_count(column: Column): Column = Column(TestingTypedCount(column.expr).toAggregateExpression()) - val df = sparkSession.range(N) + val df = spark.range(N) benchmark.addCase("sort agg w/ group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") - df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + } } benchmark.addCase("object agg w/ group by w/o fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + } } benchmark.addCase("object agg w/ group by w/ fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2") - df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + withSQLConf( + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "2") { + df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + } } benchmark.addCase("sort agg w/o group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") - df.select(typed_count($"id")).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + df.select(typed_count($"id")).collect() + } } benchmark.addCase("object agg w/o group by w/o fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - df.select(typed_count($"id")).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + df.select(typed_count($"id")).collect() + } } benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - sort agg w/ group by 31251 / 31908 3.4 298.0 1.0X - object agg w/ group by w/o fallback 6903 / 7141 15.2 65.8 4.5X - object agg w/ group by w/ fallback 20945 / 21613 5.0 199.7 1.5X - sort agg w/o group by 4734 / 5463 22.1 45.2 6.6X - object agg w/o group by w/o fallback 4310 / 4529 24.3 41.1 7.3X - */ } - ignore("ObjectHashAggregateExec vs SortAggregateExec - percentile_approx") { - val N = 2 << 20 - + private def objectHashAggregateExecVsSortAggregateExecUsingPercentileApprox(N: Int): Unit = { val benchmark = new Benchmark( name = "object agg v.s. sort agg", valuesPerIteration = N, minNumIters = 5, warmupTime = 15.seconds, minTime = 45.seconds, - outputPerIteration = true + outputPerIteration = true, + output = output ) - import sparkSession.implicits._ - - val df = sparkSession.range(N).coalesce(1) + val df = spark.range(N).coalesce(1) benchmark.addCase("sort agg w/ group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") - df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + } } benchmark.addCase("object agg w/ group by w/o fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + } } benchmark.addCase("object agg w/ group by w/ fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2") - df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + withSQLConf( + SQLConf.USE_OBJECT_HASH_AGG.key -> "true", + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "2") { + df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + } } benchmark.addCase("sort agg w/o group by") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") - df.select(percentile_approx($"id", 0.5)).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + df.select(percentile_approx($"id", 0.5)).collect() + } } benchmark.addCase("object agg w/o group by w/o fallback") { _ => - sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") - df.select(percentile_approx($"id", 0.5)).collect() + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + df.select(percentile_approx($"id", 0.5)).collect() + } } benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - sort agg w/ group by 3418 / 3530 0.6 1630.0 1.0X - object agg w/ group by w/o fallback 3210 / 3314 0.7 1530.7 1.1X - object agg w/ group by w/ fallback 3419 / 3511 0.6 1630.1 1.0X - sort agg w/o group by 4336 / 4499 0.5 2067.3 0.8X - object agg w/o group by w/o fallback 4271 / 4372 0.5 2036.7 0.8X - */ - } - - private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { - val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val functionIdentifier = FunctionIdentifier(functionName, database = None) - val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) - sessionCatalog.registerFunction(func, overrideIfExists = false) } private def percentile_approx( @@ -229,4 +211,18 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkWithCodegen with TestHiv val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) Column(approxPercentile.toAggregateExpression(isDistinct)) } + + override def runBenchmarkSuite(): Unit = { + runBenchmark("Hive UDAF vs Spark AF") { + hiveUDAFvsSparkAF(2 << 15) + } + + runBenchmark("ObjectHashAggregateExec vs SortAggregateExec - typed_count") { + objectHashAggregateExecVsSortAggregateExecUsingTypedCount(1024 * 1024 * 100) + } + + runBenchmark("ObjectHashAggregateExec vs SortAggregateExec - percentile_approx") { + objectHashAggregateExecVsSortAggregateExecUsingPercentileApprox(2 << 20) + } + } } From 9b98d9166ee2c130ba38a09e8c0aa12e29676b76 Mon Sep 17 00:00:00 2001 From: Steve Date: Thu, 25 Oct 2018 13:00:59 -0700 Subject: [PATCH 1920/2461] [SPARK-25803][K8S] Fix docker-image-tool.sh -n option ## What changes were proposed in this pull request? docker-image-tool.sh uses getopts in which a colon signifies that an option takes an argument. Since -n does not take an argument it should not have a colon. ## How was this patch tested? Following the reproduction in [JIRA](https://issues.apache.org/jira/browse/SPARK-25803):- 0. Created a custom Dockerfile to use for the spark-r container image. In each of the steps below the path to this Dockerfile is passed with the '-R' option. (spark-r is used here simply as an example, the bug applies to all options) 1. Built container images without '-n'. The [result](https://gist.github.com/sel/59f0911bb1a6a485c2487cf7ca770f9d) is that the '-R' option is honoured and the hello-world image is built for spark-r, as expected. 2. Built container images with '-n' to reproduce the issue The [result](https://gist.github.com/sel/e5cabb9f3bdad5d087349e7fbed75141) is that the '-R' option is ignored and the default container image for spark-r is built 3. Applied the patch and re-built container images with '-n' and did not reproduce the issue The [result](https://gist.github.com/sel/6af14b95012ba8ff267a4fce6e3bd3bf) is that the '-R' option is honoured and the hello-world image is built for spark-r, as expected. Closes #22798 from sel/fix-docker-image-tool-nocache. Authored-by: Steve Signed-off-by: Marcelo Vanzin --- bin/docker-image-tool.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 72563551e5b30..61959ca2a3041 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -182,7 +182,7 @@ PYDOCKERFILE= RDOCKERFILE= NOCACHEARG= BUILD_PARAMS= -while getopts f:p:R:mr:t:n:b: option +while getopts f:p:R:mr:t:nb: option do case "${option}" in From 46d2d2c74d9aaf30e158aeda58a189f6c8e48b9c Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Thu, 25 Oct 2018 13:16:08 -0700 Subject: [PATCH 1921/2461] [SPARK-24787][CORE] Revert hsync in EventLoggingListener and make FsHistoryProvider to read lastBlockBeingWritten data for logs ## What changes were proposed in this pull request? `hsync` has been added as part of SPARK-19531 to get the latest data in the history sever ui, but that is causing the performance overhead and also leading to drop many history log events. `hsync` uses the force `FileChannel.force` to sync the data to the disk and happens for the data pipeline, it is costly operation and making the application to face overhead and drop the events. I think getting the latest data in history server can be done in different way (no impact to application while writing events), there is an api `DFSInputStream.getFileLength()` which gives the file length including the `lastBlockBeingWrittenLength`(different from `FileStatus.getLen()`), this api can be used when the file status length and previously cached length are equal to verify whether any new data has been written or not, if there is any update in data length then the history server can update the in progress history log. And also I made this change as configurable with the default value false, and can be enabled for history server if users want to see the updated data in ui. ## How was this patch tested? Added new test and verified manually, with the added conf `spark.history.fs.inProgressAbsoluteLengthCheck.enabled=true`, history server is reading the logs including the last block data which is being written and updating the Web UI with the latest data. Closes #22752 from devaraj-kavali/SPARK-24787. Authored-by: Devaraj K Signed-off-by: Marcelo Vanzin --- .../deploy/history/FsHistoryProvider.scala | 22 ++++++++++- .../scheduler/EventLoggingListener.scala | 8 +--- .../history/FsHistoryProviderSuite.scala | 37 ++++++++++++++++++- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index c23a659e76df1..c4517d3dfd931 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -34,7 +34,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.hdfs.{DFSInputStream, DistributedFileSystem} import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException import org.fusesource.leveldbjni.internal.NativeDB @@ -449,7 +449,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.write(info.copy(lastProcessed = newLastScanTime, fileSize = entry.getLen())) } - if (info.fileSize < entry.getLen()) { + if (shouldReloadLog(info, entry)) { if (info.appId.isDefined && fastInProgressParsing) { // When fast in-progress parsing is on, we don't need to re-parse when the // size changes, but we do need to invalidate any existing UIs. @@ -541,6 +541,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + private[history] def shouldReloadLog(info: LogInfo, entry: FileStatus): Boolean = { + var result = info.fileSize < entry.getLen + if (!result && info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { + try { + result = Utils.tryWithResource(fs.open(entry.getPath)) { in => + in.getWrappedStream match { + case dfsIn: DFSInputStream => info.fileSize < dfsIn.getFileLength + case _ => false + } + } + } catch { + case e: Exception => + logDebug(s"Failed to check the length for the file : ${info.logPath}", e) + } + } + result + } + private def cleanAppData(appId: String, attemptId: Option[String], logPath: String): Unit = { try { val app = load(appId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 1629e1797977f..f89fcd18ef56b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets -import java.util.EnumSet import java.util.Locale import scala.collection.mutable.{ArrayBuffer, Map} @@ -28,8 +27,6 @@ import scala.collection.mutable.{ArrayBuffer, Map} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.hdfs.DFSOutputStream -import org.apache.hadoop.hdfs.client.HdfsDataOutputStream.SyncFlag import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ @@ -149,10 +146,7 @@ private[spark] class EventLoggingListener( // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) - hadoopDataStream.foreach(ds => ds.getWrappedStream match { - case wrapped: DFSOutputStream => wrapped.hsync(EnumSet.of(SyncFlag.UPDATE_LENGTH)) - case _ => ds.hflush() - }) + hadoopDataStream.foreach(_.hflush()) } if (testing) { loggedEvents += eventJson diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 444e8d6e11f88..6a761d43a5a68 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,8 +27,8 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.fs.{FileStatus, FileSystem, FSDataInputStream, Path} +import org.apache.hadoop.hdfs.{DFSInputStream, DistributedFileSystem} import org.apache.hadoop.security.AccessControlException import org.json4s.jackson.JsonMethods._ import org.mockito.ArgumentMatcher @@ -856,6 +856,39 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(!mockedProvider.isBlacklisted(accessDeniedPath)) } + test("check in-progress event logs absolute length") { + val path = new Path("testapp.inprogress") + val provider = new FsHistoryProvider(createTestConf()) + val mockedProvider = spy(provider) + val mockedFs = mock(classOf[FileSystem]) + val in = mock(classOf[FSDataInputStream]) + val dfsIn = mock(classOf[DFSInputStream]) + when(mockedProvider.fs).thenReturn(mockedFs) + when(mockedFs.open(path)).thenReturn(in) + when(in.getWrappedStream).thenReturn(dfsIn) + when(dfsIn.getFileLength).thenReturn(200) + // FileStatus.getLen is more than logInfo fileSize + var fileStatus = new FileStatus(200, false, 0, 0, 0, path) + var logInfo = new LogInfo(path.toString, 0, Some("appId"), Some("attemptId"), 100) + assert(mockedProvider.shouldReloadLog(logInfo, fileStatus)) + + fileStatus = new FileStatus() + fileStatus.setPath(path) + // DFSInputStream.getFileLength is more than logInfo fileSize + logInfo = new LogInfo(path.toString, 0, Some("appId"), Some("attemptId"), 100) + assert(mockedProvider.shouldReloadLog(logInfo, fileStatus)) + // DFSInputStream.getFileLength is equal to logInfo fileSize + logInfo = new LogInfo(path.toString, 0, Some("appId"), Some("attemptId"), 200) + assert(!mockedProvider.shouldReloadLog(logInfo, fileStatus)) + // in.getWrappedStream returns other than DFSInputStream + val bin = mock(classOf[BufferedInputStream]) + when(in.getWrappedStream).thenReturn(bin) + assert(!mockedProvider.shouldReloadLog(logInfo, fileStatus)) + // fs.open throws exception + when(mockedFs.open(path)).thenThrow(new IOException("Throwing intentionally")) + assert(!mockedProvider.shouldReloadLog(logInfo, fileStatus)) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: From 72a23a6c43fe1b5a6583ea6b35b4fbb08474abbe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Oct 2018 10:19:35 +0800 Subject: [PATCH 1922/2461] [SPARK-25772][SQL][FOLLOWUP] remove GetArrayFromMap ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/22745 we introduced the `GetArrayFromMap` expression. Later on I realized this is duplicated as we already have `MapKeys` and `MapValues`. This PR removes `GetArrayFromMap` ## How was this patch tested? existing tests Closes #22825 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/JavaTypeInference.scala | 6 +- .../expressions/objects/objects.scala | 75 ------------------- 2 files changed, 3 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index f32e080447317..8ef8b2be6939c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -26,7 +26,7 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -280,7 +280,7 @@ object JavaTypeInference { Invoke( UnresolvedMapObjects( p => deserializerFor(keyType, p), - GetKeyArrayFromMap(path)), + MapKeys(path)), "array", ObjectType(classOf[Array[Any]])) @@ -288,7 +288,7 @@ object JavaTypeInference { Invoke( UnresolvedMapObjects( p => deserializerFor(valueType, p), - GetValueArrayFromMap(path)), + MapValues(path)), "array", ObjectType(classOf[Array[Any]])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5bfa485f1569a..b6f9b4734e940 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1788,78 +1788,3 @@ case class ValidateExternalType(child: Expression, expected: DataType) ev.copy(code = code, isNull = input.isNull) } } - -object GetKeyArrayFromMap { - - /** - * Construct an instance of GetArrayFromMap case class - * extracting a key array from a Map expression. - * - * @param child a Map expression to extract a key array from - */ - def apply(child: Expression): Expression = { - GetArrayFromMap( - child, - "keyArray", - _.keyArray(), - { case MapType(kt, _, _) => kt }) - } -} - -object GetValueArrayFromMap { - - /** - * Construct an instance of GetArrayFromMap case class - * extracting a value array from a Map expression. - * - * @param child a Map expression to extract a value array from - */ - def apply(child: Expression): Expression = { - GetArrayFromMap( - child, - "valueArray", - _.valueArray(), - { case MapType(_, vt, _) => vt }) - } -} - -/** - * Extracts a key/value array from a Map expression. - * - * @param child a Map expression to extract an array from - * @param functionName name of the function that is invoked to extract an array - * @param arrayGetter function extracting `ArrayData` from `MapData` - * @param elementTypeGetter function extracting array element `DataType` from `MapType` - */ -case class GetArrayFromMap private( - child: Expression, - functionName: String, - arrayGetter: MapData => ArrayData, - elementTypeGetter: MapType => DataType) extends UnaryExpression with NonSQLExpression { - - private lazy val encodedFunctionName: String = TermName(functionName).encodedName.toString - - lazy val dataType: DataType = { - val mt: MapType = child.dataType.asInstanceOf[MapType] - ArrayType(elementTypeGetter(mt)) - } - - override def checkInputDataTypes(): TypeCheckResult = { - if (child.dataType.isInstanceOf[MapType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"Can't extract array from $child: need map type but got ${child.dataType.catalogString}") - } - } - - override def nullSafeEval(input: Any): Any = { - arrayGetter(input.asInstanceOf[MapData]) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, childValue => s"$childValue.$encodedFunctionName()") - } - - override def toString: String = s"$child.$functionName" -} From dc9b320807881403ca9f1e2e6d01de4b52db3975 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 26 Oct 2018 11:07:55 +0800 Subject: [PATCH 1923/2461] [SPARK-25793][ML] call SaveLoadV2_0.load for classNameV2_0 ## What changes were proposed in this pull request? The following code in BisectingKMeansModel.load calls the wrong version of load. ``` case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => val model = SaveLoadV1_0.load(sc, path) ``` Closes #22790 from huaxingao/spark-25793. Authored-by: Huaxin Gao Signed-off-by: Wenchen Fan --- .../spark/mllib/clustering/BisectingKMeansModel.scala | 6 +++--- .../spark/mllib/clustering/BisectingKMeansSuite.scala | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 9d115afcea75d..4c5794fbffc8e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -109,10 +109,10 @@ class BisectingKMeansModel private[clustering] ( @Since("2.0.0") override def save(sc: SparkContext, path: String): Unit = { - BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path) + BisectingKMeansModel.SaveLoadV2_0.save(sc, this, path) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = "2.0" } @Since("2.0.0") @@ -126,7 +126,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val model = SaveLoadV1_0.load(sc, path) model case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => - val model = SaveLoadV1_0.load(sc, path) + val model = SaveLoadV2_0.load(sc, path) model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index 35f7932ae8224..4a4d8b5c89de8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -187,11 +187,12 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { val points = (1 until 8).map(i => Vectors.dense(i)) val data = sc.parallelize(points, 2) - val model = new BisectingKMeans().run(data) + val model = new BisectingKMeans().setDistanceMeasure(DistanceMeasure.COSINE).run(data) try { model.save(sc, path) val sameModel = BisectingKMeansModel.load(sc, path) assert(model.k === sameModel.k) + assert(model.distanceMeasure === sameModel.distanceMeasure) model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) } finally { Utils.deleteRecursively(tempDir) From 79f3babcc6e189d7405464b9ac1eb1c017e51f5d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 25 Oct 2018 20:26:13 -0700 Subject: [PATCH 1924/2461] [SPARK-25840][BUILD] `make-distribution.sh` should not fail due to missing LICENSE-binary ## What changes were proposed in this pull request? We vote for the artifacts. All releases are in the form of the source materials needed to make changes to the software being released. (http://www.apache.org/legal/release-policy.html#artifacts) From Spark 2.4.0, the source artifact and binary artifact starts to contain own proper LICENSE files (LICENSE, LICENSE-binary). It's great to have them. However, unfortunately, `dev/make-distribution.sh` inside source artifacts start to fail because it expects `LICENSE-binary` and source artifact have only the LICENSE file. https://dist.apache.org/repos/dist/dev/spark/v2.4.0-rc4-bin/spark-2.4.0.tgz `dev/make-distribution.sh` is used during the voting phase because we are voting on that source artifact instead of GitHub repository. Individual contributors usually don't have the downstream repository and starts to try build the voting source artifacts to help the verification for the source artifact during voting phase. (Personally, I did before.) This PR aims to recover that script to work in any way. This doesn't aim for source artifacts to reproduce the compiled artifacts. ## How was this patch tested? Manual. ``` $ rm LICENSE-binary $ dev/make-distribution.sh ``` Closes #22840 from dongjoon-hyun/SPARK-25840. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/make-distribution.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 668682fbb913d..84f4ae9a64ff8 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -212,9 +212,13 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" -cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" -cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" +if [ -e "$SPARK_HOME/LICENSE-binary" ]; then + cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" + cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" + cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" +else + echo "Skipping copying LICENSE files" +fi if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" From 24e8c27dfe31e6e0a53c89e6ddc36327e537931b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 26 Oct 2018 11:39:38 +0800 Subject: [PATCH 1925/2461] [SPARK-25819][SQL] Support parse mode option for the function `from_avro` ## What changes were proposed in this pull request? Current the function `from_avro` throws exception on reading corrupt records. In practice, there could be various reasons of data corruption. It would be good to support `PERMISSIVE` mode and allow the function from_avro to process all the input file/streaming, which is consistent with from_json and from_csv. There is no obvious down side for supporting `PERMISSIVE` mode. Different from `from_csv` and `from_json`, the default parse mode is `FAILFAST` for the following reasons: 1. Since Avro is structured data format, input data is usually able to be parsed by certain schema. In such case, exposing the problems of input data to users is better than hiding it. 2. For `PERMISSIVE` mode, we have to force the data schema as fully nullable. This seems quite unnecessary for Avro. Reversing non-null schema might archive more perf optimizations in Spark. 3. To be consistent with the behavior in Spark 2.4 . ## How was this patch tested? Unit test Manual previewing generated html for the Avro data source doc: ![image](https://user-images.githubusercontent.com/1097932/47510100-02558880-d8aa-11e8-9d57-a43daee4c6b9.png) Closes #22814 from gengliangwang/improve_from_avro. Authored-by: Gengliang Wang Signed-off-by: hyukjinkwon --- docs/sql-data-sources-avro.md | 18 +++- .../spark/sql/avro/AvroDataToCatalyst.scala | 90 ++++++++++++++++--- .../apache/spark/sql/avro/AvroOptions.scala | 16 +++- .../org/apache/spark/sql/avro/package.scala | 28 +++++- .../AvroCatalystDataConversionSuite.scala | 58 ++++++++++-- .../spark/sql/avro/AvroFunctionsSuite.scala | 36 +++++++- 6 files changed, 219 insertions(+), 27 deletions(-) diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index d3b81f029d377..bfe641d1c6d1d 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -142,7 +142,10 @@ StreamingQuery query = output ## Data Source Option -Data source options of Avro can be set using the `.option` method on `DataFrameReader` or `DataFrameWriter`. +Data source options of Avro can be set via: + * the `.option` method on `DataFrameReader` or `DataFrameWriter`. + * the `options` parameter in function `from_avro`. + @@ -177,6 +180,19 @@ Data source options of Avro can be set using the `.option` method on `DataFrameR Currently supported codecs are uncompressed, snappy, deflate, bzip2 and xz.
      If the option is not set, the configuration spark.sql.avro.compression.codec config is taken into account. + + + + + +
      Property NameDefaultMeaningScope
      write
      modeFAILFASTThe mode option allows to specify parse mode for function from_avro.
      + Currently supported modes are: +
        +
      • FAILFAST: Throws an exception on processing corrupted record.
      • +
      • PERMISSIVE: Corrupt records are processed as null result. Therefore, the + data schema is forced to be fully nullable, which might be different from the one user provided.
      • +
      +
      function from_avro
      ## Configuration diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 915769fa708b0..43d3f6efb2a0c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -17,20 +17,37 @@ package org.apache.spark.sql.avro +import scala.util.control.NonFatal + import org.apache.avro.Schema import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} +import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.types._ -case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) +case class AvroDataToCatalyst( + child: Expression, + jsonFormatSchema: String, + options: Map[String, String]) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) - override lazy val dataType: DataType = SchemaConverters.toSqlType(avroSchema).dataType + override lazy val dataType: DataType = { + val dt = SchemaConverters.toSqlType(avroSchema).dataType + parseMode match { + // With PermissiveMode, the output Catalyst row might contain columns of null values for + // corrupt records, even if some of the columns are not nullable in the user-provided schema. + // Therefore we force the schema to be all nullable here. + case PermissiveMode => dt.asNullable + case _ => dt + } + } override def nullable: Boolean = true @@ -44,24 +61,75 @@ case class AvroDataToCatalyst(child: Expression, jsonFormatSchema: String) @transient private var result: Any = _ + @transient private lazy val parseMode: ParseMode = { + val mode = AvroOptions(options).parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(unacceptableModeMessage(mode.name)) + } + mode + } + + private def unacceptableModeMessage(name: String): String = { + s"from_avro() doesn't support the $name mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}." + } + + @transient private lazy val nullResultRow: Any = dataType match { + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + for(i <- 0 until st.length) { + resultRow.setNullAt(i) + } + resultRow + + case _ => + null + } + + override def nullSafeEval(input: Any): Any = { val binary = input.asInstanceOf[Array[Byte]] - decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) - result = reader.read(result, decoder) - deserializer.deserialize(result) + try { + decoder = DecoderFactory.get().binaryDecoder(binary, 0, binary.length, decoder) + result = reader.read(result, decoder) + deserializer.deserialize(result) + } catch { + // There could be multiple possible exceptions here, e.g. java.io.IOException, + // AvroRuntimeException, ArrayIndexOutOfBoundsException, etc. + // To make it simple, catch all the exceptions here. + case NonFatal(e) => parseMode match { + case PermissiveMode => nullResultRow + case FailFastMode => + throw new SparkException("Malformed records are detected in record parsing. " + + s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + + "result, try setting the option 'mode' as 'PERMISSIVE'.", e.getCause) + case _ => + throw new AnalysisException(unacceptableModeMessage(parseMode.name)) + } + } } override def simpleString: String = { - s"from_avro(${child.sql}, ${dataType.simpleString})" + s"from_avro(${child.sql}, ${dataType.simpleString}, ${options.toString()})" } override def sql: String = { - s"from_avro(${child.sql}, ${dataType.catalogString})" + s"from_avro(${child.sql}, ${dataType.catalogString}, ${options.toString()})" } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val expr = ctx.addReferenceObj("this", this) - defineCodeGen(ctx, ev, input => - s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") + nullSafeCodeGen(ctx, ev, eval => { + val result = ctx.freshName("result") + val dt = CodeGenerator.boxedType(dataType) + s""" + $dt $result = ($dt) $expr.nullSafeEval($eval); + if ($result == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = $result; + } + """ + }) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 67f56343b4524..fec17bfff5424 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.avro import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} import org.apache.spark.sql.internal.SQLConf /** @@ -79,4 +80,17 @@ class AvroOptions( val compression: String = { parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) } + + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) +} + +object AvroOptions { + def apply(parameters: Map[String, String]): AvroOptions = { + val hadoopConf = SparkSession + .getActiveSession + .map(_.sessionState.newHadoopConf()) + .getOrElse(new Configuration()) + new AvroOptions(CaseInsensitiveMap(parameters), hadoopConf) + } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala index 97f9427f96c55..dee8575c621c8 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental package object avro { + /** * Converts a binary column of avro format into its corresponding catalyst value. The specified * schema must match the read data, otherwise the behavior is undefined: it may fail or return @@ -31,8 +34,29 @@ package object avro { * @since 2.4.0 */ @Experimental - def from_avro(data: Column, jsonFormatSchema: String): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema)) + def from_avro( + data: Column, + jsonFormatSchema: String): Column = { + new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) + } + + /** + * Converts a binary column of avro format into its corresponding catalyst value. The specified + * schema must match the read data, otherwise the behavior is undefined: it may fail or return + * arbitrary result. + * + * @param data the binary column. + * @param jsonFormatSchema the avro schema in JSON string format. + * @param options options to control how the Avro record is parsed. + * + * @since 3.0.0 + */ + @Experimental + def from_avro( + data: Column, + jsonFormatSchema: String, + options: java.util.Map[String, String]): Column = { + new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) } /** diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 8334cca6cd8f1..80dd4c535ad9c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.avro +import org.apache.avro.Schema + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalHelper { +class AvroCatalystDataConversionSuite extends SparkFunSuite + with SharedSQLContext + with ExpressionEvalHelper { private def roundTripTest(data: Literal): Unit = { val avroType = SchemaConverters.toAvroType(data.dataType, data.nullable) @@ -33,14 +38,26 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH private def checkResult(data: Literal, schema: String, expected: Any): Unit = { checkEvaluation( - AvroDataToCatalyst(CatalystDataToAvro(data), schema), + AvroDataToCatalyst(CatalystDataToAvro(data), schema, Map.empty), prepareExpectedResult(expected)) } - private def assertFail(data: Literal, schema: String): Unit = { - intercept[java.io.EOFException] { - AvroDataToCatalyst(CatalystDataToAvro(data), schema).eval() + protected def checkUnsupportedRead(data: Literal, schema: String): Unit = { + val binary = CatalystDataToAvro(data) + intercept[Exception] { + AvroDataToCatalyst(binary, schema, Map("mode" -> "FAILFAST")).eval() + } + + val expected = { + val avroSchema = new Schema.Parser().parse(schema) + SchemaConverters.toSqlType(avroSchema).dataType match { + case st: StructType => Row.fromSeq((0 until st.length).map(_ => null)) + case _ => null + } } + + checkEvaluation(AvroDataToCatalyst(binary, schema, Map("mode" -> "PERMISSIVE")), + expected) } private val testingTypes = Seq( @@ -121,7 +138,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH """.stripMargin // When read int as string, avro reader is not able to parse the binary and fail. - assertFail(data, avroTypeJson) + checkUnsupportedRead(data, avroTypeJson) } test("read string as int") { @@ -151,7 +168,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH // When read float data as double, avro reader fails(trying to read 8 bytes while the data have // only 4 bytes). - assertFail(data, avroTypeJson) + checkUnsupportedRead(data, avroTypeJson) } test("read double as float") { @@ -167,4 +184,29 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH // avro reader reads the first 4 bytes of a double as a float, the result is totally undefined. checkResult(data, avroTypeJson, 5.848603E35f) } + + test("Handle unsupported input of record type") { + val actualSchema = StructType(Seq( + StructField("col_0", StringType, false), + StructField("col_1", ShortType, false), + StructField("col_2", DecimalType(8, 4), false), + StructField("col_3", BooleanType, true), + StructField("col_4", DecimalType(38, 38), false))) + + val expectedSchema = StructType(Seq( + StructField("col_0", BinaryType, false), + StructField("col_1", DoubleType, false), + StructField("col_2", DecimalType(18, 4), false), + StructField("col_3", StringType, true), + StructField("col_4", DecimalType(38, 38), false))) + + val seed = scala.util.Random.nextLong() + withClue(s"create random record with seed $seed") { + val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema) + val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema) + val input = Literal.create(converter(data), actualSchema) + val avroSchema = SchemaConverters.toAvroType(expectedSchema).toString + checkUnsupportedRead(input, avroSchema) + } + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 90a4cd6ccf9dd..46a37d8759da1 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.avro -import org.apache.avro.Schema +import scala.collection.JavaConverters._ -import org.apache.spark.sql.QueryTest +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.functions.struct -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class AvroFunctionsSuite extends QueryTest with SharedSQLContext { +class AvroFunctionsSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ test("roundtrip in to_avro and from_avro - int and string") { @@ -61,6 +62,33 @@ class AvroFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) } + test("handle invalid input in from_avro") { + val count = 10 + val df = spark.range(count).select(struct('id, 'id.as("id2")).as("struct")) + val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroTypeStruct = s""" + |{ + | "type": "record", + | "name": "struct", + | "fields": [ + | {"name": "col1", "type": "long"}, + | {"name": "col2", "type": "double"} + | ] + |} + """.stripMargin + + intercept[SparkException] { + avroStructDF.select( + from_avro('avro, avroTypeStruct, Map("mode" -> "FAILFAST").asJava)).collect() + } + + // For PERMISSIVE mode, the result should be row of null columns. + val expected = (0 until count).map(_ => Row(Row(null, null))) + checkAnswer( + avroStructDF.select(from_avro('avro, avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), + expected) + } + test("roundtrip in to_avro and from_avro - array with null") { val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") val avroTypeArrStruct = s""" From 86d469aeaa492c0642db09b27bb0879ead5d7166 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 26 Oct 2018 13:53:51 +0900 Subject: [PATCH 1926/2461] [SPARK-25822][PYSPARK] Fix a race condition when releasing a Python worker ## What changes were proposed in this pull request? There is a race condition when releasing a Python worker. If `ReaderIterator.handleEndOfDataSection` is not running in the task thread, when a task is early terminated (such as `take(N)`), the task completion listener may close the worker but "handleEndOfDataSection" can still put the worker into the worker pool to reuse. https://github.com/zsxwing/spark/commit/0e07b483d2e7c68f3b5c3c118d0bf58c501041b7 is a patch to reproduce this issue. I also found a user reported this in the mail list: http://mail-archives.apache.org/mod_mbox/spark-user/201610.mbox/%3CCAAUq=H+YLUEpd23nwvq13Ms5hOStkhX3ao4f4zQV6sgO5zM-xAmail.gmail.com%3E This PR fixes the issue by using `compareAndSet` to make sure we will never return a closed worker to the work pool. ## How was this patch tested? Jenkins. Closes #22816 from zsxwing/fix-socket-closed. Authored-by: Shixiong Zhu Signed-off-by: Takuya UESHIN --- .../spark/api/python/PythonRunner.scala | 21 ++++++++++--------- .../execution/python/ArrowPythonRunner.scala | 4 ++-- .../execution/python/PythonUDFRunner.scala | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 6e53a044e9a8c..f73e95eac8f79 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -106,15 +106,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", memoryMb.get.toString) } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool - val released = new AtomicBoolean(false) + // Whether is the worker released into idle pool or closed. When any codes try to release or + // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make + // sure there is only one winner that is going to release or close the worker. + val releasedOrClosed = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() - if (!reuseWorker || !released.get) { + if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { try { worker.close() } catch { @@ -131,7 +133,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = newReaderIterator( - stream, writerThread, startTime, env, worker, released, context) + stream, writerThread, startTime, env, worker, releasedOrClosed, context) new InterruptibleIterator(context, stdoutIterator) } @@ -148,7 +150,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: Socket, - released: AtomicBoolean, + releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] /** @@ -392,7 +394,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: Socket, - released: AtomicBoolean, + releasedOrClosed: AtomicBoolean, context: TaskContext) extends Iterator[OUT] { @@ -463,9 +465,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuseWorker) { + if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) { env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released.set(true) } } eos = true @@ -565,9 +566,9 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) startTime: Long, env: SparkEnv, worker: Socket, - released: AtomicBoolean, + releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { protected override def read(): Array[Byte] = { if (writerThread.exception.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 18992d7a9f974..04623b1ab3c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -117,9 +117,9 @@ class ArrowPythonRunner( startTime: Long, env: SparkEnv, worker: Socket, - released: AtomicBoolean, + releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdin reader for $pythonExec", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index cc61faa7e7051..752d271c4cc35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -59,9 +59,9 @@ class PythonUDFRunner( startTime: Long, env: SparkEnv, worker: Socket, - released: AtomicBoolean, + releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { protected override def read(): Array[Byte] = { if (writerThread.exception.isDefined) { From 89d748b33c8636a1b1411c505921b0a585e1e6cb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 26 Oct 2018 13:17:24 +0800 Subject: [PATCH 1927/2461] [SPARK-25842][SQL] Deprecate rangeBetween APIs introduced in SPARK-21608 ## What changes were proposed in this pull request? See the detailed information at https://issues.apache.org/jira/browse/SPARK-25841 on why these APIs should be deprecated and redesigned. This patch also reverts https://github.com/apache/spark/commit/8acb51f08b448628b65e90af3b268994f9550e45 which applies to 2.4. ## How was this patch tested? Only deprecation and doc changes. Closes #22841 from rxin/SPARK-25842. Authored-by: Reynold Xin Signed-off-by: Wenchen Fan --- python/pyspark/sql/functions.py | 30 -------- python/pyspark/sql/window.py | 70 +++++-------------- .../apache/spark/sql/expressions/Window.scala | 46 +----------- .../spark/sql/expressions/WindowSpec.scala | 45 +----------- .../org/apache/spark/sql/functions.scala | 12 ++-- 5 files changed, 28 insertions(+), 175 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8b2e423d250cd..739496b4ecb5e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -858,36 +858,6 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) -@since(2.4) -def unboundedPreceding(): - """ - Window function: returns the special frame boundary that represents the first row - in the window partition. - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.unboundedPreceding()) - - -@since(2.4) -def unboundedFollowing(): - """ - Window function: returns the special frame boundary that represents the last row - in the window partition. - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.unboundedFollowing()) - - -@since(2.4) -def currentRow(): - """ - Window function: returns the special frame boundary that represents the current row - in the window partition. - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.currentRow()) - - # ---------------------- Date/Timestamp functions ------------------------------ @since(1.5) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index d19ced954f04e..e76563dfaa9c8 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,11 +16,9 @@ # import sys -if sys.version >= '3': - long = int from pyspark import since, SparkContext -from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.column import _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -126,45 +124,20 @@ def rangeBetween(start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, - ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` - to specify special boundary values, rather than using integral values directly. + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, - a column returned by ``pyspark.sql.functions.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, - a column returned by ``pyspark.sql.functions.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). - - >>> from pyspark.sql import functions as F, SparkSession, Window - >>> spark = SparkSession.builder.getOrCreate() - >>> df = spark.createDataFrame( - ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) - >>> window = Window.orderBy("id").partitionBy("category").rangeBetween( - ... F.currentRow(), F.lit(1)) - >>> df.withColumn("sum", F.sum("id").over(window)).show() - +---+--------+---+ - | id|category|sum| - +---+--------+---+ - | 1| b| 3| - | 2| b| 5| - | 3| b| 3| - | 1| a| 4| - | 1| a| 4| - | 2| a| 2| - +---+--------+---+ """ - if isinstance(start, (int, long)) and isinstance(end, (int, long)): - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing - elif isinstance(start, Column) and isinstance(end, Column): - start = start._jc - end = end._jc + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -239,34 +212,27 @@ def rangeBetween(self, start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, - ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` - to specify special boundary values, rather than using integral values directly. + and ``Window.currentRow`` to specify special boundary values, rather than using integral + values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, - a column returned by ``pyspark.sql.functions.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, - a column returned by ``pyspark.sql.functions.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if isinstance(start, (int, long)) and isinstance(end, (int, long)): - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing - elif isinstance(start, Column) and isinstance(end, Column): - start = start._jc - end = end._jc + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing return WindowSpec(self._jspec.rangeBetween(start, end)) def _test(): import doctest SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) + (failure_count, test_count) = doctest.testmod() if failure_count: sys.exit(-1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index cd819bab1b14c..14dec8f0810f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -215,52 +215,10 @@ object Window { } /** - * Creates a [[WindowSpec]] with the frame boundaries defined, - * from `start` (inclusive) to `end` (inclusive). - * - * Both `start` and `end` are relative to the current row. For example, "lit(0)" means - * "current row", while "lit(-1)" means one off before the current row, and "lit(5)" means the - * five off after the current row. - * - * Users should use `unboundedPreceding()`, `unboundedFollowing()`, and `currentRow()` from - * [[org.apache.spark.sql.functions]] to specify special boundary values, literals are not - * transformed to [[org.apache.spark.sql.catalyst.expressions.SpecialFrameBoundary]]s. - * - * A range-based boundary is based on the actual value of the ORDER BY - * expression(s). An offset is used to alter the value of the ORDER BY expression, for - * instance if the current order by expression has a value of 10 and the lower bound offset - * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a - * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical/date/timestamp data type. An exception can be made when the - * offset is unbounded, because no value modification is needed, in this case multiple and - * non-numerical/date/timestamp data type ORDER BY expression are allowed. - * - * {{{ - * import org.apache.spark.sql.expressions.Window - * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) - * .toDF("id", "category") - * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rangeBetween(currentRow(), lit(1)) - * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() - * - * +---+--------+---+ - * | id|category|sum| - * +---+--------+---+ - * | 1| b| 3| - * | 2| b| 5| - * | 3| b| 3| - * | 1| a| 4| - * | 1| a| 4| - * | 2| a| 2| - * +---+--------+---+ - * }}} - * - * @param start boundary start, inclusive. The frame is unbounded if the expression is - * [[org.apache.spark.sql.catalyst.expressions.UnboundedPreceding]]. - * @param end boundary end, inclusive. The frame is unbounded if the expression is - * [[org.apache.spark.sql.catalyst.expressions.UnboundedFollowing]]. + * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. * @since 2.3.0 */ + @deprecated("Use the version with Long parameter types", "2.4.0") def rangeBetween(start: Column, end: Column): WindowSpec = { spec.rangeBetween(start, end) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 4c41aa3c5fb67..0cc43a58237df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -210,51 +210,10 @@ class WindowSpec private[sql]( } /** - * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). - * - * Both `start` and `end` are relative to the current row. For example, "lit(0)" means - * "current row", while "lit(-1)" means one off before the current row, and "lit(5)" means the - * five off after the current row. - * - * Users should use `unboundedPreceding()`, `unboundedFollowing()`, and `currentRow()` from - * [[org.apache.spark.sql.functions]] to specify special boundary values, literals are not - * transformed to [[org.apache.spark.sql.catalyst.expressions.SpecialFrameBoundary]]s. - * - * A range-based boundary is based on the actual value of the ORDER BY - * expression(s). An offset is used to alter the value of the ORDER BY expression, for - * instance if the current order by expression has a value of 10 and the lower bound offset - * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a - * number of constraints on the ORDER BY expressions: there can be only one expression and this - * expression must have a numerical/date/timestamp data type. An exception can be made when the - * offset is unbounded, because no value modification is needed, in this case multiple and - * non-numerical/date/timestamp data type ORDER BY expression are allowed. - * - * {{{ - * import org.apache.spark.sql.expressions.Window - * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) - * .toDF("id", "category") - * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rangeBetween(currentRow(), lit(1)) - * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() - * - * +---+--------+---+ - * | id|category|sum| - * +---+--------+---+ - * | 1| b| 3| - * | 2| b| 5| - * | 3| b| 3| - * | 1| a| 4| - * | 1| a| 4| - * | 2| a| 2| - * +---+--------+---+ - * }}} - * - * @param start boundary start, inclusive. The frame is unbounded if the expression is - * [[org.apache.spark.sql.catalyst.expressions.UnboundedPreceding]]. - * @param end boundary end, inclusive. The frame is unbounded if the expression is - * [[org.apache.spark.sql.catalyst.expressions.UnboundedFollowing]]. + * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. * @since 2.3.0 */ + @deprecated("Use the version with Long parameter types", "2.4.0") def rangeBetween(start: Column, end: Column): WindowSpec = { new WindowSpec( partitionSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dbf1f239e2d21..2748e64723a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -830,30 +830,30 @@ object functions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Window function: returns the special frame boundary that represents the first row in the - * window partition. + * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. * * @group window_funcs * @since 2.3.0 */ + @deprecated("Use Window.unboundedPreceding", "2.4.0") def unboundedPreceding(): Column = Column(UnboundedPreceding) /** - * Window function: returns the special frame boundary that represents the last row in the - * window partition. + * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. * * @group window_funcs * @since 2.3.0 */ + @deprecated("Use Window.unboundedFollowing", "2.4.0") def unboundedFollowing(): Column = Column(UnboundedFollowing) /** - * Window function: returns the special frame boundary that represents the current row in the - * window partition. + * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. * * @group window_funcs * @since 2.3.0 */ + @deprecated("Use Window.currentRow", "2.4.0") def currentRow(): Column = Column(CurrentRow) /** From 6fd5ff3951ed9ac7c0b20f2666d8bc39929bfb5c Mon Sep 17 00:00:00 2001 From: seancxmao Date: Fri, 26 Oct 2018 18:53:55 +0800 Subject: [PATCH 1928/2461] [SPARK-25797][SQL][DOCS] Add migration doc for solving issues caused by view canonicalization approach change ## What changes were proposed in this pull request? Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. See [SPARK-25797](https://issues.apache.org/jira/browse/SPARK-25797) for more details. Basically, we have 2 options. 1) Make Spark 2.2+ able to get older view definitions back. Since the expanded text is buggy and unusable, we have to use original text (this is possible with [SPARK-25459](https://issues.apache.org/jira/browse/SPARK-25459)). However, because older Spark versions don't save the context for the database, we cannot always get correct view definitions without view default database. 2) Recreate the views by `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS`. This PR aims to add migration doc to help users troubleshoot this issue by above option 2. ## How was this patch tested? N/A. Docs are generated and checked locally ``` cd docs SKIP_API=1 jekyll serve --watch ``` Closes #22846 from seancxmao/SPARK-25797. Authored-by: seancxmao Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index dfa35b88369cb..38c03d36caade 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -304,6 +304,8 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). + - Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions. + ## Upgrading From Spark SQL 2.0 to 2.1 - Datasource tables now store partition metadata in the Hive metastore. This means that Hive DDLs such as `ALTER TABLE PARTITION ... SET LOCATION` are now available for tables created with the Datasource API. From 7d44bc26408b2189804fd305797afcefb7b2b0e0 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Fri, 26 Oct 2018 08:49:27 -0500 Subject: [PATCH 1929/2461] [SPARK-25835][K8S] Create kubernetes-tests profile and use the detected SCALA_VERSION ## What changes were proposed in this pull request? - Fixes the scala version propagation issue. - Disables the tests under the k8s profile, now we will run them manually. Adds a test specific profile otherwise tests will not run if we just remove the module from the kubernetes profile (quickest solution I can think of). ## How was this patch tested? Manually by running the tests with different versions of scala. Closes #22838 from skonto/propagate-scala2.12. Authored-by: Stavros Kontopoulos Signed-off-by: Sean Owen --- pom.xml | 7 +++++++ .../integration-tests/dev/dev-run-integration-tests.sh | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 92934c125f783..597fb2fa1abd6 100644 --- a/pom.xml +++ b/pom.xml @@ -2656,6 +2656,13 @@ kubernetes resource-managers/kubernetes/core + +
      + + + + kubernetes-integration-tests + resource-managers/kubernetes/integration-tests diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index e26c0b3a39c90..c3c843e001f21 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -28,6 +28,7 @@ NAMESPACE= SERVICE_ACCOUNT= INCLUDE_TAGS="k8s" EXCLUDE_TAGS= +SCALA_VERSION="$($TEST_ROOT_DIR/build/mvn org.apache.maven.plugins:maven-help-plugin:2.1.1:evaluate -Dexpression=scala.binary.version | grep -v '\[' )" # Parse arguments while (( "$#" )); do @@ -103,4 +104,4 @@ then properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) fi -$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-$SCALA_VERSION -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} From 33e337c1180a12edf1ae97f0221e389f23192461 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 26 Oct 2018 22:14:43 +0800 Subject: [PATCH 1930/2461] [SPARK-24709][SQL][FOLLOW-UP] Make schema_of_json's input json as literal only ## What changes were proposed in this pull request? The main purpose of `schema_of_json` is the usage of combination with `from_json` (to make up the leak of schema inference) which takes its schema only as literal; however, currently `schema_of_json` allows JSON input as non-literal expressions (e.g, column). This was mistakenly allowed - we don't have to take other usages rather then the main purpose into account for now. This PR makes a followup to only allow literals for `schema_of_json`'s JSON input. We can allow non literal expressions later when it's needed or there are some usecase for it. ## How was this patch tested? Unit tests were added. Closes #22775 from HyukjinKwon/SPARK-25447-followup. Lead-authored-by: hyukjinkwon Co-authored-by: Hyukjin Kwon Signed-off-by: Wenchen Fan --- python/pyspark/sql/functions.py | 22 ++++++------ .../expressions/jsonExpressions.scala | 21 ++++++++--- .../org/apache/spark/sql/functions.scala | 24 +++++++++---- .../sql-tests/inputs/json-functions.sql | 6 +++- .../sql-tests/results/json-functions.sql.out | 36 ++++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 2 +- 6 files changed, 87 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 739496b4ecb5e..ca2a256983d67 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2335,30 +2335,32 @@ def to_json(col, options={}): @ignore_unicode_prefix @since(2.4) -def schema_of_json(col, options={}): +def schema_of_json(json, options={}): """ - Parses a column containing a JSON string and infers its schema in DDL format. + Parses a JSON string and infers its schema in DDL format. - :param col: string column in json format + :param json: a JSON string or a string literal containing a JSON string. :param options: options to control parsing. accepts the same options as the JSON datasource .. versionchanged:: 3.0 It accepts `options` parameter to control schema inferring. - >>> from pyspark.sql.types import * - >>> data = [(1, '{"a": 1}')] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(schema_of_json(df.value).alias("json")).collect() - [Row(json=u'struct')] + >>> df = spark.range(1) >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() [Row(json=u'struct')] - >>> schema = schema_of_json(lit('{a: 1}'), {'allowUnquotedFieldNames':'true'}) + >>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'}) >>> df.select(schema.alias("json")).collect() [Row(json=u'struct')] """ + if isinstance(json, basestring): + col = _create_column_from_literal(json) + elif isinstance(json, Column): + col = _to_java_column(json) + else: + raise TypeError("schema argument should be a column or string") sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_json(_to_java_column(col), options) + jc = sc._jvm.functions.schema_of_json(col, options) return Column(jc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e966924293cf7..77af5906010f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -742,7 +742,7 @@ case class StructsToJson( case class SchemaOfJson( child: Expression, options: Map[String, String]) - extends UnaryExpression with String2StringExpression with CodegenFallback { + extends UnaryExpression with CodegenFallback { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -750,6 +750,10 @@ case class SchemaOfJson( child = child, options = ExprUtils.convertToMapData(options)) + override def dataType: DataType = StringType + + override def nullable: Boolean = false + @transient private lazy val jsonOptions = new JSONOptions(options, "UTC") @@ -760,8 +764,17 @@ case class SchemaOfJson( factory } - override def convert(v: UTF8String): UTF8String = { - val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + @transient + private lazy val json = child.eval().asInstanceOf[UTF8String] + + override def checkInputDataTypes(): TypeCheckResult = child match { + case Literal(s, StringType) if s != null => super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"The input json should be a string literal and not null; however, got ${child.sql}.") + } + + override def eval(v: InternalRow): Any = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => parser.nextToken() inferField(parser, jsonOptions) } @@ -776,7 +789,7 @@ object JsonExprUtils { def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval().asInstanceOf[UTF8String] + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( "Schema should be specified in DDL format as a string literal" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2748e64723a78..757a3226855c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3626,19 +3626,29 @@ object functions { } /** - * Parses a column containing a JSON string and infers its schema. + * Parses a JSON string and infers its schema in DDL format. * - * @param e a string column containing JSON data. + * @param json a JSON string. * * @group collection_funcs * @since 2.4.0 */ - def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + def schema_of_json(json: String): Column = schema_of_json(lit(json)) /** - * Parses a column containing a JSON string and infers its schema using options. + * Parses a JSON string and infers its schema in DDL format. * - * @param e a string column containing JSON data. + * @param json a string literal containing a JSON string. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(json: Column): Column = withExpr(new SchemaOfJson(json.expr)) + + /** + * Parses a JSON string and infers its schema in DDL format using options. + * + * @param json a string column containing JSON data. * @param options options to control how the json is parsed. accepts the same options and the * json data source. See [[DataFrameReader#json]]. * @return a column with string literal containing schema in DDL format. @@ -3646,8 +3656,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def schema_of_json(e: Column, options: java.util.Map[String, String]): Column = { - withExpr(SchemaOfJson(e.expr, options.asScala.toMap)) + def schema_of_json(json: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfJson(json.expr, options.asScala.toMap)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 8bfd7c0774398..6c14eee2e4e61 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -55,4 +55,8 @@ select to_json(array(array(1, 2, 3), array(4))); -- infer schema of json literal using options select schema_of_json('{"c1":1}', map('primitivesAsString', 'true')); select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'true', 'prefersDecimal', 'true')); - +select schema_of_json(null); +CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a'); +SELECT schema_of_json(jsonField) FROM jsonTable; +-- Clean up +DROP VIEW IF EXISTS jsonTable; diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index c70a81ea28aa5..ca0cd90d94fa7 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 42 -- !query 0 @@ -318,3 +318,37 @@ select schema_of_json('{"c1":01, "c2":0.1}', map('allowNumericLeadingZeros', 'tr struct -- !query 37 output struct + + +-- !query 38 +select schema_of_json(null) +-- !query 38 schema +struct<> +-- !query 38 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_json(NULL)' due to data type mismatch: The input json should be a string literal and not null; however, got NULL.; line 1 pos 7 + + +-- !query 39 +CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a') +-- !query 39 schema +struct<> +-- !query 39 output + + + +-- !query 40 +SELECT schema_of_json(jsonField) FROM jsonTable +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_json(jsontable.`jsonField`)' due to data type mismatch: The input json should be a string literal and not null; however, got jsontable.`jsonField`.; line 1 pos 7 + + +-- !query 41 +DROP VIEW IF EXISTS jsonTable +-- !query 41 schema +struct<> +-- !query 41 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 797b274f42cdd..2b09782faeeaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -395,7 +395,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24709: infers schemas of json strings and pass them to from_json") { val in = Seq("""{"a": [1, 2, 3]}""").toDS() - val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed") val expected = StructType(StructField( "parsed", StructType(StructField( From f1891ff1e3f03668ac21b352b009bfea5e3c2b7f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 26 Oct 2018 09:48:17 -0500 Subject: [PATCH 1931/2461] [SPARK-25760][DOCS][FOLLOWUP] Add note about AddJar return value change in migration guide ## What changes were proposed in this pull request? Add note about AddJar return value change in migration guide ## How was this patch tested? n/a Closes #22826 from srowen/SPARK-25760.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- docs/sql-migration-guide-upgrade.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 38c03d36caade..c9685b866774f 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -15,6 +15,8 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. + - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. From d367bdcf521f564d2d7066257200be26b27ea926 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 26 Oct 2018 09:40:13 -0700 Subject: [PATCH 1932/2461] [SPARK-25255][PYTHON] Add getActiveSession to SparkSession in PySpark ## What changes were proposed in this pull request? add getActiveSession in session.py ## How was this patch tested? add doctest Closes #22295 from huaxingao/spark25255. Authored-by: Huaxin Gao Signed-off-by: Holden Karau --- python/pyspark/sql/session.py | 30 +++++++ python/pyspark/sql/tests.py | 151 ++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 079af8c05705d..6f4b32757314d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -192,6 +192,7 @@ def getOrCreate(self): """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances""" _instantiatedSession = None + _activeSession = None @ignore_unicode_prefix def __init__(self, sparkContext, jsparkSession=None): @@ -233,7 +234,9 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + SparkSession._activeSession = self self._jvm.SparkSession.setDefaultSession(self._jsparkSession) + self._jvm.SparkSession.setActiveSession(self._jsparkSession) def _repr_html_(self): return """ @@ -255,6 +258,29 @@ def newSession(self): """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @classmethod + @since(3.0) + def getActiveSession(cls): + """ + Returns the active SparkSession for the current thread, returned by the builder. + >>> s = SparkSession.getActiveSession() + >>> l = [('Alice', 1)] + >>> rdd = s.sparkContext.parallelize(l) + >>> df = s.createDataFrame(rdd, ['name', 'age']) + >>> df.select("age").collect() + [Row(age=1)] + """ + from pyspark import SparkContext + sc = SparkContext._active_spark_context + if sc is None: + return None + else: + if sc._jvm.SparkSession.getActiveSession().isDefined(): + SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) + return SparkSession._activeSession + else: + return None + @property @since(2.0) def sparkContext(self): @@ -671,6 +697,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr ... Py4JJavaError: ... """ + SparkSession._activeSession = self + self._jvm.SparkSession.setActiveSession(self._jsparkSession) if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") @@ -826,7 +854,9 @@ def stop(self): self._sc.stop() # We should clean the default session up. See SPARK-23228. self._jvm.SparkSession.clearDefaultSession() + self._jvm.SparkSession.clearActiveSession() SparkSession._instantiatedSession = None + SparkSession._activeSession = None @since(2.0) def __enter__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 82dc5a6751ad2..ad04270c1a361 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3985,6 +3985,157 @@ def test_jvm_default_session_already_set(self): spark.stop() +class SparkSessionTests2(unittest.TestCase): + + def test_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + activeSession = SparkSession.getActiveSession() + df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) + self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) + finally: + spark.stop() + + def test_get_active_session_when_no_active_session(self): + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + active = SparkSession.getActiveSession() + self.assertEqual(active, spark) + spark.stop() + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + + def test_SparkSession(self): + spark = SparkSession.builder \ + .master("local") \ + .config("some-config", "v2") \ + .getOrCreate() + try: + self.assertEqual(spark.conf.get("some-config"), "v2") + self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") + self.assertEqual(spark.version, spark.sparkContext.version) + spark.sql("CREATE DATABASE test_db") + spark.catalog.setCurrentDatabase("test_db") + self.assertEqual(spark.catalog.currentDatabase(), "test_db") + spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") + self.assertEqual(spark.table("table1").columns, ['name', 'age']) + self.assertEqual(spark.range(3).count(), 3) + finally: + spark.stop() + + def test_global_default_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertEqual(SparkSession.builder.getOrCreate(), spark) + finally: + spark.stop() + + def test_default_and_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + activeSession = spark._jvm.SparkSession.getActiveSession() + defaultSession = spark._jvm.SparkSession.getDefaultSession() + try: + self.assertEqual(activeSession, defaultSession) + finally: + spark.stop() + + def test_config_option_propagated_to_existing_session(self): + session1 = SparkSession.builder \ + .master("local") \ + .config("spark-config1", "a") \ + .getOrCreate() + self.assertEqual(session1.conf.get("spark-config1"), "a") + session2 = SparkSession.builder \ + .config("spark-config1", "b") \ + .getOrCreate() + try: + self.assertEqual(session1, session2) + self.assertEqual(session1.conf.get("spark-config1"), "b") + finally: + session1.stop() + + def test_new_session(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + newSession = session.newSession() + try: + self.assertNotEqual(session, newSession) + finally: + session.stop() + newSession.stop() + + def test_create_new_session_if_old_session_stopped(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + session.stop() + newSession = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertNotEqual(session, newSession) + finally: + newSession.stop() + + def test_active_session_with_None_and_not_None_context(self): + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + sc = None + session = None + try: + sc = SparkContext._active_spark_context + self.assertEqual(sc, None) + activeSession = SparkSession.getActiveSession() + self.assertEqual(activeSession, None) + sparkConf = SparkConf() + sc = SparkContext.getOrCreate(sparkConf) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertFalse(activeSession.isDefined()) + session = SparkSession(sc) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertTrue(activeSession.isDefined()) + activeSession2 = SparkSession.getActiveSession() + self.assertNotEqual(activeSession2, None) + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + +class SparkSessionTests3(ReusedSQLTestCase): + + def test_get_active_session_after_create_dataframe(self): + session2 = None + try: + activeSession1 = SparkSession.getActiveSession() + session1 = self.spark + self.assertEqual(session1, activeSession1) + session2 = self.spark.newSession() + activeSession2 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession2) + self.assertNotEqual(session2, activeSession2) + session2.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession3 = SparkSession.getActiveSession() + self.assertEqual(session2, activeSession3) + session1.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession4 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession4) + finally: + if session2 is not None: + session2.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: From 6aa506394958bfb30cd2a9085a5e8e8be927de51 Mon Sep 17 00:00:00 2001 From: shane knapp Date: Fri, 26 Oct 2018 16:37:36 -0500 Subject: [PATCH 1933/2461] [SPARK-25854][BUILD] fix `build/mvn` not to fail during Zinc server shutdown ## What changes were proposed in this pull request? the final line in the mvn helper script in build/ attempts to shut down the zinc server. due to the zinc server being set up w/a 30min timeout, by the time the mvn test instantiation finishes, the server times out. this means that when the mvn script tries to shut down zinc, it returns w/an exit code of 1. this will then automatically fail the entire build (even if the build passes). ## How was this patch tested? i set up a test build: https://amplab.cs.berkeley.edu/jenkins/job/sknapp-testing-spark-branch-2.4-test-maven-hadoop-2.7/ Closes #22854 from shaneknapp/fix-mvn-helper-script. Authored-by: shane knapp Signed-off-by: Sean Owen --- build/mvn | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/build/mvn b/build/mvn index b60ea644b262d..3816993b4e5c8 100755 --- a/build/mvn +++ b/build/mvn @@ -153,7 +153,7 @@ if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}` export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} "${ZINC_BIN}" -start -port ${ZINC_PORT} \ - -server 127.0.0.1 -idle-timeout 30m \ + -server 127.0.0.1 -idle-timeout 3h \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi @@ -163,8 +163,12 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 -# Last, call the `mvn` command as usual +# call the `mvn` command as usual +# SPARK-25854 "${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" +MVN_RETCODE=$? -# Try to shut down zinc explicitly +# Try to shut down zinc explicitly if the server is still running. "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + +exit $MVN_RETCODE From d325ffbf3a6b3555cbe5a3004ffb4dde41bff363 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 26 Oct 2018 16:45:56 -0500 Subject: [PATCH 1934/2461] [SPARK-25851][SQL][MINOR] Fix deprecated API warning in SQLListener ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/21596, Jackson is upgraded to 2.9.6. There are some deprecated API warnings in SQLListener. Create a trivial PR to fix them. ``` [warn] SQLListener.scala:92: method uncheckedSimpleType in class TypeFactory is deprecated: see corresponding Javadoc for more information. [warn] val objectType = typeFactory.uncheckedSimpleType(classOf[Object]) [warn] [warn] SQLListener.scala:93: method constructSimpleType in class TypeFactory is deprecated: see corresponding Javadoc for more information. [warn] typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(objectType, objectType)) [warn] [warn] SQLListener.scala:97: method uncheckedSimpleType in class TypeFactory is deprecated: see corresponding Javadoc for more information. [warn] val longType = typeFactory.uncheckedSimpleType(classOf[Long]) [warn] [warn] SQLListener.scala:98: method constructSimpleType in class TypeFactory is deprecated: see corresponding Javadoc for more information. [warn] typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) ``` ## How was this patch tested? Existing unit tests. Closes #22848 from gengliangwang/fixSQLListenerWarning. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- .../org/apache/spark/sql/execution/ui/SQLListener.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index c04a31c428d11..03d75c4c1b82f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -89,12 +89,12 @@ private class LongLongTupleConverter extends Converter[(Object, Object), (Long, } override def getInputType(typeFactory: TypeFactory): JavaType = { - val objectType = typeFactory.uncheckedSimpleType(classOf[Object]) - typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(objectType, objectType)) + val objectType = typeFactory.constructType(classOf[Object]) + typeFactory.constructSimpleType(classOf[(_, _)], Array(objectType, objectType)) } override def getOutputType(typeFactory: TypeFactory): JavaType = { - val longType = typeFactory.uncheckedSimpleType(classOf[Long]) - typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType)) + val longType = typeFactory.constructType(classOf[Long]) + typeFactory.constructSimpleType(classOf[(_, _)], Array(longType, longType)) } } From ca545f79410a464ef24e3986fac225f53bb2ef02 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 26 Oct 2018 16:49:48 -0500 Subject: [PATCH 1935/2461] [SPARK-25821][SQL] Remove SQLContext methods deprecated in 1.4 ## What changes were proposed in this pull request? Remove SQLContext methods deprecated in 1.4 ## How was this patch tested? Existing tests. Closes #22815 from srowen/SPARK-25821. Authored-by: Sean Owen Signed-off-by: Sean Owen --- R/pkg/NAMESPACE | 2 - R/pkg/R/SQLContext.R | 61 +--- R/pkg/tests/fulltests/test_sparkSQL.R | 25 +- docs/sparkr.md | 6 +- .../org/apache/spark/sql/SQLContext.scala | 283 ------------------ 5 files changed, 8 insertions(+), 369 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 36d7a9b38b37e..5a5dc20ff3b78 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -420,13 +420,11 @@ export("as.DataFrame", "currentDatabase", "dropTempTable", "dropTempView", - "jsonFile", "listColumns", "listDatabases", "listFunctions", "listTables", "loadDF", - "parquetFile", "read.df", "read.jdbc", "read.json", diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index c819a7d14ae98..3f89ee99e2564 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -343,7 +343,6 @@ setMethod("toDF", signature(x = "RDD"), #' path <- "path/to/file.json" #' df <- read.json(path) #' df <- read.json(path, multiLine = TRUE) -#' df <- jsonFile(path) #' } #' @name read.json #' @method read.json default @@ -363,51 +362,6 @@ read.json <- function(x, ...) { dispatchFunc("read.json(path)", x, ...) } -#' @rdname read.json -#' @name jsonFile -#' @method jsonFile default -#' @note jsonFile since 1.4.0 -jsonFile.default <- function(path) { - .Deprecated("read.json") - read.json(path) -} - -jsonFile <- function(x, ...) { - dispatchFunc("jsonFile(path)", x, ...) -} - -#' JSON RDD -#' -#' Loads an RDD storing one JSON object per string as a SparkDataFrame. -#' -#' @param sqlContext SQLContext to use -#' @param rdd An RDD of JSON string -#' @param schema A StructType object to use as schema -#' @param samplingRatio The ratio of simpling used to infer the schema -#' @return A SparkDataFrame -#' @noRd -#' @examples -#'\dontrun{ -#' sparkR.session() -#' rdd <- texFile(sc, "path/to/json") -#' df <- jsonRDD(sqlContext, rdd) -#'} - -# TODO: remove - this method is no longer exported -# TODO: support schema -jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { - .Deprecated("read.json") - rdd <- serializeToString(rdd) - if (is.null(schema)) { - read <- callJMethod(sqlContext, "read") - # samplingRatio is deprecated - sdf <- callJMethod(read, "json", callJMethod(getJRDD(rdd), "rdd")) - dataFrame(sdf) - } else { - stop("not implemented") - } -} - #' Create a SparkDataFrame from an ORC file. #' #' Loads an ORC file, returning the result as a SparkDataFrame. @@ -434,6 +388,7 @@ read.orc <- function(path, ...) { #' Loads a Parquet file, returning the result as a SparkDataFrame. #' #' @param path path of file to read. A vector of multiple paths is allowed. +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.parquet #' @name read.parquet @@ -454,20 +409,6 @@ read.parquet <- function(x, ...) { dispatchFunc("read.parquet(...)", x, ...) } -#' @param ... argument(s) passed to the method. -#' @rdname read.parquet -#' @name parquetFile -#' @method parquetFile default -#' @note parquetFile since 1.4.0 -parquetFile.default <- function(...) { - .Deprecated("read.parquet") - read.parquet(unlist(list(...))) -} - -parquetFile <- function(x, ...) { - dispatchFunc("parquetFile(...)", x, ...) -} - #' Create a SparkDataFrame from a text file. #' #' Loads text files and returns a SparkDataFrame whose schema starts with diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 509f689ac521e..68bf5eac98462 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -628,14 +628,10 @@ test_that("read/write json files", { jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") write.json(df, jsonPath3) - # Test read.json()/jsonFile() works with multiple input paths + # Test read.json() works with multiple input paths jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) expect_is(jsonDF1, "SparkDataFrame") expect_equal(count(jsonDF1), 6) - # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) - expect_is(jsonDF2, "SparkDataFrame") - expect_equal(count(jsonDF2), 6) unlink(jsonPath2) unlink(jsonPath3) @@ -655,20 +651,6 @@ test_that("read/write json files - compression option", { unlink(jsonPath) }) -test_that("jsonRDD() on a RDD with json string", { - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) - rdd <- parallelize(sc, mockLines) - expect_equal(countRDD(rdd), 3) - df <- suppressWarnings(jsonRDD(sqlContext, rdd)) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- suppressWarnings(jsonRDD(sqlContext, rdd2)) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 6) -}) - test_that("test tableNames and tables", { count <- count(listTables()) @@ -2658,7 +2640,7 @@ test_that("read/write Parquet files", { expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) - # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + # Test write.parquet/saveAsParquetFile and read.parquet parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.parquet(df, parquetPath2) parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") @@ -2666,9 +2648,6 @@ test_that("read/write Parquet files", { parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) expect_is(parquetDF, "SparkDataFrame") expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) - expect_is(parquetDF2, "SparkDataFrame") - expect_equal(count(parquetDF2), count(df) * 2) # Test if varargs works with variables saveMode <- "overwrite" diff --git a/docs/sparkr.md b/docs/sparkr.md index 79f8ab81342be..5882ed7923aa7 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -709,8 +709,12 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.3.1 and above - - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-based. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. ## Upgrading to SparkR 2.4.0 - Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does. + +## Upgrading to SparkR 3.0.0 + + - The deprecated methods `parquetFile`, `jsonRDD` and `jsonFile` in `SQLContext` have been removed. Use `read.parquet` and `read.json`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dfb12f272eb2f..1b7e969a7192e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -755,289 +755,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) sessionState.catalog.listTables(databaseName).map(_.table).toArray } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("Use createDataFrame instead.", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("Use createDataFrame instead.", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("Use createDataFrame instead.", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("Use createDataFrame instead.", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * Loads a Parquet file, returning the result as a `DataFrame`. This function returns an empty - * `DataFrame` if no paths are passed in. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. - */ - @deprecated("Use read.parquet() instead.", "1.4.0") - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else { - read.parquet(paths : _*) - } - } - - /** - * Loads a JSON file (one object per line), returning the result as a `DataFrame`. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonFile(path: String): DataFrame = { - read.json(path) - } - - /** - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a `DataFrame`. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonFile(path: String, schema: StructType): DataFrame = { - read.schema(schema).json(path) - } - - /** - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonFile(path: String, samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(path) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * `DataFrame`. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonRDD(json: RDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * `DataFrame`. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a `DataFrame`. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an JavaRDD[String] storing JSON objects (one object per record) and applies the given - * schema, returning the result as a `DataFrame`. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a `DataFrame`. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a `DataFrame`. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json() instead.", "1.4.0") - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. - */ - @deprecated("Use read.load(path) instead.", "1.4.0") - def load(path: String): DataFrame = { - read.load(path) - } - - /** - * Returns the dataset stored at path as a DataFrame, using the given data source. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. - */ - @deprecated("Use read.format(source).load(path) instead.", "1.4.0") - def load(path: String, source: String): DataFrame = { - read.format(source).load(path) - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - */ - @deprecated("Use read.format(source).options(options).load() instead.", "1.4.0") - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - */ - @deprecated("Use read.format(source).options(options).load() instead.", "1.4.0") - def load(source: String, options: Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load() instead.", "1.4.0") - def load( - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load() instead.", "1.4.0") - def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. - */ - @deprecated("Use read.jdbc() instead.", "1.4.0") - def jdbc(url: String, table: String): DataFrame = { - read.jdbc(url, table, new Properties) - } - - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. - */ - @deprecated("Use read.jdbc() instead.", "1.4.0") - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) - } - - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the `DataFrame`. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. - */ - @deprecated("Use read.jdbc() instead.", "1.4.0") - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - read.jdbc(url, table, theParts, new Properties) - } } /** From e9b71c8f017d2da3b9ae586017b2e5a040f023d2 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 26 Oct 2018 15:59:12 -0700 Subject: [PATCH 1936/2461] [SPARK-25828][K8S] Bumping Kubernetes-Client version to 4.1.0 ## What changes were proposed in this pull request? Changed the `kubernetes-client` version and refactored code that broke as a result ## How was this patch tested? Unit and Integration tests Closes #22820 from ifilonenko/SPARK-25828. Authored-by: Ilan Filonenko Signed-off-by: Erik Erlandson --- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- dev/deps/spark-deps-hadoop-3.1 | 6 +++--- docs/running-on-kubernetes.md | 3 ++- resource-managers/kubernetes/core/pom.xml | 2 +- .../scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala | 6 +++--- .../spark/deploy/k8s/features/MountVolumesFeatureStep.scala | 3 ++- .../spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala | 2 +- .../scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala | 2 +- resource-managers/kubernetes/integration-tests/pom.xml | 2 +- 9 files changed, 17 insertions(+), 15 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 537831ecac45b..0703b5b02b125 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -132,13 +132,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-4.0.2.jar -kubernetes-client-3.0.0.jar -kubernetes-model-2.0.0.jar +kubernetes-client-4.1.0.jar +kubernetes-model-4.1.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar -logging-interceptor-3.8.1.jar +logging-interceptor-3.9.1.jar lz4-java-1.5.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index bc4ef31e3bac4..513986820d5fc 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -147,13 +147,13 @@ kerby-pkix-1.0.1.jar kerby-util-1.0.1.jar kerby-xdr-1.0.1.jar kryo-shaded-4.0.2.jar -kubernetes-client-3.0.0.jar -kubernetes-model-2.0.0.jar +kubernetes-client-4.1.0.jar +kubernetes-model-4.1.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar -logging-interceptor-3.8.1.jar +logging-interceptor-3.9.1.jar lz4-java-1.5.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 60c9279f2bce2..7093ee5a9686d 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -45,7 +45,8 @@ logs and remains in "completed" state in the Kubernetes API until it's eventuall Note that in the completed state, the driver pod does *not* use any computational or memory resources. -The driver and executor pod scheduling is handled by Kubernetes. It is possible to schedule the +The driver and executor pod scheduling is handled by Kubernetes. Communication to the Kubernetes API is done via fabric8, and we are +currently running kubernetes-client version 4.1.0. Make sure that when you are making infrastructure additions that you are aware of said version. It is possible to schedule the driver and executor pods on a subset of available nodes through a [node selector](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) using the configuration property for it. It will be possible to use more advanced scheduling hints like [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) in a future release. diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 90bac19cba019..b89ea383bf872 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -29,7 +29,7 @@ Spark Project Kubernetes kubernetes - 3.0.0 + 4.1.0 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 8f36fa12aed17..0f740454fafc4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod, Time} +import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod} import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.util.Utils @@ -157,7 +157,7 @@ private[spark] object KubernetesUtils { }.getOrElse(Seq(("container state", "N/A"))) } - def formatTime(time: Time): String = { - if (time != null) time.getTime else "N/A" + def formatTime(time: String): String = { + if (time != null) time else "N/A" } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index bb0e2b3128efd..e60259c4a9b5a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -56,8 +56,9 @@ private[spark] class MountVolumesFeatureStep( val volumeBuilder = spec.volumeConf match { case KubernetesHostPathVolumeConf(hostPath) => + /* "" means that no checks will be performed before mounting the hostPath volume */ new VolumeBuilder() - .withHostPath(new HostPathVolumeSource(hostPath)) + .withHostPath(new HostPathVolumeSource(hostPath, "")) case KubernetesPVCVolumeConf(claimName) => new VolumeBuilder() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala index 1889fe5eb3e9b..79b55bc37afcd 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala @@ -20,7 +20,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod, Time} +import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala index c6b667ed85e8c..2e883623a4b1c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -82,7 +82,7 @@ object ExecutorLifecycleTestUtils { def deletedExecutor(executorId: Long): Pod = { new PodBuilder(podWithAttachedContainerForId(executorId)) .editOrNewMetadata() - .withNewDeletionTimestamp("523012521") + .withDeletionTimestamp("523012521") .endMetadata() .build() } diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 23453c8957b28..a07fe1feea3eb 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -29,7 +29,7 @@ 1.3.0 1.4.0 - 3.0.0 + 4.1.0 3.2.2 1.0 kubernetes-integration-tests From 6f05669e4e9d1cc541f1843918970cc9bac24f35 Mon Sep 17 00:00:00 2001 From: laskfla Date: Sat, 27 Oct 2018 08:09:59 -0500 Subject: [PATCH 1937/2461] [MINOR][DOC] Fix comment error of HiveUtils ## What changes were proposed in this pull request? Change the version number in comment of `HiveUtils.newClientForExecution` from `13` to `1.2.1` . ## How was this patch tested? N/A Closes #22850 from laskfla/HiveUtils-Comment. Authored-by: laskfla Signed-off-by: Sean Owen --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index cd321d41f43e8..74f21532b22df 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -252,7 +252,7 @@ private[spark] object HiveUtils extends Logging { /** * Create a [[HiveClient]] used for execution. * - * Currently this must always be Hive 13 as this is the version of Hive that is packaged + * Currently this must always be Hive 1.2.1 as this is the version of Hive that is packaged * with Spark SQL. This copy of the client is used for execution related tasks like * registering temporary functions or ensuring that the ThreadLocal SessionState is * correctly populated. This copy of Hive is *not* used for storing persistent metadata, From d5573c578a1eea9ee04886d9df37c7178e67bb30 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Sat, 27 Oct 2018 08:20:42 -0500 Subject: [PATCH 1938/2461] [SPARK-23367][BUILD] Include python document style checking ## What changes were proposed in this pull request? Includes python document style checking. - Use sphinx like check, run only if pydocstyle installed on machine/jenkins - use pydocstyle rather than single file version pep257.py, which is much older and had some known issues - verify pydocstyle latest 3.0.0 is in use, to ensure latest doc checks are getting executed - ignore (inclusion/exclusion error codes) features and support via tox.ini - Be non-breaking change and allow updating docstyle to standards at easy pace ## How was this patch tested? ./dev/run-tests Closes #22425 from rekhajoshm/SPARK-23367-2. Authored-by: Rekha Joshi Signed-off-by: Sean Owen --- dev/lint-python | 28 ++++++++++++++++++++++++++++ dev/tox.ini | 2 ++ 2 files changed, 30 insertions(+) diff --git a/dev/lint-python b/dev/lint-python index e26bd4bd4517c..2e353e142c143 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,9 +21,14 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" # Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" +DOC_PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" | grep -vF 'functions.py' )" PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" +PYDOCSTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pydocstyle-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" +PYDOCSTYLEBUILD="pydocstyle" +EXPECTED_PYDOCSTYLEVERSION="3.0.0" +PYDOCSTYLEVERSION=$(python -c 'import pkg_resources; print(pkg_resources.get_distribution("pydocstyle").version)' 2> /dev/null) SPHINXBUILD=${SPHINXBUILD:=sphinx-build} SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" @@ -99,6 +104,29 @@ else echo "flake8 checks passed." fi +# Check python document style, skip check if pydocstyle is not installed. +if hash "$PYDOCSTYLEBUILD" 2> /dev/null; then + if [[ "$PYDOCSTYLEVERSION" == "$EXPECTED_PYDOCSTYLEVERSION" ]]; then + pydocstyle --config=dev/tox.ini $DOC_PATHS_TO_CHECK >> "$PYDOCSTYLE_REPORT_PATH" + pydocstyle_status="${PIPESTATUS[0]}" + + if [ "$compile_status" -eq 0 -a "$pydocstyle_status" -eq 0 ]; then + echo "pydocstyle checks passed." + rm "$PYDOCSTYLE_REPORT_PATH" + else + echo "pydocstyle checks failed." + cat "$PYDOCSTYLE_REPORT_PATH" + rm "$PYDOCSTYLE_REPORT_PATH" + exit 1 + fi + + else + echo "The pydocstyle version needs to be latest 3.0.0. Skipping pydoc checks for now" + fi +else + echo >&2 "The pydocstyle command was not found. Skipping pydoc checks for now" +fi + # Check that the documentation builds acceptably, skip check if sphinx is not installed. if hash "$SPHINXBUILD" 2> /dev/null; then cd python/docs diff --git a/dev/tox.ini b/dev/tox.ini index 6ec223b743b4e..11b1b040035b0 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -17,3 +17,5 @@ ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* +[pydocstyle] +ignore=D100,D101,D102,D103,D104,D105,D106,D107,D200,D201,D202,D203,D204,D205,D206,D207,D208,D209,D210,D211,D212,D213,D214,D215,D300,D301,D302,D400,D401,D402,D403,D404,D405,D406,D407,D408,D409,D410,D411,D412,D413,D414 From 41e1416f4d441415212bb5705898509ce5344ec4 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sat, 27 Oct 2018 15:11:29 -0700 Subject: [PATCH 1939/2461] [SPARK-16693][SPARKR] Remove methods deprecated ## What changes were proposed in this pull request? Remove deprecated functions which includes: SQLContext/HiveContext stuff sparkR.init jsonFile parquetFile registerTempTable saveAsParquetFile unionAll createExternalTable dropTempTable ## How was this patch tested? jenkins Author: Felix Cheung Closes #22843 from felixcheung/rrddapi. --- R/pkg/NAMESPACE | 11 +- R/pkg/R/DataFrame.R | 52 +--------- R/pkg/R/SQLContext.R | 95 ++--------------- R/pkg/R/catalog.R | 99 +----------------- R/pkg/R/generics.R | 9 -- R/pkg/R/sparkR.R | 142 +------------------------- R/pkg/tests/fulltests/test_context.R | 12 --- R/pkg/tests/fulltests/test_sparkSQL.R | 60 ++--------- docs/sparkr.md | 5 +- 9 files changed, 32 insertions(+), 453 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5a5dc20ff3b78..f9f556e69a1fc 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -28,9 +28,8 @@ importFrom("utils", "download.file", "object.size", "packageVersion", "tail", "u # S3 methods exported export("sparkR.session") -export("sparkR.init") -export("sparkR.stop") export("sparkR.session.stop") +export("sparkR.stop") export("sparkR.conf") export("sparkR.version") export("sparkR.uiWebUrl") @@ -42,9 +41,6 @@ export("sparkR.callJStatic") export("install.spark") -export("sparkRSQL.init", - "sparkRHive.init") - # MLlib integration exportMethods("glm", "spark.glm", @@ -151,7 +147,6 @@ exportMethods("arrange", "printSchema", "randomSplit", "rbind", - "registerTempTable", "rename", "repartition", "repartitionByRange", @@ -159,7 +154,6 @@ exportMethods("arrange", "sample", "sample_frac", "sampleBy", - "saveAsParquetFile", "saveAsTable", "saveDF", "schema", @@ -175,7 +169,6 @@ exportMethods("arrange", "toJSON", "transform", "union", - "unionAll", "unionByName", "unique", "unpersist", @@ -415,10 +408,8 @@ export("as.DataFrame", "cacheTable", "clearCache", "createDataFrame", - "createExternalTable", "createTable", "currentDatabase", - "dropTempTable", "dropTempView", "listColumns", "listDatabases", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index bf82d0c7882d7..c99ad76f7643c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -227,7 +227,7 @@ setMethod("showDF", #' show #' #' If eager evaluation is enabled and the Spark object is a SparkDataFrame, evaluate the -#' SparkDataFrame and print top rows of the SparkDataFrame, otherwise, print the class +#' SparkDataFrame and print top rows of the SparkDataFrame, otherwise, print the class #' and type information of the Spark object. #' #' @param object a Spark object. Can be a SparkDataFrame, Column, GroupedData, WindowSpec. @@ -521,32 +521,6 @@ setMethod("createOrReplaceTempView", invisible(callJMethod(x@sdf, "createOrReplaceTempView", viewName)) }) -#' (Deprecated) Register Temporary Table -#' -#' Registers a SparkDataFrame as a Temporary Table in the SparkSession -#' @param x A SparkDataFrame -#' @param tableName A character vector containing the name of the table -#' -#' @seealso \link{createOrReplaceTempView} -#' @rdname registerTempTable-deprecated -#' @name registerTempTable -#' @aliases registerTempTable,SparkDataFrame,character-method -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' registerTempTable(df, "json_df") -#' new_df <- sql("SELECT * FROM json_df") -#'} -#' @note registerTempTable since 1.4.0 -setMethod("registerTempTable", - signature(x = "SparkDataFrame", tableName = "character"), - function(x, tableName) { - .Deprecated("createOrReplaceTempView") - invisible(callJMethod(x@sdf, "createOrReplaceTempView", tableName)) - }) - #' insertInto #' #' Insert the contents of a SparkDataFrame into a table registered in the current SparkSession. @@ -956,7 +930,6 @@ setMethod("write.orc", #' path <- "path/to/file.json" #' df <- read.json(path) #' write.parquet(df, "/tmp/sparkr-tmp1/") -#' saveAsParquetFile(df, "/tmp/sparkr-tmp2/") #'} #' @note write.parquet since 1.6.0 setMethod("write.parquet", @@ -967,17 +940,6 @@ setMethod("write.parquet", invisible(handledCallJMethod(write, "parquet", path)) }) -#' @rdname write.parquet -#' @name saveAsParquetFile -#' @aliases saveAsParquetFile,SparkDataFrame,character-method -#' @note saveAsParquetFile since 1.4.0 -setMethod("saveAsParquetFile", - signature(x = "SparkDataFrame", path = "character"), - function(x, path) { - .Deprecated("write.parquet") - write.parquet(x, path) - }) - #' Save the content of SparkDataFrame in a text file at the specified path. #' #' Save the content of the SparkDataFrame in a text file at the specified path. @@ -2762,18 +2724,6 @@ setMethod("union", dataFrame(unioned) }) -#' unionAll is deprecated - use union instead -#' @rdname union -#' @name unionAll -#' @aliases unionAll,SparkDataFrame,SparkDataFrame-method -#' @note unionAll since 1.4.0 -setMethod("unionAll", - signature(x = "SparkDataFrame", y = "SparkDataFrame"), - function(x, y) { - .Deprecated("union") - union(x, y) - }) - #' Return a new SparkDataFrame containing the union of rows, matched by column names #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3f89ee99e2564..afcdd6faa849d 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -37,37 +37,6 @@ getInternalType <- function(x) { stop(paste("Unsupported type for SparkDataFrame:", class(x)))) } -#' Temporary function to reroute old S3 Method call to new -#' This function is specifically implemented to remove SQLContext from the parameter list. -#' It determines the target to route the call by checking the parent of this callsite (say 'func'). -#' The target should be called 'func.default'. -#' We need to check the class of x to ensure it is SQLContext/HiveContext before dispatching. -#' @param newFuncSig name of the function the user should call instead in the deprecation message -#' @param x the first parameter of the original call -#' @param ... the rest of parameter to pass along -#' @return whatever the target returns -#' @noRd -dispatchFunc <- function(newFuncSig, x, ...) { - # When called with SparkR::createDataFrame, sys.call()[[1]] returns c(::, SparkR, createDataFrame) - callsite <- as.character(sys.call(sys.parent())[[1]]) - funcName <- callsite[[length(callsite)]] - f <- get(paste0(funcName, ".default")) - # Strip sqlContext from list of parameters and then pass the rest along. - contextNames <- c("org.apache.spark.sql.SQLContext", - "org.apache.spark.sql.hive.HiveContext", - "org.apache.spark.sql.hive.test.TestHiveContext", - "org.apache.spark.sql.SparkSession") - if (missing(x) && length(list(...)) == 0) { - f() - } else if (class(x) == "jobj" && - any(grepl(paste(contextNames, collapse = "|"), getClassName.jobj(x)))) { - .Deprecated(newFuncSig, old = paste0(funcName, "(sqlContext...)")) - f(...) - } else { - f(x, ...) - } -} - #' return the SparkSession #' @noRd getSparkSession <- function() { @@ -198,11 +167,10 @@ getDefaultSqlSource <- function() { #' df4 <- createDataFrame(cars, numPartitions = 2) #' } #' @name createDataFrame -#' @method createDataFrame default #' @note createDataFrame since 1.4.0 # TODO(davies): support sampling and infer type from NA -createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, - numPartitions = NULL) { +createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0, + numPartitions = NULL) { sparkSession <- getSparkSession() if (is.data.frame(data)) { @@ -285,31 +253,18 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, dataFrame(sdf) } -createDataFrame <- function(x, ...) { - dispatchFunc("createDataFrame(data, schema = NULL)", x, ...) -} - #' @rdname createDataFrame #' @aliases createDataFrame -#' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 -as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { +as.DataFrame <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { createDataFrame(data, schema, samplingRatio, numPartitions) } -#' @param ... additional argument(s). -#' @rdname createDataFrame -#' @aliases as.DataFrame -as.DataFrame <- function(data, ...) { - dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) -} - #' toDF #' #' Converts an RDD to a SparkDataFrame by infer the types. #' #' @param x An RDD -#' #' @rdname SparkDataFrame #' @noRd #' @examples @@ -345,9 +300,8 @@ setMethod("toDF", signature(x = "RDD"), #' df <- read.json(path, multiLine = TRUE) #' } #' @name read.json -#' @method read.json default #' @note read.json since 1.6.0 -read.json.default <- function(path, ...) { +read.json <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) # Allow the user to have a more flexible definition of the text file path @@ -358,10 +312,6 @@ read.json.default <- function(path, ...) { dataFrame(sdf) } -read.json <- function(x, ...) { - dispatchFunc("read.json(path)", x, ...) -} - #' Create a SparkDataFrame from an ORC file. #' #' Loads an ORC file, returning the result as a SparkDataFrame. @@ -388,13 +338,12 @@ read.orc <- function(path, ...) { #' Loads a Parquet file, returning the result as a SparkDataFrame. #' #' @param path path of file to read. A vector of multiple paths is allowed. -#' @param ... additional external data source specific named properties. +#' @param ... additional data source specific named properties. #' @return SparkDataFrame #' @rdname read.parquet #' @name read.parquet -#' @method read.parquet default #' @note read.parquet since 1.6.0 -read.parquet.default <- function(path, ...) { +read.parquet <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) # Allow the user to have a more flexible definition of the Parquet file path @@ -405,10 +354,6 @@ read.parquet.default <- function(path, ...) { dataFrame(sdf) } -read.parquet <- function(x, ...) { - dispatchFunc("read.parquet(...)", x, ...) -} - #' Create a SparkDataFrame from a text file. #' #' Loads text files and returns a SparkDataFrame whose schema starts with @@ -428,9 +373,8 @@ read.parquet <- function(x, ...) { #' df <- read.text(path) #' } #' @name read.text -#' @method read.text default #' @note read.text since 1.6.1 -read.text.default <- function(path, ...) { +read.text <- function(path, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) # Allow the user to have a more flexible definition of the text file path @@ -441,10 +385,6 @@ read.text.default <- function(path, ...) { dataFrame(sdf) } -read.text <- function(x, ...) { - dispatchFunc("read.text(path)", x, ...) -} - #' SQL Query #' #' Executes a SQL query using Spark, returning the result as a SparkDataFrame. @@ -461,18 +401,13 @@ read.text <- function(x, ...) { #' new_df <- sql("SELECT * FROM table") #' } #' @name sql -#' @method sql default #' @note sql since 1.4.0 -sql.default <- function(sqlQuery) { +sql <- function(sqlQuery) { sparkSession <- getSparkSession() sdf <- callJMethod(sparkSession, "sql", sqlQuery) dataFrame(sdf) } -sql <- function(x, ...) { - dispatchFunc("sql(sqlQuery)", x, ...) -} - #' Create a SparkDataFrame from a SparkSQL table or view #' #' Returns the specified table or view as a SparkDataFrame. The table or view must already exist or @@ -531,9 +466,8 @@ tableToDF <- function(tableName) { #' df4 <- read.df(mapTypeJsonPath, "json", stringSchema, multiLine = TRUE) #' } #' @name read.df -#' @method read.df default #' @note read.df since 1.4.0 -read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { +read.df <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { stop("path should be character, NULL or omitted.") } @@ -568,22 +502,13 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string dataFrame(sdf) } -read.df <- function(x = NULL, ...) { - dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) -} - #' @rdname read.df #' @name loadDF -#' @method loadDF default #' @note loadDF since 1.6.0 -loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { +loadDF <- function(path = NULL, source = NULL, schema = NULL, ...) { read.df(path, source, schema, ...) } -loadDF <- function(x = NULL, ...) { - dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) -} - #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index c2d0fc38786be..7641f8a7a0432 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -17,40 +17,6 @@ # catalog.R: SparkSession catalog functions -#' (Deprecated) Create an external table -#' -#' Creates an external table based on the dataset in a data source, -#' Returns a SparkDataFrame associated with the external table. -#' -#' The data source is specified by the \code{source} and a set of options(...). -#' If \code{source} is not specified, the default data source configured by -#' "spark.sql.sources.default" will be used. -#' -#' @param tableName a name of the table. -#' @param path the path of files to load. -#' @param source the name of external data source. -#' @param schema the schema of the data required for some data sources. -#' @param ... additional argument(s) passed to the method. -#' @return A SparkDataFrame. -#' @rdname createExternalTable-deprecated -#' @seealso \link{createTable} -#' @examples -#'\dontrun{ -#' sparkR.session() -#' df <- createExternalTable("myjson", path="path/to/json", source="json", schema) -#' } -#' @name createExternalTable -#' @method createExternalTable default -#' @note createExternalTable since 1.4.0 -createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { - .Deprecated("createTable", old = "createExternalTable") - createTable(tableName, path, source, schema, ...) -} - -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Creates a table based on the dataset in a data source #' #' Creates a table based on the dataset in a data source. Returns a SparkDataFrame associated with @@ -116,18 +82,13 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, .. #' cacheTable("table") #' } #' @name cacheTable -#' @method cacheTable default #' @note cacheTable since 1.4.0 -cacheTable.default <- function(tableName) { +cacheTable <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") invisible(handledCallJMethod(catalog, "cacheTable", tableName)) } -cacheTable <- function(x, ...) { - dispatchFunc("cacheTable(tableName)", x, ...) -} - #' Uncache Table #' #' Removes the specified table from the in-memory cache. @@ -145,18 +106,13 @@ cacheTable <- function(x, ...) { #' uncacheTable("table") #' } #' @name uncacheTable -#' @method uncacheTable default #' @note uncacheTable since 1.4.0 -uncacheTable.default <- function(tableName) { +uncacheTable <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") invisible(handledCallJMethod(catalog, "uncacheTable", tableName)) } -uncacheTable <- function(x, ...) { - dispatchFunc("uncacheTable(tableName)", x, ...) -} - #' Clear Cache #' #' Removes all cached tables from the in-memory cache. @@ -167,48 +123,13 @@ uncacheTable <- function(x, ...) { #' clearCache() #' } #' @name clearCache -#' @method clearCache default #' @note clearCache since 1.4.0 -clearCache.default <- function() { +clearCache <- function() { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") invisible(callJMethod(catalog, "clearCache")) } -clearCache <- function() { - dispatchFunc("clearCache()") -} - -#' (Deprecated) Drop Temporary Table -#' -#' Drops the temporary table with the given table name in the catalog. -#' If the table has been cached/persisted before, it's also unpersisted. -#' -#' @param tableName The name of the SparkSQL table to be dropped. -#' @seealso \link{dropTempView} -#' @rdname dropTempTable-deprecated -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempTable("table") -#' } -#' @name dropTempTable -#' @method dropTempTable default -#' @note dropTempTable since 1.4.0 -dropTempTable.default <- function(tableName) { - .Deprecated("dropTempView", old = "dropTempTable") - if (class(tableName) != "character") { - stop("tableName must be a string.") - } - dropTempView(tableName) -} - -dropTempTable <- function(x, ...) { - dispatchFunc("dropTempView(viewName)", x, ...) -} - #' Drops the temporary view with the given view name in the catalog. #' #' Drops the temporary view with the given view name in the catalog. @@ -249,17 +170,12 @@ dropTempView <- function(viewName) { #' tables("hive") #' } #' @name tables -#' @method tables default #' @note tables since 1.4.0 -tables.default <- function(databaseName = NULL) { +tables <- function(databaseName = NULL) { # rename column to match previous output schema withColumnRenamed(listTables(databaseName), "name", "tableName") } -tables <- function(x, ...) { - dispatchFunc("tables(databaseName = NULL)", x, ...) -} - #' Table Names #' #' Returns the names of tables in the given database as an array. @@ -273,9 +189,8 @@ tables <- function(x, ...) { #' tableNames("hive") #' } #' @name tableNames -#' @method tableNames default #' @note tableNames since 1.4.0 -tableNames.default <- function(databaseName = NULL) { +tableNames <- function(databaseName = NULL) { sparkSession <- getSparkSession() callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTableNames", @@ -283,10 +198,6 @@ tableNames.default <- function(databaseName = NULL) { databaseName) } -tableNames <- function(x, ...) { - dispatchFunc("tableNames(databaseName = NULL)", x, ...) -} - #' Returns the current default database #' #' Returns the current default database. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 045e0754f4651..76e17c10843d2 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -528,9 +528,6 @@ setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) #' @rdname printSchema setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname registerTempTable-deprecated -setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) - #' @rdname rename setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -595,9 +592,6 @@ setGeneric("write.parquet", function(x, path, ...) { standardGeneric("write.parquet") }) -#' @rdname write.parquet -setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) - #' @rdname write.stream setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { standardGeneric("write.stream") @@ -637,9 +631,6 @@ setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union setGeneric("union", function(x, y) { standardGeneric("union") }) -#' @rdname union -setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) - #' @rdname unionByName setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 038fefadaaeff..ac289d38d01bd 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -88,49 +88,6 @@ sparkR.stop <- function() { sparkR.session.stop() } -#' (Deprecated) Initialize a new Spark Context -#' -#' This function initializes a new SparkContext. -#' -#' @param master The Spark master URL -#' @param appName Application name to register with cluster manager -#' @param sparkHome Spark Home directory -#' @param sparkEnvir Named list of environment variables to set on worker nodes -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors -#' @param sparkJars Character vector of jar files to pass to the worker nodes -#' @param sparkPackages Character vector of package coordinates -#' @seealso \link{sparkR.session} -#' @rdname sparkR.init-deprecated -#' @examples -#'\dontrun{ -#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") -#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", -#' list(spark.executor.memory="1g")) -#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", -#' list(spark.executor.memory="4g"), -#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.11:2.0.1")) -#'} -#' @note sparkR.init since 1.4.0 -sparkR.init <- function( - master = "", - appName = "SparkR", - sparkHome = Sys.getenv("SPARK_HOME"), - sparkEnvir = list(), - sparkExecutorEnv = list(), - sparkJars = "", - sparkPackages = "") { - .Deprecated("sparkR.session") - sparkR.sparkContext(master, - appName, - sparkHome, - convertNamedListToEnv(sparkEnvir), - convertNamedListToEnv(sparkExecutorEnv), - sparkJars, - sparkPackages) -} - # Internal function to handle creating the SparkContext. sparkR.sparkContext <- function( master = "", @@ -272,61 +229,6 @@ sparkR.sparkContext <- function( sc } -#' (Deprecated) Initialize a new SQLContext -#' -#' This function creates a SparkContext from an existing JavaSparkContext and -#' then uses it to initialize a new SQLContext -#' -#' Starting SparkR 2.0, a SparkSession is initialized and returned instead. -#' This API is deprecated and kept for backward compatibility only. -#' -#' @param jsc The existing JavaSparkContext created with SparkR.init() -#' @seealso \link{sparkR.session} -#' @rdname sparkRSQL.init-deprecated -#' @examples -#'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#'} -#' @note sparkRSQL.init since 1.4.0 -sparkRSQL.init <- function(jsc = NULL) { - .Deprecated("sparkR.session") - - if (exists(".sparkRsession", envir = .sparkREnv)) { - return(get(".sparkRsession", envir = .sparkREnv)) - } - - # Default to without Hive support for backward compatibility. - sparkR.session(enableHiveSupport = FALSE) -} - -#' (Deprecated) Initialize a new HiveContext -#' -#' This function creates a HiveContext from an existing JavaSparkContext -#' -#' Starting SparkR 2.0, a SparkSession is initialized and returned instead. -#' This API is deprecated and kept for backward compatibility only. -#' -#' @param jsc The existing JavaSparkContext created with SparkR.init() -#' @seealso \link{sparkR.session} -#' @rdname sparkRHive.init-deprecated -#' @examples -#'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRHive.init(sc) -#'} -#' @note sparkRHive.init since 1.4.0 -sparkRHive.init <- function(jsc = NULL) { - .Deprecated("sparkR.session") - - if (exists(".sparkRsession", envir = .sparkREnv)) { - return(get(".sparkRsession", envir = .sparkREnv)) - } - - # Default to without Hive support for backward compatibility. - sparkR.session(enableHiveSupport = TRUE) -} - #' Get the existing SparkSession or initialize a new SparkSession. #' #' SparkSession is the entry point into SparkR. \code{sparkR.session} gets the existing @@ -482,26 +384,11 @@ sparkR.uiWebUrl <- function() { #' setJobGroup("myJobGroup", "My job group description", TRUE) #'} #' @note setJobGroup since 1.5.0 -#' @method setJobGroup default -setJobGroup.default <- function(groupId, description, interruptOnCancel) { +setJobGroup <- function(groupId, description, interruptOnCancel) { sc <- getSparkContext() invisible(callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)) } -setJobGroup <- function(sc, groupId, description, interruptOnCancel) { - if (class(sc) == "jobj" && any(grepl("JavaSparkContext", getClassName.jobj(sc)))) { - .Deprecated("setJobGroup(groupId, description, interruptOnCancel)", - old = "setJobGroup(sc, groupId, description, interruptOnCancel)") - setJobGroup.default(groupId, description, interruptOnCancel) - } else { - # Parameter order is shifted - groupIdToUse <- sc - descriptionToUse <- groupId - interruptOnCancelToUse <- description - setJobGroup.default(groupIdToUse, descriptionToUse, interruptOnCancelToUse) - } -} - #' Clear current job group ID and its description #' #' @rdname clearJobGroup @@ -512,22 +399,11 @@ setJobGroup <- function(sc, groupId, description, interruptOnCancel) { #' clearJobGroup() #'} #' @note clearJobGroup since 1.5.0 -#' @method clearJobGroup default -clearJobGroup.default <- function() { +clearJobGroup <- function() { sc <- getSparkContext() invisible(callJMethod(sc, "clearJobGroup")) } -clearJobGroup <- function(sc) { - if (!missing(sc) && - class(sc) == "jobj" && - any(grepl("JavaSparkContext", getClassName.jobj(sc)))) { - .Deprecated("clearJobGroup()", old = "clearJobGroup(sc)") - } - clearJobGroup.default() -} - - #' Cancel active jobs for the specified group #' #' @param groupId the ID of job group to be cancelled @@ -539,23 +415,11 @@ clearJobGroup <- function(sc) { #' cancelJobGroup("myJobGroup") #'} #' @note cancelJobGroup since 1.5.0 -#' @method cancelJobGroup default -cancelJobGroup.default <- function(groupId) { +cancelJobGroup <- function(groupId) { sc <- getSparkContext() invisible(callJMethod(sc, "cancelJobGroup", groupId)) } -cancelJobGroup <- function(sc, groupId) { - if (class(sc) == "jobj" && any(grepl("JavaSparkContext", getClassName.jobj(sc)))) { - .Deprecated("cancelJobGroup(groupId)", old = "cancelJobGroup(sc, groupId)") - cancelJobGroup.default(groupId) - } else { - # Parameter order is shifted - groupIdToUse <- sc - cancelJobGroup.default(groupIdToUse) - } -} - #' Set a human readable description of the current job. #' #' Set a description that is shown as a job description in UI. diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 288a2714a554e..eb8d2a700e1ea 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -54,15 +54,6 @@ test_that("Check masked functions", { sort(namesOfMaskedCompletely, na.last = TRUE)) }) -test_that("repeatedly starting and stopping SparkR", { - for (i in 1:4) { - sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) - rdd <- parallelize(sc, 1:20, 2L) - expect_equal(countRDD(rdd), 20) - suppressWarnings(sparkR.stop()) - } -}) - test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) @@ -101,9 +92,6 @@ test_that("job group functions can be called", { cancelJobGroup("groupId") clearJobGroup() - suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) - suppressWarnings(cancelJobGroup(sc, "groupId")) - suppressWarnings(clearJobGroup(sc)) sparkR.session.stop() }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 68bf5eac98462..58e0a54d2aacc 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -106,15 +106,6 @@ if (is_windows()) { Sys.setenv(TZ = "GMT") } -test_that("calling sparkRSQL.init returns existing SQL context", { - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) - expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) -}) - -test_that("calling sparkRSQL.init returns existing SparkSession", { - expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) -}) - test_that("calling sparkR.session returns existing SparkSession", { expect_equal(sparkR.session(), sparkSession) }) @@ -221,7 +212,7 @@ test_that("structField type strings", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(rdd, list("a", "b")) + df <- SparkR::createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") expect_is(dfAsDF, "SparkDataFrame") @@ -287,7 +278,7 @@ test_that("create DataFrame from RDD", { df <- as.DataFrame(cars, numPartitions = 2) expect_equal(getNumPartitions(df), 2) - df <- createDataFrame(cars, numPartitions = 3) + df <- SparkR::createDataFrame(cars, numPartitions = 3) expect_equal(getNumPartitions(df), 3) # validate limit by num of rows df <- createDataFrame(cars, numPartitions = 60) @@ -308,7 +299,7 @@ test_that("create DataFrame from RDD", { sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) insertInto(df, "people") - expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, + expect_equal(collect(SparkR::sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) @@ -665,10 +656,10 @@ test_that("test tableNames and tables", { expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) - suppressWarnings(registerTempTable(df, "table2")) + createOrReplaceTempView(df, "table2") tables <- listTables() expect_equal(count(tables), count + 2) - suppressWarnings(dropTempTable("table1")) + dropTempView("table1") expect_true(dropTempView("table2")) tables <- listTables() @@ -2461,7 +2452,7 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF expect_is(unioned, "SparkDataFrame") expect_equal(count(unioned), 6) expect_equal(first(unioned)$name, "Michael") - expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), df$age)), 6) + expect_equal(count(arrange(suppressWarnings(union(df, df2)), df$age)), 6) df1 <- select(df2, "age", "name") unioned1 <- arrange(unionByName(df1, df), df1$age) @@ -2640,11 +2631,11 @@ test_that("read/write Parquet files", { expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) - # Test write.parquet/saveAsParquetFile and read.parquet + # Test write.parquet and read.parquet parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.parquet(df, parquetPath2) parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - suppressWarnings(saveAsParquetFile(df, parquetPath3)) + write.parquet(df, parquetPath3) parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) expect_is(parquetDF, "SparkDataFrame") expect_equal(count(parquetDF), count(df) * 2) @@ -3456,39 +3447,6 @@ test_that("Window functions on a DataFrame", { expect_equal(result, expected) }) -test_that("createDataFrame sqlContext parameter backward compatibility", { - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) - a <- 1:3 - b <- c("a", "b", "c") - ldf <- data.frame(a, b) - # Call function with namespace :: operator - SPARK-16538 - df <- suppressWarnings(SparkR::createDataFrame(sqlContext, ldf)) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - expect_equal(count(df), 3) - ldf2 <- collect(df) - expect_equal(ldf$a, ldf2$a) - - df2 <- suppressWarnings(createDataFrame(sqlContext, iris)) - expect_equal(count(df2), 150) - expect_equal(ncol(df2), 5) - - df3 <- suppressWarnings(read.df(sqlContext, jsonPath, "json")) - expect_is(df3, "SparkDataFrame") - expect_equal(count(df3), 3) - - before <- suppressWarnings(createDataFrame(sqlContext, iris)) - after <- suppressWarnings(createDataFrame(iris)) - expect_equal(collect(before), collect(after)) - - # more tests for SPARK-16538 - createOrReplaceTempView(df, "table") - SparkR::listTables() - SparkR::sql("SELECT 1") - suppressWarnings(SparkR::sql(sqlContext, "SELECT * FROM table")) - suppressWarnings(SparkR::dropTempTable(sqlContext, "table")) -}) - test_that("randomSplit", { num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) @@ -3675,7 +3633,7 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { createOrReplaceTempView(as.DataFrame(cars), "cars") - tb <- listTables() + tb <- SparkR::listTables() expect_equal(nrow(tb), count + 1) tbs <- collect(tb) expect_true(nrow(tbs[tbs$name == "cars", ]) > 0) diff --git a/docs/sparkr.md b/docs/sparkr.md index 5882ed7923aa7..cc6bc6d14853d 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -487,7 +487,7 @@ df2 ##+-------+-----+ ##only showing top 10 rows -{% endhighlight %} +{% endhighlight %} Note that to enable eager execution in `sparkR` shell, add `spark.sql.repl.eagerEval.enabled=true` configuration property to the `--conf` option. @@ -717,4 +717,5 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 3.0.0 - - The deprecated methods `parquetFile`, `jsonRDD` and `jsonFile` in `SQLContext` have been removed. Use `read.parquet` and `read.json`. + - The deprecated methods `sparkR.init`, `sparkRSQL.init`, `sparkRHive.init` have been removed. Use `sparkR.session` instead. + - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, `dropTempTable`, `unionAll` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead. From e545811346189cb9770bb54dc31ba93057cdc68e Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 28 Oct 2018 09:38:38 +0800 Subject: [PATCH 1940/2461] [SPARK-19851][SQL] Add support for EVERY and ANY (SOME) aggregates ## What changes were proposed in this pull request? Implements Every, Some, Any aggregates in SQL. These new aggregate expressions are analyzed in normal way and rewritten to equivalent existing aggregate expressions in the optimizer. Every(x) => Min(x) where x is boolean. Some(x) => Max(x) where x is boolean. Any is a synonym for Some. SQL ``` explain extended select every(v) from test_agg group by k; ``` Plan : ``` == Parsed Logical Plan == 'Aggregate ['k], [unresolvedalias('every('v), None)] +- 'UnresolvedRelation `test_agg` == Analyzed Logical Plan == every(v): boolean Aggregate [k#0], [every(v#1) AS every(v)#5] +- SubqueryAlias `test_agg` +- Project [k#0, v#1] +- SubqueryAlias `test_agg` +- LocalRelation [k#0, v#1] == Optimized Logical Plan == Aggregate [k#0], [min(v#1) AS every(v)#5] +- LocalRelation [k#0, v#1] == Physical Plan == *(2) HashAggregate(keys=[k#0], functions=[min(v#1)], output=[every(v)#5]) +- Exchange hashpartitioning(k#0, 200) +- *(1) HashAggregate(keys=[k#0], functions=[partial_min(v#1)], output=[k#0, min#7]) +- LocalTableScan [k#0, v#1] Time taken: 0.512 seconds, Fetched 1 row(s) ``` ## How was this patch tested? Added tests in SQLQueryTestSuite, DataframeAggregateSuite Closes #22809 from dilipbiswal/SPARK-19851-specific-rewrite. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../sql/catalyst/expressions/Expression.scala | 26 +++ .../aggregate/UnevaluableAggs.scala | 62 +++++ .../catalyst/optimizer/finishAnalysis.scala | 18 +- .../ExpressionTypeCheckingSuite.scala | 3 + .../resources/sql-tests/inputs/group-by.sql | 66 ++++++ .../sql-tests/results/group-by.sql.out | 214 +++++++++++++++++- 7 files changed, 388 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..af6166bcb8692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,9 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[EveryAgg]("every"), + expression[AnyAgg]("any"), + expression[SomeAgg]("some"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c215735ab1c98..ccc5b9043a0aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override lazy val canonicalized: Expression = child.canonicalized } +/** + * An aggregate expression that gets rewritten (currently by the optimizer) into a + * different aggregate expression for evaluation. This is mainly used to provide compatibility + * with other databases. For example, we use this to support every, any/some aggregates by rewriting + * them with Min and Max respectively. + */ +trait UnevaluableAggregate extends DeclarativeAggregate { + + override def nullable: Boolean = true + + override lazy val aggBufferAttributes = + throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this") + + override lazy val initialValues: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this") + + override lazy val updateExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this") + + override lazy val mergeExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this") + + override lazy val evaluateExpression: Expression = + throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this") +} /** * Expressions that don't have SQL representation should extend this trait. Examples are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala new file mode 100644 index 0000000000000..fc33ef919498b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +abstract class UnevaluableBooleanAggBase(arg: Expression) + extends UnevaluableAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = arg :: Nil + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + arg.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.", + since = "3.0.0") +case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Every" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Any" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class SomeAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Some" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index af0837e36e8ad..fe196ec7c9d54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -28,13 +29,24 @@ import org.apache.spark.sql.types._ /** - * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can - * be evaluated. This is mainly used to provide compatibility with other databases. - * For example, we use this to support "nvl" by replacing it with "coalesce". + * Finds all the expressions that are unevaluable and replace/rewrite them with semantically + * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions: + * 1) [[RuntimeReplaceable]] expressions + * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any + * This is mainly used to provide compatibility with other databases. + * Few examples are: + * we use this to support "nvl" by replacing it with "coalesce". + * we use this to replace Every and Any with Min and Max respectively. + * + * TODO: In future, explore an option to replace aggregate functions similar to + * how RruntimeReplaceable does. */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child + case SomeAgg(arg) => Max(arg) + case AnyAgg(arg) => Max(arg) + case EveryAgg(arg) => Min(arg) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..3eb3fe66cebc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(new EveryAgg('booleanField)) + assertSuccess(new AnyAgg('booleanField)) + assertSuccess(new SomeAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 433db71527437..ec263ea70bd4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -80,3 +80,69 @@ SELECT 1 FROM range(10) HAVING true; SELECT 1 FROM range(10) HAVING MAX(id) > 0; SELECT id FROM range(10) HAVING id > 0; + +-- Test data +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v); + +-- empty table +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0; + +-- all null values +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4; + +-- aggregates are null Filtering +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5; + +-- group by +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + +-- having +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false; +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- input type checking Int +SELECT every(1); + +-- input type checking Short +SELECT some(1S); + +-- input type checking Long +SELECT any(1L); + +-- input type checking String +SELECT every("true"); + +-- every/some/any aggregates are supported as windows expression. +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; + +-- simple explain of queries having every/some/any agregates. Optimized +-- plan should show the rewritten aggregate expression. +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f9d1ee8a6bcdb..9a8d025331b67 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 47 -- !query 0 @@ -275,3 +275,215 @@ struct<> -- !query 29 output org.apache.spark.sql.AnalysisException grouping expressions sequence is empty, and '`id`' is not an aggregate function. Wrap '()' in windowing function(s) or wrap '`id`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 30 +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0 +-- !query 31 schema +struct +-- !query 31 output +NULL NULL NULL + + +-- !query 32 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4 +-- !query 32 schema +struct +-- !query 32 output +NULL NULL NULL + + +-- !query 33 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5 +-- !query 33 schema +struct +-- !query 33 output +false true true + + +-- !query 34 +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 34 schema +struct +-- !query 34 output +1 false true true +2 true true true +3 false false false +4 NULL NULL NULL +5 false true true + + +-- !query 35 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false +-- !query 35 schema +struct +-- !query 35 output +1 false +3 false +5 false + + +-- !query 36 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL +-- !query 36 schema +struct +-- !query 36 output +4 NULL + + +-- !query 37 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 37 schema +struct +-- !query 37 output +2 true + + +-- !query 38 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 38 schema +struct +-- !query 38 output + + + +-- !query 39 +SELECT every(1) +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 + + +-- !query 40 +SELECT some(1S) +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 + + +-- !query 41 +SELECT any(1L) +-- !query 41 schema +struct<> +-- !query 41 output +org.apache.spark.sql.AnalysisException +cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 + + +-- !query 42 +SELECT every("true") +-- !query 42 schema +struct<> +-- !query 42 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 + + +-- !query 43 +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 43 schema +struct +-- !query 43 output +1 false false +1 true false +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true false + + +-- !query 44 +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 44 schema +struct +-- !query 44 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 45 +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 45 schema +struct +-- !query 45 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 46 +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 46 schema +struct +-- !query 46 output +== Parsed Logical Plan == +'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] ++- 'UnresolvedRelation `test_agg` + +== Analyzed Logical Plan == +k: int, every(v): boolean, some(v): boolean, any(v): boolean +Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] ++- SubqueryAlias `test_agg` + +- Project [k#x, v#x] + +- SubqueryAlias `test_agg` + +- LocalRelation [k#x, v#x] + +== Optimized Logical Plan == +Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x] ++- LocalRelation [k#x, v#x] + +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) ++- Exchange hashpartitioning(k#x, 200) + +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x]) + +- LocalTableScan [k#x, v#x] From ff4bb836aa768082df9227628dfd5a837f8e4f4e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 28 Oct 2018 13:33:26 +0800 Subject: [PATCH 1941/2461] [SPARK-25817][SQL] Dataset encoder should support combination of map and product type ## What changes were proposed in this pull request? After https://github.com/apache/spark/pull/22745 , Dataset encoder supports the combination of java bean and map type. This PR is to fix the Scala side. The reason why it didn't work before is, `CatalystToExternalMap` tries to get the data type of the input map expression, while it can be unresolved and its data type is known. To fix it, we can follow `UnresolvedMapObjects`, to create a `UnresolvedCatalystToExternalMap`, and only create `CatalystToExternalMap` when the input map expression is resolved and the data type is known. ## How was this patch tested? enable a old test case Closes #22812 from cloud-fan/map. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/ScalaReflection.scala | 15 +++-- .../sql/catalyst/analysis/Analyzer.scala | 13 ++++- .../catalyst/encoders/ExpressionEncoder.scala | 8 +-- .../expressions/objects/objects.scala | 56 +++++++++---------- .../spark/sql/DatasetPrimitiveSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 9 +++ 6 files changed, 59 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 40074b36f6a9a..912744eab6a3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -143,8 +143,7 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. + case _: MapType => expr case _ => UpCast(expr, expected, walkedTypePath) } @@ -163,8 +162,8 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. - val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, - walkedTypePath) + val input = upCastToExpectedType( + GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) val expr = deserializerFor(tpe, input, walkedTypePath) if (nullable) { @@ -350,10 +349,10 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - CatalystToExternalMap( + UnresolvedCatalystToExternalMap( + path, p => deserializerFor(keyType, p, walkedTypePath), p => deserializerFor(valueType, p, walkedTypePath), - path, mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -431,8 +430,8 @@ object ScalaReflection extends ScalaReflection { val walkedTypePath = s"""- root class: "$clsName"""" :: Nil // The input object to `ExpressionEncoder` is located at first column of an row. - val inputObject = BoundReference(0, dataTypeFor(tpe), - nullable = !tpe.typeSymbol.asClass.isPrimitive) + val isPrimitive = tpe.typeSymbol.asClass.isPrimitive + val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive) serializerFor(inputObject, tpe, walkedTypePath) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 63a07e3898f22..c2d22c5e7ce60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2384,14 +2384,23 @@ class Analyzer( case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { case ArrayType(et, cn) => - val expr = MapObjects(func, inputData, et, cn, cls) transformUp { + MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } - expr case other => throw new AnalysisException("need an array field but got " + other.catalogString) } + case u: UnresolvedCatalystToExternalMap if u.child.resolved => + u.child.dataType match { + case _: MapType => + CatalystToExternalMap(u) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + case other => + throw new AnalysisException("need a map field but got " + other.catalogString) + } } validateNestedTupleFields(result) result diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 29f6136a75ee8..2c8e81ef17d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -119,10 +119,9 @@ object ExpressionEncoder { } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c } - .distinct - assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " + - s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}") + val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct + assert(getColExprs.size == 1, "object deserializer should have only one " + + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") val input = GetStructField(GetColumnByOrdinal(0, schema), index) val newDeserializer = enc.objDeserializer.transformUp { @@ -216,7 +215,6 @@ case class ExpressionEncoder[T]( } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s - case s: CreateNamedStruct => s case _ => throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b6f9b4734e940..4fd36a47cef52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,14 +30,13 @@ import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -963,25 +962,32 @@ case class MapObjects private( } } +/** + * Similar to [[UnresolvedMapObjects]], this is a placeholder of [[CatalystToExternalMap]]. + * + * @param child An expression that when evaluated returns a map object. + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param collClass The type of the resulting collection. + */ +case class UnresolvedCatalystToExternalMap( + child: Expression, + @transient keyFunction: Expression => Expression, + @transient valueFunction: Expression => Expression, + collClass: Class[_]) extends UnaryExpression with Unevaluable { + + override lazy val resolved = false + + override def dataType: DataType = ObjectType(collClass) +} + object CatalystToExternalMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * Construct an instance of CatalystToExternalMap case class. - * - * @param keyFunction The function applied on the key collection elements. - * @param valueFunction The function applied on the value collection elements. - * @param inputData An expression that when evaluated returns a map object. - * @param collClass The type of the resulting collection. - */ - def apply( - keyFunction: Expression => Expression, - valueFunction: Expression => Expression, - inputData: Expression, - collClass: Class[_]): CatalystToExternalMap = { + def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = { val id = curId.getAndIncrement() val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" - val mapType = inputData.dataType.asInstanceOf[MapType] + val mapType = u.child.dataType.asInstanceOf[MapType] val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" val valueLoopIsNull = if (mapType.valueContainsNull) { @@ -991,9 +997,9 @@ object CatalystToExternalMap { } val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) CatalystToExternalMap( - keyLoopValue, keyFunction(keyLoopVar), - valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), - inputData, collClass) + keyLoopValue, u.keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, u.valueFunction(valueLoopVar), + u.child, u.collClass) } } @@ -1090,15 +1096,9 @@ case class CatalystToExternalMap private( val tupleLoopValue = ctx.freshName("tupleLoopValue") val builderValue = ctx.freshName("builderValue") - val getLength = s"${genInputData.value}.numElements()" - val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") - val getKeyArray = - s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) - val getValueArray = - s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" val getValueLoopVar = CodeGenerator.getValue( valueArray, inputDataType(mapType.valueType), loopIndex) @@ -1147,10 +1147,10 @@ case class CatalystToExternalMap private( ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { - int $dataLength = $getLength; + int $dataLength = ${genInputData.value}.numElements(); $constructBuilder - $getKeyArray - $getValueArray + ArrayData $keyArray = ${genInputData.value}.keyArray(); + ArrayData $valueArray = ${genInputData.value}.valueArray(); int $loopIndex = 0; while ($loopIndex < $dataLength) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index edcdd77908d3a..96a6792f52f3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -295,7 +295,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) } - ignore("SPARK-19104: map and product combinations") { + test("SPARK-25817: map and product combinations") { // Case classes checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 27b3b3d78d2bb..82d3b22a48670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -164,6 +164,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(ClassData("a", 2)))) } + test("as map of case class - reorder fields by name") { + val df = spark.range(3).select(map(lit(1), struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Map[Int, ClassData]] + assert(ds.collect() === Array( + Map(1 -> ClassData("a", 0)), + Map(1 -> ClassData("a", 1)), + Map(1 -> ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( From a7ab7f2348cfcd665f7815f5a9ae4d9a48383b5d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 28 Oct 2018 18:15:47 +0800 Subject: [PATCH 1942/2461] [SPARK-25845][SQL] Fix MatchError for calendar interval type in range frame left boundary ## What changes were proposed in this pull request? WindowSpecDefinition checks start < last, but CalendarIntervalType is not comparable, so it would throw the following exception at runtime: ``` scala.MatchError: CalendarIntervalType (of class org.apache.spark.sql.types.CalendarIntervalType$) at org.apache.spark.sql.catalyst.util.TypeUtils$.getInterpretedOrdering(TypeUtils.scala:58) at org.apache.spark.sql.catalyst.expressions.BinaryComparison.ordering$lzycompute(predicates.scala:592) at org.apache.spark.sql.catalyst.expressions.BinaryComparison.ordering(predicates.scala:592) at org.apache.spark.sql.catalyst.expressions.GreaterThan.nullSafeEval(predicates.scala:797) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.eval(Expression.scala:496) at org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame.isGreaterThan(windowExpressions.scala:245) at org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame.checkInputDataTypes(windowExpressions.scala:216) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:171) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:171) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$childrenResolved$1.apply(Expression.scala:183) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$childrenResolved$1.apply(Expression.scala:183) at scala.collection.IndexedSeqOptimized$class.prefixLengthImpl(IndexedSeqOptimized.scala:38) at scala.collection.IndexedSeqOptimized$class.forall(IndexedSeqOptimized.scala:43) at scala.collection.mutable.ArrayBuffer.forall(ArrayBuffer.scala:48) at org.apache.spark.sql.catalyst.expressions.Expression.childrenResolved(Expression.scala:183) at org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition.resolved$lzycompute(windowExpressions.scala:48) at org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition.resolved(windowExpressions.scala:48) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$childrenResolved$1.apply(Expression.scala:183) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$childrenResolved$1.apply(Expression.scala:183) at scala.collection.LinearSeqOptimized$class.forall(LinearSeqOptimized.scala:83) ``` We fix the issue by only perform the check on boundary expressions that are AtomicType. ## How was this patch tested? Add new test case in `DataFrameWindowFramesSuite` Closes #22853 from jiangxb1987/windowBoundary. Authored-by: Xingbo Jiang Signed-off-by: Xingbo Jiang --- .../expressions/windowExpressions.scala | 8 ++++++-- .../sql/DataFrameWindowFramesSuite.scala | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 707f312499734..7de6dddda4d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -242,8 +242,12 @@ case class SpecifiedWindowFrame( case e: Expression => e.sql + " FOLLOWING" } - private def isGreaterThan(l: Expression, r: Expression): Boolean = { - GreaterThan(l, r).eval().asInstanceOf[Boolean] + // Check whether the left boundary value is greater than the right boundary value. It's required + // that the both expressions have the same data type. + // Since CalendarIntervalType is not comparable, we only compare expressions that are AtomicType. + private def isGreaterThan(l: Expression, r: Expression): Boolean = l.dataType match { + case _: AtomicType => GreaterThan(l, r).eval().asInstanceOf[Boolean] + case _ => false } private def checkBoundary(b: Expression, location: String): TypeCheckResult = b match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 2a0b2b85e10a9..9c280744682b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -267,6 +267,25 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { ) } + test("range between should accept interval values as both boundaries") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + .rangeBetween(lit(CalendarInterval.fromString("interval 3 hours")), + lit(CalendarInterval.fromString("interval 23 days 4 hours"))) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(ts(1501545600), 1), Row(ts(1501545600), 1), Row(ts(1609372800), 0), + Row(ts(1503000000), 0), Row(ts(1502000000), 0), Row(ts(1609372800), 0)) + ) + } + test("unbounded rows/range between with aggregation") { val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") val window = Window.partitionBy($"key").orderBy($"value") From 4427a96bcea625bc51fc5e0e999f170ad537a2fc Mon Sep 17 00:00:00 2001 From: liuxian Date: Sun, 28 Oct 2018 17:39:16 -0500 Subject: [PATCH 1943/2461] [SPARK-25806][SQL] The instance of FileSplit is redundant ## What changes were proposed in this pull request? The instance of `FileSplit` is redundant for `ParquetFileFormat` and `hive\orc\OrcFileFormat` class. ## How was this patch tested? Existing unit tests in `ParquetQuerySuite.scala` and `HiveOrcQuerySuite.scala` Closes #22802 from 10110346/FileSplitnotneed. Authored-by: liuxian Signed-off-by: Sean Owen --- .../datasources/parquet/ParquetFileFormat.scala | 13 +++++-------- .../apache/spark/sql/hive/orc/OrcFileFormat.scala | 3 +-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ea4f1592a7c2e..f04502d113acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -352,17 +352,14 @@ class ParquetFileFormat (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - val fileSplit = - new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) - val filePath = fileSplit.getPath - + val filePath = new Path(new URI(file.filePath)) val split = new org.apache.parquet.hadoop.ParquetInputSplit( filePath, - fileSplit.getStart, - fileSplit.getStart + fileSplit.getLength, - fileSplit.getLength, - fileSplit.getLocations, + file.start, + file.start + file.length, + file.length, + Array.empty, null) val sharedConf = broadcastedHadoopConf.value.value diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index de8085f07db19..89e6ea8604974 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -154,13 +154,12 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val job = Job.getInstance(conf) FileInputFormat.setInputPaths(job, file.filePath) - val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) // Custom OrcRecordReader is used to get // ObjectInspector during recordReader creation itself and can // avoid NameNode call in unwrapOrcStructs per file. // Specifically would be helpful for partitioned datasets. val orcReader = OrcFile.createReader(filePath, OrcFile.readerOptions(conf)) - new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + new SparkOrcNewRecordReader(orcReader, conf, file.start, file.length) } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) From ca2fca143277deaff58a69b7f1e0360cfc70561f Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 28 Oct 2018 17:51:35 -0700 Subject: [PATCH 1944/2461] [SPARK-25816][SQL] Fix attribute resolution in nested extractors ## What changes were proposed in this pull request? Extractors are made of 2 expressions, one of them defines the the value to be extract from (called `child`) and the other defines the way of extraction (called `extraction`). In this term extractors have 2 children so they shouldn't be `UnaryExpression`s. `ResolveReferences` was changed in this commit: https://github.com/apache/spark/commit/36b826f5d17ae7be89135cb2c43ff797f9e7fe48 which resulted a regression with nested extractors. An extractor need to define its children as the set of both `child` and `extraction`; and should try to resolve both in `ResolveReferences`. This PR changes `UnresolvedExtractValue` to a `BinaryExpression`. ## How was this patch tested? added UT Closes #22817 from peter-toth/SPARK-25816. Authored-by: Peter Toth Signed-off-by: gatorsmile --- .../apache/spark/sql/catalyst/analysis/unresolved.scala | 5 ++++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c1ec736c32ed4..857cf382b8f2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -407,7 +407,10 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Une * can be key of Map, index of Array, field name of Struct. */ case class UnresolvedExtractValue(child: Expression, extraction: Expression) - extends UnaryExpression with Unevaluable { + extends BinaryExpression with Unevaluable { + + override def left: Expression = child + override def right: Expression = extraction override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3f9af29aa1af5..a430884581dad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2578,4 +2578,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row ("abc", 1)) } } + + test("SPARK-25816 ResolveReferences works with nested extractors") { + val df = Seq((1, Map(1 -> "a")), (2, Map(2 -> "b"))).toDF("key", "map") + val swappedDf = df.select($"key".as("map"), $"map".as("key")) + + checkAnswer(swappedDf.filter($"key"($"map") > "a"), Row(2, Map(2 -> "b"))) + } } From 4e990d9dd2407dc257712c4b12b507f0990ca4e9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 29 Oct 2018 13:44:58 +0800 Subject: [PATCH 1945/2461] [DOC] Fix doc for spark.sql.parquet.recordLevelFilter.enabled ## What changes were proposed in this pull request? Updated the doc string value for spark.sql.parquet.recordLevelFilter.enabled to indicate that spark.sql.parquet.enableVectorizedReader must be disabled. The code in ParquetFileFormat uses spark.sql.parquet.recordLevelFilter.enabled only after falling back to parquet-mr (see else for this if statement): https://github.com/apache/spark/blob/d5573c578a1eea9ee04886d9df37c7178e67bb30/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala#L412 https://github.com/apache/spark/blob/d5573c578a1eea9ee04886d9df37c7178e67bb30/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala#L427-L430 Tests also bear this out. ## How was this patch tested? This is just a doc string fix: I built Spark and ran a single test. Closes #22865 from bersprockets/confdocfix. Authored-by: Bruce Robbins Signed-off-by: Wenchen Fan --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e8529550b8fca..4edffce120aac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -462,7 +462,8 @@ object SQLConf { val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") .doc("If true, enables Parquet's native record-level filtering using the pushed down " + "filters. This configuration only has an effect when 'spark.sql.parquet.filterPushdown' " + - "is enabled.") + "is enabled and the vectorized reader is not used. You can ensure the vectorized reader " + + "is not used by setting 'spark.sql.parquet.enableVectorizedReader' to false.") .booleanConf .createWithDefault(false) From fbaf150507a289ec0ac02fdbf4009c42cd9bc164 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 28 Oct 2018 23:01:35 -0700 Subject: [PATCH 1946/2461] [SPARK-25179][PYTHON][DOCS] Document BinaryType support in Arrow conversion ## What changes were proposed in this pull request? This PR targets to document binary type in "Apache Arrow in Spark". ## How was this patch tested? Manually built the documentation and checked. Closes #22871 from HyukjinKwon/SPARK-25179. Authored-by: hyukjinkwon Signed-off-by: gatorsmile --- docs/sql-pyspark-pandas-with-arrow.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/sql-pyspark-pandas-with-arrow.md b/docs/sql-pyspark-pandas-with-arrow.md index e8e9f55bd12b3..d04b955f9bf8b 100644 --- a/docs/sql-pyspark-pandas-with-arrow.md +++ b/docs/sql-pyspark-pandas-with-arrow.md @@ -127,8 +127,9 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p ### Supported SQL Types -Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`, -`ArrayType` of `TimestampType`, and nested `StructType`. +Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +`ArrayType` of `TimestampType`, and nested `StructType`. `BinaryType` is supported only when +installed PyArrow is equal to or higher then 0.10.0. ### Setting Arrow Batch Size From 409d688fb6169ecbc41f6296c6341ae3ed7d1ec8 Mon Sep 17 00:00:00 2001 From: yucai Date: Mon, 29 Oct 2018 20:00:31 +0800 Subject: [PATCH 1947/2461] [SPARK-25864][SQL][TEST] Make main args accessible for BenchmarkBase's subclass ## What changes were proposed in this pull request? Set main args correctly in BenchmarkBase, to make it accessible for its subclass. It will benefit: - BuiltInDataSourceWriteBenchmark - AvroWriteBenchmark ## How was this patch tested? manual tests Closes #22872 from yucai/main_args. Authored-by: yucai Signed-off-by: Wenchen Fan --- .../test/scala/org/apache/spark/benchmark/BenchmarkBase.scala | 4 ++-- .../scala/org/apache/spark/serializer/KryoBenchmark.scala | 2 +- .../apache/spark/mllib/linalg/UDTSerializationBenchmark.scala | 2 +- .../src/test/scala/org/apache/spark/sql/HashBenchmark.scala | 2 +- .../scala/org/apache/spark/sql/HashByteArrayBenchmark.scala | 2 +- .../org/apache/spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../test/scala/org/apache/spark/sql/DatasetBenchmark.scala | 2 +- .../spark/sql/execution/benchmark/AggregateBenchmark.scala | 2 +- .../spark/sql/execution/benchmark/BloomFilterBenchmark.scala | 2 +- .../sql/execution/benchmark/DataSourceReadBenchmark.scala | 2 +- .../sql/execution/benchmark/FilterPushdownBenchmark.scala | 2 +- .../apache/spark/sql/execution/benchmark/JoinBenchmark.scala | 2 +- .../apache/spark/sql/execution/benchmark/MiscBenchmark.scala | 2 +- .../sql/execution/benchmark/PrimitiveArrayBenchmark.scala | 2 +- .../apache/spark/sql/execution/benchmark/RangeBenchmark.scala | 2 +- .../apache/spark/sql/execution/benchmark/SortBenchmark.scala | 2 +- .../sql/execution/benchmark/UnsafeArrayDataBenchmark.scala | 2 +- .../spark/sql/execution/benchmark/WideSchemaBenchmark.scala | 2 +- .../columnar/compression/CompressionSchemeBenchmark.scala | 2 +- .../sql/execution/vectorized/ColumnarBatchBenchmark.scala | 2 +- .../benchmark/ObjectHashAggregateExecBenchmark.scala | 2 +- .../org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala | 2 +- 22 files changed, 23 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 89e927e5784d2..24e596e1ecdaf 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -30,7 +30,7 @@ abstract class BenchmarkBase { * Implementations of this method are supposed to use the wrapper method `runBenchmark` * for each benchmark scenario. */ - def runBenchmarkSuite(): Unit + def runBenchmarkSuite(mainArgs: Array[String]): Unit final def runBenchmark(benchmarkName: String)(func: => Any): Unit = { val separator = "=" * 96 @@ -51,7 +51,7 @@ abstract class BenchmarkBase { output = Some(new FileOutputStream(file)) } - runBenchmarkSuite() + runBenchmarkSuite(args) output.foreach { o => if (o != null) { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala index 8a52c131af847..d7730f23da108 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoBenchmark.scala @@ -39,7 +39,7 @@ import org.apache.spark.serializer.KryoTest._ object KryoBenchmark extends BenchmarkBase { val N = 1000000 - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val name = "Benchmark Kryo Unsafe vs safe Serialization" runBenchmark(name) { val benchmark = new Benchmark(name, N, 10, output = output) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index 6c1d58089867a..5f19e466ecad0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder */ object UDTSerializationBenchmark extends BenchmarkBase { - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("VectorUDT de/serialization") { val iters = 1e2.toInt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 4226ab3773fe7..3b4b80daf0843 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -102,7 +102,7 @@ object HashBenchmark extends BenchmarkBase { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val singleInt = new StructType().add("i", IntegerType) test("single ints", singleInt, 1 << 15, 1 << 14) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 7dc865d85af04..dbfa7bb18aa65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -83,7 +83,7 @@ object HashByteArrayBenchmark extends BenchmarkBase { benchmark.run() } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Benchmark for MurMurHash 3 and xxHash64") { test(8, 42L, 1 << 10, 1 << 11) test(16, 42L, 1 << 10, 1 << 11) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index e7a99485cdf04..42a4cfc91f826 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -41,7 +41,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("unsafe projection") { val iters = 1024 * 16 val numRows = 1024 * 16 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index e3df449b41f0a..dba906f63aed4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -256,7 +256,7 @@ object DatasetBenchmark extends SqlBasedBenchmark { .getOrCreate() } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val numRows = 100000000 val numChains = 10 runBenchmark("Dataset Benchmark") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 86e0df2fea350..b7d28988274bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -44,7 +44,7 @@ import org.apache.spark.unsafe.map.BytesToBytesMap */ object AggregateBenchmark extends SqlBasedBenchmark { - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("aggregate without grouping") { val N = 500L << 22 codegenBenchmark("agg w/o group", N) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala index 2f3caca849cdf..f727ebcf3fd1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BloomFilterBenchmark.scala @@ -80,7 +80,7 @@ object BloomFilterBenchmark extends SqlBasedBenchmark { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { writeBenchmark() readBenchmark() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index a1e7f9e36f4b0..a1f51f8e54805 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -585,7 +585,7 @@ object DataSourceReadBenchmark extends BenchmarkBase with SQLHelper { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("SQL Single Numeric Column Scan") { Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index cf05ca3361711..017b74aabff70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -198,7 +198,7 @@ object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Pushdown for many distinct value case") { withTempPath { dir => withTempTable("orcTable", "parquetTable") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 7bad4cb927b42..ad81711a13947 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -164,7 +164,7 @@ object JoinBenchmark extends SqlBasedBenchmark { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Join Benchmark") { broadcastHashJoinLongKey() broadcastHashJoinLongKeyWithDuplicates() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index 43380869fefe4..c4662c8999e42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -148,7 +148,7 @@ object MiscBenchmark extends SqlBasedBenchmark { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { filterAndAggregateWithoutGroup(500L << 22) limitAndAggregateWithoutGroup(500L << 20) sample(500 << 18) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index 83edf73abfae5..8b1c422e63a3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -36,7 +36,7 @@ object PrimitiveArrayBenchmark extends BenchmarkBase { .config("spark.sql.autoBroadcastJoinThreshold", 1) .getOrCreate() - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Write primitive arrays in dataset") { writeDatasetArray(4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala index a844e02dcba30..a9f873f9094ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala @@ -32,7 +32,7 @@ import org.apache.spark.benchmark.Benchmark */ object RangeBenchmark extends SqlBasedBenchmark { - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { import spark.implicits._ runBenchmark("range") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 9a54e2320b80f..784438cd43ebe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -119,7 +119,7 @@ object SortBenchmark extends BenchmarkBase { benchmark.run() } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("radix sort") { sortBenchmark() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala index 79eaeab9c399f..f582d844cdc47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala @@ -194,7 +194,7 @@ object UnsafeArrayDataBenchmark extends BenchmarkBase { benchmark.run } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Benchmark UnsafeArrayData") { readUnsafeArray(10) writeUnsafeArray(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala index 124661986ca0b..f4642e7d353e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -208,7 +208,7 @@ object WideSchemaBenchmark extends SqlBasedBenchmark { deleteTmpFiles() } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmarkWithDeleteTmpFiles("parsing large select expressions") { parsingLargeSelectExpressions() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 0f9079744a220..8ea20f28a37b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -233,7 +233,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Compression Scheme Benchmark") { bitEncodingBenchmark(1024) shortEncodingBenchmark(1024) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index f311465e582ac..953b3a67d976f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -443,7 +443,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { benchmark.run } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Int Read/Write") { intAccess(1024 * 40) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 50ee09678e2cb..3226e3a5f318a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -212,7 +212,7 @@ object ObjectHashAggregateExecBenchmark extends BenchmarkBase with SQLHelper { Column(approxPercentile.toAggregateExpression(isDistinct)) } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Hive UDAF vs Spark AF") { hiveUDAFvsSparkAF(2 << 15) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 870ad4818eb28..ec13288f759a6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -336,7 +336,7 @@ object OrcReadBenchmark extends BenchmarkBase with SQLHelper { } } - override def runBenchmarkSuite(): Unit = { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("SQL Single Numeric Column Scan") { Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) From 7fe5cff0581ca9d8221533215098f40f69362018 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 29 Oct 2018 16:47:50 +0100 Subject: [PATCH 1948/2461] [SPARK-25767][SQL] Fix lazily evaluated stream of expressions in code generation ## What changes were proposed in this pull request? Code generation is incorrect if `outputVars` parameter of `consume` method in `CodegenSupport` contains a lazily evaluated stream of expressions. This PR fixes the issue by forcing the evaluation of `inputVars` before generating the code for UnsafeRow. ## How was this patch tested? Tested with the sample program provided in https://issues.apache.org/jira/browse/SPARK-25767 Closes #22789 from peter-toth/SPARK-25767. Authored-by: Peter Toth Signed-off-by: Herman van Hovell --- .../spark/sql/execution/WholeStageCodegenExec.scala | 5 ++++- .../spark/sql/execution/WholeStageCodegenSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f5aee627fe901..5f81b6fe743c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -146,7 +146,10 @@ trait CodegenSupport extends SparkPlan { if (outputVars != null) { assert(outputVars.length == output.length) // outputVars will be used to generate the code for UnsafeRow, so we should copy them - outputVars.map(_.copy()) + outputVars.map(_.copy()) match { + case stream: Stream[ExprCode] => stream.force + case other => other + } } else { assert(row != null, "outputVars and row cannot both be null.") ctx.currentVars = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index b714dcd5269fc..09ad0fdd66369 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -319,4 +319,15 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(df.limit(1).collect() === Array(Row("bat", 8.0))) } } + + test("SPARK-25767: Lazy evaluated stream of expressions handled correctly") { + val a = Seq(1).toDF("key") + val b = Seq((1, "a")).toDF("key", "value") + val c = Seq(1).toDF("key") + + val ab = a.join(b, Stream("key"), "left") + val abc = ab.join(c, Seq("key"), "left") + + checkAnswer(abc, Row(1, "a")) + } } From 5e5d886a2bc291a707cf4a6c70ecc6de6f8e990d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 29 Oct 2018 12:56:06 -0500 Subject: [PATCH 1949/2461] [SPARK-25856][SQL][MINOR] Remove AverageLike and CountLike classes ## What changes were proposed in this pull request? These two classes were added for regr_ expression support (SPARK-23907). These have been removed and hence we can remove these base classes and inline the logic in the concrete classes. ## How was this patch tested? Existing tests. Closes #22856 from dilipbiswal/average_cleanup. Authored-by: Dilip Biswal Signed-off-by: Sean Owen --- .../expressions/aggregate/Average.scala | 33 ++++++++----------- .../expressions/aggregate/Count.scala | 28 +++++++--------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 5ecb77be5965e..8dd80dc06ab2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,9 +23,21 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -abstract class AverageLike(child: Expression) extends DeclarativeAggregate { +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") override def nullable: Boolean = true + // Return data type. override def dataType: DataType = resultType @@ -63,28 +75,11 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { sum.cast(resultType) / count.cast(resultType) } - protected def updateExpressionsDef: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* sum = */ Add( sum, coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), /* count = */ If(child.isNull, count, count + 1L) ) - - override lazy val updateExpressions = updateExpressionsDef -} - -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) - extends AverageLike(child) with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil - - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 8cab8e4856997..d402f2d592b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,10 +21,17 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -/** - * Base class for all counting aggregators. - */ -abstract class CountLike extends DeclarativeAggregate { +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are all non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. @@ -45,19 +52,6 @@ abstract class CountLike extends DeclarativeAggregate { override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) -} - -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are all non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends CountLike { override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) From 5bd5e1b9c84b5f7d4d67ab94e02d49ebdd02f177 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 30 Oct 2018 07:38:26 +0800 Subject: [PATCH 1950/2461] [MINOR][SQL] Avoid hardcoded configuration keys in SQLConf's `doc` ## What changes were proposed in this pull request? This PR proposes to avoid hardcorded configuration keys in SQLConf's `doc. ## How was this patch tested? Manually verified. Closes #22877 from HyukjinKwon/minor-conf-name. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- .../apache/spark/sql/internal/SQLConf.scala | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4edffce120aac..535ec51e315d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -408,7 +408,8 @@ object SQLConf { val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") .doc("If true, enables Parquet filter push-down optimization for Date. " + - "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") .internal() .booleanConf .createWithDefault(true) @@ -416,7 +417,7 @@ object SQLConf { val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = buildConf("spark.sql.parquet.filterPushdown.timestamp") .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + - "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") .internal() .booleanConf @@ -425,7 +426,8 @@ object SQLConf { val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = buildConf("spark.sql.parquet.filterPushdown.decimal") .doc("If true, enables Parquet filter push-down optimization for Decimal. " + - "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") .internal() .booleanConf .createWithDefault(true) @@ -433,7 +435,8 @@ object SQLConf { val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + - "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") .internal() .booleanConf .createWithDefault(true) @@ -444,7 +447,8 @@ object SQLConf { "Large threshold won't necessarily provide much better performance. " + "The experiment argued that 300 is the limit threshold. " + "By setting this value to 0 this feature can be disabled. " + - "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") .internal() .intConf .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") @@ -459,14 +463,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") - .doc("If true, enables Parquet's native record-level filtering using the pushed down " + - "filters. This configuration only has an effect when 'spark.sql.parquet.filterPushdown' " + - "is enabled and the vectorized reader is not used. You can ensure the vectorized reader " + - "is not used by setting 'spark.sql.parquet.enableVectorizedReader' to false.") - .booleanConf - .createWithDefault(false) - val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + @@ -482,6 +478,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") + .doc("If true, enables Parquet's native record-level filtering using the pushed down " + + "filters. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' " + + "is enabled and the vectorized reader is not used. You can ensure the vectorized reader " + + s"is not used by setting '${PARQUET_VECTORIZED_READER_ENABLED.key}' to false.") + .booleanConf + .createWithDefault(false) + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") @@ -642,7 +647,7 @@ object SQLConf { .internal() .doc("When true, a table created by a Hive CTAS statement (no USING clause) " + "without specifying any storage property will be converted to a data source table, " + - "using the data source set by spark.sql.sources.default.") + s"using the data source set by ${DEFAULT_DATA_SOURCE_NAME.key}.") .booleanConf .createWithDefault(false) @@ -1108,7 +1113,7 @@ object SQLConf { val DEFAULT_SIZE_IN_BYTES = buildConf("spark.sql.defaultSizeInBytes") .internal() .doc("The default table size used in query planning. By default, it is set to Long.MaxValue " + - "which is larger than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. " + + s"which is larger than `${AUTO_BROADCASTJOIN_THRESHOLD.key}` to be more conservative. " + "That is to say by default the optimizer will not choose to broadcast a table unless it " + "knows for sure its size is small enough.") .longConf @@ -1279,7 +1284,7 @@ object SQLConf { val ARROW_FALLBACK_ENABLED = buildConf("spark.sql.execution.arrow.fallback.enabled") - .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + .doc(s"When true, optimizations enabled by '${ARROW_EXECUTION_ENABLED.key}' will " + "fallback automatically to non-optimized implementations if an error occurs.") .booleanConf .createWithDefault(true) @@ -1492,7 +1497,7 @@ object SQLConf { val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") .doc("The max number of rows that are returned by eager evaluation. This only takes " + - "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + s"effect when ${REPL_EAGER_EVAL_ENABLED.key} is set to true. The valid range of this " + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") .intConf @@ -1500,7 +1505,7 @@ object SQLConf { val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") .doc("The max number of characters for each cell that is returned by eager evaluation. " + - "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + s"This only takes effect when ${REPL_EAGER_EVAL_ENABLED.key} is set to true.") .intConf .createWithDefault(20) From eab39f79e4c2fb51266ff5844114ee56b8ec2d91 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Tue, 30 Oct 2018 20:13:18 +0800 Subject: [PATCH 1951/2461] [SPARK-25755][SQL][TEST] Supplementation of non-CodeGen unit tested for BroadcastHashJoinExec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently, the BroadcastHashJoinExec physical plan supports CodeGen and non-codegen, but only CodeGen code is tested in the unit tests of InnerJoinSuite、OuterJoinSuite、ExistenceJoinSuite, and non-codegen code is not tested. This PR supplements this part of the test. ## How was this patch tested? add new unit tested. Closes #22755 from heary-cao/AddTestToBroadcastHashJoinExec. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../spark/sql/DataFrameAggregateSuite.scala | 30 ++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++-- .../spark/sql/DataFrameRangeSuite.scala | 76 ++++++++----------- .../columnar/InMemoryColumnarQuerySuite.scala | 39 +++++----- .../execution/joins/ExistenceJoinSuite.scala | 2 +- .../sql/execution/joins/InnerJoinSuite.scala | 6 +- .../sql/execution/joins/OuterJoinSuite.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 15 ++++ 8 files changed, 90 insertions(+), 95 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..d9ba6e2ce5120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -669,23 +669,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } - Seq(true, false).foreach { codegen => - test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " + - s"results when codegen is enabled: $codegen") { - withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) { - // explicit global aggregations - val emptyAgg = Map.empty[String, String] - checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) - checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) - - // global aggregation is converted to grouping aggregation: - assert(spark.emptyDataFrame.dropDuplicates().count() == 0) - } - } + testWithWholeStageCodegenOnAndOff("SPARK-22951: dropDuplicates on empty dataFrames " + + "should produce correct aggregate") { _ => + // explicit global aggregations + val emptyAgg = Map.empty[String, String] + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + + // global aggregation is converted to grouping aggregation: + assert(spark.emptyDataFrame.dropDuplicates().count() == 0) } test("SPARK-21896: Window functions inside aggregate functions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 60ebc5e6cc09b..666ba35d7a8f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -458,15 +458,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) } - test("SPARK-24633: arrays_zip splits input processing correctly") { - Seq("true", "false").foreach { wholestageCodegenEnabled => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { - val df = spark.range(1) - val exprs = (0 to 5).map(x => array($"id" + lit(x))) - checkAnswer(df.select(arrays_zip(exprs: _*)), - Row(Seq(Row(0, 1, 2, 3, 4, 5)))) - } - } + testWithWholeStageCodegenOnAndOff("SPARK-24633: arrays_zip splits input " + + "processing correctly") { _ => + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) } def testSizeOfMap(sizeOfNull: Any): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index b0b46640ff317..8cc7020579431 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventually { - import testImplicits._ test("SPARK-7150 range api") { // numSlice is greater than length @@ -107,7 +106,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall assert(res17.collect === (1 to 10).map(i => Row(i)).toArray) } - test("Range with randomized parameters") { + testWithWholeStageCodegenOnAndOff("Range with randomized parameters") { codegenEnabled => val MAX_NUM_STEPS = 10L * 1000 val seed = System.currentTimeMillis() @@ -133,25 +132,21 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val expCount = (start until end by step).size val expSum = (start until end by step).sum - for (codegen <- List(false, true)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val res = spark.range(start, end, step, partitions).toDF("id"). - agg(count("id"), sum("id")).collect() - - withClue(s"seed = $seed start = $start end = $end step = $step partitions = " + - s"$partitions codegen = $codegen") { - assert(!res.isEmpty) - assert(res.head.getLong(0) == expCount) - if (expCount > 0) { - assert(res.head.getLong(1) == expSum) - } - } + val res = spark.range(start, end, step, partitions).toDF("id"). + agg(count("id"), sum("id")).collect() + + withClue(s"seed = $seed start = $start end = $end step = $step partitions = " + + s"$partitions codegen = $codegenEnabled") { + assert(!res.isEmpty) + assert(res.head.getLong(0) == expCount) + if (expCount > 0) { + assert(res.head.getLong(1) == expSum) } } } } - test("Cancelling stage in a query with Range.") { + testWithWholeStageCodegenOnAndOff("Cancelling stage in a query with Range.") { _ => val listener = new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { sparkContext.cancelStage(taskStart.stageId) @@ -159,27 +154,25 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } sparkContext.addSparkListener(listener) - for (codegen <- Seq(true, false)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1) - .toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } - } - // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages - sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() + } + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } + + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + sparkContext.removeSparkListener(listener) } @@ -189,14 +182,11 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } - test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") { + testWithWholeStageCodegenOnAndOff("SPARK-21041 SparkSession.range()'s behavior is " + + "inconsistent with SparkContext.range()") { _ => val start = java.lang.Long.MAX_VALUE - 3 val end = java.lang.Long.MIN_VALUE + 2 - Seq("false", "true").foreach { value => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) { - assert(spark.range(start, end, 1).collect.length == 0) - assert(spark.range(start, start, 1).collect.length == 0) - } - } + assert(spark.range(start, end, 1).collect.length == 0) + assert(spark.range(start, start, 1).collect.length == 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index b1b23e4439878..e1567d06e23eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -463,29 +463,26 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(tableScanExec.partitionFilters.isEmpty) } - test("SPARK-22348: table cache should do partition batch pruning") { - Seq("true", "false").foreach { enabled => - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) { - val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y") - df1.unpersist() - df1.cache() - - // Push predicate to the cached table. - val df2 = df1.where("y = 3") - - val planBeforeFilter = df2.queryExecution.executedPlan.collect { - case f: FilterExec => f.child - } - assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) + testWithWholeStageCodegenOnAndOff("SPARK-22348: table cache " + + "should do partition batch pruning") { codegenEnabled => + val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y") + df1.unpersist() + df1.cache() - val execPlan = if (enabled == "true") { - WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) - } else { - planBeforeFilter.head - } - assert(execPlan.executeCollectPublic().length == 0) - } + // Push predicate to the cached table. + val df2 = df1.where("y = 3") + + val planBeforeFilter = df2.queryExecution.executedPlan.collect { + case f: FilterExec => f.child + } + assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) + + val execPlan = if (codegenEnabled == "true") { + WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) + } else { + planBeforeFilter.head } + assert(execPlan.executeCollectPublic().length == 0) } test("SPARK-25727 - otherCopyArgs in InMemoryRelation does not include outputOrdering") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 38377164c10e6..22279a3a43eff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -120,7 +120,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using BroadcastHashJoin") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece112258..f5edd6bbd5e69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -127,7 +127,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } - test(s"$testName using BroadcastHashJoin (build=left)") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => @@ -139,7 +139,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using BroadcastHashJoin (build=right)") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ => extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => @@ -175,7 +175,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using SortMergeJoin") { + testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ => extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b399..513248dae48be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -93,7 +93,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } if (joinType != FullOuter) { - test(s"$testName using BroadcastHashJoin") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => val buildSide = joinType match { case LeftOuter => BuildRight case RightOuter => BuildLeft diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6b03d1e5b7662..23419493e5368 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.Utils @@ -65,6 +66,20 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with } } + /** + * A helper function for turning off/on codegen. + */ + protected def testWithWholeStageCodegenOnAndOff(testName: String)(f: String => Unit): Unit = { + Seq("false", "true").foreach { codegenEnabled => + val isTurnOn = if (codegenEnabled == "true") "on" else "off" + test(s"$testName (whole-stage-codegen ${isTurnOn})") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled) { + f(codegenEnabled) + } + } + } + } + /** * Materialize the test data immediately after the `SQLContext` is set up. * This is necessary if the data is accessed by name but not through direct reference. From 327456b482dec38a19bdc65061b3c2271f86819a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Oct 2018 21:17:40 +0800 Subject: [PATCH 1952/2461] [BUILD][MINOR] release script should not interrupt by svn ## What changes were proposed in this pull request? When running the release script, you will be interrupted unexpectedly ``` ATTENTION! Your password for authentication realm: ASF Committers can only be stored to disk unencrypted! You are advised to configure your system so that Subversion can store passwords encrypted, if possible. See the documentation for details. You can avoid future appearances of this warning by setting the value of the 'store-plaintext-passwords' option to either 'yes' or 'no' in '/home/spark-rm/.subversion/servers'. ----------------------------------------------------------------------- Store password unencrypted (yes/no)? ``` We can avoid it by adding `--no-auth-cache` when running svn command. ## How was this patch tested? manually verified with 2.4.0 RC5 Closes #22885 from cloud-fan/svn. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- dev/create-release/release-build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 26e08684cc6de..2fdb5c8dd38a1 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -326,7 +326,7 @@ if [[ "$1" == "package" ]]; then svn add "svn-spark/${DEST_DIR_NAME}-bin" cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" --no-auth-cache cd .. rm -rf svn-spark fi @@ -354,7 +354,7 @@ if [[ "$1" == "docs" ]]; then svn add "svn-spark/${DEST_DIR_NAME}-docs" cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" --no-auth-cache cd .. rm -rf svn-spark fi From ce40efa200e6cb6f10a289e2ab00f711b0ebd379 Mon Sep 17 00:00:00 2001 From: Shahid Date: Tue, 30 Oct 2018 08:39:30 -0500 Subject: [PATCH 1953/2461] [SPARK-25790][MLLIB] PCA: Support more than 65535 column matrix ## What changes were proposed in this pull request? Spark PCA supports maximum only ~65,535 columns matrix. This is due to the fact that, it computes the Covariance matrix first, then compute principle components. The main bottle neck was computing **covariance matrix.** The limit 65,500 came due to the integer size limit. Because we are passing an array of size n*(n+1)/2 to the breeze library and the size cannot be more than INT_MAX. so, the maximum column size we can give is 65,500. Currently we don't have such limitation for computing SVD in spark. So, we can make use of Spark SVD to compute the PCA, if the number of columns exceeds the limit. Computation of PCA can be done directly using SVD of matrix, instead of finding the covariance matrix. Following are the papers/links for the reference. https://arxiv.org/pdf/1404.1100.pdf https://en.wikipedia.org/wiki/Principal_component_analysis#Singular_value_decomposition http://www.ifis.uni-luebeck.de/~moeller/Lectures/WS-16-17/Web-Mining-Agents/PCA-SVD.pdf ## How was this patch tested? added UT, also manually verified with the existing test for pca, by removing the limit condition in the fit method. Closes #22784 from shahidki31/PCA. Authored-by: Shahid Signed-off-by: Sean Owen --- .../org/apache/spark/mllib/feature/PCA.scala | 20 ++++++++--- .../mllib/linalg/distributed/RowMatrix.scala | 33 ++++++++++++------- .../apache/spark/mllib/feature/PCASuite.scala | 23 ++++++++++++- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index a01503f4b80a6..2fc517cad12db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.stat.Statistics import org.apache.spark.rdd.RDD /** @@ -44,12 +45,21 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { require(k <= numFeatures, s"source vector size $numFeatures must be no less than k=$k") - require(PCAUtil.memoryCost(k, numFeatures) < Int.MaxValue, - "The param k and numFeatures is too large for SVD computation. " + - "Try reducing the parameter k for PCA, or reduce the input feature " + - "vector dimension to make this tractable.") + val mat = if (numFeatures > 65535) { + val meanVector = Statistics.colStats(sources).mean.asBreeze + val meanCentredRdd = sources.map { rowVector => + Vectors.fromBreeze(rowVector.asBreeze - meanVector) + } + new RowMatrix(meanCentredRdd) + } else { + require(PCAUtil.memoryCost(k, numFeatures) < Int.MaxValue, + "The param k and numFeatures is too large for SVD computation. " + + "Try reducing the parameter k for PCA, or reduce the input feature " + + "vector dimension to make this tractable.") + + new RowMatrix(sources) + } - val mat = new RowMatrix(sources) val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) val densePC = pc match { case dm: DenseMatrix => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 78a8810052aef..82ab716ed96a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -370,32 +370,41 @@ class RowMatrix @Since("1.0.0") ( * Each column corresponds for one principal component, * and the columns are in descending order of component variance. * The row data do not need to be "centered" first; it is not necessary for - * the mean of each column to be 0. + * the mean of each column to be 0. But, if the number of columns are more than + * 65535, then the data need to be "centered". * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components, and * a vector of values which indicate how much variance each principal component * explains - * - * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.6.0") def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") - val Cov = computeCovariance().asBreeze.asInstanceOf[BDM[Double]] + if (n > 65535) { + val svd = computeSVD(k) + val s = svd.s.toArray.map(eigValue => eigValue * eigValue / (n - 1)) + val eigenSum = s.sum + val explainedVariance = s.map(_ / eigenSum) - val brzSvd.SVD(u: BDM[Double], s: BDV[Double], _) = brzSvd(Cov) + (svd.V, Vectors.dense(explainedVariance)) + } else { - val eigenSum = s.data.sum - val explainedVariance = s.data.map(_ / eigenSum) + val Cov = computeCovariance().asBreeze.asInstanceOf[BDM[Double]] - if (k == n) { - (Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance)) - } else { - (Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)), - Vectors.dense(Arrays.copyOfRange(explainedVariance, 0, k))) + val brzSvd.SVD(u: BDM[Double], s: BDV[Double], _) = brzSvd(Cov) + + val eigenSum = s.data.sum + val explainedVariance = s.data.map(_ / eigenSum) + + if (k == n) { + (Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance)) + } else { + (Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)), + Vectors.dense(Arrays.copyOfRange(explainedVariance, 0, k))) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index 8eab12416a698..fe49162c66426 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -54,4 +54,25 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { // check overflowing assert(PCAUtil.memoryCost(40000, 60000) > Int.MaxValue) } + + test("number of features more than 65535") { + val data1 = sc.parallelize(Array( + Vectors.dense((1 to 100000).map(_ => 2.0).to[scala.Vector].toArray), + Vectors.dense((1 to 100000).map(_ => 0.0).to[scala.Vector].toArray) + ), 2) + + val pca = new PCA(2).fit(data1) + // Eigen values should not be negative + assert(pca.explainedVariance.values.forall(_ >= 0)) + // Norm of the principal component should be 1.0 + assert(Math.sqrt(pca.pc.values.slice(0, 100000) + .map(Math.pow(_, 2)).sum) ~== 1.0 relTol 1e-8) + // Leading explainedVariance is 1.0 + assert(pca.explainedVariance(0) ~== 1.0 relTol 1e-12) + + // Leading principal component is '1' vector + val firstValue = pca.pc.values(0) + pca.pc.values.slice(0, 100000).map(values => + assert(values ~== firstValue relTol 1e-12)) + } } From a129f079557204e3694754a5f9184c7f178cdf2a Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Tue, 30 Oct 2018 11:01:55 -0500 Subject: [PATCH 1954/2461] [SPARK-23429][CORE][FOLLOWUP] MetricGetter should rename to ExecutorMetricType in comments ## What changes were proposed in this pull request? MetricGetter should rename to ExecutorMetricType in comments. ## How was this patch tested? Just comments, no need to test. Closes #22884 from LantaoJin/SPARK-23429_FOLLOWUP. Authored-by: LantaoJin Signed-off-by: Imran Rashid --- core/src/main/scala/org/apache/spark/Heartbeater.scala | 2 +- .../main/scala/org/apache/spark/executor/ExecutorMetrics.scala | 2 +- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- .../main/scala/org/apache/spark/scheduler/SparkListener.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Heartbeater.scala b/core/src/main/scala/org/apache/spark/Heartbeater.scala index 5ba1b9b2d828e..84091eef04306 100644 --- a/core/src/main/scala/org/apache/spark/Heartbeater.scala +++ b/core/src/main/scala/org/apache/spark/Heartbeater.scala @@ -61,7 +61,7 @@ private[spark] class Heartbeater( /** * Get the current executor level metrics. These are returned as an array, with the index - * determined by MetricGetter.values + * determined by ExecutorMetricType.values */ def getCurrentMetrics(): ExecutorMetrics = { val metrics = ExecutorMetricType.values.map(_.getMetricValue(memoryManager)).toArray diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala index 2933f3ba6d3b5..1befd27de1cba 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala @@ -28,7 +28,7 @@ import org.apache.spark.metrics.ExecutorMetricType @DeveloperApi class ExecutorMetrics private[spark] extends Serializable { - // Metrics are indexed by MetricGetter.values + // Metrics are indexed by ExecutorMetricType.values private val metrics = new Array[Long](ExecutorMetricType.values.length) // the first element is initialized to -1, indicating that the values for the array diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f93d8a8d5de55..34b1160dbbfc3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -265,7 +265,7 @@ private[spark] class DAGScheduler( // (taskId, stageId, stageAttemptId, accumUpdates) accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId, - // executor metrics indexed by MetricGetter.values + // executor metrics indexed by ExecutorMetricType.values executorUpdates: ExecutorMetrics): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, Some(executorUpdates))) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 293e8369677f0..e92b8a2718df0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -175,7 +175,7 @@ case class SparkListenerExecutorMetricsUpdate( * @param execId executor id * @param stageId stage id * @param stageAttemptId stage attempt - * @param executorMetrics executor level metrics, indexed by MetricGetter.values + * @param executorMetrics executor level metrics, indexed by ExecutorMetricType.values */ @DeveloperApi case class SparkListenerStageExecutorMetrics( From 94de5609be27e2618d6d241ec9aa032fbc601b6e Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Tue, 30 Oct 2018 09:18:55 -0700 Subject: [PATCH 1955/2461] [SPARK-25848][SQL][TEST] Refactor CSVBenchmarks to use main method ## What changes were proposed in this pull request? use spark-submit: `bin/spark-submit --class org.apache.spark.sql.execution.datasources.csv.CSVBenchmark --jars ./core/target/spark-core_2.11-3.0.0-SNAPSHOT-tests.jar,./sql/catalyst/target/spark-catalyst_2.11-3.0.0-SNAPSHOT-tests.jar ./sql/core/target/spark-sql_2.11-3.0.0-SNAPSHOT-tests.jar` Generate benchmark result: `SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.datasources.csv.CSVBenchmark"` ## How was this patch tested? manual tests Closes #22845 from heary-cao/CSVBenchmarks. Authored-by: caoxuewen Signed-off-by: Dongjoon Hyun --- sql/core/benchmarks/CSVBenchmark-results.txt | 27 +++++++ ...CSVBenchmarks.scala => CSVBenchmark.scala} | 70 +++++++------------ 2 files changed, 51 insertions(+), 46 deletions(-) create mode 100644 sql/core/benchmarks/CSVBenchmark-results.txt rename sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/{CSVBenchmarks.scala => CSVBenchmark.scala} (61%) diff --git a/sql/core/benchmarks/CSVBenchmark-results.txt b/sql/core/benchmarks/CSVBenchmark-results.txt new file mode 100644 index 0000000000000..865575bec83d8 --- /dev/null +++ b/sql/core/benchmarks/CSVBenchmark-results.txt @@ -0,0 +1,27 @@ +================================================================================================ +Benchmark to measure CSV read/write performance +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +One quoted string 64733 / 64839 0.0 1294653.1 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Select 1000 columns 185609 / 189735 0.0 185608.6 1.0X +Select 100 columns 50195 / 51808 0.0 50194.8 3.7X +Select one column 39266 / 39293 0.0 39265.6 4.7X +count() 10959 / 11000 0.1 10958.5 16.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Select 10 columns + count() 24637 / 24768 0.4 2463.7 1.0X +Select 1 column + count() 20026 / 20076 0.5 2002.6 1.2X +count() 3754 / 3877 2.7 375.4 6.6X + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala similarity index 61% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala index 5d1a874999c09..ce38b08b6fdf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala @@ -16,30 +16,31 @@ */ package org.apache.spark.sql.execution.datasources.csv -import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.{Column, Row, SparkSession} -import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ /** * Benchmark to measure CSV read/write performance. - * To run this: - * spark-submit --class --jars + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars , + * + * 2. build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/CSVBenchmark-results.txt". + * }}} */ -object CSVBenchmarks extends SQLHelper { - val conf = new SparkConf() - - val spark = SparkSession.builder - .master("local[1]") - .appName("benchmark-csv-datasource") - .config(conf) - .getOrCreate() + +object CSVBenchmark extends SqlBasedBenchmark { import spark.implicits._ def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { - val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum, output = output) withTempPath { path => val str = (0 until 10000).map(i => s""""$i"""").mkString(",") @@ -56,20 +57,13 @@ object CSVBenchmarks extends SQLHelper { ds.filter((_: Row) => true).count() } - /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz - - Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - One quoted string 30273 / 30549 0.0 605451.2 1.0X - */ benchmark.run() } } def multiColumnsBenchmark(rowsNum: Int): Unit = { val colsNum = 1000 - val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum, output = output) withTempPath { path => val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) @@ -98,23 +92,14 @@ object CSVBenchmarks extends SQLHelper { ds.count() } - /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz - - Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - Select 1000 columns 81091 / 81692 0.0 81090.7 1.0X - Select 100 columns 30003 / 34448 0.0 30003.0 2.7X - Select one column 24792 / 24855 0.0 24792.0 3.3X - count() 24344 / 24642 0.0 24343.8 3.3X - */ benchmark.run() } } def countBenchmark(rowsNum: Int): Unit = { val colsNum = 10 - val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + val benchmark = + new Benchmark(s"Count a dataset with $colsNum columns", rowsNum, output = output) withTempPath { path => val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) @@ -137,22 +122,15 @@ object CSVBenchmarks extends SQLHelper { ds.count() } - /* - Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz - - Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - --------------------------------------------------------------------------------------------- - Select 10 columns + count() 12598 / 12740 0.8 1259.8 1.0X - Select 1 column + count() 7960 / 8175 1.3 796.0 1.6X - count() 2332 / 2386 4.3 233.2 5.4X - */ benchmark.run() } } - def main(args: Array[String]): Unit = { - quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) - multiColumnsBenchmark(rowsNum = 1000 * 1000) - countBenchmark(10 * 1000 * 1000) + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Benchmark to measure CSV read/write performance") { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) + countBenchmark(10 * 1000 * 1000) + } } } From c36537fcfddc1eae1581b1b84d9d4384c5985c26 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 30 Oct 2018 10:48:04 -0700 Subject: [PATCH 1956/2461] [SPARK-25773][CORE] Cancel zombie tasks in a result stage when the job finishes ## What changes were proposed in this pull request? When a job finishes, there may be some zombie tasks still running due to stage retry. Since a result stage will never be used by other jobs, running these tasks are just wasting the cluster resource. This PR just asks TaskScheduler to cancel the running tasks of a result stage when it's already finished. Credits go to srinathshankar who suggested this idea to me. This PR also fixes two minor issues while I'm touching DAGScheduler: - Invalid spark.job.interruptOnCancel should not crash DAGScheduler. - Non fatal errors should not crash DAGScheduler. ## How was this patch tested? The new unit tests. Closes #22771 from zsxwing/SPARK-25773. Lead-authored-by: Shixiong Zhu Co-authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../apache/spark/scheduler/DAGScheduler.scala | 48 ++++++++++++++--- .../org/apache/spark/SparkContextSuite.scala | 53 ++++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 51 +++++++++++++----- 3 files changed, 129 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 34b1160dbbfc3..06966e77db81e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1295,6 +1295,27 @@ private[spark] class DAGScheduler( Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, taskMetrics)) } + /** + * Check [[SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL]] in job properties to see if we should + * interrupt running tasks. Returns `false` if the property value is not a boolean value + */ + private def shouldInterruptTaskThread(job: ActiveJob): Boolean = { + if (job.properties == null) { + false + } else { + val shouldInterruptThread = + job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + try { + shouldInterruptThread.toBoolean + } catch { + case e: IllegalArgumentException => + logWarning(s"${SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL} in Job ${job.jobId} " + + s"is invalid: $shouldInterruptThread. Using 'false' instead", e) + false + } + } + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -1364,6 +1385,21 @@ private[spark] class DAGScheduler( if (job.numFinished == job.numPartitions) { markStageAsFinished(resultStage) cleanupStateForJobAndIndependentStages(job) + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement + // killTask. + logInfo(s"Job ${job.jobId} is finished. Cancelling potential speculative " + + "or zombie tasks for this job") + // ResultStage is only used by this job. It's safe to kill speculative or + // zombie tasks in this stage. + taskScheduler.killAllTaskAttempts( + stageId, + shouldInterruptTaskThread(job), + reason = "Stage finished") + } catch { + case e: UnsupportedOperationException => + logWarning(s"Could not cancel tasks for stage $stageId", e) + } listenerBus.post( SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) } @@ -1373,7 +1409,7 @@ private[spark] class DAGScheduler( try { job.listener.taskSucceeded(rt.outputId, event.result) } catch { - case e: Exception => + case e: Throwable if !Utils.isFatalError(e) => // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } @@ -1890,10 +1926,6 @@ private[spark] class DAGScheduler( val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true - val shouldInterruptThread = - if (job.properties == null) false - else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean - // Cancel all independent, running stages. val stages = jobIdToStageIds(job.jobId) if (stages.isEmpty) { @@ -1913,12 +1945,12 @@ private[spark] class DAGScheduler( val stage = stageIdToStage(stageId) if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask - taskScheduler.cancelTasks(stageId, shouldInterruptThread) + taskScheduler.cancelTasks(stageId, shouldInterruptTaskThread(job)) markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => - logInfo(s"Could not cancel tasks for stage $stageId", e) - ableToCancelStages = false + logWarning(s"Could not cancel tasks for stage $stageId", e) + ableToCancelStages = false } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index e1666a35271d3..79192f3f3c92c 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -33,7 +33,9 @@ import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFor import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.internal.config.EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} @@ -672,6 +674,55 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + + test("cancel zombie tasks in a result stage when the job finishes") { + val conf = new SparkConf() + .setMaster("local-cluster[1,2,1024]") + .setAppName("test-cluster") + .set("spark.ui.enabled", "false") + // Disable this so that if a task is running, we can make sure the executor will always send + // task metrics via heartbeat to driver. + .set(EXECUTOR_HEARTBEAT_DROP_ZERO_ACCUMULATOR_UPDATES.key, "false") + // Set a short heartbeat interval to send SparkListenerExecutorMetricsUpdate fast + .set("spark.executor.heartbeatInterval", "1s") + sc = new SparkContext(conf) + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") + @volatile var runningTaskIds: Seq[Long] = null + val listener = new SparkListener { + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { + if (executorMetricsUpdate.execId != SparkContext.DRIVER_IDENTIFIER) { + runningTaskIds = executorMetricsUpdate.accumUpdates.map(_._1) + } + } + } + sc.addSparkListener(listener) + sc.range(0, 2).groupBy((x: Long) => x % 2, 2).map { case (x, _) => + val context = org.apache.spark.TaskContext.get() + if (context.stageAttemptNumber == 0) { + if (context.partitionId == 0) { + // Make the first task in the first stage attempt fail. + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0, 0, + new java.io.IOException("fake")) + } else { + // Make the second task in the first stage attempt sleep to generate a zombie task + Thread.sleep(60000) + } + } else { + // Make the second stage attempt successful. + } + x + }.collect() + sc.listenerBus.waitUntilEmpty(10000) + // As executors will send the metrics of running tasks via heartbeat, we can use this to check + // whether there is any running task. + eventually(timeout(10.seconds)) { + // Make sure runningTaskIds has been set + assert(runningTaskIds != null) + // Verify there is no running task. + assert(runningTaskIds.isEmpty) + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index b41d2acab7152..5f4ffa151d19b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1901,27 +1901,50 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } /** - * The job will be failed on first task throwing a DAGSchedulerSuiteDummyException. + * The job will be failed on first task throwing an error. * Any subsequent task WILL throw a legitimate java.lang.UnsupportedOperationException. * If multiple tasks, there exists a race condition between the SparkDriverExecutionExceptions * and their differing causes as to which will represent result for job... */ test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { - val e = intercept[SparkDriverExecutionException] { - // Number of parallelized partitions implies number of tasks of job - val rdd = sc.parallelize(1 to 10, 2) - sc.runJob[Int, Int]( - rdd, - (context: TaskContext, iter: Iterator[Int]) => iter.size, - // For a robust test assertion, limit number of job tasks to 1; that is, - // if multiple RDD partitions, use id of any one partition, say, first partition id=0 - Seq(0), - (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) + failAfter(1.minute) { // If DAGScheduler crashes, the following test will hang forever + for (error <- Seq( + new DAGSchedulerSuiteDummyException, + new AssertionError, // E.g., assert(foo == bar) fails + new NotImplementedError // E.g., call a method with `???` implementation. + )) { + val e = intercept[SparkDriverExecutionException] { + // Number of parallelized partitions implies number of tasks of job + val rdd = sc.parallelize(1 to 10, 2) + sc.runJob[Int, Int]( + rdd, + (context: TaskContext, iter: Iterator[Int]) => iter.size, + // For a robust test assertion, limit number of job tasks to 1; that is, + // if multiple RDD partitions, use id of any one partition, say, first partition id=0 + Seq(0), + (part: Int, result: Int) => throw error) + } + assert(e.getCause eq error) + + // Make sure we can still run commands on our SparkContext + assert(sc.parallelize(1 to 10, 2).count() === 10) + } } - assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) + } - // Make sure we can still run commands on our SparkContext - assert(sc.parallelize(1 to 10, 2).count() === 10) + test(s"invalid ${SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL} should not crash DAGScheduler") { + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "invalid") + try { + intercept[SparkException] { + sc.parallelize(1 to 1, 1).foreach { _ => + throw new DAGSchedulerSuiteDummyException + } + } + // Verify the above job didn't crash DAGScheduler by running a simple job + assert(sc.parallelize(1 to 10, 2).count() === 10) + } finally { + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) + } } test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { From f6cc354d83c2c9a757f9b507aadd4dbdc5825cca Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Tue, 30 Oct 2018 13:52:44 -0700 Subject: [PATCH 1957/2461] [SPARK-24434][K8S] pod template files ## What changes were proposed in this pull request? New feature to pass podspec files for driver and executor pods. ## How was this patch tested? new unit and integration tests - [x] more overwrites in integration tests - [ ] invalid template integration test, documentation Author: Onur Satici Author: Yifei Huang Author: onursatici Closes #22146 from onursatici/pod-template. --- docs/running-on-kubernetes.md | 180 ++++++++++++++++++ .../org/apache/spark/deploy/k8s/Config.scala | 24 +++ .../apache/spark/deploy/k8s/Constants.scala | 10 +- .../deploy/k8s/KubernetesDriverSpec.scala | 7 - .../spark/deploy/k8s/KubernetesUtils.scala | 49 ++++- .../k8s/features/BasicDriverFeatureStep.scala | 8 +- .../features/BasicExecutorFeatureStep.scala | 10 +- .../features/PodTemplateConfigMapStep.scala | 72 +++++++ .../submit/KubernetesClientApplication.scala | 6 +- .../k8s/submit/KubernetesDriverBuilder.scala | 40 +++- .../k8s/ExecutorPodsLifecycleManager.scala | 1 - .../k8s/KubernetesClusterManager.scala | 14 +- .../k8s/KubernetesExecutorBuilder.scala | 38 +++- .../deploy/k8s/KubernetesUtilsSuite.scala | 68 +++++++ .../BasicDriverFeatureStepSuite.scala | 2 +- .../PodTemplateConfigMapStepSuite.scala | 97 ++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 116 ++++++++++- .../k8s/submit/PodBuilderSuiteUtils.scala | 142 ++++++++++++++ .../ExecutorPodsLifecycleManagerSuite.scala | 4 - .../k8s/KubernetesExecutorBuilderSuite.scala | 41 +++- .../src/test/resources/driver-template.yml | 26 +++ .../src/test/resources/executor-template.yml | 25 +++ .../k8s/integrationtest/KubernetesSuite.scala | 8 +- .../integrationtest/PodTemplateSuite.scala | 55 ++++++ 24 files changed, 991 insertions(+), 52 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/resources/driver-template.yml create mode 100644 resource-managers/kubernetes/integration-tests/src/test/resources/executor-template.yml create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 7093ee5a9686d..2917197a2e2ec 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -186,6 +186,22 @@ To use a secret through an environment variable use the following options to the --conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key ``` +## Pod Template +Kubernetes allows defining pods from [template files](https://kubernetes.io/docs/concepts/workloads/pods/pod-overview/#pod-templates). +Spark users can similarly use template files to define the driver or executor pod configurations that Spark configurations do not support. +To do so, specify the spark properties `spark.kubernetes.driver.podTemplateFile` and `spark.kubernetes.executor.podTemplateFile` +to point to local files accessible to the `spark-submit` process. To allow the driver pod access the executor pod template +file, the file will be automatically mounted onto a volume in the driver pod when it's created. +Spark does not do any validation after unmarshalling these template files and relies on the Kubernetes API server for validation. + +It is important to note that Spark is opinionated about certain pod configurations so there are values in the +pod template that will always be overwritten by Spark. Therefore, users of this feature should note that specifying +the pod template file only lets Spark start with a template pod instead of an empty pod during the pod-building process. +For details, see the [full list](#pod-template-properties) of pod template values that will be overwritten by spark. + +Pod template files can also define multiple containers. In such cases, Spark will always assume that the first container in +the list will be the driver or executor container. + ## Using Kubernetes Volumes Starting with Spark 2.4.0, users can mount the following types of Kubernetes [volumes](https://kubernetes.io/docs/concepts/storage/volumes/) into the driver and executor pods: @@ -863,4 +879,168 @@ specific to Spark on Kubernetes. to provide any kerberos credentials for launching a job. + + spark.kubernetes.driver.podTemplateFile + (none) + + Specify the local file that contains the driver [pod template](#pod-template). For example + spark.kubernetes.driver.podTemplateFile=/path/to/driver-pod-template.yaml` + + + + spark.kubernetes.executor.podTemplateFile + (none) + + Specify the local file that contains the executor [pod template](#pod-template). For example + spark.kubernetes.executor.podTemplateFile=/path/to/executor-pod-template.yaml` + + + + +#### Pod template properties + +See the below table for the full list of pod specifications that will be overwritten by spark. + +### Pod Metadata + + + + + + + + + + + + + + + + + + + + + + + +
      Pod metadata keyModified valueDescription
      nameValue of spark.kubernetes.driver.pod.name + The driver pod name will be overwritten with either the configured or default value of + spark.kubernetes.driver.pod.name. The executor pod names will be unaffected. +
      namespaceValue of spark.kubernetes.namespace + Spark makes strong assumptions about the driver and executor namespaces. Both driver and executor namespaces will + be replaced by either the configured or default spark conf value. +
      labelsAdds the labels from spark.kubernetes.{driver,executor}.label.* + Spark will add additional labels specified by the spark configuration. +
      annotationsAdds the annotations from spark.kubernetes.{driver,executor}.annotation.* + Spark will add additional labels specified by the spark configuration. +
      + +### Pod Spec + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Pod spec keyModified valueDescription
      imagePullSecretsAdds image pull secrets from spark.kubernetes.container.image.pullSecrets + Additional pull secrets will be added from the spark configuration to both executor pods. +
      nodeSelectorAdds node selectors from spark.kubernetes.node.selector.* + Additional node selectors will be added from the spark configuration to both executor pods. +
      restartPolicy"never" + Spark assumes that both drivers and executors never restart. +
      serviceAccountValue of spark.kubernetes.authenticate.driver.serviceAccountName + Spark will override serviceAccount with the value of the spark configuration for only + driver pods, and only if the spark configuration is specified. Executor pods will remain unaffected. +
      serviceAccountNameValue of spark.kubernetes.authenticate.driver.serviceAccountName + Spark will override serviceAccountName with the value of the spark configuration for only + driver pods, and only if the spark configuration is specified. Executor pods will remain unaffected. +
      volumesAdds volumes from spark.kubernetes.{driver,executor}.volumes.[VolumeType].[VolumeName].mount.path + Spark will add volumes as specified by the spark conf, as well as additional volumes necessary for passing + spark conf and pod template files. +
      + +### Container spec + +The following affect the driver and executor containers. All other containers in the pod spec will be unaffected. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Container spec keyModified valueDescription
      envAdds env variables from spark.kubernetes.driverEnv.[EnvironmentVariableName] + Spark will add driver env variables from spark.kubernetes.driverEnv.[EnvironmentVariableName], and + executor env variables from spark.executorEnv.[EnvironmentVariableName]. +
      imageValue of spark.kubernetes.{driver,executor}.container.image + The image will be defined by the spark configurations. +
      imagePullPolicyValue of spark.kubernetes.container.image.pullPolicy + Spark will override the pull policy for both driver and executors. +
      nameSee description. + The container name will be assigned by spark ("spark-kubernetes-driver" for the driver container, and + "executor" for each executor container) if not defined by the pod template. If the container is defined by the + template, the template's name will be used. +
      resourcesSee description + The cpu limits are set by spark.kubernetes.{driver,executor}.limit.cores. The cpu is set by + spark.{driver,executor}.cores. The memory request and limit are set by summing the values of + spark.{driver,executor}.memory and spark.{driver,executor}.memoryOverhead. + +
      volumeMountsAdd volumes from spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.{path,readOnly} + Spark will add volumes as specified by the spark conf, as well as additional volumes necessary for passing + spark conf and pod template files. +
      diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index fff8fa4340c35..862f1d63ed39f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -278,6 +278,30 @@ private[spark] object Config extends Logging { .booleanConf .createWithDefault(false) + val KUBERNETES_DRIVER_PODTEMPLATE_FILE = + ConfigBuilder("spark.kubernetes.driver.podTemplateFile") + .doc("File containing a template pod spec for the driver") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_PODTEMPLATE_FILE = + ConfigBuilder("spark.kubernetes.executor.podTemplateFile") + .doc("File containing a template pod spec for executors") + .stringConf + .createOptional + + val KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME = + ConfigBuilder("spark.kubernetes.driver.podTemplateContainerName") + .doc("container name to be used as a basis for the driver in the given pod template") + .stringConf + .createOptional + + val KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME = + ConfigBuilder("spark.kubernetes.executor.podTemplateContainerName") + .doc("container name to be used as a basis for executors in the given pod template") + .stringConf + .createOptional + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 172a9054bb4f2..1c6d53c16871e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -76,9 +76,17 @@ private[spark] object Constants { val ENV_R_PRIMARY = "R_PRIMARY" val ENV_R_ARGS = "R_APP_ARGS" + // Pod spec templates + val EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME = "pod-spec-template.yml" + val EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH = "/opt/spark/pod-template" + val POD_TEMPLATE_VOLUME = "pod-template-volume" + val POD_TEMPLATE_CONFIGMAP = "podspec-configmap" + val POD_TEMPLATE_KEY = "podspec-configmap-key" + // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" - val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" + val DEFAULT_DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" + val DEFAULT_EXECUTOR_CONTAINER_NAME = "spark-kubernetes-executor" val MEMORY_OVERHEAD_MIN_MIB = 384L // Hadoop Configuration diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala index 0c5ae022f4070..fce8c6a4bf494 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -22,10 +22,3 @@ private[spark] case class KubernetesDriverSpec( pod: SparkPod, driverKubernetesResources: Seq[HasMetadata], systemProperties: Map[String, String]) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( - SparkPod.initialPod(), - Seq.empty, - initialProps) -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 0f740454fafc4..6fafac3ee13c9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,14 +16,18 @@ */ package org.apache.spark.deploy.k8s +import java.io.File + import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod} +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, ContainerStateRunning, ContainerStateTerminated, ContainerStateWaiting, ContainerStatus, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[spark] object KubernetesUtils { +private[spark] object KubernetesUtils extends Logging { /** * Extract and parse Spark configuration properties with a given name prefix and @@ -82,6 +86,47 @@ private[spark] object KubernetesUtils { } } + def loadPodFromTemplate( + kubernetesClient: KubernetesClient, + templateFile: File, + containerName: Option[String]): SparkPod = { + try { + val pod = kubernetesClient.pods().load(templateFile).get() + selectSparkContainer(pod, containerName) + } catch { + case e: Exception => + logError( + s"Encountered exception while attempting to load initial pod spec from file", e) + throw new SparkException("Could not load pod from template file.", e) + } + } + + def selectSparkContainer(pod: Pod, containerName: Option[String]): SparkPod = { + def selectNamedContainer( + containers: List[Container], name: String): Option[(Container, List[Container])] = + containers.partition(_.getName == name) match { + case (sparkContainer :: Nil, rest) => Some((sparkContainer, rest)) + case _ => + logWarning( + s"specified container ${name} not found on pod template, " + + s"falling back to taking the first container") + Option.empty + } + val containers = pod.getSpec.getContainers.asScala.toList + containerName + .flatMap(selectNamedContainer(containers, _)) + .orElse(containers.headOption.map((_, containers.tail))) + .map { + case (sparkContainer: Container, rest: List[Container]) => SparkPod( + new PodBuilder(pod) + .editSpec() + .withContainers(rest.asJava) + .endSpec() + .build(), + sparkContainer) + }.getOrElse(SparkPod(pod, new ContainerBuilder().build())) + } + def parseMasterUrl(url: String): String = url.substring("k8s://".length) def formatPairsBundle(pairs: Seq[(String, String)], indent: Int = 1) : String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 575bc54ffe2bb..96b14a0d82b4c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -80,7 +80,7 @@ private[spark] class BasicDriverFeatureStep( ) val driverUIPort = SparkUI.getUIPort(conf.sparkConf) val driverContainer = new ContainerBuilder(pod.container) - .withName(DRIVER_CONTAINER_NAME) + .withName(Option(pod.container.getName).getOrElse(DEFAULT_DRIVER_CONTAINER_NAME)) .withImage(driverContainerImage) .withImagePullPolicy(conf.imagePullPolicy()) .addNewPort() @@ -105,7 +105,7 @@ private[spark] class BasicDriverFeatureStep( .withNewFieldRef("v1", "status.podIP") .build()) .endEnv() - .withNewResources() + .editOrNewResources() .addToRequests("cpu", driverCpuQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .addToRequests("memory", driverMemoryQuantity) @@ -119,9 +119,9 @@ private[spark] class BasicDriverFeatureStep( .addToLabels(conf.roleLabels.asJava) .addToAnnotations(conf.roleAnnotations.asJava) .endMetadata() - .withNewSpec() + .editOrNewSpec() .withRestartPolicy("Never") - .withNodeSelector(conf.nodeSelector().asJava) + .addToNodeSelector(conf.nodeSelector().asJava) .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d89995ba5e4f4..1dab2a834f3e7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -139,10 +139,10 @@ private[spark] class BasicExecutorFeatureStep( } val executorContainer = new ContainerBuilder(pod.container) - .withName("executor") + .withName(Option(pod.container.getName).getOrElse(DEFAULT_EXECUTOR_CONTAINER_NAME)) .withImage(executorContainerImage) .withImagePullPolicy(kubernetesConf.imagePullPolicy()) - .withNewResources() + .editOrNewResources() .addToRequests("memory", executorMemoryQuantity) .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) @@ -173,14 +173,14 @@ private[spark] class BasicExecutorFeatureStep( val executorPod = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(name) - .withLabels(kubernetesConf.roleLabels.asJava) - .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .addToLabels(kubernetesConf.roleLabels.asJava) + .addToAnnotations(kubernetesConf.roleAnnotations.asJava) .addToOwnerReferences(ownerReference.toSeq: _*) .endMetadata() .editOrNewSpec() .withHostname(hostname) .withRestartPolicy("Never") - .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToNodeSelector(kubernetesConf.nodeSelector().asJava) .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala new file mode 100644 index 0000000000000..96a8013246b74 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model.{ConfigMapBuilder, ContainerBuilder, HasMetadata, PodBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class PodTemplateConfigMapStep( + conf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + def configurePod(pod: SparkPod): SparkPod = { + val podWithVolume = new PodBuilder(pod.pod) + .editSpec() + .addNewVolume() + .withName(POD_TEMPLATE_VOLUME) + .withNewConfigMap() + .withName(POD_TEMPLATE_CONFIGMAP) + .addNewItem() + .withKey(POD_TEMPLATE_KEY) + .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME) + .endItem() + .endConfigMap() + .endVolume() + .endSpec() + .build() + + val containerWithVolume = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(POD_TEMPLATE_VOLUME) + .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH) + .endVolumeMount() + .build() + SparkPod(podWithVolume, containerWithVolume) + } + + def getAdditionalPodSystemProperties(): Map[String, String] = Map[String, String]( + KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key -> + (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) + + def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + require(conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) + val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get + val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8) + Seq(new ConfigMapBuilder() + .withNewMetadata() + .withName(POD_TEMPLATE_CONFIGMAP) + .endMetadata() + .addToData(POD_TEMPLATE_KEY, podTemplateString) + .build()) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index c658756cc165b..4b58f8ba3c9bd 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import java.io.StringWriter -import java.util.{Collections, Locale, UUID} +import java.util.{Collections, Locale, Properties, UUID} +import java.util.{Collections, UUID} import java.util.Properties import io.fabric8.kubernetes.api.model._ @@ -227,7 +228,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { clientArguments.driverArgs, clientArguments.maybePyFiles, clientArguments.hadoopConfigDir) - val builder = new KubernetesDriverBuilder val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. @@ -244,7 +244,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - builder, + KubernetesDriverBuilder(kubernetesClient, kubernetesConf.sparkConf), kubernetesConf, kubernetesClient, waitForAppCompletion, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index b0b53321abd25..5565cd74280e6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -16,7 +16,12 @@ */ package org.apache.spark.deploy.k8s.submit -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import java.io.File + +import io.fabric8.kubernetes.client.KubernetesClient + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.{Config, KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} @@ -55,7 +60,11 @@ private[spark] class KubernetesDriverBuilder( provideHadoopGlobalStep: ( KubernetesConf[KubernetesDriverSpecificConf] => KerberosConfDriverFeatureStep) = - new KerberosConfDriverFeatureStep(_)) { + new KerberosConfDriverFeatureStep(_), + providePodTemplateConfigMapStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => PodTemplateConfigMapStep) = + new PodTemplateConfigMapStep(_), + provideInitialPod: () => SparkPod = SparkPod.initialPod) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { @@ -74,6 +83,10 @@ private[spark] class KubernetesDriverBuilder( val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { Seq(provideVolumesStep(kubernetesConf)) } else Nil + val podTemplateFeature = if ( + kubernetesConf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { + Seq(providePodTemplateConfigMapStep(kubernetesConf)) + } else Nil val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { case JavaMainAppResource(_) => @@ -86,14 +99,17 @@ private[spark] class KubernetesDriverBuilder( val maybeHadoopConfigStep = kubernetesConf.hadoopConfSpec.map { _ => - provideHadoopGlobalStep(kubernetesConf)} + provideHadoopGlobalStep(kubernetesConf)} val allFeatures: Seq[KubernetesFeatureConfigStep] = (baseFeatures :+ bindingsStep) ++ secretFeature ++ envSecretFeature ++ volumesFeature ++ - maybeHadoopConfigStep.toSeq + maybeHadoopConfigStep.toSeq ++ podTemplateFeature - var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + var spec = KubernetesDriverSpec( + provideInitialPod(), + driverKubernetesResources = Seq.empty, + kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { val configuredPod = feature.configurePod(spec.pod) val addedSystemProperties = feature.getAdditionalPodSystemProperties() @@ -106,3 +122,17 @@ private[spark] class KubernetesDriverBuilder( spec } } + +private[spark] object KubernetesDriverBuilder { + def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesDriverBuilder = { + conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE) + .map(new File(_)) + .map(file => new KubernetesDriverBuilder(provideInitialPod = () => + KubernetesUtils.loadPodFromTemplate( + kubernetesClient, + file, + conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME)) + )) + .getOrElse(new KubernetesDriverBuilder()) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index 1a75ae00cbd98..77a1d6cfae3bd 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -31,7 +31,6 @@ import org.apache.spark.util.Utils private[spark] class ExecutorPodsLifecycleManager( conf: SparkConf, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, snapshotsStore: ExecutorPodsSnapshotsStore, // Use a best-effort to track which executors have been removed already. It's not generally diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 9999c62c878df..ce10f766334ff 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -69,6 +69,13 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit defaultServiceAccountToken, defaultServiceAccountCaCrt) + if (sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { + KubernetesUtils.loadPodFromTemplate( + kubernetesClient, + new File(sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get), + sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME)) + } + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") @@ -81,13 +88,16 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit .build[java.lang.Long, java.lang.Long]() val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( sc.conf, - new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, removedExecutorsCache) val executorPodsAllocator = new ExecutorPodsAllocator( - sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + sc.conf, + KubernetesExecutorBuilder(kubernetesClient, sc.conf), + kubernetesClient, + snapshotsStore, + new SystemClock()) val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( snapshotsStore, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 6199a8ae30430..089f84dec277f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -16,7 +16,12 @@ */ package org.apache.spark.scheduler.cluster.k8s -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import java.io.File + +import io.fabric8.kubernetes.client.KubernetesClient + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ @@ -35,19 +40,20 @@ private[spark] class KubernetesExecutorBuilder( new LocalDirsFeatureStep(_), provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountVolumesFeatureStep) = - new MountVolumesFeatureStep(_), + new MountVolumesFeatureStep(_), provideHadoopConfStep: ( KubernetesConf[KubernetesExecutorSpecificConf] - => HadoopConfExecutorFeatureStep) = - new HadoopConfExecutorFeatureStep(_), + => HadoopConfExecutorFeatureStep) = + new HadoopConfExecutorFeatureStep(_), provideKerberosConfStep: ( KubernetesConf[KubernetesExecutorSpecificConf] - => KerberosConfExecutorFeatureStep) = - new KerberosConfExecutorFeatureStep(_), + => KerberosConfExecutorFeatureStep) = + new KerberosConfExecutorFeatureStep(_), provideHadoopSparkUserStep: ( KubernetesConf[KubernetesExecutorSpecificConf] - => HadoopSparkUserExecutorFeatureStep) = - new HadoopSparkUserExecutorFeatureStep(_)) { + => HadoopSparkUserExecutorFeatureStep) = + new HadoopSparkUserExecutorFeatureStep(_), + provideInitialPod: () => SparkPod = SparkPod.initialPod) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { @@ -85,10 +91,24 @@ private[spark] class KubernetesExecutorBuilder( volumesFeature ++ maybeHadoopConfFeatureSteps - var executorPod = SparkPod.initialPod() + var executorPod = provideInitialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) } executorPod } } + +private[spark] object KubernetesExecutorBuilder { + def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesExecutorBuilder = { + conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + .map(new File(_)) + .map(file => new KubernetesExecutorBuilder(provideInitialPod = () => + KubernetesUtils.loadPodFromTemplate( + kubernetesClient, + file, + conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME)) + )) + .getOrElse(new KubernetesExecutorBuilder()) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala new file mode 100644 index 0000000000000..7c231586af935 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, PodBuilder} + +import org.apache.spark.SparkFunSuite + +class KubernetesUtilsSuite extends SparkFunSuite { + private val HOST = "test-host" + private val POD = new PodBuilder() + .withNewSpec() + .withHostname(HOST) + .withContainers( + new ContainerBuilder().withName("first").build(), + new ContainerBuilder().withName("second").build()) + .endSpec() + .build() + + test("Selects the given container as spark container.") { + val sparkPod = KubernetesUtils.selectSparkContainer(POD, Some("second")) + assert(sparkPod.pod.getSpec.getHostname == HOST) + assert(sparkPod.pod.getSpec.getContainers.asScala.toList.map(_.getName) == List("first")) + assert(sparkPod.container.getName == "second") + } + + test("Selects the first container if no container name is given.") { + val sparkPod = KubernetesUtils.selectSparkContainer(POD, Option.empty) + assert(sparkPod.pod.getSpec.getHostname == HOST) + assert(sparkPod.pod.getSpec.getContainers.asScala.toList.map(_.getName) == List("second")) + assert(sparkPod.container.getName == "first") + } + + test("Falls back to the first container if given container name does not exist.") { + val sparkPod = KubernetesUtils.selectSparkContainer(POD, Some("does-not-exist")) + assert(sparkPod.pod.getSpec.getHostname == HOST) + assert(sparkPod.pod.getSpec.getContainers.asScala.toList.map(_.getName) == List("second")) + assert(sparkPod.container.getName == "first") + } + + test("constructs spark pod correctly with pod template with no containers") { + val noContainersPod = new PodBuilder(POD).editSpec().withContainers().endSpec().build() + val sparkPod = KubernetesUtils.selectSparkContainer(noContainersPod, Some("does-not-exist")) + assert(sparkPod.pod.getSpec.getHostname == HOST) + assert(sparkPod.container.getName == null) + val sparkPodWithNoContainerName = + KubernetesUtils.selectSparkContainer(noContainersPod, Option.empty) + assert(sparkPodWithNoContainerName.pod.getSpec.getHostname == HOST) + assert(sparkPodWithNoContainerName.container.getName == null) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eebdd157da638..5c6bcc72158be 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -84,7 +84,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val basePod = SparkPod.initialPod() val configuredPod = featureStep.configurePod(basePod) - assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getName === DEFAULT_DRIVER_CONTAINER_NAME) assert(configuredPod.container.getImage === "spark-driver:latest") assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala new file mode 100644 index 0000000000000..d7bbbd121af72 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.{File, PrintWriter} +import java.nio.file.Files + +import io.fabric8.kubernetes.api.model.ConfigMap +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { + private var sparkConf: SparkConf = _ + private var kubernetesConf : KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + private var templateFile: File = _ + + before { + sparkConf = Mockito.mock(classOf[SparkConf]) + kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + "resource", + "app-id", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String], + Option.empty) + templateFile = Files.createTempFile("pod-template", "yml").toFile + templateFile.deleteOnExit() + Mockito.doReturn(Option(templateFile.getAbsolutePath)).when(sparkConf) + .get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + } + + test("Mounts executor template volume if config specified") { + val writer = new PrintWriter(templateFile) + writer.write("pod-template-contents") + writer.close() + + val step = new PodTemplateConfigMapStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val volume = configuredPod.pod.getSpec.getVolumes.get(0) + assert(volume.getName === Constants.POD_TEMPLATE_VOLUME) + assert(volume.getConfigMap.getName === Constants.POD_TEMPLATE_CONFIGMAP) + assert(volume.getConfigMap.getItems.size() === 1) + assert(volume.getConfigMap.getItems.get(0).getKey === Constants.POD_TEMPLATE_KEY) + assert(volume.getConfigMap.getItems.get(0).getPath === + Constants.EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME) + + assert(configuredPod.container.getVolumeMounts.size() === 1) + val volumeMount = configuredPod.container.getVolumeMounts.get(0) + assert(volumeMount.getMountPath === Constants.EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH) + assert(volumeMount.getName === Constants.POD_TEMPLATE_VOLUME) + + val resources = step.getAdditionalKubernetesResources() + assert(resources.size === 1) + assert(resources.head.getMetadata.getName === Constants.POD_TEMPLATE_CONFIGMAP) + assert(resources.head.isInstanceOf[ConfigMap]) + val configMap = resources.head.asInstanceOf[ConfigMap] + assert(configMap.getData.size() === 1) + assert(configMap.getData.containsKey(Constants.POD_TEMPLATE_KEY)) + assert(configMap.getData.containsValue("pod-template-contents")) + + val systemProperties = step.getAdditionalPodSystemProperties() + assert(systemProperties.size === 1) + assert(systemProperties.contains(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key)) + assert(systemProperties.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key).get === + (Constants.EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + + Constants.EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 051d7b6994f5d..84968c3523fc0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -16,8 +16,13 @@ */ package org.apache.spark.deploy.k8s.submit -import org.apache.spark.{SparkConf, SparkFunSuite} +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.mockito.Mockito._ + +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Config.{CONTAINER_IMAGE, KUBERNETES_DRIVER_PODTEMPLATE_FILE, KUBERNETES_EXECUTOR_PODTEMPLATE_FILE} import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} @@ -34,6 +39,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val HADOOP_GLOBAL_STEP_TYPE = "hadoop-global" private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" + private val TEMPLATE_VOLUME_STEP_TYPE = "template-volume" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -68,6 +74,10 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val templateVolumeStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + TEMPLATE_VOLUME_STEP_TYPE, classOf[PodTemplateConfigMapStep] + ) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, @@ -80,7 +90,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => pythonStep, _ => rStep, _ => javaStep, - _ => hadoopGlobalStep) + _ => hadoopGlobalStep, + _ => templateVolumeStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( @@ -252,6 +263,37 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { R_STEP_TYPE) } + test("Apply template volume step if executor template is present.") { + val sparkConf = spy(new SparkConf(false)) + doReturn(Option("filename")).when(sparkConf) + .get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + val conf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("example.jar")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String], + Option.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE, + TEMPLATE_VOLUME_STEP_TYPE) + } + test("Apply HadoopSteps if HADOOP_CONF_DIR is defined.") { val conf = KubernetesConf( new SparkConf(false), @@ -314,7 +356,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { HADOOP_GLOBAL_STEP_TYPE) } - private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { assert(resolvedSpec.systemProperties.size === stepTypes.size) @@ -325,4 +366,73 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { assert(resolvedSpec.systemProperties(stepType) === stepType) } } + + test("Start with empty pod if template is not specified") { + val kubernetesClient = mock(classOf[KubernetesClient]) + val driverBuilder = KubernetesDriverBuilder.apply(kubernetesClient, new SparkConf()) + verify(kubernetesClient, never()).pods() + } + + test("Starts with template if specified") { + val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient() + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") + val kubernetesConf = new KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("example.jar")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String], + Option.empty) + val driverSpec = KubernetesDriverBuilder + .apply(kubernetesClient, sparkConf) + .buildFromFeatures(kubernetesConf) + PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(driverSpec.pod) + } + + test("Throws on misconfigured pod template") { + val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient( + new PodBuilder() + .withNewMetadata() + .addToLabels("test-label-key", "test-label-value") + .endMetadata() + .build()) + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") + val kubernetesConf = new KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("example.jar")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String], + Option.empty) + val exception = intercept[SparkException] { + KubernetesDriverBuilder + .apply(kubernetesClient, sparkConf) + .buildFromFeatures(kubernetesConf) + } + assert(exception.getMessage.contains("Could not load pod from template file.")) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala new file mode 100644 index 0000000000000..c92e9e6e3b6b3 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import java.io.File + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, when} +import org.scalatest.FlatSpec +import scala.collection.JavaConverters._ + +import org.apache.spark.deploy.k8s.SparkPod + +object PodBuilderSuiteUtils extends FlatSpec { + + def loadingMockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = { + val kubernetesClient = mock(classOf[KubernetesClient]) + val pods = + mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]]) + val podResource = mock(classOf[PodResource[Pod, DoneablePod]]) + when(kubernetesClient.pods()).thenReturn(pods) + when(pods.load(any(classOf[File]))).thenReturn(podResource) + when(podResource.get()).thenReturn(pod) + kubernetesClient + } + + def verifyPodWithSupportedFeatures(pod: SparkPod): Unit = { + val metadata = pod.pod.getMetadata + assert(metadata.getLabels.containsKey("test-label-key")) + assert(metadata.getAnnotations.containsKey("test-annotation-key")) + assert(metadata.getNamespace === "namespace") + assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference")) + val spec = pod.pod.getSpec + assert(!spec.getContainers.asScala.exists(_.getName == "executor-container")) + assert(spec.getDnsPolicy === "dns-policy") + assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname"))) + assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference")) + assert(spec.getInitContainers.asScala.exists(_.getName == "init-container")) + assert(spec.getNodeName == "node-name") + assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value") + assert(spec.getSchedulerName === "scheduler") + assert(spec.getSecurityContext.getRunAsUser === 1000L) + assert(spec.getServiceAccount === "service-account") + assert(spec.getSubdomain === "subdomain") + assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key")) + assert(spec.getVolumes.asScala.exists(_.getName == "test-volume")) + val container = pod.container + assert(container.getName === "executor-container") + assert(container.getArgs.contains("arg")) + assert(container.getCommand.equals(List("command").asJava)) + assert(container.getEnv.asScala.exists(_.getName == "env-key")) + assert(container.getResources.getLimits.get("gpu") === + new QuantityBuilder().withAmount("1").build()) + assert(container.getSecurityContext.getRunAsNonRoot) + assert(container.getStdin) + assert(container.getTerminationMessagePath === "termination-message-path") + assert(container.getTerminationMessagePolicy === "termination-message-policy") + assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume")) + + } + + + def podWithSupportedFeatures(): Pod = new PodBuilder() + .withNewMetadata() + .addToLabels("test-label-key", "test-label-value") + .addToAnnotations("test-annotation-key", "test-annotation-value") + .withNamespace("namespace") + .addNewOwnerReference() + .withController(true) + .withName("owner-reference") + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withDnsPolicy("dns-policy") + .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build()) + .withImagePullSecrets( + new LocalObjectReferenceBuilder().withName("local-reference").build()) + .withInitContainers(new ContainerBuilder().withName("init-container").build()) + .withNodeName("node-name") + .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava) + .withSchedulerName("scheduler") + .withNewSecurityContext() + .withRunAsUser(1000L) + .endSecurityContext() + .withServiceAccount("service-account") + .withSubdomain("subdomain") + .withTolerations(new TolerationBuilder() + .withKey("toleration-key") + .withOperator("Equal") + .withEffect("NoSchedule") + .build()) + .addNewVolume() + .withNewHostPath() + .withPath("/test") + .endHostPath() + .withName("test-volume") + .endVolume() + .addNewContainer() + .withArgs("arg") + .withCommand("command") + .addNewEnv() + .withName("env-key") + .withValue("env-value") + .endEnv() + .withImagePullPolicy("Always") + .withName("executor-container") + .withNewResources() + .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava) + .endResources() + .withNewSecurityContext() + .withRunAsNonRoot(true) + .endSecurityContext() + .withStdin(true) + .withTerminationMessagePath("termination-message-path") + .withTerminationMessagePolicy("termination-message-policy") + .addToVolumeMounts( + new VolumeMountBuilder() + .withName("test-volume") + .withMountPath("/test") + .build()) + .endContainer() + .endSpec() + .build() + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index d8409383b4a1c..3995b2afe7c45 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -45,9 +45,6 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte @Mock private var podOperations: PODS = _ - @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ - @Mock private var schedulerBackend: KubernetesClusterSchedulerBackend = _ @@ -64,7 +61,6 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) eventHandlerUnderTest = new ExecutorPodsLifecycleManager( new SparkConf(), - executorBuilder, kubernetesClient, snapshotsStore, removedExecutorsCache) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index b572dac2bf624..fb2509fc1bda5 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -16,12 +16,15 @@ */ package org.apache.spark.scheduler.cluster.k8s -import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.api.model.{Config => _, _} +import io.fabric8.kubernetes.client.KubernetesClient +import org.mockito.Mockito.{mock, never, verify} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.submit.PodBuilderSuiteUtils class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" @@ -193,4 +196,40 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) } } + + test("Starts with empty executor pod if template is not specified") { + val kubernetesClient = mock(classOf[KubernetesClient]) + val executorBuilder = KubernetesExecutorBuilder.apply(kubernetesClient, new SparkConf()) + verify(kubernetesClient, never()).pods() + } + + test("Starts with executor template if specified") { + val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient() + val sparkConf = new SparkConf(false) + .set("spark.driver.host", "https://driver.host.com") + .set(Config.CONTAINER_IMAGE, "spark-executor:latest") + .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "template-file.yaml") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf( + "executor-id", Some(new PodBuilder() + .withNewMetadata() + .withName("driver") + .endMetadata() + .build())), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String], + Option.empty) + val sparkPod = KubernetesExecutorBuilder + .apply(kubernetesClient, sparkConf) + .buildFromFeatures(kubernetesConf) + PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(sparkPod) + } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/driver-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/driver-template.yml new file mode 100644 index 0000000000000..0c185be81d59e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/driver-template.yml @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: v1 +Kind: Pod +metadata: + labels: + template-label-key: driver-template-label-value +spec: + containers: + - name: test-driver-container + image: will-be-overwritten + diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/executor-template.yml b/resource-managers/kubernetes/integration-tests/src/test/resources/executor-template.yml new file mode 100644 index 0000000000000..0282e23a39bd2 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/executor-template.yml @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +apiVersion: v1 +Kind: Pod +metadata: + labels: + template-label-key: executor-template-label-value +spec: + containers: + - name: test-executor-container + image: will-be-overwritten diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index c99a907f98d0a..e2e5880255e2c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite - with PythonTestsSuite with ClientModeTestsSuite + with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with Logging with Eventually with Matchers { import KubernetesSuite._ @@ -288,21 +288,21 @@ private[spark] class KubernetesSuite extends SparkFunSuite protected def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) - assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-executor") assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount === baseMemory) } protected def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) - assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-executor") assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount === standardNonJVMMemory) } protected def doBasicExecutorRPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === rImage) - assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + assert(executorPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-executor") assert(executorPod.getSpec.getContainers.get(0).getResources.getRequests.get("memory").getAmount === standardNonJVMMemory) } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala new file mode 100644 index 0000000000000..e5a847e7210cb --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PodTemplateSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.k8sTestTag + +private[spark] trait PodTemplateSuite { k8sSuite: KubernetesSuite => + + import PodTemplateSuite._ + + test("Start pod creation from template", k8sTestTag) { + sparkAppConf + .set("spark.kubernetes.driver.podTemplateFile", DRIVER_TEMPLATE_FILE.getAbsolutePath) + .set("spark.kubernetes.executor.podTemplateFile", EXECUTOR_TEMPLATE_FILE.getAbsolutePath) + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === image) + assert(driverPod.getSpec.getContainers.get(0).getName === "test-driver-container") + assert(driverPod.getMetadata.getLabels.containsKey(LABEL_KEY)) + assert(driverPod.getMetadata.getLabels.get(LABEL_KEY) === "driver-template-label-value") + }, + executorPodChecker = (executorPod: Pod) => { + assert(executorPod.getSpec.getContainers.get(0).getImage === image) + assert(executorPod.getSpec.getContainers.get(0).getName === "test-executor-container") + assert(executorPod.getMetadata.getLabels.containsKey(LABEL_KEY)) + assert(executorPod.getMetadata.getLabels.get(LABEL_KEY) === "executor-template-label-value") + } + ) + } +} + +private[spark] object PodTemplateSuite { + val LABEL_KEY = "template-label-key" + val DRIVER_TEMPLATE_FILE = new File(getClass.getResource("/driver-template.yml").getFile) + val EXECUTOR_TEMPLATE_FILE = new File(getClass.getResource("/executor-template.yml").getFile) +} From 891032da6f5b3c6a690e2ae44396873aa6a6b91d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 31 Oct 2018 09:18:53 +0800 Subject: [PATCH 1958/2461] [SPARK-25691][SQL] Use semantic equality in AliasViewChild in order to compare attributes ## What changes were proposed in this pull request? When we compare attributes, in general, we should always refer to semantic equality, as the default `equal` method can return false when there are "cosmetic" differences between them, but still they are the same thing; at least we have to consider them so when analyzing/optimizing queries. The PR focuses on the usage and comparison of the `output` of a `LogicalPlan`, which is a `Seq[Attribute]` in `AliasViewChild`. In this case, using equality implicitly fails to check the semantic equality. This results in the operator failing to stabilize. ## How was this patch tested? running the tests with the patch provided by maryannxue in https://github.com/apache/spark/pull/22060 Closes #22713 from mgaido91/SPARK-25691. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/analysis/view.scala | 8 +++--- .../sql/catalyst/optimizer/Optimizer.scala | 5 +--- .../catalyst/plans/logical/LogicalPlan.scala | 14 +++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 25 ++++++++++++++++++- 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index af74693000c44..6134d54531a19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.internal.SQLConf */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case v @ View(desc, output, child) if child.resolved && output != child.output => + case v @ View(desc, output, child) if child.resolved && !v.sameOutput(child) => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames val queryOutput = if (queryColumnNames.nonEmpty) { @@ -70,7 +70,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupp } // Map the attributes in the query output to the attributes in the view output by index. val newOutput = output.zip(queryOutput).map { - case (attr, originAttr) if attr != originAttr => + case (attr, originAttr) if !attr.semanticEquals(originAttr) => // The dataType of the output attributes may be not the same with that of the view // output, so we should cast the attribute to the dataType of the view output attribute. // Will throw an AnalysisException if the cast can't perform or might truncate. @@ -112,8 +112,8 @@ object EliminateView extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // The child should have the same output attributes with the View operator, so we simply // remove the View operator. - case View(_, output, child) => - assert(output == child.output, + case v @ View(_, output, child) => + assert(v.sameOutput(child), s"The output of the child ${child.output.mkString("[", ",", "]")} is different from the " + s"view output ${output.mkString("[", ",", "]")}") child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index da8009d50b5ec..95455ffc0495a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -530,9 +530,6 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway. */ object ColumnPruning extends Rule[LogicalPlan] { - private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = - output1.size == output2.size && - output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand @@ -607,7 +604,7 @@ object ColumnPruning extends Rule[LogicalPlan] { case w: Window if w.windowExpressions.isEmpty => w.child // Eliminate no-op Projects - case p @ Project(_, child) if sameOutput(child.output, p.output) => child + case p @ Project(_, child) if child.sameOutput(p) => child // Can't prune the columns on LeafNode case p @ Project(_, _: LeafNode) => p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 339fbb8d8b57a..a520eba001af1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -130,6 +130,20 @@ abstract class LogicalPlan * Returns the output ordering that this plan generates. */ def outputOrdering: Seq[SortOrder] = Nil + + /** + * Returns true iff `other`'s output is semantically the same, ie.: + * - it contains the same number of `Attribute`s; + * - references are the same; + * - the order is equal too. + */ + def sameOutput(other: LogicalPlan): Boolean = { + val thisOutput = this.output + val otherOutput = other.output + thisOutput.length == otherOutput.length && thisOutput.zip(otherOutput).forall { + case (a1, a2) => a1.semanticEquals(a2) + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index d8cb6f7caa99e..da3ae72c3682a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import java.util.TimeZone +import java.util.{Locale, TimeZone} import scala.reflect.ClassTag @@ -25,6 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -604,4 +606,25 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(input, expected) } } + + test("SPARK-25691: AliasViewChild with different nullabilities") { + object ViewAnalyzer extends RuleExecutor[LogicalPlan] { + val batches = Batch("View", Once, AliasViewChild(conf), EliminateView) :: Nil + } + val relation = LocalRelation('a.int.notNull, 'b.string) + val view = View(CatalogTable( + identifier = TableIdentifier("v1"), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = StructType(Seq(StructField("a", IntegerType), StructField("b", StringType)))), + output = Seq('a.int, 'b.string), + child = relation) + val tz = Option(conf.sessionLocalTimeZone) + val expected = Project(Seq( + Alias(Cast('a.int.notNull, IntegerType, tz), "a")(), + Alias(Cast('b.string, StringType, tz), "b")()), + relation) + val res = ViewAnalyzer.execute(view) + comparePlans(res, expected) + } } From f6ff6329eee720e19a56b90c0ffda9da5cecca5b Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Wed, 31 Oct 2018 10:28:17 +0800 Subject: [PATCH 1959/2461] [SPARK-25847][SQL][TEST] Refactor JSONBenchmarks to use main method ## What changes were proposed in this pull request? Refactor JSONBenchmark to use main method use spark-submit: `bin/spark-submit --class org.apache.spark.sql.execution.datasources.json.JSONBenchmark --jars ./core/target/spark-core_2.11-3.0.0-SNAPSHOT-tests.jar,./sql/catalyst/target/spark-catalyst_2.11-3.0.0-SNAPSHOT-tests.jar ./sql/core/target/spark-sql_2.11-3.0.0-SNAPSHOT-tests.jar` Generate benchmark result: `SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.datasources.json.JSONBenchmark"` ## How was this patch tested? manual tests Closes #22844 from heary-cao/JSONBenchmarks. Lead-authored-by: caoxuewen Co-authored-by: heary Co-authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- sql/core/benchmarks/JSONBenchmark-results.txt | 37 ++++++++ ...onBenchmarks.scala => JsonBenchmark.scala} | 86 ++++++------------- 2 files changed, 63 insertions(+), 60 deletions(-) create mode 100644 sql/core/benchmarks/JSONBenchmark-results.txt rename sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/{JsonBenchmarks.scala => JsonBenchmark.scala} (61%) diff --git a/sql/core/benchmarks/JSONBenchmark-results.txt b/sql/core/benchmarks/JSONBenchmark-results.txt new file mode 100644 index 0000000000000..99937309a4145 --- /dev/null +++ b/sql/core/benchmarks/JSONBenchmark-results.txt @@ -0,0 +1,37 @@ +================================================================================================ +Benchmark for performance of JSON parsing +================================================================================================ + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +No encoding 62946 / 63310 1.6 629.5 1.0X +UTF-8 is set 112814 / 112866 0.9 1128.1 0.6X + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +No encoding 16468 / 16553 6.1 164.7 1.0X +UTF-8 is set 16420 / 16441 6.1 164.2 1.0X + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +No encoding 39789 / 40053 0.3 3978.9 1.0X +UTF-8 is set 39505 / 39584 0.3 3950.5 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Select 10 columns + count() 15997 / 16015 0.6 1599.7 1.0X +Select 1 column + count() 13280 / 13326 0.8 1328.0 1.2X +count() 3006 / 3021 3.3 300.6 5.3X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala similarity index 61% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala index 3c4a5ab32724b..04f724ec8638f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala @@ -16,32 +16,31 @@ */ package org.apache.spark.sql.execution.datasources.json -import java.io.File - -import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ /** * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. - * To run this: - * spark-submit --class --jars + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars , + * + * 2. build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/JSONBenchmark-results.txt". + * }}} */ -object JSONBenchmarks extends SQLHelper { - val conf = new SparkConf() - - val spark = SparkSession.builder - .master("local[1]") - .appName("benchmark-json-datasource") - .config(conf) - .getOrCreate() + +object JSONBenchmark extends SqlBasedBenchmark { import spark.implicits._ def schemaInferring(rowsNum: Int): Unit = { - val benchmark = new Benchmark("JSON schema inferring", rowsNum) + val benchmark = new Benchmark("JSON schema inferring", rowsNum, output = output) withTempPath { path => // scalastyle:off println @@ -65,21 +64,12 @@ object JSONBenchmarks extends SQLHelper { .json(path.getAbsolutePath) } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_172-b11 on Mac OS X 10.13.5 - Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - - JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - --------------------------------------------------------------------------------------------- - No encoding 45908 / 46480 2.2 459.1 1.0X - UTF-8 is set 68469 / 69762 1.5 684.7 0.7X - */ benchmark.run() } } def perlineParsing(rowsNum: Int): Unit = { - val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + val benchmark = new Benchmark("JSON per-line parsing", rowsNum, output = output) withTempPath { path => // scalastyle:off println @@ -107,21 +97,12 @@ object JSONBenchmarks extends SQLHelper { .count() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_172-b11 on Mac OS X 10.13.5 - Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - - JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - --------------------------------------------------------------------------------------------- - No encoding 9982 / 10237 10.0 99.8 1.0X - UTF-8 is set 16373 / 16806 6.1 163.7 0.6X - */ benchmark.run() } } def perlineParsingOfWideColumn(rowsNum: Int): Unit = { - val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum, output = output) withTempPath { path => // scalastyle:off println @@ -156,22 +137,14 @@ object JSONBenchmarks extends SQLHelper { .count() } - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_172-b11 on Mac OS X 10.13.5 - Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz - - JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - --------------------------------------------------------------------------------------------- - No encoding 26038 / 26386 0.4 2603.8 1.0X - UTF-8 is set 28343 / 28557 0.4 2834.3 0.9X - */ benchmark.run() } } def countBenchmark(rowsNum: Int): Unit = { val colsNum = 10 - val benchmark = new Benchmark(s"Count a dataset with $colsNum columns", rowsNum) + val benchmark = + new Benchmark(s"Count a dataset with $colsNum columns", rowsNum, output = output) withTempPath { path => val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) @@ -195,23 +168,16 @@ object JSONBenchmarks extends SQLHelper { ds.count() } - /* - Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz - - Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - --------------------------------------------------------------------------------------------- - Select 10 columns + count() 9961 / 10006 1.0 996.1 1.0X - Select 1 column + count() 8355 / 8470 1.2 835.5 1.2X - count() 2104 / 2156 4.8 210.4 4.7X - */ benchmark.run() } } - def main(args: Array[String]): Unit = { - schemaInferring(100 * 1000 * 1000) - perlineParsing(100 * 1000 * 1000) - perlineParsingOfWideColumn(10 * 1000 * 1000) - countBenchmark(10 * 1000 * 1000) + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Benchmark for performance of JSON parsing") { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + countBenchmark(10 * 1000 * 1000) + } } } From 243ce319a06f20365d5b08d479642d75748645d9 Mon Sep 17 00:00:00 2001 From: shane knapp Date: Wed, 31 Oct 2018 10:32:26 +0800 Subject: [PATCH 1960/2461] [SPARKR] found some extra whitespace in the R tests ## What changes were proposed in this pull request? during my ubuntu-port testing, i found some extra whitespace that for some reason wasn't getting caught on the centos lint-r build step. ## How was this patch tested? the build system will test this! i used one of my ubuntu testing builds and scped over the modified file. before my fix: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7-ubuntu-testing/22/console after my fix: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7-ubuntu-testing/23/console Closes #22896 from shaneknapp/remove-extra-whitespace. Authored-by: shane knapp Signed-off-by: hyukjinkwon --- R/pkg/tests/fulltests/test_sparkSQL_eager.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL_eager.R b/R/pkg/tests/fulltests/test_sparkSQL_eager.R index df7354fa063e9..9b4489a47b655 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL_eager.R +++ b/R/pkg/tests/fulltests/test_sparkSQL_eager.R @@ -22,12 +22,12 @@ context("test show SparkDataFrame when eager execution is enabled.") test_that("eager execution is not enabled", { # Start Spark session without eager execution enabled sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) - + df <- createDataFrame(faithful) expect_is(df, "SparkDataFrame") expected <- "eruptions:double, waiting:double" expect_output(show(df), expected) - + # Stop Spark session sparkR.session.stop() }) @@ -35,9 +35,9 @@ test_that("eager execution is not enabled", { test_that("eager execution is enabled", { # Start Spark session with eager execution enabled sparkConfig <- list(spark.sql.repl.eagerEval.enabled = "true") - + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkConfig) - + df <- createDataFrame(faithful) expect_is(df, "SparkDataFrame") expected <- paste0("(+---------+-------+\n", @@ -45,7 +45,7 @@ test_that("eager execution is enabled", { "+---------+-------+\n)*", "(only showing top 20 rows)") expect_output(show(df), expected) - + # Stop Spark session sparkR.session.stop() }) @@ -55,9 +55,9 @@ test_that("eager execution is enabled with maxNumRows and truncate set", { sparkConfig <- list(spark.sql.repl.eagerEval.enabled = "true", spark.sql.repl.eagerEval.maxNumRows = as.integer(5), spark.sql.repl.eagerEval.truncate = as.integer(2)) - + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkConfig) - + df <- arrange(createDataFrame(faithful), "waiting") expect_is(df, "SparkDataFrame") expected <- paste0("(+---------+-------+\n", @@ -66,7 +66,7 @@ test_that("eager execution is enabled with maxNumRows and truncate set", { "| 1.| 43|\n)*", "(only showing top 5 rows)") expect_output(show(df), expected) - + # Stop Spark session sparkR.session.stop() }) From 9cf9a83afafb88668c95ca704a1f65a91b5e591c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Oct 2018 21:27:17 -0700 Subject: [PATCH 1961/2461] [SPARK-25862][SQL] Remove rangeBetween APIs introduced in SPARK-21608 ## What changes were proposed in this pull request? This patch removes the rangeBetween functions introduced in SPARK-21608. As explained in SPARK-25841, these functions are confusing and don't quite work. We will redesign them and introduce better ones in SPARK-25843. ## How was this patch tested? Removed relevant test cases as well. These test cases will need to be added back in SPARK-25843. Closes #22870 from rxin/SPARK-25862. Lead-authored-by: Reynold Xin Co-authored-by: hyukjinkwon Signed-off-by: gatorsmile --- .../expressions/windowExpressions.scala | 2 +- .../apache/spark/sql/expressions/Window.scala | 9 --- .../spark/sql/expressions/WindowSpec.scala | 12 ---- .../org/apache/spark/sql/functions.scala | 26 ------- .../sql-tests/results/window.sql.out | 2 +- .../sql/DataFrameWindowFramesSuite.scala | 68 +------------------ 6 files changed, 3 insertions(+), 116 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 7de6dddda4d3d..0b674d025d1ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -206,7 +206,7 @@ case class SpecifiedWindowFrame( // Check combination (of expressions). (lower, upper) match { case (l: Expression, u: Expression) if !isValidFrameBoundary(l, u) => - TypeCheckFailure(s"Window frame upper bound '$upper' does not followes the lower bound " + + TypeCheckFailure(s"Window frame upper bound '$upper' does not follow the lower bound " + s"'$lower'.") case (l: SpecialFrameBoundary, _) => TypeCheckSuccess case (_, u: SpecialFrameBoundary) => TypeCheckSuccess diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 14dec8f0810f2..d50031bb20621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -214,15 +214,6 @@ object Window { spec.rangeBetween(start, end) } - /** - * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. - * @since 2.3.0 - */ - @deprecated("Use the version with Long parameter types", "2.4.0") - def rangeBetween(start: Column, end: Column): WindowSpec = { - spec.rangeBetween(start, end) - } - private[sql] def spec: WindowSpec = { new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 0cc43a58237df..b7f3000880aca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -209,18 +209,6 @@ class WindowSpec private[sql]( SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) } - /** - * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. - * @since 2.3.0 - */ - @deprecated("Use the version with Long parameter types", "2.4.0") - def rangeBetween(start: Column, end: Column): WindowSpec = { - new WindowSpec( - partitionSpec, - orderSpec, - SpecifiedWindowFrame(RangeFrame, start.expr, end.expr)) - } - /** * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 757a3226855c5..5348b65d43b38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -829,32 +829,6 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. - * - * @group window_funcs - * @since 2.3.0 - */ - @deprecated("Use Window.unboundedPreceding", "2.4.0") - def unboundedPreceding(): Column = Column(UnboundedPreceding) - - /** - * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. - * - * @group window_funcs - * @since 2.3.0 - */ - @deprecated("Use Window.unboundedFollowing", "2.4.0") - def unboundedFollowing(): Column = Column(UnboundedFollowing) - - /** - * This function has been deprecated in Spark 2.4. See SPARK-25842 for more information. - * - * @group window_funcs - * @since 2.3.0 - */ - @deprecated("Use Window.currentRow", "2.4.0") - def currentRow(): Column = Column(CurrentRow) /** * Window function: returns the cumulative distribution of values within a window partition, diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 4afbcd62853dc..5071e0bd26b2a 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -191,7 +191,7 @@ ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, v struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -cannot resolve 'ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING' due to data type mismatch: Window frame upper bound '1' does not followes the lower bound 'unboundedfollowing$()'.; line 1 pos 33 +cannot resolve 'ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING' due to data type mismatch: Window frame upper bound '1' does not follow the lower bound 'unboundedfollowing$()'.; line 1 pos 33 -- !query 12 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 9c280744682b8..002c17f4cce4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} +import java.sql.Date import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.CalendarInterval /** * Window frame testing for DataFrame API. @@ -219,71 +218,6 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) ) - - def dt(date: String): Date = Date.valueOf(date) - - val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), - (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) - .toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)) - - checkAnswer( - df2.select( - $"key", - count("key").over(window)), - Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), - Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) - ) - } - - test("range between should accept double values as boundary") { - val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"), - (100.001D, "2")).toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D)) - - checkAnswer( - df.select( - $"key", - count("key").over(window)), - Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) - ) - } - - test("range between should accept interval values as boundary") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours"))) - - checkAnswer( - df.select( - $"key", - count("key").over(window)), - Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), - Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) - ) - } - - test("range between should accept interval values as both boundaries") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - val window = Window.partitionBy($"value").orderBy($"key") - .rangeBetween(lit(CalendarInterval.fromString("interval 3 hours")), - lit(CalendarInterval.fromString("interval 23 days 4 hours"))) - - checkAnswer( - df.select( - $"key", - count("key").over(window)), - Seq(Row(ts(1501545600), 1), Row(ts(1501545600), 1), Row(ts(1609372800), 0), - Row(ts(1503000000), 0), Row(ts(1502000000), 0), Row(ts(1609372800), 0)) - ) } test("unbounded rows/range between with aggregation") { From 49bea5a7e87ec3ce9cd9466725d81096a54a591b Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 30 Oct 2018 23:05:31 -0700 Subject: [PATCH 1962/2461] [SPARK-25833][SQL][DOCS] Update migration guide for Hive view compatibility ## What changes were proposed in this pull request? Both Spark and Hive support views. However in some cases views created by Hive are not readable by Spark. For example, if column aliases are not specified in view definition queries, both Spark and Hive will generate alias names, but in different ways. In order for Spark to be able to read views created by Hive, users should explicitly specify column aliases in view definition queries. Given that it's not uncommon that Hive and Spark are used together in enterprise data warehouse, this PR aims to explicitly describe this compatibility issue to help users troubleshoot this issue easily. ## How was this patch tested? Docs are manually generated and checked locally. ``` SKIP_API=1 jekyll serve ``` Closes #22868 from seancxmao/SPARK-25833. Authored-by: seancxmao Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide-hive-compatibility.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/sql-migration-guide-hive-compatibility.md b/docs/sql-migration-guide-hive-compatibility.md index 0234ea28bb333..94849418030ef 100644 --- a/docs/sql-migration-guide-hive-compatibility.md +++ b/docs/sql-migration-guide-hive-compatibility.md @@ -51,6 +51,21 @@ Spark SQL supports the vast majority of Hive features, such as: * Explain * Partitioned tables including dynamic partition insertion * View + * If column aliases are not specified in view definition queries, both Spark and Hive will + generate alias names, but in different ways. In order for Spark to be able to read views created + by Hive, users should explicitly specify column aliases in view definition queries. As an + example, Spark cannot read `v1` created as below by Hive. + + ``` + CREATE VIEW v1 AS SELECT * FROM (SELECT c + 1 FROM (SELECT 1 c) t1) t2; + ``` + + Instead, you should create `v1` as below with column aliases explicitly specified. + + ``` + CREATE VIEW v1 AS SELECT * FROM (SELECT c + 1 AS inc_c FROM (SELECT 1 c) t1) t2; + ``` + * All Hive DDL Functions, including: * `CREATE TABLE` * `CREATE TABLE AS SELECT` From 0ad93b0931683a58d4372a934656b5c2dbe9300a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 30 Oct 2018 23:59:37 -0700 Subject: [PATCH 1963/2461] [SPARK-25883][SQL][MINOR] Override method `prettyName` in `from_avro`/`to_avro` ## What changes were proposed in this pull request? Previously in from_avro/to_avro, we override the method `simpleString` and `sql` for the string output. However, the override only affects the alias naming: ``` Project [from_avro('col, ... , (mode,PERMISSIVE)) AS from_avro(col, struct, Map(mode -> PERMISSIVE))#11] ``` It only makes the alias name quite long: `from_avro(col, struct, Map(mode -> PERMISSIVE))`). We should follow `from_csv`/`from_json` here, to override the method prettyName only, and we will get a clean alias name ``` ... AS from_avro(col)#11 ``` ## How was this patch tested? Manual check Closes #22890 from gengliangwang/revise_from_to_avro. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../org/apache/spark/sql/avro/AvroDataToCatalyst.scala | 8 +------- .../org/apache/spark/sql/avro/CatalystDataToAvro.scala | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 43d3f6efb2a0c..ae61587fe1bb7 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -109,13 +109,7 @@ case class AvroDataToCatalyst( } } - override def simpleString: String = { - s"from_avro(${child.sql}, ${dataType.simpleString}, ${options.toString()})" - } - - override def sql: String = { - s"from_avro(${child.sql}, ${dataType.catalogString}, ${options.toString()})" - } + override def prettyName: String = "from_avro" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val expr = ctx.addReferenceObj("this", this) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala index 141ff3782adfb..6ed330d92f5e6 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -52,13 +52,7 @@ case class CatalystDataToAvro(child: Expression) extends UnaryExpression { out.toByteArray } - override def simpleString: String = { - s"to_avro(${child.sql}, ${child.dataType.simpleString})" - } - - override def sql: String = { - s"to_avro(${child.sql}, ${child.dataType.catalogString})" - } + override def prettyName: String = "to_avro" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val expr = ctx.addReferenceObj("this", this) From 34c3bc9f1e2750bbcb91a8706ab78c6a58113350 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 31 Oct 2018 02:57:39 -0700 Subject: [PATCH 1964/2461] [SPARK-25618][SQL][TEST] Reduce time taken to execute KafkaContinuousSourceStressForDontFailOnDataLossSuite ## What changes were proposed in this pull request? In this test, i have reduced the test time to 20 secs from 1 minute while reducing the sleep time from 1 sec to 100 milliseconds. With this change, i was able to run the test in 20+ seconds consistently on my laptop. I would like see if it passes in jenkins consistently. ## How was this patch tested? Its a test fix. Closes #22900 from dilipbiswal/SPARK-25618. Authored-by: Dilip Biswal Signed-off-by: Dongjoon Hyun --- .../spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala index 39c4e3fda1a4b..491a9c669bdbe 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDontFailOnDataLossSuite.scala @@ -221,7 +221,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with KafkaM .as[(String, String)] val query = startStream(kafka.map(kv => kv._2.toInt)) - val testTime = 1.minutes + val testTime = 20.seconds val startTime = System.currentTimeMillis() // Track the current existing topics val topics = mutable.ArrayBuffer[String]() @@ -252,7 +252,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with KafkaM testUtils.createTopic(topic, partitions = 1, overwrite = true) logInfo(s"Create topic $topic") case 3 => - Thread.sleep(1000) + Thread.sleep(100) case _ => // Push random messages for (topic <- topics) { val size = Random.nextInt(10) From f8484e49ef83445dd57f8f5ba4b39d2f47bd3c80 Mon Sep 17 00:00:00 2001 From: yucai Date: Wed, 31 Oct 2018 03:03:42 -0700 Subject: [PATCH 1965/2461] [SPARK-25663][SPARK-25661][SQL][TEST] Refactor BuiltInDataSourceWriteBenchmark, DataSourceWriteBenchmark and AvroWriteBenchmark to use main method ## What changes were proposed in this pull request? Refactor BuiltInDataSourceWriteBenchmark, DataSourceWriteBenchmark and AvroWriteBenchmark to use main method. ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.BuiltInDataSourceWriteBenchmark" SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/test:runMain org.apache.spark.sql.execution.benchmark.AvroWriteBenchmark" ``` ## How was this patch tested? manual tests Closes #22861 from yucai/BuiltInDataSourceWriteBenchmark. Lead-authored-by: yucai Co-authored-by: Yucai Yu Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../benchmarks/AvroWriteBenchmark-results.txt | 10 +++ .../benchmark/AvroWriteBenchmark.scala | 27 ++++---- ...uiltInDataSourceWriteBenchmark-results.txt | 60 ++++++++++++++++ .../BuiltInDataSourceWriteBenchmark.scala | 68 +++++++------------ .../benchmark/DataSourceWriteBenchmark.scala | 15 +--- 5 files changed, 108 insertions(+), 72 deletions(-) create mode 100644 external/avro/benchmarks/AvroWriteBenchmark-results.txt create mode 100644 sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt diff --git a/external/avro/benchmarks/AvroWriteBenchmark-results.txt b/external/avro/benchmarks/AvroWriteBenchmark-results.txt new file mode 100644 index 0000000000000..fb2a77333eec5 --- /dev/null +++ b/external/avro/benchmarks/AvroWriteBenchmark-results.txt @@ -0,0 +1,10 @@ +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Avro writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Output Single Int Column 3213 / 3373 4.9 204.3 1.0X +Output Single Double Column 3313 / 3345 4.7 210.7 1.0X +Output Int and String Column 7303 / 7316 2.2 464.3 0.4X +Output Partitions 5309 / 5691 3.0 337.5 0.6X +Output Buckets 7031 / 7557 2.2 447.0 0.5X + diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala index df13b4a1c2d3a..0b11434757c93 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala @@ -19,22 +19,19 @@ package org.apache.spark.sql.execution.benchmark /** * Benchmark to measure Avro data sources write performance. - * Usage: - * 1. with spark-submit: bin/spark-submit --class - * 2. with sbt: build/sbt "avro/test:runMain " + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * --jars ,, + * , + * + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/test:runMain " + * Results will be written to "benchmarks/AvroWriteBenchmark-results.txt". + * }}} */ object AvroWriteBenchmark extends DataSourceWriteBenchmark { - def main(args: Array[String]): Unit = { - /* - Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz - Avro writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 2481 / 2499 6.3 157.8 1.0X - Output Single Double Column 2705 / 2710 5.8 172.0 0.9X - Output Int and String Column 5539 / 5639 2.8 352.2 0.4X - Output Partitions 4613 / 5004 3.4 293.3 0.5X - Output Buckets 5554 / 5561 2.8 353.1 0.4X - */ - runBenchmark("Avro") + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runDataSourceBenchmark("Avro") } } diff --git a/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt new file mode 100644 index 0000000000000..9d656fc10dce4 --- /dev/null +++ b/sql/core/benchmarks/BuiltInDataSourceWriteBenchmark-results.txt @@ -0,0 +1,60 @@ +================================================================================================ +Parquet writer benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Output Single Int Column 2354 / 2438 6.7 149.7 1.0X +Output Single Double Column 2462 / 2485 6.4 156.5 1.0X +Output Int and String Column 8083 / 8100 1.9 513.9 0.3X +Output Partitions 5015 / 5027 3.1 318.8 0.5X +Output Buckets 6883 / 6887 2.3 437.6 0.3X + + +================================================================================================ +ORC writer benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Output Single Int Column 1769 / 1789 8.9 112.4 1.0X +Output Single Double Column 1989 / 2009 7.9 126.5 0.9X +Output Int and String Column 7323 / 7400 2.1 465.6 0.2X +Output Partitions 4374 / 4381 3.6 278.1 0.4X +Output Buckets 6086 / 6104 2.6 386.9 0.3X + + +================================================================================================ +JSON writer benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Output Single Int Column 2954 / 4085 5.3 187.8 1.0X +Output Single Double Column 3832 / 3837 4.1 243.6 0.8X +Output Int and String Column 9591 / 10336 1.6 609.8 0.3X +Output Partitions 4956 / 4994 3.2 315.1 0.6X +Output Buckets 6608 / 6676 2.4 420.1 0.4X + + +================================================================================================ +CSV writer benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Output Single Int Column 4118 / 4125 3.8 261.8 1.0X +Output Single Double Column 4888 / 4891 3.2 310.8 0.8X +Output Int and String Column 9788 / 9872 1.6 622.3 0.4X +Output Partitions 6578 / 6640 2.4 418.2 0.6X +Output Buckets 9125 / 9171 1.7 580.2 0.5X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala index 2de516c19da9e..cd97324c997f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala @@ -18,62 +18,40 @@ package org.apache.spark.sql.execution.benchmark /** * Benchmark to measure built-in data sources write performance. - * By default it measures 4 data source format: Parquet, ORC, JSON, CSV. Run it with spark-submit: - * spark-submit --class - * Or with sbt: - * build/sbt "sql/test:runMain " + * To run this benchmark: + * {{{ + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV. + * 1. without sbt: bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/BuiltInDataSourceWriteBenchmark-results.txt". + * + * To measure specified formats, run it with arguments. + * 1. without sbt: + * bin/spark-submit --class format1 [format2] [...] + * 2. build/sbt "sql/test:runMain format1 [format2] [...]" + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt + * "sql/test:runMain format1 [format2] [...]" + * Results will be written to "benchmarks/BuiltInDataSourceWriteBenchmark-results.txt". + * }}} * - * To measure specified formats, run it with arguments: - * spark-submit --class format1 [format2] [...] - * Or with sbt: - * build/sbt "sql/test:runMain format1 [format2] [...]" */ object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark { - def main(args: Array[String]): Unit = { - val formats: Seq[String] = if (args.isEmpty) { + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val formats: Seq[String] = if (mainArgs.isEmpty) { Seq("Parquet", "ORC", "JSON", "CSV") } else { - args + mainArgs } spark.conf.set("spark.sql.parquet.compression.codec", "snappy") spark.conf.set("spark.sql.orc.compression.codec", "snappy") - /* - Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz - Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 1815 / 1932 8.7 115.4 1.0X - Output Single Double Column 1877 / 1878 8.4 119.3 1.0X - Output Int and String Column 6265 / 6543 2.5 398.3 0.3X - Output Partitions 4067 / 4457 3.9 258.6 0.4X - Output Buckets 5608 / 5820 2.8 356.6 0.3X - - ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 1201 / 1239 13.1 76.3 1.0X - Output Single Double Column 1542 / 1600 10.2 98.0 0.8X - Output Int and String Column 6495 / 6580 2.4 412.9 0.2X - Output Partitions 3648 / 3842 4.3 231.9 0.3X - Output Buckets 5022 / 5145 3.1 319.3 0.2X - - JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 1988 / 2093 7.9 126.4 1.0X - Output Single Double Column 2854 / 2911 5.5 181.4 0.7X - Output Int and String Column 6467 / 6653 2.4 411.1 0.3X - Output Partitions 4548 / 5055 3.5 289.1 0.4X - Output Buckets 5664 / 5765 2.8 360.1 0.4X - CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Output Single Int Column 3025 / 3190 5.2 192.3 1.0X - Output Single Double Column 3575 / 3634 4.4 227.3 0.8X - Output Int and String Column 7313 / 7399 2.2 464.9 0.4X - Output Partitions 5105 / 5190 3.1 324.6 0.6X - Output Buckets 6986 / 6992 2.3 444.1 0.4X - */ formats.foreach { format => - runBenchmark(format) + runBenchmark(s"$format writer benchmark") { + runDataSourceBenchmark(format) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala index 994d6b5b7d334..405d60794ede0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -16,18 +16,9 @@ */ package org.apache.spark.sql.execution.benchmark -import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.SQLConf -trait DataSourceWriteBenchmark { - val conf = new SparkConf() - .setAppName("DataSourceWriteBenchmark") - .setIfMissing("spark.master", "local[1]") - .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - - val spark = SparkSession.builder.config(conf).getOrCreate() +trait DataSourceWriteBenchmark extends SqlBasedBenchmark { val tempTable = "temp" val numRows = 1024 * 1024 * 15 @@ -75,7 +66,7 @@ trait DataSourceWriteBenchmark { } } - def runBenchmark(format: String): Unit = { + def runDataSourceBenchmark(format: String): Unit = { val tableInt = "tableInt" val tableDouble = "tableDouble" val tableIntString = "tableIntString" @@ -84,7 +75,7 @@ trait DataSourceWriteBenchmark { withTempTable(tempTable) { spark.range(numRows).createOrReplaceTempView(tempTable) withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { - val benchmark = new Benchmark(s"$format writer benchmark", numRows) + val benchmark = new Benchmark(s"$format writer benchmark", numRows, output = output) writeNumeric(tableInt, format, benchmark, "Int") writeNumeric(tableDouble, format, benchmark, "Double") writeIntString(tableIntString, format, benchmark) From 3c0e9ce944d98859939bbcbf21c610f4b9b224dd Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Wed, 31 Oct 2018 18:39:15 +0800 Subject: [PATCH 1966/2461] [SPARK-24901][SQL] Merge the codegen of RegularHashMap and fastHashMap to reduce compiler maxCodesize when VectorizedHashMap is false. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently, Generate code of update UnsafeRow in hash aggregation. FastHashMap and RegularHashMap are two separate codes,These two separate codes need only when VectorizedHashMap is true. but other cases, we can merge together to reduce compiler maxCodesize. thanks. ``` import org.apache.spark.sql.execution.debug._ sparkSession.range(1).selectExpr("id AS key", "id AS value").groupBy("key").sum("value").debugCodegen ``` Generate code like: **Before modified:** ``` Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ ............... /* 420 */ if (agg_fastAggBuffer_0 != null) { /* 421 */ // common sub-expressions /* 422 */ /* 423 */ // evaluate aggregate function /* 424 */ agg_agg_isNull_14_0 = true; /* 425 */ long agg_value_15 = -1L; /* 426 */ do { /* 427 */ boolean agg_isNull_15 = agg_fastAggBuffer_0.isNullAt(0); /* 428 */ long agg_value_16 = agg_isNull_15 ? /* 429 */ -1L : (agg_fastAggBuffer_0.getLong(0)); /* 430 */ if (!agg_isNull_15) { /* 431 */ agg_agg_isNull_14_0 = false; /* 432 */ agg_value_15 = agg_value_16; /* 433 */ continue; /* 434 */ } /* 435 */ /* 436 */ // This comment is added for manually tracking reference of 0, false /* 437 */ /* 438 */ boolean agg_isNull_16 = false; /* 439 */ long agg_value_17 = -1L; /* 440 */ if (!false) { /* 441 */ agg_value_17 = (long) 0; /* 442 */ } /* 443 */ if (!agg_isNull_16) { /* 444 */ agg_agg_isNull_14_0 = false; /* 445 */ agg_value_15 = agg_value_17; /* 446 */ continue; /* 447 */ } /* 448 */ /* 449 */ } while (false); /* 450 */ /* 451 */ long agg_value_14 = -1L; /* 452 */ agg_value_14 = agg_value_15 + agg_expr_1_0; /* 453 */ // update fast row /* 454 */ agg_fastAggBuffer_0.setLong(0, agg_value_14); /* 455 */ } else { /* 456 */ // common sub-expressions /* 457 */ /* 458 */ // evaluate aggregate function /* 459 */ agg_agg_isNull_8_0 = true; /* 460 */ long agg_value_9 = -1L; /* 461 */ do { /* 462 */ boolean agg_isNull_9 = agg_unsafeRowAggBuffer_0.isNullAt(0); /* 463 */ long agg_value_10 = agg_isNull_9 ? /* 464 */ -1L : (agg_unsafeRowAggBuffer_0.getLong(0)); /* 465 */ if (!agg_isNull_9) { /* 466 */ agg_agg_isNull_8_0 = false; /* 467 */ agg_value_9 = agg_value_10; /* 468 */ continue; /* 469 */ } /* 470 */ /* 471 */ // This comment is added for manually tracking reference of 0, false /* 472 */ /* 473 */ boolean agg_isNull_10 = false; /* 474 */ long agg_value_11 = -1L; /* 475 */ if (!false) { /* 476 */ agg_value_11 = (long) 0; /* 477 */ } /* 478 */ if (!agg_isNull_10) { /* 479 */ agg_agg_isNull_8_0 = false; /* 480 */ agg_value_9 = agg_value_11; /* 481 */ continue; /* 482 */ } /* 483 */ /* 484 */ } while (false); /* 485 */ /* 486 */ long agg_value_8 = -1L; /* 487 */ agg_value_8 = agg_value_9 + agg_expr_1_0; /* 488 */ // update unsafe row buffer /* 489 */ agg_unsafeRowAggBuffer_0.setLong(0, agg_value_8); /* 490 */ /* 491 */ } ...................... ``` **After modified:** ``` Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ ............. /* 423 */ // Updates the proper row buffer /* 424 */ UnsafeRow agg_aggBuffer_0 = null; /* 425 */ if (agg_fastAggBuffer_0 != null) { /* 426 */ agg_aggBuffer_0 = agg_fastAggBuffer_0; /* 427 */ } else { /* 428 */ agg_aggBuffer_0 = agg_unsafeRowAggBuffer_0; /* 429 */ } /* 430 */ /* 431 */ // common sub-expressions /* 432 */ /* 433 */ // evaluate aggregate function /* 434 */ agg_agg_isNull_8_0 = true; /* 435 */ long agg_value_9 = -1L; /* 436 */ do { /* 437 */ boolean agg_isNull_9 = agg_aggBuffer_0.isNullAt(0); /* 438 */ long agg_value_10 = agg_isNull_9 ? /* 439 */ -1L : (agg_aggBuffer_0.getLong(0)); /* 440 */ if (!agg_isNull_9) { /* 441 */ agg_agg_isNull_8_0 = false; /* 442 */ agg_value_9 = agg_value_10; /* 443 */ continue; /* 444 */ } /* 445 */ /* 446 */ // This comment is added for manually tracking reference of 0, false /* 447 */ /* 448 */ boolean agg_isNull_10 = false; /* 449 */ long agg_value_11 = -1L; /* 450 */ if (!false) { /* 451 */ agg_value_11 = (long) 0; /* 452 */ } /* 453 */ if (!agg_isNull_10) { /* 454 */ agg_agg_isNull_8_0 = false; /* 455 */ agg_value_9 = agg_value_11; /* 456 */ continue; /* 457 */ } /* 458 */ /* 459 */ } while (false); /* 460 */ /* 461 */ long agg_value_8 = -1L; /* 462 */ agg_value_8 = agg_value_9 + agg_expr_1_0; /* 463 */ // update unsafe row buffer /* 464 */ agg_aggBuffer_0.setLong(0, agg_value_8); ........... ``` ## How was this patch tested? the Existed test cases. Closes #21860 from heary-cao/fastHashMap. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../aggregate/HashAggregateExec.scala | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6155ec9d30db4..25d8e7dff3d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -844,33 +844,47 @@ case class HashAggregateExec( val updateRowInHashMap: String = { if (isFastHashMapEnabled) { - ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) - } + if (isVectorizedHashMapEnabled) { + ctx.INPUT_ROW = fastRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + CodeGenerator.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true) + } - // If fast hash map is on, we first generate code to update row in fast hash map, if the - // previous loop up hit fast hash map. Otherwise, update row in regular hash map. - s""" - |if ($fastRowBuffer != null) { - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} - |} else { - | $updateRowInRegularHashMap - |} - """.stripMargin + // If vectorized fast hash map is on, we first generate code to update row + // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. + // Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(fastRowEvals)} + | // update fast row + | ${updateFastRow.mkString("\n").trim} + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + } else { + // If row-based hash map is on and the previous loop up hit fast hash map, + // we reuse regular hash buffer to update row of fast hash map. + // Otherwise, update row in regular hash map. + s""" + |// Updates the proper row buffer + |if ($fastRowBuffer != null) { + | $unsafeRowBuffer = $fastRowBuffer; + |} + |$updateRowInRegularHashMap + """.stripMargin + } } else { updateRowInRegularHashMap } From 57eddc7182ece0030f6d0cc02339c0b8d8c0be5c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 31 Oct 2018 20:22:57 +0800 Subject: [PATCH 1967/2461] [SPARK-25886][SQL][MINOR] Improve error message of `FailureSafeParser` and `from_avro` in FAILFAST mode ## What changes were proposed in this pull request? Currently in `FailureSafeParser` and `from_avro`, the exception is created with such code ``` throw new SparkException("Malformed records are detected in record parsing. " + s"Parse Mode: ${FailFastMode.name}.", e.cause) ``` 1. The cause part should be `e` instead of `e.cause` 2. If `e` contains non-null message, it should be shown in `from_json`/`from_csv`/`from_avro`, e.g. ``` com.fasterxml.jackson.core.JsonParseException: Unexpected character ('1' (code 49)): was expecting a colon to separate field name and value at [Source: (InputStreamReader); line: 1, column: 7] ``` 3.Kindly show hint for trying PERMISSIVE in error message. ## How was this patch tested? Unit test. Closes #22895 from gengliangwang/improve_error_msg. Authored-by: Gengliang Wang Signed-off-by: hyukjinkwon --- .../scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala | 2 +- .../org/apache/spark/sql/catalyst/util/FailureSafeParser.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index ae61587fe1bb7..5656ac7f38e1b 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -102,7 +102,7 @@ case class AvroDataToCatalyst( case FailFastMode => throw new SparkException("Malformed records are detected in record parsing. " + s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + - "result, try setting the option 'mode' as 'PERMISSIVE'.", e.getCause) + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) case _ => throw new AnalysisException(unacceptableModeMessage(parseMode.name)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index fecfff5789a5c..76745b11c84c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -73,7 +73,8 @@ class FailureSafeParser[IN]( Iterator.empty case FailFastMode => throw new SparkException("Malformed records are detected in record parsing. " + - s"Parse Mode: ${FailFastMode.name}.", e.cause) + s"Parse Mode: ${FailFastMode.name}. To process malformed records as null " + + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) } } } From b3af917e76997faca9bd3ed3c5cb4dafd6fac1f3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 31 Oct 2018 09:20:19 -0700 Subject: [PATCH 1968/2461] [SPARK-25893][SQL] Show a directional error message for unsupported Hive Metastore versions ## What changes were proposed in this pull request? When `spark.sql.hive.metastore.version` is misconfigured, we had better give a directional error message. **BEFORE** ```scala scala> sql("show databases").show scala.MatchError: 2.4 (of class java.lang.String) ``` **AFTER** ```scala scala> sql("show databases").show java.lang.UnsupportedOperationException: Unsupported Hive Metastore version (2.4). Please set spark.sql.hive.metastore.version with a valid version. ``` ## How was this patch tested? Manual. Closes #22902 from dongjoon-hyun/SPARK-25893. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 31899370454ba..1e7a0b187c8b3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -100,6 +100,9 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 + case version => + throw new UnsupportedOperationException(s"Unsupported Hive Metastore version ($version). " + + s"Please set ${HiveUtils.HIVE_METASTORE_VERSION.key} with a valid version.") } private def downloadVersion( From e4cb42ad89307ebc5a1bd9660c86219340d71ff6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 31 Oct 2018 09:55:03 -0700 Subject: [PATCH 1969/2461] [SPARK-25891][PYTHON] Upgrade to Py4J 0.10.8.1 ## What changes were proposed in this pull request? Py4J 0.10.8.1 is released on October 21st and is the first release of Py4J to support Python 3.7 officially. We had better have this to get the official support. Also, there are some patches related to garbage collections. https://www.py4j.org/changelog.html#py4j-0-10-8-and-py4j-0-10-8-1 ## How was this patch tested? Pass the Jenkins. Closes #22901 from dongjoon-hyun/SPARK-25891. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- bin/pyspark | 2 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../apache/spark/api/python/PythonUtils.scala | 3 ++- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- python/README.md | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.10.7-src.zip | Bin 42437 -> 0 bytes python/lib/py4j-0.10.8.1-src.zip | Bin 0 -> 41255 bytes python/setup.py | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- sbin/spark-config.sh | 2 +- 14 files changed, 13 insertions(+), 12 deletions(-) delete mode 100644 python/lib/py4j-0.10.7-src.zip create mode 100644 python/lib/py4j-0.10.8.1-src.zip diff --git a/bin/pyspark b/bin/pyspark index 5d5affb1f97c3..1dcddcc6196b8 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.8.1-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 15fa910c277b3..479fd464c7d3e 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.8.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index eff3aa1d19423..f23d09f73657b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -350,7 +350,7 @@ net.sf.py4j py4j - 0.10.7 + 0.10.8.1 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index cdce371dfcbfa..b6b0cac910d69 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,8 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) + pythonPath += + Seq(sparkHome, "python", "lib", "py4j-0.10.8.1-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 0703b5b02b125..db84b85618d8f 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -168,7 +168,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.7.jar +py4j-0.10.8.1.jar pyrolite-4.13.jar scala-compiler-2.11.12.jar scala-library-2.11.12.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 513986820d5fc..befb93da94887 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -186,7 +186,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.7.jar +py4j-0.10.8.1.jar pyrolite-4.13.jar re2j-1.1.jar scala-compiler-2.11.12.jar diff --git a/python/README.md b/python/README.md index c020d84b01ffd..ffb6147dbee8a 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). +At its core PySpark depends on Py4J (currently version 0.10.8.1), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/docs/Makefile b/python/docs/Makefile index 1ed1f33af2326..4767fd9f1c038 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -37,7 +37,7 @@ BUILDDIR ?= _build # 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON # 3. By default, SPHINXBUILD is used as 'sphinx-build'. -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.8.1-src.zip) # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip deleted file mode 100644 index 128e321078793f41154613544ab7016878d5617e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 42437 zcmagEV~j39x2FBHZQHhcw{6?DZQHhX+qP}nHg?>CvkN2FC`PIKdTjbLAX`34>8PU>@i^KjzbxRdnn zB+RF;R|mYL%{iA-2?6_Z zU7!jh>L$P>>K^$-VC%T%aRpM%5D^BJeMvMC>o3P@>OwN1h1~0HkiLZyal6>nXCpXU zA|vKIC@@Rx`$TDd0~n@W!fuGfSr1hzmQ`yWq~XanN(?kcRyPuEDIz*KgU(MqualOl z+=_@y&<$zZq}dBEC{hsI0*=$S>pvAplua&-aa%mWL7=*UfT}-@@j_Ht;JH{)Oxoc! zj|`hX=uZWvajH_HY&+Ntg%azO8S(7ITriEi#GVB6GiJsJ2C+7>xf*q=k?k2_+wg@< zrLWtr*9~{zL4pZOPDKo&Fx&H&n9Vt@r>SiIdBBeXHR#Rq#3|}BsbNL9i##!W3g;ca zEm7DH7{`Vsa3niI_E2NUfk-vT4h8sbC%9C-4{_w!sF2How@~5F^|w)J+V;{nTjry zxbfjP18pM4NSH|&=62wWSMu~AoJD2JQ(!w~R8;EtEe4LZgdq;0JCPRKNkb z|4+q^#QbsUKV+RG8#wC_iq?c+y3ARM7W?$ReEoBD9JI;1O{Q@ZjEIv~e+#W0gzQOC zE2M}>mKN%qu`E6|=QwD$`i9FsSxe6u5+qSJW|>C&z4+i$vZ*ifeOMIEz{9N(<@5%u zHU(khTf6NFod9N?L(#J){c%UN;%WFb+}H1JZP5`ozj1t^Gl%%AwRx2^!BwchFMwJO z>j)lHhpy{d+9m>}bSXMT)464g=JY@mXYtLEI$bDe&c&;-Dr3RTowZjsf;4~J+&bvd zB1hVO=E=Zxsz%`V5?RzT*3ytwP6Z$0Uv5UMyA<_SNTHD_@?3pLboUCx?J;{b$ncaB zAJ3p-s%jf&eXSSD>GUQOPG=S0gmP@1yswUphw6AS)>U~w;xifo9~;?!_}3-;uC;$T zAozDC)|axyy*?gEc^&DOlNL>`K@bIml~+c3U}w`s}ngJPJQQra!J< z7EYZxt$HoA3lI&vO}PD0z~bZL1|~dhlY%zGIDzb4}z5eUA zEN3yJ9;b@VLUS`u02WueBo_1QA@c(2ZeZJf*tKG=@V|2)TO?2B$NBAs84~~?vH$>3 z|2qeaZ7fagT%1jv+)SPRI}6Y_&$Rw){U2Fy&1>zv#hJACO;bM^Uf_ndGCO0pl2+-X z%BMZyvDV<+%F@xXL6?<8m}wlfoTS3jS@C(hh5#6jZu}={b~{u~9Vr3~m%q1TL0C6@ zwbzIKXqA{=75V1ydHSq>H?d&A=%G&68j5=DyBg)wLv^f`W=q1{FrF^0IfUS$J_Xwq z<=T=EnlbUBIQc+{;y=S@df3%%*QQCdYx?-voihLYUE5P9m6%>r&CwFlHzqEAMh2L~ z31d2})$Ev$J=3H*mVf-L){#`4I{mA?vCoqa_O+>EN`eyr$yjXB*urzD89S0)85V~n z=S%uP2CFZ9No!DFI&-N;Utg4U9q!zM4;(x10W2PX2h~IgeXd$V);1bWVrA+VHP$VF z+n?{$tJ>cX8ptEojgLonWkt8>(qY@yBBMNPr`l$ueF{lV=hQE)#!8YnlT`_jplajf zQlg8tY5}gk`YX%kpFkIUEn&1dxGa8{mUo43n>%wto6Ua?5zp*hexfvE+w{20qx7-R3IO(1)UnqZec6Dk7!00=$_SwuKdc?*pmp5OB%}Kr zv6TmG>6YHs)%{#~`|UjQ@#gmS$=c8D$=2}+OhqRzU9_oVm@irP zsN_g;aTJxwYTL3Slh>C6#M>`fw_0C^;9Ib>vpcP-JXIsA4{p(4&Q`pJg2L&3ooU(@ zy{x^sTs-?em-NeAb%9pzugD~fRr$*)X0S2&bS2ET(U9x4>H=4m;JWum_LUPRZFNik zHX6}j%|cD%t0m5{=nw?r;N#&Dp-gJwTHE0Y-@@Bf!5Ne=(Rl zYI2$Gm&|5BV%Fi}&w&iwaONeq|7koGj)q44#WL6*@T86w3%f=}-tS1_R&bZt@nH(l zJSZ3Ru6p-KU1_o*3lxn4UgxSY;~(UgoDWy8B#a-%nb&f%hzM`om9$45ZxX^fAC7kg zXI6%}5hsTrt(|F#3 zwK}d2^Q6$B8)#-72?Xa?<8fPw1X!qFFsF|kW~lKy?JsC6rHqY}RQ{1cHCw5K7XmK5 zYA$c2j3}k57pC#Iif5=o-b&x7PPovjt@26_f)w^|o_Rs6@`Od7l-oSD*0g1-7X>$v z+W+PA`0=_~YV}GWb8Yj;gU$Y7BJ#52ksF7sBoLrQ&3fnMHE}Z@s)%(yjajE!ZQ#_8 z(rx*MhT3MeI1HHO;J_%0$>jzt?Q-Ydp1wN$BrB;$9B)`o*~O429jfm?(L=N&ktr2aQx!?dT` zN`c*(d{Ab#iLRPYefN)&wxXgHSRMaEnJmz-g&`C)Y&1(<8PUKU(-PH2fJD~cEw1!t zW>K~;Dfl1v+nj^IfGkscH(8psqxaTPk8^a^EA9-BT||(r`T>Jf)heQ{ydW&okScM8 za-?&J0fuIA+zS&=;kPjb3A@d6Y&vDZ=wT$yso*ph4@HCQI57Q5YC%Ni|eq^0{9ar>J1HzAV7w^Ezy`tt?kJq zEXP8M{&Y=AzAJ#dkyNAHx*P2zzTrQ{k;g8Ld3#t;85bbqBN-4;9P=+pU^+`iM)4}<}I#IYg z3GWQ*EF|oN&Ot5rG93anmX(i9m<%SkJ86qCiLIS&D71WVc_4LE#+y=3{D=Gi|nPm~W$T_IPrnphZ4+w{Wev!O(ZEB;u zkM35Kn{VW={4TT*xPQk}Itl~Bub@0BdQs=DTUKq85<+;A8nKB+2QR}l<8dV2rk~(s zT0ZJ~d8@Ad;a#^jT|tegDXlY_&3N5Z#D^Hzp_BmE=MGA+kUx$KcHsyHP1IbqHK7JgoShdU6m-E}MsvS$ zNTIt&g(*D-Z25^yUohuR52^o5JBY09%uf19qN63s&y?+q{{TqtkJ7e`qR`fE$fL)d z1n-&!y7JYwPN;@sZM4;h`kkaC4sfddJLtah&D3pyNUFJKTAOx@N@9;p*~@f!%nHKu$FbTo8!X&$u+$eJ#JN$=>q)*w3i0-${nfgS zXR|?wCkhQqVR5hf(;QezAMNfsVjCbY5FdYi4-?WY>cW@dvj?iC@D` zZ+MiT##BD*afvwl)4ax5N@6EK3IX=}k-0g5i6$X_rFPL)VjED;M*6Qu-lmW$%8ZZ# zsh73)OtXEl9^!=xtPd=?p|NWb1S0(6?F^mUQBr12{X)%0+9x|biPaK(;vgnoBiVM; ze%R>-UsB8C(3t6a`4Tyo#8E!BLl#3D>+&PM-$05ze8vRl+sG}ReUNiR9_Oz5e9fp> zX!L$mid_{R5n+~FNs1!gt|QvwRkwCp&HnBOK>xQbDfcWCGqZKR6&D41j}Hygjm?wfZS|IN9=vxOhYz17fo z29W+`mmvA^;A{VoFEVd6(2W(geLm)&Lnmf85wVr<*aLYZ!+xPm`dvljTa3x(=Nm?5 zumCPC6l+wz9KO;Ex{3y~o=dW3>22$%j?6P}3q#0;`4mR@Axv1=vIO&_2WUqvPKmzd zI1$u9wTTVm8tK@VL?*~<8-U?=0t;GMjVSWV;jrY^>L1X5Nr!M4=w<;kAaZ)3RDusa z?c_D_(6cTUHI9dxX50dqO%|(-Imb}yzfwBE%3qSUV{NXYaYaSp*Vo??zmKan2rCb4 z9((;*ITH1+FN{pW+BAHC*{32uu!5mXPN%bjU+!(!`2C;M_5&Uvs)d1yiuqtpXXO*U zXP||N)dl)W`U=cZIm$Ry#9F-<$1-yD$%?k-lz93c_=8xFJzrX6jHUUS7vBWM_$nr7 z?~mD7CBwvb3#t&p@FD^?CSL+lJ9i_Oy>J>iJVPS$9%ZrU-uN7mQyQALO25$Ow9-ht^r&{C@*{B}1j7+k?ngnC2)+>esdc zeB4nK9N5jwLq3u^H87?Kkd3clJP;y_Jj=s$h=)7$cx#Ht$>=524MGI>ZTVlJg6~UI zY_8n~_PE^@mBf(tV{fdp1Tmw=*=vP3fJs&j4_nz5%U?ksk1$$n z;ZZ0QB_mf0^;Un<&r8b%qP_tTbSZP@>1y^X3K~V1_1FLUb{k_FrqY!g(vD)y(Go4< z&j;=|>U`2LJiDBMu8?z%&YPdC?U^;SmX-qJ(ZZyCCpfpyJ%g!=J%+%_FEul)%;2zI zVqmRo4W4CJtPY<~GS;ML?`V|n7SzMn<>Sp%kt{{#cG4^pIckO<;0U5uaY73nj692G zSS5y;Fzklk1fd=p=q?qFD5-eez&3DL(fvA#imuNiGylraa+w@5T(&X$A0=~Kyq`R& zd6OPxS1t9GaI~#0CA$wm^-y|qmcZUG$KqqY_EzXVkQTqKyBRL_7O-yAyc~^jZBed~ z`gzP<@-Lce&Nx;G;K6<();!`!B0phyP8wo@6q1dBcg*4+Bg?D9Uv3aCP8jQM??kX( z8nRj8vCG%k`e4l|Rm5n|4Xi`bD5>NcNMimi_x7Tqg>a*(q`faQfM)lC@O@Xv1z8J6 zl85~y9jVN8@J(Zb=YM;T|NVS@IIBCT8_1Xi5q(R6=Z=@>+xE-O)9dkeck)cr%j?dL zI~<(n2P*4{&Z@S-FOe)f`bp!%%k3B0X-%3_E-OE>Fbh!fN@GkBh%z0=1Yhv=K}}yB z$c6A&L#Z^J3uxhYJw)VY61|(ILJ)YBaSN3JZ5mHcVt<3fRTYy7 zzTncV1O)Zk($7e37h&~ym-*-n5B(!WJ-p!yFv%B2=Cu(=sy-TOqQLW7gBjv^gMzJg z{T4P#?Azlc^bxTCYmyz?*+$W}6i>0ot;83WP^1-rwB9~xUm^vNU7K|*i(7D-eMxtw9(-WwMhj{TCSp=r9=8y07I|I>czux3|U?y9R zl4f0+|F0#=7TNr2XR(LNSgtQ8HfpP?iNukS!2m zL7HeS+YKYI480WkrN+zj3X|cGWBYleiD(&%DZxI8NDrPFGo*b>SSB{i&N$>Ux_%P> zj2n{4ARao!V*am0#{AmS7Q$NkMc2j5%*573y*Y8tqzNX!5AVD9Jhh)FlwUI&zuHbw zY~@D0WiMxpa)xSLp@W6Gc5K;B+fs5gjH}o*JJ0=wTb1k@;-=a|5EneN~c zUZmcYg};&25S~#V+&=|e?VYXUt6#w2z#=_;X<6HxmIZEK&kRPRgv`Ss#HgT%$!ks5 zo)q*Eoew83n3;_3C=aTruTBI(iS&9pfe$5t)YjIvJ?z&U&Ap3H(Fx{6Iaya; zp4?ntwpOpi*`W~_*yyd#fMOf1xn$Y=HmBdqiMkr0A(j4CviYS4AI}A#>(9)<>z^lC z>qxM2oMTCS-5Fxw&7SpZg*(0AYJri4BXSsyp(2=YKf|Dv$cVxVeAGROIOcN10fZnY ziffYYvtAD5ngom?$7^K7mI@0_dY4$!4IXh5_(eW9dcy?* zHkZtl0H-;pPp)YHcD7eT%0h0Icx$)$VV9r`mvCQyXiwgGecGn(d>qny8Fj>$XQGLo zojU!3$`!8J{&BNszw2L8yu%lW}g zBtBsrZ?q66>7Ie#`1@xMKY3^G8$Y@T?HC@u3=vw}EkdKPuq)Kw zQ9HFQSfI`dQ!#d zu06-gXT`gQZxipkNgO%aMNRB&_QmbB#V*A@;%LN=!nzk+-6A7tb6MUJr!!$N%4x@L zNXo>woxm7nR$3a>-k9&vzdc1Ka)l^GPi^84-P9m@u^AWdg(=u-bnNq}ET&nO@7*^y zGWGJ_DUz(9NxrXYlLzboko~)ED?#-8FaiNIah^w_LVerd3>n_;-9Q53(Sj&do%K~9 zJsY#_mMt!m2!9=&e?{;VRx$S7AvR}RE7VdeHfxVj|D2>{-{j$Ro`(m#ukBw+v*ZfL z|B!uuaPr9?Tf2hM&_~+~93)A3jmwIkP8W6PlSJY|VT#$Gy^1M2cUJdbjGMXGZQa|E zL&n>?Kkiy+b(cS=?j4?7<&SUpQY0q~zCF{ELEN_mtqH^Ec`I@*xUW_ek?rXAtdd80 zYeMNuInCdfkCXSxBSN@nI9}?O;UC`#eT{020by-}yjA7V=A&S=`fd3i4x$RX0XM!f z2c6}VbE-^w?PdJLp-e*%buLZ#9F+fEUc!K%Ux7!5T7E7b8~6?C|D6b2o85YTy~=;l zg}DG)-1_}!V3ygXgYK3 zz9DgB)pR0tTL||owE}gN>CToawE?P}IEIRYTJ0RHUSz6}8Pi(~kyVD?C9=Y@>?7Y- zrb%)9U)E;A9xftid^sx}=_zq~eJlI|VqyOu*g%4+6T0d(M0$)nuU_Z?KYC$L4_5** zTHia@H`={KULR2q2OyyvjeyY6z`49>=*P14ysEuzXmMKIx3P&thL|m?`{{s*h}BCj z=CuXxkwF602)r~-QSMVM&WgvP!J)+aKR=ukA`5U)+hT1khnNK|2DvcxdB zwUXrD6|CP^QZt^}6SaGv);O+lOD2*IMnh)TbZU~f3>zA5K{+pcCLj0NJ8Urjih3c4 zZ5wGX`D$L|j?AH*?jc@(2in=n&i&1h+~!}Yt)AK%C#pW*n3JnKk`*Na`G^)}@w!>~ zEJ`JUio~CIGS2*s78fC&5`cMQHoq~~i4XfPMK0H&KJ{}4f;+6N$dGb-0dw$hoX>4AV}^Zaa5U z@`!U|fMO5#x40jk?Q&jOjFo_p+7xS+MvEGf%Z+IwH>8XGjT5(Z`OF*5h(qWyc z6SREUb6M3-$9euwGQQ8HBk&c0Yv_mI9nSy1zLYTfi|BE!z7xGAhUmYh&!j_Og`r(cHc!lXt~(`2Xi3z=mwKozO=9g_rQ6UIbrkY- zoBN$Ww(+pju@){=Na4-;x9@!h_5(a^kB&$kEIk|cr@Ke<=a_V2WJL5q46R0>9<52i zV1f+Wz+s3C<%`5>9E=5=HA_r#XsuX_TK1P?TVwecF@shDCbO7das!Kef&n{o1UMyx zLuM!*=IvnQl_81LxTYp$c0bEfyP)3-qOTE7YBW~9V!1MGrP9W>%CQAlk3&k^f~*YL zfs7N(Y>+mfCKkD7M0H*wsHBFno~DJx=TqzC_1?6}`<;G9bw{D&o~@%wHi@)gQQQ6- zg2pj^K&EET(C$nV_(>s5-GQUP4Pslc4asUoWVnV6G(gq6+aUViZr6fFp;#tLUM>g- zOP-OGt7l8#WZG6#@gEeUnTJo$EDwS6onzm;^7|?rj&{9!0u(3Y>(-C7h5iQ6j=+Uq8stvPN zNT+=EJP)vUIjX}eBYe387{vz;6+m%h%iuWBQ&GAeo=e1iUd15H89!Z>rA}Qe-=RTq&NvcM zR^MOp*K^Wz=`Tz(pr=W$!XM4>2VwjOstMxAbNGT^dB>UEl50@LATckD*Wtu6>;Ft& zQUj26+5b3U_WxP`+XB zG;~ySQnB&s*q6`~`V@18n3_xB3q3T1RM8VHgp zF#pG+;I<*%OBD_P02TuPF#o%0SsA(+>KohJ*q9o-SlZh;|38j*tz&Dy#fkiLqd(x> zC!xJ2(fN6RXoEg(w$t^L#2T>!hQ1?IV#}pQBh9BMcS!Tnw>Nz!TC9glA-6@V03vGZ zWX73;#g{`=!eh+f=P^LO!75h=zrM9lh4DL_CfBxgonyMBN~}SD%q(rn6L#A~3vu|Q zbqvN>zxI_ibW4hqOgyW^F4>+U)JSZMab*-o*Vw{(Gf7KAYm{h;@l(&MvX%<%g1%9RV!8>ILE@7>2rWD_-^ZV4mKY=c-_eN3> zc4rU7ReFg2cBb=CI6eU0fLjMM@R(D#!q}jbwBkSPrg_ZEeu~@|5Ldu5EAF7}0!H3* zV9kQv`7rZFGTkb^H#clVD=?T!Sr`haP+swi(W5oRdNIvI!^((utmx$)+TgP{!npdY zhw#9&WZb`v!ILW{eOGKX0mexeTkF{f+;w_Lon{bCDVI(|T1E&8Cz)|QqZYb`C-kPU zpZGv=enUJJ?9_Sy1QlQaWG*!-yH8FS9~1wCIjm+7An>S|26AU@b=n~D&DXmj>SVlqLsK~OcptN|m*I+5CJQ34SB zlK#oSjQFMYh(duVlT~A|xrJgTPQ_2&EYb#vnWuZ@&Eh_`HdC=dzYrws5+9;SEOP1(RwWH)DF;%( zbP6z0-hjW%wjAgI(G+>I>MZa=-{os{_oGdF)-(jb#UM7Y01l`x1A{9L7c0_GYgFqW2IUpr9@PU;;Pmi?dNs>5t$P>zJW-{YkA(9$eCDW;!a^vO}B5$c?|7bv`tN~2{0Hg$F7`u7)B@vmysjgZSTznLR8qV_oj zu*G~BSX77R)L{s{x$0n@cePZ+Od<wr5Ixb8)(cT=A|cl5 z4y||FYu@*?p9wLL0`X*hl_|iT6Ya_boE{_GKwxJh0Ahs{vO)+-*)kp-N*ymghV6Q(Inw&h7&@X%Ib z-`yG)e-%hRx^4YnU<+wA1iwGis~OxL)Lf))-ZLb^OPNmMT!-PYtp}@R_S%$zqKS4xvoJ4wRPAfv&ph@p^y$$C$tr zQ)%Z6d5B3qlf2Z3Y=|-sE6o%KSBZ3(+Fw-wUFySk(bQBCf8h~_xw-gRpXiymImpYG zyu5Y#jeY>+P+Ol}aT}IhisH6F2q+O|)2I1nO~w6Jy;~6S`NguUpIr)m|GZdebkS5| z$#F$DQU_ZILAo!M0VEym{O?+Sg_jQsWv1;&&fi~g=12-NZqTbCgbPk9jw1wl`8bGp zV?R)6G`Ebu)V$vsJI92FU)-JM!nq*P+-dcKsJH=Is_Ms6li(G$Bap!#AL<5))HWhu zEtFGt`E2-K28Se1Or@1#{|+2=YatMJJ1vFtk!%pK1D3%}%vt~rL%E1ysOkw$uNG}V zpj&C=EwThN0IvNPJkdM2Im$xHYBM6J3XFq+$~LiaR9@E~(5_Y%#RvKmT?3mhhI$b{ zOk>DoO#e#iM->QMK`L4s4O$H^n#U1YwwCK1g>-*jNPItQ$XX=pXto4CQ_l6vc2j8; z0VoYsR&qMhG9-3bZC-K`=>Hgh0-@rNIUOv8hZTuB)=5$6PyZxF`6CWL@>|={biU$p zpc$-C-!2}2VLMdJVUk{!8!~ggPKNEu-2G6#^}lo;+JazZ19dH z9<3~wVFeOL3EUTG3J?HB(I__?I$c?8wCof6=^=-si|!l(TzEPBT!g&5`wrC0%6H0- zP&(4NmwHq0ZH}*;5hOR4lDOVr-1+S%>%})DjI9mes_6{F&vw82W3NPAk%lQyy5cO| z?|QYehSG&BuJW%RuPmyZ@Gbe6x3OwTg9zl9!7T`r|O+`m-X3*$&p)c#mxt@^5{Bi0HBCh#o)37K|ghtlDhH zLB1{&tiwFMV~DG19s#y2EI7#=h>&E>ps?_x-K}rYU|5r`F7k459d<$G9Y>8cK^@=8 z6qy;%Yc3>pMasI3Hd*h1+@RQfZNJ8j+*ZQ9<*GGjDV)yBR6*))=a)9_=^8QpUuyA5 zLHMbb#@}AojI?J9^+R2dp5nk4?Guh`oa|5zRV(fEuD6ZC;}sK}NxkdlnGy!SRvQ2N z$2;VVc7g)fQTe7oOY*&`gHflqAMuwmilGJc5~jorkk8w*=e@Zrg0t=)Yucz8gb-h- zA&zC_w+;HPF?qVWFtFkfyL1Eo@kF=gvh(51V2bh!-ML)E*mNzqGBp9UX(C=D-J90QkcX-hp@y~zH&NQvNxEM*FF8HD zJ1YwO-lGTxx{d7oj%87sVP%BRM1$Xx zQy=7x?k8KuMcrB{?Mwlas&iMcJIU*MtZAGDm?7nxxBNmOPo-T zsk@zZ8*l{KiC>%oZs{(bJDcalb519;SD>#JkWk9cA{gYlx9B0xpRIFd6zpdmv(q{Y zti?!gN!y#3tbp6;Vuz_gRY`*jIq7DWo0&Mr)OC@?*2=UFZ9V$!X8vQ8t@VL?GTX1Q z^3&+0?v=a>w0ba^QAQa*?%E`D*pDGEJ}h^uS+mon4`a>WdzlgNGw{sUMt2-#c}kSj z_yTEm6!cLa7XRAQpUv;s_F0jaNzerDmOjLVPPhEYP_R6L2s|h&7Ft@Jwuz|zbFb&m z;mv0}=FFrRU(C_ROMA9Ktytw01CxR^bu>WH3wATE6i)>X_DiaS$-Bm?W?(CJD(ieT zQrjVU8-@;>lRiI#+b}22D|+6SbE9M(q_y5&SI5ulb-n}sB&z+T5!6LYsVGGI4H6XL za5PIPM2#b6?(}Z?eLkF?f=2ai)y$XR#%1eoM|dZI+~1Ts|F}giL_mQUul<`9SjnQ2QFYpVoD|era2t3sJ_=ar1}O zj*T0^hZ$Q3KXzAV*W2aujGn#X&i{Y@mj25f|A#O_)usvns5t`tk30Uq#x!$77gKjb z&;K6L{x9zMKOI}=EwQAZS$!sH{gohiu&KBv_bF=M1hP%V!nRaO?mu-fxX8lD=#Yc} z%FOJCpO>q2PY?-esT}E>p+8iSpk`iPUY*`~H+c&v4X-ZpGLM`Sry!0jUOlh+Z^cX3 zZgUSFsnm*R#1mGkla_Qh$zDVc6XI+&G#=E&Z*sem_v}~k(!wdq5pJh zuSD}&_QG;!wKd6|k}9Q^VKk#nne~^AvL>wtNS0y9dcD2LE49+R6fz@K(W@G>PD?I9 zS^stD0IwZV&fTb}Q823HMAO!ZR+PRwRWxE|?nxF;dde6W2jQW1%Oq>3)_V6JDOQ-& z7y+rouc-L!U@*MA{_bAD9z&a{!Vl9Bb}3Y7)s8geRLPK|N!17^#}{r?>ZMXgi8RtW z85pFJ5t^cvKut2&dQRB3C=>&EgA=HwEYRadYiucOQpt6dQPLVFWHNf5|zxrYf z$5lw~S zAM?!pr4GoHqX6Mr2;zCbAyT1}`-~01n{dzaH(1vlgaI!Lj!2dy~$5kjUJ~XAs%Jm0FZ(AUIX{hL+4Wn#q zhz3`@&8Y|%FE6i`i|2!{6JM_fj4ykqV2UGmia_5FL)YKm2@6j*Cl5Yei~{3q_-PeG z*N(rlbKqUA5`Ncxt@83k)IFBp>O-+s3O2pVO9i`IIN0}h1dSg9vPK_2AXJR>9skJ8{L>O(bPLM zJ1QgJx^NQgONvdU2p@`B0Ld11L~H=S0i~#@A|l7Kw1W{-c%gpLEnp8E{J9D}HeR@CgX}CcZeEGiM0wGrZ^SbHhfz1`B!xo-r z+ITf#!goNVcNseO;`zBnXU|-p9p_J=Zu&fV0*ZuvRn|8U%T(3JYc95u zV9PfO#mc~Ujss$#lMpd>nAjt;KH9=HFRUJzS9G}CP&M#d4G1hQkWoAN9=mx-D?9_7KGexvX0(EDWywpcSw_W(SDD%w-i`^?SCW+iR!g_wTP z8KpmU`9&aC2Ma5A4y-Y=;kSqcxIm2&yE?ktI(|16{~|1umKGmsAx&$k(SY`W6x9M1 zX`qR6%7IMOgS4P`l#apdd03#w_LqE<0)JGrN-wOBI&}}*M>uS$?TORWkw>X zTjrutUBX&IumET%S%^6lfv4k(cuCa@Jm6>qm5i4IVrZWx1yUpq2!S)tcMc}XU*ANn z#sO|be&OBR*3L`llq&V^i>)OYlP zkJi#ZolW~h%N*JAcsw72YzjP8FY)**motgf?n-3Ak!rl zL^Blo&iLW96=Ow$7RkbN^7eY|alSx-!nQ(C5taf~3u6Lz?(?CNB?h|}fq~a!L+Tv% z?4=q)0rCwq1J-u{sVV)ccJcX($;K=a)F01ZmJggUUt#3%iQ6cf2kll#j2k zFXz|4u?NBcCTUIkOlkXQ>9-(9;RWmtZxFNZ$LRljQq%jen!{1CWaAQs-q8?7z9z2Z zB^II8XG8HKk9mO$Jmu1|XcPu~+$$+7j4T-4-P>X9kX`oH4T%R(m$0Ouc#Y!70d+zr zRsbBv1x;1UEt#bSh?83|yKanA5T6K%1@sAh1D9$7!WQvu74k=6SbhZd(an2_!1{~p zjSc8ykF+L+_?CR?RpGbmWdb!}R5%enSr;5sDJ(sp1N{&*U(FWP)ECavjt`ddaC#3w zTm-OXmq3XI0#CSOvg*oPxja;7eC0ChPdRzGq0dS{$x@*mM9V2ulywsLyR}03N_Thj&~?wYHC*_M>oAfzhNsG|VQ& zCgDmet7-&=+oe&hG%3mA^$%RJg4>f~ljZ{gL5TkmL6sah$9E%AI-!UFwe=fwh0ZOd^_O>An<4upV_eQ@RR#2g1x)h3}9iY89 zz<1L~R)bL}x!H2fk-`iLxiavHfRlmzI}#bii1b2Lfu(8Gi)~U|+P*dtV2NO%y|KDx zUp|?H6Qu#%I+f&KG2YFOnOb#?Dglw9{uQx(8_~e#ycJOnTjhzQHj2HE*;=fOjM1Na zkCV2ZaTO+zXB|#1#J14dn2a16ODv-i4AT?N#-$w6yn?UcN}U{fxFf{_^;9(7t6IJzXZg-J2Er=4D8gF1;E zgQJBKaj}6`$dMW4$_ETYB|?MlatNBdx`1Xu%hIrdL4O7VkZTQQ9n`}zqGaScC{B3d zt(O@p1#=g|q5zkxG*X5wTqFvId8wetv_?J#g{Fd4BiZjunXo~7Rw{Wi=n{QuJ?UU7 zga*?xJXvKKeGnNY2j0^Jhk*NUu`1^Ij}MNOqZzcSTAH;B)u#4l!L6yO4quoK@fWgF=TpKU-=s%d026H2~G$?`kSV7-Rf(4Gi|1?-ucRbrh$G$FC8 zwOG>rzrt%qx0`%!s3PuS;xZh7LH{(_p{;w zZ=d_(1)Sf=WOn#?RfDx+4tEVDwoTKG>V9CgeY+>3h{;l28+r30cA8+V8kQLmlgPZD zlAshzG2USjG>S4REflEdKFR7hc(cWvcx8g{?Yx5C%BFZ>%00q!6v(ngANUhfT);7OY5! z*I*6mlUx?5>XRY(8_!W)rh7-vbYIrU?bNZH?5As-; zAO}#AG`UlKofccd1JEw3I0WJ=JdssR0b7k zw7~T_K-UZ*VNwDZo2xqLTGQ$}ACk;mmu9qOi8Cb^EZX)E{ekIX;eqVW5pCwY+r+qa zN5{0UWs-VWO~kokpG9$C=ufv5A!-@*+Bs=^#4G3?5r83U!(1GFpy#LuI%nYTT zVd~TYt;&==K?A(19Yb%6(5#3mRj`x{Q7`7ynHHC(eQ(7AGTgK6{5bUk#nFtW;A1nx zFT{I3y`=S-D?Hr~Mrc0Xa-N=tw~3~`J!rNEFnCg;+U^xep}IEn?;G#~B)HigHtRNK zY1&kENdbAUQErIUWHk?TB#OY&06?-aBPf5VOfJ|imBuAzlT5V7rCm^Op&)76p=>py zxx?ZHfIE~|Ni^NiXak3EV3qDYHE)f9zOBRR4}&C?sgwsBJ357&SS4sK=R;ixAN97d zNJl_7jn2MrQL+u5h&dnGgf+s$3Sr{GQnGRDLdeRD)E7r$|C>7ZY3T$T7mU{WwN0 zSrT>xq|q1eTa*0oD0zB@EM>pfc#gg%2(x7=|Zn?5zyD+H(z<3 zsbij*ScK!An2WV^Vu$9zsRFLx5)QY|$uvBAOL?mZVVfWW{}E3S4WL6PH><5Wsk6la zSxVnm(QSo#-8$mq|MryAEKzC92#ABvzMw}Drp;Yy^-9%c0?}fnV?Ko_1X{#m?3$0A z)r~!hWB+uM$;0WS*ld+DP8|dHIdgL`!@W)CZq5Fgv!f@3&tzNGa5o;KWP2t|r!7sE zhlcQbI{vvMoyF0(nkP#61)hHc0=LW4uTO@j$J>k3{OZk+YSTMQ({2u7#O1t3k=k+k z8*9>y0UlkTUPyaBzi|j(7poi3bhX-Pt{L0N==>r&8=k!yUKutqTVhA>&Xw3PcAoTa zdZcca>*b=eD49u;;{g|gwKjG&C(-^6>GxWbc3`QE%^MB2+TKQKzZs9|I{yQUw2R5+ zcJ0}(cjD_oL+=Y>sCCZkK~o0{X&58P;9Qh|xLJQY9NiAf|dOy#SoGWNzj&~9bc1PqUcA$U1xOnET>h&eA3EIARg zM}SGqC9=XJhg3lYx+!N=@q(uDBAPb_-CEw1FN_D61*Y_{j;dm2tJ!NRxg7MH{2i0= zyL!Ui^muyIT@Xks3c|88g?YokGXYDsp$fmO7Y`_vQIw>L%mkx+7U)LX#-YTMcP=xF zsj^(5VCf*fq;|&MC}+^PP(X=MPTsQE9NFd)XeI`>`e$m#hT4Pcg&i(554+mapab0f zIEBr3QbmAjwuS&grd-n`JL&5i(?{5;nU>M?b){;{$i2oQH@7VO!j*3{Ws@h+ShSWC%Y!X1mzv^=L z$Gc@nX-{xt`a*U%S{LdD58x)%iJs=gEND)sq3^yAbqogEjNpofuXKv~PKK$><{`06 z9@EXhRo&1--Qkg}+Md$@H-zdJvLq6A16$mZR@=j-T{U&n6gajt(z;32scBx8sw(m< zrlN{$v%k$b`3gjzvAI@E7MUM<<4+8>owWXeZ%yo9(P?UKVtfYXcIisiqRj76B6tME zyTj+n!*m^p3)dGLpC`d4>0@ z&1n+LVC7N}EHGxcF4tauT#p{IjI_lN1$SG~HE$-vZ3Dh6NyI=(-9AzPH$Kbh$XKtj z80J1%z;{`!x8GbAY+2EbF*HJI{cBdFMSmc!9SOW}ruo>tGA#ms&H}ZV2wUO{ufM$+ zMOVk~uD)#P=zbo`vmx$XD~VwzdI&%Ta{EoL>x=00ieMX+QujH$LewR$u19p` z04AVeTV76+iL*sh{Piv%o4O_5IjNOV2z~NrhP64I4IZ20b7%`&QN#2!OBTV7)S!p{ zQiHqNXA}&eo<5i09eu+977MxhLKM%xMrp^I=b|ylp*|w z9loi0X_ROAYzCX!bN!uV4G;D?ZhGkE*n_vmcugE2<3HUI;!*^Yzb46F`)2gM9kypy zeY#N7TkqPnUHd*sb@NWy5q8k053B{#u^bt%c@d^6xl?#q#{88JN;D=0((Cg9yl6na z5|j_G_0}cDJO>PmI6Q2$E%iRAYw~EEYrBtAZ*$JXyCK-p0@gEkOr_ym%+&dYjqYXC z3+K=}-sHqH(8p8V6Yf{lxxSoXsWsKYUv+bvZ$I4M-+yjGJKR1iW!0qNN4&R|5QzUS zc+qei`i7dXrfrMXwctrd8+d9oXved65(^#Z#mh_|e+Kta))eFHI`$kVRguayJ4UPX$NUs^ z6NjR|qN6W%>)Toc9kZ>u?j~QXDew$K4Sjj4nVTvH9**+BHV)w;H#!KzMLdrwo{%|F zHdh(u;@WDiWip_u*Pi}n4|&d#v{t2#CUK;yaVm&M=mKng`Z~|MbU?@DAI{eBkp2vf z_xlwt1fKVzBp?Nl@KkS;BR)_|EZ~VXS%XJ{-DECtcO*9O&<&o_Hi0?pE;%=({3O%y zBViNO_xW@de>umC1y8#zIAImSJ03(yPqNtrXWR&|?MYKTDxlzsb`C8AeqfLEvEJZu zIH0bN@Yg!ipc!w*fo7atbQ;|2$E}l1PD3#@VS4gqcKi3b5}g^~hA`12m8HqXjFH6H zW*}`Mw>=9zNaugr`|{>Cjx5jr^(m^Mi7?=SV5y^LVuIUr4^7cBZIRTHl-&+l0Th9v zSQZGtC_og~>+in#uFRKL0i@)scM#ncN#uRJeCO{mErJDrt!u1!6>(nz($vMv*%^ZL zLw9CaPdZ8h47HzM9Xx-L4HCq8xMC+?e0{LrstbjW8NtBszncmkgioaFSkRrX zsBDV`2ubUmBNI`|4j0&n;2&#Q_n+^boo%(Mrqh49m1MnKwe0lpz6W2!!1O|!1{76` zt@ZGU+{L7Zb9_fv13{O{d$I6VCstOqL^3jIPB-V<5a1;SQfLc@hKBHSxr*X(FlhOaeMATQYt zTFi;IjNlN=xhwPpa4tXf3?SMJ{j5M^1pHOMP`(06D}-z|y*>os2FF7?fMU@{!b6ig zjFn^X0quLDHnQ=s4JszNrHwHIYLR0eRm!VZ38YhHo*c)SLOZqlOPEK9@=stI-` ziF+q4Sxy_z3B8&eT>x(@0CT`yG!)Y8`tq&1`I%Pql4)0}U2jPK zQZ?%_I!ejPp1eGN@$!85?a85~%L9PqGGp(oSC^0eOOY?EZ%#VWWFum9xv@5vUV#T= z>VJBAaymRcIDdJ1JUlu$JKKA9(AwR@;{4?4NdX&CoPqv=+7nBlb+X1k|06$A{JwXW zL$gm{3V}O0=sHNW!O4-}mT4m(nn@iIuwhlCD6)4iAg_DTwOpbx3$#e)4lz288nR)l z{T4KiS|??>$H~Y+VK2j)U@U#P-jTYd~`9LB;!O2Dx<$DN8z6>LZ5mp-)n27;TZ# zBxQBM`&J!j^+!l&=4J(_M!16406MYSL8oN`r&cQliLnDB4Kd&gRqhLvZ$Pz6!K!a( z^Bd9yVWg=*4<8jM-GScuy&I5-xgY`zP_#|b&2{2c?RL^j}0$GkB|R2!T6>( zpgg#!fMF`SN~=f!Vv~A|Lq`~v>iuF`kLwl7ht7qO7vbX8nsOwLL;)|d#(%Vp-aF<> zHz@Oia_syV77{Ghz#&y5@hVW0D^dfJwFrwg9FA;ffIr}2GriQS;m~5Yl3OWwLld|3 zDL6Zue?DcmXQgppFU%PHAwhuC^s+aw&4$%oZBuSwgBG1ryI+7t0S`6);NJ1HJHGhg z2hXVPY6y-Gh=*$aIE2WE;ZHVPyshQXU4NUF2}@eFY(*T`ZU9wI0}j_^<^W?KUlhLGxilmy>N;QIY$uf#rIJbyY1y1sEcb_X^n)__3>whvl;R zpa*qF4AglTZ)(_IzmZLA-Sr+H9HP5~q0{!}TPWd$;`kn+0AK`R81A|8vXZyBSSE+l zx>;?$ufgc=cwsrD0{-!206hkOaUT}nNl?ldW(n78B`!G)1q?S9P{Q=A2q7x#g4TWn zcZ2%PjEw--gLo^w(SQS*{=J3$O;&#Oy7h3{RT%&RZZ|{EZJ{vUH(Lsw-}OOu=I>cz zp#DUX)_i+@`Gs?+f!(=w7BMI8I}{rQt4$sf+AxF?a!?`kCAs5(Q3Dm;^vcd>m-QQP z`YMZ&at0c;oa<;i`hqOE8$fmUF;Zq%Lv;5G<+_%DPr+|-QOHuke<0v?0fVVP`7i{3 z+Ef!wGldEO@WJK}j!U>U-0$*sXrxV3ANZeyF+^K&xV94n@ek@;37Y{QF{RN;kh7Ek zZW``}!S1dk*wzK$R^|kk-WU`xGXsx&vLtt7U{{OL+1c}bn%@ZJHLvoB44YT;W@YU? zBZB;(P%|DAdIX^^#H-#VMN^M7>fA`z@ckSxVz|9keMyqgiVQk%UV<}$u}p@+gXnB= zKU0eo8HCdzk4VYCf{3Q}j*2cEf*L zg>DsA0({mBxN6wnZE;nj?U}4UjTlyx`zPA_-?}o<$eLn`*dCW=qfsbJuPR>C2vEdr zl>-#MxK(@WcNenx;z)Xn&ut)^SP3W^3K0l@Ftkhl;&H(cBmh2=#09a78aR4}w`_R1 z77n%LgwZeT;Ui!A;KcAjV}Qve=%Ng@M(r4gA<(KX4>c>;ZQex>bOSz4I+AASI@5d_ z=P#sI#*xWrljJ$4zjbMynZ?5hE6 zFDNQ1->7;8FA7-}L50$!e4^AG6|vuucy%J%hDv&8Q7+SUIjI6*AsiB{rlYL!VrJeD zjdD`K)v=&D4d5scx`Ee|au)1*4&hc4f!;Z~nI0Svz6-%q!S_O{u#^v#;>GekH_W&{ zFCHVTNW)Ms7&8p4%NSi3%gdG0nZn1>Y4KQE&CS*D6AM_M2ODQ+Cz8Cr{JP?Z#Dy7E z>LKz3`n8}&mJutwvEhp*_!WX`#%LEIT7}fnFE}!6@;vU!Lm4oU6zyxHf?YlP%s^H9 zOQH}_yj?{?Q-zpJ?Rwy`4C!B-P{tXynzw2LzCk`{z1olITB^K7#1s5yYmf(wH&NgC;eZJd+4By?H z9%m!lEZF}Qc{4jd(G1)Fd-2%6D|c$Qb4Rm(gb7rU7Z4RGBvj7wbCftjcelC{e^4|C z5?r_$#9-~Xx zASo~G#-j(8;m$;R@gS+rN2p?o8OH`gK;3zKRX5ps}^yp8MZA;-91sGoZS{9FW=vY zT_Zy2qI5CSyg&rzYNM@ZhSWlEx8EJgR~J=C24S^B9-0Jf5&zNMFBn;*$TTZcUm}Tz z#mC|YuII3WKB+y+Mlx&qM1V-?TfoDbTZ~gD*%W6L#Vp{Fn_Ev34*RSxy~Z@0tTpFJ z?osr0ye!9+!pKLQ?Rk8oTv^BngLtFmw7~p*a#$FIbWDvqb2++dei%cKY#=dXe>8^P zJ{IO%`*WsF+*J7J9nA zzP`m=vA*;Z*`1=qo>HnDV+c1Jr;vd%_hwgRU=Xm52m1ny$Fsq-O1;#t!e5D`}$#fjEG~w;eND^I9h)Gz+rAnpkyD*H7 zVIY|3quIeS859nq1ZjS5n*6e{)y-ACg^u1EC;JHBmaPR#!ip=l&YWA9FV#vM;QMx$BoA& z2W<>`qIk@+*Nr?mn2Of|!@BOZ0zNf@j7~nh2D&!?sf~3{^pw@aQ__=N^0euy_deqo z%ddan2@D_bZ%SCCBszG>zL3Rqj`7$LI*~GHDw9>p&ZK(O=iGs9*|39yt3e)Hj?B zH9Dv9Ij`Tr*$y<}36Kz(r?`jPL}$ScAzBEkWF3%(s8?B55hNvr@oODm^oWJ8B5l_W znlCAMIqA~ir5?HVy|FGTr$WRrOzX>vxG}aQNjeHNl?YJOs&+O}*)~V^xpbQHCvHV> z{kOW8gJ`{b@uAT`Ki(GsE8uYG)kz9TN%BAmmBc8i6ugY=X?${`K_zD7PM3akiZG|v{sy>N~)e0t4W zma??K=?guhJkH+vN1*|2f(Q#J#=Bn0R~o|D5;LRS@;bSxID-tDr^p(EIt3Z2^`CKl zSVtN%BAAiK`HH5phbf#g`*afHtI4a4UO_!}j8|2We zrnK|RKtUtHLVmrXJ74nP5V(95I97&w20(#jS$>p!RczirTw;}M#vU3DdbcsD&IRe) zqoNDb<`Q&H)L^0$)1zg2v!-l8355s6I5l&N0pXpB5p6z}qRsym=%qr!(WEmQ;;+WX zXES;!ZH8uQxu4@~hPX1xry4#CemJ}=SLKv|XS>p7FYYzc6%vIDnhf=f@Zlv(F(EB% zL`eL>Znp@vq8K?XO{7)|R4(D`qAIHxCDZl*t(pA*mMF0nx@}L0%`O&}j&_(Z{NGyi&<(v;Q0f~8oXx-gf+^LWwXt+KYlzPIjFIaWnVge-RPffHeZIB3GkKA z<_WZpSjAd9`V-ufHfEl$N()cSjb&^dQAO36x4jY|nA+q}vnXkOXQU2?1B%MjT|c(` zT}o3(o5*bWC*7%|GK%RR~v>nZ+Wb4tSv2`|L_H1&$Ispkjf#MVaK{N?522=^Hv z^^uDO1iI2u%h;&N!aS|mzCr)zl z^Z&q#dF$(%Qp$s=>D{~r1{P&M*A*^Wn>zjvsO&8-hrs&5!KSCw@JrAhzE!&GzH7Gk z>hWvyzd%h(hs?G^|8%*1g*6F|2Ir>-dqqR*)l#1|5-z+C#JTSjc^6DjsTT%3AGday$OY{3rA21#smc=-m8(tm>9wB7$vWM zT2aSmnA`4VFI-9iVvzz=Ly%?)lHGPY%Zm^jCuh}(OAR$hOImo^7F-5TK!_XD*{5(O z4+};UPx0ssrr?lj)DUuLHSoD^-axWWK&C;tB#TNDhHSy3*NM#_9d?=P{`~ltou8ll z(z6&%xs7XI2s6GM?2o6HH|2~_pgGmt4#n1h%$uVerUI?uHJpr&nc)vJGsPcfcgLYh zlfBk4cvbv0vDkYFS?4+s#X0NC(8?J#=V`HNUUcc2b4AC@*TZ6u^FuG@H3TU@NI$1^ zT~wTOoRvPvELY_$(MUi=nmIItg^T%mY3B2;Y6j`mQnF}H3T%Z%N{en^B0^Mz!<4yu z2{ENl!gNY!st|bs%Bd7uOs5>tIlN@?obh@nV|3#|fz$XbN_CRH*4$1=WGx=KB9m?2 zRu2KmYo$ERFRDv`USK45p%)D1&O{jVNz=r1TOkxqY6KA=0yWuIAZQ`aAh^fw*EP4K zBG>O}ZgDmi-PnE1Ey~8Cn`$+;+|_s*SWRUFZYMr;a+1lI?*(=+)pV>+lZxO4K;j~BAazuh4&^Vb}xTWJww0O3FEBxDz6l z;66R0oEAY_+#B;UDxOUNLpW^Q)clt9qT2&~wobgTVdG^<@isfJIq*4#F^19|>Cmai zk8)~hD}SYPYa36V+|z^ogTrqRoRhxFmlu2J3%n27#4$5Nw;ogj$%5GPy~Uje;fT^I z%;q=tK|^9Bcs1xAPj2DBWPM3cIPh0V?Wf)sOnYd)k_?4KA$m{&RCaUNjMYX-0HuMY za;pqsvcpgk1lHNb2K+B`hvDT-5VxB`Afn@MqkAa8i*++FyibFpCpSsa z5ry1~bl=)cU6~$4%lku4j>L1c!BO}yn(@rlnTLP8?}BS$zk@%C^abk4&fv#+9XKew ziV}Xf(}hW~^}Uioyp9)K_oZg4Z3n#AZ9RPZR+!o7xYpVzQgxoSdX_JtE^rk$e zY)&$?ip!!vk}D2{^Ch6<8uq=FrVUgF=~H#c;1lM2Ap>4()|b@#85Hq0^*#w%`U%r~ zTZR^|htN?*$@}nm2sUQcZ}GvfiQW+wAxKnVsInw-v~A?GY3n&;U!#HInJj~$h6u;m zQda<#lr7y!-l>5RkXGgRsv95qbVx3oFU!z{=lmm{CI~?P9nMm4n6^1Uk^i6G-Bf0n zKkA0-L3%O@u@LPt%GWA;Hp+GsOA0_55XFmR$wFo>*|A4PKzN+3E+pHC3JI4HTaDA9 zPr$xmy0c!m0V3eHM+|22%*gvnNi5OA-yY?CMI;fQ2hvCqh4V)dznUt5mbI;?CPTMp z3qfXhCrM`Lq3UfCLUO2kL)q5d;^$w2udq1gZ6YF|LteN9WUL!5=8N^T%sMh;g%2qtb2tFCe~joH;eIjH#Mlf zxe>d;Ok$QalP?9i#9S_w@fh0HS%=`0N}6K!MP*DeyU8#MnxAEw6-c7FzSA=erWuCM zxpaL&FP@ke?Dzi&&AFBS5UCmV12RN-ilZQ&5fws)=@*V9&_F(f3!0;W8=}AvEy6D2? zLEUZhj~+sZ0ood7DX(Xk-`-p>-0gZc2^=xGwW1KVpWuaHGC)y+Vt?%wd&Rkt9!P`9 zcNarBxkAsSRHKwG*A)nXIOWwc)EgItpxhO3{t{Z#J+pW9#d<|OK8W^>Y+U~95$Q9! z;S-fmCv+kH0lw>d%!@C7Vu3DkO<@DjkF=3dS40mP5UfmdxC9Ai%3nZ33Cx{kAj5v> zlkG4(ZsyAdHp#4LGd7RKvYv-1*nsIY23fMEB!!dt8X?h*yiUo@H99$Gt{)__MaV9^ zf_c3*%coS+v4#TGxO6iEs@P9b1W$qjI*6yBZfru`G;#38q{+jiPJjYY9bqJB!EUUs zXeJ+5`pd(E=TA2#dG)t4)~L7jl9%M<<4bfQceXjC%N?REMLWOa^wmh*pbj7KM954` zi`I*YX|JmXF)+_=z3Re}h3waTHqwD5eh0VNy${{FJMY@Z*)b7&a+~87 z23CT|;*2WB*Kqsi)8B@>&~`oc`_g#Vb1d9B?_$e%)BBuc^uqIsWMqrH(p2u*!Nz3W z+nrNo!}~Bb(%p`(HMP5tvEvhS3y;1yu-6tIhtQw()#KM;5GxrC7{=V?9)?e>i~qH+ zOUygT0Lj_YW92+RzMrf_fSj{05|ATAQn2Cn10OU9pzYdZbXYx+JHD4}3MW_}(0M`^I|%Tck6G>DXLTPYA502myz3q-?+fDhcr$YWvPHESK) z(02hN$N|ik;e0^=L}dda<}Q~O_s@J@lIvXGYmOZtR83{IsnLjPHeo$pft&e-!!cb+ zmK_ln$w|bIX&n9d?t0{*&zB-|_ls)qW?lzXtSb8lE1i=1@%h z19p1w<@1C6^TU(lE%;$5jO_5&d*AL2pYI(%8}2{fJ3~=!vys(1Zrd5*?{={CT?4NZ z=3|S`iznoPNbazYUpq$l?B)J`1Xs#i{!tHOa4b5ld|nU)67>rabWW38y@%|owY5C@ zfrBuBUhf`SansX~kl3bP`@&Y5X#=zI>`aNv@Rk;kK&X8khg)8n(xmc zDMk{C3I7EnytyY3$Oj0}$kkA`Fl-8bXH^tv;i~r~G{@Pl;GOuku5Ki%^PTybTL1j> z&x^TPKKRRkUG;h(^`h3v)O6wseQRqi&2G%ErK(jctylGQaEjk08dBd%IuMuLO$J#v z^1O*793zd<=-t4TPK#Rm;#L~#tAT%t*5Y?%js6PKHV(TuCluXsGJ*f2cSkuTMPk)_ z_JS7M(1pp|b&Cqd#qHmI#2L$Jf}|#Zkc5AupKT5Gx2_tr;jhiH9`Ig~Joto#1gK`S zf9&e1aydC;JVwfzf=6L&PZH@c382l+|B|CcRn}@6T18}*TB(|WK1h9_x3}{(*1X4b z02_jq^$Eqe!V^<2ZRUD6WM3Eg&B6CT7XL2Ce!f)a{qv5eH0G7asg3z-PjL*?l+>ma z`Om#56*Bx??lLb-m{}&L0k$3fp>@FbQvRCj_3nCDPeR~RfjN|~?;%?oDK6h&QW=ph9{WUn9I1T^&$tMw;@lw~ zAuxwrA&s;I90?r>PiPy6(Dsql1}~4lIR<=+O?m?vR$4XS9h~1xd!V)=wBE9o!1S^I z&EB(vx?%X92y`*N=Y>UH@I5#N57-X*YEb}u=` z+T08F#pr3sI=Ol34QVz9Q>Umo=}(xZ)1PQb(nS7lWY1e5{iSSZ#(iwa8y0`HyzN>M} zm4Y=qJos)qH=Zn&kA&~Gsl9sqI*Rr9lL*1>Rx|{nQGI2vo8hcMmfO;C?T_itoIkb+ zwJ}WBYN)^;Wo7{;|@nP!=WYFx%+W*zF|jQW3RbZ z!{f(gxU%*KD=ox2QlBs9WLjxtrv}cTXqf(e=x8d z77s(8N}&J>{RKw&MWBlI_>;m5vDdq@fXqB_n1T`z*yg z5CYIH%cLY2ta`-3!4zM|S~QHqqnqfDiTpLqkC0NCP^jFtcv~B&Hb+)7HkqHBuDZPu zo?Fl1hFY;DRdcGAb=m=ZwPiV}9?dT=?IA}a)U1XFp0v)8*<-%qWXX_jVBNq2-6Fn4 z+)60RvGnH10r}MVmiz_YHSnW+4|CADl)pd!%Dm4nPEY>%eIl_kK2bvNK{K;64hQ6QI@;%Ybhi^@RO3pIDYYX7V>TB@j6lmyqXcjXhrnr> zrL5P^3HY1d)Q#H922!ha2vksv5%D@?H*-f+CQIOvn-x(2k9-*}&Km@N0lQv-?+-St-!&NRo8S zv_-jCy-mx7aq8si>cv$I=p7QQ>+^YsZn(xBI&w?5T{o`$9Xd%qh}r&PDMt?=*%mPg z;lW7YP0NULXD;Ouoj<&4HKMd0D%lOA)@-%DejU5U(L`iEsil`MW9c+fN3-t=Wqb3# z?S=#|`{;n=l>08wbw+#S>bk#Wo*0)ZCpl{QzU$rVvud?o1Xw|S`;e?4n0ZTVJ+&8* zR%=)eW*eT;5Lc%qPe|@wQp4`sZZB$WtMi&@HF>kT8=BOr>GUxQG}X}&sW6obwm+`Lx^XZ`DLj(Gc%^o-Ir- ztBw#GEY`WqS=nYS{_^qpDF{M`$LGV7FTOt5KOY`G&2EXkEOt{8S1={Dg^2$gE8<^n z@&jEyo=MP$g>k9reYVO-mkUR^K4dAopqr>ADFVBbj zC&%9&oSv)HFPmNzzb!B1Z$p$~2m_d*sGIDmq4rL{Q3_`-p-WGm;f;a#bP8fs)W+^4 zFUZ-!safMY_xKKc3dJU$cp!47>)qH7D9#;KPZGrrpWY=>#*5)rN+4ZPp)6vX>QM`b zaXO0|IB1>S8X5O`z;R*c^>uYrPp5Sg=S<5aLZO#Was{OlpVz`av|t03{TV>?Dh)wOcVe^9S01qgv7pJGrhHSyNsO)GeHl8Ya+vtC~Q_q zF<*b>PbZW1#JSzK$vhrr!hQ7ut)le?<}LI@i1}5I(+*GUmg8tkC6eryIMH<4L!Bxj zRFLdriZ{ca+bnPLvcgRk-Ig`0C5+k`LeR?1dpqhVhT7l&ktMU5o9@5%_J~H}%d>5) zspk1IhS)9lb+>KHgC)q)LeMM0;{;5VdP;UX<8`x|U&F)I2@>An@`3HZ)I~U1WXRBm zzc$nm5>AoZbz>MVSD>0TYs1dDvnY7p?2>+6PS8gbZa!A9c2&>~PwVl@rDs;2>+k`C z(~Q7T&|%)4fM*CefKD4#IldD6-SvF3#w=OL3L@!|p%ZtB&Tim&UMjEx)YM*7rg%MH z_F@tb3%668EZ9Gvf**)*3g0>?p31CdHe0y2-|JcZ6F|gJCKP~LoYc*#o|)-G1LTt@ z?pov};`yn;r;8@c8OD1=j6i4GE`OCnNwt?M5cMq?= z@|aYE?fDb|F=LnDYSP`6~%-6}Qrf0W@i}rUHv%{Fb=%rx`d_MB}mbJ+8K@eT2b`3m-g22MaM;Yik zj=Z7tc0@N;rl}HE2yj!SL8a{sule5TX^?j@eG$1FUGqENaCcldDo#qGDM5w_?~2-N z|ACdLz}>iYTPL38zwiEnA0B4#+#k{LR2G^7$yfkikHAK+G8d`!?c9MA;`Kg`2mfm<+Q?ro!x+o^R3bczR3x z1f_&)GLKh=V-7bn%rv{L_zNDZ!JH!x6L8mpCt%PA;`HfoPR6dSnvH zdF*?{TPKcYhX20vQT_kuEas%zbO0I{tTDeVDie6s6Y}q!g~M1z`xigs_UvMA)B7J> zmn!@t9a`Ot=hFFGoMoNgJhaED!T5s7pYp_0F6ckUahF~Bu!C;6>a32mm2U_SLRY*G zs!2}_I;1ou^?}*LmV7WR?@8;9CE*S~)VxZkAUZDn;%!weGTw$hqnV-}Lv>YJXa*=( ztwvYfbR$BjTws!3ep$^I73MGX)UY;FSJB0gYrN;42tux`BL`!U&~l`vy8}%LK}nqh zip233&ulIp9Pr;;Q9f^T9;fT3eix<=+Gh0z;XpUm@J5ZLVV#rOw*+Bp%VQK;<(Pnk zI_Bh3di`(hmV;dkQVr&PSbV>p%w0=)8*-&xs!;>rrLSjn63$Pp5rbJ2ce~saoW7@EQNEgpH#;jL zDHGSZ>!A-elsDr6SJ0G>l$FhjFL^zCH-9SwA?F0NR@c`x_OqmKR`YpbWXoHgse6w# zWJAWq9zI$vLrsugCNXh%cJ%^xFPBx-D}or#pd+T^V})G?)Zp*vmVr!^%S7VDDNhQGUPQ@oj29$ zhflwp&pYO9+$wFu(XIJoAAVB*eo+$QM?}Ie5Ek2%z{p@3ADqJLay7m(<0{{ja5~8F zGh5Y=nY(^7D@}iuZ~r-;QzdfDSV;;Mj3tq)0k+K!{c>!+?Edg+CtfAD|H$Uj8L*8u zH6jja-MoG6BlTg!EDNp<{Xw>g4;A2Yn@%{609wO@0A5P~p$-Z8CSI?(<`POu{7Ov* zs=Ms6LvAkJ<8ltBq>;_sIS+!Oz!QtV&xUU2|4ciP?h(MgvH#%@Z zim{i|dTf2Y_06`InyJEn<7HKuc7RMw%k^x0WmPf(H_`OTo%u05B1_Z9)o!QLr(17- z141-kteV|dul?t9Qdh(I%&_>i4I+_fo?&C&#`P)@0-{S8;fm?iq%^d9R{x)BDBm;P zLMli@fgghhoUNFktzRtXz}OAnEa&TmjcFmTHE_I6@s-bUae^2$vFRbo5H(7XPXdp{ zxK-|mD&A3^*uO`hQv>z&=Od2x02V+8r_`Ffo>f(`b7@AnGxB#a)w*B8bO6K_JPBOe zGYZ569+}8HzUj~b8KNUfYfc7bxkS^X88LMdGZb)9$Qgr*1~;TxRAajPh%E(D_2R0! z2AiamM*=gc_#?Q8hfbbA6s%`x@tqhY@DAC|JA7aG|XOaV8W&mO^( zJt>zHBV)3&J>=QA6vhO{6((c@+u)Av{L|BfUFCR_Y{%YU2%q}EWkM8Y`O@IJow|c)?mNu?v1>qcolY(t^Lj#iQo>?#O z7PR~6l3I^!!I67Q{L2c={zwKj8uVfsZ^=m{O5xz-_{*)tcf-^|9xq_8%+~QHN@-}f zp*$+Z+Zr|_TCsq7?LfoJia=lJ8<#y6OZcz1xdDi0*<}h52n|>RQ}F4K$70?LCe^!{ zc`*kr77L>^7oZ{V+PDPm#-wV-%X)zjWTkmnD}39RNSVxUgW@|_7ci~hNrZqII`OMI zFp<-3Y9u>uEIkIq`lBz1Y4gvuAn`keN7b<;5E{F_Mivy@yexC$Kb+Fa1YZ^`LXK?J zI}BkJPv77ON1!rrrh(lM;i;UFxECx-Q*>#u`@Bu9Pjd8%_loQK%@yU6qyWQ3wVd9P zWIG8sSoS`KvTMTFEASmP@9;G-(DjVKL4X*6b`|y?^G1bR5zeBIe#-P}O-od8;-j9V zd_y9kd<6SUi^Vzc-w+}HtCORH&o1iOXFp>2q3^EeiXeAMq0F3-2v4MTHYDaGrsWy_ zB;aG={?yK!PBls0gGuX01-YsL_i}+a z4<0;tSnTgTfBwbZ{x`+hi-Y~cFAw(xy(I7`4FOC*-x631k(+21<`u3PLD&2~t|+P; z-Lp}X?bNRi?65Faos*JT%EdUbTIY>4b6Fuf9_kG3f-xA@lJehbw;^S!SdcPZ63P(E za7RoDa-j`%26sVb7Ca{9%Nx`=D^yUC$~o6K!B|T%7>VjiW1_>>xDxA zgeUl0&nFdndy7Y@btq_vvs$967j+Isrk{GK+yq*9YZc7nat(3}oUln%PRW;k;LmXi zGuSOoVJ4x8y>5nL^)p5MG^4rz@OeY8lqKa9Rg!Qaa8XpjrO<{krG#M%909A?mY4(p zkpf^w01;i!y1jv>8Rg{=-GX=7?_>bcaOOCngM~hZeRrB6*{u0raNOufQ7f)F`lrj? z^ySZIC;Q(F&(2Q|_Kx(jf^VleEko|J8@KFAr(IRN!UH4DO}g1>gqW6pd$d|Fm`|t| z@BXu2{MG-9k4gedBsa_N)~q4?1bX>o+z$9vR{-NNbaxzD(~Fua&7=g2K5<1u?;9{a z1YF~F1DSI=d+ftReoIRWO~0~m;5Nq2y*Zs!3igh;v;6${mz|%V{L=ID86{TRYlPi+ z?EUSU0p{@kbbu9<>lc*=fN)gaUQ~iU3Na)N1(BHZUzM}u3h4eGnYn^#};JPiYUouB+ASZ2=I3Hw?9EGBrMS?=gIYnwZaL% z6WgL$Cr9v}T17VMu1C<+h{&qwMy_*B_JgDGAF~z-$B9$oiuPD41 zA#lb(hX`mS;78@4jRwV@ZYmU?G(;i$MZn-hH3jkfU_(L-;7IR&KShS_RJX?cZ|Raw zj+F9($6kaRA8^`s=W1wK%!1bjcg-wq(izBm3}|D!|B)nsX+Or!8NH(JOqks_gtjw? ziVD%a23NjG1sqc1)%0JaUN0Op6{T&(IE*yzMxTo!rR;g9umA#vhaWO!#Ut<^HhOhH zatL=KmEo_PfQ_96P#jyBs0Vk4;O_434#C~s-GX}{!6jI54IbPzxI4jZaCaT_k*&LL z@2!2i_tjKSbx%!w{nu$ZJ^wlLf3H7`0C*fPp>f)rq@NK2Xl|62$$!Tp28~|64V&Vp z*|#T%vtR__J&XPhZ{BO1fEG&nz|@coRivDNTWmM$gS?x46J?SPmv;&MEX^~Y0@|v* zw7>QSeVzt*pe@KceXZ~eTVE;8rTAVEW}`JWN4Ux`3RSiCH7^XG2)B zjY3kpA>Zuuv7Q6HEPg!=_*}F=zK%z+&xkLc5nSkj(L|3<)kySns|N&^$^5}p%9S}A zCJlH-j2#!=JCf9Xhi+Cf%3$9gut%ha70r=!M5_5jY21EQpQ0NMhMuXps+bXLIMz#- z{A$jd6^u50ZFTgxuo+D3eeP3-Z&yL(E^<-L07fkKFRz*T!C(mfbN$O6* zn2&ugwZDQ}cS#-xi_xkA4$a+InG(d{z?2y2JSYcV#+A8DP-BqS^A3hd{s`3cA`<6xs7OA?@TER#2})$U>+fW=4P zM(db#zMpsI?RO&9joyWQfUGgT8m47OlPIci)@uQh+N<4wGzj39)rCq^q5-bJIzby5 zwo}hSlna$}WlV%^1;95{(ZmKsp6pS)P5I?7E83uwP$8Ct{Nc@YHZ8(_S^Gf&mPo)~n zX#unMoKIlY?gvj1TI#h>HW~C{$x@-xAjt^mVuSN70t23b7RmE4B@O`|R-I4Q3^lcg z%y35ps4ATcz75w0!^b6XLVD|nGdl(ZZjCwmdoI#WbMhjsrF!})xAl1>-&cCvp67F% z$En>q`wx;6zkl@}e~kQb!-;f=4d>^tg=-wn0p8-bhb^pXIp}-dx5sj!#`hhg?L0np zgzvk^-bhZ)56rR)lftqr!I&m&rP&kk$uDtKUU(3ls4wmSFcAdy2*`nWKN;Lq>&0xX z;5o2Cn!Amit@b_JZ0y}kf+wxz0GXX$Y?+(G6tI%)0l1g zj3WHD{K*kh7F3$6YfO!39cE{oyMnzVUu5czvt{_0u9Rn7y>azGmtm|6?9}m;qREt^ zWmJw#y|l`msLLoqXn+lBL{{gAdhOY#3)zj^)vxwrWt~CC%bz#J0_Gk`O)DK9)wsL4 z`buNK^aTUHJPBkQ-mBR%%*j9HUp}?X)gGF7A5(HRvZVJp_4D$%*PY8?0Z?=VNrW<3 zC9j1D@B9b!3BV`@aIKlaXGcCj>%)%IiCPoqw|%)v<|s#J#O+c^#LFE8Q_bZStzC;s z!}=vTB*=I1wlgQgE}()BgfS%jiOtG;VNrtp)ck#zSj;ajo}}bj5t5Vw@-Cp=^&YxD z$g|z{BoYSeEkBpTbE_>CeCBjjtkb+Cs~+|io!dtluhqty>U9shdGNx*s@kUf34wi# zVrBwbQVNPP$;KFNeIfZHxh(`ZTRg3=_HN1pR;K_5qCI#pQzff?l!`Yi-)SvDIS;ta zb|g~+P|A%TD8ioiTu5FkofOR|W+Y2bzOWGPzu>I=xg;tT2Mie++7jGYf*BU9+QEc} zBu?8q*kpde8mnW8>lgtjhHGP{8mCoB>DwhGlY`_ybQ^^hu6X5(@@lMh1$=B4u5tAF z?ddjHBKd$0AIuZ2Wn1skhqRJL6ijO?j!Hhi^$LWK}qJ8$i~+T zhOH5QDJO7)F}sgXYf+sivq&vHKqg`=cW#;$vW`oAyeQ0DZ>_AuUXS8*Jt8Z;tjc|M zuWYr>raY(m)Q1c~|0nQgQ-h^m?Hw7X%+vtb)zk}imwL56(@nk}_jiHL?|SucecfD( zAs4@xM~aBd%)?2)(o6{F8X>xw?U=X;QhrG?PewHaj{e7(Xvik$o|(sgs=mE5H3_Q4 zlN^OuDy%d~W=U0$ohb+m^y*}UGTF)NNBUUyK0$)rBeVDkOJRu4ydpYajV1IA<)1}hOCSZRG|&J5Ri3{oM4VlnJ~^2= z+5bxc*arUh0OelJ3-2fcpn-RsbzGTde*cpa(7i8^F+cie8YbHca_S@+9 z-(SeqaT1$aNZh+LQidf|3?Ch=66O4wm+^yf9-s_f`ovPo+@XI9#MtEda&{hO+LjeAkkhy{qk5zk2uW53j z&8mxe^YC+0^b$YlF*kakiU3~tPC8!kE-`(6w2T0;fH`m%VKdNJ_2QH__S)yc5@#JT z5M?kC%3Cx$OZaX7+9===1QOmI^CGj#J}5QXDZAXU6~JX)@A<($0W=?&o5oU01nxSs z24HBjD5Z<^!XYai^g&5|23dRIhZbn{;>NyS!XUn8LF#-71HT7CCib5=c=1C8^c?md z8`r}Qeuvd-nXu+d>+162gLAT)T`_L-vg?(ogkXUBg|zYY+W3wtb!ql4^LPxq>kZZ! zN^5%q4LAw5zki#7zt7f;w{E~J_Fz!SB-A{AVD5zhq7(Uw&9fBnvjm(GKUx1YY^H>b@>Nc>63EZi6=VHOhvum{HjW zLA$v#unM%{Y_jo@8J?>bt52Ugb-ak+iF%<0k4Z$`XGdxAdd(VhZ2!Ssf|v2k5(8*0 z6jlKQZcItMDeeJ-$_|v1ZAEFC5Ln>*aakpE!CFums&yGu#kAhI;kqv@pGnA$H?4i# z*h6!gOi?}EpsmQ+A&ZO|e*1h#Biglrk1hWsC9a$-R%*{7X0Npke5~bc=%(e-F_!Tp z#|*$AkqQEpxApJdG|84sC}WgK&^Sttr$JhE(NyhDJ4|5npx~!XENiP}`eNyCG+KJ_ zkpFDum2`nuJjigl(6hUOP=x*!_jyIp@{Y%!*HO-7ys$bIsyAbPm0({4M*P$44}qU# zg^1liIUc*~wy615hmJ51T>K26g{U{}4E&>Q?y>O4z69 zxP1uP$bm5s`2s!TdWKoYaXDv#)iXV+Zh+6s;CkH$NNz5}*F+S$^Z;i2p21 z>ou9|VW=!tv^<6q;Ri^S=^}`|Yy^=JR00^UT&T$)q3aWaavKpJI0Y@ck3zS@E+SOk z0|MVRGg?fLJd_e4wsW7kg@f^`54?AvsJ!}vTU;9^INDN7y=S&_vLU(M#wxLD&uXyK z;2amohE7hQ+#^?^n2suyzY8HDAGbH#T~ZC(grt9?@T}l5Z!hiYi53)rN~`g?#fC|^ zFWOu0*w;WHs&5ha?rgcwhBRJRZN9WzVP)A(6T~O0_<_OnQm|JBx)8!9VIs;4U;N8C z63?;n^~9o39$H9xx*%#Osj?Tga=zK1iF+s>p}!2A#eJ;(6~$%;KqydiADkb z`Yg(rV-l8EB9`~43m-u0vC(T<6Gut@eDf@PV;d~mjR{`lJpLsOXG}k?%2ak(_cxYk zX>z0Nkb!%N_X+=I)f~)+aEu;`Tkm^c9~;ugSXN?^T}S&T|h$MG6~7YykGQ8 zjDD}9)!8m>|1m$F%3{fr_7s)^GsLy#5VA=3=q!Fgh$unP?u$J%F1;6=Z%{jn*}Zca zf>Z(F0sIZ|bi(GIqc?HH^KS~WQgYpUWrytNtHp2AUD>V{?O#5?!;HNV!%NepM4}mE z5#&)KR^Rq~zE84`rd-sEtc%gBGub+Jl$j%|L-U84Tu8Cv@J))@>|^cwvF|`T*X__B zmi2+%;rc%d=+JfhVhi?q!F z8`Wrio^|GMr%5E7PG7(75!F~c?R$@2!hr%`a;7&zEcOVN1G?ju&g+Dn&5 z$ZEAdc2!?yhiF+&OK)5?Ajlj}@%8z4$x<=C>0u4(f@n}e0ZUes&RTPIlo zyXnjlRi#KS3AXcxB&UK|WLoJ)d_P<~5v^dS7M|b?&b3qYu+k$IXsg8Np&niNMSPBd zbZOTf9=-w&?4MoX&rxeOnpz|p43mD_5)Dl*6j5_b6>5!jMNfgtowD8K_;3hcz#fD% zq8$u>S6R*~XM`6I$<&6b*>WB0XiNe%Z#*Vsn=o$X`!Whui>D^tIt^@fVRg4eMP`+OeZw6%&fc%@AXeR}z|8#iZ2*!bKPDBc z#XQb(7lO`7j{Z(0&Lj!=g@?zly;<5;VCM$(<#zZW5iR)Z(3Zc#Lk07rV^W$}WlsWD zd27Cc=jsSI9b^RF6uby{u_AZlPcO)r+|^*HmaHq;2wClx_VlLwj71l6=R9LcmF3m8 zmftB2B6$)N1AWIxM0B&52ym%b-ibZs6$gdd2Z=V&77ozxG;)3tq;908IN7jR_Py{t zKzYHRsICWg`m)g|w5qM|^7;l^SjA~{ibY& zI-@BjQ4whJhb-cAp6P#XyS|^N5RPM6Ir-Art8q>_<2ea^o)I=zITZ{OLp;j5QnC2# zY$R$%H}=k8Q@_14KBxnK#jMf$_51>14x;&>n|NtIr}G*1Jn}aEo@>Ks7qkL8HNIf& z_DF7}W<*=H%{ z?G?KBJ`*}{j!@kqI7ws1^U;S&sO3?Rp3xhOFT*@eFb|{H1XX(!6MZ`gnj9bi?=qLR=vB^U@gxYKf3~K^d{}RhHAOeM&hr-efEkJ zU6dz!f0tvA*4?G;HDstF@1?0oDg%OHiz|_>(praSTHxdf5!#`wmdn%yvX zM0#~NgXu&F>vsRFeXJt5h~(99NN1!XB41O*(3SzNVJ!}u;8M9L6YQ_x+S2{-1qs(| zPMgD|LgtcUB1XIE1bR$FjUy7cd4BxjLVGYIp$hg2kOUhr49dtFnGequNSzR9F-+vg z11}He6F53dpgMVdNHMvc#1z+-9nN?P4X(3R+i%E2$~Q61cEJWJ{yG%3Vre#|O|I_g zHg_Kd==*1D#dk8qsfd;?iW{+BhC=hz2`QorHX}IDbnzGwF8T1PG5KT`k?IW514D;I zQgCQfp0%yzM_cK%Z8^Ogx#Q=jKV|mxb99uFzE6SonNitK?KP^F3iGbc|SZRHR=#<+vtZ*0nhXu+I;}F!O;IF>50iCbB)f`PJzo%p}D>0V5 zvU8t$6C(FI!S+L(`-1n-O5s%!;!z9vz(uB3i`pqqp ztJT#Zul69(K$S=C1--E1^-C-#$qY+DHXl}0Jscj4^im@*Ug=SMovG4ugt{%Vt7?Fc ztqI+-IoqOI{ck?Kt1}mX$052CWx$hZ2;S`i*HAA^DiZfG9vYu%yfM{8c@@Xw_r~S! zeWR6!Pz@CC%S(Pwd0DFYv_eC*nvJ?8p+R4t4XTWdUnX{uPC%$)E7|-re7`mvw*1*Z zN(j)PKrPr;1M6*EX{3twlna@MCV{!iA>vw1MBq8Mrm-AaW8Y%>Haha!madnAf-!9= zllMJAyDiNgs4+z|*?8f@5p#7=#{oGs8}5^_dnV8bRL^^o>;ms5z;WJ9tL3IYnt>)b zImDuz*0vp&d&r7i zR4&MsL|yWZyJ@SRwMwoE51W}3Mib1Yq}EC0Ci4yEztS^*{OaN@#10Km008WJiU#FB zRjO__R__Ane>137Xeq?4aG`b_YoQKH>hi1;`OqLfOpz(nh@=A)T;ai)8CAII&`T54 zxnA%6#lK-%@DG2eU`mmlSjgb+op!t1+Jh>0C!aIv6>1N>-NM$T%@*{%D|Q@EWbxdI zMUo2q7E>L~PX7cps!4=-&Tzc5?(C+E-T$=R&gZ-zsv@s@Cz*-?Hyn$<$8iG(t3#c{ z(&Xdxr8|NKz*((dFLf_Z#~;4Tx;P6h+=zWvl3_DT7z5n-v>H%x9A>(VuP7X7xX>Cl z(QMbu7oJapH(i@89G+!A;-T)|K#N1PM0E>d6O?L33r-R>?pV>8ZfG-TFug{9F)pQJ zi?Gh4S#S5hemXm&T%nqEsz+a>If~UAUH6{)*zp$a9AcAuf!|b%9A10)V^VS)%boyM zEabbc48U3kSF{0ZyUZk}9*`^K8WHh zTZJvEH=DPCz;#?I0HTmSYvSbDqF1NOju7A%I6#tZ!f*FZ4X5*$Vx30QNrsC1UmT(A zEV@%qLdtsh`4|OAvAK(kD>@Tq;h*f?-I1)htg2l(;RiCjp73}4I_1@o;RnmTAsh4I zk-VWs31@E0QlB~XprjNsgMKma%q2vZJE1SHzs>BJ?XZ5Zfzk);M}YDVX|K{3q%S=B z@5pmiXc!Iobr1s8{ce+BpwpN*?|cCTD-2*=_RL}qx9!LCa!q?7dUuzz1acM0pTUSn zGLmYO8mliMRQ$qd2?8TbUdpuZl z;Zm?whIT&eHc}i&}Ec-vdsFpsc4KYsqCQla%AEfvG}Yn9rD1Xb2=CWdbt++d-2 zx7kmLCEP6>1+Z6bl1aQ}YH?tG2;V!AD1G@+S6l^3wzhrwJqYzh4j*)v2}28t1hwR@ z*ghz7QV;G6H+&uMAumfT>6be%0b@Y$pcRd?cOz|FlY(K6^YKVujMygUqi33A8n2v= zNBgJ)z2@dtQr6<1tMW=xfw}CyHj3)@F zL~*5}*3jCjK4)S7a|Yx1`?;3j&}$>9Om01=8Q=*1o?=!4qEK4^VtaWMbl`h zqNiCTkDlo{0oV5M`rlY+!26diPe+Bz<_lJMCc>@4{3~iDY38vWZ>wlL2d-wOSaCLhZX#Y!Y z-pJ^$?-?192$E2(g*xq5J z)_`}j+~29N(0@_Q>}@O@Ke<`BzK8jr8MkdGBy-eEuuce`$n2*ni5||K14DRDWUr zp=$s4ru{?xr(6B+)U5xc{y$CY|4#q!O2EIi!Ph%lixZ_iONfiWKVoR0jay-(NQG LCzn+IpRNA`+u!Y8 diff --git a/python/lib/py4j-0.10.8.1-src.zip b/python/lib/py4j-0.10.8.1-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..1b5dede8f2d627760e571cee6ac08b472eb31d34 GIT binary patch literal 41255 zcmZs?Q*bXpv+n(mZQHhO+jg>I+qUhj*tTukHdbsW-@8xMKKoRC`(iHUX6m_^>7MT2 z(~2@6pr`-<01`kM+bl_H2(m8spPl`G79s!_;NZz>#b97yX=mwTU_kHSsiq1A0I9AO z(AYNE&~f#E0RVwM0|5a4=L8jX$a2X(_Sxp3~6c~Ww|CJ>zK_&MY z5gu!T2>|d}0syH0Ez88l($vny+0@C+)ak#a;fkc{6#xIFUGrHxZ;2=EJ=2(7AON{r zs&{YHY#OK5N>ePvxtUIoP28U}QbCf2BNHR;1LvFS|NHrTHZfylP)xBgzZKC$3kNmx zGW*!!<14~AJRvZM@@yHO+KTk>|Nc1HyPFs=Wc1LXV-H2W)?1D8X`w%sPO~LpZWvFO z(HcVV(3pblic)Jy2+f#y5t@7;MDd^FH#-dPw&Tzu+BJLZ>Q0$|-f8Wrl1@yosa|Ud z@gIYLC?f+*;)F3B)@pUk$DV0X9hbf9sqacFO`RIxZ0z&qgMDdhn3CWGKr$9vG^Xc}yD^#0x#GJ9>&t?a07^M5z}f-xlQnbCevuw~mn7VJMW$kuJW=q|?=ME|w8 zVZSc)qLf(952OH_mRKdZ4$SJ7zZn zcm-6_V)@&%YxQ78CSP0@y_&CGEKE$CJpHTmzPmYm$m+_X5YYGccg*eW>VB&ucz0d< zeDnCY;qKx0<>~Q2pqq<}E88>3`76uttn5gBu^p4kZrh?SlhI!YjC)YIWV5sz3rg6O zYud)FRxH2f`fSN_rC8elJGTXm9a(*Fi>@}!6Ha{F>i_=PTual>;l~FNS@X~PxmYVu zu^47z+BdqmN}UEV+y#!o1M3Zy)LkY0Vyt>yGRztDG3@>95mrNdKMVo_Vmzu5Jv@tI zZ2mJQmkcC}lp$M*r5|5oudk24Z}ACjI1^oMIeghY9rfk^=Nkwe8ib_|fdZmACsa3x^3-&<)N*m$n9JBG z=vCPF?ji#pST~LJ(W(hNVi%I+9}_N4{g7R`m&zg~`dfHccfy40bxdSl>ux8)xzXzO z5Njt_?Ai*p_}x}u?@?)ZAw~(2%hxo3?{xr2is5ySc z^y<-z?Gx-LUG$R)d>NAD9=o~NWeq+}xK_teoBo)>cnA_&s3$gCc+z!oN=RggzdVe4 zBnh-q90?BfGn^a5u@>x=qzh_r*Np4ja$eln_M<%AaJ^r4v4$c3^^)eMK^eyNZQtV?evH;>uXHX!3=_QHpyFZuEFsn%)}`7IpXB^ZRJ zO=`}vCD2hri^si-qFz*qE}T^jj?^-2lPePh>5+`uPO2{z@0Amp3z0U)Mo}jogp^2u zDQw-+1C@_mI;TMXn0X8JxbXYEZqN5sk5hfD8fDpuW0;s@Jdlx)1|f0gEUDyQZ7qA7 zdkB>|?sYA%>ItAHhfcn;yz+BK8}Uw@l<~91@gL?Ncv+gt#`UPoXyPY?#H{kH^4|nW zUy)^O%LH)DY9Pr}lR7#$8VvaM8EmlB(XFVi9A6u#8%jV$+UOdZr0j!+#)j`2d4?7V z^dF=SP~qgoSd-L{oQBF%IZ)POt6|(?1ir0})xsJc+`)aQq$8*t!5rxFX{$dBi!Ymz zsnNX^^VkR&S|=DFnORLm)tGQ~I{q>b0Yfa5y{c-fb1VXjk!hv436@PU!f>Ero(}XlTq&5YXhq=Xe=_l79Qlqvl;7wbG%b2iG0zC8cMMCc;;Xy;5 zm+6fqHxA|FrW0VqzEh{AU(`T6&^4me{Jhtaa)@^(R(9BQdtLGaTNCmg(pqAj$UP?t zbsDowfW}+GA?SQy#vh)n;wn(*dstCeY(|+8IeaGxNT&;}!cB!d0%`vZw_LTUsF;4} zq6pDv{Hs$deE2JAYjF)0SbyY<-BIAM7|R)CvG{Dd@orML>vk-WCg#rPwxx0Vy!eUt zlt@BP0^_YmD85j!D2E7MX|FH1Gq;av_S(mmh{Y0!miTh`~F zig>e!{e}5Y{^~raKe(rmr~|41rP$Yc5X@LfAu)L(0Ez@x^&{s@9E;ftMkra{b;oYH z3`5KigZiyZJpkul+f=g_?!W4>ypWaIlb}txR(JGi*}!dhU69SP*?{>d01JI$DxXms z3p~8IZjdC0JCnrV%bY?F;iL{S_zkE%Fa=@L;JRB`V}}}QXNd-_2~Ha` z$08Id&@;L$h5P43eS7}}PPowK9S-8c4jhWOsd_6`4U+h8 zL8xHRIcqtM-Sz>c&LI`L>;zEjcSOdLDR*Xg-B;#8L~YmK)b}J>YT~>sg|5VJAo0U= zxV8~gn%W&X^u&{(eak>Mp_6{Smonj>TRKNTB4_rVUcb}CXcTgUqz!a% z!UK6N*-ZYr8W}6v!h|9A%@pfcBh6#uMB1XfrWqKl!q@b>?Hx8`6yClI&sjzjs~B8X zy|)-&NL!284cp9)XEAbI#fuS_utNadbFQ^?SH8C>&~85p4;N69d8l8pL-bWzJ<{<+ z=iS`vG%`Jy5po#yhQS#|Rv^xElxUgRxj82sK^qERNPD`Ity7P6^7tR$NVA!y`7KW( zt#q#l$kC?=wgWvE7RIrc#NtE@X67-0H136GlCI^5xsZmMg1FBwh+H4FdEDGb;AVa)3&Xhsr%7+)Ovc-0=qFfX$)Ki6RTegi(gu z5lk3){(9v~O)?Kl9q#|=cA+WpYPUEdwA4Cnip>ztHvxsD3e2kJw4$4;f-NX*t#XCj zGAoOPyc<3Ss9*^NC43Rm7hDzvvEXT2=Xj(cAt;{GX1UUr`3$b~A*SyW`>SR@+Uh0Z zTv(WTHwi5MHnU}mwfI2wayo#QpQ7k9+fc`vrR4Cu#R5s^zRw zly-5qqW&bD)z#?Gh!%8dWNVGCbf%7NSc*15OR7ru-L9BObt0yxI|-Jl^YtLvb;pPv z9$|H9^o{?seB==`nA=-b#H#g&UmKd}_&_rPPkxUiGCL1*$NeA*MM~3M+dlQW@Ro#v zu7y6Bb3e4b-!KfD1cY+#5Gd$E?5}?_M@5%v&IMfgV4?1f@A|C?_tIm|>N;((7!Qb} zY1faLBjD@H+#m_IM^|};3OuB;mHC)m=3>YA{*hinaE5Ny)&8(Y&Z4VIqZ_^p#)}R` zH2F`Rlakb;ox3FK}bTWF$ zzeZt=`&Y#8Ornn?D0WwPV{_H++Dc+bx3M?YS%R2Rk?gg?EucwOO%Fa;#_#O08j!nnS|jzt$*8t*G6B4FVa1#c1;J}vw}m|gvzj= z!P`Umb`csrqq}Xit2>JIt(Y0iwSu*5TrN7dj}>oUuuQGBywu1~{?vc5SoaAzEs>QL zwd+}Z=N^7VeLT`=vxP^YPz;P*F4S9brJt9U3r2kdAm~zd%+uBFw-q&uDjTnR06l?N zCh2u$#&lyi@dm^T`18TRWHmy^W}b;{IvDlz<6JTFGJP-P!J+JYcyGmj+LzB_hGKKs zGlZi!bOS1LLo3Wy6O87@^H<%}_f%`h^$p@YdyZlIh`|VaTjIZ;EY#5w#hDQ zk>$SNE;k7kCy#ZuuOV434_Pg9+vI9ze=z5kC}XwdI(H+f7ngGn05kC$yuK-G!e6PY znjb0*q}V^fe|!~iL)0RWB<(rb@d{ZV5T(=Z})P+;;^BZcseygwsQCFa=*_A2UIkyg8>WRrIDM@ZYdB z=APIxe}%+tpxN}EmQ}fwl)vpjtUTQ>q;Vdcjg0#8DuYsMHt*jg;Bk!jn@8+%It`KU zP2SZ=9;9_FBenGn9!E`GI^eQGy8{krz+G4{|gj+wX=-4VSLXX7m+`|O`&l{d?&l;@*=#50FR15WL$}^AJS&? zWU3F&cy(yawQJ=F#!>lV`kbIsJeE6CW#$Zr*S-bJrX!o#hi}M($UJ-Nk0G~Hd=_+q z5GPc;>Zka)stq4xYU+Ex3z4M0`Xr}IVW8_wU|TS<=tAdapBagseRWTuIv&QwYh=-h zfx17w*Q*?Cul;(H=Yg4AISRUcY5rA9lr6G3#O`7bw~2gTPHfa=bTW+N@&g`AbGjJV z2x0#+t~JAb5?0iG&^;_}dhys_{Q)gmP)rJw)F;xgZzm)ZPd&n&sS*ArHwj=jN%ATQ;Dk>KBvAyB?tJ%0Y4xxIh&_3(N!;l0TRD`dm& zm#QY_gsOgUl$zNDRCkJED>oP{d)MPtGSuJ;9V|3-V#{{gj*6OCT#cvMdG9yes^l~f zw++E^H+I22&$_~qb8jS{v?V@KRB&z(yKD`^57+Gc}BJ8HWmRtwiv?BbRg}xE? zHO!_JT^xREF>-#yz~glWnScKb!f=y|ga&uU1S*%>z^8~w^>|y z+j5tKc5+VlyNEv0+J}Lc;U3CAYt4{=uJxtc7(D8P)br^#oRUMcpB7p~^D>=XgNiM{ zAxAipT4AHa?}GEUr9L5VJF4PDX^ux6bi1fUuUW9;V0ujFklZolLldYcOSrbO$tLqz z(dUpdgJl!`>T5nhAnQ4ikR4$MNW@AZE&qAvW+9Ao3LFy=Ov?HuYg48T(QFYHT=`$)%W6`0`K@4{J!7)q>}sQ8D@SGT#v<=+uErL z-(kze1Xgt89eD5W3s%ayxzSQL8Xm#jV8x`~VoW4+5^EWr#rw}b&nbJ)oF~>`-slT? z_AuQQHE*BPFE{v&H1FgQcgU~oPKGN)!5Q2(y$)iV1$V%l@JS1PUtfo86BcZX6Z=S8 z!X!70I*U)wUgH=xh^yRGyzL4fhr4zieC^*C^HxiXN^t;L(<;`9&^pgjYx`cGFKzsU z1UeI&KJgUVD$uqgtLsJ+;uhNv^wb^4K$u&JW*P8C*=iIc!&>(0j=u6F-i5rxXzF5w z=Qwd%7ovY-rsfxw%hMN+PqO{8RMiqngOSp&*>Y?*&Gm=|%!gU3ddpURlG1JKoul+u z1P=m9)k#%ogSKF+*4FA@;qOStmq*DN^Ng>=z>)qd@hs&M?}C}0`}M)}n+E?s7Lj~q zdkJKC>9Ox=4D1VN|A>mX39iZdL`EOu7W59plX|`b33@uPD(EkG4sqeg9>g_{$jH}S zX0Ih4tm$?08*#hhbMPIf(E1r$HYT+XriOK%($UD=p$kwmQ3Rv5wD8`n`X$E{{w_B= zwfGh0GtDWE30VA^d-)ya=~PlRKHc8QgKq&50=&0c;5w%n01M?JmQdeLIJ_n|J4BEG zJUT47YV6@!?I$aiqtcZ{DrKL+Jlc$Mk(rOv-y6j44&#+dF}3Wh9LX|Rm!+Lpk#rCw z#@`=*--S|V;lg4^7|4^txc;v+uvh!q7)+TPe81dsD`*bkWyh2cmn< zW$^hfubtt1$OCCU_8k9NhQqy?z2_P1iqc zUprt`m^YxtXJKG-oojkuAvAU-?s%zPu4V|1@ZW+)`e#bx z_E~+qb=cJV#6P=qh%kt3QkZRvi%X_LAkwovBed}n04+qF-swj33W9bqyT1_FR-E(Xf78}0g!(r4d)&f)u<5{d<9L@r1cxq| z0DdlXK{Yk*Ne=u5f>0L#kmr?Lw&AQ=`n%N$Nj;lVS7}>A)06DTFvud3c3%Ju;=_1% zO^VAyk3TIe)lzxv;VkbKt)%Z~BvW-76L+vI)#}bV=8%~aU-@n_I*r7n`IcQucnmK{ zXYa)>;)N?od#N{v>>tcT_6Sz{_S(PD_6%?S#Bv9g^zKpVO|i>b^3hlmp^=L%U8-CI zWq))wu$;K>BM!Qp1J|NIy@TPn_j@ZZTz9-l-FU*gz9YQ-0+vHH$~%uYMT?sEim#>C z)H_o4_|qO-=R&P2;?0M-D@~NoAZt)76qY6T%8+vwuC&C7a#aFYHl`3VXE_l#ev4Hk ze2SC}qV5DjvwM}B;CF$c1;k!1$>o%Lm$&;s&)NKN+4wEqIhP)EKgTj(c5dGTEzQcu z{WRy<`12big?E^EVK|m_=AG?{Gx-NJ3`+Wu*ebIrVX_bq*Bax55f;f1Xl?Xj z4eGlzqJ9`?g_m6tkmluLsk7W;!T_nAkzRjtq&lbCq60SJY9+*X`f$1hS2XosP9|;X(>;*1W#=8Q2f-v^_c^ z4Y2fV*q`nmt)FAkiIEYp2XV9-!FseNMZ*a)Y(s}3GL$b8t8p+EaMmnwsiC!EZECq+ zj&055W5f(vNtn!H0jUiv@(Bj)%n{&}5DwX)c$l|?kynN!Qj?mRl-d0(OPzv#FNnTI zIO)+?g^K0Mu$4+1+bYKvV0{j09SgEDWCt=%Ftb70fSOq3nh~{m$)J)N%6gg>7N1Y; zlh=FGChvFp8MPh7j(fI_D!C-mf<+zsZwMO4_yO6PJtMm_E#N1`FbxNe0yl_lp*AF| z8PVYyHqZby?{35BtKF^z%|h`^l-$1{AS}7Y(ypE@fs<)lQN>&+#xoC}o>?9QZD?Y! zf$i~Xt-!5l(t|lM8O>&>InND(smq*@sh;Ni;9LScs|zPxfK!KHni35@XU;Y2Y_XppjB6UOid-tv9 zCN2Yyt%RUUm1`&ksb!4nfzbcqi^xrGW22hFx;HKbn#qVT_;3d8%|J<4L3tn!b4mFA z@jSt+2-CFgNUAt$$AIJZhsxvk&TQ*~&G|bMygS!B#}qy<8x|XO=_t`r(4GZBG}Rp`*mgLHY8=&DL7gwRXS=1Q&gMGZUm*E z!S3oN!3v~|M8feRt_6MLZtq8%)*H!JElj7(3jObi9dF-zu7}Sr6 zhy>-D-ek@q?3fMv3%fvT6jb7Z%}rY7(P;Cks1rm;RQTh7v*%A@4sXIlzJWpf@$b20 z+~;Ks!ko#|Wm($f6JHrJtXxLF>3m%9hDA$(+JD+vlMFx9ehX!{m%iRty(4&B5AH9% z*~33iXx_&`cm{R;?j5;qE%+Up6z7a1A!YUbC3&8crb~G+&48XJ`3iqDqaTFvBd8{b zBhTRr0hJwRc1!L-S;NHKFg}M9%dFr3Cp!FBGL_On*U*6bk3~To0Kok3=wM~!W@KPu zZ)0O>;$mrU=luVs)&DI>=-S$EaU%cR7z{Y~N$RXgc77fp+Mth{?Q}gQu}18Gq3;Nn z*mA4W$nY!4AJV+^?M>f_73 zVf+rK$+vA?=a??35o^*PGs~FrhTS&NLL5G6AA>R0uYF|=-I5|D6VEEMOSPv6Hxip* zTpGvGwKf&;eA=+alj%?%+|lV6b11BPsMV7msW6!yv1LbL8IXzfEnVQA1!2rXKj=TE z^!RL;FucCJa{mSC4N`gld7+t$?vNZdp7A@pVk!X#WWWUD8wd?hXR|0xb+a0`DOU(}HKu$JKWwy3MTp?E zuOaaDkliUUnUMM*sJc+rfU#7aXl=F_0SJCc|72iB{8D>FpX7rBA*r z(gw+yr+byn;y$)EQ}IE+5G3ppAEI@l%ft9uGLKPlESxdMBhuW+2`}Jif$V{10FW-r z5ib+J3=YW=GsYTYF3Gi3;kwugme&g=thB#*6m_xWMiS<8HF6LYR7rh&#hp0^7dT|x z8&I}D#tWxfqR~@=HWCgNIrRstvL>^%11VrS1(+yzz+ZM-9`t}{iac3u7I>lW-)nXE zqfL9(Gz7rKFgCCN4yZ2!gF6lvGt}6|>L@thSlwAG3yl+FrCK$mL`<6Eveq5#=XL)P znG-m^fk@bEdET1HnOL5!c|r?T#wknOOGg0l0#U+ zVDuLyaEuvk0=y-X3M2YTdzrNtrTn8Gh)V?V1+M>1WJq1(jx$v!V>;nSqWUt8=0=eu z2+lK-6+hdq>P~4882(i$|I_of4&=r2x6Y2>VH}JB_!yo)w7qL6{gUzNoN40=6kb@h z(J}#>y1H`x`wOf1S1syB*yUNk%n=(==bQrAVm=HkszYn)FofP*ZLrR}S~_C#MGQ+U zxG8#pV8MZLP!Gi-*!^~hp6E&Yg(x}D0MRimqaKb@iB)l+x@DFW$d%Sr=KuH`MU+J`ANtLmk zvRgVni5K&k;soc)a)tu`0Z_1WE27Mz42$t-Ob)t91#DOYg%4S(G9zL&q*Xd#Y$EF( z>cne`SAST{2hJ2DMZIHxQa9|0d`ckkz!Q2^lL3tMX;a&P0J{gQUULyjj2o@Z&TpuU zK%Pga6NS*8YT zdDl%}uG*Pu3lc)r%p)X2*?n2riHoBN&IA#dW+LmFEeA;@WYU@vl?_?uVWFO8?;@QA zTf^S~(562Blt@pP@E0AkpPh}H^^cy7p9Q~q&&yq<-{}J&h1vV>NZLPcQx>%cLP7~M zn!V07>!=*P8(rg)%`a74er;3o1r)?WV@RZv$V{ksk~mpI3ebNl_aka+=Kt0Osyri7 zC^BzJbN4|<{6SWjc7t3ECs=e`wi_kL&BaB)ANzttrM_a6QS)RQ|_smI|l56@>);{9M;hthOEjZKafUAY{$^HZUx4ZX&A|yFPT(sf&Q$ z?Ya~uNV1O43aEe^pSA+p591+%rK%;oyjidVgKDKyw94ks0&?lU=84+J$yF6q(wq@S zR%RGHQnQaupz`{M3hicVUU;lG+10=PW?&fk#W;#c&h(?Ic3Kg~6{M`WUawgX&^V3G zrncVdD&YS0Lgf8gMbsr(MYksOoph*QDVRzt4?wQ3w2;-7lqa#lZ1s{AN1Yec4}^|K z07bh@!w zYuY3YFhGsOmfYF}xNvj&y9j!D^&hB~m+hCGI=82Buk_?R-X7mL;!13vGwPGC7W>x1n4+(W4WA5kk9YS7C@(QqKVZw^%LWCq~ zhlYop?(BYwhr*e6b&{4#8L$bf9oTE9^XvOgrpruw-0&c&DpNIVv`YI7{SApd)C+9r z{M$-&v{b$3AdAy}l`cZlZTjnb$w+s(P&?cO?j;3s-8N~z z$-xd|U%A{$?|$DTHc>Ufp3<{vnJr`RWux(Pcy>U>=qMnJ6_aNMvLxG^J`{U$_Z@$w zpb%C>Cv8I91o5&rciES>CNS&yy{3tpO#t(W9A;lmcGslm7N4iB4+kR&wL?GPmq2oD zp|B9y0->O|)RoIagw4>BBVQ9xogw1&!8z{nb^7&Y(G0^{7_3wf=4TK=?l&o!ZqM9A z;2fpA;L);1jx=1A1T(ZrxP#WFN!InYbIs-L-C0#+jgzHZ!PT2v9?hY(?A}ml<8U7* zZ>bkUp)e|j@ESwV-}%SZ?@%7C1x`WepGe4OYWlO>;mu^*n7C&PrK1TTwJLWFtCOOk z$AZ>Tm}yW5C&nj*G*Cn_O4Dp=aWVW%OS^$Zn=uWRhdoOJPo)WC*G!*nK}5llBTuTG z;xZfmlC+bf=|iMygPKnJ&3#LEFIV8hOC9>hmr%pZ^ZL8o`$l(&BfTr@{{C)O9(Ipv zs>GvzTN~um>F#sisa>VyqvW*n;NrpDCSg)7j^Sq3bHEOCCvI^DXjf-Z#?7)Io^vL- ztrBguh=g2m4#^EPc`uQy z_@u(KW7S5RK9VK>=y{soPv0Y755swk?J+@G=M%fpNyJ}mSn79AcebEw%Wp+iK2aUG zN96<$CezxPv0!l&0eDD3DzvmFa|2P!`LO55?!{*`?##FlU&`JuTywTbtx)|06N8c^ zZ7f*92V*<21YZRf`bVmm(W}<7x_>if`tx)(TF*Io7mf~-v#v0k%Q!FID{9_fY^P)u zw58tBz}Vm7X}$~YJi7I*0n}Ygu`opQ0|FH3VDzt2h$?%`?AwEqy6H}KLZsC71i8jm z)6mUOY`irJdA;R@@T*>!K?10pZUgs8`{?1@Xff;{_)u_!`+Q^_HI4GeiiI!Xwaezu zx!`^vX@EIR!C9MJ2){fbe%p^6h?;ph>A}@GG4{73(|JP(!>G#c&vYzvxF=7?_d`IGY4mqfI{!TYBbEcPcBQ3RPL#}jdFaJZg_r(9s>S&cZd&nA^#NeTJPcQ1E)_C_HDOLQbK?YKdUlH}ei^BN)^uKlaZ~|+h z{yaoW_**4Qw|b~Kt4hBFU9ws*F?R2wRws=%Myy8*V`!LKMrev=PBqC~B}_D?mRMtj z<{SH}H9?0`eMTs}Oe111i6HFg z3N4!<|0=zkX0^$*Efx4B&<7Gb01x)v zzj|5vnb&KHHC^k2L5&@&BTeObv)gH=263v8-jFS%`mIS~q1XD(qr>lgC!aQ&xj4SO zXm`z7Z4cvwPN3l$K`2|%lL9$?+nY8Bq;L;RXeGH*#i8y_F#Kt!yY`5R))cM!G^Nj4p8ESL{5WV! z1V%n6@)~G5%!;OORyo=1VE4mXzO*@^WtH~Y(~36&uZG3IhcE6WgyZ&w%e*X&WA9|6 zy%26DHHa-(;?M3SigRQZbc^IAz{L+c35OEp8~&FrP_`@?+>?eE9vr{r$@R2R6QYUfx1+nQ=*PE*e}~}& z-{=f@3q?f(Cye;HIdJrL`9l{63;SEqOGg7}aYZ^EN^$TC3I=%iz61scwY$K1GdJ@m z*|R6{41LjceF7aZ@$~Wv5E4Vl(ay$Rm(lcX1o}JsKK?hgYAbMJ?y&+^A4;{VuPn4K zCDNMi%Lh5WhPo>wPEdAlf>p4zNSs^TeLEx55EwB|JH!u>afYm(0EVh@qXBdQrdaW>Y3?4ZE?azXs1 z#7?D|r-O0exfCfT_p+GM5v*f0ZSYhq$4(+Su zyStyWj?)0Pxr_Ue2sg)(jhJl7$gvkq_JNHCal_c`Bgv=HEaA?mi1|Z>z#v6H(yiDs zNl?4_J)NESj}alfK$g>vT?A2F){(K|FpcDZiX=Kh?hYC1g-%`94v=y|t|J+9KZ?ok zS#Mt&L`wdhRO|Ycwo+Fuuj;dsI4CX2&QhPB?|ZtquyNwTA3hH7gn&x`s1G?bFgVil zc(B=t6PBBsyJ>(ks-#B(Tm@6q440#Vz$yM&g9R0&3Adq+Byr?Yff?Ri@XrYOQ`0TM zd=io}=@|@b1ZQ6Y9pWvy7yPj!9#m5=n}T%>YYxT=Kuh6D#GwH?9lOU%s#5F@OC_vf zk{1*|^)e--F9sko$$s73{aN?^B4+y&a2b>raMItJum1J%amd=+%lYBQ2)jeVv7*R( z>p2*Dk0If-LL>Z4Iv<456s~qdmeGB(H(K0;_a4_nEh_ z$C=i@N17kFFCpYC(?WT(4wqSec6OH6^=ku502HA$?lq81z2!zoHmLC{x>EgcaI7YH89pC z=96CesVET#P?I*NqZCmKs$QI=b@ghs}y8yR;cCSjE81+#NJ#Zd;xLi2!|&2bi|@?NpbKl2 zszls-k9Tq;w{}K}w>3xPzHHyY5#*aJcqElTLc!DSDQ)|R)wrSqXDMxx1kv4GL)p+s zV3?bSpN-wWA8ZKvoam2Im}7>Gu5r|9+vX5=04;%jxcz3jA861VA$rJI;c#K|?gPrR zmFN-b)>He%ofL>Xow<>O5tG%Ro^Es72%L<;K3kmVfK*_0Kt&w&kE8rQT(FOuSLKa< zqXj0J^(Vb<4$_P=iLI#01N9BoD*3>1hF_$fM-IvrN0Jzfn?7ekJyLh~;HrpV48Sl6*Qd(3>TZ(tZ45o~>Y;{mtLQD{)> ze4b$?mFgZk=Fh*wN;9Sv5#hE;7BT0Vxz&@%tR8i$nMtWuFJEBlRct;qE3Dt3u)Mq< zC=IgdF*imS#A<#)7I$W_OkdKxzdFso%*&I0W1+yQdsMiW`x}+ic?F~9St$T-00|*o z=-zsTy#Z96fSZ3>iHRZ1P!2&Z*}lMQ^Imm^mH-U~-P$fK>H@32bRj)MQp3wH5-r$^0v1{-y3$_Kgxc#g!6f5&0! zu(ZWSOg8von9x-N#Tm(PlQ{xyWAzwC#E5L z#&!0h5Mve&YLf2Gv=r5?lxY(>1>BSz`h-U!O4Y7Mf`vOK$shx)3ffWgDO}AFmAGe9 z;MlfcBMn(Q83vYtuQqUEy>JhZ(?ub$X^AqN>)2vQZtm%_5!0ff zqrE|^+Q7|st}&FWcdylyu9km;EcThBVbRz?<~=}F6)97_uNa-TQ1T7qRu!jho^TWf zQxmsFb|8TI!_LrdsolZ2Q}~|6tlObD)uC-zv(h|rCn`b_yt@(h3Ga5HHoA!08{5q3 z97;_y#ny>>O?N&2+Q>{#AVMGV;hA`AButZNI+zbPM5pO-UaMk1jG*oWhXoa9#I4p& z6P>taERlg+5CB?Q$0;Cu#TM40THe2tQS2(uEBBc?d>_!PlVqH?DlswHI6RU8t=Opi z?+eaIv%3Dkz7T~q%$bIbK|Nm#F|m`4M9RT!!3~qk9Z`9lk6AXRFn2EdOyQGqAL*l4 z>72Y+uWjZfkw}O^I!wFQB5x(yM4NafYPANrjS9MxC^c>C24`8HGyQay*${@A?P-H} zJ~BNw8D`5=yQLfc$}qBXD9ePDp3&Mfo`_k8RkQV4A=@38Ni~z+n)X2YKG`*Oky~c+ z^l$_6I0zTglGaQmhZ6c4u4QS5%kmE2fY$JcE|*6o$X=qB(Gp{8X`_5*GGF~tmIX;G-`u`}&@@X4fk|m(!g4a@L$+;1F){!Q=N}p5R#hty z*LKpKBkmu#_A%5p>H!e>M&KR_mIQTf2=F_^-Gy9u4@0CVCae0Uipc76e8ClV-6Pq5 zV0O;9#^S0bJF?2FX#pc)7*QLkmZ{J!45F6C{2v0nJU-r-{@#x49_;eYz}@ZgHw{B! zDgU_%l7br5VI?g^BX8vwGVX$8hhtPA|9WY5w7!nY=(`(1rc~dbVRT&Y2w@!p_B8p} zC5f~L*nn7l5VGSyXf+cN>jtowDLORGs7XccIC6_Rr(fc7FD6HQ6MtJyNeoV`w0N*7 zRQ|?UlvA`T&0gp@@sFDr1>p)xH9pCQa?`a2K;-MGJ3Mf(hR@y4}+Z9x&=~4^s@T<LKN(%hqSb*B7<_nO^{e zRl_;Eu?spRl33U&TxvYUaPFr-d^ZIvvx z58EVFytFsN{G&r9mE%dNF=KVt+y*#=KTt=O;$YkEBf?iN_LQC0FwM8QYJM9fG&83@ z7@A0&1;_8ZW=*{-Vhl?3(9tdy%*hos-Rh*+1XCR1x4Lc$bK(!z$emt-<+itguOux4 zK~19Y9qzZD>maWxc2FI-1?v^EgSy+D*?B+X+B9pE&+~dLcYyLi_KG?p6KGGtQt}8A zEv2)%=P^|krJJ)nK0H;-C9JHPz$T%mkmO&d9$B23!6FWz5VX_Az~9M+Q^E0BhEAbf zn#iWWBf^7nrb?1{g|w>{I-)y+#cO+~&YuPR?P5u*8=`aqFtx2n-z5ZFu#%~TBzU^Fq1N)}{n^9LJxkM17@EmBByy4TB03Fo0CNAs-1VtOIbiOhTV z8xc~-rS6Ulqw_e$)Y}TgCCa3Q-~m~N9Qs-(5#Wo>ZKj?Le&pJ}9fT8LULo^?xPT}8 zN~?`j=1}ULqtOF@aDSZA)UmkII-x{6Y$rQ3q{&~+^Kz>7qv;MFyl`cPN!SzFqY>Gl z{Mi>cq$&k+fjlv5QKN8eFNR%a+BLrx+=%igEl0+0KxW9ZYp&%k6IcG8cr$ zN904kRTBG|6!Hg@f3xm+R-t%*nT_pL42V z53ZWMtCi*E9AQj<#ul|4x3s63YzKRH4*aaUS@a(J_}Zzt<560*BJVu2E-rqCalf&D zwX#21AGv|d6Xaey9M|xe;J-xT`Y6nR@mFj@E(sT*9kaTnp^7UqjHh_dT5dD2gq_`< zb=}Gvr{vp=xAgK0fp(FeJbSCfyl*$M4g!u3by)E-XDzrC9(|F~sW}NrX~Xh1M-rSN z3Otrl(1<)n0<$Vo*K!RyQJ8xQ;}jLLH1LmJO0IwIxUP95^!j#_FuBX_SNtPHf9GEK z>CDVKE(Gg=R{LQrR+xOI*G=%vx`oaI0oV22&9+T@WExs_T@G*6+Cn2`qyAZ2N7d!8 znEOh`mqu6j+HY5#z?koM;n0U(&Pr*~AUX^f`R;@2mkkAZVn1VQ9r@%s&Z}ToswWB~ zFLv1ePCnAfz!%ABw`BdX=rleMIa~TTbUdyS7}~j6j?xHC=4iCVE~bmM=l=&kK)}CK zzA))y7MRk*I;u*Tt!A%@wMp4o#HWQ4CX`mZz8;254-nq;?p~`ZFf~AA_lG+)2qnttGLIEX4IeE)sb7Y%K zpqc2~>Yu3{8)^@(7k0SNJnU*ugZ6Rv;}kaELKOk3*%|@}nQ={%>?G88q>r#uGcBX( z>q^y@k$a6rZYF;mC;nB~#$88b_`qX%YZA$)WBGiLOcYOl${$6Gr}Ho|&^j72*oxR% z-m&Sc9Tl-yRkgnu4XJMk*d&0Af7Rvgk9UiR(w^YP^cC!AxL&FoJb;^4CwiI})1Wz_ zhQ9kg)G-)rGlDA`zS1e`I~ip%pGCwnc}zD0S9L=Vb%#f?YI{xt+z_f?#F9wa4Qz2s zT5S)TcGc8PQ{dRrNb4q5rzS;Ns%lx}2^CdjoBeIh$yXryjLo%TvdH|<8-HT3?WFY& zd~0I=icV8=6XP>5w@X*D=4Ek@62T)N-W@(q9;RzwT)4h?VrfLwMvHYo$MxtD%Sc-cQE;~v zUGru#+&19Ll0*!o)D4jWxbbO0N5*=M#W45D0=~;)z5V90V9SbbjG+-y>tC}XE&6?N z?MUE-GtI~Dm1z<9a~7z@MAQ;rbp7qkFupo|clBjUNB8qko(*yDT1gB$(L(?#klSx^ zU0=kfC+BYlR|MOrl)BI16{0S2bv>jj2QUE*+lq3Mj-4%<;;(lB+0-rZ&PlC|Lg@M5Z1C6|pF^A5iW;V;S+WRrqy|0oM;hGKKBHg&_4K&}@8}x_uvo~|7ow2C|p6l-{Yk08FannOL#~!>j z#%tmL8UL}42$v!l|20kj8k*5VJ8aLa`ed#qx8AjDyY?YTb@NWy5q8k053B{#u^gGK zc@ZWmy;FEu#{69mN;IYh((Cg9yl6na5|j_G_0}cDtN;v)I6P{!E%iRAYw~EEYrBtA zZ*$JXyCK-p0@gEkOr_ym%+&dYjqYXC3+K=}-sHs7*T+-c6Yf{lg}$6&sWsKYUv+bv zZ$H}K-+yjGJKR1iW!0qNN4&R|5QzUSc+qei`i46w`yKU%ZJy*~j@);>(P1_c&Yr&I_Ht^JF(2i&CBo;c*i=>=iAM;bxO&p5;ZjQd#t#4})bj-Krx|@8lrob}{HT31F zW^StNdpODi+c<=a+~^<-7x6r%ctYkx*-Yh_i)*X7mPwzgUVHkRJ>)q}vs#rpn#7T+ z#;G74p$o9}>Fc8C(g7Wpe>hviL;4Fe-tSkq5P05;l7JLI!c)CX4*5VWv4AJmWDOn( zc9WUJ-I3V9LpOLz+XUvYyA<4z@{>%*kAzKBAM)ue{&J2N3!ZjcaKb8tcRYxao@BEr z&bSd^+mohxR6xNM?HpPL{J%4c^eRG)2p_MN&sn_Bdn(&;*)dTOa_V0Z}}kzx!3b zsix2DPTPS?Y2nK%t-Bj=(d?H&f6Kro|K}TOv*_Ia|WUY6OOhhd^TwtSuf2?KQf4+Bi zw$-YdPXFarlI?QUveU!!9()Z0(+h1HP+Tp3t%pzKE+#dc<2yP-qi*w#;yO!p1t?b_ zskw_AAVz${?raIDPv6>^1bT{poBF%aaDJrGBmlvSUm z0q(s#|H@;4xg(kw3@F?rD;Gci(szsV^Dl$^BT%N>ei3*U!iVWY{FpRMUssb!=+)%- z0(e^im;>&j;gDw6mv7a{&$ObKOuJI;dPDM;s#%ZGQA$?!N8PYx|#9snel z8GC2Fx_tCsihW^yL+Qwqjf&Cb#@bwZ1s;sK|LN(;>G1U6{N?HK@aW*|Z134YYj+Qe z^OL711#Cod2KEbTPhJA8lQsVNANi5u_r1Fi%|3xC1n%Tu>mbzzCr64~rj3AUCT&C@ zhEkd0XFx1e#EsgENc~~5w{f&V#@nyBdKuL1k zn;Y%LSyvbCMIa&+WzWi5r1z~+(ps^Q+s(~(O^u2N{lloOjndh=z_r#YWAfwz@s1dF zhKBzIYFS`lrjXgUv-u5Kq%hu9Adim<)G@((hV-a#ec49JR(K0=RbO9Mle$E!h+M{S z0>?%oqb2gMNEqMr222v6qeE!8N5u{O&Mnzl;;6__UzB;f*0tmi;nQ1-B+YG>BZ>a*Lm)$#E-cG*{Gnzyp|1(ew8G ztz%Uv$N}$ZD8}jit@Fx+3LxydK&tn3C9l{|1;HAxcMOmPk8CPqLFm|gvb>eUgLRMR zt7@l!XhM#e;&3|gvOq80hVH;{G7Ja2lHDTj1}uxbKyaFiX?cAyDcy@CT2)|<^*gve zT!gnK5hp}=5fid#PHsCmK0gH%+r#5?(f1FZwlIBjq;UgYNmH)_>BDwx!zdK7d{f;9 z7cYF6bKWrURLna8u?KBx!J!57I14nR;d9A>N zuE_jOue4aSVHg0rArENb3_+VCRIi3Z3ztmp+z^F8TFR#oWMKaJl*0|wOV@f~#^4VL z0t=^?y@72utnO-?aswMI?1D-B0yGK`knx8Hn#Uw!S|t7@5MOmyL!eG*5dJuXh@|09 z7N()C?V+0mZzDoT?4;LvV1X#Y7I3uNv=D+- zgbiIzwrNEnX1E5H>lxNnwGfSCh>-a!R14t8t|AYs-Spx8!I2o4(J{{3sGWWzo7TGP z$#yuzo;`$Nu*Bw5C}F)4a1x1;l)!RYUJU^$ct{_%8g zJqCYq9~R$9FxVJo3CC+Cu5b+n>Ngh1#q_Lhc{Gv-t^EjYKJ}X!KfYZL;;r;X1CF2i z_ZGD?S^3rLR^qg)G63XhrCxi&iY*kz*=I{((1|8w=XNqn3^XNF1zt3kep#k-Xr0}; zcG5Bw_Z^yR1S3%{6523?@Ny76^d-5$ft-YwBFUBW*=79(oGr^@q@3_ZEp#1iM_&+K zcLQ|%K1QhRYE=AwpvL?dg8DPrnenvF(VM|xH9m_|T8ZmS%S_rak`lv5D1@z;HuzrAywF}1(o8(@;#=^xIZr*Bb-me5K9=V52DK$ zU53lcm3pgykK@zgDY%-Oqv0p?u)aptl42*4JidIp?y#i+LkI9FD@1%je=>XtVWb#u zYt3k|;zRZ&$IS<3%cWNin(6`-Q1lI&`sTML0cK>hL$s5g$ko`z1#IDEt;M~1_*JabxgKM!a4Z* z-#{StyZ&`#;xH*x@NO?L{380gpnk>M5^`}Y*+ycqFL!%T{kxkpQf(BQCFe#(Q3TFE zJ0tf0UNZLYDxKQx+|lA6VR2RDeMdzK2~oB393^bg-L0;~4;vri38vmpF+oiOAW$F| z!P85i#$)_3#(W7J2h?ksgC0TvL`y9|9R^ZvFeT8e^n&_Fk5OH0aH1FPX2s;YNH@@m zlU6=yJjIMfod;?mP!(g?R~fq6hQz(>F@&w9g#f%m=t%msFgLN(;6?+T7{QM~!6M~E zuR1$w+00ITqz_uCKDJMetkgwT$v#JrUwc;nL#DJtQX%Suo>#lq5}4Ki_q8-?gHZ>~ z3fl&}v{Na%`Q2nR7zbw9V22zL)gnnXBerF^yQgZDv)f|irPv&?OH!y^ln$$a7l=TV zEy}^nuv!3B@Vi6(>Y@tSAgp#MLz99n>OZ>K1|y3MnPz3`OJwn|_*neF^&EE4C$neS zNDf<{2@orN3pjLhi!JOlo8qjZSRo|2x%J@UaO?WgYfQs9+XrmP@{2ykm*u!pxhW9e zejcAFR~9lc$ulOV%oJi_@ZT{t?%Y(+Mf$@8dK3dm82h6M^!Bna-`bxuMfAm|iw{#5 z-N%&O;nw!}G`NUumfb?*zdR8;s4<iZXK$-Ki`y^X|(~h*aIvj#@vE0=85TPjT(-<%ihSD*=TtwCUrx54f@@!o$ zCmR-^Ux;&>JA>hikV~r+7$*5!7X-B%=?ks17RLbdSG|y>ytrN&uDko_NxFvaSihZi zp;m9p#gqi!zDK+mDiba~dO-%mH}$(}h)L$OBUdoEHs{c2A`%H5ZYp*dsq&{}K|(_) zTtR4cfcr=@S}`IzS=w_5lUuQLcuKI0kDBt5%&)pF&;!I+J)b?ARES;bT2ldyAO$>f zz)H^8Pj+sb5m9L@*baNP`WY`^9Q@#1;fUC*iZ2FEyzuAu9bSHzQ4$}7#^#vpW6v%_aY zJxdU3SZMo;l4fXMLa#EHBh#M=w_vkk7M@uvWRQ;^xfZRPf!uXFlX*@1DXcN#A#bcSh zZsf|r6xg08VQhd;jUZ!_508Pa&3|eW-4i`!E%B6WXqP-~y6V2q_?z?XA9&!y2mG57 zFH#bfzhsBsVme10ga{Q(J!q;YtJFJ_>QU3V1G#H=8wFk;0-G19+RygR55C*`US$3a z4MZg2pRmY9)oq>*OpJ=ahlX=>aCBm3C5i>e82v93GD91uV&1Rx>W@S!sUJxf}!H3uq2g0h74{#oT87Qez>sj-=0~{LD8;Xq}@w%ZDFJr=LW>9*W6{PmlmMD&@;-D?45rUKA=qy*#y;i*DLu-L-=T7 zX0%&gCpQ&mZo=m&dW}Jwf{fJq&$vFUBMm(wn34Q1)6OU6^y8C*V4na4QhB2fAg!5e z2zP7go`y|h7hoFTI6Om}E=!Vrki%y+Wu0FJ1{w*E^XnDe`H}~R2?JJvV`XS(zD zU1_ry_Zry>iNgg=hIU5y@RFs3kQOl_Wd87Ow+OAGh*%&^q*e+{F5&H>DytYJ)Aj&g zGy4G?SYj=7+ny4eT`Vje?J%tjJ(3TR-F5N5eeY~9W~u#v>knSm;66(r8&YOEo8_tf z^5gjkp~gyLjYuwFi!C(t@#g>m5cPjF7!gn2$HEiy4TmU-)lDyq)B z?Uew*)Fy|PMMm3DeN2awW20{$KXpC2Kp0>C;COD3Rh<=^BDFKFCX^@! z^A)g-9swT7jmg^SUF_-wyv0|SmtYby4BxEP8z*)+n9xX2R$cBoRII-`2~O5mAey5>m@mgSH^Ol&+k#=l_8X^VZijWt4}frg!ri1X$GlTvxbgZRYqt zps}~S9D?WvgiTkek(Z!7ysLEAeb;R7)#KOZe}S2nh|IP_|8%*1h4BiG2Ir>-dq>`z zL1LVq`&L2&jNm#qw6Uo_?ID=!ZleWy0MJhiC7M89aV_=PBH_Zz^lArTsI4maZ)=Xf zfrw-VMgY|K#6zI3`y+NjakK1)AQ`)AHT(r{!9XcIHMPBL1e&331Zb2?n4JhOB2d0A z9784eAXJoiuSOPCWdx^Vlsx)rLmksFx82PixRe3JA_b_1AiWzTneKL$7a=B9?o}r) zHPqlpY2|5KP$#5-5U;7TPvJ}+29su<;?WsQ!68JcA>>e^;B(!)fh4(r2!(n{77Z#4 z*@8>26F-A=*k!K!^W$H3etz;x&tjwHHm-dk%=mI}Tb^Cslru_!mQ-^)6u$;!?i}SX z6<7_gfigN~hCj^A6n~iA9se!O_FBi_Rq@xvVqPY6o$Ejq=WI(ut7p`lAIYXy(xq$8 z(H;w54~spGj6?ZdIhFDP3B?(GP*|?YS)!4Ei!^g+2n!eU_0r7eUDXWItEF_&oEF#$ zN0wIIzC>iU2>&c|_7ZAJpOop8%~WAf1)NhEw3tshrgM17>Nz7+QKt6BgW#v}S#{WJ)c|L8;fr2KF=-cjYT)r`rLBY$J5}|R7T)-LhS8sRj(%l z?ZOg?1smzXZn>3z3A{1Gogw|k)06Lxhw|XBIdzCC*#bY7qP%wMlM?N3jT*QAFRqXw zhWgFbif(N1|3{ZST`~9a9^CQC2d1{S_w18p_EG8gcsi1FZPC+_!c_K7&>!c;=vRMg zNFHsUT=xCh?I2c}$T<$Z)^-w}(uc)&n8yX#7}~ZlzYAVMnEa)dBrOEW1kNYsGP{Mq zghnz7yT<6cld6C#ot)FSx?$&Ob&PcDZ7@cSuTJEQe zrESd8iG~`!O0ms6ns7`Ldp-^NG{NFyBtO#M$;G1qF{MjW@X>n{ZvuAc^uwzsJFf$x z_zL|u8iqaapDYEy5znHLPKa58`}B-*$pme2Z!F8Gcs2zL;V`LF^IO)7ZV&9)I^oHN zjh7|G)9k$Fz~>mo7;1N9L#G};3f0n9{z`Of8&6X1>B0WN;kO6QNnhouj6L)P-Ur{r zF(XR19#jL#KHBrW#hnM?h|-zN<~R02LuMp+HRv8sZs8zgeMwL_@K;HxsqPm{H)_6; z%%DXfc2EH_cXQZ`)keuTrGce#s|?}k!%z|g*4f1d{4Z={^gtT`|6>v7<+tlPOdGsx zz>5EeuuDqYaAnS#(?Js%huz5OPoXV#K9$Pbq)m?I4Tu$FC{DHE3iuUm2Ga?G*n!cV z_Uper(;n@fDHBKEJjA=6zm%k(bz|{l4@toO?8WWV>fLA6YK#b>oT^8)f_YqAdMm@S zkP%5hB70RqsmOHUO`;{<#-D3G|3iil`W(X`*&Tu>b1$e!brtw!0)ikPfB=jTv`74! zAok?X5TX|U6+Ky&Z}8NaFQbQa6dc`qxG9rKKB^nERALeJvSgHgMyD5iDl!nLBVto5kdQ7kC{X+RV&vLy?d zxn##483EyOwz`mPqbej^Mr<`QE*Zgc1~5L(3pYRn{Pu{!Or9BiUm1xdTKLGNuzN7DB@RB1<!^M2Do|aihhOE*8pSRU*(%KWOh00zVGx)iD}N}6*ObBP!`ZgE!p7l zb+ai}H}hh#UM}F_I~>!Ko@{i{g~Nlk+vXoVgb)L?wHqXh*?7Lhj3Bt%^=uM2VsdLm zA!0wl3&CW7qXfnN+AH>obE7?w1(VON zy2yZGWtzh!NC;E@0vb+W&MX5N_Cud+hv9KEUpBBwW<{H^xh$6TJVe0;Os6sEk~Jj< zoy^w=iEi|D%q-%elhaCv2vgoJyn=bXHp{0>)3JpD&A4rf#@>G5 zzG|O^P{N81T6oS4MYoFR($k_3TIfX}G9N1$EmqX~! z`s(rPFo=~N4H(AU<@|WT3k4r2&$pFdO(-Y-9K)#=>Mu3vDuM&_WL{hQg_5&X@ z2%zoSWOP^^&a$Tr8Wx&t3MW&h{gs)%v~-m?w|R*B-gpV z*Bm=QsG924rp8B9vkB|*3f#;u9FFNqvh0YuNKPVtOylUsch@5ieZCZ(yI)j;Hv?M zoow*ax;8^*bXrOIHK1G5$n;P(hhpLr*y+KS&ky#`4^NJ_kcXi#io;*;eY-b&zIXg= zxc_|b3{|Tq$q)MQ1MkN4d4B4jT2C-miN}H`TiV|VkDuM@Lza@H|GQb`2YbLr5ef>hE2imY>EOa zT=l+$<~Z9G+!No{)r~}TzB4~l>z{xAc`-N32Y(r`t6mSZUbH%ynoeDzZ*8rm*^T+N zRJCTM^{So@PVu`$L+V>e2ja53$sp@St~YUnW27+}y&Jf)X;Eum+)878HSkaIwfJ3G zqrZZ*jl(X^2}QS@OyK|M-BC`-kXSXJy`aT5d|@(o-J*eUar?I)amGSTkkkYalJIZz zv#p{2)>VTx^0guB0q+&*gHKpVfND1T$F80#mlGP}F;dnPB!#i1B(h-=K%1TaB}a>@ ztkpEMir6f*QZ)ljNPXb9xAQgDyvK9^8-kYg3B|a=6;m#4=6W||-xm4J!S}!x|1M-d zUn+Y4yyF>-c_ngYWB%GR90NBcvnfUXb1%w-41bri%nK7{mdROwZHIqo8}Pl1zvg(o zyB^k)5O?k4Sb`0yQS)u4M=oJ)W}F(tFp=Tm(+8Qd^utE;A@l54Y5qJgMEUw2y0y{b z@(m`95$WQwPt?cJ+PCnGi;ybL9pVuJbI1|WNJ}7*5J`AS+d!1Ik8C!0dHl^W;8Sd} z8_2NIrU8#|elzQV+KTY?maPP)kNt1b#|SjF+XS0;x&tNM0NC)d)am<~M>U>VbL%NS0on zV=vp6yCu-B%Xf(Q_Pg{h!JB3Gl4GpRxnN(6o`$TGlc(;GW^*ugikp-Dgn2ssiIybI zk_KrP(RZO(*mSXqFzE#8mQmPx{R@E(|I5<5x-akMClsVh3l|kUl?JLa> zi!*lh#WZSGT(u9z-T6$CFo@3Dv)R^yrtrbY;tvhN_bMx}V-Dqdc2zBFij4`{ai$e| z(>g3}u*$T4TRCJ=&N2(sX((-4r4oEzA|Cvk_0))E2!JATV1_U9!tbi}3$jq*<=OzU zty!KSsF3uSwp(d9;S@c6q( zo7$Ff2U5*QXi0JIe$t%(u%nK#*Ietv!cno{^zD(BZc7z3`)Lt1E%q^xr!_=kX<-gjA&S!FmMwUc zZ5+W7%p{lKb6i8l8(La4sBG6Sx4lv@3j=KN1vMB9B{v9dRU8tL-E}ZBlDRjSGj40! zae!gT_UrdbB!WI^_=v-r5sQp{)?ywA18CP}QW^|aJ>uYCimziU8ph$#N%Y4={+i}T zNU2OHRBl_mtqn|@BWoF(%+F0%-QEc4)^oU_R%}VtoT_D=b^u>(Sx%})^UF(1sBe+966$iSy*YM3K6U;}{sPY$_))%xIp|!<-=BYF z?&lY$C;$9Dkysh8sG%@GCQj~wQ`pi&oBXXZtYbL@l)}_FVgzR%e$YcyAT=8 zxYA=v?MT6x&BYNT(6ZAg0bBwRI4iT1_1ZZBf76?~QG3}yYPAl53d& z5@h6NMI69mUq*`a27zAyu;G%@N%BjBBqx!RCZyrtk1=JmXVus;cXYn(h|CfqojcW9 zcJL0K{w4O`QRmV97-o;o(;~=l*)PMTXd&r9Z#$mr{3|!rQW0D;7o>EU8PT4!oFqMV zG5BKd?4bD1zXh`34WYxEd8}tuGcFg^5G;a-=$Y9X1uCCPa@n?4xx<;jR>n-=AzR+W zKba>^Ta$%f0xw^V>l>mB&1nZhU@6Zj`xI;&7 z>9*^}mA^wL$qO;tU##Wm0VLZZAt5A;1m3iaI7f3Sm*_n4s?~_ndZ=VKj9RnR`ucV3 z8b=e6`J|R!zKpfg$Q;eSE0pcc|F#=4yzHd|np5t(z}6YI8 zuFtB~dJ$j+`Rzlpf?(z?vGvqmz*_CYaxmMFN<& z0G+>>&!^p9f2$^PhK6|O^I~CoS#^ZiV6o0+&dN4x@t2R!PeBnnJU$Q%ID#pyEkyk1SP}nllQ$aZmzu2#ehgJ{rrHGzW?6&tRkY)9r(kjuWt%bD zKKev4c8dA67iNPr|3gQ^$vaR*z=s80a%NLAiakt9flp89Z*)PK^&CKWNPo~899Qt4 zQFGshgXBesYyh8qJ~}*p@$!7Qe{%fo!Rfh5{j%vs@!Rr3{x(D@hA@B`j=JeRHMHL8 zJ4)f~C2Z+Q8r~R)PgD?_qBeFXc|pz&PR$zMxyN_lQz&opi3=k4biEt<0mZqa>Pe#5 z;nTZB%6Ku{$_S(@DwIWRQ(bBSGfro50|%|MTchJ%4>&IDyuPlE>glv@;+$!jL@0Fg zNv@z&;xjuyeB(^qjyWK+plpx?FWJ;an@&D!Ua0Kd#dM<7F?Kb#28_`3AtsMQ(q2UY z?L#IFbXi>c$%(>&`A5NVs~Cs(6xWi{Q%Nc!K3oX83&Ruh2D`n&dnPJ?p^XCuS3+W6 zcr(4Vb-Rq62{SPi*3t3t0f+_GmM~Bn)i0p zaSXM=0U}FgH8-7q?d=hb#MfurL{lyEWgcR;+}GW}4ZFpLbS1vuX@?3`xJUGn=90eWb9R)l?zyWmHsLJt`yx(2VCu_`-g`yy` z9vM1Gm+0&Up68_kD?m%_MP-WD^JOn)@vv|^#mR#G^C|d&7^ldsgW{>oYG$*Ad;7hf z)jt753}r$AsKrU$tm>JWJ~Ti+dE%}`ULu~K8hpC=ggK*mqklGD4Opvm?bVZBg75D| zK4V7MhHmIpwe*C3l!R=e+3n)G#p7wSe-y|)`= z7I`|gIXA`8I7#_3cu3nIZlb$~S6_Kds=@Yrj)0gk%nq;=?tQ=bbYdjD89e_rIi`Sz z1nVR1Lp;#PH8H}sD%Z*=Iz+B4r2^3De**9u7bC%Ui#Jj1nvY+Afg^pnaRAu^ZX#50 zfAy>D=UA_&(W@u@zrG%^Pop8P$$E^?J0q+J&Qa#;yb)od>u#LPt!(+`6q3PxIe*e@)l=T%h+B3`A!>mE^Us z!_J>6mB#kW)MzMMc{7Mj*b8coI?e z2}$ARZ7C+hY?7%kyQb$`^gEv3k~~2v;hN0lm64dk%?vZmZma%+$7(R=$Oh==|Hd|x zfzdi4T=XT*%cYY`Elr?WWga~;3FSQTJ(8^xq?wVw?@X%yAJJk?t4#->!GksCmqlX& zuX;lNy|Z!{>uCStG;V1ZbDQ3ua9yhKk3_V(8PBEjx1eR6-#mPeQ-kpplRqWJQ!eN~ z2f53xd>Ek{sXA*TZIv5BLgqQd;8o*C9=>MFVza*gNQ zQ$fg;b>v_S5>}4Xba$XBVJIm&phz5l@yzBT;eh|%it>3IdYrDC`dyehXq(j=gah4J z!y7f0hILL_-x7?iEss(7D#r{Yv@s`_((8X~w;b$ZkZLgR!{Yn(WX@X3BWH_qi6JAK zcOY&%eaEVEs1oJ$wY|kZH%H}Lx@yTF!Lu}#T?{KeTX0CFenYObOEp>my!7>qPQv-A zHDWNU;%=9lg46dDtjbpt@n&a3BxT|{cRlps4du;vz!fy7BV%Q=;zM4~-p$|2K*%`( zU#siu8v9w&H>>%)FuLU}&(z(=8nPkdVhJ>o@XV8(_ zvy>Gr1gilGK;%~1|97QPh8N`u%$@9&ZPv-~d9;2J*tWp72GJ>|!>9BOTt)~(yFiGN zi}`wG5o^GcU=~veuZrn@#|v(?Rw}5d(4j6Vf%f=HcjYaM@NR*~og}=DX`VX;ygNSP z->e7qWjzIGGRnJ&p$vJ8Rp(8$`r*?r=kt!Cjay}HIJz}|?88s$-!Do+{D@fi1;S#R z5*QgQ;{z(ZE?46#Gp_PY3DiLkKeJU0nYrsXv(ofe`R_mHbE-s+87oPlg0UoWG~jKs zL%$rGFS|c{+KE@m?LUgSL<6?brbfgeZJW2ZePljtSY*M~p+CqrF;M|7x9NoA2%t4g z2;j8@5bBVSZ{qcuYc64=#IMw3pt_@1a2=Q0hmy$eOik z2@VhE#eF5jY1n2sLImQRKGA^-QjEQvwqxt#txvW+)XWwB8xO0>v;%BnTCQi~D{GPo zq=}|a?#z$j5?Pu)u68?}KAn2|8!)2zV%6-vdhK7Ile!wtXGX-YZ4ik>^9&nvH?CKS z5D;C$C|As{CZ*xsv-8IyF#Ve?H=94-f%#a7wM&>seJ5 zJC|mJJ0pJ=Q?2_YOb5Vh!IQwXJ)=NOkdcYKq^G$W=?Vu1n< z3OQp?(cpwMi)u_~AF-uCt6p4H*YGAO<&l9$!Yk7x52A#3&p8VYfBPD~HQk;7Q*%N+ z`)HW&{EMY*=!QnI1yjJy=Cem|WlzfG#ORppY!7)hE=4fGafJ!l;O+1dZ}(L(YHpi> z5l%1b_ak!Iz+9uViXL`xm|oXjOY7N0M?_eOoev*P+x}|rxXdmjR*%l@j!l37Y`1#9 zT9!7hXa(UMgp-18b;AcF*FB3~;4NtP(=o^bfb zXW3;A5hx8<15@znkjG-)3?|jPnYl3sE*1-;G#6kY@Y=Wp>&B#N#>;wv5M-sfSS!5S zm&lpSZ-e4HSQmI&K}v*x8KU@A9XyfKX=)@pZmc~9%=)7*h-vfBwIK05g{1045(tf5 zU!w>LZeCWn@h?tkWr8;g79mHr>K%r#il=Wt!V#zp&@_k}qCAx|Qul&oX^Ji_cAvMY z^-7Lz@m_IVzqz7Zk`!RLsFu@P(rhPz1k2vXPr6x zV5uH{BQi=hC!yL#Hn=etA)P=;QtiiH!vP8}!~NBA`e^^Nv%_bH$LA0=`pJp}@8x>2 za_tw`V1`4eEsGs z+8!ky20J1?p3pr+v4YM<@;Kd@0aA4J0s%0xz{j0`Ka@ZmW_&!Hq15Db>K~1$}*U=Q#_g1Hte8NE--QQJQT!B_aNk zRvGM1TWSrtFZ}9SsZc+AkuaLFyviIbI&`t1X8mCsl+iGPbIH}|C!-tK?f+S7Hxt23 zaWFHggw$#B!KC%0f?U;rd$~ZI2M-=REcW-FKmTHH|C{3M#limJmxud;UJ`heh5#mD zZwX!uQJQEL<`%9QLD&2~t|+P;-Luh>?aZ$a?65Faos*JT&c!&f+UAWsb6KG{9@-4; zf-xA@lJehbw;}aZu^?l*B$OeM;f|ORBV{P2h#M*3djI*Pzb;rA?}GNt-mgf2N3^Zd4Zld2i^I zvMimVOA@XGE}AMh9QuJw8Dsb*jsR9{OUw#@Sph&JfQhbW-QGa+j`DJd?!mk4e=>lm zI73$GXralm@2DD*&YJ%PWJjb$t+?jspDuUPmp`AK?0+*nJ3l?xJJQPvKAwhJhTLa2 zZrPPiyR3MH2S%-%bhFb4u`U1hXtiFjyihOR{b#@UtN#}todlLhPMF`VSwr{Of+RQQt@Ot*QKfx^|Jkc8I$pwn7!wLTr+oD+~$MBw7gErc*N6^%W*s9p)AGbbf zrKmMG6@9?-vyszrzX=_YeP9yIiyIxGb=m%cL76i!Fqj{@ULFQ~7noknn^nd)FEf9+ zo$V3;_sueUHwTn`(9ntMin|v@J5gd_QFu2(;FN)m5zt5=mkObc2F0FkDpaF1R3ZCE zz~ID01#$iGmV{Wqk^TLC1`XY*ZjJlj(lwnND&+=`y$Ls7;I!?|<HV(8hNEBgp{M{*0Y7xn#U^lzkIFB~)#rESG{ zj5Pj6pNk@=?0HmJ00Tqfhm2YA2>gqUZXM7d!ktKE`YR`3N9h-U3Qa%ZC00#{;etO3 z{gd_#wIc%nbNqj@*hXO+kK_=Z3&0%yAZkQhbX#WwA4(6G+R6e`Dank}wP`(va!&fi zFfD}8cmw$(9Ba#baoKpb`TTsrXaK{A>!dfs*1JRR?Nd-1%+|xSE9UABZkFYO*KWhm z!wryiuP0Ry3*wD;wpcsC*sa)g`L=>bj&5gp(MeZTcR&~jZ@U4LTapp5T%IBwbZMMP&j_64R z>P@~j`Ovmh-9jJOidT6TAd272xNk~)aU^shy)yy*rq?6h8f51!8$+9crFIjiD^e!5 z9;}kW`3jsb7D6Mr8_$DIaOaPUN6cYmN?v;jGSp!fv0^`Sj;$$18cOvLoBtAUUU^U}j{lSjth6wxQ}MBIq6ZZ&7eh$`@7aZM~MfT}8u&I&Un zkBOTqQf$_*i2O z<6i2kaK#q$X0`RX4IgmTB}axgFAGInYQ(R$Rii-3fguL5B<&Blet<BPv;g^xG>!mPU+6`A3Nbr`_)299}rq zxy9~>S2SEWFIrY&ya$GJi7amUd<~&5bwdrkC+l@E9X`$=4 zKB#86ft~g0^3i{F66-`Uu;GNB?VWzH_w2yT<~byq+j?$ELP?e!vZNpYQap8~RrA(v zJ@f2Q@l(K^9-O~CJsuvJpUn)nS)iQdNQ@fShe_miFKs@baDXM5hA>FMgXr=RdqnFmS)m{K9r+YPo;js&<%d3{TR z;GacIoeg256t@vYVExR-C>7=d1tEB&Y1D%=rTmdJ;JQ3cv~_G&jSt^~^i7F;pTwk_ z;|v8p66VcWGcDh#LVXeIL$NjMrD6cy={!;A0%X!72AV%5TG-UjCJxg9#OhbtF#9sN zi|6z-oZ#$X>3r1V0Ih9QLL-nVy-7Z=2CJy!sHWgXg}C`y8^)L6Lt4GLxnC9FMr&AJ z^r5%9er8~o2I{!TrOaAA19r`1c%zZ4k$o$f7W2W(;xFlPzL0iE{3%c|X!Pb)9vb91 z>|f48-ji+u4FM!IS1jl~^|75wL3PX#sUhjl)AG`oj=xMh^xYax8FXr9eMl559_Toe zZ3=uQZl9NJ<}WK-uHp%GTk?9`dU@}?YWQVvj4-vQNDb22q%H)BT_gKAh=DqN1o#}E zW|Ns#oXKoTU<7=_NbrYzCQD^&&HUBmlLMIh6$02`=UrKFH6C%vnD#qf5;Q58PQ~r~ zBi)0;3@NudQgVvGnYGWJBQZwuDH+@>3@k#whIb{3TJvc3~c7?38yCcsI2Ul2g$Zm3D$ zT^NNkkX-~ePSTCnygeOft$fwf(byNp3h$MPqhbEWgB#VnD>P-XIVGE^D zMyryNLphux2f&zUHkIt>SGbQYuA~FLj}wsj2B9xOjHhwHFgMbR42p+~ljh!pIb@`C z#2#zD{Qc030c4%HXyeqnbsH|Hk;MazBlAUa>lj7ww3<&@I+xUTmV7ta#p0D3E{_(l ztf9_Aw^~u!Ce|n$S}FH|#lOfXUvE#M!5-Dao4LUXD>cxr54t|knGm)HMV3D=laO&$fVU=bM;0GQzg0MP$ZmN060C3lj1 zxCnB#I7=cWxU%T+<}izURJzM!-b=5@pF?07vORH$eV=mN(Uo-d(6BJZO`z00GL_V@ zCC)D;rFtD1U-?8{b<%O=BC}55tOuOD~VZBHolYp6m$;n8+V6*hDh3ce+j24^Oe5C@ZD|Sqj-3oSXQ}6 ztKN@B{IbFLR>~i#GiaH#vnEp+wF&L9()`4gu@}JsWm^lf(Uca!KL%)Bj7(HwZRI2O z6sda%s75Gc4N3$Y(j7hD6DC6uQ8{Yx@z4RaV{=CW3_3OY1cPT;bcm6e&5$*1u37^uENgU11{V3G_P zpn|;`FLdBDVmo{*Cyicuvn^;TBAjcMvQ0JrT{*Q(3tfte9#eY4Jg)a49u}$#P$JqY5c1+;L8^Dz?L0tt1@wF>_tirt>%AvJ6T*yR%;hp#`+B?K7 z{eDLSv=!t}>76H?R&Qx`x8~;~{n~D~@5Wf3>7*o0dC$esp2#hu`ZV*-te$b<`17z` zE!0T(Osv4|pb@DL-bgz##8#|*v!;B<0E%Z@p2n-d^^#8P)#o$#mdEg0>+p1_*|pov zyGuIt=65@bU(E8Z#B_pY|WiP_SG;Cm&s(b4*)_(d@mw(VioYH{-`L zUb~h$Ft7ZCNIU(!hWm{P8T2B%wL-yh{g+rO4{sv(I1+?mfqY>xx^lP+|FCP5_@4i<$TtTH~-Z?!r!n#V<9Pl3Va0lu> zN%Q}_J#NZPbIn-nJQWH|!3id&-SuQG!9_`HJln0HPPvzd(xUE+@^e~uTDBzK9W_&n zS`AFyd3I*~s`haKkl)}ReE0yF^~Z^KU3Qsk^hQslbn|vbSTc`5Ome^F|gVPO{ixSp49Ppj7g4e?0_TEh=2P zoRSwJcq4b!HjKC{Bc2~W2s(Dl+`~(F16tz-Mi`fSJWh9A@h;>@qyP)SvM z>XbLN4Kd9Q1U6Z#fKzf*LX1wSz;*M>5~eE!I^AFdX=8uA7$e07U@y*iosd1OaY&;P z%7v?Vj^03GnJI2EMv$cwXQ6moYV(<2pB1tl&$a9_NG0t9k&y}T+{jS?ph;an=y>hl zeppD$^s{k&lvdiUM~C2mAw`=DqG#lQB=Gma4%b|9+{e6(uY_A z$K5mwu1Wks-p%dQjR?-_^Xe11<7NpVWIbux0%5gykMVu1|H-jfVB#^k z9a9d&?KDzPP<>oQ^W~NOBIN-1=#_fJ5*+hExp0agCu&Jcqun{*Pphc3ILe06x$iAS z9iI{fgkir_+g#(p$KN#VthDZ`BiYw}SxTw2v}8vaoy<1cFIJ6u_mqR?5)T^D9?10y zi9?o!SE0$=o}m7OxO(LDD`zLOTK}CG94Dt2DS}w`>CYVT{^;&=04LvlEY#33leVSf zh!^nD2Qi;3rJm3-*Etl5Hz5etEKQ^;Q`3E` zl64Q1QrRaz$wNc}zW_hpKpnweR51fM(W@lv7QNE8$X*jPImuh+R-sGV0FlluDC(0` zEY&miqegGsNTgrckYU_14tde?VhJXj)_M(WxQkZoL-k~=>u~C3d@#gOpTbrm&Z18F zBPD`SIF>qAJm{!~YH{)bfz{+SgRHAZy{U(HM}JUq{x9g0GqCD}zV6wxBEk@+07!Q= z{>*pgLrJYQ9Z`SMT6CfAdhzh$9*q0z?DG>HM2XD0mOi0!Gq(DR2z^kJeocwPr|t5? zrYMi2N+T>TT3N?XAMQc-LypZspPm9rL~E5~AH2zEVNncPymSSl;^1S{AEQTP`y3Vn z7O`xBDV1IWNUE-A&3u*|OP|VxU&xs7nU9li;s!4GFl2})T(7t92U%k52vETEnti%f z0S+CHFTy-3Z0apQjZO@mE`aVBle^|%U+SmUd|G`sdaifgM<_+IW?Z=yYj+Emi3Ks$ zZ20^ksh5pR`SguUccIqkD*4x80V+3lO{1sC>zVQI%GauB-bi7H<3JWk19Z?dq+X30ofw*syT|tWx0Cu!M-4mKxX` zx|5LFd-{XEa=o=BKnoz01(6LzeT>BSq&!cZ?M8h_-dKC@sTe$|rfG25cGn5G{Us5^ zENBc_y(Q+BA*iyqE`yP8FyKxaiE2rzIFTzgd1Sk^bq`j&%HYc4r2ZSZ#Y?H9z zV1ug}1$V70G{#xwIs=pX$n!r7ZxHlJ#6@O98^<~)CEhTuJ)kb)rRLib^s5U|cDZTb zOIGC)a+6N?&O^ZxA#S)=#>e%xVP(4Dt`-Wi)nG6jf@M*|vC+sGnPAgCc3)O)($W}z z0U(yCDA>i9v+6*H<$R_ZSQUuqvr+mwhuQ+MBIS{lgl?LZ;@+So0Ogy+Wz{0~S10+0 zP0{%SYDN6ivp4j%Xe%^yUqU^4F-NFUTZr|Clu?USgI-gQbx%>;v!aqIfDu*|@(=+o zp^Fx3GctoRZ&>TGD#tpW+Y4*$(S0J^GLk(=oF-|H7pw(HzUzc>ph+mCl}|4o5%F~i z8PFJY=!p3QtY2DD?&*5Y_pG*Zv)?(z5r+_DG-G#U9tOiN;$Pi55qQi&Az~4wmg=UZlRSFXNsF07NYqwM& z2wh36MzO07sB2Z5sO#2R$LP-M1d}R5>j0(KruX|owqvp}5_Ky{8?Xdv*6eQoF zJfJ5{A3`I_t6Kx&;!(w82m*N$YJ%1!B7%I@78nv zc6htbaW%7z`@D^FjCdQJGQzSxrb7+hwF%$P?Rtf9>j6XZuyx#x=de}${#g6P$ERii zhwn7|%#_!n87KYqF@qJ11Pislm4=|9E_4`@F@TC6HGvWNhIgeM=+Xa9ADaBDV!OBC z!};Aj9@-*S8ISmEehCxefmH=XM&u$EW~HX7*9i643z2V*Cjf8xh2rV)aX99*T&u$U z!}*hXR(s|uF5GJ@TiqhBY0L?`)2fSujNE$k<82b%qA)=rqki?z8=nv`FU`|UMT&8w z3o;nn=juagp`LSCK-XXToo;+1jmF+ijBnD%CidF%#U9gtQlm1&pzk>d4t#~G z@^~EqAI0Z_iZKoy8+~xaWlz!KS+Zq{8$WjHj=8)XVS^*(_d%5K=ArsC)#q{(bS2_AXAjLb;mxD{w=CT9(j!}vl>Q8?F7sGUs1yy+V->8vo{phVYUtf<9~gsljmEXsBLr5;|Z2$%bzM<1y) z;-H&64rl0{+;URM#7{TZG%53+CqO@f27rzA3h7ym6?d?Rw z?gly|fZKY{x57&{EX8)1lWHkD3nI4zUs3+EL&dPr1j*r-BgWC;oh}~3u<4_CJcF)lyR;zCsy&0w&Lw_*~4>!{` zQz8Dv|BH|od9Rr{L5;W?%=s1!h6nM?p=2lqj8%=uIFnWj+|&qnFnHnAP&YLk>lx>P zmRge!=gx26MVy1-+xA8*f6+r!!k#>fyA?W z$ta`d6Zip7OxF3n!i>~!go4OuE|{y_QEhza{i~Ru_HDWbkLy@26F_@=6^t%o6;+=G zD8*q9ced-j4HS}$OK0=$nmRPw&y>2#VwmQ7rS=Y}>#i950W)ymyz=MvbL)?DOt~<9 z?p7-Ss?3%=DclrIcl*(UEdB#l>c$dV{aIEM!DpQPP)S@1+h>_ZCRm-lb_oXasX1Xw29_*O3K~ z0id*QS4pbA1B0G&LnLQ2jjOv}W7n>_3R!b)9f`4}hv4n4M7I12$1}DOGGyI`Ax|-3 zr!u>HB{WWI*>VEJisyp|Nl9QK@!F zm!;Tpvqf~gkxcSE$bwBV=a3!gz=km8;b$TwJjf&aKo~g zO@ce$pc|VsG7albKwJmVD@N@-uL}h9CtLI;Mpc!UWDYBu&SroQrqP!3m70W4nQ z8OZCjM4aeywiuTR$Q6`d82)QislfDETI|PbyE<5i2p0nYK>KS{$;sLh;>rFm|1AO_ zQ}cKE+n`dJrhLpI7kcZVCiLHv*Yk~! zSRA%F{~&G|bF$3O*>vvi38(9g9oP~ViW%c>!4|)(4LnHmA>eiW-M&wO)om*pMba-W zs`B$2h6muV1_|~Fxi4(?NvtWRIAmkHLM=d~9gZ;=t$VapFi^$*+9G0M$yYx31qEjMC)7y$dXVcbav7gJzru4JPT>QnLb$&dyEE%%$^A} zNz?rby^2uIain`xj5V_jgbzy>I)jrE|w?s8P`Jz8m=b!uVke(hY#klzJ zK!35R&{=>oIx^(|MUHRW`WDpsVT%C!#&nmmj+e%ZbVDX1@d$Q97)n(2@?Ir2#)y2xrGRPsW&5iv9E|+iG`;v;~6;h$PDo&j_&CD3L$_)%6yvlk= zDCGXZk%w^JERo1ttQrIKMoMWzq4MHKU;OyCaAVVp-<4QTc>h*sfjB6yKtNOOoE=wz zlcs-Hs4ir*i=sH8uvhM+5Xgw+N+%Ly>qJ(+ED6sNPXYpy8f*g_5k?38&d;UM;!ii9+C9W$Ulym%bg^^ap6z1Tn>e zrl9JI9tRTsZRN+ z+yvQoCY@ek^yIix*J^@y`Vz*O54VV2?S3_f&lr)V3-G<(*!%JRB9|ScC-Q}k`tt;< z#K9v2=Q~#SHF$;MYmg~2e(I%kunIecejW1?3Wu#2Okid&gFRr?h3$h#LZ42tWK;jz zelM|QyAM<5%qI2H$#aZ@@fsP;2L6{`U*?0IW=6FSmFKM+-cc(e{WVRg_WTUR`+Ut7 zY||w2E1K_FiV;k;^nRWMI78ody^|bLijX6e@=?N5OjdPR%urcGve2Jl=YC=k(cE=b zP?{2{tSEj1CP3()F&5Cm3E zhyM(~*BFnr`H;vpmo7*jrP z-U+kZ3Hykv;s|uNScYfd@75m?k`Vv^mfxt$m&|&A2q62v%YU{0uI3+|Ao}yaQVQ{Y zBf3k5^M6k%7#jZbbB2bb0;Dvo%xtX8Jj`q~h)QxmC=%%3SqB6FZB#A9=K?(^4|(?i zQqUo?_1`Z@UIOqf;P;+`2Zs%dp`o?CwX>n&KYkJx09{%x^!J<1pHx{XI{AmN&?r*~ zOu!ld@dExyP5FyzYHMw7@9bp$5mM2=$RM$#zaPk-V00&6f^)Zi4Fi@{2S8^^iQURHRP6S?fI{&{vSd3 zPmBGZK}aI~C-xtz#J@Z3AtC-hD*d18#XsqS%zx7V6fr``zeS9H>nZ#h;{Qx#{YjP% z{wMk05?c`Jf0NYullq@snLnvnkaOZ6f83uP8wm8j>D2sx=>K$u{0Ys-{3rC!>jeV+ beZBn0@x#LXKG|PEHY~_#@C^AX0Pz0+FY9IN literal 0 HcmV?d00001 diff --git a/python/setup.py b/python/setup.py index c447f2d40343d..7da67a4109ed1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.7'], + install_requires=['py4j==0.10.8.1'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 67d2c8610e91e..49b7f6261e8db 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1169,7 +1169,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.8.1-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 58d11e96942e1..506b27c677f55 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -265,7 +265,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.7-src.zip", + s"$sparkHome/python/lib/py4j-0.10.8.1-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index bf3da18c3706e..0771e2a044757 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.8.1-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi From af3b8160704b27dd8ed2b95b61edeec6968685be Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 31 Oct 2018 10:52:51 -0700 Subject: [PATCH 1970/2461] [SPARK-25855][CORE] Don't use erasure coding for event logs by default ## What changes were proposed in this pull request? This turns off hdfs erasure coding by default for event logs, regardless of filesystem defaults. Because this requires apis only available in hadoop 3, this uses reflection. EC isn't a very good choice for event logs, as hflush() is a no-op, and so updates to the file are not visible for a long time. This can still be configured by setting "spark.eventLog.allowErasureCoding=true", which will use filesystem defaults. ## How was this patch tested? deployed a cluster with the changes with HDFS EC on. By default, event logs didn't use EC, but configuration still would allow EC. Also tried writing to the local fs (which doesn't support EC at all) and things worked fine. Closes #22881 from squito/SPARK-25855. Authored-by: Imran Rashid Signed-off-by: Marcelo Vanzin --- .../apache/spark/deploy/SparkHadoopUtil.scala | 32 ++++++++++++++++++- .../spark/internal/config/package.scala | 5 +++ .../scheduler/EventLoggingListener.scala | 7 +++- docs/configuration.md | 11 +++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 4cc0063d010ef..78a7cf648e7b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} +import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} @@ -30,7 +31,7 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -471,4 +472,33 @@ object SparkHadoopUtil { hadoopConf.set(key.substring("spark.hadoop.".length), value) } } + + // scalastyle:off line.size.limit + /** + * Create a path that uses replication instead of erasure coding (ec), regardless of the default + * configuration in hdfs for the given path. This can be helpful as hdfs ec doesn't support + * hflush(), hsync(), or append() + * https://hadoop.apache.org/docs/r3.0.0/hadoop-project-dist/hadoop-hdfs/HDFSErasureCoding.html#Limitations + */ + // scalastyle:on line.size.limit + def createNonECFile(fs: FileSystem, path: Path): FSDataOutputStream = { + try { + // Use reflection as this uses apis only avialable in hadoop 3 + val builderMethod = fs.getClass().getMethod("createFile", classOf[Path]) + val builder = builderMethod.invoke(fs, path) + val builderCls = builder.getClass() + // this may throw a NoSuchMethodException if the path is not on hdfs + val replicateMethod = builderCls.getMethod("replicate") + val buildMethod = builderCls.getMethod("build") + val b2 = replicateMethod.invoke(builder) + buildMethod.invoke(b2).asInstanceOf[FSDataOutputStream] + } catch { + case _: NoSuchMethodException => + // No createFile() method, we're using an older hdfs client, which doesn't give us control + // over EC vs. replication. Older hdfs doesn't have EC anyway, so just create a file with + // old apis. + fs.create(path) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e8b1d8859cc44..356cf9e76c85b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -58,6 +58,11 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val EVENT_LOG_ALLOW_EC = + ConfigBuilder("spark.eventLog.allowErasureCoding") + .booleanConf + .createWithDefault(false) + private[spark] val EVENT_LOG_TESTING = ConfigBuilder("spark.eventLog.testing") .internal() diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index f89fcd18ef56b..788b23d1bfb03 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -67,6 +67,7 @@ private[spark] class EventLoggingListener( private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val shouldAllowECLogs = sparkConf.get(EVENT_LOG_ALLOW_EC) private val shouldLogStageExecutorMetrics = sparkConf.get(EVENT_LOG_STAGE_EXECUTOR_METRICS) private val testing = sparkConf.get(EVENT_LOG_TESTING) private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt @@ -119,7 +120,11 @@ private[spark] class EventLoggingListener( if ((isDefaultLocal && uri.getScheme == null) || uri.getScheme == "file") { new FileOutputStream(uri.getPath) } else { - hadoopDataStream = Some(fileSystem.create(path)) + hadoopDataStream = Some(if (shouldAllowECLogs) { + fileSystem.create(path) + } else { + SparkHadoopUtil.createNonECFile(fileSystem, path) + }) hadoopDataStream.get } diff --git a/docs/configuration.md b/docs/configuration.md index 432b4cda47db2..8cb0ed1502126 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -761,6 +761,17 @@ Apart from these, the following properties are also available, and may be useful Compression will use spark.io.compression.codec. + + spark.eventLog.allowErasureCoding + false + + Whether to allow event logs to use erasure coding, or turn erasure coding off, regardless of + filesystem defaults. On HDFS, erasure coded files will not update as quickly as regular + replicated files, so the application updates will take longer to appear in the History Server. + Note that even if this is true, Spark will still not force the file to use erasure coding, it + will simply use filesystem defaults. + + spark.eventLog.dir file:///tmp/spark-events From 68dde3481ea458b0b8deeec2f99233c2d4c1e056 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 31 Oct 2018 13:00:10 -0500 Subject: [PATCH 1971/2461] [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager. This avoids having two classes to deal with tokens; now the above class is a one-stop shop for dealing with delegation tokens. The YARN backend extends that class instead of doing composition like before, resulting in a bit less code there too. The renewer functionality is basically the same code that used to be in YARN's AMCredentialRenewer. That is also the reason why the public API of HadoopDelegationTokenManager is a little bit odd; the YARN AM has some odd requirements for how this all should be initialized, and the weirdness is needed currently to support that. Tested: - YARN with stress app for DT renewal - Mesos and K8S with basic kerberos tests (both tgt and keytab) Closes #22624 from vanzin/SPARK-23781. Authored-by: Marcelo Vanzin Signed-off-by: Imran Rashid --- .../scala/org/apache/spark/SparkConf.scala | 4 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 14 - .../HadoopDelegationTokenManager.scala | 278 ++++++++++++++---- .../HadoopFSDelegationTokenProvider.scala | 5 +- .../spark/internal/config/package.scala | 4 + .../CoarseGrainedSchedulerBackend.scala | 40 ++- .../HadoopDelegationTokenManagerSuite.scala | 142 +++------ .../spark/deploy/k8s/KubernetesConf.scala | 3 +- .../hadooputils/HadoopKerberosLogin.scala | 10 +- ...bernetesHadoopDelegationTokenManager.scala | 35 +-- .../MesosCoarseGrainedSchedulerBackend.scala | 19 +- .../MesosHadoopDelegationTokenManager.scala | 160 ---------- .../spark/deploy/yarn/ApplicationMaster.scala | 24 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../org/apache/spark/deploy/yarn/config.scala | 4 - .../yarn/security/AMCredentialRenewer.scala | 177 ----------- .../YARNHadoopDelegationTokenManager.scala | 48 ++- .../cluster/YarnSchedulerBackend.scala | 5 +- ...ARNHadoopDelegationTokenManagerSuite.scala | 5 +- 19 files changed, 355 insertions(+), 624 deletions(-) delete mode 100644 resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 5166543933b32..8537c536887e6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -731,7 +731,9 @@ private[spark] object SparkConf extends Logging { KEYTAB.key -> Seq( AlternateConfig("spark.yarn.keytab", "3.0")), PRINCIPAL.key -> Seq( - AlternateConfig("spark.yarn.principal", "3.0")) + AlternateConfig("spark.yarn.principal", "3.0")), + KERBEROS_RELOGIN_PERIOD.key -> Seq( + AlternateConfig("spark.yarn.kerberos.relogin.period", "3.0")) ) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 78a7cf648e7b8..5979151345415 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -413,20 +413,6 @@ object SparkHadoopUtil { def get: SparkHadoopUtil = instance - /** - * Given an expiration date for the current set of credentials, calculate the time when new - * credentials should be created. - * - * @param expirationDate Drop-dead expiration date - * @param conf Spark configuration - * @return Timestamp when new credentials should be created. - */ - private[spark] def nextCredentialRenewalTime(expirationDate: Long, conf: SparkConf): Long = { - val ct = System.currentTimeMillis - val ratio = conf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) - (ct + (ratio * (expirationDate - ct))).toLong - } - /** * Returns a Configuration object with Spark configuration applied on top. Unlike * the instance method, this will always return a Configuration instance, and not a diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index ab8d8d96a9b08..10cd8742f2b49 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -17,76 +17,158 @@ package org.apache.spark.deploy.security +import java.io.File +import java.security.PrivilegedExceptionAction +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.security.Credentials +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils +import org.apache.spark.util.ThreadUtils /** - * Manages all the registered HadoopDelegationTokenProviders and offer APIs for other modules to - * obtain delegation tokens and their renewal time. By default [[HadoopFSDelegationTokenProvider]], - * [[HiveDelegationTokenProvider]] and [[HBaseDelegationTokenProvider]] will be loaded in if not - * explicitly disabled. + * Manager for delegation tokens in a Spark application. + * + * This manager has two modes of operation: + * + * 1. When configured with a principal and a keytab, it will make sure long-running apps can run + * without interruption while accessing secured services. It periodically logs in to the KDC with + * user-provided credentials, and contacts all the configured secure services to obtain delegation + * tokens to be distributed to the rest of the application. + * + * Because the Hadoop UGI API does not expose the TTL of the TGT, a configuration controls how often + * to check that a relogin is necessary. This is done reasonably often since the check is a no-op + * when the relogin is not yet needed. The check period can be overridden in the configuration. * - * Also, each HadoopDelegationTokenProvider is controlled by - * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to - * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be - * enabled/disabled by the configuration spark.security.credentials.hive.enabled. + * New delegation tokens are created once 75% of the renewal interval of the original tokens has + * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. + * The driver is tasked with distributing the tokens to other processes that might need them. * - * @param sparkConf Spark configuration - * @param hadoopConf Hadoop configuration - * @param fileSystems Delegation tokens will be fetched for these Hadoop filesystems. + * 2. When operating without an explicit principal and keytab, token renewal will not be available. + * Starting the manager will distribute an initial set of delegation tokens to the provided Spark + * driver, but the app will not get new tokens when those expire. + * + * It can also be used just to create delegation tokens, by calling the `obtainDelegationTokens` + * method. This option does not require calling the `start` method, but leaves it up to the + * caller to distribute the tokens that were generated. */ private[spark] class HadoopDelegationTokenManager( - sparkConf: SparkConf, - hadoopConf: Configuration, - fileSystems: Configuration => Set[FileSystem]) - extends Logging { + protected val sparkConf: SparkConf, + protected val hadoopConf: Configuration) extends Logging { private val deprecatedProviderEnabledConfigs = List( "spark.yarn.security.tokens.%s.enabled", "spark.yarn.security.credentials.%s.enabled") private val providerEnabledConfig = "spark.security.credentials.%s.enabled" - // Maintain all the registered delegation token providers - private val delegationTokenProviders = getDelegationTokenProviders + private val principal = sparkConf.get(PRINCIPAL).orNull + private val keytab = sparkConf.get(KEYTAB).orNull + + require((principal == null) == (keytab == null), + "Both principal and keytab must be defined, or neither.") + require(keytab == null || new File(keytab).isFile(), s"Cannot find keytab at $keytab.") + + private val delegationTokenProviders = loadProviders() logDebug("Using the following builtin delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") - /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ - def this(sparkConf: SparkConf, hadoopConf: Configuration) = { - this( - sparkConf, - hadoopConf, - hadoopConf => Set(FileSystem.get(hadoopConf).getHomeDirectory.getFileSystem(hadoopConf))) + private var renewalExecutor: ScheduledExecutorService = _ + private val driverRef = new AtomicReference[RpcEndpointRef]() + + /** Set the endpoint used to send tokens to the driver. */ + def setDriverRef(ref: RpcEndpointRef): Unit = { + driverRef.set(ref) } - private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { - val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystems)) ++ - safeCreateProvider(new HiveDelegationTokenProvider) ++ - safeCreateProvider(new HBaseDelegationTokenProvider) + /** @return Whether delegation token renewal is enabled. */ + def renewalEnabled: Boolean = principal != null - // Filter out providers for which spark.security.credentials.{service}.enabled is false. - providers - .filter { p => isServiceEnabled(p.serviceName) } - .map { p => (p.serviceName, p) } - .toMap + /** + * Start the token renewer. Requires a principal and keytab. Upon start, the renewer will: + * + * - log in the configured principal, and set up a task to keep that user's ticket renewed + * - obtain delegation tokens from all available providers + * - send the tokens to the driver, if it's already registered + * - schedule a periodic task to update the tokens when needed. + * + * @return The newly logged in user. + */ + def start(): UserGroupInformation = { + require(renewalEnabled, "Token renewal must be enabled to start the renewer.") + renewalExecutor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Renewal Thread") + + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + val ugi = doLogin() + + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() + } + } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) + + val creds = obtainTokensAndScheduleRenewal(ugi) + ugi.addCredentials(creds) + + val driver = driverRef.get() + if (driver != null) { + val tokens = SparkHadoopUtil.get.serialize(creds) + driver.send(UpdateDelegationTokens(tokens)) + } + + // Transfer the original user's tokens to the new user, since it may contain needed tokens + // (such as those user to connect to YARN). Explicitly avoid overwriting tokens that already + // exist in the current user's credentials, since those were freshly obtained above + // (see SPARK-23361). + val existing = ugi.getCredentials() + existing.mergeAll(originalCreds) + ugi.addCredentials(existing) + ugi } - private def safeCreateProvider( - createFn: => HadoopDelegationTokenProvider): Option[HadoopDelegationTokenProvider] = { - try { - Some(createFn) - } catch { - case t: Throwable => - logDebug(s"Failed to load built in provider.", t) - None + def stop(): Unit = { + if (renewalExecutor != null) { + renewalExecutor.shutdown() } } - def isServiceEnabled(serviceName: String): Boolean = { + /** + * Fetch new delegation tokens for configured services, storing them in the given credentials. + * Tokens are fetched for the current logged in user. + * + * @param creds Credentials object where to store the delegation tokens. + * @return The time by which the tokens must be renewed. + */ + def obtainDelegationTokens(creds: Credentials): Long = { + delegationTokenProviders.values.flatMap { provider => + if (provider.delegationTokensRequired(sparkConf, hadoopConf)) { + provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) + } else { + logDebug(s"Service ${provider.serviceName} does not require a token." + + s" Check your configuration to see if security is disabled or not.") + None + } + }.foldLeft(Long.MaxValue)(math.min) + } + + // Visible for testing. + def isProviderLoaded(serviceName: String): Boolean = { + delegationTokenProviders.contains(serviceName) + } + + protected def isServiceEnabled(serviceName: String): Boolean = { val key = providerEnabledConfig.format(serviceName) deprecatedProviderEnabledConfigs.foreach { pattern => @@ -110,32 +192,104 @@ private[spark] class HadoopDelegationTokenManager( } /** - * Get delegation token provider for the specified service. + * List of file systems for which to obtain delegation tokens. The base implementation + * returns just the default file system in the given Hadoop configuration. */ - def getServiceDelegationTokenProvider(service: String): Option[HadoopDelegationTokenProvider] = { - delegationTokenProviders.get(service) + protected def fileSystemsToAccess(): Set[FileSystem] = { + Set(FileSystem.get(hadoopConf)) + } + + private def scheduleRenewal(delay: Long): Unit = { + val _delay = math.max(0, delay) + logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") + + val renewalTask = new Runnable() { + override def run(): Unit = { + updateTokensTask() + } + } + renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) } /** - * Writes delegation tokens to creds. Delegation tokens are fetched from all registered - * providers. - * - * @param hadoopConf hadoop Configuration - * @param creds Credentials that will be updated in place (overwritten) - * @return Time after which the fetched delegation tokens should be renewed. + * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself + * to fetch the next set of tokens when needed. */ - def obtainDelegationTokens( - hadoopConf: Configuration, - creds: Credentials): Long = { - delegationTokenProviders.values.flatMap { provider => - if (provider.delegationTokensRequired(sparkConf, hadoopConf)) { - provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) + private def updateTokensTask(): Unit = { + try { + val freshUGI = doLogin() + val creds = obtainTokensAndScheduleRenewal(freshUGI) + val tokens = SparkHadoopUtil.get.serialize(creds) + + val driver = driverRef.get() + if (driver != null) { + logInfo("Updating delegation tokens.") + driver.send(UpdateDelegationTokens(tokens)) } else { - logDebug(s"Service ${provider.serviceName} does not require a token." + - s" Check your configuration to see if security is disabled or not.") - None + // This shouldn't really happen, since the driver should register way before tokens expire. + logWarning("Delegation tokens close to expiration but no driver has registered yet.") + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) } - }.foldLeft(Long.MaxValue)(math.min) + } catch { + case e: Exception => + val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + + " If this happens too often tasks will fail.", e) + scheduleRenewal(delay) + } + } + + /** + * Obtain new delegation tokens from the available providers. Schedules a new task to fetch + * new tokens before the new set expires. + * + * @return Credentials containing the new tokens. + */ + private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { + ugi.doAs(new PrivilegedExceptionAction[Credentials]() { + override def run(): Credentials = { + val creds = new Credentials() + val nextRenewal = obtainDelegationTokens(creds) + + // Calculate the time when new credentials should be created, based on the configured + // ratio. + val now = System.currentTimeMillis + val ratio = sparkConf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) + val delay = (ratio * (nextRenewal - now)).toLong + scheduleRenewal(delay) + creds + } + }) + } + + private def doLogin(): UserGroupInformation = { + logInfo(s"Attempting to login to KDC using principal: $principal") + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + ugi + } + + private def loadProviders(): Map[String, HadoopDelegationTokenProvider] = { + val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystemsToAccess)) ++ + safeCreateProvider(new HiveDelegationTokenProvider) ++ + safeCreateProvider(new HBaseDelegationTokenProvider) + + // Filter out providers for which spark.security.credentials.{service}.enabled is false. + providers + .filter { p => isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap } -} + private def safeCreateProvider( + createFn: => HadoopDelegationTokenProvider): Option[HadoopDelegationTokenProvider] = { + try { + Some(createFn) + } catch { + case t: Throwable => + logDebug(s"Failed to load built in provider.", t) + None + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 21ca669ea98f0..767b5521e8d7b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -30,7 +30,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration => Set[FileSystem]) +private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: () => Set[FileSystem]) extends HadoopDelegationTokenProvider with Logging { // This tokenRenewalInterval will be set in the first call to obtainDelegationTokens. @@ -44,8 +44,7 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration hadoopConf: Configuration, sparkConf: SparkConf, creds: Credentials): Option[Long] = { - - val fsToGetTokens = fileSystems(hadoopConf) + val fsToGetTokens = fileSystems() val fetchCreds = fetchDelegationTokens(getTokenRenewer(hadoopConf), fsToGetTokens, creds) // Get the token renewal interval if it is not set. It will only be called once. diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 356cf9e76c85b..034e5ebbd293d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -179,6 +179,10 @@ package object config { .doc("Name of the Kerberos principal.") .stringConf.createOptional + private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.kerberos.relogin.period") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1m") + private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances") .intConf .createOptional diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index de7c0d813ae65..329158a44d369 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -18,13 +18,17 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future +import org.apache.hadoop.security.UserGroupInformation + import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -95,6 +99,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 + // Current set of delegation tokens to send to executors. + private val delegationTokens = new AtomicReference[Array[Byte]]() + + // The token manager used to create security tokens. + private var delegationTokenManager: Option[HadoopDelegationTokenManager] = None + private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") @@ -152,6 +162,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } case UpdateDelegationTokens(newDelegationTokens) => + SparkHadoopUtil.get.addDelegationTokens(newDelegationTokens, conf) + delegationTokens.set(newDelegationTokens) executorDataMap.values.foreach { ed => ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens)) } @@ -230,7 +242,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val reply = SparkAppConfig( sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey(), - fetchHadoopDelegationTokens()) + Option(delegationTokens.get())) context.reply(reply) } @@ -390,6 +402,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // TODO (prashant) send conf instead of properties driverEndpoint = createDriverEndpointRef(properties) + + if (UserGroupInformation.isSecurityEnabled()) { + delegationTokenManager = createTokenManager() + delegationTokenManager.foreach { dtm => + dtm.setDriverRef(driverEndpoint) + val creds = if (dtm.renewalEnabled) { + dtm.start().getCredentials() + } else { + val creds = UserGroupInformation.getCurrentUser().getCredentials() + dtm.obtainDelegationTokens(creds) + creds + } + delegationTokens.set(SparkHadoopUtil.get.serialize(creds)) + } + } } protected def createDriverEndpointRef( @@ -416,6 +443,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def stop() { reviveThread.shutdownNow() stopExecutors() + delegationTokenManager.foreach(_.stop()) try { if (driverEndpoint != null) { driverEndpoint.askSync[Boolean](StopDriver) @@ -684,7 +712,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp true } - protected def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { None } + /** + * Create the delegation token manager to be used for the application. This method is called + * once during the start of the scheduler backend (so after the object has already been + * fully constructed), only if security is enabled in the Hadoop configuration. + */ + protected def createTokenManager(): Option[HadoopDelegationTokenManager] = None + } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index 2849a10a2c81e..e0e630e3be63b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -21,94 +21,36 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials -import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils -class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { - private var delegationTokenManager: HadoopDelegationTokenManager = null - private var sparkConf: SparkConf = null - private var hadoopConf: Configuration = null +class HadoopDelegationTokenManagerSuite extends SparkFunSuite { + private val hadoopConf = new Configuration() - override def beforeAll(): Unit = { - super.beforeAll() - - sparkConf = new SparkConf() - hadoopConf = new Configuration() - } - - test("Correctly load default credential providers") { - delegationTokenManager = new HadoopDelegationTokenManager( - sparkConf, - hadoopConf, - hadoopFSsToAccess) - - delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) - delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) - delegationTokenManager.getServiceDelegationTokenProvider("hive") should not be (None) - delegationTokenManager.getServiceDelegationTokenProvider("bogus") should be (None) + test("default configuration") { + val manager = new HadoopDelegationTokenManager(new SparkConf(false), hadoopConf) + assert(manager.isProviderLoaded("hadoopfs")) + assert(manager.isProviderLoaded("hbase")) + assert(manager.isProviderLoaded("hive")) } test("disable hive credential provider") { - sparkConf.set("spark.security.credentials.hive.enabled", "false") - delegationTokenManager = new HadoopDelegationTokenManager( - sparkConf, - hadoopConf, - hadoopFSsToAccess) - - delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should not be (None) - delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) - delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) + val sparkConf = new SparkConf(false).set("spark.security.credentials.hive.enabled", "false") + val manager = new HadoopDelegationTokenManager(sparkConf, hadoopConf) + assert(manager.isProviderLoaded("hadoopfs")) + assert(manager.isProviderLoaded("hbase")) + assert(!manager.isProviderLoaded("hive")) } test("using deprecated configurations") { - sparkConf.set("spark.yarn.security.tokens.hadoopfs.enabled", "false") - sparkConf.set("spark.yarn.security.credentials.hive.enabled", "false") - delegationTokenManager = new HadoopDelegationTokenManager( - sparkConf, - hadoopConf, - hadoopFSsToAccess) - - delegationTokenManager.getServiceDelegationTokenProvider("hadoopfs") should be (None) - delegationTokenManager.getServiceDelegationTokenProvider("hive") should be (None) - delegationTokenManager.getServiceDelegationTokenProvider("hbase") should not be (None) - } - - test("verify no credentials are obtained") { - delegationTokenManager = new HadoopDelegationTokenManager( - sparkConf, - hadoopConf, - hadoopFSsToAccess) - val creds = new Credentials() - - // Tokens cannot be obtained from HDFS, Hive, HBase in unit tests. - delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) - val tokens = creds.getAllTokens - tokens.size() should be (0) - } - - test("obtain tokens For HiveMetastore") { - val hadoopConf = new Configuration() - hadoopConf.set("hive.metastore.kerberos.principal", "bob") - // thrift picks up on port 0 and bails out, without trying to talk to endpoint - hadoopConf.set("hive.metastore.uris", "http://localhost:0") - - val hiveCredentialProvider = new HiveDelegationTokenProvider() - val credentials = new Credentials() - hiveCredentialProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) - - credentials.getAllTokens.size() should be (0) - } - - test("Obtain tokens For HBase") { - val hadoopConf = new Configuration() - hadoopConf.set("hbase.security.authentication", "kerberos") - - val hbaseTokenProvider = new HBaseDelegationTokenProvider() - val creds = new Credentials() - hbaseTokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, creds) - - creds.getAllTokens.size should be (0) + val sparkConf = new SparkConf(false) + .set("spark.yarn.security.tokens.hadoopfs.enabled", "false") + .set("spark.yarn.security.credentials.hive.enabled", "false") + val manager = new HadoopDelegationTokenManager(sparkConf, hadoopConf) + assert(!manager.isProviderLoaded("hadoopfs")) + assert(manager.isProviderLoaded("hbase")) + assert(!manager.isProviderLoaded("hive")) } test("SPARK-23209: obtain tokens when Hive classes are not available") { @@ -123,43 +65,41 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { throw new ClassNotFoundException(name) } - if (name.startsWith("java") || name.startsWith("scala")) { - currentLoader.loadClass(name) - } else { - val classFileName = name.replaceAll("\\.", "/") + ".class" - val in = currentLoader.getResourceAsStream(classFileName) - if (in != null) { - val bytes = IOUtils.toByteArray(in) - defineClass(name, bytes, 0, bytes.length) - } else { - throw new ClassNotFoundException(name) - } + val prefixBlacklist = Seq("java", "scala", "com.sun.", "sun.") + if (prefixBlacklist.exists(name.startsWith(_))) { + return currentLoader.loadClass(name) } + + val found = findLoadedClass(name) + if (found != null) { + return found + } + + val classFileName = name.replaceAll("\\.", "/") + ".class" + val in = currentLoader.getResourceAsStream(classFileName) + if (in != null) { + val bytes = IOUtils.toByteArray(in) + return defineClass(name, bytes, 0, bytes.length) + } + + throw new ClassNotFoundException(name) } } - try { - Thread.currentThread().setContextClassLoader(noHive) + Utils.withContextClassLoader(noHive) { val test = noHive.loadClass(NoHiveTest.getClass.getName().stripSuffix("$")) test.getMethod("runTest").invoke(null) - } finally { - Thread.currentThread().setContextClassLoader(currentLoader) } } - - private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { - Set(FileSystem.get(hadoopConf)) - } } /** Test code for SPARK-23209 to avoid using too much reflection above. */ -private object NoHiveTest extends Matchers { +private object NoHiveTest { def runTest(): Unit = { try { - val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(), - _ => Set()) - manager.getServiceDelegationTokenProvider("hive") should be (None) + val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration()) + require(!manager.isProviderLoaded("hive")) } catch { case e: Throwable => // Throw a better exception in case the test fails, since there may be a lot of nesting. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 3e30ab2c8353e..066547dcbb408 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -27,7 +27,6 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.security.KubernetesHadoopDelegationTokenManager import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.deploy.k8s.submit.KubernetesClientApplication._ -import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config.ConfigEntry @@ -79,7 +78,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( def krbConfigMapName: String = s"$appResourceNamePrefix-krb5-file" def tokenManager(conf: SparkConf, hConf: Configuration): KubernetesHadoopDelegationTokenManager = - new KubernetesHadoopDelegationTokenManager(new HadoopDelegationTokenManager(conf, hConf)) + new KubernetesHadoopDelegationTokenManager(conf, hConf) def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala index 67a58491e442e..0022d8f242a72 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala @@ -38,16 +38,14 @@ private[spark] object HadoopKerberosLogin { submissionSparkConf: SparkConf, kubernetesResourceNamePrefix: String, tokenManager: KubernetesHadoopDelegationTokenManager): KerberosConfigSpec = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(submissionSparkConf) // The JobUserUGI will be taken fom the Local Ticket Cache or via keytab+principal // The login happens in the SparkSubmit so login logic is not necessary to include val jobUserUGI = tokenManager.getCurrentUser val originalCredentials = jobUserUGI.getCredentials - val (tokenData, renewalInterval) = tokenManager.getDelegationTokens( - originalCredentials, - submissionSparkConf, - hadoopConf) - require(tokenData.nonEmpty, "Did not obtain any delegation tokens") + tokenManager.obtainDelegationTokens(originalCredentials) + + val tokenData = SparkHadoopUtil.get.serialize(originalCredentials) + val initialTokenDataKeyName = KERBEROS_SECRET_KEY val newSecretName = s"$kubernetesResourceNamePrefix-$KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME" val secretDT = diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala index 135e2c482bbbc..3e98d5811d83f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala @@ -18,45 +18,20 @@ package org.apache.spark.deploy.k8s.security import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.Logging /** - * The KubernetesHadoopDelegationTokenManager fetches Hadoop delegation tokens - * on the behalf of the Kubernetes submission client. The new credentials - * (called Tokens when they are serialized) are stored in Secrets accessible - * to the driver and executors, when new Tokens are received they overwrite the current Secrets. + * Adds Kubernetes-specific functionality to HadoopDelegationTokenManager. */ private[spark] class KubernetesHadoopDelegationTokenManager( - tokenManager: HadoopDelegationTokenManager) extends Logging { + _sparkConf: SparkConf, + _hadoopConf: Configuration) + extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf) { - // HadoopUGI Util methods def getCurrentUser: UserGroupInformation = UserGroupInformation.getCurrentUser - def getShortUserName: String = getCurrentUser.getShortUserName - def getFileSystem(hadoopConf: Configuration): FileSystem = FileSystem.get(hadoopConf) def isSecurityEnabled: Boolean = UserGroupInformation.isSecurityEnabled - def loginUserFromKeytabAndReturnUGI(principal: String, keytab: String): UserGroupInformation = - UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - def serializeCreds(creds: Credentials): Array[Byte] = SparkHadoopUtil.get.serialize(creds) - def nextRT(rt: Long, conf: SparkConf): Long = SparkHadoopUtil.nextCredentialRenewalTime(rt, conf) - def getDelegationTokens( - creds: Credentials, - conf: SparkConf, - hadoopConf: Configuration): (Array[Byte], Long) = { - try { - val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) - logDebug(s"Initialized tokens") - (serializeCreds(creds), nextRT(rt, conf)) - } catch { - case e: Exception => - logError(s"Failed to fetch Hadoop delegation tokens $e") - throw e - } - } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index bac0246b7ddc5..f5866651dc90b 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -26,12 +26,12 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.Future -import org.apache.hadoop.security.UserGroupInformation import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException, TaskState} import org.apache.spark.deploy.mesos.config._ +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config import org.apache.spark.internal.config.EXECUTOR_HEARTBEAT_INTERVAL import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} @@ -60,9 +60,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with org.apache.mesos.Scheduler with MesosSchedulerUtils { - private lazy val hadoopDelegationTokenManager: MesosHadoopDelegationTokenManager = - new MesosHadoopDelegationTokenManager(conf, sc.hadoopConfiguration, driverEndpoint) - // Blacklist a slave after this many failures private val MAX_SLAVE_FAILURES = 2 @@ -678,7 +675,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( launcherBackend.close() } - private def stopSchedulerBackend() { + private def stopSchedulerBackend(): Unit = { // Make sure we're not launching tasks during shutdown stateLock.synchronized { if (stopCalled) { @@ -777,6 +774,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } } + override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { + Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration)) + } + private def numExecutors(): Int = { slaves.values.map(_.taskIDs.size).sum } @@ -789,14 +790,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( offer.getHostname } } - - override def fetchHadoopDelegationTokens(): Option[Array[Byte]] = { - if (UserGroupInformation.isSecurityEnabled) { - Some(hadoopDelegationTokenManager.getTokens()) - } else { - None - } - } } private class Slave(val hostname: String) { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala deleted file mode 100644 index a1bf4f0c048fe..0000000000000 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import java.security.PrivilegedExceptionAction -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.UserGroupInformation - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens -import org.apache.spark.ui.UIUtils -import org.apache.spark.util.ThreadUtils - - -/** - * The MesosHadoopDelegationTokenManager fetches and updates Hadoop delegation tokens on the behalf - * of the MesosCoarseGrainedSchedulerBackend. It is modeled after the YARN AMCredentialRenewer, - * and similarly will renew the Credentials when 75% of the renewal interval has passed. - * The principal difference is that instead of writing the new credentials to HDFS and - * incrementing the timestamp of the file, the new credentials (called Tokens when they are - * serialized) are broadcast to all running executors. On the executor side, when new Tokens are - * received they overwrite the current credentials. - */ -private[spark] class MesosHadoopDelegationTokenManager( - conf: SparkConf, - hadoopConfig: Configuration, - driverEndpoint: RpcEndpointRef) - extends Logging { - - require(driverEndpoint != null, "DriverEndpoint is not initialized") - - private val credentialRenewerThread: ScheduledExecutorService = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Renewal Thread") - - private val tokenManager: HadoopDelegationTokenManager = - new HadoopDelegationTokenManager(conf, hadoopConfig) - - private val principal: String = conf.get(config.PRINCIPAL).orNull - - private var (tokens: Array[Byte], timeOfNextRenewal: Long) = { - try { - val creds = UserGroupInformation.getCurrentUser.getCredentials - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) - logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf)) - } catch { - case e: Exception => - logError(s"Failed to fetch Hadoop delegation tokens $e") - throw e - } - } - - private val keytabFile: Option[String] = conf.get(config.KEYTAB) - - scheduleTokenRenewal() - - private def scheduleTokenRenewal(): Unit = { - if (keytabFile.isDefined) { - require(principal != null, "Principal is required for Keytab-based authentication") - logInfo(s"Using keytab: ${keytabFile.get} and principal $principal") - } else { - logInfo("Using ticket cache for Kerberos authentication, no token renewal.") - return - } - - def scheduleRenewal(runnable: Runnable): Unit = { - val remainingTime = timeOfNextRenewal - System.currentTimeMillis() - if (remainingTime <= 0) { - logInfo("Credentials have expired, creating new ones now.") - runnable.run() - } else { - logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) - } - } - - val credentialRenewerRunnable = - new Runnable { - override def run(): Unit = { - try { - getNewDelegationTokens() - broadcastDelegationTokens(tokens) - } catch { - case e: Exception => - // Log the error and try to write new tokens back in an hour - val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT)) - logWarning( - s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e) - credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS) - return - } - scheduleRenewal(this) - } - } - scheduleRenewal(credentialRenewerRunnable) - } - - private def getNewDelegationTokens(): Unit = { - logInfo(s"Attempting to login to KDC with principal ${principal}") - // Get new delegation tokens by logging in with a new UGI inspired by AMCredentialRenewer.scala - // Don't protect against keytabFile being empty because it's guarded above. - val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytabFile.get) - logInfo("Successfully logged into KDC") - val tempCreds = ugi.getCredentials - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val nextRenewalTime = ugi.doAs(new PrivilegedExceptionAction[Long] { - override def run(): Long = { - tokenManager.obtainDelegationTokens(hadoopConf, tempCreds) - } - }) - - val currTime = System.currentTimeMillis() - timeOfNextRenewal = if (nextRenewalTime <= currTime) { - logWarning(s"Next credential renewal time ($nextRenewalTime) is earlier than " + - s"current time ($currTime), which is unexpected, please check your credential renewal " + - "related configurations in the target services.") - currTime - } else { - SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf) - } - logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") - - // Add the temp credentials back to the original ones. - UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - // update tokens for late or dynamically added executors - tokens = SparkHadoopUtil.get.serialize(tempCreds) - } - - private def broadcastDelegationTokens(tokens: Array[Byte]) = { - logInfo("Sending new tokens to all executors") - driverEndpoint.send(UpdateDelegationTokens(tokens)) - } - - def getTokens(): Array[Byte] = { - tokens - } -} - diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8f94e3f731007..c1f3211bcab29 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -41,7 +41,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.AMCredentialRenewer +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem @@ -99,20 +99,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ => - new AMCredentialRenewer(sparkConf, yarnConf) + private val tokenManager: Option[YARNHadoopDelegationTokenManager] = { + sparkConf.get(KEYTAB).map { _ => + new YARNHadoopDelegationTokenManager(sparkConf, yarnConf) + } } - private val ugi = credentialRenewer match { - case Some(cr) => + private val ugi = tokenManager match { + case Some(tm) => // Set the context class loader so that the token renewer has access to jars distributed // by the user. - val currentLoader = Thread.currentThread().getContextClassLoader() - Thread.currentThread().setContextClassLoader(userClassLoader) - try { - cr.start() - } finally { - Thread.currentThread().setContextClassLoader(currentLoader) + Utils.withContextClassLoader(userClassLoader) { + tm.start() } case _ => @@ -380,7 +378,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends userClassThread.interrupt() } if (!inShutdown) { - credentialRenewer.foreach(_.stop()) + tokenManager.foreach(_.stop()) } } } @@ -440,7 +438,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends securityMgr, localResources) - credentialRenewer.foreach(_.setDriverRef(driverRef)) + tokenManager.foreach(_.setDriverRef(driverRef)) // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 49b7f6261e8db..6240f7b68d2c8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -310,7 +310,7 @@ private[spark] class Client( private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { val credentials = UserGroupInformation.getCurrentUser().getCredentials() val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - credentialManager.obtainDelegationTokens(hadoopConf, credentials) + credentialManager.obtainDelegationTokens(credentials) // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid // that for regular users, since in those case the user already has access to the TGT, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index f2ed555edc1df..b257d8fdd3b1a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -325,10 +325,6 @@ package object config { .stringConf .createOptional - private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") - .timeConf(TimeUnit.SECONDS) - .createWithDefaultString("1m") - // The list of cache-related config entries. This is used by Client and the AM to clean // up the environment so that these settings do not appear on the web UI. private[yarn] val CACHE_CONFIGS = Seq( diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala deleted file mode 100644 index bc8d47dbd54c6..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.yarn.security - -import java.security.PrivilegedExceptionAction -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.AtomicReference - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens -import org.apache.spark.ui.UIUtils -import org.apache.spark.util.ThreadUtils - -/** - * A manager tasked with periodically updating delegation tokens needed by the application. - * - * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run - * without interruption while accessing secured services. It periodically logs in to the KDC with - * user-provided credentials, and contacts all the configured secure services to obtain delegation - * tokens to be distributed to the rest of the application. - * - * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API - * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is - * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet - * needed. The check period can be overridden in the configuration. - * - * New delegation tokens are created once 75% of the renewal interval of the original tokens has - * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. - * The driver is tasked with distributing the tokens to other processes that might need them. - */ -private[yarn] class AMCredentialRenewer( - sparkConf: SparkConf, - hadoopConf: Configuration) extends Logging { - - private val principal = sparkConf.get(PRINCIPAL).get - private val keytab = sparkConf.get(KEYTAB).get - private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - - private val renewalExecutor: ScheduledExecutorService = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - - private val driverRef = new AtomicReference[RpcEndpointRef]() - - private val renewalTask = new Runnable() { - override def run(): Unit = { - updateTokensTask() - } - } - - def setDriverRef(ref: RpcEndpointRef): Unit = { - driverRef.set(ref) - } - - /** - * Start the token renewer. Upon start, the renewer will: - * - * - log in the configured user, and set up a task to keep that user's ticket renewed - * - obtain delegation tokens from all available providers - * - schedule a periodic task to update the tokens when needed. - * - * @return The newly logged in user. - */ - def start(): UserGroupInformation = { - val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() - val ugi = doLogin() - - val tgtRenewalTask = new Runnable() { - override def run(): Unit = { - ugi.checkTGTAndReloginFromKeytab() - } - } - val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) - renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, - TimeUnit.SECONDS) - - val creds = obtainTokensAndScheduleRenewal(ugi) - ugi.addCredentials(creds) - - // Transfer the original user's tokens to the new user, since that's needed to connect to - // YARN. Explicitly avoid overwriting tokens that already exist in the current user's - // credentials, since those were freshly obtained above (see SPARK-23361). - val existing = ugi.getCredentials() - existing.mergeAll(originalCreds) - ugi.addCredentials(existing) - - ugi - } - - def stop(): Unit = { - renewalExecutor.shutdown() - } - - private def scheduleRenewal(delay: Long): Unit = { - val _delay = math.max(0, delay) - logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") - renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) - } - - /** - * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself - * to fetch the next set of tokens when needed. - */ - private def updateTokensTask(): Unit = { - try { - val freshUGI = doLogin() - val creds = obtainTokensAndScheduleRenewal(freshUGI) - val tokens = SparkHadoopUtil.get.serialize(creds) - - val driver = driverRef.get() - if (driver != null) { - logInfo("Updating delegation tokens.") - driver.send(UpdateDelegationTokens(tokens)) - } else { - // This shouldn't really happen, since the driver should register way before tokens expire - // (or the AM should time out the application). - logWarning("Delegation tokens close to expiration but no driver has registered yet.") - SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) - } - } catch { - case e: Exception => - val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) - logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + - " If this happens too often tasks will fail.", e) - scheduleRenewal(delay) - } - } - - /** - * Obtain new delegation tokens from the available providers. Schedules a new task to fetch - * new tokens before the new set expires. - * - * @return Credentials containing the new tokens. - */ - private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { - ugi.doAs(new PrivilegedExceptionAction[Credentials]() { - override def run(): Credentials = { - val creds = new Credentials() - val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds) - - val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) - - System.currentTimeMillis() - scheduleRenewal(timeToWait) - creds - } - }) - } - - private def doLogin(): UserGroupInformation = { - logInfo(s"Attempting to login to KDC using principal: $principal") - val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - logInfo("Successfully logged into KDC.") - ugi - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 26a2e5d730218..2d9a3f0c83fd2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -22,12 +22,13 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil -import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils /** @@ -36,27 +37,25 @@ import org.apache.spark.util.Utils * in [[HadoopDelegationTokenManager]]. */ private[yarn] class YARNHadoopDelegationTokenManager( - sparkConf: SparkConf, - hadoopConf: Configuration) extends Logging { + _sparkConf: SparkConf, + _hadoopConf: Configuration) + extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf) { - private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - - // public for testing - val credentialProviders = getCredentialProviders + private val credentialProviders = { + ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) + .asScala + .toList + .filter { p => isServiceEnabled(p.serviceName) } + .map { p => (p.serviceName, p) } + .toMap + } if (credentialProviders.nonEmpty) { logDebug("Using the following YARN-specific credential providers: " + s"${credentialProviders.keys.mkString(", ")}.") } - /** - * Writes delegation tokens to creds. Delegation tokens are fetched from all registered - * providers. - * - * @return Time after which the fetched delegation tokens should be renewed. - */ - def obtainDelegationTokens(hadoopConf: Configuration, creds: Credentials): Long = { - val superInterval = delegationTokenManager.obtainDelegationTokens(hadoopConf, creds) + override def obtainDelegationTokens(creds: Credentials): Long = { + val superInterval = super.obtainDelegationTokens(creds) credentialProviders.values.flatMap { provider => if (provider.credentialsRequired(hadoopConf)) { @@ -69,18 +68,13 @@ private[yarn] class YARNHadoopDelegationTokenManager( }.foldLeft(superInterval)(math.min) } - private def getCredentialProviders: Map[String, ServiceCredentialProvider] = { - val providers = loadCredentialProviders - - providers. - filter { p => delegationTokenManager.isServiceEnabled(p.serviceName) } - .map { p => (p.serviceName, p) } - .toMap + // For testing. + override def isProviderLoaded(serviceName: String): Boolean = { + credentialProviders.contains(serviceName) || super.isProviderLoaded(serviceName) } - private def loadCredentialProviders: List[ServiceCredentialProvider] = { - ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) - .asScala - .toList + override protected def fileSystemsToAccess(): Set[FileSystem] = { + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, hadoopConf) } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 63bea3e7a5003..67c36aac49266 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,16 +19,14 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.{AtomicBoolean} -import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import scala.util.{Failure, Success} import scala.util.control.NonFatal -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -270,7 +268,6 @@ private[spark] abstract class YarnSchedulerBackend( case u @ UpdateDelegationTokens(tokens) => // Add the tokens to the current user and send a message to the scheduler so that it // notifies all registered executors of the new tokens. - SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf) driverEndpoint.send(u) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 9fa749b14c98c..98315e4235741 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.deploy.yarn.security import org.apache.hadoop.conf.Configuration import org.apache.hadoop.security.Credentials -import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} -class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { +class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite { private var credentialManager: YARNHadoopDelegationTokenManager = null private var sparkConf: SparkConf = null private var hadoopConf: Configuration = null @@ -36,7 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers test("Correctly loads credential providers") { credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - credentialManager.credentialProviders.get("yarn-test") should not be (None) + assert(credentialManager.isProviderLoaded("yarn-test")) } } From bc9f9b4d6e6ac983a903a0b9a3a668950dc0b2a7 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Wed, 31 Oct 2018 18:35:33 +0000 Subject: [PATCH 1972/2461] [SPARK-25860][SQL] Replace Literal(null, _) with FalseLiteral whenever possible MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR proposes a new optimization rule that replaces `Literal(null, _)` with `FalseLiteral` in conditions in `Join` and `Filter`, predicates in `If`, conditions in `CaseWhen`. The idea is that some expressions evaluate to `false` if the underlying expression is `null` (as an example see `GeneratePredicate$create` or `doGenCode` and `eval` methods in `If` and `CaseWhen`). Therefore, we can replace `Literal(null, _)` with `FalseLiteral`, which can lead to more optimizations later on. Let’s consider a few examples. ``` val df = spark.range(1, 100).select($"id".as("l"), ($"id" > 50).as("b")) df.createOrReplaceTempView("t") df.createOrReplaceTempView("p") ``` **Case 1** ``` spark.sql("SELECT * FROM t WHERE if(l > 10, false, NULL)").explain(true) // without the new rule … == Optimized Logical Plan == Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- Filter if ((id#0L > 10)) false else null +- Range (1, 100, step=1, splits=Some(12)) == Physical Plan == *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- *(1) Filter if ((id#0L > 10)) false else null +- *(1) Range (1, 100, step=1, splits=12) // with the new rule … == Optimized Logical Plan == LocalRelation , [l#2L, s#3] == Physical Plan == LocalTableScan , [l#2L, s#3] ``` **Case 2** ``` spark.sql("SELECT * FROM t WHERE CASE WHEN l < 10 THEN null WHEN l > 40 THEN false ELSE null END”).explain(true) // without the new rule ... == Optimized Logical Plan == Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false ELSE null END +- Range (1, 100, step=1, splits=Some(12)) == Physical Plan == *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- *(1) Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false ELSE null END +- *(1) Range (1, 100, step=1, splits=12) // with the new rule ... == Optimized Logical Plan == LocalRelation , [l#2L, s#3] == Physical Plan == LocalTableScan , [l#2L, s#3] ``` **Case 3** ``` spark.sql("SELECT * FROM t JOIN p ON IF(t.l > p.l, null, false)").explain(true) // without the new rule ... == Optimized Logical Plan == Join Inner, if ((l#2L > l#37L)) null else false :- Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] : +- Range (1, 100, step=1, splits=Some(12)) +- Project [id#0L AS l#37L, cast(id#0L as string) AS s#38] +- Range (1, 100, step=1, splits=Some(12)) == Physical Plan == BroadcastNestedLoopJoin BuildRight, Inner, if ((l#2L > l#37L)) null else false :- *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] : +- *(1) Range (1, 100, step=1, splits=12) +- BroadcastExchange IdentityBroadcastMode +- *(2) Project [id#0L AS l#37L, cast(id#0L as string) AS s#38] +- *(2) Range (1, 100, step=1, splits=12) // with the new rule ... == Optimized Logical Plan == LocalRelation , [l#2L, s#3, l#37L, s#38] ``` ## How was this patch tested? This PR comes with a set of dedicated tests. Closes #22857 from aokolnychyi/spark-25860. Authored-by: Anton Okolnychyi Signed-off-by: DB Tsai --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/expressions.scala | 57 ++++ .../optimizer/ReplaceNullWithFalseSuite.scala | 323 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../ReplaceNullWithFalseEndToEndSuite.scala | 71 ++++ 5 files changed, 454 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 95455ffc0495a..a330a84a3a24f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -84,6 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, + ReplaceNullWithFalse, PruneFilters, EliminateSorts, SimplifyCasts, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 468a950fb1087..2b29b49d00ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } + +/** + * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. + * + * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates + * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. + * + * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. + * + * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; + * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually + * `Filter(FalseLiteral)`. + * + * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can + * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` + * can be simplified into `Project(Literal(2))`. + * + * As a result, many unnecessary computations can be removed in the query optimization phase. + */ +object ReplaceNullWithFalse extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) + case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case p: LogicalPlan => p transformExpressions { + case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) + case cw @ CaseWhen(branches, _) => + val newBranches = branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> value + } + cw.copy(branches = newBranches) + } + } + + /** + * Recursively replaces `Literal(null, _)` with `FalseLiteral`. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Literal(null, _) => FalseLiteral + case _ => e + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala new file mode 100644 index 0000000000000..c6b5d0ec96776 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class ReplaceNullWithFalseSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace null literals", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + ReplaceNullWithFalse) :: Nil + } + + private val testRelation = LocalRelation('i.int, 'b.boolean) + private val anotherTestRelation = LocalRelation('d.int) + + test("replace null inside filter and join conditions") { + testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + } + + test("replace null in branches of If") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + FalseLiteral, + Literal(null, BooleanType)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace nulls in nested expressions in branches of If") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + TrueLiteral && Literal(null, BooleanType), + UnresolvedAttribute("b") && Literal(null, BooleanType)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in elseValue of CaseWhen") { + val branches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val originalCond = CaseWhen(branches, Literal(null, BooleanType)) + val expectedCond = CaseWhen(branches, FalseLiteral) + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in branch values of CaseWhen") { + val branches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType), + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val originalCond = CaseWhen(branches, Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in branches of If inside CaseWhen") { + val originalBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> + If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral), + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val originalCond = CaseWhen(originalBranches) + + val expectedBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val expectedCond = CaseWhen(expectedBranches) + + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in complex CaseWhen expressions") { + val originalBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (Literal(6) <= Literal(1)) -> FalseLiteral, + (Literal(4) === Literal(5)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType), + (Literal(4) === Literal(4)) -> TrueLiteral) + val originalCond = CaseWhen(originalBranches) + + val expectedBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, + TrueLiteral -> TrueLiteral) + val expectedCond = CaseWhen(expectedBranches) + + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in Or") { + val originalCond = Or(UnresolvedAttribute("b"), Literal(null)) + val expectedCond = UnresolvedAttribute("b") + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in And") { + val originalCond = And(UnresolvedAttribute("b"), Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace nulls in nested And/Or expressions") { + val originalCond = And( + And(UnresolvedAttribute("b"), Literal(null)), + Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in And inside branches of If") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + FalseLiteral, + And(UnresolvedAttribute("b"), Literal(null, BooleanType))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in branches of If inside And") { + val originalCond = And( + UnresolvedAttribute("b"), + If( + UnresolvedAttribute("i") > Literal(10), + Literal(null), + And(FalseLiteral, UnresolvedAttribute("b")))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in branches of If inside another If") { + val originalCond = If( + If(UnresolvedAttribute("b"), Literal(null), FalseLiteral), + TrueLiteral, + Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in CaseWhen inside another CaseWhen") { + val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null)) + val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("inability to replace null in non-boolean branches of If") { + val condition = If( + UnresolvedAttribute("i") > Literal(10), + Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)), + FalseLiteral) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("inability to replace null in non-boolean values of CaseWhen") { + val nestedCaseWhen = CaseWhen( + Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)), + Literal(null, IntegerType)) + val branchValue = If( + Literal(2) === nestedCaseWhen, + TrueLiteral, + FalseLiteral) + val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) + val condition = CaseWhen(branches) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("inability to replace null in non-boolean branches of If inside another If") { + val condition = If( + Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)), + TrueLiteral, + FalseLiteral) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("replace null in If used as a join condition") { + // this test is only for joins as the condition involves columns from different relations + val originalCond = If( + UnresolvedAttribute("d") > UnresolvedAttribute("i"), + Literal(null), + FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in CaseWhen used as a join condition") { + // this test is only for joins as the condition involves columns from different relations + val originalBranches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null), + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + + val expectedBranches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral, + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + + testJoin( + originalCond = CaseWhen(originalBranches, FalseLiteral), + expectedCond = CaseWhen(expectedBranches, FalseLiteral)) + } + + test("inability to replace null in CaseWhen inside EqualTo used as a join condition") { + // this test is only for joins as the condition involves columns from different relations + val branches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType), + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("replace null in predicates of If") { + val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) + testProjection( + originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("replace null in predicates of If inside another If") { + val predicate = If( + And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)), + TrueLiteral, + FalseLiteral) + testProjection( + originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("inability to replace null in non-boolean expressions inside If predicates") { + val predicate = GreaterThan( + UnresolvedAttribute("i"), + If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) + val column = If(predicate, Literal(5), Literal(1)).as("out") + testProjection(originalExpr = column, expectedExpr = column) + } + + test("replace null in conditions of CaseWhen") { + val branches = Seq( + And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5)) + testProjection( + originalExpr = CaseWhen(branches, Literal(2)).as("out"), + expectedExpr = Literal(2).as("out")) + } + + test("replace null in conditions of CaseWhen inside another CaseWhen") { + val nestedCaseWhen = CaseWhen( + Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)), + Literal(2)) + val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1)) + testProjection( + originalExpr = CaseWhen(branches).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("inability to replace null in non-boolean exprs inside CaseWhen conditions") { + val condition = GreaterThan( + UnresolvedAttribute("i"), + If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) + val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out") + testProjection(originalExpr = column, expectedExpr = column) + } + + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, exp) => rel.where(exp), originalCond, expectedCond) + } + + private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond) + } + + private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = { + test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) + } + + private def test( + func: (LogicalPlan, Expression) => LogicalPlan, + originalExpr: Expression, + expectedExpr: Expression): Unit = { + + val originalPlan = func(testRelation, originalExpr).analyze + val optimizedPlan = Optimize.execute(originalPlan) + val expectedPlan = func(testRelation, expectedExpr).analyze + comparePlans(optimizedPlan, expectedPlan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a430884581dad..4afae56ecdb76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,14 +31,14 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} -import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2} +import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala new file mode 100644 index 0000000000000..fc6ecc4e032f6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If} +import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.functions.{lit, when} +import org.apache.spark.sql.test.SharedSQLContext + +class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { + withTable("t1", "t2") { + Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1") + Seq(2, 3).toDF("l").write.saveAsTable("t2") + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + val q1 = df1.where("IF(l > 10, false, b AND null)") + checkAnswer(q1, Seq.empty) + checkPlanIsEmptyLocalScan(q1) + + val q2 = df1.where("CASE WHEN l < 10 THEN null WHEN l > 40 THEN false ELSE null END") + checkAnswer(q2, Seq.empty) + checkPlanIsEmptyLocalScan(q2) + + val q3 = df1.join(df2, when(df1("l") > df2("l"), lit(null)).otherwise(df1("b") && lit(null))) + checkAnswer(q3, Seq.empty) + checkPlanIsEmptyLocalScan(q3) + + val q4 = df1.where("IF(IF(b, null, false), true, null)") + checkAnswer(q4, Seq.empty) + checkPlanIsEmptyLocalScan(q4) + + val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out") + checkAnswer(q5, Row(1) :: Row(1) :: Nil) + q5.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty)) + } + + val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END") + checkAnswer(q6, Row(2) :: Row(2) :: Nil) + q6.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(e => e.find(_.isInstanceOf[CaseWhen]).isEmpty)) + } + + checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) + } + + def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = df.queryExecution.executedPlan match { + case s: LocalTableScanExec => assert(s.rows.isEmpty) + case p => fail(s"$p is not LocalTableScanExec") + } + } +} From 6be3cce751fd0abf00d668c771f56093f2fa6817 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 31 Oct 2018 15:14:10 -0700 Subject: [PATCH 1973/2461] [SPARK-25899][TESTS] Fix flaky CoarseGrainedSchedulerBackendSuite ## What changes were proposed in this pull request? I saw CoarseGrainedSchedulerBackendSuite failed in my PR and finally reproduced the following error on a very busy machine: ``` sbt.ForkMain$ForkError: org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 400 times over 10.009828643999999 seconds. Last failure message: ArrayBuffer("2", "0", "3") had length 3 instead of expected length 4. ``` The logs in this test shows executor 1 was not up when the test failed. ``` 18/10/30 11:34:03.563 dispatcher-event-loop-12 INFO CoarseGrainedSchedulerBackend$DriverEndpoint: Registered executor NettyRpcEndpointRef(spark-client://Executor) (172.17.0.2:43656) with ID 2 18/10/30 11:34:03.593 dispatcher-event-loop-3 INFO CoarseGrainedSchedulerBackend$DriverEndpoint: Registered executor NettyRpcEndpointRef(spark-client://Executor) (172.17.0.2:43658) with ID 3 18/10/30 11:34:03.629 dispatcher-event-loop-6 INFO CoarseGrainedSchedulerBackend$DriverEndpoint: Registered executor NettyRpcEndpointRef(spark-client://Executor) (172.17.0.2:43654) with ID 0 18/10/30 11:34:03.885 pool-1-thread-1-ScalaTest-running-CoarseGrainedSchedulerBackendSuite INFO CoarseGrainedSchedulerBackendSuite: ===== FINISHED o.a.s.scheduler.CoarseGrainedSchedulerBackendSuite: 'compute max number of concurrent tasks can be launched' ===== ``` And the following logs in executor 1 shows it was still doing the initialization when the timeout happened (at 18/10/30 11:34:03.885). ``` 18/10/30 11:34:03.463 netty-rpc-connection-0 INFO TransportClientFactory: Successfully created connection to 54b6b6217301/172.17.0.2:33741 after 37 ms (0 ms spent in bootstraps) 18/10/30 11:34:03.959 main INFO DiskBlockManager: Created local directory at /home/jenkins/workspace/core/target/tmp/spark-383518bc-53bd-4d9c-885b-d881f03875bf/executor-61c406e4-178f-40a6-ac2c-7314ee6fb142/blockmgr-03fb84a1-eedc-4055-8743-682eb3ac5c67 18/10/30 11:34:03.993 main INFO MemoryStore: MemoryStore started with capacity 546.3 MB ``` Hence, I think our current 10 seconds is not enough on a slow Jenkins machine. This PR just increases the timeout from 10 seconds to 60 seconds to make the test more stable. ## How was this patch tested? Jenkins Closes #22910 from zsxwing/fix-flaky-test. Authored-by: Shixiong Zhu Signed-off-by: gatorsmile --- .../scheduler/CoarseGrainedSchedulerBackendSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 80c9c6f0422a8..c5a39669366ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.util.{RpcUtils, SerializableBuffer} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with Eventually { + private val executorUpTimeout = 60.seconds + test("serialized task larger than max RPC message size") { val conf = new SparkConf conf.set("spark.rpc.message.maxSize", "1") @@ -51,7 +53,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo .setMaster("local-cluster[4, 3, 1024]") .setAppName("test") sc = new SparkContext(conf) - eventually(timeout(10.seconds)) { + eventually(timeout(executorUpTimeout)) { // Ensure all executors have been launched. assert(sc.getExecutorIds().length == 4) } @@ -64,7 +66,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo .setMaster("local-cluster[4, 3, 1024]") .setAppName("test") sc = new SparkContext(conf) - eventually(timeout(10.seconds)) { + eventually(timeout(executorUpTimeout)) { // Ensure all executors have been launched. assert(sc.getExecutorIds().length == 4) } @@ -96,7 +98,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo try { sc.addSparkListener(listener) - eventually(timeout(10.seconds)) { + eventually(timeout(executorUpTimeout)) { // Ensure all executors have been launched. assert(sc.getExecutorIds().length == 4) } From c5ef477d2f6b0a6351ab3332c8647d4b89b705b0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 31 Oct 2018 16:25:19 -0700 Subject: [PATCH 1974/2461] [INFRA] Close stale PR. Closes #22860 From c9667aff4f4888b650fad2ed41698025b1e84166 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 1 Nov 2018 09:14:16 +0800 Subject: [PATCH 1975/2461] [SPARK-25672][SQL] schema_of_csv() - schema inference from an example ## What changes were proposed in this pull request? In the PR, I propose to add new function - *schema_of_csv()* which infers schema of CSV string literal. The result of the function is a string containing a schema in DDL format. For example: ```sql select schema_of_csv('1|abc', map('delimiter', '|')) ``` ``` struct<_c0:int,_c1:string> ``` ## How was this patch tested? Added new tests to `CsvFunctionsSuite`, `CsvExpressionsSuite` and SQL tests to `csv-functions.sql` Closes #22666 from MaxGekk/schema_of_csv-function. Lead-authored-by: hyukjinkwon Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/functions.py | 41 +++++++++++--- .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../sql/catalyst}/csv/CSVInferSchema.scala | 30 ++++++----- .../sql/catalyst/expressions/ExprUtils.scala | 33 ++++++++++-- .../catalyst/expressions/csvExpressions.scala | 54 +++++++++++++++++++ .../expressions/jsonExpressions.scala | 16 +----- .../catalyst}/csv/CSVInferSchemaSuite.scala | 3 +- .../catalyst}/csv/UnivocityParserSuite.scala | 3 +- .../expressions/CsvExpressionsSuite.scala | 10 ++++ .../datasources/csv/CSVDataSource.scala | 2 +- .../org/apache/spark/sql/functions.scala | 35 ++++++++++++ .../sql-tests/inputs/csv-functions.sql | 8 +++ .../sql-tests/results/csv-functions.sql.out | 54 ++++++++++++++++++- .../apache/spark/sql/CsvFunctionsSuite.scala | 15 ++++++ 14 files changed, 262 insertions(+), 45 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/csv/CSVInferSchema.scala (92%) rename sql/{core/src/test/scala/org/apache/spark/sql/execution/datasources => catalyst/src/test/scala/org/apache/spark/sql/catalyst}/csv/CSVInferSchemaSuite.scala (98%) rename sql/{core/src/test/scala/org/apache/spark/sql/execution/datasources => catalyst/src/test/scala/org/apache/spark/sql/catalyst}/csv/UnivocityParserSuite.scala (98%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ca2a256983d67..beb1a065d2803 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2364,6 +2364,33 @@ def schema_of_json(json, options={}): return Column(jc) +@ignore_unicode_prefix +@since(3.0) +def schema_of_csv(csv, options={}): + """ + Parses a CSV string and infers its schema in DDL format. + + :param col: a CSV string or a string literal containing a CSV string. + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> df = spark.range(1) + >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + """ + if isinstance(csv, basestring): + col = _create_column_from_literal(csv) + elif isinstance(csv, Column): + col = _to_java_column(csv) + else: + raise TypeError("schema argument should be a column or string") + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_csv(col, options) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2664,13 +2691,13 @@ def from_csv(col, schema, options={}): :param schema: a string with schema in DDL format to use when parsing the CSV column. :param options: options to control parsing. accepts the same options as the CSV datasource - >>> data = [(1, '1')] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() - [Row(csv=Row(a=1))] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() - [Row(csv=Row(a=1))] + >>> data = [("1,2,3",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect() + [Row(csv=Row(a=1, b=2, c=3))] + >>> value = data[0][0] + >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() + [Row(csv=Row(_c0=1, _c1=2, _c2=3))] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index af6166bcb8692..cf8fb7eea9580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -526,7 +526,8 @@ object FunctionRegistry { castAlias("string", StringType), // csv - expression[CsvToStructs]("from_csv") + expression[CsvToStructs]("from_csv"), + expression[SchemaOfCsv]("schema_of_csv") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 4326a186d6d5f..799e9994451b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal -import scala.util.control.Exception._ +import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[csv] object CSVInferSchema { +object CSVInferSchema { /** * Similar to the JSON schema inference @@ -44,13 +43,7 @@ private[csv] object CSVInferSchema { val rootTypes: Array[DataType] = tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) - header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other - } - StructField(thisHeader, dType, nullable = true) - } + toStructFields(rootTypes, header, options) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -59,7 +52,20 @@ private[csv] object CSVInferSchema { StructType(fields) } - private def inferRowType(options: CSVOptions) + def toStructFields( + fieldTypes: Array[DataType], + header: Array[String], + options: CSVOptions): Array[StructField] = { + header.zip(fieldTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) + } + } + + def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index e5708894f22b4..040b56cc1caea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -19,14 +19,39 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{MapType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String object ExprUtils { - def evalSchemaExpr(exp: Expression): StructType = exp match { - case Literal(s, StringType) => StructType.fromDDL(s.toString) + def evalSchemaExpr(exp: Expression): StructType = { + // Use `DataType.fromDDL` since the type string can be struct<...>. + val dataType = exp match { + case Literal(s, StringType) => + DataType.fromDDL(s.toString) + case e @ SchemaOfCsv(_: Literal, _) => + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_csv function instead of ${e.sql}") + } + + if (!dataType.isInstanceOf[StructType]) { + throw new AnalysisException( + s"Schema should be struct type but got ${dataType.sql}.") + } + dataType.asInstanceOf[StructType] + } + + def evalTypeExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal, _) => + val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) case e => throw new AnalysisException( - s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + "Schema should be specified in DDL format as a string literal or output of " + + s"the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 853b1ea6a5f1c..e70296fe31292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import com.univocity.parsers.csv.CsvParser + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.csv._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util._ @@ -120,3 +123,54 @@ case class CsvToStructs( override def prettyName: String = "from_csv" } + +/** + * A function infers schema of CSV string. + */ +@ExpressionDescription( + usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.", + examples = """ + Examples: + > SELECT _FUNC_('1,abc'); + struct<_c0:int,_c1:string> + """, + since = "3.0.0") +case class SchemaOfCsv( + child: Expression, + options: Map[String, String]) + extends UnaryExpression with CodegenFallback { + + def this(child: Expression) = this(child, Map.empty[String, String]) + + def this(child: Expression, options: Expression) = this( + child = child, + options = ExprUtils.convertToMapData(options)) + + override def dataType: DataType = StringType + + override def nullable: Boolean = false + + @transient + private lazy val csv = child.eval().asInstanceOf[UTF8String] + + override def checkInputDataTypes(): TypeCheckResult = child match { + case Literal(s, StringType) if s != null => super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"The input csv should be a string literal and not null; however, got ${child.sql}.") + } + + override def eval(v: InternalRow): Any = { + val parsedOptions = new CSVOptions(options, true, "UTC") + val parser = new CsvParser(parsedOptions.asParserSettings) + val row = parser.parseLine(csv.toString) + assert(row != null, "Parsed CSV record should not be null.") + + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) + val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + UTF8String.fromString(st.catalogString) + } + + override def prettyName: String = "schema_of_csv" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 77af5906010f3..eafcb6161036e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -529,7 +529,7 @@ case class JsonToStructs( // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = options, child = child, timeZoneId = None) @@ -538,7 +538,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -784,15 +784,3 @@ case class SchemaOfJson( override def prettyName: String = "schema_of_json" } - -object JsonExprUtils { - def evalSchemaExpr(exp: Expression): DataType = exp match { - case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e @ SchemaOfJson(_: Literal, _) => - val ddlSchema = e.eval(EmptyRow).asInstanceOf[UTF8String] - DataType.fromDDL(ddlSchema.toString) - case e => throw new AnalysisException( - "Schema should be specified in DDL format as a string literal" + - s" or output of the schema_of_json function instead of ${e.sql}") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 6b64f2ffa98dd..651846d2ebcb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 6f231142949d1..e4e7dc2e8c0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 65987af710750..386e0d133dff6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -155,4 +155,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P }.getCause assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) } + + test("infer schema of CSV strings") { + checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") + } + + test("infer schema of CSV strings by using options") { + checkEvaluation( + new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")), + "struct<_c0:int,_c1:string>") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 9e7b45db9f280..4808e8ef042d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5348b65d43b38..f8c4d88cb1f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3870,6 +3870,41 @@ object functions { withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) } + /** + * Parses a CSV string and infers its schema in DDL format. + * + * @param csv a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: String): Column = schema_of_csv(lit(csv)) + + /** + * Parses a CSV string and infers its schema in DDL format. + * + * @param csv a string literal containing a CSV string. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column): Column = withExpr(new SchemaOfCsv(csv.expr)) + + /** + * Parses a CSV string and infers its schema in DDL format using options. + * + * @param csv a string literal containing a CSV string. + * @param options options to control how the CSV is parsed. accepts the same options and the + * json data source. See [[DataFrameReader#csv]]. + * @return a column with string literal containing schema in DDL format. + * + * @group collection_funcs + * @since 3.0.0 + */ + def schema_of_csv(csv: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap)) + } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql index d2214fd016028..5be6f807931b8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -7,3 +7,11 @@ select from_csv('1', 'a InvalidType'); select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); select from_csv('1', 'a INT', map('mode', 1)); select from_csv(); +-- infer schema of json literal +select from_csv('1,abc', schema_of_csv('1,abc')); +select schema_of_csv('1|abc', map('delimiter', '|')); +select schema_of_csv(null); +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a'); +SELECT schema_of_csv(csvField) FROM csvTable; +-- Clean up +DROP VIEW IF EXISTS csvTable; diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index f19f34a773c16..677bbd97c549d 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 13 -- !query 0 @@ -24,7 +24,7 @@ select from_csv('1', 1) struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_csv function instead of 1;; line 1 pos 7 -- !query 3 @@ -67,3 +67,53 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 + + +-- !query 7 +select from_csv('1,abc', schema_of_csv('1,abc')) +-- !query 7 schema +struct> +-- !query 7 output +{"_c0":1,"_c1":"abc"} + + +-- !query 8 +select schema_of_csv('1|abc', map('delimiter', '|')) +-- !query 8 schema +struct +-- !query 8 output +struct<_c0:int,_c1:string> + + +-- !query 9 +select schema_of_csv(null) +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(NULL)' due to data type mismatch: The input csv should be a string literal and not null; however, got NULL.; line 1 pos 7 + + +-- !query 10 +CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a') +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +SELECT schema_of_csv(csvField) FROM csvTable +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve 'schema_of_csv(csvtable.`csvField`)' due to data type mismatch: The input csv should be a string literal and not null; however, got csvtable.`csvField`.; line 1 pos 7 + + +-- !query 12 +DROP VIEW IF EXISTS csvTable +-- !query 12 schema +struct<> +-- !query 12 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 38a2143d6d0f0..9395f050b41ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -59,4 +59,19 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(null, null, "0,2013-111-11 12:13:14")), Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) } + + test("schema_of_csv - infers schemas") { + checkAnswer( + spark.range(1).select(schema_of_csv(lit("0.1,1"))), + Seq(Row("struct<_c0:double,_c1:int>"))) + checkAnswer( + spark.range(1).select(schema_of_csv("0.1,1")), + Seq(Row("struct<_c0:double,_c1:int>"))) + } + + test("schema_of_csv - infers schemas using options") { + val df = spark.range(1) + .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) + checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) + } } From cc82b9fed857503448c9ae6bd74ee2fe1ba9ba0b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Nov 2018 10:00:14 +0800 Subject: [PATCH 1976/2461] [SPARK-25884][SQL] Add TBLPROPERTIES and COMMENT, and use LOCATION when SHOW CREATE TABLE. ## What changes were proposed in this pull request? When `SHOW CREATE TABLE` for Datasource tables, we are missing `TBLPROPERTIES` and `COMMENT`, and we should use `LOCATION` instead of path in `OPTION`. ## How was this patch tested? Splitted `ShowCreateTableSuite` to confirm to work with both `InMemoryCatalog` and `HiveExternalCatalog`, and added some tests. Closes #22892 from ueshin/issues/SPARK-25884/show_create_table. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../spark/sql/execution/command/tables.scala | 36 +-- .../src/test/resources/sample.json | 0 .../sql-tests/inputs/show-create-table.sql | 61 +++++ .../results/show-create-table.sql.out | 222 ++++++++++++++++++ .../spark/sql}/ShowCreateTableSuite.scala | 201 +++------------- .../spark/sql/hive/HiveExternalCatalog.scala | 5 +- .../sql/hive/HiveShowCreateTableSuite.scala | 198 ++++++++++++++++ 7 files changed, 538 insertions(+), 185 deletions(-) rename sql/{hive => core}/src/test/resources/sample.json (100%) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/show-create-table.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql}/ShowCreateTableSuite.scala (56%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 64831e5089a67..871eba49dfbd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -957,9 +957,11 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman builder ++= metadata.viewText.mkString(" AS\n", "", "\n") } else { showHiveTableHeader(metadata, builder) + showTableComment(metadata, builder) showHiveTableNonDataColumns(metadata, builder) showHiveTableStorageInfo(metadata, builder) - showHiveTableProperties(metadata, builder) + showTableLocation(metadata, builder) + showTableProperties(metadata, builder) } builder.toString() @@ -973,14 +975,8 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman if (columns.nonEmpty) { builder ++= columns.mkString("(", ", ", ")\n") } - - metadata - .comment - .map("COMMENT '" + escapeSingleQuotedString(_) + "'\n") - .foreach(builder.append) } - private def showHiveTableNonDataColumns(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.partitionColumnNames.nonEmpty) { val partCols = metadata.partitionSchema.map(_.toDDL) @@ -1023,15 +1019,24 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman builder ++= s" OUTPUTFORMAT '${escapeSingleQuotedString(format)}'\n" } } + } + private def showTableLocation(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.tableType == EXTERNAL) { - storage.locationUri.foreach { uri => - builder ++= s"LOCATION '$uri'\n" + metadata.storage.locationUri.foreach { location => + builder ++= s"LOCATION '${escapeSingleQuotedString(CatalogUtils.URIToString(location))}'\n" } } } - private def showHiveTableProperties(metadata: CatalogTable, builder: StringBuilder): Unit = { + private def showTableComment(metadata: CatalogTable, builder: StringBuilder): Unit = { + metadata + .comment + .map("COMMENT '" + escapeSingleQuotedString(_) + "'\n") + .foreach(builder.append) + } + + private def showTableProperties(metadata: CatalogTable, builder: StringBuilder): Unit = { if (metadata.properties.nonEmpty) { val props = metadata.properties.map { case (key, value) => s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" @@ -1048,6 +1053,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman showDataSourceTableDataColumns(metadata, builder) showDataSourceTableOptions(metadata, builder) showDataSourceTableNonDataColumns(metadata, builder) + showTableComment(metadata, builder) + showTableLocation(metadata, builder) + showTableProperties(metadata, builder) builder.toString() } @@ -1063,14 +1071,6 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman val dataSourceOptions = metadata.storage.properties.map { case (key, value) => s"${quoteIdentifier(key)} '${escapeSingleQuotedString(value)}'" - } ++ metadata.storage.locationUri.flatMap { location => - if (metadata.tableType == MANAGED) { - // If it's a managed table, omit PATH option. Spark SQL always creates external table - // when the table creation DDL contains the PATH option. - None - } else { - Some(s"path '${escapeSingleQuotedString(CatalogUtils.URIToString(location))}'") - } } if (dataSourceOptions.nonEmpty) { diff --git a/sql/hive/src/test/resources/sample.json b/sql/core/src/test/resources/sample.json similarity index 100% rename from sql/hive/src/test/resources/sample.json rename to sql/core/src/test/resources/sample.json diff --git a/sql/core/src/test/resources/sql-tests/inputs/show-create-table.sql b/sql/core/src/test/resources/sql-tests/inputs/show-create-table.sql new file mode 100644 index 0000000000000..852bfbd63847d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/show-create-table.sql @@ -0,0 +1,61 @@ +-- simple +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet; + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- options +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +OPTIONS ('a' 1); + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- path option +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +OPTIONS ('path' '/path/to/table'); + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- location +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +LOCATION '/path/to/table'; + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- partition by +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +PARTITIONED BY (a); + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- clustered by +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS; + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- comment +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +COMMENT 'This is a comment'; + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; + + +-- tblproperties +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +TBLPROPERTIES ('a' = '1'); + +SHOW CREATE TABLE tbl; +DROP TABLE tbl; diff --git a/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out b/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out new file mode 100644 index 0000000000000..1faf16cc30509 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/show-create-table.sql.out @@ -0,0 +1,222 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 24 + + +-- !query 0 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SHOW CREATE TABLE tbl +-- !query 1 schema +struct +-- !query 1 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet + + +-- !query 2 +DROP TABLE tbl +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +OPTIONS ('a' 1) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +SHOW CREATE TABLE tbl +-- !query 4 schema +struct +-- !query 4 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet +OPTIONS ( + `a` '1' +) + + +-- !query 5 +DROP TABLE tbl +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +OPTIONS ('path' '/path/to/table') +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +SHOW CREATE TABLE tbl +-- !query 7 schema +struct +-- !query 7 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet +LOCATION 'file:/path/to/table' + + +-- !query 8 +DROP TABLE tbl +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +LOCATION '/path/to/table' +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +SHOW CREATE TABLE tbl +-- !query 10 schema +struct +-- !query 10 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet +LOCATION 'file:/path/to/table' + + +-- !query 11 +DROP TABLE tbl +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +PARTITIONED BY (a) +-- !query 12 schema +struct<> +-- !query 12 output + + + +-- !query 13 +SHOW CREATE TABLE tbl +-- !query 13 schema +struct +-- !query 13 output +CREATE TABLE `tbl` (`b` STRING, `c` INT, `a` INT) +USING parquet +PARTITIONED BY (a) + + +-- !query 14 +DROP TABLE tbl +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS +-- !query 15 schema +struct<> +-- !query 15 output + + + +-- !query 16 +SHOW CREATE TABLE tbl +-- !query 16 schema +struct +-- !query 16 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet +CLUSTERED BY (a) +SORTED BY (b) +INTO 2 BUCKETS + + +-- !query 17 +DROP TABLE tbl +-- !query 17 schema +struct<> +-- !query 17 output + + + +-- !query 18 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +COMMENT 'This is a comment' +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SHOW CREATE TABLE tbl +-- !query 19 schema +struct +-- !query 19 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet +COMMENT 'This is a comment' + + +-- !query 20 +DROP TABLE tbl +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +CREATE TABLE tbl (a INT, b STRING, c INT) USING parquet +TBLPROPERTIES ('a' = '1') +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +SHOW CREATE TABLE tbl +-- !query 22 schema +struct +-- !query 22 output +CREATE TABLE `tbl` (`a` INT, `b` STRING, `c` INT) +USING parquet +TBLPROPERTIES ( + 'a' = '1' +) + + +-- !query 23 +DROP TABLE tbl +-- !query 23 schema +struct<> +-- !query 23 output + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala similarity index 56% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala index 34ca790299859..5c347d2677d5e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql -import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.util.Utils -class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { +class SimpleShowCreateTableSuite extends ShowCreateTableSuite with SharedSQLContext + +abstract class ShowCreateTableSuite extends QueryTest with SQLTestUtils { import testImplicits._ test("data source table with user specified schema") { @@ -105,193 +105,67 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } - test("data source table using Dataset API") { + test("data source table with a comment") { withTable("ddl_test") { - spark - .range(3) - .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e) - .write - .mode("overwrite") - .partitionBy("a", "b") - .bucketBy(2, "c", "d") - .saveAsTable("ddl_test") - - checkCreateTable("ddl_test") - } - } - - test("simple hive table") { - withTable("t1") { - sql( - s"""CREATE TABLE t1 ( - | c1 INT COMMENT 'bla', - | c2 STRING - |) - |TBLPROPERTIES ( - | 'prop1' = 'value1', - | 'prop2' = 'value2' - |) - """.stripMargin - ) - - checkCreateTable("t1") - } - } - - test("simple external hive table") { - withTempDir { dir => - withTable("t1") { - sql( - s"""CREATE TABLE t1 ( - | c1 INT COMMENT 'bla', - | c2 STRING - |) - |LOCATION '${dir.toURI}' - |TBLPROPERTIES ( - | 'prop1' = 'value1', - | 'prop2' = 'value2' - |) - """.stripMargin - ) - - checkCreateTable("t1") - } - } - } - - test("partitioned hive table") { - withTable("t1") { sql( - s"""CREATE TABLE t1 ( - | c1 INT COMMENT 'bla', - | c2 STRING - |) - |COMMENT 'bla' - |PARTITIONED BY ( - | p1 BIGINT COMMENT 'bla', - | p2 STRING - |) - """.stripMargin - ) - - checkCreateTable("t1") - } - } - - test("hive table with explicit storage info") { - withTable("t1") { - sql( - s"""CREATE TABLE t1 ( - | c1 INT COMMENT 'bla', - | c2 STRING - |) - |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |COLLECTION ITEMS TERMINATED BY '@' - |MAP KEYS TERMINATED BY '#' - |NULL DEFINED AS 'NaN' + s"""CREATE TABLE ddl_test + |USING json + |COMMENT 'This is a comment' + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c """.stripMargin ) - checkCreateTable("t1") + checkCreateTable("ddl_test") } } - test("hive table with STORED AS clause") { - withTable("t1") { + test("data source table with table properties") { + withTable("ddl_test") { sql( - s"""CREATE TABLE t1 ( - | c1 INT COMMENT 'bla', - | c2 STRING - |) - |STORED AS PARQUET + s"""CREATE TABLE ddl_test + |USING json + |TBLPROPERTIES ('a' = '1') + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c """.stripMargin ) - checkCreateTable("t1") + checkCreateTable("ddl_test") } } - test("hive table with serde info") { - withTable("t1") { - sql( - s"""CREATE TABLE t1 ( - | c1 INT COMMENT 'bla', - | c2 STRING - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |WITH SERDEPROPERTIES ( - | 'mapkey.delim' = ',', - | 'field.delim' = ',' - |) - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin - ) + test("data source table using Dataset API") { + withTable("ddl_test") { + spark + .range(3) + .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e) + .write + .mode("overwrite") + .partitionBy("a", "b") + .bucketBy(2, "c", "d") + .saveAsTable("ddl_test") - checkCreateTable("t1") + checkCreateTable("ddl_test") } } - test("hive view") { + test("view") { withView("v1") { sql("CREATE VIEW v1 AS SELECT 1 AS a") checkCreateView("v1") } } - test("hive view with output columns") { + test("view with output columns") { withView("v1") { sql("CREATE VIEW v1 (b) AS SELECT 1 AS a") checkCreateView("v1") } } - test("hive bucketing is supported") { - withTable("t1") { - sql( - s"""CREATE TABLE t1 (a INT, b STRING) - |CLUSTERED BY (a) - |SORTED BY (b) - |INTO 2 BUCKETS - """.stripMargin - ) - checkCreateTable("t1") - } - } - - test("hive partitioned view is not supported") { - withTable("t1") { - withView("v1") { - sql( - s""" - |CREATE TABLE t1 (c1 INT, c2 STRING) - |PARTITIONED BY ( - | p1 BIGINT COMMENT 'bla', - | p2 STRING ) - """.stripMargin) - - createRawHiveTable( - s""" - |CREATE VIEW v1 - |PARTITIONED ON (p1, p2) - |AS SELECT * from t1 - """.stripMargin - ) - - val cause = intercept[AnalysisException] { - sql("SHOW CREATE TABLE v1") - } - - assert(cause.getMessage.contains(" - partitioned view")) - } - } - } - test("SPARK-24911: keep quotes for nested fields") { withTable("t1") { - val createTable = "CREATE TABLE `t1`(`a` STRUCT<`b`: STRING>)" - sql(createTable) + val createTable = "CREATE TABLE `t1` (`a` STRUCT<`b`: STRING>)" + sql(s"$createTable USING json") val shownDDL = sql(s"SHOW CREATE TABLE t1") .head() .getString(0) @@ -303,16 +177,11 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } } - private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] - .client.runSqlHive(ddl) - } - - private def checkCreateTable(table: String): Unit = { + protected def checkCreateTable(table: String): Unit = { checkCreateTableOrView(TableIdentifier(table, Some("default")), "TABLE") } - private def checkCreateView(table: String): Unit = { + protected def checkCreateView(table: String): Unit = { checkCreateTableOrView(TableIdentifier(table, Some("default")), "VIEW") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 445161d5de1c2..c1178ad4a84fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.DDL_TIME import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_FORMAT import org.apache.thrift.TException @@ -821,7 +822,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat schema = reorderedSchema, partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), - tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG)) + tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG), + properties = table.properties.filterKeys(!HIVE_GENERATED_TABLE_PROPERTIES(_))) } override def tableExists(db: String, table: String): Boolean = withClient { @@ -1328,6 +1330,7 @@ object HiveExternalCatalog { val CREATED_SPARK_VERSION = SPARK_SQL_PREFIX + "create.version" + val HIVE_GENERATED_TABLE_PROPERTIES = Set(DDL_TIME) val HIVE_GENERATED_STORAGE_PROPERTIES = Set(SERIALIZATION_FORMAT) // When storing data source tables in hive metastore, we need to set data schema to empty if the diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala new file mode 100644 index 0000000000000..0386dc79804c6 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowCreateTableSuite.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{AnalysisException, ShowCreateTableSuite} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveShowCreateTableSuite extends ShowCreateTableSuite with TestHiveSingleton { + + test("simple hive table") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |TBLPROPERTIES ( + | 'prop1' = 'value1', + | 'prop2' = 'value2' + |) + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("simple external hive table") { + withTempDir { dir => + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |LOCATION '${dir.toURI}' + |TBLPROPERTIES ( + | 'prop1' = 'value1', + | 'prop2' = 'value2' + |) + """.stripMargin + ) + + checkCreateTable("t1") + } + } + } + + test("partitioned hive table") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |COMMENT 'bla' + |PARTITIONED BY ( + | p1 BIGINT COMMENT 'bla', + | p2 STRING + |) + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive table with explicit storage info") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |COLLECTION ITEMS TERMINATED BY '@' + |MAP KEYS TERMINATED BY '#' + |NULL DEFINED AS 'NaN' + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive table with STORED AS clause") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |STORED AS PARQUET + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive table with serde info") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |WITH SERDEPROPERTIES ( + | 'mapkey.delim' = ',', + | 'field.delim' = ',' + |) + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive bucketing is supported") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 (a INT, b STRING) + |CLUSTERED BY (a) + |SORTED BY (b) + |INTO 2 BUCKETS + """.stripMargin + ) + checkCreateTable("t1") + } + } + + test("hive partitioned view is not supported") { + withTable("t1") { + withView("v1") { + sql( + s""" + |CREATE TABLE t1 (c1 INT, c2 STRING) + |PARTITIONED BY ( + | p1 BIGINT COMMENT 'bla', + | p2 STRING ) + """.stripMargin) + + createRawHiveTable( + s""" + |CREATE VIEW v1 + |PARTITIONED ON (p1, p2) + |AS SELECT * from t1 + """.stripMargin + ) + + val cause = intercept[AnalysisException] { + sql("SHOW CREATE TABLE v1") + } + + assert(cause.getMessage.contains(" - partitioned view")) + } + } + } + + test("SPARK-24911: keep quotes for nested fields in hive") { + withTable("t1") { + val createTable = "CREATE TABLE `t1`(`a` STRUCT<`b`: STRING>)" + sql(createTable) + val shownDDL = sql(s"SHOW CREATE TABLE t1") + .head() + .getString(0) + .split("\n") + .head + assert(shownDDL == createTable) + + checkCreateTable("t1") + } + } + + private def createRawHiveTable(ddl: String): Unit = { + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) + } +} From cd92f25be5a221e0d4618925f7bc9dfd3bb8cb59 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Nov 2018 12:47:32 +0800 Subject: [PATCH 1977/2461] [SPARK-25746][SQL][FOLLOWUP] do not add unnecessary If expression ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/22749. When we construct the new serializer in `ExpressionEncoder.tuple`, we don't need to add `if(isnull ...)` check for each field. They are either simple expressions that can propagate null correctly(e.g. `GetStructField(GetColumnByOrdinal(0, schema), index)`), or complex expression that already have the isnull check. ## How was this patch tested? existing tests Closes #22898 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../catalyst/encoders/ExpressionEncoder.scala | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 2c8e81ef17d72..592520c59a761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -89,16 +89,11 @@ object ExpressionEncoder { */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { // TODO: check if encoders length is more than 22 and throw exception for it. - encoders.foreach(_.assertUnresolved()) - val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => - StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable) - }) - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true) val serializers = encoders.zipWithIndex.map { case (enc, index) => val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + @@ -106,42 +101,39 @@ object ExpressionEncoder { val originalInputObject = boundRefs.head val newInputObject = Invoke( - BoundReference(0, ObjectType(cls), nullable = true), + newSerializerInput, s"_${index + 1}", originalInputObject.dataType, returnNullable = originalInputObject.nullable) val newSerializer = enc.objSerializer.transformUp { - case b: BoundReference => newInputObject + case BoundReference(0, _, _) => newInputObject } Alias(newSerializer, s"_${index + 1}")() } + val newSerializer = CreateStruct(serializers) - val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => + val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) + val deserializers = encoders.zipWithIndex.map { case (enc, index) => val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct assert(getColExprs.size == 1, "object deserializer should have only one " + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") - val input = GetStructField(GetColumnByOrdinal(0, schema), index) - val newDeserializer = enc.objDeserializer.transformUp { + val input = GetStructField(newDeserializerInput, index) + enc.objDeserializer.transformUp { case GetColumnByOrdinal(0, _) => input } - if (schema(index).nullable) { - If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer) - } else { - newDeserializer - } } + val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false) - val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)), - Literal.create(null, schema), CreateStruct(serializers)) - val deserializer = - NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) + def nullSafe(input: Expression, result: Expression): Expression = { + If(IsNull(input), Literal.create(null, result.dataType), result) + } new ExpressionEncoder[Any]( - serializer, - deserializer, + nullSafe(newSerializerInput, newSerializer), + nullSafe(newDeserializerInput, newDeserializer), ClassTag(cls)) } From fc8222298e26d9e4bb9ea1c0baa48cadba8ca673 Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Thu, 1 Nov 2018 09:33:55 -0700 Subject: [PATCH 1978/2461] [SPARK-25809][K8S][TEST] New K8S integration testing backends ## What changes were proposed in this pull request? Currently K8S integration tests are hardcoded to use a `minikube` based backend. `minikube` is VM based so can be resource hungry and also doesn't cope well with certain networking setups (for example using Cisco AnyConnect software VPN `minikube` is unusable as it detects its own IP incorrectly). This PR Adds a new K8S integration testing backend that allows for using the Kubernetes support in [Docker for Desktop](https://blog.docker.com/2018/07/kubernetes-is-now-available-in-docker-desktop-stable-channel/). It also generalises the framework to be able to run the integration tests against an arbitrary Kubernetes cluster. To Do: - [x] General Kubernetes cluster backend - [x] Documentation on Kubernetes integration testing - [x] Testing of general K8S backend - [x] Check whether change from timestamps being `Time` to `String` in Fabric 8 upgrade needs additional fix up ## How was this patch tested? Ran integration tests with Docker for Desktop and all passed: ![screen shot 2018-10-23 at 14 19 56](https://user-images.githubusercontent.com/2104864/47363460-c5816a00-d6ce-11e8-9c15-56b34698e797.png) Suggested Reviewers: ifilonenko srowen Author: Rob Vesse Closes #22805 from rvesse/SPARK-25809. --- .../k8s/SparkKubernetesClientFactory.scala | 5 + .../k8s/submit/LoggingPodStatusWatcher.scala | 3 - .../kubernetes/integration-tests/README.md | 183 ++++++++++++++++-- .../dev/dev-run-integration-tests.sh | 10 + .../kubernetes/integration-tests/pom.xml | 10 + .../scripts/setup-integration-test-env.sh | 43 ++-- .../k8s/integrationtest/KubernetesSuite.scala | 3 +- .../KubernetesTestComponents.scala | 5 +- .../k8s/integrationtest/ProcessUtils.scala | 5 +- .../k8s/integrationtest/TestConfig.scala | 6 +- .../k8s/integrationtest/TestConstants.scala | 15 +- .../backend/IntegrationTestBackend.scala | 21 +- .../backend/cloud/KubeConfigBackend.scala | 70 +++++++ .../docker/DockerForDesktopBackend.scala | 25 +++ 14 files changed, 356 insertions(+), 48 deletions(-) create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index c47e78cbf19e3..77bd66b608e7c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -42,6 +42,9 @@ private[spark] object SparkKubernetesClientFactory { sparkConf: SparkConf, defaultServiceAccountToken: Option[File], defaultServiceAccountCaCert: Option[File]): KubernetesClient = { + + // TODO [SPARK-25887] Support configurable context + val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX" val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX" val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf) @@ -63,6 +66,8 @@ private[spark] object SparkKubernetesClientFactory { .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") val dispatcher = new Dispatcher( ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) + + // TODO [SPARK-25887] Create builder in a way that respects configurable context val config = new ConfigBuilder() .withApiVersion("v1") .withMasterUrl(master) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala index 79b55bc37afcd..a2430c05e2568 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/LoggingPodStatusWatcher.scala @@ -18,13 +18,10 @@ package org.apache.spark.deploy.k8s.submit import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.collection.JavaConverters._ - import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action -import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.internal.Logging import org.apache.spark.util.ThreadUtils diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index b3863e6b7d1af..64f8e77597eba 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -8,26 +8,59 @@ title: Spark on Kubernetes Integration Tests Note that the integration test framework is currently being heavily revised and is subject to change. Note that currently the integration tests only run with Java 8. -The simplest way to run the integration tests is to install and run Minikube, then run the following: +The simplest way to run the integration tests is to install and run Minikube, then run the following from this +directory: dev/dev-run-integration-tests.sh The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should -run with a minimum of 3 CPUs and 4G of memory: +run with a minimum of 4 CPUs and 6G of memory: - minikube start --cpus 3 --memory 4096 + minikube start --cpus 4 --memory 6144 You can download Minikube [here](https://github.com/kubernetes/minikube/releases). # Integration test customization -Configuration of the integration test runtime is done through passing different arguments to the test script. The main useful options are outlined below. +Configuration of the integration test runtime is done through passing different arguments to the test script. +The main useful options are outlined below. + +## Using a different backend + +The integration test backend i.e. the K8S cluster used for testing is controlled by the `--deploy-mode` option. By +default this is set to `minikube`, the available backends are their perequisites are as follows. + +### `minikube` + +Uses the local `minikube` cluster, this requires that `minikube` 0.23.0 or greater be installed and that it be allocated +at least 4 CPUs and 6GB memory (some users have reported success with as few as 3 CPUs and 4GB memory). The tests will +check if `minikube` is started and abort early if it isn't currently running. + +### `docker-for-desktop` + +Since July 2018 Docker for Desktop provide an optional Kubernetes cluster that can be enabled as described in this +[blog post](https://blog.docker.com/2018/07/kubernetes-is-now-available-in-docker-desktop-stable-channel/). Assuming +this is enabled using this backend will auto-configure itself from the `docker-for-desktop` context that Docker creates +in your `~/.kube/config` file. If your config file is in a different location you should set the `KUBECONFIG` +environment variable appropriately. + +### `cloud` + +These cloud backend configures the tests to use an arbitrary Kubernetes cluster running in the cloud or otherwise. + +The `cloud` backend auto-configures the cluster to use from your K8S config file, this is assumed to be `~/.kube/config` +unless the `KUBECONFIG` environment variable is set to override this location. By default this will use whatever your +current context is in the config file, to use an alternative context from your config file you can specify the +`--context ` flag with the desired context. + +You can optionally use a different K8S master URL than the one your K8S config file specified, this should be supplied +via the `--spark-master ` flag. ## Re-using Docker Images By default, the test framework will build new Docker images on every test execution. A unique image tag is generated, -and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker image tag -that you have built by other means already, pass the tag to the test script: +and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker +image tag that you have built by other means already, pass the tag to the test script: dev/dev-run-integration-tests.sh --image-tag @@ -37,16 +70,140 @@ where if you still want to use images that were built before by the test framewo ## Spark Distribution Under Test -The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to specify the tarball: +The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to +specify the tarball: * `--spark-tgz ` - set `` to point to a tarball containing the Spark distribution to test. -TODO: Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current tree. +This Tarball should be created by first running `dev/make-distribution.sh` passing the `--tgz` flag and `-Pkubernetes` +as one of the options to ensure that Kubernetes support is included in the distribution. For more details on building a +runnable distribution please see the +[Building Spark](https://spark.apache.org/docs/latest/building-spark.html#building-a-runnable-distribution) +documentation. + +**TODO:** Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current +tree. ## Customizing the Namespace and Service Account -* `--namespace ` - set `` to the namespace in which the tests should be run. -* `--service-account ` - set `` to the name of the Kubernetes service account to -use in the namespace specified by the `--namespace`. The service account is expected to have permissions to get, list, watch, -and create pods. For clusters with RBAC turned on, it's important that the right permissions are granted to the service account -in the namespace through an appropriate role and role binding. A reference RBAC configuration is provided in `dev/spark-rbac.yaml`. +If no namespace is specified then a temporary namespace will be created and deleted during the test run. Similarly if +no service account is specified then the `default` service account for the namespace will be used. + +Using the `--namespace ` flag sets `` to the namespace in which the tests should be run. If this +is supplied then the tests assume this namespace exists in the K8S cluster and will not attempt to create it. +Additionally this namespace must have an appropriately authorized service account which can be customised via the +`--service-account` flag. + +The `--service-account ` flag sets `` to the name of the Kubernetes service +account to use in the namespace specified by the `--namespace` flag. The service account is expected to have permissions +to get, list, watch, and create pods. For clusters with RBAC turned on, it's important that the right permissions are +granted to the service account in the namespace through an appropriate role and role binding. A reference RBAC +configuration is provided in `dev/spark-rbac.yaml`. + +# Running the Test Directly + +If you prefer to run just the integration tests directly, then you can customise the behaviour via passing system +properties to Maven. For example: + + mvn integration-test -am -pl :spark-kubernetes-integration-tests_2.11 \ + -Pkubernetes -Pkubernetes-integration-tests \ + -Phadoop-2.7 -Dhadoop.version=2.7.3 \ + -Dspark.kubernetes.test.sparkTgz=spark-3.0.0-SNAPSHOT-bin-example.tgz \ + -Dspark.kubernetes.test.imageTag=sometag \ + -Dspark.kubernetes.test.imageRepo=docker.io/somerepo \ + -Dspark.kubernetes.test.namespace=spark-int-tests \ + -Dspark.kubernetes.test.deployMode=docker-for-desktop \ + -Dtest.include.tags=k8s + + +## Available Maven Properties + +The following are the available Maven properties that can be passed. For the most part these correspond to flags passed +to the wrapper scripts and using the wrapper scripts will simply set these appropriately behind the scenes. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      PropertyDescriptionDefault
      spark.kubernetes.test.sparkTgz + A runnable Spark distribution to test. +
      spark.kubernetes.test.unpackSparkDir + The directory where the runnable Spark distribution will be unpacked. + ${project.build.directory}/spark-dist-unpacked
      spark.kubernetes.test.deployMode + The integration test backend to use. Acceptable values are minikube, + docker-for-desktop and cloud. + minikube
      spark.kubernetes.test.kubeConfigContext + When using the cloud backend specifies the context from the users K8S config file that should be used + as the target cluster for integration testing. If not set and using the cloud backend then your + current context will be used. +
      spark.kubernetes.test.master + When using the cloud-url backend must be specified to indicate the K8S master URL to communicate + with. +
      spark.kubernetes.test.imageTag + A specific image tag to use, when set assumes images with those tags are already built and available in the + specified image repository. When set to N/A (the default) fresh images will be built. + N/A +
      spark.kubernetes.test.imageTagFile + A file containing the image tag to use, if no specific image tag is set then fresh images will be built with a + generated tag and that tag written to this file. + ${project.build.directory}/imageTag.txt
      spark.kubernetes.test.imageRepo + The Docker image repository that contains the images to be used if a specific image tag is set or to which the + images will be pushed to if fresh images are being built. + docker.io/kubespark
      spark.kubernetes.test.namespace + A specific Kubernetes namespace to run the tests in. If specified then the tests assume that this namespace + already exists. When not specified a temporary namespace for the tests will be created and deleted as part of the + test run. +
      spark.kubernetes.test.serviceAccountName + A specific Kubernetes service account to use for running the tests. If not specified then the namespaces default + service account will be used and that must have sufficient permissions or the tests will fail. +
      diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index c3c843e001f21..3c7cc9369047a 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -26,6 +26,7 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= +CONTEXT= INCLUDE_TAGS="k8s" EXCLUDE_TAGS= SCALA_VERSION="$($TEST_ROOT_DIR/build/mvn org.apache.maven.plugins:maven-help-plugin:2.1.1:evaluate -Dexpression=scala.binary.version | grep -v '\[' )" @@ -61,6 +62,10 @@ while (( "$#" )); do SERVICE_ACCOUNT="$2" shift ;; + --context) + CONTEXT="$2" + shift + ;; --include-tags) INCLUDE_TAGS="k8s,$2" shift @@ -94,6 +99,11 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) fi +if [ -n $CONTEXT ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.kubeConfigContext=$CONTEXT ) +fi + if [ -n $SPARK_MASTER ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index a07fe1feea3eb..07288c97bd527 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -33,11 +33,20 @@ 3.2.2 1.0 kubernetes-integration-tests + + + + ${project.build.directory}/spark-dist-unpacked N/A ${project.build.directory}/imageTag.txt minikube docker.io/kubespark + + + + + @@ -135,6 +144,7 @@ ${spark.kubernetes.test.unpackSparkDir} ${spark.kubernetes.test.imageRepo} ${spark.kubernetes.test.deployMode} + ${spark.kubernetes.test.kubeConfigContext} ${spark.kubernetes.test.master} ${spark.kubernetes.test.namespace} ${spark.kubernetes.test.serviceAccountName} diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index ccfb8e767c529..a4a9f5b7da131 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -71,19 +71,36 @@ if [[ $IMAGE_TAG == "N/A" ]]; then IMAGE_TAG=$(uuidgen); cd $UNPACKED_SPARK_TGZ - if [[ $DEPLOY_MODE == cloud ]] ; - then - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build - if [[ $IMAGE_REPO == gcr.io* ]] ; - then - gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG - else - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push - fi - else - # -m option for minikube. - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build - fi + + case $DEPLOY_MODE in + cloud) + # Build images + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + + # Push images appropriately + if [[ $IMAGE_REPO == gcr.io* ]] ; + then + gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG + else + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push + fi + ;; + + docker-for-desktop) + # Only need to build as this will place it in our local Docker repo which is all + # we need for Docker for Desktop to work so no need to also push + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + ;; + + minikube) + # Only need to build and if we do this with the -m option for minikube we will + # build the images directly using the minikube Docker daemon so no need to push + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + ;; + *) + echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 + ;; + esac cd - fi diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index e2e5880255e2c..6aa1d57085068 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -33,6 +33,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.k8s.integrationtest.TestConfig._ +import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} import org.apache.spark.internal.Logging @@ -77,7 +78,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite System.clearProperty(key) } - val sparkDirProp = System.getProperty("spark.kubernetes.test.unpackSparkDir") + val sparkDirProp = System.getProperty(CONFIG_KEY_UNPACK_DIR) require(sparkDirProp != null, "Spark home directory must be provided in system properties.") sparkHomeDir = Paths.get(sparkDirProp) require(sparkHomeDir.toFile.isDirectory, diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 5615d6173eebd..c0b435efb8c9c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -25,15 +25,16 @@ import scala.collection.mutable import io.fabric8.kubernetes.client.DefaultKubernetesClient import org.scalatest.concurrent.Eventually +import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.internal.Logging private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) { - val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) + val namespaceOption = Option(System.getProperty(CONFIG_KEY_KUBE_NAMESPACE)) val hasUserSpecifiedNamespace = namespaceOption.isDefined val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) val serviceAccountName = - Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) + Option(System.getProperty(CONFIG_KEY_KUBE_SVC_ACCOUNT)) .getOrElse("default") val kubernetesClient = defaultClient.inNamespace(namespace) val clientConfig = kubernetesClient.getConfiguration diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala index d8f3a6cec05c3..004a942c1cdb3 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala @@ -28,7 +28,7 @@ object ProcessUtils extends Logging { * executeProcess is used to run a command and return the output if it * completes within timeout seconds. */ - def executeProcess(fullCommand: Array[String], timeout: Long): Seq[String] = { + def executeProcess(fullCommand: Array[String], timeout: Long, dumpErrors: Boolean = false): Seq[String] = { val pb = new ProcessBuilder().command(fullCommand: _*) pb.redirectErrorStream(true) val proc = pb.start() @@ -40,7 +40,8 @@ object ProcessUtils extends Logging { }) assert(proc.waitFor(timeout, TimeUnit.SECONDS), s"Timed out while executing ${fullCommand.mkString(" ")}") - assert(proc.exitValue == 0, s"Failed to execute ${fullCommand.mkString(" ")}") + assert(proc.exitValue == 0, + s"Failed to execute ${fullCommand.mkString(" ")}${if (dumpErrors) "\n" + outputLines.mkString("\n")}") outputLines } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala index 5a49e0779160c..363ec0a6016bb 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala @@ -21,9 +21,11 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ + object TestConfig { def getTestImageTag: String = { - val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") + val imageTagFileProp = System.getProperty(CONFIG_KEY_IMAGE_TAG_FILE) require(imageTagFileProp != null, "Image tag file must be provided in system properties.") val imageTagFile = new File(imageTagFileProp) require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.") @@ -31,7 +33,7 @@ object TestConfig { } def getTestImageRepo: String = { - val imageRepo = System.getProperty("spark.kubernetes.test.imageRepo") + val imageRepo = System.getProperty(CONFIG_KEY_IMAGE_REPO) require(imageRepo != null, "Image repo must be provided in system properties.") imageRepo } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala index 8595d0eab1126..eeae70cd68571 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala @@ -17,6 +17,17 @@ package org.apache.spark.deploy.k8s.integrationtest object TestConstants { - val MINIKUBE_TEST_BACKEND = "minikube" - val GCE_TEST_BACKEND = "gce" + val BACKEND_MINIKUBE = "minikube" + val BACKEND_DOCKER_FOR_DESKTOP = "docker-for-desktop" + val BACKEND_CLOUD = "cloud" + + val CONFIG_KEY_DEPLOY_MODE = "spark.kubernetes.test.deployMode" + val CONFIG_KEY_KUBE_CONFIG_CONTEXT = "spark.kubernetes.test.kubeConfigContext" + val CONFIG_KEY_KUBE_MASTER_URL = "spark.kubernetes.test.master" + val CONFIG_KEY_KUBE_NAMESPACE = "spark.kubernetes.test.namespace" + val CONFIG_KEY_KUBE_SVC_ACCOUNT = "spark.kubernetes.test.serviceAccountName" + val CONFIG_KEY_IMAGE_TAG = "spark.kubernetes.test.imageTagF" + val CONFIG_KEY_IMAGE_TAG_FILE = "spark.kubernetes.test.imageTagFile" + val CONFIG_KEY_IMAGE_REPO = "spark.kubernetes.test.imageRepo" + val CONFIG_KEY_UNPACK_DIR = "spark.kubernetes.test.unpackSparkDir" } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala index 284712c6d250e..7bf324c6c4a14 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -18,7 +18,9 @@ package org.apache.spark.deploy.k8s.integrationtest.backend import io.fabric8.kubernetes.client.DefaultKubernetesClient - +import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ +import org.apache.spark.deploy.k8s.integrationtest.backend.cloud.KubeConfigBackend +import org.apache.spark.deploy.k8s.integrationtest.backend.docker.DockerForDesktopBackend import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend private[spark] trait IntegrationTestBackend { @@ -28,16 +30,15 @@ private[spark] trait IntegrationTestBackend { } private[spark] object IntegrationTestBackendFactory { - val deployModeConfigKey = "spark.kubernetes.test.deployMode" - def getTestBackend: IntegrationTestBackend = { - val deployMode = Option(System.getProperty(deployModeConfigKey)) - .getOrElse("minikube") - if (deployMode == "minikube") { - MinikubeTestBackend - } else { - throw new IllegalArgumentException( - "Invalid " + deployModeConfigKey + ": " + deployMode) + val deployMode = Option(System.getProperty(CONFIG_KEY_DEPLOY_MODE)) + .getOrElse(BACKEND_MINIKUBE) + deployMode match { + case BACKEND_MINIKUBE => MinikubeTestBackend + case BACKEND_CLOUD => new KubeConfigBackend(System.getProperty(CONFIG_KEY_KUBE_CONFIG_CONTEXT)) + case BACKEND_DOCKER_FOR_DESKTOP => DockerForDesktopBackend + case _ => throw new IllegalArgumentException("Invalid " + + CONFIG_KEY_DEPLOY_MODE + ": " + deployMode) } } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala new file mode 100644 index 0000000000000..333526ba3ef98 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.cloud + +import java.nio.file.Paths + +import io.fabric8.kubernetes.client.utils.Utils +import io.fabric8.kubernetes.client.{Config, DefaultKubernetesClient} +import org.apache.commons.lang3.StringUtils +import org.apache.spark.deploy.k8s.integrationtest.TestConstants +import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils.checkAndGetK8sMasterUrl + +private[spark] class KubeConfigBackend(var context: String) + extends IntegrationTestBackend with Logging { + logInfo(s"K8S Integration tests will run against " + + s"${if (context != null) s"context ${context}" else "default context"}" + + s" from users K8S config file") + + private var defaultClient: DefaultKubernetesClient = _ + + override def initialize(): Unit = { + // Auto-configure K8S client from K8S config file + if (Utils.getSystemPropertyOrEnvVar(Config.KUBERNETES_KUBECONFIG_FILE, null: String) == null) { + // Fabric 8 client will automatically assume a default location in this case + logWarning(s"No explicit KUBECONFIG specified, will assume .kube/config under your home directory") + } + val config = Config.autoConfigure(context) + + // If an explicit master URL was specified then override that detected from the + // K8S config if it is different + var masterUrl = Option(System.getProperty(TestConstants.CONFIG_KEY_KUBE_MASTER_URL)) + .getOrElse(null) + if (StringUtils.isNotBlank(masterUrl)) { + // Clean up master URL which would have been specified in Spark format into a normal + // K8S master URL + masterUrl = checkAndGetK8sMasterUrl(masterUrl).replaceFirst("k8s://", "") + if (!StringUtils.equals(config.getMasterUrl, masterUrl)) { + logInfo(s"Overriding K8S master URL ${config.getMasterUrl} from K8S config file " + + s"with user specified master URL ${masterUrl}") + config.setMasterUrl(masterUrl) + } + } + + defaultClient = new DefaultKubernetesClient(config) + } + + override def cleanUp(): Unit = { + super.cleanUp() + } + + override def getKubernetesClient: DefaultKubernetesClient = { + defaultClient + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala new file mode 100644 index 0000000000000..81a11ae9dcdc6 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/docker/DockerForDesktopBackend.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.docker + +import org.apache.spark.deploy.k8s.integrationtest.TestConstants +import org.apache.spark.deploy.k8s.integrationtest.backend.cloud.KubeConfigBackend + +private[spark] object DockerForDesktopBackend + extends KubeConfigBackend(TestConstants.BACKEND_DOCKER_FOR_DESKTOP) { + +} From e9d3ca0b7993995f24f5c555a570bc2521119e12 Mon Sep 17 00:00:00 2001 From: Patrick Brown Date: Thu, 1 Nov 2018 09:34:29 -0700 Subject: [PATCH 1979/2461] [SPARK-25837][CORE] Fix potential slowdown in AppStatusListener when cleaning up stages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? * Update `AppStatusListener` `cleanupStages` method to remove tasks for those stages in a single pass instead of 1 for each stage. * This fixes an issue where the cleanupStages method would get backed up, causing a backup in the executor in ElementTrackingStore, resulting in stages and jobs not getting cleaned up properly. Tasks seem most susceptible to this as there are a lot of them, however a similar issue could arise in other locations the `KVStore` `view` method is used. A broader fix might involve updates to `KVStoreView` and `InMemoryView` as it appears this interface and implementation can lead to multiple and inefficient traversals of the stored data. ## How was this patch tested? Using existing tests in AppStatusListenerSuite This is my original work and I license the work to the project under the project’s open source license. Closes #22883 from patrickbrownsync/cleanup-stages-fix. Authored-by: Patrick Brown Signed-off-by: Marcelo Vanzin --- .../spark/status/AppStatusListener.scala | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index d52b7e8dae71e..e2c190ea198e0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -1073,16 +1073,6 @@ private[spark] class AppStatusListener( kvstore.delete(e.getClass(), e.id) } - val tasks = kvstore.view(classOf[TaskDataWrapper]) - .index("stage") - .first(key) - .last(key) - .asScala - - tasks.foreach { t => - kvstore.delete(t.getClass(), t.taskId) - } - // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) @@ -1105,6 +1095,15 @@ private[spark] class AppStatusListener( cleanupCachedQuantiles(key) } + + // Delete tasks for all stages in one pass, as deleting them for each stage individually is slow + val tasks = kvstore.view(classOf[TaskDataWrapper]).asScala + val keys = stages.map { s => (s.info.stageId, s.info.attemptId) }.toSet + tasks.foreach { t => + if (keys.contains((t.stageId, t.stageAttemptId))) { + kvstore.delete(t.getClass(), t.taskId) + } + } } private def cleanupTasks(stage: LiveStage): Unit = { From e91b607719886b57d1550a70c0f9df4342d72989 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 1 Nov 2018 23:18:20 -0700 Subject: [PATCH 1980/2461] [SPARK-25918][SQL] LOAD DATA LOCAL INPATH should handle a relative path ## What changes were proposed in this pull request? Unfortunately, it seems that we missed this in 2.4.0. In Spark 2.4, if the default file system is not the local file system, `LOAD DATA LOCAL INPATH` only works in case of absolute paths. This PR aims to fix it to support relative paths. This is a regression in 2.4.0. ```scala $ ls kv1.txt kv1.txt scala> spark.sql("LOAD DATA LOCAL INPATH 'kv1.txt' INTO TABLE t") org.apache.spark.sql.AnalysisException: LOAD DATA input path does not exist: kv1.txt; ``` ## How was this patch tested? Pass the Jenkins Closes #22927 from dongjoon-hyun/SPARK-LOAD. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/execution/command/tables.scala | 5 +++-- .../spark/sql/hive/execution/HiveCommandSuite.scala | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 871eba49dfbd0..823dc0d5ed387 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -376,7 +376,8 @@ object LoadDataCommand { * @return qualified path object */ private[sql] def makeQualified(defaultUri: URI, workingDir: Path, path: Path): Path = { - val pathUri = if (path.isAbsolute()) path.toUri() else new Path(workingDir, path).toUri() + val newPath = new Path(workingDir, path) + val pathUri = if (path.isAbsolute()) path.toUri() else newPath.toUri() if (pathUri.getScheme == null || pathUri.getAuthority == null && defaultUri.getAuthority != null) { val scheme = if (pathUri.getScheme == null) defaultUri.getScheme else pathUri.getScheme @@ -393,7 +394,7 @@ object LoadDataCommand { throw new IllegalArgumentException(e) } } else { - path + newPath } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 6937e97a47dc6..9147a98c94457 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.hive.execution import java.io.File import com.google.common.io.Files +import org.apache.hadoop.fs.{FileContext, FsConstants, Path} import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.command.LoadDataCommand import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType @@ -439,4 +441,11 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } + test("SPARK-25918: LOAD DATA LOCAL INPATH should handle a relative path") { + val localFS = FileContext.getLocalFSFileContext() + val workingDir = localFS.getWorkingDirectory + val r = LoadDataCommand.makeQualified( + FsConstants.LOCAL_FS_URI, workingDir, new Path("kv1.txt")) + assert(r === new Path(s"$workingDir/kv1.txt")) + } } From c00186f90cfcc33492d760f874ead34f0e3da6ed Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 2 Nov 2018 10:56:30 -0500 Subject: [PATCH 1981/2461] [SPARK-25023] Clarify Spark security documentation ## What changes were proposed in this pull request? Clarify documentation about security. ## How was this patch tested? None, just documentation Closes #22852 from tgravescs/SPARK-25023. Authored-by: Thomas Graves Signed-off-by: Thomas Graves --- docs/index.md | 5 +++++ docs/quick-start.md | 5 +++++ docs/running-on-kubernetes.md | 5 +++++ docs/running-on-mesos.md | 5 +++++ docs/running-on-yarn.md | 5 +++++ docs/security.md | 17 +++++++++++++++-- docs/spark-standalone.md | 5 +++++ 7 files changed, 45 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index d269f54c73439..ac38f1d4c53c2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,6 +10,11 @@ It provides high-level APIs in Java, Scala, Python and R, and an optimized engine that supports general execution graphs. It also supports a rich set of higher-level tools including [Spark SQL](sql-programming-guide.html) for SQL and structured data processing, [MLlib](ml-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). +# Security + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Please see [Spark Security](security.html) before downloading and running Spark. + # Downloading Get Spark from the [downloads page](https://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions. diff --git a/docs/quick-start.md b/docs/quick-start.md index ef7af6c3f6cec..28186c11887fc 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -17,6 +17,11 @@ you can download a package for any version of Hadoop. Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. +# Security + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Please see [Spark Security](security.html) before running Spark. + # Interactive Analysis with the Spark Shell ## Basics diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 2917197a2e2ec..905226877720a 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -12,6 +12,11 @@ Kubernetes scheduler that has been added to Spark. In future versions, there may be behavioral changes around configuration, container images and entrypoints.** +# Security + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. + # Prerequisites * A runnable distribution of Spark 2.3 or above. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index b473e654563d6..2502cd4ca86f4 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -13,6 +13,11 @@ The advantages of deploying Spark with Mesos include: [frameworks](https://mesos.apache.org/documentation/latest/frameworks/) - scalable partitioning between multiple instances of Spark +# Security + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. + # How it Works In a standalone cluster deployment, the cluster manager in the below diagram is a Spark master diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3b725cf295537..a7a448fbeb65e 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -9,6 +9,11 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Security + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. + # Launching Spark on YARN Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. diff --git a/docs/security.md b/docs/security.md index ffae683df6256..2f7fa9c6179f4 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,7 +6,20 @@ title: Security * This will become a table of contents (this text will be scraped). {:toc} -# Spark RPC +# Spark Security: Things You Need To Know + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Spark supports multiple deployments types and each one supports different levels of security. Not +all deployment types will be secure in all environments and none are secure by default. Be +sure to evaluate your environment, what Spark supports, and take the appropriate measure to secure +your Spark deployment. + +There are many different types of security concerns. Spark does not necessarily protect against +all things. Listed below are some of the things Spark supports. Also check the deployment +documentation for the type of deployment you are using for deployment specific settings. Anything +not documented, Spark does not support. + +# Spark RPC (Communication protocol between Spark processes) ## Authentication @@ -123,7 +136,7 @@ The following table describes the different options available for configuring th Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or -`saveAsTable`. +`saveAsTable`. It also may not cover temporary files created explicitly by the user. The following settings cover enabling encryption for data written to disk: diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 7975b0c8b11ca..49ef2e1ce2a1b 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -8,6 +8,11 @@ title: Spark Standalone Mode In addition to running on the Mesos or YARN cluster managers, Spark also provides a simple standalone deploy mode. You can launch a standalone cluster either manually, by starting a master and workers by hand, or use our provided [launch scripts](#cluster-launch-scripts). It is also possible to run these daemons on a single machine for testing. +# Security + +Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. +Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. + # Installing Spark Standalone to a Cluster To install Spark Standalone mode, you simply place a compiled version of Spark on each node on the cluster. You can obtain pre-built versions of Spark with each release or [build it yourself](building-spark.html). From c71db43e11fb90d6675421604ad29f596f2b8bfe Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 2 Nov 2018 11:05:10 -0500 Subject: [PATCH 1982/2461] [SPARK-25909] fix documentation on cluster managers ## What changes were proposed in this pull request? Propose changing the documentation to state that there are 4, not 3, cluster managers available. ## How was this patch tested? This is a docs-only patch and doesn't need any new testing beyond the normal CI process for Spark. Closes #22922 from jameslamb/bugfix/cluster_docs. Authored-by: James Lamb Signed-off-by: Sean Owen --- docs/cluster-overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 7277e2fb2731d..1f0822f7a317b 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -45,7 +45,7 @@ There are several useful things to note about this architecture: # Cluster Manager Types -The system currently supports three cluster managers: +The system currently supports several cluster managers: * [Standalone](spark-standalone.html) -- a simple cluster manager included with Spark that makes it easy to set up a cluster. From 7ea594e7876258296f340daddefcaf71a64ab824 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 2 Nov 2018 13:24:55 -0700 Subject: [PATCH 1983/2461] [SPARK-25827][CORE] Avoid converting incoming encrypted blocks to byte buffers ## What changes were proposed in this pull request? Avoid converting encrypted bocks to regular ByteBuffers, to ensure they can be sent over the network for replication & remote reads even when > 2GB. Also updates some TODOs with links to a SPARK-25905 for improving the handling here. ## How was this patch tested? Tested on a cluster with encrypted data > 2GB (after SPARK-25904 was applied as well). Closes #22917 from squito/real_SPARK-25827. Authored-by: Imran Rashid Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/network/BlockTransferService.scala | 4 +++- .../main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../src/main/scala/org/apache/spark/storage/DiskStore.scala | 5 +++-- .../scala/org/apache/spark/util/io/ChunkedByteBuffer.scala | 6 ++++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index eef8c31e05ab1..a58c8fa2e763f 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient} -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, EncryptedManagedBuffer, StorageLevel} import org.apache.spark.util.ThreadUtils private[spark] @@ -104,6 +104,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo data match { case f: FileSegmentManagedBuffer => result.success(f) + case e: EncryptedManagedBuffer => + result.success(e) case _ => val ret = ByteBuffer.allocate(data.size.toInt) ret.put(data.nioByteBuffer()) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c01a453151911..e35dd72521247 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -721,7 +721,7 @@ private[spark] class BlockManager( * Get block from remote block managers as serialized bytes. */ def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { - // TODO if we change this method to return the ManagedBuffer, then getRemoteValues + // TODO SPARK-25905 if we change this method to return the ManagedBuffer, then getRemoteValues // could just use the inputStream on the temp file, rather than reading the file into memory. // Until then, replication can cause the process to use too much memory and get killed // even though we've read the data to disk. diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index d88bd710d1ead..841e16afc7549 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -201,7 +201,7 @@ private class DiskBlockData( private def open() = new FileInputStream(file).getChannel } -private class EncryptedBlockData( +private[spark] class EncryptedBlockData( file: File, blockSize: Long, conf: SparkConf, @@ -263,7 +263,8 @@ private class EncryptedBlockData( } } -private class EncryptedManagedBuffer(val blockData: EncryptedBlockData) extends ManagedBuffer { +private[spark] class EncryptedManagedBuffer( + val blockData: EncryptedBlockData) extends ManagedBuffer { // This is the size of the decrypted data override def size(): Long = blockData.size diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 9547cb49bbee8..da2be84723a07 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.config import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.util.{ByteArrayWritableChannel, LimitedInputStream} -import org.apache.spark.storage.StorageUtils +import org.apache.spark.storage.{EncryptedManagedBuffer, StorageUtils} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -173,11 +173,13 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { private[spark] object ChunkedByteBuffer { - // TODO eliminate this method if we switch BlockManager to getting InputStreams + // TODO SPARK-25905 eliminate this method if we switch BlockManager to getting InputStreams def fromManagedBuffer(data: ManagedBuffer): ChunkedByteBuffer = { data match { case f: FileSegmentManagedBuffer => fromFile(f.getFile, f.getOffset, f.getLength) + case e: EncryptedManagedBuffer => + e.blockData.toChunkedByteBuffer(ByteBuffer.allocate _) case other => new ChunkedByteBuffer(other.nioByteBuffer()) } From 3404a73f4cf7be37e574026d08ad5cf82cfac871 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 2 Nov 2018 13:58:08 -0700 Subject: [PATCH 1984/2461] [SPARK-25875][K8S] Merge code to set up driver command into a single step. Right now there are 3 different classes dealing with building the driver command to run inside the pod, one for each "binding" supported by Spark. This has two main shortcomings: - the code in the 3 classes is very similar; changing things in one place would probably mean making a similar change in the others. - it gives the false impression that the step implementation is the only place where binding-specific logic is needed. That is not true; there was code in KubernetesConf that was binding-specific, and there's also code in the executor-specific config step. So the 3 classes weren't really working as a language-specific abstraction. On top of that, the current code was propagating command line parameters in a different way depending on the binding. That doesn't seem necessary, and in fact using environment variables for command line parameters is in general a really bad idea, since you can't handle special characters (e.g. spaces) that way. This change merges the 3 different code paths for Java, Python and R into a single step, and also merges the 3 code paths to start the Spark driver in the k8s entry point script. This increases the amount of shared code, and also moves more feature logic into the step itself, so it doesn't live in KubernetesConf. Note that not all logic related to setting up the driver lives in that step. For example, the memory overhead calculation still lives separately, except it now happens in the driver config step instead of outside the step hierarchy altogether. Some of the noise in the diff is because of changes to KubernetesConf, which will be addressed in a separate change. Tested with new and updated unit tests + integration tests. Author: Marcelo Vanzin Closes #22897 from vanzin/SPARK-25875. --- .../org/apache/spark/deploy/k8s/Config.scala | 30 +--- .../apache/spark/deploy/k8s/Constants.scala | 10 +- .../spark/deploy/k8s/KubernetesConf.scala | 70 ++------- .../k8s/features/BasicDriverFeatureStep.scala | 42 +++-- .../features/BasicExecutorFeatureStep.scala | 21 +-- .../features/DriverCommandFeatureStep.scala | 134 ++++++++++++++++ .../KubernetesFeatureConfigStep.scala | 4 +- .../features/PodTemplateConfigMapStep.scala | 4 +- .../bindings/JavaDriverFeatureStep.scala | 46 ------ .../bindings/PythonDriverFeatureStep.scala | 76 --------- .../bindings/RDriverFeatureStep.scala | 62 -------- .../submit/KubernetesClientApplication.scala | 10 +- .../k8s/submit/KubernetesDriverBuilder.scala | 26 +--- .../deploy/k8s/submit/MainAppResource.scala | 3 +- .../deploy/k8s/KubernetesConfSuite.scala | 103 +------------ .../BasicDriverFeatureStepSuite.scala | 69 +++++++-- .../BasicExecutorFeatureStepSuite.scala | 4 - .../DriverCommandFeatureStepSuite.scala | 144 ++++++++++++++++++ ...ubernetesCredentialsFeatureStepSuite.scala | 3 - .../DriverServiceFeatureStepSuite.scala | 19 +-- .../features/EnvSecretsFeatureStepSuite.scala | 1 - .../features/LocalDirsFeatureStepSuite.scala | 4 +- .../MountSecretsFeatureStepSuite.scala | 1 - .../MountVolumesFeatureStepSuite.scala | 4 +- .../PodTemplateConfigMapStepSuite.scala | 4 +- .../bindings/JavaDriverFeatureStepSuite.scala | 61 -------- .../PythonDriverFeatureStepSuite.scala | 112 -------------- .../bindings/RDriverFeatureStepSuite.scala | 64 -------- .../spark/deploy/k8s/submit/ClientSuite.scala | 4 +- .../submit/KubernetesDriverBuilderSuite.scala | 136 +++-------------- .../k8s/KubernetesExecutorBuilderSuite.scala | 6 - .../src/main/dockerfiles/spark/entrypoint.sh | 16 -- 32 files changed, 438 insertions(+), 855 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 862f1d63ed39f..a32bd93bb65bc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.k8s import java.util.concurrent.TimeUnit +import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigBuilder @@ -125,34 +126,6 @@ private[spark] object Config extends Logging { .stringConf .createOptional - val KUBERNETES_PYSPARK_MAIN_APP_RESOURCE = - ConfigBuilder("spark.kubernetes.python.mainAppResource") - .doc("The main app resource for pyspark jobs") - .internal() - .stringConf - .createOptional - - val KUBERNETES_PYSPARK_APP_ARGS = - ConfigBuilder("spark.kubernetes.python.appArgs") - .doc("The app arguments for PySpark Jobs") - .internal() - .stringConf - .createOptional - - val KUBERNETES_R_MAIN_APP_RESOURCE = - ConfigBuilder("spark.kubernetes.r.mainAppResource") - .doc("The main app resource for SparkR jobs") - .internal() - .stringConf - .createOptional - - val KUBERNETES_R_APP_ARGS = - ConfigBuilder("spark.kubernetes.r.appArgs") - .doc("The app arguments for SparkR Jobs") - .internal() - .stringConf - .createOptional - val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") .doc("Number of pods to launch at once in each round of executor allocation.") @@ -267,6 +240,7 @@ private[spark] object Config extends Logging { .doc("This sets the resource type internally") .internal() .stringConf + .checkValues(Set(APP_RESOURCE_TYPE_JAVA, APP_RESOURCE_TYPE_PYTHON, APP_RESOURCE_TYPE_R)) .createOptional val KUBERNETES_LOCAL_DIRS_TMPFS = diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 1c6d53c16871e..85917b88e912a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -69,12 +69,8 @@ private[spark] object Constants { val ENV_HADOOP_TOKEN_FILE_LOCATION = "HADOOP_TOKEN_FILE_LOCATION" // BINDINGS - val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" val ENV_PYSPARK_FILES = "PYSPARK_FILES" - val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" - val ENV_R_PRIMARY = "R_PRIMARY" - val ENV_R_ARGS = "R_APP_ARGS" // Pod spec templates val EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME = "pod-spec-template.yml" @@ -88,6 +84,7 @@ private[spark] object Constants { val DEFAULT_DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" val DEFAULT_EXECUTOR_CONTAINER_NAME = "spark-kubernetes-executor" val MEMORY_OVERHEAD_MIN_MIB = 384L + val NON_JVM_MEMORY_OVERHEAD_FACTOR = 0.4d // Hadoop Configuration val HADOOP_FILE_VOLUME = "hadoop-properties" @@ -113,4 +110,9 @@ private[spark] object Constants { // Hadoop credentials secrets for the Spark app. val SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR = "/mnt/secrets/hadoop-credentials" val SPARK_APP_HADOOP_SECRET_VOLUME_NAME = "hadoop-secret" + + // Application resource types. + val APP_RESOURCE_TYPE_JAVA = "java" + val APP_RESOURCE_TYPE_PYTHON = "python" + val APP_RESOURCE_TYPE_R = "r" } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 066547dcbb408..ebb81540bbbbe 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -28,6 +28,7 @@ import org.apache.spark.deploy.k8s.security.KubernetesHadoopDelegationTokenManag import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.deploy.k8s.submit.KubernetesClientApplication._ import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.util.Utils private[spark] sealed trait KubernetesRoleSpecificConf @@ -36,10 +37,15 @@ private[spark] sealed trait KubernetesRoleSpecificConf * Structure containing metadata for Kubernetes logic that builds a Spark driver. */ private[spark] case class KubernetesDriverSpecificConf( - mainAppResource: Option[MainAppResource], + mainAppResource: MainAppResource, mainClass: String, appName: String, - appArgs: Seq[String]) extends KubernetesRoleSpecificConf + appArgs: Seq[String], + pyFiles: Seq[String] = Nil) extends KubernetesRoleSpecificConf { + + require(mainAppResource != null, "Main resource must be provided.") + +} /* * Structure containing metadata for Kubernetes logic that builds a Spark executor. @@ -70,7 +76,6 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String], roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], - sparkFiles: Seq[String], hadoopConfSpec: Option[HadoopConfSpec]) { def hadoopConfigMapName: String = s"$appResourceNamePrefix-hadoop-config" @@ -82,23 +87,6 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) - def sparkJars(): Seq[String] = sparkConf - .getOption("spark.jars") - .map(str => str.split(",").toSeq) - .getOrElse(Seq.empty[String]) - - def pyFiles(): Option[String] = sparkConf - .get(KUBERNETES_PYSPARK_PY_FILES) - - def pySparkMainResource(): Option[String] = sparkConf - .get(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE) - - def pySparkPythonVersion(): String = sparkConf - .get(PYSPARK_MAJOR_PYTHON_VERSION) - - def sparkRMainResource(): Option[String] = sparkConf - .get(KUBERNETES_R_MAIN_APP_RESOURCE) - def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) def imagePullSecrets(): Seq[LocalObjectReference] = { @@ -130,38 +118,11 @@ private[spark] object KubernetesConf { appName: String, appResourceNamePrefix: String, appId: String, - mainAppResource: Option[MainAppResource], + mainAppResource: MainAppResource, mainClass: String, appArgs: Array[String], maybePyFiles: Option[String], hadoopConfDir: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = { - val sparkConfWithMainAppJar = sparkConf.clone() - val additionalFiles = mutable.ArrayBuffer.empty[String] - mainAppResource.foreach { - case JavaMainAppResource(res) => - val previousJars = sparkConf - .getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty) - if (!previousJars.contains(res)) { - sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) - } - // The function of this outer match is to account for multiple nonJVM - // bindings that will all have increased default MEMORY_OVERHEAD_FACTOR to 0.4 - case nonJVM: NonJVMResource => - nonJVM match { - case PythonMainAppResource(res) => - additionalFiles += res - maybePyFiles.foreach{maybePyFiles => - additionalFiles.appendAll(maybePyFiles.split(","))} - sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) - case RMainAppResource(res) => - additionalFiles += res - sparkConfWithMainAppJar.set(KUBERNETES_R_MAIN_APP_RESOURCE, res) - } - sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) - } - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + @@ -188,11 +149,6 @@ private[spark] object KubernetesConf { KubernetesVolumeUtils.parseVolumesWithPrefix( sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) - val sparkFiles = sparkConf - .getOption("spark.files") - .map(str => str.split(",").toSeq) - .getOrElse(Seq.empty[String]) ++ additionalFiles - val hadoopConfigMapName = sparkConf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) KubernetesUtils.requireNandDefined( hadoopConfDir, @@ -205,10 +161,12 @@ private[spark] object KubernetesConf { } else { None } + val pyFiles = maybePyFiles.map(Utils.stringToSeq).getOrElse(Nil) + KubernetesConf( - sparkConfWithMainAppJar, - KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + sparkConf.clone(), + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs, pyFiles), appResourceNamePrefix, appId, driverLabels, @@ -217,7 +175,6 @@ private[spark] object KubernetesConf { driverSecretEnvNamesToKeyRefs, driverEnvs, driverVolumes, - sparkFiles, hadoopConfSpec) } @@ -274,7 +231,6 @@ private[spark] object KubernetesConf { executorEnvSecrets, executorEnv, executorVolumes, - Seq.empty[String], None) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 96b14a0d82b4c..5ddf73cb16a6f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -28,6 +28,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -47,10 +48,23 @@ private[spark] class BasicDriverFeatureStep( // Memory settings private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + + // The memory overhead factor to use. If the user has not set it, then use a different + // value for non-JVM apps. This value is propagated to executors. + private val overheadFactor = + if (conf.roleSpecificConf.mainAppResource.isInstanceOf[NonJVMResource]) { + if (conf.sparkConf.contains(MEMORY_OVERHEAD_FACTOR)) { + conf.get(MEMORY_OVERHEAD_FACTOR) + } else { + NON_JVM_MEMORY_OVERHEAD_FACTOR + } + } else { + conf.get(MEMORY_OVERHEAD_FACTOR) + } + private val memoryOverheadMiB = conf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((conf.get(MEMORY_OVERHEAD_FACTOR) * driverMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) + .getOrElse(math.max((overheadFactor * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configurePod(pod: SparkPod): SparkPod = { @@ -134,20 +148,18 @@ private[spark] class BasicDriverFeatureStep( KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, "spark.app.id" -> conf.appId, KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, - KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") - - val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( - conf.sparkJars()) - val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( - conf.sparkFiles) - if (resolvedSparkJars.nonEmpty) { - additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true", + MEMORY_OVERHEAD_FACTOR.key -> overheadFactor.toString) + + Seq("spark.jars", "spark.files").foreach { key => + conf.getOption(key).foreach { value => + val resolved = KubernetesUtils.resolveFileUrisAndPath(Utils.stringToSeq(value)) + if (resolved.nonEmpty) { + additionalProps.put(key, resolved.mkString(",")) + } + } } + additionalProps.toMap } - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 1dab2a834f3e7..7f397e6e84fa5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -58,16 +58,13 @@ private[spark] class BasicExecutorFeatureStep( (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - private val executorMemoryTotal = kubernetesConf.sparkConf - .getOption(APP_RESOURCE_TYPE.key).map{ res => - val additionalPySparkMemory = res match { - case "python" => - kubernetesConf.sparkConf - .get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) - case _ => 0 - } - executorMemoryWithOverhead + additionalPySparkMemory - }.getOrElse(executorMemoryWithOverhead) + private val executorMemoryTotal = + if (kubernetesConf.get(APP_RESOURCE_TYPE) == Some(APP_RESOURCE_TYPE_PYTHON)) { + executorMemoryWithOverhead + + kubernetesConf.get(PYSPARK_EXECUTOR_MEMORY).map(_.toInt).getOrElse(0) + } else { + executorMemoryWithOverhead + } private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) private val executorCoresRequest = @@ -187,8 +184,4 @@ private[spark] class BasicExecutorFeatureStep( SparkPod(executorPod, containerWithLimitCores) } - - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala new file mode 100644 index 0000000000000..8b8f0d01d49f7 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder} + +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.util.Utils + +/** + * Creates the driver command for running the user app, and propagates needed configuration so + * executors can also find the app code. + */ +private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverConf = conf.roleSpecificConf + + override def configurePod(pod: SparkPod): SparkPod = { + driverConf.mainAppResource match { + case JavaMainAppResource(_) => + configureForJava(pod) + + case PythonMainAppResource(res) => + configureForPython(pod, res) + + case RMainAppResource(res) => + configureForR(pod, res) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + driverConf.mainAppResource match { + case JavaMainAppResource(res) => + res.map(additionalJavaProperties).getOrElse(Map.empty) + + case PythonMainAppResource(res) => + additionalPythonProperties(res) + + case RMainAppResource(res) => + additionalRProperties(res) + } + } + + private def configureForJava(pod: SparkPod): SparkPod = { + // The user application jar is merged into the spark.jars list and managed through that + // property, so use a "blank" resource for the Java driver. + val driverContainer = baseDriverContainer(pod, SparkLauncher.NO_RESOURCE).build() + SparkPod(pod.pod, driverContainer) + } + + private def configureForPython(pod: SparkPod, res: String): SparkPod = { + val maybePythonFiles = if (driverConf.pyFiles.nonEmpty) { + // Delineation by ":" is to append the PySpark Files to the PYTHONPATH + // of the respective PySpark pod + val resolved = KubernetesUtils.resolveFileUrisAndPath(driverConf.pyFiles) + Some(new EnvVarBuilder() + .withName(ENV_PYSPARK_FILES) + .withValue(resolved.mkString(":")) + .build()) + } else { + None + } + val pythonEnvs = + Seq(new EnvVarBuilder() + .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) + .withValue(conf.sparkConf.get(PYSPARK_MAJOR_PYTHON_VERSION)) + .build()) ++ + maybePythonFiles + + val pythonContainer = baseDriverContainer(pod, KubernetesUtils.resolveFileUri(res)) + .addAllToEnv(pythonEnvs.asJava) + .build() + + SparkPod(pod.pod, pythonContainer) + } + + private def configureForR(pod: SparkPod, res: String): SparkPod = { + val rContainer = baseDriverContainer(pod, KubernetesUtils.resolveFileUri(res)).build() + SparkPod(pod.pod, rContainer) + } + + private def baseDriverContainer(pod: SparkPod, resource: String): ContainerBuilder = { + new ContainerBuilder(pod.container) + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", driverConf.mainClass) + .addToArgs(resource) + .addToArgs(driverConf.appArgs: _*) + } + + private def additionalJavaProperties(resource: String): Map[String, String] = { + resourceType(APP_RESOURCE_TYPE_JAVA) ++ mergeFileList("spark.jars", Seq(resource)) + } + + private def additionalPythonProperties(resource: String): Map[String, String] = { + resourceType(APP_RESOURCE_TYPE_PYTHON) ++ + mergeFileList("spark.files", Seq(resource) ++ driverConf.pyFiles) + } + + private def additionalRProperties(resource: String): Map[String, String] = { + resourceType(APP_RESOURCE_TYPE_R) ++ mergeFileList("spark.files", Seq(resource)) + } + + private def mergeFileList(key: String, filesToAdd: Seq[String]): Map[String, String] = { + val existing = Utils.stringToSeq(conf.sparkConf.get(key, "")) + Map(key -> (existing ++ filesToAdd).distinct.mkString(",")) + } + + private def resourceType(resType: String): Map[String, String] = { + Map(APP_RESOURCE_TYPE.key -> resType) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala index 4c1be3bb13293..58cdaa3cadd6b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -61,11 +61,11 @@ private[spark] trait KubernetesFeatureConfigStep { /** * Return any system properties that should be set on the JVM in accordance to this feature. */ - def getAdditionalPodSystemProperties(): Map[String, String] + def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty /** * Return any additional Kubernetes resources that should be added to support this feature. Only * applicable when creating the driver in cluster mode. */ - def getAdditionalKubernetesResources(): Seq[HasMetadata] + def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala index 96a8013246b74..28e2d1726ae27 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -54,11 +54,11 @@ private[spark] class PodTemplateConfigMapStep( SparkPod(podWithVolume, containerWithVolume) } - def getAdditionalPodSystemProperties(): Map[String, String] = Map[String, String]( + override def getAdditionalPodSystemProperties(): Map[String, String] = Map[String, String]( KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key -> (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) - def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { require(conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala deleted file mode 100644 index 6f063b253cd73..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.bindings - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE -import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH -import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep -import org.apache.spark.launcher.SparkLauncher - -private[spark] class JavaDriverFeatureStep( - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) - extends KubernetesFeatureConfigStep { - override def configurePod(pod: SparkPod): SparkPod = { - val withDriverArgs = new ContainerBuilder(pod.container) - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", kubernetesConf.roleSpecificConf.mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - .addToArgs(kubernetesConf.roleSpecificConf.appArgs: _*) - .build() - SparkPod(pod.pod, withDriverArgs) - } - override def getAdditionalPodSystemProperties(): Map[String, String] = - Map(APP_RESOURCE_TYPE.key -> "java") - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala deleted file mode 100644 index cf0c03b22bd7e..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.bindings - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} -import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep - -private[spark] class PythonDriverFeatureStep( - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) - extends KubernetesFeatureConfigStep { - override def configurePod(pod: SparkPod): SparkPod = { - val roleConf = kubernetesConf.roleSpecificConf - require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") - // Delineation is done by " " because that is input into PythonRunner - val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( - pyArgs => - new EnvVarBuilder() - .withName(ENV_PYSPARK_ARGS) - .withValue(pyArgs.mkString(" ")) - .build()) - val maybePythonFiles = kubernetesConf.pyFiles().map( - // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH - // of the respective PySpark pod - pyFiles => - new EnvVarBuilder() - .withName(ENV_PYSPARK_FILES) - .withValue(KubernetesUtils.resolveFileUrisAndPath(pyFiles.split(",")) - .mkString(":")) - .build()) - val envSeq = - Seq(new EnvVarBuilder() - .withName(ENV_PYSPARK_PRIMARY) - .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.pySparkMainResource().get)) - .build(), - new EnvVarBuilder() - .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) - .withValue(kubernetesConf.pySparkPythonVersion()) - .build()) - val pythonEnvs = envSeq ++ - maybePythonArgs.toSeq ++ - maybePythonFiles.toSeq - - val withPythonPrimaryContainer = new ContainerBuilder(pod.container) - .addAllToEnv(pythonEnvs.asJava) - .addToArgs("driver-py") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", roleConf.mainClass) - .build() - - SparkPod(pod.pod, withPythonPrimaryContainer) - } - override def getAdditionalPodSystemProperties(): Map[String, String] = - Map(APP_RESOURCE_TYPE.key -> "python") - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala deleted file mode 100644 index 1a7ef52fefe70..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStep.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.bindings - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} -import org.apache.spark.deploy.k8s.Config.APP_RESOURCE_TYPE -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep - -private[spark] class RDriverFeatureStep( - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) - extends KubernetesFeatureConfigStep { - override def configurePod(pod: SparkPod): SparkPod = { - val roleConf = kubernetesConf.roleSpecificConf - require(roleConf.mainAppResource.isDefined, "R Main Resource must be defined") - // Delineation is done by " " because that is input into RRunner - val maybeRArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( - rArgs => - new EnvVarBuilder() - .withName(ENV_R_ARGS) - .withValue(rArgs.mkString(" ")) - .build()) - val envSeq = - Seq(new EnvVarBuilder() - .withName(ENV_R_PRIMARY) - .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.sparkRMainResource().get)) - .build()) - val rEnvs = envSeq ++ - maybeRArgs.toSeq - - val withRPrimaryContainer = new ContainerBuilder(pod.container) - .addAllToEnv(rEnvs.asJava) - .addToArgs("driver-r") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", roleConf.mainClass) - .build() - - SparkPod(pod.pod, withRPrimaryContainer) - } - override def getAdditionalPodSystemProperties(): Map[String, String] = - Map(APP_RESOURCE_TYPE.key -> "r") - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 4b58f8ba3c9bd..543d6b16d6ae2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * @param maybePyFiles additional Python files via --py-files */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], + mainAppResource: MainAppResource, mainClass: String, driverArgs: Array[String], maybePyFiles: Option[String], @@ -53,18 +53,18 @@ private[spark] case class ClientArguments( private[spark] object ClientArguments { def fromCommandLineArgs(args: Array[String]): ClientArguments = { - var mainAppResource: Option[MainAppResource] = None + var mainAppResource: MainAppResource = JavaMainAppResource(None) var mainClass: Option[String] = None val driverArgs = mutable.ArrayBuffer.empty[String] var maybePyFiles : Option[String] = None args.sliding(2, 2).toList.foreach { case Array("--primary-java-resource", primaryJavaResource: String) => - mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) + mainAppResource = JavaMainAppResource(Some(primaryJavaResource)) case Array("--primary-py-file", primaryPythonResource: String) => - mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + mainAppResource = PythonMainAppResource(primaryPythonResource) case Array("--primary-r-file", primaryRFile: String) => - mainAppResource = Some(RMainAppResource(primaryRFile)) + mainAppResource = RMainAppResource(primaryRFile) case Array("--other-py-files", pyFiles: String) => maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 5565cd74280e6..be4daec3b1bb9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -23,7 +23,6 @@ import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.{Config, KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.features._ -import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -45,18 +44,10 @@ private[spark] class KubernetesDriverBuilder( provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountVolumesFeatureStep) = new MountVolumesFeatureStep(_), - providePythonStep: ( + provideDriverCommandStep: ( KubernetesConf[KubernetesDriverSpecificConf] - => PythonDriverFeatureStep) = - new PythonDriverFeatureStep(_), - provideRStep: ( - KubernetesConf[KubernetesDriverSpecificConf] - => RDriverFeatureStep) = - new RDriverFeatureStep(_), - provideJavaStep: ( - KubernetesConf[KubernetesDriverSpecificConf] - => JavaDriverFeatureStep) = - new JavaDriverFeatureStep(_), + => DriverCommandFeatureStep) = + new DriverCommandFeatureStep(_), provideHadoopGlobalStep: ( KubernetesConf[KubernetesDriverSpecificConf] => KerberosConfDriverFeatureStep) = @@ -88,21 +79,14 @@ private[spark] class KubernetesDriverBuilder( Seq(providePodTemplateConfigMapStep(kubernetesConf)) } else Nil - val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { - case JavaMainAppResource(_) => - provideJavaStep(kubernetesConf) - case PythonMainAppResource(_) => - providePythonStep(kubernetesConf) - case RMainAppResource(_) => - provideRStep(kubernetesConf)} - .getOrElse(provideJavaStep(kubernetesConf)) + val driverCommandStep = provideDriverCommandStep(kubernetesConf) val maybeHadoopConfigStep = kubernetesConf.hadoopConfSpec.map { _ => provideHadoopGlobalStep(kubernetesConf)} val allFeatures: Seq[KubernetesFeatureConfigStep] = - (baseFeatures :+ bindingsStep) ++ + baseFeatures ++ Seq(driverCommandStep) ++ secretFeature ++ envSecretFeature ++ volumesFeature ++ maybeHadoopConfigStep.toSeq ++ podTemplateFeature diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index dd5a4549743df..a2e01fa2d9a0e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -20,7 +20,8 @@ private[spark] sealed trait MainAppResource private[spark] sealed trait NonJVMResource -private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource +private[spark] case class JavaMainAppResource(primaryResource: Option[String]) + extends MainAppResource private[spark] case class PythonMainAppResource(primaryResource: String) extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index bb2b94f9976e2..41ca8d186c17b 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -56,7 +56,7 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - mainAppResource = None, + mainAppResource = JavaMainAppResource(None), MAIN_CLASS, APP_ARGS, maybePyFiles = None, @@ -65,109 +65,10 @@ class KubernetesConfSuite extends SparkFunSuite { assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) assert(conf.roleSpecificConf.appName === APP_NAME) - assert(conf.roleSpecificConf.mainAppResource.isEmpty) assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) assert(conf.roleSpecificConf.appArgs === APP_ARGS) } - test("Creating driver conf with and without the main app jar influences spark.jars") { - val sparkConf = new SparkConf(false) - .setJars(Seq("local:///opt/spark/jar1.jar")) - val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) - val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( - sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppJar, - MAIN_CLASS, - APP_ARGS, - maybePyFiles = None, - hadoopConfDir = None) - assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") - .split(",") - === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) - val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( - sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppResource = None, - MAIN_CLASS, - APP_ARGS, - maybePyFiles = None, - hadoopConfDir = None) - assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") - === Array("local:///opt/spark/jar1.jar")) - assert(kubernetesConfWithoutMainJar.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.1) - } - - test("Creating driver conf with a python primary file") { - val mainResourceFile = "local:///opt/spark/main.py" - val inputPyFiles = Array("local:///opt/spark/example2.py", "local:///example3.py") - val sparkConf = new SparkConf(false) - .setJars(Seq("local:///opt/spark/jar1.jar")) - .set("spark.files", "local:///opt/spark/example4.py") - val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) - val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( - sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppResource, - MAIN_CLASS, - APP_ARGS, - Some(inputPyFiles.mkString(",")), - hadoopConfDir = None) - assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") - === Array("local:///opt/spark/jar1.jar")) - assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) - assert(kubernetesConfWithMainResource.sparkFiles - === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) - } - - test("Creating driver conf with a r primary file") { - val mainResourceFile = "local:///opt/spark/main.R" - val sparkConf = new SparkConf(false) - .setJars(Seq("local:///opt/spark/jar1.jar")) - .set("spark.files", "local:///opt/spark/example2.R") - val mainAppResource = Some(RMainAppResource(mainResourceFile)) - val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( - sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppResource, - MAIN_CLASS, - APP_ARGS, - maybePyFiles = None, - hadoopConfDir = None) - assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") - === Array("local:///opt/spark/jar1.jar")) - assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) - assert(kubernetesConfWithMainResource.sparkFiles - === Array("local:///opt/spark/example2.R", mainResourceFile)) - } - - test("Testing explicit setting of memory overhead on non-JVM tasks") { - val sparkConf = new SparkConf(false) - .set(MEMORY_OVERHEAD_FACTOR, 0.3) - - val mainResourceFile = "local:///opt/spark/main.py" - val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) - val conf = KubernetesConf.createDriverConf( - sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppResource, - MAIN_CLASS, - APP_ARGS, - maybePyFiles = None, - hadoopConfDir = None) - assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) - } - test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") { val sparkConf = new SparkConf(false) .set(MEMORY_OVERHEAD_FACTOR, 0.3) @@ -192,7 +93,7 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - mainAppResource = None, + mainAppResource = JavaMainAppResource(None), MAIN_CLASS, APP_ARGS, maybePyFiles = None, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 5c6bcc72158be..1e7dfbeffdb24 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -24,8 +24,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource -import org.apache.spark.deploy.k8s.submit.PythonMainAppResource +import org.apache.spark.deploy.k8s.submit._ +import org.apache.spark.internal.config._ import org.apache.spark.ui.SparkUI class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -52,7 +52,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { new LocalObjectReferenceBuilder().withName(secret).build() } private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), APP_NAME, MAIN_CLASS, APP_ARGS) @@ -62,8 +62,8 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") .set("spark.driver.cores", "2") .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(DRIVER_MEMORY.key, "256M") + .set(DRIVER_MEMORY_OVERHEAD, 200L) .set(CONTAINER_IMAGE, "spark-driver:latest") .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( @@ -77,7 +77,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, DRIVER_ENVS, Nil, - Seq.empty[String], hadoopConfSpec = None) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -130,21 +129,22 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, - "spark.kubernetes.submitInDriver" -> "true") + "spark.kubernetes.submitInDriver" -> "true", + MEMORY_OVERHEAD_FACTOR.key -> MEMORY_OVERHEAD_FACTOR.defaultValue.get.toString) assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) } test("Check appropriate entrypoint rerouting for various bindings") { val javaSparkConf = new SparkConf() - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(DRIVER_MEMORY.key, "4g") .set(CONTAINER_IMAGE, "spark-driver:latest") val pythonSparkConf = new SparkConf() - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(DRIVER_MEMORY.key, "4g") .set(CONTAINER_IMAGE, "spark-driver-py:latest") val javaKubernetesConf = KubernetesConf( javaSparkConf, KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), + JavaMainAppResource(None), APP_NAME, PY_MAIN_CLASS, APP_ARGS), @@ -156,13 +156,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, DRIVER_ENVS, Nil, - Seq.empty[String], hadoopConfSpec = None) val pythonKubernetesConf = KubernetesConf( pythonSparkConf, KubernetesDriverSpecificConf( - Some(PythonMainAppResource("")), + PythonMainAppResource(""), APP_NAME, PY_MAIN_CLASS, APP_ARGS), @@ -174,7 +173,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, DRIVER_ENVS, Nil, - Seq.empty[String], hadoopConfSpec = None) val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) @@ -204,7 +202,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, DRIVER_ENVS, Nil, - allFiles, hadoopConfSpec = None) val step = new BasicDriverFeatureStep(kubernetesConf) @@ -215,10 +212,52 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, "spark.kubernetes.submitInDriver" -> "true", "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", - "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt", + MEMORY_OVERHEAD_FACTOR.key -> MEMORY_OVERHEAD_FACTOR.defaultValue.get.toString) assert(additionalProperties === expectedSparkConf) } + // Memory overhead tests. Tuples are: + // test name, main resource, overhead factor, expected factor + Seq( + ("java", JavaMainAppResource(None), None, MEMORY_OVERHEAD_FACTOR.defaultValue.get), + ("python default", PythonMainAppResource(null), None, NON_JVM_MEMORY_OVERHEAD_FACTOR), + ("python w/ override", PythonMainAppResource(null), Some(0.9d), 0.9d), + ("r default", RMainAppResource(null), None, NON_JVM_MEMORY_OVERHEAD_FACTOR) + ).foreach { case (name, resource, factor, expectedFactor) => + test(s"memory overhead factor: $name") { + // Choose a driver memory where the default memory overhead is > MEMORY_OVERHEAD_MIN_MIB + val driverMem = MEMORY_OVERHEAD_MIN_MIB / MEMORY_OVERHEAD_FACTOR.defaultValue.get * 2 + + // main app resource, overhead factor + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(DRIVER_MEMORY.key, s"${driverMem.toInt}m") + factor.foreach { value => sparkConf.set(MEMORY_OVERHEAD_FACTOR, value) } + val driverConf = emptyDriverSpecificConf.copy(mainAppResource = resource) + val conf = KubernetesConf( + sparkConf, + driverConf, + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Nil, + hadoopConfSpec = None) + val step = new BasicDriverFeatureStep(conf) + val pod = step.configurePod(SparkPod.initialPod()) + val mem = pod.container.getResources.getRequests.get("memory").getAmount() + val expected = (driverMem + driverMem * expectedFactor).toInt + assert(mem === s"${expected}Mi") + + val systemProperties = step.getAdditionalPodSystemProperties() + assert(systemProperties(MEMORY_OVERHEAD_FACTOR.key) === expectedFactor.toString) + } + } + def containerPort(name: String, portNumber: Int): ContainerPort = new ContainerPortBuilder() .withName(name) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 41f34bd45cd5b..e9a16aab6ccc2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -91,7 +91,6 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None)) val executor = step.configurePod(SparkPod.initialPod()) @@ -132,7 +131,6 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None)) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -154,7 +152,6 @@ class BasicExecutorFeatureStepSuite Map.empty, Map("qux" -> "quux"), Nil, - Seq.empty[String], hadoopConfSpec = None)) val executor = step.configurePod(SparkPod.initialPod()) @@ -182,7 +179,6 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None)) val executor = step.configurePod(SparkPod.initialPod()) // This is checking that basic executor + executorMemory = 1408 + 42 = 1450 diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala new file mode 100644 index 0000000000000..30672952aaf6f --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit._ +import org.apache.spark.util.Utils + +class DriverCommandFeatureStepSuite extends SparkFunSuite { + + private val MAIN_CLASS = "mainClass" + + test("java resource") { + val mainResource = "local:///main.jar" + val spec = applyFeatureStep( + JavaMainAppResource(Some(mainResource)), + appArgs = Array("5", "7")) + assert(spec.pod.container.getArgs.asScala === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", MAIN_CLASS, + "spark-internal", "5", "7")) + + val jars = Utils.stringToSeq(spec.systemProperties("spark.jars")) + assert(jars.toSet === Set(mainResource)) + } + + test("python resource with no extra files") { + val mainResource = "local:///main.py" + val sparkConf = new SparkConf(false) + .set(PYSPARK_MAJOR_PYTHON_VERSION, "3") + + val spec = applyFeatureStep( + PythonMainAppResource(mainResource), + conf = sparkConf) + assert(spec.pod.container.getArgs.asScala === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", MAIN_CLASS, + "/main.py")) + val envs = spec.pod.container.getEnv.asScala + .map { env => (env.getName, env.getValue) } + .toMap + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3") + + val files = Utils.stringToSeq(spec.systemProperties("spark.files")) + assert(files.toSet === Set(mainResource)) + } + + test("python resource with extra files") { + val expectedMainResource = "/main.py" + val expectedPySparkFiles = "/example2.py:/example3.py" + val filesInConf = Set("local:///example.py") + + val mainResource = s"local://$expectedMainResource" + val pyFiles = Seq("local:///example2.py", "local:///example3.py") + + val sparkConf = new SparkConf(false) + .set("spark.files", filesInConf.mkString(",")) + .set(PYSPARK_MAJOR_PYTHON_VERSION, "2") + val spec = applyFeatureStep( + PythonMainAppResource(mainResource), + conf = sparkConf, + appArgs = Array("5", "7", "9"), + pyFiles = pyFiles) + + assert(spec.pod.container.getArgs.asScala === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", MAIN_CLASS, + "/main.py", "5", "7", "9")) + + val envs = spec.pod.container.getEnv.asScala + .map { env => (env.getName, env.getValue) } + .toMap + val expected = Map( + ENV_PYSPARK_FILES -> expectedPySparkFiles, + ENV_PYSPARK_MAJOR_PYTHON_VERSION -> "2") + assert(envs === expected) + + val files = Utils.stringToSeq(spec.systemProperties("spark.files")) + assert(files.toSet === pyFiles.toSet ++ filesInConf ++ Set(mainResource)) + } + + test("R resource") { + val expectedMainResource = "/main.R" + val mainResource = s"local://$expectedMainResource" + + val spec = applyFeatureStep( + RMainAppResource(mainResource), + appArgs = Array("5", "7", "9")) + + assert(spec.pod.container.getArgs.asScala === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", MAIN_CLASS, + "/main.R", "5", "7", "9")) + } + + private def applyFeatureStep( + resource: MainAppResource, + conf: SparkConf = new SparkConf(false), + appArgs: Array[String] = Array(), + pyFiles: Seq[String] = Nil): KubernetesDriverSpec = { + val driverConf = new KubernetesDriverSpecificConf( + resource, MAIN_CLASS, "appName", appArgs, pyFiles = pyFiles) + val kubernetesConf = KubernetesConf( + conf, + driverConf, + "resource-prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + hadoopConfSpec = None) + val step = new DriverCommandFeatureStep(kubernetesConf) + val pod = step.configurePod(SparkPod.initialPod()) + val props = step.getAdditionalPodSystemProperties() + KubernetesDriverSpec(pod, Nil, props) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 8675ceb48cf6d..36c6616a87b0a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -62,7 +62,6 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -95,7 +94,6 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -135,7 +133,6 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 5c3e801501513..3c46667c3042e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource import org.apache.spark.util.Clock class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { @@ -59,7 +60,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, "main", "app", Seq.empty), + JavaMainAppResource(None), "main", "app", Seq.empty), SHORT_RESOURCE_NAME_PREFIX, "app-id", DRIVER_LABELS, @@ -68,7 +69,6 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None)) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -92,7 +92,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) .set(KUBERNETES_NAMESPACE, "my-namespace"), KubernetesDriverSpecificConf( - None, "main", "app", Seq.empty), + JavaMainAppResource(None), "main", "app", Seq.empty), SHORT_RESOURCE_NAME_PREFIX, "app-id", DRIVER_LABELS, @@ -101,7 +101,6 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None)) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -115,7 +114,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, "main", "app", Seq.empty), + JavaMainAppResource(None), "main", "app", Seq.empty), SHORT_RESOURCE_NAME_PREFIX, "app-id", DRIVER_LABELS, @@ -124,7 +123,6 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None)) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -147,7 +145,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { KubernetesConf( sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), KubernetesDriverSpecificConf( - None, "main", "app", Seq.empty), + JavaMainAppResource(None), "main", "app", Seq.empty), LONG_RESOURCE_NAME_PREFIX, "app-id", DRIVER_LABELS, @@ -156,7 +154,6 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None), clock) val driverService = configurationStep @@ -176,7 +173,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { KubernetesConf( sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), KubernetesDriverSpecificConf( - None, "main", "app", Seq.empty), + JavaMainAppResource(None), "main", "app", Seq.empty), LONG_RESOURCE_NAME_PREFIX, "app-id", DRIVER_LABELS, @@ -185,7 +182,6 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None), clock) fail("The driver bind address should not be allowed.") @@ -203,7 +199,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, "main", "app", Seq.empty), + JavaMainAppResource(None), "main", "app", Seq.empty), LONG_RESOURCE_NAME_PREFIX, "app-id", DRIVER_LABELS, @@ -212,7 +208,6 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index 43796b77efdc7..3d253079c3ce7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -46,7 +46,6 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ envVarsToKeys, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) val step = new EnvSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 3a4e60547d7f2..894d824999aac 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val defaultLocalDir = "/var/data/default-local-dir" @@ -36,7 +37,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { kubernetesConf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "app-name", "main", Seq.empty), @@ -48,7 +49,6 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 18e3d773f690d..1555f6a9c6527 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -44,7 +44,6 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 0d0a5fb951f64..2a957460ca8e0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.deploy.k8s.features import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class MountVolumesFeatureStepSuite extends SparkFunSuite { private val sparkConf = new SparkConf(false) private val emptyKubernetesConf = KubernetesConf( sparkConf = sparkConf, roleSpecificConf = KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "app-name", "main", Seq.empty), @@ -36,7 +37,6 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, roleVolumes = Nil, - sparkFiles = Nil, hadoopConfSpec = None) test("Mounts hostPath volumes") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala index d7bbbd121af72..370948c9502e4 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { private var sparkConf: SparkConf = _ @@ -36,7 +37,7 @@ class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { kubernetesConf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "app-name", "main", Seq.empty), @@ -48,7 +49,6 @@ class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], Option.empty) templateFile = Files.createTempFile("pod-template", "yml").toFile templateFile.deleteOnExit() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala deleted file mode 100644 index 9172e0c3dc408..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.bindings - -import scala.collection.JavaConverters._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.PythonMainAppResource - -class JavaDriverFeatureStepSuite extends SparkFunSuite { - - test("Java Step modifies container correctly") { - val baseDriverPod = SparkPod.initialPod() - val sparkConf = new SparkConf(false) - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - Some(PythonMainAppResource("local:///main.jar")), - "test-class", - "java-runner", - Seq("5 7")), - appResourceNamePrefix = "", - appId = "", - roleLabels = Map.empty, - roleAnnotations = Map.empty, - roleSecretNamesToMountPaths = Map.empty, - roleSecretEnvNamesToKeyRefs = Map.empty, - roleEnvs = Map.empty, - roleVolumes = Nil, - sparkFiles = Seq.empty[String], - hadoopConfSpec = None) - - val step = new JavaDriverFeatureStep(kubernetesConf) - val driverPod = step.configurePod(baseDriverPod).pod - val driverContainerwithJavaStep = step.configurePod(baseDriverPod).container - assert(driverContainerwithJavaStep.getArgs.size === 7) - val args = driverContainerwithJavaStep - .getArgs.asScala - assert(args === List( - "driver", - "--properties-file", SPARK_CONF_PATH, - "--class", "test-class", - "spark-internal", "5 7")) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala deleted file mode 100644 index 2bcc6465b79d6..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.bindings - -import scala.collection.JavaConverters._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.PythonMainAppResource - -class PythonDriverFeatureStepSuite extends SparkFunSuite { - - test("Python Step modifies container correctly") { - val expectedMainResource = "/main.py" - val mainResource = "local:///main.py" - val pyFiles = Seq("local:///example2.py", "local:///example3.py") - val expectedPySparkFiles = - "/example2.py:/example3.py" - val baseDriverPod = SparkPod.initialPod() - val sparkConf = new SparkConf(false) - .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) - .set(KUBERNETES_PYSPARK_PY_FILES, pyFiles.mkString(",")) - .set("spark.files", "local:///example.py") - .set(PYSPARK_MAJOR_PYTHON_VERSION, "2") - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - Some(PythonMainAppResource("local:///main.py")), - "test-app", - "python-runner", - Seq("5", "7", "9")), - appResourceNamePrefix = "", - appId = "", - roleLabels = Map.empty, - roleAnnotations = Map.empty, - roleSecretNamesToMountPaths = Map.empty, - roleSecretEnvNamesToKeyRefs = Map.empty, - roleEnvs = Map.empty, - roleVolumes = Nil, - sparkFiles = Seq.empty[String], - hadoopConfSpec = None) - - val step = new PythonDriverFeatureStep(kubernetesConf) - val driverPod = step.configurePod(baseDriverPod).pod - val driverContainerwithPySpark = step.configurePod(baseDriverPod).container - assert(driverContainerwithPySpark.getEnv.size === 4) - val envs = driverContainerwithPySpark - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) - assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) - assert(envs(ENV_PYSPARK_ARGS) === "5 7 9") - assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") - } - test("Python Step testing empty pyfiles") { - val mainResource = "local:///main.py" - val baseDriverPod = SparkPod.initialPod() - val sparkConf = new SparkConf(false) - .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) - .set(PYSPARK_MAJOR_PYTHON_VERSION, "3") - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - Some(PythonMainAppResource("local:///main.py")), - "test-class-py", - "python-runner", - Seq.empty[String]), - appResourceNamePrefix = "", - appId = "", - roleLabels = Map.empty, - roleAnnotations = Map.empty, - roleSecretNamesToMountPaths = Map.empty, - roleSecretEnvNamesToKeyRefs = Map.empty, - roleEnvs = Map.empty, - roleVolumes = Nil, - sparkFiles = Seq.empty[String], - hadoopConfSpec = None) - val step = new PythonDriverFeatureStep(kubernetesConf) - val driverContainerwithPySpark = step.configurePod(baseDriverPod).container - val args = driverContainerwithPySpark - .getArgs.asScala - assert(driverContainerwithPySpark.getArgs.size === 5) - assert(args === List( - "driver-py", - "--properties-file", SPARK_CONF_PATH, - "--class", "test-class-py")) - val envs = driverContainerwithPySpark - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3") - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala deleted file mode 100644 index 17af6011a17d5..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/RDriverFeatureStepSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.bindings - -import scala.collection.JavaConverters._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.RMainAppResource - -class RDriverFeatureStepSuite extends SparkFunSuite { - - test("R Step modifies container correctly") { - val expectedMainResource = "/main.R" - val mainResource = "local:///main.R" - val baseDriverPod = SparkPod.initialPod() - val sparkConf = new SparkConf(false) - .set(KUBERNETES_R_MAIN_APP_RESOURCE, mainResource) - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - Some(RMainAppResource(mainResource)), - "test-app", - "r-runner", - Seq("5", "7", "9")), - appResourceNamePrefix = "", - appId = "", - roleLabels = Map.empty, - roleAnnotations = Map.empty, - roleSecretNamesToMountPaths = Map.empty, - roleSecretEnvNamesToKeyRefs = Map.empty, - roleEnvs = Map.empty, - roleVolumes = Seq.empty, - sparkFiles = Seq.empty[String], - hadoopConfSpec = None) - - val step = new RDriverFeatureStep(kubernetesConf) - val driverContainerwithR = step.configurePod(baseDriverPod).container - assert(driverContainerwithR.getEnv.size === 2) - val envs = driverContainerwithR - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_R_PRIMARY) === expectedMainResource) - assert(envs(ENV_R_ARGS) === "5 7 9") - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index ae13df39b7a76..81e3822389f30 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -133,7 +134,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { sparkConf = new SparkConf(false) kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( sparkConf, - KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KubernetesDriverSpecificConf(JavaMainAppResource(None), MAIN_CLASS, APP_NAME, APP_ARGS), KUBERNETES_RESOURCE_PREFIX, APP_ID, Map.empty, @@ -142,7 +143,6 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 84968c3523fc0..fe900fda6e545 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config.{CONTAINER_IMAGE, KUBERNETES_DRIVER_PODTEMPLATE_FILE, KUBERNETES_EXECUTOR_PODTEMPLATE_FILE} import org.apache.spark.deploy.k8s.features._ -import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep, RDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -33,9 +32,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" - private val JAVA_STEP_TYPE = "java-bindings" - private val R_STEP_TYPE = "r-bindings" - private val PYSPARK_STEP_TYPE = "pyspark-bindings" + private val DRIVER_CMD_STEP_TYPE = "driver-command" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val HADOOP_GLOBAL_STEP_TYPE = "hadoop-global" private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" @@ -56,14 +53,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) - private val javaStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - JAVA_STEP_TYPE, classOf[JavaDriverFeatureStep]) - - private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) - - private val rStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - R_STEP_TYPE, classOf[RDriverFeatureStep]) + private val driverCommandStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + DRIVER_CMD_STEP_TYPE, classOf[DriverCommandFeatureStep]) private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) @@ -87,9 +78,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => envSecretsStep, _ => localDirsStep, _ => mountVolumesStep, - _ => pythonStep, - _ => rStep, - _ => javaStep, + _ => driverCommandStep, _ => hadoopGlobalStep, _ => templateVolumeStep) @@ -97,7 +86,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - Some(JavaMainAppResource("example.jar")), + JavaMainAppResource(Some("example.jar")), "test-app", "main", Seq.empty), @@ -109,7 +98,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -117,14 +105,14 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - JAVA_STEP_TYPE) + DRIVER_CMD_STEP_TYPE) } test("Apply secrets step if secrets are present.") { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "test-app", "main", Seq.empty), @@ -136,7 +124,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map("EnvName" -> "SecretName:secretKey"), Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = None) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -146,61 +133,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE, ENV_SECRETS_STEP_TYPE, - JAVA_STEP_TYPE) - } - - test("Apply Java step if main resource is none.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - None, - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Seq.empty[String], - hadoopConfSpec = None) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - JAVA_STEP_TYPE) - } - - test("Apply Python step if main resource is python.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - Some(PythonMainAppResource("example.py")), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Seq.empty[String], - hadoopConfSpec = None) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - PYSPARK_STEP_TYPE) + DRIVER_CMD_STEP_TYPE) } test("Apply volumes step if mounts are present.") { @@ -212,7 +145,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "test-app", "main", Seq.empty), @@ -224,7 +157,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, volumeSpec :: Nil, - Seq.empty[String], hadoopConfSpec = None) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -233,34 +165,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, MOUNT_VOLUMES_STEP_TYPE, - JAVA_STEP_TYPE) - } - - test("Apply R step if main resource is R.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - Some(RMainAppResource("example.R")), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Seq.empty[String], - hadoopConfSpec = None) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - R_STEP_TYPE) + DRIVER_CMD_STEP_TYPE) } test("Apply template volume step if executor template is present.") { @@ -270,7 +175,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - Some(JavaMainAppResource("example.jar")), + JavaMainAppResource(Some("example.jar")), "test-app", "main", Seq.empty), @@ -282,7 +187,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], Option.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -290,7 +194,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - JAVA_STEP_TYPE, + DRIVER_CMD_STEP_TYPE, TEMPLATE_VOLUME_STEP_TYPE) } @@ -298,7 +202,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "test-app", "main", Seq.empty), @@ -310,7 +214,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = Some( HadoopConfSpec( Some("/var/hadoop-conf"), @@ -321,7 +224,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - JAVA_STEP_TYPE, + DRIVER_CMD_STEP_TYPE, HADOOP_GLOBAL_STEP_TYPE) } @@ -329,7 +232,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + JavaMainAppResource(None), "test-app", "main", Seq.empty), @@ -341,7 +244,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], hadoopConfSpec = Some( HadoopConfSpec( None, @@ -352,7 +254,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - JAVA_STEP_TYPE, + DRIVER_CMD_STEP_TYPE, HADOOP_GLOBAL_STEP_TYPE) } @@ -381,7 +283,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val kubernetesConf = new KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - Some(JavaMainAppResource("example.jar")), + JavaMainAppResource(Some("example.jar")), "test-app", "main", Seq.empty), @@ -393,7 +295,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], Option.empty) val driverSpec = KubernetesDriverBuilder .apply(kubernetesClient, sparkConf) @@ -414,7 +315,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val kubernetesConf = new KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - Some(JavaMainAppResource("example.jar")), + JavaMainAppResource(Some("example.jar")), "test-app", "main", Seq.empty), @@ -426,7 +327,6 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], Option.empty) val exception = intercept[SparkException] { KubernetesDriverBuilder diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index fb2509fc1bda5..1fea08c37ccc6 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -76,7 +76,6 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], None) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -95,7 +94,6 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map("secret-name" -> "secret-key"), Map.empty, Nil, - Seq.empty[String], None) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -123,7 +121,6 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, volumeSpec :: Nil, - Seq.empty[String], None) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -152,7 +149,6 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], Some(HadoopConfSpec(Some("/var/hadoop-conf"), None))) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -180,7 +176,6 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], Some(HadoopConfSpec(None, Some("pre-defined-onfigMapName")))) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -225,7 +220,6 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Nil, - Seq.empty[String], Option.empty) val sparkPod = KubernetesExecutorBuilder .apply(kubernetesClient, sparkConf) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 4958b7363fee0..2b2a4e4cf6bcc 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -96,22 +96,6 @@ case "$SPARK_K8S_CMD" in "$@" ) ;; - driver-py) - CMD=( - "$SPARK_HOME/bin/spark-submit" - --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" - --deploy-mode client - "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS - ) - ;; - driver-r) - CMD=( - "$SPARK_HOME/bin/spark-submit" - --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" - --deploy-mode client - "$@" $R_PRIMARY $R_ARGS - ) - ;; executor) CMD=( ${JAVA_HOME}/bin/java From ed0c57e10dd619464c44d7af5cc0f358bd7f189d Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 2 Nov 2018 17:17:48 -0500 Subject: [PATCH 1985/2461] [SPARK-25861][MINOR][WEBUI] Remove unused refreshInterval parameter from the headerSparkPage method. ## What changes were proposed in this pull request? 'refreshInterval' is not used any where in the headerSparkPage method. So, we don't need to pass the parameter while calling the 'headerSparkPage' method. ## How was this patch tested? Existing tests Closes #22864 from shahidki31/unusedCode. Authored-by: Shahid Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 - .../org/apache/spark/sql/execution/ui/AllExecutionsPage.scala | 2 +- .../scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/hive/thriftserver/ui/ThriftServerPage.scala | 2 +- .../sql/hive/thriftserver/ui/ThriftServerSessionPage.scala | 2 +- .../scala/org/apache/spark/streaming/ui/StreamingPage.scala | 2 +- 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 732b7528f499e..3aed4647a96f0 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -218,7 +218,6 @@ private[spark] object UIUtils extends Logging { title: String, content: => Seq[Node], activeTab: SparkUITab, - refreshInterval: Option[Int] = None, helpText: Option[String] = None, showVisualization: Boolean = false, useDataTables: Boolean = false): Seq[Node] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 1a25cd2a49e36..311f805f78832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -146,7 +146,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L - UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent) } private def executionsTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 877176b030f8b..e4c119e6d06c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -84,7 +84,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging } UIUtils.headerSparkPage( - request, s"Details for Query $executionId", content, parent, Some(5000)) + request, s"Details for Query $executionId", content, parent) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 771104ceb8842..27d2c997ca3e8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -50,7 +50,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" generateSessionStatsTable(request) ++ generateSQLStatsTable(request) } - UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent) } /** Generate basic stats of the thrift server program */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 163eb43aabc72..f46eeea941540 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -58,7 +58,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) ++ generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent) } /** Generate basic stats of the thrift server program */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4ce661bc1144e..d16611f412034 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -156,7 +156,7 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent) } /** From 0e318acd0cc3b42e8be9cb2a53cccfdc4a0805f9 Mon Sep 17 00:00:00 2001 From: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Date: Sat, 3 Nov 2018 14:03:50 +0800 Subject: [PATCH 1986/2461] [SPARK-25901][CORE] Use only one thread in BarrierTaskContext companion object ## What changes were proposed in this pull request? Now we use only one `timer` (and thus a backing thread) in `BarrierTaskContext` companion object, and the objects can add `timerTasks` to that `timer`. ## How was this patch tested? This was tested manually by generating logs and seeing that they look the same as ones before, namely, that is, a partition waiting on another partition for 5seconds generates 4-5 log messages when the frequency of logging is set to 1second. Closes #22912 from yogeshg/thread. Authored-by: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Signed-off-by: Xingbo Jiang --- .../main/scala/org/apache/spark/BarrierTaskContext.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 90a5c4130f799..7ce421e5479ee 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -41,14 +41,14 @@ import org.apache.spark.util._ class BarrierTaskContext private[spark] ( taskContext: TaskContext) extends TaskContext with Logging { + import BarrierTaskContext._ + // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. private val barrierCoordinator: RpcEndpointRef = { val env = SparkEnv.get RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) } - private val timer = new Timer("Barrier task timer for barrier() calls.") - // Local barrierEpoch that identify a barrier() call from current task, it shall be identical // with the driver side epoch. private var barrierEpoch = 0 @@ -234,4 +234,7 @@ object BarrierTaskContext { @Experimental @Since("2.4.0") def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext] + + private val timer = new Timer("Barrier task timer for barrier() calls.") + } From 42b6c1fb05ead89331791cd27ea7c97ff7fd8e16 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 3 Nov 2018 09:09:39 -0700 Subject: [PATCH 1987/2461] [SPARK-25931][SQL] Benchmarking creation of Jackson parser ## What changes were proposed in this pull request? Added new benchmark which forcibly invokes Jackson parser to check overhead of its creation for short and wide JSON strings. Existing benchmarks do not allow to check that due to an optimisation introduced by #21909 for empty schema pushed down to JSON datasource. The `count()` action passes empty schema as required schema to the datasource, and Jackson parser is not created at all in that case. Besides of new benchmark I also refactored existing benchmarks: - Added `numIters` to control number of iteration in each benchmark - Renamed `JSON per-line parsing` -> `count a short column`, `JSON parsing of wide lines` -> `count a wide column`, and `Count a dataset with 10 columns` -> `Select a subset of 10 columns`. Closes #22920 from MaxGekk/json-benchmark-follow-up. Lead-authored-by: Maxim Gekk Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- sql/core/benchmarks/JSONBenchmark-results.txt | 35 ++-- .../datasources/json/JsonBenchmark.scala | 156 ++++++++++++------ 2 files changed, 131 insertions(+), 60 deletions(-) diff --git a/sql/core/benchmarks/JSONBenchmark-results.txt b/sql/core/benchmarks/JSONBenchmark-results.txt index 99937309a4145..477429430cdd0 100644 --- a/sql/core/benchmarks/JSONBenchmark-results.txt +++ b/sql/core/benchmarks/JSONBenchmark-results.txt @@ -7,31 +7,42 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -No encoding 62946 / 63310 1.6 629.5 1.0X -UTF-8 is set 112814 / 112866 0.9 1128.1 0.6X +No encoding 71832 / 72149 1.4 718.3 1.0X +UTF-8 is set 101700 / 101819 1.0 1017.0 0.7X Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +count a short column: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -No encoding 16468 / 16553 6.1 164.7 1.0X -UTF-8 is set 16420 / 16441 6.1 164.2 1.0X +No encoding 16501 / 16519 6.1 165.0 1.0X +UTF-8 is set 16477 / 16516 6.1 164.8 1.0X Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +count a wide column: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -No encoding 39789 / 40053 0.3 3978.9 1.0X -UTF-8 is set 39505 / 39584 0.3 3950.5 1.0X +No encoding 39871 / 40242 0.3 3987.1 1.0X +UTF-8 is set 39581 / 39721 0.3 3958.1 1.0X +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Select a subset of 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Select 10 columns + count() 16011 / 16033 0.6 1601.1 1.0X +Select 1 column + count() 14350 / 14392 0.7 1435.0 1.1X +count() 3007 / 3034 3.3 300.7 5.3X + +Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -Count a dataset with 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +creation of JSON parser per line: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Select 10 columns + count() 15997 / 16015 0.6 1599.7 1.0X -Select 1 column + count() 13280 / 13326 0.8 1328.0 1.2X -count() 3006 / 3021 3.3 300.6 5.3X +Short column without encoding 8334 / 8453 1.2 833.4 1.0X +Short column with UTF-8 13627 / 13784 0.7 1362.7 0.6X +Wide column without encoding 155073 / 155351 0.1 15507.3 0.1X +Wide column with UTF-8 212114 / 212263 0.0 21211.4 0.0X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala index 04f724ec8638f..f50c25ecfc1f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala @@ -39,13 +39,17 @@ import org.apache.spark.sql.types._ object JSONBenchmark extends SqlBasedBenchmark { import spark.implicits._ - def schemaInferring(rowsNum: Int): Unit = { + def prepareDataInfo(benchmark: Benchmark): Unit = { + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + } + + def schemaInferring(rowsNum: Int, numIters: Int): Unit = { val benchmark = new Benchmark("JSON schema inferring", rowsNum, output = output) withTempPath { path => - // scalastyle:off println - benchmark.out.println("Preparing data for benchmarking ...") - // scalastyle:on println + prepareDataInfo(benchmark) spark.sparkContext.range(0, rowsNum, 1) .map(_ => "a") @@ -54,11 +58,11 @@ object JSONBenchmark extends SqlBasedBenchmark { .option("encoding", "UTF-8") .json(path.getAbsolutePath) - benchmark.addCase("No encoding", 3) { _ => + benchmark.addCase("No encoding", numIters) { _ => spark.read.json(path.getAbsolutePath) } - benchmark.addCase("UTF-8 is set", 3) { _ => + benchmark.addCase("UTF-8 is set", numIters) { _ => spark.read .option("encoding", "UTF-8") .json(path.getAbsolutePath) @@ -68,28 +72,29 @@ object JSONBenchmark extends SqlBasedBenchmark { } } - def perlineParsing(rowsNum: Int): Unit = { - val benchmark = new Benchmark("JSON per-line parsing", rowsNum, output = output) + def writeShortColumn(path: String, rowsNum: Int): StructType = { + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path) + new StructType().add("fieldA", StringType) + } - withTempPath { path => - // scalastyle:off println - benchmark.out.println("Preparing data for benchmarking ...") - // scalastyle:on println + def countShortColumn(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark("count a short column", rowsNum, output = output) - spark.sparkContext.range(0, rowsNum, 1) - .map(_ => "a") - .toDF("fieldA") - .write.json(path.getAbsolutePath) - val schema = new StructType().add("fieldA", StringType) + withTempPath { path => + prepareDataInfo(benchmark) + val schema = writeShortColumn(path.getAbsolutePath, rowsNum) - benchmark.addCase("No encoding", 3) { _ => + benchmark.addCase("No encoding", numIters) { _ => spark.read .schema(schema) .json(path.getAbsolutePath) .count() } - benchmark.addCase("UTF-8 is set", 3) { _ => + benchmark.addCase("UTF-8 is set", numIters) { _ => spark.read .option("encoding", "UTF-8") .schema(schema) @@ -101,35 +106,36 @@ object JSONBenchmark extends SqlBasedBenchmark { } } - def perlineParsingOfWideColumn(rowsNum: Int): Unit = { - val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum, output = output) + def writeWideColumn(path: String, rowsNum: Int): StructType = { + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path) + new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + } + + def countWideColumn(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark("count a wide column", rowsNum, output = output) withTempPath { path => - // scalastyle:off println - benchmark.out.println("Preparing data for benchmarking ...") - // scalastyle:on println + prepareDataInfo(benchmark) + val schema = writeWideColumn(path.getAbsolutePath, rowsNum) - spark.sparkContext.range(0, rowsNum, 1) - .map { i => - val s = "abcdef0123456789ABCDEF" * 20 - s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" - } - .toDF().write.text(path.getAbsolutePath) - val schema = new StructType() - .add("a", StringType).add("b", LongType) - .add("c", StringType).add("d", LongType) - .add("e", StringType).add("f", LongType) - .add("x", StringType).add("y", LongType) - .add("z", StringType) - - benchmark.addCase("No encoding", 3) { _ => + benchmark.addCase("No encoding", numIters) { _ => spark.read .schema(schema) .json(path.getAbsolutePath) .count() } - benchmark.addCase("UTF-8 is set", 3) { _ => + benchmark.addCase("UTF-8 is set", numIters) { _ => spark.read .option("encoding", "UTF-8") .schema(schema) @@ -141,12 +147,14 @@ object JSONBenchmark extends SqlBasedBenchmark { } } - def countBenchmark(rowsNum: Int): Unit = { + def selectSubsetOfColumns(rowsNum: Int, numIters: Int): Unit = { val colsNum = 10 val benchmark = - new Benchmark(s"Count a dataset with $colsNum columns", rowsNum, output = output) + new Benchmark(s"Select a subset of $colsNum columns", rowsNum, output = output) withTempPath { path => + prepareDataInfo(benchmark) + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) val schema = StructType(fields) val columnNames = schema.fieldNames @@ -158,13 +166,13 @@ object JSONBenchmark extends SqlBasedBenchmark { val ds = spark.read.schema(schema).json(path.getAbsolutePath) - benchmark.addCase(s"Select $colsNum columns + count()", 3) { _ => + benchmark.addCase(s"Select $colsNum columns + count()", numIters) { _ => ds.select("*").filter((_: Row) => true).count() } - benchmark.addCase(s"Select 1 column + count()", 3) { _ => + benchmark.addCase(s"Select 1 column + count()", numIters) { _ => ds.select($"col1").filter((_: Row) => true).count() } - benchmark.addCase(s"count()", 3) { _ => + benchmark.addCase(s"count()", numIters) { _ => ds.count() } @@ -172,12 +180,64 @@ object JSONBenchmark extends SqlBasedBenchmark { } } + def jsonParserCreation(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark("creation of JSON parser per line", rowsNum, output = output) + + withTempPath { path => + prepareDataInfo(benchmark) + + val shortColumnPath = path.getAbsolutePath + "/short" + val shortSchema = writeShortColumn(shortColumnPath, rowsNum) + + val wideColumnPath = path.getAbsolutePath + "/wide" + val wideSchema = writeWideColumn(wideColumnPath, rowsNum) + + benchmark.addCase("Short column without encoding", numIters) { _ => + spark.read + .schema(shortSchema) + .json(shortColumnPath) + .filter((_: Row) => true) + .count() + } + + benchmark.addCase("Short column with UTF-8", numIters) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(shortSchema) + .json(shortColumnPath) + .filter((_: Row) => true) + .count() + } + + benchmark.addCase("Wide column without encoding", numIters) { _ => + spark.read + .schema(wideSchema) + .json(wideColumnPath) + .filter((_: Row) => true) + .count() + } + + benchmark.addCase("Wide column with UTF-8", numIters) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(wideSchema) + .json(wideColumnPath) + .filter((_: Row) => true) + .count() + } + + benchmark.run() + } + } + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val numIters = 3 runBenchmark("Benchmark for performance of JSON parsing") { - schemaInferring(100 * 1000 * 1000) - perlineParsing(100 * 1000 * 1000) - perlineParsingOfWideColumn(10 * 1000 * 1000) - countBenchmark(10 * 1000 * 1000) + schemaInferring(100 * 1000 * 1000, numIters) + countShortColumn(100 * 1000 * 1000, numIters) + countWideColumn(10 * 1000 * 1000, numIters) + selectSubsetOfColumns(10 * 1000 * 1000, numIters) + jsonParserCreation(10 * 1000 * 1000, numIters) } } } From 1a7abf3f453f7d6012d7e842cf05f29f3afbb3bc Mon Sep 17 00:00:00 2001 From: Alex Hagerman Date: Sat, 3 Nov 2018 12:56:59 -0500 Subject: [PATCH 1988/2461] [SPARK-25933][DOCUMENTATION] Fix pstats.Stats() reference in configuration.md ## What changes were proposed in this pull request? Change ptats.Stats() to pstats.Stats() for `spark.python.profile.dump` in configuration.md. ## How was this patch tested? Doc test Closes #22933 from AlexHagerman/doc_fix. Authored-by: Alex Hagerman Signed-off-by: Sean Owen --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8cb0ed1502126..11ee7a9610602 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -445,7 +445,7 @@ Apart from these, the following properties are also available, and may be useful The directory which is used to dump the profile result before driver exiting. The results will be dumped as separated file for each RDD. They can be loaded - by ptats.Stats(). If this is specified, the profile result will not be displayed + by pstats.Stats(). If this is specified, the profile result will not be displayed automatically. From 39399f40b861f7d8e60d0e25d2f8801343477834 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 4 Nov 2018 14:57:38 +0800 Subject: [PATCH 1989/2461] [SPARK-25638][SQL] Adding new function - to_csv() ## What changes were proposed in this pull request? New functions takes a struct and converts it to a CSV strings using passed CSV options. It accepts the same CSV options as CSV data source does. ## How was this patch tested? Added `CsvExpressionsSuite`, `CsvFunctionsSuite` as well as R, Python and SQL tests similar to tests for `to_json()` Closes #22626 from MaxGekk/to_csv. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 31 +++++++-- R/pkg/R/generics.R | 4 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 5 ++ python/pyspark/sql/functions.py | 22 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../catalyst}/csv/UnivocityGenerator.scala | 9 ++- .../catalyst/expressions/csvExpressions.scala | 67 +++++++++++++++++++ .../expressions/CsvExpressionsSuite.scala | 44 ++++++++++++ .../datasources/csv/CSVFileFormat.scala | 2 +- .../org/apache/spark/sql/functions.scala | 26 +++++++ .../sql-tests/inputs/csv-functions.sql | 6 ++ .../sql-tests/results/csv-functions.sql.out | 36 +++++++++- .../apache/spark/sql/CsvFunctionsSuite.scala | 14 +++- 14 files changed, 258 insertions(+), 12 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/csv/UnivocityGenerator.scala (94%) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f9f556e69a1fc..9d4f05af75afd 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -380,6 +380,7 @@ exportMethods("%<=>%", "tanh", "toDegrees", "toRadians", + "to_csv", "to_date", "to_json", "to_timestamp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d2ca1d6c00bb4..9292363d1ad2f 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -187,6 +187,7 @@ NULL #' \itemize{ #' \item \code{to_json}: it is the column containing the struct, array of the structs, #' the map or array of maps. +#' \item \code{to_csv}: it is the column containing the struct. #' \item \code{from_json}: it is the column containing the JSON string. #' \item \code{from_csv}: it is the column containing the CSV string. #' } @@ -204,11 +205,11 @@ NULL #' also supported for the schema. #' \item \code{from_csv}: a DDL-formatted string #' } -#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains -#' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. Additionally \code{to_json} supports the "pretty" -#' option which enables pretty JSON generation. In \code{arrays_zip}, this contains -#' additional Columns of arrays to be merged. +#' @param ... additional argument(s). In \code{to_json}, \code{to_csv} and \code{from_json}, +#' this contains additional named properties to control how it is converted, accepts +#' the same options as the JSON/CSV data source. Additionally \code{to_json} supports +#' the "pretty" option which enables pretty JSON generation. In \code{arrays_zip}, +#' this contains additional Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -1740,6 +1741,26 @@ setMethod("to_json", signature(x = "Column"), column(jc) }) +#' @details +#' \code{to_csv}: Converts a column containing a \code{structType} into a Column of CSV string. +#' Resolving the Column can fail if an unsupported type is encountered. +#' +#' @rdname column_collection_functions +#' @aliases to_csv to_csv,Column-method +#' @examples +#' +#' \dontrun{ +#' # Converts a struct into a CSV string +#' df2 <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df2, to_csv(df2$d, dateFormat = 'dd/MM/yyyy'))} +#' @note to_csv since 3.0.0 +setMethod("to_csv", signature(x = "Column"), + function(x, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", "to_csv", x@jc, options) + column(jc) + }) + #' @details #' \code{to_timestamp}: Converts the column into a TimestampType. You may optionally specify #' a format according to the rules in: diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 76e17c10843d2..463102c780b52 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1303,6 +1303,10 @@ setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("to_csv", function(x, ...) { standardGeneric("to_csv") }) + #' @rdname column_datetime_functions #' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 58e0a54d2aacc..faec387ce4eff 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1689,6 +1689,11 @@ test_that("column functions", { expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") } + # Test to_csv() + df <- sql("SELECT named_struct('name', 'Bob') as people") + j <- collect(select(df, alias(to_csv(df$people), "csv"))) + expect_equal(j[order(j$csv), ][1], "Bob") + # Test create_array() and create_map() df <- as.DataFrame(data.frame( x = c(1.0, 2.0), y = c(-1.0, 3.0), z = c(-2.0, 5.0) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index beb1a065d2803..24824efb47362 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2391,6 +2391,28 @@ def schema_of_csv(csv, options={}): return Column(jc) +@ignore_unicode_prefix +@since(3.0) +def to_csv(col, options={}): + """ + Converts a column containing a :class:`StructType` into a CSV string. + Throws an exception, in the case of an unsupported type. + + :param col: name of column containing a struct. + :param options: options to control converting. accepts the same options as the CSV datasource. + + >>> from pyspark.sql import Row + >>> data = [(1, Row(name='Alice', age=2))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_csv(df.value).alias("csv")).collect() + [Row(csv=u'2,Alice')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.to_csv(_to_java_column(col), options) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cf8fb7eea9580..c79f9906d266b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -527,7 +527,8 @@ object FunctionRegistry { // csv expression[CsvToStructs]("from_csv"), - expression[SchemaOfCsv]("schema_of_csv") + expression[SchemaOfCsv]("schema_of_csv"), + expression[StructsToCsv]("to_csv") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 37d9d9abc8680..1218f9242afeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[csv] class UnivocityGenerator( +class UnivocityGenerator( schema: StructType, writer: Writer, options: CSVOptions) { @@ -84,6 +83,10 @@ private[csv] class UnivocityGenerator( printHeader = false } + def writeToString(row: InternalRow): String = { + gen.writeRowToString(convertRow(row): _*) + } + def close(): Unit = gen.close() def flush(): Unit = gen.flush() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index e70296fe31292..74b670ae4b68a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.CharArrayWriter + import com.univocity.parsers.csv.CsvParser import org.apache.spark.sql.AnalysisException @@ -174,3 +176,68 @@ case class SchemaOfCsv( override def prettyName: String = "schema_of_csv" } + +/** + * Converts a [[StructType]] to a CSV output string. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr[, options]) - Returns a CSV string with a given struct value", + examples = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + 1,2 + > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + "26/08/2015" + """, + since = "3.0.0") +// scalastyle:on line.size.limit +case class StructsToCsv( + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + def this(options: Map[String, String], child: Expression) = this(options, child, None) + + // Used in `FunctionRegistry` + def this(child: Expression) = this(Map.empty, child, None) + + def this(child: Expression, options: Expression) = + this( + options = ExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + + @transient + lazy val writer = new CharArrayWriter() + + @transient + lazy val inputSchema: StructType = child.dataType match { + case st: StructType => st + case other => + throw new IllegalArgumentException(s"Unsupported input type ${other.catalogString}") + } + + @transient + lazy val gen = new UnivocityGenerator( + inputSchema, writer, new CSVOptions(options, columnPruning = true, timeZoneId.get)) + + // This converts rows to the CSV output according to the given schema. + @transient + lazy val converter: Any => UTF8String = { + (row: Any) => UTF8String.fromString(gen.writeToString(row.asInstanceOf[InternalRow])) + } + + override def dataType: DataType = StringType + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def nullSafeEval(value: Any): Any = converter(value) + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + + override def prettyName: String = "to_csv" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 386e0d133dff6..d006197bd5678 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -165,4 +165,48 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P new SchemaOfCsv(Literal.create("1|abc"), Map("delimiter" -> "|")), "struct<_c0:int,_c1:string>") } + + test("to_csv - struct") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), schema) + checkEvaluation(StructsToCsv(Map.empty, struct, gmtId), "1") + } + + test("to_csv null input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(null, schema) + checkEvaluation( + StructsToCsv(Map.empty, struct, gmtId), + null + ) + } + + test("to_csv with timestamp") { + val schema = StructType(StructField("t", TimestampType) :: Nil) + val c = Calendar.getInstance(DateTimeUtils.TimeZoneGMT) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + val struct = Literal.create(create_row(c.getTimeInMillis * 1000L), schema) + + checkEvaluation(StructsToCsv(Map.empty, struct, gmtId), "2016-01-01T00:00:00.000Z") + checkEvaluation( + StructsToCsv(Map.empty, struct, Option("PST")), "2015-12-31T16:00:00.000-08:00") + + checkEvaluation( + StructsToCsv( + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> gmtId.get), + struct, + gmtId), + "2016-01-01T00:00:00" + ) + checkEvaluation( + StructsToCsv( + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> "PST"), + struct, + gmtId), + "2015-12-31T16:00:00" + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 954a5a9cdecbb..964b56e706a0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f8c4d88cb1f7c..6bb1a490d8c3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3905,6 +3905,32 @@ object functions { withExpr(SchemaOfCsv(csv.expr, options.asScala.toMap)) } + /** + * (Java-specific) Converts a column containing a `StructType` into a CSV string with + * the specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a column containing a struct. + * @param options options to control how the struct column is converted into a CSV string. + * It accepts the same options and the json data source. + * + * @group collection_funcs + * @since 3.0.0 + */ + def to_csv(e: Column, options: java.util.Map[String, String]): Column = withExpr { + StructsToCsv(options.asScala.toMap, e.expr) + } + + /** + * Converts a column containing a `StructType` into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * + * @param e a column containing a struct. + * + * @group collection_funcs + * @since 3.0.0 + */ + def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String].asJava) + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql index 5be6f807931b8..a1a4bc9de3f97 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -15,3 +15,9 @@ CREATE TEMPORARY VIEW csvTable(csvField, a) AS SELECT * FROM VALUES ('1,abc', 'a SELECT schema_of_csv(csvField) FROM csvTable; -- Clean up DROP VIEW IF EXISTS csvTable; +-- to_csv +select to_csv(named_struct('a', 1, 'b', 2)); +select to_csv(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select to_csv(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_csv(named_struct('a', 1, 'b', 2), map('mode', 1)); diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out index 677bbd97c549d..03d4bfffa8923 100644 --- a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 17 -- !query 0 @@ -117,3 +117,37 @@ DROP VIEW IF EXISTS csvTable struct<> -- !query 12 output + + +-- !query 13 +select to_csv(named_struct('a', 1, 'b', 2)) +-- !query 13 schema +struct +-- !query 13 output +1,2 + + +-- !query 14 +select to_csv(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) +-- !query 14 schema +struct +-- !query 14 output +26/08/2015 + + +-- !query 15 +select to_csv(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 16 +select to_csv(named_struct('a', 1, 'b', 2), map('mode', 1)) +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got map;; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 9395f050b41ed..eb6b248e895f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -45,7 +45,6 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))) } - test("checking the columnNameOfCorruptRecord option") { val columnNameOfCorruptRecord = "_unparsed" val df = Seq("0,2013-111-11 12:13:14", "1,1983-08-04").toDS() @@ -74,4 +73,17 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) } + + test("to_csv - struct") { + val df = Seq(Tuple1(Tuple1(1))).toDF("a") + + checkAnswer(df.select(to_csv($"a")), Row("1") :: Nil) + } + + test("to_csv with option") { + val df = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") + val options = Map("timestampFormat" -> "dd/MM/yyyy HH:mm").asJava + + checkAnswer(df.select(to_csv($"a", options)), Row("26/08/2015 18:00") :: Nil) + } } From 463a6766876942e90f10d1ce2d1e36a8284bfbc2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 4 Nov 2018 14:59:33 +0800 Subject: [PATCH 1990/2461] [INFRA] Close stale PRs Closes https://github.com/apache/spark/pull/22859 Closes https://github.com/apache/spark/pull/22849 Closes https://github.com/apache/spark/pull/22591 Closes https://github.com/apache/spark/pull/22322 Closes https://github.com/apache/spark/pull/22312 Closes https://github.com/apache/spark/pull/19590 Closes #22934 from wangyum/CloseStalePRs. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon From 6c9e5ac9de3d0ae5ea86b768608b42b5feb46df4 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 5 Nov 2018 01:55:13 +0900 Subject: [PATCH 1991/2461] [SPARK-25776][CORE]The disk write buffer size must be greater than 12 ## What changes were proposed in this pull request? In `UnsafeSorterSpillWriter.java`, when we write a record to a spill file wtih ` void write(Object baseObject, long baseOffset, int recordLength, long keyPrefix)`, `recordLength` and `keyPrefix` will be written the disk write buffer first, and these will take 12 bytes, so the disk write buffer size must be greater than 12. If `diskWriteBufferSize` is 10, it will print this exception info: _java.lang.ArrayIndexOutOfBoundsException: 10 at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.writeLongToBuffer (UnsafeSorterSpillWriter.java:91) at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.write(UnsafeSorterSpillWriter.java:123) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spillIterator(UnsafeExternalSorter.java:498) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:222) at org.apache.spark.memory.MemoryConsumer.spill(MemoryConsumer.java:65)_ ## How was this patch tested? Existing UT in `UnsafeExternalSorterSuite` Closes #22754 from 10110346/diskWriteBufferSize. Authored-by: liuxian Signed-off-by: Kazuaki Ishizaki --- .../collection/unsafe/sort/UnsafeSorterSpillWriter.java | 5 ++++- .../scala/org/apache/spark/internal/config/package.scala | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 9399024f01783..c1d71a23b1dbe 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -42,7 +42,10 @@ public final class UnsafeSorterSpillWriter { private final SparkConf conf = new SparkConf(); - /** The buffer size to use when writing the sorted records to an on-disk file */ + /** + * The buffer size to use when writing the sorted records to an on-disk file, and + * this space used by prefix + len + recordLength must be greater than 4 + 8 bytes. + */ private final int diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 034e5ebbd293d..c8993e17bba67 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils package object config { @@ -504,8 +505,9 @@ package object config { ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") .doc("The buffer size, in bytes, to use when writing the sorted records to an on-disk file.") .bytesConf(ByteUnit.BYTE) - .checkValue(v => v > 0 && v <= Int.MaxValue, - s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") + .checkValue(v => v > 12 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, + s"The buffer size must be greater than 12 and less than or equal to " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") .createWithDefault(1024 * 1024) private[spark] val UNROLL_MEMORY_CHECK_PERIOD = From 950e7374a89cf45742a442afc08a74b6b4a7aa66 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 4 Nov 2018 17:41:42 -0800 Subject: [PATCH 1992/2461] [SPARK-25913][SQL] Extend UnaryExecNode by unary SparkPlan nodes ## What changes were proposed in this pull request? In the PR, I propose to extend `UnaryExecNode` instead of `SparkPlan` by unary nodes. Closes #22925 from MaxGekk/unary-exec-node. Authored-by: Maxim Gekk Signed-off-by: gatorsmile --- .../spark/sql/execution/command/commands.scala | 6 ++---- .../datasources/v2/WriteToDataSourceV2Exec.scala | 6 +++--- .../spark/sql/execution/python/EvalPythonExec.scala | 6 ++---- .../continuous/ContinuousCoalesceExec.scala | 12 +++--------- .../continuous/WriteToContinuousDataSourceExec.scala | 6 +++--- 5 files changed, 13 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2cc0e38adc2ee..ab40936eb3cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} -import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} @@ -95,7 +95,7 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { * @param child the physical plan child ran by the `DataWritingCommand`. */ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) - extends SparkPlan { + extends UnaryExecNode { override lazy val metrics: Map[String, SQLMetric] = cmd.metrics @@ -106,8 +106,6 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan) rows.map(converter(_).asInstanceOf[InternalRow]) } - override def children: Seq[SparkPlan] = child :: Nil - override def output: Seq[Attribute] = cmd.output override def nodeName: String = "Execute " + cmd.nodeName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index c3f7b690ef636..9a1fe1e0a328b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.util.Utils @@ -45,9 +45,9 @@ case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPl * The physical plan for writing data into data source v2. */ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) - extends SparkPlan { + extends UnaryExecNode { - override def children: Seq[SparkPlan] = Seq(query) + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 942a6db57416e..67dcdd3732b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -58,9 +58,7 @@ import org.apache.spark.util.Utils * RowQueue ALWAYS happened after pushing into it. */ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil + extends UnaryExecNode { override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala index 5f60343bacfaa..4c621890c9793 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -17,26 +17,20 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.UUID - -import org.apache.spark.{HashPartitioner, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} /** * Physical plan for coalescing a continuous processing plan. * * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. */ -case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = child.output - override def children: Seq[SparkPlan] = child :: Nil - override def outputPartitioning: Partitioning = SinglePartition override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index a797ac1879f41..2178466d63142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport @@ -32,8 +32,8 @@ import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. */ case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) - extends SparkPlan with Logging { - override def children: Seq[SparkPlan] = Seq(query) + extends UnaryExecNode with Logging { + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { From 4afb3503346455dced7f310a0b722d7f579ef5cb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 5 Nov 2018 15:53:06 +0800 Subject: [PATCH 1993/2461] [SPARK-25884][SQL][FOLLOW-UP] Add sample.json back. ## What changes were proposed in this pull request? This is a follow-up pr of #22892 which moved `sample.json` from hive module to sql module, but we still need the file in hive module. ## How was this patch tested? Existing tests. Closes #22942 from ueshin/issues/SPARK-25884/sample.json. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- sql/hive/src/test/resources/sample.json | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 sql/hive/src/test/resources/sample.json diff --git a/sql/hive/src/test/resources/sample.json b/sql/hive/src/test/resources/sample.json new file mode 100644 index 0000000000000..a2c2ffd5e0330 --- /dev/null +++ b/sql/hive/src/test/resources/sample.json @@ -0,0 +1,2 @@ +{"a" : "2" ,"b" : "blah", "c_!@(3)":1} +{"" : {"d!" : [4, 5], "=" : [{"Dd2": null}, {"Dd2" : true}]}} From e017cb39642a5039abd8ce8127ad41712901bdbc Mon Sep 17 00:00:00 2001 From: yucai Date: Mon, 5 Nov 2018 20:09:39 +0800 Subject: [PATCH 1994/2461] [SPARK-25850][SQL] Make the split threshold for the code generated function configurable ## What changes were proposed in this pull request? As per the discussion in [#22823](https://github.com/apache/spark/pull/22823/files#r228400706), add a new configuration to make the split threshold for the code generated function configurable. When the generated Java function source code exceeds `spark.sql.codegen.methodSplitThreshold`, it will be split into multiple small functions. ## How was this patch tested? manual tests Closes #22847 from yucai/splitThreshold. Authored-by: yucai Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/Expression.scala | 4 +++- .../expressions/codegen/CodeGenerator.scala | 3 ++- .../org/apache/spark/sql/internal/SQLConf.scala | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ccc5b9043a0aa..141fcffcb6fab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -121,7 +122,8 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + val splitThreshold = SQLConf.get.methodSplitThreshold + if (eval.code.length > splitThreshold && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d5857e060a2c4..b868a0f4fa284 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -910,12 +910,13 @@ class CodegenContext { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() var length = 0 + val splitThreshold = SQLConf.get.methodSplitThreshold for (code <- expressions) { // We can't know how many bytecode will be generated, so use the length of source code // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should // also not be too small, or it will have many function calls (for wide table), see the // results in BenchmarkWideTable. - if (length > 1024) { + if (length > splitThreshold) { blocks += blockBuilder.toString() blockBuilder.clear() length = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 535ec51e315d7..fa59fa578a969 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -818,6 +818,18 @@ object SQLConf { .intConf .createWithDefault(65535) + val CODEGEN_METHOD_SPLIT_THRESHOLD = buildConf("spark.sql.codegen.methodSplitThreshold") + .internal() + .doc("The threshold of source-code splitting in the codegen. When the number of characters " + + "in a single Java function (without comment) exceeds the threshold, the function will be " + + "automatically split to multiple smaller ones. We cannot know how many bytecode will be " + + "generated, so use the code length as metric. When running on HotSpot, a function's " + + "bytecode should not go beyond 8KB, otherwise it will not be JITted; it also should not " + + "be too small, otherwise there will be many function calls.") + .intConf + .checkValue(threshold => threshold > 0, "The threshold must be a positive integer.") + .createWithDefault(1024) + val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = buildConf("spark.sql.codegen.splitConsumeFuncByOperator") .internal() @@ -1739,6 +1751,8 @@ class SQLConf extends Serializable with Logging { def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) + def methodSplitThreshold: Int = getConf(CODEGEN_METHOD_SPLIT_THRESHOLD) + def wholeStageSplitConsumeFuncByOperator: Boolean = getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR) From 1fb3759f2b60a2e7c5e2a82afe1a580d848e0f8c Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Mon, 5 Nov 2018 08:40:25 -0600 Subject: [PATCH 1995/2461] [SPARK-25930][K8S] Fix scala string detection in k8s tests ## What changes were proposed in this pull request? - Issue is described in detail in [SPARK-25930](https://issues.apache.org/jira/browse/SPARK-25930). Since we rely on the std output, pick always the last line which contains the wanted value. Although minor, current implementation breaks tests. ## How was this patch tested? manually. rm -rf ~/.m2 and then run the tests. Closes #22931 from skonto/fix_scala_detection. Authored-by: Stavros Kontopoulos Signed-off-by: Sean Owen --- .../integration-tests/dev/dev-run-integration-tests.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 3c7cc9369047a..68f284ca1d1ce 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -29,7 +29,12 @@ SERVICE_ACCOUNT= CONTEXT= INCLUDE_TAGS="k8s" EXCLUDE_TAGS= -SCALA_VERSION="$($TEST_ROOT_DIR/build/mvn org.apache.maven.plugins:maven-help-plugin:2.1.1:evaluate -Dexpression=scala.binary.version | grep -v '\[' )" +MVN="$TEST_ROOT_DIR/build/mvn" + +SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/null\ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) # Parse arguments while (( "$#" )); do From fc65b4af00c0a813613a7977126e942df8440bbb Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 5 Nov 2018 09:13:53 -0600 Subject: [PATCH 1996/2461] [SPARK-25900][WEBUI] When the page number is more than the total page size, then fall back to the first page ## What changes were proposed in this pull request? When we give the page number more than the maximum page number, webui is throwing an exception. It would be better if fall back to the default page, instead of throwing the exception in the web ui. ## How was this patch tested? Before PR: ![screenshot from 2018-10-31 23-41-37](https://user-images.githubusercontent.com/23054875/47816448-354fbe80-dd79-11e8-83d8-6aab196642f7.png) After PR: ![screenshot from 2018-10-31 23-54-23](https://user-images.githubusercontent.com/23054875/47816461-3ed92680-dd79-11e8-959d-0c531b3a6b2d.png) Closes #22914 from shahidki31/pageFallBack. Authored-by: Shahid Signed-off-by: Sean Owen --- .../org/apache/spark/ui/PagedTable.scala | 51 ++++++++++++------- .../apache/spark/ui/jobs/AllJobsPage.scala | 17 +------ .../org/apache/spark/ui/jobs/StagePage.scala | 16 +----- .../org/apache/spark/ui/jobs/StageTable.scala | 19 +------ .../org/apache/spark/ui/storage/RDDPage.scala | 16 +----- .../org/apache/spark/ui/PagedTableSuite.scala | 17 ++----- .../sql/execution/ui/AllExecutionsPage.scala | 15 +----- 7 files changed, 45 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 0bbb10a995bcb..6c2c1f6827948 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -33,10 +33,6 @@ import org.apache.spark.util.Utils */ private[spark] abstract class PagedDataSource[T](val pageSize: Int) { - if (pageSize <= 0) { - throw new IllegalArgumentException("Page size must be positive") - } - /** * Return the size of all data. */ @@ -51,13 +47,24 @@ private[spark] abstract class PagedDataSource[T](val pageSize: Int) { * Slice the data for this page */ def pageData(page: Int): PageData[T] = { - val totalPages = (dataSize + pageSize - 1) / pageSize - if (page <= 0 || page > totalPages) { - throw new IndexOutOfBoundsException( - s"Page $page is out of range. Please select a page number between 1 and $totalPages.") + // Display all the data in one page, if the pageSize is less than or equal to zero. + val pageTableSize = if (pageSize <= 0) { + dataSize + } else { + pageSize + } + val totalPages = (dataSize + pageTableSize - 1) / pageTableSize + + val pageToShow = if (page <= 0) { + 1 + } else if (page > totalPages) { + totalPages + } else { + page } - val from = (page - 1) * pageSize - val to = dataSize.min(page * pageSize) + + val (from, to) = ((pageToShow - 1) * pageSize, dataSize.min(pageToShow * pageTableSize)) + PageData(totalPages, sliceData(from, to)) } @@ -80,8 +87,6 @@ private[spark] trait PagedTable[T] { def pageSizeFormField: String - def prevPageSizeFormField: String - def pageNumberFormField: String def dataSource: PagedDataSource[T] @@ -94,7 +99,23 @@ private[spark] trait PagedTable[T] { val _dataSource = dataSource try { val PageData(totalPages, data) = _dataSource.pageData(page) - val pageNavi = pageNavigation(page, _dataSource.pageSize, totalPages) + + val pageToShow = if (page <= 0) { + 1 + } else if (page > totalPages) { + totalPages + } else { + page + } + // Display all the data in one page, if the pageSize is less than or equal to zero. + val pageSize = if (_dataSource.pageSize <= 0) { + data.size + } else { + _dataSource.pageSize + } + + val pageNavi = pageNavigation(pageToShow, pageSize, totalPages) +
      {pageNavi} @@ -180,7 +201,6 @@ private[spark] trait PagedTable[T] { .split(search) .asScala .filterKeys(_ != pageSizeFormField) - .filterKeys(_ != prevPageSizeFormField) .filterKeys(_ != pageNumberFormField) .mapValues(URLDecoder.decode(_, "UTF-8")) .map { case (k, v) => @@ -198,9 +218,6 @@ private[spark] trait PagedTable[T] { action={Unparsed(goButtonFormPath)} class="form-inline pull-right" style="margin-bottom: 0px;"> - {hiddenFormFields} @@ -231,17 +230,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobSortColumn == jobIdTitle ) val jobPageSize = Option(parameterJobPageSize).map(_.toInt).getOrElse(100) - val jobPrevPageSize = Option(parameterJobPrevPageSize).map(_.toInt).getOrElse(jobPageSize) - - val page: Int = { - // If the user has changed to a larger page size, then go to page 1 in order to avoid - // IndexOutOfBoundsException. - if (jobPageSize <= jobPrevPageSize) { - jobPage - } else { - 1 - } - } + val currentTime = System.currentTimeMillis() try { @@ -259,7 +248,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We pageSize = jobPageSize, sortColumn = jobSortColumn, desc = jobSortDesc - ).table(page) + ).table(jobPage) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
      @@ -526,8 +515,6 @@ private[ui] class JobPagedTable( override def pageSizeFormField: String = jobTag + ".pageSize" - override def prevPageSizeFormField: String = jobTag + ".prevPageSize" - override def pageNumberFormField: String = jobTag + ".page" override val dataSource = new JobDataSource( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0f74b07a6265c..477b9ce7f7848 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -91,7 +91,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) - val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => @@ -99,8 +98,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We }.getOrElse("Index") val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) - val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) - val stageId = parameterId.toInt val stageAttemptId = parameterAttempt.toInt @@ -278,15 +275,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We accumulableRow, stageData.accumulatorUpdates.toSeq) - val page: Int = { - // If the user has changed to a larger page size, then go to page 1 in order to avoid - // IndexOutOfBoundsException. - if (taskPageSize <= taskPrevPageSize) { - taskPage - } else { - 1 - } - } val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( @@ -299,7 +287,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We desc = taskSortDesc, store = parent.store ) - (_taskTable, _taskTable.table(page)) + (_taskTable, _taskTable.table(taskPage)) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => val errorMessage = @@ -732,8 +720,6 @@ private[ui] class TaskPagedTable( override def pageSizeFormField: String = "task.pageSize" - override def prevPageSizeFormField: String = "task.prevPageSize" - override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index d01acdae59c9f..b9abd39b4705d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -53,8 +53,6 @@ private[ui] class StageTableBase( val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort")) val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc")) val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize")) - val parameterStagePrevPageSize = - UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize")) val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1) val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn => @@ -65,18 +63,7 @@ private[ui] class StageTableBase( stageSortColumn == "Stage Id" ) val stagePageSize = Option(parameterStagePageSize).map(_.toInt).getOrElse(100) - val stagePrevPageSize = Option(parameterStagePrevPageSize).map(_.toInt) - .getOrElse(stagePageSize) - - val page: Int = { - // If the user has changed to a larger page size, then go to page 1 in order to avoid - // IndexOutOfBoundsException. - if (stagePageSize <= stagePrevPageSize) { - stagePage - } else { - 1 - } - } + val currentTime = System.currentTimeMillis() val toNodeSeq = try { @@ -96,7 +83,7 @@ private[ui] class StageTableBase( isFailedStage, parameterOtherTable, request - ).table(page) + ).table(stagePage) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
      @@ -161,8 +148,6 @@ private[ui] class StagePagedTable( override def pageSizeFormField: String = stageTag + ".pageSize" - override def prevPageSizeFormField: String = stageTag + ".prevPageSize" - override def pageNumberFormField: String = stageTag + ".page" val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 238cd31433660..87da290c83057 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -39,13 +39,11 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort")) val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc")) val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize")) - val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize")) val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) - val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize) val rddId = parameterId.toInt val rddStorageInfo = try { @@ -60,16 +58,6 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web val workerTable = UIUtils.listingTable(workerHeader, workerRow, rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table")) - // Block table - val page: Int = { - // If the user has changed to a larger page size, then go to page 1 in order to avoid - // IndexOutOfBoundsException. - if (blockPageSize <= blockPrevPageSize) { - blockPage - } else { - 1 - } - } val blockTableHTML = try { val _blockTable = new BlockPagedTable( UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", @@ -78,7 +66,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web blockSortColumn, blockSortDesc, store.executorList(true)) - _blockTable.table(page) + _blockTable.table(blockPage) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
      {e.getMessage}
      @@ -242,8 +230,6 @@ private[ui] class BlockPagedTable( override def pageSizeFormField: String = "block.pageSize" - override def prevPageSizeFormField: String = "block.prevPageSize" - override def pageNumberFormField: String = "block.page" override val dataSource: BlockDataSource = new BlockDataSource( diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala index cda98ae25a57a..d18f55474bdb3 100644 --- a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -32,19 +32,12 @@ class PagedDataSourceSuite extends SparkFunSuite { val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) assert(dataSource3.pageData(3) === PageData(3, Seq(5))) - + // If the page number is more than maximum page, fall back to the last page val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) - val e1 = intercept[IndexOutOfBoundsException] { - dataSource4.pageData(4) - } - assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") - + assert(dataSource4.pageData(4) === PageData(3, Seq(5))) + // If the page number is less than or equal to zero, fall back to the first page val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) - val e2 = intercept[IndexOutOfBoundsException] { - dataSource5.pageData(0) - } - assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") - + assert(dataSource5.pageData(0) === PageData(3, 1 to 2)) } } @@ -66,8 +59,6 @@ class PagedTableSuite extends SparkFunSuite { override def pageSizeFormField: String = "pageSize" - override def prevPageSizeFormField: String = "prevPageSize" - override def pageNumberFormField: String = "page" override def goButtonFormPath: String = "" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 311f805f78832..4958f154e625f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -171,8 +171,6 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val parameterExecutionSortDesc = UIUtils.stripXSS(request.getParameter(s"$executionTag.desc")) val parameterExecutionPageSize = UIUtils.stripXSS(request .getParameter(s"$executionTag.pageSize")) - val parameterExecutionPrevPageSize = UIUtils.stripXSS(request - .getParameter(s"$executionTag.prevPageSize")) val executionPage = Option(parameterExecutionPage).map(_.toInt).getOrElse(1) val executionSortColumn = Option(parameterExecutionSortColumn).map { sortColumn => @@ -183,16 +181,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L executionSortColumn == "ID" ) val executionPageSize = Option(parameterExecutionPageSize).map(_.toInt).getOrElse(100) - val executionPrevPageSize = Option(parameterExecutionPrevPageSize).map(_.toInt) - .getOrElse(executionPageSize) - // If the user has changed to a larger page size, then go to page 1 in order to avoid - // IndexOutOfBoundsException. - val page: Int = if (executionPageSize <= executionPrevPageSize) { - executionPage - } else { - 1 - } val tableHeaderId = executionTag // "running", "completed" or "failed" try { @@ -211,7 +200,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L desc = executionSortDesc, showRunningJobs, showSucceededJobs, - showFailedJobs).table(page) + showFailedJobs).table(executionPage) } catch { case e@(_: IllegalArgumentException | _: IndexOutOfBoundsException) =>
      @@ -262,8 +251,6 @@ private[ui] class ExecutionPagedTable( "table table-bordered table-condensed table-striped " + "table-head-clickable table-cell-width-limited" - override def prevPageSizeFormField: String = s"$executionTag.prevPageSize" - override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") parameterPath + From fc10c898f45a25cf3751f0cd042e4c0743f1adba Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 5 Nov 2018 22:13:20 +0000 Subject: [PATCH 1997/2461] [SPARK-25758][ML] Deprecate computeCost in BisectingKMeans ## What changes were proposed in this pull request? The PR proposes to deprecate the `computeCost` method on `BisectingKMeans` in favor of the adoption of `ClusteringEvaluator` in order to evaluate the clustering. ## How was this patch tested? NA Closes #22869 from mgaido91/SPARK-25758_3.0. Authored-by: Marco Gaido Signed-off-by: DB Tsai --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 7 +++++++ python/pyspark/ml/clustering.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 5cb16cc765887..1a94aefa3f563 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -125,8 +125,15 @@ class BisectingKMeansModel private[ml] ( /** * Computes the sum of squared distances between the input points and their corresponding cluster * centers. + * + * @deprecated This method is deprecated and will be removed in future versions. Use + * ClusteringEvaluator instead. You can also get the cost on the training dataset in + * the summary. */ @Since("2.0.0") + @deprecated("This method is deprecated and will be removed in future versions. Use " + + "ClusteringEvaluator instead. You can also get the cost on the training dataset in the " + + "summary.", "3.0.0") def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 5ef4e765ea4e1..b37129428f491 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -540,7 +540,14 @@ def computeCost(self, dataset): """ Computes the sum of squared distances between the input points and their corresponding cluster centers. + + ..note:: Deprecated in 3.0.0. It will be removed in future versions. Use + ClusteringEvaluator instead. You can also get the cost on the training dataset in the + summary. """ + warnings.warn("Deprecated in 3.0.0. It will be removed in future versions. Use " + "ClusteringEvaluator instead. You can also get the cost on the training " + "dataset in the summary.", DeprecationWarning) return self._call_java("computeCost", dataset) @property From 486acda8c5a421b440571629730dfa6b02af9b80 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 5 Nov 2018 14:26:22 -0800 Subject: [PATCH 1998/2461] [SPARK-25944][R][BUILD] AppVeyor change to latest R version (3.5.1) ## What changes were proposed in this pull request? R 3.5.1 is released 2018-07-02. This PR targets to changes R version from 3.4.1 to 3.5.1. ## How was this patch tested? AppVeyor Closes #22948 from HyukjinKwon/SPARK-25944. Authored-by: hyukjinkwon Signed-off-by: Dongjoon Hyun --- dev/appveyor-install-dependencies.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index c91882851847b..06d9d70af311a 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -115,7 +115,7 @@ $env:Path += ";$env:HADOOP_HOME\bin" Pop-Location # ========================== R -$rVer = "3.4.1" +$rVer = "3.5.1" $rToolsVer = "3.4.0" InstallR From 0b59170001be1cc1198cfc1c0486ca34633e64d5 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 5 Nov 2018 22:42:04 +0000 Subject: [PATCH 1999/2461] [SPARK-25764][ML][EXAMPLES] Update BisectingKMeans example to use ClusteringEvaluator ## What changes were proposed in this pull request? Using `computeCost` for evaluating a model is a very poor approach. We should advice the users to a better approach which is available, ie. using the `ClusteringEvaluator` to evaluate their models. The PR updates the examples for `BisectingKMeans` in order to do that. ## How was this patch tested? running examples Closes #22786 from mgaido91/SPARK-25764. Authored-by: Marco Gaido Signed-off-by: DB Tsai --- .../examples/ml/JavaBisectingKMeansExample.java | 12 +++++++++--- .../src/main/python/ml/bisecting_k_means_example.py | 12 +++++++++--- .../spark/examples/ml/BisectingKMeansExample.scala | 12 +++++++++--- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java index 8c82aaaacca38..f517dc314b2b7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.clustering.BisectingKMeans; import org.apache.spark.ml.clustering.BisectingKMeansModel; +import org.apache.spark.ml.evaluation.ClusteringEvaluator; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -50,9 +51,14 @@ public static void main(String[] args) { BisectingKMeans bkm = new BisectingKMeans().setK(2).setSeed(1); BisectingKMeansModel model = bkm.fit(dataset); - // Evaluate clustering. - double cost = model.computeCost(dataset); - System.out.println("Within Set Sum of Squared Errors = " + cost); + // Make predictions + Dataset predictions = model.transform(dataset); + + // Evaluate clustering by computing Silhouette score + ClusteringEvaluator evaluator = new ClusteringEvaluator(); + + double silhouette = evaluator.evaluate(predictions); + System.out.println("Silhouette with squared euclidean distance = " + silhouette); // Shows the result. System.out.println("Cluster Centers: "); diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index 7842d2009e238..82adb338b5d91 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -24,6 +24,7 @@ # $example on$ from pyspark.ml.clustering import BisectingKMeans +from pyspark.ml.evaluation import ClusteringEvaluator # $example off$ from pyspark.sql import SparkSession @@ -41,9 +42,14 @@ bkm = BisectingKMeans().setK(2).setSeed(1) model = bkm.fit(dataset) - # Evaluate clustering. - cost = model.computeCost(dataset) - print("Within Set Sum of Squared Errors = " + str(cost)) + # Make predictions + predictions = model.transform(dataset) + + # Evaluate clustering by computing Silhouette score + evaluator = ClusteringEvaluator() + + silhouette = evaluator.evaluate(predictions) + print("Silhouette with squared euclidean distance = " + str(silhouette)) # Shows the result. print("Cluster Centers: ") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala index 5f8f2c99cbaf4..14e13df02733b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala @@ -21,6 +21,7 @@ package org.apache.spark.examples.ml // $example on$ import org.apache.spark.ml.clustering.BisectingKMeans +import org.apache.spark.ml.evaluation.ClusteringEvaluator // $example off$ import org.apache.spark.sql.SparkSession @@ -48,9 +49,14 @@ object BisectingKMeansExample { val bkm = new BisectingKMeans().setK(2).setSeed(1) val model = bkm.fit(dataset) - // Evaluate clustering. - val cost = model.computeCost(dataset) - println(s"Within Set Sum of Squared Errors = $cost") + // Make predictions + val predictions = model.transform(dataset) + + // Evaluate clustering by computing Silhouette score + val evaluator = new ClusteringEvaluator() + + val silhouette = evaluator.evaluate(predictions) + println(s"Silhouette with squared euclidean distance = $silhouette") // Shows the result. println("Cluster Centers: ") From c0d1bf0322be12230c30cb200f19a02e4d5e0d49 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 5 Nov 2018 17:34:23 -0600 Subject: [PATCH 2000/2461] [MINOR] Fix typos and misspellings ## What changes were proposed in this pull request? Fix typos and misspellings, per https://github.com/apache/spark-website/pull/158#issuecomment-435790366 ## How was this patch tested? Existing tests. Closes #22950 from srowen/Typos. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../java/org/apache/spark/ExecutorPlugin.java | 6 +++--- .../org/apache/spark/ExecutorPluginSuite.java | 4 ++-- docs/sql-migration-guide-upgrade.md | 2 +- .../ml/r/AFTSurvivalRegressionWrapper.scala | 6 +++--- .../org/apache/spark/ml/stat/Summarizer.scala | 4 ++-- .../stat/MultivariateOnlineSummarizer.scala | 2 +- python/pyspark/ml/stat.py | 2 +- .../spark/sql/hive/CachedTableSuite.scala | 17 ++++++++--------- 8 files changed, 21 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/ExecutorPlugin.java index ec0b57f1a2819..f86520c81df33 100644 --- a/core/src/main/java/org/apache/spark/ExecutorPlugin.java +++ b/core/src/main/java/org/apache/spark/ExecutorPlugin.java @@ -20,18 +20,18 @@ import org.apache.spark.annotation.DeveloperApi; /** - * A plugin which can be automaticaly instantiated within each Spark executor. Users can specify + * A plugin which can be automatically instantiated within each Spark executor. Users can specify * plugins which should be created with the "spark.executor.plugins" configuration. An instance * of each plugin will be created for every executor, including those created by dynamic allocation, * before the executor starts running any tasks. * * The specific api exposed to the end users still considered to be very unstable. We will - * hopefully be able to keep compatability by providing default implementations for any methods + * hopefully be able to keep compatibility by providing default implementations for any methods * added, but make no guarantees this will always be possible across all Spark releases. * * Spark does nothing to verify the plugin is doing legitimate things, or to manage the resources * it uses. A plugin acquires the same privileges as the user running the task. A bad plugin - * could also intefere with task execution and make the executor fail in unexpected ways. + * could also interfere with task execution and make the executor fail in unexpected ways. */ @DeveloperApi public interface ExecutorPlugin { diff --git a/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java b/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java index 686eb28010c6a..80cd70282a51d 100644 --- a/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java +++ b/core/src/test/java/org/apache/spark/ExecutorPluginSuite.java @@ -63,10 +63,10 @@ private SparkConf initializeSparkConf(String pluginNames) { @Test public void testPluginClassDoesNotExist() { - SparkConf conf = initializeSparkConf("nonexistant.plugin"); + SparkConf conf = initializeSparkConf("nonexistent.plugin"); try { sc = new JavaSparkContext(conf); - fail("No exception thrown for nonexistant plugin"); + fail("No exception thrown for nonexistent plugin"); } catch (Exception e) { // We cannot catch ClassNotFoundException directly because Java doesn't think it'll be thrown assertTrue(e.toString().startsWith("java.lang.ClassNotFoundException")); diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index c9685b866774f..50458e96f7c3f 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -117,7 +117,7 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was writted as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. + - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 48485e02edda8..1b5f77a9ae897 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -62,7 +62,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg private val FORMULA_REGEXP = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r private def formulaRewrite(formula: String): (String, String) = { - var rewritedFormula: String = null + var rewrittenFormula: String = null var censorCol: String = null try { val FORMULA_REGEXP(label, censor, features) = formula @@ -71,14 +71,14 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg throw new UnsupportedOperationException( "Terms of survreg formula can not support dot operator.") } - rewritedFormula = label.trim + "~" + features.trim + rewrittenFormula = label.trim + "~" + features.trim censorCol = censor.trim } catch { case e: MatchError => throw new SparkException(s"Could not parse formula: $formula") } - (rewritedFormula, censorCol) + (rewrittenFormula, censorCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index d40827edb6d64..ed7d7e0852647 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -96,7 +96,7 @@ object Summarizer extends Logging { * - numNonzeros: a vector with the number of non-zeros for each coefficients * - max: the maximum for each coefficient. * - min: the minimum for each coefficient. - * - normL2: the Euclidian norm for each coefficient. + * - normL2: the Euclidean norm for each coefficient. * - normL1: the L1 norm of each coefficient (sum of the absolute values). * @param metrics metrics that can be provided. * @return a builder. @@ -536,7 +536,7 @@ private[ml] object SummaryBuilderImpl extends Logging { } /** - * L2 (Euclidian) norm of each dimension. + * L2 (Euclidean) norm of each dimension. */ def normL2: Vector = { require(requestedMetrics.contains(NormL2)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 8121880cfb233..0554b6d8ff5b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -273,7 +273,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * L2 (Euclidian) norm of each dimension. + * L2 (Euclidean) norm of each dimension. * */ @Since("1.2.0") diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 370154fc6d62a..3f421024acdce 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -336,7 +336,7 @@ def metrics(*metrics): - numNonzeros: a vector with the number of non-zeros for each coefficients - max: the maximum for each coefficient. - min: the minimum for each coefficient. - - normL2: the Euclidian norm for each coefficient. + - normL2: the Euclidean norm for each coefficient. - normL1: the L1 norm of each coefficient (sum of the absolute values). :param metrics: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 569f00c053e5f..b492f39df62f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql.{AnalysisException, Dataset, QueryTest, SaveMode} -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation} @@ -97,24 +96,24 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("DROP nonexistant table") { - sql("DROP TABLE IF EXISTS nonexistantTable") + test("DROP nonexistent table") { + sql("DROP TABLE IF EXISTS nonexistentTable") } - test("uncache of nonexistant tables") { - val expectedErrorMsg = "Table or view not found: nonexistantTable" + test("uncache of nonexistent tables") { + val expectedErrorMsg = "Table or view not found: nonexistentTable" // make sure table doesn't exist - var e = intercept[AnalysisException](spark.table("nonexistantTable")).getMessage + var e = intercept[AnalysisException](spark.table("nonexistentTable")).getMessage assert(e.contains(expectedErrorMsg)) e = intercept[AnalysisException] { - spark.catalog.uncacheTable("nonexistantTable") + spark.catalog.uncacheTable("nonexistentTable") }.getMessage assert(e.contains(expectedErrorMsg)) e = intercept[AnalysisException] { - sql("UNCACHE TABLE nonexistantTable") + sql("UNCACHE TABLE nonexistentTable") }.getMessage assert(e.contains(expectedErrorMsg)) - sql("UNCACHE TABLE IF EXISTS nonexistantTable") + sql("UNCACHE TABLE IF EXISTS nonexistentTable") } test("no error on uncache of non-cached table") { From 78fa1be29bc9fbe98dd0226418aafc221c5e5309 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Nov 2018 09:18:17 +0800 Subject: [PATCH 2001/2461] [SPARK-25926][CORE] Move config entries in core module to internal.config. ## What changes were proposed in this pull request? Currently definitions of config entries in `core` module are in several files separately. We should move them into `internal/config` to be easy to manage. ## How was this patch tested? Existing tests. Closes #22928 from ueshin/issues/SPARK-25926/single_config_file. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- core/src/main/scala/org/apache/spark/SparkConf.scala | 2 +- .../spark/deploy/history/FsHistoryProvider.scala | 4 ++-- .../apache/spark/deploy/history/HistoryServer.scala | 2 +- .../deploy/history/HistoryServerDiskManager.scala | 3 +-- .../config.scala => internal/config/History.scala} | 6 ++---- .../config.scala => internal/config/Status.scala} | 12 ++++++++---- .../org/apache/spark/status/AppStatusListener.scala | 3 +-- .../org/apache/spark/status/AppStatusSource.scala | 11 ++--------- .../apache/spark/status/ElementTrackingStore.scala | 3 +-- .../test/scala/org/apache/spark/SparkConfSuite.scala | 2 +- .../deploy/history/FsHistoryProviderSuite.scala | 2 +- .../history/HistoryServerDiskManagerSuite.scala | 3 +-- .../spark/deploy/history/HistoryServerSuite.scala | 2 +- .../apache/spark/status/AppStatusListenerSuite.scala | 3 +-- .../spark/status/ElementTrackingStoreSuite.scala | 3 +-- .../scala/org/apache/spark/ui/StagePageSuite.scala | 2 +- .../scala/org/apache/spark/ui/UISeleniumSuite.scala | 2 +- .../cluster/mesos/MesosSchedulerUtils.scala | 2 +- .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- .../sql/execution/ui/SQLAppStatusListenerSuite.scala | 2 +- 20 files changed, 30 insertions(+), 41 deletions(-) rename core/src/main/scala/org/apache/spark/{deploy/history/config.scala => internal/config/History.scala} (95%) rename core/src/main/scala/org/apache/spark/{status/config.scala => internal/config/Status.scala} (83%) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8537c536887e6..21c5cbc04d813 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -25,9 +25,9 @@ import scala.collection.mutable.LinkedHashSet import org.apache.avro.{Schema, SchemaNormalization} -import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.History._ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index c4517d3dfd931..2230bc8d6c641 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -42,13 +42,14 @@ import org.fusesource.leveldbjni.internal.NativeDB import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.History._ +import org.apache.spark.internal.config.Status._ import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ import org.apache.spark.status._ import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} -import org.apache.spark.status.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} import org.apache.spark.util.kvstore._ @@ -86,7 +87,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) this(conf, new SystemClock()) } - import config._ import FsHistoryProvider._ // Interval between safemode checks. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 56f3f59504a7d..5856c7057b745 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,9 +28,9 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.History.HISTORY_SERVER_UI_PORT import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala index ad0dd23cb59c8..0a1f33395ad62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerDiskManager.scala @@ -27,6 +27,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.History._ import org.apache.spark.status.KVUtils._ import org.apache.spark.util.{Clock, Utils} import org.apache.spark.util.kvstore.KVStore @@ -50,8 +51,6 @@ private class HistoryServerDiskManager( listing: KVStore, clock: Clock) extends Logging { - import config._ - private val appStoreDir = new File(path, "apps") if (!appStoreDir.isDirectory() && !appStoreDir.mkdir()) { throw new IllegalArgumentException(s"Failed to create app directory ($appStoreDir).") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/internal/config/History.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/deploy/history/config.scala rename to core/src/main/scala/org/apache/spark/internal/config/History.scala index 25ba9edb9e014..3f74eb3d13d4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/History.scala @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.deploy.history +package org.apache.spark.internal.config import java.util.concurrent.TimeUnit -import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.network.util.ByteUnit -private[spark] object config { +private[spark] object History { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" @@ -63,5 +62,4 @@ private[spark] object config { "parts of event log files. It can be disabled by setting this config to 0.") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1m") - } diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/internal/config/Status.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/status/config.scala rename to core/src/main/scala/org/apache/spark/internal/config/Status.scala index 67801b8f046f4..c56157227f8fc 100644 --- a/core/src/main/scala/org/apache/spark/status/config.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Status.scala @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.status +package org.apache.spark.internal.config import java.util.concurrent.TimeUnit -import org.apache.spark.internal.config._ - -private[spark] object config { +private[spark] object Status { val ASYNC_TRACKING_ENABLED = ConfigBuilder("spark.appStateStore.asyncTracking.enable") .booleanConf @@ -51,4 +49,10 @@ private[spark] object config { .intConf .createWithDefault(Int.MaxValue) + val APP_STATUS_METRICS_ENABLED = + ConfigBuilder("spark.app.status.metrics.enabled") + .doc("Whether Dropwizard/Codahale metrics " + + "will be reported for the status of the running spark app.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index e2c190ea198e0..81d39e0407fed 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Status._ import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1 import org.apache.spark.storage._ @@ -47,8 +48,6 @@ private[spark] class AppStatusListener( appStatusSource: Option[AppStatusSource] = None, lastUpdateTime: Option[Long] = None) extends SparkListener with Logging { - import config._ - private var sparkVersion = SPARK_VERSION private var appInfo: v1.ApplicationInfo = null private var appSummary = new AppSummary(0, 0) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusSource.scala b/core/src/main/scala/org/apache/spark/status/AppStatusSource.scala index 3ab293dd648b5..f6a21578ff499 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusSource.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusSource.scala @@ -22,7 +22,7 @@ import AppStatusSource.getCounter import com.codahale.metrics.{Counter, Gauge, MetricRegistry} import org.apache.spark.SparkConf -import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.internal.config.Status.APP_STATUS_METRICS_ENABLED import org.apache.spark.metrics.source.Source private [spark] class JobDuration(val value: AtomicLong) extends Gauge[Long] { @@ -71,15 +71,8 @@ private[spark] object AppStatusSource { } def createSource(conf: SparkConf): Option[AppStatusSource] = { - Option(conf.get(AppStatusSource.APP_STATUS_METRICS_ENABLED)) + Option(conf.get(APP_STATUS_METRICS_ENABLED)) .filter(identity) .map { _ => new AppStatusSource() } } - - val APP_STATUS_METRICS_ENABLED = - ConfigBuilder("spark.app.status.metrics.enabled") - .doc("Whether Dropwizard/Codahale metrics " + - "will be reported for the status of the running spark app.") - .booleanConf - .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala b/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala index 863b0967f765e..5ec7d90bfaaba 100644 --- a/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala +++ b/core/src/main/scala/org/apache/spark/status/ElementTrackingStore.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.{HashMap, ListBuffer} import com.google.common.util.concurrent.MoreExecutors import org.apache.spark.SparkConf +import org.apache.spark.internal.config.Status._ import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.util.kvstore._ @@ -45,8 +46,6 @@ import org.apache.spark.util.kvstore._ */ private[spark] class ElementTrackingStore(store: KVStore, conf: SparkConf) extends KVStore { - import config._ - private val triggers = new HashMap[Class[_], Seq[Trigger[_]]]() private val flushTriggers = new ListBuffer[() => Unit]() private val executor = if (conf.get(ASYNC_TRACKING_ENABLED)) { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 0d06b02e74e34..df274d949bae3 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -26,8 +26,8 @@ import scala.util.{Random, Try} import com.esotericsoftware.kryo.Kryo -import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.History._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{JavaSerializer, KryoRegistrator, KryoSerializer} import org.apache.spark.util.{ResetSystemProperties, RpcUtils} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 6a761d43a5a68..87d585278a747 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -39,8 +39,8 @@ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.History._ import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.security.GroupMappingServiceProvider diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala index 4b1b921582e00..341a1e2443df0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala @@ -25,14 +25,13 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.History._ import org.apache.spark.status.KVUtils import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.util.kvstore.KVStore class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { - import config._ - private val MAX_USAGE = 3L private var testDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 11a2db81f7c6d..7c9f8aba17f3c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -45,7 +45,7 @@ import org.scalatest.mockito.MockitoSugar import org.scalatest.selenium.WebBrowser import org.apache.spark._ -import org.apache.spark.deploy.history.config._ +import org.apache.spark.internal.config.History._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 0b2bbd2fa8a78..bfd73069fbff8 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.internal.config.Status._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster._ @@ -38,8 +39,6 @@ import org.apache.spark.util.Utils class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { - import config._ - private val conf = new SparkConf() .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) .set(ASYNC_TRACKING_ENABLED, false) diff --git a/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala index 07a7b58404c29..a99c1ec7e1f07 100644 --- a/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/ElementTrackingStoreSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.status import org.mockito.Mockito._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.Status._ import org.apache.spark.util.kvstore._ class ElementTrackingStoreSuite extends SparkFunSuite { - import config._ - test("tracking for multiple types") { val store = mock(classOf[KVStore]) val tracking = new ElementTrackingStore(store, new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 6044563f7dde7..2945c3ee0a9d9 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -26,10 +26,10 @@ import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.Status._ import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} -import org.apache.spark.status.config._ import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable} class StagePageSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e86cadfeebcff..b04b065f9ecb5 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -40,9 +40,9 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE +import org.apache.spark.internal.config.Status._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} -import org.apache.spark.status.config._ private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 8ef1e18f83de3..634460686bb2b 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -37,7 +37,7 @@ import org.apache.mesos.protobuf.{ByteString, GeneratedMessageV3} import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.TaskState import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.{Status => _, _} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 1199eeca959d5..6978ec3a85715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -24,12 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Status._ import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity} -import org.apache.spark.status.config._ class SQLAppStatusListener( conf: SparkConf, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 02df45d1b7989..d79c0cf5e1c2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.internal.config +import org.apache.spark.internal.config.Status._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} @@ -38,7 +39,6 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.status.ElementTrackingStore -import org.apache.spark.status.config._ import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator} import org.apache.spark.util.kvstore.InMemoryStore From cc38abc27a671f345e3b4c170977a1976a02a0d0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 6 Nov 2018 10:39:58 +0800 Subject: [PATCH 2002/2461] [SPARK-25906][SHELL] Documents '-I' option (from Scala REPL) in spark-shell ## What changes were proposed in this pull request? This PR targets to document `-I` option from Spark 2.4.x (previously `-i` option until Spark 2.3.x). After we upgraded Scala to 2.11.12, `-i` option (`:load`) was replaced to `-I`(SI-7898). Existing `-i` became `:paste` which does not respect Spark's implicit import (for instance `toDF`, symbol as column, etc.). Therefore, `-i` option does not correctly from Spark 2.4.x and it's not documented. I checked other Scala REPL options but looks not applicable or working from quick tests. This PR only targets to document `-I` for now. ## How was this patch tested? Manually tested. **Mac:** ```bash $ ./bin/spark-shell --help Usage: ./bin/spark-shell [options] Scala REPL options: -I preload , enforcing line-by-line interpretation Options: --master MASTER_URL spark://host:port, mesos://host:port, yarn, k8s://https://host:port, or local (Default: local[*]). --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or on one of the worker machines inside the cluster ("cluster") (Default: client). ... ``` **Windows:** ```cmd C:\...\spark>.\bin\spark-shell --help Usage: .\bin\spark-shell.cmd [options] Scala REPL options: -I preload , enforcing line-by-line interpretation Options: --master MASTER_URL spark://host:port, mesos://host:port, yarn, k8s://https://host:port, or local (Default: local[*]). --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or on one of the worker machines inside the cluster ("cluster") (Default: client). ... ``` Closes #22919 from HyukjinKwon/SPARK-25906. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- bin/spark-shell | 5 ++++- bin/spark-shell2.cmd | 8 +++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/bin/spark-shell b/bin/spark-shell index 421f36cac3d47..e920137974980 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -32,7 +32,10 @@ if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi -export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options] + +Scala REPL options: + -I preload , enforcing line-by-line interpretation" # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index aaf71906c6526..549bf43bb6078 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -20,7 +20,13 @@ rem rem Figure out where the Spark framework is installed call "%~dp0find-spark-home.cmd" -set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] +set LF=^ + + +rem two empty lines are required +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options]^%LF%%LF%^%LF%%LF%^ +Scala REPL options:^%LF%%LF%^ + -I ^ preload ^, enforcing line-by-line interpretation rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We From 3ed91c9b8998f2512716f906cd1cba25578111ff Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 6 Nov 2018 05:38:59 +0000 Subject: [PATCH 2003/2461] [SPARK-25946][BUILD] Upgrade ASM to 7.x to support JDK11 ## What changes were proposed in this pull request? Upgrade ASM to 7.x to support JDK11 ## How was this patch tested? Existing tests. Closes #22953 from dbtsai/asm7. Authored-by: DB Tsai Signed-off-by: DB Tsai --- core/pom.xml | 2 +- .../org/apache/spark/util/ClosureCleaner.scala | 18 +++++++++--------- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- graphx/pom.xml | 2 +- .../spark/graphx/util/BytecodeUtils.scala | 8 ++++---- pom.xml | 8 ++++---- repl/pom.xml | 2 +- .../spark/repl/ExecutorClassLoader.scala | 6 +++--- sql/core/pom.xml | 2 +- 10 files changed, 26 insertions(+), 26 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index f23d09f73657b..5c26f9a5ea3c6 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm6-shaded + xbean-asm7-shaded org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 6c4740c002103..1b3e525644f00 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -23,8 +23,8 @@ import java.lang.invoke.SerializedLambda import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm6.Opcodes._ +import org.apache.xbean.asm7.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm7.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging @@ -424,7 +424,7 @@ private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") private class ReturnStatementFinder(targetMethodName: Option[String] = None) - extends ClassVisitor(ASM6) { + extends ClassVisitor(ASM7) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { @@ -438,7 +438,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None) val isTargetMethod = targetMethodName.isEmpty || name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") - new MethodVisitor(ASM6) { + new MethodVisitor(ASM7) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { throw new ReturnStatementInClosureException @@ -446,7 +446,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None) } } } else { - new MethodVisitor(ASM6) {} + new MethodVisitor(ASM7) {} } } } @@ -470,7 +470,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM6) { + extends ClassVisitor(ASM7) { override def visitMethod( access: Int, @@ -485,7 +485,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM6) { + new MethodVisitor(ASM7) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -525,7 +525,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM6) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM7) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -540,7 +540,7 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM6) { + new MethodVisitor(ASM7) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index db84b85618d8f..15a570908cc9a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -190,7 +190,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar -xbean-asm6-shaded-4.8.jar +xbean-asm7-shaded-4.12.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.5.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index befb93da94887..6d9191a4abb4c 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -211,7 +211,7 @@ token-provider-1.0.1.jar univocity-parsers-2.7.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar -xbean-asm6-shaded-4.8.jar +xbean-asm7-shaded-4.12.jar xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.9.jar diff --git a/graphx/pom.xml b/graphx/pom.xml index d65a8ceb62b9b..22bc148e068a5 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@ org.apache.xbean - xbean-asm6-shaded + xbean-asm7-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 50b03f71379a1..4ea09ec91d3a8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm6.Opcodes._ +import org.apache.xbean.asm7.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm7.Opcodes._ import org.apache.spark.util.Utils @@ -109,14 +109,14 @@ private[graphx] object BytecodeUtils { * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM6) { + extends ClassVisitor(ASM7) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM6) { + new MethodVisitor(ASM7) { override def visitMethodInsn( op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { diff --git a/pom.xml b/pom.xml index 597fb2fa1abd6..a08b7fda33387 100644 --- a/pom.xml +++ b/pom.xml @@ -311,13 +311,13 @@ chill-java ${chill.version} - org.apache.xbean - xbean-asm6-shaded - 4.8 + xbean-asm7-shaded + 4.12 diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 88eb0ad1da3d7..3176502b9e7ce 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm6._ -import org.apache.xbean.asm6.Opcodes._ +import org.apache.xbean.asm7._ +import org.apache.xbean.asm7.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil @@ -187,7 +187,7 @@ class ExecutorClassLoader( } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM6, cv) { +extends ClassVisitor(ASM7, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 2f72ff6cfdbfb..95e98c5444721 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm6-shaded + xbean-asm7-shaded org.scalacheck From fdd3bace1da01e5958fe0345c38e889e740ce25e Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Tue, 6 Nov 2018 08:25:32 -0600 Subject: [PATCH 2004/2461] [SPARK-22148][SPARK-15815][SCHEDULER] Acquire new executors to avoid hang because of blacklisting ## What changes were proposed in this pull request? Every time a task is unschedulable because of the condition where no. of task failures < no. of executors available, we currently abort the taskSet - failing the job. This change tries to acquire new executors so that we can complete the job successfully. We try to acquire a new executor only when we can kill an existing idle executor. We fallback to the older implementation where we abort the job if we cannot find an idle executor. ## How was this patch tested? I performed some manual tests to check and validate the behavior. ```scala val rdd = sc.parallelize(Seq(1 to 10), 3) import org.apache.spark.TaskContext val mapped = rdd.mapPartitionsWithIndex ( (index, iterator) => { if (index == 2) { Thread.sleep(30 * 1000); val attemptNum = TaskContext.get.attemptNumber; if (attemptNum < 3) throw new Exception("Fail for blacklisting")}; iterator.toList.map (x => x + " -> " + index).iterator } ) mapped.collect ``` Closes #22288 from dhruve/bug/SPARK-22148. Lead-authored-by: Dhruve Ashar Co-authored-by: Dhruve Ashar Co-authored-by: Tom Graves Signed-off-by: Thomas Graves --- .../spark/internal/config/package.scala | 8 + .../spark/scheduler/BlacklistTracker.scala | 30 ++- .../spark/scheduler/TaskSchedulerImpl.scala | 71 ++++++- .../spark/scheduler/TaskSetManager.scala | 41 ++-- .../scheduler/BlacklistIntegrationSuite.scala | 7 +- .../scheduler/TaskSchedulerImplSuite.scala | 189 +++++++++++++++++- docs/configuration.md | 8 + 7 files changed, 318 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index c8993e17bba67..2b3ba3c7daccb 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -622,6 +622,14 @@ package object config { .checkValue(v => v > 0, "The value should be a positive time value.") .createWithDefaultString("365d") + private[spark] val UNSCHEDULABLE_TASKSET_TIMEOUT = + ConfigBuilder("spark.scheduler.blacklist.unschedulableTaskSetTimeout") + .doc("The timeout in seconds to wait to acquire a new executor and schedule a task " + + "before aborting a TaskSet which is unschedulable because of being completely blacklisted.") + .timeConf(TimeUnit.SECONDS) + .checkValue(v => v >= 0, "The value should be a non negative time value.") + .createWithDefault(120) + private[spark] val BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL = ConfigBuilder("spark.scheduler.barrier.maxConcurrentTasksCheck.interval") .doc("Time in seconds to wait between a max concurrent tasks check failure and the next " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 980fbbe516b91..ef6d02d85c27b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -146,21 +146,31 @@ private[scheduler] class BlacklistTracker ( nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) } + private def killExecutor(exec: String, msg: String): Unit = { + allocationClient match { + case Some(a) => + logInfo(msg) + a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, + force = true) + case None => + logInfo(s"Not attempting to kill blacklisted executor id $exec " + + s"since allocation client is not defined.") + } + } + private def killBlacklistedExecutor(exec: String): Unit = { if (conf.get(config.BLACKLIST_KILL_ENABLED)) { - allocationClient match { - case Some(a) => - logInfo(s"Killing blacklisted executor id $exec " + - s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") - a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, - force = true) - case None => - logWarning(s"Not attempting to kill blacklisted executor id $exec " + - s"since allocation client is not defined.") - } + killExecutor(exec, + s"Killing blacklisted executor id $exec since ${config.BLACKLIST_KILL_ENABLED.key} is set.") } } + private[scheduler] def killBlacklistedIdleExecutor(exec: String): Unit = { + killExecutor(exec, + s"Killing blacklisted idle executor id $exec because of task unschedulability and trying " + + "to acquire a new executor.") + } + private def killExecutorsOnBlacklistedNode(node: String): Unit = { if (conf.get(config.BLACKLIST_KILL_ENABLED)) { allocationClient match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4f870e85ad38d..61556ea642614 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -35,7 +35,7 @@ import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} +import org.apache.spark.util.{AccumulatorV2, SystemClock, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -117,6 +117,11 @@ private[spark] class TaskSchedulerImpl( protected val executorIdToHost = new HashMap[String, String] + private val abortTimer = new Timer(true) + private val clock = new SystemClock + // Exposed for testing + val unschedulableTaskSetToExpiryTime = new HashMap[TaskSetManager, Long] + // Listener object to pass upcalls into var dagScheduler: DAGScheduler = null @@ -415,9 +420,53 @@ private[spark] class TaskSchedulerImpl( launchedAnyTask |= launchedTaskAtCurrentMaxLocality } while (launchedTaskAtCurrentMaxLocality) } + if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + taskSet.getCompletelyBlacklistedTaskIfAny(hostToExecutors).foreach { taskIndex => + // If the taskSet is unschedulable we try to find an existing idle blacklisted + // executor. If we cannot find one, we abort immediately. Else we kill the idle + // executor and kick off an abortTimer which if it doesn't schedule a task within the + // the timeout will abort the taskSet if we were unable to schedule any task from the + // taskSet. + // Note 1: We keep track of schedulability on a per taskSet basis rather than on a per + // task basis. + // Note 2: The taskSet can still be aborted when there are more than one idle + // blacklisted executors and dynamic allocation is on. This can happen when a killed + // idle executor isn't replaced in time by ExecutorAllocationManager as it relies on + // pending tasks and doesn't kill executors on idle timeouts, resulting in the abort + // timer to expire and abort the taskSet. + executorIdToRunningTaskIds.find(x => !isExecutorBusy(x._1)) match { + case Some ((executorId, _)) => + if (!unschedulableTaskSetToExpiryTime.contains(taskSet)) { + blacklistTrackerOpt.foreach(blt => blt.killBlacklistedIdleExecutor(executorId)) + + val timeout = conf.get(config.UNSCHEDULABLE_TASKSET_TIMEOUT) * 1000 + unschedulableTaskSetToExpiryTime(taskSet) = clock.getTimeMillis() + timeout + logInfo(s"Waiting for $timeout ms for completely " + + s"blacklisted task to be schedulable again before aborting $taskSet.") + abortTimer.schedule( + createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout) + } + case None => // Abort Immediately + logInfo("Cannot schedule any task because of complete blacklisting. No idle" + + s" executors can be found to kill. Aborting $taskSet." ) + taskSet.abortSinceCompletelyBlacklisted(taskIndex) + } + } + } else { + // We want to defer killing any taskSets as long as we have a non blacklisted executor + // which can be used to schedule a task from any active taskSets. This ensures that the + // job can make progress. + // Note: It is theoretically possible that a taskSet never gets scheduled on a + // non-blacklisted executor and the abort timer doesn't kick in because of a constant + // submission of new TaskSets. See the PR for more details. + if (unschedulableTaskSetToExpiryTime.nonEmpty) { + logInfo("Clearing the expiry times for all unschedulable taskSets as a task was " + + "recently scheduled.") + unschedulableTaskSetToExpiryTime.clear() + } } + if (launchedAnyTask && taskSet.isBarrier) { // Check whether the barrier tasks are partially launched. // TODO SPARK-24818 handle the assert failure case (that can happen when some locality @@ -453,6 +502,23 @@ private[spark] class TaskSchedulerImpl( return tasks } + private def createUnschedulableTaskSetAbortTimer( + taskSet: TaskSetManager, + taskIndex: Int): TimerTask = { + new TimerTask() { + override def run() { + if (unschedulableTaskSetToExpiryTime.contains(taskSet) && + unschedulableTaskSetToExpiryTime(taskSet) <= clock.getTimeMillis()) { + logInfo("Cannot schedule any task because of complete blacklisting. " + + s"Wait time for scheduling expired. Aborting $taskSet.") + taskSet.abortSinceCompletelyBlacklisted(taskIndex) + } else { + this.cancel() + } + } + } + } + /** * Shuffle offers around to avoid always placing tasks on the same workers. Exposed to allow * overriding in tests, so it can be deterministic. @@ -590,6 +656,7 @@ private[spark] class TaskSchedulerImpl( barrierCoordinator.stop() } starvationTimer.cancel() + abortTimer.cancel() } override def defaultParallelism(): Int = backend.defaultParallelism() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d5e85a11cb279..6bf60dd8e9dfa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -623,8 +623,8 @@ private[spark] class TaskSetManager( * * It is possible that this taskset has become impossible to schedule *anywhere* due to the * blacklist. The most common scenario would be if there are fewer executors than - * spark.task.maxFailures. We need to detect this so we can fail the task set, otherwise the job - * will hang. + * spark.task.maxFailures. We need to detect this so we can avoid the job from being hung. + * We try to acquire new executor/s by killing an existing idle blacklisted executor. * * There's a tradeoff here: we could make sure all tasks in the task set are schedulable, but that * would add extra time to each iteration of the scheduling loop. Here, we take the approach of @@ -635,9 +635,9 @@ private[spark] class TaskSetManager( * failures (this is because the method picks one unscheduled task, and then iterates through each * executor until it finds one that the task isn't blacklisted on). */ - private[scheduler] def abortIfCompletelyBlacklisted( - hostToExecutors: HashMap[String, HashSet[String]]): Unit = { - taskSetBlacklistHelperOpt.foreach { taskSetBlacklist => + private[scheduler] def getCompletelyBlacklistedTaskIfAny( + hostToExecutors: HashMap[String, HashSet[String]]): Option[Int] = { + taskSetBlacklistHelperOpt.flatMap { taskSetBlacklist => val appBlacklist = blacklistTracker.get // Only look for unschedulable tasks when at least one executor has registered. Otherwise, // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. @@ -658,11 +658,11 @@ private[spark] class TaskSetManager( } } - pendingTask.foreach { indexInTaskSet => + pendingTask.find { indexInTaskSet => // try to find some executor this task can run on. Its possible that some *other* // task isn't schedulable anywhere, but we will discover that in some later call, // when that unschedulable task is the last task remaining. - val blacklistedEverywhere = hostToExecutors.forall { case (host, execsOnHost) => + hostToExecutors.forall { case (host, execsOnHost) => // Check if the task can run on the node val nodeBlacklisted = appBlacklist.isNodeBlacklisted(host) || @@ -679,22 +679,27 @@ private[spark] class TaskSetManager( } } } - if (blacklistedEverywhere) { - val partition = tasks(indexInTaskSet).partitionId - abort(s""" - |Aborting $taskSet because task $indexInTaskSet (partition $partition) - |cannot run anywhere due to node and executor blacklist. - |Most recent failure: - |${taskSetBlacklist.getLatestFailureReason} - | - |Blacklisting behavior can be configured via spark.blacklist.*. - |""".stripMargin) - } } + } else { + None } } } + private[scheduler] def abortSinceCompletelyBlacklisted(indexInTaskSet: Int): Unit = { + taskSetBlacklistHelperOpt.foreach { taskSetBlacklist => + val partition = tasks(indexInTaskSet).partitionId + abort(s""" + |Aborting $taskSet because task $indexInTaskSet (partition $partition) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${taskSetBlacklist.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) + } + } + /** * Marks the task as getting result and notifies the DAG Scheduler */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index fe22d70850c7d..29bb8232f44f5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -96,15 +96,16 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM assertDataStructuresEmpty(noFailure = true) } - // Make sure that if we've failed on all executors, but haven't hit task.maxFailures yet, the job - // doesn't hang + // Make sure that if we've failed on all executors, but haven't hit task.maxFailures yet, we try + // to acquire a new executor and if we aren't able to get one, the job doesn't hang and we abort testScheduler( "SPARK-15865 Progress with fewer executors than maxTaskFailures", extraConfs = Seq( config.BLACKLIST_ENABLED.key -> "true", "spark.testing.nHosts" -> "2", "spark.testing.nExecutorsPerHost" -> "1", - "spark.testing.nCoresPerExecutor" -> "1" + "spark.testing.nCoresPerExecutor" -> "1", + "spark.scheduler.blacklist.unschedulableTaskSetTimeout" -> "0s" ) ) { def runBackend(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9e1d13e369ad9..29172b4664e32 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import scala.collection.mutable.HashMap +import scala.concurrent.duration._ import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq} import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when} import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually import org.scalatest.mockito.MockitoSugar import org.apache.spark._ @@ -40,7 +42,7 @@ class FakeSchedulerBackend extends SchedulerBackend { } class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach - with Logging with MockitoSugar { + with Logging with MockitoSugar with Eventually { var failedTaskSetException: Option[Throwable] = None var failedTaskSetReason: String = null @@ -82,10 +84,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B setupHelper() } - def setupSchedulerWithMockTaskSetBlacklist(): TaskSchedulerImpl = { + def setupSchedulerWithMockTaskSetBlacklist(confs: (String, String)*): TaskSchedulerImpl = { blacklist = mock[BlacklistTracker] val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") conf.set(config.BLACKLIST_ENABLED, true) + confs.foreach { case (k, v) => conf.set(k, v) } + sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { @@ -466,7 +470,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } } - test("abort stage when all executors are blacklisted") { + test("abort stage when all executors are blacklisted and we cannot acquire new executor") { taskScheduler = setupSchedulerWithMockTaskSetBlacklist() val taskSet = FakeTask.createTaskSet(numTasks = 10, stageAttemptId = 0) taskScheduler.submitTasks(taskSet) @@ -503,6 +507,185 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B verify(tsm).abort(anyString(), anyObject()) } + test("SPARK-22148 abort timer should kick in when task is completely blacklisted & no new " + + "executor can be acquired") { + // set the abort timer to fail immediately + taskScheduler = setupSchedulerWithMockTaskSetBlacklist( + config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "0") + + // We have only 1 task remaining with 1 executor + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + val tsm = stageToMockTaskSetManager(0) + + // submit an offer with one executor + val firstTaskAttempts = taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten + + // Fail the running task + val failedTask = firstTaskAttempts.find(_.executorId == "executor0").get + taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0)) + // we explicitly call the handleFailedTask method here to avoid adding a sleep in the test suite + // Reason being - handleFailedTask is run by an executor service and there is a momentary delay + // before it is launched and this fails the assertion check. + tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason) + when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask( + "executor0", failedTask.index)).thenReturn(true) + + // make an offer on the blacklisted executor. We won't schedule anything, and set the abort + // timer to kick in immediately + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten.size === 0) + // Wait for the abort timer to kick in. Even though we configure the timeout to be 0, there is a + // slight delay as the abort timer is launched in a separate thread. + eventually(timeout(500.milliseconds)) { + assert(tsm.isZombie) + } + } + + test("SPARK-22148 try to acquire a new executor when task is unschedulable with 1 executor") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist( + config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "10") + + // We have only 1 task remaining with 1 executor + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + val tsm = stageToMockTaskSetManager(0) + + // submit an offer with one executor + val firstTaskAttempts = taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten + + // Fail the running task + val failedTask = firstTaskAttempts.head + taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0)) + // we explicitly call the handleFailedTask method here to avoid adding a sleep in the test suite + // Reason being - handleFailedTask is run by an executor service and there is a momentary delay + // before it is launched and this fails the assertion check. + tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason) + when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask( + "executor0", failedTask.index)).thenReturn(true) + + // make an offer on the blacklisted executor. We won't schedule anything, and set the abort + // timer to expire if no new executors could be acquired. We kill the existing idle blacklisted + // executor and try to acquire a new one. + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten.size === 0) + assert(taskScheduler.unschedulableTaskSetToExpiryTime.contains(tsm)) + assert(!tsm.isZombie) + + // Offer a new executor which should be accepted + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor1", "host0", 1) + )).flatten.size === 1) + assert(taskScheduler.unschedulableTaskSetToExpiryTime.isEmpty) + assert(!tsm.isZombie) + } + + // This is to test a scenario where we have two taskSets completely blacklisted and on acquiring + // a new executor we don't want the abort timer for the second taskSet to expire and abort the job + test("SPARK-22148 abort timer should clear unschedulableTaskSetToExpiryTime for all TaskSets") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + + // We have 2 taskSets with 1 task remaining in each with 1 executor completely blacklisted + val taskSet1 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet1) + val taskSet2 = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet2) + val tsm = stageToMockTaskSetManager(0) + + // submit an offer with one executor + val firstTaskAttempts = taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten + + assert(taskScheduler.unschedulableTaskSetToExpiryTime.isEmpty) + + // Fail the running task + val failedTask = firstTaskAttempts.head + taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0)) + tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason) + when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask( + "executor0", failedTask.index)).thenReturn(true) + + // make an offer. We will schedule the task from the second taskSet. Since a task was scheduled + // we do not kick off the abort timer for taskSet1 + val secondTaskAttempts = taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten + + assert(taskScheduler.unschedulableTaskSetToExpiryTime.isEmpty) + + val tsm2 = stageToMockTaskSetManager(1) + val failedTask2 = secondTaskAttempts.head + taskScheduler.statusUpdate(failedTask2.taskId, TaskState.FAILED, ByteBuffer.allocate(0)) + tsm2.handleFailedTask(failedTask2.taskId, TaskState.FAILED, UnknownReason) + when(tsm2.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask( + "executor0", failedTask2.index)).thenReturn(true) + + // make an offer on the blacklisted executor. We won't schedule anything, and set the abort + // timer for taskSet1 and taskSet2 + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten.size === 0) + assert(taskScheduler.unschedulableTaskSetToExpiryTime.contains(tsm)) + assert(taskScheduler.unschedulableTaskSetToExpiryTime.contains(tsm2)) + assert(taskScheduler.unschedulableTaskSetToExpiryTime.size == 2) + + // Offer a new executor which should be accepted + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor1", "host1", 1) + )).flatten.size === 1) + + // Check if all the taskSets are cleared + assert(taskScheduler.unschedulableTaskSetToExpiryTime.isEmpty) + + assert(!tsm.isZombie) + } + + // this test is to check that we don't abort a taskSet which is not being scheduled on other + // executors as it is waiting on locality timeout and not being aborted because it is still not + // completely blacklisted. + test("SPARK-22148 Ensure we don't abort the taskSet if we haven't been completely blacklisted") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist( + config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "0", + // This is to avoid any potential flakiness in the test because of large pauses in jenkins + config.LOCALITY_WAIT.key -> "30s" + ) + + val preferredLocation = Seq(ExecutorCacheTaskLocation("host0", "executor0")) + val taskSet1 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0, + preferredLocation) + taskScheduler.submitTasks(taskSet1) + + val tsm = stageToMockTaskSetManager(0) + + // submit an offer with one executor + var taskAttempts = taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 1) + )).flatten + + // Fail the running task + val failedTask = taskAttempts.head + taskScheduler.statusUpdate(failedTask.taskId, TaskState.FAILED, ByteBuffer.allocate(0)) + tsm.handleFailedTask(failedTask.taskId, TaskState.FAILED, UnknownReason) + when(tsm.taskSetBlacklistHelperOpt.get.isExecutorBlacklistedForTask( + "executor0", failedTask.index)).thenReturn(true) + + // make an offer but we won't schedule anything yet as scheduler locality is still PROCESS_LOCAL + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor1", "host0", 1) + )).flatten.isEmpty) + + assert(taskScheduler.unschedulableTaskSetToExpiryTime.isEmpty) + + assert(!tsm.isZombie) + } + /** * Helper for performance tests. Takes the explicitly blacklisted nodes and executors; verifies * that the blacklists are used efficiently to ensure scheduling is not O(numPendingTasks). diff --git a/docs/configuration.md b/docs/configuration.md index 11ee7a9610602..f8937b0bc61a0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1610,6 +1610,14 @@ Apart from these, the following properties are also available, and may be useful driver using more memory. +
      + + + + {createExecutorTable(stage)} @@ -92,16 +93,7 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k)) - + @@ -145,6 +137,11 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { } } + + } } From 9a5fda60e532dc7203d21d5fbe385cd561906ccb Mon Sep 17 00:00:00 2001 From: Shanyu Zhao Date: Thu, 15 Nov 2018 10:30:16 -0600 Subject: [PATCH 2068/2461] [SPARK-26011][SPARK-SUBMIT] Yarn mode pyspark app without python main resource does not honor "spark.jars.packages" SparkSubmit determines pyspark app by the suffix of primary resource but Livy uses "spark-internal" as the primary resource when calling spark-submit, therefore args.isPython is set to false in SparkSubmit.scala. In Yarn mode, SparkSubmit module is responsible for resolving maven coordinates and adding them to "spark.submit.pyFiles" so that python's system path can be set correctly. The fix is to resolve maven coordinates not only when args.isPython is true, but also when primary resource is spark-internal. Tested the patch with Livy submitting pyspark app, spark-submit, pyspark with or without packages config. Signed-off-by: Shanyu Zhao Closes #23009 from shanyu/shanyu-26011. Authored-by: Shanyu Zhao Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0fc8c9bd789e0..324f6f8894d34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -318,7 +318,7 @@ private[spark] class SparkSubmit extends Logging { if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) - if (args.isPython) { + if (args.isPython || isInternal(args.primaryResource)) { args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } From 3649fe599f1aa27fea0abd61c18d3ffa275d267b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Nov 2018 07:58:09 +0800 Subject: [PATCH 2069/2461] [SPARK-26035][PYTHON] Break large streaming/tests.py files into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR continues to break down a big large file into smaller files. See https://github.com/apache/spark/pull/23021. It targets to follow https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/streaming/tests.py` into ...: ``` pyspark ├── __init__.py ... ├── streaming │   ├── __init__.py ... │   ├── tests │   │   ├── __init__.py │   │   ├── test_context.py │   │   ├── test_dstream.py │   │   ├── test_kinesis.py │   │   └── test_listener.py ... ├── testing ... │   ├── streamingutils.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and .`/run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.tests.test_context ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23034 from HyukjinKwon/SPARK-26035. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 7 +- python/pyspark/streaming/tests/__init__.py | 16 + .../pyspark/streaming/tests/test_context.py | 184 ++++++ .../{tests.py => tests/test_dstream.py} | 575 +----------------- .../pyspark/streaming/tests/test_kinesis.py | 89 +++ .../pyspark/streaming/tests/test_listener.py | 158 +++++ python/pyspark/testing/streamingutils.py | 190 ++++++ 7 files changed, 658 insertions(+), 561 deletions(-) create mode 100644 python/pyspark/streaming/tests/__init__.py create mode 100644 python/pyspark/streaming/tests/test_context.py rename python/pyspark/streaming/{tests.py => tests/test_dstream.py} (50%) create mode 100644 python/pyspark/streaming/tests/test_kinesis.py create mode 100644 python/pyspark/streaming/tests/test_listener.py create mode 100644 python/pyspark/testing/streamingutils.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d5fcc060616f2..58b48f43f6468 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -398,8 +398,13 @@ def __hash__(self): "python/pyspark/streaming" ], python_test_goals=[ + # doctests "pyspark.streaming.util", - "pyspark.streaming.tests", + # unittests + "pyspark.streaming.tests.test_context", + "pyspark.streaming.tests.test_dstream", + "pyspark.streaming.tests.test_kinesis", + "pyspark.streaming.tests.test_listener", ] ) diff --git a/python/pyspark/streaming/tests/__init__.py b/python/pyspark/streaming/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/streaming/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py new file mode 100644 index 0000000000000..b44121462a920 --- /dev/null +++ b/python/pyspark/streaming/tests/test_context.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import struct +import tempfile +import time + +from pyspark.streaming import StreamingContext +from pyspark.testing.streamingutils import PySparkStreamingTestCase + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + setupCalled = False + + def _add_input_stream(self): + inputs = [range(1, x) for x in range(101)] + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.ssc.stop(False) + + def test_queue_stream(self): + input = [list(range(i + 1)) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], result) + + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", bytes(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) + + def test_union(self): + input = [list(range(i + 1)) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() returns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + def test_await_termination_or_timeout(self): + self._add_input_stream() + self.ssc.start() + self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) + self.ssc.stop(False) + self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) + + +if __name__ == "__main__": + import unittest + from pyspark.streaming.tests.test_context import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests/test_dstream.py similarity index 50% rename from python/pyspark/streaming/tests.py rename to python/pyspark/streaming/tests/test_dstream.py index 8df00bc988430..d14e346b7a688 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -14,155 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import glob -import os -import sys -from itertools import chain -import time import operator -import tempfile -import random -import struct +import os import shutil +import tempfile +import time +import unittest from functools import reduce +from itertools import chain -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -if sys.version >= "3": - long = int - -from pyspark.context import SparkConf, SparkContext, RDD -from pyspark.storagelevel import StorageLevel -from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream -from pyspark.streaming.listener import StreamingListener - - -class PySparkStreamingTestCase(unittest.TestCase): - - timeout = 30 # seconds - duration = .5 - - @classmethod - def setUpClass(cls): - class_name = cls.__name__ - conf = SparkConf().set("spark.default.parallelism", 1) - cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir(tempfile.mkdtemp()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - # Clean up in the JVM just in case there has been some issues in Python API - try: - jSparkContextOption = SparkContext._jvm.SparkContext.get() - if jSparkContextOption.nonEmpty(): - jSparkContextOption.get().stop() - except: - pass - - def setUp(self): - self.ssc = StreamingContext(self.sc, self.duration) - - def tearDown(self): - if self.ssc is not None: - self.ssc.stop(False) - # Clean up in the JVM just in case there has been some issues in Python API - try: - jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() - if jStreamingContextOption.nonEmpty(): - jStreamingContextOption.get().stop(False) - except: - pass - - def wait_for(self, result, n): - start_time = time.time() - while len(result) < n and time.time() - start_time < self.timeout: - time.sleep(0.01) - if len(result) < n: - print("timeout after", self.timeout) - - def _take(self, dstream, n): - """ - Return the first `n` elements in the stream (will start and stop). - """ - results = [] - - def take(_, rdd): - if rdd and len(results) < n: - results.extend(rdd.take(n - len(results))) - - dstream.foreachRDD(take) - - self.ssc.start() - self.wait_for(results, n) - return results - - def _collect(self, dstream, n, block=True): - """ - Collect each RDDs into the returned list. - - :return: list, which will have the collected items. - """ - result = [] - - def get_output(_, rdd): - if rdd and len(result) < n: - r = rdd.collect() - if r: - result.append(r) - - dstream.foreachRDD(get_output) - - if not block: - return result - - self.ssc.start() - self.wait_for(result, n) - return result - - def _test_func(self, input, func, expected, sort=False, input2=None): - """ - @param input: dataset for the test. This should be list of lists. - @param func: wrapped function. This function should return PythonDStream object. - @param expected: expected output for this testcase. - """ - if not isinstance(input[0], RDD): - input = [self.sc.parallelize(d, 1) for d in input] - input_stream = self.ssc.queueStream(input) - if input2 and not isinstance(input2[0], RDD): - input2 = [self.sc.parallelize(d, 1) for d in input2] - input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None - - # Apply test function to stream. - if input2: - stream = func(input_stream, input_stream2) - else: - stream = func(input_stream) - - result = self._collect(stream, len(expected)) - if sort: - self._sort_result_based_on_key(result) - self._sort_result_based_on_key(expected) - self.assertEqual(expected, result) - - def _sort_result_based_on_key(self, outputs): - """Sort the list based on first value.""" - for output in outputs: - output.sort(key=lambda x: x[0]) +from pyspark import SparkConf, SparkContext, RDD +from pyspark.streaming import StreamingContext +from pyspark.testing.streamingutils import PySparkStreamingTestCase class BasicOperationTests(PySparkStreamingTestCase): @@ -526,135 +389,6 @@ def failed_func(i): self.fail("a failed func should throw an error") -class StreamingListenerTests(PySparkStreamingTestCase): - - duration = .5 - - class BatchInfoCollector(StreamingListener): - - def __init__(self): - super(StreamingListener, self).__init__() - self.batchInfosCompleted = [] - self.batchInfosStarted = [] - self.batchInfosSubmitted = [] - self.streamingStartedTime = [] - - def onStreamingStarted(self, streamingStarted): - self.streamingStartedTime.append(streamingStarted.time) - - def onBatchSubmitted(self, batchSubmitted): - self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) - - def onBatchStarted(self, batchStarted): - self.batchInfosStarted.append(batchStarted.batchInfo()) - - def onBatchCompleted(self, batchCompleted): - self.batchInfosCompleted.append(batchCompleted.batchInfo()) - - def test_batch_info_reports(self): - batch_collector = self.BatchInfoCollector() - self.ssc.addStreamingListener(batch_collector) - input = [[1], [2], [3], [4]] - - def func(dstream): - return dstream.map(int) - expected = [[1], [2], [3], [4]] - self._test_func(input, func, expected) - - batchInfosSubmitted = batch_collector.batchInfosSubmitted - batchInfosStarted = batch_collector.batchInfosStarted - batchInfosCompleted = batch_collector.batchInfosCompleted - streamingStartedTime = batch_collector.streamingStartedTime - - self.wait_for(batchInfosCompleted, 4) - - self.assertEqual(len(streamingStartedTime), 1) - - self.assertGreaterEqual(len(batchInfosSubmitted), 4) - for info in batchInfosSubmitted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), -1) - self.assertGreaterEqual(outputInfo.endTime(), -1) - self.assertIsNone(outputInfo.failureReason()) - - self.assertEqual(info.schedulingDelay(), -1) - self.assertEqual(info.processingDelay(), -1) - self.assertEqual(info.totalDelay(), -1) - self.assertEqual(info.numRecords(), 0) - - self.assertGreaterEqual(len(batchInfosStarted), 4) - for info in batchInfosStarted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), -1) - self.assertGreaterEqual(outputInfo.endTime(), -1) - self.assertIsNone(outputInfo.failureReason()) - - self.assertGreaterEqual(info.schedulingDelay(), 0) - self.assertEqual(info.processingDelay(), -1) - self.assertEqual(info.totalDelay(), -1) - self.assertEqual(info.numRecords(), 0) - - self.assertGreaterEqual(len(batchInfosCompleted), 4) - for info in batchInfosCompleted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), 0) - self.assertGreaterEqual(outputInfo.endTime(), 0) - self.assertIsNone(outputInfo.failureReason()) - - self.assertGreaterEqual(info.schedulingDelay(), 0) - self.assertGreaterEqual(info.processingDelay(), 0) - self.assertGreaterEqual(info.totalDelay(), 0) - self.assertEqual(info.numRecords(), 0) - - class WindowFunctionTests(PySparkStreamingTestCase): timeout = 15 @@ -732,156 +466,6 @@ def func(dstream): self._test_func(input, func, expected) -class StreamingContextTests(PySparkStreamingTestCase): - - duration = 0.1 - setupCalled = False - - def _add_input_stream(self): - inputs = [range(1, x) for x in range(101)] - stream = self.ssc.queueStream(inputs) - self._collect(stream, 1, block=False) - - def test_stop_only_streaming_context(self): - self._add_input_stream() - self.ssc.start() - self.ssc.stop(False) - self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) - - def test_stop_multiple_times(self): - self._add_input_stream() - self.ssc.start() - self.ssc.stop(False) - self.ssc.stop(False) - - def test_queue_stream(self): - input = [list(range(i + 1)) for i in range(3)] - dstream = self.ssc.queueStream(input) - result = self._collect(dstream, 3) - self.assertEqual(input, result) - - def test_text_file_stream(self): - d = tempfile.mkdtemp() - self.ssc = StreamingContext(self.sc, self.duration) - dstream2 = self.ssc.textFileStream(d).map(int) - result = self._collect(dstream2, 2, block=False) - self.ssc.start() - for name in ('a', 'b'): - time.sleep(1) - with open(os.path.join(d, name), "w") as f: - f.writelines(["%d\n" % i for i in range(10)]) - self.wait_for(result, 2) - self.assertEqual([list(range(10)), list(range(10))], result) - - def test_binary_records_stream(self): - d = tempfile.mkdtemp() - self.ssc = StreamingContext(self.sc, self.duration) - dstream = self.ssc.binaryRecordsStream(d, 10).map( - lambda v: struct.unpack("10b", bytes(v))) - result = self._collect(dstream, 2, block=False) - self.ssc.start() - for name in ('a', 'b'): - time.sleep(1) - with open(os.path.join(d, name), "wb") as f: - f.write(bytearray(range(10))) - self.wait_for(result, 2) - self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) - - def test_union(self): - input = [list(range(i + 1)) for i in range(3)] - dstream = self.ssc.queueStream(input) - dstream2 = self.ssc.queueStream(input) - dstream3 = self.ssc.union(dstream, dstream2) - result = self._collect(dstream3, 3) - expected = [i * 2 for i in input] - self.assertEqual(expected, result) - - def test_transform(self): - dstream1 = self.ssc.queueStream([[1]]) - dstream2 = self.ssc.queueStream([[2]]) - dstream3 = self.ssc.queueStream([[3]]) - - def func(rdds): - rdd1, rdd2, rdd3 = rdds - return rdd2.union(rdd3).union(rdd1) - - dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - - self.assertEqual([2, 3, 1], self._take(dstream, 3)) - - def test_transform_pairrdd(self): - # This regression test case is for SPARK-17756. - dstream = self.ssc.queueStream( - [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) - self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) - - def test_get_active(self): - self.assertEqual(StreamingContext.getActive(), None) - - # Verify that getActive() returns the active context - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - - # Verify that getActive() returns None - self.ssc.stop(False) - self.assertEqual(StreamingContext.getActive(), None) - - # Verify that if the Java context is stopped, then getActive() returns None - self.ssc = StreamingContext(self.sc, self.duration) - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - self.ssc._jssc.stop(False) - self.assertEqual(StreamingContext.getActive(), None) - - def test_get_active_or_create(self): - # Test StreamingContext.getActiveOrCreate() without checkpoint data - # See CheckpointTests for tests with checkpoint data - self.ssc = None - self.assertEqual(StreamingContext.getActive(), None) - - def setupFunc(): - ssc = StreamingContext(self.sc, self.duration) - ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.setupCalled = True - return ssc - - # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - # Verify that getActiveOrCreate() returns active context and does not call the setupFunc - self.ssc.start() - self.setupCalled = False - self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) - self.assertFalse(self.setupCalled) - - # Verify that getActiveOrCreate() calls setupFunc after active context is stopped - self.ssc.stop(False) - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - # Verify that if the Java context is stopped, then getActive() returns None - self.ssc = StreamingContext(self.sc, self.duration) - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - self.ssc._jssc.stop(False) - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - def test_await_termination_or_timeout(self): - self._add_input_stream() - self.ssc.start() - self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) - self.ssc.stop(False) - self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) - - class CheckpointTests(unittest.TestCase): setupCalled = False @@ -1046,140 +630,11 @@ def check_output(n): self.ssc.stop(True, True) -class KinesisStreamTests(PySparkStreamingTestCase): - - def test_kinesis_stream_api(self): - # Don't start the StreamingContext because we cannot test it in Jenkins - kinesisStream1 = KinesisUtils.createStream( - self.ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) - kinesisStream2 = KinesisUtils.createStream( - self.ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - def test_kinesis_stream(self): - if not are_kinesis_tests_enabled: - sys.stderr.write( - "Skipped test_kinesis_stream (enable by setting environment variable %s=1" - % kinesis_test_environ_var) - return - - import random - kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) - kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) - try: - kinesisTestUtils.createStream() - aWSCredentials = kinesisTestUtils.getAWSCredentials() - stream = KinesisUtils.createStream( - self.ssc, kinesisAppName, kinesisTestUtils.streamName(), - kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), - InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) - - outputBuffer = [] - - def get_output(_, rdd): - for e in rdd.collect(): - outputBuffer.append(e) - - stream.foreachRDD(get_output) - self.ssc.start() - - testData = [i for i in range(1, 11)] - expectedOutput = set([str(i) for i in testData]) - start_time = time.time() - while time.time() - start_time < 120: - kinesisTestUtils.pushData(testData) - if expectedOutput == set(outputBuffer): - break - time.sleep(10) - self.assertEqual(expectedOutput, set(outputBuffer)) - except: - import traceback - traceback.print_exc() - raise - finally: - self.ssc.stop(False) - kinesisTestUtils.deleteStream() - kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) - - -# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because -# the artifact jars are in different directories. -def search_jar(dir, name_prefix): - # We should ignore the following jars - ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") - jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build - glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build - return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)] - - -def _kinesis_asl_assembly_dir(): - SPARK_HOME = os.environ["SPARK_HOME"] - return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - - -def search_kinesis_asl_assembly_jar(): - jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") - if not jars: - return None - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py -kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" -are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' - if __name__ == "__main__": - from pyspark.streaming.tests import * - kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - - if kinesis_asl_assembly_jar is None: - kinesis_jar_present = False - jars_args = "" - else: - kinesis_jar_present = True - jars_args = "--jars %s" % kinesis_asl_assembly_jar - - existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") - os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) - testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - StreamingListenerTests] - - if kinesis_jar_present is True: - testcases.append(KinesisStreamTests) - elif are_kinesis_tests_enabled is False: - sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled into a JAR. To run these tests, " - "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " - "streaming-kinesis-asl-assembly/assembly' or " - "'build/mvn -Pkinesis-asl package' before running this test.") - else: - raise Exception( - ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % _kinesis_asl_assembly_dir()) + - "You need to build Spark with 'build/sbt -Pkinesis-asl " - "assembly/package streaming-kinesis-asl-assembly/assembly'" - "or 'build/mvn -Pkinesis-asl package' before running this test.") - - sys.stderr.write("Running tests: %s \n" % (str(testcases))) - failed = False - for testcase in testcases: - sys.stderr.write("[Running %s]\n" % (testcase)) - tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) - if not result.wasSuccessful(): - failed = True - else: - result = unittest.TextTestRunner(verbosity=2).run(tests) - if not result.wasSuccessful(): - failed = True - sys.exit(failed) + from pyspark.streaming.tests.test_dstream import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py new file mode 100644 index 0000000000000..d8a0b47f04097 --- /dev/null +++ b/python/pyspark/streaming/tests/test_kinesis.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import unittest + +from pyspark import StorageLevel +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream +from pyspark.testing.streamingutils import should_test_kinesis, kinesis_requirement_message, \ + PySparkStreamingTestCase + + +@unittest.skipIf(not should_test_kinesis, kinesis_requirement_message) +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + self.ssc.stop(False) + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + +if __name__ == "__main__": + from pyspark.streaming.tests.test_kinesis import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py new file mode 100644 index 0000000000000..7c874b6b32500 --- /dev/null +++ b/python/pyspark/streaming/tests/test_listener.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.streaming import StreamingListener +from pyspark.testing.streamingutils import PySparkStreamingTestCase + + +class StreamingListenerTests(PySparkStreamingTestCase): + + duration = .5 + + class BatchInfoCollector(StreamingListener): + + def __init__(self): + super(StreamingListener, self).__init__() + self.batchInfosCompleted = [] + self.batchInfosStarted = [] + self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) + + def onBatchSubmitted(self, batchSubmitted): + self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) + + def onBatchStarted(self, batchStarted): + self.batchInfosStarted.append(batchStarted.batchInfo()) + + def onBatchCompleted(self, batchCompleted): + self.batchInfosCompleted.append(batchCompleted.batchInfo()) + + def test_batch_info_reports(self): + batch_collector = self.BatchInfoCollector() + self.ssc.addStreamingListener(batch_collector) + input = [[1], [2], [3], [4]] + + def func(dstream): + return dstream.map(int) + expected = [[1], [2], [3], [4]] + self._test_func(input, func, expected) + + batchInfosSubmitted = batch_collector.batchInfosSubmitted + batchInfosStarted = batch_collector.batchInfosStarted + batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime + + self.wait_for(batchInfosCompleted, 4) + + self.assertEqual(len(streamingStartedTime), 1) + + self.assertGreaterEqual(len(batchInfosSubmitted), 4) + for info in batchInfosSubmitted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertEqual(info.schedulingDelay(), -1) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosStarted), 4) + for info in batchInfosStarted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosCompleted), 4) + for info in batchInfosCompleted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), 0) + self.assertGreaterEqual(outputInfo.endTime(), 0) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertGreaterEqual(info.processingDelay(), 0) + self.assertGreaterEqual(info.totalDelay(), 0) + self.assertEqual(info.numRecords(), 0) + + +if __name__ == "__main__": + import unittest + from pyspark.streaming.tests.test_listener import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py new file mode 100644 index 0000000000000..85a2fa14b936c --- /dev/null +++ b/python/pyspark/testing/streamingutils.py @@ -0,0 +1,190 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import glob +import os +import tempfile +import time +import unittest + +from pyspark import SparkConf, SparkContext, RDD +from pyspark.streaming import StreamingContext + + +def search_kinesis_asl_assembly_jar(): + kinesis_asl_assembly_dir = os.path.join( + os.environ["SPARK_HOME"], "external/kinesis-asl-assembly") + + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + + # Search jar in the project dir using the jar name_prefix for both sbt build and maven + # build because the artifact jars are in different directories. + name_prefix = "spark-streaming-kinesis-asl-assembly" + sbt_build = glob.glob(os.path.join( + kinesis_asl_assembly_dir, "target/scala-*/%s-*.jar" % name_prefix)) + maven_build = glob.glob(os.path.join( + kinesis_asl_assembly_dir, "target/%s_*.jar" % name_prefix)) + jar_paths = sbt_build + maven_build + jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] + + if not jars: + return None + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +should_skip_kinesis_tests = not os.environ.get(kinesis_test_environ_var) == '1' + +if should_skip_kinesis_tests: + kinesis_requirement_message = ( + "Skipping all Kinesis Python tests as environmental variable 'ENABLE_KINESIS_TESTS' " + "was not set.") +else: + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + if kinesis_asl_assembly_jar is None: + kinesis_requirement_message = ( + "Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % kinesis_asl_assembly_jar + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) + kinesis_requirement_message = None + +should_test_kinesis = kinesis_requirement_message is None + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 30 # seconds + duration = .5 + + @classmethod + def setUpClass(cls): + class_name = cls.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir(tempfile.mkdtemp()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + try: + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + except: + pass + + def setUp(self): + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + try: + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) + except: + pass + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print("timeout after", self.timeout) + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) From dad2d826ae9138f06751e5d092531a9e06028c21 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 16 Nov 2018 12:46:57 +0800 Subject: [PATCH 2070/2461] [SPARK-23207][SQL][FOLLOW-UP] Use `SQLConf.get.enableRadixSort` instead of `SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)`. ## What changes were proposed in this pull request? This is a follow-up of #20393. We should read the conf `"spark.sql.sort.enableRadixSort"` from `SQLConf` instead of `SparkConf`, i.e., use `SQLConf.get.enableRadixSort` instead of `SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)`, otherwise the config is never read. ## How was this patch tested? Existing tests. Closes #23046 from ueshin/issues/SPARK-23207/conf. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 3b6eebd41e886..d6742ab3e0f31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -280,7 +280,7 @@ object ShuffleExchangeExec { } // The comparator for comparing row hashcode, which should always be Integer. val prefixComparator = PrefixComparators.LONG - val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + val canUseRadixSort = SQLConf.get.enableRadixSort // The prefix computer generates row hashcode as the prefix, so we may decrease the // probability that the prefixes are equal when input rows choose column values from a // limited range. From 4ac8f9becda42e83131df87c68bcd1b0dfb50ac8 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 16 Nov 2018 13:10:44 +0800 Subject: [PATCH 2071/2461] [SPARK-26073][SQL][FOLLOW-UP] remove invalid comment as we don't use it anymore ## What changes were proposed in this pull request? remove invalid comment as we don't use it anymore More details: https://github.com/apache/spark/pull/22976#discussion_r233764857 ## How was this patch tested? N/A Closes #23044 from heary-cao/followUpOrdering. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/codegen/GenerateOrdering.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index c3b95b6c67fdd..283fd2a6e9383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -143,8 +143,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR }) ctx.currentVars = oldCurrentVars ctx.INPUT_ROW = oldInputRow - // make sure INPUT_ROW is declared even if splitExpressions - // returns an inlined block code } From 2aef79a65a145b76a88f1d4d9367091fd238b949 Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Fri, 16 Nov 2018 08:53:29 -0600 Subject: [PATCH 2072/2461] [SPARK-25023] More detailed security guidance for K8S ## What changes were proposed in this pull request? Highlights specific security issues to be aware of with Spark on K8S and recommends K8S mechanisms that should be used to secure clusters. ## How was this patch tested? N/A - Documentation only CC felixcheung tgravescs skonto Closes #23013 from rvesse/SPARK-25023. Authored-by: Rob Vesse Signed-off-by: Sean Owen --- docs/running-on-kubernetes.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 905226877720a..a7b6fd12a3e5f 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -15,7 +15,19 @@ container images and entrypoints.** # Security Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. -Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. +Please see [Spark Security](security.html) and the specific advice below before running Spark. + +## User Identity + +Images built from the project provided Dockerfiles do not contain any [`USER`](https://docs.docker.com/engine/reference/builder/#user) directives. This means that the resulting images will be running the Spark processes as `root` inside the container. On unsecured clusters this may provide an attack vector for privilege escalation and container breakout. Therefore security conscious deployments should consider providing custom images with `USER` directives specifying an unprivileged UID and GID. + +Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as. + +## Volume Mounts + +As described later in this document under [Using Kubernetes Volumes](#using-kubernetes-volumes) Spark on K8S provides configuration options that allow for mounting certain volume types into the driver and executor pods. In particular it allows for [`hostPath`](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath) volumes which as described in the Kubernetes documentation have known security vulnerabilities. + +Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/) to limit the ability to mount `hostPath` volumes appropriately for their environments. # Prerequisites @@ -214,6 +226,8 @@ Starting with Spark 2.4.0, users can mount the following types of Kubernetes [vo * [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node. * [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod. +**NB:** Please see the [Security](#security) section of this document for security issues related to volume mounts. + To mount a volume of any of the types above into the driver pod, use the following configuration property: ``` From 696b75a81013ad61d25e0552df2b019c7531f983 Mon Sep 17 00:00:00 2001 From: Matt Molek Date: Fri, 16 Nov 2018 10:00:21 -0600 Subject: [PATCH 2073/2461] [SPARK-25934][MESOS] Don't propagate SPARK_CONF_DIR from spark submit ## What changes were proposed in this pull request? Don't propagate SPARK_CONF_DIR to the driver in mesos cluster mode. ## How was this patch tested? I built the 2.3.2 tag with this patch added and deployed a test job to a mesos cluster to confirm that the incorrect SPARK_CONF_DIR was no longer passed from the submit command. Closes #22937 from mpmolek/fix-conf-dir. Authored-by: Matt Molek Signed-off-by: Sean Owen --- .../spark/deploy/rest/RestSubmissionClient.scala | 8 +++++--- .../deploy/rest/StandaloneRestSubmitSuite.scala | 12 ++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 31a8e3e60c067..afa413fe165df 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -408,6 +408,10 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { } private[spark] object RestSubmissionClient { + + // SPARK_HOME and SPARK_CONF_DIR are filtered out because they are usually wrong + // on the remote machine (SPARK-12345) (SPARK-25934) + private val BLACKLISTED_SPARK_ENV_VARS = Set("SPARK_ENV_LOADED", "SPARK_HOME", "SPARK_CONF_DIR") private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -417,9 +421,7 @@ private[spark] object RestSubmissionClient { */ private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { env.filterKeys { k => - // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) - (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || - k.startsWith("MESOS_") + (k.startsWith("SPARK_") && !BLACKLISTED_SPARK_ENV_VARS.contains(k)) || k.startsWith("MESOS_") } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 4839c842cc785..89b8bb4ff7d03 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -396,6 +396,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(filteredVariables == Map("SPARK_VAR" -> "1")) } + test("client does not send 'SPARK_HOME' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_HOME" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + + test("client does not send 'SPARK_CONF_DIR' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_CONF_DIR" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + test("client includes mesos env vars") { val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1") val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) From a2fc48c28c06192d1f650582d128d60c7188ec62 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sat, 17 Nov 2018 00:12:17 +0800 Subject: [PATCH 2074/2461] [SPARK-26034][PYTHON][TESTS] Break large mllib/tests.py file into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR breaks down the large mllib/tests.py file that contains all Python MLlib unit tests into several smaller test files to be easier to read and maintain. The tests are broken down as follows: ``` pyspark ├── __init__.py ... ├── mllib │ ├── __init__.py ... │ ├── tests │ │ ├── __init__.py │ │ ├── test_algorithms.py │ │ ├── test_feature.py │ │ ├── test_linalg.py │ │ ├── test_stat.py │ │ ├── test_streaming_algorithms.py │ │ └── test_util.py ... ├── testing ... │ ├── mllibutils.py ... ``` ## How was this patch tested? Ran tests manually by module to ensure test count was the same, and ran `python/run-tests --modules=pyspark-mllib` to verify all passing with Python 2.7 and Python 3.6. Also installed scipy to include optional tests in test_linalg. Closes #23056 from BryanCutler/python-test-breakup-mllib-SPARK-26034. Authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 9 +- python/pyspark/mllib/tests.py | 1787 ----------------- python/pyspark/mllib/tests/__init__.py | 16 + python/pyspark/mllib/tests/test_algorithms.py | 313 +++ python/pyspark/mllib/tests/test_feature.py | 201 ++ python/pyspark/mllib/tests/test_linalg.py | 642 ++++++ python/pyspark/mllib/tests/test_stat.py | 197 ++ .../mllib/tests/test_streaming_algorithms.py | 523 +++++ python/pyspark/mllib/tests/test_util.py | 115 ++ python/pyspark/testing/mllibutils.py | 44 + 10 files changed, 2059 insertions(+), 1788 deletions(-) delete mode 100644 python/pyspark/mllib/tests.py create mode 100644 python/pyspark/mllib/tests/__init__.py create mode 100644 python/pyspark/mllib/tests/test_algorithms.py create mode 100644 python/pyspark/mllib/tests/test_feature.py create mode 100644 python/pyspark/mllib/tests/test_linalg.py create mode 100644 python/pyspark/mllib/tests/test_stat.py create mode 100644 python/pyspark/mllib/tests/test_streaming_algorithms.py create mode 100644 python/pyspark/mllib/tests/test_util.py create mode 100644 python/pyspark/testing/mllibutils.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 58b48f43f6468..547635a412913 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -416,6 +416,7 @@ def __hash__(self): "python/pyspark/mllib" ], python_test_goals=[ + # doctests "pyspark.mllib.classification", "pyspark.mllib.clustering", "pyspark.mllib.evaluation", @@ -430,7 +431,13 @@ def __hash__(self): "pyspark.mllib.stat.KernelDensity", "pyspark.mllib.tree", "pyspark.mllib.util", - "pyspark.mllib.tests", + # unittests + "pyspark.mllib.tests.test_algorithms", + "pyspark.mllib.tests.test_feature", + "pyspark.mllib.tests.test_linalg", + "pyspark.mllib.tests.test_stat", + "pyspark.mllib.tests.test_streaming_algorithms", + "pyspark.mllib.tests.test_util", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py deleted file mode 100644 index 4c2ce137e331c..0000000000000 --- a/python/pyspark/mllib/tests.py +++ /dev/null @@ -1,1787 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Fuller unit tests for Python MLlib. -""" - -import os -import sys -import tempfile -import array as pyarray -from math import sqrt -from time import time, sleep -from shutil import rmtree - -from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) -from numpy import sum as array_sum - -from py4j.protocol import Py4JJavaError -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version > '3': - basestring = str - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark import SparkContext -import pyspark.ml.linalg as newlinalg -from pyspark.mllib.common import _to_java_object_rdd -from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT -from pyspark.mllib.linalg.distributed import RowMatrix -from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD -from pyspark.mllib.fpm import FPGrowth -from pyspark.mllib.recommendation import Rating -from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD -from pyspark.mllib.random import RandomRDDs -from pyspark.mllib.stat import Statistics -from pyspark.mllib.feature import HashingTF -from pyspark.mllib.feature import Word2Vec -from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler, ElementwiseProduct -from pyspark.mllib.util import LinearDataGenerator -from pyspark.mllib.util import MLUtils -from pyspark.serializers import PickleSerializer -from pyspark.streaming import StreamingContext -from pyspark.sql import SparkSession -from pyspark.sql.utils import IllegalArgumentException -from pyspark.streaming import StreamingContext - -_have_scipy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class MLLibStreamingTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.ssc = StreamingContext(self.sc, 1.0) - - def tearDown(self): - self.ssc.stop(False) - self.sc.stop() - - @staticmethod - def _eventually(condition, timeout=30.0, catch_assertions=False): - """ - Wait a given amount of time for a condition to pass, else fail with an error. - This is a helper utility for streaming ML tests. - :param condition: Function that checks for termination conditions. - condition() can return: - - True: Conditions met. Return without error. - - other value: Conditions not met yet. Continue. Upon timeout, - include last such value in error message. - Note that this method may be called at any time during - streaming execution (e.g., even before any results - have been created). - :param timeout: Number of seconds to wait. Default 30 seconds. - :param catch_assertions: If False (default), do not catch AssertionErrors. - If True, catch AssertionErrors; continue, but save - error to throw upon timeout. - """ - start_time = time() - lastValue = None - while time() - start_time < timeout: - if catch_assertions: - try: - lastValue = condition() - except AssertionError as e: - lastValue = e - else: - lastValue = condition() - if lastValue is True: - return - sleep(0.01) - if isinstance(lastValue, AssertionError): - raise lastValue - else: - raise AssertionError( - "Test failed due to timeout after %g sec, with last condition returning: %s" - % (timeout, lastValue)) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_parse_vector(self): - a = DenseVector([]) - self.assertEqual(str(a), '[]') - self.assertEqual(Vectors.parse(str(a)), a) - a = DenseVector([3, 4, 6, 7]) - self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(4, [], []) - self.assertEqual(str(a), '(4,[],[])') - self.assertEqual(SparseVector.parse(str(a)), a) - a = SparseVector(4, [0, 2], [3, 4]) - self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(10, [0, 1], [4, 5]) - self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - def test_ml_mllib_vector_conversion(self): - # to ml - # dense - mllibDV = Vectors.dense([1, 2, 3]) - mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) - mlDV2 = mllibDV.asML() - self.assertEqual(mlDV2, mlDV1) - # sparse - mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV2 = mllibSV.asML() - self.assertEqual(mlSV2, mlSV1) - # from ml - # dense - mllibDV1 = Vectors.dense([1, 2, 3]) - mlDV = newlinalg.Vectors.dense([1, 2, 3]) - mllibDV2 = Vectors.fromML(mlDV) - self.assertEqual(mllibDV1, mllibDV2) - # sparse - mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mllibSV2 = Vectors.fromML(mlSV) - self.assertEqual(mllibSV1, mllibSV2) - - def test_ml_mllib_matrix_conversion(self): - # to ml - # dense - mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM2 = mllibDM.asML() - self.assertEqual(mlDM2, mlDM1) - # transposed - mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt2 = mllibDMt.asML() - self.assertEqual(mlDMt2, mlDMt1) - # sparse - mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM2 = mllibSM.asML() - self.assertEqual(mlSM2, mlSM1) - # transposed - mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt2 = mllibSMt.asML() - self.assertEqual(mlSMt2, mlSMt1) - # from ml - # dense - mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) - mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) - mllibDM2 = Matrices.fromML(mlDM) - self.assertEqual(mllibDM1, mllibDM2) - # transposed - mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) - mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) - mllibDMt2 = Matrices.fromML(mlDMt) - self.assertEqual(mllibDMt1, mllibDMt2) - # sparse - mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mllibSM2 = Matrices.fromML(mlSM) - self.assertEqual(mllibSM1, mllibSM2) - # transposed - mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mllibSMt2 = Matrices.fromML(mlSMt) - self.assertEqual(mllibSMt1, mllibSMt2) - - -class ListTests(MLlibTestCase): - - """ - Test MLlib algorithms on plain lists, to make sure they're passed through - as NumPy arrays. - """ - - def test_bisecting_kmeans(self): - from pyspark.mllib.clustering import BisectingKMeans - data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) - bskm = BisectingKMeans() - model = bskm.train(self.sc.parallelize(data, 2), k=4) - p = array([0.0, 0.0]) - rdd_p = self.sc.parallelize([p]) - self.assertEqual(model.predict(p), model.predict(rdd_p).first()) - self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) - self.assertEqual(model.k, len(model.clusterCenters)) - - def test_kmeans(self): - from pyspark.mllib.clustering import KMeans - data = [ - [0, 1.1], - [0, 1.2], - [1.1, 0], - [1.2, 0], - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", - initializationSteps=7, epsilon=1e-4) - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_kmeans_deterministic(self): - from pyspark.mllib.clustering import KMeans - X = range(0, 100, 10) - Y = range(0, 100, 10) - data = [[x, y] for x, y in zip(X, Y)] - clusters1 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - clusters2 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - centers1 = clusters1.centers - centers2 = clusters2.centers - for c1, c2 in zip(centers1, centers2): - # TODO: Allow small numeric difference. - self.assertTrue(array_equal(c1, c2)) - - def test_gmm(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - [1, 2], - [8, 9], - [-4, -3], - [-6, -7], - ]) - clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=1) - labels = clusters.predict(data).collect() - self.assertEqual(labels[0], labels[1]) - self.assertEqual(labels[2], labels[3]) - - def test_gmm_deterministic(self): - from pyspark.mllib.clustering import GaussianMixture - x = range(0, 100, 10) - y = range(0, 100, 10) - data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) - clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEqual(round(c1, 7), round(c2, 7)) - - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ - RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel - data = [ - LabeledPoint(0.0, [1, 0, 0]), - LabeledPoint(1.0, [0, 1, 1]), - LabeledPoint(0.0, [2, 0, 0]), - LabeledPoint(1.0, [0, 2, 1]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - temp_dir = tempfile.mkdtemp() - - lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd, iterations=10) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - dt_model_dir = os.path.join(temp_dir, "dt") - dt_model.save(self.sc, dt_model_dir) - same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) - self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) - - rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, - maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - rf_model_dir = os.path.join(temp_dir, "rf") - rf_model.save(self.sc, rf_model_dir) - same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) - self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) - - gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - gbt_model_dir = os.path.join(temp_dir, "gbt") - gbt_model.save(self.sc, gbt_model_dir) - same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) - self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) - - try: - rmtree(temp_dir) - except OSError: - pass - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees - data = [ - LabeledPoint(-1.0, [0, -1]), - LabeledPoint(1.0, [0, 1]), - LabeledPoint(-1.0, [0, -2]), - LabeledPoint(1.0, [0, 2]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd, iterations=10) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - try: - LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - except ValueError: - self.fail() - - # Verify that maxBins is being passed through - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) - with self.assertRaises(Exception) as cm: - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) - - -class StatTests(MLlibTestCase): - # SPARK-4023 - def test_col_with_different_rdds(self): - # numpy - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(1000, summary.count()) - # array - data = self.sc.parallelize([range(10)] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - # array - data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - - def test_col_norms(self): - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(10, len(summary.normL1())) - self.assertEqual(10, len(summary.normL2())) - - data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) - summary2 = Statistics.colStats(data2) - self.assertEqual(array([45.0]), summary2.normL1()) - import math - expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) - self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(MLlibTestCase): - - """ - Test both vector operations and MLlib algorithms with SciPy sparse matrices, - if SciPy is available. - """ - - def test_serialize(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(lil)) - self.assertEqual(sv, _convert_to_vector(lil.tocsc())) - self.assertEqual(sv, _convert_to_vector(lil.tocoo())) - self.assertEqual(sv, _convert_to_vector(lil.tocsr())) - self.assertEqual(sv, _convert_to_vector(lil.todok())) - - def serialize(l): - return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEqual(sv, serialize(lil)) - self.assertEqual(sv, serialize(lil.tocsc())) - self.assertEqual(sv, serialize(lil.tocsr())) - self.assertEqual(sv, serialize(lil.todok())) - - def test_convert_to_vector(self): - from scipy.sparse import csc_matrix - # Create a CSC matrix with non-sorted indices - indptr = array([0, 2]) - indices = array([3, 1]) - data = array([2.0, 1.0]) - csc = csc_matrix((data, indices, indptr)) - self.assertFalse(csc.has_sorted_indices) - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(csc)) - - def test_dot(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEqual(10.0, dv.dot(lil)) - - def test_squared_distance(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 3 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEqual(15.0, dv.squared_distance(lil)) - self.assertEqual(15.0, sv.squared_distance(lil)) - - def scipy_matrix(self, size, values): - """Create a column SciPy matrix from a dictionary of values""" - from scipy.sparse import lil_matrix - lil = lil_matrix((size, 1)) - for key, value in values.items(): - lil[key, 0] = value - return lil - - def test_clustering(self): - from pyspark.mllib.clustering import KMeans - data = [ - self.scipy_matrix(3, {1: 1.0}), - self.scipy_matrix(3, {1: 1.1}), - self.scipy_matrix(3, {2: 1.0}), - self.scipy_matrix(3, {2: 1.1}) - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LogisticRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - -class ChiSqTestTests(MLlibTestCase): - def test_goodness_of_fit(self): - from numpy import inf - - observed = Vectors.dense([4, 6, 5]) - pearson = Statistics.chiSqTest(observed) - - # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` - self.assertEqual(pearson.statistic, 0.4) - self.assertEqual(pearson.degreesOfFreedom, 2) - self.assertAlmostEqual(pearson.pValue, 0.8187, 4) - - # Different expected and observed sum - observed1 = Vectors.dense([21, 38, 43, 80]) - expected1 = Vectors.dense([3, 5, 7, 20]) - pearson1 = Statistics.chiSqTest(observed1, expected1) - - # Results validated against the R command - # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` - self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) - self.assertEqual(pearson1.degreesOfFreedom, 3) - self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) - - # Vectors with different sizes - observed3 = Vectors.dense([1.0, 2.0, 3.0]) - expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) - self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) - - # Negative counts in observed - neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) - - # Count = 0.0 in expected but not observed - zero_expected = Vectors.dense([1.0, 0.0, 3.0]) - pearson_inf = Statistics.chiSqTest(observed, zero_expected) - self.assertEqual(pearson_inf.statistic, inf) - self.assertEqual(pearson_inf.degreesOfFreedom, 2) - self.assertEqual(pearson_inf.pValue, 0.0) - - # 0.0 in expected and observed simultaneously - zero_observed = Vectors.dense([2.0, 0.0, 1.0]) - self.assertRaises( - IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) - - def test_matrix_independence(self): - data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] - chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - - # Results validated against R command - # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` - self.assertAlmostEqual(chi.statistic, 21.9958, 4) - self.assertEqual(chi.degreesOfFreedom, 6) - self.assertAlmostEqual(chi.pValue, 0.001213, 4) - - # Negative counts - neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) - - # Row sum = 0.0 - row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) - - # Column sum = 0.0 - col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) - - def test_chi_sq_pearson(self): - data = [ - LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), - LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), - LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), - LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) - ] - - for numParts in [2, 4, 6, 8]: - chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) - feature1 = chi[0] - self.assertEqual(feature1.statistic, 0.75) - self.assertEqual(feature1.degreesOfFreedom, 2) - self.assertAlmostEqual(feature1.pValue, 0.6873, 4) - - feature2 = chi[1] - self.assertEqual(feature2.statistic, 1.5) - self.assertEqual(feature2.degreesOfFreedom, 3) - self.assertAlmostEqual(feature2.pValue, 0.6823, 4) - - def test_right_number_of_results(self): - num_cols = 1001 - sparse_data = [ - LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), - LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) - ] - chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) - self.assertEqual(len(chi), num_cols) - self.assertIsNotNone(chi[1000]) - - -class KolmogorovSmirnovTest(MLlibTestCase): - - def test_R_implementation_equivalence(self): - data = self.sc.parallelize([ - 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, - -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, - -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, - -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, - 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 - ]) - model = Statistics.kolmogorovSmirnovTest(data, "norm") - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - -class SerDeTest(MLlibTestCase): - def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) - self.assertEqual(_to_java_object_rdd(data).count(), 10) - - -class FeatureTest(MLlibTestCase): - def test_idf_model(self): - data = [ - Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), - Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), - Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), - Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) - ] - model = IDF().fit(self.sc.parallelize(data, 2)) - idf = model.idf() - self.assertEqual(len(idf), 11) - - -class Word2VecTests(MLlibTestCase): - def test_word2vec_setters(self): - model = Word2Vec() \ - .setVectorSize(2) \ - .setLearningRate(0.01) \ - .setNumPartitions(2) \ - .setNumIterations(10) \ - .setSeed(1024) \ - .setMinCount(3) \ - .setWindowSize(6) - self.assertEqual(model.vectorSize, 2) - self.assertTrue(model.learningRate < 0.02) - self.assertEqual(model.numPartitions, 2) - self.assertEqual(model.numIterations, 10) - self.assertEqual(model.seed, 1024) - self.assertEqual(model.minCount, 3) - self.assertEqual(model.windowSize, 6) - - def test_word2vec_get_vectors(self): - data = [ - ["a", "b", "c", "d", "e", "f", "g"], - ["a", "b", "c", "d", "e", "f"], - ["a", "b", "c", "d", "e"], - ["a", "b", "c", "d"], - ["a", "b", "c"], - ["a", "b"], - ["a"] - ] - model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEqual(len(model.getVectors()), 3) - - -class StandardScalerTests(MLlibTestCase): - def test_model_setters(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertIsNotNone(model.setWithMean(True)) - self.assertIsNotNone(model.setWithStd(True)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) - - def test_model_transform(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) - - -class ElementwiseProductTests(MLlibTestCase): - def test_model_transform(self): - weight = Vectors.dense([3, 2, 1]) - - densevec = Vectors.dense([4, 5, 6]) - sparsevec = Vectors.sparse(3, [0], [1]) - eprod = ElementwiseProduct(weight) - self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) - self.assertEqual( - eprod.transform(sparsevec), SparseVector(3, [0], [3])) - - -class StreamingKMeansTest(MLLibStreamingTestCase): - def test_model_params(self): - """Test that the model params are set correctly""" - stkm = StreamingKMeans() - stkm.setK(5).setDecayFactor(0.0) - self.assertEqual(stkm._k, 5) - self.assertEqual(stkm._decayFactor, 0.0) - - # Model not set yet. - self.assertIsNone(stkm.latestModel()) - self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) - - stkm.setInitialCenters( - centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEqual( - stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) - - def test_accuracy_for_single_center(self): - """Test that parameters obtained are correct for a single center.""" - centers, batches = self.streamingKMeansDataGenerator( - batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) - stkm = StreamingKMeans(1) - stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) - input_stream = self.ssc.queueStream( - [self.sc.parallelize(batch, 1) for batch in batches]) - stkm.trainOn(input_stream) - - self.ssc.start() - - def condition(): - self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) - return True - self._eventually(condition, catch_assertions=True) - - realCenters = array_sum(array(centers), axis=0) - for i in range(5): - modelCenters = stkm.latestModel().centers[0][i] - self.assertAlmostEqual(centers[0][i], modelCenters, 1) - self.assertAlmostEqual(realCenters[i], modelCenters, 1) - - def streamingKMeansDataGenerator(self, batches, numPoints, - k, d, r, seed, centers=None): - rng = random.RandomState(seed) - - # Generate centers. - centers = [rng.randn(d) for i in range(k)] - - return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) - for j in range(numPoints)] - for i in range(batches)] - - def test_trainOn_model(self): - """Test the model on toy data with four clusters.""" - stkm = StreamingKMeans() - initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] - stkm.setInitialCenters( - centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) - - # Create a toy dataset by setting a tiny offset for each point. - offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] - batches = [] - for offset in offsets: - batches.append([[offset[0] + center[0], offset[1] + center[1]] - for center in initCenters]) - - batches = [self.sc.parallelize(batch, 1) for batch in batches] - input_stream = self.ssc.queueStream(batches) - stkm.trainOn(input_stream) - self.ssc.start() - - # Give enough time to train the model. - def condition(): - finalModel = stkm.latestModel() - self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) - return True - self._eventually(condition, catch_assertions=True) - - def test_predictOn_model(self): - """Test that the model predicts correctly on toy data.""" - stkm = StreamingKMeans() - stkm._model = StreamingKMeansModel( - clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], - clusterWeights=[1.0, 1.0, 1.0, 1.0]) - - predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] - predict_stream = self.ssc.queueStream(predict_data) - predict_val = stkm.predictOn(predict_stream) - - result = [] - - def update(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - result.append(rdd_collect) - - predict_val.foreachRDD(update) - self.ssc.start() - - def condition(): - self.assertEqual(result, [[0], [1], [2], [3]]) - return True - - self._eventually(condition, catch_assertions=True) - - @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") - def test_trainOn_predictOn(self): - """Test that prediction happens on the updated model.""" - stkm = StreamingKMeans(decayFactor=0.0, k=2) - stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) - - # Since decay factor is set to zero, once the first batch - # is passed the clusterCenters are updated to [-0.5, 0.7] - # which causes 0.2 & 0.3 to be classified as 1, even though the - # classification based in the initial model would have been 0 - # proving that the model is updated. - batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [self.sc.parallelize(batch) for batch in batches] - input_stream = self.ssc.queueStream(batches) - predict_results = [] - - def collect(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - predict_results.append(rdd_collect) - - stkm.trainOn(input_stream) - predict_stream = stkm.predictOn(input_stream) - predict_stream.foreachRDD(collect) - - self.ssc.start() - - def condition(): - self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) - return True - - self._eventually(condition, catch_assertions=True) - - -class LinearDataGeneratorTests(MLlibTestCase): - def test_dim(self): - linear_data = LinearDataGenerator.generateLinearInput( - intercept=0.0, weights=[0.0, 0.0, 0.0], - xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], - nPoints=4, seed=0, eps=0.1) - self.assertEqual(len(linear_data), 4) - for point in linear_data: - self.assertEqual(len(point.features), 3) - - linear_data = LinearDataGenerator.generateLinearRDD( - sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, - nParts=2, intercept=0.0).collect() - self.assertEqual(len(linear_data), 6) - for point in linear_data: - self.assertEqual(len(point.features), 2) - - -class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): - - @staticmethod - def generateLogisticInput(offset, scale, nPoints, seed): - """ - Generate 1 / (1 + exp(-x * scale + offset)) - - where, - x is randomnly distributed and the threshold - and labels for each sample in x is obtained from a random uniform - distribution. - """ - rng = random.RandomState(seed) - x = rng.randn(nPoints) - sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) - y_p = rng.rand(nPoints) - cut_off = y_p <= sigmoid - y_p[cut_off] = 1.0 - y_p[~cut_off] = 0.0 - return [ - LabeledPoint(y_p[i], Vectors.dense([x[i]])) - for i in range(nPoints)] - - def test_parameter_accuracy(self): - """ - Test that the final value of weights is close to the desired value. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - - self.ssc.start() - - def condition(): - rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 - self.assertAlmostEqual(rel, 0.1, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_convergence(self): - """ - Test that weights converge to the required value on toy data. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - models = [] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - input_stream.foreachRDD( - lambda x: models.append(slr.latestModel().weights[0])) - - self.ssc.start() - - def condition(): - self.assertEqual(len(models), len(input_batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, 60.0, catch_assertions=True) - - t_models = array(models) - diff = t_models[1:] - t_models[:-1] - # Test that weights improve with a small tolerance - self.assertTrue(all(diff >= -0.1)) - self.assertTrue(array_sum(diff > 0) > 1) - - @staticmethod - def calculate_accuracy_error(true, predicted): - return sum(abs(array(true) - array(predicted))) / len(true) - - def test_predictions(self): - """Test predicted values on a toy model.""" - input_batches = [] - for i in range(20): - batch = self.sc.parallelize( - self.generateLogisticInput(0, 1.5, 100, 42 + i)) - input_batches.append(batch.map(lambda x: (x.label, x.features))) - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([1.5]) - predict_stream = slr.predictOnValues(input_stream) - true_predicted = [] - predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) - self.ssc.start() - - def condition(): - self.assertEqual(len(true_predicted), len(input_batches)) - return True - - self._eventually(condition, catch_assertions=True) - - # Test that the accuracy error is no more than 0.4 on each batch. - for batch in true_predicted: - true, predicted = zip(*batch) - self.assertTrue( - self.calculate_accuracy_error(true, predicted) < 0.4) - - def test_training_and_prediction(self): - """Test that the model improves on toy data with no. of batches""" - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.01, numIterations=25) - slr.setInitialWeights([-0.1]) - errors = [] - - def collect_errors(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(self.calculate_accuracy_error(true, predicted)) - - true_predicted = [] - input_stream = self.ssc.queueStream(input_batches) - predict_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - ps = slr.predictOnValues(predict_stream) - ps.foreachRDD(lambda x: collect_errors(x)) - - self.ssc.start() - - def condition(): - # Test that the improvement in error is > 0.3 - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 0.3) - if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): - - def assertArrayAlmostEqual(self, array1, array2, dec): - for i, j in array1, array2: - self.assertAlmostEqual(i, j, dec) - - def test_parameter_accuracy(self): - """Test that coefs are predicted accurately by fitting on toy data.""" - - # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients - # (10, 10) - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0, 0.0]) - xMean = [0.0, 0.0] - xVariance = [1.0 / 3.0, 1.0 / 3.0] - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - input_stream = self.ssc.queueStream(batches) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertArrayAlmostEqual( - slr.latestModel().weights.array, [10., 10.], 1) - self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_parameter_convergence(self): - """Test that the model parameters improve with streaming data.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - model_weights = [] - input_stream = self.ssc.queueStream(batches) - input_stream.foreachRDD( - lambda x: model_weights.append(slr.latestModel().weights[0])) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertEqual(len(model_weights), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - w = array(model_weights) - diff = w[1:] - w[:-1] - self.assertTrue(all(diff >= -0.1)) - - def test_prediction(self): - """Test prediction on a model with weights already set.""" - # Create a model with initial Weights equal to coefs - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([10.0, 10.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], - 100, 42 + i, 0.1) - batches.append( - self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) - - input_stream = self.ssc.queueStream(batches) - output_stream = slr.predictOnValues(input_stream) - samples = [] - output_stream.foreachRDD(lambda x: samples.append(x.collect())) - - self.ssc.start() - - def condition(): - self.assertEqual(len(samples), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - # Test that mean absolute error on each batch is less than 0.1 - for batch in samples: - true, predicted = zip(*batch) - self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) - - def test_train_prediction(self): - """Test that error on test data improves as model is trained.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in batches] - errors = [] - - def func(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(mean(abs(true) - abs(predicted))) - - input_stream = self.ssc.queueStream(batches) - output_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - output_stream = slr.predictOnValues(output_stream) - output_stream.foreachRDD(func) - self.ssc.start() - - def condition(): - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 2) - if len(errors) >= 3 and errors[1] - errors[-1] > 2: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class MLUtilsTests(MLlibTestCase): - def test_append_bias(self): - data = [2.0, 2.0, 2.0] - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_vector(self): - data = Vectors.dense([2.0, 2.0, 2.0]) - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_sp_vector(self): - data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) - expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) - # Returned value must be SparseVector - ret = MLUtils.appendBias(data) - self.assertEqual(ret, expected) - self.assertEqual(type(ret), SparseVector) - - def test_load_vectors(self): - import shutil - data = [ - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0] - ] - temp_dir = tempfile.mkdtemp() - load_vectors_path = os.path.join(temp_dir, "test_load_vectors") - try: - self.sc.parallelize(data).saveAsTextFile(load_vectors_path) - ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) - ret = ret_rdd.collect() - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) - self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) - except: - self.fail() - finally: - shutil.rmtree(load_vectors_path) - - -class ALSTests(MLlibTestCase): - - def test_als_ratings_serialize(self): - r = Rating(7, 1123, 3.14) - jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) - nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) - self.assertEqual(r.user, nr.user) - self.assertEqual(r.product, nr.product) - self.assertAlmostEqual(r.rating, nr.rating, 2) - - def test_als_ratings_id_long_error(self): - r = Rating(1205640308657491975, 50233468418, 1.0) - # rating user id exceeds max int value, should fail when pickled - self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, - bytearray(ser.dumps(r))) - - -class HashingTFTest(MLlibTestCase): - - def test_binary_term_freqs(self): - hashingTF = HashingTF(100).setBinary(True) - doc = "a a b c c c".split(" ") - n = hashingTF.numFeatures - output = hashingTF.transform(doc).toArray() - expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, - hashingTF.indexOf("b"): 1.0, - hashingTF.indexOf("c"): 1.0}).toArray() - for i in range(0, n): - self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(output[i])) - - -class DimensionalityReductionTests(MLlibTestCase): - - denseData = [ - Vectors.dense([0.0, 1.0, 2.0]), - Vectors.dense([3.0, 4.0, 5.0]), - Vectors.dense([6.0, 7.0, 8.0]), - Vectors.dense([9.0, 0.0, 1.0]) - ] - sparseData = [ - Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), - Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), - Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), - Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) - ] - - def assertEqualUpToSign(self, vecA, vecB): - eq1 = vecA - vecB - eq2 = vecA + vecB - self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) - - def test_svd(self): - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - m = 4 - n = 3 - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - rm = mat.computeSVD(k, computeU=True) - self.assertEqual(rm.s.size, k) - self.assertEqual(rm.U.numRows(), m) - self.assertEqual(rm.U.numCols(), k) - self.assertEqual(rm.V.numRows, n) - self.assertEqual(rm.V.numCols, k) - - # Test that U returned is None if computeU is set to False. - self.assertEqual(mat.computeSVD(1).U, None) - - # Test that low rank matrices cannot have number of singular values - # greater than a limit. - rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) - self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) - - def test_pca(self): - expected_pcs = array([ - [0.0, 1.0, 0.0], - [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], - [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] - ]) - n = 3 - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - pcs = mat.computePrincipalComponents(k) - self.assertEqual(pcs.numRows, n) - self.assertEqual(pcs.numCols, k) - - # We can just test the updated principal component for equality. - self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) - - -class FPGrowthTest(MLlibTestCase): - - def test_fpgrowth(self): - data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] - rdd = self.sc.parallelize(data, 2) - model1 = FPGrowth.train(rdd, 0.6, 2) - # use default data partition number when numPartitions is not specified - model2 = FPGrowth.train(rdd, 0.6) - self.assertEqual(sorted(model1.freqItemsets().collect()), - sorted(model2.freqItemsets().collect())) - -if __name__ == "__main__": - from pyspark.mllib.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - sc.stop() diff --git a/python/pyspark/mllib/tests/__init__.py b/python/pyspark/mllib/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/mllib/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py new file mode 100644 index 0000000000000..8a3454144a115 --- /dev/null +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -0,0 +1,313 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile +from shutil import rmtree + +from numpy import array, array_equal + +from py4j.protocol import Py4JJavaError + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.fpm import FPGrowth +from pyspark.mllib.recommendation import Rating +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase + + +ser = make_serializer() + + +class ListTests(MLlibTestCase): + + """ + Test MLlib algorithms on plain lists, to make sure they're passed through + as NumPy arrays. + """ + + def test_bisecting_kmeans(self): + from pyspark.mllib.clustering import BisectingKMeans + data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) + bskm = BisectingKMeans() + model = bskm.train(self.sc.parallelize(data, 2), k=4) + p = array([0.0, 0.0]) + rdd_p = self.sc.parallelize([p]) + self.assertEqual(model.predict(p), model.predict(rdd_p).first()) + self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) + self.assertEqual(model.k, len(model.clusterCenters)) + + def test_kmeans(self): + from pyspark.mllib.clustering import KMeans + data = [ + [0, 1.1], + [0, 1.2], + [1.1, 0], + [1.2, 0], + ] + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", + initializationSteps=7, epsilon=1e-4) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) + + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=1) + labels = clusters.predict(data).collect() + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEqual(round(c1, 7), round(c2, 7)) + + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + + def test_classification(self): + from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest, \ + RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel + data = [ + LabeledPoint(0.0, [1, 0, 0]), + LabeledPoint(1.0, [0, 1, 1]), + LabeledPoint(0.0, [2, 0, 0]), + LabeledPoint(1.0, [0, 2, 1]) + ] + rdd = self.sc.parallelize(data) + features = [p.features.tolist() for p in data] + + temp_dir = tempfile.mkdtemp() + + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + svm_model = SVMWithSGD.train(rdd, iterations=10) + self.assertTrue(svm_model.predict(features[0]) <= 0) + self.assertTrue(svm_model.predict(features[1]) > 0) + self.assertTrue(svm_model.predict(features[2]) <= 0) + self.assertTrue(svm_model.predict(features[3]) > 0) + + nb_model = NaiveBayes.train(rdd) + self.assertTrue(nb_model.predict(features[0]) <= 0) + self.assertTrue(nb_model.predict(features[1]) > 0) + self.assertTrue(nb_model.predict(features[2]) <= 0) + self.assertTrue(nb_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + dt_model_dir = os.path.join(temp_dir, "dt") + dt_model.save(self.sc, dt_model_dir) + same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) + self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) + + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + rf_model_dir = os.path.join(temp_dir, "rf") + rf_model.save(self.sc, rf_model_dir) + same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) + self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + gbt_model_dir = os.path.join(temp_dir, "gbt") + gbt_model.save(self.sc, gbt_model_dir) + same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) + self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) + + try: + rmtree(temp_dir) + except OSError: + pass + + def test_regression(self): + from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ + RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees + data = [ + LabeledPoint(-1.0, [0, -1]), + LabeledPoint(1.0, [0, 1]), + LabeledPoint(-1.0, [0, -2]), + LabeledPoint(1.0, [0, 2]) + ] + rdd = self.sc.parallelize(data) + features = [p.features.tolist() for p in data] + + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + lasso_model = LassoWithSGD.train(rdd, iterations=10) + self.assertTrue(lasso_model.predict(features[0]) <= 0) + self.assertTrue(lasso_model.predict(features[1]) > 0) + self.assertTrue(lasso_model.predict(features[2]) <= 0) + self.assertTrue(lasso_model.predict(features[3]) > 0) + + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(rr_model.predict(features[0]) <= 0) + self.assertTrue(rr_model.predict(features[1]) > 0) + self.assertTrue(rr_model.predict(features[2]) <= 0) + self.assertTrue(rr_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + except ValueError: + self.fail() + + # Verify that maxBins is being passed through + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) + with self.assertRaises(Exception) as cm: + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) + + +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, + bytearray(ser.dumps(r))) + + +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py new file mode 100644 index 0000000000000..48ed810fa6fcb --- /dev/null +++ b/python/pyspark/mllib/tests/test_feature.py @@ -0,0 +1,201 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from math import sqrt + +from numpy import array, random, exp, abs, tile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +from pyspark.mllib.feature import HashingTF, IDF, StandardScaler, ElementwiseProduct, Word2Vec +from pyspark.testing.mllibutils import MLlibTestCase + + +class FeatureTest(MLlibTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(MLlibTestCase): + def test_word2vec_setters(self): + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) \ + .setWindowSize(6) + self.assertEqual(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) + self.assertEqual(model.windowSize, 6) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEqual(len(model.getVectors()), 3) + + +class StandardScalerTests(MLlibTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_feature import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py new file mode 100644 index 0000000000000..550e32a9af024 --- /dev/null +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import array as pyarray + +from numpy import array, array_equal, zeros, arange, tile, ones, inf +from numpy import sum as array_sum + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import pyspark.ml.linalg as newlinalg +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.regression import LabeledPoint +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase + +_have_scipy = False +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass + + +ser = make_serializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_parse_vector(self): + a = DenseVector([]) + self.assertEqual(str(a), '[]') + self.assertEqual(Vectors.parse(str(a)), a) + a = DenseVector([3, 4, 6, 7]) + self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(4, [], []) + self.assertEqual(str(a), '(4,[],[])') + self.assertEqual(SparseVector.parse(str(a)), a) + a = SparseVector(4, [0, 2], [3, 4]) + self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(10, [0, 1], [4, 5]) + self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + def test_ml_mllib_vector_conversion(self): + # to ml + # dense + mllibDV = Vectors.dense([1, 2, 3]) + mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) + mlDV2 = mllibDV.asML() + self.assertEqual(mlDV2, mlDV1) + # sparse + mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV2 = mllibSV.asML() + self.assertEqual(mlSV2, mlSV1) + # from ml + # dense + mllibDV1 = Vectors.dense([1, 2, 3]) + mlDV = newlinalg.Vectors.dense([1, 2, 3]) + mllibDV2 = Vectors.fromML(mlDV) + self.assertEqual(mllibDV1, mllibDV2) + # sparse + mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mllibSV2 = Vectors.fromML(mlSV) + self.assertEqual(mllibSV1, mllibSV2) + + def test_ml_mllib_matrix_conversion(self): + # to ml + # dense + mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM2 = mllibDM.asML() + self.assertEqual(mlDM2, mlDM1) + # transposed + mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt2 = mllibDMt.asML() + self.assertEqual(mlDMt2, mlDMt1) + # sparse + mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM2 = mllibSM.asML() + self.assertEqual(mlSM2, mlSM1) + # transposed + mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt2 = mllibSMt.asML() + self.assertEqual(mlSMt2, mlSMt1) + # from ml + # dense + mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) + mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) + mllibDM2 = Matrices.fromML(mlDM) + self.assertEqual(mllibDM1, mllibDM2) + # transposed + mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) + mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) + mllibDMt2 = Matrices.fromML(mlDMt) + self.assertEqual(mllibDMt1, mllibDMt2) + # sparse + mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mllibSM2 = Matrices.fromML(mlSM) + self.assertEqual(mllibSM1, mllibSM2) + # transposed + mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mllibSMt2 = Matrices.fromML(mlSMt) + self.assertEqual(mllibSMt1, mllibSMt2) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + +@unittest.skipIf(not _have_scipy, "SciPy not installed") +class SciPyTests(MLlibTestCase): + + """ + Test both vector operations and MLlib algorithms with SciPy sparse matrices, + if SciPy is available. + """ + + def test_serialize(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 1 + lil[3, 0] = 2 + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) + + def serialize(l): + return ser.loads(ser.dumps(_convert_to_vector(l))) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) + + def test_convert_to_vector(self): + from scipy.sparse import csc_matrix + # Create a CSC matrix with non-sorted indices + indptr = array([0, 2]) + indices = array([3, 1]) + data = array([2.0, 1.0]) + csc = csc_matrix((data, indices, indptr)) + self.assertFalse(csc.has_sorted_indices) + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(csc)) + + def test_dot(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 1 + lil[3, 0] = 2 + dv = DenseVector(array([1., 2., 3., 4.])) + self.assertEqual(10.0, dv.dot(lil)) + + def test_squared_distance(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 3 + lil[3, 0] = 2 + dv = DenseVector(array([1., 2., 3., 4.])) + sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) + + def scipy_matrix(self, size, values): + """Create a column SciPy matrix from a dictionary of values""" + from scipy.sparse import lil_matrix + lil = lil_matrix((size, 1)) + for key, value in values.items(): + lil[key, 0] = value + return lil + + def test_clustering(self): + from pyspark.mllib.clustering import KMeans + data = [ + self.scipy_matrix(3, {1: 1.0}), + self.scipy_matrix(3, {1: 1.1}), + self.scipy_matrix(3, {2: 1.0}), + self.scipy_matrix(3, {2: 1.1}) + ] + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) + + def test_classification(self): + from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree + data = [ + LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), + LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) + ] + rdd = self.sc.parallelize(data) + features = [p.features for p in data] + + lr_model = LogisticRegressionWithSGD.train(rdd) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + svm_model = SVMWithSGD.train(rdd) + self.assertTrue(svm_model.predict(features[0]) <= 0) + self.assertTrue(svm_model.predict(features[1]) > 0) + self.assertTrue(svm_model.predict(features[2]) <= 0) + self.assertTrue(svm_model.predict(features[3]) > 0) + + nb_model = NaiveBayes.train(rdd) + self.assertTrue(nb_model.predict(features[0]) <= 0) + self.assertTrue(nb_model.predict(features[1]) > 0) + self.assertTrue(nb_model.predict(features[2]) <= 0) + self.assertTrue(nb_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + def test_regression(self): + from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ + RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree + data = [ + LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), + LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) + ] + rdd = self.sc.parallelize(data) + features = [p.features for p in data] + + lr_model = LinearRegressionWithSGD.train(rdd) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + lasso_model = LassoWithSGD.train(rdd) + self.assertTrue(lasso_model.predict(features[0]) <= 0) + self.assertTrue(lasso_model.predict(features[1]) > 0) + self.assertTrue(lasso_model.predict(features[2]) <= 0) + self.assertTrue(lasso_model.predict(features[3]) > 0) + + rr_model = RidgeRegressionWithSGD.train(rdd) + self.assertTrue(rr_model.predict(features[0]) <= 0) + self.assertTrue(rr_model.predict(features[1]) > 0) + self.assertTrue(rr_model.predict(features[2]) <= 0) + self.assertTrue(rr_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_linalg import * + if not _have_scipy: + print("NOTE: Skipping SciPy tests as it does not seem to be installed") + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) + if not _have_scipy: + print("NOTE: SciPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py new file mode 100644 index 0000000000000..5e74087d8fa7b --- /dev/null +++ b/python/pyspark/mllib/tests/test_stat.py @@ -0,0 +1,197 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import array as pyarray + +from numpy import array + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.mllibutils import MLlibTestCase + + +class StatTests(MLlibTestCase): + # SPARK-4023 + def test_col_with_different_rdds(self): + # numpy + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(1000, summary.count()) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + + +class ChiSqTestTests(MLlibTestCase): + def test_goodness_of_fit(self): + from numpy import inf + + observed = Vectors.dense([4, 6, 5]) + pearson = Statistics.chiSqTest(observed) + + # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + self.assertEqual(pearson.statistic, 0.4) + self.assertEqual(pearson.degreesOfFreedom, 2) + self.assertAlmostEqual(pearson.pValue, 0.8187, 4) + + # Different expected and observed sum + observed1 = Vectors.dense([21, 38, 43, 80]) + expected1 = Vectors.dense([3, 5, 7, 20]) + pearson1 = Statistics.chiSqTest(observed1, expected1) + + # Results validated against the R command + # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) + self.assertEqual(pearson1.degreesOfFreedom, 3) + self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) + + # Vectors with different sizes + observed3 = Vectors.dense([1.0, 2.0, 3.0]) + expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) + self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) + + # Negative counts in observed + neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) + + # Count = 0.0 in expected but not observed + zero_expected = Vectors.dense([1.0, 0.0, 3.0]) + pearson_inf = Statistics.chiSqTest(observed, zero_expected) + self.assertEqual(pearson_inf.statistic, inf) + self.assertEqual(pearson_inf.degreesOfFreedom, 2) + self.assertEqual(pearson_inf.pValue, 0.0) + + # 0.0 in expected and observed simultaneously + zero_observed = Vectors.dense([2.0, 0.0, 1.0]) + self.assertRaises( + IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) + + def test_matrix_independence(self): + data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + + # Results validated against R command + # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + self.assertAlmostEqual(chi.statistic, 21.9958, 4) + self.assertEqual(chi.degreesOfFreedom, 6) + self.assertAlmostEqual(chi.pValue, 0.001213, 4) + + # Negative counts + neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) + + # Row sum = 0.0 + row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) + + # Column sum = 0.0 + col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) + + def test_chi_sq_pearson(self): + data = [ + LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) + ] + + for numParts in [2, 4, 6, 8]: + chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) + feature1 = chi[0] + self.assertEqual(feature1.statistic, 0.75) + self.assertEqual(feature1.degreesOfFreedom, 2) + self.assertAlmostEqual(feature1.pValue, 0.6873, 4) + + feature2 = chi[1] + self.assertEqual(feature2.statistic, 1.5) + self.assertEqual(feature2.degreesOfFreedom, 3) + self.assertAlmostEqual(feature2.pValue, 0.6823, 4) + + def test_right_number_of_results(self): + num_cols = 1001 + sparse_data = [ + LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), + LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) + ] + chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) + self.assertEqual(len(chi), num_cols) + self.assertIsNotNone(chi[1000]) + + +class KolmogorovSmirnovTest(MLlibTestCase): + + def test_R_implementation_equivalence(self): + data = self.sc.parallelize([ + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ]) + model = Statistics.kolmogorovSmirnovTest(data, "norm") + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_stat import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py new file mode 100644 index 0000000000000..ba95855fd4f00 --- /dev/null +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -0,0 +1,523 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from time import time, sleep + +from numpy import array, random, exp, dot, all, mean, abs +from numpy import sum as array_sum + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark import SparkContext +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD +from pyspark.mllib.util import LinearDataGenerator +from pyspark.streaming import StreamingContext + + +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + self.sc.stop() + + @staticmethod + def _eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for streaming ML tests. + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return + sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) + + +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEqual( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + self.ssc.start() + + def condition(): + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) + return True + self._eventually(condition, catch_assertions=True) + + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offset for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + self.ssc.start() + + # Give enough time to train the model. + def condition(): + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + return True + self._eventually(condition, catch_assertions=True) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + self.ssc.start() + + def condition(): + self.assertEqual(result, [[0], [1], [2], [3]]) + return True + + self._eventually(condition, catch_assertions=True) + + @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [self.sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + self.ssc.start() + + def condition(): + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + return True + + self._eventually(condition, catch_assertions=True) + + +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + self.ssc.start() + + def condition(): + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + self.ssc.start() + + def condition(): + self.assertEqual(len(models), len(input_batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, 60.0, catch_assertions=True) + + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + # Test that weights improve with a small tolerance + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + self.ssc.start() + + def condition(): + self.assertEqual(len(true_predicted), len(input_batches)) + return True + + self._eventually(condition, catch_assertions=True) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + self.ssc.start() + + def condition(): + # Test that the improvement in error is > 0.3 + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 0.3) + if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertEqual(len(model_weights), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + w = array(model_weights) + diff = w[1:] - w[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + + def condition(): + self.assertEqual(len(samples), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(mean(abs(true) - abs(predicted))) + + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + + def condition(): + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 2) + if len(errors) >= 3 and errors[1] - errors[-1] > 2: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_streaming_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py new file mode 100644 index 0000000000000..c924eba80484c --- /dev/null +++ b/python/pyspark/mllib/tests/test_util.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import SparseVector, DenseVector, SparseMatrix, Vectors +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.testing.mllibutils import MLlibTestCase + + +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + +class SerDeTest(MLlibTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py new file mode 100644 index 0000000000000..9248182658f84 --- /dev/null +++ b/python/pyspark/testing/mllibutils.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark import SparkContext +from pyspark.serializers import PickleSerializer +from pyspark.sql import SparkSession + + +def make_serializer(): + return PickleSerializer() + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.spark = SparkSession(self.sc) + + def tearDown(self): + self.spark.stop() From 99cbc51b3250c07a3e8cc95c9b74e9d1725bac77 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 16 Nov 2018 09:51:41 -0800 Subject: [PATCH 2075/2461] [SPARK-26069][TESTS] Fix flaky test: RpcIntegrationSuite.sendRpcWithStreamFailures ## What changes were proposed in this pull request? The test failure is because `assertErrorAndClosed` misses one possible error message: `java.nio.channels.ClosedChannelException`. This happens when the second `uploadStream` is called after the channel has been closed. This can be reproduced by adding `Thread.sleep(1000)` below this line: https://github.com/apache/spark/blob/03306a6df39c9fd6cb581401c13c4dfc6bbd632e/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java#L217 This PR fixes the above issue and also improves the test failure messages of `assertErrorAndClosed`. ## How was this patch tested? Jenkins Closes #23041 from zsxwing/SPARK-26069. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../spark/network/RpcIntegrationSuite.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 1f4d75c7e2ec5..45f4a1808562d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -371,7 +371,10 @@ private void assertErrorsContain(Set errors, Set contains) { private void assertErrorAndClosed(RpcResult result, String expectedError) { assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); - // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + // we expect 1 additional error, which should contain one of the follow messages: + // - "closed" + // - "Connection reset" + // - "java.nio.channels.ClosedChannelException" Set errors = result.errorMessages; assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + errors, 2, errors.size()); @@ -379,15 +382,18 @@ private void assertErrorAndClosed(RpcResult result, String expectedError) { Set containsAndClosed = Sets.newHashSet(expectedError); containsAndClosed.add("closed"); containsAndClosed.add("Connection reset"); + containsAndClosed.add("java.nio.channels.ClosedChannelException"); Pair, Set> r = checkErrorsContain(errors, containsAndClosed); - Set errorsNotFound = r.getRight(); - assertEquals(1, errorsNotFound.size()); - String err = errorsNotFound.iterator().next(); - assertTrue(err.equals("closed") || err.equals("Connection reset")); + assertTrue("Got a non-empty set " + r.getLeft(), r.getLeft().isEmpty()); - assertTrue(r.getLeft().isEmpty()); + Set errorsNotFound = r.getRight(); + assertEquals( + "The size of " + errorsNotFound.toString() + " was not 2", 2, errorsNotFound.size()); + for (String err: errorsNotFound) { + assertTrue("Found a wrong error " + err, containsAndClosed.contains(err)); + } } private Pair, Set> checkErrorsContain( From 058c4602b000b24deb764a810ef8b43c41fe63ae Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 16 Nov 2018 15:43:27 -0800 Subject: [PATCH 2076/2461] [SPARK-26092][SS] Use CheckpointFileManager to write the streaming metadata file ## What changes were proposed in this pull request? Use CheckpointFileManager to write the streaming `metadata` file so that the `metadata` file will never be a partial file. ## How was this patch tested? Jenkins Closes #23060 from zsxwing/SPARK-26092. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../streaming/CheckpointFileManager.scala | 2 +- .../execution/streaming/StreamExecution.scala | 1 + .../execution/streaming/StreamMetadata.scala | 23 +++++++++++++------ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala index 606ba250ad9d2..b3e4240c315bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -56,7 +56,7 @@ trait CheckpointFileManager { * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to * overwrite the file if it already exists. It should not throw * any exception if the file exists. However, if false, then the - * implementation must not overwrite if the file alraedy exists and + * implementation must not overwrite if the file already exists and * must throw `FileAlreadyExistsException` in that case. */ def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 631a6eb649ffb..89b4f40c9c0b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -88,6 +88,7 @@ abstract class StreamExecution( val resolvedCheckpointRoot = { val checkpointPath = new Path(checkpointRoot) val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + fs.mkdirs(checkpointPath) checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala index 0bc54eac4ee8e..516afbea5d9de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala @@ -19,16 +19,18 @@ package org.apache.spark.sql.execution.streaming import java.io.{InputStreamReader, OutputStreamWriter} import java.nio.charset.StandardCharsets +import java.util.ConcurrentModificationException import scala.util.control.NonFatal import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataInputStream, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.streaming.StreamingQuery /** @@ -70,19 +72,26 @@ object StreamMetadata extends Logging { metadata: StreamMetadata, metadataFile: Path, hadoopConf: Configuration): Unit = { - var output: FSDataOutputStream = null + var output: CancellableFSDataOutputStream = null try { - val fs = metadataFile.getFileSystem(hadoopConf) - output = fs.create(metadataFile) + val fileManager = CheckpointFileManager.create(metadataFile.getParent, hadoopConf) + output = fileManager.createAtomic(metadataFile, overwriteIfPossible = false) val writer = new OutputStreamWriter(output) Serialization.write(metadata, writer) writer.close() } catch { - case NonFatal(e) => + case e: FileAlreadyExistsException => + if (output != null) { + output.cancel() + } + throw new ConcurrentModificationException( + s"Multiple streaming queries are concurrently using $metadataFile", e) + case e: Throwable => + if (output != null) { + output.cancel() + } logError(s"Error writing stream metadata $metadata to $metadataFile", e) throw e - } finally { - IOUtils.closeQuietly(output) } } } From d2792046a1b10a07b65fc30be573983f1237e450 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 16 Nov 2018 15:57:38 -0800 Subject: [PATCH 2077/2461] [SPARK-26095][BUILD] Disable parallelization in make-distibution.sh. It makes the build slower, but at least it doesn't hang. Seems that maven-shade-plugin has some issue with parallelization. Closes #23061 from vanzin/SPARK-26095. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- dev/make-distribution.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84f4ae9a64ff8..a550af93feecd 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -165,7 +165,7 @@ export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:ReservedCodeCacheSize=512m}" # Store the command as an array because $MVN variable might have spaces in it. # Normal quoting tricks don't work. # See: http://mywiki.wooledge.org/BashFAQ/050 -BUILD_COMMAND=("$MVN" -T 1C clean package -DskipTests $@) +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." From 23cd0e6e9e20a224a71859c158437e0a31982259 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 17 Nov 2018 15:07:20 +0800 Subject: [PATCH 2078/2461] [SPARK-26079][SQL] Ensure listener event delivery in StreamingQueryListenersConfSuite. Events are dispatched on a separate thread, so need to wait for them to be actually delivered before checking that the listener got them. Closes #23050 from vanzin/SPARK-26079. Authored-by: Marcelo Vanzin Signed-off-by: hyukjinkwon --- .../spark/sql/streaming/StreamingQueryListenersConfSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala index 1aaf8a9aa2d55..ddbc175e7ea48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -30,7 +30,6 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { import testImplicits._ - override protected def sparkConf: SparkConf = super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", "org.apache.spark.sql.streaming.TestListener") @@ -41,6 +40,8 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { StopStream ) + spark.sparkContext.listenerBus.waitUntilEmpty(5000) + assert(TestListener.queryStartedEvent != null) assert(TestListener.queryTerminatedEvent != null) } From b538c442cb3982cc4c3aac812a7d4764209dfbb7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Nov 2018 18:18:41 +0800 Subject: [PATCH 2079/2461] [MINOR][SQL] Fix typo in CTAS plan database string ## What changes were proposed in this pull request? Since [Spark 1.6.0](https://github.com/apache/spark/commit/56d7da14ab8f89bf4f303b27f51fd22d23967ffb#diff-6f38a103058a6e233b7ad80718452387R96), there was a redundant '}' character in CTAS string plan's database argument string; `default}`. This PR aims to fix it. **BEFORE** ```scala scala> sc.version res1: String = 1.6.0 scala> sql("create table t as select 1").explain == Physical Plan == ExecutedCommand CreateTableAsSelect [Database:default}, TableName: t, InsertIntoHiveTable] +- Project [1 AS _c0#3] +- OneRowRelation$ ``` **AFTER** ```scala scala> sql("create table t as select 1").explain == Physical Plan == Execute CreateHiveTableAsSelectCommand CreateHiveTableAsSelectCommand [Database:default, TableName: t, InsertIntoHiveTable] +- *(1) Project [1 AS 1#4] +- Scan OneRowRelation[] ``` ## How was this patch tested? Manual. Closes #23064 from dongjoon-hyun/SPARK-FIX. Authored-by: Dongjoon Hyun Signed-off-by: hyukjinkwon --- .../sql/hive/execution/CreateHiveTableAsSelectCommand.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index aa573b54a2b62..630bea5161f19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -96,7 +96,7 @@ case class CreateHiveTableAsSelectCommand( } override def argString: String = { - s"[Database:${tableDesc.database}}, " + + s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" } From ed46ac9f4736d23c2f7294133d4def93dc99cce1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Nov 2018 03:28:43 -0800 Subject: [PATCH 2080/2461] [SPARK-26091][SQL] Upgrade to 2.3.4 for Hive Metastore Client 2.3 ## What changes were proposed in this pull request? [Hive 2.3.4 is released on Nov. 7th](https://hive.apache.org/downloads.html#7-november-2018-release-234-available). This PR aims to support that version. ## How was this patch tested? Pass the Jenkins with the updated version Closes #23059 from dongjoon-hyun/SPARK-26091. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/sql-data-sources-hive-tables.md | 2 +- docs/sql-migration-guide-hive-compatibility.md | 2 +- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../org/apache/spark/sql/hive/client/IsolatedClientLoader.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/client/package.scala | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md index 687e6f8e0a7cc..28e1a39626666 100644 --- a/docs/sql-data-sources-hive-tables.md +++ b/docs/sql-data-sources-hive-tables.md @@ -115,7 +115,7 @@ The following options can be used to configure the version of Hive that is used diff --git a/docs/sql-migration-guide-hive-compatibility.md b/docs/sql-migration-guide-hive-compatibility.md index 94849418030ef..dd7b06225714f 100644 --- a/docs/sql-migration-guide-hive-compatibility.md +++ b/docs/sql-migration-guide-hive-compatibility.md @@ -10,7 +10,7 @@ displayTitle: Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.4. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 74f21532b22df..66067704195dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.3.") + s"0.12.0 through 2.3.4.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c1d8fe53a9e8c..f56ca8cb08553 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -99,7 +99,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" | "2.3.4" => hive.v2_3 case version => throw new UnsupportedOperationException(s"Unsupported Hive Metastore version ($version). " + s"Please set ${HiveUtils.HIVE_METASTORE_VERSION.key} with a valid version.") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 25e9886fa6576..e4cf7299d2af6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.3", + case object v2_3 extends HiveVersion("2.3.4", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) From e557c53c59a98f601b15850bb89fd4e252135556 Mon Sep 17 00:00:00 2001 From: Shahid Date: Sat, 17 Nov 2018 09:43:33 -0600 Subject: [PATCH 2081/2461] [SPARK-26006][MLLIB] unpersist 'dataInternalRepr' in the PrefixSpan ## What changes were proposed in this pull request? Mllib's Prefixspan - run method - cached RDD stays in cache. After run is comlpeted , rdd remain in cache. We need to unpersist the cached RDD after run method. ## How was this patch tested? Existing tests Closes #23016 from shahidki31/SPARK-26006. Authored-by: Shahid Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 7aed2f3bd8a61..64d6a0bc47b97 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -174,6 +174,13 @@ class PrefixSpan private ( val freqSequences = results.map { case (seq: Array[Int], count: Long) => new FreqSequence(toPublicRepr(seq), count) } + // Cache the final RDD to the same storage level as input + if (data.getStorageLevel != StorageLevel.NONE) { + freqSequences.persist(data.getStorageLevel) + freqSequences.count() + } + dataInternalRepr.unpersist(false) + new PrefixSpanModel(freqSequences) } From e00cac989821aea238c7bf20b69068ef7cf2eef3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 17 Nov 2018 09:46:45 -0600 Subject: [PATCH 2082/2461] [SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading ## What changes were proposed in this pull request? Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its `impurity` param by default contains the value `gini`: it is not even modifiable by the user and it differs from the actual impurity used, which is `variance`. This issue does not limit to a wrong value returned for it if the user queries by `getImpurity`, but it also affect the load of a saved model, as its `impurityStats` are created as `gini` (since this is the value stored for the model impurity) which leads to wrong `featureImportances` in model loaded from saved ones. The PR changes the `impurity` param used to one which allows only the value `variance`. ## How was this patch tested? modified UT Closes #22986 from mgaido91/SPARK-25959. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../ml/classification/GBTClassifier.scala | 4 +++- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeParams.scala | 19 ++++++++++--------- .../classification/GBTClassifierSuite.scala | 1 + project/MimaExcludes.scala | 11 +++++++++++ 6 files changed, 27 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62cfa39746ff0..62c6bdbdeb285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - metadata.getAndSetParams(model) + // We ignore the impurity while loading models because in previous models it was wrongly + // set to gini (see SPARK-25959). + metadata.getAndSetParams(model, Some(List("impurity"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6fa656275c1fd..c9de85de42fa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("1.4.0") object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities @Since("2.0.0") override def load(path: String): DecisionTreeRegressor = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82bf66ff66d8a..66d57ad6c4348 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 00157fe63af41..f1e3836ebe476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams { private[ml] trait DecisionTreeClassifierParams extends DecisionTreeParams with TreeClassifierParams -/** - * Parameters for Decision Tree-based regression algorithms. - */ -private[ml] trait TreeRegressorParams extends Params { - +private[ml] trait HasVarianceImpurity extends Params { /** * Criterion used for information gain calculation (case-insensitive). * Supported: "variance". @@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}", (value: String) => - TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) + HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params { } } -private[ml] object TreeRegressorParams { +private[ml] object HasVarianceImpurity { // These options should be lowercase. final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase(Locale.ROOT)) } +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends HasVarianceImpurity + private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { @@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { +private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity { /** * Loss function which GBT tries to minimize. (case-insensitive) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 304977634189c..cedbaf1858ef4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { model2: GBTClassificationModel): Unit = { TreeTests.checkEqual(model, model2) assert(model.numFeatures === model2.numFeatures) + assert(model.featureImportances == model2.featureImportances) } val gbt = new GBTClassifier() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b030b6ca2922f..a8d2b5d1d9cb6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,17 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), From 034ae305c33b1990b3c1a284044002874c343b4d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 18 Nov 2018 16:02:15 +0800 Subject: [PATCH 2083/2461] [SPARK-26033][PYTHON][TESTS] Break large ml/tests.py file into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR breaks down the large ml/tests.py file that contains all Python ML unit tests into several smaller test files to be easier to read and maintain. The tests are broken down as follows: ``` pyspark ├── __init__.py ... ├── ml │ ├── __init__.py ... │ ├── tests │ │ ├── __init__.py │ │ ├── test_algorithms.py │ │ ├── test_base.py │ │ ├── test_evaluation.py │ │ ├── test_feature.py │ │ ├── test_image.py │ │ ├── test_linalg.py │ │ ├── test_param.py │ │ ├── test_persistence.py │ │ ├── test_pipeline.py │ │ ├── test_stat.py │ │ ├── test_training_summary.py │ │ ├── test_tuning.py │ │ └── test_wrapper.py ... ├── testing ... │ ├── mlutils.py ... ``` ## How was this patch tested? Ran tests manually by module to ensure test count was the same, and ran `python/run-tests --modules=pyspark-ml` to verify all passing with Python 2.7 and Python 3.6. Closes #23063 from BryanCutler/python-test-breakup-ml-SPARK-26033. Authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 16 +- python/pyspark/ml/tests.py | 2762 ----------------- python/pyspark/ml/tests/__init__.py | 16 + python/pyspark/ml/tests/test_algorithms.py | 349 +++ python/pyspark/ml/tests/test_base.py | 85 + python/pyspark/ml/tests/test_evaluation.py | 71 + python/pyspark/ml/tests/test_feature.py | 318 ++ python/pyspark/ml/tests/test_image.py | 118 + python/pyspark/ml/tests/test_linalg.py | 392 +++ python/pyspark/ml/tests/test_param.py | 372 +++ python/pyspark/ml/tests/test_persistence.py | 369 +++ python/pyspark/ml/tests/test_pipeline.py | 77 + python/pyspark/ml/tests/test_stat.py | 58 + .../pyspark/ml/tests/test_training_summary.py | 258 ++ python/pyspark/ml/tests/test_tuning.py | 552 ++++ python/pyspark/ml/tests/test_wrapper.py | 120 + python/pyspark/testing/mlutils.py | 161 + 17 files changed, 3331 insertions(+), 2763 deletions(-) delete mode 100755 python/pyspark/ml/tests.py create mode 100644 python/pyspark/ml/tests/__init__.py create mode 100644 python/pyspark/ml/tests/test_algorithms.py create mode 100644 python/pyspark/ml/tests/test_base.py create mode 100644 python/pyspark/ml/tests/test_evaluation.py create mode 100644 python/pyspark/ml/tests/test_feature.py create mode 100644 python/pyspark/ml/tests/test_image.py create mode 100644 python/pyspark/ml/tests/test_linalg.py create mode 100644 python/pyspark/ml/tests/test_param.py create mode 100644 python/pyspark/ml/tests/test_persistence.py create mode 100644 python/pyspark/ml/tests/test_pipeline.py create mode 100644 python/pyspark/ml/tests/test_stat.py create mode 100644 python/pyspark/ml/tests/test_training_summary.py create mode 100644 python/pyspark/ml/tests/test_tuning.py create mode 100644 python/pyspark/ml/tests/test_wrapper.py create mode 100644 python/pyspark/testing/mlutils.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 547635a412913..eef7f259391b8 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -452,6 +452,7 @@ def __hash__(self): "python/pyspark/ml/" ], python_test_goals=[ + # doctests "pyspark.ml.classification", "pyspark.ml.clustering", "pyspark.ml.evaluation", @@ -463,7 +464,20 @@ def __hash__(self): "pyspark.ml.regression", "pyspark.ml.stat", "pyspark.ml.tuning", - "pyspark.ml.tests", + # unittests + "pyspark.ml.tests.test_algorithms", + "pyspark.ml.tests.test_base", + "pyspark.ml.tests.test_evaluation", + "pyspark.ml.tests.test_feature", + "pyspark.ml.tests.test_image", + "pyspark.ml.tests.test_linalg", + "pyspark.ml.tests.test_param", + "pyspark.ml.tests.test_persistence", + "pyspark.ml.tests.test_pipeline", + "pyspark.ml.tests.test_stat", + "pyspark.ml.tests.test_training_summary", + "pyspark.ml.tests.test_tuning", + "pyspark.ml.tests.test_wrapper", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py deleted file mode 100755 index 2b4b7315d98c0..0000000000000 --- a/python/pyspark/ml/tests.py +++ /dev/null @@ -1,2762 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for MLlib Python DataFrame-based APIs. -""" -import sys -if sys.version > '3': - xrange = range - basestring = str - -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from shutil import rmtree -import tempfile -import array as pyarray -import numpy as np -from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros -import inspect -import py4j - -from pyspark import keyword_only, SparkContext -from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer -from pyspark.ml.classification import * -from pyspark.ml.clustering import * -from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ - MulticlassClassificationEvaluator, RegressionEvaluator -from pyspark.ml.feature import * -from pyspark.ml.fpm import FPGrowth, FPGrowthModel -from pyspark.ml.image import ImageSchema -from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ - SparseMatrix, SparseVector, Vector, VectorUDT, Vectors -from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed -from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ - LinearRegression -from pyspark.ml.stat import ChiSquareTest -from pyspark.ml.tuning import * -from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession, HiveContext -from pyspark.sql.functions import rand -from pyspark.sql.types import DoubleType, IntegerType -from pyspark.storagelevel import * -from pyspark.testing.utils import QuietTest, ReusedPySparkTestCase as PySparkTestCase - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class SparkSessionTestCase(PySparkTestCase): - @classmethod - def setUpClass(cls): - PySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - PySparkTestCase.tearDownClass() - cls.spark.stop() - - -class MockDataset(DataFrame): - - def __init__(self): - self.index = 0 - - -class HasFake(Params): - - def __init__(self): - super(HasFake, self).__init__() - self.fake = Param(self, "fake", "fake param") - - def getFake(self): - return self.getOrDefault(self.fake) - - -class MockTransformer(Transformer, HasFake): - - def __init__(self): - super(MockTransformer, self).__init__() - self.dataset_index = None - - def _transform(self, dataset): - self.dataset_index = dataset.index - dataset.index += 1 - return dataset - - -class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): - - shift = Param(Params._dummy(), "shift", "The amount by which to shift " + - "data in a DataFrame", - typeConverter=TypeConverters.toFloat) - - def __init__(self, shiftVal=1): - super(MockUnaryTransformer, self).__init__() - self._setDefault(shift=1) - self._set(shift=shiftVal) - - def getShift(self): - return self.getOrDefault(self.shift) - - def setShift(self, shift): - self._set(shift=shift) - - def createTransformFunc(self): - shiftVal = self.getShift() - return lambda x: x + shiftVal - - def outputDataType(self): - return DoubleType() - - def validateInputType(self, inputType): - if inputType != DoubleType(): - raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Double.") - - -class MockEstimator(Estimator, HasFake): - - def __init__(self): - super(MockEstimator, self).__init__() - self.dataset_index = None - - def _fit(self, dataset): - self.dataset_index = dataset.index - model = MockModel() - self._copyValues(model) - return model - - -class MockModel(MockTransformer, Model, HasFake): - pass - - -class JavaWrapperMemoryTests(SparkSessionTestCase): - - def test_java_object_gets_detached(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - - model = lr.fit(df) - summary = model.summary - - self.assertIsInstance(model, JavaWrapper) - self.assertIsInstance(summary, JavaWrapper) - self.assertIsInstance(model, JavaParams) - self.assertNotIsInstance(summary, JavaParams) - - error_no_object = 'Target Object ID does not exist for this gateway' - - self.assertIn("LinearRegression_", model._java_obj.toString()) - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - model.__del__() - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - try: - summary.__del__() - except: - pass - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - summary._java_obj.toString() - - -class ParamTypeConversionTests(PySparkTestCase): - """ - Test that param type conversion happens. - """ - - def test_int(self): - lr = LogisticRegression(maxIter=5.0) - self.assertEqual(lr.getMaxIter(), 5) - self.assertTrue(type(lr.getMaxIter()) == int) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) - - def test_float(self): - lr = LogisticRegression(tol=1) - self.assertEqual(lr.getTol(), 1.0) - self.assertTrue(type(lr.getTol()) == float) - self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) - - def test_vector(self): - ewp = ElementwiseProduct(scalingVec=[1, 3]) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) - ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) - self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) - - def test_list(self): - l = [0, 1] - for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), - range(len(l)), l), pyarray.array('l', l), xrange(2), tuple(l)]: - converted = TypeConverters.toList(lst_like) - self.assertEqual(type(converted), list) - self.assertListEqual(converted, l) - - def test_list_int(self): - for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), - SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), - pyarray.array('d', [1.0, 2.0])]: - vs = VectorSlicer(indices=indices) - self.assertListEqual(vs.getIndices(), [1, 2]) - self.assertTrue(all([type(v) == int for v in vs.getIndices()])) - self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) - - def test_list_float(self): - b = Bucketizer(splits=[1, 4]) - self.assertEqual(b.getSplits(), [1.0, 4.0]) - self.assertTrue(all([type(v) == float for v in b.getSplits()])) - self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) - - def test_list_string(self): - for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: - idx_to_string = IndexToString(labels=labels) - self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) - self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) - - def test_string(self): - lr = LogisticRegression() - for col in ['features', u'features', np.str_('features')]: - lr.setFeaturesCol(col) - self.assertEqual(lr.getFeaturesCol(), 'features') - self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) - - def test_bool(self): - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - - -class PipelineTests(PySparkTestCase): - - def test_pipeline(self): - dataset = MockDataset() - estimator0 = MockEstimator() - transformer1 = MockTransformer() - estimator2 = MockEstimator() - transformer3 = MockTransformer() - pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) - pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - model0, transformer1, model2, transformer3 = pipeline_model.stages - self.assertEqual(0, model0.dataset_index) - self.assertEqual(0, model0.getFake()) - self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.getFake()) - self.assertEqual(2, dataset.index) - self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") - self.assertIsNone(transformer3.dataset_index, - "The last transformer shouldn't be called in fit.") - dataset = pipeline_model.transform(dataset) - self.assertEqual(2, model0.dataset_index) - self.assertEqual(3, transformer1.dataset_index) - self.assertEqual(4, model2.dataset_index) - self.assertEqual(5, transformer3.dataset_index) - self.assertEqual(6, dataset.index) - - def test_identity_pipeline(self): - dataset = MockDataset() - - def doTransform(pipeline): - pipeline_model = pipeline.fit(dataset) - return pipeline_model.transform(dataset) - # check that empty pipeline did not perform any transformation - self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) - # check that failure to set stages param will raise KeyError for missing param - self.assertRaises(KeyError, lambda: doTransform(Pipeline())) - - -class TestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(TestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(OtherTestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class HasThrowableProperty(Params): - - def __init__(self): - super(HasThrowableProperty, self).__init__() - self.p = Param(self, "none", "empty param") - - @property - def test_property(self): - raise RuntimeError("Test property to raise error when invoked") - - -class ParamTests(SparkSessionTestCase): - - def test_copy_new_parent(self): - testParams = TestParams() - # Copying an instantiated param should fail - with self.assertRaises(ValueError): - testParams.maxIter._copy_new_parent(testParams) - # Copying a dummy param should succeed - TestParams.maxIter._copy_new_parent(testParams) - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_param(self): - testParams = TestParams() - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_hasparam(self): - testParams = TestParams() - self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) - self.assertFalse(testParams.hasParam("notAParameter")) - self.assertTrue(testParams.hasParam(u"maxIter")) - - def test_resolveparam(self): - testParams = TestParams() - self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) - self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) - - self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) - if sys.version_info[0] >= 3: - # In Python 3, it is allowed to get/set attributes with non-ascii characters. - e_cls = AttributeError - else: - e_cls = UnicodeEncodeError - self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) - - def test_params(self): - testParams = TestParams() - maxIter = testParams.maxIter - inputCol = testParams.inputCol - seed = testParams.seed - - params = testParams.params - self.assertEqual(params, [inputCol, maxIter, seed]) - - self.assertTrue(testParams.hasParam(maxIter.name)) - self.assertTrue(testParams.hasDefault(maxIter)) - self.assertFalse(testParams.isSet(maxIter)) - self.assertTrue(testParams.isDefined(maxIter)) - self.assertEqual(testParams.getMaxIter(), 10) - testParams.setMaxIter(100) - self.assertTrue(testParams.isSet(maxIter)) - self.assertEqual(testParams.getMaxIter(), 100) - - self.assertTrue(testParams.hasParam(inputCol.name)) - self.assertFalse(testParams.hasDefault(inputCol)) - self.assertFalse(testParams.isSet(inputCol)) - self.assertFalse(testParams.isDefined(inputCol)) - with self.assertRaises(KeyError): - testParams.getInputCol() - - otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + - "set raises an error for a non-member parameter.", - typeConverter=TypeConverters.toString) - with self.assertRaises(ValueError): - testParams.set(otherParam, "value") - - # Since the default is normally random, set it to a known number for debug str - testParams._setDefault(seed=41) - testParams.setSeed(43) - - self.assertEqual( - testParams.explainParams(), - "\n".join(["inputCol: input column name. (undefined)", - "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", - "seed: random seed. (default: 41, current: 43)"])) - - def test_kmeans_param(self): - algo = KMeans() - self.assertEqual(algo.getInitMode(), "k-means||") - algo.setK(10) - self.assertEqual(algo.getK(), 10) - algo.setInitSteps(10) - self.assertEqual(algo.getInitSteps(), 10) - self.assertEqual(algo.getDistanceMeasure(), "euclidean") - algo.setDistanceMeasure("cosine") - self.assertEqual(algo.getDistanceMeasure(), "cosine") - - def test_hasseed(self): - noSeedSpecd = TestParams() - withSeedSpecd = TestParams(seed=42) - other = OtherTestParams() - # Check that we no longer use 42 as the magic number - self.assertNotEqual(noSeedSpecd.getSeed(), 42) - origSeed = noSeedSpecd.getSeed() - # Check that we only compute the seed once - self.assertEqual(noSeedSpecd.getSeed(), origSeed) - # Check that a specified seed is honored - self.assertEqual(withSeedSpecd.getSeed(), 42) - # Check that a different class has a different seed - self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) - - def test_param_property_error(self): - param_store = HasThrowableProperty() - self.assertRaises(RuntimeError, lambda: param_store.test_property) - params = param_store.params # should not invoke the property 'test_property' - self.assertEqual(len(params), 1) - - def test_word2vec_param(self): - model = Word2Vec().setWindowSize(6) - # Check windowSize is set properly - self.assertEqual(model.getWindowSize(), 6) - - def test_copy_param_extras(self): - tp = TestParams(seed=42) - extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} - tp_copy = tp.copy(extra=extra) - self.assertEqual(tp.uid, tp_copy.uid) - self.assertEqual(tp.params, tp_copy.params) - for k, v in extra.items(): - self.assertTrue(tp_copy.isDefined(k)) - self.assertEqual(tp_copy.getOrDefault(k), v) - copied_no_extra = {} - for k, v in tp_copy._paramMap.items(): - if k not in extra: - copied_no_extra[k] = v - self.assertEqual(tp._paramMap, copied_no_extra) - self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) - - def test_logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegression - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - - def test_preserve_set_state(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - self.assertFalse(binarizer.isSet("threshold")) - binarizer.transform(dataset) - binarizer._transfer_params_from_java() - self.assertFalse(binarizer.isSet("threshold"), - "Params not explicitly set should remain unset after transform") - - def test_default_params_transferred(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - # intentionally change the pyspark default, but don't set it - binarizer._defaultParamMap[binarizer.outputCol] = "my_default" - result = binarizer.transform(dataset).select("my_default").collect() - self.assertFalse(binarizer.isSet(binarizer.outputCol)) - self.assertEqual(result[0][0], 1.0) - - @staticmethod - def check_params(test_self, py_stage, check_params_exist=True): - """ - Checks common requirements for Params.params: - - set of params exist in Java and Python and are ordered by names - - param parent has the same UID as the object's UID - - default param value from Java matches value in Python - - optionally check if all params from Java also exist in Python - """ - py_stage_str = "%s %s" % (type(py_stage), py_stage) - if not hasattr(py_stage, "_to_java"): - return - java_stage = py_stage._to_java() - if java_stage is None: - return - test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) - if check_params_exist: - param_names = [p.name for p in py_stage.params] - java_params = list(java_stage.params()) - java_param_names = [jp.name() for jp in java_params] - test_self.assertEqual( - param_names, sorted(java_param_names), - "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" - % (py_stage_str, java_param_names, param_names)) - for p in py_stage.params: - test_self.assertEqual(p.parent, py_stage.uid) - java_param = java_stage.getParam(p.name) - py_has_default = py_stage.hasDefault(p) - java_has_default = java_stage.hasDefault(java_param) - test_self.assertEqual(py_has_default, java_has_default, - "Default value mismatch of param %s for Params %s" - % (p.name, str(py_stage))) - if py_has_default: - if p.name == "seed": - continue # Random seeds between Spark and PySpark are different - java_default = _java2py(test_self.sc, - java_stage.clear(java_param).getOrDefault(java_param)) - py_stage._clear(p) - py_default = py_stage.getOrDefault(p) - # equality test for NaN is always False - if isinstance(java_default, float) and np.isnan(java_default): - java_default = "NaN" - py_default = "NaN" if np.isnan(py_default) else "not NaN" - test_self.assertEqual( - java_default, py_default, - "Java default %s != python default %s of param %s for Params %s" - % (str(java_default), str(py_default), p.name, str(py_stage))) - - -class EvaluatorTests(SparkSessionTestCase): - - def test_java_params(self): - """ - This tests a bug fixed by SPARK-18274 which causes multiple copies - of a Params instance in Python to be linked to the same Java instance. - """ - evaluator = RegressionEvaluator(metricName="r2") - df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) - evaluator.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) - evaluator.evaluate(df) - evaluatorCopy.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") - - def test_clustering_evaluator_with_cosine_distance(self): - featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), - [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), - ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) - dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) - evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") - self.assertEqual(evaluator.getDistanceMeasure(), "cosine") - self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) - - -class FeatureTests(SparkSessionTestCase): - - def test_binarizer(self): - b0 = Binarizer() - self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) - self.assertTrue(all([~b0.isSet(p) for p in b0.params])) - self.assertTrue(b0.hasDefault(b0.threshold)) - self.assertEqual(b0.getThreshold(), 0.0) - b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) - self.assertTrue(all([b0.isSet(p) for p in b0.params])) - self.assertEqual(b0.getThreshold(), 1.0) - self.assertEqual(b0.getInputCol(), "input") - self.assertEqual(b0.getOutputCol(), "output") - - b0c = b0.copy({b0.threshold: 2.0}) - self.assertEqual(b0c.uid, b0.uid) - self.assertListEqual(b0c.params, b0.params) - self.assertEqual(b0c.getThreshold(), 2.0) - - b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") - self.assertNotEqual(b1.uid, b0.uid) - self.assertEqual(b1.getThreshold(), 2.0) - self.assertEqual(b1.getInputCol(), "input") - self.assertEqual(b1.getOutputCol(), "output") - - def test_idf(self): - dataset = self.spark.createDataFrame([ - (DenseVector([1.0, 2.0]),), - (DenseVector([0.0, 1.0]),), - (DenseVector([3.0, 0.2]),)], ["tf"]) - idf0 = IDF(inputCol="tf") - self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) - idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) - self.assertEqual(idf0m.uid, idf0.uid, - "Model should inherit the UID from its parent estimator.") - output = idf0m.transform(dataset) - self.assertIsNotNone(output.head().idf) - # Test that parameters transferred to Python Model - ParamTests.check_params(self, idf0m) - - def test_ngram(self): - dataset = self.spark.createDataFrame([ - Row(input=["a", "b", "c", "d", "e"])]) - ngram0 = NGram(n=4, inputCol="input", outputCol="output") - self.assertEqual(ngram0.getN(), 4) - self.assertEqual(ngram0.getInputCol(), "input") - self.assertEqual(ngram0.getOutputCol(), "output") - transformedDF = ngram0.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) - - def test_stopwordsremover(self): - dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) - stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") - # Default - self.assertEqual(stopWordRemover.getInputCol(), "input") - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["panda"]) - self.assertEqual(type(stopWordRemover.getStopWords()), list) - self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) - # Custom - stopwords = ["panda"] - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getInputCol(), "input") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a"]) - # with language selection - stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - # with locale - stopwords = ["BELKİ"] - dataset = self.spark.createDataFrame([Row(input=["belki"])]) - stopWordRemover.setStopWords(stopwords).setLocale("tr") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - - def test_count_vectorizer_with_binary(self): - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) - cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") - model = cv.fit(dataset) - - transformedList = model.transform(dataset).select("features", "expected").collect() - - for r in transformedList: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_with_maxDF(self): - dataset = self.spark.createDataFrame([ - (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), - (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - cv = CountVectorizer(inputCol="words", outputCol="features") - model1 = cv.setMaxDF(3).fit(dataset) - self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) - - transformedList1 = model1.transform(dataset).select("features", "expected").collect() - - for r in transformedList1: - feature, expected = r - self.assertEqual(feature, expected) - - model2 = cv.setMaxDF(0.75).fit(dataset) - self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) - - transformedList2 = model2.transform(dataset).select("features", "expected").collect() - - for r in transformedList2: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_from_vocab(self): - model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", - outputCol="features", minTF=2) - self.assertEqual(model.vocabulary, ["a", "b", "c"]) - self.assertEqual(model.getMinTF(), 2) - - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), - (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - - transformed_list = model.transform(dataset).select("features", "expected").collect() - - for r in transformed_list: - feature, expected = r - self.assertEqual(feature, expected) - - # Test an empty vocabulary - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): - CountVectorizerModel.from_vocabulary([], inputCol="words") - - # Test model with default settings can transform - model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") - transformed_list = model_default.transform(dataset)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 3) - - def test_rformula_force_index_label(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - # Does not index label by default since it's numeric type. - rf = RFormula(formula="y ~ x + s") - model = rf.fit(df) - transformedDF = model.transform(df) - self.assertEqual(transformedDF.head().label, 1.0) - # Force to index label. - rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) - model2 = rf2.fit(df) - transformedDF2 = model2.transform(df) - self.assertEqual(transformedDF2.head().label, 0.0) - - def test_rformula_string_indexer_order_type(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") - self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') - transformedDF = rf.fit(df).transform(df) - observed = transformedDF.select("features").collect() - expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] - for i in range(0, len(expected)): - self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) - - def test_string_indexer_handle_invalid(self): - df = self.spark.createDataFrame([ - (0, "a"), - (1, "d"), - (2, None)], ["id", "label"]) - - si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", - stringOrderType="alphabetAsc") - model1 = si1.fit(df) - td1 = model1.transform(df) - actual1 = td1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] - self.assertEqual(actual1, expected1) - - si2 = si1.setHandleInvalid("skip") - model2 = si2.fit(df) - td2 = model2.transform(df) - actual2 = td2.select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] - self.assertEqual(actual2, expected2) - - def test_string_indexer_from_labels(self): - model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", - outputCol="indexed", handleInvalid="keep") - self.assertEqual(model.labels, ["a", "b", "c"]) - - df1 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, None), - (3, "b"), - (4, "b")], ["id", "label"]) - - result1 = model.transform(df1) - actual1 = result1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), - Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] - self.assertEqual(actual1, expected1) - - model_empty_labels = StringIndexerModel.from_labels( - [], inputCol="label", outputCol="indexed", handleInvalid="keep") - actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), - Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] - self.assertEqual(actual2, expected2) - - # Test model with default settings can transform - model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") - df2 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, "b"), - (3, "b"), - (4, "b")], ["id", "label"]) - transformed_list = model_default.transform(df2)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 5) - - def test_vector_size_hint(self): - df = self.spark.createDataFrame( - [(0, Vectors.dense([0.0, 10.0, 0.5])), - (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), - (2, Vectors.dense([2.0, 12.0]))], - ["id", "vector"]) - - sizeHint = VectorSizeHint( - inputCol="vector", - handleInvalid="skip") - sizeHint.setSize(3) - self.assertEqual(sizeHint.getSize(), 3) - - output = sizeHint.transform(df).head().vector - expected = DenseVector([0.0, 10.0, 0.5]) - self.assertEqual(output, expected) - - -class HasInducedError(Params): - - def __init__(self): - super(HasInducedError, self).__init__() - self.inducedError = Param(self, "inducedError", - "Uniformly-distributed error added to feature") - - def getInducedError(self): - return self.getOrDefault(self.inducedError) - - -class InducedErrorModel(Model, HasInducedError): - - def __init__(self): - super(InducedErrorModel, self).__init__() - - def _transform(self, dataset): - return dataset.withColumn("prediction", - dataset.feature + (rand(0) * self.getInducedError())) - - -class InducedErrorEstimator(Estimator, HasInducedError): - - def __init__(self, inducedError=1.0): - super(InducedErrorEstimator, self).__init__() - self._set(inducedError=inducedError) - - def _fit(self, dataset): - model = InducedErrorModel() - self._copyValues(model) - return model - - -class CrossValidatorTests(SparkSessionTestCase): - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvCopied = cv.copy() - self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) - - cvModel = cv.fit(dataset) - cvModelCopied = cvModel.copy() - for index in range(len(cvModel.avgMetrics)): - self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) - < 0.0001) - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - - def test_param_grid_type_coercion(self): - lr = LogisticRegression(maxIter=10) - paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() - for param in paramGrid: - for v in param.values(): - assert(type(v) == float) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for CrossValidator will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - lrModel = cvModel.bestModel - - cvModelPath = temp_path + "/cvModel" - lrModel.save(cvModelPath) - loadedLrModel = LogisticRegressionModel.load(cvModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cv.setParallelism(1) - cvSerialModel = cv.fit(dataset) - cv.setParallelism(2) - cvParallelModel = cv.fit(dataset) - self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - numFolds = 3 - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - numFolds=numFolds, collectSubModels=True) - - def checkSubModels(subModels): - self.assertEqual(len(subModels), numFolds) - for i in range(numFolds): - self.assertEqual(len(subModels[i]), len(grid)) - - cvModel = cv.fit(dataset) - checkSubModels(cvModel.subModels) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testCrossValidatorSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - cvModel.save(savingPathWithSubModels) - cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) - checkSubModels(cvModel3.subModels) - cvModel4 = cvModel3.copy() - checkSubModels(cvModel4.subModels) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) - self.assertEqual(cvModel2.subModels, None) - - for i in range(numFolds): - for j in range(len(grid)): - self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) - - def test_save_load_nested_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - - originalParamMap = cv.getEstimatorParamMaps() - loadedParamMap = loadedCV.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - -class TrainValidationSplitTests(SparkSessionTestCase): - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(0.0, min(validationMetrics)) - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(1.0, max(validationMetrics)) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - lrModel = tvsModel.bestModel - - tvsModelPath = temp_path + "/tvsModel" - lrModel.save(tvsModelPath) - loadedLrModel = LogisticRegressionModel.load(tvsModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvs.setParallelism(1) - tvsSerialModel = tvs.fit(dataset) - tvs.setParallelism(2) - tvsParallelModel = tvs.fit(dataset) - self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - collectSubModels=True) - tvsModel = tvs.fit(dataset) - self.assertEqual(len(tvsModel.subModels), len(grid)) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testTrainValidationSplitSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - tvsModel.save(savingPathWithSubModels) - tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) - self.assertEqual(len(tvsModel3.subModels), len(grid)) - tvsModel4 = tvsModel3.copy() - self.assertEqual(len(tvsModel4.subModels), len(grid)) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) - self.assertEqual(tvsModel2.subModels, None) - - for i in range(len(grid)): - self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) - - def test_save_load_nested_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - - originalParamMap = tvs.getEstimatorParamMaps() - loadedParamMap = loadedTvs.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsCopied = tvs.copy() - tvsModelCopied = tvsModel.copy() - - self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, - "Copied TrainValidationSplit has the same uid of Estimator") - - self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) - self.assertEqual(len(tvsModel.validationMetrics), - len(tvsModelCopied.validationMetrics), - "Copied validationMetrics has the same size of the original") - for index in range(len(tvsModel.validationMetrics)): - self.assertEqual(tvsModel.validationMetrics[index], - tvsModelCopied.validationMetrics[index]) - - -class PersistenceTest(SparkSessionTestCase): - - def test_linear_regression(self): - lr = LinearRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/lr" - lr.save(lr_path) - lr2 = LinearRegression.load(lr_path) - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(type(lr.uid), type(lr2.uid)) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LinearRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_linear_regression_pmml_basic(self): - # Most of the validation is done in the Scala side, here we just check - # that we output text rather than parquet (e.g. that the format flag - # was respected). - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1) - model = lr.fit(df) - path = tempfile.mkdtemp() - lr_path = path + "/lr-pmml" - model.write().format("pmml").save(lr_path) - pmml_text_list = self.sc.textFile(lr_path).collect() - pmml_text = "\n".join(pmml_text_list) - self.assertIn("Apache Spark", pmml_text) - self.assertIn("PMML", pmml_text) - - def test_logistic_regression(self): - lr = LogisticRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/logreg" - lr.save(lr_path) - lr2 = LogisticRegression.load(lr_path) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LogisticRegression instance uid (%s) " - "did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LogisticRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def _compare_params(self, m1, m2, param): - """ - Compare 2 ML Params instances for the given param, and assert both have the same param value - and parent. The param must be a parameter of m1. - """ - # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. - if m1.isDefined(param): - paramValue1 = m1.getOrDefault(param) - paramValue2 = m2.getOrDefault(m2.getParam(param.name)) - if isinstance(paramValue1, Params): - self._compare_pipelines(paramValue1, paramValue2) - else: - self.assertEqual(paramValue1, paramValue2) # for general types param - # Assert parents are equal - self.assertEqual(param.parent, m2.getParam(param.name).parent) - else: - # If m1 is not defined param, then m2 should not, too. See SPARK-14931. - self.assertFalse(m2.isDefined(m2.getParam(param.name))) - - def _compare_pipelines(self, m1, m2): - """ - Compare 2 ML types, asserting that they are equivalent. - This currently supports: - - basic types - - Pipeline, PipelineModel - - OneVsRest, OneVsRestModel - This checks: - - uid - - type - - Param values and parents - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams) or isinstance(m1, Transformer): - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - self._compare_params(m1, m2, p) - elif isinstance(m1, Pipeline): - self.assertEqual(len(m1.getStages()), len(m2.getStages())) - for s1, s2 in zip(m1.getStages(), m2.getStages()): - self._compare_pipelines(s1, s2) - elif isinstance(m1, PipelineModel): - self.assertEqual(len(m1.stages), len(m2.stages)) - for s1, s2 in zip(m1.stages, m2.stages): - self._compare_pipelines(s1, s2) - elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): - for p in m1.params: - self._compare_params(m1, m2, p) - if isinstance(m1, OneVsRestModel): - self.assertEqual(len(m1.models), len(m2.models)) - for x, y in zip(m1.models, m2.models): - self._compare_pipelines(x, y) - else: - raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) - - def test_pipeline_persistence(self): - """ - Pipeline[HashingTF, PCA] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - pl = Pipeline(stages=[tf, pca]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_nested_pipeline_persistence(self): - """ - Pipeline[HashingTF, Pipeline[PCA]] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - p0 = Pipeline(stages=[pca]) - pl = Pipeline(stages=[tf, p0]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_python_transformer_pipeline_persistence(self): - """ - Pipeline[MockUnaryTransformer, Binarizer] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.range(0, 10).toDF('input') - tf = MockUnaryTransformer(shiftVal=2)\ - .setInputCol("input").setOutputCol("shiftedInput") - tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") - pl = Pipeline(stages=[tf, tf2]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_onevsrest(self): - temp_path = tempfile.mkdtemp() - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))] * 10, - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self._compare_pipelines(ovr, loadedOvr) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - self._compare_pipelines(model, loadedModel) - - def test_decisiontree_classifier(self): - dt = DecisionTreeClassifier(maxDepth=1) - path = tempfile.mkdtemp() - dtc_path = path + "/dtc" - dt.save(dtc_path) - dt2 = DecisionTreeClassifier.load(dtc_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeClassifier instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeClassifier instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_decisiontree_regressor(self): - dt = DecisionTreeRegressor(maxDepth=1) - path = tempfile.mkdtemp() - dtr_path = path + "/dtr" - dt.save(dtr_path) - dt2 = DecisionTreeClassifier.load(dtr_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeRegressor instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeRegressor instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_default_read_write(self): - temp_path = tempfile.mkdtemp() - - lr = LogisticRegression() - lr.setMaxIter(50) - lr.setThreshold(.75) - writer = DefaultParamsWriter(lr) - - savePath = temp_path + "/lr" - writer.save(savePath) - - reader = DefaultParamsReadable.read() - lr2 = reader.load(savePath) - - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) - - # test overwrite - lr.setThreshold(.8) - writer.overwrite().save(savePath) - - reader = DefaultParamsReadable.read() - lr3 = reader.load(savePath) - - self.assertEqual(lr.uid, lr3.uid) - self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) - - def test_default_read_write_default_params(self): - lr = LogisticRegression() - self.assertFalse(lr.isSet(lr.getParam("threshold"))) - - lr.setMaxIter(50) - lr.setThreshold(.75) - - # `threshold` is set by user, default param `predictionCol` is not set by user. - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - writer = DefaultParamsWriter(lr) - metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) - self.assertTrue("defaultParamMap" in metadata) - - reader = DefaultParamsReadable.read() - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - # manually create metadata without `defaultParamMap` section. - del metadata['defaultParamMap'] - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): - reader.getAndSetParams(lr, loadedMetadata) - - # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. - metadata['sparkVersion'] = '2.3.0' - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - -class LDATest(SparkSessionTestCase): - - def _compare(self, m1, m2): - """ - Temp method for comparing instances. - TODO: Replace with generic implementation once SPARK-14706 is merged. - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - if m1.isDefined(p): - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) - if isinstance(m1, LDAModel): - self.assertEqual(m1.vocabSize(), m2.vocabSize()) - self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) - - def test_persistence(self): - # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. - df = self.spark.createDataFrame([ - [1, Vectors.dense([0.0, 1.0])], - [2, Vectors.sparse(2, {0: 1.0})], - ], ["id", "features"]) - # Fit model - lda = LDA(k=2, seed=1, optimizer="em") - distributedModel = lda.fit(df) - self.assertTrue(distributedModel.isDistributed()) - localModel = distributedModel.toLocal() - self.assertFalse(localModel.isDistributed()) - # Define paths - path = tempfile.mkdtemp() - lda_path = path + "/lda" - dist_model_path = path + "/distLDAModel" - local_model_path = path + "/localLDAModel" - # Test LDA - lda.save(lda_path) - lda2 = LDA.load(lda_path) - self._compare(lda, lda2) - # Test DistributedLDAModel - distributedModel.save(dist_model_path) - distributedModel2 = DistributedLDAModel.load(dist_model_path) - self._compare(distributedModel, distributedModel2) - # Test LocalLDAModel - localModel.save(local_model_path) - localModel2 = LocalLDAModel.load(local_model_path) - self._compare(localModel, localModel2) - # Clean up - try: - rmtree(path) - except OSError: - pass - - -class TrainingSummaryTest(SparkSessionTestCase): - - def test_linear_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertAlmostEqual(s.explainedVariance, 0.25, 2) - self.assertAlmostEqual(s.meanAbsoluteError, 0.0) - self.assertAlmostEqual(s.meanSquaredError, 0.0) - self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) - self.assertAlmostEqual(s.r2, 1.0, 2) - self.assertAlmostEqual(s.r2adj, 1.0, 2) - self.assertTrue(isinstance(s.residuals, DataFrame)) - self.assertEqual(s.numInstances, 2) - self.assertEqual(s.degreesOfFreedom, 1) - devResiduals = s.devianceResiduals - self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class LinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) - - def test_glr_summary(self): - from pyspark.ml.linalg import Vectors - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", - fitIntercept=False) - model = glr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.numInstances, 2) - self.assertTrue(isinstance(s.residuals(), DataFrame)) - self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - self.assertEqual(s.degreesOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedomNull, 2) - self.assertEqual(s.rank, 1) - self.assertTrue(isinstance(s.solver, basestring)) - self.assertTrue(isinstance(s.aic, float)) - self.assertTrue(isinstance(s.deviance, float)) - self.assertTrue(isinstance(s.nullDeviance, float)) - self.assertTrue(isinstance(s.dispersion, float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class GeneralizedLinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.deviance, s.deviance) - - def test_binary_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertTrue(isinstance(s.roc, DataFrame)) - self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) - self.assertTrue(isinstance(s.pr, DataFrame)) - self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) - self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) - self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) - self.assertAlmostEqual(s.accuracy, 1.0, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) - self.assertAlmostEqual(s.weightedRecall, 1.0, 2) - self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) - - def test_multiclass_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], [])), - (2.0, 2.0, Vectors.dense(2.0)), - (2.0, 2.0, Vectors.dense(1.9))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertAlmostEqual(s.accuracy, 0.75, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) - self.assertAlmostEqual(s.weightedRecall, 0.75, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) - - def test_gaussian_mixture_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - gmm = GaussianMixture(k=2) - model = gmm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertTrue(isinstance(s.probability, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 3) - - def test_bisecting_kmeans_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - bkm = BisectingKMeans(k=2) - model = bkm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 20) - - def test_kmeans_summary(self): - data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), - (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=2, seed=1) - model = kmeans.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 1) - - -class KMeansTests(SparkSessionTestCase): - - def test_kmeans_cosine_distance(self): - data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), - (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), - (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") - model = kmeans.fit(df) - result = model.transform(df).collect() - self.assertTrue(result[0].prediction == result[1].prediction) - self.assertTrue(result[2].prediction == result[3].prediction) - self.assertTrue(result[4].prediction == result[5].prediction) - - -class OneVsRestTests(SparkSessionTestCase): - - def test_copy(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - ovr1 = ovr.copy({lr.maxIter: 10}) - self.assertEqual(ovr.getClassifier().getMaxIter(), 5) - self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) - model = ovr.fit(df) - model1 = model.copy({model.predictionCol: "indexed"}) - self.assertEqual(model1.getPredictionCol(), "indexed") - - def test_output_columns(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, parallelism=1) - model = ovr.fit(df) - output = model.transform(df) - self.assertEqual(output.columns, ["label", "features", "prediction"]) - - def test_parallelism_doesnt_change_output(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) - modelPar1 = ovrPar1.fit(df) - ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) - modelPar2 = ovrPar2.fit(df) - for i, model in enumerate(modelPar1.models): - self.assertTrue(np.allclose(model.coefficients.toArray(), - modelPar2.models[i].coefficients.toArray(), atol=1E-4)) - self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) - - def test_support_for_weightCol(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), - (1.0, Vectors.sparse(2, [], []), 1.0), - (2.0, Vectors.dense(0.5, 0.5), 1.0)], - ["label", "features", "weight"]) - # classifier inherits hasWeightCol - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, weightCol="weight") - self.assertIsNotNone(ovr.fit(df)) - # classifier doesn't inherit hasWeightCol - dt = DecisionTreeClassifier() - ovr2 = OneVsRest(classifier=dt, weightCol="weight") - self.assertIsNotNone(ovr2.fit(df)) - - -class HashingTFTest(SparkSessionTestCase): - - def test_apply_binary_term_freqs(self): - - df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) - n = 10 - hashingTF = HashingTF() - hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) - output = hashingTF.transform(df) - features = output.select("features").first().features.toArray() - expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() - for i in range(0, n): - self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(features[i])) - - -class GeneralizedLinearRegressionTest(SparkSessionTestCase): - - def test_tweedie_distribution(self): - - df = self.spark.createDataFrame( - [(1.0, Vectors.dense(0.0, 0.0)), - (1.0, Vectors.dense(1.0, 2.0)), - (2.0, Vectors.dense(0.0, 0.0)), - (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) - - glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) - - model2 = glr.setLinkPower(-1.0).fit(df) - self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) - self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) - - def test_offset(self): - - df = self.spark.createDataFrame( - [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), - (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), - (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), - (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) - - glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], - atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) - - -class LinearRegressionTest(SparkSessionTestCase): - - def test_linear_regression_with_huber_loss(self): - - data_path = "data/mllib/sample_linear_regression_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lir = LinearRegression(loss="huber", epsilon=2.0) - model = lir.fit(df) - - expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, - 1.2612, -0.333, -0.5694, -0.6311, 0.6053] - expectedIntercept = 0.1607 - expectedScale = 9.758 - - self.assertTrue( - np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) - self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) - self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) - - -class LogisticRegressionTest(SparkSessionTestCase): - - def test_binomial_logistic_regression_with_bound(self): - - df = self.spark.createDataFrame( - [(1.0, 1.0, Vectors.dense(0.0, 5.0)), - (0.0, 2.0, Vectors.dense(1.0, 2.0)), - (1.0, 3.0, Vectors.dense(2.0, 1.0)), - (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) - - lor = LogisticRegression(regParam=0.01, weightCol="weight", - lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), - upperBoundsOnIntercepts=Vectors.dense(0.0)) - model = lor.fit(df) - self.assertTrue( - np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) - - def test_multinomial_logistic_regression_with_bound(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lor = LogisticRegression(regParam=0.01, - lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), - upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) - model = lor.fit(df) - expected = [[4.593, 4.5516, 9.0099, 12.2904], - [1.0, 8.1093, 7.0, 10.0], - [3.041, 5.0, 8.0, 11.0]] - for i in range(0, len(expected)): - self.assertTrue( - np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) - self.assertTrue( - np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) - - -class MultilayerPerceptronClassifierTest(SparkSessionTestCase): - - def test_raw_and_probability_prediction(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], - blockSize=128, seed=123) - model = mlp.fit(df) - test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() - result = model.transform(test).head() - expected_prediction = 2.0 - expected_probability = [0.0, 0.0, 1.0] - expected_rawPrediction = [57.3955, -124.5462, 67.9943] - self.assertTrue(result.prediction, expected_prediction) - self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) - self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) - - -class FPGrowthTests(SparkSessionTestCase): - def setUp(self): - super(FPGrowthTests, self).setUp() - self.data = self.spark.createDataFrame( - [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], - ["items"]) - - def test_association_rules(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], - ["antecedent", "consequent", "confidence", "lift"] - ) - actual_association_rules = fpm.associationRules - - self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) - self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) - - def test_freq_itemsets(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_freq_itemsets = self.spark.createDataFrame( - [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], - ["items", "freq"] - ) - actual_freq_itemsets = fpm.freqItemsets - - self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) - self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) - - def tearDown(self): - del self.data - - -class ImageReaderTest(SparkSessionTestCase): - - def test_read_images(self): - data_path = 'data/mllib/images/origin/kittens' - df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - self.assertEqual(df.count(), 4) - first_row = df.take(1)[0][0] - array = ImageSchema.toNDArray(first_row) - self.assertEqual(len(array), first_row[1]) - self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) - self.assertEqual(df.schema, ImageSchema.imageSchema) - self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) - expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} - self.assertEqual(ImageSchema.ocvTypes, expected) - expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] - self.assertEqual(ImageSchema.imageFields, expected) - self.assertEqual(ImageSchema.undefinedImageType, "Undefined") - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "image argument should be pyspark.sql.types.Row; however", - lambda: ImageSchema.toNDArray("a")) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - ValueError, - "image argument should have attributes specified in", - lambda: ImageSchema.toNDArray(Row(a=1))) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "array argument should be numpy.ndarray; however, it got", - lambda: ImageSchema.toImage("a")) - - -class ImageReaderTest2(PySparkTestCase): - - @classmethod - def setUpClass(cls): - super(ImageReaderTest2, cls).setUpClass() - cls.hive_available = True - # Note that here we enable Hive's support. - cls.spark = None - try: - cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.tearDownClass() - cls.hive_available = False - except TypeError: - cls.tearDownClass() - cls.hive_available = False - if cls.hive_available: - cls.spark = HiveContext._createForTesting(cls.sc) - - def setUp(self): - if not self.hive_available: - self.skipTest("Hive is not available.") - - @classmethod - def tearDownClass(cls): - super(ImageReaderTest2, cls).tearDownClass() - if cls.spark is not None: - cls.spark.sparkSession.stop() - cls.spark = None - - def test_read_images_multiple_times(self): - # This test case is to check if `ImageSchema.readImages` tries to - # initiate Hive client multiple times. See SPARK-22651. - data_path = 'data/mllib/images/origin/kittens' - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - - -class ALSTest(SparkSessionTestCase): - - def test_storage_levels(self): - df = self.spark.createDataFrame( - [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], - ["user", "item", "rating"]) - als = ALS().setMaxIter(1).setRank(1) - # test default params - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") - # test non-default params - als.setIntermediateStorageLevel("MEMORY_ONLY_2") - als.setFinalStorageLevel("DISK_ONLY") - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") - - -class DefaultValuesTests(PySparkTestCase): - """ - Test :py:class:`JavaParams` classes to see if their default Param values match - those in their Scala counterparts. - """ - - def test_java_params(self): - import pyspark.ml.feature - import pyspark.ml.classification - import pyspark.ml.clustering - import pyspark.ml.evaluation - import pyspark.ml.pipeline - import pyspark.ml.recommendation - import pyspark.ml.regression - - modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, - pyspark.ml.regression] - for module in modules: - for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and not name.endswith('Params')\ - and issubclass(cls, JavaParams) and not inspect.isabstract(cls): - # NOTE: disable check_params_exist until there is parity with Scala API - ParamTests.check_params(self, cls(), check_params_exist=False) - - # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel - ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), - check_params_exist=False) - ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), - check_params_exist=False) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), - Row(label=0.0, features=self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -class WrapperTests(MLlibTestCase): - - def test_new_java_array(self): - # test array of strings - str_list = ["a", "b", "c"] - java_class = self.sc._gateway.jvm.java.lang.String - java_array = JavaWrapper._new_java_array(str_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), str_list) - # test array of integers - int_list = [1, 2, 3] - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array(int_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), int_list) - # test array of floats - float_list = [0.1, 0.2, 0.3] - java_class = self.sc._gateway.jvm.java.lang.Double - java_array = JavaWrapper._new_java_array(float_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), float_list) - # test array of bools - bool_list = [False, True, True] - java_class = self.sc._gateway.jvm.java.lang.Boolean - java_array = JavaWrapper._new_java_array(bool_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), bool_list) - # test array of Java DenseVectors - v1 = DenseVector([0.0, 1.0]) - v2 = DenseVector([1.0, 0.0]) - vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] - java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector - java_array = JavaWrapper._new_java_array(vec_java_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) - # test empty array - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array([], java_class) - self.assertEqual(_java2py(self.sc, java_array), []) - - -class ChiSquareTestTests(SparkSessionTestCase): - - def test_chisquaretest(self): - data = [[0, Vectors.dense([0, 1, 2])], - [1, Vectors.dense([1, 1, 1])], - [2, Vectors.dense([2, 1, 0])]] - df = self.spark.createDataFrame(data, ['label', 'feat']) - res = ChiSquareTest.test(df, 'feat', 'label') - # This line is hitting the collect bug described in #17218, commented for now. - # pValues = res.select("degreesOfFreedom").collect()) - self.assertIsInstance(res, DataFrame) - fieldNames = set(field.name for field in res.schema.fields) - expectedFields = ["pValues", "degreesOfFreedom", "statistics"] - self.assertTrue(all(field in fieldNames for field in expectedFields)) - - -class UnaryTransformerTests(SparkSessionTestCase): - - def test_unary_transformer_validate_input_type(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - # should not raise any errors - transformer.validateInputType(DoubleType()) - - with self.assertRaises(TypeError): - # passing the wrong input type should raise an error - transformer.validateInputType(IntegerType()) - - def test_unary_transformer_transform(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - df = self.spark.range(0, 10).toDF('input') - df = df.withColumn("input", df.input.cast(dataType="double")) - - transformed_df = transformer.transform(df) - results = transformed_df.select("input", "output").collect() - - for res in results: - self.assertEqual(res.input + shiftVal, res.output) - - -class EstimatorTest(unittest.TestCase): - - def testDefaultFitMultiple(self): - N = 4 - data = MockDataset() - estimator = MockEstimator() - params = [{estimator.fake: i} for i in range(N)] - modelIter = estimator.fitMultiple(data, params) - indexList = [] - for index, model in modelIter: - self.assertEqual(model.getFake(), index) - indexList.append(index) - self.assertEqual(sorted(indexList), list(range(N))) - - -if __name__ == "__main__": - from pyspark.ml.tests import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/ml/tests/__init__.py b/python/pyspark/ml/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/ml/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py new file mode 100644 index 0000000000000..1a72e124962c8 --- /dev/null +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -0,0 +1,349 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from shutil import rmtree +import sys +import tempfile + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import numpy as np + +from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, \ + MultilayerPerceptronClassifier, OneVsRest +from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel +from pyspark.ml.fpm import FPGrowth +from pyspark.ml.linalg import Matrices, Vectors +from pyspark.ml.recommendation import ALS +from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression +from pyspark.sql import Row +from pyspark.testing.mlutils import SparkSessionTestCase + + +class LogisticRegressionTest(SparkSessionTestCase): + + def test_binomial_logistic_regression_with_bound(self): + + df = self.spark.createDataFrame( + [(1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.0, 3.0, Vectors.dense(2.0, 1.0)), + (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) + + lor = LogisticRegression(regParam=0.01, weightCol="weight", + lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), + upperBoundsOnIntercepts=Vectors.dense(0.0)) + model = lor.fit(df) + self.assertTrue( + np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) + + def test_multinomial_logistic_regression_with_bound(self): + + data_path = "data/mllib/sample_multiclass_classification_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lor = LogisticRegression(regParam=0.01, + lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), + upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) + model = lor.fit(df) + expected = [[4.593, 4.5516, 9.0099, 12.2904], + [1.0, 8.1093, 7.0, 10.0], + [3.041, 5.0, 8.0, 11.0]] + for i in range(0, len(expected)): + self.assertTrue( + np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) + self.assertTrue( + np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) + + +class MultilayerPerceptronClassifierTest(SparkSessionTestCase): + + def test_raw_and_probability_prediction(self): + + data_path = "data/mllib/sample_multiclass_classification_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], + blockSize=128, seed=123) + model = mlp.fit(df) + test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() + result = model.transform(test).head() + expected_prediction = 2.0 + expected_probability = [0.0, 0.0, 1.0] + expected_rawPrediction = [57.3955, -124.5462, 67.9943] + self.assertTrue(result.prediction, expected_prediction) + self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) + self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) + + +class OneVsRestTests(SparkSessionTestCase): + + def test_copy(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + ovr1 = ovr.copy({lr.maxIter: 10}) + self.assertEqual(ovr.getClassifier().getMaxIter(), 5) + self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) + model = ovr.fit(df) + model1 = model.copy({model.predictionCol: "indexed"}) + self.assertEqual(model1.getPredictionCol(), "indexed") + + def test_output_columns(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, parallelism=1) + model = ovr.fit(df) + output = model.transform(df) + self.assertEqual(output.columns, ["label", "features", "prediction"]) + + def test_parallelism_doesnt_change_output(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) + modelPar1 = ovrPar1.fit(df) + ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) + modelPar2 = ovrPar2.fit(df) + for i, model in enumerate(modelPar1.models): + self.assertTrue(np.allclose(model.coefficients.toArray(), + modelPar2.models[i].coefficients.toArray(), atol=1E-4)) + self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) + + def test_support_for_weightCol(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), + (1.0, Vectors.sparse(2, [], []), 1.0), + (2.0, Vectors.dense(0.5, 0.5), 1.0)], + ["label", "features", "weight"]) + # classifier inherits hasWeightCol + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, weightCol="weight") + self.assertIsNotNone(ovr.fit(df)) + # classifier doesn't inherit hasWeightCol + dt = DecisionTreeClassifier() + ovr2 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr2.fit(df)) + + +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + +class LDATest(SparkSessionTestCase): + + def _compare(self, m1, m2): + """ + Temp method for comparing instances. + TODO: Replace with generic implementation once SPARK-14706 is merged. + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + if m1.isDefined(p): + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + if isinstance(m1, LDAModel): + self.assertEqual(m1.vocabSize(), m2.vocabSize()) + self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) + + def test_persistence(self): + # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. + df = self.spark.createDataFrame([ + [1, Vectors.dense([0.0, 1.0])], + [2, Vectors.sparse(2, {0: 1.0})], + ], ["id", "features"]) + # Fit model + lda = LDA(k=2, seed=1, optimizer="em") + distributedModel = lda.fit(df) + self.assertTrue(distributedModel.isDistributed()) + localModel = distributedModel.toLocal() + self.assertFalse(localModel.isDistributed()) + # Define paths + path = tempfile.mkdtemp() + lda_path = path + "/lda" + dist_model_path = path + "/distLDAModel" + local_model_path = path + "/localLDAModel" + # Test LDA + lda.save(lda_path) + lda2 = LDA.load(lda_path) + self._compare(lda, lda2) + # Test DistributedLDAModel + distributedModel.save(dist_model_path) + distributedModel2 = DistributedLDAModel.load(dist_model_path) + self._compare(distributedModel, distributedModel2) + # Test LocalLDAModel + localModel.save(local_model_path) + localModel2 = LocalLDAModel.load(local_model_path) + self._compare(localModel, localModel2) + # Clean up + try: + rmtree(path) + except OSError: + pass + + +class FPGrowthTests(SparkSessionTestCase): + def setUp(self): + super(FPGrowthTests, self).setUp() + self.data = self.spark.createDataFrame( + [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], + ["items"]) + + def test_association_rules(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_association_rules = self.spark.createDataFrame( + [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], + ["antecedent", "consequent", "confidence", "lift"] + ) + actual_association_rules = fpm.associationRules + + self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) + self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) + + def test_freq_itemsets(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_freq_itemsets = self.spark.createDataFrame( + [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], + ["items", "freq"] + ) + actual_freq_itemsets = fpm.freqItemsets + + self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) + self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) + + def tearDown(self): + del self.data + + +class ALSTest(SparkSessionTestCase): + + def test_storage_levels(self): + df = self.spark.createDataFrame( + [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ["user", "item", "rating"]) + als = ALS().setMaxIter(1).setRank(1) + # test default params + als.fit(df) + self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") + # test non-default params + als.setIntermediateStorageLevel("MEMORY_ONLY_2") + als.setFinalStorageLevel("DISK_ONLY") + als.fit(df) + self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") + self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") + self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") + self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") + + +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + def test_offset(self): + + df = self.spark.createDataFrame( + [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) + + glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], + atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) + + +class LinearRegressionTest(SparkSessionTestCase): + + def test_linear_regression_with_huber_loss(self): + + data_path = "data/mllib/sample_linear_regression_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lir = LinearRegression(loss="huber", epsilon=2.0) + model = lir.fit(df) + + expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, + 1.2612, -0.333, -0.5694, -0.6311, 0.6053] + expectedIntercept = 0.1607 + expectedScale = 9.758 + + self.assertTrue( + np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) + self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) + self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py new file mode 100644 index 0000000000000..59c45f638dd45 --- /dev/null +++ b/python/pyspark/ml/tests/test_base.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql.types import DoubleType, IntegerType +from pyspark.testing.mlutils import MockDataset, MockEstimator, MockUnaryTransformer, \ + SparkSessionTestCase + + +class UnaryTransformerTests(SparkSessionTestCase): + + def test_unary_transformer_validate_input_type(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal) \ + .setInputCol("input").setOutputCol("output") + + # should not raise any errors + transformer.validateInputType(DoubleType()) + + with self.assertRaises(TypeError): + # passing the wrong input type should raise an error + transformer.validateInputType(IntegerType()) + + def test_unary_transformer_transform(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal) \ + .setInputCol("input").setOutputCol("output") + + df = self.spark.range(0, 10).toDF('input') + df = df.withColumn("input", df.input.cast(dataType="double")) + + transformed_df = transformer.transform(df) + results = transformed_df.select("input", "output").collect() + + for res in results: + self.assertEqual(res.input + shiftVal, res.output) + + +class EstimatorTest(unittest.TestCase): + + def testDefaultFitMultiple(self): + N = 4 + data = MockDataset() + estimator = MockEstimator() + params = [{estimator.fake: i} for i in range(N)] + modelIter = estimator.fitMultiple(data, params) + indexList = [] + for index, model in modelIter: + self.assertEqual(model.getFake(), index) + indexList.append(index) + self.assertEqual(sorted(indexList), list(range(N))) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_base import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py new file mode 100644 index 0000000000000..6c3e5c6734509 --- /dev/null +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import numpy as np + +from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.sql import Row +from pyspark.testing.mlutils import SparkSessionTestCase + + +class EvaluatorTests(SparkSessionTestCase): + + def test_java_params(self): + """ + This tests a bug fixed by SPARK-18274 which causes multiple copies + of a Params instance in Python to be linked to the same Java instance. + """ + evaluator = RegressionEvaluator(metricName="r2") + df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) + evaluator.evaluate(df) + self.assertEqual(evaluator._java_obj.getMetricName(), "r2") + evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) + evaluator.evaluate(df) + evaluatorCopy.evaluate(df) + self.assertEqual(evaluator._java_obj.getMetricName(), "r2") + self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_evaluation import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py new file mode 100644 index 0000000000000..23f66e73b4820 --- /dev/null +++ b/python/pyspark/ml/tests/test_feature.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +if sys.version > '3': + basestring = str + +from pyspark.ml.feature import Binarizer, CountVectorizer, CountVectorizerModel, HashingTF, IDF, \ + NGram, RFormula, StopWordsRemover, StringIndexer, StringIndexerModel, VectorSizeHint +from pyspark.ml.linalg import DenseVector, SparseVector, Vectors +from pyspark.sql import Row +from pyspark.testing.utils import QuietTest +from pyspark.testing.mlutils import check_params, SparkSessionTestCase + + +class FeatureTests(SparkSessionTestCase): + + def test_binarizer(self): + b0 = Binarizer() + self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) + self.assertTrue(all([~b0.isSet(p) for p in b0.params])) + self.assertTrue(b0.hasDefault(b0.threshold)) + self.assertEqual(b0.getThreshold(), 0.0) + b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) + self.assertTrue(all([b0.isSet(p) for p in b0.params])) + self.assertEqual(b0.getThreshold(), 1.0) + self.assertEqual(b0.getInputCol(), "input") + self.assertEqual(b0.getOutputCol(), "output") + + b0c = b0.copy({b0.threshold: 2.0}) + self.assertEqual(b0c.uid, b0.uid) + self.assertListEqual(b0c.params, b0.params) + self.assertEqual(b0c.getThreshold(), 2.0) + + b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") + self.assertNotEqual(b1.uid, b0.uid) + self.assertEqual(b1.getThreshold(), 2.0) + self.assertEqual(b1.getInputCol(), "input") + self.assertEqual(b1.getOutputCol(), "output") + + def test_idf(self): + dataset = self.spark.createDataFrame([ + (DenseVector([1.0, 2.0]),), + (DenseVector([0.0, 1.0]),), + (DenseVector([3.0, 0.2]),)], ["tf"]) + idf0 = IDF(inputCol="tf") + self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) + idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) + self.assertEqual(idf0m.uid, idf0.uid, + "Model should inherit the UID from its parent estimator.") + output = idf0m.transform(dataset) + self.assertIsNotNone(output.head().idf) + # Test that parameters transferred to Python Model + check_params(self, idf0m) + + def test_ngram(self): + dataset = self.spark.createDataFrame([ + Row(input=["a", "b", "c", "d", "e"])]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) + + def test_stopwordsremover(self): + dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEqual(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["panda"]) + self.assertEqual(type(stopWordRemover.getStopWords()), list) + self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a"]) + # with language selection + stopwords = StopWordsRemover.loadDefaultStopWords("turkish") + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + + def test_count_vectorizer_with_binary(self): + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset) \ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + + def test_rformula_string_indexer_order_type(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") + self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') + transformedDF = rf.fit(df).transform(df) + observed = transformedDF.select("features").collect() + expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] + for i in range(0, len(expected)): + self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2) \ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + + +class HashingTFTest(SparkSessionTestCase): + + def test_apply_binary_term_freqs(self): + + df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 10 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_feature import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py new file mode 100644 index 0000000000000..dcc7a32c9fd70 --- /dev/null +++ b/python/pyspark/ml/tests/test_image.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import py4j + +from pyspark.ml.image import ImageSchema +from pyspark.testing.mlutils import PySparkTestCase, SparkSessionTestCase +from pyspark.sql import HiveContext, Row +from pyspark.testing.utils import QuietTest + + +class ImageReaderTest(SparkSessionTestCase): + + def test_read_images(self): + data_path = 'data/mllib/images/origin/kittens' + df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + self.assertEqual(df.count(), 4) + first_row = df.take(1)[0][0] + array = ImageSchema.toNDArray(first_row) + self.assertEqual(len(array), first_row[1]) + self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) + self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) + expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} + self.assertEqual(ImageSchema.ocvTypes, expected) + expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] + self.assertEqual(ImageSchema.imageFields, expected) + self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "image argument should be pyspark.sql.types.Row; however", + lambda: ImageSchema.toNDArray("a")) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + ValueError, + "image argument should have attributes specified in", + lambda: ImageSchema.toNDArray(Row(a=1))) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "array argument should be numpy.ndarray; however, it got", + lambda: ImageSchema.toImage("a")) + + +class ImageReaderTest2(PySparkTestCase): + + @classmethod + def setUpClass(cls): + super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True + # Note that here we enable Hive's support. + cls.spark = None + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + cls.hive_available = False + except TypeError: + cls.tearDownClass() + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") + + @classmethod + def tearDownClass(cls): + super(ImageReaderTest2, cls).tearDownClass() + if cls.spark is not None: + cls.spark.sparkSession.stop() + cls.spark = None + + def test_read_images_multiple_times(self): + # This test case is to check if `ImageSchema.readImages` tries to + # initiate Hive client multiple times. See SPARK-22651. + data_path = 'data/mllib/images/origin/kittens' + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_image import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py new file mode 100644 index 0000000000000..76e5386e86125 --- /dev/null +++ b/python/pyspark/ml/tests/test_linalg.py @@ -0,0 +1,392 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import array as pyarray +from numpy import arange, array, array_equal, inf, ones, tile, zeros + +from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \ + Vector, VectorUDT, Vectors +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase +from pyspark.sql import Row + + +ser = make_serializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), + Row(label=0.0, features=self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_linalg import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py new file mode 100644 index 0000000000000..1f36d4544ab92 --- /dev/null +++ b/python/pyspark/ml/tests/test_param.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import sys +import array as pyarray +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +if sys.version > '3': + xrange = range + +import numpy as np + +from pyspark import keyword_only +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import KMeans +from pyspark.ml.feature import Binarizer, Bucketizer, ElementwiseProduct, IndexToString, \ + VectorSlicer, Word2Vec +from pyspark.ml.linalg import DenseVector, SparseVector +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed +from pyspark.ml.wrapper import JavaParams +from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase + + +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int(self): + lr = LogisticRegression(maxIter=5.0) + self.assertEqual(lr.getMaxIter(), 5) + self.assertTrue(type(lr.getMaxIter()) == int) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) + + def test_float(self): + lr = LogisticRegression(tol=1) + self.assertEqual(lr.getTol(), 1.0) + self.assertTrue(type(lr.getTol()) == float) + self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) + + def test_vector(self): + ewp = ElementwiseProduct(scalingVec=[1, 3]) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) + ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) + self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) + + def test_list(self): + l = [0, 1] + for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), + pyarray.array('l', l), xrange(2), tuple(l)]: + converted = TypeConverters.toList(lst_like) + self.assertEqual(type(converted), list) + self.assertListEqual(converted, l) + + def test_list_int(self): + for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), + SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), + pyarray.array('d', [1.0, 2.0])]: + vs = VectorSlicer(indices=indices) + self.assertListEqual(vs.getIndices(), [1, 2]) + self.assertTrue(all([type(v) == int for v in vs.getIndices()])) + self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) + + def test_list_float(self): + b = Bucketizer(splits=[1, 4]) + self.assertEqual(b.getSplits(), [1.0, 4.0]) + self.assertTrue(all([type(v) == float for v in b.getSplits()])) + self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) + + def test_list_string(self): + for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: + idx_to_string = IndexToString(labels=labels) + self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) + self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) + + def test_string(self): + lr = LogisticRegression() + for col in ['features', u'features', np.str_('features')]: + lr.setFeaturesCol(col) + self.assertEqual(lr.getFeaturesCol(), 'features') + self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) + + def test_bool(self): + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + + +class TestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(OtherTestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + +class ParamTests(SparkSessionTestCase): + + def test_copy_new_parent(self): + testParams = TestParams() + # Copying an instantiated param should fail + with self.assertRaises(ValueError): + testParams.maxIter._copy_new_parent(testParams) + # Copying a dummy param should succeed + TestParams.maxIter._copy_new_parent(testParams) + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) + if sys.version_info[0] >= 3: + # In Python 3, it is allowed to get/set attributes with non-ascii characters. + e_cls = AttributeError + else: + e_cls = UnicodeEncodeError + self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + seed = testParams.seed + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter, seed]) + + self.assertTrue(testParams.hasParam(maxIter.name)) + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEqual(testParams.getMaxIter(), 100) + + self.assertTrue(testParams.hasParam(inputCol.name)) + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + + "set raises an error for a non-member parameter.", + typeConverter=TypeConverters.toString) + with self.assertRaises(ValueError): + testParams.set(otherParam, "value") + + # Since the default is normally random, set it to a known number for debug str + testParams._setDefault(seed=41) + testParams.setSeed(43) + + self.assertEqual( + testParams.explainParams(), + "\n".join(["inputCol: input column name. (undefined)", + "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", + "seed: random seed. (default: 41, current: 43)"])) + + def test_kmeans_param(self): + algo = KMeans() + self.assertEqual(algo.getInitMode(), "k-means||") + algo.setK(10) + self.assertEqual(algo.getK(), 10) + algo.setInitSteps(10) + self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") + + def test_hasseed(self): + noSeedSpecd = TestParams() + withSeedSpecd = TestParams(seed=42) + other = OtherTestParams() + # Check that we no longer use 42 as the magic number + self.assertNotEqual(noSeedSpecd.getSeed(), 42) + origSeed = noSeedSpecd.getSeed() + # Check that we only compute the seed once + self.assertEqual(noSeedSpecd.getSeed(), origSeed) + # Check that a specified seed is honored + self.assertEqual(withSeedSpecd.getSeed(), 42) + # Check that a different class has a different seed + self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + + def test_param_property_error(self): + param_store = HasThrowableProperty() + self.assertRaises(RuntimeError, lambda: param_store.test_property) + params = param_store.params # should not invoke the property 'test_property' + self.assertEqual(len(params), 1) + + def test_word2vec_param(self): + model = Word2Vec().setWindowSize(6) + # Check windowSize is set properly + self.assertEqual(model.getWindowSize(), 6) + + def test_copy_param_extras(self): + tp = TestParams(seed=42) + extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} + tp_copy = tp.copy(extra=extra) + self.assertEqual(tp.uid, tp_copy.uid) + self.assertEqual(tp.params, tp_copy.params) + for k, v in extra.items(): + self.assertTrue(tp_copy.isDefined(k)) + self.assertEqual(tp_copy.getOrDefault(k), v) + copied_no_extra = {} + for k, v in tp_copy._paramMap.items(): + if k not in extra: + copied_no_extra[k] = v + self.assertEqual(tp._paramMap, copied_no_extra) + self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + + +class DefaultValuesTests(PySparkTestCase): + """ + Test :py:class:`JavaParams` classes to see if their default Param values match + those in their Scala counterparts. + """ + + def test_java_params(self): + import pyspark.ml.feature + import pyspark.ml.classification + import pyspark.ml.clustering + import pyspark.ml.evaluation + import pyspark.ml.pipeline + import pyspark.ml.recommendation + import pyspark.ml.regression + + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] + for module in modules: + for name, cls in inspect.getmembers(module, inspect.isclass): + if not name.endswith('Model') and not name.endswith('Params') \ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): + # NOTE: disable check_params_exist until there is parity with Scala API + check_params(self, cls(), check_params_exist=False) + + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel + check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_param import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py new file mode 100644 index 0000000000000..b5a2e16df5532 --- /dev/null +++ b/python/pyspark/ml/tests/test_persistence.py @@ -0,0 +1,369 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from shutil import rmtree +import sys +import tempfile +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml import Transformer +from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ + OneVsRestModel +from pyspark.ml.feature import Binarizer, HashingTF, PCA +from pyspark.ml.linalg import Vectors +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter +from pyspark.ml.wrapper import JavaParams +from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase + + +class PersistenceTest(SparkSessionTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(type(lr.uid), type(lr2.uid)) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + + def test_logistic_regression(self): + lr = LogisticRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/logreg" + lr.save(lr_path) + lr2 = LogisticRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LogisticRegression instance uid (%s) " + "did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LogisticRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def _compare_params(self, m1, m2, param): + """ + Compare 2 ML Params instances for the given param, and assert both have the same param value + and parent. The param must be a parameter of m1. + """ + # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. + if m1.isDefined(param): + paramValue1 = m1.getOrDefault(param) + paramValue2 = m2.getOrDefault(m2.getParam(param.name)) + if isinstance(paramValue1, Params): + self._compare_pipelines(paramValue1, paramValue2) + else: + self.assertEqual(paramValue1, paramValue2) # for general types param + # Assert parents are equal + self.assertEqual(param.parent, m2.getParam(param.name).parent) + else: + # If m1 is not defined param, then m2 should not, too. See SPARK-14931. + self.assertFalse(m2.isDefined(m2.getParam(param.name))) + + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + - OneVsRest, OneVsRestModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self._compare_params(m1, m2, p) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): + for p in m1.params: + self._compare_params(m1, m2, p) + if isinstance(m1, OneVsRestModel): + self.assertEqual(len(m1.models), len(m2.models)) + for x, y in zip(m1.models, m2.models): + self._compare_pipelines(x, y) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + pl = Pipeline(stages=[tf, pca]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_onevsrest(self): + temp_path = tempfile.mkdtemp() + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))] * 10, + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self._compare_pipelines(ovr, loadedOvr) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + self._compare_pipelines(model, loadedModel) + + def test_decisiontree_classifier(self): + dt = DecisionTreeClassifier(maxDepth=1) + path = tempfile.mkdtemp() + dtc_path = path + "/dtc" + dt.save(dtc_path) + dt2 = DecisionTreeClassifier.load(dtc_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeClassifier instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeClassifier instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_decisiontree_regressor(self): + dt = DecisionTreeRegressor(maxDepth=1) + path = tempfile.mkdtemp() + dtr_path = path + "/dtr" + dt.save(dtr_path) + dt2 = DecisionTreeClassifier.load(dtr_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeRegressor instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeRegressor instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_default_read_write(self): + temp_path = tempfile.mkdtemp() + + lr = LogisticRegression() + lr.setMaxIter(50) + lr.setThreshold(.75) + writer = DefaultParamsWriter(lr) + + savePath = temp_path + "/lr" + writer.save(savePath) + + reader = DefaultParamsReadable.read() + lr2 = reader.load(savePath) + + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) + + # test overwrite + lr.setThreshold(.8) + writer.overwrite().save(savePath) + + reader = DefaultParamsReadable.read() + lr3 = reader.load(savePath) + + self.assertEqual(lr.uid, lr3.uid) + self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_persistence import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py new file mode 100644 index 0000000000000..31ef02c2e601f --- /dev/null +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml.pipeline import Pipeline +from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + model0, transformer1, model2, transformer3 = pipeline_model.stages + self.assertEqual(0, model0.dataset_index) + self.assertEqual(0, model0.getFake()) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.getFake()) + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + def test_identity_pipeline(self): + dataset = MockDataset() + + def doTransform(pipeline): + pipeline_model = pipeline.fit(dataset) + return pipeline_model.transform(dataset) + # check that empty pipeline did not perform any transformation + self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) + # check that failure to set stages param will raise KeyError for missing param + self.assertRaises(KeyError, lambda: doTransform(Pipeline())) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_pipeline import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py new file mode 100644 index 0000000000000..bdc4853bc05c2 --- /dev/null +++ b/python/pyspark/ml/tests/test_stat.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml.linalg import Vectors +from pyspark.ml.stat import ChiSquareTest +from pyspark.sql import DataFrame +from pyspark.testing.mlutils import SparkSessionTestCase + + +class ChiSquareTestTests(SparkSessionTestCase): + + def test_chisquaretest(self): + data = [[0, Vectors.dense([0, 1, 2])], + [1, Vectors.dense([1, 1, 1])], + [2, Vectors.dense([2, 1, 0])]] + df = self.spark.createDataFrame(data, ['label', 'feat']) + res = ChiSquareTest.test(df, 'feat', 'label') + # This line is hitting the collect bug described in #17218, commented for now. + # pValues = res.select("degreesOfFreedom").collect()) + self.assertIsInstance(res, DataFrame) + fieldNames = set(field.name for field in res.schema.fields) + expectedFields = ["pValues", "degreesOfFreedom", "statistics"] + self.assertTrue(all(field in fieldNames for field in expectedFields)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_stat import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py new file mode 100644 index 0000000000000..d5464f7be6372 --- /dev/null +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +if sys.version > '3': + basestring = str + +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans +from pyspark.ml.linalg import Vectors +from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression +from pyspark.sql import DataFrame +from pyspark.testing.mlutils import SparkSessionTestCase + + +class TrainingSummaryTest(SparkSessionTestCase): + + def test_linear_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertAlmostEqual(s.explainedVariance, 0.25, 2) + self.assertAlmostEqual(s.meanAbsoluteError, 0.0) + self.assertAlmostEqual(s.meanSquaredError, 0.0) + self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) + self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) + self.assertTrue(isinstance(s.residuals, DataFrame)) + self.assertEqual(s.numInstances, 2) + self.assertEqual(s.degreesOfFreedom, 1) + devResiduals = s.devianceResiduals + self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned + # The child class LinearRegressionTrainingSummary runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + + def test_glr_summary(self): + from pyspark.ml.linalg import Vectors + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", + fitIntercept=False) + model = glr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.numInstances, 2) + self.assertTrue(isinstance(s.residuals(), DataFrame)) + self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + self.assertEqual(s.degreesOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedomNull, 2) + self.assertEqual(s.rank, 1) + self.assertTrue(isinstance(s.solver, basestring)) + self.assertTrue(isinstance(s.aic, float)) + self.assertTrue(isinstance(s.deviance, float)) + self.assertTrue(isinstance(s.nullDeviance, float)) + self.assertTrue(isinstance(s.dispersion, float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned + # The child class GeneralizedLinearRegressionTrainingSummary runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.deviance, s.deviance) + + def test_binary_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + + def test_multiclass_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], [])), + (2.0, 2.0, Vectors.dense(2.0)), + (2.0, 2.0, Vectors.dense(1.9))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 0.75, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 3) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 20) + + def test_kmeans_summary(self): + data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=2, seed=1) + model = kmeans.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 1) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_training_summary import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py new file mode 100644 index 0000000000000..af00d1de7ab6a --- /dev/null +++ b/python/pyspark/ml/tests/test_tuning.py @@ -0,0 +1,552 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import tempfile +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.ml import Estimator, Model +from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest +from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ + MulticlassClassificationEvaluator, RegressionEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.ml.param import Param, Params +from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \ + TrainValidationSplit, TrainValidationSplitModel +from pyspark.sql.functions import rand +from pyspark.testing.mlutils import SparkSessionTestCase + + +class HasInducedError(Params): + + def __init__(self): + super(HasInducedError, self).__init__() + self.inducedError = Param(self, "inducedError", + "Uniformly-distributed error added to feature") + + def getInducedError(self): + return self.getOrDefault(self.inducedError) + + +class InducedErrorModel(Model, HasInducedError): + + def __init__(self): + super(InducedErrorModel, self).__init__() + + def _transform(self, dataset): + return dataset.withColumn("prediction", + dataset.feature + (rand(0) * self.getInducedError())) + + +class InducedErrorEstimator(Estimator, HasInducedError): + + def __init__(self, inducedError=1.0): + super(InducedErrorEstimator, self).__init__() + self._set(inducedError=inducedError) + + def _fit(self, dataset): + model = InducedErrorModel() + self._copyValues(model) + return model + + +class CrossValidatorTests(SparkSessionTestCase): + + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvCopied = cv.copy() + self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) + + cvModel = cv.fit(dataset) + cvModelCopied = cvModel.copy() + for index in range(len(cvModel.avgMetrics)): + self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) + < 0.0001) + + def test_fit_minimize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + + def test_save_load_trained_model(self): + # This tests saving and loading the trained model only. + # Save/load for CrossValidator will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + lrModel = cvModel.bestModel + + cvModelPath = temp_path + "/cvModel" + lrModel.save(cvModelPath) + loadedLrModel = LogisticRegressionModel.load(cvModelPath) + self.assertEqual(loadedLrModel.uid, lrModel.uid) + self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + + def test_save_load_simple_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cv.setParallelism(1) + cvSerialModel = cv.fit(dataset) + cv.setParallelism(2) + cvParallelModel = cv.fit(dataset) + self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + + def test_save_load_nested_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + + originalParamMap = cv.getEstimatorParamMaps() + loadedParamMap = loadedCV.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + +class TrainValidationSplitTests(SparkSessionTestCase): + + def test_fit_minimize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(0.0, min(validationMetrics)) + + def test_fit_maximize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(1.0, max(validationMetrics)) + + def test_save_load_trained_model(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + lrModel = tvsModel.bestModel + + tvsModelPath = temp_path + "/tvsModel" + lrModel.save(tvsModelPath) + loadedLrModel = LogisticRegressionModel.load(tvsModelPath) + self.assertEqual(loadedLrModel.uid, lrModel.uid) + self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + + def test_save_load_simple_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvs.setParallelism(1) + tvsSerialModel = tvs.fit(dataset) + tvs.setParallelism(2) + tvsParallelModel = tvs.fit(dataset) + self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + + def test_save_load_nested_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + + originalParamMap = tvs.getEstimatorParamMaps() + loadedParamMap = loadedTvs.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsCopied = tvs.copy() + tvsModelCopied = tvsModel.copy() + + self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, + "Copied TrainValidationSplit has the same uid of Estimator") + + self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) + self.assertEqual(len(tvsModel.validationMetrics), + len(tvsModelCopied.validationMetrics), + "Copied validationMetrics has the same size of the original") + for index in range(len(tvsModel.validationMetrics)): + self.assertEqual(tvsModel.validationMetrics[index], + tvsModelCopied.validationMetrics[index]) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_tuning import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py new file mode 100644 index 0000000000000..4326d8e060dd7 --- /dev/null +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +import py4j + +from pyspark.ml.linalg import DenseVector, Vectors +from pyspark.ml.regression import LinearRegression +from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper +from pyspark.testing.mllibutils import MLlibTestCase +from pyspark.testing.mlutils import SparkSessionTestCase + + +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + +class WrapperTests(MLlibTestCase): + + def test_new_java_array(self): + # test array of strings + str_list = ["a", "b", "c"] + java_class = self.sc._gateway.jvm.java.lang.String + java_array = JavaWrapper._new_java_array(str_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), str_list) + # test array of integers + int_list = [1, 2, 3] + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array(int_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), int_list) + # test array of floats + float_list = [0.1, 0.2, 0.3] + java_class = self.sc._gateway.jvm.java.lang.Double + java_array = JavaWrapper._new_java_array(float_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), float_list) + # test array of bools + bool_list = [False, True, True] + java_class = self.sc._gateway.jvm.java.lang.Boolean + java_array = JavaWrapper._new_java_array(bool_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), bool_list) + # test array of Java DenseVectors + v1 = DenseVector([0.0, 1.0]) + v2 = DenseVector([1.0, 0.0]) + vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] + java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector + java_array = JavaWrapper._new_java_array(vec_java_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) + # test empty array + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array([], java_class) + self.assertEqual(_java2py(self.sc, java_array), []) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_wrapper import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py new file mode 100644 index 0000000000000..12bf650a28ee1 --- /dev/null +++ b/python/pyspark/testing/mlutils.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np + +from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable +from pyspark.ml.wrapper import _java2py +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import DoubleType +from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase + + +def check_params(test_self, py_stage, check_params_exist=True): + """ + Checks common requirements for Params.params: + - set of params exist in Java and Python and are ordered by names + - param parent has the same UID as the object's UID + - default param value from Java matches value in Python + - optionally check if all params from Java also exist in Python + """ + py_stage_str = "%s %s" % (type(py_stage), py_stage) + if not hasattr(py_stage, "_to_java"): + return + java_stage = py_stage._to_java() + if java_stage is None: + return + test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) + if check_params_exist: + param_names = [p.name for p in py_stage.params] + java_params = list(java_stage.params()) + java_param_names = [jp.name() for jp in java_params] + test_self.assertEqual( + param_names, sorted(java_param_names), + "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" + % (py_stage_str, java_param_names, param_names)) + for p in py_stage.params: + test_self.assertEqual(p.parent, py_stage.uid) + java_param = java_stage.getParam(p.name) + py_has_default = py_stage.hasDefault(p) + java_has_default = java_stage.hasDefault(java_param) + test_self.assertEqual(py_has_default, java_has_default, + "Default value mismatch of param %s for Params %s" + % (p.name, str(py_stage))) + if py_has_default: + if p.name == "seed": + continue # Random seeds between Spark and PySpark are different + java_default = _java2py(test_self.sc, + java_stage.clear(java_param).getOrDefault(java_param)) + py_stage._clear(p) + py_default = py_stage.getOrDefault(p) + # equality test for NaN is always False + if isinstance(java_default, float) and np.isnan(java_default): + java_default = "NaN" + py_default = "NaN" if np.isnan(py_default) else "not NaN" + test_self.assertEqual( + java_default, py_default, + "Java default %s != python default %s of param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + + +class SparkSessionTestCase(PySparkTestCase): + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.stop() + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + def getFake(self): + return self.getOrDefault(self.fake) + + +class MockTransformer(Transformer, HasFake): + + def __init__(self): + super(MockTransformer, self).__init__() + self.dataset_index = None + + def _transform(self, dataset): + self.dataset_index = dataset.index + dataset.index += 1 + return dataset + + +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): + + shift = Param(Params._dummy(), "shift", "The amount by which to shift " + + "data in a DataFrame", + typeConverter=TypeConverters.toFloat) + + def __init__(self, shiftVal=1): + super(MockUnaryTransformer, self).__init__() + self._setDefault(shift=1) + self._set(shift=shiftVal) + + def getShift(self): + return self.getOrDefault(self.shift) + + def setShift(self, shift): + self._set(shift=shift) + + def createTransformFunc(self): + shiftVal = self.getShift() + return lambda x: x + shiftVal + + def outputDataType(self): + return DoubleType() + + def validateInputType(self, inputType): + if inputType != DoubleType(): + raise TypeError("Bad input type: {}. ".format(inputType) + + "Requires Double.") + + +class MockEstimator(Estimator, HasFake): + + def __init__(self): + super(MockEstimator, self).__init__() + self.dataset_index = None + + def _fit(self, dataset): + self.dataset_index = dataset.index + model = MockModel() + self._copyValues(model) + return model + + +class MockModel(MockTransformer, Model, HasFake): + pass From bbbdaa82a4f4fc7a84be6641518264d9bb7bde2b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Nov 2018 09:22:32 +0800 Subject: [PATCH 2084/2461] [SPARK-26105][PYTHON] Clean unittest2 imports up that were added for Python 2.6 before ## What changes were proposed in this pull request? Currently, some of PySpark tests sill assume the tests could be ran in Python 2.6 by importing `unittest2`. For instance: ```python if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest except ImportError: sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') sys.exit(1) else: import unittest ``` While I am here, I removed some of unused imports and reordered imports per PEP 8. We officially dropped Python 2.6 support a while ago and started to discuss about Python 2 drop. It's better to remove them out. ## How was this patch tested? Manually tests, and existing tests via Jenkins. Closes #23077 from HyukjinKwon/SPARK-26105. Lead-authored-by: hyukjinkwon Co-authored-by: Bryan Cutler Signed-off-by: hyukjinkwon --- python/pyspark/ml/tests/test_algorithms.py | 11 +---------- python/pyspark/ml/tests/test_base.py | 10 +--------- python/pyspark/ml/tests/test_evaluation.py | 10 +--------- python/pyspark/ml/tests/test_feature.py | 9 +-------- python/pyspark/ml/tests/test_image.py | 10 +--------- python/pyspark/ml/tests/test_linalg.py | 12 ++---------- python/pyspark/ml/tests/test_param.py | 16 +++++----------- python/pyspark/ml/tests/test_persistence.py | 10 +--------- python/pyspark/ml/tests/test_pipeline.py | 10 +--------- python/pyspark/ml/tests/test_stat.py | 10 +--------- python/pyspark/ml/tests/test_training_summary.py | 9 +-------- python/pyspark/ml/tests/test_tuning.py | 10 +--------- python/pyspark/ml/tests/test_wrapper.py | 10 +--------- python/pyspark/mllib/tests/test_algorithms.py | 13 +------------ python/pyspark/mllib/tests/test_feature.py | 11 +---------- python/pyspark/mllib/tests/test_linalg.py | 11 +---------- python/pyspark/mllib/tests/test_stat.py | 11 +---------- .../mllib/tests/test_streaming_algorithms.py | 11 +---------- python/pyspark/mllib/tests/test_util.py | 15 ++------------- python/pyspark/testing/mllibutils.py | 11 +---------- 20 files changed, 26 insertions(+), 194 deletions(-) diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index 1a72e124962c8..516bb563402e0 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -16,17 +16,8 @@ # from shutil import rmtree -import sys import tempfile - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import numpy as np diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py index 59c45f638dd45..31e3deb53046c 100644 --- a/python/pyspark/ml/tests/test_base.py +++ b/python/pyspark/ml/tests/test_base.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.sql.types import DoubleType, IntegerType from pyspark.testing.mlutils import MockDataset, MockEstimator, MockUnaryTransformer, \ diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py index 6c3e5c6734509..5438455a6f756 100644 --- a/python/pyspark/ml/tests/test_evaluation.py +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import numpy as np diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 23f66e73b4820..325feaba66957 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -17,14 +17,7 @@ # import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest if sys.version > '3': basestring = str diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py index dcc7a32c9fd70..4c280a4a67894 100644 --- a/python/pyspark/ml/tests/test_image.py +++ b/python/pyspark/ml/tests/test_image.py @@ -14,15 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import py4j diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 76e5386e86125..71cad5d7f5ad7 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -15,17 +15,9 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - +import unittest import array as pyarray + from numpy import arange, array, array_equal, inf, ones, tile, zeros from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \ diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 1f36d4544ab92..17c1b0bf65dde 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -19,17 +19,7 @@ import inspect import sys import array as pyarray -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -if sys.version > '3': - xrange = range +import unittest import numpy as np @@ -45,6 +35,10 @@ from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase +if sys.version > '3': + xrange = range + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index b5a2e16df5532..34d687039ab34 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -17,16 +17,8 @@ import json from shutil import rmtree -import sys import tempfile -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml import Transformer from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py index 31ef02c2e601f..9e3e6c4a75d7a 100644 --- a/python/pyspark/ml/tests/test_pipeline.py +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -14,15 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml.pipeline import Pipeline from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py index bdc4853bc05c2..11aaf2e8083e1 100644 --- a/python/pyspark/ml/tests/test_stat.py +++ b/python/pyspark/ml/tests/test_stat.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml.linalg import Vectors from pyspark.ml.stat import ChiSquareTest diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index d5464f7be6372..8575111c84025 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -16,14 +16,7 @@ # import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest if sys.version > '3': basestring = str diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index af00d1de7ab6a..39bb921aaf43d 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -15,16 +15,8 @@ # limitations under the License. # -import sys import tempfile -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.ml import Estimator, Model from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index 4326d8e060dd7..ae672a00c1dc1 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -15,15 +15,7 @@ # limitations under the License. # -import sys -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest import py4j diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index 8a3454144a115..cc3b64b1cb284 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -16,27 +16,16 @@ # import os -import sys import tempfile from shutil import rmtree +import unittest from numpy import array, array_equal - from py4j.protocol import Py4JJavaError -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint -from pyspark.sql.utils import IllegalArgumentException from pyspark.testing.mllibutils import make_serializer, MLlibTestCase diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py index 48ed810fa6fcb..3da841c408558 100644 --- a/python/pyspark/mllib/tests/test_feature.py +++ b/python/pyspark/mllib/tests/test_feature.py @@ -15,20 +15,11 @@ # limitations under the License. # -import sys from math import sqrt +import unittest from numpy import array, random, exp, abs, tile -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, Vectors from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.feature import HashingTF, IDF, StandardScaler, ElementwiseProduct, Word2Vec diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index 550e32a9af024..d0ebd9bc3db79 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -17,18 +17,9 @@ import sys import array as pyarray +import unittest from numpy import array, array_equal, zeros, arange, tile, ones, inf -from numpy import sum as array_sum - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest import pyspark.ml.linalg as newlinalg from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py index 5e74087d8fa7b..f23ae291d317a 100644 --- a/python/pyspark/mllib/tests/test_stat.py +++ b/python/pyspark/mllib/tests/test_stat.py @@ -15,20 +15,11 @@ # limitations under the License. # -import sys import array as pyarray +import unittest from numpy import array -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.random import RandomRDDs diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index ba95855fd4f00..4bc8904acd31c 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -15,21 +15,12 @@ # limitations under the License. # -import sys from time import time, sleep +import unittest from numpy import array, random, exp, dot, all, mean, abs from numpy import sum as array_sum -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - from pyspark import SparkContext from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py index c924eba80484c..e95716278f122 100644 --- a/python/pyspark/mllib/tests/test_util.py +++ b/python/pyspark/mllib/tests/test_util.py @@ -16,25 +16,14 @@ # import os -import sys import tempfile - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.util import LinearDataGenerator from pyspark.mllib.util import MLUtils -from pyspark.mllib.linalg import SparseVector, DenseVector, SparseMatrix, Vectors +from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors from pyspark.mllib.random import RandomRDDs -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.stat import Statistics from pyspark.testing.mllibutils import MLlibTestCase diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py index 9248182658f84..25f1bba8d37ac 100644 --- a/python/pyspark/testing/mllibutils.py +++ b/python/pyspark/testing/mllibutils.py @@ -15,16 +15,7 @@ # limitations under the License. # -import sys - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest +import unittest from pyspark import SparkContext from pyspark.serializers import PickleSerializer From 630e25e35506c02a0b1e202ef82b1b0f69e50966 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 08:06:33 -0600 Subject: [PATCH 2085/2461] [SPARK-26026][BUILD] Published Scaladoc jars missing from Maven Central ## What changes were proposed in this pull request? This restores scaladoc artifact generation, which got dropped with the Scala 2.12 update. The change looks large, but is almost all due to needing to make the InterfaceStability annotations top-level classes (i.e. `InterfaceStability.Stable` -> `Stable`), unfortunately. A few inner class references had to be qualified too. Lots of scaladoc warnings now reappear. We can choose to disable generation by default and enable for releases, later. ## How was this patch tested? N/A; build runs scaladoc now. Closes #23069 from srowen/SPARK-26026. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../network/protocol/ChunkFetchFailure.java | 2 +- .../network/protocol/ChunkFetchRequest.java | 2 +- .../network/protocol/ChunkFetchSuccess.java | 2 +- .../spark/network/protocol/OneWayMessage.java | 2 +- .../spark/network/protocol/RpcFailure.java | 2 +- .../spark/network/protocol/RpcRequest.java | 2 +- .../spark/network/protocol/RpcResponse.java | 2 +- .../spark/network/protocol/StreamFailure.java | 2 +- .../spark/network/protocol/StreamRequest.java | 2 +- .../network/protocol/StreamResponse.java | 2 +- .../spark/network/protocol/UploadStream.java | 2 +- .../spark/network/sasl/SaslMessage.java | 3 +- .../network/shuffle/RetryingBlockFetcher.java | 2 +- .../org/apache/spark/annotation/Evolving.java | 30 +++++++++ .../spark/annotation/InterfaceStability.java | 58 ----------------- .../org/apache/spark/annotation/Stable.java | 31 ++++++++++ .../org/apache/spark/annotation/Unstable.java | 30 +++++++++ .../kinesis/KinesisInputDStream.scala | 6 +- .../kinesis/SparkAWSCredentials.scala | 9 ++- .../spark/launcher/AbstractAppHandle.java | 12 ++-- .../org/apache/spark/ml/util/ReadWrite.scala | 10 +-- pom.xml | 8 ++- .../java/org/apache/spark/sql/RowFactory.java | 4 +- .../execution/UnsafeExternalRowSorter.java | 10 +-- .../sql/streaming/GroupStateTimeout.java | 4 +- .../spark/sql/streaming/OutputMode.java | 4 +- .../org/apache/spark/sql/types/DataTypes.java | 4 +- .../spark/sql/types/SQLUserDefinedType.java | 4 +- .../apache/spark/sql/AnalysisException.scala | 5 +- .../scala/org/apache/spark/sql/Encoder.scala | 5 +- .../scala/org/apache/spark/sql/Encoders.scala | 4 +- .../main/scala/org/apache/spark/sql/Row.scala | 6 +- .../spark/sql/types/AbstractDataType.scala | 4 +- .../apache/spark/sql/types/ArrayType.scala | 6 +- .../apache/spark/sql/types/BinaryType.scala | 7 +-- .../apache/spark/sql/types/BooleanType.scala | 7 +-- .../org/apache/spark/sql/types/ByteType.scala | 6 +- .../sql/types/CalendarIntervalType.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 6 +- .../org/apache/spark/sql/types/DateType.scala | 6 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../apache/spark/sql/types/DecimalType.scala | 7 +-- .../apache/spark/sql/types/DoubleType.scala | 6 +- .../apache/spark/sql/types/FloatType.scala | 6 +- .../apache/spark/sql/types/IntegerType.scala | 6 +- .../org/apache/spark/sql/types/LongType.scala | 6 +- .../org/apache/spark/sql/types/MapType.scala | 6 +- .../org/apache/spark/sql/types/Metadata.scala | 8 +-- .../org/apache/spark/sql/types/NullType.scala | 7 +-- .../apache/spark/sql/types/ObjectType.scala | 6 +- .../apache/spark/sql/types/ShortType.scala | 6 +- .../apache/spark/sql/types/StringType.scala | 6 +- .../apache/spark/sql/types/StructField.scala | 4 +- .../apache/spark/sql/types/StructType.scala | 8 +-- .../spark/sql/types/TimestampType.scala | 6 +- .../FlatMapGroupsWithStateFunction.java | 4 +- .../function/MapGroupsWithStateFunction.java | 4 +- .../java/org/apache/spark/sql/SaveMode.java | 4 +- .../org/apache/spark/sql/api/java/UDF0.java | 4 +- .../org/apache/spark/sql/api/java/UDF1.java | 4 +- .../org/apache/spark/sql/api/java/UDF10.java | 4 +- .../org/apache/spark/sql/api/java/UDF11.java | 4 +- .../org/apache/spark/sql/api/java/UDF12.java | 4 +- .../org/apache/spark/sql/api/java/UDF13.java | 4 +- .../org/apache/spark/sql/api/java/UDF14.java | 4 +- .../org/apache/spark/sql/api/java/UDF15.java | 4 +- .../org/apache/spark/sql/api/java/UDF16.java | 4 +- .../org/apache/spark/sql/api/java/UDF17.java | 4 +- .../org/apache/spark/sql/api/java/UDF18.java | 4 +- .../org/apache/spark/sql/api/java/UDF19.java | 4 +- .../org/apache/spark/sql/api/java/UDF2.java | 4 +- .../org/apache/spark/sql/api/java/UDF20.java | 4 +- .../org/apache/spark/sql/api/java/UDF21.java | 4 +- .../org/apache/spark/sql/api/java/UDF22.java | 4 +- .../org/apache/spark/sql/api/java/UDF3.java | 4 +- .../org/apache/spark/sql/api/java/UDF4.java | 4 +- .../org/apache/spark/sql/api/java/UDF5.java | 4 +- .../org/apache/spark/sql/api/java/UDF6.java | 4 +- .../org/apache/spark/sql/api/java/UDF7.java | 4 +- .../org/apache/spark/sql/api/java/UDF8.java | 4 +- .../org/apache/spark/sql/api/java/UDF9.java | 4 +- ...emaColumnConvertNotSupportedException.java | 4 +- .../spark/sql/expressions/javalang/typed.java | 4 +- .../sources/v2/BatchReadSupportProvider.java | 4 +- .../sources/v2/BatchWriteSupportProvider.java | 4 +- .../v2/ContinuousReadSupportProvider.java | 4 +- .../sql/sources/v2/DataSourceOptions.java | 4 +- .../spark/sql/sources/v2/DataSourceV2.java | 4 +- .../v2/MicroBatchReadSupportProvider.java | 4 +- .../sql/sources/v2/SessionConfigSupport.java | 4 +- .../v2/StreamingWriteSupportProvider.java | 4 +- .../sources/v2/reader/BatchReadSupport.java | 4 +- .../sql/sources/v2/reader/InputPartition.java | 4 +- .../sources/v2/reader/PartitionReader.java | 4 +- .../v2/reader/PartitionReaderFactory.java | 4 +- .../sql/sources/v2/reader/ReadSupport.java | 4 +- .../sql/sources/v2/reader/ScanConfig.java | 4 +- .../sources/v2/reader/ScanConfigBuilder.java | 4 +- .../sql/sources/v2/reader/Statistics.java | 4 +- .../v2/reader/SupportsPushDownFilters.java | 4 +- .../SupportsPushDownRequiredColumns.java | 4 +- .../v2/reader/SupportsReportPartitioning.java | 4 +- .../v2/reader/SupportsReportStatistics.java | 4 +- .../partitioning/ClusteredDistribution.java | 4 +- .../v2/reader/partitioning/Distribution.java | 4 +- .../v2/reader/partitioning/Partitioning.java | 4 +- .../streaming/ContinuousPartitionReader.java | 4 +- .../ContinuousPartitionReaderFactory.java | 4 +- .../streaming/ContinuousReadSupport.java | 4 +- .../streaming/MicroBatchReadSupport.java | 4 +- .../sources/v2/reader/streaming/Offset.java | 4 +- .../v2/reader/streaming/PartitionOffset.java | 4 +- .../sources/v2/writer/BatchWriteSupport.java | 4 +- .../sql/sources/v2/writer/DataWriter.java | 4 +- .../sources/v2/writer/DataWriterFactory.java | 4 +- .../v2/writer/WriterCommitMessage.java | 4 +- .../streaming/StreamingDataWriterFactory.java | 4 +- .../streaming/StreamingWriteSupport.java | 4 +- .../apache/spark/sql/streaming/Trigger.java | 4 +- .../sql/vectorized/ArrowColumnVector.java | 4 +- .../spark/sql/vectorized/ColumnVector.java | 4 +- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarBatch.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 4 +- .../scala/org/apache/spark/sql/Column.scala | 8 +-- .../spark/sql/DataFrameNaFunctions.scala | 5 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../spark/sql/DataFrameStatFunctions.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../scala/org/apache/spark/sql/Dataset.scala | 62 +++++++++---------- .../org/apache/spark/sql/DatasetHolder.scala | 4 +- .../spark/sql/ExperimentalMethods.scala | 4 +- .../org/apache/spark/sql/ForeachWriter.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 16 ++--- .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../org/apache/spark/sql/RuntimeConfig.scala | 5 +- .../org/apache/spark/sql/SQLContext.scala | 34 +++++----- .../org/apache/spark/sql/SQLImplicits.scala | 4 +- .../org/apache/spark/sql/SparkSession.scala | 48 +++++++------- .../spark/sql/SparkSessionExtensions.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 4 +- .../apache/spark/sql/catalog/Catalog.scala | 16 ++--- .../apache/spark/sql/catalog/interface.scala | 10 +-- .../sql/execution/streaming/Triggers.scala | 4 +- .../continuous/ContinuousTrigger.scala | 6 +- .../spark/sql/expressions/Aggregator.scala | 6 +- .../sql/expressions/UserDefinedFunction.scala | 4 +- .../apache/spark/sql/expressions/Window.scala | 6 +- .../spark/sql/expressions/WindowSpec.scala | 4 +- .../sql/expressions/scalalang/typed.scala | 4 +- .../apache/spark/sql/expressions/udaf.scala | 6 +- .../org/apache/spark/sql/functions.scala | 4 +- .../internal/BaseSessionStateBuilder.scala | 4 +- .../spark/sql/internal/SessionState.scala | 6 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 9 ++- .../scala/org/apache/spark/sql/package.scala | 4 +- .../apache/spark/sql/sources/filters.scala | 34 +++++----- .../apache/spark/sql/sources/interfaces.scala | 26 ++++---- .../sql/streaming/DataStreamReader.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 8 +-- .../spark/sql/streaming/GroupState.scala | 5 +- .../spark/sql/streaming/ProcessingTime.scala | 6 +- .../spark/sql/streaming/StreamingQuery.scala | 4 +- .../streaming/StreamingQueryException.scala | 4 +- .../streaming/StreamingQueryListener.scala | 14 ++--- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../sql/streaming/StreamingQueryStatus.scala | 4 +- .../apache/spark/sql/streaming/progress.scala | 10 +-- .../sql/util/QueryExecutionListener.scala | 6 +- .../hive/service/cli/thrift/TColumn.java | 2 +- .../hive/service/cli/thrift/TColumnValue.java | 2 +- .../service/cli/thrift/TGetInfoValue.java | 2 +- .../hive/service/cli/thrift/TTypeEntry.java | 2 +- .../cli/thrift/TTypeQualifierValue.java | 2 +- .../apache/hive/service/AbstractService.java | 8 +-- .../apache/hive/service/FilterService.java | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 4 +- 177 files changed, 590 insertions(+), 563 deletions(-) create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/Evolving.java delete mode 100644 common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/Stable.java create mode 100644 common/tags/src/main/java/org/apache/spark/annotation/Unstable.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 7b28a9a969486..a7afbfa8621c8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -33,7 +33,7 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { } @Override - public Type type() { return Type.ChunkFetchFailure; } + public Message.Type type() { return Type.ChunkFetchFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 26d063feb5fe3..fe54fcc50dc86 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -32,7 +32,7 @@ public ChunkFetchRequest(StreamChunkId streamChunkId) { } @Override - public Type type() { return Type.ChunkFetchRequest; } + public Message.Type type() { return Type.ChunkFetchRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..d5c9a9b3202fb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -39,7 +39,7 @@ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { } @Override - public Type type() { return Type.ChunkFetchSuccess; } + public Message.Type type() { return Type.ChunkFetchSuccess; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index f7ffb1bd49bb6..1632fb9e03687 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -34,7 +34,7 @@ public OneWayMessage(ManagedBuffer body) { } @Override - public Type type() { return Type.OneWayMessage; } + public Message.Type type() { return Type.OneWayMessage; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index a76624ef5dc96..61061903de23f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -31,7 +31,7 @@ public RpcFailure(long requestId, String errorString) { } @Override - public Type type() { return Type.RpcFailure; } + public Message.Type type() { return Type.RpcFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 2b30920f0598d..cc1bb95d2d566 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -38,7 +38,7 @@ public RpcRequest(long requestId, ManagedBuffer message) { } @Override - public Type type() { return Type.RpcRequest; } + public Message.Type type() { return Type.RpcRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index d73014ecd8506..c03291e9c0b23 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -33,7 +33,7 @@ public RpcResponse(long requestId, ManagedBuffer message) { } @Override - public Type type() { return Type.RpcResponse; } + public Message.Type type() { return Type.RpcResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index 258ef81c6783d..68fcfa7748611 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -33,7 +33,7 @@ public StreamFailure(String streamId, String error) { } @Override - public Type type() { return Type.StreamFailure; } + public Message.Type type() { return Type.StreamFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index dc183c043ed9a..1b135af752bd8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -34,7 +34,7 @@ public StreamRequest(String streamId) { } @Override - public Type type() { return Type.StreamRequest; } + public Message.Type type() { return Type.StreamRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 50b811604b84b..568108c4fe5e8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -40,7 +40,7 @@ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { } @Override - public Type type() { return Type.StreamResponse; } + public Message.Type type() { return Type.StreamResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java index fa1d26e76b852..7d21151e01074 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -52,7 +52,7 @@ private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { } @Override - public Type type() { return Type.UploadStream; } + public Message.Type type() { return Type.UploadStream; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 7331c2b481fb1..1b03300d948e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -23,6 +23,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.protocol.AbstractMessage; +import org.apache.spark.network.protocol.Message; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -46,7 +47,7 @@ class SaslMessage extends AbstractMessage { } @Override - public Type type() { return Type.User; } + public Message.Type type() { return Type.User; } @Override public int encodedLength() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index f309dda8afca6..6bf3da94030d4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -101,7 +101,7 @@ void createAndStart(String[] blockIds, BlockFetchingListener listener) public RetryingBlockFetcher( TransportConf conf, - BlockFetchStarter fetchStarter, + RetryingBlockFetcher.BlockFetchStarter fetchStarter, String[] blockIds, BlockFetchingListener listener) { this.fetchStarter = fetchStarter; diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java new file mode 100644 index 0000000000000..87e8948f204ff --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. + * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Evolving {} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java deleted file mode 100644 index 02bcec737e80e..0000000000000 --- a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.annotation; - -import java.lang.annotation.*; - -/** - * Annotation to inform users of how much to rely on a particular package, - * class or method not changing over time. - */ -public class InterfaceStability { - - /** - * Stable APIs that retain source and binary compatibility within a major release. - * These interfaces can change from one major release to another major release - * (e.g. from 1.0 to 2.0). - */ - @Documented - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) - public @interface Stable {}; - - /** - * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. - * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). - */ - @Documented - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) - public @interface Evolving {}; - - /** - * Unstable APIs, with no guarantee on stability. - * Classes that are unannotated are considered Unstable. - */ - @Documented - @Retention(RetentionPolicy.RUNTIME) - @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, - ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) - public @interface Unstable {}; -} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Stable.java b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java new file mode 100644 index 0000000000000..b198bfbe91e10 --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * Stable APIs that retain source and binary compatibility within a major release. + * These interfaces can change from one major release to another major release + * (e.g. from 1.0 to 2.0). + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Stable {} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java new file mode 100644 index 0000000000000..88ee72125b23f --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * Unstable APIs, with no guarantee on stability. + * Classes that are unannotated are considered Unstable. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Unstable {} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 1ffec01df9f00..d4a428f45c110 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} @@ -84,14 +84,14 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( } } -@InterfaceStability.Evolving +@Evolving object KinesisInputDStream { /** * Builder for [[KinesisInputDStream]] instances. * * @since 2.2.0 */ - @InterfaceStability.Evolving + @Evolving class Builder { // Required params private var streamingContext: Option[StreamingContext] = None diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala index 9facfe8ff2b0f..dcb60b21d9851 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -14,13 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.streaming.kinesis -import scala.collection.JavaConverters._ +package org.apache.spark.streaming.kinesis import com.amazonaws.auth._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging /** @@ -84,14 +83,14 @@ private[kinesis] final case class STSCredentials( } } -@InterfaceStability.Evolving +@Evolving object SparkAWSCredentials { /** * Builder for [[SparkAWSCredentials]] instances. * * @since 2.2.0 */ - @InterfaceStability.Evolving + @Evolving class Builder { private var basicCreds: Option[BasicCredentials] = None private var stsCreds: Option[STSCredentials] = None diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index 9cbebdaeb33d3..0999cbd216871 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -31,8 +31,8 @@ abstract class AbstractAppHandle implements SparkAppHandle { private final LauncherServer server; private LauncherServer.ServerConnection connection; - private List listeners; - private AtomicReference state; + private List listeners; + private AtomicReference state; private volatile String appId; private volatile boolean disposed; @@ -42,7 +42,7 @@ protected AbstractAppHandle(LauncherServer server) { } @Override - public synchronized void addListener(Listener l) { + public synchronized void addListener(SparkAppHandle.Listener l) { if (listeners == null) { listeners = new CopyOnWriteArrayList<>(); } @@ -50,7 +50,7 @@ public synchronized void addListener(Listener l) { } @Override - public State getState() { + public SparkAppHandle.State getState() { return state.get(); } @@ -120,11 +120,11 @@ synchronized void dispose() { } } - void setState(State s) { + void setState(SparkAppHandle.State s) { setState(s, false); } - void setState(State s, boolean force) { + void setState(SparkAppHandle.State s, boolean force) { if (force) { state.set(s); fireEvent(false); diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d985f8ca1ecc7..fbc7be25a5640 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.annotation.{DeveloperApi, Since, Unstable} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -84,7 +84,7 @@ private[util] sealed trait BaseReadWrite { * * @since 2.4.0 */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") trait MLWriterFormat { /** @@ -108,7 +108,7 @@ trait MLWriterFormat { * * @since 2.4.0 */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") trait MLFormatRegister extends MLWriterFormat { /** @@ -208,7 +208,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { /** * A ML Writer which delegates based on the requested format. */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" @@ -291,7 +291,7 @@ trait MLWritable { * Trait for classes that provide `GeneralMLWriter`. */ @Since("2.4.0") -@InterfaceStability.Unstable +@Unstable trait GeneralMLWritable extends MLWritable { /** * Returns an `MLWriter` instance for this ML instance. diff --git a/pom.xml b/pom.xml index 59e3d0fa772b4..fcec295eee128 100644 --- a/pom.xml +++ b/pom.xml @@ -2016,7 +2016,6 @@ net.alchim31.maven scala-maven-plugin - 3.4.4 @@ -2037,6 +2036,13 @@ testCompile + + attach-scaladocs + verify + + doc-jar + + ${scala.version} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java index 2ce1fdcbf56ae..0258e66ffb6e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java @@ -17,7 +17,7 @@ package org.apache.spark.sql; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; import org.apache.spark.sql.catalyst.expressions.GenericRow; /** @@ -25,7 +25,7 @@ * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public class RowFactory { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 1b2f5eee5ccdd..5395e4035e680 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -50,7 +50,7 @@ public final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final PrefixComputer prefixComputer; + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; public abstract static class PrefixComputer { @@ -74,7 +74,7 @@ public static UnsafeExternalRowSorter createWithRecordComparator( StructType schema, Supplier recordComparatorSupplier, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, @@ -85,7 +85,7 @@ public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { Supplier recordComparatorSupplier = @@ -98,9 +98,9 @@ private UnsafeExternalRowSorter( StructType schema, Supplier recordComparatorSupplier, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, - boolean canUseRadixSort) throws IOException { + boolean canUseRadixSort) { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index 5f1032d1229da..5f6a46f2b8e89 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.streaming; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.plans.logical.*; /** @@ -29,7 +29,7 @@ * @since 2.2.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving public class GroupStateTimeout { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 470c128ee6c3d..a3d72a1f5d49f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** @@ -26,7 +26,7 @@ * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving public class OutputMode { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 0f8570fe470bd..d786374f69e20 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -19,7 +19,7 @@ import java.util.*; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * To get/create specific data type, users should use singleton objects and factory methods @@ -27,7 +27,7 @@ * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public class DataTypes { /** * Gets the StringType object. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 1290614a3207d..a54398324fc66 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -20,7 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * ::DeveloperApi:: @@ -31,7 +31,7 @@ @DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) -@InterfaceStability.Evolving +@Evolving public @interface SQLUserDefinedType { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 50ee6cd4085ea..f5c87677ab9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - /** * Thrown when a query fails to analyze, usually because the query itself is invalid. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 7b02317b8538f..9853a4fcc2f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql import scala.annotation.implicitNotFound import scala.reflect.ClassTag -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.types._ - /** * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -67,7 +66,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving @implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + "classes) are supported by importing spark.implicits._ Support for serializing other types " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 8a30c81912fe9..42b865c027205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Modifier import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving object Encoders { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 180c2d130074e..e12bf9616e2de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.util.hashing.MurmurHash3 -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: @@ -124,7 +124,7 @@ object Row { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait Row extends Serializable { /** Number of elements in the Row. */ def size: Int = length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index c43cc748655e8..5367ce2af8e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -134,7 +134,7 @@ object AtomicType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 58c75b5dc7a35..7465569868f07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -21,7 +21,7 @@ import scala.math.Ordering import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.ArrayData /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object ArrayType extends AbstractDataType { /** * Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. @@ -60,7 +60,7 @@ object ArrayType extends AbstractDataType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { /** No-arg constructor for kryo. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index 032d6b54aeb79..cc8b3e6e399a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.TypeUtils - /** * The data type representing `Array[Byte]` values. * Please use the singleton `DataTypes.BinaryType`. */ -@InterfaceStability.Stable +@Stable class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. @@ -55,5 +54,5 @@ class BinaryType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 63f354d2243cf..5e3de71caa37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability - +import org.apache.spark.annotation.Stable /** * The data type representing `Boolean` values. Please use the singleton `DataTypes.BooleanType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. @@ -48,5 +47,5 @@ class BooleanType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 5854c3f5ba116..9d400eefc0f8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Byte` values. Please use the singleton `DataTypes.ByteType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. @@ -52,5 +52,5 @@ class ByteType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 2342036a57460..8e297874a0d62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing calendar time intervals. The calendar time interval is stored @@ -29,7 +29,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable class CalendarIntervalType private() extends DataType { override def defaultSize: Int = 16 @@ -40,5 +40,5 @@ class CalendarIntervalType private() extends DataType { /** * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 33fc4b9480126..c58f7a2397374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,7 +26,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -38,7 +38,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: @@ -111,7 +111,7 @@ abstract class DataType extends AbstractDataType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 9e70dd486a125..7491014b22dab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * A date type, supporting "0001-01-01" through "9999-12-31". @@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. @@ -53,5 +53,5 @@ class DateType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb202045..a3a844670e0c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Unstable import org.apache.spark.sql.AnalysisException /** @@ -31,7 +31,7 @@ import org.apache.spark.sql.AnalysisException * - If decimalVal is set, it represents the whole decimal value * - Otherwise, the decimal value is longVal / (10 ** _scale) */ -@InterfaceStability.Unstable +@Unstable final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ @@ -407,7 +407,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } -@InterfaceStability.Unstable +@Unstable object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 15004e4b9667d..25eddaf06a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -21,11 +21,10 @@ import java.util.Locale import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} - /** * The data type representing `java.math.BigDecimal` values. * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number @@ -39,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class DecimalType(precision: Int, scale: Int) extends FractionalType { if (scale > precision) { @@ -110,7 +109,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object DecimalType extends AbstractDataType { import scala.math.min diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index a5c79ff01ca06..afd3353397019 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.util.Utils /** @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. @@ -54,5 +54,5 @@ class DoubleType private() extends FractionalType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 352147ec936c9..6d98987304081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.util.Utils /** @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. @@ -55,5 +55,5 @@ class FloatType private() extends FractionalType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index a85e3729188d9..0755202d20df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Int` values. Please use the singleton `DataTypes.IntegerType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. @@ -51,5 +51,5 @@ class IntegerType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 0997028fc1057..3c49c721fdc88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Long` values. Please use the singleton `DataTypes.LongType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. @@ -51,5 +51,5 @@ class LongType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 594e155268bf6..29b9ffc0c3549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type for Maps. Keys in a map are not allowed to have `null` values. @@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. */ -@InterfaceStability.Stable +@Stable case class MapType( keyType: DataType, valueType: DataType, @@ -78,7 +78,7 @@ case class MapType( /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 7c15dc0de4b6b..4979aced145c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** @@ -37,7 +37,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { @@ -117,7 +117,7 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object Metadata { private[this] val _empty = new Metadata(Map.empty) @@ -228,7 +228,7 @@ object Metadata { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class MetadataBuilder { private val map: mutable.Map[String, Any] = mutable.Map.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index 494225b47a270..14097a5280d50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.InterfaceStability - +import org.apache.spark.annotation.Stable /** * The data type representing `NULL` values. Please use the singleton `DataTypes.NullType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class NullType private() extends DataType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "NullType$" in byte code. @@ -38,5 +37,5 @@ class NullType private() extends DataType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 203e85e1c99bd..6756b209f432e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.types import scala.language.existentials -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving -@InterfaceStability.Evolving +@Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException( @@ -38,7 +38,7 @@ object ObjectType extends AbstractDataType { /** * Represents a JVM object that is passing through Spark SQL expression evaluation. */ -@InterfaceStability.Evolving +@Evolving case class ObjectType(cls: Class[_]) extends DataType { override def defaultSize: Int = 4096 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index ee655c338b59f..9b5ddfef1ccf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Short` values. Please use the singleton `DataTypes.ShortType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. @@ -51,5 +51,5 @@ class ShortType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 59b124cda7d14..8ce1cd078e312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.unsafe.types.UTF8String /** @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. @@ -48,6 +48,6 @@ class StringType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object StringType extends StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 35f9970a0aaec..6f6b561d67d49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdenti * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class StructField( name: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 06289b1483203..3bef75d5bdb6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,10 +24,10 @@ import scala.util.control.NonFatal import org.json4s.JsonDSL._ import org.apache.spark.SparkException -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.util.Utils /** @@ -95,7 +95,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ @@ -422,7 +422,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index fdb91e0499920..a20f155418f8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `java.sql.Timestamp` values. @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. @@ -50,5 +50,5 @@ class TimestampType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object TimestampType extends TimestampType diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 802949c0ddb60..d4e1d89491f43 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -20,8 +20,8 @@ import java.io.Serializable; import java.util.Iterator; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.streaming.GroupState; /** @@ -33,7 +33,7 @@ * @since 2.1.1 */ @Experimental -@InterfaceStability.Evolving +@Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { Iterator call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 353e9886a8a57..f0abfde843cc5 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -20,8 +20,8 @@ import java.io.Serializable; import java.util.Iterator; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.streaming.GroupState; /** @@ -32,7 +32,7 @@ * @since 2.1.1 */ @Experimental -@InterfaceStability.Evolving +@Evolving public interface MapGroupsWithStateFunction extends Serializable { R call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java index 1c3c9794fb6bb..9cc073f53a3eb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -16,14 +16,14 @@ */ package org.apache.spark.sql; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public enum SaveMode { /** * Append mode means that when saving a DataFrame to a data source, if data/table already exists, diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java index 4eeb7be3f5abb..631d6eb1cfb03 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 0 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF0 extends Serializable { R call() throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java index 1460daf27dc20..a5d01406edd8c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 1 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF1 extends Serializable { R call(T1 t1) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java index 7c4f1e4897084..effe99e30b2a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 10 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF10 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java index 26a05106aebd6..e70b18b84b08f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 11 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF11 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java index 8ef7a99042025..339feb34135e1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 12 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF12 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java index 5c3b2ec1222e2..d346e5c908c6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 13 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF13 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java index 97e744d843466..d27f9f5270f4b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 14 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF14 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java index 7ddbf914fc11a..b99b57a91d465 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 15 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF15 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java index 0ae5dc7195ad6..7899fc4b7ad65 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 16 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF16 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java index 03543a556c614..40a7e95724fc2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 17 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF17 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java index 46740d3443916..47935a935891c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 18 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF18 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java index 33fefd8ecaf1d..578b796ff03a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 19 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF19 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java index 9822f19217d76..2f856aa3cf630 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 2 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF2 extends Serializable { R call(T1 t1, T2 t2) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java index 8c5e90182da1c..aa8a9fa897040 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 20 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF20 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java index e3b09f5167cff..0fe52bce2eca2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 21 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF21 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java index dc6cfa9097bab..69fd8ca422833 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 22 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF22 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java index 7c264b69ba195..84ffd655672a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 3 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF3 extends Serializable { R call(T1 t1, T2 t2, T3 t3) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java index 58df38fc3c911..dd2dc285c226d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 4 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF4 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java index 4146f96e2eed5..795cc21c3f76e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 5 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF5 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java index 25d39654c1095..a954684c3c9a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 6 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF6 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java index ce63b6a91adbb..03761f2c9ebbf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 7 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF7 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java index 0e00209ef6b9f..8cd3583b2cbf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 8 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF8 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java index 077981bb3e3ee..78a7097791963 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 9 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF9 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java index 82a1169cbe7ae..7d1fbe64fc960 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution.datasources; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Unstable; /** * Exception thrown when the parquet reader find column type mismatches. */ -@InterfaceStability.Unstable +@Unstable public class SchemaColumnConvertNotSupportedException extends RuntimeException { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index ec9c107b1c119..5a72f0c6a2555 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions.javalang; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.execution.aggregate.TypedAverage; @@ -35,7 +35,7 @@ * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving public class typed { // Note: make sure to keep in sync with typed.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java index f403dc619e86c..2a4933d75e8d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link BatchReadSupport} instances when end users run * {@code SparkSession.read.format(...).option(...).load()}. */ -@InterfaceStability.Evolving +@Evolving public interface BatchReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index bd10c3353bf12..df439e2c02fe3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -19,7 +19,7 @@ import java.util.Optional; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; @@ -31,7 +31,7 @@ * This interface is used to create {@link BatchWriteSupport} instances when end users run * {@code Dataset.write.format(...).option(...).save()}. */ -@InterfaceStability.Evolving +@Evolving public interface BatchWriteSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java index 824c290518acf..b4f2eb34a1560 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link ContinuousReadSupport} instances when end users run * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index 83df3be747085..1c5e3a0cd31e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -26,7 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent @@ -73,7 +73,7 @@ * *
      spark.scheduler.blacklist.unschedulableTaskSetTimeout120s + The timeout in seconds to wait to acquire a new executor and schedule a task before aborting a + TaskSet which is unschedulable because of being completely blacklisted. +
      spark.blacklist.enabled From 6b425874d311146d8fbf7685c1b5d8e97d73b101 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Nov 2018 23:18:55 +0800 Subject: [PATCH 2005/2461] [SPARK-25866][ML] Update KMeans formatVersion ## What changes were proposed in this pull request? When we added the `distanceMeasure`, we didn't update the `formatVersion` for `KMeans`. Despite this is not a big issue, as that information is used nowhere, we are returning a wrong information. ## How was this patch tested? NA Closes #22873 from mgaido91/SPARK-25866. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/mllib/clustering/KMeansModel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index d5c8188144ce2..b0709547ab1be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -113,7 +113,7 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], KMeansModel.SaveLoadV2_0.save(sc, this, path) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = "2.0" } @Since("1.4.0") From cee230160ba2c3a210892f71e019190b02e34071 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 6 Nov 2018 10:52:33 -0800 Subject: [PATCH 2006/2461] [SPARK-25871][STREAMING] Don't use EC for streaming WAL The write ahead log expects to be able to call hflush, but that is a no-op when writing to a file with hdfs erasure coding. So ensure that file is always written with replication instead, regardless of filesystem defaults. None yet. I'm posting this mostly to make it visible. Closes #22882 from squito/SPARK-25871. Authored-by: Imran Rashid Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/streaming/util/HdfsUtils.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index a6997359d64d2..8cb68b2be4ecf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -21,6 +21,8 @@ import java.io.{FileNotFoundException, IOException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.spark.deploy.SparkHadoopUtil + private[streaming] object HdfsUtils { def getOutputStream(path: String, conf: Configuration): FSDataOutputStream = { @@ -37,7 +39,8 @@ private[streaming] object HdfsUtils { throw new IllegalStateException("File exists and there is no append support!") } } else { - dfs.create(dfsPath) + // we dont' want to use hdfs erasure coding, as that lacks support for append and hflush + SparkHadoopUtil.createNonECFile(dfs, dfsPath) } } stream From a241a150d52b24ce952efab0830af4c0c9343c1b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Nov 2018 14:52:02 -0800 Subject: [PATCH 2007/2461] [MINOR] update known_translations ## What changes were proposed in this pull request? update known_translations after running `translate-contributors.py` during 2.4.0 release ## How was this patch tested? N/A Closes #22949 from cloud-fan/contributors. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- dev/create-release/known_translations | 58 +++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 87bf2f220481d..65c00cce8c9c6 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -203,3 +203,61 @@ shenh062326 - Shen Hong aokolnychyi - Anton Okolnychyi linbojin - Linbo Jin lw-lin - Liwei Lin +10110346 - Xian Liu +Achuth17 - Achuth Narayan Rajagopal +Adamyuanyuan - Adam Wang +DylanGuedes - Dylan Guedes +JiahuiJiang - Jiahui Jiang +KevinZwx - Kevin Zhang +LantaoJin - Lantao Jin +Lemonjing - Rann Tao +LucaCanali - Luca Canali +XD-DENG - Xiaodong Deng +aai95 - Aleksei Izmalkin +akonopko - Alexander Konopko +ankuriitg - Ankur Gupta +arucard21 - Riaas Mokiem +attilapiros - Attila Zsolt Piros +bravo-zhang - Bravo Zhang +caneGuy - Kang Zhou +chaoslawful - Xiaozhe Wang +cluo512 - Chuan Luo +codeatri - Neha Patil +crafty-coder - Carlos Pena +debugger87 - Chaozhong Yang +e-dorigatti - Emilio Dorigatti +eric-maynard - Eric Maynard +felixalbani - Felix Albani +fjh100456 - Jinhua Fu +guoxiaolongzte - Xiaolong Guo +heary-cao - Xuewen Cao +huangweizhe123 - Weizhe Huang +ivoson - Tengfei Huang +jinxing64 - Jin Xing +liu-zhaokun - Zhaokun Liu +liutang123 - Lijia Liu +maropu - Takeshi Yamamuro +maryannxue - Maryann Xue +mcteo - Thomas Dunne +mn-mikke - Marek Novotny +myroslavlisniak - Myroslav Lisniak +npoggi - Nicolas Poggi +pgandhi999 - Parth Gandhi +rimolive - Ricardo Martinelli De Oliveira +sadhen - Darcy Shen +sandeep-katta - Sandeep Katta +seancxmao - Chenxiao Mao +sel - Steve Larkin +shimamoto - Takako Shimamoto +shivusondur - Shivakumar Sondur +skonto - Stavros Kontopoulos +trystanleftwich - Trystan Leftwich +ueshin - Takuya Ueshin +uzmijnlm - Weizhe Huang +xuanyuanking - Yuanjian Li +xubo245 - Bo Xu +xueyumusic - Xue Yu +yanlin-Lynn - Yanlin Wang +yucai - Yucai Yu +zhengruifeng - Ruifeng Zheng +zuotingbing - Tingbing Zuo From 63ca4bbe792718029f6d6196e8a6bb11d1f20fca Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 6 Nov 2018 15:40:56 -0800 Subject: [PATCH 2008/2461] [SPARK-25676][SQL][TEST] Rename and refactor BenchmarkWideTable to use main method ## What changes were proposed in this pull request? Refactor BenchmarkWideTable to use main method. Generate benchmark result: ``` SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.WideTableBenchmark" ``` ## How was this patch tested? manual tests Closes #22823 from yucai/BenchmarkWideTable. Lead-authored-by: yucai Co-authored-by: Yucai Yu Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../benchmarks/WideTableBenchmark-results.txt | 17 ++++++ .../benchmark/BenchmarkWideTable.scala | 52 ------------------- .../benchmark/WideTableBenchmark.scala | 52 +++++++++++++++++++ 3 files changed, 69 insertions(+), 52 deletions(-) create mode 100644 sql/core/benchmarks/WideTableBenchmark-results.txt delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt new file mode 100644 index 0000000000000..3b41a3e036c4d --- /dev/null +++ b/sql/core/benchmarks/WideTableBenchmark-results.txt @@ -0,0 +1,17 @@ +================================================================================================ +projection on wide table +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +split threshold 10 38932 / 39307 0.0 37128.1 1.0X +split threshold 100 31991 / 32556 0.0 30508.8 1.2X +split threshold 1024 10993 / 11041 0.1 10483.5 3.5X +split threshold 2048 8959 / 8998 0.1 8543.8 4.3X +split threshold 4096 8116 / 8134 0.1 7739.8 4.8X +split threshold 8196 8069 / 8098 0.1 7695.5 4.8X +split threshold 65536 57068 / 57339 0.0 54424.3 0.7X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala deleted file mode 100644 index 76367cbbe5342..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import org.apache.spark.benchmark.Benchmark - -/** - * Benchmark to measure performance for wide table. - * To run this: - * build/sbt "sql/test-only *benchmark.BenchmarkWideTable" - * - * Benchmarks in this file are skipped in normal builds. - */ -class BenchmarkWideTable extends BenchmarkWithCodegen { - - ignore("project on wide table") { - val N = 1 << 20 - val df = sparkSession.range(N) - val columns = (0 until 400).map{ i => s"id as id$i"} - val benchmark = new Benchmark("projection on wide table", N) - benchmark.addCase("wide table", numIters = 5) { iter => - df.selectExpr(columns : _*).queryExecution.toRdd.count() - } - benchmark.run() - - /** - * Here are some numbers with different split threshold: - * - * Split threshold methods Rate(M/s) Per Row(ns) - * 10 400 0.4 2279 - * 100 200 0.6 1554 - * 1k 37 0.9 1116 - * 8k 5 0.5 2025 - * 64k 1 0.0 21649 - */ - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala new file mode 100644 index 0000000000000..ffefef1d4fce3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark to measure performance for wide table. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/WideTableBenchmark-results.txt". + * }}} + */ +object WideTableBenchmark extends SqlBasedBenchmark { + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("projection on wide table") { + val N = 1 << 20 + val df = spark.range(N) + val columns = (0 until 400).map{ i => s"id as id$i"} + val benchmark = new Benchmark("projection on wide table", N, output = output) + Seq("10", "100", "1024", "2048", "4096", "8192", "65536").foreach { n => + benchmark.addCase(s"split threshold $n", numIters = 5) { iter => + withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> n) { + df.selectExpr(columns: _*).foreach(identity(_)) + } + } + } + benchmark.run() + } + } +} From 76813cfa1e2607ea3b669a79e59b568e96395b2e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 7 Nov 2018 11:26:17 +0800 Subject: [PATCH 2009/2461] [SPARK-25950][SQL] from_csv should respect to spark.sql.columnNameOfCorruptRecord ## What changes were proposed in this pull request? Fix for `CsvToStructs` to take into account SQL config `spark.sql.columnNameOfCorruptRecord` similar to `from_json`. ## How was this patch tested? Added new test where `spark.sql.columnNameOfCorruptRecord` is set to corrupt column name different from default. Closes #22956 from MaxGekk/csv-tests. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../catalyst/expressions/csvExpressions.scala | 9 +++++- .../apache/spark/sql/CsvFunctionsSuite.scala | 31 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 74b670ae4b68a..aff372b899f86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.csv._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -92,8 +93,14 @@ case class CsvToStructs( } } + val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + @transient lazy val parser = { - val parsedOptions = new CSVOptions(options, columnPruning = true, timeZoneId.get) + val parsedOptions = new CSVOptions( + options, + columnPruning = true, + defaultTimeZoneId = timeZoneId.get, + defaultColumnNameOfCorruptRecord = nameOfCorruptRecord) val mode = parsedOptions.parseMode if (mode != PermissiveMode && mode != FailFastMode) { throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index eb6b248e895f6..1dd8ec31ee111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -86,4 +88,33 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(to_csv($"a", options)), Row("26/08/2015 18:00") :: Nil) } + + test("from_csv invalid csv - check modes") { + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + val schema = new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("_unparsed", StringType) + val badRec = "\"" + val df = Seq(badRec, "2,12").toDS() + + checkAnswer( + df.select(from_csv($"value", schema, Map("mode" -> "PERMISSIVE"))), + Row(Row(null, null, badRec)) :: Row(Row(2, 12, null)) :: Nil) + + val exception1 = intercept[SparkException] { + df.select(from_csv($"value", schema, Map("mode" -> "FAILFAST"))).collect() + }.getMessage + assert(exception1.contains( + "Malformed records are detected in record parsing. Parse Mode: FAILFAST.")) + + val exception2 = intercept[SparkException] { + df.select(from_csv($"value", schema, Map("mode" -> "DROPMALFORMED"))) + .collect() + }.getMessage + assert(exception2.contains( + "from_csv() doesn't support the DROPMALFORMED mode. " + + "Acceptable modes are PERMISSIVE and FAILFAST.")) + } + } } From 9e9fa2f69f3fd8be34c8e99efcf6cf9db70a4cd0 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 6 Nov 2018 21:26:28 -0800 Subject: [PATCH 2010/2461] [SPARK-25098][SQL] Trim the string when cast stringToTimestamp and stringToDate ## What changes were proposed in this pull request? **Hive** and **Oracle** trim the string when cast `stringToTimestamp` and `stringToDate`. this PR support this feature: ![image](https://user-images.githubusercontent.com/5399861/47979721-793b1e80-e0ff-11e8-97c8-24b10950ee9e.png) ![image](https://user-images.githubusercontent.com/5399861/47979725-7dffd280-e0ff-11e8-87d4-5767a00ed46e.png) ## How was this patch tested? unit tests Closes https://github.com/apache/spark/pull/22089 Closes #22943 from wangyum/SPARK-25098. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/util/DateTimeUtils.scala | 8 +++---- .../catalyst/util/DateTimeUtilsSuite.scala | 21 +++++++------------ 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 81d7274607ac8..5ae75dc939303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -274,7 +274,7 @@ object DateTimeUtils { } /** - * Parses a given UTF8 date string to the corresponding a corresponding [[Long]] value. + * Trim and parse a given UTF8 date string to the corresponding a corresponding [[Long]] value. * The return type is [[Option]] in order to distinguish between 0L and null. The following * formats are allowed: * @@ -311,7 +311,7 @@ object DateTimeUtils { val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0) var i = 0 var currentSegmentValue = 0 - val bytes = s.getBytes + val bytes = s.trim.getBytes var j = 0 var digitsMilli = 0 var justTime = false @@ -441,7 +441,7 @@ object DateTimeUtils { } /** - * Parses a given UTF8 date string to a corresponding [[Int]] value. + * Trim and parse a given UTF8 date string to a corresponding [[Int]] value. * The return type is [[Option]] in order to distinguish between 0 and null. The following * formats are allowed: * @@ -459,7 +459,7 @@ object DateTimeUtils { val segments: Array[Int] = Array[Int](1, 1, 1) var i = 0 var currentSegmentValue = 0 - val bytes = s.getBytes + val bytes = s.trim.getBytes var j = 0 while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) { val b = bytes(j) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 2423668392231..0182eeb171215 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -140,16 +140,10 @@ class DateTimeUtilsSuite extends SparkFunSuite { c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(stringToDate(UTF8String.fromString("2015-03-18")).get === - millisToDays(c.getTimeInMillis)) - assert(stringToDate(UTF8String.fromString("2015-03-18 ")).get === - millisToDays(c.getTimeInMillis)) - assert(stringToDate(UTF8String.fromString("2015-03-18 123142")).get === - millisToDays(c.getTimeInMillis)) - assert(stringToDate(UTF8String.fromString("2015-03-18T123123")).get === - millisToDays(c.getTimeInMillis)) - assert(stringToDate(UTF8String.fromString("2015-03-18T")).get === - millisToDays(c.getTimeInMillis)) + Seq("2015-03-18", "2015-03-18 ", " 2015-03-18", " 2015-03-18 ", "2015-03-18 123142", + "2015-03-18T123123", "2015-03-18T").foreach { s => + assert(stringToDate(UTF8String.fromString(s)).get === millisToDays(c.getTimeInMillis)) + } assert(stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) assert(stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) @@ -214,9 +208,10 @@ class DateTimeUtilsSuite extends SparkFunSuite { c = Calendar.getInstance(tz) c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - checkStringToTimestamp("2015-03-18", Option(c.getTimeInMillis * 1000)) - checkStringToTimestamp("2015-03-18 ", Option(c.getTimeInMillis * 1000)) - checkStringToTimestamp("2015-03-18T", Option(c.getTimeInMillis * 1000)) + + Seq("2015-03-18", "2015-03-18 ", " 2015-03-18", " 2015-03-18 ", "2015-03-18T").foreach { s => + checkStringToTimestamp(s, Option(c.getTimeInMillis * 1000)) + } c = Calendar.getInstance(tz) c.set(2015, 2, 18, 12, 3, 17) From 8fbc1830f962c446b915d0d8ff2b13c5c75d22fc Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 7 Nov 2018 13:18:52 +0100 Subject: [PATCH 2011/2461] [SPARK-25904][CORE] Allocate arrays smaller than Int.MaxValue JVMs can't allocate arrays of length exactly Int.MaxValue, so ensure we never try to allocate an array that big. This commit changes some defaults & configs to gracefully fallover to something that doesn't require one large array in some cases; in other cases it simply improves an error message for cases which will still fail. Closes #22818 from squito/SPARK-25827. Authored-by: Imran Rashid Signed-off-by: Imran Rashid --- .../apache/spark/internal/config/package.scala | 17 ++++++++++------- .../org/apache/spark/storage/DiskStore.scala | 6 ++++-- .../spark/storage/memory/MemoryStore.scala | 7 ++++--- .../spark/util/io/ChunkedByteBuffer.scala | 2 +- .../apache/spark/mllib/linalg/Matrices.scala | 13 +++++++------ .../org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 7 +++---- 7 files changed, 32 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 2b3ba3c7daccb..d34601358d896 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -416,8 +416,9 @@ package object config { .internal() .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) - .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + - " ChunkedByteBuffer should not larger than Int.MaxValue.") + .checkValue(_ <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, + "The chunk size during writing out the bytes of" + + " ChunkedByteBuffer should not larger than Int.MaxValue - 15.") .createWithDefault(64 * 1024 * 1024) private[spark] val CHECKPOINT_COMPRESS = @@ -488,8 +489,9 @@ package object config { "otherwise specified. These buffers reduce the number of disk seeks and system calls " + "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) - .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, - s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .checkValue(v => v > 0 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024, + s"The file buffer size must be greater than 0 and less than" + + s" ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024}.") .createWithDefaultString("32k") private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = @@ -497,8 +499,9 @@ package object config { .doc("The file system for this buffer size after each partition " + "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) - .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, - s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") + .checkValue(v => v > 0 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024, + s"The buffer size must be greater than 0 and less than" + + s" ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024}.") .createWithDefaultString("32k") private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = @@ -610,7 +613,7 @@ package object config { .internal() .doc("For testing only, controls the size of chunks when memory mapping a file") .bytesConf(ByteUnit.BYTE) - .createWithDefault(Int.MaxValue) + .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) private[spark] val BARRIER_SYNC_TIMEOUT = ConfigBuilder("spark.barrier.sync.timeout") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 841e16afc7549..29963a95cb074 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer @@ -217,7 +218,7 @@ private[spark] class EncryptedBlockData( var remaining = blockSize val chunks = new ListBuffer[ByteBuffer]() while (remaining > 0) { - val chunkSize = math.min(remaining, Int.MaxValue) + val chunkSize = math.min(remaining, ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) val chunk = allocator(chunkSize.toInt) remaining -= chunkSize JavaUtils.readFully(source, chunk) @@ -235,7 +236,8 @@ private[spark] class EncryptedBlockData( // This is used by the block transfer service to replicate blocks. The upload code reads // all bytes into memory to send the block to the remote executor, so it's ok to do this // as long as the block fits in a Java array. - assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + assert(blockSize <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, + "Block is too large to be wrapped in a byte buffer.") val dst = ByteBuffer.allocate(blockSize.toInt) val in = open() try { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 06fd56e54d9c8..8513359934bec 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -34,6 +34,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -333,11 +334,11 @@ private[spark] class MemoryStore( // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold - val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { + val chunkSize = if (initialMemoryThreshold > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + s"is too large to be set as chunk size. Chunk size has been capped to " + - s"${Utils.bytesToString(Int.MaxValue)}") - Int.MaxValue + s"${Utils.bytesToString(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)}") + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH } else { initialMemoryThreshold.toInt } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index da2be84723a07..870830fff4c3e 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -97,7 +97,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { * @throws UnsupportedOperationException if this buffer's size exceeds the maximum array size. */ def toArray: Array[Byte] = { - if (size >= Integer.MAX_VALUE) { + if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new UnsupportedOperationException( s"cannot call toArray because buffer size ($size bytes) exceeds maximum array size") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index bf9b4cfe15b2c..e474cfa002fad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods /** * Trait for a local matrix. @@ -456,7 +457,7 @@ object DenseMatrix { */ @Since("1.3.0") def zeros(numRows: Int, numCols: Int): DenseMatrix = { - require(numRows.toLong * numCols <= Int.MaxValue, + require(numRows.toLong * numCols <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) } @@ -469,7 +470,7 @@ object DenseMatrix { */ @Since("1.3.0") def ones(numRows: Int, numCols: Int): DenseMatrix = { - require(numRows.toLong * numCols <= Int.MaxValue, + require(numRows.toLong * numCols <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) } @@ -499,7 +500,7 @@ object DenseMatrix { */ @Since("1.3.0") def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { - require(numRows.toLong * numCols <= Int.MaxValue, + require(numRows.toLong * numCols <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } @@ -513,7 +514,7 @@ object DenseMatrix { */ @Since("1.3.0") def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { - require(numRows.toLong * numCols <= Int.MaxValue, + require(numRows.toLong * numCols <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } @@ -846,8 +847,8 @@ object SparseMatrix { s"density must be a double in the range 0.0 <= d <= 1.0. Currently, density: $density") val size = numRows.toLong * numCols val expected = size * density - assert(expected < Int.MaxValue, - "The expected number of nonzeros cannot be greater than Int.MaxValue.") + assert(expected < ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, + "The expected number of nonzeros cannot be greater than Int.MaxValue - 15.") val nnz = math.ceil(expected).toInt if (density == 0.0) { new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1), Array.empty, Array.empty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fa59fa578a969..518115dafd011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,6 @@ import scala.collection.immutable import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.tukaani.xz.LZMA2Options import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging @@ -36,6 +35,7 @@ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1246,7 +1246,7 @@ object SQLConf { .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " + "join operator") .intConf - .createWithDefault(Int.MaxValue) + .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") @@ -1480,7 +1480,7 @@ object SQLConf { "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + " in memory, otherwise do a global sort which spills to disk if necessary.") .intConf - .createWithDefault(Int.MaxValue) + .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c91b0d778fab1..d53400512cb84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -46,7 +45,6 @@ import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ @@ -57,6 +55,7 @@ import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -287,7 +286,7 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1) // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. val tmpRows = getRows(numRows, truncate) @@ -3264,7 +3263,7 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int): Array[Any] = { EvaluatePython.registerPicklers() - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1) val rows = getRows(numRows, truncate).map(_.toArray).toArray val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( From 0a32238d034e4d73d56ec69b7069b9d61e05582b Mon Sep 17 00:00:00 2001 From: koraseg Date: Wed, 7 Nov 2018 09:12:13 -0600 Subject: [PATCH 2012/2461] [SPARK-25885][CORE][MINOR] HighlyCompressedMapStatus deserialization/construction optimization ## What changes were proposed in this pull request? Removal of intermediate structures in HighlyCompressedMapStatus will speed up its creation and deserialization time. https://issues.apache.org/jira/browse/SPARK-25885 ## How was this patch tested? Additional tests are not necessary for the patch. Closes #22894 from Koraseg/mapStatusesOptimization. Authored-by: koraseg Signed-off-by: Sean Owen --- .../org/apache/spark/scheduler/MapStatus.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 0e221edf3965a..64f0a060a247c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.roaringbitmap.RoaringBitmap @@ -149,7 +148,7 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -189,13 +188,13 @@ private[spark] class HighlyCompressedMapStatus private ( emptyBlocks.readExternal(in) avgSize = in.readLong() val count = in.readInt() - val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + val hugeBlockSizesImpl = mutable.Map.empty[Int, Byte] (0 until count).foreach { _ => val block = in.readInt() val size = in.readByte() - hugeBlockSizesArray += Tuple2(block, size) + hugeBlockSizesImpl(block) = size } - hugeBlockSizes = hugeBlockSizesArray.toMap + hugeBlockSizes = hugeBlockSizesImpl } } @@ -215,7 +214,7 @@ private[spark] object HighlyCompressedMapStatus { val threshold = Option(SparkEnv.get) .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) - val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() + val hugeBlockSizes = mutable.Map.empty[Int, Byte] while (i < totalNumBlocks) { val size = uncompressedSizes(i) if (size > 0) { @@ -226,7 +225,7 @@ private[spark] object HighlyCompressedMapStatus { totalSmallBlockSize += size numSmallBlocks += 1 } else { - hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) + hugeBlockSizes(i) = MapStatus.compressSize(uncompressedSizes(i)) } } else { emptyBlocks.add(i) @@ -241,6 +240,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizes) } } From e4561e1c552cdd7291254dc3787a70aadfb05f0a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 7 Nov 2018 13:19:31 -0800 Subject: [PATCH 2013/2461] [SPARK-25897][K8S] Hook up k8s integration tests to sbt build. The integration tests can now be run in sbt if the right profile is enabled, using the "test" task under the respective project. This avoids having to fall back to maven to run the tests, which invalidates all your compiled stuff when you go back to sbt, making development way slower than it should. There's also a task to run the tests directly without refreshing the docker images, which is helpful if you just made a change to the submission code which should not affect the code in the images. The sbt tasks currently are not very customizable; there's some very minor things you can set in the sbt shell itself, but otherwise it's hardcoded to run on minikube. I also had to make some slight adjustments to the IT code itself, mostly to remove assumptions about the existing harness. Tested on sbt and maven. Closes #22909 from vanzin/SPARK-25897. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- bin/docker-image-tool.sh | 3 + project/SparkBuild.scala | 61 ++++++++++++++++++ .../kubernetes/integration-tests/pom.xml | 6 +- .../k8s/integrationtest/KubernetesSuite.scala | 63 +++++++++++++------ .../k8s/integrationtest/ProcessUtils.scala | 8 ++- .../integrationtest/PythonTestsSuite.scala | 12 ++-- .../k8s/integrationtest/RTestsSuite.scala | 5 +- .../k8s/integrationtest/TestConfig.scala | 40 ------------ .../k8s/integrationtest/TestConstants.scala | 2 +- .../backend/IntegrationTestBackend.scala | 4 +- .../backend/cloud/KubeConfigBackend.scala | 5 +- 11 files changed, 126 insertions(+), 83 deletions(-) delete mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 61959ca2a3041..aa5d847f4be2f 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -197,6 +197,9 @@ do if ! which minikube 1>/dev/null; then error "Cannot find minikube." fi + if ! minikube status 1>/dev/null; then + error "Cannot contact minikube. Make sure it's running." + fi eval $(minikube docker-env) ;; esac diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a0aaef293e96f..ca57df0e31a7f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -374,6 +374,8 @@ object SparkBuild extends PomBuild { // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) + enable(KubernetesIntegrationTests.settings)(kubernetesIntegrationTests) + /** * Adds the ability to run the spark shell directly from SBT without building an assembly * jar. @@ -458,6 +460,65 @@ object DockerIntegrationTests { ) } +/** + * These settings run a hardcoded configuration of the Kubernetes integration tests using + * minikube. Docker images will have the "dev" tag, and will be overwritten every time the + * integration tests are run. The integration tests are actually bound to the "test" phase, + * so running "test" on this module will run the integration tests. + * + * There are two ways to run the tests: + * - the "tests" task builds docker images and runs the test, so it's a little slow. + * - the "run-its" task just runs the tests on a pre-built set of images. + * + * Note that this does not use the shell scripts that the maven build uses, which are more + * configurable. This is meant as a quick way for developers to run these tests against their + * local changes. + */ +object KubernetesIntegrationTests { + import BuildCommons._ + + val dockerBuild = TaskKey[Unit]("docker-imgs", "Build the docker images for ITs.") + val runITs = TaskKey[Unit]("run-its", "Only run ITs, skip image build.") + val imageTag = settingKey[String]("Tag to use for images built during the test.") + val namespace = settingKey[String]("Namespace where to run pods.") + + // Hack: this variable is used to control whether to build docker images. It's updated by + // the tasks below in a non-obvious way, so that you get the functionality described in + // the scaladoc above. + private var shouldBuildImage = true + + lazy val settings = Seq( + imageTag := "dev", + namespace := "default", + dockerBuild := { + if (shouldBuildImage) { + val dockerTool = s"$sparkHome/bin/docker-image-tool.sh" + val cmd = Seq(dockerTool, "-m", "-t", imageTag.value, "build") + val ec = Process(cmd).! + if (ec != 0) { + throw new IllegalStateException(s"Process '${cmd.mkString(" ")}' exited with $ec.") + } + } + shouldBuildImage = true + }, + runITs := Def.taskDyn { + shouldBuildImage = false + Def.task { + (test in Test).value + } + }.value, + test in Test := (test in Test).dependsOn(dockerBuild).value, + javaOptions in Test ++= Seq( + "-Dspark.kubernetes.test.deployMode=minikube", + s"-Dspark.kubernetes.test.imageTag=${imageTag.value}", + s"-Dspark.kubernetes.test.namespace=${namespace.value}", + s"-Dspark.kubernetes.test.unpackSparkDir=$sparkHome" + ), + // Force packaging before building images, so that the latest code is tested. + dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly).value + ) +} + /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 07288c97bd527..301b6fe8eee56 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -155,14 +155,10 @@ test + none test - - - (?<!Suite) - integration-test diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 6aa1d57085068..b746a01eb5294 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.deploy.k8s.integrationtest import java.io.File import java.nio.file.{Path, Paths} import java.util.UUID -import java.util.regex.Pattern -import com.google.common.io.PatternFilenameFilter +import scala.collection.JavaConverters._ + +import com.google.common.base.Charsets +import com.google.common.io.Files import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} import io.fabric8.kubernetes.client.Watcher.Action @@ -29,24 +31,22 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} import org.scalatest.Matchers import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} -import scala.collection.JavaConverters._ -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.integrationtest.TestConfig._ +import org.apache.spark.{SPARK_VERSION, SparkFunSuite} import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} import org.apache.spark.internal.Logging -private[spark] class KubernetesSuite extends SparkFunSuite +class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with Logging with Eventually with Matchers { import KubernetesSuite._ - private var sparkHomeDir: Path = _ - private var pyImage: String = _ - private var rImage: String = _ + protected var sparkHomeDir: Path = _ + protected var pyImage: String = _ + protected var rImage: String = _ protected var image: String = _ protected var testBackend: IntegrationTestBackend = _ @@ -67,6 +67,30 @@ private[spark] class KubernetesSuite extends SparkFunSuite private val extraExecTotalMemory = s"${(1024 + memOverheadConstant*1024 + additionalMemory).toInt}Mi" + /** + * Build the image ref for the given image name, taking the repo and tag from the + * test configuration. + */ + private def testImageRef(name: String): String = { + val tag = sys.props.get(CONFIG_KEY_IMAGE_TAG_FILE) + .map { path => + val tagFile = new File(path) + require(tagFile.isFile, + s"No file found for image tag at ${tagFile.getAbsolutePath}.") + Files.toString(tagFile, Charsets.UTF_8).trim + } + .orElse(sys.props.get(CONFIG_KEY_IMAGE_TAG)) + .getOrElse { + throw new IllegalArgumentException( + s"One of $CONFIG_KEY_IMAGE_TAG_FILE or $CONFIG_KEY_IMAGE_TAG is required.") + } + val repo = sys.props.get(CONFIG_KEY_IMAGE_REPO) + .map { _ + "/" } + .getOrElse("") + + s"$repo$name:$tag" + } + override def beforeAll(): Unit = { super.beforeAll() // The scalatest-maven-plugin gives system properties that are referenced but not set null @@ -83,17 +107,16 @@ private[spark] class KubernetesSuite extends SparkFunSuite sparkHomeDir = Paths.get(sparkDirProp) require(sparkHomeDir.toFile.isDirectory, s"No directory found for spark home specified at $sparkHomeDir.") - val imageTag = getTestImageTag - val imageRepo = getTestImageRepo - image = s"$imageRepo/spark:$imageTag" - pyImage = s"$imageRepo/spark-py:$imageTag" - rImage = s"$imageRepo/spark-r:$imageTag" - - val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) - .toFile - .listFiles(new PatternFilenameFilter(Pattern.compile("^spark-examples_.*\\.jar$")))(0) - containerLocalSparkDistroExamplesJar = s"local:///opt/spark/examples/jars/" + - s"${sparkDistroExamplesJarFile.getName}" + image = testImageRef("spark") + pyImage = testImageRef("spark-py") + rImage = testImageRef("spark-r") + + val scalaVersion = scala.util.Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + containerLocalSparkDistroExamplesJar = + s"local:///opt/spark/examples/jars/spark-examples_$scalaVersion-${SPARK_VERSION}.jar" testBackend = IntegrationTestBackendFactory.getTestBackend testBackend.initialize() kubernetesTestComponents = new KubernetesTestComponents(testBackend.getKubernetesClient) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala index 004a942c1cdb3..9ead70f670891 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala @@ -28,7 +28,10 @@ object ProcessUtils extends Logging { * executeProcess is used to run a command and return the output if it * completes within timeout seconds. */ - def executeProcess(fullCommand: Array[String], timeout: Long, dumpErrors: Boolean = false): Seq[String] = { + def executeProcess( + fullCommand: Array[String], + timeout: Long, + dumpErrors: Boolean = false): Seq[String] = { val pb = new ProcessBuilder().command(fullCommand: _*) pb.redirectErrorStream(true) val proc = pb.start() @@ -41,7 +44,8 @@ object ProcessUtils extends Logging { assert(proc.waitFor(timeout, TimeUnit.SECONDS), s"Timed out while executing ${fullCommand.mkString(" ")}") assert(proc.exitValue == 0, - s"Failed to execute ${fullCommand.mkString(" ")}${if (dumpErrors) "\n" + outputLines.mkString("\n")}") + s"Failed to execute ${fullCommand.mkString(" ")}" + + s"${if (dumpErrors) "\n" + outputLines.mkString("\n")}") outputLines } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala index 06b73107ec236..904279923334f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -16,18 +16,14 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} - private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => import PythonTestsSuite._ import KubernetesSuite.k8sTestTag - private val pySparkDockerImage = - s"${getTestImageRepo}/spark-py:${getTestImageTag}" test("Run PySpark on simple pi.py example", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.container.image", pyImage) runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_PI, mainClass = "", @@ -41,7 +37,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.container.image", pyImage) .set("spark.kubernetes.pyspark.pythonVersion", "2") runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, @@ -59,7 +55,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.container.image", pyImage) .set("spark.kubernetes.pyspark.pythonVersion", "3") runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, @@ -77,7 +73,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => test("Run PySpark with memory customization", k8sTestTag) { sparkAppConf - .set("spark.kubernetes.container.image", pySparkDockerImage) + .set("spark.kubernetes.container.image", pyImage) .set("spark.kubernetes.pyspark.pythonVersion", "3") .set("spark.kubernetes.memoryOverheadFactor", s"$memOverheadConstant") .set("spark.executor.pyspark.memory", s"${additionalMemory}m") diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala index 885a23cfb4864..e81562a923228 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import org.apache.spark.deploy.k8s.integrationtest.TestConfig.{getTestImageRepo, getTestImageTag} - private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite => import RTestsSuite._ import KubernetesSuite.k8sTestTag test("Run SparkR on simple dataframe.R example", k8sTestTag) { - sparkAppConf - .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-r:${getTestImageTag}") + sparkAppConf.set("spark.kubernetes.container.image", rImage) runSparkApplicationAndVerifyCompletion( appResource = SPARK_R_DATAFRAME_TEST, mainClass = "", diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala deleted file mode 100644 index 363ec0a6016bb..0000000000000 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConfig.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.integrationtest - -import java.io.File - -import com.google.common.base.Charsets -import com.google.common.io.Files - -import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ - -object TestConfig { - def getTestImageTag: String = { - val imageTagFileProp = System.getProperty(CONFIG_KEY_IMAGE_TAG_FILE) - require(imageTagFileProp != null, "Image tag file must be provided in system properties.") - val imageTagFile = new File(imageTagFileProp) - require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.") - Files.toString(imageTagFile, Charsets.UTF_8).trim - } - - def getTestImageRepo: String = { - val imageRepo = System.getProperty(CONFIG_KEY_IMAGE_REPO) - require(imageRepo != null, "Image repo must be provided in system properties.") - imageRepo - } -} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala index eeae70cd68571..ecc4df716330d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/TestConstants.scala @@ -26,7 +26,7 @@ object TestConstants { val CONFIG_KEY_KUBE_MASTER_URL = "spark.kubernetes.test.master" val CONFIG_KEY_KUBE_NAMESPACE = "spark.kubernetes.test.namespace" val CONFIG_KEY_KUBE_SVC_ACCOUNT = "spark.kubernetes.test.serviceAccountName" - val CONFIG_KEY_IMAGE_TAG = "spark.kubernetes.test.imageTagF" + val CONFIG_KEY_IMAGE_TAG = "spark.kubernetes.test.imageTag" val CONFIG_KEY_IMAGE_TAG_FILE = "spark.kubernetes.test.imageTagFile" val CONFIG_KEY_IMAGE_REPO = "spark.kubernetes.test.imageRepo" val CONFIG_KEY_UNPACK_DIR = "spark.kubernetes.test.unpackSparkDir" diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala index 7bf324c6c4a14..56ddae0c9c57c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.k8s.integrationtest.backend import io.fabric8.kubernetes.client.DefaultKubernetesClient + import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.deploy.k8s.integrationtest.backend.cloud.KubeConfigBackend import org.apache.spark.deploy.k8s.integrationtest.backend.docker.DockerForDesktopBackend @@ -35,7 +36,8 @@ private[spark] object IntegrationTestBackendFactory { .getOrElse(BACKEND_MINIKUBE) deployMode match { case BACKEND_MINIKUBE => MinikubeTestBackend - case BACKEND_CLOUD => new KubeConfigBackend(System.getProperty(CONFIG_KEY_KUBE_CONFIG_CONTEXT)) + case BACKEND_CLOUD => + new KubeConfigBackend(System.getProperty(CONFIG_KEY_KUBE_CONFIG_CONTEXT)) case BACKEND_DOCKER_FOR_DESKTOP => DockerForDesktopBackend case _ => throw new IllegalArgumentException("Invalid " + CONFIG_KEY_DEPLOY_MODE + ": " + deployMode) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala index 333526ba3ef98..be1834c0b5dea 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/cloud/KubeConfigBackend.scala @@ -18,9 +18,10 @@ package org.apache.spark.deploy.k8s.integrationtest.backend.cloud import java.nio.file.Paths -import io.fabric8.kubernetes.client.utils.Utils import io.fabric8.kubernetes.client.{Config, DefaultKubernetesClient} +import io.fabric8.kubernetes.client.utils.Utils import org.apache.commons.lang3.StringUtils + import org.apache.spark.deploy.k8s.integrationtest.TestConstants import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend import org.apache.spark.internal.Logging @@ -38,7 +39,7 @@ private[spark] class KubeConfigBackend(var context: String) // Auto-configure K8S client from K8S config file if (Utils.getSystemPropertyOrEnvVar(Config.KUBERNETES_KUBECONFIG_FILE, null: String) == null) { // Fabric 8 client will automatically assume a default location in this case - logWarning(s"No explicit KUBECONFIG specified, will assume .kube/config under your home directory") + logWarning("No explicit KUBECONFIG specified, will assume $HOME/.kube/config") } val config = Config.autoConfigure(context) From a8e1c9815fef0deb45c9a516d415cea6be511415 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Nov 2018 12:26:21 +0800 Subject: [PATCH 2014/2461] [SPARK-25962][BUILD][PYTHON] Specify minimum versions for both pydocstyle and flake8 in 'lint-python' script ## What changes were proposed in this pull request? This PR explicitly specifies `flake8` and `pydocstyle` versions. - It checks flake8 binary executable - flake8 version check >= 3.5.0 - pydocstyle >= 3.0.0 (previously it was == 3.0.0) ## How was this patch tested? Manually tested. Closes #22963 from HyukjinKwon/SPARK-25962. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/lint-python | 58 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index 2e353e142c143..27d87f6b56680 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -26,9 +26,13 @@ PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" PYDOCSTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pydocstyle-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" + PYDOCSTYLEBUILD="pydocstyle" -EXPECTED_PYDOCSTYLEVERSION="3.0.0" -PYDOCSTYLEVERSION=$(python -c 'import pkg_resources; print(pkg_resources.get_distribution("pydocstyle").version)' 2> /dev/null) +MINIMUM_PYDOCSTYLEVERSION="3.0.0" + +FLAKE8BUILD="flake8" +MINIMUM_FLAKE8="3.5.0" + SPHINXBUILD=${SPHINXBUILD:=sphinx-build} SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" @@ -87,27 +91,47 @@ else rm "$PYCODESTYLE_REPORT_PATH" fi -# stop the build if there are Python syntax errors or undefined names -flake8 . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics -flake8_status="${PIPESTATUS[0]}" +# Check by flake8 +if hash "$FLAKE8BUILD" 2> /dev/null; then + FLAKE8VERSION="$( $FLAKE8BUILD --version 2> /dev/null )" + VERSION=($FLAKE8VERSION) + IS_EXPECTED_FLAKE8=$(python -c 'from distutils.version import LooseVersion; \ +print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))' 2> /dev/null) + if [[ "$IS_EXPECTED_FLAKE8" == "True" ]]; then + # stop the build if there are Python syntax errors or undefined names + $FLAKE8BUILD . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics + flake8_status="${PIPESTATUS[0]}" + + if [ "$flake8_status" -eq 0 ]; then + lint_status=0 + else + lint_status=1 + fi -if [ "$flake8_status" -eq 0 ]; then - lint_status=0 + if [ "$lint_status" -ne 0 ]; then + echo "flake8 checks failed." + exit "$lint_status" + else + echo "flake8 checks passed." + fi + else + echo "The flake8 version needs to be "$MINIMUM_FLAKE8" at latest. Your current version is '"$FLAKE8VERSION"'." + echo "flake8 checks failed." + exit 1 + fi else - lint_status=1 -fi - -if [ "$lint_status" -ne 0 ]; then + echo >&2 "The flake8 command was not found." echo "flake8 checks failed." - exit "$lint_status" -else - echo "flake8 checks passed." + exit 1 fi # Check python document style, skip check if pydocstyle is not installed. if hash "$PYDOCSTYLEBUILD" 2> /dev/null; then - if [[ "$PYDOCSTYLEVERSION" == "$EXPECTED_PYDOCSTYLEVERSION" ]]; then - pydocstyle --config=dev/tox.ini $DOC_PATHS_TO_CHECK >> "$PYDOCSTYLE_REPORT_PATH" + PYDOCSTYLEVERSION="$( $PYDOCSTYLEBUILD --version 2> /dev/null )" + IS_EXPECTED_PYDOCSTYLEVERSION=$(python -c 'from distutils.version import LooseVersion; \ +print(LooseVersion("""'$PYDOCSTYLEVERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLEVERSION'"""))') + if [[ "$IS_EXPECTED_PYDOCSTYLEVERSION" == "True" ]]; then + $PYDOCSTYLEBUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK >> "$PYDOCSTYLE_REPORT_PATH" pydocstyle_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pydocstyle_status" -eq 0 ]; then @@ -121,7 +145,7 @@ if hash "$PYDOCSTYLEBUILD" 2> /dev/null; then fi else - echo "The pydocstyle version needs to be latest 3.0.0. Skipping pydoc checks for now" + echo "The pydocstyle version needs to be "$MINIMUM_PYDOCSTYLEVERSION" at latest. Your current version is "$PYDOCSTYLEVERSION". Skipping pydoc checks for now." fi else echo >&2 "The pydocstyle command was not found. Skipping pydoc checks for now" From 0025a8397f8723011917239fe47518457d4d6860 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 7 Nov 2018 22:48:50 -0600 Subject: [PATCH 2015/2461] [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ## What changes were proposed in this pull request? - Remove some AccumulableInfo .apply() methods - Remove non-label-specific multiclass precision/recall/fScore in favor of accuracy - Remove toDegrees/toRadians in favor of degrees/radians (SparkR: only deprecated) - Remove approxCountDistinct in favor of approx_count_distinct (SparkR: only deprecated) - Remove unused Python StorageLevel constants - Remove Dataset unionAll in favor of union - Remove unused multiclass option in libsvm parsing - Remove references to deprecated spark configs like spark.yarn.am.port - Remove TaskContext.isRunningLocally - Remove ShuffleMetrics.shuffle* methods - Remove BaseReadWrite.context in favor of session - Remove Column.!== in favor of =!= - Remove Dataset.explode - Remove Dataset.registerTempTable - Remove SQLContext.getOrCreate, setActive, clearActive, constructors Not touched yet - everything else in MLLib - HiveContext - Anything deprecated more recently than 2.0.0, generally ## How was this patch tested? Existing tests Closes #22921 from srowen/SPARK-25908. Lead-authored-by: Sean Owen Co-authored-by: hyukjinkwon Co-authored-by: Sean Owen Signed-off-by: Sean Owen --- R/pkg/NAMESPACE | 3 + R/pkg/R/functions.R | 73 ++++++++++-- R/pkg/R/generics.R | 12 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 4 +- .../org/apache/spark/BarrierTaskContext.scala | 2 - .../scala/org/apache/spark/TaskContext.scala | 7 -- .../org/apache/spark/TaskContextImpl.scala | 2 - .../spark/executor/ShuffleWriteMetrics.scala | 10 -- .../spark/scheduler/AccumulableInfo.scala | 30 ----- .../org/apache/spark/ml/util/ReadWrite.scala | 22 ---- .../mllib/evaluation/MulticlassMetrics.scala | 25 ---- .../evaluation/MulticlassMetricsSuite.scala | 3 - project/MimaExcludes.scala | 18 +++ python/pyspark/mllib/evaluation.py | 38 ++---- python/pyspark/mllib/util.py | 8 +- python/pyspark/sql/dataframe.py | 33 ------ python/pyspark/sql/functions.py | 11 -- python/pyspark/sql/tests.py | 2 +- python/pyspark/storagelevel.py | 13 --- .../expressions/MathExpressionsSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 18 --- .../scala/org/apache/spark/sql/Dataset.scala | 110 ------------------ .../org/apache/spark/sql/SQLContext.scala | 50 +------- .../sql/expressions/scalalang/typed.scala | 2 +- .../org/apache/spark/sql/functions.scala | 79 ------------- .../org/apache/spark/sql/DataFrameSuite.scala | 53 ++------- .../apache/spark/sql/SQLContextSuite.scala | 30 +---- 27 files changed, 132 insertions(+), 528 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9d4f05af75afd..de56061b4c1c7 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -194,6 +194,7 @@ exportMethods("%<=>%", "acos", "add_months", "alias", + "approx_count_distinct", "approxCountDistinct", "approxQuantile", "array_contains", @@ -252,6 +253,7 @@ exportMethods("%<=>%", "dayofweek", "dayofyear", "decode", + "degrees", "dense_rank", "desc", "element_at", @@ -334,6 +336,7 @@ exportMethods("%<=>%", "posexplode", "posexplode_outer", "quarter", + "radians", "rand", "randn", "rank", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9292363d1ad2f..9abb7fc1fadb4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -112,7 +112,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = log(df$mpg), v2 = cbrt(df$disp), #' v3 = bround(df$wt, 1), v4 = bin(df$cyl), -#' v5 = hex(df$wt), v6 = toDegrees(df$gear), +#' v5 = hex(df$wt), v6 = degrees(df$gear), #' v7 = atan2(df$cyl, df$am), v8 = hypot(df$cyl, df$am), #' v9 = pmod(df$hp, df$cyl), v10 = shiftLeft(df$disp, 1), #' v11 = conv(df$hp, 10, 16), v12 = sign(df$vs - 0.5), @@ -320,23 +320,37 @@ setMethod("acos", }) #' @details -#' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. +#' \code{approx_count_distinct}: Returns the approximate number of distinct items in a group. #' #' @rdname column_aggregate_functions -#' @aliases approxCountDistinct approxCountDistinct,Column-method +#' @aliases approx_count_distinct approx_count_distinct,Column-method #' @examples #' #' \dontrun{ -#' head(select(df, approxCountDistinct(df$gear))) -#' head(select(df, approxCountDistinct(df$gear, 0.02))) +#' head(select(df, approx_count_distinct(df$gear))) +#' head(select(df, approx_count_distinct(df$gear, 0.02))) #' head(select(df, countDistinct(df$gear, df$cyl))) #' head(select(df, n_distinct(df$gear))) #' head(distinct(select(df, "gear")))} +#' @note approx_count_distinct(Column) since 3.0.0 +setMethod("approx_count_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "approx_count_distinct", x@jc) + column(jc) + }) + +#' @details +#' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. +#' +#' @rdname column_aggregate_functions +#' @aliases approxCountDistinct approxCountDistinct,Column-method #' @note approxCountDistinct(Column) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc) + .Deprecated("approx_count_distinct") + jc <- callJStatic("org.apache.spark.sql.functions", "approx_count_distinct", x@jc) column(jc) }) @@ -1651,7 +1665,22 @@ setMethod("tanh", setMethod("toDegrees", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "toDegrees", x@jc) + .Deprecated("degrees") + jc <- callJStatic("org.apache.spark.sql.functions", "degrees", x@jc) + column(jc) + }) + +#' @details +#' \code{degrees}: Converts an angle measured in radians to an approximately equivalent angle +#' measured in degrees. +#' +#' @rdname column_math_functions +#' @aliases degrees degrees,Column-method +#' @note degrees since 3.0.0 +setMethod("degrees", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "degrees", x@jc) column(jc) }) @@ -1665,7 +1694,22 @@ setMethod("toDegrees", setMethod("toRadians", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "toRadians", x@jc) + .Deprecated("radians") + jc <- callJStatic("org.apache.spark.sql.functions", "radians", x@jc) + column(jc) + }) + +#' @details +#' \code{radians}: Converts an angle measured in degrees to an approximately equivalent angle +#' measured in radians. +#' +#' @rdname column_math_functions +#' @aliases radians radians,Column-method +#' @note radians since 3.0.0 +setMethod("radians", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "radians", x@jc) column(jc) }) @@ -2065,13 +2109,24 @@ setMethod("pmod", signature(y = "Column"), #' @param rsd maximum estimation error allowed (default = 0.05). #' +#' @rdname column_aggregate_functions +#' @aliases approx_count_distinct,Column-method +#' @note approx_count_distinct(Column, numeric) since 3.0.0 +setMethod("approx_count_distinct", + signature(x = "Column"), + function(x, rsd = 0.05) { + jc <- callJStatic("org.apache.spark.sql.functions", "approx_count_distinct", x@jc, rsd) + column(jc) + }) + #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.05) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + .Deprecated("approx_count_distinct") + jc <- callJStatic("org.apache.spark.sql.functions", "approx_count_distinct", x@jc, rsd) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 463102c780b52..cbed276274ac1 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -746,6 +746,10 @@ setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy" #' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) +#' @rdname column_aggregate_functions +#' @name NULL +setGeneric("approx_count_distinct", function(x, ...) { standardGeneric("approx_count_distinct") }) + #' @rdname column_aggregate_functions #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) @@ -1287,10 +1291,18 @@ setGeneric("substring_index", function(x, delim, count) { standardGeneric("subst #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) +#' @rdname column_math_functions +#' @name NULL +setGeneric("degrees", function(x) { standardGeneric("degrees") }) + #' @rdname column_math_functions #' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) +#' @rdname column_math_functions +#' @name NULL +setGeneric("radians", function(x) { standardGeneric("radians") }) + #' @rdname column_math_functions #' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index faec387ce4eff..059c9f3057242 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1379,7 +1379,7 @@ test_that("column operators", { test_that("column functions", { c <- column("a") - c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c1 <- abs(c) + acos(c) + approx_count_distinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) @@ -1388,7 +1388,7 @@ test_that("column functions", { c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id() c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) - c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + degrees(c) + radians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) + ltrim(c, "a") + rtrim(c, "b") + trim(c, "c") c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 7ce421e5479ee..6a497afac444d 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -158,8 +158,6 @@ class BarrierTaskContext private[spark] ( override def isInterrupted(): Boolean = taskContext.isInterrupted() - override def isRunningLocally(): Boolean = taskContext.isRunningLocally() - override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { taskContext.addTaskCompletionListener(listener) this diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 2b939dabb1105..959f246f3f9f6 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -96,13 +96,6 @@ abstract class TaskContext extends Serializable { */ def isInterrupted(): Boolean - /** - * Returns true if the task is running locally in the driver program. - * @return false - */ - @deprecated("Local execution was removed, so this always returns false", "2.0.0") - def isRunningLocally(): Boolean - /** * Adds a (Java friendly) listener to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. Adding a listener diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 89730424e5acf..76296c5d0abd3 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -157,8 +157,6 @@ private[spark] class TaskContextImpl( @GuardedBy("this") override def isCompleted(): Boolean = synchronized(completed) - override def isRunningLocally(): Boolean = false - override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index ada2e1bc08593..0c9da657c2b60 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -56,14 +56,4 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { private[spark] def decRecordsWritten(v: Long): Unit = { _recordsWritten.setValue(recordsWritten - v) } - - // Legacy methods for backward compatibility. - // TODO: remove these once we make this class private. - @deprecated("use bytesWritten instead", "2.0.0") - def shuffleBytesWritten: Long = bytesWritten - @deprecated("use writeTime instead", "2.0.0") - def shuffleWriteTime: Long = writeTime - @deprecated("use recordsWritten instead", "2.0.0") - def shuffleRecordsWritten: Long = recordsWritten - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index d745345f4e0d2..bd0fe90b1f3b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -47,33 +47,3 @@ case class AccumulableInfo private[spark] ( private[spark] val countFailedValues: Boolean, // TODO: use this to identify internal task metrics instead of encoding it in the name private[spark] val metadata: Option[String] = None) - - -/** - * A collection of deprecated constructors. This will be removed soon. - */ -object AccumulableInfo { - - @deprecated("do not create AccumulableInfo", "2.0.0") - def apply( - id: Long, - name: String, - update: Option[String], - value: String, - internal: Boolean): AccumulableInfo = { - new AccumulableInfo( - id, Option(name), update, Option(value), internal, countFailedValues = false) - } - - @deprecated("do not create AccumulableInfo", "2.0.0") - def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo( - id, Option(name), update, Option(value), internal = false, countFailedValues = false) - } - - @deprecated("do not create AccumulableInfo", "2.0.0") - def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo( - id, Option(name), None, Option(value), internal = false, countFailedValues = false) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 72a60e04360d6..a0ac26a34d8c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -47,18 +47,6 @@ import org.apache.spark.util.{Utils, VersionUtils} private[util] sealed trait BaseReadWrite { private var optionSparkSession: Option[SparkSession] = None - /** - * Sets the Spark SQLContext to use for saving/loading. - * - * @deprecated Use session instead. This method will be removed in 3.0.0. - */ - @Since("1.6.0") - @deprecated("Use session instead. This method will be removed in 3.0.0.", "2.0.0") - def context(sqlContext: SQLContext): this.type = { - optionSparkSession = Option(sqlContext.sparkSession) - this - } - /** * Sets the Spark Session to use for saving/loading. */ @@ -215,10 +203,6 @@ abstract class MLWriter extends BaseReadWrite with Logging { // override for Java compatibility @Since("1.6.0") override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - @Since("1.6.0") - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** @@ -281,9 +265,6 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** @@ -352,9 +333,6 @@ abstract class MLReader[T] extends BaseReadWrite { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 9a6a8dbdccbf3..980e0c92531a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -134,31 +134,6 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def fMeasure(label: Double): Double = fMeasure(label, 1.0) - /** - * Returns precision - */ - @Since("1.1.0") - @deprecated("Use accuracy.", "2.0.0") - lazy val precision: Double = accuracy - - /** - * Returns recall - * (equals to precision for multiclass classifier - * because sum of all false positives is equal to sum - * of all false negatives) - */ - @Since("1.1.0") - @deprecated("Use accuracy.", "2.0.0") - lazy val recall: Double = accuracy - - /** - * Returns f-measure - * (equals to precision and recall because precision equals recall) - */ - @Since("1.1.0") - @deprecated("Use accuracy.", "2.0.0") - lazy val fMeasure: Double = accuracy - /** * Returns accuracy * (equals to the total number of correctly classified instances diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 142d1e9812ef1..5394baab94bcf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -77,9 +77,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(math.abs(metrics.accuracy - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.precision) < delta) - assert(math.abs(metrics.accuracy - metrics.recall) < delta) - assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) assert(math.abs(metrics.weightedTruePositiveRate - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 350d8ad6942ff..b6bd6b82d94fd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,24 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleBytesWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleWriteTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.approxCountDistinct"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.toRadians"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.toDegrees"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.monotonicallyIncreasingId"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.clearActive"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getOrCreate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.setActive"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SQLContext.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.fMeasure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.recall"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.precision"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context"), // [SPARK-25737] Remove JavaSparkContextVarargsWorkaround ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"), diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 0bb0ca37c1ab6..b171e46871fdf 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -229,46 +229,28 @@ def falsePositiveRate(self, label): return self.call("falsePositiveRate", label) @since('1.4.0') - def precision(self, label=None): + def precision(self, label): """ - Returns precision or precision for a given label (category) if specified. + Returns precision. """ - if label is None: - # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) - return self.call("precision") - else: - return self.call("precision", float(label)) + return self.call("precision", float(label)) @since('1.4.0') - def recall(self, label=None): + def recall(self, label): """ - Returns recall or recall for a given label (category) if specified. + Returns recall. """ - if label is None: - # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) - return self.call("recall") - else: - return self.call("recall", float(label)) + return self.call("recall", float(label)) @since('1.4.0') - def fMeasure(self, label=None, beta=None): + def fMeasure(self, label, beta=None): """ - Returns f-measure or f-measure for a given label (category) if specified. + Returns f-measure. """ if beta is None: - if label is None: - # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) - return self.call("fMeasure") - else: - return self.call("fMeasure", label) + return self.call("fMeasure", label) else: - if label is None: - raise Exception("If the beta parameter is specified, label can not be none") - else: - return self.call("fMeasure", label, beta) + return self.call("fMeasure", label, beta) @property @since('2.0.0') diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index fc7809387b13a..51f20db2927e2 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -38,12 +38,10 @@ class MLUtils(object): """ @staticmethod - def _parse_libsvm_line(line, multiclass=None): + def _parse_libsvm_line(line): """ Parses a line in LIBSVM format into (label, indices, values). """ - if multiclass is not None: - warnings.warn("deprecated", DeprecationWarning) items = line.split(None) label = float(items[0]) nnz = len(items) - 1 @@ -73,7 +71,7 @@ def _convert_labeled_point_to_libsvm(p): @staticmethod @since("1.0.0") - def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): + def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): """ Loads labeled data in the LIBSVM format into an RDD of LabeledPoint. The LIBSVM format is a text-based format used by @@ -116,8 +114,6 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0])) """ from pyspark.mllib.regression import LabeledPoint - if multiclass is not None: - warnings.warn("deprecated", DeprecationWarning) lines = sc.textFile(path, minPartitions) parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bf6b990487617..5748f6c6bd5eb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -119,25 +119,6 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - @since(1.3) - def registerTempTable(self, name): - """Registers this DataFrame as a temporary table using the given name. - - The lifetime of this temporary table is tied to the :class:`SparkSession` - that was used to create this :class:`DataFrame`. - - >>> df.registerTempTable("people") - >>> df2 = spark.sql("select * from people") - >>> sorted(df.collect()) == sorted(df2.collect()) - True - >>> spark.catalog.dropTempView("people") - - .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. - """ - warnings.warn( - "Deprecated in 2.0, use createOrReplaceTempView instead.", DeprecationWarning) - self._jdf.createOrReplaceTempView(name) - @since(2.0) def createTempView(self, name): """Creates a local temporary view with this DataFrame. @@ -1462,20 +1443,6 @@ def union(self, other): """ return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) - @since(1.3) - def unionAll(self, other): - """ Return a new :class:`DataFrame` containing union of rows in this and another frame. - - This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union - (that does deduplication of elements), use this function followed by :func:`distinct`. - - Also as standard in SQL, this function resolves columns by position (not by name). - - .. note:: Deprecated in 2.0, use :func:`union` instead. - """ - warnings.warn("Deprecated in 2.0, use union instead.", DeprecationWarning) - return self.union(other) - @since(2.3) def unionByName(self, other): """ Returns a new :class:`DataFrame` containing union of rows in this and another frame. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 24824efb47362..e86749cc15c35 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -252,8 +252,6 @@ def _(): # Wraps deprecated functions (keys) with the messages (values). _functions_deprecated = { - 'toDegrees': 'Deprecated in 2.1, use degrees instead.', - 'toRadians': 'Deprecated in 2.1, use radians instead.', } for _name, _doc in _functions.items(): @@ -275,15 +273,6 @@ def _(): del _name, _doc -@since(1.3) -def approxCountDistinct(col, rsd=None): - """ - .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. - """ - warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) - return approx_count_distinct(col, rsd) - - @since(2.1) def approx_count_distinct(col, rsd=None): """Aggregate function: returns a new :class:`Column` for approximate distinct count of diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ad04270c1a361..ea0269162d62a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1610,7 +1610,7 @@ def test_aggregator(self): from pyspark.sql import functions self.assertEqual((0, u'99'), tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) - self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) + self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) def test_first_last_ignorenulls(self): diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 7f29646c07432..951af45bb3227 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -56,16 +56,3 @@ def __str__(self): StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False) StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2) StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1) - -""" -.. note:: The following four storage level constants are deprecated in 2.0, since the records - will always be serialized in Python. -""" -StorageLevel.MEMORY_ONLY_SER = StorageLevel.MEMORY_ONLY -""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY`` instead.""" -StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel.MEMORY_ONLY_2 -""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY_2`` instead.""" -StorageLevel.MEMORY_AND_DISK_SER = StorageLevel.MEMORY_AND_DISK -""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK`` instead.""" -StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel.MEMORY_AND_DISK_2 -""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK_2`` instead.""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 3a094079380fd..48105571b2798 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -246,7 +246,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("toDegrees") { testUnary(ToDegrees, math.toDegrees) - checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) + checkConsistencyBetweenInterpretedAndCodegen(ToDegrees, DoubleType) } test("toRadians") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a046127c3edb4..a9a19aa8a1001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -302,24 +302,6 @@ class Column(val expr: Expression) extends Logging { */ def =!= (other: Any): Column = withExpr{ Not(EqualTo(expr, lit(other).expr)) } - /** - * Inequality test. - * {{{ - * // Scala: - * df.select( df("colA") !== df("colB") ) - * df.select( !(df("colA") === df("colB")) ) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.filter( col("colA").notEqual(col("colB")) ); - * }}} - * - * @group expr_ops - * @since 1.3.0 - */ - @deprecated("!== does not have the same precedence as ===, use =!= instead", "2.0.0") - def !== (other: Any): Column = this =!= other - /** * Inequality test. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d53400512cb84..f98eaa3d4eb90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1809,20 +1809,6 @@ class Dataset[T] private[sql]( Limit(Literal(n), logicalPlan) } - /** - * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * - * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does - * deduplication of elements), use this function followed by a [[distinct]]. - * - * Also as standard in SQL, this function resolves columns by position (not by name). - * - * @group typedrel - * @since 2.0.0 - */ - @deprecated("use union()", "2.0.0") - def unionAll(other: Dataset[T]): Dataset[T] = union(other) - /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * @@ -2122,90 +2108,6 @@ class Dataset[T] private[sql]( randomSplit(weights.toArray, seed) } - /** - * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more - * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of - * the input row are implicitly joined with each row that is output by the function. - * - * Given that this is deprecated, as an alternative, you can explode columns either using - * `functions.explode()` or `flatMap()`. The following example uses these alternatives to count - * the number of books that contain a given word: - * - * {{{ - * case class Book(title: String, words: String) - * val ds: Dataset[Book] - * - * val allWords = ds.select('title, explode(split('words, " ")).as("word")) - * - * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) - * }}} - * - * Using `flatMap()` this can similarly be exploded as: - * - * {{{ - * ds.flatMap(_.words.split(" ")) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { - val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - - val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) - - val rowFunction = - f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) - - withPlan { - Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) - } - } - - /** - * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero - * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All - * columns of the input row are implicitly joined with each value that is output by the function. - * - * Given that this is deprecated, as an alternative, you can explode columns either using - * `functions.explode()`: - * - * {{{ - * ds.select(explode(split('words, " ")).as("word")) - * }}} - * - * or `flatMap()`: - * - * {{{ - * ds.flatMap(_.words.split(" ")) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) - : DataFrame = { - val dataType = ScalaReflection.schemaFor[B].dataType - val attributes = AttributeReference(outputColumn, dataType)() :: Nil - // TODO handle the metadata? - val elementSchema = attributes.toStructType - - def rowFunction(row: Row): TraversableOnce[InternalRow] = { - val convert = CatalystTypeConverters.createToCatalystConverter(dataType) - f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) - } - val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) - - withPlan { - Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) - } - } - /** * Returns a new Dataset by adding a column or replacing the existing column that has * the same name. @@ -3053,18 +2955,6 @@ class Dataset[T] private[sql]( */ def javaRDD: JavaRDD[T] = toJavaRDD - /** - * Registers this Dataset as a temporary table using the given name. The lifetime of this - * temporary table is tied to the [[SparkSession]] that was used to create this Dataset. - * - * @group basic - * @since 1.6.0 - */ - @deprecated("Use createOrReplaceTempView(viewName) instead.", "2.0.0") - def registerTempTable(tableName: String): Unit = { - createOrReplaceTempView(tableName) - } - /** * Creates a local temporary view using the given name. The lifetime of this * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1b7e969a7192e..9982b60fefe60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD @@ -64,15 +64,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) // Note: Since Spark 2.0 this class has become a wrapper of SparkSession, where the // real functionality resides. This class remains mainly for backward compatibility. - - @deprecated("Use SparkSession.builder instead", "2.0.0") - def this(sc: SparkContext) = { - this(SparkSession.builder().sparkContext(sc).getOrCreate()) - } - - @deprecated("Use SparkSession.builder instead", "2.0.0") - def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) - // TODO: move this logic into SparkSession private[sql] def sessionState: SessionState = sparkSession.sessionState @@ -767,45 +758,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ object SQLContext { - /** - * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. - * - * This function can be used to create a singleton SQLContext object that can be shared across - * the JVM. - * - * If there is an active SQLContext for current thread, it will be returned instead of the global - * one. - * - * @since 1.5.0 - */ - @deprecated("Use SparkSession.builder instead", "2.0.0") - def getOrCreate(sparkContext: SparkContext): SQLContext = { - SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext - } - - /** - * Changes the SQLContext that will be returned in this thread and its children when - * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives - * a SQLContext with an isolated session, instead of the global (first created) context. - * - * @since 1.6.0 - */ - @deprecated("Use SparkSession.setActiveSession instead", "2.0.0") - def setActive(sqlContext: SQLContext): Unit = { - SparkSession.setActiveSession(sqlContext.sparkSession) - } - - /** - * Clears the active SQLContext for current thread. Subsequent calls to getOrCreate will - * return the first created context instead of a thread-local override. - * - * @since 1.6.0 - */ - @deprecated("Use SparkSession.clearActiveSession instead", "2.0.0") - def clearActive(): Unit = { - SparkSession.clearActiveSession() - } - /** * Converts an iterator of Java Beans to InternalRow using the provided * bean info & schema. This is not related to the singleton, but is a static diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 650ffd4586592..3e637d594caf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -80,7 +80,7 @@ object typed { // TODO: // stddevOf: Double // varianceOf: Double - // approxCountDistinct: Long + // approx_count_distinct: Long // minOf: T // maxOf: T diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6bb1a490d8c3a..b2a6e22cbfc86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -206,36 +206,6 @@ object functions { // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * @group agg_funcs - * @since 1.3.0 - */ - @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(e: Column): Column = approx_count_distinct(e) - - /** - * @group agg_funcs - * @since 1.3.0 - */ - @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(columnName: String): Column = approx_count_distinct(columnName) - - /** - * @group agg_funcs - * @since 1.3.0 - */ - @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(e: Column, rsd: Double): Column = approx_count_distinct(e, rsd) - - /** - * @group agg_funcs - * @since 1.3.0 - */ - @deprecated("Use approx_count_distinct", "2.1.0") - def approxCountDistinct(columnName: String, rsd: Double): Column = { - approx_count_distinct(Column(columnName), rsd) - } - /** * Aggregate function: returns the approximate number of distinct items in a group. * @@ -1114,27 +1084,6 @@ object functions { */ def isnull(e: Column): Column = withExpr { IsNull(e.expr) } - /** - * A column expression that generates monotonically increasing 64-bit integers. - * - * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. - * The current implementation puts the partition ID in the upper 31 bits, and the record number - * within each partition in the lower 33 bits. The assumption is that the data frame has - * less than 1 billion partitions, and each partition has less than 8 billion records. - * - * As an example, consider a `DataFrame` with two partitions, each with 3 records. - * This expression would return the following IDs: - * - * {{{ - * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. - * }}} - * - * @group normal_funcs - * @since 1.4.0 - */ - @deprecated("Use monotonically_increasing_id()", "2.0.0") - def monotonicallyIncreasingId(): Column = monotonically_increasing_id() - /** * A column expression that generates monotonically increasing 64-bit integers. * @@ -2116,20 +2065,6 @@ object functions { */ def tanh(columnName: String): Column = tanh(Column(columnName)) - /** - * @group math_funcs - * @since 1.4.0 - */ - @deprecated("Use degrees", "2.1.0") - def toDegrees(e: Column): Column = degrees(e) - - /** - * @group math_funcs - * @since 1.4.0 - */ - @deprecated("Use degrees", "2.1.0") - def toDegrees(columnName: String): Column = degrees(Column(columnName)) - /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * @@ -2152,20 +2087,6 @@ object functions { */ def degrees(columnName: String): Column = degrees(Column(columnName)) - /** - * @group math_funcs - * @since 1.4.0 - */ - @deprecated("Use radians", "2.1.0") - def toRadians(e: Column): Column = radians(e) - - /** - * @group math_funcs - * @since 1.4.0 - */ - @deprecated("Use radians", "2.1.0") - def toRadians(columnName: String): Column = radians(Column(columnName)) - /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4afae56ecdb76..edde9bfd088cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -220,31 +220,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { dfAlias.col("t2.c") } - test("simple explode") { - val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") - - checkAnswer( - df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), - Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil - ) - } - - test("explode") { - val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") - val df2 = - df.explode('letters) { - case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq - } - - checkAnswer( - df2 - .select('_1 as 'letter, 'number) - .groupBy('letter) - .agg(countDistinct('number)), - Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil - ) - } - test("Star Expansion - CreateStruct and CreateArray") { val structDf = testData2.select("a", "b").as("record") // CreateStruct and CreateArray in aggregateExpressions @@ -280,24 +255,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("Star Expansion - explode should fail with a meaningful message if it takes a star") { - val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") val e = intercept[AnalysisException] { - df.explode($"*") { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() - } - assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - - checkAnswer( - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }, - Row("1", "1,2", "1:1") :: - Row("1", "1,2", "1:2") :: - Row("2", "4", "2:4") :: - Row("3", "7,8,9", "3:7") :: - Row("3", "7,8,9", "3:8") :: - Row("3", "7,8,9", "3:9") :: Nil) + df.select(explode($"*")) + } + assert(e.getMessage.contains("Invalid usage of '*' in expression 'explode'")) + } + + test("explode on output of array-valued function") { + val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") + checkAnswer( + df.select(explode(split($"csv", ","))), + Row("1") :: Row("2") :: Row("4") :: Row("7") :: Row("8") :: Row("9") :: Nil) } test("Star Expansion - explode alias and star") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index a1799829932b8..aab2ae4afc7f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -24,32 +24,14 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -@deprecated("This suite is deprecated to silent compiler deprecation warnings", "2.0.0") class SQLContextSuite extends SparkFunSuite with SharedSparkContext { object DummyRule extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p } - test("getOrCreate instantiates SQLContext") { - val sqlContext = SQLContext.getOrCreate(sc) - assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(sc).eq(sqlContext), - "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") - } - - test("getOrCreate return the original SQLContext") { - val sqlContext = SQLContext.getOrCreate(sc) - val newSession = sqlContext.newSession() - assert(SQLContext.getOrCreate(sc).eq(sqlContext), - "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") - SparkSession.setActiveSession(newSession.sparkSession) - assert(SQLContext.getOrCreate(sc).eq(newSession), - "SQLContext.getOrCreate after explicitly setActive() did not return the active context") - } - test("Sessions of SQLContext") { - val sqlContext = SQLContext.getOrCreate(sc) + val sqlContext = SparkSession.builder().sparkContext(sc).getOrCreate().sqlContext val session1 = sqlContext.newSession() val session2 = sqlContext.newSession() @@ -77,13 +59,13 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { } test("Catalyst optimization passes are modifiable at runtime") { - val sqlContext = SQLContext.getOrCreate(sc) + val sqlContext = SparkSession.builder().sparkContext(sc).getOrCreate().sqlContext sqlContext.experimental.extraOptimizations = Seq(DummyRule) assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } test("get all tables") { - val sqlContext = SQLContext.getOrCreate(sc) + val sqlContext = SparkSession.builder().sparkContext(sc).getOrCreate().sqlContext val df = sqlContext.range(10) df.createOrReplaceTempView("listtablessuitetable") assert( @@ -100,7 +82,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { } test("getting all tables with a database name has no impact on returned table names") { - val sqlContext = SQLContext.getOrCreate(sc) + val sqlContext = SparkSession.builder().sparkContext(sc).getOrCreate().sqlContext val df = sqlContext.range(10) df.createOrReplaceTempView("listtablessuitetable") assert( @@ -117,7 +99,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { } test("query the returned DataFrame of tables") { - val sqlContext = SQLContext.getOrCreate(sc) + val sqlContext = SparkSession.builder().sparkContext(sc).getOrCreate().sqlContext val df = sqlContext.range(10) df.createOrReplaceTempView("listtablessuitetable") @@ -127,7 +109,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { StructField("isTemporary", BooleanType, false) :: Nil) Seq(sqlContext.tables(), sqlContext.sql("SHOW TABLes")).foreach { - case tableDF => + tableDF => assert(expectedSchema === tableDF.schema) tableDF.createOrReplaceTempView("tables") From d68f3a726ffb4280d85268ef5a13b408b123ff48 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Nov 2018 05:54:48 +0000 Subject: [PATCH 2016/2461] [SPARK-25676][FOLLOWUP][BUILD] Fix Scala 2.12 build error ## What changes were proposed in this pull request? This PR fixes the Scala-2.12 build. ## How was this patch tested? Manual build with Scala-2.12 profile. Closes #22970 from dongjoon-hyun/SPARK-25676-2.12. Authored-by: Dongjoon Hyun Signed-off-by: DB Tsai --- sql/core/benchmarks/WideTableBenchmark-results.txt | 14 +++++++------- .../execution/benchmark/WideTableBenchmark.scala | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt index 3b41a3e036c4d..7bc388aaa549f 100644 --- a/sql/core/benchmarks/WideTableBenchmark-results.txt +++ b/sql/core/benchmarks/WideTableBenchmark-results.txt @@ -6,12 +6,12 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -split threshold 10 38932 / 39307 0.0 37128.1 1.0X -split threshold 100 31991 / 32556 0.0 30508.8 1.2X -split threshold 1024 10993 / 11041 0.1 10483.5 3.5X -split threshold 2048 8959 / 8998 0.1 8543.8 4.3X -split threshold 4096 8116 / 8134 0.1 7739.8 4.8X -split threshold 8196 8069 / 8098 0.1 7695.5 4.8X -split threshold 65536 57068 / 57339 0.0 54424.3 0.7X +split threshold 10 39634 / 39829 0.0 37798.3 1.0X +split threshold 100 30121 / 30571 0.0 28725.8 1.3X +split threshold 1024 9678 / 9725 0.1 9229.9 4.1X +split threshold 2048 8634 / 8662 0.1 8233.6 4.6X +split threshold 4096 8561 / 8576 0.1 8164.6 4.6X +split threshold 8192 8393 / 8408 0.1 8003.8 4.7X +split threshold 65536 57063 / 57273 0.0 54419.1 0.7X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala index ffefef1d4fce3..c61db3ce4b949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf /** @@ -42,7 +43,7 @@ object WideTableBenchmark extends SqlBasedBenchmark { Seq("10", "100", "1024", "2048", "4096", "8192", "65536").foreach { n => benchmark.addCase(s"split threshold $n", numIters = 5) { iter => withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> n) { - df.selectExpr(columns: _*).foreach(identity(_)) + df.selectExpr(columns: _*).foreach((x => x): Row => Unit) } } } From 17449a2e6b28ecce7a273284eab037e8aceb3611 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 14:48:23 +0800 Subject: [PATCH 2017/2461] [SPARK-25952][SQL] Passing actual schema to JacksonParser ## What changes were proposed in this pull request? The PR fixes an issue when the corrupt record column specified via `spark.sql.columnNameOfCorruptRecord` or JSON options `columnNameOfCorruptRecord` is propagated to JacksonParser, and returned row breaks an assumption in `FailureSafeParser` that the row must contain only actual data. The issue is fixed by passing actual schema without the corrupt record field into `JacksonParser`. ## How was this patch tested? Added a test with the corrupt record column in the middle of user's schema. Closes #22958 from MaxGekk/from_json-corrupt-record-schema. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../sql/catalyst/expressions/jsonExpressions.scala | 14 ++++++++------ .../org/apache/spark/sql/JsonFunctionsSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index eafcb6161036e..52d0677f4022f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -569,14 +569,16 @@ case class JsonToStructs( throw new IllegalArgumentException(s"from_json() doesn't support the ${mode.name} mode. " + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } - val rawParser = new JacksonParser(nullableSchema, parsedOptions, allowArrayAsStructs = false) - val createParser = CreateJacksonParser.utf8String _ - - val parserSchema = nullableSchema match { - case s: StructType => s - case other => StructType(StructField("value", other) :: Nil) + val (parserSchema, actualSchema) = nullableSchema match { + case s: StructType => + (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) + case other => + (StructType(StructField("value", other) :: Nil), other) } + val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) + val createParser = CreateJacksonParser.utf8String _ + new FailureSafeParser[UTF8String]( input => rawParser.parse(input, createParser, identity[UTF8String]), mode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 2b09782faeeaa..d6b73387e84b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -578,4 +578,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Acceptable modes are PERMISSIVE and FAILFAST.")) } } + + test("corrupt record column in the middle") { + val schema = new StructType() + .add("a", IntegerType) + .add("_unparsed", StringType) + .add("b", IntegerType) + val badRec = """{"a" 1, "b": 11}""" + val df = Seq(badRec, """{"a": 2, "b": 12}""").toDS() + + checkAnswer( + df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))), + Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil) + } } From ee03f760b305e70a57c3b4409ec25897af348600 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Nov 2018 14:51:29 +0800 Subject: [PATCH 2018/2461] [SPARK-25955][TEST] Porting JSON tests for CSV functions ## What changes were proposed in this pull request? In the PR, I propose to port existing JSON tests from `JsonFunctionsSuite` that are applicable for CSV, and put them to `CsvFunctionsSuite`. In particular: - roundtrip `from_csv` to `to_csv`, and `to_csv` to `from_csv` - using `schema_of_csv` in `from_csv` - Java API `from_csv` - using `from_csv` and `to_csv` in exprs. Closes #22960 from MaxGekk/csv-additional-tests. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../apache/spark/sql/CsvFunctionsSuite.scala | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 1dd8ec31ee111..b97ac380def63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -117,4 +117,51 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { "Acceptable modes are PERMISSIVE and FAILFAST.")) } } + + test("from_csv uses DDL strings for defining a schema - java") { + val df = Seq("""1,"haa"""").toDS() + checkAnswer( + df.select( + from_csv($"value", lit("a INT, b STRING"), new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + + test("roundtrip to_csv -> from_csv") { + val df = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") + val schema = df.schema(0).dataType.asInstanceOf[StructType] + val options = Map.empty[String, String] + val readback = df.select(to_csv($"struct").as("csv")) + .select(from_csv($"csv", schema, options).as("struct")) + + checkAnswer(df, readback) + } + + test("roundtrip from_csv -> to_csv") { + val df = Seq(Some("1"), None).toDF("csv") + val schema = new StructType().add("a", IntegerType) + val options = Map.empty[String, String] + val readback = df.select(from_csv($"csv", schema, options).as("struct")) + .select(to_csv($"struct").as("csv")) + + checkAnswer(df, readback) + } + + test("infers schemas of a CSV string and pass to to from_csv") { + val in = Seq("""0.123456789,987654321,"San Francisco"""").toDS() + val options = Map.empty[String, String].asJava + val out = in.select(from_csv('value, schema_of_csv("0.1,1,a"), options) as "parsed") + val expected = StructType(Seq(StructField( + "parsed", + StructType(Seq( + StructField("_c0", DoubleType, true), + StructField("_c1", IntegerType, true), + StructField("_c2", StringType, true)))))) + + assert(out.schema == expected) + } + + test("Support to_csv in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil) + } } From 0a2e45fdb8baadf7a57eb06f319e96f95eedf298 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Nov 2018 16:32:25 +0800 Subject: [PATCH 2019/2461] Revert "[SPARK-23831][SQL] Add org.apache.derby to IsolatedClientLoader" This reverts commit a75571b46f813005a6d4b076ec39081ffab11844. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 1 - .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ------ 2 files changed, 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 1e7a0b187c8b3..c1d8fe53a9e8c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -186,7 +186,6 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 - name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 1de258f060943..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,10 +113,4 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } - - test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { - val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) - val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) - assert(!client1.equals(client2)) - } } From a3004d084c654237c60d02df1507333b92b860c6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Nov 2018 03:40:28 -0800 Subject: [PATCH 2020/2461] [SPARK-25971][SQL] Ignore partition byte-size statistics in SQLQueryTestSuite ## What changes were proposed in this pull request? Currently, `SQLQueryTestSuite` is sensitive in terms of the bytes of parquet files in table partitions. If we change the default file format (from Parquet to ORC) or update the metadata of them, the test case should be changed accordingly. This PR aims to make `SQLQueryTestSuite` more robust by ignoring the partition byte statistics. ``` -Partition Statistics 1144 bytes, 2 rows +Partition Statistics [not included in comparison] bytes, 2 rows ``` ## How was this patch tested? Pass the Jenkins with the newly updated test cases. Closes #22972 from dongjoon-hyun/SPARK-25971. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../results/describe-part-after-analyze.sql.out | 12 ++++++------ .../org/apache/spark/sql/SQLQueryTestSuite.scala | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 8ba69c698b551..17dd317f63b70 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -93,7 +93,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1121 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -128,7 +128,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1121 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -155,7 +155,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1098 bytes, 4 rows +Partition Statistics [not included in comparison] bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -190,7 +190,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1121 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -217,7 +217,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1098 bytes, 4 rows +Partition Statistics [not included in comparison] bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -244,7 +244,7 @@ Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1144 bytes, 2 rows +Partition Statistics [not included in comparison] bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 826408c7161e9..6ca3ac596e5f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -272,6 +272,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. From 0d7396f3af2d4348ae53e6a274df952b7f17c37c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 Nov 2018 03:51:55 -0800 Subject: [PATCH 2021/2461] [SPARK-22827][SQL][FOLLOW-UP] Throw `SparkOutOfMemoryError` in `HashAggregateExec`, too. ## What changes were proposed in this pull request? This is a follow-up pr of #20014 which introduced `SparkOutOfMemoryError` to avoid killing the entire executor when an `OutOfMemoryError` is thrown. We should throw `SparkOutOfMemoryError` in `HashAggregateExec`, too. ## How was this patch tested? Existing tests. Closes #22969 from ueshin/issues/SPARK-22827/oome. Authored-by: Takuya UESHIN Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 25d8e7dff3d99..08dcdf33fb8f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.TaskContext -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -762,6 +762,8 @@ case class HashAggregateExec( ("true", "true", "", "") } + val oomeClassName = classOf[SparkOutOfMemoryError].getName + val findOrInsertRegularHashMap: String = s""" |// generate grouping key @@ -787,7 +789,7 @@ case class HashAggregateExec( | $unsafeRowKeys, ${hashEval.value}); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page - | throw new OutOfMemoryError("No enough memory for aggregation"); + | throw new $oomeClassName("No enough memory for aggregation"); | } |} """.stripMargin From 6abe90625efeb8140531a875700e87ed7e981044 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Nov 2018 23:37:14 +0800 Subject: [PATCH 2022/2461] [SPARK-25676][SQL][FOLLOWUP] Use 'foreach(_ => ())' ## What changes were proposed in this pull request? #22970 fixed Scala 2.12 build error, and this PR updates the function according to the review comments. ## How was this patch tested? This is also manually tested with Scala 2.12 build. Closes #22978 from dongjoon-hyun/SPARK-25676-3. Authored-by: Dongjoon Hyun Signed-off-by: Wenchen Fan --- sql/core/benchmarks/WideTableBenchmark-results.txt | 14 +++++++------- .../execution/benchmark/WideTableBenchmark.scala | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt index 7bc388aaa549f..8c09f9ca11307 100644 --- a/sql/core/benchmarks/WideTableBenchmark-results.txt +++ b/sql/core/benchmarks/WideTableBenchmark-results.txt @@ -6,12 +6,12 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -split threshold 10 39634 / 39829 0.0 37798.3 1.0X -split threshold 100 30121 / 30571 0.0 28725.8 1.3X -split threshold 1024 9678 / 9725 0.1 9229.9 4.1X -split threshold 2048 8634 / 8662 0.1 8233.6 4.6X -split threshold 4096 8561 / 8576 0.1 8164.6 4.6X -split threshold 8192 8393 / 8408 0.1 8003.8 4.7X -split threshold 65536 57063 / 57273 0.0 54419.1 0.7X +split threshold 10 40571 / 40937 0.0 38691.7 1.0X +split threshold 100 31116 / 31669 0.0 29674.6 1.3X +split threshold 1024 10077 / 10199 0.1 9609.7 4.0X +split threshold 2048 8654 / 8692 0.1 8253.2 4.7X +split threshold 4096 8006 / 8038 0.1 7634.7 5.1X +split threshold 8192 8069 / 8107 0.1 7695.3 5.0X +split threshold 65536 56973 / 57204 0.0 54333.7 0.7X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala index c61db3ce4b949..52426d81bd1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.Row import org.apache.spark.sql.internal.SQLConf /** @@ -43,7 +42,7 @@ object WideTableBenchmark extends SqlBasedBenchmark { Seq("10", "100", "1024", "2048", "4096", "8192", "65536").foreach { n => benchmark.addCase(s"split threshold $n", numIters = 5) { iter => withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> n) { - df.selectExpr(columns: _*).foreach((x => x): Row => Unit) + df.selectExpr(columns: _*).foreach(_ => ()) } } } From 7bb901aa28d3000c2e18cc769fe5769abd650770 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 8 Nov 2018 10:08:14 -0800 Subject: [PATCH 2023/2461] [SPARK-25964][SQL][MINOR] Revise OrcReadBenchmark/DataSourceReadBenchmark case names and execution instructions ## What changes were proposed in this pull request? 1. OrcReadBenchmark is under hive module, so the way to run it should be ``` build/sbt "hive/test:runMain " ``` 2. The benchmark "String with Nulls Scan" should be with case "String with Nulls Scan(5%/50%/95%)", not "(0.05%/0.5%/0.95%)" 3. Add the null value percentages in the test case names of DataSourceReadBenchmark, for the benchmark "String with Nulls Scan" . ## How was this patch tested? Re-run benchmarks Closes #22965 from gengliangwang/fixHiveOrcReadBenchmark. Lead-authored-by: Gengliang Wang Co-authored-by: Gengliang Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../DataSourceReadBenchmark-results.txt | 336 +++++++++--------- .../benchmark/DataSourceReadBenchmark.scala | 4 +- .../benchmarks/OrcReadBenchmark-results.txt | 170 ++++----- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 11 +- 4 files changed, 263 insertions(+), 258 deletions(-) diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt index 2d3bae442cc50..b07e8b1197ff0 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -2,268 +2,268 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 21508 / 22112 0.7 1367.5 1.0X -SQL Json 8705 / 8825 1.8 553.4 2.5X -SQL Parquet Vectorized 157 / 186 100.0 10.0 136.7X -SQL Parquet MR 1789 / 1794 8.8 113.8 12.0X -SQL ORC Vectorized 156 / 166 100.9 9.9 138.0X -SQL ORC Vectorized with copy 218 / 225 72.1 13.9 98.6X -SQL ORC MR 1448 / 1492 10.9 92.0 14.9X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 26366 / 26562 0.6 1676.3 1.0X +SQL Json 8709 / 8724 1.8 553.7 3.0X +SQL Parquet Vectorized 166 / 187 94.8 10.5 159.0X +SQL Parquet MR 1706 / 1720 9.2 108.4 15.5X +SQL ORC Vectorized 167 / 174 94.2 10.6 157.9X +SQL ORC Vectorized with copy 226 / 231 69.6 14.4 116.7X +SQL ORC MR 1433 / 1465 11.0 91.1 18.4X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 202 / 211 77.7 12.9 1.0X -ParquetReader Vectorized -> Row 118 / 120 133.5 7.5 1.7X +ParquetReader Vectorized 200 / 207 78.7 12.7 1.0X +ParquetReader Vectorized -> Row 117 / 119 134.7 7.4 1.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 23282 / 23312 0.7 1480.2 1.0X -SQL Json 9187 / 9189 1.7 584.1 2.5X -SQL Parquet Vectorized 204 / 218 77.0 13.0 114.0X -SQL Parquet MR 1941 / 1953 8.1 123.4 12.0X -SQL ORC Vectorized 217 / 225 72.6 13.8 107.5X -SQL ORC Vectorized with copy 279 / 289 56.3 17.8 83.4X -SQL ORC MR 1541 / 1549 10.2 98.0 15.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 26489 / 26547 0.6 1684.1 1.0X +SQL Json 8990 / 8998 1.7 571.5 2.9X +SQL Parquet Vectorized 209 / 221 75.1 13.3 126.5X +SQL Parquet MR 1949 / 1949 8.1 123.9 13.6X +SQL ORC Vectorized 221 / 228 71.3 14.0 120.1X +SQL ORC Vectorized with copy 315 / 319 49.9 20.1 84.0X +SQL ORC MR 1527 / 1549 10.3 97.1 17.3X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 288 / 297 54.6 18.3 1.0X -ParquetReader Vectorized -> Row 255 / 257 61.7 16.2 1.1X +ParquetReader Vectorized 286 / 296 54.9 18.2 1.0X +ParquetReader Vectorized -> Row 249 / 253 63.1 15.8 1.1X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 24990 / 25012 0.6 1588.8 1.0X -SQL Json 9837 / 9865 1.6 625.4 2.5X -SQL Parquet Vectorized 170 / 180 92.3 10.8 146.6X -SQL Parquet MR 2319 / 2328 6.8 147.4 10.8X -SQL ORC Vectorized 293 / 301 53.7 18.6 85.3X -SQL ORC Vectorized with copy 297 / 309 52.9 18.9 84.0X -SQL ORC MR 1667 / 1674 9.4 106.0 15.0X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 27701 / 27744 0.6 1761.2 1.0X +SQL Json 9703 / 9733 1.6 616.9 2.9X +SQL Parquet Vectorized 176 / 182 89.2 11.2 157.0X +SQL Parquet MR 2164 / 2173 7.3 137.6 12.8X +SQL ORC Vectorized 307 / 314 51.2 19.5 90.2X +SQL ORC Vectorized with copy 312 / 319 50.4 19.8 88.7X +SQL ORC MR 1690 / 1700 9.3 107.4 16.4X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 257 / 274 61.3 16.3 1.0X -ParquetReader Vectorized -> Row 259 / 264 60.8 16.4 1.0X +ParquetReader Vectorized 259 / 277 60.7 16.5 1.0X +ParquetReader Vectorized -> Row 261 / 265 60.3 16.6 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 32537 / 32554 0.5 2068.7 1.0X -SQL Json 12610 / 12668 1.2 801.7 2.6X -SQL Parquet Vectorized 258 / 276 61.0 16.4 126.2X -SQL Parquet MR 2422 / 2435 6.5 154.0 13.4X -SQL ORC Vectorized 378 / 385 41.6 24.0 86.2X -SQL ORC Vectorized with copy 381 / 389 41.3 24.2 85.4X -SQL ORC MR 1797 / 1819 8.8 114.3 18.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 34813 / 34900 0.5 2213.3 1.0X +SQL Json 12570 / 12617 1.3 799.2 2.8X +SQL Parquet Vectorized 270 / 308 58.2 17.2 128.9X +SQL Parquet MR 2427 / 2431 6.5 154.3 14.3X +SQL ORC Vectorized 388 / 398 40.6 24.6 89.8X +SQL ORC Vectorized with copy 395 / 402 39.9 25.1 88.2X +SQL ORC MR 1819 / 1851 8.6 115.7 19.1X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 352 / 368 44.7 22.4 1.0X -ParquetReader Vectorized -> Row 351 / 359 44.8 22.3 1.0X +ParquetReader Vectorized 372 / 379 42.3 23.7 1.0X +ParquetReader Vectorized -> Row 357 / 368 44.1 22.7 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 27179 / 27184 0.6 1728.0 1.0X -SQL Json 12578 / 12585 1.3 799.7 2.2X -SQL Parquet Vectorized 161 / 171 97.5 10.3 168.5X -SQL Parquet MR 2361 / 2395 6.7 150.1 11.5X -SQL ORC Vectorized 473 / 480 33.3 30.0 57.5X -SQL ORC Vectorized with copy 478 / 483 32.9 30.4 56.8X -SQL ORC MR 1858 / 1859 8.5 118.2 14.6X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 28753 / 28781 0.5 1828.0 1.0X +SQL Json 12039 / 12215 1.3 765.4 2.4X +SQL Parquet Vectorized 170 / 177 92.4 10.8 169.0X +SQL Parquet MR 2184 / 2196 7.2 138.9 13.2X +SQL ORC Vectorized 432 / 440 36.4 27.5 66.5X +SQL ORC Vectorized with copy 439 / 442 35.9 27.9 65.6X +SQL ORC MR 1812 / 1833 8.7 115.2 15.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 251 / 255 62.7 15.9 1.0X -ParquetReader Vectorized -> Row 255 / 259 61.8 16.2 1.0X +ParquetReader Vectorized 253 / 260 62.2 16.1 1.0X +ParquetReader Vectorized -> Row 256 / 257 61.6 16.2 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 34797 / 34830 0.5 2212.3 1.0X -SQL Json 17806 / 17828 0.9 1132.1 2.0X -SQL Parquet Vectorized 260 / 269 60.6 16.5 134.0X -SQL Parquet MR 2512 / 2534 6.3 159.7 13.9X -SQL ORC Vectorized 582 / 593 27.0 37.0 59.8X -SQL ORC Vectorized with copy 576 / 584 27.3 36.6 60.4X -SQL ORC MR 2309 / 2313 6.8 146.8 15.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 36177 / 36188 0.4 2300.1 1.0X +SQL Json 18895 / 18898 0.8 1201.3 1.9X +SQL Parquet Vectorized 267 / 276 58.9 17.0 135.6X +SQL Parquet MR 2355 / 2363 6.7 149.7 15.4X +SQL ORC Vectorized 543 / 546 29.0 34.5 66.6X +SQL ORC Vectorized with copy 548 / 557 28.7 34.8 66.0X +SQL ORC MR 2246 / 2258 7.0 142.8 16.1X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 350 / 363 44.9 22.3 1.0X -ParquetReader Vectorized -> Row 350 / 366 44.9 22.3 1.0X +ParquetReader Vectorized 353 / 367 44.6 22.4 1.0X +ParquetReader Vectorized -> Row 351 / 357 44.7 22.3 1.0X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 22486 / 22590 0.5 2144.5 1.0X -SQL Json 14124 / 14195 0.7 1347.0 1.6X -SQL Parquet Vectorized 2342 / 2347 4.5 223.4 9.6X -SQL Parquet MR 4660 / 4664 2.2 444.4 4.8X -SQL ORC Vectorized 2378 / 2379 4.4 226.8 9.5X -SQL ORC Vectorized with copy 2548 / 2571 4.1 243.0 8.8X -SQL ORC MR 4206 / 4211 2.5 401.1 5.3X +SQL CSV 21130 / 21246 0.5 2015.1 1.0X +SQL Json 12145 / 12174 0.9 1158.2 1.7X +SQL Parquet Vectorized 2363 / 2377 4.4 225.3 8.9X +SQL Parquet MR 4555 / 4557 2.3 434.4 4.6X +SQL ORC Vectorized 2361 / 2388 4.4 225.1 9.0X +SQL ORC Vectorized with copy 2540 / 2557 4.1 242.2 8.3X +SQL ORC MR 4186 / 4209 2.5 399.2 5.0X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 12150 / 12178 0.9 1158.7 1.0X -SQL Json 7012 / 7014 1.5 668.7 1.7X -SQL Parquet Vectorized 792 / 796 13.2 75.5 15.3X -SQL Parquet MR 1961 / 1975 5.3 187.0 6.2X -SQL ORC Vectorized 482 / 485 21.8 46.0 25.2X -SQL ORC Vectorized with copy 710 / 715 14.8 67.7 17.1X -SQL ORC MR 2081 / 2083 5.0 198.5 5.8X +SQL CSV 11693 / 11729 0.9 1115.1 1.0X +SQL Json 7025 / 7025 1.5 669.9 1.7X +SQL Parquet Vectorized 803 / 821 13.1 76.6 14.6X +SQL Parquet MR 1776 / 1790 5.9 169.4 6.6X +SQL ORC Vectorized 491 / 494 21.4 46.8 23.8X +SQL ORC Vectorized with copy 723 / 725 14.5 68.9 16.2X +SQL ORC MR 2050 / 2063 5.1 195.5 5.7X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Data column - CSV 31789 / 31791 0.5 2021.1 1.0X -Data column - Json 12873 / 12918 1.2 818.4 2.5X -Data column - Parquet Vectorized 267 / 280 58.9 17.0 119.1X -Data column - Parquet MR 3387 / 3402 4.6 215.3 9.4X -Data column - ORC Vectorized 391 / 453 40.2 24.9 81.2X -Data column - ORC Vectorized with copy 392 / 398 40.2 24.9 81.2X -Data column - ORC MR 2508 / 2512 6.3 159.4 12.7X -Partition column - CSV 6965 / 6977 2.3 442.8 4.6X -Partition column - Json 5563 / 5576 2.8 353.7 5.7X -Partition column - Parquet Vectorized 65 / 78 241.1 4.1 487.2X -Partition column - Parquet MR 1811 / 1811 8.7 115.1 17.6X -Partition column - ORC Vectorized 66 / 73 239.0 4.2 483.0X -Partition column - ORC Vectorized with copy 65 / 70 241.1 4.1 487.3X -Partition column - ORC MR 1775 / 1778 8.9 112.8 17.9X -Both columns - CSV 30032 / 30113 0.5 1909.4 1.1X -Both columns - Json 13941 / 13959 1.1 886.3 2.3X -Both columns - Parquet Vectorized 312 / 330 50.3 19.9 101.7X -Both columns - Parquet MR 3858 / 3862 4.1 245.3 8.2X -Both columns - ORC Vectorized 431 / 437 36.5 27.4 73.8X -Both column - ORC Vectorized with copy 523 / 529 30.1 33.3 60.7X -Both columns - ORC MR 2712 / 2805 5.8 172.4 11.7X +Data column - CSV 30965 / 31041 0.5 1968.7 1.0X +Data column - Json 12876 / 12882 1.2 818.6 2.4X +Data column - Parquet Vectorized 277 / 282 56.7 17.6 111.6X +Data column - Parquet MR 3398 / 3402 4.6 216.0 9.1X +Data column - ORC Vectorized 399 / 407 39.4 25.4 77.5X +Data column - ORC Vectorized with copy 407 / 447 38.6 25.9 76.0X +Data column - ORC MR 2583 / 2589 6.1 164.2 12.0X +Partition column - CSV 7403 / 7427 2.1 470.7 4.2X +Partition column - Json 5587 / 5625 2.8 355.2 5.5X +Partition column - Parquet Vectorized 71 / 78 222.6 4.5 438.3X +Partition column - Parquet MR 1798 / 1808 8.7 114.3 17.2X +Partition column - ORC Vectorized 72 / 75 219.0 4.6 431.2X +Partition column - ORC Vectorized with copy 71 / 77 221.1 4.5 435.4X +Partition column - ORC MR 1772 / 1778 8.9 112.6 17.5X +Both columns - CSV 30211 / 30212 0.5 1920.7 1.0X +Both columns - Json 13382 / 13391 1.2 850.8 2.3X +Both columns - Parquet Vectorized 321 / 333 49.0 20.4 96.4X +Both columns - Parquet MR 3656 / 3661 4.3 232.4 8.5X +Both columns - ORC Vectorized 443 / 448 35.5 28.2 69.9X +Both column - ORC Vectorized with copy 527 / 533 29.9 33.5 58.8X +Both columns - ORC MR 2626 / 2633 6.0 167.0 11.8X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13525 / 13823 0.8 1289.9 1.0X -SQL Json 9913 / 9921 1.1 945.3 1.4X -SQL Parquet Vectorized 1517 / 1517 6.9 144.7 8.9X -SQL Parquet MR 3996 / 4008 2.6 381.1 3.4X -ParquetReader Vectorized 1120 / 1128 9.4 106.8 12.1X -SQL ORC Vectorized 1203 / 1224 8.7 114.7 11.2X -SQL ORC Vectorized with copy 1639 / 1646 6.4 156.3 8.3X -SQL ORC MR 3720 / 3780 2.8 354.7 3.6X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 13918 / 13979 0.8 1327.3 1.0X +SQL Json 10068 / 10068 1.0 960.1 1.4X +SQL Parquet Vectorized 1563 / 1564 6.7 149.0 8.9X +SQL Parquet MR 3835 / 3836 2.7 365.8 3.6X +ParquetReader Vectorized 1115 / 1118 9.4 106.4 12.5X +SQL ORC Vectorized 1172 / 1208 8.9 111.8 11.9X +SQL ORC Vectorized with copy 1630 / 1644 6.4 155.5 8.5X +SQL ORC MR 3708 / 3711 2.8 353.6 3.8X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 15860 / 15877 0.7 1512.5 1.0X -SQL Json 7676 / 7688 1.4 732.0 2.1X -SQL Parquet Vectorized 1072 / 1084 9.8 102.2 14.8X -SQL Parquet MR 2890 / 2897 3.6 275.6 5.5X -ParquetReader Vectorized 1052 / 1053 10.0 100.4 15.1X -SQL ORC Vectorized 1248 / 1248 8.4 119.0 12.7X -SQL ORC Vectorized with copy 1627 / 1637 6.4 155.2 9.7X -SQL ORC MR 3365 / 3369 3.1 320.9 4.7X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 13972 / 14043 0.8 1332.5 1.0X +SQL Json 7436 / 7469 1.4 709.1 1.9X +SQL Parquet Vectorized 1103 / 1112 9.5 105.2 12.7X +SQL Parquet MR 2841 / 2847 3.7 271.0 4.9X +ParquetReader Vectorized 992 / 1012 10.6 94.6 14.1X +SQL ORC Vectorized 1275 / 1349 8.2 121.6 11.0X +SQL ORC Vectorized with copy 1631 / 1644 6.4 155.5 8.6X +SQL ORC MR 3244 / 3259 3.2 309.3 4.3X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13401 / 13561 0.8 1278.1 1.0X -SQL Json 5253 / 5303 2.0 500.9 2.6X -SQL Parquet Vectorized 233 / 242 45.0 22.2 57.6X -SQL Parquet MR 1791 / 1796 5.9 170.8 7.5X -ParquetReader Vectorized 236 / 238 44.4 22.5 56.7X -SQL ORC Vectorized 453 / 473 23.2 43.2 29.6X -SQL ORC Vectorized with copy 573 / 577 18.3 54.7 23.4X -SQL ORC MR 1846 / 1850 5.7 176.0 7.3X +SQL CSV 11228 / 11244 0.9 1070.8 1.0X +SQL Json 5200 / 5247 2.0 495.9 2.2X +SQL Parquet Vectorized 238 / 242 44.1 22.7 47.2X +SQL Parquet MR 1730 / 1734 6.1 165.0 6.5X +ParquetReader Vectorized 237 / 238 44.3 22.6 47.4X +SQL ORC Vectorized 459 / 462 22.8 43.8 24.4X +SQL ORC Vectorized with copy 581 / 583 18.1 55.4 19.3X +SQL ORC MR 1767 / 1783 5.9 168.5 6.4X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 3147 / 3148 0.3 3001.1 1.0X -SQL Json 2666 / 2693 0.4 2542.9 1.2X -SQL Parquet Vectorized 54 / 58 19.5 51.3 58.5X -SQL Parquet MR 220 / 353 4.8 209.9 14.3X -SQL ORC Vectorized 63 / 77 16.8 59.7 50.3X -SQL ORC Vectorized with copy 63 / 66 16.7 59.8 50.2X -SQL ORC MR 317 / 321 3.3 302.2 9.9X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 3322 / 3356 0.3 3167.9 1.0X +SQL Json 2808 / 2843 0.4 2678.2 1.2X +SQL Parquet Vectorized 56 / 63 18.9 52.9 59.8X +SQL Parquet MR 215 / 219 4.9 205.4 15.4X +SQL ORC Vectorized 64 / 76 16.4 60.9 52.0X +SQL ORC Vectorized with copy 64 / 67 16.3 61.3 51.7X +SQL ORC MR 314 / 316 3.3 299.6 10.6X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 7902 / 7921 0.1 7536.2 1.0X -SQL Json 9467 / 9491 0.1 9028.6 0.8X -SQL Parquet Vectorized 73 / 79 14.3 69.8 108.0X -SQL Parquet MR 239 / 247 4.4 228.0 33.1X -SQL ORC Vectorized 78 / 84 13.4 74.6 101.0X -SQL ORC Vectorized with copy 78 / 88 13.4 74.4 101.3X -SQL ORC MR 910 / 918 1.2 867.6 8.7X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 7978 / 7989 0.1 7608.5 1.0X +SQL Json 10294 / 10325 0.1 9816.9 0.8X +SQL Parquet Vectorized 72 / 85 14.5 69.0 110.3X +SQL Parquet MR 237 / 241 4.4 226.4 33.6X +SQL ORC Vectorized 82 / 92 12.7 78.5 97.0X +SQL ORC Vectorized with copy 82 / 88 12.7 78.5 97.0X +SQL ORC MR 900 / 909 1.2 858.5 8.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13539 / 13543 0.1 12912.0 1.0X -SQL Json 17420 / 17446 0.1 16613.1 0.8X -SQL Parquet Vectorized 103 / 120 10.2 98.1 131.6X -SQL Parquet MR 250 / 258 4.2 238.9 54.1X -SQL ORC Vectorized 99 / 104 10.6 94.6 136.5X -SQL ORC Vectorized with copy 100 / 106 10.5 95.6 135.1X -SQL ORC MR 1653 / 1659 0.6 1576.3 8.2X +SQL CSV 13489 / 13508 0.1 12864.3 1.0X +SQL Json 18813 / 18827 0.1 17941.4 0.7X +SQL Parquet Vectorized 107 / 111 9.8 101.8 126.3X +SQL Parquet MR 275 / 286 3.8 262.3 49.0X +SQL ORC Vectorized 107 / 115 9.8 101.7 126.4X +SQL ORC Vectorized with copy 107 / 115 9.8 102.3 125.8X +SQL ORC MR 1659 / 1664 0.6 1582.3 8.1X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index a1f51f8e54805..ecd9ead0ae39a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -447,7 +447,9 @@ object DataSourceReadBenchmark extends BenchmarkBase with SQLHelper { } def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - val benchmark = new Benchmark("String with Nulls Scan", values, output = output) + val percentageOfNulls = fractionOfNulls * 100 + val benchmark = + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt index c77f966723d71..80c2f5e93405a 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt @@ -2,172 +2,172 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1630 / 1639 9.7 103.6 1.0X -Native ORC Vectorized 253 / 288 62.2 16.1 6.4X -Native ORC Vectorized with copy 227 / 244 69.2 14.5 7.2X -Hive built-in ORC 1980 / 1991 7.9 125.9 0.8X +Native ORC MR 1725 / 1759 9.1 109.7 1.0X +Native ORC Vectorized 272 / 316 57.8 17.3 6.3X +Native ORC Vectorized with copy 239 / 254 65.7 15.2 7.2X +Hive built-in ORC 1970 / 1987 8.0 125.3 0.9X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1587 / 1589 9.9 100.9 1.0X -Native ORC Vectorized 227 / 242 69.2 14.5 7.0X -Native ORC Vectorized with copy 228 / 238 69.0 14.5 7.0X -Hive built-in ORC 2323 / 2332 6.8 147.7 0.7X +Native ORC MR 1633 / 1672 9.6 103.8 1.0X +Native ORC Vectorized 238 / 255 66.0 15.1 6.9X +Native ORC Vectorized with copy 235 / 253 66.8 15.0 6.9X +Hive built-in ORC 2293 / 2305 6.9 145.8 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1726 / 1771 9.1 109.7 1.0X -Native ORC Vectorized 309 / 333 50.9 19.7 5.6X -Native ORC Vectorized with copy 313 / 321 50.2 19.9 5.5X -Hive built-in ORC 2668 / 2672 5.9 169.6 0.6X +Native ORC MR 1677 / 1699 9.4 106.6 1.0X +Native ORC Vectorized 325 / 342 48.3 20.7 5.2X +Native ORC Vectorized with copy 328 / 341 47.9 20.9 5.1X +Hive built-in ORC 2561 / 2569 6.1 162.8 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1722 / 1747 9.1 109.5 1.0X -Native ORC Vectorized 395 / 403 39.8 25.1 4.4X -Native ORC Vectorized with copy 399 / 405 39.4 25.4 4.3X -Hive built-in ORC 2767 / 2777 5.7 175.9 0.6X +Native ORC MR 1791 / 1795 8.8 113.9 1.0X +Native ORC Vectorized 400 / 408 39.3 25.4 4.5X +Native ORC Vectorized with copy 410 / 417 38.4 26.1 4.4X +Hive built-in ORC 2713 / 2720 5.8 172.5 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1797 / 1824 8.8 114.2 1.0X -Native ORC Vectorized 434 / 441 36.2 27.6 4.1X -Native ORC Vectorized with copy 437 / 447 36.0 27.8 4.1X -Hive built-in ORC 2701 / 2710 5.8 171.7 0.7X +Native ORC MR 1791 / 1805 8.8 113.8 1.0X +Native ORC Vectorized 433 / 438 36.3 27.5 4.1X +Native ORC Vectorized with copy 441 / 447 35.7 28.0 4.1X +Hive built-in ORC 2690 / 2803 5.8 171.0 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1931 / 2028 8.1 122.8 1.0X -Native ORC Vectorized 542 / 557 29.0 34.5 3.6X -Native ORC Vectorized with copy 550 / 564 28.6 35.0 3.5X -Hive built-in ORC 2816 / 3206 5.6 179.1 0.7X +Native ORC MR 1911 / 1930 8.2 121.5 1.0X +Native ORC Vectorized 543 / 552 29.0 34.5 3.5X +Native ORC Vectorized with copy 547 / 555 28.8 34.8 3.5X +Hive built-in ORC 2967 / 3065 5.3 188.6 0.6X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 4012 / 4068 2.6 382.6 1.0X -Native ORC Vectorized 2337 / 2339 4.5 222.9 1.7X -Native ORC Vectorized with copy 2520 / 2540 4.2 240.3 1.6X -Hive built-in ORC 5503 / 5575 1.9 524.8 0.7X +Native ORC MR 4160 / 4188 2.5 396.7 1.0X +Native ORC Vectorized 2405 / 2406 4.4 229.4 1.7X +Native ORC Vectorized with copy 2588 / 2592 4.1 246.8 1.6X +Hive built-in ORC 5514 / 5562 1.9 525.9 0.8X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Data column - Native ORC MR 2020 / 2025 7.8 128.4 1.0X -Data column - Native ORC Vectorized 398 / 409 39.5 25.3 5.1X -Data column - Native ORC Vectorized with copy 406 / 411 38.8 25.8 5.0X -Data column - Hive built-in ORC 2967 / 2969 5.3 188.6 0.7X -Partition column - Native ORC MR 1494 / 1505 10.5 95.0 1.4X -Partition column - Native ORC Vectorized 73 / 82 216.3 4.6 27.8X -Partition column - Native ORC Vectorized with copy 71 / 80 221.4 4.5 28.4X -Partition column - Hive built-in ORC 1932 / 1937 8.1 122.8 1.0X -Both columns - Native ORC MR 2057 / 2071 7.6 130.8 1.0X -Both columns - Native ORC Vectorized 445 / 448 35.4 28.3 4.5X -Both column - Native ORC Vectorized with copy 534 / 539 29.4 34.0 3.8X -Both columns - Hive built-in ORC 2994 / 2994 5.3 190.3 0.7X +Data column - Native ORC MR 1863 / 1867 8.4 118.4 1.0X +Data column - Native ORC Vectorized 411 / 418 38.2 26.2 4.5X +Data column - Native ORC Vectorized with copy 417 / 422 37.8 26.5 4.5X +Data column - Hive built-in ORC 3297 / 3308 4.8 209.6 0.6X +Partition column - Native ORC MR 1505 / 1506 10.4 95.7 1.2X +Partition column - Native ORC Vectorized 80 / 93 195.6 5.1 23.2X +Partition column - Native ORC Vectorized with copy 78 / 86 201.4 5.0 23.9X +Partition column - Hive built-in ORC 1960 / 1979 8.0 124.6 1.0X +Both columns - Native ORC MR 2076 / 2090 7.6 132.0 0.9X +Both columns - Native ORC Vectorized 450 / 463 34.9 28.6 4.1X +Both column - Native ORC Vectorized with copy 532 / 538 29.6 33.8 3.5X +Both columns - Hive built-in ORC 3528 / 3548 4.5 224.3 0.5X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1771 / 1785 5.9 168.9 1.0X -Native ORC Vectorized 372 / 375 28.2 35.5 4.8X -Native ORC Vectorized with copy 543 / 576 19.3 51.8 3.3X -Hive built-in ORC 2671 / 2671 3.9 254.7 0.7X +Native ORC MR 1727 / 1733 6.1 164.7 1.0X +Native ORC Vectorized 375 / 379 28.0 35.7 4.6X +Native ORC Vectorized with copy 552 / 556 19.0 52.6 3.1X +Hive built-in ORC 2665 / 2666 3.9 254.2 0.6X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3276 / 3302 3.2 312.5 1.0X -Native ORC Vectorized 1057 / 1080 9.9 100.8 3.1X -Native ORC Vectorized with copy 1420 / 1431 7.4 135.4 2.3X -Hive built-in ORC 5377 / 5407 2.0 512.8 0.6X +Native ORC MR 3324 / 3325 3.2 317.0 1.0X +Native ORC Vectorized 1085 / 1106 9.7 103.4 3.1X +Native ORC Vectorized with copy 1463 / 1471 7.2 139.5 2.3X +Hive built-in ORC 5272 / 5299 2.0 502.8 0.6X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3147 / 3147 3.3 300.1 1.0X -Native ORC Vectorized 1305 / 1319 8.0 124.4 2.4X -Native ORC Vectorized with copy 1685 / 1686 6.2 160.7 1.9X -Hive built-in ORC 4077 / 4085 2.6 388.8 0.8X +Native ORC MR 3045 / 3046 3.4 290.4 1.0X +Native ORC Vectorized 1248 / 1260 8.4 119.0 2.4X +Native ORC Vectorized with copy 1609 / 1624 6.5 153.5 1.9X +Hive built-in ORC 3989 / 3999 2.6 380.4 0.8X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1739 / 1744 6.0 165.8 1.0X -Native ORC Vectorized 500 / 501 21.0 47.7 3.5X -Native ORC Vectorized with copy 618 / 631 17.0 58.9 2.8X -Hive built-in ORC 2411 / 2427 4.3 229.9 0.7X +Native ORC MR 1692 / 1694 6.2 161.3 1.0X +Native ORC Vectorized 471 / 493 22.3 44.9 3.6X +Native ORC Vectorized with copy 588 / 590 17.8 56.1 2.9X +Hive built-in ORC 2398 / 2411 4.4 228.7 0.7X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1348 / 1366 0.8 1285.3 1.0X -Native ORC Vectorized 119 / 134 8.8 113.5 11.3X -Native ORC Vectorized with copy 119 / 148 8.8 113.9 11.3X -Hive built-in ORC 487 / 507 2.2 464.8 2.8X +Native ORC MR 1371 / 1379 0.8 1307.5 1.0X +Native ORC Vectorized 121 / 135 8.6 115.8 11.3X +Native ORC Vectorized with copy 122 / 138 8.6 116.2 11.3X +Hive built-in ORC 521 / 561 2.0 497.1 2.6X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 2667 / 2837 0.4 2543.6 1.0X -Native ORC Vectorized 203 / 222 5.2 193.4 13.2X -Native ORC Vectorized with copy 217 / 255 4.8 207.0 12.3X -Hive built-in ORC 737 / 741 1.4 702.4 3.6X +Native ORC MR 2711 / 2767 0.4 2585.5 1.0X +Native ORC Vectorized 210 / 232 5.0 200.5 12.9X +Native ORC Vectorized with copy 208 / 219 5.0 198.4 13.0X +Hive built-in ORC 764 / 775 1.4 728.3 3.5X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3954 / 3956 0.3 3770.4 1.0X -Native ORC Vectorized 348 / 360 3.0 331.7 11.4X -Native ORC Vectorized with copy 349 / 359 3.0 333.2 11.3X -Hive built-in ORC 1057 / 1067 1.0 1008.0 3.7X +Native ORC MR 3979 / 3988 0.3 3794.4 1.0X +Native ORC Vectorized 357 / 366 2.9 340.2 11.2X +Native ORC Vectorized with copy 361 / 371 2.9 344.5 11.0X +Hive built-in ORC 1091 / 1095 1.0 1040.5 3.6X diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index ec13288f759a6..eb3cde8472dac 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -32,9 +32,11 @@ import org.apache.spark.sql.types._ * Benchmark to measure ORC read performance. * {{{ * To run this benchmark: - * 1. without sbt: bin/spark-submit --class - * 2. build/sbt "sql/test:runMain " - * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * 1. without sbt: bin/spark-submit --class + * --jars ,,,, + * + * 2. build/sbt "hive/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain " * Results will be written to "benchmarks/OrcReadBenchmark-results.txt". * }}} * @@ -266,8 +268,9 @@ object OrcReadBenchmark extends BenchmarkBase with SQLHelper { s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + val percentageOfNulls = fractionOfNulls * 100 val benchmark = - new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values, output = output) + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { From 973f7c01df0788b6f5d21224d96c33f14c5b8c64 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Nov 2018 15:49:36 -0800 Subject: [PATCH 2024/2461] [MINOR] update HiveExternalCatalogVersionsSuite to test 2.4.0 ## What changes were proposed in this pull request? Since Spark 2.4.0 is released, we should test it in HiveExternalCatalogVersionsSuite ## How was this patch tested? N/A Closes #22984 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index fd4985d131885..f1e842334416c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -206,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.1.3", "2.2.2", "2.3.2") + val testingVersions = Seq("2.2.2", "2.3.2", "2.4.0") protected var spark: SparkSession = _ From 79551f558dafed41177b605b0436e9340edf5712 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 9 Nov 2018 09:45:06 +0800 Subject: [PATCH 2025/2461] [SPARK-25945][SQL] Support locale while parsing date/timestamp from CSV/JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In the PR, I propose to add new option `locale` into CSVOptions/JSONOptions to make parsing date/timestamps in local languages possible. Currently the locale is hard coded to `Locale.US`. ## How was this patch tested? Added two tests for parsing a date from CSV/JSON - `ноя 2018`. Closes #22951 from MaxGekk/locale. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/readwriter.py | 15 +++++++++++---- python/pyspark/sql/streaming.py | 14 ++++++++++---- .../spark/sql/catalyst/csv/CSVOptions.scala | 7 +++++-- .../spark/sql/catalyst/json/JSONOptions.scala | 7 +++++-- .../expressions/CsvExpressionsSuite.scala | 19 ++++++++++++++++++- .../expressions/JsonExpressionsSuite.scala | 19 ++++++++++++++++++- .../apache/spark/sql/DataFrameReader.scala | 4 ++++ .../sql/streaming/DataStreamReader.scala | 4 ++++ .../apache/spark/sql/CsvFunctionsSuite.scala | 17 +++++++++++++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 17 +++++++++++++++++ 10 files changed, 109 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 690b13072244b..726de4a965418 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - dropFieldIfAllNull=None, encoding=None): + dropFieldIfAllNull=None, encoding=None, locale=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param dropFieldIfAllNull: whether to ignore column of all null values or empty array/struct during schema inference. If None is set, it uses the default value, ``false``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, + locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -446,6 +450,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non If None is set, it uses the default value, ``1.0``. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -465,7 +472,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index b18453b2a4f96..02b14ea187cba 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -404,7 +404,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -469,6 +469,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -483,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -564,7 +567,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None): + enforceSchema=None, emptyValue=None, locale=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -660,6 +663,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non different, ``\0`` otherwise.. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -677,7 +683,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue) + emptyValue=emptyValue, locale=locale) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index cdaaa172e8367..642823582a645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -131,13 +131,16 @@ class CSVOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 64152e04928d2..e10b8a327c01a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -76,16 +76,19 @@ private[sql] class JSONOptions( // Whether to ignore column of all null values or empty array/struct during schema inference val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index d006197bd5678..f5aaaec456153 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -209,4 +210,20 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P "2015-12-31T16:00:00" ) } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = sdf.format(date) + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + CsvToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 304642161146b..6ee8c74010d3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -737,4 +738,20 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))), "struct") } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = s"""{"d":"${sdf.format(date)}"}""" + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + JsonToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 95c97e5c9433c..02ffc940184db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -384,6 +384,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * for schema inferring. *
    • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
    • + *
    • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
    • * * * @since 2.0.0 @@ -604,6 +606,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
    • `multiLine` (default `false`): parse one record, which may span multiple lines.
    • + *
    • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
    • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4c7dcedafeeae..20c84305776ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -296,6 +296,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * that should be used for parsing. *
    • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
    • + *
    • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
    • * * * @since 2.0.0 @@ -372,6 +374,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
    • `multiLine` (default `false`): parse one record, which may span multiple lines.
    • + *
    • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
    • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index b97ac380def63..1c359ce1d2014 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -164,4 +167,18 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil) } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""${sdf.format(ts)}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_csv($"value", lit("time timestamp"), options.asJava)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index d6b73387e84b3..24e7564259c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import collection.JavaConverters._ import org.apache.spark.SparkException @@ -591,4 +594,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))), Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil) } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""{"time": "${sdf.format(ts)}"}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_json($"value", "time timestamp", options)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } From 0558d021cc0aeae37ef0e043d244fd0300a57cd5 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 9 Nov 2018 11:45:03 +0800 Subject: [PATCH 2026/2461] [SPARK-25510][SQL][TEST][FOLLOW-UP] Remove BenchmarkWithCodegen ## What changes were proposed in this pull request? Remove `BenchmarkWithCodegen` as we don't use it anymore. More details: https://github.com/apache/spark/pull/22484#discussion_r221397904 ## How was this patch tested? N/A Closes #22985 from wangyum/SPARK-25510. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon --- .../benchmark/BenchmarkWithCodegen.scala | 54 ------------------- 1 file changed, 54 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala deleted file mode 100644 index 51331500479a3..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import org.apache.spark.SparkFunSuite -import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.SparkSession - -/** - * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together - * with other test suites). - */ -private[benchmark] trait BenchmarkWithCodegen extends SparkFunSuite { - - lazy val sparkSession = SparkSession.builder - .master("local[1]") - .appName("microbenchmark") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.sql.autoBroadcastJoinThreshold", 1) - .getOrCreate() - - /** Runs function `f` with whole stage codegen on and off. */ - def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality) - - benchmark.addCase(s"$name wholestage off", numIters = 2) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) - f - } - - benchmark.addCase(s"$name wholestage on", numIters = 5) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - f - } - - benchmark.run() - } - -} From 297b81e0eb1493b12838c3c48c6f754289ce1c1f Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Fri, 9 Nov 2018 07:55:02 -0600 Subject: [PATCH 2027/2461] [SPARK-20156][SQL][ML][FOLLOW-UP] Java String toLowerCase with Locale.ROOT ## What changes were proposed in this pull request? Add `Locale.ROOT` to all internal calls to String `toLowerCase`, `toUpperCase` ## How was this patch tested? existing tests Closes #22975 from zhengruifeng/Tokenizer_Locale. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- project/SparkBuild.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ca57df0e31a7f..5e034f9fe2a95 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,6 +17,7 @@ import java.io._ import java.nio.file.Files +import java.util.Locale import scala.io.Source import scala.util.Properties @@ -650,10 +651,13 @@ object Assembly { }, jarName in (Test, assembly) := s"${moduleName.value}-test-${version.value}.jar", mergeStrategy in assembly := { - case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard - case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") + => MergeStrategy.discard + case m if m.toLowerCase(Locale.ROOT).matches("meta-inf.*\\.sf$") + => MergeStrategy.discard case "log4j.properties" => MergeStrategy.discard - case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines + case m if m.toLowerCase(Locale.ROOT).startsWith("meta-inf/services/") + => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } From 25f506e2ad865ed671cfc618ca9d272bfb5712b7 Mon Sep 17 00:00:00 2001 From: William Montaz Date: Fri, 9 Nov 2018 08:02:53 -0600 Subject: [PATCH 2028/2461] [SPARK-25973][CORE] Spark History Main page performance improvement HistoryPage.scala counts applications (with a predicate depending on if it is displaying incomplete or complete applications) to check if it must display the dataTable. Since it only checks if allAppsSize > 0, we could use exists method on the iterator. This way we stop iterating at the first occurence found. Such a change has been relevant (roughly 12s improvement on page loading) on our cluster that runs tens of thousands of jobs per day. Closes #22982 from Willymontaz/SPARK-25973. Authored-by: William Montaz Signed-off-by: Sean Owen --- .../org/apache/spark/deploy/history/HistoryPage.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 32667ddf5c7ea..00ca4efa4d266 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -31,8 +31,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val requestedIncomplete = Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean - val allAppsSize = parent.getApplicationList() - .count(isApplicationCompleted(_) != requestedIncomplete) + val displayApplications = parent.getApplicationList() + .exists(isApplicationCompleted(_) != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() @@ -63,9 +63,9 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } { - if (allAppsSize > 0) { + if (displayApplications) { ++ + request, "/static/dataTables.rowsGroup.js")}> ++
      ++ ++ From 657fd00b5204859c2e6d7c19a71a3ec5ecf7c869 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Nov 2018 08:22:26 -0800 Subject: [PATCH 2029/2461] [SPARK-25988][SQL] Keep names unchanged when deduplicating the column names in Analyzer ## What changes were proposed in this pull request? When the queries do not use the column names with the same case, users might hit various errors. Below is a typical test failure they can hit. ``` Expected only partition pruning predicates: ArrayBuffer(isnotnull(tdate#237), (cast(tdate#237 as string) >= 2017-08-15)); org.apache.spark.sql.AnalysisException: Expected only partition pruning predicates: ArrayBuffer(isnotnull(tdate#237), (cast(tdate#237 as string) >= 2017-08-15)); at org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.prunePartitionsByFilter(ExternalCatalogUtils.scala:146) at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.listPartitionsByFilter(InMemoryCatalog.scala:560) at org.apache.spark.sql.catalyst.catalog.SessionCatalog.listPartitionsByFilter(SessionCatalog.scala:925) ``` ## How was this patch tested? Added two test cases. Closes #22990 from gatorsmile/fix1283. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../sql/catalyst/analysis/Analyzer.scala | 3 +- .../sql/catalyst/analysis/unresolved.scala | 1 + .../expressions/namedExpressions.scala | 5 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 53 +++++++++++++++++++ 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2d22c5e7ce60..6dc5b3f28b914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -824,7 +824,8 @@ class Analyzer( } private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 857cf382b8f2c..36cad3cf74785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -112,6 +112,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this + override def withExprId(newExprId: ExprId): UnresolvedAttribute = this override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 584a2946bd564..049ea77691395 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -115,6 +115,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withQualifier(newQualifier: Seq[String]): Attribute def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute + def withExprId(newExprId: ExprId): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -299,7 +300,7 @@ case class AttributeReference( } } - def withExprId(newExprId: ExprId): AttributeReference = { + override def withExprId(newExprId: ExprId): AttributeReference = { if (exprId == newExprId) { this } else { @@ -362,6 +363,8 @@ case class PrettyAttribute( throw new UnsupportedOperationException override def qualifier: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + override def withExprId(newExprId: ExprId): Attribute = + throw new UnsupportedOperationException override def nullable: Boolean = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 631ab1b7ece7f..dbb0790a4682c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2856,6 +2856,59 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000"))) } } + + test("SPARK-25988: self join with aliases on partitioned tables #1") { + withTempView("tmpView1", "tmpView2") { + withTable("tab1", "tab2") { + sql( + """ + |CREATE TABLE `tab1` (`col1` INT, `TDATE` DATE) + |USING CSV + |PARTITIONED BY (TDATE) + """.stripMargin) + spark.table("tab1").where("TDATE >= '2017-08-15'").createOrReplaceTempView("tmpView1") + sql("CREATE TABLE `tab2` (`TDATE` DATE) USING parquet") + sql( + """ + |CREATE OR REPLACE TEMPORARY VIEW tmpView2 AS + |SELECT N.tdate, col1 AS aliasCol1 + |FROM tmpView1 N + |JOIN tab2 Z + |ON N.tdate = Z.tdate + """.stripMargin) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + sql("SELECT * FROM tmpView2 x JOIN tmpView2 y ON x.tdate = y.tdate").collect() + } + } + } + } + + test("SPARK-25988: self join with aliases on partitioned tables #2") { + withTempView("tmp") { + withTable("tab1", "tab2") { + sql( + """ + |CREATE TABLE `tab1` (`EX` STRING, `TDATE` DATE) + |USING parquet + |PARTITIONED BY (tdate) + """.stripMargin) + sql("CREATE TABLE `tab2` (`TDATE` DATE) USING parquet") + sql( + """ + |CREATE OR REPLACE TEMPORARY VIEW TMP as + |SELECT N.tdate, EX AS new_ex + |FROM tab1 N + |JOIN tab2 Z + |ON N.tdate = Z.tdate + """.stripMargin) + sql( + """ + |SELECT * FROM TMP x JOIN TMP y + |ON x.tdate = y.tdate + """.stripMargin).queryExecution.executedPlan + } + } + } } case class Foo(bar: Option[String]) From 1db799795cf3c15798fbfb6043ec5775e16ba5ea Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 9 Nov 2018 09:44:04 -0800 Subject: [PATCH 2030/2461] [SPARK-25979][SQL] Window function: allow parentheses around window reference ## What changes were proposed in this pull request? Very minor parser bug, but possibly problematic for code-generated queries: Consider the following two queries: ``` SELECT avg(k) OVER (w) FROM kv WINDOW w AS (PARTITION BY v ORDER BY w) ORDER BY 1 ``` and ``` SELECT avg(k) OVER w FROM kv WINDOW w AS (PARTITION BY v ORDER BY w) ORDER BY 1 ``` The former, with parens around the OVER condition, fails to parse while the latter, without parens, succeeds: ``` Error in SQL statement: ParseException: mismatched input '(' expecting {, ',', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'LIMIT', 'LATERAL', 'WINDOW', 'UNION', 'EXCEPT', 'MINUS', 'INTERSECT', 'SORT', 'CLUSTER', 'DISTRIBUTE'}(line 1, pos 19) == SQL == SELECT avg(k) OVER (w) FROM kv WINDOW w AS (PARTITION BY v ORDER BY w) ORDER BY 1 -------------------^^^ ``` This was found when running the cockroach DB tests. I tried PostgreSQL, The SQL with parentheses is also workable. ## How was this patch tested? Unit test Closes #22987 from gengliangwang/windowParentheses. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../resources/sql-tests/inputs/window.sql | 6 ++++++ .../sql-tests/results/window.sql.out | 19 ++++++++++++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e2d34d1650ddc..5e732edb17baa 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -691,6 +691,7 @@ namedWindow windowSpec : name=identifier #windowRef + | '('name=identifier')' #windowRef | '(' ( CLUSTER BY partition+=expression (',' partition+=expression)* | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index cda4db4b449fe..faab4c61c8640 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -109,3 +109,9 @@ last_value(false, false) OVER w AS last_value_contain_null FROM testData WINDOW w AS () ORDER BY cate, val; + +-- parentheses around window reference +SELECT cate, sum(val) OVER (w) +FROM testData +WHERE val is not null +WINDOW w AS (PARTITION BY cate ORDER BY val); diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 5071e0bd26b2a..367dc4f513635 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 23 -- !query 0 @@ -363,3 +363,20 @@ NULL a false true false false true false 1 b false true false false true false 2 b false true false false true false 3 b false true false false true false + + +-- !query 22 +SELECT cate, sum(val) OVER (w) +FROM testData +WHERE val is not null +WINDOW w AS (PARTITION BY cate ORDER BY val) +-- !query 22 schema +struct +-- !query 22 output +NULL 3 +a 2 +a 2 +a 4 +b 1 +b 3 +b 6 From 8e5f3c6ba6ef9b92578a6b292cfa1c480370cbfc Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 9 Nov 2018 15:40:15 -0600 Subject: [PATCH 2031/2461] [SPARK-24101][ML][MLLIB] ML Evaluators should use weight column - added weight column for multiclass classification evaluator ## What changes were proposed in this pull request? The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data. I've closed the PR: https://github.com/apache/spark/pull/16557 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update. Note: I've updated the JIRA to: https://issues.apache.org/jira/browse/SPARK-24101 Which is a child of JIRA: https://issues.apache.org/jira/browse/SPARK-18693 ## How was this patch tested? I added tests to the metrics class. Closes #17086 from imatiach-msft/ilmat/multiclass-evaluate. Authored-by: Ilya Matiach Signed-off-by: Sean Owen --- .../MulticlassClassificationEvaluator.scala | 19 ++- .../mllib/evaluation/MulticlassMetrics.scala | 55 +++--- .../evaluation/MulticlassMetricsSuite.scala | 158 ++++++++++++++---- 3 files changed, 170 insertions(+), 62 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 794b1e7d9d881..f1602c1bc5333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "f1") @Since("2.0.0") @@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) + val predictionAndLabelsWithWeights = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new MulticlassMetrics(predictionAndLabels) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure case "weightedPrecision" => metrics.weightedPrecision diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 980e0c92531a2..ad83c24ede964 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) { + val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case (prediction: Double, label: Double) => + (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected tuples, got $other") + } /** * An auxiliary constructor taking a DataFrame. @@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predLabelsWeight.map { + case (_: Double, label: Double, weight: Double) => + (label, weight) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predictionAndLabels - .map { case (prediction, label) => - ((label, prediction), 1) + private lazy val confusions = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() @@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl while (i < n) { var j = 0 while (j < n) { - values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0) j += 1 } i += 1 @@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def falsePositiveRate(label: Double): Double = { - val fp = fpByClass.getOrElse(label, 0) - fp.toDouble / (labelCount - labelCountByClass(label)) + val fp = fpByClass.getOrElse(label, 0.0) + fp / (labelCount - labelCountByClass(label)) } /** @@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) - val fp = fpByClass.getOrElse(label, 0) + val fp = fpByClass.getOrElse(label, 0.0) if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) } @@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param label the label. */ @Since("1.1.0") - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label) /** * Returns f-measure for a given label (category) @@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * out of the total number of instances.) */ @Since("2.0.0") - lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount + lazy val accuracy: Double = tpByClass.values.sum / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 5394baab94bcf..8779de590a256 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -18,10 +18,14 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.Matrices +import org.apache.spark.ml.linalg.Matrices +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val delta = 1e-7 + test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: @@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) - val delta = 0.0000001 val tpRate0 = 2.0 / (2 + 2) val tpRate1 = 3.0 / (3 + 1) val tpRate2 = 1.0 / (1 + 0) @@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) + + assert(metrics.accuracy ~== + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = 4.0 / 9 + val weight1 = 4.0 / 9 + val weight2 = 1.0 / 9 + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) + } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) - assert(math.abs(metrics.accuracy - - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw + val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw + val weight2 = 1 * w2 / tw + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) } } From d66a4e82eceb89a274edeb22c2fb4384bed5078b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 9 Nov 2018 22:42:48 -0800 Subject: [PATCH 2032/2461] [SPARK-25102][SQL] Write Spark version to ORC/Parquet file metadata ## What changes were proposed in this pull request? Currently, Spark writes Spark version number into Hive Table properties with `spark.sql.create.version`. ``` parameters:{ spark.sql.sources.schema.part.0={ "type":"struct", "fields":[{"name":"a","type":"integer","nullable":true,"metadata":{}}] }, transient_lastDdlTime=1541142761, spark.sql.sources.schema.numParts=1, spark.sql.create.version=2.4.0 } ``` This PR aims to write Spark versions to ORC/Parquet file metadata with `org.apache.spark.sql.create.version` because we used `org.apache.` prefix in Parquet metadata already. It's different from Hive Table property key `spark.sql.create.version`, but it seems that we cannot change Hive Table property for backward compatibility. After this PR, ORC and Parquet file generated by Spark will have the following metadata. **ORC (`native` and `hive` implmentation)** ``` $ orc-tools meta /tmp/o File Version: 0.12 with ... ... User Metadata: org.apache.spark.sql.create.version=3.0.0 ``` **PARQUET** ``` $ parquet-tools meta /tmp/p ... creator: parquet-mr version 1.10.0 (build 031a6654009e3b82020012a18434c582bd74c73a) extra: org.apache.spark.sql.create.version = 3.0.0 extra: org.apache.spark.sql.parquet.row.metadata = {"type":"struct","fields":[{"name":"id","type":"long","nullable":false,"metadata":{}}]} ``` ## How was this patch tested? Pass the Jenkins with newly added test cases. This closes #22255. Closes #22932 from dongjoon-hyun/SPARK-25102. Authored-by: Dongjoon Hyun Signed-off-by: gatorsmile --- .../main/scala/org/apache/spark/package.scala | 3 +++ .../org/apache/spark/util/VersionUtils.scala | 14 ++++++++++ .../apache/spark/util/VersionUtilsSuite.scala | 25 ++++++++++++++++++ .../datasources/orc/OrcOutputWriter.scala | 15 ++++++++--- .../execution/datasources/orc/OrcUtils.scala | 14 +++++++--- .../parquet/ParquetWriteSupport.scala | 7 ++++- .../scala/org/apache/spark/sql/package.scala | 9 +++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 4 +-- .../datasources/HadoopFsRelationSuite.scala | 2 +- .../datasources/orc/OrcSourceSuite.scala | 20 +++++++++++++- .../datasources/parquet/ParquetIOSuite.scala | 21 ++++++++++++++- .../spark/sql/hive/orc/OrcFileFormat.scala | 26 ++++++++++++++++--- 12 files changed, 144 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8058a4d5dbdea..5d0639e92c36a 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -19,6 +19,8 @@ package org.apache import java.util.Properties +import org.apache.spark.util.VersionUtils + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, @@ -89,6 +91,7 @@ package object spark { } val SPARK_VERSION = SparkBuildInfo.spark_version + val SPARK_VERSION_SHORT = VersionUtils.shortVersion(SparkBuildInfo.spark_version) val SPARK_BRANCH = SparkBuildInfo.spark_branch val SPARK_REVISION = SparkBuildInfo.spark_revision val SPARK_BUILD_USER = SparkBuildInfo.spark_build_user diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala index 828153b868420..c0f8866dd58dc 100644 --- a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala @@ -23,6 +23,7 @@ package org.apache.spark.util private[spark] object VersionUtils { private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r + private val shortVersionRegex = """^(\d+\.\d+\.\d+)(.*)?$""".r /** * Given a Spark version string, return the major version number. @@ -36,6 +37,19 @@ private[spark] object VersionUtils { */ def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2 + /** + * Given a Spark version string, return the short version string. + * E.g., for 3.0.0-SNAPSHOT, return '3.0.0'. + */ + def shortVersion(sparkVersion: String): String = { + shortVersionRegex.findFirstMatchIn(sparkVersion) match { + case Some(m) => m.group(1) + case None => + throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" + + s" version string, but it could not find the major/minor/maintenance version numbers.") + } + } + /** * Given a Spark version string, return the (major version number, minor version number). * E.g., for 2.0.1-SNAPSHOT, return (2, 0). diff --git a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala index b36d6be231d39..56623ebea1651 100644 --- a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala @@ -73,4 +73,29 @@ class VersionUtilsSuite extends SparkFunSuite { } } } + + test("Return short version number") { + assert(shortVersion("3.0.0") === "3.0.0") + assert(shortVersion("3.0.0-SNAPSHOT") === "3.0.0") + withClue("shortVersion parsing should fail for missing maintenance version number") { + intercept[IllegalArgumentException] { + shortVersion("3.0") + } + } + withClue("shortVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + shortVersion("x.0.0") + } + } + withClue("shortVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + shortVersion("3.x.0") + } + } + withClue("shortVersion parsing should fail for invalid maintenance version number") { + intercept[IllegalArgumentException] { + shortVersion("3.0.x") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 84755bfa301f0..7e38fc651a31f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.orc.mapred.OrcStruct -import org.apache.orc.mapreduce.OrcOutputFormat +import org.apache.orc.OrcFile +import org.apache.orc.mapred.{OrcOutputFormat => OrcMapRedOutputFormat, OrcStruct} +import org.apache.orc.mapreduce.{OrcMapreduceRecordWriter, OrcOutputFormat} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter @@ -36,11 +37,17 @@ private[orc] class OrcOutputWriter( private[this] val serializer = new OrcSerializer(dataSchema) private val recordWriter = { - new OrcOutputFormat[OrcStruct]() { + val orcOutputFormat = new OrcOutputFormat[OrcStruct]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { new Path(path) } - }.getRecordWriter(context) + } + val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc") + val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration) + val writer = OrcFile.createWriter(filename, options) + val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer) + OrcUtils.addSparkVersionMetadata(writer) + recordWriter } override def write(row: InternalRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 95fb25bf5addb..57d2c56e87b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.execution.datasources.orc +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SparkException +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ @@ -144,4 +145,11 @@ object OrcUtils extends Logging { } } } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(writer: Writer): Unit = { + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index b40b8c2e61f33..8814e3c6ccf94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -29,7 +29,9 @@ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -93,7 +95,10 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter] val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) - val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava + val metadata = Map( + SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, + ParquetReadSupport.SPARK_METADATA_KEY -> schemaString + ).asJava logInfo( s"""Initialized Parquet WriteSupport with Catalyst schema: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 161e0102f0b43..354660e9d5943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,4 +44,13 @@ package object sql { type Strategy = SparkStrategy type DataFrame = Dataset[Row] + + /** + * Metadata key which is used to write Spark version in the followings: + * - Parquet file metadata + * - ORC file metadata + * + * Note that Hive table property `spark.sql.create.version` also has Spark version. + */ + private[sql] val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index e1567d06e23eb..861aa179a4a81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -506,7 +506,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 800) + assert(inMemoryRelation.computeStats().sizeInBytes === 868) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -519,7 +519,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 800) + assert(inMemoryRelation2.computeStats().sizeInBytes === 868) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index c1f2c18d1417d..6e08ee3c4ba3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -45,7 +45,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { import testImplicits._ Seq(1.0, 0.5).foreach { compressionFactor => withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, - "spark.sql.autoBroadcastJoinThreshold" -> "400") { + "spark.sql.autoBroadcastJoinThreshold" -> "434") { withTempPath { workDir => // the file size is 740 bytes val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index dc81c0585bf18..48910103e702a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 import java.sql.Timestamp import java.util.Locale @@ -30,7 +31,8 @@ import org.apache.orc.OrcProto.Stream.Kind import org.apache.orc.impl.RecordReaderImpl import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.Row +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -314,6 +316,22 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) } } + + test("Write Spark version into ORC file metadata") { + withTempPath { path => + spark.range(1).repartition(1).write.orc(path.getCanonicalPath) + + val partFiles = path.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + val orcFilePath = new Path(partFiles.head.getAbsolutePath) + val readerOptions = OrcFile.readerOptions(new Configuration()) + val reader = OrcFile.createReader(orcFilePath, readerOptions) + val version = UTF_8.decode(reader.getMetadataValue(SPARK_VERSION_METADATA_KEY)).toString + assert(version === SPARK_VERSION_SHORT) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 002c42f23bd64..6b05b9c0f7207 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.HadoopReadOptions import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.example.data.simple.SimpleGroup @@ -34,10 +35,11 @@ import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.SparkException +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} @@ -799,6 +801,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) } } + + test("Write Spark version into Parquet metadata") { + withTempPath { dir => + val path = dir.getAbsolutePath + spark.range(1).repartition(1).write.parquet(path) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + val conf = new Configuration() + val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), conf) + val parquetReadOptions = HadoopReadOptions.builder(conf).build() + val m = ParquetFileReader.open(hadoopInputFile, parquetReadOptions) + val metaData = m.getFileMetaData.getKeyValueMetaData + m.close() + + assert(metaData.get(SPARK_VERSION_METADATA_KEY) === SPARK_VERSION_SHORT) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 89e6ea8604974..4e641e34c18d9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.hive.orc import java.net.URI +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Properties import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -31,10 +33,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.orc.OrcConf.COMPRESS -import org.apache.spark.TaskContext +import org.apache.spark.{SPARK_VERSION_SHORT, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -274,12 +278,14 @@ private[orc] class OrcOutputWriter( override def close(): Unit = { if (recordWriterInstantiated) { + // Hive 1.2.1 ORC initializes its private `writer` field at the first write. + OrcFileFormat.addSparkVersionMetadata(recordWriter) recordWriter.close(Reporter.NULL) } } } -private[orc] object OrcFileFormat extends HiveInspectors { +private[orc] object OrcFileFormat extends HiveInspectors with Logging { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" @@ -339,4 +345,18 @@ private[orc] object OrcFileFormat extends HiveInspectors { val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(recordWriter: RecordWriter[NullWritable, Writable]): Unit = { + try { + val writerField = recordWriter.getClass.getDeclaredField("writer") + writerField.setAccessible(true) + val writer = writerField.get(recordWriter).asInstanceOf[Writer] + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) + } + } } From 2d085c13b7f715dbff23dd1f81af45ff903d1a79 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 10 Nov 2018 09:52:14 -0600 Subject: [PATCH 2033/2461] [SPARK-25984][CORE][SQL][STREAMING] Remove deprecated .newInstance(), primitive box class constructor calls ## What changes were proposed in this pull request? Deprecated in Java 11, replace Class.newInstance with Class.getConstructor.getInstance, and primtive wrapper class constructors with valueOf or equivalent ## How was this patch tested? Existing tests. Closes #22988 from srowen/SPARK-25984. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/api/python/PythonHadoopUtil.scala | 3 +- .../scala/org/apache/spark/api/r/SerDe.scala | 6 ++-- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../io/HadoopMapReduceCommitProtocol.scala | 2 +- .../spark/internal/io/SparkHadoopWriter.scala | 4 +-- .../apache/spark/metrics/MetricsSystem.scala | 2 +- .../org/apache/spark/rdd/BinaryFileRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +-- .../apache/spark/rdd/WholeTextFileRDD.scala | 2 +- .../spark/serializer/KryoSerializer.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 3 +- .../scala/org/apache/spark/FileSuite.scala | 2 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 14 ++++---- .../scheduler/TaskResultGetterSuite.scala | 2 +- .../spark/util/AccumulatorV2Suite.scala | 6 ++-- .../util/MutableURLClassLoaderSuite.scala | 15 ++++----- .../spark/util/SizeEstimatorSuite.scala | 32 +++++++++---------- .../spark/util/collection/SorterSuite.scala | 16 +++++----- .../unsafe/sort/RadixSortSuite.scala | 2 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 2 +- .../sql/kafka010/KafkaSourceProvider.scala | 2 +- .../kafka010/DirectKafkaInputDStream.scala | 2 +- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../spark/ml/linalg/MatrixUDTSuite.scala | 2 +- .../spark/ml/linalg/VectorUDTSuite.scala | 2 +- .../spark/repl/ExecutorClassLoaderSuite.scala | 12 +++---- .../cluster/SchedulerExtensionService.scala | 2 +- .../sql/catalyst/JavaTypeInference.scala | 4 +-- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++------ .../sql/catalyst/catalog/SessionCatalog.scala | 3 +- .../expressions/codegen/CodeGenerator.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../encoders/ExpressionEncoderSuite.scala | 16 +++++----- .../expressions/ObjectExpressionsSuite.scala | 8 ++--- .../aggregate/PercentileSuite.scala | 4 +-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 4 +-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/tables.scala | 3 +- .../execution/datasources/DataSource.scala | 12 +++---- .../datasources/jdbc/DriverRegistry.scala | 2 +- .../streaming/state/StateStore.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 3 +- .../sql/streaming/DataStreamWriter.scala | 2 +- .../org/apache/spark/sql/JavaRowSuite.java | 14 ++++---- .../apache/spark/sql/JavaStringLength.java | 2 +- .../sql/ApproximatePercentileQuerySuite.scala | 6 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++---- .../org/apache/spark/sql/DatasetSuite.scala | 16 +++++----- .../scala/org/apache/spark/sql/UDFSuite.scala | 2 +- .../execution/columnar/ColumnStatsSuite.scala | 4 +-- .../sources/RateStreamProviderSuite.scala | 8 +++-- .../sources/TextSocketStreamSuite.scala | 2 +- .../sources/v2/DataSourceV2UtilsSuite.scala | 2 +- .../sources/StreamingDataSourceV2Suite.scala | 8 +++-- .../org/apache/hive/service/cli/Column.java | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../org/apache/spark/sql/hive/HiveShim.scala | 2 +- .../apache/spark/sql/hive/TableReader.scala | 6 ++-- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/execution/HiveFileFormat.scala | 3 +- .../hive/execution/HiveTableScanExec.scala | 2 +- .../execution/ScriptTransformationExec.scala | 11 ++++--- .../sql/hive/HiveParquetMetastoreSuite.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 4 +-- .../spark/streaming/CheckpointSuite.scala | 3 +- 70 files changed, 190 insertions(+), 174 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72123f2232532..66038eeaea54f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -261,7 +261,7 @@ object SparkEnv extends Logging { // SparkConf, then one taking no arguments try { cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) - .newInstance(conf, new java.lang.Boolean(isDriver)) + .newInstance(conf, java.lang.Boolean.valueOf(isDriver)) .asInstanceOf[T] } catch { case _: NoSuchMethodException => diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 6259bead3ea88..2ab8add63efae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -43,7 +43,8 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).getConstructor(). + newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 537ab57f9664d..6e0a3f63988d4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -74,9 +74,9 @@ private[spark] object SerDe { jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null - case 'i' => new java.lang.Integer(readInt(dis)) - case 'd' => new java.lang.Double(readDouble(dis)) - case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'i' => java.lang.Integer.valueOf(readInt(dis)) + case 'd' => java.lang.Double.valueOf(readDouble(dis)) + case 'b' => java.lang.Boolean.valueOf(readBoolean(dis)) case 'c' => readString(dis) case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 88df7324a354a..0fc8c9bd789e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -829,7 +829,7 @@ private[spark] class SparkSubmit extends Logging { } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { - mainClass.newInstance().asInstanceOf[SparkApplication] + mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication] } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3e60c50ada59b..7477e03bfaa76 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -91,7 +91,7 @@ class HadoopMapReduceCommitProtocol( private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - val format = context.getOutputFormatClass.newInstance() + val format = context.getOutputFormatClass.getConstructor().newInstance() // If OutputFormat is Configurable, we should set conf to it. format match { case c: Configurable => c.setConf(context.getConfiguration) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 9ebd0aa301592..3a58ea816937b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -256,7 +256,7 @@ class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) private def getOutputFormat(): OutputFormat[K, V] = { require(outputFormat != null, "Must call initOutputFormat first.") - outputFormat.newInstance() + outputFormat.getConstructor().newInstance() } // -------------------------------------------------------------------------- @@ -379,7 +379,7 @@ class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfigura private def getOutputFormat(): NewOutputFormat[K, V] = { require(outputFormat != null, "Must call initOutputFormat first.") - outputFormat.newInstance() + outputFormat.getConstructor().newInstance() } // -------------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 3457a2632277d..bb7b434e9a113 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -179,7 +179,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Utils.classForName(classPath).newInstance() + val source = Utils.classForName(classPath).getConstructor().newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index a14bad47dfe10..039dbcbd5e035 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -41,7 +41,7 @@ private[spark] class BinaryFileRDD[T]( // traversing a large number of directories and files. Parallelize it. conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, Runtime.getRuntime.availableProcessors().toString) - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2d66d25ba39fa..483de28d92ab7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -120,7 +120,7 @@ class NewHadoopRDD[K, V]( } override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(_conf) @@ -183,7 +183,7 @@ class NewHadoopRDD[K, V]( } } - private val format = inputFormatClass.newInstance + private val format = inputFormatClass.getConstructor().newInstance() format match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index 9f3d0745c33c9..eada762b99c8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -44,7 +44,7 @@ private[spark] class WholeTextFileRDD( // traversing a large number of directories and files. Parallelize it. conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, Runtime.getRuntime.availableProcessors().toString) - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 72427dd6ce4d4..218c84352ce88 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -132,7 +132,8 @@ class KryoSerializer(conf: SparkConf) .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrators - .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) + .map(Class.forName(_, true, classLoader).getConstructor(). + newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e35dd72521247..edae2f95fce33 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -237,7 +237,7 @@ private[spark] class BlockManager( val priorityClass = conf.get( "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) val clazz = Utils.classForName(priorityClass) - val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + val ret = clazz.getConstructor().newInstance().asInstanceOf[BlockReplicationPolicy] logInfo(s"Using $priorityClass for block replication policy") ret } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 93b5826f8a74b..a07eee6ad8a4b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2430,7 +2430,8 @@ private[spark] object Utils extends Logging { "org.apache.spark.security.ShellBasedGroupsMappingProvider") if (groupProviderClassName != "") { try { - val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance. + val groupMappingServiceProvider = classForName(groupProviderClassName). + getConstructor().newInstance(). asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider] val currentUserGroups = groupMappingServiceProvider.getGroups(username) return currentUserGroups diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 34efcdf4bc886..df04a5ea1d99e 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -202,7 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val objs = sc.makeRDD(1 to 3).map { x => val loader = Thread.currentThread().getContextClassLoader - Class.forName(className, true, loader).newInstance() + Class.forName(className, true, loader).getConstructor().newInstance() } val outputDir = new File(tempDir, "output").getAbsolutePath objs.saveAsObjectFile(outputDir) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 47af5c3320dd9..0ec359d1c94f3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -574,7 +574,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("saveNewAPIHadoopFile should call setConf if format is configurable") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(1)))) // No error, non-configurable formats still work pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") @@ -591,14 +591,14 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("The JobId on the driver and executors should be the same during the commit") { // Create more than one rdd to mimic stageId not equal to rddId val pairs = sc.parallelize(Array((1, 2), (2, 3)), 2) - .map { p => (new Integer(p._1 + 1), new Integer(p._2 + 1)) } + .map { p => (Integer.valueOf(p._1 + 1), Integer.valueOf(p._2 + 1)) } .filter { p => p._1 > 0 } pairs.saveAsNewAPIHadoopFile[YetAnotherFakeFormat]("ignored") assert(JobID.jobid != -1) } test("saveAsHadoopFile should respect configured output committers") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(1)))) val conf = new JobConf() conf.setOutputCommitter(classOf[FakeOutputCommitter]) @@ -610,7 +610,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) FakeWriterWithCallback.calledBy = "" FakeWriterWithCallback.exception = null @@ -625,7 +625,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) val conf = new JobConf() FakeWriterWithCallback.calledBy = "" @@ -643,7 +643,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("saveAsNewAPIHadoopDataset should support invalid output paths when " + "there are no files to be committed to an absolute output location") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) def saveRddWithPath(path: String): Unit = { val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) @@ -671,7 +671,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { // for non-null invalid paths. test("saveAsHadoopDataset should respect empty output directory when " + "there are no files to be committed to an absolute output location") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) val conf = new JobConf() conf.setOutputKeyClass(classOf[Integer]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 1bddba8f6c82b..f8eb8bd71c170 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -194,7 +194,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // jar. sc = new SparkContext("local", "test", conf) val rdd = sc.parallelize(Seq(1), 1).map { _ => - val exc = excClass.newInstance().asInstanceOf[Exception] + val exc = excClass.getConstructor().newInstance().asInstanceOf[Exception] throw exc } diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index 621399af731f7..172bebbfec61d 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -40,7 +40,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.avg == 0.5) // Also test add using non-specialized add function - acc.add(new java.lang.Long(2)) + acc.add(java.lang.Long.valueOf(2)) assert(acc.count == 3) assert(acc.sum == 3) assert(acc.avg == 1.0) @@ -73,7 +73,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.avg == 0.5) // Also test add using non-specialized add function - acc.add(new java.lang.Double(2.0)) + acc.add(java.lang.Double.valueOf(2.0)) assert(acc.count == 3) assert(acc.sum == 3.0) assert(acc.avg == 1.0) @@ -96,7 +96,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.value.contains(0.0)) assert(!acc.isZero) - acc.add(new java.lang.Double(1.0)) + acc.add(java.lang.Double.valueOf(1.0)) val acc2 = acc.copyAndReset() assert(acc2.value.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index f6ac89fc2742a..8d844bd08771c 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.JavaConverters._ import org.scalatest.Matchers -import org.scalatest.Matchers._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} @@ -46,10 +45,10 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass2").newInstance() + val fakeClass = classLoader.loadClass("FakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") - val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() + val fakeClass2 = classLoader.loadClass("FakeClass2").getConstructor().newInstance() assert(fakeClass.getClass === fakeClass2.getClass) classLoader.close() parentLoader.close() @@ -58,10 +57,10 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("parent first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new MutableURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass1").newInstance() + val fakeClass = classLoader.loadClass("FakeClass1").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") - val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() + val fakeClass2 = classLoader.loadClass("FakeClass1").getConstructor().newInstance() assert(fakeClass.getClass === fakeClass2.getClass) classLoader.close() parentLoader.close() @@ -70,7 +69,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass3").newInstance() + val fakeClass = classLoader.loadClass("FakeClass3").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") classLoader.close() @@ -81,7 +80,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("FakeClassDoesNotExist").newInstance() + classLoader.loadClass("FakeClassDoesNotExist").getConstructor().newInstance() } classLoader.close() parentLoader.close() @@ -137,7 +136,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader // scalastyle:off classforname - Class.forName(className, true, loader).newInstance() + Class.forName(className, true, loader).getConstructor().newInstance() // scalastyle:on classforname Seq().iterator }.count() diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 2695295d451d5..63f9f82adf3e0 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -92,14 +92,14 @@ class SizeEstimatorSuite } test("primitive wrapper objects") { - assertResult(16)(SizeEstimator.estimate(new java.lang.Boolean(true))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Byte("1"))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Character('1'))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Short("1"))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Integer(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Float(1.0))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + assertResult(16)(SizeEstimator.estimate(java.lang.Boolean.TRUE)) + assertResult(16)(SizeEstimator.estimate(java.lang.Byte.valueOf("1"))) + assertResult(16)(SizeEstimator.estimate(java.lang.Character.valueOf('1'))) + assertResult(16)(SizeEstimator.estimate(java.lang.Short.valueOf("1"))) + assertResult(16)(SizeEstimator.estimate(java.lang.Integer.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Long.valueOf(1))) + assertResult(16)(SizeEstimator.estimate(java.lang.Float.valueOf(1.0f))) + assertResult(24)(SizeEstimator.estimate(java.lang.Double.valueOf(1.0))) } test("class field blocks rounding") { @@ -202,14 +202,14 @@ class SizeEstimatorSuite assertResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) // primitive wrapper classes - assertResult(24)(SizeEstimator.estimate(new java.lang.Boolean(true))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Byte("1"))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Character('1'))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Short("1"))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Integer(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Float(1.0))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + assertResult(24)(SizeEstimator.estimate(java.lang.Boolean.TRUE)) + assertResult(24)(SizeEstimator.estimate(java.lang.Byte.valueOf("1"))) + assertResult(24)(SizeEstimator.estimate(java.lang.Character.valueOf('1'))) + assertResult(24)(SizeEstimator.estimate(java.lang.Short.valueOf("1"))) + assertResult(24)(SizeEstimator.estimate(java.lang.Integer.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Long.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Float.valueOf(1.0f))) + assertResult(24)(SizeEstimator.estimate(java.lang.Double.valueOf(1.0))) } test("class field blocks rounding on 64-bit VM without useCompressedOops") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 65bf857e22c02..46a05e2ba798b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.lang.{Float => JFloat, Integer => JInteger} +import java.lang.{Float => JFloat} import java.util.{Arrays, Comparator} import org.apache.spark.SparkFunSuite @@ -48,7 +48,7 @@ class SorterSuite extends SparkFunSuite with Logging { // alternate. Keys are random doubles, values are ordinals from 0 to length. val keys = Array.tabulate[Double](5000) { i => rand.nextDouble() } val keyValueArray = Array.tabulate[Number](10000) { i => - if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + if (i % 2 == 0) keys(i / 2) else Integer.valueOf(i / 2) } // Map from generated keys to values, to verify correctness later @@ -112,7 +112,7 @@ class SorterSuite extends SparkFunSuite with Logging { // Test our key-value pairs where each element is a Tuple2[Float, Integer]. val kvTuples = Array.tabulate(numElements) { i => - (new JFloat(rand.nextFloat()), new JInteger(i)) + (JFloat.valueOf(rand.nextFloat()), Integer.valueOf(i)) } val kvTupleArray = new Array[AnyRef](numElements) @@ -167,23 +167,23 @@ class SorterSuite extends SparkFunSuite with Logging { val ints = Array.fill(numElements)(rand.nextInt()) val intObjects = { - val data = new Array[JInteger](numElements) + val data = new Array[Integer](numElements) var i = 0 while (i < numElements) { - data(i) = new JInteger(ints(i)) + data(i) = Integer.valueOf(ints(i)) i += 1 } data } - val intObjectArray = new Array[JInteger](numElements) + val intObjectArray = new Array[Integer](numElements) val prepareIntObjectArray = () => { System.arraycopy(intObjects, 0, intObjectArray, 0, numElements) } runExperiment("Java Arrays.sort() on non-primitive int array")({ - Arrays.sort(intObjectArray, new Comparator[JInteger] { - override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y) + Arrays.sort(intObjectArray, new Comparator[Integer] { + override def compare(x: Integer, y: Integer): Int = x.compareTo(y) }) }, prepareIntObjectArray) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..d570630c1a095 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -78,7 +78,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => JLong.valueOf(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 4fea2cb969446..8d6cca8e48c3d 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -508,7 +508,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect() assert( union2 - .map(x => new java.lang.Double(x(0).toString)) + .map(x => java.lang.Double.valueOf(x(0).toString)) .exists(p => Math.abs(p - Math.PI) < 0.001)) val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect() diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 28c9853bfea9c..5034bd73d6e74 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -510,7 +510,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, java.lang.Integer.valueOf(1)) // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index ba4009ef08856..224f41a683955 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -70,7 +70,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( @transient private var kc: Consumer[K, V] = null def consumer(): Consumer[K, V] = this.synchronized { if (null == kc) { - kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava) + kc = consumerStrategy.onStart(currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).asJava) } kc } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a0ac26a34d8c8..d985f8ca1ecc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -256,7 +256,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { s"Multiple writers found for $source+$stageName, try using the class name of the writer") } if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { - val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + val writer = writerCls.getConstructor().newInstance().asInstanceOf[MLWriterFormat] writer.write(path, sparkSession, optionMap, stage) } else { throw new SparkException(s"ML source $source is not a valid MLWriterFormat") diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala index bdceba7887cac..8371c33a209dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala @@ -31,7 +31,7 @@ class MatrixUDTSuite extends SparkFunSuite { val sm3 = dm3.toSparse for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) { - val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.getConstructor().newInstance() .asInstanceOf[MatrixUDT] assert(m === udt.deserialize(udt.serialize(m))) assert(udt.typeName == "matrix") diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 6ddb12cb76aac..67c64f762b25e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -31,7 +31,7 @@ class VectorUDTSuite extends SparkFunSuite { val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) for (v <- Seq(dv1, dv2, sv1, sv2)) { - val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.getConstructor().newInstance() .asInstanceOf[VectorUDT] assert(v === udt.deserialize(udt.serialize(v))) assert(udt.typeName == "vector") diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index e5e2094368fb0..ac528ecb829b0 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -126,7 +126,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) - val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") } @@ -134,7 +134,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) - val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass1").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") } @@ -142,7 +142,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) - val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass3").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") } @@ -151,7 +151,7 @@ class ExecutorClassLoaderSuite val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } } @@ -202,11 +202,11 @@ class ExecutorClassLoaderSuite val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", getClass().getClassLoader(), false) - val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala index 4ed285230ff81..7d15f0e2fbac8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -107,7 +107,7 @@ private[spark] class SchedulerExtensionServices extends SchedulerExtensionServic services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass => val instance = Utils.classForName(sClass) - .newInstance() + .getConstructor().newInstance() .asInstanceOf[SchedulerExtensionService] // bind this service instance.start(binding) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 8ef8b2be6939c..311060e5961cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -73,10 +73,10 @@ object JavaTypeInference { : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true) case c: Class[_] if UDTRegistration.exists(c.getName) => - val udt = UDTRegistration.getUDTFor(c.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance() .asInstanceOf[UserDefinedType[_ >: Null]] (udt, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 912744eab6a3a..64ea236532839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -357,7 +357,8 @@ object ScalaReflection extends ScalaReflection { ) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -365,8 +366,8 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, @@ -601,7 +602,7 @@ object ScalaReflection extends ScalaReflection { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -609,8 +610,8 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, @@ -721,11 +722,12 @@ object ScalaReflection extends ScalaReflection { // Null type would wrongly match the first of them, which is Option as of now case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() Schema(udt, nullable = true) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c11b444212946..b6771ec4dffe9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1134,7 +1134,8 @@ class SessionCatalog( if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) - .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .newInstance(input, + clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) .asInstanceOf[ImplicitCastInputTypes] // Check input argument size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b868a0f4fa284..7c8f7cd4315b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1305,7 +1305,7 @@ object CodeGenerator extends Logging { throw new CompileException(msg, e.getLocation) } - (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) + (evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e53628d11ccf3..33fc4b9480126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -180,7 +180,7 @@ object DataType { ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => - Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + Utils.classForName(udtClass).getConstructor().newInstance().asInstanceOf[UserDefinedType[_]] // Python UDT case JSortedObject( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e9b100b3b30db..be8fd90c4c52a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -128,13 +128,13 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(-3.7f, "primitive float") encodeDecodeTest(-3.7, "primitive double") - encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + encodeDecodeTest(java.lang.Boolean.FALSE, "boxed boolean") + encodeDecodeTest(java.lang.Byte.valueOf(-3: Byte), "boxed byte") + encodeDecodeTest(java.lang.Short.valueOf(-3: Short), "boxed short") + encodeDecodeTest(java.lang.Integer.valueOf(-3), "boxed int") + encodeDecodeTest(java.lang.Long.valueOf(-3L), "boxed long") + encodeDecodeTest(java.lang.Float.valueOf(-3.7f), "boxed float") + encodeDecodeTest(java.lang.Double.valueOf(-3.7), "boxed double") encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") @@ -224,7 +224,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes productTest( RepeatedData( Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), + Seq(Integer.valueOf(1), null, Integer.valueOf(2)), Map(1 -> 2L), Map(1 -> null), PrimitiveData(1, 1, 1, 1, 1, 1, true))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d145fd0aaba47..16842c1bcc8cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -307,7 +307,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val conf = new SparkConf() Seq(true, false).foreach { useKryo => val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) - val expected = serializer.newInstance().serialize(new Integer(1)).array() + val expected = serializer.newInstance().serialize(Integer.valueOf(1)).array() val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) @@ -384,9 +384,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val conf = new SparkConf() Seq(true, false).foreach { useKryo => val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) - val input = serializer.newInstance().serialize(new Integer(1)).array() + val input = serializer.newInstance().serialize(Integer.valueOf(1)).array() val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) - checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, Integer.valueOf(1), InternalRow.fromSeq(Seq(input))) checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } @@ -575,7 +575,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // NULL key test val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( - null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + null.asInstanceOf[java.lang.Integer] -> "v0", java.lang.Integer.valueOf(1) -> "v1") val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { { put(null, "v0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 294fce8e9a10f..63c7b42978025 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -41,9 +41,9 @@ class PercentileSuite extends SparkFunSuite { val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) - // Check non-empty buffer serializa and deserialize. + // Check non-empty buffer serialize and deserialize. data.foreach { key => - buffer.changeValue(new Integer(key), 1L, _ + 1L) + buffer.changeValue(Integer.valueOf(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 02ffc940184db..df18623e42a02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -194,7 +194,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5a28870f5d3c2..1b4998f94b25d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val source = cls.newInstance().asInstanceOf[DataSourceV2] + val source = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] source match { case provider: BatchWriteSupportProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 71f967a59d77e..c0727e844a1ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1144,7 +1144,7 @@ object SparkSession extends Logging { val extensionConfClassName = extensionOption.get try { val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() + val extensionConf = extensionConfClass.getConstructor().newInstance() .asInstanceOf[SparkSessionExtensions => Unit] extensionConf(extensions) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index aa3a6c3bf122f..84da097be53c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -670,7 +670,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class $className") } else { try { - val udf = clazz.newInstance() + val udf = clazz.getConstructor().newInstance() val udfReturnType = udfInterfaces(0).getActualTypeArguments.last var returnType = returnDataType if (returnType == null) { @@ -727,7 +727,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") } - val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index af20764f9a968..becb05cf72aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -111,7 +111,7 @@ private[sql] object SQLUtils extends Logging { private[this] def doConversion(data: Object, dataType: DataType): Object = { data match { case d: java.lang.Double if dataType == FloatType => - new java.lang.Float(d) + java.lang.Float.valueOf(d.toFloat) // Scala Map is the only allowed external type of map type in Row. case m: java.util.Map[_, _] => m.asScala case _ => data diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 823dc0d5ed387..e2cd40906f401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -231,7 +231,8 @@ case class AlterTableAddColumnsCommand( } if (DDLUtils.isDatasourceTable(catalogTable)) { - DataSource.lookupDataSource(catalogTable.provider.get, conf).newInstance() match { + DataSource.lookupDataSource(catalogTable.provider.get, conf). + getConstructor().newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ce3bc3dd48327..795a6d0b6b040 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -204,7 +204,7 @@ case class DataSource( /** Returns the name and schema of the source that can be used to continually read data. */ private def sourceSchema(): SourceInfo = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions) @@ -250,7 +250,7 @@ case class DataSource( /** Returns a source that can be used to continually read data. */ def createSource(metadataPath: String): Source = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSourceProvider => s.createSource( sparkSession.sqlContext, @@ -279,7 +279,7 @@ case class DataSource( /** Returns a sink that can be used to continually write data. */ def createSink(outputMode: OutputMode): Sink = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode) @@ -310,7 +310,7 @@ case class DataSource( * that files already exist, we don't need to check them again. */ def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { - val relation = (providingClass.newInstance(), userSpecifiedSchema) match { + val relation = (providingClass.getConstructor().newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) @@ -479,7 +479,7 @@ case class DataSource( throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) @@ -516,7 +516,7 @@ case class DataSource( throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 1723596de1db2..530d836d9fde3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -50,7 +50,7 @@ object DriverRegistry extends Logging { } else { synchronized { if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + val wrapper = new DriverWrapper(cls.getConstructor().newInstance().asInstanceOf[Driver]) DriverManager.registerDriver(wrapper) wrapperMap(className) = wrapper logTrace(s"Wrapper for $className registered") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d3313b8a315c9..7d785aa09cd9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -213,7 +213,7 @@ object StateStoreProvider { */ def create(providerClassName: String): StateStoreProvider = { val providerClass = Utils.classForName(providerClassName) - providerClass.newInstance().asInstanceOf[StateStoreProvider] + providerClass.getConstructor().newInstance().asInstanceOf[StateStoreProvider] } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 20c84305776ae..bf6021e692382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -158,7 +158,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo "read files of Hive data source directly.") } - val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). + getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 4a8c7fdb58ff1..b36a8f3f6f15b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -307,7 +307,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") var options = extraOptions.toMap - val sink = ds.newInstance() match { + val sink = ds.getConstructor().newInstance() match { case w: StreamingWriteSupportProvider if !disabledSources.contains(w.getClass.getCanonicalName) => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 3ab4db2a035d3..ca78d6489ef5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -67,20 +67,20 @@ public void setUp() { public void constructSimpleRow() { Row simpleRow = RowFactory.create( byteValue, // ByteType - new Byte(byteValue), + Byte.valueOf(byteValue), shortValue, // ShortType - new Short(shortValue), + Short.valueOf(shortValue), intValue, // IntegerType - new Integer(intValue), + Integer.valueOf(intValue), longValue, // LongType - new Long(longValue), + Long.valueOf(longValue), floatValue, // FloatType - new Float(floatValue), + Float.valueOf(floatValue), doubleValue, // DoubleType - new Double(doubleValue), + Double.valueOf(doubleValue), decimalValue, // DecimalType booleanValue, // BooleanType - new Boolean(booleanValue), + Boolean.valueOf(booleanValue), stringValue, // StringType binaryValue, // BinaryType dateValue, // DateType diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java index b90224f2ae397..5955eabe496df 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -25,6 +25,6 @@ public class JavaStringLength implements UDF1 { @Override public Integer call(String str) throws Exception { - return new Integer(str.length()); + return Integer.valueOf(str.length()); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index d635912cf7205..52708f5fe4108 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -208,7 +208,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { test("percentile_approx(col, ...), input rows contains null, with out group by") { withTempView(table) { - (1 to 1000).map(new Integer(_)).flatMap(Seq(null: Integer, _)).toDF("col") + (1 to 1000).map(Integer.valueOf(_)).flatMap(Seq(null: Integer, _)).toDF("col") .createOrReplaceTempView(table) checkAnswer( spark.sql( @@ -226,8 +226,8 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { withTempView(table) { val rand = new java.util.Random() (1 to 1000) - .map(new Integer(_)) - .map(v => (new Integer(v % 2), v)) + .map(Integer.valueOf(_)) + .map(v => (Integer.valueOf(v % 2), v)) // Add some nulls .flatMap(Seq(_, (null: Integer, null: Integer))) .toDF("key", "value").createOrReplaceTempView(table) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index edde9bfd088cf..2bb18f48e0ae2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1986,7 +1986,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11725: correctly handle null inputs for ScalaUDF") { val df = sparkContext.parallelize(Seq( - new java.lang.Integer(22) -> "John", + java.lang.Integer.valueOf(22) -> "John", null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") // passing null into the UDF that could handle it @@ -2219,9 +2219,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: no change on nullability in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) @@ -2236,9 +2236,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: set nullability to false in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 82d3b22a48670..75d06510376ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -697,15 +697,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-11894: Incorrect results are returned when using null") { val nullInt = null.asInstanceOf[java.lang.Integer] - val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() - val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS() checkDataset( ds1.joinWith(ds2, lit(true), "cross"), ((nullInt, "1"), (nullInt, "1")), - ((nullInt, "1"), (new java.lang.Integer(22), "2")), - ((new java.lang.Integer(22), "2"), (nullInt, "1")), - ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + ((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")), + ((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")), + ((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22), "2"))) } test("change encoder with compatible schema") { @@ -881,7 +881,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.rdd.map(r => r.id).count === 2) assert(ds2.rdd.map(r => r.id).count === 2) - val ds3 = ds.map(g => new java.lang.Long(g.id)) + val ds3 = ds.map(g => java.lang.Long.valueOf(g.id)) assert(ds3.rdd.map(r => r).count === 2) } @@ -1499,7 +1499,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) withTempPath { path => - Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath) + Seq(Integer.valueOf(1), null).toDF("i").write.parquet(path.getCanonicalPath) // If the primitive values are from files, we need to do runtime null check. val ds = spark.read.parquet(path.getCanonicalPath).as[Int] intercept[NullPointerException](ds.collect()) @@ -1553,7 +1553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") checkAnswer(df.where('city === 'X'), Seq(Row("X"))) checkAnswer( - df.where($"city".contains(new java.lang.Character('A'))), + df.where($"city".contains(java.lang.Character.valueOf('A'))), Seq(Row("Amsterdam"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 3b301a4f8144a..20dcefa7e3cad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -413,7 +413,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("SPARK-25044 Verify null input handling for primitive types - with udf.register") { withTable("t") { - Seq((null, new Integer(1), "x"), ("M", null, "y"), ("N", new Integer(3), null)) + Seq((null, Integer.valueOf(1), "x"), ("M", null, "y"), ("N", Integer.valueOf(3), null)) .toDF("a", "b", "c").write.format("json").saveAsTable("t") spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c) val df = spark.sql("SELECT f(a, b, c) FROM t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index d4e7e362c6c8c..3121b7e99c99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -39,7 +39,7 @@ class ColumnStatsSuite extends SparkFunSuite { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { - val columnStats = columnStatsClass.newInstance() + val columnStats = columnStatsClass.getConstructor().newInstance() columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } @@ -48,7 +48,7 @@ class ColumnStatsSuite extends SparkFunSuite { test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val columnStats = columnStatsClass.newInstance() + val columnStats = columnStatsClass.getConstructor().newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index dd74af873c2e5..be3efed714030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -53,7 +53,8 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + DataSource.lookupDataSource("rate", spark.sqlContext.conf). + getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => val readSupport = ds.createMicroBatchReadSupport( temp.getCanonicalPath, DataSourceOptions.empty()) @@ -66,7 +67,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { + spark.sqlContext.conf).getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => @@ -320,7 +321,8 @@ class RateSourceSuite extends StreamTest { } test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + DataSource.lookupDataSource("rate", spark.sqlContext.conf). + getConstructor().newInstance() match { case ds: ContinuousReadSupportProvider => val readSupport = ds.createContinuousReadSupport( "", DataSourceOptions.empty()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 409156e5ebc70..635ea6fca649c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -84,7 +84,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", - spark.sqlContext.conf).newInstance() match { + spark.sqlContext.conf).getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 4911e3225552d..f903c17923d0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,7 +33,7 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].newInstance() + val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3a0e780a73915..31fce46c2daba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -261,7 +261,7 @@ class StreamingDataSourceV2Suite extends StreamTest { ).foreach { case (source, trigger) => test(s"SPARK-25460: session options are respected in structured streaming sources - $source") { // `keyPrefix` and `shortName` are the same in this test case - val readSource = source.newInstance().shortName() + val readSource = source.getConstructor().newInstance().shortName() val writeSource = "fake-write-microbatch-continuous" val readOptionName = "optionA" @@ -299,8 +299,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf). + getConstructor().newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). + getConstructor().newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java index adb269aa235ea..26d0f718f383a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java @@ -349,7 +349,7 @@ public void addValue(Type type, Object field) { break; case FLOAT_TYPE: nulls.set(size, field == null); - doubleVars()[size] = field == null ? 0 : new Double(field.toString()); + doubleVars()[size] = field == null ? 0 : Double.valueOf(field.toString()); break; case DOUBLE_TYPE: nulls.set(size, field == null); diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d047953327958..5823548a8063c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -124,7 +124,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val tablePath = new Path(relation.tableMeta.location) - val fileFormat = fileFormatClass.newInstance() + val fileFormat = fileFormatClass.getConstructor().newInstance() val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 11afe1af32809..c9fc3d4a02c4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -217,7 +217,7 @@ private[hive] object HiveShim { instance.asInstanceOf[UDFType] } else { val func = Utils.getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + .loadClass(functionClassName).getConstructor().newInstance().asInstanceOf[UDFType] if (!func.isInstanceOf[UDF]) { // We cache the function if it's no the Simple UDF, // as we always have to create new instance for Simple UDF diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 9443fbb4330a5..536bc4a3f4ec4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -132,7 +132,7 @@ class HadoopTableReader( val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value - val deserializer = deserializerClass.newInstance() + val deserializer = deserializerClass.getConstructor().newInstance() deserializer.initialize(hconf, localTableDesc.getProperties) HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } @@ -245,7 +245,7 @@ class HadoopTableReader( val localTableDesc = tableDesc createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val deserializer = localDeserializer.newInstance() + val deserializer = localDeserializer.getConstructor().newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema // information) may be defined in table properties. Here we should merge table properties // and partition properties before initializing the deserializer. Note that partition @@ -257,7 +257,7 @@ class HadoopTableReader( } deserializer.initialize(hconf, props) // get the table deserializer - val tableSerDe = localTableDesc.getDeserializerClass.newInstance() + val tableSerDe = localTableDesc.getDeserializerClass.getConstructor().newInstance() tableSerDe.initialize(hconf, localTableDesc.getProperties) // fill the non partition key attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bc9d4cd7f4181..4d484904d2c27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -987,7 +987,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { part: JList[String], deleteData: Boolean, purge: Boolean): Unit = { - val dropOptions = dropOptionsClass.newInstance().asInstanceOf[Object] + val dropOptions = dropOptionsClass.getConstructor().newInstance().asInstanceOf[Object] dropOptionsDeleteData.setBoolean(dropOptions, deleteData) dropOptionsPurge.setBoolean(dropOptions, purge) dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index 4a7cd6901923b..d8d2a80e0e8b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -115,7 +115,8 @@ class HiveOutputWriter( private def tableDesc = fileSinkConf.getTableInfo private val serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + val serializer = tableDesc.getDeserializerClass.getConstructor(). + newInstance().asInstanceOf[Serializer] serializer.initialize(jobConf, tableDesc.getProperties) serializer } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 92c6632ad7863..fa940fe73bd13 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -120,7 +120,7 @@ case class HiveTableScanExec( HiveShim.appendReadColumns(hiveConf, neededColumnIDs, output.map(_.name)) - val deserializer = tableDesc.getDeserializerClass.newInstance + val deserializer = tableDesc.getDeserializerClass.getConstructor().newInstance() deserializer.initialize(hiveConf, tableDesc.getProperties) // Specifies types and object inspectors of columns to be scanned. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index 3328400b214fb..7b35a5f920ae9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -123,7 +123,7 @@ case class ScriptTransformationExec( var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().newInstance + outputSerde.getSerializedClass().getConstructor().newInstance() } else { null } @@ -404,7 +404,8 @@ case class HiveScriptIOSchema ( columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] + val serde = Utils.classForName(serdeClassName).getConstructor(). + newInstance().asInstanceOf[AbstractSerDe] val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") @@ -424,7 +425,8 @@ case class HiveScriptIOSchema ( inputStream: InputStream, conf: Configuration): Option[RecordReader] = { recordReaderClass.map { klass => - val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val instance = Utils.classForName(klass).getConstructor(). + newInstance().asInstanceOf[RecordReader] val props = new Properties() // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 @@ -436,7 +438,8 @@ case class HiveScriptIOSchema ( def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { recordWriterClass.map { klass => - val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + val instance = Utils.classForName(klass).getConstructor(). + newInstance().asInstanceOf[RecordWriter] instance.initialize(outputStream, conf) instance } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala index 0d4f040156084..68a0c1213ec20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala @@ -152,7 +152,7 @@ class HiveParquetMetastoreSuite extends ParquetPartitioningTest { } (1 to 10).map(i => (i, s"str$i")).toDF("a", "b").createOrReplaceTempView("jt") - (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a") + (1 to 10).map(i => Tuple1(Seq(Integer.valueOf(i), null))).toDF("a") .createOrReplaceTempView("jt_array") assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8d83dc8a8fc04..6f0b46b6a4cb3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -49,11 +49,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.util.SystemClock") try { - Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(clockClass).getConstructor().newInstance().asInstanceOf[Clock] } catch { case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") - Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(newClockClass).getConstructor().newInstance().asInstanceOf[Clock] } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 19b621f11759d..2332ee2ab9de1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -808,7 +808,8 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester // visible to mutableURLClassLoader val loader = new MutableURLClassLoader( Array(jar), appClassLoader) - assert(loader.loadClass("testClz").newInstance().toString == "testStringValue") + assert(loader.loadClass("testClz").getConstructor().newInstance().toString === + "testStringValue") // create and serialize Array[testClz] // scalastyle:off classforname From 6cd23482d1ae8c6a9fe9817ed51ee2a039d46649 Mon Sep 17 00:00:00 2001 From: Patrick Brown Date: Sat, 10 Nov 2018 12:51:24 -0600 Subject: [PATCH 2034/2461] [SPARK-25839][CORE] Implement use of KryoPool in KryoSerializer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? * Implement (optional) use of KryoPool in KryoSerializer, an alternative to the existing implementation of caching a Kryo instance inside KryoSerializerInstance * Add config key & documentation of spark.kryo.pool in order to turn this on * Add benchmark KryoSerializerBenchmark to compare new and old implementation * Add results of benchmark ## How was this patch tested? Added new tests inside KryoSerializerSuite to test the pool implementation as well as added the pool option to the existing regression testing for SPARK-7766 This is my original work and I license the work to the project under the project’s open source license. Closes #22855 from patrickbrownsync/kryo-pool. Authored-by: Patrick Brown Signed-off-by: Sean Owen --- .../KryoSerializerBenchmark-results.txt | 12 +++ .../spark/serializer/KryoSerializer.scala | 72 ++++++++++++--- .../spark/benchmark/BenchmarkBase.scala | 7 ++ .../serializer/KryoSerializerBenchmark.scala | 90 +++++++++++++++++++ .../serializer/KryoSerializerSuite.scala | 66 ++++++++++++-- 5 files changed, 230 insertions(+), 17 deletions(-) create mode 100644 core/benchmarks/KryoSerializerBenchmark-results.txt create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt new file mode 100644 index 0000000000000..c3ce336d93241 --- /dev/null +++ b/core/benchmarks/KryoSerializerBenchmark-results.txt @@ -0,0 +1,12 @@ +================================================================================================ +Benchmark KryoPool vs "pool of 1" +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.14 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz +Benchmark KryoPool vs "pool of 1": Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +KryoPool:true 2682 / 3425 0.0 5364627.9 1.0X +KryoPool:false 8176 / 9292 0.0 16351252.2 0.3X + + diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 218c84352ce88..3795d5c3b38e3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} +import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} @@ -84,6 +85,7 @@ class KryoSerializer(conf: SparkConf) private val avroSchemas = conf.getAvroSchema // whether to use unsafe based IO for serialization private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false) + private val usePool = conf.getBoolean("spark.kryo.pool", true) def newKryoOutput(): KryoOutput = if (useUnsafe) { @@ -92,6 +94,36 @@ class KryoSerializer(conf: SparkConf) new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) } + @transient + private lazy val factory: KryoFactory = new KryoFactory() { + override def create: Kryo = { + newKryo() + } + } + + private class PoolWrapper extends KryoPool { + private var pool: KryoPool = getPool + + override def borrow(): Kryo = pool.borrow() + + override def release(kryo: Kryo): Unit = pool.release(kryo) + + override def run[T](kryoCallback: KryoCallback[T]): T = pool.run(kryoCallback) + + def reset(): Unit = { + pool = getPool + } + + private def getPool: KryoPool = { + new KryoPool.Builder(factory).softReferences.build + } + } + + @transient + private lazy val internalPool = new PoolWrapper + + def pool: KryoPool = internalPool + def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() @@ -215,8 +247,14 @@ class KryoSerializer(conf: SparkConf) kryo } + override def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + super.setDefaultClassLoader(classLoader) + internalPool.reset() + this + } + override def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this, useUnsafe) + new KryoSerializerInstance(this, useUnsafe, usePool) } private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = { @@ -299,7 +337,8 @@ class KryoDeserializationStream( } } -private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean) +private[spark] class KryoSerializerInstance( + ks: KryoSerializer, useUnsafe: Boolean, usePool: Boolean) extends SerializerInstance { /** * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do @@ -307,22 +346,29 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are * not synchronized. */ - @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + @Nullable private[this] var cachedKryo: Kryo = if (usePool) null else borrowKryo() /** * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; * otherwise, it allocates a new instance. */ private[serializer] def borrowKryo(): Kryo = { - if (cachedKryo != null) { - val kryo = cachedKryo - // As a defensive measure, call reset() to clear any Kryo state that might have been modified - // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + if (usePool) { + val kryo = ks.pool.borrow() kryo.reset() - cachedKryo = null kryo } else { - ks.newKryo() + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have + // been modified by the last operation to borrow this instance + // (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } } } @@ -332,8 +378,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole * re-use. */ private[serializer] def releaseKryo(kryo: Kryo): Unit = { - if (cachedKryo == null) { - cachedKryo = kryo + if (usePool) { + ks.pool.release(kryo) + } else { + if (cachedKryo == null) { + cachedKryo = kryo + } } } diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 24e596e1ecdaf..a6666db4e95c3 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -58,5 +58,12 @@ abstract class BenchmarkBase { o.close() } } + + afterAll() } + + /** + * Any shutdown code to ensure a clean shutdown + */ + def afterAll(): Unit = {} } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala new file mode 100644 index 0000000000000..2a15c6f6a2d96 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import scala.concurrent._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.ThreadUtils + +/** + * Benchmark for KryoPool vs old "pool of 1". + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "core/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain " + * Results will be written to "benchmarks/KryoSerializerBenchmark-results.txt". + * }}} + */ +object KryoSerializerBenchmark extends BenchmarkBase { + + var sc: SparkContext = null + val N = 500 + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val name = "Benchmark KryoPool vs old\"pool of 1\" implementation" + runBenchmark(name) { + val benchmark = new Benchmark(name, N, 10, output = output) + Seq(true, false).foreach(usePool => run(usePool, benchmark)) + benchmark.run() + } + } + + private def run(usePool: Boolean, benchmark: Benchmark): Unit = { + lazy val sc = createSparkContext(usePool) + + benchmark.addCase(s"KryoPool:$usePool") { _ => + val futures = for (_ <- 0 until N) yield { + Future { + sc.parallelize(0 until 10).map(i => i + 1).count() + } + } + + val future = Future.sequence(futures) + + ThreadUtils.awaitResult(future, 10.minutes) + } + } + + def createSparkContext(usePool: Boolean): SparkContext = { + val conf = new SparkConf() + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + conf.set("spark.kryo.pool", usePool.toString) + + if (sc != null) { + sc.stop() + } + + sc = new SparkContext("local-cluster[4,1,1024]", "test", conf) + sc + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index ac25bcef54349..84af73b08d3e7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -31,7 +34,7 @@ import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -308,7 +311,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrator", "this.class.does.not.exist") - val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance().serialize(1)) assert(thrown.getMessage.contains("Failed to register classes with Kryo")) } @@ -431,9 +434,11 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { ser.deserialize[HashMap[Int, List[String]]](serializedMap) } - private def testSerializerInstanceReuse(autoReset: Boolean, referenceTracking: Boolean): Unit = { + private def testSerializerInstanceReuse( + autoReset: Boolean, referenceTracking: Boolean, usePool: Boolean): Unit = { val conf = new SparkConf(loadDefaults = false) .set("spark.kryo.referenceTracking", referenceTracking.toString) + .set("spark.kryo.pool", usePool.toString) if (!autoReset) { conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) } @@ -456,9 +461,58 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { // Regression test for SPARK-7766, an issue where disabling auto-reset and enabling // reference-tracking would lead to corrupted output when serializer instances are re-used - for (referenceTracking <- Set(true, false); autoReset <- Set(true, false)) { - test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking") { - testSerializerInstanceReuse(autoReset = autoReset, referenceTracking = referenceTracking) + for { + referenceTracking <- Seq(true, false) + autoReset <- Seq(true, false) + usePool <- Seq(true, false) + } { + test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking" + + s", usePool = $usePool") { + testSerializerInstanceReuse( + autoReset, referenceTracking, usePool) + } + } + + test("SPARK-25839 KryoPool implementation works correctly in multi-threaded environment") { + implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(4)) + + val ser = new KryoSerializer(conf.clone.set("spark.kryo.pool", "true")) + + val tests = mutable.ListBuffer[Future[Boolean]]() + + def check[T: ClassTag](t: T) { + tests += Future { + val serializerInstance = ser.newInstance() + serializerInstance.deserialize[T](serializerInstance.serialize(t)) === t + } + } + + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) + + tests.foreach { f => + assert(ThreadUtils.awaitResult(f, 10.seconds)) } } } From a3ba3a899b3b43958820dc82fcdd3a8b28653bcb Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 11 Nov 2018 14:05:19 +0800 Subject: [PATCH 2035/2461] [INFRA] Close stale PRs Closes https://github.com/apache/spark/pull/21766 Closes https://github.com/apache/spark/pull/21679 Closes https://github.com/apache/spark/pull/21161 Closes https://github.com/apache/spark/pull/20846 Closes https://github.com/apache/spark/pull/19434 Closes https://github.com/apache/spark/pull/18080 Closes https://github.com/apache/spark/pull/17648 Closes https://github.com/apache/spark/pull/17169 Add: Closes #22813 Closes #21994 Closes #22005 Closes #22463 Add: Closes #15899 Add: Closes #22539 Closes #21868 Closes #21514 Closes #21402 Closes #21322 Closes #21257 Closes #20163 Closes #19691 Closes #18697 Closes #18636 Closes #17176 Closes #23001 from wangyum/CloseStalePRs. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon From aec0af4a952df2957e21d39d1e0546a36ab7ab86 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 11 Nov 2018 21:01:29 +0800 Subject: [PATCH 2036/2461] [SPARK-25972][PYTHON] Missed JSON options in streaming.py ## What changes were proposed in this pull request? Added JSON options for `json()` in streaming.py that are presented in the similar method in readwriter.py. In particular, missed options are `dropFieldIfAllNull` and `encoding`. Closes #22973 from MaxGekk/streaming-missed-options. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/streaming.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 02b14ea187cba..58ca7b83e5b2b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -404,7 +404,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None, + dropFieldIfAllNull=None, encoding=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -472,6 +473,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -486,7 +494,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale, + dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: From 510ec77a601db1c0fa338dd76a0ea7af63441fd3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 11 Nov 2018 09:21:40 -0600 Subject: [PATCH 2037/2461] [SPARK-19714][DOCS] Clarify Bucketizer handling of invalid input ## What changes were proposed in this pull request? Clarify Bucketizer handleInvalid docs. Just a resubmit of https://github.com/apache/spark/pull/17169 ## How was this patch tested? N/A Closes #23003 from srowen/SPARK-19714. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 6 ++++-- python/pyspark/ml/feature.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f99649f7fa164..0b989b0d7d253 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -89,7 +89,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Param for how to handle invalid entries. Options are 'skip' (filter out rows with + * Param for how to handle invalid entries containing NaN values. Values outside the splits + * will always be treated as errors. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special * additional bucket). Note that in the multiple column case, the invalid handling is applied * to all columns. That said for 'error' it will throw an error if any invalids are found in @@ -99,7 +100,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String */ @Since("2.1.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "how to handle invalid entries. Options are skip (filter out rows with invalid values), " + + "how to handle invalid entries containing NaN values. Values outside the splits will always " + + "be treated as errorsOptions are skip (filter out rows with invalid values), " + "error (throw an error), or keep (keep invalid values in a special additional bucket).", ParamValidators.inArray(Bucketizer.supportedHandleInvalids)) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index eccb7acae5b98..3d23700242594 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -361,8 +361,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, "splits specified will be treated as errors.", typeConverter=TypeConverters.toListFloat) - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + - "Options are 'skip' (filter out rows with invalid values), " + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries " + "containing NaN values. Values outside the splits will always be treated " + "as errors. Options are 'skip' (filter out rows with invalid values), " + "'error' (throw an error), or 'keep' (keep invalid values in a special " + "additional bucket).", typeConverter=TypeConverters.toString) From d0ae48497c093cef23fb95c10aa448b3b498c758 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 12 Nov 2018 15:16:15 +0800 Subject: [PATCH 2038/2461] [SPARK-25949][SQL] Add test for PullOutPythonUDFInJoinCondition ## What changes were proposed in this pull request? As comment in https://github.com/apache/spark/pull/22326#issuecomment-424923967, we test the new added optimizer rule by end-to-end test in python side, need to add suites under `org.apache.spark.sql.catalyst.optimizer` like other optimizer rules. ## How was this patch tested? new added UT Closes #22955 from xuanyuanking/SPARK-25949. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- ...PullOutPythonUDFInJoinConditionSuite.scala | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala new file mode 100644 index 0000000000000..d3867f2b6bd0e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.scalatest.Matchers._ + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.types.BooleanType + +class PullOutPythonUDFInJoinConditionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Extract PythonUDF From JoinCondition", Once, + PullOutPythonUDFInJoinCondition) :: + Batch("Check Cartesian Products", Once, + CheckCartesianProducts) :: Nil + } + + val testRelationLeft = LocalRelation('a.int, 'b.int) + val testRelationRight = LocalRelation('c.int, 'd.int) + + // Dummy python UDF for testing. Unable to execute. + val pythonUDF = PythonUDF("pythonUDF", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + + val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) + + private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = { + // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false + val exception = intercept[AnalysisException] { + Optimize.execute(query.analyze) + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + // pull out the python udf while set spark.sql.crossJoin.enabled=true + withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + } + + test("inner join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("left semi join condition with python udf only") { + val query = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(pythonUDF)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF).select('a, 'b).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("python udf and common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF && 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("python udf or common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(pythonUDF || 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("pull out whole complex condition with multiple python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1 + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(condition).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("partial pull out complex condition with multiple python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq.empty, + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("throw an exception for not support join type") { + for (joinType <- unsupportedJoinTypes) { + val thrownException = the [AnalysisException] thrownBy { + val query = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(pythonUDF)) + Optimize.execute(query.analyze) + } + assert(thrownException.message.contentEquals( + s"Using PythonUDF in join condition of join type $joinType is not supported.")) + } + } +} + From 0ba9715c7d1ef1eabc276320c81f0acb20bafb59 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 11 Nov 2018 23:21:47 -0800 Subject: [PATCH 2039/2461] [SPARK-26005][SQL] Upgrade ANTRL from 4.7 to 4.7.1 ## What changes were proposed in this pull request? Based on the release description of ANTRL 4.7.1., https://github.com/antlr/antlr4/releases, let us upgrade our parser to 4.7.1. ## How was this patch tested? N/A Closes #23005 from gatorsmile/upgradeAntlr4.7. Authored-by: gatorsmile Signed-off-by: gatorsmile --- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 15a570908cc9a..a3030bd601534 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -5,7 +5,7 @@ activation-1.1.1.jar aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.7.jar +antlr4-runtime-4.7.1.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 6d9191a4abb4c..4354e76b521fc 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -7,7 +7,7 @@ activation-1.1.1.jar aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.7.jar +antlr4-runtime-4.7.1.jar aopalliance-1.0.jar aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar diff --git a/pom.xml b/pom.xml index a08b7fda33387..f58959b665e1b 100644 --- a/pom.xml +++ b/pom.xml @@ -174,7 +174,7 @@ 3.5.2 3.0.0 0.9.3 - 4.7 + 4.7.1 1.1 2.52.0 2.6 - 3.5 + 3.8.1 3.2.10 3.0.10 2.22.2 @@ -2016,7 +2016,7 @@ net.alchim31.maven scala-maven-plugin - 3.2.2 + 3.4.4 eclipse-add-source @@ -2281,7 +2281,19 @@ org.apache.maven.plugins maven-shade-plugin - 3.1.0 + 3.2.0 + + + org.ow2.asm + asm + 7.0 + + + org.ow2.asm + asm-commons + 7.0 + + org.apache.maven.plugins @@ -2296,7 +2308,7 @@ org.apache.maven.plugins maven-dependency-plugin - 3.0.2 + 3.1.1 default-cli diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 341a7fdbb59b8..a10245b372d71 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -19,7 +19,6 @@ package org.apache.hive.service.cli.thrift; import java.util.Arrays; -import java.util.concurrent.ExecutorService; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -65,7 +64,7 @@ public void run() { // Server thread pool // Start with minWorkerThreads, expand till maxWorkerThreads and reject subsequent requests String threadPoolName = "HiveServer2-HttpHandler-Pool"; - ExecutorService executorService = new ThreadPoolExecutor(minWorkerThreads, maxWorkerThreads, + ThreadPoolExecutor executorService = new ThreadPoolExecutor(minWorkerThreads, maxWorkerThreads, workerKeepAliveTime, TimeUnit.SECONDS, new SynchronousQueue(), new ThreadFactoryWithGarbageCleanup(threadPoolName)); ExecutorThreadPool threadPool = new ExecutorThreadPool(executorService); From 2b671e729250b980aa9e4ea2d483f44fa0e129cb Mon Sep 17 00:00:00 2001 From: gss2002 Date: Wed, 14 Nov 2018 13:02:13 -0800 Subject: [PATCH 2057/2461] =?UTF-8?q?[SPARK-25778]=20WriteAheadLogBackedBl?= =?UTF-8?q?ockRDD=20in=20YARN=20Cluster=20Mode=20Fails=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …due lack of access to tmpDir from $PWD to HDFS WriteAheadLogBackedBlockRDD usage of java.io.tmpdir will fail if $PWD resolves to a folder in HDFS and the Spark YARN Cluster job does not have the correct access to this folder in regards to the dummy folder. So this patch provides an option to set spark.streaming.receiver.blockStore.tmpdir to override java.io.tmpdir which sets $PWD from YARN Cluster mode. ## What changes were proposed in this pull request? This change provides an option to override the java.io.tmpdir option so that when $PWD is resolved in YARN Cluster mode Spark does not attempt to use this folder and instead use the folder provided with the following option: spark.streaming.receiver.blockStore.tmpdir ## How was this patch tested? Patch was manually tested on a Spark Streaming Job with Write Ahead logs in Cluster mode. Closes #22867 from gss2002/SPARK-25778. Authored-by: gss2002 Signed-off-by: Marcelo Vanzin --- .../spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 844760ab61d2e..f677c492d561f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -136,7 +136,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // this dummy directory should not already exist otherwise the WAL will try to recover // past events from the directory and throw errors. val nonExistentDirectory = new File( - System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).getAbsolutePath + System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).toURI.toString writeAheadLog = WriteAheadLogUtils.createLogForReceiver( SparkEnv.get.conf, nonExistentDirectory, hadoopConf) dataRead = writeAheadLog.read(partition.walRecordHandle) From 2977e2312d9690c9ced3c86b0ce937819e957775 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 14 Nov 2018 13:05:18 -0800 Subject: [PATCH 2058/2461] [SPARK-25986][BUILD] Add rules to ban throw Errors in application code ## What changes were proposed in this pull request? Add scala and java lint check rules to ban the usage of `throw new xxxErrors` and fix up all exists instance followed by https://github.com/apache/spark/pull/22989#issuecomment-437939830. See more details in https://github.com/apache/spark/pull/22969. ## How was this patch tested? Local test with lint-scala and lint-java. Closes #22989 from xuanyuanking/SPARK-25986. Authored-by: Yuanjian Li Signed-off-by: Sean Owen --- .../spark/unsafe/UnsafeAlignedOffset.java | 4 +++ .../apache/spark/memory/MemoryConsumer.java | 2 ++ .../spark/memory/TaskMemoryManager.java | 4 +++ .../unsafe/sort/UnsafeInMemorySorter.java | 2 ++ .../spark/util/random/RandomSampler.scala | 2 +- .../scala/org/apache/spark/FailureSuite.scala | 2 ++ .../apache/spark/executor/ExecutorSuite.scala | 2 ++ .../scheduler/TaskResultGetterSuite.scala | 2 ++ .../spark/storage/BlockManagerSuite.scala | 2 +- dev/checkstyle.xml | 13 +++++--- .../spark/streaming/kafka010/KafkaUtils.scala | 2 +- .../org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 8 ++--- .../org/apache/spark/ml/param/params.scala | 4 +-- .../spark/ml/tuning/ValidatorParams.scala | 4 +-- .../mllib/classification/NaiveBayes.scala | 2 +- .../apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/PredictorSuite.scala | 6 ++-- .../ml/classification/ClassifierSuite.scala | 11 ++++--- .../ml/classification/NaiveBayesSuite.scala | 4 +-- .../ml/classification/OneVsRestSuite.scala | 16 +++++----- .../spark/ml/feature/VectorIndexerSuite.scala | 4 ++- .../ml/tree/impl/RandomForestSuite.scala | 6 ++-- .../apache/spark/ml/tree/impl/TreeTests.scala | 6 ++-- .../spark/ml/tuning/CrossValidatorSuite.scala | 32 +++++++++---------- .../ml/tuning/TrainValidationSplitSuite.scala | 12 +++---- .../tuning/ValidatorParamsSuiteHelpers.scala | 3 +- .../classification/NaiveBayesSuite.scala | 2 +- .../spark/mllib/clustering/KMeansSuite.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 15 +++++---- scalastyle-config.xml | 11 +++++++ .../TungstenAggregationIterator.scala | 2 ++ .../spark/sql/FileBasedDataSourceSuite.scala | 9 +++--- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 2 +- .../spark/streaming/util/StateMap.scala | 2 +- .../spark/streaming/InputStreamsSuite.scala | 5 +-- .../spark/streaming/StateMapSuite.scala | 2 +- 39 files changed, 128 insertions(+), 87 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java index be62e40412f83..546e8780a6606 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java @@ -39,7 +39,9 @@ public static int getSize(Object object, long offset) { case 8: return (int)Platform.getLong(object, offset); default: + // checkstyle.off: RegexpSinglelineJava throw new AssertionError("Illegal UAO_SIZE"); + // checkstyle.on: RegexpSinglelineJava } } @@ -52,7 +54,9 @@ public static void putSize(Object object, long offset, int value) { Platform.putLong(object, offset, value); break; default: + // checkstyle.off: RegexpSinglelineJava throw new AssertionError("Illegal UAO_SIZE"); + // checkstyle.on: RegexpSinglelineJava } } } diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 115e1fbb79a2e..8371deca7311d 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -154,7 +154,9 @@ private void throwOom(final MemoryBlock page, final long required) { taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..28b646ba3c951 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -194,8 +194,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava } } } @@ -215,8 +217,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 75690ae264838..1a9453a8b3e80 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -214,7 +214,9 @@ public boolean hasSpaceForAnotherRecord() { public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + // checkstyle.on: RegexpSinglelineJava } Platform.copyMemory( array.getBaseObject(), diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index ea99a7e5b4847..70554f1d03067 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -49,7 +49,7 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable /** return a copy of the RandomSampler object */ override def clone: RandomSampler[T, U] = - throw new NotImplementedError("clone() is not implemented.") + throw new UnsupportedOperationException("clone() is not implemented.") } private[spark] diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index d805c67714ff8..f2d97d452ddb0 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -257,7 +257,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local[1,2]", "test") intercept[SparkException] { sc.parallelize(1 to 2).foreach { i => + // scalastyle:off throwerror throw new LinkageError() + // scalastyle:on throwerror } } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1f8a65707b2f7..32a94e60484e3 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -467,7 +467,9 @@ class FetchFailureHidingRDD( } catch { case t: Throwable => if (throwOOM) { + // scalastyle:off throwerror throw new OutOfMemoryError("OOM while handling another exception") + // scalastyle:on throwerror } else if (interrupt) { // make sure our test is setup correctly assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index f8eb8bd71c170..efb8b15cf6b4d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -265,7 +265,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local private class UndeserializableException extends Exception { private def readObject(in: ObjectInputStream): Unit = { + // scalastyle:off throwerror throw new NoClassDefFoundError() + // scalastyle:on throwerror } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 32d6e8b94e1a2..cf00c1c3aad39 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -574,7 +574,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE "list1", StorageLevel.MEMORY_ONLY, ClassTag.Any, - () => throw new AssertionError("attempted to compute locally")).isLeft) + () => fail("attempted to compute locally")).isLeft) } test("in-memory LRU storage") { diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 53c284888ebb0..e8859c01f2bd8 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -71,13 +71,13 @@ If you wish to turn off checking for a section of code, you can put a comment in the source before and after the section, with the following syntax: - // checkstyle:off no.XXX (such as checkstyle.off: NoFinalizer) + // checkstyle.off: XXX (such as checkstyle.off: NoFinalizer) ... // stuff that breaks the styles - // checkstyle:on + // checkstyle.on: XXX (such as checkstyle.on: NoFinalizer) --> - - + + @@ -180,5 +180,10 @@ + + + + + diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index 64b6ef6c53b6d..2516b948f6650 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -56,7 +56,7 @@ object KafkaUtils extends Logging { ): RDD[ConsumerRecord[K, V]] = { val preferredHosts = locationStrategy match { case PreferBrokers => - throw new AssertionError( + throw new IllegalArgumentException( "If you want to prefer brokers, you must provide a mapping using PreferFixed " + "A single KafkaRDD does not have a driver consumer and cannot look up brokers for you.") case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]() diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 5824e463ca1aa..6e950f968a65d 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -106,7 +106,7 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def copy: Vector = { - throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.") } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 51495c1a74e69..1a7a5e7a52344 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -146,7 +146,7 @@ class NaiveBayes @Since("1.5.0") ( requireZeroOneBernoulliValues case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } } @@ -196,7 +196,7 @@ class NaiveBayes @Since("1.5.0") ( case Bernoulli => math.log(n + 2.0 * lambda) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } var j = 0 while (j < numFeatures) { @@ -295,7 +295,7 @@ class NaiveBayesModel private[ml] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } @Since("1.6.0") @@ -329,7 +329,7 @@ class NaiveBayesModel private[ml] ( bernoulliCalculation(features) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index e6c347ed17c15..4c50f1e3292bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -97,7 +97,7 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali case m: Matrix => JsonMatrixConverter.toJson(m) case _ => - throw new NotImplementedError( + throw new UnsupportedOperationException( "The default jsonEncode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } @@ -151,7 +151,7 @@ private[ml] object Param { } case _ => - throw new NotImplementedError( + throw new UnsupportedOperationException( "The default jsonDecode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 135828815504a..6d46ea0adcc9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -140,8 +140,8 @@ private[ml] object ValidatorParams { "value" -> compact(render(JString(relativePath))), "isJson" -> compact(render(JBool(false)))) case _: MLWritable => - throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters " + - "of type: MLWritable that are not DefaultParamsWritable") + throw new UnsupportedOperationException("ValidatorParams.saveImpl does not handle" + + " parameters of type: MLWritable that are not DefaultParamsWritable") case _ => Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v), "isJson" -> compact(render(JBool(true)))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 9e8774732efe6..16ba6cabdc823 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -83,7 +83,7 @@ class NaiveBayesModel private[spark] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } @Since("1.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6e68d9684a672..9cdf1944329b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -117,7 +117,7 @@ sealed trait Vector extends Serializable { */ @Since("1.1.0") def copy: Vector = { - throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index ec45e32d412a9..dff00eade620f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -73,7 +73,7 @@ object PredictorSuite { } override def copy(extra: ParamMap): MockPredictor = - throw new NotImplementedError() + throw new UnsupportedOperationException() } class MockPredictionModel(override val uid: String) @@ -82,9 +82,9 @@ object PredictorSuite { def this() = this(Identifiable.randomUID("mockpredictormodel")) override def predict(features: Vector): Double = - throw new NotImplementedError() + throw new UnsupportedOperationException() override def copy(extra: ParamMap): MockPredictionModel = - throw new NotImplementedError() + throw new UnsupportedOperationException() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 87bf2be06c2be..be52d99e54d3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -117,10 +117,10 @@ object ClassifierSuite { def this() = this(Identifiable.randomUID("mockclassifier")) - override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError() + override def copy(extra: ParamMap): MockClassifier = throw new UnsupportedOperationException() override def train(dataset: Dataset[_]): MockClassificationModel = - throw new NotImplementedError() + throw new UnsupportedOperationException() // Make methods public override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = @@ -133,11 +133,12 @@ object ClassifierSuite { def this() = this(Identifiable.randomUID("mockclassificationmodel")) - protected def predictRaw(features: Vector): Vector = throw new NotImplementedError() + protected def predictRaw(features: Vector): Vector = throw new UnsupportedOperationException() - override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError() + override def copy(extra: ParamMap): MockClassificationModel = + throw new UnsupportedOperationException() - override def numClasses: Int = throw new NotImplementedError() + override def numClasses: Int = throw new UnsupportedOperationException() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 5f9ab98a2c3ce..a8c4f091b2aed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -103,7 +103,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { case Bernoulli => expectedBernoulliProbabilities(model, features) case _ => - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) } @@ -378,7 +378,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 2c3417c7e4028..519ec1720eb98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -134,8 +134,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lrModel1.coefficients ~== lrModel2.coefficients relTol 1E-3) assert(lrModel1.intercept ~== lrModel2.intercept relTol 1E-3) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail("Loaded OneVsRestModel expected model of type LogisticRegressionModel " + + s"but found ${other.getClass.getName}") } } @@ -247,8 +247,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lr.getMaxIter === lr2.getMaxIter) assert(lr.getRegParam === lr2.getRegParam) case other => - throw new AssertionError(s"Loaded OneVsRest expected classifier of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded OneVsRest expected classifier of type LogisticRegression" + + s" but found ${other.getClass.getName}") } } @@ -267,8 +267,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(classifier.getMaxIter === lr2.getMaxIter) assert(classifier.getRegParam === lr2.getRegParam) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded OneVsRestModel expected classifier of type LogisticRegression" + + s" but found ${other.getClass.getName}") } assert(model.labelMetadata === model2.labelMetadata) @@ -278,8 +278,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lrModel1.coefficients === lrModel2.coefficients) assert(lrModel1.intercept === lrModel2.intercept) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail(s"Loaded OneVsRestModel expected model of type LogisticRegressionModel" + + s" but found ${other.getClass.getName}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index e5675e31bbecf..fb5789f945dec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -283,7 +283,9 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { points.zip(rows.map(_(0))).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + case _ => + // should never happen + fail("Unit test has a bug in it.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..5caa5117d5752 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -417,9 +417,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { case n: InternalNode => n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) - case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") + case _ => fail("model.rootNode.split was not a CategoricalSplit") } - case _ => throw new AssertionError("model.rootNode was not an InternalNode") + case _ => fail("model.rootNode was not an InternalNode") } } @@ -444,7 +444,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(n.leftChild.isInstanceOf[InternalNode]) assert(n.rightChild.isInstanceOf[InternalNode]) Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) - case _ => throw new AssertionError("rootNode was not an InternalNode") + case _ => fail("rootNode was not an InternalNode") } // Single group second level tree construction. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..ae9794b87b08d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -112,7 +112,7 @@ private[ml] object TreeTests extends SparkFunSuite { checkEqual(a.rootNode, b.rootNode) } catch { case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + fail("checkEqual failed since the two trees were not identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + "TREE B:\n" + b.toDebugString + "\n", ex) } @@ -133,7 +133,7 @@ private[ml] object TreeTests extends SparkFunSuite { checkEqual(aye.rightChild, bee.rightChild) case (aye: LeafNode, bee: LeafNode) => // do nothing case _ => - throw new AssertionError("Found mismatched nodes") + fail("Found mismatched nodes") } } @@ -148,7 +148,7 @@ private[ml] object TreeTests extends SparkFunSuite { } assert(a.treeWeights === b.treeWeights) } catch { - case ex: Exception => throw new AssertionError( + case ex: Exception => fail( "checkEqual failed since the two tree ensembles were not identical") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index e6ee7220d2279..a30428ec2d283 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -190,8 +190,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -281,13 +281,13 @@ class CrossValidatorSuite assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" OneVsRest but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type OneVsRest but " + + s"found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -364,8 +364,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded internal CrossValidator expected to be" + - s" LogisticRegression but found type ${other.getClass.getName}") + fail("Loaded internal CrossValidator expected to be LogisticRegression" + + s" but found type ${other.getClass.getName}") } assert(lrcv.uid === lrcv2.uid) assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) @@ -373,12 +373,12 @@ class CrossValidatorSuite ValidatorParamsSuiteHelpers .compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) case other => - throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + - " but found: " + other.map(_.getClass.getName).mkString(", ")) + fail("Loaded Pipeline expected stages (HashingTF, CrossValidator) but found: " + + other.map(_.getClass.getName).mkString(", ")) } case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" CrossValidator but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type CrossValidator but found" + + s" ${other.getClass.getName}") } } @@ -433,8 +433,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getThreshold === lr2.getThreshold) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -447,8 +447,8 @@ class CrossValidatorSuite assert(lrModel.coefficients === lrModel2.coefficients) assert(lrModel.intercept === lrModel2.intercept) case other => - throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected bestModel of type LogisticRegressionModel" + + s" but found ${other.getClass.getName}") } assert(cv.avgMetrics === cv2.avgMetrics) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index cd76acf9c67bc..289db336eca5d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -187,8 +187,8 @@ class TrainValidationSplitSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded TrainValidationSplit expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } } @@ -264,13 +264,13 @@ class TrainValidationSplitSuite assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail(s"Loaded TrainValidationSplit expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" OneVsRest but found ${other.getClass.getName}") + fail(s"Loaded TrainValidationSplit expected estimator of type OneVsRest" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala index eae1f5adc8842..cea2f50d3470c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala @@ -47,8 +47,7 @@ object ValidatorParamsSuiteHelpers extends Assertions { val estimatorParamMap2 = Array(estimator2.extractParamMap()) compareParamMaps(estimatorParamMap, estimatorParamMap2) case other => - throw new AssertionError(s"Expected parameter of type Params but" + - s" found ${otherParam.getClass.getName}") + fail(s"Expected parameter of type Params but found ${otherParam.getClass.getName}") } case _ => assert(otherParam === v) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5ec4c15387e94..8c7d583923b32 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -71,7 +71,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 1b98250061c7a..d18cef7e264db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -349,7 +349,7 @@ object KMeansSuite extends SparkFunSuite { case (ca: DenseVector, cb: DenseVector) => assert(ca === cb) case _ => - throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") + fail("checkEqual failed since the two clusters were not identical.\n") } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bc59f3f4125fb..34bc303ac6079 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -607,7 +607,7 @@ object DecisionTreeSuite extends SparkFunSuite { checkEqual(a.topNode, b.topNode) } catch { case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + fail("checkEqual failed since the two trees were not identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + "TREE B:\n" + b.toDebugString + "\n", ex) } @@ -628,20 +628,21 @@ object DecisionTreeSuite extends SparkFunSuite { // TODO: Check other fields besides the information gain. case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) case (None, None) => - case _ => throw new AssertionError( - s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + case _ => fail(s"Only one instance has stats defined. (a.stats: ${a.stats}, " + + s"b.stats: ${b.stats})") } (a.leftNode, b.leftNode) match { case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) case (None, None) => - case _ => throw new AssertionError("Only one instance has leftNode defined. " + - s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + case _ => + fail("Only one instance has leftNode defined. (a.leftNode: ${a.leftNode}," + + " b.leftNode: ${b.leftNode})") } (a.rightNode, b.rightNode) match { case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) case (None, None) => - case _ => throw new AssertionError("Only one instance has rightNode defined. " + - s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + case _ => fail("Only one instance has rightNode defined. (a.rightNode: ${a.rightNode}, " + + "b.rightNode: ${b.rightNode})") } } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 36a73e3362218..4892819ae9973 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -240,6 +240,17 @@ This file is divided into 3 sections: ]]> + + throw new \w+Error\( + + + JavaConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 72505f7fac0c6..6d849869b577a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -206,7 +206,9 @@ class TungstenAggregationIterator( buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // failed to allocate the first page + // scalastyle:off throwerror throw new SparkOutOfMemoryError("No enough memory for aggregation") + // scalastyle:on throwerror } } processRow(buffer, newInput) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 94f163708832c..64b42c32b8b1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -509,9 +509,9 @@ object TestingUDT { override def sqlType: DataType = CalendarIntervalType override def serialize(obj: IntervalData): Any = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def deserialize(datum: Any): IntervalData = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def userClass: Class[IntervalData] = classOf[IntervalData] } @@ -521,9 +521,10 @@ object TestingUDT { private[sql] class NullUDT extends UserDefinedType[NullData] { override def sqlType: DataType = NullType - override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def serialize(obj: NullData): Any = + throw new UnsupportedOperationException("Not implemented") override def deserialize(datum: Any): NullData = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def userClass: Class[NullData] = classOf[NullData] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index e4e224df7607f..142ab6170a734 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -790,6 +790,6 @@ private case class DummySparkPlan( override val requiredChildDistribution: Seq[Distribution] = Nil, override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil ) extends SparkPlan { - override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index bceaf1a9ec061..955c3e3fa6f74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -614,7 +614,7 @@ class TestFileFormat extends TextBasedFileFormat { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - throw new NotImplementedError("JUST FOR TESTING") + throw new UnsupportedOperationException("JUST FOR TESTING") } override def buildReader( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index f57f07b498261..e8062dbb91e35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1123,7 +1123,7 @@ class ColumnarBatchSuite extends SparkFunSuite { compareStruct(childFields, r1.getStruct(ordinal, fields.length), r2.getStruct(ordinal), seed) case _ => - throw new NotImplementedError("Not implemented " + field.dataType) + throw new UnsupportedOperationException("Not implemented " + field.dataType) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 89524cd84ff32..618c036377aee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -70,7 +70,7 @@ private[streaming] object StateMap { /** Implementation of StateMap interface representing an empty map */ private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { - throw new NotImplementedError("put() should not be called on an EmptyStateMap") + throw new UnsupportedOperationException("put() should not be called on an EmptyStateMap") } override def get(key: K): Option[S] = None override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 1cf21e8a28033..7376741f64a12 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -31,6 +31,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.scalatest.Assertions import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ @@ -532,7 +533,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { /** This is a server to test the network input stream */ -class TestServer(portToBind: Int = 0) extends Logging { +class TestServer(portToBind: Int = 0) extends Logging with Assertions { val queue = new ArrayBlockingQueue[String](100) @@ -592,7 +593,7 @@ class TestServer(portToBind: Int = 0) extends Logging { servingThread.start() if (!waitForStart(10000)) { stop() - throw new AssertionError("Timeout: TestServer cannot start in 10 seconds") + fail("Timeout: TestServer cannot start in 10 seconds") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 484f3733e8423..e444132d3a626 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -35,7 +35,7 @@ class StateMapSuite extends SparkFunSuite { test("EmptyStateMap") { val map = new EmptyStateMap[Int, Int] - intercept[scala.NotImplementedError] { + intercept[UnsupportedOperationException] { map.put(1, 1, 1) } assert(map.get(1) === None) From ad853c56788fd32e035369d1fe3d96aaf6c4ef16 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 14 Nov 2018 16:22:23 -0800 Subject: [PATCH 2059/2461] [SPARK-25956] Make Scala 2.12 as default Scala version in Spark 3.0 ## What changes were proposed in this pull request? This PR makes Spark's default Scala version as 2.12, and Scala 2.11 will be the alternative version. This implies that Scala 2.12 will be used by our CI builds including pull request builds. We'll update the Jenkins to include a new compile-only jobs for Scala 2.11 to ensure the code can be still compiled with Scala 2.11. ## How was this patch tested? existing tests Closes #22967 from dbtsai/scala2.12. Authored-by: DB Tsai Signed-off-by: Dongjoon Hyun --- assembly/pom.xml | 4 +-- common/kvstore/pom.xml | 4 +-- common/network-common/pom.xml | 4 +-- common/network-shuffle/pom.xml | 4 +-- common/network-yarn/pom.xml | 4 +-- common/sketch/pom.xml | 4 +-- common/tags/pom.xml | 4 +-- common/unsafe/pom.xml | 4 +-- core/pom.xml | 4 +-- dev/deps/spark-deps-hadoop-2.7 | 36 +++++++++---------- dev/deps/spark-deps-hadoop-3.1 | 36 +++++++++---------- docs/_config.yml | 4 +-- docs/_plugins/copy_api_dirs.rb | 2 +- docs/building-spark.md | 18 +++++----- docs/cloud-integration.md | 2 +- docs/sparkr.md | 2 +- examples/pom.xml | 4 +-- external/avro/pom.xml | 4 +-- external/docker-integration-tests/pom.xml | 4 +-- external/kafka-0-10-assembly/pom.xml | 4 +-- external/kafka-0-10-sql/pom.xml | 4 +-- external/kafka-0-10/pom.xml | 4 +-- external/kinesis-asl-assembly/pom.xml | 4 +-- external/kinesis-asl/pom.xml | 4 +-- external/spark-ganglia-lgpl/pom.xml | 4 +-- graphx/pom.xml | 4 +-- hadoop-cloud/pom.xml | 4 +-- launcher/pom.xml | 4 +-- mllib-local/pom.xml | 4 +-- mllib/pom.xml | 4 +-- pom.xml | 20 ++++++----- project/MimaBuild.scala | 2 +- project/SparkBuild.scala | 14 ++++---- python/run-tests.py | 4 +-- repl/pom.xml | 4 +-- resource-managers/kubernetes/core/pom.xml | 4 +-- .../kubernetes/integration-tests/pom.xml | 4 +-- resource-managers/mesos/pom.xml | 4 +-- resource-managers/yarn/pom.xml | 4 +-- sql/catalyst/pom.xml | 4 +-- sql/core/pom.xml | 4 +-- sql/hive-thriftserver/pom.xml | 4 +-- sql/hive/pom.xml | 4 +-- streaming/pom.xml | 4 +-- tools/pom.xml | 4 +-- 45 files changed, 138 insertions(+), 138 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index b0337e58cca71..68ebfadb668ab 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-assembly_2.11 + spark-assembly_2.12 Spark Project Assembly http://spark.apache.org/ pom diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 23a0f49206909..f042a12fda3d2 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-kvstore_2.11 + spark-kvstore_2.12 jar Spark Project Local DB http://spark.apache.org/ diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 41fcbf0589499..56d01fa0e8b3d 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-common_2.11 + spark-network-common_2.12 jar Spark Project Networking http://spark.apache.org/ diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index ff717057bb25d..a6d99813a8501 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-shuffle_2.11 + spark-network-shuffle_2.12 jar Spark Project Shuffle Streaming Service http://spark.apache.org/ diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a1cf761d12d8b..55cdc3140aa08 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-yarn_2.11 + spark-network-yarn_2.12 jar Spark Project YARN Shuffle Service http://spark.apache.org/ diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index adbbcb1cb3040..3c3c0d2d96a1c 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-sketch_2.11 + spark-sketch_2.12 jar Spark Project Sketch http://spark.apache.org/ diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f6627beabe84b..883b73a69c9de 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-tags_2.11 + spark-tags_2.12 jar Spark Project Tags http://spark.apache.org/ diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 62c493a5e1ed8..7e4b08217f1b0 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-unsafe_2.11 + spark-unsafe_2.12 jar Spark Project Unsafe http://spark.apache.org/ diff --git a/core/pom.xml b/core/pom.xml index 5c26f9a5ea3c6..36d93212ba9f9 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-core_2.11 + spark-core_2.12 core diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 01691811fd3eb..c2f5755ca9925 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -22,13 +22,13 @@ avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar +breeze-macros_2.12-0.13.2.jar +breeze_2.12-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.9.3.jar -chill_2.11-0.9.3.jar +chill_2.12-0.9.3.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -96,7 +96,7 @@ jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.9.6.jar jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.11-2.9.6.jar +jackson-module-scala_2.12-2.9.6.jar jackson-xc-1.9.13.jar janino-3.0.10.jar javassist-3.18.1-GA.jar @@ -122,10 +122,10 @@ jline-2.14.6.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar +json4s-ast_2.12-3.5.3.jar +json4s-core_2.12-3.5.3.jar +json4s-jackson_2.12-3.5.3.jar +json4s-scalap_2.12-3.5.3.jar jsp-api-2.1.jar jsr305-3.0.0.jar jta-1.1.jar @@ -140,8 +140,8 @@ libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.9.1.jar lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar +machinist_2.12-0.6.1.jar +macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar @@ -170,19 +170,19 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar +scala-compiler-2.12.7.jar +scala-library-2.12.7.jar +scala-parser-combinators_2.12-1.1.0.jar +scala-reflect-2.12.7.jar +scala-xml_2.12-1.0.5.jar +shapeless_2.12-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar +spire-macros_2.12-0.13.0.jar +spire_2.12-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index fd46f1491874a..811febf22940d 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -20,13 +20,13 @@ avro-1.8.2.jar avro-ipc-1.8.2.jar avro-mapred-1.8.2-hadoop2.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar +breeze-macros_2.12-0.13.2.jar +breeze_2.12-0.13.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar chill-java-0.9.3.jar -chill_2.11-0.9.3.jar +chill_2.12-0.9.3.jar commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar @@ -96,7 +96,7 @@ jackson-jaxrs-json-provider-2.7.8.jar jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.9.6.jar jackson-module-paranamer-2.9.6.jar -jackson-module-scala_2.11-2.9.6.jar +jackson-module-scala_2.12-2.9.6.jar janino-3.0.10.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar @@ -123,10 +123,10 @@ joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-smart-2.3.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar +json4s-ast_2.12-3.5.3.jar +json4s-core_2.12-3.5.3.jar +json4s-jackson_2.12-3.5.3.jar +json4s-scalap_2.12-3.5.3.jar jsp-api-2.1.jar jsr305-3.0.0.jar jta-1.1.jar @@ -155,8 +155,8 @@ libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.9.1.jar lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar +machinist_2.12-0.6.1.jar +macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar metrics-core-3.1.5.jar metrics-graphite-3.1.5.jar @@ -189,19 +189,19 @@ protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar +scala-compiler-2.12.7.jar +scala-library-2.12.7.jar +scala-parser-combinators_2.12-1.1.0.jar +scala-reflect-2.12.7.jar +scala-xml_2.12-1.0.5.jar +shapeless_2.12-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.18.jar snappy-0.2.jar snappy-java-1.1.7.1.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar +spire-macros_2.12-0.13.0.jar +spire_2.12-0.13.0.jar stax-api-1.0.1.jar stax2-api-3.1.4.jar stream-2.7.0.jar diff --git a/docs/_config.yml b/docs/_config.yml index c3ef98575fa62..649d18bf72b57 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -16,8 +16,8 @@ include: # of Spark, Scala, and Mesos. SPARK_VERSION: 3.0.0-SNAPSHOT SPARK_VERSION_SHORT: 3.0.0 -SCALA_BINARY_VERSION: "2.11" -SCALA_VERSION: "2.11.12" +SCALA_BINARY_VERSION: "2.12" +SCALA_VERSION: "2.12.7" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 4d0d043a349bb..2d1a9547e3731 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.11/unidoc" + source = "../target/scala-2.12/unidoc" dest = "api/scala" puts "Making directory " + dest diff --git a/docs/building-spark.md b/docs/building-spark.md index 8af90db9a19dd..dfcd53c48e85c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -96,9 +96,9 @@ It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: - ./build/mvn -pl :spark-streaming_2.11 clean install + ./build/mvn -pl :spark-streaming_{{site.SCALA_BINARY_VERSION}} clean install -where `spark-streaming_2.11` is the `artifactId` as defined in `streaming/pom.xml` file. +where `spark-streaming_{{site.SCALA_BINARY_VERSION}}` is the `artifactId` as defined in `streaming/pom.xml` file. ## Continuous Compilation @@ -230,7 +230,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_{{site.SCALA_BINARY_VERSION}} or @@ -238,17 +238,17 @@ or ## Change Scala Version -To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.12): +To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.11): - ./dev/change-scala-version.sh 2.12 + ./dev/change-scala-version.sh 2.11 -For Maven, please enable the profile (e.g. 2.12): +For Maven, please enable the profile (e.g. 2.11): - ./build/mvn -Pscala-2.12 compile + ./build/mvn -Pscala-2.11 compile -For SBT, specify a complete scala version using (e.g. 2.12.6): +For SBT, specify a complete scala version using (e.g. 2.11.12): - ./build/sbt -Dscala.version=2.12.6 + ./build/sbt -Dscala.version=2.11.12 Otherwise, the sbt-pom-reader plugin will use the `scala.version` specified in the spark-parent pom. diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 36753f6373b55..5368e13727334 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -85,7 +85,7 @@ is set to the chosen version of Spark: ... org.apache.spark - hadoop-cloud_2.11 + hadoop-cloud_{{site.SCALA_BINARY_VERSION}} ${spark.version} ... diff --git a/docs/sparkr.md b/docs/sparkr.md index cc6bc6d14853d..acd0e77c4d71a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -133,7 +133,7 @@ specifying `--packages` with `spark-submit` or `sparkR` commands, or if initiali
      {% highlight r %} -sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") +sparkR.session(sparkPackages = "org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION}}") {% endhighlight %}
      diff --git a/examples/pom.xml b/examples/pom.xml index 756c475b4748d..0636406595f6e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-examples_2.11 + spark-examples_2.12 jar Spark Project Examples http://spark.apache.org/ diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 9d8f319cc9396..ba6f20bfdbf58 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-avro_2.11 + spark-avro_2.12 avro diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index f24254b698080..b39db7540b7d2 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-docker-integration-tests_2.11 + spark-docker-integration-tests_2.12 jar Spark Project Docker Integration Tests http://spark.apache.org/ diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 4f9c3163b2408..f2dcf5d217a89 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10-assembly_2.11 + spark-streaming-kafka-0-10-assembly_2.12 jar Spark Integration for Kafka 0.10 Assembly http://spark.apache.org/ diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index efd0862fb58ee..3f1055a75076f 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql-kafka-0-10_2.11 + spark-sql-kafka-0-10_2.12 sql-kafka-0-10 diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index f59f07265a0f4..d75b13da8fb70 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10_2.11 + spark-streaming-kafka-0-10_2.12 streaming-kafka-0-10 diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 0bf4c265939e7..0ce922349ea66 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl-assembly_2.11 + spark-streaming-kinesis-asl-assembly_2.12 jar Spark Project Kinesis Assembly http://spark.apache.org/ diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 0aef25329db99..7d69764b77de7 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl_2.11 + spark-streaming-kinesis-asl_2.12 jar Spark Kinesis Integration diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 35a55b70baf33..a23d255f9187c 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-ganglia-lgpl_2.11 + spark-ganglia-lgpl_2.12 jar Spark Ganglia Integration diff --git a/graphx/pom.xml b/graphx/pom.xml index 22bc148e068a5..444568a03d6c7 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-graphx_2.11 + spark-graphx_2.12 graphx diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 3182ab15db5f5..2e5b04622cf1c 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-hadoop-cloud_2.11 + spark-hadoop-cloud_2.12 jar Spark Project Cloud Integration through Hadoop Libraries diff --git a/launcher/pom.xml b/launcher/pom.xml index b1b6126ea5934..e75e8345cd51d 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-launcher_2.11 + spark-launcher_2.12 jar Spark Project Launcher http://spark.apache.org/ diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index ec5f9b0e92c8f..2eab868ac0dc8 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-mllib-local_2.11 + spark-mllib-local_2.12 mllib-local diff --git a/mllib/pom.xml b/mllib/pom.xml index 17ddb87c4d86a..0b17345064a71 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-mllib_2.11 + spark-mllib_2.12 mllib diff --git a/pom.xml b/pom.xml index ee1fd472a3ea7..59e3d0fa772b4 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 18 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT pom Spark Project Parent POM @@ -154,8 +154,8 @@ 3.4.1 3.2.2 - 2.11.12 - 2.11 + 2.12.7 + 2.12 1.9.13 2.9.6 1.1.7.1 @@ -1998,6 +1998,7 @@ --> org.jboss.netty org.codehaus.groovy + *:*_2.11 *:*_2.10 true @@ -2705,14 +2706,14 @@ - scala-2.11 + scala-2.12 - scala-2.12 + scala-2.11 - 2.12.7 - 2.12 + 2.11.12 + 2.11 @@ -2728,8 +2729,9 @@ - - *:*_2.11 + + org.jboss.netty + org.codehaus.groovy *:*_2.10 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 79e6745977e5b..10c02103aeddb 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -90,7 +90,7 @@ object MimaBuild { val organization = "org.apache.spark" val previousSparkVersion = "2.4.0" val project = projectRef.project - val fullId = "spark-" + project + "_2.11" + val fullId = "spark-" + project + "_2.12" mimaDefaultSettings ++ Seq(mimaPreviousArtifacts := Set(organization % fullId % previousSparkVersion), mimaBinaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5e034f9fe2a95..08e22fab65165 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -95,15 +95,15 @@ object SparkBuild extends PomBuild { } Option(System.getProperty("scala.version")) - .filter(_.startsWith("2.12")) + .filter(_.startsWith("2.11")) .foreach { versionString => - System.setProperty("scala-2.12", "true") + System.setProperty("scala-2.11", "true") } - if (System.getProperty("scala-2.12") == "") { + if (System.getProperty("scala-2.11") == "") { // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.12", "true") + System.setProperty("scala-2.11", "true") } profiles } @@ -849,10 +849,10 @@ object TestSettings { import BuildCommons._ private val scalaBinaryVersion = - if (System.getProperty("scala-2.12") == "true") { - "2.12" - } else { + if (System.getProperty("scala-2.11") == "true") { "2.11" + } else { + "2.12" } lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those diff --git a/python/run-tests.py b/python/run-tests.py index 44305741afe3e..9fd1c9b94ac6f 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -59,9 +59,7 @@ def print_red(text): LOGGER = logging.getLogger() # Find out where the assembly jars are located. -# Later, add back 2.12 to this list: -# for scala in ["2.11", "2.12"]: -for scala in ["2.11"]: +for scala in ["2.11", "2.12"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") diff --git a/repl/pom.xml b/repl/pom.xml index fa015b69d45d4..c7de67e41ca94 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-repl_2.11 + spark-repl_2.12 jar Spark Project REPL http://spark.apache.org/ diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index b89ea383bf872..8d594ee8f1478 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../../pom.xml - spark-kubernetes_2.11 + spark-kubernetes_2.12 jar Spark Project Kubernetes diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 301b6fe8eee56..17af0e03f2bbb 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../../pom.xml - spark-kubernetes-integration-tests_2.11 + spark-kubernetes-integration-tests_2.12 1.3.0 1.4.0 diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 9585bdfafdcf4..7b3aad4d6ce35 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-mesos_2.11 + spark-mesos_2.12 jar Spark Project Mesos diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index e55b814be8465..d18df9955bb1f 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-yarn_2.11 + spark-yarn_2.12 jar Spark Project YARN diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 16ecebf159c1f..20cc5d03fbe52 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-catalyst_2.11 + spark-catalyst_2.12 jar Spark Project Catalyst http://spark.apache.org/ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 95e98c5444721..ac5f1fc923e7d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-sql_2.11 + spark-sql_2.12 jar Spark Project SQL http://spark.apache.org/ diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 55e051c3ed1be..4a4629fae2706 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-hive-thriftserver_2.11 + spark-hive-thriftserver_2.12 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ef22e2abfb53e..9994689936033 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-hive_2.11 + spark-hive_2.12 jar Spark Project Hive http://spark.apache.org/ diff --git a/streaming/pom.xml b/streaming/pom.xml index f9a5029a8e818..1d1ea469f7d18 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-streaming_2.11 + spark-streaming_2.12 streaming diff --git a/tools/pom.xml b/tools/pom.xml index 247f5a6df4b08..6286fad403c83 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-tools_2.11 + spark-tools_2.12 tools From f6255d7b7cc4cc5d1f4fe0e5e493a1efee22f38f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Nov 2018 08:33:06 +0800 Subject: [PATCH 2060/2461] [MINOR][SQL] Add disable bucketedRead workaround when throw RuntimeException ## What changes were proposed in this pull request? It will throw `RuntimeException` when read from bucketed table(about 1.7G per bucket file): ![image](https://user-images.githubusercontent.com/5399861/48346889-8041ce00-e6b7-11e8-83b0-ead83fb15821.png) Default(enable bucket read): ![image](https://user-images.githubusercontent.com/5399861/48347084-2c83b480-e6b8-11e8-913a-9cafc043e9e4.png) Disable bucket read: ![image](https://user-images.githubusercontent.com/5399861/48347099-3a393a00-e6b8-11e8-94af-cb814e1ba277.png) The reason is that each bucket file is too big. a workaround is disable bucket read. This PR add this workaround to Spark. ## How was this patch tested? manual tests Closes #23014 from wangyum/anotherWorkaround. Authored-by: Yuming Wang Signed-off-by: hyukjinkwon --- .../spark/sql/execution/vectorized/WritableColumnVector.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index b0e119d658cb4..4f5e72c1326ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -101,10 +101,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + - "vectorized reader. For parquet file format, refer to " + + "vectorized reader, or disable " + SQLConf.BUCKETING_ENABLED().key() + " if you read " + + "from bucket table. For Parquet file format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + - ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for ORC file format, " + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; From 03306a6df39c9fd6cb581401c13c4dfc6bbd632e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Nov 2018 12:30:52 +0800 Subject: [PATCH 2061/2461] [SPARK-26036][PYTHON] Break large tests.py files into smaller files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR continues to break down a big large file into smaller files. See https://github.com/apache/spark/pull/23021. It targets to follow https://github.com/numpy/numpy/tree/master/numpy. Basically this PR proposes to break down `pyspark/tests.py` into ...: ``` pyspark ... ├── testing ... │   └── utils.py ├── tests │   ├── __init__.py │   ├── test_appsubmit.py │   ├── test_broadcast.py │   ├── test_conf.py │   ├── test_context.py │   ├── test_daemon.py │   ├── test_join.py │   ├── test_profiler.py │   ├── test_rdd.py │   ├── test_readwrite.py │   ├── test_serializers.py │   ├── test_shuffle.py │   ├── test_taskcontext.py │   ├── test_util.py │   └── test_worker.py ... ``` ## How was this patch tested? Existing tests should cover. `cd python` and .`/run-tests-with-coverage`. Manually checked they are actually being ran. Each test (not officially) can be ran via: ```bash SPARK_TESTING=1 ./bin/pyspark pyspark.tests.test_context ``` Note that if you're using Mac and Python 3, you might have to `OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. Closes #23033 from HyukjinKwon/SPARK-26036. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/sparktestsupport/modules.py | 19 +- python/pyspark/ml/tests.py | 2 +- python/pyspark/sql/tests/test_appsubmit.py | 7 +- python/pyspark/sql/tests/test_arrow.py | 7 +- python/pyspark/sql/tests/test_catalog.py | 5 +- python/pyspark/sql/tests/test_column.py | 5 +- python/pyspark/sql/tests/test_conf.py | 5 +- python/pyspark/sql/tests/test_context.py | 7 +- python/pyspark/sql/tests/test_dataframe.py | 7 +- python/pyspark/sql/tests/test_datasources.py | 5 +- python/pyspark/sql/tests/test_functions.py | 5 +- python/pyspark/sql/tests/test_group.py | 5 +- python/pyspark/sql/tests/test_pandas_udf.py | 7 +- .../sql/tests/test_pandas_udf_grouped_agg.py | 7 +- .../sql/tests/test_pandas_udf_grouped_map.py | 7 +- .../sql/tests/test_pandas_udf_scalar.py | 7 +- .../sql/tests/test_pandas_udf_window.py | 7 +- python/pyspark/sql/tests/test_readwriter.py | 5 +- python/pyspark/sql/tests/test_serde.py | 5 +- python/pyspark/sql/tests/test_session.py | 7 +- python/pyspark/sql/tests/test_streaming.py | 5 +- python/pyspark/sql/tests/test_types.py | 5 +- python/pyspark/sql/tests/test_udf.py | 7 +- python/pyspark/sql/tests/test_utils.py | 5 +- python/pyspark/test_serializers.py | 90 - python/pyspark/testing/sqlutils.py | 2 +- python/pyspark/testing/utils.py | 102 + python/pyspark/tests.py | 2502 ----------------- python/pyspark/tests/__init__.py | 16 + python/pyspark/tests/test_appsubmit.py | 248 ++ python/pyspark/{ => tests}/test_broadcast.py | 24 +- python/pyspark/tests/test_conf.py | 43 + python/pyspark/tests/test_context.py | 258 ++ python/pyspark/tests/test_daemon.py | 80 + python/pyspark/tests/test_join.py | 69 + python/pyspark/tests/test_profiler.py | 112 + python/pyspark/tests/test_rdd.py | 739 +++++ python/pyspark/tests/test_readwrite.py | 499 ++++ python/pyspark/tests/test_serializers.py | 237 ++ python/pyspark/tests/test_shuffle.py | 181 ++ python/pyspark/tests/test_taskcontext.py | 161 ++ python/pyspark/tests/test_util.py | 86 + python/pyspark/tests/test_worker.py | 157 ++ 43 files changed, 3093 insertions(+), 2666 deletions(-) delete mode 100644 python/pyspark/test_serializers.py create mode 100644 python/pyspark/testing/utils.py delete mode 100644 python/pyspark/tests.py create mode 100644 python/pyspark/tests/__init__.py create mode 100644 python/pyspark/tests/test_appsubmit.py rename python/pyspark/{ => tests}/test_broadcast.py (91%) create mode 100644 python/pyspark/tests/test_conf.py create mode 100644 python/pyspark/tests/test_context.py create mode 100644 python/pyspark/tests/test_daemon.py create mode 100644 python/pyspark/tests/test_join.py create mode 100644 python/pyspark/tests/test_profiler.py create mode 100644 python/pyspark/tests/test_rdd.py create mode 100644 python/pyspark/tests/test_readwrite.py create mode 100644 python/pyspark/tests/test_serializers.py create mode 100644 python/pyspark/tests/test_shuffle.py create mode 100644 python/pyspark/tests/test_taskcontext.py create mode 100644 python/pyspark/tests/test_util.py create mode 100644 python/pyspark/tests/test_worker.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9dbe4e4f20e03..d5fcc060616f2 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -310,6 +310,7 @@ def __hash__(self): "python/(?!pyspark/(ml|mllib|sql|streaming))" ], python_test_goals=[ + # doctests "pyspark.rdd", "pyspark.context", "pyspark.conf", @@ -318,10 +319,22 @@ def __hash__(self): "pyspark.serializers", "pyspark.profiler", "pyspark.shuffle", - "pyspark.tests", - "pyspark.test_broadcast", - "pyspark.test_serializers", "pyspark.util", + # unittests + "pyspark.tests.test_appsubmit", + "pyspark.tests.test_broadcast", + "pyspark.tests.test_conf", + "pyspark.tests.test_context", + "pyspark.tests.test_daemon", + "pyspark.tests.test_join", + "pyspark.tests.test_profiler", + "pyspark.tests.test_rdd", + "pyspark.tests.test_readwrite", + "pyspark.tests.test_serializers", + "pyspark.tests.test_shuffle", + "pyspark.tests.test_taskcontext", + "pyspark.tests.test_util", + "pyspark.tests.test_worker", ] ) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 821e037af0271..2b4b7315d98c0 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -72,7 +72,7 @@ from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * -from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase +from pyspark.testing.utils import QuietTest, ReusedPySparkTestCase as PySparkTestCase ser = PickleSerializer() diff --git a/python/pyspark/sql/tests/test_appsubmit.py b/python/pyspark/sql/tests/test_appsubmit.py index 3c71151e396b9..43abcde7785d8 100644 --- a/python/pyspark/sql/tests/test_appsubmit.py +++ b/python/pyspark/sql/tests/test_appsubmit.py @@ -22,7 +22,7 @@ import py4j from pyspark import SparkContext -from pyspark.tests import SparkSubmitTests +from pyspark.tests.test_appsubmit import SparkSubmitTests class HiveSparkSubmitTests(SparkSubmitTests): @@ -91,6 +91,7 @@ def test_hivecontext(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 44f703569703a..6e75e82d58009 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -26,7 +26,7 @@ from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest from pyspark.util import _exception_message @@ -394,6 +394,7 @@ def conf(cls): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 23d25770d4b01..873405a2c6aa3 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -194,6 +194,7 @@ def test_list_columns(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index faadde9527f6f..01d4f7e223a41 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -152,6 +152,7 @@ def test_bitwise_operations(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py index f5d68a8f48851..53ac4a66f4645 100644 --- a/python/pyspark/sql/tests/test_conf.py +++ b/python/pyspark/sql/tests/test_conf.py @@ -50,6 +50,7 @@ def test_conf(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index d9d408a0b9663..918f4ad2d62f4 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -25,7 +25,7 @@ from pyspark import HiveContext, Row from pyspark.sql.types import * from pyspark.sql.window import Window -from pyspark.tests import ReusedPySparkTestCase +from pyspark.testing.utils import ReusedPySparkTestCase class HiveContextSQLTests(ReusedPySparkTestCase): @@ -258,6 +258,7 @@ def range_frame_match(): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index eba00b5687d96..908d400e00092 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -25,7 +25,7 @@ from pyspark.sql.utils import AnalysisException, IllegalArgumentException from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest class DataFrameTests(ReusedSQLTestCase): @@ -732,6 +732,7 @@ def test_query_execution_listener_on_collect_with_arrow(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index b82737855a760..5579620bc2be1 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -165,6 +165,7 @@ def test_ignore_column_of_all_nulls(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f0b59e86af178..fe6660272e323 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -273,6 +273,7 @@ def test_sort_with_nulls_order(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 076899f377598..6de1b8ea0b3ce 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -40,6 +40,7 @@ def test_aggregator(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index 54a34a7dc5b94..c4b5478a7e893 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -21,7 +21,7 @@ from pyspark.sql.utils import ParseException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -211,6 +211,7 @@ def foofoo(x, y): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index bca47cc3a69bf..5383704434c85 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -21,7 +21,7 @@ from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -498,6 +498,7 @@ def test_register_vectorized_udf_basic(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index 4d443887c0ed2..bfecc071386e9 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -22,7 +22,7 @@ from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -525,6 +525,7 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 394ee978dcaed..2f585a3725988 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -28,7 +28,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\ test_not_compiled_message, have_pandas, have_pyarrow, pandas_requirement_message, \ pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -802,6 +802,7 @@ def test_datasource_with_udf(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index 26e7993f1d9d9..f0e6d2696df62 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -21,7 +21,7 @@ from pyspark.sql.window import Window from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest @unittest.skipIf( @@ -257,6 +257,7 @@ def test_invalid_args(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 064d308b552c1..2f8712d7631f5 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -148,6 +148,7 @@ def count_bucketed_cols(names, table="pyspark_bucket"): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index 5ea0636dcbb6f..8707f46b6a25a 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -133,6 +133,7 @@ def test_BinaryType_serialization(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index b81104796fda8..c6b9e0b2ca554 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -21,7 +21,7 @@ from pyspark import SparkConf, SparkContext from pyspark.sql import SparkSession, SQLContext, Row from pyspark.testing.sqlutils import ReusedSQLTestCase -from pyspark.tests import PySparkTestCase +from pyspark.testing.utils import PySparkTestCase class SparkSessionTests(ReusedSQLTestCase): @@ -315,6 +315,7 @@ def test_use_custom_class_for_extensions(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index cc0cab4881dc8..4b71759f74a55 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -561,6 +561,7 @@ def collectBatch(df, id): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 3b32c58a86639..fb673f2a385ef 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -939,6 +939,7 @@ def __init__(self, **kwargs): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 630b21517712f..d2dfb52f54475 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -27,7 +27,7 @@ from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message -from pyspark.tests import QuietTest +from pyspark.testing.utils import QuietTest class UDFTests(ReusedSQLTestCase): @@ -649,6 +649,7 @@ def test_udf_init_shouldnt_initialize_context(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 63a8614d2effd..5bb921da5c2f3 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -49,6 +49,7 @@ def test_capture_illegalargument_exception(self): try: import xmlrunner - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: - unittest.main(verbosity=2) + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py deleted file mode 100644 index 5b43729f9ebb1..0000000000000 --- a/python/pyspark/test_serializers.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import io -import math -import struct -import sys -import unittest - -try: - import xmlrunner -except ImportError: - xmlrunner = None - -from pyspark import serializers - - -def read_int(b): - return struct.unpack("!i", b)[0] - - -def write_int(i): - return struct.pack("!i", i) - - -class SerializersTest(unittest.TestCase): - - def test_chunked_stream(self): - original_bytes = bytearray(range(100)) - for data_length in [1, 10, 100]: - for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: - dest = ByteArrayOutput() - stream_out = serializers.ChunkedStream(dest, buffer_length) - stream_out.write(original_bytes[:data_length]) - stream_out.close() - num_chunks = int(math.ceil(float(data_length) / buffer_length)) - # length for each chunk, and a final -1 at the very end - exp_size = (num_chunks + 1) * 4 + data_length - self.assertEqual(len(dest.buffer), exp_size) - dest_pos = 0 - data_pos = 0 - for chunk_idx in range(num_chunks): - chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) - if chunk_idx == num_chunks - 1: - exp_length = data_length % buffer_length - if exp_length == 0: - exp_length = buffer_length - else: - exp_length = buffer_length - self.assertEqual(chunk_length, exp_length) - dest_pos += 4 - dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] - orig_chunk = original_bytes[data_pos:data_pos + chunk_length] - self.assertEqual(dest_chunk, orig_chunk) - dest_pos += chunk_length - data_pos += chunk_length - # ends with a -1 - self.assertEqual(dest.buffer[-4:], write_int(-1)) - - -class ByteArrayOutput(object): - def __init__(self): - self.buffer = bytearray() - - def write(self, b): - self.buffer += b - - def close(self): - pass - -if __name__ == '__main__': - from pyspark.test_serializers import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 3951776554847..afc40ccf4139d 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -23,7 +23,7 @@ from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row -from pyspark.tests import ReusedPySparkTestCase +from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.util import _exception_message diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py new file mode 100644 index 0000000000000..7df0acae026f3 --- /dev/null +++ b/python/pyspark/testing/utils.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import struct +import sys +import unittest + +from pyspark import SparkContext, SparkConf + + +have_scipy = False +have_numpy = False +try: + import scipy.sparse + have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass +try: + import numpy as np + have_numpy = True +except: + # No NumPy, but that's okay, we'll skip those tests + pass + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +def read_int(b): + return struct.unpack("!i", b)[0] + + +def write_int(i): + return struct.pack("!i", i) + + +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) + + +class PySparkTestCase(unittest.TestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name) + + def tearDown(self): + self.sc.stop() + sys.path = self._old_sys_path + + +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def conf(cls): + """ + Override this in subclasses to supply a more specific conf + """ + return SparkConf() + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class ByteArrayOutput(object): + def __init__(self): + self.buffer = bytearray() + + def write(self, b): + self.buffer += b + + def close(self): + pass diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py deleted file mode 100644 index 131c51e108cad..0000000000000 --- a/python/pyspark/tests.py +++ /dev/null @@ -1,2502 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for PySpark; additional tests are implemented as doctests in -individual modules. -""" - -from array import array -from glob import glob -import os -import re -import shutil -import subprocess -import sys -import tempfile -import time -import zipfile -import random -import threading -import hashlib - -from py4j.protocol import Py4JJavaError -try: - import xmlrunner -except ImportError: - xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - if sys.version_info[0] >= 3: - xrange = range - basestring = str - -if sys.version >= "3": - from io import StringIO -else: - from StringIO import StringIO - - -from pyspark import keyword_only -from pyspark.conf import SparkConf -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.files import SparkFiles -from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ - PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ - FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter -from pyspark import shuffle -from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import BarrierTaskContext, TaskContext - -_have_scipy = False -_have_numpy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass -try: - import numpy as np - _have_numpy = True -except: - # No NumPy, but that's okay, we'll skip those tests - pass - - -SPARK_HOME = os.environ["SPARK_HOME"] - - -class MergerTests(unittest.TestCase): - - def setUp(self): - self.N = 1 << 12 - self.l = [i for i in xrange(self.N)] - self.data = list(zip(self.l, self.l)) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) - - def test_small_dataset(self): - m = ExternalMerger(self.agg, 1000) - m.mergeValues(self.data) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - def test_medium_dataset(self): - m = ExternalMerger(self.agg, 20) - m.mergeValues(self.data) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N)) * 3) - - def test_huge_dataset(self): - m = ExternalMerger(self.agg, 5, partitions=3) - m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m.items()), - self.N * 10) - m._cleanup() - - def test_group_by_key(self): - - def gen_data(N, step): - for i in range(1, N + 1, step): - for j in range(i): - yield (i, [j]) - - def gen_gs(N, step=1): - return shuffle.GroupByKey(gen_data(N, step)) - - self.assertEqual(1, len(list(gen_gs(1)))) - self.assertEqual(2, len(list(gen_gs(2)))) - self.assertEqual(100, len(list(gen_gs(100)))) - self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) - - for k, vs in gen_gs(50002, 10000): - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - ser = PickleSerializer() - l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) - for k, vs in l: - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - def test_stopiteration_is_raised(self): - - def stopit(*args, **kwargs): - raise StopIteration() - - def legit_create_combiner(x): - return [x] - - def legit_merge_value(x, y): - return x.append(y) or x - - def legit_merge_combiners(x, y): - return x.extend(y) or x - - data = [(x % 2, x) for x in range(100)] - - # wrong create combiner - m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge value - m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge combiners - m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - - -class SorterTests(unittest.TestCase): - def test_in_memory_sort(self): - l = list(range(1024)) - random.shuffle(l) - sorter = ExternalSorter(1024) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - - def test_external_sort(self): - class CustomizedSorter(ExternalSorter): - def _next_limit(self): - return self.memory_limit - l = list(range(1024)) - random.shuffle(l) - sorter = CustomizedSorter(1) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertGreater(shuffle.DiskBytesSpilled, 0) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - - def test_external_sort_in_rdd(self): - conf = SparkConf().set("spark.python.worker.memory", "1m") - sc = SparkContext(conf=conf) - l = list(range(10240)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class SerializationTestCase(unittest.TestCase): - - def test_namedtuple(self): - from collections import namedtuple - from pickle import dumps, loads - P = namedtuple("P", "x y") - p1 = P(1, 3) - p2 = loads(dumps(p1, 2)) - self.assertEqual(p1, p2) - - from pyspark.cloudpickle import dumps - P2 = loads(dumps(P)) - p3 = P2(1, 3) - self.assertEqual(p1, p3) - - def test_itemgetter(self): - from operator import itemgetter - ser = CloudPickleSerializer() - d = range(10) - getter = itemgetter(1) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - getter = itemgetter(0, 3) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - def test_function_module_name(self): - ser = CloudPickleSerializer() - func = lambda x: x - func2 = ser.loads(ser.dumps(func)) - self.assertEqual(func.__module__, func2.__module__) - - def test_attrgetter(self): - from operator import attrgetter - ser = CloudPickleSerializer() - - class C(object): - def __getattr__(self, item): - return item - d = C() - getter = attrgetter("a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("a", "b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - d.e = C() - getter = attrgetter("e.a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("e.a", "e.b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - # Regression test for SPARK-3415 - def test_pickling_file_handles(self): - # to be corrected with SPARK-11160 - if not xmlrunner: - ser = CloudPickleSerializer() - out1 = sys.stderr - out2 = ser.loads(ser.dumps(out1)) - self.assertEqual(out1, out2) - - def test_func_globals(self): - - class Unpicklable(object): - def __reduce__(self): - raise Exception("not picklable") - - global exit - exit = Unpicklable() - - ser = CloudPickleSerializer() - self.assertRaises(Exception, lambda: ser.dumps(exit)) - - def foo(): - sys.exit(0) - - self.assertTrue("exit" in foo.__code__.co_names) - ser.dumps(foo) - - def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) - try: - from StringIO import StringIO - except ImportError: - from io import BytesIO as StringIO - io = StringIO() - ser.dump_stream(["abc", u"123", range(5)], io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) - ser.dump_stream(range(1000), io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) - io.close() - - def test_hash_serializer(self): - hash(NoOpSerializer()) - hash(UTF8Deserializer()) - hash(PickleSerializer()) - hash(MarshalSerializer()) - hash(AutoSerializer()) - hash(BatchedSerializer(PickleSerializer())) - hash(AutoBatchedSerializer(MarshalSerializer())) - hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CompressedSerializer(PickleSerializer())) - hash(FlattenedValuesSerializer(PickleSerializer())) - - -class QuietTest(object): - def __init__(self, sc): - self.log4j = sc._jvm.org.apache.log4j - - def __enter__(self): - self.old_level = self.log4j.LogManager.getRootLogger().getLevel() - self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.log4j.LogManager.getRootLogger().setLevel(self.old_level) - - -class PySparkTestCase(unittest.TestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name) - - def tearDown(self): - self.sc.stop() - sys.path = self._old_sys_path - - -class ReusedPySparkTestCase(unittest.TestCase): - - @classmethod - def conf(cls): - """ - Override this in subclasses to supply a more specific conf - """ - return SparkConf() - - @classmethod - def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - - -class CheckpointTests(ReusedPySparkTestCase): - - def setUp(self): - self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) - - def tearDown(self): - shutil.rmtree(self.checkpointDir.name) - - def test_basic_checkpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual("file:" + self.checkpointDir.name, - os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) - - def test_checkpoint_and_restore(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: [x]) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - flatMappedRDD.count() # forces a checkpoint to be computed - time.sleep(1) # 1 second - - self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), - flatMappedRDD._jrdd_deserializer) - self.assertEqual([1, 2, 3, 4], recovered.collect()) - - -class LocalCheckpointTests(ReusedPySparkTestCase): - - def test_basic_localcheckpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) - - flatMappedRDD.localCheckpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - - -class AddFileTests(PySparkTestCase): - - def test_add_py_file(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this job fails due to `userlibrary` not being on the Python path: - # disable logging in log4j temporarily - def func(x): - from userlibrary import UserClass - return UserClass().hello() - with QuietTest(self.sc): - self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) - - # Add the file, so the job should now succeed: - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - res = self.sc.parallelize(range(2)).map(func).first() - self.assertEqual("Hello World!", res) - - def test_add_file_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - self.sc.addFile(path) - download_path = SparkFiles.get("hello.txt") - self.assertNotEqual(path, download_path) - with open(download_path) as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - - def test_add_file_recursively_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello") - self.sc.addFile(path, True) - download_path = SparkFiles.get("hello") - self.assertNotEqual(path, download_path) - with open(download_path + "/hello.txt") as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - with open(download_path + "/sub_hello/sub_hello.txt") as test_file: - self.assertEqual("Sub Hello World!\n", test_file.readline()) - - def test_add_py_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlibrary import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - from userlibrary import UserClass - self.assertEqual("Hello World!", UserClass().hello()) - - def test_add_egg_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlib import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") - self.sc.addPyFile(path) - from userlib import UserClass - self.assertEqual("Hello World from inside a package!", UserClass().hello()) - - def test_overwrite_system_module(self): - self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) - - import SimpleHTTPServer - self.assertEqual("My Server", SimpleHTTPServer.__name__) - - def func(x): - import SimpleHTTPServer - return SimpleHTTPServer.__name__ - - self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) - - -class TaskContextTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - # Allow retries even though they are normally disabled in local mode - self.sc = SparkContext('local[4, 2]', class_name) - - def test_stage_id(self): - """Test the stage ids are available and incrementing as expected.""" - rdd = self.sc.parallelize(range(10)) - stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - # Test using the constructor directly rather than the get() - stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] - self.assertEqual(stage1 + 1, stage2) - self.assertEqual(stage1 + 2, stage3) - self.assertEqual(stage2 + 1, stage3) - - def test_partition_id(self): - """Test the partition id.""" - rdd1 = self.sc.parallelize(range(10), 1) - rdd2 = self.sc.parallelize(range(10), 2) - pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() - pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() - self.assertEqual(0, pids1[0]) - self.assertEqual(0, pids1[9]) - self.assertEqual(0, pids2[0]) - self.assertEqual(1, pids2[9]) - - def test_attempt_number(self): - """Verify the attempt numbers are correctly reported.""" - rdd = self.sc.parallelize(range(10)) - # Verify a simple job with no failures - attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() - map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) - - def fail_on_first(x): - """Fail on the first attempt so we get a positive attempt number""" - tc = TaskContext.get() - attempt_number = tc.attemptNumber() - partition_id = tc.partitionId() - attempt_id = tc.taskAttemptId() - if attempt_number == 0 and partition_id == 0: - raise Exception("Failing on first attempt") - else: - return [x, partition_id, attempt_number, attempt_id] - result = rdd.map(fail_on_first).collect() - # We should re-submit the first partition to it but other partitions should be attempt 0 - self.assertEqual([0, 0, 1], result[0][0:3]) - self.assertEqual([9, 3, 0], result[9][0:3]) - first_partition = filter(lambda x: x[1] == 0, result) - map(lambda x: self.assertEqual(1, x[2]), first_partition) - other_partitions = filter(lambda x: x[1] != 0, result) - map(lambda x: self.assertEqual(0, x[2]), other_partitions) - # The task attempt id should be different - self.assertTrue(result[0][3] != result[9][3]) - - def test_tc_on_driver(self): - """Verify that getting the TaskContext on the driver returns None.""" - tc = TaskContext.get() - self.assertTrue(tc is None) - - def test_get_local_property(self): - """Verify that local properties set on the driver are available in TaskContext.""" - key = "testkey" - value = "testvalue" - self.sc.setLocalProperty(key, value) - try: - rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] - self.assertEqual(prop1, value) - prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] - self.assertTrue(prop2 is None) - finally: - self.sc.setLocalProperty(key, None) - - def test_barrier(self): - """ - Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks - within a stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - tc.barrier() - return time.time() - - times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() - self.assertTrue(max(times) - min(times) < 1) - - def test_barrier_with_python_worker_reuse(self): - """ - Verify that BarrierTaskContext.barrier() with reused python worker. - """ - self.sc._conf.set("spark.python.work.reuse", "true") - rdd = self.sc.parallelize(range(4), 4) - # start a normal job first to start all worker - result = rdd.map(lambda x: x ** 2).collect() - self.assertEqual([0, 1, 4, 9], result) - # make sure `spark.python.work.reuse=true` - self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") - - # worker will be reused in this barrier job - self.test_barrier() - - def test_barrier_infos(self): - """ - Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the - barrier stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() - .getTaskInfos()).collect() - self.assertTrue(len(taskInfos) == 4) - self.assertTrue(len(taskInfos[0]) == 4) - - -class RDDTests(ReusedPySparkTestCase): - - def test_range(self): - self.assertEqual(self.sc.range(1, 1).count(), 0) - self.assertEqual(self.sc.range(1, 0, -1).count(), 1) - self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) - - def test_id(self): - rdd = self.sc.parallelize(range(10)) - id = rdd.id() - self.assertEqual(id, rdd.id()) - rdd2 = rdd.map(str).filter(bool) - id2 = rdd2.id() - self.assertEqual(id + 1, id2) - self.assertEqual(id2, rdd2.id()) - - def test_empty_rdd(self): - rdd = self.sc.emptyRDD() - self.assertTrue(rdd.isEmpty()) - - def test_sum(self): - self.assertEqual(0, self.sc.emptyRDD().sum()) - self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) - - def test_to_localiterator(self): - from time import sleep - rdd = self.sc.parallelize([1, 2, 3]) - it = rdd.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it)) - - rdd2 = rdd.repartition(1000) - it2 = rdd2.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it2)) - - def test_save_as_textfile_with_unicode(self): - # Regression test for SPARK-970 - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode("utf-8")) - - def test_save_as_textfile_with_utf8(self): - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x.encode("utf-8")]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode('utf8')) - - def test_transforming_cartesian_result(self): - # Regression test for SPARK-1034 - rdd1 = self.sc.parallelize([1, 2]) - rdd2 = self.sc.parallelize([3, 4]) - cart = rdd1.cartesian(rdd2) - result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() - - def test_transforming_pickle_file(self): - # Regression test for SPARK-2601 - data = self.sc.parallelize([u"Hello", u"World!"]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsPickleFile(tempFile.name) - pickled_file = self.sc.pickleFile(tempFile.name) - pickled_file.map(lambda x: x).collect() - - def test_cartesian_on_textfile(self): - # Regression test for - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - a = self.sc.textFile(path) - result = a.cartesian(a).collect() - (x, y) = result[0] - self.assertEqual(u"Hello World!", x.strip()) - self.assertEqual(u"Hello World!", y.strip()) - - def test_cartesian_chaining(self): - # Tests for SPARK-16589 - rdd = self.sc.parallelize(range(10), 2) - self.assertSetEqual( - set(rdd.cartesian(rdd).cartesian(rdd).collect()), - set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.cartesian(rdd)).collect()), - set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.zip(rdd)).collect()), - set([(x, (y, y)) for x in range(10) for y in range(10)]) - ) - - def test_zip_chaining(self): - # Tests for SPARK-21985 - rdd = self.sc.parallelize('abc', 2) - self.assertSetEqual( - set(rdd.zip(rdd).zip(rdd).collect()), - set([((x, x), x) for x in 'abc']) - ) - self.assertSetEqual( - set(rdd.zip(rdd.zip(rdd)).collect()), - set([(x, (x, x)) for x in 'abc']) - ) - - def test_deleting_input_files(self): - # Regression test for SPARK-1025 - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - def test_sampling_default_seed(self): - # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(xrange(1000), 1) - subset = data.takeSample(False, 10) - self.assertEqual(len(subset), 10) - - def test_aggregate_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregate and treeAggregate to build dict - # representing a counter of ints - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - # Show that single or multiple partitions work - data1 = self.sc.range(10, numSlices=1) - data2 = self.sc.range(10, numSlices=2) - - def seqOp(x, y): - x[y] += 1 - return x - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) - counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) - counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - - ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) - self.assertEqual(counts1, ground_truth) - self.assertEqual(counts2, ground_truth) - self.assertEqual(counts3, ground_truth) - self.assertEqual(counts4, ground_truth) - - def test_aggregate_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that - # contains lists of all values for each key in the original RDD - - # list(range(...)) for Python 3.x compatibility (can't use * operator - # on a range object) - # list(zip(...)) for Python 3.x compatibility (want to parallelize a - # collection, not a zip object) - tuples = list(zip(list(range(10))*2, [1]*20)) - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def seqOp(x, y): - x.append(y) - return x - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.aggregateByKey([], seqOp, comboOp).collect() - values2 = data2.aggregateByKey([], seqOp, comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - ground_truth = [(i, [1]*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_fold_mutable_zero_value(self): - # Test for SPARK-9021; uses fold to merge an RDD of dict counters into - # a single dict - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - counts1 = defaultdict(int, dict((i, 1) for i in range(10))) - counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) - counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) - counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) - all_counts = [counts1, counts2, counts3, counts4] - # Show that single or multiple partitions work - data1 = self.sc.parallelize(all_counts, 1) - data2 = self.sc.parallelize(all_counts, 2) - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - fold1 = data1.fold(defaultdict(int), comboOp) - fold2 = data2.fold(defaultdict(int), comboOp) - - ground_truth = defaultdict(int) - for counts in all_counts: - for key, val in counts.items(): - ground_truth[key] += val - self.assertEqual(fold1, ground_truth) - self.assertEqual(fold2, ground_truth) - - def test_fold_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains - # lists of all values for each key in the original RDD - - tuples = [(i, range(i)) for i in range(10)]*2 - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.foldByKey([], comboOp).collect() - values2 = data2.foldByKey([], comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - # list(range(...)) for Python 3.x compatibility - ground_truth = [(i, list(range(i))*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_aggregate_by_key(self): - data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) - - def seqOp(x, y): - x.add(y) - return x - - def combOp(x, y): - x |= y - return x - - sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) - self.assertEqual(3, len(sets)) - self.assertEqual(set([1]), sets[1]) - self.assertEqual(set([2]), sets[3]) - self.assertEqual(set([1, 3]), sets[5]) - - def test_itemgetter(self): - rdd = self.sc.parallelize([range(10)]) - from operator import itemgetter - self.assertEqual([1], rdd.map(itemgetter(1)).collect()) - self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) - - def test_namedtuple_in_rdd(self): - from collections import namedtuple - Person = namedtuple("Person", "id firstName lastName") - jon = Person(1, "Jon", "Doe") - jane = Person(2, "Jane", "Doe") - theDoes = self.sc.parallelize([jon, jane]) - self.assertEqual([jon, jane], theDoes.collect()) - - def test_large_broadcast(self): - N = 10000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 27MB - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - - def test_unpersist(self): - N = 1000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 3MB - bdata.unpersist() - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - bdata.destroy() - try: - self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - except Exception as e: - pass - else: - raise Exception("job should fail after destroy the broadcast") - - def test_multiple_broadcasts(self): - N = 1 << 21 - b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = list(range(1 << 15)) - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - def test_multithread_broadcast_pickle(self): - import threading - - b1 = self.sc.broadcast(list(range(3))) - b2 = self.sc.broadcast(list(range(3))) - - def f1(): - return b1.value - - def f2(): - return b2.value - - funcs_num_pickled = {f1: None, f2: None} - - def do_pickle(f, sc): - command = (f, None, sc.serializer, sc.serializer) - ser = CloudPickleSerializer() - ser.dumps(command) - - def process_vars(sc): - broadcast_vars = list(sc._pickled_broadcast_vars) - num_pickled = len(broadcast_vars) - sc._pickled_broadcast_vars.clear() - return num_pickled - - def run(f, sc): - do_pickle(f, sc) - funcs_num_pickled[f] = process_vars(sc) - - # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage - do_pickle(f1, self.sc) - - # run all for f2, should only add/count/clear b2 from worker thread local storage - t = threading.Thread(target=run, args=(f2, self.sc)) - t.start() - t.join() - - # count number of vars pickled in main thread, only b1 should be counted and cleared - funcs_num_pickled[f1] = process_vars(self.sc) - - self.assertEqual(funcs_num_pickled[f1], 1) - self.assertEqual(funcs_num_pickled[f2], 1) - self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) - - def test_large_closure(self): - N = 200000 - data = [float(i) for i in xrange(N)] - rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEqual(N, rdd.first()) - # regression test for SPARK-6886 - self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) - - def test_zip_with_different_serializers(self): - a = self.sc.parallelize(range(5)) - b = self.sc.parallelize(range(100, 105)) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) - b = b._reserialize(MarshalSerializer()) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - # regression test for SPARK-4841 - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - t = self.sc.textFile(path) - cnt = t.count() - self.assertEqual(cnt, t.zip(t).count()) - rdd = t.map(str) - self.assertEqual(cnt, t.zip(rdd).count()) - # regression test for bug in _reserializer() - self.assertEqual(cnt, t.zip(rdd).count()) - - def test_zip_with_different_object_sizes(self): - # regress test for SPARK-5973 - a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) - b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) - self.assertEqual(10000, a.zip(b).count()) - - def test_zip_with_different_number_of_items(self): - a = self.sc.parallelize(range(5), 2) - # different number of partitions - b = self.sc.parallelize(range(100, 106), 3) - self.assertRaises(ValueError, lambda: a.zip(b)) - with QuietTest(self.sc): - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEqual(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) - - def test_count_approx_distinct(self): - rdd = self.sc.parallelize(xrange(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) - - rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) - self.assertTrue(18 < rdd.countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) - - self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) - - def test_histogram(self): - # empty - rdd = self.sc.parallelize([]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - self.assertRaises(ValueError, lambda: rdd.histogram(1)) - - # out of range - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) - - # in range with one bucket - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual([4], rdd.histogram([0, 10])[1]) - self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) - - # in range with one bucket exact match - self.assertEqual([4], rdd.histogram([1, 4])[1]) - - # out of range with two buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) - - # out of range with two uneven buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - - # in range with two buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two bucket and None - rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two uneven buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) - - # mixed range with two uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) - - # mixed range with four uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # mixed range with uneven buckets and NaN - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, - 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # out of range with infinite buckets - rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) - - # invalid buckets - self.assertRaises(ValueError, lambda: rdd.histogram([])) - self.assertRaises(ValueError, lambda: rdd.histogram([1])) - self.assertRaises(ValueError, lambda: rdd.histogram(0)) - self.assertRaises(TypeError, lambda: rdd.histogram({})) - - # without buckets - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 4], [4]), rdd.histogram(1)) - - # without buckets single element - rdd = self.sc.parallelize([1]) - self.assertEqual(([1, 1], [1]), rdd.histogram(1)) - - # without bucket no range - rdd = self.sc.parallelize([1] * 4) - self.assertEqual(([1, 1], [4]), rdd.histogram(1)) - - # without buckets basic two - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) - - # without buckets with more requested than elements - rdd = self.sc.parallelize([1, 2]) - buckets = [1 + 0.2 * i for i in range(6)] - hist = [1, 0, 0, 0, 1] - self.assertEqual((buckets, hist), rdd.histogram(5)) - - # invalid RDDs - rdd = self.sc.parallelize([1, float('inf')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - rdd = self.sc.parallelize([float('nan')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - - # string - rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - def test_repartitionAndSortWithinPartitions_asc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) - - def test_repartitionAndSortWithinPartitions_desc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) - self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) - - def test_repartition_no_skewed(self): - num_partitions = 20 - a = self.sc.parallelize(range(int(1000)), 2) - l = a.repartition(num_partitions).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - l = a.coalesce(num_partitions, True).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - - def test_repartition_on_textfile(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - rdd = self.sc.textFile(path) - result = rdd.repartition(1).collect() - self.assertEqual(u"Hello World!", result[0]) - - def test_distinct(self): - rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEqual(rdd.getNumPartitions(), 10) - self.assertEqual(rdd.distinct().count(), 3) - result = rdd.distinct(5) - self.assertEqual(result.getNumPartitions(), 5) - self.assertEqual(result.count(), 3) - - def test_external_group_by_key(self): - self.sc._conf.set("spark.python.worker.memory", "1m") - N = 200001 - kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) - gkv = kv.groupByKey().cache() - self.assertEqual(3, gkv.count()) - filtered = gkv.filter(lambda kv: kv[0] == 1) - self.assertEqual(1, filtered.count()) - self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) - self.assertEqual([(N // 3, N // 3)], - filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) - result = filtered.collect()[0][1] - self.assertEqual(N // 3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) - - def test_sort_on_empty_rdd(self): - self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) - - def test_sample(self): - rdd = self.sc.parallelize(range(0, 100), 4) - wo = rdd.sample(False, 0.1, 2).collect() - wo_dup = rdd.sample(False, 0.1, 2).collect() - self.assertSetEqual(set(wo), set(wo_dup)) - wr = rdd.sample(True, 0.2, 5).collect() - wr_dup = rdd.sample(True, 0.2, 5).collect() - self.assertSetEqual(set(wr), set(wr_dup)) - wo_s10 = rdd.sample(False, 0.3, 10).collect() - wo_s20 = rdd.sample(False, 0.3, 20).collect() - self.assertNotEqual(set(wo_s10), set(wo_s20)) - wr_s11 = rdd.sample(True, 0.4, 11).collect() - wr_s21 = rdd.sample(True, 0.4, 21).collect() - self.assertNotEqual(set(wr_s11), set(wr_s21)) - - def test_null_in_rdd(self): - jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) - rdd = RDD(jrdd, self.sc, UTF8Deserializer()) - self.assertEqual([u"a", None, u"b"], rdd.collect()) - rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual([b"a", None, b"b"], rdd.collect()) - - def test_multiple_python_java_RDD_conversions(self): - # Regression test for SPARK-5361 - data = [ - (u'1', {u'director': u'David Lean'}), - (u'2', {u'director': u'Andrew Dominik'}) - ] - data_rdd = self.sc.parallelize(data) - data_java_rdd = data_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - # conversion between python and java RDD threw exceptions - data_java_rdd = converted_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - def test_narrow_dependency_in_join(self): - rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) - parted = rdd.partitionBy(2) - self.assertEqual(2, parted.union(parted).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - - tracker = self.sc.statusTracker() - - self.sc.setJobGroup("test1", "test", True) - d = sorted(parted.join(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test1")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test2", "test", True) - d = sorted(parted.join(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test2")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test3", "test", True) - d = sorted(parted.cogroup(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test3")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test4", "test", True) - d = sorted(parted.cogroup(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test4")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - # Regression test for SPARK-6294 - def test_take_on_jrdd(self): - rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) - rdd._jrdd.first() - - def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): - # Regression test for SPARK-5969 - seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence - rdd = self.sc.parallelize(seq) - for ascending in [True, False]: - sort = rdd.sortByKey(ascending=ascending, numPartitions=5) - self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) - sizes = sort.glom().map(len).collect() - for size in sizes: - self.assertGreater(size, 0) - - def test_pipe_functions(self): - data = ['1', '2', '3'] - rdd = self.sc.parallelize(data) - with QuietTest(self.sc): - self.assertEqual([], rdd.pipe('cc').collect()) - self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) - result = rdd.pipe('cat').collect() - result.sort() - for x, y in zip(data, result): - self.assertEqual(x, y) - self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) - self.assertEqual([], rdd.pipe('grep 4').collect()) - - def test_pipe_unicode(self): - # Regression test for SPARK-20947 - data = [u'\u6d4b\u8bd5', '1'] - rdd = self.sc.parallelize(data) - result = rdd.pipe('cat').collect() - self.assertEqual(data, result) - - def test_stopiteration_in_user_code(self): - - def stopit(*x): - raise StopIteration() - - seq_rdd = self.sc.parallelize(range(10)) - keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - msg = "Caught StopIteration thrown from user's code; failing the task" - - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, - seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - - # these methods call the user function both in the driver and in the executor - # the exception raised is different according to where the StopIteration happens - # RuntimeError is raised if in the driver - # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, lambda *x: 1, stopit) - - -class ProfilerTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, conf=conf) - - def test_profiler(self): - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - id, profiler, _ = profilers[0] - stats = profiler.stats() - self.assertTrue(stats is not None) - width, stat_list = stats.get_print_list([]) - func_names = [func_name for fname, n, func_name in stat_list] - self.assertTrue("heavy_foo" in func_names) - - old_stdout = sys.stdout - sys.stdout = io = StringIO() - self.sc.show_profiles() - self.assertTrue("heavy_foo" in io.getvalue()) - sys.stdout = old_stdout - - d = tempfile.gettempdir() - self.sc.dump_profiles(d) - self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) - - def test_custom_profiler(self): - class TestCustomProfiler(BasicProfiler): - def show(self, id): - self.result = "Custom formatting" - - self.sc.profiler_collector.profiler_cls = TestCustomProfiler - - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - _, profiler, _ = profilers[0] - self.assertTrue(isinstance(profiler, TestCustomProfiler)) - - self.sc.show_profiles() - self.assertEqual("Custom formatting", profiler.result) - - def do_computation(self): - def heavy_foo(x): - for i in range(1 << 18): - x = 1 - - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - - -class ProfilerTests2(unittest.TestCase): - def test_profiler_disabled(self): - sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) - try: - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.show_profiles()) - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.dump_profiles("/tmp/abc")) - finally: - sc.stop() - - -class InputFormatTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", - "org.apache.hadoop.io.DoubleWritable", - "org.apache.hadoop.io.Text").collect()) - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.assertEqual(doubles, ed) - - bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) - ebs = [(1, bytearray('aa', 'utf-8')), - (1, bytearray('aa', 'utf-8')), - (2, bytearray('aa', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (3, bytearray('cc', 'utf-8'))] - self.assertEqual(bytes, ebs) - - text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", - "org.apache.hadoop.io.Text", - "org.apache.hadoop.io.Text").collect()) - et = [(u'1', u'aa'), - (u'1', u'aa'), - (u'2', u'aa'), - (u'2', u'bb'), - (u'2', u'bb'), - (u'3', u'cc')] - self.assertEqual(text, et) - - bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.assertEqual(bools, eb) - - nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.assertEqual(nulls, en) - - maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - for v in maps: - self.assertTrue(v in em) - - # arrays get pickled to tuples by default - tuples = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable").collect()) - et = [(1, ()), - (2, (3.0, 4.0, 5.0)), - (3, (4.0, 5.0, 6.0))] - self.assertEqual(tuples, et) - - # with custom converters, primitive arrays can stay as arrays - arrays = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - ea = [(1, array('d')), - (2, array('d', [3.0, 4.0, 5.0])), - (3, array('d', [4.0, 5.0, 6.0]))] - self.assertEqual(arrays, ea) - - clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable").collect()) - cname = u'org.apache.spark.api.python.TestWritable' - ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), - (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), - (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), - (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), - (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] - self.assertEqual(clazz, ec) - - unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - ).collect()) - self.assertEqual(unbatched_clazz, ec) - - def test_oldhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=oldconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=newconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newolderror(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.sequenceFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.NotValidWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - maps = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - keyConverter="org.apache.spark.api.python.TestInputKeyConverter", - valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) - em = [(u'\x01', []), - (u'\x01', [3.0]), - (u'\x02', [1.0]), - (u'\x02', [1.0]), - (u'\x03', [2.0])] - self.assertEqual(maps, em) - - def test_binary_files(self): - path = os.path.join(self.tempdir.name, "binaryfiles") - os.mkdir(path) - data = b"short binary data" - with open(os.path.join(path, "part-0000"), 'wb') as f: - f.write(data) - [(p, d)] = self.sc.binaryFiles(path).collect() - self.assertTrue(p.endswith("part-0000")) - self.assertEqual(d, data) - - def test_binary_records(self): - path = os.path.join(self.tempdir.name, "binaryrecords") - os.mkdir(path) - with open(os.path.join(path, "part-0000"), 'w') as f: - for i in range(100): - f.write('%04d' % i) - result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(list(range(100)), result) - - -class OutputFormatTests(ReusedPySparkTestCase): - - def setUp(self): - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - - def tearDown(self): - shutil.rmtree(self.tempdir.name, ignore_errors=True) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") - ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) - self.assertEqual(ints, ei) - - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") - doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) - self.assertEqual(doubles, ed) - - ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] - self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") - bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) - self.assertEqual(bytes, ebs) - - et = [(u'1', u'aa'), - (u'2', u'bb'), - (u'3', u'cc')] - self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") - text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) - self.assertEqual(text, et) - - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") - bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) - self.assertEqual(bools, eb) - - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") - nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) - self.assertEqual(nulls, en) - - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() - for v in maps: - self.assertTrue(v, em) - - def test_oldhadoop(self): - basepath = self.tempdir.name - dict_data = [(1, {}), - (1, {"row1": 1.0}), - (2, {"row2": 2.0})] - self.sc.parallelize(dict_data).saveAsHadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable") - result = self.sc.hadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - for v in result: - self.assertTrue(v, dict_data) - - conf = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" - } - self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} - result = self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect() - for v in result: - self.assertTrue(v, dict_data) - - def test_newhadoop(self): - basepath = self.tempdir.name - data = [(1, ""), - (1, "a"), - (2, "bcdf")] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - self.assertEqual(result, data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=input_conf).collect()) - self.assertEqual(new_dataset, data) - - @unittest.skipIf(sys.version >= "3", "serialize of array") - def test_newhadoop_with_array(self): - basepath = self.tempdir.name - # use custom ArrayWritable types and converters to handle arrays - array_data = [(1, array('d')), - (1, array('d', [1.0, 2.0, 3.0])), - (2, array('d', [3.0, 4.0, 5.0]))] - self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - self.assertEqual(result, array_data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( - conf, - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", - conf=input_conf).collect()) - self.assertEqual(new_dataset, array_data) - - def test_newolderror(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/newolderror/saveAsHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/newolderror/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/badinputs/saveAsHadoopFile/", - "org.apache.hadoop.mapred.NotValidOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/badinputs/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - data = [(1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/converters/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", - valueConverter="org.apache.spark.api.python.TestOutputValueConverter") - converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) - expected = [(u'1', 3.0), - (u'2', 1.0), - (u'3', 2.0)] - self.assertEqual(converted, expected) - - def test_reserialization(self): - basepath = self.tempdir.name - x = range(1, 5) - y = range(1001, 1005) - data = list(zip(x, y)) - rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) - rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") - result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) - self.assertEqual(result1, data) - - rdd.saveAsHadoopFile( - basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") - result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) - self.assertEqual(result2, data) - - rdd.saveAsNewAPIHadoopFile( - basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) - self.assertEqual(result3, data) - - conf4 = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} - rdd.saveAsHadoopDataset(conf4) - result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) - self.assertEqual(result4, data) - - conf5 = {"mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" - } - rdd.saveAsNewAPIHadoopDataset(conf5) - result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) - self.assertEqual(result5, data) - - def test_malformed_RDD(self): - basepath = self.tempdir.name - # non-batch-serialized RDD[[(K, V)]] should be rejected - data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] - rdd = self.sc.parallelize(data, len(data)) - self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( - basepath + "/malformed/sequence")) - - -class DaemonTests(unittest.TestCase): - def connect(self, port): - from socket import socket, AF_INET, SOCK_STREAM - sock = socket(AF_INET, SOCK_STREAM) - sock.connect(('127.0.0.1', port)) - # send a split index of -1 to shutdown the worker - sock.send(b"\xFF\xFF\xFF\xFF") - sock.close() - return True - - def do_termination_test(self, terminator): - from subprocess import Popen, PIPE - from errno import ECONNREFUSED - - # start daemon - daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") - daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) - - # read the port number - port = read_int(daemon.stdout) - - # daemon should accept connections - self.assertTrue(self.connect(port)) - - # request shutdown - terminator(daemon) - time.sleep(1) - - # daemon should no longer accept connections - try: - self.connect(port) - except EnvironmentError as exception: - self.assertEqual(exception.errno, ECONNREFUSED) - else: - self.fail("Expected EnvironmentError to be raised") - - def test_termination_stdin(self): - """Ensure that daemon and workers terminate when stdin is closed.""" - self.do_termination_test(lambda daemon: daemon.stdin.close()) - - def test_termination_sigterm(self): - """Ensure that daemon and workers terminate on SIGTERM.""" - from signal import SIGTERM - self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) - - -class WorkerTests(ReusedPySparkTestCase): - def test_cancel_task(self): - temp = tempfile.NamedTemporaryFile(delete=True) - temp.close() - path = temp.name - - def sleep(x): - import os - import time - with open(path, 'w') as f: - f.write("%d %d" % (os.getppid(), os.getpid())) - time.sleep(100) - - # start job in background thread - def run(): - try: - self.sc.parallelize(range(1), 1).foreach(sleep) - except Exception: - pass - import threading - t = threading.Thread(target=run) - t.daemon = True - t.start() - - daemon_pid, worker_pid = 0, 0 - while True: - if os.path.exists(path): - with open(path) as f: - data = f.read().split(' ') - daemon_pid, worker_pid = map(int, data) - break - time.sleep(0.1) - - # cancel jobs - self.sc.cancelAllJobs() - t.join() - - for i in range(50): - try: - os.kill(worker_pid, 0) - time.sleep(0.1) - except OSError: - break # worker was killed - else: - self.fail("worker has not been killed after 5 seconds") - - try: - os.kill(daemon_pid, 0) - except OSError: - self.fail("daemon had been killed") - - # run a normal job - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_exception(self): - def raise_exception(_): - raise Exception() - rdd = self.sc.parallelize(xrange(100), 1) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_jvm_exception(self): - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name, 1) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_accumulator_when_reuse_worker(self): - from pyspark.accumulators import INT_ACCUMULATOR_PARAM - acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) - self.assertEqual(sum(range(100)), acc1.value) - - acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) - self.assertEqual(sum(range(100)), acc2.value) - self.assertEqual(sum(range(100)), acc1.value) - - def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(xrange(100000), 1) - self.assertEqual(0, rdd.first()) - - def count(): - try: - rdd.count() - except Exception: - pass - - t = threading.Thread(target=count) - t.daemon = True - t.start() - t.join(5) - self.assertTrue(not t.isAlive()) - self.assertEqual(100000, rdd.count()) - - def test_with_different_versions_of_python(self): - rdd = self.sc.parallelize(range(10)) - rdd.count() - version = self.sc.pythonVer - self.sc.pythonVer = "2.0" - try: - with QuietTest(self.sc): - self.assertRaises(Py4JJavaError, lambda: rdd.count()) - finally: - self.sc.pythonVer = version - - -class SparkSubmitTests(unittest.TestCase): - - def setUp(self): - self.programDir = tempfile.mkdtemp() - tmp_dir = tempfile.gettempdir() - self.sparkSubmit = [ - os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), - "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - ] - - def tearDown(self): - shutil.rmtree(self.programDir) - - def createTempFile(self, name, content, dir=None): - """ - Create a temp file with the given name and content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name) - else: - os.makedirs(os.path.join(self.programDir, dir)) - path = os.path.join(self.programDir, dir, name) - with open(path, "w") as f: - f.write(content) - return path - - def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): - """ - Create a zip archive containing a file with the given content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name + ext) - else: - path = os.path.join(self.programDir, dir, zip_name + ext) - zip = zipfile.ZipFile(path, 'w') - zip.writestr(name, content) - zip.close() - return path - - def create_spark_package(self, artifact_name): - group_id, artifact_id, version = artifact_name.split(":") - self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" - | - | - | 4.0.0 - | %s - | %s - | %s - | - """ % (group_id, artifact_id, version)).lstrip(), - os.path.join(group_id, artifact_id, version)) - self.createFileInZip("%s.py" % artifact_id, """ - |def myfunc(x): - | return x + 1 - """, ".jar", os.path.join(group_id, artifact_id, version), - "%s-%s" % (artifact_id, version)) - - def test_single_script(self): - """Submit and test a single script file""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_script_with_local_functions(self): - """Submit and test a single script file calling a global function""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 3 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out.decode('utf-8')) - - def test_module_dependency(self): - """Submit and test a script with a dependency on another module""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_module_dependency_on_cluster(self): - """Submit and test a script with a dependency on another module on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", - "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency(self): - """Submit and test a script with a dependency on a Spark Package""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency_on_cluster(self): - """Submit and test a script with a dependency on a Spark Package on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", - script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_single_script_on_cluster(self): - """Submit and test a single script on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 2 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - # this will fail if you have different spark.executor.memory - # in conf/spark-defaults.conf - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_user_configuration(self): - """Make sure user configuration is respected (SPARK-19307)""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkConf, SparkContext - | - |conf = SparkConf().set("spark.test_config", "1") - |sc = SparkContext(conf = conf) - |try: - | if sc._conf.get("spark.test_config") != "1": - | raise Exception("Cannot find spark.test_config in SparkContext's conf.") - |finally: - | sc.stop() - """) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local", script], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) - - -class ContextTests(unittest.TestCase): - - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - - def test_get_or_create(self): - with SparkContext.getOrCreate() as sc: - self.assertTrue(SparkContext.getOrCreate() is sc) - - def test_parallelize_eager_cleanup(self): - with SparkContext() as sc: - temp_files = os.listdir(sc._temp_dir) - rdd = sc.parallelize([0, 1, 2]) - post_parallalize_temp_files = os.listdir(sc._temp_dir) - self.assertEqual(temp_files, post_parallalize_temp_files) - - def test_set_conf(self): - # This is for an internal use case. When there is an existing SparkContext, - # SparkSession's builder needs to set configs into SparkContext's conf. - sc = SparkContext() - sc._conf.set("spark.test.SPARK16224", "SPARK16224") - self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") - sc.stop() - - def test_stop(self): - sc = SparkContext() - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_exception(self): - try: - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - raise Exception() - except: - pass - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_stop(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_progress_api(self): - with SparkContext() as sc: - sc.setJobGroup('test_progress_api', '', True) - rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) - - def run(): - try: - rdd.count() - except Exception: - pass - t = threading.Thread(target=run) - t.daemon = True - t.start() - # wait for scheduler to start - time.sleep(1) - - tracker = sc.statusTracker() - jobIds = tracker.getJobIdsForGroup('test_progress_api') - self.assertEqual(1, len(jobIds)) - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual(1, len(job.stageIds)) - stage = tracker.getStageInfo(job.stageIds[0]) - self.assertEqual(rdd.getNumPartitions(), stage.numTasks) - - sc.cancelAllJobs() - t.join() - # wait for event listener to update the status - time.sleep(1) - - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual('FAILED', job.status) - self.assertEqual([], tracker.getActiveJobsIds()) - self.assertEqual([], tracker.getActiveStageIds()) - - sc.stop() - - def test_startTime(self): - with SparkContext() as sc: - self.assertGreater(sc.startTime, 0) - - -class ConfTests(unittest.TestCase): - def test_memory_conf(self): - memoryList = ["1T", "1G", "1M", "1024K"] - for memory in memoryList: - sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) - l = list(range(1024)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class KeywordOnlyTests(unittest.TestCase): - class Wrapped(object): - @keyword_only - def set(self, x=None, y=None): - if "x" in self._input_kwargs: - self._x = self._input_kwargs["x"] - if "y" in self._input_kwargs: - self._y = self._input_kwargs["y"] - return x, y - - def test_keywords(self): - w = self.Wrapped() - x, y = w.set(y=1) - self.assertEqual(y, 1) - self.assertEqual(y, w._y) - self.assertIsNone(x) - self.assertFalse(hasattr(w, "_x")) - - def test_non_keywords(self): - w = self.Wrapped() - self.assertRaises(TypeError, lambda: w.set(0, y=1)) - - def test_kwarg_ownership(self): - # test _input_kwargs is owned by each class instance and not a shared static variable - class Setter(object): - @keyword_only - def set(self, x=None, other=None, other_x=None): - if "other" in self._input_kwargs: - self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) - self._x = self._input_kwargs["x"] - - a = Setter() - b = Setter() - a.set(x=1, other=b, other_x=2) - self.assertEqual(a._x, 1) - self.assertEqual(b._x, 2) - - -class UtilTests(PySparkTestCase): - def test_py4j_exception_message(self): - from pyspark.util import _exception_message - - with self.assertRaises(Py4JJavaError) as context: - # This attempts java.lang.String(null) which throws an NPE. - self.sc._jvm.java.lang.String(None) - - self.assertTrue('NullPointerException' in _exception_message(context.exception)) - - def test_parsing_version_string(self): - from pyspark.util import VersionUtils - self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): - - """General PySpark tests that depend on scipy """ - - def test_serialize(self): - from scipy.special import gammaln - x = range(1, 5) - expected = list(map(gammaln, x)) - observed = self.sc.parallelize(x).map(gammaln).collect() - self.assertEqual(expected, observed) - - -@unittest.skipIf(not _have_numpy, "NumPy not installed") -class NumPyTests(PySparkTestCase): - - """General PySpark tests that depend on numpy """ - - def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) - s = x.stats() - self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) - self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) - - stats_dict = s.asDict() - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) - - stats_sample_dict = s.asDict(sample=True) - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) - self.assertSequenceEqual( - [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) - self.assertSequenceEqual( - [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) - - -if __name__ == "__main__": - from pyspark.tests import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/tests/__init__.py b/python/pyspark/tests/__init__.py new file mode 100644 index 0000000000000..12bdf0d0175b6 --- /dev/null +++ b/python/pyspark/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py new file mode 100644 index 0000000000000..92bcb11561307 --- /dev/null +++ b/python/pyspark/tests/test_appsubmit.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import shutil +import subprocess +import tempfile +import unittest +import zipfile + + +class SparkSubmitTests(unittest.TestCase): + + def setUp(self): + self.programDir = tempfile.mkdtemp() + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] + + def tearDown(self): + shutil.rmtree(self.programDir) + + def createTempFile(self, name, content, dir=None): + """ + Create a temp file with the given name and content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) + with open(path, "w") as f: + f.write(content) + return path + + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): + """ + Create a zip archive containing a file with the given content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() + return path + + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + + def test_single_script(self): + """Submit and test a single script file""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_script_with_local_functions(self): + """Submit and test a single script file calling a global function""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 3 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) + + def test_module_dependency(self): + """Submit and test a script with a dependency on another module""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_module_dependency_on_cluster(self): + """Submit and test a script with a dependency on another module on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", + "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_single_script_on_cluster(self): + """Submit and test a single script on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 2 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_user_configuration(self): + """Make sure user configuration is respected (SPARK-19307)""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkConf, SparkContext + | + |conf = SparkConf().set("spark.test_config", "1") + |sc = SparkContext(conf = conf) + |try: + | if sc._conf.get("spark.test_config") != "1": + | raise Exception("Cannot find spark.test_config in SparkContext's conf.") + |finally: + | sc.stop() + """) + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local", script], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) + + +if __name__ == "__main__": + from pyspark.tests.test_appsubmit import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/tests/test_broadcast.py similarity index 91% rename from python/pyspark/test_broadcast.py rename to python/pyspark/tests/test_broadcast.py index a00329c18ad8f..a98626e8f4bc9 100644 --- a/python/pyspark/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -14,20 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import os import random import tempfile import unittest -try: - import xmlrunner -except ImportError: - xmlrunner = None - -from pyspark.broadcast import Broadcast -from pyspark.conf import SparkConf -from pyspark.context import SparkContext +from pyspark import SparkConf, SparkContext from pyspark.java_gateway import launch_gateway from pyspark.serializers import ChunkedStream @@ -118,9 +110,13 @@ def random_bytes(n): for buffer_length in [1, 2, 5, 8192]: self._test_chunked_stream(random_bytes(data_length), buffer_length) + if __name__ == '__main__': - from pyspark.test_broadcast import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) + from pyspark.tests.test_broadcast import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py new file mode 100644 index 0000000000000..f5a9accc3fe6e --- /dev/null +++ b/python/pyspark/tests/test_conf.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import unittest + +from pyspark import SparkContext, SparkConf + + +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_conf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py new file mode 100644 index 0000000000000..201baf420354d --- /dev/null +++ b/python/pyspark/tests/test_context.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import tempfile +import threading +import time +import unittest + +from pyspark import SparkFiles, SparkContext +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME + + +class CheckpointTests(ReusedPySparkTestCase): + + def setUp(self): + self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + shutil.rmtree(self.checkpointDir.name) + + def test_basic_checkpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual("file:" + self.checkpointDir.name, + os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) + + def test_checkpoint_and_restore(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) + self.assertEqual([1, 2, 3, 4], recovered.collect()) + + +class LocalCheckpointTests(ReusedPySparkTestCase): + + def test_basic_localcheckpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) + + flatMappedRDD.localCheckpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + + +class AddFileTests(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + def func(x): + from userlibrary import UserClass + return UserClass().hello() + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + + def test_add_file_recursively_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello") + self.sc.addFile(path, True) + download_path = SparkFiles.get("hello") + self.assertNotEqual(path, download_path) + with open(download_path + "/hello.txt") as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + with open(download_path + "/sub_hello/sub_hello.txt") as test_file: + self.assertEqual("Sub Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + + def test_add_egg_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlib import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") + self.sc.addPyFile(path) + from userlib import UserClass + self.assertEqual("Hello World from inside a package!", UserClass().hello()) + + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + + +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + + def test_set_conf(self): + # This is for an internal use case. When there is an existing SparkContext, + # SparkSession's builder needs to set configs into SparkContext's conf. + sc = SparkContext() + sc._conf.set("spark.test.SPARK16224", "SPARK16224") + self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") + sc.stop() + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + + +if __name__ == "__main__": + from pyspark.tests.test_context import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py new file mode 100644 index 0000000000000..fccd74fff1516 --- /dev/null +++ b/python/pyspark/tests/test_daemon.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import time +import unittest + +from pyspark.serializers import read_int + + +class DaemonTests(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send(b"\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "..", "daemon.py") + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + try: + self.connect(port) + except EnvironmentError as exception: + self.assertEqual(exception.errno, ECONNREFUSED) + else: + self.fail("Expected EnvironmentError to be raised") + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + + +if __name__ == "__main__": + from pyspark.tests.test_daemon import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py new file mode 100644 index 0000000000000..e97e695f8b20d --- /dev/null +++ b/python/pyspark/tests/test_join.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.utils import ReusedPySparkTestCase + + +class JoinTests(ReusedPySparkTestCase): + + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + tracker = self.sc.statusTracker() + + self.sc.setJobGroup("test1", "test", True) + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_join import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py new file mode 100644 index 0000000000000..56cbcff01657c --- /dev/null +++ b/python/pyspark/tests/test_profiler.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile +import unittest + +from pyspark import SparkConf, SparkContext, BasicProfiler +from pyspark.testing.utils import PySparkTestCase + +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO + + +class ProfilerTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, conf=conf) + + def test_profiler(self): + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + old_stdout = sys.stdout + sys.stdout = io = StringIO() + self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 18): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + + +class ProfilerTests2(unittest.TestCase): + def test_profiler_disabled(self): + sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) + try: + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.show_profiles()) + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.dump_profiles("/tmp/abc")) + finally: + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_profiler import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py new file mode 100644 index 0000000000000..b2a544b8de78a --- /dev/null +++ b/python/pyspark/tests/test_rdd.py @@ -0,0 +1,739 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import hashlib +import os +import random +import sys +import tempfile +from glob import glob + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, RDD +from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\ + MarshalSerializer, UTF8Deserializer, NoOpSerializer +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class RDDTests(ReusedPySparkTestCase): + + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + + def test_id(self): + rdd = self.sc.parallelize(range(10)) + id = rdd.id() + self.assertEqual(id, rdd.id()) + rdd2 = rdd.map(str).filter(bool) + id2 = rdd2.id() + self.assertEqual(id + 1, id2) + self.assertEqual(id2, rdd2.id()) + + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + + def test_to_localiterator(self): + from time import sleep + rdd = self.sc.parallelize([1, 2, 3]) + it = rdd.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it)) + + rdd2 = rdd.repartition(1000) + it2 = rdd2.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it2)) + + def test_save_as_textfile_with_unicode(self): + # Regression test for SPARK-970 + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) + + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) + + def test_transforming_cartesian_result(self): + # Regression test for SPARK-1034 + rdd1 = self.sc.parallelize([1, 2]) + rdd2 = self.sc.parallelize([3, 4]) + cart = rdd1.cartesian(rdd2) + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() + + def test_transforming_pickle_file(self): + # Regression test for SPARK-2601 + data = self.sc.parallelize([u"Hello", u"World!"]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsPickleFile(tempFile.name) + pickled_file = self.sc.pickleFile(tempFile.name) + pickled_file.map(lambda x: x).collect() + + def test_cartesian_on_textfile(self): + # Regression test for + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + a = self.sc.textFile(path) + result = a.cartesian(a).collect() + (x, y) = result[0] + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) + + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.zip(rdd)).collect()), + set([(x, (y, y)) for x in range(10) for y in range(10)]) + ) + + def test_zip_chaining(self): + # Tests for SPARK-21985 + rdd = self.sc.parallelize('abc', 2) + self.assertSetEqual( + set(rdd.zip(rdd).zip(rdd).collect()), + set([((x, x), x) for x in 'abc']) + ) + self.assertSetEqual( + set(rdd.zip(rdd.zip(rdd)).collect()), + set([(x, (x, x)) for x in 'abc']) + ) + + def test_deleting_input_files(self): + # Regression test for SPARK-1025 + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + def test_sampling_default_seed(self): + # Test for SPARK-3995 (default seed setting) + data = self.sc.parallelize(xrange(1000), 1) + subset = data.takeSample(False, 10) + self.assertEqual(len(subset), 10) + + def test_aggregate_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregate and treeAggregate to build dict + # representing a counter of ints + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + # Show that single or multiple partitions work + data1 = self.sc.range(10, numSlices=1) + data2 = self.sc.range(10, numSlices=2) + + def seqOp(x, y): + x[y] += 1 + return x + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) + counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) + counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + + ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) + self.assertEqual(counts1, ground_truth) + self.assertEqual(counts2, ground_truth) + self.assertEqual(counts3, ground_truth) + self.assertEqual(counts4, ground_truth) + + def test_aggregate_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that + # contains lists of all values for each key in the original RDD + + # list(range(...)) for Python 3.x compatibility (can't use * operator + # on a range object) + # list(zip(...)) for Python 3.x compatibility (want to parallelize a + # collection, not a zip object) + tuples = list(zip(list(range(10))*2, [1]*20)) + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def seqOp(x, y): + x.append(y) + return x + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.aggregateByKey([], seqOp, comboOp).collect() + values2 = data2.aggregateByKey([], seqOp, comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + ground_truth = [(i, [1]*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_fold_mutable_zero_value(self): + # Test for SPARK-9021; uses fold to merge an RDD of dict counters into + # a single dict + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + counts1 = defaultdict(int, dict((i, 1) for i in range(10))) + counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) + counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) + counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) + all_counts = [counts1, counts2, counts3, counts4] + # Show that single or multiple partitions work + data1 = self.sc.parallelize(all_counts, 1) + data2 = self.sc.parallelize(all_counts, 2) + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + fold1 = data1.fold(defaultdict(int), comboOp) + fold2 = data2.fold(defaultdict(int), comboOp) + + ground_truth = defaultdict(int) + for counts in all_counts: + for key, val in counts.items(): + ground_truth[key] += val + self.assertEqual(fold1, ground_truth) + self.assertEqual(fold2, ground_truth) + + def test_fold_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains + # lists of all values for each key in the original RDD + + tuples = [(i, range(i)) for i in range(10)]*2 + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.foldByKey([], comboOp).collect() + values2 = data2.foldByKey([], comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + # list(range(...)) for Python 3.x compatibility + ground_truth = [(i, list(range(i))*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_aggregate_by_key(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) + + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEqual([jon, jane], theDoes.collect()) + + def test_large_broadcast(self): + N = 10000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 27MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = list(range(1 << 15)) + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + + def test_large_closure(self): + N = 200000 + data = [float(i) for i in xrange(N)] + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) + + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + # regression test for SPARK-4841 + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + t = self.sc.textFile(path) + cnt = t.count() + self.assertEqual(cnt, t.zip(t).count()) + rdd = t.map(str) + self.assertEqual(cnt, t.zip(rdd).count()) + # regression test for bug in _reserializer() + self.assertEqual(cnt, t.zip(rdd).count()) + + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + + def test_count_approx_distinct(self): + rdd = self.sc.parallelize(xrange(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) + + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) + self.assertTrue(18 < rdd.countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) + + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) + + def test_histogram(self): + # empty + rdd = self.sc.parallelize([]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertRaises(ValueError, lambda: rdd.histogram(1)) + + # out of range + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) + + # in range with one bucket + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) + + # in range with one bucket exact match + self.assertEqual([4], rdd.histogram([1, 4])[1]) + + # out of range with two buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) + + # out of range with two uneven buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + + # in range with two buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two bucket and None + rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two uneven buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) + + # mixed range with two uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) + + # mixed range with four uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # mixed range with uneven buckets and NaN + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, + 199.0, 200.0, 200.1, None, float('nan')]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # out of range with infinite buckets + rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + + # invalid buckets + self.assertRaises(ValueError, lambda: rdd.histogram([])) + self.assertRaises(ValueError, lambda: rdd.histogram([1])) + self.assertRaises(ValueError, lambda: rdd.histogram(0)) + self.assertRaises(TypeError, lambda: rdd.histogram({})) + + # without buckets + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) + + # without buckets single element + rdd = self.sc.parallelize([1]) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) + + # without bucket no range + rdd = self.sc.parallelize([1] * 4) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) + + # without buckets basic two + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + + # without buckets with more requested than elements + rdd = self.sc.parallelize([1, 2]) + buckets = [1 + 0.2 * i for i in range(6)] + hist = [1, 0, 0, 0, 1] + self.assertEqual((buckets, hist), rdd.histogram(5)) + + # invalid RDDs + rdd = self.sc.parallelize([1, float('inf')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + rdd = self.sc.parallelize([float('nan')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + + # string + rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) + self.assertRaises(TypeError, lambda: rdd.histogram(2)) + + def test_repartitionAndSortWithinPartitions_asc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + + def test_distinct(self): + rdd = self.sc.parallelize((1, 2, 3)*10, 10) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) + result = rdd.distinct(5) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) + + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "1m") + N = 200001 + kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda kv: kv[0] == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) + + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) + + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + for x, y in zip(data, result): + self.assertEqual(x, y) + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_rdd import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py new file mode 100644 index 0000000000000..e45f5b371f461 --- /dev/null +++ b/python/pyspark/tests/test_readwrite.py @@ -0,0 +1,499 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import sys +import tempfile +import unittest +from array import array + +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME + + +class InputFormatTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", + "org.apache.hadoop.io.DoubleWritable", + "org.apache.hadoop.io.Text").collect()) + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.assertEqual(doubles, ed) + + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", + "org.apache.hadoop.io.Text", + "org.apache.hadoop.io.Text").collect()) + et = [(u'1', u'aa'), + (u'1', u'aa'), + (u'2', u'aa'), + (u'2', u'bb'), + (u'2', u'bb'), + (u'3', u'cc')] + self.assertEqual(text, et) + + bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.assertEqual(bools, eb) + + nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.assertEqual(nulls, en) + + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + for v in maps: + self.assertTrue(v in em) + + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable").collect()) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) + + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + ).collect()) + self.assertEqual(unbatched_clazz, ec) + + def test_oldhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newolderror(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.sequenceFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.NotValidWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + maps = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] + self.assertEqual(maps, em) + + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(list(range(100)), result) + + +class OutputFormatTests(ReusedPySparkTestCase): + + def setUp(self): + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) + + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) + + conf = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" + } + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} + result = self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + data = [(1, ""), + (1, "a"), + (2, "bcdf")] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + self.assertEqual(result, data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=input_conf).collect()) + self.assertEqual(new_dataset, data) + + @unittest.skipIf(sys.version >= "3", "serialize of array") + def test_newhadoop_with_array(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = list(zip(x, y)) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" + } + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) + + +if __name__ == "__main__": + from pyspark.tests.test_readwrite import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py new file mode 100644 index 0000000000000..bce94062c8af7 --- /dev/null +++ b/python/pyspark/tests/test_serializers.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +import sys +import unittest + +from pyspark import serializers +from pyspark.serializers import * +from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \ + AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \ + FlattenedValuesSerializer, CartesianDeserializer +from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \ + have_numpy, have_scipy + + +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from pickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEqual(p1, p2) + + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + + def test_itemgetter(self): + from operator import itemgetter + ser = CloudPickleSerializer() + d = range(10) + getter = itemgetter(1) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + getter = itemgetter(0, 3) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + def test_function_module_name(self): + ser = CloudPickleSerializer() + func = lambda x: x + func2 = ser.loads(ser.dumps(func)) + self.assertEqual(func.__module__, func2.__module__) + + def test_attrgetter(self): + from operator import attrgetter + ser = CloudPickleSerializer() + + class C(object): + def __getattr__(self, item): + return item + d = C() + getter = attrgetter("a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("a", "b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + d.e = C() + getter = attrgetter("e.a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("e.a", "e.b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + # Regression test for SPARK-3415 + def test_pickling_file_handles(self): + # to be corrected with SPARK-11160 + try: + import xmlrunner + except ImportError: + ser = CloudPickleSerializer() + out1 = sys.stderr + out2 = ser.loads(ser.dumps(out1)) + self.assertEqual(out1, out2) + + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.__code__.co_names) + ser.dumps(foo) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + ser.dump_stream(range(1000), io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) + + +@unittest.skipIf(not have_scipy, "SciPy not installed") +class SciPyTests(PySparkTestCase): + + """General PySpark tests that depend on scipy """ + + def test_serialize(self): + from scipy.special import gammaln + + x = range(1, 5) + expected = list(map(gammaln, x)) + observed = self.sc.parallelize(x).map(gammaln).collect() + self.assertEqual(expected, observed) + + +@unittest.skipIf(not have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + import numpy as np + + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) + + stats_dict = s.asDict() + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) + + stats_sample_dict = s.asDict(sample=True) + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) + self.assertSequenceEqual( + [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) + self.assertSequenceEqual( + [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) + + +class SerializersTest(unittest.TestCase): + + def test_chunked_stream(self): + original_bytes = bytearray(range(100)) + for data_length in [1, 10, 100]: + for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: + dest = ByteArrayOutput() + stream_out = serializers.ChunkedStream(dest, buffer_length) + stream_out.write(original_bytes[:data_length]) + stream_out.close() + num_chunks = int(math.ceil(float(data_length) / buffer_length)) + # length for each chunk, and a final -1 at the very end + exp_size = (num_chunks + 1) * 4 + data_length + self.assertEqual(len(dest.buffer), exp_size) + dest_pos = 0 + data_pos = 0 + for chunk_idx in range(num_chunks): + chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) + if chunk_idx == num_chunks - 1: + exp_length = data_length % buffer_length + if exp_length == 0: + exp_length = buffer_length + else: + exp_length = buffer_length + self.assertEqual(chunk_length, exp_length) + dest_pos += 4 + dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] + orig_chunk = original_bytes[data_pos:data_pos + chunk_length] + self.assertEqual(dest_chunk, orig_chunk) + dest_pos += chunk_length + data_pos += chunk_length + # ends with a -1 + self.assertEqual(dest.buffer[-4:], write_int(-1)) + + +if __name__ == "__main__": + from pyspark.tests.test_serializers import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py new file mode 100644 index 0000000000000..0489426061b75 --- /dev/null +++ b/python/pyspark/tests/test_shuffle.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter + +if sys.version_info[0] >= 3: + xrange = range + + +class MergerTests(unittest.TestCase): + + def setUp(self): + self.N = 1 << 12 + self.l = [i for i in xrange(self.N)] + self.data = list(zip(self.l, self.l)) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) + + def test_small_dataset(self): + m = ExternalMerger(self.agg, 1000) + m.mergeValues(self.data) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 1000) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + def test_medium_dataset(self): + m = ExternalMerger(self.agg, 20) + m.mergeValues(self.data) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N)) * 3) + + def test_huge_dataset(self): + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(len(v) for k, v in m.items()), + self.N * 10) + m._cleanup() + + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + + +class SorterTests(unittest.TestCase): + def test_in_memory_sort(self): + l = list(range(1024)) + random.shuffle(l) + sorter = ExternalSorter(1024) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + + def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit + l = list(range(1024)) + random.shuffle(l) + sorter = CustomizedSorter(1) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + + def test_external_sort_in_rdd(self): + conf = SparkConf().set("spark.python.worker.memory", "1m") + sc = SparkContext(conf=conf) + l = list(range(10240)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_shuffle import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py new file mode 100644 index 0000000000000..b3a967440a9b2 --- /dev/null +++ b/python/pyspark/tests/test_taskcontext.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import time + +from pyspark import SparkContext, TaskContext, BarrierTaskContext +from pyspark.testing.utils import PySparkTestCase + + +class TaskContextTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + # Allow retries even though they are normally disabled in local mode + self.sc = SparkContext('local[4, 2]', class_name) + + def test_stage_id(self): + """Test the stage ids are available and incrementing as expected.""" + rdd = self.sc.parallelize(range(10)) + stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + # Test using the constructor directly rather than the get() + stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] + self.assertEqual(stage1 + 1, stage2) + self.assertEqual(stage1 + 2, stage3) + self.assertEqual(stage2 + 1, stage3) + + def test_partition_id(self): + """Test the partition id.""" + rdd1 = self.sc.parallelize(range(10), 1) + rdd2 = self.sc.parallelize(range(10), 2) + pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() + pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() + self.assertEqual(0, pids1[0]) + self.assertEqual(0, pids1[9]) + self.assertEqual(0, pids2[0]) + self.assertEqual(1, pids2[9]) + + def test_attempt_number(self): + """Verify the attempt numbers are correctly reported.""" + rdd = self.sc.parallelize(range(10)) + # Verify a simple job with no failures + attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() + map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) + + def fail_on_first(x): + """Fail on the first attempt so we get a positive attempt number""" + tc = TaskContext.get() + attempt_number = tc.attemptNumber() + partition_id = tc.partitionId() + attempt_id = tc.taskAttemptId() + if attempt_number == 0 and partition_id == 0: + raise Exception("Failing on first attempt") + else: + return [x, partition_id, attempt_number, attempt_id] + result = rdd.map(fail_on_first).collect() + # We should re-submit the first partition to it but other partitions should be attempt 0 + self.assertEqual([0, 0, 1], result[0][0:3]) + self.assertEqual([9, 3, 0], result[9][0:3]) + first_partition = filter(lambda x: x[1] == 0, result) + map(lambda x: self.assertEqual(1, x[2]), first_partition) + other_partitions = filter(lambda x: x[1] != 0, result) + map(lambda x: self.assertEqual(0, x[2]), other_partitions) + # The task attempt id should be different + self.assertTrue(result[0][3] != result[9][3]) + + def test_tc_on_driver(self): + """Verify that getting the TaskContext on the driver returns None.""" + tc = TaskContext.get() + self.assertTrue(tc is None) + + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_with_python_worker_reuse(self): + """ + Verify that BarrierTaskContext.barrier() with reused python worker. + """ + self.sc._conf.set("spark.python.work.reuse", "true") + rdd = self.sc.parallelize(range(4), 4) + # start a normal job first to start all worker + result = rdd.map(lambda x: x ** 2).collect() + self.assertEqual([0, 1, 4, 9], result) + # make sure `spark.python.work.reuse=true` + self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") + + # worker will be reused in this barrier job + self.test_barrier() + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_taskcontext import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py new file mode 100644 index 0000000000000..11cda8fd2f5cd --- /dev/null +++ b/python/pyspark/tests/test_util.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import keyword_only +from pyspark.testing.utils import PySparkTestCase + + +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + + +if __name__ == "__main__": + from pyspark.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py new file mode 100644 index 0000000000000..a33b77d983419 --- /dev/null +++ b/python/pyspark/tests/test_worker.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import tempfile +import threading +import time + +from py4j.protocol import Py4JJavaError + +from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class WorkerTests(ReusedPySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + + def sleep(x): + import os + import time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + with open(path) as f: + data = f.read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + # run a normal job + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(xrange(100), 1) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + + def test_reuse_worker_after_take(self): + rdd = self.sc.parallelize(xrange(100000), 1) + self.assertEqual(0, rdd.first()) + + def count(): + try: + rdd.count() + except Exception: + pass + + t = threading.Thread(target=count) + t.daemon = True + t.start() + t.join(5) + self.assertTrue(not t.isAlive()) + self.assertEqual(100000, rdd.count()) + + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" + try: + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + self.sc.pythonVer = version + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_worker import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From d4130ec1f3461dcc961eee9802005ba7a15212d1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Nov 2018 17:20:49 +0800 Subject: [PATCH 2062/2461] [SPARK-26014][R] Deprecate R prior to version 3.4 in SparkR ## What changes were proposed in this pull request? This PR proposes to bump up the minimum versions of R from 3.1 to 3.4. R version. 3.1.x is too old. It's released 4.5 years ago. R 3.4.0 is released 1.5 years ago. Considering the timing for Spark 3.0, deprecating lower versions, bumping up R to 3.4 might be reasonable option. It should be good to deprecate and drop < R 3.4 support. ## How was this patch tested? Jenkins tests. Closes #23012 from HyukjinKwon/SPARK-26014. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- R/WINDOWS.md | 2 +- R/pkg/DESCRIPTION | 2 +- R/pkg/inst/profile/general.R | 4 ++++ R/pkg/inst/profile/shell.R | 4 ++++ docs/index.md | 3 ++- 5 files changed, 12 insertions(+), 3 deletions(-) diff --git a/R/WINDOWS.md b/R/WINDOWS.md index da668a69b8679..33a4c850cfdac 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -3,7 +3,7 @@ To build SparkR on Windows, the following steps are required 1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to -include Rtools and R in `PATH`. +include Rtools and R in `PATH`. Note that support for R prior to version 3.4 is deprecated as of Spark 3.0.0. 2. Install [JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index cdaaa6104e6a9..736da46eaa8d3 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -15,7 +15,7 @@ URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html SystemRequirements: Java (== 8) Depends: - R (>= 3.0), + R (>= 3.1), methods Suggests: knitr, diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8c75c19ca7ac3..3efb460846fc2 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,6 +16,10 @@ # .First <- function() { + if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) { + warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0") + } + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") dirs <- strsplit(packageDir, ",")[[1]] .libPaths(c(dirs, .libPaths())) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 8a8111a8c5419..32eb3671b5941 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -16,6 +16,10 @@ # .First <- function() { + if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) { + warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0") + } + home <- Sys.getenv("SPARK_HOME") .libPaths(c(file.path(home, "R", "lib"), .libPaths())) Sys.setenv(NOAWT = 1) diff --git a/docs/index.md b/docs/index.md index ac38f1d4c53c2..bd287e3f8d83f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,8 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} +Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. R prior to version 3.4 support is deprecated as of Spark 3.0.0. +For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). From 44d4ef60b8015fd8701a685cfb7c96c5ea57d3b1 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 15 Nov 2018 18:25:18 +0800 Subject: [PATCH 2063/2461] [SPARK-25974][SQL] Optimizes Generates bytecode for ordering based on the given order ## What changes were proposed in this pull request? Currently, when generates the code for ordering based on the given order, too many variables and assignment statements will be generated, which is not necessary. This PR will eliminate redundant variables. Optimizes Generates bytecode for ordering based on the given order. The generated code looks like: ``` spark.range(1).selectExpr( "id as key", "(id & 1023) as value1", "cast(id & 1023 as double) as value2", "cast(id & 1023 as int) as value3" ).select("value1", "value2", "value3").orderBy("value1", "value2").collect() ``` before PR(codegen size: 178) ``` Generated Ordering by input[0, bigint, false] ASC NULLS FIRST,input[1, double, false] ASC NULLS FIRST: /* 001 */ public SpecificOrdering generate(Object[] references) { /* 002 */ return new SpecificOrdering(references); /* 003 */ } /* 004 */ /* 005 */ class SpecificOrdering extends org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering { /* 006 */ /* 007 */ private Object[] references; /* 008 */ /* 009 */ /* 010 */ public SpecificOrdering(Object[] references) { /* 011 */ this.references = references; /* 012 */ /* 013 */ } /* 014 */ /* 015 */ public int compare(InternalRow a, InternalRow b) { /* 016 */ /* 017 */ InternalRow i = null; /* 018 */ /* 019 */ i = a; /* 020 */ boolean isNullA_0; /* 021 */ long primitiveA_0; /* 022 */ { /* 023 */ long value_0 = i.getLong(0); /* 024 */ isNullA_0 = false; /* 025 */ primitiveA_0 = value_0; /* 026 */ } /* 027 */ i = b; /* 028 */ boolean isNullB_0; /* 029 */ long primitiveB_0; /* 030 */ { /* 031 */ long value_0 = i.getLong(0); /* 032 */ isNullB_0 = false; /* 033 */ primitiveB_0 = value_0; /* 034 */ } /* 035 */ if (isNullA_0 && isNullB_0) { /* 036 */ // Nothing /* 037 */ } else if (isNullA_0) { /* 038 */ return -1; /* 039 */ } else if (isNullB_0) { /* 040 */ return 1; /* 041 */ } else { /* 042 */ int comp = (primitiveA_0 > primitiveB_0 ? 1 : primitiveA_0 < primitiveB_0 ? -1 : 0); /* 043 */ if (comp != 0) { /* 044 */ return comp; /* 045 */ } /* 046 */ } /* 047 */ /* 048 */ i = a; /* 049 */ boolean isNullA_1; /* 050 */ double primitiveA_1; /* 051 */ { /* 052 */ double value_1 = i.getDouble(1); /* 053 */ isNullA_1 = false; /* 054 */ primitiveA_1 = value_1; /* 055 */ } /* 056 */ i = b; /* 057 */ boolean isNullB_1; /* 058 */ double primitiveB_1; /* 059 */ { /* 060 */ double value_1 = i.getDouble(1); /* 061 */ isNullB_1 = false; /* 062 */ primitiveB_1 = value_1; /* 063 */ } /* 064 */ if (isNullA_1 && isNullB_1) { /* 065 */ // Nothing /* 066 */ } else if (isNullA_1) { /* 067 */ return -1; /* 068 */ } else if (isNullB_1) { /* 069 */ return 1; /* 070 */ } else { /* 071 */ int comp = org.apache.spark.util.Utils.nanSafeCompareDoubles(primitiveA_1, primitiveB_1); /* 072 */ if (comp != 0) { /* 073 */ return comp; /* 074 */ } /* 075 */ } /* 076 */ /* 077 */ /* 078 */ return 0; /* 079 */ } /* 080 */ /* 081 */ /* 082 */ } ``` After PR(codegen size: 89) ``` Generated Ordering by input[0, bigint, false] ASC NULLS FIRST,input[1, double, false] ASC NULLS FIRST: /* 001 */ public SpecificOrdering generate(Object[] references) { /* 002 */ return new SpecificOrdering(references); /* 003 */ } /* 004 */ /* 005 */ class SpecificOrdering extends org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering { /* 006 */ /* 007 */ private Object[] references; /* 008 */ /* 009 */ /* 010 */ public SpecificOrdering(Object[] references) { /* 011 */ this.references = references; /* 012 */ /* 013 */ } /* 014 */ /* 015 */ public int compare(InternalRow a, InternalRow b) { /* 016 */ /* 017 */ /* 018 */ long value_0 = a.getLong(0); /* 019 */ long value_2 = b.getLong(0); /* 020 */ if (false && false) { /* 021 */ // Nothing /* 022 */ } else if (false) { /* 023 */ return -1; /* 024 */ } else if (false) { /* 025 */ return 1; /* 026 */ } else { /* 027 */ int comp = (value_0 > value_2 ? 1 : value_0 < value_2 ? -1 : 0); /* 028 */ if (comp != 0) { /* 029 */ return comp; /* 030 */ } /* 031 */ } /* 032 */ /* 033 */ double value_1 = a.getDouble(1); /* 034 */ double value_3 = b.getDouble(1); /* 035 */ if (false && false) { /* 036 */ // Nothing /* 037 */ } else if (false) { /* 038 */ return -1; /* 039 */ } else if (false) { /* 040 */ return 1; /* 041 */ } else { /* 042 */ int comp = org.apache.spark.util.Utils.nanSafeCompareDoubles(value_1, value_3); /* 043 */ if (comp != 0) { /* 044 */ return comp; /* 045 */ } /* 046 */ } /* 047 */ /* 048 */ /* 049 */ return 0; /* 050 */ } /* 051 */ /* 052 */ /* 053 */ } ``` ## How was this patch tested? the existed test cases. Closes #22976 from heary-cao/GenArrayData. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../codegen/GenerateOrdering.scala | 113 ++++++++---------- 1 file changed, 51 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 9a51be6ed5aeb..c3b95b6c67fdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -68,62 +68,55 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR genComparisons(ctx, ordering) } + /** + * Creates the variables for ordering based on the given order. + */ + private def createOrderKeys( + ctx: CodegenContext, + row: String, + ordering: Seq[SortOrder]): Seq[ExprCode] = { + ctx.INPUT_ROW = row + // to use INPUT_ROW we must make sure currentVars is null + ctx.currentVars = null + ordering.map(_.child.genCode(ctx)) + } + /** * Generates the code for ordering based on the given order. */ def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val oldInputRow = ctx.INPUT_ROW val oldCurrentVars = ctx.currentVars - val inputRow = "i" - ctx.INPUT_ROW = inputRow - // to use INPUT_ROW we must make sure currentVars is null - ctx.currentVars = null - - val comparisons = ordering.map { order => - val eval = order.child.genCode(ctx) - val asc = order.isAscending - val isNullA = ctx.freshName("isNullA") - val primitiveA = ctx.freshName("primitiveA") - val isNullB = ctx.freshName("isNullB") - val primitiveB = ctx.freshName("primitiveB") + val rowAKeys = createOrderKeys(ctx, "a", ordering) + val rowBKeys = createOrderKeys(ctx, "b", ordering) + val comparisons = rowAKeys.zip(rowBKeys).zipWithIndex.map { case ((l, r), i) => + val dt = ordering(i).child.dataType + val asc = ordering(i).isAscending + val nullOrdering = ordering(i).nullOrdering + val lRetValue = nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + } + val rRetValue = nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + } s""" - ${ctx.INPUT_ROW} = a; - boolean $isNullA; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; - { - ${eval.code} - $isNullA = ${eval.isNull}; - $primitiveA = ${eval.value}; - } - ${ctx.INPUT_ROW} = b; - boolean $isNullB; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; - { - ${eval.code} - $isNullB = ${eval.isNull}; - $primitiveB = ${eval.value}; - } - if ($isNullA && $isNullB) { - // Nothing - } else if ($isNullA) { - return ${ - order.nullOrdering match { - case NullsFirst => "-1" - case NullsLast => "1" - }}; - } else if ($isNullB) { - return ${ - order.nullOrdering match { - case NullsFirst => "1" - case NullsLast => "-1" - }}; - } else { - int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; - if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; - } - } - """ + |${l.code} + |${r.code} + |if (${l.isNull} && ${r.isNull}) { + | // Nothing + |} else if (${l.isNull}) { + | return $lRetValue; + |} else if (${r.isNull}) { + | return $rRetValue; + |} else { + | int comp = ${ctx.genComp(dt, l.value, r.value)}; + | if (comp != 0) { + | return ${if (asc) "comp" else "-comp"}; + | } + |} + """.stripMargin } val code = ctx.splitExpressions( @@ -133,30 +126,26 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR returnType = "int", makeSplitFunction = { body => s""" - InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. - $body - return 0; - """ + |$body + |return 0; + """.stripMargin }, foldFunctions = { funCalls => funCalls.zipWithIndex.map { case (funCall, i) => val comp = ctx.freshName("comp") s""" - int $comp = $funCall; - if ($comp != 0) { - return $comp; - } - """ + |int $comp = $funCall; + |if ($comp != 0) { + | return $comp; + |} + """.stripMargin }.mkString }) ctx.currentVars = oldCurrentVars ctx.INPUT_ROW = oldInputRow // make sure INPUT_ROW is declared even if splitExpressions // returns an inlined block - s""" - |InternalRow $inputRow = null; - |$code - """.stripMargin + code } protected def create(ordering: Seq[SortOrder]): BaseOrdering = { From b46f75a5af372422de0f8e07ff920fa6ccd33c7e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 15 Nov 2018 20:09:53 +0800 Subject: [PATCH 2064/2461] [SPARK-26057][SQL] Transform also analyzed plans when dedup references ## What changes were proposed in this pull request? In SPARK-24865 `AnalysisBarrier` was removed and in order to improve resolution speed, the `analyzed` flag was (re-)introduced in order to process only plans which are not yet analyzed. This should not be the case when performing attribute deduplication as in that case we need to transform also the plans which were already analyzed, otherwise we can miss to rewrite some attributes leading to invalid plans. ## How was this patch tested? added UT Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23035 from mgaido91/SPARK-26057. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c4e526081f4a2..ab2312fdcdeef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -871,7 +871,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan resolveOperatorsDown { case currentFragment => + plan transformDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2bb18f48e0ae2..0ee2627814ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2554,4 +2554,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(swappedDf.filter($"key"($"map") > "a"), Row(2, Map(2 -> "b"))) } + + test("SPARK-26057: attribute deduplication on already analyzed plans") { + withTempView("a", "b", "v") { + val df1 = Seq(("1-1", 6)).toDF("id", "n") + df1.createOrReplaceTempView("a") + val df3 = Seq("1-1").toDF("id") + df3.createOrReplaceTempView("b") + spark.sql( + """ + |SELECT a.id, n as m + |FROM a + |WHERE EXISTS( + | SELECT 1 + | FROM b + | WHERE b.id = a.id) + """.stripMargin).createOrReplaceTempView("v") + val res = spark.sql( + """ + |SELECT a.id, n, m + | FROM a + | LEFT OUTER JOIN v ON v.id = a.id + """.stripMargin) + checkAnswer(res, Row("1-1", 6, 6)) + } + } } From 9610efc252c94f93689d45e320df1c5815d97b25 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Nov 2018 20:25:27 +0800 Subject: [PATCH 2065/2461] [SPARK-26055][CORE] InterfaceStability annotations should be retained at runtime ## What changes were proposed in this pull request? It's good to have annotations available at runtime, so that tools like MiMa can detect them and deal with then specially. e.g. we don't want to track compatibility for unstable classes. This PR makes `InterfaceStability` annotations to be retained at runtime, to be consistent with `Experimental` and `DeveloperApi` ## How was this patch tested? N/A Closes #23029 from cloud-fan/annotation. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../org/apache/spark/annotation/DeveloperApi.java | 1 + .../org/apache/spark/annotation/Experimental.java | 1 + .../apache/spark/annotation/InterfaceStability.java | 11 ++++++++++- .../java/org/apache/spark/annotation/Private.java | 6 ++---- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java index 0ecef6db0e039..890f2faca28b0 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java @@ -29,6 +29,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java index ff8120291455f..96875920cd9c3 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java @@ -30,6 +30,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java index 323098f69c6e1..02bcec737e80e 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java @@ -17,7 +17,7 @@ package org.apache.spark.annotation; -import java.lang.annotation.Documented; +import java.lang.annotation.*; /** * Annotation to inform users of how much to rely on a particular package, @@ -31,6 +31,9 @@ public class InterfaceStability { * (e.g. from 1.0 to 2.0). */ @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface Stable {}; /** @@ -38,6 +41,9 @@ public class InterfaceStability { * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). */ @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface Evolving {}; /** @@ -45,5 +51,8 @@ public class InterfaceStability { * Classes that are unannotated are considered Unstable. */ @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) public @interface Unstable {}; } diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Private.java b/common/tags/src/main/java/org/apache/spark/annotation/Private.java index 9082fcf0c84bc..a460d608ae16b 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/Private.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/Private.java @@ -17,10 +17,7 @@ package org.apache.spark.annotation; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; +import java.lang.annotation.*; /** * A class that is considered private to the internals of Spark -- there is a high-likelihood @@ -35,6 +32,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) From 91405b3b6eb4fa8047123d951859b6e2a1e46b6a Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 15 Nov 2018 09:22:31 -0600 Subject: [PATCH 2066/2461] [SPARK-22450][WIP][CORE][MLLIB][FOLLOWUP] Safely register MultivariateGaussian ## What changes were proposed in this pull request? register following classes in Kryo: "org.apache.spark.ml.stat.distribution.MultivariateGaussian", "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" ## How was this patch tested? added tests Due to existing module dependency, I can not import spark-core in mllib-local's testsuits, so I do not add testsuite in `org.apache.spark.ml.stat.distribution.MultivariateGaussianSuite`. And I notice that class `ClusterStats` in `ClusteringEvaluator` is registered in a different way, should it be modified to keep in line with others in ML? srowen Closes #22974 from zhengruifeng/kryo_MultivariateGaussian. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../spark/serializer/KryoSerializer.scala | 10 ++++++++- .../distribution/MultivariateGaussian.scala | 4 ++-- .../distribution/MultivariateGaussian.scala | 4 ++-- .../spark/ml/attribute/AttributeSuite.scala | 19 +++++++++++++++- .../MultivariateGaussianSuite.scala | 22 ++++++++++++++++++- 5 files changed, 52 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3795d5c3b38e3..66812a54846c6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -215,6 +215,12 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. Seq( + "org.apache.spark.ml.attribute.Attribute", + "org.apache.spark.ml.attribute.AttributeGroup", + "org.apache.spark.ml.attribute.BinaryAttribute", + "org.apache.spark.ml.attribute.NominalAttribute", + "org.apache.spark.ml.attribute.NumericAttribute", + "org.apache.spark.ml.feature.Instance", "org.apache.spark.ml.feature.LabeledPoint", "org.apache.spark.ml.feature.OffsetInstance", @@ -224,6 +230,7 @@ class KryoSerializer(conf: SparkConf) "org.apache.spark.ml.linalg.SparseMatrix", "org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.stat.distribution.MultivariateGaussian", "org.apache.spark.ml.tree.impl.TreePoint", "org.apache.spark.mllib.clustering.VectorWithNorm", "org.apache.spark.mllib.linalg.DenseMatrix", @@ -232,7 +239,8 @@ class KryoSerializer(conf: SparkConf) "org.apache.spark.mllib.linalg.SparseMatrix", "org.apache.spark.mllib.linalg.SparseVector", "org.apache.spark.mllib.linalg.Vector", - "org.apache.spark.mllib.regression.LabeledPoint" + "org.apache.spark.mllib.regression.LabeledPoint", + "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" ).foreach { name => try { val clazz = Utils.classForName(name) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala index 3167e0c286d47..e7f7a8e07d7f2 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala @@ -48,14 +48,14 @@ class MultivariateGaussian @Since("2.0.0") ( this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov)) } - private val breezeMu = mean.asBreeze.toDenseVector + @transient private lazy val breezeMu = mean.asBreeze.toDenseVector /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 4cf662e036346..9a746dcf35556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -43,7 +43,7 @@ class MultivariateGaussian @Since("1.3.0") ( require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - private val breezeMu = mu.asBreeze.toDenseVector + @transient private lazy val breezeMu = mu.asBreeze.toDenseVector /** * private[mllib] constructor @@ -60,7 +60,7 @@ class MultivariateGaussian @Since("1.3.0") ( * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 6355e0f179496..eb5f3ca45940d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.attribute -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.types._ class AttributeSuite extends SparkFunSuite { @@ -221,4 +222,20 @@ class AttributeSuite extends SparkFunSuite { val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } + + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val numericAttr = new NumericAttribute(Some("numeric"), Some(1), Some(1.0), Some(2.0)) + val nominalAttr = new NominalAttribute(Some("nominal"), Some(2), Some(false)) + val binaryAttr = new BinaryAttribute(Some("binary"), Some(3), Some(Array("i", "j"))) + + Seq(numericAttr, nominalAttr, binaryAttr).foreach { i => + val i2 = ser.deserialize[Attribute](ser.serialize(i)) + assert(i === i2) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 669d44223d713..5b4a2607f0b25 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.stat.distribution -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { @@ -80,4 +81,23 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) } + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val mu = Vectors.dense(0.0, 0.0) + val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) + val dist1 = new MultivariateGaussian(mu, sigma1) + + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) + val dist2 = new MultivariateGaussian(mu, sigma2) + + Seq(dist1, dist2).foreach { i => + val i2 = ser.deserialize[MultivariateGaussian](ser.serialize(i)) + assert(i.sigma === i2.sigma) + assert(i.mu === i2.mu) + } + } } From cae5879dbe5881a88c4925f6c5408f32d6f3860e Mon Sep 17 00:00:00 2001 From: Shahid Date: Thu, 15 Nov 2018 10:27:57 -0600 Subject: [PATCH 2067/2461] [SPARK-26044][WEBUI] Aggregated Metrics table sort based on executor ID ## What changes were proposed in this pull request? Aggregated Metrics table in the stage page is not sorted based on the executorID properly. Because executorID is string and also the logs of the executors are in the same column. In this PR, I created a new column for executor logs. ## How was this patch tested? Before patch: ![screenshot from 2018-11-14 02-05-12](https://user-images.githubusercontent.com/23054875/48441529-caa77580-e7b1-11e8-90ea-b16f63438102.png) After patch: ![screenshot from 2018-11-14 02-05-29](https://user-images.githubusercontent.com/23054875/48441540-d2671a00-e7b1-11e8-9059-890bfe80c961.png) Closes #23024 from shahidki31/AggSort. Authored-by: Shahid Signed-off-by: Sean Owen --- .../apache/spark/ui/jobs/ExecutorTable.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 0ff64f053f371..1be81e5ef9952 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -70,6 +70,7 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { Blacklisted +
      Logs
      -
      {k}
      -
      - { - executor.map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - } - } -
      -
      {k} {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks + v.killedTasks}false {executor.map(_.executorLogs).getOrElse(Map.empty).map { + case (logName, logUrl) => + }} +
      1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.3. + options are 0.12.0 through 2.3.4.
      */ -@InterfaceStability.Evolving +@Evolving public class DataSourceOptions { private final Map keyLowerCasedMap; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6e31e84bf6c72..eae7a45d1d446 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. @@ -30,5 +30,5 @@ * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ -@InterfaceStability.Evolving +@Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java index 61c08e7fa89df..c4d9ef88f607e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link MicroBatchReadSupport} instances when end users run * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. */ -@InterfaceStability.Evolving +@Evolving public interface MicroBatchReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index bbe430e299261..c00abd9b685b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ -@InterfaceStability.Evolving +@Evolving public interface SessionConfigSupport extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java index f9ca85d8089b4..8ac9c51750865 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.sql.streaming.OutputMode; @@ -30,7 +30,7 @@ * This interface is used to create {@link StreamingWriteSupport} instances when end users run * {@code Dataset.writeStream.format(...).option(...).start()}. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java index 452ee86675b42..518a8b03a2c6e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how to load the data from data source for batch processing. @@ -29,7 +29,7 @@ * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader * factory to scan data from the data source with a Spark job. */ -@InterfaceStability.Evolving +@Evolving public interface BatchReadSupport extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 95c30de907e44..5f5248084bad6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A serializable representation of an input partition returned by @@ -32,7 +32,7 @@ * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} * doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface InputPartition extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index 04ff8d0a19fc3..2945925959538 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -20,7 +20,7 @@ import java.io.Closeable; import java.io.IOException; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or @@ -32,7 +32,7 @@ * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} * returns true). */ -@InterfaceStability.Evolving +@Evolving public interface PartitionReader extends Closeable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java index f35de9310eee3..97f4a473953fc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -30,7 +30,7 @@ * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ -@InterfaceStability.Evolving +@Evolving public interface PartitionReaderFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java index a58ddb288f1ed..b1f610a82e8a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -27,7 +27,7 @@ * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ -@InterfaceStability.Evolving +@Evolving public interface ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java index 7462ce2820585..a69872a527746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -31,7 +31,7 @@ * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. */ -@InterfaceStability.Evolving +@Evolving public interface ScanConfig { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index 4c0eedfddfe22..4922962f70655 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface for building the {@link ScanConfig}. Implementations can mixin those * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in * the returned {@link ScanConfig}. */ -@InterfaceStability.Evolving +@Evolving public interface ScanConfigBuilder { ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index 44799c7d49137..14776f37fed46 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -19,13 +19,13 @@ import java.util.OptionalLong; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface to represent statistics for a data source, which is returned by * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ -@InterfaceStability.Evolving +@Evolving public interface Statistics { OptionalLong sizeInBytes(); OptionalLong numRows(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 5e7985f645a06..3a89baa1b44c2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.Filter; /** * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to * push down filters to the data source and reduce the size of the data to be read. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsPushDownFilters extends ScanConfigBuilder { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index edb164937d6ef..1934763224881 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -25,7 +25,7 @@ * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index db62cd4515362..0335c7775c2af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** @@ -27,7 +27,7 @@ * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsReportPartitioning extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 1831488ba096f..917372cdd25b3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to @@ -27,7 +27,7 @@ * data source. Implementations that return more accurate statistics based on pushed operators will * not improve query performance until the planner can push operators before getting stats. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsReportStatistics extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 6764d4b7665c7..1cdc02f5736b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** @@ -25,7 +25,7 @@ * share the same values for the {@link #clusteredColumns} will be produced by the same * {@link PartitionReader}. */ -@InterfaceStability.Evolving +@Evolving public class ClusteredDistribution implements Distribution { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 364a3f553923c..02b0e68974919 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** @@ -37,5 +37,5 @@ *
    • {@link ClusteredDistribution}
    • * */ -@InterfaceStability.Evolving +@Evolving public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index fb0b6f1df43bb..c9a00262c1287 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; @@ -28,7 +28,7 @@ * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ -@InterfaceStability.Evolving +@Evolving public interface Partitioning { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 9101c8a44d34e..c7f6fce6e81af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A variation on {@link PartitionReader} for use with continuous streaming processing. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousPartitionReader extends PartitionReader { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index 2d9f1ca1686a1..41195befe5e57 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; @@ -28,7 +28,7 @@ * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for * continuous streaming processing. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { @Override ContinuousPartitionReader createReader(InputPartition partition); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java index 9a3ad2eb8a801..2b784ac0e9f35 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanConfig; @@ -36,7 +36,7 @@ * {@link #stop()} will be called when the streaming execution is completed. Note that a single * query may have multiple executions due to restart or failure recovery. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java index edb0db11bff2c..f56066c639388 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.*; @@ -33,7 +33,7 @@ * will be called when the streaming execution is completed. Note that a single query may have * multiple executions due to restart or failure recovery. */ -@InterfaceStability.Evolving +@Evolving public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 6cf27734867cb..6104175d2c9e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An abstract representation of progress through a {@link MicroBatchReadSupport} or @@ -30,7 +30,7 @@ * maintain compatibility with DataSource V1 APIs. This extension will be removed once we * get rid of V1 completely. */ -@InterfaceStability.Evolving +@Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 383e73db6762b..2c97d924a0629 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will @@ -27,6 +27,6 @@ * * These offsets must be serializable. */ -@InterfaceStability.Evolving +@Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 0ec9e05d6a02b..efe1ac4f78db1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.writer; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how to write the data to data source for batch processing. @@ -37,7 +37,7 @@ * * Please refer to the documentation of commit/abort methods for detailed specifications. */ -@InterfaceStability.Evolving +@Evolving public interface BatchWriteSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 5fb067966ee67..d142ee523ef9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -19,7 +19,7 @@ import java.io.IOException; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is @@ -55,7 +55,7 @@ * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ -@InterfaceStability.Evolving +@Evolving public interface DataWriter { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 19a36dd232456..65105f46b82d5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; /** @@ -31,7 +31,7 @@ * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface DataWriterFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 123335c414e9f..9216e34399092 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,8 +19,8 @@ import java.io.Serializable; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side @@ -30,5 +30,5 @@ * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. */ -@InterfaceStability.Evolving +@Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index a4da24fc5ae68..7d3d21cb2b637 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.writer.DataWriter; @@ -33,7 +33,7 @@ * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingDataWriterFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index 3fdfac5e1c84a..84cfbf2dda483 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.writer.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; @@ -27,7 +27,7 @@ * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingWriteSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 5371a23230c98..fd6f7be2abc5a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -19,9 +19,9 @@ import java.util.concurrent.TimeUnit; +import org.apache.spark.annotation.Evolving; import scala.concurrent.duration.Duration; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; @@ -30,7 +30,7 @@ * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving public class Trigger { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5f58b031f6aef..906e9bc26ef53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -22,7 +22,7 @@ import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; @@ -31,7 +31,7 @@ * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ -@InterfaceStability.Evolving +@Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index ad99b450a4809..14caaeaedbe2b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; @@ -47,7 +47,7 @@ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage * footprint is negligible. */ -@InterfaceStability.Evolving +@Evolving public abstract class ColumnVector implements AutoCloseable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 72a192d089b9f..dd2bd789c26d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -25,7 +25,7 @@ /** * Array abstraction in {@link ColumnVector}. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index d206c1df42abb..07546a54013ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -18,7 +18,7 @@ import java.util.*; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; @@ -27,7 +27,7 @@ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during * the entire data loading process. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarBatch { private int numRows; private final ColumnVector[] columns; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index f2f2279590023..4b9d3c5f59915 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.types.*; @@ -26,7 +26,7 @@ /** * Row abstraction in {@link ColumnVector}. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarRow extends InternalRow { // The data for this row. // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 49d2a34080b13..5a408b29f9337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -60,7 +60,7 @@ private[sql] object Column { * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable class TypedColumn[-T, U]( expr: Expression, private[sql] val encoder: ExpressionEncoder[U]) @@ -130,7 +130,7 @@ class TypedColumn[-T, U]( * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class Column(val expr: Expression) extends Logging { def this(name: String) = this(name match { @@ -1227,7 +1227,7 @@ class Column(val expr: Expression) extends Logging { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 5288907b7d7ff..53e9f810d7c85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -22,18 +22,17 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ - /** * Functionality for working with missing data in `DataFrame`s. * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index df18623e42a02..52df13d39caa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.Partition -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 7c12432d33c33..b2f6a6ba83108 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col @@ -33,7 +33,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable final class DataFrameStatFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1b4998f94b25d..29d479f542115 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,7 +21,7 @@ import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f98eaa3d4eb90..f5caaf3f7fc87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,19 +21,17 @@ import java.io.CharArrayWriter import scala.collection.JavaConverters._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils import org.apache.spark.TaskContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ @@ -166,10 +164,10 @@ private[sql] object Dataset { * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable class Dataset[T] private[sql]( @transient val sparkSession: SparkSession, - @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution, + @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, encoder: Encoder[T]) extends Serializable { @@ -426,7 +424,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** @@ -544,7 +542,7 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def isStreaming: Boolean = logicalPlan.isStreaming /** @@ -557,7 +555,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** @@ -570,7 +568,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true) /** @@ -583,7 +581,7 @@ class Dataset[T] private[sql]( * @since 2.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) /** @@ -596,7 +594,7 @@ class Dataset[T] private[sql]( * @since 2.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint( eager = eager, reliableCheckpoint = false @@ -671,7 +669,7 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { @@ -1066,7 +1064,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. @@ -1142,7 +1140,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -1384,7 +1382,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) @@ -1418,7 +1416,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] @@ -1430,7 +1428,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1445,7 +1443,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3, U4]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1461,7 +1459,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3, U4, U5]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1632,7 +1630,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def reduce(func: (T, T) => T): T = withNewRDDExecutionId { rdd.reduce(func) } @@ -1647,7 +1645,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** @@ -1659,7 +1657,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1681,7 +1679,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) @@ -2483,7 +2481,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def filter(func: T => Boolean): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2497,7 +2495,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def filter(func: FilterFunction[T]): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2511,7 +2509,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { MapElements[T, U](func, logicalPlan) } @@ -2525,7 +2523,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) @@ -2540,7 +2538,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, @@ -2557,7 +2555,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) @@ -2588,7 +2586,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) @@ -2602,7 +2600,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) @@ -3064,7 +3062,7 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 08aa1bbe78fae..1c4ffefb897ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * A container for a [[Dataset]], used for implicit conversions in Scala. @@ -30,7 +30,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index bd8dd6ea3fe0f..302d38cde1430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable class ExperimentalMethods private[sql]() { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 52b8c839643e7..5c0fe798b1044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * The abstract class for writing custom logic to process data generated by a query. @@ -104,7 +104,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 555bcdffb6ee4..7a47242f69381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} @@ -37,7 +37,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], @@ -237,7 +237,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) @@ -272,7 +272,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { @@ -309,7 +309,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], @@ -340,7 +340,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], @@ -371,7 +371,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -413,7 +413,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d4e75b5ebd405..e85636d82a62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} @@ -45,7 +45,7 @@ import org.apache.spark.sql.types.{NumericType, StructType} * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 3c39579149fff..5a554eff02e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} import org.apache.spark.sql.internal.SQLConf - /** * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. * @@ -29,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9982b60fefe60..43f34e6ff4b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,7 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry @@ -54,7 +54,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @groupname Ungrouped Support functions for language integrated queries * @since 1.0.0 */ -@InterfaceStability.Stable +@Stable class SQLContext private[sql](val sparkSession: SparkSession) extends Logging with Serializable { @@ -86,7 +86,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * that listen for execution metrics. */ @Experimental - @InterfaceStability.Evolving + @Evolving def listenerManager: ExecutionListenerManager = sparkSession.listenerManager /** @@ -158,7 +158,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ @Experimental @transient - @InterfaceStability.Unstable + @Unstable def experimental: ExperimentalMethods = sparkSession.experimental /** @@ -244,7 +244,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self } @@ -258,7 +258,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { sparkSession.createDataFrame(rdd) } @@ -271,7 +271,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { sparkSession.createDataFrame(data) } @@ -319,7 +319,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -363,7 +363,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -401,7 +401,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -428,7 +428,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -443,7 +443,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.6.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rows, schema) } @@ -507,7 +507,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def readStream: DataStreamReader = sparkSession.readStream @@ -631,7 +631,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(end: Long): DataFrame = sparkSession.range(end).toDF() /** @@ -643,7 +643,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF() /** @@ -655,7 +655,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long): DataFrame = { sparkSession.range(start, end, step).toDF() } @@ -670,7 +670,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { sparkSession.range(start, end, step, numPartitions).toDF() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 05db292bd41b1..d329af0145c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -21,7 +21,7 @@ import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * * @since 1.6.0 */ -@InterfaceStability.Evolving +@Evolving abstract class SQLImplicits extends LowPrioritySQLImplicits { protected def _sqlContext: SQLContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c0727e844a1ca..725db97df4ed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -73,7 +73,7 @@ import org.apache.spark.util.{CallSite, Utils} * @param parentSessionState If supplied, inherit all session state (i.e. temporary * views, SQL config, UDFs etc) from parent. */ -@InterfaceStability.Stable +@Stable class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], @@ -124,7 +124,7 @@ class SparkSession private( * * @since 2.2.0 */ - @InterfaceStability.Unstable + @Unstable @transient lazy val sharedState: SharedState = { existingSharedState.getOrElse(new SharedState(sparkContext)) @@ -145,7 +145,7 @@ class SparkSession private( * * @since 2.2.0 */ - @InterfaceStability.Unstable + @Unstable @transient lazy val sessionState: SessionState = { parentSessionState @@ -186,7 +186,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** @@ -197,7 +197,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Unstable + @Unstable def experimental: ExperimentalMethods = sessionState.experimentalMethods /** @@ -231,7 +231,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Unstable + @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager /** @@ -289,7 +289,7 @@ class SparkSession private( * @return 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def emptyDataset[T: Encoder]: Dataset[T] = { val encoder = implicitly[Encoder[T]] new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) @@ -302,7 +302,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkSession.setActiveSession(this) val encoder = Encoders.product[A] @@ -316,7 +316,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] @@ -356,7 +356,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema, needsConversion = true) } @@ -370,7 +370,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD.rdd, schema) } @@ -384,7 +384,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -474,7 +474,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -493,7 +493,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } @@ -515,7 +515,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -528,7 +528,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** @@ -539,7 +539,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) } @@ -552,7 +552,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) } @@ -566,7 +566,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG) } @@ -672,7 +672,7 @@ class SparkSession private( * * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def readStream: DataStreamReader = new DataStreamReader(self) /** @@ -706,7 +706,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } @@ -775,13 +775,13 @@ class SparkSession private( } -@InterfaceStability.Stable +@Stable object SparkSession extends Logging { /** * Builder for [[SparkSession]]. */ - @InterfaceStability.Stable + @Stable class Builder extends Logging { private[this] val options = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index a4864344b2d25..5ed76789786bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.mutable -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -66,7 +66,7 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi @Experimental -@InterfaceStability.Unstable +@Unstable class SparkSessionExtensions { type RuleBuilder = SparkSession => Rule[LogicalPlan] type CheckRuleBuilder = SparkSession => LogicalPlan => Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 84da097be53c1..5a3f556c9c074 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,7 +22,7 @@ import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index ab81725def3f4..44668610d8052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental, Stable} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -29,7 +29,7 @@ import org.apache.spark.storage.StorageLevel * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable abstract class Catalog { /** @@ -233,7 +233,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable(tableName: String, path: String): DataFrame /** @@ -261,7 +261,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable(tableName: String, path: String, source: String): DataFrame /** @@ -292,7 +292,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -330,7 +330,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -366,7 +366,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -406,7 +406,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index c0c5ebc2ba2d6..cb270875228ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog import javax.annotation.Nullable -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.DefinedByConstructorParams @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams * @param locationUri path (in the form of a uri) to data files. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Database( val name: String, @Nullable val description: String, @@ -61,7 +61,7 @@ class Database( * @param isTemporary whether the table is a temporary table. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Table( val name: String, @Nullable val database: String, @@ -93,7 +93,7 @@ class Table( * @param isBucket whether the column is a bucket column. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Column( val name: String, @Nullable val description: String, @@ -126,7 +126,7 @@ class Column( * @param isTemporary whether the function is a temporary function or not. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Function( val name: String, @Nullable val database: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 19e3e55cb2829..4c0db3cb42a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.streaming.Trigger /** @@ -25,5 +25,5 @@ import org.apache.spark.sql.streaming.Trigger * the query. */ @Experimental -@InterfaceStability.Evolving +@Evolving case object OneTimeTrigger extends Trigger diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala index 90e1766c4d9f1..caffcc3c4c1a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala @@ -23,15 +23,15 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.Trigger import org.apache.spark.unsafe.types.CalendarInterval /** * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at * the specified interval. */ -@InterfaceStability.Evolving +@Evolving case class ContinuousTrigger(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 1e076207bc607..6b4def35e1955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} +import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -51,7 +51,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index eb956c4b3e888..58a942afe28c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.ScalaUDF @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index d50031bb20621..3d8d931af218e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable object Window { /** @@ -234,5 +234,5 @@ object Window { * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index b7f3000880aca..58227f075f2c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class WindowSpec private[sql]( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 3e637d594caf3..1cb579c4faa76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions.scalalang -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate._ @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.aggregate._ * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving // scalastyle:off object typed { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 4976b875fa298..4e8cb3a6ddd66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable abstract class UserDefinedAggregateFunction extends Serializable { /** @@ -159,7 +159,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b2a6e22cbfc86..1cf2a30c0c8bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -68,7 +68,7 @@ import org.apache.spark.util.Utils * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable // scalastyle:off object functions { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f67cc32c15dd2..ac07e1f6bb4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog @@ -50,7 +50,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. */ @Experimental -@InterfaceStability.Unstable +@Unstable abstract class BaseSessionStateBuilder( val session: SparkSession, val parentState: Option[SessionState] = None) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index accbea41b9603..b34db581ca2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ @@ -124,7 +124,7 @@ private[sql] object SessionState { * Concrete implementation of a [[BaseSessionStateBuilder]]. */ @Experimental -@InterfaceStability.Unstable +@Unstable class SessionStateBuilder( session: SparkSession, parentState: Option[SessionState] = None) @@ -135,7 +135,7 @@ class SessionStateBuilder( /** * Session shared [[FunctionResourceLoader]]. */ -@InterfaceStability.Unstable +@Unstable class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f76c1fae562c6..230b43022b02b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -21,8 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.annotation.{DeveloperApi, Evolving, Since} import org.apache.spark.sql.types._ /** @@ -34,7 +33,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) /** @@ -57,7 +56,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. @@ -197,7 +196,7 @@ abstract class JdbcDialect extends Serializable { * sure to register your dialects first. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving object JdbcDialects { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 354660e9d5943..61875931d226e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Unstable} import org.apache.spark.sql.execution.SparkStrategy /** @@ -40,7 +40,7 @@ package object sql { * [[org.apache.spark.sql.sources]] */ @DeveloperApi - @InterfaceStability.Unstable + @Unstable type Strategy = SparkStrategy type DataFrame = Dataset[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index bdd8c4da6bd30..3f941cc6e1072 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class Filter { /** * List of columns that are referenced by this filter. @@ -48,7 +48,7 @@ abstract class Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -60,7 +60,7 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -71,7 +71,7 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -82,7 +82,7 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -93,7 +93,7 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -104,7 +104,7 @@ case class LessThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -114,7 +114,7 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class In(attribute: String, values: Array[Any]) extends Filter { override def hashCode(): Int = { var h = attribute.hashCode @@ -141,7 +141,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -151,7 +151,7 @@ case class IsNull(attribute: String) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -161,7 +161,7 @@ case class IsNotNull(attribute: String) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -171,7 +171,7 @@ case class And(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -181,7 +181,7 @@ case class Or(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references } @@ -192,7 +192,7 @@ case class Not(child: Filter) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -203,7 +203,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -214,7 +214,7 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6057a795c8bf5..6ad054c9f6403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation._ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable trait DataSourceRegister { /** @@ -65,7 +65,7 @@ trait DataSourceRegister { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait RelationProvider { /** * Returns a new base relation with the given parameters. @@ -96,7 +96,7 @@ trait RelationProvider { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. @@ -117,7 +117,7 @@ trait SchemaRelationProvider { * @since 2.0.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait StreamSourceProvider { /** @@ -148,7 +148,7 @@ trait StreamSourceProvider { * @since 2.0.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait StreamSinkProvider { def createSink( sqlContext: SQLContext, @@ -160,7 +160,7 @@ trait StreamSinkProvider { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait CreatableRelationProvider { /** * Saves a DataFrame to a destination (using data source-specific parameters) @@ -192,7 +192,7 @@ trait CreatableRelationProvider { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType @@ -242,7 +242,7 @@ abstract class BaseRelation { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait TableScan { def buildScan(): RDD[Row] } @@ -253,7 +253,7 @@ trait TableScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -271,7 +271,7 @@ trait PrunedScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } @@ -293,7 +293,7 @@ trait PrunedFilteredScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait InsertableRelation { def insert(data: DataFrame, overwrite: Boolean): Unit } @@ -309,7 +309,7 @@ trait InsertableRelation { * @since 1.3.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index bf6021e692382..e4250145a1ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** * Specifies the input data source format. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b36a8f3f6f15b..5733258a6b310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -39,7 +39,7 @@ import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() @@ -365,7 +365,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.4.0 */ - @InterfaceStability.Evolving + @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { this.source = "foreachBatch" if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") @@ -386,7 +386,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.4.0 */ - @InterfaceStability.Evolving + @Evolving def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index e9510c903acae..ab68eba81b843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.KeyValueGroupedDataset +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** @@ -192,7 +191,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * @since 2.2.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving trait GroupState[S] extends LogicalGroupState[S] { /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index a033575d3d38f..236bd55ee6212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.unsafe.types.CalendarInterval /** @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.CalendarInterval * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") @@ -59,7 +59,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index f2dfbe42260d7..47ddc88e964e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.sql.SparkSession /** @@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession * All these methods are thread-safe. * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving trait StreamingQuery { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 03aeb14de502a..646d6888b2a16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryException private[sql]( private val queryDebugString: String, val message: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 6aa82b89ede81..916d6a0365965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.scheduler.SparkListenerEvent /** @@ -28,7 +28,7 @@ import org.apache.spark.scheduler.SparkListenerEvent * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving abstract class StreamingQueryListener { import StreamingQueryListener._ @@ -67,14 +67,14 @@ abstract class StreamingQueryListener { * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving object StreamingQueryListener { /** * Base type of [[StreamingQueryListener]] events * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving trait Event extends SparkListenerEvent /** @@ -84,7 +84,7 @@ object StreamingQueryListener { * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryStartedEvent private[sql]( val id: UUID, val runId: UUID, @@ -95,7 +95,7 @@ object StreamingQueryListener { * @param progress The query progress updates. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event /** @@ -107,7 +107,7 @@ object StreamingQueryListener { * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryTerminatedEvent private[sql]( val id: UUID, val runId: UUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index cd52d991d55c9..d9ea8dc9d4ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -42,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { private[sql] val stateStoreCoordinator = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index a0c9bcc8929eb..9dc62b7aac891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,7 +22,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Reports information about the instantaneous status of a streaming query. @@ -34,7 +34,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryStatus protected[sql]( val message: String, val isDataAvailable: Boolean, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index f2173aa1e59c2..3cd6700efef5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -29,12 +29,12 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. */ -@InterfaceStability.Evolving +@Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, val numRowsUpdated: Long, @@ -94,7 +94,7 @@ class StateOperatorProgress private[sql]( * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryProgress private[sql]( val id: UUID, val runId: UUID, @@ -165,7 +165,7 @@ class StreamingQueryProgress private[sql]( * Spark. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class SourceProgress protected[sql]( val description: String, val startOffset: String, @@ -209,7 +209,7 @@ class SourceProgress protected[sql]( * @param description Description of the source corresponding to this status. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class SinkProgress protected[sql]( val description: String) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 1310fdfa1356b..77ae047705de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.util import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental} import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.SparkSession @@ -36,7 +36,7 @@ import org.apache.spark.util.{ListenerBus, Utils} * multiple different threads. */ @Experimental -@InterfaceStability.Evolving +@Evolving trait QueryExecutionListener { /** @@ -73,7 +73,7 @@ trait QueryExecutionListener { * Manager for [[QueryExecutionListener]]. See `org.apache.spark.sql.SQLContext.listenerManager`. */ @Experimental -@InterfaceStability.Evolving +@Evolving // The `session` is used to indicate which session carries this listener manager, and we only // catch SQL executions which are launched by the same session. // The `loadExtensions` flag is used to indicate whether we should load the pre-defined, diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java index bfe50c7810f73..fc2171dc99e4c 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java @@ -148,7 +148,7 @@ public TColumn() { super(); } - public TColumn(_Fields setField, Object value) { + public TColumn(TColumn._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java index 44da2cdd089d6..8504c6d608d42 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java @@ -142,7 +142,7 @@ public TColumnValue() { super(); } - public TColumnValue(_Fields setField, Object value) { + public TColumnValue(TColumnValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java index 4fe59b1c51462..fe2a211c46309 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java @@ -136,7 +136,7 @@ public TGetInfoValue() { super(); } - public TGetInfoValue(_Fields setField, Object value) { + public TGetInfoValue(TGetInfoValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java index af7c0b4f15d95..d0d70c1279572 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java @@ -136,7 +136,7 @@ public TTypeEntry() { super(); } - public TTypeEntry(_Fields setField, Object value) { + public TTypeEntry(TTypeEntry._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java index 8c40687a0aab7..a3e3829372276 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java @@ -112,7 +112,7 @@ public TTypeQualifierValue() { super(); } - public TTypeQualifierValue(_Fields setField, Object value) { + public TTypeQualifierValue(TTypeQualifierValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java index 9dd0efc03968d..7e557aeccf5b0 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java @@ -36,7 +36,7 @@ public abstract class AbstractService implements Service { /** * Service state: initially {@link STATE#NOTINITED}. */ - private STATE state = STATE.NOTINITED; + private Service.STATE state = STATE.NOTINITED; /** * Service name. @@ -70,7 +70,7 @@ public AbstractService(String name) { } @Override - public synchronized STATE getServiceState() { + public synchronized Service.STATE getServiceState() { return state; } @@ -159,7 +159,7 @@ public long getStartTime() { * if the service state is different from * the desired state */ - private void ensureCurrentState(STATE currentState) { + private void ensureCurrentState(Service.STATE currentState) { ServiceOperations.ensureCurrentState(state, currentState); } @@ -173,7 +173,7 @@ private void ensureCurrentState(STATE currentState) { * @param newState * new service state */ - private void changeState(STATE newState) { + private void changeState(Service.STATE newState) { state = newState; // notify listeners for (ServiceStateChangeListener l : listeners) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java index 5a508745414a7..15551da4785f6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java @@ -71,7 +71,7 @@ public HiveConf getHiveConf() { } @Override - public STATE getServiceState() { + public Service.STATE getServiceState() { return service.getServiceState(); } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2882672f327c4..4f3914740ec20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo * Builder that produces a Hive-aware `SessionState`. */ @Experimental -@InterfaceStability.Unstable +@Unstable class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { From ce2cdc36e29742dda22200963cfd3f9876170455 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 08:07:20 -0600 Subject: [PATCH 2086/2461] [SPARK-26043][CORE] Make SparkHadoopUtil private to Spark ## What changes were proposed in this pull request? Make SparkHadoopUtil private to Spark ## How was this patch tested? Existing tests. Closes #23066 from srowen/SPARK-26043. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../apache/spark/deploy/SparkHadoopUtil.scala | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 5979151345415..217e5145f1c56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -18,10 +18,9 @@ package org.apache.spark.deploy import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} -import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date, Locale} +import java.util.{Arrays, Date, Locale} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -38,17 +37,13 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Contains util methods to interact with Hadoop from Spark. */ -@DeveloperApi -class SparkHadoopUtil extends Logging { +private[spark] class SparkHadoopUtil extends Logging { private val sparkConf = new SparkConf(false).loadFromSystemProperties(true) val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) @@ -274,11 +269,10 @@ class SparkHadoopUtil extends Logging { name.startsWith(prefix) && !name.endsWith(exclusionSuffix) } }) - Arrays.sort(fileStatuses, new Comparator[FileStatus] { - override def compare(o1: FileStatus, o2: FileStatus): Int = { + Arrays.sort(fileStatuses, + (o1: FileStatus, o2: FileStatus) => { Longs.compare(o1.getModificationTime, o2.getModificationTime) - } - }) + }) fileStatuses } catch { case NonFatal(e) => @@ -388,7 +382,7 @@ class SparkHadoopUtil extends Logging { } -object SparkHadoopUtil { +private[spark] object SparkHadoopUtil { private lazy val instance = new SparkHadoopUtil From b58b1fdf906d9609321824fc0bb892b986763b3e Mon Sep 17 00:00:00 2001 From: "Liu,Linhong" Date: Mon, 19 Nov 2018 22:09:44 +0800 Subject: [PATCH 2087/2461] [SPARK-26068][CORE] ChunkedByteBufferInputStream should handle empty chunks correctly ## What changes were proposed in this pull request? Empty chunk in ChunkedByteBuffer will truncate the ChunkedByteBufferInputStream. The detail reason is described in: https://issues.apache.org/jira/browse/SPARK-26068 ## How was this patch tested? Modified current UT to cover this case. Closes #23040 from LinhongLiu/fix-empty-chunked-byte-buffer. Lead-authored-by: Liu,Linhong Co-authored-by: Xianjin YE Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/util/io/ChunkedByteBuffer.scala | 3 ++- .../scala/org/apache/spark/io/ChunkedByteBufferSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 870830fff4c3e..128d6ff8cd746 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -222,7 +222,8 @@ private[spark] class ChunkedByteBufferInputStream( dispose: Boolean) extends InputStream { - private[this] var chunks = chunkedByteBuffer.getChunks().iterator + // Filter out empty chunks since `read()` assumes all chunks are non-empty. + private[this] var chunks = chunkedByteBuffer.getChunks().filter(_.hasRemaining).iterator private[this] var currentChunk: ByteBuffer = { if (chunks.hasNext) { chunks.next() diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index ff117b1c21cb1..083c5e696b753 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -90,7 +90,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) - val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, empty, bytes2)) assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit()) val inputStream = chunkedByteBuffer.toInputStream(dispose = false) From 48ea64bf5bd4201c6a7adca67e20b75d23c223f6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 19 Nov 2018 22:18:20 +0800 Subject: [PATCH 2088/2461] [SPARK-26112][SQL] Update since versions of new built-in functions. ## What changes were proposed in this pull request? The following 5 functions were removed from branch-2.4: - map_entries - map_filter - transform_values - transform_keys - map_zip_with We should update the since version to 3.0.0. ## How was this patch tested? Existing tests. Closes #23082 from ueshin/issues/SPARK-26112/since. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- R/pkg/R/functions.R | 2 +- python/pyspark/sql/functions.py | 2 +- .../catalyst/expressions/collectionOperations.scala | 2 +- .../catalyst/expressions/higherOrderFunctions.scala | 12 ++++++------ .../main/scala/org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9abb7fc1fadb4..f72645a257796 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3370,7 +3370,7 @@ setMethod("flatten", #' #' @rdname column_collection_functions #' @aliases map_entries map_entries,Column-method -#' @note map_entries since 2.4.0 +#' @note map_entries since 3.0.0 setMethod("map_entries", signature(x = "Column"), function(x) { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e86749cc15c35..286ef219a69e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2576,7 +2576,7 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) -@since(2.4) +@since(3.0) def map_entries(col): """ Collection function: Returns an unordered array of all entries in the given map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b24d7486f3454..3c260954a72a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -350,7 +350,7 @@ case class MapValues(child: Expression) > SELECT _FUNC_(map(1, 'a', 2, 'b')); [{"key":1,"value":"a"},{"key":2,"value":"b"}] """, - since = "2.4.0") + since = "3.0.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b07d9466ba0d1..0b698f9290711 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -264,13 +264,13 @@ case class ArrayTransform( * Filters entries in a map using the provided function. */ @ExpressionDescription( -usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", -examples = """ + usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", + examples = """ Examples: > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); {1:0,3:-1} """, -since = "2.4.0") + since = "3.0.0") case class MapFilter( argument: Expression, function: Expression) @@ -504,7 +504,7 @@ case class ArrayAggregate( > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); {2:1,4:2,6:3} """, - since = "2.4.0") + since = "3.0.0") case class TransformKeys( argument: Expression, function: Expression) @@ -554,7 +554,7 @@ case class TransformKeys( > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); {1:2,2:4,3:6} """, - since = "2.4.0") + since = "3.0.0") case class TransformValues( argument: Expression, function: Expression) @@ -605,7 +605,7 @@ case class TransformValues( > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); {1:"ax",2:"by"} """, - since = "2.4.0") + since = "3.0.0") case class MapZipWith(left: Expression, right: Expression, function: Expression) extends HigherOrderFunction with CodegenFallback { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1cf2a30c0c8bd..efa8f8526387f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3731,7 +3731,7 @@ object functions { /** * Returns an unordered array of all entries in the given map. * @group collection_funcs - * @since 2.4.0 + * @since 3.0.0 */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } From 35c55163555f3671edd02ed0543785af82de07ca Mon Sep 17 00:00:00 2001 From: Julien Date: Mon, 19 Nov 2018 22:24:53 +0800 Subject: [PATCH 2089/2461] [SPARK-26024][SQL] Update documentation for repartitionByRange Following [SPARK-26024](https://issues.apache.org/jira/browse/SPARK-26024), I noticed the number of elements in each partition after repartitioning using `df.repartitionByRange` can vary for the same setup: ```scala // Shuffle numbers from 0 to 1000, and make a DataFrame val df = Random.shuffle(0.to(1000)).toDF("val") // Repartition it using 3 partitions // Sum up number of elements in each partition, and collect it. // And do it several times for (i <- 0 to 9) { var counts = df.repartitionByRange(3, col("val")) .mapPartitions{part => Iterator(part.size)} .collect() println(counts.toList) } // -> the number of elements in each partition varies ``` This is expected as for performance reasons this method uses sampling to estimate the ranges (with default size of 100). Hence, the output may not be consistent, since sampling can return different values. But documentation was not mentioning it at all, leading to misunderstanding. ## What changes were proposed in this pull request? Update the documentation (Spark & PySpark) to mention the impact of `spark.sql.execution.rangeExchange.sampleSizePerPartition` on the resulting partitioned DataFrame. Closes #23025 from JulienPeloton/SPARK-26024. Authored-by: Julien Signed-off-by: Wenchen Fan --- R/pkg/R/DataFrame.R | 8 ++++++++ python/pyspark/sql/dataframe.py | 5 +++++ .../src/main/scala/org/apache/spark/sql/Dataset.scala | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c99ad76f7643c..52e76570139e2 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -767,6 +767,14 @@ setMethod("repartition", #' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} #' +#' At least one partition-by expression must be specified. +#' When no explicit sort order is specified, "ascending nulls first" is assumed. +#' +#' Note that due to performance reasons this method uses sampling to estimate the ranges. +#' Hence, the output may not be consistent, since sampling can return different values. +#' The sample size can be controlled by the config +#' \code{spark.sql.execution.rangeExchange.sampleSizePerPartition}. +#' #' @param x a SparkDataFrame. #' @param numPartitions the number of partitions to use. #' @param col the column by which the range partitioning will be performed. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5748f6c6bd5eb..c4f4d81999544 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -732,6 +732,11 @@ def repartitionByRange(self, numPartitions, *cols): At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. + Note that due to performance reasons this method uses sampling to estimate the ranges. + Hence, the output may not be consistent, since sampling can return different values. + The sample size can be controlled by the config + `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() 2 >>> df.show() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5caaf3f7fc87..0e77ec0406257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2787,6 +2787,12 @@ class Dataset[T] private[sql]( * When no explicit sort order is specified, "ascending nulls first" is assumed. * Note, the rows are not sorted in each partition of the resulting Dataset. * + * + * Note that due to performance reasons this method uses sampling to estimate the ranges. + * Hence, the output may not be consistent, since sampling can return different values. + * The sample size can be controlled by the config + * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * * @group typedrel * @since 2.3.0 */ @@ -2811,6 +2817,11 @@ class Dataset[T] private[sql]( * When no explicit sort order is specified, "ascending nulls first" is assumed. * Note, the rows are not sorted in each partition of the resulting Dataset. * + * Note that due to performance reasons this method uses sampling to estimate the ranges. + * Hence, the output may not be consistent, since sampling can return different values. + * The sample size can be controlled by the config + * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * * @group typedrel * @since 2.3.0 */ From 219b037f05636a3a7c8116987c319773f4145b63 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 19 Nov 2018 22:42:24 +0800 Subject: [PATCH 2090/2461] [SPARK-26071][SQL] disallow map as map key ## What changes were proposed in this pull request? Due to implementation limitation, currently Spark can't compare or do equality check between map types. As a result, map values can't appear in EQUAL or comparison expressions, can't be grouping key, etc. The more important thing is, map loop up needs to do equality check of the map key, and thus can't support map as map key when looking up values from a map. Thus it's not useful to have map as map key. This PR proposes to stop users from creating maps using map type as key. The list of expressions that are updated: `CreateMap`, `MapFromArrays`, `MapFromEntries`, `MapConcat`, `TransformKeys`. I manually checked all the places that create `MapType`, and came up with this list. Note that, maps with map type key still exist, via reading from parquet files, converting from scala/java map, etc. This PR is not to completely forbid map as map key, but to avoid creating it by Spark itself. Motivation: when I was trying to fix the duplicate key problem, I found it's impossible to do it with map type map key. I think it's reasonable to avoid map type map key for builtin functions. ## How was this patch tested? updated test Closes #23045 from cloud-fan/map-key. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 6 +- .../expressions/collectionOperations.scala | 12 +- .../expressions/complexTypeCreator.scala | 14 ++- .../expressions/higherOrderFunctions.scala | 4 + .../spark/sql/catalyst/util/TypeUtils.scala | 10 +- .../CollectionExpressionsSuite.scala | 113 ++++++++++-------- .../expressions/ComplexTypeSuite.scala | 83 +++++++------ .../expressions/ExpressionEvalHelper.scala | 21 +++- .../HigherOrderFunctionsSuite.scala | 41 ++++--- .../inputs/typeCoercion/native/mapconcat.sql | 9 +- .../typeCoercion/native/mapconcat.sql.out | 19 ++- 11 files changed, 203 insertions(+), 129 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 50458e96f7c3f..07079d93f25b6 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,6 +17,8 @@ displayTitle: Spark SQL Upgrading Guide - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. + - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. @@ -117,7 +119,7 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. + - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. @@ -303,7 +305,7 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). - Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3c260954a72a2..43116743e9952 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -521,13 +521,18 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def checkInputDataTypes(): TypeCheckResult = { - var funcName = s"function $prettyName" + val funcName = s"function $prettyName" if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"input to $funcName should all be of type map, but it's " + children.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + if (sameTypeCheck.isFailure) { + sameTypeCheck + } else { + TypeUtils.checkForMapKeyType(dataType.keyType) + } } } @@ -740,7 +745,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { - case Some(_) => TypeCheckResult.TypeCheckSuccess + case Some((mapType, _, _)) => + TypeUtils.checkForMapKeyType(mapType.keyType) case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0361372b6b732..6b77996789f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -161,11 +161,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { "The given values of function map should all be the same type, but they are " + values.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForMapKeyType(dataType.keyType) } } - override def dataType: DataType = { + override def dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) .getOrElse(StringType), @@ -224,6 +224,16 @@ case class MapFromArrays(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else { + val keyType = left.dataType.asInstanceOf[ArrayType].elementType + TypeUtils.checkForMapKeyType(keyType) + } + } + override def dataType: DataType = { MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0b698f9290711..8b31021866220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -514,6 +514,10 @@ case class TransformKeys( override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + override def checkInputDataTypes(): TypeCheckResult = { + TypeUtils.checkForMapKeyType(function.dataType) + } + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 76218b459ef0d..2a71fdb7592bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -46,12 +46,20 @@ object TypeUtils { if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - return TypeCheckResult.TypeCheckFailure( + TypeCheckResult.TypeCheckFailure( s"input to $caller should all be the same type, but it's " + types.map(_.catalogString).mkString("[", ", ", "]")) } } + def checkForMapKeyType(keyType: DataType): TypeCheckResult = { + if (keyType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("The key of map cannot be/contain map.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2e0adbb465008..1415b7da6fca1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -25,6 +25,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -108,32 +109,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Map Concat") { - val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + val m0 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, valueContainsNull = false)) - val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + val m1 = Literal.create(create_map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, valueContainsNull = false)) - val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) - val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) - val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) - val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) - val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) - val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + val m2 = Literal.create(create_map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(create_map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(create_map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(create_map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(create_map(List(1, 2) -> 1, List(3, 4) -> 2), MapType(ArrayType(IntegerType), IntegerType)) - val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + val m8 = Literal.create(create_map(List(5, 6) -> 3, List(1, 2) -> 4), MapType(ArrayType(IntegerType), IntegerType)) - val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), - MapType(MapType(IntegerType, IntegerType), IntegerType)) - val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), - MapType(MapType(IntegerType, IntegerType), IntegerType)) - val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + val m9 = Literal.create(create_map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, valueContainsNull = false)) - val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + val m10 = Literal.create(create_map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, valueContainsNull = false)) - val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + val m11 = Literal.create(create_map(1 -> 2, 3 -> 4), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val m14 = Literal.create(Map(5 -> 6), + val m12 = Literal.create(create_map(5 -> 6), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val m15 = Literal.create(Map(7 -> null), + val m13 = Literal.create(create_map(7 -> null), MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) @@ -147,7 +144,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), - Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // 3 maps checkEvaluation(MapConcat(Seq(m0, m1, m2)), @@ -174,7 +171,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) // keys that are primitive - checkEvaluation(MapConcat(Seq(m11, m12)), + checkEvaluation(MapConcat(Seq(m9, m10)), ( Array(1, 2, 3, 4), // keys Array("1", "2", "3", "4") // values @@ -189,20 +186,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) ) - // keys that are maps, with overlap - checkEvaluation(MapConcat(Seq(m9, m10)), - ( - Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), - Map(1 -> 2, 3 -> 4)), // keys - Array(1, 2, 3, 4) // values - ) - ) - // both keys and value are primitive and valueContainsNull = false - checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + checkEvaluation(MapConcat(Seq(m11, m12)), create_map(1 -> 2, 3 -> 4, 5 -> 6)) // both keys and value are primitive and valueContainsNull = true - checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + checkEvaluation(MapConcat(Seq(m11, m13)), create_map(1 -> 2, 3 -> 4, 7 -> null)) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) @@ -211,7 +199,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq(mNull)), null) // single map - checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + checkEvaluation(MapConcat(Seq(m0)), create_map("a" -> "1", "b" -> "2")) // no map checkEvaluation(MapConcat(Seq.empty), Map.empty) @@ -245,12 +233,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m1, mNull)).nullable) val mapConcat = MapConcat(Seq( - Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + Literal.create(create_map(Seq(1, 2) -> Seq("a", "b")), MapType( ArrayType(IntegerType, containsNull = false), ArrayType(StringType, containsNull = false), valueContainsNull = false)), - Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + Literal.create(create_map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), MapType( ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), @@ -264,6 +252,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(1, 2) -> Seq("a", "b"), Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null)) + + // map key can't be map + val mapOfMap = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val mapOfMap2 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val map = MapConcat(Seq(mapOfMap, mapOfMap2)) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("MapFromEntries") { @@ -274,20 +274,20 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper StructField("b", valueType))), true) } - def r(values: Any*): InternalRow = create_row(values: _*) + def row(values: Any*): InternalRow = create_row(values: _*) // Primitive-type keys and values val aiType = arrayType(IntegerType, IntegerType) - val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) - val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai0 = Literal.create(Seq(row(1, 10), row(2, 20), row(3, 20)), aiType) + val ai1 = Literal.create(Seq(row(1, null), row(2, 20), row(3, null)), aiType) val ai2 = Literal.create(Seq.empty, aiType) val ai3 = Literal.create(null, aiType) - val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) - val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) - val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + val ai4 = Literal.create(Seq(row(1, 10), row(1, 20)), aiType) + val ai5 = Literal.create(Seq(row(1, 10), row(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, row(2, 20), null), aiType) - checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) - checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai0), create_map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) @@ -298,23 +298,36 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // Non-primitive-type keys and values val asType = arrayType(StringType, StringType) - val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) - val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as0 = Literal.create(Seq(row("a", "aa"), row("b", "bb"), row("c", "bb")), asType) + val as1 = Literal.create(Seq(row("a", null), row("b", "bb"), row("c", null)), asType) val as2 = Literal.create(Seq.empty, asType) val as3 = Literal.create(null, asType) - val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) - val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) - val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + val as4 = Literal.create(Seq(row("a", "aa"), row("a", "bb")), asType) + val as5 = Literal.create(Seq(row("a", "aa"), row(null, "bb")), asType) + val as6 = Literal.create(Seq(null, row("b", "bb"), null), asType) - checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) - checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as0), create_map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkEvaluation(MapFromEntries(as6), null) + + // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), "The first field from a struct (key) can't be null.") - checkEvaluation(MapFromEntries(as6), null) + + // map key can't be map + val structOfMap = row(create_map(1 -> 1), 1) + val map = MapFromEntries(Literal.create( + Seq(structOfMap), + arrayType(keyType = MapType(IntegerType, IntegerType), valueType = IntegerType))) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("Sort Array") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 77aaf55480ec2..d95f42e04e37c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ @@ -158,40 +158,32 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { keys.zip(values).flatMap { case (k, v) => Seq(k, v) } } - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } - val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateMap(Nil), Map.empty) checkEvaluation( CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(intSeq, longSeq)) + create_map(intSeq, longSeq)) checkEvaluation( CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(strSeq, longSeq)) + create_map(strSeq, longSeq)) checkEvaluation( CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), - createMap(longSeq, strSeq)) + create_map(longSeq, strSeq)) val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) checkEvaluation( CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), - createMap(intSeq, strWithNull.map(_.value))) - intercept[RuntimeException] { - checkEvaluationWithoutCodegen( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), - null, null) - } - intercept[RuntimeException] { - checkEvaluationWithUnsafeProjection( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), - null, null) - } + create_map(intSeq, strWithNull.map(_.value))) + // Map key can't be null + checkExceptionInExpression[RuntimeException]( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + "Cannot use null as map key") + + // ArrayType map key and value val map = CreateMap(Seq( Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), Literal.create(strSeq, ArrayType(StringType, containsNull = false)), @@ -202,15 +194,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), valueContainsNull = false)) - checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) + checkEvaluation(map, create_map(intSeq -> strSeq, (intSeq :+ null) -> (strSeq :+ null))) + + // map key can't be map + val map2 = CreateMap(Seq( + Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)), + Literal(1) + )) + map2.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("MapFromArrays") { - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } - val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) @@ -228,24 +226,33 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val nullArray = Literal.create(null, ArrayType(StringType, false)) - checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) - checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) - checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + checkEvaluation(MapFromArrays(intArray, longArray), create_map(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), create_map(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), create_map(integerSeq, strSeq)) checkEvaluation( - MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + MapFromArrays(strArray, intWithNullArray), create_map(strSeq, intWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation(MapFromArrays(nullArray, nullArray), null) - intercept[RuntimeException] { - checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) - } - intercept[RuntimeException] { - checkEvaluation( - MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + // Map key can't be null + checkExceptionInExpression[RuntimeException]( + MapFromArrays(intWithNullArray, strArray), + "Cannot use null as map key") + + // map key can't be map + val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b")) + val map = MapFromArrays( + Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))), + Literal.create(Seq(1), ArrayType(IntegerType)) + ) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index da18475276a13..eb33325d0b31a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -48,6 +48,25 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } + // Currently MapData just stores the key and value arrays. Its equality is not well implemented, + // as the order of the map entries should not matter for equality. This method creates MapData + // with the entries ordering preserved, so that we can deterministically test expressions with + // map input/output. + protected def create_map(entries: (_, _)*): ArrayBasedMapData = { + create_map(entries.map(_._1), entries.map(_._2)) + } + + protected def create_map(keys: Seq[_], values: Seq[_]): ArrayBasedMapData = { + assert(keys.length == values.length) + val keyArray = CatalystTypeConverters + .convertToCatalyst(keys) + .asInstanceOf[ArrayData] + val valueArray = CatalystTypeConverters + .convertToCatalyst(values) + .asInstanceOf[ArrayData] + new ArrayBasedMapData(keyArray, valueArray) + } + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e13f4d98295be..66bf18af95799 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ @@ -310,13 +311,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("TransformKeys") { val ai0 = Literal.create( - Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + create_map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), MapType(IntegerType, IntegerType, valueContainsNull = false)) val ai1 = Literal.create( Map.empty[Int, Int], MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai2 = Literal.create( - Map(1 -> 1, 2 -> null, 3 -> 3), + create_map(1 -> 1, 2 -> null, 3 -> 3), MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) @@ -324,26 +325,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusValue: (Expression, Expression) => Expression = (k, v) => k + v val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 - checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) - checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation(transformKeys(ai0, plusOne), create_map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), create_map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) checkEvaluation( - transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + transformKeys(transformKeys(ai0, plusOne), plusValue), + create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) checkEvaluation(transformKeys(ai0, modKey), ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) - checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation(transformKeys(ai2, plusOne), create_map(2 -> 1, 3 -> null, 4 -> 3)) checkEvaluation( - transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + transformKeys(transformKeys(ai2, plusOne), plusOne), create_map(3 -> 1, 4 -> null, 5 -> 3)) checkEvaluation(transformKeys(ai3, plusOne), null) val as0 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + create_map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType, valueContainsNull = false)) val as1 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + create_map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType, valueContainsNull = true)) val as2 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) @@ -355,26 +357,35 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (k, v) => Length(k) + 1 checkEvaluation( - transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + transformKeys(as0, concatValue), create_map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) checkEvaluation( transformKeys(transformKeys(as0, concatValue), concatValue), - Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + create_map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) checkEvaluation( transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), Map.empty[Int, String]) checkEvaluation(transformKeys(as0, convertKeyToKeyLength), - Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + create_map(2 -> "xy", 3 -> "yz", 4 -> "zx")) checkEvaluation(transformKeys(as1, convertKeyToKeyLength), - Map(2 -> "xy", 3 -> "yz", 4 -> null)) + create_map(2 -> "xy", 3 -> "yz", 4 -> null)) checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) val ax0 = Literal.create( - Map(1 -> "x", 2 -> "y", 3 -> "z"), + create_map(1 -> "x", 2 -> "y", 3 -> "z"), MapType(IntegerType, StringType, valueContainsNull = false)) - checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + checkEvaluation(transformKeys(ax0, plusOne), create_map(2 -> "x", 3 -> "y", 4 -> "z")) + + // map key can't be map + val makeMap: (Expression, Expression) => Expression = (k, v) => CreateMap(Seq(k, v)) + val map = transformKeys(ai0, makeMap) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("TransformValues") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql index 69da67fc66fc0..60895020fcc83 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -13,7 +13,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( map('a', 'b'), map('c', 'd'), map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), - map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), map('a', 1), map('c', 2), map(1, 'a'), map(2, 'c') ) AS various_maps ( @@ -31,7 +30,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( string_map1, string_map2, array_map1, array_map2, struct_map1, struct_map2, - map_map1, map_map2, string_int_map1, string_int_map2, int_string_map1, int_string_map2 ); @@ -51,7 +49,6 @@ SELECT map_concat(string_map1, string_map2) string_map, map_concat(array_map1, array_map2) array_map, map_concat(struct_map1, struct_map2) struct_map, - map_concat(map_map1, map_map2) map_map, map_concat(string_int_map1, string_int_map2) string_int_map, map_concat(int_string_map1, int_string_map2) int_string_map FROM various_maps; @@ -71,7 +68,7 @@ FROM various_maps; -- Concatenate map of incompatible types 1 SELECT - map_concat(tinyint_map1, map_map2) tm_map + map_concat(tinyint_map1, array_map1) tm_map FROM various_maps; -- Concatenate map of incompatible types 2 @@ -86,10 +83,10 @@ FROM various_maps; -- Concatenate map of incompatible types 4 SELECT - map_concat(map_map1, array_map2) ma_map + map_concat(struct_map1, array_map2) ma_map FROM various_maps; -- Concatenate map of incompatible types 5 SELECT - map_concat(map_map1, struct_map2) ms_map + map_concat(int_map1, array_map2) ms_map FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index efc88e47209a6..79e00860e4c05 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -18,7 +18,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( map('a', 'b'), map('c', 'd'), map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), - map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), map('a', 1), map('c', 2), map(1, 'a'), map(2, 'c') ) AS various_maps ( @@ -36,7 +35,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( string_map1, string_map2, array_map1, array_map2, struct_map1, struct_map2, - map_map1, map_map2, string_int_map1, string_int_map2, int_string_map1, int_string_map2 ) @@ -61,14 +59,13 @@ SELECT map_concat(string_map1, string_map2) string_map, map_concat(array_map1, array_map2) array_map, map_concat(struct_map1, struct_map2) struct_map, - map_concat(map_map1, map_map2) map_map, map_concat(string_int_map1, string_int_map2) string_int_map, map_concat(int_string_map1, int_string_map2) int_string_map FROM various_maps -- !query 1 schema -struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,string_int_map:map,int_string_map:map> -- !query 1 output -{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {"a":1,"c":2} {1:"a",2:"c"} -- !query 2 @@ -91,13 +88,13 @@ struct,si_map:map,ib_map:map -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`array_map1`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 -- !query 4 @@ -124,21 +121,21 @@ cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' -- !query 6 SELECT - map_concat(map_map1, array_map2) ma_map + map_concat(struct_map1, array_map2) ma_map FROM various_maps -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`struct_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,struct>, map,array>]; line 2 pos 4 -- !query 7 SELECT - map_concat(map_map1, struct_map2) ms_map + map_concat(int_map1, array_map2) ms_map FROM various_maps -- !query 7 schema struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 From 32365f8177f913533d348f7079605a282f1014ef Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 09:16:42 -0600 Subject: [PATCH 2091/2461] [SPARK-26090][CORE][SQL][ML] Resolve most miscellaneous deprecation and build warnings for Spark 3 ## What changes were proposed in this pull request? The build has a lot of deprecation warnings. Some are new in Scala 2.12 and Java 11. We've fixed some, but I wanted to take a pass at fixing lots of easy miscellaneous ones here. They're too numerous and small to list here; see the pull request. Some highlights: - `BeanInfo` is deprecated in 2.12, and BeanInfo classes are pretty ancient in Java. Instead, case classes can explicitly declare getters - Eta expansion of zero-arg methods; foo() becomes () => foo() in many cases - Floating-point Range is inexact and deprecated, like 0.0 to 100.0 by 1.0 - finalize() is finally deprecated (just needs to be suppressed) - StageInfo.attempId was deprecated and easiest to remove here I'm not now going to touch some chunks of deprecation warnings: - Parquet deprecations - Hive deprecations (particularly serde2 classes) - Deprecations in generated code (mostly Thriftserver CLI) - ProcessingTime deprecations (we may need to revive this class as internal) - many MLlib deprecations because they concern methods that may be removed anyway - a few Kinesis deprecations I couldn't figure out - Mesos get/setRole, which I don't know well - Kafka/ZK deprecations (e.g. poll()) - Kinesis - a few other ones that will probably resolve by deleting a deprecated method ## How was this patch tested? Existing tests, including manual testing with the 2.11 build and Java 11. Closes #23065 from srowen/SPARK-26090. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../spark/util/kvstore/LevelDBIterator.java | 1 + common/unsafe/pom.xml | 5 ++++ .../types/UTF8StringPropertyCheckSuite.scala | 4 +-- .../spark/io/NioBufferedFileInputStream.java | 3 +- .../scala/org/apache/spark/SparkContext.scala | 4 ++- .../org/apache/spark/api/r/RBackend.scala | 8 ++--- .../HadoopDelegationTokenManager.scala | 4 ++- .../org/apache/spark/executor/Executor.scala | 7 +++-- .../apache/spark/scheduler/StageInfo.scala | 4 ++- .../rdd/ParallelCollectionSplitSuite.scala | 4 +-- .../serializer/KryoSerializerSuite.scala | 10 +++---- .../spark/status/AppStatusListenerSuite.scala | 17 +++++------ .../ExternalAppendOnlyMapSuite.scala | 1 + .../org/apache/spark/sql/avro/AvroSuite.scala | 20 ++++++------- .../sql/kafka010/KafkaContinuousTest.scala | 4 ++- .../streaming/kafka010/ConsumerStrategy.scala | 6 ++-- .../spark/ml/feature/LabeledPoint.scala | 8 +++-- .../ml/feature/QuantileDiscretizer.scala | 10 ++----- .../spark/mllib/regression/LabeledPoint.scala | 8 +++-- .../spark/mllib/stat/test/StreamingTest.scala | 5 ++-- .../apache/spark/ml/feature/DCTSuite.scala | 8 ++--- .../apache/spark/ml/feature/NGramSuite.scala | 9 +++--- .../ml/feature/QuantileDiscretizerSuite.scala | 12 ++++---- .../spark/ml/feature/TokenizerSuite.scala | 8 ++--- .../spark/ml/feature/VectorIndexerSuite.scala | 9 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- pom.xml | 5 ++++ project/MimaExcludes.scala | 5 ++++ .../k8s/submit/KubernetesDriverBuilder.scala | 2 +- .../k8s/KubernetesExecutorBuilder.scala | 2 +- .../spark/deploy/k8s/submit/ClientSuite.scala | 3 +- .../k8s/ExecutorPodsAllocatorSuite.scala | 2 +- .../deploy/yarn/YarnAllocatorSuite.scala | 3 +- .../util/HyperLogLogPlusPlusHelper.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 12 ++++---- .../sql/streaming/StreamingQueryManager.scala | 2 +- .../sql/JavaBeanDeserializationSuite.java | 17 ++++++++++- .../sources/v2/JavaRangeInputPartition.java | 30 +++++++++++++++++++ .../sql/sources/v2/JavaSimpleReadSupport.java | 9 ------ .../spark/sql/UserDefinedTypeSuite.scala | 10 +++---- .../compression/IntegralDeltaSuite.scala | 3 +- .../ProcessingTimeExecutorSuite.scala | 5 +--- .../sources/TextSocketStreamSuite.scala | 5 ++-- .../sql/util/DataFrameCallbackSuite.scala | 2 +- .../HiveCliSessionStateSuite.scala | 2 +- 45 files changed, 177 insertions(+), 125 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index f62e85d435318..e3efc92c4a54a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -196,6 +196,7 @@ public synchronized void close() throws IOException { * when Scala wrappers are used, this makes sure that, hopefully, the JNI resources held by * the iterator will eventually be released. */ + @SuppressWarnings("deprecation") @Override protected void finalize() throws Throwable { db.closeIterator(this); diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 7e4b08217f1b0..93a4f67fd23f2 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -89,6 +89,11 @@ commons-lang3 test + + org.apache.commons + commons-text + test + target/scala-${scala.binary.version}/classes diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 9656951810daf..fdb81a06d41c9 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types -import org.apache.commons.lang3.StringUtils +import org.apache.commons.text.similarity.LevenshteinDistance import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.GeneratorDrivenPropertyChecks // scalastyle:off @@ -232,7 +232,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty test("levenshteinDistance") { forAll { (one: String, another: String) => assert(toUTF8(one).levenshteinDistance(toUTF8(another)) === - StringUtils.getLevenshteinDistance(one, another)) + LevenshteinDistance.getDefaultInstance.apply(one, another)) } } diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..92bf0ecc1b5cb 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -27,7 +27,7 @@ * to read a file to avoid extra copy of data between Java and * native memory which happens when using {@link java.io.BufferedInputStream}. * Unfortunately, this is not something already available in JDK, - * {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio, + * {@code sun.nio.ch.ChannelInputStream} supports reading a file using nio, * but does not support buffering. */ public final class NioBufferedFileInputStream extends InputStream { @@ -130,6 +130,7 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + @SuppressWarnings("deprecation") @Override protected void finalize() throws IOException { close(); diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb91717dfa121..845a3d5f6d6f9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -502,7 +502,9 @@ class SparkContext(config: SparkConf) extends Logging { _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // create and start the heartbeater for collecting memory metrics - _heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, "driver-heartbeater", + _heartbeater = new Heartbeater(env.memoryManager, + () => SparkContext.this.reportHeartBeat(), + "driver-heartbeater", conf.get(EXECUTOR_HEARTBEAT_INTERVAL)) _heartbeater.start() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 7ce2581555014..50c8fdf5316d6 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.io.{DataOutputStream, File, FileOutputStream, IOException} import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit @@ -32,8 +32,6 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -99,7 +97,7 @@ private[spark] class RBackend { if (bootstrap != null && bootstrap.config().group() != null) { bootstrap.config().group().shutdownGracefully() } - if (bootstrap != null && bootstrap.childGroup() != null) { + if (bootstrap != null && bootstrap.config().childGroup() != null) { bootstrap.config().childGroup().shutdownGracefully() } bootstrap = null @@ -147,7 +145,7 @@ private[spark] object RBackend extends Logging { new Thread("wait for socket to close") { setDaemon(true) override def run(): Unit = { - // any un-catched exception will also shutdown JVM + // any uncaught exception will also shutdown JVM val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 10cd8742f2b49..1169b2878e993 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -270,7 +270,9 @@ private[spark] class HadoopDelegationTokenManager( } private def loadProviders(): Map[String, HadoopDelegationTokenProvider] = { - val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystemsToAccess)) ++ + val providers = Seq( + new HadoopFSDelegationTokenProvider( + () => HadoopDelegationTokenManager.this.fileSystemsToAccess())) ++ safeCreateProvider(new HiveDelegationTokenProvider) ++ safeCreateProvider(new HBaseDelegationTokenProvider) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 61deb543d8747..a30a501e5d4a1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -190,8 +190,11 @@ private[spark] class Executor( private val HEARTBEAT_INTERVAL_MS = conf.get(EXECUTOR_HEARTBEAT_INTERVAL) // Executor for the heartbeat task. - private val heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, - "executor-heartbeater", HEARTBEAT_INTERVAL_MS) + private val heartbeater = new Heartbeater( + env.memoryManager, + () => Executor.this.reportHeartBeat(), + "executor-heartbeater", + HEARTBEAT_INTERVAL_MS) // must be initialized before running startDriverHeartbeat() private val heartbeatReceiverRef = diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 903e25b7986f2..33a68f24bd53a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, + private val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + // This would just be the second constructor arg, except we need to maintain this method + // with parentheses for compatibility def attemptNumber(): Int = attemptId private[spark] def getStatusString: String = { diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 31ce9483cf20a..424d9f825c465 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -215,7 +215,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("exclusive ranges of doubles") { - val data = 1.0 until 100.0 by 1.0 + val data = Range.BigDecimal(1, 100, 1) val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).sum === 99) @@ -223,7 +223,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("inclusive ranges of doubles") { - val data = 1.0 to 100.0 by 1.0 + val data = Range.BigDecimal.inclusive(1, 100, 1) val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).sum === 100) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 84af73b08d3e7..e413fe3b774d0 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -202,7 +202,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time - assert(ser.serialize(t).limit() < 100) + assert(ser.serialize(t).limit() < 200) } check(1 to 1000000) check(1 to 1000000 by 2) @@ -212,10 +212,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(1L to 1000000L by 2L) check(1L until 1000000L) check(1L until 1000000L by 2L) - check(1.0 to 1000000.0 by 1.0) - check(1.0 to 1000000.0 by 2.0) - check(1.0 until 1000000.0 by 1.0) - check(1.0 until 1000000.0 by 2.0) + check(Range.BigDecimal.inclusive(1, 1000000, 1)) + check(Range.BigDecimal.inclusive(1, 1000000, 2)) + check(Range.BigDecimal(1, 1000000, 1)) + check(Range.BigDecimal(1, 1000000, 2)) } test("asJavaIterable") { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index bfd73069fbff8..5f757b757ac61 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.status import java.io.File -import java.lang.{Integer => JInteger, Long => JLong} -import java.util.{Arrays, Date, Properties} +import java.util.{Date, Properties} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -1171,12 +1170,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Stop task 2 before task 1 time += 1 tasks(1).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(1), null)) time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(0), null)) // Start task 3 and task 2 should be evicted. listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) @@ -1241,8 +1240,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Task 1 Finished time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(0), null)) // Stage 1 Completed stage1.failureReason = Some("Failed") @@ -1256,7 +1255,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(1).markFinished(TaskState.FINISHED, time) listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", + SparkListenerTaskEnd(stage1.stageId, stage1.attemptNumber, "taskType", TaskKilled(reason = "Killed"), tasks(1), null)) // Ensure killed task metrics are updated diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index cd25265784136..35fba1a3b73c6 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ +import scala.language.postfixOps import scala.ref.WeakReference import org.scalatest.Matchers diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 8d6cca8e48c3d..207c54ce75f4c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -138,7 +138,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test NULL avro type") { withTempPath { dir => val fields = - Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -161,7 +161,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -189,7 +189,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -221,7 +221,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -247,7 +247,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("Union of a single type") { withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -274,10 +274,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null), - new Field("field2", complexUnionType, "doc", null), - new Field("field3", complexUnionType, "doc", null), - new Field("field4", complexUnionType, "doc", null) + new Field("field1", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field2", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field3", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal]) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -941,7 +941,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) val name = "foo" - val avroField = new Field(name, avroType, "", null) + val avroField = new Field(name, avroType, "", null.asInstanceOf[AnyVal]) val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) val avroRecordType = resolveNullable(recordSchema, nullable) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa6bdc20bd4f9..aa21f1271b817 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -56,7 +56,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { } // Continuous processing tasks end asynchronously, so test that they actually end. - private val tasksEndedListener = new SparkListener() { + private class TasksEndedListener extends SparkListener { val activeTaskIdCount = new AtomicInteger(0) override def onTaskStart(start: SparkListenerTaskStart): Unit = { @@ -68,6 +68,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { } } + private val tasksEndedListener = new TasksEndedListener() + override def beforeEach(): Unit = { super.beforeEach() spark.sparkContext.addSparkListener(tasksEndedListener) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index cf283a5c3e11e..07960d14b0bfc 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -228,7 +228,7 @@ object ConsumerStrategies { new Subscribe[K, V]( new ju.ArrayList(topics.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** @@ -307,7 +307,7 @@ object ConsumerStrategies { new SubscribePattern[K, V]( pattern, new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** @@ -391,7 +391,7 @@ object ConsumerStrategies { new Assign[K, V]( new ju.ArrayList(topicPartitions.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala index c5d0ec1a8d350..412954f7b2d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector @@ -30,8 +28,12 @@ import org.apache.spark.ml.linalg.Vector * @param features List of features for this data point. */ @Since("2.0.0") -@BeanInfo case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: Vector) { + + def getLabel: Double = label + + def getFeatures: Vector = features + override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 56e2c543d100a..5bfaa3b7f3f52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,10 +17,6 @@ package org.apache.spark.ml.feature -import org.json4s.JsonDSL._ -import org.json4s.JValue -import org.json4s.jackson.JsonMethods._ - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -209,7 +205,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui if (isSet(inputCols)) { val splitsArray = if (isSet(numBucketsArray)) { val probArrayPerCol = $(numBucketsArray).map { numOfBuckets => - (0.0 to 1.0 by 1.0 / numOfBuckets).toArray + (0 to numOfBuckets).map(_.toDouble / numOfBuckets).toArray } val probabilityArray = probArrayPerCol.flatten.sorted.distinct @@ -229,12 +225,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui } } else { dataset.stat.approxQuantile($(inputCols), - (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError)) } bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits)) } else { val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError)) bucketizer.setSplits(getDistinctSplits(splits)) } copyValues(bucketizer.setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 4381d6ab20cc0..b320057b25276 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.regression -import scala.beans.BeanInfo - import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} @@ -32,10 +30,14 @@ import org.apache.spark.mllib.util.NumericParser * @param features List of features for this data point. */ @Since("0.8.0") -@BeanInfo case class LabeledPoint @Since("1.0.0") ( @Since("0.8.0") label: Double, @Since("1.0.0") features: Vector) { + + def getLabel: Double = label + + def getFeatures: Vector = features + override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 80c6ef0ea1aa1..85ed11d6553d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.stat.test -import scala.beans.BeanInfo - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.streaming.api.java.JavaDStream @@ -32,10 +30,11 @@ import org.apache.spark.util.StatCounter * @param value numeric value of the observation. */ @Since("1.6.0") -@BeanInfo case class BinarySample @Since("1.6.0") ( @Since("1.6.0") isExperiment: Boolean, @Since("1.6.0") value: Double) { + def getIsExperiment: Boolean = isExperiment + def getValue: Double = value override def toString: String = { s"($isExperiment, $value)" } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 6734336aac39c..985e396000d05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row -@BeanInfo -case class DCTTestData(vec: Vector, wantedVec: Vector) +case class DCTTestData(vec: Vector, wantedVec: Vector) { + def getVec: Vector = vec + def getWantedVec: Vector = wantedVec +} class DCTSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index 201a335e0d7be..1483d5df4d224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} - -@BeanInfo -case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) { + def getInputTokens: Array[String] = inputTokens + def getWantedNGrams: Array[String] = wantedNGrams +} class NGramSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b009038bbd833..82af05039653e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -31,7 +31,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val datasetSize = 100000 val numBuckets = 5 - val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val df = sc.parallelize(1 to datasetSize).map(_.toDouble).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") @@ -114,8 +114,8 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val spark = this.spark import spark.implicits._ - val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") - val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") + val trainDF = sc.parallelize((1 to 100).map(_.toDouble)).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize((-10 to 110).map(_.toDouble)).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") @@ -276,10 +276,10 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) val data2 = Array.range(1, 40, 2).map(_.toDouble) val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, - 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) + 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) val data3 = Array.range(1, 60, 3).map(_.toDouble) - val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, - 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) + val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, + 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) val data = (0 until 20).map { idx => (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index be59b0af2c78e..ba8e79f14de95 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -@BeanInfo -case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) { + def getRawText: String = rawText + def getWantedTokens: Array[String] = wantedTokens +} class TokenizerSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index fb5789f945dec..44b0f8f8ae7d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { @@ -339,6 +337,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { } private[feature] object VectorIndexerSuite { - @BeanInfo - case class FeatureData(@BeanProperty features: Vector) + case class FeatureData(features: Vector) { + def getFeatures: Vector = features + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9a59c41740daf..2fc9754ecfe1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -601,7 +601,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val df = maybeDf.get._2 val expected = estimator.fit(df) - val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + val actuals = dfs.map(t => (t, estimator.fit(t._2))) actuals.foreach { case (_, actual) => check(expected, actual) } actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } diff --git a/pom.xml b/pom.xml index fcec295eee128..9130773cb5094 100644 --- a/pom.xml +++ b/pom.xml @@ -407,6 +407,11 @@ commons-lang3 ${commons-lang3.version} + + org.apache.commons + commons-text + 1.6 + commons-lang commons-lang diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a8d2b5d1d9cb6..e35e74aa33045 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index be4daec3b1bb9..167fb402cd402 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -55,7 +55,7 @@ private[spark] class KubernetesDriverBuilder( providePodTemplateConfigMapStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => PodTemplateConfigMapStep) = new PodTemplateConfigMapStep(_), - provideInitialPod: () => SparkPod = SparkPod.initialPod) { + provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 089f84dec277f..fc41a4770bce6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -53,7 +53,7 @@ private[spark] class KubernetesExecutorBuilder( KubernetesConf[KubernetesExecutorSpecificConf] => HadoopSparkUserExecutorFeatureStep) = new HadoopSparkUserExecutorFeatureStep(_), - provideInitialPod: () => SparkPod = SparkPod.initialPod) { + provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 81e3822389f30..08f28758ef485 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} -import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} +import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} import org.scalatest.BeforeAndAfter @@ -28,7 +28,6 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class ClientSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index b336774838bcb..2f984e5d89808 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -157,7 +157,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { override def matches(argument: scala.Any): Boolean = { - if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + if (!argument.isInstanceOf[KubernetesConf[_]]) { false } else { val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 35299166d9814..c3070de3d17cf 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -116,8 +116,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter } def createContainer(host: String, resource: Resource = containerResource): Container = { - // When YARN 2.6+ is required, avoid deprecation by using version with long second arg - val containerId = ContainerId.newInstance(appAttemptId, containerNum) + val containerId = ContainerId.newContainerId(appAttemptId, containerNum) containerNum += 1 val nodeId = NodeId.newInstance(host, 1000) Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala index 9bacd3b925be3..ea619c6a7666c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala @@ -199,7 +199,7 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { var shift = 0 while (idx < m && i < REGISTERS_PER_WORD) { val Midx = (word >>> shift) & REGISTER_WORD_MASK - zInverse += 1.0 / (1 << Midx) + zInverse += 1.0 / (1L << Midx) if (Midx == 0) { V += 1.0d } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 94778840d706b..117e96175e92a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -30,8 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -@BeanInfo -private[sql] case class GroupableData(@BeanProperty data: Int) +private[sql] case class GroupableData(data: Int) { + def getData: Int = data +} private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { @@ -50,8 +49,9 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { private[spark] override def asNullable: GroupableUDT = this } -@BeanInfo -private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) +private[sql] case class UngroupableData(data: Map[Int, Int]) { + def getData: Map[Int, Int] = data +} private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index d9ea8dc9d4ac9..d9fe1a992a093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -311,7 +311,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, - trigger: Trigger = ProcessingTime(0), + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock()): StreamingQuery = { val query = createQuery( userSpecifiedName, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 7f975a647c241..8f35abeb579b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -143,11 +143,16 @@ public void setIntervals(List intervals) { this.intervals = intervals; } + @Override + public int hashCode() { + return id ^ Objects.hashCode(intervals); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof ArrayRecord)) return false; ArrayRecord other = (ArrayRecord) obj; - return (other.id == this.id) && other.intervals.equals(this.intervals); + return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); } @Override @@ -184,6 +189,11 @@ public void setIntervals(Map intervals) { this.intervals = intervals; } + @Override + public int hashCode() { + return id ^ Objects.hashCode(intervals); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof MapRecord)) return false; @@ -225,6 +235,11 @@ public void setEndTime(long endTime) { this.endTime = endTime; } + @Override + public int hashCode() { + return Long.hashCode(startTime) ^ Long.hashCode(endTime); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof Interval)) return false; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java new file mode 100644 index 0000000000000..438f489a3eea7 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import org.apache.spark.sql.sources.v2.reader.InputPartition; + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java index 685f9b9747e85..ced51dde6997b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -88,12 +88,3 @@ public void close() throws IOException { } } -class JavaRangeInputPartition implements InputPartition { - int start; - int end; - - JavaRangeInputPartition(int start, int end) { - this.start = start; - this.end = end; - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cc8b600efa46a..cf956316057eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} @@ -28,10 +26,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -@BeanInfo -private[sql] case class MyLabeledPoint( - @BeanProperty label: Double, - @BeanProperty features: UDT.MyDenseVector) +private[sql] case class MyLabeledPoint(label: Double, features: UDT.MyDenseVector) { + def getLabel: Double = label + def getFeatures: UDT.MyDenseVector = features +} // Wrapped in an object to check Scala compatibility. See SPARK-13929 object UDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 0d9f1fb0c02c9..fb3388452e4e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -46,6 +46,7 @@ class IntegralDeltaSuite extends SparkFunSuite { (input.tail, input.init).zipped.map { case (x: Int, y: Int) => (x - y).toLong case (x: Long, y: Long) => x - y + case other => fail(s"Unexpected input $other") } } @@ -116,7 +117,7 @@ class IntegralDeltaSuite extends SparkFunSuite { val row = new GenericInternalRow(1) val nullRow = new GenericInternalRow(1) nullRow.setNullAt(0) - input.map { value => + input.foreach { value => if (value == nullValue) { builder.appendFrom(nullRow, 0) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 80c76915e4c23..2d338ab92211e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.ConcurrentHashMap -import scala.collection.mutable - -import org.eclipse.jetty.util.ConcurrentHashSet import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ @@ -48,7 +45,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { } test("trigger timing") { - val triggerTimes = new ConcurrentHashSet[Int] + val triggerTimes = ConcurrentHashMap.newKeySet[Int]() val clock = new StreamManualClock() @volatile var continueExecuting = true @volatile var clockIncrementInTrigger = 0L diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 635ea6fca649c..7db31f1f8f699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -382,10 +382,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before tasks.foreach { case t: TextSocketContinuousInputPartition => val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] - for (i <- 0 until numRecords / 2) { + for (_ <- 0 until numRecords / 2) { r.next() - assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) - .isInstanceOf[(String, Timestamp)]) + assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP).isInstanceOf[(_, _)]) } case _ => throw new IllegalStateException("Unexpected task type") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index e8710aeb40bd4..ddc5dbb148cb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -150,7 +150,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { def getPeakExecutionMemory(stageId: Int): Long = { val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables - .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .filter(_._2.name == Some(InternalAccumulator.PEAK_EXECUTION_MEMORY)) assert(peakMemoryAccumulator.size == 1) peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long] diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala index 5f9ea4d26790b..035b71a37a692 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.HiveUtils class HiveCliSessionStateSuite extends SparkFunSuite { def withSessionClear(f: () => Unit): Unit = { - try f finally SessionState.detachSession() + try f() finally SessionState.detachSession() } test("CliSessionState will be reused") { From 86cc907448f0102ad0c185e87fcc897d0a32707f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 19 Nov 2018 15:11:42 -0600 Subject: [PATCH 2092/2461] This is a dummy commit to trigger ASF git sync From a09d5ba88680d07121ce94a4e68c3f42fc635f4f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 20 Nov 2018 09:27:46 +0800 Subject: [PATCH 2093/2461] [SPARK-26107][SQL] Extend ReplaceNullWithFalseInPredicate to support higher-order functions: ArrayExists, ArrayFilter, MapFilter ## What changes were proposed in this pull request? Extend the `ReplaceNullWithFalse` optimizer rule introduced in SPARK-25860 (https://github.com/apache/spark/pull/22857) to also support optimizing predicates in higher-order functions of `ArrayExists`, `ArrayFilter`, `MapFilter`. Also rename the rule to `ReplaceNullWithFalseInPredicate` to better reflect its intent. Example: ```sql select filter(a, e -> if(e is null, null, true)) as b from ( select array(null, 1, null, 3) as a) ``` The optimized logical plan: **Before**: ``` == Optimized Logical Plan == Project [filter([null,1,null,3], lambdafunction(if (isnull(lambda e#13)) null else true, lambda e#13, false)) AS b#9] +- OneRowRelation ``` **After**: ``` == Optimized Logical Plan == Project [filter([null,1,null,3], lambdafunction(if (isnull(lambda e#13)) false else true, lambda e#13, false)) AS b#9] +- OneRowRelation ``` ## How was this patch tested? Added new unit test cases to the `ReplaceNullWithFalseInPredicateSuite` (renamed from `ReplaceNullWithFalseSuite`). Closes #23079 from rednaxelafx/catalyst-master. Authored-by: Kris Mok Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 11 ++++- ...eplaceNullWithFalseInPredicateSuite.scala} | 48 +++++++++++++++++-- ...llWithFalseInPredicateEndToEndSuite.scala} | 45 ++++++++++++++++- 4 files changed, 98 insertions(+), 8 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ReplaceNullWithFalseSuite.scala => ReplaceNullWithFalseInPredicateSuite.scala} (87%) rename sql/core/src/test/scala/org/apache/spark/sql/{ReplaceNullWithFalseEndToEndSuite.scala => ReplaceNullWithFalseInPredicateEndToEndSuite.scala} (63%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a330a84a3a24f..8d251eeab8484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - ReplaceNullWithFalse, + ReplaceNullWithFalseInPredicate, PruneFilters, EliminateSorts, SimplifyCasts, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2b29b49d00ab9..354efd883f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -755,7 +755,7 @@ object CombineConcats extends Rule[LogicalPlan] { * * As a result, many unnecessary computations can be removed in the query optimization phase. */ -object ReplaceNullWithFalse extends Rule[LogicalPlan] { +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) @@ -767,6 +767,15 @@ object ReplaceNullWithFalse extends Rule[LogicalPlan] { replaceNullWithFalse(cond) -> value } cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala similarity index 87% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index c6b5d0ec96776..3a9e6cae0fd87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.{BooleanType, IntegerType} -class ReplaceNullWithFalseSuite extends PlanTest { +class ReplaceNullWithFalseInPredicateSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -36,10 +36,11 @@ class ReplaceNullWithFalseSuite extends PlanTest { ConstantFolding, BooleanSimplification, SimplifyConditionals, - ReplaceNullWithFalse) :: Nil + ReplaceNullWithFalseInPredicate) :: Nil } - private val testRelation = LocalRelation('i.int, 'b.boolean) + private val testRelation = + LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { @@ -298,6 +299,26 @@ class ReplaceNullWithFalseSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + test("replace nulls in lambda function of ArrayFilter") { + testHigherOrderFunc('a, ArrayFilter, Seq('e)) + } + + test("replace nulls in lambda function of ArrayExists") { + testHigherOrderFunc('a, ArrayExists, Seq('e)) + } + + test("replace nulls in lambda function of MapFilter") { + testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + } + + test("inability to replace nulls in arbitrary higher-order function") { + val lambdaFunc = LambdaFunction( + function = If('e > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression]('e)) + val column = ArrayTransform('a, lambdaFunc) + testProjection(originalExpr = column, expectedExpr = column) + } + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { test((rel, exp) => rel.where(exp), originalCond, expectedCond) } @@ -310,6 +331,25 @@ class ReplaceNullWithFalseSuite extends PlanTest { test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) } + private def testHigherOrderFunc( + argument: Expression, + createExpr: (Expression, Expression) => Expression, + lambdaArgs: Seq[NamedExpression]): Unit = { + val condArg = lambdaArgs.last + // the lambda body is: if(arg > 0, null, true) + val cond = GreaterThan(condArg, Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = lambdaArgs) + // the optimized lambda body is: if(arg > 0, false, true) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = lambdaArgs) + testProjection( + originalExpr = createExpr(argument, lambda1) as 'x, + expectedExpr = createExpr(argument, lambda2) as 'x) + } + private def test( func: (LogicalPlan, Expression) => LogicalPlan, originalExpr: Expression, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala similarity index 63% rename from sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index fc6ecc4e032f6..0f84b0c961a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{lit, when} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType -class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext { +class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { @@ -68,4 +69,44 @@ class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext case p => fail(s"$p is not LocalTableScanExec") } } + + test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { + def assertNoLiteralNullInPlan(df: DataFrame): Unit = { + df.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(_.find { + case Literal(null, BooleanType) => true + case _ => false + }.isEmpty)) + } + } + + withTable("t1", "t2") { + // to test ArrayFilter and ArrayExists + spark.sql("select array(null, 1, null, 3) as a") + .write.saveAsTable("t1") + // to test MapFilter + spark.sql(""" + select map_from_entries(arrays_zip(a, transform(a, e -> if(mod(e, 2) = 0, null, e)))) as m + from (select array(0, 1, 2, 3) as a) + """).write.saveAsTable("t2") + + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + // ArrayExists + val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") + checkAnswer(q1, Row(true) :: Nil) + assertNoLiteralNullInPlan(q1) + + // ArrayFilter + val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))") + checkAnswer(q2, Row(Seq[Any](1, 3)) :: Nil) + assertNoLiteralNullInPlan(q2) + + // MapFilter + val q3 = df2.selectExpr("MAP_FILTER(m, (k, v) -> IF(v is null, null, true))") + checkAnswer(q3, Row(Map[Any, Any](1 -> 1, 3 -> 3))) + assertNoLiteralNullInPlan(q3) + } + } } From a00aaf649cb5a14648102b2980ce21393804f2c7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 20 Nov 2018 08:27:57 -0600 Subject: [PATCH 2094/2461] [MINOR][YARN] Make memLimitExceededLogMessage more clean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Current `memLimitExceededLogMessage`: It‘s not very clear, because physical memory exceeds but suggestion contains virtual memory config. This pr makes it more clear and replace deprecated config: ```spark.yarn.executor.memoryOverhead```. ## How was this patch tested? manual tests Closes #23030 from wangyum/EXECUTOR_MEMORY_OVERHEAD. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- .../spark/deploy/yarn/YarnAllocator.scala | 33 ++++++++----------- .../deploy/yarn/YarnAllocatorSuite.scala | 12 ------- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebdcf45603cea..9497530805c1a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy.yarn import java.util.Collections import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger -import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable @@ -598,13 +597,21 @@ private[yarn] class YarnAllocator( (false, s"Container ${containerId}${onHostStr} was preempted.") // Should probably still count memory exceeded exit codes towards task failures case VMEM_EXCEEDED_EXIT_CODE => - (true, memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) + val vmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX virtual memory used".r + val diag = vmemExceededPattern.findFirstIn(completedContainer.getDiagnostics) + .map(_.concat(".")).getOrElse("") + val message = "Container killed by YARN for exceeding virtual memory limits. " + + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key} or boosting " + + s"${YarnConfiguration.NM_VMEM_PMEM_RATIO} or disabling " + + s"${YarnConfiguration.NM_VMEM_CHECK_ENABLED} because of YARN-4714." + (true, message) case PMEM_EXCEEDED_EXIT_CODE => - (true, memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) + val pmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX physical memory used".r + val diag = pmemExceededPattern.findFirstIn(completedContainer.getDiagnostics) + .map(_.concat(".")).getOrElse("") + val message = "Container killed by YARN for exceeding physical memory limits. " + + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key}." + (true, message) case _ => // all the failures which not covered above, like: // disk failure, kill by app master or resource manager, ... @@ -735,18 +742,6 @@ private[yarn] class YarnAllocator( private object YarnAllocator { val MEM_REGEX = "[0-9.]+ [KMG]B" - val PMEM_EXCEEDED_PATTERN = - Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") - val VMEM_EXCEEDED_PATTERN = - Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") val VMEM_EXCEEDED_EXIT_CODE = -103 val PMEM_EXCEEDED_EXIT_CODE = -104 - - def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { - val matcher = pattern.matcher(diagnostics) - val diag = if (matcher.find()) " " + matcher.group() + "." else "" - s"Container killed by YARN for exceeding memory limits. $diag " + - "Consider boosting spark.yarn.executor.memoryOverhead or " + - "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." - } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index c3070de3d17cf..b61e7df4420ef 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -29,7 +29,6 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.rpc.RpcEndpointRef @@ -376,17 +375,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } - test("memory exceeded diagnostic regexes") { - val diagnostics = - "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + - "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + - "5.8 GB of 4.2 GB virtual memory used. Killing container." - val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) - val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) - assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) - assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) - } - test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) From c34c42234f308872ebe9c7cdaee32000c0726eea Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 20 Nov 2018 08:29:59 -0600 Subject: [PATCH 2095/2461] [SPARK-26076][BUILD][MINOR] Revise ambiguous error message from load-spark-env.sh ## What changes were proposed in this pull request? When I try to run scripts (e.g. `start-master.sh`/`start-history-server.sh ` in latest master, I got such error: ``` Presence of build for multiple Scala versions detected. Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh. ``` The error message is quite confusing. Without reading `load-spark-env.sh`, I didn't know which directory to remove, or where to find and edit the `spark-evn.sh`. This PR is to make the error message more clear. Also change the script for less maintenance when we add or drop Scala versions in the future. As now with https://github.com/apache/spark/pull/22967, we can revise the error message as following(in my local setup): ``` Presence of build for multiple Scala versions detected (/Users/gengliangwang/IdeaProjects/spark/assembly/target/scala-2.12 and /Users/gengliangwang/IdeaProjects/spark/assembly/target/scala-2.11). Remove one of them or, export SPARK_SCALA_VERSION=2.12 in /Users/gengliangwang/IdeaProjects/spark/conf/spark-env.sh. Visit https://spark.apache.org/docs/latest/configuration.html#environment-variables for more details about setting environment variables in spark-env.sh. ``` ## How was this patch tested? Manual test Closes #23049 from gengliangwang/reviseEnvScript. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- bin/load-spark-env.sh | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 0b5006dbd63ac..0ada5d8d0fc1d 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -26,15 +26,17 @@ if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi +SPARK_ENV_SH="spark-env.sh" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}"/conf}" - if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then + SPARK_ENV_SH="${SPARK_CONF_DIR}/${SPARK_ENV_SH}" + if [[ -f "${SPARK_ENV_SH}" ]]; then # Promote all variable declarations to environment (exported) variables set -a - . "${SPARK_CONF_DIR}/spark-env.sh" + . ${SPARK_ENV_SH} set +a fi fi @@ -42,19 +44,22 @@ fi # Setting SPARK_SCALA_VERSION if not already set. if [ -z "$SPARK_SCALA_VERSION" ]; then + SCALA_VERSION_1=2.12 + SCALA_VERSION_2=2.11 - ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" - ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12" - - if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for multiple Scala versions detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh.' 1>&2 + ASSEMBLY_DIR_1="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_1}" + ASSEMBLY_DIR_2="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_2}" + ENV_VARIABLE_DOC="https://spark.apache.org/docs/latest/configuration.html#environment-variables" + if [[ -d "$ASSEMBLY_DIR_1" && -d "$ASSEMBLY_DIR_2" ]]; then + echo "Presence of build for multiple Scala versions detected ($ASSEMBLY_DIR_1 and $ASSEMBLY_DIR_2)." 1>&2 + echo "Remove one of them or, export SPARK_SCALA_VERSION=$SCALA_VERSION_1 in ${SPARK_ENV_SH}." 1>&2 + echo "Visit ${ENV_VARIABLE_DOC} for more details about setting environment variables in spark-env.sh." 1>&2 exit 1 fi - if [ -d "$ASSEMBLY_DIR2" ]; then - export SPARK_SCALA_VERSION="2.11" + if [[ -d "$ASSEMBLY_DIR_1" ]]; then + export SPARK_SCALA_VERSION=${SCALA_VERSION_1} else - export SPARK_SCALA_VERSION="2.12" + export SPARK_SCALA_VERSION=${SCALA_VERSION_2} fi fi From ab61ddb34d58ab5701191c8fd3a24a62f6ebf37b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 20 Nov 2018 08:56:22 -0600 Subject: [PATCH 2096/2461] [SPARK-26118][WEB UI] Introducing spark.ui.requestHeaderSize for setting HTTP requestHeaderSize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Introducing spark.ui.requestHeaderSize for configuring Jetty's HTTP requestHeaderSize. This way long authorization field does not lead to HTTP 413. ## How was this patch tested? Manually with curl (which version must be at least 7.55). With the original default value (8k limit): ```bash # Starting history server with default requestHeaderSize $ ./sbin/start-history-server.sh starting org.apache.spark.deploy.history.HistoryServer, logging to /Users/attilapiros/github/spark/logs/spark-attilapiros-org.apache.spark.deploy.history.HistoryServer-1-apiros-MBP.lan.out # Creating huge header $ echo -n "X-Custom-Header: " > cookie $ printf 'A%.0s' {1..9500} >> cookie # HTTP GET with huge header fails with 431 $ curl -H cookie http://458apiros-MBP.lan:18080/

      Bad Message 431

      reason: Request Header Fields Too Large
      # The log contains the error $ tail -1 /Users/attilapiros/github/spark/logs/spark-attilapiros-org.apache.spark.deploy.history.HistoryServer-1-apiros-MBP.lan.out 18/11/19 21:24:28 WARN HttpParser: Header is too large 8193>8192 ``` After: ```bash # Creating the history properties file with the increased requestHeaderSize $ echo spark.ui.requestHeaderSize=10000 > history.properties # Starting Spark History Server with the settings $ ./sbin/start-history-server.sh --properties-file history.properties starting org.apache.spark.deploy.history.HistoryServer, logging to /Users/attilapiros/github/spark/logs/spark-attilapiros-org.apache.spark.deploy.history.HistoryServer-1-apiros-MBP.lan.out # HTTP GET with huge header gives back HTML5 (I have added here only just a part of the response) $ curl -H cookie http://458apiros-MBP.lan:18080/ ... History Server ... ``` Closes #23090 from attilapiros/JettyHeaderSize. Authored-by: “attilapiros” Signed-off-by: Imran Rashid --- .../scala/org/apache/spark/internal/config/package.scala | 6 ++++++ core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 6 ++++-- docs/configuration.md | 8 ++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ab2b872c5551e..9cc48f6375003 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -570,6 +570,12 @@ package object config { .stringConf .createOptional + private[spark] val UI_REQUEST_HEADER_SIZE = + ConfigBuilder("spark.ui.requestHeaderSize") + .doc("Value for HTTP request header size in bytes.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("8k") + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") .doc("Class names of listeners to add to SparkContext during initialization.") .stringConf diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 52a955111231a..316af9b79d286 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -356,13 +356,15 @@ private[spark] object JettyUtils extends Logging { (connector, connector.getLocalPort()) } + val httpConfig = new HttpConfiguration() + httpConfig.setRequestHeaderSize(conf.get(UI_REQUEST_HEADER_SIZE).toInt) // If SSL is configured, create the secure connector first. val securePort = sslOptions.createJettySslContextFactory().map { factory => val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName val connectionFactories = AbstractConnectionFactory.getFactories(factory, - new HttpConnectionFactory()) + new HttpConnectionFactory(httpConfig)) def sslConnect(currentPort: Int): (ServerConnector, Int) = { newConnector(connectionFactories, currentPort) @@ -377,7 +379,7 @@ private[spark] object JettyUtils extends Logging { // Bind the HTTP port. def httpConnect(currentPort: Int): (ServerConnector, Int) = { - newConnector(Array(new HttpConnectionFactory()), currentPort) + newConnector(Array(new HttpConnectionFactory(httpConfig)), currentPort) } val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect, diff --git a/docs/configuration.md b/docs/configuration.md index 2915fb5fa9197..04210d855b110 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -973,6 +973,14 @@ Apart from these, the following properties are also available, and may be useful
      spark.com.test.filter1.param.name2=bar + + spark.ui.requestHeaderSize + 8k + + The maximum allowed size for a HTTP request header, in bytes unless otherwise specified. + This setting applies for the Spark History Server too. + + ### Compression and Serialization From db136d360e54e13f1d7071a0428964a202cf7e31 Mon Sep 17 00:00:00 2001 From: Simeon Simeonov Date: Tue, 20 Nov 2018 21:29:56 +0100 Subject: [PATCH 2097/2461] [SPARK-26084][SQL] Fixes unresolved AggregateExpression.references exception ## What changes were proposed in this pull request? This PR fixes an exception in `AggregateExpression.references` called on unresolved expressions. It implements the solution proposed in [SPARK-26084](https://issues.apache.org/jira/browse/SPARK-26084), a minor refactoring that removes the unnecessary dependence on `AttributeSet.toSeq`, which requires expression IDs and, therefore, can only execute successfully for resolved expressions. The refactored implementation is both simpler and faster, eliminating the conversion of a `Set` to a `Seq` and back to `Set`. ## How was this patch tested? Added a new test based on the failing case in [SPARK-26084](https://issues.apache.org/jira/browse/SPARK-26084). hvanhovell Closes #23075 from ssimeonov/ss_SPARK-26084. Authored-by: Simeon Simeonov Signed-off-by: Herman van Hovell --- .../expressions/aggregate/interfaces.scala | 8 ++--- .../aggregate/AggregateExpressionSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e1d16a2cd38b0..56c2ee6b53fe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -128,12 +128,10 @@ case class AggregateExpression( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferences = mode match { - case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.aggBufferAttributes + mode match { + case Partial | Complete => aggregateFunction.references + case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } - - AttributeSet(childReferences) } override def toString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala new file mode 100644 index 0000000000000..8e9c9972071ad --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet} + +class AggregateExpressionSuite extends SparkFunSuite { + + test("test references from unresolved aggregate functions") { + val x = UnresolvedAttribute("x") + val y = UnresolvedAttribute("y") + val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, isDistinct = false).references + val expected = AttributeSet(x :: y :: Nil) + assert(expected == actual, s"Expected: $expected. Actual: $actual") + } + +} From 42c48387c047d96154bcfeb95fcb816a43e60d7c Mon Sep 17 00:00:00 2001 From: shane knapp Date: Tue, 20 Nov 2018 12:38:40 -0800 Subject: [PATCH 2098/2461] [BUILD] refactor dev/lint-python in to something readable ## What changes were proposed in this pull request? `dev/lint-python` is a mess of nearly unreadable bash. i would like to fix that as best as i can. ## How was this patch tested? the build system will test this. Closes #22994 from shaneknapp/lint-python-refactor. Authored-by: shane knapp Signed-off-by: shane knapp --- dev/lint-python | 359 +++++++++++++++++++++++++++++------------------- 1 file changed, 220 insertions(+), 139 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index 27d87f6b56680..06816932e754a 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -1,5 +1,4 @@ #!/usr/bin/env bash - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -16,160 +15,242 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# define test binaries + versions +PYDOCSTYLE_BUILD="pydocstyle" +MINIMUM_PYDOCSTYLE="3.0.0" -SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -# Exclude auto-generated configuration file. -PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -DOC_PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" | grep -vF 'functions.py' )" -PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" -PYDOCSTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pydocstyle-report.txt" -PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" -PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" - -PYDOCSTYLEBUILD="pydocstyle" -MINIMUM_PYDOCSTYLEVERSION="3.0.0" - -FLAKE8BUILD="flake8" +FLAKE8_BUILD="flake8" MINIMUM_FLAKE8="3.5.0" -SPHINXBUILD=${SPHINXBUILD:=sphinx-build} -SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" +PYCODESTYLE_BUILD="pycodestyle" +MINIMUM_PYCODESTYLE="2.4.0" -cd "$SPARK_ROOT_DIR" +SPHINX_BUILD="sphinx-build" -# compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" -compile_status="${PIPESTATUS[0]}" +function compile_python_test { + local COMPILE_STATUS= + local COMPILE_REPORT= + + if [[ ! "$1" ]]; then + echo "No python files found! Something is very wrong -- exiting." + exit 1; + fi -# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. -# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. -PYCODESTYLE_VERSION="2.4.0" -PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" -PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" + # compileall: https://docs.python.org/2/library/compileall.html + echo "starting python compilation test..." + COMPILE_REPORT=$( (python -B -mcompileall -q -l $1) 2>&1) + COMPILE_STATUS=$? + + if [ $COMPILE_STATUS -ne 0 ]; then + echo "Python compilation failed with the following errors:" + echo "$COMPILE_REPORT" + echo "$COMPILE_STATUS" + exit "$COMPILE_STATUS" + else + echo "python compilation succeeded." + echo + fi +} -if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then - curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" - curl_status="$?" +function pycodestyle_test { + local PYCODESTYLE_STATUS= + local PYCODESTYLE_REPORT= + local RUN_LOCAL_PYCODESTYLE= + local VERSION= + local EXPECTED_PYCODESTYLE= + local PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$MINIMUM_PYCODESTYLE.py" + local PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$MINIMUM_PYCODESTYLE/pycodestyle.py" - if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." - exit "$curl_status" + if [[ ! "$1" ]]; then + echo "No python files found! Something is very wrong -- exiting." + exit 1; fi -fi - -# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should -# be set to the directory. -# dev/pylint should be appended to the PATH variable as well. -# Jenkins by default installs the pylint3 version, so for now this just checks the code quality -# of python3. -export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" -export "PYLINT_HOME=$PYTHONPATH" -export "PATH=$PYTHONPATH:$PATH" - -# There is no need to write this output to a file -# first, but we do so so that the check status can -# be output before the report, like with the -# scalastyle and RAT checks. -python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" -pycodestyle_status="${PIPESTATUS[0]}" - -if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then - lint_status=0 -else - lint_status=1 -fi - -if [ "$lint_status" -ne 0 ]; then - echo "pycodestyle checks failed." - cat "$PYCODESTYLE_REPORT_PATH" - rm "$PYCODESTYLE_REPORT_PATH" - exit "$lint_status" -else - echo "pycodestyle checks passed." - rm "$PYCODESTYLE_REPORT_PATH" -fi - -# Check by flake8 -if hash "$FLAKE8BUILD" 2> /dev/null; then - FLAKE8VERSION="$( $FLAKE8BUILD --version 2> /dev/null )" - VERSION=($FLAKE8VERSION) - IS_EXPECTED_FLAKE8=$(python -c 'from distutils.version import LooseVersion; \ -print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))' 2> /dev/null) - if [[ "$IS_EXPECTED_FLAKE8" == "True" ]]; then - # stop the build if there are Python syntax errors or undefined names - $FLAKE8BUILD . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics - flake8_status="${PIPESTATUS[0]}" - - if [ "$flake8_status" -eq 0 ]; then - lint_status=0 - else - lint_status=1 + + # check for locally installed pycodestyle & version + RUN_LOCAL_PYCODESTYLE="False" + if hash "$PYCODESTYLE_BUILD" 2> /dev/null; then + VERSION=$( $PYCODESTYLE_BUILD --version 2> /dev/null) + EXPECTED_PYCODESTYLE=$( (python -c 'from distutils.version import LooseVersion; + print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_PYCODESTYLE'"""))')\ + 2> /dev/null) + + if [ "$EXPECTED_PYCODESTYLE" == "True" ]; then + RUN_LOCAL_PYCODESTYLE="True" fi + fi - if [ "$lint_status" -ne 0 ]; then - echo "flake8 checks failed." - exit "$lint_status" - else - echo "flake8 checks passed." + # download the right version or run locally + if [ $RUN_LOCAL_PYCODESTYLE == "False" ]; then + # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. + # See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 + # Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. + echo "downloading pycodestyle from $PYCODESTYLE_SCRIPT_REMOTE_PATH..." + if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" + local curl_status="$?" + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to download pycodestyle.py from $PYCODESTYLE_SCRIPT_REMOTE_PATH" + exit "$curl_status" + fi fi + + echo "starting pycodestyle test..." + PYCODESTYLE_REPORT=$( (python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $1) 2>&1) + PYCODESTYLE_STATUS=$? + else + # we have the right version installed, so run locally + echo "starting pycodestyle test..." + PYCODESTYLE_REPORT=$( ($PYCODESTYLE_BUILD --config=dev/tox.ini $1) 2>&1) + PYCODESTYLE_STATUS=$? + fi + + if [ $PYCODESTYLE_STATUS -ne 0 ]; then + echo "pycodestyle checks failed:" + echo "$PYCODESTYLE_REPORT" + exit "$PYCODESTYLE_STATUS" else - echo "The flake8 version needs to be "$MINIMUM_FLAKE8" at latest. Your current version is '"$FLAKE8VERSION"'." + echo "pycodestyle checks passed." + echo + fi +} + +function flake8_test { + local FLAKE8_VERSION= + local VERSION= + local EXPECTED_FLAKE8= + local FLAKE8_REPORT= + local FLAKE8_STATUS= + + if ! hash "$FLAKE8_BUILD" 2> /dev/null; then + echo "The flake8 command was not found." echo "flake8 checks failed." exit 1 fi -else - echo >&2 "The flake8 command was not found." - echo "flake8 checks failed." - exit 1 -fi - -# Check python document style, skip check if pydocstyle is not installed. -if hash "$PYDOCSTYLEBUILD" 2> /dev/null; then - PYDOCSTYLEVERSION="$( $PYDOCSTYLEBUILD --version 2> /dev/null )" - IS_EXPECTED_PYDOCSTYLEVERSION=$(python -c 'from distutils.version import LooseVersion; \ -print(LooseVersion("""'$PYDOCSTYLEVERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLEVERSION'"""))') - if [[ "$IS_EXPECTED_PYDOCSTYLEVERSION" == "True" ]]; then - $PYDOCSTYLEBUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK >> "$PYDOCSTYLE_REPORT_PATH" - pydocstyle_status="${PIPESTATUS[0]}" - - if [ "$compile_status" -eq 0 -a "$pydocstyle_status" -eq 0 ]; then - echo "pydocstyle checks passed." - rm "$PYDOCSTYLE_REPORT_PATH" - else - echo "pydocstyle checks failed." - cat "$PYDOCSTYLE_REPORT_PATH" - rm "$PYDOCSTYLE_REPORT_PATH" - exit 1 - fi + FLAKE8_VERSION="$($FLAKE8_BUILD --version 2> /dev/null)" + VERSION=($FLAKE8_VERSION) + EXPECTED_FLAKE8=$( (python -c 'from distutils.version import LooseVersion; + print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))') \ + 2> /dev/null) + + if [[ "$EXPECTED_FLAKE8" == "False" ]]; then + echo "\ +The minimum flake8 version needs to be $MINIMUM_FLAKE8. Your current version is $FLAKE8_VERSION + +flake8 checks failed." + exit 1 + fi + + echo "starting $FLAKE8_BUILD test..." + FLAKE8_REPORT=$( ($FLAKE8_BUILD . --count --select=E901,E999,F821,F822,F823 \ + --max-line-length=100 --show-source --statistics) 2>&1) + FLAKE8_STATUS=$? + + if [ "$FLAKE8_STATUS" -ne 0 ]; then + echo "flake8 checks failed:" + echo "$FLAKE8_REPORT" + echo "$FLAKE8_STATUS" + exit "$FLAKE8_STATUS" else - echo "The pydocstyle version needs to be "$MINIMUM_PYDOCSTYLEVERSION" at latest. Your current version is "$PYDOCSTYLEVERSION". Skipping pydoc checks for now." + echo "flake8 checks passed." + echo fi -else - echo >&2 "The pydocstyle command was not found. Skipping pydoc checks for now" -fi - -# Check that the documentation builds acceptably, skip check if sphinx is not installed. -if hash "$SPHINXBUILD" 2> /dev/null; then - cd python/docs - make clean - # Treat warnings as errors so we stop correctly - SPHINXOPTS="-a -W" make html &> "$SPHINX_REPORT_PATH" || lint_status=1 - if [ "$lint_status" -ne 0 ]; then - echo "pydoc checks failed." - cat "$SPHINX_REPORT_PATH" - echo "re-running make html to print full warning list" - make clean - SPHINXOPTS="-a" make html - rm "$SPHINX_REPORT_PATH" - exit "$lint_status" - else - echo "pydoc checks passed." - rm "$SPHINX_REPORT_PATH" - fi - cd ../.. -else - echo >&2 "The $SPHINXBUILD command was not found. Skipping pydoc checks for now" -fi +} + +function pydocstyle_test { + local PYDOCSTYLE_REPORT= + local PYDOCSTYLE_STATUS= + local PYDOCSTYLE_VERSION= + local EXPECTED_PYDOCSTYLE= + + # Exclude auto-generated configuration file. + local DOC_PATHS_TO_CHECK="$( cd "${SPARK_ROOT_DIR}" && find . -name "*.py" | grep -vF 'functions.py' )" + + # Check python document style, skip check if pydocstyle is not installed. + if ! hash "$PYDOCSTYLE_BUILD" 2> /dev/null; then + echo "The pydocstyle command was not found. Skipping pydocstyle checks for now." + echo + return + fi + + PYDOCSTYLE_VERSION="$($PYDOCSTYLEBUILD --version 2> /dev/null)" + EXPECTED_PYDOCSTYLE=$(python -c 'from distutils.version import LooseVersion; \ + print(LooseVersion("""'$PYDOCSTYLE_VERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLE'"""))' \ + 2> /dev/null) + + if [[ "$EXPECTED_PYDOCSTYLE" == "False" ]]; then + echo "\ +The minimum version of pydocstyle needs to be $MINIMUM_PYDOCSTYLE. +Your current version is $PYDOCSTYLE_VERSION. +Skipping pydocstyle checks for now." + echo + return + fi + + echo "starting $PYDOCSTYLE_BUILD test..." + PYDOCSTYLE_REPORT=$( ($PYDOCSTYLE_BUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK) 2>&1) + PYDOCSTYLE_STATUS=$? + + if [ "$PYDOCSTYLE_STATUS" -ne 0 ]; then + echo "pydocstyle checks failed:" + echo "$PYDOCSTYLE_REPORT" + exit "$PYDOCSTYLE_STATUS" + else + echo "pydocstyle checks passed." + echo + fi +} + +function sphinx_test { + local SPHINX_REPORT= + local SPHINX_STATUS= + + # Check that the documentation builds acceptably, skip check if sphinx is not installed. + if ! hash "$SPHINX_BUILD" 2> /dev/null; then + echo "The $SPHINX_BUILD command was not found. Skipping pydoc checks for now." + echo + return + fi + + echo "starting $SPHINX_BUILD tests..." + pushd python/docs &> /dev/null + make clean &> /dev/null + # Treat warnings as errors so we stop correctly + SPHINX_REPORT=$( (SPHINXOPTS="-a -W" make html) 2>&1) + SPHINX_STATUS=$? + + if [ "$SPHINX_STATUS" -ne 0 ]; then + echo "$SPHINX_BUILD checks failed:" + echo "$SPHINX_REPORT" + echo + echo "re-running make html to print full warning list:" + make clean &> /dev/null + SPHINX_REPORT=$( (SPHINXOPTS="-a" make html) 2>&1) + echo "$SPHINX_REPORT" + exit "$SPHINX_STATUS" + else + echo "$SPHINX_BUILD checks passed." + echo + fi + + popd &> /dev/null +} + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")" + +pushd "$SPARK_ROOT_DIR" &> /dev/null + +PYTHON_SOURCE="$(find . -name "*.py")" + +compile_python_test "$PYTHON_SOURCE" +pycodestyle_test "$PYTHON_SOURCE" +flake8_test +pydocstyle_test +sphinx_test + +echo +echo "all lint-python tests passed!" + +popd &> /dev/null From 23bcd6ce458f1e49f307c89ca2794dc9a173077c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 20 Nov 2018 18:03:54 -0600 Subject: [PATCH 2099/2461] [SPARK-26043][HOTFIX] Hotfix a change to SparkHadoopUtil that doesn't work in 2.11 ## What changes were proposed in this pull request? Hotfix a change to SparkHadoopUtil that doesn't work in 2.11 ## How was this patch tested? Existing tests. Closes #23097 from srowen/SPARK-26043.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../scala/org/apache/spark/deploy/SparkHadoopUtil.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 217e5145f1c56..7bb2a419107d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Date, Locale} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -269,10 +269,11 @@ private[spark] class SparkHadoopUtil extends Logging { name.startsWith(prefix) && !name.endsWith(exclusionSuffix) } }) - Arrays.sort(fileStatuses, - (o1: FileStatus, o2: FileStatus) => { + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { Longs.compare(o1.getModificationTime, o2.getModificationTime) - }) + } + }) fileStatuses } catch { case NonFatal(e) => From 47851056c20c5d981b1ca66bac3f00c19a882727 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 20 Nov 2018 18:05:39 -0600 Subject: [PATCH 2100/2461] [SPARK-26124][BUILD] Update plugins to latest versions ## What changes were proposed in this pull request? Update many plugins we use to the latest version, especially MiMa, which entails excluding some new errors on old changes. ## How was this patch tested? N/A Closes #23087 from srowen/Plugins. Authored-by: Sean Owen Signed-off-by: Sean Owen --- pom.xml | 40 +++++++++++++++++++++++--------------- project/MimaExcludes.scala | 10 +++++++++- project/plugins.sbt | 14 ++++++------- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/pom.xml b/pom.xml index 9130773cb5094..08a29d2d52310 100644 --- a/pom.xml +++ b/pom.xml @@ -1977,7 +1977,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 3.0.0-M1 + 3.0.0-M2 enforce-versions @@ -2077,7 +2077,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.7.0 + 3.8.0 ${java.version} ${java.version} @@ -2094,7 +2094,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.22.0 + 3.0.0-M1 @@ -2148,7 +2148,7 @@ org.scalatest scalatest-maven-plugin - 1.0 + 2.0.0 ${project.build.directory}/surefire-reports @@ -2195,7 +2195,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.0.2 + 3.1.0 org.apache.maven.plugins @@ -2222,7 +2222,7 @@ org.apache.maven.plugins maven-clean-plugin - 3.0.0 + 3.1.0 @@ -2240,9 +2240,12 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.0.0-M1 + 3.0.1 - -Xdoclint:all -Xdoclint:-missing + + -Xdoclint:all + -Xdoclint:-missing + example @@ -2293,7 +2296,7 @@ org.apache.maven.plugins maven-shade-plugin - 3.2.0 + 3.2.1 org.ow2.asm @@ -2310,12 +2313,12 @@ org.apache.maven.plugins maven-install-plugin - 2.5.2 + 3.0.0-M1 org.apache.maven.plugins maven-deploy-plugin - 2.8.2 + 3.0.0-M1 org.apache.maven.plugins @@ -2361,7 +2364,7 @@ org.apache.maven.plugins maven-jar-plugin - [2.6,) + 3.1.0 test-jar @@ -2518,12 +2521,17 @@ org.apache.maven.plugins maven-checkstyle-plugin - 2.17 + 3.0.0 false true - ${basedir}/src/main/java,${basedir}/src/main/scala - ${basedir}/src/test/java + + ${basedir}/src/main/java + ${basedir}/src/main/scala + + + ${basedir}/src/test/java + dev/checkstyle.xml ${basedir}/target/checkstyle-output.xml ${project.build.sourceEncoding} @@ -2533,7 +2541,7 @@ com.puppycrawl.tools checkstyle - 8.2 + 8.14 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e35e74aa33045..b750535e8a70b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,7 +36,15 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( - // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 + // [SPARK-26124] Update plugins, including MiMa + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.planInputPartitions"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.planInputPartitions"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters.build"), + + // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), diff --git a/project/plugins.sbt b/project/plugins.sbt index ffbd417b0f145..c9354735a62f5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,7 +1,7 @@ addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") // sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. -libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.14" // checkstyle uses guava 23.0. libraryDependencies += "com.google.guava" % "guava" % "23.0" @@ -9,13 +9,13 @@ libraryDependencies += "com.google.guava" % "guava" % "23.0" // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.3") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4") -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.3.0") // sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6 addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -28,12 +28,12 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1") -libraryDependencies += "org.ow2.asm" % "asm" % "5.1" +libraryDependencies += "org.ow2.asm" % "asm" % "7.0" -libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1" +libraryDependencies += "org.ow2.asm" % "asm-commons" % "7.0" // sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14 -addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.12") // Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues // related to test-jar dependencies (https://github.com/sbt/sbt-pom-reader/pull/14). The source for From 2df34db586bec379e40b5cf30021f5b7a2d79271 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 21 Nov 2018 09:29:22 +0800 Subject: [PATCH 2101/2461] [SPARK-26122][SQL] Support encoding for multiLine in CSV datasource ## What changes were proposed in this pull request? In the PR, I propose to pass the CSV option `encoding`/`charset` to `uniVocity` parser to allow parsing CSV files in different encodings when `multiLine` is enabled. The value of the option is passed to the `beginParsing` method of `CSVParser`. ## How was this patch tested? Added new test to `CSVSuite` for different encodings and enabled/disabled header. Closes #23091 from MaxGekk/csv-miltiline-encoding. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../sql/catalyst/csv/UnivocityParser.scala | 12 ++++++----- .../datasources/csv/CSVDataSource.scala | 6 ++++-- .../execution/datasources/csv/CSVSuite.scala | 21 +++++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 46ed58ed92830..ed196935e357f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -271,11 +271,12 @@ private[sql] object UnivocityParser { def tokenizeStream( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser): Iterator[Array[String]] = { + tokenizer: CsvParser, + encoding: String): Iterator[Array[String]] = { val handleHeader: () => Unit = () => if (shouldDropHeader) tokenizer.parseNext - convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens) + convertStream(inputStream, tokenizer, handleHeader, encoding)(tokens => tokens) } /** @@ -297,7 +298,7 @@ private[sql] object UnivocityParser { val handleHeader: () => Unit = () => headerChecker.checkHeaderColumnNames(tokenizer) - convertStream(inputStream, tokenizer, handleHeader) { tokens => + convertStream(inputStream, tokenizer, handleHeader, parser.options.charset) { tokens => safeParser.parse(tokens) }.flatten } @@ -305,9 +306,10 @@ private[sql] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, tokenizer: CsvParser, - handleHeader: () => Unit)( + handleHeader: () => Unit, + encoding: String)( convert: Array[String] => T) = new Iterator[T] { - tokenizer.beginParsing(inputStream) + tokenizer.beginParsing(inputStream, encoding) // We can handle header here since here the stream is open. handleHeader() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4808e8ef042d1..554baaf1a9b3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -192,7 +192,8 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, - new CsvParser(parsedOptions.asParserSettings)) + new CsvParser(parsedOptions.asParserSettings), + encoding = parsedOptions.charset) }.take(1).headOption match { case Some(firstRow) => val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -203,7 +204,8 @@ object MultiLineCSVDataSource extends CSVDataSource { lines.getConfiguration, new Path(lines.getPath())), parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) + new CsvParser(parsedOptions.asParserSettings), + encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) CSVInferSchema.infer(sampled, header, parsedOptions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2efe1dda475c5..e29cd2aa7c4e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1859,4 +1859,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(df, Row(null, csv)) } } + + test("encoding in multiLine mode") { + val df = spark.range(3).toDF() + Seq("UTF-8", "ISO-8859-1", "CP1251", "US-ASCII", "UTF-16BE", "UTF-32LE").foreach { encoding => + Seq(true, false).foreach { header => + withTempPath { path => + df.write + .option("encoding", encoding) + .option("header", header) + .csv(path.getCanonicalPath) + val readback = spark.read + .option("multiLine", true) + .option("encoding", encoding) + .option("inferSchema", true) + .option("header", header) + .csv(path.getCanonicalPath) + checkAnswer(readback, df) + } + } + } + } } From 4b7f7ef5007c2c8a5090f22c6e08927e9f9a407b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Nov 2018 09:31:12 +0800 Subject: [PATCH 2102/2461] [SPARK-26120][TESTS][SS][SPARKR] Fix a streaming query leak in Structured Streaming R tests ## What changes were proposed in this pull request? Stop the streaming query in `Specify a schema by using a DDL-formatted string when reading` to avoid outputting annoying logs. ## How was this patch tested? Jenkins Closes #23089 from zsxwing/SPARK-26120. Authored-by: Shixiong Zhu Signed-off-by: hyukjinkwon --- R/pkg/tests/fulltests/test_streaming.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index bfb1a046490ec..6f0d2aefee886 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -127,6 +127,7 @@ test_that("Specify a schema by using a DDL-formatted string when reading", { expect_false(awaitTermination(q, 5 * 1000)) callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + stopQuery(q) expect_error(read.stream(path = parquetPath, schema = "name stri"), "DataType stri is not supported.") From a480a6256318b43b963fb7414ccb789e4b950c8b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 21 Nov 2018 00:24:34 -0800 Subject: [PATCH 2103/2461] [SPARK-25954][SS] Upgrade to Kafka 2.1.0 ## What changes were proposed in this pull request? [Kafka 2.1.0 vote](https://lists.apache.org/thread.html/9f487094491e512b556a1c9c3c6034ac642b088e3f797e3d192ebc9d%3Cdev.kafka.apache.org%3E) passed. Since Kafka 2.1.0 includes official JDK 11 support [KAFKA-7264](https://issues.apache.org/jira/browse/KAFKA-7264), we had better use that. ## How was this patch tested? Pass the Jenkins. Closes #23099 from dongjoon-hyun/SPARK-25954. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 3f1055a75076f..d97e8cf18605e 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -30,7 +30,7 @@ sql-kafka-0-10 - 2.0.0 + 2.1.0 jar Kafka 0.10+ Source for Structured Streaming diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index d75b13da8fb70..cfc45559d8e34 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -29,7 +29,7 @@ streaming-kafka-0-10 - 2.0.0 + 2.1.0 jar Spark Integration for Kafka 0.10 From 540afc2b18ef61cceb50b9a5b327e6fcdbe1e7e4 Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 21 Nov 2018 09:31:35 -0600 Subject: [PATCH 2104/2461] [SPARK-26109][WEBUI] Duration in the task summary metrics table and the task table are different ## What changes were proposed in this pull request? Task summary table displays the summary of the task table in the stage page. However, the 'Duration' metrics of 'task summary' table and 'task table' are not matching. The reason is because, in the 'task summary' we display 'executorRunTime' as the duration, and in the 'task table' the actual duration of the task. Except duration metrics, all other metrics are properly displaying in the task summary. In Spark2.2, used to show 'executorRunTime' as duration in the 'taskTable'. That is why, in summary metrics also the 'exeuctorRunTime' shows as the duration. So, we need to show 'executorRunTime' as the duration in the tasks table to follow the same behaviour as the previous versions of spark. ## How was this patch tested? Before patch: ![screenshot from 2018-11-19 04-32-06](https://user-images.githubusercontent.com/23054875/48679263-1e4fff80-ebb4-11e8-9ed5-16d892039e01.png) After patch: ![screenshot from 2018-11-19 04-37-39](https://user-images.githubusercontent.com/23054875/48679343-e39a9700-ebb4-11e8-8df9-9dc3a28d4bce.png) Closes #23081 from shahidki31/duratinSummary. Authored-by: Shahid Signed-off-by: Sean Owen --- .../src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 477b9ce7f7848..7e6cc4297d6b1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -843,7 +843,7 @@ private[ui] class TaskPagedTable(
      {UIUtils.formatDate(task.launchTime)} - {formatDuration(task.duration)} + {formatDuration(task.taskMetrics.map(_.executorRunTime))} {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} @@ -996,7 +996,9 @@ private[ui] object ApiHelper { HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, HEADER_HOST -> TaskIndexNames.HOST, HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, - HEADER_DURATION -> TaskIndexNames.DURATION, + // SPARK-26109: Duration of task as executorRunTime to make it consistent with the + // aggregated tasks summary metrics table and the previous versions of Spark. + HEADER_DURATION -> TaskIndexNames.EXEC_RUN_TIME, HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, HEADER_GC_TIME -> TaskIndexNames.GC_TIME, From 6bbdf34baed7b2bab1fbfbce7782b3093a72812f Mon Sep 17 00:00:00 2001 From: Drew Robb Date: Wed, 21 Nov 2018 09:38:06 -0600 Subject: [PATCH 2105/2461] [SPARK-8288][SQL] ScalaReflection can use companion object constructor ## What changes were proposed in this pull request? This change fixes a particular scenario where default spark SQL can't encode (thrift) types that are generated by twitter scrooge. These types are a trait that extends `scala.ProductX` with a constructor defined only in a companion object, rather than a actual case class. The actual case class used is child class, but that type is almost never referred to in code. The type has no corresponding constructor symbol and causes an exception. For all other purposes, these classes act just like case classes, so it is unfortunate that spark SQL can't serialize them nicely as it can actual case classes. For an full example of a scrooge codegen class, see https://gist.github.com/anonymous/ba13d4b612396ca72725eaa989900314. This change catches the case where the type has no constructor but does have an `apply` method on the type's companion object. This allows for thrift types to be serialized/deserialized with implicit encoders the same way as normal case classes. This fix had to be done in three places where the constructor is assumed to be an actual constructor: 1) In serializing, determining the schema for the dataframe relies on inspecting its constructor (`ScalaReflection.constructParams`). Here we fall back to using the companion constructor arguments. 2) In deserializing or evaluating, in the java codegen ( `NewInstance.doGenCode`), the type couldn't be constructed with the new keyword. If there is no constructor, we change the constructor call to try the companion constructor. 3) In deserializing or evaluating, without codegen, the constructor is directly invoked (`NewInstance.constructor`). This was fixed with scala reflection to get the actual companion apply method. The return type of `findConstructor` was changed because the companion apply method constructor can't be represented as a `java.lang.reflect.Constructor`. There might be situations in which this approach would also fail in a new way, but it does at a minimum work for the specific scrooge example and will not impact cases that were already succeeding prior to this change Note: this fix does not enable using scrooge thrift enums, additional work for this is necessary. With this patch, it seems like you could patch `com.twitter.scrooge.ThriftEnum` to extend `_root_.scala.Product1[Int]` with `def _1 = value` to get spark's implicit encoders to handle enums, but I've yet to use this method myself. Note: I previously opened a PR for this issue, but only was able to fix case 1) there: https://github.com/apache/spark/pull/18766 ## How was this patch tested? I've fixed all 3 cases and added two tests that use a case class that is similar to scrooge generated one. The test in ScalaReflectionSuite checks 1), and the additional asserting in ObjectExpressionsSuite checks 2) and 3). Closes #23062 from drewrobb/SPARK-8288. Authored-by: Drew Robb Signed-off-by: Sean Owen --- .../spark/sql/catalyst/ScalaReflection.scala | 48 ++++++++++++++++--- .../expressions/objects/objects.scala | 18 ++++--- .../sql/catalyst/ScalaReflectionSuite.scala | 31 ++++++++++++ .../expressions/ObjectExpressionsSuite.scala | 10 ++++ .../org/apache/spark/sql/DatasetSuite.scala | 8 ++++ 5 files changed, 103 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 64ea236532839..c8542d0f2f7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -788,12 +788,37 @@ object ScalaReflection extends ScalaReflection { } /** - * Finds an accessible constructor with compatible parameters. This is a more flexible search - * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible - * matching constructor is returned. Otherwise, it returns `None`. + * Finds an accessible constructor with compatible parameters. This is a more flexible search than + * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned if it exists. Otherwise, we check for additional compatible + * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`. */ - def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { - Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match { + case Some(c) => Some(x => c.newInstance(x: _*).asInstanceOf[T]) + case None => + val companion = mirror.staticClass(cls.getName).companion + val moduleMirror = mirror.reflectModule(companion.asModule) + val applyMethods = companion.asTerm.typeSignature + .member(universe.TermName("apply")).asTerm.alternatives + applyMethods.find { method => + val params = method.typeSignature.paramLists.head + // Check that the needed params are the same length and of matching types + params.size == paramTypes.tail.size && + params.zip(paramTypes.tail).forall { case(ps, pc) => + ps.typeSignature.typeSymbol == mirror.classSymbol(pc) + } + }.map { applyMethodSymbol => + val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size + val instanceMirror = mirror.reflect(moduleMirror.instance) + val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod) + (_args: Seq[AnyRef]) => { + // Drop the "outer" argument if it is provided + val args = if (_args.size == expectedArgsCount) _args else _args.tail + method.apply(args: _*).asInstanceOf[T] + } + } + } } /** @@ -973,8 +998,19 @@ trait ScalaReflection extends Logging { } } + /** + * If our type is a Scala trait it may have a companion object that + * only defines a constructor via `apply` method. + */ + private def getCompanionConstructor(tpe: Type): Symbol = { + tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply")) + } + protected def constructParams(tpe: Type): Seq[Symbol] = { - val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR) + val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match { + case NoSymbol => getCompanionConstructor(tpe) + case sym => sym + } val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramLists } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4fd36a47cef52..59c897b6a53ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -462,12 +462,12 @@ case class NewInstance( val d = outerObj.getClass +: paramTypes val c = getConstructor(outerObj.getClass +: paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(outerObj +: args: _*) + c(outerObj +: args) } }.getOrElse { val c = getConstructor(paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(args: _*) + c(args) } } } @@ -486,10 +486,16 @@ case class NewInstance( ev.isNull = resultIsNull - val constructorCall = outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" - }.getOrElse { - s"new $className($argString)" + val constructorCall = cls.getConstructors.size match { + // If there are no constructors, the `new` method will fail. In + // this case we can try to call the apply method constructor + // that might be defined on the companion object. + case 0 => s"$className$$.MODULE$$.apply($argString)" + case _ => outer.map { gen => + s"${gen.value}.new ${cls.getSimpleName}($argString)" + }.getOrElse { + s"new $className($argString)" + } } val code = code""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d98589db323cc..80824cc2a7f21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -109,6 +109,30 @@ object TestingUDT { } } +/** An example derived from Twitter/Scrooge codegen for thrift */ +object ScroogeLikeExample { + def apply(x: Int): ScroogeLikeExample = new Immutable(x) + + def unapply(_item: ScroogeLikeExample): Option[Int] = Some(_item.x) + + class Immutable(val x: Int) extends ScroogeLikeExample +} + +trait ScroogeLikeExample extends Product1[Int] with Serializable { + import ScroogeLikeExample._ + + def x: Int + + def _1: Int = x + + override def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample] + + override def equals(other: Any): Boolean = + canEqual(other) && + this.x == other.asInstanceOf[ScroogeLikeExample].x + + override def hashCode: Int = x +} class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ @@ -362,4 +386,11 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } + + test("SPARK-8288: schemaFor works for a class with only a companion object constructor") { + val schema = schemaFor[ScroogeLikeExample] + assert(schema === Schema( + StructType(Seq( + StructField("x", IntegerType, nullable = false))), nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 16842c1bcc8cb..436675bf50353 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -410,6 +411,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { dataType = ObjectType(classOf[outerObj.Inner]), outerPointer = Some(() => outerObj)) checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + + // SPARK-8288: A class with only a companion object constructor + val newInst3 = NewInstance( + cls = classOf[ScroogeLikeExample], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[ScroogeLikeExample]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1)) } test("LambdaVariable should support interpreted execution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ac677e8ec6bc2..540fbff6a3a63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1570,6 +1571,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long]) checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) } + + test("SPARK-8288: class with only a companion object constructor") { + val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2)) + val ds = data.toDS + checkDataset(ds, data: _*) + checkAnswer(ds.select("x"), Seq(Row(1), Row(2))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 07a700b3711057553dfbb7b047216565726509c7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 21 Nov 2018 16:41:12 +0100 Subject: [PATCH 2106/2461] [SPARK-26129][SQL] Instrumentation for per-query planning time ## What changes were proposed in this pull request? We currently don't have good visibility into query planning time (analysis vs optimization vs physical planning). This patch adds a simple utility to track the runtime of various rules and various planning phases. ## How was this patch tested? Added unit tests and end-to-end integration tests. Closes #23096 from rxin/SPARK-26129. Authored-by: Reynold Xin Signed-off-by: Reynold Xin --- .../sql/catalyst/QueryPlanningTracker.scala | 127 ++++++++++++++++++ .../sql/catalyst/analysis/Analyzer.scala | 22 +-- .../sql/catalyst/rules/RuleExecutor.scala | 19 ++- .../catalyst/QueryPlanningTrackerSuite.scala | 78 +++++++++++ .../sql/catalyst/analysis/AnalysisTest.scala | 3 +- .../ResolveGroupingAnalyticsSuite.scala | 3 +- .../ResolvedUuidExpressionsSuite.scala | 10 +- .../scala/org/apache/spark/sql/Dataset.scala | 9 ++ .../org/apache/spark/sql/SparkSession.scala | 6 +- .../spark/sql/execution/QueryExecution.scala | 21 ++- .../QueryPlanningTrackerEndToEndSuite.scala | 52 +++++++ .../apache/spark/sql/hive/test/TestHive.scala | 16 ++- 12 files changed, 338 insertions(+), 28 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala new file mode 100644 index 0000000000000..420f2a1f20997 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.collection.JavaConverters._ + +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * A simple utility for tracking runtime and associated stats in query planning. + * + * There are two separate concepts we track: + * + * 1. Phases: These are broad scope phases in query planning, as listed below, i.e. analysis, + * optimizationm and physical planning (just planning). + * + * 2. Rules: These are the individual Catalyst rules that we track. In addition to time, we also + * track the number of invocations and effective invocations. + */ +object QueryPlanningTracker { + + // Define a list of common phases here. + val PARSING = "parsing" + val ANALYSIS = "analysis" + val OPTIMIZATION = "optimization" + val PLANNING = "planning" + + class RuleSummary( + var totalTimeNs: Long, var numInvocations: Long, var numEffectiveInvocations: Long) { + + def this() = this(totalTimeNs = 0, numInvocations = 0, numEffectiveInvocations = 0) + + override def toString: String = { + s"RuleSummary($totalTimeNs, $numInvocations, $numEffectiveInvocations)" + } + } + + /** + * A thread local variable to implicitly pass the tracker around. This assumes the query planner + * is single-threaded, and avoids passing the same tracker context in every function call. + */ + private val localTracker = new ThreadLocal[QueryPlanningTracker]() { + override def initialValue: QueryPlanningTracker = null + } + + /** Returns the current tracker in scope, based on the thread local variable. */ + def get: Option[QueryPlanningTracker] = Option(localTracker.get()) + + /** Sets the current tracker for the execution of function f. We assume f is single-threaded. */ + def withTracker[T](tracker: QueryPlanningTracker)(f: => T): T = { + val originalTracker = localTracker.get() + localTracker.set(tracker) + try f finally { localTracker.set(originalTracker) } + } +} + + +class QueryPlanningTracker { + + import QueryPlanningTracker._ + + // Mapping from the name of a rule to a rule's summary. + // Use a Java HashMap for less overhead. + private val rulesMap = new java.util.HashMap[String, RuleSummary] + + // From a phase to time in ns. + private val phaseToTimeNs = new java.util.HashMap[String, Long] + + /** Measure the runtime of function f, and add it to the time for the specified phase. */ + def measureTime[T](phase: String)(f: => T): T = { + val startTime = System.nanoTime() + val ret = f + val timeTaken = System.nanoTime() - startTime + phaseToTimeNs.put(phase, phaseToTimeNs.getOrDefault(phase, 0) + timeTaken) + ret + } + + /** + * Record a specific invocation of a rule. + * + * @param rule name of the rule + * @param timeNs time taken to run this invocation + * @param effective whether the invocation has resulted in a plan change + */ + def recordRuleInvocation(rule: String, timeNs: Long, effective: Boolean): Unit = { + var s = rulesMap.get(rule) + if (s eq null) { + s = new RuleSummary + rulesMap.put(rule, s) + } + + s.totalTimeNs += timeNs + s.numInvocations += 1 + s.numEffectiveInvocations += (if (effective) 1 else 0) + } + + // ------------ reporting functions below ------------ + + def rules: Map[String, RuleSummary] = rulesMap.asScala.toMap + + def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap + + /** Returns the top k most expensive rules (as measured by time). */ + def topRulesByTime(k: Int): Seq[(String, RuleSummary)] = { + val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) + val q = new BoundedPriorityQueue(k)(orderingByTime) + rulesMap.asScala.foreach(q.+=) + q.toSeq.sortBy(r => -r._2.totalTimeNs) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ab2312fdcdeef..b977fa07db5c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,16 +102,18 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { - val analyzed = execute(plan) - try { - checkAnalysis(analyzed) - analyzed - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae + def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { + AnalysisHelper.markInAnalyzer { + val analyzed = executeAndTrack(plan, tracker) + try { + checkAnalysis(analyzed) + analyzed + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e991a2dc7462f..cf6ff4f986399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.rules import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -66,6 +67,17 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ protected def isPlanIntegral(plan: TreeType): Boolean = true + /** + * Executes the batches of rules defined by the subclass, and also tracks timing info for each + * rule using the provided tracker. + * @see [[execute]] + */ + def executeAndTrack(plan: TreeType, tracker: QueryPlanningTracker): TreeType = { + QueryPlanningTracker.withTracker(tracker) { + execute(plan) + } + } + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -74,6 +86,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { var curPlan = plan val queryExecutionMetrics = RuleExecutor.queryExecutionMeter val planChangeLogger = new PlanChangeLogger() + val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get batches.foreach { batch => val batchStartPlan = curPlan @@ -88,8 +101,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime + val effective = !result.fastEquals(plan) - if (!result.fastEquals(plan)) { + if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) planChangeLogger.log(rule.ruleName, plan, result) @@ -97,6 +111,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) queryExecutionMetrics.incNumExecution(rule.ruleName) + // Record timing information using QueryPlanningTracker + tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective)) + // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala new file mode 100644 index 0000000000000..f42c262dfbdd8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite + +class QueryPlanningTrackerSuite extends SparkFunSuite { + + test("phases") { + val t = new QueryPlanningTracker + t.measureTime("p1") { + Thread.sleep(1) + } + + assert(t.phases("p1") > 0) + assert(!t.phases.contains("p2")) + + val old = t.phases("p1") + + t.measureTime("p1") { + Thread.sleep(1) + } + assert(t.phases("p1") > old) + } + + test("rules") { + val t = new QueryPlanningTracker + t.recordRuleInvocation("r1", 1, effective = false) + t.recordRuleInvocation("r2", 2, effective = true) + t.recordRuleInvocation("r3", 1, effective = false) + t.recordRuleInvocation("r3", 2, effective = true) + + val rules = t.rules + + assert(rules("r1").totalTimeNs == 1) + assert(rules("r1").numInvocations == 1) + assert(rules("r1").numEffectiveInvocations == 0) + + assert(rules("r2").totalTimeNs == 2) + assert(rules("r2").numInvocations == 1) + assert(rules("r2").numEffectiveInvocations == 1) + + assert(rules("r3").totalTimeNs == 3) + assert(rules("r3").numInvocations == 2) + assert(rules("r3").numEffectiveInvocations == 1) + } + + test("topRulesByTime") { + val t = new QueryPlanningTracker + t.recordRuleInvocation("r2", 2, effective = true) + t.recordRuleInvocation("r4", 4, effective = true) + t.recordRuleInvocation("r1", 1, effective = false) + t.recordRuleInvocation("r3", 3, effective = false) + + val top = t.topRulesByTime(2) + assert(top.size == 2) + assert(top(0)._1 == "r4") + assert(top(1)._1 == "r3") + + // Don't crash when k > total size + assert(t.topRulesByTime(10).size == 4) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3d7c91870133b..fab1b776a3c72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -21,6 +21,7 @@ import java.net.URI import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -54,7 +55,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.executeAndCheck(inputPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan, new QueryPlanningTracker) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 8da4d7e3aa372..aa5eda8e5ba87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -109,7 +110,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), unresolved_b, UnresolvedAlias(count(unresolved_c)))) - val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2, new QueryPlanningTracker) val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions assert(gExpressions.size == 3) val firstGroupingExprAttrName = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala index fe57c199b8744..64bd07534b19b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -34,6 +35,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { private lazy val uuid3 = Uuid().as('_uuid3) private lazy val uuid1Ref = uuid1.toAttribute + private val tracker = new QueryPlanningTracker private val analyzer = getAnalyzer(caseSensitive = true) private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { @@ -47,7 +49,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { test("analyzed plan sets random seed for Uuid expression") { val plan = r.select(a, uuid1) - val resolvedPlan = analyzer.executeAndCheck(plan) + val resolvedPlan = analyzer.executeAndCheck(plan, tracker) getUuidExpressions(resolvedPlan).foreach { u => assert(u.resolved) assert(u.randomSeed.isDefined) @@ -56,14 +58,14 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { test("Uuid expressions should have different random seeds") { val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) - val resolvedPlan = analyzer.executeAndCheck(plan) + val resolvedPlan = analyzer.executeAndCheck(plan, tracker) assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) } test("Different analyzed plans should have different random seeds in Uuids") { val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) - val resolvedPlan1 = analyzer.executeAndCheck(plan) - val resolvedPlan2 = analyzer.executeAndCheck(plan) + val resolvedPlan1 = analyzer.executeAndCheck(plan, tracker) + val resolvedPlan2 = analyzer.executeAndCheck(plan, tracker) val uuids1 = getUuidExpressions(resolvedPlan1) val uuids2 = getUuidExpressions(resolvedPlan2) assert(uuids1.distinct.length == 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e77ec0406257..e757921b485df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ @@ -76,6 +77,14 @@ private[sql] object Dataset { qe.assertAnalyzed() new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) } + + /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker) + : DataFrame = { + val qe = new QueryExecution(sparkSession, logicalPlan, tracker) + qe.assertAnalyzed() + new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 725db97df4ed1..739c6b54b4cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -648,7 +648,11 @@ class SparkSession private( * @since 2.0.0 */ def sql(sqlText: String): DataFrame = { - Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText)) + val tracker = new QueryPlanningTracker + val plan = tracker.measureTime(QueryPlanningTracker.PARSING) { + sessionState.sqlParser.parsePlan(sqlText) + } + Dataset.ofRows(self, plan, tracker) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 905d035b64275..87a4ceb91aae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils * While this is not a public class, we should avoid changing the function names for the sake of * changing them, because a lot of developers use the feature for debugging. */ -class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { +class QueryExecution( + val sparkSession: SparkSession, + val logical: LogicalPlan, + val tracker: QueryPlanningTracker = new QueryPlanningTracker) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner @@ -56,9 +59,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } } - lazy val analyzed: LogicalPlan = { + lazy val analyzed: LogicalPlan = tracker.measureTime(QueryPlanningTracker.ANALYSIS) { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.executeAndCheck(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } lazy val withCachedData: LogicalPlan = { @@ -67,9 +70,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { sparkSession.sharedState.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = tracker.measureTime(QueryPlanningTracker.OPTIMIZATION) { + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker) + } - lazy val sparkPlan: SparkPlan = { + lazy val sparkPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. @@ -78,7 +83,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { + prepareForExecution(sparkPlan) + } /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala new file mode 100644 index 0000000000000..0af4c85400e9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.test.SharedSQLContext + +class QueryPlanningTrackerEndToEndSuite extends SharedSQLContext { + + test("programmatic API") { + val df = spark.range(1000).selectExpr("count(*)") + df.collect() + val tracker = df.queryExecution.tracker + + assert(tracker.phases.size == 3) + assert(tracker.phases("analysis") > 0) + assert(tracker.phases("optimization") > 0) + assert(tracker.phases("planning") > 0) + + assert(tracker.rules.nonEmpty) + } + + test("sql") { + val df = spark.sql("select * from range(1)") + df.collect() + + val tracker = df.queryExecution.tracker + + assert(tracker.phases.size == 4) + assert(tracker.phases("parsing") > 0) + assert(tracker.phases("analysis") > 0) + assert(tracker.phases("optimization") > 0) + assert(tracker.phases("planning") > 0) + + assert(tracker.rules.nonEmpty) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 634b3db19ec27..3508affda241a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -33,9 +33,9 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} @@ -219,6 +219,16 @@ private[hive] class TestHiveSparkSession( sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() } + /** + * This is a temporary hack to override SparkSession.sql so we can still use the version of + * Dataset.ofRows that creates a TestHiveQueryExecution (rather than a normal QueryExecution + * which wouldn't load all the test tables). + */ + override def sql(sqlText: String): DataFrame = { + val plan = sessionState.sqlParser.parsePlan(sqlText) + Dataset.ofRows(self, plan) + } + override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) } @@ -586,7 +596,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.executeAndCheck(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } } From 81550b38e43fb20f89f529d2127575c71a54a538 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 21 Nov 2018 11:16:54 -0800 Subject: [PATCH 2107/2461] [SPARK-26066][SQL] Move truncatedString to sql/catalyst and add spark.sql.debug.maxToStringFields conf ## What changes were proposed in this pull request? In the PR, I propose: - new SQL config `spark.sql.debug.maxToStringFields` to control maximum number fields up to which `truncatedString` cuts its input sequences. - Moving `truncatedString` out of `core` to `sql/catalyst` because it is used only in the `sql/catalyst` packages for restricting number of fields converted to strings from `TreeNode` and expressions of`StructType`. ## How was this patch tested? Added a test to `QueryExecutionSuite` to check that `spark.sql.debug.maxToStringFields` impacts to behavior of `truncatedString`. Closes #23039 from MaxGekk/truncated-string-catalyst. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/util/Utils.scala | 48 ------------------- .../org/apache/spark/util/UtilsSuite.scala | 8 ---- .../sql/catalyst/expressions/Expression.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 4 +- .../spark/sql/catalyst/trees/TreeNode.scala | 10 ++-- .../spark/sql/catalyst/util/package.scala | 37 +++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../apache/spark/sql/types/StructType.scala | 4 +- .../org/apache/spark/sql/util/UtilSuite.scala | 31 ++++++++++++ .../sql/execution/DataSourceScanExec.scala | 5 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 3 +- .../aggregate/HashAggregateExec.scala | 7 +-- .../aggregate/ObjectHashAggregateExec.scala | 8 ++-- .../aggregate/SortAggregateExec.scala | 8 ++-- .../execution/columnar/InMemoryRelation.scala | 5 +- .../datasources/LogicalRelation.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 5 +- .../v2/DataSourceV2StringFormat.scala | 5 +- .../apache/spark/sql/execution/limit.scala | 6 +-- .../streaming/MicroBatchExecution.scala | 7 +-- .../continuous/ContinuousExecution.scala | 7 +-- .../sql/execution/streaming/memory.scala | 4 +- .../sql/execution/QueryExecutionSuite.scala | 26 ++++++++++ 24 files changed, 156 insertions(+), 103 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 743fd5d75b2db..227c9e734f0af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -31,7 +31,6 @@ import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.TimeUnit.NANOSECONDS -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream import scala.annotation.tailrec @@ -93,53 +92,6 @@ private[spark] object Utils extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null - /** - * The performance overhead of creating and logging strings for wide schemas can be large. To - * limit the impact, we bound the number of fields to include by default. This can be overridden - * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv. - */ - val DEFAULT_MAX_TO_STRING_FIELDS = 25 - - private[spark] def maxNumToStringFields = { - if (SparkEnv.get != null) { - SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) - } else { - DEFAULT_MAX_TO_STRING_FIELDS - } - } - - /** Whether we have warned about plan string truncation yet. */ - private val truncationWarningPrinted = new AtomicBoolean(false) - - /** - * Format a sequence with semantics similar to calling .mkString(). Any elements beyond - * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. - * - * @return the trimmed and formatted string. - */ - def truncatedString[T]( - seq: Seq[T], - start: String, - sep: String, - end: String, - maxNumFields: Int = maxNumToStringFields): String = { - if (seq.length > maxNumFields) { - if (truncationWarningPrinted.compareAndSet(false, true)) { - logWarning( - "Truncated the string representation of a plan since it was too large. This " + - "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.") - } - val numFields = math.max(0, maxNumFields - 1) - seq.take(numFields).mkString( - start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) - } else { - seq.mkString(start, sep, end) - } - } - - /** Shorthand for calling truncatedString() without start or end strings. */ - def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5293645cab058..f5e912b50d1ab 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -45,14 +45,6 @@ import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { - test("truncatedString") { - assert(Utils.truncatedString(Nil, "[", ", ", "]", 2) == "[]") - assert(Utils.truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") - assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") - assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") - assert(Utils.truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") - } - test("timeConversion") { // Test -1 assert(Utils.timeStringAsSeconds("-1") === -1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 141fcffcb6fab..d51b11024a09d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the basic expression abstract classes in Catalyst. @@ -237,7 +237,7 @@ abstract class Expression extends TreeNode[Expression] { override def simpleString: String = toString - override def toString: String = prettyName + Utils.truncatedString( + override def toString: String = prettyName + truncatedString( flatArguments.toSeq, "(", ", ", ")") /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f09c5ceefed13..07fa17b233a47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler /** @@ -485,7 +485,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) override def output: Seq[Attribute] = child.output override def simpleString: String = { - val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]") + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]") s"CTE $cteAliases" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 1027216165005..2e9f9f53e94ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -440,10 +440,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case tn: TreeNode[_] => tn.simpleString :: Nil case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil case iter: Iterable[_] if iter.isEmpty => Nil - case seq: Seq[_] => Utils.truncatedString(seq, "[", ", ", "]") :: Nil - case set: Set[_] => Utils.truncatedString(set.toSeq, "{", ", ", "}") :: Nil + case seq: Seq[_] => truncatedString(seq, "[", ", ", "]") :: Nil + case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}") :: Nil case array: Array[_] if array.isEmpty => Nil - case array: Array[_] => Utils.truncatedString(array, "[", ", ", "]") :: Nil + case array: Array[_] => truncatedString(array, "[", ", ", "]") :: Nil case null => Nil case None => Nil case Some(null) => Nil @@ -664,7 +664,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => JArray(t.map(parseToJson).toList) case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => - JString(Utils.truncatedString(t, "[", ", ", "]")) + JString(truncatedString(t, "[", ", ", "]")) case t: Seq[_] => JNull case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 0978e92dd4f72..277584b20dcd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql.catalyst import java.io._ import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -package object util { +package object util extends Logging { /** Silences output to stderr or stdout for the duration of f */ def quietly[A](f: => A): A = { @@ -167,6 +170,38 @@ package object util { builder.toString() } + /** Whether we have warned about plan string truncation yet. */ + private val truncationWarningPrinted = new AtomicBoolean(false) + + /** + * Format a sequence with semantics similar to calling .mkString(). Any elements beyond + * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. + * + * @return the trimmed and formatted string. + */ + def truncatedString[T]( + seq: Seq[T], + start: String, + sep: String, + end: String, + maxNumFields: Int = SQLConf.get.maxToStringFields): String = { + if (seq.length > maxNumFields) { + if (truncationWarningPrinted.compareAndSet(false, true)) { + logWarning( + "Truncated the string representation of a plan since it was too large. This " + + s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.") + } + val numFields = math.max(0, maxNumFields - 1) + seq.take(numFields).mkString( + start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + } else { + seq.mkString(start, sep, end) + } + } + + /** Shorthand for calling truncatedString() without start or end strings. */ + def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 518115dafd011..cc0e9727812db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1594,6 +1594,13 @@ object SQLConf { "WHERE, which does not follow SQL standard.") .booleanConf .createWithDefault(false) + + val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields") + .doc("Maximum number of fields of sequence-like entries can be converted to strings " + + "in debug output. Any elements beyond the limit will be dropped and replaced by a" + + """ "... N more fields" placeholder.""") + .intConf + .createWithDefault(25) } /** @@ -2009,6 +2016,8 @@ class SQLConf extends Serializable with Logging { def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3bef75d5bdb6e..6e8bbde7787a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} import org.apache.spark.util.Utils /** @@ -346,7 +346,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") - Utils.truncatedString(fieldTypes, "struct<", ",", ">") + truncatedString(fieldTypes, "struct<", ",", ">") } override def catalogString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala new file mode 100644 index 0000000000000..9c162026942f6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.truncatedString + +class UtilSuite extends SparkFunSuite { + test("truncatedString") { + assert(truncatedString(Nil, "[", ", ", "]", 2) == "[]") + assert(truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") + assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") + assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") + assert(truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index a9b18ab57237d..77e381ef6e6b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -56,8 +57,8 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) } - val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") - s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" + val metadataStr = truncatedString(metadataEntries, " ", ", ", "") + s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]")}$metadataStr" } override def verboseString: String = redact(super.verboseString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2962becb64e88..9f67d556af362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -197,6 +197,6 @@ case class RDDScanExec( } override def simpleString: String = { - s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + s"$nodeName${truncatedString(output, "[", ",", "]")}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 87a4ceb91aae6..cfb5e43207b03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} @@ -213,7 +214,7 @@ class QueryExecution( writer.write("== Parsed Logical Plan ==\n") writeOrError(writer)(logical.treeString(_, verbose, addSuffix)) writer.write("\n== Analyzed Logical Plan ==\n") - val analyzedOutput = stringOrError(Utils.truncatedString( + val analyzedOutput = stringOrError(truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")) writer.write(analyzedOutput) writer.write("\n") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 08dcdf33fb8f2..4827f838fc514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow @@ -930,9 +931,9 @@ case class HashAggregateExec( testFallbackStartsAt match { case None => - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..7145bb03028d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.Utils /** * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may @@ -143,9 +143,9 @@ case class ObjectHashAggregateExec( private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..d732b905dcdd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.Utils /** * Sort-based aggregate operator. @@ -114,9 +114,9 @@ case class SortAggregateExec( private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 3b6588587c35a..73eb65f84489c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{LongAccumulator, Utils} +import org.apache.spark.util.LongAccumulator /** @@ -209,5 +210,5 @@ case class InMemoryRelation( override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) override def simpleString: String = - s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" + s"InMemoryRelation [${truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 8d715f6342988..1023572d19e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. @@ -63,7 +63,7 @@ case class LogicalRelation( case _ => // Do nothing. } - override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" + override def simpleString: String = s"Relation[${truncatedString(output, ",")}] $relation" } object LogicalRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f15014442e3fb..51c385e25bee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,10 +27,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} -import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -159,8 +159,9 @@ private[sql] object JDBCRelation extends Logging { val column = schema.find { f => resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) }.getOrElse { + val maxNumToStringFields = SQLConf.get.maxToStringFields throw new AnalysisException(s"User-defined partition column $columnName not " + - s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + s"found in the JDBC relation: ${schema.simpleString(maxNumToStringFields)}") } column.dataType match { case _: NumericType | DateType | TimestampType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 97e6c6d702acb..e829f621b4ea3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.util.Utils @@ -72,10 +73,10 @@ trait DataSourceV2StringFormat { }.mkString("[", ",", "]") } - val outputStr = Utils.truncatedString(output, "[", ", ", "]") + val outputStr = truncatedString(output, "[", ", ", "]") val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { + truncatedString(entries.map { case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) }, " (", ", ", ")") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 9bfe1a79fc1e1..90dafcf535914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.util.Utils /** * Take the first `limit` elements and collect them to a single partition. @@ -177,8 +177,8 @@ case class TakeOrderedAndProjectExec( override def outputPartitioning: Partitioning = SinglePartition override def simpleString: String = { - val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]") - val outputString = Utils.truncatedString(output, "[", ",", "]") + val orderByString = truncatedString(sortOrder, "[", ",", "]") + val outputString = truncatedString(output, "[", ",", "]") s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2cac86599ef19..5defca391a355 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,13 +24,14 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class MicroBatchExecution( sparkSession: SparkSession, @@ -475,8 +476,8 @@ class MicroBatchExecution( case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => assert(output.size == dataPlan.output.size, - s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(dataPlan.output, ",")}") + s"Invalid batch: ${truncatedString(output, ",")} != " + + s"${truncatedString(dataPlan.output, ",")}") val aliases = output.zip(dataPlan.output).map { case (to, from) => Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4a7df731da67d..1eab55122e84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} @@ -35,7 +36,7 @@ import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class ContinuousExecution( sparkSession: SparkSession, @@ -164,8 +165,8 @@ class ContinuousExecution( val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, - s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newOutput, ",")}") + s"Invalid reader: ${truncatedString(output, ",")} != " + + s"${truncatedString(newOutput, ",")}") replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf52aba21a04..daee089f3871d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils object MemoryStream { protected val currentBlockId = new AtomicInteger(0) @@ -117,7 +117,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" + override def toString: String = s"MemoryStream[${truncatedString(output, ",")}]" override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index a5922d7c825db..0c47a2040f171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -20,9 +20,20 @@ import scala.io.Source import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +case class QueryExecutionTestRecord( + c0: Int, c1: Int, c2: Int, c3: Int, c4: Int, + c5: Int, c6: Int, c7: Int, c8: Int, c9: Int, + c10: Int, c11: Int, c12: Int, c13: Int, c14: Int, + c15: Int, c16: Int, c17: Int, c18: Int, c19: Int, + c20: Int, c21: Int, c22: Int, c23: Int, c24: Int, + c25: Int, c26: Int) + class QueryExecutionSuite extends SharedSQLContext { + import testImplicits._ + def checkDumpedPlans(path: String, expected: Int): Unit = { assert(Source.fromFile(path).getLines.toList .takeWhile(_ != "== Whole Stage Codegen ==") == List( @@ -80,6 +91,21 @@ class QueryExecutionSuite extends SharedSQLContext { assert(exception.getMessage.contains("Illegal character in scheme name")) } + test("limit number of fields by sql config") { + def relationPlans: String = { + val ds = spark.createDataset(Seq(QueryExecutionTestRecord( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) + ds.queryExecution.toString + } + withSQLConf(SQLConf.MAX_TO_STRING_FIELDS.key -> "26") { + assert(relationPlans.contains("more fields")) + } + withSQLConf(SQLConf.MAX_TO_STRING_FIELDS.key -> "27") { + assert(!relationPlans.contains("more fields")) + } + } + test("toString() exception/error handling") { spark.experimental.extraStrategies = Seq( new SparkStrategy { From 4aa9ccbde7870fb2750712e9e38e6aad740e0770 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Nov 2018 17:03:57 -0600 Subject: [PATCH 2108/2461] [SPARK-26127][ML] Remove deprecated setters from tree regression and classification models ## What changes were proposed in this pull request? The setter methods are deprecated since 2.1 for the models of regression and classification using trees. The deprecation was stating that the method would have been removed in 3.0. Hence the PR removes the deprecated method. ## How was this patch tested? NA Closes #23093 from mgaido91/SPARK-26127. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../DecisionTreeClassifier.scala | 18 +-- .../ml/classification/GBTClassifier.scala | 26 ++--- .../RandomForestClassifier.scala | 24 ++-- .../ml/regression/DecisionTreeRegressor.scala | 18 +-- .../spark/ml/regression/GBTRegressor.scala | 27 +++-- .../ml/regression/RandomForestRegressor.scala | 24 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 105 ------------------ project/MimaExcludes.scala | 74 +++++++++++- 8 files changed, 138 insertions(+), 178 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6648e78d8eafa..bcf89766b0873 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -55,27 +55,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -87,15 +87,15 @@ class DecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) override protected def train( dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62c6bdbdeb285..fab8155add5a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -69,27 +69,27 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -101,7 +101,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -110,7 +110,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } @@ -119,25 +119,25 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) /** @group setParam */ @Since("2.3.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) // Parameters from GBTClassifierParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 57132381b6474..05fff8885fbf2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -57,27 +57,27 @@ class RandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -89,31 +89,31 @@ class RandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train( diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c9de85de42fa5..faadc4d7b4ccc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -54,27 +54,27 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -86,15 +86,15 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 07f88d8d5f84d..186fa2399af05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -69,27 +68,27 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -101,7 +100,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -110,7 +109,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } @@ -119,21 +118,21 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTRegressorParams: @@ -143,7 +142,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("2.3.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 66d57ad6c4348..7f5e668ca71db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -56,27 +56,27 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -88,31 +88,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index f1e3836ebe476..c06c68d44ae1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -110,80 +110,24 @@ private[ml] trait DecisionTreeParams extends PredictorParams setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) - /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) - /** @group getParam */ final def getMaxBins: Int = $(maxBins) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) - /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setSeed(value: Long): this.type = set(seed, value) - - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) - /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) - /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -226,13 +170,6 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -273,13 +210,6 @@ private[ml] trait HasVarianceImpurity extends Params { setDefault(impurity -> "variance") - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -346,13 +276,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) - /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) @@ -406,13 +329,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(featureSubsetStrategy -> "auto") - /** - * @deprecated This method is deprecated and will be removed in 3.0.0 - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - /** @group getParam */ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } @@ -440,13 +356,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - /** @group getParam */ final def getNumTrees: Int = $(numTrees) } @@ -491,13 +400,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @Since("2.4.0") final def getValidationTol: Double = $(validationTol) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) - /** * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking * the contribution of each estimator. @@ -508,13 +410,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b750535e8a70b..9089c7d9ffc70 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,76 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-26127] Remove deprecated setters from tree regression and classification models + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setNumTrees"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setNumTrees"), + // [SPARK-26124] Update plugins, including MiMa ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"), @@ -50,15 +120,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), From 9b48107f9c84631e0ddaf0f2223296a3cbc16f83 Mon Sep 17 00:00:00 2001 From: Nagaram Prasad Addepally Date: Wed, 21 Nov 2018 15:51:37 -0800 Subject: [PATCH 2109/2461] [SPARK-25957][K8S] Make building alternate language binding docker images optional ## What changes were proposed in this pull request? bin/docker-image-tool.sh tries to build all docker images (JVM, PySpark and SparkR) by default. But not all spark distributions are built with SparkR and hence this script will fail on such distros. With this change, we make building alternate language binding docker images (PySpark and SparkR) optional. User has to specify dockerfile for those language bindings using -p and -R flags accordingly, to build the binding docker images. ## How was this patch tested? Tested following scenarios. *bin/docker-image-tool.sh -r -t build* --> Builds only JVM docker image (default behavior) *bin/docker-image-tool.sh -r -t -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build* --> Builds both JVM and PySpark docker images *bin/docker-image-tool.sh -r -t -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile -R kubernetes/dockerfiles/spark/bindings/R/Dockerfile build* --> Builds JVM, PySpark and SparkR docker images. Author: Nagaram Prasad Addepally Closes #23053 from ramaddepally/SPARK-25957. --- bin/docker-image-tool.sh | 63 +++++++++++-------- docs/running-on-kubernetes.md | 12 ++++ .../scripts/setup-integration-test-env.sh | 12 +++- 3 files changed, 59 insertions(+), 28 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index aa5d847f4be2f..e51201a77cb5d 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -41,6 +41,18 @@ function image_ref { echo "$image" } +function docker_push { + local image_name="$1" + if [ ! -z $(docker images -q "$(image_ref ${image_name})") ]; then + docker push "$(image_ref ${image_name})" + if [ $? -ne 0 ]; then + error "Failed to push $image_name Docker image." + fi + else + echo "$(image_ref ${image_name}) image not found. Skipping push for this image." + fi +} + function build { local BUILD_ARGS local IMG_PATH @@ -92,8 +104,8 @@ function build { base_img=$(image_ref spark) ) local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} - local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-false} + local RDOCKERFILE=${RDOCKERFILE:-false} docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ @@ -102,33 +114,29 @@ function build { error "Failed to build Spark JVM Docker image, please refer to Docker build output for details." fi - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ - -t $(image_ref spark-py) \ - -f "$PYDOCKERFILE" . + if [ "${PYDOCKERFILE}" != "false" ]; then + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" . + if [ $? -ne 0 ]; then + error "Failed to build PySpark Docker image, please refer to Docker build output for details." + fi + fi + + if [ "${RDOCKERFILE}" != "false" ]; then + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" . if [ $? -ne 0 ]; then - error "Failed to build PySpark Docker image, please refer to Docker build output for details." + error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ - -t $(image_ref spark-r) \ - -f "$RDOCKERFILE" . - if [ $? -ne 0 ]; then - error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi } function push { - docker push "$(image_ref spark)" - if [ $? -ne 0 ]; then - error "Failed to push Spark JVM Docker image." - fi - docker push "$(image_ref spark-py)" - if [ $? -ne 0 ]; then - error "Failed to push PySpark Docker image." - fi - docker push "$(image_ref spark-r)" - if [ $? -ne 0 ]; then - error "Failed to push SparkR Docker image." - fi + docker_push "spark" + docker_push "spark-py" + docker_push "spark-r" } function usage { @@ -143,8 +151,10 @@ Commands: Options: -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. - -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. - -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + -p file (Optional) Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. + Skips building PySpark docker image if not specified. + -R file (Optional) Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + Skips building SparkR docker image if not specified. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -164,6 +174,9 @@ Examples: - Build image in minikube with tag "testing" $0 -m -t testing build + - Build PySpark docker image + $0 -r docker.io/myrepo -t v2.3.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + - Build and push image with tag "v2.3.0" to docker.io/myrepo $0 -r docker.io/myrepo -t v2.3.0 build $0 -r docker.io/myrepo -t v2.3.0 push diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index a7b6fd12a3e5f..a9d448820e700 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -88,6 +88,18 @@ $ ./bin/docker-image-tool.sh -r -t my-tag build $ ./bin/docker-image-tool.sh -r -t my-tag push ``` +By default `bin/docker-image-tool.sh` builds docker image for running JVM jobs. You need to opt-in to build additional +language binding docker images. + +Example usage is +```bash +# To build additional PySpark docker image +$ ./bin/docker-image-tool.sh -r -t my-tag -p ./kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + +# To build additional SparkR docker image +$ ./bin/docker-image-tool.sh -r -t my-tag -R ./kubernetes/dockerfiles/spark/bindings/R/Dockerfile build +``` + ## Cluster Mode To launch Spark Pi in cluster mode, diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index a4a9f5b7da131..36e30d7b2cffb 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -72,10 +72,16 @@ then IMAGE_TAG=$(uuidgen); cd $UNPACKED_SPARK_TGZ + # Build PySpark image + LANGUAGE_BINDING_BUILD_ARGS="-p $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/python/Dockerfile" + + # Build SparkR image + LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/R/Dockerfile" + case $DEPLOY_MODE in cloud) # Build images - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build # Push images appropriately if [[ $IMAGE_REPO == gcr.io* ]] ; @@ -89,13 +95,13 @@ then docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; minikube) # Only need to build and if we do this with the -m option for minikube we will # build the images directly using the minikube Docker daemon so no need to push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; *) echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 From ce7b57cb5d552ac3df8557a3863792c425005994 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Nov 2018 08:02:23 +0800 Subject: [PATCH 2110/2461] [SPARK-26106][PYTHON] Prioritizes ML unittests over the doctests in PySpark ## What changes were proposed in this pull request? Arguably, unittests usually takes longer then doctests. We better prioritize unittests over doctests. Other modules are already being prioritized over doctests. Looks ML module was missed at the very first place. ## How was this patch tested? Jenkins tests. Closes #23078 from HyukjinKwon/SPARK-26106. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- python/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/run-tests.py b/python/run-tests.py index 9fd1c9b94ac6f..01a6e81264dd6 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -250,7 +250,7 @@ def main(): if python_implementation not in module.blacklisted_python_implementations: for test_goal in module.python_test_goals: heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests'] + 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): priority = 0 else: From 38628dd1b8298d2686e5d00de17c461c70db99a8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 22 Nov 2018 09:35:29 +0800 Subject: [PATCH 2111/2461] [SPARK-25935][SQL] Prevent null rows from JSON parser ## What changes were proposed in this pull request? An input without valid JSON tokens on the root level will be treated as a bad record, and handled according to `mode`. Previously such input was converted to `null`. After the changes, the input is converted to a row with `null`s in the `PERMISSIVE` mode according the schema. This allows to remove a code in the `from_json` function which can produce `null` as result rows. ## How was this patch tested? It was tested by existing test suites. Some of them I have to modify (`JsonSuite` for example) because previously bad input was just silently ignored. For now such input is handled according to specified `mode`. Closes #22938 from MaxGekk/json-nulls. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- docs/sql-migration-guide-upgrade.md | 2 ++ .../expressions/jsonExpressions.scala | 26 ++++++++++++------- .../sql/catalyst/json/JacksonParser.scala | 2 +- .../expressions/JsonExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ------- .../datasources/json/JsonSuite.scala | 12 ++++++--- 7 files changed, 31 insertions(+), 25 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 059c9f3057242..f355a515935c8 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1674,7 +1674,7 @@ test_that("column functions", { # check for unparseable df <- as.DataFrame(list(list("a" = ""))) - expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]]$a, NA) # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 07079d93f25b6..e8f2bcc9adfb4 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -15,6 +15,8 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. + - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. + - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 52d0677f4022f..543c6c41de58a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -550,15 +550,23 @@ case class JsonToStructs( s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.") } - // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = nullableSchema match { - case _: StructType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null - case _: ArrayType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null - case _: MapType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null + private lazy val castRow = nullableSchema match { + case _: StructType => (row: InternalRow) => row + case _: ArrayType => (row: InternalRow) => row.getArray(0) + case _: MapType => (row: InternalRow) => row.getMap(0) + } + + // This converts parsed rows to the desired output by the given schema. + private def convertRow(rows: Iterator[InternalRow]) = { + if (rows.hasNext) { + val result = rows.next() + // JSON's parser produces one record only. + assert(!rows.hasNext) + castRow(result) + } else { + throw new IllegalArgumentException("Expected one row from JSON parser.") + } } val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) @@ -593,7 +601,7 @@ case class JsonToStructs( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { - converter(parser.parse(json.asInstanceOf[UTF8String])) + convertRow(parser.parse(json.asInstanceOf[UTF8String])) } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 57c7f2faf3107..773ff5a7a4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -399,7 +399,7 @@ class JacksonParser( // a null first token is equivalent to testing for input.trim.isEmpty // but it works on any token stream and not just strings parser.nextToken() match { - case null => Nil + case null => throw new RuntimeException("Not found any JSON token") case _ => rootConverter.apply(parser) match { case null => throw new RuntimeException("Root converter returned null") case rows => rows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 6ee8c74010d3d..34bd2a99b2b4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -547,7 +547,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), - null + InternalRow(null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dbb0790a4682c..4cc8a45391996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -240,16 +240,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row("1"), Row("2"))) } - test("SPARK-11226 Skip empty line in json file") { - spark.read - .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", "").toDS()) - .createOrReplaceTempView("d") - - checkAnswer( - sql("select count(1) from d"), - Seq(Row(3))) - } - test("SPARK-8828 sum should return null if all input values are null") { checkAnswer( sql("select sum(a), avg(a) from allNulls"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 06032ded42a53..9ea9189cdf7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1115,6 +1115,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(null, null, null), Row(null, null, null), Row(null, null, null), + Row(null, null, null), Row("str_a_4", "str_b_4", "str_c_4"), Row(null, null, null)) ) @@ -1136,6 +1137,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.select($"a", $"b", $"c", $"_unparsed"), Row(null, null, null, "{") :: + Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1150,6 +1152,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.filter($"_unparsed".isNotNull).select($"_unparsed"), Row("{") :: + Row("") :: Row("""{"a":1, b:2}""") :: Row("""{"a":{, b:3}""") :: Row("]") :: Nil @@ -1171,6 +1174,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.selectExpr("a", "b", "c", "_malformed"), Row(null, null, null, "{") :: + Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1813,6 +1817,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") + .repartition(1) .write .option("compression", "GzIp") .text(path) @@ -1838,6 +1843,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") + .repartition(1) .write .text(path) @@ -1892,7 +1898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .text(path) val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) - assert(jsonDF.count() === corruptRecordCount) + assert(jsonDF.count() === corruptRecordCount + 1) // null row for empty file assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) .add("dummy", StringType)) @@ -1905,7 +1911,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { F.count($"dummy").as("valid"), F.count($"_corrupt_record").as("corrupt"), F.count("*").as("count")) - checkAnswer(counts, Row(1, 4, 6)) + checkAnswer(counts, Row(1, 5, 7)) // null row for empty file } } @@ -2513,7 +2519,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } checkCount(2) - countForMalformedJSON(0, Seq("")) + countForMalformedJSON(1, Seq("")) } test("SPARK-25040: empty strings should be disallowed") { From ab2eafb3cdc7631452650c6cac03a92629255347 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Nov 2018 10:50:01 +0800 Subject: [PATCH 2112/2461] [SPARK-26085][SQL] Key attribute of non-struct type under typed aggregation should be named as "key" too ## What changes were proposed in this pull request? When doing typed aggregation on a Dataset, for struct key type, the key attribute is named as "key". But for non-struct type, the key attribute is named as "value". This key attribute should also be named as "key" for non-struct type. ## How was this patch tested? Added test. Closes #23054 from viirya/SPARK-26085. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 ++ .../org/apache/spark/sql/internal/SQLConf.scala | 12 ++++++++++++ .../apache/spark/sql/KeyValueGroupedDataset.scala | 7 ++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index e8f2bcc9adfb4..397ca59d96497 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -20,6 +20,8 @@ displayTitle: Spark SQL Upgrading Guide - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. + + - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. ## Upgrading From Spark SQL 2.3 to 2.4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc0e9727812db..7bcf21595ce5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1595,6 +1595,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE = + buildConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue") + .internal() + .doc("When set to true, the key attribute resulted from running `Dataset.groupByKey` " + + "for non-struct key type, will be named as `value`, following the behavior of Spark " + + "version 2.4 and earlier.") + .booleanConf + .createWithDefault(false) + val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields") .doc("Maximum number of fields of sequence-like entries can be converted to strings " + "in debug output. Any elements beyond the limit will be dropped and replaced by a" + @@ -2016,6 +2025,9 @@ class SQLConf extends Serializable with Logging { def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + def nameNonStructGroupingKeyAsValue: Boolean = + getConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE) + def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 7a47242f69381..2d849c65997a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** @@ -459,7 +460,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( columns.map(_.withInputType(vExprEnc, dataAttributes).named) val keyColumn = if (!kExprEnc.isSerializedAsStruct) { assert(groupingAttributes.length == 1) - groupingAttributes.head + if (SQLConf.get.nameNonStructGroupingKeyAsValue) { + groupingAttributes.head + } else { + Alias(groupingAttributes.head, "key")() + } } else { Alias(CreateStruct(groupingAttributes), "key")() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 540fbff6a3a63..baece2ddac7eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1572,6 +1572,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) } + test("SPARK-26085: fix key attribute name for atomic type for typed aggregation") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.groupByKey(x => x).count().schema.head.name == "key") + + // Enable legacy flag to follow previous Spark behavior + withSQLConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE.key -> "true") { + assert(ds.groupByKey(x => x).count().schema.head.name == "value") + } + } + test("SPARK-8288: class with only a companion object constructor") { val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2)) val ds = data.toDS From 8d54bf79f215378fbd95794591a87604a5eaf7a3 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 22 Nov 2018 10:57:19 +0800 Subject: [PATCH 2113/2461] [SPARK-26099][SQL] Verification of the corrupt column in from_csv/from_json ## What changes were proposed in this pull request? The corrupt column specified via JSON/CSV option *columnNameOfCorruptRecord* must have the `string` type and be `nullable`. This has been already checked in `DataFrameReader`.`csv`/`json` and in `Json`/`CsvFileFormat` but not in `from_json`/`from_csv`. The PR adds such checks inside functions as well. ## How was this patch tested? Added tests to `Json`/`CsvExpressionSuite` for checking type of the corrupt column. They don't check the `nullable` property because `schema` is forcibly casted to nullable. Closes #23070 from MaxGekk/verify-corrupt-column-csv-json. Authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../sql/catalyst/expressions/ExprUtils.scala | 16 ++++++++++++++ .../catalyst/expressions/csvExpressions.scala | 4 ++++ .../expressions/jsonExpressions.scala | 1 + .../expressions/CsvExpressionsSuite.scala | 11 ++++++++++ .../expressions/JsonExpressionsSuite.scala | 11 ++++++++++ .../apache/spark/sql/DataFrameReader.scala | 21 +++---------------- .../datasources/csv/CSVFileFormat.scala | 9 ++------ .../datasources/json/JsonFileFormat.scala | 11 +++------- 8 files changed, 51 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 040b56cc1caea..89e9071324eff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -67,4 +67,20 @@ object ExprUtils { case _ => throw new AnalysisException("Must use a map() function for options") } + + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index aff372b899f86..1e4e1c663c90e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -106,6 +106,10 @@ case class CsvToStructs( throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } + ExprUtils.verifyColumnNameOfCorruptRecord( + nullableSchema, + parsedOptions.columnNameOfCorruptRecord) + val actualSchema = StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 543c6c41de58a..47304d835fdf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -579,6 +579,7 @@ case class JsonToStructs( } val (parserSchema, actualSchema) = nullableSchema match { case s: StructType => + ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) case other => (StructType(StructField("value", other) :: Nil), other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index f5aaaec456153..98c93a4946f4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ @@ -226,4 +227,14 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P InternalRow(17836)) // number of days from 1970-01-01 } } + + test("verify corrupt column") { + checkExceptionInExpression[AnalysisException]( + CsvToStructs( + schema = StructType.fromDDL("i int, _unparsed boolean"), + options = Map("columnNameOfCorruptRecord" -> "_unparsed"), + child = Literal.create("a"), + timeZoneId = gmtId), + expectedErrMsg = "The field for corrupt records must be string type and nullable") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 34bd2a99b2b4d..9b89a27c23770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -754,4 +755,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with InternalRow(17836)) // number of days from 1970-01-01 } } + + test("verify corrupt column") { + checkExceptionInExpression[AnalysisException]( + JsonToStructs( + schema = StructType.fromDDL("i int, _unparsed boolean"), + options = Map("columnNameOfCorruptRecord" -> "_unparsed"), + child = Literal.create("""{"i":"a"}"""), + timeZoneId = gmtId), + expectedErrMsg = "The field for corrupt records must be string type and nullable") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 52df13d39caa7..f08fd64acd9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.command.DDLUtils @@ -442,7 +443,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } - verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) @@ -504,7 +505,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions) } - verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) @@ -765,22 +766,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } - /** - * A convenient function for schema validation in datasources supporting - * `columnNameOfCorruptRecord` as an option. - */ - private def verifyColumnNameOfCorruptRecord( - schema: StructType, - columnNameOfCorruptRecord: String): Unit = { - schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } - } - /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 964b56e706a0b..ff1911d69a6b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser} +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -110,13 +111,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.columnNameOfCorruptRecord) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) if (requiredSchema.length == 1 && requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 1f7c9d73f19fe..610f0d1619fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,8 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -107,13 +108,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) if (requiredSchema.length == 1 && requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { From 15c038497791e7735898356db2464b8732695365 Mon Sep 17 00:00:00 2001 From: Takanobu Asanuma Date: Wed, 21 Nov 2018 23:09:57 -0800 Subject: [PATCH 2114/2461] [SPARK-26134][CORE] Upgrading Hadoop to 2.7.4 to fix java.version problem ## What changes were proposed in this pull request? When I ran spark-shell on JDK11+28(2018-09-25), It failed with the error below. ``` Exception in thread "main" java.lang.ExceptionInInitializerError at org.apache.hadoop.util.StringUtils.(StringUtils.java:80) at org.apache.hadoop.security.SecurityUtil.getAuthenticationMethod(SecurityUtil.java:611) at org.apache.hadoop.security.UserGroupInformation.initialize(UserGroupInformation.java:273) at org.apache.hadoop.security.UserGroupInformation.ensureInitialized(UserGroupInformation.java:261) at org.apache.hadoop.security.UserGroupInformation.loginUserFromSubject(UserGroupInformation.java:791) at org.apache.hadoop.security.UserGroupInformation.getLoginUser(UserGroupInformation.java:761) at org.apache.hadoop.security.UserGroupInformation.getCurrentUser(UserGroupInformation.java:634) at org.apache.spark.util.Utils$.$anonfun$getCurrentUserName$1(Utils.scala:2427) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.util.Utils$.getCurrentUserName(Utils.scala:2427) at org.apache.spark.SecurityManager.(SecurityManager.scala:79) at org.apache.spark.deploy.SparkSubmit.secMgr$lzycompute$1(SparkSubmit.scala:359) at org.apache.spark.deploy.SparkSubmit.secMgr$1(SparkSubmit.scala:359) at org.apache.spark.deploy.SparkSubmit.$anonfun$prepareSubmitEnvironment$9(SparkSubmit.scala:367) at scala.Option.map(Option.scala:146) at org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment(SparkSubmit.scala:367) at org.apache.spark.deploy.SparkSubmit.submit(SparkSubmit.scala:143) at org.apache.spark.deploy.SparkSubmit.doSubmit(SparkSubmit.scala:86) at org.apache.spark.deploy.SparkSubmit$$anon$2.doSubmit(SparkSubmit.scala:927) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:936) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: java.lang.StringIndexOutOfBoundsException: begin 0, end 3, length 2 at java.base/java.lang.String.checkBoundsBeginEnd(String.java:3319) at java.base/java.lang.String.substring(String.java:1874) at org.apache.hadoop.util.Shell.(Shell.java:52) ``` This is a Hadoop issue that fails to parse some java.version. It has been fixed from Hadoop-2.7.4(see [HADOOP-14586](https://issues.apache.org/jira/browse/HADOOP-14586)). Note, Hadoop-2.7.5 or upper have another problem with Spark ([SPARK-25330](https://issues.apache.org/jira/browse/SPARK-25330)). So upgrading to 2.7.4 would be fine for now. ## How was this patch tested? Existing tests. Closes #23101 from tasanuma/SPARK-26134. Authored-by: Takanobu Asanuma Signed-off-by: Dongjoon Hyun --- assembly/README | 2 +- dev/deps/spark-deps-hadoop-2.7 | 31 ++++++++++--------- pom.xml | 2 +- .../kubernetes/integration-tests/README.md | 2 +- .../hive/client/IsolatedClientLoader.scala | 2 +- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/assembly/README b/assembly/README index d5dafab477410..1fd6d8858348c 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.3 + -Dhadoop.version=2.7.4 diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index c2f5755ca9925..ec7c304c9e36b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -64,21 +64,21 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.3.jar -hadoop-auth-2.7.3.jar -hadoop-client-2.7.3.jar -hadoop-common-2.7.3.jar -hadoop-hdfs-2.7.3.jar -hadoop-mapreduce-client-app-2.7.3.jar -hadoop-mapreduce-client-common-2.7.3.jar -hadoop-mapreduce-client-core-2.7.3.jar -hadoop-mapreduce-client-jobclient-2.7.3.jar -hadoop-mapreduce-client-shuffle-2.7.3.jar -hadoop-yarn-api-2.7.3.jar -hadoop-yarn-client-2.7.3.jar -hadoop-yarn-common-2.7.3.jar -hadoop-yarn-server-common-2.7.3.jar -hadoop-yarn-server-web-proxy-2.7.3.jar +hadoop-annotations-2.7.4.jar +hadoop-auth-2.7.4.jar +hadoop-client-2.7.4.jar +hadoop-common-2.7.4.jar +hadoop-hdfs-2.7.4.jar +hadoop-mapreduce-client-app-2.7.4.jar +hadoop-mapreduce-client-common-2.7.4.jar +hadoop-mapreduce-client-core-2.7.4.jar +hadoop-mapreduce-client-jobclient-2.7.4.jar +hadoop-mapreduce-client-shuffle-2.7.4.jar +hadoop-yarn-api-2.7.4.jar +hadoop-yarn-client-2.7.4.jar +hadoop-yarn-common-2.7.4.jar +hadoop-yarn-server-common-2.7.4.jar +hadoop-yarn-server-web-proxy-2.7.4.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar @@ -117,6 +117,7 @@ jersey-guava-2.22.2.jar jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jetty-6.1.26.jar +jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar jline-2.14.6.jar joda-time-2.9.3.jar diff --git a/pom.xml b/pom.xml index 08a29d2d52310..93075e9b06a68 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ spark 1.7.16 1.2.17 - 2.7.3 + 2.7.4 2.5.0 ${hadoop.version} 3.4.6 diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 64f8e77597eba..73fc0581d64f5 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -107,7 +107,7 @@ properties to Maven. For example: mvn integration-test -am -pl :spark-kubernetes-integration-tests_2.11 \ -Pkubernetes -Pkubernetes-integration-tests \ - -Phadoop-2.7 -Dhadoop.version=2.7.3 \ + -Phadoop-2.7 -Dhadoop.version=2.7.4 \ -Dspark.kubernetes.test.sparkTgz=spark-3.0.0-SNAPSHOT-bin-example.tgz \ -Dspark.kubernetes.test.imageTag=sometag \ -Dspark.kubernetes.test.imageRepo=docker.io/somerepo \ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index f56ca8cb08553..ca98c30add168 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -65,7 +65,7 @@ private[hive] object IsolatedClientLoader extends Logging { case e: RuntimeException if e.getMessage.contains("hadoop") => // If the error message contains hadoop, it is probably because the hadoop // version cannot be resolved. - val fallbackVersion = "2.7.3" + val fallbackVersion = "2.7.4" logWarning(s"Failed to resolve Hadoop artifacts for the version $hadoopVersion. We " + s"will change the hadoop version from $hadoopVersion to $fallbackVersion and try " + "again. Hadoop classes will not be shared between Spark and Hive metastore client. " + From ab00533490953164cb2360bf2b9adc2c9fa962db Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 22 Nov 2018 02:27:06 -0800 Subject: [PATCH 2115/2461] [SPARK-26129][SQL] edge behavior for QueryPlanningTracker.topRulesByTime - followup patch ## What changes were proposed in this pull request? This is an addendum patch for SPARK-26129 that defines the edge case behavior for QueryPlanningTracker.topRulesByTime. ## How was this patch tested? Added unit tests for each behavior. Closes #23110 from rxin/SPARK-26129-1. Authored-by: Reynold Xin Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/QueryPlanningTracker.scala | 17 ++++++++++++----- .../catalyst/QueryPlanningTrackerSuite.scala | 9 ++++++++- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala index 420f2a1f20997..244081cd160b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala @@ -116,12 +116,19 @@ class QueryPlanningTracker { def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap - /** Returns the top k most expensive rules (as measured by time). */ + /** + * Returns the top k most expensive rules (as measured by time). If k is larger than the rules + * seen so far, return all the rules. If there is no rule seen so far or k <= 0, return empty seq. + */ def topRulesByTime(k: Int): Seq[(String, RuleSummary)] = { - val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) - val q = new BoundedPriorityQueue(k)(orderingByTime) - rulesMap.asScala.foreach(q.+=) - q.toSeq.sortBy(r => -r._2.totalTimeNs) + if (k <= 0) { + Seq.empty + } else { + val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) + val q = new BoundedPriorityQueue(k)(orderingByTime) + rulesMap.asScala.foreach(q.+=) + q.toSeq.sortBy(r => -r._2.totalTimeNs) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala index f42c262dfbdd8..120b284a77854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala @@ -62,17 +62,24 @@ class QueryPlanningTrackerSuite extends SparkFunSuite { test("topRulesByTime") { val t = new QueryPlanningTracker + + // Return empty seq when k = 0 + assert(t.topRulesByTime(0) == Seq.empty) + assert(t.topRulesByTime(1) == Seq.empty) + t.recordRuleInvocation("r2", 2, effective = true) t.recordRuleInvocation("r4", 4, effective = true) t.recordRuleInvocation("r1", 1, effective = false) t.recordRuleInvocation("r3", 3, effective = false) + // k <= total size + assert(t.topRulesByTime(0) == Seq.empty) val top = t.topRulesByTime(2) assert(top.size == 2) assert(top(0)._1 == "r4") assert(top(1)._1 == "r3") - // Don't crash when k > total size + // k > total size assert(t.topRulesByTime(10).size == 4) } } From aeda76e2b74ef07b2814770d68cf145cdbb0197c Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Thu, 22 Nov 2018 15:43:04 -0600 Subject: [PATCH 2116/2461] [GRAPHX] Remove unused variables left over by previous refactoring. ## What changes were proposed in this pull request? Some variables were previously used for indexing the routing table's backing array, but that indexing now happens elsewhere, and so the variables aren't needed. ## How was this patch tested? Unit tests. (This contribution is my original work and I license the work to Spark under its open source license.) Closes #23112 from huonw/remove-unused-variables. Authored-by: Huon Wilson Signed-off-by: Sean Owen --- .../apache/spark/graphx/impl/ShippableVertexPartition.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index a4e293d74a012..184b96426fa9b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -117,13 +117,11 @@ class ShippableVertexPartition[VD: ClassTag]( val initialSize = if (shipSrc && shipDst) routingTable.partitionSize(pid) else 64 val vids = new PrimitiveVector[VertexId](initialSize) val attrs = new PrimitiveVector[VD](initialSize) - var i = 0 routingTable.foreachWithinEdgePartition(pid, shipSrc, shipDst) { vid => if (isDefined(vid)) { vids += vid attrs += this(vid) } - i += 1 } (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array)) } @@ -137,12 +135,10 @@ class ShippableVertexPartition[VD: ClassTag]( def shipVertexIds(): Iterator[(PartitionID, Array[VertexId])] = { Iterator.tabulate(routingTable.numEdgePartitions) { pid => val vids = new PrimitiveVector[VertexId](routingTable.partitionSize(pid)) - var i = 0 routingTable.foreachWithinEdgePartition(pid, true, true) { vid => if (isDefined(vid)) { vids += vid } - i += 1 } (pid, vids.trim().array) } From dd8c179c28c5df20210b70a69d93d866ccaca4cc Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Nov 2018 15:45:25 -0600 Subject: [PATCH 2117/2461] [SPARK-25867][ML] Remove KMeans computeCost ## What changes were proposed in this pull request? The PR removes the deprecated method `computeCost` of `KMeans`. ## How was this patch tested? NA Closes #22875 from mgaido91/SPARK-25867. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../org/apache/spark/ml/clustering/KMeans.scala | 16 ---------------- .../apache/spark/ml/clustering/KMeansSuite.scala | 12 +++++------- project/MimaExcludes.scala | 3 +++ python/pyspark/ml/clustering.py | 16 ---------------- 4 files changed, 8 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 498310d6644e1..919496aa1a840 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -143,22 +143,6 @@ class KMeansModel private[ml] ( @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) - /** - * Return the K-means cost (sum of squared distances of points to their nearest center) for this - * model on the given data. - * - * @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator - * instead. You can also get the cost on the training dataset in the summary. - */ - @deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " + - "instead. You can also get the cost on the training dataset in the summary.", "2.4.0") - @Since("2.0.0") - def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) - val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - parentModel.computeCost(data) - } - /** * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ccbceab53bb66..4f47d91f0d0d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -117,7 +117,6 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(clusters === Set(0, 1, 2, 3, 4)) } - assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) // Check validity of model summary @@ -132,7 +131,6 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } assert(summary.cluster.columns === Array(predictionColName)) assert(summary.trainingCost < 0.1) - assert(model.computeCost(dataset) == summary.trainingCost) val clusterSizes = summary.clusterSizes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) @@ -201,15 +199,15 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } test("KMean with Array input") { - def trainAndComputeCost(dataset: Dataset[_]): Double = { + def trainAndGetCost(dataset: Dataset[_]): Double = { val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) - model.computeCost(dataset) + model.summary.trainingCost } val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) - val trueCost = trainAndComputeCost(newDataset) - val doubleArrayCost = trainAndComputeCost(newDatasetD) - val floatArrayCost = trainAndComputeCost(newDatasetF) + val trueCost = trainAndGetCost(newDataset) + val doubleArrayCost = trainAndGetCost(newDatasetD) + val floatArrayCost = trainAndGetCost(newDatasetF) // checking the cost is fine enough as a sanity check assert(trueCost ~== doubleArrayCost absTol 1e-6) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9089c7d9ffc70..333adb0c84025 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25867] Remove KMeans computeCost + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), + // [SPARK-26127] Remove deprecated setters from tree regression and classification models ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index aaeeeb82d3d86..d0b507ec5dad4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -335,20 +335,6 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] - @since("2.0.0") - def computeCost(self, dataset): - """ - Return the K-means cost (sum of squared distances of points to their nearest center) - for this model on the given data. - - ..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead. - You can also get the cost on the training dataset in the summary. - """ - warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator " - "instead. You can also get the cost on the training dataset in the summary.", - DeprecationWarning) - return self._call_java("computeCost", dataset) - @property @since("2.1.0") def hasSummary(self): @@ -387,8 +373,6 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol >>> centers = model.clusterCenters() >>> len(centers) 2 - >>> model.computeCost(df) - 2.0 >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction From d81d95a7e8a621e42c9c61305c32df72b6e868be Mon Sep 17 00:00:00 2001 From: oraviv Date: Thu, 22 Nov 2018 15:48:01 -0600 Subject: [PATCH 2118/2461] [SPARK-19368][MLLIB] BlockMatrix.toIndexedRowMatrix() optimization for sparse matrices ## What changes were proposed in this pull request? Optimization [SPARK-12869] was made for dense matrices but caused great performance issue for sparse matrices because manipulating them is very inefficient. When manipulating sparse matrices in Breeze we better use VectorBuilder. ## How was this patch tested? checked it against a use case that we have that after moving to Spark 2 took 6.5 hours instead of 20 mins. After the change it is back to 20 mins again. Closes #16732 from uzadude/SparseVector_optimization. Authored-by: oraviv Signed-off-by: Sean Owen --- .../linalg/distributed/BlockMatrix.scala | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 7caacd13b3459..e58860fea97d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.linalg.distributed +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM} import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV} - import org.apache.spark.{Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging @@ -28,6 +27,7 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel + /** * A grid partitioner, which uses a regular grid to partition coordinates. * @@ -273,24 +273,37 @@ class BlockMatrix @Since("1.3.0") ( require(cols < Int.MaxValue, s"The number of columns should be less than Int.MaxValue ($cols).") val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) => - mat.rowIter.zipWithIndex.map { + mat.rowIter.zipWithIndex.filter(_._1.size > 0).map { case (vector, rowIdx) => - blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector.asBreeze)) + blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector)) } }.groupByKey().map { case (rowIdx, vectors) => - val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble - - val wholeVector = if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz - BSV.zeros[Double](cols) - } else { - BDV.zeros[Double](cols) - } + val numberNonZero = vectors.map(_._2.numActives).sum + val numberNonZeroPerRow = numberNonZero.toDouble / cols.toDouble + + val wholeVector = + if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz + val arrBufferIndices = new ArrayBuffer[Int](numberNonZero) + val arrBufferValues = new ArrayBuffer[Double](numberNonZero) + + vectors.foreach { case (blockColIdx: Int, vec: Vector) => + val offset = colsPerBlock * blockColIdx + vec.foreachActive { case (colIdx: Int, value: Double) => + arrBufferIndices += offset + colIdx + arrBufferValues += value + } + } + Vectors.sparse(cols, arrBufferIndices.toArray, arrBufferValues.toArray) + } else { + val wholeVectorBuf = BDV.zeros[Double](cols) + vectors.foreach { case (blockColIdx: Int, vec: Vector) => + val offset = colsPerBlock * blockColIdx + wholeVectorBuf(offset until Math.min(cols, offset + colsPerBlock)) := vec.asBreeze + } + Vectors.fromBreeze(wholeVectorBuf) + } - vectors.foreach { case (blockColIdx: Int, vec: BV[_]) => - val offset = colsPerBlock * blockColIdx - wholeVector(offset until Math.min(cols, offset + colsPerBlock)) := vec - } - new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector)) + IndexedRow(rowIdx, wholeVector) } new IndexedRowMatrix(rows) } From 1d766f0e222c24e8e8cad68e664e83f4f71f7541 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 22 Nov 2018 14:49:41 -0800 Subject: [PATCH 2119/2461] [SPARK-26144][BUILD] `build/mvn` should detect `scala.version` based on `scala.binary.version` ## What changes were proposed in this pull request? Currently, `build/mvn` downloads and uses **Scala 2.12.7** in `Scala-2.11` Jenkins job. The root cause is `build/mvn` got the first match from `pom.xml` blindly. - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.11/6/consoleFull ``` exec: curl -s -L https://downloads.lightbend.com/zinc/0.3.15/zinc-0.3.15.tgz exec: curl -s -L https://downloads.lightbend.com/scala/2.12.7/scala-2.12.7.tgz exec: curl -s -L https://www.apache.org/dyn/closer.lua?action=download&filename=/maven/maven-3/3.5.4/binaries/apache-maven-3.5.4-bin.tar.gz ``` ## How was this patch tested? Manual. ``` $ build/mvn clean exec: curl --progress-bar -L https://downloads.lightbend.com/scala/2.12.7/scala-2.12.7.tgz ... $ git clean -fdx $ dev/change-scala-version.sh 2.11 $ build/mvn clean exec: curl --progress-bar -L https://downloads.lightbend.com/scala/2.11.12/scala-2.11.12.tgz ``` Closes #23118 from dongjoon-hyun/SPARK-26144. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- build/mvn | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index 3816993b4e5c8..4cb10e0d03fa4 100755 --- a/build/mvn +++ b/build/mvn @@ -116,7 +116,8 @@ install_zinc() { # the build/ folder install_scala() { # determine the Scala version used in Spark - local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local scala_binary_version=`grep "scala.binary.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | grep ${scala_binary_version} | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} From 76aae7f1fd512f150ffcdb618107b12e1e97fe43 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Nov 2018 14:54:00 -0800 Subject: [PATCH 2120/2461] [SPARK-24553][UI][FOLLOWUP] Fix unnecessary UI redirect ## What changes were proposed in this pull request? This PR is a follow-up PR of #21600 to fix the unnecessary UI redirect. ## How was this patch tested? Local verification Closes #23116 from jerryshao/SPARK-24553. Authored-by: jerryshao Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../scala/org/apache/spark/ui/storage/StoragePage.scala | 2 +- .../org/apache/spark/ui/storage/StoragePageSuite.scala | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b9abd39b4705d..766efc15e26ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -368,7 +368,7 @@ private[ui] class StagePagedTable( {if (cachedRddInfos.nonEmpty) { Text("RDD: ") ++ cachedRddInfos.map { i => -
      {i.name} + {i.name} } }}
      {s.details}
      diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 3eb546e336e99..2488197814ffd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -78,7 +78,7 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends {rdd.id} - {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index cdc7f541b9552..06f01a60868f9 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -81,19 +81,19 @@ class StoragePageSuite extends SparkFunSuite { Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=1")) + Some("http://localhost:4040/storage/rdd/?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=2")) + Some("http://localhost:4040/storage/rdd/?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=3")) + Some("http://localhost:4040/storage/rdd/?id=3")) } test("empty rddTable") { From 0ec7b99ea2b638453ed38bb092905bee4f907fe5 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Fri, 23 Nov 2018 08:55:00 +0800 Subject: [PATCH 2121/2461] [SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float GROUP BY treats -0.0 and 0.0 as different values which is unlike hive's behavior. In addition current behavior with codegen is unpredictable (see example in JIRA ticket). ## What changes were proposed in this pull request? In Platform.putDouble/Float() checking if the value is -0.0, and if so replacing with 0.0. This is used by UnsafeRow so it won't have -0.0 values. ## How was this patch tested? Added tests Closes #23043 from adoron/adoron-spark-26021-replace-minus-zero-with-zero. Authored-by: Alon Doron Signed-off-by: Wenchen Fan --- .../java/org/apache/spark/unsafe/Platform.java | 10 ++++++++++ .../org/apache/spark/unsafe/PlatformUtilSuite.java | 14 ++++++++++++++ .../spark/sql/catalyst/expressions/UnsafeRow.java | 6 ------ .../catalyst/expressions/codegen/UnsafeWriter.java | 6 ------ .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 ++++++++++++++ .../scala/org/apache/spark/sql/QueryTest.scala | 5 ++++- 6 files changed, 42 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 076b693f81c88..4563efcfcf474 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -174,6 +174,11 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } _UNSAFE.putFloat(object, offset, value); } @@ -182,6 +187,11 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } _UNSAFE.putDouble(object, offset, value); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..ab34324eb54cc 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,4 +157,18 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + // SPARK-26021 + public void writeMinusZeroIsReplacedWithZero() { + byte[] doubleBytes = new byte[Double.BYTES]; + byte[] floatBytes = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); + double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); + float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); + + Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); + Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef8c91c1..9bf9452855f5f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 2781655002000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) { } protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d9ba6e2ce5120..ff64edcd07f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -723,4 +723,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { + val colName = "i" + val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() + val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() + + assert(doubles.length == 1) + assert(floats.length == 1) + // using compare since 0.0 == -0.0 is true + assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) + assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) + assert(doubles(0).getLong(1) == 3) + assert(floats(0).getLong(1) == 3) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index baca9c1cfb9a0..8ba67239fb907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -289,7 +289,7 @@ object QueryTest { def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null - case d: java.math.BigDecimal => BigDecimal(d) + case bd: java.math.BigDecimal => BigDecimal(bd) // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ case seq: Seq[_] => seq.map { case b: java.lang.Byte => b.byteValue @@ -303,6 +303,9 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) + // spark treats -0.0 as 0.0 + case d: Double if d == -0.0d => 0.0d + case f: Float if f == -0.0f => 0.0f case o => o }) } From 1d3dd58d21400b5652b75af7e7e53aad85a31528 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 22 Nov 2018 22:45:08 -0800 Subject: [PATCH 2122/2461] [SPARK-25954][SS][FOLLOWUP][TEST-MAVEN] Add Zookeeper 3.4.7 test dependency to Kafka modules ## What changes were proposed in this pull request? This is a followup of #23099 . After upgrading to Kafka 2.1.0, maven test fails due to Zookeeper test dependency while sbt test succeeds. - [sbt test on master branch](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/5203/) - [maven test on master branch](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7/5653/) The root cause is that the embedded Kafka server is using [Zookeepr 3.4.7 API](https://zookeeper.apache.org/doc/r3.4.7/api/org/apache/zookeeper/AsyncCallback.MultiCallback.html ) while Apache Spark provides Zookeeper 3.4.6. This PR adds a test dependency. ``` KafkaMicroBatchV2SourceSuite: *** RUN ABORTED *** ... org.apache.spark.sql.kafka010.KafkaTestUtils.setupEmbeddedKafkaServer(KafkaTestUtils.scala:123) ... Cause: java.lang.ClassNotFoundException: org.apache.zookeeper.AsyncCallback$MultiCallback at java.net.URLClassLoader.findClass(URLClassLoader.java:381) at java.lang.ClassLoader.loadClass(ClassLoader.java:424) at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:331) at java.lang.ClassLoader.loadClass(ClassLoader.java:357) at kafka.zk.KafkaZkClient$.apply(KafkaZkClient.scala:1693) at kafka.server.KafkaServer.createZkClient$1(KafkaServer.scala:348) at kafka.server.KafkaServer.initZkClient(KafkaServer.scala:372) at kafka.server.KafkaServer.startup(KafkaServer.scala:202) at org.apache.spark.sql.kafka010.KafkaTestUtils.$anonfun$setupEmbeddedKafkaServer$2(KafkaTestUtils.scala:120) at org.apache.spark.sql.kafka010.KafkaTestUtils.$anonfun$setupEmbeddedKafkaServer$2$adapted(KafkaTestUtils.scala:116) ... ``` ## How was this patch tested? Pass the maven Jenkins test. Closes #23119 from dongjoon-hyun/SPARK-25954-2. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- external/kafka-0-10-sql/pom.xml | 7 +++++++ external/kafka-0-10/pom.xml | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index d97e8cf18605e..1af407167597b 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -89,6 +89,13 @@
      + + + org.apache.zookeeper + zookeeper + 3.4.7 + test + net.sf.jopt-simple jopt-simple diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index cfc45559d8e34..ea18b7e035915 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -74,6 +74,13 @@ + + + org.apache.zookeeper + zookeeper + 3.4.7 + test + net.sf.jopt-simple jopt-simple From 92fc0a8f9619a8e7f8382d6a5c288aeceb03a472 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 23 Nov 2018 06:18:44 -0600 Subject: [PATCH 2123/2461] [SPARK-26069][TESTS][FOLLOWUP] Add another possible error message ## What changes were proposed in this pull request? `org.apache.spark.network.RpcIntegrationSuite.sendRpcWithStreamFailures` is still flaky and here is error message: ``` sbt.ForkMain$ForkError: java.lang.AssertionError: Got a non-empty set [Failed to send RPC RPC 8249697863992194475 to /172.17.0.2:41177: java.io.IOException: Broken pipe] at org.junit.Assert.fail(Assert.java:88) at org.junit.Assert.assertTrue(Assert.java:41) at org.apache.spark.network.RpcIntegrationSuite.assertErrorAndClosed(RpcIntegrationSuite.java:389) at org.apache.spark.network.RpcIntegrationSuite.sendRpcWithStreamFailures(RpcIntegrationSuite.java:347) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.internal.runners.statements.RunBefores.evaluate(RunBefores.java:26) at org.junit.internal.runners.statements.RunAfters.evaluate(RunAfters.java:27) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runners.Suite.runChild(Suite.java:128) at org.junit.runners.Suite.runChild(Suite.java:27) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runner.JUnitCore.run(JUnitCore.java:137) at org.junit.runner.JUnitCore.run(JUnitCore.java:115) at com.novocode.junit.JUnitRunner$1.execute(JUnitRunner.java:132) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` This happened when the second RPC message was being sent but the connection was closed at the same time. ## How was this patch tested? Jenkins Closes #23109 from zsxwing/SPARK-26069-2. Authored-by: Shixiong Zhu Signed-off-by: Sean Owen --- .../spark/network/RpcIntegrationSuite.java | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 45f4a1808562d..1c0aa4da27ff9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -371,18 +371,20 @@ private void assertErrorsContain(Set errors, Set contains) { private void assertErrorAndClosed(RpcResult result, String expectedError) { assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); - // we expect 1 additional error, which should contain one of the follow messages: - // - "closed" - // - "Connection reset" - // - "java.nio.channels.ClosedChannelException" Set errors = result.errorMessages; assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + errors, 2, errors.size()); + // We expect 1 additional error due to closed connection and here are possible keywords in the + // error message. + Set possibleClosedErrors = Sets.newHashSet( + "closed", + "Connection reset", + "java.nio.channels.ClosedChannelException", + "java.io.IOException: Broken pipe" + ); Set containsAndClosed = Sets.newHashSet(expectedError); - containsAndClosed.add("closed"); - containsAndClosed.add("Connection reset"); - containsAndClosed.add("java.nio.channels.ClosedChannelException"); + containsAndClosed.addAll(possibleClosedErrors); Pair, Set> r = checkErrorsContain(errors, containsAndClosed); @@ -390,7 +392,9 @@ private void assertErrorAndClosed(RpcResult result, String expectedError) { Set errorsNotFound = r.getRight(); assertEquals( - "The size of " + errorsNotFound.toString() + " was not 2", 2, errorsNotFound.size()); + "The size of " + errorsNotFound + " was not " + (possibleClosedErrors.size() - 1), + possibleClosedErrors.size() - 1, + errorsNotFound.size()); for (String err: errorsNotFound) { assertTrue("Found a wrong error " + err, containsAndClosed.contains(err)); } From 466d011d3515723653e41d8b1d0b6150b9945f52 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 23 Nov 2018 21:12:25 +0800 Subject: [PATCH 2124/2461] [SPARK-26117][CORE][SQL] use SparkOutOfMemoryError instead of OutOfMemoryError when catch exception ## What changes were proposed in this pull request? the pr #20014 which introduced `SparkOutOfMemoryError` to avoid killing the entire executor when an `OutOfMemoryError `is thrown. so apply for memory using `MemoryConsumer. allocatePage `when catch exception, use `SparkOutOfMemoryError `instead of `OutOfMemoryError` ## How was this patch tested? N / A Closes #23084 from heary-cao/SparkOutOfMemoryError. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../java/org/apache/spark/memory/MemoryConsumer.java | 10 +++++----- .../org/apache/spark/unsafe/map/BytesToBytesMap.java | 5 +++-- .../unsafe/sort/UnsafeExternalSorterSuite.java | 7 ++++--- .../unsafe/sort/UnsafeInMemorySorterSuite.java | 5 +++-- .../catalyst/expressions/RowBasedKeyValueBatch.java | 3 ++- .../apache/spark/sql/execution/python/RowQueue.scala | 4 ++-- 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 8371deca7311d..4bfd2d358f36f 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -83,10 +83,10 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Allocates a LongArray of `size`. Note that this method may throw `OutOfMemoryError` if Spark - * doesn't have enough memory for this allocation, or throw `TooLargePageException` if this - * `LongArray` is too large to fit in a single page. The caller side should take care of these - * two exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError` + * if Spark doesn't have enough memory for this allocation, or throw `TooLargePageException` + * if this `LongArray` is too large to fit in a single page. The caller side should take care of + * these two exceptions, or make sure the `size` is small enough that won't trigger exceptions. * * @throws SparkOutOfMemoryError * @throws TooLargePageException @@ -111,7 +111,7 @@ public void freeArray(LongArray array) { /** * Allocate a memory block with at least `required` bytes. * - * @throws OutOfMemoryError + * @throws SparkOutOfMemoryError */ protected MemoryBlock allocatePage(long required) { MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 9b6cbab38cbcc..a4e88598f7607 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -31,6 +31,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; @@ -741,7 +742,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) { try { growAndRehash(); - } catch (OutOfMemoryError oom) { + } catch (SparkOutOfMemoryError oom) { canGrowArray = false; } } @@ -757,7 +758,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff private boolean acquireNewPage(long required) { try { currentPage = allocatePage(required); - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { return false; } dataPages.add(currentPage); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 411cd5cb57331..d1b29d90ad913 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -38,6 +38,7 @@ import org.apache.spark.executor.TaskMetrics; import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; @@ -534,10 +535,10 @@ public void testOOMDuringSpill() throws Exception { insertNumber(sorter, 1024); fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); } - // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) - catch (OutOfMemoryError oom){ + // we expect an SparkOutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (SparkOutOfMemoryError oom){ String oomStackTrace = Utils.exceptionString(oom); - assertThat("expected OutOfMemoryError in " + + assertThat("expected SparkOutOfMemoryError in " + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", oomStackTrace, Matchers.containsString( diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 85ffdca436e14..b0d485f0c953f 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -27,6 +27,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -178,8 +179,8 @@ public int compare( testMemoryManager.markExecutionAsOutOfMemoryOnce(); try { sorter.reset(); - fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); - } catch (OutOfMemoryError oom) { + fail("expected SparkOutOfMemoryError but it seems operation surprisingly succeeded"); + } catch (SparkOutOfMemoryError oom) { // as expected } // [SPARK-21907] this failed on NPE at diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 460513816dfd9..6344cf18c11b8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -126,7 +127,7 @@ public final void close() { private boolean acquirePage(long requiredSize) { try { page = allocatePage(requiredSize); - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { logger.warn("Failed to allocate page ({} bytes).", requiredSize); return false; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index d2820ff335ecf..eb12641f548ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -23,7 +23,7 @@ import com.google.common.io.Closeables import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform @@ -226,7 +226,7 @@ private[python] case class HybridRowQueue( val page = try { allocatePage(required) } catch { - case _: OutOfMemoryError => + case _: SparkOutOfMemoryError => null } val buffer = if (page != null) { From 8e8d1177e623d5f995fb9ba1d9574675e1e70d56 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 24 Nov 2018 00:50:20 +0900 Subject: [PATCH 2125/2461] [SPARK-26108][SQL] Support custom lineSep in CSV datasource ## What changes were proposed in this pull request? In the PR, I propose new options for CSV datasource - `lineSep` similar to Text and JSON datasource. The option allows to specify custom line separator of maximum length of 2 characters (because of a restriction in `uniVocity` parser). New option can be used in reading and writing CSV files. ## How was this patch tested? Added a few tests with custom `lineSep` for enabled/disabled `multiLine` in read as well as tests in write. Also I added roundtrip tests. Closes #23080 from MaxGekk/csv-line-sep. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- python/pyspark/sql/readwriter.py | 13 ++- python/pyspark/sql/streaming.py | 7 +- .../spark/sql/catalyst/csv/CSVOptions.scala | 23 +++- .../apache/spark/sql/DataFrameReader.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/csv/CSVDataSource.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 + .../execution/datasources/csv/CSVSuite.scala | 110 +++++++++++++++++- 8 files changed, 151 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 726de4a965418..1d2dd4d808930 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -353,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -453,6 +453,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -472,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -868,7 +871,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None): r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -922,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the default UTF-8 charset will be used. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, ``""``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. Maximum length is 1 character. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -932,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, - encoding=encoding, emptyValue=emptyValue) + encoding=encoding, emptyValue=emptyValue, lineSep=lineSep) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 58ca7b83e5b2b..d92b0d5677e25 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -576,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None, locale=None): + enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -675,6 +675,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, it uses the default value, ``en-US``. For instance, ``locale`` is used while parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -692,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue, locale=locale) + emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 6bb50b42a369c..94bdb72d675d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -192,6 +192,20 @@ class CSVOptions( */ val emptyValueInWrite = emptyValue.getOrElse("\"\"") + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + require(sep.length == 1, "'lineSep' can contain only 1 character.") + sep + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(charset) + } + val lineSeparatorInWrite: Option[String] = lineSeparator + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -200,6 +214,8 @@ class CSVOptions( format.setQuoteEscape(escape) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + lineSeparatorInWrite.foreach(format.setLineSeparator) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) @@ -216,8 +232,10 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + lineSeparator.foreach(format.setLineSeparator) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) @@ -227,7 +245,10 @@ class CSVOptions( settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - settings.setLineSeparatorDetectionEnabled(multiLine == true) + settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine) + lineSeparatorInRead.foreach { _ => + settings.setNormalizeLineEndingsWithinQuotes(!multiLine) + } settings } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f08fd64acd9a1..da88598eed061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -609,6 +609,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `multiLine` (default `false`): parse one record, which may span multiple lines.
    • *
    • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. * For instance, this is used while parsing dates and timestamps.
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
    • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 29d479f542115..5a807d3d4b93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -658,6 +658,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * whitespaces from values being written should be skipped. *
    • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not * trailing whitespaces from values being written should be skipped.
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing. + * Maximum length is 1 character.
    • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 554baaf1a9b3b..b35b8851918b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -95,7 +95,7 @@ object TextInputCSVDataSource extends CSVDataSource { headerChecker: CSVHeaderChecker, requiredSchema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index e4250145a1ae2..c8e3e1c191044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -377,6 +377,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `multiLine` (default `false`): parse one record, which may span multiple lines.
    • *
    • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. * For instance, this is used while parsing dates and timestamps.
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
    • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e29cd2aa7c4e6..c275d63d32cc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -33,7 +33,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -1880,4 +1880,110 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("""Support line separator - default value \r, \r\n and \n""") { + val data = "\"a\",1\r\"c\",2\r\n\"d\",3\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("inferSchema", true).csv(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("_c0", StringType) :: StructField("_c1", IntegerType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + def testLineSeparator(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"Support line separator in ${encoding} #${id}") { + // Read + val data = + s""""a",1$lineSep + |c,2$lineSep" + |d",3""".stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(encoding)) + val schema = StructType(StructField("_c0", StringType) + :: StructField("_c1", LongType) :: Nil) + + val expected = Seq(("a", 1), ("\nc", 2), ("\nd", 3)) + .toDF("_c0", "_c1") + Seq(false, true).foreach { multiLine => + val reader = spark + .read + .option("lineSep", lineSep) + .option("multiLine", multiLine) + .option("encoding", encoding) + val df = if (inferSchema) { + reader.option("inferSchema", true).csv(path.getAbsolutePath) + } else { + reader.schema(schema).csv(path.getAbsolutePath) + } + checkAnswer(df, expected) + } + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), encoding) + assert( + readBack === s"a${lineSep}b${lineSep}c${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val readBack = spark + .read + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, ":", "ISO-8859-1", true), + (3, "!", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "у", "CP1251", true), + (8, "\r", "UTF-16LE", true), + (9, "\u000d", "UTF-32BE", false), + (10, "=", "US-ASCII", false), + (11, "$", "utf-32le", true) + ).foreach { case (testNum, sep, encoding, inferSchema) => + testLineSeparator(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("lineSep restrictions") { + val errMsg1 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg1.contains("'lineSep' cannot be an empty string")) + + val errMsg2 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "123").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg2.contains("'lineSep' can contain only 1 character")) + } } From ecb785f4e471ce3add66c67d0d8152dd237dbfaf Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 23 Nov 2018 21:08:06 +0100 Subject: [PATCH 2126/2461] [SPARK-26038] Decimal toScalaBigInt/toJavaBigInteger for decimals not fitting in long ## What changes were proposed in this pull request? Fix Decimal `toScalaBigInt` and `toJavaBigInteger` used to only work for decimals not fitting long. ## How was this patch tested? Added test to DecimalSuite. Closes #23022 from juliuszsompolski/SPARK-26038. Authored-by: Juliusz Sompolski Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/types/Decimal.scala | 16 ++++++++++++++-- .../apache/spark/sql/types/DecimalSuite.scala | 11 +++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index a3a844670e0c6..0192059a3a39f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -185,9 +185,21 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toScalaBigInt: BigInt = BigInt(toLong) + def toScalaBigInt: BigInt = { + if (decimalVal.ne(null)) { + decimalVal.toBigInt() + } else { + BigInt(toLong) + } + } - def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong) + def toJavaBigInteger: java.math.BigInteger = { + if (decimalVal.ne(null)) { + decimalVal.underlying().toBigInteger() + } else { + java.math.BigInteger.valueOf(toLong) + } + } def toUnscaledLong: Long = { if (decimalVal.ne(null)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 10de90c6a44ca..8abd7625c21aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -228,4 +228,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val decimal = Decimal.apply(bigInt) assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") } + + test("SPARK-26038: toScalaBigInt/toJavaBigInteger") { + // not fitting long + val decimal = Decimal("1234568790123456789012348790.1234879012345678901234568790") + assert(decimal.toScalaBigInt == scala.math.BigInt("1234568790123456789012348790")) + assert(decimal.toJavaBigInteger == new java.math.BigInteger("1234568790123456789012348790")) + // fitting long + val decimalLong = Decimal(123456789123456789L, 18, 9) + assert(decimalLong.toScalaBigInt == scala.math.BigInt("123456789")) + assert(decimalLong.toJavaBigInteger == new java.math.BigInteger("123456789")) + } } From de84899204f3428f3d1d688b277dc06b021d860a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 23 Nov 2018 14:14:21 -0800 Subject: [PATCH 2127/2461] [SPARK-26140] Enable custom metrics implementation in shuffle reader ## What changes were proposed in this pull request? This patch defines an internal Spark interface for reporting shuffle metrics and uses that in shuffle reader. Before this patch, shuffle metrics is tied to a specific implementation (using a thread local temporary data structure and accumulators). After this patch, callers that define their own shuffle RDDs can create a custom metrics implementation. With this patch, we would be able to create a better metrics for the SQL layer, e.g. reporting shuffle metrics in the SQL UI, for each exchange operator. Note that I'm separating read side and write side implementations, as they are very different, to simplify code review. Write side change is at https://github.com/apache/spark/pull/23106 ## How was this patch tested? No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases. Closes #23105 from rxin/SPARK-26140. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../spark/executor/ShuffleReadMetrics.scala | 18 ++++--- .../org/apache/spark/rdd/CoGroupedRDD.scala | 4 +- .../org/apache/spark/rdd/ShuffledRDD.scala | 4 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 7 ++- .../shuffle/BlockStoreShuffleReader.scala | 5 +- .../apache/spark/shuffle/ShuffleManager.scala | 3 +- .../shuffle/ShuffleMetricsReporter.scala | 33 ++++++++++++ .../org/apache/spark/shuffle/metrics.scala | 52 +++++++++++++++++++ .../shuffle/sort/SortShuffleManager.scala | 6 ++- .../storage/ShuffleBlockFetcherIterator.scala | 10 ++-- .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++- .../spark/scheduler/CustomShuffledRDD.scala | 3 +- .../BlockStoreShuffleReaderSuite.scala | 5 +- .../ShuffleBlockFetcherIteratorSuite.scala | 31 +++++++---- .../spark/sql/execution/ShuffledRowRDD.scala | 4 +- 15 files changed, 155 insertions(+), 36 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/metrics.scala diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 4be395c8358b2..2f97e969d2dd2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleMetricsReporter import org.apache.spark.util.LongAccumulator @@ -123,12 +124,13 @@ class ShuffleReadMetrics private[spark] () extends Serializable { } } + /** * A temporary shuffle read metrics holder that is used to collect shuffle read metrics for each * shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at * last. */ -private[spark] class TempShuffleReadMetrics { +private[spark] class TempShuffleReadMetrics extends ShuffleMetricsReporter { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L @@ -137,13 +139,13 @@ private[spark] class TempShuffleReadMetrics { private[this] var _fetchWaitTime = 0L private[this] var _recordsRead = 0L - def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v - def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v - def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v - def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v - def incLocalBytesRead(v: Long): Unit = _localBytesRead += v - def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v - def incRecordsRead(v: Long): Unit = _recordsRead += v + override def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v + override def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v + override def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v + override def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v + override def incLocalBytesRead(v: Long): Unit = _localBytesRead += v + override def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v + override def incRecordsRead(v: Long): Unit = _recordsRead += v def remoteBlocksFetched: Long = _remoteBlocksFetched def localBlocksFetched: Long = _localBlocksFetched diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4574c3724962e..7e76731f5e454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -143,8 +143,10 @@ class CoGroupedRDD[K: ClassTag]( case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle + val metrics = context.taskMetrics().createTempShuffleReadMetrics() val it = SparkEnv.get.shuffleManager - .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) + .getReader( + shuffleDependency.shuffleHandle, split.index, split.index + 1, context, metrics) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index e8f9b27b7eb55..5ec99b7f4f3ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -101,7 +101,9 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] - SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + val metrics = context.taskMetrics().createTempShuffleReadMetrics() + SparkEnv.get.shuffleManager.getReader( + dep.shuffleHandle, split.index, split.index + 1, context, metrics) .read() .asInstanceOf[Iterator[(K, C)]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index a733eaa5d7e53..42d190377f104 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -107,9 +107,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) case shuffleDependency: ShuffleDependency[_, _, _] => + val metrics = context.taskMetrics().createTempShuffleReadMetrics() val iter = SparkEnv.get.shuffleManager .getReader( - shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + shuffleDependency.shuffleHandle, + partition.index, + partition.index + 1, + context, + metrics) .read() iter.foreach(op) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 74b0e0b3a741a..7cb031ce318b7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + readMetrics: ShuffleMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) @@ -53,7 +54,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + readMetrics) val serializerInstance = dep.serializer.newInstance() @@ -66,7 +68,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 4ea8a7120a9cc..d1061d83cb85a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -48,7 +48,8 @@ private[spark] trait ShuffleManager { handle: ShuffleHandle, startPartition: Int, endPartition: Int, - context: TaskContext): ShuffleReader[K, C] + context: TaskContext, + metrics: ShuffleMetricsReporter): ShuffleReader[K, C] /** * Remove a shuffle's metadata from the ShuffleManager. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala new file mode 100644 index 0000000000000..32865149c97c2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An interface for reporting shuffle information, for each shuffle. This interface assumes + * all the methods are called on a single-threaded, i.e. concrete implementations would not need + * to synchronize anything. + */ +private[spark] trait ShuffleMetricsReporter { + def incRemoteBlocksFetched(v: Long): Unit + def incLocalBlocksFetched(v: Long): Unit + def incRemoteBytesRead(v: Long): Unit + def incRemoteBytesReadToDisk(v: Long): Unit + def incLocalBytesRead(v: Long): Unit + def incFetchWaitTime(v: Long): Unit + def incRecordsRead(v: Long): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/metrics.scala b/core/src/main/scala/org/apache/spark/shuffle/metrics.scala new file mode 100644 index 0000000000000..33be677bc90cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/metrics.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An interface for reporting shuffle read metrics, for each shuffle. This interface assumes + * all the methods are called on a single-threaded, i.e. concrete implementations would not need + * to synchronize. + * + * All methods have additional Spark visibility modifier to allow public, concrete implementations + * that still have these methods marked as private[spark]. + */ +private[spark] trait ShuffleReadMetricsReporter { + private[spark] def incRemoteBlocksFetched(v: Long): Unit + private[spark] def incLocalBlocksFetched(v: Long): Unit + private[spark] def incRemoteBytesRead(v: Long): Unit + private[spark] def incRemoteBytesReadToDisk(v: Long): Unit + private[spark] def incLocalBytesRead(v: Long): Unit + private[spark] def incFetchWaitTime(v: Long): Unit + private[spark] def incRecordsRead(v: Long): Unit +} + + +/** + * An interface for reporting shuffle write metrics. This interface assumes all the methods are + * called on a single-threaded, i.e. concrete implementations would not need to synchronize. + * + * All methods have additional Spark visibility modifier to allow public, concrete implementations + * that still have these methods marked as private[spark]. + */ +private[spark] trait ShuffleWriteMetricsReporter { + private[spark] def incBytesWritten(v: Long): Unit + private[spark] def incRecordsWritten(v: Long): Unit + private[spark] def incWriteTime(v: Long): Unit + private[spark] def decBytesWritten(v: Long): Unit + private[spark] def decRecordsWritten(v: Long): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0caf84c6050a8..57c3150e5a697 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -114,9 +114,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager handle: ShuffleHandle, startPartition: Int, endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { + context: TaskContext, + metrics: ShuffleMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, endPartition, context, metrics) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index aecc2284a9588..a2e0713e70b04 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleMetricsReporter} import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -51,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. Note that zero-sized blocks are * already excluded, which happened in - * [[MapOutputTracker.convertMapStatuses]]. + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -59,6 +59,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param shuffleMetrics used to report shuffle metrics. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -71,7 +72,8 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) + detectCorrupt: Boolean, + shuffleMetrics: ShuffleMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -137,8 +139,6 @@ final class ShuffleBlockFetcherIterator( */ private[this] val corruptedBlocks = mutable.HashSet[BlockId]() - private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b917469e48747..419a26b857ea2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -397,8 +397,10 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } - val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + val taskContext = new TaskContextImpl( + 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala index 838686923767e..1be2e2a067115 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala @@ -104,8 +104,9 @@ class CustomShuffledRDD[K, V, C]( override def compute(p: Partition, context: TaskContext): Iterator[(K, C)] = { val part = p.asInstanceOf[CustomShuffledRDDPartition] + val metrics = context.taskMetrics().createTempShuffleReadMetrics() SparkEnv.get.shuffleManager.getReader( - dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context) + dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context, metrics) .read() .asInstanceOf[Iterator[(K, C)]] } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 2d8a83c6fabed..eb97d5a1e5074 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -126,11 +126,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext .set("spark.shuffle.compress", "false") .set("spark.shuffle.spill.compress", "false")) + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, - TaskContext.empty(), + taskContext, + metrics, serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b268195e09a5b..01ee9ef0825f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -102,8 +102,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ).toIterator + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val iterator = new ShuffleBlockFetcherIterator( - TaskContext.empty(), + taskContext, transfer, blockManager, blocksByAddress, @@ -112,7 +114,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + metrics) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -190,7 +193,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -258,7 +262,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -328,7 +333,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -392,7 +398,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next()._1, iterator.next()._1) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) @@ -446,7 +453,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - false) + false, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -496,8 +504,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + val taskContext = TaskContext.empty() new ShuffleBlockFetcherIterator( - TaskContext.empty(), + taskContext, transfer, blockManager, blocksByAddress, @@ -506,7 +515,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, - detectCorrupt = true) + detectCorrupt = true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( @@ -552,7 +562,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 862ee05392f37..542266bc1ae07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -154,6 +154,7 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] + val metrics = context.taskMetrics().createTempShuffleReadMetrics() // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -161,7 +162,8 @@ class ShuffledRowRDD( dependency.shuffleHandle, shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, - context) + context, + metrics) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } From 7f5f7a967d36d78f73d8fa1e178dfdb324d73bf1 Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 24 Nov 2018 09:10:15 -0600 Subject: [PATCH 2128/2461] [SPARK-25786][CORE] If the ByteBuffer.hasArray is false , it will throw UnsupportedOperationException for Kryo ## What changes were proposed in this pull request? `deserialize` for kryo, the type of input parameter is ByteBuffer, if it is not backed by an accessible byte array. it will throw `UnsupportedOperationException` Exception Info: ``` java.lang.UnsupportedOperationException was thrown. java.lang.UnsupportedOperationException at java.nio.ByteBuffer.array(ByteBuffer.java:994) at org.apache.spark.serializer.KryoSerializerInstance.deserialize(KryoSerializer.scala:362) ``` ## How was this patch tested? Added a unit test Closes #22779 from 10110346/InputStreamKryo. Authored-by: liuxian Signed-off-by: Sean Owen --- .../apache/spark/serializer/KryoSerializer.scala | 16 +++++++++++++--- .../spark/serializer/KryoSerializerSuite.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 66812a54846c6..1e1c27c477877 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -42,7 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, ByteBufferInputStream, SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer /** @@ -417,7 +417,12 @@ private[spark] class KryoSerializerInstance( override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + if (bytes.hasArray) { + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + } else { + input.setBuffer(new Array[Byte](4096)) + input.setInputStream(new ByteBufferInputStream(bytes)) + } kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -429,7 +434,12 @@ private[spark] class KryoSerializerInstance( val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + if (bytes.hasArray) { + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + } else { + input.setBuffer(new Array[Byte](4096)) + input.setInputStream(new ByteBufferInputStream(bytes)) + } kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e413fe3b774d0..a7eed4b6a8b88 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.ByteBuffer import java.util.concurrent.Executors import scala.collection.JavaConverters._ @@ -551,6 +552,17 @@ class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSpar deserializationStream.close() assert(serInstance.deserialize[Any](helloHello) === ((hello, hello))) } + + test("SPARK-25786: ByteBuffer.array -- UnsupportedOperationException") { + val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + val obj = "UnsupportedOperationException" + val serObj = serInstance.serialize(obj) + val byteBuffer = ByteBuffer.allocateDirect(serObj.array().length) + byteBuffer.put(serObj.array()) + byteBuffer.flip() + assert(serInstance.deserialize[Any](serObj) === (obj)) + assert(serInstance.deserialize[Any](byteBuffer) === (obj)) + } } class ClassLoaderTestingObject From 0f56977f8c9bfc48230d499925e31ff81bcd0f86 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 24 Nov 2018 09:12:05 -0600 Subject: [PATCH 2129/2461] [SPARK-26156][WEBUI] Revise summary section of stage page ## What changes were proposed in this pull request? In the summary section of stage page: ![image](https://user-images.githubusercontent.com/1097932/48935518-ebef2b00-ef42-11e8-8672-eaa4cac92c5e.png) 1. the following metrics names can be revised: Output => Output Size / Records Shuffle Read: => Shuffle Read Size / Records Shuffle Write => Shuffle Write Size / Records After changes, the names are more clear, and consistent with the other names in the same page. 2. The associated job id URL should not contain the 3 tails spaces. Reduce the number of spaces to one, and exclude the space from link. This is consistent with SQL execution page. ## How was this patch tested? Manual check: ![image](https://user-images.githubusercontent.com/1097932/48935538-f7425680-ef42-11e8-8b2a-a4f388d3ea52.png) Closes #23125 from gengliangwang/reviseStagePage. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- .../org/apache/spark/ui/jobs/StagePage.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7e6cc4297d6b1..2b436b9234144 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -152,20 +152,20 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We }} {if (hasOutput(stageData)) {
    • - Output: + Output Size / Records: {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
    • }} {if (hasShuffleRead(stageData)) {
    • - Shuffle Read: + Shuffle Read Size / Records: {s"${Utils.bytesToString(stageData.shuffleReadBytes)} / " + s"${stageData.shuffleReadRecords}"}
    • }} {if (hasShuffleWrite(stageData)) {
    • - Shuffle Write: + Shuffle Write Size / Records: {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + s"${stageData.shuffleWriteRecords}"}
    • @@ -183,10 +183,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We {if (!stageJobIds.isEmpty) {
    • Associated Job Ids: - {stageJobIds.map(jobId => {val detailUrl = "%s/jobs/job/?id=%s".format( - UIUtils.prependBaseUri(request, parent.basePath), jobId) - {s"${jobId}"}    - })} + {stageJobIds.sorted.map { jobId => + val jobURL = "%s/jobs/job/?id=%s" + .format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + {jobId.toString}  + }}
    • }} From eea4a0330b913cd45e369f09ec3d1dbb1b81f1b5 Mon Sep 17 00:00:00 2001 From: Lee moon soo Date: Sat, 24 Nov 2018 16:09:13 -0800 Subject: [PATCH 2130/2461] [MINOR][K8S] Invalid property "spark.driver.pod.name" is referenced in docs. ## What changes were proposed in this pull request? "Running on Kubernetes" references `spark.driver.pod.name` few places, and it should be `spark.kubernetes.driver.pod.name`. ## How was this patch tested? See changes Closes #23133 from Leemoonsoo/fix-driver-pod-name-prop. Authored-by: Lee moon soo Signed-off-by: Dongjoon Hyun --- docs/running-on-kubernetes.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index a9d448820e700..e940d9a63b7af 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -166,7 +166,7 @@ hostname via `spark.driver.host` and your spark driver's port to `spark.driver.p ### Client Mode Executor Pod Garbage Collection -If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod. +If you run your Spark driver in a pod, it is highly recommended to set `spark.kubernetes.driver.pod.name` to the name of that pod. When this property is set, the Spark scheduler will deploy the executor pods with an [OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted. @@ -175,7 +175,7 @@ an OwnerReference pointing to that pod will be added to each executor pod's Owne setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated prematurely when the wrong pod is deleted. -If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is +If your application is not running inside a pod, or if `spark.kubernetes.driver.pod.name` is not set when your application is actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the From 41d5aaec840234b1fcfd6f87f5e9e7729a3f0fe2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 26 Nov 2018 00:26:24 +0900 Subject: [PATCH 2131/2461] [SPARK-26148][PYTHON][TESTS] Increases default parallelism in PySpark tests to speed up ## What changes were proposed in this pull request? This PR proposes to increase parallelism in PySpark tests to speed up from 4 to 8. It decreases the elapsed time from https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99163/consoleFull Tests passed in 1770 seconds to https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99186/testReport/ Tests passed in 1027 seconds ## How was this patch tested? Jenkins tests Closes #23111 from HyukjinKwon/parallelism. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- dev/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 7ec73347d16bf..27f7527052e29 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -460,7 +460,7 @@ def parse_opts(): prog="run-tests" ) parser.add_option( - "-p", "--parallelism", type="int", default=4, + "-p", "--parallelism", type="int", default=8, help="The number of suites to test in parallel (default %default)" ) From c5daccb1dafca528ccb4be65d63c943bf9a7b0f2 Mon Sep 17 00:00:00 2001 From: Katrin Leinweber <9948149+katrinleinweber@users.noreply.github.com> Date: Sun, 25 Nov 2018 17:43:55 -0600 Subject: [PATCH 2132/2461] [MINOR] Update all DOI links to preferred resolver ## What changes were proposed in this pull request? The DOI foundation recommends [this new resolver](https://www.doi.org/doi_handbook/3_Resolution.html#3.8). Accordingly, this PR re`sed`s all static DOI links ;-) ## How was this patch tested? It wasn't, since it seems as safe as a "[typo fix](https://spark.apache.org/contributing.html)". In case any of the files is included from other projects, and should be updated there, please let me know. Closes #23129 from katrinleinweber/resolve-DOIs-securely. Authored-by: Katrin Leinweber <9948149+katrinleinweber@users.noreply.github.com> Signed-off-by: Sean Owen --- R/pkg/R/stats.R | 4 ++-- .../scala/org/apache/spark/api/java/JavaPairRDD.scala | 6 +++--- .../scala/org/apache/spark/api/java/JavaRDDLike.scala | 2 +- .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 8 ++++---- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- docs/ml-classification-regression.md | 4 ++-- docs/ml-collaborative-filtering.md | 4 ++-- docs/ml-frequent-pattern-mining.md | 8 ++++---- docs/mllib-collaborative-filtering.md | 4 ++-- docs/mllib-frequent-pattern-mining.md | 6 +++--- docs/mllib-isotonic-regression.md | 4 ++-- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 2 +- .../main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 2 +- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 4 ++-- .../scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- .../spark/mllib/linalg/distributed/RowMatrix.scala | 2 +- .../org/apache/spark/mllib/recommendation/ALS.scala | 2 +- python/pyspark/ml/fpm.py | 6 +++--- python/pyspark/ml/recommendation.py | 2 +- python/pyspark/mllib/fpm.py | 2 +- python/pyspark/mllib/linalg/distributed.py | 2 +- python/pyspark/rdd.py | 2 +- python/pyspark/sql/dataframe.py | 4 ++-- .../spark/sql/catalyst/util/QuantileSummaries.scala | 2 +- .../org/apache/spark/sql/DataFrameStatFunctions.scala | 10 +++++----- .../spark/sql/execution/stat/FrequentItems.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- 29 files changed, 54 insertions(+), 54 deletions(-) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 497f18c763048..7252351ebebb2 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -109,7 +109,7 @@ setMethod("corr", #' #' Finding frequent items for columns, possibly with false positives. #' Using the frequent element count algorithm described in -#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. +#' \url{https://doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. #' #' @param x A SparkDataFrame. #' @param cols A vector column names to search frequent items in. @@ -143,7 +143,7 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' *exact* rank of x is close to (p * N). More precisely, #' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed -#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 +#' optimizations). The algorithm was first present in [[https://doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. #' Note that NA values will be ignored in numerical columns before calculation. For #' columns only containing NA values, an empty list is returned. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 80a4f84087466..50ed8d9bd3f68 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -952,7 +952,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -969,7 +969,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -985,7 +985,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 91ae1002abd21..5ba821935ac69 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -685,7 +685,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index e68c6b1366c7f..4bf4f082d0382 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -394,7 +394,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is * greater than `p`) would trigger sparse representation of registers, which may reduce the @@ -436,7 +436,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -456,7 +456,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -473,7 +473,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 743e3441eea55..6a25ee20b2c68 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1258,7 +1258,7 @@ abstract class RDD[T: ClassTag]( * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater * than `p`) would trigger sparse representation of registers, which may reduce the memory @@ -1290,7 +1290,7 @@ abstract class RDD[T: ClassTag]( * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index b3d109039da4d..42912a2e2bc31 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -941,9 +941,9 @@ Essentially isotonic regression is a best fitting the original data points. We implement a -[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111) which uses an approach to -[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 8b0f287dc39ad..58646642bfbcc 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -41,7 +41,7 @@ for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.ml` to deal with such data is taken -from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of @@ -55,7 +55,7 @@ We scale the regularization parameter `regParam` in solving each least squares p the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper -"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)". It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md index c2043d495c149..f613664271ec6 100644 --- a/docs/ml-frequent-pattern-mining.md +++ b/docs/ml-frequent-pattern-mining.md @@ -18,7 +18,7 @@ for more information. ## FP-Growth The FP-growth algorithm is described in the paper -[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372), where "FP" stands for frequent pattern. Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, @@ -26,7 +26,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, -as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence is more scalable than a single-machine implementation. We refer users to the papers for more details. @@ -90,7 +90,7 @@ Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. PrefixSpan is a sequential pattern mining algorithm described in [Pei et al., Mining Sequential Patterns by Pattern-Growth: The -PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. @@ -137,4 +137,4 @@ Refer to the [R API docs](api/R/spark.prefixSpan.html) for more details. {% include_example r/ml/prefixSpan.R %} - \ No newline at end of file + diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index b2300028e151b..aeebb26bb45f3 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -37,7 +37,7 @@ for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken -from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of @@ -51,7 +51,7 @@ Since v1.1, we scale the regularization parameter `lambda` in solving each least the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper -"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)". It makes `lambda` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 0d3192c6b1d9c..8e4505756b275 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -15,7 +15,7 @@ a popular algorithm to mining frequent itemsets. ## FP-growth The FP-growth algorithm is described in the paper -[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372), where "FP" stands for frequent pattern. Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, @@ -23,7 +23,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, -as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. @@ -122,7 +122,7 @@ Refer to the [`AssociationRules` Java docs](api/java/org/apache/spark/mllib/fpm/ PrefixSpan is a sequential pattern mining algorithm described in [Pei et al., Mining Sequential Patterns by Pattern-Growth: The -PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 99cab98c690c6..9964fce3273be 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -24,9 +24,9 @@ Essentially isotonic regression is a best fitting the original data points. `spark.mllib` supports a -[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111) which uses an approach to -[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 919496aa1a840..2eed84d51782a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -263,7 +263,7 @@ object KMeansModel extends MLReadable[KMeansModel] { /** * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. * - * @see Bahmani et al., Scalable k-means++. + * @see Bahmani et al., Scalable k-means++. */ @Since("1.5.0") class KMeans @Since("1.5.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 840a89b76d26b..7322815c12ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -118,10 +118,10 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { /** * :: Experimental :: * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in - * Li et al., PFP: Parallel FP-Growth for Query + * Li et al., PFP: Parallel FP-Growth for Query * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in - * Han et al., Mining frequent patterns without + * Han et al., Mining frequent patterns without * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index bd1c1a8885201..2a3413553a6af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth - * (see here). + * (see here). * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to * run the PrefixSpan algorithm. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index ffe592789b3cc..50ef4330ddc80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -557,7 +557,7 @@ object ALSModel extends MLReadable[ALSModel] { * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. + * https://doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 3a1bc35186dc3..519c1ea47c1db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -152,10 +152,10 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { /** * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in - * Li et al., PFP: Parallel FP-Growth for Query + * Li et al., PFP: Parallel FP-Growth for Query * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in - * Han et al., Mining frequent patterns without + * Han et al., Mining frequent patterns without * candidate generation. * * @param minSupport the minimal support level of the frequent pattern, any pattern that appears diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 64d6a0bc47b97..b2c09b408b40b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,7 @@ import org.apache.spark.storage.StorageLevel * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth - * (see here). + * (see here). * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 82ab716ed96a8..c12b751bfb8e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -540,7 +540,7 @@ class RowMatrix @Since("1.0.0") ( * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. * Reference: * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce - * architectures" (see here) + * architectures" (see here) * * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 14288221b6945..12870f819b147 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -54,7 +54,7 @@ case class Rating @Since("0.8.0") ( * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * here, adapted for the blocked approach + * here, adapted for the blocked approach * used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 886ad8409ca66..734763ebd3fa6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -167,8 +167,8 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, independent group of mining tasks. The FP-Growth algorithm is described in Han et al., Mining frequent patterns without candidate generation [HAN2000]_ - .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 - .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 + .. [LI2008] https://doi.org/10.1145/1454008.1454027 + .. [HAN2000] https://doi.org/10.1145/335191.335372 .. note:: null values in the feature column are ignored during fit(). .. note:: Internally `transform` `collects` and `broadcasts` association rules. @@ -254,7 +254,7 @@ class PrefixSpan(JavaParams): A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - (see here). + (see here). This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` method to run the PrefixSpan algorithm. diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index a8eae9bd268d3..520d7912c1a10 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -57,7 +57,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha For implicit preference data, the algorithm used is based on `"Collaborative Filtering for Implicit Feedback Datasets", - `_, adapted for the blocked + `_, adapted for the blocked approach used here. Essentially instead of finding the low-rank approximations to the diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index de18dad1f675d..6accb9b4926e8 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -132,7 +132,7 @@ class PrefixSpan(object): A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - ([[http://doi.org/10.1109/ICDE.2001.914830]]). + ([[https://doi.org/10.1109/ICDE.2001.914830]]). .. versionadded:: 1.6.0 """ diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 7e8b15056cabe..b7f09782be9dd 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -270,7 +270,7 @@ def tallSkinnyQR(self, computeQ=False): Reference: Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce architectures" - ([[http://dx.doi.org/10.1145/1996092.1996103]]) + ([[https://doi.org/10.1145/1996092.1996103]]) :param: computeQ: whether to computeQ :return: QRDecomposition(Q: RowMatrix, R: Matrix), where diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ccf39e1ffbe96..8bd6897df925f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2354,7 +2354,7 @@ def countApproxDistinct(self, relativeSD=0.05): The algorithm used is based on streamlib's implementation of `"HyperLogLog in Practice: Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available here - `_. + `_. :param relativeSD: Relative accuracy. Smaller values create counters that require more space. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c4f4d81999544..4abbeacfd56b4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1806,7 +1806,7 @@ def approxQuantile(self, col, probabilities, relativeError): This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). The algorithm was first - present in [[http://dx.doi.org/10.1145/375663.375670 + present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. @@ -1928,7 +1928,7 @@ def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the frequent element count algorithm described in - "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". + "https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. .. note:: This function is meant for exploratory data analysis, as we make no diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 3190e511e2cb5..2a03f85ab594b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * Helper class to compute approximate quantile summary. * This implementation is based on the algorithm proposed in the paper: * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael - * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670) + * and Khanna, Sanjeev. (https://doi.org/10.1145/375663.375670) * * In order to optimize for speed, it maintains an internal buffer of the last seen samples, * and only inserts them after crossing a certain size threshold. This guarantees a near-constant diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b2f6a6ba83108..0b22b898557f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -51,7 +51,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed * optimizations). - * The algorithm was first present in + * The algorithm was first present in * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. * * @param col the name of the numerical column @@ -218,7 +218,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, + * here, proposed by Karp, * Schenker, and Papadimitriou. * The `support` should be greater than 1e-4. * @@ -265,7 +265,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, + * here, proposed by Karp, * Schenker, and Papadimitriou. * Uses a `default` support of 1%. * @@ -284,7 +284,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * * This function is meant for exploratory data analysis, as we make no guarantee about the @@ -328,7 +328,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * Uses a `default` support of 1%. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 86f6307254332..420faa6f24734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -69,7 +69,7 @@ object FrequentItems extends Logging { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * The `support` should be greater than 1e-4. * For Internal use only. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index bea652cc33076..ac25a8fd90bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -45,7 +45,7 @@ object StatFunctions extends Logging { * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed * optimizations). - * The algorithm was first present in + * The algorithm was first present in * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. * * @param df the dataframe From 94145786a5b91a7f0bca44f27599a61c72f3a18f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 25 Nov 2018 15:53:07 -0800 Subject: [PATCH 2133/2461] [SPARK-25908][SQL][FOLLOW-UP] Add back unionAll ## What changes were proposed in this pull request? This PR is to add back `unionAll`, which is widely used. The name is also consistent with our ANSI SQL. We also have the corresponding `intersectAll` and `exceptAll`, which were introduced in Spark 2.4. ## How was this patch tested? Added a test case in DataFrameSuite Closes #23131 from gatorsmile/addBackUnionAll. Authored-by: gatorsmile Signed-off-by: gatorsmile --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 14 ++++++++++++++ R/pkg/R/generics.R | 3 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 1 + docs/sparkr.md | 2 +- docs/sql-migration-guide-upgrade.md | 2 ++ python/pyspark/sql/dataframe.py | 11 +++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 14 ++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 9 files changed, 53 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index de56061b4c1c7..cdeafdd90ce4a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -169,6 +169,7 @@ exportMethods("arrange", "toJSON", "transform", "union", + "unionAll", "unionByName", "unique", "unpersist", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 52e76570139e2..ad9cd845f696c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2732,6 +2732,20 @@ setMethod("union", dataFrame(unioned) }) +#' Return a new SparkDataFrame containing the union of rows +#' +#' This is an alias for `union`. +#' +#' @rdname union +#' @name unionAll +#' @aliases unionAll,SparkDataFrame,SparkDataFrame-method +#' @note unionAll since 1.4.0 +setMethod("unionAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + union(x, y) + }) + #' Return a new SparkDataFrame containing the union of rows, matched by column names #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index cbed276274ac1..b2ca6e62175e7 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -631,6 +631,9 @@ setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union setGeneric("union", function(x, y) { standardGeneric("union") }) +#' @rdname union +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + #' @rdname unionByName setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index f355a515935c8..77a29c9ecad86 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2458,6 +2458,7 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF expect_equal(count(unioned), 6) expect_equal(first(unioned)$name, "Michael") expect_equal(count(arrange(suppressWarnings(union(df, df2)), df$age)), 6) + expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), df$age)), 6) df1 <- select(df2, "age", "name") unioned1 <- arrange(unionByName(df1, df), df1$age) diff --git a/docs/sparkr.md b/docs/sparkr.md index acd0e77c4d71a..5972435a0e409 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -718,4 +718,4 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 3.0.0 - The deprecated methods `sparkR.init`, `sparkRSQL.init`, `sparkRHive.init` have been removed. Use `sparkR.session` instead. - - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, `dropTempTable`, `unionAll` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead. + - The deprecated methods `parquetFile`, `saveAsParquetFile`, `jsonFile`, `registerTempTable`, `createExternalTable`, and `dropTempTable` have been removed. Use `read.parquet`, `write.parquet`, `read.json`, `createOrReplaceTempView`, `createTable`, `dropTempView`, `union` instead. diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 397ca59d96497..68cb8f5a0d18c 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -9,6 +9,8 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.4 to 3.0 + - Since Spark 3.0, the Dataset and DataFrame API `unionAll` is not deprecated any more. It is an alias for `union`. + - In PySpark, when creating a `SparkSession` with `SparkSession.builder.getOrCreate()`, if there is an existing `SparkContext`, the builder was trying to update the `SparkConf` of the existing `SparkContext` with configurations specified to the builder, but the `SparkContext` is shared by all `SparkSession`s, so we should not update them. Since 3.0, the builder comes to not update the configurations. This is the same behavior as Java/Scala API in 2.3 and above. If you want to update them, you need to update them prior to creating a `SparkSession`. - In Spark version 2.4 and earlier, the parser of JSON data source treats empty strings as null for some data types such as `IntegerType`. For `FloatType` and `DoubleType`, it fails on empty strings and throws exceptions. Since Spark 3.0, we disallow empty strings and will throw exceptions for data types except for `StringType` and `BinaryType`. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4abbeacfd56b4..ca15b36699166 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1448,6 +1448,17 @@ def union(self, other): """ return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) + @since(1.3) + def unionAll(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by :func:`distinct`. + + Also as standard in SQL, this function resolves columns by position (not by name). + """ + return self.union(other) + @since(2.3) def unionByName(self, other): """ Returns a new :class:`DataFrame` containing union of rows in this and another frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e757921b485df..f361bde281732 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1852,6 +1852,20 @@ class Dataset[T] private[sql]( CombineUnions(Union(logicalPlan, other.logicalPlan)) } + /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * This is an alias for `union`. + * + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). + * + * @group typedrel + * @since 2.0.0 + */ + def unionAll(other: Dataset[T]): Dataset[T] = union(other) + /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0ee2627814ba0..7a0767a883f15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -97,6 +97,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { unionDF.agg(avg('key), max('key), min('key), sum('key)), Row(50.5, 100, 1, 25250) :: Nil ) + + // unionAll is an alias of union + val unionAllDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + checkAnswer(unionDF, unionAllDF) } test("union should union DataFrames with UDTs (SPARK-13410)") { From 6339c8c2c6b80a85e4ad6a7fa7595cf567a1113e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 26 Nov 2018 11:13:28 +0800 Subject: [PATCH 2134/2461] [SPARK-24762][SQL] Enable Option of Product encoders ## What changes were proposed in this pull request? SparkSQL doesn't support to encode `Option[Product]` as a top-level row now, because in SparkSQL entire top-level row can't be null. However for use cases like Aggregator, it is reasonable to use `Option[Product]` as buffer and output column types. Due to above limitation, we don't do it for now. This patch proposes to encode `Option[Product]` at top-level as single struct column. So we can work around the issue that entire top-level row can't be null. To summarize encoding of `Product` and `Option[Product]`. For `Product`, 1. at root level, the schema is all fields are flatten it into multiple columns. The `Product ` can't be null, otherwise it throws an exception. ```scala val df = Seq((1 -> "a"), (2 -> "b")).toDF() df.printSchema() root |-- _1: integer (nullable = false) |-- _2: string (nullable = true) ``` 2. At non-root level, `Product` is a struct type column. ```scala val df = Seq((1, (1 -> "a")), (2, (2 -> "b")), (3, null)).toDF() df.printSchema() root |-- _1: integer (nullable = false) |-- _2: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` For `Option[Product]`, 1. it was not supported at root level. After this change, it is a struct type column. ```scala val df = Seq(Some(1 -> "a"), Some(2 -> "b"), None).toDF() df.printSchema root |-- value: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` 2. At non-root level, it is also a struct type column. ```scala val df = Seq((1, Some(1 -> "a")), (2, Some(2 -> "b")), (3, None)).toDF() df.printSchema root |-- _1: integer (nullable = false) |-- _2: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` 3. For use case like Aggregator, it was not supported too. After this change, we support to use `Option[Product]` as buffer/output column type. ```scala val df = Seq( OptionBooleanIntData("bob", Some((true, 1))), OptionBooleanIntData("bob", Some((false, 2))), OptionBooleanIntData("bob", None)).toDF() val group = df .groupBy("name") .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) group.printSchema root |-- name: string (nullable = true) |-- isGood: struct (nullable = true) | |-- _1: boolean (nullable = false) | |-- _2: integer (nullable = false) ``` The buffer and output type of `OptionBooleanIntAggregator` is both `Option[(Boolean, Int)`. ## How was this patch tested? Added test. Closes #21732 from viirya/SPARK-24762. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../catalyst/encoders/ExpressionEncoder.scala | 32 +++++--- .../scala/org/apache/spark/sql/Dataset.scala | 10 +-- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 18 ++--- .../spark/sql/DatasetAggregatorSuite.scala | 64 ++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 77 ++++++++++++++++--- 6 files changed, 163 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 592520c59a761..d019924711e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -49,15 +49,6 @@ object ExpressionEncoder { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - if (ScalaReflection.optionOfProductType(tpe)) { - throw new UnsupportedOperationException( - "Cannot create encoder for Option of Product type, because Product type is represented " + - "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null Product objects, " + - "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + - "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") - } - val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) @@ -198,7 +189,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -213,6 +204,9 @@ case class ExpressionEncoder[T]( } else { // For other input objects like primitive, array, map, etc., we construct a struct to wrap // the serializer which is a column of an row. + // + // Note: Because Spark SQL doesn't allow top-level row to be null, to encode + // top-level Option[Product] type, we make it as a top-level struct column. CreateNamedStruct(Literal("value") :: objSerializer :: Nil) } }.flatten @@ -226,7 +220,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -253,10 +247,24 @@ case class ExpressionEncoder[T]( }) /** - * Returns true if the type `T` is serialized as a struct. + * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] + /** + * Returns true if the type `T` is an `Option` type. + */ + def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + + /** + * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in + * the struct are naturally mapped to top-level columns in a row. In other words, the serialized + * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be + * flattened to top-level row, because in Spark SQL top-level row can't be null. This method + * returns true if `T` is serialized as struct and is not `Option` type. + */ + def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f361bde281732..b10d66dfb1aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1093,7 +1093,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.isSerializedAsStruct) { + val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1103,7 +1103,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.isSerializedAsStruct) { + val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1116,14 +1116,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStruct) { + if (!this.exprEnc.isSerializedAsStructForTopLevel) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStruct) { + if (!other.exprEnc.isSerializedAsStructForTopLevel) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1396,7 +1396,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.isSerializedAsStruct) { + if (!encoder.isSerializedAsStructForTopLevel) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 2d849c65997a7..a3cbea9021f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -458,7 +458,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStruct) { + val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) { assert(groupingAttributes.length == 1) if (SQLConf.get.nameNonStructGroupingKeyAsValue) { groupingAttributes.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 39200ec00e152..b75752945a492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -40,9 +40,9 @@ object TypedAggregateExpression { val outputEncoder = encoderFor[OUT] val outputType = outputEncoder.objSerializer.dataType - // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer - // expression is an alias of `BoundReference`, which means the buffer object doesn't need - // serialization. + // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct + // and the serializer expression is an alias of `BoundReference`, which means the buffer + // object doesn't need serialization. val isSimpleBuffer = { bufferSerializer.head match { case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true @@ -76,7 +76,7 @@ object TypedAggregateExpression { None, bufferSerializer, bufferEncoder.resolveAndBind().deserializer, - outputEncoder.serializer, + outputEncoder.objSerializer, outputType, outputEncoder.objSerializer.nullable) } @@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression( inputSchema: Option[StructType], bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, - outputSerializer: Seq[Expression], + outputSerializer: Expression, dataType: DataType, nullable: Boolean, mutableAggBufferOffset: Int = 0, @@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression( aggregator.merge(buffer, input) } - private lazy val resultObjToRow = dataType match { - case _: StructType => - UnsafeProjection.create(CreateStruct(outputSerializer)) - case _ => - assert(outputSerializer.length == 1) - UnsafeProjection.create(outputSerializer.head) - } + private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer) override def eval(buffer: Any): Any = { val resultObj = aggregator.finish(buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 538ea3c66c40e..97c3f358c0e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) +case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } +case class OptionBooleanIntAggregator(colName: String) + extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { + + override def zero: Option[(Boolean, Int)] = None + + override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[(Boolean, Int)] + } else { + val nestedRow = row.getStruct(index) + Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) + } + merge(buffer, value) + } + + override def merge( + b1: Option[(Boolean, Int)], + b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { + if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { + val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) + Some((true, newInt)) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction + + override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } + + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { + val df = Seq( + OptionBooleanIntData("bob", Some((true, 1))), + OptionBooleanIntData("bob", Some((false, 2))), + OptionBooleanIntData("bob", None)).toDF() + + val group = df + .groupBy("name") + .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) + + val expectedSchema = new StructType() + .add("name", StringType, nullable = true) + .add("isGood", + new StructType() + .add("_1", BooleanType, nullable = false) + .add("_2", IntegerType, nullable = false), + nullable = true) + + assert(df.schema == expectedSchema) + assert(group.schema == expectedSchema) + checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) + checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index baece2ddac7eb..0f900833d2cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1312,15 +1312,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of Product type") { - checkDataset(Seq(Some(1), None).toDS(), Some(1), None) - - val e = intercept[UnsupportedOperationException] { - Seq(Some(1 -> "a"), None).toDS() - } - assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) - } - test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1558,6 +1549,74 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(Row("Amsterdam"))) } + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset( + ds, + data: _*) + + val schema = new StructType().add( + "value", + new StructType() + .add("_1", IntegerType, nullable = false) + .add("_2", StringType, nullable = true), + nullable = true) + + assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset( + nestedDs, + nestedOptData: _*) + + val nestedSchema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false) + )), nullable = true) + )) + assert(nestedDs.schema == nestedSchema) + } + + test("SPARK-24762: Resolving Option[Product] field") { + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS() + .as[(Int, Option[(String, Double)])] + checkDataset(ds, + (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None)) + } + + test("SPARK-24762: select Option[Product] field") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds1, + Some((1, 2)), Some((2, 3)), Some((3, 4))) + + val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]]) + checkDataset(ds2, + None, None, Some((3, 4))) + } + + test("SPARK-24762: joinWith on Option[Product]") { + val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a") + val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b") + val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") + checkDataset(joined, (Some((2, 3)), Some((1, 2)))) + } + + test("SPARK-24762: typed agg on Option[Product] type") { + val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS() + assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1))) + + assert(ds.groupByKey(x => x).count().collect() === + Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1))) + } + test("SPARK-25942: typed aggregation on primitive type") { val ds = Seq(1, 2, 3).toDS() From 6ab8485da21035778920da0d9332709f9acaff45 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 26 Nov 2018 15:47:04 +0800 Subject: [PATCH 2135/2461] [SPARK-26169] Create DataFrameSetOperationsSuite ## What changes were proposed in this pull request? Create a new suite DataFrameSetOperationsSuite for the test cases of DataFrame/Dataset's set operations. Also, add test cases of NULL handling for Array Except and Array Intersect. ## How was this patch tested? N/A Closes #23137 from gatorsmile/setOpsTest. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../CollectionExpressionsSuite.scala | 26 + .../sql/DataFrameSetOperationsSuite.scala | 509 ++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 478 ---------------- 3 files changed, 535 insertions(+), 478 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1415b7da6fca1..d2edb2f24688d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1658,6 +1658,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + test("Array Except - null handling") { + val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val oneNull = Literal.create(Seq(null), ArrayType(IntegerType)) + val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType)) + + checkEvaluation(ArrayExcept(oneNull, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(twoNulls, twoNulls), Seq.empty) + checkEvaluation(ArrayExcept(twoNulls, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(empty, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(oneNull, empty), Seq(null)) + checkEvaluation(ArrayExcept(twoNulls, empty), Seq(null)) + } + test("Array Intersect") { val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) @@ -1769,4 +1782,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Array Intersect - null handling") { + val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val oneNull = Literal.create(Seq(null), ArrayType(IntegerType)) + val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType)) + + checkEvaluation(ArrayIntersect(oneNull, oneNull), Seq(null)) + checkEvaluation(ArrayIntersect(twoNulls, twoNulls), Seq(null)) + checkEvaluation(ArrayIntersect(twoNulls, oneNull), Seq(null)) + checkEvaluation(ArrayIntersect(oneNull, twoNulls), Seq(null)) + checkEvaluation(ArrayIntersect(empty, oneNull), Seq.empty) + checkEvaluation(ArrayIntersect(oneNull, empty), Seq.empty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala new file mode 100644 index 0000000000000..30452af1fad64 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} +import org.apache.spark.sql.test.SQLTestData.NullStrings +import org.apache.spark.sql.types._ + +class DataFrameSetOperationsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.except(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.except(nullInts), + Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.except(allNulls.filter("0 = 1")), + Row(null) :: Nil) + checkAnswer( + allNulls.except(allNulls), + Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.except(df.filter("0 = 1")), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").except(allNulls), + Nil) + } + + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + + test("except distinct - SQL compliance") { + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 3).toDF("id") + + checkAnswer( + df_left.except(df_right), + Row(2) :: Row(4) :: Nil + ) + } + + test("except - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.except(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.except(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.except(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.except(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersect") { + checkAnswer( + lowerCaseData.intersect(lowerCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + } + + test("intersect - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersect(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersect(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersect(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersect(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.union(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } + + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("date", "timestamp", "decimal") + + val widenTypedRows = Seq( + (new Timestamp(2), 10.5D, "string") + ).toDF("date", "timestamp", "decimal") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } + + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } + + test("union all") { + val unionDF = testData.union(testData).union(testData) + .union(testData).union(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + + // unionAll is an alias of union + val unionAllDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + checkAnswer(unionDF, unionAllDF) + } + + test("union should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.union(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { + def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + val df1 = spark.createDataFrame(Seq( + (1, 1) + )).toDF("a", "b").withColumn("c", newCol) + + val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) + checkAnswer(df2, result) + } + + check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) + check(lit(null).cast("int"), $"c".isNotNull, Seq()) + check(lit(2).cast("int"), $"c".isNull, Seq()) + check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" =!= 2, Seq()) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7a0767a883f15..fc3faa08d55f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -85,129 +85,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } - test("union all") { - val unionDF = testData.union(testData).union(testData) - .union(testData).union(testData) - - // Before optimizer, Union should be combined. - assert(unionDF.queryExecution.analyzed.collect { - case j: Union if j.children.size == 5 => j }.size === 1) - - checkAnswer( - unionDF.agg(avg('key), max('key), min('key), sum('key)), - Row(50.5, 100, 1, 25250) :: Nil - ) - - // unionAll is an alias of union - val unionAllDF = testData.unionAll(testData).unionAll(testData) - .unionAll(testData).unionAll(testData) - - checkAnswer(unionDF, unionAllDF) - } - - test("union should union DataFrames with UDTs (SPARK-13410)") { - val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) - val schema1 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) - val schema2 = StructType(Array(StructField("label", IntegerType, false), - StructField("point", new ExamplePointUDT(), false))) - val df1 = spark.createDataFrame(rowRDD1, schema1) - val df2 = spark.createDataFrame(rowRDD2, schema2) - - checkAnswer( - df1.union(df2).orderBy("label"), - Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) - ) - } - - test("union by name") { - var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") - val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") - val unionDf = df1.unionByName(df2.unionByName(df3)) - checkAnswer(unionDf, - Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil - ) - - // Check if adjacent unions are combined into a single one - assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) - - // Check failure cases - df1 = Seq((1, 2)).toDF("a", "c") - df2 = Seq((3, 4, 5)).toDF("a", "b", "c") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains( - "Union can only be performed on tables with the same number of columns, " + - "but the first table has 2 columns and the second table has 3 columns")) - - df1 = Seq((1, 2, 3)).toDF("a", "b", "c") - df2 = Seq((4, 5, 6)).toDF("a", "c", "d") - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) - } - - test("union by name - type coercion") { - var df1 = Seq((1, "a")).toDF("c0", "c1") - var df2 = Seq((3, 1L)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) - - df1 = Seq((1, 1.0)).toDF("c0", "c1") - df2 = Seq((8L, 3.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) - - df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") - df2 = Seq(("a", 4.0)).toDF("c1", "c0") - checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) - - df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") - df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") - val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") - checkAnswer(df1.unionByName(df2.unionByName(df3)), - Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil - ) - } - - test("union by name - check case sensitivity") { - def checkCaseSensitiveTest(): Unit = { - val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") - val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") - checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val errMsg2 = intercept[AnalysisException] { - checkCaseSensitiveTest() - }.getMessage - assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) - } - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkCaseSensitiveTest() - } - } - - test("union by name - check name duplication") { - Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - var df1 = Seq((1, 1)).toDF(c0, c1) - var df2 = Seq((1, 1)).toDF("c0", "c1") - var errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) - df1 = Seq((1, 1)).toDF("c0", "c1") - df2 = Seq((1, 1)).toDF(c0, c1) - errMsg = intercept[AnalysisException] { - df1.unionByName(df2) - }.getMessage - assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) - } - } - } - test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) @@ -528,259 +405,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("except") { - checkAnswer( - lowerCaseData.except(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.except(lowerCaseData), Nil) - checkAnswer(upperCaseData.except(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.except(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.except(nullInts), - Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.except(allNulls.filter("0 = 1")), - Row(null) :: Nil) - checkAnswer( - allNulls.except(allNulls), - Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.except(df.filter("0 = 1")), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").except(allNulls), - Nil) - } - - test("SPARK-23274: except between two projects without references used in filter") { - val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") - val df1 = df.filter($"a" === 1) - val df2 = df.filter($"a" === 2) - checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) - checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) - } - - test("except distinct - SQL compliance") { - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 3).toDF("id") - - checkAnswer( - df_left.except(df_right), - Row(2) :: Row(4) :: Nil - ) - } - - test("except - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.except(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.except(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.except(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.except(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("except all") { - checkAnswer( - lowerCaseData.exceptAll(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) - checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.exceptAll(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.exceptAll(nullInts), - Nil) - - // check that duplicate values are preserved - checkAnswer( - allNulls.exceptAll(allNulls.filter("0 = 1")), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - checkAnswer( - allNulls.exceptAll(allNulls.limit(2)), - Row(null) :: Row(null) :: Nil) - - // check that duplicates are retained. - val df = spark.sparkContext.parallelize( - NullStrings(1, "id1") :: - NullStrings(1, "id1") :: - NullStrings(2, "id1") :: - NullStrings(3, null) :: Nil).toDF("id", "value") - - checkAnswer( - df.exceptAll(df.filter("0 = 1")), - Row(1, "id1") :: - Row(1, "id1") :: - Row(2, "id1") :: - Row(3, null) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").exceptAll(allNulls), - Nil) - - } - - test("exceptAll - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.exceptAll(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.exceptAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.exceptAll(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.exceptAll(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersect") { - checkAnswer( - lowerCaseData.intersect(lowerCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersect(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.intersect(allNulls), - Row(null) :: Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.intersect(df), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - } - - test("intersect - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersect(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersect(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersect(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersect(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersectAll") { - checkAnswer( - lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), - Row(1, "a") :: - Row(2, "b") :: - Row(2, "b") :: - Row(3, "c") :: - Row(3, "c") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersectAll(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // Duplicate nulls are preserved. - checkAnswer( - allNulls.intersectAll(allNulls), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 2, 2, 3).toDF("id") - - checkAnswer( - df_left.intersectAll(df_right), - Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) - } - - test("intersectAll - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersectAll(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersectAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersectAll(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersectAll(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) @@ -1782,56 +1406,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val df1 = (1 to 100).map(Tuple1.apply).toDF("i") - val df2 = (1 to 30).map(Tuple1.apply).toDF("i") - val intersect = df1.intersect(df2) - val except = df1.except(df2) - assert(intersect.count() === 30) - assert(except.count() === 70) - } - - test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { - val df1 = (1 to 20).map(Tuple1.apply).toDF("i") - val df2 = (1 to 10).map(Tuple1.apply).toDF("i") - - // When generating expected results at here, we need to follow the implementation of - // Rand expression. - def expected(df: DataFrame): Seq[Row] = { - df.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.filter(_.getInt(0) < rng.nextDouble() * 10) - } - } - - val union = df1.union(df2) - checkAnswer( - union.filter('i < rand(7) * 10), - expected(union) - ) - checkAnswer( - union.select(rand(7)), - union.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.map(_ => rng.nextDouble()).map(i => Row(i)) - } - ) - - val intersect = df1.intersect(df2) - checkAnswer( - intersect.filter('i < rand(7) * 10), - expected(intersect) - ) - - val except = df1.except(df2) - checkAnswer( - except.filter('i < rand(7) * 10), - expected(except) - ) - } - test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") @@ -2280,21 +1854,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-17123: Performing set operations that combine non-scala native types") { - val dates = Seq( - (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), - (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) - ).toDF("date", "timestamp", "decimal") - - val widenTypedRows = Seq( - (new Timestamp(2), 10.5D, "string") - ).toDF("date", "timestamp", "decimal") - - dates.union(widenTypedRows).collect() - dates.except(widenTypedRows).collect() - dates.intersect(widenTypedRows).collect() - } - test("SPARK-18070 binary operator should not consider nullability when comparing input types") { val rows = Seq(Row(Seq(1), Seq(1))) val schema = new StructType() @@ -2314,25 +1873,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(BigDecimal(0)) :: Nil) } - test("SPARK-19893: cannot run set operations with map type") { - val df = spark.range(1).select(map(lit("key"), $"id").as("m")) - val e = intercept[AnalysisException](df.intersect(df)) - assert(e.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e2 = intercept[AnalysisException](df.except(df)) - assert(e2.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e3 = intercept[AnalysisException](df.distinct()) - assert(e3.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - withTempView("v") { - df.createOrReplaceTempView("v") - val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) - assert(e4.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - } - } - test("SPARK-20359: catalyst outer join optimization should not throw npe") { val df1 = Seq("a", "b", "c").toDF("x") .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) @@ -2517,24 +2057,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { - def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { - val df1 = spark.createDataFrame(Seq( - (1, 1) - )).toDF("a", "b").withColumn("c", newCol) - - val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) - checkAnswer(df2, result) - } - - check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) - check(lit(null).cast("int"), $"c".isNotNull, Seq()) - check(lit(2).cast("int"), $"c".isNull, Seq()) - check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" =!= 2, Seq()) - } - test("SPARK-25402 Null handling in BooleanSimplification") { val schema = StructType.fromDDL("a boolean, b int") val rows = Seq(Row(null, 1)) From 6bb60b30fd74b2c38640a4e54e5bb19eb890793e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 26 Nov 2018 15:51:28 +0800 Subject: [PATCH 2136/2461] [SPARK-26168][SQL] Update the code comments in Expression and Aggregate ## What changes were proposed in this pull request? This PR is to improve the code comments to document some common traits and traps about the expression. ## How was this patch tested? N/A Closes #23135 from gatorsmile/addcomments. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/TypeCoercion.scala | 5 ++- .../sql/catalyst/expressions/Expression.scala | 44 +++++++++++++++---- .../expressions/namedExpressions.scala | 3 ++ .../plans/logical/basicLogicalOperators.scala | 16 ++++++- 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72ac80e0a0a18..133fa119b7aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -181,8 +181,9 @@ object TypeCoercion { } /** - * The method finds a common type for data types that differ only in nullable, containsNull - * and valueContainsNull flags. If the input types are too different, None is returned. + * The method finds a common type for data types that differ only in nullable flags, including + * `nullable`, `containsNull` of [[ArrayType]] and `valueContainsNull` of [[MapType]]. + * If the input types are different besides nullable flags, None is returned. */ def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { if (t1 == t2) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d51b11024a09d..2ecec61adb0ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf @@ -40,12 +41,28 @@ import org.apache.spark.sql.types._ * "name(arguments...)", the concrete implementation must be a case class whose constructor * arguments are all Expressions types. See [[Substring]] for an example. * - * There are a few important traits: + * There are a few important traits or abstract classes: * * - [[Nondeterministic]]: an expression that is not deterministic. + * - [[Stateful]]: an expression that contains mutable state. For example, MonotonicallyIncreasingID + * and Rand. A stateful expression is always non-deterministic. * - [[Unevaluable]]: an expression that is not supposed to be evaluated. * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to * interpreted mode. + * - [[NullIntolerant]]: an expression that is null intolerant (i.e. any null input will result in + * null output). + * - [[NonSQLExpression]]: a common base trait for the expressions that do not have SQL + * expressions like representation. For example, `ScalaUDF`, `ScalaUDAF`, + * and object `MapObjects` and `Invoke`. + * - [[UserDefinedExpression]]: a common base trait for user-defined functions, including + * UDF/UDAF/UDTF. + * - [[HigherOrderFunction]]: a common base trait for higher order functions that take one or more + * (lambda) functions and applies these to some objects. The function + * produces a number of variables which can be consumed by some lambda + * functions. + * - [[NamedExpression]]: An [[Expression]] that is named. + * - [[TimeZoneAwareExpression]]: A common base trait for time zone aware expressions. + * - [[SubqueryExpression]]: A base interface for expressions that contain a [[LogicalPlan]]. * * - [[LeafExpression]]: an expression that has no child. * - [[UnaryExpression]]: an expression that has one child. @@ -54,12 +71,20 @@ import org.apache.spark.sql.types._ * - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have * the same output data type. * + * A few important traits used for type coercion rules: + * - [[ExpectsInputTypes]]: an expression that has the expected input types. This trait is typically + * used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * - [[ImplicitCastInputTypes]]: an expression that has the expected input types, which can be + * implicitly castable using [[TypeCoercion.ImplicitTypeCasts]]. + * - [[ComplexTypeMergingExpression]]: to resolve output types of the complex expressions + * (e.g., [[CaseWhen]]). */ abstract class Expression extends TreeNode[Expression] { /** * Returns true when an expression is a candidate for static evaluation before the query is - * executed. + * executed. A typical use case: [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]] * * The following conditions are used to determine suitability for constant folding: * - A [[Coalesce]] is foldable if all of its children are foldable @@ -72,7 +97,8 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns true when the current expression always return the same result for fixed inputs from - * children. + * children. The non-deterministic expressions should not change in number and order. They should + * not be evaluated during the query planning. * * Note that this means that an expression should be considered as non-deterministic if: * - it relies on some mutable internal state, or @@ -252,8 +278,9 @@ abstract class Expression extends TreeNode[Expression] { /** - * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization - * time (e.g. Star). This trait is used by those expressions. + * An expression that cannot be evaluated. These expressions don't live past analysis or + * optimization time (e.g. Star) and should not be evaluated during query planning and + * execution. */ trait Unevaluable extends Expression { @@ -724,9 +751,10 @@ abstract class TernaryExpression extends Expression { } /** - * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. - * This logic is usually utilized by expressions combining data from multiple child expressions - * of non-primitive types (e.g. [[CaseWhen]]). + * A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]] + * and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date + * type. This is usually utilized by the expressions (e.g. [[CaseWhen]]) that combine data from + * multiple child expressions of non-primitive types. */ trait ComplexTypeMergingExpression extends Expression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 049ea77691395..02b48f9e30f2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -130,6 +130,9 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. * + * Note that when creating a new Alias, all the [[AttributeReference]] that refer to + * the original alias should be updated to the new one. + * * @param child The computation being performed * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 07fa17b233a47..a26ec4eed8648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{AliasIdentifier} +import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString @@ -575,6 +575,18 @@ case class Range( } } +/** + * This is a Group by operator with the aggregate functions and projections. + * + * @param groupingExpressions expressions for grouping keys + * @param aggregateExpressions expressions for a project list, which could contain + * [[AggregateFunction]]s. + * + * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before + * separating projection from grouping and aggregate, we should avoid expression-level optimization + * on aggregateExpressions, which could reference an expression in groupingExpressions. + * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]] + */ case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], From 1bb60ab8392adf8b896cc04fb1d060620cf09d8a Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Mon, 26 Nov 2018 05:57:33 -0600 Subject: [PATCH 2137/2461] [SPARK-26153][ML] GBT & RandomForest avoid unnecessary `first` job to compute `numFeatures` ## What changes were proposed in this pull request? use base models' `numFeature` instead of `first` job ## How was this patch tested? existing tests Closes #23123 from zhengruifeng/avoid_first_job. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 5 +++-- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 6 ++++-- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index fab8155add5a8..09a9df6d15ece 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") ( (convert2LabeledPoint(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") ( maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, validationIndicatorCol) - instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = if (withValidation) { @@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") ( GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 05fff8885fbf2..0a3bfd1f85e08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") ( .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 186fa2399af05..9b386ef5eed8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -165,7 +165,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { (extractLabeledPoints(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) instr.logPipelineStage(this) @@ -173,7 +172,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) - instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, @@ -182,6 +180,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f5e668ca71db..afa9a646412b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) new RandomForestRegressionModel(uid, trees, numFeatures) } From 2512a1d42911370854ca42d987c851128fa0b263 Mon Sep 17 00:00:00 2001 From: Anastasios Zouzias Date: Mon, 26 Nov 2018 11:10:38 -0600 Subject: [PATCH 2138/2461] [SPARK-26121][STRUCTURED STREAMING] Allow users to define prefix of Kafka's consumer group (group.id) ## What changes were proposed in this pull request? Allow the Spark Structured Streaming user to specify the prefix of the consumer group (group.id), compared to force consumer group ids of the form `spark-kafka-source-*` ## How was this patch tested? Unit tests provided by Spark (backwards compatible change, i.e., user can optionally use the functionality) `mvn test -pl external/kafka-0-10` Closes #23103 from zouzias/SPARK-26121. Authored-by: Anastasios Zouzias Signed-off-by: cody koeninger --- .../structured-streaming-kafka-integration.md | 37 ++++++++++++------- .../sql/kafka010/KafkaSourceProvider.scala | 18 +++++++-- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 71fd5b10cc407..a549ce2a6a05f 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -123,7 +123,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") -### Creating a Kafka Source for Batch Queries +### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, you can create a Dataset/DataFrame for a defined range of offsets. @@ -374,17 +374,24 @@ The following configurations are optional: streaming and batch Rate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume. + + groupIdPrefix + string + spark-kafka-source + streaming and batch + Prefix of consumer group identifiers (`group.id`) that are generated by structured streaming queries + ## Writing Data to Kafka -Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that +Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record. -Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, +Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, if writing the query is successful, then you can assume that the query output was written at least once. A possible -solution to remove duplicates when reading the written data could be to introduce a primary (unique) key +solution to remove duplicates when reading the written data could be to introduce a primary (unique) key that can be used to perform de-duplication when reading. The Dataframe being written to Kafka should have the following columns in schema: @@ -405,8 +412,8 @@ The Dataframe being written to Kafka should have the following columns in schema \* The topic column is required if the "topic" configuration option is not specified.
      -The value column is the only required option. If a key column is not specified then -a ```null``` valued key column will be automatically added (see Kafka semantics on +The value column is the only required option. If a key column is not specified then +a ```null``` valued key column will be automatically added (see Kafka semantics on how ```null``` valued key values are handled). If a topic column exists then its value is used as the topic when writing the given row to Kafka, unless the "topic" configuration option is set i.e., the "topic" configuration option overrides the topic column. @@ -568,7 +575,7 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ .save() - + {% endhighlight %} @@ -576,23 +583,25 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ ## Kafka Specific Configurations -Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, -`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see +Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs) for parameters related to writing data. Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception: -- **group.id**: Kafka source will create a unique group id for each query automatically. +- **group.id**: Kafka source will create a unique group id for each query automatically. The user can +set the prefix of the automatically generated group.id's via the optional source option `groupIdPrefix`, default value +is "spark-kafka-source". - **auto.offset.reset**: Set the source option `startingOffsets` to specify - where to start instead. Structured Streaming manages which offsets are consumed internally, rather - than rely on the kafka Consumer to do it. This will ensure that no data is missed when new + where to start instead. Structured Streaming manages which offsets are consumed internally, rather + than rely on the kafka Consumer to do it. This will ensure that no data is missed when new topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new streaming query is started, and that resuming will always pick up from where the query left off. -- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use +- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the keys. -- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. +- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the values. - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 5034bd73d6e74..f770f0c2a04c2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -77,7 +77,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -119,7 +119,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -159,7 +159,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -538,6 +538,18 @@ private[kafka010] object KafkaSourceProvider extends Logging { .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) .build() + /** + * Returns a unique consumer group (group.id), allowing the user to set the prefix of + * the consumer group + */ + private def streamingUniqueGroupId( + parameters: Map[String, String], + metadataPath: String): String = { + val groupIdPrefix = parameters + .getOrElse("groupIdPrefix", "spark-kafka-source") + s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" + } + /** Class to conveniently update Kafka config params, while logging the changes */ private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { private val map = new ju.HashMap[String, Object](kafkaParams.asJava) From 3df307aa515b3564686e75d1b71754bbcaaf2dec Mon Sep 17 00:00:00 2001 From: Nihar Sheth Date: Mon, 26 Nov 2018 11:06:02 -0800 Subject: [PATCH 2139/2461] [SPARK-25960][K8S] Support subpath mounting with Kubernetes ## What changes were proposed in this pull request? This PR adds configurations to use subpaths with Spark on k8s. Subpaths (https://kubernetes.io/docs/concepts/storage/volumes/#using-subpath) allow the user to specify a path within a volume to use instead of the volume's root. ## How was this patch tested? Added unit tests. Ran SparkPi on a cluster with event logging pointed at a subpath-mount and verified the driver host created and used the subpath. Closes #23026 from NiharS/k8s_subpath. Authored-by: Nihar Sheth Signed-off-by: Marcelo Vanzin --- docs/running-on-kubernetes.md | 17 ++++ .../org/apache/spark/deploy/k8s/Config.scala | 1 + .../deploy/k8s/KubernetesVolumeSpec.scala | 1 + .../deploy/k8s/KubernetesVolumeUtils.scala | 2 + .../features/MountVolumesFeatureStep.scala | 1 + .../k8s/KubernetesVolumeUtilsSuite.scala | 12 +++ .../MountVolumesFeatureStepSuite.scala | 79 +++++++++++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 34 ++++++++ .../k8s/KubernetesExecutorBuilderSuite.scala | 1 + 9 files changed, 148 insertions(+) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e940d9a63b7af..2c01e1e7155ef 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -245,6 +245,7 @@ To mount a volume of any of the types above into the driver pod, use the followi ``` --conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path= --conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly= +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath= ``` Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification. @@ -806,6 +807,14 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath + (none) + + Specifies a subpath to be mounted from the volume into the driver pod. + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint. + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly (none) @@ -830,6 +839,14 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.subPath + (none) + + Specifies a subpath to be mounted from the volume into the executor pod. + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint. + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly false diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index a32bd93bb65bc..724acd231a6cb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -297,6 +297,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY = "mount.subPath" val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index b1762d1efe2ea..1a214fad96618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -34,5 +34,6 @@ private[spark] case class KubernetesEmptyDirVolumeConf( private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( volumeName: String, mountPath: String, + mountSubPath: String, mountReadOnly: Boolean, volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 713df5fffc3a2..155326469235b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -39,6 +39,7 @@ private[spark] object KubernetesVolumeUtils { getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" for { path <- properties.getTry(pathKey) @@ -46,6 +47,7 @@ private[spark] object KubernetesVolumeUtils { } yield KubernetesVolumeSpec( volumeName = volumeName, mountPath = path, + mountSubPath = properties.get(subPathKey).getOrElse(""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), volumeConf = volumeConf ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index e60259c4a9b5a..1473a7d3ee7f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -51,6 +51,7 @@ private[spark] class MountVolumesFeatureStep( val volumeMount = new VolumeMountBuilder() .withMountPath(spec.mountPath) .withReadOnly(spec.mountReadOnly) + .withSubPath(spec.mountSubPath) .withName(spec.volumeName) .build() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index d795d159773a8..de79a58a3a756 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -33,6 +33,18 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { KubernetesHostPathVolumeConf("/hostPath")) } + test("Parses subPath correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountSubPath === "subPath") + } + test("Parses persistentVolumeClaim volumes correctly") { val sparkConf = new SparkConf(false) sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 2a957460ca8e0..aadbf16897f46 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -43,6 +43,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) @@ -62,6 +63,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -83,6 +85,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) ) @@ -104,6 +107,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -125,12 +129,14 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val hpVolumeConf = KubernetesVolumeSpec( "hpVolume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/checkpoints", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -142,4 +148,77 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(configuredPod.pod.getSpec.getVolumes.size() === 2) assert(configuredPod.container.getVolumeMounts.size() === 2) } + + test("Mounts subpath on emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "foo", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDirMount = configuredPod.container.getVolumeMounts.get(0) + assert(emptyDirMount.getMountPath === "/tmp") + assert(emptyDirMount.getName === "testVolume") + assert(emptyDirMount.getSubPath === "foo") + } + + test("Mounts subpath on persistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "bar", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + val pvcMount = configuredPod.container.getVolumeMounts.get(0) + assert(pvcMount.getMountPath === "/tmp") + assert(pvcMount.getName === "testVolume") + assert(pvcMount.getSubPath === "bar") + } + + test("Mounts multiple subpaths") { + val volumeConf = KubernetesEmptyDirVolumeConf(None, None) + val emptyDirSpec = KubernetesVolumeSpec( + "testEmptyDir", + "/tmp/foo", + "foo", + true, + KubernetesEmptyDirVolumeConf(None, None) + ) + val pvcSpec = KubernetesVolumeSpec( + "testPVC", + "/tmp/bar", + "bar", + true, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy( + roleVolumes = emptyDirSpec :: pvcSpec :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + val mounts = configuredPod.container.getVolumeMounts + assert(mounts.size() === 2) + assert(mounts.get(0).getName === "testEmptyDir") + assert(mounts.get(0).getMountPath === "/tmp/foo") + assert(mounts.get(0).getSubPath === "foo") + assert(mounts.get(1).getName === "testPVC") + assert(mounts.get(1).getMountPath === "/tmp/bar") + assert(mounts.get(1).getSubPath === "bar") + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index fe900fda6e545..3708864592d75 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -140,6 +140,40 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val volumeSpec = KubernetesVolumeSpec( "volume", "/tmp", + "", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + JavaMainAppResource(None), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + hadoopConfSpec = None) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + DRIVER_CMD_STEP_TYPE) + } + + test("Apply volumes step if a mount subpath is present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + "foo", false, KubernetesHostPathVolumeConf("/path")) val conf = KubernetesConf( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 1fea08c37ccc6..a59f6d072023e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -107,6 +107,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val volumeSpec = KubernetesVolumeSpec( "volume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/checkpoint")) val conf = KubernetesConf( From 76ef02e499db49c0c6a37fa9dff3d731aeac9898 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Mon, 26 Nov 2018 14:08:32 -0600 Subject: [PATCH 2140/2461] [SPARK-21809] Change Stage Page to use datatables to support sorting columns and searching Support column sort, pagination and search for Stage Page using jQuery DataTable and REST API. Before this commit, the Stage page generated a hard-coded HTML table that could not support search. Supporting search and sort (over all applications rather than the 20 entries in the current page) in any case will greatly improve the user experience. Created the stagespage-template.html for displaying application information in datables. Added REST api endpoint and javascript code to fetch data from the endpoint and display it on the data table. Because of the above change, certain functionalities in the page had to be modified to support the addition of datatables. For example, the toggle checkbox 'Select All' previously would add the checked fields as columns in the Task table and as rows in the Summary Metrics table, but after the change, only columns are added in the Task Table as it got tricky to add rows dynamically in the datatables. ## How was this patch tested? I have attached the screenshots of the Stage Page UI before and after the fix. **Before:** 30564304-35991e1c-9c8a-11e7-850f-2ac7a347f600 31360592-cbaa2bae-ad14-11e7-941d-95b4c7d14970 **After:** 31360591-c5650ee4-ad14-11e7-9665-5a08d8f21830 31360604-d266b6b0-ad14-11e7-94b5-dcc4bb5443f4 Closes #21688 from pgandhi999/SPARK-21809-2.3. Authored-by: pgandhi Signed-off-by: Thomas Graves --- .../ui/static/executorspage-template.html | 8 +- .../apache/spark/ui/static/executorspage.js | 84 +- .../spark/ui/static/images/sort_asc.png | Bin 0 -> 160 bytes .../ui/static/images/sort_asc_disabled.png | Bin 0 -> 148 bytes .../spark/ui/static/images/sort_both.png | Bin 0 -> 201 bytes .../spark/ui/static/images/sort_desc.png | Bin 0 -> 158 bytes .../ui/static/images/sort_desc_disabled.png | Bin 0 -> 146 bytes .../org/apache/spark/ui/static/stagepage.js | 958 ++++++++++++ .../spark/ui/static/stagespage-template.html | 124 ++ .../org/apache/spark/ui/static/utils.js | 113 +- .../spark/ui/static/webui-dataTables.css | 20 + .../org/apache/spark/ui/static/webui.css | 101 ++ .../apache/spark/status/AppStatusStore.scala | 26 +- .../spark/status/api/v1/StagesResource.scala | 121 +- .../org/apache/spark/status/api/v1/api.scala | 5 +- .../org/apache/spark/status/storeTypes.scala | 5 +- .../scala/org/apache/spark/ui/UIUtils.scala | 2 + .../apache/spark/ui/jobs/ExecutorTable.scala | 149 -- .../org/apache/spark/ui/jobs/StagePage.scala | 325 +---- .../blacklisting_for_stage_expectation.json | 1287 +++++++++-------- ...acklisting_node_for_stage_expectation.json | 112 +- .../one_stage_attempt_json_expectation.json | 40 +- .../one_stage_json_expectation.json | 40 +- .../stage_task_list_expectation.json | 100 +- ...multi_attempt_app_json_1__expectation.json | 40 +- ...multi_attempt_app_json_2__expectation.json | 40 +- ...k_list_w__offset___length_expectation.json | 250 +++- ...stage_task_list_w__sortBy_expectation.json | 100 +- ...tBy_short_names___runtime_expectation.json | 100 +- ...rtBy_short_names__runtime_expectation.json | 100 +- ...age_with_accumulable_json_expectation.json | 150 +- .../spark/status/AppStatusUtilsSuite.scala | 10 +- .../org/apache/spark/ui/StagePageSuite.scala | 12 - 33 files changed, 3064 insertions(+), 1358 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_asc.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_asc_disabled.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_both.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_desc.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_desc_disabled.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/stagepage.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/stagespage-template.html create mode 100644 core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css delete mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 5c91304e49fd7..f2c17aef097a4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -16,10 +16,10 @@ --> diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js index 4f63f6413d6de..deeafad4eb5f5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/utils.js +++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js @@ -18,7 +18,7 @@ // this function works exactly the same as UIUtils.formatDuration function formatDuration(milliseconds) { if (milliseconds < 100) { - return milliseconds + " ms"; + return parseInt(milliseconds).toFixed(1) + " ms"; } var seconds = milliseconds * 1.0 / 1000; if (seconds < 1) { @@ -74,3 +74,114 @@ function getTimeZone() { return new Date().toString().match(/\((.*)\)/)[1]; } } + +function formatLogsCells(execLogs, type) { + if (type !== 'display') return Object.keys(execLogs); + if (!execLogs) return; + var result = ''; + $.each(execLogs, function (logName, logUrl) { + result += '' + }); + return result; +} + +function getStandAloneAppId(cb) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var appId = words[ind + 1]; + cb(appId); + return; + } + ind = words.indexOf("history"); + if (ind > 0) { + var appId = words[ind + 1]; + cb(appId); + return; + } + // Looks like Web UI is running in standalone mode + // Let's get application-id using REST End Point + $.getJSON(location.origin + "/api/v1/applications", function(response, status, jqXHR) { + if (response && response.length > 0) { + var appId = response[0].id; + cb(appId); + return; + } + }); +} + +// This function is a helper function for sorting in datatable. +// When the data is in duration (e.g. 12ms 2s 2min 2h ) +// It will convert the string into integer for correct ordering +function ConvertDurationString(data) { + data = data.toString(); + var units = data.replace(/[\d\.]/g, '' ) + .replace(' ', '') + .toLowerCase(); + var multiplier = 1; + + switch(units) { + case 's': + multiplier = 1000; + break; + case 'min': + multiplier = 600000; + break; + case 'h': + multiplier = 3600000; + break; + default: + break; + } + return parseFloat(data) * multiplier; +} + +function createTemplateURI(appId, templateName) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var baseURI = words.slice(0, ind + 1).join('/') + '/' + appId + '/static/' + templateName + '-template.html'; + return baseURI; + } + ind = words.indexOf("history"); + if(ind > 0) { + var baseURI = words.slice(0, ind).join('/') + '/static/' + templateName + '-template.html'; + return baseURI; + } + return location.origin + "/static/" + templateName + "-template.html"; +} + +function setDataTableDefaults() { + $.extend($.fn.dataTable.defaults, { + stateSave: true, + lengthMenu: [[20, 40, 60, 100, -1], [20, 40, 60, 100, "All"]], + pageLength: 20 + }); +} + +function formatDate(date) { + if (date <= 0) return "-"; + else return date.split(".")[0].replace("T", " "); +} + +function createRESTEndPointForExecutorsPage(appId) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var appId = words[ind + 1]; + var newBaseURI = words.slice(0, ind + 2).join('/'); + return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors" + } + ind = words.indexOf("history"); + if (ind > 0) { + var appId = words[ind + 1]; + var attemptId = words[ind + 2]; + var newBaseURI = words.slice(0, ind).join('/'); + if (isNaN(attemptId)) { + return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors"; + } else { + return newBaseURI + "/api/v1/applications/" + appId + "/" + attemptId + "/allexecutors"; + } + } + return location.origin + "/api/v1/applications/" + appId + "/allexecutors"; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css b/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css new file mode 100644 index 0000000000000..f6b4abed21e0d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +table.dataTable thead .sorting_asc { background: url('images/sort_asc.png') no-repeat bottom right; } + +table.dataTable thead .sorting_desc { background: url('images/sort_desc.png') no-repeat bottom right; } \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 266eeec55576e..fe5bb25687af1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -260,4 +260,105 @@ a.expandbutton { .paginate_button.active > a { color: #999999; text-decoration: underline; +} + +.title-table { + clear: left; + display: inline-block; +} + +.table-dataTable { + width: 100%; +} + +.container-fluid-div { + width: 200px; +} + +.scheduler-delay-checkbox-div { + width: 120px; +} + +.task-deserialization-time-checkbox-div { + width: 175px; +} + +.shuffle-read-blocked-time-checkbox-div { + width: 187px; +} + +.shuffle-remote-reads-checkbox-div { + width: 157px; +} + +.result-serialization-time-checkbox-div { + width: 171px; +} + +.getting-result-time-checkbox-div { + width: 141px; +} + +.peak-execution-memory-checkbox-div { + width: 170px; +} + +#active-tasks-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#active-tasks-table th:first-child { + border-left: 1px solid #dddddd; +} + +#accumulator-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#accumulator-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-executor-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-executor-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-metrics-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-metrics-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-execs-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-execs-table th:first-child { + border-left: 1px solid #dddddd; +} + +#active-executors-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#active-executors-table th:first-child { + border-left: 1px solid #dddddd; } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 63b9d8988499d..5c0ed4d5d8f4c 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -351,7 +351,9 @@ private[spark] class AppStatusStore( def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.toApi).toSeq.reverse + .max(maxTasks).asScala.map { taskDataWrapper => + constructTaskData(taskDataWrapper) + }.toSeq.reverse } def taskList( @@ -390,7 +392,9 @@ private[spark] class AppStatusStore( } val ordered = if (ascending) indexed else indexed.reverse() - ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + ordered.skip(offset).max(length).asScala.map { taskDataWrapper => + constructTaskData(taskDataWrapper) + }.toSeq } def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { @@ -496,6 +500,24 @@ private[spark] class AppStatusStore( store.close() } + def constructTaskData(taskDataWrapper: TaskDataWrapper) : v1.TaskData = { + val taskDataOld: v1.TaskData = taskDataWrapper.toApi + val executorLogs: Option[Map[String, String]] = try { + Some(executorSummary(taskDataOld.executorId).executorLogs) + } catch { + case e: NoSuchElementException => e.getMessage + None + } + new v1.TaskData(taskDataOld.taskId, taskDataOld.index, + taskDataOld.attempt, taskDataOld.launchTime, taskDataOld.resultFetchStart, + taskDataOld.duration, taskDataOld.executorId, taskDataOld.host, taskDataOld.status, + taskDataOld.taskLocality, taskDataOld.speculative, taskDataOld.accumulatorUpdates, + taskDataOld.errorMessage, taskDataOld.taskMetrics, + executorLogs.getOrElse(Map[String, String]()), + AppStatusUtils.schedulerDelay(taskDataOld), + AppStatusUtils.gettingResultTime(taskDataOld)) + } + } private[spark] object AppStatusStore { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 30d52b97833e6..f81892734c2de 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -16,15 +16,16 @@ */ package org.apache.spark.status.api.v1 -import java.util.{List => JList} +import java.util.{HashMap, List => JList, Locale} import javax.ws.rs._ -import javax.ws.rs.core.MediaType +import javax.ws.rs.core.{Context, MediaType, MultivaluedMap, UriInfo} import org.apache.spark.SparkException import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.ApiHelper._ @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class StagesResource extends BaseAppResource { @@ -102,4 +103,120 @@ private[v1] class StagesResource extends BaseAppResource { withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy)) } + // This api needs to stay formatted exactly as it is below, since, it is being used by the + // datatables for the stages page. + @GET + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskTable") + def taskTable( + @PathParam("stageId") stageId: Int, + @PathParam("stageAttemptId") stageAttemptId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean, + @Context uriInfo: UriInfo): + HashMap[String, Object] = { + withUI { ui => + val uriQueryParameters = uriInfo.getQueryParameters(true) + val totalRecords = uriQueryParameters.getFirst("numTasks") + var isSearch = false + var searchValue: String = null + var filteredRecords = totalRecords + // The datatables client API sends a list of query parameters to the server which contain + // information like the columns to be sorted, search value typed by the user in the search + // box, pagination index etc. For more information on these query parameters, + // refer https://datatables.net/manual/server-side. + if (uriQueryParameters.getFirst("search[value]") != null && + uriQueryParameters.getFirst("search[value]").length > 0) { + isSearch = true + searchValue = uriQueryParameters.getFirst("search[value]") + } + val _tasksToShow: Seq[TaskData] = doPagination(uriQueryParameters, stageId, stageAttemptId, + isSearch, totalRecords.toInt) + val ret = new HashMap[String, Object]() + if (_tasksToShow.nonEmpty) { + // Performs server-side search based on input from user + if (isSearch) { + val filteredTaskList = filterTaskList(_tasksToShow, searchValue) + filteredRecords = filteredTaskList.length.toString + if (filteredTaskList.length > 0) { + val pageStartIndex = uriQueryParameters.getFirst("start").toInt + val pageLength = uriQueryParameters.getFirst("length").toInt + ret.put("aaData", filteredTaskList.slice( + pageStartIndex, pageStartIndex + pageLength)) + } else { + ret.put("aaData", filteredTaskList) + } + } else { + ret.put("aaData", _tasksToShow) + } + } else { + ret.put("aaData", _tasksToShow) + } + ret.put("recordsTotal", totalRecords) + ret.put("recordsFiltered", filteredRecords) + ret + } + } + + // Performs pagination on the server side + def doPagination(queryParameters: MultivaluedMap[String, String], stageId: Int, + stageAttemptId: Int, isSearch: Boolean, totalRecords: Int): Seq[TaskData] = { + var columnNameToSort = queryParameters.getFirst("columnNameToSort") + // Sorting on Logs column will default to Index column sort + if (columnNameToSort.equalsIgnoreCase("Logs")) { + columnNameToSort = "Index" + } + val isAscendingStr = queryParameters.getFirst("order[0][dir]") + var pageStartIndex = 0 + var pageLength = totalRecords + // We fetch only the desired rows upto the specified page length for all cases except when a + // search query is present, in that case, we need to fetch all the rows to perform the search + // on the entire table + if (!isSearch) { + pageStartIndex = queryParameters.getFirst("start").toInt + pageLength = queryParameters.getFirst("length").toInt + } + withUI(_.store.taskList(stageId, stageAttemptId, pageStartIndex, pageLength, + indexName(columnNameToSort), isAscendingStr.equalsIgnoreCase("asc"))) + } + + // Filters task list based on search parameter + def filterTaskList( + taskDataList: Seq[TaskData], + searchValue: String): Seq[TaskData] = { + val defaultOptionString: String = "d" + val searchValueLowerCase = searchValue.toLowerCase(Locale.ROOT) + val containsValue = (taskDataParams: Any) => taskDataParams.toString.toLowerCase( + Locale.ROOT).contains(searchValueLowerCase) + val taskMetricsContainsValue = (task: TaskData) => task.taskMetrics match { + case None => false + case Some(metrics) => + (containsValue(task.taskMetrics.get.executorDeserializeTime) + || containsValue(task.taskMetrics.get.executorRunTime) + || containsValue(task.taskMetrics.get.jvmGcTime) + || containsValue(task.taskMetrics.get.resultSerializationTime) + || containsValue(task.taskMetrics.get.memoryBytesSpilled) + || containsValue(task.taskMetrics.get.diskBytesSpilled) + || containsValue(task.taskMetrics.get.peakExecutionMemory) + || containsValue(task.taskMetrics.get.inputMetrics.bytesRead) + || containsValue(task.taskMetrics.get.inputMetrics.recordsRead) + || containsValue(task.taskMetrics.get.outputMetrics.bytesWritten) + || containsValue(task.taskMetrics.get.outputMetrics.recordsWritten) + || containsValue(task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime) + || containsValue(task.taskMetrics.get.shuffleReadMetrics.recordsRead) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.bytesWritten) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.recordsWritten) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.writeTime)) + } + val filteredTaskDataSequence: Seq[TaskData] = taskDataList.filter(f => + (containsValue(f.taskId) || containsValue(f.index) || containsValue(f.attempt) + || containsValue(f.launchTime) + || containsValue(f.resultFetchStart.getOrElse(defaultOptionString)) + || containsValue(f.duration.getOrElse(defaultOptionString)) + || containsValue(f.executorId) || containsValue(f.host) || containsValue(f.status) + || containsValue(f.taskLocality) || containsValue(f.speculative) + || containsValue(f.errorMessage.getOrElse(defaultOptionString)) + || taskMetricsContainsValue(f) + || containsValue(f.schedulerDelay) || containsValue(f.gettingResultTime))) + filteredTaskDataSequence + } + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 30afd8b769720..aa21da2b66ab2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -253,7 +253,10 @@ class TaskData private[spark]( val speculative: Boolean, val accumulatorUpdates: Seq[AccumulableInfo], val errorMessage: Option[String] = None, - val taskMetrics: Option[TaskMetrics] = None) + val taskMetrics: Option[TaskMetrics] = None, + val executorLogs: Map[String, String], + val schedulerDelay: Long, + val gettingResultTime: Long) class TaskMetrics private[spark]( val executorDeserializeTime: Long, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 646cf25880e37..ef19e86f3135f 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -283,7 +283,10 @@ private[spark] class TaskDataWrapper( speculative, accumulatorUpdates, errorMessage, - metrics) + metrics, + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) } @JsonIgnore @KVIndex(TaskIndexNames.STAGE) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 3aed4647a96f0..60a929375baae 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -204,6 +204,8 @@ private[spark] object UIUtils extends Logging { href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala deleted file mode 100644 index 1be81e5ef9952..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import scala.xml.{Node, Unparsed} - -import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.StageData -import org.apache.spark.ui.{ToolTips, UIUtils} -import org.apache.spark.util.Utils - -/** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { - - import ApiHelper._ - - def toNodeSeq: Seq[Node] = { - - - - - - - - - - {if (hasInput(stage)) { - - }} - {if (hasOutput(stage)) { - - }} - {if (hasShuffleRead(stage)) { - - }} - {if (hasShuffleWrite(stage)) { - - }} - {if (hasBytesSpilled(stage)) { - - - }} - - - - - {createExecutorTable(stage)} - -
      Executor IDAddressTask TimeTotal TasksFailed TasksKilled TasksSucceeded Tasks - Input Size / Records - - Output Size / Records - - - Shuffle Read Size / Records - - - Shuffle Write Size / Records - Shuffle Spill (Memory)Shuffle Spill (Disk) - - Blacklisted - - Logs
      - - } - - private def createExecutorTable(stage: StageData) : Seq[Node] = { - val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) - - executorSummary.toSeq.sortBy(_._1).map { case (k, v) => - val executor = store.asOption(store.executorSummary(k)) - - {k} - {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} - {v.failedTasks} - {v.killedTasks} - {v.succeededTasks} - {if (hasInput(stage)) { - - {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} - - }} - {if (hasOutput(stage)) { - - {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} - - }} - {if (hasShuffleRead(stage)) { - - {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} - - }} - {if (hasShuffleWrite(stage)) { - - {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} - - }} - {if (hasBytesSpilled(stage)) { - - {Utils.bytesToString(v.memoryBytesSpilled)} - - - {Utils.bytesToString(v.diskBytesSpilled)} - - }} - { - if (executor.map(_.isBlacklisted).getOrElse(false)) { - for application - } else if (v.isBlacklistedForStage) { - for stage - } else { - false - } - } - {executor.map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - }} - - - - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2b436b9234144..a213b764abea7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -92,6 +92,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val eventTimelineParameterTaskPage = UIUtils.stripXSS( + request.getParameter("task.eventTimelinePageNumber")) + val eventTimelineParameterTaskPageSize = UIUtils.stripXSS( + request.getParameter("task.eventTimelinePageSize")) + var eventTimelineTaskPage = Option(eventTimelineParameterTaskPage).map(_.toInt).getOrElse(1) + var eventTimelineTaskPageSize = Option( + eventTimelineParameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => UIUtils.decodeURLParameter(sortColumn) @@ -132,6 +140,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } else { s"$totalTasks, showing $storedTasks" } + if (eventTimelineTaskPageSize < 1 || eventTimelineTaskPageSize > totalTasks) { + eventTimelineTaskPageSize = totalTasks + } + val eventTimelineTotalPages = + (totalTasks + eventTimelineTaskPageSize - 1) / eventTimelineTaskPageSize + if (eventTimelineTaskPage < 1 || eventTimelineTaskPage > eventTimelineTotalPages) { + eventTimelineTaskPage = 1 + } val summary =
      @@ -193,73 +209,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
      - val showAdditionalMetrics = -
      - - - Show Additional Metrics - - -
      - val stageGraph = parent.store.asOption(parent.store.operationGraphForStage(stageId)) val dagViz = UIUtils.showDagVizForStage(stageId, stageGraph) @@ -277,7 +226,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We stageData.accumulatorUpdates.toSeq) val currentTime = System.currentTimeMillis() - val (taskTable, taskTableHTML) = try { + val taskTable = try { val _taskTable = new TaskPagedTable( stageData, UIUtils.prependBaseUri(request, parent.basePath) + @@ -288,17 +237,10 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We desc = taskSortDesc, store = parent.store ) - (_taskTable, _taskTable.table(taskPage)) + _taskTable } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - val errorMessage = -
      -

      Error while rendering stage table:

      -
      -              {Utils.exceptionString(e)}
      -            
      -
      - (null, errorMessage) + null } val jsForScrollingDownToTaskTable = @@ -316,190 +258,36 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } - val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, - Array(0, 0.25, 0.5, 0.75, 1.0)) - - val summaryTable = metricsSummary.map { metrics => - def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { - data.map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - - def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { - data.map { size => - {Utils.bytesToString(size.toLong)} - } - } - - def sizeQuantilesWithRecords( - data: IndexedSeq[Double], - records: IndexedSeq[Double]) : Seq[Node] = { - data.zip(records).map { case (d, r) => - {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} - } - } - - def titleCell(title: String, tooltip: String): Seq[Node] = { - - - {title} - - - } - - def simpleTitleCell(title: String): Seq[Node] = {title} - - val deserializationQuantiles = titleCell("Task Deserialization Time", - ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - - val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - - val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - - val serializationQuantiles = titleCell("Result Serialization Time", - ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - - val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ - timeQuantiles(metrics.gettingResultTime) - - val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", - ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) - - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ - timeQuantiles(metrics.schedulerDelay) - - def inputQuantiles: Seq[Node] = { - simpleTitleCell("Input Size / Records") ++ - sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) - } - - def outputQuantiles: Seq[Node] = { - simpleTitleCell("Output Size / Records") ++ - sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten) - } - - def shuffleReadBlockedQuantiles: Seq[Node] = { - titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ - timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) - } - - def shuffleReadTotalQuantiles: Seq[Node] = { - titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ - sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, - metrics.shuffleReadMetrics.readRecords) - } - - def shuffleReadRemoteQuantiles: Seq[Node] = { - titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ - sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) - } - - def shuffleWriteQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle Write Size / Records") ++ - sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, - metrics.shuffleWriteMetrics.writeRecords) - } - - def memoryBytesSpilledQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) - } - - def diskBytesSpilledQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) - } - - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", - "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false) - } - - val executorTable = new ExecutorTable(stageData, parent.store) - - val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators(stageData)) {

      Accumulators

      ++ accumulableTable } else Seq() - - val aggMetrics = - -

      - - Aggregated Metrics by Executor -

      -
      -
      - {executorTable.toNodeSeq} -
      - val content = summary ++ - dagViz ++ - showAdditionalMetrics ++ + dagViz ++
      ++ makeTimeline( // Only show the tasks in the table - Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), - currentTime) ++ -

      Summary Metrics for {numCompleted} Completed Tasks

      ++ -
      {summaryTable.getOrElse("No tasks have reported metrics yet.")}
      ++ - aggMetrics ++ - maybeAccumulableTable ++ - -

      - - Tasks ({totalTasksNumStr}) -

      -
      ++ -
      - {taskTableHTML ++ jsForScrollingDownToTaskTable} -
      - UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) + Option(taskTable).map({ taskPagedTable => + val from = (eventTimelineTaskPage - 1) * eventTimelineTaskPageSize + val to = taskPagedTable.dataSource.dataSize.min( + eventTimelineTaskPage * eventTimelineTaskPageSize) + taskPagedTable.dataSource.sliceData(from, to)}).getOrElse(Nil), currentTime, + eventTimelineTaskPage, eventTimelineTaskPageSize, eventTimelineTotalPages, stageId, + stageAttemptId, totalTasks) ++ +
      + + +
      + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true, + useDataTables = true) + } - def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { + def makeTimeline( + tasks: Seq[TaskData], + currentTime: Long, + page: Int, + pageSize: Int, + totalPages: Int, + stageId: Int, + stageAttemptId: Int, + totalTasks: Int): Seq[Node] = { val executorsSet = new HashSet[(String, String)] var minLaunchTime = Long.MaxValue var maxFinishTime = Long.MinValue @@ -658,6 +446,31 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We Enable zooming +
      +
      + + + + + + + + + + +
      +
      {TIMELINE_LEGEND} ++ @@ -959,7 +772,7 @@ private[ui] class TaskPagedTable( } } -private[ui] object ApiHelper { +private[spark] object ApiHelper { val HEADER_ID = "ID" val HEADER_TASK_INDEX = "Index" diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json index 5e9e8230e2745..62e5c123fd3d4 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -1,639 +1,708 @@ { - "status": "COMPLETE", - "stageId": 0, - "attemptId": 0, - "numTasks": 10, - "numActiveTasks": 0, - "numCompleteTasks": 10, - "numFailedTasks": 2, - "numKilledTasks": 0, - "numCompletedIndices": 10, - "executorRunTime": 761, - "executorCpuTime": 269916000, - "submissionTime": "2018-01-09T10:21:18.152GMT", - "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", - "completionTime": "2018-01-09T10:21:19.062GMT", - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleReadBytes": 0, - "shuffleReadRecords": 0, - "shuffleWriteBytes": 460, - "shuffleWriteRecords": 10, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "name": "map at :26", - "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", - "schedulingPool": "default", - "rddIds": [ - 1, - 0 - ], - "accumulatorUpdates": [], - "tasks": { - "0": { - "taskId": 0, - "index": 0, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.347GMT", - "duration": 562, - "executorId": "0", - "host": "172.30.65.138", - "status": "FAILED", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", - "taskMetrics": { - "executorDeserializeTime": 0, - "executorDeserializeCpuTime": 0, - "executorRunTime": 460, - "executorCpuTime": 0, - "resultSize": 0, - "jvmGcTime": 14, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 0, - "writeTime": 3873006, - "recordsWritten": 0 + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 2, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 761, + "executorCpuTime" : 269916000, + "submissionTime" : "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime" : "2018-01-09T10:21:18.347GMT", + "completionTime" : "2018-01-09T10:21:19.062GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 460, + "shuffleWriteRecords" : 10, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :26", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.347GMT", + "duration" : 562, + "executorId" : "0", + "host" : "172.30.65.138", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 460, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 14, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3873006, + "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout", + "stderr" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 102, + "gettingResultTime" : 0 }, - "5": { - "taskId": 5, - "index": 3, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.958GMT", - "duration": 22, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2586000, - "executorRunTime": 9, - "executorCpuTime": 9635000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 262919, - "recordsWritten": 1 + "5" : { + "taskId" : 5, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.958GMT", + "duration" : 22, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2586000, + "executorRunTime" : 9, + "executorCpuTime" : 9635000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 262919, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, - "10": { - "taskId": 10, - "index": 8, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.034GMT", - "duration": 12, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 1803000, - "executorRunTime": 6, - "executorCpuTime": 6157000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 243647, - "recordsWritten": 1 + "10" : { + "taskId" : 10, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.034GMT", + "duration" : 12, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 1803000, + "executorRunTime" : 6, + "executorCpuTime" : 6157000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 243647, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, - "1": { - "taskId": 1, - "index": 1, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.364GMT", - "duration": 565, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 301, - "executorDeserializeCpuTime": 200029000, - "executorRunTime": 212, - "executorCpuTime": 198479000, - "resultSize": 1115, - "jvmGcTime": 13, - "resultSerializationTime": 1, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 2409488, - "recordsWritten": 1 + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.364GMT", + "duration" : 565, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 301, + "executorDeserializeCpuTime" : 200029000, + "executorRunTime" : 212, + "executorCpuTime" : 198479000, + "resultSize" : 1115, + "jvmGcTime" : 13, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 2409488, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 51, + "gettingResultTime" : 0 }, - "6": { - "taskId": 6, - "index": 4, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.980GMT", - "duration": 16, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2610000, - "executorRunTime": 10, - "executorCpuTime": 9622000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 385110, - "recordsWritten": 1 + "6" : { + "taskId" : 6, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.980GMT", + "duration" : 16, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2610000, + "executorRunTime" : 10, + "executorCpuTime" : 9622000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 385110, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "9": { - "taskId": 9, - "index": 7, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.022GMT", - "duration": 12, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 1981000, - "executorRunTime": 7, - "executorCpuTime": 6335000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 259354, - "recordsWritten": 1 + "9" : { + "taskId" : 9, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.022GMT", + "duration" : 12, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 1981000, + "executorRunTime" : 7, + "executorCpuTime" : 6335000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 259354, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "2": { - "taskId": 2, - "index": 2, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.899GMT", - "duration": 27, - "executorId": "0", - "host": "172.30.65.138", - "status": "FAILED", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", - "taskMetrics": { - "executorDeserializeTime": 0, - "executorDeserializeCpuTime": 0, - "executorRunTime": 16, - "executorCpuTime": 0, - "resultSize": 0, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 0, - "writeTime": 126128, - "recordsWritten": 0 + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.899GMT", + "duration" : 27, + "executorId" : "0", + "host" : "172.30.65.138", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 126128, + "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout", + "stderr" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, - "7": { - "taskId": 7, - "index": 5, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.996GMT", - "duration": 15, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 2231000, - "executorRunTime": 9, - "executorCpuTime": 8407000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 205520, - "recordsWritten": 1 + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.996GMT", + "duration" : 15, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 2231000, + "executorRunTime" : 9, + "executorCpuTime" : 8407000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 205520, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, - "3": { - "taskId": 3, - "index": 0, - "attempt": 1, - "launchTime": "2018-01-09T10:21:18.919GMT", - "duration": 24, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 8, - "executorDeserializeCpuTime": 8878000, - "executorRunTime": 10, - "executorCpuTime": 9364000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 207014, - "recordsWritten": 1 + "3" : { + "taskId" : 3, + "index" : 0, + "attempt" : 1, + "launchTime" : "2018-01-09T10:21:18.919GMT", + "duration" : 24, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 8878000, + "executorRunTime" : 10, + "executorCpuTime" : 9364000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 207014, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, - "11": { - "taskId": 11, - "index": 9, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.045GMT", - "duration": 15, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2017000, - "executorRunTime": 6, - "executorCpuTime": 6676000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 233652, - "recordsWritten": 1 + "11" : { + "taskId" : 11, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.045GMT", + "duration" : 15, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2017000, + "executorRunTime" : 6, + "executorCpuTime" : 6676000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 233652, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, - "8": { - "taskId": 8, - "index": 6, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.011GMT", - "duration": 11, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 1, - "executorDeserializeCpuTime": 1554000, - "executorRunTime": 7, - "executorCpuTime": 6034000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 213296, - "recordsWritten": 1 + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.011GMT", + "duration" : 11, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 1554000, + "executorRunTime" : 7, + "executorCpuTime" : 6034000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 213296, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "4": { - "taskId": 4, - "index": 2, - "attempt": 1, - "launchTime": "2018-01-09T10:21:18.943GMT", - "duration": 16, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 2211000, - "executorRunTime": 9, - "executorCpuTime": 9207000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 292381, - "recordsWritten": 1 + "4" : { + "taskId" : 4, + "index" : 2, + "attempt" : 1, + "launchTime" : "2018-01-09T10:21:18.943GMT", + "duration" : 16, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 2211000, + "executorRunTime" : 9, + "executorCpuTime" : 9207000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 292381, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } }, - "executorSummary": { - "0": { - "taskTime": 589, - "failedTasks": 2, - "succeededTasks": 0, - "killedTasks": 0, - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleRead": 0, - "shuffleReadRecords": 0, - "shuffleWrite": 0, - "shuffleWriteRecords": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "isBlacklistedForStage": true + "executorSummary" : { + "0" : { + "taskTime" : 589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true }, - "1": { - "taskTime": 708, - "failedTasks": 0, - "succeededTasks": 10, - "killedTasks": 0, - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleRead": 0, - "shuffleReadRecords": 0, - "shuffleWrite": 460, - "shuffleWriteRecords": 10, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "isBlacklistedForStage": false + "1" : { + "taskTime" : 708, + "failedTasks" : 0, + "succeededTasks" : 10, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 460, + "shuffleWriteRecords" : 10, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, - "killedTasksSummary": {} + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json index acd4cc53de6cd..6e46c881b2a21 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -74,7 +74,13 @@ "writeTime" : 3662221, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 68, + "gettingResultTime" : 0 }, "5" : { "taskId" : 5, @@ -122,7 +128,13 @@ "writeTime" : 191901, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 46, + "gettingResultTime" : 0 }, "10" : { "taskId" : 10, @@ -169,7 +181,13 @@ "writeTime" : 301705, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 50, + "gettingResultTime" : 0 }, "1" : { "taskId" : 1, @@ -217,7 +235,13 @@ "writeTime" : 3075188, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 174, + "gettingResultTime" : 0 }, "6" : { "taskId" : 6, @@ -265,7 +289,13 @@ "writeTime" : 183718, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -312,7 +342,13 @@ "writeTime" : 366050, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 42, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -359,7 +395,13 @@ "writeTime" : 369513, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 26, + "gettingResultTime" : 0 }, "2" : { "taskId" : 2, @@ -406,7 +448,13 @@ "writeTime" : 3322956, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 74, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -453,7 +501,13 @@ "writeTime" : 319101, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, "7" : { "taskId" : 7, @@ -500,7 +554,13 @@ "writeTime" : 377601, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, "3" : { "taskId" : 3, @@ -547,7 +607,13 @@ "writeTime" : 3587839, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 63, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -594,7 +660,13 @@ "writeTime" : 323898, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 12, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -641,7 +713,13 @@ "writeTime" : 311940, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 84, + "gettingResultTime" : 0 }, "4" : { "taskId" : 4, @@ -689,7 +767,13 @@ "writeTime" : 16858066, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 338, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 03f886afa5413..aa9471301fe3e 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -74,7 +74,10 @@ "writeTime" : 76000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 19, + "gettingResultTime" : 0 }, "14" : { "taskId" : 14, @@ -121,7 +124,10 @@ "writeTime" : 88000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -168,7 +174,10 @@ "writeTime" : 98000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -215,7 +224,10 @@ "writeTime" : 73000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -262,7 +274,10 @@ "writeTime" : 101000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -309,7 +324,10 @@ "writeTime" : 83000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -356,7 +374,10 @@ "writeTime" : 94000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "15" : { "taskId" : 15, @@ -403,7 +424,10 @@ "writeTime" : 79000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 947c89906955d..584803b5e8631 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -74,7 +74,10 @@ "writeTime" : 76000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 19, + "gettingResultTime" : 0 }, "14" : { "taskId" : 14, @@ -121,7 +124,10 @@ "writeTime" : 88000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -168,7 +174,10 @@ "writeTime" : 98000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -215,7 +224,10 @@ "writeTime" : 73000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -262,7 +274,10 @@ "writeTime" : 101000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -309,7 +324,10 @@ "writeTime" : 83000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -356,7 +374,10 @@ "writeTime" : 94000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "15" : { "taskId" : 15, @@ -403,7 +424,10 @@ "writeTime" : 79000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index a15ee23523365..f859ab6fff240 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -89,7 +92,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -135,7 +141,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -181,7 +190,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -273,7 +288,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -319,7 +337,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -365,7 +386,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -411,7 +435,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -457,7 +484,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 10, "index" : 10, @@ -503,7 +533,10 @@ "writeTime" : 94709, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 11, "index" : 11, @@ -549,7 +582,10 @@ "writeTime" : 94507, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -595,7 +631,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 13, "index" : 13, @@ -641,7 +680,10 @@ "writeTime" : 95004, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -687,7 +729,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -733,7 +778,10 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -779,7 +827,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -825,7 +876,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -871,7 +925,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -917,5 +974,8 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index f9182b1658334..ea88ca116707a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -48,7 +48,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -99,7 +102,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -150,7 +156,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -201,7 +210,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -252,7 +264,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -303,7 +318,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -354,7 +372,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -405,5 +426,8 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 76dd2f710b90f..efd0a45bf01d0 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -48,7 +48,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -99,7 +102,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -150,7 +156,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -201,7 +210,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -252,7 +264,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -303,7 +318,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -354,7 +372,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -405,5 +426,8 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 6bdc10465d89e..d83528d84972c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 94709, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 11, "index" : 11, @@ -89,7 +92,10 @@ "writeTime" : 94507, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -135,7 +141,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 13, "index" : 13, @@ -181,7 +190,10 @@ "writeTime" : 95004, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -227,7 +239,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -273,7 +288,10 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -319,7 +337,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -365,7 +386,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -411,7 +435,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -457,7 +484,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -503,7 +533,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -595,7 +631,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 23, "index" : 23, @@ -641,7 +680,10 @@ "writeTime" : 91844, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 24, "index" : 24, @@ -687,7 +729,10 @@ "writeTime" : 157194, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 25, "index" : 25, @@ -733,7 +778,10 @@ "writeTime" : 94134, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 }, { "taskId" : 26, "index" : 26, @@ -779,7 +827,10 @@ "writeTime" : 108213, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 27, "index" : 27, @@ -825,7 +876,10 @@ "writeTime" : 102019, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 28, "index" : 28, @@ -871,7 +925,10 @@ "writeTime" : 104299, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, { "taskId" : 29, "index" : 29, @@ -917,7 +974,10 @@ "writeTime" : 114938, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, { "taskId" : 30, "index" : 30, @@ -963,7 +1023,10 @@ "writeTime" : 119770, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 31, "index" : 31, @@ -1009,7 +1072,10 @@ "writeTime" : 92619, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, { "taskId" : 32, "index" : 32, @@ -1055,7 +1121,10 @@ "writeTime" : 89603, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 33, "index" : 33, @@ -1101,7 +1170,10 @@ "writeTime" : 118329, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, { "taskId" : 34, "index" : 34, @@ -1147,7 +1219,10 @@ "writeTime" : 127746, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 }, { "taskId" : 35, "index" : 35, @@ -1193,7 +1268,10 @@ "writeTime" : 160963, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, { "taskId" : 36, "index" : 36, @@ -1239,7 +1317,10 @@ "writeTime" : 123855, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 37, "index" : 37, @@ -1285,7 +1366,10 @@ "writeTime" : 111869, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 38, "index" : 38, @@ -1331,7 +1415,10 @@ "writeTime" : 131158, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 39, "index" : 39, @@ -1377,7 +1464,10 @@ "writeTime" : 98748, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 40, "index" : 40, @@ -1423,7 +1513,10 @@ "writeTime" : 94792, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 41, "index" : 41, @@ -1469,7 +1562,10 @@ "writeTime" : 90765, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 42, "index" : 42, @@ -1515,7 +1611,10 @@ "writeTime" : 103713, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 }, { "taskId" : 43, "index" : 43, @@ -1561,7 +1660,10 @@ "writeTime" : 171516, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 }, { "taskId" : 44, "index" : 44, @@ -1607,7 +1709,10 @@ "writeTime" : 98293, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 45, "index" : 45, @@ -1653,7 +1758,10 @@ "writeTime" : 92985, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 46, "index" : 46, @@ -1699,7 +1807,10 @@ "writeTime" : 113322, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, { "taskId" : 47, "index" : 47, @@ -1745,7 +1856,10 @@ "writeTime" : 103015, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 48, "index" : 48, @@ -1791,7 +1905,10 @@ "writeTime" : 139844, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 49, "index" : 49, @@ -1837,7 +1954,10 @@ "writeTime" : 94984, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 50, "index" : 50, @@ -1883,7 +2003,10 @@ "writeTime" : 90836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 51, "index" : 51, @@ -1929,7 +2052,10 @@ "writeTime" : 96013, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 52, "index" : 52, @@ -1975,7 +2101,10 @@ "writeTime" : 89664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 53, "index" : 53, @@ -2021,7 +2150,10 @@ "writeTime" : 92835, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 54, "index" : 54, @@ -2067,7 +2199,10 @@ "writeTime" : 90506, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 55, "index" : 55, @@ -2113,7 +2248,10 @@ "writeTime" : 108309, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 56, "index" : 56, @@ -2159,7 +2297,10 @@ "writeTime" : 90329, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 57, "index" : 57, @@ -2205,7 +2346,10 @@ "writeTime" : 96849, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 58, "index" : 58, @@ -2251,7 +2395,10 @@ "writeTime" : 97521, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 59, "index" : 59, @@ -2297,5 +2444,8 @@ "writeTime" : 100753, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index bc1cd49909d31..82e339c8f56dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -89,7 +92,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -135,7 +141,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -181,7 +190,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -273,7 +288,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 0, "index" : 0, @@ -319,7 +337,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -365,7 +386,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -411,7 +435,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -457,7 +484,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -503,7 +533,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -595,7 +631,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -641,7 +680,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -687,7 +729,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -733,7 +778,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -779,7 +827,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -825,7 +876,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -871,7 +925,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -917,5 +974,8 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index bc1cd49909d31..82e339c8f56dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -89,7 +92,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -135,7 +141,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -181,7 +190,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -273,7 +288,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 0, "index" : 0, @@ -319,7 +337,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -365,7 +386,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -411,7 +435,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -457,7 +484,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -503,7 +533,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -595,7 +631,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -641,7 +680,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -687,7 +729,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -733,7 +778,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -779,7 +827,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -825,7 +876,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -871,7 +925,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -917,5 +974,8 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index 09857cb401acd..01eef1b565bf6 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 94792, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 41, "index" : 41, @@ -89,7 +92,10 @@ "writeTime" : 90765, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 43, "index" : 43, @@ -135,7 +141,10 @@ "writeTime" : 171516, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 }, { "taskId" : 57, "index" : 57, @@ -181,7 +190,10 @@ "writeTime" : 96849, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 58, "index" : 58, @@ -227,7 +239,10 @@ "writeTime" : 97521, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 68, "index" : 68, @@ -273,7 +288,10 @@ "writeTime" : 101750, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 86, "index" : 86, @@ -319,7 +337,10 @@ "writeTime" : 95848, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 32, "index" : 32, @@ -365,7 +386,10 @@ "writeTime" : 89603, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 39, "index" : 39, @@ -411,7 +435,10 @@ "writeTime" : 98748, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 42, "index" : 42, @@ -457,7 +484,10 @@ "writeTime" : 103713, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 }, { "taskId" : 51, "index" : 51, @@ -503,7 +533,10 @@ "writeTime" : 96013, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 59, "index" : 59, @@ -549,7 +582,10 @@ "writeTime" : 100753, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 63, "index" : 63, @@ -595,7 +631,10 @@ "writeTime" : 102779, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 87, "index" : 87, @@ -641,7 +680,10 @@ "writeTime" : 102159, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 90, "index" : 90, @@ -687,7 +729,10 @@ "writeTime" : 98472, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 99, "index" : 99, @@ -733,7 +778,10 @@ "writeTime" : 133964, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 44, "index" : 44, @@ -779,7 +827,10 @@ "writeTime" : 98293, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 47, "index" : 47, @@ -825,7 +876,10 @@ "writeTime" : 103015, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 50, "index" : 50, @@ -871,7 +925,10 @@ "writeTime" : 90836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 52, "index" : 52, @@ -917,5 +974,8 @@ "writeTime" : 89664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 963f010968b62..a8e1fd303a42a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -83,14 +83,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, - "1" : { - "taskId" : 1, - "index" : 1, + "5" : { + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 53, + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -99,11 +102,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "247", - "value" : "2175" + "update" : "897", + "value" : "3750" } ], "taskMetrics" : { - "executorDeserializeTime" : 14, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -135,14 +138,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, - "2" : { - "taskId" : 2, - "index" : 2, + "1" : { + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 48, + "launchTime" : "2015-03-16T19:25:36.521GMT", + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -151,11 +157,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "378", - "value" : "378" + "update" : "247", + "value" : "2175" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 14, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -187,14 +193,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "3" : { - "taskId" : 3, - "index" : 3, + "6" : { + "taskId" : 6, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 50, + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -203,11 +212,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "572", - "value" : "950" + "update" : "978", + "value" : "1928" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -239,14 +248,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "4" : { - "taskId" : 4, - "index" : 4, + "2" : { + "taskId" : 2, + "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 52, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -255,17 +267,17 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "678", - "value" : "2853" + "update" : "378", + "value" : "378" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "peakExecutionMemory" : 0, @@ -291,14 +303,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, - "5" : { - "taskId" : 5, - "index" : 5, + "7" : { + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 52, + "launchTime" : "2015-03-16T19:25:36.524GMT", + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -307,8 +322,8 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "897", - "value" : "3750" + "update" : "1222", + "value" : "4972" } ], "taskMetrics" : { "executorDeserializeTime" : 12, @@ -343,14 +358,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "6" : { - "taskId" : 6, - "index" : 6, + "3" : { + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 51, + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -359,11 +377,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "978", - "value" : "1928" + "update" : "572", + "value" : "950" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -395,14 +413,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, - "7" : { - "taskId" : 7, - "index" : 7, + "4" : { + "taskId" : 4, + "index" : 4, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 51, + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -411,8 +432,8 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "1222", - "value" : "4972" + "update" : "678", + "value" : "2853" } ], "taskMetrics" : { "executorDeserializeTime" : 12, @@ -421,7 +442,7 @@ "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "peakExecutionMemory" : 0, @@ -447,7 +468,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala index 9e74e86ad54b9..a01b24d323d28 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -52,7 +52,10 @@ class AppStatusUtilsSuite extends SparkFunSuite { inputMetrics = null, outputMetrics = null, shuffleReadMetrics = null, - shuffleWriteMetrics = null))) + shuffleWriteMetrics = null)), + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) val finishedTask = new TaskData( @@ -83,7 +86,10 @@ class AppStatusUtilsSuite extends SparkFunSuite { inputMetrics = null, outputMetrics = null, shuffleReadMetrics = null, - shuffleWriteMetrics = null))) + shuffleWriteMetrics = null)), + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) } } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 2945c3ee0a9d9..5e976ae4e91da 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -96,18 +96,6 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { } } - test("peak execution memory should displayed") { - val html = renderStagePage().toString().toLowerCase(Locale.ROOT) - val targetString = "peak execution memory" - assert(html.contains(targetString)) - } - - test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val html = renderStagePage().toString().toLowerCase(Locale.ROOT) - // verify min/25/50/75/max show task value not cumulative values - assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) - } - /** * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. From fbf62b7100be992cbc4eb67e154682db6c91e60e Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 26 Nov 2018 13:13:06 -0800 Subject: [PATCH 2141/2461] [SPARK-25451][SPARK-26100][CORE] Aggregated metrics table doesn't show the right number of the total tasks Total tasks in the aggregated table and the tasks table are not matching some times in the WEBUI. We need to force update the executor summary of the particular executorId, when ever last task of that executor has reached. Currently it force update based on last task on the stage end. So, for some particular executorId task might miss at the stage end. Tests to reproduce: ``` bin/spark-shell --master yarn --conf spark.executor.instances=3 sc.parallelize(1 to 10000, 10).map{ x => throw new RuntimeException("Bad executor")}.collect() ``` Before patch: ![screenshot from 2018-11-15 02-24-05](https://user-images.githubusercontent.com/23054875/48511776-b0d36480-e87d-11e8-89a8-ab97216e2c21.png) After patch: ![screenshot from 2018-11-15 02-32-38](https://user-images.githubusercontent.com/23054875/48512141-c39a6900-e87e-11e8-8535-903e1d11d13e.png) Closes #23038 from shahidki31/SPARK-25451. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../spark/status/AppStatusListener.scala | 19 +++++++- .../org/apache/spark/status/LiveEntity.scala | 2 + .../spark/status/AppStatusListenerSuite.scala | 45 +++++++++++++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 81d39e0407fed..8e845573a903d 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -473,6 +473,7 @@ private[spark] class AppStatusListener( val locality = event.taskInfo.taskLocality.toString() val count = stage.localitySummary.getOrElse(locality, 0L) + 1L stage.localitySummary = stage.localitySummary ++ Map(locality -> count) + stage.activeTasksPerExecutor(event.taskInfo.executorId) += 1 maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -558,6 +559,7 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) } + stage.activeTasksPerExecutor(event.taskInfo.executorId) -= 1 // [SPARK-24415] Wait for all tasks to finish before removing stage from live list val removeStage = stage.activeTasks == 0 && @@ -582,7 +584,11 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) } - conditionalLiveUpdate(job, now, removeStage) + if (removeStage) { + update(job, now) + } else { + maybeUpdate(job, now) + } } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -593,7 +599,16 @@ private[spark] class AppStatusListener( if (metricsDelta != null) { esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } - conditionalLiveUpdate(esummary, now, removeStage) + + val isLastTask = stage.activeTasksPerExecutor(event.taskInfo.executorId) == 0 + + // If the last task of the executor finished, then update the esummary + // for both live and history events. + if (isLastTask) { + update(esummary, now) + } else { + maybeUpdate(esummary, now) + } if (!stage.cleaning && stage.savedTasks.get() > maxTasksPerStage) { stage.cleaning = true diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 80663318c1ba1..47e45a66ecccb 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -376,6 +376,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + val activeTasksPerExecutor = new HashMap[String, Int]().withDefaultValue(0) + var blackListedExecutors = new HashSet[String]() // Used for cleanup of tasks after they reach the configured limit. Not written to the store. diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 5f757b757ac61..1c787ff43b9ac 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1273,6 +1273,51 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(allJobs.head.numFailedStages == 1) } + test("SPARK-25451: total tasks in the executor summary should match total stage tasks") { + val testConf = conf.clone.set(LIVE_ENTITY_UPDATE_PERIOD, Long.MaxValue) + + val listener = new AppStatusListener(store, testConf, true) + + val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details") + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + val tasks = createTasks(4, Array("1", "2")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) + } + + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + Success, tasks(0), null)) + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + Success, tasks(1), null)) + + stage.failureReason = Some("Failed") + listener.onStageCompleted(SparkListenerStageCompleted(stage)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobFailed(new RuntimeException("Bad Executor")))) + + time += 1 + tasks(2).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) + time += 1 + tasks(3).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) + + val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) + esummary.foreach { execSummary => + assert(execSummary.failedTasks === 1) + assert(execSummary.succeededTasks === 1) + assert(execSummary.killedTasks === 0) + } + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) From 6f1a1c1248e0341a690aee655af05da9e9cbff90 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Nov 2018 14:37:41 -0800 Subject: [PATCH 2142/2461] [SPARK-25451][HOTFIX] Call stage.attemptNumber instead of attemptId. Closes #23149 from vanzin/SPARK-25451.hotfix. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/status/AppStatusListenerSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 1c787ff43b9ac..7860a0df4bb2d 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1289,11 +1289,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", Success, tasks(0), null)) time += 1 tasks(1).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", Success, tasks(1), null)) stage.failureReason = Some("Failed") @@ -1303,11 +1303,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(2).markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) time += 1 tasks(3).markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptId, "taskType", + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) From 9deaa726ef1645746892a23d369c3d14677a48ff Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Nov 2018 15:33:21 -0800 Subject: [PATCH 2143/2461] [INFRA] Close stale PR. Closes #23107 From c995e0737de66441052fbf0fb941c5ea05d0163f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 26 Nov 2018 17:01:56 -0800 Subject: [PATCH 2144/2461] [SPARK-26140] followup: rename ShuffleMetricsReporter ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/23105, due to working on two parallel PRs at once, I made the mistake of committing the copy of the PR that used the name ShuffleMetricsReporter for the interface, rather than the appropriate one ShuffleReadMetricsReporter. This patch fixes that. ## How was this patch tested? This should be fine as long as compilation passes. Closes #23147 from rxin/ShuffleReadMetricsReporter. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../spark/executor/ShuffleReadMetrics.scala | 4 +-- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../apache/spark/shuffle/ShuffleManager.scala | 2 +- .../shuffle/ShuffleMetricsReporter.scala | 33 ------------------- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 4 +-- 6 files changed, 7 insertions(+), 40 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 2f97e969d2dd2..12c4b8f67f71c 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -18,7 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.shuffle.ShuffleMetricsReporter +import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.LongAccumulator @@ -130,7 +130,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { * shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at * last. */ -private[spark] class TempShuffleReadMetrics extends ShuffleMetricsReporter { +private[spark] class TempShuffleReadMetrics extends ShuffleReadMetricsReporter { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7cb031ce318b7..27e2f98c58f0c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -33,7 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, - readMetrics: ShuffleMetricsReporter, + readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index d1061d83cb85a..df601cbdb2050 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -49,7 +49,7 @@ private[spark] trait ShuffleManager { startPartition: Int, endPartition: Int, context: TaskContext, - metrics: ShuffleMetricsReporter): ShuffleReader[K, C] + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] /** * Remove a shuffle's metadata from the ShuffleManager. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala deleted file mode 100644 index 32865149c97c2..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMetricsReporter.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -/** - * An interface for reporting shuffle information, for each shuffle. This interface assumes - * all the methods are called on a single-threaded, i.e. concrete implementations would not need - * to synchronize anything. - */ -private[spark] trait ShuffleMetricsReporter { - def incRemoteBlocksFetched(v: Long): Unit - def incLocalBlocksFetched(v: Long): Unit - def incRemoteBytesRead(v: Long): Unit - def incRemoteBytesReadToDisk(v: Long): Unit - def incLocalBytesRead(v: Long): Unit - def incFetchWaitTime(v: Long): Unit - def incRecordsRead(v: Long): Unit -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 57c3150e5a697..4f8be198e4a72 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -115,7 +115,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext, - metrics: ShuffleMetricsReporter): ShuffleReader[K, C] = { + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, metrics) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a2e0713e70b04..86f7c08eddcb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf -import org.apache.spark.shuffle.{FetchFailedException, ShuffleMetricsReporter} +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -73,7 +73,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, - shuffleMetrics: ShuffleMetricsReporter) + shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ From 1c487f7d1442a7043e7faff76ab67a633edc7b05 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Nov 2018 12:13:48 +0800 Subject: [PATCH 2145/2461] [SPARK-24762][SQL][FOLLOWUP] Enable Option of Product encoders ## What changes were proposed in this pull request? This is follow-up of #21732. This patch inlines `isOptionType` method. ## How was this patch tested? Existing tests. Closes #23143 from viirya/SPARK-24762-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index d019924711e3e..589e215c55e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -251,11 +251,6 @@ case class ExpressionEncoder[T]( */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] - /** - * Returns true if the type `T` is an `Option` type. - */ - def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) - /** * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in * the struct are naturally mapped to top-level columns in a row. In other words, the serialized @@ -263,7 +258,9 @@ case class ExpressionEncoder[T]( * flattened to top-level row, because in Spark SQL top-level row can't be null. This method * returns true if `T` is serialized as struct and is not `Option` type. */ - def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType + def isSerializedAsStructForTopLevel: Boolean = { + isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + } // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This From 85383d29ede19dd73949fe57cadb73ec94b29334 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Nov 2018 04:51:32 +0000 Subject: [PATCH 2146/2461] [SPARK-25860][SPARK-26107][FOLLOW-UP] Rule ReplaceNullWithFalseInPredicate ## What changes were proposed in this pull request? Based on https://github.com/apache/spark/pull/22857 and https://github.com/apache/spark/pull/23079, this PR did a few updates - Limit the data types of NULL to Boolean. - Limit the input data type of replaceNullWithFalse to Boolean; throw an exception in the testing mode. - Create a new file for the rule ReplaceNullWithFalseInPredicate - Update the description of this rule. ## How was this patch tested? Added a test case Closes #23139 from gatorsmile/followupSpark-25860. Authored-by: gatorsmile Signed-off-by: DB Tsai --- .../ReplaceNullWithFalseInPredicate.scala | 110 ++++++++++++++++++ .../sql/catalyst/optimizer/expressions.scala | 66 ----------- ...ReplaceNullWithFalseInPredicateSuite.scala | 11 +- 3 files changed, 119 insertions(+), 68 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala new file mode 100644 index 0000000000000..72a60f692ac78 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} +import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.Utils + + +/** + * A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search + * condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator + * "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is + * semantically equivalent to `FalseLiteral` when evaluating the whole search condition. + * + * Please note that FALSE and NULL are not exchangeable in most cases, when the search condition + * contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable + * in very limited cases. + * + * For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`. + * + * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; + * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually + * `Filter(FalseLiteral)`. + * + * Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch + * conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions. + * + * For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified + * into `Project(Literal(2))`. + */ +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) + case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case p: LogicalPlan => p transformExpressions { + case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) + case cw @ CaseWhen(branches, _) => + val newBranches = branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> value + } + cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) + } + } + + /** + * Recursively traverse the Boolean-type expression to replace + * `Literal(null, BooleanType)` with `FalseLiteral`, if possible. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or + * `Literal(null, BooleanType)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case Literal(null, BooleanType) => + FalseLiteral + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case e if e.dataType == BooleanType => + e + case e => + val message = "Expected a Boolean type expression in replaceNullWithFalse, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`." + if (Utils.isTesting) { + throw new IllegalArgumentException(message) + } else { + logWarning(message) + e + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 354efd883f814..468a950fb1087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -736,69 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } - -/** - * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. - * - * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates - * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. - * - * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. - * - * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; - * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually - * `Filter(FalseLiteral)`. - * - * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can - * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` - * can be simplified into `Project(Literal(2))`. - * - * As a result, many unnecessary computations can be removed in the query optimization phase. - */ -object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) - case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) - case p: LogicalPlan => p transformExpressions { - case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) - case cw @ CaseWhen(branches, _) => - val newBranches = branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> value - } - cw.copy(branches = newBranches) - case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - af.copy(function = newLambda) - case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - ae.copy(function = newLambda) - case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => - val newLambda = lf.copy(function = replaceNullWithFalse(func)) - mf.copy(function = newLambda) - } - } - - /** - * Recursively replaces `Literal(null, _)` with `FalseLiteral`. - * - * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit - * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. - */ - private def replaceNullWithFalse(e: Expression): Expression = e match { - case cw: CaseWhen if cw.dataType == BooleanType => - val newBranches = cw.branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> replaceNullWithFalse(value) - } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => - If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) - case And(left, right) => - And(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Or(left, right) => - Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Literal(null, _) => FalseLiteral - case _ => e - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 3a9e6cae0fd87..ee0d04da3e46c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -44,8 +44,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { - testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) - testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + } + + test("Not expected type - replaceNullWithFalse") { + val e = intercept[IllegalArgumentException] { + testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral) + }.getMessage + assert(e.contains("but got the type `int` in `CAST(NULL AS INT)")) } test("replace null in branches of If") { From 6a064ba8f271d5f9d04acd41d0eea50a5b0f5018 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 26 Nov 2018 22:35:52 -0800 Subject: [PATCH 2147/2461] [SPARK-26141] Enable custom metrics implementation in shuffle write ## What changes were proposed in this pull request? This is the write side counterpart to https://github.com/apache/spark/pull/23105 ## How was this patch tested? No behavior change expected, as it is a straightforward refactoring. Updated all existing test cases. Closes #23106 from rxin/SPARK-26141. Authored-by: Reynold Xin Signed-off-by: Reynold Xin --- .../sort/BypassMergeSortShuffleWriter.java | 11 +++++------ .../shuffle/sort/ShuffleExternalSorter.java | 18 ++++++++++++------ .../shuffle/sort/UnsafeShuffleWriter.java | 9 +++++---- .../storage/TimeTrackingOutputStream.java | 7 ++++--- .../spark/executor/ShuffleWriteMetrics.scala | 13 +++++++------ .../spark/scheduler/ShuffleMapTask.scala | 3 ++- .../apache/spark/shuffle/ShuffleManager.scala | 6 +++++- .../shuffle/sort/SortShuffleManager.scala | 10 ++++++---- .../apache/spark/storage/BlockManager.scala | 7 +++---- .../spark/storage/DiskBlockObjectWriter.scala | 4 ++-- .../spark/util/collection/ExternalSorter.scala | 4 ++-- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 6 ++++-- .../scala/org/apache/spark/ShuffleSuite.scala | 12 ++++++++---- .../BypassMergeSortShuffleWriterSuite.scala | 16 ++++++++-------- project/MimaExcludes.scala | 7 ++++++- 15 files changed, 79 insertions(+), 54 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index b020a6d99247b..fda33cd8293d5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -37,12 +37,11 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; @@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int numPartitions; private final BlockManager blockManager; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final Serializer serializer; @@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, - TaskContext taskContext, - SparkConf conf) { + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 1c0d664afb138..6ee9d5f0eec3b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -38,6 +38,7 @@ import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; /** * Force this sorter to spill when there are this many elements in memory. @@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics) { super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), memoryManager.getTungstenMemoryMode()); @@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private void writeSortedFile(boolean isLastFile) { - final ShuffleWriteMetrics writeMetricsToUse; + final ShuffleWriteMetricsReporter writeMetricsToUse; if (isLastFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. @@ -241,9 +242,14 @@ private void writeSortedFile(boolean isLastFile) { // // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. - // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten()); - taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten()); + // SPARK-3577 tracks the spill time separately. + + // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning + // of this method. + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled( + ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..4b0c74341551e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -37,7 +37,6 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; @@ -47,6 +46,7 @@ import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; @@ -73,7 +73,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -122,7 +122,8 @@ public UnsafeShuffleWriter( SerializedShuffleHandle handle, int mapId, TaskContext taskContext, - SparkConf sparkConf) throws IOException { + SparkConf sparkConf, + ShuffleWriteMetricsReporter writeMetrics) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -138,7 +139,7 @@ public UnsafeShuffleWriter( this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index 5d0555a8c28e1..fcba3b73445c9 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -21,7 +21,7 @@ import java.io.OutputStream; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; /** * Intercepts write calls and tracks total time spent writing in order to update shuffle write @@ -30,10 +30,11 @@ @Private public final class TimeTrackingOutputStream extends OutputStream { - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final OutputStream outputStream; - public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + public TimeTrackingOutputStream( + ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) { this.writeMetrics = writeMetrics; this.outputStream = outputStream; } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 0c9da657c2b60..d0b0e7da079c9 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.LongAccumulator @@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private[spark] () extends Serializable { +class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator private[executor] val _writeTime = new LongAccumulator @@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { */ def writeTime: Long = _writeTime.sum - private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) - private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) - private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) - private[spark] def decBytesWritten(v: Long): Unit = { + private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] override def decBytesWritten(v: Long): Unit = { _bytesWritten.setValue(bytesWritten - v) } - private[spark] def decRecordsWritten(v: Long): Unit = { + private[spark] override def decRecordsWritten(v: Long): Unit = { _recordsWritten.setValue(recordsWritten - v) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f2cd65fd523ab..5412717d61988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask( var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager - writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) + writer = manager.getWriter[Any, Any]( + dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) writer.stop(success = true).get } catch { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index df601cbdb2050..18a743fbfa6fc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -38,7 +38,11 @@ private[spark] trait ShuffleManager { dependency: ShuffleDependency[K, V, C]): ShuffleHandle /** Get a writer for a given partition. Called on executors by map tasks. */ - def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] /** * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 4f8be198e4a72..b51a843a31c31 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V]( handle: ShuffleHandle, mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get @@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager unsafeShuffleHandle, mapId, context, - env.conf) + env.conf, + metrics) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, - context, - env.conf) + env.conf, + metrics) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index edae2f95fce33..1b617297e0a30 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,10 +33,9 @@ import scala.util.Random import scala.util.control.NonFatal import com.codahale.metrics.{MetricRegistry, MetricSet} -import com.google.common.io.CountingOutputStream import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source @@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -932,7 +931,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { + writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = { val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a024c83d8d8b7..17390f9c60e79 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -20,9 +20,9 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils /** @@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter( syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics, + writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream with Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b159200d79222..eac3db01158d0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -793,8 +793,8 @@ private[spark] class ExternalSorter[K, V, C]( def nextPartition(): Int = cur._1._1 } - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " + + s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) forceSpillFiles += spillFile val spillReader = new SpillReader(spillFile) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a07d0e84ea854..30ad3f5575545 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -162,7 +162,8 @@ private UnsafeShuffleWriter createWriter( new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics() ); } @@ -521,7 +522,8 @@ public void testPeakMemoryUsed() throws Exception { new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf); + conf, + taskContext.taskMetrics().shuffleWriteMetrics()); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 419a26b857ea2..35f728cd57fe2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful - val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context1 = + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer1 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. - val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context2 = + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer2 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 85ccb33471048..4467c3241a947 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -136,8 +136,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -160,8 +160,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(records) writer.stop( /* success = */ true) @@ -195,8 +195,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { @@ -217,8 +217,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 333adb0c84025..3fabec0f60125 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -226,7 +226,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"), + + // [SPARK-26141] Enable custom metrics implementation in shuffle write + // Following are Java private classes + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this") ) // Exclude rules for 2.4.x From 65244b1d790699b6a3a29f2fa111d35f9809111a Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Tue, 27 Nov 2018 20:10:34 +0800 Subject: [PATCH 2148/2461] [SPARK-23356][SQL][TEST] add new test cases for a + 1,a + b and Rand in SetOperationSuite ## What changes were proposed in this pull request? The purpose of this PR is supplement new test cases for a + 1,a + b and Rand in SetOperationSuite. It comes from the comment of closed PR:#20541, thanks. ## How was this patch tested? add new test cases Closes #23138 from heary-cao/UnionPushTestCases. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../optimizer/SetOperationSuite.scala | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index da3923f8d6477..17e00c9a3ead2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} +import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanOrEqual, If, Literal, Rand, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -196,4 +196,31 @@ class SetOperationSuite extends PlanTest { )) comparePlans(expectedPlan, rewrittenPlan) } + + test("SPARK-23356 union: expressions with literal in project list are pushed down") { + val unionQuery = testUnion.select(('a + 1).as("aa")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 1).as("aa")) :: + testRelation2.select(('d + 1).as("aa")) :: + testRelation3.select(('g + 1).as("aa")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: expressions in project list are pushed down") { + val unionQuery = testUnion.select(('a + 'b).as("ab")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 'b).as("ab")) :: + testRelation2.select(('d + 'e).as("ab")) :: + testRelation3.select(('g + 'h).as("ab")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: no pushdown for non-deterministic expression") { + val unionQuery = testUnion.select('a, Rand(10).as("rnd")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = unionQuery.analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } } From 2d89d109e19d1e84c4ada3c9d5d48cfcf3d997ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 27 Nov 2018 09:09:16 -0800 Subject: [PATCH 2149/2461] [SPARK-26025][K8S] Speed up docker image build on dev repo. The "build context" for a docker image - basically the whole contents of the current directory where "docker" is invoked - can be huge in a dev build, easily breaking a couple of gigs. Doing that copy 3 times during the build of docker images severely slows down the process. This patch creates a smaller build context - basically mimicking what the make-distribution.sh script does, so that when building the docker images, only the necessary bits are in the current directory. For PySpark and R that is optimized further, since those images are built based on the previously built Spark main image. In my current local clone, the dir size is about 2G, but with this script the "context" sent to docker is about 250M for the main image, 1M for the pyspark image and 8M for the R image. That speeds up the image builds considerably. I also snuck in a fix to the k8s integration test dependencies in the sbt build, so that the examples are properly built (without having to do it manually). Closes #23019 from vanzin/SPARK-26025. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- bin/docker-image-tool.sh | 122 ++++++++++++------ project/SparkBuild.scala | 3 +- .../src/main/dockerfiles/spark/Dockerfile | 14 +- 3 files changed, 91 insertions(+), 48 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index e51201a77cb5d..9f735f1148da4 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -29,6 +29,20 @@ if [ -z "${SPARK_HOME}" ]; then fi . "${SPARK_HOME}/bin/load-spark-env.sh" +CTX_DIR="$SPARK_HOME/target/tmp/docker" + +function is_dev_build { + [ ! -f "$SPARK_HOME/RELEASE" ] +} + +function cleanup_ctx_dir { + if is_dev_build; then + rm -rf "$CTX_DIR" + fi +} + +trap cleanup_ctx_dir EXIT + function image_ref { local image="$1" local add_repo="${2:-1}" @@ -53,80 +67,114 @@ function docker_push { fi } +# Create a smaller build context for docker in dev builds to make the build faster. Docker +# uploads all of the current directory to the daemon, and it can get pretty big with dev +# builds that contain test log files and other artifacts. +# +# Three build contexts are created, one for each image: base, pyspark, and sparkr. For them +# to have the desired effect, the docker command needs to be executed inside the appropriate +# context directory. +# +# Note: docker does not support symlinks in the build context. +function create_dev_build_context {( + set -e + local BASE_CTX="$CTX_DIR/base" + mkdir -p "$BASE_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$BASE_CTX/kubernetes/dockerfiles" + + cp -r "assembly/target/scala-$SPARK_SCALA_VERSION/jars" "$BASE_CTX/jars" + cp -r "resource-managers/kubernetes/integration-tests/tests" \ + "$BASE_CTX/kubernetes/tests" + + mkdir "$BASE_CTX/examples" + cp -r "examples/src" "$BASE_CTX/examples/src" + # Copy just needed examples jars instead of everything. + mkdir "$BASE_CTX/examples/jars" + for i in examples/target/scala-$SPARK_SCALA_VERSION/jars/*; do + if [ ! -f "$BASE_CTX/jars/$(basename $i)" ]; then + cp $i "$BASE_CTX/examples/jars" + fi + done + + for other in bin sbin data; do + cp -r "$other" "$BASE_CTX/$other" + done + + local PYSPARK_CTX="$CTX_DIR/pyspark" + mkdir -p "$PYSPARK_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$PYSPARK_CTX/kubernetes/dockerfiles" + mkdir "$PYSPARK_CTX/python" + cp -r "python/lib" "$PYSPARK_CTX/python/lib" + + local R_CTX="$CTX_DIR/sparkr" + mkdir -p "$R_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$R_CTX/kubernetes/dockerfiles" + cp -r "R" "$R_CTX/R" +)} + +function img_ctx_dir { + if is_dev_build; then + echo "$CTX_DIR/$1" + else + echo "$SPARK_HOME" + fi +} + function build { local BUILD_ARGS - local IMG_PATH - local JARS - - if [ ! -f "$SPARK_HOME/RELEASE" ]; then - # Set image build arguments accordingly if this is a source repo and not a distribution archive. - # - # Note that this will copy all of the example jars directory into the image, and that will - # contain a lot of duplicated jars with the main Spark directory. In a proper distribution, - # the examples directory is cleaned up before generating the distribution tarball, so this - # issue does not occur. - IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles - JARS=assembly/target/scala-$SPARK_SCALA_VERSION/jars - BUILD_ARGS=( - ${BUILD_PARAMS} - --build-arg - img_path=$IMG_PATH - --build-arg - spark_jars=$JARS - --build-arg - example_jars=examples/target/scala-$SPARK_SCALA_VERSION/jars - --build-arg - k8s_tests=resource-managers/kubernetes/integration-tests/tests - ) - else - # Not passed as arguments to docker, but used to validate the Spark directory. - IMG_PATH="kubernetes/dockerfiles" - JARS=jars - BUILD_ARGS=(${BUILD_PARAMS}) + local SPARK_ROOT="$SPARK_HOME" + + if is_dev_build; then + create_dev_build_context || error "Failed to create docker build context." + SPARK_ROOT="$CTX_DIR/base" fi # Verify that the Docker image content directory is present - if [ ! -d "$IMG_PATH" ]; then + if [ ! -d "$SPARK_ROOT/kubernetes/dockerfiles" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi # Verify that Spark has actually been built/is a runnable distribution # i.e. the Spark JARs that the Docker files will place into the image are present - local TOTAL_JARS=$(ls $JARS/spark-* | wc -l) + local TOTAL_JARS=$(ls $SPARK_ROOT/jars/spark-* | wc -l) TOTAL_JARS=$(( $TOTAL_JARS )) if [ "${TOTAL_JARS}" -eq 0 ]; then error "Cannot find Spark JARs. This script assumes that Apache Spark has first been built locally or this is a runnable distribution." fi + local BUILD_ARGS=(${BUILD_PARAMS}) local BINDING_BUILD_ARGS=( ${BUILD_PARAMS} --build-arg base_img=$(image_ref spark) ) - local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local BASEDOCKERFILE=${BASEDOCKERFILE:-"kubernetes/dockerfiles/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-false} local RDOCKERFILE=${RDOCKERFILE:-false} - docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir base) && docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$BASEDOCKERFILE" . + -f "$BASEDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build Spark JVM Docker image, please refer to Docker build output for details." fi if [ "${PYDOCKERFILE}" != "false" ]; then - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir pyspark) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ - -f "$PYDOCKERFILE" . + -f "$PYDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build PySpark Docker image, please refer to Docker build output for details." fi fi if [ "${RDOCKERFILE}" != "false" ]; then - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir sparkr) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-r) \ - -f "$RDOCKERFILE" . + -f "$RDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 08e22fab65165..bb834bc483f1f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -516,7 +516,8 @@ object KubernetesIntegrationTests { s"-Dspark.kubernetes.test.unpackSparkDir=$sparkHome" ), // Force packaging before building images, so that the latest code is tested. - dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly).value + dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly) + .dependsOn(packageBin in Compile in examples).value ) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 5f469c30a96fa..89b20e1446229 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,11 +17,6 @@ FROM openjdk:8-alpine -ARG spark_jars=jars -ARG example_jars=examples/jars -ARG img_path=kubernetes/dockerfiles -ARG k8s_tests=kubernetes/tests - # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -41,13 +36,12 @@ RUN set -ex && \ echo "auth required pam_wheel.so use_uid" >> /etc/pam.d/su && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY ${spark_jars} /opt/spark/jars +COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY ${img_path}/spark/entrypoint.sh /opt/ -COPY ${example_jars} /opt/spark/examples/jars -COPY examples/src /opt/spark/examples/src -COPY ${k8s_tests} /opt/spark/tests +COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark From 8c6871828e3eb9fdb3bc665441a1aaf60b86b1e7 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 28 Nov 2018 13:37:11 +0800 Subject: [PATCH 2150/2461] [SPARK-26159] Codegen for LocalTableScanExec and RDDScanExec ## What changes were proposed in this pull request? Implement codegen for `LocalTableScanExec` and `ExistingRDDExec`. Refactor to share code between `LocalTableScanExec`, `ExistingRDDExec`, `InputAdapter` and `RowDataSourceScanExec`. The difference in `doProduce` between these four was that `ExistingRDDExec` and `RowDataSourceScanExec` triggered adding an `UnsafeProjection`, while `InputAdapter` and `LocalTableScanExec` did not. In the new trait `InputRDDCodegen` I added a flag `createUnsafeProjection` which the operators set accordingly. Note: `LocalTableScanExec` explicitly creates its input as `UnsafeRows`, so it was obvious why it doesn't need an `UnsafeProjection`. But if an `InputAdapter` may take input that is `InternalRows` but not `UnsafeRows`, then I think it doesn't need an unsafe projection just because any other operator that is its parent would do that. That assumes that that any parent operator would always result in some `UnsafeProjection` being eventually added, and hence the output of the `WholeStageCodegen` unit would be `UnsafeRows`. If these assumptions hold, I think `createUnsafeProjection` could be set to `(parent == null)`. Note: Do not codegen `LocalTableScanExec` when it's the only operator. `LocalTableScanExec` has optimized driver-only `executeCollect` and `executeTake` code paths that are used to return `Command` results without starting Spark Jobs. They can no longer be used if the `LocalTableScanExec` gets optimized. ## How was this patch tested? Covered and used in existing tests. Closes #23127 from juliuszsompolski/SPARK-26159. Authored-by: Juliusz Sompolski Signed-off-by: Wenchen Fan --- python/pyspark/sql/dataframe.py | 2 +- .../sql/execution/DataSourceScanExec.scala | 28 +------ .../spark/sql/execution/ExistingRDD.scala | 7 +- .../sql/execution/LocalTableScanExec.scala | 10 ++- .../sql/execution/WholeStageCodegenExec.scala | 78 ++++++++++++++----- .../sql-tests/results/operators.sql.out | 12 +-- 6 files changed, 86 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ca15b36699166..b8833a39078ba 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -257,7 +257,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan ExistingRDD[age#0,name#1] + *(1) Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 77e381ef6e6b4..4faa27c2c1e23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -84,7 +84,7 @@ case class RowDataSourceScanExec( rdd: RDD[InternalRow], @transient relation: BaseRelation, override val tableIdentifier: Option[TableIdentifier]) - extends DataSourceScanExec { + extends DataSourceScanExec with InputRDDCodegen { def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput) @@ -104,30 +104,10 @@ case class RowDataSourceScanExec( } } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - rdd :: Nil - } + // Input can be InternalRow, has to be turned into UnsafeRows. + override protected val createUnsafeProjection: Boolean = true - override protected def doProduce(ctx: CodegenContext): String = { - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val exprRows = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable) - } - val row = ctx.freshName("row") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columnsRowInput = exprRows.map(_.genCode(ctx)) - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } + override def inputRDD: RDD[InternalRow] = rdd override val metadata: Map[String, String] = { val markedFilters = for (filter <- filters) yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 9f67d556af362..e214bfd050410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -175,7 +175,7 @@ case class RDDScanExec( rdd: RDD[InternalRow], name: String, override val outputPartitioning: Partitioning = UnknownPartitioning(0), - override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { + override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode with InputRDDCodegen { private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") @@ -199,4 +199,9 @@ case class RDDScanExec( override def simpleString: String = { s"$nodeName${truncatedString(output, "[", ",", "]")}" } + + // Input can be InternalRow, has to be turned into UnsafeRows. + override protected val createUnsafeProjection: Boolean = true + + override def inputRDD: RDD[InternalRow] = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 448eb703eacde..31640db3722ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - @transient rows: Seq[InternalRow]) extends LeafExecNode { + @transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -76,4 +76,12 @@ case class LocalTableScanExec( longMetric("numOutputRows").add(taken.size) taken } + + // Input is already UnsafeRows. + override protected val createUnsafeProjection: Boolean = false + + // Do not codegen when there is no parent - to support the fast driver-local collect/take paths. + override def supportCodegen: Boolean = (parent != null) + + override def inputRDD: RDD[InternalRow] = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 29bcbcae366c5..fbda0d87a175f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -350,6 +350,15 @@ trait CodegenSupport extends SparkPlan { */ def needStopCheck: Boolean = parent.needStopCheck + /** + * Helper default should stop check code. + */ + def shouldStopCheckCode: String = if (needStopCheck) { + "if (shouldStop()) return;" + } else { + "// shouldStop check is eliminated" + } + /** * A sequence of checks which evaluate to true if the downstream Limit operators have not received * enough records and reached the limit. If current node is a data producing node, it can leverage @@ -406,6 +415,53 @@ trait BlockingOperatorWithCodegen extends CodegenSupport { override def limitNotReachedChecks: Seq[String] = Nil } +/** + * Leaf codegen node reading from a single RDD. + */ +trait InputRDDCodegen extends CodegenSupport { + + def inputRDD: RDD[InternalRow] + + // If the input can be InternalRows, an UnsafeProjection needs to be created. + protected val createUnsafeProjection: Boolean + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + override def doProduce(ctx: CodegenContext): String = { + // Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + forceInline = true) + val row = ctx.freshName("row") + + val outputVars = if (createUnsafeProjection) { + // creating the vars will make the parent consume add an unsafe projection. + ctx.INPUT_ROW = row + ctx.currentVars = null + output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } else { + null + } + + val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) { + val numOutputRows = metricTerm(ctx, "numOutputRows") + s"$numOutputRows.add(1);" + } else { + "" + } + s""" + | while ($limitNotReachedCond $input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | ${updateNumOutputRowsMetrics} + | ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim} + | ${shouldStopCheckCode} + | } + """.stripMargin + } +} /** * InputAdapter is used to hide a SparkPlan from a subtree that supports codegen. @@ -413,7 +469,7 @@ trait BlockingOperatorWithCodegen extends CodegenSupport { * This is the leaf node of a tree with WholeStageCodegen that is used to generate code * that consumes an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen { override def output: Seq[Attribute] = child.output @@ -429,24 +485,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp child.doExecuteBroadcast() } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.execute() :: Nil - } + override def inputRDD: RDD[InternalRow] = child.execute() - override def doProduce(ctx: CodegenContext): String = { - // Right now, InputAdapter is only used when there is one input RDD. - // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - forceInline = true) - val row = ctx.freshName("row") - s""" - | while ($limitNotReachedCond $input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | ${consume(ctx, null, row).trim} - | if (shouldStop()) return; - | } - """.stripMargin - } + // InputAdapter does not need UnsafeProjection. + protected val createUnsafeProjection: Boolean = false override def generateTreeString( depth: Int, diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fd1d0db9e3f78..570b281353f3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -201,7 +201,7 @@ struct -- !query 24 output == Physical Plan == *Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 25 @@ -211,7 +211,7 @@ struct -- !query 25 output == Physical Plan == *Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 26 @@ -221,7 +221,7 @@ struct -- !query 26 output == Physical Plan == *Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 27 @@ -231,7 +231,7 @@ struct -- !query 27 output == Physical Plan == *Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 28 @@ -241,7 +241,7 @@ struct -- !query 28 output == Physical Plan == *Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 29 @@ -251,7 +251,7 @@ struct -- !query 29 output == Physical Plan == *Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 30 From 09a91d98bdecb86ecad4647b7ef5fb3f69bdc671 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 16:21:42 +0800 Subject: [PATCH 2151/2461] [SPARK-26021][SQL][FOLLOWUP] add test for special floating point values ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/23043 . Add a test to show the minor behavior change introduced by #23043 , and add migration guide. ## How was this patch tested? a new test Closes #23141 from cloud-fan/follow. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/unsafe/PlatformUtilSuite.java | 12 +++++--- docs/sql-migration-guide-upgrade.md | 6 ++-- .../catalyst/expressions/UnsafeArrayData.java | 6 ---- .../spark/sql/DatasetPrimitiveSuite.scala | 29 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 7 +++++ 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index ab34324eb54cc..2474081dad5c9 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -165,10 +165,14 @@ public void writeMinusZeroIsReplacedWithZero() { byte[] floatBytes = new byte[Float.BYTES]; Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); - float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); - Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); - Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); + byte[] doubleBytes2 = new byte[Double.BYTES]; + byte[] floatBytes2 = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f); + + // Make sure the bytes we write from 0.0 and -0.0 are same. + Assert.assertArrayEquals(doubleBytes, doubleBytes2); + Assert.assertArrayEquals(floatBytes, floatBytes2); } } diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 68cb8f5a0d18c..25cd541190919 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,14 +17,16 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. - - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. + - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. - + - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. + - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 9002abdcfd474..d5f679fe23d48 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -334,17 +334,11 @@ public void setLong(int ordinal, long value) { } public void setFloat(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value); } public void setDouble(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 96a6792f52f3e..0ded5d8ce1e28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -393,4 +393,33 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { val ds = spark.createDataset(data) checkDataset(ds, data: _*) } + + test("special floating point values") { + import org.scalatest.exceptions.TestFailedException + + // Spark treats -0.0 as 0.0 + intercept[TestFailedException] { + checkDataset(Seq(-0.0d).toDS(), -0.0d) + } + intercept[TestFailedException] { + checkDataset(Seq(-0.0f).toDS(), -0.0f) + } + intercept[TestFailedException] { + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(-0.0)) + } + + val floats = Seq[Float](-0.0f, 0.0f, Float.NaN).toDS() + checkDataset(floats, 0.0f, 0.0f, Float.NaN) + + val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN).toDS() + checkDataset(doubles, 0.0, 0.0, Double.NaN) + + checkDataset(Seq(Tuple1(Float.NaN)).toDS(), Tuple1(Float.NaN)) + checkDataset(Seq(Tuple1(-0.0f)).toDS(), Tuple1(0.0f)) + checkDataset(Seq(Tuple1(Double.NaN)).toDS(), Tuple1(Double.NaN)) + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + + val complex = Map(Array(Seq(Tuple1(Double.NaN))) -> Map(Tuple2(Float.NaN, null))) + checkDataset(Seq(complex).toDS(), complex) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8ba67239fb907..a547676c5ed5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -132,6 +132,13 @@ abstract class QueryTest extends PlanTest { a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + // 0.0 == -0.0, turn float/double to binary before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) case (a, b) => a == b } From 93112e693082f3fba24cebaf9a98dcf5c1eb84af Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 28 Nov 2018 20:18:13 +0800 Subject: [PATCH 2152/2461] [SPARK-26142][SQL] Implement shuffle read metrics in SQL ## What changes were proposed in this pull request? Implement `SQLShuffleMetricsReporter` on the sql side as the customized ShuffleMetricsReporter, which extended the `TempShuffleReadMetrics` and update SQLMetrics, in this way shuffle metrics can be reported in the SQL UI. ## How was this patch tested? Add UT in SQLMetricsSuite. Manual test locally, before: ![image](https://user-images.githubusercontent.com/4833765/48960517-30f97880-efa8-11e8-982c-92d05938fd1d.png) after: ![image](https://user-images.githubusercontent.com/4833765/48960587-b54bfb80-efa8-11e8-8e95-7a3c8c74cc5c.png) Closes #23128 from xuanyuanking/SPARK-26142. Lead-authored-by: Yuanjian Li Co-authored-by: liyuanjian Signed-off-by: Wenchen Fan --- .../spark/sql/execution/ShuffledRowRDD.scala | 9 ++- .../exchange/ShuffleExchangeExec.scala | 5 +- .../apache/spark/sql/execution/limit.scala | 10 ++- .../sql/execution/metric/SQLMetrics.scala | 20 ++++++ .../metric/SQLShuffleMetricsReporter.scala | 67 +++++++++++++++++++ .../execution/UnsafeRowSerializerSuite.scala | 5 +- .../execution/metric/SQLMetricsSuite.scala | 21 ++++-- 7 files changed, 126 insertions(+), 11 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 542266bc1ae07..9b05faaed0459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -22,6 +22,7 @@ import java.util.Arrays import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter} /** * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition @@ -112,6 +113,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A */ class ShuffledRowRDD( var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + metrics: Map[String, SQLMetric], specifiedPartitionStartIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { @@ -154,7 +156,10 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] - val metrics = context.taskMetrics().createTempShuffleReadMetrics() + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics) // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -163,7 +168,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - metrics) + sqlMetricsReporter) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index d6742ab3e0f31..8938d93da90eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -48,7 +48,8 @@ case class ShuffleExchangeExec( // e.g. it can be null on the Executor side override lazy val metrics = Map( - "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") + ) ++ SQLMetrics.getShuffleReadMetrics(sparkContext) override def nodeName: String = { val extraInfo = coordinator match { @@ -108,7 +109,7 @@ case class ShuffleExchangeExec( assert(newPartitioning.isInstanceOf[HashPartitioning]) newPartitioning = UnknownPartitioning(indices.length) } - new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 90dafcf535914..ea845da8438fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Take the first `limit` elements and collect them to a single partition. @@ -37,11 +38,13 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - locallyLimited, child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer), + metrics) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -151,6 +154,8 @@ case class TakeOrderedAndProjectExec( private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) + protected override def doExecute(): RDD[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val localTopK: RDD[InternalRow] = { @@ -160,7 +165,8 @@ case class TakeOrderedAndProjectExec( } val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - localTopK, child.output, SinglePartition, serializer)) + localTopK, child.output, SinglePartition, serializer), + metrics) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList != child.output) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index cbf707f4a9cfd..0b5ee3a5e0577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -82,6 +82,14 @@ object SQLMetrics { private val baseForAvgMetric: Int = 10 + val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" + val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" + val REMOTE_BYTES_READ = "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = "remoteBytesReadToDisk" + val LOCAL_BYTES_READ = "localBytesRead" + val FETCH_WAIT_TIME = "fetchWaitTime" + val RECORDS_READ = "recordsRead" + /** * Converts a double value to long value by multiplying a base integer, so we can store it in * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore @@ -194,4 +202,16 @@ object SQLMetrics { SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) } } + + /** + * Create all shuffle read relative metrics and return the Map. + */ + def getShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( + REMOTE_BLOCKS_FETCHED -> createMetric(sc, "remote blocks fetched"), + LOCAL_BLOCKS_FETCHED -> createMetric(sc, "local blocks fetched"), + REMOTE_BYTES_READ -> createSizeMetric(sc, "remote bytes read"), + REMOTE_BYTES_READ_TO_DISK -> createSizeMetric(sc, "remote bytes read to disk"), + LOCAL_BYTES_READ -> createSizeMetric(sc, "local bytes read"), + FETCH_WAIT_TIME -> createTimingMetric(sc, "fetch wait time"), + RECORDS_READ -> createMetric(sc, "records read")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala new file mode 100644 index 0000000000000..542141ea4b4e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.executor.TempShuffleReadMetrics + +/** + * A shuffle metrics reporter for SQL exchange operators. + * @param tempMetrics [[TempShuffleReadMetrics]] created in TaskContext. + * @param metrics All metrics in current SparkPlan. This param should not empty and + * contains all shuffle metrics defined in [[SQLMetrics.getShuffleReadMetrics]]. + */ +private[spark] class SQLShuffleMetricsReporter( + tempMetrics: TempShuffleReadMetrics, + metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { + private[this] val _remoteBlocksFetched = metrics(SQLMetrics.REMOTE_BLOCKS_FETCHED) + private[this] val _localBlocksFetched = metrics(SQLMetrics.LOCAL_BLOCKS_FETCHED) + private[this] val _remoteBytesRead = metrics(SQLMetrics.REMOTE_BYTES_READ) + private[this] val _remoteBytesReadToDisk = metrics(SQLMetrics.REMOTE_BYTES_READ_TO_DISK) + private[this] val _localBytesRead = metrics(SQLMetrics.LOCAL_BYTES_READ) + private[this] val _fetchWaitTime = metrics(SQLMetrics.FETCH_WAIT_TIME) + private[this] val _recordsRead = metrics(SQLMetrics.RECORDS_READ) + + override def incRemoteBlocksFetched(v: Long): Unit = { + _remoteBlocksFetched.add(v) + tempMetrics.incRemoteBlocksFetched(v) + } + override def incLocalBlocksFetched(v: Long): Unit = { + _localBlocksFetched.add(v) + tempMetrics.incLocalBlocksFetched(v) + } + override def incRemoteBytesRead(v: Long): Unit = { + _remoteBytesRead.add(v) + tempMetrics.incRemoteBytesRead(v) + } + override def incRemoteBytesReadToDisk(v: Long): Unit = { + _remoteBytesReadToDisk.add(v) + tempMetrics.incRemoteBytesReadToDisk(v) + } + override def incLocalBytesRead(v: Long): Unit = { + _localBytesRead.add(v) + tempMetrics.incLocalBytesRead(v) + } + override def incFetchWaitTime(v: Long): Unit = { + _fetchWaitTime.add(v) + tempMetrics.incFetchWaitTime(v) + } + override def incRecordsRead(v: Long): Unit = { + _recordsRead.add(v) + tempMetrics.incRecordsRead(v) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index d305ce3e698ae..96b3aa5ee75b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -137,7 +138,9 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { rowsRDD, new PartitionIdPassthrough(2), new UnsafeRowSerializer(2)) - val shuffled = new ShuffledRowRDD(dependency) + val shuffled = new ShuffledRowRDD( + dependency, + SQLMetrics.getShuffleReadMetrics(spark.sparkContext)) shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index b955c157a620e..0f1d08b6af5d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -94,8 +94,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 1L, "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + val shuffleExpected1 = Map( + "records read" -> 2L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L) testSparkPlanMetrics(df, 1, Map( 2L -> (("HashAggregate", expected1(0))), + 1L -> (("Exchange", shuffleExpected1)), 0L -> (("HashAggregate", expected1(1)))) ) @@ -106,8 +111,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 3L, "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + val shuffleExpected2 = Map( + "records read" -> 4L, + "local blocks fetched" -> 4L, + "remote blocks fetched" -> 0L) testSparkPlanMetrics(df2, 1, Map( 2L -> (("HashAggregate", expected2(0))), + 1L -> (("Exchange", shuffleExpected2)), 0L -> (("HashAggregate", expected2(1)))) ) } @@ -191,7 +201,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared testSparkPlanMetrics(df, 1, Map( 0L -> (("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of output rows" -> 4L)))) + "number of output rows" -> 4L))), + 2L -> (("Exchange", Map( + "records read" -> 4L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L)))) ) } } @@ -208,7 +222,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> (("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + // It's 8 because we read 6 rows in the left and 2 row in the right one "number of output rows" -> 8L)))) ) @@ -216,7 +230,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> (("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + // It's 8 because we read 6 rows in the left and 2 row in the right one "number of output rows" -> 8L)))) ) } @@ -287,7 +301,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan is // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) val df = df1.join(df2, "key") - val metrics = getSparkPlanMetrics(df, 1, Set(1L)) testSparkPlanMetrics(df, 1, Map( 1L -> (("ShuffledHashJoin", Map( "number of output rows" -> 2L, From 438f8fd675d8f819373b6643dea3a77d954b6822 Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Wed, 28 Nov 2018 20:22:24 +0800 Subject: [PATCH 2153/2461] [SPARK-26114][CORE] ExternalSorter's readingIterator field leak ## What changes were proposed in this pull request? This pull request fixes [SPARK-26114](https://issues.apache.org/jira/browse/SPARK-26114) issue that occurs when trying to reduce the number of partitions by means of coalesce without shuffling after shuffle-based transformations. The leak occurs because of not cleaning up `ExternalSorter`'s `readingIterator` field as it's done for its `map` and `buffer` fields. Additionally there are changes to the `CompletionIterator` to prevent capturing its `sub`-iterator and holding it even after the completion iterator completes. It is necessary because in some cases, e.g. in case of standard scala's `flatMap` iterator (which is used is `CoalescedRDD`'s `compute` method) the next value of the main iterator is assigned to `flatMap`'s `cur` field only after it is available. For DAGs where ShuffledRDD is a parent of CoalescedRDD it means that the data should be fetched from the map-side of the shuffle, but the process of fetching this data consumes quite a lot of memory in addition to the memory already consumed by the iterator held by `flatMap`'s `cur` field (until it is reassigned). For the following data ```scala import org.apache.hadoop.io._ import org.apache.hadoop.io.compress._ import org.apache.commons.lang._ import org.apache.spark._ // generate 100M records of sample data sc.makeRDD(1 to 1000, 1000) .flatMap(item => (1 to 100000) .map(i => new Text(RandomStringUtils.randomAlphanumeric(3).toLowerCase) -> new Text(RandomStringUtils.randomAlphanumeric(1024)))) .saveAsSequenceFile("/tmp/random-strings", Some(classOf[GzipCodec])) ``` and the following job ```scala import org.apache.hadoop.io._ import org.apache.spark._ import org.apache.spark.storage._ val rdd = sc.sequenceFile("/tmp/random-strings", classOf[Text], classOf[Text]) rdd .map(item => item._1.toString -> item._2.toString) .repartitionAndSortWithinPartitions(new HashPartitioner(1000)) .coalesce(10,false) .count ``` ... executed like the following ```bash spark-shell \ --num-executors=5 \ --executor-cores=2 \ --master=yarn \ --deploy-mode=client \ --conf spark.executor.memoryOverhead=512 \ --conf spark.executor.memory=1g \ --conf spark.dynamicAllocation.enabled=false \ --conf spark.executor.extraJavaOptions='-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp -Dio.netty.noUnsafe=true' ``` ... executors are always failing with OutOfMemoryErrors. The main issue is multiple leaks of ExternalSorter references. For example, in case of 2 tasks per executor it is expected to be 2 simultaneous instances of ExternalSorter per executor but heap dump generated on OutOfMemoryError shows that there are more ones. ![run1-noparams-dominator-tree-externalsorter](https://user-images.githubusercontent.com/1523889/48703665-782ce580-ec05-11e8-95a9-d6c94e8285ab.png) P.S. This PR does not cover cases with CoGroupedRDDs which use ExternalAppendOnlyMap internally, which itself can lead to OutOfMemoryErrors in many places. ## How was this patch tested? - Existing unit tests - New unit tests - Job executions on the live environment Here is the screenshot before applying this patch ![run3-noparams-failure-ui-5x2-repartition-and-sort](https://user-images.githubusercontent.com/1523889/48700395-f769eb80-ebfc-11e8-831b-e94c757d416c.png) Here is the screenshot after applying this patch ![run3-noparams-success-ui-5x2-repartition-and-sort](https://user-images.githubusercontent.com/1523889/48700610-7a8b4180-ebfd-11e8-9761-baaf38a58e66.png) And in case of reducing the number of executors even more the job is still stable ![run3-noparams-success-ui-2x2-repartition-and-sort](https://user-images.githubusercontent.com/1523889/48700619-82e37c80-ebfd-11e8-98ed-a38e1f1f1fd9.png) Closes #23083 from szhem/SPARK-26114-externalsorter-leak. Authored-by: Sergey Zhemzhitsky Signed-off-by: Wenchen Fan --- .../spark/util/CompletionIterator.scala | 7 ++++-- .../util/collection/ExternalSorter.scala | 3 ++- .../spark/util/CompletionIteratorSuite.scala | 22 +++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 21acaa95c5645..f4d6c7a28d2e4 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,11 +25,14 @@ private[spark] abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { private[this] var completed = false - def next(): A = sub.next() + private[this] var iter = sub + def next(): A = iter.next() def hasNext: Boolean = { - val r = sub.hasNext + val r = iter.hasNext if (!r && !completed) { completed = true + // reassign to release resources of highly resource consuming iterators early + iter = Iterator.empty.asInstanceOf[I] completion() } r diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index eac3db01158d0..46279e79d78db 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -727,9 +727,10 @@ private[spark] class ExternalSorter[K, V, C]( spills.clear() forceSpillFiles.foreach(s => s.file.delete()) forceSpillFiles.clear() - if (map != null || buffer != null) { + if (map != null || buffer != null || readingIterator != null) { map = null // So that the memory can be garbage-collected buffer = null // So that the memory can be garbage-collected + readingIterator = null // So that the memory can be garbage-collected releaseMemory() } } diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala index 688fcd9f9aaba..29421f7aa9e36 100644 --- a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.util +import java.lang.ref.PhantomReference +import java.lang.ref.ReferenceQueue + import org.apache.spark.SparkFunSuite class CompletionIteratorSuite extends SparkFunSuite { @@ -44,4 +47,23 @@ class CompletionIteratorSuite extends SparkFunSuite { assert(!completionIter.hasNext) assert(numTimesCompleted === 1) } + test("reference to sub iterator should not be available after completion") { + var sub = Iterator(1, 2, 3) + + val refQueue = new ReferenceQueue[Iterator[Int]] + val ref = new PhantomReference[Iterator[Int]](sub, refQueue) + + val iter = CompletionIterator[Int, Iterator[Int]](sub, {}) + sub = null + iter.toArray + + for (_ <- 1 to 100 if !ref.isEnqueued) { + System.gc() + if (!ref.isEnqueued) { + Thread.sleep(10) + } + } + assert(ref.isEnqueued) + assert(refQueue.poll() === ref) + } } From affe80958d366f399466a9dba8e03da7f3b7b9bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 20:38:42 +0800 Subject: [PATCH 2154/2461] [SPARK-26147][SQL] only pull out unevaluable python udf from join condition ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable. This PR fixes this mistake. ## How was this patch tested? a new test Closes #23153 from cloud-fan/join. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests/test_udf.py | 12 ++ .../spark/sql/catalyst/optimizer/joins.scala | 22 ++-- ...PullOutPythonUDFInJoinConditionSuite.scala | 120 ++++++++++++------ 3 files changed, 106 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index d2dfb52f54475..ed298f724d551 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -209,6 +209,18 @@ def test_udf_in_join_condition(self): with self.sql_conf({"spark.sql.crossJoin.enabled": True}): self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_left_outer_join_condition(self): + # regression test for SPARK-26147 + from pyspark.sql.functions import udf, col + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a: str(a), StringType()) + # The join condition can't be pushed down, as it refers to attributes from both sides. + # The Python UDF only refer to attributes from one side, so it's evaluable. + df = left.join(right, f("a") == col("b").cast("string"), how="left_outer") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_left_semi_join_condition(self): # regression test for SPARK-25314 from pyspark.sql.functions import udf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 7149edee0173e..6ebb194d71c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF - * and pull them out from join condition. For python udf accessing attributes from only one side, - * they are pushed down by operation push down rules. If not (e.g. user disables filter push - * down rules), we need to pull them out in this rule too. + * PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides. + * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them + * out from join condition. */ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { - def hasPythonUDF(expression: Expression): Boolean = { - expression.collectFirst { case udf: PythonUDF => udf }.isDefined + + private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = { + expr.find { e => + PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right) + }.isDefined } override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case j @ Join(_, _, joinType, condition) - if condition.isDefined && hasPythonUDF(condition.get) => + case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) => if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { // The current strategy only support InnerLike and LeftSemi join because for other type, // it breaks SQL semantic if we run the join condition as a filter after join. If we pass @@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH } // If condition expression contains python udf, it will be moved out from // the new join conditions. - val (udf, rest) = - splitConjunctivePredicates(condition.get).partition(hasPythonUDF) + val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j)) val newCondition = if (rest.isEmpty) { - logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," + + logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," + s" it will be moved out and the join plan will be turned to cross join.") None } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala index d3867f2b6bd0e..3f1c91df7f2e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.scalatest.Matchers._ - import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf._ -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, IntegerType} class PullOutPythonUDFInJoinConditionSuite extends PlanTest { @@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { CheckCartesianProducts) :: Nil } - val testRelationLeft = LocalRelation('a.int, 'b.int) - val testRelationRight = LocalRelation('c.int, 'd.int) + val attrA = 'a.int + val attrB = 'b.int + val attrC = 'c.int + val attrD = 'd.int + + val testRelationLeft = LocalRelation(attrA, attrB) + val testRelationRight = LocalRelation(attrC, attrD) + + // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only + // refer to attributes from one side. + val evaluableJoinCond = { + val pythonUDF = PythonUDF("evaluable", null, + IntegerType, + Seq(attrA), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + pythonUDF === attrC + } - // Dummy python UDF for testing. Unable to execute. - val pythonUDF = PythonUDF("pythonUDF", null, + // This join condition is a PythonUDF which refers to attributes from 2 tables. + val unevaluableJoinCond = PythonUDF("unevaluable", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) @@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { } } - test("inner join condition with python udf only") { - val query = testRelationLeft.join( + test("inner join condition with python udf") { + val query1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF).analyze - comparePlanWithCrossJoinEnable(query, expected) + condition = None).where(unevaluableJoinCond).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } - test("left semi join condition with python udf only") { - val query = testRelationLeft.join( + test("left semi join condition with python udf") { + val query1 = testRelationLeft.join( testRelationRight, joinType = LeftSemi, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF).select('a, 'b).analyze - comparePlanWithCrossJoinEnable(query, expected) + condition = None).where(unevaluableJoinCond).select('a, 'b).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } - test("python udf and common condition") { + test("unevaluable python udf and common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF && 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } - test("python udf or common condition") { + test("unevaluable python udf or common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF || 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze + condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze comparePlanWithCrossJoinEnable(query, expected) } - test("pull out whole complex condition with multiple python udf") { + test("pull out whole complex condition with multiple unevaluable python udf") { val pythonUDF1 = PythonUDF("pythonUDF1", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1 + val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1 val query = testRelationLeft.join( testRelationRight, @@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { comparePlanWithCrossJoinEnable(query, expected) } - test("partial pull out complex condition with multiple python udf") { + test("partial pull out complex condition with multiple unevaluable python udf") { val pythonUDF1 = PythonUDF("pythonUDF1", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr + val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr val query = testRelationLeft.join( testRelationRight, @@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("pull out unevaluable python udf when it's mixed with evaluable one") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond && unevaluableJoinCond)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } test("throw an exception for not support join type") { for (joinType <- unsupportedJoinTypes) { - val thrownException = the [AnalysisException] thrownBy { + val e = intercept[AnalysisException] { val query = testRelationLeft.join( testRelationRight, joinType, - condition = Some(pythonUDF)) + condition = Some(unevaluableJoinCond)) Optimize.execute(query.analyze) } - assert(thrownException.message.contentEquals( + assert(e.message.contentEquals( s"Using PythonUDF in join condition of join type $joinType is not supported.")) + + val query2 = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } } } - From ce61bac1d84f8577b180400e44bd9bf22292e0b6 Mon Sep 17 00:00:00 2001 From: Mark Pavey Date: Wed, 28 Nov 2018 07:19:47 -0800 Subject: [PATCH 2155/2461] =?UTF-8?q?[SPARK-26137][CORE]=20Use=20Java=20sy?= =?UTF-8?q?stem=20property=20"file.separator"=20inste=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … of hard coded "/" in DependencyUtils ## What changes were proposed in this pull request? Use Java system property "file.separator" instead of hard coded "/" in DependencyUtils. ## How was this patch tested? Manual test: Submit Spark application via REST API that reads data from Elasticsearch using spark-elasticsearch library. Without fix application fails with error: 18/11/22 10:36:20 ERROR Version: Multiple ES-Hadoop versions detected in the classpath; please use only one jar:file:/C:/<...>/spark-2.4.0-bin-hadoop2.6/work/driver-20181122103610-0001/myApp-assembly-1.0.jar jar:file:/C:/<...>/myApp-assembly-1.0.jar 18/11/22 10:36:20 ERROR Main: Application [MyApp] failed: java.lang.Error: Multiple ES-Hadoop versions detected in the classpath; please use only one jar:file:/C:/<...>/spark-2.4.0-bin-hadoop2.6/work/driver-20181122103610-0001/myApp-assembly-1.0.jar jar:file:/C:/<...>/myApp-assembly-1.0.jar at org.elasticsearch.hadoop.util.Version.(Version.java:73) at org.elasticsearch.hadoop.rest.RestService.findPartitions(RestService.java:214) at org.elasticsearch.spark.rdd.AbstractEsRDD.esPartitions$lzycompute(AbstractEsRDD.scala:73) at org.elasticsearch.spark.rdd.AbstractEsRDD.esPartitions(AbstractEsRDD.scala:72) at org.elasticsearch.spark.rdd.AbstractEsRDD.getPartitions(AbstractEsRDD.scala:44) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:251) at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:49) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:253) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:251) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:251) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126) at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:363) at org.apache.spark.rdd.RDD.collect(RDD.scala:944) ... at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.deploy.worker.DriverWrapper$.main(DriverWrapper.scala:65) at org.apache.spark.deploy.worker.DriverWrapper.main(DriverWrapper.scala) With fix application runs successfully. Closes #23102 from markpavey/JIRA_SPARK-26137_DependencyUtilsFileSeparatorFix. Authored-by: Mark Pavey Signed-off-by: Sean Owen --- .../apache/spark/deploy/DependencyUtils.scala | 3 ++- .../spark/deploy/SparkSubmitSuite.scala | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index 178bdcfccb603..5a17a6b6e169c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -61,11 +61,12 @@ private[deploy] object DependencyUtils extends Logging { hadoopConf: Configuration, secMgr: SecurityManager): String = { val targetDir = Utils.createTempDir() + val userJarName = userJar.split(File.separatorChar).last Option(jars) .map { resolveGlobPaths(_, hadoopConf) .split(",") - .filterNot(_.contains(userJar.split("/").last)) + .filterNot(_.contains(userJarName)) .mkString(",") } .filterNot(_ == "") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 652c36ffa6e71..c093789244bfe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -962,6 +962,25 @@ class SparkSubmitSuite } } + test("remove copies of application jar from classpath") { + val fs = File.separator + val sparkConf = new SparkConf(false) + val hadoopConf = new Configuration() + val secMgr = new SecurityManager(sparkConf) + + val appJarName = "myApp.jar" + val jar1Name = "myJar1.jar" + val jar2Name = "myJar2.jar" + val userJar = s"file:/path${fs}to${fs}app${fs}jar$fs$appJarName" + val jars = s"file:/$jar1Name,file:/$appJarName,file:/$jar2Name" + + val resolvedJars = DependencyUtils + .resolveAndDownloadJars(jars, userJar, sparkConf, hadoopConf, secMgr) + + assert(!resolvedJars.contains(appJarName)) + assert(resolvedJars.contains(jar1Name) && resolvedJars.contains(jar2Name)) + } + test("Avoid re-upload remote resources in yarn client mode") { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) From 87bd9c75df6b67bef903751269a4fd381f9140d9 Mon Sep 17 00:00:00 2001 From: Brandon Krieger Date: Wed, 28 Nov 2018 07:22:48 -0800 Subject: [PATCH 2156/2461] [SPARK-25998][CORE] Change TorrentBroadcast to hold weak reference of broadcast object ## What changes were proposed in this pull request? This PR changes the broadcast object in TorrentBroadcast from a strong reference to a weak reference. This allows it to be garbage collected even if the Dataset is held in memory. This is ok, because the broadcast object can always be re-read. ## How was this patch tested? Tested in Spark shell by taking a heap dump, full repro steps listed in https://issues.apache.org/jira/browse/SPARK-25998. Closes #22995 from bkrieger/bk/torrent-broadcast-weak. Authored-by: Brandon Krieger Signed-off-by: Sean Owen --- .../spark/broadcast/TorrentBroadcast.scala | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index cbd49e070f2eb..26ead57316e18 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,6 +18,7 @@ package org.apache.spark.broadcast import java.io._ +import java.lang.ref.SoftReference import java.nio.ByteBuffer import java.util.zip.Adler32 @@ -61,9 +62,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], * which builds this value by reading blocks from the driver and/or other executors. * - * On the driver, if the value is required, it is read lazily from the block manager. + * On the driver, if the value is required, it is read lazily from the block manager. We hold + * a soft reference so that it can be garbage collected if required, as we can always reconstruct + * in the future. */ - @transient private lazy val _value: T = readBroadcastBlock() + @transient private var _value: SoftReference[T] = _ /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ @@ -92,8 +95,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** The checksum for all the blocks. */ private var checksums: Array[Int] = _ - override protected def getValue() = { - _value + override protected def getValue() = synchronized { + val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get + if (memoized != null) { + memoized + } else { + val newlyRead = readBroadcastBlock() + _value = new SoftReference[T](newlyRead) + newlyRead + } } private def calcChecksum(block: ByteBuffer): Int = { @@ -205,8 +215,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } private def readBroadcastBlock(): T = Utils.tryOrIOException { - TorrentBroadcast.synchronized { - val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + broadcastCache.synchronized { Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { setConf(SparkEnv.get.conf) From 9fde3deab87c8f9c6d8dd147f5d52d243ff4b7ad Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 28 Nov 2018 07:33:34 -0800 Subject: [PATCH 2157/2461] [SPARK-25989][ML] OneVsRestModel handle empty outputCols incorrectly ## What changes were proposed in this pull request? ignore empty output columns ## How was this patch tested? added tests Closes #22991 from zhengruifeng/ovrm_empty_outcol. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../spark/ml/classification/OneVsRest.scala | 35 ++++++++++++------- .../ml/classification/OneVsRestSuite.scala | 26 ++++++++++++++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1835a91775e0a..2f42a5922054e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -37,7 +37,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -169,6 +169,12 @@ final class OneVsRestModel private[ml] ( // Check schema transformSchema(dataset.schema, logging = true) + if (getPredictionCol == "" && getRawPredictionCol == "") { + logWarning(s"$uid: OneVsRestModel.transform() was called as NOOP" + + " since no output columns were set.") + return dataset.toDF + } + // determine the input columns: these need to be passed through val origCols = dataset.schema.map(f => col(f.name)) @@ -209,6 +215,9 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } + var predictionColNames = Seq.empty[String] + var predictionColumns = Seq.empty[Column] + if (getRawPredictionCol != "") { val numClass = models.length @@ -219,24 +228,24 @@ final class OneVsRestModel private[ml] ( Vectors.dense(predArray) } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } + predictionColNames = predictionColNames :+ getRawPredictionCol + predictionColumns = predictionColumns :+ rawPredictionUDF(col(accColName)) + } - // output confidence as raw prediction, label and label metadata as prediction - aggregatedDataset - .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) - .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) - .drop(accColName) - } else { + if (getPredictionCol != "") { // output the index of the classifier with highest confidence as prediction val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } - // output label and label metadata as prediction - aggregatedDataset - .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + + predictionColNames = predictionColNames :+ getPredictionCol + predictionColumns = predictionColumns :+ labelUDF(col(accColName)) + .as(getPredictionCol, labelMetadata) } + + aggregatedDataset + .withColumns(predictionColNames, predictionColumns) + .drop(accColName) } @Since("1.4.1") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 519ec1720eb98..b6e8c927403ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -290,6 +290,32 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { checkModelData(ovaModel, newOvaModel) } + test("should ignore empty output cols") { + val lr = new LogisticRegression().setMaxIter(1) + val ovr = new OneVsRest().setClassifier(lr) + val ovrModel = ovr.fit(dataset) + + val output1 = ovrModel.setPredictionCol("").setRawPredictionCol("") + .transform(dataset) + assert(output1.schema.fieldNames.toSet === + Set("label", "features")) + + val output2 = ovrModel.setPredictionCol("prediction").setRawPredictionCol("") + .transform(dataset) + assert(output2.schema.fieldNames.toSet === + Set("label", "features", "prediction")) + + val output3 = ovrModel.setPredictionCol("").setRawPredictionCol("rawPrediction") + .transform(dataset) + assert(output3.schema.fieldNames.toSet === + Set("label", "features", "rawPrediction")) + + val output4 = ovrModel.setPredictionCol("prediction").setRawPredictionCol("rawPrediction") + .transform(dataset) + assert(output4.schema.fieldNames.toSet === + Set("label", "features", "prediction", "rawPrediction")) + } + test("should support all NumericType labels and not support other types") { val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( From fa0d4bf69929c5acd676d602e758a969713d19d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 23:42:13 +0800 Subject: [PATCH 2158/2461] [SPARK-25829][SQL] remove duplicated map keys with last wins policy ## What changes were proposed in this pull request? Currently duplicated map keys are not handled consistently. For example, map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. This PR proposes to remove duplicated map keys with last wins policy, to follow Java/Scala and Presto. It only applies to built-in functions, as users can create map with duplicated map keys via private APIs anyway. updated functions: `CreateMap`, `MapFromArrays`, `MapFromEntries`, `StringToMap`, `MapConcat`, `TransformKeys`. For other places: 1. data source v1 doesn't have this problem, as users need to provide a java/scala map, which can't have duplicated keys. 2. data source v2 may have this problem. I've added a note to `ArrayBasedMapData` to ask the caller to take care of duplicated keys. In the future we should enforce it in the stable data APIs for data source v2. 3. UDF doesn't have this problem, as users need to provide a java/scala map. Same as data source v1. 4. file format. I checked all of them and only parquet does not enforce it. For backward compatibility reasons I change nothing but leave a note saying that the behavior will be undefined if users write map with duplicated keys to parquet files. Maybe we can add a config and fail by default if parquet files have map with duplicated keys. This can be done in followup. ## How was this patch tested? updated tests and new tests Closes #23124 from cloud-fan/map. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 + .../spark/sql/avro/AvroDeserializer.scala | 4 +- python/pyspark/sql/functions.py | 10 +- .../catalyst/expressions/UnsafeMapData.java | 3 + .../sql/catalyst/CatalystTypeConverters.scala | 6 - .../spark/sql/catalyst/InternalRow.scala | 48 ++-- .../catalyst/expressions/BoundAttribute.scala | 8 +- .../expressions/collectionOperations.scala | 242 ++++-------------- .../expressions/complexTypeCreator.scala | 106 +++----- .../expressions/higherOrderFunctions.scala | 8 +- .../expressions/objects/objects.scala | 8 +- .../sql/catalyst/json/JacksonParser.scala | 2 + .../catalyst/util/ArrayBasedMapBuilder.scala | 120 +++++++++ .../sql/catalyst/util/ArrayBasedMapData.scala | 15 ++ .../spark/sql/catalyst/util/ArrayData.scala | 18 +- .../CollectionExpressionsSuite.scala | 87 +++---- .../expressions/ComplexTypeSuite.scala | 20 +- .../HigherOrderFunctionsSuite.scala | 37 +-- .../util/ArrayBasedMapBuilderSuite.scala | 105 ++++++++ .../datasources/orc/OrcDeserializer.scala | 2 + .../parquet/ParquetRowConverter.scala | 6 +- .../spark/sql/DataFrameFunctionsSuite.scala | 6 +- 22 files changed, 444 insertions(+), 419 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 25cd541190919..55838e773e4b1 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -27,6 +27,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 272e7d5b388d9..4e2224b058a0a 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.avro -import java.math.{BigDecimal} +import java.math.BigDecimal import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -218,6 +218,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { i += 1 } + // The Avro map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) case (UNION, _) => diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 286ef219a69e9..f98e550e39da8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2656,11 +2656,11 @@ def map_concat(*cols): >>> from pyspark.sql.functions import map_concat >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) - +--------------------------------+ - |map3 | - +--------------------------------+ - |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| - +--------------------------------+ + +------------------------+ + |map3 | + +------------------------+ + |[1 -> d, 2 -> b, 3 -> c]| + +------------------------+ """ sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], (list, set)): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index f17441dfccb6d..a0833a6df8bbd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -28,6 +28,9 @@ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head * to indicate the number of bytes of the unsafe key array. * [unsafe key array numBytes] [unsafe key array] [unsafe value array] + * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. */ // TODO: Use a more efficient format which doesn't depend on unsafe array. public final class UnsafeMapData extends MapData { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 6f5fbdd79e668..93df73ab1eaf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,12 +431,6 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) - case (keys: Array[_], values: Array[_]) => - // case for mapdata with duplicate keys - new ArrayBasedMapData( - new GenericArrayData(keys.map(convertToCatalyst)), - new GenericArrayData(values.map(convertToCatalyst)) - ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 274d75e680f03..e49c10be6be4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -125,22 +125,36 @@ object InternalRow { * actually takes a `SpecializedGetters` input because it can be generalized to other classes * that implements `SpecializedGetters` (e.g., `ArrayData`) too. */ - def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { - case BooleanType => (input, ordinal) => input.getBoolean(ordinal) - case ByteType => (input, ordinal) => input.getByte(ordinal) - case ShortType => (input, ordinal) => input.getShort(ordinal) - case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) - case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) - case FloatType => (input, ordinal) => input.getFloat(ordinal) - case DoubleType => (input, ordinal) => input.getDouble(ordinal) - case StringType => (input, ordinal) => input.getUTF8String(ordinal) - case BinaryType => (input, ordinal) => input.getBinary(ordinal) - case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) - case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) - case _: ArrayType => (input, ordinal) => input.getArray(ordinal) - case _: MapType => (input, ordinal) => input.getMap(ordinal) - case u: UserDefinedType[_] => getAccessor(u.sqlType) - case _ => (input, ordinal) => input.get(ordinal, dataType) + def getAccessor(dt: DataType, nullable: Boolean = true): (SpecializedGetters, Int) => Any = { + val getValueNullSafe: (SpecializedGetters, Int) => Any = dt match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType, nullable) + case _ => (input, ordinal) => input.get(ordinal, dt) + } + + if (nullable) { + (getter, index) => { + if (getter.isNullAt(index)) { + null + } else { + getValueNullSafe(getter, index) + } + } + } else { + getValueNullSafe + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 77582e10f9ff2..ea8c369ee49ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -34,15 +34,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (nullable && input.isNullAt(ordinal)) { - null - } else { - accessor(input, ordinal) - } + accessor(input, ordinal) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 43116743e9952..fa8e38acd522d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -546,33 +546,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres override def nullable: Boolean = children.exists(_.nullable) + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def eval(input: InternalRow): Any = { - val maps = children.map(_.eval(input)) + val maps = children.map(_.eval(input).asInstanceOf[MapData]) if (maps.contains(null)) { return null } - val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) - val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) - val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements()) if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + s"elements due to exceeding the map size limit " + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } - val finalKeyArray = new Array[AnyRef](numElements.toInt) - val finalValueArray = new Array[AnyRef](numElements.toInt) - var position = 0 - for (i <- keyArrayDatas.indices) { - val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) - val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) - Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) - Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) - position += keyArray.length - } - new ArrayBasedMapData(new GenericArrayData(finalKeyArray), - new GenericArrayData(finalValueArray)) + for (map <- maps) { + mapBuilder.putAll(map.keyArray(), map.valueArray()) + } + mapBuilder.build() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -581,16 +573,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val valueType = dataType.valueType val argsName = ctx.freshName("args") val hasNullName = ctx.freshName("hasNull") - val mapDataClass = classOf[MapData].getName - val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName - val arrayDataClass = classOf[ArrayData].getName - - val init = - s""" - |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; - |boolean ${ev.isNull}, $hasNullName = false; - |$mapDataClass ${ev.value} = null; - """.stripMargin + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map { case ((m, true), i) => @@ -613,10 +596,10 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """.stripMargin } - val codes = ctx.splitExpressionsWithCurrentInputs( + val prepareMaps = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "getMapConcatInputs", - extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + extraArguments = (s"MapData[]", argsName) :: ("boolean", hasNullName) :: Nil, returnType = "boolean", makeSplitFunction = body => s""" @@ -646,34 +629,34 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val mapMerge = s""" - |${ev.isNull} = $hasNullName; - |if (!${ev.isNull}) { - | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; - | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; - | long $numElementsName = 0; - | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { - | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); - | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); - | $numElementsName += $argsName[$idxName].numElements(); - | } - | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful attempt to concat maps with " + - | $numElementsName + " elements due to exceeding the map size limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - | } - | $arrayDataClass $finKeysName = $keyConcat($keyArgsName, - | (int) $numElementsName); - | $arrayDataClass $finValsName = $valueConcat($valArgsName, - | (int) $numElementsName); - | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}]; + |ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}]; + |long $numElementsName = 0; + |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); |} + |if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |} + |ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName); + |ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName); + |${ev.value} = $builderTerm.from($finKeysName, $finValsName); """.stripMargin ev.copy( code = code""" - |$init - |$codes - |$mapMerge + |MapData[] $argsName = new MapData[${mapCodes.size}]; + |boolean $hasNullName = false; + |$prepareMaps + |boolean ${ev.isNull} = $hasNullName; + |MapData ${ev.value} = null; + |if (!$hasNullName) { + | $mapMerge + |} """.stripMargin) } @@ -751,171 +734,44 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override protected def nullSafeEval(input: Any): Any = { - val arrayData = input.asInstanceOf[ArrayData] - val numEntries = arrayData.numElements() + val entries = input.asInstanceOf[ArrayData] + val numEntries = entries.numElements() var i = 0 - if(nullEntries) { + if (nullEntries) { while (i < numEntries) { - if (arrayData.isNullAt(i)) return null + if (entries.isNullAt(i)) return null i += 1 } } - val keyArray = new Array[AnyRef](numEntries) - val valueArray = new Array[AnyRef](numEntries) + i = 0 while (i < numEntries) { - val entry = arrayData.getStruct(i, 2) - val key = entry.get(0, dataType.keyType) - if (key == null) { - throw new RuntimeException("The first field from a struct (key) can't be null.") - } - keyArray.update(i, key) - val value = entry.get(1, dataType.valueType) - valueArray.update(i, value) + mapBuilder.put(entries.getStruct(i, 2)) i += 1 } - ArrayBasedMapData(keyArray, valueArray) + mapBuilder.build() } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { val numEntries = ctx.freshName("numEntries") - val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) - val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) - val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) - } else { - genCodeForAnyElements(ctx, c, ev.value, numEntries) - } + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) + val i = ctx.freshName("idx") ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { s""" |final int $numEntries = $c.numElements(); - |$code + |for (int $i = 0; $i < $numEntries; $i++) { + | $builderTerm.put($c.getStruct($i, 2)); + |} + |${ev.value} = $builderTerm.build(); """.stripMargin } }) } - private def genCodeForAssignmentLoop( - ctx: CodegenContext, - childVariable: String, - mapData: String, - numEntries: String, - keyAssignment: (String, String) => String, - valueAssignment: (String, String) => String): String = { - val entry = ctx.freshName("entry") - val i = ctx.freshName("idx") - - val nullKeyCheck = if (dataTypeDetails.get._2) { - s""" - |if ($entry.isNullAt(0)) { - | throw new RuntimeException("The first field from a struct (key) can't be null."); - |} - """.stripMargin - } else { - "" - } - - s""" - |for (int $i = 0; $i < $numEntries; $i++) { - | InternalRow $entry = $childVariable.getStruct($i, 2); - | $nullKeyCheck - | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} - | ${valueAssignment(entry, i)} - |} - """.stripMargin - } - - private def genCodeForPrimitiveElements( - ctx: CodegenContext, - childVariable: String, - mapData: String, - numEntries: String): String = { - val byteArraySize = ctx.freshName("byteArraySize") - val keySectionSize = ctx.freshName("keySectionSize") - val valueSectionSize = ctx.freshName("valueSectionSize") - val data = ctx.freshName("byteArray") - val unsafeMapData = ctx.freshName("unsafeMapData") - val keyArrayData = ctx.freshName("keyArrayData") - val valueArrayData = ctx.freshName("valueArrayData") - - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val keySize = dataType.keyType.defaultSize - val valueSize = dataType.valueType.defaultSize - val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" - val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" - - val keyAssignment = (key: String, idx: String) => - CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key) - val valueAssignment = (entry: String, idx: String) => - CodeGenerator.createArrayAssignment( - valueArrayData, dataType.valueType, entry, idx, "1", dataType.valueContainsNull) - val assignmentLoop = genCodeForAssignmentLoop( - ctx, - childVariable, - mapData, - numEntries, - keyAssignment, - valueAssignment - ) - - s""" - |final long $keySectionSize = $kByteSize; - |final long $valueSectionSize = $vByteSize; - |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; - |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} - |} else { - | final byte[] $data = new byte[(int)$byteArraySize]; - | UnsafeMapData $unsafeMapData = new UnsafeMapData(); - | Platform.putLong($data, $baseOffset, $keySectionSize); - | Platform.putLong($data, ${baseOffset + 8}, $numEntries); - | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); - | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); - | ArrayData $keyArrayData = $unsafeMapData.keyArray(); - | ArrayData $valueArrayData = $unsafeMapData.valueArray(); - | $assignmentLoop - | $mapData = $unsafeMapData; - |} - """.stripMargin - } - - private def genCodeForAnyElements( - ctx: CodegenContext, - childVariable: String, - mapData: String, - numEntries: String): String = { - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - val mapDataClass = classOf[ArrayBasedMapData].getName() - - val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) - val valueAssignment = (entry: String, idx: String) => { - val value = CodeGenerator.getValue(entry, dataType.valueType, "1") - if (dataType.valueContainsNull && isValuePrimitive) { - s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" - } else { - s"$values[$idx] = $value;" - } - } - val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" - val assignmentLoop = genCodeForAssignmentLoop( - ctx, - childVariable, - mapData, - numEntries, - keyAssignment, - valueAssignment) - - s""" - |final Object[] $keys = new Object[$numEntries]; - |final Object[] $values = new Object[$numEntries]; - |$assignmentLoop - |$mapData = $mapDataClass.apply($keys, $values); - """.stripMargin - } - override def prettyName: String = "map_from_entries" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 6b77996789f1a..4e722c9237a90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -24,8 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String /** @@ -62,7 +60,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val et = dataType.elementType val (allocation, assigns, arrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, et, children, false, "createArray") + GenArrayData.genCodeToCreateArrayData(ctx, et, children, "createArray") ev.copy( code = code"${allocation}${assigns}", value = JavaCode.variable(arrayData, dataType), @@ -79,7 +77,6 @@ private [sql] object GenArrayData { * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements * @param elementsExpr concatenated set of [[Expression]] for each element of an underlying array - * @param isMapKey if true, throw an exception when the element is null * @param functionName string to include in the error message * @return (array allocation, concatenated assignments to each array elements, arrayData name) */ @@ -87,7 +84,6 @@ private [sql] object GenArrayData { ctx: CodegenContext, elementType: DataType, elementsExpr: Seq[Expression], - isMapKey: Boolean, functionName: String): (String, String, String) = { val arrayDataName = ctx.freshName("arrayData") val numElements = s"${elementsExpr.length}L" @@ -103,15 +99,9 @@ private [sql] object GenArrayData { val assignment = if (!expr.nullable) { setArrayElement } else { - val isNullAssignment = if (!isMapKey) { - s"$arrayDataName.setNullAt($i);" - } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" - } - s""" |if (${eval.isNull}) { - | $isNullAssignment + | $arrayDataName.setNullAt($i); |} else { | $setArrayElement |} @@ -165,7 +155,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } } - override def dataType: MapType = { + override lazy val dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) .getOrElse(StringType), @@ -176,32 +166,33 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def eval(input: InternalRow): Any = { - val keyArray = keys.map(_.eval(input)).toArray - if (keyArray.contains(null)) { - throw new RuntimeException("Cannot use null as map key!") + var i = 0 + while (i < keys.length) { + mapBuilder.put(keys(i).eval(input), values(i).eval(input)) + i += 1 } - val valueArray = values.map(_.eval(input)).toArray - new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + mapBuilder.build() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mapClass = classOf[ArrayBasedMapData].getName val MapType(keyDt, valueDt, _) = dataType val (allocationKeyData, assignKeys, keyArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys, true, "createMap") + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys, "createMap") val (allocationValueData, assignValues, valueArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values, false, "createMap") + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values, "createMap") + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) val code = code""" - final boolean ${ev.isNull} = false; $allocationKeyData $assignKeys $allocationValueData $assignValues - final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); + final MapData ${ev.value} = $builderTerm.from($keyArrayData, $valueArrayData); """ - ev.copy(code = code) + ev.copy(code = code, isNull = FalseLiteral) } override def prettyName: String = "map" @@ -234,53 +225,25 @@ case class MapFromArrays(left: Expression, right: Expression) } } - override def dataType: DataType = { + override def dataType: MapType = { MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, valueType = right.dataType.asInstanceOf[ArrayType].elementType, valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { val keyArrayData = keyArray.asInstanceOf[ArrayData] val valueArrayData = valueArray.asInstanceOf[ArrayData] - if (keyArrayData.numElements != valueArrayData.numElements) { - throw new RuntimeException("The given two arrays should have the same length") - } - val leftArrayType = left.dataType.asInstanceOf[ArrayType] - if (leftArrayType.containsNull) { - var i = 0 - while (i < keyArrayData.numElements) { - if (keyArrayData.isNullAt(i)) { - throw new RuntimeException("Cannot use null as map key!") - } - i += 1 - } - } - new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + mapBuilder.from(keyArrayData.copy(), valueArrayData.copy()) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { - val arrayBasedMapData = classOf[ArrayBasedMapData].getName - val leftArrayType = left.dataType.asInstanceOf[ArrayType] - val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { - val i = ctx.freshName("i") - s""" - |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { - | if ($keyArrayData.isNullAt($i)) { - | throw new RuntimeException("Cannot use null as map key!"); - | } - |} - """.stripMargin - } - s""" - |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { - | throw new RuntimeException("The given two arrays should have the same length"); - |} - |$keyArrayElemNullCheck - |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); - """.stripMargin + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) + s"${ev.value} = $builderTerm.from($keyArrayData.copy(), $valueArrayData.copy());" }) } @@ -488,28 +451,25 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } } + private lazy val mapBuilder = new ArrayBasedMapBuilder(StringType, StringType) + override def nullSafeEval( inputString: Any, stringDelimiter: Any, keyValueDelimiter: Any): Any = { val keyValues = inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1) - - val iterator = new Iterator[(UTF8String, UTF8String)] { - var index = 0 - val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String] - - override def hasNext: Boolean = { - keyValues.length > index - } - - override def next(): (UTF8String, UTF8String) = { - val keyValueArray = keyValues(index).split(keyValueDelimiterUTF8String, 2) - index += 1 - (keyValueArray(0), if (keyValueArray.length < 2) null else keyValueArray(1)) - } + val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String] + + var i = 0 + while (i < keyValues.length) { + val keyValueArray = keyValues(i).split(keyValueDelimiterUTF8String, 2) + val key = keyValueArray(0) + val value = if (keyValueArray.length < 2) null else keyValueArray(1) + mapBuilder.put(key, value) + i += 1 } - ArrayBasedMapData(iterator, keyValues.size, identity, identity) + mapBuilder.build() } override def prettyName: String = "str_to_map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 8b31021866220..a8639d29f964d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -512,7 +512,7 @@ case class TransformKeys( @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType - override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull) override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForMapKeyType(function.dataType) @@ -525,6 +525,7 @@ case class TransformKeys( @transient lazy val LambdaFunction( _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] @@ -534,13 +535,10 @@ case class TransformKeys( keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) val result = functionForEval.eval(inputRow) - if (result == null) { - throw new RuntimeException("Cannot use null as map key!") - } resultKeys.update(i, result) i += 1 } - new ArrayBasedMapData(resultKeys, map.valueArray()) + mapBuilder.from(resultKeys, map.valueArray()) } override def prettyName: String = "transform_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 59c897b6a53ce..8182730feb4b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -587,17 +587,13 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - if (nullable && input.isNullAt(0)) { - null - } else { - accessor(input, 0) - } + accessor(input, 0) } override def genCode(ctx: CodegenContext): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 773ff5a7a4013..92517aac053b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -367,6 +367,8 @@ class JacksonParser( values += fieldConverter.apply(parser) } + // The JSON map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. ArrayBasedMapData(keys.toArray, values.toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala new file mode 100644 index 0000000000000..e7cd61655dc9a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** + * A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes + * duplicated map keys w.r.t. the last wins policy. + */ +class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable { + assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map") + assert(keyType != NullType, "map key cannot be null type.") + + private lazy val keyToIndex = keyType match { + // Binary type data is `byte[]`, which can't use `==` to check equality. + case _: AtomicType | _: CalendarIntervalType if !keyType.isInstanceOf[BinaryType] => + new java.util.HashMap[Any, Int]() + case _ => + // for complex types, use interpreted ordering to be able to compare unsafe data with safe + // data, e.g. UnsafeRow vs GenericInternalRow. + new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + } + + // TODO: specialize it + private lazy val keys = mutable.ArrayBuffer.empty[Any] + private lazy val values = mutable.ArrayBuffer.empty[Any] + + private lazy val keyGetter = InternalRow.getAccessor(keyType) + private lazy val valueGetter = InternalRow.getAccessor(valueType) + + def put(key: Any, value: Any): Unit = { + if (key == null) { + throw new RuntimeException("Cannot use null as map key.") + } + + val index = keyToIndex.getOrDefault(key, -1) + if (index == -1) { + keyToIndex.put(key, values.length) + keys.append(key) + values.append(value) + } else { + // Overwrite the previous value, as the policy is last wins. + values(index) = value + } + } + + // write a 2-field row, the first field is key and the second field is value. + def put(entry: InternalRow): Unit = { + if (entry.isNullAt(0)) { + throw new RuntimeException("Cannot use null as map key.") + } + put(keyGetter(entry, 0), valueGetter(entry, 1)) + } + + def putAll(keyArray: ArrayData, valueArray: ArrayData): Unit = { + if (keyArray.numElements() != valueArray.numElements()) { + throw new RuntimeException( + "The key array and value array of MapData must have the same length.") + } + + var i = 0 + while (i < keyArray.numElements()) { + put(keyGetter(keyArray, i), valueGetter(valueArray, i)) + i += 1 + } + } + + private def reset(): Unit = { + keyToIndex.clear() + keys.clear() + values.clear() + } + + /** + * Builds the result [[ArrayBasedMapData]] and reset this builder to free up the resources. The + * builder becomes fresh afterward and is ready to take input and build another map. + */ + def build(): ArrayBasedMapData = { + val map = new ArrayBasedMapData( + new GenericArrayData(keys.toArray), new GenericArrayData(values.toArray)) + reset() + map + } + + /** + * Builds a [[ArrayBasedMapData]] from the given key and value array and reset this builder. The + * builder becomes fresh afterward and is ready to take input and build another map. + */ + def from(keyArray: ArrayData, valueArray: ArrayData): ArrayBasedMapData = { + assert(keyToIndex.isEmpty, "'from' can only be called with a fresh ArrayBasedMapBuilder.") + putAll(keyArray, valueArray) + if (keyToIndex.size == keyArray.numElements()) { + // If there is no duplicated map keys, creates the MapData with the input key and value array, + // as they might already in unsafe format and are more efficient. + reset() + new ArrayBasedMapData(keyArray, valueArray) + } else { + build() + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 91b3139443696..0989af26b8c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -19,6 +19,12 @@ package org.apache.spark.sql.catalyst.util import java.util.{Map => JavaMap} +/** + * A simple `MapData` implementation which is backed by 2 arrays. + * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. + */ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) @@ -83,6 +89,9 @@ object ArrayBasedMapData { * Creates a [[ArrayBasedMapData]] by applying the given converters over * each (key -> value) pair from the given iterator * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. + * * @param iterator Input iterator * @param size Number of elements * @param keyConverter This function is applied over all the keys extracted from the @@ -108,6 +117,12 @@ object ArrayBasedMapData { ArrayBasedMapData(keys, values) } + /** + * Creates a [[ArrayBasedMapData]] from a key and value array. + * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. + */ def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 4da8ce05fe8a3..ebbf241088f80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -172,11 +172,7 @@ abstract class ArrayData extends SpecializedGetters with Serializable { val values = new Array[T](size) var i = 0 while (i < size) { - if (isNullAt(i)) { - values(i) = null.asInstanceOf[T] - } else { - values(i) = accessor(this, i).asInstanceOf[T] - } + values(i) = accessor(this, i).asInstanceOf[T] i += 1 } values @@ -187,11 +183,7 @@ abstract class ArrayData extends SpecializedGetters with Serializable { val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { - if (isNullAt(i)) { - f(i, null) - } else { - f(i, accessor(this, i)) - } + f(i, accessor(this, i)) i += 1 } } @@ -208,11 +200,7 @@ class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends I override def apply(idx: Int): T = if (0 <= idx && idx < arrayData.numElements()) { - if (arrayData.isNullAt(idx)) { - null.asInstanceOf[T] - } else { - accessor(arrayData, idx).asInstanceOf[T] - } + accessor(arrayData, idx).asInstanceOf[T] } else { throw new IndexOutOfBoundsException( s"Index $idx must be between 0 and the length of the ArrayData.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d2edb2f24688d..bed8547dbc83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -114,13 +114,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(create_map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, valueContainsNull = false)) val m2 = Literal.create(create_map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) - val m3 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m3 = Literal.create(create_map("f" -> "1", "g" -> "2"), MapType(StringType, StringType)) val m4 = Literal.create(create_map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) val m5 = Literal.create(create_map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) - val m6 = Literal.create(create_map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m6 = Literal.create(create_map("c" -> null, "d" -> 3), MapType(StringType, IntegerType)) val m7 = Literal.create(create_map(List(1, 2) -> 1, List(3, 4) -> 2), MapType(ArrayType(IntegerType), IntegerType)) - val m8 = Literal.create(create_map(List(5, 6) -> 3, List(1, 2) -> 4), + val m8 = Literal.create(create_map(List(5, 6) -> 3, List(7, 8) -> 4), MapType(ArrayType(IntegerType), IntegerType)) val m9 = Literal.create(create_map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, valueContainsNull = false)) @@ -134,57 +134,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) - // overlapping maps - checkEvaluation(MapConcat(Seq(m0, m1)), - ( - Array("a", "b", "c", "a"), // keys - Array("1", "2", "3", "4") // values - ) - ) + // overlapping maps should remove duplicated map keys w.r.t. last win policy. + checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3")) // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // 3 maps - checkEvaluation(MapConcat(Seq(m0, m1, m2)), - ( - Array("a", "b", "c", "a", "d", "e"), // keys - Array("1", "2", "3", "4", "4", "5") // values - ) - ) + checkEvaluation(MapConcat(Seq(m0, m2, m3)), + create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5", "f" -> "1", "g" -> "2")) // null reference values - checkEvaluation(MapConcat(Seq(m3, m4)), - ( - Array("a", "b", "a", "c"), // keys - Array("1", "2", null, "3") // values - ) - ) + checkEvaluation(MapConcat(Seq(m2, m4)), + create_map("d" -> "4", "e" -> "5", "a" -> null, "c" -> "3")) // null primitive values checkEvaluation(MapConcat(Seq(m5, m6)), - ( - Array("a", "b", "a", "c"), // keys - Array(1, 2, null, 3) // values - ) - ) + create_map("a" -> 1, "b" -> 2, "c" -> null, "d" -> 3)) // keys that are primitive checkEvaluation(MapConcat(Seq(m9, m10)), - ( - Array(1, 2, 3, 4), // keys - Array("1", "2", "3", "4") // values - ) - ) + create_map(1 -> "1", 2 -> "2", 3 -> "3", 4 -> "4")) - // keys that are arrays, with overlap + // keys that are arrays checkEvaluation(MapConcat(Seq(m7, m8)), - ( - Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys - Array(1, 2, 3, 4) // values - ) - ) + create_map(List(1, 2) -> 1, List(3, 4) -> 2, List(5, 6) -> 3, List(7, 8) -> 4)) + // both keys and value are primitive and valueContainsNull = false checkEvaluation(MapConcat(Seq(m11, m12)), create_map(1 -> 2, 3 -> 4, 5 -> 6)) @@ -205,15 +181,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq.empty), Map.empty) // force split expressions for input in generated code - val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") - val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") - checkEvaluation(MapConcat( - Seq( - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 - )), - (expectedKeys, expectedValues)) + val expectedKeys = (1 to 65).map(_.toString) + val expectedValues = (1 to 65).map(_.toString) + checkEvaluation( + MapConcat( + expectedKeys.zip(expectedValues).map { + case (k, v) => Literal.create(create_map(k -> v), MapType(StringType, StringType)) + }), + create_map(expectedKeys.zip(expectedValues): _*)) // argument checking assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) @@ -248,7 +223,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), valueContainsNull = true)) - checkEvaluation(mapConcat, Map( + checkEvaluation(mapConcat, create_map( Seq(1, 2) -> Seq("a", "b"), Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null)) @@ -282,7 +257,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val ai1 = Literal.create(Seq(row(1, null), row(2, 20), row(3, null)), aiType) val ai2 = Literal.create(Seq.empty, aiType) val ai3 = Literal.create(null, aiType) + // The map key is duplicated val ai4 = Literal.create(Seq(row(1, 10), row(1, 20)), aiType) + // The map key is null val ai5 = Literal.create(Seq(row(1, 10), row(null, 20)), aiType) val ai6 = Literal.create(Seq(null, row(2, 20), null), aiType) @@ -290,10 +267,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) - checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20)) + // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(ai5), - "The first field from a struct (key) can't be null.") + "Cannot use null as map key") checkEvaluation(MapFromEntries(ai6), null) // Non-primitive-type keys and values @@ -310,13 +289,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) - checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) - checkEvaluation(MapFromEntries(as6), null) - + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb")) // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), - "The first field from a struct (key) can't be null.") + "Cannot use null as map key") + checkEvaluation(MapFromEntries(as6), null) // map key can't be map val structOfMap = row(create_map(1 -> 1), 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index d95f42e04e37c..dc60464815043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -183,6 +183,11 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), "Cannot use null as map key") + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))), + create_map(1 -> 3)) + // ArrayType map key and value val map = CreateMap(Seq( Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), @@ -243,12 +248,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { MapFromArrays(intWithNullArray, strArray), "Cannot use null as map key") + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + MapFromArrays( + Literal.create(Seq(1, 1), ArrayType(IntegerType)), + Literal.create(Seq(2, 3), ArrayType(IntegerType))), + create_map(1 -> 3)) + // map key can't be map val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b")) val map = MapFromArrays( Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))), - Literal.create(Seq(1), ArrayType(IntegerType)) - ) + Literal.create(Seq(1), ArrayType(IntegerType))) map.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") case TypeCheckResult.TypeCheckFailure(msg) => @@ -356,6 +367,11 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val m5 = Map("a" -> null) checkEvaluation(new StringToMap(s5), m5) + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + new StringToMap(Literal("a:1,b:2,a:3")), + create_map("a" -> "3", "b" -> "2")) + // arguments checking assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 66bf18af95799..03fb75e330c66 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -330,8 +330,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( transformKeys(transformKeys(ai0, plusOne), plusValue), create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) - checkEvaluation(transformKeys(ai0, modKey), - ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3)) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( @@ -467,16 +467,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper .bind(validateBinding) } - val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + val mii2 = Literal.create(create_map(1 -> null, 2 -> -2, 3 -> null), MapType(IntegerType, IntegerType, valueContainsNull = true)) val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mii4 = MapFromArrays( - Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), - Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { @@ -492,12 +489,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mii0, mii3, multiplyKeyWithValues), Map(1 -> null, 2 -> null, 3 -> null)) - checkEvaluation( - map_zip_with(mii0, mii4, multiplyKeyWithValues), - Map(1 -> null, 2 -> 800, 3 -> null)) - checkEvaluation( - map_zip_with(mii4, mii0, multiplyKeyWithValues), - Map(2 -> 800, 1 -> null, 3 -> null)) checkEvaluation( map_zip_with(mii0, miin, multiplyKeyWithValues), null) @@ -511,9 +502,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), MapType(StringType, StringType, valueContainsNull = true)) val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) - val mss4 = MapFromArrays( - Literal.create(Seq("a", "a"), ArrayType(StringType, false)), - Literal.create(Seq("a", "n"), ArrayType(StringType, false))) val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) val concat: (Expression, Expression, Expression) => Expression = { @@ -529,12 +517,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mss0, mss3, concat), Map("a" -> null, "b" -> null, "d" -> null)) - checkEvaluation( - map_zip_with(mss0, mss4, concat), - Map("a" -> "axa", "b" -> null, "d" -> null)) - checkEvaluation( - map_zip_with(mss4, mss0, concat), - Map("a" -> "aax", "b" -> null, "d" -> null)) checkEvaluation( map_zip_with(mss0, mssn, concat), null) @@ -550,9 +532,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), MapType(BinaryType, BinaryType, valueContainsNull = true)) val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) - val mbb4 = MapFromArrays( - Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), - Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) checkEvaluation( @@ -564,12 +543,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mbb0, mbb3, concat), Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) - checkEvaluation( - map_zip_with(mbb0, mbb4, concat), - Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) - checkEvaluation( - map_zip_with(mbb4, mbb0, concat), - Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) checkEvaluation( map_zip_with(mbb0, mbbn, concat), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala new file mode 100644 index 0000000000000..8509bce177129 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType} +import org.apache.spark.unsafe.Platform + +class ArrayBasedMapBuilderSuite extends SparkFunSuite { + + test("basic") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, 1) + builder.put(InternalRow(2, 2)) + builder.putAll(new GenericArrayData(Seq(3)), new GenericArrayData(Seq(3))) + val map = builder.build() + assert(map.numElements() == 3) + assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 1, 2 -> 2, 3 -> 3)) + } + + test("fail with null key") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, null) // null value is OK + val e = intercept[RuntimeException](builder.put(null, 1)) + assert(e.getMessage.contains("Cannot use null as map key")) + } + + test("remove duplicated keys with last wins policy") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, 1) + builder.put(2, 2) + builder.put(1, 2) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2)) + } + + test("binary type key") { + val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType) + builder.put(Array(1.toByte), 1) + builder.put(Array(2.toByte), 2) + builder.put(Array(1.toByte), 3) + val map = builder.build() + assert(map.numElements() == 2) + val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq + assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1)) + assert(entries(0)._2 == 3) + assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2)) + assert(entries(1)._2 == 2) + } + + test("struct type key") { + val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType) + builder.put(InternalRow(1), 1) + builder.put(InternalRow(2), 2) + val unsafeRow = { + val row = new UnsafeRow(1) + val bytes = new Array[Byte](16) + row.pointTo(bytes, 16) + row.setInt(0, 1) + row + } + builder.put(unsafeRow, 3) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2)) + } + + test("array type key") { + val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType) + builder.put(new GenericArrayData(Seq(1, 1)), 1) + builder.put(new GenericArrayData(Seq(2, 2)), 2) + val unsafeArray = { + val array = new UnsafeArrayData() + val bytes = new Array[Byte](24) + Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2) + array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24) + array.setInt(0, 1) + array.setInt(1, 1) + array + } + builder.put(unsafeArray, 3) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == + Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 4ecc54bd2fd96..ee16b3ab07f5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -179,6 +179,8 @@ class OrcDeserializer( i += 1 } + // The ORC map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 1199725941842..004a96d134132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -558,8 +558,12 @@ private[parquet] class ParquetRowConverter( override def getConverter(fieldIndex: Int): Converter = keyValueConverter - override def end(): Unit = + override def end(): Unit = { + // The parquet map may contains null or duplicated map keys. When it happens, the behavior is + // undefined. + // TODO (SPARK-26174): disallow it with a config. updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray)) + } // NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next // value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 666ba35d7a8f3..e6d1a038a5918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -89,13 +89,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val msg1 = intercept[Exception] { df5.select(map_from_arrays($"k", $"v")).collect }.getMessage - assert(msg1.contains("Cannot use null as map key!")) + assert(msg1.contains("Cannot use null as map key")) val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") val msg2 = intercept[Exception] { df6.select(map_from_arrays($"k", $"v")).collect }.getMessage - assert(msg2.contains("The given two arrays should have the same length")) + assert(msg2.contains("The key array and value array of MapData must have the same length")) } test("struct with column name") { @@ -2588,7 +2588,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val ex3 = intercept[Exception] { dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() } - assert(ex3.getMessage.contains("Cannot use null as map key!")) + assert(ex3.getMessage.contains("Cannot use null as map key")) val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") From 8bfea86b1c8a65ce73711af02d9e4140659a926d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 29 Nov 2018 01:54:06 +0000 Subject: [PATCH 2159/2461] [SPARK-26133][ML] Remove deprecated OneHotEncoder and rename OneHotEncoderEstimator to OneHotEncoder ## What changes were proposed in this pull request? We have deprecated `OneHotEncoder` at Spark 2.3.0 and introduced `OneHotEncoderEstimator`. At 3.0.0, we remove deprecated `OneHotEncoder` and rename `OneHotEncoderEstimator` to `OneHotEncoder`. TODO: According to ML migration guide, we need to keep `OneHotEncoderEstimator` as an alias after renaming. This is not done at this patch in order to facilitate review. ## How was this patch tested? Existing tests. Closes #23100 from viirya/remove_one_hot_encoder. Authored-by: Liang-Chi Hsieh Signed-off-by: DB Tsai --- docs/ml-features.md | 24 +- docs/ml-guide.md | 6 + ...ple.java => JavaOneHotEncoderExample.java} | 8 +- ...r_example.py => onehot_encoder_example.py} | 8 +- ...ample.scala => OneHotEncoderExample.scala} | 8 +- .../spark/ml/feature/OneHotEncoder.scala | 530 +++++++++++++++--- .../ml/feature/OneHotEncoderEstimator.scala | 528 ----------------- .../apache/spark/ml/feature/RFormula.scala | 4 +- .../feature/OneHotEncoderEstimatorSuite.scala | 422 -------------- .../spark/ml/feature/OneHotEncoderSuite.scala | 411 +++++++++++--- project/MimaExcludes.scala | 12 + python/pyspark/ml/feature.py | 102 +--- 12 files changed, 841 insertions(+), 1222 deletions(-) rename examples/src/main/java/org/apache/spark/examples/ml/{JavaOneHotEncoderEstimatorExample.java => JavaOneHotEncoderExample.java} (91%) rename examples/src/main/python/ml/{onehot_encoder_estimator_example.py => onehot_encoder_example.py} (83%) rename examples/src/main/scala/org/apache/spark/examples/ml/{OneHotEncoderEstimatorExample.scala => OneHotEncoderExample.scala} (89%) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 882b895a9d154..83a211ce02e67 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -779,43 +779,37 @@ for more details on the API. -## OneHotEncoder (Deprecated since 2.3.0) - -Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). - -`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. - -## OneHotEncoderEstimator +## OneHotEncoder [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. -`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). +`OneHotEncoder` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). -`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). +`OneHotEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). **Examples**
      -Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. +Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
      -Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) +Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
      -Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. +Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example python/ml/onehot_encoder_estimator_example.py %} +{% include_example python/ml/onehot_encoder_example.py %}
      diff --git a/docs/ml-guide.md b/docs/ml-guide.md index aea07be34cb86..57d4e1fe9d33a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -104,6 +104,12 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. +## From 2.4 to 3.0 + +### Breaking changes + +* `OneHotEncoder` which is deprecated in 2.3, is removed in 3.0 and `OneHotEncoderEstimator` is now renamed to `OneHotEncoder`. + ## From 2.2 to 2.3 ### Breaking changes diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java similarity index 91% rename from examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java index 6f93cff94b725..4b49bebf7ccfe 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -23,7 +23,7 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.ml.feature.OneHotEncoderEstimator; +import org.apache.spark.ml.feature.OneHotEncoder; import org.apache.spark.ml.feature.OneHotEncoderModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -34,11 +34,11 @@ import org.apache.spark.sql.types.StructType; // $example off$ -public class JavaOneHotEncoderEstimatorExample { +public class JavaOneHotEncoderExample { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("JavaOneHotEncoderEstimatorExample") + .appName("JavaOneHotEncoderExample") .getOrCreate(); // Note: categorical features are usually first encoded with StringIndexer @@ -59,7 +59,7 @@ public static void main(String[] args) { Dataset df = spark.createDataFrame(data, schema); - OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() + OneHotEncoder encoder = new OneHotEncoder() .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); diff --git a/examples/src/main/python/ml/onehot_encoder_estimator_example.py b/examples/src/main/python/ml/onehot_encoder_example.py similarity index 83% rename from examples/src/main/python/ml/onehot_encoder_estimator_example.py rename to examples/src/main/python/ml/onehot_encoder_example.py index 2723e681cea7c..73775b79e36cb 100644 --- a/examples/src/main/python/ml/onehot_encoder_estimator_example.py +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -18,14 +18,14 @@ from __future__ import print_function # $example on$ -from pyspark.ml.feature import OneHotEncoderEstimator +from pyspark.ml.feature import OneHotEncoder # $example off$ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("OneHotEncoderEstimatorExample")\ + .appName("OneHotEncoderExample")\ .getOrCreate() # Note: categorical features are usually first encoded with StringIndexer @@ -39,8 +39,8 @@ (2.0, 0.0) ], ["categoryIndex1", "categoryIndex2"]) - encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], - outputCols=["categoryVec1", "categoryVec2"]) + encoder = OneHotEncoder(inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryVec1", "categoryVec2"]) model = encoder.fit(df) encoded = model.transform(df) encoded.show() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala similarity index 89% rename from examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala index 45d816808ed8e..742f3cdeea35c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -19,15 +19,15 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.feature.OneHotEncoderEstimator +import org.apache.spark.ml.feature.OneHotEncoder // $example off$ import org.apache.spark.sql.SparkSession -object OneHotEncoderEstimatorExample { +object OneHotEncoderExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder - .appName("OneHotEncoderEstimatorExample") + .appName("OneHotEncoderExample") .getOrCreate() // Note: categorical features are usually first encoded with StringIndexer @@ -41,7 +41,7 @@ object OneHotEncoderEstimatorExample { (2.0, 0.0) )).toDF("categoryIndex1", "categoryIndex2") - val encoder = new OneHotEncoderEstimator() + val encoder = new OneHotEncoder() .setInputCols(Array("categoryIndex1", "categoryIndex2")) .setOutputCols(Array("categoryVec1", "categoryVec2")) val model = encoder.fit(df) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 27e4869a020b7..ec9792cbbda8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,126 +17,512 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException import org.apache.spark.annotation.Since -import org.apache.spark.ml.Transformer +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** Private trait for params and common methods for OneHotEncoder and OneHotEncoderModel */ +private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid + with HasInputCols with HasOutputCols { + + /** + * Param for how to handle invalid data during transform(). + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. + * Default: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", + ParamValidators.inArray(OneHotEncoder.supportedHandleInvalids)) + + setDefault(handleInvalid, OneHotEncoder.ERROR_INVALID) + + /** + * Whether to drop the last category in the encoded vector (default: true) + * @group param + */ + @Since("2.3.0") + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) + + /** @group getParam */ + @Since("2.3.0") + def getDropLast: Boolean = $(dropLast) + + protected def validateAndTransformSchema( + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + // Input columns must be NumericType. + inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) + + // Prepares output columns with proper attributes by examining input columns. + val inputFields = $(inputCols).map(schema(_)) + + val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => + OneHotEncoderCommon.transformOutputColumnSchema( + inputField, outputColName, dropLast, keepInvalid) + } + outputFields.foldLeft(schema) { case (newSchema, outputField) => + SchemaUtils.appendColumn(newSchema, outputField) + } + } +} /** * A one-hot encoder that maps a column of category indices to a column of binary vectors, with * at most a single one-value per row that indicates the input category index. * For example with 5 categories, an input value of 2.0 would map to an output vector of * `[0.0, 0.0, 1.0, 0.0]`. - * The last category is not included by default (configurable via `OneHotEncoder!.dropLast` + * The last category is not included by default (configurable via `dropLast`), * because it makes the vector entries sum up to one, and hence linearly dependent. * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. * * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. * The output vectors are sparse. * + * When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + * vector. + * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * * @see `StringIndexer` for converting categorical values into category indices - * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder` - * will be removed in 3.0.0. */ -@Since("1.4.0") -@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" + - " will be removed in 3.0.0.", "2.3.0") -class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with DefaultParamsWritable { +@Since("3.0.0") +class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) + extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable { - @Since("1.4.0") - def this() = this(Identifiable.randomUID("oneHot")) + @Since("3.0.0") + def this() = this(Identifiable.randomUID("oneHotEncoder")) - /** - * Whether to drop the last category in the encoded vector (default: true) - * @group param - */ - @Since("1.4.0") - final val dropLast: BooleanParam = - new BooleanParam(this, "dropLast", "whether to drop the last category") - setDefault(dropLast -> true) + /** @group setParam */ + @Since("3.0.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) - /** @group getParam */ - @Since("2.0.0") - def getDropLast: Boolean = $(dropLast) + /** @group setParam */ + @Since("3.0.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) /** @group setParam */ - @Since("1.4.0") + @Since("3.0.0") def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ - @Since("1.4.0") - def setInputCol(value: String): this.type = set(inputCol, value) + @Since("3.0.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("3.0.0") + override def transformSchema(schema: StructType): StructType = { + val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID + validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + } + + @Since("3.0.0") + override def fit(dataset: Dataset[_]): OneHotEncoderModel = { + transformSchema(dataset.schema) + + // Compute the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, + keepInvalid = false) + val categorySizes = new Array[Int]($(outputCols).length) + + val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val numOfAttrs = AttributeGroup.fromStructField( + transformedSchema(outputColName)).size + if (numOfAttrs < 0) { + Some(idx) + } else { + categorySizes(idx) = numOfAttrs + None + } + } + + // Some input columns don't have attributes or their attributes don't have necessary info. + // We need to scan the data to get the number of values for each column. + if (columnToScanIndices.length > 0) { + val inputColNames = columnToScanIndices.map($(inputCols)(_)) + val outputColNames = columnToScanIndices.map($(outputCols)(_)) + + // When fitting data, we want the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, inputColNames, outputColNames, dropLast = false) + attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => + categorySizes(idx) = attrGroup.size + } + } + + val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) + copyValues(model) + } + + @Since("3.0.0") + override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) +} + +@Since("3.0.0") +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { + + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + + @Since("3.0.0") + override def load(path: String): OneHotEncoder = super.load(path) +} + +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ +@Since("3.0.0") +class OneHotEncoderModel private[ml] ( + @Since("3.0.0") override val uid: String, + @Since("3.0.0") val categorySizes: Array[Int]) + extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable { + + import OneHotEncoderModel._ + + // Returns the category size for each index with `dropLast` and `handleInvalid` + // taken into account. + private def getConfigedCategorySizes: Array[Int] = { + val dropLast = getDropLast + val keepInvalid = getHandleInvalid == OneHotEncoder.KEEP_INVALID + + if (!dropLast && keepInvalid) { + // When `handleInvalid` is "keep", an extra category is added as last category + // for invalid data. + categorySizes.map(_ + 1) + } else if (dropLast && !keepInvalid) { + // When `dropLast` is true, the last category is removed. + categorySizes.map(_ - 1) + } else { + // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid + // data is removed. Thus, it is the same as the plain number of categories. + categorySizes + } + } + + private def encoder: UserDefinedFunction = { + val keepInvalid = getHandleInvalid == OneHotEncoder.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes + + // The udf performed on input data. The first parameter is the input value. The second + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize + } else { + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoder.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoder.KEEP_INVALID}.") + } + } + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) + } else { + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) + } + } + } /** @group setParam */ - @Since("1.4.0") - def setOutputCol(value: String): this.type = set(outputCol, value) + @Since("3.0.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) - @Since("1.4.0") + /** @group setParam */ + @Since("3.0.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("3.0.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("3.0.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("3.0.0") override def transformSchema(schema: StructType): StructType = { - val inputColName = $(inputCol) - val outputColName = $(outputCol) - val inputFields = schema.fields + val inputColNames = $(inputCols) + + require(inputColNames.length == categorySizes.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"features ${categorySizes.length} during fitting.") + + val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID + val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + verifyNumOfValues(transformedSchema) + } - require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type ${NumericType.simpleString} but got " + - schema(inputColName).dataType.catalogString) - require(!inputFields.exists(_.name == outputColName), - s"Output column $outputColName already exists.") + /** + * If the metadata of input columns also specifies the number of categories, we need to + * compare with expected category number with `handleInvalid` and `dropLast` taken into + * account. Mismatched numbers will cause exception. + */ + private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes + $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = $(inputCols)(idx) + val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) - val outputField = OneHotEncoderCommon.transformOutputColumnSchema( - schema(inputColName), outputColName, $(dropLast)) - val outputFields = inputFields :+ outputField - StructType(outputFields) + // If the input metadata specifies number of category for output column, + // comparing with expected category number with `handleInvalid` and + // `dropLast` taken into account. + if (attrGroup.attributes.nonEmpty) { + val numCategories = configedSizes(idx) + require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + + s"$numCategories categorical values for input column $inputColName, " + + s"but the input column had metadata specifying ${attrGroup.size} values.") + } + } + schema } - @Since("2.0.0") + @Since("3.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - // schema transformation - val inputColName = $(inputCol) - val outputColName = $(outputCol) + val transformedSchema = transformSchema(dataset.schema, logging = true) + val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID - val outputAttrGroupFromSchema = AttributeGroup.fromStructField( - transformSchema(dataset.schema)(outputColName)) + val encodedColumns = $(inputCols).indices.map { idx => + val inputColName = $(inputCols)(idx) + val outputColName = $(outputCols)(idx) - val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { - OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0) - } else { - outputAttrGroupFromSchema + val outputAttrGroupFromSchema = + AttributeGroup.fromStructField(transformedSchema(outputColName)) + + val metadata = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, + categorySizes(idx), $(dropLast), keepInvalid).toMetadata() + } else { + outputAttrGroupFromSchema.toMetadata() + } + + encoder(col(inputColName).cast(DoubleType), lit(idx)) + .as(outputColName, metadata) + } + dataset.withColumns($(outputCols), encodedColumns) + } + + @Since("3.0.0") + override def copy(extra: ParamMap): OneHotEncoderModel = { + val copied = new OneHotEncoderModel(uid, categorySizes) + copyValues(copied, extra).setParent(parent) + } + + @Since("3.0.0") + override def write: MLWriter = new OneHotEncoderModelWriter(this) +} + +@Since("3.0.0") +object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { + + private[OneHotEncoderModel] + class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { + + private case class Data(categorySizes: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.categorySizes) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] { + + private val className = classOf[OneHotEncoderModel].getName + + override def load(path: String): OneHotEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("categorySizes") + .head() + val categorySizes = data.getAs[Seq[Int]](0).toArray + val model = new OneHotEncoderModel(metadata.uid, categorySizes) + metadata.getAndSetParams(model) + model + } + } + + @Since("3.0.0") + override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader + + @Since("3.0.0") + override def load(path: String): OneHotEncoderModel = super.load(path) +} + +/** + * Provides some helper methods used by `OneHotEncoder`. + */ +private[feature] object OneHotEncoderCommon { + + private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = { + val inputAttr = Attribute.fromStructField(inputCol) + inputAttr match { + case nominal: NominalAttribute => + if (nominal.values.isDefined) { + nominal.values + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values + } else { + Some(Array.tabulate(2)(_.toString)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column ${inputCol.name} cannot be continuous-value.") + case _ => + None // optimistic about unknown attributes } + } - val metadata = outputAttrGroup.toMetadata() + /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */ + private def genOutputAttrGroup( + outputAttrNames: Option[Array[String]], + outputColName: String): AttributeGroup = { + outputAttrNames.map { attrNames => + val attrs: Array[Attribute] = attrNames.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup(outputColName, attrs) + }.getOrElse{ + new AttributeGroup(outputColName) + } + } - // data transformation - val size = outputAttrGroup.size - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val encode = udf { label: Double => - if (label < size) { - Vectors.sparse(size, Array(label.toInt), oneValue) + /** + * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column. + */ + def transformOutputColumnSchema( + inputCol: StructField, + outputColName: String, + dropLast: Boolean, + keepInvalid: Boolean = false): StructField = { + val outputAttrNames = genOutputAttrNames(inputCol) + val filteredOutputAttrNames = outputAttrNames.map { names => + if (dropLast && !keepInvalid) { + require(names.length > 1, + s"The input column ${inputCol.name} should have at least two distinct values.") + names.dropRight(1) + } else if (!dropLast && keepInvalid) { + names ++ Seq("invalidValues") } else { - Vectors.sparse(size, emptyIndices, emptyValues) + names } } - dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) + genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField() } - @Since("1.4.1") - override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) -} + /** + * This method is called when we want to generate `AttributeGroup` from actual data for + * one-hot encoder. + */ + def getOutputAttrGroupFromData( + dataset: Dataset[_], + inputColNames: Seq[String], + outputColNames: Seq[String], + dropLast: Boolean): Seq[AttributeGroup] = { + // The RDD approach has advantage of early-stop if any values are invalid. It seems that + // DataFrame ops don't have equivalent functions. + val columns = inputColNames.map { inputColName => + col(inputColName).cast(DoubleType) + } + val numOfColumns = columns.length -@Since("1.6.0") -object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { + val numAttrsArray = dataset.select(columns: _*).rdd.map { row => + (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray + }.treeAggregate(new Array[Double](numOfColumns))( + (maxValues, curValues) => { + (0 until numOfColumns).foreach { idx => + val x = curValues(idx) + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") + assert(x >= 0.0 && x == x.toInt, + s"Values from column ${inputColNames(idx)} must be indices, but got $x.") + maxValues(idx) = math.max(maxValues(idx), x) + } + maxValues + }, + (m0, m1) => { + (0 until numOfColumns).foreach { idx => + m0(idx) = math.max(m0(idx), m1(idx)) + } + m0 + } + ).map(_.toInt + 1) - @Since("1.6.0") - override def load(path: String): OneHotEncoder = super.load(path) + outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => + createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false) + } + } + + /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ + def createAttrGroupForAttrNames( + outputColName: String, + numAttrs: Int, + dropLast: Boolean, + keepInvalid: Boolean): AttributeGroup = { + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) + val filtered = if (dropLast && !keepInvalid) { + outputAttrNames.dropRight(1) + } else if (!dropLast && keepInvalid) { + outputAttrNames ++ Seq("invalidValues") + } else { + outputAttrNames + } + genOutputAttrGroup(Some(filtered), outputColName) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala deleted file mode 100644 index 4a44f3186538d..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ /dev/null @@ -1,528 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.feature - -import org.apache.hadoop.fs.Path - -import org.apache.spark.SparkException -import org.apache.spark.annotation.Since -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} -import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} - -/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ -private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid - with HasInputCols with HasOutputCols { - - /** - * Param for how to handle invalid data during transform(). - * Options are 'keep' (invalid data presented as an extra categorical feature) or - * 'error' (throw an error). - * Note that this Param is only used during transform; during fitting, invalid data - * will result in an error. - * Default: "error" - * @group param - */ - @Since("2.3.0") - override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data during transform(). " + - "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error). Note that this Param is only used during transform; " + - "during fitting, invalid data will result in an error.", - ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) - - setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) - - /** - * Whether to drop the last category in the encoded vector (default: true) - * @group param - */ - @Since("2.3.0") - final val dropLast: BooleanParam = - new BooleanParam(this, "dropLast", "whether to drop the last category") - setDefault(dropLast -> true) - - /** @group getParam */ - @Since("2.3.0") - def getDropLast: Boolean = $(dropLast) - - protected def validateAndTransformSchema( - schema: StructType, - dropLast: Boolean, - keepInvalid: Boolean): StructType = { - val inputColNames = $(inputCols) - val outputColNames = $(outputCols) - - require(inputColNames.length == outputColNames.length, - s"The number of input columns ${inputColNames.length} must be the same as the number of " + - s"output columns ${outputColNames.length}.") - - // Input columns must be NumericType. - inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) - - // Prepares output columns with proper attributes by examining input columns. - val inputFields = $(inputCols).map(schema(_)) - - val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => - OneHotEncoderCommon.transformOutputColumnSchema( - inputField, outputColName, dropLast, keepInvalid) - } - outputFields.foldLeft(schema) { case (newSchema, outputField) => - SchemaUtils.appendColumn(newSchema, outputField) - } - } -} - -/** - * A one-hot encoder that maps a column of category indices to a column of binary vectors, with - * at most a single one-value per row that indicates the input category index. - * For example with 5 categories, an input value of 2.0 would map to an output vector of - * `[0.0, 0.0, 1.0, 0.0]`. - * The last category is not included by default (configurable via `dropLast`), - * because it makes the vector entries sum up to one, and hence linearly dependent. - * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - * - * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. - * The output vectors are sparse. - * - * When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is - * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros - * vector. - * - * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols - * come in pairs, specified by the order in the arrays, and each pair is treated independently. - * - * @see `StringIndexer` for converting categorical values into category indices - */ -@Since("2.3.0") -class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String) - extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable { - - @Since("2.3.0") - def this() = this(Identifiable.randomUID("oneHotEncoder")) - - /** @group setParam */ - @Since("2.3.0") - def setInputCols(values: Array[String]): this.type = set(inputCols, values) - - /** @group setParam */ - @Since("2.3.0") - def setOutputCols(values: Array[String]): this.type = set(outputCols, values) - - /** @group setParam */ - @Since("2.3.0") - def setDropLast(value: Boolean): this.type = set(dropLast, value) - - /** @group setParam */ - @Since("2.3.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - - @Since("2.3.0") - override def transformSchema(schema: StructType): StructType = { - val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - validateAndTransformSchema(schema, dropLast = $(dropLast), - keepInvalid = keepInvalid) - } - - @Since("2.3.0") - override def fit(dataset: Dataset[_]): OneHotEncoderModel = { - transformSchema(dataset.schema) - - // Compute the plain number of categories without `handleInvalid` and - // `dropLast` taken into account. - val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, - keepInvalid = false) - val categorySizes = new Array[Int]($(outputCols).length) - - val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => - val numOfAttrs = AttributeGroup.fromStructField( - transformedSchema(outputColName)).size - if (numOfAttrs < 0) { - Some(idx) - } else { - categorySizes(idx) = numOfAttrs - None - } - } - - // Some input columns don't have attributes or their attributes don't have necessary info. - // We need to scan the data to get the number of values for each column. - if (columnToScanIndices.length > 0) { - val inputColNames = columnToScanIndices.map($(inputCols)(_)) - val outputColNames = columnToScanIndices.map($(outputCols)(_)) - - // When fitting data, we want the plain number of categories without `handleInvalid` and - // `dropLast` taken into account. - val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, inputColNames, outputColNames, dropLast = false) - attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => - categorySizes(idx) = attrGroup.size - } - } - - val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) - copyValues(model) - } - - @Since("2.3.0") - override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra) -} - -@Since("2.3.0") -object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] { - - private[feature] val KEEP_INVALID: String = "keep" - private[feature] val ERROR_INVALID: String = "error" - private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) - - @Since("2.3.0") - override def load(path: String): OneHotEncoderEstimator = super.load(path) -} - -/** - * @param categorySizes Original number of categories for each feature being encoded. - * The array contains one value for each input column, in order. - */ -@Since("2.3.0") -class OneHotEncoderModel private[ml] ( - @Since("2.3.0") override val uid: String, - @Since("2.3.0") val categorySizes: Array[Int]) - extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable { - - import OneHotEncoderModel._ - - // Returns the category size for each index with `dropLast` and `handleInvalid` - // taken into account. - private def getConfigedCategorySizes: Array[Int] = { - val dropLast = getDropLast - val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID - - if (!dropLast && keepInvalid) { - // When `handleInvalid` is "keep", an extra category is added as last category - // for invalid data. - categorySizes.map(_ + 1) - } else if (dropLast && !keepInvalid) { - // When `dropLast` is true, the last category is removed. - categorySizes.map(_ - 1) - } else { - // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid - // data is removed. Thus, it is the same as the plain number of categories. - categorySizes - } - } - - private def encoder: UserDefinedFunction = { - val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID - val configedSizes = getConfigedCategorySizes - val localCategorySizes = categorySizes - - // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index in inputCols of the column being encoded. - udf { (label: Double, colIdx: Int) => - val origCategorySize = localCategorySizes(colIdx) - // idx: index in vector of the single 1-valued element - val idx = if (label >= 0 && label < origCategorySize) { - label - } else { - if (keepInvalid) { - origCategorySize - } else { - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative. " + - s"To handle invalid values, set Param handleInvalid to " + - s"${OneHotEncoderEstimator.KEEP_INVALID}") - } else { - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") - } - } - } - - val size = configedSizes(colIdx) - if (idx < size) { - Vectors.sparse(size, Array(idx.toInt), Array(1.0)) - } else { - Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) - } - } - } - - /** @group setParam */ - @Since("2.3.0") - def setInputCols(values: Array[String]): this.type = set(inputCols, values) - - /** @group setParam */ - @Since("2.3.0") - def setOutputCols(values: Array[String]): this.type = set(outputCols, values) - - /** @group setParam */ - @Since("2.3.0") - def setDropLast(value: Boolean): this.type = set(dropLast, value) - - /** @group setParam */ - @Since("2.3.0") - def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - - @Since("2.3.0") - override def transformSchema(schema: StructType): StructType = { - val inputColNames = $(inputCols) - - require(inputColNames.length == categorySizes.length, - s"The number of input columns ${inputColNames.length} must be the same as the number of " + - s"features ${categorySizes.length} during fitting.") - - val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast), - keepInvalid = keepInvalid) - verifyNumOfValues(transformedSchema) - } - - /** - * If the metadata of input columns also specifies the number of categories, we need to - * compare with expected category number with `handleInvalid` and `dropLast` taken into - * account. Mismatched numbers will cause exception. - */ - private def verifyNumOfValues(schema: StructType): StructType = { - val configedSizes = getConfigedCategorySizes - $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => - val inputColName = $(inputCols)(idx) - val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) - - // If the input metadata specifies number of category for output column, - // comparing with expected category number with `handleInvalid` and - // `dropLast` taken into account. - if (attrGroup.attributes.nonEmpty) { - val numCategories = configedSizes(idx) - require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column $inputColName, " + - s"but the input column had metadata specifying ${attrGroup.size} values.") - } - } - schema - } - - @Since("2.3.0") - override def transform(dataset: Dataset[_]): DataFrame = { - val transformedSchema = transformSchema(dataset.schema, logging = true) - val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - - val encodedColumns = $(inputCols).indices.map { idx => - val inputColName = $(inputCols)(idx) - val outputColName = $(outputCols)(idx) - - val outputAttrGroupFromSchema = - AttributeGroup.fromStructField(transformedSchema(outputColName)) - - val metadata = if (outputAttrGroupFromSchema.size < 0) { - OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, - categorySizes(idx), $(dropLast), keepInvalid).toMetadata() - } else { - outputAttrGroupFromSchema.toMetadata() - } - - encoder(col(inputColName).cast(DoubleType), lit(idx)) - .as(outputColName, metadata) - } - dataset.withColumns($(outputCols), encodedColumns) - } - - @Since("2.3.0") - override def copy(extra: ParamMap): OneHotEncoderModel = { - val copied = new OneHotEncoderModel(uid, categorySizes) - copyValues(copied, extra).setParent(parent) - } - - @Since("2.3.0") - override def write: MLWriter = new OneHotEncoderModelWriter(this) -} - -@Since("2.3.0") -object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { - - private[OneHotEncoderModel] - class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { - - private case class Data(categorySizes: Array[Int]) - - override protected def saveImpl(path: String): Unit = { - DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.categorySizes) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - - private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] { - - private val className = classOf[OneHotEncoderModel].getName - - override def load(path: String): OneHotEncoderModel = { - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("categorySizes") - .head() - val categorySizes = data.getAs[Seq[Int]](0).toArray - val model = new OneHotEncoderModel(metadata.uid, categorySizes) - metadata.getAndSetParams(model) - model - } - } - - @Since("2.3.0") - override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader - - @Since("2.3.0") - override def load(path: String): OneHotEncoderModel = super.load(path) -} - -/** - * Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`. - */ -private[feature] object OneHotEncoderCommon { - - private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = { - val inputAttr = Attribute.fromStructField(inputCol) - inputAttr match { - case nominal: NominalAttribute => - if (nominal.values.isDefined) { - nominal.values - } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(_.toString)) - } else { - None - } - case binary: BinaryAttribute => - if (binary.values.isDefined) { - binary.values - } else { - Some(Array.tabulate(2)(_.toString)) - } - case _: NumericAttribute => - throw new RuntimeException( - s"The input column ${inputCol.name} cannot be continuous-value.") - case _ => - None // optimistic about unknown attributes - } - } - - /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */ - private def genOutputAttrGroup( - outputAttrNames: Option[Array[String]], - outputColName: String): AttributeGroup = { - outputAttrNames.map { attrNames => - val attrs: Array[Attribute] = attrNames.map { name => - BinaryAttribute.defaultAttr.withName(name) - } - new AttributeGroup(outputColName, attrs) - }.getOrElse{ - new AttributeGroup(outputColName) - } - } - - /** - * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column. - */ - def transformOutputColumnSchema( - inputCol: StructField, - outputColName: String, - dropLast: Boolean, - keepInvalid: Boolean = false): StructField = { - val outputAttrNames = genOutputAttrNames(inputCol) - val filteredOutputAttrNames = outputAttrNames.map { names => - if (dropLast && !keepInvalid) { - require(names.length > 1, - s"The input column ${inputCol.name} should have at least two distinct values.") - names.dropRight(1) - } else if (!dropLast && keepInvalid) { - names ++ Seq("invalidValues") - } else { - names - } - } - - genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField() - } - - /** - * This method is called when we want to generate `AttributeGroup` from actual data for - * one-hot encoder. - */ - def getOutputAttrGroupFromData( - dataset: Dataset[_], - inputColNames: Seq[String], - outputColNames: Seq[String], - dropLast: Boolean): Seq[AttributeGroup] = { - // The RDD approach has advantage of early-stop if any values are invalid. It seems that - // DataFrame ops don't have equivalent functions. - val columns = inputColNames.map { inputColName => - col(inputColName).cast(DoubleType) - } - val numOfColumns = columns.length - - val numAttrsArray = dataset.select(columns: _*).rdd.map { row => - (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray - }.treeAggregate(new Array[Double](numOfColumns))( - (maxValues, curValues) => { - (0 until numOfColumns).foreach { idx => - val x = curValues(idx) - assert(x <= Int.MaxValue, - s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") - assert(x >= 0.0 && x == x.toInt, - s"Values from column ${inputColNames(idx)} must be indices, but got $x.") - maxValues(idx) = math.max(maxValues(idx), x) - } - maxValues - }, - (m0, m1) => { - (0 until numOfColumns).foreach { idx => - m0(idx) = math.max(m0(idx), m1(idx)) - } - m0 - } - ).map(_.toInt + 1) - - outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => - createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false) - } - } - - /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ - def createAttrGroupForAttrNames( - outputColName: String, - numAttrs: Int, - dropLast: Boolean, - keepInvalid: Boolean): AttributeGroup = { - val outputAttrNames = Array.tabulate(numAttrs)(_.toString) - val filtered = if (dropLast && !keepInvalid) { - outputAttrNames.dropRight(1) - } else if (!dropLast && keepInvalid) { - outputAttrNames ++ Seq("invalidValues") - } else { - outputAttrNames - } - genOutputAttrGroup(Some(filtered), outputColName) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 346e1823f00b8..d7eb13772aa64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -246,7 +246,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. if (!hasIntercept && !keepReferenceCategory) { - encoderStages += new OneHotEncoderEstimator(uid) + encoderStages += new OneHotEncoder(uid) .setInputCols(Array(indexed(term))) .setOutputCols(Array(encodedCol)) .setDropLast(false) @@ -269,7 +269,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) if (oneHotEncodeColumns.nonEmpty) { val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip - encoderStages += new OneHotEncoderEstimator(uid) + encoderStages += new OneHotEncoder(uid) .setInputCols(inputCols) .setOutputCols(outputCols) .setDropLast(true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala deleted file mode 100644 index d549e13262273..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ /dev/null @@ -1,422 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.feature - -import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} -import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} -import org.apache.spark.sql.{Encoder, Row} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types._ - -class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest { - - import testImplicits._ - - test("params") { - ParamsSuite.checkParams(new OneHotEncoderEstimator) - } - - test("OneHotEncoderEstimator dropLast = false") { - val data = Seq( - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected", new VectorUDT))) - - val df = spark.createDataFrame(sc.parallelize(data), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - assert(encoder.getDropLast === true) - encoder.setDropLast(false) - assert(encoder.getDropLast === false) - val model = encoder.fit(df) - testTransformer[(Double, Vector)](df, model, "output", "expected") { - case Row(output: Vector, expected: Vector) => - assert(output === expected) - } - } - - test("OneHotEncoderEstimator dropLast = true") { - val data = Seq( - Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))), - Row(2.0, Vectors.sparse(2, Seq())), - Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), - Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), - Row(2.0, Vectors.sparse(2, Seq()))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected", new VectorUDT))) - - val df = spark.createDataFrame(sc.parallelize(data), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - - val model = encoder.fit(df) - testTransformer[(Double, Vector)](df, model, "output", "expected") { - case Row(output: Vector, expected: Vector) => - assert(output === expected) - } - } - - test("input column with ML attribute") { - val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") - .select(col("size").as("size", attr.toMetadata())) - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("size")) - .setOutputCols(Array("encoded")) - val model = encoder.fit(df) - testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => - val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) - } - } - - test("input column without ML attribute") { - val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("index")) - .setOutputCols(Array("encoded")) - val model = encoder.fit(df) - testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => - val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) - } - } - - test("read/write") { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("index")) - .setOutputCols(Array("encoded")) - testDefaultReadWrite(encoder) - } - - test("OneHotEncoderModel read/write") { - val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) - val newInstance = testDefaultReadWrite(instance) - assert(newInstance.categorySizes === instance.categorySizes) - } - - test("OneHotEncoderEstimator with varying types") { - val data = Seq( - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected", new VectorUDT))) - - val df = spark.createDataFrame(sc.parallelize(data), schema) - - class NumericTypeWithEncoder[A](val numericType: NumericType) - (implicit val encoder: Encoder[(A, Vector)]) - - val types = Seq( - new NumericTypeWithEncoder[Short](ShortType), - new NumericTypeWithEncoder[Long](LongType), - new NumericTypeWithEncoder[Int](IntegerType), - new NumericTypeWithEncoder[Float](FloatType), - new NumericTypeWithEncoder[Byte](ByteType), - new NumericTypeWithEncoder[Double](DoubleType), - new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) - - for (t <- types) { - val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) - val estimator = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - .setDropLast(false) - - val model = estimator.fit(dfWithTypes) - testTransformer(dfWithTypes, model, "output", "expected") { - case Row(output: Vector, expected: Vector) => - assert(output === expected) - }(t.encoder) - } - } - - test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") { - val data = Seq( - Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), - Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) - - val schema = StructType(Array( - StructField("input1", DoubleType), - StructField("expected1", new VectorUDT), - StructField("input2", DoubleType), - StructField("expected2", new VectorUDT))) - - val df = spark.createDataFrame(sc.parallelize(data), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input1", "input2")) - .setOutputCols(Array("output1", "output2")) - assert(encoder.getDropLast === true) - encoder.setDropLast(false) - assert(encoder.getDropLast === false) - - val model = encoder.fit(df) - testTransformer[(Double, Vector, Double, Vector)]( - df, - model, - "output1", - "output2", - "expected1", - "expected2") { - case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => - assert(output1 === expected1) - assert(output2 === expected2) - } - } - - test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") { - val data = Seq( - Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), - Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())), - Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))), - Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0))))) - - val schema = StructType(Array( - StructField("input1", DoubleType), - StructField("expected1", new VectorUDT), - StructField("input2", DoubleType), - StructField("expected2", new VectorUDT))) - - val df = spark.createDataFrame(sc.parallelize(data), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input1", "input2")) - .setOutputCols(Array("output1", "output2")) - - val model = encoder.fit(df) - testTransformer[(Double, Vector, Double, Vector)]( - df, - model, - "output1", - "output2", - "expected1", - "expected2") { - case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => - assert(output1 === expected1) - assert(output2 === expected2) - } - } - - test("Throw error on invalid values") { - val trainingData = Seq((0, 0), (1, 1), (2, 2)) - val trainingDF = trainingData.toDF("id", "a") - val testData = Seq((0, 0), (1, 2), (1, 3)) - val testDF = testData.toDF("id", "a") - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("a")) - .setOutputCols(Array("encoded")) - - val model = encoder.fit(trainingDF) - testTransformerByInterceptingException[(Int, Int)]( - testDF, - model, - expectedMessagePart = "Unseen value: 3.0. To handle unseen values", - firstResultCol = "encoded") - - } - - test("Can't transform on negative input") { - val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b") - val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b") - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("a")) - .setOutputCols(Array("encoded")) - - val model = encoder.fit(trainingDF) - testTransformerByInterceptingException[(Int, Int)]( - testDF, - model, - expectedMessagePart = "Negative value: -1.0. Input can't be negative", - firstResultCol = "encoded") - } - - test("Keep on invalid values: dropLast = false") { - val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") - - val testData = Seq( - Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), - Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected", new VectorUDT))) - - val testDF = spark.createDataFrame(sc.parallelize(testData), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - .setHandleInvalid("keep") - .setDropLast(false) - - val model = encoder.fit(trainingDF) - testTransformer[(Double, Vector)](testDF, model, "output", "expected") { - case Row(output: Vector, expected: Vector) => - assert(output === expected) - } - } - - test("Keep on invalid values: dropLast = true") { - val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") - - val testData = Seq( - Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), - Row(3.0, Vectors.sparse(3, Seq()))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected", new VectorUDT))) - - val testDF = spark.createDataFrame(sc.parallelize(testData), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - .setHandleInvalid("keep") - .setDropLast(true) - - val model = encoder.fit(trainingDF) - testTransformer[(Double, Vector)](testDF, model, "output", "expected") { - case Row(output: Vector, expected: Vector) => - assert(output === expected) - } - } - - test("OneHotEncoderModel changes dropLast") { - val data = Seq( - Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), - Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), - Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq()))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected1", new VectorUDT), - StructField("expected2", new VectorUDT))) - - val df = spark.createDataFrame(sc.parallelize(data), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - - val model = encoder.fit(df) - - model.setDropLast(false) - testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { - case Row(output: Vector, expected1: Vector) => - assert(output === expected1) - } - - model.setDropLast(true) - testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { - case Row(output: Vector, expected2: Vector) => - assert(output === expected2) - } - } - - test("OneHotEncoderModel changes handleInvalid") { - val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") - - val testData = Seq( - Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), - Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), - Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) - - val schema = StructType(Array( - StructField("input", DoubleType), - StructField("expected", new VectorUDT))) - - val testDF = spark.createDataFrame(sc.parallelize(testData), schema) - - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("input")) - .setOutputCols(Array("output")) - - val model = encoder.fit(trainingDF) - model.setHandleInvalid("error") - - testTransformerByInterceptingException[(Double, Vector)]( - testDF, - model, - expectedMessagePart = "Unseen value: 3.0. To handle unseen values", - firstResultCol = "output") - - model.setHandleInvalid("keep") - testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } - } - - test("Transforming on mismatched attributes") { - val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") - .select(col("size").as("size", attr.toMetadata())) - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("size")) - .setOutputCols(Array("encoded")) - val model = encoder.fit(df) - - val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") - val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") - .select(col("size").as("size", testAttr.toMetadata())) - testTransformerByInterceptingException[(Double)]( - testDF, - model, - expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", - firstResultCol = "encoded") - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 41b32b2ffa096..d92313f4ce038 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -18,72 +18,71 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} -import org.apache.spark.ml.linalg.Vector -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} -import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ -class OneHotEncoderSuite - extends MLTest with DefaultReadWriteTest { +class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ - def stringIndexed(): DataFrame = { - val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) - val df = data.toDF("id", "label") - val indexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .fit(df) - indexer.transform(df) - } - test("params") { ParamsSuite.checkParams(new OneHotEncoder) } test("OneHotEncoder dropLast = false") { - val transformed = stringIndexed() + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val encoder = new OneHotEncoder() - .setInputCol("labelIndex") - .setOutputCol("labelVec") + .setInputCols(Array("input")) + .setOutputCols(Array("output")) assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val expected = Seq( - (0, Vectors.sparse(3, Seq((0, 1.0)))), - (1, Vectors.sparse(3, Seq((2, 1.0)))), - (2, Vectors.sparse(3, Seq((1, 1.0)))), - (3, Vectors.sparse(3, Seq((0, 1.0)))), - (4, Vectors.sparse(3, Seq((0, 1.0)))), - (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") - - val withExpected = transformed.join(expected, "id") - testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + val model = encoder.fit(df) + testTransformer[(Double, Vector)](df, model, "output", "expected") { case Row(output: Vector, expected: Vector) => assert(output === expected) } } test("OneHotEncoder dropLast = true") { - val transformed = stringIndexed() + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val encoder = new OneHotEncoder() - .setInputCol("labelIndex") - .setOutputCol("labelVec") - val expected = Seq( - (0, Vectors.sparse(2, Seq((0, 1.0)))), - (1, Vectors.sparse(2, Seq())), - (2, Vectors.sparse(2, Seq((1, 1.0)))), - (3, Vectors.sparse(2, Seq((0, 1.0)))), - (4, Vectors.sparse(2, Seq((0, 1.0)))), - (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected") - - val withExpected = transformed.join(expected, "id") - testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + testTransformer[(Double, Vector)](df, model, "output", "expected") { case Row(output: Vector, expected: Vector) => assert(output === expected) } @@ -94,52 +93,61 @@ class OneHotEncoderSuite val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() - .setInputCol("size") - .setOutputCol("encoded") - testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows => - val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } } - test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() - .setInputCol("index") - .setOutputCol("encoded") - val rows = encoder.transform(df).select("encoded").collect() - val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } } test("read/write") { - val t = new OneHotEncoder() - .setInputCol("myInputCol") - .setOutputCol("myOutputCol") - .setDropLast(false) - testDefaultReadWrite(t) + val encoder = new OneHotEncoder() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + testDefaultReadWrite(encoder) + } + + test("OneHotEncoderModel read/write") { + val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.categorySizes === instance.categorySizes) } test("OneHotEncoder with varying types") { - val df = stringIndexed() - val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val expected = Seq( - (0, Vectors.sparse(3, Seq((0, 1.0)))), - (1, Vectors.sparse(3, Seq((2, 1.0)))), - (2, Vectors.sparse(3, Seq((1, 1.0)))), - (3, Vectors.sparse(3, Seq((0, 1.0)))), - (4, Vectors.sparse(3, Seq((0, 1.0)))), - (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) - val withExpected = df.join(expected, "id") + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) class NumericTypeWithEncoder[A](val numericType: NumericType) - (implicit val encoder: Encoder[(A, Vector)]) + (implicit val encoder: Encoder[(A, Vector)]) val types = Seq( new NumericTypeWithEncoder[Short](ShortType), @@ -151,17 +159,264 @@ class OneHotEncoderSuite new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) for (t <- types) { - val dfWithTypes = withExpected.select(col("labelIndex") - .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected")) - val encoder = new OneHotEncoder() - .setInputCol("labelIndex") - .setOutputCol("labelVec") + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoder() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) .setDropLast(false) - testTransformer(dfWithTypes, encoder, "labelVec", "expected") { + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { case Row(output: Vector, expected: Vector) => assert(output === expected) }(t.encoder) } } + + test("OneHotEncoder: encoding multiple columns and dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoder() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) + } + } + + test("OneHotEncoder: encoding multiple columns and dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())), + Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoder() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val model = encoder.fit(df) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) + } + } + + test("Throw error on invalid values") { + val trainingData = Seq((0, 0), (1, 1), (2, 2)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2), (1, 3)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoder() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "encoded") + + } + + test("Can't transform on negative input") { + val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b") + val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b") + + val encoder = new OneHotEncoder() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Negative value: -1.0. Input can't be negative", + firstResultCol = "encoded") + } + + test("Keep on invalid values: dropLast = false") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoder() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(false) + + val model = encoder.fit(trainingDF) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } + } + + test("Keep on invalid values: dropLast = true") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(3, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoder() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(true) + + val model = encoder.fit(trainingDF) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } + } + + test("OneHotEncoderModel changes dropLast") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected1", new VectorUDT), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoder() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + + model.setDropLast(false) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { + case Row(output: Vector, expected1: Vector) => + assert(output === expected1) + } + + model.setDropLast(true) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { + case Row(output: Vector, expected2: Vector) => + assert(output === expected2) + } + } + + test("OneHotEncoderModel changes handleInvalid") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoder() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(trainingDF) + model.setHandleInvalid("error") + + testTransformerByInterceptingException[(Double, Vector)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "output") + + model.setHandleInvalid("keep") + testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } + } + + test("Transforming on mismatched attributes") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + + val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") + val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", testAttr.toMetadata())) + testTransformerByInterceptingException[(Double)]( + testDF, + model, + expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", + firstResultCol = "encoded") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3fabec0f60125..5e97d826370f7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -228,6 +228,18 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"), + // [SPARK-26133][ML] Remove deprecated OneHotEncoder and rename OneHotEncoderEstimator to OneHotEncoder + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.OneHotEncoder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.getInputCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.getOutputCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.inputCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.setInputCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.setOutputCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.outputCol"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator$"), + // [SPARK-26141] Enable custom metrics implementation in shuffle write // Following are Java private classes ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3d23700242594..6cc80e181e5e0 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -44,8 +44,7 @@ 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', - 'OneHotEncoder', - 'OneHotEncoderEstimator', 'OneHotEncoderModel', + 'OneHotEncoder', 'OneHotEncoderModel', 'PCA', 'PCAModel', 'PolynomialExpansion', 'QuantileDiscretizer', @@ -1642,91 +1641,8 @@ def getP(self): @inherit_doc -class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): - """ - A one-hot encoder that maps a column of category indices to a - column of binary vectors, with at most a single one-value per row - that indicates the input category index. - For example with 5 categories, an input value of 2.0 would map to - an output vector of `[0.0, 0.0, 1.0, 0.0]`. - The last category is not included by default (configurable via - :py:attr:`dropLast`) because it makes the vector entries sum up to - one, and hence linearly dependent. - So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - - .. note:: This is different from scikit-learn's OneHotEncoder, - which keeps all categories. The output vectors are sparse. - - .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to - :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0. - - .. seealso:: - - :py:class:`StringIndexer` for converting categorical values into - category indices - - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") - >>> model = stringIndexer.fit(stringIndDf) - >>> td = model.transform(stringIndDf) - >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") - >>> encoder.transform(td).head().features - SparseVector(2, {0: 1.0}) - >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs - SparseVector(2, {0: 1.0}) - >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} - >>> encoder.transform(td, params).head().test - SparseVector(3, {0: 1.0}) - >>> onehotEncoderPath = temp_path + "/onehot-encoder" - >>> encoder.save(onehotEncoderPath) - >>> loadedEncoder = OneHotEncoder.load(onehotEncoderPath) - >>> loadedEncoder.getDropLast() == encoder.getDropLast() - True - - .. versionadded:: 1.4.0 - """ - - dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", - typeConverter=TypeConverters.toBoolean) - - @keyword_only - def __init__(self, dropLast=True, inputCol=None, outputCol=None): - """ - __init__(self, dropLast=True, inputCol=None, outputCol=None) - """ - super(OneHotEncoder, self).__init__() - self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self._setDefault(dropLast=True) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - @since("1.4.0") - def setParams(self, dropLast=True, inputCol=None, outputCol=None): - """ - setParams(self, dropLast=True, inputCol=None, outputCol=None) - Sets params for this OneHotEncoder. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - @since("1.4.0") - def setDropLast(self, value): - """ - Sets the value of :py:attr:`dropLast`. - """ - return self._set(dropLast=value) - - @since("1.4.0") - def getDropLast(self): - """ - Gets the value of dropLast or its default value. - """ - return self.getOrDefault(self.dropLast) - - -@inherit_doc -class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, - JavaMLReadable, JavaMLWritable): +class OneHotEncoder(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): """ A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index. @@ -1751,13 +1667,13 @@ class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHand >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) - >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"]) + >>> ohe = OneHotEncoder(inputCols=["input"], outputCols=["output"]) >>> model = ohe.fit(df) >>> model.transform(df).head().output SparseVector(2, {0: 1.0}) >>> ohePath = temp_path + "/oheEstimator" >>> ohe.save(ohePath) - >>> loadedOHE = OneHotEncoderEstimator.load(ohePath) + >>> loadedOHE = OneHotEncoder.load(ohePath) >>> loadedOHE.getInputCols() == ohe.getInputCols() True >>> modelPath = temp_path + "/ohe-model" @@ -1784,9 +1700,9 @@ def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropL """ __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) """ - super(OneHotEncoderEstimator, self).__init__() + super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj( - "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid) + "org.apache.spark.ml.feature.OneHotEncoder", self.uid) self._setDefault(handleInvalid="error", dropLast=True) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1796,7 +1712,7 @@ def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropL def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): """ setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) - Sets params for this OneHotEncoderEstimator. + Sets params for this OneHotEncoder. """ kwargs = self._input_kwargs return self._set(**kwargs) @@ -1821,7 +1737,7 @@ def _create_model(self, java_model): class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - Model fitted by :py:class:`OneHotEncoderEstimator`. + Model fitted by :py:class:`OneHotEncoder`. .. versionadded:: 2.3.0 """ From 7a83d71403edf7d24fa5efc0ef913f3ce76d88b8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 29 Nov 2018 22:15:12 +0800 Subject: [PATCH 2160/2461] [SPARK-26163][SQL] Parsing decimals from JSON using locale ## What changes were proposed in this pull request? In the PR, I propose using of the locale option to parse (and infer) decimals from JSON input. After the changes, `JacksonParser` converts input string to `BigDecimal` and to Spark's Decimal by using `java.text.DecimalFormat`. New behaviour can be switched off via SQL config `spark.sql.legacy.decimalParsing.enabled`. ## How was this patch tested? Added 2 tests to `JsonExpressionsSuite` for the `en-US`, `ko-KR`, `ru-RU`, `de-DE` locales: - Inferring decimal type using locale from JSON field values - Converting JSON field values to specified decimal type using the locales. Closes #23132 from MaxGekk/json-decimal-parsing-locale. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/ExprUtils.scala | 21 +++++ .../expressions/jsonExpressions.scala | 7 +- .../sql/catalyst/json/JacksonParser.scala | 6 ++ .../sql/catalyst/json/JsonInferSchema.scala | 89 +++++++++++-------- .../expressions/JsonExpressionsSuite.scala | 42 ++++++++- .../datasources/json/JsonDataSource.scala | 4 +- .../datasources/json/JsonSuite.scala | 15 ++-- 7 files changed, 132 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 89e9071324eff..3f3d6b2b63a06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition} +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType} @@ -83,4 +86,22 @@ object ExprUtils { } } } + + def getDecimalParser(locale: Locale): String => java.math.BigDecimal = { + if (locale == Locale.US) { // Special handling the default locale for backward compatibility + (s: String) => new java.math.BigDecimal(s.replaceAll(",", "")) + } else { + val decimalFormat = new DecimalFormat("", new DecimalFormatSymbols(locale)) + decimalFormat.setParseBigDecimal(true) + (s: String) => { + val pos = new ParsePosition(0) + val result = decimalFormat.parse(s, pos).asInstanceOf[java.math.BigDecimal] + if (pos.getIndex() != s.length() || pos.getErrorIndex() != -1) { + throw new IllegalArgumentException("Cannot parse any decimal"); + } else { + result + } + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 47304d835fdf8..e0cab537ce1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -23,12 +23,10 @@ import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -775,6 +773,9 @@ case class SchemaOfJson( factory } + @transient + private lazy val jsonInferSchema = new JsonInferSchema(jsonOptions) + @transient private lazy val json = child.eval().asInstanceOf[UTF8String] @@ -787,7 +788,7 @@ case class SchemaOfJson( override def eval(v: InternalRow): Any = { val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => parser.nextToken() - inferField(parser, jsonOptions) + jsonInferSchema.inferField(parser) } UTF8String.fromString(dt.catalogString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 92517aac053b2..2357595906b11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -135,6 +136,8 @@ class JacksonParser( } } + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. @@ -261,6 +264,9 @@ class JacksonParser( (parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) { case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) => Decimal(parser.getDecimalValue, dt.precision, dt.scale) + case VALUE_STRING if parser.getTextLength >= 1 => + val bigDecimal = decimalParser(parser.getText) + Decimal(bigDecimal, dt.precision, dt.scale) } case st: StructType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 9999a005106f9..263e05de32075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -19,18 +19,23 @@ package org.apache.spark.sql.catalyst.json import java.util.Comparator +import scala.util.control.Exception.allCatch + import com.fasterxml.jackson.core._ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -private[sql] object JsonInferSchema { +private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { + + private val decimalParser = ExprUtils.getDecimalParser(options.locale) /** * Infer the type of a collection of json records in three stages: @@ -40,21 +45,20 @@ private[sql] object JsonInferSchema { */ def infer[T]( json: RDD[T], - configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val parseMode = configOptions.parseMode - val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord + val parseMode = options.parseMode + val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord // In each RDD partition, perform schema inference on each row and merge afterwards. - val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val typeMerger = JsonInferSchema.compatibleRootType(columnNameOfCorruptRecord, parseMode) val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() - configOptions.setJacksonOptions(factory) + options.setJacksonOptions(factory) iter.flatMap { row => try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions)) + Some(inferField(parser)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -82,7 +86,7 @@ private[sql] object JsonInferSchema { } json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) - canonicalizeType(rootType, configOptions) match { + canonicalizeType(rootType, options) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -90,34 +94,17 @@ private[sql] object JsonInferSchema { } } - private[this] val structFieldComparator = new Comparator[StructField] { - override def compare(o1: StructField, o2: StructField): Int = { - o1.name.compareTo(o2.name) - } - } - - private def isSorted(arr: Array[StructField]): Boolean = { - var i: Int = 0 - while (i < arr.length - 1) { - if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { - return false - } - i += 1 - } - true - } - /** * Infer the type of a json document from the parser's token stream */ - def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions) + inferField(parser) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -128,18 +115,25 @@ private[sql] object JsonInferSchema { // record fields' types have been combined. NullType + case VALUE_STRING if options.prefersDecimal => + val decimalTry = allCatch opt { + val bigDecimal = decimalParser(parser.getText) + DecimalType(bigDecimal.precision, bigDecimal.scale) + } + decimalTry.getOrElse(StringType) case VALUE_STRING => StringType + case START_OBJECT => val builder = Array.newBuilder[StructField] while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions), + inferField(parser), nullable = true) } val fields: Array[StructField] = builder.result() // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(fields, structFieldComparator) + java.util.Arrays.sort(fields, JsonInferSchema.structFieldComparator) StructType(fields) case START_ARRAY => @@ -148,15 +142,15 @@ private[sql] object JsonInferSchema { // the type as we pass through all JSON objects. var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType( - elementType, inferField(parser, configOptions)) + elementType = JsonInferSchema.compatibleType( + elementType, inferField(parser)) } ArrayType(elementType) - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if options.primitivesAsString => StringType - case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType + case (VALUE_TRUE | VALUE_FALSE) if options.primitivesAsString => StringType case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => import JsonParser.NumberType._ @@ -172,7 +166,7 @@ private[sql] object JsonInferSchema { } else { DoubleType } - case FLOAT | DOUBLE if configOptions.prefersDecimal => + case FLOAT | DOUBLE if options.prefersDecimal => val v = parser.getDecimalValue if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { DecimalType(Math.max(v.precision(), v.scale()), v.scale()) @@ -217,12 +211,31 @@ private[sql] object JsonInferSchema { case other => Some(other) } +} + +object JsonInferSchema { + val structFieldComparator = new Comparator[StructField] { + override def compare(o1: StructField, o2: StructField): Int = { + o1.name.compareTo(o2.name) + } + } + + def isSorted(arr: Array[StructField]): Boolean = { + var i: Int = 0 + while (i < arr.length - 1) { + if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { + return false + } + i += 1 + } + true + } - private def withCorruptField( + def withCorruptField( struct: StructType, other: DataType, columnNameOfCorruptRecords: String, - parseMode: ParseMode) = parseMode match { + parseMode: ParseMode): StructType = parseMode match { case PermissiveMode => // If we see any other data type at the root level, we get records that cannot be // parsed. So, we use the struct as the data type and add the corrupt field to the schema. @@ -230,7 +243,7 @@ private[sql] object JsonInferSchema { // If this given struct does not have a column used for corrupt records, // add this field. val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields // Note: other code relies on this sorting for correctness, so don't remove it! java.util.Arrays.sort(newFields, structFieldComparator) StructType(newFields) @@ -253,7 +266,7 @@ private[sql] object JsonInferSchema { /** * Remove top-level ArrayType wrappers and merge the remaining schemas */ - private def compatibleRootType( + def compatibleRootType( columnNameOfCorruptRecords: String, parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 9b89a27c23770..5d60cefc13896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.SimpleDateFormat +import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat} import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -765,4 +765,44 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with timeZoneId = gmtId), expectedErrMsg = "The field for corrupt records must be string type and nullable") } + + def decimalInput(langTag: String): (Decimal, String) = { + val decimalVal = new java.math.BigDecimal("1000.001") + val decimalType = new DecimalType(10, 5) + val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale) + val decimalFormat = new DecimalFormat("", + new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = s"""{"d": "${decimalFormat.format(expected.toBigDecimal)}"}""" + + (expected, input) + } + + test("parse decimals using locale") { + def checkDecimalParsing(langTag: String): Unit = { + val schema = new StructType().add("d", DecimalType(10, 5)) + val options = Map("locale" -> langTag) + val (expected, input) = decimalInput(langTag) + + checkEvaluation( + JsonToStructs(schema, options, Literal.create(input), gmtId), + InternalRow(expected)) + } + + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) + } + + test("inferring the decimal type using locale") { + def checkDecimalInfer(langTag: String, expectedType: String): Unit = { + val options = Map("locale" -> langTag, "prefersDecimal" -> "true") + val (_, input) = decimalInput(langTag) + + checkEvaluation( + SchemaOfJson(Literal.create(input), options), + expectedType) + } + + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach { + checkDecimalInfer(_, """struct""") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index c7608e2e881ff..456f08a2a2ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -107,7 +107,7 @@ object TextInputJsonDataSource extends JsonDataSource { }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) SQLExecution.withSQLConfPropagated(json.sparkSession) { - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + new JsonInferSchema(parsedOptions).infer(rdd, rowParser) } } @@ -166,7 +166,7 @@ object MultiLineJsonDataSource extends JsonDataSource { .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) SQLExecution.withSQLConfPropagated(sparkSession) { - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + new JsonInferSchema(parsedOptions).infer[PortableDataStream](sampled, parser) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9ea9189cdf7f4..ee31077e12ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -31,8 +31,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} -import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource @@ -118,10 +117,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2) + var actual = JsonInferSchema.compatibleType(t1, t2) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1) + actual = JsonInferSchema.compatibleType(t2, t1) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } @@ -1373,9 +1372,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonInferSchema.infer on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = JsonInferSchema.infer( + val options = new JSONOptions(Map.empty[String, String], "GMT") + val emptySchema = new JsonInferSchema(options).infer( empty.rdd, - new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1400,9 +1399,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = JsonInferSchema.infer( + val options = new JSONOptions(Map.empty[String, String], "GMT") + val emptySchema = new JsonInferSchema(options).infer( emptyRecords.rdd, - new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } From b9b68a6dc7d0f735163e980392ea957f2d589923 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 29 Nov 2018 22:37:02 +0800 Subject: [PATCH 2161/2461] [SPARK-26211][SQL] Fix InSet for binary, and struct and array with null. ## What changes were proposed in this pull request? Currently `InSet` doesn't work properly for binary type, or struct and array type with null value in the set. Because, as for binary type, the `HashSet` doesn't work properly for `Array[Byte]`, and as for struct and array type with null value in the set, the `ordering` will throw a `NPE`. ## How was this patch tested? Added a few tests. Closes #23176 from ueshin/issues/SPARK-26211/inset. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/predicates.scala | 33 ++++++------ .../catalyst/expressions/PredicateSuite.scala | 50 ++++++++++++++++++- 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 16e0bc3aaf35b..01ecb99025eaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -367,31 +367,26 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } @transient lazy val set: Set[Any] = child.dataType match { - case _: AtomicType => hset + case t: AtomicType if !t.isInstanceOf[BinaryType] => hset case _: NullType => hset case _ => // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows - TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset - null) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val setTerm = ctx.addReferenceObj("set", set) - val childGen = child.genCode(ctx) - val setIsNull = if (hasNull) { - s"${ev.isNull} = !${ev.value};" - } else { - "" - } - ev.copy(code = - code""" - |${childGen.code} - |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; - |if (!${ev.isNull}) { - | ${ev.value} = $setTerm.contains(${childGen.value}); - | $setIsNull - |} - """.stripMargin) + nullSafeCodeGen(ctx, ev, c => { + val setTerm = ctx.addReferenceObj("set", set) + val setIsNull = if (hasNull) { + s"${ev.isNull} = !${ev.value};" + } else { + "" + } + s""" + |${ev.value} = $setTerm.contains($c); + |$setIsNull + """.stripMargin + }) } override def sql: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ac76b17ef4761..3b60d1d88b3c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -268,7 +268,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(nl, nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, - LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.foreach { t => val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { @@ -293,6 +293,54 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("INSET: binary") { + val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null + val onetwo = Literal(Array(1.toByte, 2.toByte)) + val three = Literal(Array(3.toByte)) + val threefour = Literal(Array(3.toByte, 4.toByte)) + val nl = Literal(null, onetwo.dataType) + checkEvaluation(InSet(onetwo, hS), true) + checkEvaluation(InSet(three, hS), true) + checkEvaluation(InSet(three, nS), true) + checkEvaluation(InSet(threefour, hS), false) + checkEvaluation(InSet(threefour, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + } + + test("INSET: struct") { + val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null + val oneA = Literal.create((1, "a")) + val twoB = Literal.create((2, "b")) + val twoC = Literal.create((2, "c")) + val nl = Literal(null, oneA.dataType) + checkEvaluation(InSet(oneA, hS), true) + checkEvaluation(InSet(twoB, hS), true) + checkEvaluation(InSet(twoB, nS), true) + checkEvaluation(InSet(twoC, hS), false) + checkEvaluation(InSet(twoC, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + } + + test("INSET: array") { + val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null + val onetwo = Literal.create(Seq(1, 2)) + val three = Literal.create(Seq(3)) + val threefour = Literal.create(Seq(3, 4)) + val nl = Literal(null, onetwo.dataType) + checkEvaluation(InSet(onetwo, hS), true) + checkEvaluation(InSet(three, hS), true) + checkEvaluation(InSet(three, nS), true) + checkEvaluation(InSet(threefour, hS), false) + checkEvaluation(InSet(threefour, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + } + private case class MyStruct(a: Long, b: String) private case class MyStruct2(a: MyStruct, b: Array[Int]) private val udt = new ExamplePointUDT From 06a87711b8a3a71c32897003cd9c6203e1c0c42e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 29 Nov 2018 08:48:12 -0600 Subject: [PATCH 2162/2461] [SPARK-26024][FOLLOWUP][MINOR] Follow-up to remove extra blank lines in R function descriptions ## What changes were proposed in this pull request? Follow-up to remove extra blank lines in R function descriptions ## How was this patch tested? N/A Closes #23167 from srowen/SPARK-26024.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- R/pkg/R/DataFrame.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ad9cd845f696c..745bb3e15932b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -766,7 +766,6 @@ setMethod("repartition", #' \item{2.} {Return a new SparkDataFrame range partitioned by the given column(s), #' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} -#' #' At least one partition-by expression must be specified. #' When no explicit sort order is specified, "ascending nulls first" is assumed. #' @@ -828,7 +827,6 @@ setMethod("repartitionByRange", #' toJSON #' #' Converts a SparkDataFrame into a SparkDataFrame of JSON string. -#' #' Each row is turned into a JSON document with columns as different fields. #' The returned SparkDataFrame has a single character column with the name \code{value} #' From e3ea93ab6c8e434faa360af831947e682206d50d Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 29 Nov 2018 08:53:12 -0600 Subject: [PATCH 2163/2461] [MINOR][ML] add missing params to Instr ## What changes were proposed in this pull request? add following param to instr: GBTC: validationTol GBTR: validationTol, validationIndicatorCol colnames in LiR, LinearSVC, etc ## How was this patch tested? existing tests Closes #23122 from zhengruifeng/instr_append_missing_params. Authored-by: zhengruifeng Signed-off-by: Sean Owen --- .../spark/ml/classification/DecisionTreeClassifier.scala | 3 ++- .../org/apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../scala/org/apache/spark/ml/classification/LinearSVC.scala | 4 ++-- .../apache/spark/ml/classification/LogisticRegression.scala | 5 +++-- .../ml/classification/MultilayerPerceptronClassifier.scala | 4 ++-- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 3 ++- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 3 ++- 7 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bcf89766b0873..d9292a5476767 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -115,7 +115,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, + probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 09a9df6d15ece..abe2d1febfdf8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -194,7 +194,7 @@ class GBTClassifier @Since("1.4.0") ( instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, - validationIndicatorCol) + validationIndicatorCol, validationTol) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = if (withValidation) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 1b5c02fc9a576..ff801abef9a94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -173,8 +173,8 @@ class LinearSVC @Since("2.2.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, regParam, maxIter, fitIntercept, tol, standardization, threshold, - aggregationDepth) + instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol, + regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) val (summarizer, labelSummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 6f0804f0c8e4a..27a7db0b2f5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -503,8 +503,9 @@ class LogisticRegression @Since("1.2.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, regParam, elasticNetParam, standardization, threshold, - maxIter, tol, fitIntercept) + instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol, + probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol, + fitIntercept) val (summarizer, labelSummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 4feddce1d9f2d..47b8a8df637b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -205,8 +205,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( dataset: Dataset[_]): MultilayerPerceptronClassificationModel = instrumented { instr => instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, labelCol, featuresCol, predictionCol, layers, maxIter, tol, - blockSize, solver, stepSize, seed) + instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, layers, maxIter, + tol, blockSize, solver, stepSize, seed) val myLayers = $(layers) val labels = myLayers.last diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 2f42a5922054e..e1fceb1fc96a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -377,7 +377,8 @@ final class OneVsRest @Since("1.4.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) + instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, + rawPredictionCol, parallelism) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9b386ef5eed8f..9a5b7d59e9aef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -171,7 +171,8 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) instr.logDataset(dataset) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol, validationTol) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, From 9a09e91a3e880b7a07b11a957fb6766578f5a1af Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Thu, 29 Nov 2018 08:54:31 -0600 Subject: [PATCH 2164/2461] [SPARK-26177] Automated formatting for Scala code ## What changes were proposed in this pull request? Add a maven plugin and wrapper script to use scalafmt to format files that differ from git master. Intention is for contributors to be able to use this to automate fixing code style, not to include it in build pipeline yet. If this PR is accepted, I'd make a different PR to update the code style section of https://spark.apache.org/contributing.html to mention the script ## How was this patch tested? Manually tested by modifying a few files and running ./dev/scalafmt then checking that ./dev/scalastyle still passed. Closes #23148 from koeninger/scalafmt. Authored-by: cody koeninger Signed-off-by: Sean Owen --- dev/.scalafmt.conf | 24 ++++++++++++++++++++++++ dev/scalafmt | 23 +++++++++++++++++++++++ pom.xml | 21 +++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 dev/.scalafmt.conf create mode 100755 dev/scalafmt diff --git a/dev/.scalafmt.conf b/dev/.scalafmt.conf new file mode 100644 index 0000000000000..def67e0269822 --- /dev/null +++ b/dev/.scalafmt.conf @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +align = none +align.openParenDefnSite = false +align.openParenCallSite = false +align.tokens = [] +docstrings = JavaDoc +maxColumn = 98 + diff --git a/dev/scalafmt b/dev/scalafmt new file mode 100755 index 0000000000000..76f688a2f5b88 --- /dev/null +++ b/dev/scalafmt @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# by default, format only files that differ from git master +params="${@:---diff}" + +./build/mvn mvn-scalafmt_2.12:format -Dscalafmt.skip=false -Dscalafmt.parameters="$params" diff --git a/pom.xml b/pom.xml index 93075e9b06a68..3ca2f739ce0ea 100644 --- a/pom.xml +++ b/pom.xml @@ -156,6 +156,9 @@ 3.2.2 2.12.7 2.12 + --diff --test + + true 1.9.13 2.9.6 1.1.7.1 @@ -2600,6 +2603,24 @@ + + org.antipathy + mvn-scalafmt_2.12 + 0.9_1.5.1 + + ${scalafmt.parameters} + ${scalafmt.skip} + dev/.scalafmt.conf + + + + validate + + format + + + + From 31c4fab3fb0343edf971de9070a319c6b3094647 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 29 Nov 2018 10:31:31 -0600 Subject: [PATCH 2165/2461] [SPARK-26081][SQL] Prevent empty files for empty partitions in Text datasources ## What changes were proposed in this pull request? In the PR, I propose to postpone creation of `OutputStream`/`Univocity`/`JacksonGenerator` till the first row should be written. This prevents creation of empty files for empty partitions. So, no need to open and to read such files back while loading data from the location. ## How was this patch tested? Added tests for Text, JSON and CSV datasource where empty dataset is written but should not produce any files. Closes #23052 from MaxGekk/text-empty-files. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Sean Owen --- .../datasources/csv/CSVFileFormat.scala | 20 ++++++++++++------- .../datasources/json/JsonFileFormat.scala | 19 +++++++++--------- .../datasources/text/TextFileFormat.scala | 16 +++++++++++---- .../execution/datasources/csv/CSVSuite.scala | 9 +++++++++ .../datasources/json/JsonSuite.scala | 13 ++++++++++-- .../datasources/text/TextSuite.scala | 9 +++++++++ 6 files changed, 64 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index ff1911d69a6b5..4c5a1d327023c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -169,13 +169,19 @@ private[csv] class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - private val charset = Charset.forName(params.charset) - - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) - - private val gen = new UnivocityGenerator(dataSchema, writer, params) + private var univocityGenerator: Option[UnivocityGenerator] = None + + override def write(row: InternalRow): Unit = { + val gen = univocityGenerator.getOrElse { + val charset = Charset.forName(params.charset) + val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) + val newGen = new UnivocityGenerator(dataSchema, os, params) + univocityGenerator = Some(newGen) + newGen + } - override def write(row: InternalRow): Unit = gen.write(row) + gen.write(row) + } - override def close(): Unit = gen.close() + override def close(): Unit = univocityGenerator.map(_.close()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 610f0d1619fc9..3042133ee43aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -175,19 +175,20 @@ private[json] class JsonOutputWriter( " which can be read back by Spark only if multiLine is enabled.") } - private val writer = CodecStreams.createOutputStreamWriter( - context, new Path(path), encoding) - - // create the Generator without separator inserted between 2 records - private[this] val gen = new JacksonGenerator(dataSchema, writer, options) + private var jacksonGenerator: Option[JacksonGenerator] = None override def write(row: InternalRow): Unit = { + val gen = jacksonGenerator.getOrElse { + val os = CodecStreams.createOutputStreamWriter(context, new Path(path), encoding) + // create the Generator without separator inserted between 2 records + val newGen = new JacksonGenerator(dataSchema, os, options) + jacksonGenerator = Some(newGen) + newGen + } + gen.write(row) gen.writeLineEnding() } - override def close(): Unit = { - gen.close() - writer.close() - } + override def close(): Unit = jacksonGenerator.map(_.close()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 268297148b522..01948ab25d63c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.io.OutputStream + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -148,17 +150,23 @@ class TextOutputWriter( context: TaskAttemptContext) extends OutputWriter { - private val writer = CodecStreams.createOutputStream(context, new Path(path)) + private var outputStream: Option[OutputStream] = None override def write(row: InternalRow): Unit = { + val os = outputStream.getOrElse{ + val newStream = CodecStreams.createOutputStream(context, new Path(path)) + outputStream = Some(newStream) + newStream + } + if (!row.isNullAt(0)) { val utf8string = row.getUTF8String(0) - utf8string.writeTo(writer) + utf8string.writeTo(os) } - writer.write(lineSeparator) + os.write(lineSeparator) } override def close(): Unit = { - writer.close() + outputStream.map(_.close()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index c275d63d32cc8..e14e8d49db5c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1986,4 +1986,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te }.getMessage assert(errMsg2.contains("'lineSep' can contain only 1 character")) } + + test("do not produce empty files for empty partitions") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.emptyDataset[String].write.csv(path) + val files = new File(path).listFiles() + assert(!files.exists(_.getName.endsWith("csv"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index ee31077e12ef3..ee5176e23e34d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1897,7 +1897,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .text(path) val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) - assert(jsonDF.count() === corruptRecordCount + 1) // null row for empty file + assert(jsonDF.count() === corruptRecordCount) assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) .add("dummy", StringType)) @@ -1910,7 +1910,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { F.count($"dummy").as("valid"), F.count($"_corrupt_record").as("corrupt"), F.count("*").as("count")) - checkAnswer(counts, Row(1, 5, 7)) // null row for empty file + checkAnswer(counts, Row(1, 4, 6)) // null row for empty file } } @@ -2555,4 +2555,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { emptyString(StringType, "") emptyString(BinaryType, "".getBytes(StandardCharsets.UTF_8)) } + + test("do not produce empty files for empty partitions") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.emptyDataset[String].write.json(path) + val files = new File(path).listFiles() + assert(!files.exists(_.getName.endsWith("json"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 0e7f3afa9c3ab..a86d5ee37f3db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -233,4 +233,13 @@ class TextSuite extends QueryTest with SharedSQLContext { assert(data(3) == Row("\"doh\"")) assert(data.length == 4) } + + test("do not produce empty files for empty partitions") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.emptyDataset[String].write.text(path) + val files = new File(path).listFiles() + assert(!files.exists(_.getName.endsWith("txt"))) + } + } } From de4228152771390b0c5ba15254e9c5b832095366 Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Thu, 29 Nov 2018 10:39:00 -0600 Subject: [PATCH 2166/2461] [MINOR][DOCS][WIP] Fix Typos ## What changes were proposed in this pull request? Fix Typos. ## How was this patch tested? NA Closes #23145 from kjmrknsn/docUpdate. Authored-by: Keiji Yoshida Signed-off-by: Sean Owen --- docs/index.md | 4 +-- docs/rdd-programming-guide.md | 8 +++--- docs/running-on-mesos.md | 2 +- docs/sql-data-sources-avro.md | 6 ++-- docs/sql-data-sources-hive-tables.md | 2 +- docs/sql-data-sources-jdbc.md | 2 +- docs/sql-data-sources-load-save-functions.md | 2 +- docs/sql-getting-started.md | 2 +- docs/sql-programming-guide.md | 2 +- docs/sql-pyspark-pandas-with-arrow.md | 2 +- docs/sql-reference.md | 6 ++-- docs/streaming-programming-guide.md | 2 +- .../structured-streaming-programming-guide.md | 28 +++++++++---------- 13 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/index.md b/docs/index.md index bd287e3f8d83f..8864239eb1643 100644 --- a/docs/index.md +++ b/docs/index.md @@ -66,8 +66,8 @@ Example applications are also provided in Python. For example, ./bin/spark-submit examples/src/main/python/pi.py 10 -Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included). -To run Spark interactively in a R interpreter, use `bin/sparkR`: +Spark also provides an [R API](sparkr.html) since 1.4 (only DataFrames APIs included). +To run Spark interactively in an R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 9a07d6ca24b65..2d1ddae5780de 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -332,7 +332,7 @@ One important parameter for parallel collections is the number of *partitions* t Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes a URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight scala %} scala> val distFile = sc.textFile("data.txt") @@ -365,7 +365,7 @@ Apart from text files, Spark's Scala API also supports several other data format Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes a URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight java %} JavaRDD distFile = sc.textFile("data.txt"); @@ -397,7 +397,7 @@ Apart from text files, Spark's Java API also supports several other data formats PySpark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes a URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight python %} >>> distFile = sc.textFile("data.txt") @@ -1122,7 +1122,7 @@ costly operation. #### Background -To understand what happens during the shuffle we can consider the example of the +To understand what happens during the shuffle, we can consider the example of the [`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all values for a single key are combined into a tuple - the key and the result of executing a reduce function against all values associated with that key. The challenge is that not all values for a diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 2502cd4ca86f4..b3ba4b255b71a 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -687,7 +687,7 @@ See the [configuration page](configuration.html) for information on Spark config 0 Set the maximum number GPU resources to acquire for this job. Note that executors will still launch when no GPU resources are found - since this configuration is just a upper limit and not a guaranteed amount. + since this configuration is just an upper limit and not a guaranteed amount. diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index bfe641d1c6d1d..b403a66fad79a 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -66,9 +66,9 @@ write.df(select(df, "name", "favorite_color"), "namesAndFavColors.avro", "avro") ## to_avro() and from_avro() The Avro package provides function `to_avro` to encode a column as binary in Avro format, and `from_avro()` to decode Avro binary data into a column. Both functions transform one column to -another column, and the input/output SQL data type can be complex type or primitive type. +another column, and the input/output SQL data type can be a complex type or a primitive type. -Using Avro record as columns are useful when reading from or writing to a streaming source like Kafka. Each +Using Avro record as columns is useful when reading from or writing to a streaming source like Kafka. Each Kafka key-value record will be augmented with some metadata, such as the ingestion timestamp into Kafka, the offset in Kafka, etc. * If the "value" field that contains your data is in Avro, you could use `from_avro()` to extract your data, enrich it, clean it, and then push it downstream to Kafka again or write it out to a file. * `to_avro()` can be used to turn structs into Avro records. This method is particularly useful when you would like to re-encode multiple columns into a single one when writing data out to Kafka. @@ -151,7 +151,7 @@ Data source options of Avro can be set via: avroSchema None - Optional Avro schema provided by an user in JSON format. The date type and naming of record fields + Optional Avro schema provided by a user in JSON format. The date type and naming of record fields should match the input Avro data or Catalyst data, otherwise the read/write action will fail. read and write diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md index 28e1a39626666..3b39a32d43240 100644 --- a/docs/sql-data-sources-hive-tables.md +++ b/docs/sql-data-sources-hive-tables.md @@ -74,7 +74,7 @@ creating table, you can create a table using storage handler at Hive side, and u inputFormat, outputFormat These 2 options specify the name of a corresponding `InputFormat` and `OutputFormat` class as a string literal, - e.g. `org.apache.hadoop.hive.ql.io.orc.OrcInputFormat`. These 2 options must be appeared in pair, and you can not + e.g. `org.apache.hadoop.hive.ql.io.orc.OrcInputFormat`. These 2 options must be appeared in a pair, and you can not specify them if you already specified the `fileFormat` option. diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 057e8217241aa..9a5d0fc7d424c 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -55,7 +55,7 @@ the following case-insensitive options: as a subquery in the FROM clause. Spark will also assign an alias to the subquery clause. As an example, spark will issue a query of the following form to the JDBC Source.

      SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

      - Below are couple of restrictions while using this option.
      + Below are a couple of restrictions while using this option.
      1. It is not allowed to specify `dbtable` and `query` options at the same time.
      2. It is not allowed to specify `query` and `partitionColumn` options at the same time. When specifying diff --git a/docs/sql-data-sources-load-save-functions.md b/docs/sql-data-sources-load-save-functions.md index e4c7b1766f918..4386caedb38b3 100644 --- a/docs/sql-data-sources-load-save-functions.md +++ b/docs/sql-data-sources-load-save-functions.md @@ -324,4 +324,4 @@ CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS; `partitionBy` creates a directory structure as described in the [Partition Discovery](sql-data-sources-parquet.html#partition-discovery) section. Thus, it has limited applicability to columns with high cardinality. In contrast `bucketBy` distributes -data across a fixed number of buckets and can be used when a number of unique values is unbounded. +data across a fixed number of buckets and can be used when the number of unique values is unbounded. diff --git a/docs/sql-getting-started.md b/docs/sql-getting-started.md index 88512205894ab..0c3f0fb20610f 100644 --- a/docs/sql-getting-started.md +++ b/docs/sql-getting-started.md @@ -99,7 +99,7 @@ Here we include some basic examples of structured data processing using Datasets
        {% include_example untyped_ops scala/org/apache/spark/examples/sql/SparkSQLExample.scala %} -For a complete list of the types of operations that can be performed on a Dataset refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.Dataset). +For a complete list of the types of operations that can be performed on a Dataset, refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.Dataset). In addition to simple column references and expressions, Datasets also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$).
        diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index eca8915dfa975..9c85a15827bbe 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -7,7 +7,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, Spark SQL uses this extra information to perform extra optimizations. There are several ways to -interact with Spark SQL including SQL and the Dataset API. When computing a result +interact with Spark SQL including SQL and the Dataset API. When computing a result, the same execution engine is used, independent of which API/language you are using to express the computation. This unification means that developers can easily switch back and forth between different APIs based on which provides the most natural way to express a given transformation. diff --git a/docs/sql-pyspark-pandas-with-arrow.md b/docs/sql-pyspark-pandas-with-arrow.md index d04b955f9bf8b..d18ca0beb0fc6 100644 --- a/docs/sql-pyspark-pandas-with-arrow.md +++ b/docs/sql-pyspark-pandas-with-arrow.md @@ -129,7 +129,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, `ArrayType` of `TimestampType`, and nested `StructType`. `BinaryType` is supported only when -installed PyArrow is equal to or higher then 0.10.0. +installed PyArrow is equal to or higher than 0.10.0. ### Setting Arrow Batch Size diff --git a/docs/sql-reference.md b/docs/sql-reference.md index 9e4239b6bad23..88d0596f3876e 100644 --- a/docs/sql-reference.md +++ b/docs/sql-reference.md @@ -38,15 +38,15 @@ Spark SQL and DataFrames support the following data types: elements with the type of `elementType`. `containsNull` is used to indicate if elements in a `ArrayType` value can have `null` values. - `MapType(keyType, valueType, valueContainsNull)`: - Represents values comprising a set of key-value pairs. The data type of keys are - described by `keyType` and the data type of values are described by `valueType`. + Represents values comprising a set of key-value pairs. The data type of keys is + described by `keyType` and the data type of values is described by `valueType`. For a `MapType` value, keys are not allowed to have `null` values. `valueContainsNull` is used to indicate if values of a `MapType` value can have `null` values. - `StructType(fields)`: Represents values with the structure described by a sequence of `StructField`s (`fields`). * `StructField(name, dataType, nullable)`: Represents a field in a `StructType`. The name of a field is indicated by `name`. The data type of a field is indicated - by `dataType`. `nullable` is used to indicate if values of this fields can have + by `dataType`. `nullable` is used to indicate if values of these fields can have `null` values.
        diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 70bee5032a24d..94c61205bd53b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -733,7 +733,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea Python API As of Spark {{site.SPARK_VERSION_SHORT}}, out of these sources, Kafka and Kinesis are available in the Python API. -This category of sources require interfacing with external non-Spark libraries, some of them with +This category of sources requires interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka). Hence, to minimize issues related to version conflicts of dependencies, the functionality to create DStreams from these sources has been moved to separate libraries that can be [linked](#linking) to explicitly when necessary. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 8cea98c2cc52b..32d61dcdb4599 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1493,7 +1493,7 @@ Additional details on supported joins: ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. -- *With watermark* - If there is a upper bound on how late a duplicate record may arrive, then you can define a watermark on a event time column and deduplicate using both the guid and the event time columns. The query will use the watermark to remove old state data from past records that are not expected to get any duplicates any more. This bounds the amount of the state the query has to maintain. +- *With watermark* - If there is an upper bound on how late a duplicate record may arrive, then you can define a watermark on an event time column and deduplicate using both the guid and the event time columns. The query will use the watermark to remove old state data from past records that are not expected to get any duplicates any more. This bounds the amount of the state the query has to maintain. - *Without watermark* - Since there are no bounds on when a duplicate record may arrive, the query stores the data from all the past records as state. @@ -1577,7 +1577,7 @@ event time seen in each input stream, calculates watermarks based on the corresp and chooses a single global watermark with them to be used for stateful operations. By default, the minimum is chosen as the global watermark because it ensures that no data is accidentally dropped as too late if one of the streams falls behind the others -(for example, one of the streams stop receiving data due to upstream failures). In other words, +(for example, one of the streams stops receiving data due to upstream failures). In other words, the global watermark will safely move at the pace of the slowest stream and the query output will be delayed accordingly. @@ -1598,7 +1598,7 @@ Some of them are as follows. - Multiple streaming aggregations (i.e. a chain of aggregations on a streaming DF) are not yet supported on streaming Datasets. -- Limit and take first N rows are not supported on streaming Datasets. +- Limit and take the first N rows are not supported on streaming Datasets. - Distinct operations on streaming Datasets are not supported. @@ -1634,7 +1634,7 @@ returned through `Dataset.writeStream()`. You will have to specify one or more o - *Query name:* Optionally, specify a unique name of the query for identification. -- *Trigger interval:* Optionally, specify the trigger interval. If it is not specified, the system will check for availability of new data as soon as the previous processing has completed. If a trigger time is missed because the previous processing has not completed, then the system will trigger processing immediately. +- *Trigger interval:* Optionally, specify the trigger interval. If it is not specified, the system will check for availability of new data as soon as the previous processing has been completed. If a trigger time is missed because the previous processing has not been completed, then the system will trigger processing immediately. - *Checkpoint location:* For some output sinks where the end-to-end fault-tolerance can be guaranteed, specify the location where the system will write all the checkpoint information. This should be a directory in an HDFS-compatible fault-tolerant file system. The semantics of checkpointing is discussed in more detail in the next section. @@ -2106,7 +2106,7 @@ With `foreachBatch`, you can do the following. ###### Foreach If `foreachBatch` is not an option (for example, corresponding batch data writer does not exist, or -continuous processing mode), then you can express you custom writer logic using `foreach`. +continuous processing mode), then you can express your custom writer logic using `foreach`. Specifically, you can express the data writing logic by dividing it into three methods: `open`, `process`, and `close`. Since Spark 2.4, `foreach` is available in Scala, Java and Python. @@ -2236,8 +2236,8 @@ When the streaming query is started, Spark calls the function or the object’s in the continuous mode, then this guarantee does not hold and therefore should not be used for deduplication. #### Triggers -The trigger settings of a streaming query defines the timing of streaming data processing, whether -the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query. +The trigger settings of a streaming query define the timing of streaming data processing, whether +the query is going to be executed as micro-batch query with a fixed batch interval or as a continuous processing query. Here are the different kinds of triggers that are supported. @@ -2960,7 +2960,7 @@ the effect of the change is not well-defined. For all of them: - Addition/deletion/modification of rate limits is allowed: `spark.readStream.format("kafka").option("subscribe", "topic")` to `spark.readStream.format("kafka").option("subscribe", "topic").option("maxOffsetsPerTrigger", ...)` - - Changes to subscribed topics/files is generally not allowed as the results are unpredictable: `spark.readStream.format("kafka").option("subscribe", "topic")` to `spark.readStream.format("kafka").option("subscribe", "newTopic")` + - Changes to subscribed topics/files are generally not allowed as the results are unpredictable: `spark.readStream.format("kafka").option("subscribe", "topic")` to `spark.readStream.format("kafka").option("subscribe", "newTopic")` - *Changes in the type of output sink*: Changes between a few specific combinations of sinks are allowed. This needs to be verified on a case-by-case basis. Here are a few examples. @@ -2974,17 +2974,17 @@ the effect of the change is not well-defined. For all of them: - *Changes in the parameters of output sink*: Whether this is allowed and whether the semantics of the change are well-defined depends on the sink and the query. Here are a few examples. - - Changes to output directory of a file sink is not allowed: `sdf.writeStream.format("parquet").option("path", "/somePath")` to `sdf.writeStream.format("parquet").option("path", "/anotherPath")` + - Changes to output directory of a file sink are not allowed: `sdf.writeStream.format("parquet").option("path", "/somePath")` to `sdf.writeStream.format("parquet").option("path", "/anotherPath")` - - Changes to output topic is allowed: `sdf.writeStream.format("kafka").option("topic", "someTopic")` to `sdf.writeStream.format("kafka").option("topic", "anotherTopic")` + - Changes to output topic are allowed: `sdf.writeStream.format("kafka").option("topic", "someTopic")` to `sdf.writeStream.format("kafka").option("topic", "anotherTopic")` - - Changes to the user-defined foreach sink (that is, the `ForeachWriter` code) is allowed, but the semantics of the change depends on the code. + - Changes to the user-defined foreach sink (that is, the `ForeachWriter` code) are allowed, but the semantics of the change depends on the code. - *Changes in projection / filter / map-like operations**: Some cases are allowed. For example: - Addition / deletion of filters is allowed: `sdf.selectExpr("a")` to `sdf.where(...).selectExpr("a").filter(...)`. - - Changes in projections with same output schema is allowed: `sdf.selectExpr("stringColumn AS json").writeStream` to `sdf.selectExpr("anotherStringColumn AS json").writeStream` + - Changes in projections with same output schema are allowed: `sdf.selectExpr("stringColumn AS json").writeStream` to `sdf.selectExpr("anotherStringColumn AS json").writeStream` - Changes in projections with different output schema are conditionally allowed: `sdf.selectExpr("a").writeStream` to `sdf.selectExpr("b").writeStream` is allowed only if the output sink allows the schema change from `"a"` to `"b"`. @@ -3000,7 +3000,7 @@ the effect of the change is not well-defined. For all of them: - *Streaming deduplication*: For example, `sdf.dropDuplicates("a")`. Any change in number or type of grouping keys or aggregates is not allowed. - *Stream-stream join*: For example, `sdf1.join(sdf2, ...)` (i.e. both inputs are generated with `sparkSession.readStream`). Changes - in the schema or equi-joining columns are not allowed. Changes in join type (outer or inner) not allowed. Other changes in the join condition are ill-defined. + in the schema or equi-joining columns are not allowed. Changes in join type (outer or inner) are not allowed. Other changes in the join condition are ill-defined. - *Arbitrary stateful operation*: For example, `sdf.groupByKey(...).mapGroupsWithState(...)` or `sdf.groupByKey(...).flatMapGroupsWithState(...)`. Any change to the schema of the user-defined state and the type of timeout is not allowed. @@ -3083,7 +3083,7 @@ spark \ -A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. +A checkpoint interval of 1 second means that the continuous processing engine will record the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. ## Supported Queries {:.no_toc} From 24e78b7f163acf6129d934633ae6d3e6d568656a Mon Sep 17 00:00:00 2001 From: Shahid Date: Thu, 29 Nov 2018 09:48:18 -0800 Subject: [PATCH 2167/2461] [SPARK-26186][SPARK-26184][CORE] Last updated time is not getting updated for the Inprogress application ## What changes were proposed in this pull request? When the 'spark.history.fs.inProgressOptimization.enabled' is true, inProgress application's last updated time is not getting updated in the History UI. Also, during the cleaning time, InProgress application is getting removed from the listing, even if the last updated time is within the cleaning threshold time. In this PR, if the fastInprogressOptimization enabled, we update the `lastUpdateTime` of the application as last scan time. This will update the `lastUpdateTime` in the historyUI and also while cleaning, it won't remove if the updateTime is within the cleaning interval ## How was this patch tested? Added UT, attached screen shot. Before patch: ![screenshot from 2018-11-27 23-22-38](https://user-images.githubusercontent.com/23054875/49101600-9b5a3380-f29c-11e8-8efc-3fb594e4279a.png) ![screenshot from 2018-11-27 23-20-11](https://user-images.githubusercontent.com/23054875/49101601-9c8b6080-f29c-11e8-928e-643a8c8f4477.png) After Patch: ![screenshot from 2018-11-27 23-37-10](https://user-images.githubusercontent.com/23054875/49101911-669aac00-f29d-11e8-8181-663e4a08ab0e.png) ![screenshot from 2018-11-27 23-39-04](https://user-images.githubusercontent.com/23054875/49102010-a5306680-f29d-11e8-947a-e8a2a09a785a.png) Closes #23158 from shahidki31/HistoryLastUpdateTime. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../deploy/history/FsHistoryProvider.scala | 22 +++++++++++ .../history/FsHistoryProviderSuite.scala | 39 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 4f4a00b7d831d..da6e5f03aabb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -461,6 +461,28 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (info.appId.isDefined && fastInProgressParsing) { // When fast in-progress parsing is on, we don't need to re-parse when the // size changes, but we do need to invalidate any existing UIs. + // Also, we need to update the `lastUpdated time` to display the updated time in + // the HistoryUI and to avoid cleaning the inprogress app while running. + val appInfo = listing.read(classOf[ApplicationInfoWrapper], info.appId.get) + + val attemptList = appInfo.attempts.map { attempt => + if (attempt.info.attemptId == info.attemptId) { + new AttemptInfoWrapper( + attempt.info.copy(lastUpdated = new Date(newLastScanTime)), + attempt.logPath, + attempt.fileSize, + attempt.adminAcls, + attempt.viewAcls, + attempt.adminAclsGroups, + attempt.viewAclsGroups) + } else { + attempt + } + } + + val updatedAppInfo = new ApplicationInfoWrapper(appInfo.info, attemptList) + listing.write(updatedAppInfo) + invalidateUI(info.appId.get, info.attemptId) false } else { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index b0ced46c9c7b8..527c654a7cd68 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -334,6 +334,45 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(!log2.exists()) } + test("should not clean inprogress application with lastUpdated time less than maxTime") { + val firstFileModifiedTime = TimeUnit.DAYS.toMillis(1) + val secondFileModifiedTime = TimeUnit.DAYS.toMillis(6) + val maxAge = TimeUnit.DAYS.toMillis(7) + val clock = new ManualClock(0) + val provider = new FsHistoryProvider( + createTestConf().set(MAX_LOG_AGE_S, maxAge / 1000), clock) + val log = newLogFile("inProgressApp1", None, inProgress = true) + writeFile(log, true, None, + SparkListenerApplicationStart( + "inProgressApp1", Some("inProgressApp1"), 3L, "test", Some("attempt1")) + ) + clock.setTime(firstFileModifiedTime) + log.setLastModified(clock.getTimeMillis()) + provider.checkForLogs() + writeFile(log, true, None, + SparkListenerApplicationStart( + "inProgressApp1", Some("inProgressApp1"), 3L, "test", Some("attempt1")), + SparkListenerJobStart(0, 1L, Nil, null) + ) + + clock.setTime(secondFileModifiedTime) + log.setLastModified(clock.getTimeMillis()) + provider.checkForLogs() + clock.setTime(TimeUnit.DAYS.toMillis(10)) + writeFile(log, true, None, + SparkListenerApplicationStart( + "inProgressApp1", Some("inProgressApp1"), 3L, "test", Some("attempt1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerJobEnd(0, 1L, JobSucceeded) + ) + log.setLastModified(clock.getTimeMillis()) + provider.checkForLogs() + // This should not trigger any cleanup + updateAndCheck(provider) { list => + list.size should be(1) + } + } + test("log cleaner for inProgress files") { val firstFileModifiedTime = TimeUnit.SECONDS.toMillis(10) val secondFileModifiedTime = TimeUnit.SECONDS.toMillis(20) From 1144df3b5dc8280ae4a07678cb439f0b44cb17b0 Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Thu, 29 Nov 2018 09:59:38 -0800 Subject: [PATCH 2168/2461] [SPARK-26015][K8S] Set a default UID for Spark on K8S Images Adds USER directives to the Dockerfiles which is configurable via build argument (`spark_uid`) for easy customisation. A `-u` flag is added to `bin/docker-image-tool.sh` to make it easy to customise this e.g. ``` > bin/docker-image-tool.sh -r rvesse -t uid -u 185 build > bin/docker-image-tool.sh -r rvesse -t uid push ``` If no UID is explicitly specified it defaults to `185` - this is per skonto's suggestion to align with the OpenShift standard reserved UID for Java apps ( https://lists.openshift.redhat.com/openshift-archives/users/2016-March/msg00283.html) Notes: - We have to make the `WORKDIR` writable by the root group or otherwise jobs will fail with `AccessDeniedException` To Do: - [x] Debug and resolve issue with client mode test - [x] Consider whether to always propagate `SPARK_USER_NAME` to environment of driver and executor pods so `entrypoint.sh` can insert that into `/etc/passwd` entry - [x] Rebase once PR #23013 is merged and update documentation accordingly Built the Docker images with the new Dockerfiles that include the `USER` directives. Ran the Spark on K8S integration tests against the new images. All pass except client mode which I am currently debugging further. Also manually dropped myself into the resulting container images via `docker run` and checked `id -u` output to see that UID is as expected. Tried customising the UID from the default via the new `-u` argument to `docker-image-tool.sh` and again checked the resulting image for the correct runtime UID. cc felixcheung skonto vanzin Closes #23017 from rvesse/SPARK-26015. Authored-by: Rob Vesse Signed-off-by: Marcelo Vanzin --- bin/docker-image-tool.sh | 16 +++++++++++++--- docs/running-on-kubernetes.md | 5 +++-- .../docker/src/main/dockerfiles/spark/Dockerfile | 6 ++++++ .../main/dockerfiles/spark/bindings/R/Dockerfile | 9 +++++++++ .../dockerfiles/spark/bindings/python/Dockerfile | 9 +++++++++ .../src/main/dockerfiles/spark/entrypoint.sh | 2 +- .../integrationtest/ClientModeTestsSuite.scala | 3 ++- 7 files changed, 43 insertions(+), 7 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 9f735f1148da4..fbf9c9e448fd1 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -146,6 +146,12 @@ function build { fi local BUILD_ARGS=(${BUILD_PARAMS}) + + # If a custom SPARK_UID was set add it to build arguments + if [ -n "$SPARK_UID" ]; then + BUILD_ARGS+=(--build-arg spark_uid=$SPARK_UID) + fi + local BINDING_BUILD_ARGS=( ${BUILD_PARAMS} --build-arg @@ -207,8 +213,10 @@ Options: -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. -n Build docker image with --no-cache - -b arg Build arg to build or push the image. For multiple build args, this option needs to - be used separately for each build arg. + -u uid UID to use in the USER directive to set the user the main Spark process runs as inside the + resulting container + -b arg Build arg to build or push the image. For multiple build args, this option needs to + be used separately for each build arg. Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -243,7 +251,8 @@ PYDOCKERFILE= RDOCKERFILE= NOCACHEARG= BUILD_PARAMS= -while getopts f:p:R:mr:t:nb: option +SPARK_UID= +while getopts f:p:R:mr:t:nb:u: option do case "${option}" in @@ -263,6 +272,7 @@ do fi eval $(minikube docker-env) ;; + u) SPARK_UID=${OPTARG};; esac done diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 2c01e1e7155ef..5639253d52f54 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -19,9 +19,9 @@ Please see [Spark Security](security.html) and the specific advice below before ## User Identity -Images built from the project provided Dockerfiles do not contain any [`USER`](https://docs.docker.com/engine/reference/builder/#user) directives. This means that the resulting images will be running the Spark processes as `root` inside the container. On unsecured clusters this may provide an attack vector for privilege escalation and container breakout. Therefore security conscious deployments should consider providing custom images with `USER` directives specifying an unprivileged UID and GID. +Images built from the project provided Dockerfiles contain a default [`USER`](https://docs.docker.com/engine/reference/builder/#user) directive with a default UID of `185`. This means that the resulting images will be running the Spark processes as this UID inside the container. Security conscious deployments should consider providing custom images with `USER` directives specifying their desired unprivileged UID and GID. The resulting UID should include the root group in its supplementary groups in order to be able to run the Spark executables. Users building their own images with the provided `docker-image-tool.sh` script can use the `-u ` option to specify the desired UID. -Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as. +Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. This can be used to override the `USER` directives in the images themselves. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as. ## Volume Mounts @@ -87,6 +87,7 @@ Example usage is: $ ./bin/docker-image-tool.sh -r -t my-tag build $ ./bin/docker-image-tool.sh -r -t my-tag push ``` +This will build using the projects provided default `Dockerfiles`. To see more options available for customising the behaviour of this tool, including providing custom `Dockerfiles`, please run with the `-h` flag. By default `bin/docker-image-tool.sh` builds docker image for running JVM jobs. You need to opt-in to build additional language binding docker images. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 89b20e1446229..0843040324707 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,6 +17,8 @@ FROM openjdk:8-alpine +ARG spark_uid=185 + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -47,5 +49,9 @@ COPY data /opt/spark/data ENV SPARK_HOME /opt/spark WORKDIR /opt/spark/work-dir +RUN chmod g+w /opt/spark/work-dir ENTRYPOINT [ "/opt/entrypoint.sh" ] + +# Specify the User that the actual main process will run as +USER ${spark_uid} diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile index 9f67422efeb3c..9ded57c655104 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile @@ -16,8 +16,14 @@ # ARG base_img +ARG spark_uid=185 + FROM $base_img WORKDIR / + +# Reset to root to run installation tasks +USER 0 + RUN mkdir ${SPARK_HOME}/R RUN apk add --no-cache R R-dev @@ -27,3 +33,6 @@ ENV R_HOME /usr/lib/R WORKDIR /opt/spark/work-dir ENTRYPOINT [ "/opt/entrypoint.sh" ] + +# Specify the User that the actual main process will run as +USER ${spark_uid} diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile index 69b6efa6149a0..de1a0617b1cc5 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -16,8 +16,14 @@ # ARG base_img +ARG spark_uid=185 + FROM $base_img WORKDIR / + +# Reset to root to run installation tasks +USER 0 + RUN mkdir ${SPARK_HOME}/python # TODO: Investigate running both pip and pip3 via virtualenvs RUN apk add --no-cache python && \ @@ -37,3 +43,6 @@ ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4 WORKDIR /opt/spark/work-dir ENTRYPOINT [ "/opt/entrypoint.sh" ] + +# Specify the User that the actual main process will run as +USER ${spark_uid} diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 2b2a4e4cf6bcc..2d770075a0748 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -30,7 +30,7 @@ set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + echo "$myuid:x:$myuid:$mygid:${SPARK_USER_NAME:-anonymous uid}:$SPARK_HOME:/bin/false" >> /etc/passwd else echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" fi diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala index c8bd584516ea5..2720cdf74ca8f 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ClientModeTestsSuite.scala @@ -62,11 +62,12 @@ private[spark] trait ClientModeTestsSuite { k8sSuite: KubernetesSuite => .endMetadata() .withNewSpec() .withServiceAccountName(kubernetesTestComponents.serviceAccountName) + .withRestartPolicy("Never") .addNewContainer() .withName("spark-example") .withImage(image) .withImagePullPolicy("IfNotPresent") - .withCommand("/opt/spark/bin/run-example") + .addToArgs("/opt/spark/bin/run-example") .addToArgs("--master", s"k8s://https://kubernetes.default.svc") .addToArgs("--deploy-mode", "client") .addToArgs("--conf", s"spark.kubernetes.container.image=$image") From 9fdc7a840daa64d1302d12027fd84ea9894110a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BA=AE?= Date: Thu, 29 Nov 2018 13:08:53 -0600 Subject: [PATCH 2169/2461] [SPARK-26158][MLLIB] fix covariance accuracy problem for DenseVector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Enhance accuracy of the covariance logic in RowMatrix for function computeCovariance ## How was this patch tested? Unit test Accuracy test Closes #23126 from KyleLi1985/master. Authored-by: 李亮 Signed-off-by: Sean Owen --- .../mllib/linalg/distributed/RowMatrix.scala | 97 ++++++++++++++----- .../apache/spark/ml/feature/JavaPCASuite.java | 3 +- .../linalg/distributed/RowMatrixSuite.scala | 14 +++ 3 files changed, 90 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index c12b751bfb8e4..ff02e5dd3c253 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -128,6 +128,77 @@ class RowMatrix @Since("1.0.0") ( RowMatrix.triuToFull(n, GU.data) } + private def computeDenseVectorCovariance(mean: Vector, n: Int, m: Long): Matrix = { + + val bc = rows.context.broadcast(mean) + + // Computes n*(n+1)/2, avoiding overflow in the multiplication. + // This succeeds when n <= 65535, which is checked above + val nt = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) + + val MU = rows.treeAggregate(new BDV[Double](nt))( + seqOp = (U, v) => { + + val n = v.size + val na = Array.ofDim[Double](n) + val means = bc.value + + val ta = v.toArray + for (index <- 0 until n) { + na(index) = ta(index) - means(index) + } + + BLAS.spr(1.0, new DenseVector(na), U.data) + U + }, combOp = (U1, U2) => U1 += U2) + + bc.destroy() + + val M = RowMatrix.triuToFull(n, MU.data).asBreeze + + var i = 0 + var j = 0 + val m1 = m - 1.0 + while (i < n) { + j = i + while (j < n) { + val Mij = M(i, j) / m1 + M(i, j) = Mij + M(j, i) = Mij + j += 1 + } + i += 1 + } + + Matrices.fromBreeze(M) + } + + private def computeSparseVectorCovariance(mean: Vector, n: Int, m: Long): Matrix = { + + // We use the formula Cov(X, Y) = E[X * Y] - E[X] E[Y], which is not accurate if E[X * Y] is + // large but Cov(X, Y) is small, but it is good for sparse computation. + // TODO: find a fast and stable way for sparse data. + val G = computeGramianMatrix().asBreeze + + var i = 0 + var j = 0 + val m1 = m - 1.0 + var alpha = 0.0 + while (i < n) { + alpha = m / m1 * mean(i) + j = i + while (j < n) { + val Gij = G(i, j) / m1 - alpha * mean(j) + G(i, j) = Gij + G(j, i) = Gij + j += 1 + } + i += 1 + } + + Matrices.fromBreeze(G) + } + private def checkNumColumns(cols: Int): Unit = { if (cols > 65535) { throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") @@ -337,29 +408,11 @@ class RowMatrix @Since("1.0.0") ( " Cannot compute the covariance of a RowMatrix with <= 1 row.") val mean = summary.mean - // We use the formula Cov(X, Y) = E[X * Y] - E[X] E[Y], which is not accurate if E[X * Y] is - // large but Cov(X, Y) is small, but it is good for sparse computation. - // TODO: find a fast and stable way for sparse data. - - val G = computeGramianMatrix().asBreeze - - var i = 0 - var j = 0 - val m1 = m - 1.0 - var alpha = 0.0 - while (i < n) { - alpha = m / m1 * mean(i) - j = i - while (j < n) { - val Gij = G(i, j) / m1 - alpha * mean(j) - G(i, j) = Gij - G(j, i) = Gij - j += 1 - } - i += 1 + if (rows.first().isInstanceOf[DenseVector]) { + computeDenseVectorCovariance(mean, n, m) + } else { + computeSparseVectorCovariance(mean, n, m) } - - Matrices.fromBreeze(G) } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 683ceffeaed0e..2e177edf2a5c3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -28,7 +28,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.Vectors; -import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.sql.Dataset; @@ -67,7 +66,7 @@ public void testPCA() { JavaRDD dataRDD = jsc.parallelize(points, 2); RowMatrix mat = new RowMatrix(dataRDD.map( - (Vector vector) -> (org.apache.spark.mllib.linalg.Vector) new DenseVector(vector.toArray()) + (Vector vector) -> org.apache.spark.mllib.linalg.Vectors.fromML(vector) ).rdd()); Matrix pc = mat.computePrincipalComponents(3); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 7c9e14f8cee70..a4ca4f0a80faa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -266,6 +266,20 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("dense vector covariance accuracy (SPARK-26158)") { + val denseData = Seq( + Vectors.dense(100000.000004, 199999.999999), + Vectors.dense(100000.000012, 200000.000002), + Vectors.dense(99999.9999931, 200000.000003), + Vectors.dense(99999.9999977, 200000.000001) + ) + val denseMat = new RowMatrix(sc.parallelize(denseData, 2)) + + val result = denseMat.computeCovariance() + val expected = breeze.linalg.cov(denseMat.toBreeze()) + assert(closeToZero(abs(expected) - abs(result.asBreeze.asInstanceOf[BDM[Double]]))) + } + test("compute covariance") { for (mat <- Seq(denseMat, sparseMat)) { val result = mat.computeCovariance() From cb368f2c2964797d7313d3a4151e2352ff7847a9 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 29 Nov 2018 12:09:30 -0800 Subject: [PATCH 2170/2461] [SPARK-26142] followup: Move sql shuffle read metrics relatives to SQLShuffleMetricsReporter ## What changes were proposed in this pull request? Follow up for https://github.com/apache/spark/pull/23128, move sql read metrics relatives to `SQLShuffleMetricsReporter`, in order to put sql shuffle read metrics relatives closer and avoid possible problem about forgetting update SQLShuffleMetricsReporter while new metrics added by others. ## How was this patch tested? Existing tests. Closes #23175 from xuanyuanking/SPARK-26142-follow. Authored-by: Yuanjian Li Signed-off-by: Reynold Xin --- .../exchange/ShuffleExchangeExec.scala | 4 +- .../apache/spark/sql/execution/limit.scala | 6 +-- .../sql/execution/metric/SQLMetrics.scala | 20 -------- .../metric/SQLShuffleMetricsReporter.scala | 50 +++++++++++++++---- .../execution/UnsafeRowSerializerSuite.scala | 4 +- 5 files changed, 47 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 8938d93da90eb..c9ca395bceaa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair @@ -49,7 +49,7 @@ case class ShuffleExchangeExec( override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") - ) ++ SQLMetrics.getShuffleReadMetrics(sparkContext) + ) ++ SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) override def nodeName: String = { val extraInfo = coordinator match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ea845da8438fe..e9ab7cd138d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter /** * Take the first `limit` elements and collect them to a single partition. @@ -38,7 +38,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) + override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( @@ -154,7 +154,7 @@ case class TakeOrderedAndProjectExec( private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) + override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) protected override def doExecute(): RDD[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 0b5ee3a5e0577..cbf707f4a9cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -82,14 +82,6 @@ object SQLMetrics { private val baseForAvgMetric: Int = 10 - val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" - val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" - val REMOTE_BYTES_READ = "remoteBytesRead" - val REMOTE_BYTES_READ_TO_DISK = "remoteBytesReadToDisk" - val LOCAL_BYTES_READ = "localBytesRead" - val FETCH_WAIT_TIME = "fetchWaitTime" - val RECORDS_READ = "recordsRead" - /** * Converts a double value to long value by multiplying a base integer, so we can store it in * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore @@ -202,16 +194,4 @@ object SQLMetrics { SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) } } - - /** - * Create all shuffle read relative metrics and return the Map. - */ - def getShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( - REMOTE_BLOCKS_FETCHED -> createMetric(sc, "remote blocks fetched"), - LOCAL_BLOCKS_FETCHED -> createMetric(sc, "local blocks fetched"), - REMOTE_BYTES_READ -> createSizeMetric(sc, "remote bytes read"), - REMOTE_BYTES_READ_TO_DISK -> createSizeMetric(sc, "remote bytes read to disk"), - LOCAL_BYTES_READ -> createSizeMetric(sc, "local bytes read"), - FETCH_WAIT_TIME -> createTimingMetric(sc, "fetch wait time"), - RECORDS_READ -> createMetric(sc, "records read")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala index 542141ea4b4e6..780f0d7622294 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -17,24 +17,32 @@ package org.apache.spark.sql.execution.metric +import org.apache.spark.SparkContext import org.apache.spark.executor.TempShuffleReadMetrics /** * A shuffle metrics reporter for SQL exchange operators. * @param tempMetrics [[TempShuffleReadMetrics]] created in TaskContext. * @param metrics All metrics in current SparkPlan. This param should not empty and - * contains all shuffle metrics defined in [[SQLMetrics.getShuffleReadMetrics]]. + * contains all shuffle metrics defined in createShuffleReadMetrics. */ private[spark] class SQLShuffleMetricsReporter( - tempMetrics: TempShuffleReadMetrics, - metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { - private[this] val _remoteBlocksFetched = metrics(SQLMetrics.REMOTE_BLOCKS_FETCHED) - private[this] val _localBlocksFetched = metrics(SQLMetrics.LOCAL_BLOCKS_FETCHED) - private[this] val _remoteBytesRead = metrics(SQLMetrics.REMOTE_BYTES_READ) - private[this] val _remoteBytesReadToDisk = metrics(SQLMetrics.REMOTE_BYTES_READ_TO_DISK) - private[this] val _localBytesRead = metrics(SQLMetrics.LOCAL_BYTES_READ) - private[this] val _fetchWaitTime = metrics(SQLMetrics.FETCH_WAIT_TIME) - private[this] val _recordsRead = metrics(SQLMetrics.RECORDS_READ) + tempMetrics: TempShuffleReadMetrics, + metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { + private[this] val _remoteBlocksFetched = + metrics(SQLShuffleMetricsReporter.REMOTE_BLOCKS_FETCHED) + private[this] val _localBlocksFetched = + metrics(SQLShuffleMetricsReporter.LOCAL_BLOCKS_FETCHED) + private[this] val _remoteBytesRead = + metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ) + private[this] val _remoteBytesReadToDisk = + metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ_TO_DISK) + private[this] val _localBytesRead = + metrics(SQLShuffleMetricsReporter.LOCAL_BYTES_READ) + private[this] val _fetchWaitTime = + metrics(SQLShuffleMetricsReporter.FETCH_WAIT_TIME) + private[this] val _recordsRead = + metrics(SQLShuffleMetricsReporter.RECORDS_READ) override def incRemoteBlocksFetched(v: Long): Unit = { _remoteBlocksFetched.add(v) @@ -65,3 +73,25 @@ private[spark] class SQLShuffleMetricsReporter( tempMetrics.incRecordsRead(v) } } + +private[spark] object SQLShuffleMetricsReporter { + val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" + val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" + val REMOTE_BYTES_READ = "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = "remoteBytesReadToDisk" + val LOCAL_BYTES_READ = "localBytesRead" + val FETCH_WAIT_TIME = "fetchWaitTime" + val RECORDS_READ = "recordsRead" + + /** + * Create all shuffle read relative metrics and return the Map. + */ + def createShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( + REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks fetched"), + LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks fetched"), + REMOTE_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "remote bytes read"), + REMOTE_BYTES_READ_TO_DISK -> SQLMetrics.createSizeMetric(sc, "remote bytes read to disk"), + LOCAL_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "local bytes read"), + FETCH_WAIT_TIME -> SQLMetrics.createTimingMetric(sc, "fetch wait time"), + RECORDS_READ -> SQLMetrics.createMetric(sc, "records read")) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 96b3aa5ee75b8..1ad5713ab8ae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -140,7 +140,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { new UnsafeRowSerializer(2)) val shuffled = new ShuffledRowRDD( dependency, - SQLMetrics.getShuffleReadMetrics(spark.sparkContext)) + SQLShuffleMetricsReporter.createShuffleReadMetrics(spark.sparkContext)) shuffled.count() } } From 59741887e272be92ebd6e61783f99f7d8fc05456 Mon Sep 17 00:00:00 2001 From: Wing Yew Poon Date: Thu, 29 Nov 2018 14:56:34 -0600 Subject: [PATCH 2171/2461] [SPARK-25905][CORE] When getting a remote block, avoid forcing a conversion to a ChunkedByteBuffer ## What changes were proposed in this pull request? In `BlockManager`, `getRemoteValues` gets a `ChunkedByteBuffer` (by calling `getRemoteBytes`) and creates an `InputStream` from it. `getRemoteBytes`, in turn, gets a `ManagedBuffer` and converts it to a `ChunkedByteBuffer`. Instead, expose a `getRemoteManagedBuffer` method so `getRemoteValues` can just get this `ManagedBuffer` and use its `InputStream`. When reading a remote cache block from disk, this reduces heap memory usage significantly. Retain `getRemoteBytes` for other callers. ## How was this patch tested? Imran Rashid wrote an application (https://github.com/squito/spark_2gb_test/blob/master/src/main/scala/com/cloudera/sparktest/LargeBlocks.scala), that among other things, tests reading remote cache blocks. I ran this application, using 2500MB blocks, to test reading a cache block on disk. Without this change, with `--executor-memory 5g`, the test fails with `java.lang.OutOfMemoryError: Java heap space`. With the change, the test passes with `--executor-memory 2g`. I also ran the unit tests in core. In particular, `DistributedSuite` has a set of tests that exercise the `getRemoteValues` code path. `BlockManagerSuite` has several tests that call `getRemoteBytes`; I left these unchanged, so `getRemoteBytes` still gets exercised. Closes #23058 from wypoon/SPARK-25905. Authored-by: Wing Yew Poon Signed-off-by: Imran Rashid --- .../apache/spark/storage/BlockManager.scala | 43 ++++++++++++------- .../spark/util/io/ChunkedByteBuffer.scala | 2 - .../org/apache/spark/DistributedSuite.scala | 2 +- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1b617297e0a30..1dfbc6effb346 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -692,9 +692,9 @@ private[spark] class BlockManager( */ private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = { val ct = implicitly[ClassTag[T]] - getRemoteBytes(blockId).map { data => + getRemoteManagedBuffer(blockId).map { data => val values = - serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct) + serializerManager.dataDeserializeStream(blockId, data.createInputStream())(ct) new BlockResult(values, DataReadMethod.Network, data.size) } } @@ -717,13 +717,9 @@ private[spark] class BlockManager( } /** - * Get block from remote block managers as serialized bytes. + * Get block from remote block managers as a ManagedBuffer. */ - def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { - // TODO SPARK-25905 if we change this method to return the ManagedBuffer, then getRemoteValues - // could just use the inputStream on the temp file, rather than reading the file into memory. - // Until then, replication can cause the process to use too much memory and get killed - // even though we've read the data to disk. + private def getRemoteManagedBuffer(blockId: BlockId): Option[ManagedBuffer] = { logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") var runningFailureCount = 0 @@ -788,14 +784,13 @@ private[spark] class BlockManager( } if (data != null) { - // SPARK-24307 undocumented "escape-hatch" in case there are any issues in converting to - // ChunkedByteBuffer, to go back to old code-path. Can be removed post Spark 2.4 if - // new path is stable. - if (remoteReadNioBufferConversion) { - return Some(new ChunkedByteBuffer(data.nioByteBuffer())) - } else { - return Some(ChunkedByteBuffer.fromManagedBuffer(data)) - } + // If the ManagedBuffer is a BlockManagerManagedBuffer, the disposal of the + // byte buffers backing it may need to be handled after reading the bytes. + // In this case, since we just fetched the bytes remotely, we do not have + // a BlockManagerManagedBuffer. The assert here is to ensure that this holds + // true (or the disposal is handled). + assert(!data.isInstanceOf[BlockManagerManagedBuffer]) + return Some(data) } logDebug(s"The value of block $blockId is null") } @@ -803,6 +798,22 @@ private[spark] class BlockManager( None } + /** + * Get block from remote block managers as serialized bytes. + */ + def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + getRemoteManagedBuffer(blockId).map { data => + // SPARK-24307 undocumented "escape-hatch" in case there are any issues in converting to + // ChunkedByteBuffer, to go back to old code-path. Can be removed post Spark 2.4 if + // new path is stable. + if (remoteReadNioBufferConversion) { + new ChunkedByteBuffer(data.nioByteBuffer()) + } else { + ChunkedByteBuffer.fromManagedBuffer(data) + } + } + } + /** * Get a block from the block manager (either local or remote). * diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 128d6ff8cd746..2c3730de08b5b 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -172,8 +172,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { private[spark] object ChunkedByteBuffer { - - // TODO SPARK-25905 eliminate this method if we switch BlockManager to getting InputStreams def fromManagedBuffer(data: ManagedBuffer): ChunkedByteBuffer = { data match { case f: FileSegmentManagedBuffer => diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 629a323042ff2..4083b20c23594 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -195,7 +195,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) } - // This will exercise the getRemoteBytes / getRemoteValues code paths: + // This will exercise the getRemoteValues code path: assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet) } From f97326bcdba532eabf25d4899b13709e9af2bfea Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 30 Nov 2018 08:27:55 +0800 Subject: [PATCH 2172/2461] [SPARK-25977][SQL] Parsing decimals from CSV using locale ## What changes were proposed in this pull request? In the PR, I propose using of the locale option to parse decimals from CSV input. After the changes, `UnivocityParser` converts input string to `BigDecimal` and to Spark's Decimal by using `java.text.DecimalFormat`. ## How was this patch tested? Added a test for the `en-US`, `ko-KR`, `ru-RU`, `de-DE` locales. Closes #22979 from MaxGekk/decimal-parsing-locale. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: hyukjinkwon --- .../spark/sql/catalyst/csv/CSVExprUtils.scala | 4 + .../sql/catalyst/csv/CSVInferSchema.scala | 72 ++++----- .../sql/catalyst/csv/UnivocityParser.scala | 8 +- .../catalyst/expressions/csvExpressions.scala | 5 +- .../catalyst/csv/CSVInferSchemaSuite.scala | 147 ++++++++++++------ .../catalyst/csv/UnivocityParserSuite.scala | 22 ++- .../datasources/csv/CSVDataSource.scala | 4 +- 7 files changed, 168 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala index bbe27831f01df..6c982a1de9a48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst.csv +import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition} +import java.util.Locale + object CSVExprUtils { /** * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 799e9994451b2..94cb4b114e6b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -17,16 +17,19 @@ package org.apache.spark.sql.catalyst.csv -import java.math.BigDecimal - import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -object CSVInferSchema { +class CSVInferSchema(options: CSVOptions) extends Serializable { + + private val decimalParser = { + ExprUtils.getDecimalParser(options.locale) + } /** * Similar to the JSON schema inference @@ -36,14 +39,13 @@ object CSVInferSchema { */ def infer( tokenRDD: RDD[Array[String]], - header: Array[String], - options: CSVOptions): StructType = { + header: Array[String]): StructType = { val fields = if (options.inferSchemaFlag) { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) + tokenRDD.aggregate(startType)(inferRowType, mergeRowTypes) - toStructFields(rootTypes, header, options) + toStructFields(rootTypes, header) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -54,8 +56,7 @@ object CSVInferSchema { def toStructFields( fieldTypes: Array[DataType], - header: Array[String], - options: CSVOptions): Array[StructField] = { + header: Array[String]): Array[StructField] = { header.zip(fieldTypes).map { case (thisHeader, rootType) => val dType = rootType match { case _: NullType => StringType @@ -65,11 +66,10 @@ object CSVInferSchema { } } - def inferRowType(options: CSVOptions) - (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i), options) + rowSoFar(i) = inferField(rowSoFar(i), next(i)) i+=1 } rowSoFar @@ -85,20 +85,20 @@ object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { + def inferField(typeSoFar: DataType, field: String): DataType = { if (field == null || field.isEmpty || field == options.nullValue) { typeSoFar } else { typeSoFar match { - case NullType => tryParseInteger(field, options) - case IntegerType => tryParseInteger(field, options) - case LongType => tryParseLong(field, options) + case NullType => tryParseInteger(field) + case IntegerType => tryParseInteger(field) + case LongType => tryParseLong(field) case _: DecimalType => // DecimalTypes have different precisions and scales, so we try to find the common type. - compatibleType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) - case DoubleType => tryParseDouble(field, options) - case TimestampType => tryParseTimestamp(field, options) - case BooleanType => tryParseBoolean(field, options) + compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) + case DoubleType => tryParseDouble(field) + case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -106,30 +106,30 @@ object CSVInferSchema { } } - private def isInfOrNan(field: String, options: CSVOptions): Boolean = { + private def isInfOrNan(field: String): Boolean = { field == options.nanValue || field == options.negativeInf || field == options.positiveInf } - private def tryParseInteger(field: String, options: CSVOptions): DataType = { + private def tryParseInteger(field: String): DataType = { if ((allCatch opt field.toInt).isDefined) { IntegerType } else { - tryParseLong(field, options) + tryParseLong(field) } } - private def tryParseLong(field: String, options: CSVOptions): DataType = { + private def tryParseLong(field: String): DataType = { if ((allCatch opt field.toLong).isDefined) { LongType } else { - tryParseDecimal(field, options) + tryParseDecimal(field) } } - private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + private def tryParseDecimal(field: String): DataType = { val decimalTry = allCatch opt { - // `BigDecimal` conversion can fail when the `field` is not a form of number. - val bigDecimal = new BigDecimal(field) + // The conversion can fail when the `field` is not a form of number. + val bigDecimal = decimalParser(field) // Because many other formats do not support decimal, it reduces the cases for // decimals by disallowing values having scale (eg. `1.1`). if (bigDecimal.scale <= 0) { @@ -138,21 +138,21 @@ object CSVInferSchema { // 2. scale is bigger than precision. DecimalType(bigDecimal.precision, bigDecimal.scale) } else { - tryParseDouble(field, options) + tryParseDouble(field) } } - decimalTry.getOrElse(tryParseDouble(field, options)) + decimalTry.getOrElse(tryParseDouble(field)) } - private def tryParseDouble(field: String, options: CSVOptions): DataType = { - if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) { + private def tryParseDouble(field: String): DataType = { + if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) { DoubleType } else { - tryParseTimestamp(field, options) + tryParseTimestamp(field) } } - private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { + private def tryParseTimestamp(field: String): DataType = { // This case infers a custom `dataFormat` is set. if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { TimestampType @@ -160,11 +160,11 @@ object CSVInferSchema { // We keep this for backwards compatibility. TimestampType } else { - tryParseBoolean(field, options) + tryParseBoolean(field) } } - private def tryParseBoolean(field: String, options: CSVOptions): DataType = { + private def tryParseBoolean(field: String): DataType = { if ((allCatch opt field.toBoolean).isDefined) { BooleanType } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index ed196935e357f..85e129224c913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.csv import java.io.InputStream -import java.math.BigDecimal import scala.util.Try import scala.util.control.NonFatal @@ -27,7 +26,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -104,6 +103,8 @@ class UnivocityParser( requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + /** * Create a converter which converts the string value to a value according to a desired type. * Currently, we do not support complex types (`ArrayType`, `MapType`, `StructType`). @@ -149,8 +150,7 @@ class UnivocityParser( case dt: DecimalType => (d: String) => nullSafeDatum(d, name, nullable, options) { datum => - val value = new BigDecimal(datum.replaceAll(",", "")) - Decimal(value, dt.precision, dt.scale) + Decimal(decimalParser(datum), dt.precision, dt.scale) } case _: TimestampType => (d: String) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 1e4e1c663c90e..83b0299bac440 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -180,8 +180,9 @@ case class SchemaOfCsv( val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) - val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + val inferSchema = new CSVInferSchema(parsedOptions) + val fieldTypes = inferSchema.inferRowType(startType, row) + val st = StructType(inferSchema.toStructFields(fieldTypes, header)) UTF8String.fromString(st.catalogString) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 651846d2ebcb5..1a020e67a75b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -17,126 +17,175 @@ package org.apache.spark.sql.catalyst.csv +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.Locale + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class CSVInferSchemaSuite extends SparkFunSuite { +class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { test("String fields types are inferred correctly from null types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(NullType, "", options) == NullType) - assert(CSVInferSchema.inferField(NullType, null, options) == NullType) - assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) - assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType) - assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "test", options) == StringType) - assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "") == NullType) + assert(inferSchema.inferField(NullType, null) == NullType) + assert(inferSchema.inferField(NullType, "100000000000") == LongType) + assert(inferSchema.inferField(NullType, "60") == IntegerType) + assert(inferSchema.inferField(NullType, "3.5") == DoubleType) + assert(inferSchema.inferField(NullType, "test") == StringType) + assert(inferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(inferSchema.inferField(NullType, "True") == BooleanType) + assert(inferSchema.inferField(NullType, "FAlSE") == BooleanType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) + assert(inferSchema.inferField(NullType, textValueOne) == expectedTypeOne) } test("String fields types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) - assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType) - assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) - assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) - assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(LongType, "1.0") == DoubleType) + assert(inferSchema.inferField(LongType, "test") == StringType) + assert(inferSchema.inferField(IntegerType, "1.0") == DoubleType) + assert(inferSchema.inferField(DoubleType, null) == DoubleType) + assert(inferSchema.inferField(DoubleType, "test") == StringType) + assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(inferSchema.inferField(LongType, "True") == BooleanType) + assert(inferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(inferSchema.inferField(TimestampType, "FALSE") == BooleanType) val textValueOne = Long.MaxValue.toString + "0" val decimalValueOne = new java.math.BigDecimal(textValueOne) val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) - assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) + assert(inferSchema.inferField(IntegerType, textValueOne) == expectedTypeOne) } test("Timestamp field types are inferred correctly via custom data format") { var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + var inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType) + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(TimestampType, "2015") == TimestampType) } test("Timestamp field types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) + assert(inferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) + assert(inferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) } test("Boolean fields types are inferred correctly from other types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) - assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(LongType, "Fale") == StringType) + assert(inferSchema.inferField(DoubleType, "TRUEe") == StringType) } test("Type arrays are merged to highest common type") { + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) + assert( - CSVInferSchema.mergeRowTypes(Array(StringType), + inferSchema.mergeRowTypes(Array(StringType), Array(DoubleType)).deep == Array(StringType).deep) assert( - CSVInferSchema.mergeRowTypes(Array(IntegerType), + inferSchema.mergeRowTypes(Array(IntegerType), Array(LongType)).deep == Array(LongType).deep) assert( - CSVInferSchema.mergeRowTypes(Array(DoubleType), + inferSchema.mergeRowTypes(Array(DoubleType), Array(LongType)).deep == Array(DoubleType).deep) } test("Null fields are handled properly when a nullValue is specified") { var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) - assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) - assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) + var inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "null") == NullType) + assert(inferSchema.inferField(StringType, "null") == StringType) + assert(inferSchema.inferField(LongType, "null") == LongType) options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") - assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) - assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) - assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) - assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) - assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) + inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(IntegerType, "\\N") == IntegerType) + assert(inferSchema.inferField(DoubleType, "\\N") == DoubleType) + assert(inferSchema.inferField(TimestampType, "\\N") == TimestampType) + assert(inferSchema.inferField(BooleanType, "\\N") == BooleanType) + assert(inferSchema.inferField(DecimalType(1, 1), "\\N") == DecimalType(1, 1)) } test("Merging Nulltypes should yield Nulltype.") { - val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) + + val mergedNullTypes = inferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") - assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val inferSchema = new CSVInferSchema(options) // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). - assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == + assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") == DecimalType(4, -9)) // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20. val value = "12345678901234567890.01234567890123456789" - assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType) + assert(inferSchema.inferField(DecimalType(3, -10), value) == DoubleType) // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType - assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0)) - assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options) + assert(inferSchema.inferField(NullType, s"${Long.MaxValue}1") == DecimalType(20, 0)) + assert(inferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00") == StringType) } test("DoubleType should be inferred when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", "positiveInf" -> "inf"), false, "GMT") - assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) - assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) + val inferSchema = new CSVInferSchema(options) + + assert(inferSchema.inferField(NullType, "nan") == DoubleType) + assert(inferSchema.inferField(NullType, "inf") == DoubleType) + assert(inferSchema.inferField(NullType, "-inf") == DoubleType) + } + + test("inferring the decimal type using locale") { + def checkDecimalInfer(langTag: String, expectedType: DataType): Unit = { + val options = new CSVOptions( + parameters = Map("locale" -> langTag, "inferSchema" -> "true", "sep" -> "|"), + columnPruning = false, + defaultTimeZoneId = "GMT") + val inferSchema = new CSVInferSchema(options) + + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(Decimal(1000001).toBigDecimal) + + assert(inferSchema.inferField(NullType, input) == expectedType) + } + + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index e4e7dc2e8c0e6..7212402ef5cff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal +import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.util.Locale import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class UnivocityParserSuite extends SparkFunSuite { +class UnivocityParserSuite extends SparkFunSuite with SQLHelper { private val parser = new UnivocityParser( StructType(Seq.empty), new CSVOptions(Map.empty[String, String], false, "GMT")) @@ -196,4 +200,20 @@ class UnivocityParserSuite extends SparkFunSuite { assert(doubleVal2 == Double.PositiveInfinity) } + test("parse decimals using locale") { + def checkDecimalParsing(langTag: String): Unit = { + val decimalVal = new BigDecimal("1000.001") + val decimalType = new DecimalType(10, 5) + val expected = Decimal(decimalVal, decimalType.precision, decimalType.scale) + val df = new DecimalFormat("", new DecimalFormatSymbols(Locale.forLanguageTag(langTag))) + val input = df.format(expected.toBigDecimal) + + val options = new CSVOptions(Map("locale" -> langTag), false, "GMT") + val parser = new UnivocityParser(new StructType().add("d", decimalType), options) + + assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === expected) + } + + Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b35b8851918b1..b46dfb94c133e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -135,7 +135,7 @@ object TextInputCSVDataSource extends CSVDataSource { val parser = new CsvParser(parsedOptions.asParserSettings) linesWithoutHeader.map(parser.parseLine) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + new CSVInferSchema(parsedOptions).infer(tokenRDD, header) case _ => // If the first line could not be read, just return the empty schema. StructType(Nil) @@ -208,7 +208,7 @@ object MultiLineCSVDataSource extends CSVDataSource { encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) - CSVInferSchema.infer(sampled, header, parsedOptions) + new CSVInferSchema(parsedOptions).infer(sampled, header) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) From 0166c7373eee2654c49c210927e4e290d103f24f Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 29 Nov 2018 18:00:47 -0800 Subject: [PATCH 2173/2461] [SPARK-25501][SS] Add kafka delegation token support. ## What changes were proposed in this pull request? It adds kafka delegation token support for structured streaming. Please see the relevant [SPIP](https://docs.google.com/document/d/1ouRayzaJf_N5VQtGhVq9FURXVmRpXzEEWYHob0ne3NY/edit?usp=sharing) What this PR contains: * Configuration parameters for the feature * Delegation token fetching from broker * Usage of token through dynamic JAAS configuration * Minor refactoring in the existing code What this PR doesn't contain: * Documentation changes because design can change ## How was this patch tested? Existing tests + added small amount of additional unit tests. Because it's an external service integration mainly tested on cluster. * 4 node cluster * Kafka broker version 1.1.0 * Topic with 4 partitions * security.protocol = SASL_SSL * sasl.mechanism = SCRAM-SHA-256 An example of obtaining a token: ``` 18/10/01 01:07:49 INFO kafka010.TokenUtil: TOKENID HMAC OWNER RENEWERS ISSUEDATE EXPIRYDATE MAXDATE 18/10/01 01:07:49 INFO kafka010.TokenUtil: D1-v__Q5T_uHx55rW16Jwg [hidden] User:user [] 2018-10-01T01:07 2018-10-02T01:07 2018-10-08T01:07 18/10/01 01:07:49 INFO security.KafkaDelegationTokenProvider: Get token from Kafka: Kind: KAFKA_DELEGATION_TOKEN, Service: kafka.server.delegation.token, Ident: 44 31 2d 76 5f 5f 51 35 54 5f 75 48 78 35 35 72 57 31 36 4a 77 67 ``` An example token usage: ``` 18/10/01 01:08:07 INFO kafka010.KafkaSecurityHelper: Scram JAAS params: org.apache.kafka.common.security.scram.ScramLoginModule required tokenauth=true serviceName="kafka" username="D1-v__Q5T_uHx55rW16Jwg" password="[hidden]"; 18/10/01 01:08:07 INFO kafka010.KafkaSourceProvider: Delegation token detected, using it for login. ``` Closes #22598 from gaborgsomogyi/SPARK-25501. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- core/pom.xml | 13 + .../HadoopDelegationTokenManager.scala | 3 +- .../KafkaDelegationTokenProvider.scala | 61 +++++ .../deploy/security/KafkaTokenUtil.scala | 202 +++++++++++++++ .../apache/spark/internal/config/Kafka.scala | 82 ++++++ .../HadoopDelegationTokenManagerSuite.scala | 5 +- .../deploy/security/KafkaTokenUtilSuite.scala | 239 ++++++++++++++++++ external/kafka-0-10-sql/pom.xml | 2 - .../sql/kafka010/KafkaSecurityHelper.scala | 56 ++++ .../sql/kafka010/KafkaSourceProvider.scala | 82 +++--- .../kafka010/KafkaStreamingWriteSupport.scala | 22 +- .../kafka010/KafkaContinuousSinkSuite.scala | 4 +- .../kafka010/KafkaSecurityHelperSuite.scala | 100 ++++++++ external/kafka-0-10/pom.xml | 2 - pom.xml | 2 + 15 files changed, 825 insertions(+), 50 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala create mode 100644 core/src/main/scala/org/apache/spark/internal/config/Kafka.scala create mode 100644 core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 36d93212ba9f9..49b1a54e32598 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -408,6 +408,19 @@ provided + + + org.apache.kafka + kafka-clients + ${kafka.version} + provided + + target/scala-${scala.binary.version}/classes diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 1169b2878e993..126a6ab801369 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -274,7 +274,8 @@ private[spark] class HadoopDelegationTokenManager( new HadoopFSDelegationTokenProvider( () => HadoopDelegationTokenManager.this.fileSystemsToAccess())) ++ safeCreateProvider(new HiveDelegationTokenProvider) ++ - safeCreateProvider(new HBaseDelegationTokenProvider) + safeCreateProvider(new HBaseDelegationTokenProvider) ++ + safeCreateProvider(new KafkaDelegationTokenProvider) // Filter out providers for which spark.security.credentials.{service}.enabled is false. providers diff --git a/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala new file mode 100644 index 0000000000000..45995be630cc5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import scala.language.existentials +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.Credentials +import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ + +private[security] class KafkaDelegationTokenProvider + extends HadoopDelegationTokenProvider with Logging { + + override def serviceName: String = "kafka" + + override def obtainDelegationTokens( + hadoopConf: Configuration, + sparkConf: SparkConf, + creds: Credentials): Option[Long] = { + try { + logDebug("Attempting to fetch Kafka security token.") + val (token, nextRenewalDate) = KafkaTokenUtil.obtainToken(sparkConf) + creds.addToken(token.getService, token) + return Some(nextRenewalDate) + } catch { + case NonFatal(e) => + logInfo(s"Failed to get token from service $serviceName", e) + } + None + } + + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { + val protocol = sparkConf.get(Kafka.SECURITY_PROTOCOL) + sparkConf.contains(Kafka.BOOTSTRAP_SERVERS) && + (protocol == SASL_SSL.name || + protocol == SSL.name || + protocol == SASL_PLAINTEXT.name) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala b/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala new file mode 100644 index 0000000000000..c890cee59ffe0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import java.{ util => ju } +import java.text.SimpleDateFormat + +import scala.util.control.NonFatal + +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.token.{Token, TokenIdentifier} +import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{AdminClient, CreateDelegationTokenOptions} +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.security.JaasContext +import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL} +import org.apache.kafka.common.security.token.delegation.DelegationToken + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ + +private[spark] object KafkaTokenUtil extends Logging { + val TOKEN_KIND = new Text("KAFKA_DELEGATION_TOKEN") + val TOKEN_SERVICE = new Text("kafka.server.delegation.token") + + private[spark] class KafkaDelegationTokenIdentifier extends AbstractDelegationTokenIdentifier { + override def getKind: Text = TOKEN_KIND + } + + private[security] def obtainToken(sparkConf: SparkConf): (Token[_ <: TokenIdentifier], Long) = { + val adminClient = AdminClient.create(createAdminClientProperties(sparkConf)) + val createDelegationTokenOptions = new CreateDelegationTokenOptions() + val createResult = adminClient.createDelegationToken(createDelegationTokenOptions) + val token = createResult.delegationToken().get() + printToken(token) + + (new Token[KafkaDelegationTokenIdentifier]( + token.tokenInfo.tokenId.getBytes, + token.hmacAsBase64String.getBytes, + TOKEN_KIND, + TOKEN_SERVICE + ), token.tokenInfo.expiryTimestamp) + } + + private[security] def createAdminClientProperties(sparkConf: SparkConf): ju.Properties = { + val adminClientProperties = new ju.Properties + + val bootstrapServers = sparkConf.get(Kafka.BOOTSTRAP_SERVERS) + require(bootstrapServers.nonEmpty, s"Tried to obtain kafka delegation token but bootstrap " + + "servers not configured.") + adminClientProperties.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers.get) + + val protocol = sparkConf.get(Kafka.SECURITY_PROTOCOL) + adminClientProperties.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, protocol) + protocol match { + case SASL_SSL.name => + setTrustStoreProperties(sparkConf, adminClientProperties) + + case SSL.name => + setTrustStoreProperties(sparkConf, adminClientProperties) + setKeyStoreProperties(sparkConf, adminClientProperties) + logWarning("Obtaining kafka delegation token with SSL protocol. Please " + + "configure 2-way authentication on the broker side.") + + case SASL_PLAINTEXT.name => + logWarning("Obtaining kafka delegation token through plain communication channel. Please " + + "consider the security impact.") + } + + // There are multiple possibilities to log in and applied in the following order: + // - JVM global security provided -> try to log in with JVM global security configuration + // which can be configured for example with 'java.security.auth.login.config'. + // For this no additional parameter needed. + // - Keytab is provided -> try to log in with kerberos module and keytab using kafka's dynamic + // JAAS configuration. + // - Keytab not provided -> try to log in with kerberos module and ticket cache using kafka's + // dynamic JAAS configuration. + // Kafka client is unable to use subject from JVM which already logged in + // to kdc (see KAFKA-7677) + if (isGlobalJaasConfigurationProvided) { + logDebug("JVM global security configuration detected, using it for login.") + } else { + adminClientProperties.put(SaslConfigs.SASL_MECHANISM, SaslConfigs.GSSAPI_MECHANISM) + if (sparkConf.contains(KEYTAB)) { + logDebug("Keytab detected, using it for login.") + val jaasParams = getKeytabJaasParams(sparkConf) + adminClientProperties.put(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + } else { + logDebug("Using ticket cache for login.") + val jaasParams = getTicketCacheJaasParams(sparkConf) + adminClientProperties.put(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + } + } + + adminClientProperties + } + + def isGlobalJaasConfigurationProvided: Boolean = { + try { + JaasContext.loadClientContext(ju.Collections.emptyMap[String, Object]()) + true + } catch { + case NonFatal(_) => false + } + } + + private def setTrustStoreProperties(sparkConf: SparkConf, properties: ju.Properties): Unit = { + sparkConf.get(Kafka.TRUSTSTORE_LOCATION).foreach { truststoreLocation => + properties.put("ssl.truststore.location", truststoreLocation) + } + sparkConf.get(Kafka.TRUSTSTORE_PASSWORD).foreach { truststorePassword => + properties.put("ssl.truststore.password", truststorePassword) + } + } + + private def setKeyStoreProperties(sparkConf: SparkConf, properties: ju.Properties): Unit = { + sparkConf.get(Kafka.KEYSTORE_LOCATION).foreach { keystoreLocation => + properties.put("ssl.keystore.location", keystoreLocation) + } + sparkConf.get(Kafka.KEYSTORE_PASSWORD).foreach { keystorePassword => + properties.put("ssl.keystore.password", keystorePassword) + } + sparkConf.get(Kafka.KEY_PASSWORD).foreach { keyPassword => + properties.put("ssl.key.password", keyPassword) + } + } + + private[security] def getKeytabJaasParams(sparkConf: SparkConf): String = { + val serviceName = sparkConf.get(Kafka.KERBEROS_SERVICE_NAME) + require(serviceName.nonEmpty, "Kerberos service name must be defined") + + val params = + s""" + |${getKrb5LoginModuleName} required + | useKeyTab=true + | serviceName="${serviceName.get}" + | keyTab="${sparkConf.get(KEYTAB).get}" + | principal="${sparkConf.get(PRINCIPAL).get}"; + """.stripMargin.replace("\n", "") + logDebug(s"Krb keytab JAAS params: $params") + params + } + + def getTicketCacheJaasParams(sparkConf: SparkConf): String = { + val serviceName = sparkConf.get(Kafka.KERBEROS_SERVICE_NAME) + require(serviceName.nonEmpty, "Kerberos service name must be defined") + + val params = + s""" + |${getKrb5LoginModuleName} required + | useTicketCache=true + | serviceName="${serviceName.get}"; + """.stripMargin.replace("\n", "") + logDebug(s"Krb ticket cache JAAS params: $params") + params + } + + /** + * Krb5LoginModule package vary in different JVMs. + * Please see Hadoop UserGroupInformation for further details. + */ + private def getKrb5LoginModuleName(): String = { + if (System.getProperty("java.vendor").contains("IBM")) { + "com.ibm.security.auth.module.Krb5LoginModule" + } else { + "com.sun.security.auth.module.Krb5LoginModule" + } + } + + private def printToken(token: DelegationToken): Unit = { + if (log.isDebugEnabled) { + val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm") + logDebug("%-15s %-30s %-15s %-25s %-15s %-15s %-15s".format( + "TOKENID", "HMAC", "OWNER", "RENEWERS", "ISSUEDATE", "EXPIRYDATE", "MAXDATE")) + val tokenInfo = token.tokenInfo + logDebug("%-15s [hidden] %-15s %-25s %-15s %-15s %-15s".format( + tokenInfo.tokenId, + tokenInfo.owner, + tokenInfo.renewersAsString, + dateFormat.format(tokenInfo.issueTimestamp), + dateFormat.format(tokenInfo.expiryTimestamp), + dateFormat.format(tokenInfo.maxTimestamp))) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala new file mode 100644 index 0000000000000..85d74c27142ad --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +private[spark] object Kafka { + + val BOOTSTRAP_SERVERS = + ConfigBuilder("spark.kafka.bootstrap.servers") + .doc("A list of coma separated host/port pairs to use for establishing the initial " + + "connection to the Kafka cluster. For further details please see kafka documentation. " + + "Only used to obtain delegation token.") + .stringConf + .createOptional + + val SECURITY_PROTOCOL = + ConfigBuilder("spark.kafka.security.protocol") + .doc("Protocol used to communicate with brokers. For further details please see kafka " + + "documentation. Only used to obtain delegation token.") + .stringConf + .createWithDefault("SASL_SSL") + + val KERBEROS_SERVICE_NAME = + ConfigBuilder("spark.kafka.sasl.kerberos.service.name") + .doc("The Kerberos principal name that Kafka runs as. This can be defined either in " + + "Kafka's JAAS config or in Kafka's config. For further details please see kafka " + + "documentation. Only used to obtain delegation token.") + .stringConf + .createOptional + + val TRUSTSTORE_LOCATION = + ConfigBuilder("spark.kafka.ssl.truststore.location") + .doc("The location of the trust store file. For further details please see kafka " + + "documentation. Only used to obtain delegation token.") + .stringConf + .createOptional + + val TRUSTSTORE_PASSWORD = + ConfigBuilder("spark.kafka.ssl.truststore.password") + .doc("The store password for the trust store file. This is optional for client and only " + + "needed if ssl.truststore.location is configured. For further details please see kafka " + + "documentation. Only used to obtain delegation token.") + .stringConf + .createOptional + + val KEYSTORE_LOCATION = + ConfigBuilder("spark.kafka.ssl.keystore.location") + .doc("The location of the key store file. This is optional for client and can be used for " + + "two-way authentication for client. For further details please see kafka documentation. " + + "Only used to obtain delegation token.") + .stringConf + .createOptional + + val KEYSTORE_PASSWORD = + ConfigBuilder("spark.kafka.ssl.keystore.password") + .doc("The store password for the key store file. This is optional for client and only " + + "needed if ssl.keystore.location is configured. For further details please see kafka " + + "documentation. Only used to obtain delegation token.") + .stringConf + .createOptional + + val KEY_PASSWORD = + ConfigBuilder("spark.kafka.ssl.key.password") + .doc("The password of the private key in the key store file. This is optional for client. " + + "For further details please see kafka documentation. Only used to obtain delegation token.") + .stringConf + .createOptional +} diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index e0e630e3be63b..def9e626a2df2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.security import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.security.Credentials import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils @@ -33,6 +31,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { assert(manager.isProviderLoaded("hadoopfs")) assert(manager.isProviderLoaded("hbase")) assert(manager.isProviderLoaded("hive")) + assert(manager.isProviderLoaded("kafka")) } test("disable hive credential provider") { @@ -41,6 +40,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { assert(manager.isProviderLoaded("hadoopfs")) assert(manager.isProviderLoaded("hbase")) assert(!manager.isProviderLoaded("hive")) + assert(manager.isProviderLoaded("kafka")) } test("using deprecated configurations") { @@ -51,6 +51,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { assert(!manager.isProviderLoaded("hadoopfs")) assert(manager.isProviderLoaded("hbase")) assert(!manager.isProviderLoaded("hive")) + assert(manager.isProviderLoaded("kafka")) } test("SPARK-23209: obtain tokens when Hive classes are not available") { diff --git a/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala new file mode 100644 index 0000000000000..682bebde916fa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.security + +import java.{ util => ju } +import javax.security.auth.login.{AppConfigurationEntry, Configuration} + +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { + private val bootStrapServers = "127.0.0.1:0" + private val trustStoreLocation = "/path/to/trustStore" + private val trustStorePassword = "trustStoreSecret" + private val keyStoreLocation = "/path/to/keyStore" + private val keyStorePassword = "keyStoreSecret" + private val keyPassword = "keySecret" + private val keytab = "/path/to/keytab" + private val kerberosServiceName = "kafka" + private val principal = "user@domain.com" + + private var sparkConf: SparkConf = null + + private class KafkaJaasConfiguration extends Configuration { + val entry = + new AppConfigurationEntry( + "DummyModule", + AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, + ju.Collections.emptyMap[String, Object]() + ) + + override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { + if (name.equals("KafkaClient")) { + Array(entry) + } else { + null + } + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + sparkConf = new SparkConf() + } + + override def afterEach(): Unit = { + try { + resetGlobalConfig() + } finally { + super.afterEach() + } + } + + private def setGlobalKafkaClientConfig(): Unit = { + Configuration.setConfiguration(new KafkaJaasConfiguration) + } + + private def resetGlobalConfig(): Unit = { + Configuration.setConfiguration(null) + } + + test("createAdminClientProperties without bootstrap servers should throw exception") { + val thrown = intercept[IllegalArgumentException] { + KafkaTokenUtil.createAdminClientProperties(sparkConf) + } + assert(thrown.getMessage contains + "Tried to obtain kafka delegation token but bootstrap servers not configured.") + } + + test("createAdminClientProperties with SASL_PLAINTEXT protocol should not include " + + "keystore and truststore config") { + sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) + sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_PLAINTEXT.name) + sparkConf.set(Kafka.TRUSTSTORE_LOCATION, trustStoreLocation) + sparkConf.set(Kafka.TRUSTSTORE_PASSWORD, trustStoreLocation) + sparkConf.set(Kafka.KEYSTORE_LOCATION, keyStoreLocation) + sparkConf.set(Kafka.KEYSTORE_PASSWORD, keyStorePassword) + sparkConf.set(Kafka.KEY_PASSWORD, keyPassword) + sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) + + assert(adminClientProperties.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + === bootStrapServers) + assert(adminClientProperties.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) + === SASL_PLAINTEXT.name) + assert(!adminClientProperties.containsKey("ssl.truststore.location")) + assert(!adminClientProperties.containsKey("ssl.truststore.password")) + assert(!adminClientProperties.containsKey("ssl.keystore.location")) + assert(!adminClientProperties.containsKey("ssl.keystore.password")) + assert(!adminClientProperties.containsKey("ssl.key.password")) + } + + test("createAdminClientProperties with SASL_SSL protocol should include truststore config") { + sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) + sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_SSL.name) + sparkConf.set(Kafka.TRUSTSTORE_LOCATION, trustStoreLocation) + sparkConf.set(Kafka.TRUSTSTORE_PASSWORD, trustStorePassword) + sparkConf.set(Kafka.KEYSTORE_LOCATION, keyStoreLocation) + sparkConf.set(Kafka.KEYSTORE_PASSWORD, keyStorePassword) + sparkConf.set(Kafka.KEY_PASSWORD, keyPassword) + sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) + + assert(adminClientProperties.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + === bootStrapServers) + assert(adminClientProperties.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) + === SASL_SSL.name) + assert(adminClientProperties.get("ssl.truststore.location") === trustStoreLocation) + assert(adminClientProperties.get("ssl.truststore.password") === trustStorePassword) + assert(!adminClientProperties.containsKey("ssl.keystore.location")) + assert(!adminClientProperties.containsKey("ssl.keystore.password")) + assert(!adminClientProperties.containsKey("ssl.key.password")) + } + + test("createAdminClientProperties with SSL protocol should include keystore and truststore " + + "config") { + sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) + sparkConf.set(Kafka.SECURITY_PROTOCOL, SSL.name) + sparkConf.set(Kafka.TRUSTSTORE_LOCATION, trustStoreLocation) + sparkConf.set(Kafka.TRUSTSTORE_PASSWORD, trustStorePassword) + sparkConf.set(Kafka.KEYSTORE_LOCATION, keyStoreLocation) + sparkConf.set(Kafka.KEYSTORE_PASSWORD, keyStorePassword) + sparkConf.set(Kafka.KEY_PASSWORD, keyPassword) + sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) + + assert(adminClientProperties.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + === bootStrapServers) + assert(adminClientProperties.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) + === SSL.name) + assert(adminClientProperties.get("ssl.truststore.location") === trustStoreLocation) + assert(adminClientProperties.get("ssl.truststore.password") === trustStorePassword) + assert(adminClientProperties.get("ssl.keystore.location") === keyStoreLocation) + assert(adminClientProperties.get("ssl.keystore.password") === keyStorePassword) + assert(adminClientProperties.get("ssl.key.password") === keyPassword) + } + + test("createAdminClientProperties with global config should not set dynamic jaas config") { + sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) + sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_SSL.name) + setGlobalKafkaClientConfig() + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) + + assert(adminClientProperties.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + === bootStrapServers) + assert(adminClientProperties.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) + === SASL_SSL.name) + assert(!adminClientProperties.containsKey(SaslConfigs.SASL_MECHANISM)) + assert(!adminClientProperties.containsKey(SaslConfigs.SASL_JAAS_CONFIG)) + } + + test("createAdminClientProperties with keytab should set keytab dynamic jaas config") { + sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) + sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_SSL.name) + sparkConf.set(KEYTAB, keytab) + sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) + sparkConf.set(PRINCIPAL, principal) + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) + + assert(adminClientProperties.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + === bootStrapServers) + assert(adminClientProperties.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) + === SASL_SSL.name) + assert(adminClientProperties.containsKey(SaslConfigs.SASL_MECHANISM)) + val saslJaasConfig = adminClientProperties.getProperty(SaslConfigs.SASL_JAAS_CONFIG) + assert(saslJaasConfig.contains("Krb5LoginModule required")) + assert(saslJaasConfig.contains("useKeyTab=true")) + } + + test("createAdminClientProperties without keytab should set ticket cache dynamic jaas config") { + sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) + sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_SSL.name) + sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) + + val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) + + assert(adminClientProperties.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + === bootStrapServers) + assert(adminClientProperties.get(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG) + === SASL_SSL.name) + assert(adminClientProperties.containsKey(SaslConfigs.SASL_MECHANISM)) + val saslJaasConfig = adminClientProperties.getProperty(SaslConfigs.SASL_JAAS_CONFIG) + assert(saslJaasConfig.contains("Krb5LoginModule required")) + assert(saslJaasConfig.contains("useTicketCache=true")) + } + + test("isGlobalJaasConfigurationProvided without global config should return false") { + assert(!KafkaTokenUtil.isGlobalJaasConfigurationProvided) + } + + test("isGlobalJaasConfigurationProvided with global config should return false") { + setGlobalKafkaClientConfig() + + assert(KafkaTokenUtil.isGlobalJaasConfigurationProvided) + } + + test("getKeytabJaasParams with keytab no service should throw exception") { + sparkConf.set(KEYTAB, keytab) + + val thrown = intercept[IllegalArgumentException] { + KafkaTokenUtil.getKeytabJaasParams(sparkConf) + } + + assert(thrown.getMessage contains "Kerberos service name must be defined") + } + + test("getTicketCacheJaasParams without service should throw exception") { + val thrown = intercept[IllegalArgumentException] { + KafkaTokenUtil.getTicketCacheJaasParams(sparkConf) + } + + assert(thrown.getMessage contains "Kerberos service name must be defined") + } +} diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 1af407167597b..de8731c4b774b 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -29,8 +29,6 @@ spark-sql-kafka-0-10_2.12 sql-kafka-0-10 - - 2.1.0 jar Kafka 0.10+ Source for Structured Streaming diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala new file mode 100644 index 0000000000000..74d5ef9c05f14 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.token.{Token, TokenIdentifier} +import org.apache.kafka.common.security.scram.ScramLoginModule + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.security.KafkaTokenUtil +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ + +private[kafka010] object KafkaSecurityHelper extends Logging { + def isTokenAvailable(): Boolean = { + UserGroupInformation.getCurrentUser().getCredentials.getToken( + KafkaTokenUtil.TOKEN_SERVICE) != null + } + + def getTokenJaasParams(sparkConf: SparkConf): String = { + val token = UserGroupInformation.getCurrentUser().getCredentials.getToken( + KafkaTokenUtil.TOKEN_SERVICE) + val serviceName = sparkConf.get(Kafka.KERBEROS_SERVICE_NAME) + require(serviceName.isDefined, "Kerberos service name must be defined") + val username = new String(token.getIdentifier) + val password = new String(token.getPassword) + + val loginModuleName = classOf[ScramLoginModule].getName + val params = + s""" + |$loginModuleName required + | tokenauth=true + | serviceName="${serviceName.get}" + | username="$username" + | password="$password"; + """.stripMargin.replace("\n", "") + logDebug(s"Scram JAAS params: ${params.replaceAll("password=\".*\"", "password=\"[hidden]\"")}") + + params + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index f770f0c2a04c2..0ac330435e5c5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, Optional, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.config.SaslConfigs import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} +import org.apache.spark.SparkEnv +import org.apache.spark.deploy.security.KafkaTokenUtil import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ @@ -80,12 +83,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + val specifiedKafkaParams = convertToSpecifiedParams(parameters) val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) @@ -122,12 +120,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + val specifiedKafkaParams = convertToSpecifiedParams(parameters) val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) @@ -198,12 +191,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister parameters: Map[String, String]): BaseRelation = { validateBatchOptions(parameters) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + val specifiedKafkaParams = convertToSpecifiedParams(parameters) val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) @@ -230,8 +218,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister outputMode: OutputMode): Sink = { val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) val specifiedKafkaParams = kafkaParamsForProducer(parameters) - new KafkaSink(sqlContext, - new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic) + new KafkaSink(sqlContext, specifiedKafkaParams, defaultTopic) } override def createRelation( @@ -248,8 +235,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) val specifiedKafkaParams = kafkaParamsForProducer(parameters) - KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, - new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic) + KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, specifiedKafkaParams, + topic) /* This method is suppose to return a relation that reads the data that was written. * We cannot support this for Kafka. Therefore, in order to make things consistent, @@ -274,13 +261,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister options: DataSourceOptions): StreamingWriteSupport = { import scala.collection.JavaConverters._ - val spark = SparkSession.getActiveSession.get val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - KafkaWriter.validateQuery( - schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + KafkaWriter.validateQuery(schema.toAttributes, producerParams, topic) new KafkaStreamingWriteSupport(topic, producerParams, schema) } @@ -481,6 +466,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { + private val serClassName = classOf[ByteArraySerializer].getName private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( @@ -515,6 +501,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .setTokenJaasConfigIfNeeded() .build() def kafkaParamsForExecutors( @@ -536,6 +523,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .setTokenJaasConfigIfNeeded() .build() /** @@ -568,11 +556,32 @@ private[kafka010] object KafkaSourceProvider extends Logging { this } + def setTokenJaasConfigIfNeeded(): ConfigUpdater = { + // There are multiple possibilities to log in and applied in the following order: + // - JVM global security provided -> try to log in with JVM global security configuration + // which can be configured for example with 'java.security.auth.login.config'. + // For this no additional parameter needed. + // - Token is provided -> try to log in with scram module using kafka's dynamic JAAS + // configuration. + if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) { + logDebug("JVM global security configuration detected, using it for login.") + } else if (KafkaSecurityHelper.isTokenAvailable()) { + logDebug("Delegation token detected, using it for login.") + val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) + val mechanism = kafkaParams + .getOrElse(SaslConfigs.SASL_MECHANISM, SaslConfigs.DEFAULT_SASL_MECHANISM) + require(mechanism.startsWith("SCRAM"), + "Delegation token works only with SCRAM mechanism.") + set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + } + this + } + def build(): ju.Map[String, Object] = map } private[kafka010] def kafkaParamsForProducer( - parameters: Map[String, String]): Map[String, String] = { + parameters: Map[String, String]): ju.Map[String, Object] = { val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( @@ -580,17 +589,26 @@ private[kafka010] object KafkaSourceProvider extends Logging { + "are serialized with ByteArraySerializer.") } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + "value are serialized with ByteArraySerializer.") } + + val specifiedKafkaParams = convertToSpecifiedParams(parameters) + + ConfigUpdater("executor", specifiedKafkaParams) + .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) + .set(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, serClassName) + .setTokenJaasConfigIfNeeded() + .build() + } + + private def convertToSpecifiedParams(parameters: Map[String, String]): Map[String, String] = { parameters .keySet .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + .toMap } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala index 927c56d9ce829..0d831c3884609 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.kafka010 -import scala.collection.JavaConverters._ +import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute @@ -41,10 +41,12 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage * @param schema The schema of the input data. */ class KafkaStreamingWriteSupport( - topic: Option[String], producerParams: Map[String, String], schema: StructType) + topic: Option[String], + producerParams: ju.Map[String, Object], + schema: StructType) extends StreamingWriteSupport { - validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + validateQuery(schema.toAttributes, producerParams, topic) override def createStreamingWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) @@ -62,7 +64,9 @@ class KafkaStreamingWriteSupport( * @param schema The schema of the input data. */ case class KafkaStreamWriterFactory( - topic: Option[String], producerParams: Map[String, String], schema: StructType) + topic: Option[String], + producerParams: ju.Map[String, Object], + schema: StructType) extends StreamingDataWriterFactory { override def createWriter( @@ -83,12 +87,12 @@ case class KafkaStreamWriterFactory( * @param inputSchema The attributes in the input data. */ class KafkaStreamDataWriter( - targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + targetTopic: Option[String], + producerParams: ju.Map[String, Object], + inputSchema: Seq[Attribute]) extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - import scala.collection.JavaConverters._ - private lazy val producer = CachedKafkaProducer.getOrCreate( - new java.util.HashMap[String, Object](producerParams.asJava)) + private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams) def write(row: InternalRow): Unit = { checkForErrors() @@ -112,7 +116,7 @@ class KafkaStreamDataWriter( if (producer != null) { producer.flush() checkForErrors() - CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + CachedKafkaProducer.close(producerParams) } } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 3f6fcf6b2e52c..b21037b1340ce 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -409,7 +409,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { */ val topic = newTopic() testUtils.createTopic(topic, 1) - val options = new java.util.HashMap[String, String] + val options = new java.util.HashMap[String, Object] options.put("bootstrap.servers", testUtils.brokerAddress) options.put("buffer.memory", "16384") // min buffer size options.put("block.on.buffer.full", "true") @@ -417,7 +417,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) val inputSchema = Seq(AttributeReference("value", BinaryType)()) val data = new Array[Byte](15000) // large value - val writeTask = new KafkaStreamDataWriter(Some(topic), options.asScala.toMap, inputSchema) + val writeTask = new KafkaStreamDataWriter(Some(topic), options, inputSchema) try { val fieldTypes: Array[DataType] = Array(BinaryType) val converter = UnsafeProjection.create(fieldTypes) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala new file mode 100644 index 0000000000000..772fe4614bad0 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.UUID + +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.security.KafkaTokenUtil +import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdentifier +import org.apache.spark.internal.config._ + +class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { + private val keytab = "/path/to/keytab" + private val kerberosServiceName = "kafka" + private val principal = "user@domain.com" + private val tokenId = "tokenId" + UUID.randomUUID().toString + private val tokenPassword = "tokenPassword" + UUID.randomUUID().toString + + private var sparkConf: SparkConf = null + + override def beforeEach(): Unit = { + super.beforeEach() + sparkConf = new SparkConf() + } + + override def afterEach(): Unit = { + try { + resetUGI + } finally { + super.afterEach() + } + } + + private def addTokenToUGI(): Unit = { + val token = new Token[KafkaDelegationTokenIdentifier]( + tokenId.getBytes, + tokenPassword.getBytes, + KafkaTokenUtil.TOKEN_KIND, + KafkaTokenUtil.TOKEN_SERVICE + ) + val creds = new Credentials() + creds.addToken(KafkaTokenUtil.TOKEN_SERVICE, token) + UserGroupInformation.getCurrentUser.addCredentials(creds) + } + + private def resetUGI: Unit = { + UserGroupInformation.setLoginUser(null) + } + + test("isTokenAvailable without token should return false") { + assert(!KafkaSecurityHelper.isTokenAvailable()) + } + + test("isTokenAvailable with token should return true") { + addTokenToUGI() + + assert(KafkaSecurityHelper.isTokenAvailable()) + } + + test("getTokenJaasParams with token no service should throw exception") { + addTokenToUGI() + + val thrown = intercept[IllegalArgumentException] { + KafkaSecurityHelper.getTokenJaasParams(sparkConf) + } + + assert(thrown.getMessage contains "Kerberos service name must be defined") + } + + test("getTokenJaasParams with token should return scram module") { + addTokenToUGI() + sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) + + val jaasParams = KafkaSecurityHelper.getTokenJaasParams(sparkConf) + + assert(jaasParams.contains("ScramLoginModule required")) + assert(jaasParams.contains("tokenauth=true")) + assert(jaasParams.contains(tokenId)) + assert(jaasParams.contains(tokenPassword)) + } +} diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index ea18b7e035915..333572e99b1c7 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -28,8 +28,6 @@ spark-streaming-kafka-0-10_2.12 streaming-kafka-0-10 - - 2.1.0 jar Spark Integration for Kafka 0.10 diff --git a/pom.xml b/pom.xml index 3ca2f739ce0ea..dfc3c540dc18e 100644 --- a/pom.xml +++ b/pom.xml @@ -128,6 +128,8 @@ 1.2.1.spark2 1.2.1 + + 2.1.0 10.12.1.1 1.10.0 1.5.3 From 66b2046462c0e93b2ca167728eba9f4d13a5a67c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 30 Nov 2018 10:29:30 +0800 Subject: [PATCH 2174/2461] [SPARK-25446][R] Add schema_of_json() and schema_of_csv() to R ## What changes were proposed in this pull request? This PR proposes to expose `schema_of_json` and `schema_of_csv` at R side. **`schema_of_json`**: ```r json <- '{"name":"Bob"}' df <- sql("SELECT * FROM range(1)") head(select(df, schema_of_json(json))) ``` ``` schema_of_json({"name":"Bob"}) 1 struct ``` **`schema_of_csv`**: ```r csv <- "Amsterdam,2018" df <- sql("SELECT * FROM range(1)") head(select(df, schema_of_csv(csv))) ``` ``` schema_of_csv(Amsterdam,2018) 1 struct<_c0:string,_c1:int> ``` ## How was this patch tested? Manually tested, unit tests added, documentation manually built and verified. Closes #22939 from HyukjinKwon/SPARK-25446. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 77 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 8 +++ R/pkg/tests/fulltests/test_sparkSQL.R | 16 +++++- 4 files changed, 94 insertions(+), 9 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index cdeafdd90ce4a..1f8ba0bcf1cf5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -351,6 +351,8 @@ exportMethods("%<=>%", "row_number", "rpad", "rtrim", + "schema_of_csv", + "schema_of_json", "second", "sha1", "sha2", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f72645a257796..f568a931ae1fe 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -205,11 +205,18 @@ NULL #' also supported for the schema. #' \item \code{from_csv}: a DDL-formatted string #' } -#' @param ... additional argument(s). In \code{to_json}, \code{to_csv} and \code{from_json}, -#' this contains additional named properties to control how it is converted, accepts -#' the same options as the JSON/CSV data source. Additionally \code{to_json} supports -#' the "pretty" option which enables pretty JSON generation. In \code{arrays_zip}, -#' this contains additional Columns of arrays to be merged. +#' @param ... additional argument(s). +#' \itemize{ +#' \item \code{to_json}, \code{from_json} and \code{schema_of_json}: this contains +#' additional named properties to control how it is converted and accepts the +#' same options as the JSON data source. +#' \item \code{to_json}: it supports the "pretty" option which enables pretty +#' JSON generation. +#' \item \code{to_csv}, \code{from_csv} and \code{schema_of_csv}: this contains +#' additional named properties to control how it is converted and accepts the +#' same options as the CSV data source. +#' \item \code{arrays_zip}, this contains additional Columns of arrays to be merged. +#' } #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -1771,12 +1778,16 @@ setMethod("to_date", #' df2 <- mutate(df2, people_json = to_json(df2$people)) #' #' # Converts a map into a JSON object -#' df2 <- sql("SELECT map('name', 'Bob')) as people") +#' df2 <- sql("SELECT map('name', 'Bob') as people") #' df2 <- mutate(df2, people_json = to_json(df2$people)) #' #' # Converts an array of maps into a JSON array #' df2 <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") -#' df2 <- mutate(df2, people_json = to_json(df2$people))} +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts a map into a pretty JSON object +#' df2 <- sql("SELECT map('name', 'Bob') as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people, pretty = TRUE))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), function(x, ...) { @@ -2285,6 +2296,32 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") column(jc) }) +#' @details +#' \code{schema_of_json}: Parses a JSON string and infers its schema in DDL format. +#' +#' @rdname column_collection_functions +#' @aliases schema_of_json schema_of_json,characterOrColumn-method +#' @examples +#' +#' \dontrun{ +#' json <- "{\"name\":\"Bob\"}" +#' df <- sql("SELECT * FROM range(1)") +#' head(select(df, schema_of_json(json)))} +#' @note schema_of_json since 3.0.0 +setMethod("schema_of_json", signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "character") { + col <- callJStatic("org.apache.spark.sql.functions", "lit", x) + } else { + col <- x@jc + } + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "schema_of_json", + col, options) + column(jc) + }) + #' @details #' \code{from_csv}: Parses a column containing a CSV string into a Column of \code{structType} #' with the specified \code{schema}. @@ -2315,6 +2352,32 @@ setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"), column(jc) }) +#' @details +#' \code{schema_of_csv}: Parses a CSV string and infers its schema in DDL format. +#' +#' @rdname column_collection_functions +#' @aliases schema_of_csv schema_of_csv,characterOrColumn-method +#' @examples +#' +#' \dontrun{ +#' csv <- "Amsterdam,2018" +#' df <- sql("SELECT * FROM range(1)") +#' head(select(df, schema_of_csv(csv)))} +#' @note schema_of_csv since 3.0.0 +setMethod("schema_of_csv", signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "character") { + col <- callJStatic("org.apache.spark.sql.functions", "lit", x) + } else { + col <- x@jc + } + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "schema_of_csv", + col, options) + column(jc) + }) + #' @details #' \code{from_utc_timestamp}: This is a common function for databases supporting TIMESTAMP WITHOUT #' TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b2ca6e62175e7..9d8c24c686c76 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1206,6 +1206,14 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @name NULL setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("schema_of_csv", function(x, ...) { standardGeneric("schema_of_csv") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("schema_of_json", function(x, ...) { standardGeneric("schema_of_json") }) + #' @rdname column_aggregate_functions #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 77a29c9ecad86..0d5118c127f2b 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1620,14 +1620,20 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) - # Test from_csv() + # Test from_csv(), schema_of_csv() df <- as.DataFrame(list(list("col" = "1"))) c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv"))) expect_equal(c[[1]][[1]]$a, 1) c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv"))) expect_equal(c[[1]][[1]]$a, 1) - # Test to_json(), from_json() + df <- as.DataFrame(list(list("col" = "1"))) + c <- collect(select(df, schema_of_csv("Amsterdam,2018"))) + expect_equal(c[[1]], "struct<_c0:string,_c1:int>") + c <- collect(select(df, schema_of_csv(lit("Amsterdam,2018")))) + expect_equal(c[[1]], "struct<_c0:string,_c1:int>") + + # Test to_json(), from_json(), schema_of_json() df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") j <- collect(select(df, alias(to_json(df$people), "json"))) expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") @@ -1654,6 +1660,12 @@ test_that("column functions", { expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } + df <- as.DataFrame(list(list("col" = "1"))) + c <- collect(select(df, schema_of_json('{"name":"Bob"}'))) + expect_equal(c[[1]], "struct") + c <- collect(select(df, schema_of_json(lit('{"name":"Bob"}')))) + expect_equal(c[[1]], "struct") + # Test to_json() supports arrays of primitive types and arrays df <- sql("SELECT array(19, 42, 70) as age") j <- collect(select(df, alias(to_json(df$age), "json"))) From 8edb64c1b9ee49d836e171a459dd93f524df92bf Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 30 Nov 2018 11:56:25 +0800 Subject: [PATCH 2175/2461] [SPARK-26060][SQL] Track SparkConf entries and make SET command reject such entries. ## What changes were proposed in this pull request? Currently the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. We should track `SparkConf` entries and make the command reject for such entries. ## How was this patch tested? Added a test and existing tests. Closes #23031 from ueshin/issues/SPARK-26060/set_command. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 ++ .../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/RuntimeConfig.scala | 4 ++++ .../org/apache/spark/sql/RuntimeConfigSuite.scala | 10 ++++++++++ .../spark/sql/execution/command/DDLSuite.scala | 8 ++++++++ 5 files changed, 35 insertions(+), 1 deletion(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 55838e773e4b1..e48125a0972b5 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -29,6 +29,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined. + - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.execution.setCommandRejectsSparkConfs` to `false`. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7bcf21595ce5a..f1c845bc94507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.Utils object SQLConf { - private val sqlConfEntries = java.util.Collections.synchronizedMap( + private[sql] val sqlConfEntries = java.util.Collections.synchronizedMap( new java.util.HashMap[String, ConfigEntry[_]]()) val staticConfKeys: java.util.Set[String] = @@ -1610,6 +1610,14 @@ object SQLConf { """ "... N more fields" placeholder.""") .intConf .createWithDefault(25) + + val SET_COMMAND_REJECTS_SPARK_CONFS = + buildConf("spark.sql.legacy.execution.setCommandRejectsSparkConfs") + .internal() + .doc("If it is set to true, SET command will fail when the key is registered as " + + "a SparkConf entry.") + .booleanConf + .createWithDefault(true) } /** @@ -2030,6 +2038,8 @@ class SQLConf extends Serializable with Logging { def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) + def setCommandRejectsSparkConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CONFS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 5a554eff02e3d..d83a01ff9ea65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -153,5 +153,9 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { if (SQLConf.staticConfKeys.contains(key)) { throw new AnalysisException(s"Cannot modify the value of a static config: $key") } + if (sqlConf.setCommandRejectsSparkConfs && + ConfigEntry.findEntry(key) != null && !SQLConf.sqlConfEntries.containsKey(key)) { + throw new AnalysisException(s"Cannot modify the value of a Spark config: $key") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cdcea09ad9758..6196757eb7010 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.config class RuntimeConfigSuite extends SparkFunSuite { @@ -68,4 +69,13 @@ class RuntimeConfigSuite extends SparkFunSuite { assert(!conf.isModifiable("")) assert(!conf.isModifiable("invalid config parameter")) } + + test("reject SparkConf entries") { + val conf = newConf() + + val ex = intercept[AnalysisException] { + conf.set(config.CPUS_PER_TASK.key, 4) + } + assert(ex.getMessage.contains("Spark config")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f8d98dead2d42..9d32fb6d46962 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -24,6 +24,7 @@ import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach +import org.apache.spark.internal.config import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} @@ -2715,4 +2716,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + test("set command rejects SparkConf entries") { + val ex = intercept[AnalysisException] { + sql(s"SET ${config.CPUS_PER_TASK.key} = 4") + } + assert(ex.getMessage.contains("Spark config")) + } } From 9cfc3ee6253bed21924424ccaadea0287a6f15f4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 30 Nov 2018 12:00:55 +0800 Subject: [PATCH 2176/2461] [SPARK-26188][SQL] FileIndex: don't infer data types of partition columns if user specifies schema ## What changes were proposed in this pull request? This PR is to fix a regression introduced in: https://github.com/apache/spark/pull/21004/files#r236998030 If user specifies schema, Spark don't need to infer data type for of partition columns, otherwise the data type might not match with the one user provided. E.g. for partition directory `p=4d`, after data type inference the column value will be `4.0`. See https://issues.apache.org/jira/browse/SPARK-26188 for more details. Note that user specified schema **might not cover all the data columns**: ``` val schema = new StructType() .add("id", StringType) .add("ex", ArrayType(StringType)) val df = spark.read .schema(schema) .format("parquet") .load(src.toString) assert(df.schema.toList === List( StructField("ex", ArrayType(StringType)), StructField("part", IntegerType), // inferred partitionColumn dataType StructField("id", StringType))) // used user provided partitionColumn dataType ``` For the missing columns in user specified schema, Spark still need to infer their data types if `partitionColumnTypeInferenceEnabled` is enabled. To implement the partially inference, refactor `PartitioningUtils.parsePartitions` and pass the user specified schema as parameter to cast partition values. ## How was this patch tested? Add unit test. Closes #23165 from gengliangwang/fixFileIndex. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../PartitioningAwareFileIndex.scala | 47 ++----------------- .../datasources/PartitioningUtils.scala | 39 ++++++++++++--- .../datasources/FileIndexSuite.scala | 16 +++++++ .../ParquetPartitionDiscoverySuite.scala | 22 +++++++-- 4 files changed, 72 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index cc8af7b92c454..7b0e4dbcc25f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -126,33 +126,15 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - val inferredPartitionSpec = PartitioningUtils.parsePartitions( + + val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis + PartitioningUtils.parsePartitions( leafDirs, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, basePaths = basePaths, + userSpecifiedSchema = userSpecifiedSchema, + caseSensitive = caseSensitive, timeZoneId = timeZoneId) - userSpecifiedSchema match { - case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val userPartitionSchema = - combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - - // we need to cast into the data type that user specified. - def castPartitionValuesToUserSchema(row: InternalRow) = { - InternalRow((0 until row.numFields).map { i => - val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType - Cast( - Literal.create(row.get(i, dt), dt), - userPartitionSchema.fields(i).dataType, - Option(timeZoneId)).eval() - }: _*) - } - - PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => - part.copy(values = castPartitionValuesToUserSchema(part.values)) - }) - case _ => - inferredPartitionSpec - } } private def prunePartitions( @@ -233,25 +215,6 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } - - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param spec A partition inference result - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { - val equality = sparkSession.sessionState.conf.resolver - val resolved = spec.partitionColumns.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 3183fd30e5e0d..9d2c9ba0c1a5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -94,18 +94,34 @@ object PartitioningUtils { paths: Seq[Path], typeInference: Boolean, basePaths: Set[Path], + userSpecifiedSchema: Option[StructType], + caseSensitive: Boolean, timeZoneId: String): PartitionSpec = { - parsePartitions(paths, typeInference, basePaths, DateTimeUtils.getTimeZone(timeZoneId)) + parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema, + caseSensitive, DateTimeUtils.getTimeZone(timeZoneId)) } private[datasources] def parsePartitions( paths: Seq[Path], typeInference: Boolean, basePaths: Set[Path], + userSpecifiedSchema: Option[StructType], + caseSensitive: Boolean, timeZone: TimeZone): PartitionSpec = { + val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) { + val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap + if (!caseSensitive) { + CaseInsensitiveMap(nameToDataType) + } else { + nameToDataType + } + } else { + Map.empty[String, DataType] + } + // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, typeInference, basePaths, timeZone) + parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone) }.unzip // We create pairs of (path -> path's partition value) here @@ -147,7 +163,7 @@ object PartitioningUtils { columnNames.zip(literals).map { case (name, Literal(_, dataType)) => // We always assume partition columns are nullable since we've no idea whether null values // will be appended in the future. - StructField(name, dataType, nullable = true) + StructField(name, userSpecifiedDataTypes.getOrElse(name, dataType), nullable = true) } } @@ -185,6 +201,7 @@ object PartitioningUtils { path: Path, typeInference: Boolean, basePaths: Set[Path], + userSpecifiedDataTypes: Map[String, DataType], timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` @@ -206,7 +223,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, typeInference, timeZone) + parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -239,6 +256,7 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, typeInference: Boolean, + userSpecifiedDataTypes: Map[String, DataType], timeZone: TimeZone): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { @@ -250,7 +268,16 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) + val literal = if (userSpecifiedDataTypes.contains(columnName)) { + // SPARK-26188: if user provides corresponding column schema, get the column value without + // inference, and then cast it as user specified data type. + val columnValue = inferPartitionColumnValue(rawColumnValue, false, timeZone) + val castedValue = + Cast(columnValue, userSpecifiedDataTypes(columnName), Option(timeZone.getID)).eval() + Literal.create(castedValue, userSpecifiedDataTypes(columnName)) + } else { + inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) + } Some(columnName -> literal) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 49e7af4a9896b..fdb0511f01a22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} class FileIndexSuite extends SharedSQLContext { @@ -49,6 +50,21 @@ class FileIndexSuite extends SharedSQLContext { } } + test("SPARK-26188: don't infer data types of partition columns if user specifies schema") { + withTempDir { dir => + val partitionDirectory = new File(dir, s"a=4d") + partitionDirectory.mkdir() + val file = new File(partitionDirectory, "text.txt") + stringToFile(file, "text") + val path = new Path(dir.getCanonicalPath) + val schema = StructType(Seq(StructField("a", StringType, false))) + val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema)) + val partitionValues = fileIndex.partitionSpec().partitions.map(_.values) + assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 && + partitionValues(0).getString(0) == "4d") + } + } + test("InMemoryFileIndex: input paths are converted to qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 9966ed94a8392..f808ca458aaa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -115,6 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha paths.map(new Path(_)), true, Set(new Path("hdfs://host:9000/path/")), + None, + true, timeZoneId) // Valid @@ -128,6 +130,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha paths.map(new Path(_)), true, Set(new Path("hdfs://host:9000/path/something=true/table")), + None, + true, timeZoneId) // Valid @@ -141,6 +145,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha paths.map(new Path(_)), true, Set(new Path("hdfs://host:9000/path/table=true")), + None, + true, timeZoneId) // Invalid @@ -154,6 +160,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha paths.map(new Path(_)), true, Set(new Path("hdfs://host:9000/path/")), + None, + true, timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -174,6 +182,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha paths.map(new Path(_)), true, Set(new Path("hdfs://host:9000/tmp/tables/")), + None, + true, timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -181,13 +191,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path], timeZone) + parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone) }.getMessage assert(message.contains(expected)) @@ -231,6 +241,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha path = new Path("file://path/a=10"), typeInference = true, basePaths = Set(new Path("file://path/a=10")), + Map.empty, timeZone = timeZone)._1 assert(partitionSpec1.isEmpty) @@ -240,6 +251,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha path = new Path("file://path/a=10"), typeInference = true, basePaths = Set(new Path("file://path")), + Map.empty, timeZone = timeZone)._1 assert(partitionSpec2 == @@ -258,6 +270,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha paths.map(new Path(_)), true, rootPaths, + None, + true, timeZoneId) assert(actualSpec.partitionColumns === spec.partitionColumns) assert(actualSpec.partitions.length === spec.partitions.length) @@ -370,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId) assert(actualSpec === spec) } From 2b2c94a3ee89630047bcdd416a977e0d1cdb1926 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 30 Nov 2018 00:02:43 -0800 Subject: [PATCH 2177/2461] [SPARK-25528][SQL] data source v2 API refactor (batch read) ## What changes were proposed in this pull request? This is the first step of the data source v2 API refactor [proposal](https://docs.google.com/document/d/1uUmKCpWLdh9vHxP7AWJ9EgbwB_U6T3EJYNjhISGmiQg/edit?usp=sharing) It adds the new API for batch read, without removing the old APIs, as they are still needed for streaming sources. More concretely, it adds 1. `TableProvider`, works like an anonymous catalog 2. `Table`, represents a structured data set. 3. `ScanBuilder` and `Scan`, a logical represents of data source scan 4. `Batch`, a physical representation of data source batch scan. ## How was this patch tested? existing tests Closes #23086 from cloud-fan/refactor-batch. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../kafka010/KafkaContinuousSourceSuite.scala | 4 +- .../sql/kafka010/KafkaContinuousTest.scala | 4 +- project/MimaExcludes.scala | 48 ++-- .../sql/sources/v2/SupportsBatchRead.java | 33 +++ .../apache/spark/sql/sources/v2/Table.java | 59 ++++ .../spark/sql/sources/v2/TableProvider.java | 64 +++++ .../spark/sql/sources/v2/reader/Batch.java | 48 ++++ .../reader/OldSupportsReportPartitioning.java | 38 +++ .../reader/OldSupportsReportStatistics.java | 38 +++ .../spark/sql/sources/v2/reader/Scan.java | 68 +++++ .../sql/sources/v2/reader/ScanBuilder.java | 30 ++ .../sql/sources/v2/reader/ScanConfig.java | 4 +- .../sql/sources/v2/reader/Statistics.java | 2 +- .../v2/reader/SupportsPushDownFilters.java | 4 +- .../SupportsPushDownRequiredColumns.java | 4 +- .../v2/reader/SupportsReportPartitioning.java | 8 +- .../v2/reader/SupportsReportStatistics.java | 6 +- .../v2/reader/partitioning/Partitioning.java | 3 +- .../apache/spark/sql/DataFrameReader.scala | 36 ++- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/v2/DataSourceV2Relation.scala | 90 +++--- .../datasources/v2/DataSourceV2ScanExec.scala | 68 ++--- .../datasources/v2/DataSourceV2Strategy.scala | 34 +-- .../v2/DataSourceV2StreamingScanExec.scala | 120 ++++++++ .../streaming/ProgressReporter.scala | 4 +- .../continuous/ContinuousExecution.scala | 5 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 116 ++++---- .../sources/v2/JavaColumnarDataSourceV2.java | 27 +- .../v2/JavaPartitionAwareDataSource.java | 29 +- .../v2/JavaSchemaRequiredDataSource.java | 36 ++- ...Support.java => JavaSimpleBatchTable.java} | 33 +-- .../sources/v2/JavaSimpleDataSourceV2.java | 19 +- .../sql/sources/v2/DataSourceV2Suite.scala | 260 ++++++++++-------- .../sources/v2/SimpleWritableDataSource.scala | 35 ++- .../continuous/ContinuousSuite.scala | 4 +- 35 files changed, 942 insertions(+), 441 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportPartitioning.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportStatistics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala rename sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/{JavaSimpleReadSupport.java => JavaSimpleBatchTable.java} (78%) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index af510219a6f6f..9ba066a4cdc32 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.producer.ProducerRecord import org.apache.spark.sql.Dataset -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.streaming.Trigger @@ -208,7 +208,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2ScanExec + case scan: DataSourceV2StreamingScanExec if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists { config => diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index aa21f1271b817..5549e821be753 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.executedPlan.collectFirst { - case scan: DataSourceV2ScanExec + case scan: DataSourceV2StreamingScanExec if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] => scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig] }.exists(_.knownPartitions.size == newCount), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5e97d826370f7..fcef424c330f1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -197,37 +197,6 @@ object MimaExcludes { // [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.nextCredentialRenewalTime"), - // Data Source V2 API changes - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ContinuousReadSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.WriteSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.StreamWriteSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.MicroBatchReadSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder.build"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.InputPartition.createPartitionReader"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.fullSchema"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.planInputPartitions"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.outputPartitioning"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.outputPartitioning"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.fullSchema"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.planInputPartitions"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder.build"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.ContinuousInputPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.InputPartitionReader"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"), - // [SPARK-26133][ML] Remove deprecated OneHotEncoder and rename OneHotEncoderEstimator to OneHotEncoder ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.OneHotEncoder"), @@ -243,7 +212,22 @@ object MimaExcludes { // [SPARK-26141] Enable custom metrics implementation in shuffle write // Following are Java private classes ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"), + + // Data Source V2 API changes + (problem: Problem) => problem match { + case MissingClassProblem(cls) => + !cls.fullName.startsWith("org.apache.spark.sql.sources.v2") + case MissingTypesProblem(newCls, _) => + !newCls.fullName.startsWith("org.apache.spark.sql.sources.v2") + case InheritedNewAbstractMethodProblem(cls, _) => + !cls.fullName.startsWith("org.apache.spark.sql.sources.v2") + case DirectMissingMethodProblem(meth) => + !meth.owner.fullName.startsWith("org.apache.spark.sql.sources.v2") + case ReversedMissingMethodProblem(meth) => + !meth.owner.fullName.startsWith("org.apache.spark.sql.sources.v2") + case _ => true + } ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java new file mode 100644 index 0000000000000..0df89dbb608a4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.reader.Scan; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; + +/** + * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. + *

        + * If a {@link Table} implements this interface, its {@link Table#newScanBuilder(DataSourceOptions)} + * must return a {@link ScanBuilder} that builds {@link Scan} with {@link Scan#toBatch()} + * implemented. + *

        + */ +@Evolving +public interface SupportsBatchRead extends Table { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java new file mode 100644 index 0000000000000..0c65fe0f9e76a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.reader.Scan; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.types.StructType; + +/** + * An interface representing a logical structured data set of a data source. For example, the + * implementation can be a directory on the file system, a topic of Kafka, or a table in the + * catalog, etc. + *

        + * This interface can mixin the following interfaces to support different operations: + *

        + *
          + *
        • {@link SupportsBatchRead}: this table can be read in batch queries.
        • + *
        + */ +@Evolving +public interface Table { + + /** + * A name to identify this table. Implementations should provide a meaningful name, like the + * database and table name from catalog, or the location of files for this table. + */ + String name(); + + /** + * Returns the schema of this table. + */ + StructType schema(); + + /** + * Returns a {@link ScanBuilder} which can be used to build a {@link Scan} later. Spark will call + * this method for each data scanning query. + *

        + * The builder can take some query specific information to do operators pushdown, and keep these + * information in the created {@link Scan}. + *

        + */ + ScanBuilder newScanBuilder(DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java new file mode 100644 index 0000000000000..855d5efe0c69f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; + +/** + * The base interface for v2 data sources which don't have a real catalog. Implementations must + * have a public, 0-arg constructor. + *

        + * The major responsibility of this interface is to return a {@link Table} for read/write. + *

        + */ +@Evolving +// TODO: do not extend `DataSourceV2`, after we finish the API refactor completely. +public interface TableProvider extends DataSourceV2 { + + /** + * Return a {@link Table} instance to do read/write with user-specified options. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + Table getTable(DataSourceOptions options); + + /** + * Return a {@link Table} instance to do read/write with user-specified schema and options. + *

        + * By default this method throws {@link UnsupportedOperationException}, implementations should + * override this method to handle user-specified schema. + *

        + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + * @param schema the user-specified schema. + * @throws UnsupportedOperationException + */ + default Table getTable(DataSourceOptions options, StructType schema) { + String name; + if (this instanceof DataSourceRegister) { + name = ((DataSourceRegister) this).shortName(); + } else { + name = this.getClass().getName(); + } + throw new UnsupportedOperationException( + name + " source does not support user-specified schema"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java new file mode 100644 index 0000000000000..bcfa1983abb8b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Evolving; + +/** + * A physical representation of a data source scan for batch queries. This interface is used to + * provide physical information, like how many partitions the scanned data has, and how to read + * records from the partitions. + */ +@Evolving +public interface Batch { + + /** + * Returns a list of {@link InputPartition input partitions}. Each {@link InputPartition} + * represents a data split that can be processed by one Spark task. The number of input + * partitions returned here is the same as the number of RDD partitions this scan outputs. + *

        + * If the {@link Scan} supports filter pushdown, this Batch is likely configured with a filter + * and is responsible for creating splits for that filter, which is not a full scan. + *

        + *

        + * This method will be called only once during a data source scan, to launch one Spark job. + *

        + */ + InputPartition[] planInputPartitions(); + + /** + * Returns a factory, which produces one {@link PartitionReader} for one {@link InputPartition}. + */ + PartitionReaderFactory createReaderFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportPartitioning.java new file mode 100644 index 0000000000000..347a465905acc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportPartitioning.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; + +/** + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report data partitioning and try to avoid shuffle at Spark side. + * + * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Spark may avoid adding a shuffle even if the reader does not implement this interface. + */ +@Evolving +// TODO: remove it, after we finish the API refactor completely. +public interface OldSupportsReportPartitioning extends ReadSupport { + + /** + * Returns the output data partitioning that this reader guarantees. + */ + Partitioning outputPartitioning(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportStatistics.java new file mode 100644 index 0000000000000..0d3ec17107c13 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/OldSupportsReportStatistics.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * report statistics to Spark. + * + * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the + * data source. Implementations that return more accurate statistics based on pushed operators will + * not improve query performance until the planner can push operators before getting stats. + */ +@Evolving +// TODO: remove it, after we finish the API refactor completely. +public interface OldSupportsReportStatistics extends ReadSupport { + + /** + * Returns the estimated statistics of this data source scan. + */ + Statistics estimateStatistics(ScanConfig config); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java new file mode 100644 index 0000000000000..4d84fb19aa022 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.Table; + +/** + * A logical representation of a data source scan. This interface is used to provide logical + * information, like what the actual read schema is. + *

        + * This logical representation is shared between batch scan, micro-batch streaming scan and + * continuous streaming scan. Data sources must implement the corresponding methods in this + * interface, to match what the table promises to support. For example, {@link #toBatch()} must be + * implemented, if the {@link Table} that creates this {@link Scan} implements + * {@link SupportsBatchRead}. + *

        + */ +@Evolving +public interface Scan { + + /** + * Returns the actual schema of this data source scan, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + */ + StructType readSchema(); + + /** + * A description string of this scan, which may includes information like: what filters are + * configured for this scan, what's the value of some important options like path, etc. The + * description doesn't need to include {@link #readSchema()}, as Spark already knows it. + *

        + * By default this returns the class name of the implementation. Please override it to provide a + * meaningful description. + *

        + */ + default String description() { + return this.getClass().toString(); + } + + /** + * Returns the physical representation of this scan for batch query. By default this method throws + * exception, data sources must overwrite this method to provide an implementation, if the + * {@link Table} that creates this scan implements {@link SupportsBatchRead}. + * + * @throws UnsupportedOperationException + */ + default Batch toBatch() { + throw new UnsupportedOperationException("Batch scans are not supported"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java new file mode 100644 index 0000000000000..d4bc1ff977132 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Evolving; + +/** + * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ + * interfaces to do operator pushdown, and keep the operator pushdown result in the returned + * {@link Scan}. + */ +@Evolving +public interface ScanBuilder { + Scan build(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java index a69872a527746..c8cff68c2ef76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -28,8 +28,8 @@ * For APIs that take a {@link ScanConfig} as input, like * {@link ReadSupport#planInputPartitions(ScanConfig)}, * {@link BatchReadSupport#createReaderFactory(ScanConfig)} and - * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to - * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. + * {@link OldSupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need + * to cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. */ @Evolving public interface ScanConfig { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index 14776f37fed46..a0b194a41f585 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -23,7 +23,7 @@ /** * An interface to represent statistics for a data source, which is returned by - * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. + * {@link SupportsReportStatistics#estimateStatistics()}. */ @Evolving public interface Statistics { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 3a89baa1b44c2..296d3e47e732b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,11 +21,11 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to * push down filters to the data source and reduce the size of the data to be read. */ @Evolving -public interface SupportsPushDownFilters extends ScanConfigBuilder { +public interface SupportsPushDownFilters extends ScanBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index 1934763224881..60e71c5dd008a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @Evolving -public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { +public interface SupportsPushDownRequiredColumns extends ScanBuilder { /** * Applies column pruning w.r.t. the given requiredSchema. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 0335c7775c2af..ba175812a88d7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -21,17 +21,17 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** - * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * A mix in interface for {@link Batch}. Data sources can implement this interface to * report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, + * Note that, when a {@link Batch} implementation creates exactly one {@link InputPartition}, * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ @Evolving -public interface SupportsReportPartitioning extends ReadSupport { +public interface SupportsReportPartitioning extends Batch { /** * Returns the output data partitioning that this reader guarantees. */ - Partitioning outputPartitioning(ScanConfig config); + Partitioning outputPartitioning(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 917372cdd25b3..d9f5fb64083ad 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to + * A mix in interface for {@link Batch}. Data sources can implement this interface to * report statistics to Spark. * * As of Spark 2.4, statistics are reported to the optimizer before any operator is pushed to the @@ -28,10 +28,10 @@ * not improve query performance until the planner can push operators before getting stats. */ @Evolving -public interface SupportsReportStatistics extends ReadSupport { +public interface SupportsReportStatistics extends Batch { /** * Returns the estimated statistics of this data source scan. */ - Statistics estimateStatistics(ScanConfig config); + Statistics estimateStatistics(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index c9a00262c1287..c7370eb3d38af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -19,12 +19,11 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by - * {@link SupportsReportPartitioning#outputPartitioning(ScanConfig)}. Note that this should work + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index da88598eed061..661fe98d8c901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -194,20 +194,26 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) - if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] - if (ds.isInstanceOf[BatchReadSupportProvider]) { - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = ds, conf = sparkSession.sessionState.conf) - val pathsOption = { - val objectMapper = new ObjectMapper() - DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) - } - Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, sessionOptions ++ extraOptions.toMap + pathsOption, - userSpecifiedSchema = userSpecifiedSchema)) - } else { - loadV1Source(paths: _*) + if (classOf[TableProvider].isAssignableFrom(cls)) { + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = provider, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + } + val finalOptions = sessionOptions ++ extraOptions.toMap + pathsOption + val dsOptions = new DataSourceOptions(finalOptions.asJava) + val table = userSpecifiedSchema match { + case Some(schema) => provider.getTable(dsOptions, schema) + case _ => provider.getTable(dsOptions) + } + table match { + case s: SupportsBatchRead => + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( + provider, s, finalOptions, userSpecifiedSchema = userSpecifiedSchema)) + + case _ => loadV1Source(paths: _*) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5a807d3d4b93e..b9c4076994e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -252,7 +252,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions if (mode == SaveMode.Append) { - val relation = DataSourceV2Relation.create(source, options) + val relation = DataSourceV2Relation.createRelationForWrite(source, options) runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index f7e29593a6353..0a6b0afe6cfe5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,13 +22,13 @@ import java.util.UUID import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{BatchReadSupportProvider, BatchWriteSupportProvider, DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.reader.{BatchReadSupport, ReadSupport, ScanConfigBuilder, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport import org.apache.spark.sql.types.StructType @@ -40,32 +40,38 @@ import org.apache.spark.sql.types.StructType * @param userSpecifiedSchema The user-specified schema for this scan. */ case class DataSourceV2Relation( - source: DataSourceV2, - readSupport: BatchReadSupport, + // TODO: remove `source` when we finish API refactor for write. + source: TableProvider, + table: SupportsBatchRead, output: Seq[AttributeReference], options: Map[String, String], - tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None) - extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat { + extends LeafNode with MultiInstanceRelation with NamedRelation { import DataSourceV2Relation._ - override def name: String = { - tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") - } - - override def pushedFilters: Seq[Expression] = Seq.empty + override def name: String = table.name() - override def simpleString: String = "RelationV2 " + metadataString + override def simpleString: String = { + s"RelationV2${truncatedString(output, "[", ", ", "]")} $name" + } def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) - override def computeStats(): Statistics = readSupport match { - case r: SupportsReportStatistics => - val statistics = r.estimateStatistics(readSupport.newScanConfigBuilder().build()) - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) - case _ => - Statistics(sizeInBytes = conf.defaultSizeInBytes) + def newScanBuilder(): ScanBuilder = { + val dsOptions = new DataSourceOptions(options.asJava) + table.newScanBuilder(dsOptions) + } + + override def computeStats(): Statistics = { + val scan = newScanBuilder().build() + scan match { + case r: SupportsReportStatistics => + val statistics = r.estimateStatistics() + Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } } override def newInstance(): DataSourceV2Relation = { @@ -109,7 +115,7 @@ case class StreamingDataSourceV2Relation( } override def computeStats(): Statistics = readSupport match { - case r: SupportsReportStatistics => + case r: OldSupportsReportStatistics => val statistics = r.estimateStatistics(scanConfigBuilder.build()) Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -119,15 +125,6 @@ case class StreamingDataSourceV2Relation( object DataSourceV2Relation { private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupportProvider: BatchReadSupportProvider = { - source match { - case provider: BatchReadSupportProvider => - provider - case _ => - throw new AnalysisException(s"Data source is not readable: $name") - } - } - def asWriteSupportProvider: BatchWriteSupportProvider = { source match { case provider: BatchWriteSupportProvider => @@ -146,18 +143,6 @@ object DataSourceV2Relation { } } - def createReadSupport( - options: Map[String, String], - userSpecifiedSchema: Option[StructType]): BatchReadSupport = { - val v2Options = new DataSourceOptions(options.asJava) - userSpecifiedSchema match { - case Some(s) => - asReadSupportProvider.createBatchReadSupport(s, v2Options) - case _ => - asReadSupportProvider.createBatchReadSupport(v2Options) - } - } - def createWriteSupport( options: Map[String, String], schema: StructType): BatchWriteSupport = { @@ -170,20 +155,21 @@ object DataSourceV2Relation { } def create( - source: DataSourceV2, + provider: TableProvider, + table: SupportsBatchRead, options: Map[String, String], - tableIdent: Option[TableIdentifier] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val readSupport = source.createReadSupport(options, userSpecifiedSchema) - val output = readSupport.fullSchema().toAttributes - val ident = tableIdent.orElse(tableFromOptions(options)) - DataSourceV2Relation( - source, readSupport, output, options, ident, userSpecifiedSchema) + val output = table.schema().toAttributes + DataSourceV2Relation(provider, table, output, options, userSpecifiedSchema) } - private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { - options - .get(DataSourceOptions.TABLE_KEY) - .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) + // TODO: remove this when we finish API refactor for write. + def createRelationForWrite( + source: DataSourceV2, + options: Map[String, String]): DataSourceV2Relation = { + val provider = source.asInstanceOf[TableProvider] + val dsOptions = new DataSourceOptions(options.asJava) + val table = provider.getTable(dsOptions) + create(provider, table.asInstanceOf[SupportsBatchRead], options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 25f86a66a8269..725bcc3af3ca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -22,60 +22,47 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} /** - * Physical plan node for scanning data from a data source. + * Physical plan node for scanning a batch of data from a data source. */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], - @transient source: DataSourceV2, - @transient options: Map[String, String], - @transient pushedFilters: Seq[Expression], - @transient readSupport: ReadSupport, - @transient scanConfig: ScanConfig) - extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { + scanDesc: String, + @transient batch: Batch) + extends LeafExecNode with ColumnarBatchScan { - override def simpleString: String = "ScanV2 " + metadataString + override def simpleString: String = { + s"ScanV2${truncatedString(output, "[", ", ", "]")} $scanDesc" + } // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { - case other: DataSourceV2ScanExec => - output == other.output && readSupport.getClass == other.readSupport.getClass && - options == other.options + case other: DataSourceV2ScanExec => this.batch == other.batch case _ => false } - override def hashCode(): Int = { - Seq(output, source, options).hashCode() - } + override def hashCode(): Int = batch.hashCode() + + private lazy val partitions = batch.planInputPartitions() + + private lazy val readerFactory = batch.createReaderFactory() - override def outputPartitioning: physical.Partitioning = readSupport match { + override def outputPartitioning: physical.Partitioning = batch match { case _ if partitions.length == 1 => SinglePartition case s: SupportsReportPartitioning => new DataSourcePartitioning( - s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) case _ => super.outputPartitioning } - private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) - - private lazy val readerFactory = readSupport match { - case r: BatchReadSupport => r.createReaderFactory(scanConfig) - case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) - case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) - case _ => throw new IllegalStateException("unknown read support: " + readSupport) - } - - // TODO: clean this up when we have dedicated scan plan for continuous streaming. - override val supportsBatch: Boolean = { + override def supportsBatch: Boolean = { require(partitions.forall(readerFactory.supportColumnarReads) || !partitions.exists(readerFactory.supportColumnarReads), "Cannot mix row-based and columnar input partitions.") @@ -83,25 +70,8 @@ case class DataSourceV2ScanExec( partitions.exists(readerFactory.supportColumnarReads) } - private lazy val inputRDD: RDD[InternalRow] = readSupport match { - case _: ContinuousReadSupport => - assert(!supportsBatch, - "continuous stream reader does not support columnar read yet.") - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetReaderPartitions(partitions.size)) - new ContinuousDataSourceRDD( - sparkContext, - sqlContext.conf.continuousStreamingExecutorQueueSize, - sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - partitions, - schema, - readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) - - case _ => - new DataSourceRDD( - sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) + private lazy val inputRDD: RDD[InternalRow] = { + new DataSourceRDD(sparkContext, partitions, readerFactory, supportsBatch) } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9a3109e7c199e..2e26fce880b68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -37,9 +37,9 @@ object DataSourceV2Strategy extends Strategy { * @return pushed filter and post-scan filters. */ private def pushFilters( - configBuilder: ScanConfigBuilder, + scanBuilder: ScanBuilder, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - configBuilder match { + scanBuilder match { case r: SupportsPushDownFilters => // A map from translated data source filters to original catalyst filter expressions. val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] @@ -76,18 +76,18 @@ object DataSourceV2Strategy extends Strategy { */ // TODO: nested column pruning. private def pruneColumns( - configBuilder: ScanConfigBuilder, + scanBuilder: ScanBuilder, relation: DataSourceV2Relation, - exprs: Seq[Expression]): (ScanConfig, Seq[AttributeReference]) = { - configBuilder match { + exprs: Seq[Expression]): (Scan, Seq[AttributeReference]) = { + scanBuilder match { case r: SupportsPushDownRequiredColumns => val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) if (neededOutput != relation.output) { r.pruneColumns(neededOutput.toStructType) - val config = r.build() + val scan = r.build() val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - config -> config.readSchema().toAttributes.map { + scan -> scan.readSchema().toAttributes.map { // We have to keep the attribute id during transformation. a => a.withExprId(nameToAttr(a.name).exprId) } @@ -95,19 +95,19 @@ object DataSourceV2Strategy extends Strategy { r.build() -> relation.output } - case _ => configBuilder.build() -> relation.output + case _ => scanBuilder.build() -> relation.output } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val configBuilder = relation.readSupport.newScanConfigBuilder() + val scanBuilder = relation.newScanBuilder() // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(configBuilder, filters) - val (config, output) = pruneColumns(configBuilder, relation, project ++ postScanFilters) + val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, filters) + val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters) logInfo( s""" |Pushing operators to ${relation.source.getClass} @@ -116,16 +116,10 @@ object DataSourceV2Strategy extends Strategy { |Output: ${output.mkString(", ")} """.stripMargin) - val scan = DataSourceV2ScanExec( - output, - relation.source, - relation.options, - pushedFilters, - relation.readSupport, - config) + val plan = DataSourceV2ScanExec(output, scan.description(), scan.toBatch) val filterCondition = postScanFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + val withFilter = filterCondition.map(FilterExec(_, plan)).getOrElse(plan) // always add the projection, which will produce unsafe rows required by some operators ProjectExec(project, withFilter) :: Nil @@ -135,7 +129,7 @@ object DataSourceV2Strategy extends Strategy { val scanConfig = r.scanConfigBuilder.build() // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, - DataSourceV2ScanExec( + DataSourceV2StreamingScanExec( r.output, r.source, r.options, r.pushedFilters, r.readSupport, scanConfig)) :: Nil case WriteToDataSourceV2(writer, query) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala new file mode 100644 index 0000000000000..c872940909964 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReaderFactory, ContinuousReadSupport, MicroBatchReadSupport} + +/** + * Physical plan node for scanning data from a data source. + */ +// TODO: micro-batch should be handled by `DataSourceV2ScanExec`, after we finish the API refactor +// completely. +case class DataSourceV2StreamingScanExec( + output: Seq[AttributeReference], + @transient source: DataSourceV2, + @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], + @transient readSupport: ReadSupport, + @transient scanConfig: ScanConfig) + extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { + + override def simpleString: String = "ScanV2 " + metadataString + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2StreamingScanExec => + output == other.output && readSupport.getClass == other.readSupport.getClass && + options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } + + override def outputPartitioning: physical.Partitioning = readSupport match { + case _ if partitions.length == 1 => + SinglePartition + + case s: OldSupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(scanConfig), AttributeMap(output.map(a => a -> a.name))) + + case _ => super.outputPartitioning + } + + private lazy val partitions: Seq[InputPartition] = readSupport.planInputPartitions(scanConfig) + + private lazy val readerFactory = readSupport match { + case r: MicroBatchReadSupport => r.createReaderFactory(scanConfig) + case r: ContinuousReadSupport => r.createContinuousReaderFactory(scanConfig) + case _ => throw new IllegalStateException("unknown read support: " + readSupport) + } + + override val supportsBatch: Boolean = { + require(partitions.forall(readerFactory.supportColumnarReads) || + !partitions.exists(readerFactory.supportColumnarReads), + "Cannot mix row-based and columnar input partitions.") + + partitions.exists(readerFactory.supportColumnarReads) + } + + private lazy val inputRDD: RDD[InternalRow] = readSupport match { + case _: ContinuousReadSupport => + assert(!supportsBatch, + "continuous stream reader does not support columnar read yet.") + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetReaderPartitions(partitions.size)) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + partitions, + schema, + readerFactory.asInstanceOf[ContinuousPartitionReaderFactory]) + + case _ => + new DataSourceRDD( + sparkContext, partitions, readerFactory.asInstanceOf[PartitionReaderFactory], supportsBatch) + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this)(codegenStageId = 0).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 392229bcb5f55..6a22f0cc8431a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent @@ -256,7 +256,7 @@ trait ProgressReporter extends Logging { // (can happen with self-unions or self-joins). This means the source is scanned multiple // times in the query, we should count the numRows for each scan. val sourceToInputRowsTuples = lastExecution.executedPlan.collect { - case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2StreamingScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) val source = s.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1eab55122e84b..af23c5cd3d80a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} @@ -206,7 +206,8 @@ class ContinuousExecution( } val (readSupport, scanConfig) = lastExecution.executedPlan.collect { - case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => + case scan: DataSourceV2StreamingScanExec + if scan.readSupport.isInstanceOf[ContinuousReadSupport] => scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig }.head diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 5602310219a74..2612b6185fd4c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,62 +24,29 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaAdvancedDataSourceV2 implements TableProvider { - public class ReadSupport extends JavaSimpleReadSupport { - @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new AdvancedScanConfigBuilder(); - } - - @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; - List res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaRangeInputPartition(0, 5)); - res.add(new JavaRangeInputPartition(5, 10)); - } else if (lowerBound < 4) { - res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); - res.add(new JavaRangeInputPartition(5, 10)); - } else if (lowerBound < 9) { - res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + @Override + public Table getTable(DataSourceOptions options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new AdvancedScanBuilder(); } - - return res.stream().toArray(InputPartition[]::new); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; - return new AdvancedReaderFactory(requiredSchema); - } + }; } - public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, + static class AdvancedScanBuilder implements ScanBuilder, Scan, SupportsPushDownFilters, SupportsPushDownRequiredColumns { - // Exposed for testing. - public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - public Filter[] filters = new Filter[0]; + private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + private Filter[] filters = new Filter[0]; @Override public void pruneColumns(StructType requiredSchema) { @@ -121,9 +88,58 @@ public Filter[] pushedFilters() { } @Override - public ScanConfig build() { + public Scan build() { return this; } + + @Override + public Batch toBatch() { + return new AdvancedBatch(requiredSchema, filters); + } + } + + public static class AdvancedBatch implements Batch { + // Exposed for testing. + public StructType requiredSchema; + public Filter[] filters; + + AdvancedBatch(StructType requiredSchema, Filter[] filters) { + this.requiredSchema = requiredSchema; + this.filters = filters; + } + + @Override + public InputPartition[] planInputPartitions() { + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new AdvancedReaderFactory(requiredSchema); + } } static class AdvancedReaderFactory implements PartitionReaderFactory { @@ -165,10 +181,4 @@ public void close() throws IOException { }; } } - - - @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); - } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java index 28a9330398310..d72ab5338aa8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -21,21 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; -public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaColumnarDataSourceV2 implements TableProvider { - class ReadSupport extends JavaSimpleReadSupport { + class MyScanBuilder extends JavaSimpleScanBuilder { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new JavaRangeInputPartition(0, 50); partitions[1] = new JavaRangeInputPartition(50, 90); @@ -43,11 +43,21 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { + public PartitionReaderFactory createReaderFactory() { return new ColumnarReaderFactory(); } } + @Override + public Table getTable(DataSourceOptions options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new MyScanBuilder(); + } + }; + } + static class ColumnarReaderFactory implements PartitionReaderFactory { private static final int BATCH_SIZE = 20; @@ -106,9 +116,4 @@ public void close() throws IOException { }; } } - - @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); - } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 18a11dde82198..a513bfb26ef1c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -22,18 +22,20 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.*; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaPartitionAwareDataSource implements TableProvider { - class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { + class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportPartitioning { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); @@ -41,16 +43,26 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { + public PartitionReaderFactory createReaderFactory() { return new SpecificReaderFactory(); } @Override - public Partitioning outputPartitioning(ScanConfig config) { + public Partitioning outputPartitioning() { return new MyPartitioning(); } } + @Override + public Table getTable(DataSourceOptions options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new MyScanBuilder(); + } + }; + } + static class MyPartitioning implements Partitioning { @Override @@ -106,9 +118,4 @@ public void close() throws IOException { }; } } - - @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); - } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index cc9ac04a0dad3..815d57ba94139 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,39 +17,51 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaSchemaRequiredDataSource implements TableProvider { - class ReadSupport extends JavaSimpleReadSupport { - private final StructType schema; + class MyScanBuilder extends JavaSimpleScanBuilder { - ReadSupport(StructType schema) { + private StructType schema; + + MyScanBuilder(StructType schema) { this.schema = schema; } @Override - public StructType fullSchema() { + public StructType readSchema() { return schema; } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { return new InputPartition[0]; } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - throw new IllegalArgumentException("requires a user-supplied schema"); + public Table getTable(DataSourceOptions options, StructType schema) { + return new JavaSimpleBatchTable() { + + @Override + public StructType schema() { + return schema; + } + + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new MyScanBuilder(schema); + } + }; } @Override - public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { - return new ReadSupport(schema); + public Table getTable(DataSourceOptions options) { + throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java similarity index 78% rename from sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java rename to sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java index ced51dde6997b..cb5954d5a6211 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java @@ -21,43 +21,44 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -abstract class JavaSimpleReadSupport implements BatchReadSupport { +abstract class JavaSimpleBatchTable implements Table, SupportsBatchRead { @Override - public StructType fullSchema() { + public StructType schema() { return new StructType().add("i", "int").add("j", "int"); } @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new JavaNoopScanConfigBuilder(fullSchema()); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new JavaSimpleReaderFactory(); + public String name() { + return this.getClass().toString(); } } -class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { - - private StructType schema; +abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch { - JavaNoopScanConfigBuilder(StructType schema) { - this.schema = schema; + @Override + public Scan build() { + return this; } @Override - public ScanConfig build() { + public Batch toBatch() { return this; } @Override public StructType readSchema() { - return schema; + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new JavaSimpleReaderFactory(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2cdbba84ec4a4..852c4546df885 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,17 +17,17 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; -public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaSimpleDataSourceV2 implements TableProvider { - class ReadSupport extends JavaSimpleReadSupport { + class MyScanBuilder extends JavaSimpleScanBuilder { @Override - public InputPartition[] planInputPartitions(ScanConfig config) { + public InputPartition[] planInputPartitions() { InputPartition[] partitions = new InputPartition[2]; partitions[0] = new JavaRangeInputPartition(0, 5); partitions[1] = new JavaRangeInputPartition(5, 10); @@ -36,7 +36,12 @@ public InputPartition[] planInputPartitions(ScanConfig config) { } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public Table getTable(DataSourceOptions options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new MyScanBuilder(); + } + }; } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e8f291af13baf..d282193d35d76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -38,18 +38,17 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ - private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + private def getBatch(query: DataFrame): AdvancedBatch = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + d.batch.asInstanceOf[AdvancedBatch] }.head } - private def getJavaScanConfig( - query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = { query.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => - d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + d.batch.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedBatch] }.head } @@ -73,51 +72,51 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val batch = getBatch(q1) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q1) - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + val batch = getJavaBatch(q1) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val batch = getBatch(q2) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } else { - val config = getJavaScanConfig(q2) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i", "j")) + val batch = getJavaBatch(q2) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val batch = getBatch(q3) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) } else { - val config = getJavaScanConfig(q3) - assert(config.filters.flatMap(_.references).toSet == Set("i")) - assert(config.requiredSchema.fieldNames === Seq("i")) + val batch = getJavaBatch(q3) + assert(batch.filters.flatMap(_.references).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val config = getScanConfig(q4) + val batch = getBatch(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } else { - val config = getJavaScanConfig(q4) + val batch = getJavaBatch(q4) // 'j < 10 is not supported by the testing data source. - assert(config.filters.isEmpty) - assert(config.requiredSchema.fieldNames === Seq("j")) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) } } } @@ -279,26 +278,26 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val config1 = getScanConfig(q1) - assert(config1.requiredSchema.fieldNames === Seq("i")) + val batch1 = getBatch(q1) + assert(batch1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val config2 = getScanConfig(q2) - assert(config2.requiredSchema.isEmpty) + val batch2 = getBatch(q2) + assert(batch2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val config3 = getScanConfig(q3) - assert(config3.filters.isEmpty) - assert(config3.requiredSchema.fieldNames === Seq("j")) + val batch3 = getBatch(q3) + assert(batch3.filters.isEmpty) + assert(batch3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val config4 = getScanConfig(q4) - assert(config4.requiredSchema.fieldNames === Seq("i")) + val batch4 = getBatch(q4) + assert(batch4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -374,10 +373,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { case class RangeInputPartition(start: Int, end: Int) extends InputPartition -case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { - override def build(): ScanConfig = this -} - object SimpleReaderFactory extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val RangeInputPartition(start, end) = partition @@ -396,87 +391,68 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } -abstract class SimpleReadSupport extends BatchReadSupport { - override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") +abstract class SimpleBatchTable extends Table with SupportsBatchRead { - override def newScanConfigBuilder(): ScanConfigBuilder = { - NoopScanConfigBuilder(fullSchema()) - } + override def schema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - SimpleReaderFactory - } + override def name(): String = this.getClass.toString } +abstract class SimpleScanBuilder extends ScanBuilder + with Batch with Scan { + + override def build(): Scan = this + + override def toBatch: Batch = this -class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory +} + +class SimpleSinglePartitionSource extends TableProvider { + + class MyScanBuilder extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } } // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. -class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class SimpleDataSourceV2 extends TableProvider { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyScanBuilder extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } } -class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - - class ReadSupport extends SimpleReadSupport { - override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() +class AdvancedDataSourceV2 extends TableProvider { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters - - val lowerBound = filters.collectFirst { - case GreaterThan("i", v: Int) => v - } - - val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] - - if (lowerBound.isEmpty) { - res.append(RangeInputPartition(0, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 4) { - res.append(RangeInputPartition(lowerBound.get + 1, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 9) { - res.append(RangeInputPartition(lowerBound.get + 1, 10)) - } - - res.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema - new AdvancedReaderFactory(requiredSchema) + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new AdvancedScanBuilder() } } - - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } } -class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { +class AdvancedScanBuilder extends ScanBuilder + with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns { var requiredSchema = new StructType().add("i", "int").add("j", "int") var filters = Array.empty[Filter] @@ -498,10 +474,40 @@ class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig override def pushedFilters(): Array[Filter] = filters - override def build(): ScanConfig = this + override def build(): Scan = this + + override def toBatch: Batch = new AdvancedBatch(filters, requiredSchema) +} + +class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = filters.collectFirst { + case GreaterThan("i", v: Int) => v + } + + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 4) { + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 9) { + res.append(RangeInputPartition(lowerBound.get + 1, 10)) + } + + res.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = { + new AdvancedReaderFactory(requiredSchema) + } } class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val RangeInputPartition(start, end) = partition new PartitionReader[InternalRow] { @@ -526,39 +532,47 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF } -class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { +class SchemaRequiredDataSource extends TableProvider { - class ReadSupport(val schema: StructType) extends SimpleReadSupport { - override def fullSchema(): StructType = schema + class MyScanBuilder(schema: StructType) extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = Array.empty - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = - Array.empty + override def readSchema(): StructType = schema } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + override def getTable(options: DataSourceOptions): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createBatchReadSupport( - schema: StructType, options: DataSourceOptions): BatchReadSupport = { - new ReadSupport(schema) + override def getTable(options: DataSourceOptions, schema: StructType): Table = { + val userGivenSchema = schema + new SimpleBatchTable { + override def schema(): StructType = userGivenSchema + + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder(userGivenSchema) + } + } } } -class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { +class ColumnarDataSourceV2 extends TableProvider { - class ReadSupport extends SimpleReadSupport { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyScanBuilder extends SimpleScanBuilder { + + override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { ColumnarReaderFactory } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } } @@ -608,21 +622,29 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } -class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { +class PartitionAwareDataSource extends TableProvider { + + class MyScanBuilder extends SimpleScanBuilder + with SupportsReportPartitioning{ - class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + override def planInputPartitions(): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. Array( SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { SpecificReaderFactory } - override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder() + } } class MyPartitioning extends Partitioning { @@ -633,10 +655,6 @@ class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvide case _ => false } } - - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - new ReadSupport - } } case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition @@ -662,7 +680,7 @@ object SpecificReaderFactory extends PartitionReaderFactory { class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def fullSchema(): StructType = { + override def writeSchema(): StructType = { // This is a bit hacky since this source implements read support but throws // during schema retrieval. Might have to rewrite but it's done // such so for minimised changes. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a7dfc2d1deacc..82bb4fa33a3ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -39,19 +39,16 @@ import org.apache.spark.util.SerializableConfiguration * Each job moves files from `target/_temporary/queryId/` to `target`. */ class SimpleWritableDataSource extends DataSourceV2 - with BatchReadSupportProvider + with TableProvider with BatchWriteSupportProvider with SessionConfigSupport { - protected def fullSchema(): StructType = new StructType().add("i", "long").add("j", "long") + protected def writeSchema(): StructType = new StructType().add("i", "long").add("j", "long") override def keyPrefix: String = "simpleWritableDataSource" - class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - - override def fullSchema(): StructType = SimpleWritableDataSource.this.fullSchema() - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + class MyScanBuilder(path: String, conf: Configuration) extends SimpleScanBuilder { + override def planInputPartitions(): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -66,10 +63,24 @@ class SimpleWritableDataSource extends DataSourceV2 } } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + override def createReaderFactory(): PartitionReaderFactory = { val serializableConf = new SerializableConfiguration(conf) new CSVReaderFactory(serializableConf) } + + override def readSchema(): StructType = writeSchema + } + + override def getTable(options: DataSourceOptions): Table = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder(path.toUri.toString, conf) + } + + override def schema(): StructType = writeSchema + } } class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { @@ -105,12 +116,6 @@ class SimpleWritableDataSource extends DataSourceV2 } } - override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { - val path = new Path(options.get("path").get()) - val conf = SparkContext.getActive.get.hadoopConfiguration - new ReadSupport(path.toUri.toString, conf) - } - override def createBatchWriteSupport( queryId: String, schema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 93eae292acc2b..756092fc7ff5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming.continuous import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r + case DataSourceV2StreamingScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r }.get val deltaMs = numTriggers * 1000 + 300 From c3f27b2437497396913fdec96f085c3626ef4e59 Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Fri, 30 Nov 2018 09:03:46 -0600 Subject: [PATCH 2178/2461] [MINOR][DOCS] Fix typos ## What changes were proposed in this pull request? Fix Typos. This PR is the complete version of https://github.com/apache/spark/pull/23145. ## How was this patch tested? NA Closes #23185 from kjmrknsn/docUpdate. Authored-by: Keiji Yoshida Signed-off-by: Sean Owen --- docs/configuration.md | 2 +- docs/graphx-programming-guide.md | 4 ++-- docs/ml-datasource.md | 2 +- docs/ml-features.md | 8 ++++---- docs/ml-pipeline.md | 2 +- docs/mllib-linear-methods.md | 4 ++-- docs/security.md | 2 +- docs/sparkr.md | 2 +- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 04210d855b110..8914bd0310f98 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -498,7 +498,7 @@ Apart from these, the following properties are also available, and may be useful
        diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index cb96fd773aa5a..ecedeaf958f19 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -522,7 +522,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts, A key step in many graph analytics tasks is aggregating information about the neighborhood of each vertex. -For example, we might want to know the number of followers each user has or the average age of the +For example, we might want to know the number of followers each user has or the average age of the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and connected components) repeatedly aggregate properties of neighboring vertices (e.g., current PageRank Value, shortest path to the source, and smallest reachable vertex id). @@ -700,7 +700,7 @@ a new value for the vertex property, and then send messages to neighboring verti super step. Unlike Pregel, messages are computed in parallel as a function of the edge triplet and the message computation has access to both the source and destination vertex attributes. Vertices that do not receive a message are skipped within a super -step. The Pregel operators terminates iteration and returns the final graph when there are no +step. The Pregel operator terminates iteration and returns the final graph when there are no messages remaining. > Note, unlike more standard Pregel implementations, vertices in GraphX can only send messages to diff --git a/docs/ml-datasource.md b/docs/ml-datasource.md index 15083326240ac..35afaef5ad7f0 100644 --- a/docs/ml-datasource.md +++ b/docs/ml-datasource.md @@ -5,7 +5,7 @@ displayTitle: Data sources --- In this section, we introduce how to use data source in ML to load data. -Beside some general data sources such as Parquet, CSV, JSON and JDBC, we also provide some specific data sources for ML. +Besides some general data sources such as Parquet, CSV, JSON and JDBC, we also provide some specific data sources for ML. **Table of Contents** diff --git a/docs/ml-features.md b/docs/ml-features.md index 83a211ce02e67..a140bc6e7a22f 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -359,7 +359,7 @@ Assume that we have the following DataFrame with columns `id` and `raw`: ~~~~ id | raw ----|---------- - 0 | [I, saw, the, red, baloon] + 0 | [I, saw, the, red, balloon] 1 | [Mary, had, a, little, lamb] ~~~~ @@ -369,7 +369,7 @@ column, we should get the following: ~~~~ id | raw | filtered ----|-----------------------------|-------------------- - 0 | [I, saw, the, red, baloon] | [saw, red, baloon] + 0 | [I, saw, the, red, balloon] | [saw, red, balloon] 1 | [Mary, had, a, little, lamb]|[Mary, little, lamb] ~~~~ @@ -1302,7 +1302,7 @@ need to know vector size, can use that column as an input. To use `VectorSizeHint` a user must set the `inputCol` and `size` parameters. Applying this transformer to a dataframe produces a new dataframe with updated metadata for `inputCol` specifying the vector size. Downstream operations on the resulting dataframe can get this size using the -meatadata. +metadata. `VectorSizeHint` can also take an optional `handleInvalid` parameter which controls its behaviour when the vector column contains nulls or vectors of the wrong size. By default @@ -1310,7 +1310,7 @@ behaviour when the vector column contains nulls or vectors of the wrong size. By also be set to "skip", indicating that rows containing invalid values should be filtered out from the resulting dataframe, or "optimistic", indicating that the column should not be checked for invalid values and all rows should be kept. Note that the use of "optimistic" can cause the -resulting dataframe to be in an inconsistent state, me:aning the metadata for the column +resulting dataframe to be in an inconsistent state, meaning the metadata for the column `VectorSizeHint` was applied to does not match the contents of that column. Users should take care to avoid this kind of inconsistent state. diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index 8c01ccb94c75f..0c9c998f63535 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -62,7 +62,7 @@ In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [ A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. -Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." +Columns in a `DataFrame` are named. The code examples below use names such as "text", "features", and "label". ## Pipeline components diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 73f6e206ca543..2879d884162ad 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -272,7 +272,7 @@ In `spark.mllib`, the first class $0$ is chosen as the "pivot" class. See Section 4.4 of [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for references. -Here is an +Here is a [detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297). For multiclass classification problems, the algorithm will output a multinomial logistic regression @@ -350,7 +350,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
        -The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. +The following example demonstrates how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). diff --git a/docs/security.md b/docs/security.md index 02d581c6dad91..be4834660fb7a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -337,7 +337,7 @@ Configuration for SSL is organized hierarchically. The user can configure the de which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The following table -describes the the SSL configuration namespaces: +describes the SSL configuration namespaces:
        Reuse Python worker or not. If yes, it will use a fixed number of Python workers, does not need to fork() a Python process for every task. It will be very useful - if there is large broadcast, then the broadcast will not be needed to transferred + if there is a large broadcast, then the broadcast will not need to be transferred from JVM to Python worker for every task.
        diff --git a/docs/sparkr.md b/docs/sparkr.md index 5972435a0e409..0057f05de0ff3 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -296,7 +296,7 @@ head(agg(rollup(df, "cyl", "disp", "gear"), avg(df$mpg))) ### Operating on Columns -SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. +SparkR also provides a number of functions that can be directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions.
        {% highlight r %} From 9b23be2e95fec756066ca0ed3188c3db2602b757 Mon Sep 17 00:00:00 2001 From: schintap Date: Fri, 30 Nov 2018 12:48:56 -0600 Subject: [PATCH 2179/2461] [SPARK-26201] Fix python broadcast with encryption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Python with rpc and disk encryption enabled along with a python broadcast variable and just read the value back on the driver side the job failed with: Traceback (most recent call last): File "broadcast.py", line 37, in words_new.value File "/pyspark.zip/pyspark/broadcast.py", line 137, in value File "pyspark.zip/pyspark/broadcast.py", line 122, in load_from_path File "pyspark.zip/pyspark/broadcast.py", line 128, in load EOFError: Ran out of input To reproduce use configs: --conf spark.network.crypto.enabled=true --conf spark.io.encryption.enabled=true Code: words_new = sc.broadcast(["scala", "java", "hadoop", "spark", "akka"]) words_new.value print(words_new.value) ## How was this patch tested? words_new = sc.broadcast([“scala”, “java”, “hadoop”, “spark”, “akka”]) textFile = sc.textFile(“README.md”) wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word + words_new.value[1], 1)).reduceByKey(lambda a, b: a+b) count = wordCounts.count() print(count) words_new.value print(words_new.value) Closes #23166 from redsanket/SPARK-26201. Authored-by: schintap Signed-off-by: Thomas Graves --- .../apache/spark/api/python/PythonRDD.scala | 29 ++++++++++++++++--- python/pyspark/broadcast.py | 21 ++++++++++---- python/pyspark/tests/test_broadcast.py | 15 ++++++++++ 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8b5a7a9aefea5..5ed5070558af7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial with Logging { private var encryptionServer: PythonServer[Unit] = null + private var decryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial override def handleConnection(sock: Socket): Unit = { val env = SparkEnv.get val in = sock.getInputStream() - val dir = new File(Utils.getLocalDir(env.conf)) - val file = File.createTempFile("broadcast", "", dir) - path = file.getAbsolutePath - val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + val abspath = new File(path).getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath)) DechunkedInputStream.dechunkAndCopyToOutput(in, out) } } Array(encryptionServer.port, encryptionServer.secret) } + def setupDecryptionServer(): Array[Any] = { + decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") { + override def handleConnection(sock: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream())) + Utils.tryWithSafeFinally { + val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + out.flush() + } { + JavaUtils.closeQuietly(out) + } + } + } + Array(decryptionServer.port, decryptionServer.secret) + } + + def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult() + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 1c7f2a7418df0..29358b5740e51 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -77,11 +77,12 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name - python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + self._sc = sc + self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket - port, auth_secret = python_broadcast.setupEncryptionServer() + port, auth_secret = self._python_broadcast.setupEncryptionServer() (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: @@ -89,12 +90,14 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, broadcast_out = f self.dump(value, broadcast_out) if sc._encryption_enabled: - python_broadcast.waitTillDataReceived() - self._jbroadcast = sc._jsc.broadcast(python_broadcast) + self._python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None + self._sc = None + self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file @@ -134,7 +137,15 @@ def value(self): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load_from_path(self._path) + # we only need to decrypt it here when encryption is enabled and + # if its on the driver, since executor decryption is handled already + if self._sc is not None and self._sc._encryption_enabled: + port, auth_secret = self._python_broadcast.setupDecryptionServer() + (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) + self._python_broadcast.waitTillBroadcastDataSent() + return self.load(decrypted_sock_file) + else: + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index a98626e8f4bc9..11d31d24bb011 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -67,6 +67,21 @@ def test_broadcast_with_encryption(self): def test_broadcast_no_encryption(self): self._test_multiple_broadcasts() + def _test_broadcast_on_driver(self, *extra_confs): + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + bs = self.sc.broadcast(value=5) + self.assertEqual(5, bs.value) + + def test_broadcast_value_driver_no_encryption(self): + self._test_broadcast_on_driver() + + def test_broadcast_value_driver_encryption(self): + self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) + class BroadcastFrameProtocolTest(unittest.TestCase): From 36edbac1c8337a4719f90e4abd58d38738b2e1fb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 30 Nov 2018 14:23:18 -0800 Subject: [PATCH 2180/2461] [SPARK-26226][SQL] Update query tracker to report timeline for phases ## What changes were proposed in this pull request? This patch changes the query plan tracker added earlier to report phase timeline, rather than just a duration for each phase. This way, we can easily find time that's unaccounted for. ## How was this patch tested? Updated test cases to reflect that. Closes #23183 from rxin/SPARK-26226. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../sql/catalyst/QueryPlanningTracker.scala | 45 +++++++++++++++---- .../catalyst/QueryPlanningTrackerSuite.scala | 18 +++++--- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 8 ++-- .../QueryPlanningTrackerEndToEndSuite.scala | 15 +------ 5 files changed, 55 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala index 244081cd160b6..cd75407c7ee7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala @@ -41,6 +41,13 @@ object QueryPlanningTracker { val OPTIMIZATION = "optimization" val PLANNING = "planning" + /** + * Summary for a rule. + * @param totalTimeNs total amount of time, in nanosecs, spent in this rule. + * @param numInvocations number of times the rule has been invoked. + * @param numEffectiveInvocations number of times the rule has been invoked and + * resulted in a plan change. + */ class RuleSummary( var totalTimeNs: Long, var numInvocations: Long, var numEffectiveInvocations: Long) { @@ -51,6 +58,18 @@ object QueryPlanningTracker { } } + /** + * Summary of a phase, with start time and end time so we can construct a timeline. + */ + class PhaseSummary(val startTimeMs: Long, val endTimeMs: Long) { + + def durationMs: Long = endTimeMs - startTimeMs + + override def toString: String = { + s"PhaseSummary($startTimeMs, $endTimeMs)" + } + } + /** * A thread local variable to implicitly pass the tracker around. This assumes the query planner * is single-threaded, and avoids passing the same tracker context in every function call. @@ -79,15 +98,25 @@ class QueryPlanningTracker { // Use a Java HashMap for less overhead. private val rulesMap = new java.util.HashMap[String, RuleSummary] - // From a phase to time in ns. - private val phaseToTimeNs = new java.util.HashMap[String, Long] + // From a phase to its start time and end time, in ms. + private val phasesMap = new java.util.HashMap[String, PhaseSummary] - /** Measure the runtime of function f, and add it to the time for the specified phase. */ - def measureTime[T](phase: String)(f: => T): T = { - val startTime = System.nanoTime() + /** + * Measure the start and end time of a phase. Note that if this function is called multiple + * times for the same phase, the recorded start time will be the start time of the first call, + * and the recorded end time will be the end time of the last call. + */ + def measurePhase[T](phase: String)(f: => T): T = { + val startTime = System.currentTimeMillis() val ret = f - val timeTaken = System.nanoTime() - startTime - phaseToTimeNs.put(phase, phaseToTimeNs.getOrDefault(phase, 0) + timeTaken) + val endTime = System.currentTimeMillis + + if (phasesMap.containsKey(phase)) { + val oldSummary = phasesMap.get(phase) + phasesMap.put(phase, new PhaseSummary(oldSummary.startTimeMs, endTime)) + } else { + phasesMap.put(phase, new PhaseSummary(startTime, endTime)) + } ret } @@ -114,7 +143,7 @@ class QueryPlanningTracker { def rules: Map[String, RuleSummary] = rulesMap.asScala.toMap - def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap + def phases: Map[String, PhaseSummary] = phasesMap.asScala.toMap /** * Returns the top k most expensive rules (as measured by time). If k is larger than the rules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala index 120b284a77854..9593a720e4248 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala @@ -23,19 +23,23 @@ class QueryPlanningTrackerSuite extends SparkFunSuite { test("phases") { val t = new QueryPlanningTracker - t.measureTime("p1") { + t.measurePhase("p1") { Thread.sleep(1) } - assert(t.phases("p1") > 0) + assert(t.phases("p1").durationMs > 0) assert(!t.phases.contains("p2")) + } - val old = t.phases("p1") + test("multiple measurePhase call") { + val t = new QueryPlanningTracker + t.measurePhase("p1") { Thread.sleep(1) } + val s1 = t.phases("p1") + assert(s1.durationMs > 0) - t.measureTime("p1") { - Thread.sleep(1) - } - assert(t.phases("p1") > old) + t.measurePhase("p1") { Thread.sleep(1) } + val s2 = t.phases("p1") + assert(s2.durationMs > s1.durationMs) } test("rules") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 739c6b54b4cb3..26272c3906685 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -649,7 +649,7 @@ class SparkSession private( */ def sql(sqlText: String): DataFrame = { val tracker = new QueryPlanningTracker - val plan = tracker.measureTime(QueryPlanningTracker.PARSING) { + val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { sessionState.sqlParser.parsePlan(sqlText) } Dataset.ofRows(self, plan, tracker) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index cfb5e43207b03..eef5a3f899f55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -60,7 +60,7 @@ class QueryExecution( } } - lazy val analyzed: LogicalPlan = tracker.measureTime(QueryPlanningTracker.ANALYSIS) { + lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) { SparkSession.setActiveSession(sparkSession) sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } @@ -71,11 +71,11 @@ class QueryExecution( sparkSession.sharedState.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = tracker.measureTime(QueryPlanningTracker.OPTIMIZATION) { + lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker) } - lazy val sparkPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { + lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. @@ -84,7 +84,7 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { + lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { prepareForExecution(sparkPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala index 0af4c85400e9e..e42177c156ee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala @@ -25,12 +25,7 @@ class QueryPlanningTrackerEndToEndSuite extends SharedSQLContext { val df = spark.range(1000).selectExpr("count(*)") df.collect() val tracker = df.queryExecution.tracker - - assert(tracker.phases.size == 3) - assert(tracker.phases("analysis") > 0) - assert(tracker.phases("optimization") > 0) - assert(tracker.phases("planning") > 0) - + assert(tracker.phases.keySet == Set("analysis", "optimization", "planning")) assert(tracker.rules.nonEmpty) } @@ -39,13 +34,7 @@ class QueryPlanningTrackerEndToEndSuite extends SharedSQLContext { df.collect() val tracker = df.queryExecution.tracker - - assert(tracker.phases.size == 4) - assert(tracker.phases("parsing") > 0) - assert(tracker.phases("analysis") > 0) - assert(tracker.phases("optimization") > 0) - assert(tracker.phases("planning") > 0) - + assert(tracker.phases.keySet == Set("parsing", "analysis", "optimization", "planning")) assert(tracker.rules.nonEmpty) } From 8856e9f6a3d5c019fcae45dbbdfa9128cd700e19 Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 30 Nov 2018 15:20:05 -0800 Subject: [PATCH 2181/2461] [SPARK-26219][CORE] Executor summary should get updated for failure jobs in the history server UI The root cause of the problem is, whenever the taskEnd event comes after stageCompleted event, execSummary is updating only for live UI. we need to update for history UI too. To see the previous discussion, refer: PR for https://github.com/apache/spark/pull/23038, https://issues.apache.org/jira/browse/SPARK-26100. Added UT. Manually verified Test step to reproduce: ``` bin/spark-shell --master yarn --conf spark.executor.instances=3 sc.parallelize(1 to 10000, 10).map{ x => throw new RuntimeException("Bad executor")}.collect() ``` Open Executors page from the History UI Before patch: ![screenshot from 2018-11-29 22-13-34](https://user-images.githubusercontent.com/23054875/49246338-a21ead00-f43a-11e8-8214-f1020420be52.png) After patch: ![screenshot from 2018-11-30 00-54-49](https://user-images.githubusercontent.com/23054875/49246353-aa76e800-f43a-11e8-98ef-7faecaa7a50e.png) Closes #23181 from shahidki31/executorUpdate. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../spark/status/AppStatusListener.scala | 19 ++-- .../spark/status/AppStatusListenerSuite.scala | 92 +++++++++++-------- 2 files changed, 64 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 8e845573a903d..bd3f58b6182c0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -641,9 +641,14 @@ private[spark] class AppStatusListener( } } - // Force an update on live applications when the number of active tasks reaches 0. This is - // checked in some tests (e.g. SQLTestUtilsBase) so it needs to be reliably up to date. - conditionalLiveUpdate(exec, now, exec.activeTasks == 0) + // Force an update on both live and history applications when the number of active tasks + // reaches 0. This is checked in some tests (e.g. SQLTestUtilsBase) so it needs to be + // reliably up to date. + if (exec.activeTasks == 0) { + update(exec, now) + } else { + maybeUpdate(exec, now) + } } } @@ -1024,14 +1029,6 @@ private[spark] class AppStatusListener( } } - private def conditionalLiveUpdate(entity: LiveEntity, now: Long, condition: Boolean): Unit = { - if (condition) { - liveUpdate(entity, now) - } else { - maybeUpdate(entity, now) - } - } - private def cleanupExecutors(count: Long): Unit = { // Because the limit is on the number of *dead* executors, we need to calculate whether // there are actually enough dead executors to be deleted. diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 7860a0df4bb2d..61fec8c1d0e4e 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1273,48 +1273,68 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(allJobs.head.numFailedStages == 1) } - test("SPARK-25451: total tasks in the executor summary should match total stage tasks") { - val testConf = conf.clone.set(LIVE_ENTITY_UPDATE_PERIOD, Long.MaxValue) + Seq(true, false).foreach { live => + test(s"Total tasks in the executor summary should match total stage tasks (live = $live)") { - val listener = new AppStatusListener(store, testConf, true) + val testConf = if (live) { + conf.clone().set(LIVE_ENTITY_UPDATE_PERIOD, Long.MaxValue) + } else { + conf.clone().set(LIVE_ENTITY_UPDATE_PERIOD, -1L) + } - val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details") - listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) - listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + val listener = new AppStatusListener(store, testConf, live) - val tasks = createTasks(4, Array("1", "2")) - tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) - } + listener.onExecutorAdded(createExecutorAddedEvent(1)) + listener.onExecutorAdded(createExecutorAddedEvent(2)) + val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details") + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) - time += 1 - tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", - Success, tasks(0), null)) - time += 1 - tasks(1).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", - Success, tasks(1), null)) + val tasks = createTasks(4, Array("1", "2")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) + } - stage.failureReason = Some("Failed") - listener.onStageCompleted(SparkListenerStageCompleted(stage)) - time += 1 - listener.onJobEnd(SparkListenerJobEnd(1, time, JobFailed(new RuntimeException("Bad Executor")))) + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(0), null)) + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(1), null)) - time += 1 - tasks(2).markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", - ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) - time += 1 - tasks(3).markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", - ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) - - val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) - esummary.foreach { execSummary => - assert(execSummary.failedTasks === 1) - assert(execSummary.succeededTasks === 1) - assert(execSummary.killedTasks === 0) + stage.failureReason = Some("Failed") + listener.onStageCompleted(SparkListenerStageCompleted(stage)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobFailed( + new RuntimeException("Bad Executor")))) + + time += 1 + tasks(2).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) + time += 1 + tasks(3).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) + + val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) + esummary.foreach { execSummary => + assert(execSummary.failedTasks === 1) + assert(execSummary.succeededTasks === 1) + assert(execSummary.killedTasks === 0) + } + + val allExecutorSummary = store.view(classOf[ExecutorSummaryWrapper]).asScala.map(_.info) + assert(allExecutorSummary.size === 2) + allExecutorSummary.foreach { allExecSummary => + assert(allExecSummary.failedTasks === 1) + assert(allExecSummary.activeTasks === 0) + assert(allExecSummary.completedTasks === 1) + } + store.delete(classOf[ExecutorSummaryWrapper], "1") + store.delete(classOf[ExecutorSummaryWrapper], "2") } } From 6be272b75b4ae3149869e19df193675cc4117763 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 30 Nov 2018 16:23:37 -0800 Subject: [PATCH 2182/2461] [SPARK-25876][K8S] Simplify kubernetes configuration types. There are a few issues with the current configuration types used in the kubernetes backend: - they use type parameters for role-specific specialization, which makes type signatures really noisy throughout the code base. - they break encapsulation by forcing the code that creates the config object to remove the configuration from SparkConf before creating the k8s-specific wrapper. - they don't provide an easy way for tests to have default values for fields they do not use. This change fixes those problems by: - creating a base config type with role-specific specialization using inheritance - encapsulating the logic of parsing SparkConf into k8s-specific views inside the k8s config classes - providing some helper code for tests to easily override just the part of the configs they want. Most of the change relates to the above, especially cleaning up the tests. While doing that, I also made some smaller changes elsewhere: - removed unnecessary type parameters in KubernetesVolumeSpec - simplified the error detection logic in KubernetesVolumeUtils; all the call sites would just throw the first exception collected by that class, since they all called "get" on the "Try" object. Now the unnecessary wrapping is gone and the exception is just thrown where it occurs. - removed a lot of unnecessary mocking from tests. - changed the kerberos-related code so that less logic needs to live in the driver builder. In spirit it should be part of the upcoming work in this series of cleanups, but it made parts of this change simpler. Tested with existing unit tests and integration tests. Author: Marcelo Vanzin Closes #22959 from vanzin/SPARK-25876. --- .../org/apache/spark/deploy/k8s/Config.scala | 17 +- .../spark/deploy/k8s/KubernetesConf.scala | 302 ++++++++---------- .../deploy/k8s/KubernetesVolumeSpec.scala | 10 +- .../deploy/k8s/KubernetesVolumeUtils.scala | 53 +-- .../k8s/features/BasicDriverFeatureStep.scala | 24 +- .../features/BasicExecutorFeatureStep.scala | 29 +- .../features/DriverCommandFeatureStep.scala | 22 +- ...iverKubernetesCredentialsFeatureStep.scala | 6 +- .../features/DriverServiceFeatureStep.scala | 10 +- .../k8s/features/EnvSecretsFeatureStep.scala | 11 +- .../HadoopConfExecutorFeatureStep.scala | 14 +- .../HadoopSparkUserExecutorFeatureStep.scala | 17 +- .../KerberosConfDriverFeatureStep.scala | 113 ++++--- .../KerberosConfExecutorFeatureStep.scala | 21 +- .../k8s/features/LocalDirsFeatureStep.scala | 9 +- .../features/MountSecretsFeatureStep.scala | 13 +- .../features/MountVolumesFeatureStep.scala | 11 +- .../features/PodTemplateConfigMapStep.scala | 5 +- .../hadooputils/HadoopKerberosLogin.scala | 64 ---- ...bernetesHadoopDelegationTokenManager.scala | 37 --- .../submit/KubernetesClientApplication.scala | 61 +--- .../k8s/submit/KubernetesDriverBuilder.scala | 53 ++- .../k8s/KubernetesExecutorBuilder.scala | 36 +-- .../deploy/k8s/KubernetesConfSuite.scala | 71 ++-- .../spark/deploy/k8s/KubernetesTestConf.scala | 138 ++++++++ .../k8s/KubernetesVolumeUtilsSuite.scala | 30 +- .../BasicDriverFeatureStepSuite.scala | 127 ++------ .../BasicExecutorFeatureStepSuite.scala | 103 ++---- .../DriverCommandFeatureStepSuite.scala | 29 +- ...ubernetesCredentialsFeatureStepSuite.scala | 69 +--- .../DriverServiceFeatureStepSuite.scala | 193 ++++------- .../features/EnvSecretsFeatureStepSuite.scala | 32 +- .../features/LocalDirsFeatureStepSuite.scala | 46 +-- .../MountSecretsFeatureStepSuite.scala | 21 +- .../MountVolumesFeatureStepSuite.scala | 56 ++-- .../PodTemplateConfigMapStepSuite.scala | 28 +- .../spark/deploy/k8s/submit/ClientSuite.scala | 47 +-- .../submit/KubernetesDriverBuilderSuite.scala | 204 ++---------- .../k8s/ExecutorPodsAllocatorSuite.scala | 43 +-- .../k8s/KubernetesExecutorBuilderSuite.scala | 114 ++----- 40 files changed, 777 insertions(+), 1512 deletions(-) delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 724acd231a6cb..1abf2901268f8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -60,7 +60,8 @@ private[spark] object Config extends Logging { .doc("Comma separated list of the Kubernetes secrets used " + "to access private image registries.") .stringConf - .createOptional + .toSequence + .createWithDefault(Nil) val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" @@ -112,16 +113,16 @@ private[spark] object Config extends Logging { .stringConf .createOptional - val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = - ConfigBuilder("spark.kubernetes.executor.podNamePrefix") - .doc("Prefix to use in front of the executor pod names.") + // For testing only. + val KUBERNETES_DRIVER_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.driver.resourceNamePrefix") .internal() .stringConf - .createWithDefault("spark") + .createOptional - val KUBERNETES_PYSPARK_PY_FILES = - ConfigBuilder("spark.kubernetes.python.pyFiles") - .doc("The PyFiles that are distributed via client arguments") + val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.executor.podNamePrefix") + .doc("Prefix to use in front of the executor pod names.") .internal() .stringConf .createOptional diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index ebb81540bbbbe..a06c21b47f15e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -16,93 +16,53 @@ */ package org.apache.spark.deploy.k8s -import scala.collection.mutable +import java.util.Locale import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} -import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.security.KubernetesHadoopDelegationTokenManager import org.apache.spark.deploy.k8s.submit._ -import org.apache.spark.deploy.k8s.submit.KubernetesClientApplication._ import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.util.Utils - -private[spark] sealed trait KubernetesRoleSpecificConf - -/* - * Structure containing metadata for Kubernetes logic that builds a Spark driver. - */ -private[spark] case class KubernetesDriverSpecificConf( - mainAppResource: MainAppResource, - mainClass: String, - appName: String, - appArgs: Seq[String], - pyFiles: Seq[String] = Nil) extends KubernetesRoleSpecificConf { - - require(mainAppResource != null, "Main resource must be provided.") - -} - -/* - * Structure containing metadata for Kubernetes logic that builds a Spark executor. - */ -private[spark] case class KubernetesExecutorSpecificConf( - executorId: String, - driverPod: Option[Pod]) - extends KubernetesRoleSpecificConf - -/* - * Structure containing metadata for HADOOP_CONF_DIR customization - */ -private[spark] case class HadoopConfSpec( - hadoopConfDir: Option[String], - hadoopConfigMapName: Option[String]) - /** * Structure containing metadata for Kubernetes logic to build Spark pods. */ -private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( - sparkConf: SparkConf, - roleSpecificConf: T, - appResourceNamePrefix: String, - appId: String, - roleLabels: Map[String, String], - roleAnnotations: Map[String, String], - roleSecretNamesToMountPaths: Map[String, String], - roleSecretEnvNamesToKeyRefs: Map[String, String], - roleEnvs: Map[String, String], - roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], - hadoopConfSpec: Option[HadoopConfSpec]) { +private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) { - def hadoopConfigMapName: String = s"$appResourceNamePrefix-hadoop-config" + val resourceNamePrefix: String + def labels: Map[String, String] + def environment: Map[String, String] + def annotations: Map[String, String] + def secretEnvNamesToKeyRefs: Map[String, String] + def secretNamesToMountPaths: Map[String, String] + def volumes: Seq[KubernetesVolumeSpec] - def krbConfigMapName: String = s"$appResourceNamePrefix-krb5-file" + def appName: String = get("spark.app.name", "spark") - def tokenManager(conf: SparkConf, hConf: Configuration): KubernetesHadoopDelegationTokenManager = - new KubernetesHadoopDelegationTokenManager(conf, hConf) + def hadoopConfigMapName: String = s"$resourceNamePrefix-hadoop-config" - def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + def krbConfigMapName: String = s"$resourceNamePrefix-krb5-file" - def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + def namespace: String = get(KUBERNETES_NAMESPACE) - def imagePullSecrets(): Seq[LocalObjectReference] = { + def imagePullPolicy: String = get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets: Seq[LocalObjectReference] = { sparkConf .get(IMAGE_PULL_SECRETS) - .map(_.split(",")) - .getOrElse(Array.empty[String]) - .map(_.trim) .map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } } - def nodeSelector(): Map[String, String] = + def nodeSelector: Map[String, String] = KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + def contains(config: ConfigEntry[_]): Boolean = sparkConf.contains(config) + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) def get(conf: String): String = sparkConf.get(conf) @@ -112,125 +72,139 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( def getOption(key: String): Option[String] = sparkConf.getOption(key) } +private[spark] class KubernetesDriverConf( + sparkConf: SparkConf, + val appId: String, + val mainAppResource: MainAppResource, + val mainClass: String, + val appArgs: Array[String], + val pyFiles: Seq[String]) + extends KubernetesConf(sparkConf) { + + override val resourceNamePrefix: String = { + val custom = if (Utils.isTesting) get(KUBERNETES_DRIVER_POD_NAME_PREFIX) else None + custom.getOrElse(KubernetesConf.getResourceNamePrefix(appName)) + } + + override def labels: Map[String, String] = { + val presetLabels = Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + + presetLabels.keys.foreach { key => + require( + !driverCustomLabels.contains(key), + s"Label with key $key is not allowed as it is reserved for Spark bookkeeping operations.") + } + + driverCustomLabels ++ presetLabels + } + + override def environment: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + } + + override def annotations: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + } + + override def secretNamesToMountPaths: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + } + + override def secretEnvNamesToKeyRefs: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) + } + + override def volumes: Seq[KubernetesVolumeSpec] = { + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX) + } +} + +private[spark] class KubernetesExecutorConf( + sparkConf: SparkConf, + val appId: String, + val executorId: String, + val driverPod: Option[Pod]) + extends KubernetesConf(sparkConf) { + + override val resourceNamePrefix: String = { + get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX).getOrElse( + KubernetesConf.getResourceNamePrefix(appName)) + } + + override def labels: Map[String, String] = { + val presetLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) + + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + + presetLabels.keys.foreach { key => + require( + !executorCustomLabels.contains(key), + s"Custom executor labels cannot contain $key as it is reserved for Spark.") + } + + executorCustomLabels ++ presetLabels + } + + override def environment: Map[String, String] = sparkConf.getExecutorEnv.toMap + + override def annotations: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + } + + override def secretNamesToMountPaths: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + } + + override def secretEnvNamesToKeyRefs: Map[String, String] = { + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) + } + + override def volumes: Seq[KubernetesVolumeSpec] = { + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX) + } + +} + private[spark] object KubernetesConf { def createDriverConf( sparkConf: SparkConf, - appName: String, - appResourceNamePrefix: String, appId: String, mainAppResource: MainAppResource, mainClass: String, appArgs: Array[String], - maybePyFiles: Option[String], - hadoopConfDir: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - val driverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> appId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) - val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) - val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) - val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( - sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) - // Also parse executor volumes in order to verify configuration - // before the driver pod is created - KubernetesVolumeUtils.parseVolumesWithPrefix( - sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) - - val hadoopConfigMapName = sparkConf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) - KubernetesUtils.requireNandDefined( - hadoopConfDir, - hadoopConfigMapName, - "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " + - "as the creation of an additional ConfigMap, when one is already specified is extraneous" ) - val hadoopConfSpec = - if (hadoopConfDir.isDefined || hadoopConfigMapName.isDefined) { - Some(HadoopConfSpec(hadoopConfDir, hadoopConfigMapName)) - } else { - None - } - val pyFiles = maybePyFiles.map(Utils.stringToSeq).getOrElse(Nil) + maybePyFiles: Option[String]): KubernetesDriverConf = { + // Parse executor volumes in order to verify configuration before the driver pod is created. + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX) - - KubernetesConf( - sparkConf.clone(), - KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs, pyFiles), - appResourceNamePrefix, - appId, - driverLabels, - driverAnnotations, - driverSecretNamesToMountPaths, - driverSecretEnvNamesToKeyRefs, - driverEnvs, - driverVolumes, - hadoopConfSpec) + val pyFiles = maybePyFiles.map(Utils.stringToSeq).getOrElse(Nil) + new KubernetesDriverConf(sparkConf.clone(), appId, mainAppResource, mainClass, appArgs, + pyFiles) } def createExecutorConf( sparkConf: SparkConf, executorId: String, appId: String, - driverPod: Option[Pod]): KubernetesConf[KubernetesExecutorSpecificConf] = { - val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorCustomLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorCustomLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - val executorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> appId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorCustomLabels - val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) - val executorEnv = sparkConf.getExecutorEnv.toMap - val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( - sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) - - // If no prefix is defined then we are in pure client mode - // (not the one used by cluster mode inside the container) - val appResourceNamePrefix = { - if (sparkConf.getOption(KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key).isEmpty) { - getResourceNamePrefix(getAppName(sparkConf)) - } else { - sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - } - } + driverPod: Option[Pod]): KubernetesExecutorConf = { + new KubernetesExecutorConf(sparkConf.clone(), appId, executorId, driverPod) + } - KubernetesConf( - sparkConf.clone(), - KubernetesExecutorSpecificConf(executorId, driverPod), - appResourceNamePrefix, - appId, - executorLabels, - executorAnnotations, - executorMountSecrets, - executorEnvSecrets, - executorEnv, - executorVolumes, - None) + def getResourceNamePrefix(appName: String): String = { + val launchTime = System.currentTimeMillis() + s"$appName-$launchTime" + .trim + .toLowerCase(Locale.ROOT) + .replaceAll("\\s+", "-") + .replaceAll("\\.", "-") + .replaceAll("[^a-z0-9\\-]", "") + .replaceAll("-+", "-") } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index 1a214fad96618..0ebe8fd26015d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -18,12 +18,10 @@ package org.apache.spark.deploy.k8s private[spark] sealed trait KubernetesVolumeSpecificConf -private[spark] case class KubernetesHostPathVolumeConf( - hostPath: String) +private[spark] case class KubernetesHostPathVolumeConf(hostPath: String) extends KubernetesVolumeSpecificConf -private[spark] case class KubernetesPVCVolumeConf( - claimName: String) +private[spark] case class KubernetesPVCVolumeConf(claimName: String) extends KubernetesVolumeSpecificConf private[spark] case class KubernetesEmptyDirVolumeConf( @@ -31,9 +29,9 @@ private[spark] case class KubernetesEmptyDirVolumeConf( sizeLimit: Option[String]) extends KubernetesVolumeSpecificConf -private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( +private[spark] case class KubernetesVolumeSpec( volumeName: String, mountPath: String, mountSubPath: String, mountReadOnly: Boolean, - volumeConf: T) + volumeConf: KubernetesVolumeSpecificConf) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 155326469235b..c0c4f86f1a6a0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -16,10 +16,6 @@ */ package org.apache.spark.deploy.k8s -import java.util.NoSuchElementException - -import scala.util.{Failure, Success, Try} - import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ @@ -31,9 +27,7 @@ private[spark] object KubernetesVolumeUtils { * @param prefix the given property name prefix * @return a Map storing with volume name as key and spec as value */ - def parseVolumesWithPrefix( - sparkConf: SparkConf, - prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + def parseVolumesWithPrefix(sparkConf: SparkConf, prefix: String): Seq[KubernetesVolumeSpec] = { val properties = sparkConf.getAllWithPrefix(prefix).toMap getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => @@ -41,17 +35,13 @@ private[spark] object KubernetesVolumeUtils { val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" - for { - path <- properties.getTry(pathKey) - volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) - } yield KubernetesVolumeSpec( + KubernetesVolumeSpec( volumeName = volumeName, - mountPath = path, + mountPath = properties(pathKey), mountSubPath = properties.get(subPathKey).getOrElse(""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), - volumeConf = volumeConf - ) - } + volumeConf = parseVolumeSpecificConf(properties, volumeType, volumeName)) + }.toSeq } /** @@ -61,9 +51,7 @@ private[spark] object KubernetesVolumeUtils { * @param properties flat mapping of property names to values * @return Set[(volumeType, volumeName)] */ - private def getVolumeTypesAndNames( - properties: Map[String, String] - ): Set[(String, String)] = { + private def getVolumeTypesAndNames(properties: Map[String, String]): Set[(String, String)] = { properties.keys.flatMap { k => k.split('.').toList match { case tpe :: name :: _ => Some((tpe, name)) @@ -73,40 +61,25 @@ private[spark] object KubernetesVolumeUtils { } private def parseVolumeSpecificConf( - options: Map[String, String], - volumeType: String, - volumeName: String): Try[KubernetesVolumeSpecificConf] = { + options: Map[String, String], + volumeType: String, + volumeName: String): KubernetesVolumeSpecificConf = { volumeType match { case KUBERNETES_VOLUMES_HOSTPATH_TYPE => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" - for { - path <- options.getTry(pathKey) - } yield KubernetesHostPathVolumeConf(path) + KubernetesHostPathVolumeConf(options(pathKey)) case KUBERNETES_VOLUMES_PVC_TYPE => val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" - for { - claimName <- options.getTry(claimNameKey) - } yield KubernetesPVCVolumeConf(claimName) + KubernetesPVCVolumeConf(options(claimNameKey)) case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" - Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey)) case _ => - Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) - } - } - - /** - * Convenience wrapper to accumulate key lookup errors - */ - implicit private class MapOps[A, B](m: Map[A, B]) { - def getTry(key: A): Try[B] = { - m - .get(key) - .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + throw new IllegalArgumentException(s"Kubernetes Volume type `$volumeType` is not supported") } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 5ddf73cb16a6f..d8cf3653d3226 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -30,13 +30,12 @@ import org.apache.spark.internal.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Utils -private[spark] class BasicDriverFeatureStep( - conf: KubernetesConf[KubernetesDriverSpecificConf]) +private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) extends KubernetesFeatureConfigStep { private val driverPodName = conf .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"${conf.appResourceNamePrefix}-driver") + .getOrElse(s"${conf.resourceNamePrefix}-driver") private val driverContainerImage = conf .get(DRIVER_CONTAINER_IMAGE) @@ -52,8 +51,8 @@ private[spark] class BasicDriverFeatureStep( // The memory overhead factor to use. If the user has not set it, then use a different // value for non-JVM apps. This value is propagated to executors. private val overheadFactor = - if (conf.roleSpecificConf.mainAppResource.isInstanceOf[NonJVMResource]) { - if (conf.sparkConf.contains(MEMORY_OVERHEAD_FACTOR)) { + if (conf.mainAppResource.isInstanceOf[NonJVMResource]) { + if (conf.contains(MEMORY_OVERHEAD_FACTOR)) { conf.get(MEMORY_OVERHEAD_FACTOR) } else { NON_JVM_MEMORY_OVERHEAD_FACTOR @@ -68,8 +67,7 @@ private[spark] class BasicDriverFeatureStep( private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configurePod(pod: SparkPod): SparkPod = { - val driverCustomEnvs = conf.roleEnvs - .toSeq + val driverCustomEnvs = conf.environment.toSeq .map { env => new EnvVarBuilder() .withName(env._1) @@ -96,7 +94,7 @@ private[spark] class BasicDriverFeatureStep( val driverContainer = new ContainerBuilder(pod.container) .withName(Option(pod.container.getName).getOrElse(DEFAULT_DRIVER_CONTAINER_NAME)) .withImage(driverContainerImage) - .withImagePullPolicy(conf.imagePullPolicy()) + .withImagePullPolicy(conf.imagePullPolicy) .addNewPort() .withName(DRIVER_PORT_NAME) .withContainerPort(driverPort) @@ -130,13 +128,13 @@ private[spark] class BasicDriverFeatureStep( val driverPod = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(driverPodName) - .addToLabels(conf.roleLabels.asJava) - .addToAnnotations(conf.roleAnnotations.asJava) + .addToLabels(conf.labels.asJava) + .addToAnnotations(conf.annotations.asJava) .endMetadata() .editOrNewSpec() .withRestartPolicy("Never") - .addToNodeSelector(conf.nodeSelector().asJava) - .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .addToNodeSelector(conf.nodeSelector.asJava) + .addToImagePullSecrets(conf.imagePullSecrets: _*) .endSpec() .build() @@ -147,7 +145,7 @@ private[spark] class BasicDriverFeatureStep( val additionalProps = mutable.Map( KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, "spark.app.id" -> conf.appId, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.resourceNamePrefix, KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true", MEMORY_OVERHEAD_FACTOR.key -> overheadFactor.toString) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 7f397e6e84fa5..8bf315248388f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -29,8 +29,7 @@ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -private[spark] class BasicExecutorFeatureStep( - kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) +private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutorConf) extends KubernetesFeatureConfigStep { // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf @@ -42,7 +41,7 @@ private[spark] class BasicExecutorFeatureStep( .sparkConf .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + private val executorPodNamePrefix = kubernetesConf.resourceNamePrefix private val driverUrl = RpcEndpointAddress( kubernetesConf.get("spark.driver.host"), @@ -76,7 +75,7 @@ private[spark] class BasicExecutorFeatureStep( private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) override def configurePod(pod: SparkPod): SparkPod = { - val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.executorId}" // hostname must be no longer than 63 characters, so take the last 63 characters of the pod // name as the hostname. This preserves uniqueness since the end of name contains @@ -98,7 +97,7 @@ private[spark] class BasicExecutorFeatureStep( .get(EXECUTOR_JAVA_OPTIONS) .map { opts => val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, - kubernetesConf.roleSpecificConf.executorId) + kubernetesConf.executorId) val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => @@ -112,8 +111,8 @@ private[spark] class BasicExecutorFeatureStep( (ENV_APPLICATION_ID, kubernetesConf.appId), // This is to set the SPARK_CONF_DIR to be /opt/spark/conf (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ - kubernetesConf.roleEnvs) + (ENV_EXECUTOR_ID, kubernetesConf.executorId)) ++ + kubernetesConf.environment) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) @@ -138,7 +137,7 @@ private[spark] class BasicExecutorFeatureStep( val executorContainer = new ContainerBuilder(pod.container) .withName(Option(pod.container.getName).getOrElse(DEFAULT_EXECUTOR_CONTAINER_NAME)) .withImage(executorContainerImage) - .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withImagePullPolicy(kubernetesConf.imagePullPolicy) .editOrNewResources() .addToRequests("memory", executorMemoryQuantity) .addToLimits("memory", executorMemoryQuantity) @@ -158,27 +157,27 @@ private[spark] class BasicExecutorFeatureStep( .endResources() .build() }.getOrElse(executorContainer) - val driverPod = kubernetesConf.roleSpecificConf.driverPod - val ownerReference = driverPod.map(pod => + val ownerReference = kubernetesConf.driverPod.map { pod => new OwnerReferenceBuilder() .withController(true) .withApiVersion(pod.getApiVersion) .withKind(pod.getKind) .withName(pod.getMetadata.getName) .withUid(pod.getMetadata.getUid) - .build()) + .build() + } val executorPod = new PodBuilder(pod.pod) .editOrNewMetadata() .withName(name) - .addToLabels(kubernetesConf.roleLabels.asJava) - .addToAnnotations(kubernetesConf.roleAnnotations.asJava) + .addToLabels(kubernetesConf.labels.asJava) + .addToAnnotations(kubernetesConf.annotations.asJava) .addToOwnerReferences(ownerReference.toSeq: _*) .endMetadata() .editOrNewSpec() .withHostname(hostname) .withRestartPolicy("Never") - .addToNodeSelector(kubernetesConf.nodeSelector().asJava) - .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .addToNodeSelector(kubernetesConf.nodeSelector.asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets: _*) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala index 8b8f0d01d49f7..76b4ec98d494e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala @@ -32,13 +32,11 @@ import org.apache.spark.util.Utils * Creates the driver command for running the user app, and propagates needed configuration so * executors can also find the app code. */ -private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDriverSpecificConf]) +private[spark] class DriverCommandFeatureStep(conf: KubernetesDriverConf) extends KubernetesFeatureConfigStep { - private val driverConf = conf.roleSpecificConf - override def configurePod(pod: SparkPod): SparkPod = { - driverConf.mainAppResource match { + conf.mainAppResource match { case JavaMainAppResource(_) => configureForJava(pod) @@ -51,7 +49,7 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDri } override def getAdditionalPodSystemProperties(): Map[String, String] = { - driverConf.mainAppResource match { + conf.mainAppResource match { case JavaMainAppResource(res) => res.map(additionalJavaProperties).getOrElse(Map.empty) @@ -71,10 +69,10 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDri } private def configureForPython(pod: SparkPod, res: String): SparkPod = { - val maybePythonFiles = if (driverConf.pyFiles.nonEmpty) { + val maybePythonFiles = if (conf.pyFiles.nonEmpty) { // Delineation by ":" is to append the PySpark Files to the PYTHONPATH // of the respective PySpark pod - val resolved = KubernetesUtils.resolveFileUrisAndPath(driverConf.pyFiles) + val resolved = KubernetesUtils.resolveFileUrisAndPath(conf.pyFiles) Some(new EnvVarBuilder() .withName(ENV_PYSPARK_FILES) .withValue(resolved.mkString(":")) @@ -85,7 +83,7 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDri val pythonEnvs = Seq(new EnvVarBuilder() .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) - .withValue(conf.sparkConf.get(PYSPARK_MAJOR_PYTHON_VERSION)) + .withValue(conf.get(PYSPARK_MAJOR_PYTHON_VERSION)) .build()) ++ maybePythonFiles @@ -105,9 +103,9 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDri new ContainerBuilder(pod.container) .addToArgs("driver") .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", driverConf.mainClass) + .addToArgs("--class", conf.mainClass) .addToArgs(resource) - .addToArgs(driverConf.appArgs: _*) + .addToArgs(conf.appArgs: _*) } private def additionalJavaProperties(resource: String): Map[String, String] = { @@ -116,7 +114,7 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDri private def additionalPythonProperties(resource: String): Map[String, String] = { resourceType(APP_RESOURCE_TYPE_PYTHON) ++ - mergeFileList("spark.files", Seq(resource) ++ driverConf.pyFiles) + mergeFileList("spark.files", Seq(resource) ++ conf.pyFiles) } private def additionalRProperties(resource: String): Map[String, String] = { @@ -124,7 +122,7 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesConf[KubernetesDri } private def mergeFileList(key: String, filesToAdd: Seq[String]): Map[String, String] = { - val existing = Utils.stringToSeq(conf.sparkConf.get(key, "")) + val existing = Utils.stringToSeq(conf.get(key, "")) Map(key -> (existing ++ filesToAdd).distinct.mkString(",")) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala index ff5ad6673b309..795ca49a3c87b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -28,7 +28,7 @@ import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf) extends KubernetesFeatureConfigStep { // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. @@ -66,7 +66,7 @@ private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: Kube clientCertDataBase64.isDefined private val driverCredentialsSecretName = - s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + s"${kubernetesConf.resourceNamePrefix}-kubernetes-credentials" override def configurePod(pod: SparkPod): SparkPod = { if (!shouldMountSecret) { @@ -122,7 +122,7 @@ private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: Kube val redactedTokens = kubernetesConf.sparkConf.getAll .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) .toMap - .mapValues( _ => "") + .map { case (k, v) => (k, "") } redactedTokens ++ resolvedMountedCaCertFile.map { file => Map( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala index f2d7bbd08f305..42305457f4fff 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -20,13 +20,13 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesDriverConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.util.{Clock, SystemClock} private[spark] class DriverServiceFeatureStep( - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + kubernetesConf: KubernetesDriverConf, clock: Clock = new SystemClock) extends KubernetesFeatureConfigStep with Logging { import DriverServiceFeatureStep._ @@ -38,7 +38,7 @@ private[spark] class DriverServiceFeatureStep( s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + "managed via a Kubernetes service.") - private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val preferredServiceName = s"${kubernetesConf.resourceNamePrefix}$DRIVER_SVC_POSTFIX" private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { preferredServiceName } else { @@ -58,7 +58,7 @@ private[spark] class DriverServiceFeatureStep( override def configurePod(pod: SparkPod): SparkPod = pod override def getAdditionalPodSystemProperties(): Map[String, String] = { - val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace}.svc" Map(DRIVER_HOST_KEY -> driverHostname, "spark.driver.port" -> driverPort.toString, org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> @@ -72,7 +72,7 @@ private[spark] class DriverServiceFeatureStep( .endMetadata() .withNewSpec() .withClusterIP("None") - .withSelector(kubernetesConf.roleLabels.asJava) + .withSelector(kubernetesConf.labels.asJava) .addNewPort() .withName(DRIVER_PORT_NAME) .withPort(driverPort) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala index 03ff7d48420ff..d78f04dcc40e6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -20,14 +20,13 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} -private[spark] class EnvSecretsFeatureStep( - kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) +private[spark] class EnvSecretsFeatureStep(kubernetesConf: KubernetesConf) extends KubernetesFeatureConfigStep { override def configurePod(pod: SparkPod): SparkPod = { val addedEnvSecrets = kubernetesConf - .roleSecretEnvNamesToKeyRefs + .secretEnvNamesToKeyRefs .map{ case (envName, keyRef) => // Keyref parts val keyRefParts = keyRef.split(":") @@ -50,8 +49,4 @@ private[spark] class EnvSecretsFeatureStep( .build() SparkPod(pod.pod, containerWithEnvVars) } - - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala index fd09de2a918a1..bca66759d586e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala @@ -16,9 +16,7 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.HasMetadata - -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil import org.apache.spark.internal.Logging @@ -28,21 +26,15 @@ import org.apache.spark.internal.Logging * containing Hadoop config files mounted as volumes and an ENV variable * pointed to the mounted file directory. */ -private[spark] class HadoopConfExecutorFeatureStep( - kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) +private[spark] class HadoopConfExecutorFeatureStep(conf: KubernetesExecutorConf) extends KubernetesFeatureConfigStep with Logging { override def configurePod(pod: SparkPod): SparkPod = { - val sparkConf = kubernetesConf.sparkConf - val hadoopConfDirCMapName = sparkConf.getOption(HADOOP_CONFIG_MAP_NAME) + val hadoopConfDirCMapName = conf.getOption(HADOOP_CONFIG_MAP_NAME) require(hadoopConfDirCMapName.isDefined, "Ensure that the env `HADOOP_CONF_DIR` is defined either in the client or " + " using pre-existing ConfigMaps") logInfo("HADOOP_CONF_DIR defined") HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) } - - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala index 5b6a6d5a7db45..e342110763196 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala @@ -16,28 +16,19 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.HasMetadata - -import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesExecutorSpecificConf import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil -import org.apache.spark.internal.Logging /** * This step is responsible for setting ENV_SPARK_USER when HADOOP_FILES are detected * however, this step would not be run if Kerberos is enabled, as Kerberos sets SPARK_USER */ -private[spark] class HadoopSparkUserExecutorFeatureStep( - kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) - extends KubernetesFeatureConfigStep with Logging { +private[spark] class HadoopSparkUserExecutorFeatureStep(conf: KubernetesExecutorConf) + extends KubernetesFeatureConfigStep { override def configurePod(pod: SparkPod): SparkPod = { - val sparkUserName = kubernetesConf.sparkConf.get(KERBEROS_SPARK_USER_NAME) + val sparkUserName = conf.get(KERBEROS_SPARK_USER_NAME) HadoopBootstrapUtil.bootstrapSparkUserPod(sparkUserName, pod) } - - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala index ce47933b7f700..c6d5a866fa7bc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala @@ -16,40 +16,43 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.HasMetadata +import io.fabric8.kubernetes.api.model.{HasMetadata, Secret, SecretBuilder} +import org.apache.commons.codec.binary.Base64 +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesDriverConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesDriverSpecificConf import org.apache.spark.deploy.k8s.features.hadooputils._ -import org.apache.spark.internal.Logging +import org.apache.spark.deploy.security.HadoopDelegationTokenManager /** * Runs the necessary Hadoop-based logic based on Kerberos configs and the presence of the * HADOOP_CONF_DIR. This runs various bootstrap methods defined in HadoopBootstrapUtil. */ -private[spark] class KerberosConfDriverFeatureStep( - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) - extends KubernetesFeatureConfigStep with Logging { - - require(kubernetesConf.hadoopConfSpec.isDefined, - "Ensure that HADOOP_CONF_DIR is defined either via env or a pre-defined ConfigMap") - private val hadoopConfDirSpec = kubernetesConf.hadoopConfSpec.get - private val conf = kubernetesConf.sparkConf - private val principal = conf.get(org.apache.spark.internal.config.PRINCIPAL) - private val keytab = conf.get(org.apache.spark.internal.config.KEYTAB) - private val existingSecretName = conf.get(KUBERNETES_KERBEROS_DT_SECRET_NAME) - private val existingSecretItemKey = conf.get(KUBERNETES_KERBEROS_DT_SECRET_ITEM_KEY) - private val krb5File = conf.get(KUBERNETES_KERBEROS_KRB5_FILE) - private val krb5CMap = conf.get(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP) - private val kubeTokenManager = kubernetesConf.tokenManager(conf, - SparkHadoopUtil.get.newConfiguration(conf)) +private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDriverConf) + extends KubernetesFeatureConfigStep { + + private val hadoopConfDir = Option(kubernetesConf.sparkConf.getenv(ENV_HADOOP_CONF_DIR)) + private val hadoopConfigMapName = kubernetesConf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) + KubernetesUtils.requireNandDefined( + hadoopConfDir, + hadoopConfigMapName, + "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " + + "as the creation of an additional ConfigMap, when one is already specified is extraneous") + + private val principal = kubernetesConf.get(org.apache.spark.internal.config.PRINCIPAL) + private val keytab = kubernetesConf.get(org.apache.spark.internal.config.KEYTAB) + private val existingSecretName = kubernetesConf.get(KUBERNETES_KERBEROS_DT_SECRET_NAME) + private val existingSecretItemKey = kubernetesConf.get(KUBERNETES_KERBEROS_DT_SECRET_ITEM_KEY) + private val krb5File = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_FILE) + private val krb5CMap = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf) + private val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, hadoopConf) private val isKerberosEnabled = - (hadoopConfDirSpec.hadoopConfDir.isDefined && kubeTokenManager.isSecurityEnabled) || - (hadoopConfDirSpec.hadoopConfigMapName.isDefined && - (krb5File.isDefined || krb5CMap.isDefined)) + (hadoopConfDir.isDefined && UserGroupInformation.isSecurityEnabled) || + (hadoopConfigMapName.isDefined && (krb5File.isDefined || krb5CMap.isDefined)) require(keytab.isEmpty || isKerberosEnabled, "You must enable Kerberos support if you are specifying a Kerberos Keytab") @@ -76,11 +79,11 @@ private[spark] class KerberosConfDriverFeatureStep( "If a secret storing a Kerberos Delegation Token is specified you must also" + " specify the item-key where the data is stored") - private val hadoopConfigurationFiles = hadoopConfDirSpec.hadoopConfDir.map { hConfDir => + private val hadoopConfigurationFiles = hadoopConfDir.map { hConfDir => HadoopBootstrapUtil.getHadoopConfFiles(hConfDir) } private val newHadoopConfigMapName = - if (hadoopConfDirSpec.hadoopConfigMapName.isEmpty) { + if (hadoopConfigMapName.isEmpty) { Some(kubernetesConf.hadoopConfigMapName) } else { None @@ -95,23 +98,24 @@ private[spark] class KerberosConfDriverFeatureStep( dtSecret = None, dtSecretName = secretName, dtSecretItemKey = secretItemKey, - jobUserName = kubeTokenManager.getCurrentUser.getShortUserName) + jobUserName = UserGroupInformation.getCurrentUser.getShortUserName) }).orElse( if (isKerberosEnabled) { - Some(HadoopKerberosLogin.buildSpec( - conf, - kubernetesConf.appResourceNamePrefix, - kubeTokenManager)) + Some(buildKerberosSpec()) } else { None } ) override def configurePod(pod: SparkPod): SparkPod = { + if (!isKerberosEnabled) { + return pod + } + val hadoopBasedSparkPod = HadoopBootstrapUtil.bootstrapHadoopConfDir( - hadoopConfDirSpec.hadoopConfDir, + hadoopConfDir, newHadoopConfigMapName, - hadoopConfDirSpec.hadoopConfigMapName, + hadoopConfigMapName, pod) kerberosConfSpec.map { hSpec => HadoopBootstrapUtil.bootstrapKerberosPod( @@ -124,11 +128,15 @@ private[spark] class KerberosConfDriverFeatureStep( hadoopBasedSparkPod) }.getOrElse( HadoopBootstrapUtil.bootstrapSparkUserPod( - kubeTokenManager.getCurrentUser.getShortUserName, + UserGroupInformation.getCurrentUser.getShortUserName, hadoopBasedSparkPod)) } override def getAdditionalPodSystemProperties(): Map[String, String] = { + if (!isKerberosEnabled) { + return Map.empty + } + val resolvedConfValues = kerberosConfSpec.map { hSpec => Map(KERBEROS_DT_SECRET_NAME -> hSpec.dtSecretName, KERBEROS_DT_SECRET_KEY -> hSpec.dtSecretItemKey, @@ -136,13 +144,16 @@ private[spark] class KerberosConfDriverFeatureStep( KRB5_CONFIG_MAP_NAME -> krb5CMap.getOrElse(kubernetesConf.krbConfigMapName)) }.getOrElse( Map(KERBEROS_SPARK_USER_NAME -> - kubeTokenManager.getCurrentUser.getShortUserName)) + UserGroupInformation.getCurrentUser.getShortUserName)) Map(HADOOP_CONFIG_MAP_NAME -> - hadoopConfDirSpec.hadoopConfigMapName.getOrElse( - kubernetesConf.hadoopConfigMapName)) ++ resolvedConfValues + hadoopConfigMapName.getOrElse(kubernetesConf.hadoopConfigMapName)) ++ resolvedConfValues } override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (!isKerberosEnabled) { + return Seq.empty + } + val hadoopConfConfigMap = for { hName <- newHadoopConfigMapName hFiles <- hadoopConfigurationFiles @@ -162,4 +173,34 @@ private[spark] class KerberosConfDriverFeatureStep( krb5ConfigMap.toSeq ++ kerberosDTSecret.toSeq } + + private def buildKerberosSpec(): KerberosConfigSpec = { + // The JobUserUGI will be taken fom the Local Ticket Cache or via keytab+principal + // The login happens in the SparkSubmit so login logic is not necessary to include + val jobUserUGI = UserGroupInformation.getCurrentUser + val creds = jobUserUGI.getCredentials + tokenManager.obtainDelegationTokens(creds) + val tokenData = SparkHadoopUtil.get.serialize(creds) + require(tokenData.nonEmpty, "Did not obtain any delegation tokens") + val newSecretName = + s"${kubernetesConf.resourceNamePrefix}-$KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME" + val secretDT = + new SecretBuilder() + .withNewMetadata() + .withName(newSecretName) + .endMetadata() + .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(tokenData)) + .build() + KerberosConfigSpec( + dtSecret = Some(secretDT), + dtSecretName = newSecretName, + dtSecretItemKey = KERBEROS_SECRET_KEY, + jobUserName = jobUserUGI.getShortUserName) + } + + private case class KerberosConfigSpec( + dtSecret: Option[Secret], + dtSecretName: String, + dtSecretItemKey: String, + jobUserName: String) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala index 06a88b6c229f7..32bb6a5d2bcbb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala @@ -16,38 +16,29 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.HasMetadata - -import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesExecutorSpecificConf import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil import org.apache.spark.internal.Logging /** * This step is responsible for mounting the DT secret for the executors */ -private[spark] class KerberosConfExecutorFeatureStep( - kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) +private[spark] class KerberosConfExecutorFeatureStep(conf: KubernetesExecutorConf) extends KubernetesFeatureConfigStep with Logging { - private val sparkConf = kubernetesConf.sparkConf - private val maybeKrb5CMap = sparkConf.getOption(KRB5_CONFIG_MAP_NAME) + private val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) require(maybeKrb5CMap.isDefined, "HADOOP_CONF_DIR ConfigMap not found") override def configurePod(pod: SparkPod): SparkPod = { logInfo(s"Mounting Resources for Kerberos") HadoopBootstrapUtil.bootstrapKerberosPod( - sparkConf.get(KERBEROS_DT_SECRET_NAME), - sparkConf.get(KERBEROS_DT_SECRET_KEY), - sparkConf.get(KERBEROS_SPARK_USER_NAME), + conf.get(KERBEROS_DT_SECRET_NAME), + conf.get(KERBEROS_DT_SECRET_KEY), + conf.get(KERBEROS_SPARK_USER_NAME), None, None, maybeKrb5CMap, pod) } - - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty[HasMetadata] } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala index be386e119d465..19ed2df5551db 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -16,16 +16,15 @@ */ package org.apache.spark.deploy.k8s.features -import java.nio.file.Paths import java.util.UUID import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ private[spark] class LocalDirsFeatureStep( - conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], + conf: KubernetesConf, defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}") extends KubernetesFeatureConfigStep { @@ -73,8 +72,4 @@ private[spark] class LocalDirsFeatureStep( .build() SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) } - - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala index 97fa9499b2edb..f4e1a3a326729 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -18,14 +18,13 @@ package org.apache.spark.deploy.k8s.features import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} -private[spark] class MountSecretsFeatureStep( - kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) +private[spark] class MountSecretsFeatureStep(kubernetesConf: KubernetesConf) extends KubernetesFeatureConfigStep { override def configurePod(pod: SparkPod): SparkPod = { val addedVolumes = kubernetesConf - .roleSecretNamesToMountPaths + .secretNamesToMountPaths .keys .map(secretName => new VolumeBuilder() @@ -40,7 +39,7 @@ private[spark] class MountSecretsFeatureStep( .endSpec() .build() val addedVolumeMounts = kubernetesConf - .roleSecretNamesToMountPaths + .secretNamesToMountPaths .map { case (secretName, mountPath) => new VolumeMountBuilder() @@ -54,9 +53,5 @@ private[spark] class MountSecretsFeatureStep( SparkPod(podWithVolumes, containerWithMounts) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty - private def secretVolumeName(secretName: String): String = s"$secretName-volume" } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index 1473a7d3ee7f6..8548e7057cdf0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -20,12 +20,11 @@ import io.fabric8.kubernetes.api.model._ import org.apache.spark.deploy.k8s._ -private[spark] class MountVolumesFeatureStep( - kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) +private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) extends KubernetesFeatureConfigStep { override def configurePod(pod: SparkPod): SparkPod = { - val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + val (volumeMounts, volumes) = constructVolumes(conf.volumes).unzip val podWithVolumes = new PodBuilder(pod.pod) .editSpec() @@ -40,12 +39,8 @@ private[spark] class MountVolumesFeatureStep( SparkPod(podWithVolumes, containerWithVolumeMounts) } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty - - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty - private def constructVolumes( - volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + volumeSpecs: Iterable[KubernetesVolumeSpec] ): Iterable[(VolumeMount, Volume)] = { volumeSpecs.map { spec => val volumeMount = new VolumeMountBuilder() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala index 28e2d1726ae27..09dcf93a54f8e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -22,12 +22,11 @@ import java.nio.charset.StandardCharsets import com.google.common.io.Files import io.fabric8.kubernetes.api.model.{ConfigMapBuilder, ContainerBuilder, HasMetadata, PodBuilder} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -private[spark] class PodTemplateConfigMapStep( - conf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) +private[spark] class PodTemplateConfigMapStep(conf: KubernetesConf) extends KubernetesFeatureConfigStep { def configurePod(pod: SparkPod): SparkPod = { val podWithVolume = new PodBuilder(pod.pod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala deleted file mode 100644 index 0022d8f242a72..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopKerberosLogin.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.hadooputils - -import io.fabric8.kubernetes.api.model.SecretBuilder -import org.apache.commons.codec.binary.Base64 - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.security.KubernetesHadoopDelegationTokenManager - -/** - * This logic does all the heavy lifting for Delegation Token creation. This step - * assumes that the job user has either specified a principal and keytab or ran - * $kinit before running spark-submit. By running UGI.getCurrentUser we are able - * to obtain the current user, either signed in via $kinit or keytab. With the - * Job User principal you then retrieve the delegation token from the NameNode - * and store values in DelegationToken. Lastly, the class puts the data into - * a secret. All this is defined in a KerberosConfigSpec. - */ -private[spark] object HadoopKerberosLogin { - def buildSpec( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String, - tokenManager: KubernetesHadoopDelegationTokenManager): KerberosConfigSpec = { - // The JobUserUGI will be taken fom the Local Ticket Cache or via keytab+principal - // The login happens in the SparkSubmit so login logic is not necessary to include - val jobUserUGI = tokenManager.getCurrentUser - val originalCredentials = jobUserUGI.getCredentials - tokenManager.obtainDelegationTokens(originalCredentials) - - val tokenData = SparkHadoopUtil.get.serialize(originalCredentials) - - val initialTokenDataKeyName = KERBEROS_SECRET_KEY - val newSecretName = s"$kubernetesResourceNamePrefix-$KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME" - val secretDT = - new SecretBuilder() - .withNewMetadata() - .withName(newSecretName) - .endMetadata() - .addToData(initialTokenDataKeyName, Base64.encodeBase64String(tokenData)) - .build() - KerberosConfigSpec( - dtSecret = Some(secretDT), - dtSecretName = newSecretName, - dtSecretItemKey = initialTokenDataKeyName, - jobUserName = jobUserUGI.getShortUserName) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala deleted file mode 100644 index 3e98d5811d83f..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/security/KubernetesHadoopDelegationTokenManager.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.k8s.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.UserGroupInformation - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.security.HadoopDelegationTokenManager - -/** - * Adds Kubernetes-specific functionality to HadoopDelegationTokenManager. - */ -private[spark] class KubernetesHadoopDelegationTokenManager( - _sparkConf: SparkConf, - _hadoopConf: Configuration) - extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf) { - - def getCurrentUser: UserGroupInformation = UserGroupInformation.getCurrentUser - def isSecurityEnabled: Boolean = UserGroupInformation.isSecurityEnabled - -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 543d6b16d6ae2..70a93c968795e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -17,19 +17,17 @@ package org.apache.spark.deploy.k8s.submit import java.io.StringWriter -import java.util.{Collections, Locale, Properties, UUID} import java.util.{Collections, UUID} import java.util.Properties import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.hadoop.security.UserGroupInformation import scala.collection.mutable import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -47,8 +45,7 @@ private[spark] case class ClientArguments( mainAppResource: MainAppResource, mainClass: String, driverArgs: Array[String], - maybePyFiles: Option[String], - hadoopConfigDir: Option[String]) + maybePyFiles: Option[String]) private[spark] object ClientArguments { @@ -82,8 +79,7 @@ private[spark] object ClientArguments { mainAppResource, mainClass.get, driverArgs.toArray, - maybePyFiles, - sys.env.get(ENV_HADOOP_CONF_DIR)) + maybePyFiles) } } @@ -92,27 +88,24 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * + * @param conf The kubernetes driver config. * @param builder Responsible for building the base driver pod based on a composition of * implemented features. - * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete - * @param appName the application name * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( + conf: KubernetesDriverConf, builder: KubernetesDriverBuilder, - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, - appName: String, - watcher: LoggingPodStatusWatcher, - kubernetesResourceNamePrefix: String) extends Logging { + watcher: LoggingPodStatusWatcher) extends Logging { def run(): Unit = { - val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) - val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" + val resolvedDriverSpec = builder.buildFromFeatures(conf) + val configMapName = s"${conf.resourceNamePrefix}-driver-conf-map" val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap @@ -155,11 +148,11 @@ private[spark] class Client( } if (waitForAppCompletion) { - logInfo(s"Waiting for application $appName to finish...") + logInfo(s"Waiting for application ${conf.appName} to finish...") watcher.awaitCompletion() - logInfo(s"Application $appName finished.") + logInfo(s"Application ${conf.appName} finished.") } else { - logInfo(s"Deployed Spark application $appName into Kubernetes.") + logInfo(s"Deployed Spark application ${conf.appName} into Kubernetes.") } } } @@ -216,19 +209,13 @@ private[spark] class KubernetesClientApplication extends SparkApplication { // a unique app ID (captured by spark.app.id) in the format below. val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val kubernetesResourceNamePrefix = KubernetesClientApplication.getResourceNamePrefix(appName) - sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, - appName, - kubernetesResourceNamePrefix, kubernetesAppId, clientArguments.mainAppResource, clientArguments.mainClass, clientArguments.driverArgs, - clientArguments.maybePyFiles, - clientArguments.hadoopConfigDir) - val namespace = kubernetesConf.namespace() + clientArguments.maybePyFiles) // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = KubernetesUtils.parseMasterUrl(sparkConf.get("spark.master")) @@ -238,36 +225,18 @@ private[spark] class KubernetesClientApplication extends SparkApplication { Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, - Some(namespace), + Some(kubernetesConf.namespace), KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX, sparkConf, None, None)) { kubernetesClient => val client = new Client( - KubernetesDriverBuilder(kubernetesClient, kubernetesConf.sparkConf), kubernetesConf, + KubernetesDriverBuilder(kubernetesClient, kubernetesConf.sparkConf), kubernetesClient, waitForAppCompletion, - appName, - watcher, - kubernetesResourceNamePrefix) + watcher) client.run() } } } - -private[spark] object KubernetesClientApplication { - - def getAppName(conf: SparkConf): String = conf.getOption("spark.app.name").getOrElse("spark") - - def getResourceNamePrefix(appName: String): String = { - val launchTime = System.currentTimeMillis() - s"$appName-$launchTime" - .trim - .toLowerCase(Locale.ROOT) - .replaceAll("\\s+", "-") - .replaceAll("\\.", "-") - .replaceAll("[^a-z0-9\\-]", "") - .replaceAll("-+", "-") - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 167fb402cd402..a5ad9729aee9a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -21,57 +21,46 @@ import java.io.File import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.{Config, KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesDriverBuilder( - provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + provideBasicStep: (KubernetesDriverConf => BasicDriverFeatureStep) = new BasicDriverFeatureStep(_), - provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) - => DriverKubernetesCredentialsFeatureStep = + provideCredentialsStep: (KubernetesDriverConf => DriverKubernetesCredentialsFeatureStep) = new DriverKubernetesCredentialsFeatureStep(_), - provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + provideServiceStep: (KubernetesDriverConf => DriverServiceFeatureStep) = new DriverServiceFeatureStep(_), - provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => MountSecretsFeatureStep) = + provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), - provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => EnvSecretsFeatureStep) = + provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) - => LocalDirsFeatureStep = + provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) = new LocalDirsFeatureStep(_), - provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => MountVolumesFeatureStep) = + provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) = new MountVolumesFeatureStep(_), - provideDriverCommandStep: ( - KubernetesConf[KubernetesDriverSpecificConf] - => DriverCommandFeatureStep) = + provideDriverCommandStep: (KubernetesDriverConf => DriverCommandFeatureStep) = new DriverCommandFeatureStep(_), - provideHadoopGlobalStep: ( - KubernetesConf[KubernetesDriverSpecificConf] - => KerberosConfDriverFeatureStep) = - new KerberosConfDriverFeatureStep(_), - providePodTemplateConfigMapStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => PodTemplateConfigMapStep) = - new PodTemplateConfigMapStep(_), - provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { + provideHadoopGlobalStep: (KubernetesDriverConf => KerberosConfDriverFeatureStep) = + new KerberosConfDriverFeatureStep(_), + providePodTemplateConfigMapStep: (KubernetesConf => PodTemplateConfigMapStep) = + new PodTemplateConfigMapStep(_), + provideInitialPod: () => SparkPod = () => SparkPod.initialPod) { - def buildFromFeatures( - kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + def buildFromFeatures(kubernetesConf: KubernetesDriverConf): KubernetesDriverSpec = { val baseFeatures = Seq( provideBasicStep(kubernetesConf), provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { Seq(provideSecretsStep(kubernetesConf)) } else Nil - val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + val envSecretFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) { Seq(provideEnvSecretsStep(kubernetesConf)) } else Nil - val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + val volumesFeature = if (kubernetesConf.volumes.nonEmpty) { Seq(provideVolumesStep(kubernetesConf)) } else Nil val podTemplateFeature = if ( @@ -81,14 +70,12 @@ private[spark] class KubernetesDriverBuilder( val driverCommandStep = provideDriverCommandStep(kubernetesConf) - val maybeHadoopConfigStep = - kubernetesConf.hadoopConfSpec.map { _ => - provideHadoopGlobalStep(kubernetesConf)} + val hadoopConfigStep = Some(provideHadoopGlobalStep(kubernetesConf)) val allFeatures: Seq[KubernetesFeatureConfigStep] = baseFeatures ++ Seq(driverCommandStep) ++ secretFeature ++ envSecretFeature ++ volumesFeature ++ - maybeHadoopConfigStep.toSeq ++ podTemplateFeature + hadoopConfigStep ++ podTemplateFeature var spec = KubernetesDriverSpec( provideInitialPod(), diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index fc41a4770bce6..d24ff0d1e6600 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -26,50 +26,38 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) - => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesExecutorConf => BasicExecutorFeatureStep) = new BasicExecutorFeatureStep(_), - provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) - => MountSecretsFeatureStep = + provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), - provideEnvSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) - => LocalDirsFeatureStep = + provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) = new LocalDirsFeatureStep(_), - provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => MountVolumesFeatureStep) = + provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) = new MountVolumesFeatureStep(_), - provideHadoopConfStep: ( - KubernetesConf[KubernetesExecutorSpecificConf] - => HadoopConfExecutorFeatureStep) = + provideHadoopConfStep: (KubernetesExecutorConf => HadoopConfExecutorFeatureStep) = new HadoopConfExecutorFeatureStep(_), - provideKerberosConfStep: ( - KubernetesConf[KubernetesExecutorSpecificConf] - => KerberosConfExecutorFeatureStep) = + provideKerberosConfStep: (KubernetesExecutorConf => KerberosConfExecutorFeatureStep) = new KerberosConfExecutorFeatureStep(_), - provideHadoopSparkUserStep: ( - KubernetesConf[KubernetesExecutorSpecificConf] - => HadoopSparkUserExecutorFeatureStep) = + provideHadoopSparkUserStep: (KubernetesExecutorConf => HadoopSparkUserExecutorFeatureStep) = new HadoopSparkUserExecutorFeatureStep(_), provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { - def buildFromFeatures( - kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + def buildFromFeatures(kubernetesConf: KubernetesExecutorConf): SparkPod = { val sparkConf = kubernetesConf.sparkConf val maybeHadoopConfigMap = sparkConf.getOption(HADOOP_CONFIG_MAP_NAME) val maybeDTSecretName = sparkConf.getOption(KERBEROS_DT_SECRET_NAME) val maybeDTDataItem = sparkConf.getOption(KERBEROS_DT_SECRET_KEY) val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { Seq(provideSecretsStep(kubernetesConf)) } else Nil - val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + val secretEnvFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) { Seq(provideEnvSecretsStep(kubernetesConf)) } else Nil - val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + val volumesFeature = if (kubernetesConf.volumes.nonEmpty) { Seq(provideVolumesStep(kubernetesConf)) } else Nil diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index 41ca8d186c17b..f4d40b0b3590d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -26,10 +26,6 @@ import org.apache.spark.deploy.k8s.submit._ class KubernetesConfSuite extends SparkFunSuite { - private val APP_NAME = "test-app" - private val RESOURCE_NAME_PREFIX = "prefix" - private val APP_ID = "test-id" - private val MAIN_CLASS = "test-class" private val APP_ARGS = Array("arg1", "arg2") private val CUSTOM_LABELS = Map( "customLabel1Key" -> "customLabel1Value", @@ -49,26 +45,6 @@ class KubernetesConfSuite extends SparkFunSuite { private val DRIVER_POD = new PodBuilder().build() private val EXECUTOR_ID = "executor-id" - test("Basic driver translated fields.") { - val sparkConf = new SparkConf(false) - val conf = KubernetesConf.createDriverConf( - sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppResource = JavaMainAppResource(None), - MAIN_CLASS, - APP_ARGS, - maybePyFiles = None, - hadoopConfDir = None) - assert(conf.appId === APP_ID) - assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) - assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) - assert(conf.roleSpecificConf.appName === APP_NAME) - assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) - assert(conf.roleSpecificConf.appArgs === APP_ARGS) - } - test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") { val sparkConf = new SparkConf(false) .set(MEMORY_OVERHEAD_FACTOR, 0.3) @@ -90,22 +66,19 @@ class KubernetesConfSuite extends SparkFunSuite { val conf = KubernetesConf.createDriverConf( sparkConf, - APP_NAME, - RESOURCE_NAME_PREFIX, - APP_ID, - mainAppResource = JavaMainAppResource(None), - MAIN_CLASS, + KubernetesTestConf.APP_ID, + JavaMainAppResource(None), + KubernetesTestConf.MAIN_CLASS, APP_ARGS, - maybePyFiles = None, - hadoopConfDir = None) - assert(conf.roleLabels === Map( - SPARK_APP_ID_LABEL -> APP_ID, + None) + assert(conf.labels === Map( + SPARK_APP_ID_LABEL -> KubernetesTestConf.APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ CUSTOM_LABELS) - assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) - assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) - assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) - assert(conf.roleEnvs === CUSTOM_ENVS) + assert(conf.annotations === CUSTOM_ANNOTATIONS) + assert(conf.secretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.secretEnvNamesToKeyRefs === SECRET_ENV_VARS) + assert(conf.environment === CUSTOM_ENVS) assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) } @@ -113,20 +86,20 @@ class KubernetesConfSuite extends SparkFunSuite { val conf = KubernetesConf.createExecutorConf( new SparkConf(false), EXECUTOR_ID, - APP_ID, + KubernetesTestConf.APP_ID, Some(DRIVER_POD)) - assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) - assert(conf.roleSpecificConf.driverPod.get === DRIVER_POD) + assert(conf.executorId === EXECUTOR_ID) + assert(conf.driverPod.get === DRIVER_POD) } test("Image pull secrets.") { val conf = KubernetesConf.createExecutorConf( new SparkConf(false) - .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + .set(IMAGE_PULL_SECRETS, Seq("my-secret-1", "my-secret-2 ")), EXECUTOR_ID, - APP_ID, + KubernetesTestConf.APP_ID, Some(DRIVER_POD)) - assert(conf.imagePullSecrets() === + assert(conf.imagePullSecrets === Seq( new LocalObjectReferenceBuilder().withName("my-secret-1").build(), new LocalObjectReferenceBuilder().withName("my-secret-2").build())) @@ -150,14 +123,14 @@ class KubernetesConfSuite extends SparkFunSuite { val conf = KubernetesConf.createExecutorConf( sparkConf, EXECUTOR_ID, - APP_ID, + KubernetesTestConf.APP_ID, Some(DRIVER_POD)) - assert(conf.roleLabels === Map( + assert(conf.labels === Map( SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, - SPARK_APP_ID_LABEL -> APP_ID, + SPARK_APP_ID_LABEL -> KubernetesTestConf.APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) - assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) - assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) - assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) + assert(conf.annotations === CUSTOM_ANNOTATIONS) + assert(conf.secretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.secretEnvNamesToKeyRefs === SECRET_ENV_VARS) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala new file mode 100644 index 0000000000000..1d77a6d18152a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} + +/** + * Builder methods for KubernetesConf that allow easy control over what to return for a few + * properties. For use with tests instead of having to mock specific properties. + */ +object KubernetesTestConf { + + val APP_ID = "appId" + val MAIN_CLASS = "mainClass" + val RESOURCE_PREFIX = "prefix" + val EXECUTOR_ID = "1" + + private val DEFAULT_CONF = new SparkConf(false) + + // scalastyle:off argcount + def createDriverConf( + sparkConf: SparkConf = DEFAULT_CONF, + appId: String = APP_ID, + mainAppResource: MainAppResource = JavaMainAppResource(None), + mainClass: String = MAIN_CLASS, + appArgs: Array[String] = Array.empty, + pyFiles: Seq[String] = Nil, + resourceNamePrefix: Option[String] = None, + labels: Map[String, String] = Map.empty, + environment: Map[String, String] = Map.empty, + annotations: Map[String, String] = Map.empty, + secretEnvNamesToKeyRefs: Map[String, String] = Map.empty, + secretNamesToMountPaths: Map[String, String] = Map.empty, + volumes: Seq[KubernetesVolumeSpec] = Seq.empty): KubernetesDriverConf = { + val conf = sparkConf.clone() + + resourceNamePrefix.foreach { prefix => + conf.set(KUBERNETES_DRIVER_POD_NAME_PREFIX, prefix) + } + setPrefixedConfigs(conf, KUBERNETES_DRIVER_LABEL_PREFIX, labels) + setPrefixedConfigs(conf, KUBERNETES_DRIVER_ENV_PREFIX, environment) + setPrefixedConfigs(conf, KUBERNETES_DRIVER_ANNOTATION_PREFIX, annotations) + setPrefixedConfigs(conf, KUBERNETES_DRIVER_SECRETS_PREFIX, secretNamesToMountPaths) + setPrefixedConfigs(conf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX, secretEnvNamesToKeyRefs) + setVolumeSpecs(conf, KUBERNETES_DRIVER_VOLUMES_PREFIX, volumes) + + new KubernetesDriverConf(conf, appId, mainAppResource, mainClass, appArgs, pyFiles) + } + // scalastyle:on argcount + + def createExecutorConf( + sparkConf: SparkConf = DEFAULT_CONF, + driverPod: Option[Pod] = None, + labels: Map[String, String] = Map.empty, + environment: Map[String, String] = Map.empty, + annotations: Map[String, String] = Map.empty, + secretEnvNamesToKeyRefs: Map[String, String] = Map.empty, + secretNamesToMountPaths: Map[String, String] = Map.empty, + volumes: Seq[KubernetesVolumeSpec] = Seq.empty): KubernetesExecutorConf = { + val conf = sparkConf.clone() + + setPrefixedConfigs(conf, KUBERNETES_EXECUTOR_LABEL_PREFIX, labels) + setPrefixedConfigs(conf, "spark.executorEnv.", environment) + setPrefixedConfigs(conf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX, annotations) + setPrefixedConfigs(conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX, secretNamesToMountPaths) + setPrefixedConfigs(conf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX, secretEnvNamesToKeyRefs) + setVolumeSpecs(conf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX, volumes) + + new KubernetesExecutorConf(conf, APP_ID, EXECUTOR_ID, driverPod) + } + + private def setPrefixedConfigs( + conf: SparkConf, + prefix: String, + values: Map[String, String]): Unit = { + values.foreach { case (k, v) => + conf.set(s"${prefix}$k", v) + } + } + + private def setVolumeSpecs( + conf: SparkConf, + prefix: String, + volumes: Seq[KubernetesVolumeSpec]): Unit = { + def key(vtype: String, vname: String, subkey: String): String = { + s"${prefix}$vtype.$vname.$subkey" + } + + volumes.foreach { case spec => + val (vtype, configs) = spec.volumeConf match { + case KubernetesHostPathVolumeConf(path) => + (KUBERNETES_VOLUMES_HOSTPATH_TYPE, + Map(KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> path)) + + case KubernetesPVCVolumeConf(claimName) => + (KUBERNETES_VOLUMES_PVC_TYPE, + Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + val mconf = medium.map { m => (KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY, m) }.toMap + val lconf = sizeLimit.map { l => (KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY, l) }.toMap + (KUBERNETES_VOLUMES_EMPTYDIR_TYPE, mconf ++ lconf) + } + + conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_PATH_KEY), spec.mountPath) + if (spec.mountSubPath.nonEmpty) { + conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY), + spec.mountSubPath) + } + conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_READONLY_KEY), + spec.mountReadOnly.toString) + configs.foreach { case (k, v) => + conf.set(key(vtype, spec.volumeName, k), v) + } + } + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index de79a58a3a756..c0790898e0976 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -25,7 +25,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly === true) @@ -39,7 +39,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountSubPath === "subPath") @@ -51,7 +51,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly === true) @@ -66,7 +66,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly === true) @@ -79,7 +79,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly === true) @@ -92,27 +92,29 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { sparkConf.set("test.hostPath.volumeName.mount.path", "/path") sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head assert(volumeSpec.mountReadOnly === false) } - test("Gracefully fails on missing mount key") { + test("Fails on missing mount key") { val sparkConf = new SparkConf(false) sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head - assert(volumeSpec.isFailure === true) - assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + val e = intercept[NoSuchElementException] { + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.") + } + assert(e.getMessage.contains("emptyDir.volumeName.mount.path")) } - test("Gracefully fails on missing option key") { + test("Fails on missing option key") { val sparkConf = new SparkConf(false) sparkConf.set("test.hostPath.volumeName.mount.path", "/path") sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") - val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head - assert(volumeSpec.isFailure === true) - assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + val e = intercept[NoSuchElementException] { + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.") + } + assert(e.getMessage.contains("hostPath.volumeName.options.path")) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 1e7dfbeffdb24..e4951bc1e69ed 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model.{ContainerPort, ContainerPortBuilder, LocalObjectReferenceBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ @@ -30,32 +30,17 @@ import org.apache.spark.ui.SparkUI class BasicDriverFeatureStepSuite extends SparkFunSuite { - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val PY_MAIN_CLASS = "org.apache.spark.deploy.PythonRunner" - private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) - private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ANNOTATIONS = Map("customAnnotation" -> "customAnnotationValue") private val DRIVER_ENVS = Map( - DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, - DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + "customDriverEnv1" -> "customDriverEnv2", + "customDriverEnv2" -> "customDriverEnv2") private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") private val TEST_IMAGE_PULL_SECRET_OBJECTS = TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } - private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( - JavaMainAppResource(None), - APP_NAME, - MAIN_CLASS, - APP_ARGS) test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -65,19 +50,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(DRIVER_MEMORY.key, "256M") .set(DRIVER_MEMORY_OVERHEAD, 200L) .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) - val kubernetesConf = KubernetesConf( - sparkConf, - emptyDriverSpecificConf, - RESOURCE_NAME_PREFIX, - APP_ID, - DRIVER_LABELS, - DRIVER_ANNOTATIONS, - Map.empty, - Map.empty, - DRIVER_ENVS, - Nil, - hadoopConfSpec = None) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS) + val kubernetesConf = KubernetesTestConf.createDriverConf( + sparkConf = sparkConf, + labels = DRIVER_LABELS, + environment = DRIVER_ENVS, + annotations = DRIVER_ANNOTATIONS) val featureStep = new BasicDriverFeatureStep(kubernetesConf) val basePod = SparkPod.initialPod() @@ -99,10 +77,11 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val envs = configuredPod.container .getEnv .asScala - .map(env => (env.getName, env.getValue)) + .map { env => (env.getName, env.getValue) } .toMap - assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) - assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + DRIVER_ENVS.foreach { case (k, v) => + assert(envs(v) === v) + } assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) @@ -122,13 +101,15 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val driverPodMetadata = configuredPod.pod.getMetadata assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + DRIVER_LABELS.foreach { case (k, v) => + assert(driverPodMetadata.getLabels.get(k) === v) + } assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.app.id" -> KubernetesTestConf.APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> kubernetesConf.resourceNamePrefix, "spark.kubernetes.submitInDriver" -> "true", MEMORY_OVERHEAD_FACTOR.key -> MEMORY_OVERHEAD_FACTOR.defaultValue.get.toString) assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) @@ -141,39 +122,10 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val pythonSparkConf = new SparkConf() .set(DRIVER_MEMORY.key, "4g") .set(CONTAINER_IMAGE, "spark-driver-py:latest") - val javaKubernetesConf = KubernetesConf( - javaSparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - APP_NAME, - PY_MAIN_CLASS, - APP_ARGS), - RESOURCE_NAME_PREFIX, - APP_ID, - DRIVER_LABELS, - DRIVER_ANNOTATIONS, - Map.empty, - Map.empty, - DRIVER_ENVS, - Nil, - hadoopConfSpec = None) - - val pythonKubernetesConf = KubernetesConf( - pythonSparkConf, - KubernetesDriverSpecificConf( - PythonMainAppResource(""), - APP_NAME, - PY_MAIN_CLASS, - APP_ARGS), - RESOURCE_NAME_PREFIX, - APP_ID, - DRIVER_LABELS, - DRIVER_ANNOTATIONS, - Map.empty, - Map.empty, - DRIVER_ENVS, - Nil, - hadoopConfSpec = None) + val javaKubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = javaSparkConf) + val pythonKubernetesConf = KubernetesTestConf.createDriverConf( + sparkConf = pythonSparkConf, + mainAppResource = PythonMainAppResource("")) val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) val basePod = SparkPod.initialPod() @@ -191,25 +143,14 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .setJars(allJars) .set("spark.files", allFiles.mkString(",")) .set(CONTAINER_IMAGE, "spark-driver:latest") - val kubernetesConf = KubernetesConf( - sparkConf, - emptyDriverSpecificConf, - RESOURCE_NAME_PREFIX, - APP_ID, - DRIVER_LABELS, - DRIVER_ANNOTATIONS, - Map.empty, - Map.empty, - DRIVER_ENVS, - Nil, - hadoopConfSpec = None) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.app.id" -> KubernetesTestConf.APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> kubernetesConf.resourceNamePrefix, "spark.kubernetes.submitInDriver" -> "true", "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt", @@ -234,19 +175,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") .set(DRIVER_MEMORY.key, s"${driverMem.toInt}m") factor.foreach { value => sparkConf.set(MEMORY_OVERHEAD_FACTOR, value) } - val driverConf = emptyDriverSpecificConf.copy(mainAppResource = resource) - val conf = KubernetesConf( - sparkConf, - driverConf, - RESOURCE_NAME_PREFIX, - APP_ID, - DRIVER_LABELS, - DRIVER_ANNOTATIONS, - Map.empty, - Map.empty, - DRIVER_ENVS, - Nil, - hadoopConfSpec = None) + val conf = KubernetesTestConf.createDriverConf( + sparkConf = sparkConf, + mainAppResource = resource) val step = new BasicDriverFeatureStep(conf) val pod = step.configurePod(SparkPod.initialPod()) val mem = pod.container.getResources.getRequests.get("memory").getAmount() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index e9a16aab6ccc2..d6003c977937c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -19,20 +19,18 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ -import org.mockito.MockitoAnnotations -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} +import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -class BasicExecutorFeatureStepSuite - extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { +class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "app-id" private val DRIVER_HOSTNAME = "localhost" private val DRIVER_PORT = 7098 private val DRIVER_ADDRESS = RpcEndpointAddress( @@ -45,7 +43,6 @@ class BasicExecutorFeatureStepSuite private val RESOURCE_NAME_PREFIX = "base" private val EXECUTOR_IMAGE = "executor-image" private val LABELS = Map("label1key" -> "label1value") - private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") private val TEST_IMAGE_PULL_SECRET_OBJECTS = TEST_IMAGE_PULL_SECRETS.map { secret => @@ -66,37 +63,35 @@ class BasicExecutorFeatureStepSuite private var baseConf: SparkConf = _ before { - MockitoAnnotations.initMocks(this) baseConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set("spark.driver.host", DRIVER_HOSTNAME) + .set(DRIVER_HOST_ADDRESS, DRIVER_HOSTNAME) .set("spark.driver.port", DRIVER_PORT.toString) - .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS) .set("spark.kubernetes.resource.type", "java") } + private def newExecutorConf( + environment: Map[String, String] = Map.empty): KubernetesExecutorConf = { + KubernetesTestConf.createExecutorConf( + sparkConf = baseConf, + driverPod = Some(DRIVER_POD), + labels = LABELS, + environment = environment) + } + test("basic executor pod has reasonable defaults") { - val step = new BasicExecutorFeatureStep( - KubernetesConf( - baseConf, - KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), - RESOURCE_NAME_PREFIX, - APP_ID, - LABELS, - ANNOTATIONS, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None)) + val step = new BasicExecutorFeatureStep(newExecutorConf()) val executor = step.configurePod(SparkPod.initialPod()) // The executor pod name and default labels. assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") - assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + LABELS.foreach { case (k, v) => + assert(executor.pod.getMetadata.getLabels.get(k) === v) + } assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) // There is exactly 1 container with no volume mounts and default memory limits. @@ -116,43 +111,18 @@ class BasicExecutorFeatureStepSuite } test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" - val step = new BasicExecutorFeatureStep( - KubernetesConf( - conf, - KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), - longPodNamePrefix, - APP_ID, - LABELS, - ANNOTATIONS, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None)) + baseConf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, longPodNamePrefix) + val step = new BasicExecutorFeatureStep(newExecutorConf()) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val step = new BasicExecutorFeatureStep( - KubernetesConf( - conf, - KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), - RESOURCE_NAME_PREFIX, - APP_ID, - LABELS, - ANNOTATIONS, - Map.empty, - Map.empty, - Map("qux" -> "quux"), - Nil, - hadoopConfSpec = None)) + baseConf.set(EXECUTOR_JAVA_OPTIONS, "foo=bar") + baseConf.set(EXECUTOR_CLASS_PATH, "bar=baz") + val kconf = newExecutorConf(environment = Map("qux" -> "quux")) + val step = new BasicExecutorFeatureStep(kconf) val executor = step.configurePod(SparkPod.initialPod()) checkEnv(executor, @@ -163,23 +133,10 @@ class BasicExecutorFeatureStepSuite } test("test executor pyspark memory") { - val conf = baseConf.clone() - conf.set("spark.kubernetes.resource.type", "python") - conf.set(org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY, 42L) - - val step = new BasicExecutorFeatureStep( - KubernetesConf( - conf, - KubernetesExecutorSpecificConf("1", Some(DRIVER_POD)), - RESOURCE_NAME_PREFIX, - APP_ID, - LABELS, - ANNOTATIONS, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None)) + baseConf.set("spark.kubernetes.resource.type", "python") + baseConf.set(PYSPARK_EXECUTOR_MEMORY, 42L) + + val step = new BasicExecutorFeatureStep(newExecutorConf()) val executor = step.configurePod(SparkPod.initialPod()) // This is checking that basic executor + executorMemory = 1408 + 42 = 1450 assert(executor.container.getResources.getRequests.get("memory").getAmount === "1450Mi") @@ -199,7 +156,7 @@ class BasicExecutorFeatureStepSuite ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> APP_ID, + ENV_APPLICATION_ID -> KubernetesTestConf.APP_ID, ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala index 30672952aaf6f..f74ac928028c7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.util.Utils class DriverCommandFeatureStepSuite extends SparkFunSuite { - private val MAIN_CLASS = "mainClass" - test("java resource") { val mainResource = "local:///main.jar" val spec = applyFeatureStep( @@ -37,7 +35,7 @@ class DriverCommandFeatureStepSuite extends SparkFunSuite { assert(spec.pod.container.getArgs.asScala === List( "driver", "--properties-file", SPARK_CONF_PATH, - "--class", MAIN_CLASS, + "--class", KubernetesTestConf.MAIN_CLASS, "spark-internal", "5", "7")) val jars = Utils.stringToSeq(spec.systemProperties("spark.jars")) @@ -55,7 +53,7 @@ class DriverCommandFeatureStepSuite extends SparkFunSuite { assert(spec.pod.container.getArgs.asScala === List( "driver", "--properties-file", SPARK_CONF_PATH, - "--class", MAIN_CLASS, + "--class", KubernetesTestConf.MAIN_CLASS, "/main.py")) val envs = spec.pod.container.getEnv.asScala .map { env => (env.getName, env.getValue) } @@ -86,7 +84,7 @@ class DriverCommandFeatureStepSuite extends SparkFunSuite { assert(spec.pod.container.getArgs.asScala === List( "driver", "--properties-file", SPARK_CONF_PATH, - "--class", MAIN_CLASS, + "--class", KubernetesTestConf.MAIN_CLASS, "/main.py", "5", "7", "9")) val envs = spec.pod.container.getEnv.asScala @@ -112,7 +110,7 @@ class DriverCommandFeatureStepSuite extends SparkFunSuite { assert(spec.pod.container.getArgs.asScala === List( "driver", "--properties-file", SPARK_CONF_PATH, - "--class", MAIN_CLASS, + "--class", KubernetesTestConf.MAIN_CLASS, "/main.R", "5", "7", "9")) } @@ -121,20 +119,11 @@ class DriverCommandFeatureStepSuite extends SparkFunSuite { conf: SparkConf = new SparkConf(false), appArgs: Array[String] = Array(), pyFiles: Seq[String] = Nil): KubernetesDriverSpec = { - val driverConf = new KubernetesDriverSpecificConf( - resource, MAIN_CLASS, "appName", appArgs, pyFiles = pyFiles) - val kubernetesConf = KubernetesConf( - conf, - driverConf, - "resource-prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) + val kubernetesConf = KubernetesTestConf.createDriverConf( + sparkConf = conf, + mainAppResource = resource, + appArgs = appArgs, + pyFiles = pyFiles) val step = new DriverCommandFeatureStep(kubernetesConf) val pod = step.configurePod(SparkPod.initialPod()) val props = step.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 36c6616a87b0a..7d8e9296a6cb5 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -18,51 +18,25 @@ package org.apache.spark.deploy.k8s.features import java.io.File +import scala.collection.JavaConverters._ + import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} -import org.mockito.{Mock, MockitoAnnotations} -import org.scalatest.BeforeAndAfter -import scala.collection.JavaConverters._ +import io.fabric8.kubernetes.api.model.Secret import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.util.Utils -class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite { - private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" - private val APP_ID = "k8s-app" - private var credentialsTempDirectory: File = _ + private val credentialsTempDirectory = Utils.createTempDir() private val BASE_DRIVER_POD = SparkPod.initialPod() - @Mock - private var driverSpecificConf: KubernetesDriverSpecificConf = _ - - before { - MockitoAnnotations.initMocks(this) - credentialsTempDirectory = Utils.createTempDir() - } - - after { - credentialsTempDirectory.delete() - } - test("Don't set any credentials") { - val kubernetesConf = KubernetesConf( - new SparkConf(false), - driverSpecificConf, - KUBERNETES_RESOURCE_NAME_PREFIX, - APP_ID, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) + val kubernetesConf = KubernetesTestConf.createDriverConf() val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) @@ -83,19 +57,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") - val kubernetesConf = KubernetesConf( - submissionSparkConf, - driverSpecificConf, - KUBERNETES_RESOURCE_NAME_PREFIX, - APP_ID, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) - + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = submissionSparkConf) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) @@ -122,18 +84,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesConf = KubernetesConf( - submissionSparkConf, - driverSpecificConf, - KUBERNETES_RESOURCE_NAME_PREFIX, - APP_ID, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = submissionSparkConf) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( @@ -153,7 +104,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef .head .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === - s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") + s"${kubernetesConf.resourceNamePrefix}-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => (data._1, new String(BaseEncoding.base64().decode(data._2), Charsets.UTF_8)) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 3c46667c3042e..045278939dfff 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -16,24 +16,19 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter import scala.collection.JavaConverters._ +import io.fabric8.kubernetes.api.model.Service + import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.JavaMainAppResource -import org.apache.spark.util.Clock - -class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { +import org.apache.spark.internal.config._ +import org.apache.spark.util.ManualClock - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) +class DriverServiceFeatureStepSuite extends SparkFunSuite { private val LONG_RESOURCE_NAME_PREFIX = "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - @@ -42,34 +37,14 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { "label1key" -> "label1value", "label2key" -> "label2value") - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - test("Headless service has a port for the driver RPC and the block manager.") { - sparkConf = sparkConf + val sparkConf = new SparkConf(false) .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - val configurationStep = new DriverServiceFeatureStep( - KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(None), "main", "app", Seq.empty), - SHORT_RESOURCE_NAME_PREFIX, - "app-id", - DRIVER_LABELS, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None)) + .set(DRIVER_BLOCK_MANAGER_PORT, 8080) + val kconf = KubernetesTestConf.createDriverConf( + sparkConf = sparkConf, + labels = DRIVER_LABELS) + val configurationStep = new DriverServiceFeatureStep(kconf) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) @@ -80,50 +55,28 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { verifyService( 9000, 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + s"${kconf.resourceNamePrefix}${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", driverService) } test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceFeatureStep( - KubernetesConf( - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), "main", "app", Seq.empty), - SHORT_RESOURCE_NAME_PREFIX, - "app-id", - DRIVER_LABELS, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None)) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val sparkConf = new SparkConf(false) + .set("spark.driver.port", "9000") + .set(DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace") + val kconf = KubernetesTestConf.createDriverConf( + sparkConf = sparkConf, + labels = DRIVER_LABELS) + val configurationStep = new DriverServiceFeatureStep(kconf) + val expectedServiceName = kconf.resourceNamePrefix + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX val expectedHostName = s"$expectedServiceName.my-namespace.svc" val additionalProps = configurationStep.getAdditionalPodSystemProperties() verifySparkConfHostNames(additionalProps, expectedHostName) } test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceFeatureStep( - KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(None), "main", "app", Seq.empty), - SHORT_RESOURCE_NAME_PREFIX, - "app-id", - DRIVER_LABELS, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None)) + val kconf = KubernetesTestConf.createDriverConf(labels = DRIVER_LABELS) + val configurationStep = new DriverServiceFeatureStep(kconf) val resolvedService = configurationStep .getAdditionalKubernetesResources() .head @@ -131,30 +84,23 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { verifyService( DEFAULT_DRIVER_PORT, DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + s"${kconf.resourceNamePrefix}${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", resolvedService) val additionalProps = configurationStep.getAdditionalPodSystemProperties() assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) - assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) - === DEFAULT_BLOCKMANAGER_PORT.toString) + assert(additionalProps(DRIVER_BLOCK_MANAGER_PORT.key) === DEFAULT_BLOCKMANAGER_PORT.toString) } test("Long prefixes should switch to using a generated name.") { - when(clock.getTimeMillis()).thenReturn(10000) + val clock = new ManualClock() + clock.setTime(10000) + val sparkConf = new SparkConf(false) + .set(KUBERNETES_NAMESPACE, "my-namespace") val configurationStep = new DriverServiceFeatureStep( - KubernetesConf( - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), "main", "app", Seq.empty), - LONG_RESOURCE_NAME_PREFIX, - "app-id", - DRIVER_LABELS, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None), + KubernetesTestConf.createDriverConf( + sparkConf = sparkConf, + resourceNamePrefix = Some(LONG_RESOURCE_NAME_PREFIX), + labels = DRIVER_LABELS), clock) val driverService = configurationStep .getAdditionalKubernetesResources() @@ -168,56 +114,27 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { } test("Disallow bind address and driver host to be set explicitly.") { - try { - new DriverServiceFeatureStep( - KubernetesConf( - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), "main", "app", Seq.empty), - LONG_RESOURCE_NAME_PREFIX, - "app-id", - DRIVER_LABELS, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None), - clock) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") + val sparkConf = new SparkConf(false) + .set(DRIVER_BIND_ADDRESS, "host") + .set("spark.app.name", LONG_RESOURCE_NAME_PREFIX) + val e1 = intercept[IllegalArgumentException] { + new DriverServiceFeatureStep(KubernetesTestConf.createDriverConf(sparkConf = sparkConf)) } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - new DriverServiceFeatureStep( - KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(None), "main", "app", Seq.empty), - LONG_RESOURCE_NAME_PREFIX, - "app-id", - DRIVER_LABELS, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None), - clock) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") + assert(e1.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + + sparkConf.remove(DRIVER_BIND_ADDRESS) + sparkConf.set(DRIVER_HOST_ADDRESS, "host") + + val e2 = intercept[IllegalArgumentException] { + new DriverServiceFeatureStep(KubernetesTestConf.createDriverConf(sparkConf = sparkConf)) } + assert(e2.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") } private def verifyService( @@ -227,7 +144,9 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { service: Service): Unit = { assert(service.getMetadata.getName === expectedServiceName) assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + DRIVER_LABELS.foreach { case (k, v) => + assert(service.getSpec.getSelector.get(k) === v) + } assert(service.getSpec.getPorts.size() === 2) val driverServicePorts = service.getSpec.getPorts.asScala assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index 3d253079c3ce7..0455526111067 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -16,12 +16,12 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.PodBuilder +import scala.collection.JavaConverters._ -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.k8s._ -class EnvSecretsFeatureStepSuite extends SparkFunSuite{ +class EnvSecretsFeatureStepSuite extends SparkFunSuite { private val KEY_REF_NAME_FOO = "foo" private val KEY_REF_NAME_BAR = "bar" private val KEY_REF_KEY_FOO = "key_foo" @@ -34,28 +34,14 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ val envVarsToKeys = Map( ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") - val sparkConf = new SparkConf(false) - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), - "resource-name-prefix", - "app-id", - Map.empty, - Map.empty, - Map.empty, - envVarsToKeys, - Map.empty, - Nil, - hadoopConfSpec = None) + val kubernetesConf = KubernetesTestConf.createDriverConf( + secretEnvNamesToKeyRefs = envVarsToKeys) val step = new EnvSecretsFeatureStep(kubernetesConf) - val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container - - val expectedVars = - Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") - - expectedVars.foreach { envName => - assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + val container = step.configurePod(baseDriverPod).container + val containerEnvKeys = container.getEnv.asScala.map { v => v.getName }.toSet + envVarsToKeys.keys.foreach { envName => + assert(containerEnvKeys.contains(envName)) } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 894d824999aac..8f34ce5c6b94f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -17,45 +17,19 @@ package org.apache.spark.deploy.k8s.features import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} -import org.mockito.Mockito -import org.scalatest._ -import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.util.SparkConfWithEnv -class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { +class LocalDirsFeatureStepSuite extends SparkFunSuite { private val defaultLocalDir = "/var/data/default-local-dir" - private var sparkConf: SparkConf = _ - private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ - - before { - val realSparkConf = new SparkConf(false) - sparkConf = Mockito.spy(realSparkConf) - kubernetesConf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "app-name", - "main", - Seq.empty), - "resource", - "app-id", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) - } test("Resolve to default local dir if neither env nor configuration are set") { - Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") - Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") - val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val stepUnderTest = new LocalDirsFeatureStep(KubernetesTestConf.createDriverConf(), + defaultLocalDir) val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) assert(configuredPod.pod.getSpec.getVolumes.size === 1) assert(configuredPod.pod.getSpec.getVolumes.get(0) === @@ -79,8 +53,9 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { } test("Use configured local dirs split on comma if provided.") { - Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2") - .when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val sparkConf = new SparkConfWithEnv(Map( + "SPARK_LOCAL_DIRS" -> "/var/data/my-local-dir-1,/var/data/my-local-dir-2")) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) assert(configuredPod.pod.getSpec.getVolumes.size === 2) @@ -116,9 +91,8 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { } test("Use tmpfs to back default local dir") { - Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") - Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") - Mockito.doReturn(true).when(sparkConf).get(KUBERNETES_LOCAL_DIRS_TMPFS) + val sparkConf = new SparkConf(false).set(KUBERNETES_LOCAL_DIRS_TMPFS, true) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) assert(configuredPod.pod.getSpec.getVolumes.size === 1) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 1555f6a9c6527..22f6d26c4d0d3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -16,10 +16,8 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.PodBuilder - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.{KubernetesTestConf, SecretVolumeUtils, SparkPod} class MountSecretsFeatureStepSuite extends SparkFunSuite { @@ -32,19 +30,8 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) - val sparkConf = new SparkConf(false) - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesExecutorSpecificConf("1", Some(new PodBuilder().build())), - "resource-name-prefix", - "app-id", - Map.empty, - Map.empty, - secretNamesToMountPaths, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) + val kubernetesConf = KubernetesTestConf.createExecutorConf( + secretNamesToMountPaths = secretNamesToMountPaths) val step = new MountSecretsFeatureStep(kubernetesConf) val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index aadbf16897f46..e6f1dd640e3ea 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -16,29 +16,12 @@ */ package org.apache.spark.deploy.k8s.features +import scala.collection.JavaConverters._ + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class MountVolumesFeatureStepSuite extends SparkFunSuite { - private val sparkConf = new SparkConf(false) - private val emptyKubernetesConf = KubernetesConf( - sparkConf = sparkConf, - roleSpecificConf = KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "app-name", - "main", - Seq.empty), - appResourceNamePrefix = "resource", - appId = "app-id", - roleLabels = Map.empty, - roleAnnotations = Map.empty, - roleSecretNamesToMountPaths = Map.empty, - roleSecretEnvNamesToKeyRefs = Map.empty, - roleEnvs = Map.empty, - roleVolumes = Nil, - hadoopConfSpec = None) - test("Mounts hostPath volumes") { val volumeConf = KubernetesVolumeSpec( "testVolume", @@ -47,7 +30,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -67,7 +50,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { true, KubernetesPVCVolumeConf("pvcClaim") ) - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -89,7 +72,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { false, KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) ) - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -111,7 +94,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { false, KubernetesEmptyDirVolumeConf(None, None) ) - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -140,8 +123,8 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { true, KubernetesPVCVolumeConf("pvcClaim") ) - val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val kubernetesConf = KubernetesTestConf.createDriverConf( + volumes = Seq(hpVolumeConf, pvcVolumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -157,7 +140,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { false, KubernetesEmptyDirVolumeConf(None, None) ) - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -176,7 +159,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { true, KubernetesPVCVolumeConf("pvcClaim") ) - val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) @@ -206,19 +189,18 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { true, KubernetesEmptyDirVolumeConf(None, None) ) - val kubernetesConf = emptyKubernetesConf.copy( - roleVolumes = emptyDirSpec :: pvcSpec :: Nil) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(emptyDirSpec, pvcSpec)) val step = new MountVolumesFeatureStep(kubernetesConf) val configuredPod = step.configurePod(SparkPod.initialPod()) assert(configuredPod.pod.getSpec.getVolumes.size() === 2) - val mounts = configuredPod.container.getVolumeMounts - assert(mounts.size() === 2) - assert(mounts.get(0).getName === "testEmptyDir") - assert(mounts.get(0).getMountPath === "/tmp/foo") - assert(mounts.get(0).getSubPath === "foo") - assert(mounts.get(1).getName === "testPVC") - assert(mounts.get(1).getMountPath === "/tmp/bar") - assert(mounts.get(1).getSubPath === "bar") + val mounts = configuredPod.container.getVolumeMounts.asScala.sortBy(_.getName()) + assert(mounts.size === 2) + assert(mounts(0).getName === "testEmptyDir") + assert(mounts(0).getMountPath === "/tmp/foo") + assert(mounts(0).getSubPath === "foo") + assert(mounts(1).getName === "testPVC") + assert(mounts(1).getMountPath === "/tmp/bar") + assert(mounts(1).getSubPath === "bar") } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala index 370948c9502e4..7295b82ca4799 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala @@ -20,40 +20,22 @@ import java.io.{File, PrintWriter} import java.nio.file.Files import io.fabric8.kubernetes.api.model.ConfigMap -import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { - private var sparkConf: SparkConf = _ - private var kubernetesConf : KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + private var kubernetesConf : KubernetesConf = _ private var templateFile: File = _ before { - sparkConf = Mockito.mock(classOf[SparkConf]) - kubernetesConf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "app-name", - "main", - Seq.empty), - "resource", - "app-id", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Option.empty) templateFile = Files.createTempFile("pod-template", "yml").toFile templateFile.deleteOnExit() - Mockito.doReturn(Option(templateFile.getAbsolutePath)).when(sparkConf) - .get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + + val sparkConf = new SparkConf(false) + .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, templateFile.getAbsolutePath) + kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) } test("Mounts executor template volume if config specified") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 08f28758ef485..e9c05fef6f5db 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -24,8 +24,8 @@ import org.mockito.Mockito.{doReturn, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ @@ -37,10 +37,6 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_PREFIX = "resource-example" private val POD_NAME = "driver" private val CONTAINER_NAME = "container" - private val APP_ID = "app-id" - private val APP_NAME = "app" - private val MAIN_CLASS = "main" - private val APP_ARGS = Seq("arg1", "arg2") private val RESOLVED_JAVA_OPTIONS = Map( "conf1key" -> "conf1value", "conf2key" -> "conf2value") @@ -122,28 +118,15 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var resourceList: RESOURCE_LIST = _ - private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ - - private var sparkConf: SparkConf = _ + private var kconf: KubernetesDriverConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( - sparkConf, - KubernetesDriverSpecificConf(JavaMainAppResource(None), MAIN_CLASS, APP_NAME, APP_ARGS), - KUBERNETES_RESOURCE_PREFIX, - APP_ID, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) - when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) + kconf = KubernetesTestConf.createDriverConf( + resourceNamePrefix = Some(KUBERNETES_RESOURCE_PREFIX)) + when(driverBuilder.buildFromFeatures(kconf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) @@ -158,26 +141,22 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { test("The client should configure the pod using the builder.") { val submissionClient = new Client( + kconf, driverBuilder, - kubernetesConf, kubernetesClient, false, - "spark", - loggingPodStatusWatcher, - KUBERNETES_RESOURCE_PREFIX) + loggingPodStatusWatcher) submissionClient.run() verify(podOperations).create(FULL_EXPECTED_POD) } test("The client should create Kubernetes resources") { val submissionClient = new Client( + kconf, driverBuilder, - kubernetesConf, kubernetesClient, false, - "spark", - loggingPodStatusWatcher, - KUBERNETES_RESOURCE_PREFIX) + loggingPodStatusWatcher) submissionClient.run() val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues assert(otherCreatedResources.size === 2) @@ -197,13 +176,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( + kconf, driverBuilder, - kubernetesConf, kubernetesClient, true, - "spark", - loggingPodStatusWatcher, - KUBERNETES_RESOURCE_PREFIX) + loggingPodStatusWatcher) submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 3708864592d75..7e7dc4763c2e7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -83,48 +83,21 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => templateVolumeStep) test("Apply fundamental steps all the time.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - JavaMainAppResource(Some("example.jar")), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = None) + val conf = KubernetesTestConf.createDriverConf() validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE) + DRIVER_CMD_STEP_TYPE, + HADOOP_GLOBAL_STEP_TYPE) } test("Apply secrets step if secrets are present.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map("secret" -> "secretMountPath"), - Map("EnvName" -> "SecretName:secretKey"), - Map.empty, - Nil, - hadoopConfSpec = None) + val conf = KubernetesTestConf.createDriverConf( + secretEnvNamesToKeyRefs = Map("EnvName" -> "SecretName:secretKey"), + secretNamesToMountPaths = Map("secret" -> "secretMountPath")) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -133,7 +106,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE, ENV_SECRETS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE) + DRIVER_CMD_STEP_TYPE, + HADOOP_GLOBAL_STEP_TYPE) } test("Apply volumes step if mounts are present.") { @@ -143,22 +117,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { "", false, KubernetesHostPathVolumeConf("/path")) - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - volumeSpec :: Nil, - hadoopConfSpec = None) + val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec)) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -166,7 +125,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, MOUNT_VOLUMES_STEP_TYPE, - DRIVER_CMD_STEP_TYPE) + DRIVER_CMD_STEP_TYPE, + HADOOP_GLOBAL_STEP_TYPE) } test("Apply volumes step if a mount subpath is present.") { @@ -176,22 +136,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { "foo", false, KubernetesHostPathVolumeConf("/path")) - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - volumeSpec :: Nil, - hadoopConfSpec = None) + val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec)) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -199,89 +144,14 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, MOUNT_VOLUMES_STEP_TYPE, - DRIVER_CMD_STEP_TYPE) - } - - test("Apply template volume step if executor template is present.") { - val sparkConf = spy(new SparkConf(false)) - doReturn(Option("filename")).when(sparkConf) - .get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) - val conf = KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(Some("example.jar")), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Option.empty) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - TEMPLATE_VOLUME_STEP_TYPE) - } - - test("Apply HadoopSteps if HADOOP_CONF_DIR is defined.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = Some( - HadoopConfSpec( - Some("/var/hadoop-conf"), - None))) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, DRIVER_CMD_STEP_TYPE, HADOOP_GLOBAL_STEP_TYPE) } - test("Apply HadoopSteps if HADOOP_CONF ConfigMap is defined.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesDriverSpecificConf( - JavaMainAppResource(None), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - hadoopConfSpec = Some( - HadoopConfSpec( - None, - Some("pre-defined-configMapName")))) + test("Apply template volume step if executor template is present.") { + val sparkConf = new SparkConf(false) + .set(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "filename") + val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -289,12 +159,16 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) + HADOOP_GLOBAL_STEP_TYPE, + TEMPLATE_VOLUME_STEP_TYPE) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { - assert(resolvedSpec.systemProperties.size === stepTypes.size) + val addedProperties = resolvedSpec.systemProperties + .filter { case (k, _) => !k.startsWith("spark.") } + .toMap + assert(addedProperties.keys.toSet === stepTypes.toSet) stepTypes.foreach { stepType => assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) assert(resolvedSpec.driverKubernetesResources.containsSlice( @@ -314,22 +188,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val sparkConf = new SparkConf(false) .set(CONTAINER_IMAGE, "spark-driver:latest") .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = new KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(Some("example.jar")), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Option.empty) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) val driverSpec = KubernetesDriverBuilder .apply(kubernetesClient, sparkConf) .buildFromFeatures(kubernetesConf) @@ -346,22 +205,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val sparkConf = new SparkConf(false) .set(CONTAINER_IMAGE, "spark-driver:latest") .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = new KubernetesConf( - sparkConf, - KubernetesDriverSpecificConf( - JavaMainAppResource(Some("example.jar")), - "test-app", - "main", - Seq.empty), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Option.empty) + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) val exception = intercept[SparkException] { KubernetesDriverBuilder .apply(kubernetesClient, sparkConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 2f984e5d89808..ddf9f67a0727d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -27,7 +27,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ @@ -79,7 +79,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) when(driverPodOperations.get).thenReturn(driverPod) - when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]))) .thenAnswer(executorPodAnswer()) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() waitForExecutorPodsClock = new ManualClock(0L) @@ -147,44 +147,9 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { private def executorPodAnswer(): Answer[SparkPod] = { new Answer[SparkPod] { override def answer(invocation: InvocationOnMock): SparkPod = { - val k8sConf = invocation.getArgumentAt( - 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) - executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + val k8sConf = invocation.getArgumentAt(0, classOf[KubernetesExecutorConf]) + executorPodWithId(k8sConf.executorId.toInt) } } } - - private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = - Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any): Boolean = { - if (!argument.isInstanceOf[KubernetesConf[_]]) { - false - } else { - val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - val executorSpecificConf = k8sConf.roleSpecificConf - // TODO: HADOOP_CONF_DIR - val expectedK8sConf = KubernetesConf.createExecutorConf( - conf, - executorSpecificConf.executorId, - TEST_SPARK_APP_ID, - Some(driverPod)) - - // Set prefixes to a common string since KUBERNETES_EXECUTOR_POD_NAME_PREFIX - // has not be set for the tests and thus KubernetesConf will use a random - // string for the prefix, based on the app name, and this comparison here will fail. - val k8sConfCopy = k8sConf - .copy(appResourceNamePrefix = "") - .copy(sparkConf = conf) - val expectedK8sConfCopy = expectedK8sConf - .copy(appResourceNamePrefix = "") - .copy(sparkConf = conf) - - k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && - // Since KubernetesConf.createExecutorConf clones the SparkConf object, force - // deep equality comparison for the SparkConf object and use object equality - // comparison on all other fields. - k8sConfCopy == expectedK8sConfCopy - } - } - }) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index a59f6d072023e..b6a75b15af85a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster.k8s +import scala.collection.JavaConverters._ + import io.fabric8.kubernetes.api.model.{Config => _, _} import io.fabric8.kubernetes.client.KubernetesClient import org.mockito.Mockito.{mock, never, verify} @@ -25,6 +27,7 @@ import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.submit.PodBuilderSuiteUtils +import org.apache.spark.util.SparkConfWithEnv class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" @@ -64,37 +67,15 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { _ => hadoopSparkUser) test("Basic steps are consistently applied.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesExecutorSpecificConf( - "executor-id", Some(new PodBuilder().build())), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - None) + val conf = KubernetesTestConf.createExecutorConf() validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { - val conf = KubernetesConf( - new SparkConf(false), - KubernetesExecutorSpecificConf( - "executor-id", Some(new PodBuilder().build())), - "prefix", - "appId", - Map.empty, - Map.empty, - Map("secret" -> "secretMountPath"), - Map("secret-name" -> "secret-key"), - Map.empty, - Nil, - None) + val conf = KubernetesTestConf.createExecutorConf( + secretEnvNamesToKeyRefs = Map("secret-name" -> "secret-key"), + secretNamesToMountPaths = Map("secret" -> "secretMountPath")) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -110,19 +91,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { "", false, KubernetesHostPathVolumeConf("/checkpoint")) - val conf = KubernetesConf( - new SparkConf(false), - KubernetesExecutorSpecificConf( - "executor-id", Some(new PodBuilder().build())), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - volumeSpec :: Nil, - None) + val conf = KubernetesTestConf.createExecutorConf( + volumes = Seq(volumeSpec)) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -132,25 +102,10 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { test("Apply basicHadoop step if HADOOP_CONF_DIR is defined") { // HADOOP_DELEGATION_TOKEN - val HADOOP_CREDS_PREFIX = "spark.security.credentials." - val HADOOPFS_PROVIDER = s"$HADOOP_CREDS_PREFIX.hadoopfs.enabled" - val conf = KubernetesConf( - new SparkConf(false) + val conf = KubernetesTestConf.createExecutorConf( + sparkConf = new SparkConfWithEnv(Map("HADOOP_CONF_DIR" -> "/var/hadoop-conf")) .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") - .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name") - .set(KERBEROS_SPARK_USER_NAME, "spark-user") - .set(HADOOPFS_PROVIDER, "true"), - KubernetesExecutorSpecificConf( - "executor-id", Some(new PodBuilder().build())), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Some(HadoopConfSpec(Some("/var/hadoop-conf"), None))) + .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name")) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -160,24 +115,13 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { } test("Apply kerberos step if DT secrets created") { - val conf = KubernetesConf( - new SparkConf(false) + val conf = KubernetesTestConf.createExecutorConf( + sparkConf = new SparkConf(false) .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name") .set(KERBEROS_SPARK_USER_NAME, "spark-user") .set(KERBEROS_DT_SECRET_NAME, "dt-secret") - .set(KERBEROS_DT_SECRET_KEY, "dt-key"), - KubernetesExecutorSpecificConf( - "executor-id", Some(new PodBuilder().build())), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Some(HadoopConfSpec(None, Some("pre-defined-onfigMapName")))) + .set(KERBEROS_DT_SECRET_KEY, "dt-key" )) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -187,10 +131,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { - assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) - stepTypes.foreach { stepType => - assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) - } + assert(resolvedPod.pod.getMetadata.getLabels.asScala.keys.toSet === stepTypes.toSet) } test("Starts with empty executor pod if template is not specified") { @@ -205,25 +146,14 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { .set("spark.driver.host", "https://driver.host.com") .set(Config.CONTAINER_IMAGE, "spark-executor:latest") .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesConf( - sparkConf, - KubernetesExecutorSpecificConf( - "executor-id", Some(new PodBuilder() - .withNewMetadata() + val kubernetesConf = KubernetesTestConf.createExecutorConf( + sparkConf = sparkConf, + driverPod = Some(new PodBuilder() + .withNewMetadata() .withName("driver") .endMetadata() - .build())), - "prefix", - "appId", - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Map.empty, - Nil, - Option.empty) - val sparkPod = KubernetesExecutorBuilder - .apply(kubernetesClient, sparkConf) + .build())) + val sparkPod = KubernetesExecutorBuilder(kubernetesClient, sparkConf) .buildFromFeatures(kubernetesConf) PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(sparkPod) } From 28d33744076abd8bf7955eefcbdeef4849a99c40 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Sat, 1 Dec 2018 10:37:03 +0800 Subject: [PATCH 2183/2461] [SPARK-23647][PYTHON][SQL] Adds more types for hint in pyspark Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Addition of float, int and list hints for `pyspark.sql` Hint. ## How was this patch tested? I did manual tests following the same principles used in the Scala version, and also added unit tests. Closes #20788 from DylanGuedes/jira-21030. Authored-by: DylanGuedes Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/dataframe.py | 6 ++++-- python/pyspark/sql/tests/test_dataframe.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b8833a39078ba..1b1092c409be0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -485,10 +485,12 @@ def hint(self, name, *parameters): if not isinstance(name, str): raise TypeError("name should be provided as str, got {0}".format(type(name))) + allowed_types = (basestring, list, float, int) for p in parameters: - if not isinstance(p, str): + if not isinstance(p, allowed_types): raise TypeError( - "all parameters should be str, got {0} of type {1}".format(p, type(p))) + "all parameters should be in {0}, got {1} of type {2}".format( + allowed_types, p, type(p))) jdf = self._jdf.hint(name, self._jseq(parameters)) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 908d400e00092..65edf593c300e 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -375,6 +375,19 @@ def test_generic_hints(self): plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + # add tests for SPARK-23647 (test more types for hint) + def test_extended_hint_types(self): + from pyspark.sql import DataFrame + + df = self.spark.range(10e10).toDF("id") + such_a_nice_list = ["itworks1", "itworks2", "itworks3"] + hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list) + logical_plan = hinted_df._jdf.queryExecution().logical() + + self.assertEqual(1, logical_plan.toString().count("1.2345")) + self.assertEqual(1, logical_plan.toString().count("what")) + self.assertEqual(3, logical_plan.toString().count("itworks")) + def test_sample(self): self.assertRaisesRegexp( TypeError, From 2f6e88fecb455a02c4c08c41290e2f338e979543 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 30 Nov 2018 23:14:05 -0800 Subject: [PATCH 2184/2461] [SPARK-26189][R] Fix unionAll doc in SparkR ## What changes were proposed in this pull request? Fix unionAll doc in SparkR ## How was this patch tested? Manually ran test Author: Huaxin Gao Closes #23161 from huaxingao/spark-26189. --- R/pkg/R/DataFrame.R | 20 ++++++++++++++++---- R/pkg/R/generics.R | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 745bb3e15932b..24ed449f2a7d1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2730,13 +2730,25 @@ setMethod("union", dataFrame(unioned) }) -#' Return a new SparkDataFrame containing the union of rows +#' Return a new SparkDataFrame containing the union of rows. #' -#' This is an alias for `union`. +#' This is an alias for \code{union}. #' -#' @rdname union -#' @name unionAll +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. +#' @return A SparkDataFrame containing the result of the unionAll operation. +#' @family SparkDataFrame functions #' @aliases unionAll,SparkDataFrame,SparkDataFrame-method +#' @rdname unionAll +#' @name unionAll +#' @seealso \link{union} +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' unionAllDF <- unionAll(df1, df2) +#' } #' @note unionAll since 1.4.0 setMethod("unionAll", signature(x = "SparkDataFrame", y = "SparkDataFrame"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9d8c24c686c76..eed76465221c6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -631,7 +631,7 @@ setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union setGeneric("union", function(x, y) { standardGeneric("union") }) -#' @rdname union +#' @rdname unionAll setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @rdname unionByName From 327ac83f5cf33c84775a95442862bea56d8a0005 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 1 Dec 2018 16:34:11 +0800 Subject: [PATCH 2185/2461] [SPARK-26180][CORE][TEST] Reuse withTempDir function to the SparkCore test case ## What changes were proposed in this pull request? Currently, the common `withTempDir` function is used in Spark SQL test cases. To handle `val dir = Utils. createTempDir()` and `Utils. deleteRecursively (dir)`. Unfortunately, the `withTempDir` function cannot be used in the Spark Core test case. This PR Sharing `withTempDir` function in Spark Sql and SparkCore to clean up SparkCore test cases. thanks. ## How was this patch tested? N / A Closes #23151 from heary-cao/withCreateTempDir. Authored-by: caoxuewen Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/CheckpointSuite.scala | 5 +- .../apache/spark/ContextCleanerSuite.scala | 97 ++-- .../scala/org/apache/spark/FileSuite.scala | 19 +- .../org/apache/spark/SparkContextSuite.scala | 315 +++++----- .../org/apache/spark/SparkFunSuite.scala | 12 +- .../api/python/PythonBroadcastSuite.scala | 5 +- .../spark/deploy/SparkSubmitSuite.scala | 539 +++++++++--------- .../history/FsHistoryProviderSuite.scala | 89 +-- .../history/HistoryServerArgumentsSuite.scala | 8 +- .../master/PersistenceEngineSuite.scala | 5 +- .../input/WholeTextFileInputFormatSuite.scala | 6 +- .../WholeTextFileRecordReaderSuite.scala | 62 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 5 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 113 ++-- .../spark/scheduler/DAGSchedulerSuite.scala | 22 +- ...putCommitCoordinatorIntegrationSuite.scala | 5 +- .../serializer/KryoSerializerSuite.scala | 42 +- .../apache/spark/storage/DiskStoreSuite.scala | 1 - .../util/PeriodicRDDCheckpointerSuite.scala | 52 +- .../org/apache/spark/util/UtilsSuite.scala | 222 ++++---- .../apache/spark/sql/test/SQLTestUtils.scala | 100 ++-- .../spark/sql/hive/client/VersionsSuite.scala | 9 - .../spark/streaming/TestSuiteBase.scala | 12 - 23 files changed, 858 insertions(+), 887 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 48408ccc8f81b..6d9e47cfd00fc 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -586,8 +586,7 @@ object CheckpointSuite { class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { test("checkpoint compression") { - val checkpointDir = Utils.createTempDir() - try { + withTempDir { checkpointDir => val conf = new SparkConf() .set("spark.checkpoint.compress", "true") .set("spark.ui.enabled", "false") @@ -616,8 +615,6 @@ class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { // Verify that the compressed content can be read back assert(rdd.collect().toSeq === (1 to 20)) - } finally { - Utils.deleteRecursively(checkpointDir) } } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 6724af952505f..1fcc975ab39a9 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -207,54 +207,55 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { } test("automatically cleanup normal checkpoint") { - val checkpointDir = Utils.createTempDir() - checkpointDir.delete() - var rdd = newPairRDD() - sc.setCheckpointDir(checkpointDir.toString) - rdd.checkpoint() - rdd.cache() - rdd.collect() - var rddId = rdd.id - - // Confirm the checkpoint directory exists - assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined) - val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get - val fs = path.getFileSystem(sc.hadoopConfiguration) - assert(fs.exists(path)) - - // the checkpoint is not cleaned by default (without the configuration set) - var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) - rdd = null // Make RDD out of scope, ok if collected earlier - runGC() - postGCTester.assertCleanup() - assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) - - // Verify that checkpoints are NOT cleaned up if the config is not enabled - sc.stop() - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName("cleanupCheckpoint") - .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false") - sc = new SparkContext(conf) - rdd = newPairRDD() - sc.setCheckpointDir(checkpointDir.toString) - rdd.checkpoint() - rdd.cache() - rdd.collect() - rddId = rdd.id - - // Confirm the checkpoint directory exists - assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) - - // Reference rdd to defeat any early collection by the JVM - rdd.count() - - // Test that GC causes checkpoint data cleanup after dereferencing the RDD - postGCTester = new CleanerTester(sc, Seq(rddId)) - rdd = null // Make RDD out of scope - runGC() - postGCTester.assertCleanup() - assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + withTempDir { checkpointDir => + checkpointDir.delete() + var rdd = newPairRDD() + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + var rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined) + val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get + val fs = path.getFileSystem(sc.hadoopConfiguration) + assert(fs.exists(path)) + + // the checkpoint is not cleaned by default (without the configuration set) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) + rdd = null // Make RDD out of scope, ok if collected earlier + runGC() + postGCTester.assertCleanup() + assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + + // Verify that checkpoints are NOT cleaned up if the config is not enabled + sc.stop() + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("cleanupCheckpoint") + .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false") + sc = new SparkContext(conf) + rdd = newPairRDD() + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + + // Reference rdd to defeat any early collection by the JVM + rdd.count() + + // Test that GC causes checkpoint data cleanup after dereferencing the RDD + postGCTester = new CleanerTester(sc, Seq(rddId)) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get)) + } } test("automatically clean up local checkpoint") { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index df04a5ea1d99e..983a7917e8aab 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -306,17 +306,18 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { .set("spark.files.openCostInBytes", "0") .set("spark.default.parallelism", "1")) - val tempDir = Utils.createTempDir() - val tempDirPath = tempDir.getAbsolutePath + withTempDir { tempDir => + val tempDirPath = tempDir.getAbsolutePath - for (i <- 0 until 8) { - val tempFile = new File(tempDir, s"part-0000$i") - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, - StandardCharsets.UTF_8) - } + for (i <- 0 until 8) { + val tempFile = new File(tempDir, s"part-0000$i") + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, + StandardCharsets.UTF_8) + } - for (p <- Seq(1, 2, 8)) { - assert(sc.binaryFiles(tempDirPath, minPartitions = p).getNumPartitions === p) + for (p <- Seq(1, 2, 8)) { + assert(sc.binaryFiles(tempDirPath, minPartitions = p).getNumPartitions === p) + } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 79192f3f3c92c..ec4c7efb5835a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -116,56 +116,57 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test("basic case for addFile and listFiles") { - val dir = Utils.createTempDir() - - val file1 = File.createTempFile("someprefix1", "somesuffix1", dir) - val absolutePath1 = file1.getAbsolutePath - - val file2 = File.createTempFile("someprefix2", "somesuffix2", dir) - val relativePath = file2.getParent + "/../" + file2.getParentFile.getName + "/" + file2.getName - val absolutePath2 = file2.getAbsolutePath - - try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords2", file2, StandardCharsets.UTF_8) - val length1 = file1.length() - val length2 = file2.length() - - sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) - sc.addFile(file1.getAbsolutePath) - sc.addFile(relativePath) - sc.parallelize(Array(1), 1).map(x => { - val gotten1 = new File(SparkFiles.get(file1.getName)) - val gotten2 = new File(SparkFiles.get(file2.getName)) - if (!gotten1.exists()) { - throw new SparkException("file doesn't exist : " + absolutePath1) - } - if (!gotten2.exists()) { - throw new SparkException("file doesn't exist : " + absolutePath2) - } + withTempDir { dir => + val file1 = File.createTempFile("someprefix1", "somesuffix1", dir) + val absolutePath1 = file1.getAbsolutePath + + val file2 = File.createTempFile("someprefix2", "somesuffix2", dir) + val relativePath = file2.getParent + "/../" + file2.getParentFile.getName + + "/" + file2.getName + val absolutePath2 = file2.getAbsolutePath + + try { + Files.write("somewords1", file1, StandardCharsets.UTF_8) + Files.write("somewords2", file2, StandardCharsets.UTF_8) + val length1 = file1.length() + val length2 = file2.length() + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(file1.getAbsolutePath) + sc.addFile(relativePath) + sc.parallelize(Array(1), 1).map(x => { + val gotten1 = new File(SparkFiles.get(file1.getName)) + val gotten2 = new File(SparkFiles.get(file2.getName)) + if (!gotten1.exists()) { + throw new SparkException("file doesn't exist : " + absolutePath1) + } + if (!gotten2.exists()) { + throw new SparkException("file doesn't exist : " + absolutePath2) + } - if (length1 != gotten1.length()) { - throw new SparkException( - s"file has different length $length1 than added file ${gotten1.length()} : " + - absolutePath1) - } - if (length2 != gotten2.length()) { - throw new SparkException( - s"file has different length $length2 than added file ${gotten2.length()} : " + - absolutePath2) - } + if (length1 != gotten1.length()) { + throw new SparkException( + s"file has different length $length1 than added file ${gotten1.length()} : " + + absolutePath1) + } + if (length2 != gotten2.length()) { + throw new SparkException( + s"file has different length $length2 than added file ${gotten2.length()} : " + + absolutePath2) + } - if (absolutePath1 == gotten1.getAbsolutePath) { - throw new SparkException("file should have been copied :" + absolutePath1) - } - if (absolutePath2 == gotten2.getAbsolutePath) { - throw new SparkException("file should have been copied : " + absolutePath2) - } - x - }).count() - assert(sc.listFiles().filter(_.contains("somesuffix1")).size == 1) - } finally { - sc.stop() + if (absolutePath1 == gotten1.getAbsolutePath) { + throw new SparkException("file should have been copied :" + absolutePath1) + } + if (absolutePath2 == gotten2.getAbsolutePath) { + throw new SparkException("file should have been copied : " + absolutePath2) + } + x + }).count() + assert(sc.listFiles().filter(_.contains("somesuffix1")).size == 1) + } finally { + sc.stop() + } } } @@ -202,51 +203,51 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test("addFile recursive works") { - val pluto = Utils.createTempDir() - val neptune = Utils.createTempDir(pluto.getAbsolutePath) - val saturn = Utils.createTempDir(neptune.getAbsolutePath) - val alien1 = File.createTempFile("alien", "1", neptune) - val alien2 = File.createTempFile("alien", "2", saturn) - - try { - sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) - sc.addFile(neptune.getAbsolutePath, true) - sc.parallelize(Array(1), 1).map(x => { - val sep = File.separator - if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { - throw new SparkException("can't access file under root added directory") - } - if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName)) - .exists()) { - throw new SparkException("can't access file in nested directory") - } - if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName)) - .exists()) { - throw new SparkException("file exists that shouldn't") - } - x - }).count() - } finally { - sc.stop() + withTempDir { pluto => + val neptune = Utils.createTempDir(pluto.getAbsolutePath) + val saturn = Utils.createTempDir(neptune.getAbsolutePath) + val alien1 = File.createTempFile("alien", "1", neptune) + val alien2 = File.createTempFile("alien", "2", saturn) + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(neptune.getAbsolutePath, true) + sc.parallelize(Array(1), 1).map(x => { + val sep = File.separator + if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { + throw new SparkException("can't access file under root added directory") + } + if (!new File(SparkFiles.get( + neptune.getName + sep + saturn.getName + sep + alien2.getName)).exists()) { + throw new SparkException("can't access file in nested directory") + } + if (new File(SparkFiles.get( + pluto.getName + sep + neptune.getName + sep + alien1.getName)).exists()) { + throw new SparkException("file exists that shouldn't") + } + x + }).count() + } finally { + sc.stop() + } } } test("addFile recursive can't add directories by default") { - val dir = Utils.createTempDir() - - try { - sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) - intercept[SparkException] { - sc.addFile(dir.getAbsolutePath) + withTempDir { dir => + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + intercept[SparkException] { + sc.addFile(dir.getAbsolutePath) + } + } finally { + sc.stop() } - } finally { - sc.stop() } } test("cannot call addFile with different paths that have the same filename") { - val dir = Utils.createTempDir() - try { + withTempDir { dir => val subdir1 = new File(dir, "subdir1") val subdir2 = new File(dir, "subdir2") assert(subdir1.mkdir()) @@ -267,8 +268,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc.addFile(file2.getAbsolutePath) } assert(getAddedFileContents() === "old") - } finally { - Utils.deleteRecursively(dir) } } @@ -296,30 +295,33 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test("add jar with invalid path") { - val tmpDir = Utils.createTempDir() - val tmpJar = File.createTempFile("test", ".jar", tmpDir) + withTempDir { tmpDir => + val tmpJar = File.createTempFile("test", ".jar", tmpDir) - sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) - sc.addJar(tmpJar.getAbsolutePath) + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(tmpJar.getAbsolutePath) - // Invalid jar path will only print the error log, will not add to file server. - sc.addJar("dummy.jar") - sc.addJar("") - sc.addJar(tmpDir.getAbsolutePath) + // Invalid jar path will only print the error log, will not add to file server. + sc.addJar("dummy.jar") + sc.addJar("") + sc.addJar(tmpDir.getAbsolutePath) - assert(sc.listJars().size == 1) - assert(sc.listJars().head.contains(tmpJar.getName)) + assert(sc.listJars().size == 1) + assert(sc.listJars().head.contains(tmpJar.getName)) + } } test("SPARK-22585 addJar argument without scheme is interpreted literally without url decoding") { - val tmpDir = new File(Utils.createTempDir(), "host%3A443") - tmpDir.mkdirs() - val tmpJar = File.createTempFile("t%2F", ".jar", tmpDir) + withTempDir { dir => + val tmpDir = new File(dir, "host%3A443") + tmpDir.mkdirs() + val tmpJar = File.createTempFile("t%2F", ".jar", tmpDir) - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") - sc.addJar(tmpJar.getAbsolutePath) - assert(sc.listJars().size === 1) + sc.addJar(tmpJar.getAbsolutePath) + assert(sc.listJars().size === 1) + } } test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { @@ -340,60 +342,61 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Comma separated paths for newAPIHadoopFile/wholeTextFiles/binaryFiles (SPARK-7155)") { // Regression test for SPARK-7155 // dir1 and dir2 are used for wholeTextFiles and binaryFiles - val dir1 = Utils.createTempDir() - val dir2 = Utils.createTempDir() - - val dirpath1 = dir1.getAbsolutePath - val dirpath2 = dir2.getAbsolutePath - - // file1 and file2 are placed inside dir1, they are also used for - // textFile, hadoopFile, and newAPIHadoopFile - // file3, file4 and file5 are placed inside dir2, they are used for - // textFile, hadoopFile, and newAPIHadoopFile as well - val file1 = new File(dir1, "part-00000") - val file2 = new File(dir1, "part-00001") - val file3 = new File(dir2, "part-00000") - val file4 = new File(dir2, "part-00001") - val file5 = new File(dir2, "part-00002") - - val filepath1 = file1.getAbsolutePath - val filepath2 = file2.getAbsolutePath - val filepath3 = file3.getAbsolutePath - val filepath4 = file4.getAbsolutePath - val filepath5 = file5.getAbsolutePath - - - try { - // Create 5 text files. - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, - StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) - Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) - Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) - - sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) - - // Test textFile, hadoopFile, and newAPIHadoopFile for file1 and file2 - assert(sc.textFile(filepath1 + "," + filepath2).count() == 5L) - assert(sc.hadoopFile(filepath1 + "," + filepath2, - classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) - assert(sc.newAPIHadoopFile(filepath1 + "," + filepath2, - classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) - - // Test textFile, hadoopFile, and newAPIHadoopFile for file3, file4, and file5 - assert(sc.textFile(filepath3 + "," + filepath4 + "," + filepath5).count() == 5L) - assert(sc.hadoopFile(filepath3 + "," + filepath4 + "," + filepath5, - classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) - assert(sc.newAPIHadoopFile(filepath3 + "," + filepath4 + "," + filepath5, - classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) - - // Test wholeTextFiles, and binaryFiles for dir1 and dir2 - assert(sc.wholeTextFiles(dirpath1 + "," + dirpath2).count() == 5L) - assert(sc.binaryFiles(dirpath1 + "," + dirpath2).count() == 5L) - - } finally { - sc.stop() + withTempDir { dir1 => + withTempDir { dir2 => + val dirpath1 = dir1.getAbsolutePath + val dirpath2 = dir2.getAbsolutePath + + // file1 and file2 are placed inside dir1, they are also used for + // textFile, hadoopFile, and newAPIHadoopFile + // file3, file4 and file5 are placed inside dir2, they are used for + // textFile, hadoopFile, and newAPIHadoopFile as well + val file1 = new File(dir1, "part-00000") + val file2 = new File(dir1, "part-00001") + val file3 = new File(dir2, "part-00000") + val file4 = new File(dir2, "part-00001") + val file5 = new File(dir2, "part-00002") + + val filepath1 = file1.getAbsolutePath + val filepath2 = file2.getAbsolutePath + val filepath3 = file3.getAbsolutePath + val filepath4 = file4.getAbsolutePath + val filepath5 = file5.getAbsolutePath + + + try { + // Create 5 text files. + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, + StandardCharsets.UTF_8) + Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) + Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) + Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) + Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + // Test textFile, hadoopFile, and newAPIHadoopFile for file1 and file2 + assert(sc.textFile(filepath1 + "," + filepath2).count() == 5L) + assert(sc.hadoopFile(filepath1 + "," + filepath2, + classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + assert(sc.newAPIHadoopFile(filepath1 + "," + filepath2, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + + // Test textFile, hadoopFile, and newAPIHadoopFile for file3, file4, and file5 + assert(sc.textFile(filepath3 + "," + filepath4 + "," + filepath5).count() == 5L) + assert(sc.hadoopFile(filepath3 + "," + filepath4 + "," + filepath5, + classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + assert(sc.newAPIHadoopFile(filepath3 + "," + filepath4 + "," + filepath5, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + + // Test wholeTextFiles, and binaryFiles for dir1 and dir2 + assert(sc.wholeTextFiles(dirpath1 + "," + dirpath2).count() == 5L) + assert(sc.binaryFiles(dirpath1 + "," + dirpath2).count() == 5L) + + } finally { + sc.stop() + } + } } } diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 31289026b0027..dad24d7c01b8b 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -23,7 +23,7 @@ import java.io.File import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} import org.apache.spark.internal.Logging -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} /** * Base abstract class for all unit tests in Spark for handling common functionality. @@ -106,4 +106,14 @@ abstract class SparkFunSuite } } + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir() + try f(dir) finally { + Utils.deleteRecursively(dir) + } + } } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index b38a3667abee1..7407a656dbfc8 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.util.Utils // a PythonBroadcast: class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext { test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { - val tempDir = Utils.createTempDir() val broadcastedString = "Hello, world!" def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = { val source = Source.fromFile(broadcast.path) @@ -39,7 +38,7 @@ class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkC source.close() contents should be (broadcastedString) } - try { + withTempDir { tempDir => val broadcastDataFile: File = { val file = new File(tempDir, "broadcastData") val printWriter = new PrintWriter(file) @@ -53,8 +52,6 @@ class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkC val deserializedBroadcast = Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance()) assertBroadcastIsValid(deserializedBroadcast) - } finally { - Utils.deleteRecursively(tempDir) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index c093789244bfe..a8973d1b60f89 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -494,13 +494,11 @@ class SparkSubmitSuite } test("launch simple application with spark-submit with redaction") { - val testDir = Utils.createTempDir() - testDir.deleteOnExit() - val testDirPath = new Path(testDir.getAbsolutePath()) val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val fileSystem = Utils.getHadoopFileSystem("/", SparkHadoopUtil.get.newConfiguration(new SparkConf())) - try { + withTempDir { testDir => + val testDirPath = new Path(testDir.getAbsolutePath()) val args = Seq( "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", @@ -519,8 +517,6 @@ class SparkSubmitSuite Source.fromInputStream(logData).getLines().foreach { line => assert(!line.contains("secret_password")) } - } finally { - Utils.deleteRecursively(testDir) } } @@ -614,108 +610,112 @@ class SparkSubmitSuite assert(new File(rScriptDir).exists) // compile a small jar containing a class that will be called from R code. - val tempDir = Utils.createTempDir() - val srcDir = new File(tempDir, "sparkrtest") - srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").toURI.getPath, - """package sparkrtest; + withTempDir { tempDir => + val srcDir = new File(tempDir, "sparkrtest") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").toURI.getPath, + """package sparkrtest; | |public class DummyClass implements java.io.Serializable { | public static String helloWorld(String arg) { return "Hello " + arg; } | public static int addStuff(int arg1, int arg2) { return arg1 + arg2; } |} - """.stripMargin) - val excFile = TestUtils.createCompiledClass("DummyClass", srcDir, excSource, Seq.empty) - val jarFile = new File(tempDir, "sparkRTestJar-%s.jar".format(System.currentTimeMillis())) - val jarURL = TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("sparkrtest")) + """. + stripMargin) + val excFile = TestUtils.createCompiledClass("DummyClass", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "sparkRTestJar-%s.jar".format(System.currentTimeMillis())) + val jarURL = TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("sparkrtest")) - val args = Seq( - "--name", "testApp", - "--master", "local", - "--jars", jarURL.toString, - "--verbose", - "--conf", "spark.ui.enabled=false", - rScriptDir) - runSparkSubmit(args) + val args = Seq( + "--name", "testApp", + "--master", "local", + "--jars", jarURL.toString, + "--verbose", + "--conf", "spark.ui.enabled=false", + rScriptDir) + runSparkSubmit(args) + } } test("resolves command line argument paths correctly") { - val dir = Utils.createTempDir() - val archive = Paths.get(dir.toPath.toString, "single.zip") - Files.createFile(archive) - val jars = "/jar1,/jar2" - val files = "local:/file1,file2" - val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" - val pyFiles = "py-file1,py-file2" - - // Test jars and files - val clArgs = Seq( - "--master", "local", - "--class", "org.SomeClass", - "--jars", jars, - "--files", files, - "thejar.jar") - val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) - appArgs.jars should be (Utils.resolveURIs(jars)) - appArgs.files should be (Utils.resolveURIs(files)) - conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) - conf.get("spark.files") should be (Utils.resolveURIs(files)) - - // Test files and archives (Yarn) - val clArgs2 = Seq( - "--master", "yarn", - "--class", "org.SomeClass", - "--files", files, - "--archives", archives, - "thejar.jar" - ) - val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) - appArgs2.files should be (Utils.resolveURIs(files)) - appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") - conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should fullyMatch regex - ("file:/archive1,file:.*#archive3") - - // Test python files - val clArgs3 = Seq( - "--master", "local", - "--py-files", pyFiles, - "--conf", "spark.pyspark.driver.python=python3.4", - "--conf", "spark.pyspark.python=python3.5", - "mister.py" - ) - val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) - appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) - conf3.get("spark.submit.pyFiles") should be ( - PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) - conf3.get(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") - conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") + withTempDir { dir => + val archive = Paths.get(dir.toPath.toString, "single.zip") + Files.createFile(archive) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test jars and files + val clArgs = Seq( + "--master", "local", + "--class", "org.SomeClass", + "--jars", jars, + "--files", files, + "thejar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) + appArgs.jars should be(Utils.resolveURIs(jars)) + appArgs.files should be(Utils.resolveURIs(files)) + conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be(Utils.resolveURIs(files)) + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) + appArgs2.files should be(Utils.resolveURIs(files)) + appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") + conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should fullyMatch regex + ("file:/archive1,file:.*#archive3") + + // Test python files + val clArgs3 = Seq( + "--master", "local", + "--py-files", pyFiles, + "--conf", "spark.pyspark.driver.python=python3.4", + "--conf", "spark.pyspark.python=python3.5", + "mister.py" + ) + val appArgs3 = new SparkSubmitArguments(clArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) + appArgs3.pyFiles should be(Utils.resolveURIs(pyFiles)) + conf3.get("spark.submit.pyFiles") should be( + PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + conf3.get(PYSPARK_DRIVER_PYTHON.key) should be("python3.4") + conf3.get(PYSPARK_PYTHON.key) should be("python3.5") + } } test("ambiguous archive mapping results in error message") { - val dir = Utils.createTempDir() - val archive1 = Paths.get(dir.toPath.toString, "first.zip") - val archive2 = Paths.get(dir.toPath.toString, "second.zip") - Files.createFile(archive1) - Files.createFile(archive2) - val jars = "/jar1,/jar2" - val files = "local:/file1,file2" - val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" - val pyFiles = "py-file1,py-file2" - - // Test files and archives (Yarn) - val clArgs2 = Seq( - "--master", "yarn", - "--class", "org.SomeClass", - "--files", files, - "--archives", archives, - "thejar.jar" - ) + withTempDir { dir => + val archive1 = Paths.get(dir.toPath.toString, "first.zip") + val archive2 = Paths.get(dir.toPath.toString, "second.zip") + Files.createFile(archive1) + Files.createFile(archive2) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) - testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + } } test("resolves config paths correctly") { @@ -724,77 +724,77 @@ class SparkSubmitSuite val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles - val tmpDir = Utils.createTempDir() - - // Test jars and files - val f1 = File.createTempFile("test-submit-jars-files", "", tmpDir) - val writer1 = new PrintWriter(f1) - writer1.println("spark.jars " + jars) - writer1.println("spark.files " + files) - writer1.close() - val clArgs = Seq( - "--master", "local", - "--class", "org.SomeClass", - "--properties-file", f1.getPath, - "thejar.jar" - ) - val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) - conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) - conf.get("spark.files") should be(Utils.resolveURIs(files)) - - // Test files and archives (Yarn) - val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) - val writer2 = new PrintWriter(f2) - writer2.println("spark.yarn.dist.files " + files) - writer2.println("spark.yarn.dist.archives " + archives) - writer2.close() - val clArgs2 = Seq( - "--master", "yarn", - "--class", "org.SomeClass", - "--properties-file", f2.getPath, - "thejar.jar" - ) - val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) - conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) - - // Test python files - val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) - val writer3 = new PrintWriter(f3) - writer3.println("spark.submit.pyFiles " + pyFiles) - writer3.close() - val clArgs3 = Seq( - "--master", "local", - "--properties-file", f3.getPath, - "mister.py" - ) - val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) - conf3.get("spark.submit.pyFiles") should be( - PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) - - // Test remote python files - val hadoopConf = new Configuration() - updateConfWithFakeS3Fs(hadoopConf) - val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) - val pyFile1 = File.createTempFile("file1", ".py", tmpDir) - val pyFile2 = File.createTempFile("file2", ".py", tmpDir) - val writer4 = new PrintWriter(f4) - val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" - writer4.println("spark.submit.pyFiles " + remotePyFiles) - writer4.close() - val clArgs4 = Seq( - "--master", "yarn", - "--deploy-mode", "cluster", - "--properties-file", f4.getPath, - "hdfs:///tmp/mister.py" - ) - val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) - // Should not format python path for yarn cluster mode - conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) + withTempDir { tmpDir => + // Test jars and files + val f1 = File.createTempFile("test-submit-jars-files", "", tmpDir) + val writer1 = new PrintWriter(f1) + writer1.println("spark.jars " + jars) + writer1.println("spark.files " + files) + writer1.close() + val clArgs = Seq( + "--master", "local", + "--class", "org.SomeClass", + "--properties-file", f1.getPath, + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) + conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be(Utils.resolveURIs(files)) + + // Test files and archives (Yarn) + val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) + val writer2 = new PrintWriter(f2) + writer2.println("spark.yarn.dist.files " + files) + writer2.println("spark.yarn.dist.archives " + archives) + writer2.close() + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--properties-file", f2.getPath, + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) + conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) + + // Test python files + val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) + val writer3 = new PrintWriter(f3) + writer3.println("spark.submit.pyFiles " + pyFiles) + writer3.close() + val clArgs3 = Seq( + "--master", "local", + "--properties-file", f3.getPath, + "mister.py" + ) + val appArgs3 = new SparkSubmitArguments(clArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) + conf3.get("spark.submit.pyFiles") should be( + PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + + // Test remote python files + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val pyFile1 = File.createTempFile("file1", ".py", tmpDir) + val pyFile2 = File.createTempFile("file2", ".py", tmpDir) + val writer4 = new PrintWriter(f4) + val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" + writer4.println("spark.submit.pyFiles " + remotePyFiles) + writer4.close() + val clArgs4 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--properties-file", f4.getPath, + "hdfs:///tmp/mister.py" + ) + val appArgs4 = new SparkSubmitArguments(clArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) + // Should not format python path for yarn cluster mode + conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) + } } test("user classpath first in driver") { @@ -828,46 +828,50 @@ class SparkSubmitSuite } test("support glob path") { - val tmpJarDir = Utils.createTempDir() - val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) - val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir) - - val tmpFileDir = Utils.createTempDir() - val file1 = File.createTempFile("tmpFile1", "", tmpFileDir) - val file2 = File.createTempFile("tmpFile2", "", tmpFileDir) - - val tmpPyFileDir = Utils.createTempDir() - val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir) - val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir) - - val tmpArchiveDir = Utils.createTempDir() - val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) - val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) - - val tempPyFile = File.createTempFile("tmpApp", ".py") - tempPyFile.deleteOnExit() - - val args = Seq( - "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), - "--name", "testApp", - "--master", "yarn", - "--deploy-mode", "client", - "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar", - "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", - "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", - "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", - tempPyFile.toURI().toString()) - - val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) - conf.get("spark.yarn.dist.jars").split(",").toSet should be - (Set(jar1.toURI.toString, jar2.toURI.toString)) - conf.get("spark.yarn.dist.files").split(",").toSet should be - (Set(file1.toURI.toString, file2.toURI.toString)) - conf.get("spark.yarn.dist.pyFiles").split(",").toSet should be - (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) - conf.get("spark.yarn.dist.archives").split(",").toSet should be - (Set(archive1.toURI.toString, archive2.toURI.toString)) + withTempDir { tmpJarDir => + withTempDir { tmpFileDir => + withTempDir { tmpPyFileDir => + withTempDir { tmpArchiveDir => + val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) + val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir) + + val file1 = File.createTempFile("tmpFile1", "", tmpFileDir) + val file2 = File.createTempFile("tmpFile2", "", tmpFileDir) + + val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir) + val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir) + + val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) + val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + + val tempPyFile = File.createTempFile("tmpApp", ".py") + tempPyFile.deleteOnExit() + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar", + "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", + "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", + "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", + tempPyFile.toURI().toString()) + + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) + conf.get("spark.yarn.dist.jars").split(",").toSet should be + (Set(jar1.toURI.toString, jar2.toURI.toString)) + conf.get("spark.yarn.dist.files").split(",").toSet should be + (Set(file1.toURI.toString, file2.toURI.toString)) + conf.get("spark.yarn.dist.pyFiles").split(",").toSet should be + (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) + conf.get("spark.yarn.dist.archives").split(",").toSet should be + (Set(archive1.toURI.toString, archive2.toURI.toString)) + } + } + } + } } // scalastyle:on println @@ -985,37 +989,38 @@ class SparkSubmitSuite val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - val tmpDir = Utils.createTempDir() - val file = File.createTempFile("tmpFile", "", tmpDir) - val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) - val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) - val tmpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) - val tmpJarPath = s"s3a://${new File(tmpJar.toURI).getAbsolutePath}" + withTempDir { tmpDir => + val file = File.createTempFile("tmpFile", "", tmpDir) + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpJarPath = s"s3a://${new File(tmpJar.toURI).getAbsolutePath}" - val args = Seq( - "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), - "--name", "testApp", - "--master", "yarn", - "--deploy-mode", "client", - "--jars", tmpJarPath, - "--files", s"s3a://${file.getAbsolutePath}", - "--py-files", s"s3a://${pyFile.getAbsolutePath}", - s"s3a://$mainResource" + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", tmpJarPath, + "--files", s"s3a://${file.getAbsolutePath}", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + s"s3a://$mainResource" ) - val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) - // All the resources should still be remote paths, so that YARN client will not upload again. - conf.get("spark.yarn.dist.jars") should be (tmpJarPath) - conf.get("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") - conf.get("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + // All the resources should still be remote paths, so that YARN client will not upload again. + conf.get("spark.yarn.dist.jars") should be(tmpJarPath) + conf.get("spark.yarn.dist.files") should be(s"s3a://${file.getAbsolutePath}") + conf.get("spark.yarn.dist.pyFiles") should be(s"s3a://${pyFile.getAbsolutePath}") - // Local repl jars should be a local path. - conf.get("spark.repl.local.jars") should (startWith("file:")) + // Local repl jars should be a local path. + conf.get("spark.repl.local.jars") should (startWith("file:")) - // local py files should not be a URI format. - conf.get("spark.submit.pyFiles") should (startWith("/")) + // local py files should not be a URI format. + conf.get("spark.submit.pyFiles") should (startWith("/")) + } } test("download remote resource if it is not supported by yarn service") { @@ -1095,18 +1100,13 @@ class SparkSubmitSuite } private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { - val tmpDir = Utils.createTempDir() - - val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") - val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf), StandardCharsets.UTF_8) - for ((key, value) <- defaults) writer.write(s"$key $value\n") - - writer.close() - - try { + withTempDir { tmpDir => + val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") + val writer = + new OutputStreamWriter(new FileOutputStream(defaultsConf), StandardCharsets.UTF_8) + for ((key, value) <- defaults) writer.write(s"$key $value\n") + writer.close() f(tmpDir.getAbsolutePath) - } finally { - Utils.deleteRecursively(tmpDir) } } @@ -1134,39 +1134,40 @@ class SparkSubmitSuite val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - val tmpDir = Utils.createTempDir() - val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + withTempDir { tmpDir => + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) - val args = Seq( - "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), - "--name", "testApp", - "--master", "yarn", - "--deploy-mode", "client", - "--py-files", s"s3a://${pyFile.getAbsolutePath}", - "spark-internal" - ) + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) - val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) - conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") - conf.get("spark.submit.pyFiles") should (startWith("/")) + conf.get(PY_FILES.key) should be(s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.submit.pyFiles") should (startWith("/")) - // Verify "spark.submit.pyFiles" - val args1 = Seq( - "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), - "--name", "testApp", - "--master", "yarn", - "--deploy-mode", "client", - "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", - "spark-internal" - ) + // Verify "spark.submit.pyFiles" + val args1 = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) - val appArgs1 = new SparkSubmitArguments(args1) - val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) + val appArgs1 = new SparkSubmitArguments(args1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) - conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") - conf1.get("spark.submit.pyFiles") should (startWith("/")) + conf1.get(PY_FILES.key) should be(s"s3a://${pyFile.getAbsolutePath}") + conf1.get("spark.submit.pyFiles") should (startWith("/")) + } } test("handles natural line delimiters in --properties-file and --conf uniformly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 527c654a7cd68..c1ae27aa940f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -767,53 +767,54 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("clean up stale app information") { - val storeDir = Utils.createTempDir() - val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - val clock = new ManualClock() - val provider = spy(new FsHistoryProvider(conf, clock)) - val appId = "new1" - - // Write logs for two app attempts. - clock.advance(1) - val attempt1 = newLogFile(appId, Some("1"), inProgress = false) - writeFile(attempt1, true, None, - SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + withTempDir { storeDir => + val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + val clock = new ManualClock() + val provider = spy(new FsHistoryProvider(conf, clock)) + val appId = "new1" + + // Write logs for two app attempts. + clock.advance(1) + val attempt1 = newLogFile(appId, Some("1"), inProgress = false) + writeFile(attempt1, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) ) - val attempt2 = newLogFile(appId, Some("2"), inProgress = false) - writeFile(attempt2, true, None, - SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + val attempt2 = newLogFile(appId, Some("2"), inProgress = false) + writeFile(attempt2, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) ) - updateAndCheck(provider) { list => - assert(list.size === 1) - assert(list(0).id === appId) - assert(list(0).attempts.size === 2) - } - - // Load the app's UI. - val ui = provider.getAppUI(appId, Some("1")) - assert(ui.isDefined) - - // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since - // attempt 2 still exists, listing data should be there. - clock.advance(1) - attempt1.delete() - updateAndCheck(provider) { list => - assert(list.size === 1) - assert(list(0).id === appId) - assert(list(0).attempts.size === 1) - } - assert(!ui.get.valid) - assert(provider.getAppUI(appId, None) === None) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 2) + } - // Delete the second attempt's log file. Now everything should go away. - clock.advance(1) - attempt2.delete() - updateAndCheck(provider) { list => - assert(list.isEmpty) + // Load the app's UI. + val ui = provider.getAppUI(appId, Some("1")) + assert(ui.isDefined) + + // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since + // attempt 2 still exists, listing data should be there. + clock.advance(1) + attempt1.delete() + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 1) + } + assert(!ui.get.valid) + assert(provider.getAppUI(appId, None) === None) + + // Delete the second attempt's log file. Now everything should go away. + clock.advance(1) + attempt2.delete() + updateAndCheck(provider) { list => + assert(list.isEmpty) + } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index 37954826af90c..e89733a144cfa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -41,18 +41,14 @@ class HistoryServerArgumentsSuite extends SparkFunSuite { } test("Properties File Arguments Parsing --properties-file") { - val tmpDir = Utils.createTempDir() - val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) - try { + withTempDir { tmpDir => + val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) Files.write("spark.test.CustomPropertyA blah\n" + "spark.test.CustomPropertyB notblah\n", outFile, UTF_8) val argStrings = Array("--properties-file", outFile.getAbsolutePath) val hsa = new HistoryServerArguments(conf, argStrings) assert(conf.get("spark.test.CustomPropertyA") === "blah") assert(conf.get("spark.test.CustomPropertyB") === "notblah") - } finally { - Utils.deleteRecursively(tmpDir) } } - } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 62fe0eaedfd27..30278655dbe0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -31,14 +31,11 @@ import org.apache.spark.util.Utils class PersistenceEngineSuite extends SparkFunSuite { test("FileSystemPersistenceEngine") { - val dir = Utils.createTempDir() - try { + withTempDir { dir => val conf = new SparkConf() testPersistenceEngine(conf, serializer => new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) ) - } finally { - Utils.deleteRecursively(dir) } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala index 817dc082b7d38..576ca1613f75e 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -59,9 +59,7 @@ class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll test("for small files minimum split size per node and per rack should be less than or equal to " + "maximum split size.") { - var dir : File = null; - try { - dir = Utils.createTempDir() + withTempDir { dir => logInfo(s"Local disk address is ${dir.toString}.") // Set the minsize per node and rack to be larger than the size of the input file. @@ -75,8 +73,6 @@ class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll } // ensure spark job runs successfully without exceptions from the CombineFileInputFormat assert(sc.wholeTextFiles(dir.toString).count == 3) - } finally { - Utils.deleteRecursively(dir) } } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index ddf73d6370631..47552916adb22 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -89,52 +89,50 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl * 3) Does the contents be the same. */ test("Correctness of WholeTextFileRecordReader.") { - val dir = Utils.createTempDir() - logInfo(s"Local disk address is ${dir.toString}.") + withTempDir { dir => + logInfo(s"Local disk address is ${dir.toString}.") - WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => - createNativeFile(dir, filename, contents, false) - } + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } - val res = sc.wholeTextFiles(dir.toString, 3).collect() + val res = sc.wholeTextFiles(dir.toString, 3).collect() - assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, - "Number of files read out does not fit with the actual value.") + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") - for ((filename, contents) <- res) { - val shortName = filename.split('/').last - assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), - s"Missing file name $filename.") - assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, - s"file $filename contents can not match.") + for ((filename, contents) <- res) { + val shortName = filename.split('/').last + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } } - - Utils.deleteRecursively(dir) } test("Correctness of WholeTextFileRecordReader with GzipCodec.") { - val dir = Utils.createTempDir() - logInfo(s"Local disk address is ${dir.toString}.") + withTempDir { dir => + logInfo(s"Local disk address is ${dir.toString}.") - WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => - createNativeFile(dir, filename, contents, true) - } + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, true) + } - val res = sc.wholeTextFiles(dir.toString, 3).collect() + val res = sc.wholeTextFiles(dir.toString, 3).collect() - assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, - "Number of files read out does not fit with the actual value.") + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") - for ((filename, contents) <- res) { - val shortName = filename.split('/').last.split('.')(0) + for ((filename, contents) <- res) { + val shortName = filename.split('/').last.split('.')(0) - assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), - s"Missing file name $filename.") - assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, - s"file $filename contents can not match.") + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } } - - Utils.deleteRecursively(dir) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0ec359d1c94f3..945b09441ea9a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -470,15 +470,12 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("zero-partition RDD") { - val emptyDir = Utils.createTempDir() - try { + withTempDir { emptyDir => val file = sc.textFile(emptyDir.getAbsolutePath) assert(file.partitions.isEmpty) assert(file.collect().toList === Nil) // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - } finally { - Utils.deleteRecursively(emptyDir) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index a799b1cfb0765..5cb2b561d6bce 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -822,63 +822,66 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } test("file server") { - val conf = new SparkConf() - val tempDir = Utils.createTempDir() - val file = new File(tempDir, "file") - Files.write(UUID.randomUUID().toString(), file, UTF_8) - val fileWithSpecialChars = new File(tempDir, "file name") - Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) - val empty = new File(tempDir, "empty") - Files.write("", empty, UTF_8); - val jar = new File(tempDir, "jar") - Files.write(UUID.randomUUID().toString(), jar, UTF_8) - - val dir1 = new File(tempDir, "dir1") - assert(dir1.mkdir()) - val subFile1 = new File(dir1, "file1") - Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) - - val dir2 = new File(tempDir, "dir2") - assert(dir2.mkdir()) - val subFile2 = new File(dir2, "file2") - Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) - - val fileUri = env.fileServer.addFile(file) - val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) - val emptyUri = env.fileServer.addFile(empty) - val jarUri = env.fileServer.addJar(jar) - val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) - val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) - - // Try registering directories with invalid names. - Seq("/files", "/jars").foreach { uri => - intercept[IllegalArgumentException] { - env.fileServer.addDirectory(uri, dir1) - } - } + withTempDir { tempDir => + withTempDir { destDir => + val conf = new SparkConf() + + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val fileWithSpecialChars = new File(tempDir, "file name") + Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val dir1 = new File(tempDir, "dir1") + assert(dir1.mkdir()) + val subFile1 = new File(dir1, "file1") + Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + + val dir2 = new File(tempDir, "dir2") + assert(dir2.mkdir()) + val subFile2 = new File(dir2, "file2") + Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) + val emptyUri = env.fileServer.addFile(empty) + val jarUri = env.fileServer.addJar(jar) + val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) + val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) + + // Try registering directories with invalid names. + Seq("/files", "/jars").foreach { uri => + intercept[IllegalArgumentException] { + env.fileServer.addDirectory(uri, dir1) + } + } - val destDir = Utils.createTempDir() - val sm = new SecurityManager(conf) - val hc = SparkHadoopUtil.get.conf - - val files = Seq( - (file, fileUri), - (fileWithSpecialChars, fileWithSpecialCharsUri), - (empty, emptyUri), - (jar, jarUri), - (subFile1, dir1Uri + "/file1"), - (subFile2, dir2Uri + "/file2")) - files.foreach { case (f, uri) => - val destFile = new File(destDir, f.getName()) - Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) - assert(Files.equal(f, destFile)) - } + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + + val files = Seq( + (file, fileUri), + (fileWithSpecialChars, fileWithSpecialCharsUri), + (empty, emptyUri), + (jar, jarUri), + (subFile1, dir1Uri + "/file1"), + (subFile2, dir2Uri + "/file2")) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } - // Try to download files that do not exist. - Seq("files", "jars", "dir1").foreach { root => - intercept[Exception] { - val uri = env.address.toSparkURL + s"/$root/doesNotExist" - Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + // Try to download files that do not exist. + Seq("files", "jars", "dir1").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5f4ffa151d19b..ed6a3d93b312f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2831,18 +2831,22 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed before)") { - sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) - val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) - shuffleMapRdd.checkpoint() - shuffleMapRdd.doCheckpoint() - assertResultStageNotRollbacked(shuffleMapRdd) + withTempDir { dir => + sc.setCheckpointDir(dir.getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageNotRollbacked(shuffleMapRdd) + } } test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing now)") { - sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) - val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) - shuffleMapRdd.checkpoint() - assertResultStageFailToRollback(shuffleMapRdd) + withTempDir { dir => + sc.setCheckpointDir(dir.getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index d6ff5bb33055c..848f702935536 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -49,11 +49,8 @@ class OutputCommitCoordinatorIntegrationSuite test("exception thrown in OutputCommitter.commitTask()") { // Regression test for SPARK-10381 failAfter(Span(60, Seconds)) { - val tempDir = Utils.createTempDir() - try { + withTempDir { tempDir => sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") - } finally { - Utils.deleteRecursively(tempDir) } } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a7eed4b6a8b88..467e49026a029 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -369,27 +369,27 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { - val dir = Utils.createTempDir() - val tmpfile = dir.toString + "/RoaringBitmap" - val outStream = new FileOutputStream(tmpfile) - val output = new KryoOutput(outStream) - val bitmap = new RoaringBitmap - bitmap.add(1) - bitmap.add(3) - bitmap.add(5) - // Ignore Kryo because it doesn't use writeObject - bitmap.serialize(new KryoOutputObjectOutputBridge(null, output)) - output.flush() - output.close() - - val inStream = new FileInputStream(tmpfile) - val input = new KryoInput(inStream) - val ret = new RoaringBitmap - // Ignore Kryo because it doesn't use readObject - ret.deserialize(new KryoInputObjectInputBridge(null, input)) - input.close() - assert(ret == bitmap) - Utils.deleteRecursively(dir) + withTempDir { dir => + val tmpfile = dir.toString + "/RoaringBitmap" + val outStream = new FileOutputStream(tmpfile) + val output = new KryoOutput(outStream) + val bitmap = new RoaringBitmap + bitmap.add(1) + bitmap.add(3) + bitmap.add(5) + // Ignore Kryo because it doesn't use writeObject + bitmap.serialize(new KryoOutputObjectOutputBridge(null, output)) + output.flush() + output.close() + + val inStream = new FileInputStream(tmpfile) + val input = new KryoInput(inStream) + val ret = new RoaringBitmap + // Ignore Kryo because it doesn't use readObject + ret.deserialize(new KryoInputObjectInputBridge(null, input)) + input.close() + assert(ret == bitmap) + } } test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index eec961a491101..959cf58fa0536 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -132,7 +132,6 @@ class DiskStoreSuite extends SparkFunSuite { } test("block data encryption") { - val testDir = Utils.createTempDir() val testData = new Array[Byte](128 * 1024) new Random().nextBytes(testData) diff --git a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index f9e1b791c86ea..e48f0014fbbd6 100644 --- a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -50,34 +50,34 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext } test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var rddsToCheck = Seq.empty[RDDToCheck] - sc.setCheckpointDir(path) - val rdd1 = createRDD(sc) - val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) - checkpointer.update(rdd1) - rdd1.count() - rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) - checkCheckpoint(rddsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val rdd = createRDD(sc) - checkpointer.update(rdd) - rdd.count() - rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) - checkCheckpoint(rddsToCheck, iteration, checkpointInterval) - iteration += 1 - } + withTempDir { tempDir => + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = + new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } - checkpointer.deleteAllCheckpoints() - rddsToCheck.foreach { rdd => - confirmCheckpointRemoved(rdd.rdd) + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } } - - Utils.deleteRecursively(tempDir) } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f5e912b50d1ab..901a724da8a1b 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -295,31 +295,30 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { private val workerConf = new SparkConf() def testOffsetBytes(isCompressed: Boolean): Unit = { - val tmpDir2 = Utils.createTempDir() - val suffix = getSuffix(isCompressed) - val f1Path = tmpDir2 + "/f1" + suffix - writeLogFile(f1Path, "1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) - val f1Length = Utils.getFileLength(new File(f1Path), workerConf) + withTempDir { tmpDir2 => + val suffix = getSuffix(isCompressed) + val f1Path = tmpDir2 + "/f1" + suffix + writeLogFile(f1Path, "1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) + val f1Length = Utils.getFileLength(new File(f1Path), workerConf) - // Read first few bytes - assert(Utils.offsetBytes(f1Path, f1Length, 0, 5) === "1\n2\n3") + // Read first few bytes + assert(Utils.offsetBytes(f1Path, f1Length, 0, 5) === "1\n2\n3") - // Read some middle bytes - assert(Utils.offsetBytes(f1Path, f1Length, 4, 11) === "3\n4\n5\n6") + // Read some middle bytes + assert(Utils.offsetBytes(f1Path, f1Length, 4, 11) === "3\n4\n5\n6") - // Read last few bytes - assert(Utils.offsetBytes(f1Path, f1Length, 12, 18) === "7\n8\n9\n") + // Read last few bytes + assert(Utils.offsetBytes(f1Path, f1Length, 12, 18) === "7\n8\n9\n") - // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(f1Path, f1Length, -5, 5) === "1\n2\n3") + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(f1Path, f1Length, -5, 5) === "1\n2\n3") - // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(f1Path, f1Length, 12, 22) === "7\n8\n9\n") + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(f1Path, f1Length, 12, 22) === "7\n8\n9\n") - // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(f1Path, f1Length, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") - - Utils.deleteRecursively(tmpDir2) + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(f1Path, f1Length, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") + } } test("reading offset bytes of a file") { @@ -331,41 +330,41 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } def testOffsetBytesMultipleFiles(isCompressed: Boolean): Unit = { - val tmpDir = Utils.createTempDir() - val suffix = getSuffix(isCompressed) - val files = (1 to 3).map(i => new File(tmpDir, i.toString + suffix)) :+ new File(tmpDir, "4") - writeLogFile(files(0).getAbsolutePath, "0123456789".getBytes(StandardCharsets.UTF_8)) - writeLogFile(files(1).getAbsolutePath, "abcdefghij".getBytes(StandardCharsets.UTF_8)) - writeLogFile(files(2).getAbsolutePath, "ABCDEFGHIJ".getBytes(StandardCharsets.UTF_8)) - writeLogFile(files(3).getAbsolutePath, "9876543210".getBytes(StandardCharsets.UTF_8)) - val fileLengths = files.map(Utils.getFileLength(_, workerConf)) - - // Read first few bytes in the 1st file - assert(Utils.offsetBytes(files, fileLengths, 0, 5) === "01234") + withTempDir { tmpDir => + val suffix = getSuffix(isCompressed) + val files = (1 to 3).map(i => + new File(tmpDir, i.toString + suffix)) :+ new File(tmpDir, "4") + writeLogFile(files(0).getAbsolutePath, "0123456789".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(1).getAbsolutePath, "abcdefghij".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(2).getAbsolutePath, "ABCDEFGHIJ".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(3).getAbsolutePath, "9876543210".getBytes(StandardCharsets.UTF_8)) + val fileLengths = files.map(Utils.getFileLength(_, workerConf)) - // Read bytes within the 1st file - assert(Utils.offsetBytes(files, fileLengths, 5, 8) === "567") + // Read first few bytes in the 1st file + assert(Utils.offsetBytes(files, fileLengths, 0, 5) === "01234") - // Read bytes across 1st and 2nd file - assert(Utils.offsetBytes(files, fileLengths, 8, 18) === "89abcdefgh") + // Read bytes within the 1st file + assert(Utils.offsetBytes(files, fileLengths, 5, 8) === "567") - // Read bytes across 1st, 2nd and 3rd file - assert(Utils.offsetBytes(files, fileLengths, 5, 24) === "56789abcdefghijABCD") + // Read bytes across 1st and 2nd file + assert(Utils.offsetBytes(files, fileLengths, 8, 18) === "89abcdefgh") - // Read bytes across 3rd and 4th file - assert(Utils.offsetBytes(files, fileLengths, 25, 35) === "FGHIJ98765") + // Read bytes across 1st, 2nd and 3rd file + assert(Utils.offsetBytes(files, fileLengths, 5, 24) === "56789abcdefghijABCD") - // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(files, fileLengths, -5, 18) === "0123456789abcdefgh") + // Read bytes across 3rd and 4th file + assert(Utils.offsetBytes(files, fileLengths, 25, 35) === "FGHIJ98765") - // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(files, fileLengths, 18, 45) === "ijABCDEFGHIJ9876543210") + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(files, fileLengths, -5, 18) === "0123456789abcdefgh") - // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(files, fileLengths, -5, 45) === - "0123456789abcdefghijABCDEFGHIJ9876543210") + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(files, fileLengths, 18, 45) === "ijABCDEFGHIJ9876543210") - Utils.deleteRecursively(tmpDir) + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(files, fileLengths, -5, 45) === + "0123456789abcdefghijABCDEFGHIJ9876543210") + } } test("reading offset bytes across multiple files") { @@ -427,27 +426,28 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files - val parent: File = Utils.createTempDir() - // The parent directory has two child directories - val child1: File = Utils.createTempDir(parent.getCanonicalPath) - val child2: File = Utils.createTempDir(parent.getCanonicalPath) - val child3: File = Utils.createTempDir(child1.getCanonicalPath) - // set the last modified time of child1 to 30 secs old - child1.setLastModified(System.currentTimeMillis() - (1000 * 30)) - - // although child1 is old, child2 is still new so return true - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) - - child2.setLastModified(System.currentTimeMillis - (1000 * 30)) - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) - - parent.setLastModified(System.currentTimeMillis - (1000 * 30)) - // although parent and its immediate children are new, child3 is still old - // we expect a full recursive search for new files. - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) - - child3.setLastModified(System.currentTimeMillis - (1000 * 30)) - assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + withTempDir { parent => + // The parent directory has two child directories + val child1: File = Utils.createTempDir(parent.getCanonicalPath) + val child2: File = Utils.createTempDir(parent.getCanonicalPath) + val child3: File = Utils.createTempDir(child1.getCanonicalPath) + // set the last modified time of child1 to 30 secs old + child1.setLastModified(System.currentTimeMillis() - (1000 * 30)) + + // although child1 is old, child2 is still new so return true + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + child2.setLastModified(System.currentTimeMillis - (1000 * 30)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + parent.setLastModified(System.currentTimeMillis - (1000 * 30)) + // although parent and its immediate children are new, child3 is still old + // we expect a full recursive search for new files. + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + child3.setLastModified(System.currentTimeMillis - (1000 * 30)) + assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + } } test("resolveURI") { @@ -608,9 +608,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } test("loading properties from file") { - val tmpDir = Utils.createTempDir() - val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) - try { + withTempDir { tmpDir => + val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) System.setProperty("spark.test.fileNameLoadB", "2") Files.write("spark.test.fileNameLoadA true\n" + "spark.test.fileNameLoadB 1\n", outFile, StandardCharsets.UTF_8) @@ -621,8 +620,6 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val sparkConf = new SparkConf assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true) assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2) - } finally { - Utils.deleteRecursively(tmpDir) } } @@ -638,52 +635,53 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } test("fetch hcfs dir") { - val tempDir = Utils.createTempDir() - val sourceDir = new File(tempDir, "source-dir") - sourceDir.mkdir() - val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) - val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) - val targetDir = new File(tempDir, "target-dir") - Files.write("some text", sourceFile, StandardCharsets.UTF_8) - - val path = - if (Utils.isWindows) { - new Path("file:/" + sourceDir.getAbsolutePath.replace("\\", "/")) - } else { - new Path("file://" + sourceDir.getAbsolutePath) - } - val conf = new Configuration() - val fs = Utils.getHadoopFileSystem(path.toString, conf) + withTempDir { tempDir => + val sourceDir = new File(tempDir, "source-dir") + sourceDir.mkdir() + val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) + val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) + val targetDir = new File(tempDir, "target-dir") + Files.write("some text", sourceFile, StandardCharsets.UTF_8) + + val path = + if (Utils.isWindows) { + new Path("file:/" + sourceDir.getAbsolutePath.replace("\\", "/")) + } else { + new Path("file://" + sourceDir.getAbsolutePath) + } + val conf = new Configuration() + val fs = Utils.getHadoopFileSystem(path.toString, conf) - assert(!targetDir.isDirectory()) - Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) - assert(targetDir.isDirectory()) + assert(!targetDir.isDirectory()) + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + assert(targetDir.isDirectory()) - // Copy again to make sure it doesn't error if the dir already exists. - Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + // Copy again to make sure it doesn't error if the dir already exists. + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) - val destDir = new File(targetDir, sourceDir.getName()) - assert(destDir.isDirectory()) + val destDir = new File(targetDir, sourceDir.getName()) + assert(destDir.isDirectory()) - val destInnerDir = new File(destDir, innerSourceDir.getName) - assert(destInnerDir.isDirectory()) + val destInnerDir = new File(destDir, innerSourceDir.getName) + assert(destInnerDir.isDirectory()) - val destInnerFile = new File(destInnerDir, sourceFile.getName) - assert(destInnerFile.isFile()) + val destInnerFile = new File(destInnerDir, sourceFile.getName) + assert(destInnerFile.isFile()) - val filePath = - if (Utils.isWindows) { - new Path("file:/" + sourceFile.getAbsolutePath.replace("\\", "/")) - } else { - new Path("file://" + sourceFile.getAbsolutePath) - } - val testFileDir = new File(tempDir, "test-filename") - val testFileName = "testFName" - val testFilefs = Utils.getHadoopFileSystem(filePath.toString, conf) - Utils.fetchHcfsFile(filePath, testFileDir, testFilefs, new SparkConf(), - conf, false, Some(testFileName)) - val newFileName = new File(testFileDir, testFileName) - assert(newFileName.isFile()) + val filePath = + if (Utils.isWindows) { + new Path("file:/" + sourceFile.getAbsolutePath.replace("\\", "/")) + } else { + new Path("file://" + sourceFile.getAbsolutePath) + } + val testFileDir = new File(tempDir, "test-filename") + val testFileName = "testFName" + val testFilefs = Utils.getHadoopFileSystem(filePath.toString, conf) + Utils.fetchHcfsFile(filePath, testFileDir, testFilefs, new SparkConf(), + conf, false, Some(testFileName)) + val newFileName = new File(testFileDir, testFileName) + assert(newFileName.isFile()) + } } test("shutdown hook manager") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 23419493e5368..85963ec4ca699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -66,6 +66,17 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with } } + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected override def withTempDir(f: File => Unit): Unit = { + super.withTempDir { dir => + f(dir) + waitForTasksToFinish() + } + } + /** * A helper function for turning off/on codegen. */ @@ -143,43 +154,6 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with test(name) { runOnThread() } } } -} - -/** - * Helper trait that can be extended by all external SQL test suites. - * - * This allows subclasses to plugin a custom `SQLContext`. - * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. - * - * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is - * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. - */ -private[sql] trait SQLTestUtilsBase - extends Eventually - with BeforeAndAfterAll - with SQLTestData - with PlanTestBase { self: Suite => - - protected def sparkContext = spark.sparkContext - - // Shorthand for running a query using our SQLContext - protected lazy val sql = spark.sql _ - - /** - * A helper object for importing SQL implicits. - * - * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the `SQLContext` immediately before the first test is run, - * but the implicits import is needed in the constructor. - */ - protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.spark.sqlContext - } - - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - SparkSession.setActiveSession(spark) - super.withSQLConf(pairs: _*)(f) - } /** * Copy file in jar's resource to a temp file, then pass it to `f`. @@ -206,21 +180,6 @@ private[sql] trait SQLTestUtilsBase } } - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally { - // wait for all tasks to finish before deleting files - waitForTasksToFinish() - Utils.deleteRecursively(dir) - } - } - /** * Creates the specified number of temporary directories, which is then passed to `f` and will be * deleted after `f` returns. @@ -233,6 +192,43 @@ private[sql] trait SQLTestUtilsBase files.foreach(Utils.deleteRecursively) } } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually + with BeforeAndAfterAll + with SQLTestData + with PlanTestBase { self: Suite => + + protected def sparkContext = spark.sparkContext + + // Shorthand for running a query using our SQLContext + protected lazy val sql = spark.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `spark.implicits._` is not possible here. + * This is because we create the `SQLContext` immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) + } /** * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index dc96ec416afd8..218bd18e5dc99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -55,15 +55,6 @@ class VersionsSuite extends SparkFunSuite with Logging { import HiveClientBuilder.buildClient - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } - /** * Drops table `tableName` after calling `f`. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index ada494eb897f3..6a0f523e4b49b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -557,16 +557,4 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { verifyOutput[W](output.toSeq, expectedOutput, useSet) } } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * (originally from `SqlTestUtils`.) - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } - } From 1abfbda7eb9bd855d70ba64fc137ecc101e1d8b0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 1 Dec 2018 07:06:18 -0600 Subject: [PATCH 2186/2461] [SPARK-26212][BUILD][TEST-MAVEN] Upgrade maven version to 3.6.0 ## What changes were proposed in this pull request? This PR updates maven version from 3.5.4 to 3.6.0. The release note of the 3.6.0 is [here](https://maven.apache.org/docs/3.6.0/release-notes.html). From [the release note of the 3.6.0](https://maven.apache.org/docs/3.6.0/release-notes.html), the followings are new features: 1. There had been issues related to the project discoverytime which has been increased in previous version which influenced some of our users. 1. The output in the reactor summary has been improved. 1. There was an issue related to the classpath ordering. ## How was this patch tested? Existing tests Closes #23177 from kiszk/SPARK-26212. Authored-by: Kazuaki Ishizaki Signed-off-by: Sean Owen --- dev/appveyor-install-dependencies.ps1 | 2 +- docs/building-spark.md | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index cc68ffb90d875..7c7bdd623477a 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -81,7 +81,7 @@ if (!(Test-Path $tools)) { # ========================== Maven Push-Location $tools -$mavenVer = "3.5.4" +$mavenVer = "3.6.0" Start-FileDownload "https://archive.apache.org/dist/maven/maven-3/$mavenVer/binaries/apache-maven-$mavenVer-bin.zip" "maven.zip" # extract diff --git a/docs/building-spark.md b/docs/building-spark.md index dfcd53c48e85c..55695f35931c6 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -12,7 +12,7 @@ redirect_from: "building-with-maven.html" ## Apache Maven The Maven-based build is the build of reference for Apache Spark. -Building Spark using Maven requires Maven 3.5.4 and Java 8. +Building Spark using Maven requires Maven 3.6.0 and Java 8. Note that support for Java 7 was removed as of Spark 2.2.0. ### Setting up Maven's Memory Usage diff --git a/pom.xml b/pom.xml index dfc3c540dc18e..61321a1450708 100644 --- a/pom.xml +++ b/pom.xml @@ -114,7 +114,7 @@ 1.8 ${java.version} ${java.version} - 3.5.4 + 3.6.0 spark 1.7.16 1.2.17 From 60e4239a1e3506d342099981b6e3b3b8431a203e Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 1 Dec 2018 07:11:31 -0600 Subject: [PATCH 2187/2461] [MINOR][DOC] Correct some document description errors ## What changes were proposed in this pull request? Correct some document description errors. ## How was this patch tested? N/A Closes #23162 from 10110346/docerror. Authored-by: liuxian Signed-off-by: Sean Owen --- .../org/apache/spark/internal/config/package.scala | 10 +++++----- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9cc48f6375003..646b3881a79b0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -281,7 +281,7 @@ package object config { private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") .intConf - .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") + .checkValue(_ > 0, "The capacity of listener bus event queue must be positive") .createWithDefault(10000) private[spark] val LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED = @@ -430,8 +430,8 @@ package object config { .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) .checkValue(_ <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, - "The chunk size during writing out the bytes of" + - " ChunkedByteBuffer should not larger than Int.MaxValue - 15.") + "The chunk size during writing out the bytes of ChunkedByteBuffer should" + + s" be less than or equal to ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") .createWithDefault(64 * 1024 * 1024) private[spark] val CHECKPOINT_COMPRESS = @@ -503,7 +503,7 @@ package object config { "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024, - s"The file buffer size must be greater than 0 and less than" + + s"The file buffer size must be positive and less than or equal to" + s" ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024}.") .createWithDefaultString("32k") @@ -513,7 +513,7 @@ package object config { "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024, - s"The buffer size must be greater than 0 and less than" + + s"The buffer size must be positive and less than or equal to" + s" ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024}.") .createWithDefaultString("32k") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f1c845bc94507..c4f00d723c252 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -326,7 +326,7 @@ object SQLConf { "factor as the estimated data size, in case the data is compressed in the file and lead to" + " a heavily underestimated result.") .doubleConf - .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0") + .checkValue(_ > 0, "the value of fileDataSizeFactor must be greater than 0") .createWithDefault(1.0) val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") @@ -673,7 +673,7 @@ object SQLConf { val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed. Defaults to 100000") .intConf - .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be larger than 0") + .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be greater than 0") .createWithDefault(100000) val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled") @@ -1154,7 +1154,7 @@ object SQLConf { .internal() .doc("The number of bins when generating histograms.") .intConf - .checkValue(num => num > 1, "The number of bins must be larger than 1.") + .checkValue(num => num > 1, "The number of bins must be greater than 1.") .createWithDefault(254) val PERCENTILE_ACCURACY = From 55c96858107739dd768abea1dff88bd970e47e9f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 1 Dec 2018 16:22:38 -0800 Subject: [PATCH 2188/2461] [SPARK-26226][SQL] Track optimization phase for streaming queries ## What changes were proposed in this pull request? In an earlier PR, we missed measuring the optimization phase time for streaming queries. This patch adds it. ## How was this patch tested? Given this is a debugging feature, and it is very convoluted to add tests to verify the phase is set properly, I am not introducing a streaming specific test. Closes #23193 from rxin/SPARK-26226-1. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../spark/sql/execution/streaming/IncrementalExecution.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index fad287e28877d..a73e88c19ba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, ExpressionWithRandomSeed} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} @@ -73,7 +74,8 @@ class IncrementalExecution( * Walk the optimized logical plan and replace CurrentBatchTimestamp * with the desired literal */ - override lazy val optimizedPlan: LogicalPlan = { + override + lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions { case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") From cbb9bb96d292d6e738f2f33637fb1c9715b167ac Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 1 Dec 2018 16:24:06 -0800 Subject: [PATCH 2189/2461] [SPARK-26241][SQL] Add queryId to IncrementalExecution ## What changes were proposed in this pull request? This is a small change for better debugging: to pass query uuid in IncrementalExecution, when we look at the QueryExecution in isolation to trace back the query. ## How was this patch tested? N/A - just add some field for better debugging. Closes #23192 from rxin/SPARK-26241. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../scala/org/apache/spark/sql/execution/command/commands.scala | 2 +- .../spark/sql/execution/streaming/IncrementalExecution.scala | 1 + .../spark/sql/execution/streaming/MicroBatchExecution.scala | 1 + .../execution/streaming/continuous/ContinuousExecution.scala | 1 + 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index ab40936eb3cc9..754a3316ffb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -154,7 +154,7 @@ case class ExplainCommand( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logicalPlan, OutputMode.Append(), "", - UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a73e88c19ba9e..af52af0d1d7e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -42,6 +42,7 @@ class IncrementalExecution( logicalPlan: LogicalPlan, val outputMode: OutputMode, val checkpointLocation: String, + val queryId: UUID, val runId: UUID, val currentBatchId: Long, val offsetSeqMetadata: OffsetSeqMetadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 5defca391a355..64e09edf27f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -521,6 +521,7 @@ class MicroBatchExecution( triggerLogicalPlan, outputMode, checkpointFile("state"), + id, runId, currentBatchId, offsetSeqMetadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index af23c5cd3d80a..4d42428fd189e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -199,6 +199,7 @@ class ContinuousExecution( withSink, outputMode, checkpointFile("state"), + id, runId, currentBatchId, offsetSeqMetadata) From 17fdca7c1bab94e6e54b25807344b06a78780cf6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 2 Dec 2018 10:22:22 +0800 Subject: [PATCH 2190/2461] [SPARK-26211][SQL][TEST][FOLLOW-UP] Combine test cases for `In` and `InSet`. ## What changes were proposed in this pull request? This is a follow pr of #23176. `In` and `InSet` are semantically equal, so the tests for `In` should pass with `InSet`, and vice versa. This combines those test cases. ## How was this patch tested? The combined tests and existing tests. Closes #23187 from ueshin/issues/SPARK-26211/in_inset_tests. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../catalyst/expressions/PredicateSuite.scala | 160 ++++++++---------- 1 file changed, 66 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 3b60d1d88b3c6..0f63717f9daf2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -124,34 +124,43 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, false, null) :: (null, null, null) :: Nil) - test("basic IN predicate test") { - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + private def checkInAndInSet(in: In, expected: Any): Unit = { + // expecting all in.list are Literal or NonFoldableLiteral. + checkEvaluation(in, expected) + checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected) + } + + test("basic IN/INSET predicate test") { + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) + checkInAndInSet(In(Literal(1), Seq.empty), false) + checkInAndInSet(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkInAndInSet(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkInAndInSet(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkInAndInSet(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkInAndInSet(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkInAndInSet(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) + checkEvaluation( + And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))), + true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(ns, Seq(ns)), null) - checkEvaluation(In(Literal("a"), Seq(ns)), null) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) - + checkInAndInSet(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkInAndInSet(In(ns, Seq(ns)), null) + checkInAndInSet(In(Literal("a"), Seq(ns)), null) + checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) + checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } test("IN with different types") { @@ -187,11 +196,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(input(0), input.slice(1, 10)), expected) + checkInAndInSet(In(input(0), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => - RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType] + RandomDataGenerator.forType(t).isDefined && + !t.isInstanceOf[DecimalType] && !t.isInstanceOf[BinaryType] } ++ Seq(DecimalType.USER_DEFAULT) val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true)) @@ -252,93 +262,55 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.inlinedMutableStates.isEmpty) } - test("INSET") { - val hS = HashSet[Any]() + 1 + 2 - val nS = HashSet[Any]() + 1 + 2 + null - val one = Literal(1) - val two = Literal(2) - val three = Literal(3) - val nl = Literal(null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) - - val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, - LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) - primitiveTypes.foreach { t => - val dataGen = RandomDataGenerator.forType(t, nullable = true).get - val inputData = Seq.fill(10) { - val value = dataGen.apply() - value match { - case d: Double if d.isNaN => 0.0d - case f: Float if f.isNaN => 0.0f - case _ => value - } - } - val input = inputData.map(Literal(_)) - val expected = if (inputData(0) == null) { - null - } else if (inputData.slice(1, 10).contains(inputData(0))) { - true - } else if (inputData.slice(1, 10).contains(null)) { - null - } else { - false - } - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) - } - } - - test("INSET: binary") { - val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) - val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null + test("IN/INSET: binary") { val onetwo = Literal(Array(1.toByte, 2.toByte)) val three = Literal(Array(3.toByte)) val threefour = Literal(Array(3.toByte, 4.toByte)) - val nl = Literal(null, onetwo.dataType) - checkEvaluation(InSet(onetwo, hS), true) - checkEvaluation(InSet(three, hS), true) - checkEvaluation(InSet(three, nS), true) - checkEvaluation(InSet(threefour, hS), false) - checkEvaluation(InSet(threefour, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + val nl = NonFoldableLiteral.create(null, onetwo.dataType) + val hS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte))) + val nS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte)), + NonFoldableLiteral.create(null, onetwo.dataType)) + checkInAndInSet(In(onetwo, hS), true) + checkInAndInSet(In(three, hS), true) + checkInAndInSet(In(three, nS), true) + checkInAndInSet(In(threefour, hS), false) + checkInAndInSet(In(threefour, nS), null) + checkInAndInSet(In(nl, hS), null) + checkInAndInSet(In(nl, nS), null) } - test("INSET: struct") { - val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value - val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null + test("IN/INSET: struct") { val oneA = Literal.create((1, "a")) val twoB = Literal.create((2, "b")) val twoC = Literal.create((2, "c")) - val nl = Literal(null, oneA.dataType) - checkEvaluation(InSet(oneA, hS), true) - checkEvaluation(InSet(twoB, hS), true) - checkEvaluation(InSet(twoB, nS), true) - checkEvaluation(InSet(twoC, hS), false) - checkEvaluation(InSet(twoC, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + val nl = NonFoldableLiteral.create(null, oneA.dataType) + val hS = Seq(Literal.create((1, "a")), Literal.create((2, "b"))) + val nS = Seq(Literal.create((1, "a")), Literal.create((2, "b")), + NonFoldableLiteral.create(null, oneA.dataType)) + checkInAndInSet(In(oneA, hS), true) + checkInAndInSet(In(twoB, hS), true) + checkInAndInSet(In(twoB, nS), true) + checkInAndInSet(In(twoC, hS), false) + checkInAndInSet(In(twoC, nS), null) + checkInAndInSet(In(nl, hS), null) + checkInAndInSet(In(nl, nS), null) } - test("INSET: array") { - val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value - val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null + test("IN/INSET: array") { val onetwo = Literal.create(Seq(1, 2)) val three = Literal.create(Seq(3)) val threefour = Literal.create(Seq(3, 4)) - val nl = Literal(null, onetwo.dataType) - checkEvaluation(InSet(onetwo, hS), true) - checkEvaluation(InSet(three, hS), true) - checkEvaluation(InSet(three, nS), true) - checkEvaluation(InSet(threefour, hS), false) - checkEvaluation(InSet(threefour, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + val nl = NonFoldableLiteral.create(null, onetwo.dataType) + val hS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3))) + val nS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3)), + NonFoldableLiteral.create(null, onetwo.dataType)) + checkInAndInSet(In(onetwo, hS), true) + checkInAndInSet(In(three, hS), true) + checkInAndInSet(In(three, nS), true) + checkInAndInSet(In(threefour, hS), false) + checkInAndInSet(In(threefour, nS), null) + checkInAndInSet(In(nl, hS), null) + checkInAndInSet(In(nl, nS), null) } private case class MyStruct(a: Long, b: String) From 3e46e3ccd58d0a2d445dff58a52ab1966ce133e8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 2 Dec 2018 10:29:25 +0800 Subject: [PATCH 2191/2461] [SPARK-26161][SQL] Ignore empty files in load ## What changes were proposed in this pull request? In the PR, I propose filtering out all empty files inside of `FileSourceScanExec` and exclude them from file splits. It should reduce overhead of opening and reading files without any data, and as consequence datasources will not produce empty partitions for such files. ## How was this patch tested? Added a test which creates an empty and non-empty files. If empty files are ignored in load, Text datasource in the `wholetext` mode must create only one partition for non-empty file. Closes #23130 from MaxGekk/ignore-empty-files. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../spark/sql/execution/DataSourceScanExec.scala | 4 ++-- .../sql/execution/datasources/json/JsonSuite.scala | 3 +-- .../apache/spark/sql/sources/SaveLoadSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 4faa27c2c1e23..b29d5c76c5f3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -368,7 +368,7 @@ case class FileSourceScanExec( logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") val filesGroupedToBuckets = selectedPartitions.flatMap { p => - p.files.map { f => + p.files.filter(_.getLen > 0).map { f => val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts) } @@ -418,7 +418,7 @@ case class FileSourceScanExec( s"open cost is considered as scanning $openCostInBytes bytes.") val splitFiles = selectedPartitions.flatMap { partition => - partition.files.flatMap { file => + partition.files.filter(_.getLen > 0).flatMap { file => val blockLocations = getBlockLocations(file) if (fsRelation.fileFormat.isSplitable( fsRelation.sparkSession, fsRelation.options, file.getPath)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index ee5176e23e34d..9d23161c1f24e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1842,7 +1842,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") - .repartition(1) .write .text(path) @@ -1910,7 +1909,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { F.count($"dummy").as("valid"), F.count($"_corrupt_record").as("corrupt"), F.count("*").as("count")) - checkAnswer(counts, Row(1, 4, 6)) // null row for empty file + checkAnswer(counts, Row(1, 4, 6)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 12779b46bfe8c..048e4b80c72aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.sources import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import org.scalatest.BeforeAndAfter @@ -142,4 +144,15 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA assert(e.contains(s"Partition column `$unknown` not found in schema $schemaCatalog")) } } + + test("skip empty files in non bucketed read") { + withTempDir { dir => + val path = dir.getCanonicalPath + Files.write(Paths.get(path, "empty"), Array.empty[Byte]) + Files.write(Paths.get(path, "notEmpty"), "a".getBytes(StandardCharsets.UTF_8)) + val readback = spark.read.option("wholetext", true).text(path) + + assert(readback.rdd.getNumPartitions === 1) + } + } } From 39617cb2c0c433494e9f17fcd4e49c6300a9c4b0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 2 Dec 2018 10:46:17 +0800 Subject: [PATCH 2192/2461] [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) ## What changes were proposed in this pull request? It's a bad idea to use case class as public API, as it has a very wide surface. For example, the `copy` method, its fields, the companion object, etc. For a particular case, `UserDefinedFunction`. It has a private constructor, and I believe we only want users to access a few methods:`apply`, `nullable`, `asNonNullable`, etc. However, all its fields, and `copy` method, and the companion object are public unexpectedly. As a result, we made many tricks to work around the binary compatibility issues. This PR proposes to only make interfaces public, and hide implementations behind with a private class. Now `UserDefinedFunction` is a pure trait, and the concrete implementation is `SparkUserDefinedFunction`, which is private. Changing class to interface is not binary compatible(but source compatible), so 3.0 is a good chance to do it. This is the first PR to go with this direction. If it's accepted, I'll create a umbrella JIRA and fix all the public case classes. ## How was this patch tested? existing tests. Closes #23178 from cloud-fan/udf. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 + project/MimaExcludes.scala | 6 +- .../sql/expressions/UserDefinedFunction.scala | 119 ++++++++---------- .../org/apache/spark/sql/functions.scala | 2 +- 4 files changed, 63 insertions(+), 66 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index e48125a0972b5..787f4bcbbea82 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -31,6 +31,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.execution.setCommandRejectsSparkConfs` to `false`. + - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fcef424c330f1..1c83cf5860c58 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -227,7 +227,11 @@ object MimaExcludes { case ReversedMissingMethodProblem(meth) => !meth.owner.fullName.startsWith("org.apache.spark.sql.sources.v2") case _ => true - } + }, + + // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.expressions.UserDefinedFunction") ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 58a942afe28c3..f88e0e0f299de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -38,25 +38,14 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Stable -case class UserDefinedFunction protected[sql] ( - f: AnyRef, - dataType: DataType, - inputTypes: Option[Seq[DataType]]) { - - private var _nameOption: Option[String] = None - private var _nullable: Boolean = true - private var _deterministic: Boolean = true - - // This is a `var` instead of in the constructor for backward compatibility of this case class. - // TODO: revisit this case class in Spark 3.0, and narrow down the public surface. - private[sql] var nullableTypes: Option[Seq[Boolean]] = None +sealed trait UserDefinedFunction { /** * Returns true when the UDF can return a nullable value. * * @since 2.3.0 */ - def nullable: Boolean = _nullable + def nullable: Boolean /** * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same @@ -64,7 +53,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 2.3.0 */ - def deterministic: Boolean = _deterministic + def deterministic: Boolean /** * Returns an expression that invokes the UDF, using the given arguments. @@ -72,80 +61,83 @@ case class UserDefinedFunction protected[sql] ( * @since 1.3.0 */ @scala.annotation.varargs - def apply(exprs: Column*): Column = { + def apply(exprs: Column*): Column + + /** + * Updates UserDefinedFunction with a given name. + * + * @since 2.3.0 + */ + def withName(name: String): UserDefinedFunction + + /** + * Updates UserDefinedFunction to non-nullable. + * + * @since 2.3.0 + */ + def asNonNullable(): UserDefinedFunction + + /** + * Updates UserDefinedFunction to nondeterministic. + * + * @since 2.3.0 + */ + def asNondeterministic(): UserDefinedFunction +} + +private[sql] case class SparkUserDefinedFunction( + f: AnyRef, + dataType: DataType, + inputTypes: Option[Seq[DataType]], + nullableTypes: Option[Seq[Boolean]], + name: Option[String] = None, + nullable: Boolean = true, + deterministic: Boolean = true) extends UserDefinedFunction { + + @scala.annotation.varargs + override def apply(exprs: Column*): Column = { // TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()` // and `nullableTypes` is always set. - if (nullableTypes.isEmpty) { - nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f)) - } if (inputTypes.isDefined) { assert(inputTypes.get.length == nullableTypes.get.length) } + val inputsNullSafe = nullableTypes.getOrElse { + ScalaReflection.getParameterTypeNullability(f) + } + Column(ScalaUDF( f, dataType, exprs.map(_.expr), - nullableTypes.get, + inputsNullSafe, inputTypes.getOrElse(Nil), - udfName = _nameOption, - nullable = _nullable, - udfDeterministic = _deterministic)) - } - - private def copyAll(): UserDefinedFunction = { - val udf = copy() - udf._nameOption = _nameOption - udf._nullable = _nullable - udf._deterministic = _deterministic - udf.nullableTypes = nullableTypes - udf + udfName = name, + nullable = nullable, + udfDeterministic = deterministic)) } - /** - * Updates UserDefinedFunction with a given name. - * - * @since 2.3.0 - */ - def withName(name: String): UserDefinedFunction = { - val udf = copyAll() - udf._nameOption = Option(name) - udf + override def withName(name: String): UserDefinedFunction = { + copy(name = Option(name)) } - /** - * Updates UserDefinedFunction to non-nullable. - * - * @since 2.3.0 - */ - def asNonNullable(): UserDefinedFunction = { + override def asNonNullable(): UserDefinedFunction = { if (!nullable) { this } else { - val udf = copyAll() - udf._nullable = false - udf + copy(nullable = false) } } - /** - * Updates UserDefinedFunction to nondeterministic. - * - * @since 2.3.0 - */ - def asNondeterministic(): UserDefinedFunction = { - if (!_deterministic) { + override def asNondeterministic(): UserDefinedFunction = { + if (!deterministic) { this } else { - val udf = copyAll() - udf._deterministic = false - udf + copy(deterministic = false) } } } -// We have to use a name different than `UserDefinedFunction` here, to avoid breaking the binary -// compatibility of the auto-generate UserDefinedFunction object. private[sql] object SparkUserDefinedFunction { def create( @@ -157,8 +149,7 @@ private[sql] object SparkUserDefinedFunction { } else { Some(inputSchemas.map(_.get.dataType)) } - val udf = new UserDefinedFunction(f, dataType, inputTypes) - udf.nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true))) - udf + val nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true))) + SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index efa8f8526387f..33186f778d868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4259,7 +4259,7 @@ object functions { def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { // TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently // unavailable. We may need to create type-safe overloaded versions of udf() methods. - new UserDefinedFunction(f, dataType, inputTypes = None) + SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes = None) } /** From 031bd80e4f943d64a1171856978022fac320de5d Mon Sep 17 00:00:00 2001 From: lichaoqun Date: Sun, 2 Dec 2018 10:55:17 +0800 Subject: [PATCH 2193/2461] [SPARK-26195][SQL] Correct exception messages in some classes ## What changes were proposed in this pull request? UnsupportedOperationException messages are not the same with method name.This PR correct these messages. ## How was this patch tested? NA Closes #23154 from lcqzte10192193/wid-lcq-1127. Authored-by: lichaoqun Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 4 ++-- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- .../apache/spark/sql/catalyst/expressions/generators.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 36cad3cf74785..d44b42134f868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -204,10 +204,10 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + throw new UnsupportedOperationException(s"Cannot generate code for expression: $this") override def terminate(): TraversableOnce[InternalRow] = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + throw new UnsupportedOperationException(s"Cannot terminate expression: $this") } case class UnresolvedFunction( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2ecec61adb0ac..c89c2272be752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -288,7 +288,7 @@ trait Unevaluable extends Expression { throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + throw new UnsupportedOperationException(s"Cannot generate code for expression: $this") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d6e67b9ac3d10..9c74fdf6c9a14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -258,7 +258,7 @@ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generat throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + throw new UnsupportedOperationException(s"Cannot generate code for expression: $this") override def elementSchema: StructType = child.elementSchema From c7d95ccedf593edf9fda9ecaf8d0b4dda451440d Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Sun, 2 Dec 2018 17:38:25 +0800 Subject: [PATCH 2194/2461] [SPARK-26208][SQL] add headers to empty csv files when header=true ## What changes were proposed in this pull request? Add headers to empty csv files when header=true, because otherwise these files are invalid when reading. ## How was this patch tested? Added test for roundtrip of empty dataframe to csv file with headers and back in CSVSuite Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23173 from koertkuipers/feat-empty-csv-with-header. Authored-by: Koert Kuipers Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/csv/UnivocityGenerator.scala | 9 ++++---- .../datasources/csv/CSVFileFormat.scala | 22 ++++++++++++------- .../execution/datasources/csv/CSVSuite.scala | 13 +++++++++++ 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 1218f9242afeb..2ab376c0ac208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -32,7 +32,6 @@ class UnivocityGenerator( private val writerSettings = options.asWriterSettings writerSettings.setHeaders(schema.fieldNames: _*) private val gen = new CsvWriter(writer, writerSettings) - private var printHeader = options.headerFlag // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. // When the value is null, this converter should not be called. @@ -72,15 +71,15 @@ class UnivocityGenerator( values } + def writeHeaders(): Unit = { + gen.writeHeaders() + } + /** * Writes a single InternalRow to CSV using Univocity. */ def write(row: InternalRow): Unit = { - if (printHeader) { - gen.writeHeaders() - } gen.writeRow(convertRow(row): _*) - printHeader = false } def writeToString(row: InternalRow): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 4c5a1d327023c..f7d8a9e1042d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -171,15 +171,21 @@ private[csv] class CsvOutputWriter( private var univocityGenerator: Option[UnivocityGenerator] = None - override def write(row: InternalRow): Unit = { - val gen = univocityGenerator.getOrElse { - val charset = Charset.forName(params.charset) - val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) - val newGen = new UnivocityGenerator(dataSchema, os, params) - univocityGenerator = Some(newGen) - newGen - } + if (params.headerFlag) { + val gen = getGen() + gen.writeHeaders() + } + private def getGen(): UnivocityGenerator = univocityGenerator.getOrElse { + val charset = Charset.forName(params.charset) + val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) + val newGen = new UnivocityGenerator(dataSchema, os, params) + univocityGenerator = Some(newGen) + newGen + } + + override def write(row: InternalRow): Unit = { + val gen = getGen() gen.write(row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e14e8d49db5c5..bc950f2418d33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1987,6 +1987,19 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(errMsg2.contains("'lineSep' can contain only 1 character")) } + test("SPARK-26208: write and read empty data to csv file with headers") { + withTempPath { path => + val df1 = spark.range(10).repartition(2).filter(_ < 0).map(_.toString).toDF + // we have 2 partitions but they are both empty and will be filtered out upon writing + // thanks to SPARK-23271 one new empty partition will be inserted + df1.write.format("csv").option("header", true).save(path.getAbsolutePath) + val df2 = spark.read.format("csv").option("header", true).option("inferSchema", false) + .load(path.getAbsolutePath) + assert(df1.schema === df2.schema) + checkAnswer(df1, df2) + } + } + test("do not produce empty files for empty partitions") { withTempPath { dir => val path = dir.getCanonicalPath From 9cda9a892d03f60a76cd5d9b4546e72c50962c85 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 2 Dec 2018 17:41:08 +0800 Subject: [PATCH 2195/2461] [SPARK-26080][PYTHON] Skips Python resource limit on Windows in Python worker ## What changes were proposed in this pull request? `resource` package is a Unix specific package. See https://docs.python.org/2/library/resource.html and https://docs.python.org/3/library/resource.html. Note that we document Windows support: > Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). This should be backported into branch-2.4 to restore Windows support in Spark 2.4.1. ## How was this patch tested? Manually mocking the changed logics. Closes #23055 from HyukjinKwon/SPARK-26080. Lead-authored-by: hyukjinkwon Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- docs/configuration.md | 2 ++ python/pyspark/worker.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8914bd0310f98..9abbb3f634900 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -190,6 +190,8 @@ of the most common options to set are: and it is up to the application to avoid exceeding the overhead memory space shared with other non-JVM processes. When PySpark is run in YARN or Kubernetes, this memory is added to executor resource requests. + + NOTE: Python memory usage may not be limited on platforms that do not support resource limiting, such as Windows.
        diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8c59f1f999f18..953b468e96519 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -22,7 +22,12 @@ import os import sys import time -import resource +# 'resource' is a Unix specific module. +has_resource_module = True +try: + import resource +except ImportError: + has_resource_module = False import socket import traceback @@ -268,9 +273,9 @@ def main(infile, outfile): # set up memory limits memory_limit_mb = int(os.environ.get('PYSPARK_EXECUTOR_MEMORY_MB', "-1")) - total_memory = resource.RLIMIT_AS - try: - if memory_limit_mb > 0: + if memory_limit_mb > 0 and has_resource_module: + total_memory = resource.RLIMIT_AS + try: (soft_limit, hard_limit) = resource.getrlimit(total_memory) msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) print(msg, file=sys.stderr) @@ -283,9 +288,9 @@ def main(infile, outfile): print(msg, file=sys.stderr) resource.setrlimit(total_memory, (new_limit, new_limit)) - except (resource.error, OSError, ValueError) as e: - # not all systems support resource limits, so warn instead of failing - print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) + except (resource.error, OSError, ValueError) as e: + # not all systems support resource limits, so warn instead of failing + print("WARN: Failed to set memory limit: {0}\n".format(e), file=sys.stderr) # initialize global state taskContext = None From 676bbb2446af1f281b8f76a5428b7ba75b7588b3 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 2 Dec 2018 08:52:01 -0600 Subject: [PATCH 2196/2461] [SPARK-26198][SQL] Fix Metadata serialize null values throw NPE ## What changes were proposed in this pull request? How to reproduce this issue: ```scala scala> val meta = new org.apache.spark.sql.types.MetadataBuilder().putNull("key").build().json java.lang.NullPointerException at org.apache.spark.sql.types.Metadata$.org$apache$spark$sql$types$Metadata$$toJsonValue(Metadata.scala:196) at org.apache.spark.sql.types.Metadata$$anonfun$1.apply(Metadata.scala:180) ``` This pr fix `NullPointerException` when `Metadata` serialize `null` values. ## How was this patch tested? unit tests Closes #23164 from wangyum/SPARK-26198. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- .../src/main/scala/org/apache/spark/sql/types/Metadata.scala | 2 ++ .../scala/org/apache/spark/sql/types/MetadataSuite.scala | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 4979aced145c9..b6a859b75c37f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -190,6 +190,8 @@ object Metadata { JBool(x) case x: String => JString(x) + case null => + JNull case x: Metadata => toJsonValue(x.map) case other => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala index 210e65708170f..b4aeac562d2b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala @@ -26,6 +26,7 @@ class MetadataSuite extends SparkFunSuite { assert(meta.## !== 0) assert(meta.getString("key") === "value") assert(meta.contains("key")) + assert(meta === Metadata.fromJson(meta.json)) intercept[NoSuchElementException](meta.getString("no_such_key")) intercept[ClassCastException](meta.getBoolean("key")) } @@ -36,6 +37,7 @@ class MetadataSuite extends SparkFunSuite { assert(meta.## !== 0) assert(meta.getLong("key") === 12) assert(meta.contains("key")) + assert(meta === Metadata.fromJson(meta.json)) intercept[NoSuchElementException](meta.getLong("no_such_key")) intercept[ClassCastException](meta.getBoolean("key")) } @@ -46,6 +48,7 @@ class MetadataSuite extends SparkFunSuite { assert(meta.## !== 0) assert(meta.getDouble("key") === 12) assert(meta.contains("key")) + assert(meta === Metadata.fromJson(meta.json)) intercept[NoSuchElementException](meta.getDouble("no_such_key")) intercept[ClassCastException](meta.getBoolean("key")) } @@ -56,6 +59,7 @@ class MetadataSuite extends SparkFunSuite { assert(meta.## !== 0) assert(meta.getBoolean("key") === true) assert(meta.contains("key")) + assert(meta === Metadata.fromJson(meta.json)) intercept[NoSuchElementException](meta.getBoolean("no_such_key")) intercept[ClassCastException](meta.getString("key")) } @@ -69,6 +73,7 @@ class MetadataSuite extends SparkFunSuite { assert(meta.getLong("key") === 0) assert(meta.getBoolean("key") === false) assert(meta.contains("key")) + assert(meta === Metadata.fromJson(meta.json)) intercept[NoSuchElementException](meta.getLong("no_such_key")) } } From bfa3d32f7719cd4bfb2c161fe4a6bd3eea148158 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 3 Dec 2018 16:18:22 +0800 Subject: [PATCH 2197/2461] [SPARK-26117][FOLLOW-UP][SQL] throw SparkOutOfMemoryError intead of SparkException in UnsafeHashedRelation ## What changes were proposed in this pull request? When build hash Map with one row of data and run out of memory, we should throw a SparkOutOfMemoryError exception, which is more accurate than SparkException. this PR fix it. ## How was this patch tested? N / A Closes #23190 from heary-cao/throwUnsafeHashedRelation. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 86eb47a70f1ad..e8c01d46a84c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -24,7 +24,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED -import org.apache.spark.memory.{MemoryConsumer, StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.memory._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode @@ -316,7 +316,9 @@ private[joins] object UnsafeHashedRelation { row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) if (!success) { binaryMap.free() - throw new SparkException("There is no enough memory to build hash map") + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("There is no enough memory to build hash map") + // scalastyle:on throwerror } } } From 11e5f1bcd49eec8ab4225d6e68a051b5c6a21cb2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 3 Dec 2018 18:25:38 +0800 Subject: [PATCH 2198/2461] [SPARK-26151][SQL] Return partial results for bad CSV records ## What changes were proposed in this pull request? In the PR, I propose to change behaviour of `UnivocityParser` and `FailureSafeParser`, and return all fields that were parsed and converted to expected types successfully instead of just returning a row with all `null`s for a bad input in the `PERMISSIVE` mode. For example, for CSV line `0,2013-111-11 12:13:14` and DDL schema `a int, b timestamp`, new result is `Row(0, null)`. ## How was this patch tested? It was checked by existing tests from `CsvSuite` and `CsvFunctionsSuite`. Closes #23120 from MaxGekk/failuresafe-partial-result. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/UnivocityParser.scala | 27 ++++++++++--------- .../sql/catalyst/util/FailureSafeParser.scala | 21 ++++++--------- .../apache/spark/sql/CsvFunctionsSuite.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 6 ++--- 4 files changed, 27 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 85e129224c913..8fff4b0781b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -243,21 +243,24 @@ class UnivocityParser( () => getPartialResult(), new RuntimeException("Malformed CSV record")) } else { - try { - // When the length of the returned tokens is identical to the length of the parsed schema, - // we just need to convert the tokens that correspond to the required columns. - var i = 0 - while (i < requiredSchema.length) { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. + var badRecordException: Option[Throwable] = None + var i = 0 + while (i < requiredSchema.length) { + try { row(i) = valueConverters(i).apply(getToken(tokens, i)) - i += 1 + } catch { + case NonFatal(e) => + badRecordException = badRecordException.orElse(Some(e)) } + i += 1 + } + + if (badRecordException.isEmpty) { row - } catch { - case NonFatal(e) => - // For corrupted records with the number of tokens same as the schema, - // CSV reader doesn't support partial results. All fields other than the field - // configured by `columnNameOfCorruptRecord` are set to `null`. - throw BadRecordException(() => getCurrentInput, () => None, e) + } else { + throw BadRecordException(() => getCurrentInput, () => Some(row), badRecordException.get) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 76745b11c84c9..4baf052bfe564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -33,26 +33,21 @@ class FailureSafeParser[IN]( private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) private val resultRow = new GenericInternalRow(schema.length) - private val nullResult = new GenericInternalRow(schema.length) // This function takes 2 parameters: an optional partial result, and the bad record. If the given // schema doesn't contain a field for corrupted record, we just return the partial result or a // row with all fields null. If the given schema contains a field for corrupted record, we will // set the bad record to this field, and set other fields according to the partial result or null. private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { - if (corruptFieldIndex.isDefined) { - (row, badRecord) => { - var i = 0 - while (i < actualSchema.length) { - val from = actualSchema(i) - resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull - i += 1 - } - resultRow(corruptFieldIndex.get) = badRecord() - resultRow + (row, badRecord) => { + var i = 0 + while (i < actualSchema.length) { + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull + i += 1 } - } else { - (row, _) => row.getOrElse(nullResult) + corruptFieldIndex.foreach(index => resultRow(index) = badRecord()) + resultRow } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 1c359ce1d2014..537d13b1bc8dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -60,7 +60,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { "mode" -> "Permissive", "columnNameOfCorruptRecord" -> columnNameOfCorruptRecord))) checkAnswer(df2, Seq( - Row(Row(null, null, "0,2013-111-11 12:13:14")), + Row(Row(0, null, "0,2013-111-11 12:13:14")), Row(Row(1, java.sql.Date.valueOf("1983-08-04"), null)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index bc950f2418d33..c9273193b6425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1116,7 +1116,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .schema(schema) .csv(testFile(valueMalformedFile)) checkAnswer(df1, - Row(null, null) :: + Row(0, null) :: Row(1, java.sql.Date.valueOf("1983-08-04")) :: Nil) @@ -1131,7 +1131,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, - Row(null, null, "0,2013-111-11 12:13:14") :: + Row(0, null, "0,2013-111-11 12:13:14") :: Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: Nil) @@ -1148,7 +1148,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, - Row(null, "0,2013-111-11 12:13:14", null) :: + Row(0, "0,2013-111-11 12:13:14", null) :: Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: Nil) From b569ba53f4b650c03bd11def7c7f7589ceff61eb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Dec 2018 19:53:45 +0800 Subject: [PATCH 2199/2461] [SPARK-26230][SQL] FileIndex: if case sensitive, validate partitions with original column names ## What changes were proposed in this pull request? Partition column name is required to be unique under the same directory. The following paths are invalid partitioned directory: ``` hdfs://host:9000/path/a=1 hdfs://host:9000/path/b=2 ``` If case sensitive, the following paths should be invalid too: ``` hdfs://host:9000/path/a=1 hdfs://host:9000/path/A=2 ``` Since column 'a' and 'A' are different, and it is wrong to use either one as the column name in partition schema. Also, there is a `TODO` comment in the code. Currently the Spark doesn't validate such case when `CASE_SENSITIVE` enabled. This PR is to resolve the problem. ## How was this patch tested? Add unit test Closes #23186 from gengliangwang/SPARK-26230. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../datasources/PartitioningUtils.scala | 14 +++++--- .../datasources/FileIndexSuite.scala | 32 ++++++++++++++++++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 9d2c9ba0c1a5b..d66cb09bda0cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -155,7 +155,8 @@ object PartitioningUtils { "root directory of the table. If there are multiple root directories, " + "please load them separately and then union them.") - val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, timeZone) + val resolvedPartitionValues = + resolvePartitions(pathsWithPartitionValues, caseSensitive, timeZone) // Creates the StructType which represents the partition columns. val fields = { @@ -345,15 +346,18 @@ object PartitioningUtils { */ def resolvePartitions( pathsWithPartitionValues: Seq[(Path, PartitionValues)], + caseSensitive: Boolean, timeZone: TimeZone): Seq[PartitionValues] = { if (pathsWithPartitionValues.isEmpty) { Seq.empty } else { - // TODO: Selective case sensitivity. - val distinctPartColNames = - pathsWithPartitionValues.map(_._2.columnNames.map(_.toLowerCase())).distinct + val partColNames = if (caseSensitive) { + pathsWithPartitionValues.map(_._2.columnNames) + } else { + pathsWithPartitionValues.map(_._2.columnNames.map(_.toLowerCase())) + } assert( - distinctPartColNames.size == 1, + partColNames.distinct.size == 1, listConflictingPartitionColumns(pathsWithPartitionValues)) // Resolves possible type conflicts for each column diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index fdb0511f01a22..ec552f7ddf47a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -52,7 +52,7 @@ class FileIndexSuite extends SharedSQLContext { test("SPARK-26188: don't infer data types of partition columns if user specifies schema") { withTempDir { dir => - val partitionDirectory = new File(dir, s"a=4d") + val partitionDirectory = new File(dir, "a=4d") partitionDirectory.mkdir() val file = new File(partitionDirectory, "text.txt") stringToFile(file, "text") @@ -65,6 +65,36 @@ class FileIndexSuite extends SharedSQLContext { } } + test("SPARK-26230: if case sensitive, validate partitions with original column names") { + withTempDir { dir => + val partitionDirectory = new File(dir, "a=1") + partitionDirectory.mkdir() + val file = new File(partitionDirectory, "text.txt") + stringToFile(file, "text") + val partitionDirectory2 = new File(dir, "A=2") + partitionDirectory2.mkdir() + val file2 = new File(partitionDirectory2, "text.txt") + stringToFile(file2, "text") + val path = new Path(dir.getCanonicalPath) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, None) + val partitionValues = fileIndex.partitionSpec().partitions.map(_.values) + assert(partitionValues.length == 2) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val msg = intercept[AssertionError] { + val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, None) + fileIndex.partitionSpec() + }.getMessage + assert(msg.contains("Conflicting partition column names detected")) + assert("Partition column name list #[0-1]: A".r.findFirstIn(msg).isDefined) + assert("Partition column name list #[0-1]: a".r.findFirstIn(msg).isDefined) + } + } + } + test("InMemoryFileIndex: input paths are converted to qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") From eebb940edb0c89eef88e1deb0fc0ae1c7a24bc3d Mon Sep 17 00:00:00 2001 From: pgandhi Date: Mon, 3 Dec 2018 07:53:21 -0600 Subject: [PATCH 2200/2461] [SPARK-26253][WEBUI] Task Summary Metrics Table on Stage Page shows empty table when no data is present Task Summary Metrics Table on Stage Page shows empty table when no data is present instead of showing a message. ## What changes were proposed in this pull request? Added a custom message to show on the task summary metrics table as well as executor summary table when no data is present. ## How was this patch tested? **Before:** ![49335550-29277d00-f615-11e8-8e62-a953e76bcebf](https://user-images.githubusercontent.com/22228190/49361520-425a2780-f702-11e8-8df4-08862ab6ceb8.png) **After:** screen shot 2018-12-03 at 1 56 09 pm Closes #23205 from pgandhi999/SPARK-26253. Authored-by: pgandhi Signed-off-by: Sean Owen --- .../resources/org/apache/spark/ui/static/stagepage.js | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js index 4c83ec7e95ab1..564467487e84e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js @@ -221,7 +221,10 @@ function createDataTableForTaskSummaryMetricsTable(taskSummaryMetricsTable) { "searching": false, "order": [[0, "asc"]], "bSort": false, - "bAutoWidth": false + "bAutoWidth": false, + "oLanguage": { + "sEmptyTable": "No tasks have reported metrics yet" + } }; taskSummaryMetricsDataTable = $(taskMetricsTable).DataTable(taskConf); } @@ -426,7 +429,10 @@ $(document).ready(function () { } ], "order": [[0, "asc"]], - "bAutoWidth": false + "bAutoWidth": false, + "oLanguage": { + "sEmptyTable": "No data to show yet" + } } var executorSummaryTableSelector = $("#summary-executor-table").DataTable(executorSummaryConf); From 8534d753ecb21ea64ffbaefb5eaca38ba0464c6d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 3 Dec 2018 23:54:26 +0800 Subject: [PATCH 2201/2461] [SPARK-26181][SQL] the `hasMinMaxStats` method of `ColumnStatsMap` is not correct ## What changes were proposed in this pull request? For now the `hasMinMaxStats` will return the same as `hasCountStats`, which is obviously not as expected. ## How was this patch tested? Existing tests. Closes #23152 from adrian-wang/minmaxstats. Authored-by: Daoyuan Wang Signed-off-by: Wenchen Fan --- .../statsEstimation/FilterEstimation.scala | 14 +++++++--- .../FilterEstimationSuite.scala | 27 +++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 14 ++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3eeefaedb18..2c5beef43f52a 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -670,6 +670,14 @@ case class FilterEstimation(plan: Filter) extends Logging { logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft) return None case _ => + if (!colStatsMap.hasMinMaxStats(attrLeft)) { + logDebug("[CBO] No min/max statistics for " + attrLeft) + return None + } + if (!colStatsMap.hasMinMaxStats(attrRight)) { + logDebug("[CBO] No min/max statistics for " + attrRight) + return None + } } val colStatLeft = colStatsMap(attrLeft) @@ -879,13 +887,13 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { } def hasCountStats(a: Attribute): Boolean = - get(a).map(_.hasCountStats).getOrElse(false) + get(a).exists(_.hasCountStats) def hasDistinctCount(a: Attribute): Boolean = - get(a).map(_.distinctCount.isDefined).getOrElse(false) + get(a).exists(_.distinctCount.isDefined) def hasMinMaxStats(a: Attribute): Boolean = - get(a).map(_.hasCountStats).getOrElse(false) + get(a).exists(_.hasMinMaxStats) /** * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 47bfa62569583..b0a47e7835129 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{ColumnStatsMap, FilterEstimation} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -821,6 +822,32 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("ColumnStatsMap tests") { + val attrNoDistinct = AttributeReference("att_without_distinct", IntegerType)() + val attrNoCount = AttributeReference("att_without_count", BooleanType)() + val attrNoMinMax = AttributeReference("att_without_min_max", DateType)() + val colStatNoDistinct = ColumnStat(distinctCount = None, min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val colStatNoCount = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = None, avgLen = Some(1), maxLen = Some(1)) + val colStatNoMinMax = ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(1), avgLen = None, maxLen = None) + val columnStatsMap = ColumnStatsMap(AttributeMap(Seq( + attrNoDistinct -> colStatNoDistinct, + attrNoCount -> colStatNoCount, + attrNoMinMax -> colStatNoMinMax + ))) + assert(!columnStatsMap.hasDistinctCount(attrNoDistinct)) + assert(columnStatsMap.hasDistinctCount(attrNoCount)) + assert(columnStatsMap.hasDistinctCount(attrNoMinMax)) + assert(!columnStatsMap.hasCountStats(attrNoDistinct)) + assert(!columnStatsMap.hasCountStats(attrNoCount)) + assert(columnStatsMap.hasCountStats(attrNoMinMax)) + assert(columnStatsMap.hasMinMaxStats(attrNoDistinct)) + assert(columnStatsMap.hasMinMaxStats(attrNoCount)) + assert(!columnStatsMap.hasMinMaxStats(attrNoMinMax)) + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index dfcde8cc0d39f..fab2a27cdef17 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2276,4 +2276,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + + test("SPARK-26181 hasMinMaxStats method of ColumnStatsMap is not correct") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("all_null") { + sql("create table all_null (attr1 int, attr2 int)") + sql("insert into all_null values (null, null)") + sql("analyze table all_null compute statistics for columns attr1, attr2") + // check if the stats can be calculated without Cast exception. + sql("select * from all_null where attr1 < 1").queryExecution.stringWithStats + sql("select * from all_null where attr1 < attr2").queryExecution.stringWithStats + } + } + } + } From 6e4e70fe7bc3e103b8538748511261bb43cf3548 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Dec 2018 10:02:15 -0600 Subject: [PATCH 2202/2461] [SPARK-26235][CORE] Change log level for ClassNotFoundException/NoClassDefFoundError in SparkSubmit to Error ## What changes were proposed in this pull request? In my local setup, I set log4j root category as ERROR (https://stackoverflow.com/questions/27781187/how-to-stop-info-messages-displaying-on-spark-console , first item show up if we google search "set spark log level".) When I run such command ``` spark-submit --class foo bar.jar ``` Nothing shows up, and the script exits. After quick investigation, I think the log level for ClassNotFoundException/NoClassDefFoundError in SparkSubmit should be ERROR instead of WARN. Since the whole process exit because of the exception/error. Before https://github.com/apache/spark/pull/20925, the message is not controlled by `log4j.rootCategory`. ## How was this patch tested? Manual check. Closes #23189 from gengliangwang/changeLogLevel. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 324f6f8894d34..d4055cb6c5853 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -813,14 +813,14 @@ private[spark] class SparkSubmit extends Logging { mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => - logWarning(s"Failed to load $childMainClass.", e) + logError(s"Failed to load class $childMainClass.") if (childMainClass.contains("thriftserver")) { logInfo(s"Failed to load main class $childMainClass.") logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) case e: NoClassDefFoundError => - logWarning(s"Failed to load $childMainClass: ${e.getMessage()}") + logError(s"Failed to load $childMainClass: ${e.getMessage()}") if (e.getMessage.contains("org/apache/hadoop/hive")) { logInfo(s"Failed to load hive class.") logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") @@ -915,6 +915,8 @@ object SparkSubmit extends CommandLineUtils with Logging { override protected def logInfo(msg: => String): Unit = self.logInfo(msg) override protected def logWarning(msg: => String): Unit = self.logWarning(msg) + + override protected def logError(msg: => String): Unit = self.logError(msg) } } @@ -922,6 +924,8 @@ object SparkSubmit extends CommandLineUtils with Logging { override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + override protected def logError(msg: => String): Unit = printMessage(s"Error: $msg") + override def doSubmit(args: Array[String]): Unit = { try { super.doSubmit(args) From 5e5b9f2ee0b4d8470197b404906fbd245c28f8ac Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Mon, 3 Dec 2018 10:03:51 -0600 Subject: [PATCH 2203/2461] [SPARK-26177] Config change followup to [] Automated formatting for Scala code Let's keep this open for a while to see if other configuration tweaks are suggested ## What changes were proposed in this pull request? Formatting configuration changes following up https://github.com/apache/spark/pull/23148 ## How was this patch tested? ./dev/scalafmt Closes #23182 from koeninger/scalafmt-config. Authored-by: cody koeninger Signed-off-by: Sean Owen --- dev/.scalafmt.conf | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/.scalafmt.conf b/dev/.scalafmt.conf index def67e0269822..9a8813e3b3eed 100644 --- a/dev/.scalafmt.conf +++ b/dev/.scalafmt.conf @@ -19,6 +19,10 @@ align = none align.openParenDefnSite = false align.openParenCallSite = false align.tokens = [] +optIn = { + configStyleArguments = false +} +danglingParentheses = false docstrings = JavaDoc maxColumn = 98 From 04046e5432acb1132fa567f2230723bc1a92a482 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 00:05:15 +0800 Subject: [PATCH 2204/2461] [SPARK-25498][SQL] InterpretedMutableProjection should handle UnsafeRow ## What changes were proposed in this pull request? Since `AggregationIterator` uses `MutableProjection` for `UnsafeRow`, `InterpretedMutableProjection` needs to handle `UnsafeRow` as buffer internally for fixed-length types only. ## How was this patch tested? Run 'SQLQueryTestSuite' with the interpreted mode. Closes #22512 from maropu/InterpreterTest. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/InternalRow.scala | 22 +++++ .../InterpretedMutableProjection.scala | 23 +++++- .../expressions/ExpressionEvalHelper.scala | 11 +++ .../expressions/MutableProjectionSuite.scala | 81 +++++++++++++++++++ .../expressions/UnsafeRowConverterSuite.scala | 15 +--- .../sql-tests/inputs/change-column.sql | 1 + .../test/resources/sql-tests/inputs/udaf.sql | 3 + .../sql-tests/results/change-column.sql.out | 10 ++- .../resources/sql-tests/results/udaf.sql.out | 18 ++++- .../apache/spark/sql/SQLQueryTestSuite.scala | 27 ++++++- 10 files changed, 192 insertions(+), 19 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e49c10be6be4e..bdab407688a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -157,4 +157,26 @@ object InternalRow { getValueNullSafe } } + + /** + * Returns a writer for an `InternalRow` with given data type. + */ + def getWriter(ordinal: Int, dt: DataType): (InternalRow, Any) => Unit = dt match { + case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean]) + case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) + case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) + case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) + case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) + case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) + case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) + case DecimalType.Fixed(precision, _) => + (input, v) => input.setDecimal(ordinal, v.asInstanceOf[Decimal], precision) + case udt: UserDefinedType[_] => getWriter(ordinal, udt.sqlType) + case NullType => (input, _) => input.setNullAt(ordinal) + case StringType => (input, v) => input.update(ordinal, v.asInstanceOf[UTF8String].copy()) + case _: StructType => (input, v) => input.update(ordinal, v.asInstanceOf[InternalRow].copy()) + case _: ArrayType => (input, v) => input.update(ordinal, v.asInstanceOf[ArrayData].copy()) + case _: MapType => (input, v) => input.update(ordinal, v.asInstanceOf[MapData].copy()) + case _ => (input, v) => input.update(ordinal, v) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 0654108cea281..122a564da61be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -49,10 +49,31 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable def currentValue: InternalRow = mutableRow override def target(row: InternalRow): MutableProjection = { + // If `mutableRow` is `UnsafeRow`, `MutableProjection` accepts fixed-length types only + require(!row.isInstanceOf[UnsafeRow] || + validExprs.forall { case (e, _) => UnsafeRow.isFixedLength(e.dataType) }, + "MutableProjection cannot use UnsafeRow for output data types: " + + validExprs.map(_._1.dataType).filterNot(UnsafeRow.isFixedLength) + .map(_.catalogString).mkString(", ")) mutableRow = row this } + private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => + val writer = InternalRow.getWriter(i, e.dataType) + if (!e.nullable) { + (v: Any) => writer(mutableRow, v) + } else { + (v: Any) => { + if (v == null) { + mutableRow.setNullAt(i) + } else { + writer(mutableRow, v) + } + } + } + }.toArray + override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < validExprs.length) { @@ -64,7 +85,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable i = 0 while (i < validExprs.length) { val (_, ordinal) = validExprs(i) - mutableRow(ordinal) = buffer(ordinal) + fieldWriters(i)(buffer(ordinal)) i += 1 } mutableRow diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index eb33325d0b31a..a7282e1b1cadc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -456,4 +456,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa diff < eps * math.min(absX, absY) } } + + def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala new file mode 100644 index 0000000000000..2db1c3b98819c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { + + val fixedLengthTypes = Array[DataType]( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + DateType, TimestampType) + + val variableLengthTypes = Array( + StringType, DecimalType.defaultConcreteType, CalendarIntervalType, BinaryType, + ArrayType(StringType), MapType(IntegerType, StringType), + StructType.fromDDL("a INT, b STRING"), ObjectType(classOf[java.lang.Integer])) + + def createMutableProjection(dataTypes: Array[DataType]): MutableProjection = { + MutableProjection.create(dataTypes.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) + } + + testBothCodegenAndInterpreted("fixed-length types") { + val inputRow = InternalRow.fromSeq(Seq(true, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 1, 2L)) + val proj = createMutableProjection(fixedLengthTypes) + assert(proj(inputRow) === inputRow) + } + + testBothCodegenAndInterpreted("unsafe buffer") { + val inputRow = InternalRow.fromSeq(Seq(false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L)) + val numBytes = UnsafeRow.calculateBitSetWidthInBytes(fixedLengthTypes.length) + val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length) + val proj = createMutableProjection(fixedLengthTypes) + val projUnsafeRow = proj.target(unsafeBuffer)(inputRow) + assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow) + } + + testBothCodegenAndInterpreted("variable-length types") { + val proj = createMutableProjection(variableLengthTypes) + val scalaValues = Seq("abc", BigDecimal(10), CalendarInterval.fromString("interval 1 day"), + Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"), + new java.lang.Integer(5)) + val inputRow = InternalRow.fromSeq(scalaValues.zip(variableLengthTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj(inputRow) + variableLengthTypes.zipWithIndex.foreach { case (dataType, index) => + val toScala = CatalystTypeConverters.createToScalaConverter(dataType) + assert(toScala(projRow.get(index, dataType)) === toScala(inputRow.get(index, dataType))) + } + } + + test("unsupported types for unsafe buffer") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) { + val proj = createMutableProjection(Array(StringType)) + val errMsg = intercept[IllegalArgumentException] { + proj.target(new UnsafeRow(1)) + }.getMessage + assert(errMsg.contains("MutableProjection cannot use UnsafeRow for output data types:")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 5a646d9a850ac..268372b5d0504 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -26,26 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase + with ExpressionEvalHelper { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { - val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) - for (fallbackMode <- modes) { - test(s"$name with $fallbackMode") { - withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { - f - } - } - } - } - testBothCodegenAndInterpreted("basic conversion with only primitive types") { val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index 2909024e4c9f7..6f5ac221ce79c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -54,3 +54,4 @@ ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; DROP TABLE partition_table; +DROP VIEW global_temp.global_temp_view; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index 2183ba23afc38..58613a1325dfa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -11,3 +11,6 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1; CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; SELECT default.udaf1(int_col1) as udaf1 from t1; + +DROP FUNCTION myDoubleAvg; +DROP FUNCTION udaf1; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ff1ecbcc44c23..114617873af47 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 33 +-- Number of queries: 34 -- !query 0 @@ -313,3 +313,11 @@ DROP TABLE partition_table struct<> -- !query 32 output + + +-- !query 33 +DROP VIEW global_temp.global_temp_view +-- !query 33 schema +struct<> +-- !query 33 output + diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 87824ab81cdf7..f4455bb717578 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 8 -- !query 0 @@ -52,3 +52,19 @@ struct<> -- !query 5 output org.apache.spark.sql.AnalysisException Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 + + +-- !query 6 +DROP FUNCTION myDoubleAvg +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +DROP FUNCTION udaf1 +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 6ca3ac596e5f4..fd180ce2380a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -22,11 +22,13 @@ import java.util.{Locale, TimeZone} import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -140,6 +142,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val input = fileToString(new File(testCase.inputFile)) val (comments, code) = input.split("\n").partition(_.startsWith("--")) + + // Runs all the tests on both codegen-only and interpreter modes + val codegenConfigSets = Array(CODEGEN_ONLY, NO_CODEGEN).map { + case codegenFactoryMode => + Array(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode.toString) + } val configSets = { val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) val configs = configLines.map(_.split(",").map { confAndValue => @@ -148,12 +156,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { }) // When we are regenerating the golden files we don't need to run all the configs as they // all need to return the same result - if (regenerateGoldenFiles && configs.nonEmpty) { - configs.take(1) + if (regenerateGoldenFiles) { + if (configs.nonEmpty) { + configs.take(1) + } else { + Array.empty[Array[(String, String)]] + } } else { - configs + if (configs.nonEmpty) { + codegenConfigSets.flatMap { codegenConfig => + configs.map { config => + config ++ codegenConfig + } + } + } else { + codegenConfigSets + } } } + // List of SQL queries to run // note: this is not a robust way to split queries using semicolon, but works for now. val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq From 0c2935b01def8a5f631851999d9c2d57b63763e6 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Mon, 3 Dec 2018 09:02:47 -0800 Subject: [PATCH 2205/2461] [SPARK-25515][K8S] Adds a config option to keep executor pods for debugging ## What changes were proposed in this pull request? Keeps K8s executor resources present if case of failure or normal termination. Introduces a new boolean config option: `spark.kubernetes.deleteExecutors`, with default value set to true. The idea is to update Spark K8s backend structures but leave the resources around. The assumption is that since entries are not removed from the `removedExecutorsCache` we are immune to updates that refer to the the executor resources previously removed. The only delete operation not touched is the one in the `doKillExecutors` method. Reason is right now we dont support [blacklisting](https://issues.apache.org/jira/browse/SPARK-23485) and dynamic allocation with Spark on K8s. In both cases in the future we might want to handle these scenarios although its more complicated. More tests can be added if approach is approved. ## How was this patch tested? Manually by running a Spark job and verifying pods are not deleted. Closes #23136 from skonto/keep_pods. Authored-by: Stavros Kontopoulos Signed-off-by: Yinan Li --- docs/running-on-kubernetes.md | 7 +++++++ .../org/apache/spark/deploy/k8s/Config.scala | 7 +++++++ .../cluster/k8s/ExecutorPodsAllocator.scala | 15 ++++++++++----- .../k8s/ExecutorPodsLifecycleManager.scala | 8 ++++++-- .../k8s/KubernetesClusterSchedulerBackend.scala | 15 ++++++++++----- .../k8s/ExecutorPodsLifecycleManagerSuite.scala | 14 +++++++++++++- 6 files changed, 53 insertions(+), 13 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 5639253d52f54..3172b1bca8f05 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -944,6 +944,13 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.podTemplateFile=/path/to/executor-pod-template.yaml` + + + + +
        spark.kubernetes.executor.deleteOnTerminationtrue + Specify whether executor pods should be deleted in case of failure or normal termination. +
        #### Pod template properties diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 1abf2901268f8..e8bf16df190e8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -282,6 +282,13 @@ private[spark] object Config extends Logging { val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." + val KUBERNETES_DELETE_EXECUTORS = + ConfigBuilder("spark.kubernetes.executor.deleteOnTermination") + .doc("If set to false then executor pods will not be deleted in case " + + "of failure or normal termination.") + .booleanConf + .createWithDefault(true) + val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 77bb9c3fcc9f4..ef4cbdf162c6c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -51,6 +51,8 @@ private[spark] class ExecutorPodsAllocator( private val kubernetesDriverPodName = conf .get(KUBERNETES_DRIVER_POD_NAME) + private val shouldDeleteExecutors = conf.get(KUBERNETES_DELETE_EXECUTORS) + private val driverPod = kubernetesDriverPodName .map(name => Option(kubernetesClient.pods() .withName(name) @@ -86,11 +88,14 @@ private[spark] class ExecutorPodsAllocator( s" cluster after $podCreationTimeout milliseconds despite the fact that a" + " previous allocation attempt tried to create it. The executor may have been" + " deleted but the application missed the deletion event.") - Utils.tryLogNonFatalError { - kubernetesClient - .pods() - .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) - .delete() + + if (shouldDeleteExecutors) { + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } } newlyCreatedExecutors -= execId } else { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index 77a1d6cfae3bd..95e1ba8362a02 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -30,7 +30,7 @@ import org.apache.spark.scheduler.ExecutorExited import org.apache.spark.util.Utils private[spark] class ExecutorPodsLifecycleManager( - conf: SparkConf, + val conf: SparkConf, kubernetesClient: KubernetesClient, snapshotsStore: ExecutorPodsSnapshotsStore, // Use a best-effort to track which executors have been removed already. It's not generally @@ -43,6 +43,8 @@ private[spark] class ExecutorPodsLifecycleManager( private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + private lazy val shouldDeleteExecutors = conf.get(KUBERNETES_DELETE_EXECUTORS) + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { snapshotsStore.addSubscriber(eventProcessingInterval) { onNewSnapshots(schedulerBackend, _) @@ -112,7 +114,9 @@ private[spark] class ExecutorPodsLifecycleManager( schedulerBackend: KubernetesClusterSchedulerBackend, execIdsRemovedInRound: mutable.Set[Long]): Unit = { removeExecutorFromSpark(schedulerBackend, podState, execId) - removeExecutorFromK8s(podState.pod) + if (shouldDeleteExecutors) { + removeExecutorFromK8s(podState.pod) + } execIdsRemovedInRound += execId } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index fa6dc2c479bbf..6356b58645806 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -21,6 +21,7 @@ import java.util.concurrent.ExecutorService import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} @@ -51,6 +52,8 @@ private[spark] class KubernetesClusterSchedulerBackend( private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) + private val shouldDeleteExecutors = conf.get(KUBERNETES_DELETE_EXECUTORS) + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { removeExecutor(executorId, reason) @@ -82,11 +85,13 @@ private[spark] class KubernetesClusterSchedulerBackend( pollEvents.stop() } - Utils.tryLogNonFatalError { - kubernetesClient.pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) - .delete() + if (shouldDeleteExecutors) { + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() + } } Utils.tryLogNonFatalError { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index 3995b2afe7c45..7411f8f9d69e9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.Mockito.{mock, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter @@ -30,6 +30,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.deploy.k8s.KubernetesUtils._ import org.apache.spark.scheduler.ExecutorExited @@ -100,6 +101,17 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) } + test("Keep executor pods in k8s if configured.") { + val failedPod = failedExecutorWithoutDeletion(1) + eventHandlerUnderTest.conf.set(Config.KUBERNETES_DELETE_EXECUTORS, false) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(podOperations, never()).delete() + } + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { val reason = Option(failedPod.getStatus.getReason) val message = Option(failedPod.getStatus.getMessage) From 187bb7d008872e812aaa6590c89121bfa50e97d3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 3 Dec 2018 13:54:09 -0800 Subject: [PATCH 2206/2461] [SPARK-25957][FOLLOWUP] Build python docker image in sbt build too. docker-image-tool.sh requires explicit argument to create the python image now; do that from the sbt integration tests target too. Closes #23172 from vanzin/SPARK-25957.followup. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- project/SparkBuild.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index bb834bc483f1f..a0946a9ad6656 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -494,7 +494,13 @@ object KubernetesIntegrationTests { dockerBuild := { if (shouldBuildImage) { val dockerTool = s"$sparkHome/bin/docker-image-tool.sh" - val cmd = Seq(dockerTool, "-m", "-t", imageTag.value, "build") + val bindingsDir = s"$sparkHome/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings" + val cmd = Seq(dockerTool, "-m", + "-t", imageTag.value, + "-p", s"$bindingsDir/python/Dockerfile", + "-R", s"$bindingsDir/R/Dockerfile", + "build" + ) val ec = Process(cmd).! if (ec != 0) { throw new IllegalStateException(s"Process '${cmd.mkString(" ")}' exited with $ec.") From 518a3d10c87bb6d7d442eba7265fc026aa54473e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 3 Dec 2018 14:03:10 -0800 Subject: [PATCH 2207/2461] [SPARK-26033][SPARK-26034][PYTHON][FOLLOW-UP] Small cleanup and deduplication in ml/mllib tests ## What changes were proposed in this pull request? This PR is a small follow up that puts some logic and functions into smaller scope and make it localized, and deduplicate. ## How was this patch tested? Manually tested. Jenkins tests as well. Closes #23200 from HyukjinKwon/followup-SPARK-26034-SPARK-26033. Authored-by: Hyukjin Kwon Signed-off-by: Bryan Cutler --- python/pyspark/ml/tests/test_linalg.py | 44 +++++++------ python/pyspark/mllib/tests/test_algorithms.py | 8 +-- python/pyspark/mllib/tests/test_linalg.py | 62 ++++++++----------- python/pyspark/testing/mllibutils.py | 5 -- 4 files changed, 51 insertions(+), 68 deletions(-) diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 71cad5d7f5ad7..995bc35e4ca80 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -20,25 +20,17 @@ from numpy import arange, array, array_equal, inf, ones, tile, zeros +from pyspark.serializers import PickleSerializer from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \ Vector, VectorUDT, Vectors -from pyspark.testing.mllibutils import make_serializer, MLlibTestCase +from pyspark.testing.mllibutils import MLlibTestCase from pyspark.sql import Row -ser = make_serializer() - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - class VectorTests(MLlibTestCase): def _test_serialize(self, v): + ser = PickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) @@ -77,24 +69,30 @@ def test_dot(self): self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): + def squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + sv = SparseVector(4, {1: 1, 3: 2}) dv = DenseVector(array([1., 2., 3., 4.])) lst = DenseVector([4, 3, 2, 1]) lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, squared_distance(sv, dv)) + self.assertEqual(25.0, squared_distance(sv, lst)) + self.assertEqual(20.0, squared_distance(dv, lst)) + self.assertEqual(15.0, squared_distance(dv, sv)) + self.assertEqual(25.0, squared_distance(lst, sv)) + self.assertEqual(20.0, squared_distance(lst, dv)) + self.assertEqual(0.0, squared_distance(sv, sv)) + self.assertEqual(0.0, squared_distance(dv, dv)) + self.assertEqual(0.0, squared_distance(lst, lst)) + self.assertEqual(25.0, squared_distance(sv, lst1)) + self.assertEqual(3.0, squared_distance(sv, arr)) + self.assertEqual(3.0, squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index cc3b64b1cb284..21a2d64087bc1 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -26,10 +26,8 @@ from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint -from pyspark.testing.mllibutils import make_serializer, MLlibTestCase - - -ser = make_serializer() +from pyspark.serializers import PickleSerializer +from pyspark.testing.mllibutils import MLlibTestCase class ListTests(MLlibTestCase): @@ -265,6 +263,7 @@ def test_regression(self): class ALSTests(MLlibTestCase): def test_als_ratings_serialize(self): + ser = PickleSerializer() r = Rating(7, 1123, 3.14) jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) @@ -273,6 +272,7 @@ def test_als_ratings_serialize(self): self.assertAlmostEqual(r.rating, nr.rating, 2) def test_als_ratings_id_long_error(self): + ser = PickleSerializer() r = Rating(1205640308657491975, 50233468418, 1.0) # rating user id exceeds max int value, should fail when pickled self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index d0ebd9bc3db79..f26e28d1744de 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -22,33 +22,18 @@ from numpy import array, array_equal, zeros, arange, tile, ones, inf import pyspark.ml.linalg as newlinalg +from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint -from pyspark.testing.mllibutils import make_serializer, MLlibTestCase - -_have_scipy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass - - -ser = make_serializer() - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) +from pyspark.testing.mllibutils import MLlibTestCase +from pyspark.testing.utils import have_scipy class VectorTests(MLlibTestCase): def _test_serialize(self, v): + ser = PickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) @@ -87,24 +72,30 @@ def test_dot(self): self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): + def squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + sv = SparseVector(4, {1: 1, 3: 2}) dv = DenseVector(array([1., 2., 3., 4.])) lst = DenseVector([4, 3, 2, 1]) lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, squared_distance(sv, dv)) + self.assertEqual(25.0, squared_distance(sv, lst)) + self.assertEqual(20.0, squared_distance(dv, lst)) + self.assertEqual(15.0, squared_distance(dv, sv)) + self.assertEqual(25.0, squared_distance(lst, sv)) + self.assertEqual(20.0, squared_distance(lst, dv)) + self.assertEqual(0.0, squared_distance(sv, sv)) + self.assertEqual(0.0, squared_distance(dv, dv)) + self.assertEqual(0.0, squared_distance(lst, lst)) + self.assertEqual(25.0, squared_distance(sv, lst1)) + self.assertEqual(3.0, squared_distance(sv, arr)) + self.assertEqual(3.0, squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) @@ -466,7 +457,7 @@ def test_infer_schema(self): raise ValueError("Expected a matrix but got type %r" % type(m)) -@unittest.skipIf(not _have_scipy, "SciPy not installed") +@unittest.skipIf(not have_scipy, "SciPy not installed") class SciPyTests(MLlibTestCase): """ @@ -476,6 +467,8 @@ class SciPyTests(MLlibTestCase): def test_serialize(self): from scipy.sparse import lil_matrix + + ser = PickleSerializer() lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 @@ -621,13 +614,10 @@ def test_regression(self): if __name__ == "__main__": from pyspark.mllib.tests.test_linalg import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") + try: import xmlrunner testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') except ImportError: testRunner = None unittest.main(testRunner=testRunner, verbosity=2) - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py index 25f1bba8d37ac..c09fb50482e49 100644 --- a/python/pyspark/testing/mllibutils.py +++ b/python/pyspark/testing/mllibutils.py @@ -18,14 +18,9 @@ import unittest from pyspark import SparkContext -from pyspark.serializers import PickleSerializer from pyspark.sql import SparkSession -def make_serializer(): - return PickleSerializer() - - class MLlibTestCase(unittest.TestCase): def setUp(self): self.sc = SparkContext('local[4]', "MLlib tests") From a24e1a126c55fc06f5867c0e5e5b0ee71201e018 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Mon, 3 Dec 2018 14:57:18 -0800 Subject: [PATCH 2208/2461] [SPARK-26256][K8S] Fix labels for pod deletion ## What changes were proposed in this pull request? Adds proper labels when deleting executor pods. ## How was this patch tested? Manually with tests. Closes #23209 from skonto/fix-deletion-labels. Authored-by: Stavros Kontopoulos Signed-off-by: Marcelo Vanzin --- .../scheduler/cluster/k8s/ExecutorPodsAllocator.scala | 2 ++ .../cluster/k8s/ExecutorPodsAllocatorSuite.scala | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index ef4cbdf162c6c..2f0f949566d6a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -93,6 +93,8 @@ private[spark] class ExecutorPodsAllocator( Utils.tryLogNonFatalError { kubernetesClient .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) .delete() } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index ddf9f67a0727d..303e24b8f4977 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -138,7 +138,15 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { snapshotsStore.notifySubscribers() snapshotsStore.replaceSnapshot(Seq.empty[Pod]) waitForExecutorPodsClock.setTime(podCreationTimeout + 1) - when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + when(podOperations + .withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(podOperations) + when(podOperations + withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(podOperations) + when(podOperations + .withLabel(SPARK_EXECUTOR_ID_LABEL, "1")) + .thenReturn(labeledPods) snapshotsStore.notifySubscribers() verify(labeledPods).delete() verify(podOperations).create(podWithAttachedContainerForId(2)) From 0889fbaf959e25ebb79e691692a02a93962727d0 Mon Sep 17 00:00:00 2001 From: Qi Shao Date: Mon, 3 Dec 2018 15:36:41 -0800 Subject: [PATCH 2209/2461] [SPARK-26083][K8S] Add Copy pyspark into corresponding dir cmd in pyspark Dockerfile When I try to run `./bin/pyspark` cmd in a pod in Kubernetes(image built without change from pyspark Dockerfile), I'm getting an error: ``` $SPARK_HOME/bin/pyspark --deploy-mode client --master k8s://https://$KUBERNETES_SERVICE_HOST:$KUBERNETES_SERVICE_PORT_HTTPS ... Python 2.7.15 (default, Aug 22 2018, 13:24:18) [GCC 6.4.0] on linux2 Type "help", "copyright", "credits" or "license" for more information. Could not open PYTHONSTARTUP IOError: [Errno 2] No such file or directory: '/opt/spark/python/pyspark/shell.py' ``` This is because `pyspark` folder doesn't exist under `/opt/spark/python/` ## What changes were proposed in this pull request? Added `COPY python/pyspark ${SPARK_HOME}/python/pyspark` to pyspark Dockerfile to resolve issue above. ## How was this patch tested? Google Kubernetes Engine Closes #23037 from AzureQ/master. Authored-by: Qi Shao Signed-off-by: Marcelo Vanzin --- bin/docker-image-tool.sh | 1 + .../docker/src/main/dockerfiles/spark/bindings/python/Dockerfile | 1 + 2 files changed, 2 insertions(+) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index fbf9c9e448fd1..4f66137eb1c7a 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -107,6 +107,7 @@ function create_dev_build_context {( "$PYSPARK_CTX/kubernetes/dockerfiles" mkdir "$PYSPARK_CTX/python" cp -r "python/lib" "$PYSPARK_CTX/python/lib" + cp -r "python/pyspark" "$PYSPARK_CTX/python/pyspark" local R_CTX="$CTX_DIR/sparkr" mkdir -p "$R_CTX/kubernetes" diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile index de1a0617b1cc5..36b91eb9a3aac 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -38,6 +38,7 @@ RUN apk add --no-cache python && \ # Removed the .cache to save space rm -r /root/.cache +COPY python/pyspark ${SPARK_HOME}/python/pyspark COPY python/lib ${SPARK_HOME}/python/lib ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip From f7af4a1965b1052d3c77505ab1b660a294757bed Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 12:14:38 +0800 Subject: [PATCH 2210/2461] [SPARK-25498][SQL][FOLLOW-UP] Return an empty config set when regenerating the golden files ## What changes were proposed in this pull request? This pr is to return an empty config set when regenerating the golden files in `SQLQueryTestSuite`. This is the follow-up of #22512. ## How was this patch tested? N/A Closes #23212 from maropu/SPARK-25498-FOLLOWUP. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/SQLQueryTestSuite.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index fd180ce2380a3..cf4585bf7ac6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -154,14 +154,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val (conf, value) = confAndValue.span(_ != '=') conf.trim -> value.substring(1).trim }) - // When we are regenerating the golden files we don't need to run all the configs as they + // When we are regenerating the golden files, we don't need to set any config as they // all need to return the same result if (regenerateGoldenFiles) { - if (configs.nonEmpty) { - configs.take(1) - } else { - Array.empty[Array[(String, String)]] - } + Array.empty[Array[(String, String)]] } else { if (configs.nonEmpty) { codegenConfigSets.flatMap { codegenConfig => From b4dea313c45042e4094d14ebdeb8ad27be4cc695 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 3 Dec 2018 23:00:02 -0800 Subject: [PATCH 2211/2461] [SPARK-25573] Combine resolveExpression and resolve in the Analyzer ## What changes were proposed in this pull request? Currently in the Analyzer, we have two methods 1) Resolve 2)ResolveExpressions that are called at different code paths to resolve attributes, column ordinal and extract value expressions. ~~In this PR, we combine the two into one method to make sure, there is only one method that is tasked with resolving the attributes.~~ Update the description of the methods and use better names to make it easier to know when to make use of one method vs the other. ## How was this patch tested? Existing tests. Closes #22899 from dilipbiswal/SPARK-25573-final. Authored-by: Dilip Biswal Signed-off-by: gatorsmile --- .../sql/catalyst/analysis/Analyzer.scala | 97 +++++++++++++------ 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b977fa07db5c4..777053168a056 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -883,21 +883,38 @@ class Analyzer( } } - private def resolve(e: Expression, q: LogicalPlan): Expression = e match { - case f: LambdaFunction if !f.bound => f - case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver) - .orElse(resolveLiteralFunction(nameParts, u, q)) - .getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - case UnresolvedExtractValue(child, fieldExpr) if child.resolved => - ExtractValue(child, fieldExpr, resolver) - case _ => e.mapChildren(resolve(_, q)) + /** + * Resolves the attribute and extract value expressions(s) by traversing the + * input expression in top down manner. The traversal is done in top-down manner as + * we need to skip over unbound lamda function expression. The lamda expressions are + * resolved in a different rule [[ResolveLambdaVariables]] + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]] + * + * Note : In this routine, the unresolved attributes are resolved from the input plan's + * children attributes. + */ + private def resolveExpressionTopDown(e: Expression, q: LogicalPlan): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = + withPosition(u) { + q.resolveChildren(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, q)) + .getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(resolveExpressionTopDown(_, q)) + } } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { @@ -936,7 +953,7 @@ class Analyzer( // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = - ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) + ordering.map(order => resolveExpressionBottomUp(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by @@ -944,7 +961,7 @@ class Analyzer( case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g case g @ Generate(generator, join, outer, qualifier, output, child) => - val newG = resolveExpression(generator, child, throws = true) + val newG = resolveExpressionBottomUp(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { @@ -959,11 +976,11 @@ class Analyzer( // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute // names leading to ambiguous references exception. case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => - a.mapExpressions(resolve(_, appendColumns)) + a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q.mapExpressions(resolve(_, q)) + q.mapExpressions(resolveExpressionTopDown(_, q)) } def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { @@ -1060,7 +1077,22 @@ class Analyzer( func.map(wrapper) } - protected[sql] def resolveExpression( + /** + * Resolves the attribute, column value and extract value expressions(s) by traversing the + * input expression in bottom-up manner. In order to resolve the nested complex type fields + * correctly, this function makes use of `throws` parameter to control when to raise an + * AnalysisException. + * + * Example : + * SELECT a.b FROM t ORDER BY b[0].d + * + * In the above example, in b needs to be resolved before d can be resolved. Given we are + * doing a bottom up traversal, it will first attempt to resolve d and fail as b has not + * been resolved yet. If `throws` is false, this function will handle the exception by + * returning the original attribute. In this case `d` will be resolved in subsequent passes + * after `b` is resolved. + */ + protected[sql] def resolveExpressionBottomUp( expr: Expression, plan: LogicalPlan, throws: Boolean = false): Expression = { @@ -1073,11 +1105,14 @@ class Analyzer( expr transformUp { case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { - plan.resolve(nameParts, resolver) - .orElse(resolveLiteralFunction(nameParts, u, plan)) - .getOrElse(u) - } + val result = + withPosition(u) { + plan.resolve(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, plan)) + .getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -1223,7 +1258,7 @@ class Analyzer( plan match { case p: Project => // Resolving expressions against current plan. - val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, p)) // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) // If some attributes used by expressions are resolvable only on the rewritten child @@ -1232,7 +1267,7 @@ class Analyzer( (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => - val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) + val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { @@ -1244,20 +1279,20 @@ class Analyzer( } case g: Generate => - val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) + val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, g)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes // via its children. case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => - val maybeResolvedExprs = exprs.map(resolveExpression(_, u)) + val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, u)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) (newExprs, u.withNewChildren(Seq(newChild))) // For other operators, we can't recursively resolve and add attributes via its children. case other => - (exprs.map(resolveExpression(_, other)), other) + (exprs.map(resolveExpressionBottomUp(_, other)), other) } } } @@ -2387,7 +2422,7 @@ class Analyzer( } validateTopLevelTupleFields(deserializer, inputs) - val resolved = resolveExpression( + val resolved = resolveExpressionBottomUp( deserializer, LocalRelation(inputs), throws = true) val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => From 26128484228089c74517cd15cef0bb4166a4186f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 20:20:29 +0800 Subject: [PATCH 2212/2461] [SPARK-25374][SQL] SafeProjection supports fallback to an interpreted mode ## What changes were proposed in this pull request? In SPARK-23711, we have implemented the expression fallback logic to an interpreted mode. So, this pr fixed code to support the same fallback mode in `SafeProjection` based on `CodeGeneratorWithInterpretedFallback`. ## How was this patch tested? Add tests in `CodeGeneratorWithInterpretedFallbackSuite` and `UnsafeRowConverterSuite`. Closes #22468 from maropu/SPARK-25374-3. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../InterpretedSafeProjection.scala | 125 ++++++++++++++++++ .../sql/catalyst/expressions/Projection.scala | 34 +++-- .../expressions/CodeGenerationSuite.scala | 2 +- ...eneratorWithInterpretedFallbackSuite.scala | 15 +++ .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/MutableProjectionSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 89 ++++++++++++- .../DeclarativeAggregateEvaluator.scala | 11 +- .../codegen/GeneratedProjectionSuite.scala | 8 +- .../util/ArrayDataIndexedSeqSuite.scala | 4 +- .../org/apache/spark/sql/types/TestUDT.scala | 61 +++++++++ .../spark/sql/FileBasedDataSourceSuite.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 105 +++++---------- .../datasources/json/JsonSuite.scala | 4 +- .../datasources/orc/OrcQuerySuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 2 +- .../execution/ObjectHashAggregateSuite.scala | 4 +- .../sql/sources/HadoopFsRelationTest.scala | 2 +- 19 files changed, 371 insertions(+), 111 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 589e215c55e44..fbf0bd68b9584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -302,7 +302,7 @@ case class ExpressionEncoder[T]( private lazy val inputRow = new GenericInternalRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) + private lazy val constructProjection = SafeProjection.create(deserializer :: Nil) /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala new file mode 100644 index 0000000000000..70789dac1d87a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ + + +/** + * An interpreted version of a safe projection. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection { + + private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType)) + + private[this] val exprsWithWriters = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.map { case (e, i) => + val converter = generateSafeValueConverter(e.dataType) + val writer = InternalRow.getWriter(i, e.dataType) + val f = if (!e.nullable) { + (v: Any) => writer(mutableRow, converter(v)) + } else { + (v: Any) => { + if (v == null) { + mutableRow.setNullAt(i) + } else { + writer(mutableRow, converter(v)) + } + } + } + (e, f) + } + + private def generateSafeValueConverter(dt: DataType): Any => Any = dt match { + case ArrayType(elemType, _) => + val elementConverter = generateSafeValueConverter(elemType) + v => { + val arrayValue = v.asInstanceOf[ArrayData] + val result = new Array[Any](arrayValue.numElements()) + arrayValue.foreach(elemType, (i, e) => { + result(i) = elementConverter(e) + }) + new GenericArrayData(result) + } + + case st: StructType => + val fieldTypes = st.fields.map(_.dataType) + val fieldConverters = fieldTypes.map(generateSafeValueConverter) + v => { + val row = v.asInstanceOf[InternalRow] + val ar = new Array[Any](row.numFields) + var idx = 0 + while (idx < row.numFields) { + ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx))) + idx += 1 + } + new GenericInternalRow(ar) + } + + case MapType(keyType, valueType, _) => + lazy val keyConverter = generateSafeValueConverter(keyType) + lazy val valueConverter = generateSafeValueConverter(valueType) + v => { + val mapValue = v.asInstanceOf[MapData] + val keys = mapValue.keyArray().toArray[Any](keyType) + val values = mapValue.valueArray().toArray[Any](valueType) + val convertedKeys = keys.map(keyConverter) + val convertedValues = values.map(valueConverter) + ArrayBasedMapData(convertedKeys, convertedValues) + } + + case udt: UserDefinedType[_] => + generateSafeValueConverter(udt.sqlType) + + case _ => identity + } + + override def apply(row: InternalRow): InternalRow = { + var i = 0 + while (i < exprsWithWriters.length) { + val (expr, writer) = exprsWithWriters(i) + writer(expr.eval(row)) + i += 1 + } + mutableRow + } +} + +/** + * Helper functions for creating an [[InterpretedSafeProjection]]. + */ +object InterpretedSafeProjection { + + /** + * Returns an [[SafeProjection]] for given sequence of bound Expressions. + */ + def createProjection(exprs: Seq[Expression]): Projection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedSafeProjection(cleanedExpressions) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 792646cf9f10c..b48f7ba655b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -169,26 +169,40 @@ object UnsafeProjection /** * A projection that could turn UnsafeRow into GenericInternalRow */ -object FromUnsafeProjection { +object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = { + GenerateSafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): Projection = { + InterpretedSafeProjection.createProjection(in) + } /** - * Returns a Projection for given StructType. + * Returns a SafeProjection for given StructType. */ - def apply(schema: StructType): Projection = { - apply(schema.fields.map(_.dataType)) + def create(schema: StructType): Projection = create(schema.fields.map(_.dataType)) + + /** + * Returns a SafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): Projection = { + createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) } /** - * Returns an UnsafeProjection for given Array of DataTypes. + * Returns a SafeProjection for given sequence of Expressions (bounded). */ - def apply(fields: Seq[DataType]): Projection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + def create(exprs: Seq[Expression]): Projection = { + createObject(exprs) } /** - * Returns a Projection for given sequence of Expressions (bounded). + * Returns a SafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. */ - private def create(exprs: Seq[Expression]): Projection = { - GenerateSafeProjection.generate(exprs) + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + create(toBoundExprs(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7843003a4aac3..7e6fe5b4e2069 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -251,7 +251,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { UTF8String.fromString("c")) assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3) - val fromUnsafe = FromUnsafeProjection(schema) + val fromUnsafe = SafeProjection.create(schema) val internalRow2 = fromUnsafe(unsafeRow) assert(internalRow === internalRow2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 6ea3b05ff9c1e..da5bddb0c09fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -106,4 +106,19 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) } } + + test("SPARK-25374 Correctly handles NoOp in SafeProjection") { + val exprs = Seq(Add(BoundReference(0, IntegerType, nullable = true), Literal.create(1)), NoOp) + val input = InternalRow.fromSeq(1 :: 1 :: Nil) + val expected = 2 :: null :: Nil + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val proj = SafeProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val proj = SafeProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a7282e1b1cadc..b4fd170467d81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -321,8 +321,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - actual = FromUnsafeProjection(expression.dataType :: Nil)( - plan(inputRow)).get(0, expression.dataType) + val ref = new BoundReference(0, expression.dataType, nullable = true) + actual = GenerateSafeProjection.generate(ref :: Nil)(plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected, expression)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 2db1c3b98819c..0d594eb10962e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -51,7 +51,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length) val proj = createMutableProjection(fixedLengthTypes) val projUnsafeRow = proj.target(unsafeBuffer)(inputRow) - assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow) + assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } testBothCodegenAndInterpreted("variable-length types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 268372b5d0504..ecb8047459b0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase with ExpressionEvalHelper { @@ -535,4 +535,91 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } + + testBothCodegenAndInterpreted("SPARK-25374 converts back into safe representation") { + def convertBackToInternalRow(inputRow: InternalRow, fields: Array[DataType]): InternalRow = { + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow = unsafeProj(inputRow) + val safeProj = SafeProjection.create(fields) + safeProj(unsafeRow) + } + + // Simple tests + val inputRow = InternalRow.fromSeq(Seq( + false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"), + Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2) + )) + val fields1 = Array( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, StringType, DecimalType.defaultConcreteType, CalendarIntervalType, + BinaryType) + + assert(convertBackToInternalRow(inputRow, fields1) === inputRow) + + // Array tests + val arrayRow = InternalRow.fromSeq(Seq( + createArray(1, 2, 3), + createArray( + createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*), + createArray(Seq("d").map(UTF8String.fromString): _*)) + )) + val fields2 = Array[DataType]( + ArrayType(IntegerType), + ArrayType(ArrayType(StringType))) + + assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow) + + // Struct tests + val structRow = InternalRow.fromSeq(Seq( + InternalRow.fromSeq(Seq[Any](1, 4.0)), + InternalRow.fromSeq(Seq( + UTF8String.fromString("test"), + InternalRow.fromSeq(Seq( + 1, + createArray(Seq("2", "3").map(UTF8String.fromString): _*) + )) + )) + )) + val fields3 = Array[DataType]( + StructType( + StructField("c0", IntegerType) :: + StructField("c1", DoubleType) :: + Nil), + StructType( + StructField("c2", StringType) :: + StructField("c3", StructType( + StructField("c4", IntegerType) :: + StructField("c5", ArrayType(StringType)) :: + Nil)) :: + Nil)) + + assert(convertBackToInternalRow(structRow, fields3) === structRow) + + // Map tests + val mapRow = InternalRow.fromSeq(Seq( + createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2), + createMap( + createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*), + createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*) + )( + createMap(Seq("k3", "k4").map(UTF8String.fromString): _*)(3.toShort, 4.toShort), + createMap(Seq("k5", "k6").map(UTF8String.fromString): _*)(5.toShort, 6.toShort) + ))) + val fields4 = Array[DataType]( + MapType(StringType, IntegerType), + MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) + + val mapResultRow = convertBackToInternalRow(mapRow, fields4) + val mapExpectedRow = mapRow + checkResult(mapExpectedRow, mapResultRow, + exprDataType = StructType(fields4.zipWithIndex.map(f => StructField(s"c${f._2}", f._1))), + exprNullable = false) + + // UDT tests + val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val udt = new TestUDT.MyDenseVectorUDT() + val udtRow = InternalRow.fromSeq(Seq(udt.serialize(vector))) + val fields5 = Array[DataType](udt) + assert(convertBackToInternalRow(udtRow, fields5) === udtRow) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala index 614f24db0aafb..b0f55b3b5c443 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -17,25 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, SafeProjection} /** * Evaluator for a [[DeclarativeAggregate]]. */ case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { - lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + lazy val initializer = SafeProjection.create(function.initialValues) - lazy val updater = GenerateSafeProjection.generate( + lazy val updater = SafeProjection.create( function.updateExpressions, function.aggBufferAttributes ++ input) - lazy val merger = GenerateSafeProjection.generate( + lazy val merger = SafeProjection.create( function.mergeExpressions, function.aggBufferAttributes ++ function.inputAggBufferAttributes) - lazy val evaluator = GenerateSafeProjection.generate( + lazy val evaluator = SafeProjection.create( function.evaluateExpression :: Nil, function.aggBufferAttributes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 2c45b3b0c73d1..4c9bcfe8f93a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -58,7 +58,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { } // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => @@ -109,7 +109,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { } // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => @@ -147,7 +147,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafeRow.getArray(1).getBinary(1) === null) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) - val safeProj = FromUnsafeProjection(fields) + val safeProj = SafeProjection.create(fields) val row2 = safeProj(unsafeRow) assert(row2 === row) } @@ -233,7 +233,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { val nestedSchema = StructType( Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(nested) // test generated MutableProjection diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index 6400898343ae7..da71e3a4d53e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{SafeProjection, UnsafeArrayData, UnsafeProjection} import org.apache.spark.sql.types._ class ArrayDataIndexedSeqSuite extends SparkFunSuite { @@ -77,7 +77,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { val internalRow = rowConverter.toRow(row) val unsafeRowConverter = UnsafeProjection.create(schema) - val safeRowConverter = FromUnsafeProjection(schema) + val safeRowConverter = SafeProjection.create(schema) val unsafeRow = unsafeRowConverter(internalRow) val safeRow = safeRowConverter(unsafeRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala new file mode 100644 index 0000000000000..1be8ee9dfa92b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} + + +// Wrapped in an object to check Scala compatibility. See SPARK-13929 +object TestUDT { + + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) + private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def hashCode(): Int = java.util.Arrays.hashCode(data) + + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) + case _ => false + } + + override def toString: String = data.mkString("(", ", ", ")") + } + + private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) + } + } + + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + + private[spark] override def asNullable: MyDenseVectorUDT = this + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 64b42c32b8b1b..54299e9808bf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -312,13 +312,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + val schema = StructType(StructField("a", new TestUDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cf956316057eb..6628d36ffc702 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,56 +20,14 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -private[sql] case class MyLabeledPoint(label: Double, features: UDT.MyDenseVector) { +private[sql] case class MyLabeledPoint(label: Double, features: TestUDT.MyDenseVector) { def getLabel: Double = label - def getFeatures: UDT.MyDenseVector = features -} - -// Wrapped in an object to check Scala compatibility. See SPARK-13929 -object UDT { - - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) - private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { - override def hashCode(): Int = java.util.Arrays.hashCode(data) - - override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) - case _ => false - } - - override def toString: String = data.mkString("(", ", ", ")") - } - - private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - - override def serialize(features: MyDenseVector): ArrayData = { - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } - - override def deserialize(datum: Any): MyDenseVector = { - datum match { - case data: ArrayData => - new MyDenseVector(data.toDoubleArray()) - } - } - - override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] - - private[spark] override def asNullable: MyDenseVectorUDT = this - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] - } - + def getFeatures: TestUDT.MyDenseVector = features } // object and classes to test SPARK-19311 @@ -148,12 +106,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT import testImplicits._ private lazy val pointsRDD = Seq( - MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF() + MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0)))).toDF() private lazy val pointsRDD2 = Seq( - MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF() + MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -162,16 +120,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - val features: RDD[UDT.MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: UDT.MyDenseVector) => v } - val featuresArrays: Array[UDT.MyDenseVector] = features.collect() + val features: RDD[TestUDT.MyDenseVector] = + pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } + val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.2, 2.0)))) + assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.2, 2.0)))) } test("UDTs and UDFs") { - spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + spark.udf.register("testType", + (d: TestUDT.MyDenseVector) => d.isInstanceOf[TestUDT.MyDenseVector]) pointsRDD.createOrReplaceTempView("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -185,8 +144,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( spark.read.parquet(path), Seq( - Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } } @@ -197,17 +156,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( spark.read.parquet(path), Seq( - Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } } // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val vec = new UDT.MyDenseVector(Array(0.1, 1.0)) + val vec = new TestUDT.MyDenseVector(Array(0.1, 1.0)) val df = Seq((1, vec)).toDF("int", "vec") - assert(vec === df.collect()(0).getAs[UDT.MyDenseVector](1)) - assert(vec === df.take(1)(0).getAs[UDT.MyDenseVector](1)) + assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1)) + assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1)) checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) } @@ -219,14 +178,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT ) val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new UDT.MyDenseVectorUDT, false) + StructField("vec", new TestUDT.MyDenseVectorUDT, false) )) val jsonRDD = spark.read.schema(schema).json(data.toDS()) checkAnswer( jsonRDD, - Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: - Row(2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Row(1, new TestUDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new TestUDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: Nil ) } @@ -239,25 +198,25 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new UDT.MyDenseVectorUDT, false) + StructField("vec", new TestUDT.MyDenseVectorUDT, false) )) val jsonDataset = spark.read.schema(schema).json(data.toDS()) - .as[(Int, UDT.MyDenseVector)] + .as[(Int, TestUDT.MyDenseVector)] checkDataset( jsonDataset, - (1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), - (2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) + (1, new TestUDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), + (2, new TestUDT.MyDenseVector(Array(2.25, 4.5, 8.75))) ) } test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") - assert(new UDT.MyDenseVectorUDT().typeName === "mydensevector") + assert(new TestUDT.MyDenseVectorUDT().typeName === "mydensevector") } test("Catalyst type converter null handling for UDTs") { - val udt = new UDT.MyDenseVectorUDT() + val udt = new TestUDT.MyDenseVectorUDT() val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) assert(toScalaConverter(null) === null) @@ -303,12 +262,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT test("except on UDT") { checkAnswer( pointsRDD.except(pointsRDD2), - Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Seq(Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } test("SPARK-23054 Cast UserDefinedType to string") { - val udt = new UDT.MyDenseVectorUDT() - val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val udt = new TestUDT.MyDenseVectorUDT() + val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) val data = udt.serialize(vector) val ret = Cast(Literal(data, udt), StringType, None) checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9d23161c1f24e..dff37ca2d40f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1463,7 +1463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new UDT.MyDenseVectorUDT()) + new TestUDT.MyDenseVectorUDT()) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, nullable = true) } @@ -1487,7 +1487,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Seq(2, 3, 4), Map("a string" -> 2000L), Row(4.75.toFloat, Seq(false, true)), - new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) + new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 998b7b31dcd6a..918dbcdfa1cc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType, TestUDT} import org.apache.spark.util.Utils case class AllDataTypesWithNonPrimitiveType( @@ -103,7 +103,7 @@ abstract class OrcQueryTest extends OrcTest { test("Read/write UserDefinedType") { withTempPath { path => - val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) + val data = Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) val udtDF = data.toDF("id", "vectors") udtDF.write.orc(path.getAbsolutePath) val readBack = spark.read.schema(udtDF.schema).orc(path.getAbsolutePath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index c65bf7c14c7a5..cfae2d82e273d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -884,7 +884,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new UDT.MyDenseVectorUDT()) + new TestUDT.MyDenseVectorUDT()) // Right now, we will use SortAggregate to handle UDAFs. // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortAggregate to use // UnsafeRow as the aggregation buffer. While, dataTypes will trigger diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index c9309197791bd..2391106cfb253 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -124,7 +124,7 @@ class ObjectHashAggregateSuite .add("f2", ArrayType(BooleanType), nullable = true), // UDT - new UDT.MyDenseVectorUDT(), + new TestUDT.MyDenseVectorUDT(), // Others StringType, @@ -259,7 +259,7 @@ class ObjectHashAggregateSuite StringType, BinaryType, NullType, BooleanType ) - val udt = new UDT.MyDenseVectorUDT() + val udt = new TestUDT.MyDenseVectorUDT() val fixedLengthTypes = builtinNumericTypes ++ Seq(BooleanType, NullType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 6bd59fde550de..6075f2c8877d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -115,7 +115,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new StructType() .add("f1", FloatType, nullable = true) .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), - new UDT.MyDenseVectorUDT() + new TestUDT.MyDenseVectorUDT() ).filter(supportsDataType) test(s"test all data types") { From 93f5592aa8c1254a93524fda81cf0e418c22cb2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E7=81=BF00244106?= <00244106@zte.intra> Date: Tue, 4 Dec 2018 22:08:16 +0900 Subject: [PATCH 2213/2461] [MINOR][SQL] Combine the same codes in test cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In the DDLSuit, there are four test cases have the same codes , writing a function can combine the same code. ## How was this patch tested? existing tests. Closes #23194 from CarolinePeng/Update_temp. Authored-by: 彭灿00244106 <00244106@zte.intra> Signed-off-by: Takeshi Yamamuro --- .../sql/execution/command/DDLSuite.scala | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9d32fb6d46962..052a5e757c445 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -377,41 +377,41 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("CTAS a managed table with the existing empty directory") { - val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + private def withEmptyDirInTablePath(dirName: String)(f : File => Unit): Unit = { + val tableLoc = + new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier(dirName))) try { tableLoc.mkdir() + f(tableLoc) + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + + test("CTAS a managed table with the existing empty directory") { + withEmptyDirInTablePath("tab1") { tableLoc => withTable("tab1") { sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") checkAnswer(spark.table("tab1"), Row(1, "a")) } - } finally { - waitForTasksToFinish() - Utils.deleteRecursively(tableLoc) } } test("create a managed table with the existing empty directory") { - val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) - try { - tableLoc.mkdir() + withEmptyDirInTablePath("tab1") { tableLoc => withTable("tab1") { sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") sql("INSERT INTO tab1 VALUES (1, 'a')") checkAnswer(spark.table("tab1"), Row(1, "a")) } - } finally { - waitForTasksToFinish() - Utils.deleteRecursively(tableLoc) } } test("create a managed table with the existing non-empty directory") { withTable("tab1") { - val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) - try { - // create an empty hidden file - tableLoc.mkdir() + withEmptyDirInTablePath("tab1") { tableLoc => val hiddenGarbageFile = new File(tableLoc.getCanonicalPath, ".garbage") hiddenGarbageFile.createNewFile() val exMsg = "Can not create the managed table('`tab1`'). The associated location" @@ -439,28 +439,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { }.getMessage assert(ex.contains(exMsgWithDefaultDB)) } - } finally { - waitForTasksToFinish() - Utils.deleteRecursively(tableLoc) } } } test("rename a managed table with existing empty directory") { - val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) - try { + withEmptyDirInTablePath("tab2") { tableLoc => withTable("tab1") { sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") - tableLoc.mkdir() val ex = intercept[AnalysisException] { sql("ALTER TABLE tab1 RENAME TO tab2") }.getMessage val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" assert(ex.contains(expectedMsg)) } - } finally { - waitForTasksToFinish() - Utils.deleteRecursively(tableLoc) } } From 06a3b6aafa510ede2f1376b29a46f99447286c67 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 4 Dec 2018 07:57:58 -0600 Subject: [PATCH 2214/2461] [SPARK-24423][FOLLOW-UP][SQL] Fix error example ## What changes were proposed in this pull request? ![image](https://user-images.githubusercontent.com/5399861/49172173-42ad9800-f37b-11e8-8135-7adc323357ae.png) It will throw: ``` requirement failed: When reading JDBC data sources, users need to specify all or none for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and 'numPartitions' ``` and ``` User-defined partition column subq.c1 not found in the JDBC relation ... ``` This PR fix this error example. ## How was this patch tested? manual tests Closes #23170 from wangyum/SPARK-24499. Authored-by: Yuming Wang Signed-off-by: Sean Owen --- docs/sql-data-sources-jdbc.md | 6 +++--- .../sql/execution/datasources/jdbc/JDBCOptions.scala | 10 +++++++--- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 10 +++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 9a5d0fc7d424c..a2b14620be12e 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -64,9 +64,9 @@ the following case-insensitive options: Example:
        spark.read.format("jdbc")
        -    .option("dbtable", "(select c1, c2 from t1) as subq")
        -    .option("partitionColumn", "subq.c1"
        -    .load() + .option("url", jdbcUrl)
        + .option("query", "select c1, c2 from t1")
        + .load()
      diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7dfbb9d8b5c05..b4469cb538fa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -137,9 +137,13 @@ class JDBCOptions( |the partition columns using the supplied subquery alias to resolve any ambiguity. |Example : |spark.read.format("jdbc") - | .option("dbtable", "(select c1, c2 from t1) as subq") - | .option("partitionColumn", "subq.c1" - | .load() + | .option("url", jdbcUrl) + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "c1") + | .option("lowerBound", "1") + | .option("upperBound", "100") + | .option("numPartitions", "3") + | .load() """.stripMargin ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7fa0e7fc162ca..71e83767964a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1348,9 +1348,13 @@ class JDBCSuite extends QueryTest |the partition columns using the supplied subquery alias to resolve any ambiguity. |Example : |spark.read.format("jdbc") - | .option("dbtable", "(select c1, c2 from t1) as subq") - | .option("partitionColumn", "subq.c1" - | .load() + | .option("url", jdbcUrl) + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "c1") + | .option("lowerBound", "1") + | .option("upperBound", "100") + | .option("numPartitions", "3") + | .load() """.stripMargin val e5 = intercept[RuntimeException] { sql( From f982ca07e80074bdc1e3b742c5e21cf368e4ede2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 4 Dec 2018 08:36:33 -0600 Subject: [PATCH 2215/2461] [SPARK-26178][SQL] Use java.time API for parsing timestamps and dates from CSV ## What changes were proposed in this pull request? In the PR, I propose to use **java.time API** for parsing timestamps and dates from CSV content with microseconds precision. The SQL config `spark.sql.legacy.timeParser.enabled` allow to switch back to previous behaviour with using `java.text.SimpleDateFormat`/`FastDateFormat` for parsing/generating timestamps/dates. ## How was this patch tested? It was tested by `UnivocityParserSuite`, `CsvExpressionsSuite`, `CsvFunctionsSuite` and `CsvSuite`. Closes #23150 from MaxGekk/time-parser. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Sean Owen --- docs/sql-migration-guide-upgrade.md | 2 + .../sql/catalyst/csv/CSVInferSchema.scala | 15 +- .../spark/sql/catalyst/csv/CSVOptions.scala | 10 +- .../sql/catalyst/csv/UnivocityGenerator.scala | 14 +- .../sql/catalyst/csv/UnivocityParser.scala | 38 ++-- .../sql/catalyst/util/DateTimeFormatter.scala | 179 ++++++++++++++++++ .../sql/catalyst/util/DateTimeUtils.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../catalyst/csv/CSVInferSchemaSuite.scala | 7 +- .../catalyst/csv/UnivocityParserSuite.scala | 113 ++++++----- .../sql/catalyst/util/DateTimeTestUtils.scala | 5 +- .../sql/util/DateTimeFormatterSuite.scala | 103 ++++++++++ .../apache/spark/sql/CsvFunctionsSuite.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 66 ++++--- 14 files changed, 431 insertions(+), 134 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 787f4bcbbea82..fee0e6df7177c 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -33,6 +33,8 @@ displayTitle: Spark SQL Upgrading Guide - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. + - Since Spark 3.0, CSV datasource uses java.time API for parsing and generating CSV content. New formatting implementation supports date/timestamp patterns conformed to ISO 8601. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 94cb4b114e6b6..345dc4d41993e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,10 +22,16 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeFormatter import org.apache.spark.sql.types._ -class CSVInferSchema(options: CSVOptions) extends Serializable { +class CSVInferSchema(val options: CSVOptions) extends Serializable { + + @transient + private lazy val timeParser = DateTimeFormatter( + options.timestampFormat, + options.timeZone, + options.locale) private val decimalParser = { ExprUtils.getDecimalParser(options.locale) @@ -154,10 +160,7 @@ class CSVInferSchema(options: CSVOptions) extends Serializable { private def tryParseTimestamp(field: String): DataType = { // This case infers a custom `dataFormat` is set. - if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { - TimestampType - } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - // We keep this for backwards compatibility. + if ((allCatch opt timeParser.parse(field)).isDefined) { TimestampType } else { tryParseBoolean(field) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 94bdb72d675d4..90c96d1f55c91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} -import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ @@ -146,13 +145,10 @@ class CSVOptions( // A language tag in IETF BCP 47 format val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) - // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. - val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) + val dateFormat: String = parameters.getOrElse("dateFormat", "yyyy-MM-dd") - val timestampFormat: FastDateFormat = - FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) + val timestampFormat: String = + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 2ab376c0ac208..af09cd6c8449b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -22,7 +22,7 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter} import org.apache.spark.sql.types._ class UnivocityGenerator( @@ -41,14 +41,18 @@ class UnivocityGenerator( private val valueConverters: Array[ValueConverter] = schema.map(_.dataType).map(makeConverter).toArray + private val timeFormatter = DateTimeFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale) + private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => - (row: InternalRow, ordinal: Int) => - options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + (row: InternalRow, ordinal: Int) => dateFormatter.format(row.getInt(ordinal)) case TimestampType => - (row: InternalRow, ordinal: Int) => - options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + (row: InternalRow, ordinal: Int) => timeFormatter.format(row.getLong(ordinal)) case udt: UserDefinedType[_] => makeConverter(udt.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 8fff4b0781b1d..0f375e036029c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.csv import java.io.InputStream -import scala.util.Try import scala.util.control.NonFatal import com.univocity.parsers.csv.CsvParser @@ -27,7 +26,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} -import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -75,6 +74,12 @@ class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) + private val timeFormatter = DateTimeFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale) + // Retrieve the raw record string. private def getCurrentInput: UTF8String = { UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) @@ -100,7 +105,7 @@ class UnivocityParser( // // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = { - requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable)).toArray } private val decimalParser = ExprUtils.getDecimalParser(options.locale) @@ -115,8 +120,7 @@ class UnivocityParser( def makeConverter( name: String, dataType: DataType, - nullable: Boolean = true, - options: CSVOptions): ValueConverter = dataType match { + nullable: Boolean = true): ValueConverter = dataType match { case _: ByteType => (d: String) => nullSafeDatum(d, name, nullable, options)(_.toByte) @@ -154,34 +158,16 @@ class UnivocityParser( } case _: TimestampType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(datum).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(datum).getTime * 1000L - } - } + nullSafeDatum(d, name, nullable, options)(timeFormatter.parse) case _: DateType => (d: String) => - nullSafeDatum(d, name, nullable, options) { datum => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - } - } + nullSafeDatum(d, name, nullable, options)(dateFormatter.parse) case _: StringType => (d: String) => nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => - makeConverter(name, udt.sqlType, nullable, options) + makeConverter(name, udt.sqlType, nullable) // We don't actually hit this exception though, we keep it for understandability case _ => throw new RuntimeException(s"Unsupported type: ${dataType.typeName}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala new file mode 100644 index 0000000000000..ad1f4131de2f6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.time._ +import java.time.format.DateTimeFormatterBuilder +import java.time.temporal.{ChronoField, TemporalQueries} +import java.util.{Locale, TimeZone} + +import scala.util.Try + +import org.apache.commons.lang3.time.FastDateFormat + +import org.apache.spark.sql.internal.SQLConf + +sealed trait DateTimeFormatter { + def parse(s: String): Long // returns microseconds since epoch + def format(us: Long): String +} + +class Iso8601DateTimeFormatter( + pattern: String, + timeZone: TimeZone, + locale: Locale) extends DateTimeFormatter { + val formatter = new DateTimeFormatterBuilder() + .appendPattern(pattern) + .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) + .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) + .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) + .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) + .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) + .toFormatter(locale) + + def toInstant(s: String): Instant = { + val temporalAccessor = formatter.parse(s) + if (temporalAccessor.query(TemporalQueries.offset()) == null) { + val localDateTime = LocalDateTime.from(temporalAccessor) + val zonedDateTime = ZonedDateTime.of(localDateTime, timeZone.toZoneId) + Instant.from(zonedDateTime) + } else { + Instant.from(temporalAccessor) + } + } + + private def instantToMicros(instant: Instant): Long = { + val sec = Math.multiplyExact(instant.getEpochSecond, DateTimeUtils.MICROS_PER_SECOND) + val result = Math.addExact(sec, instant.getNano / DateTimeUtils.NANOS_PER_MICROS) + result + } + + def parse(s: String): Long = instantToMicros(toInstant(s)) + + def format(us: Long): String = { + val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND) + val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND) + val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS) + + formatter.withZone(timeZone.toZoneId).format(instant) + } +} + +class LegacyDateTimeFormatter( + pattern: String, + timeZone: TimeZone, + locale: Locale) extends DateTimeFormatter { + val format = FastDateFormat.getInstance(pattern, timeZone, locale) + + protected def toMillis(s: String): Long = format.parse(s).getTime + + def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS + + def format(us: Long): String = { + format.format(DateTimeUtils.toJavaTimestamp(us)) + } +} + +class LegacyFallbackDateTimeFormatter( + pattern: String, + timeZone: TimeZone, + locale: Locale) extends LegacyDateTimeFormatter(pattern, timeZone, locale) { + override def toMillis(s: String): Long = { + Try {super.toMillis(s)}.getOrElse(DateTimeUtils.stringToTime(s).getTime) + } +} + +object DateTimeFormatter { + def apply(format: String, timeZone: TimeZone, locale: Locale): DateTimeFormatter = { + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyFallbackDateTimeFormatter(format, timeZone, locale) + } else { + new Iso8601DateTimeFormatter(format, timeZone, locale) + } + } +} + +sealed trait DateFormatter { + def parse(s: String): Int // returns days since epoch + def format(days: Int): String +} + +class Iso8601DateFormatter( + pattern: String, + timeZone: TimeZone, + locale: Locale) extends DateFormatter { + + val dateTimeFormatter = new Iso8601DateTimeFormatter(pattern, timeZone, locale) + + override def parse(s: String): Int = { + val seconds = dateTimeFormatter.toInstant(s).getEpochSecond + val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) + + days.toInt + } + + override def format(days: Int): String = { + val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) + dateTimeFormatter.formatter.withZone(timeZone.toZoneId).format(instant) + } +} + +class LegacyDateFormatter( + pattern: String, + timeZone: TimeZone, + locale: Locale) extends DateFormatter { + val format = FastDateFormat.getInstance(pattern, timeZone, locale) + + def parse(s: String): Int = { + val milliseconds = format.parse(s).getTime + DateTimeUtils.millisToDays(milliseconds) + } + + def format(days: Int): String = { + val date = DateTimeUtils.toJavaDate(days) + format.format(date) + } +} + +class LegacyFallbackDateFormatter( + pattern: String, + timeZone: TimeZone, + locale: Locale) extends LegacyDateFormatter(pattern, timeZone, locale) { + override def parse(s: String): Int = { + Try(super.parse(s)).orElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime)) + }.getOrElse { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + s.toInt + } + } +} + +object DateFormatter { + def apply(format: String, timeZone: TimeZone, locale: Locale): DateFormatter = { + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyFallbackDateFormatter(format, timeZone, locale) + } else { + new Iso8601DateFormatter(format, timeZone, locale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 5ae75dc939303..c6dfdbf2505ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -50,7 +50,7 @@ object DateTimeUtils { final val MILLIS_PER_SECOND = 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY - + final val NANOS_PER_MICROS = 1000L final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L // number of days in 400 years diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c4f00d723c252..451b051f8407e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1618,6 +1618,13 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) + + val LEGACY_TIME_PARSER_ENABLED = buildConf("spark.sql.legacy.timeParser.enabled") + .doc("When set to true, java.text.SimpleDateFormat is used for formatting and parsing " + + " dates/timestamps in a locale-sensitive manner. When set to false, classes from " + + "java.time.* packages are used for the same purpose.") + .booleanConf + .createWithDefault(false) } /** @@ -2040,6 +2047,8 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CONFS) + def legacyTimeParserEnabled: Boolean = getConf(SQLConf.LEGACY_TIME_PARSER_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 1a020e67a75b0..c2b525ad1a9f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -22,13 +22,12 @@ import java.util.Locale import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { +class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val options = new CSVOptions(Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss"), false, "GMT") val inferSchema = new CSVInferSchema(options) assert(inferSchema.inferField(NullType, "") == NullType) @@ -48,7 +47,7 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { } test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val options = new CSVOptions(Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss"), false, "GMT") val inferSchema = new CSVInferSchema(options) assert(inferSchema.inferField(LongType, "1.0") == DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 7212402ef5cff..2d0b0d3033a9c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal import java.text.{DecimalFormat, DecimalFormatSymbols} -import java.util.Locale +import java.util.{Locale, TimeZone} + +import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParserSuite extends SparkFunSuite with SQLHelper { - private val parser = new UnivocityParser( - StructType(Seq.empty), - new CSVOptions(Map.empty[String, String], false, "GMT")) - private def assertNull(v: Any) = assert(v == null) test("Can parse decimal type values") { @@ -43,7 +40,8 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === + val parser = new UnivocityParser(StructType(Seq.empty), options) + assert(parser.makeConverter("_1", decimalType).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } } @@ -56,22 +54,23 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { types.foreach { t => // Tests that a custom nullValue. val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") - val converter = - parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) + var parser = new UnivocityParser(StructType(Seq.empty), nullValueOptions) + val converter = parser.makeConverter("_1", t, nullable = true) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) + parser = new UnivocityParser(StructType(Seq.empty), options) + assertNull(parser.makeConverter("_1", t, nullable = true).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") - val converter = - parser.makeConverter("_1", t, nullable = false, options = options) + val parser = new UnivocityParser(StructType(Seq.empty), options) + val converter = parser.makeConverter("_1", t, nullable = false) var message = intercept[RuntimeException] { converter.apply("-") }.getMessage @@ -86,62 +85,74 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { // null. Seq(true, false).foreach { b => val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") - val converter = - parser.makeConverter("_1", StringType, nullable = b, options = options) + val parser = new UnivocityParser(StructType(Seq.empty), options) + val converter = parser.makeConverter("_1", StringType, nullable = b) assert(converter.apply("") == UTF8String.fromString("")) } } test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val parser = new UnivocityParser(StructType(Seq.empty), options) val exception = intercept[RuntimeException]{ - parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") + parser.makeConverter("_1", IntegerType, nullable = false).apply("") } assert(exception.getMessage.contains("null value found but field _1 is not nullable.")) } test("Types are cast correctly") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") - assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", LongType, options = options).apply("10") == 10) - assert(parser.makeConverter("_1", FloatType, options = options).apply("1.00") == 1.0) - assert(parser.makeConverter("_1", DoubleType, options = options).apply("1.00") == 1.0) - assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) - - val timestampsOptions = + var parser = new UnivocityParser(StructType(Seq.empty), options) + assert(parser.makeConverter("_1", ByteType).apply("10") == 10) + assert(parser.makeConverter("_1", ShortType).apply("10") == 10) + assert(parser.makeConverter("_1", IntegerType).apply("10") == 10) + assert(parser.makeConverter("_1", LongType).apply("10") == 10) + assert(parser.makeConverter("_1", FloatType).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", DoubleType).apply("1.00") == 1.0) + assert(parser.makeConverter("_1", BooleanType).apply("true") == true) + + var timestampsOptions = new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") + parser = new UnivocityParser(StructType(Seq.empty), timestampsOptions) val customTimestamp = "31/01/2015 00:00" - val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime - val castedTimestamp = - parser.makeConverter("_1", TimestampType, nullable = true, options = timestampsOptions) + var format = FastDateFormat.getInstance( + timestampsOptions.timestampFormat, timestampsOptions.timeZone, timestampsOptions.locale) + val expectedTime = format.parse(customTimestamp).getTime + val castedTimestamp = parser.makeConverter("_1", TimestampType, nullable = true) .apply(customTimestamp) assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") - val expectedDate = dateOptions.dateFormat.parse(customDate).getTime - val castedDate = - parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) - .apply(customTimestamp) - assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) + parser = new UnivocityParser(StructType(Seq.empty), dateOptions) + format = FastDateFormat.getInstance( + dateOptions.dateFormat, dateOptions.timeZone, dateOptions.locale) + val expectedDate = format.parse(customDate).getTime + val castedDate = parser.makeConverter("_1", DateType, nullable = true) + .apply(customDate) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate, TimeZone.getTimeZone("GMT"))) val timestamp = "2015-01-01 00:00:00" - assert(parser.makeConverter("_1", TimestampType, options = options).apply(timestamp) == - DateTimeUtils.stringToTime(timestamp).getTime * 1000L) - assert(parser.makeConverter("_1", DateType, options = options).apply("2015-01-01") == - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) + timestampsOptions = new CSVOptions(Map( + "timestampFormat" -> "yyyy-MM-dd HH:mm:ss", + "dateFormat" -> "yyyy-MM-dd"), false, "UTC") + parser = new UnivocityParser(StructType(Seq.empty), timestampsOptions) + val expected = 1420070400 * DateTimeUtils.MICROS_PER_SECOND + assert(parser.makeConverter("_1", TimestampType).apply(timestamp) == + expected) + assert(parser.makeConverter("_1", DateType).apply("2015-01-01") == + expected / DateTimeUtils.MICROS_PER_DAY) } test("Throws exception for casting an invalid string to Float and Double Types") { val options = new CSVOptions(Map.empty[String, String], false, "GMT") + val parser = new UnivocityParser(StructType(Seq.empty), options) val types = Seq(DoubleType, FloatType) val input = Seq("10u000", "abc", "1 2/3") types.foreach { dt => input.foreach { v => val message = intercept[NumberFormatException] { - parser.makeConverter("_1", dt, options = options).apply(v) + parser.makeConverter("_1", dt).apply(v) }.getMessage assert(message.contains(v)) } @@ -150,9 +161,9 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { test("Float NaN values are parsed correctly") { val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") + val parser = new UnivocityParser(StructType(Seq.empty), options) val floatVal: Float = parser.makeConverter( - "_1", FloatType, nullable = true, options = options - ).apply("nn").asInstanceOf[Float] + "_1", FloatType, nullable = true).apply("nn").asInstanceOf[Float] // Java implements the IEEE-754 floating point standard which guarantees that any comparison // against NaN will return false (except != which returns true) @@ -161,41 +172,41 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { test("Double NaN values are parsed correctly") { val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") + val parser = new UnivocityParser(StructType(Seq.empty), options) val doubleVal: Double = parser.makeConverter( - "_1", DoubleType, nullable = true, options = options - ).apply("-").asInstanceOf[Double] + "_1", DoubleType, nullable = true).apply("-").asInstanceOf[Double] assert(doubleVal.isNaN) } test("Float infinite values can be parsed") { val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") + var parser = new UnivocityParser(StructType(Seq.empty), negativeInfOptions) val floatVal1 = parser.makeConverter( - "_1", FloatType, nullable = true, options = negativeInfOptions - ).apply("max").asInstanceOf[Float] + "_1", FloatType, nullable = true).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") + parser = new UnivocityParser(StructType(Seq.empty), positiveInfOptions) val floatVal2 = parser.makeConverter( - "_1", FloatType, nullable = true, options = positiveInfOptions - ).apply("max").asInstanceOf[Float] + "_1", FloatType, nullable = true).apply("max").asInstanceOf[Float] assert(floatVal2 == Float.PositiveInfinity) } test("Double infinite values can be parsed") { val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") + var parser = new UnivocityParser(StructType(Seq.empty), negativeInfOptions) val doubleVal1 = parser.makeConverter( - "_1", DoubleType, nullable = true, options = negativeInfOptions - ).apply("max").asInstanceOf[Double] + "_1", DoubleType, nullable = true).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") + parser = new UnivocityParser(StructType(Seq.empty), positiveInfOptions) val doubleVal2 = parser.makeConverter( - "_1", DoubleType, nullable = true, options = positiveInfOptions - ).apply("max").asInstanceOf[Double] + "_1", DoubleType, nullable = true).apply("max").asInstanceOf[Double] assert(doubleVal2 == Double.PositiveInfinity) } @@ -211,7 +222,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { val options = new CSVOptions(Map("locale" -> langTag), false, "GMT") val parser = new UnivocityParser(new StructType().add("d", decimalType), options) - assert(parser.makeConverter("_1", decimalType, options = options).apply(input) === expected) + assert(parser.makeConverter("_1", decimalType).apply(input) === expected) } Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala index dfa0fe93a2f9c..66d8d28988f89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala @@ -26,7 +26,7 @@ object DateTimeTestUtils { val ALL_TIMEZONES: Seq[TimeZone] = TimeZone.getAvailableIDs.toSeq.map(TimeZone.getTimeZone) - val outstandingTimezones: Seq[TimeZone] = Seq( + val outstandingTimezonesIds: Seq[String] = Seq( "UTC", "PST", "CET", @@ -34,7 +34,8 @@ object DateTimeTestUtils { "America/Los_Angeles", "Antarctica/Vostok", "Asia/Hong_Kong", - "Europe/Amsterdam").map(TimeZone.getTimeZone) + "Europe/Amsterdam") + val outstandingTimezones: Seq[TimeZone] = outstandingTimezonesIds.map(TimeZone.getTimeZone) def withDefaultTimeZone[T](newDefaultTimeZone: TimeZone)(block: => T): T = { val originalDefaultTimeZone = TimeZone.getDefault diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala new file mode 100644 index 0000000000000..02d4ee0490604 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.{Locale, TimeZone} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter, DateTimeTestUtils} + +class DateTimeFormatterSuite extends SparkFunSuite { + test("parsing dates using time zones") { + val localDate = "2018-12-02" + val expectedDays = Map( + "UTC" -> 17867, + "PST" -> 17867, + "CET" -> 17866, + "Africa/Dakar" -> 17867, + "America/Los_Angeles" -> 17867, + "Antarctica/Vostok" -> 17866, + "Asia/Hong_Kong" -> 17866, + "Europe/Amsterdam" -> 17866) + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US) + val daysSinceEpoch = formatter.parse(localDate) + assert(daysSinceEpoch === expectedDays(timeZone)) + } + } + + test("parsing timestamps using time zones") { + val localDate = "2018-12-02T10:11:12.001234" + val expectedMicros = Map( + "UTC" -> 1543745472001234L, + "PST" -> 1543774272001234L, + "CET" -> 1543741872001234L, + "Africa/Dakar" -> 1543745472001234L, + "America/Los_Angeles" -> 1543774272001234L, + "Antarctica/Vostok" -> 1543723872001234L, + "Asia/Hong_Kong" -> 1543716672001234L, + "Europe/Amsterdam" -> 1543741872001234L) + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = DateTimeFormatter( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + TimeZone.getTimeZone(timeZone), + Locale.US) + val microsSinceEpoch = formatter.parse(localDate) + assert(microsSinceEpoch === expectedMicros(timeZone)) + } + } + + test("format dates using time zones") { + val daysSinceEpoch = 17867 + val expectedDate = Map( + "UTC" -> "2018-12-02", + "PST" -> "2018-12-01", + "CET" -> "2018-12-02", + "Africa/Dakar" -> "2018-12-02", + "America/Los_Angeles" -> "2018-12-01", + "Antarctica/Vostok" -> "2018-12-02", + "Asia/Hong_Kong" -> "2018-12-02", + "Europe/Amsterdam" -> "2018-12-02") + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US) + val date = formatter.format(daysSinceEpoch) + assert(date === expectedDate(timeZone)) + } + } + + test("format timestamps using time zones") { + val microsSinceEpoch = 1543745472001234L + val expectedTimestamp = Map( + "UTC" -> "2018-12-02T10:11:12.001234", + "PST" -> "2018-12-02T02:11:12.001234", + "CET" -> "2018-12-02T11:11:12.001234", + "Africa/Dakar" -> "2018-12-02T10:11:12.001234", + "America/Los_Angeles" -> "2018-12-02T02:11:12.001234", + "Antarctica/Vostok" -> "2018-12-02T16:11:12.001234", + "Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234", + "Europe/Amsterdam" -> "2018-12-02T11:11:12.001234") + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = DateTimeFormatter( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + TimeZone.getTimeZone(timeZone), + Locale.US) + val timestamp = formatter.format(microsSinceEpoch) + assert(timestamp === expectedTimestamp(timeZone)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 537d13b1bc8dd..6b67fccf86b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -53,7 +53,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { test("checking the columnNameOfCorruptRecord option") { val columnNameOfCorruptRecord = "_unparsed" val df = Seq("0,2013-111-11 12:13:14", "1,1983-08-04").toDS() - val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val schema = new StructType().add("a", IntegerType).add("b", DateType) val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) val df2 = df .select(from_csv($"value", schemaWithCorrField1, Map( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index c9273193b6425..3b977d74053e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -586,6 +586,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") .load(testFile(commentsFile)) .collect() @@ -622,10 +623,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val options = Map( "header" -> "true", "inferSchema" -> "false", - "dateFormat" -> "dd/MM/yyyy hh:mm") + "dateFormat" -> "dd/MM/yyyy HH:mm") val results = spark.read .format("csv") .options(options) + .option("timeZone", "UTC") .schema(customSchema) .load(testFile(datesFile)) .select("date") @@ -893,36 +895,38 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("Write dates correctly in ISO8601 format by default") { - withTempDir { dir => - val customSchema = new StructType(Array(StructField("date", DateType, true))) - val iso8601datesPath = s"${dir.getCanonicalPath}/iso8601dates.csv" - val dates = spark.read - .format("csv") - .schema(customSchema) - .option("header", "true") - .option("inferSchema", "false") - .option("dateFormat", "dd/MM/yyyy HH:mm") - .load(testFile(datesFile)) - dates.write - .format("csv") - .option("header", "true") - .save(iso8601datesPath) + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + withTempDir { dir => + val customSchema = new StructType(Array(StructField("date", DateType, true))) + val iso8601datesPath = s"${dir.getCanonicalPath}/iso8601dates.csv" + val dates = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("inferSchema", "false") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + dates.write + .format("csv") + .option("header", "true") + .save(iso8601datesPath) - // This will load back the dates as string. - val stringSchema = StructType(StructField("date", StringType, true) :: Nil) - val iso8601dates = spark.read - .format("csv") - .schema(stringSchema) - .option("header", "true") - .load(iso8601datesPath) + // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) + val iso8601dates = spark.read + .format("csv") + .schema(stringSchema) + .option("header", "true") + .load(iso8601datesPath) + + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) + val expectedDates = dates.collect().map { r => + // This should be ISO8601 formatted string. + Row(iso8501.format(r.toSeq.head)) + } - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) - val expectedDates = dates.collect().map { r => - // This should be ISO8601 formatted string. - Row(iso8501.format(r.toSeq.head)) + checkAnswer(iso8601dates, expectedDates) } - - checkAnswer(iso8601dates, expectedDates) } } @@ -1107,7 +1111,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { Seq(false, true).foreach { multiLine => - val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + val schema = new StructType().add("a", IntegerType).add("b", DateType) // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read @@ -1139,7 +1143,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val schemaWithCorrField2 = new StructType() .add("a", IntegerType) .add(columnNameOfCorruptRecord, StringType) - .add("b", TimestampType) + .add("b", DateType) val df3 = spark .read .option("mode", "permissive") @@ -1325,7 +1329,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val columnNameOfCorruptRecord = "_corrupt_record" val schema = new StructType() .add("a", IntegerType) - .add("b", TimestampType) + .add("b", DateType) .add(columnNameOfCorruptRecord, StringType) // negative cases val msg = intercept[AnalysisException] { From 556d83e0d87a8f899f29544eb5ca4999a84c96c1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 4 Dec 2018 10:33:27 -0800 Subject: [PATCH 2216/2461] [SPARK-26233][SQL] CheckOverflow when encoding a decimal value ## What changes were proposed in this pull request? When we encode a Decimal from external source we don't check for overflow. That method is useful not only in order to enforce that we can represent the correct value in the specified range, but it also changes the underlying data to the right precision/scale. Since in our code generation we assume that a decimal has exactly the same precision and scale of its data type, missing to enforce it can lead to corrupted output/results when there are subsequent transformations. ## How was this patch tested? added UT Closes #23210 from mgaido91/SPARK-26233. Authored-by: Marco Gaido Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/catalyst/encoders/RowEncoder.scala | 4 ++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d905f8f9858e8..8ca3d356f3bdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -106,12 +106,12 @@ object RowEncoder { returnNullable = false) case d: DecimalType => - StaticInvoke( + CheckOverflow(StaticInvoke( Decimal.getClass, d, "fromDecimal", inputObject :: Nil, - returnNullable = false) + returnNullable = false), d) case StringType => StaticInvoke( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0f900833d2cfe..525c7cef39563 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1647,6 +1647,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(ds, data: _*) checkAnswer(ds.select("x"), Seq(Row(1), Row(2))) } + + test("SPARK-26233: serializer should enforce decimal precision and scale") { + val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8)))) + val encoder = RowEncoder(s) + implicit val uEnc = encoder + val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111))) + checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), + Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 35f9163adf5c067229afbe57ed60d5dd5f2422c8 Mon Sep 17 00:00:00 2001 From: Shahid Date: Tue, 4 Dec 2018 11:00:58 -0800 Subject: [PATCH 2217/2461] [SPARK-26119][CORE][WEBUI] Task summary table should contain only successful tasks' metrics ## What changes were proposed in this pull request? Task summary table in the stage page currently displays the summary of all the tasks. However, we should display the task summary of only successful tasks, to follow the behavior of previous versions of spark. ## How was this patch tested? Added UT. attached screenshot Before patch: ![screenshot from 2018-11-20 00-36-18](https://user-images.githubusercontent.com/23054875/48729339-62e3a580-ec5d-11e8-81f0-0d191a234ffe.png) ![screenshot from 2018-11-20 01-18-37](https://user-images.githubusercontent.com/23054875/48731112-41d18380-ec62-11e8-8c31-1ffbfa04e746.png) Closes #23088 from shahidki31/summaryMetrics. Authored-by: Shahid Signed-off-by: Marcelo Vanzin --- .../apache/spark/status/AppStatusStore.scala | 73 +++++++++++++------ .../spark/status/AppStatusStoreSuite.scala | 33 ++++++++- 2 files changed, 81 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 5c0ed4d5d8f4c..b35781cb36e81 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -148,11 +148,20 @@ private[spark] class AppStatusStore( // cheaper for disk stores (avoids deserialization). val count = { Utils.tryWithResource( - store.view(classOf[TaskDataWrapper]) - .parent(stageKey) - .index(TaskIndexNames.EXEC_RUN_TIME) - .first(0L) - .closeableIterator() + if (store.isInstanceOf[InMemoryStore]) { + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.STATUS) + .first("SUCCESS") + .last("SUCCESS") + .closeableIterator() + } else { + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.EXEC_RUN_TIME) + .first(0L) + .closeableIterator() + } ) { it => var _count = 0L while (it.hasNext()) { @@ -221,30 +230,50 @@ private[spark] class AppStatusStore( // stabilize once the stage finishes. It's also slow, especially with disk stores. val indices = quantiles.map { q => math.min((q * count).toLong, count - 1) } + // TODO: Summary metrics needs to display all the successful tasks' metrics (SPARK-26119). + // For InMemory case, it is efficient to find using the following code. But for diskStore case + // we need an efficient solution to avoid deserialization time overhead. For that, we need to + // rework on the way indexing works, so that we can index by specific metrics for successful + // and failed tasks differently (would be tricky). Also would require changing the disk store + // version (to invalidate old stores). def scanTasks(index: String)(fn: TaskDataWrapper => Long): IndexedSeq[Double] = { - Utils.tryWithResource( - store.view(classOf[TaskDataWrapper]) + if (store.isInstanceOf[InMemoryStore]) { + val quantileTasks = store.view(classOf[TaskDataWrapper]) .parent(stageKey) .index(index) .first(0L) - .closeableIterator() - ) { it => - var last = Double.NaN - var currentIdx = -1L - indices.map { idx => - if (idx == currentIdx) { - last - } else { - val diff = idx - currentIdx - currentIdx = idx - if (it.skip(diff - 1)) { - last = fn(it.next()).toDouble + .asScala + .filter { _.status == "SUCCESS"} // Filter "SUCCESS" tasks + .toIndexedSeq + + indices.map { index => + fn(quantileTasks(index.toInt)).toDouble + }.toIndexedSeq + } else { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(index) + .first(0L) + .closeableIterator() + ) { it => + var last = Double.NaN + var currentIdx = -1L + indices.map { idx => + if (idx == currentIdx) { last } else { - Double.NaN + val diff = idx - currentIdx + currentIdx = idx + if (it.skip(diff - 1)) { + last = fn(it.next()).toDouble + last + } else { + Double.NaN + } } - } - }.toIndexedSeq + }.toIndexedSeq + } } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala index 92f90f3d96ddf..75a658161d3ff 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -77,6 +77,34 @@ class AppStatusStoreSuite extends SparkFunSuite { assert(store.count(classOf[CachedQuantile]) === 2) } + test("only successfull task have taskSummary") { + val store = new InMemoryStore() + (0 until 5).foreach { i => store.write(newTaskData(i, status = "FAILED")) } + val appStore = new AppStatusStore(store).taskSummary(stageId, attemptId, uiQuantiles) + assert(appStore.size === 0) + } + + test("summary should contain task metrics of only successfull tasks") { + val store = new InMemoryStore() + + for (i <- 0 to 5) { + if (i % 2 == 1) { + store.write(newTaskData(i, status = "FAILED")) + } else { + store.write(newTaskData(i)) + } + } + + val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, uiQuantiles).get + + val values = Array(0.0, 2.0, 4.0) + + val dist = new Distribution(values, 0, values.length).getQuantiles(uiQuantiles.sorted) + dist.zip(summary.executorRunTime).foreach { case (expected, actual) => + assert(expected === actual) + } + } + private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = { val store = new InMemoryStore() val values = (0 until count).map { i => @@ -93,12 +121,11 @@ class AppStatusStoreSuite extends SparkFunSuite { } } - private def newTaskData(i: Int): TaskDataWrapper = { + private def newTaskData(i: Int, status: String = "SUCCESS"): TaskDataWrapper = { new TaskDataWrapper( - i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None, + i, i, i, i, i, i, i.toString, i.toString, status, i.toString, false, Nil, None, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, stageId, attemptId) } - } From 180f969c97a66b4c265e5fad8272665a00572f1a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 4 Dec 2018 14:35:04 -0800 Subject: [PATCH 2218/2461] [SPARK-26094][CORE][STREAMING] createNonEcFile creates parent dirs. ## What changes were proposed in this pull request? We explicitly avoid files with hdfs erasure coding for the streaming WAL and for event logs, as hdfs EC does not support all relevant apis. However, the new builder api used has different semantics -- it does not create parent dirs, and it does not resolve relative paths. This updates createNonEcFile to have similar semantics to the old api. ## How was this patch tested? Ran tests with the WAL pointed at a non-existent dir, which failed before this change. Manually tested the new function with a relative path as well. Unit tests via jenkins. Closes #23092 from squito/SPARK-26094. Authored-by: Imran Rashid Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/deploy/SparkHadoopUtil.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7bb2a419107d6..937199273dab9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -466,7 +466,13 @@ private[spark] object SparkHadoopUtil { try { // Use reflection as this uses apis only avialable in hadoop 3 val builderMethod = fs.getClass().getMethod("createFile", classOf[Path]) - val builder = builderMethod.invoke(fs, path) + // the builder api does not resolve relative paths, nor does it create parent dirs, while + // the old api does. + if (!fs.mkdirs(path.getParent())) { + throw new IOException(s"Failed to create parents of $path") + } + val qualifiedPath = fs.makeQualified(path) + val builder = builderMethod.invoke(fs, qualifiedPath) val builderCls = builder.getClass() // this may throw a NoSuchMethodException if the path is not on hdfs val replicateMethod = builderCls.getMethod("replicate") From 7143e9d7220bd98ceb82c5c5f045108a8a664ec1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 5 Dec 2018 09:12:24 +0800 Subject: [PATCH 2219/2461] [SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the limit size ## What changes were proposed in this pull request? The PR starts from the [comment](https://github.com/apache/spark/pull/23124#discussion_r236112390) in the main one and it aims at: - simplifying the code for `MapConcat`; - be more precise in checking the limit size. ## How was this patch tested? existing tests Closes #23217 from mgaido91/SPARK-25829_followup. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../expressions/collectionOperations.scala | 77 +------------------ .../catalyst/util/ArrayBasedMapBuilder.scala | 10 +++ 2 files changed, 12 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fa8e38acd522d..67f6739b1e18f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres return null } - val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + - s"elements due to exceeding the map size limit " + - s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") - } - for (map <- maps) { mapBuilder.putAll(map.keyArray(), map.valueArray()) } @@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapCodes = children.map(_.genCode(ctx)) - val keyType = dataType.keyType - val valueType = dataType.valueType val argsName = ctx.freshName("args") val hasNullName = ctx.freshName("hasNull") val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) @@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres ) val idxName = ctx.freshName("idx") - val numElementsName = ctx.freshName("numElems") - val finKeysName = ctx.freshName("finalKeys") - val finValsName = ctx.freshName("finalValues") - - val keyConcat = genCodeForArrays(ctx, keyType, false) - - val valueConcat = - if (valueType.sameType(keyType) && - !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { - keyConcat - } else { - genCodeForArrays(ctx, valueType, dataType.valueContainsNull) - } - - val keyArgsName = ctx.freshName("keyArgs") - val valArgsName = ctx.freshName("valArgs") - val mapMerge = s""" - |ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}]; - |ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}]; - |long $numElementsName = 0; |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { - | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); - | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); - | $numElementsName += $argsName[$idxName].numElements(); + | $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray()); |} - |if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful attempt to concat maps with " + - | $numElementsName + " elements due to exceeding the map size limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - |} - |ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName); - |ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName); - |${ev.value} = $builderTerm.from($finKeysName, $finValsName); + |${ev.value} = $builderTerm.build(); """.stripMargin ev.copy( @@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """.stripMargin) } - private def genCodeForArrays( - ctx: CodegenContext, - elementType: DataType, - checkForNull: Boolean): String = { - val counter = ctx.freshName("counter") - val arrayData = ctx.freshName("arrayData") - val argsName = ctx.freshName("args") - val numElemName = ctx.freshName("numElements") - val y = ctx.freshName("y") - val z = ctx.freshName("z") - - val allocation = CodeGenerator.createArrayData( - arrayData, elementType, numElemName, s" $prettyName failed.") - val assignment = CodeGenerator.createArrayAssignment( - arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull) - - val concat = ctx.freshName("concat") - val concatDef = - s""" - |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { - | $allocation - | int $counter = 0; - | for (int $y = 0; $y < ${children.length}; $y++) { - | for (int $z = 0; $z < $argsName[$y].numElements(); $z++) { - | $assignment - | $counter++; - | } - | } - | return $arrayData; - |} - """.stripMargin - - ctx.addNewFunction(concat, concatDef) - } - override def prettyName: String = "map_concat" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index e7cd61655dc9a..98934368205ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods /** * A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes @@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria val index = keyToIndex.getOrDefault(key, -1) if (index == -1) { + if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " + + s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } keyToIndex.put(key, values.length) keys.append(key) values.append(value) @@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria build() } } + + /** + * Returns the current size of the map which is going to be produced by the current builder. + */ + def size: Int = keys.size } From 7e3eb3cd209d83394ca2b2cec79b26b1bbe9d7ea Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 5 Dec 2018 15:22:08 +0800 Subject: [PATCH 2220/2461] [SPARK-26252][PYTHON] Add support to run specific unittests and/or doctests in python/run-tests script ## What changes were proposed in this pull request? This PR proposes add a developer option, `--testnames`, to our testing script to allow run specific set of unittests and doctests. **1. Run unittests in the class** ```bash ./run-tests --testnames 'pyspark.sql.tests.test_arrow ArrowTests' ``` ``` Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python tests: ['pyspark.sql.tests.test_arrow ArrowTests'] Starting test(python2.7): pyspark.sql.tests.test_arrow ArrowTests Starting test(pypy): pyspark.sql.tests.test_arrow ArrowTests Finished test(python2.7): pyspark.sql.tests.test_arrow ArrowTests (14s) Finished test(pypy): pyspark.sql.tests.test_arrow ArrowTests (14s) ... 22 tests were skipped Tests passed in 14 seconds Skipped tests in pyspark.sql.tests.test_arrow ArrowTests with pypy: test_createDataFrame_column_name_encoding (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_fallback_disabled (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_fallback_enabled (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped ... ``` **2. Run single unittest in the class.** ```bash ./run-tests --testnames 'pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion' ``` ``` Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python tests: ['pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion'] Starting test(pypy): pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion Starting test(python2.7): pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion Finished test(pypy): pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion (0s) ... 1 tests were skipped Finished test(python2.7): pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion (8s) Tests passed in 8 seconds Skipped tests in pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion with pypy: test_null_conversion (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` **3. Run doctests in single PySpark module.** ```bash ./run-tests --testnames pyspark.sql.dataframe ``` ``` Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python tests: ['pyspark.sql.dataframe'] Starting test(pypy): pyspark.sql.dataframe Starting test(python2.7): pyspark.sql.dataframe Finished test(python2.7): pyspark.sql.dataframe (47s) Finished test(pypy): pyspark.sql.dataframe (48s) Tests passed in 48 seconds ``` Of course, you can mix them: ```bash ./run-tests --testnames 'pyspark.sql.tests.test_arrow ArrowTests,pyspark.sql.dataframe' ``` ``` Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python tests: ['pyspark.sql.tests.test_arrow ArrowTests', 'pyspark.sql.dataframe'] Starting test(pypy): pyspark.sql.dataframe Starting test(pypy): pyspark.sql.tests.test_arrow ArrowTests Starting test(python2.7): pyspark.sql.dataframe Starting test(python2.7): pyspark.sql.tests.test_arrow ArrowTests Finished test(pypy): pyspark.sql.tests.test_arrow ArrowTests (0s) ... 22 tests were skipped Finished test(python2.7): pyspark.sql.tests.test_arrow ArrowTests (18s) Finished test(python2.7): pyspark.sql.dataframe (50s) Finished test(pypy): pyspark.sql.dataframe (52s) Tests passed in 52 seconds Skipped tests in pyspark.sql.tests.test_arrow ArrowTests with pypy: test_createDataFrame_column_name_encoding (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_fallback_disabled (pyspark.sql.tests.test_arrow.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` and also you can use all other options (except `--modules`, which will be ignored) ```bash ./run-tests --testnames 'pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion' --python-executables=python ``` ``` Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python'] Will test the following Python tests: ['pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion'] Starting test(python): pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion Finished test(python): pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion (12s) Tests passed in 12 seconds ``` See help below: ```bash ./run-tests --help ``` ``` Usage: run-tests [options] Options: ... Developer Options: --testnames=TESTNAMES A comma-separated list of specific modules, classes and functions of doctest or unittest to test. For example, 'pyspark.sql.foo' to run the module as unittests or doctests, 'pyspark.sql.tests FooTests' to run the specific class of unittests, 'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. '--modules' option is ignored if they are given. ``` I intentionally grouped it as a developer option to be more conservative. ## How was this patch tested? Manually tested. Negative tests were also done. ```bash ./run-tests --testnames 'pyspark.sql.tests.test_arrow ArrowTests.test_null_conversion1' --python-executables=python ``` ``` ... AttributeError: type object 'ArrowTests' has no attribute 'test_null_conversion1' ... ``` ```bash ./run-tests --testnames 'pyspark.sql.tests.test_arrow ArrowT' --python-executables=python ``` ``` ... AttributeError: 'module' object has no attribute 'ArrowT' ... ``` ```bash ./run-tests --testnames 'pyspark.sql.tests.test_ar' --python-executables=python ``` ``` ... /.../python2.7: No module named pyspark.sql.tests.test_ar ``` Closes #23203 from HyukjinKwon/SPARK-26252. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/run-tests-with-coverage | 2 - python/run-tests.py | 68 +++++++++++++++++++++++----------- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage index 6d74b563e9140..457821037d43c 100755 --- a/python/run-tests-with-coverage +++ b/python/run-tests-with-coverage @@ -50,8 +50,6 @@ export SPARK_CONF_DIR="$COVERAGE_DIR/conf" # This environment variable enables the coverage. export COVERAGE_PROCESS_START="$FWDIR/.coveragerc" -# If you'd like to run a specific unittest class, you could do such as -# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests ./run-tests "$@" # Don't run coverage for the coverage command itself diff --git a/python/run-tests.py b/python/run-tests.py index 01a6e81264dd6..e45268c13769a 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -19,7 +19,7 @@ from __future__ import print_function import logging -from optparse import OptionParser +from optparse import OptionParser, OptionGroup import os import re import shutil @@ -99,7 +99,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python): try: per_test_output = tempfile.TemporaryFile() retcode = subprocess.Popen( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + [os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(), stderr=per_test_output, stdout=per_test_output, env=env).wait() shutil.rmtree(tmp_dir, ignore_errors=True) except: @@ -190,6 +190,20 @@ def parse_opts(): help="Enable additional debug logging" ) + group = OptionGroup(parser, "Developer Options") + group.add_option( + "--testnames", type="string", + default=None, + help=( + "A comma-separated list of specific modules, classes and functions of doctest " + "or unittest to test. " + "For example, 'pyspark.sql.foo' to run the module as unittests or doctests, " + "'pyspark.sql.tests FooTests' to run the specific class of unittests, " + "'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. " + "'--modules' option is ignored if they are given.") + ) + parser.add_option_group(group) + (opts, args) = parser.parse_args() if args: parser.error("Unsupported arguments: %s" % ' '.join(args)) @@ -213,25 +227,31 @@ def _check_coverage(python_exec): def main(): opts = parse_opts() - if (opts.verbose): + if opts.verbose: log_level = logging.DEBUG else: log_level = logging.INFO + should_test_modules = opts.testnames is None logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') - modules_to_test = [] - for module_name in opts.modules.split(','): - if module_name in python_modules: - modules_to_test.append(python_modules[module_name]) - else: - print("Error: unrecognized module '%s'. Supported modules: %s" % - (module_name, ", ".join(python_modules))) - sys.exit(-1) LOGGER.info("Will test against the following Python executables: %s", python_execs) - LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + + if should_test_modules: + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module '%s'. Supported modules: %s" % + (module_name, ", ".join(python_modules))) + sys.exit(-1) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + else: + testnames_to_test = opts.testnames.split(',') + LOGGER.info("Will test the following Python tests: %s", testnames_to_test) task_queue = Queue.PriorityQueue() for python_exec in python_execs: @@ -246,16 +266,20 @@ def main(): LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) - for module in modules_to_test: - if python_implementation not in module.blacklisted_python_implementations: - for test_goal in module.python_test_goals: - heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] - if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): - priority = 0 - else: - priority = 100 - task_queue.put((priority, (python_exec, test_goal))) + if should_test_modules: + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + for test_goal in module.python_test_goals: + heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', + 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] + if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): + priority = 0 + else: + priority = 100 + task_queue.put((priority, (python_exec, test_goal))) + else: + for test_goal in testnames_to_test: + task_queue.put((0, (python_exec, test_goal))) # Create the target directory before starting tasks to avoid races. target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) From 169d9ad8f1b6006c8db0edbdfffc20dc73c78610 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Dec 2018 19:30:25 +0800 Subject: [PATCH 2221/2461] [SPARK-26133][ML][FOLLOWUP] Fix doc for OneHotEncoder ## What changes were proposed in this pull request? This fixes doc of renamed OneHotEncoder in PySpark. ## How was this patch tested? N/A Closes #23230 from viirya/remove_one_hot_encoder_followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Hyukjin Kwon --- python/pyspark/ml/feature.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 6cc80e181e5e0..c9507c20918e3 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1648,22 +1648,22 @@ class OneHotEncoder(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid at most a single one-value per row that indicates the input category index. For example with 5 categories, an input value of 2.0 would map to an output vector of `[0.0, 0.0, 1.0, 0.0]`. - The last category is not included by default (configurable via `dropLast`), + The last category is not included by default (configurable via :py:attr:`dropLast`), because it makes the vector entries sum up to one, and hence linearly dependent. So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories. - The output vectors are sparse. + .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories. + The output vectors are sparse. - When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is - added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros - vector. + When :py:attr:`handleInvalid` is configured to 'keep', an extra "category" indicating invalid + values is added as last category. So when :py:attr:`dropLast` is true, invalid values are + encoded as all-zeros vector. - Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output - cols come in pairs, specified by the order in the arrays, and each pair is treated - independently. + .. note:: When encoding multi-column by using :py:attr:`inputCols` and + :py:attr:`outputCols` params, input/output cols come in pairs, specified by the order in + the arrays, and each pair is treated independently. - See `StringIndexer` for converting categorical values into category indices + .. seealso:: :py:class:`StringIndexer` for converting categorical values into category indices >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) @@ -1671,7 +1671,7 @@ class OneHotEncoder(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid >>> model = ohe.fit(df) >>> model.transform(df).head().output SparseVector(2, {0: 1.0}) - >>> ohePath = temp_path + "/oheEstimator" + >>> ohePath = temp_path + "/ohe" >>> ohe.save(ohePath) >>> loadedOHE = OneHotEncoder.load(ohePath) >>> loadedOHE.getInputCols() == ohe.getInputCols() From 7bb1dab8a006531d612e21d888c7fc6911990017 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Wed, 5 Dec 2018 23:10:48 +0800 Subject: [PATCH 2222/2461] [SPARK-26271][FOLLOW-UP][SQL] remove unuse object SparkPlan ## What changes were proposed in this pull request? this code come from PR: https://github.com/apache/spark/pull/11190, but this code has never been used, only since PR: https://github.com/apache/spark/pull/14548, Let's continue fix it. thanks. ## How was this patch tested? N / A Closes #23227 from heary-cao/unuseSparkPlan. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 9d9b020309d9f..a89ccca99d059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -423,11 +423,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -object SparkPlan { - private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) -} - trait LeafExecNode extends SparkPlan { override final def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet From dd518a196c2d40ae48034b8b0950d1c8045c02ed Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 5 Dec 2018 23:43:03 +0800 Subject: [PATCH 2223/2461] [SPARK-26151][SQL][FOLLOWUP] Return partial results for bad CSV records ## What changes were proposed in this pull request? Updated SQL migration guide according to changes in https://github.com/apache/spark/pull/23120 Closes #23235 from MaxGekk/failuresafe-partial-result-followup. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index fee0e6df7177c..ed2ff139bcc33 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -35,6 +35,8 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, CSV datasource uses java.time API for parsing and generating CSV content. New formatting implementation supports date/timestamp patterns conformed to ISO 8601. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. From ab76900fedc05df7080c9b6c81d65a3f260c1c26 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 6 Dec 2018 09:14:46 +0800 Subject: [PATCH 2224/2461] [SPARK-26275][PYTHON][ML] Increases timeout for StreamingLogisticRegressionWithSGDTests.test_training_and_prediction test ## What changes were proposed in this pull request? Looks this test is flaky https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99704/console https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99569/console https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99644/console https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99548/console https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99454/console https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/99609/console ``` ====================================================================== FAIL: test_training_and_prediction (pyspark.mllib.tests.test_streaming_algorithms.StreamingLogisticRegressionWithSGDTests) Test that the model improves on toy data with no. of batches ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/jenkins/workspace/SparkPullRequestBuilder/python/pyspark/mllib/tests/test_streaming_algorithms.py", line 367, in test_training_and_prediction self._eventually(condition) File "/home/jenkins/workspace/SparkPullRequestBuilder/python/pyspark/mllib/tests/test_streaming_algorithms.py", line 78, in _eventually % (timeout, lastValue)) AssertionError: Test failed due to timeout after 30 sec, with last condition returning: Latest errors: 0.67, 0.71, 0.78, 0.7, 0.75, 0.74, 0.73, 0.69, 0.62, 0.71, 0.69, 0.75, 0.72, 0.77, 0.71, 0.74 ---------------------------------------------------------------------- Ran 13 tests in 185.051s FAILED (failures=1, skipped=1) ``` This looks happening after increasing the parallelism in Jenkins to speed up at https://github.com/apache/spark/pull/23111. I am able to reproduce this manually when the resource usage is heavy (with manual decrease of timeout). ## How was this patch tested? Manually tested by ``` cd python ./run-tests --testnames 'pyspark.mllib.tests.test_streaming_algorithms StreamingLogisticRegressionWithSGDTests.test_training_and_prediction' --python-executables=python ``` Closes #23236 from HyukjinKwon/SPARK-26275. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/mllib/tests/test_streaming_algorithms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index 4bc8904acd31c..bf2ad2d267bb2 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -364,7 +364,7 @@ def condition(): return True return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - self._eventually(condition) + self._eventually(condition, timeout=60.0) class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): From ecaa495b1fe532c36e952ccac42f4715809476af Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 6 Dec 2018 10:07:28 -0800 Subject: [PATCH 2225/2461] [SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record batches to improve performance ## What changes were proposed in this pull request? When executing `toPandas` with Arrow enabled, partitions that arrive in the JVM out-of-order must be buffered before they can be send to Python. This causes an excess of memory to be used in the driver JVM and increases the time it takes to complete because data must sit in the JVM waiting for preceding partitions to come in. This change sends un-ordered partitions to Python as soon as they arrive in the JVM, followed by a list of partition indices so that Python can assemble the data in the correct order. This way, data is not buffered at the JVM and there is no waiting on particular partitions so performance will be increased. Followup to #21546 ## How was this patch tested? Added new test with a large number of batches per partition, and test that forces a small delay in the first partition. These test that partitions are collected out-of-order and then are are put in the correct order in Python. ## Performance Tests - toPandas Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5 loops each. Test code ```python df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", rand()) for i in range(5): start = time.time() _ = df.toPandas() elapsed = time.time() - start ``` Spark config ``` spark.driver.memory 5g spark.executor.memory 5g spark.driver.maxResultSize 2g spark.sql.execution.arrow.enabled true ``` Current Master w/ Arrow stream | This PR ---------------------|------------ 5.16207 | 4.342533 5.133671 | 4.399408 5.147513 | 4.468471 5.105243 | 4.36524 5.018685 | 4.373791 Avg Master | Avg This PR ------------------|-------------- 5.1134364 | 4.3898886 Speedup of **1.164821449** Closes #22275 from BryanCutler/arrow-toPandas-oo-batches-SPARK-25274. Authored-by: Bryan Cutler Signed-off-by: Bryan Cutler --- python/pyspark/serializers.py | 33 ++++++++++++++ python/pyspark/sql/dataframe.py | 11 ++++- python/pyspark/sql/tests/test_arrow.py | 28 ++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 45 ++++++++++--------- 4 files changed, 95 insertions(+), 22 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ff9a612b77f61..f3ebd3767a0a1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,6 +185,39 @@ def loads(self, obj): raise NotImplementedError +class ArrowCollectSerializer(Serializer): + """ + Deserialize a stream of batches followed by batch order information. Used in + DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM. + """ + + def __init__(self): + self.serializer = ArrowStreamSerializer() + + def dump_stream(self, iterator, stream): + return self.serializer.dump_stream(iterator, stream) + + def load_stream(self, stream): + """ + Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields + a list of indices that can be used to put the RecordBatches in the correct order. + """ + # load the batches + for batch in self.serializer.load_stream(stream): + yield batch + + # load the batch order indices + num = read_int(stream) + batch_order = [] + for i in xrange(num): + index = read_int(stream) + batch_order.append(index) + yield batch_order + + def __repr__(self): + return "ArrowCollectSerializer(%s)" % self.serializer + + class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1b1092c409be0..a1056d0b787e3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -2168,7 +2168,14 @@ def _collectAsArrow(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowStreamSerializer())) + + # Collect list of un-ordered batches where last element is a list of correct order indices + results = list(_load_from_socket(sock_info, ArrowCollectSerializer())) + batches = results[:-1] + batch_order = results[-1] + + # Re-order the batch list using the correct order + return [batches[i] for i in batch_order] ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 6e75e82d58009..21fe5000df5d9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -381,6 +381,34 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_python.toPandas()) self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + def test_toPandas_batch_order(self): + + def delay_first_part(partition_index, iterator): + if partition_index == 0: + time.sleep(0.1) + return iterator + + # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python + def run_test(num_records, num_parts, max_records, use_delay=False): + df = self.spark.range(num_records, numPartitions=num_parts).toDF("a") + if use_delay: + df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF() + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}): + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf, pdf_arrow) + + cases = [ + (1024, 512, 2), # Use large num partitions for more likely collecting out of order + (64, 8, 2, True), # Use delay in first partition to force collecting out of order + (64, 64, 1), # Test single batch per partition + (64, 1, 64), # Test single partition, single batch + (64, 1, 8), # Test single partition, multiple batches + (30, 7, 2), # Test different sized partitions + ] + + for case in cases: + run_test(*case) + class EncryptionArrowTests(ArrowTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b10d66dfb1aef..a664c7338badb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql -import java.io.CharArrayWriter +import java.io.{CharArrayWriter, DataOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.util.control.NonFatal @@ -3200,34 +3201,38 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { out => + PythonRDD.serveToStream("serve-Arrow") { outputStream => + val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length - // Store collection results for worst case of 1 to N-1 partitions - val results = new Array[Array[Array[Byte]]](numPartitions - 1) - var lastIndex = -1 // index of last partition written + // Batches ordered by (index of partition, batch index in that partition) tuple + val batchOrder = new ArrayBuffer[(Int, Int)]() + var partitionCount = 0 - // Handler to eagerly write partitions to Python in order + // Handler to eagerly write batches to Python as they arrive, un-ordered def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { - // If result is from next partition in order - if (index - 1 == lastIndex) { + if (arrowBatches.nonEmpty) { + // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) - lastIndex += 1 - // Write stored partitions that come next in order - while (lastIndex < results.length && results(lastIndex) != null) { - batchWriter.writeBatches(results(lastIndex).iterator) - results(lastIndex) = null - lastIndex += 1 + arrowBatches.indices.foreach { + partition_batch_index => batchOrder.append((index, partition_batch_index)) } - // After last batch, end the stream - if (lastIndex == results.length) { - batchWriter.end() + } + partitionCount += 1 + + // After last batch, end the stream and write batch order indices + if (partitionCount == numPartitions) { + batchWriter.end() + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) => + out.writeInt(overall_batch_index) } - } else { - // Store partitions received out of order - results(index - 1) = arrowBatches + out.flush() } } From b14a26ee5764aa98472bc69ab1dec408b89bc78a Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 6 Dec 2018 10:59:20 -0800 Subject: [PATCH 2226/2461] [SPARK-26236][SS] Add kafka delegation token support documentation. ## What changes were proposed in this pull request? Kafka delegation token support implemented in [PR#22598](https://github.com/apache/spark/pull/22598) but that didn't contain documentation because of rapid changes. Because it has been merged in this PR I've documented it. ## How was this patch tested? jekyll build + manual html check Closes #23195 from gaborgsomogyi/SPARK-26236. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../structured-streaming-kafka-integration.md | 216 +++++++++++++++++- 1 file changed, 206 insertions(+), 10 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index a549ce2a6a05f..7040f8da2c614 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -66,8 +66,8 @@ Dataset df = spark .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .load(); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics Dataset df = spark @@ -75,8 +75,8 @@ Dataset df = spark .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .load(); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern Dataset df = spark @@ -84,8 +84,8 @@ Dataset df = spark .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .load(); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% endhighlight %} @@ -479,7 +479,7 @@ StreamingQuery ds = df .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("topic", "topic1") - .start() + .start(); // Write key-value data from a DataFrame to Kafka using a topic specified in the data StreamingQuery ds = df @@ -487,7 +487,7 @@ StreamingQuery ds = df .writeStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .start() + .start(); {% endhighlight %} @@ -547,14 +547,14 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("topic", "topic1") - .save() + .save(); // Write key-value data from a DataFrame to Kafka using a topic specified in the data df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") .write() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .save() + .save(); {% endhighlight %} @@ -624,3 +624,199 @@ For experimenting on `spark-shell`, you can also use `--packages` to add `spark- See [Application Submission Guide](submitting-applications.html) for more details about submitting applications with external dependencies. + +## Security + +Kafka 0.9.0.0 introduced several features that increases security in a cluster. For detailed +description about these possibilities, see [Kafka security docs](http://kafka.apache.org/documentation.html#security). + +It's worth noting that security is optional and turned off by default. + +Spark supports the following ways to authenticate against Kafka cluster: +- **Delegation token (introduced in Kafka broker 1.1.0)** +- **JAAS login configuration** + +### Delegation token + +This way the application can be configured via Spark parameters and may not need JAAS login +configuration (Spark can use Kafka's dynamic JAAS configuration feature). For further information +about delegation tokens, see [Kafka delegation token docs](http://kafka.apache.org/documentation/#security_delegation_token). + +The process is initiated by Spark's Kafka delegation token provider. When `spark.kafka.bootstrap.servers`, +Spark considers the following log in options, in order of preference: +- **JAAS login configuration** +- **Keytab file**, such as, + + ./bin/spark-submit \ + --keytab \ + --principal \ + --conf spark.kafka.bootstrap.servers= \ + ... + +- **Kerberos credential cache**, such as, + + ./bin/spark-submit \ + --conf spark.kafka.bootstrap.servers= \ + ... + +The Kafka delegation token provider can be turned off by setting `spark.security.credentials.kafka.enabled` to `false` (default: `true`). + +Spark can be configured to use the following authentication protocols to obtain token (it must match with +Kafka broker configuration): +- **SASL SSL (default)** +- **SSL** +- **SASL PLAINTEXT (for testing)** + +After obtaining delegation token successfully, Spark distributes it across nodes and renews it accordingly. +Delegation token uses `SCRAM` login module for authentication and because of that the appropriate +`sasl.mechanism` has to be configured on source/sink (it must match with Kafka broker configuration): + +
      +
      +{% highlight scala %} + +// Setting on Kafka Source for Streaming Queries +val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("subscribe", "topic1") + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + +// Setting on Kafka Source for Batch Queries +val df = spark + .read + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("subscribe", "topic1") + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + +// Setting on Kafka Sink for Streaming Queries +val ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("topic", "topic1") + .start() + +// Setting on Kafka Sink for Batch Queries +val ds = df + .selectExpr("topic1", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .save() + +{% endhighlight %} +
      +
      +{% highlight java %} + +// Setting on Kafka Source for Streaming Queries +Dataset df = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("subscribe", "topic1") + .load(); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); + +// Setting on Kafka Source for Batch Queries +Dataset df = spark + .read() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("subscribe", "topic1") + .load(); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); + +// Setting on Kafka Sink for Streaming Queries +StreamingQuery ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("topic", "topic1") + .start(); + +// Setting on Kafka Sink for Batch Queries +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") + .option("topic", "topic1") + .save(); + +{% endhighlight %} +
      +
      +{% highlight python %} + +// Setting on Kafka Source for Streaming Queries +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ + .option("subscribe", "topic1") \ + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +// Setting on Kafka Source for Batch Queries +df = spark \ + .read \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ + .option("subscribe", "topic1") \ + .load() +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +// Setting on Kafka Sink for Streaming Queries +ds = df \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ + .option("topic", "topic1") \ + .start() + +// Setting on Kafka Sink for Batch Queries +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ + .option("topic", "topic1") \ + .save() + +{% endhighlight %} +
      +
      + +When delegation token is available on an executor it can be overridden with JAAS login configuration. + +### JAAS login configuration + +JAAS login configuration must placed on all nodes where Spark tries to access Kafka cluster. +This provides the possibility to apply any custom authentication logic with a higher cost to maintain. +This can be done several ways. One possibility is to provide additional JVM parameters, such as, + + ./bin/spark-submit \ + --driver-java-options "-Djava.security.auth.login.config=/path/to/custom_jaas.conf" \ + --conf spark.executor.extraJavaOptions=-Djava.security.auth.login.config=/path/to/custom_jaas.conf \ + ... From dbd90e54408d593e02a3dd1e659fcf9a7b940535 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 6 Dec 2018 14:17:13 -0800 Subject: [PATCH 2227/2461] [SPARK-26194][K8S] Auto generate auth secret for k8s apps. This change modifies the logic in the SecurityManager to do two things: - generate unique app secrets also when k8s is being used - only store the secret in the user's UGI on YARN The latter is needed so that k8s won't unnecessarily create k8s secrets for the UGI credentials when only the auth token is stored there. On the k8s side, the secret is propagated to executors using an environment variable instead. This ensures it works in both client and cluster mode. Security doc was updated to mention the feature and clarify that proper access control in k8s should be enabled for it to be secure. Author: Marcelo Vanzin Closes #23174 from vanzin/SPARK-26194. --- .../org/apache/spark/SecurityManager.scala | 21 +++- .../apache/spark/SecurityManagerSuite.scala | 57 +++++++---- docs/security.md | 34 ++++--- .../features/BasicExecutorFeatureStep.scala | 96 +++++++++++-------- .../cluster/k8s/ExecutorPodsAllocator.scala | 5 +- .../k8s/KubernetesClusterManager.scala | 3 +- .../KubernetesClusterSchedulerBackend.scala | 5 +- .../k8s/KubernetesExecutorBuilder.scala | 13 ++- .../BasicExecutorFeatureStepSuite.scala | 46 ++++++--- .../k8s/ExecutorPodsAllocatorSuite.scala | 9 +- ...bernetesClusterSchedulerBackendSuite.scala | 9 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 18 ++-- .../k8s/integrationtest/KubernetesSuite.scala | 2 + 13 files changed, 205 insertions(+), 113 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 3cfafeb951105..96e4b53b24181 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -348,15 +348,23 @@ private[spark] class SecurityManager( */ def initializeAuth(): Unit = { import SparkMasterRegex._ + val k8sRegex = "k8s.*".r if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return } + // TODO: this really should be abstracted somewhere else. val master = sparkConf.get(SparkLauncher.SPARK_MASTER, "") - master match { + val storeInUgi = master match { case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => - // Secret generation allowed here + true + + case k8sRegex() => + // Don't propagate the secret through the user's credentials in kubernetes. That conflicts + // with the way k8s handles propagation of delegation tokens. + false + case _ => require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") @@ -364,9 +372,12 @@ private[spark] class SecurityManager( } secretKey = Utils.createSecret(sparkConf) - val creds = new Credentials() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) - UserGroupInformation.getCurrentUser().addCredentials(creds) + + if (storeInUgi) { + val creds = new Credentials() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) + UserGroupInformation.getCurrentUser().addCredentials(creds) + } } // Default SecurityManager only has a single secret key, so ignore appId. diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e357299770a2e..eec8004fc94f4 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -395,15 +395,23 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } - test("secret key generation") { - Seq( - ("yarn", true), - ("local", true), - ("local[*]", true), - ("local[1, 2]", true), - ("local-cluster[2, 1, 1024]", false), - ("invalid", false) - ).foreach { case (master, shouldGenerateSecret) => + // How is the secret expected to be generated and stored. + object SecretTestType extends Enumeration { + val MANUAL, AUTO, UGI = Value + } + + import SecretTestType._ + + Seq( + ("yarn", UGI), + ("local", UGI), + ("local[*]", UGI), + ("local[1, 2]", UGI), + ("k8s://127.0.0.1", AUTO), + ("local-cluster[2, 1, 1024]", MANUAL), + ("invalid", MANUAL) + ).foreach { case (master, secretType) => + test(s"secret key generation: master '$master'") { val conf = new SparkConf() .set(NETWORK_AUTH_ENABLED, true) .set(SparkLauncher.SPARK_MASTER, master) @@ -412,19 +420,26 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { UserGroupInformation.createUserForTesting("authTest", Array()).doAs( new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - if (shouldGenerateSecret) { - mgr.initializeAuth() - val creds = UserGroupInformation.getCurrentUser().getCredentials() - val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) - assert(secret != null) - assert(new String(secret, UTF_8) === mgr.getSecretKey()) - } else { - intercept[IllegalArgumentException] { + secretType match { + case UGI => + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + + case AUTO => mgr.initializeAuth() - } - intercept[IllegalArgumentException] { - mgr.getSecretKey() - } + val creds = UserGroupInformation.getCurrentUser().getCredentials() + assert(creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) === null) + + case MANUAL => + intercept[IllegalArgumentException] { + mgr.initializeAuth() + } + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } } } } diff --git a/docs/security.md b/docs/security.md index be4834660fb7a..2a4f3c074c1e5 100644 --- a/docs/security.md +++ b/docs/security.md @@ -26,21 +26,29 @@ not documented, Spark does not support. Spark currently supports authentication for RPC channels using a shared secret. Authentication can be turned on by setting the `spark.authenticate` configuration parameter. -The exact mechanism used to generate and distribute the shared secret is deployment-specific. +The exact mechanism used to generate and distribute the shared secret is deployment-specific. Unless +specified below, the secret must be defined by setting the `spark.authenticate.secret` config +option. The same secret is shared by all Spark applications and daemons in that case, which limits +the security of these deployments, especially on multi-tenant clusters. -For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle -generating and distributing the shared secret. Each application will use a unique shared secret. In +The REST Submission Server and the MesosClusterDispatcher do not support authentication. You should +ensure that all network access to the REST API & MesosClusterDispatcher (port 6066 and 7077 +respectively by default) are restricted to hosts that are trusted to submit jobs. + +### YARN + +For Spark on [YARN](running-on-yarn.html), Spark will automatically handle generating and +distributing the shared secret. Each application will use a unique shared secret. In the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of secrets to be secure. -For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. -This secret will be shared by all the daemons and applications, so this deployment configuration is -not as secure as the above, especially when considering multi-tenant clusters. In this -configuration, a user with the secret can effectively impersonate any other user. +### Kubernetes -The Rest Submission Server and the MesosClusterDispatcher do not support authentication. You should -ensure that all network access to the REST API & MesosClusterDispatcher (port 6066 and 7077 -respectively by default) are restricted to hosts that are trusted to submit jobs. +On Kubernetes, Spark will also automatically generate an authentication secret unique to each +application. The secret is propagated to executor pods using environment variables. This means +that any user that can list pods in the namespace where the Spark application is running can +also see their authentication secret. Access control rules should be properly set up by the +Kubernetes admin to ensure that Spark authentication is secure. @@ -738,10 +746,10 @@ tokens for supported will be created. ## Secure Interaction with Kubernetes When talking to Hadoop-based services behind Kerberos, it was noted that Spark needs to obtain delegation tokens -so that non-local processes can authenticate. These delegation tokens in Kubernetes are stored in Secrets that are -shared by the Driver and its Executors. As such, there are three ways of submitting a Kerberos job: +so that non-local processes can authenticate. These delegation tokens in Kubernetes are stored in Secrets that are +shared by the Driver and its Executors. As such, there are three ways of submitting a Kerberos job: -In all cases you must define the environment variable: `HADOOP_CONF_DIR` or +In all cases you must define the environment variable: `HADOOP_CONF_DIR` or `spark.kubernetes.hadoop.configMapName.` It also important to note that the KDC needs to be visible from inside the containers. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 8bf315248388f..939aa88b07973 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -20,7 +20,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ -import org.apache.spark.SparkException +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -29,11 +29,12 @@ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils -private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutorConf) +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesExecutorConf, + secMgr: SecurityManager) extends KubernetesFeatureConfigStep { // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf - private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) private val executorContainerImage = kubernetesConf .get(EXECUTOR_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the executor container image")) @@ -87,44 +88,61 @@ private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutor val executorCpuQuantity = new QuantityBuilder(false) .withAmount(executorCoresRequest) .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = kubernetesConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, - kubernetesConf.executorId) - val delimitedOpts = Utils.splitCommandString(subsOpts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + + val executorEnv: Seq[EnvVar] = { + (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.executorId) + ) ++ kubernetesConf.environment).map { case (k, v) => + new EnvVarBuilder() + .withName(k) + .withValue(v) + .build() } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - (ENV_EXECUTOR_CORES, executorCores.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, kubernetesConf.appId), - // This is to set the SPARK_CONF_DIR to be /opt/spark/conf - (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, kubernetesConf.executorId)) ++ - kubernetesConf.environment) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") + } ++ { + Seq(new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + } ++ { + Option(secMgr.getSecretKey()).map { authSecret => + new EnvVarBuilder() + .withName(SecurityManager.ENV_AUTH_SECRET) + .withValue(authSecret) + .build() + } + } ++ { + kubernetesConf.get(EXECUTOR_CLASS_PATH).map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + } ++ { + val userOpts = kubernetesConf.get(EXECUTOR_JAVA_OPTIONS).toSeq.flatMap { opts => + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.executorId) + Utils.splitCommandString(subsOpts) + } + + val sparkOpts = Utils.sparkJavaOpts(kubernetesConf.sparkConf, + SparkConf.isExecutorStartupConf) + + (userOpts ++ sparkOpts).zipWithIndex.map { case (opt, index) => + new EnvVarBuilder() + .withName(s"$ENV_JAVA_OPT_PREFIX$index") + .withValue(opt) + .build() + } + } + val requiredPorts = Seq( (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) .map { case (name, port) => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 2f0f949566d6a..ac42554b1334b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.PodBuilder import io.fabric8.kubernetes.client.KubernetesClient import scala.collection.mutable -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesConf @@ -31,6 +31,7 @@ import org.apache.spark.util.{Clock, Utils} private[spark] class ExecutorPodsAllocator( conf: SparkConf, + secMgr: SecurityManager, executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, snapshotsStore: ExecutorPodsSnapshotsStore, @@ -135,7 +136,7 @@ private[spark] class ExecutorPodsAllocator( newExecutorId.toString, applicationId, driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) + val executorPod = executorBuilder.buildFromFeatures(executorConf, secMgr) val podWithAttachedContainer = new PodBuilder(executorPod.pod) .editOrNewSpec() .addToContainers(executorPod.container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ce10f766334ff..b31fbb420ed6d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -94,6 +94,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit val executorPodsAllocator = new ExecutorPodsAllocator( sc.conf, + sc.env.securityManager, KubernetesExecutorBuilder(kubernetesClient, sc.conf), kubernetesClient, snapshotsStore, @@ -110,7 +111,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], - sc.env.rpcEnv, + sc, kubernetesClient, requestExecutorsService, snapshotsStore, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 6356b58645806..68f6f2e46e316 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -21,6 +21,7 @@ import java.util.concurrent.ExecutorService import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -30,7 +31,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, - rpcEnv: RpcEnv, + sc: SparkContext, kubernetesClient: KubernetesClient, requestExecutorsService: ExecutorService, snapshotsStore: ExecutorPodsSnapshotsStore, @@ -38,7 +39,7 @@ private[spark] class KubernetesClusterSchedulerBackend( lifecycleEventHandler: ExecutorPodsLifecycleManager, watchEvents: ExecutorPodsWatchSnapshotSource, pollEvents: ExecutorPodsPollingSnapshotSource) - extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d24ff0d1e6600..ba273cad6a8e5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -20,14 +20,14 @@ import java.io.File import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesExecutorConf => BasicExecutorFeatureStep) = - new BasicExecutorFeatureStep(_), + provideBasicStep: (KubernetesExecutorConf, SecurityManager) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_, _), provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = @@ -44,13 +44,16 @@ private[spark] class KubernetesExecutorBuilder( new HadoopSparkUserExecutorFeatureStep(_), provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { - def buildFromFeatures(kubernetesConf: KubernetesExecutorConf): SparkPod = { + def buildFromFeatures( + kubernetesConf: KubernetesExecutorConf, + secMgr: SecurityManager): SparkPod = { val sparkConf = kubernetesConf.sparkConf val maybeHadoopConfigMap = sparkConf.getOption(HADOOP_CONFIG_MAP_NAME) val maybeDTSecretName = sparkConf.getOption(KERBEROS_DT_SECRET_NAME) val maybeDTDataItem = sparkConf.getOption(KERBEROS_DT_SECRET_KEY) - val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val baseFeatures = Seq(provideBasicStep(kubernetesConf, secMgr), + provideLocalDirsStep(kubernetesConf)) val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { Seq(provideSecretsStep(kubernetesConf)) } else Nil diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index d6003c977937c..6aa862643c788 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -21,13 +21,14 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { @@ -63,7 +64,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private var baseConf: SparkConf = _ before { - baseConf = new SparkConf() + baseConf = new SparkConf(false) .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) @@ -84,7 +85,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { } test("basic executor pod has reasonable defaults") { - val step = new BasicExecutorFeatureStep(newExecutorConf()) + val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf)) val executor = step.configurePod(SparkPod.initialPod()) // The executor pod name and default labels. @@ -106,7 +107,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { assert(executor.pod.getSpec.getNodeSelector.isEmpty) assert(executor.pod.getSpec.getVolumes.isEmpty) - checkEnv(executor, Map()) + checkEnv(executor, baseConf, Map()) checkOwnerReferences(executor.pod, DRIVER_POD_UID) } @@ -114,7 +115,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" baseConf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, longPodNamePrefix) - val step = new BasicExecutorFeatureStep(newExecutorConf()) + val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf)) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -122,10 +123,10 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { baseConf.set(EXECUTOR_JAVA_OPTIONS, "foo=bar") baseConf.set(EXECUTOR_CLASS_PATH, "bar=baz") val kconf = newExecutorConf(environment = Map("qux" -> "quux")) - val step = new BasicExecutorFeatureStep(kconf) + val step = new BasicExecutorFeatureStep(kconf, new SecurityManager(baseConf)) val executor = step.configurePod(SparkPod.initialPod()) - checkEnv(executor, + checkEnv(executor, baseConf, Map("SPARK_JAVA_OPT_0" -> "foo=bar", ENV_CLASSPATH -> "bar=baz", "qux" -> "quux")) @@ -136,12 +137,27 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { baseConf.set("spark.kubernetes.resource.type", "python") baseConf.set(PYSPARK_EXECUTOR_MEMORY, 42L) - val step = new BasicExecutorFeatureStep(newExecutorConf()) + val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf)) val executor = step.configurePod(SparkPod.initialPod()) // This is checking that basic executor + executorMemory = 1408 + 42 = 1450 assert(executor.container.getResources.getRequests.get("memory").getAmount === "1450Mi") } + test("auth secret propagation") { + val conf = baseConf.clone() + .set(NETWORK_AUTH_ENABLED, true) + .set("spark.master", "k8s://127.0.0.1") + + val secMgr = new SecurityManager(conf) + secMgr.initializeAuth() + + val step = new BasicExecutorFeatureStep(KubernetesTestConf.createExecutorConf(sparkConf = conf), + secMgr) + + val executor = step.configurePod(SparkPod.initialPod()) + checkEnv(executor, conf, Map(SecurityManager.ENV_AUTH_SECRET -> secMgr.getSecretKey())) + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) @@ -150,7 +166,10 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { } // Check that the expected environment variables are present. - private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + private def checkEnv( + executorPod: SparkPod, + conf: SparkConf, + additionalEnvVars: Map[String, String]): Unit = { val defaultEnvs = Map( ENV_EXECUTOR_ID -> "1", ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, @@ -160,10 +179,15 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val extraJavaOptsStart = additionalEnvVars.keys.count(_.startsWith(ENV_JAVA_OPT_PREFIX)) + val extraJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) + val extraJavaOptsEnvs = extraJavaOpts.zipWithIndex.map { case (opt, ind) => + s"$ENV_JAVA_OPT_PREFIX${ind + extraJavaOptsStart}" -> opt + }.toMap + val mapEnvs = executorPod.container.getEnv.asScala.map { x => (x.getName, x.getValue) }.toMap - assert(defaultEnvs === mapEnvs) + assert((defaultEnvs ++ extraJavaOptsEnvs) === mapEnvs) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 303e24b8f4977..d4fa31af3d5ce 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -20,13 +20,13 @@ import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.any +import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.{never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -52,6 +52,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + private val secMgr = new SecurityManager(conf) private var waitForExecutorPodsClock: ManualClock = _ @@ -79,12 +80,12 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) when(driverPodOperations.get).thenReturn(driverPod) - when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]))) + when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr))) .thenAnswer(executorPodAnswer()) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() waitForExecutorPodsClock = new ManualClock(0L) podsAllocatorUnderTest = new ExecutorPodsAllocator( - conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + conf, secMgr, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 52e7a12dbaf06..75232f7b98b04 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -23,7 +23,7 @@ import org.mockito.Matchers.{eq => mockitoEq} import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} @@ -41,6 +41,9 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn @Mock private var sc: SparkContext = _ + @Mock + private var env: SparkEnv = _ + @Mock private var rpcEnv: RpcEnv = _ @@ -81,6 +84,8 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn MockitoAnnotations.initMocks(this) when(taskScheduler.sc).thenReturn(sc) when(sc.conf).thenReturn(sparkConf) + when(sc.env).thenReturn(env) + when(env.rpcEnv).thenReturn(rpcEnv) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) @@ -88,7 +93,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn when(kubernetesClient.pods()).thenReturn(podOperations) schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( taskScheduler, - rpcEnv, + sc, kubernetesClient, requestExecutorsService, eventQueue, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index b6a75b15af85a..ef521fd801e97 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.{Config => _, _} import io.fabric8.kubernetes.client.KubernetesClient import org.mockito.Mockito.{mock, never, verify} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ @@ -39,6 +39,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val KERBEROS_CONF_STEP_TYPE = "kerberos-step" private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" + private val secMgr = new SecurityManager(new SparkConf(false)) + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( @@ -57,7 +59,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( - _ => basicFeatureStep, + (_, _) => basicFeatureStep, _ => mountSecretsStep, _ => envSecretsStep, _ => localDirsStep, @@ -69,7 +71,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { test("Basic steps are consistently applied.") { val conf = KubernetesTestConf.createExecutorConf() validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) + builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -77,7 +79,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { secretEnvNamesToKeyRefs = Map("secret-name" -> "secret-key"), secretNamesToMountPaths = Map("secret" -> "secretMountPath")) validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), + builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE, @@ -94,7 +96,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val conf = KubernetesTestConf.createExecutorConf( volumes = Seq(volumeSpec)) validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), + builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, MOUNT_VOLUMES_STEP_TYPE) @@ -107,7 +109,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name")) validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), + builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, HADOOP_CONF_STEP_TYPE, @@ -123,7 +125,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { .set(KERBEROS_DT_SECRET_NAME, "dt-secret") .set(KERBEROS_DT_SECRET_KEY, "dt-key" )) validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), + builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, HADOOP_CONF_STEP_TYPE, @@ -154,7 +156,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { .endMetadata() .build())) val sparkPod = KubernetesExecutorBuilder(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf) + .buildFromFeatures(kubernetesConf, secMgr) PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(sparkPod) } } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index b746a01eb5294..f8f4b4177f3bd 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SPARK_VERSION, SparkFunSuite} import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite @@ -138,6 +139,7 @@ class KubernetesSuite extends SparkFunSuite .set("spark.kubernetes.driver.pod.name", driverPodName) .set("spark.kubernetes.driver.label.spark-app-locator", appLocator) .set("spark.kubernetes.executor.label.spark-app-locator", appLocator) + .set(NETWORK_AUTH_ENABLED.key, "true") if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { kubernetesTestComponents.createNamespace() } From bfc5569a53510bc75c15384084ff89b418592875 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 7 Dec 2018 09:57:35 +0800 Subject: [PATCH 2228/2461] [SPARK-26289][CORE] cleanup enablePerfMetrics parameter from BytesToBytesMap ## What changes were proposed in this pull request? `enablePerfMetrics `was originally designed in `BytesToBytesMap `to control `getNumHashCollisions getTimeSpentResizingNs getAverageProbesPerLookup`. However, as the Spark version gradual progress. this parameter is only used for `getAverageProbesPerLookup ` and always given to true when using `BytesToBytesMap`. it is also dangerous to determine whether `getAverageProbesPerLookup `opens and throws an `IllegalStateException `exception. So this pr will be remove `enablePerfMetrics `parameter from `BytesToBytesMap`. thanks. ## How was this patch tested? the existed test cases. Closes #23244 from heary-cao/enablePerfMetrics. Authored-by: caoxuewen Signed-off-by: Wenchen Fan --- .../spark/unsafe/map/BytesToBytesMap.java | 33 ++++--------------- .../map/AbstractBytesToBytesMapSuite.java | 4 +-- .../UnsafeFixedWidthAggregationMap.java | 2 +- .../sql/execution/joins/HashedRelation.scala | 6 ++-- 4 files changed, 12 insertions(+), 33 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index a4e88598f7607..405e529464152 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -159,11 +159,9 @@ public final class BytesToBytesMap extends MemoryConsumer { */ private final Location loc; - private final boolean enablePerfMetrics; + private long numProbes = 0L; - private long numProbes = 0; - - private long numKeyLookups = 0; + private long numKeyLookups = 0L; private long peakMemoryUsedBytes = 0L; @@ -180,8 +178,7 @@ public BytesToBytesMap( SerializerManager serializerManager, int initialCapacity, double loadFactor, - long pageSizeBytes, - boolean enablePerfMetrics) { + long pageSizeBytes) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; @@ -189,7 +186,6 @@ public BytesToBytesMap( this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; - this.enablePerfMetrics = enablePerfMetrics; if (initialCapacity <= 0) { throw new IllegalArgumentException("Initial capacity must be greater than 0"); } @@ -209,14 +205,6 @@ public BytesToBytesMap( TaskMemoryManager taskMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, initialCapacity, pageSizeBytes, false); - } - - public BytesToBytesMap( - TaskMemoryManager taskMemoryManager, - int initialCapacity, - long pageSizeBytes, - boolean enablePerfMetrics) { this( taskMemoryManager, SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, @@ -224,8 +212,7 @@ public BytesToBytesMap( initialCapacity, // In order to re-use the longArray for sorting, the load factor cannot be larger than 0.5. 0.5, - pageSizeBytes, - enablePerfMetrics); + pageSizeBytes); } /** @@ -462,15 +449,12 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) { assert(longArray != null); - if (enablePerfMetrics) { - numKeyLookups++; - } + numKeyLookups++; + int pos = hash & mask; int step = 1; while (true) { - if (enablePerfMetrics) { - numProbes++; - } + numProbes++; if (longArray.get(pos * 2) == 0) { // This is a new key. loc.with(pos, hash, false); @@ -860,9 +844,6 @@ public long getPeakMemoryUsedBytes() { * Returns the average number of probes per key lookup. */ public double getAverageProbesPerLookup() { - if (!enablePerfMetrics) { - throw new IllegalStateException(); - } return (1.0 * numProbes) / numKeyLookups; } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 53a233f698c7a..aa29232e73e13 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -530,7 +530,7 @@ public void failureToGrow() { @Test public void spillInIterator() throws IOException { BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); + taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024); try { int i; for (i = 0; i < 1024; i++) { @@ -569,7 +569,7 @@ public void spillInIterator() throws IOException { @Test public void multipleValuesForSameKey() { BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024, false); + new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024); try { int i; for (i = 0; i < 1024; i++) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c8cf44b51df77..7e76a651ba2cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -98,7 +98,7 @@ public UnsafeFixedWidthAggregationMap( this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = new BytesToBytesMap( - taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index e8c01d46a84c0..b1ff6e83acc24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -248,8 +248,7 @@ private[joins] class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision - pageSizeBytes, - true) + pageSizeBytes) var i = 0 var keyBuffer = new Array[Byte](1024) @@ -299,8 +298,7 @@ private[joins] object UnsafeHashedRelation { taskMemoryManager, // Only 70% of the slots can be used before growing, more capacity help to reduce collision (sizeEstimate * 1.5 + 1).toInt, - pageSizeBytes, - true) + pageSizeBytes) // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) From 5a140b7844936cf2b65f08853b8cfd8c499d4f13 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 7 Dec 2018 11:13:14 +0800 Subject: [PATCH 2229/2461] [SPARK-26263][SQL] Validate partition values with user provided schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently if user provides data schema, partition column values are converted as per it. But if the conversion failed, e.g. converting string to int, the column value is null. This PR proposes to throw exception in such case, instead of converting into null value silently: 1. These null partition column values doesn't make sense to users in most cases. It is better to show the conversion failure, and then users can adjust the schema or ETL jobs to fix it. 2. There are always exceptions on such conversion failure for non-partition data columns. Partition columns should have the same behavior. We can reproduce the case above as following: ``` /tmp/testDir ├── p=bar └── p=foo ``` If we run: ``` val schema = StructType(Seq(StructField("p", IntegerType, false))) spark.read.schema(schema).csv("/tmp/testDir/").show() ``` We will get: ``` +----+ | p| +----+ |null| |null| +----+ ``` ## How was this patch tested? Unit test Closes #23215 from gengliangwang/SPARK-26263. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 ++ .../apache/spark/sql/internal/SQLConf.scala | 12 ++++++++ .../PartitioningAwareFileIndex.scala | 4 +-- .../datasources/PartitioningUtils.scala | 30 +++++++++++++------ .../datasources/FileIndexSuite.scala | 27 ++++++++++++++++- .../ParquetPartitionDiscoverySuite.scala | 18 ++++++++--- 6 files changed, 77 insertions(+), 16 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index ed2ff139bcc33..3638b0873aa4d 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -29,6 +29,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined. + - In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`. + - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.execution.setCommandRejectsSparkConfs` to `false`. - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 451b051f8407e..6857b8de79758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1396,6 +1396,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val VALIDATE_PARTITION_COLUMNS = + buildConf("spark.sql.sources.validatePartitionColumns") + .internal() + .doc("When this option is set to true, partition column values will be validated with " + + "user-specified schema. If the validation fails, a runtime exception is thrown." + + "When this option is set to false, the partition column value will be converted to null " + + "if it can not be casted to corresponding user-specified schema.") + .booleanConf + .createWithDefault(true) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -2014,6 +2024,8 @@ class SQLConf extends Serializable with Logging { def allowCreatingManagedTableUsingNonemptyLocation: Boolean = getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) + def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 7b0e4dbcc25f4..b2e4155e6f49e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -127,13 +127,13 @@ abstract class PartitioningAwareFileIndex( val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis PartitioningUtils.parsePartitions( leafDirs, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, basePaths = basePaths, userSpecifiedSchema = userSpecifiedSchema, - caseSensitive = caseSensitive, + caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis, + validatePartitionColumns = sparkSession.sqlContext.conf.validatePartitionColumns, timeZoneId = timeZoneId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index d66cb09bda0cc..6458b65466fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -26,12 +26,13 @@ import scala.util.Try import org.apache.hadoop.fs.Path -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -96,9 +97,10 @@ object PartitioningUtils { basePaths: Set[Path], userSpecifiedSchema: Option[StructType], caseSensitive: Boolean, + validatePartitionColumns: Boolean, timeZoneId: String): PartitionSpec = { - parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema, - caseSensitive, DateTimeUtils.getTimeZone(timeZoneId)) + parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema, caseSensitive, + validatePartitionColumns, DateTimeUtils.getTimeZone(timeZoneId)) } private[datasources] def parsePartitions( @@ -107,6 +109,7 @@ object PartitioningUtils { basePaths: Set[Path], userSpecifiedSchema: Option[StructType], caseSensitive: Boolean, + validatePartitionColumns: Boolean, timeZone: TimeZone): PartitionSpec = { val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) { val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap @@ -121,7 +124,8 @@ object PartitioningUtils { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone) + parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, + validatePartitionColumns, timeZone) }.unzip // We create pairs of (path -> path's partition value) here @@ -203,6 +207,7 @@ object PartitioningUtils { typeInference: Boolean, basePaths: Set[Path], userSpecifiedDataTypes: Map[String, DataType], + validatePartitionColumns: Boolean, timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` @@ -224,7 +229,8 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone) + parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, + validatePartitionColumns, timeZone) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -258,6 +264,7 @@ object PartitioningUtils { columnSpec: String, typeInference: Boolean, userSpecifiedDataTypes: Map[String, DataType], + validatePartitionColumns: Boolean, timeZone: TimeZone): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { @@ -272,10 +279,15 @@ object PartitioningUtils { val literal = if (userSpecifiedDataTypes.contains(columnName)) { // SPARK-26188: if user provides corresponding column schema, get the column value without // inference, and then cast it as user specified data type. - val columnValue = inferPartitionColumnValue(rawColumnValue, false, timeZone) - val castedValue = - Cast(columnValue, userSpecifiedDataTypes(columnName), Option(timeZone.getID)).eval() - Literal.create(castedValue, userSpecifiedDataTypes(columnName)) + val dataType = userSpecifiedDataTypes(columnName) + val columnValueLiteral = inferPartitionColumnValue(rawColumnValue, false, timeZone) + val columnValue = columnValueLiteral.eval() + val castedValue = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval() + if (validatePartitionColumns && columnValue != null && castedValue == null) { + throw new RuntimeException(s"Failed to cast value `$columnValue` to `$dataType` " + + s"for partition column `$columnName`") + } + Literal.create(castedValue, dataType) } else { inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index ec552f7ddf47a..6bd0a2591fc1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} class FileIndexSuite extends SharedSQLContext { @@ -95,6 +95,31 @@ class FileIndexSuite extends SharedSQLContext { } } + test("SPARK-26263: Throw exception when partition value can't be casted to user-specified type") { + withTempDir { dir => + val partitionDirectory = new File(dir, "a=foo") + partitionDirectory.mkdir() + val file = new File(partitionDirectory, "text.txt") + stringToFile(file, "text") + val path = new Path(dir.getCanonicalPath) + val schema = StructType(Seq(StructField("a", IntegerType, false))) + withSQLConf(SQLConf.VALIDATE_PARTITION_COLUMNS.key -> "true") { + val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema)) + val msg = intercept[RuntimeException] { + fileIndex.partitionSpec() + }.getMessage + assert(msg == "Failed to cast value `foo` to `IntegerType` for partition column `a`") + } + + withSQLConf(SQLConf.VALIDATE_PARTITION_COLUMNS.key -> "false") { + val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema)) + val partitionValues = fileIndex.partitionSpec().partitions.map(_.values) + assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 && + partitionValues(0).isNullAt(0)) + } + } + } + test("InMemoryFileIndex: input paths are converted to qualified paths") { withTempDir { dir => val file = new File(dir, "text.txt") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index f808ca458aaa7..88067358667c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, true, timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -117,6 +117,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Set(new Path("hdfs://host:9000/path/")), None, true, + true, timeZoneId) // Valid @@ -132,6 +133,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Set(new Path("hdfs://host:9000/path/something=true/table")), None, true, + true, timeZoneId) // Valid @@ -147,6 +149,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Set(new Path("hdfs://host:9000/path/table=true")), None, true, + true, timeZoneId) // Invalid @@ -162,6 +165,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Set(new Path("hdfs://host:9000/path/")), None, true, + true, timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -184,6 +188,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Set(new Path("hdfs://host:9000/tmp/tables/")), None, true, + true, timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -191,13 +196,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path], + Map.empty, true, timeZone)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone) + parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone) }.getMessage assert(message.contains(expected)) @@ -242,6 +248,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha typeInference = true, basePaths = Set(new Path("file://path/a=10")), Map.empty, + true, timeZone = timeZone)._1 assert(partitionSpec1.isEmpty) @@ -252,6 +259,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha typeInference = true, basePaths = Set(new Path("file://path")), Map.empty, + true, timeZone = timeZone)._1 assert(partitionSpec2 == @@ -272,6 +280,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha rootPaths, None, true, + true, timeZoneId) assert(actualSpec.partitionColumns === spec.partitionColumns) assert(actualSpec.partitions.length === spec.partitions.length) @@ -384,7 +393,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, + true, true, timeZoneId) assert(actualSpec === spec) } From 477226520358f0cc47d5ea255ad84d3c13f6d77d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 6 Dec 2018 20:50:57 -0800 Subject: [PATCH 2230/2461] [SPARK-26298][BUILD] Upgrade Janino to 3.0.11 ## What changes were proposed in this pull request? This PR aims to upgrade Janino compiler to the latest version 3.0.11. The followings are the changes from the [release note](http://janino-compiler.github.io/janino/changelog.html). - Script with many "helper" variables. - Java 9+ compatibility - Compilation Error Messages Generated by JDK. - Added experimental support for the "StackMapFrame" attribute; not active yet. - Make Unparser more flexible. - Fixed NPEs in various "toString()" methods. - Optimize static method invocation with rvalue target expression. - Added all missing "ClassFile.getConstant*Info()" methods, removing the necessity for many type casts. ## How was this patch tested? Pass the Jenkins with the existing tests. Closes #23250 from dongjoon-hyun/SPARK-26298. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ec7c304c9e36b..d250d5205586e 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -34,7 +34,7 @@ commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.10.jar +commons-compiler-3.0.11.jar commons-compress-1.8.1.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar @@ -98,7 +98,7 @@ jackson-module-jaxb-annotations-2.9.6.jar jackson-module-paranamer-2.9.6.jar jackson-module-scala_2.12-2.9.6.jar jackson-xc-1.9.13.jar -janino-3.0.10.jar +janino-3.0.11.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 811febf22940d..347503ace557a 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -31,7 +31,7 @@ commons-beanutils-1.9.3.jar commons-cli-1.2.jar commons-codec-1.10.jar commons-collections-3.2.2.jar -commons-compiler-3.0.10.jar +commons-compiler-3.0.11.jar commons-compress-1.8.1.jar commons-configuration2-2.1.1.jar commons-crypto-1.0.0.jar @@ -97,7 +97,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-jaxb-annotations-2.9.6.jar jackson-module-paranamer-2.9.6.jar jackson-module-scala_2.12-2.9.6.jar -janino-3.0.10.jar +janino-3.0.11.jar javassist-3.18.1-GA.jar javax.annotation-api-1.2.jar javax.inject-1.jar diff --git a/pom.xml b/pom.xml index 61321a1450708..15e5c18290405 100644 --- a/pom.xml +++ b/pom.xml @@ -173,7 +173,7 @@ 3.8.1 3.2.10 - 3.0.10 + 3.0.11 2.22.2 2.9.3 3.5.2 From 1ab3d3e474ce2e36d58aea8ad09fb61f0c73e5c5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 7 Dec 2018 07:55:54 -0800 Subject: [PATCH 2231/2461] [SPARK-26060][SQL][FOLLOW-UP] Rename the config name. ## What changes were proposed in this pull request? This is a follow-up of #23031 to rename the config name to `spark.sql.legacy.setCommandRejectsSparkCoreConfs`. ## How was this patch tested? Existing tests. Closes #23245 from ueshin/issues/SPARK-26060/rename_config. Authored-by: Takuya UESHIN Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide-upgrade.md | 2 +- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 7 ++++--- .../main/scala/org/apache/spark/sql/RuntimeConfig.scala | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 3638b0873aa4d..67c30fb941ecd 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -31,7 +31,7 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`. - - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.execution.setCommandRejectsSparkConfs` to `false`. + - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.setCommandRejectsSparkCoreConfs` to `false`. - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6857b8de79758..86e068bf632bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1621,8 +1621,8 @@ object SQLConf { .intConf .createWithDefault(25) - val SET_COMMAND_REJECTS_SPARK_CONFS = - buildConf("spark.sql.legacy.execution.setCommandRejectsSparkConfs") + val SET_COMMAND_REJECTS_SPARK_CORE_CONFS = + buildConf("spark.sql.legacy.setCommandRejectsSparkCoreConfs") .internal() .doc("If it is set to true, SET command will fail when the key is registered as " + "a SparkConf entry.") @@ -2057,7 +2057,8 @@ class SQLConf extends Serializable with Logging { def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) - def setCommandRejectsSparkConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CONFS) + def setCommandRejectsSparkCoreConfs: Boolean = + getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) def legacyTimeParserEnabled: Boolean = getConf(SQLConf.LEGACY_TIME_PARSER_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index d83a01ff9ea65..0f5aab7f47d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -153,7 +153,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { if (SQLConf.staticConfKeys.contains(key)) { throw new AnalysisException(s"Cannot modify the value of a static config: $key") } - if (sqlConf.setCommandRejectsSparkConfs && + if (sqlConf.setCommandRejectsSparkCoreConfs && ConfigEntry.findEntry(key) != null && !SQLConf.sqlConfEntries.containsKey(key)) { throw new AnalysisException(s"Cannot modify the value of a Spark config: $key") } From 543577a1e8c0904048b73008fa7c4cee33f69894 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Fri, 7 Dec 2018 10:33:42 -0800 Subject: [PATCH 2232/2461] [SPARK-24243][CORE] Expose exceptions from InProcessAppHandle Adds a new method to SparkAppHandle called getError which returns the exception (if present) that caused the underlying Spark app to fail. New tests added to SparkLauncherSuite for the new method. Closes #21849 Closes #23221 from vanzin/SPARK-24243. Signed-off-by: Marcelo Vanzin --- .../spark/launcher/SparkLauncherSuite.java | 102 ++++++++++++++++-- .../spark/launcher/ChildProcAppHandle.java | 20 +++- .../spark/launcher/InProcessAppHandle.java | 13 +++ .../spark/launcher/OutputRedirector.java | 25 +++++ .../apache/spark/launcher/SparkAppHandle.java | 8 ++ project/MimaExcludes.scala | 3 + 6 files changed, 159 insertions(+), 12 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 6a1a38c1a54f4..773c390175b6d 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -41,6 +41,8 @@ public class SparkLauncherSuite extends BaseSuite { private static final NamedThreadFactory TF = new NamedThreadFactory("SparkLauncherSuite-%d"); + private static final String EXCEPTION_MESSAGE = "dummy-exception"; + private static final RuntimeException DUMMY_EXCEPTION = new RuntimeException(EXCEPTION_MESSAGE); private final SparkLauncher launcher = new SparkLauncher(); @@ -130,17 +132,8 @@ public void testInProcessLauncher() throws Exception { try { inProcessLauncherTestImpl(); } finally { - Properties p = new Properties(); - for (Map.Entry e : properties.entrySet()) { - p.put(e.getKey(), e.getValue()); - } - System.setProperties(p); - // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. - // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. - // See SPARK-23019 and SparkContext.stop() for details. - eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { - assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); - }); + restoreSystemProperties(properties); + waitForSparkContextShutdown(); } } @@ -227,6 +220,82 @@ public void testInProcessLauncherDoesNotKillJvm() throws Exception { assertEquals(SparkAppHandle.State.LOST, handle.getState()); } + @Test + public void testInProcessLauncherGetError() throws Exception { + // Because this test runs SparkLauncher in process and in client mode, it pollutes the system + // properties, and that can cause test failures down the test pipeline. So restore the original + // system properties after this test runs. + Map properties = new HashMap<>(System.getProperties()); + + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(ErrorInProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(); + + final SparkAppHandle _handle = handle; + eventually(Duration.ofSeconds(60), Duration.ofMillis(1000), () -> { + assertEquals(SparkAppHandle.State.FAILED, _handle.getState()); + }); + + assertNotNull(handle.getError()); + assertTrue(handle.getError().isPresent()); + assertSame(handle.getError().get(), DUMMY_EXCEPTION); + } finally { + if (handle != null) { + handle.kill(); + } + restoreSystemProperties(properties); + waitForSparkContextShutdown(); + } + } + + @Test + public void testSparkLauncherGetError() throws Exception { + SparkAppHandle handle = null; + try { + handle = new SparkLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(ErrorInProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(); + + final SparkAppHandle _handle = handle; + eventually(Duration.ofSeconds(60), Duration.ofMillis(1000), () -> { + assertEquals(SparkAppHandle.State.FAILED, _handle.getState()); + }); + + assertNotNull(handle.getError()); + assertTrue(handle.getError().isPresent()); + assertTrue(handle.getError().get().getMessage().contains(EXCEPTION_MESSAGE)); + } finally { + if (handle != null) { + handle.kill(); + } + } + } + + private void restoreSystemProperties(Map properties) { + Properties p = new Properties(); + for (Map.Entry e : properties.entrySet()) { + p.put(e.getKey(), e.getValue()); + } + System.setProperties(p); + } + + private void waitForSparkContextShutdown() throws Exception { + // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. + // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. + // See SPARK-23019 and SparkContext.stop() for details. + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); + } + public static class SparkLauncherTestApp { public static void main(String[] args) throws Exception { @@ -264,4 +333,15 @@ public static void main(String[] args) throws Exception { } + /** + * Similar to {@link InProcessTestApp} except it throws an exception + */ + public static class ErrorInProcessTestApp { + + public static void main(String[] args) { + assertNotEquals(0, args.length); + assertEquals(args[0], "hello"); + throw DUMMY_EXCEPTION; + } + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 5609f8492f4f4..7dfcf0e66734a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.InputStream; +import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; @@ -29,7 +30,7 @@ class ChildProcAppHandle extends AbstractAppHandle { private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); private volatile Process childProc; - private OutputRedirector redirector; + private volatile OutputRedirector redirector; ChildProcAppHandle(LauncherServer server) { super(server); @@ -46,6 +47,23 @@ public synchronized void disconnect() { } } + /** + * Parses the logs of {@code spark-submit} and returns the last exception thrown. + *

      + * Since {@link SparkLauncher} runs {@code spark-submit} in a sub-process, it's difficult to + * accurately retrieve the full {@link Throwable} from the {@code spark-submit} process. + * This method parses the logs of the sub-process and provides a best-effort attempt at + * returning the last exception thrown by the {@code spark-submit} process. Only the exception + * message is parsed, the associated stacktrace is meaningless. + * + * @return an {@link Optional} containing a {@link RuntimeException} with the parsed + * exception, otherwise returns a {@link Optional#EMPTY} + */ + @Override + public Optional getError() { + return redirector != null ? Optional.ofNullable(redirector.getError()) : Optional.empty(); + } + @Override public synchronized void kill() { if (!isDisposed()) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 15fbca0facef2..ba09050c756d2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -17,7 +17,9 @@ package org.apache.spark.launcher; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -31,6 +33,8 @@ class InProcessAppHandle extends AbstractAppHandle { // Avoid really long thread names. private static final int MAX_APP_NAME_LEN = 16; + private volatile Throwable error; + private Thread app; InProcessAppHandle(LauncherServer server) { @@ -51,6 +55,11 @@ public synchronized void kill() { } } + @Override + public Optional getError() { + return Optional.ofNullable(error); + } + synchronized void start(String appName, Method main, String[] args) { CommandBuilderUtils.checkState(app == null, "Handle already started."); @@ -62,7 +71,11 @@ synchronized void start(String appName, Method main, String[] args) { try { main.invoke(null, (Object) args); } catch (Throwable t) { + if (t instanceof InvocationTargetException) { + t = t.getCause(); + } LOG.log(Level.WARNING, "Application failed with exception.", t); + error = t; setState(State.FAILED); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java index 6f4b0bb38e031..0f097f8313925 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java +++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java @@ -37,6 +37,7 @@ class OutputRedirector { private final ChildProcAppHandle callback; private volatile boolean active; + private volatile Throwable error; OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) { this(in, loggerName, tf, null); @@ -61,6 +62,10 @@ private void redirect() { while ((line = reader.readLine()) != null) { if (active) { sink.info(line.replaceFirst("\\s*$", "")); + if ((containsIgnoreCase(line, "Error") || containsIgnoreCase(line, "Exception")) && + !line.contains("at ")) { + error = new RuntimeException(line); + } } } } catch (IOException e) { @@ -85,4 +90,24 @@ boolean isAlive() { return thread.isAlive(); } + Throwable getError() { + return error; + } + + /** + * Copied from Apache Commons Lang {@code StringUtils#containsIgnoreCase(String, String)} + */ + private static boolean containsIgnoreCase(String str, String searchStr) { + if (str == null || searchStr == null) { + return false; + } + int len = searchStr.length(); + int max = str.length() - len; + for (int i = 0; i <= max; i++) { + if (str.regionMatches(true, i, searchStr, 0, len)) { + return true; + } + } + return false; + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java index cefb4d1a95fb6..afec270e2b11c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -17,6 +17,8 @@ package org.apache.spark.launcher; +import java.util.Optional; + /** * A handle to a running Spark application. *

      @@ -100,6 +102,12 @@ public boolean isFinal() { */ void disconnect(); + /** + * If the application failed due to an error, return the underlying error. If the app + * succeeded, this method returns an empty {@link Optional}. + */ + Optional getError(); + /** * Listener for updates to a handle's state. The callbacks do not receive information about * what exactly has changed, just that an update has occurred. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1c83cf5860c58..4eeebb805070a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-24243][CORE] Expose exceptions from InProcessAppHandle + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.launcher.SparkAppHandle.getError"), + // [SPARK-25867] Remove KMeans computeCost ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), From 9b7679a97ed2622081390301af8735e62055492d Mon Sep 17 00:00:00 2001 From: 10087686 Date: Fri, 7 Dec 2018 14:11:25 -0600 Subject: [PATCH 2233/2461] [SPARK-26294][CORE] Delete Unnecessary If statement ## What changes were proposed in this pull request? Delete unnecessary If statement, because it Impossible execution when records less than or equal to zero.it is only execution when records begin zero. ................... if (inMemSorter == null || inMemSorter.numRecords() <= 0) { return 0L; } .................... if (inMemSorter.numRecords() > 0) { ..................... } ## How was this patch tested? Existing tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23247 from wangjiaochun/inMemSorter. Authored-by: 10087686 Signed-off-by: Sean Owen --- .../unsafe/sort/UnsafeExternalSorter.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5056652a2420b..af5a934b7da62 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -213,14 +213,12 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriters.size() > 1 ? " times" : " time"); ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); - // We only write out contents of the inMemSorter if it is not empty. - if (inMemSorter.numRecords() > 0) { - final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, - inMemSorter.numRecords()); - spillWriters.add(spillWriter); - spillIterator(inMemSorter.getSortedIterator(), spillWriter); - } + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + inMemSorter.numRecords()); + spillWriters.add(spillWriter); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); final long spillSize = freeMemory(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in From bd00f10773f3a08e50ecd06c7e70d9b38094a252 Mon Sep 17 00:00:00 2001 From: dima-asana <42555784+dima-asana@users.noreply.github.com> Date: Fri, 7 Dec 2018 14:14:43 -0600 Subject: [PATCH 2234/2461] [MINOR][SQL][DOC] Correct parquet nullability documentation ## What changes were proposed in this pull request? Parquet files appear to have nullability info when being written, not being read. ## How was this patch tested? Some test code: (running spark 2.3, but the relevant code in DataSource looks identical on master) case class NullTest(bo: Boolean, opbol: Option[Boolean]) val testDf = spark.createDataFrame(Seq(NullTest(true, Some(false)))) defined class NullTest testDf: org.apache.spark.sql.DataFrame = [bo: boolean, opbol: boolean] testDf.write.parquet("s3://asana-stats/tmp_dima/parquet_check_schema") spark.read.parquet("s3://asana-stats/tmp_dima/parquet_check_schema/part-00000-b1bf4a19-d9fe-4ece-a2b4-9bbceb490857-c000.snappy.parquet4").printSchema() root |-- bo: boolean (nullable = true) |-- opbol: boolean (nullable = true) Meanwhile, the parquet file formed does have nullable info: []batchprod-report000:/tmp/dimakamalov-batch$ aws s3 ls s3://asana-stats/tmp_dima/parquet_check_schema/ 2018-10-17 21:03:52 0 _SUCCESS 2018-10-17 21:03:50 504 part-00000-b1bf4a19-d9fe-4ece-a2b4-9bbceb490857-c000.snappy.parquet []batchprod-report000:/tmp/dimakamalov-batch$ aws s3 cp s3://asana-stats/tmp_dima/parquet_check_schema/part-00000-b1bf4a19-d9fe-4ece-a2b4-9bbceb490857-c000.snappy.parquet . download: s3://asana-stats/tmp_dima/parquet_check_schema/part-00000-b1bf4a19-d9fe-4ece-a2b4-9bbceb490857-c000.snappy.parquet to ./part-00000-b1bf4a19-d9fe-4ece-a2b4-9bbceb490857-c000.snappy.parquet []batchprod-report000:/tmp/dimakamalov-batch$ java -jar parquet-tools-1.8.2.jar schema part-00000-b1bf4a19-d9fe-4ece-a2b4-9bbceb490857-c000.snappy.parquet message spark_schema { required boolean bo; optional boolean opbol; } Closes #22759 from dima-asana/dima-asana-nullable-parquet-doc. Authored-by: dima-asana <42555784+dima-asana@users.noreply.github.com> Signed-off-by: Sean Owen --- docs/sql-data-sources-parquet.md | 2 +- .../sql/test/DataFrameReaderWriterSuite.scala | 44 +++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/docs/sql-data-sources-parquet.md b/docs/sql-data-sources-parquet.md index 4fed3eaf83e5d..dcd2936518465 100644 --- a/docs/sql-data-sources-parquet.md +++ b/docs/sql-data-sources-parquet.md @@ -9,7 +9,7 @@ displayTitle: Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +of the original data. When reading Parquet files, all columns are automatically converted to be nullable for compatibility reasons. ### Loading Data Programmatically diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 237872585e11d..e45ab19aadbfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -23,6 +23,13 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.hadoop.util.HadoopInputFile +import org.apache.parquet.schema.PrimitiveType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.Type.Repetition import org.scalatest.BeforeAndAfter import org.apache.spark.SparkContext @@ -31,6 +38,7 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -522,11 +530,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Seq("json", "orc", "parquet", "csv").foreach { format => val schema = StructType( StructField("cl1", IntegerType, nullable = false).withComment("test") :: - StructField("cl2", IntegerType, nullable = true) :: - StructField("cl3", IntegerType, nullable = true) :: Nil) + StructField("cl2", IntegerType, nullable = true) :: + StructField("cl3", IntegerType, nullable = true) :: Nil) val row = Row(3, null, 4) val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + // if we write and then read, the read will enforce schema to be nullable val tableName = "tab" withTable(tableName) { df.write.format(format).mode("overwrite").saveAsTable(tableName) @@ -536,12 +545,41 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Row("cl1", "test") :: Nil) // Verify the schema val expectedFields = schema.fields.map(f => f.copy(nullable = true)) - assert(spark.table(tableName).schema == schema.copy(fields = expectedFields)) + assert(spark.table(tableName).schema === schema.copy(fields = expectedFields)) } } } } + test("parquet - column nullability -- write only") { + val schema = StructType( + StructField("cl1", IntegerType, nullable = false) :: + StructField("cl2", IntegerType, nullable = true) :: Nil) + val row = Row(3, 4) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), new Configuration()) + val f = ParquetFileReader.open(hadoopInputFile) + val parquetSchema = f.getFileMetaData.getSchema.getColumns.asScala + .map(_.getPrimitiveType) + f.close() + + // the write keeps nullable info from the schema + val expectedParquetSchema = Seq( + new PrimitiveType(Repetition.REQUIRED, PrimitiveTypeName.INT32, "cl1"), + new PrimitiveType(Repetition.OPTIONAL, PrimitiveTypeName.INT32, "cl2") + ) + + assert (expectedParquetSchema === parquetSchema) + } + + } + test("SPARK-17230: write out results of decimal calculation") { val df = spark.range(99, 101) .selectExpr("id", "cast(id as long) * cast('1.0' as decimal(38, 18)) as num") From 3b8ae23735f5b29db95516662190f606edc51fd7 Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 7 Dec 2018 14:31:35 -0600 Subject: [PATCH 2235/2461] [SPARK-26196][SPARK-26281][WEBUI] Total tasks title in the stage page is incorrect when there are failed or killed tasks and update duration metrics ## What changes were proposed in this pull request? This PR fixes 3 issues 1) Total tasks message in the tasks table is incorrect, when there are failed or killed tasks 2) Sorting of the "Duration" column is not correct 3) Duration in the aggregated tasks summary table and the tasks table and not matching. Total tasks = numCompleteTasks + numActiveTasks + numKilledTasks + numFailedTasks; Corrected the duration metrics in the tasks table as executorRunTime based on the PR https://github.com/apache/spark/pull/23081 ## How was this patch tested? test step: 1) ``` bin/spark-shell scala > sc.parallelize(1 to 100, 10).map{ x => throw new RuntimeException("Bad executor")}.collect() ``` ![screenshot from 2018-11-28 07-26-00](https://user-images.githubusercontent.com/23054875/49123523-e2691880-f2de-11e8-9c16-60d1865e6e77.png) After patch: ![screenshot from 2018-11-28 07-24-31](https://user-images.githubusercontent.com/23054875/49123525-e432dc00-f2de-11e8-89ca-4a53e19c9c18.png) 2) Duration metrics: Before patch: ![screenshot from 2018-12-06 03-25-14](https://user-images.githubusercontent.com/23054875/49546591-9e8d9900-f906-11e8-8a0b-157742c47655.png) After patch: ![screenshot from 2018-12-06 03-23-14](https://user-images.githubusercontent.com/23054875/49546589-9cc3d580-f906-11e8-827f-52ef8ffdeaec.png) Closes #23160 from shahidki31/totalTasks. Authored-by: Shahid Signed-off-by: Sean Owen --- .../resources/org/apache/spark/ui/static/stagepage.js | 9 +++++---- .../org/apache/spark/status/api/v1/StagesResource.scala | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js index 564467487e84e..08de2b0fee034 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js @@ -616,7 +616,8 @@ $(document).ready(function () { $("#accumulator-table").DataTable(accumulatorConf); // building tasks table that uses server side functionality - var totalTasksToShow = responseBody.numCompleteTasks + responseBody.numActiveTasks; + var totalTasksToShow = responseBody.numCompleteTasks + responseBody.numActiveTasks + + responseBody.numKilledTasks + responseBody.numFailedTasks; var taskTable = "#active-tasks-table"; var taskConf = { "serverSide": true, @@ -667,8 +668,8 @@ $(document).ready(function () { {data : "launchTime", name: "Launch Time", render: formatDate}, { data : function (row, type) { - if (row.duration) { - return type === 'display' ? formatDuration(row.duration) : row.duration; + if (row.taskMetrics && row.taskMetrics.executorRunTime) { + return type === 'display' ? formatDuration(row.taskMetrics.executorRunTime) : row.taskMetrics.executorRunTime; } else { return ""; } @@ -927,7 +928,7 @@ $(document).ready(function () { // title number and toggle list $("#summaryMetricsTitle").html("Summary Metrics for " + "" + responseBody.numCompleteTasks + " Completed Tasks" + ""); - $("#tasksTitle").html("Task (" + totalTasksToShow + ")"); + $("#tasksTitle").html("Tasks (" + totalTasksToShow + ")"); // hide or show the accumulate update table if (accumulatorTable.length == 0) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index f81892734c2de..9d1d66a0e15a4 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -210,7 +210,6 @@ private[v1] class StagesResource extends BaseAppResource { (containsValue(f.taskId) || containsValue(f.index) || containsValue(f.attempt) || containsValue(f.launchTime) || containsValue(f.resultFetchStart.getOrElse(defaultOptionString)) - || containsValue(f.duration.getOrElse(defaultOptionString)) || containsValue(f.executorId) || containsValue(f.host) || containsValue(f.status) || containsValue(f.taskLocality) || containsValue(f.speculative) || containsValue(f.errorMessage.getOrElse(defaultOptionString)) From 20278e719e28fc5d7a8069e0498a8df143ecee90 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 7 Dec 2018 13:53:35 -0800 Subject: [PATCH 2236/2461] [SPARK-24333][ML][PYTHON] Add fit with validation set to spark.ml GBT: Python API ## What changes were proposed in this pull request? Add validationIndicatorCol and validationTol to GBT Python. ## How was this patch tested? Add test in doctest to test the new API. Closes #21465 from huaxingao/spark-24333. Authored-by: Huaxin Gao Signed-off-by: Bryan Cutler --- python/pyspark/ml/classification.py | 81 ++++++++------ .../ml/param/_shared_params_code_gen.py | 5 +- python/pyspark/ml/param/shared.py | 71 ++++++++----- python/pyspark/ml/regression.py | 100 +++++++++++++----- 4 files changed, 169 insertions(+), 88 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ce028512357f2..6ddfce95a3d4d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -23,7 +23,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.param.shared import * from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \ - RandomForestParams, TreeEnsembleModel, TreeEnsembleParams + GBTParams, HasVarianceImpurity, RandomForestParams, TreeEnsembleModel, TreeEnsembleParams from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper @@ -895,15 +895,6 @@ def getImpurity(self): return self.getOrDefault(self.impurity) -class GBTParams(TreeEnsembleParams): - """ - Private class to track supported GBT params. - - .. versionadded:: 1.4.0 - """ - supportedLossTypes = ["logistic"] - - @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, @@ -1174,9 +1165,31 @@ def trees(self): return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] +class GBTClassifierParams(GBTParams, HasVarianceImpurity): + """ + Private class to track supported GBTClassifier params. + + .. versionadded:: 3.0.0 + """ + + supportedLossTypes = ["logistic"] + + lossType = Param(Params._dummy(), "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(supportedLossTypes), + typeConverter=TypeConverters.toString) + + @since("1.4.0") + def getLossType(self): + """ + Gets the value of lossType or its default value. + """ + return self.getOrDefault(self.lossType) + + @inherit_doc -class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, +class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable): """ `Gradient-Boosted Trees (GBTs) `_ @@ -1242,32 +1255,28 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] >>> model.numClasses 2 + >>> gbt = gbt.setValidationIndicatorCol("validationIndicator") + >>> gbt.getValidationIndicatorCol() + 'validationIndicator' + >>> gbt.getValidationTol() + 0.01 .. versionadded:: 1.4.0 """ - lossType = Param(Params._dummy(), "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes), - typeConverter=TypeConverters.toString) - - stepSize = Param(Params._dummy(), "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + - "the contribution of each estimator.", - typeConverter=TypeConverters.toFloat) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, - featureSubsetStrategy="all"): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance", + featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ - featureSubsetStrategy="all") + impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \ + validationIndicatorCol=None) """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -1275,7 +1284,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, - featureSubsetStrategy="all") + impurity="variance", featureSubsetStrategy="all", validationTol=0.01) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1285,13 +1294,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, - featureSubsetStrategy="all"): + impurity="variance", featureSubsetStrategy="all", validationTol=0.01, + validationIndicatorCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ - featureSubsetStrategy="all") + impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \ + validationIndicatorCol=None) Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1307,13 +1318,6 @@ def setLossType(self, value): """ return self._set(lossType=value) - @since("1.4.0") - def getLossType(self): - """ - Gets the value of lossType or its default value. - """ - return self.getOrDefault(self.lossType) - @since("2.4.0") def setFeatureSubsetStrategy(self, value): """ @@ -1321,6 +1325,13 @@ def setFeatureSubsetStrategy(self, value): """ return self._set(featureSubsetStrategy=value) + @since("3.0.0") + def setValidationIndicatorCol(self, value): + """ + Sets the value of :py:attr:`validationIndicatorCol`. + """ + return self._set(validationIndicatorCol=value) + class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index e45ba840b412b..1b0c8c5d28b78 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -164,7 +164,10 @@ def get$Name(self): "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"), ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", - "'euclidean'", "TypeConverters.toString")] + "'euclidean'", "TypeConverters.toString"), + ("validationIndicatorCol", "name of the column that indicates whether each row is for " + + "training or for validation. False indicates training; true indicates validation.", + None, "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 618f5bf0a8103..6405b9fce7efb 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -702,6 +702,53 @@ def getLoss(self): return self.getOrDefault(self.loss) +class HasDistanceMeasure(Params): + """ + Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. + """ + + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasDistanceMeasure, self).__init__() + self._setDefault(distanceMeasure='euclidean') + + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + def getDistanceMeasure(self): + """ + Gets the value of distanceMeasure or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + + +class HasValidationIndicatorCol(Params): + """ + Mixin for param validationIndicatorCol: name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation. + """ + + validationIndicatorCol = Param(Params._dummy(), "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasValidationIndicatorCol, self).__init__() + + def setValidationIndicatorCol(self, value): + """ + Sets the value of :py:attr:`validationIndicatorCol`. + """ + return self._set(validationIndicatorCol=value) + + def getValidationIndicatorCol(self): + """ + Gets the value of validationIndicatorCol or its default value. + """ + return self.getOrDefault(self.validationIndicatorCol) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. @@ -790,27 +837,3 @@ def getCacheNodeIds(self): """ return self.getOrDefault(self.cacheNodeIds) - -class HasDistanceMeasure(Params): - """ - Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. - """ - - distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) - - def __init__(self): - super(HasDistanceMeasure, self).__init__() - self._setDefault(distanceMeasure='euclidean') - - def setDistanceMeasure(self, value): - """ - Sets the value of :py:attr:`distanceMeasure`. - """ - return self._set(distanceMeasure=value) - - def getDistanceMeasure(self): - """ - Gets the value of distanceMeasure or its default value. - """ - return self.getOrDefault(self.distanceMeasure) - diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 98f4361351847..78cb4a6703554 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -650,19 +650,20 @@ def getFeatureSubsetStrategy(self): return self.getOrDefault(self.featureSubsetStrategy) -class TreeRegressorParams(Params): +class HasVarianceImpurity(Params): """ Private class to track supported impurity measures. """ supportedImpurities = ["variance"] + impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString) def __init__(self): - super(TreeRegressorParams, self).__init__() + super(HasVarianceImpurity, self).__init__() @since("1.4.0") def setImpurity(self, value): @@ -679,6 +680,10 @@ def getImpurity(self): return self.getOrDefault(self.impurity) +class TreeRegressorParams(HasVarianceImpurity): + pass + + class RandomForestParams(TreeEnsembleParams): """ Private class to track supported random forest parameters. @@ -705,12 +710,52 @@ def getNumTrees(self): return self.getOrDefault(self.numTrees) -class GBTParams(TreeEnsembleParams): +class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol): """ Private class to track supported GBT params. """ + + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator.", + typeConverter=TypeConverters.toFloat) + + validationTol = Param(Params._dummy(), "validationTol", + "Threshold for stopping early when fit with validation is used. " + + "If the error rate on the validation input changes by less than the " + + "validationTol, then learning will stop early (before `maxIter`). " + + "This parameter is ignored when fit without validation is used.", + typeConverter=TypeConverters.toFloat) + + @since("3.0.0") + def getValidationTol(self): + """ + Gets the value of validationTol or its default value. + """ + return self.getOrDefault(self.validationTol) + + +class GBTRegressorParams(GBTParams, TreeRegressorParams): + """ + Private class to track supported GBTRegressor params. + + .. versionadded:: 3.0.0 + """ + supportedLossTypes = ["squared", "absolute"] + lossType = Param(Params._dummy(), "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(supportedLossTypes), + typeConverter=TypeConverters.toString) + + @since("1.4.0") + def getLossType(self): + """ + Gets the value of lossType or its default value. + """ + return self.getOrDefault(self.lossType) + @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, @@ -1030,9 +1075,9 @@ def featureImportances(self): @inherit_doc -class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, - JavaMLReadable, TreeRegressorParams): +class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + GBTRegressorParams, HasCheckpointInterval, HasSeed, JavaMLWritable, + JavaMLReadable): """ `Gradient-Boosted Trees (GBTs) `_ learning algorithm for regression. @@ -1079,39 +1124,36 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, ... ["label", "features"]) >>> model.evaluateEachIteration(validation, "squared") [0.0, 0.0, 0.0, 0.0, 0.0] + >>> gbt = gbt.setValidationIndicatorCol("validationIndicator") + >>> gbt.getValidationIndicatorCol() + 'validationIndicator' + >>> gbt.getValidationTol() + 0.01 .. versionadded:: 1.4.0 """ - lossType = Param(Params._dummy(), "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes), - typeConverter=TypeConverters.toString) - - stepSize = Param(Params._dummy(), "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + - "the contribution of each estimator.", - typeConverter=TypeConverters.toFloat) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance", featureSubsetStrategy="all"): + impurity="variance", featureSubsetStrategy="all", validationTol=0.01, + validationIndicatorCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance", featureSubsetStrategy="all") + impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \ + validationIndicatorCol=None) """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance", featureSubsetStrategy="all") + impurity="variance", featureSubsetStrategy="all", validationTol=0.01) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1121,13 +1163,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance", featureSubsetStrategy="all"): + impuriy="variance", featureSubsetStrategy="all", validationTol=0.01, + validationIndicatorCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance", featureSubsetStrategy="all") + impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \ + validationIndicatorCol=None) Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1143,13 +1187,6 @@ def setLossType(self, value): """ return self._set(lossType=value) - @since("1.4.0") - def getLossType(self): - """ - Gets the value of lossType or its default value. - """ - return self.getOrDefault(self.lossType) - @since("2.4.0") def setFeatureSubsetStrategy(self, value): """ @@ -1157,6 +1194,13 @@ def setFeatureSubsetStrategy(self, value): """ return self._set(featureSubsetStrategy=value) + @since("3.0.0") + def setValidationIndicatorCol(self, value): + """ + Sets the value of :py:attr:`validationIndicatorCol`. + """ + return self._set(validationIndicatorCol=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ From 9b1f6c8bab5401258c653d4e2efb50e97c6d282f Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Fri, 7 Dec 2018 13:58:02 -0800 Subject: [PATCH 2237/2461] [SPARK-26304][SS] Add default value to spark.kafka.sasl.kerberos.service.name parameter ## What changes were proposed in this pull request? spark.kafka.sasl.kerberos.service.name is an optional parameter but most of the time value `kafka` has to be set. As I've written in the jira the following reasoning is behind: * Kafka's configuration guide suggest the same value: https://kafka.apache.org/documentation/#security_sasl_kerberos_brokerconfig * It would be easier for spark users by providing less configuration * Other streaming engines are doing the same In this PR I've changed the parameter from optional to `WithDefault` and set `kafka` as default value. ## How was this patch tested? Available unit tests + on cluster. Closes #23254 from gaborgsomogyi/SPARK-26304. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../deploy/security/KafkaTokenUtil.scala | 7 ++---- .../apache/spark/internal/config/Kafka.scala | 2 +- .../deploy/security/KafkaTokenUtilSuite.scala | 24 ------------------- .../sql/kafka010/KafkaSecurityHelper.scala | 5 +--- .../kafka010/KafkaSecurityHelperSuite.scala | 15 ------------ 5 files changed, 4 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala b/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala index c890cee59ffe0..aec0f72feb3c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/KafkaTokenUtil.scala @@ -143,14 +143,11 @@ private[spark] object KafkaTokenUtil extends Logging { } private[security] def getKeytabJaasParams(sparkConf: SparkConf): String = { - val serviceName = sparkConf.get(Kafka.KERBEROS_SERVICE_NAME) - require(serviceName.nonEmpty, "Kerberos service name must be defined") - val params = s""" |${getKrb5LoginModuleName} required | useKeyTab=true - | serviceName="${serviceName.get}" + | serviceName="${sparkConf.get(Kafka.KERBEROS_SERVICE_NAME)}" | keyTab="${sparkConf.get(KEYTAB).get}" | principal="${sparkConf.get(PRINCIPAL).get}"; """.stripMargin.replace("\n", "") @@ -166,7 +163,7 @@ private[spark] object KafkaTokenUtil extends Logging { s""" |${getKrb5LoginModuleName} required | useTicketCache=true - | serviceName="${serviceName.get}"; + | serviceName="${sparkConf.get(Kafka.KERBEROS_SERVICE_NAME)}"; """.stripMargin.replace("\n", "") logDebug(s"Krb ticket cache JAAS params: $params") params diff --git a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala index 85d74c27142ad..064fc93cb8ed8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala @@ -40,7 +40,7 @@ private[spark] object Kafka { "Kafka's JAAS config or in Kafka's config. For further details please see kafka " + "documentation. Only used to obtain delegation token.") .stringConf - .createOptional + .createWithDefault("kafka") val TRUSTSTORE_LOCATION = ConfigBuilder("spark.kafka.ssl.truststore.location") diff --git a/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala index 682bebde916fa..18aa537b3a51d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/KafkaTokenUtilSuite.scala @@ -36,7 +36,6 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { private val keyStorePassword = "keyStoreSecret" private val keyPassword = "keySecret" private val keytab = "/path/to/keytab" - private val kerberosServiceName = "kafka" private val principal = "user@domain.com" private var sparkConf: SparkConf = null @@ -96,7 +95,6 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { sparkConf.set(Kafka.KEYSTORE_LOCATION, keyStoreLocation) sparkConf.set(Kafka.KEYSTORE_PASSWORD, keyStorePassword) sparkConf.set(Kafka.KEY_PASSWORD, keyPassword) - sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) @@ -119,7 +117,6 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { sparkConf.set(Kafka.KEYSTORE_LOCATION, keyStoreLocation) sparkConf.set(Kafka.KEYSTORE_PASSWORD, keyStorePassword) sparkConf.set(Kafka.KEY_PASSWORD, keyPassword) - sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) @@ -143,7 +140,6 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { sparkConf.set(Kafka.KEYSTORE_LOCATION, keyStoreLocation) sparkConf.set(Kafka.KEYSTORE_PASSWORD, keyStorePassword) sparkConf.set(Kafka.KEY_PASSWORD, keyPassword) - sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) @@ -177,7 +173,6 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_SSL.name) sparkConf.set(KEYTAB, keytab) - sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) sparkConf.set(PRINCIPAL, principal) val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) @@ -195,7 +190,6 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { test("createAdminClientProperties without keytab should set ticket cache dynamic jaas config") { sparkConf.set(Kafka.BOOTSTRAP_SERVERS, bootStrapServers) sparkConf.set(Kafka.SECURITY_PROTOCOL, SASL_SSL.name) - sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) val adminClientProperties = KafkaTokenUtil.createAdminClientProperties(sparkConf) @@ -218,22 +212,4 @@ class KafkaTokenUtilSuite extends SparkFunSuite with BeforeAndAfterEach { assert(KafkaTokenUtil.isGlobalJaasConfigurationProvided) } - - test("getKeytabJaasParams with keytab no service should throw exception") { - sparkConf.set(KEYTAB, keytab) - - val thrown = intercept[IllegalArgumentException] { - KafkaTokenUtil.getKeytabJaasParams(sparkConf) - } - - assert(thrown.getMessage contains "Kerberos service name must be defined") - } - - test("getTicketCacheJaasParams without service should throw exception") { - val thrown = intercept[IllegalArgumentException] { - KafkaTokenUtil.getTicketCacheJaasParams(sparkConf) - } - - assert(thrown.getMessage contains "Kerberos service name must be defined") - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala index 74d5ef9c05f14..7215295b10091 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelper.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.kafka010 import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.kafka.common.security.scram.ScramLoginModule import org.apache.spark.SparkConf @@ -35,8 +34,6 @@ private[kafka010] object KafkaSecurityHelper extends Logging { def getTokenJaasParams(sparkConf: SparkConf): String = { val token = UserGroupInformation.getCurrentUser().getCredentials.getToken( KafkaTokenUtil.TOKEN_SERVICE) - val serviceName = sparkConf.get(Kafka.KERBEROS_SERVICE_NAME) - require(serviceName.isDefined, "Kerberos service name must be defined") val username = new String(token.getIdentifier) val password = new String(token.getPassword) @@ -45,7 +42,7 @@ private[kafka010] object KafkaSecurityHelper extends Logging { s""" |$loginModuleName required | tokenauth=true - | serviceName="${serviceName.get}" + | serviceName="${sparkConf.get(Kafka.KERBEROS_SERVICE_NAME)}" | username="$username" | password="$password"; """.stripMargin.replace("\n", "") diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala index 772fe4614bad0..fd9dee390d185 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala @@ -26,12 +26,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.security.KafkaTokenUtil import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdentifier -import org.apache.spark.internal.config._ class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { - private val keytab = "/path/to/keytab" - private val kerberosServiceName = "kafka" - private val principal = "user@domain.com" private val tokenId = "tokenId" + UUID.randomUUID().toString private val tokenPassword = "tokenPassword" + UUID.randomUUID().toString @@ -76,19 +72,8 @@ class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { assert(KafkaSecurityHelper.isTokenAvailable()) } - test("getTokenJaasParams with token no service should throw exception") { - addTokenToUGI() - - val thrown = intercept[IllegalArgumentException] { - KafkaSecurityHelper.getTokenJaasParams(sparkConf) - } - - assert(thrown.getMessage contains "Kerberos service name must be defined") - } - test("getTokenJaasParams with token should return scram module") { addTokenToUGI() - sparkConf.set(Kafka.KERBEROS_SERVICE_NAME, kerberosServiceName) val jaasParams = KafkaSecurityHelper.getTokenJaasParams(sparkConf) From 2ea9792fdeb07be19d63e7625cfc483e062a1d9c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 8 Dec 2018 05:59:53 -0600 Subject: [PATCH 2238/2461] [SPARK-26266][BUILD] Update to Scala 2.12.8 ## What changes were proposed in this pull request? Update to Scala 2.12.8 ## How was this patch tested? Existing tests. Closes #23218 from srowen/SPARK-26266. Authored-by: Sean Owen Signed-off-by: Sean Owen --- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- dev/deps/spark-deps-hadoop-3.1 | 6 +++--- docs/_config.yml | 2 +- pom.xml | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index d250d5205586e..71423af0789c6 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -171,10 +171,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar -scala-compiler-2.12.7.jar -scala-library-2.12.7.jar +scala-compiler-2.12.8.jar +scala-library-2.12.8.jar scala-parser-combinators_2.12-1.1.0.jar -scala-reflect-2.12.7.jar +scala-reflect-2.12.8.jar scala-xml_2.12-1.0.5.jar shapeless_2.12-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 347503ace557a..93eafef045330 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -189,10 +189,10 @@ protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.12.7.jar -scala-library-2.12.7.jar +scala-compiler-2.12.8.jar +scala-library-2.12.8.jar scala-parser-combinators_2.12-1.1.0.jar -scala-reflect-2.12.7.jar +scala-reflect-2.12.8.jar scala-xml_2.12-1.0.5.jar shapeless_2.12-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/docs/_config.yml b/docs/_config.yml index 649d18bf72b57..146c90fcff6e5 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -17,7 +17,7 @@ include: SPARK_VERSION: 3.0.0-SNAPSHOT SPARK_VERSION_SHORT: 3.0.0 SCALA_BINARY_VERSION: "2.12" -SCALA_VERSION: "2.12.7" +SCALA_VERSION: "2.12.8" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/pom.xml b/pom.xml index 15e5c18290405..310d7de955125 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,7 @@ 3.4.1 3.2.2 - 2.12.7 + 2.12.8 2.12 --diff --test From 678e1aca6901944c119d2ec56169d4e69fce66de Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 8 Dec 2018 22:23:50 +0800 Subject: [PATCH 2239/2461] [SPARK-24207][R] follow-up PR for SPARK-24207 to fix code style problems ## What changes were proposed in this pull request? follow-up PR for SPARK-24207 to fix code style problems Closes #23256 from huaxingao/spark-24207-cnt. Authored-by: Huaxin Gao Signed-off-by: Hyukjin Kwon --- R/pkg/R/mllib_fpm.R | 7 +++-- R/pkg/tests/fulltests/test_mllib_fpm.R | 29 ++++++++++--------- R/pkg/vignettes/sparkr-vignettes.Rmd | 7 +++-- examples/src/main/r/ml/prefixSpan.R | 9 +++--- .../spark/examples/ml/FPGrowthExample.scala | 3 -- .../spark/examples/ml/PrefixSpanExample.scala | 3 -- 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index ac37580c6b373..c248e9ec9be94 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -190,9 +190,10 @@ setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), #' @examples #' \dontrun{ #' df <- createDataFrame(list(list(list(list(1L, 2L), list(3L))), -#' list(list(list(1L), list(3L, 2L), list(1L, 2L))), -#' list(list(list(1L, 2L), list(5L))), -#' list(list(list(6L)))), schema = c("sequence")) +#' list(list(list(1L), list(3L, 2L), list(1L, 2L))), +#' list(list(list(1L, 2L), list(5L))), +#' list(list(list(6L)))), +#' schema = c("sequence")) #' frequency <- spark.findFrequentSequentialPatterns(df, minSupport = 0.5, maxPatternLength = 5L, #' maxLocalProjDBSize = 32000000L) #' showDF(frequency) diff --git a/R/pkg/tests/fulltests/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R index daf9ff97a8216..bc1e17538d41a 100644 --- a/R/pkg/tests/fulltests/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -84,19 +84,20 @@ test_that("spark.fpGrowth", { }) test_that("spark.prefixSpan", { - df <- createDataFrame(list(list(list(list(1L, 2L), list(3L))), - list(list(list(1L), list(3L, 2L), list(1L, 2L))), - list(list(list(1L, 2L), list(5L))), - list(list(list(6L)))), schema = c("sequence")) - result1 <- spark.findFrequentSequentialPatterns(df, minSupport = 0.5, maxPatternLength = 5L, - maxLocalProjDBSize = 32000000L) - - expected_result <- createDataFrame(list(list(list(list(1L)), 3L), - list(list(list(3L)), 2L), - list(list(list(2L)), 3L), - list(list(list(1L, 2L)), 3L), - list(list(list(1L), list(3L)), 2L)), - schema = c("sequence", "freq")) - }) + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L))), + list(list(list(1L), list(3L, 2L), list(1L, 2L))), + list(list(list(1L, 2L), list(5L))), + list(list(list(6L)))), + schema = c("sequence")) + result <- spark.findFrequentSequentialPatterns(df, minSupport = 0.5, maxPatternLength = 5L, + maxLocalProjDBSize = 32000000L) + + expected_result <- createDataFrame(list(list(list(list(1L)), 3L), list(list(list(3L)), 2L), + list(list(list(2L)), 3L), list(list(list(1L, 2L)), 3L), + list(list(list(1L), list(3L)), 2L)), + schema = c("sequence", "freq")) + + expect_equivalent(expected_result, result) +}) sparkR.session.stop() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index f80b45b4f36a8..1c6a03c4b9bc3 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1019,9 +1019,10 @@ head(predict(fpm, df)) ```{r} df <- createDataFrame(list(list(list(list(1L, 2L), list(3L))), - list(list(list(1L), list(3L, 2L), list(1L, 2L))), - list(list(list(1L, 2L), list(5L))), - list(list(list(6L)))), schema = c("sequence")) + list(list(list(1L), list(3L, 2L), list(1L, 2L))), + list(list(list(1L, 2L), list(5L))), + list(list(list(6L)))), + schema = c("sequence")) head(spark.findFrequentSequentialPatterns(df, minSupport = 0.5, maxPatternLength = 5L)) ``` diff --git a/examples/src/main/r/ml/prefixSpan.R b/examples/src/main/r/ml/prefixSpan.R index 9b70573ffb787..02908aeb02968 100644 --- a/examples/src/main/r/ml/prefixSpan.R +++ b/examples/src/main/r/ml/prefixSpan.R @@ -28,9 +28,10 @@ sparkR.session(appName = "SparkR-ML-prefixSpan-example") # Load training data df <- createDataFrame(list(list(list(list(1L, 2L), list(3L))), - list(list(list(1L), list(3L, 2L), list(1L, 2L))), - list(list(list(1L, 2L), list(5L))), - list(list(list(6L)))), schema = c("sequence")) + list(list(list(1L), list(3L, 2L), list(1L, 2L))), + list(list(list(1L, 2L), list(5L))), + list(list(list(6L)))), + schema = c("sequence")) # Finding frequent sequential patterns frequency <- spark.findFrequentSequentialPatterns(df, minSupport = 0.5, maxPatternLength = 5L, @@ -39,4 +40,4 @@ showDF(frequency) # $example off$ -sparkR.session.stop() \ No newline at end of file +sparkR.session.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala index 59110d70de550..bece0d96c030f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala @@ -17,8 +17,6 @@ package org.apache.spark.examples.ml -// scalastyle:off println - // $example on$ import org.apache.spark.ml.fpm.FPGrowth // $example off$ @@ -64,4 +62,3 @@ object FPGrowthExample { spark.stop() } } -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PrefixSpanExample.scala index 0a2d31097a024..b4e0811c506be 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PrefixSpanExample.scala @@ -17,8 +17,6 @@ package org.apache.spark.examples.ml -// scalastyle:off println - // $example on$ import org.apache.spark.ml.fpm.PrefixSpan // $example off$ @@ -59,4 +57,3 @@ object PrefixSpanExample { spark.stop() } } -// scalastyle:on println From bdf32847b1ffcb3aa4d0bef058f86e65656e99fb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 8 Dec 2018 11:18:09 -0800 Subject: [PATCH 2240/2461] [SPARK-26021][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/23043 There are 4 places we need to deal with NaN and -0.0: 1. comparison expressions. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. 2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. 3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. Different NaNs should be assigned to the same group. 4. window partition keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. The case 1 is OK. Our comparison already handles NaN and -0.0, and for struct/array/map, we will recursively compare the fields/elements. Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. To fix it, a simple solution is: normalize float/double when building unsafe data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to worry about it anymore. Following this direction, this PR moves the handling of NaN and -0.0 from `Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not handle them, which reduces the perf overhead. It's also easier to add comments explaining why we do it in `UnsafeWriter`. ## How was this patch tested? existing tests Closes #23239 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/unsafe/Platform.java | 10 ------ .../spark/unsafe/PlatformUtilSuite.java | 18 ---------- .../expressions/codegen/UnsafeWriter.java | 35 +++++++++++++++++++ .../codegen/UnsafeRowWriterSuite.scala | 20 +++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 12 +++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 14 ++++++++ 6 files changed, 81 insertions(+), 28 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 4563efcfcf474..076b693f81c88 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -174,11 +174,6 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } _UNSAFE.putFloat(object, offset, value); } @@ -187,11 +182,6 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } _UNSAFE.putDouble(object, offset, value); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 2474081dad5c9..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,22 +157,4 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } - - @Test - // SPARK-26021 - public void writeMinusZeroIsReplacedWithZero() { - byte[] doubleBytes = new byte[Double.BYTES]; - byte[] floatBytes = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - - byte[] doubleBytes2 = new byte[Double.BYTES]; - byte[] floatBytes2 = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f); - - // Make sure the bytes we write from 0.0 and -0.0 are same. - Assert.assertArrayEquals(doubleBytes, doubleBytes2); - Assert.assertArrayEquals(floatBytes, floatBytes2); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 95263a0da95a8..7553ab8cf7000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -198,11 +198,46 @@ protected final void writeLong(long offset, long value) { Platform.putLong(getBuffer(), offset, value); } + // We need to take care of NaN and -0.0 in several places: + // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be + // treated as same. + // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong + // to the same group. + // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be + // treated as same. + // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` + // should be treated as same. + // + // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we + // recursively compare the fields/elements, so it's also fine. + // + // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different + // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. + // + // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing + // float/double columns and nested fields to `UnsafeRow`. + // + // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract + // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex + // types, so nested float/double may not be normalized. We need to make sure that all the unsafe + // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during + // creation. protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } Platform.putFloat(getBuffer(), offset, value); } + // See comments for `writeFloat`. protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index fb651b76fc16d..22e1fa6dfed4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite { assert(res1 == res2) } + test("SPARK-26021: normalize float/double NaN and -0.0") { + val unsafeRowWriter1 = new UnsafeRowWriter(4) + unsafeRowWriter1.resetRowWriter() + unsafeRowWriter1.write(0, Float.NaN) + unsafeRowWriter1.write(1, Double.NaN) + unsafeRowWriter1.write(2, 0.0f) + unsafeRowWriter1.write(3, 0.0) + val res1 = unsafeRowWriter1.getRow + + val unsafeRowWriter2 = new UnsafeRowWriter(4) + unsafeRowWriter2.resetRowWriter() + unsafeRowWriter2.write(0, 0.0f/0.0f) + unsafeRowWriter2.write(1, 0.0/0.0) + unsafeRowWriter2.write(2, -0.0f) + unsafeRowWriter2.write(3, -0.0) + val res2 = unsafeRowWriter2.getRow + + // The two rows should be the equal + assert(res1 == res2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e6b30f9956daf..c9f41ab1c0179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } + + test("NaN and -0.0 in join keys") { + val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") + val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") + val joined = df1.join(df2, Seq("f", "d")) + checkAnswer(joined, Seq( + Row(Float.NaN, Double.NaN), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(0.0f, 0.0))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78277d7dcf757..9a5d5a9966ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -681,4 +681,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("S2", "P2", 300, 300, 500))) } + + test("NaN and -0.0 in window partition keys") { + val df = Seq( + (Float.NaN, Double.NaN, 1), + (0.0f/0.0f, 0.0/0.0, 1), + (0.0f, 0.0, 1), + (-0.0f, -0.0, 1)).toDF("f", "d", "i") + val result = df.select($"f", count("i").over(Window.partitionBy("f", "d"))) + checkAnswer(result, Seq( + Row(Float.NaN, 2), + Row(Float.NaN, 2), + Row(0.0f, 2), + Row(0.0f, 2))) + } } From 55276d3a26474e7479941db3e9c065d86344885f Mon Sep 17 00:00:00 2001 From: seancxmao Date: Sat, 8 Dec 2018 17:53:12 -0800 Subject: [PATCH 2241/2461] [SPARK-25132][SQL][FOLLOWUP][DOC] Add migration doc for case-insensitive field resolution when reading from Parquet ## What changes were proposed in this pull request? #22148 introduces a behavior change. According to discussion at #22184, this PR updates migration guide when upgrade from Spark 2.3 to 2.4. ## How was this patch tested? N/A Closes #23238 from seancxmao/SPARK-25132-doc-2.4. Authored-by: seancxmao Signed-off-by: Dongjoon Hyun --- docs/sql-migration-guide-upgrade.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 67c30fb941ecd..f6458a9b2730b 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -145,6 +145,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.3 and earlier, HAVING without GROUP BY is treated as WHERE. This means, `SELECT 1 FROM range(10) HAVING true` is executed as `SELECT 1 FROM range(10) WHERE true` and returns 10 rows. This violates SQL standard, and has been fixed in Spark 2.4. Since Spark 2.4, HAVING without GROUP BY is treated as a global aggregate, which means `SELECT 1 FROM range(10) HAVING true` will return only one row. To restore the previous behavior, set `spark.sql.legacy.parser.havingWithoutGroupByAsWhere` to `true`. + - In version 2.3 and earlier, when reading from a Parquet data source table, Spark always returns null for any column whose column names in Hive metastore schema and Parquet schema are in different letter cases, no matter whether `spark.sql.caseSensitive` is set to `true` or `false`. Since 2.4, when `spark.sql.caseSensitive` is set to `false`, Spark does case insensitive column name resolution between Hive metastore schema and Parquet schema, so even column names are in different letter cases, Spark returns corresponding column values. An exception is thrown if there is ambiguity, i.e. more than one Parquet column is matched. This change also applies to Parquet Hive tables when `spark.sql.hive.convertMetastoreParquet` is set to `true`. + ## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production. From 877f82cb30bc4edef770b36e1e394a887ab535c6 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 9 Dec 2018 10:49:15 +0800 Subject: [PATCH 2242/2461] [SPARK-26193][SQL] Implement shuffle write metrics in SQL ## What changes were proposed in this pull request? 1. Implement `SQLShuffleWriteMetricsReporter` on the SQL side as the customized `ShuffleWriteMetricsReporter`. 2. Add shuffle write metrics to `ShuffleExchangeExec`, and use these metrics to create corresponding `SQLShuffleWriteMetricsReporter` in shuffle dependency. 3. Rework on `ShuffleMapTask` to add new class named `ShuffleWriteProcessor` which control shuffle write process, we use sql shuffle write metrics by customizing a ShuffleWriteProcessor on SQL side. ## How was this patch tested? Add UT in SQLMetricsSuite. Manually test locally, update screen shot to document attached in JIRA. Closes #23207 from xuanyuanking/SPARK-26193. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/Dependency.scala | 6 +- .../spark/scheduler/ShuffleMapTask.scala | 20 +---- .../spark/shuffle/ShuffleWriteProcessor.scala | 74 +++++++++++++++++++ project/MimaExcludes.scala | 3 + .../exchange/ShuffleExchangeExec.scala | 38 ++++++++-- .../apache/spark/sql/execution/limit.scala | 30 ++++++-- .../sql/execution/metric/SQLMetrics.scala | 12 +++ .../metric/SQLShuffleMetricsReporter.scala | 55 ++++++++++++++ .../execution/metric/SQLMetricsSuite.scala | 36 +++++++-- 9 files changed, 234 insertions(+), 40 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 9ea6d2fa2fd95..fb051a8c0db8e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} /** * :: DeveloperApi :: @@ -65,6 +65,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param keyOrdering key ordering for RDD's shuffles * @param aggregator map/reduce-side aggregator for RDD's shuffle * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) + * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask */ @DeveloperApi class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @@ -73,7 +74,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val serializer: Serializer = SparkEnv.get.serializer, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false) + val mapSideCombine: Boolean = false, + val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor) extends Dependency[Product2[K, V]] { if (mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 5412717d61988..2a8d1dd995e27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -92,25 +92,7 @@ private[spark] class ShuffleMapTask( threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L - var writer: ShuffleWriter[Any, Any] = null - try { - val manager = SparkEnv.get.shuffleManager - writer = manager.getWriter[Any, Any]( - dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics) - writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - writer.stop(success = true).get - } catch { - case e: Exception => - try { - if (writer != null) { - writer.stop(success = false) - } - } catch { - case e: Exception => - log.debug("Could not stop writer", e) - } - throw e - } + dep.shuffleWriterProcessor.writeProcess(rdd, dep, partitionId, context, partition) } override def preferredLocations: Seq[TaskLocation] = preferredLocs diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala new file mode 100644 index 0000000000000..f5213157a9a85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.MapStatus + +/** + * The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor + * and put it into [[ShuffleDependency]], and executors use it in each ShuffleMapTask. + */ +private[spark] class ShuffleWriteProcessor extends Serializable with Logging { + + /** + * Create a [[ShuffleWriteMetricsReporter]] from the task context. As the reporter is a + * per-row operator, here need a careful consideration on performance. + */ + protected def createMetricsReporter(context: TaskContext): ShuffleWriteMetricsReporter = { + context.taskMetrics().shuffleWriteMetrics + } + + /** + * The write process for particular partition, it controls the life circle of [[ShuffleWriter]] + * get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for + * this task. + */ + def writeProcess( + rdd: RDD[_], + dep: ShuffleDependency[_, _, _], + partitionId: Int, + context: TaskContext, + partition: Partition): MapStatus = { + var writer: ShuffleWriter[Any, Any] = null + try { + val manager = SparkEnv.get.shuffleManager + writer = manager.getWriter[Any, Any]( + dep.shuffleHandle, + partitionId, + context, + createMetricsReporter(context)) + writer.write( + rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.stop(success = true).get + } catch { + case e: Exception => + try { + if (writer != null) { + writer.stop(success = false) + } + } catch { + case e: Exception => + log.debug("Could not stop writer", e) + } + throw e + } + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4eeebb805070a..b3252d70a80c8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -217,6 +217,9 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"), + // [SPARK-26139] Implement shuffle write metrics in SQL + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"), + // Data Source V2 API changes (problem: Problem) => problem match { case MissingClassProblem(cls) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index c9ca395bceaa4..0c2020572e721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -23,6 +23,7 @@ import java.util.function.Supplier import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair @@ -46,10 +47,13 @@ case class ShuffleExchangeExec( // NOTE: coordinator can be null after serialization/deserialization, // e.g. it can be null on the Executor side - + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") - ) ++ SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + ) ++ readMetrics ++ writeMetrics override def nodeName: String = { val extraInfo = coordinator match { @@ -90,7 +94,11 @@ case class ShuffleExchangeExec( private[exchange] def prepareShuffleDependency() : ShuffleDependency[Int, InternalRow, InternalRow] = { ShuffleExchangeExec.prepareShuffleDependency( - child.execute(), child.output, newPartitioning, serializer) + child.execute(), + child.output, + newPartitioning, + serializer, + writeMetrics) } /** @@ -109,7 +117,7 @@ case class ShuffleExchangeExec( assert(newPartitioning.isInstanceOf[HashPartitioning]) newPartitioning = UnknownPartitioning(indices.length) } - new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices) + new ShuffledRowRDD(shuffleDependency, readMetrics, specifiedPartitionStartIndices) } /** @@ -204,7 +212,9 @@ object ShuffleExchangeExec { rdd: RDD[InternalRow], outputAttributes: Seq[Attribute], newPartitioning: Partitioning, - serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { + serializer: Serializer, + writeMetrics: Map[String, SQLMetric]) + : ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) case HashPartitioning(_, n) => @@ -333,8 +343,22 @@ object ShuffleExchangeExec { new ShuffleDependency[Int, InternalRow, InternalRow]( rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), - serializer) + serializer, + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics)) dependency } + + /** + * Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter + * with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]]. + */ + def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = { + new ShuffleWriteProcessor { + override protected def createMetricsReporter( + context: TaskContext): ShuffleWriteMetricsReporter = { + new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index e9ab7cd138d99..1f2fdde538645 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter +import org.apache.spark.sql.execution.metric.{SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter} /** * Take the first `limit` elements and collect them to a single partition. @@ -38,13 +38,21 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - locallyLimited, child.output, SinglePartition, serializer), - metrics) + locallyLimited, + child.output, + SinglePartition, + serializer, + writeMetrics), + readMetrics) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -154,7 +162,11 @@ case class TakeOrderedAndProjectExec( private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) - override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) @@ -165,8 +177,12 @@ case class TakeOrderedAndProjectExec( } val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - localTopK, child.output, SinglePartition, serializer), - metrics) + localTopK, + child.output, + SinglePartition, + serializer, + writeMetrics), + readMetrics) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList != child.output) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index cbf707f4a9cfd..19809b07508d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import scala.concurrent.duration._ + import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates @@ -78,6 +80,7 @@ object SQLMetrics { private val SUM_METRIC = "sum" private val SIZE_METRIC = "size" private val TIMING_METRIC = "timing" + private val NS_TIMING_METRIC = "nsTiming" private val AVERAGE_METRIC = "average" private val baseForAvgMetric: Int = 10 @@ -121,6 +124,13 @@ object SQLMetrics { acc } + def createNanoTimingMetric(sc: SparkContext, name: String): SQLMetric = { + // Same with createTimingMetric, just normalize the unit of time to millisecond. + val acc = new SQLMetric(NS_TIMING_METRIC, -1) + acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = false) + acc + } + /** * Create a metric to report the average information (including min, med, max) like * avg hash probe. As average metrics are double values, this kind of metrics should be @@ -163,6 +173,8 @@ object SQLMetrics { Utils.bytesToString } else if (metricsType == TIMING_METRIC) { Utils.msDurationToString + } else if (metricsType == NS_TIMING_METRIC) { + duration => Utils.msDurationToString(duration.nanos.toMillis) } else { throw new IllegalStateException("unexpected metrics type: " + metricsType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala index 780f0d7622294..ff7941e3b3e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.SparkContext import org.apache.spark.executor.TempShuffleReadMetrics +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter /** * A shuffle metrics reporter for SQL exchange operators. @@ -95,3 +96,57 @@ private[spark] object SQLShuffleMetricsReporter { FETCH_WAIT_TIME -> SQLMetrics.createTimingMetric(sc, "fetch wait time"), RECORDS_READ -> SQLMetrics.createMetric(sc, "records read")) } + +/** + * A shuffle write metrics reporter for SQL exchange operators. + * @param metricsReporter Other reporter need to be updated in this SQLShuffleWriteMetricsReporter. + * @param metrics Shuffle write metrics in current SparkPlan. + */ +private[spark] class SQLShuffleWriteMetricsReporter( + metricsReporter: ShuffleWriteMetricsReporter, + metrics: Map[String, SQLMetric]) extends ShuffleWriteMetricsReporter { + private[this] val _bytesWritten = + metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_BYTES_WRITTEN) + private[this] val _recordsWritten = + metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN) + private[this] val _writeTime = + metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME) + + override private[spark] def incBytesWritten(v: Long): Unit = { + metricsReporter.incBytesWritten(v) + _bytesWritten.add(v) + } + override private[spark] def decRecordsWritten(v: Long): Unit = { + metricsReporter.decBytesWritten(v) + _recordsWritten.set(_recordsWritten.value - v) + } + override private[spark] def incRecordsWritten(v: Long): Unit = { + metricsReporter.incRecordsWritten(v) + _recordsWritten.add(v) + } + override private[spark] def incWriteTime(v: Long): Unit = { + metricsReporter.incWriteTime(v) + _writeTime.add(v) + } + override private[spark] def decBytesWritten(v: Long): Unit = { + metricsReporter.decBytesWritten(v) + _bytesWritten.set(_bytesWritten.value - v) + } +} + +private[spark] object SQLShuffleWriteMetricsReporter { + val SHUFFLE_BYTES_WRITTEN = "shuffleBytesWritten" + val SHUFFLE_RECORDS_WRITTEN = "shuffleRecordsWritten" + val SHUFFLE_WRITE_TIME = "shuffleWriteTime" + + /** + * Create all shuffle write relative metrics and return the Map. + */ + def createShuffleWriteMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( + SHUFFLE_BYTES_WRITTEN -> + SQLMetrics.createSizeMetric(sc, "shuffle bytes written"), + SHUFFLE_RECORDS_WRITTEN -> + SQLMetrics.createMetric(sc, "shuffle records written"), + SHUFFLE_WRITE_TIME -> + SQLMetrics.createNanoTimingMetric(sc, "shuffle write time")) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0f1d08b6af5d5..2251607e76af8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -97,7 +97,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val shuffleExpected1 = Map( "records read" -> 2L, "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L) + "remote blocks fetched" -> 0L, + "shuffle records written" -> 2L) testSparkPlanMetrics(df, 1, Map( 2L -> (("HashAggregate", expected1(0))), 1L -> (("Exchange", shuffleExpected1)), @@ -114,7 +115,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val shuffleExpected2 = Map( "records read" -> 4L, "local blocks fetched" -> 4L, - "remote blocks fetched" -> 0L) + "remote blocks fetched" -> 0L, + "shuffle records written" -> 4L) testSparkPlanMetrics(df2, 1, Map( 2L -> (("HashAggregate", expected2(0))), 1L -> (("Exchange", shuffleExpected2)), @@ -170,6 +172,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))), + 1L -> (("Exchange", Map( + "shuffle records written" -> 2L, + "records read" -> 2L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L))), 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 1L)))) ) @@ -177,6 +184,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df2 = testData2.groupBy('a).agg(collect_set('a)) testSparkPlanMetrics(df2, 1, Map( 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 4L))), + 1L -> (("Exchange", Map( + "shuffle records written" -> 4L, + "records read" -> 4L, + "local blocks fetched" -> 4L, + "remote blocks fetched" -> 0L))), 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 3L)))) ) } @@ -205,7 +217,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared 2L -> (("Exchange", Map( "records read" -> 4L, "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L)))) + "remote blocks fetched" -> 0L, + "shuffle records written" -> 2L)))) ) } } @@ -299,12 +312,25 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value") // Assume the execution plan is - // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) + // Project(nodeId = 0) + // +- ShuffledHashJoin(nodeId = 1) + // :- Exchange(nodeId = 2) + // : +- Project(nodeId = 3) + // : +- LocalTableScan(nodeId = 4) + // +- Exchange(nodeId = 5) + // +- Project(nodeId = 6) + // +- LocalTableScan(nodeId = 7) val df = df1.join(df2, "key") testSparkPlanMetrics(df, 1, Map( 1L -> (("ShuffledHashJoin", Map( "number of output rows" -> 2L, - "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))), + 2L -> (("Exchange", Map( + "shuffle records written" -> 2L, + "records read" -> 2L))), + 5L -> (("Exchange", Map( + "shuffle records written" -> 10L, + "records read" -> 10L)))) ) } } From ec506bd30c2ca324c12c9ec811764081c2eb8c42 Mon Sep 17 00:00:00 2001 From: Shahid Date: Sun, 9 Dec 2018 11:44:16 -0600 Subject: [PATCH 2243/2461] [SPARK-26283][CORE] Enable reading from open frames of zstd, when reading zstd compressed eventLog ## What changes were proposed in this pull request? Root cause: Prior to Spark2.4, When we enable zst for eventLog compression, for inprogress application, It always throws exception in the Application UI, when we open from the history server. But after 2.4 it will display the UI information based on the completed frames in the zstd compressed eventLog. But doesn't read incomplete frames for inprogress application. In this PR, we have added 'setContinous(true)' for reading input stream from eventLog, so that it can read from open frames also. (By default 'isContinous=false' for zstd inputStream and when we try to read an open frame, it throws truncated error) ## How was this patch tested? Test steps: 1) Add the configurations in the spark-defaults.conf (i) spark.eventLog.compress true (ii) spark.io.compression.codec zstd 2) Restart history server 3) bin/spark-shell 4) sc.parallelize(1 to 1000, 1000).count 5) Open app UI from the history server UI **Before fix** ![screenshot from 2018-12-06 00-01-38](https://user-images.githubusercontent.com/23054875/49537340-bfe28b00-f8ee-11e8-9fca-6d42fdc89e1a.png) **After fix:** ![screenshot from 2018-12-06 00-34-39](https://user-images.githubusercontent.com/23054875/49537353-ca9d2000-f8ee-11e8-803d-645897b9153b.png) Closes #23241 from shahidki31/zstdEventLog. Authored-by: Shahid Signed-off-by: Sean Owen --- .../scala/org/apache/spark/io/CompressionCodec.scala | 12 ++++++++++++ .../spark/scheduler/EventLoggingListener.scala | 2 +- .../apache/spark/scheduler/ReplayListenerBus.scala | 2 -- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0664c5ac752c1..c4f4b18769d2b 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -43,6 +43,10 @@ trait CompressionCodec { def compressedOutputStream(s: OutputStream): OutputStream def compressedInputStream(s: InputStream): InputStream + + private[spark] def compressedContinuousInputStream(s: InputStream): InputStream = { + compressedInputStream(s) + } } private[spark] object CompressionCodec { @@ -197,4 +201,12 @@ class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec { // avoid overhead excessive of JNI call while trying to uncompress small amount of data. new BufferedInputStream(new ZstdInputStream(s), bufferSize) } + + override def compressedContinuousInputStream(s: InputStream): InputStream = { + // SPARK-26283: Enable reading from open frames of zstd (for eg: zstd compressed eventLog + // Reading). By default `isContinuous` is false, and when we try to read from open frames, + // `compressedInputStream` method above throws truncated error exception. This method set + // `isContinuous` true to allow reading from open frames. + new BufferedInputStream(new ZstdInputStream(s).setContinuous(true), bufferSize) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 5f697fe99258d..069a91f1a8fc8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -402,7 +402,7 @@ private[spark] object EventLoggingListener extends Logging { val codec = codecName(log).map { c => codecMap.getOrElseUpdate(c, CompressionCodec.createCodec(new SparkConf, c)) } - codec.map(_.compressedInputStream(in)).getOrElse(in) + codec.map(_.compressedContinuousInputStream(in)).getOrElse(in) } catch { case e: Throwable => in.close() diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 4c6b0c1227b18..226c23733c870 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -118,8 +118,6 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { case e: HaltReplayException => // Just stop replay. case _: EOFException if maybeTruncated => - case _: IOException if maybeTruncated => - logWarning(s"Failed to read Spark event log: $sourceName") case ioe: IOException => throw ioe case e: Exception => From 403c8d5a60b2712561044e320d6f9233ed3172bf Mon Sep 17 00:00:00 2001 From: 10087686 Date: Sun, 9 Dec 2018 22:44:41 -0800 Subject: [PATCH 2244/2461] [SPARK-26287][CORE] Don't need to create an empty spill file when memory has no records ## What changes were proposed in this pull request? If there are no records in memory, then we don't need to create an empty temp spill file. ## How was this patch tested? Existing tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23225 from wangjiaochun/ShufflSorter. Authored-by: 10087686 Signed-off-by: Dongjoon Hyun --- .../spark/shuffle/sort/ShuffleExternalSorter.java | 13 +++++++++---- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 1 + 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 6ee9d5f0eec3b..dc43215373e11 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -145,6 +145,15 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private void writeSortedFile(boolean isLastFile) { + // This call performs the actual sort. + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = + inMemSorter.getSortedIterator(); + + // If there are no sorted records, so we don't need to create an empty spill file. + if (!sortedRecords.hasNext()) { + return; + } + final ShuffleWriteMetricsReporter writeMetricsToUse; if (isLastFile) { @@ -157,10 +166,6 @@ private void writeSortedFile(boolean isLastFile) { writeMetricsToUse = new ShuffleWriteMetrics(); } - // This call performs the actual sort. - final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = - inMemSorter.getSortedIterator(); - // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. This array does not need to be large enough to hold a single diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 30ad3f5575545..aa5082f1ac7ff 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -235,6 +235,7 @@ public void writeEmptyIterator() throws Exception { final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); + assertEquals(0, spillFilesCreated.size()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); assertEquals(0, taskMetrics.shuffleWriteMetrics().bytesWritten()); From 3bc83de3cce86a06c275c86b547a99afd781761f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 10 Dec 2018 14:57:20 +0800 Subject: [PATCH 2245/2461] [SPARK-26307][SQL] Fix CTAS when INSERT a partitioned table using Hive serde ## What changes were proposed in this pull request? This is a Spark 2.3 regression introduced in https://github.com/apache/spark/pull/20521. We should add the partition info for InsertIntoHiveTable in CreateHiveTableAsSelectCommand. Otherwise, we will hit the following error by running the newly added test case: ``` [info] - CTAS: INSERT a partitioned table using Hive serde *** FAILED *** (829 milliseconds) [info] org.apache.spark.SparkException: Requested partitioning does not match the tab1 table: [info] Requested partitions: [info] Table partitions: part [info] at org.apache.spark.sql.hive.execution.InsertIntoHiveTable.processInsert(InsertIntoHiveTable.scala:179) [info] at org.apache.spark.sql.hive.execution.InsertIntoHiveTable.run(InsertIntoHiveTable.scala:107) ``` ## How was this patch tested? Added a test case. Closes #23255 from gatorsmile/fixCTAS. Authored-by: gatorsmile Signed-off-by: Wenchen Fan --- .../execution/CreateHiveTableAsSelectCommand.scala | 4 +++- .../scala/org/apache/spark/sql/hive/InsertSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 630bea5161f19..fd1e931ee0c7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -57,9 +57,11 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } + // For CTAS, there is no static partition values to insert. + val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap InsertIntoHiveTable( tableDesc, - Map.empty, + partition, query, overwrite = false, ifPartitionNotExists = false, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index 5879748d05b2b..510de3a7eab57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -752,6 +752,17 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } + test("SPARK-26307: CTAS - INSERT a partitioned table using Hive serde") { + withTable("tab1") { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + val df = Seq(("a", 100)).toDF("part", "id") + df.write.format("hive").partitionBy("part").mode("overwrite").saveAsTable("tab1") + df.write.format("hive").partitionBy("part").mode("append").saveAsTable("tab1") + } + } + } + + Seq("LOCAL", "").foreach { local => Seq(true, false).foreach { caseSensitivity => Seq("orc", "parquet").foreach { format => From c8ac6ae84c8a3cc7ed787155a84cbeb56c78a048 Mon Sep 17 00:00:00 2001 From: Darcy Shen Date: Mon, 10 Dec 2018 22:26:28 +0800 Subject: [PATCH 2246/2461] [SPARK-26319][SQL][TEST] Add appendReadColumns Unit Test for HiveShimSuite ## What changes were proposed in this pull request? Add appendReadColumns Unit Test for HiveShimSuite. ## How was this patch tested? ``` $ build/sbt > project hive > testOnly *HiveShimSuite ``` Closes #23268 from sadhen/refactor/hiveshim. Authored-by: Darcy Shen Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/hive/HiveShimSuite.scala | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala new file mode 100644 index 0000000000000..a716f739b5c20 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils + +import org.apache.spark.SparkFunSuite + +class HiveShimSuite extends SparkFunSuite { + + test("appendReadColumns") { + val conf = new Configuration + val ids = Seq(1, 2, 3).map(Int.box) + val names = Seq("a", "b", "c") + val moreIds = Seq(4, 5).map(Int.box) + val moreNames = Seq("d", "e") + + // test when READ_COLUMN_NAMES_CONF_STR is empty + HiveShim.appendReadColumns(conf, ids, names) + assert(names.asJava === ColumnProjectionUtils.getReadColumnNames(conf)) + + // test when READ_COLUMN_NAMES_CONF_STR is non-empty + HiveShim.appendReadColumns(conf, moreIds, moreNames) + assert((names ++ moreNames).asJava === ColumnProjectionUtils.getReadColumnNames(conf)) + } +} From 42e8c381b15dd48c2f00c088c897ebdd25405aef Mon Sep 17 00:00:00 2001 From: 10087686 Date: Mon, 10 Dec 2018 22:28:26 +0800 Subject: [PATCH 2247/2461] [SPARK-26286][TEST] Add MAXIMUM_PAGE_SIZE_BYTES exception bound unit test ## What changes were proposed in this pull request? Add MAXIMUM_PAGE_SIZE_BYTES Exception test ## How was this patch tested? Existing tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23226 from wangjiaochun/BytesToBytesMapSuite. Authored-by: 10087686 Signed-off-by: Hyukjin Kwon --- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index aa29232e73e13..a11cd535b5471 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -622,6 +622,17 @@ public void initialCapacityBoundsChecking() { } catch (IllegalArgumentException e) { // expected exception } + + try { + new BytesToBytesMap( + taskMemoryManager, + 1, + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES + 1); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + } @Test From 9794923272c26ee5ba760a57718a368c33d09f04 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 10 Dec 2018 22:37:17 +0800 Subject: [PATCH 2248/2461] [MINOR][DOC] Update the condition description of serialized shuffle ## What changes were proposed in this pull request? `1. The shuffle dependency specifies no aggregation or output ordering.` If the shuffle dependency specifies aggregation, but it only aggregates at the reduce-side, serialized shuffle can still be used. `3. The shuffle produces fewer than 16777216 output partitions.` If the number of output partitions is 16777216 , we can use serialized shuffle. We can see this mothod: `canUseSerializedShuffle` ## How was this patch tested? N/A Closes #23228 from 10110346/SerializedShuffle_doc. Authored-by: liuxian Signed-off-by: Wenchen Fan --- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b51a843a31c31..b59fa8e8a3ccd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -33,10 +33,10 @@ import org.apache.spark.shuffle._ * Sort-based shuffle has two different write paths for producing its map output files: * * - Serialized sorting: used when all three of the following conditions hold: - * 1. The shuffle dependency specifies no aggregation or output ordering. + * 1. The shuffle dependency specifies no map-side combine. * 2. The shuffle serializer supports relocation of serialized values (this is currently * supported by KryoSerializer and Spark SQL's custom serializers). - * 3. The shuffle produces fewer than 16777216 output partitions. + * 3. The shuffle produces fewer than or equal to 16777216 output partitions. * - Deserialized sorting: used to handle all other cases. * * ----------------------- From 0bf6c77141e40cc636351c5e77194bb75144bb12 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 10 Dec 2018 10:27:04 -0600 Subject: [PATCH 2249/2461] This tests pushing to gitbox From b1a724b468d5c1c4aee2a22ffc6d8edac537c3d6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 10 Dec 2018 11:11:06 -0600 Subject: [PATCH 2250/2461] This tests pushing directly to github From 90c77ea3132d0b7a12c316bd42fb8d0f59bee253 Mon Sep 17 00:00:00 2001 From: Reza Safi Date: Mon, 10 Dec 2018 11:14:11 -0600 Subject: [PATCH 2251/2461] [SPARK-24958][CORE] Add memory from procfs to executor metrics. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds the entire memory used by spark’s executor (as measured by procfs) to the executor metrics. The memory usage is collected from the entire process tree under the executor. The metrics are subdivided into memory used by java, by python, and by other processes, to aid users in diagnosing the source of high memory usage. The additional metrics are sent to the driver in heartbeats, using the mechanism introduced by SPARK-23429. This also slightly extends that approach to allow one ExecutorMetricType to collect multiple metrics. Added unit tests and also tested on a live cluster. Closes #22612 from rezasafi/ptreememory2. Authored-by: Reza Safi Signed-off-by: Imran Rashid --- .../scala/org/apache/spark/Heartbeater.scala | 11 +- .../spark/executor/ExecutorMetrics.scala | 23 +- .../spark/executor/ProcfsMetricsGetter.scala | 228 ++++++++++++++++++ .../spark/internal/config/package.scala | 5 + .../spark/metrics/ExecutorMetricType.scala | 74 +++++- .../org/apache/spark/status/api/v1/api.scala | 6 +- .../org/apache/spark/util/JsonProtocol.scala | 16 +- .../application_list_json_expectation.json | 15 ++ .../completed_app_list_json_expectation.json | 15 ++ ...ith_executor_metrics_json_expectation.json | 40 ++- ...process_tree_metrics_json_expectation.json | 98 ++++++++ .../limit_app_list_json_expectation.json | 30 +-- .../minDate_app_list_json_expectation.json | 15 ++ .../minEndDate_app_list_json_expectation.json | 15 ++ .../test/resources/ProcfsMetrics/22763/stat | 1 + .../test/resources/ProcfsMetrics/26109/stat | 1 + .../application_1538416563558_0014 | 190 +++++++++++++++ .../deploy/history/HistoryServerSuite.scala | 3 + .../executor/ProcfsMetricsGetterSuite.scala | 41 ++++ .../scheduler/EventLoggingListenerSuite.scala | 85 ++++--- .../spark/status/AppStatusListenerSuite.scala | 74 ++++-- .../apache/spark/util/JsonProtocolSuite.scala | 46 ++-- dev/.rat-excludes | 2 + 23 files changed, 901 insertions(+), 133 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala create mode 100644 core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_process_tree_metrics_json_expectation.json create mode 100644 core/src/test/resources/ProcfsMetrics/22763/stat create mode 100644 core/src/test/resources/ProcfsMetrics/26109/stat create mode 100644 core/src/test/resources/spark-events/application_1538416563558_0014 create mode 100644 core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Heartbeater.scala b/core/src/main/scala/org/apache/spark/Heartbeater.scala index 84091eef04306..1012755e068d1 100644 --- a/core/src/main/scala/org/apache/spark/Heartbeater.scala +++ b/core/src/main/scala/org/apache/spark/Heartbeater.scala @@ -61,10 +61,17 @@ private[spark] class Heartbeater( /** * Get the current executor level metrics. These are returned as an array, with the index - * determined by ExecutorMetricType.values + * determined by ExecutorMetricType.metricToOffset */ def getCurrentMetrics(): ExecutorMetrics = { - val metrics = ExecutorMetricType.values.map(_.getMetricValue(memoryManager)).toArray + + val metrics = new Array[Long](ExecutorMetricType.numMetrics) + var offset = 0 + ExecutorMetricType.metricGetters.foreach { metric => + val newMetrics = metric.getMetricValues(memoryManager) + Array.copy(newMetrics, 0, metrics, offset, newMetrics.size) + offset += newMetrics.length + } new ExecutorMetrics(metrics) } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala index 1befd27de1cba..f19ac813fde34 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorMetrics.scala @@ -27,17 +27,15 @@ import org.apache.spark.metrics.ExecutorMetricType */ @DeveloperApi class ExecutorMetrics private[spark] extends Serializable { - - // Metrics are indexed by ExecutorMetricType.values - private val metrics = new Array[Long](ExecutorMetricType.values.length) - + // Metrics are indexed by ExecutorMetricType.metricToOffset + private val metrics = new Array[Long](ExecutorMetricType.numMetrics) // the first element is initialized to -1, indicating that the values for the array // haven't been set yet. metrics(0) = -1 - /** Returns the value for the specified metricType. */ - def getMetricValue(metricType: ExecutorMetricType): Long = { - metrics(ExecutorMetricType.metricIdxMap(metricType)) + /** Returns the value for the specified metric. */ + def getMetricValue(metricName: String): Long = { + metrics(ExecutorMetricType.metricToOffset(metricName)) } /** Returns true if the values for the metrics have been set, false otherwise. */ @@ -49,14 +47,14 @@ class ExecutorMetrics private[spark] extends Serializable { } /** - * Constructor: create the ExecutorMetrics with the values specified. + * Constructor: create the ExecutorMetrics with using a given map. * * @param executorMetrics map of executor metric name to value */ private[spark] def this(executorMetrics: Map[String, Long]) { this() - (0 until ExecutorMetricType.values.length).foreach { idx => - metrics(idx) = executorMetrics.getOrElse(ExecutorMetricType.values(idx).name, 0L) + ExecutorMetricType.metricToOffset.foreach { case(name, idx) => + metrics(idx) = executorMetrics.getOrElse(name, 0L) } } @@ -69,9 +67,8 @@ class ExecutorMetrics private[spark] extends Serializable { */ private[spark] def compareAndUpdatePeakValues(executorMetrics: ExecutorMetrics): Boolean = { var updated = false - - (0 until ExecutorMetricType.values.length).foreach { idx => - if (executorMetrics.metrics(idx) > metrics(idx)) { + (0 until ExecutorMetricType.numMetrics).foreach { idx => + if (executorMetrics.metrics(idx) > metrics(idx)) { updated = true metrics(idx) = executorMetrics.metrics(idx) } diff --git a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala new file mode 100644 index 0000000000000..af67f41e94af1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.io._ +import java.nio.charset.Charset +import java.nio.file.{Files, Paths} +import java.util.Locale + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Try + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.util.Utils + + +private[spark] case class ProcfsMetrics( + jvmVmemTotal: Long, + jvmRSSTotal: Long, + pythonVmemTotal: Long, + pythonRSSTotal: Long, + otherVmemTotal: Long, + otherRSSTotal: Long) + +// Some of the ideas here are taken from the ProcfsBasedProcessTree class in hadoop +// project. +private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends Logging { + private val procfsStatFile = "stat" + private val testing = sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") + private val pageSize = computePageSize() + private var isAvailable: Boolean = isProcfsAvailable + private val pid = computePid() + + private lazy val isProcfsAvailable: Boolean = { + if (testing) { + true + } + else { + val procDirExists = Try(Files.exists(Paths.get(procfsDir))).recover { + case ioe: IOException => + logWarning("Exception checking for procfs dir", ioe) + false + } + val shouldLogStageExecutorMetrics = + SparkEnv.get.conf.get(config.EVENT_LOG_STAGE_EXECUTOR_METRICS) + val shouldLogStageExecutorProcessTreeMetrics = + SparkEnv.get.conf.get(config.EVENT_LOG_PROCESS_TREE_METRICS) + procDirExists.get && shouldLogStageExecutorProcessTreeMetrics && shouldLogStageExecutorMetrics + } + } + + private def computePid(): Int = { + if (!isAvailable || testing) { + return -1; + } + try { + // This can be simplified in java9: + // https://docs.oracle.com/javase/9/docs/api/java/lang/ProcessHandle.html + val cmd = Array("bash", "-c", "echo $PPID") + val out = Utils.executeAndGetOutput(cmd) + Integer.parseInt(out.split("\n")(0)) + } + catch { + case e: SparkException => + logWarning("Exception when trying to compute process tree." + + " As a result reporting of ProcessTree metrics is stopped", e) + isAvailable = false + -1 + } + } + + private def computePageSize(): Long = { + if (testing) { + return 4096; + } + try { + val cmd = Array("getconf", "PAGESIZE") + val out = Utils.executeAndGetOutput(cmd) + Integer.parseInt(out.split("\n")(0)) + } catch { + case e: Exception => + logWarning("Exception when trying to compute pagesize, as a" + + " result reporting of ProcessTree metrics is stopped") + isAvailable = false + 0 + } + } + + private def computeProcessTree(): Set[Int] = { + if (!isAvailable || testing) { + return Set() + } + var ptree: Set[Int] = Set() + ptree += pid + val queue = mutable.Queue.empty[Int] + queue += pid + while ( !queue.isEmpty ) { + val p = queue.dequeue() + val c = getChildPids(p) + if (!c.isEmpty) { + queue ++= c + ptree ++= c.toSet + } + } + ptree + } + + private def getChildPids(pid: Int): ArrayBuffer[Int] = { + try { + val builder = new ProcessBuilder("pgrep", "-P", pid.toString) + val process = builder.start() + val childPidsInInt = mutable.ArrayBuffer.empty[Int] + def appendChildPid(s: String): Unit = { + if (s != "") { + logTrace("Found a child pid:" + s) + childPidsInInt += Integer.parseInt(s) + } + } + val stdoutThread = Utils.processStreamByLine("read stdout for pgrep", + process.getInputStream, appendChildPid) + val errorStringBuilder = new StringBuilder() + val stdErrThread = Utils.processStreamByLine( + "stderr for pgrep", + process.getErrorStream, + line => errorStringBuilder.append(line)) + val exitCode = process.waitFor() + stdoutThread.join() + stdErrThread.join() + val errorString = errorStringBuilder.toString() + // pgrep will have exit code of 1 if there are more than one child process + // and it will have a exit code of 2 if there is no child process + if (exitCode != 0 && exitCode > 2) { + val cmd = builder.command().toArray.mkString(" ") + logWarning(s"Process $cmd exited with code $exitCode and stderr: $errorString") + throw new SparkException(s"Process $cmd exited with code $exitCode") + } + childPidsInInt + } catch { + case e: Exception => + logWarning("Exception when trying to compute process tree." + + " As a result reporting of ProcessTree metrics is stopped.", e) + isAvailable = false + mutable.ArrayBuffer.empty[Int] + } + } + + def addProcfsMetricsFromOneProcess( + allMetrics: ProcfsMetrics, + pid: Int): ProcfsMetrics = { + + // The computation of RSS and Vmem are based on proc(5): + // http://man7.org/linux/man-pages/man5/proc.5.html + try { + val pidDir = new File(procfsDir, pid.toString) + def openReader(): BufferedReader = { + val f = new File(new File(procfsDir, pid.toString), procfsStatFile) + new BufferedReader(new InputStreamReader(new FileInputStream(f), Charset.forName("UTF-8"))) + } + Utils.tryWithResource(openReader) { in => + val procInfo = in.readLine + val procInfoSplit = procInfo.split(" ") + val vmem = procInfoSplit(22).toLong + val rssMem = procInfoSplit(23).toLong * pageSize + if (procInfoSplit(1).toLowerCase(Locale.US).contains("java")) { + allMetrics.copy( + jvmVmemTotal = allMetrics.jvmVmemTotal + vmem, + jvmRSSTotal = allMetrics.jvmRSSTotal + (rssMem) + ) + } + else if (procInfoSplit(1).toLowerCase(Locale.US).contains("python")) { + allMetrics.copy( + pythonVmemTotal = allMetrics.pythonVmemTotal + vmem, + pythonRSSTotal = allMetrics.pythonRSSTotal + (rssMem) + ) + } + else { + allMetrics.copy( + otherVmemTotal = allMetrics.otherVmemTotal + vmem, + otherRSSTotal = allMetrics.otherRSSTotal + (rssMem) + ) + } + } + } catch { + case f: IOException => + logWarning("There was a problem with reading" + + " the stat file of the process. ", f) + ProcfsMetrics(0, 0, 0, 0, 0, 0) + } + } + + private[spark] def computeAllMetrics(): ProcfsMetrics = { + if (!isAvailable) { + return ProcfsMetrics(0, 0, 0, 0, 0, 0) + } + val pids = computeProcessTree + var allMetrics = ProcfsMetrics(0, 0, 0, 0, 0, 0) + for (p <- pids) { + allMetrics = addProcfsMetricsFromOneProcess(allMetrics, p) + // if we had an error getting any of the metrics, we don't want to report partial metrics, as + // that would be misleading. + if (!isAvailable) { + return ProcfsMetrics(0, 0, 0, 0, 0, 0) + } + } + allMetrics + } +} + +private[spark] object ProcfsMetricsGetter { + final val pTreeInfo = new ProcfsMetricsGetter +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 646b3881a79b0..85bb557abef5d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -93,6 +93,11 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val EVENT_LOG_PROCESS_TREE_METRICS = + ConfigBuilder("spark.eventLog.logStageExecutorProcessTreeMetrics.enabled") + .booleanConf + .createWithDefault(false) + private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala index cd10dad25e87b..704b36d3118b7 100644 --- a/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala +++ b/core/src/main/scala/org/apache/spark/metrics/ExecutorMetricType.scala @@ -19,25 +19,43 @@ package org.apache.spark.metrics import java.lang.management.{BufferPoolMXBean, ManagementFactory} import javax.management.ObjectName +import scala.collection.mutable + +import org.apache.spark.executor.ProcfsMetricsGetter import org.apache.spark.memory.MemoryManager /** * Executor metric types for executor-level metrics stored in ExecutorMetrics. */ sealed trait ExecutorMetricType { + private[spark] def getMetricValues(memoryManager: MemoryManager): Array[Long] + private[spark] def names: Seq[String] +} + +sealed trait SingleValueExecutorMetricType extends ExecutorMetricType { + override private[spark] def names = { + Seq(getClass().getName(). + stripSuffix("$").split("""\.""").last) + } + + override private[spark] def getMetricValues(memoryManager: MemoryManager): Array[Long] = { + val metrics = new Array[Long](1) + metrics(0) = getMetricValue(memoryManager) + metrics + } + private[spark] def getMetricValue(memoryManager: MemoryManager): Long - private[spark] val name = getClass().getName().stripSuffix("$").split("""\.""").last } private[spark] abstract class MemoryManagerExecutorMetricType( - f: MemoryManager => Long) extends ExecutorMetricType { + f: MemoryManager => Long) extends SingleValueExecutorMetricType { override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { f(memoryManager) } } private[spark] abstract class MBeanExecutorMetricType(mBeanName: String) - extends ExecutorMetricType { + extends SingleValueExecutorMetricType { private val bean = ManagementFactory.newPlatformMXBeanProxy( ManagementFactory.getPlatformMBeanServer, new ObjectName(mBeanName).toString, classOf[BufferPoolMXBean]) @@ -47,18 +65,40 @@ private[spark] abstract class MBeanExecutorMetricType(mBeanName: String) } } -case object JVMHeapMemory extends ExecutorMetricType { +case object JVMHeapMemory extends SingleValueExecutorMetricType { override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { ManagementFactory.getMemoryMXBean.getHeapMemoryUsage().getUsed() } } -case object JVMOffHeapMemory extends ExecutorMetricType { +case object JVMOffHeapMemory extends SingleValueExecutorMetricType { override private[spark] def getMetricValue(memoryManager: MemoryManager): Long = { ManagementFactory.getMemoryMXBean.getNonHeapMemoryUsage().getUsed() } } +case object ProcessTreeMetrics extends ExecutorMetricType { + override val names = Seq( + "ProcessTreeJVMVMemory", + "ProcessTreeJVMRSSMemory", + "ProcessTreePythonVMemory", + "ProcessTreePythonRSSMemory", + "ProcessTreeOtherVMemory", + "ProcessTreeOtherRSSMemory") + + override private[spark] def getMetricValues(memoryManager: MemoryManager): Array[Long] = { + val allMetrics = ProcfsMetricsGetter.pTreeInfo.computeAllMetrics() + val processTreeMetrics = new Array[Long](names.length) + processTreeMetrics(0) = allMetrics.jvmVmemTotal + processTreeMetrics(1) = allMetrics.jvmRSSTotal + processTreeMetrics(2) = allMetrics.pythonVmemTotal + processTreeMetrics(3) = allMetrics.pythonRSSTotal + processTreeMetrics(4) = allMetrics.otherVmemTotal + processTreeMetrics(5) = allMetrics.otherRSSTotal + processTreeMetrics + } +} + case object OnHeapExecutionMemory extends MemoryManagerExecutorMetricType( _.onHeapExecutionMemoryUsed) @@ -84,8 +124,9 @@ case object MappedPoolMemory extends MBeanExecutorMetricType( "java.nio:type=BufferPool,name=mapped") private[spark] object ExecutorMetricType { - // List of all executor metric types - val values = IndexedSeq( + + // List of all executor metric getters + val metricGetters = IndexedSeq( JVMHeapMemory, JVMOffHeapMemory, OnHeapExecutionMemory, @@ -95,10 +136,21 @@ private[spark] object ExecutorMetricType { OnHeapUnifiedMemory, OffHeapUnifiedMemory, DirectPoolMemory, - MappedPoolMemory + MappedPoolMemory, + ProcessTreeMetrics ) - // Map of executor metric type to its index in values. - val metricIdxMap = - Map[ExecutorMetricType, Int](ExecutorMetricType.values.zipWithIndex: _*) + + val (metricToOffset, numMetrics) = { + var numberOfMetrics = 0 + val definedMetricsAndOffset = mutable.LinkedHashMap.empty[String, Int] + metricGetters.foreach { m => + var metricInSet = 0 + (0 until m.names.length).foreach { idx => + definedMetricsAndOffset += (m.names(idx) -> (idx + numberOfMetrics)) + } + numberOfMetrics += m.names.length + } + (definedMetricsAndOffset, numberOfMetrics) + } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index aa21da2b66ab2..c7d3cd37db6f9 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -133,9 +133,9 @@ private[spark] class ExecutorMetricsJsonSerializer jsonGenerator: JsonGenerator, serializerProvider: SerializerProvider): Unit = { metrics.foreach { m: ExecutorMetrics => - val metricsMap = ExecutorMetricType.values.map { metricType => - metricType.name -> m.getMetricValue(metricType) - }.toMap + val metricsMap = ExecutorMetricType.metricToOffset.map { case (metric, _) => + metric -> m.getMetricValue(metric) + } jsonGenerator.writeObject(metricsMap) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 0cd8612b8fd1c..348291fe5e7ac 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -394,10 +394,10 @@ private[spark] object JsonProtocol { /** Convert executor metrics to JSON. */ def executorMetricsToJson(executorMetrics: ExecutorMetrics): JValue = { - val metrics = ExecutorMetricType.values.map{ metricType => - JField(metricType.name, executorMetrics.getMetricValue(metricType)) - } - JObject(metrics: _*) + val metrics = ExecutorMetricType.metricToOffset.map { case (m, _) => + JField(m, executorMetrics.getMetricValue(m)) + } + JObject(metrics.toSeq: _*) } def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { @@ -611,10 +611,10 @@ private[spark] object JsonProtocol { /** Extract the executor metrics from JSON. */ def executorMetricsFromJson(json: JValue): ExecutorMetrics = { val metrics = - ExecutorMetricType.values.map { metric => - metric.name -> jsonOption(json \ metric.name).map(_.extract[Long]).getOrElse(0L) - }.toMap - new ExecutorMetrics(metrics) + ExecutorMetricType.metricToOffset.map { case (metric, _) => + metric -> jsonOption(json \ metric).map(_.extract[Long]).getOrElse(0L) + } + new ExecutorMetrics(metrics.toMap) } def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index eea6f595efd2a..0f0ccf9858a38 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1538416563558_0014", + "name" : "PythonBisectingKMeansExample", + "attempts" : [ { + "startTime" : "2018-10-02T00:42:39.580GMT", + "endTime" : "2018-10-02T00:44:02.338GMT", + "lastUpdated" : "", + "duration" : 82758, + "sparkUser" : "root", + "completed" : true, + "appSparkVersion" : "2.5.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1538440959580, + "endTimeEpoch" : 1538441042338 + } ] +}, { "id" : "application_1506645932520_24630151", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 7bc7f31be097b..e136a35a1e3a9 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1538416563558_0014", + "name" : "PythonBisectingKMeansExample", + "attempts" : [ { + "startTime" : "2018-10-02T00:42:39.580GMT", + "endTime" : "2018-10-02T00:44:02.338GMT", + "lastUpdated" : "", + "duration" : 82758, + "sparkUser" : "root", + "completed" : true, + "appSparkVersion" : "2.5.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1538440959580, + "endTimeEpoch" : 1538441042338 + } ] +}, { "id" : "application_1506645932520_24630151", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json index 9bf2086cc8e72..75674778dd1f6 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_metrics_json_expectation.json @@ -37,7 +37,13 @@ "DirectPoolMemory" : 397602, "MappedPoolMemory" : 0, "JVMHeapMemory" : 629553808, - "OffHeapStorageMemory" : 0 + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory": 0, + "ProcessTreeJVMRSSMemory": 0, + "ProcessTreePythonVMemory": 0, + "ProcessTreePythonRSSMemory": 0, + "ProcessTreeOtherVMemory": 0, + "ProcessTreeOtherRSSMemory": 0 } }, { "id" : "7", @@ -177,7 +183,13 @@ "DirectPoolMemory" : 126261, "MappedPoolMemory" : 0, "JVMHeapMemory" : 518613056, - "OffHeapStorageMemory" : 0 + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory": 0, + "ProcessTreeJVMRSSMemory": 0, + "ProcessTreePythonVMemory": 0, + "ProcessTreePythonRSSMemory": 0, + "ProcessTreeOtherVMemory": 0, + "ProcessTreeOtherRSSMemory": 0 } }, { "id" : "3", @@ -221,7 +233,13 @@ "DirectPoolMemory" : 87796, "MappedPoolMemory" : 0, "JVMHeapMemory" : 726805712, - "OffHeapStorageMemory" : 0 + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory": 0, + "ProcessTreeJVMRSSMemory": 0, + "ProcessTreePythonVMemory": 0, + "ProcessTreePythonRSSMemory": 0, + "ProcessTreeOtherVMemory": 0, + "ProcessTreeOtherRSSMemory": 0 } }, { "id" : "2", @@ -265,7 +283,13 @@ "DirectPoolMemory" : 87796, "MappedPoolMemory" : 0, "JVMHeapMemory" : 595946552, - "OffHeapStorageMemory" : 0 + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory": 0, + "ProcessTreeJVMRSSMemory": 0, + "ProcessTreePythonVMemory": 0, + "ProcessTreePythonRSSMemory": 0, + "ProcessTreeOtherVMemory": 0, + "ProcessTreeOtherRSSMemory": 0 } }, { "id" : "1", @@ -309,6 +333,12 @@ "DirectPoolMemory" : 98230, "MappedPoolMemory" : 0, "JVMHeapMemory" : 755008624, - "OffHeapStorageMemory" : 0 + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory": 0, + "ProcessTreeJVMRSSMemory": 0, + "ProcessTreePythonVMemory": 0, + "ProcessTreePythonRSSMemory": 0, + "ProcessTreeOtherVMemory": 0, + "ProcessTreeOtherRSSMemory": 0 } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_process_tree_metrics_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_process_tree_metrics_json_expectation.json new file mode 100644 index 0000000000000..69efefe736dd4 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_with_executor_process_tree_metrics_json_expectation.json @@ -0,0 +1,98 @@ +[ { + "id" : "driver", + "hostPort" : "rezamemory-1.gce.something.com:43959", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "addTime" : "2018-10-02T00:42:47.690GMT", + "executorLogs" : { }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 554933, + "JVMOffHeapMemory" : 104976128, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 554933, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 228407, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 350990264, + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory" : 5067235328, + "ProcessTreeJVMRSSMemory" : 710475776, + "ProcessTreePythonVMemory" : 408375296, + "ProcessTreePythonRSSMemory" : 40284160, + "ProcessTreeOtherVMemory" : 0, + "ProcessTreeOtherRSSMemory" : 0 + } +}, { + "id" : "9", + "hostPort" : "rezamemory-2.gce.something.com:40797", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 1, + "maxTasks" : 1, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 2, + "totalTasks" : 2, + "totalDuration" : 6191, + "totalGCTime" : 288, + "totalInputBytes" : 108, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "addTime" : "2018-10-02T00:43:56.142GMT", + "executorLogs" : { + "stdout" : "http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000010/root/stdout?start=-4096", + "stderr" : "http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000010/root/stderr?start=-4096" + }, + "memoryMetrics" : { + "usedOnHeapStorageMemory" : 0, + "usedOffHeapStorageMemory" : 0, + "totalOnHeapStorageMemory" : 384093388, + "totalOffHeapStorageMemory" : 0 + }, + "blacklistedInStages" : [ ], + "peakMemoryMetrics" : { + "OnHeapStorageMemory" : 1088805, + "JVMOffHeapMemory" : 59006656, + "OffHeapExecutionMemory" : 0, + "OnHeapUnifiedMemory" : 1088805, + "OnHeapExecutionMemory" : 0, + "OffHeapUnifiedMemory" : 0, + "DirectPoolMemory" : 20181, + "MappedPoolMemory" : 0, + "JVMHeapMemory" : 193766856, + "OffHeapStorageMemory" : 0, + "ProcessTreeJVMVMemory" : 3016261632, + "ProcessTreeJVMRSSMemory" : 405860352, + "ProcessTreePythonVMemory" : 625926144, + "ProcessTreePythonRSSMemory" : 69013504, + "ProcessTreeOtherVMemory" : 0, + "ProcessTreeOtherRSSMemory" : 0 + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index 9e1e65a358815..0ef9377dcb08b 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1538416563558_0014", + "name" : "PythonBisectingKMeansExample", + "attempts" : [ { + "startTime" : "2018-10-02T00:42:39.580GMT", + "endTime" : "2018-10-02T00:44:02.338GMT", + "lastUpdated" : "", + "duration" : 82758, + "sparkUser" : "root", + "completed" : true, + "appSparkVersion" : "2.5.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1538440959580, + "endTimeEpoch" : 1538441042338 + } ] +}, { "id" : "application_1506645932520_24630151", "name" : "Spark shell", "attempts" : [ { @@ -28,19 +43,4 @@ "startTimeEpoch" : 1516300235119, "endTimeEpoch" : 1516300707938 } ] -}, { - "id" : "app-20180109111548-0000", - "name" : "Spark shell", - "attempts" : [ { - "startTime" : "2018-01-09T10:15:42.372GMT", - "endTime" : "2018-01-09T10:24:37.606GMT", - "lastUpdated" : "", - "duration" : 535234, - "sparkUser" : "attilapiros", - "completed" : true, - "appSparkVersion" : "2.3.0-SNAPSHOT", - "lastUpdatedEpoch" : 0, - "startTimeEpoch" : 1515492942372, - "endTimeEpoch" : 1515493477606 - } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 28c6bf1b3e01e..ea9dc1b97afc8 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1538416563558_0014", + "name" : "PythonBisectingKMeansExample", + "attempts" : [ { + "startTime" : "2018-10-02T00:42:39.580GMT", + "endTime" : "2018-10-02T00:44:02.338GMT", + "lastUpdated" : "", + "duration" : 82758, + "sparkUser" : "root", + "completed" : true, + "appSparkVersion" : "2.5.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1538440959580, + "endTimeEpoch" : 1538441042338 + } ] +}, { "id" : "application_1506645932520_24630151", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index f547b79f47e1a..2a77071a9ffd9 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,19 @@ [ { + "id" : "application_1538416563558_0014", + "name" : "PythonBisectingKMeansExample", + "attempts" : [ { + "startTime" : "2018-10-02T00:42:39.580GMT", + "endTime" : "2018-10-02T00:44:02.338GMT", + "lastUpdated" : "", + "duration" : 82758, + "sparkUser" : "root", + "completed" : true, + "appSparkVersion" : "2.5.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1538440959580, + "endTimeEpoch" : 1538441042338 + } ] +}, { "id" : "application_1506645932520_24630151", "name" : "Spark shell", "attempts" : [ { diff --git a/core/src/test/resources/ProcfsMetrics/22763/stat b/core/src/test/resources/ProcfsMetrics/22763/stat new file mode 100644 index 0000000000000..cea4b713d0ee4 --- /dev/null +++ b/core/src/test/resources/ProcfsMetrics/22763/stat @@ -0,0 +1 @@ +22763 (python2.7) S 22756 22756 7051 0 -1 1077944384 449 0 0 0 4 3 0 0 20 0 3 0 117445 360595456 1912 18446744073709551615 4194304 4196756 140726192435536 140726192432528 140707465485051 0 0 16781312 2 18446744073709551615 0 0 17 1 0 0 0 0 0 6294976 6295604 38744064 140726192440006 140726192440119 140726192440119 140726192443369 0 \ No newline at end of file diff --git a/core/src/test/resources/ProcfsMetrics/26109/stat b/core/src/test/resources/ProcfsMetrics/26109/stat new file mode 100644 index 0000000000000..ae46bfabd047e --- /dev/null +++ b/core/src/test/resources/ProcfsMetrics/26109/stat @@ -0,0 +1 @@ +26109 (java) S 1 26107 5788 0 -1 1077944320 75354 0 0 0 572 52 0 0 20 0 34 0 4355257 4769947648 64114 18446744073709551615 4194304 4196468 140737190381776 140737190364320 139976994791319 0 0 0 16800975 18446744073709551615 0 0 17 2 0 0 0 0 0 6293624 6294260 11276288 140737190385424 140737190414250 140737190414250 140737190416335 0 diff --git a/core/src/test/resources/spark-events/application_1538416563558_0014 b/core/src/test/resources/spark-events/application_1538416563558_0014 new file mode 100644 index 0000000000000..000288dbc4541 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1538416563558_0014 @@ -0,0 +1,190 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.5.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"rezamemory-1.gce.something.com","Port":43959},"Maximum Memory":384093388,"Timestamp":1538440967690,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/java/jdk1.8.0_121/jre","Java Version":"1.8.0_121 (Oracle Corporation)","Scala Version":"version 2.11.12"},"Spark Properties":{"spark.serializer":"org.apache.spark.serializer.KryoSerializer","spark.yarn.jars":"local:/opt/some/path/lib/spark2/jars/*","spark.driver.host":"rezamemory-1.gce.something.com","spark.serializer.objectStreamReset":"100","spark.eventLog.enabled":"true","spark.executor.heartbeatInterval":"100ms","spark.hadoop.mapreduce.application.classpath":"","spark.driver.port":"35918","spark.shuffle.service.enabled":"true","spark.rdd.compress":"True","spark.driver.extraLibraryPath":"/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/lib/native","spark.executorEnv.PYTHONPATH":"/opt/some/path/lib/spark2/python/lib/py4j-0.10.7-src.zip/opt/some/path/lib/spark2/python/lib/pyspark.zip","spark.yarn.historyServer.address":"http://rezamemory-1.gce.something.com:18089","spark.app.name":"PythonBisectingKMeansExample","spark.ui.killEnabled":"true","spark.sql.hive.metastore.jars":"${env:HADOOP_COMMON_HOME}/../hive/lib/*:${env:HADOOP_COMMON_HOME}/client/*","spark.dynamicAllocation.schedulerBacklogTimeout":"1","spark.yarn.am.extraLibraryPath":"/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/lib/native","spark.scheduler.mode":"FIFO","spark.eventLog.logStageExecutorMetrics.enabled":"true","spark.yarn.config.gatewayPath":"/opt/cloudera/parcels","spark.executor.id":"driver","spark.yarn.config.replacementPath":"{{HADOOP_COMMON_HOME}}/../../..","spark.eventLog.logStageExecutorProcessTreeMetrics.enabled":"true","spark.submit.deployMode":"client","spark.shuffle.service.port":"7337","spark.master":"yarn","spark.authenticate":"false","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.extraLibraryPath":"/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/lib/native","spark.eventLog.dir":"hdfs://rezamemory-1.gce.something.com:8020/user/spark/spark2ApplicationHistory","spark.dynamicAllocation.enabled":"true","spark.sql.catalogImplementation":"hive","spark.hadoop.yarn.application.classpath":"","spark.driver.appUIAddress":"http://rezamemory-1.gce.something.com:4040","spark.yarn.isPython":"true","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"rezamemory-1.gce.something.com","spark.dynamicAllocation.minExecutors":"0","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://rezamemory-1.gce.something.com:8088/proxy/application_1538416563558_0014","spark.dynamicAllocation.executorIdleTimeout":"60","spark.app.id":"application_1538416563558_0014","spark.sql.hive.metastore.version":"1.1.0"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/root","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/java/jdk1.8.0_121/jre/lib/amd64","user.dir":"/","java.library.path":":/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/lib/native:/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/lib/native:/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.121-b13","jetty.git.hash":"unknown","java.endorsed.dirs":"/usr/java/jdk1.8.0_121/jre/lib/endorsed","java.runtime.version":"1.8.0_121-b13","java.vm.info":"mixed mode","java.ext.dirs":"/usr/java/jdk1.8.0_121/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/java/jdk1.8.0_121/jre/lib/resources.jar:/usr/java/jdk1.8.0_121/jre/lib/rt.jar:/usr/java/jdk1.8.0_121/jre/lib/sunrsasign.jar:/usr/java/jdk1.8.0_121/jre/lib/jsse.jar:/usr/java/jdk1.8.0_121/jre/lib/jce.jar:/usr/java/jdk1.8.0_121/jre/lib/charsets.jar:/usr/java/jdk1.8.0_121/jre/lib/jfr.jar:/usr/java/jdk1.8.0_121/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Los_Angeles","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"3.10.0-693.5.2.el7.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","user.language":"en","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"root","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.executor.heartbeatInterval=100ms --conf spark.eventLog.logStageExecutorProcessTreeMetrics.enabled=true --conf spark.eventLog.logStageExecutorMetrics.enabled=true ./opt/some/path/lib/spark2/examples/src/main/python/mllib/bisecting_k_means_example.py","java.home":"/usr/java/jdk1.8.0_121/jre","java.version":"1.8.0_121","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/opt/some/path/lib/spark2/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/netty-3.10.5.Final.jar":"System Classpath","/opt/some/path/lib/spark2/jars/validation-api-1.1.0.Final.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-annotations-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-azure-datalake-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jaxb-impl-2.2.3-1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jasper-compiler-5.5.23.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/logredactor-1.0.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-streaming_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-common-2.22.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-collections-3.2.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/guice-servlet-3.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-hadoop-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/opt/some/path/lib/spark2/jars/parquet-hadoop-1.10.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/parquet-jackson-1.10.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-server-2.22.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jtransforms-2.4.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/aircompressor-0.10.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-el-1.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/avro-mapred-1.8.2-hadoop2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/minlog-1.3.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-daemon-1.0.13.jar":"System Classpath","/opt/some/path/lib/spark2/jars/kryo-shaded-4.0.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-crypto-1.0.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-mllib-local_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-openstack-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-lang3-3.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/univocity-parsers-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/javassist-3.18.1-GA.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-api-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-registry-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/activation-1.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/objenesis-2.5.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/aopalliance-1.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-xc-1.8.8.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-applications-unmanaged-am-launcher-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-repl_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hue-plugins-3.9.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-digester-1.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/json4s-ast_2.11-3.5.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-server-resourcemanager-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-math3-3.1.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/activation-1.1.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-server-tests-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/snappy-java-1.0.4.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/microsoft-windowsazure-storage-sdk-0.6.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/opt/some/path/lib/spark2/kafka-0.9/metrics-core-2.2.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jersey-json-1.9.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-kvstore_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/parquet-common-1.10.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/protobuf-java-2.5.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jersey-guice-1.9.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-archives-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-archive-logs-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jetty-6.1.26.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-examples-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/snappy-java-1.1.7.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-server-web-proxy-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-httpclient-3.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/orc-mapreduce-1.5.2-nohive.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jsch-0.1.42.jar":"System Classpath","/opt/some/path/lib/spark2/jars/metrics-jvm-3.1.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/javax.annotation-api-1.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/pyrolite-4.13.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-jackson-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/json4s-jackson_2.11-3.5.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-client-2.22.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jline-2.11.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-hdfs-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-scala_2.10-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/kafka-0.9/spark-streaming-kafka-0-8_2.11-2.2.0.cloudera1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jsp-api-2.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jaxb-api-2.2.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-server-nodemanager-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-logging-1.1.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-compiler-3.0.10.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-generator-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/parquet-format-2.4.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-core-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/curator-framework-2.7.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jersey-server-1.9.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-common-2.6.0-cdh5.12.0-tests.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jetty-6.1.26.cloudera.4.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/mockito-all-1.8.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-core-2.2.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/leveldbjni-all-1.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jsp-api-2.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-unsafe_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/oro-2.0.8.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-hs-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-common-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-codec-1.10.jar":"System Classpath","/opt/some/path/lib/spark2/jars/xmlenc-0.52.jar":"System Classpath","/opt/some/path/lib/spark2/jars/opencsv-2.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/xbean-asm6-shaded-4.8.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/javax.inject-1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/parquet-encoding-1.10.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-common-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/scala-library-2.11.12.jar":"System Classpath","/opt/some/path/lib/spark2/jars/json4s-scalap_2.11-3.5.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/log4j-1.2.17.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jaxb-api-2.2.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/LICENSE.txt":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-common-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/netty-3.9.9.Final.jar":"System Classpath","/opt/some/path/lib/spark2/jars/json4s-core_2.11-3.5.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-yarn-api-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/httpcore-4.2.5.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jettison-1.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/zookeeper-3.4.6.jar":"System Classpath","/opt/some/path/lib/spark2/jars/metrics-core-3.1.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-auth-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jersey-core-1.9.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-network-shuffle_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-beanutils-1.9.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/chill_2.11-0.9.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-core-2.6.7.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/paranamer-2.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/janino-3.0.10.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jetty-util-6.1.26.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-common-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/py4j-0.10.7.jar":"System Classpath","/opt/some/path/lib/spark2/jars/ivy-2.4.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-lang-2.6.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-format-2.1.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-client-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/stream-2.7.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-hdfs-2.6.0-cdh5.12.0-tests.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/xml-apis-1.3.04.jar":"System Classpath","/opt/some/path/lib/spark2/kafka-0.9/kafka_2.11-0.9.0-kafka-2.0.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/metrics-core-3.0.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-server-applicationhistoryservice-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/conf/":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/guice-servlet-3.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/slf4j-api-1.7.5.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-configuration-1.6.jar":"System Classpath","/opt/some/path/lib/spark2/jars/xz-1.5.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-tools-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-yarn-server-common-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/arrow-format-0.10.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/okio-1.4.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/compress-lzf-1.0.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-mapreduce-client-jobclient-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hppc-0.7.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/stax-api-1.0-2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-yarn_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/api-util-1.0.0-M20.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-applications-distributedshell-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/joda-time-2.9.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-sls-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jets3t-0.9.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/curator-recipes-2.7.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/leveldbjni-all-1.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/guice-3.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-streaming-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/guava-14.0.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hamcrest-core-1.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/aws-java-sdk-bundle-1.11.134.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-client-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-hs-plugins-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-gridmix-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/api-util-1.0.0-M20.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/xz-1.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-pig-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-guava-2.22.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/scala-compiler-2.11.12.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-sql_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-mapreduce-client-app-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/java-xmlbuilder-0.4.jar":"System Classpath","/opt/some/path/lib/spark2/jars/slf4j-api-1.7.16.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-hadoop-bundle-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-mapreduce-client-shuffle-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-pig-bundle-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-digester-1.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/metrics-json-3.1.5.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-codec-1.4.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-beanutils-1.7.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-catalyst_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-common-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/scala-parser-combinators_2.11-1.1.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jetty-util-6.1.26.cloudera.4.jar":"System Classpath","/opt/some/path/lib/spark2/jars/httpclient-4.5.6.jar":"System Classpath","/opt/some/path/lib/spark2/jars/antlr4-runtime-4.7.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-lang-2.6.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-mllib_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-app-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/machinist_2.11-0.6.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-core_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spire_2.11-0.13.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-xc-1.9.13.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-thrift-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/htrace-core-3.1.0-incubating.jar":"System Classpath","/opt/some/path/lib/spark2/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-annotations-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-io-2.4.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-annotations-2.2.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/orc-core-1.5.2-nohive.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-net-3.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/arrow-memory-0.10.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-graphx_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-core-asl-1.8.8.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/okhttp-2.4.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-format-2.1.0-cdh5.12.0-sources.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/htrace-core4-4.0.1-incubating.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-datajoin-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-aws-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/scala-reflect-2.11.12.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-net-3.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-databind-2.2.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/parquet-column-1.10.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/xmlenc-0.52.jar":"System Classpath","/opt/some/path/lib/spark2/kafka-0.9/kafka-clients-0.9.0-kafka-2.0.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-io-2.4.jar":"System Classpath","/opt/some/path/lib/spark2/jars/lz4-java-1.4.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/core-1.1.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/arrow-vector-0.10.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-azure-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-format-2.1.0-cdh5.12.0-javadoc.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-nfs-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-yarn-client-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/breeze_2.11-0.13.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-yarn-server-common-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/httpcore-4.4.10.jar":"System Classpath","/opt/some/path/lib/spark2/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-mapreduce-client-core-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/servlet-api-2.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-yarn-common-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-math3-3.4.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/javax.inject-1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-jaxrs-1.8.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/curator-recipes-2.7.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/spark-1.6.0-cdh5.12.0-yarn-shuffle.jar":"System Classpath","/opt/some/path/lib/spark2/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/zookeeper-3.4.5-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/httpclient-4.2.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/metrics-graphite-3.1.5.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-compress-1.4.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-sketch_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-network-common_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/gson-2.2.4.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-cascading-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-auth-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/orc-shims-1.5.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/log4j-1.2.17.jar":"System Classpath","/opt/some/path/lib/spark2/jars/stax-api-1.0-2.jar":"System Classpath","/opt/some/path/lib/spark2/kafka-0.9/zkclient-0.7.jar":"System Classpath","/opt/some/path/lib/spark2/jars/paranamer-2.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/opt/some/path/lib/spark2/jars/gson-2.2.4.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-tags_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-configuration-1.6.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-hdfs-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/guice-3.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jsr305-1.3.9.jar":"System Classpath","/opt/some/path/lib/spark2/jars/curator-client-2.7.1.jar":"System Classpath","/opt/some/path/lib/spark2/conf/yarn-conf/":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/azure-data-lake-store-sdk-2.1.4.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-distcp-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/junit-4.11.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-extras-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/xercesImpl-2.9.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jasper-runtime-5.5.23.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/curator-client-2.7.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/avro-1.8.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-compress-1.8.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-jobclient-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jsr305-3.0.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-collections-3.2.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/guava-11.0.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/asm-3.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/avro-1.7.6-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-httpclient-3.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jersey-client-1.9.jar":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-yarn-server-web-proxy-2.7.3.jar":"System Classpath","/opt/some/path/lib/spark2/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/commons-cli-1.2.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-scrooge_2.10-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/commons-cli-1.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-ant-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/spark-launcher_2.11-2.5.0-SNAPSHOT.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-nativetask-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/xercesImpl-2.9.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-hdfs-nfs-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-encoding-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-shuffle-2.6.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-avro-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/opt/some/path/lib/spark2/jars/protobuf-java-2.5.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-test-hadoop2-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-mapreduce-client-jobclient-2.6.0-cdh5.12.0-tests.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/aopalliance-1.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-column-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/slf4j-log4j12-1.7.5.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/parquet-protobuf-1.5.0-cdh5.12.0.jar":"System Classpath","/opt/some/path/lib/spark2/jars/avro-ipc-1.8.2.jar":"System Classpath","/opt/some/path/lib/spark2/jars/arpack_combined_all-0.1.jar":"System Classpath","/opt/some/path/lib/spark2/jars/netty-all-4.1.17.Final.jar":"System Classpath","/opt/some/path/lib/spark2/jars/chill-java-0.9.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/lib/hadoop/NOTICE.txt":"System Classpath","/opt/some/path/lib/spark2/jars/hadoop-mapreduce-client-common-2.7.3.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/jackson-mapper-asl-1.8.8.jar":"System Classpath","/opt/some/path/lib/spark2/jars/jackson-annotations-2.6.7.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/curator-framework-2.7.1.jar":"System Classpath","/opt/cloudera/parcels/CDH-5.12.0-1.cdh5.12.0.p0.29/jars/hadoop-rumen-2.6.0-cdh5.12.0.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"PythonBisectingKMeansExample","App ID":"application_1538416563558_0014","Timestamp":1538440959580,"User":"root"} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1538440969009,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"first at BisectingKMeans.scala:163","Number of Tasks":1,"RDD Info":[{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"2\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:163","Parent IDs":[3],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.first(RDD.scala:1377)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:163)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[0],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"first\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"first at BisectingKMeans.scala:163","Number of Tasks":1,"RDD Info":[{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"2\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:163","Parent IDs":[3],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.first(RDD.scala:1377)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:163)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440969044,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"first\"}"}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538440973727,"Executor ID":"1","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000002/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000002/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1538440973735,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Port":46411},"Maximum Memory":384093388,"Timestamp":1538440973890,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1538440973735,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440977628,"Failed":false,"Killed":false,"Accumulables":[{"ID":23,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":22,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.jvmGCTime","Update":208,"Value":208,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.resultSize","Update":1448,"Value":1448,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.executorCpuTime","Update":1105071149,"Value":1105071149,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorRunTime","Update":2307,"Value":2307,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorDeserializeCpuTime","Update":651096062,"Value":651096062,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeTime","Update":1322,"Value":1322,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1322,"Executor Deserialize CPU Time":651096062,"Executor Run Time":2307,"Executor CPU Time":1105071149,"Result Size":1448,"JVM GC Time":208,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":72,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":256071440,"JVMOffHeapMemory":92211424,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":333371,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":333371,"OffHeapUnifiedMemory":0,"DirectPoolMemory":134726,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4926242816,"ProcessTreeJVMRSSMemory":525656064,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":0,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":182536928,"JVMOffHeapMemory":58263224,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1086483,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1086483,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20304,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3009855488,"ProcessTreeJVMRSSMemory":404488192,"ProcessTreePythonVMemory":626200576,"ProcessTreePythonRSSMemory":69218304,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"first at BisectingKMeans.scala:163","Number of Tasks":1,"RDD Info":[{"RDD ID":4,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"2\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:163","Parent IDs":[3],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.first(RDD.scala:1377)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:163)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440969044,"Completion Time":1538440977644,"Accumulables":[{"ID":23,"Name":"internal.metrics.input.recordsRead","Value":4,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorDeserializeCpuTime","Value":651096062,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.resultSize","Value":1448,"Internal":true,"Count Failed Values":true},{"ID":22,"Name":"internal.metrics.input.bytesRead","Value":72,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.executorCpuTime","Value":1105071149,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeTime","Value":1322,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorRunTime","Value":2307,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.jvmGCTime","Value":208,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1538440977650,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerJobStart","Job ID":1,"Submission Time":1538440977784,"Stage Infos":[{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"map at BisectingKMeans.scala:170","Number of Tasks":2,"RDD Info":[{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:170)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]},{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"ShuffledRDD","Scope":"{\"id\":\"13\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[1],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:171)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[1,2],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"15\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"map at BisectingKMeans.scala:170","Number of Tasks":2,"RDD Info":[{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:170)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440977793,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"15\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":0,"Attempt":0,"Launch Time":1538440977816,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":1,"Attempt":0,"Launch Time":1538440978659,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":0,"Attempt":0,"Launch Time":1538440977816,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440978683,"Failed":false,"Killed":false,"Accumulables":[{"ID":48,"Name":"internal.metrics.input.recordsRead","Update":8,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":47,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":46,"Name":"internal.metrics.shuffle.write.writeTime","Update":13535058,"Value":13535058,"Internal":true,"Count Failed Values":true},{"ID":45,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":44,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":35,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":1088,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":30,"Name":"internal.metrics.resultSize","Update":1662,"Value":1662,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.executorCpuTime","Update":202227536,"Value":202227536,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorRunTime","Update":705,"Value":705,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorDeserializeCpuTime","Update":65694833,"Value":65694833,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeTime","Update":119,"Value":119,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":119,"Executor Deserialize CPU Time":65694833,"Executor Run Time":705,"Executor CPU Time":202227536,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":13535058,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":72,"Records Read":8},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":1,"Attempt":0,"Launch Time":1538440978659,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440978820,"Failed":false,"Killed":false,"Accumulables":[{"ID":48,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":47,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":46,"Name":"internal.metrics.shuffle.write.writeTime","Update":289555,"Value":13824613,"Internal":true,"Count Failed Values":true},{"ID":45,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":44,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":356,"Internal":true,"Count Failed Values":true},{"ID":35,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":2176,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":30,"Name":"internal.metrics.resultSize","Update":1662,"Value":3324,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.executorCpuTime","Update":36560031,"Value":238787567,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorRunTime","Update":120,"Value":825,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorDeserializeCpuTime","Update":7042587,"Value":72737420,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":127,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":8,"Executor Deserialize CPU Time":7042587,"Executor Run Time":120,"Executor CPU Time":36560031,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":289555,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":72,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":292935952,"JVMOffHeapMemory":95141200,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":351534,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":351534,"OffHeapUnifiedMemory":0,"DirectPoolMemory":135031,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4929392640,"ProcessTreeJVMRSSMemory":539996160,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":1,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":215586960,"JVMOffHeapMemory":60718904,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1492038,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1492038,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20637,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3014057984,"ProcessTreeJVMRSSMemory":422723584,"ProcessTreePythonVMemory":958914560,"ProcessTreePythonRSSMemory":106622976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"map at BisectingKMeans.scala:170","Number of Tasks":2,"RDD Info":[{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:170)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440977793,"Completion Time":1538440978821,"Accumulables":[{"ID":26,"Name":"internal.metrics.executorDeserializeTime","Value":127,"Internal":true,"Count Failed Values":true},{"ID":35,"Name":"internal.metrics.peakExecutionMemory","Value":2176,"Internal":true,"Count Failed Values":true},{"ID":44,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":356,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.executorCpuTime","Value":238787567,"Internal":true,"Count Failed Values":true},{"ID":47,"Name":"internal.metrics.input.bytesRead","Value":144,"Internal":true,"Count Failed Values":true},{"ID":46,"Name":"internal.metrics.shuffle.write.writeTime","Value":13824613,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorRunTime","Value":825,"Internal":true,"Count Failed Values":true},{"ID":45,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":2,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorDeserializeCpuTime","Value":72737420,"Internal":true,"Count Failed Values":true},{"ID":48,"Name":"internal.metrics.input.recordsRead","Value":12,"Internal":true,"Count Failed Values":true},{"ID":30,"Name":"internal.metrics.resultSize","Value":3324,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"ShuffledRDD","Scope":"{\"id\":\"13\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[1],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:171)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440978830,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"15\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":2,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":1,"Attempt":0,"Launch Time":1538440978844,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":2,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":0,"Attempt":0,"Launch Time":1538440979033,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":2,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":1,"Attempt":0,"Launch Time":1538440978844,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979050,"Failed":false,"Killed":false,"Accumulables":[{"ID":68,"Name":"internal.metrics.shuffle.read.recordsRead","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":356,"Value":356,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":60,"Name":"internal.metrics.peakExecutionMemory","Update":992,"Value":992,"Internal":true,"Count Failed Values":true},{"ID":59,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.resultSize","Update":1828,"Value":1828,"Internal":true,"Count Failed Values":true},{"ID":54,"Name":"internal.metrics.executorCpuTime","Update":88389028,"Value":88389028,"Internal":true,"Count Failed Values":true},{"ID":53,"Name":"internal.metrics.executorRunTime","Update":122,"Value":122,"Internal":true,"Count Failed Values":true},{"ID":52,"Name":"internal.metrics.executorDeserializeCpuTime","Update":27126551,"Value":27126551,"Internal":true,"Count Failed Values":true},{"ID":51,"Name":"internal.metrics.executorDeserializeTime","Update":45,"Value":45,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":45,"Executor Deserialize CPU Time":27126551,"Executor Run Time":122,"Executor CPU Time":88389028,"Result Size":1828,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":2,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":356,"Total Records Read":2},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":2,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":0,"Attempt":0,"Launch Time":1538440979033,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979084,"Failed":false,"Killed":false,"Accumulables":[{"ID":68,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":356,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":60,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":992,"Internal":true,"Count Failed Values":true},{"ID":59,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.resultSize","Update":1706,"Value":3534,"Internal":true,"Count Failed Values":true},{"ID":54,"Name":"internal.metrics.executorCpuTime","Update":15055355,"Value":103444383,"Internal":true,"Count Failed Values":true},{"ID":53,"Name":"internal.metrics.executorRunTime","Update":26,"Value":148,"Internal":true,"Count Failed Values":true},{"ID":52,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4722422,"Value":31848973,"Internal":true,"Count Failed Values":true},{"ID":51,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":50,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4722422,"Executor Run Time":26,"Executor CPU Time":15055355,"Result Size":1706,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":2,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":303792496,"JVMOffHeapMemory":95545824,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":371127,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":371127,"OffHeapUnifiedMemory":0,"DirectPoolMemory":135031,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4931497984,"ProcessTreeJVMRSSMemory":549777408,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":2,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":227393200,"JVMOffHeapMemory":61799392,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":463135,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":463135,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20637,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3016163328,"ProcessTreeJVMRSSMemory":436539392,"ProcessTreePythonVMemory":958914560,"ProcessTreePythonRSSMemory":106622976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":2,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":10,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"14\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[9],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":9,"Name":"ShuffledRDD","Scope":"{\"id\":\"13\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[1],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:171)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440978830,"Completion Time":1538440979086,"Accumulables":[{"ID":68,"Name":"internal.metrics.shuffle.read.recordsRead","Value":2,"Internal":true,"Count Failed Values":true},{"ID":59,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":53,"Name":"internal.metrics.executorRunTime","Value":148,"Internal":true,"Count Failed Values":true},{"ID":62,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":0,"Internal":true,"Count Failed Values":true},{"ID":65,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":55,"Name":"internal.metrics.resultSize","Value":3534,"Internal":true,"Count Failed Values":true},{"ID":64,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":0,"Internal":true,"Count Failed Values":true},{"ID":67,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":58,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":52,"Name":"internal.metrics.executorDeserializeCpuTime","Value":31848973,"Internal":true,"Count Failed Values":true},{"ID":60,"Name":"internal.metrics.peakExecutionMemory","Value":992,"Internal":true,"Count Failed Values":true},{"ID":54,"Name":"internal.metrics.executorCpuTime","Value":103444383,"Internal":true,"Count Failed Values":true},{"ID":63,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":2,"Internal":true,"Count Failed Values":true},{"ID":66,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":356,"Internal":true,"Count Failed Values":true},{"ID":51,"Name":"internal.metrics.executorDeserializeTime","Value":50,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":1,"Completion Time":1538440979087,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerJobStart","Job ID":2,"Submission Time":1538440979161,"Stage Infos":[{"Stage ID":3,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"25\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]},{"Stage ID":4,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"ShuffledRDD","Scope":"{\"id\":\"26\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[3],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[3,4],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"28\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":3,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"25\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979163,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"28\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":3,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":0,"Attempt":0,"Launch Time":1538440979184,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":3,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":1,"Attempt":0,"Launch Time":1538440979344,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":3,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":0,"Attempt":0,"Launch Time":1538440979184,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979348,"Failed":false,"Killed":false,"Accumulables":[{"ID":98,"Name":"internal.metrics.input.recordsRead","Update":8,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":97,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.shuffle.write.writeTime","Update":259310,"Value":259310,"Internal":true,"Count Failed Values":true},{"ID":95,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":94,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355,"Value":355,"Internal":true,"Count Failed Values":true},{"ID":85,"Name":"internal.metrics.peakExecutionMemory","Update":1264,"Value":1264,"Internal":true,"Count Failed Values":true},{"ID":84,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.resultSize","Update":1662,"Value":1662,"Internal":true,"Count Failed Values":true},{"ID":79,"Name":"internal.metrics.executorCpuTime","Update":40081727,"Value":40081727,"Internal":true,"Count Failed Values":true},{"ID":78,"Name":"internal.metrics.executorRunTime","Update":98,"Value":98,"Internal":true,"Count Failed Values":true},{"ID":77,"Name":"internal.metrics.executorDeserializeCpuTime","Update":24271689,"Value":24271689,"Internal":true,"Count Failed Values":true},{"ID":76,"Name":"internal.metrics.executorDeserializeTime","Update":39,"Value":39,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":39,"Executor Deserialize CPU Time":24271689,"Executor Run Time":98,"Executor CPU Time":40081727,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355,"Shuffle Write Time":259310,"Shuffle Records Written":2},"Input Metrics":{"Bytes Read":72,"Records Read":8},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":3,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":1,"Attempt":0,"Launch Time":1538440979344,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979441,"Failed":false,"Killed":false,"Accumulables":[{"ID":98,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":97,"Name":"internal.metrics.input.bytesRead","Update":36,"Value":108,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.shuffle.write.writeTime","Update":221381,"Value":480691,"Internal":true,"Count Failed Values":true},{"ID":95,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":94,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":85,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":2352,"Internal":true,"Count Failed Values":true},{"ID":84,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":83,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.resultSize","Update":1662,"Value":3324,"Internal":true,"Count Failed Values":true},{"ID":79,"Name":"internal.metrics.executorCpuTime","Update":23089017,"Value":63170744,"Internal":true,"Count Failed Values":true},{"ID":78,"Name":"internal.metrics.executorRunTime","Update":74,"Value":172,"Internal":true,"Count Failed Values":true},{"ID":77,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3471167,"Value":27742856,"Internal":true,"Count Failed Values":true},{"ID":76,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":43,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3471167,"Executor Run Time":74,"Executor CPU Time":23089017,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":221381,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":36,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":3,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":318926040,"JVMOffHeapMemory":96521592,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":391718,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":391718,"OffHeapUnifiedMemory":0,"DirectPoolMemory":135031,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4932550656,"ProcessTreeJVMRSSMemory":569753600,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":3,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":236711480,"JVMOffHeapMemory":62683008,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":483726,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":483726,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20922,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3019313152,"ProcessTreeJVMRSSMemory":445640704,"ProcessTreePythonVMemory":958914560,"ProcessTreePythonRSSMemory":106622976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":3,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":12,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"25\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[11],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":11,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"24\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979163,"Completion Time":1538440979444,"Accumulables":[{"ID":83,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":95,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":3,"Internal":true,"Count Failed Values":true},{"ID":77,"Name":"internal.metrics.executorDeserializeCpuTime","Value":27742856,"Internal":true,"Count Failed Values":true},{"ID":80,"Name":"internal.metrics.resultSize","Value":3324,"Internal":true,"Count Failed Values":true},{"ID":98,"Name":"internal.metrics.input.recordsRead","Value":12,"Internal":true,"Count Failed Values":true},{"ID":85,"Name":"internal.metrics.peakExecutionMemory","Value":2352,"Internal":true,"Count Failed Values":true},{"ID":94,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":533,"Internal":true,"Count Failed Values":true},{"ID":76,"Name":"internal.metrics.executorDeserializeTime","Value":43,"Internal":true,"Count Failed Values":true},{"ID":79,"Name":"internal.metrics.executorCpuTime","Value":63170744,"Internal":true,"Count Failed Values":true},{"ID":97,"Name":"internal.metrics.input.bytesRead","Value":108,"Internal":true,"Count Failed Values":true},{"ID":96,"Name":"internal.metrics.shuffle.write.writeTime","Value":480691,"Internal":true,"Count Failed Values":true},{"ID":78,"Name":"internal.metrics.executorRunTime","Value":172,"Internal":true,"Count Failed Values":true},{"ID":84,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":4,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"ShuffledRDD","Scope":"{\"id\":\"26\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[3],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979446,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"28\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":4,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":0,"Attempt":0,"Launch Time":1538440979462,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":4,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":1,"Attempt":0,"Launch Time":1538440979527,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":4,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":0,"Attempt":0,"Launch Time":1538440979462,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979528,"Failed":false,"Killed":false,"Accumulables":[{"ID":118,"Name":"internal.metrics.shuffle.read.recordsRead","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":117,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":116,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":115,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":114,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":113,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":112,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":110,"Name":"internal.metrics.peakExecutionMemory","Update":800,"Value":800,"Internal":true,"Count Failed Values":true},{"ID":109,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":108,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":105,"Name":"internal.metrics.resultSize","Update":1828,"Value":1828,"Internal":true,"Count Failed Values":true},{"ID":104,"Name":"internal.metrics.executorCpuTime","Update":17714408,"Value":17714408,"Internal":true,"Count Failed Values":true},{"ID":103,"Name":"internal.metrics.executorRunTime","Update":30,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":102,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12579502,"Value":12579502,"Internal":true,"Count Failed Values":true},{"ID":101,"Name":"internal.metrics.executorDeserializeTime","Update":22,"Value":22,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":22,"Executor Deserialize CPU Time":12579502,"Executor Run Time":30,"Executor CPU Time":17714408,"Result Size":1828,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":1,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":178,"Total Records Read":1},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":4,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":1,"Attempt":0,"Launch Time":1538440979527,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979572,"Failed":false,"Killed":false,"Accumulables":[{"ID":118,"Name":"internal.metrics.shuffle.read.recordsRead","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":117,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":116,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":355,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":115,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":114,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":113,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":112,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":110,"Name":"internal.metrics.peakExecutionMemory","Update":992,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":109,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":108,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":105,"Name":"internal.metrics.resultSize","Update":1828,"Value":3656,"Internal":true,"Count Failed Values":true},{"ID":104,"Name":"internal.metrics.executorCpuTime","Update":16462125,"Value":34176533,"Internal":true,"Count Failed Values":true},{"ID":103,"Name":"internal.metrics.executorRunTime","Update":16,"Value":46,"Internal":true,"Count Failed Values":true},{"ID":102,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3190663,"Value":15770165,"Internal":true,"Count Failed Values":true},{"ID":101,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":26,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3190663,"Executor Run Time":16,"Executor CPU Time":16462125,"Result Size":1828,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":2,"Fetch Wait Time":1,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":355,"Total Records Read":2},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":4,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":329919832,"JVMOffHeapMemory":96756344,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":413740,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":413740,"OffHeapUnifiedMemory":0,"DirectPoolMemory":135031,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4935208960,"ProcessTreeJVMRSSMemory":585252864,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":4,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":242876648,"JVMOffHeapMemory":62975784,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":505748,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":505748,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20922,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3019313152,"ProcessTreeJVMRSSMemory":451244032,"ProcessTreePythonVMemory":958914560,"ProcessTreePythonRSSMemory":106622976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":4,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":14,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"27\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[13],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":13,"Name":"ShuffledRDD","Scope":"{\"id\":\"26\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[12],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[3],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979446,"Completion Time":1538440979573,"Accumulables":[{"ID":101,"Name":"internal.metrics.executorDeserializeTime","Value":26,"Internal":true,"Count Failed Values":true},{"ID":110,"Name":"internal.metrics.peakExecutionMemory","Value":1792,"Internal":true,"Count Failed Values":true},{"ID":104,"Name":"internal.metrics.executorCpuTime","Value":34176533,"Internal":true,"Count Failed Values":true},{"ID":113,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":116,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":533,"Internal":true,"Count Failed Values":true},{"ID":115,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":118,"Name":"internal.metrics.shuffle.read.recordsRead","Value":3,"Internal":true,"Count Failed Values":true},{"ID":109,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":103,"Name":"internal.metrics.executorRunTime","Value":46,"Internal":true,"Count Failed Values":true},{"ID":112,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":0,"Internal":true,"Count Failed Values":true},{"ID":105,"Name":"internal.metrics.resultSize","Value":3656,"Internal":true,"Count Failed Values":true},{"ID":114,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":0,"Internal":true,"Count Failed Values":true},{"ID":117,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":108,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":102,"Name":"internal.metrics.executorDeserializeCpuTime","Value":15770165,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":2,"Completion Time":1538440979573,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerJobStart","Job ID":3,"Submission Time":1538440979609,"Stage Infos":[{"Stage ID":5,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"35\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"34\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]},{"Stage ID":6,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":18,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"37\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"ShuffledRDD","Scope":"{\"id\":\"36\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[5],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[5,6],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"38\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":5,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"35\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"34\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979619,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"38\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":5,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":0,"Attempt":0,"Launch Time":1538440979638,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":5,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":1,"Attempt":0,"Launch Time":1538440979754,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":5,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":0,"Attempt":0,"Launch Time":1538440979638,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979756,"Failed":false,"Killed":false,"Accumulables":[{"ID":148,"Name":"internal.metrics.input.recordsRead","Update":8,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":147,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":146,"Name":"internal.metrics.shuffle.write.writeTime","Update":272852,"Value":272852,"Internal":true,"Count Failed Values":true},{"ID":145,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":144,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355,"Value":355,"Internal":true,"Count Failed Values":true},{"ID":135,"Name":"internal.metrics.peakExecutionMemory","Update":1264,"Value":1264,"Internal":true,"Count Failed Values":true},{"ID":134,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":133,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":130,"Name":"internal.metrics.resultSize","Update":1662,"Value":1662,"Internal":true,"Count Failed Values":true},{"ID":129,"Name":"internal.metrics.executorCpuTime","Update":23042622,"Value":23042622,"Internal":true,"Count Failed Values":true},{"ID":128,"Name":"internal.metrics.executorRunTime","Update":76,"Value":76,"Internal":true,"Count Failed Values":true},{"ID":127,"Name":"internal.metrics.executorDeserializeCpuTime","Update":13112180,"Value":13112180,"Internal":true,"Count Failed Values":true},{"ID":126,"Name":"internal.metrics.executorDeserializeTime","Update":28,"Value":28,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":28,"Executor Deserialize CPU Time":13112180,"Executor Run Time":76,"Executor CPU Time":23042622,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355,"Shuffle Write Time":272852,"Shuffle Records Written":2},"Input Metrics":{"Bytes Read":72,"Records Read":8},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":5,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":1,"Attempt":0,"Launch Time":1538440979754,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979851,"Failed":false,"Killed":false,"Accumulables":[{"ID":148,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":147,"Name":"internal.metrics.input.bytesRead","Update":36,"Value":108,"Internal":true,"Count Failed Values":true},{"ID":146,"Name":"internal.metrics.shuffle.write.writeTime","Update":229882,"Value":502734,"Internal":true,"Count Failed Values":true},{"ID":145,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":144,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":135,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":2352,"Internal":true,"Count Failed Values":true},{"ID":134,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":133,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":130,"Name":"internal.metrics.resultSize","Update":1662,"Value":3324,"Internal":true,"Count Failed Values":true},{"ID":129,"Name":"internal.metrics.executorCpuTime","Update":22093052,"Value":45135674,"Internal":true,"Count Failed Values":true},{"ID":128,"Name":"internal.metrics.executorRunTime","Update":81,"Value":157,"Internal":true,"Count Failed Values":true},{"ID":127,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3862579,"Value":16974759,"Internal":true,"Count Failed Values":true},{"ID":126,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":32,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3862579,"Executor Run Time":81,"Executor CPU Time":22093052,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":229882,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":36,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":5,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":341682304,"JVMOffHeapMemory":97514672,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":434309,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":434309,"OffHeapUnifiedMemory":0,"DirectPoolMemory":135031,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4935254016,"ProcessTreeJVMRSSMemory":597999616,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":5,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":252029672,"JVMOffHeapMemory":63463032,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":526317,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":526317,"OffHeapUnifiedMemory":0,"DirectPoolMemory":21041,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3020365824,"ProcessTreeJVMRSSMemory":458960896,"ProcessTreePythonVMemory":958914560,"ProcessTreePythonRSSMemory":106622976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":5,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":16,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"35\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[15],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":15,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"34\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979619,"Completion Time":1538440979852,"Accumulables":[{"ID":146,"Name":"internal.metrics.shuffle.write.writeTime","Value":502734,"Internal":true,"Count Failed Values":true},{"ID":128,"Name":"internal.metrics.executorRunTime","Value":157,"Internal":true,"Count Failed Values":true},{"ID":134,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":133,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":127,"Name":"internal.metrics.executorDeserializeCpuTime","Value":16974759,"Internal":true,"Count Failed Values":true},{"ID":145,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":3,"Internal":true,"Count Failed Values":true},{"ID":130,"Name":"internal.metrics.resultSize","Value":3324,"Internal":true,"Count Failed Values":true},{"ID":148,"Name":"internal.metrics.input.recordsRead","Value":12,"Internal":true,"Count Failed Values":true},{"ID":129,"Name":"internal.metrics.executorCpuTime","Value":45135674,"Internal":true,"Count Failed Values":true},{"ID":147,"Name":"internal.metrics.input.bytesRead","Value":108,"Internal":true,"Count Failed Values":true},{"ID":126,"Name":"internal.metrics.executorDeserializeTime","Value":32,"Internal":true,"Count Failed Values":true},{"ID":135,"Name":"internal.metrics.peakExecutionMemory","Value":2352,"Internal":true,"Count Failed Values":true},{"ID":144,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":533,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":6,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":18,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"37\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"ShuffledRDD","Scope":"{\"id\":\"36\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[5],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979854,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"38\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":6,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":0,"Attempt":0,"Launch Time":1538440979869,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":6,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":1,"Attempt":0,"Launch Time":1538440979920,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":6,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":0,"Attempt":0,"Launch Time":1538440979869,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979921,"Failed":false,"Killed":false,"Accumulables":[{"ID":168,"Name":"internal.metrics.shuffle.read.recordsRead","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":167,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":166,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":165,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":164,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":163,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":162,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":160,"Name":"internal.metrics.peakExecutionMemory","Update":800,"Value":800,"Internal":true,"Count Failed Values":true},{"ID":159,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":158,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":155,"Name":"internal.metrics.resultSize","Update":1828,"Value":1828,"Internal":true,"Count Failed Values":true},{"ID":154,"Name":"internal.metrics.executorCpuTime","Update":15546330,"Value":15546330,"Internal":true,"Count Failed Values":true},{"ID":153,"Name":"internal.metrics.executorRunTime","Update":19,"Value":19,"Internal":true,"Count Failed Values":true},{"ID":152,"Name":"internal.metrics.executorDeserializeCpuTime","Update":11263754,"Value":11263754,"Internal":true,"Count Failed Values":true},{"ID":151,"Name":"internal.metrics.executorDeserializeTime","Update":22,"Value":22,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":22,"Executor Deserialize CPU Time":11263754,"Executor Run Time":19,"Executor CPU Time":15546330,"Result Size":1828,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":1,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":178,"Total Records Read":1},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":6,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":1,"Attempt":0,"Launch Time":1538440979920,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440979972,"Failed":false,"Killed":false,"Accumulables":[{"ID":168,"Name":"internal.metrics.shuffle.read.recordsRead","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":167,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":166,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":355,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":165,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":164,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":163,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":162,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":160,"Name":"internal.metrics.peakExecutionMemory","Update":992,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":159,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":158,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":157,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":155,"Name":"internal.metrics.resultSize","Update":1871,"Value":3699,"Internal":true,"Count Failed Values":true},{"ID":154,"Name":"internal.metrics.executorCpuTime","Update":15089701,"Value":30636031,"Internal":true,"Count Failed Values":true},{"ID":153,"Name":"internal.metrics.executorRunTime","Update":27,"Value":46,"Internal":true,"Count Failed Values":true},{"ID":152,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3045280,"Value":14309034,"Internal":true,"Count Failed Values":true},{"ID":151,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":25,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":3045280,"Executor Run Time":27,"Executor CPU Time":15089701,"Result Size":1871,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":2,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":355,"Total Records Read":2},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":6,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":350990264,"JVMOffHeapMemory":97710440,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":456312,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":456312,"OffHeapUnifiedMemory":0,"DirectPoolMemory":135031,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4932550656,"ProcessTreeJVMRSSMemory":604299264,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":6,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":18,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"37\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[17],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":17,"Name":"ShuffledRDD","Scope":"{\"id\":\"36\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[16],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[5],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440979854,"Completion Time":1538440979973,"Accumulables":[{"ID":155,"Name":"internal.metrics.resultSize","Value":3699,"Internal":true,"Count Failed Values":true},{"ID":164,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":0,"Internal":true,"Count Failed Values":true},{"ID":167,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":158,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":166,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":533,"Internal":true,"Count Failed Values":true},{"ID":151,"Name":"internal.metrics.executorDeserializeTime","Value":25,"Internal":true,"Count Failed Values":true},{"ID":160,"Name":"internal.metrics.peakExecutionMemory","Value":1792,"Internal":true,"Count Failed Values":true},{"ID":163,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":154,"Name":"internal.metrics.executorCpuTime","Value":30636031,"Internal":true,"Count Failed Values":true},{"ID":157,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":165,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":168,"Name":"internal.metrics.shuffle.read.recordsRead","Value":3,"Internal":true,"Count Failed Values":true},{"ID":159,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":153,"Name":"internal.metrics.executorRunTime","Value":46,"Internal":true,"Count Failed Values":true},{"ID":162,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":0,"Internal":true,"Count Failed Values":true},{"ID":152,"Name":"internal.metrics.executorDeserializeCpuTime","Value":14309034,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":3,"Completion Time":1538440979974,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerJobStart","Job ID":4,"Submission Time":1538440980008,"Stage Infos":[{"Stage ID":7,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":20,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"45\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[19],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"44\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]},{"Stage ID":8,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"47\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[21],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":21,"Name":"ShuffledRDD","Scope":"{\"id\":\"46\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[7],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[7,8],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"48\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":7,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":20,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"45\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[19],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"44\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440980015,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"48\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":7,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":0,"Attempt":0,"Launch Time":1538440980049,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Port":46411},"Timestamp":1538440980522} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538440980759,"Executor ID":"1","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000002 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000002\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerTaskEnd","Stage ID":7,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"1","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000002 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000002\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":13,"Index":0,"Attempt":0,"Launch Time":1538440980049,"Executor ID":"1","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440980757,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538440986317,"Executor ID":"2","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000003/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000003/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":7,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":0,"Attempt":1,"Launch Time":1538440986317,"Executor ID":"2","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"rezamemory-2.gce.something.com","Port":39119},"Maximum Memory":384093388,"Timestamp":1538440986696,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538440988793,"Executor ID":"3","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000004/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000004/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":7,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1538440988793,"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Port":40911},"Maximum Memory":384093388,"Timestamp":1538440989162,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"2","Host":"rezamemory-2.gce.something.com","Port":39119},"Timestamp":1538440993798} +{"Event":"SparkListenerTaskEnd","Stage ID":7,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"2","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000003 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000003\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":14,"Index":0,"Attempt":1,"Launch Time":1538440986317,"Executor ID":"2","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440994010,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538440994012,"Executor ID":"2","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000003 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000003\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerTaskStart","Stage ID":7,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":0,"Attempt":2,"Launch Time":1538440995449,"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":7,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1538440988793,"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440995450,"Failed":false,"Killed":false,"Accumulables":[{"ID":198,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":197,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":196,"Name":"internal.metrics.shuffle.write.writeTime","Update":10065137,"Value":10065137,"Internal":true,"Count Failed Values":true},{"ID":195,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":194,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":185,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":1088,"Internal":true,"Count Failed Values":true},{"ID":184,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":183,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":181,"Name":"internal.metrics.jvmGCTime","Update":360,"Value":360,"Internal":true,"Count Failed Values":true},{"ID":180,"Name":"internal.metrics.resultSize","Update":1705,"Value":1705,"Internal":true,"Count Failed Values":true},{"ID":179,"Name":"internal.metrics.executorCpuTime","Update":1406669099,"Value":1406669099,"Internal":true,"Count Failed Values":true},{"ID":178,"Name":"internal.metrics.executorRunTime","Update":4128,"Value":4128,"Internal":true,"Count Failed Values":true},{"ID":177,"Name":"internal.metrics.executorDeserializeCpuTime","Update":726605764,"Value":726605764,"Internal":true,"Count Failed Values":true},{"ID":176,"Name":"internal.metrics.executorDeserializeTime","Update":1995,"Value":1995,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1995,"Executor Deserialize CPU Time":726605764,"Executor Run Time":4128,"Executor CPU Time":1406669099,"Result Size":1705,"JVM GC Time":360,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":10065137,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":72,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":7,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":0,"Attempt":2,"Launch Time":1538440995449,"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440995696,"Failed":false,"Killed":false,"Accumulables":[{"ID":198,"Name":"internal.metrics.input.recordsRead","Update":8,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":197,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":196,"Name":"internal.metrics.shuffle.write.writeTime","Update":293846,"Value":10358983,"Internal":true,"Count Failed Values":true},{"ID":195,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":194,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":185,"Name":"internal.metrics.peakExecutionMemory","Update":1264,"Value":2352,"Internal":true,"Count Failed Values":true},{"ID":184,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":183,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":180,"Name":"internal.metrics.resultSize","Update":1662,"Value":3367,"Internal":true,"Count Failed Values":true},{"ID":179,"Name":"internal.metrics.executorCpuTime","Update":91844758,"Value":1498513857,"Internal":true,"Count Failed Values":true},{"ID":178,"Name":"internal.metrics.executorRunTime","Update":220,"Value":4348,"Internal":true,"Count Failed Values":true},{"ID":177,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8316162,"Value":734921926,"Internal":true,"Count Failed Values":true},{"ID":176,"Name":"internal.metrics.executorDeserializeTime","Update":9,"Value":2004,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":9,"Executor Deserialize CPU Time":8316162,"Executor Run Time":220,"Executor CPU Time":91844758,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355,"Shuffle Write Time":293846,"Shuffle Records Written":2},"Input Metrics":{"Bytes Read":72,"Records Read":8},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"2","Stage ID":7,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":201931120,"JVMOffHeapMemory":58230320,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1094710,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1094710,"OffHeapUnifiedMemory":0,"DirectPoolMemory":45633,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3023769600,"ProcessTreeJVMRSSMemory":410324992,"ProcessTreePythonVMemory":285470720,"ProcessTreePythonRSSMemory":30171136,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":7,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":195471784,"JVMOffHeapMemory":100867584,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":476885,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":476885,"OffHeapUnifiedMemory":0,"DirectPoolMemory":171571,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":4971368448,"ProcessTreeJVMRSSMemory":663375872,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"1","Stage ID":7,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":258718872,"JVMOffHeapMemory":63737056,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":548320,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":548320,"OffHeapUnifiedMemory":0,"DirectPoolMemory":21084,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3021418496,"ProcessTreeJVMRSSMemory":466001920,"ProcessTreePythonVMemory":958914560,"ProcessTreePythonRSSMemory":106622976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"3","Stage ID":7,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":223684056,"JVMOffHeapMemory":60665000,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1482102,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1482102,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20318,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3015626752,"ProcessTreeJVMRSSMemory":404672512,"ProcessTreePythonVMemory":958963712,"ProcessTreePythonRSSMemory":106639360,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":7,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":20,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"45\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[19],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":19,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"44\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440980015,"Completion Time":1538440995697,"Accumulables":[{"ID":176,"Name":"internal.metrics.executorDeserializeTime","Value":2004,"Internal":true,"Count Failed Values":true},{"ID":185,"Name":"internal.metrics.peakExecutionMemory","Value":2352,"Internal":true,"Count Failed Values":true},{"ID":194,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":533,"Internal":true,"Count Failed Values":true},{"ID":184,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":178,"Name":"internal.metrics.executorRunTime","Value":4348,"Internal":true,"Count Failed Values":true},{"ID":196,"Name":"internal.metrics.shuffle.write.writeTime","Value":10358983,"Internal":true,"Count Failed Values":true},{"ID":181,"Name":"internal.metrics.jvmGCTime","Value":360,"Internal":true,"Count Failed Values":true},{"ID":180,"Name":"internal.metrics.resultSize","Value":3367,"Internal":true,"Count Failed Values":true},{"ID":198,"Name":"internal.metrics.input.recordsRead","Value":12,"Internal":true,"Count Failed Values":true},{"ID":183,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":177,"Name":"internal.metrics.executorDeserializeCpuTime","Value":734921926,"Internal":true,"Count Failed Values":true},{"ID":195,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":3,"Internal":true,"Count Failed Values":true},{"ID":179,"Name":"internal.metrics.executorCpuTime","Value":1498513857,"Internal":true,"Count Failed Values":true},{"ID":197,"Name":"internal.metrics.input.bytesRead","Value":144,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":8,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"47\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[21],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":21,"Name":"ShuffledRDD","Scope":"{\"id\":\"46\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[7],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440995698,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"48\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":8,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":0,"Attempt":0,"Launch Time":1538440995710,"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Port":40911},"Timestamp":1538440996257} +{"Event":"SparkListenerTaskEnd","Stage ID":8,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"3","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000004 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000004\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":17,"Index":0,"Attempt":0,"Launch Time":1538440995710,"Executor ID":"3","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538440996467,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538440996468,"Executor ID":"3","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000004 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000004\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538441002826,"Executor ID":"4","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000005/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000005/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":8,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":0,"Attempt":1,"Launch Time":1538441002828,"Executor ID":"4","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538441003031,"Executor ID":"5","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000006/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000006/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":8,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":1,"Attempt":0,"Launch Time":1538441003032,"Executor ID":"5","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"rezamemory-2.gce.something.com","Port":39248},"Maximum Memory":384093388,"Timestamp":1538441003132,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"rezamemory-2.gce.something.com","Port":43165},"Maximum Memory":384093388,"Timestamp":1538441003383,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerTaskEnd","Stage ID":8,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":0,"Attempt":1,"Launch Time":1538441002828,"Executor ID":"4","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441006147,"Failed":false,"Killed":false,"Accumulables":[{"ID":218,"Name":"internal.metrics.shuffle.read.recordsRead","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":217,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":216,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":215,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":214,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":213,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":212,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":210,"Name":"internal.metrics.peakExecutionMemory","Update":800,"Value":800,"Internal":true,"Count Failed Values":true},{"ID":209,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":208,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":207,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":206,"Name":"internal.metrics.jvmGCTime","Update":350,"Value":350,"Internal":true,"Count Failed Values":true},{"ID":205,"Name":"internal.metrics.resultSize","Update":1914,"Value":1914,"Internal":true,"Count Failed Values":true},{"ID":204,"Name":"internal.metrics.executorCpuTime","Update":219243972,"Value":219243972,"Internal":true,"Count Failed Values":true},{"ID":203,"Name":"internal.metrics.executorRunTime","Update":893,"Value":893,"Internal":true,"Count Failed Values":true},{"ID":202,"Name":"internal.metrics.executorDeserializeCpuTime","Update":717217987,"Value":717217987,"Internal":true,"Count Failed Values":true},{"ID":201,"Name":"internal.metrics.executorDeserializeTime","Update":1972,"Value":1972,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1972,"Executor Deserialize CPU Time":717217987,"Executor Run Time":893,"Executor CPU Time":219243972,"Result Size":1914,"JVM GC Time":350,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":1,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":178,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":1},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":8,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":1,"Attempt":0,"Launch Time":1538441003032,"Executor ID":"5","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441006584,"Failed":false,"Killed":false,"Accumulables":[{"ID":218,"Name":"internal.metrics.shuffle.read.recordsRead","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":217,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":216,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":215,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":214,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":355,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":213,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":212,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":210,"Name":"internal.metrics.peakExecutionMemory","Update":992,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":209,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":208,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":207,"Name":"internal.metrics.resultSerializationTime","Update":10,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":206,"Name":"internal.metrics.jvmGCTime","Update":270,"Value":620,"Internal":true,"Count Failed Values":true},{"ID":205,"Name":"internal.metrics.resultSize","Update":1914,"Value":3828,"Internal":true,"Count Failed Values":true},{"ID":204,"Name":"internal.metrics.executorCpuTime","Update":210863492,"Value":430107464,"Internal":true,"Count Failed Values":true},{"ID":203,"Name":"internal.metrics.executorRunTime","Update":412,"Value":1305,"Internal":true,"Count Failed Values":true},{"ID":202,"Name":"internal.metrics.executorDeserializeCpuTime","Update":727356712,"Value":1444574699,"Internal":true,"Count Failed Values":true},{"ID":201,"Name":"internal.metrics.executorDeserializeTime","Update":2604,"Value":4576,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2604,"Executor Deserialize CPU Time":727356712,"Executor Run Time":412,"Executor CPU Time":210863492,"Result Size":1914,"JVM GC Time":270,"Result Serialization Time":10,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":2,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":355,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":2},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":8,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":263995432,"JVMOffHeapMemory":101978136,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":498888,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":498888,"OffHeapUnifiedMemory":0,"DirectPoolMemory":191656,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":5008089088,"ProcessTreeJVMRSSMemory":663732224,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"5","Stage ID":8,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":150497592,"JVMOffHeapMemory":45958576,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":22003,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":22003,"OffHeapUnifiedMemory":0,"DirectPoolMemory":3446,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":2984218624,"ProcessTreeJVMRSSMemory":325042176,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"4","Stage ID":8,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":181352744,"JVMOffHeapMemory":47061200,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":22003,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":22003,"OffHeapUnifiedMemory":0,"DirectPoolMemory":11272,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3013332992,"ProcessTreeJVMRSSMemory":416645120,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"3","Stage ID":8,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":226223752,"JVMOffHeapMemory":60840424,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":433558,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":433558,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20318,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3016937472,"ProcessTreeJVMRSSMemory":406044672,"ProcessTreePythonVMemory":958963712,"ProcessTreePythonRSSMemory":106639360,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":8,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":22,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"47\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[21],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":21,"Name":"ShuffledRDD","Scope":"{\"id\":\"46\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[20],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[7],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538440995698,"Completion Time":1538441006585,"Accumulables":[{"ID":218,"Name":"internal.metrics.shuffle.read.recordsRead","Value":3,"Internal":true,"Count Failed Values":true},{"ID":209,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":202,"Name":"internal.metrics.executorDeserializeCpuTime","Value":1444574699,"Internal":true,"Count Failed Values":true},{"ID":205,"Name":"internal.metrics.resultSize","Value":3828,"Internal":true,"Count Failed Values":true},{"ID":214,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":533,"Internal":true,"Count Failed Values":true},{"ID":217,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":208,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":216,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":0,"Internal":true,"Count Failed Values":true},{"ID":207,"Name":"internal.metrics.resultSerializationTime","Value":12,"Internal":true,"Count Failed Values":true},{"ID":210,"Name":"internal.metrics.peakExecutionMemory","Value":1792,"Internal":true,"Count Failed Values":true},{"ID":201,"Name":"internal.metrics.executorDeserializeTime","Value":4576,"Internal":true,"Count Failed Values":true},{"ID":213,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":0,"Internal":true,"Count Failed Values":true},{"ID":204,"Name":"internal.metrics.executorCpuTime","Value":430107464,"Internal":true,"Count Failed Values":true},{"ID":203,"Name":"internal.metrics.executorRunTime","Value":1305,"Internal":true,"Count Failed Values":true},{"ID":212,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":215,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":206,"Name":"internal.metrics.jvmGCTime","Value":620,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":4,"Completion Time":1538441006585,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerJobStart","Job ID":5,"Submission Time":1538441006610,"Stage Infos":[{"Stage ID":9,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"55\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"54\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]},{"Stage ID":10,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"57\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"ShuffledRDD","Scope":"{\"id\":\"56\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[9],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[9,10],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"58\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":9,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"55\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"54\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441006612,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"58\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":9,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":0,"Attempt":0,"Launch Time":1538441006622,"Executor ID":"4","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":9,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":1,"Attempt":0,"Launch Time":1538441006623,"Executor ID":"5","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"4","Host":"rezamemory-2.gce.something.com","Port":39248},"Timestamp":1538441010070} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"5","Host":"rezamemory-2.gce.something.com","Port":43165},"Timestamp":1538441010233} +{"Event":"SparkListenerTaskEnd","Stage ID":9,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"4","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000005 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000005\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":20,"Index":0,"Attempt":0,"Launch Time":1538441006622,"Executor ID":"4","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441010280,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538441010281,"Executor ID":"4","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000005 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000005\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerTaskEnd","Stage ID":9,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"5","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000006 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000006\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":21,"Index":1,"Attempt":0,"Launch Time":1538441006623,"Executor ID":"5","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441010484,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538441010485,"Executor ID":"5","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000006 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000006\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538441015443,"Executor ID":"6","Executor Info":{"Host":"rezamemory-3.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-3.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000007/root/stdout?start=-4096","stderr":"http://rezamemory-3.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000007/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":9,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":1,"Attempt":1,"Launch Time":1538441015444,"Executor ID":"6","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"6","Host":"rezamemory-3.gce.something.com","Port":45593},"Maximum Memory":384093388,"Timestamp":1538441015852,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538441020314,"Executor ID":"7","Executor Info":{"Host":"rezamemory-3.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-3.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000008/root/stdout?start=-4096","stderr":"http://rezamemory-3.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000008/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":9,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":0,"Attempt":1,"Launch Time":1538441020315,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Port":40992},"Maximum Memory":384093388,"Timestamp":1538441020602,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"6","Host":"rezamemory-3.gce.something.com","Port":45593},"Timestamp":1538441022942} +{"Event":"SparkListenerTaskEnd","Stage ID":9,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"6","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000007 on host: rezamemory-3.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000007\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":22,"Index":1,"Attempt":1,"Launch Time":1538441015444,"Executor ID":"6","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441023152,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538441023153,"Executor ID":"6","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000007 on host: rezamemory-3.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000007\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerTaskStart","Stage ID":9,"Stage Attempt ID":0,"Task Info":{"Task ID":24,"Index":1,"Attempt":2,"Launch Time":1538441025899,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":9,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":0,"Attempt":1,"Launch Time":1538441020315,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441025900,"Failed":false,"Killed":false,"Accumulables":[{"ID":248,"Name":"internal.metrics.input.recordsRead","Update":8,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":247,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":246,"Name":"internal.metrics.shuffle.write.writeTime","Update":3971129,"Value":3971129,"Internal":true,"Count Failed Values":true},{"ID":245,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":244,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355,"Value":355,"Internal":true,"Count Failed Values":true},{"ID":235,"Name":"internal.metrics.peakExecutionMemory","Update":1264,"Value":1264,"Internal":true,"Count Failed Values":true},{"ID":234,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":233,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":231,"Name":"internal.metrics.jvmGCTime","Update":244,"Value":244,"Internal":true,"Count Failed Values":true},{"ID":230,"Name":"internal.metrics.resultSize","Update":1705,"Value":1705,"Internal":true,"Count Failed Values":true},{"ID":229,"Name":"internal.metrics.executorCpuTime","Update":1268816374,"Value":1268816374,"Internal":true,"Count Failed Values":true},{"ID":228,"Name":"internal.metrics.executorRunTime","Update":2978,"Value":2978,"Internal":true,"Count Failed Values":true},{"ID":227,"Name":"internal.metrics.executorDeserializeCpuTime","Update":714859741,"Value":714859741,"Internal":true,"Count Failed Values":true},{"ID":226,"Name":"internal.metrics.executorDeserializeTime","Update":2106,"Value":2106,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2106,"Executor Deserialize CPU Time":714859741,"Executor Run Time":2978,"Executor CPU Time":1268816374,"Result Size":1705,"JVM GC Time":244,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355,"Shuffle Write Time":3971129,"Shuffle Records Written":2},"Input Metrics":{"Bytes Read":72,"Records Read":8},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":9,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":24,"Index":1,"Attempt":2,"Launch Time":1538441025899,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026136,"Failed":false,"Killed":false,"Accumulables":[{"ID":248,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":247,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":246,"Name":"internal.metrics.shuffle.write.writeTime","Update":265841,"Value":4236970,"Internal":true,"Count Failed Values":true},{"ID":245,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":244,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":235,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":2352,"Internal":true,"Count Failed Values":true},{"ID":234,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":233,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":232,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":230,"Name":"internal.metrics.resultSize","Update":1705,"Value":3410,"Internal":true,"Count Failed Values":true},{"ID":229,"Name":"internal.metrics.executorCpuTime","Update":88980290,"Value":1357796664,"Internal":true,"Count Failed Values":true},{"ID":228,"Name":"internal.metrics.executorRunTime","Update":201,"Value":3179,"Internal":true,"Count Failed Values":true},{"ID":227,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8550572,"Value":723410313,"Internal":true,"Count Failed Values":true},{"ID":226,"Name":"internal.metrics.executorDeserializeTime","Update":13,"Value":2119,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":13,"Executor Deserialize CPU Time":8550572,"Executor Run Time":201,"Executor CPU Time":88980290,"Result Size":1705,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":265841,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":72,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":9,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":332727504,"JVMOffHeapMemory":103237664,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":519462,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":519462,"OffHeapUnifiedMemory":0,"DirectPoolMemory":228406,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":5011247104,"ProcessTreeJVMRSSMemory":658915328,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"5","Stage ID":9,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":184519808,"JVMOffHeapMemory":58341088,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1116714,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1116714,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20420,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":2998673408,"ProcessTreeJVMRSSMemory":378527744,"ProcessTreePythonVMemory":0,"ProcessTreePythonRSSMemory":0,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"7","Stage ID":9,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":218694008,"JVMOffHeapMemory":60757008,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1482103,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1482103,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20668,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3020120064,"ProcessTreeJVMRSSMemory":423698432,"ProcessTreePythonVMemory":958894080,"ProcessTreePythonRSSMemory":106696704,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"4","Stage ID":9,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":220189424,"JVMOffHeapMemory":59534504,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1116714,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1116714,"OffHeapUnifiedMemory":0,"DirectPoolMemory":27895,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3024392192,"ProcessTreeJVMRSSMemory":431939584,"ProcessTreePythonVMemory":283738112,"ProcessTreePythonRSSMemory":27226112,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"6","Stage ID":9,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":208356192,"JVMOffHeapMemory":58297728,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1094711,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1094711,"OffHeapUnifiedMemory":0,"DirectPoolMemory":27296,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3027820544,"ProcessTreeJVMRSSMemory":439750656,"ProcessTreePythonVMemory":286220288,"ProcessTreePythonRSSMemory":30846976,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":9,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":24,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"55\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[23],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":23,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"54\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441006612,"Completion Time":1538441026137,"Accumulables":[{"ID":227,"Name":"internal.metrics.executorDeserializeCpuTime","Value":723410313,"Internal":true,"Count Failed Values":true},{"ID":245,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":3,"Internal":true,"Count Failed Values":true},{"ID":226,"Name":"internal.metrics.executorDeserializeTime","Value":2119,"Internal":true,"Count Failed Values":true},{"ID":235,"Name":"internal.metrics.peakExecutionMemory","Value":2352,"Internal":true,"Count Failed Values":true},{"ID":244,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":533,"Internal":true,"Count Failed Values":true},{"ID":229,"Name":"internal.metrics.executorCpuTime","Value":1357796664,"Internal":true,"Count Failed Values":true},{"ID":247,"Name":"internal.metrics.input.bytesRead","Value":144,"Internal":true,"Count Failed Values":true},{"ID":232,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":234,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":228,"Name":"internal.metrics.executorRunTime","Value":3179,"Internal":true,"Count Failed Values":true},{"ID":246,"Name":"internal.metrics.shuffle.write.writeTime","Value":4236970,"Internal":true,"Count Failed Values":true},{"ID":231,"Name":"internal.metrics.jvmGCTime","Value":244,"Internal":true,"Count Failed Values":true},{"ID":230,"Name":"internal.metrics.resultSize","Value":3410,"Internal":true,"Count Failed Values":true},{"ID":248,"Name":"internal.metrics.input.recordsRead","Value":12,"Internal":true,"Count Failed Values":true},{"ID":233,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":10,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"57\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"ShuffledRDD","Scope":"{\"id\":\"56\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[9],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026138,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"58\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":10,"Stage Attempt ID":0,"Task Info":{"Task ID":25,"Index":0,"Attempt":0,"Launch Time":1538441026147,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":10,"Stage Attempt ID":0,"Task Info":{"Task ID":26,"Index":1,"Attempt":0,"Launch Time":1538441026309,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":10,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":25,"Index":0,"Attempt":0,"Launch Time":1538441026147,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026311,"Failed":false,"Killed":false,"Accumulables":[{"ID":268,"Name":"internal.metrics.shuffle.read.recordsRead","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":267,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":266,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":265,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":264,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":263,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":262,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":260,"Name":"internal.metrics.peakExecutionMemory","Update":800,"Value":800,"Internal":true,"Count Failed Values":true},{"ID":259,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":258,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":255,"Name":"internal.metrics.resultSize","Update":1828,"Value":1828,"Internal":true,"Count Failed Values":true},{"ID":254,"Name":"internal.metrics.executorCpuTime","Update":80311930,"Value":80311930,"Internal":true,"Count Failed Values":true},{"ID":253,"Name":"internal.metrics.executorRunTime","Update":89,"Value":89,"Internal":true,"Count Failed Values":true},{"ID":252,"Name":"internal.metrics.executorDeserializeCpuTime","Update":29610969,"Value":29610969,"Internal":true,"Count Failed Values":true},{"ID":251,"Name":"internal.metrics.executorDeserializeTime","Update":62,"Value":62,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":62,"Executor Deserialize CPU Time":29610969,"Executor Run Time":89,"Executor CPU Time":80311930,"Result Size":1828,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":1,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":178,"Total Records Read":1},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":10,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":26,"Index":1,"Attempt":0,"Launch Time":1538441026309,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026375,"Failed":false,"Killed":false,"Accumulables":[{"ID":268,"Name":"internal.metrics.shuffle.read.recordsRead","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":267,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":266,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":355,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":265,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":264,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":263,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":262,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":260,"Name":"internal.metrics.peakExecutionMemory","Update":992,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":259,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":258,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":255,"Name":"internal.metrics.resultSize","Update":1828,"Value":3656,"Internal":true,"Count Failed Values":true},{"ID":254,"Name":"internal.metrics.executorCpuTime","Update":18625831,"Value":98937761,"Internal":true,"Count Failed Values":true},{"ID":253,"Name":"internal.metrics.executorRunTime","Update":38,"Value":127,"Internal":true,"Count Failed Values":true},{"ID":252,"Name":"internal.metrics.executorDeserializeCpuTime","Update":6238101,"Value":35849070,"Internal":true,"Count Failed Values":true},{"ID":251,"Name":"internal.metrics.executorDeserializeTime","Update":6,"Value":68,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":6,"Executor Deserialize CPU Time":6238101,"Executor Run Time":38,"Executor CPU Time":18625831,"Result Size":1828,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":2,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":355,"Total Records Read":2},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":10,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":341644736,"JVMOffHeapMemory":103378144,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":541469,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":541469,"OffHeapUnifiedMemory":0,"DirectPoolMemory":228406,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":5011247104,"ProcessTreeJVMRSSMemory":658989056,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"7","Stage ID":10,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":228132872,"JVMOffHeapMemory":61634808,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":455614,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":455614,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20669,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3021172736,"ProcessTreeJVMRSSMemory":436867072,"ProcessTreePythonVMemory":958894080,"ProcessTreePythonRSSMemory":106696704,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":10,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":26,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"57\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[25],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":25,"Name":"ShuffledRDD","Scope":"{\"id\":\"56\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[24],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[9],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026138,"Completion Time":1538441026376,"Accumulables":[{"ID":254,"Name":"internal.metrics.executorCpuTime","Value":98937761,"Internal":true,"Count Failed Values":true},{"ID":262,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":0,"Internal":true,"Count Failed Values":true},{"ID":253,"Name":"internal.metrics.executorRunTime","Value":127,"Internal":true,"Count Failed Values":true},{"ID":265,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":259,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":268,"Name":"internal.metrics.shuffle.read.recordsRead","Value":3,"Internal":true,"Count Failed Values":true},{"ID":267,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":258,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":252,"Name":"internal.metrics.executorDeserializeCpuTime","Value":35849070,"Internal":true,"Count Failed Values":true},{"ID":255,"Name":"internal.metrics.resultSize","Value":3656,"Internal":true,"Count Failed Values":true},{"ID":264,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":0,"Internal":true,"Count Failed Values":true},{"ID":263,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":266,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":533,"Internal":true,"Count Failed Values":true},{"ID":260,"Name":"internal.metrics.peakExecutionMemory","Value":1792,"Internal":true,"Count Failed Values":true},{"ID":251,"Name":"internal.metrics.executorDeserializeTime","Value":68,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":5,"Completion Time":1538441026376,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerJobStart","Job ID":6,"Submission Time":1538441026404,"Stage Infos":[{"Stage ID":12,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":30,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"67\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[29],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":29,"Name":"ShuffledRDD","Scope":"{\"id\":\"66\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[28],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[11],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]},{"Stage ID":11,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":28,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"65\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[27],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":27,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"64\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[12,11],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"68\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":11,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":28,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"65\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[27],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":27,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"64\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026408,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"68\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":11,"Stage Attempt ID":0,"Task Info":{"Task ID":27,"Index":0,"Attempt":0,"Launch Time":1538441026450,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":11,"Stage Attempt ID":0,"Task Info":{"Task ID":28,"Index":1,"Attempt":0,"Launch Time":1538441026585,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":11,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":27,"Index":0,"Attempt":0,"Launch Time":1538441026450,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026586,"Failed":false,"Killed":false,"Accumulables":[{"ID":298,"Name":"internal.metrics.input.recordsRead","Update":8,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":297,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":296,"Name":"internal.metrics.shuffle.write.writeTime","Update":278446,"Value":278446,"Internal":true,"Count Failed Values":true},{"ID":295,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":294,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":355,"Value":355,"Internal":true,"Count Failed Values":true},{"ID":285,"Name":"internal.metrics.peakExecutionMemory","Update":1264,"Value":1264,"Internal":true,"Count Failed Values":true},{"ID":284,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":283,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":280,"Name":"internal.metrics.resultSize","Update":1662,"Value":1662,"Internal":true,"Count Failed Values":true},{"ID":279,"Name":"internal.metrics.executorCpuTime","Update":23317154,"Value":23317154,"Internal":true,"Count Failed Values":true},{"ID":278,"Name":"internal.metrics.executorRunTime","Update":69,"Value":69,"Internal":true,"Count Failed Values":true},{"ID":277,"Name":"internal.metrics.executorDeserializeCpuTime","Update":17832528,"Value":17832528,"Internal":true,"Count Failed Values":true},{"ID":276,"Name":"internal.metrics.executorDeserializeTime","Update":53,"Value":53,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":53,"Executor Deserialize CPU Time":17832528,"Executor Run Time":69,"Executor CPU Time":23317154,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":355,"Shuffle Write Time":278446,"Shuffle Records Written":2},"Input Metrics":{"Bytes Read":72,"Records Read":8},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":11,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":28,"Index":1,"Attempt":0,"Launch Time":1538441026585,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026700,"Failed":false,"Killed":false,"Accumulables":[{"ID":298,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":297,"Name":"internal.metrics.input.bytesRead","Update":36,"Value":108,"Internal":true,"Count Failed Values":true},{"ID":296,"Name":"internal.metrics.shuffle.write.writeTime","Update":215244,"Value":493690,"Internal":true,"Count Failed Values":true},{"ID":295,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":294,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":178,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":285,"Name":"internal.metrics.peakExecutionMemory","Update":1088,"Value":2352,"Internal":true,"Count Failed Values":true},{"ID":284,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":283,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":280,"Name":"internal.metrics.resultSize","Update":1662,"Value":3324,"Internal":true,"Count Failed Values":true},{"ID":279,"Name":"internal.metrics.executorCpuTime","Update":23292541,"Value":46609695,"Internal":true,"Count Failed Values":true},{"ID":278,"Name":"internal.metrics.executorRunTime","Update":94,"Value":163,"Internal":true,"Count Failed Values":true},{"ID":277,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4400590,"Value":22233118,"Internal":true,"Count Failed Values":true},{"ID":276,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":57,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4400590,"Executor Run Time":94,"Executor CPU Time":23292541,"Result Size":1662,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":178,"Shuffle Write Time":215244,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":36,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":11,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":198912952,"JVMOffHeapMemory":104016864,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":554933,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":554933,"OffHeapUnifiedMemory":0,"DirectPoolMemory":228407,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":5040721920,"ProcessTreeJVMRSSMemory":705302528,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"7","Stage ID":11,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":249428840,"JVMOffHeapMemory":62917480,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":455614,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":455614,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20911,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3035901952,"ProcessTreeJVMRSSMemory":447041536,"ProcessTreePythonVMemory":958894080,"ProcessTreePythonRSSMemory":106696704,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":11,"Stage Attempt ID":0,"Stage Name":"filter at BisectingKMeans.scala:213","Number of Tasks":2,"RDD Info":[{"RDD ID":28,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"65\",\"name\":\"filter\"}","Callsite":"filter at BisectingKMeans.scala:213","Parent IDs":[27],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":27,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"64\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:372","Parent IDs":[8],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":8,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"12\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:170","Parent IDs":[7],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":7,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"11\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:169","Parent IDs":[6],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":2,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":6,"Name":"ZippedPartitionsRDD2","Scope":"{\"id\":\"10\",\"name\":\"zip\"}","Callsite":"zip at BisectingKMeans.scala:169","Parent IDs":[3,5],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":3,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[2],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":5,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"9\",\"name\":\"map\"}","Callsite":"map at BisectingKMeans.scala:168","Parent IDs":[3],"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.filter(RDD.scala:387)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:213)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026408,"Completion Time":1538441026701,"Accumulables":[{"ID":295,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":3,"Internal":true,"Count Failed Values":true},{"ID":298,"Name":"internal.metrics.input.recordsRead","Value":12,"Internal":true,"Count Failed Values":true},{"ID":280,"Name":"internal.metrics.resultSize","Value":3324,"Internal":true,"Count Failed Values":true},{"ID":283,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":277,"Name":"internal.metrics.executorDeserializeCpuTime","Value":22233118,"Internal":true,"Count Failed Values":true},{"ID":294,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":533,"Internal":true,"Count Failed Values":true},{"ID":276,"Name":"internal.metrics.executorDeserializeTime","Value":57,"Internal":true,"Count Failed Values":true},{"ID":285,"Name":"internal.metrics.peakExecutionMemory","Value":2352,"Internal":true,"Count Failed Values":true},{"ID":279,"Name":"internal.metrics.executorCpuTime","Value":46609695,"Internal":true,"Count Failed Values":true},{"ID":297,"Name":"internal.metrics.input.bytesRead","Value":108,"Internal":true,"Count Failed Values":true},{"ID":284,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":278,"Name":"internal.metrics.executorRunTime","Value":163,"Internal":true,"Count Failed Values":true},{"ID":296,"Name":"internal.metrics.shuffle.write.writeTime","Value":493690,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":12,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":30,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"67\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[29],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":29,"Name":"ShuffledRDD","Scope":"{\"id\":\"66\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[28],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[11],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026702,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"68\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":12,"Stage Attempt ID":0,"Task Info":{"Task ID":29,"Index":0,"Attempt":0,"Launch Time":1538441026714,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":12,"Stage Attempt ID":0,"Task Info":{"Task ID":30,"Index":1,"Attempt":0,"Launch Time":1538441026794,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":12,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":29,"Index":0,"Attempt":0,"Launch Time":1538441026714,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026795,"Failed":false,"Killed":false,"Accumulables":[{"ID":318,"Name":"internal.metrics.shuffle.read.recordsRead","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":317,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":316,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":178,"Value":178,"Internal":true,"Count Failed Values":true},{"ID":315,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":314,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":313,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":312,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":310,"Name":"internal.metrics.peakExecutionMemory","Update":800,"Value":800,"Internal":true,"Count Failed Values":true},{"ID":309,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":308,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":307,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":305,"Name":"internal.metrics.resultSize","Update":1871,"Value":1871,"Internal":true,"Count Failed Values":true},{"ID":304,"Name":"internal.metrics.executorCpuTime","Update":16951615,"Value":16951615,"Internal":true,"Count Failed Values":true},{"ID":303,"Name":"internal.metrics.executorRunTime","Update":28,"Value":28,"Internal":true,"Count Failed Values":true},{"ID":302,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12613041,"Value":12613041,"Internal":true,"Count Failed Values":true},{"ID":301,"Name":"internal.metrics.executorDeserializeTime","Update":31,"Value":31,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":31,"Executor Deserialize CPU Time":12613041,"Executor Run Time":28,"Executor CPU Time":16951615,"Result Size":1871,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":1,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":178,"Total Records Read":1},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":12,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":30,"Index":1,"Attempt":0,"Launch Time":1538441026794,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441026839,"Failed":false,"Killed":false,"Accumulables":[{"ID":318,"Name":"internal.metrics.shuffle.read.recordsRead","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":317,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":316,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":355,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":315,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":314,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":313,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":2,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":312,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":310,"Name":"internal.metrics.peakExecutionMemory","Update":992,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":309,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":308,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":307,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":305,"Name":"internal.metrics.resultSize","Update":1871,"Value":3742,"Internal":true,"Count Failed Values":true},{"ID":304,"Name":"internal.metrics.executorCpuTime","Update":17828037,"Value":34779652,"Internal":true,"Count Failed Values":true},{"ID":303,"Name":"internal.metrics.executorRunTime","Update":24,"Value":52,"Internal":true,"Count Failed Values":true},{"ID":302,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3879530,"Value":16492571,"Internal":true,"Count Failed Values":true},{"ID":301,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":36,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":3879530,"Executor Run Time":24,"Executor CPU Time":17828037,"Result Size":1871,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":2,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":355,"Total Records Read":2},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":12,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":204287872,"JVMOffHeapMemory":104055736,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":519458,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":519458,"OffHeapUnifiedMemory":0,"DirectPoolMemory":228407,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":5047037952,"ProcessTreeJVMRSSMemory":708661248,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"7","Stage ID":12,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":252161344,"JVMOffHeapMemory":63019944,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":441078,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":441078,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20911,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3038007296,"ProcessTreeJVMRSSMemory":451837952,"ProcessTreePythonVMemory":958894080,"ProcessTreePythonRSSMemory":106696704,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":12,"Stage Attempt ID":0,"Stage Name":"collect at BisectingKMeans.scala:304","Number of Tasks":2,"RDD Info":[{"RDD ID":30,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"67\",\"name\":\"mapValues\"}","Callsite":"mapValues at BisectingKMeans.scala:303","Parent IDs":[29],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":29,"Name":"ShuffledRDD","Scope":"{\"id\":\"66\",\"name\":\"aggregateByKey\"}","Callsite":"aggregateByKey at BisectingKMeans.scala:300","Parent IDs":[28],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[11],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:944)\norg.apache.spark.mllib.clustering.BisectingKMeans$.org$apache$spark$mllib$clustering$BisectingKMeans$$summarize(BisectingKMeans.scala:304)\norg.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$run$1.apply$mcVI$sp(BisectingKMeans.scala:216)\nscala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:210)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:255)\norg.apache.spark.mllib.clustering.BisectingKMeans.run(BisectingKMeans.scala:261)\norg.apache.spark.mllib.api.python.PythonMLLibAPI.trainBisectingKMeans(PythonMLLibAPI.scala:135)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026702,"Completion Time":1538441026840,"Accumulables":[{"ID":304,"Name":"internal.metrics.executorCpuTime","Value":34779652,"Internal":true,"Count Failed Values":true},{"ID":313,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":307,"Name":"internal.metrics.resultSerializationTime","Value":2,"Internal":true,"Count Failed Values":true},{"ID":316,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":533,"Internal":true,"Count Failed Values":true},{"ID":301,"Name":"internal.metrics.executorDeserializeTime","Value":36,"Internal":true,"Count Failed Values":true},{"ID":310,"Name":"internal.metrics.peakExecutionMemory","Value":1792,"Internal":true,"Count Failed Values":true},{"ID":318,"Name":"internal.metrics.shuffle.read.recordsRead","Value":3,"Internal":true,"Count Failed Values":true},{"ID":309,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":303,"Name":"internal.metrics.executorRunTime","Value":52,"Internal":true,"Count Failed Values":true},{"ID":312,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":0,"Internal":true,"Count Failed Values":true},{"ID":315,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":317,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":0,"Internal":true,"Count Failed Values":true},{"ID":308,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":302,"Name":"internal.metrics.executorDeserializeCpuTime","Value":16492571,"Internal":true,"Count Failed Values":true},{"ID":314,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":0,"Internal":true,"Count Failed Values":true},{"ID":305,"Name":"internal.metrics.resultSize","Value":3742,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":6,"Completion Time":1538441026840,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerUnpersistRDD","RDD ID":32} +{"Event":"SparkListenerUnpersistRDD","RDD ID":5} +{"Event":"SparkListenerJobStart","Job ID":7,"Submission Time":1538441026935,"Stage Infos":[{"Stage ID":13,"Stage Attempt ID":0,"Stage Name":"sum at BisectingKMeansModel.scala:101","Number of Tasks":2,"RDD Info":[{"RDD ID":36,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"78\",\"name\":\"map\"}","Callsite":"map at BisectingKMeansModel.scala:101","Parent IDs":[35],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":35,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"77\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[34],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":34,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.DoubleRDDFunctions.sum(DoubleRDDFunctions.scala:34)\norg.apache.spark.mllib.clustering.BisectingKMeansModel.computeCost(BisectingKMeansModel.scala:101)\norg.apache.spark.mllib.clustering.BisectingKMeansModel.computeCost(BisectingKMeansModel.scala:108)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Accumulables":[]}],"Stage IDs":[13],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"79\",\"name\":\"sum\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":13,"Stage Attempt ID":0,"Stage Name":"sum at BisectingKMeansModel.scala:101","Number of Tasks":2,"RDD Info":[{"RDD ID":36,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"78\",\"name\":\"map\"}","Callsite":"map at BisectingKMeansModel.scala:101","Parent IDs":[35],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":35,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"77\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[34],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":34,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.DoubleRDDFunctions.sum(DoubleRDDFunctions.scala:34)\norg.apache.spark.mllib.clustering.BisectingKMeansModel.computeCost(BisectingKMeansModel.scala:101)\norg.apache.spark.mllib.clustering.BisectingKMeansModel.computeCost(BisectingKMeansModel.scala:108)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026936,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"79\",\"name\":\"sum\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":13,"Stage Attempt ID":0,"Task Info":{"Task ID":31,"Index":0,"Attempt":0,"Launch Time":1538441026947,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Port":40992},"Timestamp":1538441027285} +{"Event":"SparkListenerTaskEnd","Stage ID":13,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"7","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000008 on host: rezamemory-3.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000008\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":31,"Index":0,"Attempt":0,"Launch Time":1538441026947,"Executor ID":"7","Host":"rezamemory-3.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441027494,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538441027495,"Executor ID":"7","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000008 on host: rezamemory-3.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000008\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538441032740,"Executor ID":"8","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000009/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000009/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":13,"Stage Attempt ID":0,"Task Info":{"Task ID":32,"Index":0,"Attempt":1,"Launch Time":1538441032741,"Executor ID":"8","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"8","Host":"rezamemory-2.gce.something.com","Port":41485},"Maximum Memory":384093388,"Timestamp":1538441033142,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1538441036142,"Executor ID":"9","Executor Info":{"Host":"rezamemory-2.gce.something.com","Total Cores":1,"Log Urls":{"stdout":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000010/root/stdout?start=-4096","stderr":"http://rezamemory-2.gce.something.com:8042/node/containerlogs/container_1538416563558_0014_01_000010/root/stderr?start=-4096"}}} +{"Event":"SparkListenerTaskStart","Stage ID":13,"Stage Attempt ID":0,"Task Info":{"Task ID":33,"Index":1,"Attempt":0,"Launch Time":1538441036144,"Executor ID":"9","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"9","Host":"rezamemory-2.gce.something.com","Port":40797},"Maximum Memory":384093388,"Timestamp":1538441036560,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"8","Host":"rezamemory-2.gce.something.com","Port":41485},"Timestamp":1538441040323} +{"Event":"SparkListenerTaskEnd","Stage ID":13,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExecutorLostFailure","Executor ID":"8","Exit Caused By App":true,"Loss Reason":"Container marked as failed: container_1538416563558_0014_01_000009 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000009\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"},"Task Info":{"Task ID":32,"Index":0,"Attempt":1,"Launch Time":1538441032741,"Executor ID":"8","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441040533,"Failed":true,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerExecutorRemoved","Timestamp":1538441040534,"Executor ID":"8","Removed Reason":"Container marked as failed: container_1538416563558_0014_01_000009 on host: rezamemory-2.gce.something.com. Exit status: 56. Diagnostics: Exception from container-launch.\nContainer id: container_1538416563558_0014_01_000009\nExit code: 56\nStack trace: ExitCodeException exitCode=56: \n\tat org.apache.hadoop.util.Shell.runCommand(Shell.java:601)\n\tat org.apache.hadoop.util.Shell.run(Shell.java:504)\n\tat org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:786)\n\tat org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302)\n\tat org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82)\n\tat java.util.concurrent.FutureTask.run(FutureTask.java:266)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n\n\nContainer exited with a non-zero exit code 56\n"} +{"Event":"SparkListenerTaskStart","Stage ID":13,"Stage Attempt ID":0,"Task Info":{"Task ID":34,"Index":0,"Attempt":2,"Launch Time":1538441042184,"Executor ID":"9","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":13,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":33,"Index":1,"Attempt":0,"Launch Time":1538441036144,"Executor ID":"9","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441042185,"Failed":false,"Killed":false,"Accumulables":[{"ID":348,"Name":"internal.metrics.input.recordsRead","Update":2,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":347,"Name":"internal.metrics.input.bytesRead","Update":36,"Value":36,"Internal":true,"Count Failed Values":true},{"ID":334,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":333,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":332,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":331,"Name":"internal.metrics.jvmGCTime","Update":288,"Value":288,"Internal":true,"Count Failed Values":true},{"ID":330,"Name":"internal.metrics.resultSize","Update":1539,"Value":1539,"Internal":true,"Count Failed Values":true},{"ID":329,"Name":"internal.metrics.executorCpuTime","Update":1278640624,"Value":1278640624,"Internal":true,"Count Failed Values":true},{"ID":328,"Name":"internal.metrics.executorRunTime","Update":2796,"Value":2796,"Internal":true,"Count Failed Values":true},{"ID":327,"Name":"internal.metrics.executorDeserializeCpuTime","Update":720112530,"Value":720112530,"Internal":true,"Count Failed Values":true},{"ID":326,"Name":"internal.metrics.executorDeserializeTime","Update":2587,"Value":2587,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2587,"Executor Deserialize CPU Time":720112530,"Executor Run Time":2796,"Executor CPU Time":1278640624,"Result Size":1539,"JVM GC Time":288,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":36,"Records Read":2},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":13,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":34,"Index":0,"Attempt":2,"Launch Time":1538441042184,"Executor ID":"9","Host":"rezamemory-2.gce.something.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1538441042334,"Failed":false,"Killed":false,"Accumulables":[{"ID":348,"Name":"internal.metrics.input.recordsRead","Update":4,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":347,"Name":"internal.metrics.input.bytesRead","Update":72,"Value":108,"Internal":true,"Count Failed Values":true},{"ID":334,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":333,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":330,"Name":"internal.metrics.resultSize","Update":1453,"Value":2992,"Internal":true,"Count Failed Values":true},{"ID":329,"Name":"internal.metrics.executorCpuTime","Update":69678739,"Value":1348319363,"Internal":true,"Count Failed Values":true},{"ID":328,"Name":"internal.metrics.executorRunTime","Update":118,"Value":2914,"Internal":true,"Count Failed Values":true},{"ID":327,"Name":"internal.metrics.executorDeserializeCpuTime","Update":6252896,"Value":726365426,"Internal":true,"Count Failed Values":true},{"ID":326,"Name":"internal.metrics.executorDeserializeTime","Update":6,"Value":2593,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":6,"Executor Deserialize CPU Time":6252896,"Executor Run Time":118,"Executor CPU Time":69678739,"Result Size":1453,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":72,"Records Read":4},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"driver","Stage ID":13,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":266240264,"JVMOffHeapMemory":104976128,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":534126,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":534126,"OffHeapUnifiedMemory":0,"DirectPoolMemory":228407,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":5067235328,"ProcessTreeJVMRSSMemory":710475776,"ProcessTreePythonVMemory":408375296,"ProcessTreePythonRSSMemory":40284160,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"8","Stage ID":13,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":197860072,"JVMOffHeapMemory":57762424,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1088805,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1088805,"OffHeapUnifiedMemory":0,"DirectPoolMemory":25453,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3028791296,"ProcessTreeJVMRSSMemory":430297088,"ProcessTreePythonVMemory":286212096,"ProcessTreePythonRSSMemory":30441472,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageExecutorMetrics","Executor ID":"9","Stage ID":13,"Stage Attempt ID":0,"Executor Metrics":{"JVMHeapMemory":193766856,"JVMOffHeapMemory":59006656,"OnHeapExecutionMemory":0,"OffHeapExecutionMemory":0,"OnHeapStorageMemory":1088805,"OffHeapStorageMemory":0,"OnHeapUnifiedMemory":1088805,"OffHeapUnifiedMemory":0,"DirectPoolMemory":20181,"MappedPoolMemory":0,"ProcessTreeJVMVMemory":3016261632,"ProcessTreeJVMRSSMemory":405860352,"ProcessTreePythonVMemory":625926144,"ProcessTreePythonRSSMemory":69013504,"ProcessTreeOtherVMemory":0,"ProcessTreeOtherRSSMemory":0}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":13,"Stage Attempt ID":0,"Stage Name":"sum at BisectingKMeansModel.scala:101","Number of Tasks":2,"RDD Info":[{"RDD ID":36,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"78\",\"name\":\"map\"}","Callsite":"map at BisectingKMeansModel.scala:101","Parent IDs":[35],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":35,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"77\",\"name\":\"mapPartitions\"}","Callsite":"mapPartitions at PythonMLLibAPI.scala:1346","Parent IDs":[34],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":34,"Name":"PythonRDD","Callsite":"RDD at PythonRDD.scala:53","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":1,"Name":"data/mllib/kmeans_data.txt","Scope":"{\"id\":\"0\",\"name\":\"textFile\"}","Callsite":"textFile at NativeMethodAccessorImpl.java:0","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":2,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.DoubleRDDFunctions.sum(DoubleRDDFunctions.scala:34)\norg.apache.spark.mllib.clustering.BisectingKMeansModel.computeCost(BisectingKMeansModel.scala:101)\norg.apache.spark.mllib.clustering.BisectingKMeansModel.computeCost(BisectingKMeansModel.scala:108)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\npy4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\npy4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\npy4j.Gateway.invoke(Gateway.java:282)\npy4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\npy4j.commands.CallCommand.execute(CallCommand.java:79)\npy4j.GatewayConnection.run(GatewayConnection.java:238)\njava.lang.Thread.run(Thread.java:745)","Submission Time":1538441026936,"Completion Time":1538441042335,"Accumulables":[{"ID":331,"Name":"internal.metrics.jvmGCTime","Value":288,"Internal":true,"Count Failed Values":true},{"ID":334,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":328,"Name":"internal.metrics.executorRunTime","Value":2914,"Internal":true,"Count Failed Values":true},{"ID":327,"Name":"internal.metrics.executorDeserializeCpuTime","Value":726365426,"Internal":true,"Count Failed Values":true},{"ID":348,"Name":"internal.metrics.input.recordsRead","Value":6,"Internal":true,"Count Failed Values":true},{"ID":330,"Name":"internal.metrics.resultSize","Value":2992,"Internal":true,"Count Failed Values":true},{"ID":333,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":332,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":326,"Name":"internal.metrics.executorDeserializeTime","Value":2593,"Internal":true,"Count Failed Values":true},{"ID":347,"Name":"internal.metrics.input.bytesRead","Value":108,"Internal":true,"Count Failed Values":true},{"ID":329,"Name":"internal.metrics.executorCpuTime","Value":1348319363,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":7,"Completion Time":1538441042335,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1538441042338} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 7c9f8aba17f3c..2a2d013bacbda 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -83,6 +83,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.testing", "true") .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .set("spark.eventLog.logStageExecutorMetrics.enabled", "true") + .set("spark.eventLog.logStageExecutorProcessTreeMetrics.enabled", "true") conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -131,6 +132,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "executor list json" -> "applications/local-1422981780767/executors", "executor list with executor metrics json" -> "applications/application_1506645932520_24630151/executors", + "executor list with executor process tree metrics json" -> + "applications/application_1538416563558_0014/executors", "stage list json" -> "applications/local-1422981780767/stages", "complete stage list json" -> "applications/local-1422981780767/stages?status=complete", "failed stage list json" -> "applications/local-1422981780767/stages?status=failed", diff --git a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala new file mode 100644 index 0000000000000..9ed1497db5e1d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.SparkFunSuite + + +class ProcfsMetricsGetterSuite extends SparkFunSuite { + + val p = new ProcfsMetricsGetter(getTestResourcePath("ProcfsMetrics")) + + test("testGetProcessInfo") { + var r = ProcfsMetrics(0, 0, 0, 0, 0, 0) + r = p.addProcfsMetricsFromOneProcess(r, 26109) + assert(r.jvmVmemTotal == 4769947648L) + assert(r.jvmRSSTotal == 262610944) + assert(r.pythonVmemTotal == 0) + assert(r.pythonRSSTotal == 0) + + r = p.addProcfsMetricsFromOneProcess(r, 22763) + assert(r.pythonVmemTotal == 360595456) + assert(r.pythonRSSTotal == 7831552) + assert(r.jvmVmemTotal == 4769947648L) + assert(r.jvmRSSTotal == 262610944) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index cecd6996df7bd..0c04a93646d7c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -282,53 +282,67 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // receive 3 metric updates from each executor with just stage 0 running, // with different peak updates for each executor createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L))), + new ExecutorMetrics(Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L, 7500L, 3500L, + 6500L, 2500L, 5500L, 1500L))), createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L))), - // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 + new ExecutorMetrics(Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L, 8500L, 3500L, + 7500L, 2500L, 6500L, 1500L))), + // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L))), + new ExecutorMetrics(Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L, 8000L, 4000L, + 7000L, 3000L, 6000L, 2000L))), // exec 2: new stage 0 peaks for metrics at indexes: 0, 4, 6 createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L))), + new ExecutorMetrics(Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L, 9000L, 4000L, + 8000L, 3000L, 7000L, 2000L))), // exec 1: new stage 0 peaks for metrics at indexes: 5, 7 createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L))), + new ExecutorMetrics(Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L, 8000L, 3500L, + 7000L, 2500L, 6000L, 1500L))), // exec 2: new stage 0 peaks for metrics at indexes: 0, 5, 6, 7, 8 createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L))), + new ExecutorMetrics(Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L, 8500L, 3500L, + 7500L, 2500L, 6500L, 1500L))), // now start stage 1, one more metric update for each executor, and new // peaks for some stage 1 metrics (as listed), initialize stage 1 peaks createStageSubmittedEvent(1), // exec 1: new stage 0 peaks for metrics at indexes: 0, 3, 7; initialize stage 1 peaks createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L))), - // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 2, 3, 6, 7, 9; + new ExecutorMetrics(Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, + 0L, 5000L, 3000L, 4000L, 2000L, 3000L, 1000L))), + // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 3, 6, 7, 9; // initialize stage 1 peaks createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L))), + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, + 40L, 8000L, 4000L, 7000L, 3000L, 6000L, 2000L))), // complete stage 0, and 3 more updates for each executor with just // stage 1 running createStageCompletedEvent(0), // exec 1: new stage 1 peaks for metrics at indexes: 0, 1, 3 createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L))), - // enew ExecutorMetrics(xec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 + new ExecutorMetrics(Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L, 5000L, 3000L, + 4000L, 2000L, 3000L, 1000L))), + // exec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L))), + new ExecutorMetrics(Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, + 20L, 8000L, 5000L, 7000L, 4000L, 6000L, 3000L, 5000L, 2000L))), // exec 1: new stage 1 peaks for metrics at indexes: 0, 4, 5, 7 createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L))), + new ExecutorMetrics(Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L, 3000L, 2500L, + 2000L, 1500L, 1000L, 500L))), // exec 2: new stage 1 peak for metrics at index: 7 createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L))), + new ExecutorMetrics(Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, + 20L, 7000L, 3000L, 6000L, 2000L, 5000L, 1000L))), // exec 1: no new stage 1 peaks createExecutorMetricsUpdateEvent(1, - new ExecutorMetrics(Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L))), + new ExecutorMetrics(Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, + 0L, 4000L, 2500L, 3000L, 1500L, 2000L, 500L))), createExecutorRemovedEvent(1), // exec 2: new stage 1 peak for metrics at index: 6 createExecutorMetricsUpdateEvent(2, - new ExecutorMetrics(Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L))), + new ExecutorMetrics(Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L, 7000L, + 4000L, 6000L, 3000L, 5000L, 2000L))), createStageCompletedEvent(1), SparkListenerApplicationEnd(1000L)) @@ -342,20 +356,23 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // expected StageExecutorMetrics, for the given stage id and executor id val expectedMetricsEvents: Map[(Int, String), SparkListenerStageExecutorMetrics] = - Map( - ((0, "1"), - new SparkListenerStageExecutorMetrics("1", 0, 0, - new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, 70L, 20L)))), - ((0, "2"), - new SparkListenerStageExecutorMetrics("2", 0, 0, - new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L)))), - ((1, "1"), - new SparkListenerStageExecutorMetrics("1", 1, 0, - new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L)))), - ((1, "2"), - new SparkListenerStageExecutorMetrics("2", 1, 0, - new ExecutorMetrics(Array(7000L, 70L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L))))) - + Map( + ((0, "1"), + new SparkListenerStageExecutorMetrics("1", 0, 0, + new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, + 70L, 20L, 8000L, 4000L, 7000L, 3000L, 6000L, 2000L)))), + ((0, "2"), + new SparkListenerStageExecutorMetrics("2", 0, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, + 80L, 40L, 9000L, 4000L, 8000L, 3000L, 7000L, 2000L)))), + ((1, "1"), + new SparkListenerStageExecutorMetrics("1", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, + 50L, 0L, 5000L, 3000L, 4000L, 2000L, 3000L, 1000L)))), + ((1, "2"), + new SparkListenerStageExecutorMetrics("2", 1, 0, + new ExecutorMetrics(Array(7000L, 70L, 50L, 40L, 10L, 30L, 50L, 60L, + 40L, 40L, 8000L, 5000L, 7000L, 4000L, 6000L, 3000L))))) // Verify the log file contains the expected events. // Posted events should be logged, except for ExecutorMetricsUpdate events -- these // are consolidated, and the peak values for each stage are logged at stage end. @@ -456,9 +473,9 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit assert(executorMetrics.execId === expectedMetrics.execId) assert(executorMetrics.stageId === expectedMetrics.stageId) assert(executorMetrics.stageAttemptId === expectedMetrics.stageAttemptId) - ExecutorMetricType.values.foreach { metricType => - assert(executorMetrics.executorMetrics.getMetricValue(metricType) === - expectedMetrics.executorMetrics.getMetricValue(metricType)) + ExecutorMetricType.metricToOffset.foreach { metric => + assert(executorMetrics.executorMetrics.getMetricValue(metric._1) === + expectedMetrics.executorMetrics.getMetricValue(metric._1)) } case None => assert(false) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 61fec8c1d0e4e..71eeb0480245d 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1367,58 +1367,74 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // receive 3 metric updates from each executor with just stage 0 running, // with different peak updates for each executor listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L))) + Array(4000L, 50L, 20L, 0L, 40L, 0L, 60L, 0L, 70L, 20L, 7500L, 3500L, + 6500L, 2500L, 5500L, 1500L))) listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L))) + Array(1500L, 50L, 20L, 0L, 0L, 0L, 20L, 0L, 70L, 0L, 8500L, 3500L, + 7500L, 2500L, 6500L, 1500L))) // exec 1: new stage 0 peaks for metrics at indexes: 2, 4, 6 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L))) + Array(4000L, 50L, 50L, 0L, 50L, 0L, 100L, 0L, 70L, 20L, 8000L, 4000L, + 7000L, 3000L, 6000L, 2000L))) // exec 2: new stage 0 peaks for metrics at indexes: 0, 4, 6 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L))) + Array(2000L, 50L, 10L, 0L, 10L, 0L, 30L, 0L, 70L, 0L, 9000L, 4000L, + 8000L, 3000L, 7000L, 2000L))) // exec 1: new stage 0 peaks for metrics at indexes: 5, 7 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L))) + Array(2000L, 40L, 50L, 0L, 40L, 10L, 90L, 10L, 50L, 0L, 8000L, 3500L, + 7000L, 2500L, 6000L, 1500L))) // exec 2: new stage 0 peaks for metrics at indexes: 0, 5, 6, 7, 8 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L))) + Array(3500L, 50L, 15L, 0L, 10L, 10L, 35L, 10L, 80L, 0L, 8500L, 3500L, + 7500L, 2500L, 6500L, 1500L))) // now start stage 1, one more metric update for each executor, and new // peaks for some stage 1 metrics (as listed), initialize stage 1 peaks listener.onStageSubmitted(createStageSubmittedEvent(1)) // exec 1: new stage 0 peaks for metrics at indexes: 0, 3, 7 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L))) + Array(5000L, 30L, 50L, 20L, 30L, 10L, 80L, 30L, 50L, 0L, 5000L, 3000L, + 4000L, 2000L, 3000L, 1000L))) // exec 2: new stage 0 peaks for metrics at indexes: 0, 1, 2, 3, 6, 7, 9 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(7000L, 80L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L))) + Array(7000L, 80L, 50L, 20L, 0L, 10L, 50L, 30L, 10L, 40L, 8000L, 4000L, + 7000L, 3000L, 6000L, 2000L))) // complete stage 0, and 3 more updates for each executor with just // stage 1 running listener.onStageCompleted(createStageCompletedEvent(0)) // exec 1: new stage 1 peaks for metrics at indexes: 0, 1, 3 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L))) + Array(6000L, 70L, 20L, 30L, 10L, 0L, 30L, 30L, 30L, 0L, 5000L, 3000L, + 4000L, 2000L, 3000L, 1000L))) // exec 2: new stage 1 peaks for metrics at indexes: 3, 4, 7, 8 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L))) + Array(5500L, 30L, 20L, 40L, 10L, 0L, 30L, 40L, 40L, 20L, 8000L, 5000L, + 7000L, 4000L, 6000L, 3000L))) // exec 1: new stage 1 peaks for metrics at indexes: 0, 4, 5, 7 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L))) + Array(7000L, 70L, 5L, 25L, 60L, 30L, 65L, 55L, 30L, 0L, 3000L, 2500L, 2000L, + 1500L, 1000L, 500L))) // exec 2: new stage 1 peak for metrics at index: 7 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L))) + Array(5500L, 40L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 20L, 7000L, 3000L, + 6000L, 2000L, 5000L, 1000L))) // exec 1: no new stage 1 peaks listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(1, - Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L))) + Array(5500L, 70L, 15L, 20L, 55L, 20L, 70L, 40L, 20L, 0L, 4000L, 2500L, + 3000L, 1500, 2000L, 500L))) listener.onExecutorRemoved(createExecutorRemovedEvent(1)) // exec 2: new stage 1 peak for metrics at index: 6 listener.onExecutorMetricsUpdate(createExecutorMetricsUpdateEvent(2, - Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L))) + Array(4000L, 20L, 25L, 30L, 10L, 30L, 35L, 60L, 0L, 0L, 7000L, 4000L, 6000L, + 3000L, 5000L, 2000L))) listener.onStageCompleted(createStageCompletedEvent(1)) // expected peak values for each executor val expectedValues = Map( - "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, 70L, 20L)), - "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 80L, 40L))) + "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, + 70L, 20L, 8000L, 4000L, 7000L, 3000L, 6000L, 2000L)), + "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, + 80L, 40L, 9000L, 5000L, 8000L, 4000L, 7000L, 3000L))) // check that the stored peak values match the expected values expectedValues.foreach { case (id, metrics) => @@ -1426,8 +1442,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.id === id) exec.info.peakMemoryMetrics match { case Some(actual) => - ExecutorMetricType.values.foreach { metricType => - assert(actual.getMetricValue(metricType) === metrics.getMetricValue(metricType)) + ExecutorMetricType.metricToOffset.foreach { metric => + assert(actual.getMetricValue(metric._1) === metrics.getMetricValue(metric._1)) } case _ => assert(false) @@ -1446,23 +1462,29 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { listener.onStageSubmitted(createStageSubmittedEvent(0)) listener.onStageSubmitted(createStageSubmittedEvent(1)) listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("1", 0, 0, - new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, 70L, 20L)))) + new ExecutorMetrics(Array(5000L, 50L, 50L, 20L, 50L, 10L, 100L, 30L, + 70L, 20L, 8000L, 4000L, 7000L, 3000L, 6000L, 2000L)))) listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("2", 0, 0, - new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L)))) + new ExecutorMetrics(Array(7000L, 70L, 50L, 20L, 10L, 10L, 50L, 30L, 80L, 40L, 9000L, + 4000L, 8000L, 3000L, 7000L, 2000L)))) listener.onStageCompleted(createStageCompletedEvent(0)) // executor 1 is removed before stage 1 has finished, the stage executor metrics // are logged afterwards and should still be used to update the executor metrics. listener.onExecutorRemoved(createExecutorRemovedEvent(1)) listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("1", 1, 0, - new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L)))) + new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 80L, 55L, 50L, 0L, 5000L, 3000L, + 4000L, 2000L, 3000L, 1000L)))) listener.onStageExecutorMetrics(SparkListenerStageExecutorMetrics("2", 1, 0, - new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L)))) + new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 40L, 40L, 8000L, 5000L, + 7000L, 4000L, 6000L, 3000L)))) listener.onStageCompleted(createStageCompletedEvent(1)) // expected peak values for each executor val expectedValues = Map( - "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, 70L, 20L)), - "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, 80L, 40L))) + "1" -> new ExecutorMetrics(Array(7000L, 70L, 50L, 30L, 60L, 30L, 100L, 55L, + 70L, 20L, 8000L, 4000L, 7000L, 3000L, 6000L, 2000L)), + "2" -> new ExecutorMetrics(Array(7000L, 80L, 50L, 40L, 10L, 30L, 50L, 60L, + 80L, 40L, 9000L, 5000L, 8000L, 4000L, 7000L, 3000L))) // check that the stored peak values match the expected values for ((id, metrics) <- expectedValues) { @@ -1470,8 +1492,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.id === id) exec.info.peakMemoryMetrics match { case Some(actual) => - ExecutorMetricType.values.foreach { metricType => - assert(actual.getMetricValue(metricType) === metrics.getMetricValue(metricType)) + ExecutorMetricType.metricToOffset.foreach { metric => + assert(actual.getMetricValue(metric._1) === metrics.getMetricValue(metric._1)) } case _ => assert(false) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 1e0d2af9a4711..303ca7cb8801a 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -96,7 +96,8 @@ class JsonProtocolSuite extends SparkFunSuite { .accumulators().map(AccumulatorSuite.makeInfo) .zipWithIndex.map { case (a, i) => a.copy(id = i) } val executorUpdates = new ExecutorMetrics( - Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L)) + Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, + 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L)) SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)), Some(executorUpdates)) } @@ -105,8 +106,8 @@ class JsonProtocolSuite extends SparkFunSuite { "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) val stageExecutorMetrics = SparkListenerStageExecutorMetrics("1", 2, 3, - new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, 321L, 654L, 765L))) - + new ExecutorMetrics(Array(543L, 123456L, 12345L, 1234L, 123L, 12L, 432L, + 321L, 654L, 765L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) testEvent(taskStart, taskStartJsonString) @@ -440,14 +441,14 @@ class JsonProtocolSuite extends SparkFunSuite { test("executorMetricsFromJson backward compatibility: handle missing metrics") { // any missing metrics should be set to 0 - val executorMetrics = new ExecutorMetrics( - Array(12L, 23L, 45L, 67L, 78L, 89L, 90L, 123L, 456L, 789L)) + val executorMetrics = new ExecutorMetrics(Array(12L, 23L, 45L, 67L, 78L, 89L, + 90L, 123L, 456L, 789L, 40L, 20L, 20L, 10L, 20L, 10L)) val oldExecutorMetricsJson = JsonProtocol.executorMetricsToJson(executorMetrics) .removeField( _._1 == "MappedPoolMemory") - val expectedExecutorMetrics = new ExecutorMetrics( - Array(12L, 23L, 45L, 67L, 78L, 89L, 90L, 123L, 456L, 0L)) - assertEquals(expectedExecutorMetrics, + val exepectedExecutorMetrics = new ExecutorMetrics(Array(12L, 23L, 45L, 67L, + 78L, 89L, 90L, 123L, 456L, 0L, 40L, 20L, 20L, 10L, 20L, 10L)) + assertEquals(exepectedExecutorMetrics, JsonProtocol.executorMetricsFromJson(oldExecutorMetricsJson)) } @@ -753,9 +754,9 @@ private[spark] object JsonProtocolSuite extends Assertions { assertStackTraceElementEquals) } - private def assertEquals(metrics1: ExecutorMetrics, metrics2: ExecutorMetrics) { - ExecutorMetricType.values.foreach { metricType => - assert(metrics1.getMetricValue(metricType) === metrics2.getMetricValue(metricType)) + private def assertEquals(metrics1: ExecutorMetrics, metrics2: ExecutorMetrics): Unit = { + ExecutorMetricType.metricToOffset.foreach { metric => + assert(metrics1.getMetricValue(metric._1) === metrics2.getMetricValue(metric._1)) } } @@ -872,13 +873,14 @@ private[spark] object JsonProtocolSuite extends Assertions { if (includeTaskMetrics) { Seq((1L, 1, 1, Seq(makeAccumulableInfo(1, false, false, None), makeAccumulableInfo(2, false, false, None)))) - } else { + } else { Seq() } val executorMetricsUpdate = if (includeExecutorMetrics) { - Some(new ExecutorMetrics(Array(123456L, 543L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L))) - } else { + Some(new ExecutorMetrics(Array(123456L, 543L, 0L, 0L, 0L, 0L, 0L, + 0L, 0L, 0L, 256912L, 123456L, 123456L, 61728L, 30364L, 15182L))) + } else { None } SparkListenerExecutorMetricsUpdate(execId, taskMetrics, executorMetricsUpdate) @@ -2082,7 +2084,13 @@ private[spark] object JsonProtocolSuite extends Assertions { | "OnHeapUnifiedMemory" : 432, | "OffHeapUnifiedMemory" : 321, | "DirectPoolMemory" : 654, - | "MappedPoolMemory" : 765 + | "MappedPoolMemory" : 765, + | "ProcessTreeJVMVMemory": 256912, + | "ProcessTreeJVMRSSMemory": 123456, + | "ProcessTreePythonVMemory": 123456, + | "ProcessTreePythonRSSMemory": 61728, + | "ProcessTreeOtherVMemory": 30364, + | "ProcessTreeOtherRSSMemory": 15182 | } | |} @@ -2105,7 +2113,13 @@ private[spark] object JsonProtocolSuite extends Assertions { | "OnHeapUnifiedMemory" : 432, | "OffHeapUnifiedMemory" : 321, | "DirectPoolMemory" : 654, - | "MappedPoolMemory" : 765 + | "MappedPoolMemory" : 765, + | "ProcessTreeJVMVMemory": 256912, + | "ProcessTreeJVMRSSMemory": 123456, + | "ProcessTreePythonVMemory": 123456, + | "ProcessTreePythonRSSMemory": 61728, + | "ProcessTreeOtherVMemory": 30364, + | "ProcessTreeOtherRSSMemory": 15182 | } |} """.stripMargin diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 777950016801d..8239cbc3a381c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -82,6 +82,8 @@ app-20161115172038-0000 app-20161116163331-0000 application_1516285256255_0012 application_1506645932520_24630151 +application_1538416563558_0014 +stat local-1422981759269 local-1422981780767 local-1425081759269 From 0a37da68e1cbca0e0120beab916309898022ba85 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 10 Dec 2018 12:04:44 -0800 Subject: [PATCH 2252/2461] [SPARK-26317][BUILD] Upgrade SBT to 0.13.18 ## What changes were proposed in this pull request? SBT 0.13.14 ~ 1.1.1 has a bug on accessing `java.util.Base64.getDecoder` with JDK9+. It's fixed at 1.1.2 and backported to [0.13.18 (released on Nov 28th)](https://github.com/sbt/sbt/releases/tag/v0.13.18). This PR aims to update SBT. ## How was this patch tested? Pass the Jenkins with the building and existing tests. Closes #23270 from dongjoon-hyun/SPARK-26317. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- project/build.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/build.properties b/project/build.properties index d03985d980ec8..23aa187fb35a7 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.17 +sbt.version=0.13.18 From 82c1ac48a37bcc929db86515bffd602c381415be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E7=94=B0=E7=94=B000222924?= Date: Mon, 10 Dec 2018 18:27:01 -0600 Subject: [PATCH 2253/2461] =?UTF-8?q?[SPARK-25696]=20The=20storage=20memor?= =?UTF-8?q?y=20displayed=20on=20spark=20Application=20UI=20is=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … incorrect. ## What changes were proposed in this pull request? In the reported heartbeat information, the unit of the memory data is bytes, which is converted by the formatBytes() function in the utils.js file before being displayed in the interface. The cardinality of the unit conversion in the formatBytes function is 1000, which should be 1024. Change the cardinality of the unit conversion in the formatBytes function to 1024. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22683 from httfighter/SPARK-25696. Lead-authored-by: 韩田田00222924 Co-authored-by: han.tiantian@zte.com.cn Signed-off-by: Sean Owen --- R/pkg/R/context.R | 2 +- R/pkg/R/mllib_tree.R | 6 +-- .../org/apache/spark/ui/static/utils.js | 4 +- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../spark/serializer/KryoSerializer.scala | 4 +- .../scala/org/apache/spark/util/Utils.scala | 46 +++++++++---------- .../apache/spark/MapOutputTrackerSuite.scala | 2 +- .../serializer/KryoSerializerSuite.scala | 2 +- .../apache/spark/storage/DiskStoreSuite.scala | 2 +- .../org/apache/spark/util/UtilsSuite.scala | 18 ++++---- docs/configuration.md | 6 +-- docs/hardware-provisioning.md | 4 +- docs/mllib-decision-tree.md | 2 +- docs/running-on-mesos.md | 2 +- docs/spark-standalone.md | 4 +- docs/streaming-kinesis-integration.md | 2 +- docs/tuning.md | 10 ++-- .../linalg/distributed/BlockMatrix.scala | 2 +- .../optimization/GradientDescentSuite.scala | 2 +- python/pyspark/rdd.py | 2 +- python/pyspark/shuffle.py | 4 +- .../spark/deploy/yarn/YarnAllocator.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 4 +- .../catalyst/expressions/OrderingSuite.scala | 2 +- .../exchange/ExchangeCoordinator.scala | 14 +++--- .../execution/python/WindowInPandasExec.scala | 2 +- .../sql/execution/window/WindowExec.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../spark/sql/StatisticsCollectionSuite.scala | 12 ++--- 29 files changed, 85 insertions(+), 85 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index e99136723f65b..0207f249f9aa0 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,7 +87,7 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' -#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MiB), the function #' will write it to disk and send the file name to JVM. Also to make sure each slice is not #' larger than that limit, number of slices may be increased. #' diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 0e60842dd44c8..9844061cfd074 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -157,7 +157,7 @@ print.summary.decisionTree <- function(x) { #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' Note: this setting will be ignored if the checkpoint directory is not #' set. -#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param maxMemoryInMB Maximum memory in MiB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the @@ -382,7 +382,7 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' Note: this setting will be ignored if the checkpoint directory is not #' set. -#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param maxMemoryInMB Maximum memory in MiB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the @@ -588,7 +588,7 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' Note: this setting will be ignored if the checkpoint directory is not #' set. -#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param maxMemoryInMB Maximum memory in MiB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js index deeafad4eb5f5..22985e31a7808 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/utils.js +++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js @@ -40,9 +40,9 @@ function formatDuration(milliseconds) { function formatBytes(bytes, type) { if (type !== 'display') return bytes; if (bytes == 0) return '0.0 B'; - var k = 1000; + var k = 1024; var dm = 1; - var sizes = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB']; + var sizes = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']; var i = Math.floor(Math.log(bytes) / Math.log(k)); return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i]; } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 845a3d5f6d6f9..696dafda6d1ec 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1043,7 +1043,7 @@ class SparkContext(config: SparkConf) extends Logging { // See SPARK-11227 for details. FileSystem.getLocal(hadoopConfiguration) - // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. + // A Hadoop configuration can be about 10 KiB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( @@ -2723,7 +2723,7 @@ object SparkContext extends Logging { val memoryPerSlaveInt = memoryPerSlave.toInt if (sc.executorMemory > memoryPerSlaveInt) { throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + "Asked to launch cluster with %d MiB RAM / worker but requested %d MiB/worker".format( memoryPerSlaveInt, sc.executorMemory)) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1e1c27c477877..72ca0fbe667e3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -62,14 +62,14 @@ class KryoSerializer(conf: SparkConf) if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + - s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") + s"2048 MiB, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} MiB.") } private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + - s"2048 mb, got: + $maxBufferSizeMb mb.") + s"2048 MiB, got: + $maxBufferSizeMb MiB.") } private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 227c9e734f0af..b4ea1ee950217 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1092,41 +1092,41 @@ private[spark] object Utils extends Logging { * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // Convert to bytes, rather than directly to MiB, because when no units are specified the unit // is assumed to be bytes (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** - * Convert a quantity in bytes to a human-readable string such as "4.0 MB". + * Convert a quantity in bytes to a human-readable string such as "4.0 MiB". */ def bytesToString(size: Long): String = bytesToString(BigInt(size)) def bytesToString(size: BigInt): String = { - val EB = 1L << 60 - val PB = 1L << 50 - val TB = 1L << 40 - val GB = 1L << 30 - val MB = 1L << 20 - val KB = 1L << 10 - - if (size >= BigInt(1L << 11) * EB) { + val EiB = 1L << 60 + val PiB = 1L << 50 + val TiB = 1L << 40 + val GiB = 1L << 30 + val MiB = 1L << 20 + val KiB = 1L << 10 + + if (size >= BigInt(1L << 11) * EiB) { // The number is too large, show it in scientific notation. BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B" } else { val (value, unit) = { - if (size >= 2 * EB) { - (BigDecimal(size) / EB, "EB") - } else if (size >= 2 * PB) { - (BigDecimal(size) / PB, "PB") - } else if (size >= 2 * TB) { - (BigDecimal(size) / TB, "TB") - } else if (size >= 2 * GB) { - (BigDecimal(size) / GB, "GB") - } else if (size >= 2 * MB) { - (BigDecimal(size) / MB, "MB") - } else if (size >= 2 * KB) { - (BigDecimal(size) / KB, "KB") + if (size >= 2 * EiB) { + (BigDecimal(size) / EiB, "EiB") + } else if (size >= 2 * PiB) { + (BigDecimal(size) / PiB, "PiB") + } else if (size >= 2 * TiB) { + (BigDecimal(size) / TiB, "TiB") + } else if (size >= 2 * GiB) { + (BigDecimal(size) / GiB, "GiB") + } else if (size >= 2 * MiB) { + (BigDecimal(size) / MiB, "MiB") + } else if (size >= 2 * KiB) { + (BigDecimal(size) / KiB, "KiB") } else { (BigDecimal(size), "B") } @@ -1157,7 +1157,7 @@ private[spark] object Utils extends Logging { } /** - * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". + * Convert a quantity in megabytes to a human-readable string such as "4.0 MiB". */ def megabytesToString(megabytes: Long): String = { bytesToString(megabytes * 1024L * 1024L) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..3e1a3d4f73069 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -244,7 +244,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast - newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KiB << 1MiB framesize // needs TorrentBroadcast so need a SparkContext withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 467e49026a029..8af53274d9b2f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -75,7 +75,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val thrown3 = intercept[IllegalArgumentException](newKryoInstance(conf, "2g", "3g")) assert(thrown3.getMessage.contains(kryoBufferProperty)) assert(!thrown3.getMessage.contains(kryoBufferMaxProperty)) - // test configuration with mb is supported properly + // test configuration with MiB is supported properly newKryoInstance(conf, "8m", "9m") } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 959cf58fa0536..6f60b08088cd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -128,7 +128,7 @@ class DiskStoreSuite extends SparkFunSuite { assert(e.getMessage === s"requirement failed: can't create a byte buffer of size ${blockData.size}" + - " since it exceeds 10.0 KB.") + " since it exceeds 10.0 KiB.") } test("block data encryption") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 901a724da8a1b..b2ff1cce3eb0b 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -133,7 +133,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes - // This demonstrates that we can have e.g 1024^3 PB without overflowing. + // This demonstrates that we can have e.g 1024^3 PiB without overflowing. assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) @@ -149,7 +149,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // Test overflow exception intercept[IllegalArgumentException] { - // This value exceeds Long.MAX when converted to TB + // This value exceeds Long.MAX when converted to TiB ByteUnit.PiB.toTiB(9223372036854775807L) } @@ -189,13 +189,13 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") - assert(Utils.bytesToString(2000000) === "1953.1 KB") - assert(Utils.bytesToString(2097152) === "2.0 MB") - assert(Utils.bytesToString(2306867) === "2.2 MB") - assert(Utils.bytesToString(5368709120L) === "5.0 GB") - assert(Utils.bytesToString(5L * (1L << 40)) === "5.0 TB") - assert(Utils.bytesToString(5L * (1L << 50)) === "5.0 PB") - assert(Utils.bytesToString(5L * (1L << 60)) === "5.0 EB") + assert(Utils.bytesToString(2000000) === "1953.1 KiB") + assert(Utils.bytesToString(2097152) === "2.0 MiB") + assert(Utils.bytesToString(2306867) === "2.2 MiB") + assert(Utils.bytesToString(5368709120L) === "5.0 GiB") + assert(Utils.bytesToString(5L * (1L << 40)) === "5.0 TiB") + assert(Utils.bytesToString(5L * (1L << 50)) === "5.0 PiB") + assert(Utils.bytesToString(5L * (1L << 60)) === "5.0 EiB") assert(Utils.bytesToString(BigInt(1L << 11) * (1L << 60)) === "2.36E+21 B") } diff --git a/docs/configuration.md b/docs/configuration.md index 9abbb3f634900..ff9b802617f08 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1384,14 +1384,14 @@ Apart from these, the following properties are also available, and may be useful

      - + - + diff --git a/docs/hardware-provisioning.md b/docs/hardware-provisioning.md index 896f9302ef300..29876a51b2804 100644 --- a/docs/hardware-provisioning.md +++ b/docs/hardware-provisioning.md @@ -37,7 +37,7 @@ use the same disks as HDFS. # Memory -In general, Spark can run well with anywhere from **8 GB to hundreds of gigabytes** of memory per +In general, Spark can run well with anywhere from **8 GiB to hundreds of gigabytes** of memory per machine. In all cases, we recommend allocating only at most 75% of the memory for Spark; leave the rest for the operating system and buffer cache. @@ -47,7 +47,7 @@ Storage tab of Spark's monitoring UI (`http://:4040`) to see its si Note that memory usage is greatly affected by storage level and serialization format -- see the [tuning guide](tuning.html) for tips on how to reduce it. -Finally, note that the Java VM does not always behave well with more than 200 GB of RAM. If you +Finally, note that the Java VM does not always behave well with more than 200 GiB of RAM. If you purchase machines with more RAM than this, you can run _multiple worker JVMs per node_. In Spark's [standalone mode](spark-standalone.html), you can set the number of workers per node with the `SPARK_WORKER_INSTANCES` variable in `conf/spark-env.sh`, and the number of cores diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index ec13b81f85557..281755f4cea8f 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -149,7 +149,7 @@ These parameters may be tuned. Be careful to validate on held-out test data whe * Note that the `maxBins` parameter must be at least the maximum number of categories `$M$` for any categorical feature. * **`maxMemoryInMB`**: Amount of memory to be used for collecting sufficient statistics. - * The default value is conservatively chosen to be 256 MB to allow the decision algorithm to work in most scenarios. Increasing `maxMemoryInMB` can lead to faster training (if the memory is available) by allowing fewer passes over the data. However, there may be decreasing returns as `maxMemoryInMB` grows since the amount of communication on each iteration can be proportional to `maxMemoryInMB`. + * The default value is conservatively chosen to be 256 MiB to allow the decision algorithm to work in most scenarios. Increasing `maxMemoryInMB` can lead to faster training (if the memory is available) by allowing fewer passes over the data. However, there may be decreasing returns as `maxMemoryInMB` grows since the amount of communication on each iteration can be proportional to `maxMemoryInMB`. * *Implementation details*: For faster processing, the decision tree algorithm collects statistics about groups of nodes to split (rather than 1 node at a time). The number of nodes which can be handled in one group is determined by the memory requirements (which vary per features). The `maxMemoryInMB` parameter specifies the memory limit in terms of megabytes which each worker can use for these statistics. * **`subsamplingRate`**: Fraction of the training data used for learning the decision tree. This parameter is most relevant for training ensembles of trees (using [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$) and [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees)), where it can be useful to subsample the original data. For training a single decision tree, this parameter is less useful since the number of training instances is generally not the main constraint. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index b3ba4b255b71a..968d668e2c93a 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -437,7 +437,7 @@ See the [configuration page](configuration.html) for information on Spark config diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 49ef2e1ce2a1b..672a4d0f3199a 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -60,7 +60,7 @@ Finally, the following configuration options can be passed to the master and wor - + @@ -128,7 +128,7 @@ You can optionally configure the cluster further by setting environment variable - + diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 6a52e8a7b0ebd..4a1812bbb40a2 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -248,5 +248,5 @@ de-aggregate records during consumption. - `InitialPositionInStream.TRIM_HORIZON` may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. #### Kinesis retry configuration - - `spark.streaming.kinesis.retry.waitTime` : Wait time between Kinesis retries as a duration string. When reading from Amazon Kinesis, users may hit `ProvisionedThroughputExceededException`'s, when consuming faster than 5 transactions/second or, exceeding the maximum read rate of 2 MB/second. This configuration can be tweaked to increase the sleep between fetches when a fetch fails to reduce these exceptions. Default is "100ms". + - `spark.streaming.kinesis.retry.waitTime` : Wait time between Kinesis retries as a duration string. When reading from Amazon Kinesis, users may hit `ProvisionedThroughputExceededException`'s, when consuming faster than 5 transactions/second or, exceeding the maximum read rate of 2 MiB/second. This configuration can be tweaked to increase the sleep between fetches when a fetch fails to reduce these exceptions. Default is "100ms". - `spark.streaming.kinesis.retry.maxAttempts` : Max number of retries for Kinesis fetches. This config can also be used to tackle the Kinesis `ProvisionedThroughputExceededException`'s in scenarios mentioned above. It can be increased to have more number of retries for Kinesis reads. Default is 3. diff --git a/docs/tuning.md b/docs/tuning.md index cd0f9cd081369..43acacb98cbf9 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -115,7 +115,7 @@ variety of workloads without requiring user expertise of how memory is divided i Although there are two relevant configurations, the typical user should not need to adjust them as the default values are applicable to most workloads: -* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) +* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MiB) (default 0.6). The rest of the space (40%) is reserved for user data structures, internal metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually large records. @@ -147,7 +147,7 @@ pointer-based data structures and wrapper objects. There are several ways to do Java standard library. 2. Avoid nested structures with a lot of small objects and pointers when possible. 3. Consider using numeric IDs or enumeration objects instead of strings for keys. -4. If you have less than 32 GB of RAM, set the JVM flag `-XX:+UseCompressedOops` to make pointers be +4. If you have less than 32 GiB of RAM, set the JVM flag `-XX:+UseCompressedOops` to make pointers be four bytes instead of eight. You can add these options in [`spark-env.sh`](configuration.html#environment-variables). @@ -224,8 +224,8 @@ temporary objects created during task execution. Some steps which may be useful * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the - size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 128 MB, - we can estimate size of Eden to be `4*3*128MB`. + size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 128 MiB, + we can estimate size of Eden to be `4*3*128MiB`. * Monitor how the frequency and time taken by garbage collection changes with the new settings. @@ -267,7 +267,7 @@ available in `SparkContext` can greatly reduce the size of each serialized task, of launching a job over a cluster. If your tasks use any large object from the driver program inside of them (e.g. a static lookup table), consider turning it into a broadcast variable. Spark prints the serialized size of each task on the master, so you can look at that to -decide whether your tasks are too large; in general tasks larger than about 20 KB are probably +decide whether your tasks are too large; in general tasks larger than about 20 KiB are probably worth optimizing. ## Data Locality diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index e58860fea97d0..e32d615af2a47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -322,7 +322,7 @@ class BlockMatrix @Since("1.3.0") ( val m = numRows().toInt val n = numCols().toInt val mem = m * n / 125000 - if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!") + if (mem > 500) logWarning(s"Storing this matrix will require $mem MiB of memory!") val localBlocks = blocks.collect() val values = new Array[Double](m * n) localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 37eb794b0c5c9..6250b0363ee3b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -190,7 +190,7 @@ class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkCo iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be - // greater than 1MB and hence Spark would throw an error. + // greater than 1MiB and hence Spark would throw an error. val (weights, loss) = GradientDescent.runMiniBatchSGD( points, new LogisticGradient, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8bd6897df925f..b6e17cab44e9c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -127,7 +127,7 @@ def __new__(cls, mean, confidence, low, high): def _parse_memory(s): """ Parse a memory string in the format supported by Java (e.g. 1g, 200m) and - return the value in MB + return the value in MiB >>> _parse_memory("256m") 256 diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index bd0ac0039ffe1..5d2d63850e9b2 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -37,7 +37,7 @@ process = None def get_used_memory(): - """ Return the used memory in MB """ + """ Return the used memory in MiB """ global process if process is None or process._pid != os.getpid(): process = psutil.Process(os.getpid()) @@ -50,7 +50,7 @@ def get_used_memory(): except ImportError: def get_used_memory(): - """ Return the used memory in MB """ + """ Return the used memory in MiB """ if platform.system() == 'Linux': for line in open('/proc/self/status'): if line.startswith('VmRSS:'): diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 9497530805c1a..d37d0d66d8ae2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -127,7 +127,7 @@ private[yarn] class YarnAllocator( private var numUnexpectedContainerRelease = 0L private val containerIdToExecutorId = new HashMap[ContainerId, String] - // Executor memory in MB. + // Executor memory in MiB. protected val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 8818d0135b297..b7ce367230810 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -160,7 +160,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } - test("Coalesce should not throw 64kb exception") { + test("Coalesce should not throw 64KiB exception") { val inputs = (1 to 2500).map(x => Literal(s"x_$x")) checkEvaluation(Coalesce(inputs), "x_1") } @@ -171,7 +171,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.inlinedMutableStates.size == 1) } - test("AtLeastNNonNulls should not throw 64kb exception") { + test("AtLeastNNonNulls should not throw 64KiB exception") { val inputs = (1 to 4000).map(x => Literal(s"x_$x")) checkEvaluation(AtLeastNNonNulls(1, inputs), true) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index d0604b8eb7675..94e251d90bcfa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -128,7 +128,7 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("SPARK-16845: GeneratedClass$SpecificOrdering grows beyond 64 KB") { + test("SPARK-16845: GeneratedClass$SpecificOrdering grows beyond 64 KiB") { val sortOrder = Literal("abc").asc // this is passing prior to SPARK-16845, and it should also be passing after SPARK-16845 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index f5d93ee5fa914..e4ec76f0b9a1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -73,14 +73,14 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * greater than the target size. * * For example, we have two stages with the following pre-shuffle partition size statistics: - * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB] - * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB] - * assuming the target input size is 128 MB, we will have four post-shuffle partitions, + * stage 1: [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB] + * stage 2: [10 MiB, 10 MiB, 70 MiB, 5 MiB, 5 MiB] + * assuming the target input size is 128 MiB, we will have four post-shuffle partitions, * which are: - * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MB) - * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MB) - * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MB) - * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB) + * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MiB) + * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MiB) + * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MiB) + * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MiB) */ class ExchangeCoordinator( advisoryTargetPostShuffleInputSize: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 27bed1137e5b3..82973307feef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -44,7 +44,7 @@ case class WindowInPandasExec( override def requiredChildDistribution: Seq[Distribution] = { if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MB? + // Only show warning when the number of bytes is larger than 100 MiB? logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index fede0f3e92d67..729b8bdb3dae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -90,7 +90,7 @@ case class WindowExec( override def requiredChildDistribution: Seq[Distribution] = { if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MB? + // Only show warning when the number of bytes is larger than 100 MiB? logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index fc3faa08d55f4..b51c51e663503 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1904,7 +1904,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val e = intercept[SparkException] { df.filter(filter).count() }.getMessage - assert(e.contains("grows beyond 64 KB")) + assert(e.contains("grows beyond 64 KiB")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index cb562d65b6147..02dc32d5f90ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -227,12 +227,12 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared BigInt(0) -> (("0.0 B", "0")), BigInt(100) -> (("100.0 B", "100")), BigInt(2047) -> (("2047.0 B", "2.05E+3")), - BigInt(2048) -> (("2.0 KB", "2.05E+3")), - BigInt(3333333) -> (("3.2 MB", "3.33E+6")), - BigInt(4444444444L) -> (("4.1 GB", "4.44E+9")), - BigInt(5555555555555L) -> (("5.1 TB", "5.56E+12")), - BigInt(6666666666666666L) -> (("5.9 PB", "6.67E+15")), - BigInt(1L << 10 ) * (1L << 60) -> (("1024.0 EB", "1.18E+21")), + BigInt(2048) -> (("2.0 KiB", "2.05E+3")), + BigInt(3333333) -> (("3.2 MiB", "3.33E+6")), + BigInt(4444444444L) -> (("4.1 GiB", "4.44E+9")), + BigInt(5555555555555L) -> (("5.1 TiB", "5.56E+12")), + BigInt(6666666666666666L) -> (("5.9 PiB", "6.67E+15")), + BigInt(1L << 10 ) * (1L << 60) -> (("1024.0 EiB", "1.18E+21")), BigInt(1L << 11) * (1L << 60) -> (("2.36E+21 B", "2.36E+21")) ) numbers.foreach { case (input, (expectedSize, expectedRows)) => From 05cf81e6de3d61ddb0af81cd179665693f23351f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 10 Dec 2018 18:28:13 -0600 Subject: [PATCH 2254/2461] [SPARK-19827][R] spark.ml R API for PIC ## What changes were proposed in this pull request? Add PowerIterationCluster (PIC) in R ## How was this patch tested? Add test case Closes #23072 from huaxingao/spark-19827. Authored-by: Huaxin Gao Signed-off-by: Sean Owen --- R/pkg/NAMESPACE | 3 +- R/pkg/R/generics.R | 4 ++ R/pkg/R/mllib_clustering.R | 62 +++++++++++++++++++ R/pkg/tests/fulltests/test_mllib_clustering.R | 13 ++++ R/pkg/vignettes/sparkr-vignettes.Rmd | 14 +++++ docs/ml-clustering.md | 41 ++++++++++++ docs/sparkr.md | 1 + .../src/main/r/ml/powerIterationClustering.R | 38 ++++++++++++ .../clustering/PowerIterationClustering.scala | 4 +- .../r/PowerIterationClusteringWrapper.scala | 39 ++++++++++++ python/pyspark/ml/clustering.py | 4 +- 11 files changed, 218 insertions(+), 5 deletions(-) create mode 100644 examples/src/main/r/ml/powerIterationClustering.R create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 1f8ba0bcf1cf5..cfad20db16c75 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -67,7 +67,8 @@ exportMethods("glm", "spark.fpGrowth", "spark.freqItemsets", "spark.associationRules", - "spark.findFrequentSequentialPatterns") + "spark.findFrequentSequentialPatterns", + "spark.assignClusters") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index eed76465221c6..09d817127edd6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1479,6 +1479,10 @@ setGeneric("spark.associationRules", function(object) { standardGeneric("spark.a setGeneric("spark.findFrequentSequentialPatterns", function(data, ...) { standardGeneric("spark.findFrequentSequentialPatterns") }) +#' @rdname spark.powerIterationClustering +setGeneric("spark.assignClusters", + function(data, ...) { standardGeneric("spark.assignClusters") }) + #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 900be685824da..7d9dcebfe70d3 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -41,6 +41,12 @@ setClass("KMeansModel", representation(jobj = "jobj")) #' @note LDAModel since 2.1.0 setClass("LDAModel", representation(jobj = "jobj")) +#' S4 class that represents a PowerIterationClustering +#' +#' @param jobj a Java object reference to the backing Scala PowerIterationClustering +#' @note PowerIterationClustering since 3.0.0 +setClass("PowerIterationClustering", slots = list(jobj = "jobj")) + #' Bisecting K-Means Clustering Model #' #' Fits a bisecting k-means clustering model against a SparkDataFrame. @@ -610,3 +616,59 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), function(object, path, overwrite = FALSE) { write_internal(object, path, overwrite) }) + +#' PowerIterationClustering +#' +#' A scalable graph clustering algorithm. Users can call \code{spark.assignClusters} to +#' return a cluster assignment for each input vertex. +#' +# Run the PIC algorithm and returns a cluster assignment for each input vertex. +#' @param data a SparkDataFrame. +#' @param k the number of clusters to create. +#' @param initMode the initialization algorithm. +#' @param maxIter the maximum number of iterations. +#' @param sourceCol the name of the input column for source vertex IDs. +#' @param destinationCol the name of the input column for destination vertex IDs +#' @param weightCol weight column name. If this is not set or \code{NULL}, +#' we treat all instance weights as 1.0. +#' @param ... additional argument(s) passed to the method. +#' @return A dataset that contains columns of vertex id and the corresponding cluster for the id. +#' The schema of it will be: +#' \code{id: Long} +#' \code{cluster: Int} +#' @rdname spark.powerIterationClustering +#' @aliases assignClusters,PowerIterationClustering-method,SparkDataFrame-method +#' @examples +#' \dontrun{ +#' df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), +#' list(1L, 2L, 1.0), list(3L, 4L, 1.0), +#' list(4L, 0L, 0.1)), +#' schema = c("src", "dst", "weight")) +#' clusters <- spark.assignClusters(df, initMode="degree", weightCol="weight") +#' showDF(clusters) +#' } +#' @note spark.assignClusters(SparkDataFrame) since 3.0.0 +setMethod("spark.assignClusters", + signature(data = "SparkDataFrame"), + function(data, k = 2L, initMode = c("random", "degree"), maxIter = 20L, + sourceCol = "src", destinationCol = "dst", weightCol = NULL) { + if (!is.numeric(k) || k < 1) { + stop("k should be a number with value >= 1.") + } + if (!is.integer(maxIter) || maxIter <= 0) { + stop("maxIter should be a number with value > 0.") + } + initMode <- match.arg(initMode) + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) + } + jobj <- callJStatic("org.apache.spark.ml.r.PowerIterationClusteringWrapper", + "getPowerIterationClustering", + as.integer(k), initMode, + as.integer(maxIter), as.character(sourceCol), + as.character(destinationCol), weightCol) + object <- new("PowerIterationClustering", jobj = jobj) + dataFrame(callJMethod(object@jobj, "assignClusters", data@sdf)) + }) diff --git a/R/pkg/tests/fulltests/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R index 4110e13da4948..b78a476f1d058 100644 --- a/R/pkg/tests/fulltests/test_mllib_clustering.R +++ b/R/pkg/tests/fulltests/test_mllib_clustering.R @@ -319,4 +319,17 @@ test_that("spark.posterior and spark.perplexity", { expect_equal(length(local.posterior), sum(unlist(local.posterior))) }) +test_that("spark.assignClusters", { + df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), + list(1L, 2L, 1.0), list(3L, 4L, 1.0), + list(4L, 0L, 0.1)), + schema = c("src", "dst", "weight")) + clusters <- spark.assignClusters(df, initMode = "degree", weightCol = "weight") + expected_result <- createDataFrame(list(list(4L, 1L), list(0L, 0L), + list(1L, 0L), list(3L, 1L), + list(2L, 0L)), + schema = c("id", "cluster")) + expect_equivalent(expected_result, clusters) +}) + sparkR.session.stop() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 1c6a03c4b9bc3..cbe8c61725c88 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -549,6 +549,8 @@ SparkR supports the following machine learning models and algorithms. * Latent Dirichlet Allocation (LDA) +* Power Iteration Clustering (PIC) + #### Collaborative Filtering * Alternating Least Squares (ALS) @@ -982,6 +984,18 @@ predicted <- predict(model, df) head(predicted) ``` +#### Power Iteration Clustering + +Power Iteration Clustering (PIC) is a scalable graph clustering algorithm. `spark.assignClusters` method runs the PIC algorithm and returns a cluster assignment for each input vertex. + +```{r} +df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), + list(1L, 2L, 1.0), list(3L, 4L, 1.0), + list(4L, 0L, 0.1)), + schema = c("src", "dst", "weight")) +head(spark.assignClusters(df, initMode = "degree", weightCol = "weight")) +``` + #### FP-growth `spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index 1186fb73d0faf..65f265256200b 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -265,3 +265,44 @@ Refer to the [R API docs](api/R/spark.gaussianMixture.html) for more details. + +## Power Iteration Clustering (PIC) + +Power Iteration Clustering (PIC) is a scalable graph clustering algorithm +developed by [Lin and Cohen](http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf). +From the abstract: PIC finds a very low-dimensional embedding of a dataset +using truncated power iteration on a normalized pair-wise similarity matrix of the data. + +`spark.ml`'s PowerIterationClustering implementation takes the following parameters: + +* `k`: the number of clusters to create +* `initMode`: param for the initialization algorithm +* `maxIter`: param for maximum number of iterations +* `srcCol`: param for the name of the input column for source vertex IDs +* `dstCol`: name of the input column for destination vertex IDs +* `weightCol`: Param for weight column name + +**Examples** + +
      + +
      +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.PowerIterationClustering) for more details. + +{% include_example scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala %} +
      + +
      +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/PowerIterationClustering.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java %} +
      + +
      + +Refer to the [R API docs](api/R/spark.powerIterationClustering.html) for more details. + +{% include_example r/ml/powerIterationClustering.R %} +
      + +
      diff --git a/docs/sparkr.md b/docs/sparkr.md index 0057f05de0ff3..dbb61241007ff 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -544,6 +544,7 @@ SparkR supports the following machine learning algorithms currently: * [`spark.gaussianMixture`](api/R/spark.gaussianMixture.html): [`Gaussian Mixture Model (GMM)`](ml-clustering.html#gaussian-mixture-model-gmm) * [`spark.kmeans`](api/R/spark.kmeans.html): [`K-Means`](ml-clustering.html#k-means) * [`spark.lda`](api/R/spark.lda.html): [`Latent Dirichlet Allocation (LDA)`](ml-clustering.html#latent-dirichlet-allocation-lda) +* [`spark.powerIterationClustering (PIC)`](api/R/spark.powerIterationClustering.html): [`Power Iteration Clustering (PIC)`](ml-clustering.html#power-iteration-clustering-pic) #### Collaborative Filtering diff --git a/examples/src/main/r/ml/powerIterationClustering.R b/examples/src/main/r/ml/powerIterationClustering.R new file mode 100644 index 0000000000000..ba43037106d14 --- /dev/null +++ b/examples/src/main/r/ml/powerIterationClustering.R @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/powerIterationClustering.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-powerIterationCLustering-example") + +# $example on$ +df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), + list(1L, 2L, 1.0), list(3L, 4L, 1.0), + list(4L, 0L, 0.1)), + schema = c("src", "dst", "weight")) +# assign clusters +clusters <- spark.assignClusters(df, k=2L, maxIter=20L, initMode="degree", weightCol="weight") + +showDF(arrange(clusters, clusters$id)) +# $example off$ + +sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 1b9a3499947d9..d9a330f67e8dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -97,8 +97,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has /** * :: Experimental :: * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by - * Lin and Cohen. From the abstract: - * PIC finds a very low-dimensional embedding of a dataset using truncated power + * Lin and Cohen. From + * the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power * iteration on a normalized pair-wise similarity matrix of the data. * * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala new file mode 100644 index 0000000000000..b5dfad0224ed8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.clustering.PowerIterationClustering + +private[r] object PowerIterationClusteringWrapper { + def getPowerIterationClustering( + k: Int, + initMode: String, + maxIter: Int, + srcCol: String, + dstCol: String, + weightCol: String): PowerIterationClustering = { + val pic = new PowerIterationClustering() + .setK(k) + .setInitMode(initMode) + .setMaxIter(maxIter) + .setSrcCol(srcCol) + .setDstCol(dstCol) + if (weightCol != null) pic.setWeightCol(weightCol) + pic + } +} diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index d0b507ec5dad4..d8a6dfb7d3a71 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -1193,8 +1193,8 @@ class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReada .. note:: Experimental Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by - `Lin and Cohen `_. From the abstract: - PIC finds a very low-dimensional embedding of a dataset using truncated power + `Lin and Cohen `_. From the + abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method From cbe92305cd4f80725a251cf74353e8985d28306c Mon Sep 17 00:00:00 2001 From: 10129659 Date: Tue, 11 Dec 2018 09:50:21 +0800 Subject: [PATCH 2255/2461] [SPARK-26312][SQL] Replace RDDConversions.rowToRowRdd with RowEncoder to improve its conversion performance ## What changes were proposed in this pull request? `RDDConversions` would get disproportionately slower as the number of columns in the query increased, for the type of `converters` before is `scala.collection.immutable.::` which is a subtype of list. This PR removing `RDDConversions` and using `RowEncoder` to convert the Row to InternalRow. The test of `PrunedScanSuite` for 2000 columns and 20k rows takes 409 seconds before this PR, and 361 seconds after. ## How was this patch tested? Test case of `PrunedScanSuite` Closes #23262 from eatoncys/toarray. Authored-by: 10129659 Signed-off-by: Hyukjin Kwon --- .../spark/sql/execution/ExistingRDD.scala | 44 +------------------ .../datasources/DataSourceStrategy.scala | 7 ++- 2 files changed, 7 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index e214bfd050410..49fb288fdea6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,54 +18,14 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Encoder, Row, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.{Encoder, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.DataType - -object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { - data.mapPartitions { iterator => - val numColumns = outputTypes.length - val mutableRow = new GenericInternalRow(numColumns) - val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) - iterator.map { r => - var i = 0 - while (i < numColumns) { - mutableRow(i) = converters(i)(r.productElement(i)) - i += 1 - } - - mutableRow - } - } - } - - /** - * Convert the objects inside Row into the types Catalyst expected. - */ - def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = { - data.mapPartitions { iterator => - val numColumns = outputTypes.length - val mutableRow = new GenericInternalRow(numColumns) - val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) - iterator.map { r => - var i = 0 - while (i < numColumns) { - mutableRow(i) = converters(i)(r(i)) - i += 1 - } - - mutableRow - } - } - } -} object ExternalRDD { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c6000442fae76..b304e2da6e1cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Quali import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -416,7 +416,10 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) + val converters = RowEncoder(StructType.fromAttributes(output)) + rdd.mapPartitions { iterator => + iterator.map(converters.toRow) + } } else { rdd.asInstanceOf[RDD[InternalRow]] } From 7d5f6e8c493b96898ba01edede1522121fe945fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Dec 2018 14:16:51 +0800 Subject: [PATCH 2256/2461] [SPARK-26293][SQL] Cast exception when having python udf in subquery ## What changes were proposed in this pull request? This is a regression introduced by https://github.com/apache/spark/pull/22104 at Spark 2.4.0. When we have Python UDF in subquery, we will hit an exception ``` Caused by: java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.AttributeReference cannot be cast to org.apache.spark.sql.catalyst.expressions.PythonUDF at scala.collection.immutable.Stream.map(Stream.scala:414) at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:98) at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:815) ... ``` https://github.com/apache/spark/pull/22104 turned `ExtractPythonUDFs` from a physical rule to optimizer rule. However, there is a difference between a physical rule and optimizer rule. A physical rule always runs once, an optimizer rule may be applied twice on a query tree even the rule is located in a batch that only runs once. For a subquery, the `OptimizeSubqueries` rule will execute the entire optimizer on the query plan inside subquery. Later on subquery will be turned to joins, and the optimizer rules will be applied to it again. Unfortunately, the `ExtractPythonUDFs` rule is not idempotent. When it's applied twice on a query plan inside subquery, it will produce a malformed plan. It extracts Python UDF from Python exec plans. This PR proposes 2 changes to be double safe: 1. `ExtractPythonUDFs` should skip python exec plans, to make the rule idempotent 2. `ExtractPythonUDFs` should skip subquery ## How was this patch tested? a new test. Closes #23248 from cloud-fan/python. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests/test_udf.py | 52 +++++++------------ .../python/ArrowEvalPythonExec.scala | 8 ++- .../python/BatchEvalPythonExec.scala | 8 ++- .../execution/python/ExtractPythonUDFs.scala | 18 +++++-- 4 files changed, 46 insertions(+), 40 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index ed298f724d551..12cf8c7de1dad 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -23,7 +23,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, Column, Row -from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.functions import UserDefinedFunction, udf from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message @@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self): def test_nondeterministic_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() self.assertEqual(udf_random_col.deterministic, False) @@ -113,7 +112,6 @@ def test_nondeterministic_udf(self): def test_nondeterministic_udf2(self): import random - from pyspark.sql.functions import udf random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() self.assertEqual(random_udf.deterministic, False) random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) @@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self): def test_nondeterministic_udf3(self): # regression test for SPARK-23233 - from pyspark.sql.functions import udf f = udf(lambda x: x) # Here we cache the JVM UDF instance. self.spark.range(1).select(f("id")) @@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self): self.assertFalse(deterministic) def test_nondeterministic_udf_in_aggregate(self): - from pyspark.sql.functions import udf, sum + from pyspark.sql.functions import sum import random udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() df = self.spark.range(10) @@ -181,7 +178,6 @@ def test_multiple_udfs(self): self.assertEqual(tuple(row), (6, 5)) def test_udf_in_filter_on_top_of_outer_join(self): - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(a=1)]) df = left.join(right, on='a', how='left_outer') @@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self): def test_udf_in_filter_on_top_of_join(self): # regression test for SPARK-18589 - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self): def test_udf_in_join_condition(self): # regression test for SPARK-25314 - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -211,7 +205,7 @@ def test_udf_in_join_condition(self): def test_udf_in_left_outer_join_condition(self): # regression test for SPARK-26147 - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import col left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a: str(a), StringType()) @@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self): def test_udf_in_left_semi_join_condition(self): # regression test for SPARK-25314 - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self): def test_udf_and_common_filter_in_join_condition(self): # regression test for SPARK-25314 # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self): def test_udf_and_common_filter_in_left_semi_join_condition(self): # regression test for SPARK-25314 # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self): def test_udf_not_supported_in_join_condition(self): # regression test for SPARK-25314 # test python udf is not supported in join type besides left_semi and inner join. - from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) f = udf(lambda a, b: a == b, BooleanType()) @@ -301,7 +291,7 @@ def test_broadcast_in_udf(self): def test_udf_with_filter_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import col from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a < 2, BooleanType()) @@ -310,7 +300,7 @@ def test_udf_with_filter_function(self): def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col, sum + from pyspark.sql.functions import col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) @@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self): self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) def test_udf_in_generate(self): - from pyspark.sql.functions import udf, explode + from pyspark.sql.functions import explode df = self.spark.range(5) f = udf(lambda x: list(range(x)), ArrayType(LongType())) row = df.select(explode(f(*df))).groupBy().sum().first() @@ -353,7 +343,6 @@ def test_udf_in_generate(self): self.assertEqual(res[3][1], 1) def test_udf_with_order_by_and_limit(self): - from pyspark.sql.functions import udf my_copy = udf(lambda x: x, IntegerType()) df = self.spark.range(10).orderBy("id") res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) @@ -394,14 +383,14 @@ def test_non_existed_udaf(self): lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) def test_udf_with_input_file_name(self): - from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.functions import input_file_name sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() self.assertTrue(row[0].find("people1.json") != -1) def test_udf_with_input_file_name_for_hadooprdd(self): - from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.functions import input_file_name def filename(path): return path @@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self): # This is separate of UDFInitializationTests # to avoid context initialization # when udf is called - - from pyspark.sql.functions import UserDefinedFunction - f = UserDefinedFunction(lambda x: x, StringType()) self.assertIsNone( @@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self): ) def test_udf_with_string_return_type(self): - from pyspark.sql.functions import UserDefinedFunction - add_one = UserDefinedFunction(lambda x: x + 1, "integer") make_pair = UserDefinedFunction(lambda x: (-x, x), "struct") make_array = UserDefinedFunction( @@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self): self.assertTupleEqual(expected, actual) def test_udf_shouldnt_accept_noncallable_object(self): - from pyspark.sql.functions import UserDefinedFunction - non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) def test_udf_with_decorator(self): - from pyspark.sql.functions import lit, udf + from pyspark.sql.functions import lit from pyspark.sql.types import IntegerType, DoubleType @udf(IntegerType()) @@ -523,7 +505,6 @@ def as_double(x): ) def test_udf_wrapper(self): - from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType def f(x): @@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self): # SPARK-24721 @unittest.skipIf(not test_compiled, test_not_compiled_message) def test_datasource_with_udf(self): - from pyspark.sql.functions import udf, lit, col + from pyspark.sql.functions import lit, col path = tempfile.mkdtemp() shutil.rmtree(path) @@ -609,8 +590,6 @@ def test_datasource_with_udf(self): # SPARK-25591 def test_same_accumulator_in_udfs(self): - from pyspark.sql.functions import udf - data_schema = StructType([StructField("a", IntegerType(), True), StructField("b", IntegerType(), True)]) data = self.spark.createDataFrame([[1, 2]], schema=data_schema) @@ -632,6 +611,15 @@ def second_udf(x): data.collect() self.assertEqual(test_accum.value, 101) + # SPARK-26293 + def test_udf_in_subquery(self): + f = udf(lambda x: x, "long") + with self.tempView("v"): + self.spark.range(1).filter(f("id") >= 0).createTempView("v") + sql = self.spark.sql + result = sql("select i from values(0L) as data(i) where i in (select id from v)") + self.assertEqual(result.collect(), [Row(i=0)]) + class UDFInitializationTests(unittest.TestCase): def tearDown(self): @@ -642,8 +630,6 @@ def tearDown(self): SparkContext._active_spark_context.stop() def test_udf_init_shouldnt_initialize_context(self): - from pyspark.sql.functions import UserDefinedFunction - UserDefinedFunction(lambda x: x, StringType()) self.assertIsNone( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 2b87796dc6833..a5203daea9cd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) /** * A logical plan that evaluates a [[PythonUDF]]. */ -case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) - extends UnaryNode +case class ArrowEvalPython( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) +} /** * A physical plan that evaluates a [[PythonUDF]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index b08b7e60e130b..d3736d24e5019 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * A logical plan that evaluates a [[PythonUDF]] */ -case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan) - extends UnaryNode +case class BatchEvalPython( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) +} /** * A physical plan that evaluates a [[PythonUDF]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 90b5325919e96..380c31baa6213 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper { expressions.flatMap(collectEvaluableUDFs) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case plan: LogicalPlan => extract(plan) + def apply(plan: LogicalPlan): LogicalPlan = plan match { + // SPARK-26293: A subquery will be rewritten into join later, and will go through this rule + // eventually. Here we skip subquery, as Python UDF only needs to be extracted once. + case _: Subquery => plan + + case _ => plan transformUp { + // A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and + // `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't + // extract Python UDFs from them. + case p: BatchEvalPython => p + case p: ArrowEvalPython => p + + case plan: LogicalPlan => extract(plan) + } } /** From 4e1d859c19d3bfdfcb8acf915a97c68633b9ca95 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 11 Dec 2018 16:06:57 +0800 Subject: [PATCH 2257/2461] [SPARK-26303][SQL] Return partial results for bad JSON records ## What changes were proposed in this pull request? In the PR, I propose to return partial results from JSON datasource and JSON functions in the PERMISSIVE mode if some of JSON fields are parsed and converted to desired types successfully. The changes are made only for `StructType`. Whole bad JSON records are placed into the corrupt column specified by the `columnNameOfCorruptRecord` option or SQL config. Partial results are not returned for malformed JSON input. ## How was this patch tested? Added new UT which checks converting JSON strings with one invalid and one valid field at the end of the string. Closes #23253 from MaxGekk/json-bad-record. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- docs/sql-migration-guide-upgrade.md | 4 +++- python/pyspark/sql/readwriter.py | 4 ++-- python/pyspark/sql/streaming.py | 4 ++-- .../sql/catalyst/json/JacksonParser.scala | 24 +++++++++++++++---- .../catalyst/util/BadRecordException.scala | 10 ++++++++ .../apache/spark/sql/DataFrameReader.scala | 16 ++++++------- .../sql/streaming/DataStreamReader.scala | 16 ++++++------- .../datasources/json/JsonSuite.scala | 15 +++++++++++- 8 files changed, 67 insertions(+), 26 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index f6458a9b2730b..8834e8991d8c3 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -37,7 +37,9 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, CSV datasource uses java.time API for parsing and generating CSV content. New formatting implementation supports date/timestamp patterns conformed to ISO 8601. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. + - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. + + - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully. ## Upgrading From Spark SQL 2.3 to 2.4 diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 1d2dd4d808930..7b10512a43294 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -211,7 +211,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it uses the default value, ``PERMISSIVE``. * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ - into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ schema does not have the field, it drops corrupt records during parsing. \ @@ -424,7 +424,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non set, it uses the default value, ``PERMISSIVE``. * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ - into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ schema does not have the field, it drops corrupt records during parsing. \ diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d92b0d5677e25..fc23b9d99c34a 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -441,7 +441,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it uses the default value, ``PERMISSIVE``. * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ - into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ schema does not have the field, it drops corrupt records during parsing. \ @@ -648,7 +648,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non set, it uses the default value, ``PERMISSIVE``. * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ - into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \ fields to ``null``. To keep corrupt records, an user can set a string type \ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ schema does not have the field, it drops corrupt records during parsing. \ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 2357595906b11..7e3bd4df51bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -22,6 +22,7 @@ import java.nio.charset.MalformedInputException import scala.collection.mutable.ArrayBuffer import scala.util.Try +import scala.util.control.NonFatal import com.fasterxml.jackson.core._ @@ -29,7 +30,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -347,17 +347,28 @@ class JacksonParser( schema: StructType, fieldConverters: Array[ValueConverter]): InternalRow = { val row = new GenericInternalRow(schema.length) + var badRecordException: Option[Throwable] = None + while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { case Some(index) => - row.update(index, fieldConverters(index).apply(parser)) - + try { + row.update(index, fieldConverters(index).apply(parser)) + } catch { + case NonFatal(e) => + badRecordException = badRecordException.orElse(Some(e)) + parser.skipChildren() + } case None => parser.skipChildren() } } - row + if (badRecordException.isEmpty) { + row + } else { + throw PartialResultException(row, badRecordException.get) + } } /** @@ -428,6 +439,11 @@ class JacksonParser( val wrappedCharException = new CharConversionException(msg) wrappedCharException.initCause(e) throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) + case PartialResultException(row, cause) => + throw BadRecordException( + record = () => recordLiteral(record), + partialResult = () => Some(row), + cause) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index 985f0dc1cd60e..d719a33929fcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -20,6 +20,16 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.UTF8String +/** + * Exception thrown when the underlying parser returns a partial result of parsing. + * @param partialResult the partial result of parsing a bad record. + * @param cause the actual exception about why the parser cannot return full result. + */ +case class PartialResultException( + partialResult: InternalRow, + cause: Throwable) + extends Exception(cause) + /** * Exception thrown when the underlying parser meet a bad record and can't parse it. * @param record a function to return the record that cause the parser to fail diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 661fe98d8c901..9751528654ffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -362,7 +362,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * during parsing. *
        *
      • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a - * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`. To * keep corrupt records, an user can set a string type field named * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the * field, it drops corrupt records during parsing. When inferring a schema, it implicitly @@ -598,13 +598,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * during parsing. It supports the following case-insensitive modes. *
          *
        • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a - * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. A record with less/more tokens than schema is not a corrupted record to - * CSV. When it meets a record having fewer tokens than the length of the schema, sets - * `null` to extra fields. When the record has more tokens than the length of the schema, - * it drops extra tokens.
        • + * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`. + * To keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have + * the field, it drops corrupt records during parsing. A record with less/more tokens + * than schema is not a corrupted record to CSV. When it meets a record having fewer + * tokens than the length of the schema, sets `null` to extra fields. When the record + * has more tokens than the length of the schema, it drops extra tokens. *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • *
        diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c8e3e1c191044..914fa90ae7e14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -273,7 +273,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * during parsing. *
          *
        • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a - * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`. To * keep corrupt records, an user can set a string type field named * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the * field, it drops corrupt records during parsing. When inferring a schema, it implicitly @@ -360,13 +360,13 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * during parsing. It supports the following case-insensitive modes. *
            *
          • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a - * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. A record with less/more tokens than schema is not a corrupted record to - * CSV. When it meets a record having fewer tokens than the length of the schema, sets - * `null` to extra fields. When the record has more tokens than the length of the schema, - * it drops extra tokens.
          • + * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`. + * To keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have + * the field, it drops corrupt records during parsing. A record with less/more tokens + * than schema is not a corrupted record to CSV. When it meets a record having fewer + * tokens than the length of the schema, sets `null` to extra fields. When the record + * has more tokens than the length of the schema, it drops extra tokens. *
          • `DROPMALFORMED` : ignores the whole corrupted records.
          • *
          • `FAILFAST` : throws an exception when it meets corrupted records.
          • *
          diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index dff37ca2d40f0..3330de3584ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -248,7 +248,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( sql("select nullstr, headers.Host from jsonTable"), - Seq(Row("", "1.abc.com"), Row("", null), Row(null, null), Row(null, null)) + Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row(null, null)) ) } @@ -2563,4 +2563,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(!files.exists(_.getName.endsWith("json"))) } } + + test("return partial result for bad records") { + val schema = "a double, b array, c string, _corrupt_record string" + val badRecords = Seq( + """{"a":"-","b":[0, 1, 2],"c":"abc"}""", + """{"a":0.1,"b":{},"c":"def"}""").toDS() + val df = spark.read.schema(schema).json(badRecords) + + checkAnswer( + df, + Row(null, Array(0, 1, 2), "abc", """{"a":"-","b":[0, 1, 2],"c":"abc"}""") :: + Row(0.1, null, "def", """{"a":0.1,"b":{},"c":"def"}""") :: Nil) + } } From bd7df6b1e129741136d09a3d29f9ffcc32ce1de3 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 11 Dec 2018 18:47:21 +0800 Subject: [PATCH 2258/2461] [SPARK-26327][SQL] Bug fix for `FileSourceScanExec` metrics update and name changing ## What changes were proposed in this pull request? As the description in [SPARK-26327](https://issues.apache.org/jira/browse/SPARK-26327), `postDriverMetricUpdates` was called on wrong place cause this bug, fix this by split the initializing of `selectedPartitions` and metrics updating logic. Add the updating logic in `inputRDD` initializing which can take effect in both code generation node and normal node. Also rename `metadataTime` to `fileListingTime` for clearer meaning. ## How was this patch tested? New test case in `SQLMetricsSuite`. Manual test: | | Before | After | |---------|:--------:|:-------:| | CodeGen |![image](https://user-images.githubusercontent.com/4833765/49741753-13c7e800-fcd2-11e8-97a8-8057b657aa3c.png)|![image](https://user-images.githubusercontent.com/4833765/49741774-1f1b1380-fcd2-11e8-98d9-78b950f4e43a.png)| | Normal |![image](https://user-images.githubusercontent.com/4833765/49741836-378b2e00-fcd2-11e8-80c3-ab462a6a3184.png)|![image](https://user-images.githubusercontent.com/4833765/49741860-4a056780-fcd2-11e8-9ef1-863de217f183.png)| Closes #23277 from xuanyuanking/SPARK-26327. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../sql/execution/DataSourceScanExec.scala | 28 +++++++++++++------ .../execution/metric/SQLMetricsSuite.scala | 15 ++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index b29d5c76c5f3a..c0fa4e777b49c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -167,19 +167,14 @@ case class FileSourceScanExec( partitionSchema = relation.partitionSchema, relation.sparkSession.sessionState.conf) + private var fileListingTime = 0L + @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = relation.location.listFiles(partitionFilters, dataFilters) val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000 - - metrics("numFiles").add(ret.map(_.files.size.toLong).sum) - metrics("metadataTime").add(timeTakenMs) - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, - metrics("numFiles") :: metrics("metadataTime") :: Nil) - + fileListingTime = timeTakenMs ret } @@ -291,6 +286,8 @@ case class FileSourceScanExec( } private lazy val inputRDD: RDD[InternalRow] = { + // Update metrics for taking effect in both code generation node and normal node. + updateDriverMetrics() val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -316,7 +313,7 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), - "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time (ms)"), + "fileListingTime" -> SQLMetrics.createMetric(sparkContext, "file listing time (ms)"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { @@ -507,6 +504,19 @@ case class FileSourceScanExec( } } + /** + * Send the updated metrics to driver, while this function calling, selectedPartitions has + * been initialized. See SPARK-26327 for more detail. + */ + private def updateDriverMetrics() = { + metrics("numFiles").add(selectedPartitions.map(_.files.size.toLong).sum) + metrics("fileListingTime").add(fileListingTime) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics("numFiles") :: metrics("fileListingTime") :: Nil) + } + override def doCanonicalize(): FileSourceScanExec = { FileSourceScanExec( relation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2251607e76af8..4a80638f68858 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -636,4 +636,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared assert(filters.head.metrics("numOutputRows").value == 1) } } + + test("SPARK-26327: FileSourceScanExec metrics") { + withTable("testDataForScan") { + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").saveAsTable("testDataForScan") + // The execution plan only has 1 FileScan node. + val df = spark.sql( + "SELECT * FROM testDataForScan WHERE p = 1") + testSparkPlanMetrics(df, 1, Map( + 0L -> (("Scan parquet default.testdataforscan", Map( + "number of output rows" -> 3L, + "number of files" -> 2L)))) + ) + } + } } From a3bbca98d7d120f22727a55cdc448608e6bb9fad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Dec 2018 21:08:39 +0800 Subject: [PATCH 2259/2461] [SPARK-26265][CORE] Fix deadlock in BytesToBytesMap.MapIterator when locking both BytesToBytesMap.MapIterator and TaskMemoryManager ## What changes were proposed in this pull request? In `BytesToBytesMap.MapIterator.advanceToNextPage`, We will first lock this `MapIterator` and then `TaskMemoryManager` when going to free a memory page by calling `freePage`. At the same time, it is possibly that another memory consumer first locks `TaskMemoryManager` and then this `MapIterator` when it acquires memory and causes spilling on this `MapIterator`. So it ends with the `MapIterator` object holds lock to the `MapIterator` object and waits for lock on `TaskMemoryManager`, and the other consumer holds lock to `TaskMemoryManager` and waits for lock on the `MapIterator` object. To avoid deadlock here, this patch proposes to keep reference to the page to free and free it after releasing the lock of `MapIterator`. ## How was this patch tested? Added test and manually test by running the test 100 times to make sure there is no deadlock. Closes #23272 from viirya/SPARK-26265. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/unsafe/map/BytesToBytesMap.java | 86 ++++++++++--------- .../spark/memory/TestMemoryConsumer.java | 4 +- .../map/AbstractBytesToBytesMapSuite.java | 47 ++++++++++ 3 files changed, 96 insertions(+), 41 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 405e529464152..fbba002f1f80f 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -255,11 +255,18 @@ private MapIterator(int numRecords, Location loc, boolean destructive) { } private void advanceToNextPage() { + // SPARK-26265: We will first lock this `MapIterator` and then `TaskMemoryManager` when going + // to free a memory page by calling `freePage`. At the same time, it is possibly that another + // memory consumer first locks `TaskMemoryManager` and then this `MapIterator` when it + // acquires memory and causes spilling on this `MapIterator`. To avoid deadlock here, we keep + // reference to the page to free and free it after releasing the lock of `MapIterator`. + MemoryBlock pageToFree = null; + synchronized (this) { int nextIdx = dataPages.indexOf(currentPage) + 1; if (destructive && currentPage != null) { dataPages.remove(currentPage); - freePage(currentPage); + pageToFree = currentPage; nextIdx --; } if (dataPages.size() > nextIdx) { @@ -283,6 +290,9 @@ private void advanceToNextPage() { } } } + if (pageToFree != null) { + freePage(pageToFree); + } } @Override @@ -329,52 +339,50 @@ public Location next() { } } - public long spill(long numBytes) throws IOException { - synchronized (this) { - if (!destructive || dataPages.size() == 1) { - return 0L; - } + public synchronized long spill(long numBytes) throws IOException { + if (!destructive || dataPages.size() == 1) { + return 0L; + } - updatePeakMemoryUsed(); + updatePeakMemoryUsed(); - // TODO: use existing ShuffleWriteMetrics - ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + // TODO: use existing ShuffleWriteMetrics + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); - long released = 0L; - while (dataPages.size() > 0) { - MemoryBlock block = dataPages.getLast(); - // The currentPage is used, cannot be released - if (block == currentPage) { - break; - } + long released = 0L; + while (dataPages.size() > 0) { + MemoryBlock block = dataPages.getLast(); + // The currentPage is used, cannot be released + if (block == currentPage) { + break; + } - Object base = block.getBaseObject(); - long offset = block.getBaseOffset(); - int numRecords = UnsafeAlignedOffset.getSize(base, offset); - int uaoSize = UnsafeAlignedOffset.getUaoSize(); - offset += uaoSize; - final UnsafeSorterSpillWriter writer = - new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); - while (numRecords > 0) { - int length = UnsafeAlignedOffset.getSize(base, offset); - writer.write(base, offset + uaoSize, length, 0); - offset += uaoSize + length + 8; - numRecords--; - } - writer.close(); - spillWriters.add(writer); + Object base = block.getBaseObject(); + long offset = block.getBaseOffset(); + int numRecords = UnsafeAlignedOffset.getSize(base, offset); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + offset += uaoSize; + final UnsafeSorterSpillWriter writer = + new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); + while (numRecords > 0) { + int length = UnsafeAlignedOffset.getSize(base, offset); + writer.write(base, offset + uaoSize, length, 0); + offset += uaoSize + length + 8; + numRecords--; + } + writer.close(); + spillWriters.add(writer); - dataPages.removeLast(); - released += block.size(); - freePage(block); + dataPages.removeLast(); + released += block.size(); + freePage(block); - if (released >= numBytes) { - break; - } + if (released >= numBytes) { + break; } - - return released; } + + return released; } @Override diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index 0bbaea6b834b8..6aa577d1bf797 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -38,12 +38,12 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { return used; } - void use(long size) { + public void use(long size) { long got = taskMemoryManager.acquireExecutionMemory(size, this); used += got; } - void free(long size) { + public void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index a11cd535b5471..e5fbafc23d957 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -33,6 +33,8 @@ import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.JavaUtils; @@ -678,4 +680,49 @@ public void testPeakMemoryUsed() { } } + @Test + public void avoidDeadlock() throws InterruptedException { + memoryManager.limit(PAGE_SIZE_BYTES); + MemoryMode mode = useOffHeapMemoryAllocator() ? MemoryMode.OFF_HEAP: MemoryMode.ON_HEAP; + TestMemoryConsumer c1 = new TestMemoryConsumer(taskMemoryManager, mode); + BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024); + + Thread thread = new Thread(() -> { + int i = 0; + long used = 0; + while (i < 10) { + c1.use(10000000); + used += 10000000; + i++; + } + c1.free(used); + }); + + try { + int i; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + } + + // Starts to require memory at another memory consumer. + thread.start(); + + BytesToBytesMap.MapIterator iter = map.destructiveIterator(); + for (i = 0; i < 1024; i++) { + iter.next(); + } + assertFalse(iter.hasNext()); + } finally { + map.free(); + thread.join(); + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + } + } From 5c67a9a7fa29836fc825504bbcc3c3fc820009c6 Mon Sep 17 00:00:00 2001 From: jiake Date: Tue, 11 Dec 2018 21:23:27 +0800 Subject: [PATCH 2260/2461] [SPARK-26316][SPARK-21052] Revert hash join metrics in that causes performance degradation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The wrong implementation in the hash join metrics in [spark 21052](https://issues.apache.org/jira/browse/SPARK-21052) caused significant performance degradation in TPC-DS. And the result is [here](https://docs.google.com/spreadsheets/d/18a5BdOlmm8euTaRodyeWum9yu92mbWWu6JbhGXtr7yE/edit#gid=0) in TPC-DS 1TB scale. So we currently partial revert 21052. **Cluster info:**   | Master Node | Worker Nodes -- | -- | -- Node | 1x | 4x Processor | Intel(R) Xeon(R) Platinum 8170 CPU 2.10GHz | Intel(R) Xeon(R) Platinum 8180 CPU 2.50GHz Memory | 192 GB | 384 GB Storage Main | 8 x 960G SSD | 8 x 960G SSD Network | 10Gbe |   Role | CM Management NameNodeSecondary NameNodeResource ManagerHive Metastore Server | DataNodeNodeManager OS Version | CentOS 7.2 | CentOS 7.2 Hadoop | Apache Hadoop 2.7.5 | Apache Hadoop 2.7.5 Hive | Apache Hive 2.2.0 |   Spark | Apache Spark 2.1.0  & Apache Spark2.3.0 |   JDK  version | 1.8.0_112 | 1.8.0_112 **Related parameters setting:** Component | Parameter | Value -- | -- | -- Yarn Resource Manager | yarn.scheduler.maximum-allocation-mb | 120GB   | yarn.scheduler.minimum-allocation-mb | 1GB   | yarn.scheduler.maximum-allocation-vcores | 121   | Yarn.resourcemanager.scheduler.class | Fair Scheduler Yarn Node Manager | yarn.nodemanager.resource.memory-mb | 120GB   | yarn.nodemanager.resource.cpu-vcores | 121 Spark | spark.executor.memory | 110GB   | spark.executor.cores | 50 ## How was this patch tested? N/A Closes #23269 from JkSelf/partial-revert-21052. Authored-by: jiake Signed-off-by: Wenchen Fan --- .../joins/BroadcastHashJoinExec.scala | 28 +----- .../spark/sql/execution/joins/HashJoin.scala | 8 +- .../sql/execution/joins/HashedRelation.scala | 35 ------- .../joins/ShuffledHashJoinExec.scala | 6 +- .../execution/metric/SQLMetricsSuite.scala | 94 +------------------ 5 files changed, 6 insertions(+), 165 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index a6f3ea47c8492..fd4a7897c7ad1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{BooleanType, LongType} -import org.apache.spark.util.TaskCompletionListener /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -48,8 +47,7 @@ case class BroadcastHashJoinExec( extends BinaryExecNode with HashJoin with CodegenSupport { override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) @@ -63,13 +61,12 @@ case class BroadcastHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val avgHashProbe = longMetric("avgHashProbe") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) - join(streamedIter, hashed, numOutputRows, avgHashProbe) + join(streamedIter, hashed, numOutputRows) } } @@ -111,23 +108,6 @@ case class BroadcastHashJoinExec( } } - /** - * Returns the codes used to add a task completion listener to update avg hash probe - * at the end of the task. - */ - private def genTaskListener(avgHashProbe: String, relationTerm: String): String = { - val listenerClass = classOf[TaskCompletionListener].getName - val taskContextClass = classOf[TaskContext].getName - s""" - | $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() { - | @Override - | public void onTaskCompletion($taskContextClass context) { - | $avgHashProbe.set($relationTerm.getAverageProbesPerLookup()); - | } - | }); - """.stripMargin - } - /** * Returns a tuple of Broadcast of HashedRelation and the variable name for it. */ @@ -137,15 +117,11 @@ case class BroadcastHashJoinExec( val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val clsName = broadcastRelation.value.getClass.getName - // At the end of the task, we update the avg hash probe. - val avgHashProbe = metricTerm(ctx, "avgHashProbe") - // Inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", v => s""" | $v = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($v.estimatedSize()); - | ${genTaskListener(avgHashProbe, v)} """.stripMargin, forceInline = true) (broadcastRelation, relationTerm) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index dab873bf9b9a0..1aef5f6864263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -194,8 +193,7 @@ trait HashJoin { protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, - numOutputRows: SQLMetric, - avgHashProbe: SQLMetric): Iterator[InternalRow] = { + numOutputRows: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => @@ -213,10 +211,6 @@ trait HashJoin { s"BroadcastHashJoin should not take $x as the JoinType") } - // At the end of the task, we update the avg hash probe. - TaskContext.get().addTaskCompletionListener[Unit](_ => - avgHashProbe.set(hashed.getAverageProbesPerLookup)) - val resultProj = createResultProjection joinedIter.map { r => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b1ff6e83acc24..7c21062c4cec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -80,11 +80,6 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { * Release any used resources. */ def close(): Unit - - /** - * Returns the average number of probes per key lookup. - */ - def getAverageProbesPerLookup: Double } private[execution] object HashedRelation { @@ -279,8 +274,6 @@ private[joins] class UnsafeHashedRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { read(() => in.readInt(), () => in.readLong(), in.readBytes) } - - override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup } private[joins] object UnsafeHashedRelation { @@ -395,10 +388,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap // The number of unique keys. private var numKeys = 0L - // Tracking average number of probes per key lookup. - private var numKeyLookups = 0L - private var numProbes = 0L - // needed by serializer def this() = { this( @@ -483,8 +472,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - numKeyLookups += 1 - numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -493,14 +480,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) - numKeyLookups += 1 - numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } pos = nextSlot(pos) - numProbes += 1 } } null @@ -528,8 +512,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - numKeyLookups += 1 - numProbes += 1 if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -538,14 +520,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) - numKeyLookups += 1 - numProbes += 1 while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } pos = nextSlot(pos) - numProbes += 1 } } null @@ -585,11 +564,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) assert(numKeys < array.length / 2) - numKeyLookups += 1 - numProbes += 1 while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) - numProbes += 1 } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. @@ -721,8 +697,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap writeLong(maxKey) writeLong(numKeys) writeLong(numValues) - writeLong(numKeyLookups) - writeLong(numProbes) writeLong(array.length) writeLongArray(writeBuffer, array, array.length) @@ -764,8 +738,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = readLong() numKeys = readLong() numValues = readLong() - numKeyLookups = readLong() - numProbes = readLong() val length = readLong().toInt mask = length - 2 @@ -783,11 +755,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap override def read(kryo: Kryo, in: Input): Unit = { read(() => in.readBoolean(), () => in.readLong(), in.readBytes) } - - /** - * Returns the average number of probes per key lookup. - */ - def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups } private[joins] class LongHashedRelation( @@ -839,8 +806,6 @@ private[joins] class LongHashedRelation( resultRow = new UnsafeRow(nFields) map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } - - override def getAverageProbesPerLookup: Double = map.getAverageProbesPerLookup } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 2b59ed6e4d16b..524804d61e599 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -42,8 +42,7 @@ case class ShuffledHashJoinExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), - "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), - "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) override def requiredChildDistribution: Seq[Distribution] = HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil @@ -63,10 +62,9 @@ case class ShuffledHashJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val avgHashProbe = longMetric("avgHashProbe") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) - join(streamIter, hashed, numOutputRows, avgHashProbe) + join(streamIter, hashed, numOutputRows) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4a80638f68858..f6495496a58e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -261,50 +261,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared ) } - test("BroadcastHashJoin metrics: track avg probe") { - // The executed plan looks like: - // Project [a#210, b#211, b#221] - // +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight - // :- Project [_1#207 AS a#210, _2#208 AS b#211] - // : +- Filter isnotnull(_1#207) - // : +- LocalTableScan [_1#207, _2#208] - // +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true])) - // +- Project [_1#217 AS a#220, _2#218 AS b#221] - // +- Filter isnotnull(_1#217) - // +- LocalTableScan [_1#217, _2#218] - // - // Assume the execution plan with node id is - // WholeStageCodegen disabled: - // Project(nodeId = 0) - // BroadcastHashJoin(nodeId = 1) - // ...(ignored) - // - // WholeStageCodegen enabled: - // WholeStageCodegen(nodeId = 0) - // Project(nodeId = 1) - // BroadcastHashJoin(nodeId = 2) - // Project(nodeId = 3) - // Filter(nodeId = 4) - // ...(ignored) - Seq(true, false).foreach { enableWholeStage => - val df1 = generateRandomBytesDF() - val df2 = generateRandomBytesDF() - val df = df1.join(broadcast(df2), "a") - val nodeIds = if (enableWholeStage) { - Set(2L) - } else { - Set(1L) - } - val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") - probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => - assert(probe.toDouble > 1.0) - } - } - } - } - test("ShuffledHashJoin metrics") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40", "spark.sql.shuffle.partitions" -> "2", @@ -323,8 +279,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = df1.join(df2, "key") testSparkPlanMetrics(df, 1, Map( 1L -> (("ShuffledHashJoin", Map( - "number of output rows" -> 2L, - "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))), + "number of output rows" -> 2L))), 2L -> (("Exchange", Map( "shuffle records written" -> 2L, "records read" -> 2L))), @@ -335,53 +290,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } } - test("ShuffledHashJoin metrics: track avg probe") { - // The executed plan looks like: - // Project [a#308, b#309, b#319] - // +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight - // :- Exchange hashpartitioning(a#308, 2) - // : +- Project [_1#305 AS a#308, _2#306 AS b#309] - // : +- Filter isnotnull(_1#305) - // : +- LocalTableScan [_1#305, _2#306] - // +- Exchange hashpartitioning(a#318, 2) - // +- Project [_1#315 AS a#318, _2#316 AS b#319] - // +- Filter isnotnull(_1#315) - // +- LocalTableScan [_1#315, _2#316] - // - // Assume the execution plan with node id is - // WholeStageCodegen disabled: - // Project(nodeId = 0) - // ShuffledHashJoin(nodeId = 1) - // ...(ignored) - // - // WholeStageCodegen enabled: - // WholeStageCodegen(nodeId = 0) - // Project(nodeId = 1) - // ShuffledHashJoin(nodeId = 2) - // ...(ignored) - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000", - "spark.sql.shuffle.partitions" -> "2", - "spark.sql.join.preferSortMergeJoin" -> "false") { - Seq(true, false).foreach { enableWholeStage => - val df1 = generateRandomBytesDF(65535 * 5) - val df2 = generateRandomBytesDF(65535) - val df = df1.join(df2, "a") - val nodeIds = if (enableWholeStage) { - Set(2L) - } else { - Set(1L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe (min, med, max)") - probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe => - assert(probe.toDouble > 1.0) - } - } - } - } - } - test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") From d811369ce23186cbb3208ad665e15408e13fea87 Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 11 Dec 2018 09:12:17 -0800 Subject: [PATCH 2261/2461] [SPARK-26300][SS] Remove a redundant `checkForStreaming` call ## What changes were proposed in this pull request? If `checkForContinuous` is called ( `checkForStreaming` is called in `checkForContinuous` ), the `checkForStreaming` mothod will be called twice in `createQuery` , this is not necessary, and the `checkForStreaming` method has a lot of statements, so it's better to remove one of them. ## How was this patch tested? Existing unit tests in `StreamingQueryManagerSuite` and `ContinuousAggregationSuite` Closes #23251 from 10110346/isUnsupportedOperationCheckEnabled. Authored-by: liuxian Signed-off-by: Dongjoon Hyun --- .../spark/sql/streaming/StreamingQueryManager.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index d9fe1a992a093..881cd96cc9dc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -246,9 +246,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo val analyzedPlan = df.queryExecution.analyzed df.queryExecution.assertAnalyzed() - if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { - UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) - } + val operationCheckEnabled = sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + @@ -257,7 +255,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo (sink, trigger) match { case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => - if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + if (operationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } new StreamingQueryWrapper(new ContinuousExecution( @@ -272,6 +270,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo extraOptions, deleteCheckpointOnStop)) case _ => + if (operationCheckEnabled) { + UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, From 57d6fbfa8c803ce1791e7be36aba0219a1fcaa63 Mon Sep 17 00:00:00 2001 From: mcheah Date: Tue, 11 Dec 2018 13:49:52 -0800 Subject: [PATCH 2262/2461] [SPARK-26239] File-based secret key loading for SASL. This proposes an alternative way to load secret keys into a Spark application that is running on Kubernetes. Instead of automatically generating the secret, the secret key can reside in a file that is shared between both the driver and executor containers. Unit tests. Closes #23252 from mccheah/auth-secret-with-file. Authored-by: mcheah Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/SecurityManager.scala | 34 +++++++++- .../scala/org/apache/spark/SparkEnv.scala | 4 +- .../spark/internal/config/package.scala | 31 +++++++++ .../apache/spark/SecurityManagerSuite.scala | 66 ++++++++++++++++++- docs/security.md | 44 +++++++++++++ .../features/BasicExecutorFeatureStep.scala | 16 +++-- .../BasicExecutorFeatureStepSuite.scala | 23 +++++++ 7 files changed, 205 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 96e4b53b24181..15783c952c231 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,8 +17,11 @@ package org.apache.spark +import java.io.File import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files +import java.util.Base64 import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} @@ -43,7 +46,8 @@ import org.apache.spark.util.Utils */ private[spark] class SecurityManager( sparkConf: SparkConf, - val ioEncryptionKey: Option[Array[Byte]] = None) + val ioEncryptionKey: Option[Array[Byte]] = None, + authSecretFileConf: ConfigEntry[Option[String]] = AUTH_SECRET_FILE) extends Logging with SecretKeyHolder { import SecurityManager._ @@ -328,6 +332,7 @@ private[spark] class SecurityManager( .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) + .orElse(secretKeyFromFile()) .getOrElse { throw new IllegalArgumentException( s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config") @@ -348,7 +353,6 @@ private[spark] class SecurityManager( */ def initializeAuth(): Unit = { import SparkMasterRegex._ - val k8sRegex = "k8s.*".r if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return @@ -371,7 +375,14 @@ private[spark] class SecurityManager( return } - secretKey = Utils.createSecret(sparkConf) + if (sparkConf.get(AUTH_SECRET_FILE_DRIVER).isDefined != + sparkConf.get(AUTH_SECRET_FILE_EXECUTOR).isDefined) { + throw new IllegalArgumentException( + "Invalid secret configuration: Secret files must be specified for both the driver and the" + + " executors, not only one or the other.") + } + + secretKey = secretKeyFromFile().getOrElse(Utils.createSecret(sparkConf)) if (storeInUgi) { val creds = new Credentials() @@ -380,6 +391,22 @@ private[spark] class SecurityManager( } } + private def secretKeyFromFile(): Option[String] = { + sparkConf.get(authSecretFileConf).flatMap { secretFilePath => + sparkConf.getOption(SparkLauncher.SPARK_MASTER).map { + case k8sRegex() => + val secretFile = new File(secretFilePath) + require(secretFile.isFile, s"No file found containing the secret key at $secretFilePath.") + val base64Key = Base64.getEncoder.encodeToString(Files.readAllBytes(secretFile.toPath)) + require(!base64Key.isEmpty, s"Secret key from file located at $secretFilePath is empty.") + base64Key + case _ => + throw new IllegalArgumentException( + "Secret keys provided via files is only allowed in Kubernetes mode.") + } + } + } + // Default SecurityManager only has a single secret key, so ignore appId. override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() @@ -387,6 +414,7 @@ private[spark] class SecurityManager( private[spark] object SecurityManager { + val k8sRegex = "k8s.*".r val SPARK_AUTH_CONF = NETWORK_AUTH_ENABLED.key val SPARK_AUTH_SECRET_CONF = "spark.authenticate.secret" // This is used to set auth secret to an executor's env variable. It should have the same diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 66038eeaea54f..de0c8579d9acc 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -232,8 +232,8 @@ object SparkEnv extends Logging { if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") } - - val securityManager = new SecurityManager(conf, ioEncryptionKey) + val authSecretFileConf = if (isDriver) AUTH_SECRET_FILE_DRIVER else AUTH_SECRET_FILE_EXECUTOR + val securityManager = new SecurityManager(conf, ioEncryptionKey, authSecretFileConf) if (isDriver) { securityManager.initializeAuth() } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 85bb557abef5d..f1c1c034df49a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -424,6 +424,37 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val AUTH_SECRET_FILE = + ConfigBuilder("spark.authenticate.secret.file") + .doc("Path to a file that contains the authentication secret to use. The secret key is " + + "loaded from this path on both the driver and the executors if overrides are not set for " + + "either entity (see below). File-based secret keys are only allowed when using " + + "Kubernetes.") + .stringConf + .createOptional + + private[spark] val AUTH_SECRET_FILE_DRIVER = + ConfigBuilder("spark.authenticate.secret.driver.file") + .doc("Path to a file that contains the authentication secret to use. Loaded by the " + + "driver. In Kubernetes client mode it is often useful to set a different secret " + + "path for the driver vs. the executors, since the driver may not be running in " + + "a pod unlike the executors. If this is set, an accompanying secret file must " + + "be specified for the executors. The fallback configuration allows the same path to be " + + "used for both the driver and the executors when running in cluster mode. File-based " + + "secret keys are only allowed when using Kubernetes.") + .fallbackConf(AUTH_SECRET_FILE) + + private[spark] val AUTH_SECRET_FILE_EXECUTOR = + ConfigBuilder("spark.authenticate.secret.executor.file") + .doc("Path to a file that contains the authentication secret to use. Loaded by the " + + "executors only. In Kubernetes client mode it is often useful to set a different " + + "secret path for the driver vs. the executors, since the driver may not be running " + + "in a pod unlike the executors. If this is set, an accompanying secret file must be " + + "specified for the executors. The fallback configuration allows the same path to be " + + "used for both the driver and the executors when running in cluster mode. File-based " + + "secret keys are only allowed when using Kubernetes.") + .fallbackConf(AUTH_SECRET_FILE) + private[spark] val NETWORK_ENCRYPTION_ENABLED = ConfigBuilder("spark.network.crypto.enabled") .booleanConf diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index eec8004fc94f4..e9061f4e7beb8 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark import java.io.File import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import java.security.PrivilegedExceptionAction +import java.util.Base64 import org.apache.hadoop.security.UserGroupInformation @@ -395,9 +397,54 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } + test("use executor-specific secret file configuration.") { + val secretFileFromDriver = createTempSecretFile("driver-secret") + val secretFileFromExecutor = createTempSecretFile("executor-secret") + val conf = new SparkConf() + .setMaster("k8s://127.0.0.1") + .set(AUTH_SECRET_FILE_DRIVER, Some(secretFileFromDriver.getAbsolutePath)) + .set(AUTH_SECRET_FILE_EXECUTOR, Some(secretFileFromExecutor.getAbsolutePath)) + .set(SecurityManager.SPARK_AUTH_CONF, "true") + val mgr = new SecurityManager(conf, authSecretFileConf = AUTH_SECRET_FILE_EXECUTOR) + assert(encodeFileAsBase64(secretFileFromExecutor) === mgr.getSecretKey()) + } + + test("secret file must be defined in both driver and executor") { + val conf1 = new SparkConf() + .set(AUTH_SECRET_FILE_DRIVER, Some("/tmp/driver-secret.txt")) + .set(SecurityManager.SPARK_AUTH_CONF, "true") + val mgr1 = new SecurityManager(conf1) + intercept[IllegalArgumentException] { + mgr1.initializeAuth() + } + + val conf2 = new SparkConf() + .set(AUTH_SECRET_FILE_EXECUTOR, Some("/tmp/executor-secret.txt")) + .set(SecurityManager.SPARK_AUTH_CONF, "true") + val mgr2 = new SecurityManager(conf2) + intercept[IllegalArgumentException] { + mgr2.initializeAuth() + } + } + + Seq("yarn", "local", "local[*]", "local[1,2]", "mesos://localhost:8080").foreach { master => + test(s"master $master cannot use file mounted secrets") { + val conf = new SparkConf() + .set(AUTH_SECRET_FILE, "/tmp/secret.txt") + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .setMaster(master) + intercept[IllegalArgumentException] { + new SecurityManager(conf).getSecretKey() + } + intercept[IllegalArgumentException] { + new SecurityManager(conf).initializeAuth() + } + } + } + // How is the secret expected to be generated and stored. object SecretTestType extends Enumeration { - val MANUAL, AUTO, UGI = Value + val MANUAL, AUTO, UGI, FILE = Value } import SecretTestType._ @@ -408,6 +455,7 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { ("local[*]", UGI), ("local[1, 2]", UGI), ("k8s://127.0.0.1", AUTO), + ("k8s://127.0.1.1", FILE), ("local-cluster[2, 1, 1024]", MANUAL), ("invalid", MANUAL) ).foreach { case (master, secretType) => @@ -440,6 +488,12 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { intercept[IllegalArgumentException] { mgr.getSecretKey() } + + case FILE => + val secretFile = createTempSecretFile() + conf.set(AUTH_SECRET_FILE, secretFile.getAbsolutePath) + mgr.initializeAuth() + assert(encodeFileAsBase64(secretFile) === mgr.getSecretKey()) } } } @@ -447,5 +501,15 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { } } + private def encodeFileAsBase64(secretFile: File) = { + Base64.getEncoder.encodeToString(Files.readAllBytes(secretFile.toPath)) + } + + private def createTempSecretFile(contents: String = "test-secret"): File = { + val secretDir = Utils.createTempDir("temp-secrets") + val secretFile = new File(secretDir, "temp-secret.txt") + Files.write(secretFile.toPath, contents.getBytes(UTF_8)) + secretFile + } } diff --git a/docs/security.md b/docs/security.md index 2a4f3c074c1e5..8416ed91356aa 100644 --- a/docs/security.md +++ b/docs/security.md @@ -66,6 +66,50 @@ Kubernetes admin to ensure that Spark authentication is secure.
      Property NameDefaultMeaning
      spark.files.maxPartitionBytes134217728 (128 MB)134217728 (128 MiB) The maximum number of bytes to pack into a single partition when reading files.
      spark.files.openCostInBytes4194304 (4 MB)4194304 (4 MiB) The estimated cost to open a file, measured by the number of bytes could be scanned at the same time. This is used when putting multiple files into a partition. It is better to overestimate, @@ -1445,7 +1445,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.message.maxSize 128 - Maximum message size (in MB) to allow in "control plane" communication; generally only applies to map + Maximum message size (in MiB) to allow in "control plane" communication; generally only applies to map output size information sent between executors and the driver. Increase this if you are running jobs with many thousands of map and reduce tasks and see messages about the RPC message size. spark.mesos.executor.memoryOverhead executor memory * 0.10, with minimum of 384 - The amount of additional memory, specified in MB, to be allocated per executor. By default, + The amount of additional memory, specified in MiB, to be allocated per executor. By default, the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value.
      -m MEM, --memory MEMTotal amount of memory to allow Spark applications to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on workerTotal amount of memory to allow Spark applications to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GiB); only on worker
      -d DIR, --work-dir DIR
      SPARK_WORKER_MEMORYTotal amount of memory to allow Spark applications to use on the machine, e.g. 1000m, 2g (default: total memory minus 1 GB); note that each application's individual memory is configured using its spark.executor.memory property.Total amount of memory to allow Spark applications to use on the machine, e.g. 1000m, 2g (default: total memory minus 1 GiB); note that each application's individual memory is configured using its spark.executor.memory property.
      SPARK_WORKER_PORT
      +Alternatively, one can mount authentication secrets using files and Kubernetes secrets that +the user mounts into their pods. + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.authenticate.secret.fileNone + Path pointing to the secret key to use for securing connections. Ensure that the + contents of the file have been securely generated. This file is loaded on both the driver + and the executors unless other settings override this (see below). +
      spark.authenticate.secret.driver.fileThe value of spark.authenticate.secret.file + When specified, overrides the location that the Spark driver reads to load the secret. + Useful when in client mode, when the location of the secret file may differ in the pod versus + the node the driver is running in. When this is specified, + spark.authenticate.secret.executor.file must be specified so that the driver + and the executors can both use files to load the secret key. Ensure that the contents of the file + on the driver is identical to the contents of the file on the executors. +
      spark.authenticate.secret.executor.fileThe value of spark.authenticate.secret.file + When specified, overrides the location that the Spark executors read to load the secret. + Useful in client mode, when the location of the secret file may differ in the pod versus + the node the driver is running in. When this is specified, + spark.authenticate.secret.driver.file must be specified so that the driver + and the executors can both use files to load the secret key. Ensure that the contents of the file + on the driver is identical to the contents of the file on the executors. +
      + +Note that when using files, Spark will not mount these files into the containers for you. It is up +you to ensure that the secret files are deployed securely into your containers and that the driver's +secret file agrees with the executors' secret file. + ## Encryption Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 939aa88b07973..4bcf4c9446aa3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, PYSPARK_EXECUTOR_MEMORY} +import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -112,12 +112,14 @@ private[spark] class BasicExecutorFeatureStep( .build()) .build()) } ++ { - Option(secMgr.getSecretKey()).map { authSecret => - new EnvVarBuilder() - .withName(SecurityManager.ENV_AUTH_SECRET) - .withValue(authSecret) - .build() - } + if (kubernetesConf.get(AUTH_SECRET_FILE_EXECUTOR).isEmpty) { + Option(secMgr.getSecretKey()).map { authSecret => + new EnvVarBuilder() + .withName(SecurityManager.ENV_AUTH_SECRET) + .withValue(authSecret) + .build() + } + } else None } ++ { kubernetesConf.get(EXECUTOR_CLASS_PATH).map { cp => new EnvVarBuilder() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 6aa862643c788..05989d9be7ad5 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.deploy.k8s.features +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ @@ -158,6 +162,25 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { checkEnv(executor, conf, Map(SecurityManager.ENV_AUTH_SECRET -> secMgr.getSecretKey())) } + test("Auth secret shouldn't propagate if files are loaded.") { + val secretDir = Utils.createTempDir("temp-secret") + val secretFile = new File(secretDir, "secret-file.txt") + Files.write(secretFile.toPath, "some-secret".getBytes(StandardCharsets.UTF_8)) + val conf = baseConf.clone() + .set(NETWORK_AUTH_ENABLED, true) + .set(AUTH_SECRET_FILE, secretFile.getAbsolutePath) + .set("spark.master", "k8s://127.0.0.1") + val secMgr = new SecurityManager(conf) + secMgr.initializeAuth() + + val step = new BasicExecutorFeatureStep(KubernetesTestConf.createExecutorConf(sparkConf = conf), + secMgr) + + val executor = step.configurePod(SparkPod.initialPod()) + assert(!KubernetesFeaturesTestUtils.containerHasEnvVar( + executor.container, SecurityManager.ENV_AUTH_SECRET)) + } + // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) From bd8da3799dd160771ebb3ea55b7678b644248425 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 12 Dec 2018 10:03:50 +0800 Subject: [PATCH 2263/2461] [SPARK-26193][SQL][FOLLOW UP] Read metrics rename and display text changes ## What changes were proposed in this pull request? Follow up pr for #23207, include following changes: - Rename `SQLShuffleMetricsReporter` to `SQLShuffleReadMetricsReporter` to make it match with write side naming. - Display text changes for read side for naming consistent. - Rename function in `ShuffleWriteProcessor`. - Delete `private[spark]` in execution package. ## How was this patch tested? Existing tests. Closes #23286 from xuanyuanking/SPARK-26193-follow. Authored-by: Yuanjian Li Signed-off-by: Wenchen Fan --- .../spark/scheduler/ShuffleMapTask.scala | 2 +- .../spark/shuffle/ShuffleWriteProcessor.scala | 2 +- .../spark/sql/execution/ShuffledRowRDD.scala | 6 ++-- .../exchange/ShuffleExchangeExec.scala | 4 +-- .../apache/spark/sql/execution/limit.scala | 6 ++-- .../metric/SQLShuffleMetricsReporter.scala | 36 +++++++++---------- .../execution/UnsafeRowSerializerSuite.scala | 4 +-- .../execution/metric/SQLMetricsSuite.scala | 20 +++++------ 8 files changed, 40 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2a8d1dd995e27..35664ff515d4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -92,7 +92,7 @@ private[spark] class ShuffleMapTask( threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L - dep.shuffleWriterProcessor.writeProcess(rdd, dep, partitionId, context, partition) + dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition) } override def preferredLocations: Seq[TaskLocation] = preferredLocs diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala index f5213157a9a85..5b0c7e9f2b0b4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -41,7 +41,7 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { * get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for * this task. */ - def writeProcess( + def write( rdd: RDD[_], dep: ShuffleDependency[_, _, _], partitionId: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 9b05faaed0459..079ff25fcb67e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -22,7 +22,7 @@ import java.util.Arrays import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} /** * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition @@ -157,9 +157,9 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() - // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator, + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics) + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 0c2020572e721..da7b0c6f43fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair @@ -50,7 +50,7 @@ case class ShuffleExchangeExec( private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) private lazy val readMetrics = - SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") ) ++ readMetrics ++ writeMetrics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1f2fdde538645..bfaf080292bce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.metric.{SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} /** * Take the first `limit` elements and collect them to a single partition. @@ -41,7 +41,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) private lazy val readMetrics = - SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) @@ -165,7 +165,7 @@ case class TakeOrderedAndProjectExec( private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) private lazy val readMetrics = - SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala index ff7941e3b3e8d..2c0ea80495abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -27,23 +27,23 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter * @param metrics All metrics in current SparkPlan. This param should not empty and * contains all shuffle metrics defined in createShuffleReadMetrics. */ -private[spark] class SQLShuffleMetricsReporter( +class SQLShuffleReadMetricsReporter( tempMetrics: TempShuffleReadMetrics, metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { private[this] val _remoteBlocksFetched = - metrics(SQLShuffleMetricsReporter.REMOTE_BLOCKS_FETCHED) + metrics(SQLShuffleReadMetricsReporter.REMOTE_BLOCKS_FETCHED) private[this] val _localBlocksFetched = - metrics(SQLShuffleMetricsReporter.LOCAL_BLOCKS_FETCHED) + metrics(SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED) private[this] val _remoteBytesRead = - metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ) + metrics(SQLShuffleReadMetricsReporter.REMOTE_BYTES_READ) private[this] val _remoteBytesReadToDisk = - metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ_TO_DISK) + metrics(SQLShuffleReadMetricsReporter.REMOTE_BYTES_READ_TO_DISK) private[this] val _localBytesRead = - metrics(SQLShuffleMetricsReporter.LOCAL_BYTES_READ) + metrics(SQLShuffleReadMetricsReporter.LOCAL_BYTES_READ) private[this] val _fetchWaitTime = - metrics(SQLShuffleMetricsReporter.FETCH_WAIT_TIME) + metrics(SQLShuffleReadMetricsReporter.FETCH_WAIT_TIME) private[this] val _recordsRead = - metrics(SQLShuffleMetricsReporter.RECORDS_READ) + metrics(SQLShuffleReadMetricsReporter.RECORDS_READ) override def incRemoteBlocksFetched(v: Long): Unit = { _remoteBlocksFetched.add(v) @@ -75,7 +75,7 @@ private[spark] class SQLShuffleMetricsReporter( } } -private[spark] object SQLShuffleMetricsReporter { +object SQLShuffleReadMetricsReporter { val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" val REMOTE_BYTES_READ = "remoteBytesRead" @@ -88,8 +88,8 @@ private[spark] object SQLShuffleMetricsReporter { * Create all shuffle read relative metrics and return the Map. */ def createShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( - REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks fetched"), - LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks fetched"), + REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks read"), + LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks read"), REMOTE_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "remote bytes read"), REMOTE_BYTES_READ_TO_DISK -> SQLMetrics.createSizeMetric(sc, "remote bytes read to disk"), LOCAL_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "local bytes read"), @@ -102,7 +102,7 @@ private[spark] object SQLShuffleMetricsReporter { * @param metricsReporter Other reporter need to be updated in this SQLShuffleWriteMetricsReporter. * @param metrics Shuffle write metrics in current SparkPlan. */ -private[spark] class SQLShuffleWriteMetricsReporter( +class SQLShuffleWriteMetricsReporter( metricsReporter: ShuffleWriteMetricsReporter, metrics: Map[String, SQLMetric]) extends ShuffleWriteMetricsReporter { private[this] val _bytesWritten = @@ -112,29 +112,29 @@ private[spark] class SQLShuffleWriteMetricsReporter( private[this] val _writeTime = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME) - override private[spark] def incBytesWritten(v: Long): Unit = { + override def incBytesWritten(v: Long): Unit = { metricsReporter.incBytesWritten(v) _bytesWritten.add(v) } - override private[spark] def decRecordsWritten(v: Long): Unit = { + override def decRecordsWritten(v: Long): Unit = { metricsReporter.decBytesWritten(v) _recordsWritten.set(_recordsWritten.value - v) } - override private[spark] def incRecordsWritten(v: Long): Unit = { + override def incRecordsWritten(v: Long): Unit = { metricsReporter.incRecordsWritten(v) _recordsWritten.add(v) } - override private[spark] def incWriteTime(v: Long): Unit = { + override def incWriteTime(v: Long): Unit = { metricsReporter.incWriteTime(v) _writeTime.add(v) } - override private[spark] def decBytesWritten(v: Long): Unit = { + override def decBytesWritten(v: Long): Unit = { metricsReporter.decBytesWritten(v) _bytesWritten.set(_bytesWritten.value - v) } } -private[spark] object SQLShuffleWriteMetricsReporter { +object SQLShuffleWriteMetricsReporter { val SHUFFLE_BYTES_WRITTEN = "shuffleBytesWritten" val SHUFFLE_RECORDS_WRITTEN = "shuffleRecordsWritten" val SHUFFLE_WRITE_TIME = "shuffleWriteTime" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1ad5713ab8ae6..ca8692290edb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter +import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -140,7 +140,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { new UnsafeRowSerializer(2)) val shuffled = new ShuffledRowRDD( dependency, - SQLShuffleMetricsReporter.createShuffleReadMetrics(spark.sparkContext)) + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(spark.sparkContext)) shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f6495496a58e1..47265df4831df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -96,8 +96,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) val shuffleExpected1 = Map( "records read" -> 2L, - "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L, + "local blocks read" -> 2L, + "remote blocks read" -> 0L, "shuffle records written" -> 2L) testSparkPlanMetrics(df, 1, Map( 2L -> (("HashAggregate", expected1(0))), @@ -114,8 +114,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) val shuffleExpected2 = Map( "records read" -> 4L, - "local blocks fetched" -> 4L, - "remote blocks fetched" -> 0L, + "local blocks read" -> 4L, + "remote blocks read" -> 0L, "shuffle records written" -> 4L) testSparkPlanMetrics(df2, 1, Map( 2L -> (("HashAggregate", expected2(0))), @@ -175,8 +175,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared 1L -> (("Exchange", Map( "shuffle records written" -> 2L, "records read" -> 2L, - "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L))), + "local blocks read" -> 2L, + "remote blocks read" -> 0L))), 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 1L)))) ) @@ -187,8 +187,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared 1L -> (("Exchange", Map( "shuffle records written" -> 4L, "records read" -> 4L, - "local blocks fetched" -> 4L, - "remote blocks fetched" -> 0L))), + "local blocks read" -> 4L, + "remote blocks read" -> 0L))), 0L -> (("ObjectHashAggregate", Map("number of output rows" -> 3L)))) ) } @@ -216,8 +216,8 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "number of output rows" -> 4L))), 2L -> (("Exchange", Map( "records read" -> 4L, - "local blocks fetched" -> 2L, - "remote blocks fetched" -> 0L, + "local blocks read" -> 2L, + "remote blocks read" -> 0L, "shuffle records written" -> 2L)))) ) } From 79e36e2c2ac01458b5baa3f3ee310fddd29e9c35 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 12 Dec 2018 09:03:13 -0600 Subject: [PATCH 2264/2461] [SPARK-19827][R][FOLLOWUP] spark.ml R API for PIC ## What changes were proposed in this pull request? Follow up style fixes to PIC in R; see #23072 ## How was this patch tested? Existing tests. Closes #23292 from srowen/SPARK-19827.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- R/pkg/R/mllib_clustering.R | 15 ++++++--------- R/pkg/R/mllib_fpm.R | 4 ++-- examples/src/main/r/ml/powerIterationClustering.R | 3 ++- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 7d9dcebfe70d3..9b32b71d34fef 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -621,11 +621,10 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' #' A scalable graph clustering algorithm. Users can call \code{spark.assignClusters} to #' return a cluster assignment for each input vertex. -#' -# Run the PIC algorithm and returns a cluster assignment for each input vertex. +#' Run the PIC algorithm and returns a cluster assignment for each input vertex. #' @param data a SparkDataFrame. #' @param k the number of clusters to create. -#' @param initMode the initialization algorithm. +#' @param initMode the initialization algorithm; "random" or "degree" #' @param maxIter the maximum number of iterations. #' @param sourceCol the name of the input column for source vertex IDs. #' @param destinationCol the name of the input column for destination vertex IDs @@ -633,18 +632,16 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' we treat all instance weights as 1.0. #' @param ... additional argument(s) passed to the method. #' @return A dataset that contains columns of vertex id and the corresponding cluster for the id. -#' The schema of it will be: -#' \code{id: Long} -#' \code{cluster: Int} +#' The schema of it will be: \code{id: integer}, \code{cluster: integer} #' @rdname spark.powerIterationClustering -#' @aliases assignClusters,PowerIterationClustering-method,SparkDataFrame-method +#' @aliases spark.assignClusters,SparkDataFrame-method #' @examples #' \dontrun{ #' df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), #' list(1L, 2L, 1.0), list(3L, 4L, 1.0), #' list(4L, 0L, 0.1)), #' schema = c("src", "dst", "weight")) -#' clusters <- spark.assignClusters(df, initMode="degree", weightCol="weight") +#' clusters <- spark.assignClusters(df, initMode = "degree", weightCol = "weight") #' showDF(clusters) #' } #' @note spark.assignClusters(SparkDataFrame) since 3.0.0 @@ -652,7 +649,7 @@ setMethod("spark.assignClusters", signature(data = "SparkDataFrame"), function(data, k = 2L, initMode = c("random", "degree"), maxIter = 20L, sourceCol = "src", destinationCol = "dst", weightCol = NULL) { - if (!is.numeric(k) || k < 1) { + if (!is.integer(k) || k < 1) { stop("k should be a number with value >= 1.") } if (!is.integer(maxIter) || maxIter <= 0) { diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index c248e9ec9be94..0cc7a16c302dc 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -183,8 +183,8 @@ setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), #' @return A complete set of frequent sequential patterns in the input sequences of itemsets. #' The returned \code{SparkDataFrame} contains columns of sequence and corresponding #' frequency. The schema of it will be: -#' \code{sequence: ArrayType(ArrayType(T))} (T is the item type) -#' \code{freq: Long} +#' \code{sequence: ArrayType(ArrayType(T))}, \code{freq: integer} +#' where T is the item type #' @rdname spark.prefixSpan #' @aliases findFrequentSequentialPatterns,PrefixSpan,SparkDataFrame-method #' @examples diff --git a/examples/src/main/r/ml/powerIterationClustering.R b/examples/src/main/r/ml/powerIterationClustering.R index ba43037106d14..3530d88e50509 100644 --- a/examples/src/main/r/ml/powerIterationClustering.R +++ b/examples/src/main/r/ml/powerIterationClustering.R @@ -30,7 +30,8 @@ df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), list(4L, 0L, 0.1)), schema = c("src", "dst", "weight")) # assign clusters -clusters <- spark.assignClusters(df, k=2L, maxIter=20L, initMode="degree", weightCol="weight") +clusters <- spark.assignClusters(df, k = 2L, maxIter = 20L, + initMode = "degree", weightCol = "weight") showDF(arrange(clusters, clusters$id)) # $example off$ From 570b8f3d45ad8d6649ed633251a8194d910f1ab5 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 12 Dec 2018 10:06:41 -0600 Subject: [PATCH 2265/2461] [SPARK-24102][ML][MLLIB] ML Evaluators should use weight column - added weight column for regression evaluator ## What changes were proposed in this pull request? The evaluators BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator and the corresponding metrics classes BinaryClassificationMetrics, RegressionMetrics and MulticlassMetrics should use sample weight data. I've closed the PR: https://github.com/apache/spark/pull/16557 as recommended in favor of creating three pull requests, one for each of the evaluators (binary/regression/multiclass) to make it easier to review/update. The updates to the regression metrics were based on (and updated with new changes based on comments): https://issues.apache.org/jira/browse/SPARK-11520 ("RegressionMetrics should support instance weights") but the pull request was closed as the changes were never checked in. ## How was this patch tested? I added tests to the metrics class. Closes #17085 from imatiach-msft/ilmat/regression-evaluate. Authored-by: Ilya Matiach Signed-off-by: Sean Owen --- .../ml/evaluation/RegressionEvaluator.scala | 19 ++++--- .../mllib/evaluation/RegressionMetrics.scala | 30 ++++++----- .../stat/MultivariateOnlineSummarizer.scala | 25 ++++++---- .../stat/MultivariateStatisticalSummary.scala | 6 +++ .../evaluation/RegressionMetricsSuite.scala | 50 +++++++++++++++++++ project/MimaExcludes.scala | 5 +- 6 files changed, 106 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 031cd0d635bf4..616569bb55e4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "rmse") @Since("2.0.0") @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = dataset - .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) + val predictionAndLabelsWithWeights = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) .rdd - .map { case Row(prediction: Double, label: Double) => (prediction, label) } - val metrics = new RegressionMetrics(predictionAndLabels) + .map { case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) } + val metrics = new RegressionMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "rmse" => metrics.rootMeanSquaredError case "mse" => metrics.meanSquaredError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 020676cac5a64..525047973ad5c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight) + * or (prediction, observation) pairs * @param throughOrigin True if the regression is through the origin. For example, in linear * regression, it will be true without fitting intercept. */ @Since("1.2.0") class RegressionMetrics @Since("2.0.0") ( - predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean) extends Logging { @Since("1.2.0") - def this(predictionAndObservations: RDD[(Double, Double)]) = + def this(predictionAndObservations: RDD[_ <: Product]) = this(predictionAndObservations, false) /** @@ -52,10 +53,13 @@ class RegressionMetrics @Since("2.0.0") ( * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. */ private lazy val summary: MultivariateStatisticalSummary = { - val summary: MultivariateStatisticalSummary = predictionAndObservations.map { - case (prediction, observation) => Vectors.dense(observation, observation - prediction) + val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map { + case (prediction: Double, observation: Double, weight: Double) => + (Vectors.dense(observation, observation - prediction), weight) + case (prediction: Double, observation: Double) => + (Vectors.dense(observation, observation - prediction), 1.0) }.treeAggregate(new MultivariateOnlineSummarizer())( - (summary, v) => summary.add(v), + (summary, sample) => summary.add(sample._1, sample._2), (sum1, sum2) => sum1.merge(sum2) ) summary @@ -63,11 +67,13 @@ class RegressionMetrics @Since("2.0.0") ( private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) - private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SStot = summary.variance(0) * (summary.weightSum - 1) private lazy val SSreg = { val yMean = summary.mean(0) - predictionAndObservations.map { - case (prediction, _) => math.pow(prediction - yMean, 2) + predAndObsWithOptWeight.map { + case (prediction: Double, _: Double, weight: Double) => + math.pow(prediction - yMean, 2) * weight + case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2) }.sum() } @@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def explainedVariance: Double = { - SSreg / summary.count + SSreg / summary.weightSum } /** @@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanAbsoluteError: Double = { - summary.normL1(1) / summary.count + summary.normL1(1) / summary.weightSum } /** @@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanSquaredError: Double = { - SSerr / summary.count + SSerr / summary.weightSum } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 0554b6d8ff5b5..6d510e1633d67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var totalCnt: Long = 0 private var totalWeightSum: Double = 0.0 private var weightSquareSum: Double = 0.0 - private var weightSum: Array[Double] = _ + private var currWeightSum: Array[Double] = _ private var nnz: Array[Long] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n = Array.ofDim[Double](n) currM2 = Array.ofDim[Double](n) currL1 = Array.ofDim[Double](n) - weightSum = Array.ofDim[Double](n) + currWeightSum = Array.ofDim[Double](n) nnz = Array.ofDim[Long](n) currMax = Array.fill[Double](n)(Double.MinValue) currMin = Array.fill[Double](n)(Double.MaxValue) @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 - val localWeightSum = weightSum + val localWeightSum = currWeightSum val localNumNonzeros = nnz val localCurrMax = currMax val localCurrMin = currMin @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { - val thisNnz = weightSum(i) - val otherNnz = other.weightSum(i) + val thisNnz = currWeightSum(i) + val otherNnz = other.currWeightSum(i) val totalNnz = thisNnz + otherNnz val totalCnnz = nnz(i) + other.nnz(i) if (totalNnz != 0.0) { @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMax(i) = math.max(currMax(i), other.currMax(i)) currMin(i) = math.min(currMin(i), other.currMin(i)) } - weightSum(i) = totalNnz + currWeightSum(i) = totalNnz nnz(i) = totalCnnz i += 1 } @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.totalCnt = other.totalCnt this.totalWeightSum = other.totalWeightSum this.weightSquareSum = other.weightSquareSum - this.weightSum = other.weightSum.clone() + this.currWeightSum = other.currWeightSum.clone() this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum) + realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum) i += 1 } Vectors.dense(realMean) @@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val len = currM2n.length while (i < len) { // We prevent variance from negative value caused by numerical error. - realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) * + (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } @@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S @Since("1.1.0") override def count: Long = totalCnt + /** + * Sum of weights. + */ + override def weightSum: Double = totalWeightSum + /** * Number of nonzero elements in each dimension. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 39a16fb743d64..a4381032f8c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary { @Since("1.0.0") def count: Long + /** + * Sum of weights. + */ + @Since("3.0.0") + def weightSum: Double + /** * Number of nonzero elements (including explicitly presented zero values) in each column. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index f1d517383643d..23809777f7d3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { "root mean squared error mismatch") assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } + + test("regression metrics with same (1.0) weight samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false) + assert(metrics.explainedVariance ~== 8.79687 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch") + } + + /** + * The following values are hand calculated using the formula: + * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + * preds = c(2.25, -0.25, 1.75, 7.75) + * obs = c(3.0, -0.5, 2.0, 7.0) + * weights = c(0.1, 0.2, 0.15, 0.05) + * count = 4 + * + * Weighted metrics can be calculated with MultivariateStatisticalSummary. + * (observations, observations - predictions) + * mean (1.7, 0.05) + * variance (7.3, 0.3) + * numNonZeros (0.5, 0.5) + * max (7.0, 0.75) + * min (-0.5, -0.75) + * normL2 (2.0, 0.32596) + * normL1 (1.05, 0.2) + * + * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425 + * meanAbsoluteError: normL1(1) / weightedCount = 0.4 + * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125 + * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098 + * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910 + */ + test("regression metrics with weighted samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2) + val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false) + assert(metrics.explainedVariance ~== 5.2425 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b3252d70a80c8..883913332ca1e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -531,7 +531,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"), + + // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum") ) ++ Seq( // [SPARK-17019] Expose on-heap and off-heap memory usage in various places ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), From a63e7b2a212bab94d080b00cf1c5f397800a276a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 12 Dec 2018 12:01:21 -0800 Subject: [PATCH 2266/2461] [SPARK-25877][K8S] Move all feature logic to feature classes. This change makes the driver and executor builders a lot simpler by encapsulating almost all feature logic into the respective feature classes. The only logic that remains is the creation of the initial pod, which needs to happen before anything else so is better to be left in the builder class. Most feature classes already behave fine when the config has nothing they should handle, but a few minor tweaks had to be added. Unit tests were also updated or added to account for these. The builder suites were simplified a lot and just test the remaining pod-related code in the builders themselves. Author: Marcelo Vanzin Closes #23220 from vanzin/SPARK-25877. --- .../HadoopConfExecutorFeatureStep.scala | 10 +- .../HadoopSparkUserExecutorFeatureStep.scala | 5 +- .../KerberosConfExecutorFeatureStep.scala | 26 +-- .../features/PodTemplateConfigMapStep.scala | 82 +++++--- .../submit/KubernetesClientApplication.scala | 4 +- .../k8s/submit/KubernetesDriverBuilder.scala | 99 +++------ .../cluster/k8s/ExecutorPodsAllocator.scala | 3 +- .../k8s/KubernetesClusterManager.scala | 2 +- .../k8s/KubernetesExecutorBuilder.scala | 100 +++------ .../spark/deploy/k8s/PodBuilderSuite.scala | 177 ++++++++++++++++ .../PodTemplateConfigMapStepSuite.scala | 25 ++- .../spark/deploy/k8s/submit/ClientSuite.scala | 2 +- .../submit/KubernetesDriverBuilderSuite.scala | 194 +----------------- .../k8s/submit/PodBuilderSuiteUtils.scala | 142 ------------- .../k8s/ExecutorPodsAllocatorSuite.scala | 4 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 144 +------------ 16 files changed, 343 insertions(+), 676 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala index bca66759d586e..da332881ae1a2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala @@ -31,10 +31,10 @@ private[spark] class HadoopConfExecutorFeatureStep(conf: KubernetesExecutorConf) override def configurePod(pod: SparkPod): SparkPod = { val hadoopConfDirCMapName = conf.getOption(HADOOP_CONFIG_MAP_NAME) - require(hadoopConfDirCMapName.isDefined, - "Ensure that the env `HADOOP_CONF_DIR` is defined either in the client or " + - " using pre-existing ConfigMaps") - logInfo("HADOOP_CONF_DIR defined") - HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) + if (hadoopConfDirCMapName.isDefined) { + HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) + } else { + pod + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala index e342110763196..c038e75491ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala @@ -28,7 +28,8 @@ private[spark] class HadoopSparkUserExecutorFeatureStep(conf: KubernetesExecutor extends KubernetesFeatureConfigStep { override def configurePod(pod: SparkPod): SparkPod = { - val sparkUserName = conf.get(KERBEROS_SPARK_USER_NAME) - HadoopBootstrapUtil.bootstrapSparkUserPod(sparkUserName, pod) + conf.getOption(KERBEROS_SPARK_USER_NAME).map { user => + HadoopBootstrapUtil.bootstrapSparkUserPod(user, pod) + }.getOrElse(pod) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala index 32bb6a5d2bcbb..907271b1cb483 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala @@ -27,18 +27,20 @@ import org.apache.spark.internal.Logging private[spark] class KerberosConfExecutorFeatureStep(conf: KubernetesExecutorConf) extends KubernetesFeatureConfigStep with Logging { - private val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) - require(maybeKrb5CMap.isDefined, "HADOOP_CONF_DIR ConfigMap not found") - override def configurePod(pod: SparkPod): SparkPod = { - logInfo(s"Mounting Resources for Kerberos") - HadoopBootstrapUtil.bootstrapKerberosPod( - conf.get(KERBEROS_DT_SECRET_NAME), - conf.get(KERBEROS_DT_SECRET_KEY), - conf.get(KERBEROS_SPARK_USER_NAME), - None, - None, - maybeKrb5CMap, - pod) + val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) + if (maybeKrb5CMap.isDefined) { + logInfo(s"Mounting Resources for Kerberos") + HadoopBootstrapUtil.bootstrapKerberosPod( + conf.get(KERBEROS_DT_SECRET_NAME), + conf.get(KERBEROS_DT_SECRET_KEY), + conf.get(KERBEROS_SPARK_USER_NAME), + None, + None, + maybeKrb5CMap, + pod) + } else { + pod + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala index 09dcf93a54f8e..7f41ca43589b6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -28,44 +28,60 @@ import org.apache.spark.deploy.k8s.Constants._ private[spark] class PodTemplateConfigMapStep(conf: KubernetesConf) extends KubernetesFeatureConfigStep { + + private val hasTemplate = conf.contains(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + def configurePod(pod: SparkPod): SparkPod = { - val podWithVolume = new PodBuilder(pod.pod) - .editSpec() - .addNewVolume() - .withName(POD_TEMPLATE_VOLUME) - .withNewConfigMap() - .withName(POD_TEMPLATE_CONFIGMAP) - .addNewItem() - .withKey(POD_TEMPLATE_KEY) - .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .endSpec() - .build() + if (hasTemplate) { + val podWithVolume = new PodBuilder(pod.pod) + .editSpec() + .addNewVolume() + .withName(POD_TEMPLATE_VOLUME) + .withNewConfigMap() + .withName(POD_TEMPLATE_CONFIGMAP) + .addNewItem() + .withKey(POD_TEMPLATE_KEY) + .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME) + .endItem() + .endConfigMap() + .endVolume() + .endSpec() + .build() - val containerWithVolume = new ContainerBuilder(pod.container) - .addNewVolumeMount() - .withName(POD_TEMPLATE_VOLUME) - .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH) - .endVolumeMount() - .build() - SparkPod(podWithVolume, containerWithVolume) + val containerWithVolume = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(POD_TEMPLATE_VOLUME) + .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH) + .endVolumeMount() + .build() + SparkPod(podWithVolume, containerWithVolume) + } else { + pod + } } - override def getAdditionalPodSystemProperties(): Map[String, String] = Map[String, String]( - KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key -> - (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) + override def getAdditionalPodSystemProperties(): Map[String, String] = { + if (hasTemplate) { + Map[String, String]( + KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key -> + (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)) + } else { + Map.empty + } + } override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { - require(conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) - val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get - val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8) - Seq(new ConfigMapBuilder() - .withNewMetadata() - .withName(POD_TEMPLATE_CONFIGMAP) - .endMetadata() - .addToData(POD_TEMPLATE_KEY, podTemplateString) - .build()) + if (hasTemplate) { + val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get + val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8) + Seq(new ConfigMapBuilder() + .withNewMetadata() + .withName(POD_TEMPLATE_CONFIGMAP) + .endMetadata() + .addToData(POD_TEMPLATE_KEY, podTemplateString) + .build()) + } else { + Nil + } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 70a93c968795e..3888778bf84ca 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -104,7 +104,7 @@ private[spark] class Client( watcher: LoggingPodStatusWatcher) extends Logging { def run(): Unit = { - val resolvedDriverSpec = builder.buildFromFeatures(conf) + val resolvedDriverSpec = builder.buildFromFeatures(conf, kubernetesClient) val configMapName = s"${conf.resourceNamePrefix}-driver-conf-map" val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the @@ -232,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None)) { kubernetesClient => val client = new Client( kubernetesConf, - KubernetesDriverBuilder(kubernetesClient, kubernetesConf.sparkConf), + new KubernetesDriverBuilder(), kubernetesClient, waitForAppCompletion, watcher) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index a5ad9729aee9a..d2c0ced9fa2f4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -20,90 +20,49 @@ import java.io.File import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.features._ -private[spark] class KubernetesDriverBuilder( - provideBasicStep: (KubernetesDriverConf => BasicDriverFeatureStep) = - new BasicDriverFeatureStep(_), - provideCredentialsStep: (KubernetesDriverConf => DriverKubernetesCredentialsFeatureStep) = - new DriverKubernetesCredentialsFeatureStep(_), - provideServiceStep: (KubernetesDriverConf => DriverServiceFeatureStep) = - new DriverServiceFeatureStep(_), - provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_), - provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) = - new LocalDirsFeatureStep(_), - provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) = - new MountVolumesFeatureStep(_), - provideDriverCommandStep: (KubernetesDriverConf => DriverCommandFeatureStep) = - new DriverCommandFeatureStep(_), - provideHadoopGlobalStep: (KubernetesDriverConf => KerberosConfDriverFeatureStep) = - new KerberosConfDriverFeatureStep(_), - providePodTemplateConfigMapStep: (KubernetesConf => PodTemplateConfigMapStep) = - new PodTemplateConfigMapStep(_), - provideInitialPod: () => SparkPod = () => SparkPod.initialPod) { +private[spark] class KubernetesDriverBuilder { - def buildFromFeatures(kubernetesConf: KubernetesDriverConf): KubernetesDriverSpec = { - val baseFeatures = Seq( - provideBasicStep(kubernetesConf), - provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf), - provideLocalDirsStep(kubernetesConf)) - - val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { - Seq(provideSecretsStep(kubernetesConf)) - } else Nil - val envSecretFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) { - Seq(provideEnvSecretsStep(kubernetesConf)) - } else Nil - val volumesFeature = if (kubernetesConf.volumes.nonEmpty) { - Seq(provideVolumesStep(kubernetesConf)) - } else Nil - val podTemplateFeature = if ( - kubernetesConf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { - Seq(providePodTemplateConfigMapStep(kubernetesConf)) - } else Nil - - val driverCommandStep = provideDriverCommandStep(kubernetesConf) - - val hadoopConfigStep = Some(provideHadoopGlobalStep(kubernetesConf)) + def buildFromFeatures( + conf: KubernetesDriverConf, + client: KubernetesClient): KubernetesDriverSpec = { + val initialPod = conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE) + .map { file => + KubernetesUtils.loadPodFromTemplate( + client, + new File(file), + conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME)) + } + .getOrElse(SparkPod.initialPod()) - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ Seq(driverCommandStep) ++ - secretFeature ++ envSecretFeature ++ volumesFeature ++ - hadoopConfigStep ++ podTemplateFeature + val features = Seq( + new BasicDriverFeatureStep(conf), + new DriverKubernetesCredentialsFeatureStep(conf), + new DriverServiceFeatureStep(conf), + new MountSecretsFeatureStep(conf), + new EnvSecretsFeatureStep(conf), + new LocalDirsFeatureStep(conf), + new MountVolumesFeatureStep(conf), + new DriverCommandFeatureStep(conf), + new KerberosConfDriverFeatureStep(conf), + new PodTemplateConfigMapStep(conf)) - var spec = KubernetesDriverSpec( - provideInitialPod(), + val spec = KubernetesDriverSpec( + initialPod, driverKubernetesResources = Seq.empty, - kubernetesConf.sparkConf.getAll.toMap) - for (feature <- allFeatures) { + conf.sparkConf.getAll.toMap) + + features.foldLeft(spec) { case (spec, feature) => val configuredPod = feature.configurePod(spec.pod) val addedSystemProperties = feature.getAdditionalPodSystemProperties() val addedResources = feature.getAdditionalKubernetesResources() - spec = KubernetesDriverSpec( + KubernetesDriverSpec( configuredPod, spec.driverKubernetesResources ++ addedResources, spec.systemProperties ++ addedSystemProperties) } - spec } -} -private[spark] object KubernetesDriverBuilder { - def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesDriverBuilder = { - conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE) - .map(new File(_)) - .map(file => new KubernetesDriverBuilder(provideInitialPod = () => - KubernetesUtils.loadPodFromTemplate( - kubernetesClient, - file, - conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME)) - )) - .getOrElse(new KubernetesDriverBuilder()) - } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index ac42554b1334b..da3edfeca9b1f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -136,7 +136,8 @@ private[spark] class ExecutorPodsAllocator( newExecutorId.toString, applicationId, driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf, secMgr) + val executorPod = executorBuilder.buildFromFeatures(executorConf, secMgr, + kubernetesClient) val podWithAttachedContainer = new PodBuilder(executorPod.pod) .editOrNewSpec() .addToContainers(executorPod.container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b31fbb420ed6d..809bdf8ca8c27 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -95,7 +95,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit val executorPodsAllocator = new ExecutorPodsAllocator( sc.conf, sc.env.securityManager, - KubernetesExecutorBuilder(kubernetesClient, sc.conf), + new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index ba273cad6a8e5..0b74966fe8685 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -20,86 +20,36 @@ import java.io.File import io.fabric8.kubernetes.client.KubernetesClient -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.SecurityManager import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.features._ -private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesExecutorConf, SecurityManager) => BasicExecutorFeatureStep = - new BasicExecutorFeatureStep(_, _), - provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_), - provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) = - new LocalDirsFeatureStep(_), - provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) = - new MountVolumesFeatureStep(_), - provideHadoopConfStep: (KubernetesExecutorConf => HadoopConfExecutorFeatureStep) = - new HadoopConfExecutorFeatureStep(_), - provideKerberosConfStep: (KubernetesExecutorConf => KerberosConfExecutorFeatureStep) = - new KerberosConfExecutorFeatureStep(_), - provideHadoopSparkUserStep: (KubernetesExecutorConf => HadoopSparkUserExecutorFeatureStep) = - new HadoopSparkUserExecutorFeatureStep(_), - provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { +private[spark] class KubernetesExecutorBuilder { def buildFromFeatures( - kubernetesConf: KubernetesExecutorConf, - secMgr: SecurityManager): SparkPod = { - val sparkConf = kubernetesConf.sparkConf - val maybeHadoopConfigMap = sparkConf.getOption(HADOOP_CONFIG_MAP_NAME) - val maybeDTSecretName = sparkConf.getOption(KERBEROS_DT_SECRET_NAME) - val maybeDTDataItem = sparkConf.getOption(KERBEROS_DT_SECRET_KEY) - - val baseFeatures = Seq(provideBasicStep(kubernetesConf, secMgr), - provideLocalDirsStep(kubernetesConf)) - val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) { - Seq(provideSecretsStep(kubernetesConf)) - } else Nil - val secretEnvFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) { - Seq(provideEnvSecretsStep(kubernetesConf)) - } else Nil - val volumesFeature = if (kubernetesConf.volumes.nonEmpty) { - Seq(provideVolumesStep(kubernetesConf)) - } else Nil - - val maybeHadoopConfFeatureSteps = maybeHadoopConfigMap.map { _ => - val maybeKerberosStep = - if (maybeDTSecretName.isDefined && maybeDTDataItem.isDefined) { - provideKerberosConfStep(kubernetesConf) - } else { - provideHadoopSparkUserStep(kubernetesConf) - } - Seq(provideHadoopConfStep(kubernetesConf)) :+ - maybeKerberosStep - }.getOrElse(Seq.empty[KubernetesFeatureConfigStep]) - - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ - secretFeature ++ - secretEnvFeature ++ - volumesFeature ++ - maybeHadoopConfFeatureSteps - - var executorPod = provideInitialPod() - for (feature <- allFeatures) { - executorPod = feature.configurePod(executorPod) - } - executorPod + conf: KubernetesExecutorConf, + secMgr: SecurityManager, + client: KubernetesClient): SparkPod = { + val initialPod = conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) + .map { file => + KubernetesUtils.loadPodFromTemplate( + client, + new File(file), + conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME)) + } + .getOrElse(SparkPod.initialPod()) + + val features = Seq( + new BasicExecutorFeatureStep(conf, secMgr), + new MountSecretsFeatureStep(conf), + new EnvSecretsFeatureStep(conf), + new LocalDirsFeatureStep(conf), + new MountVolumesFeatureStep(conf), + new HadoopConfExecutorFeatureStep(conf), + new KerberosConfExecutorFeatureStep(conf), + new HadoopSparkUserExecutorFeatureStep(conf)) + + features.foldLeft(initialPod) { case (pod, feature) => feature.configurePod(pod) } } -} -private[spark] object KubernetesExecutorBuilder { - def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesExecutorBuilder = { - conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE) - .map(new File(_)) - .map(file => new KubernetesExecutorBuilder(provideInitialPod = () => - KubernetesUtils.loadPodFromTemplate( - kubernetesClient, - file, - conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME)) - )) - .getOrElse(new KubernetesExecutorBuilder()) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala new file mode 100644 index 0000000000000..7dde0c1377168 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.io.File + +import io.fabric8.kubernetes.api.model.{Config => _, _} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, never, verify, when} +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.internal.config.ConfigEntry + +abstract class PodBuilderSuite extends SparkFunSuite { + + protected def templateFileConf: ConfigEntry[_] + + protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod + + private val baseConf = new SparkConf(false) + .set(Config.CONTAINER_IMAGE, "spark-executor:latest") + + test("use empty initial pod if template is not specified") { + val client = mock(classOf[KubernetesClient]) + buildPod(baseConf.clone(), client) + verify(client, never()).pods() + } + + test("load pod template if specified") { + val client = mockKubernetesClient() + val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml") + val pod = buildPod(sparkConf, client) + verifyPod(pod) + } + + test("complain about misconfigured pod template") { + val client = mockKubernetesClient( + new PodBuilder() + .withNewMetadata() + .addToLabels("test-label-key", "test-label-value") + .endMetadata() + .build()) + val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml") + val exception = intercept[SparkException] { + buildPod(sparkConf, client) + } + assert(exception.getMessage.contains("Could not load pod from template file.")) + } + + private def mockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = { + val kubernetesClient = mock(classOf[KubernetesClient]) + val pods = + mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]]) + val podResource = mock(classOf[PodResource[Pod, DoneablePod]]) + when(kubernetesClient.pods()).thenReturn(pods) + when(pods.load(any(classOf[File]))).thenReturn(podResource) + when(podResource.get()).thenReturn(pod) + kubernetesClient + } + + private def verifyPod(pod: SparkPod): Unit = { + val metadata = pod.pod.getMetadata + assert(metadata.getLabels.containsKey("test-label-key")) + assert(metadata.getAnnotations.containsKey("test-annotation-key")) + assert(metadata.getNamespace === "namespace") + assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference")) + val spec = pod.pod.getSpec + assert(!spec.getContainers.asScala.exists(_.getName == "executor-container")) + assert(spec.getDnsPolicy === "dns-policy") + assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname"))) + assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference")) + assert(spec.getInitContainers.asScala.exists(_.getName == "init-container")) + assert(spec.getNodeName == "node-name") + assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value") + assert(spec.getSchedulerName === "scheduler") + assert(spec.getSecurityContext.getRunAsUser === 1000L) + assert(spec.getServiceAccount === "service-account") + assert(spec.getSubdomain === "subdomain") + assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key")) + assert(spec.getVolumes.asScala.exists(_.getName == "test-volume")) + val container = pod.container + assert(container.getName === "executor-container") + assert(container.getArgs.contains("arg")) + assert(container.getCommand.equals(List("command").asJava)) + assert(container.getEnv.asScala.exists(_.getName == "env-key")) + assert(container.getResources.getLimits.get("gpu") === + new QuantityBuilder().withAmount("1").build()) + assert(container.getSecurityContext.getRunAsNonRoot) + assert(container.getStdin) + assert(container.getTerminationMessagePath === "termination-message-path") + assert(container.getTerminationMessagePolicy === "termination-message-policy") + assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume")) + } + + private def podWithSupportedFeatures(): Pod = { + new PodBuilder() + .withNewMetadata() + .addToLabels("test-label-key", "test-label-value") + .addToAnnotations("test-annotation-key", "test-annotation-value") + .withNamespace("namespace") + .addNewOwnerReference() + .withController(true) + .withName("owner-reference") + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withDnsPolicy("dns-policy") + .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build()) + .withImagePullSecrets( + new LocalObjectReferenceBuilder().withName("local-reference").build()) + .withInitContainers(new ContainerBuilder().withName("init-container").build()) + .withNodeName("node-name") + .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava) + .withSchedulerName("scheduler") + .withNewSecurityContext() + .withRunAsUser(1000L) + .endSecurityContext() + .withServiceAccount("service-account") + .withSubdomain("subdomain") + .withTolerations(new TolerationBuilder() + .withKey("toleration-key") + .withOperator("Equal") + .withEffect("NoSchedule") + .build()) + .addNewVolume() + .withNewHostPath() + .withPath("/test") + .endHostPath() + .withName("test-volume") + .endVolume() + .addNewContainer() + .withArgs("arg") + .withCommand("command") + .addNewEnv() + .withName("env-key") + .withValue("env-value") + .endEnv() + .withImagePullPolicy("Always") + .withName("executor-container") + .withNewResources() + .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava) + .endResources() + .withNewSecurityContext() + .withRunAsNonRoot(true) + .endSecurityContext() + .withStdin(true) + .withTerminationMessagePath("termination-message-path") + .withTerminationMessagePolicy("termination-message-policy") + .addToVolumeMounts( + new VolumeMountBuilder() + .withName("test-volume") + .withMountPath("/test") + .build()) + .endContainer() + .endSpec() + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala index 7295b82ca4799..5e7388dc8e672 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala @@ -20,25 +20,32 @@ import java.io.{File, PrintWriter} import java.nio.file.Files import io.fabric8.kubernetes.api.model.ConfigMap -import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s._ -class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter { - private var kubernetesConf : KubernetesConf = _ - private var templateFile: File = _ +class PodTemplateConfigMapStepSuite extends SparkFunSuite { - before { - templateFile = Files.createTempFile("pod-template", "yml").toFile + test("Do nothing when executor template is not specified") { + val conf = KubernetesTestConf.createDriverConf() + val step = new PodTemplateConfigMapStep(conf) + + val initialPod = SparkPod.initialPod() + val configuredPod = step.configurePod(initialPod) + assert(configuredPod === initialPod) + + assert(step.getAdditionalKubernetesResources().isEmpty) + assert(step.getAdditionalPodSystemProperties().isEmpty) + } + + test("Mounts executor template volume if config specified") { + val templateFile = Files.createTempFile("pod-template", "yml").toFile templateFile.deleteOnExit() val sparkConf = new SparkConf(false) .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, templateFile.getAbsolutePath) - kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - } + val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - test("Mounts executor template volume if config specified") { val writer = new PrintWriter(templateFile) writer.write("pod-template-contents") writer.close() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index e9c05fef6f5db..1bb926cbca23d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -126,7 +126,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { MockitoAnnotations.initMocks(this) kconf = KubernetesTestConf.createDriverConf( resourceNamePrefix = Some(KUBERNETES_RESOURCE_PREFIX)) - when(driverBuilder.buildFromFeatures(kconf)).thenReturn(BUILT_KUBERNETES_SPEC) + when(driverBuilder.buildFromFeatures(kconf, kubernetesClient)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 7e7dc4763c2e7..6518c91a1a1fd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -16,201 +16,21 @@ */ package org.apache.spark.deploy.k8s.submit -import io.fabric8.kubernetes.api.model.PodBuilder import io.fabric8.kubernetes.client.KubernetesClient -import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Config.{CONTAINER_IMAGE, KUBERNETES_DRIVER_PODTEMPLATE_FILE, KUBERNETES_EXECUTOR_PODTEMPLATE_FILE} -import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.internal.config.ConfigEntry -class KubernetesDriverBuilderSuite extends SparkFunSuite { +class KubernetesDriverBuilderSuite extends PodBuilderSuite { - private val BASIC_STEP_TYPE = "basic" - private val CREDENTIALS_STEP_TYPE = "credentials" - private val SERVICE_STEP_TYPE = "service" - private val LOCAL_DIRS_STEP_TYPE = "local-dirs" - private val SECRETS_STEP_TYPE = "mount-secrets" - private val DRIVER_CMD_STEP_TYPE = "driver-command" - private val ENV_SECRETS_STEP_TYPE = "env-secrets" - private val HADOOP_GLOBAL_STEP_TYPE = "hadoop-global" - private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" - private val TEMPLATE_VOLUME_STEP_TYPE = "template-volume" - - private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) - - private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) - - private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) - - private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) - - private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) - - private val driverCommandStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - DRIVER_CMD_STEP_TYPE, classOf[DriverCommandFeatureStep]) - - private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) - - private val hadoopGlobalStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - HADOOP_GLOBAL_STEP_TYPE, classOf[KerberosConfDriverFeatureStep]) - - private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) - - private val templateVolumeStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - TEMPLATE_VOLUME_STEP_TYPE, classOf[PodTemplateConfigMapStep] - ) - - private val builderUnderTest: KubernetesDriverBuilder = - new KubernetesDriverBuilder( - _ => basicFeatureStep, - _ => credentialsStep, - _ => serviceStep, - _ => secretsStep, - _ => envSecretsStep, - _ => localDirsStep, - _ => mountVolumesStep, - _ => driverCommandStep, - _ => hadoopGlobalStep, - _ => templateVolumeStep) - - test("Apply fundamental steps all the time.") { - val conf = KubernetesTestConf.createDriverConf() - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) + override protected def templateFileConf: ConfigEntry[_] = { + Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE } - test("Apply secrets step if secrets are present.") { - val conf = KubernetesTestConf.createDriverConf( - secretEnvNamesToKeyRefs = Map("EnvName" -> "SecretName:secretKey"), - secretNamesToMountPaths = Map("secret" -> "secretMountPath")) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE, - ENV_SECRETS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) - } - - test("Apply volumes step if mounts are present.") { - val volumeSpec = KubernetesVolumeSpec( - "volume", - "/tmp", - "", - false, - KubernetesHostPathVolumeConf("/path")) - val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec)) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - MOUNT_VOLUMES_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) - } - - test("Apply volumes step if a mount subpath is present.") { - val volumeSpec = KubernetesVolumeSpec( - "volume", - "/tmp", - "foo", - false, - KubernetesHostPathVolumeConf("/path")) - val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec)) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - MOUNT_VOLUMES_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE) - } - - test("Apply template volume step if executor template is present.") { - val sparkConf = new SparkConf(false) - .set(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "filename") + override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = { val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf), - BASIC_STEP_TYPE, - CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - DRIVER_CMD_STEP_TYPE, - HADOOP_GLOBAL_STEP_TYPE, - TEMPLATE_VOLUME_STEP_TYPE) - } - - private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) - : Unit = { - val addedProperties = resolvedSpec.systemProperties - .filter { case (k, _) => !k.startsWith("spark.") } - .toMap - assert(addedProperties.keys.toSet === stepTypes.toSet) - stepTypes.foreach { stepType => - assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) - assert(resolvedSpec.driverKubernetesResources.containsSlice( - KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) - assert(resolvedSpec.systemProperties(stepType) === stepType) - } - } - - test("Start with empty pod if template is not specified") { - val kubernetesClient = mock(classOf[KubernetesClient]) - val driverBuilder = KubernetesDriverBuilder.apply(kubernetesClient, new SparkConf()) - verify(kubernetesClient, never()).pods() + new KubernetesDriverBuilder().buildFromFeatures(conf, client).pod } - test("Starts with template if specified") { - val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient() - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - val driverSpec = KubernetesDriverBuilder - .apply(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf) - PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(driverSpec.pod) - } - - test("Throws on misconfigured pod template") { - val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient( - new PodBuilder() - .withNewMetadata() - .addToLabels("test-label-key", "test-label-value") - .endMetadata() - .build()) - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) - val exception = intercept[SparkException] { - KubernetesDriverBuilder - .apply(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf) - } - assert(exception.getMessage.contains("Could not load pod from template file.")) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala deleted file mode 100644 index c92e9e6e3b6b3..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import java.io.File - -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.KubernetesClient -import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource} -import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, when} -import org.scalatest.FlatSpec -import scala.collection.JavaConverters._ - -import org.apache.spark.deploy.k8s.SparkPod - -object PodBuilderSuiteUtils extends FlatSpec { - - def loadingMockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = { - val kubernetesClient = mock(classOf[KubernetesClient]) - val pods = - mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]]) - val podResource = mock(classOf[PodResource[Pod, DoneablePod]]) - when(kubernetesClient.pods()).thenReturn(pods) - when(pods.load(any(classOf[File]))).thenReturn(podResource) - when(podResource.get()).thenReturn(pod) - kubernetesClient - } - - def verifyPodWithSupportedFeatures(pod: SparkPod): Unit = { - val metadata = pod.pod.getMetadata - assert(metadata.getLabels.containsKey("test-label-key")) - assert(metadata.getAnnotations.containsKey("test-annotation-key")) - assert(metadata.getNamespace === "namespace") - assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference")) - val spec = pod.pod.getSpec - assert(!spec.getContainers.asScala.exists(_.getName == "executor-container")) - assert(spec.getDnsPolicy === "dns-policy") - assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname"))) - assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference")) - assert(spec.getInitContainers.asScala.exists(_.getName == "init-container")) - assert(spec.getNodeName == "node-name") - assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value") - assert(spec.getSchedulerName === "scheduler") - assert(spec.getSecurityContext.getRunAsUser === 1000L) - assert(spec.getServiceAccount === "service-account") - assert(spec.getSubdomain === "subdomain") - assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key")) - assert(spec.getVolumes.asScala.exists(_.getName == "test-volume")) - val container = pod.container - assert(container.getName === "executor-container") - assert(container.getArgs.contains("arg")) - assert(container.getCommand.equals(List("command").asJava)) - assert(container.getEnv.asScala.exists(_.getName == "env-key")) - assert(container.getResources.getLimits.get("gpu") === - new QuantityBuilder().withAmount("1").build()) - assert(container.getSecurityContext.getRunAsNonRoot) - assert(container.getStdin) - assert(container.getTerminationMessagePath === "termination-message-path") - assert(container.getTerminationMessagePolicy === "termination-message-policy") - assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume")) - - } - - - def podWithSupportedFeatures(): Pod = new PodBuilder() - .withNewMetadata() - .addToLabels("test-label-key", "test-label-value") - .addToAnnotations("test-annotation-key", "test-annotation-value") - .withNamespace("namespace") - .addNewOwnerReference() - .withController(true) - .withName("owner-reference") - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withDnsPolicy("dns-policy") - .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build()) - .withImagePullSecrets( - new LocalObjectReferenceBuilder().withName("local-reference").build()) - .withInitContainers(new ContainerBuilder().withName("init-container").build()) - .withNodeName("node-name") - .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava) - .withSchedulerName("scheduler") - .withNewSecurityContext() - .withRunAsUser(1000L) - .endSecurityContext() - .withServiceAccount("service-account") - .withSubdomain("subdomain") - .withTolerations(new TolerationBuilder() - .withKey("toleration-key") - .withOperator("Equal") - .withEffect("NoSchedule") - .build()) - .addNewVolume() - .withNewHostPath() - .withPath("/test") - .endHostPath() - .withName("test-volume") - .endVolume() - .addNewContainer() - .withArgs("arg") - .withCommand("command") - .addNewEnv() - .withName("env-key") - .withValue("env-value") - .endEnv() - .withImagePullPolicy("Always") - .withName("executor-container") - .withNewResources() - .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava) - .endResources() - .withNewSecurityContext() - .withRunAsNonRoot(true) - .endSecurityContext() - .withStdin(true) - .withTerminationMessagePath("termination-message-path") - .withTerminationMessagePolicy("termination-message-policy") - .addToVolumeMounts( - new VolumeMountBuilder() - .withName("test-volume") - .withMountPath("/test") - .build()) - .endContainer() - .endSpec() - .build() - -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index d4fa31af3d5ce..278a3821a6f3d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -80,8 +80,8 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) when(driverPodOperations.get).thenReturn(driverPod) - when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr))) - .thenAnswer(executorPodAnswer()) + when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr), + meq(kubernetesClient))).thenAnswer(executorPodAnswer()) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() waitForExecutorPodsClock = new ManualClock(0L) podsAllocatorUnderTest = new ExecutorPodsAllocator( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index ef521fd801e97..bd716174a8271 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -16,147 +16,23 @@ */ package org.apache.spark.scheduler.cluster.k8s -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{Config => _, _} import io.fabric8.kubernetes.client.KubernetesClient -import org.mockito.Mockito.{mock, never, verify} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features._ -import org.apache.spark.deploy.k8s.submit.PodBuilderSuiteUtils -import org.apache.spark.util.SparkConfWithEnv - -class KubernetesExecutorBuilderSuite extends SparkFunSuite { - private val BASIC_STEP_TYPE = "basic" - private val SECRETS_STEP_TYPE = "mount-secrets" - private val ENV_SECRETS_STEP_TYPE = "env-secrets" - private val LOCAL_DIRS_STEP_TYPE = "local-dirs" - private val HADOOP_CONF_STEP_TYPE = "hadoop-conf-step" - private val HADOOP_SPARK_USER_STEP_TYPE = "hadoop-spark-user" - private val KERBEROS_CONF_STEP_TYPE = "kerberos-step" - private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" - - private val secMgr = new SecurityManager(new SparkConf(false)) - - private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) - private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) - private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) - private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) - private val hadoopConfStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - HADOOP_CONF_STEP_TYPE, classOf[HadoopConfExecutorFeatureStep]) - private val hadoopSparkUser = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - HADOOP_SPARK_USER_STEP_TYPE, classOf[HadoopSparkUserExecutorFeatureStep]) - private val kerberosConf = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - KERBEROS_CONF_STEP_TYPE, classOf[KerberosConfExecutorFeatureStep]) - private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( - MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) +import org.apache.spark.internal.config.ConfigEntry - private val builderUnderTest = new KubernetesExecutorBuilder( - (_, _) => basicFeatureStep, - _ => mountSecretsStep, - _ => envSecretsStep, - _ => localDirsStep, - _ => mountVolumesStep, - _ => hadoopConfStep, - _ => kerberosConf, - _ => hadoopSparkUser) - - test("Basic steps are consistently applied.") { - val conf = KubernetesTestConf.createExecutorConf() - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) - } - - test("Apply secrets step if secrets are present.") { - val conf = KubernetesTestConf.createExecutorConf( - secretEnvNamesToKeyRefs = Map("secret-name" -> "secret-key"), - secretNamesToMountPaths = Map("secret" -> "secretMountPath")) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE, - ENV_SECRETS_STEP_TYPE) - } +class KubernetesExecutorBuilderSuite extends PodBuilderSuite { - test("Apply volumes step if mounts are present.") { - val volumeSpec = KubernetesVolumeSpec( - "volume", - "/tmp", - "", - false, - KubernetesHostPathVolumeConf("/checkpoint")) - val conf = KubernetesTestConf.createExecutorConf( - volumes = Seq(volumeSpec)) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - MOUNT_VOLUMES_STEP_TYPE) + override protected def templateFileConf: ConfigEntry[_] = { + Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE } - test("Apply basicHadoop step if HADOOP_CONF_DIR is defined") { - // HADOOP_DELEGATION_TOKEN - val conf = KubernetesTestConf.createExecutorConf( - sparkConf = new SparkConfWithEnv(Map("HADOOP_CONF_DIR" -> "/var/hadoop-conf")) - .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") - .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name")) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - HADOOP_CONF_STEP_TYPE, - HADOOP_SPARK_USER_STEP_TYPE) + override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = { + sparkConf.set("spark.driver.host", "https://driver.host.com") + val conf = KubernetesTestConf.createExecutorConf(sparkConf = sparkConf) + val secMgr = new SecurityManager(sparkConf) + new KubernetesExecutorBuilder().buildFromFeatures(conf, secMgr, client) } - test("Apply kerberos step if DT secrets created") { - val conf = KubernetesTestConf.createExecutorConf( - sparkConf = new SparkConf(false) - .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name") - .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name") - .set(KERBEROS_SPARK_USER_NAME, "spark-user") - .set(KERBEROS_DT_SECRET_NAME, "dt-secret") - .set(KERBEROS_DT_SECRET_KEY, "dt-key" )) - validateStepTypesApplied( - builderUnderTest.buildFromFeatures(conf, secMgr), - BASIC_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE, - HADOOP_CONF_STEP_TYPE, - KERBEROS_CONF_STEP_TYPE) - } - - private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { - assert(resolvedPod.pod.getMetadata.getLabels.asScala.keys.toSet === stepTypes.toSet) - } - - test("Starts with empty executor pod if template is not specified") { - val kubernetesClient = mock(classOf[KubernetesClient]) - val executorBuilder = KubernetesExecutorBuilder.apply(kubernetesClient, new SparkConf()) - verify(kubernetesClient, never()).pods() - } - - test("Starts with executor template if specified") { - val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient() - val sparkConf = new SparkConf(false) - .set("spark.driver.host", "https://driver.host.com") - .set(Config.CONTAINER_IMAGE, "spark-executor:latest") - .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "template-file.yaml") - val kubernetesConf = KubernetesTestConf.createExecutorConf( - sparkConf = sparkConf, - driverPod = Some(new PodBuilder() - .withNewMetadata() - .withName("driver") - .endMetadata() - .build())) - val sparkPod = KubernetesExecutorBuilder(kubernetesClient, sparkConf) - .buildFromFeatures(kubernetesConf, secMgr) - PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(sparkPod) - } } From 2920438c43ade38e62442b0ba8937b716a05f7ad Mon Sep 17 00:00:00 2001 From: Luca Canali Date: Wed, 12 Dec 2018 16:18:22 -0800 Subject: [PATCH 2267/2461] [SPARK-25277][YARN] YARN applicationMaster metrics should not register static metrics ## What changes were proposed in this pull request? YARN applicationMaster metrics registration introduced in SPARK-24594 causes further registration of static metrics (Codegenerator and HiveExternalCatalog) and of JVM metrics, which I believe do not belong in this context. This looks like an unintended side effect of using the start method of [[MetricsSystem]]. A possible solution proposed here, is to introduce startNoRegisterSources to avoid these additional registrations of static sources and of JVM sources in the case of YARN applicationMaster metrics (this could be useful for other metrics that may be added in the future). ## How was this patch tested? Manually tested on a YARN cluster, Closes #22279 from LucaCanali/YarnMetricsRemoveExtraSourceRegistration. Lead-authored-by: Luca Canali Co-authored-by: LucaCanali Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/metrics/MetricsSystem.scala | 8 +++++--- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index bb7b434e9a113..301317a79dfcf 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -94,11 +94,13 @@ private[spark] class MetricsSystem private ( metricsConfig.initialize() - def start() { + def start(registerStaticSources: Boolean = true) { require(!running, "Attempting to start a MetricsSystem that is already running") running = true - StaticSources.allSources.foreach(registerSource) - registerSources() + if (registerStaticSources) { + StaticSources.allSources.foreach(registerSource) + registerSources() + } registerSinks() sinks.foreach(_.start) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c1f3211bcab29..e46c4f970c4a3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -449,7 +449,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val ms = MetricsSystem.createMetricsSystem("applicationMaster", sparkConf, securityMgr) val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId) ms.registerSource(new ApplicationMasterSource(prefix, allocator)) - ms.start() + // do not register static sources in this case as per SPARK-25277 + ms.start(false) metricsSystem = Some(ms) reporterThread = launchReporterThread() } From 6daa78309460e338dd688cf6cdbd46a12666f72e Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 12 Dec 2018 16:45:50 -0800 Subject: [PATCH 2268/2461] [SPARK-26322][SS] Add spark.kafka.sasl.token.mechanism to ease delegation token configuration. ## What changes were proposed in this pull request? When Kafka delegation token obtained, SCRAM `sasl.mechanism` has to be configured for authentication. This can be configured on the related source/sink which is inconvenient from user perspective. Such granularity is not required and this configuration can be implemented with one central parameter. In this PR `spark.kafka.sasl.token.mechanism` added to configure this centrally (default: `SCRAM-SHA-512`). ## How was this patch tested? Existing unit tests + on cluster. Closes #23274 from gaborgsomogyi/SPARK-26322. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../apache/spark/internal/config/Kafka.scala | 9 ++ .../structured-streaming-kafka-integration.md | 144 +----------------- .../sql/kafka010/KafkaSourceProvider.scala | 15 +- 3 files changed, 21 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala index 064fc93cb8ed8..e91ddd3e9741a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala @@ -79,4 +79,13 @@ private[spark] object Kafka { "For further details please see kafka documentation. Only used to obtain delegation token.") .stringConf .createOptional + + val TOKEN_SASL_MECHANISM = + ConfigBuilder("spark.kafka.sasl.token.mechanism") + .doc("SASL mechanism used for client connections with delegation token. Because SCRAM " + + "login module used for authentication a compatible mechanism has to be set here. " + + "For further details please see kafka documentation (sasl.mechanism). Only used to " + + "authenticate against Kafka broker with delegation token.") + .stringConf + .createWithDefault("SCRAM-SHA-512") } diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 7040f8da2c614..3d64ec4cb55f7 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -642,9 +642,9 @@ This way the application can be configured via Spark parameters and may not need configuration (Spark can use Kafka's dynamic JAAS configuration feature). For further information about delegation tokens, see [Kafka delegation token docs](http://kafka.apache.org/documentation/#security_delegation_token). -The process is initiated by Spark's Kafka delegation token provider. When `spark.kafka.bootstrap.servers`, +The process is initiated by Spark's Kafka delegation token provider. When `spark.kafka.bootstrap.servers` is set, Spark considers the following log in options, in order of preference: -- **JAAS login configuration** +- **JAAS login configuration**, please see example below. - **Keytab file**, such as, ./bin/spark-submit \ @@ -669,144 +669,8 @@ Kafka broker configuration): After obtaining delegation token successfully, Spark distributes it across nodes and renews it accordingly. Delegation token uses `SCRAM` login module for authentication and because of that the appropriate -`sasl.mechanism` has to be configured on source/sink (it must match with Kafka broker configuration): - -
      -
      -{% highlight scala %} - -// Setting on Kafka Source for Streaming Queries -val df = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - -// Setting on Kafka Source for Batch Queries -val df = spark - .read - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - -// Setting on Kafka Sink for Streaming Queries -val ds = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("topic", "topic1") - .start() - -// Setting on Kafka Sink for Batch Queries -val ds = df - .selectExpr("topic1", "CAST(key AS STRING)", "CAST(value AS STRING)") - .write - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .save() - -{% endhighlight %} -
      -
      -{% highlight java %} - -// Setting on Kafka Source for Streaming Queries -Dataset df = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load(); -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); - -// Setting on Kafka Source for Batch Queries -Dataset df = spark - .read() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("subscribe", "topic1") - .load(); -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); - -// Setting on Kafka Sink for Streaming Queries -StreamingQuery ds = df - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .writeStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("topic", "topic1") - .start(); - -// Setting on Kafka Sink for Batch Queries -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .write() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") - .option("topic", "topic1") - .save(); - -{% endhighlight %} -
      -
      -{% highlight python %} - -// Setting on Kafka Source for Streaming Queries -df = spark \ - .readStream \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("subscribe", "topic1") \ - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - -// Setting on Kafka Source for Batch Queries -df = spark \ - .read \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("subscribe", "topic1") \ - .load() -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - -// Setting on Kafka Sink for Streaming Queries -ds = df \ - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ - .writeStream \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("topic", "topic1") \ - .start() - -// Setting on Kafka Sink for Batch Queries -df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ - .write \ - .format("kafka") \ - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ - .option("kafka.sasl.mechanism", "SCRAM-SHA-512") \ - .option("topic", "topic1") \ - .save() - -{% endhighlight %} -
      -
      +`spark.kafka.sasl.token.mechanism` (default: `SCRAM-SHA-512`) has to be configured. Also, this parameter +must match with Kafka broker configuration. When delegation token is available on an executor it can be overridden with JAAS login configuration. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 0ac330435e5c5..6a0c2088ac3d1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,6 +30,7 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.SparkEnv import org.apache.spark.deploy.security.KafkaTokenUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ @@ -501,7 +502,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .setTokenJaasConfigIfNeeded() + .setAuthenticationConfigIfNeeded() .build() def kafkaParamsForExecutors( @@ -523,7 +524,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .setTokenJaasConfigIfNeeded() + .setAuthenticationConfigIfNeeded() .build() /** @@ -556,7 +557,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { this } - def setTokenJaasConfigIfNeeded(): ConfigUpdater = { + def setAuthenticationConfigIfNeeded(): ConfigUpdater = { // There are multiple possibilities to log in and applied in the following order: // - JVM global security provided -> try to log in with JVM global security configuration // which can be configured for example with 'java.security.auth.login.config'. @@ -568,11 +569,11 @@ private[kafka010] object KafkaSourceProvider extends Logging { } else if (KafkaSecurityHelper.isTokenAvailable()) { logDebug("Delegation token detected, using it for login.") val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) - val mechanism = kafkaParams - .getOrElse(SaslConfigs.SASL_MECHANISM, SaslConfigs.DEFAULT_SASL_MECHANISM) + set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + val mechanism = SparkEnv.get.conf.get(Kafka.TOKEN_SASL_MECHANISM) require(mechanism.startsWith("SCRAM"), "Delegation token works only with SCRAM mechanism.") - set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + set(SaslConfigs.SASL_MECHANISM, mechanism) } this } @@ -600,7 +601,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { ConfigUpdater("executor", specifiedKafkaParams) .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) .set(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, serClassName) - .setTokenJaasConfigIfNeeded() + .setAuthenticationConfigIfNeeded() .build() } From 05b68d5cc92e46bd701cd01b4179cd13397eaf90 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Dec 2018 11:13:15 +0800 Subject: [PATCH 2269/2461] [SPARK-26297][SQL] improve the doc of Distribution/Partitioning ## What changes were proposed in this pull request? Some documents of `Distribution/Partitioning` are stale and misleading, this PR fixes them: 1. `Distribution` never have intra-partition requirement 2. `OrderedDistribution` does not require tuples that share the same value being colocated in the same partition. 3. `RangePartitioning` can provide a weaker guarantee for a prefix of its `ordering` expressions. ## How was this patch tested? comment-only PR. Closes #23249 from cloud-fan/doc. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../plans/physical/partitioning.scala | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc1a5e835d9cd..17e1cb416fc8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -22,13 +22,11 @@ import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed - * in parallel on many machines. Distribution can be used to refer to two distinct physical - * properties: - * - Inter-node partitioning of data: In this case the distribution describes how tuples are - * partitioned across physical machines in a cluster. Knowing this property allows some - * operators (e.g., Aggregate) to perform partition local operations instead of global ones. - * - Intra-partition ordering of data: In this case the distribution describes guarantees made - * about how tuples are distributed within a single partition. + * in parallel on many machines. + * + * Distribution here refers to inter-node partitioning of data. That is, it describes how tuples + * are partitioned across physical machines in a cluster. Knowing this property allows some + * operators (e.g., Aggregate) to perform partition local operations instead of global ones. */ sealed trait Distribution { /** @@ -70,9 +68,7 @@ case object AllTuples extends Distribution { /** * Represents data where tuples that share the same values for the `clustering` - * [[Expression Expressions]] will be co-located. Based on the context, this - * can mean such tuples are either co-located in the same partition or they will be contiguous - * within a single partition. + * [[Expression Expressions]] will be co-located in the same partition. */ case class ClusteredDistribution( clustering: Seq[Expression], @@ -118,10 +114,12 @@ case class HashClusteredDistribution( /** * Represents data where tuples have been ordered according to the `ordering` - * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the - * same value for the ordering expressions are contiguous and will never be split across - * partitions. + * [[Expression Expressions]]. Its requirement is defined as the following: + * - Given any 2 adjacent partitions, all the rows of the second partition must be larger than or + * equal to any row in the first partition, according to the `ordering` expressions. + * + * In other words, this distribution requires the rows to be ordered across partitions, but not + * necessarily within a partition. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -241,12 +239,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) /** * Represents a partitioning where rows are split across partitions based on some total ordering of - * the expressions specified in `ordering`. When data is partitioned in this manner the following - * two conditions are guaranteed to hold: - * - All row where the expressions in `ordering` evaluate to the same values will be in the same - * partition. - * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows - * that are in between `min` and `max` in this `ordering` will reside in this partition. + * the expressions specified in `ordering`. When data is partitioned in this manner, it guarantees: + * Given any 2 adjacent partitions, all the rows of the second partition must be larger than any row + * in the first partition, according to the `ordering` expressions. + * + * This is a strictly stronger guarantee than what `OrderedDistribution(ordering)` requires, as + * there is no overlap between partitions. * * This class extends expression primarily so that transformations over expression will descend * into its child. @@ -262,6 +260,22 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => + // If `ordering` is a prefix of `requiredOrdering`: + // Let's say `ordering` is [a, b] and `requiredOrdering` is [a, b, c]. According to the + // RangePartitioning definition, any [a, b] in a previous partition must be smaller + // than any [a, b] in the following partition. This also means any [a, b, c] in a + // previous partition must be smaller than any [a, b, c] in the following partition. + // Thus `RangePartitioning(a, b)` satisfies `OrderedDistribution(a, b, c)`. + // + // If `requiredOrdering` is a prefix of `ordering`: + // Let's say `ordering` is [a, b, c] and `requiredOrdering` is [a, b]. According to the + // RangePartitioning definition, any [a, b, c] in a previous partition must be smaller + // than any [a, b, c] in the following partition. If there is a [a1, b1] from a previous + // partition which is larger than a [a2, b2] from the following partition, then there + // must be a [a1, b1 c1] larger than [a2, b2, c2], which violates RangePartitioning + // definition. So it's guaranteed that, any [a, b] in a previous partition must not be + // greater(i.e. smaller or equal to) than any [a, b] in the following partition. Thus + // `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`. val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering, _) => From 3238e3d1c0d9be5c43a72705e18afbbb4c512e15 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Dec 2018 12:50:15 +0800 Subject: [PATCH 2270/2461] [SPARK-26348][SQL][TEST] make sure expression is resolved during test ## What changes were proposed in this pull request? cleanup some tests to make sure expression is resolved during test. ## How was this patch tested? test-only PR Closes #23297 from cloud-fan/test. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 10 ++++---- .../expressions/JsonExpressionsSuite.scala | 11 ++++----- .../catalyst/expressions/PredicateSuite.scala | 23 ++++++------------- .../expressions/StringExpressionsSuite.scala | 7 ++---- 5 files changed, 20 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 176ea823b1fcd..151481c80ee96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -136,7 +136,7 @@ package object dsl { implicit def longToLiteral(l: Long): Literal = Literal(l) implicit def floatToLiteral(f: Float): Literal = Literal(f) implicit def doubleToLiteral(d: Double): Literal = Literal(d) - implicit def stringToLiteral(s: String): Literal = Literal(s) + implicit def stringToLiteral(s: String): Literal = Literal.create(s, StringType) implicit def dateToLiteral(d: Date): Literal = Literal(d) implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying()) implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4fd170467d81..1c91adab71375 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + val expr = resolver.resolveTimeZones(expression) + assert(expr.resolved) + serializer.deserialize(serializer.serialize(expr)) } protected def checkEvaluation( @@ -296,9 +298,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation()) - // We should analyze the plan first, otherwise we possibly optimize an unresolved plan. - val analyzedPlan = SimpleAnalyzer.execute(plan) - val optimizedPlan = SimpleTestOptimizer.execute(analyzedPlan) + val optimizedPlan = SimpleTestOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 5d60cefc13896..238e6e34b4ae5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf @@ -694,11 +694,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val mapType2 = MapType(IntegerType, CalendarIntervalType) val schema2 = StructType(StructField("a", mapType2) :: Nil) val struct2 = Literal.create(null, schema2) - intercept[TreeNodeException[_]] { - checkEvaluation( - StructsToJson(Map.empty, struct2, gmtId), - null - ) + StructsToJson(Map.empty, struct2, gmtId).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("Unable to convert column a of type calendarinterval to JSON")) + case _ => fail("from_json should not work on interval map value type.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0f63717f9daf2..3541afcd2144d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -24,6 +24,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} @@ -231,22 +232,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { testWithRandomDataGeneration(structType, nullable) } - // Map types: not supported - for ( - keyType <- atomicTypes; - valueType <- atomicTypes; - nullable <- Seq(true, false)) { - val mapType = MapType(keyType, valueType) - val e = intercept[Exception] { - testWithRandomDataGeneration(mapType, nullable) - } - if (e.getMessage.contains("Code generation of")) { - // If the `value` expression is null, `eval` will be short-circuited. - // Codegen version evaluation will be run then. - assert(e.getMessage.contains("cannot generate equality code for un-comparable type")) - } else { - assert(e.getMessage.contains("Exception evaluating")) - } + // In doesn't support map type and will fail the analyzer. + val map = Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)) + In(map, Seq(map)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("function in does not support ordering on type map")) + case _ => fail("In should not work on map type") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index aa334e040d5fc..e95f2dff231b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -744,16 +744,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("ParseUrl") { def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = { - checkEvaluation( - ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected) + checkEvaluation(ParseUrl(Seq(urlStr, partToExtract)), expected) } def checkParseUrlWithKey( expected: String, urlStr: String, partToExtract: String, key: String): Unit = { - checkEvaluation( - ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected) + checkEvaluation(ParseUrl(Seq(urlStr, partToExtract, key)), expected) } checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST") @@ -798,7 +796,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sentences(nullString, nullString, nullString), null) checkEvaluation(Sentences(nullString, nullString), null) checkEvaluation(Sentences(nullString), null) - checkEvaluation(Sentences(Literal.create(null, NullType)), null) checkEvaluation(Sentences("", nullString, nullString), Seq.empty) checkEvaluation(Sentences("", nullString), Seq.empty) checkEvaluation(Sentences(""), Seq.empty) From 8edae94fa7ec1a1cc2c69e0924da0da85d4aac83 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 13 Dec 2018 13:14:59 +0800 Subject: [PATCH 2271/2461] [SPARK-26355][PYSPARK] Add a workaround for PyArrow 0.11. ## What changes were proposed in this pull request? In PyArrow 0.11, there is a API breaking change. - [ARROW-1949](https://issues.apache.org/jira/browse/ARROW-1949) - [Python/C++] Add option to Array.from_pandas and pyarrow.array to perform unsafe casts. This causes test failures in `ScalarPandasUDFTests.test_vectorized_udf_null_(byte|short|int|long)`: ``` File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/worker.py", line 377, in main process() File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/worker.py", line 372, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/serializers.py", line 317, in dump_stream batch = _create_batch(series, self._timezone) File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/serializers.py", line 286, in _create_batch arrs = [create_array(s, t) for s, t in series] File "/Users/ueshin/workspace/apache-spark/spark/python/pyspark/serializers.py", line 284, in create_array return pa.Array.from_pandas(s, mask=mask, type=t) File "pyarrow/array.pxi", line 474, in pyarrow.lib.Array.from_pandas return array(obj, mask=mask, type=type, safe=safe, from_pandas=True, File "pyarrow/array.pxi", line 169, in pyarrow.lib.array return _ndarray_to_array(values, mask, type, from_pandas, safe, File "pyarrow/array.pxi", line 69, in pyarrow.lib._ndarray_to_array check_status(NdarrayToArrow(pool, values, mask, from_pandas, File "pyarrow/error.pxi", line 81, in pyarrow.lib.check_status raise ArrowInvalid(message) ArrowInvalid: Floating point value truncated ``` We should add a workaround to support PyArrow 0.11. ## How was this patch tested? In my local environment. Closes #23305 from ueshin/issues/SPARK-26355/pyarrow_0.11. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- python/pyspark/serializers.py | 5 ++++- .../pyspark/sql/tests/test_pandas_udf_grouped_map.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index f3ebd3767a0a1..fd4695210fb7c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -281,7 +281,10 @@ def create_array(s, t): # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. return pa.Array.from_pandas(s.apply( lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) - return pa.Array.from_pandas(s, mask=mask, type=t) + elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + return pa.Array.from_pandas(s, mask=mask, type=t) + return pa.Array.from_pandas(s, mask=mask, type=t, safe=False) arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index bfecc071386e9..a12c608dff9dd 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -468,8 +468,15 @@ def invalid_positional_types(pdf): with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): grouped_df.apply(column_name_typo).collect() - with self.assertRaisesRegexp(Exception, "No cast implemented"): - grouped_df.apply(invalid_positional_types).collect() + from distutils.version import LooseVersion + import pyarrow as pa + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + else: + with self.assertRaisesRegexp(Exception, "an integer is required"): + grouped_df.apply(invalid_positional_types).collect() def test_positional_assignment_conf(self): import pandas as pd From 19b63c560d92a0e30b2dfc523bcce4f1c9daf851 Mon Sep 17 00:00:00 2001 From: Qi Shao Date: Thu, 13 Dec 2018 20:05:49 +0800 Subject: [PATCH 2272/2461] [MINOR][R] Fix indents of sparkR welcome message to be consistent with pyspark and spark-shell ## What changes were proposed in this pull request? 1. Removed empty space at the beginning of welcome message lines of sparkR to be consistent with welcome message of `pyspark` and `spark-shell` 2. Setting indent of logo message lines to 3 to be consistent with welcome message of `pyspark` and `spark-shell` Output of `pyspark`: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ Using Python version 3.6.6 (default, Jun 28 2018 11:07:29) SparkSession available as 'spark'. ``` Output of `spark-shell`: ``` Spark session available as 'spark'. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ Using Scala version 2.11.12 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_161) Type in expressions to have them evaluated. Type :help for more information. ``` ## How was this patch tested? Before: Output of `sparkR`: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ SparkSession available as 'spark'. ``` After: ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0 /_/ SparkSession available as 'spark'. ``` Closes #23293 from AzureQ/master. Authored-by: Qi Shao Signed-off-by: Hyukjin Kwon --- R/pkg/inst/profile/shell.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 32eb3671b5941..e4e0d032997de 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -33,19 +33,19 @@ sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", spark) assign("sc", sc, envir = .GlobalEnv) sparkVer <- SparkR:::callJMethod(sc, "version") - cat("\n Welcome to") + cat("\nWelcome to") cat("\n") - cat(" ____ __", "\n") - cat(" / __/__ ___ _____/ /__", "\n") - cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") - cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") if (nchar(sparkVer) == 0) { cat("\n") } else { - cat(" version ", sparkVer, "\n") + cat(" version", sparkVer, "\n") } - cat(" /_/", "\n") + cat(" /_/", "\n") cat("\n") - cat("\n SparkSession available as 'spark'.\n") + cat("\nSparkSession available as 'spark'.\n") } From f3726092169406979849b3cb5afeb52be106fd68 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Thu, 13 Dec 2018 07:40:13 -0600 Subject: [PATCH 2273/2461] [MINOR][DOC] Fix comments of ConvertToLocalRelation rule ## What changes were proposed in this pull request? There are some comments issues left when `ConvertToLocalRelation` rule was added (see #22205/[SPARK-25212](https://issues.apache.org/jira/browse/SPARK-25212)). This PR fixes those comments issues. ## How was this patch tested? N/A Closes #23273 from seancxmao/ConvertToLocalRelation-doc. Authored-by: seancxmao Signed-off-by: Sean Owen --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d251eeab8484..f615757a837a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -131,11 +131,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: - // run this once earlier. this might simplify the plan and reduce cost of optimizer. - // for example, a query such as Filter(LocalRelation) would go through all the heavy + // Run this once earlier. This might simplify the plan and reduce cost of optimizer. + // For example, a query such as Filter(LocalRelation) would go through all the heavy // optimizer rules that are triggered when there is a filter - // (e.g. InferFiltersFromConstraints). if we run this batch earlier, the query becomes just - // LocalRelation and does not trigger many rules + // (e.g. InferFiltersFromConstraints). If we run this batch earlier, the query becomes just + // LocalRelation and does not trigger many rules. Batch("LocalRelation early", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: @@ -1370,10 +1370,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { } /** - * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to - * another LocalRelation. - * - * This is relatively simple as it currently handles only 2 single case: Project and Limit. + * Converts local operations (i.e. ones that don't require data exchange) on `LocalRelation` to + * another `LocalRelation`. */ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { From f69998ace6b2e0665e45839c3e98c77a18be442a Mon Sep 17 00:00:00 2001 From: lichaoqun Date: Thu, 13 Dec 2018 07:42:17 -0600 Subject: [PATCH 2274/2461] [MINOR][DOC] update the condition description of BypassMergeSortShuffle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? These three condition descriptions should be updated, follow #23228 :
    • no Ordering is specified,
    • no Aggregator is specified, and
    • the number of partitions is less than spark.shuffle.sort.bypassMergeThreshold.
    • 1、If the shuffle dependency specifies aggregation, but it only aggregates at the reduce-side, BypassMergeSortShuffle can still be used. 2、If the number of output partitions is spark.shuffle.sort.bypassMergeThreshold(eg.200), we can use BypassMergeSortShuffle. ## How was this patch tested? N/A Closes #23281 from lcqzte10192193/wid-lcq-1211. Authored-by: lichaoqun Signed-off-by: Sean Owen --- .../spark/shuffle/sort/BypassMergeSortShuffleWriter.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index fda33cd8293d5..997bc9e3f0435 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -58,9 +58,8 @@ * simultaneously opens separate serializers and file streams for all partitions. As a result, * {@link SortShuffleManager} only selects this write path when *
        - *
      • no Ordering is specified,
      • - *
      • no Aggregator is specified, and
      • - *
      • the number of partitions is less than + *
      • no map-side combine is specified, and
      • + *
      • the number of partitions is less than or equal to * spark.shuffle.sort.bypassMergeThreshold.
      • *
      * From 29b3eb6fedd8f90495046da598eacc4ac00944c3 Mon Sep 17 00:00:00 2001 From: "n.fraison" Date: Thu, 13 Dec 2018 08:34:47 -0600 Subject: [PATCH 2275/2461] [SPARK-26340][CORE] Ensure cores per executor is greater than cpu per task Currently this check is only performed for dynamic allocation use case in ExecutorAllocationManager. ## What changes were proposed in this pull request? Checks that cpu per task is lower than number of cores per executor otherwise throw an exception ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23290 from ashangit/master. Authored-by: n.fraison Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/SparkConf.scala | 9 +++++++++ .../src/test/scala/org/apache/spark/SparkConfSuite.scala | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 21c5cbc04d813..8d135d3e083d7 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -605,6 +605,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } + if (contains("spark.executor.cores") && contains("spark.task.cpus")) { + val executorCores = getInt("spark.executor.cores", 1) + val taskCpus = getInt("spark.task.cpus", 1) + + if (executorCores < taskCpus) { + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + } + } + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index df274d949bae3..7cb03deae1391 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -138,6 +138,13 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(sc.appName === "My other app") } + test("creating SparkContext with cpus per tasks bigger than cores per executors") { + val conf = new SparkConf(false) + .set("spark.executor.cores", "1") + .set("spark.task.cpus", "2") + intercept[SparkException] { sc = new SparkContext(conf) } + } + test("nested property names") { // This wasn't supported by some external conf parsing libraries System.setProperty("spark.test.a", "a") From 6c1f7ba8f627a69cac74f11400066dd9871d9102 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Dec 2018 23:03:26 +0800 Subject: [PATCH 2276/2461] [SPARK-26313][SQL] move `newScanBuilder` from Table to read related mix-in traits ## What changes were proposed in this pull request? As discussed in https://github.com/apache/spark/pull/23208/files#r239684490 , we should put `newScanBuilder` in read related mix-in traits like `SupportsBatchRead`, to support write-only table. In the `Append` operator, we should skip schema validation if not necessary. In the future we would introduce a capability API, so that data source can tell Spark that it doesn't want to do validation. ## How was this patch tested? existing tests. Closes #23266 from cloud-fan/ds-read. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- .../sql/sources/v2/SupportsBatchRead.java | 8 ++--- .../spark/sql/sources/v2/SupportsRead.java | 35 +++++++++++++++++++ .../apache/spark/sql/sources/v2/Table.java | 15 ++------ 3 files changed, 41 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java index 0df89dbb608a4..6c5a95d2a75b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java @@ -24,10 +24,10 @@ /** * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. *

      - * If a {@link Table} implements this interface, its {@link Table#newScanBuilder(DataSourceOptions)} - * must return a {@link ScanBuilder} that builds {@link Scan} with {@link Scan#toBatch()} - * implemented. + * If a {@link Table} implements this interface, the + * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that + * builds {@link Scan} with {@link Scan#toBatch()} implemented. *

      */ @Evolving -public interface SupportsBatchRead extends Table { } +public interface SupportsBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java new file mode 100644 index 0000000000000..e22738d20d507 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.sql.sources.v2.reader.Scan; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; + +/** + * An internal base interface of mix-in interfaces for readable {@link Table}. This adds + * {@link #newScanBuilder(DataSourceOptions)} that is used to create a scan for batch, micro-batch, + * or continuous processing. + */ +interface SupportsRead extends Table { + + /** + * Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this + * method to configure each scan. + */ + ScanBuilder newScanBuilder(DataSourceOptions options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 0c65fe0f9e76a..08664859b8de2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -18,8 +18,6 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; import org.apache.spark.sql.types.StructType; /** @@ -43,17 +41,8 @@ public interface Table { String name(); /** - * Returns the schema of this table. + * Returns the schema of this table. If the table is not readable and doesn't have a schema, an + * empty schema can be returned here. */ StructType schema(); - - /** - * Returns a {@link ScanBuilder} which can be used to build a {@link Scan} later. Spark will call - * this method for each data scanning query. - *

      - * The builder can take some query specific information to do operators pushdown, and keep these - * information in the created {@link Scan}. - *

      - */ - ScanBuilder newScanBuilder(DataSourceOptions options); } From 524d1be6d2920674eb871b5f0f25e7496a374090 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 13 Dec 2018 09:07:33 -0800 Subject: [PATCH 2277/2461] [SPARK-26098][WEBUI] Show associated SQL query in Job page ## What changes were proposed in this pull request? For jobs associated to SQL queries, it would be easier to understand the context to showing the SQL query in Job detail page. Before code change, it is hard to tell what the job is about from the job page: ![image](https://user-images.githubusercontent.com/1097932/48659359-96baa180-ea8a-11e8-8419-a0a87c3f30fc.png) After code change: ![image](https://user-images.githubusercontent.com/1097932/48659390-26f8e680-ea8b-11e8-8fdd-3b58909ea364.png) After navigating to the associated SQL detail page, We can see the whole context : ![image](https://user-images.githubusercontent.com/1097932/48659463-9fac7280-ea8c-11e8-9dfe-244e849f72a5.png) **For Jobs don't have associated SQL query, the text won't be shown.** ## How was this patch tested? Manual test Closes #23068 from gengliangwang/addSQLID. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../apache/spark/status/AppStatusListener.scala | 7 ++++++- .../org/apache/spark/status/AppStatusStore.scala | 7 +++++++ .../scala/org/apache/spark/status/LiveEntity.scala | 5 +++-- .../scala/org/apache/spark/status/storeTypes.scala | 3 ++- .../scala/org/apache/spark/ui/jobs/JobPage.scala | 14 +++++++++++++- 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index bd3f58b6182c0..262ff6547faa5 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -70,6 +70,8 @@ private[spark] class AppStatusListener( private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() + + private val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id" // Keep the active executor count as a separate variable to avoid having to do synchronization // around liveExecutors. @volatile private var activeExecutorCount = 0 @@ -318,6 +320,8 @@ private[spark] class AppStatusListener( val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") val jobGroup = Option(event.properties) .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + val sqlExecutionId = Option(event.properties) + .flatMap(p => Option(p.getProperty(SQL_EXECUTION_ID_KEY)).map(_.toLong)) val job = new LiveJob( event.jobId, @@ -325,7 +329,8 @@ private[spark] class AppStatusListener( if (event.time > 0) Some(new Date(event.time)) else None, event.stageIds, jobGroup, - numTasks) + numTasks, + sqlExecutionId) liveJobs.put(event.jobId, job) liveUpdate(job, now) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index b35781cb36e81..312bcccb1cca1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -56,6 +56,13 @@ private[spark] class AppStatusStore( store.read(classOf[JobDataWrapper], jobId).info } + // Returns job data and associated SQL execution ID of certain Job ID. + // If there is no related SQL execution, the SQL execution ID part will be None. + def jobWithAssociatedSql(jobId: Int): (v1.JobData, Option[Long]) = { + val data = store.read(classOf[JobDataWrapper], jobId) + (data.info, data.sqlExecutionId) + } + def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { val base = store.view(classOf[ExecutorSummaryWrapper]) val filtered = if (activeOnly) { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 47e45a66ecccb..7f7b83a54d794 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -64,7 +64,8 @@ private class LiveJob( val submissionTime: Option[Date], val stageIds: Seq[Int], jobGroup: Option[String], - numTasks: Int) extends LiveEntity { + numTasks: Int, + sqlExecutionId: Option[Long]) extends LiveEntity { var activeTasks = 0 var completedTasks = 0 @@ -108,7 +109,7 @@ private class LiveJob( skippedStages.size, failedStages, killedSummary) - new JobDataWrapper(info, skippedStages) + new JobDataWrapper(info, skippedStages, sqlExecutionId) } } diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index ef19e86f3135f..eea47b3b17098 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -68,7 +68,8 @@ private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { */ private[spark] class JobDataWrapper( val info: JobData, - val skippedStages: Set[Int]) { + val skippedStages: Set[Int], + val sqlExecutionId: Option[Long]) { @JsonIgnore @KVIndex private def id: Int = info.jobId diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 55444a2c0c9ab..b58a6ca447edf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -189,7 +189,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val jobId = parameterId.toInt - val jobData = store.asOption(store.job(jobId)).getOrElse { + val (jobData, sqlExecutionId) = store.asOption(store.jobWithAssociatedSql(jobId)).getOrElse { val content =

      No information to display for job {jobId}

      @@ -197,6 +197,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP return UIUtils.headerSparkPage( request, s"Details for Job $jobId", content, parent) } + val isComplete = jobData.status != JobExecutionStatus.RUNNING val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the @@ -278,6 +279,17 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP Status: {jobData.status} + { + if (sqlExecutionId.isDefined) { +
    • + Associated SQL Query: + {{sqlExecutionId.get}} +
    • + } + } { if (jobData.jobGroup.isDefined) {
    • From 362e472831e0609f88fdeb01d8e14badc812b0f4 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 13 Dec 2018 16:12:55 -0800 Subject: [PATCH 2278/2461] [SPARK-23886][SS] Update query status for ContinuousExecution ## What changes were proposed in this pull request? Added query status updates to ContinuousExecution. ## How was this patch tested? Existing unit tests + added ContinuousQueryStatusAndProgressSuite. Closes #23095 from gaborgsomogyi/SPARK-23886. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../streaming/MicroBatchExecution.scala | 6 ++ .../streaming/ProgressReporter.scala | 1 - .../continuous/ContinuousExecution.scala | 6 ++ .../sql/streaming/StreamingQueryStatus.scala | 6 +- ...ontinuousQueryStatusAndProgressSuite.scala | 55 +++++++++++++++++++ 5 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 64e09edf27f58..03beefeca269b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -147,6 +147,12 @@ class MicroBatchExecution( logInfo(s"Query $prettyIdString was stopped") } + /** Begins recording statistics about query progress for a given trigger. */ + override protected def startTrigger(): Unit = { + super.startTrigger() + currentStatus = currentStatus.copy(isTriggerActive = true) + } + /** * Repeatedly attempts to run batches as data arrives. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 6a22f0cc8431a..39ab702ee083c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -114,7 +114,6 @@ trait ProgressReporter extends Logging { logDebug("Starting Trigger Calculation") lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() - currentStatus = currentStatus.copy(isTriggerActive = true) currentTriggerStartOffsets = null currentTriggerEndOffsets = null currentDurationsMs.clear() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4d42428fd189e..f0859aaaa3041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -118,6 +118,8 @@ class ContinuousExecution( // For at least once, we can just ignore those reports and risk duplicates. commitLog.getLatest() match { case Some((latestEpochId, _)) => + updateStatusMessage("Starting new streaming query " + + s"and getting offsets from latest epoch $latestEpochId") val nextOffsets = offsetLog.get(latestEpochId).getOrElse { throw new IllegalStateException( s"Batch $latestEpochId was committed without end epoch offsets!") @@ -129,6 +131,7 @@ class ContinuousExecution( nextOffsets case None => // We are starting this stream for the first time. Offsets are all None. + updateStatusMessage("Starting new streaming query") logInfo(s"Starting new streaming query.") currentBatchId = 0 OffsetSeq.fill(continuousSources.map(_ => null): _*) @@ -263,6 +266,7 @@ class ContinuousExecution( epochUpdateThread.setDaemon(true) epochUpdateThread.start() + updateStatusMessage("Running") reportTimeTaken("runContinuous") { SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution) { @@ -322,6 +326,8 @@ class ContinuousExecution( * before this is called. */ def commit(epoch: Long): Unit = { + updateStatusMessage(s"Committing epoch $epoch") + assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 9dc62b7aac891..6ca9aacab7247 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -28,9 +28,11 @@ import org.apache.spark.annotation.Evolving * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. - * @param isDataAvailable True when there is new data to be processed. + * @param isDataAvailable True when there is new data to be processed. Doesn't apply + * to ContinuousExecution where it is always false. * @param isTriggerActive True when the trigger is actively firing, false when waiting for the - * next trigger time. + * next trigger time. Doesn't apply to ContinuousExecution where it is + * always false. * * @since 2.1.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala new file mode 100644 index 0000000000000..10bea7f090571 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.streaming.Trigger + +class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase { + test("StreamingQueryStatus - ContinuousExecution isDataAvailable and isTriggerActive " + + "should be false") { + import testImplicits._ + + val input = ContinuousMemoryStream[Int] + + def assertStatus(stream: StreamExecution): Unit = { + assert(stream.status.isDataAvailable === false) + assert(stream.status.isTriggerActive === false) + } + + val trigger = Trigger.Continuous(100) + testStream(input.toDF(), useV2Sink = true)( + StartStream(trigger), + Execute(assertStatus), + AddData(input, 0, 1, 2), + Execute(assertStatus), + CheckAnswer(0, 1, 2), + Execute(assertStatus), + StopStream, + Execute(assertStatus), + AddData(input, 3, 4, 5), + Execute(assertStatus), + StartStream(trigger), + Execute(assertStatus), + CheckAnswer(0, 1, 2, 3, 4, 5), + Execute(assertStatus), + StopStream, + Execute(assertStatus)) + } +} From 160e583a17235318c06b95992941a772ff782fae Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 14 Dec 2018 10:45:24 +0800 Subject: [PATCH 2279/2461] [SPARK-26364][PYTHON][TESTING] Clean up imports in test_pandas_udf* ## What changes were proposed in this pull request? Clean up unconditional import statements and move them to the top. Conditional imports (pandas, numpy, pyarrow) are left as-is. ## How was this patch tested? Exising tests. Closes #23314 from icexelloss/clean-up-test-imports. Authored-by: Li Jin Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_pandas_udf.py | 16 +--- .../sql/tests/test_pandas_udf_grouped_agg.py | 39 +--------- .../sql/tests/test_pandas_udf_grouped_map.py | 40 +++------- .../sql/tests/test_pandas_udf_scalar.py | 75 +++++-------------- .../sql/tests/test_pandas_udf_window.py | 29 +------ 5 files changed, 36 insertions(+), 163 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index c4b5478a7e893..d4d9679649ee9 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -17,12 +17,16 @@ import unittest +from pyspark.sql.functions import udf, pandas_udf, PandasUDFType from pyspark.sql.types import * from pyspark.sql.utils import ParseException +from pyspark.rdd import PythonEvalType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest +from py4j.protocol import Py4JJavaError + @unittest.skipIf( not have_pandas or not have_pyarrow, @@ -30,9 +34,6 @@ class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - udf = pandas_udf(lambda x: x, DoubleType()) self.assertEqual(udf.returnType, DoubleType()) self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @@ -65,10 +66,6 @@ def test_pandas_udf_basic(self): self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_pandas_udf_decorator(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - from pyspark.sql.types import StructType, StructField, DoubleType - @pandas_udf(DoubleType()) def foo(x): return x @@ -114,8 +111,6 @@ def foo(x): self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_udf_wrong_arg(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): with self.assertRaises(ParseException): @pandas_udf('blah') @@ -151,9 +146,6 @@ def foo(k, v, w): return k def test_stopiteration_in_udf(self): - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - from py4j.protocol import Py4JJavaError - def foo(x): raise StopIteration() diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 5383704434c85..18264ead2fd08 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -17,6 +17,9 @@ import unittest +from pyspark.rdd import PythonEvalType +from pyspark.sql.functions import array, explode, col, lit, mean, sum, \ + udf, pandas_udf, PandasUDFType from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ @@ -31,7 +34,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property def data(self): - from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))) \ @@ -40,8 +42,6 @@ def data(self): @property def python_plus_one(self): - from pyspark.sql.functions import udf - @udf('double') def plus_one(v): assert isinstance(v, (int, float)) @@ -51,7 +51,6 @@ def plus_one(v): @property def pandas_scalar_plus_two(self): import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.SCALAR) def plus_two(v): @@ -61,8 +60,6 @@ def plus_two(v): @property def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() @@ -70,8 +67,6 @@ def avg(v): @property def pandas_agg_sum_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def sum(v): return v.sum() @@ -80,7 +75,6 @@ def sum(v): @property def pandas_agg_weighted_mean_udf(self): import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType @pandas_udf('double', PandasUDFType.GROUPED_AGG) def weighted_mean(v, w): @@ -88,8 +82,6 @@ def weighted_mean(v, w): return weighted_mean def test_manual(self): - from pyspark.sql.functions import pandas_udf, array - df = self.data sum_udf = self.pandas_agg_sum_udf mean_udf = self.pandas_agg_mean_udf @@ -118,8 +110,6 @@ def test_manual(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_basic(self): - from pyspark.sql.functions import col, lit, mean - df = self.data weighted_mean_udf = self.pandas_agg_weighted_mean_udf @@ -150,9 +140,6 @@ def test_basic(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_unsupported_types(self): - from pyspark.sql.types import DoubleType, MapType - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): pandas_udf( @@ -173,8 +160,6 @@ def mean_and_std_udf(v): return {v.mean(): v.std()} def test_alias(self): - from pyspark.sql.functions import mean - df = self.data mean_udf = self.pandas_agg_mean_udf @@ -187,8 +172,6 @@ def test_mixed_sql(self): """ Test mixing group aggregate pandas UDF with sql expression. """ - from pyspark.sql.functions import sum - df = self.data sum_udf = self.pandas_agg_sum_udf @@ -225,8 +208,6 @@ def test_mixed_udfs(self): """ Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. """ - from pyspark.sql.functions import sum - df = self.data plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two @@ -292,8 +273,6 @@ def test_multiple_udfs(self): """ Test multiple group aggregate pandas UDFs in one agg function. """ - from pyspark.sql.functions import sum, mean - df = self.data mean_udf = self.pandas_agg_mean_udf sum_udf = self.pandas_agg_sum_udf @@ -315,8 +294,6 @@ def test_multiple_udfs(self): self.assertPandasEqual(expected1, result1) def test_complex_groupby(self): - from pyspark.sql.functions import sum - df = self.data sum_udf = self.pandas_agg_sum_udf plus_one = self.python_plus_one @@ -359,8 +336,6 @@ def test_complex_groupby(self): self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) def test_complex_expressions(self): - from pyspark.sql.functions import col, sum - df = self.data plus_one = self.python_plus_one plus_two = self.pandas_scalar_plus_two @@ -434,7 +409,6 @@ def test_complex_expressions(self): self.assertPandasEqual(expected3, result3) def test_retain_group_columns(self): - from pyspark.sql.functions import sum with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -444,8 +418,6 @@ def test_retain_group_columns(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) @@ -453,8 +425,6 @@ def test_array_type(self): self.assertEquals(result1.first()['v2'], [1.0, 2.0]) def test_invalid_args(self): - from pyspark.sql.functions import mean - df = self.data plus_one = self.python_plus_one mean_udf = self.pandas_agg_mean_udf @@ -478,9 +448,6 @@ def test_invalid_args(self): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() def test_register_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - sum_pandas_udf = pandas_udf( lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index a12c608dff9dd..80e70349b78d3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -18,7 +18,12 @@ import datetime import unittest +from collections import OrderedDict +from decimal import Decimal +from distutils.version import LooseVersion + from pyspark.sql import Row +from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -32,16 +37,12 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property def data(self): - from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') def test_supported_types(self): - from decimal import Decimal - from distutils.version import LooseVersion import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType values = [ 1, 2, 3, @@ -131,8 +132,6 @@ def test_supported_types(self): self.assertPandasEqual(expected3, result3) def test_array_type_correct(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") output_schema = StructType( @@ -151,8 +150,6 @@ def test_array_type_correct(self): self.assertPandasEqual(expected, result) def test_register_grouped_map_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -161,7 +158,6 @@ def test_register_grouped_map_udf(self): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data @pandas_udf( @@ -176,7 +172,6 @@ def foo(pdf): self.assertPandasEqual(expected, result) def test_coerce(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo = pandas_udf( @@ -191,7 +186,6 @@ def test_coerce(self): self.assertPandasEqual(expected, result) def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType df = self.data @pandas_udf( @@ -210,7 +204,6 @@ def normalize(pdf): self.assertPandasEqual(expected, result) def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data @pandas_udf( @@ -229,7 +222,6 @@ def normalize(pdf): self.assertPandasEqual(expected, result) def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data foo_udf = pandas_udf( @@ -243,8 +235,6 @@ def test_datatype_string(self): self.assertPandasEqual(expected, result) def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -255,7 +245,6 @@ def test_wrong_return_type(self): PandasUDFType.GROUPED_MAP) def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType df = self.data with QuietTest(self.sc): @@ -277,9 +266,7 @@ def test_wrong_args(self): pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from distutils.version import LooseVersion import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' unsupported_types = [ @@ -300,7 +287,6 @@ def test_unsupported_types(self): # Regression test for SPARK-23314 def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am dt = [datetime.datetime(2015, 11, 1, 0, 30), datetime.datetime(2015, 11, 1, 1, 30), @@ -311,12 +297,12 @@ def test_timestamp_dst(self): self.assertPandasEqual(df.toPandas(), result.toPandas()) def test_udf_with_key(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType + import numpy as np + df = self.data pdf = df.toPandas() def foo1(key, pdf): - import numpy as np assert type(key) == tuple assert type(key[0]) == np.int64 @@ -326,7 +312,6 @@ def foo1(key, pdf): v4=pdf.v * pdf.id.mean()) def foo2(key, pdf): - import numpy as np assert type(key) == tuple assert type(key[0]) == np.int64 assert type(key[1]) == np.int32 @@ -385,9 +370,7 @@ def foo3(key, pdf): self.assertPandasEqual(expected4, result4) def test_column_order(self): - from collections import OrderedDict import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType # Helper function to set column names from a list def rename_pdf(pdf, names): @@ -468,7 +451,6 @@ def invalid_positional_types(pdf): with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): grouped_df.apply(column_name_typo).collect() - from distutils.version import LooseVersion import pyarrow as pa if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0. @@ -480,7 +462,6 @@ def invalid_positional_types(pdf): def test_positional_assignment_conf(self): import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType with self.sql_conf({ "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): @@ -496,9 +477,7 @@ def foo(_): self.assertEqual(r.b, 1) def test_self_join_with_pandas(self): - import pyspark.sql.functions as F - - @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + @pandas_udf('key long, col string', PandasUDFType.GROUPED_MAP) def dummy_pandas_udf(df): return df[['key', 'col']] @@ -508,12 +487,11 @@ def dummy_pandas_udf(df): # this was throwing an AnalysisException before SPARK-24208 res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), - F.col('temp0.key') == F.col('temp1.key')) + col('temp0.key') == col('temp1.key')) self.assertEquals(res.count(), 5) def test_mixed_scalar_udfs_followed_by_grouby_apply(self): import pandas as pd - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType df = self.spark.range(0, 10).toDF('v1') df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 2f585a3725988..6a6865a9fb16d 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -16,12 +16,20 @@ # import datetime import os +import random import shutil import sys import tempfile import time import unittest +from datetime import date, datetime +from decimal import Decimal +from distutils.version import LooseVersion + +from pyspark.rdd import PythonEvalType +from pyspark.sql import Column +from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf from pyspark.sql.types import Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -59,18 +67,16 @@ def tearDownClass(cls): @property def nondeterministic_vectorized_udf(self): - from pyspark.sql.functions import pandas_udf + import pandas as pd + import numpy as np @pandas_udf('double') def random_udf(v): - import pandas as pd - import numpy as np return pd.Series(np.random.random(len(v))) random_udf = random_udf.asNondeterministic() return random_udf def test_pandas_udf_tokenize(self): - from pyspark.sql.functions import pandas_udf tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), ArrayType(StringType())) self.assertEqual(tokenize.returnType, ArrayType(StringType())) @@ -79,7 +85,6 @@ def test_pandas_udf_tokenize(self): self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) def test_pandas_udf_nested_arrays(self): - from pyspark.sql.functions import pandas_udf tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), ArrayType(ArrayType(StringType()))) self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) @@ -88,7 +93,6 @@ def test_pandas_udf_nested_arrays(self): self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -114,9 +118,6 @@ def test_vectorized_udf_basic(self): self.assertEquals(df.collect(), res.collect()) def test_register_nondeterministic_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - import random random_pandas_udf = pandas_udf( lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() self.assertEqual(random_pandas_udf.deterministic, False) @@ -129,7 +130,6 @@ def test_register_nondeterministic_vectorized_udf_basic(self): self.assertEqual(row[0], 7) def test_vectorized_udf_null_boolean(self): - from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] schema = StructType().add("bool", BooleanType()) df = self.spark.createDataFrame(data, schema) @@ -138,7 +138,6 @@ def test_vectorized_udf_null_boolean(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_byte(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("byte", ByteType()) df = self.spark.createDataFrame(data, schema) @@ -147,7 +146,6 @@ def test_vectorized_udf_null_byte(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_short(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("short", ShortType()) df = self.spark.createDataFrame(data, schema) @@ -156,7 +154,6 @@ def test_vectorized_udf_null_short(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_int(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("int", IntegerType()) df = self.spark.createDataFrame(data, schema) @@ -165,7 +162,6 @@ def test_vectorized_udf_null_int(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_long(self): - from pyspark.sql.functions import pandas_udf, col data = [(None,), (2,), (3,), (4,)] schema = StructType().add("long", LongType()) df = self.spark.createDataFrame(data, schema) @@ -174,7 +170,6 @@ def test_vectorized_udf_null_long(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_float(self): - from pyspark.sql.functions import pandas_udf, col data = [(3.0,), (5.0,), (-1.0,), (None,)] schema = StructType().add("float", FloatType()) df = self.spark.createDataFrame(data, schema) @@ -183,7 +178,6 @@ def test_vectorized_udf_null_float(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_double(self): - from pyspark.sql.functions import pandas_udf, col data = [(3.0,), (5.0,), (-1.0,), (None,)] schema = StructType().add("double", DoubleType()) df = self.spark.createDataFrame(data, schema) @@ -192,8 +186,6 @@ def test_vectorized_udf_null_double(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_decimal(self): - from decimal import Decimal - from pyspark.sql.functions import pandas_udf, col data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] schema = StructType().add("decimal", DecimalType(38, 18)) df = self.spark.createDataFrame(data, schema) @@ -202,7 +194,6 @@ def test_vectorized_udf_null_decimal(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_string(self): - from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] schema = StructType().add("str", StringType()) df = self.spark.createDataFrame(data, schema) @@ -211,7 +202,6 @@ def test_vectorized_udf_null_string(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_string_in_udf(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd df = self.spark.range(10) str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) @@ -220,7 +210,6 @@ def test_vectorized_udf_string_in_udf(self): self.assertEquals(expected.collect(), actual.collect()) def test_vectorized_udf_datatype_string(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -244,9 +233,8 @@ def test_vectorized_udf_datatype_string(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_binary(self): - from distutils.version import LooseVersion import pyarrow as pa - from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -262,7 +250,6 @@ def test_vectorized_udf_null_binary(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_array_type(self): - from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), ([3, 4],)] array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) df = self.spark.createDataFrame(data, schema=array_schema) @@ -271,7 +258,6 @@ def test_vectorized_udf_array_type(self): self.assertEquals(df.collect(), result.collect()) def test_vectorized_udf_null_array(self): - from pyspark.sql.functions import pandas_udf, col data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) df = self.spark.createDataFrame(data, schema=array_schema) @@ -280,7 +266,6 @@ def test_vectorized_udf_null_array(self): self.assertEquals(df.collect(), result.collect()) def test_vectorized_udf_complex(self): - from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( col('id').cast('int').alias('a'), col('id').cast('int').alias('b'), @@ -293,7 +278,6 @@ def test_vectorized_udf_complex(self): self.assertEquals(expected.collect(), res.collect()) def test_vectorized_udf_exception(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) with QuietTest(self.sc): @@ -301,8 +285,8 @@ def test_vectorized_udf_exception(self): df.select(raise_exception(col('id'))).collect() def test_vectorized_udf_invalid_length(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd + df = self.spark.range(10) raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) with QuietTest(self.sc): @@ -312,7 +296,6 @@ def test_vectorized_udf_invalid_length(self): df.select(raise_exception(col('id'))).collect() def test_vectorized_udf_chained(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) f = pandas_udf(lambda x: x + 1, LongType()) g = pandas_udf(lambda x: x - 1, LongType()) @@ -320,7 +303,6 @@ def test_vectorized_udf_chained(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -328,7 +310,6 @@ def test_vectorized_udf_wrong_return_type(self): pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): @@ -336,7 +317,6 @@ def test_vectorized_udf_return_scalar(self): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) @pandas_udf(returnType=LongType()) @@ -346,21 +326,18 @@ def identity(x): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_empty_partition(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda x: x, LongType()) res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_varargs(self): - from pyspark.sql.functions import pandas_udf, col df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda *v: v[0], LongType()) res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -368,8 +345,6 @@ def test_vectorized_udf_unsupported_types(self): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) def test_vectorized_udf_dates(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import date schema = StructType().add("idx", LongType()).add("date", DateType()) data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), @@ -405,8 +380,6 @@ def check_data(idx, date, date_copy): self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime schema = StructType([ StructField("idx", LongType(), True), StructField("timestamp", TimestampType(), True)]) @@ -447,8 +420,8 @@ def check_data(idx, timestamp, timestamp_copy): self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd + df = self.spark.range(10) @pandas_udf(returnType=TimestampType()) @@ -465,8 +438,8 @@ def gen_timestamps(id): self.assertEquals(expected, ts) def test_vectorized_udf_check_config(self): - from pyspark.sql.functions import pandas_udf, col import pandas as pd + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @@ -479,9 +452,8 @@ def check_records_per_batch(x): self.assertTrue(r <= 3) def test_vectorized_udf_timestamps_respect_session_timezone(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime import pandas as pd + schema = StructType([ StructField("idx", LongType(), True), StructField("timestamp", TimestampType(), True)]) @@ -519,8 +491,6 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import pandas_udf, col - @pandas_udf('double') def plus_ten(v): return v + 10 @@ -533,8 +503,6 @@ def plus_ten(v): self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) def test_nondeterministic_vectorized_udf_in_aggregate(self): - from pyspark.sql.functions import sum - df = self.spark.range(10) random_udf = self.nondeterministic_vectorized_udf @@ -545,8 +513,6 @@ def test_nondeterministic_vectorized_udf_in_aggregate(self): df.agg(sum(random_udf(df.id))).collect() def test_register_vectorized_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, col, expr df = self.spark.range(10).select( col('id').cast('int').alias('a'), col('id').cast('int').alias('b')) @@ -563,11 +529,10 @@ def test_register_vectorized_udf_basic(self): # Regression test for SPARK-23314 def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] + dt = [datetime(2015, 11, 1, 0, 30), + datetime(2015, 11, 1, 1, 30), + datetime(2015, 11, 1, 2, 30)] df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') foo_udf = pandas_udf(lambda x: x, 'timestamp') result = df.withColumn('time', foo_udf(df.time)) @@ -593,7 +558,6 @@ def test_type_annotation(self): def test_mixed_udf(self): import pandas as pd - from pyspark.sql.functions import col, udf, pandas_udf df = self.spark.range(0, 1).toDF('v') @@ -696,8 +660,6 @@ def f4(x): def test_mixed_udf_and_sql(self): import pandas as pd - from pyspark.sql import Column - from pyspark.sql.functions import udf, pandas_udf df = self.spark.range(0, 1).toDF('v') @@ -758,7 +720,6 @@ def test_datasource_with_udf(self): # This needs to a separate test because Arrow dependency is optional import pandas as pd import numpy as np - from pyspark.sql.functions import pandas_udf, lit, col path = tempfile.mkdtemp() shutil.rmtree(path) diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index f0e6d2696df62..0a7a19c1c0814 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -18,6 +18,8 @@ import unittest from pyspark.sql.utils import AnalysisException +from pyspark.sql.functions import array, explode, col, lit, mean, min, max, rank, \ + udf, pandas_udf, PandasUDFType from pyspark.sql.window import Window from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -30,7 +32,6 @@ class WindowPandasUDFTests(ReusedSQLTestCase): @property def data(self): - from pyspark.sql.functions import array, explode, col, lit return self.spark.range(10).toDF('id') \ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))) \ @@ -39,18 +40,14 @@ def data(self): @property def python_plus_one(self): - from pyspark.sql.functions import udf return udf(lambda v: v + 1, 'double') @property def pandas_scalar_time_two(self): - from pyspark.sql.functions import pandas_udf return pandas_udf(lambda v: v * 2, 'double') @property def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() @@ -58,8 +55,6 @@ def avg(v): @property def pandas_agg_max_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def max(v): return v.max() @@ -67,8 +62,6 @@ def max(v): @property def pandas_agg_min_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUPED_AGG) def min(v): return v.min() @@ -88,8 +81,6 @@ def unpartitioned_window(self): return Window.partitionBy() def test_simple(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window @@ -105,8 +96,6 @@ def test_simple(self): self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) def test_multiple_udfs(self): - from pyspark.sql.functions import max, min, mean - df = self.data w = self.unbounded_window @@ -121,8 +110,6 @@ def test_multiple_udfs(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_replace_existing(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window @@ -132,8 +119,6 @@ def test_replace_existing(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_mixed_sql(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window mean_udf = self.pandas_agg_mean_udf @@ -144,8 +129,6 @@ def test_mixed_sql(self): self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) def test_mixed_udf(self): - from pyspark.sql.functions import mean - df = self.data w = self.unbounded_window @@ -171,8 +154,6 @@ def test_mixed_udf(self): self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) def test_without_partitionBy(self): - from pyspark.sql.functions import mean - df = self.data w = self.unpartitioned_window mean_udf = self.pandas_agg_mean_udf @@ -187,8 +168,6 @@ def test_without_partitionBy(self): self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) def test_mixed_sql_and_udf(self): - from pyspark.sql.functions import max, min, rank, col - df = self.data w = self.unbounded_window ow = self.ordered_window @@ -221,8 +200,6 @@ def test_mixed_sql_and_udf(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data w = self.unbounded_window @@ -231,8 +208,6 @@ def test_array_type(self): self.assertEquals(result1.first()['v2'], [1.0, 2.0]) def test_invalid_args(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data w = self.unbounded_window ow = self.ordered_window From 9c481c7a6b8019c569a08b2645bf9c19ff84a9e5 Mon Sep 17 00:00:00 2001 From: jasonwayne Date: Fri, 14 Dec 2018 10:47:58 +0800 Subject: [PATCH 2280/2461] [SPARK-26360] remove redundant validateQuery call ## What changes were proposed in this pull request? remove a redundant `KafkaWriter.validateQuery` call in `KafkaSourceProvider ` ## How was this patch tested? Just removing duplicate codes, so I just build and run unit tests. Closes #23309 from JasonWayne/SPARK-26360. Authored-by: jasonwayne Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/kafka010/KafkaSourceProvider.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6a0c2088ac3d1..4b8b5c0019b44 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -266,8 +266,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - KafkaWriter.validateQuery(schema.toAttributes, producerParams, topic) - new KafkaStreamingWriteSupport(topic, producerParams, schema) } From 93139afb072d14870fb4eab01cb11df28eb0f8dd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 14 Dec 2018 10:50:48 +0800 Subject: [PATCH 2281/2461] [SPARK-26337][SQL][TEST] Add benchmark for LongToUnsafeRowMap ## What changes were proposed in this pull request? Regarding the performance issue of SPARK-26155, it reports the issue on TPC-DS. I think it is better to add a benchmark for `LongToUnsafeRowMap` which is the root cause of performance regression. It can be easier to show performance difference between different metric implementations in `LongToUnsafeRowMap`. ## How was this patch tested? Manually run added benchmark. Closes #23284 from viirya/SPARK-26337. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- ...HashedRelationMetricsBenchmark-results.txt | 11 +++ .../HashedRelationMetricsBenchmark.scala | 84 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala diff --git a/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt b/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt new file mode 100644 index 0000000000000..338244ad542f4 --- /dev/null +++ b/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt @@ -0,0 +1,11 @@ +================================================================================================ +LongToUnsafeRowMap metrics +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6 +Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz +LongToUnsafeRowMap metrics: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +LongToUnsafeRowMap 234 / 315 2.1 467.3 1.0X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala new file mode 100644 index 0000000000000..bdf753debe62a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeProjection} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap +import org.apache.spark.sql.types.LongType + +/** + * Benchmark to measure metrics performance at HashedRelation. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/HashedRelationMetricsBenchmark-results.txt". + * }}} + */ +object HashedRelationMetricsBenchmark extends SqlBasedBenchmark { + + def benchmarkLongToUnsafeRowMapMetrics(numRows: Int): Unit = { + runBenchmark("LongToUnsafeRowMap metrics") { + val benchmark = new Benchmark("LongToUnsafeRowMap metrics", numRows, output = output) + benchmark.addCase("LongToUnsafeRowMap") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + val keys = Range.Long(0, numRows, 1) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + + val threads = (0 to 100).map { _ => + val thread = new Thread { + override def run: Unit = { + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) == k) + } + } + } + thread.start() + thread + } + threads.map(_.join()) + map.free() + } + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + benchmarkLongToUnsafeRowMapMetrics(500000) + } +} From 2d8838dccde6d77b4ff1a15fdd6a0d4da2fda8c7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Dec 2018 20:55:12 -0800 Subject: [PATCH 2282/2461] [SPARK-26368][SQL] Make it clear that getOrInferFileFormatSchema doesn't create InMemoryFileIndex ## What changes were proposed in this pull request? I was looking at the code and it was a bit difficult to see the life cycle of InMemoryFileIndex passed into getOrInferFileFormatSchema, because once it is passed in, and another time it was created in getOrInferFileFormatSchema. It'd be easier to understand the life cycle if we move the creation of it out. ## How was this patch tested? This is a simple code move and should be covered by existing tests. Closes #23317 from rxin/SPARK-26368. Authored-by: Reynold Xin Signed-off-by: gatorsmile --- .../execution/datasources/DataSource.scala | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 795a6d0b6b040..fefff68c4ba8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -122,21 +122,14 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list + * @param getFileIndex [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { - // The operations below are expensive therefore try not to do them if we don't need to, e.g., - // in streaming mode, we have already inferred and registered partition columns, we will - // never have to materialize the lazy val below - lazy val tempFileIndex = fileIndex.getOrElse { - val globbedPaths = - checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) - createInMemoryFileIndex(globbedPaths) - } + getFileIndex: () => InMemoryFileIndex): (StructType, StructType) = { + lazy val tempFileIndex = getFileIndex() val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning @@ -236,7 +229,15 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, () => { + // The operations below are expensive therefore try not to do them if we don't need to, + // e.g., in streaming mode, we have already inferred and registered partition columns, + // we will never have to materialize the lazy val below + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) + }) SourceInfo( s"FileSource[$path]", StructType(dataSchema ++ partitionSchema), @@ -370,7 +371,7 @@ case class DataSource( } else { val index = createInMemoryFileIndex(globbedPaths) val (resultDataSchema, resultPartitionSchema) = - getOrInferFileFormatSchema(format, Some(index)) + getOrInferFileFormatSchema(format, () => index) (index, resultDataSchema, resultPartitionSchema) } From 3dda58af2b7f42beab736d856bf17b4d35c8866c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 15 Dec 2018 00:23:28 +0800 Subject: [PATCH 2283/2461] [SPARK-26370][SQL] Fix resolution of higher-order function for the same identifier. ## What changes were proposed in this pull request? When using a higher-order function with the same variable name as the existing columns in `Filter` or something which uses `Analyzer.resolveExpressionBottomUp` during the resolution, e.g.,: ```scala val df = Seq( (Seq(1, 9, 8, 7), 1, 2), (Seq(5, 9, 7), 2, 2), (Seq.empty, 3, 2), (null, 4, 2) ).toDF("i", "x", "d") checkAnswer(df.filter("exists(i, x -> x % d == 0)"), Seq(Row(Seq(1, 9, 8, 7), 1, 2))) checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), Seq(Row(1))) ``` the following exception happens: ``` java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to org.apache.spark.sql.catalyst.expressions.NamedExpression at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.immutable.List.map(List.scala:298) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185) at org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source) at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216) at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215) ... ``` because the `UnresolvedAttribute`s in `LambdaFunction` are unexpectedly resolved by the rule. This pr modified to use a placeholder `UnresolvedNamedLambdaVariable` to prevent unexpected resolution. ## How was this patch tested? Added a test and modified some tests. Closes #23320 from ueshin/issues/SPARK-26370/hof_resolution. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan --- .../analysis/higherOrderFunctions.scala | 5 ++-- .../expressions/higherOrderFunctions.scala | 26 +++++++++++++++++-- .../sql/catalyst/parser/AstBuilder.scala | 7 +++-- .../ResolveLambdaVariablesSuite.scala | 10 ++++--- ...ReplaceNullWithFalseInPredicateSuite.scala | 14 +++++----- .../parser/ExpressionParserSuite.scala | 6 +++-- .../typeCoercion/native/mapZipWith.sql.out | 4 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 20 ++++++++++++++ 8 files changed, 72 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index a8a7bbd9f9cd0..1cd7f412bb678 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap)) - case u @ UnresolvedAttribute(name +: nestedFields) => + case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) => parentLambdaMap.get(canonicalizer(name)) match { case Some(lambda) => nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) => ExtractValue(expr, Literal(fieldName), conf.resolver) } - case None => u + case None => + UnresolvedAttribute(u.nameParts) } case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a8639d29f964d..7141b6e996389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods +/** + * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]]. + */ +case class UnresolvedNamedLambdaVariable(nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable { + + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier") + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") + override lazy val resolved = false + + override def toString: String = s"lambda '$name" + + override def sql: String = name +} + /** * A named lambda variable. */ @@ -79,7 +101,7 @@ case class LambdaFunction( object LambdaFunction { val identity: LambdaFunction = { - val id = UnresolvedAttribute.quoted("id") + val id = UnresolvedNamedLambdaVariable(Seq("id")) LambdaFunction(id, Seq(id)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 672bffcfc0cad..8959f78b656d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { val arguments = ctx.IDENTIFIER().asScala.map { name => - UnresolvedAttribute.quoted(name.getText) + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) } - LambdaFunction(expression(ctx.expression), arguments) + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala index c4171c75ecd03..a5847ba7c522d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest { comparePlans(Analyzer.execute(plan(e1)), plan(e2)) } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) + test("resolution - no op") { checkExpression(key, key) } test("resolution - simple") { - val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil)) + val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil)) val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) checkExpression(in, out) } test("resolution - nested") { val in = ArrayTransform(values2, LambdaFunction( - ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil)) + ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil)) val out = ArrayTransform(values2, LambdaFunction( ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) checkExpression(in, out) @@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest { test("fail - name collisions") { val p = plan(ArrayTransform(values1, - LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil))) + LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil))) val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage assert(msg.contains("arguments should not have names that are semantically the same")) } test("fail - lambda arguments") { val p = plan(ArrayTransform(values1, - LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil))) + LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil))) val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage assert(msg.contains("does not match the number of arguments expected")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index ee0d04da3e46c..748075bfd6a68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) + test("replace nulls in lambda function of ArrayFilter") { - testHigherOrderFunc('a, ArrayFilter, Seq('e)) + testHigherOrderFunc('a, ArrayFilter, Seq(lv('e))) } test("replace nulls in lambda function of ArrayExists") { - testHigherOrderFunc('a, ArrayExists, Seq('e)) + testHigherOrderFunc('a, ArrayExists, Seq(lv('e))) } test("replace nulls in lambda function of MapFilter") { - testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v))) } test("inability to replace nulls in arbitrary higher-order function") { val lambdaFunc = LambdaFunction( - function = If('e > 0, Literal(null, BooleanType), TrueLiteral), - arguments = Seq[NamedExpression]('e)) + function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression](lv('e))) val column = ArrayTransform('a, lambdaFunc) testProjection(originalExpr = column, expectedExpr = column) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4df22c5b29fa..8bcc69d580d83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest { intercept("foo(a x)", "extraneous input 'x'") } + private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) + test("lambda functions") { - assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr))) - assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr))) + assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x)))) + assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y)))) } test("window function expressions") { diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 35740094ba53e..86a578ca013df 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -85,7 +85,7 @@ FROM various_maps struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 -- !query 6 @@ -113,7 +113,7 @@ FROM various_maps struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query 9 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e6d1a038a5918..b7fc9570af919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex.getMessage.contains("Cannot use null as map key")) } + + test("SPARK-26370: Fix resolution of higher-order function for the same identifier") { + val df = Seq( + (Seq(1, 9, 8, 7), 1, 2), + (Seq(5, 9, 7), 2, 2), + (Seq.empty, 3, 2), + (null, 4, 2) + ).toDF("i", "x", "d") + + checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"), + Seq( + Row(1, true), + Row(2, false), + Row(3, false), + Row(4, null))) + checkAnswer(df.filter("exists(i, x -> x % d == 0)"), + Seq(Row(Seq(1, 9, 8, 7), 1, 2))) + checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), + Seq(Row(1))) + } } object DataFrameFunctionsSuite { From d25e443eec6efc9172eade6ac11be7b3ff04759d Mon Sep 17 00:00:00 2001 From: CarolinPeng <00244106@zte.intra> Date: Fri, 14 Dec 2018 14:23:21 -0600 Subject: [PATCH 2284/2461] [MINOR][SQL] Some errors in the notes. ## What changes were proposed in this pull request? When using ordinals to access linked list, the time cost is O(n). ## How was this patch tested? Existing tests. Closes #23280 from CarolinePeng/update_Two. Authored-by: CarolinPeng <00244106@zte.intra> Signed-off-by: Sean Owen --- .../org/apache/spark/sql/catalyst/expressions/package.scala | 2 +- .../apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 0083ee64653e9..bf18e8bcb52df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -101,7 +101,7 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) } - // It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when + // It's possible that `attrs` is a linked list, which can lead to bad O(n) loops when // accessing attributes by their ordinals. To avoid this performance penalty, convert the input // to an array. @transient private lazy val attrsArray = attrs.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a520eba001af1..3ad2ee6923615 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -93,7 +93,7 @@ abstract class LogicalPlan /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as - * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. + * string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolveChildren( nameParts: Seq[String], From 1b604c1fd0b9ef17b394818fbd6c546bc01cdd8c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 15 Dec 2018 13:52:07 +0800 Subject: [PATCH 2285/2461] [SPARK-26265][CORE][FOLLOWUP] Put freePage into a finally block ## What changes were proposed in this pull request? Based on the [comment](https://github.com/apache/spark/pull/23272#discussion_r240735509), it seems to be better to put `freePage` into a `finally` block. This patch as a follow-up to do so. ## How was this patch tested? Existing tests. Closes #23294 from viirya/SPARK-26265-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Hyukjin Kwon --- .../spark/unsafe/map/BytesToBytesMap.java | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index fbba002f1f80f..7df8aafb2b674 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -262,36 +262,39 @@ private void advanceToNextPage() { // reference to the page to free and free it after releasing the lock of `MapIterator`. MemoryBlock pageToFree = null; - synchronized (this) { - int nextIdx = dataPages.indexOf(currentPage) + 1; - if (destructive && currentPage != null) { - dataPages.remove(currentPage); - pageToFree = currentPage; - nextIdx --; - } - if (dataPages.size() > nextIdx) { - currentPage = dataPages.get(nextIdx); - pageBaseObject = currentPage.getBaseObject(); - offsetInPage = currentPage.getBaseOffset(); - recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); - offsetInPage += UnsafeAlignedOffset.getUaoSize(); - } else { - currentPage = null; - if (reader != null) { - handleFailedDelete(); + try { + synchronized (this) { + int nextIdx = dataPages.indexOf(currentPage) + 1; + if (destructive && currentPage != null) { + dataPages.remove(currentPage); + pageToFree = currentPage; + nextIdx--; } - try { - Closeables.close(reader, /* swallowIOException = */ false); - reader = spillWriters.getFirst().getReader(serializerManager); - recordsInPage = -1; - } catch (IOException e) { - // Scala iterator does not handle exception - Platform.throwException(e); + if (dataPages.size() > nextIdx) { + currentPage = dataPages.get(nextIdx); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + recordsInPage = UnsafeAlignedOffset.getSize(pageBaseObject, offsetInPage); + offsetInPage += UnsafeAlignedOffset.getUaoSize(); + } else { + currentPage = null; + if (reader != null) { + handleFailedDelete(); + } + try { + Closeables.close(reader, /* swallowIOException = */ false); + reader = spillWriters.getFirst().getReader(serializerManager); + recordsInPage = -1; + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } } } - } - if (pageToFree != null) { - freePage(pageToFree); + } finally { + if (pageToFree != null) { + freePage(pageToFree); + } } } From 9ccae0c9e7d1a0a704e8cd7574ba508419e05e30 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sat, 15 Dec 2018 13:55:24 +0800 Subject: [PATCH 2286/2461] [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts ## What changes were proposed in this pull request? Multiple SparkContexts are discouraged and it has been warning for last 4 years, see SPARK-4180. It could cause arbitrary and mysterious error cases, see SPARK-2243. Honestly, I didn't even know Spark still allows it, which looks never officially supported, see SPARK-2243. I believe It should be good timing now to remove this configuration. ## How was this patch tested? Each doc was manually checked and manually tested: ``` $ ./bin/spark-shell --conf=spark.driver.allowMultipleContexts=true ... scala> new SparkContext() org.apache.spark.SparkException: Only one SparkContext should be running in this JVM (see SPARK-2243).The currently running SparkContext was created at: org.apache.spark.sql.SparkSession$Builder.getOrCreate(SparkSession.scala:939) ... org.apache.spark.SparkContext$.$anonfun$assertNoOtherContextIsRunning$2(SparkContext.scala:2435) at scala.Option.foreach(Option.scala:274) at org.apache.spark.SparkContext$.assertNoOtherContextIsRunning(SparkContext.scala:2432) at org.apache.spark.SparkContext$.markPartiallyConstructed(SparkContext.scala:2509) at org.apache.spark.SparkContext.(SparkContext.scala:80) at org.apache.spark.SparkContext.(SparkContext.scala:112) ... 49 elided ``` Closes #23311 from HyukjinKwon/SPARK-26362. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/SparkContext.scala | 65 +++++++------------ .../spark/api/java/JavaSparkContext.scala | 4 +- .../org/apache/spark/SparkContextSuite.scala | 19 +----- .../ExternalClusterManagerSuite.scala | 3 +- docs/rdd-programming-guide.md | 2 +- project/MimaExcludes.scala | 4 ++ python/pyspark/context.py | 3 + .../execution/ExchangeCoordinatorSuite.scala | 1 - 8 files changed, 34 insertions(+), 67 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 696dafda6d1ec..09cc346db0ed2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -64,9 +64,8 @@ import org.apache.spark.util.logging.DriverLogger * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * - * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before - * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. - * + * @note Only one `SparkContext` should be active per JVM. You must `stop()` the + * active `SparkContext` before creating a new one. * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ @@ -75,14 +74,10 @@ class SparkContext(config: SparkConf) extends Logging { // The call site where this SparkContext was constructed. private val creationSite: CallSite = Utils.getCallSite() - // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active - private val allowMultipleContexts: Boolean = - config.getBoolean("spark.driver.allowMultipleContexts", false) - // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having started construction. // NOTE: this must be placed at the beginning of the SparkContext constructor. - SparkContext.markPartiallyConstructed(this, allowMultipleContexts) + SparkContext.markPartiallyConstructed(this) val startTime = System.currentTimeMillis() @@ -2392,7 +2387,7 @@ class SparkContext(config: SparkConf) extends Logging { // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having finished construction. // NOTE: this must be placed at the end of the SparkContext constructor. - SparkContext.setActiveContext(this, allowMultipleContexts) + SparkContext.setActiveContext(this) } /** @@ -2409,18 +2404,18 @@ object SparkContext extends Logging { private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() /** - * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. + * Access to this field is guarded by `SPARK_CONTEXT_CONSTRUCTOR_LOCK`. */ private val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** - * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * Points to a partially-constructed SparkContext if another thread is in the SparkContext * constructor, or `None` if no SparkContext is being constructed. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + * Access to this field is guarded by `SPARK_CONTEXT_CONSTRUCTOR_LOCK`. */ private var contextBeingConstructed: Option[SparkContext] = None @@ -2428,24 +2423,16 @@ object SparkContext extends Logging { * Called to ensure that no other SparkContext is running in this JVM. * * Throws an exception if a running context is detected and logs a warning if another thread is - * constructing a SparkContext. This warning is necessary because the current locking scheme + * constructing a SparkContext. This warning is necessary because the current locking scheme * prevents us from reliably distinguishing between cases where another context is being * constructed and cases where another constructor threw an exception. */ - private def assertNoOtherContextIsRunning( - sc: SparkContext, - allowMultipleContexts: Boolean): Unit = { + private def assertNoOtherContextIsRunning(sc: SparkContext): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { Option(activeContext.get()).filter(_ ne sc).foreach { ctx => - val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + - " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + val errMsg = "Only one SparkContext should be running in this JVM (see SPARK-2243)." + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" - val exception = new SparkException(errMsg) - if (allowMultipleContexts) { - logWarning("Multiple running SparkContexts detected in the same JVM!", exception) - } else { - throw exception - } + throw new SparkException(errMsg) } contextBeingConstructed.filter(_ ne sc).foreach { otherContext => @@ -2454,7 +2441,7 @@ object SparkContext extends Logging { val otherContextCreationSite = Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + - " constructor). This may indicate an error, since only one SparkContext may be" + + " constructor). This may indicate an error, since only one SparkContext should be" + " running in this JVM (see SPARK-2243)." + s" The other SparkContext was created at:\n$otherContextCreationSite" logWarning(warnMsg) @@ -2467,8 +2454,6 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * @note This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. * @param config `SparkConfig` that will be used for initialisation of the `SparkContext` * @return current `SparkContext` (or a new one if it wasn't created before the function call) */ @@ -2477,7 +2462,7 @@ object SparkContext extends Logging { // from assertNoOtherContextIsRunning within setActiveContext SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { if (activeContext.get() == null) { - setActiveContext(new SparkContext(config), allowMultipleContexts = false) + setActiveContext(new SparkContext(config)) } else { if (config.getAll.nonEmpty) { logWarning("Using an existing SparkContext; some configuration may not take effect.") @@ -2494,14 +2479,12 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * @note This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. * @return current `SparkContext` (or a new one if wasn't created before the function call) */ def getOrCreate(): SparkContext = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { if (activeContext.get() == null) { - setActiveContext(new SparkContext(), allowMultipleContexts = false) + setActiveContext(new SparkContext()) } activeContext.get() } @@ -2516,16 +2499,14 @@ object SparkContext extends Logging { /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is - * running. Throws an exception if a running context is detected and logs a warning if another - * thread is constructing a SparkContext. This warning is necessary because the current locking + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking * scheme prevents us from reliably distinguishing between cases where another context is being * constructed and cases where another constructor threw an exception. */ - private[spark] def markPartiallyConstructed( - sc: SparkContext, - allowMultipleContexts: Boolean): Unit = { + private[spark] def markPartiallyConstructed(sc: SparkContext): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - assertNoOtherContextIsRunning(sc, allowMultipleContexts) + assertNoOtherContextIsRunning(sc) contextBeingConstructed = Some(sc) } } @@ -2534,18 +2515,16 @@ object SparkContext extends Logging { * Called at the end of the SparkContext constructor to ensure that no other SparkContext has * raced with this constructor and started. */ - private[spark] def setActiveContext( - sc: SparkContext, - allowMultipleContexts: Boolean): Unit = { + private[spark] def setActiveContext(sc: SparkContext): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - assertNoOtherContextIsRunning(sc, allowMultipleContexts) + assertNoOtherContextIsRunning(sc) contextBeingConstructed = None activeContext.set(sc) } } /** - * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's * also called in unit tests to prevent a flood of warnings from test suites that don't / can't * properly clean up their SparkContexts. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 03f259d73e975..2f74d09b3a2bc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -40,8 +40,8 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD} * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. * - * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before - * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * @note Only one `SparkContext` should be active per JVM. You must `stop()` the + * active `SparkContext` before creating a new one. */ class JavaSparkContext(val sc: SparkContext) extends Closeable { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ec4c7efb5835a..66de2f2ac86a4 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -44,7 +44,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 val conf = new SparkConf().setAppName("test").setMaster("local") - .set("spark.driver.allowMultipleContexts", "false") sc = new SparkContext(conf) val envBefore = SparkEnv.get // A SparkContext is already running, so we shouldn't be able to create a second one @@ -58,7 +57,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test("Can still construct a new SparkContext after failing to construct a previous one") { - val conf = new SparkConf().set("spark.driver.allowMultipleContexts", "false") + val conf = new SparkConf() // This is an invalid configuration (no app name or master URL) intercept[SparkException] { new SparkContext(conf) @@ -67,18 +66,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(conf.setMaster("local").setAppName("test")) } - test("Check for multiple SparkContexts can be disabled via undocumented debug option") { - var secondSparkContext: SparkContext = null - try { - val conf = new SparkConf().setAppName("test").setMaster("local") - .set("spark.driver.allowMultipleContexts", "true") - sc = new SparkContext(conf) - secondSparkContext = new SparkContext(conf) - } finally { - Option(secondSparkContext).foreach(_.stop()) - } - } - test("Test getOrCreate") { var sc2: SparkContext = null SparkContext.clearActiveContext() @@ -92,10 +79,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc === sc2) assert(sc eq sc2) - // Try creating second context to confirm that it's still possible, if desired - sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") - .set("spark.driver.allowMultipleContexts", "true")) - sc2.stop() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 0621c98d41184..30d0966691a3c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -25,8 +25,7 @@ import org.apache.spark.util.AccumulatorV2 class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext { test("launch of backend and scheduler") { - val conf = new SparkConf().setMaster("myclusterManager"). - setAppName("testcm").set("spark.driver.allowMultipleContexts", "true") + val conf = new SparkConf().setMaster("myclusterManager").setAppName("testcm") sc = new SparkContext(conf) // check if the scheduler components are created and initialized sc.schedulerBackend match { diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 2d1ddae5780de..308a8ea653909 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -138,7 +138,7 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. -Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. +Only one SparkContext should be active per JVM. You must `stop()` the active SparkContext before creating a new one. {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 883913332ca1e..7bb70a29195d6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -220,6 +220,10 @@ object MimaExcludes { // [SPARK-26139] Implement shuffle write metrics in SQL ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"), + // [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.setActiveContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.markPartiallyConstructed"), + // Data Source V2 API changes (problem: Problem) => problem match { case MissingClassProblem(cls) => diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1180bf91baa5a..6137ed25a0dd9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -63,6 +63,9 @@ class SparkContext(object): Main entry point for Spark functionality. A SparkContext represents the connection to a Spark cluster, and can be used to create L{RDD} and broadcast variables on that cluster. + + .. note:: Only one :class:`SparkContext` should be active per JVM. You must `stop()` + the active :class:`SparkContext` before creating a new one. """ _gateway = None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 6ad025f37e440..4a439940beb74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -263,7 +263,6 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .setMaster("local[*]") .setAppName("test") .set("spark.ui.enabled", "false") - .set("spark.driver.allowMultipleContexts", "true") .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") From 860f4497f2a59b21d455ec8bfad9ae15d2fd4d2e Mon Sep 17 00:00:00 2001 From: Jing Chen He Date: Sat, 15 Dec 2018 08:41:16 -0600 Subject: [PATCH 2287/2461] [SPARK-26315][PYSPARK] auto cast threshold from Integer to Float in approxSimilarityJoin of BucketedRandomProjectionLSHModel ## What changes were proposed in this pull request? If the input parameter 'threshold' to the function approxSimilarityJoin is not a float, we would get an exception. The fix is to convert the 'threshold' into a float before calling the java implementation method. ## How was this patch tested? Added a new test case. Without this fix, the test will throw an exception as reported in the JIRA. With the fix, the test passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23313 from jerryjch/SPARK-26315. Authored-by: Jing Chen He Signed-off-by: Sean Owen --- python/pyspark/ml/feature.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c9507c20918e3..08ae58246adb6 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -192,6 +192,7 @@ def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol") "datasetA" and "datasetB", and a column "distCol" is added to show the distance between each pair. """ + threshold = TypeConverters.toFloat(threshold) return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol) @@ -239,6 +240,16 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp | 3| 6| 2.23606797749979| +---+---+-----------------+ ... + >>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select( + ... col("datasetA.id").alias("idA"), + ... col("datasetB.id").alias("idB"), + ... col("EuclideanDistance")).show() + +---+---+-----------------+ + |idA|idB|EuclideanDistance| + +---+---+-----------------+ + | 3| 6| 2.23606797749979| + +---+---+-----------------+ + ... >>> brpPath = temp_path + "/brp" >>> brp.save(brpPath) >>> brp2 = BucketedRandomProjectionLSH.load(brpPath) From 8a27952cdbf492939d9bda59e2f516f574581636 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 16 Dec 2018 09:32:13 +0800 Subject: [PATCH 2288/2461] [SPARK-26243][SQL] Use java.time API for parsing timestamps and dates from JSON ## What changes were proposed in this pull request? In the PR, I propose to switch on **java.time API** for parsing timestamps and dates from JSON inputs with microseconds precision. The SQL config `spark.sql.legacy.timeParser.enabled` allow to switch back to previous behavior with using `java.text.SimpleDateFormat`/`FastDateFormat` for parsing/generating timestamps/dates. ## How was this patch tested? It was tested by `JsonExpressionsSuite`, `JsonFunctionsSuite` and `JsonSuite`. Closes #23196 from MaxGekk/json-time-parser. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 +- .../sql/catalyst/csv/CSVInferSchema.scala | 6 +- .../sql/catalyst/csv/UnivocityGenerator.scala | 8 +- .../sql/catalyst/csv/UnivocityParser.scala | 6 +- .../spark/sql/catalyst/json/JSONOptions.scala | 10 +- .../sql/catalyst/json/JacksonGenerator.scala | 14 +- .../sql/catalyst/json/JacksonParser.scala | 35 +-- ...rmatter.scala => TimestampFormatter.scala} | 93 ++++---- .../sql/util/DateTimeFormatterSuite.scala | 103 --------- .../util/DateTimestampFormatterSuite.scala | 174 +++++++++++++++ .../datasources/json/JsonSuite.scala | 201 ++++++++++-------- .../sql/sources/HadoopFsRelationTest.scala | 105 ++++----- 12 files changed, 422 insertions(+), 335 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/{DateTimeFormatter.scala => TimestampFormatter.scala} (63%) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 8834e8991d8c3..115fc6516fb4c 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -35,7 +35,7 @@ displayTitle: Spark SQL Upgrading Guide - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. - - Since Spark 3.0, CSV datasource uses java.time API for parsing and generating CSV content. New formatting implementation supports date/timestamp patterns conformed to ISO 8601. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + - Since Spark 3.0, CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpuse with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 345dc4d41993e..35ade136cc607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,13 +22,13 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.DateTimeFormatter +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @transient - private lazy val timeParser = DateTimeFormatter( + private lazy val timestampParser = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) @@ -160,7 +160,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private def tryParseTimestamp(field: String): DataType = { // This case infers a custom `dataFormat` is set. - if ((allCatch opt timeParser.parse(field)).isDefined) { + if ((allCatch opt timestampParser.parse(field)).isDefined) { TimestampType } else { tryParseBoolean(field) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index af09cd6c8449b..f012d96138f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -22,7 +22,7 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter} +import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.types._ class UnivocityGenerator( @@ -41,18 +41,18 @@ class UnivocityGenerator( private val valueConverters: Array[ValueConverter] = schema.map(_.dataType).map(makeConverter).toArray - private val timeFormatter = DateTimeFormatter( + private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) - private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => (row: InternalRow, ordinal: Int) => dateFormatter.format(row.getInt(ordinal)) case TimestampType => - (row: InternalRow, ordinal: Int) => timeFormatter.format(row.getLong(ordinal)) + (row: InternalRow, ordinal: Int) => timestampFormatter.format(row.getLong(ordinal)) case udt: UserDefinedType[_] => makeConverter(udt.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 0f375e036029c..ed089120055e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -74,11 +74,11 @@ class UnivocityParser( private val row = new GenericInternalRow(requiredSchema.length) - private val timeFormatter = DateTimeFormatter( + private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) - private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -158,7 +158,7 @@ class UnivocityParser( } case _: TimestampType => (d: String) => - nullSafeDatum(d, name, nullable, options)(timeFormatter.parse) + nullSafeDatum(d, name, nullable, options)(timestampFormatter.parse) case _: DateType => (d: String) => nullSafeDatum(d, name, nullable, options)(dateFormatter.parse) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index e10b8a327c01a..eaff3fa7bec25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -21,7 +21,6 @@ import java.nio.charset.{Charset, StandardCharsets} import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} -import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ @@ -82,13 +81,10 @@ private[sql] class JSONOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) - // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. - val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) + val dateFormat: String = parameters.getOrElse("dateFormat", "yyyy-MM-dd") - val timestampFormat: FastDateFormat = - FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) + val timestampFormat: String = + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index d02a2be8ddad6..951f5190cd504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -23,7 +23,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ /** @@ -77,6 +77,12 @@ private[sql] class JacksonGenerator( private val lineSeparator: String = options.lineSeparatorInWrite + private val timestampFormatter = TimestampFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -116,14 +122,12 @@ private[sql] class JacksonGenerator( case TimestampType => (row: SpecializedGetters, ordinal: Int) => - val timestampString = - options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + val timestampString = timestampFormatter.format(row.getLong(ordinal)) gen.writeString(timestampString) case DateType => (row: SpecializedGetters, ordinal: Int) => - val dateString = - options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + val dateString = dateFormatter.format(row.getInt(ordinal)) gen.writeString(dateString) case BinaryType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7e3bd4df51bb7..3f245e1400fa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -55,6 +55,12 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) + private val timestampFormatter = TimestampFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + private val dateFormatter = DateFormatter(options.dateFormat, options.locale) + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -218,17 +224,7 @@ class JacksonParser( case TimestampType => (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { case VALUE_STRING if parser.getTextLength >= 1 => - val stringValue = parser.getText - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - Long.box { - Try(options.timestampFormat.parse(stringValue).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(stringValue).getTime * 1000L - } - } + timestampFormatter.parse(parser.getText) case VALUE_NUMBER_INT => parser.getLongValue * 1000000L @@ -237,22 +233,7 @@ class JacksonParser( case DateType => (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { case VALUE_STRING if parser.getTextLength >= 1 => - val stringValue = parser.getText - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681.x - Int.box { - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(stringValue).getTime)) - .orElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(stringValue).getTime)) - } - .getOrElse { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - stringValue.toInt - } - } + dateFormatter.parse(parser.getText) } case BinaryType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala similarity index 63% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index ad1f4131de2f6..2b8d22dde9267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.time._ import java.time.format.DateTimeFormatterBuilder -import java.time.temporal.{ChronoField, TemporalQueries} +import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} import java.util.{Locale, TimeZone} import scala.util.Try @@ -28,31 +28,44 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.sql.internal.SQLConf -sealed trait DateTimeFormatter { +sealed trait TimestampFormatter { def parse(s: String): Long // returns microseconds since epoch def format(us: Long): String } -class Iso8601DateTimeFormatter( +trait FormatterUtils { + protected def zoneId: ZoneId + protected def buildFormatter( + pattern: String, + locale: Locale): java.time.format.DateTimeFormatter = { + new DateTimeFormatterBuilder() + .appendPattern(pattern) + .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) + .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) + .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) + .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) + .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) + .toFormatter(locale) + } + protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor): java.time.Instant = { + val localDateTime = LocalDateTime.from(temporalAccessor) + val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) + Instant.from(zonedDateTime) + } +} + +class Iso8601TimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends DateTimeFormatter { - val formatter = new DateTimeFormatterBuilder() - .appendPattern(pattern) - .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) - .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) - .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) - .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) - .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) - .toFormatter(locale) + locale: Locale) extends TimestampFormatter with FormatterUtils { + val zoneId = timeZone.toZoneId + val formatter = buildFormatter(pattern, locale) def toInstant(s: String): Instant = { val temporalAccessor = formatter.parse(s) if (temporalAccessor.query(TemporalQueries.offset()) == null) { - val localDateTime = LocalDateTime.from(temporalAccessor) - val zonedDateTime = ZonedDateTime.of(localDateTime, timeZone.toZoneId) - Instant.from(zonedDateTime) + toInstantWithZoneId(temporalAccessor) } else { Instant.from(temporalAccessor) } @@ -75,10 +88,10 @@ class Iso8601DateTimeFormatter( } } -class LegacyDateTimeFormatter( +class LegacyTimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends DateTimeFormatter { + locale: Locale) extends TimestampFormatter { val format = FastDateFormat.getInstance(pattern, timeZone, locale) protected def toMillis(s: String): Long = format.parse(s).getTime @@ -90,21 +103,21 @@ class LegacyDateTimeFormatter( } } -class LegacyFallbackDateTimeFormatter( +class LegacyFallbackTimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends LegacyDateTimeFormatter(pattern, timeZone, locale) { + locale: Locale) extends LegacyTimestampFormatter(pattern, timeZone, locale) { override def toMillis(s: String): Long = { Try {super.toMillis(s)}.getOrElse(DateTimeUtils.stringToTime(s).getTime) } } -object DateTimeFormatter { - def apply(format: String, timeZone: TimeZone, locale: Locale): DateTimeFormatter = { +object TimestampFormatter { + def apply(format: String, timeZone: TimeZone, locale: Locale): TimestampFormatter = { if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyFallbackDateTimeFormatter(format, timeZone, locale) + new LegacyFallbackTimestampFormatter(format, timeZone, locale) } else { - new Iso8601DateTimeFormatter(format, timeZone, locale) + new Iso8601TimestampFormatter(format, timeZone, locale) } } } @@ -116,13 +129,19 @@ sealed trait DateFormatter { class Iso8601DateFormatter( pattern: String, - timeZone: TimeZone, - locale: Locale) extends DateFormatter { + locale: Locale) extends DateFormatter with FormatterUtils { + + val zoneId = ZoneId.of("UTC") + + val formatter = buildFormatter(pattern, locale) - val dateTimeFormatter = new Iso8601DateTimeFormatter(pattern, timeZone, locale) + def toInstant(s: String): Instant = { + val temporalAccessor = formatter.parse(s) + toInstantWithZoneId(temporalAccessor) + } override def parse(s: String): Int = { - val seconds = dateTimeFormatter.toInstant(s).getEpochSecond + val seconds = toInstant(s).getEpochSecond val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) days.toInt @@ -130,15 +149,12 @@ class Iso8601DateFormatter( override def format(days: Int): String = { val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) - dateTimeFormatter.formatter.withZone(timeZone.toZoneId).format(instant) + formatter.withZone(zoneId).format(instant) } } -class LegacyDateFormatter( - pattern: String, - timeZone: TimeZone, - locale: Locale) extends DateFormatter { - val format = FastDateFormat.getInstance(pattern, timeZone, locale) +class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { + val format = FastDateFormat.getInstance(pattern, locale) def parse(s: String): Int = { val milliseconds = format.parse(s).getTime @@ -153,8 +169,7 @@ class LegacyDateFormatter( class LegacyFallbackDateFormatter( pattern: String, - timeZone: TimeZone, - locale: Locale) extends LegacyDateFormatter(pattern, timeZone, locale) { + locale: Locale) extends LegacyDateFormatter(pattern, locale) { override def parse(s: String): Int = { Try(super.parse(s)).orElse { // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards @@ -169,11 +184,11 @@ class LegacyFallbackDateFormatter( } object DateFormatter { - def apply(format: String, timeZone: TimeZone, locale: Locale): DateFormatter = { + def apply(format: String, locale: Locale): DateFormatter = { if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyFallbackDateFormatter(format, timeZone, locale) + new LegacyFallbackDateFormatter(format, locale) } else { - new Iso8601DateFormatter(format, timeZone, locale) + new Iso8601DateFormatter(format, locale) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala deleted file mode 100644 index 02d4ee0490604..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util - -import java.util.{Locale, TimeZone} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter, DateTimeTestUtils} - -class DateTimeFormatterSuite extends SparkFunSuite { - test("parsing dates using time zones") { - val localDate = "2018-12-02" - val expectedDays = Map( - "UTC" -> 17867, - "PST" -> 17867, - "CET" -> 17866, - "Africa/Dakar" -> 17867, - "America/Los_Angeles" -> 17867, - "Antarctica/Vostok" -> 17866, - "Asia/Hong_Kong" -> 17866, - "Europe/Amsterdam" -> 17866) - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US) - val daysSinceEpoch = formatter.parse(localDate) - assert(daysSinceEpoch === expectedDays(timeZone)) - } - } - - test("parsing timestamps using time zones") { - val localDate = "2018-12-02T10:11:12.001234" - val expectedMicros = Map( - "UTC" -> 1543745472001234L, - "PST" -> 1543774272001234L, - "CET" -> 1543741872001234L, - "Africa/Dakar" -> 1543745472001234L, - "America/Los_Angeles" -> 1543774272001234L, - "Antarctica/Vostok" -> 1543723872001234L, - "Asia/Hong_Kong" -> 1543716672001234L, - "Europe/Amsterdam" -> 1543741872001234L) - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateTimeFormatter( - "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", - TimeZone.getTimeZone(timeZone), - Locale.US) - val microsSinceEpoch = formatter.parse(localDate) - assert(microsSinceEpoch === expectedMicros(timeZone)) - } - } - - test("format dates using time zones") { - val daysSinceEpoch = 17867 - val expectedDate = Map( - "UTC" -> "2018-12-02", - "PST" -> "2018-12-01", - "CET" -> "2018-12-02", - "Africa/Dakar" -> "2018-12-02", - "America/Los_Angeles" -> "2018-12-01", - "Antarctica/Vostok" -> "2018-12-02", - "Asia/Hong_Kong" -> "2018-12-02", - "Europe/Amsterdam" -> "2018-12-02") - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US) - val date = formatter.format(daysSinceEpoch) - assert(date === expectedDate(timeZone)) - } - } - - test("format timestamps using time zones") { - val microsSinceEpoch = 1543745472001234L - val expectedTimestamp = Map( - "UTC" -> "2018-12-02T10:11:12.001234", - "PST" -> "2018-12-02T02:11:12.001234", - "CET" -> "2018-12-02T11:11:12.001234", - "Africa/Dakar" -> "2018-12-02T10:11:12.001234", - "America/Los_Angeles" -> "2018-12-02T02:11:12.001234", - "Antarctica/Vostok" -> "2018-12-02T16:11:12.001234", - "Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234", - "Europe/Amsterdam" -> "2018-12-02T11:11:12.001234") - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - val formatter = DateTimeFormatter( - "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", - TimeZone.getTimeZone(timeZone), - Locale.US) - val timestamp = formatter.format(microsSinceEpoch) - assert(timestamp === expectedTimestamp(timeZone)) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala new file mode 100644 index 0000000000000..43e348c7eebf4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.{Locale, TimeZone} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf + +class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { + test("parsing dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val daysSinceEpoch = formatter.parse("2018-12-02") + assert(daysSinceEpoch === 17867) + } + } + } + + test("parsing timestamps using time zones") { + val localDate = "2018-12-02T10:11:12.001234" + val expectedMicros = Map( + "UTC" -> 1543745472001234L, + "PST" -> 1543774272001234L, + "CET" -> 1543741872001234L, + "Africa/Dakar" -> 1543745472001234L, + "America/Los_Angeles" -> 1543774272001234L, + "Antarctica/Vostok" -> 1543723872001234L, + "Asia/Hong_Kong" -> 1543716672001234L, + "Europe/Amsterdam" -> 1543741872001234L) + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = TimestampFormatter( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + TimeZone.getTimeZone(timeZone), + Locale.US) + val microsSinceEpoch = formatter.parse(localDate) + assert(microsSinceEpoch === expectedMicros(timeZone)) + } + } + + test("format dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(17867) + assert(date === "2018-12-02") + } + } + } + + test("format timestamps using time zones") { + val microsSinceEpoch = 1543745472001234L + val expectedTimestamp = Map( + "UTC" -> "2018-12-02T10:11:12.001234", + "PST" -> "2018-12-02T02:11:12.001234", + "CET" -> "2018-12-02T11:11:12.001234", + "Africa/Dakar" -> "2018-12-02T10:11:12.001234", + "America/Los_Angeles" -> "2018-12-02T02:11:12.001234", + "Antarctica/Vostok" -> "2018-12-02T16:11:12.001234", + "Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234", + "Europe/Amsterdam" -> "2018-12-02T11:11:12.001234") + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + val formatter = TimestampFormatter( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSS", + TimeZone.getTimeZone(timeZone), + Locale.US) + val timestamp = formatter.format(microsSinceEpoch) + assert(timestamp === expectedTimestamp(timeZone)) + } + } + + test("roundtrip timestamp -> micros -> timestamp using timezones") { + Seq( + -58710115316212000L, + -18926315945345679L, + -9463427405253013L, + -244000001L, + 0L, + 99628200102030L, + 1543749753123456L, + 2177456523456789L, + 11858049903010203L).foreach { micros => + DateTimeTestUtils.outstandingTimezones.foreach { timeZone => + val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US) + val timestamp = formatter.format(micros) + val parsed = formatter.parse(timestamp) + assert(micros === parsed) + } + } + } + + test("roundtrip micros -> timestamp -> micros using timezones") { + Seq( + "0109-07-20T18:38:03.788000", + "1370-04-01T10:00:54.654321", + "1670-02-11T14:09:54.746987", + "1969-12-31T23:55:55.999999", + "1970-01-01T00:00:00.000000", + "1973-02-27T02:30:00.102030", + "2018-12-02T11:22:33.123456", + "2039-01-01T01:02:03.456789", + "2345-10-07T22:45:03.010203").foreach { timestamp => + DateTimeTestUtils.outstandingTimezones.foreach { timeZone => + val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US) + val micros = formatter.parse(timestamp) + val formatted = formatter.format(micros) + assert(timestamp === formatted) + } + } + } + + test("roundtrip date -> days -> date") { + Seq( + "0050-01-01", + "0953-02-02", + "1423-03-08", + "1969-12-31", + "1972-08-25", + "1975-09-26", + "2018-12-12", + "2038-01-01", + "5010-11-17").foreach { date => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val days = formatter.parse(date) + val formatted = formatter.format(days) + assert(date === formatted) + } + } + } + } + + test("roundtrip days -> date -> days") { + Seq( + -701265, + -371419, + -199722, + -1, + 0, + 967, + 2094, + 17877, + 24837, + 1110657).foreach { days => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(days) + val parsed = formatter.parse(date) + assert(days === parsed) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3330de3584ebb..786335b42e3cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -57,14 +57,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } val factory = new JsonFactory() - def enforceCorrectType(value: Any, dataType: DataType): Any = { + def enforceCorrectType( + value: Any, + dataType: DataType, + options: Map[String, String] = Map.empty): Any = { val writer = new StringWriter() Utils.tryWithResource(factory.createGenerator(writer)) { generator => generator.writeObject(value) generator.flush() } - val dummyOption = new JSONOptions(Map.empty[String, String], "GMT") + val dummyOption = new JSONOptions(options, SQLConf.get.sessionLocalTimeZone) val dummySchema = StructType(Seq.empty) val parser = new JacksonParser(dummySchema, dummyOption, allowArrayAsStructs = true) @@ -96,19 +99,27 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), - enforceCorrectType(strTime, TimestampType)) + checkTypePromotion( + expected = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType, + Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss"))) val strDate = "2014-10-15" checkTypePromotion( DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), - enforceCorrectType(ISO8601Time1, TimestampType)) + enforceCorrectType( + ISO8601Time1, + TimestampType, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss.SX"))) + val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), - enforceCorrectType(ISO8601Time2, TimestampType)) + enforceCorrectType( + ISO8601Time2, + TimestampType, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ssXXX"))) val ISO8601Date = "1970-01-01" checkTypePromotion(DateTimeUtils.millisToDays(32400000), @@ -1440,103 +1451,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("backward compatibility") { - // This test we make sure our JSON support can read JSON data generated by previous version - // of Spark generated through toJSON method and JSON data source. - // The data is generated by the following program. - // Here are a few notes: - // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) - // in the JSON object. - // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to - // JSON objects generated by those Spark versions (col17). - // - If the type is NullType, we do not write data out. - - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new TestUDT.MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new TestUDT.MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) - val constantValues = - Seq( - "a string in binary".getBytes(StandardCharsets.UTF_8), - null, - true, - 1.toByte, - 2.toShort, - 3, - Long.MaxValue, - 0.25.toFloat, - 0.75, - new java.math.BigDecimal(s"1234.23456"), - new java.math.BigDecimal(s"1.23456"), - java.sql.Date.valueOf("2015-01-01"), - java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), - Seq(2, 3, 4), - Map("a string" -> 2000L), - Row(4.75.toFloat, Seq(false, true)), - new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) - val data = - Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil + val constantValues = + Seq( + "a string in binary".getBytes(StandardCharsets.UTF_8), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil - // Data generated by previous versions. - // scalastyle:off - val existingJSONData = + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: - """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil - // scalastyle:on - - // Generate data for the current version. - val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) - withTempPath { path => - df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) - // df.toJSON will convert internal rows to external rows first and then generate - // JSON objects. While, df.write.format("json") will write internal rows directly. - val allJSON = + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = existingJSONData ++ df.toJSON.collect() ++ sparkContext.textFile(path.getCanonicalPath).collect() - Utils.deleteRecursively(path) - sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) - - // Read data back with the schema specified. - val col0Values = - Seq( - "Spark 1.2.2", - "Spark 1.3.1", - "Spark 1.3.1", - "Spark 1.4.1", - "Spark 1.4.1", - "Spark 1.5.0", - "Spark 1.5.0", - "Spark " + spark.sparkContext.version, - "Spark " + spark.sparkContext.version) - val expectedResult = col0Values.map { v => - Row.fromSeq(Seq(v) ++ constantValues) + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + spark.sparkContext.version, + "Spark " + spark.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + spark.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) } - checkAnswer( - spark.read.format("json").schema(schema).load(path.getCanonicalPath), - expectedResult - ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 6075f2c8877d6..f0f62b608785d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import java.io.File +import java.util.TimeZone import scala.util.Random @@ -125,56 +126,62 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } else { Seq(false) } - for (dataType <- supportedDataTypes) { - for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) { - val extraMessage = if (isParquetDataSource) { - s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled" - } else { - "" - } - logInfo(s"Testing $dataType data type$extraMessage") - - val extraOptions = Map[String, String]( - "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString - ) - - withTempPath { file => - val path = file.getCanonicalPath - - val dataGenerator = RandomDataGenerator.forType( - dataType = dataType, - nullable = true, - new Random(System.nanoTime()) - ).getOrElse { - fail(s"Failed to create data generator for schema $dataType") + // TODO: Support new parser too, see SPARK-26374. + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") { + for (dataType <- supportedDataTypes) { + for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) { + val extraMessage = if (isParquetDataSource) { + s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled" + } else { + "" + } + logInfo(s"Testing $dataType data type$extraMessage") + + val extraOptions = Map[String, String]( + "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString + ) + + withTempPath { file => + val path = file.getCanonicalPath + + val seed = System.nanoTime() + withClue(s"Random data generated with the seed: ${seed}") { + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, + nullable = true, + new Random(seed) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") + } + + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = + spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .options(extraOptions) + .save(path) + + val loadedDF = spark + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .options(extraOptions) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } } - - // Create a DF for the schema with random data. The index field is used to sort the - // DataFrame. This is a workaround for SPARK-10591. - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", dataType, nullable = true) - val rdd = - spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - - df.write - .mode("overwrite") - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .options(extraOptions) - .save(path) - - val loadedDF = spark - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .schema(df.schema) - .options(extraOptions) - .load(path) - .orderBy("index") - - checkAnswer(loadedDF, df) } } } From cd815ae6c5ce3edb8aec3add942549f76a20e586 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 16 Dec 2018 10:57:11 +0800 Subject: [PATCH 2289/2461] [SPARK-26078][SQL] Dedup self-join attributes on IN subqueries ## What changes were proposed in this pull request? When there is a self-join as result of a IN subquery, the join condition may be invalid, resulting in trivially true predicates and return wrong results. The PR deduplicates the subquery output in order to avoid the issue. ## How was this patch tested? added UT Closes #23057 from mgaido91/SPARK-26078. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/subquery.scala | 99 ++++++++++++------- .../org/apache/spark/sql/SubquerySuite.scala | 37 +++++++ 2 files changed, 98 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index e9b7a8b76e683..34840c6c977a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -43,31 +43,53 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { + + private def buildJoin( + outerPlan: LogicalPlan, + subplan: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]): Join = { + // Deduplicate conflicting attributes if any. + val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition) + Join(outerPlan, dedupSubplan, joinType, condition) + } + + private def dedupSubqueryOnSelfJoin( + outerPlan: LogicalPlan, + subplan: LogicalPlan, + valuesOpt: Option[Seq[Expression]], + condition: Option[Expression] = None): LogicalPlan = { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should - // de-duplicate conflicting attributes. We don't use transformation here because we only - // care about the most top join converted from correlated predicate subquery. - case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) => - val duplicates = right.outputSet.intersect(left.outputSet) - if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = right.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val newRight = Project(aliasedExpressions, right) - val newJoinCond = joinCond.map { condExpr => - condExpr transform { - case a: Attribute => aliasMap.getOrElse(a, a).toAttribute + // de-duplicate conflicting attributes. + // SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer + // values. In this case, the resulting join would contain trivially true conditions (eg. + // id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting + // attributes in the join condition, the subquery's conflicting attributes are changed using + // a projection which aliases them and resolves the problem. + val outerReferences = valuesOpt.map(values => + AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty) + val outerRefs = outerPlan.outputSet ++ outerReferences + val duplicates = outerRefs.intersect(subplan.outputSet) + if (duplicates.nonEmpty) { + condition.foreach { e => + val conflictingAttrs = e.references.intersect(duplicates) + if (conflictingAttrs.nonEmpty) { + throw new AnalysisException("Found conflicting attributes " + + s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " + + s"$outerPlan\nand subplan:\n $subplan") } - } - Join(left, newRight, joinType, newJoinCond) - } else { - j } - case _ => joinPlan + val rewrites = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = subplan.output.map { ref => + rewrites.getOrElse(ref, ref) + } + Project(aliasedExpressions, subplan) + } else { + subplan + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -85,17 +107,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { withSubquery.foldLeft(newFilter) { case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) + buildJoin(outerPlan, sub, LeftSemi, joinCond) case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) + buildJoin(outerPlan, sub, LeftAnti, joinCond) case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => - val inConditions = values.zip(sub.output).map(EqualTo.tupled) - val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) + val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) + val inConditions = values.zip(newSub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, newSub, LeftSemi, joinCond) case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive @@ -103,7 +124,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = values.zip(sub.output).map(EqualTo.tupled) + + // Deduplicate conflicting attributes if any. + val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) + val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) - // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) + Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) @@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { e transformUp { case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - // Deduplicate conflicting attributes if any. - newPlan = dedupJoin( - Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) + newPlan = + buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = values.zip(sub.output).map(EqualTo.tupled) - val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. - newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) + val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) + val inConditions = values.zip(newSub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions) exists } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5088821ad7361..c95c52f1d3a9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -1280,4 +1281,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(subqueries.length == 1) } } + + test("SPARK-26078: deduplicate fake self joins for IN subqueries") { + withTempView("a", "b") { + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") + + val df1 = spark.sql( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b"))) + val df2 = spark.sql( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b"))) + val df3 = spark.sql( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR + |c.id IN (SELECT id FROM b WHERE num = 3) + """.stripMargin) + checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b"))) + } + } } From e3e33d8794da5f3597b8d706b734af5025360939 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 16 Dec 2018 11:02:00 +0800 Subject: [PATCH 2290/2461] [SPARK-26372][SQL] Don't reuse value from previous row when parsing bad CSV input field ## What changes were proposed in this pull request? CSV parsing accidentally uses the previous good value for a bad input field. See example in Jira. This PR ensures that the associated column is set to null when an input field cannot be converted. ## How was this patch tested? Added new test. Ran all SQL unit tests (testOnly org.apache.spark.sql.*). Ran pyspark tests for pyspark-sql Closes #23323 from bersprockets/csv-bad-field. Authored-by: Bruce Robbins Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/csv/UnivocityParser.scala | 1 + .../resources/test-data/bad_after_good.csv | 2 ++ .../execution/datasources/csv/CSVSuite.scala | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/bad_after_good.csv diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index ed089120055e2..82a5b3c302b18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -239,6 +239,7 @@ class UnivocityParser( } catch { case NonFatal(e) => badRecordException = badRecordException.orElse(Some(e)) + row.setNullAt(i) } i += 1 } diff --git a/sql/core/src/test/resources/test-data/bad_after_good.csv b/sql/core/src/test/resources/test-data/bad_after_good.csv new file mode 100644 index 0000000000000..4621a7d23714d --- /dev/null +++ b/sql/core/src/test/resources/test-data/bad_after_good.csv @@ -0,0 +1,2 @@ +"good record",1999-08-01 +"bad record",1999-088-01 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 3b977d74053e6..d9e5d7af19671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -63,6 +63,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val datesFile = "test-data/dates.csv" private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" + private val badAfterGoodFile = "test-data/bad_after_good.csv" /** Verifies data and schema. */ private def verifyCars( @@ -2012,4 +2013,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(!files.exists(_.getName.endsWith("csv"))) } } + + test("Do not reuse last good value for bad input field") { + val schema = StructType( + StructField("col1", StringType) :: + StructField("col2", DateType) :: + Nil + ) + val rows = spark.read + .schema(schema) + .format("csv") + .load(testFile(badAfterGoodFile)) + + val expectedRows = Seq( + Row("good record", java.sql.Date.valueOf("1999-08-01")), + Row("bad record", null)) + + checkAnswer(rows, expectedRows) + } } From 5217f7b2263c7aaeadf60ef602776bb3777269cd Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 17 Dec 2018 08:24:51 +0800 Subject: [PATCH 2291/2461] [SPARK-26248][SQL] Infer date type from CSV ## What changes were proposed in this pull request? The `CSVInferSchema` class is extended to support inferring of `DateType` from CSV input. The attempt to infer `DateType` is performed after inferring `TimestampType`. ## How was this patch tested? Added new test for inferring date types from CSV . It was also tested by existing suites like `CSVInferSchemaSuite`, `CsvExpressionsSuite`, `CsvFunctionsSuite` and `CsvSuite`. Closes #23202 from MaxGekk/csv-date-inferring. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/CSVInferSchema.scala | 20 +++++++++++++++---- .../catalyst/csv/CSVInferSchemaSuite.scala | 18 +++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 35ade136cc607..11f3740d99a72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,16 +22,20 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.TimestampFormatter +import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @transient - private lazy val timestampParser = TimestampFormatter( + private lazy val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) + @transient + private lazy val dateFormatter = DateFormatter( + options.dateFormat, + options.locale) private val decimalParser = { ExprUtils.getDecimalParser(options.locale) @@ -104,6 +108,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) + case DateType => tryParseDate(field) case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => @@ -159,9 +164,16 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { } private def tryParseTimestamp(field: String): DataType = { - // This case infers a custom `dataFormat` is set. - if ((allCatch opt timestampParser.parse(field)).isDefined) { + if ((allCatch opt timestampFormatter.parse(field)).isDefined) { TimestampType + } else { + tryParseDate(field) + } + } + + private def tryParseDate(field: String): DataType = { + if ((allCatch opt dateFormatter.parse(field)).isDefined) { + DateType } else { tryParseBoolean(field) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index c2b525ad1a9f8..84b2e616a4426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -187,4 +187,22 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } + + test("inferring date type") { + var options = new CSVOptions(Map("dateFormat" -> "yyyy/MM/dd"), false, "GMT") + var inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "2018/12/02") == DateType) + + options = new CSVOptions(Map("dateFormat" -> "MMM yyyy"), false, "GMT") + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "Dec 2018") == DateType) + + options = new CSVOptions( + Map("dateFormat" -> "yyyy-MM-dd", "timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), + columnPruning = false, + defaultTimeZoneId = "GMT") + inferSchema = new CSVInferSchema(options) + assert(inferSchema.inferField(NullType, "2018-12-03T11:00:00") == TimestampType) + assert(inferSchema.inferField(NullType, "2018-12-03") == DateType) + } } From e408e05322ac4e31de4d9bc58687c86882e3944a Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Sun, 16 Dec 2018 17:11:58 -0800 Subject: [PATCH 2292/2461] [MINOR][DOCS] Fix the "not found: value Row" error on the "programmatic_schema" example ## What changes were proposed in this pull request? Print `import org.apache.spark.sql.Row` of `SparkSQLExample.scala` on the `programmatic_schema` example to fix the `not found: value Row` error on it. ``` scala> val rowRDD = peopleRDD.map(_.split(",")).map(attributes => Row(attributes(0), attributes(1).trim)) :28: error: not found: value Row val rowRDD = peopleRDD.map(_.split(",")).map(attributes => Row(attributes(0), attributes(1).trim)) ``` ## How was this patch tested? NA Closes #23326 from kjmrknsn/fix-sql-getting-started. Authored-by: Keiji Yoshida Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/examples/sql/SparkSQLExample.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala index 958361a6684c5..678cbc64aff1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.examples.sql +// $example on:programmatic_schema$ import org.apache.spark.sql.Row +// $example off:programmatic_schema$ // $example on:init_session$ import org.apache.spark.sql.SparkSession // $example off:init_session$ From db1c5b1839598eada81e4709ab4d25e799bb1810 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 17 Dec 2018 11:53:14 +0800 Subject: [PATCH 2293/2461] Revert "[SPARK-26248][SQL] Infer date type from CSV" This reverts commit 5217f7b2263c7aaeadf60ef602776bb3777269cd. --- .../sql/catalyst/csv/CSVInferSchema.scala | 20 ++++--------------- .../catalyst/csv/CSVInferSchemaSuite.scala | 18 ----------------- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 11f3740d99a72..35ade136cc607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -22,20 +22,16 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @transient - private lazy val timestampFormatter = TimestampFormatter( + private lazy val timestampParser = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) - @transient - private lazy val dateFormatter = DateFormatter( - options.dateFormat, - options.locale) private val decimalParser = { ExprUtils.getDecimalParser(options.locale) @@ -108,7 +104,6 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { compatibleType(typeSoFar, tryParseDecimal(field)).getOrElse(StringType) case DoubleType => tryParseDouble(field) case TimestampType => tryParseTimestamp(field) - case DateType => tryParseDate(field) case BooleanType => tryParseBoolean(field) case StringType => StringType case other: DataType => @@ -164,16 +159,9 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { } private def tryParseTimestamp(field: String): DataType = { - if ((allCatch opt timestampFormatter.parse(field)).isDefined) { + // This case infers a custom `dataFormat` is set. + if ((allCatch opt timestampParser.parse(field)).isDefined) { TimestampType - } else { - tryParseDate(field) - } - } - - private def tryParseDate(field: String): DataType = { - if ((allCatch opt dateFormatter.parse(field)).isDefined) { - DateType } else { tryParseBoolean(field) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala index 84b2e616a4426..c2b525ad1a9f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala @@ -187,22 +187,4 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper { Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, DecimalType(7, 0))) } - - test("inferring date type") { - var options = new CSVOptions(Map("dateFormat" -> "yyyy/MM/dd"), false, "GMT") - var inferSchema = new CSVInferSchema(options) - assert(inferSchema.inferField(NullType, "2018/12/02") == DateType) - - options = new CSVOptions(Map("dateFormat" -> "MMM yyyy"), false, "GMT") - inferSchema = new CSVInferSchema(options) - assert(inferSchema.inferField(NullType, "Dec 2018") == DateType) - - options = new CSVOptions( - Map("dateFormat" -> "yyyy-MM-dd", "timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), - columnPruning = false, - defaultTimeZoneId = "GMT") - inferSchema = new CSVInferSchema(options) - assert(inferSchema.inferField(NullType, "2018-12-03T11:00:00") == TimestampType) - assert(inferSchema.inferField(NullType, "2018-12-03") == DateType) - } } From 56448c662398f4c5319a337e6601450270a6a27c Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 17 Dec 2018 13:41:20 +0800 Subject: [PATCH 2294/2461] [SPARK-26352][SQL] join reorder should not change the order of output attributes ## What changes were proposed in this pull request? The optimizer rule `org.apache.spark.sql.catalyst.optimizer.ReorderJoin` performs join reordering on inner joins. This was introduced from SPARK-12032 (https://github.com/apache/spark/pull/10073) in 2015-12. After it had reordered the joins, though, it didn't check whether or not the output attribute order is still the same as before. Thus, it's possible to have a mismatch between the reordered output attributes order vs the schema that a DataFrame thinks it has. The same problem exists in the CBO version of join reordering (`CostBasedJoinReorder`) too. This can be demonstrated with the example: ```scala spark.sql("create table table_a (x int, y int) using parquet") spark.sql("create table table_b (i int, j int) using parquet") spark.sql("create table table_c (a int, b int) using parquet") val df = spark.sql(""" with df1 as (select * from table_a cross join table_b) select * from df1 join table_c on a = x and b = i """) ``` here's what the DataFrame thinks: ``` scala> df.printSchema root |-- x: integer (nullable = true) |-- y: integer (nullable = true) |-- i: integer (nullable = true) |-- j: integer (nullable = true) |-- a: integer (nullable = true) |-- b: integer (nullable = true) ``` here's what the optimized plan thinks, after join reordering: ``` scala> df.queryExecution.optimizedPlan.output.foreach(a => println(s"|-- ${a.name}: ${a.dataType.typeName}")) |-- x: integer |-- y: integer |-- a: integer |-- b: integer |-- i: integer |-- j: integer ``` If we exclude the `ReorderJoin` rule (using Spark 2.4's optimizer rule exclusion feature), it's back to normal: ``` scala> spark.conf.set("spark.sql.optimizer.excludedRules", "org.apache.spark.sql.catalyst.optimizer.ReorderJoin") scala> val df = spark.sql("with df1 as (select * from table_a cross join table_b) select * from df1 join table_c on a = x and b = i") df: org.apache.spark.sql.DataFrame = [x: int, y: int ... 4 more fields] scala> df.queryExecution.optimizedPlan.output.foreach(a => println(s"|-- ${a.name}: ${a.dataType.typeName}")) |-- x: integer |-- y: integer |-- i: integer |-- j: integer |-- a: integer |-- b: integer ``` Note that this output attribute ordering problem leads to data corruption, and can manifest itself in various symptoms: * Silently corrupting data, if the reordered columns happen to either have matching types or have sufficiently-compatible types (e.g. all fixed length primitive types are considered as "sufficiently compatible" in an `UnsafeRow`), then only the resulting data is going to be wrong but it might not trigger any alarms immediately. Or * Weird Java-level exceptions like `java.lang.NegativeArraySizeException`, or even SIGSEGVs. ## How was this patch tested? Added new unit test in `JoinReorderSuite` and new end-to-end test in `JoinSuite`. Also made `JoinReorderSuite` and `StarJoinReorderSuite` assert more strongly on maintaining output attribute order. Closes #23303 from rednaxelafx/fix-join-reorder. Authored-by: Kris Mok Signed-off-by: Wenchen Fan --- .../optimizer/CostBasedJoinReorder.scala | 10 +++++ .../spark/sql/catalyst/optimizer/joins.scala | 12 +++++- .../optimizer/JoinOptimizationSuite.scala | 3 ++ .../catalyst/optimizer/JoinReorderSuite.scala | 38 +++++++++++++++++-- .../StarJoinCostBasedReorderSuite.scala | 21 +++++++++- .../optimizer/StarJoinReorderSuite.scala | 28 ++++++++++++-- .../org/apache/spark/sql/JoinSuite.scala | 14 +++++++ 7 files changed, 116 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 064ca68b7a628..01634a9d852c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -48,6 +48,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { if projectList.forall(_.isInstanceOf[Attribute]) => reorder(p, p.output) } + // After reordering is finished, convert OrderedJoin back to Join result transformDown { case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) @@ -175,11 +176,20 @@ object JoinReorderDP extends PredicateHelper with Logging { assert(topOutputSet == p.outputSet) // Keep the same order of final output attributes. p.copy(projectList = output) + case finalPlan if !sameOutput(finalPlan, output) => + Project(output, finalPlan) case finalPlan => finalPlan } } + private def sameOutput(plan: LogicalPlan, expectedOutput: Seq[Attribute]): Boolean = { + val thisOutput = plan.output + thisOutput.length == expectedOutput.length && thisOutput.zip(expectedOutput).forall { + case (a1, a2) => a1.semanticEquals(a2) + } + } + /** Find all possible plans at the next level, based on existing levels. */ private def searchLevel( existingLevels: Seq[JoinPlanMap], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 6ebb194d71c2e..0b6471289a471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -86,9 +86,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case ExtractFiltersAndInnerJoins(input, conditions) + case p @ ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { + val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) @@ -99,6 +99,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } else { createOrderedJoin(input, conditions) } + + if (p.sameOutput(reordered)) { + reordered + } else { + // Reordering the joins have changed the order of the columns. + // Inject a projection to make sure we restore to the expected ordering. + Project(p.output, reordered) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index ccd9d8dd4d213..e9438b2eee550 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -102,16 +102,19 @@ class JoinOptimizationSuite extends PlanTest { x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), x.join(z, condition = Some("x.b".attr === "z.b".attr)) .join(y, condition = Some("y.d".attr === "z.a".attr)) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) ), ( x.join(y, Cross).join(z, Cross) .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)), x.join(z, Cross, Some("x.b".attr === "z.b".attr)) .join(y, Cross, Some("y.d".attr === "z.a".attr)) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) ), ( x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr), x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner) + .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*) ) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 565b0a10154a8..c94a8b9e318f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} @@ -124,7 +124,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { // the original order (t1 J t2) J t3. val bestPlan = t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) assertEqualPlans(originalPlan, bestPlan) } @@ -139,7 +140,9 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { val bestPlan = t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) // this is redundant but we'll take it for now .join(t4) + .select(outputsOf(t1, t2, t4, t3): _*) assertEqualPlans(originalPlan, bestPlan) } @@ -202,6 +205,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t1, t4, t2, t3): _*) assertEqualPlans(originalPlan, bestPlan) } @@ -219,6 +223,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } + test("SPARK-26352: join reordering should not change the order of attributes") { + // This test case does not rely on CBO. + // It's similar to the test case above, but catches a reordering bug that the one above doesn't + val tab1 = LocalRelation('x.int, 'y.int) + val tab2 = LocalRelation('i.int, 'j.int) + val tab3 = LocalRelation('a.int, 'b.int) + val original = + tab1.join(tab2, Cross) + .join(tab3, Inner, Some('a === 'x && 'b === 'i)) + val expected = + tab1.join(tab3, Inner, Some('a === 'x)) + .join(tab2, Cross, Some('b === 'i)) + .select(outputsOf(tab1, tab2, tab3): _*) + + assertEqualPlans(original, expected) + } + test("reorder recursively") { // Original order: // Join @@ -266,8 +287,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { - val optimized = Optimize.execute(originalPlan.analyze) + val analyzed = originalPlan.analyze + val optimized = Optimize.execute(analyzed) val expected = groundTruthBestPlan.analyze + + assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect + assert(analyzed.sameOutput(optimized)) + compareJoinOrder(optimized, expected) } + + private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index d4d23ad69b2c2..baae934e1e4fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -218,6 +218,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) .join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1"))) .join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1"))) + .select(outputsOf(f1, t1, t2, d1, d2): _*) assertEqualPlans(query, expected) } @@ -256,6 +257,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas .join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner, Some(nameToAttr("d1_c2") === nameToAttr("t2_c1"))) .join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1"))) + .select(outputsOf(d1, t1, t2, f1, d2, t3): _*) assertEqualPlans(query, expected) } @@ -297,6 +299,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + .select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*) assertEqualPlans(query, expected) } @@ -347,6 +350,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas Some(nameToAttr("d3_c2") === nameToAttr("t1_c1"))) .join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner, Some(nameToAttr("d2_c2") === nameToAttr("t5_c1"))) + .select(outputsOf(d1, t3, t4, f1, d2, t5, t6, d3, t1, t2): _*) assertEqualPlans(query, expected) } @@ -375,6 +379,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .select(outputsOf(d1, d2, f1, d3): _*) assertEqualPlans(query, expected) } @@ -400,13 +405,27 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) .join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1"))) .join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1"))) + .select(outputsOf(t1, f1, t2, t3): _*) assertEqualPlans(query, expected) } private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val optimized = Optimize.execute(plan1.analyze) + val analyzed = plan1.analyze + val optimized = Optimize.execute(analyzed) val expected = plan2.analyze + + assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect + assert(equivalentOutput(analyzed, optimized)) + compareJoinOrder(optimized, expected) } + + private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } + + private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + normalizeExprIds(plan1).output == normalizeExprIds(plan2).output + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 4e0883e91e84a..9dc653b9d6c44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -182,6 +182,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d2, f1, d3, s3): _*) assertEqualPlans(query, expected) } @@ -220,6 +221,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, f1, d2, s3, d3): _*) assertEqualPlans(query, expected) } @@ -255,7 +257,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2"))) - + .select(outputsOf(d1, f1, d2, s3, d3): _*) assertEqualPlans(query, expected) } @@ -292,6 +294,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + .select(outputsOf(d1, f1, d2, s3, d3): _*) assertEqualPlans(query, expected) } @@ -395,6 +398,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, f11, f1, d2, s3): _*) assertEqualPlans(query, equivQuery) } @@ -430,6 +434,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -465,6 +470,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -499,6 +505,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d2.where(nameToAttr("d2_c2") === 2), Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -532,6 +539,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } @@ -565,13 +573,27 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .select(outputsOf(d1, d3, f1, d2, s3): _*) assertEqualPlans(query, expected) } - private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { - val optimized = Optimize.execute(plan1.analyze) + private def assertEqualPlans(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val analyzed = plan1.analyze + val optimized = Optimize.execute(analyzed) val expected = plan2.analyze + + assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect + assert(equivalentOutput(analyzed, optimized)) + compareJoinOrder(optimized, expected) } + + private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } + + private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + normalizeExprIds(plan1).output == normalizeExprIds(plan2).output + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index aa2162c9d2cda..91445c8d96d85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -895,4 +895,18 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(res, Row(0, 0, 0)) } } + + test("SPARK-26352: join reordering should not change the order of columns") { + withTable("tab1", "tab2", "tab3") { + spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1") + spark.sql("select 42 as i, 200 as j").write.saveAsTable("tab2") + spark.sql("select 1 as a, 42 as b").write.saveAsTable("tab3") + + val df = spark.sql(""" + with tmp as (select * from tab1 cross join tab2) + select * from tmp join tab3 on a = x and b = i + """) + checkAnswer(df, Row(1, 100, 42, 200, 1, 42)) + } + } } From 5960a8297ca06a4c62f39a8821dba4ba172f2bfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 16 Dec 2018 23:40:06 -0800 Subject: [PATCH 2295/2461] [SPARK-26327][SQL][FOLLOW-UP] Refactor the code and restore the metrics name ## What changes were proposed in this pull request? - The original comment about `updateDriverMetrics` is not right. - Refactor the code to ensure `selectedPartitions ` has been set before sending the driver-side metrics. - Restore the original name, which is more general and extendable. ## How was this patch tested? The existing tests. Closes #23328 from gatorsmile/followupSpark-26142. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../sql/execution/DataSourceScanExec.scala | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index c0fa4e777b49c..322ffffca564b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} @@ -167,14 +167,26 @@ case class FileSourceScanExec( partitionSchema = relation.partitionSchema, relation.sparkSession.sessionState.conf) - private var fileListingTime = 0L + val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has + * been initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = relation.location.listFiles(partitionFilters, dataFilters) + driverMetrics("numFiles") = ret.map(_.files.size.toLong).sum val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000 - fileListingTime = timeTakenMs + driverMetrics("metadataTime") = timeTakenMs ret } @@ -286,8 +298,6 @@ case class FileSourceScanExec( } private lazy val inputRDD: RDD[InternalRow] = { - // Update metrics for taking effect in both code generation node and normal node. - updateDriverMetrics() val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -298,12 +308,14 @@ case class FileSourceScanExec( options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) - relation.bucketSpec match { + val readRDD = relation.bucketSpec match { case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled => createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation) case _ => createNonBucketedReadRDD(readFile, selectedPartitions, relation) } + sendDriverMetrics() + readRDD } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -313,7 +325,7 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), - "fileListingTime" -> SQLMetrics.createMetric(sparkContext, "file listing time (ms)"), + "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { @@ -504,19 +516,6 @@ case class FileSourceScanExec( } } - /** - * Send the updated metrics to driver, while this function calling, selectedPartitions has - * been initialized. See SPARK-26327 for more detail. - */ - private def updateDriverMetrics() = { - metrics("numFiles").add(selectedPartitions.map(_.files.size.toLong).sum) - metrics("fileListingTime").add(fileListingTime) - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, - metrics("numFiles") :: metrics("fileListingTime") :: Nil) - } - override def doCanonicalize(): FileSourceScanExec = { FileSourceScanExec( relation, From f6888f7c944daff3d7c88b37e883673866eb148e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 17 Dec 2018 00:13:51 -0800 Subject: [PATCH 2296/2461] [SPARK-20636] Add the rule TransposeWindow to the optimization batch ## What changes were proposed in this pull request? This PR is a follow-up of the PR https://github.com/apache/spark/pull/17899. It is to add the rule TransposeWindow the optimizer batch. ## How was this patch tested? The existing tests. Closes #23222 from gatorsmile/followupSPARK-20636. Authored-by: gatorsmile Signed-off-by: gatorsmile --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/DataFrameWindowFunctionsSuite.scala | 38 +++++++++++++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f615757a837a1..3eb6bca6ec976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -73,6 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CombineLimits, CombineUnions, // Constant folding and strength reduction + TransposeWindow, NullPropagation, ConstantPropagation, FoldablePropagation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 9a5d5a9966ab7..9277dc6859247 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.catalyst.optimizer.TransposeWindow +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -668,18 +670,30 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { ("S2", "P2", 300) ).toDF("sno", "pno", "qty") - val w1 = Window.partitionBy("sno") - val w2 = Window.partitionBy("sno", "pno") - - checkAnswer( - df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) - .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")), - Seq( - Row("S1", "P1", 100, 800, 800), - Row("S1", "P1", 700, 800, 800), - Row("S2", "P1", 200, 200, 500), - Row("S2", "P2", 300, 300, 500))) - + Seq(true, false).foreach { transposeWindowEnabled => + val excludedRules = if (transposeWindowEnabled) "" else TransposeWindow.ruleName + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) { + val w1 = Window.partitionBy("sno") + val w2 = Window.partitionBy("sno", "pno") + + val select = df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) + .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")) + + val expectedNumExchanges = if (transposeWindowEnabled) 1 else 2 + val actualNumExchanges = select.queryExecution.executedPlan.collect { + case e: Exchange => e + }.length + assert(actualNumExchanges == expectedNumExchanges) + + checkAnswer( + select, + Seq( + Row("S1", "P1", 100, 800, 800), + Row("S1", "P1", 700, 800, 800), + Row("S2", "P1", 200, 200, 500), + Row("S2", "P2", 300, 300, 500))) + } + } } test("NaN and -0.0 in window partition keys") { From 12640d674b0af6716023fad30fe12cee728bfe34 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 17 Dec 2018 21:47:38 +0800 Subject: [PATCH 2297/2461] [SPARK-26243][SQL][FOLLOWUP] fix code style issues in TimestampFormatter.scala ## What changes were proposed in this pull request? 1. rename `FormatterUtils` to `DateTimeFormatterHelper`, and move it to a separated file 2. move `DateFormatter` and its implementation to a separated file 3. mark some methods as private 4. add `override` to some methods ## How was this patch tested? existing tests Closes #23329 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/util/DateFormatter.scala | 96 +++++++++++++++ .../util/DateTimeFormatterHelper.scala | 44 +++++++ .../catalyst/util/TimestampFormatter.scala | 115 ++---------------- .../spark/sql/util/DateFormatterSuite.scala | 92 ++++++++++++++ ...te.scala => TimestampFormatterSuite.scala} | 73 +---------- 5 files changed, 246 insertions(+), 174 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/util/{DateTimestampFormatterSuite.scala => TimestampFormatterSuite.scala} (66%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala new file mode 100644 index 0000000000000..9e8d51cc65f03 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.time.{Instant, ZoneId} +import java.util.Locale + +import scala.util.Try + +import org.apache.commons.lang3.time.FastDateFormat + +import org.apache.spark.sql.internal.SQLConf + +sealed trait DateFormatter { + def parse(s: String): Int // returns days since epoch + def format(days: Int): String +} + +class Iso8601DateFormatter( + pattern: String, + locale: Locale) extends DateFormatter with DateTimeFormatterHelper { + + private val formatter = buildFormatter(pattern, locale) + private val UTC = ZoneId.of("UTC") + + private def toInstant(s: String): Instant = { + val temporalAccessor = formatter.parse(s) + toInstantWithZoneId(temporalAccessor, UTC) + } + + override def parse(s: String): Int = { + val seconds = toInstant(s).getEpochSecond + val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) + days.toInt + } + + override def format(days: Int): String = { + val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) + formatter.withZone(UTC).format(instant) + } +} + +class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { + private val format = FastDateFormat.getInstance(pattern, locale) + + override def parse(s: String): Int = { + val milliseconds = format.parse(s).getTime + DateTimeUtils.millisToDays(milliseconds) + } + + override def format(days: Int): String = { + val date = DateTimeUtils.toJavaDate(days) + format.format(date) + } +} + +class LegacyFallbackDateFormatter( + pattern: String, + locale: Locale) extends LegacyDateFormatter(pattern, locale) { + override def parse(s: String): Int = { + Try(super.parse(s)).orElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime)) + }.getOrElse { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + s.toInt + } + } +} + +object DateFormatter { + def apply(format: String, locale: Locale): DateFormatter = { + if (SQLConf.get.legacyTimeParserEnabled) { + new LegacyFallbackDateFormatter(format, locale) + } else { + new Iso8601DateFormatter(format, locale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala new file mode 100644 index 0000000000000..b85101d38d9e6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.time.{Instant, LocalDateTime, ZonedDateTime, ZoneId} +import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder} +import java.time.temporal.{ChronoField, TemporalAccessor} +import java.util.Locale + +trait DateTimeFormatterHelper { + + protected def buildFormatter(pattern: String, locale: Locale): DateTimeFormatter = { + new DateTimeFormatterBuilder() + .appendPattern(pattern) + .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) + .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) + .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) + .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) + .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) + .toFormatter(locale) + } + + protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor, zoneId: ZoneId): Instant = { + val localDateTime = LocalDateTime.from(temporalAccessor) + val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) + Instant.from(zonedDateTime) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 2b8d22dde9267..eb1303303463d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.util import java.time._ -import java.time.format.DateTimeFormatterBuilder -import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} +import java.time.temporal.TemporalQueries import java.util.{Locale, TimeZone} import scala.util.Try @@ -33,39 +32,16 @@ sealed trait TimestampFormatter { def format(us: Long): String } -trait FormatterUtils { - protected def zoneId: ZoneId - protected def buildFormatter( - pattern: String, - locale: Locale): java.time.format.DateTimeFormatter = { - new DateTimeFormatterBuilder() - .appendPattern(pattern) - .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) - .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) - .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) - .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) - .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) - .toFormatter(locale) - } - protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor): java.time.Instant = { - val localDateTime = LocalDateTime.from(temporalAccessor) - val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) - Instant.from(zonedDateTime) - } -} - class Iso8601TimestampFormatter( pattern: String, timeZone: TimeZone, - locale: Locale) extends TimestampFormatter with FormatterUtils { - val zoneId = timeZone.toZoneId - val formatter = buildFormatter(pattern, locale) + locale: Locale) extends TimestampFormatter with DateTimeFormatterHelper { + private val formatter = buildFormatter(pattern, locale) - def toInstant(s: String): Instant = { + private def toInstant(s: String): Instant = { val temporalAccessor = formatter.parse(s) if (temporalAccessor.query(TemporalQueries.offset()) == null) { - toInstantWithZoneId(temporalAccessor) + toInstantWithZoneId(temporalAccessor, timeZone.toZoneId) } else { Instant.from(temporalAccessor) } @@ -77,9 +53,9 @@ class Iso8601TimestampFormatter( result } - def parse(s: String): Long = instantToMicros(toInstant(s)) + override def parse(s: String): Long = instantToMicros(toInstant(s)) - def format(us: Long): String = { + override def format(us: Long): String = { val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND) val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND) val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS) @@ -92,13 +68,13 @@ class LegacyTimestampFormatter( pattern: String, timeZone: TimeZone, locale: Locale) extends TimestampFormatter { - val format = FastDateFormat.getInstance(pattern, timeZone, locale) + private val format = FastDateFormat.getInstance(pattern, timeZone, locale) protected def toMillis(s: String): Long = format.parse(s).getTime - def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS + override def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS - def format(us: Long): String = { + override def format(us: Long): String = { format.format(DateTimeUtils.toJavaTimestamp(us)) } } @@ -121,74 +97,3 @@ object TimestampFormatter { } } } - -sealed trait DateFormatter { - def parse(s: String): Int // returns days since epoch - def format(days: Int): String -} - -class Iso8601DateFormatter( - pattern: String, - locale: Locale) extends DateFormatter with FormatterUtils { - - val zoneId = ZoneId.of("UTC") - - val formatter = buildFormatter(pattern, locale) - - def toInstant(s: String): Instant = { - val temporalAccessor = formatter.parse(s) - toInstantWithZoneId(temporalAccessor) - } - - override def parse(s: String): Int = { - val seconds = toInstant(s).getEpochSecond - val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY) - - days.toInt - } - - override def format(days: Int): String = { - val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY) - formatter.withZone(zoneId).format(instant) - } -} - -class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { - val format = FastDateFormat.getInstance(pattern, locale) - - def parse(s: String): Int = { - val milliseconds = format.parse(s).getTime - DateTimeUtils.millisToDays(milliseconds) - } - - def format(days: Int): String = { - val date = DateTimeUtils.toJavaDate(days) - format.format(date) - } -} - -class LegacyFallbackDateFormatter( - pattern: String, - locale: Locale) extends LegacyDateFormatter(pattern, locale) { - override def parse(s: String): Int = { - Try(super.parse(s)).orElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime)) - }.getOrElse { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - s.toInt - } - } -} - -object DateFormatter { - def apply(format: String, locale: Locale): DateFormatter = { - if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyFallbackDateFormatter(format, locale) - } else { - new Iso8601DateFormatter(format, locale) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala new file mode 100644 index 0000000000000..019615b81101c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf + +class DateFormatterSuite extends SparkFunSuite with SQLHelper { + test("parsing dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val daysSinceEpoch = formatter.parse("2018-12-02") + assert(daysSinceEpoch === 17867) + } + } + } + + test("format dates") { + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(17867) + assert(date === "2018-12-02") + } + } + } + + test("roundtrip date -> days -> date") { + Seq( + "0050-01-01", + "0953-02-02", + "1423-03-08", + "1969-12-31", + "1972-08-25", + "1975-09-26", + "2018-12-12", + "2038-01-01", + "5010-11-17").foreach { date => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val days = formatter.parse(date) + val formatted = formatter.format(days) + assert(date === formatted) + } + } + } + } + + test("roundtrip days -> date -> days") { + Seq( + -701265, + -371419, + -199722, + -1, + 0, + 967, + 2094, + 17877, + 24837, + 1110657).foreach { days => + DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + val formatter = DateFormatter("yyyy-MM-dd", Locale.US) + val date = formatter.format(days) + val parsed = formatter.parse(date) + assert(days === parsed) + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala similarity index 66% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index 43e348c7eebf4..c110ffa01f733 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -21,19 +21,9 @@ import java.util.{Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, TimestampFormatter} -class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { - test("parsing dates") { - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val daysSinceEpoch = formatter.parse("2018-12-02") - assert(daysSinceEpoch === 17867) - } - } - } +class TimestampFormatterSuite extends SparkFunSuite with SQLHelper { test("parsing timestamps using time zones") { val localDate = "2018-12-02T10:11:12.001234" @@ -56,16 +46,6 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } - test("format dates") { - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val date = formatter.format(17867) - assert(date === "2018-12-02") - } - } - } - test("format timestamps using time zones") { val microsSinceEpoch = 1543745472001234L val expectedTimestamp = Map( @@ -87,7 +67,7 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } - test("roundtrip timestamp -> micros -> timestamp using timezones") { + test("roundtrip micros -> timestamp -> micros using timezones") { Seq( -58710115316212000L, -18926315945345679L, @@ -107,7 +87,7 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } - test("roundtrip micros -> timestamp -> micros using timezones") { + test("roundtrip timestamp -> micros -> timestamp using timezones") { Seq( "0109-07-20T18:38:03.788000", "1370-04-01T10:00:54.654321", @@ -126,49 +106,4 @@ class DateTimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } } - - test("roundtrip date -> days -> date") { - Seq( - "0050-01-01", - "0953-02-02", - "1423-03-08", - "1969-12-31", - "1972-08-25", - "1975-09-26", - "2018-12-12", - "2038-01-01", - "5010-11-17").foreach { date => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val days = formatter.parse(date) - val formatted = formatter.format(days) - assert(date === formatted) - } - } - } - } - - test("roundtrip days -> date -> days") { - Seq( - -701265, - -371419, - -199722, - -1, - 0, - 967, - 2094, - 17877, - 24837, - 1110657).foreach { days => - DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone => - withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { - val formatter = DateFormatter("yyyy-MM-dd", Locale.US) - val date = formatter.format(days) - val parsed = formatter.parse(date) - assert(days === parsed) - } - } - } - } } From c04ad17ccf14a07ffdb2bf637124492a341075f2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 17 Dec 2018 09:28:23 -0600 Subject: [PATCH 2298/2461] [SPARK-20351][ML] Add trait hasTrainingSummary to replace the duplicate code ## What changes were proposed in this pull request? Add a trait HasTrainingSummary to avoid code duplicate related to training summary. Currently all the training summary use the similar pattern which can be generalized, ``` private[ml] final var trainingSummary: Option[T] = None def hasSummary: Boolean = trainingSummary.isDefined def summary: T = trainingSummary.getOrElse... private[ml] def setSummary(summary: Option[T]): ... ``` Classes with the trait need to override `setSummry`. And for Java compatibility, they will also have to override `summary` method, otherwise the java code will regard all the summary class as Object due to a known issue with Scala. ## How was this patch tested? existing Java and Scala unit tests Closes #17654 from hhbyyh/hassummary. Authored-by: Yuhao Yang Signed-off-by: Sean Owen --- .../classification/LogisticRegression.scala | 24 ++------- .../spark/ml/clustering/BisectingKMeans.scala | 25 ++------- .../spark/ml/clustering/GaussianMixture.scala | 24 ++------- .../apache/spark/ml/clustering/KMeans.scala | 23 ++------ .../GeneralizedLinearRegression.scala | 22 ++------ .../ml/regression/LinearRegression.scala | 21 ++------ .../spark/ml/util/HasTrainingSummary.scala | 52 +++++++++++++++++++ 7 files changed, 78 insertions(+), 113 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 27a7db0b2f5d4..f2a5c11a34867 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -934,8 +934,8 @@ class LogisticRegressionModel private[spark] ( @Since("2.1.0") val interceptVector: Vector, @Since("1.3.0") override val numClasses: Int, private val isMultinomial: Boolean) - extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with MLWritable { + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with MLWritable + with LogisticRegressionParams with HasTrainingSummary[LogisticRegressionTrainingSummary] { require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " + s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " + @@ -1018,20 +1018,16 @@ class LogisticRegressionModel private[spark] ( @Since("1.6.0") override val numFeatures: Int = coefficientMatrix.numCols - private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None - /** * Gets summary of model on training set. An exception is thrown - * if `trainingSummary == None`. + * if `hasSummary` is false. */ @Since("1.5.0") - def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException("No training summary available for this LogisticRegressionModel") - } + override def summary: LogisticRegressionTrainingSummary = super.summary /** * Gets summary of model on training set. An exception is thrown - * if `trainingSummary == None` or it is a multiclass model. + * if `hasSummary` is false or it is a multiclass model. */ @Since("2.3.0") def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match { @@ -1062,16 +1058,6 @@ class LogisticRegressionModel private[spark] ( (model, model.getProbabilityCol, model.getPredictionCol) } - private[classification] - def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** Indicates whether a training summary exists for this model instance. */ - @Since("1.5.0") - def hasSummary: Boolean = trainingSummary.isDefined - /** * Evaluates the model on a test dataset. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1a94aefa3f563..49e9f51368131 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -87,8 +87,9 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter @Since("2.0.0") class BisectingKMeansModel private[ml] ( @Since("2.0.0") override val uid: String, - private val parentModel: MLlibBisectingKMeansModel - ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { + private val parentModel: MLlibBisectingKMeansModel) + extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable + with HasTrainingSummary[BisectingKMeansSummary] { @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { @@ -143,28 +144,12 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) - private var trainingSummary: Option[BisectingKMeansSummary] = None - - private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.1.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.1.0") - def summary: BisectingKMeansSummary = trainingSummary.getOrElse { - throw new SparkException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: BisectingKMeansSummary = super.summary } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88abc1605d69f..bb10b3228b93f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -85,7 +85,8 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override val uid: String, @Since("2.0.0") val weights: Array[Double], @Since("2.0.0") val gaussians: Array[MultivariateGaussian]) - extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable + with HasTrainingSummary[GaussianMixtureSummary] { /** @group setParam */ @Since("2.1.0") @@ -160,28 +161,13 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) - private var trainingSummary: Option[GaussianMixtureSummary] = None - - private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.0.0") - def summary: GaussianMixtureSummary = trainingSummary.getOrElse { - throw new RuntimeException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: GaussianMixtureSummary = super.summary + } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 2eed84d51782a..319747d4a1930 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -107,7 +107,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private[clustering] val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable + with HasTrainingSummary[KMeansSummary] { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -153,28 +154,12 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: GeneralMLWriter = new GeneralMLWriter(this) - private var trainingSummary: Option[KMeansSummary] = None - - private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.0.0") - def summary: KMeansSummary = trainingSummary.getOrElse { - throw new SparkException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: KMeansSummary = super.summary } /** Helper class for storing model data */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index abb60ea205751..885b13bf8dac3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1001,7 +1001,8 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("2.0.0") val intercept: Double) extends RegressionModel[Vector, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase with MLWritable { + with GeneralizedLinearRegressionBase with MLWritable + with HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary] { /** * Sets the link prediction (linear predictor) column name. @@ -1054,29 +1055,12 @@ class GeneralizedLinearRegressionModel private[ml] ( output.toDF() } - private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None - /** * Gets R-like summary of model on training set. An exception is * thrown if there is no summary available. */ @Since("2.0.0") - def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException( - "No training summary available for this GeneralizedLinearRegressionModel") - } - - /** - * Indicates if [[summary]] is available. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - - private[regression] - def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } + override def summary: GeneralizedLinearRegressionTrainingSummary = super.summary /** * Evaluate the model on the given dataset, returning a summary of the results. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ce6c12cc368dd..197828762d160 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -647,33 +647,20 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with GeneralMLWritable { + with LinearRegressionParams with GeneralMLWritable + with HasTrainingSummary[LinearRegressionTrainingSummary] { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) - private var trainingSummary: Option[LinearRegressionTrainingSummary] = None - override val numFeatures: Int = coefficients.size /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("1.5.0") - def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException("No training summary available for this LinearRegressionModel") - } - - private[regression] - def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** Indicates whether a training summary exists for this model instance. */ - @Since("1.5.0") - def hasSummary: Boolean = trainingSummary.isDefined + override def summary: LinearRegressionTrainingSummary = super.summary /** * Evaluates the model on a test dataset. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala new file mode 100644 index 0000000000000..edb0208144e10 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since + + +/** + * Trait for models that provides Training summary. + * + * @tparam T Summary instance type + */ +@Since("3.0.0") +private[ml] trait HasTrainingSummary[T] { + + private[ml] final var trainingSummary: Option[T] = None + + /** Indicates whether a training summary exists for this model instance. */ + @Since("3.0.0") + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Gets summary of model on training set. An exception is + * thrown if if `hasSummary` is false. + */ + @Since("3.0.0") + def summary: T = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for this ${this.getClass.getSimpleName}") + } + + private[ml] def setSummary(summary: Option[T]): this.type = { + this.trainingSummary = summary + this + } +} From 6d45e6ea1507943f6ee833af8ad7969294b0356a Mon Sep 17 00:00:00 2001 From: chakravarthi Date: Mon, 17 Dec 2018 09:46:50 -0800 Subject: [PATCH 2299/2461] [SPARK-26255][YARN] Apply user provided UI filters to SQL tab in yarn mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? User specified filters are not applied to SQL tab in yarn mode, as it is overridden by the yarn AmIp filter. So we need to append user provided filters (spark.ui.filters) with yarn filter. ## How was this patch tested? 【Test step】: 1) Launch spark sql with authentication filter as below: 2) spark-sql --master yarn --conf spark.ui.filters=org.apache.hadoop.security.authentication.server.AuthenticationFilter --conf spark.org.apache.hadoop.security.authentication.server.AuthenticationFilter.params="type=simple" 3) Go to Yarn application list UI link 4) Launch the application master for the Spark-SQL app ID and access all the tabs by appending tab name. 5) It will display an error for all tabs including SQL tab.(before able to access SQL tab,as Authentication filter is not applied for SQL tab) 6) Also can be verified with info logs,that Authentication filter applied to SQL tab.(before it is not applied). I have attached the behaviour below in following order.. 1) Command used 2) Before fix (logs and UI) 3) After fix (logs and UI) **1) COMMAND USED**: launching spark-sql with authentication filter. ![image](https://user-images.githubusercontent.com/45845595/49947295-e7e97400-ff16-11e8-8c9a-10659487ddee.png) **2) BEFORE FIX:** **UI result:** able to access SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49948398-62b38e80-ff19-11e8-95dc-e74f9e3c2ba7.png) **logs**: authentication filter not applied to SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49947343-ff286180-ff16-11e8-9de0-3f8db140bc32.png) **3) AFTER FIX:** **UI result**: Not able to access SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49947360-0d767d80-ff17-11e8-9e9e-a95311949164.png) **in logs**: Both yarn filter and Authentication filter applied to SQL tab. ![image](https://user-images.githubusercontent.com/45845595/49947377-1a936c80-ff17-11e8-9f44-700eb3dc0ded.png) Closes #23312 from chakravarthiT/SPARK-26255_ui. Authored-by: chakravarthi Signed-off-by: Marcelo Vanzin --- .../apache/spark/scheduler/cluster/YarnSchedulerBackend.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 67c36aac49266..1289d4be79ea4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -168,8 +168,10 @@ private[spark] abstract class YarnSchedulerBackend( filterName != null && filterName.nonEmpty && filterParams != null && filterParams.nonEmpty if (hasFilter) { + // SPARK-26255: Append user provided filters(spark.ui.filters) with yarn filter. + val allFilters = filterName + "," + conf.get("spark.ui.filters", "") logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") - conf.set("spark.ui.filters", filterName) + conf.set("spark.ui.filters", allFilters) filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } From 5a116e669cb196f59ab3f8d06477f675cd0400f9 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 17 Dec 2018 10:07:35 -0800 Subject: [PATCH 2300/2461] [SPARK-26371][SS] Increase kafka ConfigUpdater test coverage. ## What changes were proposed in this pull request? As Kafka delegation token added logic into ConfigUpdater it would be good to test it. This PR contains the following changes: * ConfigUpdater extracted to a separate file and renamed to KafkaConfigUpdater * mockito-core dependency added to kafka-0-10-sql * Unit tests added ## How was this patch tested? Existing + new unit tests + on cluster. Closes #23321 from gaborgsomogyi/SPARK-26371. Authored-by: Gabor Somogyi Signed-off-by: Dongjoon Hyun --- external/kafka-0-10-sql/pom.xml | 5 + .../sql/kafka010/KafkaConfigUpdater.scala | 74 ++++++++++++ .../sql/kafka010/KafkaSourceProvider.scala | 52 +------- .../kafka010/KafkaConfigUpdaterSuite.scala | 113 ++++++++++++++++++ .../kafka010/KafkaDelegationTokenTest.scala | 90 ++++++++++++++ .../kafka010/KafkaSecurityHelperSuite.scala | 46 +------ 6 files changed, 287 insertions(+), 93 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index de8731c4b774b..1c77906f43b17 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -106,6 +106,11 @@ ${jetty.version} test + + org.mockito + mockito-core + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala new file mode 100644 index 0000000000000..bc1b8019f6a63 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdater.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.config.SaslConfigs + +import org.apache.spark.SparkEnv +import org.apache.spark.deploy.security.KafkaTokenUtil +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Kafka + +/** + * Class to conveniently update Kafka config params, while logging the changes + */ +private[kafka010] case class KafkaConfigUpdater(module: String, kafkaParams: Map[String, String]) + extends Logging { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") + this + } + + def setIfUnset(key: String, value: Object): this.type = { + if (!map.containsKey(key)) { + map.put(key, value) + logDebug(s"$module: Set $key to $value") + } + this + } + + def setAuthenticationConfigIfNeeded(): this.type = { + // There are multiple possibilities to log in and applied in the following order: + // - JVM global security provided -> try to log in with JVM global security configuration + // which can be configured for example with 'java.security.auth.login.config'. + // For this no additional parameter needed. + // - Token is provided -> try to log in with scram module using kafka's dynamic JAAS + // configuration. + if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) { + logDebug("JVM global security configuration detected, using it for login.") + } else if (KafkaSecurityHelper.isTokenAvailable()) { + logDebug("Delegation token detected, using it for login.") + val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) + set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) + val mechanism = SparkEnv.get.conf.get(Kafka.TOKEN_SASL_MECHANISM) + require(mechanism.startsWith("SCRAM"), + "Delegation token works only with SCRAM mechanism.") + set(SaslConfigs.SASL_MECHANISM, mechanism) + } + this + } + + def build(): ju.Map[String, Object] = map +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 4b8b5c0019b44..5774ee7a1c945 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -24,13 +24,9 @@ import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig import org.apache.kafka.clients.producer.ProducerConfig -import org.apache.kafka.common.config.SaslConfigs import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} -import org.apache.spark.SparkEnv -import org.apache.spark.deploy.security.KafkaTokenUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ @@ -483,7 +479,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { } def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] = - ConfigUpdater("source", specifiedKafkaParams) + KafkaConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -506,7 +502,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { def kafkaParamsForExecutors( specifiedKafkaParams: Map[String, String], uniqueGroupId: String): ju.Map[String, Object] = - ConfigUpdater("executor", specifiedKafkaParams) + KafkaConfigUpdater("executor", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -537,48 +533,6 @@ private[kafka010] object KafkaSourceProvider extends Logging { s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" } - /** Class to conveniently update Kafka config params, while logging the changes */ - private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { - private val map = new ju.HashMap[String, Object](kafkaParams.asJava) - - def set(key: String, value: Object): this.type = { - map.put(key, value) - logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") - this - } - - def setIfUnset(key: String, value: Object): ConfigUpdater = { - if (!map.containsKey(key)) { - map.put(key, value) - logDebug(s"$module: Set $key to $value") - } - this - } - - def setAuthenticationConfigIfNeeded(): ConfigUpdater = { - // There are multiple possibilities to log in and applied in the following order: - // - JVM global security provided -> try to log in with JVM global security configuration - // which can be configured for example with 'java.security.auth.login.config'. - // For this no additional parameter needed. - // - Token is provided -> try to log in with scram module using kafka's dynamic JAAS - // configuration. - if (KafkaTokenUtil.isGlobalJaasConfigurationProvided) { - logDebug("JVM global security configuration detected, using it for login.") - } else if (KafkaSecurityHelper.isTokenAvailable()) { - logDebug("Delegation token detected, using it for login.") - val jaasParams = KafkaSecurityHelper.getTokenJaasParams(SparkEnv.get.conf) - set(SaslConfigs.SASL_JAAS_CONFIG, jaasParams) - val mechanism = SparkEnv.get.conf.get(Kafka.TOKEN_SASL_MECHANISM) - require(mechanism.startsWith("SCRAM"), - "Delegation token works only with SCRAM mechanism.") - set(SaslConfigs.SASL_MECHANISM, mechanism) - } - this - } - - def build(): ju.Map[String, Object] = map - } - private[kafka010] def kafkaParamsForProducer( parameters: Map[String, String]): ju.Map[String, Object] = { val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } @@ -596,7 +550,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { val specifiedKafkaParams = convertToSpecifiedParams(parameters) - ConfigUpdater("executor", specifiedKafkaParams) + KafkaConfigUpdater("executor", specifiedKafkaParams) .set(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, serClassName) .set(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, serClassName) .setAuthenticationConfigIfNeeded() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala new file mode 100644 index 0000000000000..25ccca3cb9846 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaConfigUpdaterSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.config.SaslConfigs + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.config._ + +class KafkaConfigUpdaterSuite extends SparkFunSuite with KafkaDelegationTokenTest { + private val testModule = "testModule" + private val testKey = "testKey" + private val testValue = "testValue" + private val otherTestValue = "otherTestValue" + + test("set should always set value") { + val params = Map.empty[String, String] + + val updatedParams = KafkaConfigUpdater(testModule, params) + .set(testKey, testValue) + .build() + + assert(updatedParams.size() === 1) + assert(updatedParams.get(testKey) === testValue) + } + + test("setIfUnset without existing key should set value") { + val params = Map.empty[String, String] + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setIfUnset(testKey, testValue) + .build() + + assert(updatedParams.size() === 1) + assert(updatedParams.get(testKey) === testValue) + } + + test("setIfUnset with existing key should not set value") { + val params = Map[String, String](testKey -> testValue) + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setIfUnset(testKey, otherTestValue) + .build() + + assert(updatedParams.size() === 1) + assert(updatedParams.get(testKey) === testValue) + } + + test("setAuthenticationConfigIfNeeded with global security should not set values") { + val params = Map.empty[String, String] + setGlobalKafkaClientConfig() + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + + assert(updatedParams.size() === 0) + } + + test("setAuthenticationConfigIfNeeded with token should set values") { + val params = Map.empty[String, String] + setSparkEnv(Map.empty) + addTokenToUGI() + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + + assert(updatedParams.size() === 2) + assert(updatedParams.containsKey(SaslConfigs.SASL_JAAS_CONFIG)) + assert(updatedParams.get(SaslConfigs.SASL_MECHANISM) === + Kafka.TOKEN_SASL_MECHANISM.defaultValueString) + } + + test("setAuthenticationConfigIfNeeded with token and invalid mechanism should throw exception") { + val params = Map.empty[String, String] + setSparkEnv(Map[String, String](Kafka.TOKEN_SASL_MECHANISM.key -> "INVALID")) + addTokenToUGI() + + val e = intercept[IllegalArgumentException] { + KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + } + + assert(e.getMessage.contains("Delegation token works only with SCRAM mechanism.")) + } + + test("setAuthenticationConfigIfNeeded without security should not set values") { + val params = Map.empty[String, String] + + val updatedParams = KafkaConfigUpdater(testModule, params) + .setAuthenticationConfigIfNeeded() + .build() + + assert(updatedParams.size() === 0) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala new file mode 100644 index 0000000000000..1899c65c721bb --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import javax.security.auth.login.{AppConfigurationEntry, Configuration} + +import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token +import org.mockito.Mockito.{doReturn, mock} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.deploy.security.KafkaTokenUtil +import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdentifier + +/** + * This is a trait which provides functionalities for Kafka delegation token related test suites. + */ +trait KafkaDelegationTokenTest extends BeforeAndAfterEach { + self: SparkFunSuite => + + protected val tokenId = "tokenId" + ju.UUID.randomUUID().toString + protected val tokenPassword = "tokenPassword" + ju.UUID.randomUUID().toString + + private class KafkaJaasConfiguration extends Configuration { + val entry = + new AppConfigurationEntry( + "DummyModule", + AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, + ju.Collections.emptyMap[String, Object]() + ) + + override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { + if (name.equals("KafkaClient")) { + Array(entry) + } else { + null + } + } + } + + override def afterEach(): Unit = { + try { + Configuration.setConfiguration(null) + UserGroupInformation.setLoginUser(null) + SparkEnv.set(null) + } finally { + super.afterEach() + } + } + + protected def setGlobalKafkaClientConfig(): Unit = { + Configuration.setConfiguration(new KafkaJaasConfiguration) + } + + protected def addTokenToUGI(): Unit = { + val token = new Token[KafkaDelegationTokenIdentifier]( + tokenId.getBytes, + tokenPassword.getBytes, + KafkaTokenUtil.TOKEN_KIND, + KafkaTokenUtil.TOKEN_SERVICE + ) + val creds = new Credentials() + creds.addToken(KafkaTokenUtil.TOKEN_SERVICE, token) + UserGroupInformation.getCurrentUser.addCredentials(creds) + } + + protected def setSparkEnv(settings: Traversable[(String, String)]): Unit = { + val conf = new SparkConf().setAll(settings) + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala index fd9dee390d185..d908bbfc2c5f4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSecurityHelperSuite.scala @@ -17,51 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.UUID - -import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.Token -import org.scalatest.BeforeAndAfterEach - import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.security.KafkaTokenUtil -import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdentifier - -class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { - private val tokenId = "tokenId" + UUID.randomUUID().toString - private val tokenPassword = "tokenPassword" + UUID.randomUUID().toString - - private var sparkConf: SparkConf = null - - override def beforeEach(): Unit = { - super.beforeEach() - sparkConf = new SparkConf() - } - - override def afterEach(): Unit = { - try { - resetUGI - } finally { - super.afterEach() - } - } - - private def addTokenToUGI(): Unit = { - val token = new Token[KafkaDelegationTokenIdentifier]( - tokenId.getBytes, - tokenPassword.getBytes, - KafkaTokenUtil.TOKEN_KIND, - KafkaTokenUtil.TOKEN_SERVICE - ) - val creds = new Credentials() - creds.addToken(KafkaTokenUtil.TOKEN_SERVICE, token) - UserGroupInformation.getCurrentUser.addCredentials(creds) - } - - private def resetUGI: Unit = { - UserGroupInformation.setLoginUser(null) - } +class KafkaSecurityHelperSuite extends SparkFunSuite with KafkaDelegationTokenTest { test("isTokenAvailable without token should return false") { assert(!KafkaSecurityHelper.isTokenAvailable()) } @@ -75,7 +33,7 @@ class KafkaSecurityHelperSuite extends SparkFunSuite with BeforeAndAfterEach { test("getTokenJaasParams with token should return scram module") { addTokenToUGI() - val jaasParams = KafkaSecurityHelper.getTokenJaasParams(sparkConf) + val jaasParams = KafkaSecurityHelper.getTokenJaasParams(new SparkConf()) assert(jaasParams.contains("ScramLoginModule required")) assert(jaasParams.contains("tokenauth=true")) From 81d377d772a527d9ae3311be0480e6403769e919 Mon Sep 17 00:00:00 2001 From: Vaclav Kosar Date: Mon, 17 Dec 2018 11:50:24 -0800 Subject: [PATCH 2301/2461] [SPARK-24933][SS] Report numOutputRows in SinkProgress ## What changes were proposed in this pull request? SinkProgress should report similar properties like SourceProgress as long as they are available for given Sink. Count of written rows is metric availble for all Sinks. Since relevant progress information is with respect to commited rows, ideal object to carry this info is WriterCommitMessage. For brevity the implementation will focus only on Sinks with API V2 and on Micro Batch mode. Implemention for Continuous mode will be provided at later date. ### Before ``` {"description":"org.apache.spark.sql.kafka010.KafkaSourceProvider3c0bd317"} ``` ### After ``` {"description":"org.apache.spark.sql.kafka010.KafkaSourceProvider3c0bd317","numOutputRows":5000} ``` ### This PR is related to: - https://issues.apache.org/jira/browse/SPARK-24647 - https://issues.apache.org/jira/browse/SPARK-21313 ## How was this patch tested? Existing and new unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21919 from vackosar/feature/SPARK-24933-numOutputRows. Lead-authored-by: Vaclav Kosar Co-authored-by: Kosar, Vaclav: Functions Transformation Signed-off-by: gatorsmile --- .../spark/sql/kafka010/KafkaSinkSuite.scala | 21 +++++++++++++ .../v2/WriteToDataSourceV2Exec.scala | 30 +++++++++++++++---- .../streaming/MicroBatchExecution.scala | 11 +++++-- .../streaming/ProgressReporter.scala | 7 +++-- .../execution/streaming/StreamExecution.scala | 4 +++ .../apache/spark/sql/streaming/progress.scala | 21 +++++++++++-- ...StreamingQueryStatusAndProgressSuite.scala | 10 ++++--- 7 files changed, 88 insertions(+), 16 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index d46c4139011da..07d2b8a5dc420 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -232,6 +232,27 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { } } + test("streaming - sink progress is produced") { + /* ensure sink progress is correctly produced. */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))() + + try { + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + assert(writer.lastProgress.sink.numOutputRows == 3L) + } finally { + writer.stop() + } + } test("streaming - write data with bad schema") { val input = MemoryStream[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 9a1fe1e0a328b..d7e20eed4cbc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{LongAccumulator, Utils} /** * Deprecated logical plan for writing data into data source v2. This is being replaced by more @@ -47,6 +47,8 @@ case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPl case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) extends UnaryExecNode { + var commitProgress: Option[StreamWriterCommitProgress] = None + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil @@ -55,6 +57,7 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark val useCommitCoordinator = writeSupport.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) + val totalNumRowsAccumulator = new LongAccumulator() logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${messages.length} partitions.") @@ -65,15 +68,18 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark (context: TaskContext, iter: Iterator[InternalRow]) => DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), rdd.partitions.indices, - (index, message: WriterCommitMessage) => { - messages(index) = message - writeSupport.onDataWriterCommit(message) + (index, result: DataWritingSparkTaskResult) => { + val commitMessage = result.writerCommitMessage + messages(index) = commitMessage + totalNumRowsAccumulator.add(result.numRows) + writeSupport.onDataWriterCommit(commitMessage) } ) logInfo(s"Data source write support $writeSupport is committing.") writeSupport.commit(messages) logInfo(s"Data source write support $writeSupport committed.") + commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value)) } catch { case cause: Throwable => logError(s"Data source write support $writeSupport is aborting.") @@ -102,7 +108,7 @@ object DataWritingSparkTask extends Logging { writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], - useCommitCoordinator: Boolean): WriterCommitMessage = { + useCommitCoordinator: Boolean): DataWritingSparkTaskResult = { val stageId = context.stageId() val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() @@ -110,9 +116,12 @@ object DataWritingSparkTask extends Logging { val attemptId = context.attemptNumber() val dataWriter = writerFactory.createWriter(partId, taskId) + var count = 0L // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { while (iter.hasNext) { + // Count is here. + count += 1 dataWriter.write(iter.next()) } @@ -139,7 +148,7 @@ object DataWritingSparkTask extends Logging { logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + s"stage $stageId.$stageAttempt)") - msg + DataWritingSparkTaskResult(count, msg) })(catchBlock = { // If there is an error, abort this writer @@ -151,3 +160,12 @@ object DataWritingSparkTask extends Logging { }) } } + +private[v2] case class DataWritingSparkTaskResult( + numRows: Long, + writerCommitMessage: WriterCommitMessage) + +/** + * Sink progress information collected after commit. + */ +private[sql] case class StreamWriterCommitProgress(numOutputRows: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 03beefeca269b..8ad436a4ff57d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} @@ -246,6 +246,7 @@ class MicroBatchExecution( * DONE */ private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { + sinkCommitProgress = None offsetLog.getLatest() match { case Some((latestBatchId, nextOffsets)) => /* First assume that we are re-executing the latest known batch @@ -537,7 +538,8 @@ class MicroBatchExecution( val nextBatch = new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) - reportTimeTaken("addBatch") { + val batchSinkProgress: Option[StreamWriterCommitProgress] = + reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) @@ -545,10 +547,15 @@ class MicroBatchExecution( // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } + lastExecution.executedPlan match { + case w: WriteToDataSourceV2Exec => w.commitProgress + case _ => None + } } } withProgressLocked { + sinkCommitProgress = batchSinkProgress watermarkTracker.updateWatermark(lastExecution.executedPlan) commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)) committedOffsets ++= availableOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 39ab702ee083c..d1f3f74c5e731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamWriterCommitProgress} import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent @@ -56,6 +56,7 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] + protected def sinkCommitProgress: Option[StreamWriterCommitProgress] protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -167,7 +168,9 @@ trait ProgressReporter extends Logging { ) } - val sinkProgress = new SinkProgress(sink.toString) + val sinkProgress = SinkProgress( + sink.toString, + sinkCommitProgress.map(_.numOutputRows)) val newProgress = new StreamingQueryProgress( id = id, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 89b4f40c9c0b9..83824f40ab90b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand +import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -114,6 +115,9 @@ abstract class StreamExecution( @volatile var availableOffsets = new StreamProgress + @volatile + var sinkCommitProgress: Option[StreamWriterCommitProgress] = None + /** The current batchId or -1 if execution has not yet been initialized. */ protected var currentBatchId: Long = -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 3cd6700efef5f..0b3945cbd1323 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -30,6 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS /** * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. @@ -207,11 +208,19 @@ class SourceProgress protected[sql]( * during a trigger. See [[StreamingQueryProgress]] for more information. * * @param description Description of the source corresponding to this status. + * @param numOutputRows Number of rows written to the sink or -1 for Continuous Mode (temporarily) + * or Sink V1 (until decommissioned). * @since 2.1.0 */ @Evolving class SinkProgress protected[sql]( - val description: String) extends Serializable { + val description: String, + val numOutputRows: Long) extends Serializable { + + /** SinkProgress without custom metrics. */ + protected[sql] def this(description: String) { + this(description, DEFAULT_NUM_OUTPUT_ROWS) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -222,6 +231,14 @@ class SinkProgress protected[sql]( override def toString: String = prettyJson private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) + ("description" -> JString(description)) ~ + ("numOutputRows" -> JInt(numOutputRows)) } } + +private[sql] object SinkProgress { + val DEFAULT_NUM_OUTPUT_ROWS: Long = -1L + + def apply(description: String, numOutputRows: Option[Long]): SinkProgress = + new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 7bef687e7e43b..2f460b044b237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -73,7 +73,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "inputRowsPerSecond" : 10.0 | } ], | "sink" : { - | "description" : "sink" + | "description" : "sink", + | "numOutputRows" : -1 | } |} """.stripMargin.trim) @@ -105,7 +106,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { | "numInputRows" : 678 | } ], | "sink" : { - | "description" : "sink" + | "description" : "sink", + | "numOutputRows" : -1 | } |} """.stripMargin.trim) @@ -250,7 +252,7 @@ object StreamingQueryStatusAndProgressSuite { processedRowsPerSecond = Double.PositiveInfinity // should not be present in the json ) ), - sink = new SinkProgress("sink") + sink = SinkProgress("sink", None) ) val testProgress2 = new StreamingQueryProgress( @@ -274,7 +276,7 @@ object StreamingQueryStatusAndProgressSuite { processedRowsPerSecond = Double.NegativeInfinity // should not be present in the json ) ), - sink = new SinkProgress("sink") + sink = SinkProgress("sink", None) ) val testStatus = new StreamingQueryStatus("active", true, false) From 114d0de14c441f06d98ab1bcf6c8375c58ecd9ab Mon Sep 17 00:00:00 2001 From: suxingfate Date: Mon, 17 Dec 2018 13:36:57 -0800 Subject: [PATCH 2302/2461] [SPARK-25922][K8] Spark Driver/Executor "spark-app-selector" label mismatch ## What changes were proposed in this pull request? In K8S Cluster mode, the algorithm to generate spark-app-selector/spark.app.id of spark driver is different with spark executor. This patch makes sure spark driver and executor to use the same spark-app-selector/spark.app.id if spark.app.id is set, otherwise it will use superclass applicationId. In K8S Client mode, spark-app-selector/spark.app.id for executors will use superclass applicationId. ## How was this patch tested? Manually run." Closes #23322 from suxingfate/SPARK-25922. Lead-authored-by: suxingfate Co-authored-by: xinglwang Signed-off-by: Yinan Li --- .../KubernetesClusterSchedulerBackend.scala | 28 ++++++++++++++----- ...bernetesClusterSchedulerBackendSuite.scala | 14 +++++----- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 68f6f2e46e316..03f5da2bb0bce 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -18,9 +18,10 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} +import io.fabric8.kubernetes.client.KubernetesClient + import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -39,10 +40,10 @@ private[spark] class KubernetesClusterSchedulerBackend( lifecycleEventHandler: ExecutorPodsLifecycleManager, watchEvents: ExecutorPodsWatchSnapshotSource, pollEvents: ExecutorPodsPollingSnapshotSource) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { - private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( - requestExecutorsService) + private implicit val requestExecutorContext = + ExecutionContext.fromExecutorService(requestExecutorsService) protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { @@ -60,6 +61,17 @@ private[spark] class KubernetesClusterSchedulerBackend( removeExecutor(executorId, reason) } + /** + * Get an application ID associated with the job. + * This returns the string value of spark.app.id if set, otherwise + * the locally-generated ID from the superclass. + * + * @return The application ID + */ + override def applicationId(): String = { + conf.getOption("spark.app.id").map(_.toString).getOrElse(super.applicationId) + } + override def start(): Unit = { super.start() if (!Utils.isDynamicAllocationEnabled(conf)) { @@ -88,7 +100,8 @@ private[spark] class KubernetesClusterSchedulerBackend( if (shouldDeleteExecutors) { Utils.tryLogNonFatalError { - kubernetesClient.pods() + kubernetesClient + .pods() .withLabel(SPARK_APP_ID_LABEL, applicationId()) .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) .delete() @@ -120,7 +133,8 @@ private[spark] class KubernetesClusterSchedulerBackend( } override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - kubernetesClient.pods() + kubernetesClient + .pods() .withLabel(SPARK_APP_ID_LABEL, applicationId()) .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) @@ -133,7 +147,7 @@ private[spark] class KubernetesClusterSchedulerBackend( } private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) - extends DriverEndpoint(rpcEnv, sparkProperties) { + extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { // Don't do anything besides disabling the executor - allow the Kubernetes API events to diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 75232f7b98b04..6e182bed459f8 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -37,6 +37,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val requestExecutorsService = new DeterministicScheduler() private val sparkConf = new SparkConf(false) .set("spark.executor.instances", "3") + .set("spark.app.id", TEST_SPARK_APP_ID) @Mock private var sc: SparkContext = _ @@ -87,8 +88,10 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn when(sc.env).thenReturn(env) when(env.rpcEnv).thenReturn(rpcEnv) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(rpcEnv.setupEndpoint( - mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + when( + rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), + driverEndpoint.capture())) .thenReturn(driverEndpointRef) when(kubernetesClient.pods()).thenReturn(podOperations) schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( @@ -100,9 +103,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn podAllocator, lifecycleEventHandler, watchEvents, - pollEvents) { - override def applicationId(): String = TEST_SPARK_APP_ID - } + pollEvents) } test("Start all components") { @@ -127,8 +128,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn test("Remove executor") { schedulerBackendUnderTest.start() - schedulerBackendUnderTest.doRemoveExecutor( - "1", ExecutorKilled) + schedulerBackendUnderTest.doRemoveExecutor("1", ExecutorKilled) verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } From 86100df54ba8413bebd6ca243b55a6007bc7a2de Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 18 Dec 2018 09:15:21 +0800 Subject: [PATCH 2303/2461] [SPARK-24561][SQL][PYTHON] User-defined window aggregation functions with Pandas UDF (bounded window) ## What changes were proposed in this pull request? This PR implements a new feature - window aggregation Pandas UDF for bounded window. #### Doc: https://docs.google.com/document/d/14EjeY5z4-NC27-SmIP9CsMPCANeTcvxN44a7SIJtZPc/edit#heading=h.c87w44wcj3wj #### Example: ``` from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.sql.window import Window df = spark.range(0, 10, 2).toDF('v') w1 = Window.partitionBy().orderBy('v').rangeBetween(-2, 4) w2 = Window.partitionBy().orderBy('v').rowsBetween(-2, 2) pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() df.withColumn('v_mean', avg(df['v']).over(w1)).show() # +---+------+ # | v|v_mean| # +---+------+ # | 0| 1.0| # | 2| 2.0| # | 4| 4.0| # | 6| 6.0| # | 8| 7.0| # +---+------+ df.withColumn('v_mean', avg(df['v']).over(w2)).show() # +---+------+ # | v|v_mean| # +---+------+ # | 0| 2.0| # | 2| 3.0| # | 4| 4.0| # | 6| 5.0| # | 8| 6.0| # +---+------+ ``` #### High level changes: This PR modifies the existing WindowInPandasExec physical node to deal with unbounded (growing, shrinking and sliding) windows. * `WindowInPandasExec` now share the same base class as `WindowExec` and share utility functions. See `WindowExecBase` * `WindowFunctionFrame` now has two new functions `currentLowerBound` and `currentUpperBound` - to return the lower and upper window bound for the current output row. It is also modified to allow `AggregateProcessor` == null. Null aggregator processor is used for `WindowInPandasExec` where we don't have an aggregator and only uses lower and upper bound functions from `WindowFunctionFrame` * The biggest change is in `WindowInPandasExec`, where it is modified to take `currentLowerBound` and `currentUpperBound` and write those values together with the input data to the python process for rolling window aggregation. See `WindowInPandasExec` for more details. #### Discussion In benchmarking, I found numpy variant of the rolling window UDF is much faster than the pandas version: Spark SQL window function: 20s Pandas variant: ~80s Numpy variant: 10s Numpy variant with numba: 4s Allowing numpy variant of the vectorized UDFs is something I want to discuss because of the performance improvement, but doesn't have to be in this PR. ## How was this patch tested? New tests Closes #22305 from icexelloss/SPARK-24561-bounded-window-udf. Authored-by: Li Jin Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 21 +- .../sql/tests/test_pandas_udf_window.py | 157 ++++++++- python/pyspark/worker.py | 57 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 - .../execution/python/WindowInPandasExec.scala | 329 +++++++++++++++--- .../sql/execution/window/WindowExec.scala | 189 +--------- .../sql/execution/window/WindowExecBase.scala | 230 ++++++++++++ .../window/WindowFunctionFrame.scala | 108 ++++-- 8 files changed, 792 insertions(+), 304 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f98e550e39da8..d188de39e21c7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2982,8 +2982,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - This example shows using grouped aggregated UDFs as window functions. Note that only - unbounded window frame is supported at the moment: + This example shows using grouped aggregated UDFs as window functions. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql import Window @@ -2993,20 +2992,24 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() - >>> w = Window \\ - ... .partitionBy('id') \\ - ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> w = (Window.partitionBy('id') + ... .orderBy('v') + ... .rowsBetween(-1, 0)) >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP +---+----+------+ | id| v|mean_v| +---+----+------+ - | 1| 1.0| 1.5| + | 1| 1.0| 1.0| | 1| 2.0| 1.5| - | 2| 3.0| 6.0| - | 2| 5.0| 6.0| - | 2|10.0| 6.0| + | 2| 3.0| 3.0| + | 2| 5.0| 4.0| + | 2|10.0| 7.5| +---+----+------+ + .. note:: For performance reasons, the input series to window functions are not copied. + Therefore, mutating the input series is not allowed and will cause incorrect results. + For the same reason, users should also not rely on the index of the input series. + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py index 0a7a19c1c0814..3ba98e76468b3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -46,6 +46,15 @@ def python_plus_one(self): def pandas_scalar_time_two(self): return pandas_udf(lambda v: v * 2, 'double') + @property + def pandas_agg_count_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('long', PandasUDFType.GROUPED_AGG) + def count(v): + return len(v) + return count + @property def pandas_agg_mean_udf(self): @pandas_udf('double', PandasUDFType.GROUPED_AGG) @@ -70,7 +79,7 @@ def min(v): @property def unbounded_window(self): return Window.partitionBy('id') \ - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v') @property def ordered_window(self): @@ -80,6 +89,32 @@ def ordered_window(self): def unpartitioned_window(self): return Window.partitionBy() + @property + def sliding_row_window(self): + return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1) + + @property + def sliding_range_window(self): + return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4) + + @property + def growing_row_window(self): + return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3) + + @property + def growing_range_window(self): + return Window.partitionBy('id').orderBy('v') \ + .rangeBetween(Window.unboundedPreceding, 4) + + @property + def shrinking_row_window(self): + return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing) + + @property + def shrinking_range_window(self): + return Window.partitionBy('id').orderBy('v') \ + .rangeBetween(-3, Window.unboundedFollowing) + def test_simple(self): df = self.data w = self.unbounded_window @@ -100,12 +135,12 @@ def test_multiple_udfs(self): w = self.unbounded_window result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ - .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ - .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ - .withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('min_w', min(df['w']).over(w)) + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) @@ -183,16 +218,16 @@ def test_mixed_sql_and_udf(self): # Test chaining sql aggregate function and udf result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('min_v', min(df['v']).over(w)) \ - .withColumn('v_diff', col('max_v') - col('min_v')) \ - .drop('max_v', 'min_v') + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') expected3 = expected1 # Test mixing sql window function and udf result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) + .withColumn('rank', rank().over(ow)) expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) + .withColumn('rank', rank().over(ow)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) @@ -210,8 +245,6 @@ def test_array_type(self): def test_invalid_args(self): df = self.data w = self.unbounded_window - ow = self.ordered_window - mean_udf = self.pandas_agg_mean_udf with QuietTest(self.sc): with self.assertRaisesRegexp( @@ -220,11 +253,101 @@ def test_invalid_args(self): foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) df.withColumn('v2', foo_udf(df['v']).over(w)) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - '.*Only unbounded window frame is supported.*'): - df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + def test_bounded_simple(self): + from pyspark.sql.functions import mean, max, min, count + + df = self.data + w1 = self.sliding_row_window + w2 = self.shrinking_range_window + + plus_one = self.python_plus_one + count_udf = self.pandas_agg_count_udf + mean_udf = self.pandas_agg_mean_udf + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \ + .withColumn('count_v', count_udf(df['v']).over(w2)) \ + .withColumn('max_v', max_udf(df['v']).over(w2)) \ + .withColumn('min_v', min_udf(df['v']).over(w1)) + + expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \ + .withColumn('count_v', count(df['v']).over(w2)) \ + .withColumn('max_v', max(df['v']).over(w2)) \ + .withColumn('min_v', min(df['v']).over(w1)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_growing_window(self): + from pyspark.sql.functions import mean + + df = self.data + w1 = self.growing_row_window + w2 = self.growing_range_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \ + .withColumn('m2', mean_udf(df['v']).over(w2)) + + expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \ + .withColumn('m2', mean(df['v']).over(w2)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_sliding_window(self): + from pyspark.sql.functions import mean + + df = self.data + w1 = self.sliding_row_window + w2 = self.sliding_range_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \ + .withColumn('m2', mean_udf(df['v']).over(w2)) + + expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \ + .withColumn('m2', mean(df['v']).over(w2)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_shrinking_window(self): + from pyspark.sql.functions import mean + + df = self.data + w1 = self.shrinking_row_window + w2 = self.shrinking_range_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \ + .withColumn('m2', mean_udf(df['v']).over(w2)) + + expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \ + .withColumn('m2', mean(df['v']).over(w2)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_bounded_mixed(self): + from pyspark.sql.functions import mean, max + + df = self.data + w1 = self.sliding_row_window + w2 = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + max_udf = self.pandas_agg_max_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \ + .withColumn('max_v', max_udf(df['v']).over(w2)) \ + .withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \ + .withColumn('max_v', max(df['v']).over(w2)) \ + .withColumn('mean_unbounded_v', mean(df['v']).over(w1)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) if __name__ == "__main__": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 953b468e96519..bf007b0c62d8d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -145,7 +145,18 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def wrap_window_agg_pandas_udf(f, return_type): +def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index): + window_bound_types_str = runner_conf.get('pandas_window_bound_types') + window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(',')][udf_index] + if window_bound_type == 'bounded': + return wrap_bounded_window_agg_pandas_udf(f, return_type) + elif window_bound_type == 'unbounded': + return wrap_unbounded_window_agg_pandas_udf(f, return_type) + else: + raise RuntimeError("Invalid window bound type: {} ".format(window_bound_type)) + + +def wrap_unbounded_window_agg_pandas_udf(f, return_type): # This is similar to grouped_agg_pandas_udf, the only difference # is that window_agg_pandas_udf needs to repeat the return value # to match window length, where grouped_agg_pandas_udf just returns @@ -160,7 +171,41 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type, runner_conf): +def wrap_bounded_window_agg_pandas_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(begin_index, end_index, *series): + import pandas as pd + result = [] + + # Index operation is faster on np.ndarray, + # So we turn the index series into np array + # here for performance + begin_array = begin_index.values + end_array = end_index.values + + for i in range(len(begin_array)): + # Note: Create a slice from a series for each window is + # actually pretty expensive. However, there + # is no easy way to reduce cost here. + # Note: s.iloc[i : j] is about 30% faster than s[i: j], with + # the caveat that the created slices shares the same + # memory with s. Therefore, user are not allowed to + # change the value of input series inside the window + # function. It is rare that user needs to modify the + # input series in the window function, and therefore, + # it is be a reasonable restriction. + # Note: Calling reset_index on the slices will increase the cost + # of creating slices by about 100%. Therefore, for performance + # reasons we don't do it here. + series_slices = [s.iloc[begin_array[i]: end_array[i]] for s in series] + result.append(f(*series_slices)) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -184,7 +229,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf): elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: - return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: @@ -226,7 +271,8 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -238,7 +284,8 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=i) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 6a91d556b2f3e..88d41e8824405 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -134,11 +134,6 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") - case _ @ WindowExpression(_: PythonUDF, - WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) - if !frame.isUnbounded => - failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") - case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window // function or a Pandas window UDF. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 82973307feef3..1ce1215bfdd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -27,17 +27,64 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.arrow.ArrowUtils -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.execution.window._ +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * This class calculates and outputs windowed aggregates over the rows in a single partition. + * + * This is similar to [[WindowExec]]. The main difference is that this node does not compute + * any window aggregation values. Instead, it computes the lower and upper bound for each window + * (i.e. window bounds) and pass the data and indices to Python worker to do the actual window + * aggregation. + * + * It currently materializes all data associated with the same partition key and passes them to + * Python worker. This is not strictly necessary for sliding windows and can be improved (by + * possibly slicing data into overlapping chunks and stitching them together). + * + * This class groups window expressions by their window boundaries so that window expressions + * with the same window boundaries can share the same window bounds. The window bounds are + * prepended to the data passed to the python worker. + * + * For example, if we have: + * avg(v) over specifiedwindowframe(RowFrame, -5, 5), + * avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing), + * avg(v) over specifiedwindowframe(RowFrame, -3, 3), + * max(v) over specifiedwindowframe(RowFrame, -3, 3) + * + * The python input will look like: + * (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v) + * + * where w1 is specifiedwindowframe(RowFrame, -5, 5) + * w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing) + * w3 is specifiedwindowframe(RowFrame, -3, 3) + * + * Note that w2 doesn't have bound indices in the python input because it's unbounded window + * so it's bound indices will always be the same. + * + * Bounded window and Unbounded window are evaluated differently in Python worker: + * (1) Bounded window takes the window bound indices in addition to the input columns. + * Unbounded window takes only input columns. + * (2) Bounded window evaluates the udf once per input row. + * Unbounded window evaluates the udf once per window partition. + * This is controlled by Python runner conf "pandas_window_bound_types" + * + * The logic to compute window bounds is delegated to [[WindowFunctionFrame]] and shared with + * [[WindowExec]] + * + * Note this doesn't support partial aggregation and all aggregation is computed from the entire + * window. + */ case class WindowInPandasExec( windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - child: SparkPlan) extends UnaryExecNode { + child: SparkPlan) + extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) { override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) @@ -60,6 +107,26 @@ case class WindowInPandasExec( override def outputPartitioning: Partitioning = child.outputPartitioning + /** + * Helper functions and data structures for window bounds + * + * It contains: + * (1) Total number of window bound indices in the python input row + * (2) Function from frame index to its lower bound column index in the python input row + * (3) Function from frame index to its upper bound column index in the python input row + * (4) Seq from frame index to its window bound type + */ + private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) + + /** + * Enum for window bound types. Used only inside this class. + */ + private sealed case class WindowBoundType(value: String) + private object UnboundedWindow extends WindowBoundType("unbounded") + private object BoundedWindow extends WindowBoundType("bounded") + + private val windowBoundTypeConf = "pandas_window_bound_types" + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => @@ -73,68 +140,150 @@ case class WindowInPandasExec( } /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. + * See [[WindowBoundHelpers]] for details. */ - private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map { case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) + private def computeWindowBoundHelpers( + factories: Seq[InternalRow => WindowFunctionFrame] + ): WindowBoundHelpers = { + val functionFrames = factories.map(_(EmptyRow)) + + val windowBoundTypes = functionFrames.map { + case _: UnboundedWindowFunctionFrame => UnboundedWindow + case _: UnboundedFollowingWindowFunctionFrame | + _: SlidingWindowFunctionFrame | + _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow + // It should be impossible to get other types of window function frame here + case frame => throw new RuntimeException(s"Unexpected window function frame $frame.") } - val unboundToRefMap = expressions.zip(references).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) + + val requiredIndices = functionFrames.map { + case _: UnboundedWindowFunctionFrame => 0 + case _ => 2 + } + + val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail + + val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => + if (num == 0) { + // Sentinel values for unbounded window + (-1, -1) + } else { + (upperBoundIndex - 2, upperBoundIndex - 1) + } + } + + def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 + def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 + + (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) } protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute() + // Unwrap the expressions and factories from the map. + val expressionsWithFrameIndex = + windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { + case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) + } + + val expressions = expressionsWithFrameIndex.map(_._1) + val expressionIndexToFrameIndex = + expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap + + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + // Helper functions + val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = + computeWindowBoundHelpers(factories) + val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } + val numFrames = factories.length + + val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + val spillThreshold = conf.windowExecBufferSpillThreshold val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Extract window expressions and window functions - val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) - - val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) + val udfExpressions = windowExpressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + // We shouldn't be chaining anything here. + // All chained python functions should only contain one function. val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + require(pyFuncs.length == expressions.length) + + val udfWindowBoundTypes = pyFuncs.indices.map(i => + frameWindowBoundTypes(expressionIndexToFrameIndex(i))) + val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) // Filter child output attributes down to only those that are UDF inputs. - // Also eliminate duplicate UDF inputs. - val allInputs = new ArrayBuffer[Expression] - val dataTypes = new ArrayBuffer[DataType] + // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node + // handles UDF inputs. + val dataInputs = new ArrayBuffer[Expression] + val dataInputTypes = new ArrayBuffer[DataType] val argOffsets = inputs.map { input => input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + if (dataInputs.exists(_.semanticEquals(e))) { + dataInputs.indexWhere(_.semanticEquals(e)) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + dataInputs += e + dataInputTypes += e.dataType + dataInputs.length - 1 } }.toArray }.toArray - // Schema of input rows to the python runner - val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => - StructField(s"_$i", dt) - }) + // In addition to UDF inputs, we will prepend window bounds for each UDFs. + // For bounded windows, we prepend lower bound and upper bound. For unbounded windows, + // we no not add window bounds. (strictly speaking, we only need to lower or upper bound + // if the window is bounded only on one side, this can be improved in the future) - inputRDD.mapPartitionsInternal { iter => - val context = TaskContext.get() + // Setting window bounds for each window frames. Each window frame has different bounds so + // each has its own window bound columns. + val windowBoundsInput = factories.indices.flatMap { frameIndex => + if (isBounded(frameIndex)) { + Seq( + BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), + BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) + ) + } else { + Seq.empty + } + } - val grouped = if (partitionSpec.isEmpty) { - // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset + // for the UDF is (lowerBoundOffet, upperBoundOffset, inputOffset1, inputOffset2, ...) + // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) + pyFuncs.indices.foreach { exprIndex => + val frameIndex = expressionIndexToFrameIndex(exprIndex) + if (isBounded(frameIndex)) { + argOffsets(exprIndex) = + Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ + argOffsets(exprIndex).map(_ + windowBoundsInput.length) } else { - GroupedIterator(iter, partitionSpec, child.output) + argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) } + } + + val allInputs = windowBoundsInput ++ dataInputs + val allInputTypes = allInputs.map(_.dataType) + + // Start processing. + child.execute().mapPartitions { iter => + val context = TaskContext.get() + + // Get all relevant projections. + val resultProj = createResultProjection(expressions) + val pythonInputProj = UnsafeProjection.create( + allInputs, + windowBoundsInput.map(ref => + AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output + ) + val pythonInputSchema = StructType( + allInputTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + } + ) + val grouping = UnsafeProjection.create(partitionSpec, child.output) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -144,11 +293,94 @@ case class WindowInPandasExec( queue.close() } - val inputProj = UnsafeProjection.create(allInputs, child.output) - val pythonInput = grouped.map { case (_, rows) => - rows.map { row => - queue.add(row.asInstanceOf[UnsafeRow]) - inputProj(row) + val stream = iter.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + row + } + + val pythonInput = new Iterator[Iterator[UnsafeRow]] { + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow() { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + + val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) + + val frames = factories.map(_(indexRow)) + + private[this] def fetchNextPartition() { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + buffer.clear() + + while (nextRowAvailable && nextGroup == currentGroup) { + buffer.add(nextRow) + fetchNextRow() + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(buffer) + i += 1 + } + + // Setup iteration + rowIndex = 0 + bufferIterator = buffer.generateIterator() + } + + // Iteration + var rowIndex = 0 + + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable + + override final def next(): Iterator[UnsafeRow] = { + // Load the next partition if we need to. + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { + fetchNextPartition() + } + + val join = new JoinedRow + + bufferIterator.zipWithIndex.map { + case (current, index) => + var frameIndex = 0 + while (frameIndex < numFrames) { + frames(frameIndex).write(index, current) + // If the window is unbounded we don't need to write out window bounds. + if (isBounded(frameIndex)) { + indexRow.setInt( + lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) + indexRow.setInt( + upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) + } + frameIndex += 1 + } + + pythonInputProj(join(indexRow, current)) + } } } @@ -156,12 +388,11 @@ case class WindowInPandasExec( pyFuncs, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, argOffsets, - windowInputSchema, + pythonInputSchema, sessionLocalTimeZone, pythonRunnerConf).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow - val resultProj = createResultProjection(expressions) windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => val leftRow = queue.remove() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 729b8bdb3dae8..89f6edda2ef57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -83,7 +83,7 @@ case class WindowExec( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) - extends UnaryExecNode { + extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) { override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) @@ -104,193 +104,6 @@ case class WindowExec( override def outputPartitioning: Partitioning = child.outputPartitioning - /** - * Create a bound ordering object for a given frame type and offset. A bound ordering object is - * used to determine which input row lies within the frame boundaries of an output row. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame to evaluate. This can either be a Row or Range frame. - * @param bound with respect to the row. - * @param timeZone the session local timezone for time related calculations. - * @return a bound ordering object. - */ - private[this] def createBoundOrdering( - frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { - (frame, bound) match { - case (RowFrame, CurrentRow) => - RowBoundOrdering(0) - - case (RowFrame, IntegerLiteral(offset)) => - RowBoundOrdering(offset) - - case (RangeFrame, CurrentRow) => - val ordering = newOrdering(orderSpec, child.output) - RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) - - case (RangeFrame, offset: Expression) if orderSpec.size == 1 => - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => UnaryMinus(offset) - case Ascending => offset - } - - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = (expr.dataType, boundOffset.dataType) match { - case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) - case (a, b) if a== b => Add(expr, boundOffset) - } - val bound = newMutableProjection(boundExpr :: Nil, child.output) - - // Construct the ordering. This is used to compare the result of current value projection - // to the result of bound value projection. This is done manually because we want to use - // Code Generation (if it is enabled). - val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil - val ordering = newOrdering(boundSortExprs, Nil) - RangeBoundOrdering(ordering, current, bound) - - case (RangeFrame, _) => - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") - } - } - - /** - * Collection containing an entry for each window frame to process. Each entry contains a frame's - * [[WindowExpression]]s and factory function for the WindowFrameFunction. - */ - private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Expression, Expression) - type ExpressionBuffer = mutable.Buffer[Expression] - val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] - - // Add a function and its function to the map for a given frame. - def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, fr.lower, fr.upper) - val (es, fns) = framedFunctions.getOrElseUpdate( - key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) - es += e - fns += fn - } - - // Collect all valid window functions and group them by their frame. - windowExpression.foreach { x => - x.foreach { - case e @ WindowExpression(function, spec) => - val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - function match { - case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) - case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) - case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) - case f => sys.error(s"Unsupported window function: $f") - } - case _ => - } - } - - // Map the groups to a (unbound) expression and frame factory pair. - var numExpressions = 0 - val timeZone = conf.sessionLocalTimeZone - framedFunctions.toSeq.map { - case (key, (expressions, functionSeq)) => - val ordinal = numExpressions - val functions = functionSeq.toArray - - // Construct an aggregate processor if we need one. - def processor = AggregateProcessor( - functions, - ordinal, - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) - - // Create the factory - val factory = key match { - // Offset Frame - case ("OFFSET", _, IntegerLiteral(offset), _) => - target: InternalRow => - new OffsetWindowFunctionFrame( - target, - ordinal, - // OFFSET frame functions are guaranteed be OffsetWindowFunctions. - functions.map(_.asInstanceOf[OffsetWindowFunction]), - child.output, - (expressions, schema) => - newMutableProjection(expressions, schema, subexpressionEliminationEnabled), - offset) - - // Entire Partition Frame. - case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) - } - - // Growing Frame. - case ("AGGREGATE", frameType, UnboundedPreceding, upper) => - target: InternalRow => { - new UnboundedPrecedingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, upper, timeZone)) - } - - // Shrinking Frame. - case ("AGGREGATE", frameType, lower, UnboundedFollowing) => - target: InternalRow => { - new UnboundedFollowingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone)) - } - - // Moving Frame. - case ("AGGREGATE", frameType, lower, upper) => - target: InternalRow => { - new SlidingWindowFunctionFrame( - target, - processor, - createBoundOrdering(frameType, lower, timeZone), - createBoundOrdering(frameType, upper, timeZone)) - } - } - - // Keep track of the number of expressions. This is a side-effect in a map... - numExpressions += expressions.size - - // Create the Frame Expression - Factory pair. - (expressions, factory) - } - } - - /** - * Create the resulting projection. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param expressions unbound ordered function expressions. - * @return the final resulting projection. - */ - private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { - val references = expressions.zipWithIndex.map{ case (e, i) => - // Results of window expressions will be on the right side of child's output - BoundReference(child.output.size + i, e.dataType, e.nullable) - } - val unboundToRefMap = expressions.zip(references).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - UnsafeProjection.create( - child.output ++ patchedWindowExpression, - child.output) - } - protected override def doExecute(): RDD[InternalRow] = { // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala new file mode 100644 index 0000000000000..dcb86f48bdf32 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType} + +abstract class WindowExecBase( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. + * @return a bound ordering object. + */ + private def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RangeFrame, CurrentRow) => + val ordering = newOrdering(orderSpec, child.output) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset + } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = (expr.dataType, boundOffset.dataType) match { + case (DateType, IntegerType) => DateAdd(expr, boundOffset) + case (TimestampType, CalendarIntervalType) => + TimeAdd(expr, boundOffset, Some(timeZone)) + case (a, b) if a == b => Add(expr, boundOffset) + } + val bound = newMutableProjection(boundExpr :: Nil, child.output) + + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = newOrdering(boundSortExprs, Nil) + RangeBoundOrdering(ordering, current, bound) + + case (RangeFrame, _) => + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + } + + /** + * Collection containing an entry for each window frame to process. Each entry contains a frame's + * [[WindowExpression]]s and factory function for the WindowFrameFunction. + */ + protected lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Expression, Expression) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, fr.lower, fr.upper) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es += e + fns += fn + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f: PythonUDF => collect("AGGREGATE", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions + // in a single Window physical node. Therefore, we can assume no SQL aggregation + // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL + // aggregation function in a single physical node. + def processor = if (functions.exists(_.isInstanceOf[PythonUDF])) { + null + } else { + AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) + } + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", _, IntegerLiteral(offset), _) => + target: InternalRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + // OFFSET frame functions are guaranteed be OffsetWindowFunctions. + functions.map(_.asInstanceOf[OffsetWindowFunction]), + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), + offset) + + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + + // Growing Frame. + case ("AGGREGATE", frameType, UnboundedPreceding, upper) => + target: InternalRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, upper, timeZone)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, lower, UnboundedFollowing) => + target: InternalRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, lower, upper) => + target: InternalRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 156002ef58fbe..a5601899ea2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray * Before use a frame must be prepared by passing it all the rows in the current partition. After * preparation the update method can be called to fill the output rows. */ -private[window] abstract class WindowFunctionFrame { +abstract class WindowFunctionFrame { /** * Prepare the frame for calculating the results for a partition. * @@ -42,6 +42,20 @@ private[window] abstract class WindowFunctionFrame { * Write the current results to the target row. */ def write(index: Int, current: InternalRow): Unit + + /** + * The current lower window bound in the row array (inclusive). + * + * This should be called after the current row is updated via [[write]] + */ + def currentLowerBound(): Int + + /** + * The current row index of the upper window bound in the row array (exclusive) + * + * This should be called after the current row is updated via [[write]] + */ + def currentUpperBound(): Int } object WindowFunctionFrame { @@ -62,7 +76,7 @@ object WindowFunctionFrame { * @param newMutableProjection function used to create the projection. * @param offset by which rows get moved within a partition. */ -private[window] final class OffsetWindowFunctionFrame( +final class OffsetWindowFunctionFrame( target: InternalRow, ordinal: Int, expressions: Array[OffsetWindowFunction], @@ -137,6 +151,10 @@ private[window] final class OffsetWindowFunctionFrame( } inputIndex += 1 } + + override def currentLowerBound(): Int = throw new UnsupportedOperationException() + + override def currentUpperBound(): Int = throw new UnsupportedOperationException() } /** @@ -148,7 +166,7 @@ private[window] final class OffsetWindowFunctionFrame( * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ -private[window] final class SlidingWindowFunctionFrame( +final class SlidingWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering, @@ -170,24 +188,24 @@ private[window] final class SlidingWindowFunctionFrame( private[this] val buffer = new util.ArrayDeque[InternalRow]() /** - * Index of the first input row with a value greater than the upper bound of the current - * output row. + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. */ - private[this] var inputHighIndex = 0 + private[this] var lowerBound = 0 /** - * Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. + * Index of the first input row with a value greater than the upper bound of the current + * output row. */ - private[this] var inputLowIndex = 0 + private[this] var upperBound = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIterator = input.generateIterator() nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex = 0 - inputLowIndex = 0 + lowerBound = 0 + upperBound = 0 buffer.clear() } @@ -197,27 +215,27 @@ private[window] final class SlidingWindowFunctionFrame( // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + while (!buffer.isEmpty && lbound.compare(buffer.peek(), lowerBound, current, index) < 0) { buffer.remove() - inputLowIndex += 1 + lowerBound += 1 bufferUpdated = true } // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) { - inputLowIndex += 1 + while (nextRow != null && ubound.compare(nextRow, upperBound, current, index) <= 0) { + if (lbound.compare(nextRow, lowerBound, current, index) < 0) { + lowerBound += 1 } else { buffer.add(nextRow.copy()) bufferUpdated = true } nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex += 1 + upperBound += 1 } // Only recalculate and update when the buffer changes. - if (bufferUpdated) { + if (processor != null && bufferUpdated) { processor.initialize(input.length) val iter = buffer.iterator() while (iter.hasNext) { @@ -226,6 +244,10 @@ private[window] final class SlidingWindowFunctionFrame( processor.evaluate(target) } } + + override def currentLowerBound(): Int = lowerBound + + override def currentUpperBound(): Int = upperBound } /** @@ -239,27 +261,39 @@ private[window] final class SlidingWindowFunctionFrame( * @param target to write results to. * @param processor to calculate the row values with. */ -private[window] final class UnboundedWindowFunctionFrame( +final class UnboundedWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor) extends WindowFunctionFrame { + val lowerBound: Int = 0 + var upperBound: Int = 0 + /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { - processor.initialize(rows.length) - - val iterator = rows.generateIterator() - while (iterator.hasNext) { - processor.update(iterator.next()) + if (processor != null) { + processor.initialize(rows.length) + val iterator = rows.generateIterator() + while (iterator.hasNext) { + processor.update(iterator.next()) + } } + + upperBound = rows.length } /** Write the frame columns for the current row to the given target row. */ override def write(index: Int, current: InternalRow): Unit = { // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate // for each row. - processor.evaluate(target) + if (processor != null) { + processor.evaluate(target) + } } + + override def currentLowerBound(): Int = lowerBound + + override def currentUpperBound(): Int = upperBound } /** @@ -276,7 +310,7 @@ private[window] final class UnboundedWindowFunctionFrame( * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ -private[window] final class UnboundedPrecedingWindowFunctionFrame( +final class UnboundedPrecedingWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor, ubound: BoundOrdering) @@ -308,7 +342,9 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( nextRow = inputIterator.next() } - processor.initialize(input.length) + if (processor != null) { + processor.initialize(input.length) + } } /** Write the frame columns for the current row to the given target row. */ @@ -318,17 +354,23 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { - processor.update(nextRow) + if (processor != null) { + processor.update(nextRow) + } nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. - if (bufferUpdated) { + if (processor != null && bufferUpdated) { processor.evaluate(target) } } + + override def currentLowerBound(): Int = 0 + + override def currentUpperBound(): Int = inputIndex } /** @@ -347,7 +389,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ -private[window] final class UnboundedFollowingWindowFunctionFrame( +final class UnboundedFollowingWindowFunctionFrame( target: InternalRow, processor: AggregateProcessor, lbound: BoundOrdering) @@ -384,7 +426,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( } // Only recalculate and update when the buffer changes. - if (bufferUpdated) { + if (processor != null && bufferUpdated) { processor.initialize(input.length) if (nextRow != null) { processor.update(nextRow) @@ -395,4 +437,8 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( processor.evaluate(target) } } + + override def currentLowerBound(): Int = inputIndex + + override def currentUpperBound(): Int = input.length } From d72571e51d8b41e2287750759e120547afeeb7d7 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 18 Dec 2018 13:50:55 +0800 Subject: [PATCH 2304/2461] [SPARK-26246][SQL] Inferring TimestampType from JSON ## What changes were proposed in this pull request? The `JsonInferSchema` class is extended to support `TimestampType` inferring from string fields in JSON input: - If the `prefersDecimal` option is set to `true`, it tries to infer decimal type from the string field. - If decimal type inference fails or `prefersDecimal` is disabled, `JsonInferSchema` tries to infer `TimestampType`. - If timestamp type inference fails, `StringType` is returned as the inferred type. ## How was this patch tested? Added new test suite - `JsonInferSchemaSuite` to check date and timestamp types inferring from JSON using `JsonInferSchema` directly. A few tests were added `JsonSuite` to check type merging and roundtrip tests. This changes was tested by `JsonSuite`, `JsonExpressionsSuite` and `JsonFunctionsSuite` as well. Closes #23201 from MaxGekk/json-infer-time. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/json/JsonInferSchema.scala | 22 +++- .../catalyst/json/JsonInferSchemaSuite.scala | 102 ++++++++++++++++++ .../datasources/json/JsonSuite.scala | 52 +++++++++ 3 files changed, 171 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 263e05de32075..d1bc00c08c1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -37,6 +37,12 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { private val decimalParser = ExprUtils.getDecimalParser(options.locale) + @transient + private lazy val timestampFormatter = TimestampFormatter( + options.timestampFormat, + options.timeZone, + options.locale) + /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record @@ -115,13 +121,19 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { // record fields' types have been combined. NullType - case VALUE_STRING if options.prefersDecimal => + case VALUE_STRING => + val field = parser.getText val decimalTry = allCatch opt { - val bigDecimal = decimalParser(parser.getText) + val bigDecimal = decimalParser(field) DecimalType(bigDecimal.precision, bigDecimal.scale) } - decimalTry.getOrElse(StringType) - case VALUE_STRING => StringType + if (options.prefersDecimal && decimalTry.isDefined) { + decimalTry.get + } else if ((allCatch opt timestampFormatter.parse(field)).isDefined) { + TimestampType + } else { + StringType + } case START_OBJECT => val builder = Array.newBuilder[StructField] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala new file mode 100644 index 0000000000000..9307f9b47b807 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.json + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { + + def checkType(options: Map[String, String], json: String, dt: DataType): Unit = { + val jsonOptions = new JSONOptions(options, "UTC", "") + val inferSchema = new JsonInferSchema(jsonOptions) + val factory = new JsonFactory() + jsonOptions.setJacksonOptions(factory) + val parser = CreateJacksonParser.string(factory, json) + parser.nextToken() + val expectedType = StructType(Seq(StructField("a", dt, true))) + + assert(inferSchema.inferField(parser) === expectedType) + } + + def checkTimestampType(pattern: String, json: String): Unit = { + checkType(Map("timestampFormat" -> pattern), json, TimestampType) + } + + test("inferring timestamp type") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkTimestampType("yyyy", """{"a": "2018"}""") + checkTimestampType("yyyy=MM", """{"a": "2018=12"}""") + checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""") + checkTimestampType( + "yyyy-MM-dd'T'HH:mm:ss.SSS", + """{"a": "2018-12-02T21:04:00.123"}""") + checkTimestampType( + "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX", + """{"a": "2018-12-02T21:04:00.123567+01:00"}""") + } + } + } + + test("prefer decimals over timestamps") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map( + "prefersDecimal" -> "true", + "timestampFormat" -> "yyyyMMdd.HHmmssSSS" + ), + json = """{"a": "20181202.210400123"}""", + dt = DecimalType(17, 9) + ) + } + } + } + + test("skip decimal type inferring") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map( + "prefersDecimal" -> "false", + "timestampFormat" -> "yyyyMMdd.HHmmssSSS" + ), + json = """{"a": "20181202.210400123"}""", + dt = TimestampType + ) + } + } + } + + test("fallback to string type") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkType( + options = Map("timestampFormat" -> "yyyy,MM,dd.HHmmssSSS"), + json = """{"a": "20181202.210400123"}""", + dt = StringType + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 786335b42e3cb..8f575a371c98e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType.fromDDL import org.apache.spark.util.Utils class TestFileFilter extends PathFilter { @@ -2589,4 +2590,55 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(null, Array(0, 1, 2), "abc", """{"a":"-","b":[0, 1, 2],"c":"abc"}""") :: Row(0.1, null, "def", """{"a":0.1,"b":{},"c":"def"}""") :: Nil) } + + test("inferring timestamp type") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + def schemaOf(jsons: String*): StructType = spark.read.json(jsons.toDS).schema + + assert(schemaOf( + """{"a":"2018-12-17T10:11:12.123-01:00"}""", + """{"a":"2018-12-16T22:23:24.123-02:00"}""") === fromDDL("a timestamp")) + + assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":1}""") + === fromDDL("a string")) + assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":"123"}""") + === fromDDL("a string")) + + assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":null}""") + === fromDDL("a timestamp")) + assert(schemaOf("""{"a":null}""", """{"a":"2018-12-17T10:11:12.123-01:00"}""") + === fromDDL("a timestamp")) + } + } + } + + test("roundtrip for timestamp type inferring") { + Seq(true, false).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val customSchema = new StructType().add("date", TimestampType) + withTempDir { dir => + val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.json" + val timestampsWithFormat = spark.read + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .json(datesRecords) + assert(timestampsWithFormat.schema === customSchema) + + timestampsWithFormat.write + .format("json") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") + .option(DateTimeUtils.TIMEZONE_OPTION, "UTC") + .save(timestampsWithFormatPath) + + val readBack = spark.read + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") + .option(DateTimeUtils.TIMEZONE_OPTION, "UTC") + .json(timestampsWithFormatPath) + + assert(readBack.schema === customSchema) + checkAnswer(readBack, timestampsWithFormat) + } + } + } + } } From 218341c5db62bf5363c4a16440fa742970f1e919 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 18 Dec 2018 20:52:02 +0800 Subject: [PATCH 2305/2461] [SPARK-26081][SQL][FOLLOW-UP] Use foreach instead of misuse of map (for Unit) ## What changes were proposed in this pull request? This PR proposes to use foreach instead of misuse of map (for Unit). This could cause some weird errors potentially and it's not a good practice anyway. See also SPARK-16694 ## How was this patch tested? N/A Closes #23341 from HyukjinKwon/followup-SPARK-26081. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../spark/sql/execution/datasources/csv/CSVFileFormat.scala | 2 +- .../spark/sql/execution/datasources/json/JsonFileFormat.scala | 2 +- .../spark/sql/execution/datasources/text/TextFileFormat.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index f7d8a9e1042d5..f4f139d180058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -189,5 +189,5 @@ private[csv] class CsvOutputWriter( gen.write(row) } - override def close(): Unit = univocityGenerator.map(_.close()) + override def close(): Unit = univocityGenerator.foreach(_.close()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3042133ee43aa..40f55e7068010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -190,5 +190,5 @@ private[json] class JsonOutputWriter( gen.writeLineEnding() } - override def close(): Unit = jacksonGenerator.map(_.close()) + override def close(): Unit = jacksonGenerator.foreach(_.close()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 01948ab25d63c..0607f7b3c0d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -153,7 +153,7 @@ class TextOutputWriter( private var outputStream: Option[OutputStream] = None override def write(row: InternalRow): Unit = { - val os = outputStream.getOrElse{ + val os = outputStream.getOrElse { val newStream = CodecStreams.createOutputStream(context, new Path(path)) outputStream = Some(newStream) newStream @@ -167,6 +167,6 @@ class TextOutputWriter( } override def close(): Unit = { - outputStream.map(_.close()) + outputStream.foreach(_.close()) } } From 4d693ac904d89b3afeba107eb0480120daf78174 Mon Sep 17 00:00:00 2001 From: Stan Zhai Date: Tue, 18 Dec 2018 07:02:09 -0600 Subject: [PATCH 2306/2461] [SPARK-24680][DEPLOY] Support spark.executorEnv.JAVA_HOME in Standalone mode ## What changes were proposed in this pull request? spark.executorEnv.JAVA_HOME does not take effect when a Worker starting an Executor process in Standalone mode. This PR fixed this. ## How was this patch tested? Manual tests. Closes #21663 from stanzhai/fix-executor-env-java-home. Lead-authored-by: Stan Zhai Co-authored-by: Stan Zhai Signed-off-by: Sean Owen --- .../spark/launcher/AbstractCommandBuilder.java | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index ce24400f557cd..56edceb17bfb8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -91,14 +91,18 @@ abstract List buildCommand(Map env) */ List buildJavaCommand(String extraClassPath) throws IOException { List cmd = new ArrayList<>(); - String envJavaHome; - if (javaHome != null) { - cmd.add(join(File.separator, javaHome, "bin", "java")); - } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) { - cmd.add(join(File.separator, envJavaHome, "bin", "java")); - } else { - cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); + String[] candidateJavaHomes = new String[] { + javaHome, + childEnv.get("JAVA_HOME"), + System.getenv("JAVA_HOME"), + System.getProperty("java.home") + }; + for (String javaHome : candidateJavaHomes) { + if (javaHome != null) { + cmd.add(join(File.separator, javaHome, "bin", "java")); + break; + } } // Load extra JAVA_OPTS from conf/java-opts, if it exists. From 3c0bb6bc45e64fd82052d7857f2a06c34f0c1793 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 19 Dec 2018 00:01:53 +0800 Subject: [PATCH 2307/2461] [SPARK-26384][SQL] Propagate SQL configs for CSV schema inferring ## What changes were proposed in this pull request? Currently, SQL configs are not propagated to executors while schema inferring in CSV datasource. For example, changing of `spark.sql.legacy.timeParser.enabled` does not impact on inferring timestamp types. In the PR, I propose to fix the issue by wrapping schema inferring action using `SQLExecution.withSQLConfPropagated`. ## How was this patch tested? Added logging to `TimestampFormatter`: ```patch -object TimestampFormatter { +object TimestampFormatter extends Logging { def apply(format: String, timeZone: TimeZone, locale: Locale): TimestampFormatter = { if (SQLConf.get.legacyTimeParserEnabled) { + logError("LegacyFallbackTimestampFormatter is being used") new LegacyFallbackTimestampFormatter(format, timeZone, locale) } else { + logError("Iso8601TimestampFormatter is being used") new Iso8601TimestampFormatter(format, timeZone, locale) } } ``` and run the command in `spark-shell`: ```shell $ ./bin/spark-shell --conf spark.sql.legacy.timeParser.enabled=true ``` ```scala scala> Seq("2010|10|10").toDF.repartition(1).write.mode("overwrite").text("/tmp/foo") scala> spark.read.option("inferSchema", "true").option("header", "false").option("timestampFormat", "yyyy|MM|dd").csv("/tmp/foo").printSchema() 18/12/18 10:47:27 ERROR TimestampFormatter: LegacyFallbackTimestampFormatter is being used root |-- _c0: timestamp (nullable = true) ``` Closes #23345 from MaxGekk/csv-schema-infer-propagate-configs. Authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- .../sql/execution/datasources/csv/CSVDataSource.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b46dfb94c133e..375cec597166c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -135,7 +136,9 @@ object TextInputCSVDataSource extends CSVDataSource { val parser = new CsvParser(parsedOptions.asParserSettings) linesWithoutHeader.map(parser.parseLine) } - new CSVInferSchema(parsedOptions).infer(tokenRDD, header) + SQLExecution.withSQLConfPropagated(csv.sparkSession) { + new CSVInferSchema(parsedOptions).infer(tokenRDD, header) + } case _ => // If the first line could not be read, just return the empty schema. StructType(Nil) @@ -208,7 +211,9 @@ object MultiLineCSVDataSource extends CSVDataSource { encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) - new CSVInferSchema(parsedOptions).infer(sampled, header) + SQLExecution.withSQLConfPropagated(sparkSession) { + new CSVInferSchema(parsedOptions).infer(sampled, header) + } case None => // If the first row could not be read, just return the empty schema. StructType(Nil) From befca983d2da4f7828aa7a7cd7345d17c4f291dd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 18 Dec 2018 10:09:56 -0800 Subject: [PATCH 2308/2461] [SPARK-26382][CORE] prefix comparator should handle -0.0 ## What changes were proposed in this pull request? This is kind of a followup of https://github.com/apache/spark/pull/23239 The `UnsafeProject` will normalize special float/double values(NaN and -0.0), so the sorter doesn't have to handle it. However, for consistency and future-proof, this PR proposes to normalize `-0.0` in the prefix comparator, so that it's same with the normal ordering. Note that prefix comparator handles NaN as well. This is not a bug fix, but a safe guard. ## How was this patch tested? existing tests Closes #23334 from cloud-fan/sort. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../unsafe/sort/PrefixComparators.java | 2 ++ .../unsafe/sort/PrefixComparatorsSuite.scala | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 0910db22af004..bef1bdadb27aa 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -69,6 +69,8 @@ public static final class DoublePrefixComparator { * details see http://stereopsis.com/radix.html. */ public static long computePrefix(double value) { + // normalize -0.0 to 0.0, as they should be equal + value = value == -0.0 ? 0.0 : value; // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible // positive NaN, so there's nothing special we need to do for NaNs. long bits = Double.doubleToLongBits(value); diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 73546ef1b7a60..38cb37c524594 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -125,6 +125,7 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + // NaN is greater than the max double value. assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } @@ -134,22 +135,34 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { assert(java.lang.Double.doubleToRawLongBits(negativeNan) < 0) val prefix = PrefixComparators.DoublePrefixComparator.computePrefix(negativeNan) val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + // -NaN is greater than the max double value. assert(PrefixComparators.DOUBLE.compare(prefix, doubleMaxPrefix) === 1) } test("double prefix comparator handles other special values properly") { - val nullValue = 0L + // See `SortPrefix.nullValue` for how we deal with nulls for float/double type + val smallestNullPrefix = 0L + val largestNullPrefix = -1L val nan = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NaN) val posInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.PositiveInfinity) val negInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NegativeInfinity) val minValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MinValue) val maxValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) val zero = PrefixComparators.DoublePrefixComparator.computePrefix(0.0) + val minusZero = PrefixComparators.DoublePrefixComparator.computePrefix(-0.0) + + // null is greater than everything including NaN, when we need to treat it as the largest value. + assert(PrefixComparators.DOUBLE.compare(largestNullPrefix, nan) === 1) + // NaN is greater than the positive infinity. assert(PrefixComparators.DOUBLE.compare(nan, posInf) === 1) assert(PrefixComparators.DOUBLE.compare(posInf, maxValue) === 1) assert(PrefixComparators.DOUBLE.compare(maxValue, zero) === 1) assert(PrefixComparators.DOUBLE.compare(zero, minValue) === 1) assert(PrefixComparators.DOUBLE.compare(minValue, negInf) === 1) - assert(PrefixComparators.DOUBLE.compare(negInf, nullValue) === 1) + // null is smaller than everything including negative infinity, when we need to treat it as + // the smallest value. + assert(PrefixComparators.DOUBLE.compare(negInf, smallestNullPrefix) === 1) + // 0.0 should be equal to -0.0. + assert(PrefixComparators.DOUBLE.compare(zero, minusZero) === 0) } } From 428eb2ad0ad8a141427120b13de3287962258c2d Mon Sep 17 00:00:00 2001 From: Jackey Lee Date: Tue, 18 Dec 2018 12:15:36 -0600 Subject: [PATCH 2309/2461] [SPARK-26394][CORE] Fix annotation error for Utils.timeStringAsMs ## What changes were proposed in this pull request? Change microseconds to milliseconds in annotation of Utils.timeStringAsMs. Closes #23346 from stczwd/stczwd. Authored-by: Jackey Lee Signed-off-by: Sean Owen --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b4ea1ee950217..143abd3bbea8e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1037,7 +1037,7 @@ private[spark] object Utils extends Logging { } /** - * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. */ def timeStringAsMs(str: String): Long = { From 4b3fe3a9ccc8a4a8eb0d037d19cb07a8a288e37a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 18 Dec 2018 13:30:09 -0800 Subject: [PATCH 2310/2461] [SPARK-25815][K8S] Support kerberos in client mode, keytab-based token renewal. This change hooks up the k8s backed to the updated HadoopDelegationTokenManager, so that delegation tokens are also available in client mode, and keytab-based token renewal is enabled. The change re-works the k8s feature steps related to kerberos so that the driver does all the credential management and provides all the needed information to executors - so nothing needs to be added to executor pods. This also makes cluster mode behave a lot more similarly to client mode, since no driver-related config steps are run in the latter case. The main two things that don't need to happen in executors anymore are: - adding the Hadoop config to the executor pods: this is not needed since the Spark driver will serialize the Hadoop config and send it to executors when running tasks. - mounting the kerberos config file in the executor pods: this is not needed once you remove the above. The Hadoop conf sent by the driver with the tasks is already resolved (i.e. has all the kerberos names properly defined), so executors do not need access to the kerberos realm information anymore. The change also avoids creating delegation tokens unnecessarily. This means that they'll only be created if a secret with tokens was not provided, and if a keytab is not provided. In either of those cases, the driver code will handle delegation tokens: in cluster mode by creating a secret and stashing them, in client mode by using existing mechanisms to send DTs to executors. One last feature: the change also allows defining a keytab with a "local:" URI. This is supported in client mode (although that's the same as not saying "local:"), and in k8s cluster mode. This allows the keytab to be mounted onto the image from a pre-existing secret, for example. Finally, the new code always sets SPARK_USER in the driver and executor pods. This is in line with how other resource managers behave: the submitting user reflects which user will access Hadoop services in the app. (With kerberos, that's overridden by the logged in user.) That user is unrelated to the OS user the app is running as inside the containers. Tested: - client and cluster mode with kinit - cluster mode with keytab - cluster mode with local: keytab - YARN cluster with keytab (to make sure it isn't broken) Closes #22911 from vanzin/SPARK-25815. Authored-by: Marcelo Vanzin Signed-off-by: Marcelo Vanzin --- .../org/apache/spark/deploy/SparkSubmit.scala | 29 +- .../HadoopDelegationTokenManager.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 8 + .../apache/spark/deploy/k8s/Constants.scala | 9 +- .../spark/deploy/k8s/KubernetesConf.scala | 4 - .../apache/spark/deploy/k8s/SparkPod.scala | 25 +- .../k8s/features/BasicDriverFeatureStep.scala | 4 + .../features/BasicExecutorFeatureStep.scala | 4 + .../HadoopConfDriverFeatureStep.scala | 124 +++++++ .../HadoopConfExecutorFeatureStep.scala | 40 --- .../HadoopSparkUserExecutorFeatureStep.scala | 35 -- .../KerberosConfDriverFeatureStep.scala | 315 ++++++++++-------- .../KerberosConfExecutorFeatureStep.scala | 46 --- .../hadooputils/HadoopBootstrapUtil.scala | 283 ---------------- .../hadooputils/KerberosConfigSpec.scala | 33 -- .../k8s/submit/KubernetesDriverBuilder.scala | 1 + .../KubernetesClusterSchedulerBackend.scala | 7 +- .../k8s/KubernetesExecutorBuilder.scala | 5 +- .../BasicDriverFeatureStepSuite.scala | 3 +- .../BasicExecutorFeatureStepSuite.scala | 9 +- .../HadoopConfDriverFeatureStepSuite.scala | 71 ++++ .../KerberosConfDriverFeatureStepSuite.scala | 171 ++++++++++ .../KubernetesFeaturesTestUtils.scala | 6 + .../org/apache/spark/deploy/yarn/Client.scala | 24 +- .../spark/deploy/yarn/ClientSuite.scala | 6 +- 25 files changed, 649 insertions(+), 621 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d4055cb6c5853..763bd0a70a035 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} -import java.net.URL +import java.net.{URI, URL} import java.security.PrivilegedExceptionAction import java.text.ParseException import java.util.UUID @@ -334,19 +334,20 @@ private[spark] class SparkSubmit extends Logging { val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() - // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient || isKubernetesCluster) { - if (args.principal != null) { - if (args.keytab != null) { - require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sparkConf.set(KEYTAB, args.keytab) - sparkConf.set(PRINCIPAL, args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - } + // Kerberos is not supported in standalone mode, and keytab support is not yet available + // in Mesos cluster mode. + if (clusterManager != STANDALONE + && !isMesosCluster + && args.principal != null + && args.keytab != null) { + // If client mode, make sure the keytab is just a local path. + if (deployMode == CLIENT && Utils.isLocalUri(args.keytab)) { + args.keytab = new URI(args.keytab).getPath() + } + + if (!Utils.isLocalUri(args.keytab)) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 126a6ab801369..f7e3ddecee093 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.security import java.io.File +import java.net.URI import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -71,11 +72,13 @@ private[spark] class HadoopDelegationTokenManager( private val providerEnabledConfig = "spark.security.credentials.%s.enabled" private val principal = sparkConf.get(PRINCIPAL).orNull - private val keytab = sparkConf.get(KEYTAB).orNull + + // The keytab can be a local: URI for cluster mode, so translate it to a regular path. If it is + // needed later on, the code will check that it exists. + private val keytab = sparkConf.get(KEYTAB).map { uri => new URI(uri).getPath() }.orNull require((principal == null) == (keytab == null), "Both principal and keytab must be defined, or neither.") - require(keytab == null || new File(keytab).isFile(), s"Cannot find keytab at $keytab.") private val delegationTokenProviders = loadProviders() logDebug("Using the following builtin delegation token providers: " + @@ -264,6 +267,7 @@ private[spark] class HadoopDelegationTokenManager( private def doLogin(): UserGroupInformation = { logInfo(s"Attempting to login to KDC using principal: $principal") + require(new File(keytab).isFile(), s"Cannot find keytab at $keytab.") val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") ugi diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 143abd3bbea8e..f322e92c6c8cb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -92,6 +92,9 @@ private[spark] object Utils extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null + /** Scheme used for files that are locally available on worker nodes in the cluster. */ + val LOCAL_SCHEME = "local" + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -2829,6 +2832,11 @@ private[spark] object Utils extends Logging { def isClientMode(conf: SparkConf): Boolean = { "client".equals(conf.get(SparkLauncher.DEPLOY_MODE, "client")) } + + /** Returns whether the URI is a "local:" URI. */ + def isLocalUri(uri: String): Boolean = { + uri.startsWith(s"$LOCAL_SCHEME:") + } } private[util] object CallerContext extends Logging { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 85917b88e912a..76041e7de5182 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -87,25 +87,22 @@ private[spark] object Constants { val NON_JVM_MEMORY_OVERHEAD_FACTOR = 0.4d // Hadoop Configuration - val HADOOP_FILE_VOLUME = "hadoop-properties" + val HADOOP_CONF_VOLUME = "hadoop-properties" val KRB_FILE_VOLUME = "krb5-file" val HADOOP_CONF_DIR_PATH = "/opt/hadoop/conf" val KRB_FILE_DIR_PATH = "/etc" val ENV_HADOOP_CONF_DIR = "HADOOP_CONF_DIR" val HADOOP_CONFIG_MAP_NAME = "spark.kubernetes.executor.hadoopConfigMapName" - val KRB5_CONFIG_MAP_NAME = - "spark.kubernetes.executor.krb5ConfigMapName" // Kerberos Configuration - val KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME = "delegation-tokens" val KERBEROS_DT_SECRET_NAME = "spark.kubernetes.kerberos.dt-secret-name" val KERBEROS_DT_SECRET_KEY = "spark.kubernetes.kerberos.dt-secret-key" - val KERBEROS_SPARK_USER_NAME = - "spark.kubernetes.kerberos.spark-user-name" val KERBEROS_SECRET_KEY = "hadoop-tokens" + val KERBEROS_KEYTAB_VOLUME = "kerberos-keytab" + val KERBEROS_KEYTAB_MOUNT_POINT = "/mnt/secrets/kerberos-keytab" // Hadoop credentials secrets for the Spark app. val SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR = "/mnt/secrets/hadoop-credentials" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index a06c21b47f15e..6febad981af56 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -42,10 +42,6 @@ private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) { def appName: String = get("spark.app.name", "spark") - def hadoopConfigMapName: String = s"$resourceNamePrefix-hadoop-config" - - def krbConfigMapName: String = s"$resourceNamePrefix-krb5-file" - def namespace: String = get(KUBERNETES_NAMESPACE) def imagePullPolicy: String = get(CONTAINER_IMAGE_PULL_POLICY) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 345dd117fd35f..fd1196368a7ff 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -18,7 +18,30 @@ package org.apache.spark.deploy.k8s import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -private[spark] case class SparkPod(pod: Pod, container: Container) +private[spark] case class SparkPod(pod: Pod, container: Container) { + + /** + * Convenience method to apply a series of chained transformations to a pod. + * + * Use it like: + * + * original.modify { case pod => + * // update pod and return new one + * }.modify { case pod => + * // more changes that create a new pod + * }.modify { + * case pod if someCondition => // new pod + * } + * + * This makes it cleaner to apply multiple transformations, avoiding having to create + * a bunch of awkwardly-named local variables. Since the argument is a partial function, + * it can do matching without needing to exhaust all the possibilities. If the function + * is not applied, then the original pod will be kept. + */ + def transform(fn: PartialFunction[SparkPod, SparkPod]): SparkPod = fn.lift(this).getOrElse(this) + +} + private[spark] object SparkPod { def initialPod(): SparkPod = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index d8cf3653d3226..8362c14fb289d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -110,6 +110,10 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) .withContainerPort(driverUIPort) .withProtocol("TCP") .endPort() + .addNewEnv() + .withName(ENV_SPARK_USER) + .withValue(Utils.getCurrentUserName()) + .endEnv() .addAllToEnv(driverCustomEnvs.asJava) .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 4bcf4c9446aa3..c8bf7cdb4224f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -163,6 +163,10 @@ private[spark] class BasicExecutorFeatureStep( .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() + .addNewEnv() + .withName(ENV_SPARK_USER) + .withValue(Utils.getCurrentUserName()) + .endEnv() .addAllToEnv(executorEnv.asJava) .withPorts(requiredPorts.asJava) .addToArgs("executor") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala new file mode 100644 index 0000000000000..d602ed5481e65 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +/** + * Mounts the Hadoop configuration - either a pre-defined config map, or a local configuration + * directory - on the driver pod. + */ +private[spark] class HadoopConfDriverFeatureStep(conf: KubernetesConf) + extends KubernetesFeatureConfigStep { + + private val confDir = Option(conf.sparkConf.getenv(ENV_HADOOP_CONF_DIR)) + private val existingConfMap = conf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) + + KubernetesUtils.requireNandDefined( + confDir, + existingConfMap, + "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " + + "as the creation of an additional ConfigMap, when one is already specified is extraneous") + + private lazy val confFiles: Seq[File] = { + val dir = new File(confDir.get) + if (dir.isDirectory) { + dir.listFiles.filter(_.isFile).toSeq + } else { + Nil + } + } + + private def newConfigMapName: String = s"${conf.resourceNamePrefix}-hadoop-config" + + private def hasHadoopConf: Boolean = confDir.isDefined || existingConfMap.isDefined + + override def configurePod(original: SparkPod): SparkPod = { + original.transform { case pod if hasHadoopConf => + val confVolume = if (confDir.isDefined) { + val keyPaths = confFiles.map { file => + new KeyToPathBuilder() + .withKey(file.getName()) + .withPath(file.getName()) + .build() + } + new VolumeBuilder() + .withName(HADOOP_CONF_VOLUME) + .withNewConfigMap() + .withName(newConfigMapName) + .withItems(keyPaths.asJava) + .endConfigMap() + .build() + } else { + new VolumeBuilder() + .withName(HADOOP_CONF_VOLUME) + .withNewConfigMap() + .withName(existingConfMap.get) + .endConfigMap() + .build() + } + + val podWithConf = new PodBuilder(pod.pod) + .editSpec() + .addNewVolumeLike(confVolume) + .endVolume() + .endSpec() + .build() + + val containerWithMount = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(HADOOP_CONF_VOLUME) + .withMountPath(HADOOP_CONF_DIR_PATH) + .endVolumeMount() + .addNewEnv() + .withName(ENV_HADOOP_CONF_DIR) + .withValue(HADOOP_CONF_DIR_PATH) + .endEnv() + .build() + + SparkPod(podWithConf, containerWithMount) + } + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (confDir.isDefined) { + val fileMap = confFiles.map { file => + (file.getName(), Files.toString(file, StandardCharsets.UTF_8)) + }.toMap.asJava + + Seq(new ConfigMapBuilder() + .withNewMetadata() + .withName(newConfigMapName) + .endMetadata() + .addToData(fileMap) + .build()) + } else { + Nil + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala deleted file mode 100644 index da332881ae1a2..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features - -import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil -import org.apache.spark.internal.Logging - -/** - * This step is responsible for bootstraping the container with ConfigMaps - * containing Hadoop config files mounted as volumes and an ENV variable - * pointed to the mounted file directory. - */ -private[spark] class HadoopConfExecutorFeatureStep(conf: KubernetesExecutorConf) - extends KubernetesFeatureConfigStep with Logging { - - override def configurePod(pod: SparkPod): SparkPod = { - val hadoopConfDirCMapName = conf.getOption(HADOOP_CONFIG_MAP_NAME) - if (hadoopConfDirCMapName.isDefined) { - HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod) - } else { - pod - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala deleted file mode 100644 index c038e75491ca5..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features - -import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil - -/** - * This step is responsible for setting ENV_SPARK_USER when HADOOP_FILES are detected - * however, this step would not be run if Kerberos is enabled, as Kerberos sets SPARK_USER - */ -private[spark] class HadoopSparkUserExecutorFeatureStep(conf: KubernetesExecutorConf) - extends KubernetesFeatureConfigStep { - - override def configurePod(pod: SparkPod): SparkPod = { - conf.getOption(KERBEROS_SPARK_USER_NAME).map { user => - HadoopBootstrapUtil.bootstrapSparkUserPod(user, pod) - }.getOrElse(pod) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala index c6d5a866fa7bc..721d7e97b21f8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala @@ -16,31 +16,40 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, Secret, SecretBuilder} +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model._ import org.apache.commons.codec.binary.Base64 -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.k8s.{KubernetesDriverConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils /** - * Runs the necessary Hadoop-based logic based on Kerberos configs and the presence of the - * HADOOP_CONF_DIR. This runs various bootstrap methods defined in HadoopBootstrapUtil. + * Provide kerberos / service credentials to the Spark driver. + * + * There are three use cases, in order of precedence: + * + * - keytab: if a kerberos keytab is defined, it is provided to the driver, and the driver will + * manage the kerberos login and the creation of delegation tokens. + * - existing tokens: if a secret containing delegation tokens is provided, it will be mounted + * on the driver pod, and the driver will handle distribution of those tokens to executors. + * - tgt only: if Hadoop security is enabled, the local TGT will be used to create delegation + * tokens which will be provided to the driver. The driver will handle distribution of the + * tokens to executors. */ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDriverConf) - extends KubernetesFeatureConfigStep { - - private val hadoopConfDir = Option(kubernetesConf.sparkConf.getenv(ENV_HADOOP_CONF_DIR)) - private val hadoopConfigMapName = kubernetesConf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP) - KubernetesUtils.requireNandDefined( - hadoopConfDir, - hadoopConfigMapName, - "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " + - "as the creation of an additional ConfigMap, when one is already specified is extraneous") + extends KubernetesFeatureConfigStep with Logging { private val principal = kubernetesConf.get(org.apache.spark.internal.config.PRINCIPAL) private val keytab = kubernetesConf.get(org.apache.spark.internal.config.KEYTAB) @@ -49,15 +58,6 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri private val krb5File = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_FILE) private val krb5CMap = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf) - private val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, hadoopConf) - private val isKerberosEnabled = - (hadoopConfDir.isDefined && UserGroupInformation.isSecurityEnabled) || - (hadoopConfigMapName.isDefined && (krb5File.isDefined || krb5CMap.isDefined)) - require(keytab.isEmpty || isKerberosEnabled, - "You must enable Kerberos support if you are specifying a Kerberos Keytab") - - require(existingSecretName.isEmpty || isKerberosEnabled, - "You must enable Kerberos support if you are specifying a Kerberos Secret") KubernetesUtils.requireNandDefined( krb5File, @@ -79,128 +79,183 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri "If a secret storing a Kerberos Delegation Token is specified you must also" + " specify the item-key where the data is stored") - private val hadoopConfigurationFiles = hadoopConfDir.map { hConfDir => - HadoopBootstrapUtil.getHadoopConfFiles(hConfDir) + if (!hasKerberosConf) { + logInfo("You have not specified a krb5.conf file locally or via a ConfigMap. " + + "Make sure that you have the krb5.conf locally on the driver image.") } - private val newHadoopConfigMapName = - if (hadoopConfigMapName.isEmpty) { - Some(kubernetesConf.hadoopConfigMapName) - } else { - None - } - // Either use pre-existing secret or login to create new Secret with DT stored within - private val kerberosConfSpec: Option[KerberosConfigSpec] = (for { - secretName <- existingSecretName - secretItemKey <- existingSecretItemKey - } yield { - KerberosConfigSpec( - dtSecret = None, - dtSecretName = secretName, - dtSecretItemKey = secretItemKey, - jobUserName = UserGroupInformation.getCurrentUser.getShortUserName) - }).orElse( - if (isKerberosEnabled) { - Some(buildKerberosSpec()) + // Create delegation tokens if needed. This is a lazy val so that it's not populated + // unnecessarily. But it needs to be accessible to different methods in this class, + // since it's not clear based solely on available configuration options that delegation + // tokens are needed when other credentials are not available. + private lazy val delegationTokens: Array[Byte] = { + if (keytab.isEmpty && existingSecretName.isEmpty) { + val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, + SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf)) + val creds = UserGroupInformation.getCurrentUser().getCredentials() + tokenManager.obtainDelegationTokens(creds) + // If no tokens and no secrets are stored in the credentials, make sure nothing is returned, + // to avoid creating an unnecessary secret. + if (creds.numberOfTokens() > 0 || creds.numberOfSecretKeys() > 0) { + SparkHadoopUtil.get.serialize(creds) + } else { + null + } } else { - None + null } - ) + } - override def configurePod(pod: SparkPod): SparkPod = { - if (!isKerberosEnabled) { - return pod - } + private def needKeytabUpload: Boolean = keytab.exists(!Utils.isLocalUri(_)) - val hadoopBasedSparkPod = HadoopBootstrapUtil.bootstrapHadoopConfDir( - hadoopConfDir, - newHadoopConfigMapName, - hadoopConfigMapName, - pod) - kerberosConfSpec.map { hSpec => - HadoopBootstrapUtil.bootstrapKerberosPod( - hSpec.dtSecretName, - hSpec.dtSecretItemKey, - hSpec.jobUserName, - krb5File, - Some(kubernetesConf.krbConfigMapName), - krb5CMap, - hadoopBasedSparkPod) - }.getOrElse( - HadoopBootstrapUtil.bootstrapSparkUserPod( - UserGroupInformation.getCurrentUser.getShortUserName, - hadoopBasedSparkPod)) - } + private def dtSecretName: String = s"${kubernetesConf.resourceNamePrefix}-delegation-tokens" - override def getAdditionalPodSystemProperties(): Map[String, String] = { - if (!isKerberosEnabled) { - return Map.empty - } + private def ktSecretName: String = s"${kubernetesConf.resourceNamePrefix}-kerberos-keytab" - val resolvedConfValues = kerberosConfSpec.map { hSpec => - Map(KERBEROS_DT_SECRET_NAME -> hSpec.dtSecretName, - KERBEROS_DT_SECRET_KEY -> hSpec.dtSecretItemKey, - KERBEROS_SPARK_USER_NAME -> hSpec.jobUserName, - KRB5_CONFIG_MAP_NAME -> krb5CMap.getOrElse(kubernetesConf.krbConfigMapName)) - }.getOrElse( - Map(KERBEROS_SPARK_USER_NAME -> - UserGroupInformation.getCurrentUser.getShortUserName)) - Map(HADOOP_CONFIG_MAP_NAME -> - hadoopConfigMapName.getOrElse(kubernetesConf.hadoopConfigMapName)) ++ resolvedConfValues - } + private def hasKerberosConf: Boolean = krb5CMap.isDefined | krb5File.isDefined - override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { - if (!isKerberosEnabled) { - return Seq.empty - } + private def newConfigMapName: String = s"${kubernetesConf.resourceNamePrefix}-krb5-file" - val hadoopConfConfigMap = for { - hName <- newHadoopConfigMapName - hFiles <- hadoopConfigurationFiles - } yield { - HadoopBootstrapUtil.buildHadoopConfigMap(hName, hFiles) - } + override def configurePod(original: SparkPod): SparkPod = { + original.transform { case pod if hasKerberosConf => + val configMapVolume = if (krb5CMap.isDefined) { + new VolumeBuilder() + .withName(KRB_FILE_VOLUME) + .withNewConfigMap() + .withName(krb5CMap.get) + .endConfigMap() + .build() + } else { + val krb5Conf = new File(krb5File.get) + new VolumeBuilder() + .withName(KRB_FILE_VOLUME) + .withNewConfigMap() + .withName(newConfigMapName) + .withItems(new KeyToPathBuilder() + .withKey(krb5Conf.getName()) + .withPath(krb5Conf.getName()) + .build()) + .endConfigMap() + .build() + } - val krb5ConfigMap = krb5File.map { fileLocation => - HadoopBootstrapUtil.buildkrb5ConfigMap( - kubernetesConf.krbConfigMapName, - fileLocation) - } + val podWithVolume = new PodBuilder(pod.pod) + .editSpec() + .addNewVolumeLike(configMapVolume) + .endVolume() + .endSpec() + .build() + + val containerWithMount = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(KRB_FILE_VOLUME) + .withMountPath(KRB_FILE_DIR_PATH + "/krb5.conf") + .withSubPath("krb5.conf") + .endVolumeMount() + .build() + + SparkPod(podWithVolume, containerWithMount) + }.transform { + case pod if needKeytabUpload => + // If keytab is defined and is a submission-local file (not local: URI), then create a + // secret for it. The keytab data will be stored in this secret below. + val podWitKeytab = new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(KERBEROS_KEYTAB_VOLUME) + .withNewSecret() + .withSecretName(ktSecretName) + .endSecret() + .endVolume() + .endSpec() + .build() + + val containerWithKeytab = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(KERBEROS_KEYTAB_VOLUME) + .withMountPath(KERBEROS_KEYTAB_MOUNT_POINT) + .endVolumeMount() + .build() + + SparkPod(podWitKeytab, containerWithKeytab) + + case pod if existingSecretName.isDefined | delegationTokens != null => + val secretName = existingSecretName.getOrElse(dtSecretName) + val itemKey = existingSecretItemKey.getOrElse(KERBEROS_SECRET_KEY) + + val podWithTokens = new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .endVolume() + .endSpec() + .build() - val kerberosDTSecret = kerberosConfSpec.flatMap(_.dtSecret) + val containerWithTokens = new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) + .withMountPath(SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR) + .endVolumeMount() + .addNewEnv() + .withName(ENV_HADOOP_TOKEN_FILE_LOCATION) + .withValue(s"$SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR/$itemKey") + .endEnv() + .build() - hadoopConfConfigMap.toSeq ++ - krb5ConfigMap.toSeq ++ - kerberosDTSecret.toSeq + SparkPod(podWithTokens, containerWithTokens) + } } - private def buildKerberosSpec(): KerberosConfigSpec = { - // The JobUserUGI will be taken fom the Local Ticket Cache or via keytab+principal - // The login happens in the SparkSubmit so login logic is not necessary to include - val jobUserUGI = UserGroupInformation.getCurrentUser - val creds = jobUserUGI.getCredentials - tokenManager.obtainDelegationTokens(creds) - val tokenData = SparkHadoopUtil.get.serialize(creds) - require(tokenData.nonEmpty, "Did not obtain any delegation tokens") - val newSecretName = - s"${kubernetesConf.resourceNamePrefix}-$KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME" - val secretDT = - new SecretBuilder() - .withNewMetadata() - .withName(newSecretName) - .endMetadata() - .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(tokenData)) - .build() - KerberosConfigSpec( - dtSecret = Some(secretDT), - dtSecretName = newSecretName, - dtSecretItemKey = KERBEROS_SECRET_KEY, - jobUserName = jobUserUGI.getShortUserName) + override def getAdditionalPodSystemProperties(): Map[String, String] = { + // If a submission-local keytab is provided, update the Spark config so that it knows the + // path of the keytab in the driver container. + if (needKeytabUpload) { + val ktName = new File(keytab.get).getName() + Map(KEYTAB.key -> s"$KERBEROS_KEYTAB_MOUNT_POINT/$ktName") + } else { + Map.empty + } } - private case class KerberosConfigSpec( - dtSecret: Option[Secret], - dtSecretName: String, - dtSecretItemKey: String, - jobUserName: String) + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + Seq[HasMetadata]() ++ { + krb5File.map { path => + val file = new File(path) + new ConfigMapBuilder() + .withNewMetadata() + .withName(newConfigMapName) + .endMetadata() + .addToData( + Map(file.getName() -> Files.toString(file, StandardCharsets.UTF_8)).asJava) + .build() + } + } ++ { + // If a submission-local keytab is provided, stash it in a secret. + if (needKeytabUpload) { + val kt = new File(keytab.get) + Seq(new SecretBuilder() + .withNewMetadata() + .withName(ktSecretName) + .endMetadata() + .addToData(kt.getName(), Base64.encodeBase64String(Files.toByteArray(kt))) + .build()) + } else { + Nil + } + } ++ { + if (delegationTokens != null) { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(dtSecretName) + .endMetadata() + .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(delegationTokens)) + .build()) + } else { + Nil + } + } + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala deleted file mode 100644 index 907271b1cb483..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features - -import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil -import org.apache.spark.internal.Logging - -/** - * This step is responsible for mounting the DT secret for the executors - */ -private[spark] class KerberosConfExecutorFeatureStep(conf: KubernetesExecutorConf) - extends KubernetesFeatureConfigStep with Logging { - - override def configurePod(pod: SparkPod): SparkPod = { - val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME) - if (maybeKrb5CMap.isDefined) { - logInfo(s"Mounting Resources for Kerberos") - HadoopBootstrapUtil.bootstrapKerberosPod( - conf.get(KERBEROS_DT_SECRET_NAME), - conf.get(KERBEROS_DT_SECRET_KEY), - conf.get(KERBEROS_SPARK_USER_NAME), - None, - None, - maybeKrb5CMap, - pod) - } else { - pod - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala deleted file mode 100644 index 5bee766caf2be..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.hadooputils - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ - -import com.google.common.io.Files -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkPod -import org.apache.spark.internal.Logging - -private[spark] object HadoopBootstrapUtil extends Logging { - - /** - * Mounting the DT secret for both the Driver and the executors - * - * @param dtSecretName Name of the secret that stores the Delegation Token - * @param dtSecretItemKey Name of the Item Key storing the Delegation Token - * @param userName Name of the SparkUser to set SPARK_USER - * @param fileLocation Optional Location of the krb5 file - * @param newKrb5ConfName Optional location of the ConfigMap for Krb5 - * @param existingKrb5ConfName Optional name of ConfigMap for Krb5 - * @param pod Input pod to be appended to - * @return a modified SparkPod - */ - def bootstrapKerberosPod( - dtSecretName: String, - dtSecretItemKey: String, - userName: String, - fileLocation: Option[String], - newKrb5ConfName: Option[String], - existingKrb5ConfName: Option[String], - pod: SparkPod): SparkPod = { - - val preConfigMapVolume = existingKrb5ConfName.map { kconf => - new VolumeBuilder() - .withName(KRB_FILE_VOLUME) - .withNewConfigMap() - .withName(kconf) - .endConfigMap() - .build() - } - - val createConfigMapVolume = for { - fLocation <- fileLocation - krb5ConfName <- newKrb5ConfName - } yield { - val krb5File = new File(fLocation) - val fileStringPath = krb5File.toPath.getFileName.toString - new VolumeBuilder() - .withName(KRB_FILE_VOLUME) - .withNewConfigMap() - .withName(krb5ConfName) - .withItems(new KeyToPathBuilder() - .withKey(fileStringPath) - .withPath(fileStringPath) - .build()) - .endConfigMap() - .build() - } - - // Breaking up Volume creation for clarity - val configMapVolume = preConfigMapVolume.orElse(createConfigMapVolume) - if (configMapVolume.isEmpty) { - logInfo("You have not specified a krb5.conf file locally or via a ConfigMap. " + - "Make sure that you have the krb5.conf locally on the Driver and Executor images") - } - - val kerberizedPodWithDTSecret = new PodBuilder(pod.pod) - .editOrNewSpec() - .addNewVolume() - .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) - .withNewSecret() - .withSecretName(dtSecretName) - .endSecret() - .endVolume() - .endSpec() - .build() - - // Optionally add the krb5.conf ConfigMap - val kerberizedPod = configMapVolume.map { cmVolume => - new PodBuilder(kerberizedPodWithDTSecret) - .editSpec() - .addNewVolumeLike(cmVolume) - .endVolume() - .endSpec() - .build() - }.getOrElse(kerberizedPodWithDTSecret) - - val kerberizedContainerWithMounts = new ContainerBuilder(pod.container) - .addNewVolumeMount() - .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME) - .withMountPath(SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR) - .endVolumeMount() - .addNewEnv() - .withName(ENV_HADOOP_TOKEN_FILE_LOCATION) - .withValue(s"$SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR/$dtSecretItemKey") - .endEnv() - .addNewEnv() - .withName(ENV_SPARK_USER) - .withValue(userName) - .endEnv() - .build() - - // Optionally add the krb5.conf Volume Mount - val kerberizedContainer = - if (configMapVolume.isDefined) { - new ContainerBuilder(kerberizedContainerWithMounts) - .addNewVolumeMount() - .withName(KRB_FILE_VOLUME) - .withMountPath(KRB_FILE_DIR_PATH + "/krb5.conf") - .withSubPath("krb5.conf") - .endVolumeMount() - .build() - } else { - kerberizedContainerWithMounts - } - - SparkPod(kerberizedPod, kerberizedContainer) - } - - /** - * setting ENV_SPARK_USER when HADOOP_FILES are detected - * - * @param sparkUserName Name of the SPARK_USER - * @param pod Input pod to be appended to - * @return a modified SparkPod - */ - def bootstrapSparkUserPod(sparkUserName: String, pod: SparkPod): SparkPod = { - val envModifiedContainer = new ContainerBuilder(pod.container) - .addNewEnv() - .withName(ENV_SPARK_USER) - .withValue(sparkUserName) - .endEnv() - .build() - SparkPod(pod.pod, envModifiedContainer) - } - - /** - * Grabbing files in the HADOOP_CONF_DIR - * - * @param path location of HADOOP_CONF_DIR - * @return a list of File object - */ - def getHadoopConfFiles(path: String): Seq[File] = { - val dir = new File(path) - if (dir.isDirectory) { - dir.listFiles.filter(_.isFile).toSeq - } else { - Seq.empty[File] - } - } - - /** - * Bootstraping the container with ConfigMaps that store - * Hadoop configuration files - * - * @param hadoopConfDir directory location of HADOOP_CONF_DIR env - * @param newHadoopConfigMapName name of the new configMap for HADOOP_CONF_DIR - * @param existingHadoopConfigMapName name of the pre-defined configMap for HADOOP_CONF_DIR - * @param pod Input pod to be appended to - * @return a modified SparkPod - */ - def bootstrapHadoopConfDir( - hadoopConfDir: Option[String], - newHadoopConfigMapName: Option[String], - existingHadoopConfigMapName: Option[String], - pod: SparkPod): SparkPod = { - val preConfigMapVolume = existingHadoopConfigMapName.map { hConf => - new VolumeBuilder() - .withName(HADOOP_FILE_VOLUME) - .withNewConfigMap() - .withName(hConf) - .endConfigMap() - .build() } - - val createConfigMapVolume = for { - dirLocation <- hadoopConfDir - hConfName <- newHadoopConfigMapName - } yield { - val hadoopConfigFiles = getHadoopConfFiles(dirLocation) - val keyPaths = hadoopConfigFiles.map { file => - val fileStringPath = file.toPath.getFileName.toString - new KeyToPathBuilder() - .withKey(fileStringPath) - .withPath(fileStringPath) - .build() - } - new VolumeBuilder() - .withName(HADOOP_FILE_VOLUME) - .withNewConfigMap() - .withName(hConfName) - .withItems(keyPaths.asJava) - .endConfigMap() - .build() - } - - // Breaking up Volume Creation for clarity - val configMapVolume = preConfigMapVolume.getOrElse(createConfigMapVolume.get) - - val hadoopSupportedPod = new PodBuilder(pod.pod) - .editSpec() - .addNewVolumeLike(configMapVolume) - .endVolume() - .endSpec() - .build() - - val hadoopSupportedContainer = new ContainerBuilder(pod.container) - .addNewVolumeMount() - .withName(HADOOP_FILE_VOLUME) - .withMountPath(HADOOP_CONF_DIR_PATH) - .endVolumeMount() - .addNewEnv() - .withName(ENV_HADOOP_CONF_DIR) - .withValue(HADOOP_CONF_DIR_PATH) - .endEnv() - .build() - SparkPod(hadoopSupportedPod, hadoopSupportedContainer) - } - - /** - * Builds ConfigMap given the file location of the - * krb5.conf file - * - * @param configMapName name of configMap for krb5 - * @param fileLocation location of krb5 file - * @return a ConfigMap - */ - def buildkrb5ConfigMap( - configMapName: String, - fileLocation: String): ConfigMap = { - val file = new File(fileLocation) - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(Map(file.toPath.getFileName.toString -> - Files.toString(file, StandardCharsets.UTF_8)).asJava) - .build() - } - - /** - * Builds ConfigMap given the ConfigMap name - * and a list of Hadoop Conf files - * - * @param hadoopConfigMapName name of hadoopConfigMap - * @param hadoopConfFiles list of hadoopFiles - * @return a ConfigMap - */ - def buildHadoopConfigMap( - hadoopConfigMapName: String, - hadoopConfFiles: Seq[File]): ConfigMap = { - new ConfigMapBuilder() - .withNewMetadata() - .withName(hadoopConfigMapName) - .endMetadata() - .addToData(hadoopConfFiles.map { file => - (file.toPath.getFileName.toString, - Files.toString(file, StandardCharsets.UTF_8)) - }.toMap.asJava) - .build() - } - -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala deleted file mode 100644 index 7f7ef216cf485..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.features.hadooputils - -import io.fabric8.kubernetes.api.model.Secret - -/** - * Represents a given configuration of the Kerberos Configuration logic - *

      - * - The secret containing a DT, either previously specified or built on the fly - * - The name of the secret where the DT will be stored - * - The data item-key on the secret which correlates with where the current DT data is stored - * - The Job User's username - */ -private[spark] case class KerberosConfigSpec( - dtSecret: Option[Secret], - dtSecretName: String, - dtSecretItemKey: String, - jobUserName: String) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index d2c0ced9fa2f4..57e4060bc85b9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -46,6 +46,7 @@ private[spark] class KubernetesDriverBuilder { new LocalDirsFeatureStep(conf), new MountVolumesFeatureStep(conf), new DriverCommandFeatureStep(conf), + new HadoopConfDriverFeatureStep(conf), new KerberosConfDriverFeatureStep(conf), new PodTemplateConfigMapStep(conf)) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 03f5da2bb0bce..cd298971e02a7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -25,6 +25,7 @@ import io.fabric8.kubernetes.client.KubernetesClient import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -143,7 +144,11 @@ private[spark] class KubernetesClusterSchedulerBackend( } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { - new KubernetesDriverEndpoint(rpcEnv, properties) + new KubernetesDriverEndpoint(sc.env.rpcEnv, properties) + } + + override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { + Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration)) } private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 0b74966fe8685..48aa2c56d4d69 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -44,10 +44,7 @@ private[spark] class KubernetesExecutorBuilder { new MountSecretsFeatureStep(conf), new EnvSecretsFeatureStep(conf), new LocalDirsFeatureStep(conf), - new MountVolumesFeatureStep(conf), - new HadoopConfExecutorFeatureStep(conf), - new KerberosConfExecutorFeatureStep(conf), - new HadoopSparkUserExecutorFeatureStep(conf)) + new MountVolumesFeatureStep(conf)) features.foldLeft(initialPod) { case (pod, feature) => feature.configurePod(pod) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index e4951bc1e69ed..5ceb9d6d6fcd0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ import org.apache.spark.ui.SparkUI +import org.apache.spark.util.Utils class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -73,7 +74,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val foundPortNames = configuredPod.container.getPorts.asScala.toSet assert(expectedPortNames === foundPortNames) - assert(configuredPod.container.getEnv.size === 3) val envs = configuredPod.container .getEnv .asScala @@ -82,6 +82,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_ENVS.foreach { case (k, v) => assert(envs(v) === v) } + assert(envs(ENV_SPARK_USER) === Utils.getCurrentUserName()) assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 05989d9be7ad5..c2efab01e4248 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -200,7 +200,8 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> KubernetesTestConf.APP_ID, ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + ENV_EXECUTOR_POD_IP -> null, + ENV_SPARK_USER -> Utils.getCurrentUserName()) val extraJavaOptsStart = additionalEnvVars.keys.count(_.startsWith(ENV_JAVA_OPT_PREFIX)) val extraJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) @@ -208,9 +209,11 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { s"$ENV_JAVA_OPT_PREFIX${ind + extraJavaOptsStart}" -> opt }.toMap - val mapEnvs = executorPod.container.getEnv.asScala.map { + val containerEnvs = executorPod.container.getEnv.asScala.map { x => (x.getName, x.getValue) }.toMap - assert((defaultEnvs ++ extraJavaOptsEnvs) === mapEnvs) + + val expectedEnvs = defaultEnvs ++ additionalEnvVars ++ extraJavaOptsEnvs + assert(containerEnvs === expectedEnvs) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..e1c01dbdc7358 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model.ConfigMap + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.util.{SparkConfWithEnv, Utils} + +class HadoopConfDriverFeatureStepSuite extends SparkFunSuite { + + import KubernetesFeaturesTestUtils._ + import SecretVolumeUtils._ + + test("mount hadoop config map if defined") { + val sparkConf = new SparkConf(false) + .set(Config.KUBERNETES_HADOOP_CONF_CONFIG_MAP, "testConfigMap") + val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) + val step = new HadoopConfDriverFeatureStep(conf) + checkPod(step.configurePod(SparkPod.initialPod())) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + test("create hadoop config map if config dir is defined") { + val confDir = Utils.createTempDir() + val confFiles = Set("core-site.xml", "hdfs-site.xml") + + confFiles.foreach { f => + Files.write("some data", new File(confDir, f), UTF_8) + } + + val sparkConf = new SparkConfWithEnv(Map(ENV_HADOOP_CONF_DIR -> confDir.getAbsolutePath())) + val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf) + + val step = new HadoopConfDriverFeatureStep(conf) + checkPod(step.configurePod(SparkPod.initialPod())) + + val hadoopConfMap = filter[ConfigMap](step.getAdditionalKubernetesResources()).head + assert(hadoopConfMap.getData().keySet().asScala === confFiles) + } + + private def checkPod(pod: SparkPod): Unit = { + assert(podHasVolume(pod.pod, HADOOP_CONF_VOLUME)) + assert(containerHasVolume(pod.container, HADOOP_CONF_VOLUME, HADOOP_CONF_DIR_PATH)) + assert(containerHasEnvVar(pod.container, ENV_HADOOP_CONF_DIR)) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..41ca3a94ce7a7 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction + +import scala.collection.JavaConverters._ + +import com.google.common.io.Files +import io.fabric8.kubernetes.api.model.{ConfigMap, Secret} +import org.apache.commons.codec.binary.Base64 +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class KerberosConfDriverFeatureStepSuite extends SparkFunSuite { + + import KubernetesFeaturesTestUtils._ + import SecretVolumeUtils._ + + private val tmpDir = Utils.createTempDir() + + test("mount krb5 config map if defined") { + val configMap = "testConfigMap" + val step = createStep( + new SparkConf(false).set(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP, configMap)) + + checkPodForKrbConf(step.configurePod(SparkPod.initialPod()), configMap) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(filter[ConfigMap](step.getAdditionalKubernetesResources()).isEmpty) + } + + test("create krb5.conf config map if local config provided") { + val krbConf = File.createTempFile("krb5", ".conf", tmpDir) + Files.write("some data", krbConf, UTF_8) + + val sparkConf = new SparkConf(false) + .set(KUBERNETES_KERBEROS_KRB5_FILE, krbConf.getAbsolutePath()) + val step = createStep(sparkConf) + + val confMap = filter[ConfigMap](step.getAdditionalKubernetesResources()).head + assert(confMap.getData().keySet().asScala === Set(krbConf.getName())) + + checkPodForKrbConf(step.configurePod(SparkPod.initialPod()), confMap.getMetadata().getName()) + assert(step.getAdditionalPodSystemProperties().isEmpty) + } + + test("create keytab secret if client keytab file used") { + val keytab = File.createTempFile("keytab", ".bin", tmpDir) + Files.write("some data", keytab, UTF_8) + + val sparkConf = new SparkConf(false) + .set(KEYTAB, keytab.getAbsolutePath()) + .set(PRINCIPAL, "alice") + val step = createStep(sparkConf) + + val pod = step.configurePod(SparkPod.initialPod()) + assert(podHasVolume(pod.pod, KERBEROS_KEYTAB_VOLUME)) + assert(containerHasVolume(pod.container, KERBEROS_KEYTAB_VOLUME, KERBEROS_KEYTAB_MOUNT_POINT)) + + assert(step.getAdditionalPodSystemProperties().keys === Set(KEYTAB.key)) + + val secret = filter[Secret](step.getAdditionalKubernetesResources()).head + assert(secret.getData().keySet().asScala === Set(keytab.getName())) + } + + test("do nothing if container-local keytab used") { + val sparkConf = new SparkConf(false) + .set(KEYTAB, "local:/my.keytab") + .set(PRINCIPAL, "alice") + val step = createStep(sparkConf) + + val initial = SparkPod.initialPod() + assert(step.configurePod(initial) === initial) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + test("mount delegation tokens if provided") { + val dtSecret = "tokenSecret" + val sparkConf = new SparkConf(false) + .set(KUBERNETES_KERBEROS_DT_SECRET_NAME, dtSecret) + .set(KUBERNETES_KERBEROS_DT_SECRET_ITEM_KEY, "dtokens") + val step = createStep(sparkConf) + + checkPodForTokens(step.configurePod(SparkPod.initialPod()), dtSecret) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + test("create delegation tokens if needed") { + // Since HadoopDelegationTokenManager does not create any tokens without proper configs and + // services, start with a test user that already has some tokens that will just be piped + // through to the driver. + val testUser = UserGroupInformation.createUserForTesting("k8s", Array()) + testUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + val creds = testUser.getCredentials() + creds.addSecretKey(new Text("K8S_TEST_KEY"), Array[Byte](0x4, 0x2)) + testUser.addCredentials(creds) + + val tokens = SparkHadoopUtil.get.serialize(creds) + + val step = createStep(new SparkConf(false)) + + val dtSecret = filter[Secret](step.getAdditionalKubernetesResources()).head + assert(dtSecret.getData().get(KERBEROS_SECRET_KEY) === Base64.encodeBase64String(tokens)) + + checkPodForTokens(step.configurePod(SparkPod.initialPod()), + dtSecret.getMetadata().getName()) + + assert(step.getAdditionalPodSystemProperties().isEmpty) + } + }) + } + + test("do nothing if no config and no tokens") { + val step = createStep(new SparkConf(false)) + val initial = SparkPod.initialPod() + assert(step.configurePod(initial) === initial) + assert(step.getAdditionalPodSystemProperties().isEmpty) + assert(step.getAdditionalKubernetesResources().isEmpty) + } + + private def checkPodForKrbConf(pod: SparkPod, confMapName: String): Unit = { + val podVolume = pod.pod.getSpec().getVolumes().asScala.find(_.getName() == KRB_FILE_VOLUME) + assert(podVolume.isDefined) + assert(containerHasVolume(pod.container, KRB_FILE_VOLUME, KRB_FILE_DIR_PATH + "/krb5.conf")) + assert(podVolume.get.getConfigMap().getName() === confMapName) + } + + private def checkPodForTokens(pod: SparkPod, dtSecretName: String): Unit = { + val podVolume = pod.pod.getSpec().getVolumes().asScala + .find(_.getName() == SPARK_APP_HADOOP_SECRET_VOLUME_NAME) + assert(podVolume.isDefined) + assert(containerHasVolume(pod.container, SPARK_APP_HADOOP_SECRET_VOLUME_NAME, + SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR)) + assert(containerHasEnvVar(pod.container, ENV_HADOOP_TOKEN_FILE_LOCATION)) + assert(podVolume.get.getSecret().getSecretName() === dtSecretName) + } + + private def createStep(conf: SparkConf): KerberosConfDriverFeatureStep = { + val kconf = KubernetesTestConf.createDriverConf(sparkConf = conf) + new KerberosConfDriverFeatureStep(kconf) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index f90380e30e52a..076b681be2397 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers @@ -63,4 +64,9 @@ object KubernetesFeaturesTestUtils { def containerHasEnvVar(container: Container, envVarName: String): Boolean = { container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) } + + def filter[T: ClassTag](list: Seq[HasMetadata]): Seq[T] = { + val desired = implicitly[ClassTag[T]].runtimeClass + list.filter(_.getClass() == desired).map(_.asInstanceOf[T]).toSeq + } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 6240f7b68d2c8..184fb6a8ad13e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -116,6 +116,8 @@ private[spark] class Client( } } + require(keytab == null || !Utils.isLocalUri(keytab), "Keytab should reference a local file.") + private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -472,7 +474,7 @@ private[spark] class Client( appMasterOnly: Boolean = false): (Boolean, String) = { val trimmedPath = path.trim() val localURI = Utils.resolveURI(trimmedPath) - if (localURI.getScheme != LOCAL_SCHEME) { + if (localURI.getScheme != Utils.LOCAL_SCHEME) { if (addDistributedUri(localURI)) { val localPath = getQualifiedLocalPath(localURI, hadoopConf) val linkname = targetDir.map(_ + "/").getOrElse("") + @@ -515,7 +517,7 @@ private[spark] class Client( val sparkArchive = sparkConf.get(SPARK_ARCHIVE) if (sparkArchive.isDefined) { val archive = sparkArchive.get - require(!isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.") + require(!Utils.isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.") distribute(Utils.resolveURI(archive).toString, resType = LocalResourceType.ARCHIVE, destName = Some(LOCALIZED_LIB_DIR)) @@ -525,7 +527,7 @@ private[spark] class Client( // Break the list of jars to upload, and resolve globs. val localJars = new ArrayBuffer[String]() jars.foreach { jar => - if (!isLocalUri(jar)) { + if (!Utils.isLocalUri(jar)) { val path = getQualifiedLocalPath(Utils.resolveURI(jar), hadoopConf) val pathFs = FileSystem.get(path.toUri(), hadoopConf) pathFs.globStatus(path).filter(_.isFile()).foreach { entry => @@ -814,7 +816,7 @@ private[spark] class Client( } (pySparkArchives ++ pyArchives).foreach { path => val uri = Utils.resolveURI(path) - if (uri.getScheme != LOCAL_SCHEME) { + if (uri.getScheme != Utils.LOCAL_SCHEME) { pythonPath += buildPath(Environment.PWD.$$(), new Path(uri).getName()) } else { pythonPath += uri.getPath() @@ -1183,9 +1185,6 @@ private object Client extends Logging { // Alias for the user jar val APP_JAR_NAME: String = "__app__.jar" - // URI scheme that identifies local resources - val LOCAL_SCHEME = "local" - // Staging directory for any temporary jars or files val SPARK_STAGING: String = ".sparkStaging" @@ -1307,7 +1306,7 @@ private object Client extends Logging { addClasspathEntry(buildPath(Environment.PWD.$$(), LOCALIZED_LIB_DIR, "*"), env) if (sparkConf.get(SPARK_ARCHIVE).isEmpty) { sparkConf.get(SPARK_JARS).foreach { jars => - jars.filter(isLocalUri).foreach { jar => + jars.filter(Utils.isLocalUri).foreach { jar => val uri = new URI(jar) addClasspathEntry(getClusterPath(sparkConf, uri.getPath()), env) } @@ -1340,7 +1339,7 @@ private object Client extends Logging { private def getMainJarUri(mainJar: Option[String]): Option[URI] = { mainJar.flatMap { path => val uri = Utils.resolveURI(path) - if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None + if (uri.getScheme == Utils.LOCAL_SCHEME) Some(uri) else None }.orElse(Some(new URI(APP_JAR_NAME))) } @@ -1368,7 +1367,7 @@ private object Client extends Logging { uri: URI, fileName: String, env: HashMap[String, String]): Unit = { - if (uri != null && uri.getScheme == LOCAL_SCHEME) { + if (uri != null && uri.getScheme == Utils.LOCAL_SCHEME) { addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath(Environment.PWD.$$(), fileName), env) @@ -1489,11 +1488,6 @@ private object Client extends Logging { components.mkString(Path.SEPARATOR) } - /** Returns whether the URI is a "local:" URI. */ - def isLocalUri(uri: String): Boolean = { - uri.startsWith(s"$LOCAL_SCHEME:") - } - def createAppReport(report: ApplicationReport): YarnAppReport = { val diags = report.getDiagnostics() val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index b3286e8fd824e..a6f57fcdb2461 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -100,7 +100,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => val uri = new URI(entry) - if (LOCAL_SCHEME.equals(uri.getScheme())) { + if (Utils.LOCAL_SCHEME.equals(uri.getScheme())) { cp should contain (uri.getPath()) } else { cp should not contain (uri.getPath()) @@ -136,7 +136,7 @@ class ClientSuite extends SparkFunSuite with Matchers { val expected = ADDED.split(",") .map(p => { val uri = new URI(p) - if (LOCAL_SCHEME == uri.getScheme()) { + if (Utils.LOCAL_SCHEME == uri.getScheme()) { p } else { Option(uri.getFragment()).getOrElse(new File(p).getName()) @@ -249,7 +249,7 @@ class ClientSuite extends SparkFunSuite with Matchers { any(classOf[MutableHashMap[URI, Path]]), anyBoolean(), any()) classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*")) - sparkConf.set(SPARK_ARCHIVE, LOCAL_SCHEME + ":" + archive.getPath()) + sparkConf.set(SPARK_ARCHIVE, Utils.LOCAL_SCHEME + ":" + archive.getPath()) intercept[IllegalArgumentException] { client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil) } From 834b8609793525a5a486013732d8c98e1c6e6504 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 18 Dec 2018 23:21:52 -0800 Subject: [PATCH 2311/2461] [SPARK-26366][SQL] ReplaceExceptWithFilter should consider NULL as False ## What changes were proposed in this pull request? In `ReplaceExceptWithFilter` we do not consider properly the case in which the condition returns NULL. Indeed, in that case, since negating NULL still returns NULL, so it is not true the assumption that negating the condition returns all the rows which didn't satisfy it, rows returning NULL may not be returned. This happens when constraints inferred by `InferFiltersFromConstraints` are not enough, as it happens with `OR` conditions. The rule had also problems with non-deterministic conditions: in such a scenario, this rule would change the probability of the output. The PR fixes these problem by: - returning False for the condition when it is Null (in this way we do return all the rows which didn't satisfy it); - avoiding any transformation when the condition is non-deterministic. ## How was this patch tested? added UTs Closes #23315 from mgaido91/SPARK-26366. Authored-by: Marco Gaido Signed-off-by: gatorsmile --- .../optimizer/ReplaceExceptWithFilter.scala | 32 ++++++++------ .../optimizer/ReplaceOperatorSuite.scala | 44 ++++++++++++++----- .../org/apache/spark/sql/DatasetSuite.scala | 11 +++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 38 ++++++++++++++++ 4 files changed, 101 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index efd3944eba7f5..4996d24dfd298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule * Note: * Before flipping the filter condition of the right node, we should: * 1. Combine all it's [[Filter]]. - * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition). + * 2. Update the attribute references to the left node; + * 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition). */ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { @@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { plan.transform { case e @ Except(left, right, false) if isEligible(left, right) => - val newCondition = transformCondition(left, skipProject(right)) - newCondition.map { c => - Distinct(Filter(Not(c), left)) - }.getOrElse { + val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition + if (filterCondition.deterministic) { + transformCondition(left, filterCondition).map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } + } else { e } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { - val filterCondition = - InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition - - val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - - if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { - Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = { + val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap + if (condition.references.forall(r => attributeNameMap.contains(r.name))) { + val rewrittenCondition = condition.transform { + case a: AttributeReference => attributeNameMap(a.name) + } + // We need to consider as False when the condition is NULL, otherwise we do not return those + // rows containing NULL which are instead filtered in the Except right plan + Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral))) } else { None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 3b1b2d588ef67..c8e15c7da763e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType class ReplaceOperatorSuite extends PlanTest { @@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze comparePlans(optimized, correctAnswer) @@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), table1)).analyze + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), + table1)).analyze comparePlans(optimized, correctAnswer) } @@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Project(Seq(attributeA, attributeB), table1))).analyze comparePlans(optimized, correctAnswer) @@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze comparePlans(optimized, correctAnswer) @@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA === 1 && attributeB === 2)), + Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))), Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze @@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, query) } + + test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") { + val basePlan = LocalRelation(Seq('a.int, 'b.int)) + val otherPlan = basePlan.where('a.in(1, 2) || 'b.in()) + val except = Except(basePlan, otherPlan, false) + val result = OptimizeIn(Optimize.execute(except.analyze)) + val correctAnswer = Aggregate(basePlan.output, basePlan.output, + Filter(!Coalesce(Seq( + 'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)), + Literal.FalseLiteral)), + basePlan)).analyze + comparePlans(result, correctAnswer) + } + + test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") { + val basePlan = LocalRelation(Seq('a.int, 'b.int)) + val otherPlan = basePlan.where('a > rand(1L)) + val except = Except(basePlan, otherPlan, false) + val result = Optimize.execute(except.analyze) + val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) => + a1 <=> a2 }.reduce( _ && _) + val correctAnswer = Aggregate(basePlan.output, otherPlan.output, + Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze + comparePlans(result, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 525c7cef39563..c90b15814a534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1656,6 +1656,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) } + + test("SPARK-26366: return nulls which are not filtered in except") { + val inputDF = sqlContext.createDataFrame( + sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true)))) + + val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c") + checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4cc8a45391996..37a8815350a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2899,6 +2899,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-26366: verify ReplaceExceptWithFilter") { + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(0, 3, 5), + Row(0, 3, null), + Row(null, 3, 5), + Row(0, null, 5), + Row(0, null, null), + Row(null, null, 5), + Row(null, 3, null), + Row(null, null, null))), + StructType(Seq(StructField("c1", IntegerType), + StructField("c2", IntegerType), + StructField("c3", IntegerType)))) + val where = "c2 >= 3 OR c1 >= 0" + val whereNullSafe = + """ + |(c2 IS NOT NULL AND c2 >= 3) + |OR (c1 IS NOT NULL AND c1 >= 0) + """.stripMargin + + val df_a = df.filter(where) + val df_b = df.filter(whereNullSafe) + checkAnswer(df.except(df_a), df.except(df_b)) + + val whereWithIn = "c2 >= 3 OR c1 in (2)" + val whereWithInNullSafe = + """ + |(c2 IS NOT NULL AND c2 >= 3) + """.stripMargin + val dfIn_a = df.filter(whereWithIn) + val dfIn_b = df.filter(whereWithInNullSafe) + checkAnswer(df.except(dfIn_a), df.except(dfIn_b)) + } + } + } } case class Foo(bar: Option[String]) From 08f74ada3656af401099aa79471ef8a1155a3f07 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 19 Dec 2018 09:41:30 -0800 Subject: [PATCH 2312/2461] [SPARK-26390][SQL] ColumnPruning rule should only do column pruning ## What changes were proposed in this pull request? This is a small clean up. By design catalyst rules should be orthogonal: each rule should have its own responsibility. However, the `ColumnPruning` rule does not only do column pruning, but also remove no-op project and window. This PR updates the `RemoveRedundantProject` rule to remove no-op window as well, and clean up the `ColumnPruning` rule to only do column pruning. ## How was this patch tested? existing tests Closes #23343 from cloud-fan/column-pruning. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/optimizer/Optimizer.scala | 23 +++++++++---------- .../optimizer/ColumnPruningSuite.scala | 7 +++--- .../optimizer/CombiningLimitsSuite.scala | 5 ++-- .../optimizer/JoinOptimizationSuite.scala | 1 + .../RemoveRedundantAliasAndProjectSuite.scala | 2 +- .../optimizer/RewriteSubquerySuite.scala | 2 +- .../optimizer/TransposeWindowSuite.scala | 2 +- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3eb6bca6ec976..44d5543114902 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -93,7 +93,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewriteCorrelatedScalarSubquery, EliminateSerialization, RemoveRedundantAliases, - RemoveRedundantProject, + RemoveNoopOperators, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -177,7 +177,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) :+ + RemoveNoopOperators) :+ Batch("UpdateAttributeReferences", Once, UpdateNullabilityInAttributeReferences) } @@ -403,11 +403,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { } /** - * Remove projections from the query plan that do not make any modifications. + * Remove no-op operators from the query plan that do not make any modifications. */ -object RemoveRedundantProject extends Rule[LogicalPlan] { +object RemoveNoopOperators extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ Project(_, child) if p.output == child.output => child + // Eliminate no-op Projects + case p @ Project(_, child) if child.sameOutput(p) => child + + // Eliminate no-op Window + case w: Window if w.windowExpressions.isEmpty => w.child } } @@ -602,17 +606,12 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = w.copy( windowExpressions = w.windowExpressions.filter(p.references.contains))) - // Eliminate no-op Window - case w: Window if w.windowExpressions.isEmpty => w.child - - // Eliminate no-op Projects - case p @ Project(_, child) if child.sameOutput(p) => child - // Can't prune the columns on LeafNode case p @ Project(_, _: LeafNode) => p // for all other logical plans that inherits the output from it's children - case p @ Project(_, child) => + // Project over project is handled by the first case, skip it here. + case p @ Project(_, child) if !child.isInstanceOf[Project] => val required = child.references ++ p.references if (!child.inputSet.subsetOf(required)) { val newChildren = child.children.map(c => prunedChild(c, required)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 8d7c9bf220bc2..57195d5fda7c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest { val batches = Batch("Column pruning", FixedPoint(100), PushDownPredicate, ColumnPruning, + RemoveNoopOperators, CollapseProject) :: Nil } @@ -340,10 +341,8 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on Union") { val input1 = LocalRelation('a.int, 'b.string, 'c.double) val input2 = LocalRelation('c.int, 'd.string, 'e.double) - val query = Project('b :: Nil, - Union(input1 :: input2 :: Nil)).analyze - val expected = Project('b :: Nil, - Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + val query = Project('b :: Nil, Union(input1 :: input2 :: Nil)).analyze + val expected = Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil).analyze comparePlans(Optimize.execute(query), expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index ef4b848924f06..b190dd5a7c220 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,8 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Filter Pushdown", FixedPoint(100), - ColumnPruning) :: + Batch("Column Pruning", FixedPoint(100), + ColumnPruning, + RemoveNoopOperators) :: Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e9438b2eee550..6fe5e619d03ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -39,6 +39,7 @@ class JoinOptimizationSuite extends PlanTest { ReorderJoin, PushPredicateThroughJoin, ColumnPruning, + RemoveNoopOperators, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 1973b5abb462d..3802dbf5d6e06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -33,7 +33,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper FixedPoint(50), PushProjectionThroughUnion, RemoveRedundantAliases, - RemoveRedundantProject) :: Nil + RemoveNoopOperators) :: Nil } test("all expressions in project list are aliased child output") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 6b3739c372c3a..f00d22e6e96a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -34,7 +34,7 @@ class RewriteSubquerySuite extends PlanTest { RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) :: Nil + RemoveNoopOperators) :: Nil } test("Column pruning after rewriting predicate subquery") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala index 58b3d1c98f3cd..4acd57832d2f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor class TransposeWindowSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) :: + Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveNoopOperators) :: Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil } From 61c443acd23c74ebcb20fd32e5e0ed6c1722b5dc Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 20 Dec 2018 10:41:45 +0800 Subject: [PATCH 2313/2461] [SPARK-26262][SQL] Runs SQLQueryTestSuite on mixed config sets: WHOLESTAGE_CODEGEN_ENABLED and CODEGEN_FACTORY_MODE ## What changes were proposed in this pull request? For better test coverage, this pr proposed to use the 4 mixed config sets of `WHOLESTAGE_CODEGEN_ENABLED` and `CODEGEN_FACTORY_MODE` when running `SQLQueryTestSuite`: 1. WHOLESTAGE_CODEGEN_ENABLED=true, CODEGEN_FACTORY_MODE=CODEGEN_ONLY 2. WHOLESTAGE_CODEGEN_ENABLED=false, CODEGEN_FACTORY_MODE=CODEGEN_ONLY 3. WHOLESTAGE_CODEGEN_ENABLED=true, CODEGEN_FACTORY_MODE=NO_CODEGEN 4. WHOLESTAGE_CODEGEN_ENABLED=false, CODEGEN_FACTORY_MODE=NO_CODEGEN This pr also moved some existing tests into `ExplainSuite` because explain output results are different between codegen and interpreter modes. ## How was this patch tested? Existing tests. Closes #23213 from maropu/InterpreterModeTest. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../resources/sql-tests/inputs/group-by.sql | 5 - .../sql-tests/inputs/inline-table.sql | 3 - .../resources/sql-tests/inputs/operators.sql | 21 -- .../inputs/sql-compatibility-functions.sql | 5 - .../sql-tests/inputs/string-functions.sql | 27 --- .../inputs/table-valued-functions.sql | 6 - .../sql-tests/results/group-by.sql.out | 30 +-- .../sql-tests/results/inline-table.sql.out | 32 +-- .../sql-tests/results/operators.sql.out | 204 +++++++----------- .../sql-compatibility-functions.sql.out | 61 ++---- .../results/string-functions.sql.out | 131 +++-------- .../results/table-valued-functions.sql.out | 41 +--- .../org/apache/spark/sql/ExplainSuite.scala | 133 +++++++++++- .../apache/spark/sql/SQLQueryTestSuite.scala | 51 ++--- 14 files changed, 281 insertions(+), 469 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index ec263ea70bd4a..7e81ff1aba37b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -141,8 +141,3 @@ SELECT every("true"); SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; - --- simple explain of queries having every/some/any agregates. Optimized --- plan should show the rewritten aggregate expression. -EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; - diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 41d316444ed6b..b3ec956cd178e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -49,6 +49,3 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b); -- string to timestamp select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); - --- cross-join inline tables -EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 37f9cd44da7f2..ba14789d48db6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -29,27 +29,6 @@ select 2 * 5; select 5 % 3; select pmod(-7, 3); --- check operator precedence. --- We follow Oracle operator precedence in the table below that lists the levels of precedence --- among SQL operators from high to low: ------------------------------------------------------------------------------------------- --- Operator Operation ------------------------------------------------------------------------------------------- --- +, - identity, negation --- *, / multiplication, division --- +, -, || addition, subtraction, concatenation --- =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison --- NOT exponentiation, logical negation --- AND conjunction --- OR disjunction ------------------------------------------------------------------------------------------- -explain select 'a' || 1 + 2; -explain select 1 - 2 || 'b'; -explain select 2 * 4 + 3 || 'b'; -explain select 3 + 1 || 'a' || 4 / 2; -explain select 1 == 1 OR 'a' || 'b' == 'ab'; -explain select 'a' || 'c' == 'ac' AND 2 == 3; - -- math functions select cot(1); select cot(null); diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index f1461032065ad..1ae49c8bfc76a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -12,11 +12,6 @@ SELECT nullif(1, 2.1d), nullif(1, 1.0d); SELECT nvl(1, 2.1d), nvl(null, 2.1d); SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d); --- explain for these functions; use range to avoid constant folding -explain extended -select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') -from range(2); - -- SPARK-16730 cast alias functions for Hive compatibility SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1); SELECT float(1), double(1), decimal(1); diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 2effb43183d75..fbc231627e36f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -5,10 +5,6 @@ select format_string(); -- A pipe operator for string concatenation select 'a' || 'b' || 'c'; --- Check if catalyst combine nested `Concat`s -EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)); - -- replace function select replace('abc', 'b', '123'); select replace('abc', 'b'); @@ -25,29 +21,6 @@ select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); --- turn off concatBinaryAsString -set spark.sql.function.concatBinaryAsString=false; - --- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false -EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - string(id + 1) col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -); - -EXPLAIN SELECT (col1 || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -); - -- split function SELECT split('aa1cc2ee3', '[1-9]+'); SELECT split('aa1cc2ee3', '[1-9]+', 2); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 72cd8ca9d8722..6f14c8ca87821 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -21,9 +21,3 @@ select * from range(1, null); -- range call with a mixed-case function name select * from RaNgE(2); - --- Explain -EXPLAIN select * from RaNgE(2); - --- cross-join table valued functions -EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 9a8d025331b67..daf47c4d0a39a 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 47 +-- Number of queries: 46 -- !query 0 @@ -459,31 +459,3 @@ struct --- !query 46 output -== Parsed Logical Plan == -'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] -+- 'UnresolvedRelation `test_agg` - -== Analyzed Logical Plan == -k: int, every(v): boolean, some(v): boolean, any(v): boolean -Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] -+- SubqueryAlias `test_agg` - +- Project [k#x, v#x] - +- SubqueryAlias `test_agg` - +- LocalRelation [k#x, v#x] - -== Optimized Logical Plan == -Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x] -+- LocalRelation [k#x, v#x] - -== Physical Plan == -*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) -+- Exchange hashpartitioning(k#x, 200) - +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x]) - +- LocalTableScan [k#x, v#x] diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index c065ce5012929..4e80f0bda5513 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 17 -- !query 0 @@ -151,33 +151,3 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991- struct> -- !query 16 output 1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] - - --- !query 17 -EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null) --- !query 17 schema -struct --- !query 17 output -== Parsed Logical Plan == -'Project [*] -+- 'Join Cross - :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] - +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)] - -== Analyzed Logical Plan == -col1: string, col2: int, col1: string, col2: int -Project [col1#x, col2#x, col1#x, col2#x] -+- Join Cross - :- LocalRelation [col1#x, col2#x] - +- LocalRelation [col1#x, col2#x] - -== Optimized Logical Plan == -Join Cross -:- LocalRelation [col1#x, col2#x] -+- LocalRelation [col1#x, col2#x] - -== Physical Plan == -BroadcastNestedLoopJoin BuildRight, Cross -:- LocalTableScan [col1#x, col2#x] -+- BroadcastExchange IdentityBroadcastMode - +- LocalTableScan [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 570b281353f3d..e0cbd575bc346 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 49 -- !query 0 @@ -195,260 +195,200 @@ struct -- !query 24 -explain select 'a' || 1 + 2 +select cot(1) -- !query 24 schema -struct +struct -- !query 24 output -== Physical Plan == -*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] -+- *Scan OneRowRelation[] +0.6420926159343306 -- !query 25 -explain select 1 - 2 || 'b' +select cot(null) -- !query 25 schema -struct +struct -- !query 25 output -== Physical Plan == -*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] -+- *Scan OneRowRelation[] +NULL -- !query 26 -explain select 2 * 4 + 3 || 'b' +select cot(0) -- !query 26 schema -struct +struct -- !query 26 output -== Physical Plan == -*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] -+- *Scan OneRowRelation[] +Infinity -- !query 27 -explain select 3 + 1 || 'a' || 4 / 2 +select cot(-1) -- !query 27 schema -struct +struct -- !query 27 output -== Physical Plan == -*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] -+- *Scan OneRowRelation[] +-0.6420926159343306 -- !query 28 -explain select 1 == 1 OR 'a' || 'b' == 'ab' +select ceiling(0) -- !query 28 schema -struct +struct -- !query 28 output -== Physical Plan == -*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] -+- *Scan OneRowRelation[] +0 -- !query 29 -explain select 'a' || 'c' == 'ac' AND 2 == 3 +select ceiling(1) -- !query 29 schema -struct +struct -- !query 29 output -== Physical Plan == -*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] -+- *Scan OneRowRelation[] +1 -- !query 30 -select cot(1) +select ceil(1234567890123456) -- !query 30 schema -struct +struct -- !query 30 output -0.6420926159343306 +1234567890123456 -- !query 31 -select cot(null) +select ceiling(1234567890123456) -- !query 31 schema -struct +struct -- !query 31 output -NULL +1234567890123456 -- !query 32 -select cot(0) +select ceil(0.01) -- !query 32 schema -struct +struct -- !query 32 output -Infinity +1 -- !query 33 -select cot(-1) +select ceiling(-0.10) -- !query 33 schema -struct +struct -- !query 33 output --0.6420926159343306 +0 -- !query 34 -select ceiling(0) +select floor(0) -- !query 34 schema -struct +struct -- !query 34 output 0 -- !query 35 -select ceiling(1) +select floor(1) -- !query 35 schema -struct +struct -- !query 35 output 1 -- !query 36 -select ceil(1234567890123456) +select floor(1234567890123456) -- !query 36 schema -struct +struct -- !query 36 output 1234567890123456 -- !query 37 -select ceiling(1234567890123456) --- !query 37 schema -struct --- !query 37 output -1234567890123456 - - --- !query 38 -select ceil(0.01) --- !query 38 schema -struct --- !query 38 output -1 - - --- !query 39 -select ceiling(-0.10) --- !query 39 schema -struct --- !query 39 output -0 - - --- !query 40 -select floor(0) --- !query 40 schema -struct --- !query 40 output -0 - - --- !query 41 -select floor(1) --- !query 41 schema -struct --- !query 41 output -1 - - --- !query 42 -select floor(1234567890123456) --- !query 42 schema -struct --- !query 42 output -1234567890123456 - - --- !query 43 select floor(0.01) --- !query 43 schema +-- !query 37 schema struct --- !query 43 output +-- !query 37 output 0 --- !query 44 +-- !query 38 select floor(-0.10) --- !query 44 schema +-- !query 38 schema struct --- !query 44 output +-- !query 38 output -1 --- !query 45 +-- !query 39 select 1 > 0.00001 --- !query 45 schema +-- !query 39 schema struct<(CAST(1 AS BIGINT) > 0):boolean> --- !query 45 output +-- !query 39 output true --- !query 46 +-- !query 40 select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null) --- !query 46 schema +-- !query 40 schema struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> --- !query 46 output +-- !query 40 output 1 NULL 0 NULL NULL NULL --- !query 47 +-- !query 41 select BIT_LENGTH('abc') --- !query 47 schema +-- !query 41 schema struct --- !query 47 output +-- !query 41 output 24 --- !query 48 +-- !query 42 select CHAR_LENGTH('abc') --- !query 48 schema +-- !query 42 schema struct --- !query 48 output +-- !query 42 output 3 --- !query 49 +-- !query 43 select CHARACTER_LENGTH('abc') --- !query 49 schema +-- !query 43 schema struct --- !query 49 output +-- !query 43 output 3 --- !query 50 +-- !query 44 select OCTET_LENGTH('abc') --- !query 50 schema +-- !query 44 schema struct --- !query 50 output +-- !query 44 output 3 --- !query 51 +-- !query 45 select abs(-3.13), abs('-2.19') --- !query 51 schema +-- !query 45 schema struct --- !query 51 output +-- !query 45 output 3.13 2.19 --- !query 52 +-- !query 46 select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11) --- !query 52 schema +-- !query 46 schema struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)> --- !query 52 output +-- !query 46 output -1.11 -1.11 1.11 1.11 --- !query 53 +-- !query 47 select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null) --- !query 53 schema +-- !query 47 schema struct --- !query 53 output +-- !query 47 output 1 0 NULL NULL NULL NULL --- !query 54 +-- !query 48 select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)) --- !query 54 schema +-- !query 48 schema struct --- !query 54 output +-- !query 48 output NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index e035505f15d28..69a8e958000db 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 14 -- !query 0 @@ -67,74 +67,49 @@ struct -- !query 8 -explain extended -select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') -from range(2) --- !query 8 schema -struct --- !query 8 output -== Parsed Logical Plan == -'Project [unresolvedalias('ifnull('id, x), None), unresolvedalias('nullif('id, x), None), unresolvedalias('nvl('id, x), None), unresolvedalias('nvl2('id, x, y), None)] -+- 'UnresolvedTableValuedFunction range, [2] - -== Analyzed Logical Plan == -ifnull(`id`, 'x'): string, nullif(`id`, 'x'): bigint, nvl(`id`, 'x'): string, nvl2(`id`, 'x', 'y'): string -Project [ifnull(id#xL, x) AS ifnull(`id`, 'x')#x, nullif(id#xL, x) AS nullif(`id`, 'x')#xL, nvl(id#xL, x) AS nvl(`id`, 'x')#x, nvl2(id#xL, x, y) AS nvl2(`id`, 'x', 'y')#x] -+- Range (0, 2, step=1, splits=None) - -== Optimized Logical Plan == -Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- Range (0, 2, step=1, splits=None) - -== Physical Plan == -*Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=2) - - --- !query 9 SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1) --- !query 9 schema +-- !query 8 schema struct --- !query 9 output +-- !query 8 output true 1 1 1 1 --- !query 10 +-- !query 9 SELECT float(1), double(1), decimal(1) --- !query 10 schema +-- !query 9 schema struct --- !query 10 output +-- !query 9 output 1.0 1.0 1 --- !query 11 +-- !query 10 SELECT date("2014-04-04"), timestamp(date("2014-04-04")) --- !query 11 schema +-- !query 10 schema struct --- !query 11 output +-- !query 10 output 2014-04-04 2014-04-04 00:00:00 --- !query 12 +-- !query 11 SELECT string(1, 2) --- !query 12 schema +-- !query 11 schema struct<> --- !query 12 output +-- !query 11 output org.apache.spark.sql.AnalysisException Function string accepts only one argument; line 1 pos 7 --- !query 13 +-- !query 12 CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st) --- !query 13 schema +-- !query 12 schema struct<> --- !query 13 output +-- !query 12 output --- !query 14 +-- !query 13 SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") --- !query 14 schema +-- !query 13 schema struct --- !query 14 output +-- !query 13 output gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index e8f2e0a81455a..25d93b2063146 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 13 -- !query 0 @@ -29,151 +29,80 @@ abc -- !query 3 -EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col -FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) --- !query 3 schema -struct --- !query 3 output -== Parsed Logical Plan == -'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x] -+- 'SubqueryAlias `__auto_generated_subquery_name` - +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x] - +- 'UnresolvedTableValuedFunction range, [10] - -== Analyzed Logical Plan == -col: string -Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] -+- SubqueryAlias `__auto_generated_subquery_name` - +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] - +- Range (0, 10, step=1, splits=None) - -== Optimized Logical Plan == -Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] -+- Range (0, 10, step=1, splits=None) - -== Physical Plan == -*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] -+- *Range (0, 10, step=1, splits=2) - - --- !query 4 select replace('abc', 'b', '123') --- !query 4 schema +-- !query 3 schema struct --- !query 4 output +-- !query 3 output a123c --- !query 5 +-- !query 4 select replace('abc', 'b') --- !query 5 schema +-- !query 4 schema struct --- !query 5 output +-- !query 4 output ac --- !query 6 +-- !query 5 select length(uuid()), (uuid() <> uuid()) --- !query 6 schema +-- !query 5 schema struct --- !query 6 output +-- !query 5 output 36 true --- !query 7 +-- !query 6 select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) --- !query 7 schema +-- !query 6 schema struct --- !query 7 output +-- !query 6 output 4 NULL NULL --- !query 8 +-- !query 7 select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null) --- !query 8 schema +-- !query 7 schema struct --- !query 8 output +-- !query 7 output ab abcd ab NULL --- !query 9 +-- !query 8 select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a') --- !query 9 schema +-- !query 8 schema struct --- !query 9 output +-- !query 8 output NULL NULL --- !query 10 +-- !query 9 select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) --- !query 10 schema +-- !query 9 schema struct --- !query 10 output +-- !query 9 output cd abcd cd NULL --- !query 11 +-- !query 10 select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') --- !query 11 schema +-- !query 10 schema struct --- !query 11 output +-- !query 10 output NULL NULL --- !query 12 -set spark.sql.function.concatBinaryAsString=false --- !query 12 schema -struct --- !query 12 output -spark.sql.function.concatBinaryAsString false - - --- !query 13 -EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - string(id + 1) col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -) --- !query 13 schema -struct --- !query 13 output -== Physical Plan == -*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] -+- *Range (0, 10, step=1, splits=2) - - --- !query 14 -EXPLAIN SELECT (col1 || (col3 || col4)) col -FROM ( - SELECT - string(id) col1, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -) --- !query 14 schema -struct --- !query 14 output -== Physical Plan == -*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] -+- *Range (0, 10, step=1, splits=2) - - --- !query 15 +-- !query 11 SELECT split('aa1cc2ee3', '[1-9]+') --- !query 15 schema +-- !query 11 schema struct> --- !query 15 output +-- !query 11 output ["aa","cc","ee",""] --- !query 16 +-- !query 12 SELECT split('aa1cc2ee3', '[1-9]+', 2) --- !query 16 schema +-- !query 12 schema struct> --- !query 16 output +-- !query 12 output ["aa","cc2ee3"] diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index 94af9181225d6..fdbea0ee90720 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 8 -- !query 0 @@ -99,42 +99,3 @@ struct -- !query 7 output 0 1 - - --- !query 8 -EXPLAIN select * from RaNgE(2) --- !query 8 schema -struct --- !query 8 output -== Physical Plan == -*Range (0, 2, step=1, splits=2) - - --- !query 9 -EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3) --- !query 9 schema -struct --- !query 9 output -== Parsed Logical Plan == -'Project [*] -+- 'Join Cross - :- 'UnresolvedTableValuedFunction range, [3] - +- 'UnresolvedTableValuedFunction range, [3] - -== Analyzed Logical Plan == -id: bigint, id: bigint -Project [id#xL, id#xL] -+- Join Cross - :- Range (0, 3, step=1, splits=None) - +- Range (0, 3, step=1, splits=None) - -== Optimized Logical Plan == -Join Cross -:- Range (0, 3, step=1, splits=None) -+- Range (0, 3, step=1, splits=None) - -== Physical Plan == -BroadcastNestedLoopJoin BuildRight, Cross -:- *Range (0, 3, step=1, splits=2) -+- BroadcastExchange IdentityBroadcastMode - +- *Range (0, 3, step=1, splits=2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 56d300e30a58e..ce475922eb5e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -29,10 +30,11 @@ class ExplainSuite extends QueryTest with SharedSQLContext { private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = { val output = new java.io.ByteArrayOutputStream() Console.withOut(output) { - df.explain(extended = false) + df.explain(extended = true) } + val normalizedOutput = output.toString.replaceAll("#\\d+", "#x") for (key <- keywords) { - assert(output.toString.contains(key)) + assert(normalizedOutput.contains(key)) } } @@ -53,6 +55,133 @@ class ExplainSuite extends QueryTest with SharedSQLContext { checkKeywordsExistsInExplain(df, keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)") } + + test("optimized plan should show the rewritten aggregate expression") { + withTempView("test_agg") { + sql( + """ + |CREATE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + | (1, true), (1, false), + | (2, true), + | (3, false), (3, null), + | (4, null), (4, null), + | (5, null), (5, true), (5, false) AS test_agg(k, v) + """.stripMargin) + + // simple explain of queries having every/some/any aggregates. Optimized + // plan should show the rewritten aggregate expression. + val df = sql("SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k") + checkKeywordsExistsInExplain(df, + "Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, " + + "max(v#x) AS any(v)#x]") + } + } + + test("explain inline tables cross-joins") { + val df = sql( + """ + |SELECT * FROM VALUES ('one', 1), ('three', null) + | CROSS JOIN VALUES ('one', 1), ('three', null) + """.stripMargin) + checkKeywordsExistsInExplain(df, + "Join Cross", + ":- LocalRelation [col1#x, col2#x]", + "+- LocalRelation [col1#x, col2#x]") + } + + test("explain table valued functions") { + checkKeywordsExistsInExplain(sql("select * from RaNgE(2)"), "Range (0, 2, step=1, splits=None)") + checkKeywordsExistsInExplain(sql("SELECT * FROM range(3) CROSS JOIN range(3)"), + "Join Cross", + ":- Range (0, 3, step=1, splits=None)", + "+- Range (0, 3, step=1, splits=None)") + } + + test("explain string functions") { + // Check if catalyst combine nested `Concat`s + val df1 = sql( + """ + |SELECT (col1 || col2 || col3 || col4) col + | FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) + """.stripMargin) + checkKeywordsExistsInExplain(df1, + "Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)" + + ", cast(id#xL as string)) AS col#x]") + + // Check if catalyst combine nested `Concat`s if concatBinaryAsString=false + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { + val df2 = sql( + """ + |SELECT ((col1 || col2) || (col3 || col4)) col + |FROM ( + | SELECT + | string(id) col1, + | string(id + 1) col2, + | encode(string(id + 2), 'utf-8') col3, + | encode(string(id + 3), 'utf-8') col4 + | FROM range(10) + |) + """.stripMargin) + checkKeywordsExistsInExplain(df2, + "Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), " + + "cast(encode(cast((id#xL + 2) as string), utf-8) as string), " + + "cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]") + + val df3 = sql( + """ + |SELECT (col1 || (col3 || col4)) col + |FROM ( + | SELECT + | string(id) col1, + | encode(string(id + 2), 'utf-8') col3, + | encode(string(id + 3), 'utf-8') col4 + | FROM range(10) + |) + """.stripMargin) + checkKeywordsExistsInExplain(df3, + "Project [concat(cast(id#xL as string), " + + "cast(encode(cast((id#xL + 2) as string), utf-8) as string), " + + "cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]") + } + } + + test("check operator precedence") { + // We follow Oracle operator precedence in the table below that lists the levels + // of precedence among SQL operators from high to low: + // --------------------------------------------------------------------------------------- + // Operator Operation + // --------------------------------------------------------------------------------------- + // +, - identity, negation + // *, / multiplication, division + // +, -, || addition, subtraction, concatenation + // =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison + // NOT exponentiation, logical negation + // AND conjunction + // OR disjunction + // --------------------------------------------------------------------------------------- + checkKeywordsExistsInExplain(sql("select 'a' || 1 + 2"), + "Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x]") + checkKeywordsExistsInExplain(sql("select 1 - 2 || 'b'"), + "Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x]") + checkKeywordsExistsInExplain(sql("select 2 * 4 + 3 || 'b'"), + "Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x]") + checkKeywordsExistsInExplain(sql("select 3 + 1 || 'a' || 4 / 2"), + "Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), " + + "CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x]") + checkKeywordsExistsInExplain(sql("select 1 == 1 OR 'a' || 'b' == 'ab'"), + "Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x]") + checkKeywordsExistsInExplain(sql("select 'a' || 'c' == 'ac' AND 2 == 3"), + "Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x]") + } + + test("explain for these functions; use range to avoid constant folding") { + val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " + + "from range(2)") + checkKeywordsExistsInExplain(df, + "Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, " + + "id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, " + + "x AS nvl2(`id`, 'x', 'y')#x]") + } } case class ExplainSingleData(id: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index cf4585bf7ac6c..b2515226d9a14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -137,28 +137,39 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } } + // For better test coverage, runs the tests on mixed config sets: WHOLESTAGE_CODEGEN_ENABLED + // and CODEGEN_FACTORY_MODE. + private lazy val codegenConfigSets = Array( + ("true", "CODEGEN_ONLY"), + ("false", "CODEGEN_ONLY"), + ("false", "NO_CODEGEN") + ).map { case (wholeStageCodegenEnabled, codegenFactoryMode) => + Array(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageCodegenEnabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode) + } + /** Run a test case. */ private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) val (comments, code) = input.split("\n").partition(_.startsWith("--")) - // Runs all the tests on both codegen-only and interpreter modes - val codegenConfigSets = Array(CODEGEN_ONLY, NO_CODEGEN).map { - case codegenFactoryMode => - Array(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode.toString) - } - val configSets = { - val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) - val configs = configLines.map(_.split(",").map { confAndValue => - val (conf, value) = confAndValue.span(_ != '=') - conf.trim -> value.substring(1).trim - }) - // When we are regenerating the golden files, we don't need to set any config as they - // all need to return the same result - if (regenerateGoldenFiles) { - Array.empty[Array[(String, String)]] - } else { + // List of SQL queries to run + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + // When we are regenerating the golden files, we don't need to set any config as they + // all need to return the same result + if (regenerateGoldenFiles) { + runQueries(queries, testCase.resultFile, None) + } else { + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + if (configs.nonEmpty) { codegenConfigSets.flatMap { codegenConfig => configs.map { config => @@ -169,15 +180,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { codegenConfigSets } } - } - // List of SQL queries to run - // note: this is not a robust way to split queries using semicolon, but works for now. - val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq - - if (configSets.isEmpty) { - runQueries(queries, testCase.resultFile, None) - } else { configSets.foreach { configSet => try { runQueries(queries, testCase.resultFile, Some(configSet)) From 5ad03607d1487e7ab3e3b6d00eef9c4028ed4975 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 20 Dec 2018 10:47:24 +0800 Subject: [PATCH 2314/2461] [SPARK-25271][SQL] Hive ctas commands should use data source if it is convertible ## What changes were proposed in this pull request? In Spark 2.3.0 and previous versions, Hive CTAS command will convert to use data source to write data into the table when the table is convertible. This behavior is controlled by the configs like HiveUtils.CONVERT_METASTORE_ORC and HiveUtils.CONVERT_METASTORE_PARQUET. In 2.3.1, we drop this optimization by mistake in the PR [SPARK-22977](https://github.com/apache/spark/pull/20521/files#r217254430). Since that Hive CTAS command only uses Hive Serde to write data. This patch adds this optimization back to Hive CTAS command. This patch adds OptimizedCreateHiveTableAsSelectCommand which uses data source to write data. ## How was this patch tested? Added test. Closes #22514 from viirya/SPARK-25271-2. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/execution/command/ddl.scala | 8 ++ .../datasources/DataSourceStrategy.scala | 12 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 43 +++++- .../spark/sql/hive/HiveStrategies.scala | 62 +++----- .../org/apache/spark/sql/hive/HiveUtils.scala | 8 ++ .../CreateHiveTableAsSelectCommand.scala | 134 +++++++++++++----- .../spark/sql/hive/HiveParquetSuite.scala | 14 ++ .../sql/hive/execution/SQLQuerySuite.scala | 40 ++++++ 8 files changed, 230 insertions(+), 91 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index e1faecedd20ed..096481f68275d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -820,6 +820,14 @@ object DDLUtils { table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER } + def readHiveTable(table: CatalogTable): HiveTableRelation = { + HiveTableRelation( + table, + // Hive table columns are always nullable. + table.dataSchema.asNullable.toAttributes, + table.partitionSchema.asNullable.toAttributes) + } + /** * Throws a standard error for actions that require partitionProvider = hive. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b304e2da6e1cf..b5cf8c9515bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -244,27 +244,19 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] }) } - private def readHiveTable(table: CatalogTable): LogicalPlan = { - HiveTableRelation( - table, - // Hive table columns are always nullable. - table.dataSchema.asNullable.toAttributes, - table.partitionSchema.asNullable.toAttributes) - } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) if DDLUtils.isDatasourceTable(tableMeta) => i.copy(table = readDataSourceTable(tableMeta)) case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) => - i.copy(table = readHiveTable(tableMeta)) + i.copy(table = DDLUtils.readHiveTable(tableMeta)) case UnresolvedCatalogRelation(tableMeta) if DDLUtils.isDatasourceTable(tableMeta) => readDataSourceTable(tableMeta) case UnresolvedCatalogRelation(tableMeta) => - readHiveTable(tableMeta) + DDLUtils.readHiveTable(tableMeta) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5823548a8063c..03f4b8d83e353 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.Locale + import scala.util.control.NonFatal import com.google.common.util.concurrent.Striped @@ -29,6 +31,8 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -113,7 +117,44 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - def convertToLogicalRelation( + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + + def convert(relation: HiveTableRelation): LogicalRelation = { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. + if (serde.contains("parquet")) { + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) + convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") + } else { + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties + if (SQLConf.get.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { + convertToLogicalRelation( + relation, + options, + classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat], + "orc") + } else { + convertToLogicalRelation( + relation, + options, + classOf[org.apache.spark.sql.hive.orc.OrcFileFormat], + "orc") + } + } + } + + private def convertToLogicalRelation( relation: HiveTableRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 07ee105404311..8a5ab188a949f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTab import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} +import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -181,49 +180,17 @@ case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { private def isConvertible(relation: HiveTableRelation): Boolean = { - val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) - serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || - serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) + isConvertible(relation.tableMeta) } - // Return true for Apache ORC and Hive ORC-related configuration names. - // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. - private def isOrcProperty(key: String) = - key.startsWith("orc.") || key.contains(".orc.") - - private def isParquetProperty(key: String) = - key.startsWith("parquet.") || key.contains(".parquet.") - - private def convert(relation: HiveTableRelation): LogicalRelation = { - val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) - - // Consider table and storage properties. For properties existing in both sides, storage - // properties will supersede table properties. - if (serde.contains("parquet")) { - val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ - relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) - sessionCatalog.metastoreCatalog - .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") - } else { - val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ - relation.tableMeta.storage.properties - if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { - sessionCatalog.metastoreCatalog.convertToLogicalRelation( - relation, - options, - classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat], - "orc") - } else { - sessionCatalog.metastoreCatalog.convertToLogicalRelation( - relation, - options, - classOf[org.apache.spark.sql.hive.orc.OrcFileFormat], - "orc") - } - } + private def isConvertible(tableMeta: CatalogTable): Boolean = { + val serde = tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + serde.contains("parquet") && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || + serde.contains("orc") && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + private val metastoreCatalog = sessionCatalog.metastoreCatalog + override def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { // Write path @@ -231,12 +198,21 @@ case class RelationConversions( // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && !r.isPartitioned && isConvertible(r) => - InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists) + InsertIntoTable(metastoreCatalog.convert(r), partition, + query, overwrite, ifPartitionNotExists) // Read path case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => - convert(relation) + metastoreCatalog.convert(relation) + + // CTAS + case CreateTable(tableDesc, mode, Some(query)) + if DDLUtils.isHiveTable(tableDesc) && tableDesc.partitionColumnNames.isEmpty && + isConvertible(tableDesc) && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_CTAS) => + DDLUtils.checkDataColNames(tableDesc) + OptimizedCreateHiveTableAsSelectCommand( + tableDesc, query, query.output.map(_.name), mode) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 66067704195dd..b60d4c71f5941 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -110,6 +110,14 @@ private[spark] object HiveUtils extends Logging { .booleanConf .createWithDefault(true) + val CONVERT_METASTORE_CTAS = buildConf("spark.sql.hive.convertMetastoreCtas") + .doc("When set to true, Spark will try to use built-in data source writer " + + "instead of Hive serde in CTAS. This flag is effective only if " + + "`spark.sql.hive.convertMetastoreParquet` or `spark.sql.hive.convertMetastoreOrc` is " + + "enabled respectively for Parquet and ORC formats") + .booleanConf + .createWithDefault(true) + val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index fd1e931ee0c7a..608f21e726259 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,32 +20,26 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation} +import org.apache.spark.sql.hive.HiveSessionCatalog +trait CreateHiveTableAsSelectBase extends DataWritingCommand { + val tableDesc: CatalogTable + val query: LogicalPlan + val outputColumnNames: Seq[String] + val mode: SaveMode -/** - * Create table and insert the query result into it. - * - * @param tableDesc the Table Describe, which may contain serde, storage handler etc. - * @param query the query whose result will be insert into the new relation - * @param mode SaveMode - */ -case class CreateHiveTableAsSelectCommand( - tableDesc: CatalogTable, - query: LogicalPlan, - outputColumnNames: Seq[String], - mode: SaveMode) - extends DataWritingCommand { - - private val tableIdentifier = tableDesc.identifier + protected val tableIdentifier = tableDesc.identifier override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (catalog.tableExists(tableIdentifier)) { + val tableExists = catalog.tableExists(tableIdentifier) + + if (tableExists) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -57,15 +51,8 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - // For CTAS, there is no static partition values to insert. - val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap - InsertIntoHiveTable( - tableDesc, - partition, - query, - overwrite = false, - ifPartitionNotExists = false, - outputColumnNames = outputColumnNames).run(sparkSession, child) + val command = getWritingCommand(catalog, tableDesc, tableExists = true) + command.run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -77,15 +64,8 @@ case class CreateHiveTableAsSelectCommand( try { // Read back the metadata of the table which was created just now. val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) - // For CTAS, there is no static partition values to insert. - val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap - InsertIntoHiveTable( - createdTableMeta, - partition, - query, - overwrite = true, - ifPartitionNotExists = false, - outputColumnNames = outputColumnNames).run(sparkSession, child) + val command = getWritingCommand(catalog, createdTableMeta, tableExists = false) + command.run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. @@ -97,9 +77,89 @@ case class CreateHiveTableAsSelectCommand( Seq.empty[Row] } + // Returns `DataWritingCommand` which actually writes data into the table. + def getWritingCommand( + catalog: SessionCatalog, + tableDesc: CatalogTable, + tableExists: Boolean): DataWritingCommand + override def argString: String = { s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" } } + +/** + * Create table and insert the query result into it. + * + * @param tableDesc the table description, which may contain serde, storage handler etc. + * @param query the query whose result will be insert into the new relation + * @param mode SaveMode + */ +case class CreateHiveTableAsSelectCommand( + tableDesc: CatalogTable, + query: LogicalPlan, + outputColumnNames: Seq[String], + mode: SaveMode) + extends CreateHiveTableAsSelectBase { + + override def getWritingCommand( + catalog: SessionCatalog, + tableDesc: CatalogTable, + tableExists: Boolean): DataWritingCommand = { + // For CTAS, there is no static partition values to insert. + val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + tableDesc, + partition, + query, + overwrite = if (tableExists) false else true, + ifPartitionNotExists = false, + outputColumnNames = outputColumnNames) + } +} + +/** + * Create table and insert the query result into it. This creates Hive table but inserts + * the query result into it by using data source. + * + * @param tableDesc the table description, which may contain serde, storage handler etc. + * @param query the query whose result will be insert into the new relation + * @param mode SaveMode + */ +case class OptimizedCreateHiveTableAsSelectCommand( + tableDesc: CatalogTable, + query: LogicalPlan, + outputColumnNames: Seq[String], + mode: SaveMode) + extends CreateHiveTableAsSelectBase { + + override def getWritingCommand( + catalog: SessionCatalog, + tableDesc: CatalogTable, + tableExists: Boolean): DataWritingCommand = { + val metastoreCatalog = catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog + val hiveTable = DDLUtils.readHiveTable(tableDesc) + + val hadoopRelation = metastoreCatalog.convert(hiveTable) match { + case LogicalRelation(t: HadoopFsRelation, _, _, _) => t + case _ => throw new AnalysisException(s"$tableIdentifier should be converted to " + + "HadoopFsRelation.") + } + + InsertIntoHadoopFsRelationCommand( + hadoopRelation.location.rootPaths.head, + Map.empty, // We don't support to convert partitioned table. + false, + Seq.empty, // We don't support to convert partitioned table. + hadoopRelation.bucketSpec, + hadoopRelation.fileFormat, + hadoopRelation.options, + query, + if (tableExists) mode else SaveMode.Overwrite, + Some(tableDesc), + Some(hadoopRelation.location), + query.output.map(_.name)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index e5c9df05d5674..470c6a342b4dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -92,4 +92,18 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton } } } + + test("SPARK-25271: write empty map into hive parquet table") { + import testImplicits._ + + Seq(Map(1 -> "a"), Map.empty[Int, String]).toDF("m").createOrReplaceTempView("p") + withTempView("p") { + val targetTable = "targetTable" + withTable(targetTable) { + sql(s"CREATE TABLE $targetTable STORED AS PARQUET AS SELECT m FROM p") + checkAnswer(sql(s"SELECT m FROM $targetTable"), + Row(Map(1 -> "a")) :: Row(Map.empty[Int, String]) :: Nil) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index fab2a27cdef17..6acf44606cbbe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2276,6 +2276,46 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-25271: Hive ctas commands should use data source if it is convertible") { + withTempView("p") { + Seq(1, 2, 3).toDF("id").createOrReplaceTempView("p") + + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + Seq(true, false).foreach { isConvertedCtas => + withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> s"$isConvertedCtas") { + + val targetTable = "targetTable" + withTable(targetTable) { + val df = sql(s"CREATE TABLE $targetTable STORED AS $format AS SELECT id FROM p") + checkAnswer(sql(s"SELECT id FROM $targetTable"), + Row(1) :: Row(2) :: Row(3) :: Nil) + + val ctasDSCommand = df.queryExecution.analyzed.collect { + case _: OptimizedCreateHiveTableAsSelectCommand => true + }.headOption + val ctasCommand = df.queryExecution.analyzed.collect { + case _: CreateHiveTableAsSelectCommand => true + }.headOption + + if (isConverted && isConvertedCtas) { + assert(ctasDSCommand.nonEmpty) + assert(ctasCommand.isEmpty) + } else { + assert(ctasDSCommand.isEmpty) + assert(ctasCommand.nonEmpty) + } + } + } + } + } + } + } + } + } test("SPARK-26181 hasMinMaxStats method of ColumnStatsMap is not correct") { withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { From 04d8e3a33c6bb08b2891ca52613cd5ccd24a69dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BA=AE?= Date: Thu, 20 Dec 2018 13:22:12 +0800 Subject: [PATCH 2315/2461] [SPARK-26318][SQL] Deprecate Row.merge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Deprecate Row.merge ## How was this patch tested? N/A Closes #23271 from KyleLi1985/master. Authored-by: 李亮 Signed-off-by: Hyukjin Kwon --- sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e12bf9616e2de..4f5af9ac80b10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -57,6 +57,7 @@ object Row { /** * Merge multiple rows into a single row, one after another. */ + @deprecated("This method is deprecated and will be removed in future versions.", "3.0.0") def merge(rows: Row*): Row = { // TODO: Improve the performance of this if used in performance critical part. new GenericRow(rows.flatMap(_.toSeq).toArray) From 98c0ca78610ccf62784081353584717c62285485 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 20 Dec 2018 14:17:44 +0800 Subject: [PATCH 2316/2461] [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF ## What changes were proposed in this pull request? Currently, when we infer the schema for scala/java decimals, we return as data type the `SYSTEM_DEFAULT` implementation, ie. the decimal type with precision 38 and scale 18. But this is not right, as we know nothing about the right precision and scale and these values can be not enough to store the data. This problem arises in particular with UDF, where we cast all the input of type `DecimalType` to a `DecimalType(38, 18)`: in case this is not enough, null is returned as input for the UDF. The PR defines a custom handling for casting to the expected data types for ScalaUDF: the decimal precision and scale is picked from the input, so no casting to different and maybe wrong percision and scale happens. ## How was this patch tested? added UTs Closes #23308 from mgaido91/SPARK-26308. Authored-by: Marco Gaido Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/TypeCoercion.scala | 31 ++++++++++++++++++ .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 32 ++++++++++++++++++- 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 133fa119b7aa6..1706b3eece6d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -879,6 +879,37 @@ object TypeCoercion { } } e.withNewChildren(children) + + case udf: ScalaUDF if udf.inputTypes.nonEmpty => + val children = udf.children.zip(udf.inputTypes).map { case (in, expected) => + implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in) + } + udf.withNewChildren(children) + } + + private def udfInputToCastType(input: DataType, expectedType: DataType): DataType = { + (input, expectedType) match { + // SPARK-26308: avoid casting to an arbitrary precision and scale for decimals. Please note + // that precision and scale cannot be inferred properly for a ScalaUDF because, when it is + // created, it is not bound to any column. So here the precision and scale of the input + // column is used. + case (in: DecimalType, _: DecimalType) => in + case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) => + ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp) + case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp, nullableExp)) => + MapType(udfInputToCastType(keyDtIn, keyDtExp), + udfInputToCastType(valueDtIn, valueDtExp), + nullableExp) + case (StructType(fieldsIn), StructType(fieldsExp)) => + val fieldTypes = + fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case (dtIn, dtExp) => + udfInputToCastType(dtIn, dtExp) + } + StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) => + field.copy(dataType = newDt) + }) + case (_, other) => other + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index fae90caebf96c..a23aaa3a0b3ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -52,7 +52,7 @@ case class ScalaUDF( udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) - extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + extends Expression with NonSQLExpression with UserDefinedExpression { // The constructor for SPARK 2.1 and 2.2 def this( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 20dcefa7e3cad..a26d306cff6b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.math.BigDecimal + import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.QueryExecution @@ -26,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.{DataTypes, DoubleType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.QueryExecutionListener @@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) } } + + test("SPARK-26308: udf with decimal") { + val df1 = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(new BigDecimal("2011000000000002456556")))), + StructType(Seq(StructField("col1", DecimalType(30, 0))))) + val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => { + if (value == null) null else value.toBigInteger.toString + }) + checkAnswer(df1.select(udf1(df1.col("col1"))), Seq(Row("2011000000000002456556"))) + } + + test("SPARK-26308: udf with complex types of decimal") { + val df1 = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(Array(new BigDecimal("2011000000000002456556"))))), + StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0)))))) + val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => { + arr.map(value => if (value == null) null else value.toBigInteger.toString) + }) + checkAnswer(df1.select(udf1($"col1")), Seq(Row(Array("2011000000000002456556")))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(Map("a" -> new BigDecimal("2011000000000002456556"))))), + StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30, 0)))))) + val udf2 = org.apache.spark.sql.functions.udf((map: Map[String, BigDecimal]) => { + map.mapValues(value => if (value == null) null else value.toBigInteger.toString) + }) + checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556")))) + } } From 7c8f4756c34a0b00931c2987c827a18d989e6c08 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 20 Dec 2018 08:26:25 -0600 Subject: [PATCH 2317/2461] [SPARK-24687][CORE] Avoid job hanging when generate task binary causes fatal error ## What changes were proposed in this pull request? When NoClassDefFoundError thrown,it will cause job hang. `Exception in thread "dag-scheduler-event-loop" java.lang.NoClassDefFoundError: Lcom/xxx/data/recommend/aggregator/queue/QueueName; at java.lang.Class.getDeclaredFields0(Native Method) at java.lang.Class.privateGetDeclaredFields(Class.java:2436) at java.lang.Class.getDeclaredField(Class.java:1946) at java.io.ObjectStreamClass.getDeclaredSUID(ObjectStreamClass.java:1659) at java.io.ObjectStreamClass.access$700(ObjectStreamClass.java:72) at java.io.ObjectStreamClass$2.run(ObjectStreamClass.java:480) at java.io.ObjectStreamClass$2.run(ObjectStreamClass.java:468) at java.security.AccessController.doPrivileged(Native Method) at java.io.ObjectStreamClass.(ObjectStreamClass.java:468) at java.io.ObjectStreamClass.lookup(ObjectStreamClass.java:365) at java.io.ObjectOutputStream.writeClass(ObjectOutputStream.java:1212) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1119) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1377) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1173) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508) at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431) at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177) at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1377)` It is caused by NoClassDefFoundError will not catch up during task seriazation. `var taskBinary: Broadcast[Array[Byte]] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => JavaUtils.bufferToArray( closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return }` image below shows that stage 33 blocked and never be scheduled. 2018-06-28 4 28 42 2018-06-28 4 28 49 ## How was this patch tested? UT Closes #21664 from caneGuy/zhoukang/fix-noclassdeferror. Authored-by: zhoukang Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 06966e77db81e..6f4c326442e1e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1170,9 +1170,11 @@ private[spark] class DAGScheduler( // Abort execution return - case NonFatal(e) => + case e: Throwable => abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage + + // Abort execution return } From a888d202ab719ccb13b328aa445341d1d6b06881 Mon Sep 17 00:00:00 2001 From: Jorge Machado Date: Thu, 20 Dec 2018 08:29:51 -0600 Subject: [PATCH 2318/2461] [SPARK-26324][DOCS] Add Spark docs for Running in Mesos with SSL ## What changes were proposed in this pull request? Added docs for running spark jobs with Mesos on SSL Closes #23342 from jomach/master. Lead-authored-by: Jorge Machado Co-authored-by: Jorge Machado Co-authored-by: Jorge Machado Co-authored-by: Jorge Machado Signed-off-by: Sean Owen --- docs/running-on-mesos.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 968d668e2c93a..a07773c1c71e1 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -108,6 +108,19 @@ Please note that if you specify multiple ways to obtain the credentials then the An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. +### Deploy to a Mesos running on Secure Sockets + +If you want to deploy a Spark Application into a Mesos cluster that is running in a secure mode there are some environment variables that need to be set. + +- `LIBPROCESS_SSL_ENABLED=true` enables SSL communication +- `LIBPROCESS_SSL_VERIFY_CERT=false` verifies the ssl certificate +- `LIBPROCESS_SSL_KEY_FILE=pathToKeyFile.key` path to key +- `LIBPROCESS_SSL_CERT_FILE=pathToCRTFile.crt` the certificate file to be used + +All options can be found at http://mesos.apache.org/documentation/latest/ssl/ + +Then submit happens as described in Client mode or Cluster mode below + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary From 6692bacf3e74e7a17d8e676e8a06ab198f85d328 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 20 Dec 2018 10:05:56 -0800 Subject: [PATCH 2319/2461] [SPARK-26409][SQL][TESTS] SQLConf should be serializable in test sessions ## What changes were proposed in this pull request? `SQLConf` is supposed to be serializable. However, currently it is not serializable in `WithTestConf`. `WithTestConf` uses the method `overrideConfs` in closure, while the classes which implements it (`TestHiveSessionStateBuilder` and `TestSQLSessionStateBuilder`) are not serializable. This PR is to use a local variable to fix it. ## How was this patch tested? Add unit test. Closes #23352 from gengliangwang/serializableSQLConf. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../apache/spark/sql/internal/BaseSessionStateBuilder.scala | 3 ++- .../test/scala/org/apache/spark/sql/SerializationSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index ac07e1f6bb4f8..319c2649592fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -309,13 +309,14 @@ private[sql] trait WithTestConf { self: BaseSessionStateBuilder => def overrideConfs: Map[String, String] override protected lazy val conf: SQLConf = { + val overrideConfigurations = overrideConfs val conf = parentState.map(_.conf.clone()).getOrElse { new SQLConf { clear() override def clear(): Unit = { super.clear() // Make sure we start with the default test configs even after clear - overrideConfs.foreach { case (key, value) => setConfString(key, value) } + overrideConfigurations.foreach { case (key, value) => setConfString(key, value) } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index cd6b2647e0be6..1a1c956aed3d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -27,4 +27,9 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext { val spark = SparkSession.builder.getOrCreate() new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext) } + + test("[SPARK-26409] SQLConf should be serializable") { + val spark = SparkSession.builder.getOrCreate() + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sessionState.conf) + } } From 3d6b44d9ea92dc1eabb8f211176861e51240bf93 Mon Sep 17 00:00:00 2001 From: Ngone51 Date: Thu, 20 Dec 2018 10:25:52 -0800 Subject: [PATCH 2320/2461] [SPARK-26392][YARN] Cancel pending allocate requests by taking locality preference into account ## What changes were proposed in this pull request? Right now, we cancel pending allocate requests by its sending order. I thing we can take locality preference into account when do this to perfom least impact on task locality preference. ## How was this patch tested? N.A. Closes #23344 from Ngone51/dev-cancel-pending-allocate-requests-by-taking-locality-preference-into-account. Authored-by: Ngone51 Signed-off-by: Marcelo Vanzin --- .../spark/deploy/yarn/YarnAllocator.scala | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index d37d0d66d8ae2..54b1ec266113f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -294,6 +294,15 @@ private[yarn] class YarnAllocator( s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") + // Split the pending container request into three groups: locality matched list, locality + // unmatched list and non-locality list. Take the locality matched container request into + // consideration of container placement, treat as allocated containers. + // For locality unmatched and locality free container requests, cancel these container + // requests, since required locality preference has been changed, recalculating using + // container placement strategy. + val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality( + hostToLocalTaskCounts, pendingAllocate) + if (missing > 0) { if (log.isInfoEnabled()) { var requestContainerMessage = s"Will request $missing executor container(s), each with " + @@ -306,15 +315,6 @@ private[yarn] class YarnAllocator( logInfo(requestContainerMessage) } - // Split the pending container request into three groups: locality matched list, locality - // unmatched list and non-locality list. Take the locality matched container request into - // consideration of container placement, treat as allocated containers. - // For locality unmatched and locality free container requests, cancel these container - // requests, since required locality preference has been changed, recalculating using - // container placement strategy. - val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality( - hostToLocalTaskCounts, pendingAllocate) - // cancel "stale" requests for locations that are no longer needed staleRequests.foreach { stale => amClient.removeContainerRequest(stale) @@ -374,14 +374,9 @@ private[yarn] class YarnAllocator( val numToCancel = math.min(numPendingAllocate, -missing) logInfo(s"Canceling requests for $numToCancel executor container(s) to have a new desired " + s"total $targetNumExecutors executors.") - - val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) - if (!matchingRequests.isEmpty) { - matchingRequests.iterator().next().asScala - .take(numToCancel).foreach(amClient.removeContainerRequest) - } else { - logWarning("Expected to find pending requests, but found none.") - } + // cancel pending allocate requests by taking locality preference into account + val cancelRequests = (staleRequests ++ anyHostRequests ++ localRequests).take(numToCancel) + cancelRequests.foreach(amClient.removeContainerRequest) } } From aa0d4ca8bab08a467645080a5b8a28bf6dd8a042 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 20 Dec 2018 11:22:49 -0800 Subject: [PATCH 2321/2461] [SPARK-25970][ML] Add Instrumentation to PrefixSpan ## What changes were proposed in this pull request? Add Instrumentation to PrefixSpan ## How was this patch tested? existing tests Closes #22971 from zhengruifeng/log_PrefixSpan. Authored-by: zhengruifeng Signed-off-by: Xiangrui Meng --- .../src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 2a3413553a6af..b0006a8d4a58e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -135,7 +136,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = instrumented { instr => + instr.logDataset(dataset) + instr.logParams(this, params: _*) + val sequenceColParam = $(sequenceCol) val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && From 98ecda3e8ef9db5e21b5b9605df09d1653094b9c Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 21 Dec 2018 13:01:14 +0800 Subject: [PATCH 2322/2461] [MINOR][SQL] Locality does not need to be implemented ## What changes were proposed in this pull request? `HadoopFileWholeTextReader` and `HadoopFileLinesReader` will be eventually called in `FileSourceScanExec`. In fact, locality has been implemented in `FileScanRDD`, even if we implement it in `HadoopFileWholeTextReader ` and `HadoopFileLinesReader`, it would be useless. So I think these `TODO` can be removed. ## How was this patch tested? N/A Closes #23339 from 10110346/noneededtodo. Authored-by: liuxian Signed-off-by: Wenchen Fan --- .../spark/sql/execution/datasources/HadoopFileLinesReader.scala | 2 +- .../sql/execution/datasources/HadoopFileWholeTextReader.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 00a78f7343c59..57082b40e1132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -51,7 +51,7 @@ class HadoopFileLinesReader( new Path(new URI(file.filePath)), file.start, file.length, - // TODO: Implement Locality + // The locality is decided by `getPreferredLocations` in `FileScanRDD`. Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala index c61a89e6e8c3f..f5724f7c5955d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileWholeTextReader.scala @@ -40,7 +40,7 @@ class HadoopFileWholeTextReader(file: PartitionedFile, conf: Configuration) Array(new Path(new URI(file.filePath))), Array(file.start), Array(file.length), - // TODO: Implement Locality + // The locality is decided by `getPreferredLocations` in `FileScanRDD`. Array.empty[String]) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) From 305e9b5ad22b428501fd42d3730d73d2e09ad4c5 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 21 Dec 2018 16:09:30 +0800 Subject: [PATCH 2323/2461] [SPARK-26422][R] Support to disable Hive support in SparkR even for Hadoop versions unsupported by Hive fork ## What changes were proposed in this pull request? Currently, even if I explicitly disable Hive support in SparkR session as below: ```r sparkSession <- sparkR.session("local[4]", "SparkR", Sys.getenv("SPARK_HOME"), enableHiveSupport = FALSE) ``` produces when the Hadoop version is not supported by our Hive fork: ``` java.lang.reflect.InvocationTargetException ... Caused by: java.lang.IllegalArgumentException: Unrecognized Hadoop major version number: 3.1.1.3.1.0.0-78 at org.apache.hadoop.hive.shims.ShimLoader.getMajorVersion(ShimLoader.java:174) at org.apache.hadoop.hive.shims.ShimLoader.loadShims(ShimLoader.java:139) at org.apache.hadoop.hive.shims.ShimLoader.getHadoopShims(ShimLoader.java:100) at org.apache.hadoop.hive.conf.HiveConf$ConfVars.(HiveConf.java:368) ... 43 more Error in handleErrors(returnStatus, conn) : java.lang.ExceptionInInitializerError at org.apache.hadoop.hive.conf.HiveConf.(HiveConf.java:105) at java.lang.Class.forName0(Native Method) at java.lang.Class.forName(Class.java:348) at org.apache.spark.util.Utils$.classForName(Utils.scala:193) at org.apache.spark.sql.SparkSession$.hiveClassesArePresent(SparkSession.scala:1116) at org.apache.spark.sql.api.r.SQLUtils$.getOrCreateSparkSession(SQLUtils.scala:52) at org.apache.spark.sql.api.r.SQLUtils.getOrCreateSparkSession(SQLUtils.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ``` The root cause is that: ``` SparkSession.hiveClassesArePresent ``` check if the class is loadable or not to check if that's in classpath but `org.apache.hadoop.hive.conf.HiveConf` has a check for Hadoop version as static logic which is executed right away. This throws an `IllegalArgumentException` and that's not caught: https://github.com/apache/spark/blob/36edbac1c8337a4719f90e4abd58d38738b2e1fb/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala#L1113-L1121 So, currently, if users have a Hive built-in Spark with unsupported Hadoop version by our fork (namely 3+), there's no way to use SparkR even though it could work. This PR just propose to change the order of bool comparison so that we can don't execute `SparkSession.hiveClassesArePresent` when: 1. `enableHiveSupport` is explicitly disabled 2. `spark.sql.catalogImplementation` is `in-memory` so that we **only** check `SparkSession.hiveClassesArePresent` when Hive support is explicitly enabled by short circuiting. ## How was this patch tested? It's difficult to write a test since we don't run tests against Hadoop 3 yet. See https://github.com/apache/spark/pull/21588. Manually tested. Closes #23356 from HyukjinKwon/SPARK-26422. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index becb05cf72aba..e98cab8b56d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -49,9 +49,17 @@ private[sql] object SQLUtils extends Logging { sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { val spark = - if (SparkSession.hiveClassesArePresent && enableHiveSupport && + if (enableHiveSupport && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == - "hive") { + "hive" && + // Note that the order of conditions here are on purpose. + // `SparkSession.hiveClassesArePresent` checks if Hive's `HiveConf` is loadable or not; + // however, `HiveConf` itself has some static logic to check if Hadoop version is + // supported or not, which throws an `IllegalArgumentException` if unsupported. + // If this is checked first, there's no way to disable Hive support in the case above. + // So, we intentionally check if Hive classes are loadable or not only when + // Hive support is explicitly enabled by short-circuiting. See also SPARK-26422. + SparkSession.hiveClassesArePresent) { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { if (enableHiveSupport) { From 8e76d6621aaddb8b73443b14ea2c6eebe9089893 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 21 Dec 2018 10:41:25 -0800 Subject: [PATCH 2324/2461] [SPARK-26267][SS] Retry when detecting incorrect offsets from Kafka ## What changes were proposed in this pull request? Due to [KAFKA-7703](https://issues.apache.org/jira/browse/KAFKA-7703), Kafka may return an earliest offset when we are request a latest offset. This will cause Spark to reprocess data. As per suggestion in KAFKA-7703, we put a position call between poll and seekToEnd to block the fetch request triggered by `poll` before calling `seekToEnd`. In addition, to avoid other unknown issues, we also use the previous known offsets to audit the latest offsets returned by Kafka. If we find some incorrect offsets (a latest offset is less than an offset in `knownOffsets`), we will retry at most `maxOffsetFetchAttempts` times. ## How was this patch tested? Jenkins Closes #23324 from zsxwing/SPARK-26267. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../kafka010/KafkaContinuousReadSupport.scala | 4 +- .../kafka010/KafkaMicroBatchReadSupport.scala | 19 ++++- .../kafka010/KafkaOffsetRangeCalculator.scala | 2 + .../sql/kafka010/KafkaOffsetReader.scala | 80 +++++++++++++++++-- .../spark/sql/kafka010/KafkaSource.scala | 5 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 48 +++++++++++ 6 files changed, 145 insertions(+), 13 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala index 1753a28fba2fb..02dfb9ca2b95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReadSupport.scala @@ -60,7 +60,7 @@ class KafkaContinuousReadSupport( override def initialOffset(): Offset = { val offsets = initialOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) } logInfo(s"Initial offsets: $offsets") @@ -107,7 +107,7 @@ class KafkaContinuousReadSupport( override def needsReconfiguration(config: ScanConfig): Boolean = { val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions - offsetReader.fetchLatestOffsets().keySet != knownPartitions + offsetReader.fetchLatestOffsets(None).keySet != knownPartitions } override def toString(): String = s"KafkaSource[$offsetReader]" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala index bb4de674c3c72..b4f042e93a5da 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReadSupport.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaMicroBatchReadSupport( override def latestOffset(start: Offset): Offset = { val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets - val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets)) endPartitionOffsets = KafkaSourceOffset(maxOffsetsPerTrigger.map { maxOffsets => rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) }.getOrElse { @@ -133,10 +133,21 @@ private[kafka010] class KafkaMicroBatchReadSupport( }.toSeq logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + val fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets + val untilOffsets = endPartitionOffsets + untilOffsets.foreach { case (tp, untilOffset) => + fromOffsets.get(tp).foreach { fromOffset => + if (untilOffset < fromOffset) { + reportDataLoss(s"Partition $tp's offset was changed from " + + s"$fromOffset to $untilOffset, some data may have been missed") + } + } + } + // Calculate offset ranges val offsetRanges = rangeCalculator.getRanges( - fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, - untilOffsets = endPartitionOffsets, + fromOffsets = fromOffsets, + untilOffsets = untilOffsets, executorLocations = getSortedExecutorList()) // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, @@ -186,7 +197,7 @@ private[kafka010] class KafkaMicroBatchReadSupport( case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => - KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index fb209c724afba..6008794924052 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -37,6 +37,8 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int * the read tasks of the skewed partitions to multiple Spark tasks. * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more * depending on rounding errors or Kafka partitions that didn't receive any new data. + * + * Empty ranges (`KafkaOffsetRange.size <= 0`) will be dropped. */ def getRanges( fromOffsets: PartitionOffsetMap, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 82066697cb95a..fc443d22bf5a2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -21,6 +21,7 @@ import java.{util => ju} import java.util.concurrent.{Executors, ThreadFactory} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.control.NonFatal @@ -137,6 +138,12 @@ private[kafka010] class KafkaOffsetReader( // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() + + // Call `position` to wait until the potential offset request triggered by `poll(0)` is + // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by + // `poll(0)` may reset offsets that should have been set by another request. + partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) + consumer.pause(partitions) assert(partitions.asScala == partitionOffsets.keySet, "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + @@ -192,19 +199,82 @@ private[kafka010] class KafkaOffsetReader( /** * Fetch the latest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. + * + * Kafka may return earliest offsets when we are requesting latest offsets if `poll` is called + * right before `seekToEnd` (KAFKA-7703). As a workaround, we will call `position` right after + * `poll` to wait until the potential offset request triggered by `poll(0)` is done. + * + * In addition, to avoid other unknown issues, we also use the given `knownOffsets` to audit the + * latest offsets returned by Kafka. If we find some incorrect offsets (a latest offset is less + * than an offset in `knownOffsets`), we will retry at most `maxOffsetFetchAttempts` times. When + * a topic is recreated, the latest offsets may be less than offsets in `knownOffsets`. We cannot + * distinguish this with KAFKA-7703, so we just return whatever we get from Kafka after retrying. */ - def fetchLatestOffsets(): Map[TopicPartition, Long] = runUninterruptibly { + def fetchLatestOffsets( + knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) val partitions = consumer.assignment() + + // Call `position` to wait until the potential offset request triggered by `poll(0)` is + // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by + // `poll(0)` may reset offsets that should have been set by another request. + partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {}) + consumer.pause(partitions) logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") - consumer.seekToEnd(partitions) - val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap - logDebug(s"Got latest offsets for partition : $partitionOffsets") - partitionOffsets + if (knownOffsets.isEmpty) { + consumer.seekToEnd(partitions) + partitions.asScala.map(p => p -> consumer.position(p)).toMap + } else { + var partitionOffsets: PartitionOffsetMap = Map.empty + + /** + * Compare `knownOffsets` and `partitionOffsets`. Returns all partitions that have incorrect + * latest offset (offset in `knownOffsets` is great than the one in `partitionOffsets`). + */ + def findIncorrectOffsets(): Seq[(TopicPartition, Long, Long)] = { + var incorrectOffsets = ArrayBuffer[(TopicPartition, Long, Long)]() + partitionOffsets.foreach { case (tp, offset) => + knownOffsets.foreach(_.get(tp).foreach { knownOffset => + if (knownOffset > offset) { + val incorrectOffset = (tp, knownOffset, offset) + incorrectOffsets += incorrectOffset + } + }) + } + incorrectOffsets + } + + // Retry to fetch latest offsets when detecting incorrect offsets. We don't use + // `withRetriesWithoutInterrupt` to retry because: + // + // - `withRetriesWithoutInterrupt` will reset the consumer for each attempt but a fresh + // consumer has a much bigger chance to hit KAFKA-7703. + // - Avoid calling `consumer.poll(0)` which may cause KAFKA-7703. + var incorrectOffsets: Seq[(TopicPartition, Long, Long)] = Nil + var attempt = 0 + do { + consumer.seekToEnd(partitions) + partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + attempt += 1 + + incorrectOffsets = findIncorrectOffsets() + if (incorrectOffsets.nonEmpty) { + logWarning("Found incorrect offsets in some partitions " + + s"(partition, previous offset, fetched offset): $incorrectOffsets") + if (attempt < maxOffsetFetchAttempts) { + logWarning("Retrying to fetch latest offsets because of incorrect offsets") + Thread.sleep(offsetFetchAttemptIntervalMs) + } + } + } while (incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts) + + logDebug(s"Got latest offsets for partition : $partitionOffsets") + partitionOffsets + } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 66ec7e0cd084a..d65b3cea632c4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( metadataLog.get(0).getOrElse { val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets(None)) case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) @@ -148,7 +148,8 @@ private[kafka010] class KafkaSource( // Make sure initialPartitionOffsets is initialized initialPartitionOffsets - val latest = kafkaReader.fetchLatestOffsets() + val latest = kafkaReader.fetchLatestOffsets( + currentPartitionOffsets.orElse(Some(initialPartitionOffsets))) val offsets = maxOffsetsPerTrigger match { case None => latest diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 5ee76990b54f4..61cbb3285a4f0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -329,6 +329,54 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } + test("subscribe topic by pattern with topic recreation between batches") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-good" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, Array("1", "3")) + testUtils.createTopic(topic2, partitions = 1) + testUtils.sendMessages(topic2, Array("2", "4")) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.default.api.timeout.ms", "3000") + .option("startingOffsets", "earliest") + .option("subscribePattern", s"$topicPrefix-.*") + + val ds = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + .map(kv => kv._2.toInt) + + testStream(ds)( + StartStream(), + AssertOnQuery { q => + q.processAllAvailable() + true + }, + CheckAnswer(1, 2, 3, 4), + // Restart the stream in this test to make the test stable. When recreating a topic when a + // consumer is alive, it may not be able to see the recreated topic even if a fresh consumer + // has seen it. + StopStream, + // Recreate `topic2` and wait until it's available + WithOffsetSync(new TopicPartition(topic2, 0), expectedOffset = 1) { () => + testUtils.deleteTopic(topic2) + testUtils.createTopic(topic2) + testUtils.sendMessages(topic2, Array("6")) + }, + StartStream(), + ExpectFailure[IllegalStateException](e => { + // The offset of `topic2` should be changed from 2 to 1 + assert(e.getMessage.contains("was changed from 2 to 1")) + }) + ) + } + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-current" From d6a5f859848bbd237e19075dd26e1547fb3af417 Mon Sep 17 00:00:00 2001 From: wuyi Date: Fri, 21 Dec 2018 13:21:58 -0600 Subject: [PATCH 2325/2461] [SPARK-26269][YARN] Yarnallocator should have same blacklist behaviour with yarn to maxmize use of cluster resource ## What changes were proposed in this pull request? As I mentioned in jira [SPARK-26269](https://issues.apache.org/jira/browse/SPARK-26269), in order to maxmize the use of cluster resource, this pr try to make `YarnAllocator` have the same blacklist behaviour with YARN. ## How was this patch tested? Added. Closes #23223 from Ngone51/dev-YarnAllocator-should-have-same-blacklist-behaviour-with-YARN. Lead-authored-by: wuyi Co-authored-by: Ngone51 Signed-off-by: Thomas Graves --- .../spark/deploy/yarn/YarnAllocator.scala | 32 ++++++-- .../yarn/YarnAllocatorBlacklistTracker.scala | 4 +- .../YarnAllocatorBlacklistTrackerSuite.scala | 2 +- .../deploy/yarn/YarnAllocatorSuite.scala | 75 ++++++++++++++++++- 4 files changed, 101 insertions(+), 12 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 54b1ec266113f..a3feca5dfd229 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -607,13 +607,23 @@ private[yarn] class YarnAllocator( val message = "Container killed by YARN for exceeding physical memory limits. " + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key}." (true, message) - case _ => - // all the failures which not covered above, like: - // disk failure, kill by app master or resource manager, ... - allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) - (true, "Container marked as failed: " + containerId + onHostStr + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) + case other_exit_status => + // SPARK-26269: follow YARN's blacklisting behaviour(see https://github + // .com/apache/hadoop/blob/228156cfd1b474988bc4fedfbf7edddc87db41e3/had + // oop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/ap + // ache/hadoop/yarn/util/Apps.java#L273 for details) + if (NOT_APP_AND_SYSTEM_FAULT_EXIT_STATUS.contains(other_exit_status)) { + (false, s"Container marked as failed: $containerId$onHostStr" + + s". Exit status: ${completedContainer.getExitStatus}" + + s". Diagnostics: ${completedContainer.getDiagnostics}.") + } else { + // completed container from a bad node + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) + (true, s"Container from a bad node: $containerId$onHostStr" + + s". Exit status: ${completedContainer.getExitStatus}" + + s". Diagnostics: ${completedContainer.getDiagnostics}.") + } + } if (exitCausedByApp) { @@ -739,4 +749,12 @@ private object YarnAllocator { val MEM_REGEX = "[0-9.]+ [KMG]B" val VMEM_EXCEEDED_EXIT_CODE = -103 val PMEM_EXCEEDED_EXIT_CODE = -104 + + val NOT_APP_AND_SYSTEM_FAULT_EXIT_STATUS = Set( + ContainerExitStatus.KILLED_BY_RESOURCEMANAGER, + ContainerExitStatus.KILLED_BY_APPMASTER, + ContainerExitStatus.KILLED_AFTER_APP_COMPLETION, + ContainerExitStatus.ABORTED, + ContainerExitStatus.DISKS_FAILED + ) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala index ceac7cda5f8be..268976b629507 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -120,7 +120,9 @@ private[spark] class YarnAllocatorBlacklistTracker( if (removals.nonEmpty) { logInfo(s"removing nodes from YARN application master's blacklist: $removals") } - amClient.updateBlacklist(additions.asJava, removals.asJava) + if (additions.nonEmpty || removals.nonEmpty) { + amClient.updateBlacklist(additions.asJava, removals.asJava) + } currentBlacklistedYarnNodes = nodesToBlacklist } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala index aeac68e6ed330..201910731e934 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -87,7 +87,7 @@ class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers // expired blacklisted nodes (simulating a resource request) yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) // no change is communicated to YARN regarding the blacklisting - verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + verify(amClientMock, times(0)).updateBlacklist(Collections.emptyList(), Collections.emptyList()) } test("combining scheduler and allocation blacklist") { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index b61e7df4420ef..53a538dc1de29 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.yarn +import java.util.Collections + import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration @@ -114,13 +116,29 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter clock) } - def createContainer(host: String, resource: Resource = containerResource): Container = { - val containerId = ContainerId.newContainerId(appAttemptId, containerNum) + def createContainer( + host: String, + containerNumber: Int = containerNum, + resource: Resource = containerResource): Container = { + val containerId: ContainerId = ContainerId.newContainerId(appAttemptId, containerNum) containerNum += 1 val nodeId = NodeId.newInstance(host, 1000) Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null) } + def createContainers(hosts: Seq[String], containerIds: Seq[Int]): Seq[Container] = { + hosts.zip(containerIds).map{case (host, id) => createContainer(host, id)} + } + + def createContainerStatus( + containerId: ContainerId, + exitStatus: Int, + containerState: ContainerState = ContainerState.COMPLETE, + diagnostics: String = "diagnostics"): ContainerStatus = { + ContainerStatus.newInstance(containerId, containerState, diagnostics, exitStatus) + } + + test("single container allocated") { // request a single container and receive it val handler = createAllocator(1) @@ -148,7 +166,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter Map(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "gpu" -> "2G")) handler.updateResourceRequests() - val container = createContainer("host1", handler.resource) + val container = createContainer("host1", resource = handler.resource) handler.handleAllocatedContainers(Array(container)) // get amount of memory and vcores from resource, so effectively skipping their validation @@ -417,4 +435,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter clock.advance(50 * 1000L) handler.getNumExecutorsFailed should be (0) } + + test("SPARK-26269: YarnAllocator should have same blacklist behaviour with YARN") { + val rmClientSpy = spy(rmClient) + val maxExecutors = 11 + + val handler = createAllocator( + maxExecutors, + rmClientSpy, + Map( + "spark.yarn.blacklist.executor.launch.blacklisting.enabled" -> "true", + "spark.blacklist.application.maxFailedExecutorsPerNode" -> "0")) + handler.updateResourceRequests() + + val hosts = (0 until maxExecutors).map(i => s"host$i") + val ids = 0 to maxExecutors + val containers = createContainers(hosts, ids) + + val nonBlacklistedStatuses = Seq( + ContainerExitStatus.SUCCESS, + ContainerExitStatus.PREEMPTED, + ContainerExitStatus.KILLED_EXCEEDED_VMEM, + ContainerExitStatus.KILLED_EXCEEDED_PMEM, + ContainerExitStatus.KILLED_BY_RESOURCEMANAGER, + ContainerExitStatus.KILLED_BY_APPMASTER, + ContainerExitStatus.KILLED_AFTER_APP_COMPLETION, + ContainerExitStatus.ABORTED, + ContainerExitStatus.DISKS_FAILED) + + val nonBlacklistedContainerStatuses = nonBlacklistedStatuses.zipWithIndex.map { + case (exitStatus, idx) => createContainerStatus(containers(idx).getId, exitStatus) + } + + val BLACKLISTED_EXIT_CODE = 1 + val blacklistedStatuses = Seq(ContainerExitStatus.INVALID, BLACKLISTED_EXIT_CODE) + + val blacklistedContainerStatuses = blacklistedStatuses.zip(9 until maxExecutors).map { + case (exitStatus, idx) => createContainerStatus(containers(idx).getId, exitStatus) + } + + handler.handleAllocatedContainers(containers.slice(0, 9)) + handler.processCompletedContainers(nonBlacklistedContainerStatuses) + verify(rmClientSpy, never()) + .updateBlacklist(hosts.slice(0, 9).asJava, Collections.emptyList()) + + handler.handleAllocatedContainers(containers.slice(9, 11)) + handler.processCompletedContainers(blacklistedContainerStatuses) + verify(rmClientSpy) + .updateBlacklist(hosts.slice(9, 10).asJava, Collections.emptyList()) + verify(rmClientSpy) + .updateBlacklist(hosts.slice(10, 11).asJava, Collections.emptyList()) + } } From 8dd29fe36b781d115213b1d6a8446ad04e9239bb Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 21 Dec 2018 11:28:22 -0800 Subject: [PATCH 2326/2461] [SPARK-25642][YARN] Adding two new metrics to record the number of registered connections as well as the number of active connections to YARN Shuffle Service Recently, the ability to expose the metrics for YARN Shuffle Service was added as part of [SPARK-18364](https://github.com/apache/spark/pull/22485). We need to add some metrics to be able to determine the number of active connections as well as open connections to the external shuffle service to benchmark network and connection issues on large cluster environments. Added two more shuffle server metrics for Spark Yarn shuffle service: numRegisteredConnections which indicate the number of registered connections to the shuffle service and numActiveConnections which indicate the number of active connections to the shuffle service at any given point in time. If these metrics are outputted to a file, we get something like this: 1533674653489 default.shuffleService: Hostname=server1.abc.com, openBlockRequestLatencyMillis_count=729, openBlockRequestLatencyMillis_rate15=0.7110833548897356, openBlockRequestLatencyMillis_rate5=1.657808981793011, openBlockRequestLatencyMillis_rate1=2.2404486061620474, openBlockRequestLatencyMillis_rateMean=0.9242558551196706, numRegisteredConnections=35, blockTransferRateBytes_count=2635880512, blockTransferRateBytes_rate15=2578547.6094160094, blockTransferRateBytes_rate5=6048721.726302424, blockTransferRateBytes_rate1=8548922.518223226, blockTransferRateBytes_rateMean=3341878.633637769, registeredExecutorsSize=5, registerExecutorRequestLatencyMillis_count=5, registerExecutorRequestLatencyMillis_rate15=0.0027973949328659836, registerExecutorRequestLatencyMillis_rate5=0.0021278007987206426, registerExecutorRequestLatencyMillis_rate1=2.8270296777387467E-6, registerExecutorRequestLatencyMillis_rateMean=0.006339206380043053, numActiveConnections=35 Closes #22498 from pgandhi999/SPARK-18364. Authored-by: pgandhi Signed-off-by: Marcelo Vanzin --- .../spark/network/TransportContext.java | 9 ++++++- .../server/TransportChannelHandler.java | 18 +++++++++++++- .../spark/network/server/TransportServer.java | 5 ++++ .../shuffle/ExternalShuffleBlockHandler.java | 24 +++++++++++++++++-- .../network/yarn/YarnShuffleService.java | 21 +++++++++------- .../yarn/YarnShuffleServiceMetrics.java | 5 ++++ .../spark/deploy/ExternalShuffleService.scala | 2 ++ .../yarn/YarnShuffleServiceMetricsSuite.scala | 3 ++- 8 files changed, 73 insertions(+), 14 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 480b52652de53..1a3f3f2a6f249 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; +import com.codahale.metrics.Counter; import io.netty.channel.Channel; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; @@ -66,6 +67,8 @@ public class TransportContext { private final RpcHandler rpcHandler; private final boolean closeIdleConnections; private final boolean isClientOnly; + // Number of registered connections to the shuffle service + private Counter registeredConnections = new Counter(); /** * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created @@ -221,7 +224,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler, conf.maxChunksBeingTransferred()); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs(), closeIdleConnections); + conf.connectionTimeoutMs(), closeIdleConnections, this); } /** @@ -234,4 +237,8 @@ private ChunkFetchRequestHandler createChunkFetchHandler(TransportChannelHandler } public TransportConf getConf() { return conf; } + + public Counter getRegisteredConnections() { + return registeredConnections; + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index c824a7b0d4740..ca81099c4d5cb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -21,6 +21,7 @@ import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; +import org.apache.spark.network.TransportContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,18 +58,21 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler allMetrics; // Time latency for open block request in ms private final Timer openBlockRequestLatencyMillis = new Timer(); @@ -181,14 +183,20 @@ private class ShuffleMetrics implements MetricSet { private final Timer registerExecutorRequestLatencyMillis = new Timer(); // Block transfer rate in byte per second private final Meter blockTransferRateBytes = new Meter(); + // Number of active connections to the shuffle service + private Counter activeConnections = new Counter(); + // Number of registered connections to the shuffle service + private Counter registeredConnections = new Counter(); - private ShuffleMetrics() { + public ShuffleMetrics() { allMetrics = new HashMap<>(); allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis); allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis); allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); allMetrics.put("registeredExecutorsSize", (Gauge) () -> blockManager.getRegisteredExecutorsSize()); + allMetrics.put("numActiveConnections", activeConnections); + allMetrics.put("numRegisteredConnections", registeredConnections); } @Override @@ -244,4 +252,16 @@ public ManagedBuffer next() { } } + @Override + public void channelActive(TransportClient client) { + metrics.activeConnections.inc(); + super.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { + metrics.activeConnections.dec(); + super.channelInactive(client); + } + } diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 72ae1a1295236..7e8d3b2bc3ba4 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -170,15 +170,6 @@ protected void serviceInit(Configuration conf) throws Exception { TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); - // register metrics on the block handler into the Node Manager's metrics system. - YarnShuffleServiceMetrics serviceMetrics = - new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); - - MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance(); - metricsSystem.register( - "sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics); - logger.info("Registered metrics with Hadoop's DefaultMetricsSystem"); - // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests List bootstraps = Lists.newArrayList(); @@ -199,6 +190,18 @@ protected void serviceInit(Configuration conf) throws Exception { port = shuffleServer.getPort(); boundPort = port; String authEnabledString = authEnabled ? "enabled" : "not enabled"; + + // register metrics on the block handler into the Node Manager's metrics system. + blockHandler.getAllMetrics().getMetrics().put("numRegisteredConnections", + shuffleServer.getRegisteredConnections()); + YarnShuffleServiceMetrics serviceMetrics = + new YarnShuffleServiceMetrics(blockHandler.getAllMetrics()); + + MetricsSystemImpl metricsSystem = (MetricsSystemImpl) DefaultMetricsSystem.instance(); + metricsSystem.register( + "sparkShuffleService", "Metrics on the Spark Shuffle Service", serviceMetrics); + logger.info("Registered metrics with Hadoop's DefaultMetricsSystem"); + logger.info("Started YARN shuffle service for Spark on port {}. " + "Authentication is {}. Registered executor file is {}", port, authEnabledString, registeredExecutorFile); diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java index 3e4d479b862b3..501237407e9b2 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleServiceMetrics.java @@ -107,6 +107,11 @@ public static void collectMetric( throw new IllegalStateException( "Not supported class type of metric[" + name + "] for value " + gaugeValue); } + } else if (metric instanceof Counter) { + Counter c = (Counter) metric; + long counterValue = c.getCount(); + metricsRecordBuilder.addGauge(new ShuffleServiceMetricsInfo(name, "Number of " + + "connections to shuffle service " + name), counterValue); } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f6b3c37f0fe72..03e3abb3ce569 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -84,6 +84,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana server = transportContext.createServer(port, bootstraps.asJava) shuffleServiceSource.registerMetricSet(server.getAllMetrics) + blockHandler.getAllMetrics.getMetrics.put("numRegisteredConnections", + server.getRegisteredConnections) shuffleServiceSource.registerMetricSet(blockHandler.getAllMetrics) masterMetricsSystem.registerSource(shuffleServiceSource) masterMetricsSystem.start() diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index 40b92282a3b8f..952fd0b70bb7b 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -38,7 +38,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { test("metrics named as expected") { val allMetrics = Set( "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", - "blockTransferRateBytes", "registeredExecutorsSize") + "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", + "numRegisteredConnections") metrics.getMetrics.keySet().asScala should be (allMetrics) } From bba506f8f454c7a8fa82e93a1728e02428fe0d35 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 22 Dec 2018 10:16:27 +0800 Subject: [PATCH 2327/2461] [SPARK-26216][SQL][FOLLOWUP] use abstract class instead of trait for UserDefinedFunction ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/23178 , to keep binary compability by using abstract class. ## How was this patch tested? Manual test. I created a simple app with Spark 2.4 ``` object TryUDF { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().appName("test").master("local[*]").getOrCreate() import spark.implicits._ val f1 = udf((i: Int) => i + 1) println(f1.deterministic) spark.range(10).select(f1.asNonNullable().apply($"id")).show() spark.stop() } } ``` When I run it with current master, it fails with ``` java.lang.IncompatibleClassChangeError: Found interface org.apache.spark.sql.expressions.UserDefinedFunction, but class was expected ``` When I run it with this PR, it works Closes #23351 from cloud-fan/minor. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- docs/sql-migration-guide-upgrade.md | 2 -- project/MimaExcludes.scala | 28 ++++++++++++++++++- .../sql/expressions/UserDefinedFunction.scala | 2 +- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 115fc6516fb4c..1bd3b5ad0e1aa 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -33,8 +33,6 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.setCommandRejectsSparkCoreConfs` to `false`. - - Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0. - - Since Spark 3.0, CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpuse with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7bb70a29195d6..89fc53ce3972f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -241,7 +241,33 @@ object MimaExcludes { // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.expressions.UserDefinedFunction") + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3") ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index f88e0e0f299de..901472d8e0360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Stable -sealed trait UserDefinedFunction { +sealed abstract class UserDefinedFunction { /** * Returns true when the UDF can return a nullable value. From 81addaa6b7b6f16e477f8dbb26a5d5e9541131b0 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 22 Dec 2018 00:41:21 -0800 Subject: [PATCH 2328/2461] [SPARK-26427][BUILD] Upgrade Apache ORC to 1.5.4 ## What changes were proposed in this pull request? This PR aims to update Apache ORC dependency to the latest version 1.5.4 released at Dec. 20. ([Release Notes](https://issues.apache.org/jira/secure/ReleaseNote.jspa?projectId=12318320&version=12344187])) ``` [ORC-237] OrcFile.mergeFiles Specified block size is less than configured minimum value [ORC-409] Changes for extending MemoryManagerImpl [ORC-410] Fix a locale-dependent test in TestCsvReader [ORC-416] Avoid opening data reader when there is no stripe [ORC-417] Use dynamic Apache Maven mirror link [ORC-419] Ensure to call `close` at RecordReaderImpl constructor exception [ORC-432] openjdk 8 has a bug that prevents surefire from working [ORC-435] Ability to read stripes that are greater than 2GB [ORC-437] Make acid schema checks case insensitive [ORC-411] Update build to work with Java 10. [ORC-418] Fix broken docker build script ``` ## How was this patch tested? Build and pass Jenkins. Closes #23364 from dongjoon-hyun/SPARK-26427. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- dev/deps/spark-deps-hadoop-3.1 | 6 +++--- pom.xml | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 71423af0789c6..1af29fcaff2aa 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -155,9 +155,9 @@ objenesis-2.5.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar +orc-core-1.5.4-nohive.jar +orc-mapreduce-1.5.4-nohive.jar +orc-shims-1.5.4.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 93eafef045330..05f180b17a588 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -172,9 +172,9 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.5.3-nohive.jar -orc-mapreduce-1.5.3-nohive.jar -orc-shims-1.5.3.jar +orc-core-1.5.4-nohive.jar +orc-mapreduce-1.5.4-nohive.jar +orc-shims-1.5.4.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 310d7de955125..de9421419edc2 100644 --- a/pom.xml +++ b/pom.xml @@ -132,7 +132,7 @@ 2.1.0 10.12.1.1 1.10.0 - 1.5.3 + 1.5.4 nohive 1.6.0 9.4.12.v20180830 @@ -1740,6 +1740,10 @@ ${orc.classifier} ${orc.deps.scope} + + javax.xml.bind + jaxb-api + org.apache.hadoop hadoop-common From ceff0c8450a4f2e31ec52dfc4d101f67c67853c5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 22 Dec 2018 00:43:59 -0800 Subject: [PATCH 2329/2461] [SPARK-26428][SS][TEST] Minimize deprecated `ProcessingTime` usage ## What changes were proposed in this pull request? Use of `ProcessingTime` class was deprecated in favor of `Trigger.ProcessingTime` in Spark 2.2. And, [SPARK-21464](https://issues.apache.org/jira/browse/SPARK-21464) minimized it at 2.2.1. Recently, it grows again in test suites. This PR aims to clean up newly introduced deprecation warnings for Spark 3.0. ## How was this patch tested? Pass the Jenkins with existing tests and manually check the warnings. Closes #23367 from dongjoon-hyun/SPARK-26428. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 16 ++++++++-------- .../sql/streaming/FileStreamSourceSuite.scala | 2 +- .../apache/spark/sql/streaming/StreamSuite.scala | 4 ++-- .../streaming/StreamingQueryListenerSuite.scala | 6 +++--- .../sql/streaming/StreamingQuerySuite.scala | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 61cbb3285a4f0..d4eb526540053 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext @@ -236,7 +236,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } testStream(mapped)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // 1 from smallest, 1 from middle, 8 from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), @@ -247,7 +247,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 ), StopStream, - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // smallest now empty, 1 more from middle, 9 more from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, @@ -282,7 +282,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { val mapped = kafka.map(kv => kv._2.toInt + 1) testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), + StartStream(trigger = Trigger.ProcessingTime(1)), makeSureGetOffsetCalled, AddKafkaData(Set(topic), 1, 2, 3), CheckAnswer(2, 3, 4), @@ -605,7 +605,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { } testStream(kafka)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // 5 from smaller topic, 5 from bigger one CheckLastBatch((0 to 4) ++ (100 to 104): _*), @@ -618,7 +618,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // smaller topic empty, 5 from bigger one CheckLastBatch(110 to 114: _*), StopStream, - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, // smallest now empty, 5 from bigger one CheckLastBatch(115 to 119: _*), @@ -727,7 +727,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // The message values are the same as their offsets to make the test easy to follow testUtils.withTranscationalProducer { producer => testStream(mapped)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, CheckAnswer(), WithOffsetSync(topicPartition, expectedOffset = 5) { () => @@ -850,7 +850,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { // The message values are the same as their offsets to make the test easy to follow testUtils.withTranscationalProducer { producer => testStream(mapped)( - StartStream(ProcessingTime(100), clock), + StartStream(Trigger.ProcessingTime(100), clock), waitUntilBatchProcessed, CheckNewAnswer(), WithOffsetSync(topicPartition, expectedOffset = 5) { () => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index d4bd9c7987f2d..de664cafed3b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1360,7 +1360,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { options = srcOptions) val clock = new StreamManualClock() testStream(fileStream)( - StartStream(trigger = ProcessingTime(10), triggerClock = clock), + StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock), AssertOnQuery { _ => // Block until the first batch finishes. eventually(timeout(streamingTimeout)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f55ddb5419d20..55fdcee83f114 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -312,7 +312,7 @@ class StreamSuite extends StreamTest { val inputData = MemoryStream[Int] testStream(inputData.toDS())( - StartStream(ProcessingTime("10 seconds"), new StreamManualClock), + StartStream(Trigger.ProcessingTime("10 seconds"), new StreamManualClock), /* -- batch 0 ----------------------- */ // Add some data in batch 0 @@ -353,7 +353,7 @@ class StreamSuite extends StreamTest { /* Stop then restart the Stream */ StopStream, - StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), + StartStream(Trigger.ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), /* -- batch 1 no rerun ----------------- */ // batch 1 would not re-run because the latest batch id logged in commit log is 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index fe77a1b4469c5..d00f2e3bf4d1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -82,7 +82,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { testStream(df, OutputMode.Append)( // Start event generated when query started - StartStream(ProcessingTime(100), triggerClock = clock), + StartStream(Trigger.ProcessingTime(100), triggerClock = clock), AssertOnQuery { query => assert(listener.startEvent !== null) assert(listener.startEvent.id === query.id) @@ -124,7 +124,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { }, // Termination event generated with exception message when stopped with error - StartStream(ProcessingTime(100), triggerClock = clock), + StartStream(Trigger.ProcessingTime(100), triggerClock = clock), AssertStreamExecThreadToWaitForClock(), AddData(inputData, 0), AdvanceManualClock(100), // process bad data @@ -306,7 +306,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } val clock = new StreamManualClock() val actions = mutable.ArrayBuffer[StreamAction]() - actions += StartStream(trigger = ProcessingTime(10), triggerClock = clock) + actions += StartStream(trigger = Trigger.ProcessingTime(10), triggerClock = clock) for (_ <- 1 to 100) { actions += AdvanceManualClock(10) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index c170641372d61..29b816486a1fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -257,7 +257,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi var lastProgressBeforeStop: StreamingQueryProgress = null testStream(mapped, OutputMode.Complete)( - StartStream(ProcessingTime(1000), triggerClock = clock), + StartStream(Trigger.ProcessingTime(1000), triggerClock = clock), AssertStreamExecThreadIsWaitingForTime(1000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), @@ -370,7 +370,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Stopped"), // Test status and progress after query terminated with error - StartStream(ProcessingTime(1000), triggerClock = clock), + StartStream(Trigger.ProcessingTime(1000), triggerClock = clock), AdvanceManualClock(1000), // ensure initial trigger completes before AddData AddData(inputData, 0), AdvanceManualClock(1000), // allow another trigger From c7bfb4cf832d5b20527df6e19855b6a7436988a9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 22 Dec 2018 00:46:36 -0800 Subject: [PATCH 2330/2461] [SPARK-26430][BUILD][TEST-MAVEN] Upgrade Surefire plugin to 3.0.0-M2 ## What changes were proposed in this pull request? This PR aims to upgrade Maven Surefile plugin for JDK11 support. 3.0.0-M2 is [released Dec. 9th.](https://issues.apache.org/jira/projects/SUREFIRE/versions/12344396) ``` [SUREFIRE-1568] Versions 2.21 and higher doesn't work with junit-platform for Java 9 module [SUREFIRE-1605] NoClassDefFoundError (RunNotifier) with JDK 11 [SUREFIRE-1600] Surefire Project using surefire:2.12.4 is not fully able to work with JDK 10+ on internal build system. Therefore surefire-shadefire should go with Surefire:3.0.0-M2. [SUREFIRE-1593] 3.0.0-M1 produces invalid code sources on Windows [SUREFIRE-1602] Surefire fails loading class ForkedBooter when using a sub-directory pom file and a local maven repo [SUREFIRE-1606] maven-shared-utils must not be on provider's classpath [SUREFIRE-1531] Option to switch-off Java 9 modules [SUREFIRE-1590] Deploy multiple versions of Report XSD [SUREFIRE-1591] Java 1.7 feature Diamonds replaced Generics [SUREFIRE-1594] Java 1.7 feature try-catch - multiple exceptions in one catch [SUREFIRE-1595] Java 1.7 feature System.lineSeparator() [SUREFIRE-1597] ModularClasspathForkConfiguration with debug logs (args file and its path on file system) [SUREFIRE-1596] Unnecessary check JAVA_RECENT == JAVA_1_7 in unit tests [SUREFIRE-1598] Fixed typo in assertion statement in integration test Surefire855AllowFailsafeUseArtifactFileIT [SUREFIRE-1607] Roadmap on Project Site ``` ## How was this patch tested? Pass the Jenkins. Closes #23370 from dongjoon-hyun/SPARK-26430. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index de9421419edc2..321de209a56a1 100644 --- a/pom.xml +++ b/pom.xml @@ -2103,7 +2103,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M1 + 3.0.0-M2 From 0a02d5c36fc5035abcfb930e1a229d65c6cf683f Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Sat, 22 Dec 2018 09:03:02 -0600 Subject: [PATCH 2331/2461] =?UTF-8?q?[SPARK-26285][CORE]=20accumulator=20m?= =?UTF-8?q?etrics=20sources=20for=20LongAccumulator=20and=20Doub=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …leAccumulator ## What changes were proposed in this pull request? This PR implements metric sources for LongAccumulator and DoubleAccumulator, such that a user can register these accumulators easily and have their values be reported by the driver's metric namespace. ## How was this patch tested? Unit tests, and manual tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23242 from abellina/SPARK-26285_accumulator_source. Lead-authored-by: Alessandro Bellina Co-authored-by: Alessandro Bellina Co-authored-by: Alessandro Bellina Signed-off-by: Thomas Graves --- .../metrics/source/AccumulatorSource.scala | 89 ++++++++++++++++++ .../source/AccumulatorSourceSuite.scala | 91 +++++++++++++++++++ .../examples/AccumulatorMetricsTest.scala | 77 ++++++++++++++++ 3 files changed, 257 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala create mode 100644 core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala diff --git a/core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala new file mode 100644 index 0000000000000..45a4d224d45fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/source/AccumulatorSource.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.util.{AccumulatorV2, DoubleAccumulator, LongAccumulator} + +/** + * AccumulatorSource is a Spark metric Source that reports the current value + * of the accumulator as a gauge. + * + * It is restricted to the LongAccumulator and the DoubleAccumulator, as those + * are the current built-in numerical accumulators with Spark, and excludes + * the CollectionAccumulator, as that is a List of values (hard to report, + * to a metrics system) + */ +private[spark] class AccumulatorSource extends Source { + private val registry = new MetricRegistry + protected def register[T](accumulators: Map[String, AccumulatorV2[_, T]]): Unit = { + accumulators.foreach { + case (name, accumulator) => + val gauge = new Gauge[T] { + override def getValue: T = accumulator.value + } + registry.register(MetricRegistry.name(name), gauge) + } + } + + override def sourceName: String = "AccumulatorSource" + override def metricRegistry: MetricRegistry = registry +} + +@Experimental +class LongAccumulatorSource extends AccumulatorSource + +@Experimental +class DoubleAccumulatorSource extends AccumulatorSource + +/** + * :: Experimental :: + * Metrics source specifically for LongAccumulators. Accumulators + * are only valid on the driver side, so these metrics are reported + * only by the driver. + * Register LongAccumulators using: + * LongAccumulatorSource.register(sc, {"name" -> longAccumulator}) + */ +@Experimental +object LongAccumulatorSource { + def register(sc: SparkContext, accumulators: Map[String, LongAccumulator]): Unit = { + val source = new LongAccumulatorSource + source.register(accumulators) + sc.env.metricsSystem.registerSource(source) + } +} + +/** + * :: Experimental :: + * Metrics source specifically for DoubleAccumulators. Accumulators + * are only valid on the driver side, so these metrics are reported + * only by the driver. + * Register DoubleAccumulators using: + * DoubleAccumulatorSource.register(sc, {"name" -> doubleAccumulator}) + */ +@Experimental +object DoubleAccumulatorSource { + def register(sc: SparkContext, accumulators: Map[String, DoubleAccumulator]): Unit = { + val source = new DoubleAccumulatorSource + source.register(accumulators) + sc.env.metricsSystem.registerSource(source) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala b/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala new file mode 100644 index 0000000000000..6a6c07cb068cc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.MetricRegistry +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.{mock, never, spy, times, verify, when} + +import org.apache.spark.{SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.{DoubleAccumulator, LongAccumulator} + +class AccumulatorSourceSuite extends SparkFunSuite { + test("that that accumulators register against the metric system's register") { + val acc1 = new LongAccumulator() + val acc2 = new LongAccumulator() + val mockContext = mock(classOf[SparkContext]) + val mockEnvironment = mock(classOf[SparkEnv]) + val mockMetricSystem = mock(classOf[MetricsSystem]) + when(mockEnvironment.metricsSystem) thenReturn (mockMetricSystem) + when(mockContext.env) thenReturn (mockEnvironment) + val accs = Map("my-accumulator-1" -> acc1, + "my-accumulator-2" -> acc2) + LongAccumulatorSource.register(mockContext, accs) + val captor = new ArgumentCaptor[AccumulatorSource]() + verify(mockMetricSystem, times(1)).registerSource(captor.capture()) + val source = captor.getValue() + val gauges = source.metricRegistry.getGauges() + assert (gauges.size == 2) + assert (gauges.firstKey == "my-accumulator-1") + assert (gauges.lastKey == "my-accumulator-2") + } + + test("the accumulators value property is checked when the gauge's value is requested") { + val acc1 = new LongAccumulator() + acc1.add(123) + val acc2 = new LongAccumulator() + acc2.add(456) + val mockContext = mock(classOf[SparkContext]) + val mockEnvironment = mock(classOf[SparkEnv]) + val mockMetricSystem = mock(classOf[MetricsSystem]) + when(mockEnvironment.metricsSystem) thenReturn (mockMetricSystem) + when(mockContext.env) thenReturn (mockEnvironment) + val accs = Map("my-accumulator-1" -> acc1, + "my-accumulator-2" -> acc2) + LongAccumulatorSource.register(mockContext, accs) + val captor = new ArgumentCaptor[AccumulatorSource]() + verify(mockMetricSystem, times(1)).registerSource(captor.capture()) + val source = captor.getValue() + val gauges = source.metricRegistry.getGauges() + assert(gauges.get("my-accumulator-1").getValue() == 123) + assert(gauges.get("my-accumulator-2").getValue() == 456) + } + + test("the double accumulators value propety is checked when the gauge's value is requested") { + val acc1 = new DoubleAccumulator() + acc1.add(123.123) + val acc2 = new DoubleAccumulator() + acc2.add(456.456) + val mockContext = mock(classOf[SparkContext]) + val mockEnvironment = mock(classOf[SparkEnv]) + val mockMetricSystem = mock(classOf[MetricsSystem]) + when(mockEnvironment.metricsSystem) thenReturn (mockMetricSystem) + when(mockContext.env) thenReturn (mockEnvironment) + val accs = Map( + "my-accumulator-1" -> acc1, + "my-accumulator-2" -> acc2) + DoubleAccumulatorSource.register(mockContext, accs) + val captor = new ArgumentCaptor[AccumulatorSource]() + verify(mockMetricSystem, times(1)).registerSource(captor.capture()) + val source = captor.getValue() + val gauges = source.metricRegistry.getGauges() + assert(gauges.get("my-accumulator-1").getValue() == 123.123) + assert(gauges.get("my-accumulator-2").getValue() == 456.456) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala new file mode 100644 index 0000000000000..5d9a9a73f12ec --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/AccumulatorMetricsTest.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import org.apache.spark.metrics.source.{DoubleAccumulatorSource, LongAccumulatorSource} +import org.apache.spark.sql.SparkSession + +/** + * Usage: AccumulatorMetricsTest [numElem] + * + * This example shows how to register accumulators against the accumulator source. + * A simple RDD is created, and during the map, the accumulators are incremented. + * + * The only argument, numElem, sets the number elements in the collection to parallize. + * + * The result is output to stdout in the driver with the values of the accumulators. + * For the long accumulator, it should equal numElem the double accumulator should be + * roughly 1.1 x numElem (within double precision.) This example also sets up a + * ConsoleSink (metrics) instance, and so registered codahale metrics (like the + * accumulator source) are reported to stdout as well. + */ +object AccumulatorMetricsTest { + def main(args: Array[String]) { + + val spark = SparkSession + .builder() + .config("spark.metrics.conf.*.sink.console.class", + "org.apache.spark.metrics.sink.ConsoleSink") + .getOrCreate() + + val sc = spark.sparkContext + + val acc = sc.longAccumulator("my-long-metric") + // register the accumulator, the metric system will report as + // [spark.metrics.namespace].[execId|driver].AccumulatorSource.my-long-metric + LongAccumulatorSource.register(sc, List(("my-long-metric" -> acc)).toMap) + + val acc2 = sc.doubleAccumulator("my-double-metric") + // register the accumulator, the metric system will report as + // [spark.metrics.namespace].[execId|driver].AccumulatorSource.my-double-metric + DoubleAccumulatorSource.register(sc, List(("my-double-metric" -> acc2)).toMap) + + val num = if (args.length > 0) args(0).toInt else 1000000 + + val startTime = System.nanoTime + + val accumulatorTest = sc.parallelize(1 to num).foreach(_ => { + acc.add(1) + acc2.add(1.1) + }) + + // Print a footer with test time and accumulator values + println("Test took %.0f milliseconds".format((System.nanoTime - startTime) / 1E6)) + println("Accumulator values:") + println("*** Long accumulator (my-long-metric): " + acc.value) + println("*** Double accumulator (my-double-metric): " + acc2.value) + + spark.stop() + } +} +// scalastyle:on println From 90a810352e94c0b74c19324301e51e8f5bbe98dd Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 22 Dec 2018 10:32:32 -0600 Subject: [PATCH 2332/2461] [SPARK-25245][DOCS][SS] Explain regarding limiting modification on "spark.sql.shuffle.partitions" for structured streaming ## What changes were proposed in this pull request? This patch adds explanation of `why "spark.sql.shuffle.partitions" keeps unchanged in structured streaming`, which couple of users already wondered and some of them even thought it as a bug. This patch would help other end users to know about such behavior before they find by theirselves and being wondered. ## How was this patch tested? No need to test because this is a simple addition on guide doc with markdown editor. Closes #22238 from HeartSaVioR/SPARK-25245. Lead-authored-by: Jungtaek Lim Co-authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Sean Owen --- docs/structured-streaming-programming-guide.md | 10 ++++++++++ .../scala/org/apache/spark/sql/internal/SQLConf.scala | 8 ++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 32d61dcdb4599..e76b53dbb4dc3 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -3113,6 +3113,16 @@ See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections f # Additional Information +**Notes** + +- Several configurations are not modifiable after the query has run. To change them, discard the checkpoint and start a new query. These configurations include: + - `spark.sql.shuffle.partitions` + - This is due to the physical partitioning of state: state is partitioned via applying hash function to key, hence the number of partitions for state should be unchanged. + - If you want to run fewer tasks for stateful operations, `coalesce` would help with avoiding unnecessary repartitioning. + - After `coalesce`, the number of (reduced) tasks will be kept unless another shuffle happens. + - `spark.sql.streaming.stateStore.providerClass`: To read the previous state of the query properly, the class of state store provider should be unchanged. + - `spark.sql.streaming.multipleWatermarkPolicy`: Modification of this would lead inconsistent watermark value when query contains multiple watermarks, hence the policy should be unchanged. + **Further Reading** - See and run the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 86e068bf632bd..fe445e0019353 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -263,7 +263,9 @@ object SQLConf { .createWithDefault(true) val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions") - .doc("The default number of partitions to use when shuffling data for joins or aggregations.") + .doc("The default number of partitions to use when shuffling data for joins or aggregations. " + + "Note: For structured streaming, this configuration cannot be changed between query " + + "restarts from the same checkpoint location.") .intConf .createWithDefault(200) @@ -882,7 +884,9 @@ object SQLConf { .internal() .doc( "The class used to manage state data in stateful streaming queries. This class must " + - "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") + "be a subclass of StateStoreProvider, and must have a zero-arg constructor. " + + "Note: For structured streaming, this configuration cannot be changed between query " + + "restarts from the same checkpoint location.") .stringConf .createWithDefault( "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") From a5a24d92bdf6e6a8e33bdc8833bedba033576b4c Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sat, 22 Dec 2018 10:35:14 -0800 Subject: [PATCH 2333/2461] [SPARK-26402][SQL] Accessing nested fields with different cases in case insensitive mode ## What changes were proposed in this pull request? GetStructField with different optional names should be semantically equal. We will use this as building block to compare the nested fields used in the plans to be optimized by catalyst optimizer. This PR also fixes a bug below that accessing nested fields with different cases in case insensitive mode will result `AnalysisException`. ``` sql("create table t (s struct) using json") sql("select s.I from t group by s.i") ``` which is currently failing ``` org.apache.spark.sql.AnalysisException: expression 'default.t.`s`' is neither present in the group by, nor is it an aggregate function ``` as cloud-fan pointed out. ## How was this patch tested? New tests are added. Closes #23353 from dbtsai/nestedEqual. Lead-authored-by: DB Tsai Co-authored-by: DB Tsai Signed-off-by: Dongjoon Hyun --- .../catalyst/expressions/Canonicalize.scala | 4 ++- .../expressions/CanonicalizeSuite.scala | 29 ++++++++++++++++++ .../BinaryComparisonSimplificationSuite.scala | 30 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 19 ++++++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index fe6db8b344d3d..4d218b936b3a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -26,6 +26,7 @@ package org.apache.spark.sql.catalyst.expressions * * The following rules are applied: * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. + * - Names for [[GetStructField]] are stripped. * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. @@ -37,10 +38,11 @@ object Canonicalize { expressionReorder(ignoreNamesTypes(e)) } - /** Remove names and nullability from types. */ + /** Remove names and nullability from types, and names from `GetStructField`. */ private[expressions] def ignoreNamesTypes(e: Expression): Expression = e match { case a: AttributeReference => AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + case GetStructField(child, ordinal, Some(_)) => GetStructField(child, ordinal, None) case _ => e } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 28e6940f3cca3..9802a6e5891b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -50,4 +51,32 @@ class CanonicalizeSuite extends SparkFunSuite { assert(range.where(arrays1).sameResult(range.where(arrays2))) assert(!range.where(arrays1).sameResult(range.where(arrays3))) } + + test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") { + val expId = NamedExpression.newExprId + val qualifier = Seq.empty[String] + val structType = StructType( + StructField("a", StructType(StructField("b", IntegerType, false) :: Nil), false) :: Nil) + + // GetStructField with different names are semantically equal + val fieldA1 = GetStructField( + AttributeReference("data1", structType, false)(expId, qualifier), + 0, Some("a1")) + val fieldA2 = GetStructField( + AttributeReference("data2", structType, false)(expId, qualifier), + 0, Some("a2")) + assert(fieldA1.semanticEquals(fieldA2)) + + val fieldB1 = GetStructField( + GetStructField( + AttributeReference("data1", structType, false)(expId, qualifier), + 0, Some("a1")), + 0, Some("b1")) + val fieldB2 = GetStructField( + GetStructField( + AttributeReference("data2", structType, false)(expId, qualifier), + 0, Some("a2")), + 0, Some("b2")) + assert(fieldB1.semanticEquals(fieldB2)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index a313681eeb8f0..5794691a365a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { @@ -92,4 +93,33 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper val correctAnswer = nonNullableRelation.analyze comparePlans(actual, correctAnswer) } + + test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") { + val expId = NamedExpression.newExprId + val qualifier = Seq.empty[String] + val structType = StructType( + StructField("a", StructType(StructField("b", IntegerType, false) :: Nil), false) :: Nil) + + val fieldA1 = GetStructField( + GetStructField( + AttributeReference("data1", structType, false)(expId, qualifier), + 0, Some("a1")), + 0, Some("b1")) + val fieldA2 = GetStructField( + GetStructField( + AttributeReference("data2", structType, false)(expId, qualifier), + 0, Some("a2")), + 0, Some("b2")) + + // GetStructField with different names are semantically equal; thus, `EqualTo(fieldA1, fieldA2)` + // will be optimized to `TrueLiteral` by `SimplifyBinaryComparison`. + val originalQuery = nonNullableRelation + .where(EqualTo(fieldA1, fieldA2)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = nonNullableRelation.analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 37a8815350a53..656da9fa01806 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2937,6 +2937,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val msg = intercept[AnalysisException] { + withTable("t") { + sql("create table t (s struct) using json") + checkAnswer(sql("select s.I from t group by s.i"), Nil) + } + }.message + assert(msg.contains("No such struct field I in i")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTable("t") { + sql("create table t (s struct) using json") + checkAnswer(sql("select s.I from t group by s.i"), Nil) + } + } + } } case class Foo(bar: Option[String]) From 1008ab0801c192e8f261001eaaf58a6c9f6e747a Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 24 Dec 2018 10:47:47 +0800 Subject: [PATCH 2334/2461] [SPARK-26178][SPARK-26243][SQL][FOLLOWUP] Replacing SimpleDateFormat by DateTimeFormatter in comments ## What changes were proposed in this pull request? The PRs #23150 and #23196 switched JSON and CSV datasources on new formatter for dates/timestamps which is based on `DateTimeFormatter`. In this PR, I replaced `SimpleDateFormat` by `DateTimeFormatter` to reflect the changes. Closes #23374 from MaxGekk/java-time-docs. Authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/readwriter.py | 28 +++++++++++-------- python/pyspark/sql/streaming.py | 14 ++++++---- .../apache/spark/sql/DataFrameReader.scala | 12 ++++---- .../apache/spark/sql/DataFrameWriter.scala | 12 ++++---- .../sql/streaming/DataStreamReader.scala | 12 ++++---- 5 files changed, 42 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b10512a43294..3da052391a95b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -226,11 +226,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param multiLine: parse one record, which may span multiple lines, per file. If None is @@ -406,11 +407,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param negativeInf: sets the string representation of a negative infinity value. If None is set, it uses the default value, ``Inf``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is @@ -803,11 +805,12 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param encoding: specifies encoding (charset) of saved json files. If None is set, @@ -904,11 +907,12 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fc23b9d99c34a..b981fdc4edc77 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -456,11 +456,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param multiLine: parse one record, which may span multiple lines, per file. If None is @@ -630,11 +631,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param negativeInf: sets the string representation of a negative infinity value. If None is set, it uses the default value, ``Inf``. :param dateFormat: sets the string that indicates a date format. Custom date formats - follow the formats at ``java.text.SimpleDateFormat``. This + follow the formats at ``java.time.format.DateTimeFormatter``. This applies to date type. If None is set, it uses the default value, ``yyyy-MM-dd``. - :param timestampFormat: sets the string that indicates a timestamp format. Custom date - formats follow the formats at ``java.text.SimpleDateFormat``. + :param timestampFormat: sets the string that indicates a timestamp format. + Custom date formats follow the formats at + ``java.time.format.DateTimeFormatter``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9751528654ffb..ce8e4c8f5b82b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -375,11 +375,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.

    • *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
    • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • *
    • `encoding` (by default it is not set): allows to forcibly set one of standard basic @@ -585,11 +585,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity * value.
    • *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
    • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
    • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
    • *
    • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b9c4076994e96..981b3a8fd4ac1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -530,11 +530,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
    • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved json * files. If it is not set, the UTF-8 charset will be used.
    • *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • @@ -649,11 +649,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`). *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
    • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
    • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading * whitespaces from values being written should be skipped.
    • *
    • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 914fa90ae7e14..98589da9552cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -286,11 +286,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
    • *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
    • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator @@ -347,11 +347,11 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity * value.
    • *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. - * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to - * date type.
    • + * Custom date formats follow the formats at `java.time.format.DateTimeFormatter`. + * This applies to date type. *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at - * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + * `java.time.format.DateTimeFormatter`. This applies to timestamp type. *
    • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
    • *
    • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed From 0523f5e378e69f406104fabaf3ebe913de976bdb Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 23 Dec 2018 21:09:44 -0800 Subject: [PATCH 2335/2461] [SPARK-14023][CORE][SQL] Don't reference 'field' in StructField errors for clarity in exceptions ## What changes were proposed in this pull request? Variation of https://github.com/apache/spark/pull/20500 I cheated by not referencing fields or columns at all as this exception propagates in contexts where both would be applicable. ## How was this patch tested? Existing tests Closes #23373 from srowen/SPARK-14023.2. Authored-by: Sean Owen Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++---------- .../spark/sql/types/StructTypeSuite.scala | 8 ++++---- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6e8bbde7787a6..e01d7c59cac52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -28,7 +28,6 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} -import org.apache.spark.util.Utils /** * A [[StructType]] object can be constructed by @@ -57,7 +56,7 @@ import org.apache.spark.util.Utils * * // If this struct does not have a field called "d", it throws an exception. * struct("d") - * // java.lang.IllegalArgumentException: Field "d" does not exist. + * // java.lang.IllegalArgumentException: d does not exist. * // ... * * // Extract multiple StructFields. Field names are provided in a set. @@ -69,7 +68,7 @@ import org.apache.spark.util.Utils * // Any names without matching fields will throw an exception. * // For the case shown below, an exception is thrown due to "d". * struct(Set("b", "c", "d")) - * // java.lang.IllegalArgumentException: Field "d" does not exist. + * // java.lang.IllegalArgumentException: d does not exist. * // ... * }}} * @@ -272,22 +271,21 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def apply(name: String): StructField = { nameToField.getOrElse(name, throw new IllegalArgumentException( - s"""Field "$name" does not exist. - |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) + s"$name does not exist. Available: ${fieldNames.mkString(", ")}")) } /** * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the * original order of fields. * - * @throws IllegalArgumentException if a field cannot be found for any of the given names + * @throws IllegalArgumentException if at least one given field name does not exist */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. - |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) + s"${nonExistFields.mkString(", ")} do(es) not exist. " + + s"Available: ${fieldNames.mkString(", ")}") } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -301,8 +299,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, throw new IllegalArgumentException( - s"""Field "$name" does not exist. - |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) + s"$name does not exist. Available: ${fieldNames.mkString(", ")}")) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 53a78c94aa6fb..b4ce26be24de2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -22,21 +22,21 @@ import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { - val s = StructType.fromDDL("a INT, b STRING") + private val s = StructType.fromDDL("a INT, b STRING") test("lookup a single missing field should output existing fields") { val e = intercept[IllegalArgumentException](s("c")).getMessage - assert(e.contains("Available fields: a, b")) + assert(e.contains("Available: a, b")) } test("lookup a set of missing fields should output existing fields") { val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage - assert(e.contains("Available fields: a, b")) + assert(e.contains("Available: a, b")) } test("lookup fieldIndex for missing field should output existing fields") { val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage - assert(e.contains("Available fields: a, b")) + assert(e.contains("Available: a, b")) } test("SPARK-24849: toDDL - simple struct") { From 827383a97c11a61661440ff86ce0c3382a2a23b2 Mon Sep 17 00:00:00 2001 From: wangyanlin01 Date: Tue, 25 Dec 2018 15:53:42 +0800 Subject: [PATCH 2336/2461] [SPARK-26426][SQL] fix ExpresionInfo assert error in windows operation system. ## What changes were proposed in this pull request? fix ExpresionInfo assert error in windows operation system, when running unit tests. ## How was this patch tested? unit tests Closes #23363 from yanlin-Lynn/unit-test-windows. Authored-by: wangyanlin01 Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/catalyst/expressions/ExpressionInfo.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index ab13ac9cc5483..d5a1b77c0ec81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -79,7 +79,7 @@ public ExpressionInfo( assert name != null; assert arguments != null; assert examples != null; - assert examples.isEmpty() || examples.startsWith("\n Examples:"); + assert examples.isEmpty() || examples.startsWith(System.lineSeparator() + " Examples:"); assert note != null; assert since != null; From 7c7fccfeb5bc079fede41eb64f57ab6b1b4b9018 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 27 Dec 2018 11:09:50 +0800 Subject: [PATCH 2337/2461] [SPARK-26424][SQL] Use java.time API in date/timestamp expressions ## What changes were proposed in this pull request? In the PR, I propose to switch the `DateFormatClass`, `ToUnixTimestamp`, `FromUnixTime`, `UnixTime` on java.time API for parsing/formatting dates and timestamps. The API has been already implemented by the `Timestamp`/`DateFormatter` classes. One of benefit is those classes support parsing timestamps with microsecond precision. Old behaviour can be switched on via SQL config: `spark.sql.legacy.timeParser.enabled` (`false` by default). ## How was this patch tested? It was tested by existing test suites - `DateFunctionsSuite`, `DateExpressionsSuite`, `JsonSuite`, `CsvSuite`, `SQLQueryTestSuite` as well as PySpark tests. Closes #23358 from MaxGekk/new-time-cast. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- R/pkg/R/functions.R | 8 +- docs/sql-migration-guide-upgrade.md | 1 + python/pyspark/sql/functions.py | 6 +- .../sql/catalyst/csv/CSVInferSchema.scala | 3 +- .../expressions/datetimeExpressions.scala | 82 +++++++++++-------- .../sql/catalyst/json/JsonInferSchema.scala | 3 +- .../sql/catalyst/util/DateFormatter.scala | 8 +- .../util/DateTimeFormatterHelper.scala | 21 +++-- .../sql/catalyst/util/DateTimeUtils.scala | 10 --- .../catalyst/util/TimestampFormatter.scala | 22 ++++- .../catalyst/csv/UnivocityParserSuite.scala | 2 +- .../spark/sql/util/DateFormatterSuite.scala | 7 ++ .../sql/util/TimestampFormatterSuite.scala | 12 +++ .../org/apache/spark/sql/functions.scala | 10 +-- .../apache/spark/sql/DateFunctionsSuite.scala | 2 +- 15 files changed, 122 insertions(+), 75 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f568a931ae1fe..5b3cc0940d9c3 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1723,7 +1723,7 @@ setMethod("radians", #' @details #' \code{to_date}: Converts the column into a DateType. You may optionally specify #' a format according to the rules in: -#' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. +#' \url{https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a DateType if the format is omitted @@ -1819,7 +1819,7 @@ setMethod("to_csv", signature(x = "Column"), #' @details #' \code{to_timestamp}: Converts the column into a TimestampType. You may optionally specify #' a format according to the rules in: -#' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. +#' \url{https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html}. #' If the string cannot be parsed according to the specified format (or default), #' the value of the column will be null. #' By default, it follows casting rules to a TimestampType if the format is omitted @@ -2240,7 +2240,7 @@ setMethod("n", signature(x = "Column"), #' \code{date_format}: Converts a date/timestamp/string to a value of string in the format #' specified by the date format given by the second argument. A pattern could be for instance #' \code{dd.MM.yyyy} and could return a string like '18.03.1993'. All -#' pattern letters of \code{java.text.SimpleDateFormat} can be used. +#' pattern letters of \code{java.time.format.DateTimeFormatter} can be used. #' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' @@ -2666,7 +2666,7 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) #' to a string representing the timestamp of that moment in the current system time zone in the JVM #' in the given format. -#' See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' See \href{https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html}{ #' Customizing Formats} for available options. #' #' @rdname column_datetime_functions diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 1bd3b5ad0e1aa..c4d2157de8b60 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -39,6 +39,7 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully. + - Since Spark 3.0, the `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions use java.time API for parsing and formatting dates/timestamps from/to strings by using ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html) based on Proleptic Gregorian calendar. In Spark version 2.4 and earlier, java.text.SimpleDateFormat and java.util.GregorianCalendar (hybrid calendar that supports both the Julian and Gregorian calendar systems, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html) is used for the same purpuse. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d188de39e21c7..d2a771e9bb8ea 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -874,7 +874,7 @@ def date_format(date, format): format given by the second argument. A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - pattern letters of the Java class `java.text.SimpleDateFormat` can be used. + pattern letters of the Java class `java.time.format.DateTimeFormatter` can be used. .. note:: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. @@ -1094,7 +1094,7 @@ def to_date(col, format=None): """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType` using the optionally specified format. Specify formats according to - `SimpleDateFormats `_. + `DateTimeFormatter `_. # noqa By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format is omitted (equivalent to ``col.cast("date")``). @@ -1119,7 +1119,7 @@ def to_timestamp(col, format=None): """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType` using the optionally specified format. Specify formats according to - `SimpleDateFormats `_. + `DateTimeFormatter `_. # noqa By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format is omitted (equivalent to ``col.cast("timestamp")``). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 35ade136cc607..4dd41042856d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { - @transient - private lazy val timestampParser = TimestampFormatter( + private val timestampParser = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 45e17ae235a94..73af0a3c5c2ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp -import java.text.DateFormat -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.util.control.NonFatal @@ -28,7 +27,8 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -562,16 +562,17 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti copy(timeZoneId = Option(timeZoneId)) override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val df = DateTimeUtils.newDateFormat(format.toString, timeZone) - UTF8String.fromString(df.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) + val df = TimestampFormatter(format.toString, timeZone, Locale.US) + UTF8String.fromString(df.format(timestamp.asInstanceOf[Long])) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tf = TimestampFormatter.getClass.getName.stripSuffix("$") val tz = ctx.addReferenceObj("timeZone", timeZone) + val locale = ctx.addReferenceObj("locale", Locale.US) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz) - .format(new java.util.Date($timestamp / 1000)))""" + s"""UTF8String.fromString($tf.apply($format.toString(), $tz, $locale) + .format($timestamp))""" }) } @@ -612,9 +613,10 @@ case class ToUnixTimestamp( } /** - * Converts time string with given pattern. - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), returns null if fail. + * Converts time string with given pattern to Unix time stamp (in seconds), returns null if fail. + * See [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html] + * if SQL config spark.sql.legacy.timeParser.enabled is set to true otherwise + * [https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html]. * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". * If no parameters provided, the first parameter will be current_timestamp. @@ -663,9 +665,9 @@ abstract class UnixTime override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: DateFormat = + private lazy val formatter: TimestampFormatter = try { - DateTimeUtils.newDateFormat(constFormat.toString, timeZone) + TimestampFormatter(constFormat.toString, timeZone, Locale.US) } catch { case NonFatal(_) => null } @@ -677,16 +679,16 @@ abstract class UnixTime } else { left.dataType match { case DateType => - DateTimeUtils.daysToMillis(t.asInstanceOf[Int], timeZone) / 1000L + DateTimeUtils.daysToMillis(t.asInstanceOf[Int], timeZone) / MILLIS_PER_SECOND case TimestampType => - t.asInstanceOf[Long] / 1000000L + t.asInstanceOf[Long] / MICROS_PER_SECOND case StringType if right.foldable => if (constFormat == null || formatter == null) { null } else { try { formatter.parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L + t.asInstanceOf[UTF8String].toString) / MICROS_PER_SECOND } catch { case NonFatal(_) => null } @@ -698,8 +700,8 @@ abstract class UnixTime } else { val formatString = f.asInstanceOf[UTF8String].toString try { - DateTimeUtils.newDateFormat(formatString, timeZone).parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L + TimestampFormatter(formatString, timeZone, Locale.US).parse( + t.asInstanceOf[UTF8String].toString) / MICROS_PER_SECOND } catch { case NonFatal(_) => null } @@ -712,7 +714,7 @@ abstract class UnixTime val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => - val df = classOf[DateFormat].getName + val df = classOf[TimestampFormatter].getName if (formatter == null) { ExprCode.forNullValue(dataType) } else { @@ -724,24 +726,35 @@ abstract class UnixTime $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { - ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; + ${ev.value} = $formatterName.parse(${eval1.value}.toString()) / 1000000L; + } catch (java.lang.IllegalArgumentException e) { + ${ev.isNull} = true; } catch (java.text.ParseException e) { ${ev.isNull} = true; + } catch (java.time.format.DateTimeParseException e) { + ${ev.isNull} = true; + } catch (java.time.DateTimeException e) { + ${ev.isNull} = true; } }""") } case StringType => val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val locale = ctx.addReferenceObj("locale", Locale.US) + val dtu = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $dtu.newDateFormat($format.toString(), $tz) - .parse($string.toString()).getTime() / 1000L; + ${ev.value} = $dtu.apply($format.toString(), $tz, $locale) + .parse($string.toString()) / 1000000L; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } catch (java.text.ParseException e) { ${ev.isNull} = true; + } catch (java.time.format.DateTimeParseException e) { + ${ev.isNull} = true; + } catch (java.time.DateTimeException e) { + ${ev.isNull} = true; } """ }) @@ -806,9 +819,9 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ copy(timeZoneId = Option(timeZoneId)) private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: DateFormat = + private lazy val formatter: TimestampFormatter = try { - DateTimeUtils.newDateFormat(constFormat.toString, timeZone) + TimestampFormatter(constFormat.toString, timeZone, Locale.US) } catch { case NonFatal(_) => null } @@ -823,8 +836,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ null } else { try { - UTF8String.fromString(formatter.format( - new java.util.Date(time.asInstanceOf[Long] * 1000L))) + UTF8String.fromString(formatter.format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { case NonFatal(_) => null } @@ -835,8 +847,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ null } else { try { - UTF8String.fromString(DateTimeUtils.newDateFormat(f.toString, timeZone) - .format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) + UTF8String.fromString(TimestampFormatter(f.toString, timeZone, Locale.US) + .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { case NonFatal(_) => null } @@ -846,7 +858,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val df = classOf[DateFormat].getName + val df = classOf[TimestampFormatter].getName if (format.foldable) { if (formatter == null) { ExprCode.forNullValue(StringType) @@ -859,8 +871,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { - ${ev.value} = UTF8String.fromString($formatterName.format( - new java.util.Date(${t.value} * 1000L))); + ${ev.value} = UTF8String.fromString($formatterName.format(${t.value} * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } @@ -868,12 +879,13 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } } else { val tz = ctx.addReferenceObj("timeZone", timeZone) - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val locale = ctx.addReferenceObj("locale", Locale.US) + val tf = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.value} = UTF8String.fromString($dtu.newDateFormat($f.toString(), $tz).format( - new java.util.Date($seconds * 1000L))); + ${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $tz, $locale). + format($seconds * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index d1bc00c08c1c6..3203e626ea400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -37,8 +37,7 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { private val decimalParser = ExprUtils.getDecimalParser(options.locale) - @transient - private lazy val timestampFormatter = TimestampFormatter( + private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.timeZone, options.locale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 9e8d51cc65f03..b4c99674fc1cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.sql.internal.SQLConf -sealed trait DateFormatter { +sealed trait DateFormatter extends Serializable { def parse(s: String): Int // returns days since epoch def format(days: Int): String } @@ -35,7 +35,8 @@ class Iso8601DateFormatter( pattern: String, locale: Locale) extends DateFormatter with DateTimeFormatterHelper { - private val formatter = buildFormatter(pattern, locale) + @transient + private lazy val formatter = buildFormatter(pattern, locale) private val UTC = ZoneId.of("UTC") private def toInstant(s: String): Instant = { @@ -56,7 +57,8 @@ class Iso8601DateFormatter( } class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { - private val format = FastDateFormat.getInstance(pattern, locale) + @transient + private lazy val format = FastDateFormat.getInstance(pattern, locale) override def parse(s: String): Int = { val milliseconds = format.parse(s).getTime diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala index b85101d38d9e6..91cc57e0bb019 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -17,27 +17,36 @@ package org.apache.spark.sql.catalyst.util -import java.time.{Instant, LocalDateTime, ZonedDateTime, ZoneId} -import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder} -import java.time.temporal.{ChronoField, TemporalAccessor} +import java.time._ +import java.time.chrono.IsoChronology +import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, ResolverStyle} +import java.time.temporal.{ChronoField, TemporalAccessor, TemporalQueries} import java.util.Locale trait DateTimeFormatterHelper { protected def buildFormatter(pattern: String, locale: Locale): DateTimeFormatter = { new DateTimeFormatterBuilder() + .parseCaseInsensitive() .appendPattern(pattern) - .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970) + .parseDefaulting(ChronoField.ERA, 1) .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1) .parseDefaulting(ChronoField.DAY_OF_MONTH, 1) - .parseDefaulting(ChronoField.HOUR_OF_DAY, 0) .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0) .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0) .toFormatter(locale) + .withChronology(IsoChronology.INSTANCE) + .withResolverStyle(ResolverStyle.STRICT) } protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor, zoneId: ZoneId): Instant = { - val localDateTime = LocalDateTime.from(temporalAccessor) + val localTime = if (temporalAccessor.query(TemporalQueries.localTime) == null) { + LocalTime.ofNanoOfDay(0) + } else { + LocalTime.from(temporalAccessor) + } + val localDate = LocalDate.from(temporalAccessor) + val localDateTime = LocalDateTime.of(localDate, localTime) val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId) Instant.from(zonedDateTime) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index c6dfdbf2505ba..3e5e1fbc2b368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -111,16 +111,6 @@ object DateTimeUtils { computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone) } - def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = { - val sdf = new SimpleDateFormat(formatString, Locale.US) - sdf.setTimeZone(timeZone) - // Enable strict parsing, if the input date/format is invalid, it will throw an exception. - // e.g. to parse invalid date '2016-13-12', or '2016-01-12' with invalid format 'yyyy-aa-dd', - // an exception will be throwed. - sdf.setLenient(false) - sdf - } - // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisUtc: Long): SQLDate = { millisToDays(millisUtc, defaultTimeZone()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index eb1303303463d..b67b2d7cc3c51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.util +import java.text.ParseException import java.time._ +import java.time.format.DateTimeParseException import java.time.temporal.TemporalQueries import java.util.{Locale, TimeZone} @@ -27,7 +29,19 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.sql.internal.SQLConf -sealed trait TimestampFormatter { +sealed trait TimestampFormatter extends Serializable { + /** + * Parses a timestamp in a string and converts it to microseconds. + * + * @param s - string with timestamp to parse + * @return microseconds since epoch. + * @throws ParseException can be thrown by legacy parser + * @throws DateTimeParseException can be thrown by new parser + * @throws DateTimeException unable to obtain local date or time + */ + @throws(classOf[ParseException]) + @throws(classOf[DateTimeParseException]) + @throws(classOf[DateTimeException]) def parse(s: String): Long // returns microseconds since epoch def format(us: Long): String } @@ -36,7 +50,8 @@ class Iso8601TimestampFormatter( pattern: String, timeZone: TimeZone, locale: Locale) extends TimestampFormatter with DateTimeFormatterHelper { - private val formatter = buildFormatter(pattern, locale) + @transient + private lazy val formatter = buildFormatter(pattern, locale) private def toInstant(s: String): Instant = { val temporalAccessor = formatter.parse(s) @@ -68,7 +83,8 @@ class LegacyTimestampFormatter( pattern: String, timeZone: TimeZone, locale: Locale) extends TimestampFormatter { - private val format = FastDateFormat.getInstance(pattern, timeZone, locale) + @transient + private lazy val format = FastDateFormat.getInstance(pattern, timeZone, locale) protected def toMillis(s: String): Long = format.parse(s).getTime diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 2d0b0d3033a9c..4ae61bc61255c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -112,7 +112,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { assert(parser.makeConverter("_1", BooleanType).apply("true") == true) var timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy HH:mm"), false, "GMT") parser = new UnivocityParser(StructType(Seq.empty), timestampsOptions) val customTimestamp = "31/01/2015 00:00" var format = FastDateFormat.getInstance( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala index 019615b81101c..2dc55e0e1f633 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.util +import java.time.LocalDate import java.util.Locale import org.apache.spark.SparkFunSuite @@ -89,4 +90,10 @@ class DateFormatterSuite extends SparkFunSuite with SQLHelper { } } } + + test("parsing date without explicit day") { + val formatter = DateFormatter("yyyy MMM", Locale.US) + val daysSinceEpoch = formatter.parse("2018 Dec") + assert(daysSinceEpoch === LocalDate.of(2018, 12, 1).toEpochDay) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index c110ffa01f733..edccbb2a7f5db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.util +import java.time.{LocalDateTime, ZoneOffset} import java.util.{Locale, TimeZone} +import java.util.concurrent.TimeUnit import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.SQLHelper @@ -106,4 +108,14 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper { } } } + + test(" case insensitive parsing of am and pm") { + val formatter = TimestampFormatter( + "yyyy MMM dd hh:mm:ss a", + TimeZone.getTimeZone("UTC"), + Locale.US) + val micros = formatter.parse("2009 Mar 20 11:30:01 am") + assert(micros === TimeUnit.SECONDS.toMicros( + LocalDateTime.of(2009, 3, 20, 11, 30, 1).toEpochSecond(ZoneOffset.UTC))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 33186f778d868..645452553e6a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2578,7 +2578,7 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` @@ -2811,7 +2811,7 @@ object functions { * representing the timestamp of that moment in the current system time zone in the given * format. * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param ut A number of a type that is castable to a long, such as string or integer. Can be * negative for timestamps before the unix epoch @@ -2855,7 +2855,7 @@ object functions { /** * Converts time string with given pattern to Unix timestamp (in seconds). * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param s A date, timestamp or string. If a string, the data must be in a format that can be * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` @@ -2883,7 +2883,7 @@ object functions { /** * Converts time string with the given pattern to timestamp. * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param s A date, timestamp or string. If a string, the data must be in a format that can be * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` @@ -2908,7 +2908,7 @@ object functions { /** * Converts the column into a `DateType` with a specified format * - * See [[java.text.SimpleDateFormat]] for valid date and time format patterns + * See [[java.time.format.DateTimeFormatter]] for valid date and time format patterns * * @param e A date, timestamp or string. If a string, the data must be in a format that can be * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index c4ec7150c4075..62bb72dd6ea25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -405,7 +405,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Date.valueOf("2014-12-31")))) checkAnswer( df.select(to_date(col("s"), "yyyy-MM-dd")), - Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")), Row(null))) + Seq(Row(null), Row(Date.valueOf("2014-12-31")), Row(null))) // now switch format checkAnswer( From f89cdec8b9a9fcc95ba7458869b4ba9d038560f9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 27 Dec 2018 16:03:14 +0800 Subject: [PATCH 2338/2461] [SPARK-26435][SQL] Support creating partitioned table using Hive CTAS by specifying partition column names ## What changes were proposed in this pull request? Spark SQL doesn't support creating partitioned table using Hive CTAS in SQL syntax. However it is supported by using DataFrameWriter API. ```scala val df = Seq(("a", 1)).toDF("part", "id") df.write.format("hive").partitionBy("part").saveAsTable("t") ``` Hive begins to support this syntax in newer version: https://issues.apache.org/jira/browse/HIVE-20241: ``` CREATE TABLE t PARTITIONED BY (part) AS SELECT 1 as id, "a" as part ``` This patch adds this support to SQL syntax. ## How was this patch tested? Added tests. Closes #23376 from viirya/hive-ctas-partitioned-table. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../spark/sql/execution/SparkSqlParser.scala | 33 ++++++++----- .../sql/hive/execution/HiveDDLSuite.scala | 48 ++++++++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5e732edb17baa..b39681d886c5c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -88,7 +88,8 @@ statement (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? ((COMMENT comment=STRING) | - (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + (PARTITIONED BY '(' partitionColumns=colTypeList ')' | + PARTITIONED BY partitionColumnNames=identifierList) | bucketSpec | skewSpec | rowFormat | diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 364efea52830e..8deb55b00a9d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1196,33 +1196,40 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { selectQuery match { case Some(q) => - // Hive does not allow to use a CTAS statement to create a partitioned table. - if (tableDesc.partitionColumnNames.nonEmpty) { - val errorMessage = "A Create Table As Select (CTAS) statement is not allowed to " + - "create a partitioned table using Hive's file formats. " + - "Please use the syntax of \"CREATE TABLE tableName USING dataSource " + - "OPTIONS (...) PARTITIONED BY ...\" to create a partitioned table through a " + - "CTAS statement." - operationNotAllowed(errorMessage, ctx) - } - // Don't allow explicit specification of schema for CTAS. - if (schema.nonEmpty) { + if (dataCols.nonEmpty) { operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", ctx) } + // When creating partitioned table with CTAS statement, we can't specify data type for the + // partition columns. + if (partitionCols.nonEmpty) { + val errorMessage = "Create Partitioned Table As Select cannot specify data type for " + + "the partition columns of the target table." + operationNotAllowed(errorMessage, ctx) + } + + // Hive CTAS supports dynamic partition by specifying partition column names. + val partitionColumnNames = + Option(ctx.partitionColumnNames) + .map(visitIdentifierList(_).toArray) + .getOrElse(Array.empty[String]) + + val tableDescWithPartitionColNames = + tableDesc.copy(partitionColumnNames = partitionColumnNames) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. - val newTableDesc = tableDesc.copy( + val newTableDesc = tableDescWithPartitionColNames.copy( storage = CatalogStorageFormat.empty.copy(locationUri = locUri), provider = Some(conf.defaultDataSourceName)) CreateTable(newTableDesc, mode, Some(q)) } else { - CreateTable(tableDesc, mode, Some(q)) + CreateTable(tableDescWithPartitionColNames, mode, Some(q)) } case None => CreateTable(tableDesc, mode, None) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index fd38944a5dd2e..6abdc4054cb0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.net.URI -import java.util.Date import scala.language.existentials @@ -33,6 +32,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveExternalCatalog @@ -2370,4 +2370,50 @@ class HiveDDLSuite )) } } + + test("Hive CTAS can't create partitioned table by specifying schema") { + val err1 = intercept[ParseException] { + spark.sql( + s""" + |CREATE TABLE t (a int) + |PARTITIONED BY (b string) + |STORED AS parquet + |AS SELECT 1 as a, "a" as b + """.stripMargin) + }.getMessage + assert(err1.contains("Schema may not be specified in a Create Table As Select " + + "(CTAS) statement")) + + val err2 = intercept[ParseException] { + spark.sql( + s""" + |CREATE TABLE t + |PARTITIONED BY (b string) + |STORED AS parquet + |AS SELECT 1 as a, "a" as b + """.stripMargin) + }.getMessage + assert(err2.contains("Create Partitioned Table As Select cannot specify data type for " + + "the partition columns of the target table")) + } + + test("Hive CTAS with dynamic partition") { + Seq("orc", "parquet").foreach { format => + withTable("t") { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + spark.sql( + s""" + |CREATE TABLE t + |PARTITIONED BY (b) + |STORED AS $format + |AS SELECT 1 as a, "a" as b + """.stripMargin) + checkAnswer(spark.table("t"), Row(1, "a")) + + assert(spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + .partitionColumnNames === Seq("b")) + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6acf44606cbbe..70efad103d13e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -692,8 +692,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |AS SELECT key, value FROM mytable1 """.stripMargin) }.getMessage - assert(e.contains("A Create Table As Select (CTAS) statement is not allowed to " + - "create a partitioned table using Hive's file formats")) + assert(e.contains("Create Partitioned Table As Select cannot specify data type for " + + "the partition columns of the target table")) } } } From a1c1dd3484a4dcd7c38fe256e69dbaaaf10d1a92 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 27 Dec 2018 11:13:16 +0100 Subject: [PATCH 2339/2461] [SPARK-26191][SQL] Control truncation of Spark plans via maxFields parameter ## What changes were proposed in this pull request? In the PR, I propose to add `maxFields` parameter to all functions involved in creation of textual representation of spark plans such as `simpleString` and `verboseString`. New parameter restricts number of fields converted to truncated strings. Any elements beyond the limit will be dropped and replaced by a `"... N more fields"` placeholder. The threshold is bumped up to `Int.MaxValue` for `toFile()`. ## How was this patch tested? Added a test to `QueryExecutionSuite` which checks `maxFields` impacts on number of truncated fields in `LocalRelation`. Closes #23159 from MaxGekk/to-file-max-fields. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Herman van Hovell --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../catalyst/encoders/ExpressionEncoder.scala | 8 ++- .../sql/catalyst/expressions/Expression.scala | 6 +- .../expressions/codegen/javaCode.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../expressions/higherOrderFunctions.scala | 4 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/namedExpressions.scala | 4 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../catalyst/plans/logical/LogicalPlan.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 8 +-- .../spark/sql/catalyst/trees/TreeNode.scala | 56 +++++++++++-------- .../spark/sql/catalyst/util/package.scala | 10 ++-- .../apache/spark/sql/types/StructType.scala | 6 +- .../aggregate/PercentileSuite.scala | 2 +- .../org/apache/spark/sql/util/UtilSuite.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 12 ++-- .../spark/sql/execution/ExistingRDD.scala | 6 +- .../spark/sql/execution/QueryExecution.scala | 21 ++++--- .../spark/sql/execution/SparkPlanInfo.scala | 6 +- .../sql/execution/WholeStageCodegenExec.scala | 28 ++++++++-- .../aggregate/HashAggregateExec.scala | 12 ++-- .../aggregate/ObjectHashAggregateExec.scala | 12 ++-- .../aggregate/SortAggregateExec.scala | 12 ++-- .../execution/basicPhysicalOperators.scala | 4 +- .../execution/columnar/InMemoryRelation.scala | 4 +- .../datasources/LogicalRelation.scala | 4 +- .../SaveIntoDataSourceCommand.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../datasources/v2/DataSourceV2Relation.scala | 8 ++- .../datasources/v2/DataSourceV2ScanExec.scala | 4 +- .../v2/DataSourceV2StreamingScanExec.scala | 2 +- .../v2/DataSourceV2StringFormat.scala | 6 +- .../spark/sql/execution/debug/package.scala | 3 +- .../apache/spark/sql/execution/limit.scala | 6 +- .../streaming/MicroBatchExecution.scala | 6 +- .../continuous/ContinuousExecution.scala | 7 ++- .../sql/execution/streaming/memory.scala | 5 +- .../apache/spark/sql/execution/subquery.scala | 2 +- .../sql/execution/QueryExecutionSuite.scala | 13 +++++ .../CreateHiveTableAsSelectCommand.scala | 2 +- 43 files changed, 203 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 777053168a056..198645d875c47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -979,7 +979,7 @@ class Analyzer( a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) case q: LogicalPlan => - logTrace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) } @@ -1777,7 +1777,7 @@ class Analyzer( case p if p.expressions.exists(hasGenerator) => throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + - "got: " + p.simpleString) + "got: " + p.simpleString(SQLConf.get.maxToStringFields)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 88d41e8824405..c28a97839fe49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -303,7 +304,7 @@ trait CheckAnalysis extends PredicateHelper { val missingAttributes = o.missingInput.mkString(",") val input = o.inputSet.mkString(",") val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " + - s"from $input in operator ${operator.simpleString}." + s"from $input in operator ${operator.simpleString(SQLConf.get.maxToStringFields)}." val resolver = plan.conf.resolver val attrsWithSameName = o.missingInput.filter { missing => @@ -368,7 +369,7 @@ trait CheckAnalysis extends PredicateHelper { s"""nondeterministic expressions are only allowed in |Project, Filter, Aggregate or Window, found: | ${o.expressions.map(_.sql).mkString(",")} - |in operator ${operator.simpleString} + |in operator ${operator.simpleString(SQLConf.get.maxToStringFields)} """.stripMargin) case _: UnresolvedHint => @@ -380,7 +381,8 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { - case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") + case o if !o.resolved => + failAnalysis(s"unresolved operator ${o.simpleString(SQLConf.get.maxToStringFields)}") case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1706b3eece6d7..b19aa50ba2156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1069,8 +1069,8 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug( - s"Promoting $a from ${a.dataType} to ${newType.dataType} in ${q.simpleString}") + logDebug(s"Promoting $a from ${a.dataType} to ${newType.dataType} in " + + s" ${q.simpleString(SQLConf.get.maxToStringFields)}") newType } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index fbf0bd68b9584..da5c1fd0feb01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -323,8 +324,8 @@ case class ExpressionEncoder[T]( extractProjection(inputRow) } catch { case e: Exception => - throw new RuntimeException( - s"Error while encoding: $e\n${serializer.map(_.simpleString).mkString("\n")}", e) + throw new RuntimeException(s"Error while encoding: $e\n" + + s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e) } /** @@ -336,7 +337,8 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${deserializer.simpleString}", e) + throw new RuntimeException(s"Error while decoding: $e\n" + + s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c89c2272be752..d5d119543da77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -259,12 +259,12 @@ abstract class Expression extends TreeNode[Expression] { // Marks this as final, Expression.verboseString should never be called, and thus shouldn't be // overridden by concrete classes. - final override def verboseString: String = simpleString + final override def verboseString(maxFields: Int): String = simpleString(maxFields) - override def simpleString: String = toString + override def simpleString(maxFields: Int): String = toString override def toString: String = prettyName + truncatedString( - flatArguments.toSeq, "(", ", ", ")") + flatArguments.toSeq, "(", ", ", ")", SQLConf.get.maxToStringFields) /** * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 17d4a0dc4e884..17fff64a1b7df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -197,7 +197,7 @@ trait Block extends TreeNode[Block] with JavaCode { case _ => code"$this\n$other" } - override def verboseString: String = toString + override def verboseString(maxFields: Int): String = toString } object Block { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9c74fdf6c9a14..6b6da1c8b4142 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -101,7 +102,7 @@ case class UserDefinedGenerator( inputRow = new InterpretedProjection(children) convertToScala = { val inputSchema = StructType(children.map { e => - StructField(e.simpleString, e.dataType, nullable = true) + StructField(e.simpleString(SQLConf.get.maxToStringFields), e.dataType, nullable = true) }) CatalystTypeConverters.createToScalaConverter(inputSchema) }.asInstanceOf[InternalRow => Row] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 7141b6e996389..e6cc11d1ad280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -76,7 +76,9 @@ case class NamedLambdaVariable( override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" - override def simpleString: String = s"lambda $name#${exprId.id}: ${dataType.simpleString}" + override def simpleString(maxFields: Int): String = { + s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0cdeda9b10516..1f1decc45a3f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -40,7 +41,7 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { input } - private val outputPrefix = s"Result of ${child.simpleString} is " + private val outputPrefix = s"Result of ${child.simpleString(SQLConf.get.maxToStringFields)} is " override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val outputPrefixField = ctx.addReferenceObj("outputPrefix", outputPrefix) @@ -72,7 +73,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def prettyName: String = "assert_true" - private val errMsg = s"'${child.simpleString}' is not true!" + private val errMsg = s"'${child.simpleString(SQLConf.get.maxToStringFields)}' is not true!" override def eval(input: InternalRow) : Any = { val v = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 02b48f9e30f2d..131459bf27bc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -330,7 +330,9 @@ case class AttributeReference( // Since the expression id is not in the first constructor it is missing from the default // tree string. - override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + override def simpleString(maxFields: Int): String = { + s"$name#${exprId.id}: ${dataType.simpleString(maxFields)}" + } override def sql: String = { val qualifierPrefix = if (qualifier.nonEmpty) qualifier.mkString(".") + "." else "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ca0cea6ba7de3..125181fb213f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -172,9 +172,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString: String = statePrefix + super.simpleString + override def simpleString(maxFields: Int): String = statePrefix + super.simpleString(maxFields) - override def verboseString: String = simpleString + override def verboseString(maxFields: Int): String = simpleString(maxFields) /** * All the subqueries of current plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 3ad2ee6923615..51e0f4b4c84dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -36,8 +36,8 @@ abstract class LogicalPlan /** Returns true if this subtree has data from a streaming data source. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - override def verboseStringWithSuffix: String = { - super.verboseString + statsCache.map(", " + _.toString).getOrElse("") + override def verboseStringWithSuffix(maxFields: Int): String = { + super.verboseString(maxFields) + statsCache.map(", " + _.toString).getOrElse("") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a26ec4eed8648..d8b3a4af4f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -468,7 +468,7 @@ case class View( override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { s"View (${desc.identifier}, ${output.mkString("[", ",", "]")})" } } @@ -484,8 +484,8 @@ case class View( case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def simpleString: String = { - val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]") + override def simpleString(maxFields: Int): String = { + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields) s"CTE $cteAliases" } @@ -557,7 +557,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { s"Range ($start, $end, step=$step, splits=$numSlices)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2e9f9f53e94ac..21e59bbd283e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -433,17 +434,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { private lazy val allChildren: Set[TreeNode[_]] = (children ++ innerChildren).toSet[TreeNode[_]] /** Returns a string representing the arguments to this node, minus any children */ - def argString: String = stringArgs.flatMap { + def argString(maxFields: Int): String = stringArgs.flatMap { case tn: TreeNode[_] if allChildren.contains(tn) => Nil case Some(tn: TreeNode[_]) if allChildren.contains(tn) => Nil - case Some(tn: TreeNode[_]) => tn.simpleString :: Nil - case tn: TreeNode[_] => tn.simpleString :: Nil + case Some(tn: TreeNode[_]) => tn.simpleString(maxFields) :: Nil + case tn: TreeNode[_] => tn.simpleString(maxFields) :: Nil case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil case iter: Iterable[_] if iter.isEmpty => Nil - case seq: Seq[_] => truncatedString(seq, "[", ", ", "]") :: Nil - case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}") :: Nil + case seq: Seq[_] => truncatedString(seq, "[", ", ", "]", maxFields) :: Nil + case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}", maxFields) :: Nil case array: Array[_] if array.isEmpty => Nil - case array: Array[_] => truncatedString(array, "[", ", ", "]") :: Nil + case array: Array[_] => truncatedString(array, "[", ", ", "]", maxFields) :: Nil case null => Nil case None => Nil case Some(null) => Nil @@ -456,24 +457,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other :: Nil }.mkString(", ") - /** ONE line description of this node. */ - def simpleString: String = s"$nodeName $argString".trim + /** + * ONE line description of this node. + * @param maxFields Maximum number of fields that will be converted to strings. + * Any elements beyond the limit will be dropped. + */ + def simpleString(maxFields: Int): String = { + s"$nodeName ${argString(maxFields)}".trim + } /** ONE line description of this node with more information */ - def verboseString: String + def verboseString(maxFields: Int): String /** ONE line description of this node with some suffix information */ - def verboseStringWithSuffix: String = verboseString + def verboseStringWithSuffix(maxFields: Int): String = verboseString(maxFields) override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ def treeString: String = treeString(verbose = true) - def treeString(verbose: Boolean, addSuffix: Boolean = false): String = { + def treeString( + verbose: Boolean, + addSuffix: Boolean = false, + maxFields: Int = SQLConf.get.maxToStringFields): String = { val writer = new StringBuilderWriter() try { - treeString(writer, verbose, addSuffix) + treeString(writer, verbose, addSuffix, maxFields) writer.toString } finally { writer.close() @@ -483,8 +493,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def treeString( writer: Writer, verbose: Boolean, - addSuffix: Boolean): Unit = { - generateTreeString(0, Nil, writer, verbose, "", addSuffix) + addSuffix: Boolean, + maxFields: Int): Unit = { + generateTreeString(0, Nil, writer, verbose, "", addSuffix, maxFields) } /** @@ -550,7 +561,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { + addSuffix: Boolean = false, + maxFields: Int): Unit = { if (depth > 0) { lastChildren.init.foreach { isLast => @@ -560,9 +572,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } val str = if (verbose) { - if (addSuffix) verboseStringWithSuffix else verboseString + if (addSuffix) verboseStringWithSuffix(maxFields) else verboseString(maxFields) } else { - simpleString + simpleString(maxFields) } writer.write(prefix) writer.write(str) @@ -571,17 +583,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose, - addSuffix = addSuffix)) + addSuffix = addSuffix, maxFields = maxFields)) innerChildren.last.generateTreeString( depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose, - addSuffix = addSuffix) + addSuffix = addSuffix, maxFields = maxFields) } if (children.nonEmpty) { children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix)) + depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix, maxFields)) children.last.generateTreeString( - depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix) + depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix, maxFields) } } @@ -664,7 +676,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => JArray(t.map(parseToJson).toList) case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => - JString(truncatedString(t, "[", ", ", "]")) + JString(truncatedString(t, "[", ", ", "]", SQLConf.get.maxToStringFields)) case t: Seq[_] => JNull case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 277584b20dcd2..7f5860e12cfd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -184,14 +184,14 @@ package object util extends Logging { start: String, sep: String, end: String, - maxNumFields: Int = SQLConf.get.maxToStringFields): String = { - if (seq.length > maxNumFields) { + maxFields: Int): String = { + if (seq.length > maxFields) { if (truncationWarningPrinted.compareAndSet(false, true)) { logWarning( "Truncated the string representation of a plan since it was too large. This " + s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.") } - val numFields = math.max(0, maxNumFields - 1) + val numFields = math.max(0, maxFields - 1) seq.take(numFields).mkString( start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) } else { @@ -200,7 +200,9 @@ package object util extends Logging { } /** Shorthand for calling truncatedString() without start or end strings. */ - def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") + def truncatedString[T](seq: Seq[T], sep: String, maxFields: Int): String = { + truncatedString(seq, "", sep, "", maxFields) + } /* FIX ME implicit class debugLogging(a: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e01d7c59cac52..d563276a5711d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} +import org.apache.spark.sql.internal.SQLConf /** * A [[StructType]] object can be constructed by @@ -343,7 +344,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") - truncatedString(fieldTypes, "struct<", ",", ">") + truncatedString( + fieldTypes, + "struct<", ",", ">", + SQLConf.get.maxToStringFields) } override def catalogString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 63c7b42978025..0e0c8e167a0a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -215,7 +215,7 @@ class PercentileSuite extends SparkFunSuite { val percentile2 = new Percentile(child, percentage) assertEqual(percentile2.checkInputDataTypes(), TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " + - s"but got ${percentage.simpleString}")) + s"but got ${percentage.simpleString(100)}")) } val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala index 9c162026942f6..d95de71e897a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala @@ -26,6 +26,6 @@ class UtilSuite extends SparkFunSuite { assert(truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") - assert(truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") + assert(truncatedString(Seq(1, 2, 3), ", ", 10) == "1, 2, 3") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 322ffffca564b..1d7dd73706c48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -52,19 +52,19 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { // Metadata that describes more details of this scan. protected def metadata: Map[String, String] - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { val metadataEntries = metadata.toSeq.sorted.map { case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) } - val metadataStr = truncatedString(metadataEntries, " ", ", ", "") - s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]")}$metadataStr" + val metadataStr = truncatedString(metadataEntries, " ", ", ", "", maxFields) + s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]", maxFields)}$metadataStr" } - override def verboseString: String = redact(super.verboseString) + override def verboseString(maxFields: Int): String = redact(super.verboseString(maxFields)) - override def treeString(verbose: Boolean, addSuffix: Boolean): String = { - redact(super.treeString(verbose, addSuffix)) + override def treeString(verbose: Boolean, addSuffix: Boolean, maxFields: Int): String = { + redact(super.treeString(verbose, addSuffix, maxFields)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 49fb288fdea6a..981ecae80a724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -79,7 +79,7 @@ case class ExternalRDDScanExec[T]( } } - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { s"$nodeName${output.mkString("[", ",", "]")}" } } @@ -156,8 +156,8 @@ case class RDDScanExec( } } - override def simpleString: String = { - s"$nodeName${truncatedString(output, "[", ",", "]")}" + override def simpleString(maxFields: Int): String = { + s"$nodeName${truncatedString(output, "[", ",", "]", maxFields)}" } // Input can be InternalRow, has to be turned into UnsafeRows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index eef5a3f899f55..9b8d2e830867d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -208,27 +209,27 @@ class QueryExecution( } } - private def writePlans(writer: Writer): Unit = { + private def writePlans(writer: Writer, maxFields: Int): Unit = { val (verbose, addSuffix) = (true, false) writer.write("== Parsed Logical Plan ==\n") - writeOrError(writer)(logical.treeString(_, verbose, addSuffix)) + writeOrError(writer)(logical.treeString(_, verbose, addSuffix, maxFields)) writer.write("\n== Analyzed Logical Plan ==\n") val analyzedOutput = stringOrError(truncatedString( - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")) + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields)) writer.write(analyzedOutput) writer.write("\n") - writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix)) + writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix, maxFields)) writer.write("\n== Optimized Logical Plan ==\n") - writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix)) + writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix, maxFields)) writer.write("\n== Physical Plan ==\n") - writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix)) + writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix, maxFields)) } override def toString: String = withRedaction { val writer = new StringBuilderWriter() try { - writePlans(writer) + writePlans(writer, SQLConf.get.maxToStringFields) writer.toString } finally { writer.close() @@ -280,14 +281,16 @@ class QueryExecution( /** * Dumps debug information about query execution into the specified file. + * + * @param maxFields maximim number of fields converted to string representation. */ - def toFile(path: String): Unit = { + def toFile(path: String, maxFields: Int = Int.MaxValue): Unit = { val filePath = new Path(path) val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf()) val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath))) try { - writePlans(writer) + writePlans(writer, maxFields) writer.write("\n== Whole Stage Codegen ==\n") org.apache.spark.sql.execution.debug.writeCodegen(writer, executedPlan) } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 59ffd16381116..f554ff0aa775f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.sql.internal.SQLConf /** * :: DeveloperApi :: @@ -62,7 +63,10 @@ private[execution] object SparkPlanInfo { case fileScan: FileSourceScanExec => fileScan.metadata case _ => Map[String, String]() } - new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + new SparkPlanInfo( + plan.nodeName, + plan.simpleString(SQLConf.get.maxToStringFields), + children.map(fromSparkPlan), metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fbda0d87a175f..f4927dedabe56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -87,7 +87,7 @@ trait CodegenSupport extends SparkPlan { this.parent = parent ctx.freshNamePrefix = variablePrefix s""" - |${ctx.registerComment(s"PRODUCE: ${this.simpleString}")} + |${ctx.registerComment(s"PRODUCE: ${this.simpleString(SQLConf.get.maxToStringFields)}")} |${doProduce(ctx)} """.stripMargin } @@ -188,7 +188,7 @@ trait CodegenSupport extends SparkPlan { parent.doConsume(ctx, inputVars, rowVar) } s""" - |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} + |${ctx.registerComment(s"CONSUME: ${parent.simpleString(SQLConf.get.maxToStringFields)}")} |$evaluated |$consumeFunc """.stripMargin @@ -496,8 +496,16 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { - child.generateTreeString(depth, lastChildren, writer, verbose, prefix = "", addSuffix = false) + addSuffix: Boolean = false, + maxFields: Int): Unit = { + child.generateTreeString( + depth, + lastChildren, + writer, + verbose, + prefix = "", + addSuffix = false, + maxFields) } override def needCopyResult: Boolean = false @@ -772,8 +780,16 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): Unit = { - child.generateTreeString(depth, lastChildren, writer, verbose, s"*($codegenStageId) ", false) + addSuffix: Boolean = false, + maxFields: Int): Unit = { + child.generateTreeString( + depth, + lastChildren, + writer, + verbose, + s"*($codegenStageId) ", + false, + maxFields) } override def needStopCheck: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4827f838fc514..2355d305c38e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -922,18 +922,18 @@ case class HashAggregateExec( """ } - override def verboseString: String = toString(verbose = true) + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - override def simpleString: String = toString(verbose = false) + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - private def toString(verbose: Boolean): String = { + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => - val keyString = truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 7145bb03028d9..bd52c6321647a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -137,15 +137,15 @@ case class ObjectHashAggregateExec( } } - override def verboseString: String = toString(verbose = true) + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - override def simpleString: String = toString(verbose = false) + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - private def toString(verbose: Boolean): String = { + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions - val keyString = truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index d732b905dcdd5..7ab6ecc08a7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -107,16 +107,16 @@ case class SortAggregateExec( } } - override def simpleString: String = toString(verbose = false) + override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - override def verboseString: String = toString(verbose = true) + override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - private def toString(verbose: Boolean): String = { + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions - val keyString = truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) + val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 09effe087e195..2570b36b3166d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -586,7 +586,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" + override def simpleString(maxFields: Int): String = { + s"Range ($start, $end, step=$step, splits=$numSlices)" + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 73eb65f84489c..4109d9994dd8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -209,6 +209,6 @@ case class InMemoryRelation( override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) - override def simpleString: String = - s"InMemoryRelation [${truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" + override def simpleString(maxFields: Int): String = + s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 1023572d19e2e..db3604fe92cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -63,7 +63,9 @@ case class LogicalRelation( case _ => // Do nothing. } - override def simpleString: String = s"Relation[${truncatedString(output, ",")}] $relation" + override def simpleString(maxFields: Int): String = { + s"Relation[${truncatedString(output, ",", maxFields)}] $relation" + } } object LogicalRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 00b1b5dedb593..f29e7869fb27c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -48,7 +48,7 @@ case class SaveIntoDataSourceCommand( Seq.empty[Row] } - override def simpleString: String = { + override def simpleString(maxFields: Int): String = { val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index fdc5e85f3c2ea..042320edea4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -68,7 +68,7 @@ case class CreateTempViewUsing( s"Temporary view '$tableIdent' should not have specified a database") } - override def argString: String = { + override def argString(maxFields: Int): String = { s"[tableIdent:$tableIdent " + userSpecifiedSchema.map(_ + " ").getOrElse("") + s"replace:$replace " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 0a6b0afe6cfe5..7bf2b8bff3732 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -52,8 +52,8 @@ case class DataSourceV2Relation( override def name: String = table.name() - override def simpleString: String = { - s"RelationV2${truncatedString(output, "[", ", ", "]")} $name" + override def simpleString(maxFields: Int): String = { + s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } def newWriteSupport(): BatchWriteSupport = source.createWriteSupport(options, schema) @@ -96,7 +96,9 @@ case class StreamingDataSourceV2Relation( override def isStreaming: Boolean = true - override def simpleString: String = "Streaming RelationV2 " + metadataString + override def simpleString(maxFields: Int): String = { + "Streaming RelationV2 " + metadataString(maxFields) + } override def pushedFilters: Seq[Expression] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 725bcc3af3ca5..53e4e77c65e26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -35,8 +35,8 @@ case class DataSourceV2ScanExec( @transient batch: Batch) extends LeafExecNode with ColumnarBatchScan { - override def simpleString: String = { - s"ScanV2${truncatedString(output, "[", ", ", "]")} $scanDesc" + override def simpleString(maxFields: Int): String = { + s"ScanV2${truncatedString(output, "[", ", ", "]", maxFields)} $scanDesc" } // TODO: unify the equal/hashCode implementation for all data source v2 query plans. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala index c872940909964..be75fe4f596dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StreamingScanExec.scala @@ -42,7 +42,7 @@ case class DataSourceV2StreamingScanExec( @transient scanConfig: ScanConfig) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { - override def simpleString: String = "ScanV2 " + metadataString + override def simpleString(maxFields: Int): String = "ScanV2 " + metadataString(maxFields) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index e829f621b4ea3..f11703c8a2773 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -59,7 +59,7 @@ trait DataSourceV2StringFormat { case _ => Utils.getSimpleName(source.getClass) } - def metadataString: String = { + def metadataString(maxFields: Int): String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] if (pushedFilters.nonEmpty) { @@ -73,12 +73,12 @@ trait DataSourceV2StringFormat { }.mkString("[", ",", "]") } - val outputStr = truncatedString(output, "[", ", ", "]") + val outputStr = truncatedString(output, "[", ", ", "]", maxFields) val entriesStr = if (entries.nonEmpty) { truncatedString(entries.map { case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) - }, " (", ", ", ")") + }, " (", ", ", ")", maxFields) } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3511cefa7c292..ae8197f617a28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, Codegen import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{AccumulatorV2, LongAccumulator} @@ -216,7 +217,7 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - debugPrint(s"== ${child.simpleString} ==") + debugPrint(s"== ${child.simpleString(SQLConf.get.maxToStringFields)} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => // This is called on driver. All accumulator updates have a fixed value. So it's safe to use diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index bfaf080292bce..56973af8fd648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -198,9 +198,9 @@ case class TakeOrderedAndProjectExec( override def outputPartitioning: Partitioning = SinglePartition - override def simpleString: String = { - val orderByString = truncatedString(sortOrder, "[", ",", "]") - val outputString = truncatedString(output, "[", ",", "]") + override def simpleString(maxFields: Int): String = { + val orderByString = truncatedString(sortOrder, "[", ",", "]", maxFields) + val outputString = truncatedString(output, "[", ",", "]", maxFields) s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 8ad436a4ff57d..38ecb0dd12daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -482,9 +483,10 @@ class MicroBatchExecution( val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => + val maxFields = SQLConf.get.maxToStringFields assert(output.size == dataPlan.output.size, - s"Invalid batch: ${truncatedString(output, ",")} != " + - s"${truncatedString(dataPlan.output, ",")}") + s"Invalid batch: ${truncatedString(output, ",", maxFields)} != " + + s"${truncatedString(dataPlan.output, ",", maxFields)}") val aliases = output.zip(dataPlan.output).map { case (to, from) => Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f0859aaaa3041..89033b70f1431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} @@ -166,10 +167,10 @@ class ContinuousExecution( val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = readSupport.fullSchema().toAttributes - + val maxFields = SQLConf.get.maxToStringFields assert(output.size == newOutput.size, - s"Invalid reader: ${truncatedString(output, ",")} != " + - s"${truncatedString(newOutput, ",")}") + s"Invalid reader: ${truncatedString(output, ",", maxFields)} != " + + s"${truncatedString(newOutput, ",", maxFields)}") replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index daee089f3871d..13b75ae4a4339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -117,7 +118,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${truncatedString(output, ",")}]" + override def toString: String = { + s"MemoryStream[${truncatedString(output, ",", SQLConf.get.maxToStringFields)}]" + } override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 310ebcdf67686..e180d2228c3b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -51,7 +51,7 @@ case class ScalarSubquery( override def dataType: DataType = plan.schema.fields.head.dataType override def children: Seq[Expression] = Nil override def nullable: Boolean = true - override def toString: String = plan.simpleString + override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query) override def semanticEquals(other: Expression): Boolean = other match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 0c47a2040f171..3cc97c995702a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -106,6 +106,19 @@ class QueryExecutionSuite extends SharedSQLContext { } } + test("check maximum fields restriction") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + val ds = spark.createDataset(Seq(QueryExecutionTestRecord( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) + ds.queryExecution.debug.toFile(path) + val localRelations = Source.fromFile(path).getLines().filter(_.contains("LocalRelation")) + + assert(!localRelations.exists(_.contains("more fields"))) + } + } + test("toString() exception/error handling") { spark.experimental.extraStrategies = Seq( new SparkStrategy { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 608f21e726259..7249eacfbf9a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -83,7 +83,7 @@ trait CreateHiveTableAsSelectBase extends DataWritingCommand { tableDesc: CatalogTable, tableExists: Boolean): DataWritingCommand - override def argString: String = { + override def argString(maxFields: Int): String = { s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" From add287f397d41c1725464dff89d4a555ffc9db04 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 27 Dec 2018 22:26:37 +0800 Subject: [PATCH 2340/2461] [SPARK-25892][SQL] Change AttributeReference.withMetadata's return type to AttributeReference ## What changes were proposed in this pull request? Currently the `AttributeReference.withMetadata` method have return type `Attribute`, the rest of with methods in the `AttributeReference` return type are `AttributeReference`, as the [spark-25892](https://issues.apache.org/jira/browse/SPARK-25892?jql=project%20%3D%20SPARK%20AND%20component%20in%20(ML%2C%20PySpark%2C%20SQL)) mentioned. This PR will change `AttributeReference.withMetadata` method's return type from `Attribute` to `AttributeReference`. ## How was this patch tested? Run all `sql/test,` `catalyst/test` and `org.apache.spark.sql.execution.streaming.*` Closes #22918 from kevinyu98/spark-25892. Authored-by: Kevin Yu Signed-off-by: Hyukjin Kwon --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 131459bf27bc8..7ebb171f34ba2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -311,7 +311,7 @@ case class AttributeReference( } } - override def withMetadata(newMetadata: Metadata): Attribute = { + override def withMetadata(newMetadata: Metadata): AttributeReference = { AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier) } From 68496c1af310aadfb1b226cb05be510252769d43 Mon Sep 17 00:00:00 2001 From: deepyaman Date: Fri, 28 Dec 2018 00:02:41 +0800 Subject: [PATCH 2341/2461] [SPARK-26451][SQL] Change lead/lag argument name from count to offset ## What changes were proposed in this pull request? Change aligns argument name with that in Scala version and documentation. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23357 from deepyaman/patch-1. Authored-by: deepyaman Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d2a771e9bb8ea..3c33e2bed92d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -798,7 +798,7 @@ def factorial(col): # --------------- Window functions ------------------------ @since(1.4) -def lag(col, count=1, default=None): +def lag(col, offset=1, default=None): """ Window function: returns the value that is `offset` rows before the current row, and `defaultValue` if there is less than `offset` rows before the current row. For example, @@ -807,15 +807,15 @@ def lag(col, count=1, default=None): This is equivalent to the LAG function in SQL. :param col: name of column or expression - :param count: number of row to extend + :param offset: number of row to extend :param default: default value """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.lag(_to_java_column(col), count, default)) + return Column(sc._jvm.functions.lag(_to_java_column(col), offset, default)) @since(1.4) -def lead(col, count=1, default=None): +def lead(col, offset=1, default=None): """ Window function: returns the value that is `offset` rows after the current row, and `defaultValue` if there is less than `offset` rows after the current row. For example, @@ -824,11 +824,11 @@ def lead(col, count=1, default=None): This is equivalent to the LEAD function in SQL. :param col: name of column or expression - :param count: number of row to extend + :param offset: number of row to extend :param default: default value """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.lead(_to_java_column(col), count, default)) + return Column(sc._jvm.functions.lead(_to_java_column(col), offset, default)) @since(1.4) From f2adb610680f9addec51bf470acf64f3849073e9 Mon Sep 17 00:00:00 2001 From: wuqingxin Date: Fri, 28 Dec 2018 00:15:57 -0800 Subject: [PATCH 2342/2461] [SPARK-26446][CORE] Add cachedExecutorIdleTimeout docs at ExecutorAllocationManager ## What changes were proposed in this pull request? Add docs to describe how remove policy act while considering the property `spark.dynamicAllocation.cachedExecutorIdleTimeout` in ExecutorAllocationManager ## How was this patch tested? comment-only PR. Closes #23386 from TopGunViper/SPARK-26446. Authored-by: wuqingxin Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/ExecutorAllocationManager.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index c3e5b96a55884..3f0b71bbe17f1 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -57,7 +57,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * a long time to ramp up under heavy workloads. * * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not - * been scheduled to run any tasks, then it is removed. + * been scheduled to run any tasks, then it is removed. Note that an executor caching any data + * blocks will be removed if it has been idle for more than L seconds. * * There is no retry logic in either case because we make the assumption that the cluster manager * will eventually fulfill all requests it receives asynchronously. @@ -81,7 +82,12 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * This is used only after the initial backlog timeout is exceeded * * spark.dynamicAllocation.executorIdleTimeout (K) - - * If an executor has been idle for this duration, remove it + * If an executor without caching any data blocks has been idle for this duration, remove it + * + * spark.dynamicAllocation.cachedExecutorIdleTimeout (L) - + * If an executor with caching data blocks has been idle for more than this duration, + * the executor will be removed + * */ private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, From 5bef4fedfe1916320223b1245bacb58f151cee66 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Fri, 28 Dec 2018 07:40:59 -0600 Subject: [PATCH 2343/2461] [SPARK-26444][WEBUI] Stage color doesn't change with it's status ## What changes were proposed in this pull request? On job page, in event timeline section, stage color doesn't change according to its status. Below are some screenshots. ACTIVE: active COMPLETE: complete FAILED: failed This PR lets stage color change with it's status. The main idea is to make css style class name match the corresponding stage status. ## How was this patch tested? Manually tested locally. ``` // active/complete stage sc.parallelize(1 to 3, 3).map { n => Thread.sleep(10* 1000); n }.count // failed stage sc.parallelize(1 to 3, 3).map { n => Thread.sleep(10* 1000); throw new Exception() }.count ``` Note we need to clear browser cache to let new `timeline-view.css` take effect. Below are screenshots after this PR. ACTIVE: active-after COMPLETE: complete-after FAILED: failed-after Closes #23385 from seancxmao/timeline-stage-color. Authored-by: seancxmao Signed-off-by: Sean Owen --- .../org/apache/spark/ui/static/timeline-view.css | 8 ++++---- .../src/main/scala/org/apache/spark/ui/jobs/JobPage.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index 3bf3e8bfa1f31..10bceae2fbdda 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -98,12 +98,12 @@ rect.getting-result-time-proportion { cursor: pointer; } -.vis-timeline .vis-item.stage.succeeded { +.vis-timeline .vis-item.stage.complete { background-color: #A0DFFF; border-color: #3EC0FF; } -.vis-timeline .vis-item.stage.succeeded.vis-selected { +.vis-timeline .vis-item.stage.complete.vis-selected { background-color: #A0DFFF; border-color: #3EC0FF; z-index: auto; @@ -130,12 +130,12 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis-timeline .vis-item.stage.running { +.vis-timeline .vis-item.stage.active { background-color: #A2FCC0; border-color: #36F572; } -.vis-timeline .vis-item.stage.running.vis-selected { +.vis-timeline .vis-item.stage.active.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: auto; diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index b58a6ca447edf..cd82439223b07 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -62,7 +62,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stageId = stage.stageId val attemptId = stage.attemptId val name = stage.name - val status = stage.status.toString + val status = stage.status.toString.toLowerCase(Locale.ROOT) val submissionTime = stage.submissionTime.get.getTime() val completionTime = stage.completionTime.map(_.getTime()) .getOrElse(System.currentTimeMillis()) From e0054b88a1624ec5196dc206997db065731099ac Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 28 Dec 2018 11:29:06 -0800 Subject: [PATCH 2344/2461] [SPARK-26424][SQL][FOLLOWUP] Fix DateFormatClass/UnixTime codegen ## What changes were proposed in this pull request? This PR fixes the codegen bug introduced by #23358 . - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7-ubuntu-scala-2.11/158/ ``` Line 44, Column 93: A method named "apply" is not declared in any enclosing class nor any supertype, nor through a static import ``` ## How was this patch tested? Manual. `DateExpressionsSuite` should be passed with Scala-2.11. Closes #23394 from dongjoon-hyun/SPARK-26424. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/expressions/datetimeExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 73af0a3c5c2ee..8fc0112c02577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -571,7 +571,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti val tz = ctx.addReferenceObj("timeZone", timeZone) val locale = ctx.addReferenceObj("locale", Locale.US) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($tf.apply($format.toString(), $tz, $locale) + s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $tz, $locale) .format($timestamp))""" }) } @@ -741,11 +741,11 @@ abstract class UnixTime case StringType => val tz = ctx.addReferenceObj("timeZone", timeZone) val locale = ctx.addReferenceObj("locale", Locale.US) - val dtu = TimestampFormatter.getClass.getName.stripSuffix("$") + val tf = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $dtu.apply($format.toString(), $tz, $locale) + ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $tz, $locale) .parse($string.toString()) / 1000000L; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; From e63243df8aca9f44255879e931e0c372beef9fc2 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sat, 29 Dec 2018 12:11:45 -0800 Subject: [PATCH 2345/2461] [SPARK-26496][SS][TEST] Avoid to use Random.nextString in StreamingInnerJoinSuite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Similar with https://github.com/apache/spark/pull/21446. Looks random string is not quite safe as a directory name. ```scala scala> val prefix = Random.nextString(10); val dir = new File("/tmp", "del_" + prefix + "-" + UUID.randomUUID.toString); dir.mkdirs() prefix: String = 窽텘⒘駖ⵚ駢⡞Ρ닋੎ dir: java.io.File = /tmp/del_窽텘⒘駖ⵚ駢⡞Ρ닋੎-a3f99855-c429-47a0-a108-47bca6905745 res40: Boolean = false // nope, didn't like this one ``` ## How was this patch tested? Unit test was added, and manually. Closes #23405 from HyukjinKwon/SPARK-26496. Authored-by: Hyukjin Kwon Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/streaming/StreamingJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index c5cc8df4356a8..42fe9f34ee3ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -350,7 +350,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with withTempDir { tempDir => val queryId = UUID.randomUUID val opId = 0 - val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextFloat.toString).toString val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) implicit val sqlContext = spark.sqlContext From e6d3e7d0d8c80adaa51b43d76f1cc83bb9a010b9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 29 Dec 2018 17:33:43 -0800 Subject: [PATCH 2346/2461] [SPARK-26443][CORE] Use ConfigEntry for hardcoded configs for history category. ## What changes were proposed in this pull request? This pr makes hardcoded "spark.history" configs to use `ConfigEntry` and put them in `History` config object. ## How was this patch tested? Existing tests. Closes #23384 from ueshin/issues/SPARK-26443/hardcoded_history_configs. Authored-by: Takuya UESHIN Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/SparkConf.scala | 4 +- .../deploy/history/FsHistoryProvider.scala | 21 ++++----- .../spark/deploy/history/HistoryServer.scala | 16 ++++--- .../history/HistoryServerArguments.scala | 2 +- .../spark/internal/config/History.scala | 46 ++++++++++++++++++- .../org/apache/spark/SparkConfSuite.scala | 2 +- .../history/FsHistoryProviderSuite.scala | 23 +++++----- .../history/HistoryServerArgumentsSuite.scala | 9 ++-- .../deploy/history/HistoryServerSuite.scala | 13 +++--- 9 files changed, 89 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8d135d3e083d7..0b47da12b5b42 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -682,11 +682,11 @@ private[spark] object SparkConf extends Logging { private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( "spark.executor.userClassPathFirst" -> Seq( AlternateConfig("spark.files.userClassPathFirst", "1.3")), - "spark.history.fs.update.interval" -> Seq( + UPDATE_INTERVAL_S.key -> Seq( AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), AlternateConfig("spark.history.fs.updateInterval", "1.3"), AlternateConfig("spark.history.updateInterval", "1.3")), - "spark.history.fs.cleaner.interval" -> Seq( + CLEANER_INTERVAL_S.key -> Seq( AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), MAX_LOG_AGE_S.key -> Seq( AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index da6e5f03aabb5..709a380dfb636 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -42,7 +42,7 @@ import org.fusesource.leveldbjni.internal.NativeDB import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.DRIVER_LOG_DFS_DIR +import org.apache.spark.internal.config.{DRIVER_LOG_DFS_DIR, History} import org.apache.spark.internal.config.History._ import org.apache.spark.internal.config.Status._ import org.apache.spark.io.CompressionCodec @@ -91,24 +91,22 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) import FsHistoryProvider._ // Interval between safemode checks. - private val SAFEMODE_CHECK_INTERVAL_S = conf.getTimeAsSeconds( - "spark.history.fs.safemodeCheck.interval", "5s") + private val SAFEMODE_CHECK_INTERVAL_S = conf.get(History.SAFEMODE_CHECK_INTERVAL_S) // Interval between each check for event log updates - private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") + private val UPDATE_INTERVAL_S = conf.get(History.UPDATE_INTERVAL_S) // Interval between each cleaner checks for event logs to delete - private val CLEAN_INTERVAL_S = conf.get(CLEANER_INTERVAL_S) + private val CLEAN_INTERVAL_S = conf.get(History.CLEANER_INTERVAL_S) // Number of threads used to replay event logs. - private val NUM_PROCESSING_THREADS = conf.getInt(SPARK_HISTORY_FS_NUM_REPLAY_THREADS, - Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) + private val NUM_PROCESSING_THREADS = conf.get(History.NUM_REPLAY_THREADS) - private val logDir = conf.get(EVENT_LOG_DIR) + private val logDir = conf.get(History.HISTORY_LOG_DIR) - private val HISTORY_UI_ACLS_ENABLE = conf.getBoolean("spark.history.ui.acls.enable", false) - private val HISTORY_UI_ADMIN_ACLS = conf.get("spark.history.ui.admin.acls", "") - private val HISTORY_UI_ADMIN_ACLS_GROUPS = conf.get("spark.history.ui.admin.acls.groups", "") + private val HISTORY_UI_ACLS_ENABLE = conf.get(History.UI_ACLS_ENABLE) + private val HISTORY_UI_ADMIN_ACLS = conf.get(History.UI_ADMIN_ACLS) + private val HISTORY_UI_ADMIN_ACLS_GROUPS = conf.get(History.UI_ADMIN_ACLS_GROUPS) logInfo(s"History server ui acls " + (if (HISTORY_UI_ACLS_ENABLE) "enabled" else "disabled") + "; users with admin permissions: " + HISTORY_UI_ADMIN_ACLS.toString + "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) @@ -1089,7 +1087,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } private[history] object FsHistoryProvider { - private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 5856c7057b745..b9303388638fd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.internal.config.History.HISTORY_SERVER_UI_PORT +import org.apache.spark.internal.config.History import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ @@ -56,7 +56,7 @@ class HistoryServer( with Logging with UIRoot with ApplicationCacheOperations { // How many applications to retain - private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) + private val retainedApplications = conf.get(History.RETAINED_APPLICATIONS) // How many applications the summary ui displays private[history] val maxApplications = conf.get(HISTORY_UI_MAX_APPS); @@ -273,14 +273,14 @@ object HistoryServer extends Logging { initSecurity() val securityManager = createSecurityManager(conf) - val providerName = conf.getOption("spark.history.provider") + val providerName = conf.get(History.PROVIDER) .getOrElse(classOf[FsHistoryProvider].getName()) val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.get(HISTORY_SERVER_UI_PORT) + val port = conf.get(History.HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() @@ -319,10 +319,12 @@ object HistoryServer extends Logging { // from a keytab file so that we can access HDFS beyond the kerberos ticket expiration. // As long as it is using Hadoop rpc (hdfs://), a relogin will automatically // occur from the keytab. - if (conf.getBoolean("spark.history.kerberos.enabled", false)) { + if (conf.get(History.KERBEROS_ENABLED)) { // if you have enabled kerberos the following 2 params must be set - val principalName = conf.get("spark.history.kerberos.principal") - val keytabFilename = conf.get("spark.history.kerberos.keytab") + val principalName = conf.get(History.KERBEROS_PRINCIPAL) + .getOrElse(throw new NoSuchElementException(History.KERBEROS_PRINCIPAL.key)) + val keytabFilename = conf.get(History.KERBEROS_KEYTAB) + .getOrElse(throw new NoSuchElementException(History.KERBEROS_KEYTAB.key)) SparkHadoopUtil.get.loginUserFromKeytab(principalName, keytabFilename) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 49f00cb10179e..dec89769c030b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -79,7 +79,7 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin | | spark.history.fs.logDirectory Directory where app logs are stored | (default: file:/tmp/spark-events) - | spark.history.fs.updateInterval How often to reload log data from storage + | spark.history.fs.update.interval How often to reload log data from storage | (in seconds, default: 10) |""".stripMargin) // scalastyle:on println diff --git a/core/src/main/scala/org/apache/spark/internal/config/History.scala b/core/src/main/scala/org/apache/spark/internal/config/History.scala index b7d8061d26d21..f984dd385344b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/History.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/History.scala @@ -25,10 +25,18 @@ private[spark] object History { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" - val EVENT_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") + val HISTORY_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") .stringConf .createWithDefault(DEFAULT_LOG_DIR) + val SAFEMODE_CHECK_INTERVAL_S = ConfigBuilder("spark.history.fs.safemodeCheck.interval") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("5s") + + val UPDATE_INTERVAL_S = ConfigBuilder("spark.history.fs.update.interval") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("10s") + val CLEANER_ENABLED = ConfigBuilder("spark.history.fs.cleaner.enabled") .booleanConf .createWithDefault(false) @@ -79,4 +87,40 @@ private[spark] object History { val MAX_DRIVER_LOG_AGE_S = ConfigBuilder("spark.history.fs.driverlog.cleaner.maxAge") .fallbackConf(MAX_LOG_AGE_S) + + val UI_ACLS_ENABLE = ConfigBuilder("spark.history.ui.acls.enable") + .booleanConf + .createWithDefault(false) + + val UI_ADMIN_ACLS = ConfigBuilder("spark.history.ui.admin.acls") + .stringConf + .createWithDefault("") + + val UI_ADMIN_ACLS_GROUPS = ConfigBuilder("spark.history.ui.admin.acls.groups") + .stringConf + .createWithDefault("") + + val NUM_REPLAY_THREADS = ConfigBuilder("spark.history.fs.numReplayThreads") + .intConf + .createWithDefaultFunction(() => Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) + + val RETAINED_APPLICATIONS = ConfigBuilder("spark.history.retainedApplications") + .intConf + .createWithDefault(50) + + val PROVIDER = ConfigBuilder("spark.history.provider") + .stringConf + .createOptional + + val KERBEROS_ENABLED = ConfigBuilder("spark.history.kerberos.enabled") + .booleanConf + .createWithDefault(false) + + val KERBEROS_PRINCIPAL = ConfigBuilder("spark.history.kerberos.principal") + .stringConf + .createOptional + + val KERBEROS_KEYTAB = ConfigBuilder("spark.history.kerberos.keytab") + .stringConf + .createOptional } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 7cb03deae1391..e14a5dcb5ef84 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -232,7 +232,7 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst test("deprecated configs") { val conf = new SparkConf() - val newName = "spark.history.fs.update.interval" + val newName = UPDATE_INTERVAL_S.key assert(!conf.contains(newName)) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index c1ae27aa940f6..6d2e329094ae2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -294,7 +294,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val maxAge = TimeUnit.SECONDS.toMillis(10) val clock = new ManualClock(maxAge / 2) val provider = new FsHistoryProvider( - createTestConf().set("spark.history.fs.cleaner.maxAge", s"${maxAge}ms"), clock) + createTestConf().set(MAX_LOG_AGE_S.key, s"${maxAge}ms"), clock) val log1 = newLogFile("app1", Some("attempt1"), inProgress = false) writeFile(log1, true, None, @@ -379,7 +379,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val maxAge = TimeUnit.SECONDS.toMillis(40) val clock = new ManualClock(0) val provider = new FsHistoryProvider( - createTestConf().set("spark.history.fs.cleaner.maxAge", s"${maxAge}ms"), clock) + createTestConf().set(MAX_LOG_AGE_S.key, s"${maxAge}ms"), clock) val log1 = newLogFile("inProgressApp1", None, inProgress = true) writeFile(log1, true, None, @@ -462,8 +462,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val maxAge = TimeUnit.SECONDS.toSeconds(40) val clock = new ManualClock(0) val testConf = new SparkConf() - testConf.set("spark.history.fs.logDirectory", - Utils.createTempDir(namePrefix = "eventLog").getAbsolutePath()) + testConf.set(HISTORY_LOG_DIR, Utils.createTempDir(namePrefix = "eventLog").getAbsolutePath()) testConf.set(DRIVER_LOG_DFS_DIR, testDir.getAbsolutePath()) testConf.set(DRIVER_LOG_CLEANER_ENABLED, true) testConf.set(DRIVER_LOG_CLEANER_INTERVAL, maxAge / 4) @@ -645,9 +644,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Test both history ui admin acls and application acls are configured. val conf1 = createTestConf() - .set("spark.history.ui.acls.enable", "true") - .set("spark.history.ui.admin.acls", "user1,user2") - .set("spark.history.ui.admin.acls.groups", "group1") + .set(UI_ACLS_ENABLE, true) + .set(UI_ADMIN_ACLS, "user1,user2") + .set(UI_ADMIN_ACLS_GROUPS, "group1") .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) createAndCheck(conf1, ("spark.admin.acls", "user"), ("spark.admin.acls.groups", "group")) { @@ -667,9 +666,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Test only history ui admin acls are configured. val conf2 = createTestConf() - .set("spark.history.ui.acls.enable", "true") - .set("spark.history.ui.admin.acls", "user1,user2") - .set("spark.history.ui.admin.acls.groups", "group1") + .set(UI_ACLS_ENABLE, true) + .set(UI_ADMIN_ACLS, "user1,user2") + .set(UI_ADMIN_ACLS_GROUPS, "group1") .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) createAndCheck(conf2) { securityManager => // Test whether user has permission to access UI. @@ -687,7 +686,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Test neither history ui admin acls nor application acls are configured. val conf3 = createTestConf() - .set("spark.history.ui.acls.enable", "true") + .set(UI_ACLS_ENABLE, true) .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) createAndCheck(conf3) { securityManager => // Test whether user has permission to access UI. @@ -1036,7 +1035,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private def createTestConf(inMemory: Boolean = false): SparkConf = { val conf = new SparkConf() - .set(EVENT_LOG_DIR, testDir.getAbsolutePath()) + .set(HISTORY_LOG_DIR, testDir.getAbsolutePath()) .set(FAST_IN_PROGRESS_PARSING, true) if (!inMemory) { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index e89733a144cfa..6b479873f69f2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -22,21 +22,22 @@ import java.nio.charset.StandardCharsets._ import com.google.common.io.Files import org.apache.spark._ +import org.apache.spark.internal.config.History._ import org.apache.spark.util.Utils class HistoryServerArgumentsSuite extends SparkFunSuite { private val logDir = new File("src/test/resources/spark-events") private val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) - .set("spark.history.fs.updateInterval", "1") + .set(HISTORY_LOG_DIR, logDir.getAbsolutePath) + .set(UPDATE_INTERVAL_S, 1L) .set("spark.testing", "true") test("No Arguments Parsing") { val argStrings = Array.empty[String] val hsa = new HistoryServerArguments(conf, argStrings) - assert(conf.get("spark.history.fs.logDirectory") === logDir.getAbsolutePath) - assert(conf.get("spark.history.fs.updateInterval") === "1") + assert(conf.get(HISTORY_LOG_DIR) === logDir.getAbsolutePath) + assert(conf.get(UPDATE_INTERVAL_S) === 1L) assert(conf.get("spark.testing") === "true") } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 2a2d013bacbda..a9dee67ae9383 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -78,8 +78,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers Utils.deleteRecursively(storeDir) assert(storeDir.mkdir()) val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir) - .set("spark.history.fs.update.interval", "0") + .set(HISTORY_LOG_DIR, logDir) + .set(UPDATE_INTERVAL_S.key, "0") .set("spark.testing", "true") .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .set("spark.eventLog.logStageExecutorMetrics.enabled", "true") @@ -416,11 +416,10 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // allowed refresh rate (1Hz) stop() val myConf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) + .set(HISTORY_LOG_DIR, logDir.getAbsolutePath) .set("spark.eventLog.dir", logDir.getAbsolutePath) - .set("spark.history.fs.update.interval", "1s") + .set(UPDATE_INTERVAL_S.key, "1s") .set("spark.eventLog.enabled", "true") - .set("spark.history.cache.window", "250ms") .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .remove("spark.testing") val provider = new FsHistoryProvider(myConf) @@ -613,8 +612,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers stop() init( "spark.ui.filters" -> classOf[FakeAuthFilter].getName(), - "spark.history.ui.acls.enable" -> "true", - "spark.history.ui.admin.acls" -> admin) + UI_ACLS_ENABLE.key -> "true", + UI_ADMIN_ACLS.key -> admin) val tests = Seq( (owner, HttpServletResponse.SC_OK), From 240817b7aea14d12f7764e17ab11073d14e8e6aa Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 29 Dec 2018 21:47:49 -0600 Subject: [PATCH 2347/2461] [SPARK-26363][WEBUI] Avoid duplicated KV store lookups in method `taskList` ## What changes were proposed in this pull request? In the method `taskList`(since https://github.com/apache/spark/pull/21688), the executor log value is queried in KV store for every task(method `constructTaskData`). This PR propose to use a hashmap for reducing duplicated KV store lookups in the method. ![image](https://user-images.githubusercontent.com/1097932/49946230-841c7680-ff29-11e8-8b83-d8f7553bfe5e.png) ## How was this patch tested? Manual check Closes #23310 from gengliangwang/removeExecutorLog. Authored-by: Gengliang Wang Signed-off-by: Sean Owen --- .../apache/spark/status/AppStatusStore.scala | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 312bcccb1cca1..0487f2f07c097 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -20,6 +20,7 @@ package org.apache.spark.status import java.util.{List => JList} import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.status.api.v1 @@ -386,10 +387,9 @@ private[spark] class AppStatusStore( def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) - store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map { taskDataWrapper => - constructTaskData(taskDataWrapper) - }.toSeq.reverse + val taskDataWrapperIter = store.view(classOf[TaskDataWrapper]).index("stage") + .first(stageKey).last(stageKey).reverse().max(maxTasks).asScala + constructTaskDataList(taskDataWrapperIter).reverse } def taskList( @@ -428,9 +428,8 @@ private[spark] class AppStatusStore( } val ordered = if (ascending) indexed else indexed.reverse() - ordered.skip(offset).max(length).asScala.map { taskDataWrapper => - constructTaskData(taskDataWrapper) - }.toSeq + val taskDataWrapperIter = ordered.skip(offset).max(length).asScala + constructTaskDataList(taskDataWrapperIter) } def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { @@ -536,24 +535,29 @@ private[spark] class AppStatusStore( store.close() } - def constructTaskData(taskDataWrapper: TaskDataWrapper) : v1.TaskData = { - val taskDataOld: v1.TaskData = taskDataWrapper.toApi - val executorLogs: Option[Map[String, String]] = try { - Some(executorSummary(taskDataOld.executorId).executorLogs) - } catch { - case e: NoSuchElementException => e.getMessage - None - } - new v1.TaskData(taskDataOld.taskId, taskDataOld.index, - taskDataOld.attempt, taskDataOld.launchTime, taskDataOld.resultFetchStart, - taskDataOld.duration, taskDataOld.executorId, taskDataOld.host, taskDataOld.status, - taskDataOld.taskLocality, taskDataOld.speculative, taskDataOld.accumulatorUpdates, - taskDataOld.errorMessage, taskDataOld.taskMetrics, - executorLogs.getOrElse(Map[String, String]()), - AppStatusUtils.schedulerDelay(taskDataOld), - AppStatusUtils.gettingResultTime(taskDataOld)) + def constructTaskDataList(taskDataWrapperIter: Iterable[TaskDataWrapper]): Seq[v1.TaskData] = { + val executorIdToLogs = new HashMap[String, Map[String, String]]() + taskDataWrapperIter.map { taskDataWrapper => + val taskDataOld: v1.TaskData = taskDataWrapper.toApi + val executorLogs = executorIdToLogs.getOrElseUpdate(taskDataOld.executorId, { + try { + executorSummary(taskDataOld.executorId).executorLogs + } catch { + case e: NoSuchElementException => + Map.empty + } + }) + + new v1.TaskData(taskDataOld.taskId, taskDataOld.index, + taskDataOld.attempt, taskDataOld.launchTime, taskDataOld.resultFetchStart, + taskDataOld.duration, taskDataOld.executorId, taskDataOld.host, taskDataOld.status, + taskDataOld.taskLocality, taskDataOld.speculative, taskDataOld.accumulatorUpdates, + taskDataOld.errorMessage, taskDataOld.taskMetrics, + executorLogs, + AppStatusUtils.schedulerDelay(taskDataOld), + AppStatusUtils.gettingResultTime(taskDataOld)) + }.toSeq } - } private[spark] object AppStatusStore { From 0996b7c95a79fb018169ed1da7a8e3e482260838 Mon Sep 17 00:00:00 2001 From: seancxmao Date: Mon, 31 Dec 2018 08:24:18 -0600 Subject: [PATCH 2348/2461] [SPARK-23375][SQL][FOLLOWUP][TEST] Test Sort metrics while Sort is missing ## What changes were proposed in this pull request? #20560/[SPARK-23375](https://issues.apache.org/jira/browse/SPARK-23375) introduced an optimizer rule to eliminate redundant Sort. For a test case named "Sort metrics" in `SQLMetricsSuite`, because range is already sorted, sort is removed by the `RemoveRedundantSorts`, which makes this test case meaningless. This PR modifies the query for testing Sort metrics and checks Sort exists in the plan. ## How was this patch tested? Modify the existing test case. Closes #23258 from seancxmao/sort-metrics. Authored-by: seancxmao Signed-off-by: Sean Owen --- .../execution/metric/SQLMetricsSuite.scala | 18 ++++++-- .../metric/SQLMetricsTestUtils.scala | 43 ++++++++++++++++--- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 47265df4831df..7368a6c9e1d64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -194,10 +194,20 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } test("Sort metrics") { - // Assume the execution plan is - // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = spark.range(10).sort('id) - testSparkPlanMetrics(ds.toDF(), 2, Map.empty) + // Assume the execution plan with node id is + // Sort(nodeId = 0) + // Exchange(nodeId = 1) + // Project(nodeId = 2) + // LocalTableScan(nodeId = 3) + // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, + // so Project here is not collapsed into LocalTableScan. + val df = Seq(1, 3, 2).toDF("id").sort('id) + testSparkPlanMetricsWithPredicates(df, 2, Map( + 0L -> (("Sort", Map( + "sort time total (min, med, max)" -> {_.toString.matches(timingMetricPattern)}, + "peak memory total (min, med, max)" -> {_.toString.matches(sizeMetricPattern)}, + "spill size total (min, med, max)" -> {_.toString.matches(sizeMetricPattern)}))) + )) } test("SortMergeJoin metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index dcc540fc4f109..2d245d2ba1e35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -40,6 +40,18 @@ trait SQLMetricsTestUtils extends SQLTestUtils { protected def statusStore: SQLAppStatusStore = spark.sharedState.statusStore + // Pattern of size SQLMetric value, e.g. "\n96.2 MiB (32.1 MiB, 32.1 MiB, 32.1 MiB)" + protected val sizeMetricPattern = { + val bytes = "([0-9]+(\\.[0-9]+)?) (EiB|PiB|TiB|GiB|MiB|KiB|B)" + s"\\n$bytes \\($bytes, $bytes, $bytes\\)" + } + + // Pattern of timing SQLMetric value, e.g. "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)" + protected val timingMetricPattern = { + val duration = "([0-9]+(\\.[0-9]+)?) (ms|s|m|h)" + s"\\n$duration \\($duration, $duration, $duration\\)" + } + /** * Get execution metrics for the SQL execution and verify metrics values. * @@ -185,15 +197,34 @@ trait SQLMetricsTestUtils extends SQLTestUtils { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet) + val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) => + (nodeName, nodeMetrics.mapValues(expectedMetricValue => + (actualMetricValue: Any) => expectedMetricValue.toString === actualMetricValue)) + } + testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates) + } + + /** + * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates. + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetricsPredicates the expected metrics predicates. The format is + * `nodeId -> (operatorName, metric name -> metric predicate)`. + */ + protected def testSparkPlanMetricsWithPredicates( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = { + val optActualMetrics = + getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) optActualMetrics.foreach { actualMetrics => - assert(expectedMetrics.keySet === actualMetrics.keySet) - for (nodeId <- expectedMetrics.keySet) { - val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) + for ((nodeId, (expectedNodeName, expectedMetricsPredicatesMap)) + <- expectedMetricsPredicates) { val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) assert(expectedNodeName === actualNodeName) - for (metricName <- expectedMetricsMap.keySet) { - assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) { + assert(metricPredicate(actualMetricsMap(metricName))) } } } From 89c92ccc2046d068aea23ae7973a97c58cfdc966 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 31 Dec 2018 16:39:46 +0100 Subject: [PATCH 2349/2461] [SPARK-26504][SQL] Rope-wise dumping of Spark plans ## What changes were proposed in this pull request? Proposed new class `StringConcat` for converting a sequence of strings to string with one memory allocation in the `toString` method. `StringConcat` replaces `StringBuilderWriter` in methods of dumping of Spark plans and codegen to strings. All `Writer` arguments are replaced by `String => Unit` in methods related to Spark plans stringification. ## How was this patch tested? It was tested by existing suites `QueryExecutionSuite`, `DebuggingSuite` as well as new tests for `StringConcat` in `StringUtilsSuite`. Closes #23406 from MaxGekk/rope-plan. Authored-by: Maxim Gekk Signed-off-by: Herman van Hovell --- .../spark/sql/catalyst/plans/QueryPlan.scala | 17 ++++ .../spark/sql/catalyst/trees/TreeNode.scala | 38 ++++----- .../spark/sql/catalyst/util/StringUtils.scala | 32 +++++++ .../sql/catalyst/util/StringUtilsSuite.scala | 13 +++ .../spark/sql/execution/QueryExecution.scala | 85 +++++++++---------- .../sql/execution/WholeStageCodegenExec.scala | 8 +- .../spark/sql/execution/debug/package.scala | 27 +++--- 7 files changed, 133 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 125181fb213f8..8f5444ed8a5a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf @@ -301,4 +302,20 @@ object QueryPlan extends PredicateHelper { Nil } } + + /** + * Converts the query plan to string and appends it via provided function. + */ + def append[T <: QueryPlan[T]]( + plan: => QueryPlan[T], + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int = SQLConf.get.maxToStringFields): Unit = { + try { + plan.treeString(append, verbose, addSuffix, maxFields) + } catch { + case e: AnalysisException => append(e.toString) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 21e59bbd283e4..570a019b2af77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.catalyst.trees -import java.io.Writer import java.util.UUID import scala.collection.Map import scala.reflect.ClassTag -import org.apache.commons.io.output.StringBuilderWriter import org.apache.commons.lang3.ClassUtils import org.json4s.JsonAST._ import org.json4s.JsonDSL._ @@ -37,6 +35,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -481,21 +480,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { verbose: Boolean, addSuffix: Boolean = false, maxFields: Int = SQLConf.get.maxToStringFields): String = { - val writer = new StringBuilderWriter() - try { - treeString(writer, verbose, addSuffix, maxFields) - writer.toString - } finally { - writer.close() - } + val concat = new StringConcat() + + treeString(concat.append, verbose, addSuffix, maxFields) + concat.toString } def treeString( - writer: Writer, + append: String => Unit, verbose: Boolean, addSuffix: Boolean, maxFields: Int): Unit = { - generateTreeString(0, Nil, writer, verbose, "", addSuffix, maxFields) + generateTreeString(0, Nil, append, verbose, "", addSuffix, maxFields) } /** @@ -558,7 +554,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - writer: Writer, + append: String => Unit, verbose: Boolean, prefix: String = "", addSuffix: Boolean = false, @@ -566,9 +562,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (depth > 0) { lastChildren.init.foreach { isLast => - writer.write(if (isLast) " " else ": ") + append(if (isLast) " " else ": ") } - writer.write(if (lastChildren.last) "+- " else ":- ") + append(if (lastChildren.last) "+- " else ":- ") } val str = if (verbose) { @@ -576,24 +572,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { simpleString(maxFields) } - writer.write(prefix) - writer.write(str) - writer.write("\n") + append(prefix) + append(str) + append("\n") if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose, + depth + 2, lastChildren :+ children.isEmpty :+ false, append, verbose, addSuffix = addSuffix, maxFields = maxFields)) innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose, + depth + 2, lastChildren :+ children.isEmpty :+ true, append, verbose, addSuffix = addSuffix, maxFields = maxFields) } if (children.nonEmpty) { children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix, maxFields)) + depth + 1, lastChildren :+ false, append, verbose, prefix, addSuffix, maxFields)) children.last.generateTreeString( - depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix, maxFields) + depth + 1, lastChildren :+ true, append, verbose, prefix, addSuffix, maxFields) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index bc861a805ce61..643b83b1741ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String @@ -87,4 +89,34 @@ object StringUtils { } funcNames.toSeq } + + /** + * Concatenation of sequence of strings to final string with cheap append method + * and one memory allocation for the final string. + */ + class StringConcat { + private val strings = new ArrayBuffer[String] + private var length: Int = 0 + + /** + * Appends a string and accumulates its length to allocate a string buffer for all + * appended strings once in the toString method. + */ + def append(s: String): Unit = { + if (s != null) { + strings.append(s) + length += s.length + } + } + + /** + * The method allocates memory for all appended strings, writes them to the memory and + * returns concatenated string. + */ + override def toString: String = { + val result = new java.lang.StringBuilder(length) + strings.foreach(result.append) + result.toString + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index 78fee5135c3ae..616ec12032dbd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -43,4 +43,17 @@ class StringUtilsSuite extends SparkFunSuite { assert(filterPattern(names, " a. ") === Seq("a1", "a2")) assert(filterPattern(names, " d* ") === Nil) } + + test("string concatenation") { + def concat(seq: String*): String = { + seq.foldLeft(new StringConcat())((acc, s) => {acc.append(s); acc}).toString + } + + assert(new StringConcat().toString == "") + assert(concat("") == "") + assert(concat(null) == "") + assert(concat("a") == "a") + assert(concat("1", "2") == "12") + assert(concat("abc", "\n", "123") == "abc\n123") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9b8d2e830867d..7fccbf65d8525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.execution -import java.io.{BufferedWriter, OutputStreamWriter, Writer} +import java.io.{BufferedWriter, OutputStreamWriter} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import org.apache.commons.io.output.StringBuilderWriter import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} @@ -108,10 +109,6 @@ class QueryExecution( ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) - protected def stringOrError[A](f: => A): String = - try f.toString catch { case e: AnalysisException => e.toString } - - /** * Returns the result as a hive compatible sequence of strings. This is used in tests and * `SparkSQLDriver` for CLI applications. @@ -197,55 +194,53 @@ class QueryExecution( } def simpleString: String = withRedaction { - s"""== Physical Plan == - |${stringOrError(executedPlan.treeString(verbose = false))} - """.stripMargin.trim - } - - private def writeOrError(writer: Writer)(f: Writer => Unit): Unit = { - try f(writer) - catch { - case e: AnalysisException => writer.write(e.toString) - } + val concat = new StringConcat() + concat.append("== Physical Plan ==\n") + QueryPlan.append(executedPlan, concat.append, verbose = false, addSuffix = false) + concat.append("\n") + concat.toString } - private def writePlans(writer: Writer, maxFields: Int): Unit = { + private def writePlans(append: String => Unit, maxFields: Int): Unit = { val (verbose, addSuffix) = (true, false) - - writer.write("== Parsed Logical Plan ==\n") - writeOrError(writer)(logical.treeString(_, verbose, addSuffix, maxFields)) - writer.write("\n== Analyzed Logical Plan ==\n") - val analyzedOutput = stringOrError(truncatedString( - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields)) - writer.write(analyzedOutput) - writer.write("\n") - writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix, maxFields)) - writer.write("\n== Optimized Logical Plan ==\n") - writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix, maxFields)) - writer.write("\n== Physical Plan ==\n") - writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix, maxFields)) + append("== Parsed Logical Plan ==\n") + QueryPlan.append(logical, append, verbose, addSuffix, maxFields) + append("\n== Analyzed Logical Plan ==\n") + val analyzedOutput = try { + truncatedString( + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields) + } catch { + case e: AnalysisException => e.toString + } + append(analyzedOutput) + append("\n") + QueryPlan.append(analyzed, append, verbose, addSuffix, maxFields) + append("\n== Optimized Logical Plan ==\n") + QueryPlan.append(optimizedPlan, append, verbose, addSuffix, maxFields) + append("\n== Physical Plan ==\n") + QueryPlan.append(executedPlan, append, verbose, addSuffix, maxFields) } override def toString: String = withRedaction { - val writer = new StringBuilderWriter() - try { - writePlans(writer, SQLConf.get.maxToStringFields) - writer.toString - } finally { - writer.close() - } + val concat = new StringConcat() + writePlans(concat.append, SQLConf.get.maxToStringFields) + concat.toString } def stringWithStats: String = withRedaction { + val concat = new StringConcat() + val maxFields = SQLConf.get.maxToStringFields + // trigger to compute stats for logical plans optimizedPlan.stats // only show optimized logical plan and physical plan - s"""== Optimized Logical Plan == - |${stringOrError(optimizedPlan.treeString(verbose = true, addSuffix = true))} - |== Physical Plan == - |${stringOrError(executedPlan.treeString(verbose = true))} - """.stripMargin.trim + concat.append("== Optimized Logical Plan ==\n") + QueryPlan.append(optimizedPlan, concat.append, verbose = true, addSuffix = true, maxFields) + concat.append("\n== Physical Plan ==\n") + QueryPlan.append(executedPlan, concat.append, verbose = true, addSuffix = false, maxFields) + concat.append("\n") + concat.toString } /** @@ -282,7 +277,7 @@ class QueryExecution( /** * Dumps debug information about query execution into the specified file. * - * @param maxFields maximim number of fields converted to string representation. + * @param maxFields maximum number of fields converted to string representation. */ def toFile(path: String, maxFields: Int = Int.MaxValue): Unit = { val filePath = new Path(path) @@ -290,9 +285,9 @@ class QueryExecution( val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath))) try { - writePlans(writer, maxFields) + writePlans(writer.write, maxFields) writer.write("\n== Whole Stage Codegen ==\n") - org.apache.spark.sql.execution.debug.writeCodegen(writer, executedPlan) + org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan) } finally { writer.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f4927dedabe56..3b0a99669ccd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -493,7 +493,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - writer: Writer, + append: String => Unit, verbose: Boolean, prefix: String = "", addSuffix: Boolean = false, @@ -501,7 +501,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod child.generateTreeString( depth, lastChildren, - writer, + append, verbose, prefix = "", addSuffix = false, @@ -777,7 +777,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - writer: Writer, + append: String => Unit, verbose: Boolean, prefix: String = "", addSuffix: Boolean = false, @@ -785,7 +785,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) child.generateTreeString( depth, lastChildren, - writer, + append, verbose, s"*($codegenStageId) ", false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index ae8197f617a28..53b74c7c85594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.execution -import java.io.Writer import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.commons.io.output.StringBuilderWriter - import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -32,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery @@ -73,24 +71,19 @@ package object debug { * @return single String containing all WholeStageCodegen subtrees and corresponding codegen */ def codegenString(plan: SparkPlan): String = { - val writer = new StringBuilderWriter() - - try { - writeCodegen(writer, plan) - writer.toString - } finally { - writer.close() - } + val concat = new StringConcat() + writeCodegen(concat.append, plan) + concat.toString } - def writeCodegen(writer: Writer, plan: SparkPlan): Unit = { + def writeCodegen(append: String => Unit, plan: SparkPlan): Unit = { val codegenSeq = codegenStringSeq(plan) - writer.write(s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n") + append(s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n") for (((subtree, code), i) <- codegenSeq.zipWithIndex) { - writer.write(s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n") - writer.write(subtree) - writer.write("\nGenerated code:\n") - writer.write(s"${code}\n") + append(s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n") + append(subtree) + append("\nGenerated code:\n") + append(s"${code}\n") } } From c0b9db120d4c2ad0b5b99b9152549e94ef8f5a2d Mon Sep 17 00:00:00 2001 From: Hirobe Keiichi Date: Mon, 31 Dec 2018 10:15:14 -0600 Subject: [PATCH 2350/2461] [SPARK-26339][SQL] Throws better exception when reading files that start with underscore ## What changes were proposed in this pull request? As the description in SPARK-26339, spark.read behavior is very confusing when reading files that start with underscore, fix this by throwing exception which message is "Path does not exist". ## How was this patch tested? manual tests. Both of codes below throws exception which message is "Path does not exist". ``` spark.read.csv("/home/forcia/work/spark/_test.csv") spark.read.schema("test STRING, number INT").csv("/home/forcia/work/spark/_test.csv") ``` Closes #23288 from KeiichiHirobe/SPARK-26339. Authored-by: Hirobe Keiichi Signed-off-by: Sean Owen --- .../execution/datasources/DataSource.scala | 17 +++++++++++++++- .../src/test/resources/test-data/_cars.csv | 7 +++++++ .../execution/datasources/csv/CSVSuite.scala | 20 +++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/resources/test-data/_cars.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fefff68c4ba8b..517e04317d94e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -543,7 +543,7 @@ case class DataSource( checkFilesExist: Boolean): Seq[Path] = { val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() - allPaths.flatMap { path => + val allGlobPath = allPaths.flatMap { path => val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -560,6 +560,21 @@ case class DataSource( } globPath }.toSeq + + val (filteredOut, filteredIn) = allGlobPath.partition { path => + InMemoryFileIndex.shouldFilterOut(path.getName) + } + if (filteredOut.nonEmpty) { + if (filteredIn.isEmpty) { + throw new AnalysisException( + s"All paths were ignored:\n${filteredOut.mkString("\n ")}") + } else { + logDebug( + s"Some paths were ignored:\n${filteredOut.mkString("\n ")}") + } + } + + allGlobPath } } diff --git a/sql/core/src/test/resources/test-data/_cars.csv b/sql/core/src/test/resources/test-data/_cars.csv new file mode 100644 index 0000000000000..40ded573ade5c --- /dev/null +++ b/sql/core/src/test/resources/test-data/_cars.csv @@ -0,0 +1,7 @@ + +year,make,model,comment,blank +"2012","Tesla","S","No comment", + +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d9e5d7af19671..fb1bedfaa32c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -53,6 +53,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val carsEmptyValueFile = "test-data/cars-empty-value.csv" private val carsBlankColName = "test-data/cars-blank-column-name.csv" private val carsCrlf = "test-data/cars-crlf.csv" + private val carsFilteredOutFile = "test-data/_cars.csv" private val emptyFile = "test-data/empty.csv" private val commentsFile = "test-data/comments.csv" private val disableCommentsFile = "test-data/disable_comments.csv" @@ -346,6 +347,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(result.schema.fieldNames.size === 1) } + test("SPARK-26339 Not throw an exception if some of specified paths are filtered in") { + val cars = spark + .read + .option("header", "false") + .csv(testFile(carsFile), testFile(carsFilteredOutFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("SPARK-26339 Throw an exception only if all of the specified paths are filtered out") { + val e = intercept[AnalysisException] { + val cars = spark + .read + .option("header", "false") + .csv(testFile(carsFilteredOutFile)) + }.getMessage + assert(e.contains("All paths were ignored:")) + } + test("DDL test with empty file") { withView("carsTable") { spark.sql( From c0368363f8a81dd739c6c90fb2849b2a3ab4d8e4 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 31 Dec 2018 17:46:06 +0100 Subject: [PATCH 2351/2461] [SPARK-26495][SQL] Simplify the SelectedField extractor. ## What changes were proposed in this pull request? The current `SelectedField` extractor is somewhat complicated and it seems to be handling cases that should be handled automatically: - `GetArrayItem(child: GetStructFieldObject())` - `GetArrayStructFields(child: GetArrayStructFields())` - `GetMap(value: GetStructFieldObject())` This PR removes those cases and simplifies the extractor by passing down the data type instead of a field. ## How was this patch tested? Existing tests. Closes #23397 from hvanhovell/SPARK-26495. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../spark/sql/execution/SelectedField.scala | 103 +++++++----------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala index 0e7c593f9fb67..68f797a856a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -51,8 +52,6 @@ import org.apache.spark.sql.types._ * type appropriate to the complex type extractor. In our example, the name of the child expression * is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string * field named "first". - * - * @param expr the top-level complex type extractor */ private[execution] object SelectedField { def unapply(expr: Expression): Option[StructField] = { @@ -64,71 +63,51 @@ private[execution] object SelectedField { selectField(unaliased, None) } - private def selectField(expr: Expression, fieldOpt: Option[StructField]): Option[StructField] = { + /** + * Convert an expression into the parts of the schema (the field) it accesses. + */ + private def selectField(expr: Expression, dataTypeOpt: Option[DataType]): Option[StructField] = { expr match { - // No children. Returns a StructField with the attribute name or None if fieldOpt is None. - case AttributeReference(name, dataType, nullable, metadata) => - fieldOpt.map(field => - StructField(name, wrapStructType(dataType, field), nullable, metadata)) - // Handles case "expr0.field[n]", where "expr0" is of struct type and "expr0.field" is of - // array type. - case GetArrayItem(x @ GetStructFieldObject(child, field @ StructField(name, - dataType, nullable, metadata)), _) => - val childField = fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), nullable, metadata)).getOrElse(field) - selectField(child, Some(childField)) - // Handles case "expr0.field[n]", where "expr0.field" is of array type. - case GetArrayItem(child, _) => - selectField(child, fieldOpt) - // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type. - case GetArrayStructFields(child: GetArrayStructFields, - field @ StructField(name, dataType, nullable, metadata), _, _, _) => - val childField = fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), - nullable, metadata)).orElse(Some(field)) - selectField(child, childField) - // Handles case "expr0.field", where "expr0" is of array type. - case GetArrayStructFields(child, - field @ StructField(name, dataType, nullable, metadata), _, _, _) => - val childField = - fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), - nullable, metadata)).orElse(Some(field)) - selectField(child, childField) - // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of - // map type. - case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name, - dataType, - nullable, metadata)), _) => - val childField = fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), - nullable, metadata)).orElse(Some(field)) - selectField(child, childField) - // Handles case "expr0.field[key]", where "expr0.field" is of map type. + case a: Attribute => + dataTypeOpt.map { dt => + StructField(a.name, dt, a.nullable) + } + case c: GetStructField => + val field = c.childSchema(c.ordinal) + val newField = field.copy(dataType = dataTypeOpt.getOrElse(field.dataType)) + selectField(c.child, Option(struct(newField))) + case GetArrayStructFields(child, field, _, _, containsNull) => + val newFieldDataType = dataTypeOpt match { + case None => + // GetArrayStructFields is the top level extractor. This means its result is + // not pruned and we need to use the element type of the array its producing. + field.dataType + case Some(ArrayType(dataType, _)) => + // GetArrayStructFields is part of a chain of extractors and its result is pruned + // by a parent expression. In this case need to use the parent element type. + dataType + case Some(x) => + // This should not happen. + throw new AnalysisException(s"DataType '$x' is not supported by GetArrayStructFields.") + } + val newField = StructField(field.name, newFieldDataType, field.nullable) + selectField(child, Option(ArrayType(struct(newField), containsNull))) case GetMapValue(child, _) => - selectField(child, fieldOpt) - // Handles case "expr0.field", where expr0 is of struct type. - case GetStructFieldObject(child, - field @ StructField(name, dataType, nullable, metadata)) => - val childField = fieldOpt.map(field => StructField(name, - wrapStructType(dataType, field), - nullable, metadata)).orElse(Some(field)) - selectField(child, childField) + // GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be + // the top-level extractor. However it can be part of an extractor chain. + val MapType(keyType, _, valueContainsNull) = child.dataType + val opt = dataTypeOpt.map(dt => MapType(keyType, dt, valueContainsNull)) + selectField(child, opt) + case GetArrayItem(child, _) => + // GetArrayItem does not select a field from a struct (i.e. prune the struct) so it can't be + // the top-level extractor. However it can be part of an extractor chain. + val ArrayType(_, containsNull) = child.dataType + val opt = dataTypeOpt.map(dt => ArrayType(dt, containsNull)) + selectField(child, opt) case _ => None } } - // Constructs a composition of complex types with a StructType(Array(field)) at its core. Returns - // a StructType for a StructType, an ArrayType for an ArrayType and a MapType for a MapType. - private def wrapStructType(dataType: DataType, field: StructField): DataType = { - dataType match { - case _: StructType => - StructType(Array(field)) - case ArrayType(elementType, containsNull) => - ArrayType(wrapStructType(elementType, field), containsNull) - case MapType(keyType, valueType, valueContainsNull) => - MapType(keyType, wrapStructType(valueType, field), valueContainsNull) - } - } + private def struct(field: StructField): StructType = StructType(Array(field)) } From b1a9b5eff59f64c370cd7388761effdf2152a108 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 31 Dec 2018 13:35:02 -0800 Subject: [PATCH 2352/2461] [SPARK-26470][CORE] Use ConfigEntry for hardcoded configs for eventLog category ## What changes were proposed in this pull request? The PR makes hardcoded `spark.eventLog` configs to use `ConfigEntry` and put them in the `config` package. ## How was this patch tested? existing tests Closes #23395 from mgaido91/SPARK-26470. Authored-by: Marco Gaido Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/SparkContext.scala | 7 +++---- .../apache/spark/internal/config/package.scala | 9 +++++++++ .../spark/deploy/history/HistoryServerSuite.scala | 9 +++++---- .../scheduler/EventLoggingListenerSuite.scala | 15 ++++++++------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 09cc346db0ed2..3475859c3ed69 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -230,7 +230,7 @@ class SparkContext(config: SparkConf) extends Logging { def deployMode: String = _conf.getOption("spark.submit.deployMode").getOrElse("client") def appName: String = _conf.get("spark.app.name") - private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def isEventLogEnabled: Boolean = _conf.get(EVENT_LOG_ENABLED) private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec @@ -396,15 +396,14 @@ class SparkContext(config: SparkConf) extends Logging { _eventLogDir = if (isEventLogEnabled) { - val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) - .stripSuffix("/") + val unresolvedDir = conf.get(EVENT_LOG_DIR).stripSuffix("/") Some(Utils.resolveURI(unresolvedDir)) } else { None } _eventLogCodec = { - val compress = _conf.getBoolean("spark.eventLog.compress", false) + val compress = _conf.get(EVENT_LOG_COMPRESS) if (compress && isEventLogEnabled) { Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) } else { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index f1c1c034df49a..d8e9c099028f5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit +import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -62,6 +63,14 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val EVENT_LOG_ENABLED = ConfigBuilder("spark.eventLog.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_DIR = ConfigBuilder("spark.eventLog.dir") + .stringConf + .createWithDefault(EventLoggingListener.DEFAULT_LOG_DIR) + private[spark] val EVENT_LOG_COMPRESS = ConfigBuilder("spark.eventLog.compress") .booleanConf diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a9dee67ae9383..96458c55b5f55 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -45,6 +45,7 @@ import org.scalatest.mockito.MockitoSugar import org.scalatest.selenium.WebBrowser import org.apache.spark._ +import org.apache.spark.internal.config._ import org.apache.spark.internal.config.History._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData @@ -82,8 +83,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set(UPDATE_INTERVAL_S.key, "0") .set("spark.testing", "true") .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - .set("spark.eventLog.logStageExecutorMetrics.enabled", "true") - .set("spark.eventLog.logStageExecutorProcessTreeMetrics.enabled", "true") + .set(EVENT_LOG_STAGE_EXECUTOR_METRICS, true) + .set(EVENT_LOG_PROCESS_TREE_METRICS, true) conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -417,9 +418,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers stop() val myConf = new SparkConf() .set(HISTORY_LOG_DIR, logDir.getAbsolutePath) - .set("spark.eventLog.dir", logDir.getAbsolutePath) + .set(EVENT_LOG_DIR, logDir.getAbsolutePath) .set(UPDATE_INTERVAL_S.key, "1s") - .set("spark.eventLog.enabled", "true") + .set(EVENT_LOG_ENABLED, true) .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .remove("spark.testing") val provider = new FsHistoryProvider(myConf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 0c04a93646d7c..04987e6ef79ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io._ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -122,7 +123,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // Expected IOException, since we haven't enabled log overwrite. intercept[IOException] { testEventLogging() } // Try again, but enable overwriting. - testEventLogging(extraConf = Map("spark.eventLog.overwrite" -> "true")) + testEventLogging(extraConf = Map(EVENT_LOG_OVERWRITE.key -> "true")) } test("Event log name") { @@ -526,15 +527,15 @@ object EventLoggingListenerSuite { /** Get a SparkConf with event logging enabled. */ def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf - conf.set("spark.eventLog.enabled", "true") - conf.set("spark.eventLog.logBlockUpdates.enabled", "true") - conf.set("spark.eventLog.testing", "true") - conf.set("spark.eventLog.dir", logDir.toString) + conf.set(EVENT_LOG_ENABLED, true) + conf.set(EVENT_LOG_BLOCK_UPDATES, true) + conf.set(EVENT_LOG_TESTING, true) + conf.set(EVENT_LOG_DIR, logDir.toString) compressionCodec.foreach { codec => - conf.set("spark.eventLog.compress", "true") + conf.set(EVENT_LOG_COMPRESS, true) conf.set("spark.io.compression.codec", codec) } - conf.set("spark.eventLog.logStageExecutorMetrics.enabled", "true") + conf.set(EVENT_LOG_STAGE_EXECUTOR_METRICS, true) conf } From 993736154b6a46ffd7c3218173a2653a3842bba0 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 1 Jan 2019 09:14:23 +0800 Subject: [PATCH 2353/2461] [MINOR] Fix inconsistency log level among delegation token providers ## What changes were proposed in this pull request? There's some inconsistency for log level while logging error messages in delegation token providers. (DEBUG, INFO, WARNING) Given that failing to obtain token would often crash the query, I guess it would be nice to set higher log level for error log messages. ## How was this patch tested? The patch just changed the log level. Closes #23418 from HeartSaVioR/FIX-inconsistency-log-level-between-delegation-token-providers. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Hyukjin Kwon --- .../HBaseDelegationTokenProvider.scala | 4 +- .../HadoopFSDelegationTokenProvider.scala | 45 +++++++++++-------- .../HiveDelegationTokenProvider.scala | 4 +- .../KafkaDelegationTokenProvider.scala | 2 +- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 5dcde4ec3a8a4..6ef68351bc9b2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -50,7 +50,7 @@ private[security] class HBaseDelegationTokenProvider creds.addToken(token.getService, token) } catch { case NonFatal(e) => - logDebug(s"Failed to get token from service $serviceName", e) + logWarning(s"Failed to get token from service $serviceName", e) } None @@ -71,7 +71,7 @@ private[security] class HBaseDelegationTokenProvider confCreate.invoke(null, conf).asInstanceOf[Configuration] } catch { case NonFatal(e) => - logDebug("Fail to invoke HBaseConfiguration", e) + logWarning("Fail to invoke HBaseConfiguration", e) conf } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 767b5521e8d7b..00200f807d224 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.security import scala.collection.JavaConverters._ import scala.util.Try +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem @@ -44,28 +45,34 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: () => Set[Fil hadoopConf: Configuration, sparkConf: SparkConf, creds: Credentials): Option[Long] = { - val fsToGetTokens = fileSystems() - val fetchCreds = fetchDelegationTokens(getTokenRenewer(hadoopConf), fsToGetTokens, creds) + try { + val fsToGetTokens = fileSystems() + val fetchCreds = fetchDelegationTokens(getTokenRenewer(hadoopConf), fsToGetTokens, creds) - // Get the token renewal interval if it is not set. It will only be called once. - if (tokenRenewalInterval == null) { - tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf, fsToGetTokens) - } + // Get the token renewal interval if it is not set. It will only be called once. + if (tokenRenewalInterval == null) { + tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf, fsToGetTokens) + } - // Get the time of next renewal. - val nextRenewalDate = tokenRenewalInterval.flatMap { interval => - val nextRenewalDates = fetchCreds.getAllTokens.asScala - .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) - .map { token => - val identifier = token - .decodeIdentifier() - .asInstanceOf[AbstractDelegationTokenIdentifier] - identifier.getIssueDate + interval - } - if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) - } + // Get the time of next renewal. + val nextRenewalDate = tokenRenewalInterval.flatMap { interval => + val nextRenewalDates = fetchCreds.getAllTokens.asScala + .filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]) + .map { token => + val identifier = token + .decodeIdentifier() + .asInstanceOf[AbstractDelegationTokenIdentifier] + identifier.getIssueDate + interval + } + if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min) + } - nextRenewalDate + nextRenewalDate + } catch { + case NonFatal(e) => + logWarning(s"Failed to get token from service $serviceName", e) + None + } } override def delegationTokensRequired( diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index 7249eb85ac7c7..90f7051381571 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -49,7 +49,7 @@ private[spark] class HiveDelegationTokenProvider new HiveConf(hadoopConf, classOf[HiveConf]) } catch { case NonFatal(e) => - logDebug("Fail to create Hive Configuration", e) + logWarning("Fail to create Hive Configuration", e) hadoopConf case e: NoClassDefFoundError => logWarning(classNotFoundErrorStr) @@ -104,7 +104,7 @@ private[spark] class HiveDelegationTokenProvider None } catch { case NonFatal(e) => - logDebug(s"Failed to get token from service $serviceName", e) + logWarning(s"Failed to get token from service $serviceName", e) None case e: NoClassDefFoundError => logWarning(classNotFoundErrorStr) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala index 45995be630cc5..f67cb26259fee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/KafkaDelegationTokenProvider.scala @@ -44,7 +44,7 @@ private[security] class KafkaDelegationTokenProvider return Some(nextRenewalDate) } catch { case NonFatal(e) => - logInfo(s"Failed to get token from service $serviceName", e) + logWarning(s"Failed to get token from service $serviceName", e) } None } From f7455618ce6de8d2e70f10722dc112fcc6ee3cee Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 1 Jan 2019 09:29:28 +0800 Subject: [PATCH 2354/2461] Revert "[SPARK-26339][SQL] Throws better exception when reading files that start with underscore" This reverts commit c0b9db120d4c2ad0b5b99b9152549e94ef8f5a2d. --- .../execution/datasources/DataSource.scala | 17 +--------------- .../src/test/resources/test-data/_cars.csv | 7 ------- .../execution/datasources/csv/CSVSuite.scala | 20 ------------------- 3 files changed, 1 insertion(+), 43 deletions(-) delete mode 100644 sql/core/src/test/resources/test-data/_cars.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 517e04317d94e..fefff68c4ba8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -543,7 +543,7 @@ case class DataSource( checkFilesExist: Boolean): Seq[Path] = { val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() - val allGlobPath = allPaths.flatMap { path => + allPaths.flatMap { path => val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -560,21 +560,6 @@ case class DataSource( } globPath }.toSeq - - val (filteredOut, filteredIn) = allGlobPath.partition { path => - InMemoryFileIndex.shouldFilterOut(path.getName) - } - if (filteredOut.nonEmpty) { - if (filteredIn.isEmpty) { - throw new AnalysisException( - s"All paths were ignored:\n${filteredOut.mkString("\n ")}") - } else { - logDebug( - s"Some paths were ignored:\n${filteredOut.mkString("\n ")}") - } - } - - allGlobPath } } diff --git a/sql/core/src/test/resources/test-data/_cars.csv b/sql/core/src/test/resources/test-data/_cars.csv deleted file mode 100644 index 40ded573ade5c..0000000000000 --- a/sql/core/src/test/resources/test-data/_cars.csv +++ /dev/null @@ -1,7 +0,0 @@ - -year,make,model,comment,blank -"2012","Tesla","S","No comment", - -1997,Ford,E350,"Go get one now they are going fast", -2015,Chevy,Volt - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index fb1bedfaa32c3..d9e5d7af19671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -53,7 +53,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val carsEmptyValueFile = "test-data/cars-empty-value.csv" private val carsBlankColName = "test-data/cars-blank-column-name.csv" private val carsCrlf = "test-data/cars-crlf.csv" - private val carsFilteredOutFile = "test-data/_cars.csv" private val emptyFile = "test-data/empty.csv" private val commentsFile = "test-data/comments.csv" private val disableCommentsFile = "test-data/disable_comments.csv" @@ -347,25 +346,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(result.schema.fieldNames.size === 1) } - test("SPARK-26339 Not throw an exception if some of specified paths are filtered in") { - val cars = spark - .read - .option("header", "false") - .csv(testFile(carsFile), testFile(carsFilteredOutFile)) - - verifyCars(cars, withHeader = false, checkTypes = false) - } - - test("SPARK-26339 Throw an exception only if all of the specified paths are filtered out") { - val e = intercept[AnalysisException] { - val cars = spark - .read - .option("header", "false") - .csv(testFile(carsFilteredOutFile)) - }.getMessage - assert(e.contains("All paths were ignored:")) - } - test("DDL test with empty file") { withView("carsTable") { spark.sql( From 5f0ddd2d6e2fdebf549207bbc4b13ca709eee3c4 Mon Sep 17 00:00:00 2001 From: Thomas D'Silva Date: Tue, 1 Jan 2019 14:11:14 +0800 Subject: [PATCH 2355/2461] [SPARK-26499][SQL] JdbcUtils.makeGetter does not handle ByteType MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …Type ## What changes were proposed in this pull request? Modifed JdbcUtils.makeGetter to handle ByteType. ## How was this patch tested? Added a new test to JDBCSuite that maps ```TINYINT``` to ```ByteType```. Closes #23400 from twdsilva/tiny_int_support. Authored-by: Thomas D'Silva Signed-off-by: Hyukjin Kwon --- .../datasources/jdbc/JdbcUtils.scala | 4 +++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 25 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index edea549748b47..922bef284c98e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -438,6 +438,10 @@ object JdbcUtils extends Logging { (rs: ResultSet, row: InternalRow, pos: Int) => row.setShort(pos, rs.getShort(pos + 1)) + case ByteType => + (rs: ResultSet, row: InternalRow, pos: Int) => + row.update(pos, rs.getByte(pos + 1)) + case StringType => (rs: ResultSet, row: InternalRow, pos: Int) => // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 71e83767964a0..e4641631e607d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -56,6 +56,20 @@ class JDBCSuite extends QueryTest Some(StringType) } + val testH2DialectTinyInt = new JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2") + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = { + sqlType match { + case java.sql.Types.TINYINT => Some(ByteType) + case _ => None + } + } + } + before { Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -693,6 +707,17 @@ class JDBCSuite extends QueryTest JdbcDialects.unregisterDialect(testH2Dialect) } + test("Map TINYINT to ByteType via JdbcDialects") { + JdbcDialects.registerDialect(testH2DialectTinyInt) + val df = spark.read.jdbc(urlWithUserAndPass, "test.inttypes", new Properties()) + val rows = df.collect() + assert(rows.length === 2) + assert(rows(0).get(2).isInstanceOf[Byte]) + assert(rows(0).getByte(2) === 3) + assert(rows(1).isNullAt(2)) + JdbcDialects.unregisterDialect(testH2DialectTinyInt) + } + test("Default jdbc dialect registration") { assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) From 2bf4d97118c20812e26a7ea59826ee470ab42f7c Mon Sep 17 00:00:00 2001 From: zhoukang Date: Tue, 1 Jan 2019 09:13:13 -0600 Subject: [PATCH 2356/2461] [SPARK-24544][SQL] Print actual failure cause when look up function failed ## What changes were proposed in this pull request? When we operate as below: ` 0: jdbc:hive2://xxx/> create function funnel_analysis as 'com.xxx.hive.extend.udf.UapFunnelAnalysis'; ` ` 0: jdbc:hive2://xxx/> select funnel_analysis(1,",",1,''); Error: org.apache.spark.sql.AnalysisException: Undefined function: 'funnel_analysis'. This function is neither a registered temporary function nor a permanent function registered in the database 'xxx'.; line 1 pos 7 (state=,code=0) ` ` 0: jdbc:hive2://xxx/> describe function funnel_analysis; +-----------------------------------------------------------+--+ | function_desc | +-----------------------------------------------------------+--+ | Function: xxx.funnel_analysis | | Class: com.xxx.hive.extend.udf.UapFunnelAnalysis | | Usage: N/A. | +-----------------------------------------------------------+--+ ` We can see describe funtion will get right information,but when we actually use this funtion,we will get an undefined exception. Which is really misleading,the real cause is below: ` No handler for Hive UDF 'com.xxx.xxx.hive.extend.udf.UapFunnelAnalysis': java.lang.IllegalStateException: Should not be called directly; at org.apache.hadoop.hive.ql.udf.generic.GenericUDTF.initialize(GenericUDTF.java:72) at org.apache.spark.sql.hive.HiveGenericUDTF.outputInspector$lzycompute(hiveUDFs.scala:204) at org.apache.spark.sql.hive.HiveGenericUDTF.outputInspector(hiveUDFs.scala:204) at org.apache.spark.sql.hive.HiveGenericUDTF.elementSchema$lzycompute(hiveUDFs.scala:212) at org.apache.spark.sql.hive.HiveGenericUDTF.elementSchema(hiveUDFs.scala:212) ` This patch print the actual failure for quick debugging. ## How was this patch tested? UT Closes #21790 from caneGuy/zhoukang/print-warning1. Authored-by: zhoukang Signed-off-by: Sean Owen --- .../catalyst/analysis/NoSuchItemException.scala | 4 ++-- .../sql/catalyst/catalog/SessionCatalog.scala | 5 +++-- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 14 ++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 9 ++++++--- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index f5aae60431c15..8bf6f69f3b17a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -40,10 +40,10 @@ class NoSuchPartitionException( class NoSuchPermanentFunctionException(db: String, func: String) extends AnalysisException(s"Function '$func' not found in database '$db'") -class NoSuchFunctionException(db: String, func: String) +class NoSuchFunctionException(db: String, func: String, cause: Option[Throwable] = None) extends AnalysisException( s"Undefined function: '$func'. This function is neither a registered temporary function nor " + - s"a permanent function registered in the database '$db'.") + s"a permanent function registered in the database '$db'.", cause = cause) class NoSuchPartitionsException(db: String, table: String, specs: Seq[TablePartitionSpec]) extends AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b6771ec4dffe9..1dbe946503e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1222,9 +1222,10 @@ class SessionCatalog( databaseExists(db) && externalCatalog.functionExists(db, name.funcName) } - protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { + protected[sql] def failFunctionLookup( + name: FunctionIdentifier, cause: Option[Throwable] = None): Nothing = { throw new NoSuchFunctionException( - db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) + db = name.database.getOrElse(getCurrentDatabase), func = name.funcName, cause) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 19e8c0334689c..92f87ea796e87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1448,4 +1448,18 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } } + + test("SPARK-24544: test print actual failure cause when look up function failed") { + withBasicCatalog { catalog => + val cause = intercept[NoSuchFunctionException] { + catalog.failFunctionLookup(FunctionIdentifier("failureFunc"), + Some(new Exception("Actual error"))) + } + + // fullStackTrace will be printed, but `cause.getMessage` has been + // override in `AnalysisException`,so here we get the root cause + // exception message for check. + assert(cause.cause.get.getMessage.contains("Actual error")) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 405c0c8bfe660..7560805bb3b09 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} +import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( @@ -141,8 +142,10 @@ private[sql] class HiveSessionCatalog( // let's try to load it as a Hive's built-in function. // Hive is case insensitive. val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) + logWarning("Encountered a failure during looking up function:" + + s" ${Utils.exceptionString(error)}") if (!hiveFunctions.contains(functionName)) { - failFunctionLookup(funcName) + failFunctionLookup(funcName, Some(error)) } // TODO: Remove this fallback path once we implement the list of fallback functions @@ -150,12 +153,12 @@ private[sql] class HiveSessionCatalog( val functionInfo = { try { Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( - failFunctionLookup(funcName)) + failFunctionLookup(funcName, Some(error))) } catch { // If HiveFunctionRegistry.getFunctionInfo throws an exception, // we are failing to load a Hive builtin function, which means that // the given function is not a Hive builtin function. - case NonFatal(e) => failFunctionLookup(funcName) + case NonFatal(e) => failFunctionLookup(funcName, Some(e)) } } val className = functionInfo.getFunctionClass.getName From 001d3095385626e329b3853364a4feeb811aac5a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 1 Jan 2019 09:18:58 -0600 Subject: [PATCH 2357/2461] [SPARK-25765][ML] Add training cost to BisectingKMeans summary ## What changes were proposed in this pull request? The PR adds the `trainingCost` value to the `BisectingKMeansSummary`, in order to expose the information retrievable by running `computeCost` on the training dataset. This fills the gap with `KMeans` implementation. ## How was this patch tested? improved UTs Closes #22764 from mgaido91/SPARK-25765. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../spark/ml/clustering/BisectingKMeans.scala | 13 +++- .../mllib/clustering/BisectingKMeans.scala | 3 +- .../clustering/BisectingKMeansModel.scala | 59 ++++++++++++++++--- .../ml/clustering/BisectingKMeansSuite.scala | 2 + .../clustering/BisectingKMeansSuite.scala | 1 + project/MimaExcludes.scala | 3 + python/pyspark/ml/clustering.py | 12 +++- 7 files changed, 82 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 49e9f51368131..d846f17e7f549 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -264,7 +264,12 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.run(rdd, Some(instr)) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + $(maxIter), + parentModel.trainingCost) instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logNumFeatures(model.clusterCenters.head.size) model.setSummary(Some(summary)) @@ -294,6 +299,8 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param numIter Number of iterations. + * @param trainingCost Sum of the cost to the nearest centroid for all points in the training + * dataset. This is equivalent to sklearn's inertia. */ @Since("2.1.0") @Experimental @@ -302,4 +309,6 @@ class BisectingKMeansSummary private[clustering] ( predictionCol: String, featuresCol: String, k: Int, - numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) + numIter: Int, + @Since("3.0.0") val trainingCost: Double) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 80ab8eb9bc8b0..696dff0f319a5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -242,7 +242,8 @@ class BisectingKMeans private ( norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters val root = buildTree(clusters, dMeasure) - new BisectingKMeansModel(root, this.distanceMeasure) + val totalCost = root.leafNodes.map(_.cost).sum + new BisectingKMeansModel(root, this.distanceMeasure, totalCost) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 4c5794fbffc8e..b54b8917e060a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -41,11 +41,12 @@ import org.apache.spark.sql.{Row, SparkSession} @Since("1.6.0") class BisectingKMeansModel private[clustering] ( private[clustering] val root: ClusteringTreeNode, - @Since("2.4.0") val distanceMeasure: String + @Since("2.4.0") val distanceMeasure: String, + @Since("3.0.0") val trainingCost: Double ) extends Serializable with Saveable with Logging { @Since("1.6.0") - def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) + def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN, 0.0) private val distanceMeasureInstance: DistanceMeasure = DistanceMeasure.decodeFromString(distanceMeasure) @@ -109,10 +110,10 @@ class BisectingKMeansModel private[clustering] ( @Since("2.0.0") override def save(sc: SparkContext, path: String): Unit = { - BisectingKMeansModel.SaveLoadV2_0.save(sc, this, path) + BisectingKMeansModel.SaveLoadV3_0.save(sc, this, path) } - override protected def formatVersion: String = "2.0" + override protected def formatVersion: String = "3.0" } @Since("2.0.0") @@ -128,11 +129,15 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => val model = SaveLoadV2_0.load(sc, path) model + case (SaveLoadV3_0.thisClassName, SaveLoadV3_0.thisFormatVersion) => + val model = SaveLoadV3_0.load(sc, path) + model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $formatVersion). Supported:\n" + s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + - s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") + s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})\n" + + s" (${SaveLoadV3_0.thisClassName}, ${SaveLoadV3_0.thisClassName})") } } @@ -195,7 +200,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) + val totalCost = rootNode.leafNodes.map(_.cost).sum + new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN, totalCost) } } @@ -231,7 +237,46 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode, distanceMeasure) + val totalCost = rootNode.leafNodes.map(_.cost).sum + new BisectingKMeansModel(rootNode, distanceMeasure, totalCost) + } + } + + private[clustering] object SaveLoadV3_0 { + private[clustering] val thisFormatVersion = "3.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure) + ~ ("trainingCost" -> model.trainingCost))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + val trainingCost = (metadata \ "trainingCost").extract[Double] + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode, distanceMeasure, trainingCost) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 1b7780e171e77..461f8b8d211d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -134,6 +134,8 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) assert(summary.numIter == 20) + assert(summary.trainingCost < 0.1) + assert(model.computeCost(dataset) == summary.trainingCost) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index 4a4d8b5c89de8..10d5f325d68e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -194,6 +194,7 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.k === sameModel.k) assert(model.distanceMeasure === sameModel.distanceMeasure) model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) + assert(model.trainingCost == sameModel.trainingCost) } finally { Utils.deleteRecursively(tempDir) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 89fc53ce3972f..cf8d9f3c24d07 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25765][ML] Add training cost to BisectingKMeans summary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel.this"), + // [SPARK-24243][CORE] Expose exceptions from InProcessAppHandle ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.launcher.SparkAppHandle.getError"), diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index d8a6dfb7d3a71..5a776aec14252 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -587,6 +587,8 @@ class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPred 2 >>> summary.clusterSizes [2, 2] + >>> summary.trainingCost + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction @@ -700,7 +702,15 @@ class BisectingKMeansSummary(ClusteringSummary): .. versionadded:: 2.1.0 """ - pass + + @property + @since("3.0.0") + def trainingCost(self): + """ + Sum of squared distances to the nearest centroid for all points in the training dataset. + This is equivalent to sklearn's inertia. + """ + return self._call_java("trainingCost") @inherit_doc From 5da55873fa330f4ab21fb05a0a7dbf45bbeb5a54 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 2 Jan 2019 07:59:32 +0800 Subject: [PATCH 2358/2461] [SPARK-26374][TEST][SQL] Enable TimestampFormatter in HadoopFsRelationTest ## What changes were proposed in this pull request? Default timestamp pattern defined in `JSONOptions` doesn't allow saving/loading timestamps with time zones of seconds precision. Because of that, the round trip test failed for timestamps before 1582. In the PR, I propose to extend zone offset section from `XXX` to `XXXXX` which should allow to save/load zone offsets like `-07:52:48`. ## How was this patch tested? It was tested by `JsonHadoopFsRelationSuite` and `TimestampFormatterSuite`. Closes #23417 from MaxGekk/hadoopfsrelationtest-new-formatter. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- .../sql/util/TimestampFormatterSuite.scala | 32 ++++++++++--------- .../sql/sources/HadoopFsRelationTest.scala | 6 ++-- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index edccbb2a7f5db..2ce3eacc30cc0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -70,21 +70,23 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper { } test("roundtrip micros -> timestamp -> micros using timezones") { - Seq( - -58710115316212000L, - -18926315945345679L, - -9463427405253013L, - -244000001L, - 0L, - 99628200102030L, - 1543749753123456L, - 2177456523456789L, - 11858049903010203L).foreach { micros => - DateTimeTestUtils.outstandingTimezones.foreach { timeZone => - val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US) - val timestamp = formatter.format(micros) - val parsed = formatter.parse(timestamp) - assert(micros === parsed) + Seq("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXXXX").foreach { pattern => + Seq( + -58710115316212000L, + -18926315945345679L, + -9463427405253013L, + -244000001L, + 0L, + 99628200102030L, + 1543749753123456L, + 2177456523456789L, + 11858049903010203L).foreach { micros => + DateTimeTestUtils.outstandingTimezones.foreach { timeZone => + val formatter = TimestampFormatter(pattern, timeZone, Locale.US) + val timestamp = formatter.format(micros) + val parsed = formatter.parse(timestamp) + assert(micros === parsed) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index f0f62b608785d..57b896612bfe0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -126,8 +126,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } else { Seq(false) } - // TODO: Support new parser too, see SPARK-26374. - withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") { + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "false") { for (dataType <- supportedDataTypes) { for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) { val extraMessage = if (isParquetDataSource) { @@ -138,7 +137,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes logInfo(s"Testing $dataType data type$extraMessage") val extraOptions = Map[String, String]( - "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString + "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString, + "timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss.SSSXXXXX" ) withTempPath { file => From 39a0493387d66a1e5c04f568804ebc83c2a5f644 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 2 Jan 2019 08:01:34 +0800 Subject: [PATCH 2359/2461] [SPARK-26227][R] from_[csv|json] should accept schema_of_[csv|json] in R API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? **1. Document `from_csv(..., schema_of_csv(...))` support:** ```R csv <- "Amsterdam,2018" df <- sql(paste0("SELECT '", csv, "' as csv")) head(select(df, from_csv(df$csv, schema_of_csv(csv)))) ``` ``` from_csv(csv) 1 Amsterdam, 2018 ``` **2. Allow `from_json(..., schema_of_json(...))`** Before: ```R df2 <- sql("SELECT named_struct('name', 'Bob') as people") df2 <- mutate(df2, people_json = to_json(df2$people)) head(select(df2, from_json(df2$people_json, schema_of_json(head(df2)$people_json)))) ``` ``` Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘from_json’ for signature ‘"Column", "Column"’ ``` After: ```R df2 <- sql("SELECT named_struct('name', 'Bob') as people") df2 <- mutate(df2, people_json = to_json(df2$people)) head(select(df2, from_json(df2$people_json, schema_of_json(head(df2)$people_json)))) ``` ``` from_json(people_json) 1 Bob ``` **3. (While I'm here) Allow `structType` as schema for `from_csv` support to match with `from_json`.** Before: ```R csv <- "Amsterdam,2018" df <- sql(paste0("SELECT '", csv, "' as csv")) head(select(df, from_csv(df$csv, structType("city STRING, year INT")))) ``` ``` Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘from_csv’ for signature ‘"Column", "structType"’ ``` After: ```R csv <- "Amsterdam,2018" df <- sql(paste0("SELECT '", csv, "' as csv")) head(select(df, from_csv(df$csv, structType("city STRING, year INT")))) ``` ``` from_csv(csv) 1 Amsterdam, 2018 ``` ## How was this patch tested? Manually tested and unittests were added. Closes #23184 from HyukjinKwon/SPARK-26227-1. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- R/pkg/R/functions.R | 60 ++++++++++++------- R/pkg/tests/fulltests/test_sparkSQL.R | 16 ++++- .../org/apache/spark/sql/api/r/SQLUtils.scala | 6 +- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5b3cc0940d9c3..58fc4104b0f08 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -202,8 +202,9 @@ NULL #' \itemize{ #' \item \code{from_json}: a structType object to use as the schema to use #' when parsing the JSON string. Since Spark 2.3, the DDL-formatted string is -#' also supported for the schema. -#' \item \code{from_csv}: a DDL-formatted string +#' also supported for the schema. Since Spark 3.0, \code{schema_of_json} or +#' the DDL-formatted string literal can also be accepted. +#' \item \code{from_csv}: a structType object, DDL-formatted string or \code{schema_of_csv} #' } #' @param ... additional argument(s). #' \itemize{ @@ -2254,6 +2255,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) +setClassUnion("characterOrstructTypeOrColumn", c("character", "structType", "Column")) + #' @details #' \code{from_json}: Parses a column containing a JSON string into a Column of \code{structType} #' with the specified \code{schema} or array of \code{structType} if \code{as.json.array} is set @@ -2261,7 +2264,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' @rdname column_collection_functions #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @aliases from_json from_json,Column,characterOrstructType-method +#' @aliases from_json from_json,Column,characterOrstructTypeOrColumn-method #' @examples #' #' \dontrun{ @@ -2269,25 +2272,37 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' df2 <- mutate(df2, d2 = to_json(df2$d, dateFormat = 'dd/MM/yyyy')) #' schema <- structType(structField("date", "string")) #' head(select(df2, from_json(df2$d2, schema, dateFormat = 'dd/MM/yyyy'))) - #' df2 <- sql("SELECT named_struct('name', 'Bob') as people") #' df2 <- mutate(df2, people_json = to_json(df2$people)) #' schema <- structType(structField("name", "string")) #' head(select(df2, from_json(df2$people_json, schema))) -#' head(select(df2, from_json(df2$people_json, "name STRING")))} +#' head(select(df2, from_json(df2$people_json, "name STRING"))) +#' head(select(df2, from_json(df2$people_json, schema_of_json(head(df2)$people_json))))} #' @note from_json since 2.2.0 -setMethod("from_json", signature(x = "Column", schema = "characterOrstructType"), +setMethod("from_json", signature(x = "Column", schema = "characterOrstructTypeOrColumn"), function(x, schema, as.json.array = FALSE, ...) { if (is.character(schema)) { - schema <- structType(schema) + jschema <- structType(schema)$jobj + } else if (class(schema) == "structType") { + jschema <- schema$jobj + } else { + jschema <- schema@jc } if (as.json.array) { - jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", - "createArrayType", - schema$jobj) - } else { - jschema <- schema$jobj + # This case is R-specifically different. Unlike Scala and Python side, + # R side has 'as.json.array' option to indicate if the schema should be + # treated as struct or element type of array in order to make it more + # R-friendly. + if (class(schema) == "Column") { + jschema <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createArrayType", + jschema) + } else { + jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", + "createArrayType", + jschema) + } } options <- varargsToStrEnv(...) jc <- callJStatic("org.apache.spark.sql.functions", @@ -2328,22 +2343,27 @@ setMethod("schema_of_json", signature(x = "characterOrColumn"), #' If the string is unparseable, the Column will contain the value NA. #' #' @rdname column_collection_functions -#' @aliases from_csv from_csv,Column,character-method +#' @aliases from_csv from_csv,Column,characterOrstructTypeOrColumn-method #' @examples #' #' \dontrun{ -#' df <- sql("SELECT 'Amsterdam,2018' as csv") +#' csv <- "Amsterdam,2018" +#' df <- sql(paste0("SELECT '", csv, "' as csv")) #' schema <- "city STRING, year INT" -#' head(select(df, from_csv(df$csv, schema)))} +#' head(select(df, from_csv(df$csv, schema))) +#' head(select(df, from_csv(df$csv, structType(schema)))) +#' head(select(df, from_csv(df$csv, schema_of_csv(csv))))} #' @note from_csv since 3.0.0 -setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"), +setMethod("from_csv", signature(x = "Column", schema = "characterOrstructTypeOrColumn"), function(x, schema, ...) { - if (class(schema) == "Column") { - jschema <- schema@jc - } else if (is.character(schema)) { + if (class(schema) == "structType") { + schema <- callJMethod(schema$jobj, "toDDL") + } + + if (is.character(schema)) { jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema) } else { - stop("schema argument should be a column or character") + jschema <- schema@jc } options <- varargsToStrEnv(...) jc <- callJStatic("org.apache.spark.sql.functions", diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0d5118c127f2b..a1805f57b1dcf 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1626,6 +1626,12 @@ test_that("column functions", { expect_equal(c[[1]][[1]]$a, 1) c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv"))) expect_equal(c[[1]][[1]]$a, 1) + c <- collect(select(df, alias(from_csv(df$col, structType("a INT")), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) + c <- collect(select(df, alias(from_csv(df$col, schema_of_csv("1")), "csv"))) + expect_equal(c[[1]][[1]]$`_c0`, 1) + c <- collect(select(df, alias(from_csv(df$col, schema_of_csv(lit("1"))), "csv"))) + expect_equal(c[[1]][[1]]$`_c0`, 1) df <- as.DataFrame(list(list("col" = "1"))) c <- collect(select(df, schema_of_csv("Amsterdam,2018"))) @@ -1651,7 +1657,9 @@ test_that("column functions", { expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") df <- as.DataFrame(j) schemas <- list(structType(structField("age", "integer"), structField("height", "double")), - "age INT, height DOUBLE") + "age INT, height DOUBLE", + schema_of_json("{\"age\":16,\"height\":176.5}"), + schema_of_json(lit("{\"age\":16,\"height\":176.5}"))) for (schema in schemas) { s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) expect_equal(ncol(s), 1) @@ -1691,7 +1699,11 @@ test_that("column functions", { # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) - for (schema in list(structType(structField("name", "string")), "name STRING")) { + schemas <- list(structType(structField("name", "string")), + "name STRING", + schema_of_json("{\"name\":\"Alice\"}"), + schema_of_json(lit("{\"name\":\"Bob\"}"))) + for (schema in schemas) { arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) expect_equal(ncol(arr), 1) expect_equal(nrow(arr), 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index e98cab8b56d13..f5d8d4ea0a4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericRowWithSchema} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -233,4 +233,8 @@ private[sql] object SQLUtils extends Logging { } sparkSession.sessionState.catalog.listTables(db).map(_.table).toArray } + + def createArrayType(column: Column): ArrayType = { + new ArrayType(ExprUtils.evalTypeExpr(column.expr), true) + } } From d371180c01bf68ed4e5f88df836c7f2fb27a46d3 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 2 Jan 2019 08:04:36 +0800 Subject: [PATCH 2360/2461] [MINOR][R] Deduplicate RStudio setup documentation ## What changes were proposed in this pull request? This PR targets to deduplicate RStudio setup for SparkR. ## How was this patch tested? N/A Closes #23421 from HyukjinKwon/minor-doc. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- R/README.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/R/README.md b/R/README.md index d77a1ecffc99c..e238a0efe4b5e 100644 --- a/R/README.md +++ b/R/README.md @@ -39,15 +39,7 @@ To set other options like driver memory, executor memory etc. you can pass in th #### Using SparkR from RStudio -If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example -```R -# Set this to where Spark is installed -Sys.setenv(SPARK_HOME="/Users/username/spark") -# This line loads SparkR from the installed directory -.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) -library(SparkR) -sparkR.session() -``` +If you wish to use SparkR from RStudio, please refer [SparkR documentation](https://spark.apache.org/docs/latest/sparkr.html#starting-up-from-rstudio). #### Making changes to SparkR From 79b05481a2cff7a0aa34146c72068cc6e41e2241 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 1 Jan 2019 22:37:28 -0600 Subject: [PATCH 2361/2461] [SPARK-26508][CORE][SQL] Address warning messages in Java reported at lgtm.com ## What changes were proposed in this pull request? This PR addresses warning messages in Java files reported at [lgtm.com](https://lgtm.com). [lgtm.com](https://lgtm.com) provides automated code review of Java/Python/JavaScript files for OSS projects. [Here](https://lgtm.com/projects/g/apache/spark/alerts/?mode=list&severity=warning) are warning messages regarding Apache Spark project. This PR addresses the following warnings: - Result of multiplication cast to wider type - Implicit narrowing conversion in compound assignment - Boxed variable is never null - Useless null check NOTE: `Potential input resource leak` looks false positive for now. ## How was this patch tested? Existing UTs Closes #23420 from kiszk/SPARK-26508. Authored-by: Kazuaki Ishizaki Signed-off-by: Sean Owen --- .../java/org/apache/spark/network/util/ByteUnit.java | 12 ++++++------ .../org/apache/spark/network/util/TransportConf.java | 6 +++--- .../org/apache/spark/unsafe/map/BytesToBytesMap.java | 4 ++-- .../main/java/org/apache/spark/examples/JavaTC.java | 2 +- .../org/apache/spark/examples/ml/JavaALSExample.java | 2 +- .../examples/mllib/JavaCorrelationsExample.java | 2 +- .../mllib/JavaRandomForestClassificationExample.java | 8 ++++---- .../org/apache/spark/launcher/LauncherServer.java | 7 +++---- .../expressions/codegen/UnsafeArrayWriter.java | 2 +- .../expressions/codegen/UnsafeRowWriter.java | 2 +- .../sql/catalyst/expressions/xml/UDFXPathUtil.java | 2 +- 11 files changed, 24 insertions(+), 25 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java index 984575acaf511..6f7925c26094d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -18,11 +18,11 @@ public enum ByteUnit { BYTE(1), - KiB(1024L), - MiB((long) Math.pow(1024L, 2L)), - GiB((long) Math.pow(1024L, 3L)), - TiB((long) Math.pow(1024L, 4L)), - PiB((long) Math.pow(1024L, 5L)); + KiB(1L << 10), + MiB(1L << 20), + GiB(1L << 30), + TiB(1L << 40), + PiB(1L << 50); ByteUnit(long multiplier) { this.multiplier = multiplier; @@ -50,7 +50,7 @@ public long convertTo(long d, ByteUnit u) { } } - public double toBytes(long d) { + public long toBytes(long d) { if (d < 0) { throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 43a6bc7dc3d06..201628b04fbef 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -309,8 +309,8 @@ public int chunkFetchHandlerThreads() { } int chunkFetchHandlerThreadsPercent = conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 100); - return (int)Math.ceil( - (this.serverThreads() > 0 ? this.serverThreads() : 2 * NettyRuntime.availableProcessors()) * - chunkFetchHandlerThreadsPercent/(double)100); + int threads = + this.serverThreads() > 0 ? this.serverThreads() : 2 * NettyRuntime.availableProcessors(); + return (int) Math.ceil(threads * (chunkFetchHandlerThreadsPercent / 100.0)); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 7df8aafb2b674..2ff98a69ee1f4 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -712,7 +712,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff final long recordOffset = offset; UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize); UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen); - offset += (2 * uaoSize); + offset += (2L * uaoSize); Platform.copyMemory(kbase, koff, base, offset, klen); offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); @@ -780,7 +780,7 @@ private void allocate(int capacity) { assert (capacity >= 0); capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - longArray = allocateArray(capacity * 2); + longArray = allocateArray(capacity * 2L); longArray.zeroOut(); this.growthThreshold = (int) (capacity * loadFactor); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index c9ca9c9b3a412..7e8df69e7e8da 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -71,7 +71,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); - Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2; + int slices = (args.length > 0) ? Integer.parseInt(args[0]): 2; JavaPairRDD tc = jsc.parallelizePairs(generateGraph(), slices).cache(); // Linear transitive closure: each round grows paths by one edge, diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 27052be87b82e..b8d2c9f6a6584 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -111,7 +111,7 @@ public static void main(String[] args) { .setMetricName("rmse") .setLabelCol("rating") .setPredictionCol("prediction"); - Double rmse = evaluator.evaluate(predictions); + double rmse = evaluator.evaluate(predictions); System.out.println("Root-mean-square error = " + rmse); // Generate top 10 movie recommendations for each user diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java index c0fa0b3cac1e9..9bd858b598905 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java @@ -46,7 +46,7 @@ public static void main(String[] args) { // compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. // If a method is not specified, Pearson's method will be used by default. - Double correlation = Statistics.corr(seriesX.srdd(), seriesY.srdd(), "pearson"); + double correlation = Statistics.corr(seriesX.srdd(), seriesY.srdd(), "pearson"); System.out.println("Correlation is: " + correlation); // note that each Vector is a row and not a column diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java index 6998ce2156c25..0707db8d3e839 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java @@ -48,14 +48,14 @@ public static void main(String[] args) { // Train a RandomForest model. // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; + int numClasses = 2; Map categoricalFeaturesInfo = new HashMap<>(); Integer numTrees = 3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; - Integer seed = 12345; + int maxDepth = 5; + int maxBins = 32; + int seed = 12345; RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 607879fd02ea9..3ff77878f68a8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -318,9 +318,9 @@ protected void handle(Message msg) throws IOException { throw new IllegalArgumentException("Received Hello for unknown client."); } } else { + String msgClassName = msg != null ? msg.getClass().getName() : "no message"; if (handle == null) { - throw new IllegalArgumentException("Expected hello, got: " + - msg != null ? msg.getClass().getName() : null); + throw new IllegalArgumentException("Expected hello, got: " + msgClassName); } if (msg instanceof SetAppId) { SetAppId set = (SetAppId) msg; @@ -328,8 +328,7 @@ protected void handle(Message msg) throws IOException { } else if (msg instanceof SetState) { handle.setState(((SetState)msg).state); } else { - throw new IllegalArgumentException("Invalid message: " + - msg != null ? msg.getClass().getName() : null); + throw new IllegalArgumentException("Invalid message: " + msgClassName); } } } catch (Exception e) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index a78dd970d23e4..997eecd839d85 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -74,7 +74,7 @@ public void initialize(int numElements) { } private long getElementOffset(int ordinal) { - return startingOffset + headerInBytes + ordinal * elementSize; + return startingOffset + headerInBytes + ordinal * (long) elementSize; } private void setNullBit(int ordinal) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 3960d6d520476..d2298aa263646 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -132,7 +132,7 @@ public void setNull8Bytes(int ordinal) { } public long getFieldOffset(int ordinal) { - return startingOffset + nullBitsSize + 8 * ordinal; + return startingOffset + nullBitsSize + 8L * ordinal; } public void write(int ordinal, boolean value) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index 023ec139652c5..e9f18229b54c2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -180,7 +180,7 @@ public long skip(long ns) throws IOException { return 0; } // Bound skip by beginning and end of the source - long n = Math.min(length - next, ns); + int n = (int) Math.min(length - next, ns); n = Math.max(-next, n); next += n; return n; From 4bdfda92a1c570d7a1142ee30eb41e37661bc240 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 2 Jan 2019 11:23:53 -0600 Subject: [PATCH 2362/2461] [SPARK-26507][CORE] Fix core tests for Java 11 ## What changes were proposed in this pull request? This should make tests in core modules pass for Java 11. ## How was this patch tested? Existing tests, with modifications. Closes #23419 from srowen/Java11. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../main/scala/org/apache/spark/util/Utils.scala | 10 +++++++--- .../metrics/source/AccumulatorSourceSuite.scala | 9 ++++----- .../apache/spark/util/JsonProtocolSuite.scala | 16 +++++++++++----- .../scala/org/apache/spark/util/UtilsSuite.scala | 16 ---------------- .../launcher/SparkSubmitCommandBuilderSuite.java | 2 +- pom.xml | 2 ++ 6 files changed, 25 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f322e92c6c8cb..22f074cf98971 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2740,13 +2740,17 @@ private[spark] object Utils extends Logging { /** * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. - * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + * This method mimics scalatest's getSimpleNameOfAnObjectsClass. */ def getSimpleName(cls: Class[_]): String = { try { - return cls.getSimpleName + cls.getSimpleName } catch { - case err: InternalError => return stripDollars(stripPackages(cls.getName)) + // TODO: the value returned here isn't even quite right; it returns simple names + // like UtilsSuite$MalformedClassObject$MalformedClass instead of MalformedClass + // The exact value may not matter much as it's used in log statements + case _: InternalError => + stripDollars(stripPackages(cls.getName)) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala b/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala index 6a6c07cb068cc..45e6e0b4913ed 100644 --- a/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/source/AccumulatorSourceSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.metrics.source -import com.codahale.metrics.MetricRegistry import org.mockito.ArgumentCaptor -import org.mockito.Mockito.{mock, never, spy, times, verify, when} +import org.mockito.Mockito.{mock, times, verify, when} import org.apache.spark.{SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.metrics.MetricsSystem @@ -37,7 +36,7 @@ class AccumulatorSourceSuite extends SparkFunSuite { val accs = Map("my-accumulator-1" -> acc1, "my-accumulator-2" -> acc2) LongAccumulatorSource.register(mockContext, accs) - val captor = new ArgumentCaptor[AccumulatorSource]() + val captor = ArgumentCaptor.forClass(classOf[AccumulatorSource]) verify(mockMetricSystem, times(1)).registerSource(captor.capture()) val source = captor.getValue() val gauges = source.metricRegistry.getGauges() @@ -59,7 +58,7 @@ class AccumulatorSourceSuite extends SparkFunSuite { val accs = Map("my-accumulator-1" -> acc1, "my-accumulator-2" -> acc2) LongAccumulatorSource.register(mockContext, accs) - val captor = new ArgumentCaptor[AccumulatorSource]() + val captor = ArgumentCaptor.forClass(classOf[AccumulatorSource]) verify(mockMetricSystem, times(1)).registerSource(captor.capture()) val source = captor.getValue() val gauges = source.metricRegistry.getGauges() @@ -81,7 +80,7 @@ class AccumulatorSourceSuite extends SparkFunSuite { "my-accumulator-1" -> acc1, "my-accumulator-2" -> acc2) DoubleAccumulatorSource.register(mockContext, accs) - val captor = new ArgumentCaptor[AccumulatorSource]() + val captor = ArgumentCaptor.forClass(classOf[AccumulatorSource]) verify(mockMetricSystem, times(1)).registerSource(captor.capture()) val source = captor.getValue() val gauges = source.metricRegistry.getGauges() diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 303ca7cb8801a..b88f25726fc41 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -761,13 +761,13 @@ private[spark] object JsonProtocolSuite extends Assertions { } private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { - val expectedJson = pretty(parse(expected)) - val actualJson = pretty(parse(actual)) + val expectedJson = parse(expected) + val actualJson = parse(actual) if (expectedJson != actualJson) { // scalastyle:off // This prints something useful if the JSON strings don't match - println("=== EXPECTED ===\n" + expectedJson + "\n") - println("=== ACTUAL ===\n" + actualJson + "\n") + println(s"=== EXPECTED ===\n${pretty(expectedJson)}\n") + println(s"=== ACTUAL ===\n${pretty(actualJson)}\n") // scalastyle:on throw new TestFailedException(s"$metadata JSON did not equal", 1) } @@ -807,7 +807,13 @@ private[spark] object JsonProtocolSuite extends Assertions { } private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) { - assert(ste1 === ste2) + // This mimics the equals() method from Java 8 and earlier. Java 9 adds checks for + // class loader and module, which will cause them to be not equal, when we don't + // care about those + assert(ste1.getClassName === ste2.getClassName) + assert(ste1.getMethodName === ste2.getMethodName) + assert(ste1.getLineNumber === ste2.getLineNumber) + assert(ste1.getFileName === ste2.getFileName) } /** ----------------------------------- * diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index b2ff1cce3eb0b..d3f94fbe05d72 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1156,22 +1156,6 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } } - object MalformedClassObject { - class MalformedClass - } - - test("Safe getSimpleName") { - // getSimpleName on class of MalformedClass will result in error: Malformed class name - // Utils.getSimpleName works - val err = intercept[java.lang.InternalError] { - classOf[MalformedClassObject.MalformedClass].getSimpleName - } - assert(err.getMessage === "Malformed class name") - - assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === - "UtilsSuite$MalformedClassObject$MalformedClass") - } - test("stringHalfWidth") { // scalastyle:off nonascii assert(Utils.stringHalfWidth(null) == 0) diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index b343094b2e7b8..e694e9066f12e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -158,7 +158,7 @@ public void testPySparkLauncher() throws Exception { Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); - assertEquals("python", cmd.get(cmd.size() - 1)); + assertTrue(Arrays.asList("python", "python2", "python3").contains(cmd.get(cmd.size() - 1))); assertEquals( String.format("\"%s\" \"foo\" \"%s\" \"bar\" \"%s\"", parser.MASTER, parser.DEPLOY_MODE, SparkSubmitCommandBuilder.PYSPARK_SHELL_RESOURCE), diff --git a/pom.xml b/pom.xml index 321de209a56a1..a433659cd2002 100644 --- a/pom.xml +++ b/pom.xml @@ -2060,6 +2060,8 @@ ${scala.version} + true + true incremental true From d40654861b1a051d4cfc4ec0d8e80f817e6b8e8b Mon Sep 17 00:00:00 2001 From: seancxmao Date: Wed, 2 Jan 2019 15:45:14 -0600 Subject: [PATCH 2363/2461] [SPARK-26277][SQL][TEST] WholeStageCodegen metrics should be tested with whole-stage codegen enabled ## What changes were proposed in this pull request? In `org.apache.spark.sql.execution.metric.SQLMetricsSuite`, there's a test case named "WholeStageCodegen metrics". However, it is executed with whole-stage codegen disabled. This PR fixes this by enable whole-stage codegen for this test case. ## How was this patch tested? Tested locally using exiting test cases. Closes #23224 from seancxmao/codegen-metrics. Authored-by: seancxmao Signed-off-by: Sean Owen --- .../spark/sql/execution/metric/SQLMetricsSuite.scala | 11 ++++++++--- .../sql/execution/metric/SQLMetricsTestUtils.scala | 7 +++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 7368a6c9e1d64..6174ec4c8908c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -77,11 +77,16 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared } test("WholeStageCodegen metrics") { - // Assume the execution plan is - // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) + // Assume the execution plan with node id is + // WholeStageCodegen(nodeId = 0) + // Filter(nodeId = 1) + // Range(nodeId = 2) // TODO: update metrics in generated operators val ds = spark.range(10).filter('id < 5) - testSparkPlanMetrics(ds.toDF(), 1, Map.empty) + testSparkPlanMetricsWithPredicates(ds.toDF(), 1, Map( + 0L -> (("WholeStageCodegen", Map( + "duration total (min, med, max)" -> {_.toString.matches(timingMetricPattern)}))) + ), true) } test("Aggregate metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 2d245d2ba1e35..0e13f7dd55bae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -144,6 +144,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { * @param df `DataFrame` to run * @param expectedNumOfJobs number of jobs that will run * @param expectedNodeIds the node ids of the metrics to collect from execution data. + * @param enableWholeStage enable whole-stage code generation or not. */ protected def getSparkPlanMetrics( df: DataFrame, @@ -210,13 +211,15 @@ trait SQLMetricsTestUtils extends SQLTestUtils { * @param expectedNumOfJobs number of jobs that will run * @param expectedMetricsPredicates the expected metrics predicates. The format is * `nodeId -> (operatorName, metric name -> metric predicate)`. + * @param enableWholeStage enable whole-stage code generation or not. */ protected def testSparkPlanMetricsWithPredicates( df: DataFrame, expectedNumOfJobs: Int, - expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = { + expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])], + enableWholeStage: Boolean = false): Unit = { val optActualMetrics = - getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet) + getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet, enableWholeStage) optActualMetrics.foreach { actualMetrics => assert(expectedMetricsPredicates.keySet === actualMetrics.keySet) for ((nodeId, (expectedNodeName, expectedMetricsPredicatesMap)) From 8be4d24a27a1e9995a53d4efb3a13a47813d1f77 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 2 Jan 2019 16:57:10 -0800 Subject: [PATCH 2364/2461] [SPARK-26023][SQL][FOLLOWUP] Dumping truncated plans and generated code to a file ## What changes were proposed in this pull request? `DataSourceScanExec` overrides "wrong" `treeString` method without `append`. In the PR, I propose to make `treeString`s **final** to prevent such mistakes in the future. And removed the `treeString` and `verboseString` since they both use `simpleString` with reduction. ## How was this patch tested? It was tested by `DataSourceScanExecRedactionSuite` Closes #23431 from MaxGekk/datasource-scan-exec-followup. Authored-by: Maxim Gekk Signed-off-by: gatorsmile --- .../org/apache/spark/sql/catalyst/trees/TreeNode.scala | 4 ++-- .../apache/spark/sql/execution/DataSourceScanExec.scala | 9 ++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 570a019b2af77..d214ebb309031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -474,9 +474,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = treeString(verbose = true) + final def treeString: String = treeString(verbose = true) - def treeString( + final def treeString( verbose: Boolean, addSuffix: Boolean = false, maxFields: Int = SQLConf.get.maxToStringFields): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 1d7dd73706c48..8b84eda361038 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -58,13 +58,8 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { key + ": " + StringUtils.abbreviate(redact(value), 100) } val metadataStr = truncatedString(metadataEntries, " ", ", ", "", maxFields) - s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]", maxFields)}$metadataStr" - } - - override def verboseString(maxFields: Int): String = redact(super.verboseString(maxFields)) - - override def treeString(verbose: Boolean, addSuffix: Boolean, maxFields: Int): String = { - redact(super.treeString(verbose, addSuffix, maxFields)) + redact( + s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]", maxFields)}$metadataStr") } /** From 56967b7e288ac54e705b14a21516df5402d4c9d9 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 3 Jan 2019 11:01:54 +0800 Subject: [PATCH 2365/2461] [SPARK-26403][SQL] Support pivoting using array column for `pivot(column)` API ## What changes were proposed in this pull request? This PR fixes `pivot(Column)` can accepts `collection.mutable.WrappedArray`. Note that we return `collection.mutable.WrappedArray` from `ArrayType`, and `Literal.apply` doesn't support this. We can unwrap the array and use it for type dispatch. ```scala val df = Seq( (2, Seq.empty[String]), (2, Seq("a", "x")), (3, Seq.empty[String]), (3, Seq("a", "x"))).toDF("x", "s") df.groupBy("x").pivot("s").count().show() ``` Before: ``` Unsupported literal type class scala.collection.mutable.WrappedArray$ofRef WrappedArray() java.lang.RuntimeException: Unsupported literal type class scala.collection.mutable.WrappedArray$ofRef WrappedArray() at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:80) at org.apache.spark.sql.RelationalGroupedDataset.$anonfun$pivot$2(RelationalGroupedDataset.scala:427) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36) at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33) at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:39) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:425) at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:406) at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:317) at org.apache.spark.sql.DataFramePivotSuite.$anonfun$new$1(DataFramePivotSuite.scala:341) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) ``` After: ``` +---+---+------+ | x| []|[a, x]| +---+---+------+ | 3| 1| 1| | 2| 1| 1| +---+---+------+ ``` ## How was this patch tested? Manually tested and unittests were added. Closes #23349 from HyukjinKwon/SPARK-26403. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../spark/sql/catalyst/expressions/literals.scala | 1 + .../catalyst/expressions/LiteralExpressionSuite.scala | 2 ++ .../org/apache/spark/sql/DataFramePivotSuite.scala | 11 +++++++++++ 3 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 34d252886ffb0..48beffa18a551 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -67,6 +67,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case a: collection.mutable.WrappedArray[_] => apply(a.array) case a: Array[_] => val elementType = componentTypeToDataType(a.getClass.getComponentType()) val dataType = ArrayType(elementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 3ea6bfac9ddca..133aaa449ea44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -179,6 +179,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkArrayLiteral(Array("a", "b", "c")) checkArrayLiteral(Array(1.0, 4.0)) checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR)) + val arr = collection.mutable.WrappedArray.make(Array(1.0, 4.0)) + checkEvaluation(Literal(arr), toCatalyst(arr)) } test("seq") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b52ca58c07d27..8c2c11be9b6fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -333,4 +333,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { } assert(exception.getMessage.contains("Unsupported literal type")) } + + test("SPARK-26403: pivoting by array column") { + val df = Seq( + (2, Seq.empty[String]), + (2, Seq("a", "x")), + (3, Seq.empty[String]), + (3, Seq("a", "x"))).toDF("x", "s") + val expected = Seq((3, 1, 1), (2, 1, 1)).toDF + val actual = df.groupBy("x").pivot("s").count() + checkAnswer(actual, expected) + } } From 2a30deb85ae4e42c5cbc936383dd5c3970f4a74f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 3 Jan 2019 11:27:40 +0100 Subject: [PATCH 2366/2461] [SPARK-26502][SQL] Move hiveResultString() from QueryExecution to HiveResult ## What changes were proposed in this pull request? In the PR, I propose to move `hiveResultString()` out of `QueryExecution` and put it to a separate object. Closes #23409 from MaxGekk/hive-result-string. Lead-authored-by: Maxim Gekk Co-authored-by: Maxim Gekk Signed-off-by: Herman van Hovell --- .../spark/sql/execution/HiveResult.scala | 116 ++++++++++++++++++ .../spark/sql/execution/QueryExecution.scala | 91 +------------- .../apache/spark/sql/SQLQueryTestSuite.scala | 5 +- .../hive/thriftserver/SparkSQLDriver.scala | 3 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../hive/execution/HiveComparisonTest.scala | 4 +- 6 files changed, 126 insertions(+), 95 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala new file mode 100644 index 0000000000000..22d3ca958a210 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Runs a query returning the result in Hive compatible form. + */ +object HiveResult { + /** + * Returns the result as a hive compatible sequence of strings. This is used in tests and + * `SparkSQLDriver` for CLI applications. + */ + def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match { + case ExecutedCommandExec(desc: DescribeTableCommand) => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + executedPlan.executeCollectPublic().map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, + Option(comment.asInstanceOf[String]).getOrElse("")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") + } + // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => + command.executeCollect().map(_.getString(1)) + case other => + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq + // We need the types so we can output struct field names + val types = executedPlan.output.map(_.dataType) + // Reformat to match hive tab delimited output. + result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")) + } + + /** Formats a datum (based on the given data type) and returns the string representation. */ + private def toHiveString(a: (Any, DataType)): String = { + val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, + BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) + val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) + + def formatDecimal(d: java.math.BigDecimal): String = { + if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { + java.math.BigDecimal.ZERO.toPlainString + } else { + d.stripTrailingZeros().toPlainString + } + } + + /** Hive outputs fields of structs slightly differently than top level attributes. */ + def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.toSeq.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + a match { + case (struct: Row, StructType(fields)) => + struct.toSeq.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (d: Date, DateType) => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case (t: Timestamp, TimestampType) => + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), timeZone) + case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) + case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString + case (other, tpe) if primitiveTypes.contains(tpe) => other.toString + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7fccbf65d8525..72499aa936a56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -18,25 +18,20 @@ package org.apache.spark.sql.execution import java.io.{BufferedWriter, OutputStreamWriter} -import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils /** @@ -109,90 +104,6 @@ class QueryExecution( ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) - /** - * Returns the result as a hive compatible sequence of strings. This is used in tests and - * `SparkSQLDriver` for CLI applications. - */ - def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => - // If it is a describe command for a Hive table, we want to have the output format - // be similar with Hive. - desc.run(sparkSession).map { - case Row(name: String, dataType: String, comment) => - Seq(name, dataType, - Option(comment.asInstanceOf[String]).getOrElse("")) - .map(s => String.format(s"%-20s", s)) - .mkString("\t") - } - // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => - command.executeCollect().map(_.getString(1)) - case other => - val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq - // We need the types so we can output struct field names - val types = analyzed.output.map(_.dataType) - // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")) - } - - /** Formats a datum (based on the given data type) and returns the string representation. */ - private def toHiveString(a: (Any, DataType)): String = { - val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, - BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) - - def formatDecimal(d: java.math.BigDecimal): String = { - if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { - java.math.BigDecimal.ZERO.toPlainString - } else { - d.stripTrailingZeros().toPlainString - } - } - - /** Hive outputs fields of structs slightly differently than top level attributes. */ - def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "null" - case (s: String, StringType) => "\"" + s + "\"" - case (decimal, DecimalType()) => decimal.toString - case (interval, CalendarIntervalType) => interval.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - - a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "NULL" - case (d: Date, DateType) => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case (t: Timestamp, TimestampType) => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) - case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) - case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) - case (interval, CalendarIntervalType) => interval.toString - case (other, tpe) if primitiveTypes.contains(tpe) => other.toString - } - } - def simpleString: String = withRedaction { val concat = new StringConcat() concat.append("== Physical Plan ==\n") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index b2515226d9a14..24b312348bd67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -22,11 +22,11 @@ import java.util.{Locale, TimeZone} import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -287,7 +287,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val schema = df.schema val notIncludedMsg = "[not included in comparison]" // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") + val answer = hiveResultString(df.queryExecution.executedPlan) + .map(_.replaceAll("#\\d+", "#x") .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 6775902173444..960fdd11db15d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.HiveResult.hiveResultString private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) @@ -61,7 +62,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { - execution.hiveResultString() + hiveResultString(execution.executedPlan) } tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3508affda241a..4c2bc62b9faf8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -297,7 +297,7 @@ private[hive] class TestHiveSparkSession( protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { - () => new TestHiveQueryExecution(sql).hiveResultString(): Unit + () => new TestHiveQueryExecution(sql).executedPlan.executeCollect(): Unit } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 272e6f51f5002..66426824573c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -345,7 +346,8 @@ abstract class HiveComparisonTest val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) def getResult(): Seq[String] = { - SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString()) + SQLExecution.withNewExecutionId( + query.sparkSession, query)(hiveResultString(query.executedPlan)) } try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => From 88b074f3f06ddd236d63e8bf31edebe1d3e94fe4 Mon Sep 17 00:00:00 2001 From: Liupengcheng Date: Thu, 3 Jan 2019 10:26:14 -0600 Subject: [PATCH 2367/2461] [SPARK-26501][CORE][TEST] Fix unexpected overriden of exitFn in SparkSubmitSuite ## What changes were proposed in this pull request? The overriden of SparkSubmit's exitFn at some previous tests in SparkSubmitSuite may cause the following tests pass even they failed when they were run separately. This PR is to fix this problem. ## How was this patch tested? unittest Closes #23404 from liupc/Fix-SparkSubmitSuite-exitFn. Authored-by: Liupengcheng Signed-off-by: Sean Owen --- .../spark/deploy/SparkSubmitSuite.scala | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a8973d1b60f89..2a7a55cbb9039 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -72,27 +72,31 @@ trait TestPrematureExit { mainObject.printStream = printStream @volatile var exitedCleanly = false + val original = mainObject.exitFn mainObject.exitFn = (_) => exitedCleanly = true - - @volatile var exception: Exception = null - val thread = new Thread { - override def run() = try { - mainObject.main(input) - } catch { - // Capture the exception to check whether the exception contains searchString or not - case e: Exception => exception = e + try { + @volatile var exception: Exception = null + val thread = new Thread { + override def run() = try { + mainObject.main(input) + } catch { + // Capture the exception to check whether the exception contains searchString or not + case e: Exception => exception = e + } } - } - thread.start() - thread.join() - if (exitedCleanly) { - val joined = printStream.lineBuffer.mkString("\n") - assert(joined.contains(searchString)) - } else { - assert(exception != null) - if (!exception.getMessage.contains(searchString)) { - throw exception + thread.start() + thread.join() + if (exitedCleanly) { + val joined = printStream.lineBuffer.mkString("\n") + assert(joined.contains(searchString)) + } else { + assert(exception != null) + if (!exception.getMessage.contains(searchString)) { + throw exception + } } + } finally { + mainObject.exitFn = original } } } From 40711eef168716c44b873359e17822fe6b3387f4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 3 Jan 2019 10:30:47 -0600 Subject: [PATCH 2368/2461] [SPARK-26517][SQL][TEST] Avoid duplicate test in ParquetSchemaPruningSuite ## What changes were proposed in this pull request? `testExactCaseQueryPruning` and `testMixedCaseQueryPruning` don't need to set up `PARQUET_VECTORIZED_READER_ENABLED` config. Because `withMixedCaseData` will run against both Spark vectorized reader and Parquet-mr reader. ## How was this patch tested? Existing test. Closes #23427 from viirya/fix-parquet-schema-pruning-test. Authored-by: Liang-Chi Hsieh Signed-off-by: Sean Owen --- .../parquet/ParquetSchemaPruningSuite.scala | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index 434c4414edeba..9a02529a25507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -312,15 +312,8 @@ class ParquetSchemaPruningSuite // schema's column and field names. N.B. this implies that `testThunk` should pass using either a // case-sensitive or case-insensitive query parser private def testExactCaseQueryPruning(testName: String)(testThunk: => Unit) { - test(s"Spark vectorized reader - case-sensitive parser - mixed-case schema - $testName") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", - SQLConf.CASE_SENSITIVE.key -> "true") { - withMixedCaseData(testThunk) - } - } - test(s"Parquet-mr reader - case-sensitive parser - mixed-case schema - $testName") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", - SQLConf.CASE_SENSITIVE.key -> "true") { + test(s"Case-sensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { withMixedCaseData(testThunk) } } @@ -330,20 +323,14 @@ class ParquetSchemaPruningSuite // Tests schema pruning for a query whose column and field names may differ in case from the table // schema's column and field names private def testMixedCaseQueryPruning(testName: String)(testThunk: => Unit) { - test(s"Spark vectorized reader - case-insensitive parser - mixed-case schema - $testName") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", - SQLConf.CASE_SENSITIVE.key -> "false") { - withMixedCaseData(testThunk) - } - } - test(s"Parquet-mr reader - case-insensitive parser - mixed-case schema - $testName") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false", - SQLConf.CASE_SENSITIVE.key -> "false") { + test(s"Case-insensitive parser - mixed-case schema - $testName") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { withMixedCaseData(testThunk) } } } + // Tests given test function with Spark vectorized reader and Parquet-mr reader. private def withMixedCaseData(testThunk: => Unit) { withParquetTable(mixedCaseData, "mixedcase") { testThunk From e2dbafdbc5e50fcf2554bf51939ce0cd363d8806 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 4 Jan 2019 00:37:03 +0800 Subject: [PATCH 2369/2461] [SPARK-26447][SQL] Allow OrcColumnarBatchReader to return less partition columns ## What changes were proposed in this pull request? Currently OrcColumnarBatchReader returns all the partition column values in the batch read. In data source V2, we can improve it by returning the required partition column values only. This PR is part of https://github.com/apache/spark/pull/23383 . As cloud-fan suggested, create a new PR to make review easier. Also, this PR doesn't improve `OrcFileFormat`, since in the method `buildReaderWithPartitionValues`, the `requiredSchema` filter out all the partition columns, so we can't know which partition column is required. ## How was this patch tested? Unit test Closes #23387 from gengliangwang/refactorOrcColumnarBatch. Lead-authored-by: Gengliang Wang Co-authored-by: Gengliang Wang Co-authored-by: Dongjoon Hyun Signed-off-by: Wenchen Fan --- .../orc/OrcColumnarBatchReader.java | 93 ++++++++++--------- .../datasources/orc/OrcFileFormat.scala | 10 +- .../orc/OrcColumnarBatchReaderSuite.scala | 80 ++++++++++++++++ 3 files changed, 136 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index a0d9578a377b1..7dc90df05a8fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.stream.IntStream; +import com.google.common.annotations.VisibleForTesting; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; @@ -58,9 +59,14 @@ public class OrcColumnarBatchReader extends RecordReader { /** * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column doesn't exist in the ORC file. + * -1 means this required column is partition column, or it doesn't exist in the ORC file. + * Ideally partition column should never appear in the physical file, and should only appear + * in the directory name. However, Spark allows partition columns inside physical file, + * but Spark will discard the values from the file, and use the partition value got from + * directory name. The column order will be reserved though. */ - private int[] requestedColIds; + @VisibleForTesting + public int[] requestedDataColIds; // Record reader from ORC row batch. private org.apache.orc.RecordReader recordReader; @@ -68,7 +74,8 @@ public class OrcColumnarBatchReader extends RecordReader { private StructField[] requiredFields; // The result columnar batch for vectorized execution by whole-stage codegen. - private ColumnarBatch columnarBatch; + @VisibleForTesting + public ColumnarBatch columnarBatch; // Writable column vectors of the result columnar batch. private WritableColumnVector[] columnVectors; @@ -143,25 +150,33 @@ public void initialize( /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. + * + * @param orcSchema Schema from ORC file reader. + * @param requiredFields All the fields that are required to return, including partition fields. + * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. + * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionValues Values of partition columns. */ public void initBatch( TypeDescription orcSchema, - int[] requestedColIds, StructField[] requiredFields, - StructType partitionSchema, + int[] requestedDataColIds, + int[] requestedPartitionColIds, InternalRow partitionValues) { batch = orcSchema.createRowBatch(capacity); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. - + assert(requiredFields.length == requestedDataColIds.length); + assert(requiredFields.length == requestedPartitionColIds.length); + // If a required column is also partition column, use partition value and don't read from file. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedPartitionColIds[i] != -1) { + requestedDataColIds[i] = -1; + } + } this.requiredFields = requiredFields; - this.requestedColIds = requestedColIds; - assert(requiredFields.length == requestedColIds.length); + this.requestedDataColIds = requestedDataColIds; StructType resultSchema = new StructType(requiredFields); - for (StructField f : partitionSchema.fields()) { - resultSchema = resultSchema.add(f); - } - if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); @@ -169,22 +184,18 @@ public void initBatch( columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } - // Initialize the missing columns once. + // Initialize the partition columns and missing columns once. for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] == -1) { + if (requestedPartitionColIds[i] != -1) { + ColumnVectorUtils.populate(columnVectors[i], + partitionValues, requestedPartitionColIds[i]); + columnVectors[i].setIsConstant(); + } else if (requestedDataColIds[i] == -1) { columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); - columnVectors[i + partitionIdx].setIsConstant(); - } - } - columnarBatch = new ColumnarBatch(columnVectors); } else { // Just wrap the ORC column vector instead of copying it to Spark column vector. @@ -192,26 +203,22 @@ public void initBatch( for (int i = 0; i < requiredFields.length; i++) { DataType dt = requiredFields[i].dataType(); - int colId = requestedColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - orcVectorWrappers[i] = missingCol; - } else { - orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); - } - } - - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - DataType dt = partitionSchema.fields()[i].dataType(); + if (requestedPartitionColIds[i] != -1) { OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); - ColumnVectorUtils.populate(partitionCol, partitionValues, i); + ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); partitionCol.setIsConstant(); - orcVectorWrappers[partitionIdx + i] = partitionCol; + orcVectorWrappers[i] = partitionCol; + } else { + int colId = requestedDataColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } } } @@ -233,7 +240,7 @@ private boolean nextBatch() throws IOException { if (!copyToSpark) { for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] != -1) { + if (requestedDataColIds[i] != -1) { ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); } } @@ -248,8 +255,8 @@ private boolean nextBatch() throws IOException { StructField field = requiredFields[i]; WritableColumnVector toColumn = columnVectors[i]; - if (requestedColIds[i] >= 0) { - ColumnVector fromColumn = batch.cols[requestedColIds[i]]; + if (requestedDataColIds[i] >= 0) { + ColumnVector fromColumn = batch.cols[requestedDataColIds[i]]; if (fromColumn.isRepeating) { putRepeatingValues(batchSize, field, fromColumn, toColumn); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 4574f8247af54..cd10ad21cd820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -206,13 +206,15 @@ class OrcFileFormat // after opening a file. val iter = new RecordReaderIterator(batchReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - + val requestedDataColIds = requestedColIds ++ Array.fill(partitionSchema.length)(-1) + val requestedPartitionColIds = + Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, - requestedColIds, - requiredSchema.fields, - partitionSchema, + resultSchema.fields, + requestedDataColIds, + requestedPartitionColIds, file.partitionValues) iter.asInstanceOf[Iterator[InternalRow]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala new file mode 100644 index 0000000000000..52abeb20e7f25 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.orc.TypeDescription + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String.fromString + +class OrcColumnarBatchReaderSuite extends QueryTest with SQLTestUtils with SharedSQLContext { + private val dataSchema = StructType.fromDDL("col1 int, col2 int") + private val partitionSchema = StructType.fromDDL("p1 string, p2 string") + private val partitionValues = InternalRow(fromString("partValue1"), fromString("partValue2")) + private val orcFileSchemaList = Seq( + "struct", "struct", + "struct", "struct") + orcFileSchemaList.foreach { case schema => + val orcFileSchema = TypeDescription.fromString(schema) + + val isConstant = classOf[WritableColumnVector].getDeclaredField("isConstant") + isConstant.setAccessible(true) + + def getReader( + requestedDataColIds: Array[Int], + requestedPartitionColIds: Array[Int], + resultFields: Array[StructField]): OrcColumnarBatchReader = { + val reader = new OrcColumnarBatchReader(false, false, 4096) + reader.initBatch( + orcFileSchema, + resultFields, + requestedDataColIds, + requestedPartitionColIds, + partitionValues) + reader + } + + test(s"all partitions are requested: $schema") { + val requestedDataColIds = Array(0, 1, 0, 0) + val requestedPartitionColIds = Array(-1, -1, 0, 1) + val reader = getReader(requestedDataColIds, requestedPartitionColIds, + dataSchema.fields ++ partitionSchema.fields) + assert(reader.requestedDataColIds === Array(0, 1, -1, -1)) + } + + test(s"initBatch should initialize requested partition columns only: $schema") { + val requestedDataColIds = Array(0, -1) // only `col1` is requested, `col2` doesn't exist + val requestedPartitionColIds = Array(-1, 0) // only `p1` is requested + val reader = getReader(requestedDataColIds, requestedPartitionColIds, + Array(dataSchema.fields(0), partitionSchema.fields(0))) + val batch = reader.columnarBatch + assert(batch.numCols() === 2) + + assert(batch.column(0).isInstanceOf[OrcColumnVector]) + assert(batch.column(1).isInstanceOf[OnHeapColumnVector]) + + val p1 = batch.column(1).asInstanceOf[OnHeapColumnVector] + assert(isConstant.get(p1).asInstanceOf[Boolean]) // Partition column is constant. + assert(p1.getUTF8String(0) === partitionValues.getUTF8String(0)) + } + } +} From 05372d188aeaeff5e8de8866ec6e7b932bafa70f Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 3 Jan 2019 14:30:27 -0800 Subject: [PATCH 2370/2461] [SPARK-26489][CORE] Use ConfigEntry for hardcoded configs for python/r categories ## What changes were proposed in this pull request? The PR makes hardcoded configs below to use ConfigEntry. * spark.pyspark * spark.python * spark.r This patch doesn't change configs which are not relevant to SparkConf (e.g. system properties, python source code) ## How was this patch tested? Existing tests. Closes #23428 from HeartSaVioR/SPARK-26489. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Marcelo Vanzin --- .../spark/api/python/PythonRunner.scala | 6 +-- .../api/python/PythonWorkerFactory.scala | 15 +++--- .../org/apache/spark/api/r/RBackend.scala | 10 ++-- .../apache/spark/api/r/RBackendHandler.scala | 7 ++- .../org/apache/spark/api/r/RRunner.scala | 8 ++-- .../org/apache/spark/deploy/RRunner.scala | 9 ++-- .../apache/spark/internal/config/Python.scala | 47 +++++++++++++++++++ .../config/R.scala} | 26 ++++++---- .../spark/internal/config/package.scala | 4 -- .../features/BasicExecutorFeatureStep.scala | 1 + .../BasicExecutorFeatureStepSuite.scala | 1 + .../org/apache/spark/deploy/yarn/Client.scala | 1 + .../spark/deploy/yarn/YarnAllocator.scala | 1 + 13 files changed, 96 insertions(+), 40 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/internal/config/Python.scala rename core/src/main/scala/org/apache/spark/{api/r/SparkRDefaults.scala => internal/config/R.scala} (56%) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f73e95eac8f79..6b748c825d293 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.PYSPARK_EXECUTOR_MEMORY +import org.apache.spark.internal.config.Python._ import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -71,7 +71,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf private val bufferSize = conf.getInt("spark.buffer.size", 65536) - private val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) // each python worker gets an equal part of the allocation. the worker pool will grow to the // number of concurrent tasks, which is determined by the number of cores in this executor. private val memoryMb = conf.get(PYSPARK_EXECUTOR_MEMORY) @@ -496,7 +496,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( extends Thread(s"Worker Monitor for $pythonExec") { /** How long to wait before killing the python worker if a task cannot be interrupted. */ - private val taskKillTimeout = env.conf.getTimeAsMs("spark.python.task.killTimeout", "2s") + private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) setDaemon(true) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 1f2f503a28d49..09e219fef5a1e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -28,6 +28,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python._ import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} @@ -41,7 +42,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // currently only works on UNIX-based systems now because it uses signals for child management, // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. private val useDaemon = { - val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) + val useDaemonEnabled = SparkEnv.get.conf.get(PYTHON_USE_DAEMON) // This flag is ignored on Windows as it's unable to fork. !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled @@ -53,21 +54,21 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // This configuration indicates the module to run the daemon to execute its Python workers. private val daemonModule = - SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value => + SparkEnv.get.conf.get(PYTHON_DAEMON_MODULE).map { value => logInfo( - s"Python daemon module in PySpark is set to [$value] in 'spark.python.daemon.module', " + + s"Python daemon module in PySpark is set to [$value] in '${PYTHON_DAEMON_MODULE.key}', " + "using this to start the daemon up. Note that this configuration only has an effect when " + - "'spark.python.use.daemon' is enabled and the platform is not Windows.") + s"'${PYTHON_USE_DAEMON.key}' is enabled and the platform is not Windows.") value }.getOrElse("pyspark.daemon") // This configuration indicates the module to run each Python worker. private val workerModule = - SparkEnv.get.conf.getOption("spark.python.worker.module").map { value => + SparkEnv.get.conf.get(PYTHON_WORKER_MODULE).map { value => logInfo( - s"Python worker module in PySpark is set to [$value] in 'spark.python.worker.module', " + + s"Python worker module in PySpark is set to [$value] in '${PYTHON_WORKER_MODULE.key}', " + "using this to start the worker up. Note that this configuration only has an effect when " + - "'spark.python.use.daemon' is disabled or the platform is Windows.") + s"'${PYTHON_USE_DAEMON.key}' is disabled or the platform is Windows.") value }.getOrElse("pyspark.worker") diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 50c8fdf5316d6..36b4132088b58 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -32,6 +32,7 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.R._ /** * Netty-based backend server that is used to communicate between R and Java. @@ -47,10 +48,8 @@ private[spark] class RBackend { def init(): (Int, RAuthHelper) = { val conf = new SparkConf() - val backendConnectionTimeout = conf.getInt( - "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) - bossGroup = new NioEventLoopGroup( - conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) + val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT) + bossGroup = new NioEventLoopGroup(conf.get(R_NUM_BACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) val authHelper = new RAuthHelper(conf) @@ -126,8 +125,7 @@ private[spark] object RBackend extends Logging { // Connection timeout is set by socket client. To make it configurable we will pass the // timeout value to client inside the temp file val conf = new SparkConf() - val backendConnectionTimeout = conf.getInt( - "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT) // tell the R process via temporary file val path = args(0) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 18fc595301f46..7b74efa41044f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -29,6 +29,7 @@ import io.netty.handler.timeout.ReadTimeoutException import org.apache.spark.SparkConf import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.R._ import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -98,10 +99,8 @@ private[r] class RBackendHandler(server: RBackend) } } val conf = new SparkConf() - val heartBeatInterval = conf.getInt( - "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL) - val backendConnectionTimeout = conf.getInt( - "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val heartBeatInterval = conf.get(R_HEARTBEAT_INTERVAL) + val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT) val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index e7fdc3963945a..3fdea04cdf7a7 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -27,6 +27,7 @@ import scala.util.Try import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.R._ import org.apache.spark.util.Utils /** @@ -340,11 +341,10 @@ private[r] object RRunner { // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", // but kept here for backward compatibility. val sparkConf = SparkEnv.get.conf - var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") - rCommand = sparkConf.get("spark.r.command", rCommand) + var rCommand = sparkConf.get(SPARKR_COMMAND) + rCommand = sparkConf.get(R_COMMAND).orElse(Some(rCommand)).get - val rConnectionTimeout = sparkConf.getInt( - "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val rConnectionTimeout = sparkConf.get(R_BACKEND_CONNECTION_TIMEOUT) val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir(0) + "/SparkR/worker/" + script diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index e86b362639e57..6284e6a6448f8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,7 +25,8 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkUserAppException} -import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults} +import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.internal.config.R._ import org.apache.spark.util.RedirectThread /** @@ -43,8 +44,8 @@ object RRunner { val rCommand = { // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", // but kept here for backward compatibility. - var cmd = sys.props.getOrElse("spark.sparkr.r.command", "Rscript") - cmd = sys.props.getOrElse("spark.r.command", cmd) + var cmd = sys.props.getOrElse(SPARKR_COMMAND.key, SPARKR_COMMAND.defaultValue.get) + cmd = sys.props.getOrElse(R_COMMAND.key, cmd) if (sys.props.getOrElse("spark.submit.deployMode", "client") == "client") { cmd = sys.props.getOrElse("spark.r.driver.command", cmd) } @@ -53,7 +54,7 @@ object RRunner { // Connection timeout set by R process on its connection to RBackend in seconds. val backendConnectionTimeout = sys.props.getOrElse( - "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString) + R_BACKEND_CONNECTION_TIMEOUT.key, R_BACKEND_CONNECTION_TIMEOUT.defaultValue.get.toString) // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala new file mode 100644 index 0000000000000..26a0598f49411 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.internal.config + +import java.util.concurrent.TimeUnit + +import org.apache.spark.network.util.ByteUnit + +private[spark] object Python { + val PYTHON_WORKER_REUSE = ConfigBuilder("spark.python.worker.reuse") + .booleanConf + .createWithDefault(true) + + val PYTHON_TASK_KILL_TIMEOUT = ConfigBuilder("spark.python.task.killTimeout") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("2s") + + val PYTHON_USE_DAEMON = ConfigBuilder("spark.python.use.daemon") + .booleanConf + .createWithDefault(true) + + val PYTHON_DAEMON_MODULE = ConfigBuilder("spark.python.daemon.module") + .stringConf + .createOptional + + val PYTHON_WORKER_MODULE = ConfigBuilder("spark.python.worker.module") + .stringConf + .createOptional + + val PYSPARK_EXECUTOR_MEMORY = ConfigBuilder("spark.executor.pyspark.memory") + .bytesConf(ByteUnit.MiB) + .createOptional +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/internal/config/R.scala similarity index 56% rename from core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala rename to core/src/main/scala/org/apache/spark/internal/config/R.scala index af67cbbce4e51..26e06a5231c42 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/R.scala @@ -14,17 +14,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.internal.config -package org.apache.spark.api.r +private[spark] object R { -private[spark] object SparkRDefaults { + val R_BACKEND_CONNECTION_TIMEOUT = ConfigBuilder("spark.r.backendConnectionTimeout") + .intConf + .createWithDefault(6000) - // Default value for spark.r.backendConnectionTimeout config - val DEFAULT_CONNECTION_TIMEOUT: Int = 6000 + val R_NUM_BACKEND_THREADS = ConfigBuilder("spark.r.numRBackendThreads") + .intConf + .createWithDefault(2) - // Default value for spark.r.heartBeatInterval config - val DEFAULT_HEARTBEAT_INTERVAL: Int = 100 + val R_HEARTBEAT_INTERVAL = ConfigBuilder("spark.r.heartBeatInterval") + .intConf + .createWithDefault(100) - // Default value for spark.r.numRBackendThreads config - val DEFAULT_NUM_RBACKEND_THREADS = 2 + val SPARKR_COMMAND = ConfigBuilder("spark.sparkr.r.command") + .stringConf + .createWithDefault("Rscript") + + val R_COMMAND = ConfigBuilder("spark.r.command") + .stringConf + .createOptional } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d8e9c099028f5..da8060459477f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -166,10 +166,6 @@ package object config { .checkValue(_ >= 0, "The off-heap memory size must not be negative") .createWithDefault(0) - private[spark] val PYSPARK_EXECUTOR_MEMORY = ConfigBuilder("spark.executor.pyspark.memory") - .bytesConf(ByteUnit.MiB) - .createOptional - private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() .booleanConf.createWithDefault(false) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index c8bf7cdb4224f..dd73a5e52281c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -25,6 +25,7 @@ import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Python._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index c2efab01e4248..e28c650a571ed 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Python._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 184fb6a8ad13e..44a60b835f12f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -53,6 +53,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Python._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.util.{CallerContext, Utils} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a3feca5dfd229..8c6eff9915136 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Python._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor From f65dc9593ee4b84343fea04fdcace14096788be8 Mon Sep 17 00:00:00 2001 From: "Liu,Linhong" Date: Fri, 4 Jan 2019 10:51:33 +0800 Subject: [PATCH 2371/2461] [SPARK-26526][SQL][TEST] Fix invalid test case about non-deterministic expression ## What changes were proposed in this pull request? Test case in SPARK-10316 is used to make sure non-deterministic `Filter` won't be pushed through `Project` But in current code base this test case can't cover this purpose. Change LogicalRDD to HadoopFsRelation can fix this issue. ## How was this patch tested? Modified test pass. Closes #23440 from LinhongLiu/fix-test. Authored-by: Liu,Linhong Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b51c51e663503..3082e0bb97dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1398,11 +1398,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = spark.read.json((1 to 10).map(i => s"""{"id": $i}""").toDS()) + withTempDir { dir => + (1 to 10).toDF("id").write.mode(SaveMode.Overwrite).json(dir.getCanonicalPath) + val input = spark.read.json(dir.getCanonicalPath) - val df = input.select($"id", rand(0).as('r)) - df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => - assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } } } From 27e42c1de502da80fa3e22bb69de47fb00158174 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 Jan 2019 20:01:19 -0800 Subject: [PATCH 2372/2461] [MINOR][NETWORK][TEST] Fix TransportFrameDecoderSuite to use ByteBuf instead of ByteBuffer ## What changes were proposed in this pull request? `fireChannelRead` expects `io.netty.buffer.ByteBuf`.I checked that this is the only place which misuse `java.nio.ByteBuffer` in `network` module. ## How was this patch tested? Pass the Jenkins with the existing tests. Closes #23442 from dongjoon-hyun/SPARK-NETWORK-COMMON. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/network/util/TransportFrameDecoderSuite.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index b53e41303751c..7d40387c5f1af 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.network.util; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Random; @@ -69,7 +68,7 @@ public void testInterception() throws Exception { decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); - verify(ctx).fireChannelRead(any(ByteBuffer.class)); + verify(ctx).fireChannelRead(any(ByteBuf.class)); assertEquals(0, len.refCnt()); assertEquals(0, dataBuf.refCnt()); } finally { From 4419e1daca6c5de373d5f3f13c417b791d768c96 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 4 Jan 2019 22:12:35 +0800 Subject: [PATCH 2373/2461] [SPARK-26445][CORE] Use ConfigEntry for hardcoded configs for driver/executor categories. ## What changes were proposed in this pull request? The PR makes hardcoded spark.driver, spark.executor, and spark.cores.max configs to use `ConfigEntry`. Note that some config keys are from `SparkLauncher` instead of defining in the config package object because the string is already defined in it and it does not depend on core module. ## How was this patch tested? Existing tests. Closes #23415 from ueshin/issues/SPARK-26445/hardcoded_driver_executor_configs. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- .../spark/ExecutorAllocationManager.scala | 4 +- .../scala/org/apache/spark/SparkConf.scala | 25 ++++----- .../scala/org/apache/spark/SparkContext.scala | 8 +-- .../scala/org/apache/spark/SparkEnv.scala | 8 +-- .../spark/api/python/PythonRunner.scala | 4 +- .../org/apache/spark/deploy/Client.scala | 8 +-- .../spark/deploy/FaultToleranceTest.scala | 6 +- .../org/apache/spark/deploy/SparkSubmit.scala | 24 ++++---- .../spark/deploy/SparkSubmitArguments.scala | 20 +++---- .../deploy/rest/StandaloneRestServer.scala | 13 +++-- .../rest/SubmitRestProtocolRequest.scala | 11 ++-- .../spark/deploy/worker/DriverWrapper.scala | 6 +- .../spark/internal/config/package.scala | 55 ++++++++++++++++++- .../spark/memory/StaticMemoryManager.scala | 9 +-- .../spark/memory/UnifiedMemoryManager.scala | 9 +-- .../apache/spark/metrics/MetricsSystem.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 2 +- .../cluster/StandaloneSchedulerBackend.scala | 16 +++--- .../local/LocalSchedulerBackend.scala | 4 +- .../org/apache/spark/util/RpcUtils.scala | 5 +- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../spark/util/logging/FileAppender.scala | 10 ++-- .../util/logging/RollingFileAppender.scala | 19 ++----- .../ExecutorAllocationManagerSuite.scala | 2 +- .../org/apache/spark/SparkConfSuite.scala | 2 +- .../StandaloneDynamicAllocationSuite.scala | 6 +- .../memory/UnifiedMemoryManagerSuite.scala | 2 +- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../apache/spark/util/FileAppenderSuite.scala | 13 ++--- .../k8s/features/BasicDriverFeatureStep.scala | 4 +- .../features/BasicExecutorFeatureStep.scala | 6 +- .../features/DriverServiceFeatureStep.scala | 15 +++-- .../BasicDriverFeatureStepSuite.scala | 2 +- .../BasicExecutorFeatureStepSuite.scala | 16 +++--- .../DriverServiceFeatureStepSuite.scala | 6 +- .../apache/spark/deploy/mesos/config.scala | 3 + .../deploy/rest/mesos/MesosRestServer.scala | 13 +++-- .../cluster/mesos/MesosClusterScheduler.scala | 10 ++-- .../MesosCoarseGrainedSchedulerBackend.scala | 20 +++---- .../MesosFineGrainedSchedulerBackend.scala | 12 ++-- .../spark/deploy/yarn/ApplicationMaster.scala | 6 +- .../org/apache/spark/deploy/yarn/config.scala | 10 +--- .../cluster/YarnClientSchedulerBackend.scala | 8 +-- .../yarn/ResourceRequestHelperSuite.scala | 2 +- 46 files changed, 236 insertions(+), 200 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 3f0b71bbe17f1..d966582295b37 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -127,7 +127,7 @@ private[spark] class ExecutorAllocationManager( // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers private val tasksPerExecutorForFullParallelism = - conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + conf.get(EXECUTOR_CORES) / conf.getInt("spark.task.cpus", 1) private val executorAllocationRatio = conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) @@ -223,7 +223,7 @@ private[spark] class ExecutorAllocationManager( "shuffle service. You may enable this through spark.shuffle.service.enabled.") } if (tasksPerExecutorForFullParallelism == 0) { - throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + throw new SparkException(s"${EXECUTOR_CORES.key} must not be < spark.task.cpus.") } if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 0b47da12b5b42..681e4378a4dd5 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -503,12 +503,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria logWarning(msg) } - val executorOptsKey = "spark.executor.extraJavaOptions" - val executorClasspathKey = "spark.executor.extraClassPath" - val driverOptsKey = "spark.driver.extraJavaOptions" - val driverClassPathKey = "spark.driver.extraClassPath" - val driverLibraryPathKey = "spark.driver.extraLibraryPath" - val sparkExecutorInstances = "spark.executor.instances" + val executorOptsKey = EXECUTOR_JAVA_OPTIONS.key // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -517,7 +512,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria |spark.driver.libraryPath was detected (set to '$value'). |This is deprecated in Spark 1.2+. | - |Please instead use: $driverLibraryPathKey + |Please instead use: ${DRIVER_LIBRARY_PATH.key} """.stripMargin logWarning(warning) } @@ -594,9 +589,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } - if (contains("spark.cores.max") && contains("spark.executor.cores")) { - val totalCores = getInt("spark.cores.max", 1) - val executorCores = getInt("spark.executor.cores", 1) + if (contains(CORES_MAX) && contains(EXECUTOR_CORES)) { + val totalCores = getInt(CORES_MAX.key, 1) + val executorCores = get(EXECUTOR_CORES) val leftCores = totalCores % executorCores if (leftCores != 0) { logWarning(s"Total executor cores: ${totalCores} is not " + @@ -605,12 +600,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } - if (contains("spark.executor.cores") && contains("spark.task.cpus")) { - val executorCores = getInt("spark.executor.cores", 1) + if (contains(EXECUTOR_CORES) && contains("spark.task.cpus")) { + val executorCores = get(EXECUTOR_CORES) val taskCpus = getInt("spark.task.cpus", 1) if (executorCores < taskCpus) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + throw new SparkException(s"${EXECUTOR_CORES.key} must not be less than spark.task.cpus.") } } @@ -680,7 +675,7 @@ private[spark] object SparkConf extends Logging { * TODO: consolidate it with `ConfigBuilder.withAlternative`. */ private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( - "spark.executor.userClassPathFirst" -> Seq( + EXECUTOR_USER_CLASS_PATH_FIRST.key -> Seq( AlternateConfig("spark.files.userClassPathFirst", "1.3")), UPDATE_INTERVAL_S.key -> Seq( AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), @@ -703,7 +698,7 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), - "spark.executor.logs.rolling.maxSize" -> Seq( + EXECUTOR_LOGS_ROLLING_MAX_SIZE.key -> Seq( AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), "spark.io.compression.snappy.blockSize" -> Seq( AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3475859c3ed69..89be9de083075 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -386,9 +386,9 @@ class SparkContext(config: SparkConf) extends Logging { // Set Spark driver host and port system properties. This explicitly sets the configuration // instead of relying on the default value of the config constant. _conf.set(DRIVER_HOST_ADDRESS, _conf.get(DRIVER_HOST_ADDRESS)) - _conf.setIfMissing("spark.driver.port", "0") + _conf.setIfMissing(DRIVER_PORT, 0) - _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) + _conf.set(EXECUTOR_ID, SparkContext.DRIVER_IDENTIFIER) _jars = Utils.getUserJars(_conf) _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.nonEmpty)) @@ -461,7 +461,7 @@ class SparkContext(config: SparkConf) extends Logging { files.foreach(addFile) } - _executorMemory = _conf.getOption("spark.executor.memory") + _executorMemory = _conf.getOption(EXECUTOR_MEMORY.key) .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) .orElse(Option(System.getenv("SPARK_MEM")) .map(warnSparkMem)) @@ -2639,7 +2639,7 @@ object SparkContext extends Logging { case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) case "yarn" => if (conf != null && conf.getOption("spark.submit.deployMode").contains("cluster")) { - conf.getInt("spark.driver.cores", 0) + conf.getInt(DRIVER_CORES.key, 0) } else { 0 } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index de0c8579d9acc..9222781fa0833 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -163,10 +163,10 @@ object SparkEnv extends Logging { mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains(DRIVER_HOST_ADDRESS), s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") - assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") + assert(conf.contains(DRIVER_PORT), s"${DRIVER_PORT.key} is not set on the driver!") val bindAddress = conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) - val port = conf.get("spark.driver.port").toInt + val port = conf.get(DRIVER_PORT) val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { Some(CryptoStreamUtils.createKey(conf)) } else { @@ -251,7 +251,7 @@ object SparkEnv extends Logging { // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", rpcEnv.address.port.toString) + conf.set(DRIVER_PORT, rpcEnv.address.port) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -359,7 +359,7 @@ object SparkEnv extends Logging { // We need to set the executor ID before the MetricsSystem is created because sources and // sinks specified in the metrics configuration file will want to incorporate this executor's // ID into the metrics they report. - conf.set("spark.executor.id", executorId) + conf.set(EXECUTOR_ID, executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 6b748c825d293..5168e9330965d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_CORES import org.apache.spark.internal.config.Python._ import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -74,8 +75,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) // each python worker gets an equal part of the allocation. the worker pool will grow to the // number of concurrent tasks, which is determined by the number of cores in this executor. - private val memoryMb = conf.get(PYSPARK_EXECUTOR_MEMORY) - .map(_ / conf.getInt("spark.executor.cores", 1)) + private val memoryMb = conf.get(PYSPARK_EXECUTOR_MEMORY).map(_ / conf.get(EXECUTOR_CORES)) // All the Python functions should have the same exec, version and envvars. protected val envVars = funcs.head.funcs.head.envVars diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index d5145094ec079..d94b174d8d868 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,7 +27,7 @@ import org.apache.log4j.Logger import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.util.{SparkExitCode, ThreadUtils, Utils} @@ -68,17 +68,17 @@ private class ClientEndpoint( // people call `addJar` assuming the jar is in the same directory. val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" - val classPathConf = "spark.driver.extraClassPath" + val classPathConf = config.DRIVER_CLASS_PATH.key val classPathEntries = sys.props.get(classPathConf).toSeq.flatMap { cp => cp.split(java.io.File.pathSeparator) } - val libraryPathConf = "spark.driver.extraLibraryPath" + val libraryPathConf = config.DRIVER_LIBRARY_PATH.key val libraryPathEntries = sys.props.get(libraryPathConf).toSeq.flatMap { cp => cp.split(java.io.File.pathSeparator) } - val extraJavaOptsConf = "spark.driver.extraJavaOptions" + val extraJavaOptsConf = config.DRIVER_JAVA_OPTIONS.key val extraJavaOpts = sys.props.get(extraJavaOptsConf) .map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c6307da61c7eb..0679bdf7c7075 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -34,7 +34,7 @@ import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.master.RecoveryState -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -77,7 +77,7 @@ private object FaultToleranceTest extends App with Logging { private val containerSparkHome = "/opt/spark" private val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome) - System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip + System.setProperty(config.DRIVER_HOST_ADDRESS.key, "172.17.42.1") // default docker host ip private def afterEach() { if (sc != null) { @@ -216,7 +216,7 @@ private object FaultToleranceTest extends App with Logging { if (sc != null) { sc.stop() } // Counter-hack: Because of a hack in SparkEnv#create() that changes this // property, we need to reset it. - System.setProperty("spark.driver.port", "0") + System.setProperty(config.DRIVER_PORT.key, "0") sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 763bd0a70a035..a4c65aeaae3f6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -514,13 +514,13 @@ private[spark] class SparkSubmit extends Logging { OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, - confKey = "spark.driver.memory"), + confKey = DRIVER_MEMORY.key), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - confKey = "spark.driver.extraClassPath"), + confKey = DRIVER_CLASS_PATH.key), OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - confKey = "spark.driver.extraJavaOptions"), + confKey = DRIVER_JAVA_OPTIONS.key), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - confKey = "spark.driver.extraLibraryPath"), + confKey = DRIVER_LIBRARY_PATH.key), OptionAssigner(args.principal, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = PRINCIPAL.key), OptionAssigner(args.keytab, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, @@ -537,7 +537,7 @@ private[spark] class SparkSubmit extends Logging { // Yarn only OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, - confKey = "spark.executor.instances"), + confKey = EXECUTOR_INSTANCES.key), OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"), OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), @@ -545,22 +545,22 @@ private[spark] class SparkSubmit extends Logging { // Other options OptionAssigner(args.executorCores, STANDALONE | YARN | KUBERNETES, ALL_DEPLOY_MODES, - confKey = "spark.executor.cores"), + confKey = EXECUTOR_CORES.key), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN | KUBERNETES, ALL_DEPLOY_MODES, - confKey = "spark.executor.memory"), + confKey = EXECUTOR_MEMORY.key), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, - confKey = "spark.cores.max"), + confKey = CORES_MAX.key), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, - confKey = "spark.driver.memory"), + confKey = DRIVER_MEMORY.key), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, - confKey = "spark.driver.cores"), + confKey = DRIVER_CORES.key), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, - confKey = "spark.driver.supervise"), + confKey = DRIVER_SUPERVISE.key), OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, confKey = "spark.jars.ivy"), // An internal option used only for spark-shell to add user jars to repl's classloader, @@ -727,7 +727,7 @@ private[spark] class SparkSubmit extends Logging { // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sparkConf.remove("spark.driver.host") + sparkConf.remove(DRIVER_HOST_ADDRESS) } // Resolve paths in certain spark properties diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 4cf08a7980f55..34facd5a58c40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -31,7 +31,7 @@ import scala.util.Try import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkSubmitAction._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -155,31 +155,31 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orElse(env.get("MASTER")) .orNull driverExtraClassPath = Option(driverExtraClassPath) - .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orElse(sparkProperties.get(config.DRIVER_CLASS_PATH.key)) .orNull driverExtraJavaOptions = Option(driverExtraJavaOptions) - .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orElse(sparkProperties.get(config.DRIVER_JAVA_OPTIONS.key)) .orNull driverExtraLibraryPath = Option(driverExtraLibraryPath) - .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orElse(sparkProperties.get(config.DRIVER_LIBRARY_PATH.key)) .orNull driverMemory = Option(driverMemory) - .orElse(sparkProperties.get("spark.driver.memory")) + .orElse(sparkProperties.get(config.DRIVER_MEMORY.key)) .orElse(env.get("SPARK_DRIVER_MEMORY")) .orNull driverCores = Option(driverCores) - .orElse(sparkProperties.get("spark.driver.cores")) + .orElse(sparkProperties.get(config.DRIVER_CORES.key)) .orNull executorMemory = Option(executorMemory) - .orElse(sparkProperties.get("spark.executor.memory")) + .orElse(sparkProperties.get(config.EXECUTOR_MEMORY.key)) .orElse(env.get("SPARK_EXECUTOR_MEMORY")) .orNull executorCores = Option(executorCores) - .orElse(sparkProperties.get("spark.executor.cores")) + .orElse(sparkProperties.get(config.EXECUTOR_CORES.key)) .orElse(env.get("SPARK_EXECUTOR_CORES")) .orNull totalExecutorCores = Option(totalExecutorCores) - .orElse(sparkProperties.get("spark.cores.max")) + .orElse(sparkProperties.get(config.CORES_MAX.key)) .orNull name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull @@ -197,7 +197,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orElse(env.get("DEPLOY_MODE")) .orNull numExecutors = Option(numExecutors) - .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + .getOrElse(sparkProperties.get(config.EXECUTOR_INSTANCES.key).orNull) queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab) .orElse(sparkProperties.get("spark.kerberos.keytab")) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index afa1a5fbba792..c75e684df2264 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletResponse import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.internal.config import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils @@ -132,12 +133,12 @@ private[rest] class StandaloneSubmitRequestServlet( // Optional fields val sparkProperties = request.sparkProperties - val driverMemory = sparkProperties.get("spark.driver.memory") - val driverCores = sparkProperties.get("spark.driver.cores") - val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") - val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") - val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") - val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get(config.DRIVER_MEMORY.key) + val driverCores = sparkProperties.get(config.DRIVER_CORES.key) + val driverExtraJavaOptions = sparkProperties.get(config.DRIVER_JAVA_OPTIONS.key) + val driverExtraClassPath = sparkProperties.get(config.DRIVER_CLASS_PATH.key) + val driverExtraLibraryPath = sparkProperties.get(config.DRIVER_LIBRARY_PATH.key) + val superviseDriver = sparkProperties.get(config.DRIVER_SUPERVISE.key) // The semantics of "spark.master" and the masterUrl are different. While the // property "spark.master" could contain all registered masters, masterUrl // contains only the active master. To make sure a Spark driver can recover diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 86ddf954ca128..7f462148c71a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.rest import scala.util.Try +import org.apache.spark.internal.config import org.apache.spark.util.Utils /** @@ -49,11 +50,11 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertFieldIsSet(appArgs, "appArgs") assertFieldIsSet(environmentVariables, "environmentVariables") assertPropertyIsSet("spark.app.name") - assertPropertyIsBoolean("spark.driver.supervise") - assertPropertyIsNumeric("spark.driver.cores") - assertPropertyIsNumeric("spark.cores.max") - assertPropertyIsMemory("spark.driver.memory") - assertPropertyIsMemory("spark.executor.memory") + assertPropertyIsBoolean(config.DRIVER_SUPERVISE.key) + assertPropertyIsNumeric(config.DRIVER_CORES.key) + assertPropertyIsNumeric(config.CORES_MAX.key) + assertPropertyIsMemory(config.DRIVER_MEMORY.key) + assertPropertyIsMemory(config.EXECUTOR_MEMORY.key) } private def assertPropertyIsSet(key: String): Unit = diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 8d6a2b80ef5f2..1e8ad0b6af6a6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -23,7 +23,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util._ @@ -43,7 +43,7 @@ object DriverWrapper extends Logging { case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() val host: String = Utils.localHostName() - val port: Int = sys.props.getOrElse("spark.driver.port", "0").toInt + val port: Int = sys.props.getOrElse(config.DRIVER_PORT.key, "0").toInt val rpcEnv = RpcEnv.create("Driver", host, port, conf, new SecurityManager(conf)) logInfo(s"Driver address: ${rpcEnv.address}") rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) @@ -51,7 +51,7 @@ object DriverWrapper extends Logging { val currentLoader = Thread.currentThread.getContextClassLoader val userJarUrl = new File(userJar).toURI().toURL() val loader = - if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + if (sys.props.getOrElse(config.DRIVER_USER_CLASS_PATH_FIRST.key, "false").toBoolean) { new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader) } else { new MutableURLClassLoader(Array(userJarUrl), currentLoader) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index da8060459477f..8caaa73b02273 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -39,7 +39,12 @@ package object config { private[spark] val DRIVER_USER_CLASS_PATH_FIRST = ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) - private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores") + .doc("Number of cores to use for the driver process, only in cluster mode.") + .intConf + .createWithDefault(1) + + private[spark] val DRIVER_MEMORY = ConfigBuilder(SparkLauncher.DRIVER_MEMORY) .doc("Amount of memory to use for the driver process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") @@ -113,6 +118,9 @@ package object config { private[spark] val EVENT_LOG_CALLSITE_LONG_FORM = ConfigBuilder("spark.eventLog.longForm.enabled").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_ID = + ConfigBuilder("spark.executor.id").stringConf.createOptional + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional @@ -139,7 +147,11 @@ package object config { private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST = ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) - private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + private[spark] val EXECUTOR_CORES = ConfigBuilder(SparkLauncher.EXECUTOR_CORES) + .intConf + .createWithDefault(1) + + private[spark] val EXECUTOR_MEMORY = ConfigBuilder(SparkLauncher.EXECUTOR_MEMORY) .doc("Amount of memory to use per executor process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") @@ -150,6 +162,15 @@ package object config { .bytesConf(ByteUnit.MiB) .createOptional + private[spark] val CORES_MAX = ConfigBuilder("spark.cores.max") + .doc("When running on a standalone deploy cluster or a Mesos cluster in coarse-grained " + + "sharing mode, the maximum amount of CPU cores to request for the application from across " + + "the cluster (not from each machine). If not set, the default will be " + + "`spark.deploy.defaultCores` on Spark's standalone cluster manager, or infinite " + + "(all available cores) on Mesos.") + .intConf + .createOptional + private[spark] val MEMORY_OFFHEAP_ENABLED = ConfigBuilder("spark.memory.offHeap.enabled") .doc("If true, Spark will attempt to use off-heap memory for certain operations. " + "If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive.") @@ -347,6 +368,17 @@ package object config { .stringConf .createWithDefault(Utils.localCanonicalHostName()) + private[spark] val DRIVER_PORT = ConfigBuilder("spark.driver.port") + .doc("Port of driver endpoints.") + .intConf + .createWithDefault(0) + + private[spark] val DRIVER_SUPERVISE = ConfigBuilder("spark.driver.supervise") + .doc("If true, restarts the driver automatically if it fails with a non-zero exit status. " + + "Only has effect in Spark standalone mode or Mesos cluster deploy mode.") + .booleanConf + .createWithDefault(false) + private[spark] val DRIVER_BIND_ADDRESS = ConfigBuilder("spark.driver.bindAddress") .doc("Address where to bind network listen sockets on the driver.") .fallbackConf(DRIVER_HOST_ADDRESS) @@ -729,4 +761,23 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + private[spark] val EXECUTOR_LOGS_ROLLING_STRATEGY = + ConfigBuilder("spark.executor.logs.rolling.strategy").stringConf.createWithDefault("") + + private[spark] val EXECUTOR_LOGS_ROLLING_TIME_INTERVAL = + ConfigBuilder("spark.executor.logs.rolling.time.interval").stringConf.createWithDefault("daily") + + private[spark] val EXECUTOR_LOGS_ROLLING_MAX_SIZE = + ConfigBuilder("spark.executor.logs.rolling.maxSize") + .stringConf + .createWithDefault((1024 * 1024).toString) + + private[spark] val EXECUTOR_LOGS_ROLLING_MAX_RETAINED_FILES = + ConfigBuilder("spark.executor.logs.rolling.maxRetainedFiles").intConf.createWithDefault(-1) + + private[spark] val EXECUTOR_LOGS_ROLLING_ENABLE_COMPRESSION = + ConfigBuilder("spark.executor.logs.rolling.enableCompression") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index a6f7db0600e60..8286087042741 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.memory import org.apache.spark.SparkConf +import org.apache.spark.internal.config import org.apache.spark.storage.BlockId /** @@ -127,14 +128,14 @@ private[spark] object StaticMemoryManager { if (systemMaxMemory < MIN_MEMORY_BYTES) { throw new IllegalArgumentException(s"System memory $systemMaxMemory must " + s"be at least $MIN_MEMORY_BYTES. Please increase heap size using the --driver-memory " + - s"option or spark.driver.memory in Spark configuration.") + s"option or ${config.DRIVER_MEMORY.key} in Spark configuration.") } - if (conf.contains("spark.executor.memory")) { - val executorMemory = conf.getSizeAsBytes("spark.executor.memory") + if (conf.contains(config.EXECUTOR_MEMORY)) { + val executorMemory = conf.getSizeAsBytes(config.EXECUTOR_MEMORY.key) if (executorMemory < MIN_MEMORY_BYTES) { throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " + s"$MIN_MEMORY_BYTES. Please increase executor memory using the " + - s"--executor-memory option or spark.executor.memory in Spark configuration.") + s"--executor-memory option or ${config.EXECUTOR_MEMORY.key} in Spark configuration.") } } val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 78edd2c4d7faa..9260fd3a6fb34 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.memory import org.apache.spark.SparkConf +import org.apache.spark.internal.config import org.apache.spark.storage.BlockId /** @@ -216,15 +217,15 @@ object UnifiedMemoryManager { if (systemMemory < minSystemMemory) { throw new IllegalArgumentException(s"System memory $systemMemory must " + s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " + - s"option or spark.driver.memory in Spark configuration.") + s"option or ${config.DRIVER_MEMORY.key} in Spark configuration.") } // SPARK-12759 Check executor memory to fail fast if memory is insufficient - if (conf.contains("spark.executor.memory")) { - val executorMemory = conf.getSizeAsBytes("spark.executor.memory") + if (conf.contains(config.EXECUTOR_MEMORY)) { + val executorMemory = conf.getSizeAsBytes(config.EXECUTOR_MEMORY.key) if (executorMemory < minSystemMemory) { throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " + s"$minSystemMemory. Please increase executor memory using the " + - s"--executor-memory option or spark.executor.memory in Spark configuration.") + s"--executor-memory option or ${config.EXECUTOR_MEMORY.key} in Spark configuration.") } } val usableMemory = systemMemory - reservedMemory diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 301317a79dfcf..b1e311ada4599 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -130,7 +130,7 @@ private[spark] class MetricsSystem private ( private[spark] def buildRegistryName(source: Source): String = { val metricsNamespace = conf.get(METRICS_NAMESPACE).orElse(conf.getOption("spark.app.id")) - val executorId = conf.getOption("spark.executor.id") + val executorId = conf.get(EXECUTOR_ID) val defaultName = MetricRegistry.name(source.sourceName) if (instance == "driver" || instance == "executor") { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 6bf60dd8e9dfa..41f032ccf82bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -717,7 +717,7 @@ private[spark] class TaskSetManager( calculatedTasks += 1 if (maxResultSize > 0 && totalResultSize > maxResultSize) { val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + - s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than ${config.MAX_RESULT_SIZE.key} " + s"(${Utils.bytesToString(maxResultSize)})" logError(msg) abort(msg) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index f73a58ff5d48c..adef20d3077d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -25,7 +25,7 @@ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientListener} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ @@ -54,7 +54,7 @@ private[spark] class StandaloneSchedulerBackend( private val registrationBarrier = new Semaphore(0) - private val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + private val maxCores = conf.get(config.CORES_MAX) private val totalExpectedCores = maxCores.getOrElse(0) override def start() { @@ -69,8 +69,8 @@ private[spark] class StandaloneSchedulerBackend( // The endpoint for executors to talk to us val driverUrl = RpcEndpointAddress( - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port").toInt, + sc.conf.get(config.DRIVER_HOST_ADDRESS), + sc.conf.get(config.DRIVER_PORT), CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val args = Seq( "--driver-url", driverUrl, @@ -79,11 +79,11 @@ private[spark] class StandaloneSchedulerBackend( "--cores", "{{CORES}}", "--app-id", "{{APP_ID}}", "--worker-url", "{{WORKER_URL}}") - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") + val extraJavaOpts = sc.conf.get(config.EXECUTOR_JAVA_OPTIONS) .map(Utils.splitCommandString).getOrElse(Seq.empty) - val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath") + val classPathEntries = sc.conf.get(config.EXECUTOR_CLASS_PATH) .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil) - val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath") + val libraryPathEntries = sc.conf.get(config.EXECUTOR_LIBRARY_PATH) .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil) // When testing, expose the parent class path to the child. This is processed by @@ -102,7 +102,7 @@ private[spark] class StandaloneSchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val webUrl = sc.ui.map(_.webUrl).getOrElse("") - val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) + val coresPerExecutor = conf.getOption(config.EXECUTOR_CORES.key).map(_.toInt) // If we're using dynamic allocation, set our initial executor limit to 0 for now. // ExecutorAllocationManager will send the real initial limit to the Master later. val initialExecutorLimit = diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 0de57fbd5600c..6ff8bf29b006a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -24,7 +24,7 @@ import java.nio.ByteBuffer import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ @@ -116,7 +116,7 @@ private[spark] class LocalSchedulerBackend( * @param conf Spark configuration. */ def getUserClasspath(conf: SparkConf): Seq[URL] = { - val userClassPathStr = conf.getOption("spark.executor.extraClassPath") + val userClassPathStr = conf.get(config.EXECUTOR_CLASS_PATH) userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index e5cccf39f9455..902e48fed3916 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark.SparkConf +import org.apache.spark.internal.config import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} private[spark] object RpcUtils { @@ -26,8 +27,8 @@ private[spark] object RpcUtils { * Retrieve a `RpcEndpointRef` which is located in the driver via its name. */ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) + val driverHost: String = conf.get(config.DRIVER_HOST_ADDRESS.key, "localhost") + val driverPort: Int = conf.getInt(config.DRIVER_PORT.key, 7077) Utils.checkHost(driverHost) rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 22f074cf98971..3527fee68939d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2231,7 +2231,7 @@ private[spark] object Utils extends Logging { s"${e.getMessage}: Service$serviceString failed after " + s"$maxRetries retries (on a random free port)! " + s"Consider explicitly setting the appropriate binding address for " + - s"the service$serviceString (for example spark.driver.bindAddress " + + s"the service$serviceString (for example ${DRIVER_BIND_ADDRESS.key} " + s"for SparkDriver) to the correct binding address." } else { s"${e.getMessage}: Service$serviceString failed after " + diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 2f9ad4c8cc3e1..3188e0bd2b70d 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -20,7 +20,7 @@ package org.apache.spark.util.logging import java.io.{File, FileOutputStream, InputStream, IOException} import org.apache.spark.SparkConf -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.util.{IntParam, Utils} /** @@ -115,11 +115,9 @@ private[spark] object FileAppender extends Logging { /** Create the right appender based on Spark configuration */ def apply(inputStream: InputStream, file: File, conf: SparkConf): FileAppender = { - import RollingFileAppender._ - - val rollingStrategy = conf.get(STRATEGY_PROPERTY, STRATEGY_DEFAULT) - val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) - val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) + val rollingStrategy = conf.get(config.EXECUTOR_LOGS_ROLLING_STRATEGY) + val rollingSizeBytes = conf.get(config.EXECUTOR_LOGS_ROLLING_MAX_SIZE) + val rollingInterval = conf.get(config.EXECUTOR_LOGS_ROLLING_TIME_INTERVAL) def createTimeBasedAppender(): FileAppender = { val validatedParams: Option[(Long, String)] = rollingInterval match { diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 5d8cec8447b53..59439b68792e5 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -24,6 +24,7 @@ import com.google.common.io.Files import org.apache.commons.io.IOUtils import org.apache.spark.SparkConf +import org.apache.spark.internal.config /** * Continuously appends data from input stream into the given file, and rolls @@ -44,10 +45,8 @@ private[spark] class RollingFileAppender( bufferSize: Int = RollingFileAppender.DEFAULT_BUFFER_SIZE ) extends FileAppender(inputStream, activeFile, bufferSize) { - import RollingFileAppender._ - - private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) - private val enableCompression = conf.getBoolean(ENABLE_COMPRESSION, false) + private val maxRetainedFiles = conf.get(config.EXECUTOR_LOGS_ROLLING_MAX_RETAINED_FILES) + private val enableCompression = conf.get(config.EXECUTOR_LOGS_ROLLING_ENABLE_COMPRESSION) /** Stop the appender */ override def stop() { @@ -82,7 +81,7 @@ private[spark] class RollingFileAppender( // Roll the log file and compress if enableCompression is true. private def rotateFile(activeFile: File, rolloverFile: File): Unit = { if (enableCompression) { - val gzFile = new File(rolloverFile.getAbsolutePath + GZIP_LOG_SUFFIX) + val gzFile = new File(rolloverFile.getAbsolutePath + RollingFileAppender.GZIP_LOG_SUFFIX) var gzOutputStream: GZIPOutputStream = null var inputStream: InputStream = null try { @@ -103,7 +102,7 @@ private[spark] class RollingFileAppender( // Check if the rollover file already exists. private def rolloverFileExist(file: File): Boolean = { - file.exists || new File(file.getAbsolutePath + GZIP_LOG_SUFFIX).exists + file.exists || new File(file.getAbsolutePath + RollingFileAppender.GZIP_LOG_SUFFIX).exists } /** Move the active log file to a new rollover file */ @@ -164,15 +163,7 @@ private[spark] class RollingFileAppender( * names of configurations that configure rolling file appenders. */ private[spark] object RollingFileAppender { - val STRATEGY_PROPERTY = "spark.executor.logs.rolling.strategy" - val STRATEGY_DEFAULT = "" - val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" - val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" - val SIZE_DEFAULT = (1024 * 1024).toString - val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 - val ENABLE_COMPRESSION = "spark.executor.logs.rolling.enableCompression" val GZIP_LOG_SUFFIX = ".gz" diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 5c718cb654ce8..d0389235cb724 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -155,7 +155,7 @@ class ExecutorAllocationManagerSuite .set("spark.dynamicAllocation.maxExecutors", "15") .set("spark.dynamicAllocation.minExecutors", "3") .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) - .set("spark.executor.cores", cores.toString) + .set(config.EXECUTOR_CORES, cores) val sc = new SparkContext(conf) contexts += sc var manager = sc.executorAllocationManager.get diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index e14a5dcb5ef84..9a6abbdb0a46f 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -140,7 +140,7 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst test("creating SparkContext with cpus per tasks bigger than cores per executors") { val conf = new SparkConf(false) - .set("spark.executor.cores", "1") + .set(EXECUTOR_CORES, 1) .set("spark.task.cpus", "2") intercept[SparkException] { sc = new SparkContext(conf) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index a1d2a1283db14..8567dd1f08233 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -243,7 +243,7 @@ class StandaloneDynamicAllocationSuite } test("dynamic allocation with cores per executor") { - sc = new SparkContext(appConf.set("spark.executor.cores", "2")) + sc = new SparkContext(appConf.set(config.EXECUTOR_CORES, 2)) val appId = sc.applicationId eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() @@ -296,7 +296,7 @@ class StandaloneDynamicAllocationSuite test("dynamic allocation with cores per executor AND max cores") { sc = new SparkContext(appConf - .set("spark.executor.cores", "2") + .set(config.EXECUTOR_CORES, 2) .set("spark.cores.max", "8")) val appId = sc.applicationId eventually(timeout(10.seconds), interval(10.millis)) { @@ -526,7 +526,7 @@ class StandaloneDynamicAllocationSuite new SparkConf() .setMaster(masterRpcEnv.address.toSparkURL) .setAppName("test") - .set("spark.executor.memory", "256m") + .set(config.EXECUTOR_MEMORY.key, "256m") } /** Make a master to which our application will send executor requests. */ diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index d56cfc183d921..5ce3453b682fe 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -248,7 +248,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val mm = UnifiedMemoryManager(conf, numCores = 1) // Try using an executor memory that's too small - val conf2 = conf.clone().set("spark.executor.memory", (reservedMemory / 2).toString) + val conf2 = conf.clone().set(EXECUTOR_MEMORY.key, (reservedMemory / 2).toString) val exception = intercept[IllegalArgumentException] { UnifiedMemoryManager(conf2, numCores = 1) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index d264adaef90a5..f73ff67837c6d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -655,7 +655,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } test("abort the job if total size of results is too large") { - val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") + val conf = new SparkConf().set(config.MAX_RESULT_SIZE.key, "2m") sc = new SparkContext("local", "test", conf) def genBytes(size: Int): (Int) => Array[Byte] = { (x: Int) => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 3962bdc27d22c..19116cf22d2f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE +import org.apache.spark.internal.config.{DRIVER_PORT, MEMORY_OFFHEAP_SIZE} import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -86,7 +86,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", rpcEnv.address.port.toString) + conf.set(DRIVER_PORT, rpcEnv.address.port) conf.set("spark.testing", "true") conf.set("spark.memory.fraction", "1") conf.set("spark.memory.storageFraction", "1") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index cf00c1c3aad39..e866342e4472c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -124,7 +124,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .set("spark.storage.unrollMemoryThreshold", "512") rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) - conf.set("spark.driver.port", rpcEnv.address.port.toString) + conf.set(DRIVER_PORT, rpcEnv.address.port) // Mock SparkContext to reduce the memory usage of tests. It's fine since the only reason we // need to create a SparkContext is to initialize LiveListenerBus. diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 52cd5378bc715..242163931f7ac 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -34,7 +34,7 @@ import org.mockito.Mockito.{atLeast, mock, verify} import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.util.logging.{FileAppender, RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy} class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { @@ -136,7 +136,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // setup input stream and appender val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) - val conf = new SparkConf().set(RollingFileAppender.RETAINED_FILES_PROPERTY, "10") + val conf = new SparkConf().set(config.EXECUTOR_LOGS_ROLLING_MAX_RETAINED_FILES, 10) val appender = new RollingFileAppender(testInputStream, testFile, new SizeBasedRollingPolicy(1000, false), conf, 10) @@ -200,13 +200,12 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { appender.awaitTermination() } - import RollingFileAppender._ - def rollingStrategy(strategy: String): Seq[(String, String)] = - Seq(STRATEGY_PROPERTY -> strategy) - def rollingSize(size: String): Seq[(String, String)] = Seq(SIZE_PROPERTY -> size) + Seq(config.EXECUTOR_LOGS_ROLLING_STRATEGY.key -> strategy) + def rollingSize(size: String): Seq[(String, String)] = + Seq(config.EXECUTOR_LOGS_ROLLING_MAX_SIZE.key -> size) def rollingInterval(interval: String): Seq[(String, String)] = - Seq(INTERVAL_PROPERTY -> interval) + Seq(config.EXECUTOR_LOGS_ROLLING_TIME_INTERVAL.key -> interval) val msInDay = 24 * 60 * 60 * 1000L val msInHour = 60 * 60 * 1000L diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 8362c14fb289d..d52988df58d66 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -42,7 +42,7 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) .getOrElse(throw new SparkException("Must specify the driver container image")) // CPU settings - private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverCpuCores = conf.get(DRIVER_CORES.key, "1") private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) // Memory settings @@ -85,7 +85,7 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf) ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } - val driverPort = conf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) + val driverPort = conf.sparkConf.getInt(DRIVER_PORT.key, DEFAULT_DRIVER_PORT) val driverBlockManagerPort = conf.sparkConf.getInt( DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index dd73a5e52281c..6c3a6b39fa5cb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -46,8 +46,8 @@ private[spark] class BasicExecutorFeatureStep( private val executorPodNamePrefix = kubernetesConf.resourceNamePrefix private val driverUrl = RpcEndpointAddress( - kubernetesConf.get("spark.driver.host"), - kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + kubernetesConf.get(DRIVER_HOST_ADDRESS), + kubernetesConf.sparkConf.getInt(DRIVER_PORT.key, DEFAULT_DRIVER_PORT), CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) private val executorMemoryString = kubernetesConf.get( @@ -67,7 +67,7 @@ private[spark] class BasicExecutorFeatureStep( executorMemoryWithOverhead } - private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCores = kubernetesConf.sparkConf.get(EXECUTOR_CORES) private val executorCoresRequest = if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala index 42305457f4fff..15671179b18b3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} import org.apache.spark.deploy.k8s.{KubernetesDriverConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.util.{Clock, SystemClock} private[spark] class DriverServiceFeatureStep( @@ -51,18 +51,17 @@ private[spark] class DriverServiceFeatureStep( } private val driverPort = kubernetesConf.sparkConf.getInt( - "spark.driver.port", DEFAULT_DRIVER_PORT) + config.DRIVER_PORT.key, DEFAULT_DRIVER_PORT) private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) override def configurePod(pod: SparkPod): SparkPod = pod override def getAdditionalPodSystemProperties(): Map[String, String] = { val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace}.svc" Map(DRIVER_HOST_KEY -> driverHostname, - "spark.driver.port" -> driverPort.toString, - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> - driverBlockManagerPort.toString) + config.DRIVER_PORT.key -> driverPort.toString, + config.DRIVER_BLOCK_MANAGER_PORT.key -> driverBlockManagerPort.toString) } override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { @@ -90,8 +89,8 @@ private[spark] class DriverServiceFeatureStep( } private[spark] object DriverServiceFeatureStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_BIND_ADDRESS_KEY = config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = config.DRIVER_HOST_ADDRESS.key val DRIVER_SVC_POSTFIX = "-driver-svc" val MAX_SERVICE_NAME_LENGTH = 63 } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 5ceb9d6d6fcd0..27d59dd7f3e5b 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -46,7 +46,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set("spark.driver.cores", "2") + .set(DRIVER_CORES, 2) .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") .set(DRIVER_MEMORY.key, "256M") .set(DRIVER_MEMORY_OVERHEAD, 200L) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index e28c650a571ed..36bfb7d41ec39 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config._ +import org.apache.spark.internal.config import org.apache.spark.internal.config.Python._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -74,8 +74,8 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set(DRIVER_HOST_ADDRESS, DRIVER_HOSTNAME) - .set("spark.driver.port", DRIVER_PORT.toString) + .set(config.DRIVER_HOST_ADDRESS, DRIVER_HOSTNAME) + .set(config.DRIVER_PORT, DRIVER_PORT) .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS) .set("spark.kubernetes.resource.type", "java") } @@ -125,8 +125,8 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { } test("classpath and extra java options get translated into environment variables") { - baseConf.set(EXECUTOR_JAVA_OPTIONS, "foo=bar") - baseConf.set(EXECUTOR_CLASS_PATH, "bar=baz") + baseConf.set(config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + baseConf.set(config.EXECUTOR_CLASS_PATH, "bar=baz") val kconf = newExecutorConf(environment = Map("qux" -> "quux")) val step = new BasicExecutorFeatureStep(kconf, new SecurityManager(baseConf)) val executor = step.configurePod(SparkPod.initialPod()) @@ -150,7 +150,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { test("auth secret propagation") { val conf = baseConf.clone() - .set(NETWORK_AUTH_ENABLED, true) + .set(config.NETWORK_AUTH_ENABLED, true) .set("spark.master", "k8s://127.0.0.1") val secMgr = new SecurityManager(conf) @@ -168,8 +168,8 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { val secretFile = new File(secretDir, "secret-file.txt") Files.write(secretFile.toPath, "some-secret".getBytes(StandardCharsets.UTF_8)) val conf = baseConf.clone() - .set(NETWORK_AUTH_ENABLED, true) - .set(AUTH_SECRET_FILE, secretFile.getAbsolutePath) + .set(config.NETWORK_AUTH_ENABLED, true) + .set(config.AUTH_SECRET_FILE, secretFile.getAbsolutePath) .set("spark.master", "k8s://127.0.0.1") val secMgr = new SecurityManager(conf) secMgr.initializeAuth() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 045278939dfff..822f1e32968c2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -39,7 +39,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { test("Headless service has a port for the driver RPC and the block manager.") { val sparkConf = new SparkConf(false) - .set("spark.driver.port", "9000") + .set(DRIVER_PORT, 9000) .set(DRIVER_BLOCK_MANAGER_PORT, 8080) val kconf = KubernetesTestConf.createDriverConf( sparkConf = sparkConf, @@ -61,7 +61,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { test("Hostname and ports are set according to the service name.") { val sparkConf = new SparkConf(false) - .set("spark.driver.port", "9000") + .set(DRIVER_PORT, 9000) .set(DRIVER_BLOCK_MANAGER_PORT, 8080) .set(KUBERNETES_NAMESPACE, "my-namespace") val kconf = KubernetesTestConf.createDriverConf( @@ -87,7 +87,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite { s"${kconf.resourceNamePrefix}${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", resolvedService) val additionalProps = configurationStep.getAdditionalPodSystemProperties() - assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(DRIVER_PORT.key) === DEFAULT_DRIVER_PORT.toString) assert(additionalProps(DRIVER_BLOCK_MANAGER_PORT.key) === DEFAULT_BLOCKMANAGER_PORT.toString) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index d134847dc74d2..dd0b2bad1ecb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -129,4 +129,7 @@ package object config { "when launching drivers. Default is to accept all offers with sufficient resources.") .stringConf .createWithDefault("") + + private[spark] val EXECUTOR_URI = + ConfigBuilder("spark.executor.uri").stringConf.createOptional } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 68f6921153d89..a4aba3e9c0d05 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -27,6 +27,7 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest._ +import org.apache.spark.internal.config import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.util.Utils @@ -92,12 +93,12 @@ private[mesos] class MesosSubmitRequestServlet( // Optional fields val sparkProperties = request.sparkProperties - val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") - val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") - val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") - val superviseDriver = sparkProperties.get("spark.driver.supervise") - val driverMemory = sparkProperties.get("spark.driver.memory") - val driverCores = sparkProperties.get("spark.driver.cores") + val driverExtraJavaOptions = sparkProperties.get(config.DRIVER_JAVA_OPTIONS.key) + val driverExtraClassPath = sparkProperties.get(config.DRIVER_CLASS_PATH.key) + val driverExtraLibraryPath = sparkProperties.get(config.DRIVER_LIBRARY_PATH.key) + val superviseDriver = sparkProperties.get(config.DRIVER_SUPERVISE.key) + val driverMemory = sparkProperties.get(config.DRIVER_MEMORY.key) + val driverCores = sparkProperties.get(config.DRIVER_CORES.key) val name = request.sparkProperties.getOrElse("spark.app.name", mainClass) // Construct driver description diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index cb1bcba651be6..021b1ac84805e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -32,6 +32,7 @@ import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.internal.config.{CORES_MAX, EXECUTOR_LIBRARY_PATH, EXECUTOR_MEMORY} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils @@ -365,8 +366,7 @@ private[spark] class MesosClusterScheduler( } private def getDriverExecutorURI(desc: MesosDriverDescription): Option[String] = { - desc.conf.getOption("spark.executor.uri") - .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + desc.conf.get(config.EXECUTOR_URI).orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) } private def getDriverFrameworkID(desc: MesosDriverDescription): String = { @@ -474,7 +474,7 @@ private[spark] class MesosClusterScheduler( } else if (executorUri.isDefined) { val folderBasename = executorUri.get.split('/').last.split('.').head - val entries = conf.getOption("spark.executor.extraLibraryPath") + val entries = conf.get(EXECUTOR_LIBRARY_PATH) .map(path => Seq(path) ++ desc.command.libraryPathEntries) .getOrElse(desc.command.libraryPathEntries) @@ -528,10 +528,10 @@ private[spark] class MesosClusterScheduler( options ++= Seq("--class", desc.command.mainClass) } - desc.conf.getOption("spark.executor.memory").foreach { v => + desc.conf.getOption(EXECUTOR_MEMORY.key).foreach { v => options ++= Seq("--executor-memory", v) } - desc.conf.getOption("spark.cores.max").foreach { v => + desc.conf.getOption(CORES_MAX.key).foreach { v => options ++= Seq("--total-executor-cores", v) } desc.conf.getOption("spark.submit.pyFiles").foreach { pyFiles => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f5866651dc90b..d0174516c2361 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -33,7 +33,6 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkExceptio import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config -import org.apache.spark.internal.config.EXECUTOR_HEARTBEAT_INTERVAL import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient @@ -63,9 +62,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Blacklist a slave after this many failures private val MAX_SLAVE_FAILURES = 2 - private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) + private val maxCoresOption = conf.get(config.CORES_MAX) - private val executorCoresOption = conf.getOption("spark.executor.cores").map(_.toInt) + private val executorCoresOption = conf.getOption(config.EXECUTOR_CORES.key).map(_.toInt) private val minCoresPerExecutor = executorCoresOption.getOrElse(1) @@ -220,18 +219,18 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val environment = Environment.newBuilder() - val extraClassPath = conf.getOption("spark.executor.extraClassPath") + val extraClassPath = conf.get(config.EXECUTOR_CLASS_PATH) extraClassPath.foreach { cp => environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + val extraJavaOpts = conf.get(config.EXECUTOR_JAVA_OPTIONS).map { Utils.substituteAppNExecIds(_, appId, taskId) }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable - val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p => + val prefixEnv = conf.get(config.EXECUTOR_LIBRARY_PATH).map { p => Utils.libraryPathEnvPrefix(Seq(p)) }.getOrElse("") @@ -261,8 +260,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = conf.getOption("spark.executor.uri") - .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val uri = conf.get(EXECUTOR_URI).orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) if (uri.isEmpty) { val executorSparkHome = conf.getOption("spark.mesos.executor.home") @@ -304,8 +302,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( "driverURL" } else { RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.get("spark.driver.port").toInt, + conf.get(config.DRIVER_HOST_ADDRESS), + conf.get(config.DRIVER_PORT), CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString } } @@ -633,7 +631,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), - sc.conf.get(EXECUTOR_HEARTBEAT_INTERVAL)) + sc.conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) slave.shuffleRegistered = true } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 0bb6fe0fa4bdf..192f9407a1ba4 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -28,8 +28,9 @@ import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} -import org.apache.spark.deploy.mesos.config +import org.apache.spark.deploy.mesos.config.EXECUTOR_URI import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.internal.config import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils @@ -107,15 +108,15 @@ private[spark] class MesosFineGrainedSchedulerBackend( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val environment = Environment.newBuilder() - sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => + sc.conf.get(config.EXECUTOR_CLASS_PATH).foreach { cp => environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + val extraJavaOpts = sc.conf.get(config.EXECUTOR_JAVA_OPTIONS).map { Utils.substituteAppNExecIds(_, appId, execId) }.getOrElse("") - val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => + val prefixEnv = sc.conf.get(config.EXECUTOR_LIBRARY_PATH).map { p => Utils.libraryPathEnvPrefix(Seq(p)) }.getOrElse("") @@ -132,8 +133,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.getOption("spark.executor.uri") - .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val uri = sc.conf.get(EXECUTOR_URI).orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) val executorBackendName = classOf[MesosExecutorBackend].getName if (uri.isEmpty) { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e46c4f970c4a3..8dbdac168f701 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -470,8 +470,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends rpcEnv = sc.env.rpcEnv val userConf = sc.getConf - val host = userConf.get("spark.driver.host") - val port = userConf.get("spark.driver.port").toInt + val host = userConf.get(DRIVER_HOST_ADDRESS) + val port = userConf.get(DRIVER_PORT) registerAM(host, port, userConf, sc.ui.map(_.webUrl)) val driverRef = rpcEnv.setupEndpointRef( @@ -505,7 +505,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends amCores, true) // The client-mode AM doesn't listen for incoming connections, so report an invalid port. - registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress")) + registerAM(hostname, -1, sparkConf, sparkConf.get(DRIVER_APP_UI_ADDRESS)) // The driver should be up and listening, so unlike cluster mode, just try to connect to it // with no waiting or retrying. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index b257d8fdd3b1a..7e9cd409daf36 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -224,16 +224,12 @@ package object config { /* Driver configuration. */ - private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores") - .intConf - .createWithDefault(1) + private[spark] val DRIVER_APP_UI_ADDRESS = ConfigBuilder("spark.driver.appUIAddress") + .stringConf + .createOptional /* Executor configuration. */ - private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores") - .intConf - .createWithDefault(1) - private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.executor.nodeLabelExpression") .doc("Node label expression for executors.") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 9397a1e3de9ac..167eef19ed856 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport} import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl @@ -42,10 +42,10 @@ private[spark] class YarnClientSchedulerBackend( * This waits until the application is running. */ override def start() { - val driverHost = conf.get("spark.driver.host") - val driverPort = conf.get("spark.driver.port") + val driverHost = conf.get(config.DRIVER_HOST_ADDRESS) + val driverPort = conf.get(config.DRIVER_PORT) val hostport = driverHost + ":" + driverPort - sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.webUrl) } + sc.ui.foreach { ui => conf.set(DRIVER_APP_UI_ADDRESS, ui.webUrl) } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ("--arg", hostport) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala index 8032213602c95..9e3cc6ec01dfd 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.yarn.ResourceRequestTestHelper.ResourceInformation import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.config.{DRIVER_MEMORY, EXECUTOR_MEMORY} +import org.apache.spark.internal.config.{DRIVER_CORES, DRIVER_MEMORY, EXECUTOR_CORES, EXECUTOR_MEMORY} class ResourceRequestHelperSuite extends SparkFunSuite with Matchers { From 36440e64476610ec4037fb14f50cf7f06495e384 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 4 Jan 2019 15:35:23 -0600 Subject: [PATCH 2374/2461] [SPARK-26306][TEST][BUILD] More memory to de-flake SorterSuite ## What changes were proposed in this pull request? Increase test memory to avoid OOM in TimSort-related tests. ## How was this patch tested? Existing tests. Closes #23425 from srowen/SPARK-26306. Authored-by: Sean Owen Signed-off-by: Sean Owen --- pom.xml | 4 ++-- project/SparkBuild.scala | 2 +- resource-managers/kubernetes/integration-tests/pom.xml | 2 +- sql/hive/pom.xml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index a433659cd2002..d0c525adf355c 100644 --- a/pom.xml +++ b/pom.xml @@ -2115,7 +2115,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -ea -Xmx3g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} - -da -Xmx3g -XX:ReservedCodeCacheSize=${CodeCacheSize} + -da -Xmx4g -XX:ReservedCodeCacheSize=${CodeCacheSize} From 89cebf4932ff966cc876ba8a9ecd9d9c034fb071 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 4 Jan 2019 15:37:09 -0600 Subject: [PATCH 2375/2461] [SPARK-24421][CORE][FOLLOWUP] Use normal direct ByteBuffer allocation if Cleaner can't be set ## What changes were proposed in this pull request? In Java 9+ we can't use sun.misc.Cleaner by default anymore, and this was largely handled in https://github.com/apache/spark/pull/22993 However I think the change there left a significant problem. If a DirectByteBuffer is allocated using the reflective hack in Platform, now, we by default can't set a Cleaner. But I believe this means the memory isn't freed promptly or possibly at all. If a Cleaner can't be set, I think we need to use normal APIs to allocate the direct ByteBuffer. According to comments in the code, the downside is simply that the normal APIs will check and impose limits on how much off-heap memory can be allocated. Per the original review on https://github.com/apache/spark/pull/22993 this much seems fine, as either way in this case the user would have to add a JVM setting (increase max, or allow the reflective access). ## How was this patch tested? Existing tests. This resolved an OutOfMemoryError in Java 11 from TimSort tests without increasing test heap size. (See https://github.com/apache/spark/pull/23419#issuecomment-450772125 ) This suggests there is a problem and that this resolves it. Closes #23424 from srowen/SPARK-24421.2. Authored-by: Sean Owen Signed-off-by: Sean Owen --- .../org/apache/spark/unsafe/Platform.java | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 076b693f81c88..1adf7abfc8a68 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -209,22 +209,33 @@ public static long reallocateMemory(long address, long oldSize, long newSize) { } /** - * Uses internal JDK APIs to allocate a DirectByteBuffer while ignoring the JVM's - * MaxDirectMemorySize limit (the default limit is too low and we do not want to require users - * to increase it). + * Allocate a DirectByteBuffer, potentially bypassing the JVM's MaxDirectMemorySize limit. */ public static ByteBuffer allocateDirectBuffer(int size) { try { - long memory = allocateMemory(size); - ByteBuffer buffer = (ByteBuffer) DBB_CONSTRUCTOR.newInstance(memory, size); - if (CLEANER_CREATE_METHOD != null) { + if (CLEANER_CREATE_METHOD == null) { + // Can't set a Cleaner (see comments on field), so need to allocate via normal Java APIs try { - DBB_CLEANER_FIELD.set(buffer, - CLEANER_CREATE_METHOD.invoke(null, buffer, (Runnable) () -> freeMemory(memory))); - } catch (IllegalAccessException | InvocationTargetException e) { - throw new IllegalStateException(e); + return ByteBuffer.allocateDirect(size); + } catch (OutOfMemoryError oome) { + // checkstyle.off: RegexpSinglelineJava + throw new OutOfMemoryError("Failed to allocate direct buffer (" + oome.getMessage() + + "); try increasing -XX:MaxDirectMemorySize=... to, for example, your heap size"); + // checkstyle.on: RegexpSinglelineJava } } + // Otherwise, use internal JDK APIs to allocate a DirectByteBuffer while ignoring the JVM's + // MaxDirectMemorySize limit (the default limit is too low and we do not want to + // require users to increase it). + long memory = allocateMemory(size); + ByteBuffer buffer = (ByteBuffer) DBB_CONSTRUCTOR.newInstance(memory, size); + try { + DBB_CLEANER_FIELD.set(buffer, + CLEANER_CREATE_METHOD.invoke(null, buffer, (Runnable) () -> freeMemory(memory))); + } catch (IllegalAccessException | InvocationTargetException e) { + freeMemory(memory); + throw new IllegalStateException(e); + } return buffer; } catch (Exception e) { throwException(e); From bccb8602d7bc78894689e9b2e5fe685763d32d23 Mon Sep 17 00:00:00 2001 From: shane knapp Date: Fri, 4 Jan 2019 18:27:26 -0800 Subject: [PATCH 2376/2461] [SPARK-26537][BUILD] change git-wip-us to gitbox ## What changes were proposed in this pull request? due to apache recently moving from git-wip-us.apache.org to gitbox.apache.org, we need to update the packaging scripts to point to the new repo location. this will also need to be backported to 2.4, 2.3, 2.1, 2.0 and 1.6. ## How was this patch tested? the build system will test this. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23454 from shaneknapp/update-apache-repo. Authored-by: shane knapp Signed-off-by: Dongjoon Hyun --- dev/create-release/release-tag.sh | 2 +- dev/create-release/release-util.sh | 4 ++-- pom.xml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) mode change 100644 => 100755 dev/create-release/release-util.sh diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index 628bc0504c9c8..010082d960a29 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -61,7 +61,7 @@ done init_java init_maven_sbt -ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +ASF_SPARK_REPO="gitbox.apache.org/repos/asf/spark.git" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh old mode 100644 new mode 100755 index 7426b0d6ca08d..c925de9be52d4 --- a/dev/create-release/release-util.sh +++ b/dev/create-release/release-util.sh @@ -19,8 +19,8 @@ DRY_RUN=${DRY_RUN:-0} GPG="gpg --no-tty --batch" -ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" -ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" +ASF_REPO="https://gitbox.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://gitbox.apache.org/repos/asf?p=spark.git" function error { echo "$*" diff --git a/pom.xml b/pom.xml index d0c525adf355c..40b0e328c0359 100644 --- a/pom.xml +++ b/pom.xml @@ -39,7 +39,7 @@ scm:git:git@github.com:apache/spark.git - scm:git:https://git-wip-us.apache.org/repos/asf/spark.git + scm:git:https://gitbox.apache.org/repos/asf/spark.git scm:git:git@github.com:apache/spark.git HEAD From e15a319ccd1125584c09c38ca90b252324df6998 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 4 Jan 2019 19:23:38 -0800 Subject: [PATCH 2377/2461] [SPARK-26536][BUILD][TEST] Upgrade Mockito to 2.23.4 ## What changes were proposed in this pull request? This PR upgrades Mockito from 1.10.19 to 2.23.4. The following changes are required. - Replace `org.mockito.Matchers` with `org.mockito.ArgumentMatchers` - Replace `anyObject` with `any` - Replace `getArgumentAt` with `getArgument` and add type annotation. - Use `isNull` matcher in case of `null` is invoked. ```scala saslHandler.channelInactive(null); - verify(handler).channelInactive(any(TransportClient.class)); + verify(handler).channelInactive(isNull()); ``` - Make and use `doReturn` wrapper to avoid [SI-4775](https://issues.scala-lang.org/browse/SI-4775) ```scala private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) ``` ## How was this patch tested? Pass the Jenkins with the existing tests. Closes #23452 from dongjoon-hyun/SPARK-26536. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/network/sasl/SparkSaslSuite.java | 4 ++-- .../ExternalShuffleBlockHandlerSuite.java | 4 +++- .../shuffle/OneForOneBlockFetcherSuite.java | 8 ++++---- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 4 ++-- .../spark/ExecutorAllocationManagerSuite.scala | 2 +- .../apache/spark/HeartbeatReceiverSuite.scala | 13 ++++++------- .../org/apache/spark/MapOutputTrackerSuite.scala | 2 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../deploy/history/ApplicationCacheSuite.scala | 2 +- .../deploy/history/FsHistoryProviderSuite.scala | 4 ++-- .../history/HistoryServerDiskManagerSuite.scala | 6 ++++-- .../spark/deploy/worker/DriverRunnerTest.scala | 2 +- .../apache/spark/deploy/worker/WorkerSuite.scala | 2 +- .../apache/spark/executor/ExecutorSuite.scala | 2 +- .../apache/spark/memory/MemoryManagerSuite.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 2 +- .../spark/rpc/netty/NettyRpcHandlerSuite.scala | 2 +- .../spark/scheduler/BlacklistTrackerSuite.scala | 2 +- .../apache/spark/scheduler/MapStatusSuite.scala | 3 ++- .../scheduler/OutputCommitCoordinatorSuite.scala | 16 +++++++++------- .../spark/scheduler/TaskContextSuite.scala | 2 +- .../spark/scheduler/TaskResultGetterSuite.scala | 2 +- .../spark/scheduler/TaskSchedulerImplSuite.scala | 10 +++++----- .../spark/scheduler/TaskSetBlacklistSuite.scala | 2 +- .../spark/scheduler/TaskSetManagerSuite.scala | 4 ++-- .../spark/security/CryptoStreamUtilsSuite.scala | 2 +- .../sort/BypassMergeSortShuffleWriterSuite.scala | 2 +- .../sort/IndexShuffleBlockResolverSuite.scala | 2 +- .../shuffle/sort/SortShuffleManagerSuite.scala | 4 +++- .../apache/spark/storage/BlockManagerSuite.scala | 2 +- .../storage/PartiallyUnrolledIteratorSuite.scala | 4 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 7 +++++-- .../sql/kafka010/KafkaDelegationTokenTest.scala | 4 +++- .../kinesis/KinesisCheckpointerSuite.scala | 2 +- .../streaming/kinesis/KinesisReceiverSuite.scala | 5 ++--- .../launcher/SparkSubmitOptionParserSuite.java | 6 +++++- .../org/apache/spark/ml/PipelineSuite.scala | 2 +- pom.xml | 2 +- .../spark/repl/ExecutorClassLoaderSuite.scala | 2 +- .../spark/deploy/k8s/PodBuilderSuite.scala | 2 +- .../features/KubernetesFeaturesTestUtils.scala | 8 ++++---- .../spark/deploy/k8s/submit/ClientSuite.scala | 4 +++- .../cluster/k8s/ExecutorPodsAllocatorSuite.scala | 4 ++-- .../k8s/ExecutorPodsLifecycleManagerSuite.scala | 4 ++-- .../KubernetesClusterSchedulerBackendSuite.scala | 2 +- .../mesos/MesosClusterSchedulerSuite.scala | 7 ++++--- ...MesosCoarseGrainedSchedulerBackendSuite.scala | 9 ++++----- .../MesosFineGrainedSchedulerBackendSuite.scala | 12 ++++++------ .../spark/scheduler/cluster/mesos/Utils.scala | 11 ++++++----- .../apache/spark/deploy/yarn/ClientSuite.scala | 5 +++-- .../yarn/YarnShuffleServiceMetricsSuite.scala | 8 ++++---- .../continuous/EpochCoordinatorSuite.scala | 4 ++-- .../test/DataStreamReaderWriterSuite.scala | 2 +- .../streaming/ReceivedBlockTrackerSuite.scala | 2 +- .../ExecutorAllocationManagerSuite.scala | 4 ++-- .../streaming/util/WriteAheadLogSuite.scala | 4 ++-- 56 files changed, 131 insertions(+), 111 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6f15718bd8705..59adf9704cbf6 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -347,10 +347,10 @@ public void testRpcHandlerDelegate() throws Exception { verify(handler).getStreamManager(); saslHandler.channelInactive(null); - verify(handler).channelInactive(any(TransportClient.class)); + verify(handler).channelInactive(isNull()); saslHandler.exceptionCaught(null, null); - verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); + verify(handler).exceptionCaught(isNull(), isNull()); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7846b71d5a8b1..4cc9a16e1449f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -27,7 +27,7 @@ import org.mockito.ArgumentCaptor; import static org.junit.Assert.*; -import static org.mockito.Matchers.any; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; @@ -79,6 +79,8 @@ public void testRegisterExecutor() { @SuppressWarnings("unchecked") @Test public void testOpenShuffleBlocks() { + when(client.getClientId()).thenReturn("app0"); + RpcResponseCallback callback = mock(RpcResponseCallback.class); ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index dc947a619bf02..95460637db89d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -28,10 +28,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Matchers.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index e5fbafc23d957..ecfebf8f8287e 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -50,8 +50,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.mockito.Answers.RETURNS_SMART_NULLS; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyInt; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.when; diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index d0389235cb724..38f5e8c9f0ac8 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index de479db5fbc0f..a69e589743ef9 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -23,8 +23,7 @@ import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ -import org.mockito.Matchers -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, spy, verify, when} import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} @@ -151,7 +150,7 @@ class HeartbeatReceiverSuite heartbeatReceiverClock.advance(executorTimeout) heartbeatReceiverRef.askSync[Boolean](ExpireDeadHosts) // Only the second executor should be expired as a dead host - verify(scheduler).executorLost(Matchers.eq(executorId2), any()) + verify(scheduler).executorLost(meq(executorId2), any()) val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) @@ -223,10 +222,10 @@ class HeartbeatReceiverSuite assert(!response.reregisterBlockManager) // Additionally verify that the scheduler callback is called with the correct parameters verify(scheduler).executorHeartbeatReceived( - Matchers.eq(executorId), - Matchers.eq(Array(1L -> metrics.accumulators())), - Matchers.eq(blockManagerId), - Matchers.eq(executorUpdates)) + meq(executorId), + meq(Array(1L -> metrics.accumulators())), + meq(blockManagerId), + meq(executorUpdates)) } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 3e1a3d4f73069..c088da8fbf3ba 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.apache.spark.LocalSparkContext._ diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 8567dd1f08233..8c3c38dbc7ea0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import scala.collection.mutable import scala.concurrent.duration._ -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, verify, when} import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index 44f9c566a380d..0402d949e9042 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import com.codahale.metrics.Counter import org.eclipse.jetty.servlet.ServletContextHandler -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 6d2e329094ae2..7d6efd95fbabe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hdfs.{DFSInputStream, DistributedFileSystem} import org.apache.hadoop.security.AccessControlException import org.json4s.jackson.JsonMethods._ import org.mockito.ArgumentMatcher -import org.mockito.Matchers.{any, argThat} +import org.mockito.ArgumentMatchers.{any, argThat} import org.mockito.Mockito.{doThrow, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers @@ -933,7 +933,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val mockedFs = spy(provider.fs) doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open( argThat(new ArgumentMatcher[Path]() { - override def matches(path: Any): Boolean = { + override def matches(path: Path): Boolean = { path.asInstanceOf[Path].getName.toLowerCase(Locale.ROOT) == "accessdenied" } })) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala index 341a1e2443df0..f78469e132490 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerDiskManagerSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.history import java.io.File import org.mockito.AdditionalAnswers -import org.mockito.Matchers.{any, anyBoolean, anyLong, eq => meq} -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{anyBoolean, anyLong, eq => meq} +import org.mockito.Mockito.{doAnswer, spy} import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} @@ -32,6 +32,8 @@ import org.apache.spark.util.kvstore.KVStore class HistoryServerDiskManagerSuite extends SparkFunSuite with BeforeAndAfter { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + private val MAX_USAGE = 3L private var testDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 52956045d5985..1deac43897f90 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -21,7 +21,7 @@ import java.io.File import scala.concurrent.duration._ -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index e3fe2b696aa1f..e5e5b5e428c49 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -22,7 +22,7 @@ import java.util.function.Supplier import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 32a94e60484e3..a5fe2026c0f77 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -30,7 +30,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.mockito.ArgumentCaptor -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 85eeb5055ae03..8b35f1dfddb08 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.mockito.Matchers.{any, anyLong} +import org.mockito.ArgumentMatchers.{any, anyLong} import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 5cb2b561d6bce..558b7fa49832b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -29,7 +29,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.Files -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index a71d8726e7066..4bc001fe8f7c5 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -21,7 +21,7 @@ import java.net.InetSocketAddress import java.nio.ByteBuffer import io.netty.channel.Channel -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 96c8404327e24..aea4c5f96bbe6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{never, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..f41ffb7f2c0b4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import scala.util.Random -import org.mockito.Mockito._ +import org.mockito.Mockito.mock import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} @@ -31,6 +31,7 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) test("compressSize") { assert(MapStatus.compressSize(0L) === 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 158c9eb75f2b6..a560013dba963 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -26,8 +26,8 @@ import scala.language.postfixOps import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType -import org.mockito.Matchers -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{doAnswer, spy, times, verify} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter @@ -71,6 +71,8 @@ import org.apache.spark.util.{ThreadUtils, Utils} */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + var outputCommitCoordinator: OutputCommitCoordinator = null var tempDir: File = null var sc: SparkContext = null @@ -103,7 +105,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { invoke.callRealMethod() mockTaskScheduler.backend.reviveOffers() } - }).when(mockTaskScheduler).submitTasks(Matchers.any()) + }).when(mockTaskScheduler).submitTasks(any()) doAnswer(new Answer[TaskSetManager]() { override def answer(invoke: InvocationOnMock): TaskSetManager = { @@ -123,7 +125,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { } } } - }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any()) + }).when(mockTaskScheduler).createTaskSetManager(any(), any()) sc.taskScheduler = mockTaskScheduler val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler) @@ -154,7 +156,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) + any(), any(), any(), any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -268,8 +270,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(retriedStage.size === 1) assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) verify(sc.env.outputCommitCoordinator, times(2)) - .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) - verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) + .stageStart(meq(retriedStage.head), any()) + verify(sc.env.outputCommitCoordinator).stageEnd(meq(retriedStage.head)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index aa9c36c0aaacb..3bfc97b80184c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index efb8b15cf6b4d..ea1439cfebca2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -28,7 +28,7 @@ import scala.util.control.NonFatal import com.google.common.util.concurrent.MoreExecutors import org.mockito.ArgumentCaptor -import org.mockito.Matchers.{any, anyLong} +import org.mockito.ArgumentMatchers.{any, anyLong} import org.mockito.Mockito.{spy, times, verify} import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 29172b4664e32..9c555a923d625 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import scala.concurrent.duration._ -import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq} +import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq} import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when} import org.scalatest.BeforeAndAfterEach import org.scalatest.concurrent.Eventually @@ -430,7 +430,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B verify(blacklist, never).updateBlacklistForSuccessfulTaskSet( stageId = meq(2), stageAttemptId = anyInt(), - failuresByExec = anyObject()) + failuresByExec = any()) } test("scheduled tasks obey node and executor blacklists") { @@ -504,7 +504,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B WorkerOffer("executor3", "host1", 2) )).flatten.size === 0) assert(tsm.isZombie) - verify(tsm).abort(anyString(), anyObject()) + verify(tsm).abort(anyString(), any()) } test("SPARK-22148 abort timer should kick in when task is completely blacklisted & no new " + @@ -1184,7 +1184,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(finalTsm.isZombie) // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet - verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), any()) // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything // else succeeds, to make sure we get the right updates to the blacklist in all cases. @@ -1202,7 +1202,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // we update the blacklist for the stage attempts with all successful tasks. Even though // some tasksets had failures, we still consider them all successful from a blacklisting // perspective, as the failures weren't from a problem w/ the tasks themselves. - verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), any()) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 6e2709dbe1e8b..b3bc76687ce1b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.scheduler -import org.mockito.Matchers.isA +import org.mockito.ArgumentMatchers.isA import org.mockito.Mockito.{never, verify} import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f73ff67837c6d..f9dfd2c456c52 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,7 +22,7 @@ import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{any, anyInt, anyString} +import org.mockito.ArgumentMatchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, spy, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -1319,7 +1319,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg when(taskSetManagerSpy.addPendingTask(anyInt())).thenAnswer( new Answer[Unit] { override def answer(invocationOnMock: InvocationOnMock): Unit = { - val task = invocationOnMock.getArgumentAt(0, classOf[Int]) + val task: Int = invocationOnMock.getArgument(0) assert(taskSetManager.taskSetBlacklistHelperOpt.get. isExecutorBlacklistedForTask(exec, task)) } diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 0d3611c80b8d0..e5d1bf4fde9e4 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -24,7 +24,7 @@ import java.nio.file.Files import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.apache.spark._ diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 4467c3241a947..7f956c26d0ff0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 4ce379b76b551..0154d0b6ef6f9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -21,7 +21,7 @@ import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index f29dac965c803..e5f3aab6a6a1a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import org.mockito.Mockito._ +import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.Matchers @@ -31,6 +31,8 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} */ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + import SortShuffleManager.canUseSerializedShuffle private class RuntimeExceptionAnswer extends Answer[Object] { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e866342e4472c..a7bb2a03360aa 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -27,7 +27,7 @@ import scala.language.{implicitConversions, postfixOps} import scala.reflect.ClassTag import org.apache.commons.lang3.RandomUtils -import org.mockito.{Matchers => mc} +import org.mockito.{ArgumentMatchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala index cbc903f17ad75..56860b2e55709 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.mockito.Matchers +import org.mockito.ArgumentMatchers.{eq => meq} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar @@ -45,7 +45,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar { joinIterator.hasNext joinIterator.hasNext verify(memoryStore, times(1)) - .releaseUnrollMemoryForThisTask(Matchers.eq(ON_HEAP), Matchers.eq(unrollSize.toLong)) + .releaseUnrollMemoryForThisTask(meq(ON_HEAP), meq(unrollSize.toLong)) // Secondly, iterate over rest iterator (unrollSize until unrollSize + restSize).foreach { value => diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 01ee9ef0825f8..6b83243fe496c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -24,8 +24,8 @@ import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future -import org.mockito.Matchers.{any, eq => meq} -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester @@ -40,6 +40,9 @@ import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { + + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala index 1899c65c721bb..31247ab219082 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDelegationTokenTest.scala @@ -22,7 +22,7 @@ import javax.security.auth.login.{AppConfigurationEntry, Configuration} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token -import org.mockito.Mockito.{doReturn, mock} +import org.mockito.Mockito.mock import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} @@ -35,6 +35,8 @@ import org.apache.spark.deploy.security.KafkaTokenUtil.KafkaDelegationTokenIdent trait KafkaDelegationTokenTest extends BeforeAndAfterEach { self: SparkFunSuite => + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + protected val tokenId = "tokenId" + ju.UUID.randomUUID().toString protected val tokenPassword = "tokenPassword" + ju.UUID.randomUUID().toString diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala index e26f4477d1d7d..bd31b7dc49a64 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -24,7 +24,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 2fadda271ea28..7531a9cc400d9 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -24,9 +24,8 @@ import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record -import org.mockito.Matchers._ -import org.mockito.Matchers.{eq => meq} -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{anyListOf, anyString, eq => meq} +import org.mockito.Mockito.{never, times, verify, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mockito.MockitoSugar diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java index 9ff7aceb581f4..4e26cf6c109c8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -23,6 +23,7 @@ import org.junit.Before; import org.junit.Test; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.*; public class SparkSubmitOptionParserSuite extends BaseSuite { @@ -48,14 +49,17 @@ public void testAllOptions() { } } + int nullCount = 0; for (String[] switchNames : parser.switches) { int switchCount = 0; for (String name : switchNames) { parser.parse(Arrays.asList(name)); count++; + nullCount++; switchCount++; verify(parser, times(switchCount)).handle(eq(switchNames[0]), same(null)); - verify(parser, times(count)).handle(anyString(), any(String.class)); + verify(parser, times(nullCount)).handle(anyString(), isNull()); + verify(parser, times(count - nullCount)).handle(anyString(), any(String.class)); verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7848eae931a06..1183cb0617610 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mockito.MockitoSugar.mock diff --git a/pom.xml b/pom.xml index 40b0e328c0359..245344a826935 100644 --- a/pom.xml +++ b/pom.xml @@ -764,7 +764,7 @@ org.mockito mockito-core - 1.10.19 + 2.23.4 test diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index ac528ecb829b0..e9ed01ff22338 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -30,7 +30,7 @@ import scala.io.Source import scala.language.implicitConversions import com.google.common.io.Files -import org.mockito.Matchers.anyString +import org.mockito.ArgumentMatchers.anyString import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala index 7dde0c1377168..707c823d69cf0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.api.model.{Config => _, _} import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource} -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, never, verify, when} import scala.collection.JavaConverters._ diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 076b681be2397..95de7d9059540 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -20,8 +20,8 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} -import org.mockito.Matchers -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -37,10 +37,10 @@ object KubernetesFeaturesTestUtils { when(mockStep.getAdditionalPodSystemProperties()) .thenReturn(Map(stepType -> stepType)) - when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + when(mockStep.configurePod(any(classOf[SparkPod]))) .thenAnswer(new Answer[SparkPod]() { override def answer(invocation: InvocationOnMock): SparkPod = { - val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val originalPod: SparkPod = invocation.getArgument(0) val configuredPod = new PodBuilder(originalPod.pod) .editOrNewMetadata() .addToLabels(stepType, stepType) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 1bb926cbca23d..aa421be6e8412 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -20,7 +20,7 @@ import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} -import org.mockito.Mockito.{doReturn, verify, when} +import org.mockito.Mockito.{verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ @@ -31,6 +31,8 @@ import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + private val DRIVER_POD_UID = "pod-id" private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala index 278a3821a6f3d..55d9adc212f92 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -20,7 +20,7 @@ import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -156,7 +156,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { private def executorPodAnswer(): Answer[SparkPod] = { new Answer[SparkPod] { override def answer(invocation: InvocationOnMock): SparkPod = { - val k8sConf = invocation.getArgumentAt(0, classOf[KubernetesExecutorConf]) + val k8sConf: KubernetesExecutorConf = invocation.getArgument(0) executorPodWithId(k8sConf.executorId.toInt) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index 7411f8f9d69e9..b20ed4799e325 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -21,7 +21,7 @@ import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.client.dsl.PodResource import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, never, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -128,7 +128,7 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { new Answer[PodResource[Pod, DoneablePod]] { override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { - val podName = invocation.getArgumentAt(0, classOf[String]) + val podName: String = invocation.getArgument(0) namedExecutorPods.getOrElseUpdate( podName, mock(classOf[PodResource[Pod, DoneablePod]])) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 6e182bed459f8..8ed934d91dd7e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.client.KubernetesClient import org.jmock.lib.concurrent.DeterministicScheduler import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} -import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.ArgumentMatchers.{eq => mockitoEq} import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 082d4bcfdf83a..7adac1964e010 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -24,7 +24,8 @@ import scala.collection.JavaConverters._ import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Value.{Scalar, Type} import org.apache.mesos.SchedulerDriver -import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.{eq => meq} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar @@ -133,7 +134,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi when( driver.launchTasks( - Matchers.eq(Collections.singleton(offer.getId)), + meq(Collections.singleton(offer.getId)), capture.capture()) ).thenReturn(Status.valueOf(1)) @@ -156,7 +157,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(mem.exists(_.getRole() == "*")) verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer.getId)), + meq(Collections.singleton(offer.getId)), capture.capture() ) } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index da33d85d8fb2e..0cfaa0a0c9a60 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -24,9 +24,8 @@ import scala.concurrent.duration._ import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ -import org.mockito.Matchers -import org.mockito.Matchers._ -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong, anyString, eq => meq} +import org.mockito.Mockito.{times, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.ScalaFutures import org.scalatest.mockito.MockitoSugar @@ -697,9 +696,9 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite offerId: OfferID, filter: Boolean = false): Unit = { if (filter) { - verify(driver, times(1)).declineOffer(Matchers.eq(offerId), anyObject[Filters]) + verify(driver, times(1)).declineOffer(meq(offerId), any[Filters]()) } else { - verify(driver, times(1)).declineOffer(Matchers.eq(offerId)) + verify(driver, times(1)).declineOffer(meq(offerId)) } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 1ead4b1ed7c7e..c9b7e6c439c4b 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -30,8 +30,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.Scalar -import org.mockito.{ArgumentCaptor, Matchers} -import org.mockito.Matchers._ +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.{any, anyLong, eq => meq} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar @@ -264,7 +264,7 @@ class MesosFineGrainedSchedulerBackendSuite val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + meq(Collections.singleton(mesosOffers.get(0).getId)), capture.capture(), any(classOf[Filters]) ) @@ -275,7 +275,7 @@ class MesosFineGrainedSchedulerBackendSuite backend.resourceOffers(driver, mesosOffers) verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + meq(Collections.singleton(mesosOffers.get(0).getId)), capture.capture(), any(classOf[Filters]) ) @@ -373,7 +373,7 @@ class MesosFineGrainedSchedulerBackendSuite val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + meq(Collections.singleton(mesosOffers.get(0).getId)), capture.capture(), any(classOf[Filters]) ) @@ -382,7 +382,7 @@ class MesosFineGrainedSchedulerBackendSuite backend.resourceOffers(driver, mesosOffers) verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + meq(Collections.singleton(mesosOffers.get(0).getId)), capture.capture(), any(classOf[Filters]) ) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index c9f47471cd75e..65e595e3cf2bf 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -25,8 +25,9 @@ import org.apache.mesos.Protos._ import org.apache.mesos.Protos.Value.{Range => MesosRange, Ranges, Scalar} import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString -import org.mockito.{ArgumentCaptor, Matchers} -import org.mockito.Mockito._ +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{times, verify} import org.apache.spark.deploy.mesos.config.MesosSecretConfig @@ -84,15 +85,15 @@ object Utils { def verifyTaskLaunched(driver: SchedulerDriver, offerId: String): List[TaskInfo] = { val captor = ArgumentCaptor.forClass(classOf[java.util.Collection[TaskInfo]]) verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(createOfferId(offerId))), + meq(Collections.singleton(createOfferId(offerId))), captor.capture()) captor.getValue.asScala.toList } def verifyTaskNotLaunched(driver: SchedulerDriver, offerId: String): Unit = { verify(driver, times(0)).launchTasks( - Matchers.eq(Collections.singleton(createOfferId(offerId))), - Matchers.any(classOf[java.util.Collection[TaskInfo]])) + meq(Collections.singleton(createOfferId(offerId))), + any(classOf[java.util.Collection[TaskInfo]])) } def createOfferId(offerId: String): OfferID = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index a6f57fcdb2461..9acd99546c036 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -34,8 +34,8 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.YarnClientApplication import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.mockito.Matchers.{eq => meq, _} -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{any, anyBoolean, anyShort, eq => meq} +import org.mockito.Mockito.{spy, verify} import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TestUtils} @@ -43,6 +43,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.util.{SparkConfWithEnv, Utils} class ClientSuite extends SparkFunSuite with Matchers { + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) import Client._ diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index 952fd0b70bb7b..f538cbc5b7657 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.yarn import scala.collection.JavaConverters._ import org.apache.hadoop.metrics2.MetricsRecordBuilder -import org.mockito.Matchers._ +import org.mockito.ArgumentMatchers.{any, anyDouble, anyInt, anyLong} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.Matchers @@ -56,8 +56,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { YarnShuffleServiceMetrics.collectMetric(builder, testname, metrics.getMetrics.get(testname)) - verify(builder).addCounter(anyObject(), anyLong()) - verify(builder, times(4)).addGauge(anyObject(), anyDouble()) + verify(builder).addCounter(any(), anyLong()) + verify(builder, times(4)).addGauge(any(), anyDouble()) } } @@ -69,6 +69,6 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { metrics.getMetrics.get("registeredExecutorsSize")) // only one - verify(builder).addGauge(anyObject(), anyInt()) + verify(builder).addGauge(any(), anyInt()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 3c973d8ebc704..e644c16ddfeab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.streaming.continuous +import org.mockito.ArgumentMatchers.{any, eq => eqTo} import org.mockito.InOrder -import org.mockito.Matchers.{any, eq => eqTo} -import org.mockito.Mockito._ +import org.mockito.Mockito.{inOrder, never, verify} import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 8212fb912ec57..4d3a54a048e8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.hadoop.fs.Path -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index fd7e00b1de25f..bdaef94949159 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -26,7 +26,7 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.mockito.Matchers.any +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{doThrow, reset, spy} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index 8d81b582e4d30..7ec02c4782e42 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import org.mockito.Matchers.{eq => meq} -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{eq => meq} +import org.mockito.Mockito.{never, reset, times, verify, when} import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester} import org.scalatest.concurrent.Eventually.{eventually, timeout} import org.scalatest.mockito.MockitoSugar diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 4a2549fc0a96d..c20380d8490df 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -31,8 +31,8 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.mockito.ArgumentCaptor -import org.mockito.Matchers.{eq => meq, _} -import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers.{any, anyLong, eq => meq} +import org.mockito.Mockito.{times, verify, when} import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ From 5969b8a2edb913fe2a8e0d928010eb8f471c7b02 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 5 Jan 2019 00:55:17 -0800 Subject: [PATCH 2378/2461] [SPARK-26541][BUILD] Add `-Pdocker-integration-tests` to `dev/scalastyle` ## What changes were proposed in this pull request? This PR makes `scalastyle` to check `docker-integration-tests` module additionally and fixes one error. ## How was this patch tested? Pass the Jenkins with the updated Scalastyle. ``` ======================================================================== Running Scala style checks ======================================================================== Scalastyle checks passed. ``` Closes #23459 from dongjoon-hyun/SPARK-26541. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/scalastyle | 1 + .../org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/scalastyle b/dev/scalastyle index 2d6ee0da1d4c1..ff6dba5b536a8 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -28,6 +28,7 @@ ERRORS=$(echo -e "q\n" \ -Phive \ -Phive-thriftserver \ -Pspark-ganglia-lgpl \ + -Pdocker-integration-tests \ scalastyle test:scalastyle \ | awk '{if($1~/error/)print}' \ ) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 70d294d0ca650..79fdf9c2ba434 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -154,7 +154,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo // A value with fractions from DECIMAL(3, 2) is correct: assert(row.getDecimal(1).compareTo(BigDecimal.valueOf(1.23)) == 0) // A value > Int.MaxValue from DECIMAL(10) is correct: - assert(row.getDecimal(2).compareTo(BigDecimal.valueOf(9999999999l)) == 0) + assert(row.getDecimal(2).compareTo(BigDecimal.valueOf(9999999999L)) == 0) } From 1af1190beeb1ac15205a9bd06ca67e363de03221 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 5 Jan 2019 01:14:58 -0800 Subject: [PATCH 2379/2461] [SPARK-26078][SQL][FOLLOWUP] Remove useless import ## What changes were proposed in this pull request? While backporting the patch to 2.4/2.3, I realized that the patch introduces unneeded imports (probably leftovers from intermediate changes). This PR removes the useless import. ## How was this patch tested? NA Closes #23451 from mgaido91/SPARK-26078_FOLLOWUP. Authored-by: Marco Gaido Signed-off-by: Dongjoon Hyun --- sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index c95c52f1d3a9c..48c1676609132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ From 980e6bcd1c016139c6918d788fb4806a60740fcf Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 5 Jan 2019 21:50:27 +0800 Subject: [PATCH 2380/2461] [SPARK-26246][SQL][FOLLOWUP] Inferring TimestampType from JSON ## What changes were proposed in this pull request? Added new JSON option `inferTimestamp` (`true` by default) to control inferring of `TimestampType` from string values. ## How was this patch tested? Add new UT to `JsonInferSchemaSuite`. Closes #23455 from MaxGekk/json-infer-time-followup. Authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- docs/sql-migration-guide-upgrade.md | 3 +++ .../org/apache/spark/sql/catalyst/json/JSONOptions.scala | 6 ++++++ .../apache/spark/sql/catalyst/json/JsonInferSchema.scala | 3 ++- .../spark/sql/catalyst/json/JsonInferSchemaSuite.scala | 6 ++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index c4d2157de8b60..7e6a0c097d242 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -40,6 +40,9 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully. - Since Spark 3.0, the `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions use java.time API for parsing and formatting dates/timestamps from/to strings by using ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html) based on Proleptic Gregorian calendar. In Spark version 2.4 and earlier, java.text.SimpleDateFormat and java.util.GregorianCalendar (hybrid calendar that supports both the Julian and Gregorian calendar systems, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html) is used for the same purpuse. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + + - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they matches to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index eaff3fa7bec25..1ec9d5093a789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -117,6 +117,12 @@ private[sql] class JSONOptions( */ val pretty: Boolean = parameters.get("pretty").map(_.toBoolean).getOrElse(false) + /** + * Enables inferring of TimestampType from strings matched to the timestamp pattern + * defined by the timestampFormat option. + */ + val inferTimestamp: Boolean = parameters.get("inferTimestamp").map(_.toBoolean).getOrElse(true) + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 3203e626ea400..0bf3f03cdb72d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -128,7 +128,8 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { } if (options.prefersDecimal && decimalTry.isDefined) { decimalTry.get - } else if ((allCatch opt timestampFormatter.parse(field)).isDefined) { + } else if (options.inferTimestamp && + (allCatch opt timestampFormatter.parse(field)).isDefined) { TimestampType } else { StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala index 9307f9b47b807..9a6f4f5f9b0cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala @@ -99,4 +99,10 @@ class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper { } } } + + test("disable timestamp inferring") { + val json = """{"a": "2019-01-04T21:11:10.123Z"}""" + checkType(Map("inferTimestamp" -> "true"), json, TimestampType) + checkType(Map("inferTimestamp" -> "false"), json, StringType) + } } From 0037bbb71725619590f5ecbc9a5a470c4889810f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 5 Jan 2019 22:53:28 +0800 Subject: [PATCH 2381/2461] [MINOR][DOC] Fix typos in the SQL migration guide ## What changes were proposed in this pull request? Fixed a few typos in the migration guide. Closes #23465 from MaxGekk/fix-typos-migration-guide. Authored-by: Maxim Gekk Signed-off-by: Hyukjin Kwon --- docs/sql-migration-guide-upgrade.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 7e6a0c097d242..0fcdd420bcfe3 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,7 +17,7 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. - - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. + - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independently of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. @@ -27,21 +27,21 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. - - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined. + - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined. - In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`. - In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.setCommandRejectsSparkCoreConfs` to `false`. - - Since Spark 3.0, CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpuse with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + - Since Spark 3.0, CSV/JSON datasources use java.time API for parsing and generating CSV/JSON content. In Spark version 2.4 and earlier, java.text.SimpleDateFormat is used for the same purpose with fallbacks to the parsing mechanisms of Spark 2.0 and 1.x. For example, `2018-12-08 10:39:21.123` with the pattern `yyyy-MM-dd'T'HH:mm:ss.SSS` cannot be parsed since Spark 3.0 because the timestamp does not match to the pattern but it can be parsed by earlier Spark versions due to a fallback to `Timestamp.valueOf`. To parse the same timestamp since Spark 3.0, the pattern should be `yyyy-MM-dd HH:mm:ss.SSS`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - In Spark version 2.4 and earlier, CSV datasource converts a malformed CSV string to a row with all `null`s in the PERMISSIVE mode. Since Spark 3.0, the returned row can contain non-`null` fields if some of CSV column values were parsed and converted to desired types successfully. - In Spark version 2.4 and earlier, JSON datasource and JSON functions like `from_json` convert a bad JSON record to a row with all `null`s in the PERMISSIVE mode when specified schema is `StructType`. Since Spark 3.0, the returned row can contain non-`null` fields if some of JSON column values were parsed and converted to desired types successfully. - - Since Spark 3.0, the `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions use java.time API for parsing and formatting dates/timestamps from/to strings by using ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html) based on Proleptic Gregorian calendar. In Spark version 2.4 and earlier, java.text.SimpleDateFormat and java.util.GregorianCalendar (hybrid calendar that supports both the Julian and Gregorian calendar systems, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html) is used for the same purpuse. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. + - Since Spark 3.0, the `unix_timestamp`, `date_format`, `to_unix_timestamp`, `from_unixtime`, `to_date`, `to_timestamp` functions use java.time API for parsing and formatting dates/timestamps from/to strings by using ISO chronology (https://docs.oracle.com/javase/8/docs/api/java/time/chrono/IsoChronology.html) based on Proleptic Gregorian calendar. In Spark version 2.4 and earlier, java.text.SimpleDateFormat and java.util.GregorianCalendar (hybrid calendar that supports both the Julian and Gregorian calendar systems, see https://docs.oracle.com/javase/7/docs/api/java/util/GregorianCalendar.html) is used for the same purpose. New implementation supports pattern formats as described here https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html and performs strict checking of its input. For example, the `2015-07-22 10:00:00` timestamp cannot be parse if pattern is `yyyy-MM-dd` because the parser does not consume whole input. Another example is the `31/01/2015 00:00` input cannot be parsed by the `dd/MM/yyyy hh:mm` pattern because `hh` supposes hours in the range `1-12`. To switch back to the implementation used in Spark 2.4 and earlier, set `spark.sql.legacy.timeParser.enabled` to `true`. - - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they matches to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring. + - Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring. ## Upgrading From Spark SQL 2.3 to 2.4 From 4ab5b5b9185f60f671d90d94732d0d784afa5f84 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sat, 5 Jan 2019 14:37:04 -0800 Subject: [PATCH 2382/2461] [SPARK-26545] Fix typo in EqualNullSafe's truth table comment ## What changes were proposed in this pull request? The truth table comment in EqualNullSafe incorrectly marked FALSE results as UNKNOWN. ## How was this patch tested? N/A Closes #23461 from rednaxelafx/fix-typo. Authored-by: Kris Mok Signed-off-by: gatorsmile --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 01ecb99025eaa..37fe22f4556e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -653,9 +653,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp // +---------+---------+---------+---------+ // | <=> | TRUE | FALSE | UNKNOWN | // +---------+---------+---------+---------+ - // | TRUE | TRUE | FALSE | UNKNOWN | - // | FALSE | FALSE | TRUE | UNKNOWN | - // | UNKNOWN | UNKNOWN | UNKNOWN | TRUE | + // | TRUE | TRUE | FALSE | FALSE | + // | FALSE | FALSE | TRUE | FALSE | + // | UNKNOWN | FALSE | FALSE | TRUE | // +---------+---------+---------+---------+ override def eval(input: InternalRow): Any = { val input1 = left.eval(input) From a17851cb95687963936c4d4a7eed132ee2c10677 Mon Sep 17 00:00:00 2001 From: Dave DeCaprio Date: Sat, 5 Jan 2019 19:20:35 -0800 Subject: [PATCH 2383/2461] [SPARK-26548][SQL] Don't hold CacheManager write lock while computing executedPlan ## What changes were proposed in this pull request? Address SPARK-26548, in Spark 2.4.0, the CacheManager holds a write lock while computing the executedPlan for a cached logicalPlan. In some cases with very large query plans this can be an expensive operation, taking minutes to run. The entire cache is blocked during this time. This PR changes that so the writeLock is only obtained after the executedPlan is generated, this reduces the time the lock is held to just the necessary time when the shared data structure is being updated. gatorsmile and cloud-fan - You can committed patches in this area before. This is a small incremental change. ## How was this patch tested? Has been tested on a live system where the blocking was causing major issues and it is working well. CacheManager has no explicit unit test but is used in many places internally as part of the SharedState. Closes #23469 from DaveDeCaprio/optimizer-unblocked. Lead-authored-by: Dave DeCaprio Co-authored-by: David DeCaprio Signed-off-by: gatorsmile --- .../org/apache/spark/sql/execution/CacheManager.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index c9929935fb8ac..728fde54fe69a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -88,7 +88,7 @@ class CacheManager extends Logging { def cacheQuery( query: Dataset[_], tableName: Option[String] = None, - storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { + storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = { val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") @@ -100,7 +100,13 @@ class CacheManager extends Logging { sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, planToCache) - cachedData.add(CachedData(planToCache, inMemoryRelation)) + writeLock { + if (lookupCachedData(planToCache).nonEmpty) { + logWarning("Data has already been cached.") + } else { + cachedData.add(CachedData(planToCache, inMemoryRelation)) + } + } } } From 737f08949adecbae37bb92dfad71ae5f3a82cbee Mon Sep 17 00:00:00 2001 From: SongYadong Date: Sun, 6 Jan 2019 08:46:20 -0600 Subject: [PATCH 2384/2461] [SPARK-26527][CORE] Let acquireUnrollMemory fail fast if required space exceeds memory limit ## What changes were proposed in this pull request? When acquiring unroll memory from `StaticMemoryManager`, let it fail fast if required space exceeds memory limit, just like acquiring storage memory. I think this may reduce some computation and memory evicting costs especially when required space(`numBytes`) is very big. ## How was this patch tested? Existing unit tests. Closes #23426 from SongYadong/acquireUnrollMemory_fail_fast. Authored-by: SongYadong Signed-off-by: Sean Owen --- .../spark/memory/StaticMemoryManager.scala | 27 ++++++++++++------- .../spark/storage/MemoryStoreSuite.scala | 4 +-- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 8286087042741..0fd349dc51619 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -80,16 +80,23 @@ private[spark] class StaticMemoryManager( memoryMode: MemoryMode): Boolean = synchronized { require(memoryMode != MemoryMode.OFF_HEAP, "StaticMemoryManager does not support off-heap unroll memory") - val currentUnrollMemory = onHeapStorageMemoryPool.memoryStore.currentUnrollMemory - val freeMemory = onHeapStorageMemoryPool.memoryFree - // When unrolling, we will use all of the existing free memory, and, if necessary, - // some extra space freed from evicting cached blocks. We must place a cap on the - // amount of memory to be evicted by unrolling, however, otherwise unrolling one - // big block can blow away the entire cache. - val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) - // Keep it within the range 0 <= X <= maxNumBytesToFree - val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) - onHeapStorageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree) + if (numBytes > maxOnHeapStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxOnHeapStorageMemory bytes)") + false + } else { + val currentUnrollMemory = onHeapStorageMemoryPool.memoryStore.currentUnrollMemory + val freeMemory = onHeapStorageMemoryPool.memoryFree + // When unrolling, we will use all of the existing free memory, and, if necessary, + // some extra space freed from evicting cached blocks. We must place a cap on the + // amount of memory to be evicted by unrolling, however, otherwise unrolling one + // big block can blow away the entire cache. + val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) + // Keep it within the range 0 <= X <= maxNumBytesToFree + val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) + onHeapStorageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree) + } } private[memory] diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 7274072e5049a..baff672f5fb8f 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -291,11 +291,11 @@ class MemoryStoreSuite blockInfoManager.removeBlock("b3") putIteratorAsBytes("b3", smallIterator, ClassTag.Any) - // Unroll huge block with not enough space. This should fail and kick out b2 in the process. + // Unroll huge block with not enough space. This should fail. val result4 = putIteratorAsBytes("b4", bigIterator, ClassTag.Any) assert(result4.isLeft) // unroll was unsuccessful assert(!memoryStore.contains("b1")) - assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator From 9d8e9b394bbc065a72076585a21393f42ce86cd1 Mon Sep 17 00:00:00 2001 From: Hirobe Keiichi Date: Sun, 6 Jan 2019 08:52:09 -0600 Subject: [PATCH 2385/2461] [SPARK-26339][SQL] Throws better exception when reading files that start with underscore ## What changes were proposed in this pull request? My pull request #23288 was resolved and merged to master, but it turned out later that my change breaks another regression test. Because we cannot reopen pull request, I create a new pull request here. Commit 92934b4 is only change after pull request #23288. `CheckFileExist` was avoided at 239cfa4 after discussing #23288 (comment). But, that change turned out to be wrong because we should not check if argument checkFileExist is false. Test https://github.com/apache/spark/blob/27e42c1de502da80fa3e22bb69de47fb00158174/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala#L2555 failed when we avoided checkFileExist, but now successed after commit 92934b4 . ## How was this patch tested? Both of below tests were passed. ``` testOnly org.apache.spark.sql.execution.datasources.csv.CSVSuite testOnly org.apache.spark.sql.SQLQuerySuite ``` Closes #23446 from KeiichiHirobe/SPARK-26339. Authored-by: Hirobe Keiichi Signed-off-by: Sean Owen --- .../execution/datasources/DataSource.scala | 19 +++++++++++++++++- .../src/test/resources/test-data/_cars.csv | 7 +++++++ .../execution/datasources/csv/CSVSuite.scala | 20 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/resources/test-data/_cars.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index fefff68c4ba8b..2a438a5cbf957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -543,7 +543,7 @@ case class DataSource( checkFilesExist: Boolean): Seq[Path] = { val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() - allPaths.flatMap { path => + val allGlobPath = allPaths.flatMap { path => val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -560,6 +560,23 @@ case class DataSource( } globPath }.toSeq + + if (checkFilesExist) { + val (filteredOut, filteredIn) = allGlobPath.partition { path => + InMemoryFileIndex.shouldFilterOut(path.getName) + } + if (filteredOut.nonEmpty) { + if (filteredIn.isEmpty) { + throw new AnalysisException( + s"All paths were ignored:\n${filteredOut.mkString("\n ")}") + } else { + logDebug( + s"Some paths were ignored:\n${filteredOut.mkString("\n ")}") + } + } + } + + allGlobPath } } diff --git a/sql/core/src/test/resources/test-data/_cars.csv b/sql/core/src/test/resources/test-data/_cars.csv new file mode 100644 index 0000000000000..40ded573ade5c --- /dev/null +++ b/sql/core/src/test/resources/test-data/_cars.csv @@ -0,0 +1,7 @@ + +year,make,model,comment,blank +"2012","Tesla","S","No comment", + +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d9e5d7af19671..fb1bedfaa32c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -53,6 +53,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val carsEmptyValueFile = "test-data/cars-empty-value.csv" private val carsBlankColName = "test-data/cars-blank-column-name.csv" private val carsCrlf = "test-data/cars-crlf.csv" + private val carsFilteredOutFile = "test-data/_cars.csv" private val emptyFile = "test-data/empty.csv" private val commentsFile = "test-data/comments.csv" private val disableCommentsFile = "test-data/disable_comments.csv" @@ -346,6 +347,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(result.schema.fieldNames.size === 1) } + test("SPARK-26339 Not throw an exception if some of specified paths are filtered in") { + val cars = spark + .read + .option("header", "false") + .csv(testFile(carsFile), testFile(carsFilteredOutFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("SPARK-26339 Throw an exception only if all of the specified paths are filtered out") { + val e = intercept[AnalysisException] { + val cars = spark + .read + .option("header", "false") + .csv(testFile(carsFilteredOutFile)) + }.getMessage + assert(e.contains("All paths were ignored:")) + } + test("DDL test with empty file") { withView("carsTable") { spark.sql( From b305d71625380f6fcd7b675d423222eca1840c2a Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 6 Jan 2019 17:36:06 -0800 Subject: [PATCH 2386/2461] [SPARK-26547][SQL] Remove duplicate toHiveString from HiveUtils ## What changes were proposed in this pull request? The `toHiveString()` and `toHiveStructString` methods were removed from `HiveUtils` because they have been already implemented in `HiveResult`. One related test was moved to `HiveResultSuite`. ## How was this patch tested? By tests from `hive-thriftserver`. Closes #23466 from MaxGekk/dedup-hive-result-string. Authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/HiveResult.scala | 115 ++++++++++-------- .../spark/sql/execution/HiveResultSuite.scala | 30 +++++ .../SparkExecuteStatementOperation.scala | 4 +- .../org/apache/spark/sql/hive/HiveUtils.scala | 46 ------- .../spark/sql/hive/HiveUtilsSuite.scala | 11 +- 5 files changed, 96 insertions(+), 110 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 22d3ca958a210..c90b254a6d121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -56,61 +56,70 @@ object HiveResult { result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")) } - /** Formats a datum (based on the given data type) and returns the string representation. */ - private def toHiveString(a: (Any, DataType)): String = { - val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, - BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) - val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) - - def formatDecimal(d: java.math.BigDecimal): String = { - if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { - java.math.BigDecimal.ZERO.toPlainString - } else { - d.stripTrailingZeros().toPlainString - } + private def formatDecimal(d: java.math.BigDecimal): String = { + if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { + java.math.BigDecimal.ZERO.toPlainString + } else { + d.stripTrailingZeros().toPlainString // Hive strips trailing zeros } + } - /** Hive outputs fields of structs slightly differently than top level attributes. */ - def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "null" - case (s: String, StringType) => "\"" + s + "\"" - case (decimal, DecimalType()) => decimal.toString - case (interval, CalendarIntervalType) => interval.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } + private val primitiveTypes = Seq( + StringType, + IntegerType, + LongType, + DoubleType, + FloatType, + BooleanType, + ByteType, + ShortType, + DateType, + TimestampType, + BinaryType) - a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "NULL" - case (d: Date, DateType) => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case (t: Timestamp, TimestampType) => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), timeZone) - case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) - case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) - case (interval, CalendarIntervalType) => interval.toString - case (other, tpe) if primitiveTypes.contains(tpe) => other.toString - } + /** Hive outputs fields of structs slightly differently than top level attributes. */ + private def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.toSeq.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** Formats a datum (based on the given data type) and returns the string representation. */ + def toHiveString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.toSeq.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (d: Date, DateType) => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case (t: Timestamp, TimestampType) => + val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), timeZone) + case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) + case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString + case (other, _ : UserDefinedType[_]) => other.toString + case (other, tpe) if primitiveTypes.contains(tpe) => other.toString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala new file mode 100644 index 0000000000000..4205b3f79a972 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT} + +class HiveResultSuite extends SparkFunSuite { + + test("toHiveString correctly handles UDTs") { + val point = new ExamplePoint(50.0, 50.0) + val tpe = new ExamplePointUDT() + assert(HiveResult.toHiveString((point, tpe)) === "(50.0, 50.0)") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 3cfc81b8a9579..e68c6011c1393 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -34,8 +34,8 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLContext} +import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.execution.command.SetCommand -import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.{Utils => SparkUtils} @@ -103,7 +103,7 @@ private[hive] class SparkExecuteStatementOperation( case BinaryType => to += from.getAs[Array[Byte]](ordinal) case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => - val hiveString = HiveUtils.toHiveString((from.get(ordinal), dataTypes(ordinal))) + val hiveString = HiveResult.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index b60d4c71f5941..597eef129f63e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -444,52 +444,6 @@ private[spark] object HiveUtils extends Logging { propMap.toMap } - protected val primitiveTypes = - Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DateType, TimestampType, BinaryType) - - protected[sql] def toHiveString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "NULL" - case (d: Int, DateType) => new DateWritable(d).toString - case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString - case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) - case (decimal: java.math.BigDecimal, DecimalType()) => - // Hive strips trailing zeros so use its toString - HiveDecimal.create(decimal).toString - case (other, _ : UserDefinedType[_]) => other.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - - /** Hive outputs fields of structs slightly differently than top level attributes. */ - protected def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "null" - case (s: String, StringType) => "\"" + s + "\"" - case (decimal, DecimalType()) => decimal.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - /** * Infers the schema for Hive serde tables and returns the CatalogTable with the inferred schema. * When the tables are data source tables or the schema already exists, returns the original diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index f2b75e4b23f02..303dd70760a1b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.hive -import java.net.URL - import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SQLTestUtils} -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} +import org.apache.spark.util.ChildFirstURLClassLoader class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -62,10 +61,4 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton Thread.currentThread().setContextClassLoader(contextClassLoader) } } - - test("toHiveString correctly handles UDTs") { - val point = new ExamplePoint(50.0, 50.0) - val tpe = new ExamplePointUDT() - assert(HiveUtils.toHiveString((point, tpe)) === "(50.0, 50.0)") - } } From fe039faddf13c6a30f7aea69324aa4d4bb84c632 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 6 Jan 2019 19:59:31 -0800 Subject: [PATCH 2387/2461] [SPARK-26554][BUILD] Update `release-util.sh` to avoid GitBox fake 200 headers ## What changes were proposed in this pull request? Unlike the previous Apache Git repository, new GitBox repository returns a fake HTTP 200 header instead of `404 Not Found` header. This makes release scripts out of order. This PR aims to fix it to handle the html body message instead of the fake HTTP headers. This is a release blocker. ```bash $ curl -s --head --fail "https://gitbox.apache.org/repos/asf?p=spark.git;a=commit;h=v3.0.0" HTTP/1.1 200 OK Date: Sun, 06 Jan 2019 22:42:39 GMT Server: Apache/2.4.18 (Ubuntu) Vary: Accept-Encoding Access-Control-Allow-Origin: * Access-Control-Allow-Methods: POST, GET, OPTIONS Access-Control-Allow-Headers: X-PINGOTHER Access-Control-Max-Age: 1728000 Content-Type: text/html; charset=utf-8 ``` **BEFORE** ```bash $ ./do-release-docker.sh -d /tmp/test -n Branch [branch-2.4]: Current branch version is 2.4.1-SNAPSHOT. Release [2.4.1]: RC # [1]: v2.4.1-rc1 already exists. Continue anyway [y/n]? ``` **AFTER** ```bash $ ./do-release-docker.sh -d /tmp/test -n Branch [branch-2.4]: Current branch version is 2.4.1-SNAPSHOT. Release [2.4.1]: RC # [1]: This is a dry run. Please confirm the ref that will be built for testing. Ref [v2.4.1-rc1]: ``` ## How was this patch tested? Manual. Closes #23476 from dongjoon-hyun/SPARK-26554. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/create-release/release-util.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh index c925de9be52d4..9a340528b506d 100755 --- a/dev/create-release/release-util.sh +++ b/dev/create-release/release-util.sh @@ -73,7 +73,9 @@ function fcreate_secure { } function check_for_tag { - curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null + # Check HTML body messages instead of header status codes. Apache GitBox returns + # a header with `200 OK` status code for both existing and non-existing tag URLs + ! curl -s --fail "$ASF_REPO_WEBUI;a=commit;h=$1" | grep '404 Not Found' > /dev/null } function get_release_info { From 61133cb8a69e7814c3450e84ce9cc9226d7e8ad8 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 6 Jan 2019 21:00:10 -0800 Subject: [PATCH 2388/2461] [SPARK-26536][BUILD][FOLLOWUP][TEST-MAVEN] Make StreamingReadSupport public for maven testing ## What changes were proposed in this pull request? `StreamingReadSupport` is designed to be a `package` interface. Mockito seems to complain during `Maven` testing. This doesn't fail in `sbt` and IntelliJ. For mock-testing purpose, this PR makes it `public` interface and adds explicit comments like `public interface ReadSupport` ```scala EpochCoordinatorSuite: *** RUN ABORTED *** java.lang.IllegalAccessError: tried to access class org.apache.spark.sql.sources.v2.reader.streaming.StreamingReadSupport from class org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport$MockitoMock$58628338 at org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport$MockitoMock$58628338.(Unknown Source) at sun.reflect.GeneratedSerializationConstructorAccessor632.newInstance(Unknown Source) at java.lang.reflect.Constructor.newInstance(Constructor.java:423) at org.objenesis.instantiator.sun.SunReflectionFactoryInstantiator.newInstance(SunReflectionFactoryInstantiator.java:48) at org.objenesis.ObjenesisBase.newInstance(ObjenesisBase.java:73) at org.mockito.internal.creation.instance.ObjenesisInstantiator.newInstance(ObjenesisInstantiator.java:19) at org.mockito.internal.creation.bytebuddy.SubclassByteBuddyMockMaker.createMock(SubclassByteBuddyMockMaker.java:47) at org.mockito.internal.creation.bytebuddy.ByteBuddyMockMaker.createMock(ByteBuddyMockMaker.java:25) at org.mockito.internal.util.MockUtil.createMock(MockUtil.java:35) at org.mockito.internal.MockitoCore.mock(MockitoCore.java:69) ``` ## How was this patch tested? Pass the Jenkins with Maven build Closes #23463 from dongjoon-hyun/SPARK-26536-2. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../v2/reader/streaming/StreamingReadSupport.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java index 84872d1ebc26e..bd39fc858d3b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/StreamingReadSupport.java @@ -17,14 +17,17 @@ package org.apache.spark.sql.sources.v2.reader.streaming; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.sql.sources.v2.reader.ReadSupport; /** - * A base interface for streaming read support. This is package private and is invisible to data - * sources. Data sources should implement concrete streaming read support interfaces: - * {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + * A base interface for streaming read support. Data sources should implement concrete streaming + * read support interfaces: {@link MicroBatchReadSupport} or {@link ContinuousReadSupport}. + * This is exposed for a testing purpose. */ -interface StreamingReadSupport extends ReadSupport { +@VisibleForTesting +public interface StreamingReadSupport extends ReadSupport { /** * Returns the initial offset for a streaming query to start reading from. Note that the From 468d25ec7419b4c55955ead877232aae5654260e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 6 Jan 2019 22:45:18 -0800 Subject: [PATCH 2389/2461] [MINOR][BUILD] Fix script name in `release-tag.sh` usage message ## What changes were proposed in this pull request? This PR fixes the old script name in `release-tag.sh`. $ ./release-tag.sh --help | head -n1 usage: tag-release.sh ## How was this patch tested? Manual. $ ./release-tag.sh --help | head -n1 usage: release-tag.sh Closes #23477 from dongjoon-hyun/SPARK-RELEASE-TAG. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/create-release/release-tag.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index 010082d960a29..8024440759eb5 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -21,8 +21,9 @@ SELF=$(cd $(dirname $0) && pwd) . "$SELF/release-util.sh" function exit_with_usage { + local NAME=$(basename $0) cat << EOF -usage: tag-release.sh +usage: $NAME Tags a Spark release on a particular branch. Inputs are specified with the following environment variables: From a927c764c1eee066efc1c2c713dfee411de79245 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 7 Jan 2019 18:36:52 +0800 Subject: [PATCH 2390/2461] [SPARK-26559][ML][PYSPARK] ML image can't work with numpy versions prior to 1.9 ## What changes were proposed in this pull request? Due to [API change](https://github.com/numpy/numpy/pull/4257/files#diff-c39521d89f7e61d6c0c445d93b62f7dc) at 1.9, PySpark image doesn't work with numpy version prior to 1.9. When running image test with numpy version prior to 1.9, we can see error: ``` test_read_images (pyspark.ml.tests.test_image.ImageReaderTest) ... ERROR test_read_images_multiple_times (pyspark.ml.tests.test_image.ImageReaderTest2) ... ok ====================================================================== ERROR: test_read_images (pyspark.ml.tests.test_image.ImageReaderTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "/Users/viirya/docker_tmp/repos/spark-1/python/pyspark/ml/tests/test_image.py", line 36, in test_read_images self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) File "/Users/viirya/docker_tmp/repos/spark-1/python/pyspark/ml/image.py", line 193, in toImage data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) AttributeError: 'numpy.ndarray' object has no attribute 'tobytes' ---------------------------------------------------------------------- Ran 2 tests in 29.040s FAILED (errors=1) ``` ## How was this patch tested? Manually test with numpy version prior and after 1.9. Closes #23484 from viirya/fix-pyspark-image. Authored-by: Liang-Chi Hsieh Signed-off-by: Hyukjin Kwon --- python/pyspark/ml/image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index edb90a3578546..a1aacea88e42e 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -28,6 +28,7 @@ import warnings import numpy as np +from distutils.version import LooseVersion from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string @@ -190,7 +191,11 @@ def toImage(self, array, origin=""): # Running `bytearray(numpy.array([1]))` fails in specific Python versions # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. # Here, it avoids it by converting it to bytes. - data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + if LooseVersion(np.__version__) >= LooseVersion('1.9'): + data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + else: + # Numpy prior to 1.9 don't have `tobytes` method. + data = bytearray(array.astype(dtype=np.uint8).ravel()) # Creating new Row with _create_row(), because Row(name = value, ... ) # orders fields by name, which conflicts with expected schema order From 868e02533d76d45bff4200d07658105b6004cf46 Mon Sep 17 00:00:00 2001 From: ayudovin Date: Mon, 7 Jan 2019 08:58:33 -0600 Subject: [PATCH 2391/2461] [SPARK-26383][CORE] NPE when use DataFrameReader.jdbc with wrong URL ### What changes were proposed in this pull request? When passing wrong url to jdbc then It would throw IllegalArgumentException instead of NPE. ### How was this patch tested? Adding test case to Existing tests in JDBCSuite Closes #23464 from ayudovin/fixing-npe. Authored-by: ayudovin Signed-off-by: Sean Owen --- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 7 ++++++- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 13 +++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 922bef284c98e..86a27b5afc250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -48,6 +48,7 @@ object JdbcUtils extends Logging { * Returns a factory for creating connections to the given JDBC URL. * * @param options - JDBC options that contains url, table and other information. + * @throws IllegalArgumentException if the driver could not open a JDBC connection. */ def createConnectionFactory(options: JDBCOptions): () => Connection = { val driverClass: String = options.driverClass @@ -60,7 +61,11 @@ object JdbcUtils extends Logging { throw new IllegalStateException( s"Did not find registered driver with class $driverClass") } - driver.connect(options.url, options.asConnectionProperties) + val connection: Connection = driver.connect(options.url, options.asConnectionProperties) + require(connection != null, + s"The driver could not open a JDBC connection. Check the URL: ${options.url}") + + connection } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e4641631e607d..aefa5da94481b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1507,4 +1507,17 @@ class JDBCSuite extends QueryTest checkNotPushdown(sql("SELECT name, theid FROM predicateOption WHERE theid = 1")), Row("fred", 1) :: Nil) } + + test("SPARK-26383 throw IllegalArgumentException if wrong kind of driver to the given url") { + val e = intercept[IllegalArgumentException] { + val opts = Map( + "url" -> "jdbc:mysql://localhost/db", + "dbtable" -> "table", + "driver" -> "org.postgresql.Driver" + ) + spark.read.format("jdbc").options(opts).load + }.getMessage + assert(e.contains("The driver could not open a JDBC connection. " + + "Check the URL: jdbc:mysql://localhost/db")) + } } From 71183b283343a99c6fa99a41268dae412598067f Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 7 Jan 2019 09:15:50 -0800 Subject: [PATCH 2392/2461] [SPARK-24489][ML] Check for invalid input type of weight data in ml.PowerIterationClustering ## What changes were proposed in this pull request? The test case will result the following failure. currently in ml.PIC, there is no check for the data type of weight column. ``` test("invalid input types for weight") { val invalidWeightData = spark.createDataFrame(Seq( (0L, 1L, "a"), (2L, 3L, "b") )).toDF("src", "dst", "weight") val pic = new PowerIterationClustering() .setWeightCol("weight") val result = pic.assignClusters(invalidWeightData) } ``` ``` Job aborted due to stage failure: Task 0 in stage 8077.0 failed 1 times, most recent failure: Lost task 0.0 in stage 8077.0 (TID 882, localhost, executor driver): scala.MatchError: [0,1,null] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema) at org.apache.spark.ml.clustering.PowerIterationClustering$$anonfun$3.apply(PowerIterationClustering.scala:178) at org.apache.spark.ml.clustering.PowerIterationClustering$$anonfun$3.apply(PowerIterationClustering.scala:178) at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434) at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440) at scala.collection.Iterator$class.foreach(Iterator.scala:893) at scala.collection.AbstractIterator.foreach(Iterator.scala:1336) at org.apache.spark.graphx.EdgeRDD$$anonfun$1.apply(EdgeRDD.scala:107) at org.apache.spark.graphx.EdgeRDD$$anonfun$1.apply(EdgeRDD.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:847) ``` In this PR, added check types for weight column. ## How was this patch tested? UT added Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21509 from shahidki31/testCasePic. Authored-by: Shahid Signed-off-by: Holden Karau --- .../ml/clustering/PowerIterationClustering.scala | 1 + .../PowerIterationClusteringSuite.scala | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index d9a330f67e8dc..149e99d2f195a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -166,6 +166,7 @@ class PowerIterationClustering private[clustering] ( val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { lit(1.0) } else { + SchemaUtils.checkNumericType(dataset.schema, $(weightCol)) col($(weightCol)).cast(DoubleType) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 55b460f1a4524..0ba3ffabb75d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -145,6 +145,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(msg.contains("Similarity must be nonnegative")) } + test("check for invalid input types of weight") { + val invalidWeightData = spark.createDataFrame(Seq( + (0L, 1L, "a"), + (2L, 3L, "b") + )).toDF("src", "dst", "weight") + + val msg = intercept[IllegalArgumentException] { + new PowerIterationClustering() + .setWeightCol("weight") + .assignClusters(invalidWeightData) + }.getMessage + assert(msg.contains("requirement failed: Column weight must be of type numeric" + + " but was actually of type string.")) + } + test("test default weight") { val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) From 669e8a155987995a1a5d49a96b88c05f39e41723 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 7 Jan 2019 14:40:08 -0600 Subject: [PATCH 2393/2461] [SPARK-25689][YARN] Make driver, not AM, manage delegation tokens. This change modifies the behavior of the delegation token code when running on YARN, so that the driver controls the renewal, in both client and cluster mode. For that, a few different things were changed: * The AM code only runs code that needs DTs when DTs are available. In a way, this restores the AM behavior to what it was pre-SPARK-23361, but keeping the fix added in that bug. Basically, all the AM code is run in a "UGI.doAs()" block; but code that needs to talk to HDFS (basically the distributed cache handling code) was delayed to the point where the driver is up and running, and thus when valid delegation tokens are available. * SparkSubmit / ApplicationMaster now handle user login, not the token manager. The previous AM code was relying on the token manager to keep the user logged in when keytabs are used. This required some odd APIs in the token manager and the AM so that the right UGI was exposed and used in the right places. After this change, the logged in user is handled separately from the token manager, so the API was cleaned up, and, as explained above, the whole AM runs under the logged in user, which also helps with simplifying some more code. * Distributed cache configs are sent separately to the AM. Because of the delayed initialization of the cached resources in the AM, it became easier to write the cache config to a separate properties file instead of bundling it with the rest of the Spark config. This also avoids having to modify the SparkConf to hide things from the UI. * Finally, the AM doesn't manage the token manager anymore. The above changes allow the token manager to be completely handled by the driver's scheduler backend code also in YARN mode (whether client or cluster), making it similar to other RMs. To maintain the fix added in SPARK-23361 also in client mode, the AM now sends an extra message to the driver on initialization to fetch delegation tokens; and although it might not really be needed, the driver also keeps the running AM updated when new tokens are created. Tested in a kerberized cluster with the same tests used to validate SPARK-23361, in both client and cluster mode. Also tested with a non-kerberized cluster. Closes #23338 from vanzin/SPARK-25689. Authored-by: Marcelo Vanzin Signed-off-by: Imran Rashid --- .../HadoopDelegationTokenManager.scala | 110 ++++++-------- .../HiveDelegationTokenProvider.scala | 16 ++- .../cluster/CoarseGrainedClusterMessage.scala | 3 + .../CoarseGrainedSchedulerBackend.scala | 40 ++++-- .../HadoopDelegationTokenManagerSuite.scala | 8 +- .../KerberosConfDriverFeatureStep.scala | 2 +- .../KubernetesClusterSchedulerBackend.scala | 7 +- .../MesosCoarseGrainedSchedulerBackend.scala | 7 +- .../spark/deploy/yarn/ApplicationMaster.scala | 135 +++++++++--------- .../yarn/ApplicationMasterArguments.scala | 5 + .../org/apache/spark/deploy/yarn/Client.scala | 100 +++++++------ .../spark/deploy/yarn/YarnRMClient.scala | 8 +- .../org/apache/spark/deploy/yarn/config.scala | 10 -- .../YARNHadoopDelegationTokenManager.scala | 7 +- .../cluster/YarnClientSchedulerBackend.scala | 6 + .../cluster/YarnSchedulerBackend.scala | 17 ++- ...ARNHadoopDelegationTokenManagerSuite.scala | 2 +- 17 files changed, 246 insertions(+), 237 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index f7e3ddecee093..d97857a39fc21 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -21,7 +21,6 @@ import java.io.File import java.net.URI import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem @@ -39,32 +38,24 @@ import org.apache.spark.util.ThreadUtils /** * Manager for delegation tokens in a Spark application. * - * This manager has two modes of operation: - * - * 1. When configured with a principal and a keytab, it will make sure long-running apps can run - * without interruption while accessing secured services. It periodically logs in to the KDC with - * user-provided credentials, and contacts all the configured secure services to obtain delegation - * tokens to be distributed to the rest of the application. - * - * Because the Hadoop UGI API does not expose the TTL of the TGT, a configuration controls how often - * to check that a relogin is necessary. This is done reasonably often since the check is a no-op - * when the relogin is not yet needed. The check period can be overridden in the configuration. + * When configured with a principal and a keytab, this manager will make sure long-running apps can + * run without interruption while accessing secured services. It periodically logs in to the KDC + * with user-provided credentials, and contacts all the configured secure services to obtain + * delegation tokens to be distributed to the rest of the application. * * New delegation tokens are created once 75% of the renewal interval of the original tokens has - * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. - * The driver is tasked with distributing the tokens to other processes that might need them. + * elapsed. The new tokens are sent to the Spark driver endpoint. The driver is tasked with + * distributing the tokens to other processes that might need them. * - * 2. When operating without an explicit principal and keytab, token renewal will not be available. - * Starting the manager will distribute an initial set of delegation tokens to the provided Spark - * driver, but the app will not get new tokens when those expire. - * - * It can also be used just to create delegation tokens, by calling the `obtainDelegationTokens` - * method. This option does not require calling the `start` method, but leaves it up to the - * caller to distribute the tokens that were generated. + * This class can also be used just to create delegation tokens, by calling the + * `obtainDelegationTokens` method. This option does not require calling the `start` method nor + * providing a driver reference, but leaves it up to the caller to distribute the tokens that were + * generated. */ private[spark] class HadoopDelegationTokenManager( protected val sparkConf: SparkConf, - protected val hadoopConf: Configuration) extends Logging { + protected val hadoopConf: Configuration, + protected val schedulerRef: RpcEndpointRef) extends Logging { private val deprecatedProviderEnabledConfigs = List( "spark.yarn.security.tokens.%s.enabled", @@ -85,60 +76,44 @@ private[spark] class HadoopDelegationTokenManager( s"${delegationTokenProviders.keys.mkString(", ")}.") private var renewalExecutor: ScheduledExecutorService = _ - private val driverRef = new AtomicReference[RpcEndpointRef]() - - /** Set the endpoint used to send tokens to the driver. */ - def setDriverRef(ref: RpcEndpointRef): Unit = { - driverRef.set(ref) - } /** @return Whether delegation token renewal is enabled. */ def renewalEnabled: Boolean = principal != null /** - * Start the token renewer. Requires a principal and keytab. Upon start, the renewer will: + * Start the token renewer. Requires a principal and keytab. Upon start, the renewer will + * obtain delegation tokens for all configured services and send them to the driver, and + * set up tasks to periodically get fresh tokens as needed. * - * - log in the configured principal, and set up a task to keep that user's ticket renewed - * - obtain delegation tokens from all available providers - * - send the tokens to the driver, if it's already registered - * - schedule a periodic task to update the tokens when needed. + * This method requires that a keytab has been provided to Spark, and will try to keep the + * logged in user's TGT valid while this manager is active. * - * @return The newly logged in user. + * @return New set of delegation tokens created for the configured principal. */ - def start(): UserGroupInformation = { + def start(): Array[Byte] = { require(renewalEnabled, "Token renewal must be enabled to start the renewer.") + require(schedulerRef != null, "Token renewal requires a scheduler endpoint.") renewalExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Renewal Thread") - val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() - val ugi = doLogin() - - val tgtRenewalTask = new Runnable() { - override def run(): Unit = { - ugi.checkTGTAndReloginFromKeytab() + val ugi = UserGroupInformation.getCurrentUser() + if (ugi.isFromKeytab()) { + // In Hadoop 2.x, renewal of the keytab-based login seems to be automatic, but in Hadoop 3.x, + // it is configurable (see hadoop.kerberos.keytab.login.autorenewal.enabled, added in + // HADOOP-9567). This task will make sure that the user stays logged in regardless of that + // configuration's value. Note that checkTGTAndReloginFromKeytab() is a no-op if the TGT does + // not need to be renewed yet. + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() + } } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) } - val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) - renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, - TimeUnit.SECONDS) - val creds = obtainTokensAndScheduleRenewal(ugi) - ugi.addCredentials(creds) - - val driver = driverRef.get() - if (driver != null) { - val tokens = SparkHadoopUtil.get.serialize(creds) - driver.send(UpdateDelegationTokens(tokens)) - } - - // Transfer the original user's tokens to the new user, since it may contain needed tokens - // (such as those user to connect to YARN). Explicitly avoid overwriting tokens that already - // exist in the current user's credentials, since those were freshly obtained above - // (see SPARK-23361). - val existing = ugi.getCredentials() - existing.mergeAll(originalCreds) - ugi.addCredentials(existing) - ugi + updateTokensTask() } def stop(): Unit = { @@ -218,27 +193,22 @@ private[spark] class HadoopDelegationTokenManager( * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself * to fetch the next set of tokens when needed. */ - private def updateTokensTask(): Unit = { + private def updateTokensTask(): Array[Byte] = { try { val freshUGI = doLogin() val creds = obtainTokensAndScheduleRenewal(freshUGI) val tokens = SparkHadoopUtil.get.serialize(creds) - val driver = driverRef.get() - if (driver != null) { - logInfo("Updating delegation tokens.") - driver.send(UpdateDelegationTokens(tokens)) - } else { - // This shouldn't really happen, since the driver should register way before tokens expire. - logWarning("Delegation tokens close to expiration but no driver has registered yet.") - SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) - } + logInfo("Updating delegation tokens.") + schedulerRef.send(UpdateDelegationTokens(tokens)) + tokens } catch { case e: Exception => val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + " If this happens too often tasks will fail.", e) scheduleRenewal(delay) + null } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index 90f7051381571..4ca0136424fe1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -67,11 +67,17 @@ private[spark] class HiveDelegationTokenProvider // Other modes (such as client with or without keytab, or cluster mode with keytab) do not need // a delegation token, since there's a valid kerberos TGT for the right user available to the // driver, which is the only process that connects to the HMS. - val deployMode = sparkConf.get("spark.submit.deployMode", "client") - UserGroupInformation.isSecurityEnabled && + // + // Note that this means Hive tokens are not re-created periodically by the token manager. + // This is because HMS connections are only performed by the Spark driver, and the driver + // either has a TGT, in which case it does not need tokens, or it has a token created + // elsewhere, in which case it cannot create new ones. The check for an existing token avoids + // printing an exception to the logs in the latter case. + val currentToken = UserGroupInformation.getCurrentUser().getCredentials().getToken(tokenAlias) + currentToken == null && UserGroupInformation.isSecurityEnabled && hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty && (SparkHadoopUtil.get.isProxyUser(UserGroupInformation.getCurrentUser()) || - (deployMode == "cluster" && !sparkConf.contains(KEYTAB))) + (!Utils.isClientMode(sparkConf) && !sparkConf.contains(KEYTAB))) } override def obtainDelegationTokens( @@ -98,7 +104,7 @@ private[spark] class HiveDelegationTokenProvider val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) logDebug(s"Get Token from hive metastore: ${hive2Token.toString}") - creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) + creds.addToken(tokenAlias, hive2Token) } None @@ -134,4 +140,6 @@ private[spark] class HiveDelegationTokenProvider case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e) } } + + private def tokenAlias: Text = new Text("hive.server2.delegation.token") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index e8b7fc0ef100a..9e768c22c17e3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -104,6 +104,9 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Used by YARN's client mode AM to retrieve the current set of delegation tokens. + object RetrieveDelegationTokens extends CoarseGrainedClusterMessage + // Request executors by specifying the new total number of executors desired // This includes executors already pending or running case class RequestExecutors( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 329158a44d369..98ed2fffc0ac5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -162,11 +162,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } case UpdateDelegationTokens(newDelegationTokens) => - SparkHadoopUtil.get.addDelegationTokens(newDelegationTokens, conf) - delegationTokens.set(newDelegationTokens) - executorDataMap.values.foreach { ed => - ed.executorEndpoint.send(UpdateDelegationTokens(newDelegationTokens)) - } + updateDelegationTokens(newDelegationTokens) case RemoveExecutor(executorId, reason) => // We will remove the executor's state and cannot restore it. However, the connection @@ -404,17 +400,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint = createDriverEndpointRef(properties) if (UserGroupInformation.isSecurityEnabled()) { - delegationTokenManager = createTokenManager() + delegationTokenManager = createTokenManager(driverEndpoint) delegationTokenManager.foreach { dtm => - dtm.setDriverRef(driverEndpoint) - val creds = if (dtm.renewalEnabled) { - dtm.start().getCredentials() + val tokens = if (dtm.renewalEnabled) { + dtm.start() } else { val creds = UserGroupInformation.getCurrentUser().getCredentials() dtm.obtainDelegationTokens(creds) - creds + SparkHadoopUtil.get.serialize(creds) + } + if (tokens != null) { + delegationTokens.set(tokens) } - delegationTokens.set(SparkHadoopUtil.get.serialize(creds)) } } } @@ -716,8 +713,27 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Create the delegation token manager to be used for the application. This method is called * once during the start of the scheduler backend (so after the object has already been * fully constructed), only if security is enabled in the Hadoop configuration. + * + * @param schedulerRef RPC endpoint for the scheduler, where updated delegation tokens should be + * sent. */ - protected def createTokenManager(): Option[HadoopDelegationTokenManager] = None + protected def createTokenManager( + schedulerRef: RpcEndpointRef): Option[HadoopDelegationTokenManager] = None + + /** + * Called when a new set of delegation tokens is sent to the driver. Child classes can override + * this method but should always call this implementation, which handles token distribution to + * executors. + */ + protected def updateDelegationTokens(tokens: Array[Byte]): Unit = { + SparkHadoopUtil.get.addDelegationTokens(tokens, conf) + delegationTokens.set(tokens) + executorDataMap.values.foreach { ed => + ed.executorEndpoint.send(UpdateDelegationTokens(tokens)) + } + } + + protected def currentDelegationTokens: Array[Byte] = delegationTokens.get() } diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index def9e626a2df2..af7d44b160fef 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -27,7 +27,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { private val hadoopConf = new Configuration() test("default configuration") { - val manager = new HadoopDelegationTokenManager(new SparkConf(false), hadoopConf) + val manager = new HadoopDelegationTokenManager(new SparkConf(false), hadoopConf, null) assert(manager.isProviderLoaded("hadoopfs")) assert(manager.isProviderLoaded("hbase")) assert(manager.isProviderLoaded("hive")) @@ -36,7 +36,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { test("disable hive credential provider") { val sparkConf = new SparkConf(false).set("spark.security.credentials.hive.enabled", "false") - val manager = new HadoopDelegationTokenManager(sparkConf, hadoopConf) + val manager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, null) assert(manager.isProviderLoaded("hadoopfs")) assert(manager.isProviderLoaded("hbase")) assert(!manager.isProviderLoaded("hive")) @@ -47,7 +47,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite { val sparkConf = new SparkConf(false) .set("spark.yarn.security.tokens.hadoopfs.enabled", "false") .set("spark.yarn.security.credentials.hive.enabled", "false") - val manager = new HadoopDelegationTokenManager(sparkConf, hadoopConf) + val manager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, null) assert(!manager.isProviderLoaded("hadoopfs")) assert(manager.isProviderLoaded("hbase")) assert(!manager.isProviderLoaded("hive")) @@ -99,7 +99,7 @@ private object NoHiveTest { def runTest(): Unit = { try { - val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration()) + val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(), null) require(!manager.isProviderLoaded("hive")) } catch { case e: Throwable => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala index 721d7e97b21f8..a77e8d4dbcff2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala @@ -91,7 +91,7 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri private lazy val delegationTokens: Array[Byte] = { if (keytab.isEmpty && existingSecretName.isEmpty) { val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, - SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf)) + SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf), null) val creds = UserGroupInformation.getCurrentUser().getCredentials() tokenManager.obtainDelegationTokens(creds) // If no tokens and no secrets are stored in the credentials, make sure nothing is returned, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index cd298971e02a7..e285e202a1488 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} import org.apache.spark.util.{ThreadUtils, Utils} @@ -147,8 +147,9 @@ private[spark] class KubernetesClusterSchedulerBackend( new KubernetesDriverEndpoint(sc.env.rpcEnv, properties) } - override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { - Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration)) + override protected def createTokenManager( + schedulerRef: RpcEndpointRef): Option[HadoopDelegationTokenManager] = { + Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration, schedulerRef)) } private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index d0174516c2361..03cd2583b9b2f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.config import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.rpc.{RpcEndpointAddress, RpcEndpointRef} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -772,8 +772,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } } - override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { - Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration)) + override protected def createTokenManager( + schedulerRef: RpcEndpointRef): Option[HadoopDelegationTokenManager] = { + Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration, schedulerRef)) } private def numExecutors(): Int = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8dbdac168f701..1ece7bdc979c7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.{StringUtils => ComStrUtils} import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -41,7 +42,6 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem @@ -58,6 +58,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. + private val appAttemptId = YarnSparkHadoopUtil.getContainerId.getApplicationAttemptId() private val isClusterMode = args.userClass != null private val sparkConf = new SparkConf() @@ -99,25 +100,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private val tokenManager: Option[YARNHadoopDelegationTokenManager] = { - sparkConf.get(KEYTAB).map { _ => - new YARNHadoopDelegationTokenManager(sparkConf, yarnConf) - } - } - - private val ugi = tokenManager match { - case Some(tm) => - // Set the context class loader so that the token renewer has access to jars distributed - // by the user. - Utils.withContextClassLoader(userClassLoader) { - tm.start() - } - - case _ => - SparkHadoopUtil.get.createSparkUser() - } - - private val client = doAsUser { new YarnRMClient() } + private val client = new YarnRMClient() // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -174,11 +157,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. private val sparkContextPromise = Promise[SparkContext]() - // Load the list of localized files set by the client. This is used when launching executors, - // and is loaded here so that these configs don't pollute the Web UI's environment page in - // cluster mode. - private val localResources = doAsUser { + /** + * Load the list of localized files set by the client, used when launching executors. This should + * be called in a context where the needed credentials to access HDFS are available. + */ + private def prepareLocalResources(): Map[String, LocalResource] = { logInfo("Preparing Local resources") + val distCacheConf = new SparkConf(false) + if (args.distCacheConf != null) { + Utils.getPropertiesFromFile(args.distCacheConf).foreach { case (k, v) => + distCacheConf.set(k, v) + } + } + val resources = HashMap[String, LocalResource]() def setupDistributedCache( @@ -199,11 +190,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends resources(fileName) = amJarRsrc } - val distFiles = sparkConf.get(CACHED_FILES) - val fileSizes = sparkConf.get(CACHED_FILES_SIZES) - val timeStamps = sparkConf.get(CACHED_FILES_TIMESTAMPS) - val visibilities = sparkConf.get(CACHED_FILES_VISIBILITIES) - val resTypes = sparkConf.get(CACHED_FILES_TYPES) + val distFiles = distCacheConf.get(CACHED_FILES) + val fileSizes = distCacheConf.get(CACHED_FILES_SIZES) + val timeStamps = distCacheConf.get(CACHED_FILES_TIMESTAMPS) + val visibilities = distCacheConf.get(CACHED_FILES_VISIBILITIES) + val resTypes = distCacheConf.get(CACHED_FILES_TYPES) for (i <- 0 to distFiles.size - 1) { val resType = LocalResourceType.valueOf(resTypes(i)) @@ -212,7 +203,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } // Distribute the conf archive to executors. - sparkConf.get(CACHED_CONF_ARCHIVE).foreach { path => + distCacheConf.get(CACHED_CONF_ARCHIVE).foreach { path => val uri = new URI(path) val fs = FileSystem.get(uri, yarnConf) val status = fs.getFileStatus(new Path(uri)) @@ -225,33 +216,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends LocalResourceVisibility.PRIVATE.name()) } - // Clean up the configuration so it doesn't show up in the Web UI (since it's really noisy). - CACHE_CONFIGS.foreach { e => - sparkConf.remove(e) - sys.props.remove(e.key) - } - resources.toMap } - def getAttemptId(): ApplicationAttemptId = { - client.getAttemptId() - } - final def run(): Int = { - doAsUser { - runImpl() - } - exitCode - } - - private def runImpl(): Unit = { try { - val appAttemptId = client.getAttemptId() - - var attemptID: Option[String] = None - - if (isClusterMode) { + val attemptID = if (isClusterMode) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -264,7 +234,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) - attemptID = Option(appAttemptId.getAttemptId.toString) + Option(appAttemptId.getAttemptId.toString) + } else { + None } new CallerContext( @@ -277,7 +249,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + val isLastAttempt = appAttemptId.getAttemptId() >= maxAppAttempts if (!finished) { // The default state of ApplicationMaster is failed if it is invoked by shut down hook. @@ -322,6 +294,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logWarning("Exception during stopping of the metric system: ", e) } } + + exitCode } /** @@ -377,9 +351,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logDebug("shutting down user thread") userClassThread.interrupt() } - if (!inShutdown) { - tokenManager.foreach(_.stop()) - } } } } @@ -405,8 +376,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends port: Int, _sparkConf: SparkConf, uiAddress: Option[String]): Unit = { - val appId = client.getAttemptId().getApplicationId().toString() - val attemptId = client.getAttemptId().getAttemptId().toString() + val appId = appAttemptId.getApplicationId().toString() + val attemptId = appAttemptId.getAttemptId().toString() val historyAddress = ApplicationMaster .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) @@ -415,9 +386,20 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = { - val appId = client.getAttemptId().getApplicationId().toString() + // In client mode, the AM may be restarting after delegation tokens have reached their TTL. So + // always contact the driver to get the current set of valid tokens, so that local resources can + // be initialized below. + if (!isClusterMode) { + val tokens = driverRef.askSync[Array[Byte]](RetrieveDelegationTokens) + if (tokens != null) { + SparkHadoopUtil.get.addDelegationTokens(tokens, _sparkConf) + } + } + + val appId = appAttemptId.getApplicationId().toString() val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + val localResources = prepareLocalResources() // Before we initialize the allocator, let's log the information about how executors will // be run up front, to avoid printing this out for every single executor being launched. @@ -433,13 +415,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends allocator = client.createAllocator( yarnConf, _sparkConf, + appAttemptId, driverUrl, driverRef, securityMgr, localResources) - tokenManager.foreach(_.setDriverRef(driverRef)) - // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), // the allocator is ready to service requests. @@ -755,6 +736,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends case None => logWarning("Container allocator is not ready to find executor loss reasons yet.") } + + case UpdateDelegationTokens(tokens) => + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -767,12 +751,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def doAsUser[T](fn: => T): T = { - ugi.doAs(new PrivilegedExceptionAction[T]() { - override def run: T = fn - }) - } - } object ApplicationMaster extends Logging { @@ -793,7 +771,24 @@ object ApplicationMaster extends Logging { SignalUtils.registerLogger(log) val amArgs = new ApplicationMasterArguments(args) master = new ApplicationMaster(amArgs) - System.exit(master.run()) + + val ugi = master.sparkConf.get(PRINCIPAL) match { + case Some(principal) => + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + SparkHadoopUtil.get.loginUserFromKeytab(principal, master.sparkConf.get(KEYTAB).orNull) + val newUGI = UserGroupInformation.getCurrentUser() + // Transfer the original user's tokens to the new user, since it may contain needed tokens + // (such as those user to connect to YARN). + newUGI.addCredentials(originalCreds) + newUGI + + case _ => + SparkHadoopUtil.get.createSparkUser() + } + + ugi.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = System.exit(master.run()) + }) } private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { @@ -801,7 +796,7 @@ object ApplicationMaster extends Logging { } private[spark] def getAttemptId(): ApplicationAttemptId = { - master.getAttemptId + master.appAttemptId } private[spark] def getHistoryServerAddress( diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index cc76a7c8f13f5..c10206c847271 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,6 +26,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var primaryRFile: String = null var userArgs: Seq[String] = Nil var propertiesFile: String = null + var distCacheConf: String = null parseArgs(args.toList) @@ -62,6 +63,10 @@ class ApplicationMasterArguments(val args: Array[String]) { propertiesFile = value args = tail + case ("--dist-cache-conf") :: value :: tail => + distCacheConf = value + args = tail + case _ => printUsageAndExit(1, args) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 44a60b835f12f..9f09dc0317547 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -100,21 +100,19 @@ private[spark] class Client( } private val distCacheMgr = new ClientDistributedCacheManager() + private val cachedResourcesConf = new SparkConf(false) - private val principal = sparkConf.get(PRINCIPAL).orNull private val keytab = sparkConf.get(KEYTAB).orNull - private val loginFromKeytab = principal != null - private val amKeytabFileName: String = { + private val amKeytabFileName: Option[String] = if (keytab != null && isClusterMode) { + val principal = sparkConf.get(PRINCIPAL).orNull require((principal == null) == (keytab == null), "Both principal and keytab must be defined, or neither.") - if (loginFromKeytab) { - logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") - // Generate a file name that can be used for the keytab file, that does not conflict - // with any user file. - new File(keytab).getName() + "-" + UUID.randomUUID().toString - } else { - null - } + logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + Some(new File(keytab).getName() + "-" + UUID.randomUUID().toString) + } else { + None } require(keytab == null || !Utils.isLocalUri(keytab), "Keytab should reference a local file.") @@ -220,16 +218,7 @@ private[spark] class Client( } } - if (isClusterMode && principal != null && keytab != null) { - val newUgi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - newUgi.doAs(new PrivilegedExceptionAction[Unit] { - override def run(): Unit = { - cleanupStagingDirInternal() - } - }) - } else { - cleanupStagingDirInternal() - } + cleanupStagingDirInternal() } /** @@ -312,7 +301,7 @@ private[spark] class Client( */ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { val credentials = UserGroupInformation.getCurrentUser().getCredentials() - val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf, null) credentialManager.obtainDelegationTokens(credentials) // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid @@ -496,11 +485,11 @@ private[spark] class Client( // If we passed in a keytab, make sure we copy the keytab to the staging directory on // HDFS, and setup the relevant environment vars, so the AM can login again. - if (loginFromKeytab) { + amKeytabFileName.foreach { kt => logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") val (_, localizedPath) = distribute(keytab, - destName = Some(amKeytabFileName), + destName = Some(kt), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") } @@ -636,7 +625,7 @@ private[spark] class Client( // Update the configuration with all the distributed files, minus the conf archive. The // conf archive will be handled by the AM differently so that we avoid having to send // this configuration by other means. See SPARK-14602 for one reason of why this is needed. - distCacheMgr.updateConfiguration(sparkConf) + distCacheMgr.updateConfiguration(cachedResourcesConf) // Upload the conf archive to HDFS manually, and record its location in the configuration. // This will allow the AM to know where the conf archive is in HDFS, so that it can be @@ -648,7 +637,7 @@ private[spark] class Client( // system. val remoteConfArchivePath = new Path(destDir, LOCALIZED_CONF_ARCHIVE) val remoteFs = FileSystem.get(remoteConfArchivePath.toUri(), hadoopConf) - sparkConf.set(CACHED_CONF_ARCHIVE, remoteConfArchivePath.toString()) + cachedResourcesConf.set(CACHED_CONF_ARCHIVE, remoteConfArchivePath.toString()) val localConfArchive = new Path(createConfArchive().toURI()) copyFileToRemote(destDir, localConfArchive, replication, symlinkCache, force = true, @@ -660,11 +649,6 @@ private[spark] class Client( remoteFs, hadoopConf, remoteConfArchivePath, localResources, LocalResourceType.ARCHIVE, LOCALIZED_CONF_DIR, statCache, appMasterOnly = false) - // Clear the cache-related entries from the configuration to avoid them polluting the - // UI's environment page. This works for client mode; for cluster mode, this is handled - // by the AM. - CACHE_CONFIGS.foreach(sparkConf.remove) - localResources } @@ -768,19 +752,25 @@ private[spark] class Client( hadoopConf.writeXml(confStream) confStream.closeEntry() - // Save Spark configuration to a file in the archive, but filter out the app's secret. - val props = new Properties() - sparkConf.getAll.foreach { case (k, v) => - props.setProperty(k, v) + // Save Spark configuration to a file in the archive. + val props = confToProperties(sparkConf) + + // If propagating the keytab to the AM, override the keytab name with the name of the + // distributed file. Otherwise remove princpal/keytab from the conf, so they're not seen + // by the AM at all. + amKeytabFileName match { + case Some(kt) => + props.setProperty(KEYTAB.key, kt) + case None => + props.remove(PRINCIPAL.key) + props.remove(KEYTAB.key) } - // Override spark.yarn.key to point to the location in distributed cache which will be used - // by AM. - Option(amKeytabFileName).foreach { k => props.setProperty(KEYTAB.key, k) } - confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) - val writer = new OutputStreamWriter(confStream, StandardCharsets.UTF_8) - props.store(writer, "Spark configuration.") - writer.flush() - confStream.closeEntry() + + writePropertiesToArchive(props, SPARK_CONF_FILE, confStream) + + // Write the distributed cache config to the archive. + writePropertiesToArchive(confToProperties(cachedResourcesConf), DIST_CACHE_CONF_FILE, + confStream) } finally { confStream.close() } @@ -984,7 +974,10 @@ private[spark] class Client( } val amArgs = Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ - Seq("--properties-file", buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) + Seq("--properties-file", + buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) ++ + Seq("--dist-cache-conf", + buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, DIST_CACHE_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ @@ -1213,6 +1206,9 @@ private object Client extends Logging { // Name of the file in the conf archive containing Spark configuration. val SPARK_CONF_FILE = "__spark_conf__.properties" + // Name of the file in the conf archive containing the distributed cache info. + val DIST_CACHE_CONF_FILE = "__spark_dist_cache__.properties" + // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" @@ -1512,6 +1508,22 @@ private object Client extends Logging { } getClusterPath(conf, cmdPrefix) } + + def confToProperties(conf: SparkConf): Properties = { + val props = new Properties() + conf.getAll.foreach { case (k, v) => + props.setProperty(k, v) + } + props + } + + def writePropertiesToArchive(props: Properties, name: String, out: ZipOutputStream): Unit = { + out.putNextEntry(new ZipEntry(name)) + val writer = new OutputStreamWriter(out, StandardCharsets.UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + out.closeEntry() + } } private[spark] class YarnClusterApplication extends SparkApplication { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 05a7b1e1310c4..cf16edf16c034 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -76,12 +76,13 @@ private[spark] class YarnRMClient extends Logging { def createAllocator( conf: YarnConfiguration, sparkConf: SparkConf, + appAttemptId: ApplicationAttemptId, driverUrl: String, driverRef: RpcEndpointRef, securityMgr: SecurityManager, localResources: Map[String, LocalResource]): YarnAllocator = { require(registered, "Must register AM before creating allocator.") - new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, appAttemptId, securityMgr, localResources, new SparkRackResolver()) } @@ -100,11 +101,6 @@ private[spark] class YarnRMClient extends Logging { } } - /** Returns the attempt ID. */ - def getAttemptId(): ApplicationAttemptId = { - YarnSparkHadoopUtil.getContainerId.getApplicationAttemptId() - } - /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String): Map[String, String] = { // Figure out which scheme Yarn is using. Note the method seems to have been added after 2.2, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 7e9cd409daf36..6091cd496c037 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -321,16 +321,6 @@ package object config { .stringConf .createOptional - // The list of cache-related config entries. This is used by Client and the AM to clean - // up the environment so that these settings do not appear on the web UI. - private[yarn] val CACHE_CONFIGS = Seq( - CACHED_FILES, - CACHED_FILES_SIZES, - CACHED_FILES_TIMESTAMPS, - CACHED_FILES_VISIBILITIES, - CACHED_FILES_TYPES, - CACHED_CONF_ARCHIVE) - /* YARN allocator-level blacklisting related config entries. */ private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 2d9a3f0c83fd2..bb40ea8015198 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -36,10 +36,11 @@ import org.apache.spark.util.Utils * [[ServiceCredentialProvider]] interface, as well as the builtin providers defined * in [[HadoopDelegationTokenManager]]. */ -private[yarn] class YARNHadoopDelegationTokenManager( +private[spark] class YARNHadoopDelegationTokenManager( _sparkConf: SparkConf, - _hadoopConf: Configuration) - extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf) { + _hadoopConf: Configuration, + _schedulerRef: RpcEndpointRef) + extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf, _schedulerRef) { private val credentialProviders = { ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 167eef19ed856..934fba3e6ff35 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -27,6 +27,7 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ private[spark] class YarnClientSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -166,4 +167,9 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } + override protected def updateDelegationTokens(tokens: Array[Byte]): Unit = { + super.updateDelegationTokens(tokens) + amEndpoint.foreach(_.send(UpdateDelegationTokens(tokens))) + } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 1289d4be79ea4..6357d4adbcd99 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -27,6 +27,8 @@ import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -55,6 +57,7 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) + protected var amEndpoint: Option[RpcEndpointRef] = None private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) @@ -191,6 +194,11 @@ private[spark] abstract class YarnSchedulerBackend( sc.executorAllocationManager.foreach(_.reset()) } + override protected def createTokenManager( + schedulerRef: RpcEndpointRef): Option[HadoopDelegationTokenManager] = { + Some(new YARNHadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration, schedulerRef)) + } + /** * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. * This endpoint communicates with the executors and queries the AM for an executor's exit @@ -226,7 +234,6 @@ private[spark] abstract class YarnSchedulerBackend( */ private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var amEndpoint: Option[RpcEndpointRef] = None private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( executorId: String, @@ -266,11 +273,6 @@ private[spark] abstract class YarnSchedulerBackend( logWarning(s"Requesting driver to remove executor $executorId for reason $reason") driverEndpoint.send(r) } - - case u @ UpdateDelegationTokens(tokens) => - // Add the tokens to the current user and send a message to the scheduler so that it - // notifies all registered executors of the new tokens. - driverEndpoint.send(u) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -304,6 +306,9 @@ private[spark] abstract class YarnSchedulerBackend( case RetrieveLastAllocatedExecutorId => context.reply(currentExecutorIdCounter) + + case RetrieveDelegationTokens => + context.reply(currentDelegationTokens) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 98315e4235741..f00453cb9c597 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -34,7 +34,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite { } test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf, null) assert(credentialManager.isProviderLoaded("yarn-test")) } } From 98be8953c75c026c1cb432cc8f66dd312feed0c6 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 7 Jan 2019 13:59:40 -0800 Subject: [PATCH 2394/2461] [SPARK-26065][SQL] Change query hint from a `LogicalPlan` to a field ## What changes were proposed in this pull request? The existing query hint implementation relies on a logical plan node `ResolvedHint` to store query hints in logical plans, and on `Statistics` in physical plans. Since `ResolvedHint` is not really a logical operator and can break the pattern matching for existing and future optimization rules, it is a issue to the Optimizer as the old `AnalysisBarrier` was to the Analyzer. Given the fact that all our query hints are either 1) a join hint, i.e., broadcast hint; or 2) a re-partition hint, which is indeed an operator, we only need to add a hint field on the Join plan and that will be a good enough solution for the current hint usage. This PR is to let `Join` node have a hint for its left sub-tree and another hint for its right sub-tree and each hint is a merged result of all the effective hints specified in the corresponding sub-tree. The "effectiveness" of a hint, i.e., whether that hint should be propagated to the `Join` node, is currently consistent with the hint propagation rules originally implemented in the `Statistics` approach. Note that the `ResolvedHint` node still has to live through the analysis stage because of the `Dataset` interface, but it will be got rid of and moved to the `Join` node in the "pre-optimization" stage. This PR also introduces a change in how hints work with join reordering. Before this PR, hints would stop join reordering. For example, in "a.join(b).join(c).hint("broadcast").join(d)", the broadcast hint would stop d from participating in the cost-based join reordering while still allowing reordering from under the hint node. After this PR, though, the broadcast hint will not interfere with join reordering at all, and after reordering if a relation associated with a hint stays unchanged or equivalent to the original relation, the hint will be retained, otherwise will be discarded. For example, the original plan is like "a.join(b).hint("broadcast").join(c).hint("broadcast").join(d)", thus the join order is "a JOIN b JOIN c JOIN d". So if after reordering the join order becomes "a JOIN b JOIN (c JOIN d)", the plan will be like "a.join(b).hint("broadcast").join(c.join(d))"; but if after reordering the join order becomes "a JOIN c JOIN b JOIN d", the plan will be like "a.join(c).join(b).hint("broadcast").join(d)". ## How was this patch tested? Added new tests. Closes #23036 from maryannxue/query-hint. Authored-by: maryannxue Signed-off-by: gatorsmile --- .../sql/catalyst/analysis/Analyzer.scala | 16 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../analysis/StreamingJoinHelper.scala | 2 +- .../UnsupportedOperationChecker.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../optimizer/CostBasedJoinReorder.scala | 84 ++++++-- .../optimizer/EliminateResolvedHint.scala | 59 ++++++ .../sql/catalyst/optimizer/Optimizer.scala | 36 ++-- .../optimizer/PropagateEmptyRelation.scala | 2 +- .../ReplaceNullWithFalseInPredicate.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 3 +- .../spark/sql/catalyst/optimizer/joins.scala | 27 ++- .../sql/catalyst/optimizer/subquery.scala | 14 +- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql/catalyst/planning/patterns.scala | 41 ++-- .../plans/logical/LogicalPlanVisitor.scala | 3 - .../catalyst/plans/logical/Statistics.scala | 7 +- .../plans/logical/basicLogicalOperators.scala | 14 +- .../sql/catalyst/plans/logical/hints.scala | 27 ++- .../statsEstimation/AggregateEstimation.scala | 3 +- .../BasicStatsPlanVisitor.scala | 2 - .../statsEstimation/JoinEstimation.scala | 2 +- .../SizeInBytesOnlyStatsPlanVisitor.scala | 22 +- .../analysis/AnalysisErrorSuite.scala | 8 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../catalyst/analysis/ResolveHintsSuite.scala | 2 +- .../optimizer/ColumnPruningSuite.scala | 6 +- .../optimizer/FilterPushdownSuite.scala | 14 -- .../optimizer/JoinOptimizationSuite.scala | 28 +-- .../catalyst/optimizer/JoinReorderSuite.scala | 83 +++++++- .../optimizer/ReplaceOperatorSuite.scala | 8 +- .../spark/sql/catalyst/plans/PlanTest.scala | 10 +- .../sql/catalyst/plans/SameResultSuite.scala | 16 +- .../BasicStatsEstimationSuite.scala | 18 -- .../FilterEstimationSuite.scala | 2 +- .../statsEstimation/JoinEstimationSuite.scala | 31 +-- .../scala/org/apache/spark/sql/Dataset.scala | 16 +- .../spark/sql/execution/SparkStrategies.scala | 52 ++--- .../execution/columnar/InMemoryRelation.scala | 9 +- .../apache/spark/sql/CachedTableSuite.scala | 21 ++ .../apache/spark/sql/DataFrameJoinSuite.scala | 10 +- .../org/apache/spark/sql/JoinHintSuite.scala | 193 ++++++++++++++++++ .../spark/sql/StatisticsCollectionSuite.scala | 3 +- .../execution/joins/BroadcastJoinSuite.scala | 14 +- .../execution/joins/ExistenceJoinSuite.scala | 11 +- .../sql/execution/joins/InnerJoinSuite.scala | 15 +- .../sql/execution/joins/OuterJoinSuite.scala | 11 +- 47 files changed, 680 insertions(+), 283 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 198645d875c47..2aa0f2117364c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -943,7 +943,7 @@ class Analyzer( failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") // To resolve duplicate expression IDs for Join and Intersect - case j @ Join(left, right, _, _) if !j.duplicateResolved => + case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) @@ -2249,13 +2249,14 @@ class Analyzer( */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case j @ Join(left, right, UsingJoin(joinType, usingCols), _) + case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint) if left.resolved && right.resolved && j.duplicateResolved => - commonNaturalJoinProcessing(left, right, joinType, usingCols, None) - case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint) + case j @ Join(left, right, NaturalJoin(joinType), condition, hint) + if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) - commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) + commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint) } } @@ -2360,7 +2361,8 @@ class Analyzer( right: LogicalPlan, joinType: JoinType, joinNames: Seq[String], - condition: Option[Expression]) = { + condition: Option[Expression], + hint: JoinHint) = { val leftKeys = joinNames.map { keyName => left.output.find(attr => resolver(attr.name, keyName)).getOrElse { throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + @@ -2401,7 +2403,7 @@ class Analyzer( sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields - Project(projectList, Join(left, right, joinType, newCondition)) + Project(projectList, Join(left, right, joinType, newCondition, hint)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c28a97839fe49..18c40b370cb5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -172,7 +172,7 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + s"conditions: $condition") - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.sql}' " + s"of type ${condition.dataType.catalogString} is not a boolean.") @@ -609,7 +609,7 @@ trait CheckAnalysis extends PredicateHelper { failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => + case j @ Join(left, right, joinType, _, _) => joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 7a0aa08289efa..76733dd6dac3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { */ def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { plan match { - case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) => (leftKeys ++ rightKeys).exists { case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index cff4cee09427f..41ba6d34b5499 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -229,7 +229,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 151481c80ee96..846ee3b386527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -325,7 +325,7 @@ package object dsl { otherPlan: LogicalPlan, joinType: JoinType = Inner, condition: Option[Expression] = None): LogicalPlan = - Join(logicalPlan, otherPlan, joinType, condition) + Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE) def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder]( otherPlan: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 01634a9d852c6..743d3ce944fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType} -import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -31,6 +31,40 @@ import org.apache.spark.sql.internal.SQLConf * Cost-based join reorder. * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. + * + * Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering. + * Such hints will be applied on the equivalent counterparts (i.e., join between the same relations + * regardless of the join order) of the original nodes after reordering. + * For example, the plan before reordering is like: + * + * Join + * / \ + * Hint1 t4 + * / + * Join + * / \ + * Join t3 + * / \ + * Hint2 t2 + * / + * t1 + * + * The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after + * reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like: + * + * Join + * / \ + * Hint1 t4 + * / + * Join + * / \ + * Join t2 + * / \ + * t1 t3 + * + * "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node, + * "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no + * equivalent node to "t1 JOIN t2". */ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { @@ -40,24 +74,30 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan } else { + // Use a map to track the hints on the join items. + val hintMap = new mutable.HashMap[AttributeSet, HintInfo] val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. - case j @ Join(_, _, _: InnerLike, Some(cond)) => - reorder(j, j.output) - case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) + case j @ Join(_, _, _: InnerLike, Some(cond), _) => + reorder(j, j.output, hintMap) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => - reorder(p, p.output) + reorder(p, p.output, hintMap) } - - // After reordering is finished, convert OrderedJoin back to Join - result transformDown { - case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) + // After reordering is finished, convert OrderedJoin back to Join. + result transform { + case OrderedJoin(left, right, jt, cond) => + val joinHint = JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)) + Join(left, right, jt, cond, joinHint) } } } - private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { - val (items, conditions) = extractInnerJoins(plan) + private def reorder( + plan: LogicalPlan, + output: Seq[Attribute], + hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan, hintMap) val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. @@ -75,27 +115,31 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { * Extracts items of consecutive inner joins and join conditions. * This method works for bushy trees and left/right deep trees. */ - private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { + private def extractInnerJoins( + plan: LogicalPlan, + hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond)) => - val (leftPlans, leftConditions) = extractInnerJoins(left) - val (rightPlans, rightConditions) = extractInnerJoins(right) + case Join(left, right, _: InnerLike, Some(cond), hint) => + hint.leftHint.foreach(hintMap.put(left.outputSet, _)) + hint.rightHint.foreach(hintMap.put(right.outputSet, _)) + val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap) + val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) if projectList.forall(_.isInstanceOf[Attribute]) => - extractInnerJoins(j) + extractInnerJoins(j, hintMap) case _ => (Seq(plan), Set()) } } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond)) => + case j @ Join(left, right, jt: InnerLike, Some(cond), _) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan @@ -295,7 +339,7 @@ object JoinReorderDP extends PredicateHelper with Logging { } else { (otherPlan, onePlan) } - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And), JoinHint.NONE) val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds val remainingConds = conditions -- collectedJoinConds val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala new file mode 100644 index 0000000000000..bbe4eee4b4326 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Replaces [[ResolvedHint]] operators from the plan. Move the [[HintInfo]] to associated [[Join]] + * operators, otherwise remove it if no [[Join]] operator is matched. + */ +object EliminateResolvedHint extends Rule[LogicalPlan] { + // This is also called in the beginning of the optimization phase, and as a result + // is using transformUp rather than resolveOperators. + def apply(plan: LogicalPlan): LogicalPlan = { + val pulledUp = plan transformUp { + case j: Join => + val leftHint = mergeHints(collectHints(j.left)) + val rightHint = mergeHints(collectHints(j.right)) + j.copy(hint = JoinHint(leftHint, rightHint)) + } + pulledUp.transform { + case h: ResolvedHint => h.child + } + } + + private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = { + hints.reduceOption((h1, h2) => HintInfo( + broadcast = h1.broadcast || h2.broadcast)) + } + + private def collectHints(plan: LogicalPlan): Seq[HintInfo] = { + plan match { + case h: ResolvedHint => collectHints(h.child) :+ h.hints + case u: UnaryNode => collectHints(u.child) + // TODO revisit this logic: + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + case i: Intersect => collectHints(i.left) + case e: Except => collectHints(e.left) + case _ => Seq.empty + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 44d5543114902..06f908281dd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -115,6 +115,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // However, because we also use the analyzer to canonicalized queries (for view definition), // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, + EliminateResolvedHint, EliminateSubqueryAliases, EliminateView, ReplaceExpressions, @@ -192,6 +193,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) */ def nonExcludableRules: Seq[String] = EliminateDistinct.ruleName :: + EliminateResolvedHint.ruleName :: EliminateSubqueryAliases.ruleName :: EliminateView.ruleName :: ReplaceExpressions.ruleName :: @@ -356,7 +358,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child // attribute is not on the black list. - case Join(left, right, joinType, condition) => + case Join(left, right, joinType, condition, hint) => val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet) val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet) val mapping = AttributeMap( @@ -365,7 +367,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { val newCondition = condition.map(_.transform { case a: Attribute => mapping.getOrElse(a, a) }) - Join(newLeft, newRight, joinType, newCondition) + Join(newLeft, newRight, joinType, newCondition, hint) case _ => // Remove redundant aliases in the subtree(s). @@ -460,7 +462,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - case LocalLimit(exp, join @ Join(left, right, joinType, _)) => + case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) @@ -578,7 +580,7 @@ object ColumnPruning extends Rule[LogicalPlan] { p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) // Eliminate unneeded attributes from right side of a Left Existence Join. - case j @ Join(_, right, LeftExistence(_), _) => + case j @ Join(_, right, LeftExistence(_), _, _) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -792,7 +794,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] filter } - case join @ Join(left, right, joinType, conditionOpt) => + case join @ Join(left, right, joinType, conditionOpt, _) => joinType match { // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an // inner join, it just drops the right side in the final output. @@ -919,7 +921,6 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] { def canEliminateSort(plan: LogicalPlan): Boolean = plan match { case p: Project => p.projectList.forall(_.deterministic) case f: Filter => f.condition.deterministic - case _: ResolvedHint => true case _ => false } } @@ -1094,7 +1095,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). case _: AppendColumns => true - case _: ResolvedHint => true case _: Distinct => true case _: Generate => true case _: Pivot => true @@ -1179,7 +1179,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter - case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => + case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, hint)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -1193,7 +1193,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { commonFilterCondition.partition(canEvaluateWithinJoin) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) - val join = Join(newLeft, newRight, joinType, newJoinCond) + val join = Join(newLeft, newRight, joinType, newJoinCond, hint) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) } else { @@ -1205,7 +1205,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) + val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, hint) (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -1215,7 +1215,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = right val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, joinType, newJoinCond) + val newJoin = Join(newLeft, newRight, joinType, newJoinCond, hint) (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -1225,7 +1225,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case j @ Join(left, right, joinType, joinCondition) => + case j @ Join(left, right, joinType, joinCondition, hint) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -1238,7 +1238,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, hint) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -1246,7 +1246,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = right val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, RightOuter, newJoinCond) + Join(newLeft, newRight, RightOuter, newJoinCond, hint) case LeftOuter | LeftAnti | ExistenceJoin(_) => // push down the right side only join filter for right sub query val newLeft = left @@ -1254,7 +1254,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, hint) case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") @@ -1310,7 +1310,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) if isCartesianProduct(j) => throw new AnalysisException( s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans @@ -1449,7 +1449,7 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } - Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And), JoinHint.NONE)) } } @@ -1470,7 +1470,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } - Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And))) + Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And), JoinHint.NONE)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index c3fdb924243df..b19e13870aa65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -56,7 +56,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit // Joins on empty LocalRelations generated from streaming sources are not eliminated // as stateful streaming joins need to perform other state management operations other than // just processing the input data. - case p @ Join(_, _, joinType, _) + case p @ Join(_, _, joinType, _, _) if !p.children.exists(_.isStreaming) => val isLeftEmpty = isEmptyLocalRelation(p.left) val isRightEmpty = isEmptyLocalRelation(p.right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 72a60f692ac78..689915a985343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -52,7 +52,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) - case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) case p: LogicalPlan => p transformExpressions { case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) case cw @ CaseWhen(branches, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 468a950fb1087..39709529c00d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -600,7 +600,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty => val newJoin = j.transformExpressions(replaceFoldable) val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { case _: InnerLike | LeftExistence(_) => Nil @@ -648,7 +648,6 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: Distinct => true case _: AppendColumns => true case _: AppendColumnsWithObject => true - case _: ResolvedHint => true case _: RepartitionByExpression => true case _: Repartition => true case _: Sort => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 0b6471289a471..82aefca8a1af6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -43,10 +43,13 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * * @param input a list of LogicalPlans to inner join and the type of inner join. * @param conditions a list of condition for join. + * @param hintMap a map of relation output attribute sets to their corresponding hints. */ @tailrec - final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) - : LogicalPlan = { + final def createOrderedJoin( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression], + hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) @@ -55,7 +58,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { case (Inner, Inner) => Inner case (_, _) => Cross } - val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) + val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And), + JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))) if (others.nonEmpty) { Filter(others.reduceLeft(And), join) } else { @@ -78,26 +82,27 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) - val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) + val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And), + JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet))) // should not have reference to same logical plan - createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) + createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others, hintMap) } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ ExtractFiltersAndInnerJoins(input, conditions) + case p @ ExtractFiltersAndInnerJoins(input, conditions, hintMap) if input.size > 2 && conditions.nonEmpty => val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) { val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions) if (starJoinPlan.nonEmpty) { val rest = input.filterNot(starJoinPlan.contains(_)) - createOrderedJoin(starJoinPlan ++ rest, conditions) + createOrderedJoin(starJoinPlan ++ rest, conditions, hintMap) } else { - createOrderedJoin(input, conditions) + createOrderedJoin(input, conditions, hintMap) } } else { - createOrderedJoin(input, conditions) + createOrderedJoin(input, conditions, hintMap) } if (p.sameOutput(reordered)) { @@ -156,7 +161,7 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _, _)) => val newJoinType = buildNewJoinType(f, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } @@ -176,7 +181,7 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH } override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) => + case j @ Join(_, _, joinType, Some(cond), _) if hasUnevaluablePythonUDF(cond, j) => if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { // The current strategy only support InnerLike and LeftSemi join because for other type, // it breaks SQL semantic if we run the join condition as a filter after join. If we pass diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 34840c6c977a6..e78ed1c3c5d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -51,7 +51,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { condition: Option[Expression]): Join = { // Deduplicate conflicting attributes if any. val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition) - Join(outerPlan, dedupSubplan, joinType, condition) + Join(outerPlan, dedupSubplan, joinType, condition, JoinHint.NONE) } private def dedupSubqueryOnSelfJoin( @@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) - Join(outerPlan, newSub, LeftSemi, joinCond) + Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE) case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive @@ -142,7 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) - Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)) + Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond), JoinHint.NONE) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) @@ -172,7 +172,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values)) val inConditions = values.zip(newSub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) - newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions) + newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions, JoinHint.NONE) exists } } @@ -450,7 +450,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // CASE 1: Subquery guaranteed not to have the COUNT bug Project( currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // Subquery might have the COUNT bug. Add appropriate corrections. val (topPart, havingNode, aggNode) = splitSubquery(query) @@ -477,7 +477,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { aggValRef), origOutput.name)(exprId = origOutput.exprId), Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And))) + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. @@ -507,7 +507,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { currentChild.output :+ caseExpr, Join(currentChild, Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And))) + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8959f78b656d2..a27c6d3c3671c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -515,7 +515,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => val right = plan(relation.relationPrimary) - val join = right.optionalMap(left)(Join(_, _, Inner, None)) + val join = right.optionalMap(left)(Join(_, _, Inner, None, JoinHint.NONE)) withJoinRelations(join, relation) } if (ctx.pivotClause() != null) { @@ -727,7 +727,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case None => (baseJoinType, None) } - Join(left, plan(join.right), joinType, condition) + Join(left, plan(join.right), joinType, condition, JoinHint.NONE) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 84be677e438a6..dfc3b2d22129d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -98,12 +100,13 @@ object PhysicalOperation extends PredicateHelper { * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild, joinHint) */ type ReturnType = - (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) + (JoinType, Seq[Expression], Seq[Expression], + Option[Expression], LogicalPlan, LogicalPlan, JoinHint) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition) => + case join @ Join(left, right, joinType, condition, hint) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -133,7 +136,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if (joinKeys.nonEmpty) { val (leftKeys, rightKeys) = joinKeys.unzip logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") - Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) + Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right, hint)) } else { None } @@ -164,25 +167,35 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { * was involved in an explicit cross join. Also returns the entire list of join conditions for * the left-deep tree. */ - def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + def flattenJoin( + plan: LogicalPlan, + hintMap: mutable.HashMap[AttributeSet, HintInfo], + parentJoinType: InnerLike = Inner) : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { - case Join(left, right, joinType: InnerLike, cond) => - val (plans, conditions) = flattenJoin(left, joinType) + case Join(left, right, joinType: InnerLike, cond, hint) => + val (plans, conditions) = flattenJoin(left, hintMap, joinType) + hint.leftHint.map(hintMap.put(left.outputSet, _)) + hint.rightHint.map(hintMap.put(right.outputSet, _)) (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => - val (plans, conditions) = flattenJoin(j) + case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) => + val (plans, conditions) = flattenJoin(j, hintMap) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) case _ => (Seq((plan, parentJoinType)), Seq.empty) } - def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] + def unapply(plan: LogicalPlan) + : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])] = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _)) => - Some(flattenJoin(f)) - case j @ Join(_, _, joinType, _) => - Some(flattenJoin(j)) + case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) => + val hintMap = new mutable.HashMap[AttributeSet, HintInfo] + val flattened = flattenJoin(f, hintMap) + Some((flattened._1, flattened._2, hintMap.toMap)) + case j @ Join(_, _, joinType, _, _) => + val hintMap = new mutable.HashMap[AttributeSet, HintInfo] + val flattened = flattenJoin(j, hintMap) + Some((flattened._1, flattened._2, hintMap.toMap)) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index 2c248d74869ce..18baced8f3d61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -37,7 +37,6 @@ trait LogicalPlanVisitor[T] { case p: Project => visitProject(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) - case p: ResolvedHint => visitHint(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) @@ -61,8 +60,6 @@ trait LogicalPlanVisitor[T] { def visitGlobalLimit(p: GlobalLimit): T - def visitHint(p: ResolvedHint): T - def visitIntersect(p: Intersect): T def visitJoin(p: Join): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index b3a48860aa63b..5a388117a6c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -52,13 +52,11 @@ import org.apache.spark.util.Utils * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. * @param attributeStats Statistics for Attributes. - * @param hints Query hints. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil), - hints: HintInfo = HintInfo()) { + attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil)) { override def toString: String = "Statistics(" + simpleString + ")" @@ -70,8 +68,7 @@ case class Statistics( s"rowCount=${BigDecimal(rowCount.get, new MathContext(3, RoundingMode.HALF_UP)).toString()}" } else { "" - }, - s"hints=$hints" + } ).filter(_.nonEmpty).mkString(", ") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d8b3a4af4f7bf..639d68f4ecd76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -288,7 +288,8 @@ case class Join( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) + condition: Option[Expression], + hint: JoinHint) extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { @@ -350,6 +351,17 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } + + // Ignore hint for canonicalization + protected override def doCanonicalize(): LogicalPlan = + super.doCanonicalize().asInstanceOf[Join].copy(hint = JoinHint.NONE) + + // Do not include an empty join hint in string description + protected override def stringArgs: Iterator[Any] = super.stringArgs.filter { e => + (!e.isInstanceOf[JoinHint] + || e.asInstanceOf[JoinHint].leftHint.isDefined + || e.asInstanceOf[JoinHint].rightHint.isDefined) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index cbb626590d1d7..b2ba725e9d44f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -35,6 +35,7 @@ case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan /** * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]]. + * This node will be eliminated before optimization starts. */ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) extends UnaryNode { @@ -44,11 +45,31 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def doCanonicalize(): LogicalPlan = child.canonicalized } +/** + * Hint that is associated with a [[Join]] node, with [[HintInfo]] on its left child and on its + * right child respectively. + */ +case class JoinHint(leftHint: Option[HintInfo], rightHint: Option[HintInfo]) { -case class HintInfo(broadcast: Boolean = false) { + override def toString: String = { + Seq( + leftHint.map("leftHint=" + _), + rightHint.map("rightHint=" + _)) + .filter(_.isDefined).map(_.get).mkString(", ") + } +} - /** Must be called when computing stats for a join operator to reset hints. */ - def resetForJoin(): HintInfo = copy(broadcast = false) +object JoinHint { + val NONE = JoinHint(None, None) +} + +/** + * The hint attributes to be applied on a specific node. + * + * @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join + * strategy and the node with this hint is preferred to be the build side. + */ +case class HintInfo(broadcast: Boolean = false) { override def toString: String = { val hints = scala.collection.mutable.ArrayBuffer.empty[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 111c594a53e52..eb56ab43ea9d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -56,8 +56,7 @@ object AggregateEstimation { Some(Statistics( sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), - attributeStats = outputAttrStats, - hints = childStats.hints)) + attributeStats = outputAttrStats)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index b6c16079d1984..b8c652dc8f12e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -47,8 +47,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitGlobalLimit(p: GlobalLimit): Statistics = fallback(p) - override def visitHint(p: ResolvedHint): Statistics = fallback(p) - override def visitIntersect(p: Intersect): Statistics = fallback(p) override def visitJoin(p: Join): Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 2543e38a92c0a..19a0d1279cc32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging { case _ if !rowCountsExist(join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index ee43f9126386b..da36db7ae1f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -44,7 +44,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = p.child.stats.hints) + Statistics(sizeInBytes = sizeInBytes) } /** @@ -60,8 +60,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { if (p.groupingExpressions.isEmpty) { Statistics( sizeInBytes = EstimationUtils.getOutputSize(p.output, outputRowCount = 1), - rowCount = Some(1), - hints = p.child.stats.hints) + rowCount = Some(1)) } else { visitUnaryNode(p) } @@ -87,19 +86,15 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { // Don't propagate column stats, because we don't know the distribution after limit Statistics( sizeInBytes = EstimationUtils.getOutputSize(p.output, rowCount, childStats.attributeStats), - rowCount = Some(rowCount), - hints = childStats.hints) + rowCount = Some(rowCount)) } - override def visitHint(p: ResolvedHint): Statistics = p.child.stats.copy(hints = p.hints) - override def visitIntersect(p: Intersect): Statistics = { val leftSize = p.left.stats.sizeInBytes val rightSize = p.right.stats.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics( - sizeInBytes = sizeInBytes, - hints = p.left.stats.hints.resetForJoin()) + sizeInBytes = sizeInBytes) } override def visitJoin(p: Join): Statistics = { @@ -108,10 +103,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { // LeftSemi and LeftAnti won't ever be bigger than left p.left.stats case _ => - // Make sure we don't propagate isBroadcastable in other joins, because - // they could explode the size. - val stats = default(p) - stats.copy(hints = stats.hints.resetForJoin()) + default(p) } } @@ -121,7 +113,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - Statistics(sizeInBytes = 1, rowCount = Some(0), hints = childStats.hints) + Statistics(sizeInBytes = 1, rowCount = Some(0)) } else { // The output row count of LocalLimit should be the sum of row counts from each partition. // However, since the number of partitions is not available here, we just use statistics of @@ -147,7 +139,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } val sampleRows = p.child.stats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) // Don't propagate column stats, because we don't know the distribution after a sample operation - Statistics(sizeInBytes, sampleRows, hints = p.child.stats.hints) + Statistics(sizeInBytes, sampleRows) } override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 117e96175e92a..129ce3b1105ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -443,7 +443,7 @@ class AnalysisErrorSuite extends AnalysisTest { } test("error test for self-join") { - val join = Join(testRelation, testRelation, Cross, None) + val join = Join(testRelation, testRelation, Cross, None, JoinHint.NONE) val error = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(join) } @@ -565,7 +565,8 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation(b), Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, - Option(EqualTo(b, c)))), + Option(EqualTo(b, c)), + JoinHint.NONE)), LocalRelation(a)) assertAnalysisError(plan1, "Accessing outer query column is not allowed in" :: Nil) @@ -575,7 +576,8 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, - Option(EqualTo(b, c)))), + Option(EqualTo(b, c)), + JoinHint.NONE)), LocalRelation(a)) assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index da3ae72c3682a..982948483fa1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -397,7 +397,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), Project(Seq($"y.key"), SubqueryAlias("y", input)), - Cross, None)) + Cross, None, JoinHint.NONE)) assertAnalysisSuccess(query) } @@ -578,7 +578,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq(UnresolvedAttribute("a")), pythonUdf, output, project) val left = SubqueryAlias("temp0", flatMapGroupsInPandas) val right = SubqueryAlias("temp1", flatMapGroupsInPandas) - val join = Join(left, right, Inner, None) + val join = Join(left, right, Inner, None, JoinHint.NONE) assertAnalysisSuccess( Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index bd66ee5355f45..563e8adf87edc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -60,7 +60,7 @@ class ResolveHintsSuite extends AnalysisTest { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), - ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, JoinHint.NONE), caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 57195d5fda7c5..0cd6e092e2036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -353,15 +353,15 @@ class ColumnPruningSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input), - ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze + SubqueryAlias("y", input), Inner, None, JoinHint.NONE)).analyze val optimized = Optimize.execute(query) val expected = Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), - Inner, None).analyze + Project(Seq($"y.key"), SubqueryAlias("y", input)), + Inner, None, JoinHint.NONE).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 82a10254d846d..cf4e9fcea2c6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -822,19 +821,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("broadcast hint") { - val originalQuery = ResolvedHint(testRelation) - .where('a === 2L && 'b + Rand(10).as("rnd") === 3) - - val optimized = Optimize.execute(originalQuery.analyze) - - val correctAnswer = ResolvedHint(testRelation.where('a === 2L)) - .where('b + Rand(10).as("rnd") === 3) - .analyze - - comparePlans(optimized, correctAnswer) - } - test("union") { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 6fe5e619d03ad..9093d7fecb0f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -65,7 +65,8 @@ class JoinOptimizationSuite extends PlanTest { def testExtractCheckCross (plan: LogicalPlan, expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]) { - assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + assert( + ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2, Map.empty))) } testExtract(x, None) @@ -124,29 +125,4 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, queryAnswerPair._2.analyze) } } - - test("broadcasthint sets relation statistics to smallest value") { - val input = LocalRelation('key.int, 'value.string) - - val query = - Project(Seq($"x.key", $"y.key"), - Join( - SubqueryAlias("x", input), - ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze - - val optimized = Optimize.execute(query) - - val expected = - Join( - Project(Seq($"x.key"), SubqueryAlias("x", input)), - ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), - Cross, None).analyze - - comparePlans(optimized, expected) - - val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r - } - assert(broadcastChildren.size == 1) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index c94a8b9e318f6..0dee846205868 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -31,6 +31,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { object Optimize extends RuleExecutor[LogicalPlan] { val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, @@ -42,6 +44,12 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { CostBasedJoinReorder) :: Nil } + object ResolveHints extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Resolve Hints", Once, + EliminateResolvedHint) :: Nil + } + var originalConfCBOEnabled = false var originalConfJoinReorderEnabled = false @@ -284,12 +292,85 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("hints preservation") { + // Apply hints if we find an equivalent node in the new plan, otherwise discard them. + val originalPlan = + t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast")) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan = + t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .hint("broadcast") + .join( + t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast"), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + + assertEqualPlans(originalPlan, bestPlan) + + val originalPlan2 = + t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast")) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan2 = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .hint("broadcast") + .join( + t4.hint("broadcast") + .join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t1, t2, t3, t4): _*) + + assertEqualPlans(originalPlan2, bestPlan2) + + val originalPlan3 = + t1.join(t4).hint("broadcast") + .join(t2.hint("broadcast")).hint("broadcast") + .join(t3.hint("broadcast")) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan3 = + t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join( + t4.join(t3.hint("broadcast"), + Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t1, t4, t2, t3): _*) + + assertEqualPlans(originalPlan3, bestPlan3) + + val originalPlan4 = + t2.hint("broadcast") + .join(t4).hint("broadcast") + .join(t3.hint("broadcast")).hint("broadcast") + .join(t1) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan4 = + t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join( + t4.join(t3.hint("broadcast"), + Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + .select(outputsOf(t2, t4, t3, t1): _*) + + assertEqualPlans(originalPlan4, bestPlan4) + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { val analyzed = originalPlan.analyze val optimized = Optimize.execute(analyzed) - val expected = groundTruthBestPlan.analyze + val expected = ResolveHints.execute(groundTruthBestPlan.analyze) assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect assert(analyzed.sameOutput(optimized)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index c8e15c7da763e..6d1af12e68b23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -48,7 +48,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } @@ -160,7 +160,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze + Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } @@ -175,7 +175,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(left.output, right.output, - Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } @@ -248,7 +248,7 @@ class ReplaceOperatorSuite extends PlanTest { val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) => a1 <=> a2 }.reduce( _ && _) val correctAnswer = Aggregate(basePlan.output, otherPlan.output, - Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze + Join(basePlan, otherPlan, LeftAnti, Option(condition), JoinHint.NONE)).analyze comparePlans(result, correctAnswer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3081ff935f043..5394732f41f2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -99,11 +99,11 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => .reduce(And), child) case sample: Sample => sample.copy(seed = 0L) - case Join(left, right, joinType, condition) if condition.isDefined => + case Join(left, right, joinType, condition, hint) if condition.isDefined => val newCondition = splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) - Join(left, right, joinType, Some(newCondition)) + Join(left, right, joinType, Some(newCondition), hint) } } @@ -165,8 +165,10 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { (plan1, plan2) match { case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) + && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) + && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) case (p1: Project, p2: Project) => p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) case _ => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 7c8ed78a49116..fbaaf807af5d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union} +import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ /** @@ -30,6 +32,10 @@ class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("EliminateResolvedHint", Once, EliminateResolvedHint) :: Nil + } + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze @@ -72,4 +78,12 @@ class SameResultSuite extends SparkFunSuite { val df2 = testRelation.join(testRelation) assertSameResult(df1, df2) } + + test("join hint") { + val df1 = testRelation.join(testRelation.hint("broadcast")) + val df2 = testRelation.join(testRelation) + val df1Optimized = Optimize.execute(df1.analyze) + val df2Optimized = Optimize.execute(df2.analyze) + assertSameResult(df1Optimized, df2Optimized) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 953094cb0dd52..16a5c2d3001a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -38,24 +38,6 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { // row count * (overhead + column size) size = Some(10 * (8 + 4))) - test("BroadcastHint estimation") { - val filter = Filter(Literal(true), plan) - val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), - rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) - val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4)) - checkStats( - filter, - expectedStatsCboOn = filterStatsCboOn, - expectedStatsCboOff = filterStatsCboOff) - - val broadcastHint = ResolvedHint(filter, HintInfo(broadcast = true)) - checkStats( - broadcastHint, - expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(broadcast = true)), - expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(broadcast = true)) - ) - } - test("range") { val range = Range(1, 5, 1, None) val rangeStats = Statistics(sizeInBytes = 4 * 8) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index b0a47e7835129..1cf888519077a 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -528,7 +528,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = 30, attributeStats = AttributeMap(Seq(attrIntLargerRange -> colStatIntLargerRange))) val nonLeafChild = Join(largerTable, smallerTable, LeftOuter, - Some(EqualTo(attrIntLargerRange, attrInt))) + Some(EqualTo(attrIntLargerRange, attrInt)), JoinHint.NONE) Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach { predicate => validateEstimatedStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 12c0a7be21292..6c5a2b247fc23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -79,8 +79,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax) val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax) - val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2))) - val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1))) + val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)), JoinHint.NONE) + val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)), JoinHint.NONE) val expectedStatsAfterJoin = Statistics( sizeInBytes = expectedRows * (8 + 2 * 4), rowCount = Some(expectedRows), @@ -284,7 +284,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("cross join") { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) - val join = Join(table1, table2, Cross, None) + val join = Join(table1, table2, Cross, None, JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 5 * 3 * (8 + 4 * 4), rowCount = Some(5 * 3), @@ -299,7 +299,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, Inner, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), @@ -312,7 +312,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, LeftOuter, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 5 * (8 + 4 * 4), rowCount = Some(5), @@ -328,7 +328,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, RightOuter, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), rowCount = Some(3), @@ -344,7 +344,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // key-5-9 and key-2-4 are disjoint val join = Join(table1, table2, FullOuter, - Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-5-9"), nameToAttr("key-2-4"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = (5 + 3) * (8 + 4 * 4), rowCount = Some(5 + 3), @@ -361,7 +361,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) val join = Join(table1, table2, Inner, - Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) + Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))), JoinHint.NONE) // Update column stats for equi-join keys (key-1-5 and key-1-2). val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) @@ -383,7 +383,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, Inner, Some( And(EqualTo(nameToAttr("key-1-2"), nameToAttr("key-1-2")), - EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) + EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))), JoinHint.NONE) // Update column stats for join keys. val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), @@ -404,7 +404,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, - Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) + Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))), JoinHint.NONE) val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) @@ -422,7 +422,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, - Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), JoinHint.NONE) val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) @@ -440,7 +440,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4) // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, FullOuter, - Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -456,7 +456,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) Seq(LeftSemi, LeftAnti).foreach { jt => val join = Join(table2, table3, jt, - Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) + Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))), JoinHint.NONE) // For now we just propagate the statistics from left side for left semi/anti join. val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 2), @@ -525,7 +525,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { withClue(s"For data type ${key1.dataType}") { // All values in two tables are the same, so column stats after join are also the same. val join = Join(Project(Seq(key1), table1), Project(Seq(key2), table2), Inner, - Some(EqualTo(key1, key2))) + Some(EqualTo(key1, key2)), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), rowCount = Some(1), @@ -543,7 +543,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { outputList = Seq(nullColumn), rowCount = 1, attributeStats = AttributeMap(Seq(nullColumn -> nullColStat))) - val join = Join(table1, nullTable, Inner, Some(EqualTo(nameToAttr("key-1-5"), nullColumn))) + val join = Join(table1, nullTable, Inner, + Some(EqualTo(nameToAttr("key-1-5"), nullColumn)), JoinHint.NONE) val expectedStats = Statistics( sizeInBytes = 1, rowCount = Some(0), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a664c7338badb..44cada086489a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -862,7 +862,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) } /** @@ -940,7 +940,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) .analyzed.asInstanceOf[Join] withPlan { @@ -948,7 +948,8 @@ class Dataset[T] private[sql]( joined.left, joined.right, UsingJoin(JoinType(joinType), usingColumns), - None) + None, + JoinHint.NONE) } } @@ -1001,7 +1002,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -1048,7 +1049,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) } /** @@ -1083,7 +1084,8 @@ class Dataset[T] private[sql]( this.logicalPlan, other.logicalPlan, JoinType(joinType), - Some(condition.expr))).analyzed.asInstanceOf[Join] + Some(condition.expr), + JoinHint.NONE)).analyzed.asInstanceOf[Join] if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) { throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql) @@ -1135,7 +1137,7 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) + withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr), JoinHint.NONE)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..b7cc373b2df12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -208,17 +208,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : Boolean = { - val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast - val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + private def canBroadcastByHints( + joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { + val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) + val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) buildLeft || buildRight } - private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : BuildSide = { - val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast - val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + private def broadcastSideByHints( + joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { + val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) + val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) broadcastSide(buildLeft, buildRight, left, right) } @@ -241,14 +241,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- BroadcastHashJoin -------------------------------------------------------------------- // broadcast hints were specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBroadcastByHints(joinType, left, right) => - val buildSide = broadcastSideByHints(joinType, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) + if canBroadcastByHints(joinType, left, right, hint) => + val buildSide = broadcastSideByHints(joinType, left, right, hint) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // broadcast hints were not specified, so need to infer it from size and configuration. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( @@ -256,14 +256,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- ShuffledHashJoin --------------------------------------------------------------------- - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => @@ -272,7 +272,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- SortMergeJoin ------------------------------------------------------------ - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil @@ -280,25 +280,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast - case j @ logical.Join(left, right, joinType, condition) - if canBroadcastByHints(joinType, left, right) => - val buildSide = broadcastSideByHints(joinType, left, right) + case j @ logical.Join(left, right, joinType, condition, hint) + if canBroadcastByHints(joinType, left, right, hint) => + val buildSide = broadcastSideByHints(joinType, left, right, hint) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition) + case j @ logical.Join(left, right, joinType, condition, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin - case logical.Join(left, right, _: InnerLike, condition) => + case logical.Join(left, right, _: InnerLike, condition, _) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil - case logical.Join(left, right, joinType, condition) => + case logical.Join(left, right, joinType, condition, hint) => val buildSide = broadcastSide( - left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) + hint.leftHint.exists(_.broadcast), hint.rightHint.exists(_.broadcast), left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil @@ -380,13 +380,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if left.isStreaming && right.isStreaming => new StreamingSymmetricHashJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case Join(left, right, _, _) if left.isStreaming && right.isStreaming => + case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( "Stream-stream join without equality predicate is not supported", plan = Some(plan)) @@ -561,6 +561,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical except (all) operator should have been replaced by union, aggregate" + " and generate operators in the optimizer") + case logical.ResolvedHint(child, hints) => + throw new IllegalStateException( + "ResolvedHint operator should have been replaced by join hint in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil @@ -632,7 +635,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil - case h: ResolvedHint => planLater(h.child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 4109d9994dd8f..41f406d6c2993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel @@ -184,12 +184,7 @@ case class InMemoryRelation( override def computeStats(): Statistics = { if (cacheBuilder.sizeInBytesStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) + statsOfPlanToCache } else { Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 6e805c4f3c39a..2141be4d680f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,9 +27,11 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -925,4 +927,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + test("Cache should respect the broadcast hint") { + val df = broadcast(spark.range(1000)).cache() + val df2 = spark.range(1000).cache() + df.count() + df2.count() + + // Test the broadcast hint. + val joinPlan = df.join(df2, "id").queryExecution.optimizedPlan + val hint = joinPlan.collect { + case Join(_, _, _, _, hint) => hint + } + assert(hint.size == 1) + assert(hint(0).leftHint.get.broadcast) + assert(hint(0).rightHint.isEmpty) + + // Clean-up + df.unpersist() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c9f41ab1c0179..a4a3e2a62d1a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -198,7 +198,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // outer -> left val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" >= 3) assert(outerJoin2Left.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, LeftOuter, _) => j }.size === 1) + case j @ Join(_, _, LeftOuter, _, _) => j }.size === 1) checkAnswer( outerJoin2Left, Row(3, 4, "3", null, null, null) :: Nil) @@ -206,7 +206,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // outer -> right val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" >= 3) assert(outerJoin2Right.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, RightOuter, _) => j }.size === 1) + case j @ Join(_, _, RightOuter, _, _) => j }.size === 1) checkAnswer( outerJoin2Right, Row(null, null, null, 5, 6, "5") :: Nil) @@ -215,7 +215,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer"). where($"a.int" === 1 && $"b.int2" === 3) assert(outerJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( outerJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) @@ -223,7 +223,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // right -> inner val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" > 0) assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( rightJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) @@ -231,7 +231,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // left -> inner val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" > 0) assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { - case j @ Join(_, _, Inner, _) => j }.size === 1) + case j @ Join(_, _, Inner, _, _) => j }.size === 1) checkAnswer( leftJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala new file mode 100644 index 0000000000000..3652895ff43d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.test.SharedSQLContext + +class JoinHintSuite extends PlanTest with SharedSQLContext { + import testImplicits._ + + lazy val df = spark.range(10) + lazy val df1 = df.selectExpr("id as a1", "id as a2") + lazy val df2 = df.selectExpr("id as b1", "id as b2") + lazy val df3 = df.selectExpr("id as c1", "id as c2") + + def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = { + val optimized = df.queryExecution.optimizedPlan + val joinHints = optimized collect { + case Join(_, _, _, _, hint) => hint + case _: ResolvedHint => fail("ResolvedHint should not appear after optimize.") + } + assert(joinHints == expectedHints) + } + + test("single join") { + verifyJoinHint( + df.hint("broadcast").join(df, "id"), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast"), "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: Nil + ) + } + + test("multiple joins") { + verifyJoinHint( + df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"), 'a1 === 'c1), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3, 'a1 === 'c1), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + } + + test("hint scope") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + verifyJoinHint( + sql( + """ + |select /*+ broadcast(a, b)*/ * from ( + | select /*+ broadcast(b)*/ * from a join b on a.a1 = b.b1 + |) a join ( + | select /*+ broadcast(a)*/ * from a join b on a.a1 = b.b1 + |) b on a.a1 = b.b1 + """.stripMargin), + JoinHint( + Some(HintInfo(broadcast = true)), + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + } + } + + test("hint preserved after join reorder") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + df3.createOrReplaceTempView("c") + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None):: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None):: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(b, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))):: Nil + ) + + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE:: Nil + ) + + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10) + .join(df, 'b1 === 'id), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE:: Nil + ) + } + } + + test("intersect/except") { + val dfSub = spark.range(2) + verifyJoinHint( + df.hint("broadcast").except(dfSub).join(df, "id"), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast").intersect(dfSub), "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: + JoinHint.NONE :: Nil + ) + } + + test("hint merge") { + verifyJoinHint( + df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"), + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"), + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 02dc32d5f90ba..99842680cedfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -237,8 +237,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared ) numbers.foreach { case (input, (expectedSize, expectedRows)) => val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) - val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + - s" hints=none" + val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows" assert(stats.simpleString == expectedString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 42dd0024b2582..f238148e61c39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -203,7 +203,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("broadcast hint in SQL") { - import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} + import org.apache.spark.sql.catalyst.plans.logical.Join spark.range(10).createOrReplaceTempView("t") spark.range(10).createOrReplaceTempView("u") @@ -216,12 +216,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) - assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) - assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast) + assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty) + assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty) + assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast) + assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty) + assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 22279a3a43eff..771a9730247af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf @@ -85,7 +85,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Row]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, Some(condition), JoinHint.NONE) ExtractEquiJoinKeys.unapply(join) } @@ -102,7 +103,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -121,7 +122,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -140,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index f5edd6bbd5e69..f99a278bb2427 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf @@ -80,7 +80,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, Some(condition()), JoinHint.NONE) ExtractEquiJoinKeys.unapply(join) } @@ -128,7 +129,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -140,7 +141,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -152,7 +153,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -164,7 +165,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -176,7 +177,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 513248dae48be..1f04fcf6ca451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf @@ -72,13 +72,14 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer: Seq[Product]): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, Some(condition), JoinHint.NONE) ExtractEquiJoinKeys.unapply(join) } if (joinType != FullOuter) { test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -99,7 +100,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { case RightOuter => BuildLeft case _ => fail(s"Unsupported join type $joinType") } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashJoinExec( @@ -112,7 +113,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( From 1a641525e60039cc6b10816e946cb6f44b3e2696 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 Jan 2019 15:35:33 -0800 Subject: [PATCH 2395/2461] [SPARK-26491][CORE][TEST] Use ConfigEntry for hardcoded configs for test categories ## What changes were proposed in this pull request? The PR makes hardcoded `spark.test` and `spark.testing` configs to use `ConfigEntry` and put them in the config package. ## How was this patch tested? existing UTs Closes #23413 from mgaido91/SPARK-26491. Authored-by: Marco Gaido Signed-off-by: Marcelo Vanzin --- .../spark/ExecutorAllocationManager.scala | 4 +- .../scala/org/apache/spark/SparkContext.scala | 3 +- .../deploy/history/FsHistoryProvider.scala | 3 +- .../apache/spark/deploy/worker/Worker.scala | 4 +- .../spark/executor/ProcfsMetricsGetter.scala | 2 +- .../apache/spark/executor/TaskMetrics.scala | 3 +- .../apache/spark/internal/config/Tests.scala | 56 +++++++++++++++++++ .../spark/memory/StaticMemoryManager.scala | 5 +- .../spark/memory/UnifiedMemoryManager.scala | 7 ++- .../apache/spark/scheduler/DAGScheduler.scala | 3 +- .../cluster/StandaloneSchedulerBackend.scala | 3 +- .../org/apache/spark/util/SizeEstimator.scala | 5 +- .../scala/org/apache/spark/util/Utils.scala | 5 +- .../org/apache/spark/DistributedSuite.scala | 5 +- .../ExecutorAllocationManagerSuite.scala | 3 +- .../scala/org/apache/spark/ShuffleSuite.scala | 5 +- .../org/apache/spark/SparkFunSuite.scala | 3 +- .../history/HistoryServerArgumentsSuite.scala | 6 +- .../deploy/history/HistoryServerSuite.scala | 7 ++- .../memory/StaticMemoryManagerSuite.scala | 5 +- .../memory/UnifiedMemoryManagerSuite.scala | 33 +++++------ .../scheduler/BarrierTaskContextSuite.scala | 7 ++- .../scheduler/BlacklistIntegrationSuite.scala | 19 ++++--- .../sort/ShuffleExternalSorterSuite.scala | 5 +- .../BlockManagerReplicationSuite.scala | 9 +-- .../spark/storage/BlockManagerSuite.scala | 10 ++-- .../spark/storage/MemoryStoreSuite.scala | 1 - .../spark/util/SizeEstimatorSuite.scala | 5 +- .../ExternalAppendOnlyMapSuite.scala | 3 +- .../util/collection/ExternalSorterSuite.scala | 3 +- .../KubernetesTestComponents.scala | 3 +- .../MesosCoarseGrainedSchedulerBackend.scala | 3 +- .../spark/sql/execution/SQLExecution.scala | 4 +- .../apache/spark/sql/BenchmarkQueryTest.scala | 3 +- .../execution/UnsafeRowSerializerSuite.scala | 3 +- 35 files changed, 165 insertions(+), 83 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/internal/config/Tests.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index d966582295b37..0807e653b41a9 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -27,6 +27,7 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.TEST_SCHEDULE_INTERVAL import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -157,7 +158,7 @@ private[spark] class ExecutorAllocationManager( // Polling loop interval (ms) private val intervalMillis: Long = if (Utils.isTesting) { - conf.getLong(TESTING_SCHEDULE_INTERVAL_KEY, 100) + conf.get(TEST_SCHEDULE_INTERVAL) } else { 100 } @@ -899,5 +900,4 @@ private[spark] class ExecutorAllocationManager( private object ExecutorAllocationManager { val NOT_SET = Long.MaxValue - val TESTING_SCHEDULE_INTERVAL_KEY = "spark.testing.dynamicAllocation.scheduleInterval" } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 89be9de083075..3a1e1b9310029 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -45,6 +45,7 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests._ import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -470,7 +471,7 @@ class SparkContext(config: SparkConf) extends Logging { // Convert java options to env vars as a work around // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + for { (envKey, propKey) <- Seq(("SPARK_TESTING", IS_TESTING.key)) value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 709a380dfb636..3c5648434fa66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -45,6 +45,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DRIVER_LOG_DFS_DIR, History} import org.apache.spark.internal.config.History._ import org.apache.spark.internal.config.Status._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ @@ -267,7 +268,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } // Disable the background thread during tests. - if (!conf.contains("spark.testing")) { + if (!conf.contains(IS_TESTING)) { // A task that periodically checks for event log updates on disk. logDebug(s"Scheduling update thread every $UPDATE_INTERVAL_S seconds") pool.scheduleWithFixedDelay( diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index d5ea2523c628b..467df26c47354 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -37,6 +37,7 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} @@ -103,7 +104,6 @@ private[deploy] class Worker( private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true) - private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None /** @@ -127,7 +127,7 @@ private[deploy] class Worker( private var connected = false private val workerId = generateWorkerId() private val sparkHome = - if (testing) { + if (sys.props.contains(IS_TESTING.key)) { assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") new File(sys.props("spark.test.home")) } else { diff --git a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala index af67f41e94af1..f354d603c2e3d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala +++ b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala @@ -43,7 +43,7 @@ private[spark] case class ProcfsMetrics( // project. private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends Logging { private val procfsStatFile = "stat" - private val testing = sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") + private val testing = Utils.isTesting private val pageSize = computePageSize() private var isAvailable: Boolean = isProcfsAvailable private val pid = computePid() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 85b2745a2aec4..ea79c7310349d 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.util._ @@ -202,7 +203,7 @@ class TaskMetrics private[spark] () extends Serializable { } // Only used for test - private[spark] val testAccum = sys.props.get("spark.testing").map(_ => new LongAccumulator) + private[spark] val testAccum = sys.props.get(IS_TESTING.key).map(_ => new LongAccumulator) import InternalAccumulator._ diff --git a/core/src/main/scala/org/apache/spark/internal/config/Tests.scala b/core/src/main/scala/org/apache/spark/internal/config/Tests.scala new file mode 100644 index 0000000000000..21660ab3a9512 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/Tests.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +private[spark] object Tests { + + val TEST_USE_COMPRESSED_OOPS_KEY = "spark.test.useCompressedOops" + + val TEST_MEMORY = ConfigBuilder("spark.testing.memory") + .longConf + .createWithDefault(Runtime.getRuntime.maxMemory) + + val TEST_SCHEDULE_INTERVAL = + ConfigBuilder("spark.testing.dynamicAllocation.scheduleInterval") + .longConf + .createWithDefault(100) + + val IS_TESTING = ConfigBuilder("spark.testing") + .booleanConf + .createOptional + + val TEST_NO_STAGE_RETRY = ConfigBuilder("spark.test.noStageRetry") + .booleanConf + .createWithDefault(false) + + val TEST_RESERVED_MEMORY = ConfigBuilder("spark.testing.reservedMemory") + .longConf + .createOptional + + val TEST_N_HOSTS = ConfigBuilder("spark.testing.nHosts") + .intConf + .createWithDefault(5) + + val TEST_N_EXECUTORS_HOST = ConfigBuilder("spark.testing.nExecutorsPerHost") + .intConf + .createWithDefault(4) + + val TEST_N_CORES_EXECUTOR = ConfigBuilder("spark.testing.nCoresPerExecutor") + .intConf + .createWithDefault(2) +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 0fd349dc51619..7e052c02c9376 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.memory import org.apache.spark.SparkConf import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests.TEST_MEMORY import org.apache.spark.storage.BlockId /** @@ -120,7 +121,7 @@ private[spark] object StaticMemoryManager { * Return the total amount of memory available for the storage region, in bytes. */ private def getMaxStorageMemory(conf: SparkConf): Long = { - val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val systemMaxMemory = conf.get(TEST_MEMORY) val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) (systemMaxMemory * memoryFraction * safetyFraction).toLong @@ -130,7 +131,7 @@ private[spark] object StaticMemoryManager { * Return the total amount of memory available for the execution region, in bytes. */ private def getMaxExecutionMemory(conf: SparkConf): Long = { - val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val systemMaxMemory = conf.get(TEST_MEMORY) if (systemMaxMemory < MIN_MEMORY_BYTES) { throw new IllegalArgumentException(s"System memory $systemMaxMemory must " + diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 9260fd3a6fb34..7801bb87050f6 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.memory import org.apache.spark.SparkConf import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests._ import org.apache.spark.storage.BlockId /** @@ -210,9 +211,9 @@ object UnifiedMemoryManager { * Return the total amount of memory shared between execution and storage, in bytes. */ private def getMaxMemory(conf: SparkConf): Long = { - val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) - val reservedMemory = conf.getLong("spark.testing.reservedMemory", - if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) + val systemMemory = conf.get(TEST_MEMORY) + val reservedMemory = conf.getLong(TEST_RESERVED_MEMORY.key, + if (conf.contains(IS_TESTING)) 0 else RESERVED_SYSTEM_MEMORY_BYTES) val minSystemMemory = (reservedMemory * 1.5).ceil.toLong if (systemMemory < minSystemMemory) { throw new IllegalArgumentException(s"System memory $systemMemory must " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6f4c326442e1e..f6ade180ee25f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} @@ -186,7 +187,7 @@ private[spark] class DAGScheduler( private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ - private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) /** * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index adef20d3077d8..66080b6e6b4ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientListener} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ @@ -90,7 +91,7 @@ private[spark] class StandaloneSchedulerBackend( // compute-classpath.{cmd,sh} and makes all needed jars available to child processes // when the assembly is built with the "*-provided" profiles enabled. val testingClassPath = - if (sys.props.contains("spark.testing")) { + if (sys.props.contains(IS_TESTING.key)) { sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq } else { Nil diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 3bfdf95db84c6..e12b6b71578c1 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -28,6 +28,7 @@ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Tests.TEST_USE_COMPRESSED_OOPS_KEY import org.apache.spark.util.collection.OpenHashSet /** @@ -126,8 +127,8 @@ object SizeEstimator extends Logging { private def getIsCompressedOops: Boolean = { // This is only used by tests to override the detection of compressed oops. The test // actually uses a system property instead of a SparkConf, so we'll stick with that. - if (System.getProperty("spark.test.useCompressedOops") != null) { - return System.getProperty("spark.test.useCompressedOops").toBoolean + if (System.getProperty(TEST_USE_COMPRESSED_OOPS_KEY) != null) { + return System.getProperty(TEST_USE_COMPRESSED_OOPS_KEY).toBoolean } // java.vm.info provides compressed ref info for IBM JDKs diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3527fee68939d..16ef38142ad9f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -60,6 +60,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -1847,7 +1848,7 @@ private[spark] object Utils extends Logging { * Indicates whether Spark is currently running unit tests. */ def isTesting: Boolean = { - sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") + sys.env.contains("SPARK_TESTING") || sys.props.contains(IS_TESTING.key) } /** @@ -2175,7 +2176,7 @@ private[spark] object Utils extends Logging { */ def portMaxRetries(conf: SparkConf): Int = { val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt) - if (conf.contains("spark.testing")) { + if (conf.contains(IS_TESTING)) { // Set a higher number of retries for tests... maxRetries.getOrElse(100) } else { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4083b20c23594..21050e44414f5 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests._ import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -217,7 +218,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val size = 10000 val conf = new SparkConf() .set("spark.storage.unrollMemoryThreshold", "1024") - .set("spark.testing.memory", (size / 2).toString) + .set(TEST_MEMORY, size.toLong / 2) sc = new SparkContext(clusterUrl, "test", conf) val data = sc.parallelize(1 to size, 2).persist(StorageLevel.MEMORY_ONLY) assert(data.count() === size) @@ -233,7 +234,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val numPartitions = 20 val conf = new SparkConf() .set("spark.storage.unrollMemoryThreshold", "1024") - .set("spark.testing.memory", size.toString) + .set(TEST_MEMORY, size.toLong) sc = new SparkContext(clusterUrl, "test", conf) val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY) assert(data.count() === size) diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 38f5e8c9f0ac8..6b310b9cb67aa 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests.TEST_SCHEDULE_INTERVAL import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -1166,7 +1167,7 @@ class ExecutorAllocationManagerSuite .set("spark.dynamicAllocation.testing", "true") // SPARK-22864: effectively disable the allocation schedule by setting the period to a // really long value. - .set(TESTING_SCHEDULE_INTERVAL_KEY, "10000") + .set(TEST_SCHEDULE_INTERVAL, 10000L) val sc = new SparkContext(conf) contexts += sc sc diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 35f728cd57fe2..ffa70425ea367 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass +import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} @@ -37,7 +38,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // Ensure that the DAGScheduler doesn't retry stages whose fetches fail, so that we accurately // test that the shuffle works (rather than retrying until all blocks are local to one Executor). - conf.set("spark.test.noStageRetry", "true") + conf.set(TEST_NO_STAGE_RETRY, true) test("groupByKey without compression") { val myConf = conf.clone().set("spark.shuffle.compress", "false") @@ -269,7 +270,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } test("[SPARK-4085] rerun map stage if reduce stage cannot find its local shuffle file") { - val myConf = conf.clone().set("spark.test.noStageRetry", "false") + val myConf = conf.clone().set(TEST_NO_STAGE_RETRY, false) sc = new SparkContext("local", "test", myConf) val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) rdd.count() diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index dad24d7c01b8b..7d114b1b0c144 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -23,6 +23,7 @@ import java.io.File import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.util.{AccumulatorContext, Utils} /** @@ -59,7 +60,7 @@ abstract class SparkFunSuite protected val enableAutoThreadAudit = true protected override def beforeAll(): Unit = { - System.setProperty("spark.testing", "true") + System.setProperty(IS_TESTING.key, "true") if (enableAutoThreadAudit) { doThreadPreAudit() } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index 6b479873f69f2..5903ae71ec66e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.spark._ import org.apache.spark.internal.config.History._ -import org.apache.spark.util.Utils +import org.apache.spark.internal.config.Tests._ class HistoryServerArgumentsSuite extends SparkFunSuite { @@ -31,14 +31,14 @@ class HistoryServerArgumentsSuite extends SparkFunSuite { private val conf = new SparkConf() .set(HISTORY_LOG_DIR, logDir.getAbsolutePath) .set(UPDATE_INTERVAL_S, 1L) - .set("spark.testing", "true") + .set(IS_TESTING, true) test("No Arguments Parsing") { val argStrings = Array.empty[String] val hsa = new HistoryServerArguments(conf, argStrings) assert(conf.get(HISTORY_LOG_DIR) === logDir.getAbsolutePath) assert(conf.get(UPDATE_INTERVAL_S) === 1L) - assert(conf.get("spark.testing") === "true") + assert(conf.get(IS_TESTING).getOrElse(false)) } test("Properties File Arguments Parsing --properties-file") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 96458c55b5f55..bb7d3c52bc9c4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -47,6 +47,7 @@ import org.scalatest.selenium.WebBrowser import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.History._ +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI @@ -81,7 +82,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val conf = new SparkConf() .set(HISTORY_LOG_DIR, logDir) .set(UPDATE_INTERVAL_S.key, "0") - .set("spark.testing", "true") + .set(IS_TESTING, true) .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .set(EVENT_LOG_STAGE_EXECUTOR_METRICS, true) .set(EVENT_LOG_PROCESS_TREE_METRICS, true) @@ -400,7 +401,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers */ test("security manager starts with spark.authenticate set") { val conf = new SparkConf() - .set("spark.testing", "true") + .set(IS_TESTING, true) .set(SecurityManager.SPARK_AUTH_CONF, "true") HistoryServer.createSecurityManager(conf) } @@ -422,7 +423,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set(UPDATE_INTERVAL_S.key, "1s") .set(EVENT_LOG_ENABLED, true) .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - .remove("spark.testing") + .remove(IS_TESTING) val provider = new FsHistoryProvider(myConf) val securityManager = HistoryServer.createSecurityManager(myConf) diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 0f32fe4059fbb..c3275add50f48 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -21,6 +21,7 @@ import org.mockito.Mockito.when import org.apache.spark.SparkConf import org.apache.spark.internal.config.MEMORY_OFFHEAP_SIZE +import org.apache.spark.internal.config.Tests.TEST_MEMORY import org.apache.spark.storage.TestBlockId import org.apache.spark.storage.memory.MemoryStore @@ -48,8 +49,8 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { new StaticMemoryManager( conf.clone .set("spark.memory.fraction", "1") - .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set(MEMORY_OFFHEAP_SIZE.key, maxOffHeapExecutionMemory.toString), + .set(TEST_MEMORY, maxOnHeapExecutionMemory) + .set(MEMORY_OFFHEAP_SIZE, maxOffHeapExecutionMemory), maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxOnHeapStorageMemory = 0, numCores = 1) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 5ce3453b682fe..8556e920daebb 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests._ import org.apache.spark.storage.TestBlockId import org.apache.spark.storage.memory.MemoryStore @@ -43,8 +44,8 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes maxOffHeapExecutionMemory: Long): UnifiedMemoryManager = { val conf = new SparkConf() .set("spark.memory.fraction", "1") - .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set(MEMORY_OFFHEAP_SIZE.key, maxOffHeapExecutionMemory.toString) + .set(TEST_MEMORY, maxOnHeapExecutionMemory) + .set(MEMORY_OFFHEAP_SIZE, maxOffHeapExecutionMemory) .set("spark.memory.storageFraction", storageFraction.toString) UnifiedMemoryManager(conf, numCores = 1) } @@ -218,19 +219,19 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes } test("small heap") { - val systemMemory = 1024 * 1024 - val reservedMemory = 300 * 1024 + val systemMemory = 1024L * 1024 + val reservedMemory = 300L * 1024 val memoryFraction = 0.8 val conf = new SparkConf() .set("spark.memory.fraction", memoryFraction.toString) - .set("spark.testing.memory", systemMemory.toString) - .set("spark.testing.reservedMemory", reservedMemory.toString) + .set(TEST_MEMORY, systemMemory) + .set(TEST_RESERVED_MEMORY, reservedMemory) val mm = UnifiedMemoryManager(conf, numCores = 1) val expectedMaxMemory = ((systemMemory - reservedMemory) * memoryFraction).toLong assert(mm.maxHeapMemory === expectedMaxMemory) // Try using a system memory that's too small - val conf2 = conf.clone().set("spark.testing.memory", (reservedMemory / 2).toString) + val conf2 = conf.clone().set(TEST_MEMORY, reservedMemory / 2) val exception = intercept[IllegalArgumentException] { UnifiedMemoryManager(conf2, numCores = 1) } @@ -238,13 +239,13 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes } test("insufficient executor memory") { - val systemMemory = 1024 * 1024 - val reservedMemory = 300 * 1024 + val systemMemory = 1024L * 1024 + val reservedMemory = 300L * 1024 val memoryFraction = 0.8 val conf = new SparkConf() .set("spark.memory.fraction", memoryFraction.toString) - .set("spark.testing.memory", systemMemory.toString) - .set("spark.testing.reservedMemory", reservedMemory.toString) + .set(TEST_MEMORY, systemMemory) + .set(TEST_RESERVED_MEMORY, reservedMemory) val mm = UnifiedMemoryManager(conf, numCores = 1) // Try using an executor memory that's too small @@ -259,7 +260,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.memory.storageFraction", "0") - .set("spark.testing.memory", "1000") + .set(TEST_MEMORY, 1000L) val mm = UnifiedMemoryManager(conf, numCores = 2) val ms = makeMemoryStore(mm) val memoryMode = MemoryMode.ON_HEAP @@ -285,7 +286,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.memory.storageFraction", "0") - .set("spark.testing.memory", "1000") + .set(TEST_MEMORY, 1000L) val mm = UnifiedMemoryManager(conf, numCores = 2) makeBadMemoryStore(mm) val memoryMode = MemoryMode.ON_HEAP @@ -306,9 +307,9 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("not enough free memory in the storage pool --OFF_HEAP") { val conf = new SparkConf() - .set(MEMORY_OFFHEAP_SIZE.key, "1000") - .set("spark.testing.memory", "1000") - .set(MEMORY_OFFHEAP_ENABLED.key, "true") + .set(MEMORY_OFFHEAP_SIZE, 1000L) + .set(TEST_MEMORY, 1000L) + .set(MEMORY_OFFHEAP_ENABLED, true) val taskAttemptId = 0L val mm = UnifiedMemoryManager(conf, numCores = 1) val ms = makeMemoryStore(mm) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 36dd620a56853..112fd31a060e6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import scala.util.Random import org.apache.spark._ +import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { @@ -76,7 +77,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { test("throw exception on barrier() call timeout") { val conf = new SparkConf() .set("spark.barrier.sync.timeout", "1") - .set("spark.test.noStageRetry", "true") + .set(TEST_NO_STAGE_RETRY, true) .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") sc = new SparkContext(conf) @@ -101,7 +102,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { test("throw exception if barrier() call doesn't happen on every task") { val conf = new SparkConf() .set("spark.barrier.sync.timeout", "1") - .set("spark.test.noStageRetry", "true") + .set(TEST_NO_STAGE_RETRY, true) .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") sc = new SparkContext(conf) @@ -124,7 +125,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { test("throw exception if the number of barrier() calls are not the same on every task") { val conf = new SparkConf() .set("spark.barrier.sync.timeout", "1") - .set("spark.test.noStageRetry", "true") + .set(TEST_NO_STAGE_RETRY, true) .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") sc = new SparkContext(conf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index 29bb8232f44f5..2215f7f366213 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -20,6 +20,7 @@ import scala.concurrent.duration._ import org.apache.spark._ import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests._ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorMockBackend]{ @@ -58,9 +59,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM extraConfs = Seq( config.BLACKLIST_ENABLED.key -> "true", config.MAX_TASK_FAILURES.key -> "4", - "spark.testing.nHosts" -> "2", - "spark.testing.nExecutorsPerHost" -> "5", - "spark.testing.nCoresPerExecutor" -> "10" + TEST_N_HOSTS.key -> "2", + TEST_N_EXECUTORS_HOST.key -> "5", + TEST_N_CORES_EXECUTOR.key -> "10" ) ) { // To reliably reproduce the failure that would occur without blacklisting, we have to use 1 @@ -102,9 +103,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM "SPARK-15865 Progress with fewer executors than maxTaskFailures", extraConfs = Seq( config.BLACKLIST_ENABLED.key -> "true", - "spark.testing.nHosts" -> "2", - "spark.testing.nExecutorsPerHost" -> "1", - "spark.testing.nCoresPerExecutor" -> "1", + TEST_N_HOSTS.key -> "2", + TEST_N_EXECUTORS_HOST.key -> "1", + TEST_N_CORES_EXECUTOR.key -> "1", "spark.scheduler.blacklist.unschedulableTaskSetTimeout" -> "0s" ) ) { @@ -129,9 +130,9 @@ class MultiExecutorMockBackend( conf: SparkConf, taskScheduler: TaskSchedulerImpl) extends MockBackend(conf, taskScheduler) { - val nHosts = conf.getInt("spark.testing.nHosts", 5) - val nExecutorsPerHost = conf.getInt("spark.testing.nExecutorsPerHost", 4) - val nCoresPerExecutor = conf.getInt("spark.testing.nCoresPerExecutor", 2) + val nHosts = conf.get(TEST_N_HOSTS) + val nExecutorsPerHost = conf.get(TEST_N_EXECUTORS_HOST) + val nCoresPerExecutor = conf.get(TEST_N_CORES_EXECUTOR) override val executorIdToExecutor: Map[String, ExecutorTaskStatus] = { (0 until nHosts).flatMap { hostIdx => diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala index b9f0e873375b0..43621cb85762c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory._ import org.apache.spark.unsafe.Platform @@ -33,8 +34,8 @@ class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext wi val conf = new SparkConf() .setMaster("local[1]") .setAppName("ShuffleExternalSorterSuite") - .set("spark.testing", "true") - .set("spark.testing.memory", "1600") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 1600L) .set("spark.memory.fraction", "1") sc = new SparkContext(conf) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 19116cf22d2f8..480e07fb9399a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DRIVER_PORT, MEMORY_OFFHEAP_SIZE} +import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -69,8 +70,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set(MEMORY_OFFHEAP_SIZE.key, maxMem.toString) + conf.set(TEST_MEMORY, maxMem) + conf.set(MEMORY_OFFHEAP_SIZE, maxMem) val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) @@ -87,7 +88,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite conf.set("spark.authenticate", "false") conf.set(DRIVER_PORT, rpcEnv.address.port) - conf.set("spark.testing", "true") + conf.set(IS_TESTING, true) conf.set("spark.memory.fraction", "1") conf.set("spark.memory.storageFraction", "1") conf.set("spark.storage.unrollFraction", "0.4") @@ -233,7 +234,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - conf.set("spark.testing.memory", "10000") + conf.set(TEST_MEMORY, 10000L) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a7bb2a03360aa..bda81365b0792 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -89,8 +90,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE transferService: Option[BlockTransferService] = Option.empty, testConf: Option[SparkConf] = None): BlockManager = { val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) - bmConf.set("spark.testing.memory", maxMem.toString) - bmConf.set(MEMORY_OFFHEAP_SIZE.key, maxMem.toString) + bmConf.set(TEST_MEMORY, maxMem) + bmConf.set(MEMORY_OFFHEAP_SIZE, maxMem) val serializer = new KryoSerializer(bmConf) val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { Some(CryptoStreamUtils.createKey(bmConf)) @@ -115,11 +116,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE System.setProperty("os.arch", "amd64") conf = new SparkConf(false) .set("spark.app.id", "test") - .set("spark.testing", "true") + .set(IS_TESTING, true) .set("spark.memory.fraction", "1") .set("spark.memory.storageFraction", "1") .set("spark.kryoserializer.buffer", "1m") - .set("spark.test.useCompressedOops", "true") .set("spark.storage.unrollFraction", "0.4") .set("spark.storage.unrollMemoryThreshold", "512") @@ -901,7 +901,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. - conf.set("spark.testing.memory", "1200") + conf.set(TEST_MEMORY, 1200L) val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val memoryManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index baff672f5fb8f..b02af2bfe7acc 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -39,7 +39,6 @@ class MemoryStoreSuite with ResetSystemProperties { var conf: SparkConf = new SparkConf(false) - .set("spark.test.useCompressedOops", "true") .set("spark.storage.unrollFraction", "0.4") .set("spark.storage.unrollMemoryThreshold", "512") diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 63f9f82adf3e0..8bc62db81e4f9 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.config.Tests.TEST_USE_COMPRESSED_OOPS_KEY class DummyClass1 {} @@ -76,7 +77,7 @@ class SizeEstimatorSuite // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case super.beforeEach() System.setProperty("os.arch", "amd64") - System.setProperty("spark.test.useCompressedOops", "true") + System.setProperty(TEST_USE_COMPRESSED_OOPS_KEY, "true") } override def afterEach(): Unit = { @@ -192,7 +193,7 @@ class SizeEstimatorSuite // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("64-bit arch with no compressed oops") { System.setProperty("os.arch", "amd64") - System.setProperty("spark.test.useCompressedOops", "false") + System.setProperty(TEST_USE_COMPRESSED_OOPS_KEY, "false") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 35fba1a3b73c6..6211399005e1a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests.TEST_MEMORY import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.util.CompletionIterator @@ -552,7 +553,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite val conf = createSparkConf(loadDefaults = false) .set("spark.shuffle.memoryFraction", "0.01") .set("spark.memory.useLegacyMode", "true") - .set("spark.testing.memory", "100000000") + .set(TEST_MEMORY, 100000000L) .set("spark.shuffle.sort.bypassMergeThreshold", "0") sc = new SparkContext("local", "test", conf) val N = 2e5.toInt diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e2..aa400dd74e9ca 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ +import org.apache.spark.internal.config.Tests.TEST_MEMORY import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray @@ -639,7 +640,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = false, kryo = false) .set("spark.shuffle.memoryFraction", "0.01") .set("spark.memory.useLegacyMode", "true") - .set("spark.testing.memory", "100000000") + .set(TEST_MEMORY, 100000000L) .set("spark.shuffle.sort.bypassMergeThreshold", "0") sc = new SparkContext("local", "test", conf) val N = 2e5.toInt diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index c0b435efb8c9c..cc89683949010 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -27,6 +27,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.deploy.k8s.integrationtest.TestConstants._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Tests.IS_TESTING private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) { @@ -67,7 +68,7 @@ private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesCl .set("spark.executors.instances", "1") .set("spark.app.name", "spark-test-app") .set("spark.ui.enabled", "true") - .set("spark.testing", "false") + .set(IS_TESTING, false) .set("spark.kubernetes.submission.waitAppCompletion", "false") .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 03cd2583b9b2f..fb235350700f9 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -33,6 +33,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkExceptio import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient @@ -298,7 +299,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } protected def driverURL: String = { - if (conf.contains("spark.testing")) { + if (conf.contains(IS_TESTING)) { "driverURL" } else { RpcEndpointAddress( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index dda7cb55f5395..5b38fe5c46bbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.SparkContext +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} @@ -38,7 +38,7 @@ object SQLExecution { executionIdToQueryExecution.get(executionId) } - private val testing = sys.props.contains("spark.testing") + private val testing = sys.props.contains(IS_TESTING.key) private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index d95794d624033..c37d663941d8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} @@ -29,7 +30,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting // the max iteration of analyzer/optimizer batches. - assert(Utils.isTesting, "spark.testing is not set to true") + assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true") /** * Drop all the tables diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index ca8692290edb2..963e42517b441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ +import org.apache.spark.internal.config.Tests.TEST_MEMORY import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} @@ -99,7 +100,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { val conf = new SparkConf() .set("spark.shuffle.spill.initialMemoryThreshold", "1") .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.testing.memory", "80000") + .set(TEST_MEMORY, 80000L) spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") outputFile.deleteOnExit() From 5102ccc4ab6e30caa5510131dee7098b4f3ad32e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 7 Jan 2019 15:48:54 -0800 Subject: [PATCH 2396/2461] [SPARK-26339][SQL][FOLLOW-UP] Issue warning instead of throwing an exception for underscore files ## What changes were proposed in this pull request? The PR https://github.com/apache/spark/pull/23446 happened to introduce a behaviour change - empty dataframes can't be read anymore from underscore files. It looks controversial to allow or disallow this case so this PR targets to fix to issue warning instead of throwing an exception to be more conservative. **Before** ```scala scala> spark.read.schema("a int").parquet("_tmp*").show() org.apache.spark.sql.AnalysisException: All paths were ignored: file:/.../_tmp file:/.../_tmp1; at org.apache.spark.sql.execution.datasources.DataSource.checkAndGlobPathIfNecessary(DataSource.scala:570) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:360) at org.apache.spark.sql.DataFrameReader.loadV1Source(DataFrameReader.scala:231) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:219) at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:651) at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:635) ... 49 elided scala> spark.read.text("_tmp*").show() org.apache.spark.sql.AnalysisException: All paths were ignored: file:/.../_tmp file:/.../_tmp1; at org.apache.spark.sql.execution.datasources.DataSource.checkAndGlobPathIfNecessary(DataSource.scala:570) at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:360) at org.apache.spark.sql.DataFrameReader.loadV1Source(DataFrameReader.scala:231) at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:219) at org.apache.spark.sql.DataFrameReader.text(DataFrameReader.scala:723) at org.apache.spark.sql.DataFrameReader.text(DataFrameReader.scala:695) ... 49 elided ``` **After** ```scala scala> spark.read.schema("a int").parquet("_tmp*").show() 19/01/07 15:14:43 WARN DataSource: All paths were ignored: file:/.../_tmp file:/.../_tmp1 +---+ | a| +---+ +---+ scala> spark.read.text("_tmp*").show() 19/01/07 15:14:51 WARN DataSource: All paths were ignored: file:/.../_tmp file:/.../_tmp1 +-----+ |value| +-----+ +-----+ ``` ## How was this patch tested? Manually tested as above. Closes #23481 from HyukjinKwon/SPARK-26339. Authored-by: Hyukjin Kwon Signed-off-by: gatorsmile --- .../execution/datasources/DataSource.scala | 6 +++--- .../src/test/resources/test-data/_cars.csv | 7 ------- .../execution/datasources/csv/CSVSuite.scala | 20 ------------------- 3 files changed, 3 insertions(+), 30 deletions(-) delete mode 100644 sql/core/src/test/resources/test-data/_cars.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 2a438a5cbf957..5dad784e45af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -567,11 +567,11 @@ case class DataSource( } if (filteredOut.nonEmpty) { if (filteredIn.isEmpty) { - throw new AnalysisException( - s"All paths were ignored:\n${filteredOut.mkString("\n ")}") + logWarning( + s"All paths were ignored:\n ${filteredOut.mkString("\n ")}") } else { logDebug( - s"Some paths were ignored:\n${filteredOut.mkString("\n ")}") + s"Some paths were ignored:\n ${filteredOut.mkString("\n ")}") } } } diff --git a/sql/core/src/test/resources/test-data/_cars.csv b/sql/core/src/test/resources/test-data/_cars.csv deleted file mode 100644 index 40ded573ade5c..0000000000000 --- a/sql/core/src/test/resources/test-data/_cars.csv +++ /dev/null @@ -1,7 +0,0 @@ - -year,make,model,comment,blank -"2012","Tesla","S","No comment", - -1997,Ford,E350,"Go get one now they are going fast", -2015,Chevy,Volt - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index fb1bedfaa32c3..d9e5d7af19671 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -53,7 +53,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val carsEmptyValueFile = "test-data/cars-empty-value.csv" private val carsBlankColName = "test-data/cars-blank-column-name.csv" private val carsCrlf = "test-data/cars-crlf.csv" - private val carsFilteredOutFile = "test-data/_cars.csv" private val emptyFile = "test-data/empty.csv" private val commentsFile = "test-data/comments.csv" private val disableCommentsFile = "test-data/disable_comments.csv" @@ -347,25 +346,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(result.schema.fieldNames.size === 1) } - test("SPARK-26339 Not throw an exception if some of specified paths are filtered in") { - val cars = spark - .read - .option("header", "false") - .csv(testFile(carsFile), testFile(carsFilteredOutFile)) - - verifyCars(cars, withHeader = false, checkTypes = false) - } - - test("SPARK-26339 Throw an exception only if all of the specified paths are filtered out") { - val e = intercept[AnalysisException] { - val cars = spark - .read - .option("header", "false") - .csv(testFile(carsFilteredOutFile)) - }.getMessage - assert(e.contains("All paths were ignored:")) - } - test("DDL test with empty file") { withView("carsTable") { spark.sql( From 5fb5a0292d9ced48860abe712a10cbb8e513b75a Mon Sep 17 00:00:00 2001 From: Adrian Tanase Date: Mon, 7 Jan 2019 19:03:38 -0600 Subject: [PATCH 2397/2461] [MINOR][K8S] add missing docs for podTemplateContainerName properties ## What changes were proposed in this pull request? Adding docs for an enhancement that came in late in this PR: #22146 Currently the docs state that we're going to use the first container in a pod template, which was the implementation for some time, until it was improved with 2 new properties. ## How was this patch tested? I tested that the properties work by combining pod templates with client-mode and a simple pod template. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23155 from aditanase/k8s-readme. Authored-by: Adrian Tanase Signed-off-by: Sean Owen --- docs/running-on-kubernetes.md | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 3172b1bca8f05..3453ee912205f 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -229,8 +229,11 @@ pod template that will always be overwritten by Spark. Therefore, users of this the pod template file only lets Spark start with a template pod instead of an empty pod during the pod-building process. For details, see the [full list](#pod-template-properties) of pod template values that will be overwritten by spark. -Pod template files can also define multiple containers. In such cases, Spark will always assume that the first container in -the list will be the driver or executor container. +Pod template files can also define multiple containers. In such cases, you can use the spark properties +`spark.kubernetes.driver.podTemplateContainerName` and `spark.kubernetes.executor.podTemplateContainerName` +to indicate which container should be used as a basis for the driver or executor. +If not specified, or if the container name is not valid, Spark will assume that the first container in the list +will be the driver or executor container. ## Using Kubernetes Volumes @@ -932,16 +935,32 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.podTemplateFile (none) - Specify the local file that contains the driver [pod template](#pod-template). For example - spark.kubernetes.driver.podTemplateFile=/path/to/driver-pod-template.yaml` + Specify the local file that contains the driver pod template. For example + spark.kubernetes.driver.podTemplateFile=/path/to/driver-pod-template.yaml + + + + spark.kubernetes.driver.podTemplateContainerName + (none) + + Specify the container name to be used as a basis for the driver in the given pod template. + For example spark.kubernetes.driver.podTemplateContainerName=spark-driver spark.kubernetes.executor.podTemplateFile (none) - Specify the local file that contains the executor [pod template](#pod-template). For example - spark.kubernetes.executor.podTemplateFile=/path/to/executor-pod-template.yaml` + Specify the local file that contains the executor pod template. For example + spark.kubernetes.executor.podTemplateFile=/path/to/executor-pod-template.yaml + + + + spark.kubernetes.executor.podTemplateContainerName + (none) + + Specify the container name to be used as a basis for the executor in the given pod template. + For example spark.kubernetes.executor.podTemplateContainerName=spark-executor From 6f35ede31cc72a81e3852b1ac7454589d1897bfc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 7 Jan 2019 17:54:05 -0800 Subject: [PATCH 2398/2461] [SPARK-26554][BUILD][FOLLOWUP] Use GitHub instead of GitBox to check HEADER ## What changes were proposed in this pull request? This PR uses GitHub repository instead of GitBox because GitHub repo returns HTTP header status correctly. ## How was this patch tested? Manual. ``` $ ./do-release-docker.sh -d /tmp/test -n Branch [branch-2.4]: Current branch version is 2.4.1-SNAPSHOT. Release [2.4.1]: RC # [1]: This is a dry run. Please confirm the ref that will be built for testing. Ref [v2.4.1-rc1]: ``` Closes #23482 from dongjoon-hyun/SPARK-26554-2. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/create-release/release-util.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh index 9a340528b506d..5486c18e95bc8 100755 --- a/dev/create-release/release-util.sh +++ b/dev/create-release/release-util.sh @@ -21,6 +21,7 @@ DRY_RUN=${DRY_RUN:-0} GPG="gpg --no-tty --batch" ASF_REPO="https://gitbox.apache.org/repos/asf/spark.git" ASF_REPO_WEBUI="https://gitbox.apache.org/repos/asf?p=spark.git" +ASF_GITHUB_REPO="https://github.com/apache/spark" function error { echo "$*" @@ -73,9 +74,7 @@ function fcreate_secure { } function check_for_tag { - # Check HTML body messages instead of header status codes. Apache GitBox returns - # a header with `200 OK` status code for both existing and non-existing tag URLs - ! curl -s --fail "$ASF_REPO_WEBUI;a=commit;h=$1" | grep '404 Not Found' > /dev/null + curl -s --head --fail "$ASF_GITHUB_REPO/releases/tag/$1" > /dev/null } function get_release_info { From 29a7d2da44585d91a9e94bf88dc7b1f42a0e5674 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 7 Jan 2019 18:59:43 -0800 Subject: [PATCH 2399/2461] [SPARK-24196][SQL] Implement Spark's own GetSchemasOperation ## What changes were proposed in this pull request? This PR fix SQL Client tools can't show DBs by implementing Spark's own `GetSchemasOperation`. ## How was this patch tested? unit tests and manual tests ![image](https://user-images.githubusercontent.com/5399861/47782885-3dd5d400-dd3c-11e8-8586-59a8c15c7020.png) ![image](https://user-images.githubusercontent.com/5399861/47782899-4928ff80-dd3c-11e8-9d2d-ba9580ba4301.png) Closes #22903 from wangyum/SPARK-24196. Authored-by: Yuming Wang Signed-off-by: gatorsmile --- .../cli/operation/GetSchemasOperation.java | 2 +- .../SparkGetSchemasOperation.scala | 66 +++++++++++ .../server/SparkSQLOperationManager.scala | 17 ++- .../HiveThriftServer2Suites.scala | 16 +++ .../SparkMetadataOperationSuite.scala | 103 ++++++++++++++++++ 5 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java index d6f6280f1c398..3516bc2ba242c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java @@ -41,7 +41,7 @@ public class GetSchemasOperation extends MetadataOperation { .addStringColumn("TABLE_SCHEM", "Schema name.") .addStringColumn("TABLE_CATALOG", "Catalog name."); - private RowSet rowSet; + protected RowSet rowSet; protected GetSchemasOperation(HiveSession parentSession, String catalogName, String schemaName) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala new file mode 100644 index 0000000000000..d585049c28e33 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.GetSchemasOperation +import org.apache.hive.service.cli.operation.MetadataOperation.DEFAULT_HIVE_CATALOG +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.SQLContext + +/** + * Spark's own GetSchemasOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + * @param catalogName catalog name. null if not applicable. + * @param schemaName database name, null or a concrete database name + */ +private[hive] class SparkGetSchemasOperation( + sqlContext: SQLContext, + parentSession: HiveSession, + catalogName: String, + schemaName: String) + extends GetSchemasOperation(parentSession, catalogName, schemaName) { + + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + if (isAuthV2Enabled) { + val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" + authorizeMetaGets(HiveOperationType.GET_TABLES, null, cmdStr) + } + + try { + val schemaPattern = convertSchemaPattern(schemaName) + sqlContext.sessionState.catalog.listDatabases(schemaPattern).foreach { dbName => + rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG)) + } + setState(OperationState.FINISHED) + } catch { + case e: HiveSQLException => + setState(OperationState.ERROR) + throw e + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index bf7c01f60fb5c..85b6c7134755b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -21,13 +21,13 @@ import java.util.{Map => JMap} import java.util.concurrent.ConcurrentHashMap import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation} import org.apache.spark.sql.internal.SQLConf /** @@ -63,6 +63,19 @@ private[thriftserver] class SparkSQLOperationManager() operation } + override def newGetSchemasOperation( + parentSession: HiveSession, + catalogName: String, + schemaName: String): GetSchemasOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + " initialized or had already closed.") + val operation = new SparkGetSchemasOperation(sqlContext, parentSession, catalogName, schemaName) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetSchemasOperation with session=$parentSession.") + operation + } + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { val iterator = confMap.entrySet().iterator() while (iterator.hasNext) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 70eb28cdd0c64..f9509aed4aaab 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -818,6 +818,22 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } + def withDatabase(dbNames: String*)(fs: (Statement => Unit)*) { + val user = System.getProperty("user.name") + val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } + val statements = connections.map(_.createStatement()) + + try { + statements.zip(fs).foreach { case (s, f) => f(s) } + } finally { + dbNames.foreach { name => + statements(0).execute(s"DROP DATABASE IF EXISTS $name") + } + statements.foreach(_.close()) + connections.foreach(_.close()) + } + } + def withJdbcStatement(tableNames: String*)(f: Statement => Unit) { withMultipleConnectionJdbcStatement(tableNames: _*)(f) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala new file mode 100644 index 0000000000000..9a997ae01df9d --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.Properties + +import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils} +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.thrift._ +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket + +class SparkMetadataOperationSuite extends HiveThriftJdbcTest { + + override def mode: ServerMode.Value = ServerMode.binary + + test("Spark's own GetSchemasOperation(SparkGetSchemasOperation)") { + def testGetSchemasOperation( + catalog: String, + schemaPattern: String)(f: HiveQueryResultSet => Unit): Unit = { + val rawTransport = new TSocket("localhost", serverPort) + val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val client = new TCLIService.Client(new TBinaryProtocol(transport)) + transport.open() + var rs: HiveQueryResultSet = null + try { + val openResp = client.OpenSession(new TOpenSessionReq) + val sessHandle = openResp.getSessionHandle + val schemaReq = new TGetSchemasReq(sessHandle) + + if (catalog != null) { + schemaReq.setCatalogName(catalog) + } + + if (schemaPattern == null) { + schemaReq.setSchemaName("%") + } else { + schemaReq.setSchemaName(schemaPattern) + } + + val schemaResp = client.GetSchemas(schemaReq) + JdbcUtils.verifySuccess(schemaResp.getStatus) + + rs = new HiveQueryResultSet.Builder(connection) + .setClient(client) + .setSessionHandle(sessHandle) + .setStmtHandle(schemaResp.getOperationHandle) + .build() + f(rs) + } finally { + rs.close() + connection.close() + transport.close() + rawTransport.close() + } + } + + def checkResult(dbNames: Seq[String], rs: HiveQueryResultSet): Unit = { + if (dbNames.nonEmpty) { + for (i <- dbNames.indices) { + assert(rs.next()) + assert(rs.getString("TABLE_SCHEM") === dbNames(i)) + } + } else { + assert(!rs.next()) + } + } + + withDatabase("db1", "db2") { statement => + Seq("CREATE DATABASE db1", "CREATE DATABASE db2").foreach(statement.execute) + + testGetSchemasOperation(null, "%") { rs => + checkResult(Seq("db1", "db2"), rs) + } + testGetSchemasOperation(null, "db1") { rs => + checkResult(Seq("db1"), rs) + } + testGetSchemasOperation(null, "db_not_exist") { rs => + checkResult(Seq.empty, rs) + } + testGetSchemasOperation(null, "db*") { rs => + checkResult(Seq("db1", "db2"), rs) + } + } + } +} From 72a572ffd6e156243b13f9243ed296f6d77b4241 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Jan 2019 22:44:33 +0800 Subject: [PATCH 2400/2461] [SPARK-26323][SQL] Scala UDF should still check input types even if some inputs are of type Any ## What changes were proposed in this pull request? For Scala UDF, when checking input nullability, we will skip inputs with type `Any`, and only check the inputs that provide nullability info. We should do the same for checking input types. ## How was this patch tested? new tests Closes #23275 from cloud-fan/udf. Authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/analysis/TypeCoercion.scala | 13 +- .../sql/catalyst/expressions/ScalaUDF.scala | 4 +- .../spark/sql/types/AbstractDataType.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 216 ++++++++---------- .../sql/expressions/UserDefinedFunction.scala | 57 ++--- .../org/apache/spark/sql/functions.scala | 52 ++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 15 ++ 7 files changed, 175 insertions(+), 184 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b19aa50ba2156..13cc9b9c125e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -882,7 +882,18 @@ object TypeCoercion { case udf: ScalaUDF if udf.inputTypes.nonEmpty => val children = udf.children.zip(udf.inputTypes).map { case (in, expected) => - implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in) + // Currently Scala UDF will only expect `AnyDataType` at top level, so this trick works. + // In the future we should create types like `AbstractArrayType`, so that Scala UDF can + // accept inputs of array type of arbitrary element type. + if (expected == AnyDataType) { + in + } else { + implicitCast( + in, + udfInputToCastType(in.dataType, expected.asInstanceOf[DataType]) + ).getOrElse(in) + } + } udf.withNewChildren(children) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a23aaa3a0b3ef..fae1119c394b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{AbstractDataType, DataType} /** * User-defined function. @@ -48,7 +48,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputsNullSafe: Seq[Boolean], - inputTypes: Seq[DataType] = Nil, + inputTypes: Seq[AbstractDataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 5367ce2af8e9f..d2ef08873187e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -96,7 +96,7 @@ private[sql] object TypeCollection { /** * An `AbstractDataType` that matches any concrete data types. */ -protected[sql] object AnyDataType extends AbstractDataType { +protected[sql] object AnyDataType extends AbstractDataType with Serializable { // Note that since AnyDataType matches any concrete types, defaultConcreteType should never // be invoked. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 5a3f556c9c074..fe5d1afd8478a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -123,17 +123,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] | val inputSchemas: Seq[Option[ScalaReflection.Schema]] = $inputSchemas + | val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + | val finalUdf = if (nullable) udf else udf.asNonNullable() | def builder(e: Seq[Expression]) = if (e.length == $x) { - | ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - | if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - | Some(name), nullable, udfDeterministic = true) + | finalUdf.createScalaUDF(e) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $x; Found: " + e.length) | } | functionRegistry.createOrReplaceTempFunction(name, builder) - | val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - | if (nullable) udf else udf.asNonNullable() + | finalUdf |}""".stripMargin) } @@ -170,17 +169,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -191,17 +189,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -212,17 +209,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -233,17 +229,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -254,17 +249,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -275,17 +269,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -296,17 +289,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -317,17 +309,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -338,17 +329,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -359,17 +349,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -380,17 +369,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -401,17 +389,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -422,17 +409,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -443,17 +429,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -464,17 +449,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -485,17 +469,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -506,17 +489,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -527,17 +509,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -548,17 +529,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -569,17 +549,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -590,17 +569,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -611,17 +589,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } /** @@ -632,17 +609,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil + val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name) + val finalUdf = if (nullable) udf else udf.asNonNullable() def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), - if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), - Some(name), nullable, udfDeterministic = true) + finalUdf.createScalaUDF(e) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) } functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name) - if (nullable) udf else udf.asNonNullable() + finalUdf } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 901472d8e0360..1b2d6c7ffb529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.ScalaUDF -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.types.{AnyDataType, DataType} /** * A user-defined function. To create one, use the `udf` functions in `functions`. @@ -88,40 +88,47 @@ sealed abstract class UserDefinedFunction { private[sql] case class SparkUserDefinedFunction( f: AnyRef, dataType: DataType, - inputTypes: Option[Seq[DataType]], - nullableTypes: Option[Seq[Boolean]], + inputSchemas: Seq[Option[ScalaReflection.Schema]], name: Option[String] = None, nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { @scala.annotation.varargs override def apply(exprs: Column*): Column = { - // TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()` - // and `nullableTypes` is always set. - if (inputTypes.isDefined) { - assert(inputTypes.get.length == nullableTypes.get.length) - } + Column(createScalaUDF(exprs.map(_.expr))) + } + + private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = { + // It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type + // check and null check for them. + val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType)) - val inputsNullSafe = nullableTypes.getOrElse { + val inputsNullSafe = if (inputSchemas.isEmpty) { + // This is for backward compatibility of `functions.udf(AnyRef, DataType)`. We need to + // do reflection of the lambda function object and see if its arguments are nullable or not. + // This doesn't work for Scala 2.12 and we should consider removing this workaround, as Spark + // uses Scala 2.12 by default since 3.0. ScalaReflection.getParameterTypeNullability(f) + } else { + inputSchemas.map(_.map(_.nullable).getOrElse(true)) } - Column(ScalaUDF( + ScalaUDF( f, dataType, - exprs.map(_.expr), + exprs, inputsNullSafe, - inputTypes.getOrElse(Nil), + inputTypes, udfName = name, nullable = nullable, - udfDeterministic = deterministic)) + udfDeterministic = deterministic) } - override def withName(name: String): UserDefinedFunction = { + override def withName(name: String): SparkUserDefinedFunction = { copy(name = Option(name)) } - override def asNonNullable(): UserDefinedFunction = { + override def asNonNullable(): SparkUserDefinedFunction = { if (!nullable) { this } else { @@ -129,7 +136,7 @@ private[sql] case class SparkUserDefinedFunction( } } - override def asNondeterministic(): UserDefinedFunction = { + override def asNondeterministic(): SparkUserDefinedFunction = { if (!deterministic) { this } else { @@ -137,19 +144,3 @@ private[sql] case class SparkUserDefinedFunction( } } } - -private[sql] object SparkUserDefinedFunction { - - def create( - f: AnyRef, - dataType: DataType, - inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = { - val inputTypes = if (inputSchemas.contains(None)) { - None - } else { - Some(inputSchemas.map(_.get.dataType)) - } - val nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true))) - SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 645452553e6a5..7572cf23cde8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3874,7 +3874,7 @@ object functions { |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] | val inputSchemas = $inputSchemas - | val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + | val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) } @@ -3897,7 +3897,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = f$anyCast.call($anyParams) - | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = Seq.fill($i)(None)) + | SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None)) |}""".stripMargin) } @@ -3919,7 +3919,7 @@ object functions { def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3935,7 +3935,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3951,7 +3951,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3967,7 +3967,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3983,7 +3983,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3999,7 +3999,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4015,7 +4015,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4031,7 +4031,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4047,7 +4047,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4063,7 +4063,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4079,7 +4079,7 @@ object functions { def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil - val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) + val udf = SparkUserDefinedFunction(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4098,7 +4098,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = Seq.fill(0)(None)) + SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None)) } /** @@ -4112,7 +4112,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(1)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(1)(None)) } /** @@ -4126,7 +4126,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(2)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(2)(None)) } /** @@ -4140,7 +4140,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(3)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(3)(None)) } /** @@ -4154,7 +4154,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(4)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(4)(None)) } /** @@ -4168,7 +4168,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(5)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(5)(None)) } /** @@ -4182,7 +4182,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(6)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(6)(None)) } /** @@ -4196,7 +4196,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(7)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(7)(None)) } /** @@ -4210,7 +4210,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(8)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(8)(None)) } /** @@ -4224,7 +4224,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(9)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(9)(None)) } /** @@ -4238,7 +4238,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(10)(None)) + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(10)(None)) } // scalastyle:on parameter.number @@ -4257,9 +4257,7 @@ object functions { * @since 2.0.0 */ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { - // TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently - // unavailable. We may need to create type-safe overloaded versions of udf() methods. - SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes = None) + SparkUserDefinedFunction(f, dataType, inputSchemas = Nil) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index a26d306cff6b5..06b9343c37581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -450,4 +450,19 @@ class UDFSuite extends QueryTest with SharedSQLContext { }) checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556")))) } + + test("SPARK-26323 Verify input type check - with udf()") { + val f = udf((x: Long, y: Any) => x) + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j").select(f($"i", $"j")) + checkAnswer(df, Seq(Row(1L), Row(2L))) + } + + test("SPARK-26323 Verify input type check - with udf.register") { + withTable("t") { + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.format("json").saveAsTable("t") + spark.udf.register("f", (x: Long, y: Any) => x) + val df = spark.sql("SELECT f(i, j) FROM t") + checkAnswer(df, Seq(Row(1L), Row(2L))) + } + } } From b7113822d5f9a984d30cc7fb3b8920fcf630a96a Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 8 Jan 2019 10:45:23 -0600 Subject: [PATCH 2401/2461] [MINOR][WEBUI] Modify the name of the column named "shuffle spill" in the StagePage ## What changes were proposed in this pull request? ![default](https://user-images.githubusercontent.com/24688163/50752687-16463f00-128a-11e9-8ee3-4d156f7631f6.png) For this DAG, it has no shuffle operation, only sorting, and sorting leads to spill. ![default](https://user-images.githubusercontent.com/24688163/50752974-0f6bfc00-128b-11e9-9362-a0f440e02359.png) So I think the name of the column named "shuffle spill" is not all right in the StagePage ## How was this patch tested? Manual testing Closes #23483 from 10110346/shufflespillwebui. Authored-by: liuxian Signed-off-by: Sean Owen --- .../resources/org/apache/spark/ui/static/stagepage.js | 8 ++++---- .../org/apache/spark/ui/static/stagespage-template.html | 8 ++++---- .../main/scala/org/apache/spark/ui/jobs/StagePage.scala | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js index 08de2b0fee034..5b792ffc584d1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js @@ -100,10 +100,10 @@ function getColumnNameForTaskMetricSummary(columnKey) { return "Scheduler Delay"; case "diskBytesSpilled": - return "Shuffle spill (disk)"; + return "Spill (disk)"; case "memoryBytesSpilled": - return "Shuffle spill (memory)"; + return "Spill (memory)"; case "shuffleReadMetrics": return "Shuffle Read Size / Records"; @@ -842,7 +842,7 @@ $(document).ready(function () { return ""; } }, - name: "Shuffle Spill (Memory)" + name: "Spill (Memory)" }, { data : function (row, type) { @@ -852,7 +852,7 @@ $(document).ready(function () { return ""; } }, - name: "Shuffle Spill (Disk)" + name: "Spill (Disk)" }, { data : function (row, type) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagespage-template.html b/core/src/main/resources/org/apache/spark/ui/static/stagespage-template.html index 6f950c61b2d63..6b0435bb20281 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagespage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/stagespage-template.html @@ -59,8 +59,8 @@

      Aggregated Metrics by Executor

      Output Size / Records Shuffle Read Size / Records Shuffle Write Size / Records - Shuffle Spill (Memory) - Shuffle Spill (Disk) + Spill (Memory) + Spill (Disk) @@ -111,8 +111,8 @@

      Write Time Shuffle Write Size / Records Shuffle Read Size / Records - Shuffle Spill (Memory) - Shuffle Spill (Disk) + Spill (Memory) + Spill (Disk) Errors diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index a213b764abea7..3bca1d5743018 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -188,11 +188,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We }} {if (hasBytesSpilled(stageData)) {
    • - Shuffle Spill (Memory): + Spill (Memory): {Utils.bytesToString(stageData.memoryBytesSpilled)}
    • - Shuffle Spill (Disk): + Spill (Disk): {Utils.bytesToString(stageData.diskBytesSpilled)}
    • }} @@ -797,8 +797,8 @@ private[spark] object ApiHelper { val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads" val HEADER_SHUFFLE_WRITE_TIME = "Write Time" val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records" - val HEADER_MEM_SPILL = "Shuffle Spill (Memory)" - val HEADER_DISK_SPILL = "Shuffle Spill (Disk)" + val HEADER_MEM_SPILL = "Spill (Memory)" + val HEADER_DISK_SPILL = "Spill (Disk)" val HEADER_ERROR = "Errors" private[ui] val COLUMN_TO_INDEX = Map( From c101182b10cffd9314c44eefe4db53ba3d6553b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 9 Jan 2019 01:24:47 +0800 Subject: [PATCH 2402/2461] [SPARK-26002][SQL] Fix day of year calculation for Julian calendar days MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fixing leap year calculations for date operators (year/month/dayOfYear) where the Julian calendars are used (before 1582-10-04). In a Julian calendar every years which are multiples of 4 are leap years (there is no extra exception for years multiples of 100). ## How was this patch tested? With a unit test ("SPARK-26002: correct day of year calculations for Julian calendar years") which focuses to these corner cases. Manually: ``` scala> sql("select year('1500-01-01')").show() +------------------------------+ |year(CAST(1500-01-01 AS DATE))| +------------------------------+ | 1500| +------------------------------+ scala> sql("select dayOfYear('1100-01-01')").show() +-----------------------------------+ |dayofyear(CAST(1100-01-01 AS DATE))| +-----------------------------------+ | 1| +-----------------------------------+ ``` Closes #23000 from attilapiros/julianOffByDays. Authored-by: “attilapiros” Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/DateTimeUtils.scala | 56 +++++++++++++++---- .../catalyst/util/DateTimeUtilsSuite.scala | 30 ++++++++++ .../resources/sql-tests/inputs/datetime.sql | 2 + .../sql-tests/results/datetime.sql.out | 10 +++- 4 files changed, 86 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 3e5e1fbc2b368..e95117f95cdb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -53,14 +53,30 @@ object DateTimeUtils { final val NANOS_PER_MICROS = 1000L final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L - // number of days in 400 years + // number of days in 400 years by Gregorian calendar final val daysIn400Years: Int = 146097 + + // In the Julian calendar every year that is exactly divisible by 4 is a leap year without any + // exception. But in the Gregorian calendar every year that is exactly divisible by four + // is a leap year, except for years that are exactly divisible by 100, but these centurial years + // are leap years if they are exactly divisible by 400. + // So there are 3 extra days in the Julian calendar within a 400 years cycle compared to the + // Gregorian calendar. + final val extraLeapDaysIn400YearsJulian = 3 + + // number of days in 400 years by Julian calendar + final val daysIn400YearsInJulian: Int = daysIn400Years + extraLeapDaysIn400YearsJulian + // number of days between 1.1.1970 and 1.1.2001 final val to2001 = -11323 // this is year -17999, calculation: 50 * daysIn400Year final val YearZero = -17999 final val toYearZero = to2001 + 7304850 + + // days to year -17999 in Julian calendar + final val toYearZeroInJulian = toYearZero + 49 * extraLeapDaysIn400YearsJulian + final val TimeZoneGMT = TimeZone.getTimeZone("GMT") final val TimeZoneUTC = TimeZone.getTimeZone("UTC") final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) @@ -575,20 +591,30 @@ object DateTimeUtils { * Return the number of days since the start of 400 year period. * The second year of a 400 year period (year 1) starts on day 365. */ - private[this] def yearBoundary(year: Int): Int = { - year * 365 + ((year / 4 ) - (year / 100) + (year / 400)) + private[this] def yearBoundary(year: Int, isGregorian: Boolean): Int = { + if (isGregorian) { + year * 365 + ((year / 4) - (year / 100) + (year / 400)) + } else { + year * 365 + (year / 4) + } } /** * Calculates the number of years for the given number of days. This depends * on a 400 year period. * @param days days since the beginning of the 400 year period + * @param isGregorian indicates whether leap years should be calculated according to Gregorian + * (or Julian) calendar * @return (number of year, days in year) */ - private[this] def numYears(days: Int): (Int, Int) = { + private[this] def numYears(days: Int, isGregorian: Boolean): (Int, Int) = { val year = days / 365 - val boundary = yearBoundary(year) - if (days > boundary) (year, days - boundary) else (year - 1, days - yearBoundary(year - 1)) + val boundary = yearBoundary(year, isGregorian) + if (days > boundary) { + (year, days - boundary) + } else { + (year - 1, days - yearBoundary(year - 1, isGregorian)) + } } /** @@ -599,18 +625,26 @@ object DateTimeUtils { * equals to the period 1.1.1601 until 31.12.2000. */ private[this] def getYearAndDayInYear(daysSince1970: SQLDate): (Int, Int) = { - // add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999) - var daysSince1970Tmp = daysSince1970 // Since Julian calendar was replaced with the Gregorian calendar, // the 10 days after Oct. 4 were skipped. // (1582-10-04) -141428 days since 1970-01-01 if (daysSince1970 <= -141428) { - daysSince1970Tmp -= 10 + getYearAndDayInYear(daysSince1970 - 10, toYearZeroInJulian, daysIn400YearsInJulian, false) + } else { + getYearAndDayInYear(daysSince1970, toYearZero, daysIn400Years, true) } - val daysNormalized = daysSince1970Tmp + toYearZero + } + + private def getYearAndDayInYear( + daysSince1970: SQLDate, + toYearZero: SQLDate, + daysIn400Years: SQLDate, + isGregorian: Boolean): (Int, Int) = { + // add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999) + val daysNormalized = daysSince1970 + toYearZero val numOfQuarterCenturies = daysNormalized / daysIn400Years val daysInThis400 = daysNormalized % daysIn400Years + 1 - val (years, dayInYear) = numYears(daysInThis400) + val (years, dayInYear) = numYears(daysInThis400, isGregorian) val year: Int = (2001 - 20000) + 400 * numOfQuarterCenturies + years (year, dayInYear) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 0182eeb171215..2cb6110e2c093 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -410,6 +410,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) } + test("SPARK-26002: correct day of year calculations for Julian calendar years") { + val c = Calendar.getInstance() + c.set(Calendar.MILLISECOND, 0) + (1000 to 1600 by 100).foreach { year => + // January 1 is the 1st day of year. + c.set(year, 0, 1, 0, 0, 0) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === year) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 1) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 1) + + // March 1 is the 61st day of the year as they are leap years. It is true for + // even the multiples of 100 as before 1582-10-4 the Julian calendar leap year calculation + // is used in which every multiples of 4 are leap years + c.set(year, 2, 1, 0, 0, 0) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 61) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + + // testing leap day (February 29) in leap years + c.set(year, 1, 29, 0, 0, 0) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 60) + + // For non-leap years: + c.set(year + 1, 2, 1, 0, 0, 0) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 60) + } + + c.set(1582, 2, 1, 0, 0, 0) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 60) + } + test("get year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..8bd8bc2b94b8e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,5 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select year('1500-01-01'), month('1500-01-01'), dayOfYear('1500-01-01'); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 63aa00426ea32..2090633802e26 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 11 -- !query 0 @@ -89,3 +89,11 @@ select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), week struct -- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select year('1500-01-01'), month('1500-01-01'), dayOfYear('1500-01-01') +-- !query 10 schema +struct +-- !query 10 output +1500 1 1 From 2783e4c45f55f4fc87748d1c4a454bfdf3024156 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 Jan 2019 11:25:33 -0600 Subject: [PATCH 2403/2461] [SPARK-24522][UI] Create filter to apply HTTP security checks consistently. Currently there is code scattered in a bunch of places to do different things related to HTTP security, such as access control, setting security-related headers, and filtering out bad content. This makes it really easy to miss these things when writing new UI code. This change creates a new filter that does all of those things, and makes sure that all servlet handlers that are attached to the UI get the new filter and any user-defined filters consistently. The extent of the actual features should be the same as before. The new filter is added at the end of the filter chain, because authentication is done by custom filters and thus needs to happen first. This means that custom filters see unfiltered HTTP requests - which is actually the current behavior anyway. As a side-effect of some of the code refactoring, handlers added after the initial set also get wrapped with a GzipHandler, which didn't happen before. Tested with added unit tests and in a history server with SPNEGO auth configured. Closes #23302 from vanzin/SPARK-24522. Authored-by: Marcelo Vanzin Signed-off-by: Imran Rashid --- .../spark/deploy/history/HistoryPage.scala | 5 +- .../spark/deploy/history/HistoryServer.scala | 8 +- .../deploy/master/ui/ApplicationPage.scala | 3 +- .../spark/deploy/master/ui/MasterPage.scala | 6 +- .../spark/deploy/worker/ui/LogPage.scala | 28 ++-- .../spark/deploy/worker/ui/WorkerWebUI.scala | 1 - .../spark/metrics/sink/MetricsServlet.scala | 2 +- .../spark/status/api/v1/SecurityFilter.scala | 36 ---- .../apache/spark/ui/HttpSecurityFilter.scala | 116 +++++++++++++ .../org/apache/spark/ui/JettyUtils.scala | 154 ++++++++--------- .../scala/org/apache/spark/ui/UIUtils.scala | 21 --- .../scala/org/apache/spark/ui/WebUI.scala | 15 +- .../ui/exec/ExecutorThreadDumpPage.scala | 4 +- .../apache/spark/ui/jobs/AllJobsPage.scala | 16 +- .../org/apache/spark/ui/jobs/JobPage.scala | 3 +- .../org/apache/spark/ui/jobs/JobsTab.scala | 4 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 3 +- .../org/apache/spark/ui/jobs/StagePage.scala | 19 +-- .../org/apache/spark/ui/jobs/StageTable.scala | 15 +- .../org/apache/spark/ui/jobs/StagesTab.scala | 4 +- .../org/apache/spark/ui/storage/RDDPage.scala | 11 +- .../spark/ui/HttpSecurityFilterSuite.scala | 157 ++++++++++++++++++ .../scala/org/apache/spark/ui/UISuite.scala | 147 +++++++++++----- .../org/apache/spark/ui/UIUtilsSuite.scala | 39 ----- .../spark/deploy/mesos/ui/DriverPage.scala | 3 +- .../cluster/YarnSchedulerBackend.scala | 35 +++- .../cluster/YarnSchedulerBackendSuite.scala | 59 ++++++- .../sql/execution/ui/AllExecutionsPage.scala | 19 +-- .../sql/execution/ui/ExecutionPage.scala | 3 +- .../ui/ThriftServerSessionPage.scala | 3 +- .../apache/spark/streaming/ui/BatchPage.scala | 8 +- 31 files changed, 609 insertions(+), 338 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 00ca4efa4d266..7a8ab7fddd79f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -27,9 +27,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val requestedIncomplete = - Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean + val requestedIncomplete = Option(request.getParameter("showIncomplete")) + .getOrElse("false").toBoolean val displayApplications = parent.getApplicationList() .exists(isApplicationCompleted(_) != requestedIncomplete) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index b9303388638fd..ff2ea3b843ee3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -150,17 +150,15 @@ class HistoryServer( ui: SparkUI, completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") - handlers.synchronized { - ui.getHandlers.foreach(attachHandler) + ui.getHandlers.foreach { handler => + serverInfo.get.addHandler(handler, ui.securityManager) } } /** Detach a reconstructed UI from this server. Only valid after bind(). */ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") - handlers.synchronized { - ui.getHandlers.foreach(detachHandler) - } + ui.getHandlers.foreach(detachHandler) provider.onUIDetached(appId, attemptId, ui) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index fad4e46dc035d..bcd7a7e4ccdb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -33,8 +33,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val appId = UIUtils.stripXSS(request.getParameter("appId")) + val appId = request.getParameter("appId") val state = master.askSync[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index b8afe203fbfa2..6701465c023c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -57,10 +57,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { if (parent.killEnabled && parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val killFlag = - Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean - val id = Option(UIUtils.stripXSS(request.getParameter("id"))) + val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean + val id = Option(request.getParameter("id")) if (id.isDefined && killFlag) { action(id.get) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 4fca9342c0378..4e720a759a1bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -33,15 +33,13 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val supportedLogTypes = Set("stderr", "stdout") private val defaultBytes = 100 * 1024 - // stripXSS is called first to remove suspicious characters used in XSS attacks def renderLog(request: HttpServletRequest): String = { - val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) - val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) - val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) - val logType = UIUtils.stripXSS(request.getParameter("logType")) - val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) - val byteLength = - Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) + val logType = request.getParameter("logType") + val offset = Option(request.getParameter("offset")).map(_.toLong) + val byteLength = Option(request.getParameter("byteLength")).map(_.toInt) .getOrElse(defaultBytes) val logDir = (appId, executorId, driverId) match { @@ -58,15 +56,13 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with pre + logText } - // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) - val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) - val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) - val logType = UIUtils.stripXSS(request.getParameter("logType")) - val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) - val byteLength = - Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) + val logType = request.getParameter("logType") + val offset = Option(request.getParameter("offset")).map(_.toLong) + val byteLength = Option(request.getParameter("byteLength")).map(_.toInt) .getOrElse(defaultBytes) val (logDir, params, pageName) = (appId, executorId, driverId) match { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index ea67b7434a769..54886955b98fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -50,7 +50,6 @@ class WorkerWebUI( addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), - worker.securityMgr, worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 68b58b8490641..bea24ca7807e4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -51,7 +51,7 @@ private[spark] class MetricsServlet( def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), conf) ) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala deleted file mode 100644 index 1cd37185d6601..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import javax.ws.rs.container.{ContainerRequestContext, ContainerRequestFilter} -import javax.ws.rs.core.Response -import javax.ws.rs.ext.Provider - -@Provider -private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext { - override def filter(req: ContainerRequestContext): Unit = { - val user = httpRequest.getRemoteUser() - if (!uiRoot.securityManager.checkUIViewPermissions(user)) { - req.abortWith( - Response - .status(Response.Status.FORBIDDEN) - .entity(raw"""user "$user" is not authorized""") - .build() - ) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala new file mode 100644 index 0000000000000..da84fdf8fe140 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.{Enumeration, Map => JMap} +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} + +import scala.collection.JavaConverters._ + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.config._ + +/** + * A servlet filter that implements HTTP security features. The following actions are taken + * for every request: + * + * - perform access control of authenticated requests. + * - check request data for disallowed content (e.g. things that could be used to create XSS + * attacks). + * - set response headers to prevent certain kinds of attacks. + * + * Request parameters are sanitized so that HTML content is escaped, and disallowed content is + * removed. + */ +private class HttpSecurityFilter( + conf: SparkConf, + securityMgr: SecurityManager) extends Filter { + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val hres = res.asInstanceOf[HttpServletResponse] + hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + + if (!securityMgr.checkUIViewPermissions(hreq.getRemoteUser())) { + hres.sendError(HttpServletResponse.SC_FORBIDDEN, + "User is not authorized to access this page.") + return + } + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val xFrameOptionsValue = conf.getOption("spark.ui.allowFramingFrom") + .map { uri => s"ALLOW-FROM $uri" } + .getOrElse("SAMEORIGIN") + + hres.setHeader("X-Frame-Options", xFrameOptionsValue) + hres.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION)) + if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) { + hres.setHeader("X-Content-Type-Options", "nosniff") + } + if (hreq.getScheme() == "https") { + conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach( + hres.setHeader("Strict-Transport-Security", _)) + } + + chain.doFilter(new XssSafeRequest(hreq), res) + } + +} + +private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequestWrapper(req) { + + private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r + + private val parameterMap: Map[String, Array[String]] = { + super.getParameterMap().asScala.map { case (name, values) => + stripXSS(name) -> values.map(stripXSS) + }.toMap + } + + override def getParameterMap(): JMap[String, Array[String]] = parameterMap.asJava + + override def getParameterNames(): Enumeration[String] = { + parameterMap.keys.iterator.asJavaEnumeration + } + + override def getParameterValues(name: String): Array[String] = parameterMap.get(name).orNull + + override def getParameter(name: String): String = { + parameterMap.get(name).flatMap(_.headOption).orNull + } + + private def stripXSS(str: String): String = { + if (str != null) { + // Remove new lines and single quotes, followed by escaping HTML version 4.0 + StringEscapeUtils.escapeHtml4(NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(str, "")) + } else { + null + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 316af9b79d286..08f5fb937da7e 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui import java.net.{URI, URL} +import java.util.EnumSet import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} @@ -68,43 +69,16 @@ private[spark] object JettyUtils extends Logging { implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] = new ServletParams(responder, "text/plain") - def createServlet[T <: AnyRef]( + private def createServlet[T <: AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager, conf: SparkConf): HttpServlet = { - - // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options - // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the - // same origin, but allow framing for a specific named URI. - // Example: spark.ui.allowFramingFrom = https://example.com/ - val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") - val xFrameOptionsValue = - allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") - new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { - if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { - response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) - response.setStatus(HttpServletResponse.SC_OK) - val result = servletParams.responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.setHeader("X-Frame-Options", xFrameOptionsValue) - response.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION)) - if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) { - response.setHeader("X-Content-Type-Options", "nosniff") - } - if (request.getScheme == "https") { - conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach( - response.setHeader("Strict-Transport-Security", _)) - } - response.getWriter.print(servletParams.extractFn(result)) - } else { - response.setStatus(HttpServletResponse.SC_FORBIDDEN) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_FORBIDDEN, - "User is not authorized to access this page.") - } + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.getWriter.print(servletParams.extractFn(result)) } catch { case e: IllegalArgumentException => response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage) @@ -124,10 +98,9 @@ private[spark] object JettyUtils extends Logging { def createServletHandler[T <: AnyRef]( path: String, servletParams: ServletParams[T], - securityMgr: SecurityManager, conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) + createServletHandler(path, createServlet(servletParams, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ @@ -257,36 +230,6 @@ private[spark] object JettyUtils extends Logging { contextHandler } - /** Add filters, if any, to the given list of ServletContextHandlers */ - def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { - val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) - filters.foreach { - case filter : String => - if (!filter.isEmpty) { - logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.") - val holder : FilterHolder = new FilterHolder() - holder.setClassName(filter) - // Get any parameters for each filter - conf.get("spark." + filter + ".params", "").split(',').map(_.trim()).toSet.foreach { - param: String => - if (!param.isEmpty) { - val parts = param.split("=") - if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) - } - } - - val prefix = s"spark.$filter.param." - conf.getAll - .filter { case (k, v) => k.length() > prefix.length() && k.startsWith(prefix) } - .foreach { case (k, v) => holder.setInitParameter(k.substring(prefix.length()), v) } - - val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, - DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) - handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } - } - } - } - /** * Attempt to start a Jetty server bound to the supplied hostName:port using the given * context handlers. @@ -298,12 +241,9 @@ private[spark] object JettyUtils extends Logging { hostName: String, port: Int, sslOptions: SSLOptions, - handlers: Seq[ServletContextHandler], conf: SparkConf, serverName: String = ""): ServerInfo = { - addFilters(handlers, conf) - // Start the server first, with no connectors. val pool = new QueuedThreadPool if (serverName.nonEmpty) { @@ -398,16 +338,6 @@ private[spark] object JettyUtils extends Logging { } server.addConnector(httpConnector) - - // Add all the known handlers now that connectors are configured. - handlers.foreach { h => - h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME)) - val gzipHandler = new GzipHandler() - gzipHandler.setHandler(h) - collection.addHandler(gzipHandler) - gzipHandler.start() - } - pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) ServerInfo(server, httpPort, securePort, conf, collection) } catch { @@ -489,6 +419,16 @@ private[spark] object JettyUtils extends Logging { } } + def addFilter( + handler: ServletContextHandler, + filter: String, + params: Map[String, String]): Unit = { + val holder = new FilterHolder() + holder.setClassName(filter) + params.foreach { case (k, v) => holder.setInitParameter(k, v) } + handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType])) + } + // Create a new URI from the arguments, handling IPv6 host encoding and default ports. private def createRedirectURI( scheme: String, server: String, port: Int, path: String, query: String) = { @@ -509,20 +449,37 @@ private[spark] case class ServerInfo( server: Server, boundPort: Int, securePort: Option[Int], - conf: SparkConf, - private val rootHandler: ContextHandlerCollection) { + private val conf: SparkConf, + private val rootHandler: ContextHandlerCollection) extends Logging { - def addHandler(handler: ServletContextHandler): Unit = { + def addHandler( + handler: ServletContextHandler, + securityMgr: SecurityManager): Unit = synchronized { handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) - JettyUtils.addFilters(Seq(handler), conf) - rootHandler.addHandler(handler) + addFilters(handler, securityMgr) + + val gzipHandler = new GzipHandler() + gzipHandler.setHandler(handler) + rootHandler.addHandler(gzipHandler) + if (!handler.isStarted()) { handler.start() } + gzipHandler.start() } - def removeHandler(handler: ContextHandler): Unit = { - rootHandler.removeHandler(handler) + def removeHandler(handler: ServletContextHandler): Unit = synchronized { + // Since addHandler() always adds a wrapping gzip handler, find the container handler + // and remove it. + rootHandler.getHandlers() + .find { h => + h.isInstanceOf[GzipHandler] && h.asInstanceOf[GzipHandler].getHandler() == handler + } + .foreach { h => + rootHandler.removeHandler(h) + h.stop() + } + if (handler.isStarted) { handler.stop() } @@ -537,4 +494,33 @@ private[spark] case class ServerInfo( threadPool.asInstanceOf[LifeCycle].stop } } + + /** + * Add filters, if any, to the given ServletContextHandlers. Always adds a filter at the end + * of the chain to perform security-related functions. + */ + private def addFilters(handler: ServletContextHandler, securityMgr: SecurityManager): Unit = { + conf.getOption("spark.ui.filters").toSeq.flatMap(Utils.stringToSeq).foreach { filter => + logInfo(s"Adding filter to ${handler.getContextPath()}: $filter") + val oldParams = conf.getOption(s"spark.$filter.params").toSeq + .flatMap(Utils.stringToSeq) + .flatMap { param => + val parts = param.split("=") + if (parts.length == 2) Some(parts(0) -> parts(1)) else None + } + .toMap + + val newParams = conf.getAllWithPrefix(s"spark.$filter.param.").toMap + + JettyUtils.addFilter(handler, filter, oldParams ++ newParams) + } + + // This filter must come after user-installed filters, since that's where authentication + // filters are installed. This means that custom filters will see the request before it's + // been validated by the security filter. + val securityFilter = new HttpSecurityFilter(conf, securityMgr) + val holder = new FilterHolder(securityFilter) + handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType])) + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 60a929375baae..967435030bc4d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -27,8 +27,6 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} -import org.apache.commons.lang3.StringEscapeUtils - import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -38,8 +36,6 @@ private[spark] object UIUtils extends Logging { val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" - private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r - // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = @@ -552,23 +548,6 @@ private[spark] object UIUtils extends Logging { } } - /** - * Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks - * - * For more information about XSS testing: - * https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and - * https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001) - */ - def stripXSS(requestParameter: String): String = { - if (requestParameter == null) { - null - } else { - // Remove new lines and single quotes, followed by escaping HTML version 4.0 - StringEscapeUtils.escapeHtml4( - NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, "")) - } - } - def buildErrorResponse(status: Response.Status, msg: String): Response = { Response.status(status).entity(msg).`type`(MediaType.TEXT_PLAIN).build() } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 2e43f17e6a8e3..ebf8655ce8c2f 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -58,7 +58,6 @@ private[spark] abstract class WebUI( def getBasePath: String = basePath def getTabs: Seq[WebUITab] = tabs def getHandlers: Seq[ServletContextHandler] = handlers - def getSecurityManager: SecurityManager = securityManager /** Attaches a tab to this UI, along with all of its attached pages. */ def attachTab(tab: WebUITab): Unit = { @@ -81,9 +80,9 @@ private[spark] abstract class WebUI( def attachPage(page: WebUIPage): Unit = { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) + (request: HttpServletRequest) => page.render(request), conf, basePath) val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, conf, basePath) + (request: HttpServletRequest) => page.renderJson(request), conf, basePath) attachHandler(renderHandler) attachHandler(renderJsonHandler) val handlers = pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) @@ -91,13 +90,13 @@ private[spark] abstract class WebUI( } /** Attaches a handler to this UI. */ - def attachHandler(handler: ServletContextHandler): Unit = { + def attachHandler(handler: ServletContextHandler): Unit = synchronized { handlers += handler - serverInfo.foreach(_.addHandler(handler)) + serverInfo.foreach(_.addHandler(handler, securityManager)) } /** Detaches a handler from this UI. */ - def detachHandler(handler: ServletContextHandler): Unit = { + def detachHandler(handler: ServletContextHandler): Unit = synchronized { handlers -= handler serverInfo.foreach(_.removeHandler(handler)) } @@ -129,7 +128,9 @@ private[spark] abstract class WebUI( assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") - serverInfo = Some(startJettyServer(host, port, sslOptions, handlers, conf, name)) + val server = startJettyServer(host, port, sslOptions, conf, name) + handlers.foreach(server.addHandler(_, securityManager)) + serverInfo = Some(server) logInfo(s"Bound $className to $host, and started at $webUrl") } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f9713fb5b4a3c..a13037b5e24db 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -28,10 +28,8 @@ private[ui] class ExecutorThreadDumpPage( parent: SparkUITab, sc: Option[SparkContext]) extends WebUIPage("threadDump") { - // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val executorId = - Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId => + val executorId = Option(request.getParameter("executorId")).map { executorId => UIUtils.decodeURLParameter(executorId) }.getOrElse { throw new IllegalArgumentException(s"Missing executorId parameter") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2c22e0555fcb8..b35ea5b52549b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -205,21 +205,17 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobTag: String, jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { - // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => - UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq - } - val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) + val parameterOtherTable = request.getParameterMap().asScala + .filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page")) - val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort")) - val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc")) - val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize")) + val parameterJobPage = request.getParameter(jobTag + ".page") + val parameterJobSortColumn = request.getParameter(jobTag + ".sort") + val parameterJobSortDesc = request.getParameter(jobTag + ".desc") + val parameterJobPageSize = request.getParameter(jobTag + ".pageSize") val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1) val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index cd82439223b07..46295e73e086b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -184,8 +184,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP } def render(request: HttpServletRequest): Seq[Node] = { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) + val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val jobId = parameterId.toInt diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index ff1b75e5c5065..37bb292bd5950 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -47,9 +47,7 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) - jobId.foreach { id => + Option(request.getParameter("id")).map(_.toInt).foreach { id => store.asOption(store.job(id)).foreach { job => if (job.status == JobExecutionStatus.RUNNING) { sc.foreach(_.cancelJob(id)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 22a40101e33df..6d2710385d9d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -29,8 +29,7 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => + val poolName = Option(request.getParameter("poolname")).map { poolname => UIUtils.decodeURLParameter(poolname) }.getOrElse { throw new IllegalArgumentException(s"Missing poolname parameter") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 3bca1d5743018..8ec625da042f7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -80,22 +80,19 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } def render(request: HttpServletRequest): Seq[Node] = { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) + val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) + val parameterAttempt = request.getParameter("attempt") require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") - val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) - val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) - val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) - val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val parameterTaskPage = request.getParameter("task.page") + val parameterTaskSortColumn = request.getParameter("task.sort") + val parameterTaskSortDesc = request.getParameter("task.desc") + val parameterTaskPageSize = request.getParameter("task.pageSize") - val eventTimelineParameterTaskPage = UIUtils.stripXSS( - request.getParameter("task.eventTimelinePageNumber")) - val eventTimelineParameterTaskPageSize = UIUtils.stripXSS( - request.getParameter("task.eventTimelinePageSize")) + val eventTimelineParameterTaskPage = request.getParameter("task.eventTimelinePageNumber") + val eventTimelineParameterTaskPageSize = request.getParameter("task.eventTimelinePageSize") var eventTimelineTaskPage = Option(eventTimelineParameterTaskPage).map(_.toInt).getOrElse(1) var eventTimelineTaskPageSize = Option( eventTimelineParameterTaskPageSize).map(_.toInt).getOrElse(100) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 766efc15e26ba..330b6422a13af 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -42,17 +42,14 @@ private[ui] class StageTableBase( isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { - // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => - UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq - } - val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) + val parameterOtherTable = request.getParameterMap().asScala + .filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) - val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page")) - val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort")) - val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc")) - val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize")) + val parameterStagePage = request.getParameter(stageTag + ".page") + val parameterStageSortColumn = request.getParameter(stageTag + ".sort") + val parameterStageSortDesc = request.getParameter(stageTag + ".desc") + val parameterStagePageSize = request.getParameter(stageTag + ".pageSize") val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1) val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 10b032084ce4f..e16c337ba1643 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -45,9 +45,7 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) - stageId.foreach { id => + Option(request.getParameter("id")).map(_.toInt).foreach { id => store.asOption(store.lastStageAttempt(id)).foreach { stage => val status = stage.status if (status == StageStatus.ACTIVE || status == StageStatus.PENDING) { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 87da290c83057..dde441abe5903 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -31,14 +31,13 @@ import org.apache.spark.util.Utils private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("rdd") { def render(request: HttpServletRequest): Seq[Node] = { - // stripXSS is called first to remove suspicious characters used in XSS attacks - val parameterId = UIUtils.stripXSS(request.getParameter("id")) + val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page")) - val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort")) - val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc")) - val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize")) + val parameterBlockPage = request.getParameter("block.page") + val parameterBlockSortColumn = request.getParameter("block.sort") + val parameterBlockSortDesc = request.getParameter("block.desc") + val parameterBlockPageSize = request.getParameter("block.pageSize") val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") diff --git a/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala new file mode 100644 index 0000000000000..f46cc293ed271 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.UUID +import javax.servlet.FilterChain +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.collection.JavaConverters._ + +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{mock, times, verify, when} + +import org.apache.spark._ +import org.apache.spark.internal.config._ + +class HttpSecurityFilterSuite extends SparkFunSuite { + + test("filter bad user input") { + val badValues = Map( + "encoded" -> "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b", + "alert1" -> """>"'>